mxnet
bfloat.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
26 #ifndef MSHADOW_BFLOAT_H_
27 #define MSHADOW_BFLOAT_H_
28 #include "./base.h"
29 
31 namespace mshadow {
32 /* \brief name space for host/device portable bfloats */
33 namespace bfloat {
34 
35 #define MSHADOW_BF16_OPERATOR_TYPE(RTYPE, ITYPE, OP) \
36  MSHADOW_XINLINE RTYPE operator OP (ITYPE a, bf16_t b) { \
37  return RTYPE(a OP float(b)); /* NOLINT(*) */ \
38  } \
39  MSHADOW_XINLINE RTYPE operator OP (bf16_t a, ITYPE b) { \
40  return RTYPE(float(a) OP b); /* NOLINT(*) */ \
41  }
42 
43 #define MSHADOW_BF16_OPERATOR(RTYPE, OP) \
44  MSHADOW_XINLINE RTYPE operator OP (bf16_t a, bf16_t b) { \
45  return RTYPE(static_cast<float>(a) OP float(b)); /* NOLINT(*) */ \
46  } \
47  MSHADOW_BF16_OPERATOR_TYPE(float, float, OP) \
48  MSHADOW_BF16_OPERATOR_TYPE(double, double, OP) \
49  MSHADOW_BF16_OPERATOR_TYPE(float, int8_t, OP) \
50  MSHADOW_BF16_OPERATOR_TYPE(float, uint8_t, OP) \
51  MSHADOW_BF16_OPERATOR_TYPE(float, int32_t, OP) \
52  MSHADOW_BF16_OPERATOR_TYPE(float, uint32_t, OP) \
53  MSHADOW_BF16_OPERATOR_TYPE(float, int64_t, OP) \
54  MSHADOW_BF16_OPERATOR_TYPE(float, uint64_t, OP)
55 
56 #define MSHADOW_BF16_ASSIGNOP(AOP, OP) \
57  template<typename T> \
58  MSHADOW_XINLINE bf16_t operator AOP (const T& a) { \
59  return *this = bf16_t(float(*this) OP float(a)); /* NOLINT(*)*/ \
60  } \
61  template<typename T> \
62  MSHADOW_XINLINE bf16_t operator AOP (const volatile T& a) volatile { \
63  return *this = bf16_t(float(*this) OP float(a)); /* NOLINT(*)*/ \
64  }
65 
66 #define MSHADOW_BF16_CONVERSIONOP(T) \
67  MSHADOW_XINLINE operator T() const { \
68  return T(BF16ToFloat(bf16_)); /* NOLINT(*)*/ \
69  } \
70  MSHADOW_XINLINE operator T() const volatile { \
71  return T(BF16ToFloat(bf16_)); /* NOLINT(*)*/ \
72  }
73 
74 class MSHADOW_ALIGNED(2) bf16_t {
75  public:
76  uint16_t bf16_;
77 
78 static MSHADOW_XINLINE bf16_t Binary(uint16_t value) {
79  bf16_t res;
80  res.bf16_ = value;
81  return res;
82  }
83 
84  MSHADOW_XINLINE bf16_t() {}
85 
86  MSHADOW_XINLINE bf16_t(const float& value) { constructor(value); }
87  MSHADOW_XINLINE explicit bf16_t(const double& value) { constructor(value); }
88  MSHADOW_XINLINE explicit bf16_t(const int8_t& value) { constructor(value); }
89  MSHADOW_XINLINE explicit bf16_t(const uint8_t& value) { constructor(value); }
90  MSHADOW_XINLINE explicit bf16_t(const int32_t& value) { constructor(value); }
91  MSHADOW_XINLINE explicit bf16_t(const uint32_t& value) { constructor(value); }
92  MSHADOW_XINLINE explicit bf16_t(const int64_t& value) { constructor(value); }
93  MSHADOW_XINLINE explicit bf16_t(const uint64_t& value) { constructor(value); }
94 
96 
100  MSHADOW_BF16_ASSIGNOP(/=, /)
101 
102  MSHADOW_XINLINE bf16_t operator+() {
103  return *this;
104  }
105 
106  MSHADOW_XINLINE bf16_t operator-() {
107  return bf16_t(-float(*this)); // NOLINT(*)
108  }
109 
110  MSHADOW_XINLINE bf16_t operator=(const bf16_t& a) {
111  bf16_ = a.bf16_;
112  return a;
113  }
114 
115  template<typename T>
116  MSHADOW_XINLINE bf16_t operator=(const T& a) {
117  return *this = bf16_t(a); /* NOLINT(*)*/
118  }
119 
120  MSHADOW_XINLINE bf16_t operator=(const bf16_t& a) volatile {
121  bf16_ = a.bf16_;
122  return a;
123  }
124 
125  template<typename T>
126  MSHADOW_XINLINE bf16_t operator=(const T& a) volatile {
127  return *this = bf16_t(a); /* NOLINT(*)*/
128  }
129 
130  private:
131  union Bits {
132  float f;
133  int32_t si;
134  uint32_t ui;
135  };
136 
137  MSHADOW_XINLINE uint16_t FloatToBF16(const float& value) const {
138  return reinterpret_cast<const uint16_t*>(&value)[1];
139  }
140 
141  // Same as above routine, except for addition of volatile keyword
142  MSHADOW_XINLINE uint16_t FloatToBF16(const volatile float& value) const volatile { // NOLINT (*)
143  return reinterpret_cast<const volatile uint16_t*>(&value)[1];
144  }
145 
146  MSHADOW_XINLINE float BF16ToFloat(const uint16_t& value) const {
147  float ret = 0.f;
148  reinterpret_cast<uint16_t*>(&ret)[1] = value;
149  return ret;
150  }
151 
152  MSHADOW_XINLINE float BF16ToFloat(const volatile uint16_t& value) const volatile { // NOLINT(*)
153  float ret = 0.f;
154  reinterpret_cast<uint16_t*>(&ret)[1] = value;
155  return ret;
156  }
157 
158  template<typename T>
159  MSHADOW_XINLINE void constructor(const T& value) {
160  bf16_ = FloatToBF16(float(value)); // NOLINT(*)
161  }
162 };
163 
165 MSHADOW_BF16_OPERATOR(bf16_t, +)
167 MSHADOW_BF16_OPERATOR(bf16_t, -)
169 MSHADOW_BF16_OPERATOR(bf16_t, *)
171 MSHADOW_BF16_OPERATOR(bf16_t, /)
173 MSHADOW_BF16_OPERATOR(bool, >)
175 MSHADOW_BF16_OPERATOR(bool, <)
177 MSHADOW_BF16_OPERATOR(bool, >=)
179 MSHADOW_BF16_OPERATOR(bool, <=)
180 
181 #define MSHADOW_BF16_MIN mshadow::bfloat::bf16_t::Binary(0xFF7F);
182 #define MSHADOW_BF16_MAX mshadow::bfloat::bf16_t::Binary(0x7F7F);
183 #define MSHADOW_BF16_SIGN_BIT 0x8000
184 #define MSHADOW_BF16_EXPONENT_BITS 0x7f80
185 } // namespace bfloat
186 } // namespace mshadow
187 #endif // MSHADOW_BFLOAT_H_
MSHADOW_BF16_OPERATOR
#define MSHADOW_BF16_OPERATOR(RTYPE, OP)
Definition: bfloat.h:43
mshadow::bfloat::MSHADOW_ALIGNED
class MSHADOW_ALIGNED(2) bf16_t
Definition: bfloat.h:74
MSHADOW_BF16_ASSIGNOP
#define MSHADOW_BF16_ASSIGNOP(AOP, OP)
Definition: bfloat.h:56
MSHADOW_XINLINE
#define MSHADOW_XINLINE
Definition: base.h:228
mshadow::expr::operator+
BinaryMapExp< op::plus, TA, ScalarExp< MSHADOW_SCALAR_ >, MSHADOW_SCALAR_,(ta|type::kMapper)> operator+(const Exp< TA, MSHADOW_SCALAR_, ta > &lhs, const ScalarExp< MSHADOW_SCALAR_ > &rhs)
operator overload
Definition: expr_scalar-inl.h:93
MSHADOW_BF16_CONVERSIONOP
#define MSHADOW_BF16_CONVERSIONOP(T)
Definition: bfloat.h:66
mshadow
overloaded + operator between half_t and bf16_t
Definition: base.h:319
mshadow::expr::operator-
BinaryMapExp< op::minus, TA, ScalarExp< MSHADOW_SCALAR_ >, MSHADOW_SCALAR_,(ta|type::kMapper)> operator-(const Exp< TA, MSHADOW_SCALAR_, ta > &lhs, const ScalarExp< MSHADOW_SCALAR_ > &rhs)
operator overload
Definition: expr_scalar-inl.h:101
base.h
definitions of base types, operators, macros functions