mxnet
bfloat.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
27 #ifndef MSHADOW_BFLOAT_H_
28 #define MSHADOW_BFLOAT_H_
29 #include "./base.h"
30 
32 namespace mshadow {
33 /* \brief name space for host/device portable bfloats */
34 namespace bfloat {
35 
36 #define MSHADOW_BF16_OPERATOR_TYPE(RTYPE, ITYPE, OP) \
37  MSHADOW_XINLINE RTYPE operator OP (ITYPE a, bf16_t b) { \
38  return RTYPE(a OP float(b)); /* NOLINT(*) */ \
39  } \
40  MSHADOW_XINLINE RTYPE operator OP (bf16_t a, ITYPE b) { \
41  return RTYPE(float(a) OP b); /* NOLINT(*) */ \
42  }
43 
44 #define MSHADOW_BF16_OPERATOR(RTYPE, OP) \
45  MSHADOW_XINLINE RTYPE operator OP (bf16_t a, bf16_t b) { \
46  return RTYPE(static_cast<float>(a) OP float(b)); /* NOLINT(*) */ \
47  } \
48  MSHADOW_BF16_OPERATOR_TYPE(float, float, OP) \
49  MSHADOW_BF16_OPERATOR_TYPE(double, double, OP) \
50  MSHADOW_BF16_OPERATOR_TYPE(float, int8_t, OP) \
51  MSHADOW_BF16_OPERATOR_TYPE(float, uint8_t, OP) \
52  MSHADOW_BF16_OPERATOR_TYPE(float, int32_t, OP) \
53  MSHADOW_BF16_OPERATOR_TYPE(float, uint32_t, OP) \
54  MSHADOW_BF16_OPERATOR_TYPE(float, int64_t, OP) \
55  MSHADOW_BF16_OPERATOR_TYPE(float, uint64_t, OP)
56 
57 #define MSHADOW_BF16_ASSIGNOP(AOP, OP) \
58  template<typename T> \
59  MSHADOW_XINLINE bf16_t operator AOP (const T& a) { \
60  return *this = bf16_t(float(*this) OP float(a)); /* NOLINT(*)*/ \
61  } \
62  template<typename T> \
63  MSHADOW_XINLINE bf16_t operator AOP (const volatile T& a) volatile { \
64  return *this = bf16_t(float(*this) OP float(a)); /* NOLINT(*)*/ \
65  }
66 
67 #define MSHADOW_BF16_CONVERSIONOP(T) \
68  MSHADOW_XINLINE operator T() const { \
69  return T(BF16ToFloat(bf16_)); /* NOLINT(*)*/ \
70  } \
71  MSHADOW_XINLINE operator T() const volatile { \
72  return T(BF16ToFloat(bf16_)); /* NOLINT(*)*/ \
73  }
74 
75 class MSHADOW_ALIGNED(2) bf16_t {
76  public:
77  uint16_t bf16_;
78 
79 static MSHADOW_XINLINE bf16_t Binary(uint16_t value) {
80  bf16_t res;
81  res.bf16_ = value;
82  return res;
83  }
84 
85  MSHADOW_XINLINE bf16_t() {}
86 
87  MSHADOW_XINLINE bf16_t(const float& value) { constructor(value); }
88  MSHADOW_XINLINE explicit bf16_t(const double& value) { constructor(value); }
89  MSHADOW_XINLINE explicit bf16_t(const int8_t& value) { constructor(value); }
90  MSHADOW_XINLINE explicit bf16_t(const uint8_t& value) { constructor(value); }
91  MSHADOW_XINLINE explicit bf16_t(const int32_t& value) { constructor(value); }
92  MSHADOW_XINLINE explicit bf16_t(const uint32_t& value) { constructor(value); }
93  MSHADOW_XINLINE explicit bf16_t(const int64_t& value) { constructor(value); }
94  MSHADOW_XINLINE explicit bf16_t(const uint64_t& value) { constructor(value); }
95 
97 
100  MSHADOW_BF16_ASSIGNOP(*=, *)
101  MSHADOW_BF16_ASSIGNOP(/=, /)
102 
103  MSHADOW_XINLINE bf16_t operator+() {
104  return *this;
105  }
106 
107  MSHADOW_XINLINE bf16_t operator-() {
108  return bf16_t(-float(*this)); // NOLINT(*)
109  }
110 
111  MSHADOW_XINLINE bf16_t operator=(const bf16_t& a) {
112  bf16_ = a.bf16_;
113  return a;
114  }
115 
116  template<typename T>
117  MSHADOW_XINLINE bf16_t operator=(const T& a) {
118  return *this = bf16_t(a); /* NOLINT(*)*/
119  }
120 
121  MSHADOW_XINLINE bf16_t operator=(const bf16_t& a) volatile {
122  bf16_ = a.bf16_;
123  return a;
124  }
125 
126  template<typename T>
127  MSHADOW_XINLINE bf16_t operator=(const T& a) volatile {
128  return *this = bf16_t(a); /* NOLINT(*)*/
129  }
130 
131  private:
132  union Bits {
133  float f;
134  int32_t si;
135  uint32_t ui;
136  };
137 
138  MSHADOW_XINLINE uint16_t FloatToBF16(const float& value) const {
139  return reinterpret_cast<const uint16_t*>(&value)[1];
140  }
141 
142  // Same as above routine, except for addition of volatile keyword
143  MSHADOW_XINLINE uint16_t FloatToBF16(const volatile float& value) const volatile { // NOLINT (*)
144  return reinterpret_cast<const volatile uint16_t*>(&value)[1];
145  }
146 
147  MSHADOW_XINLINE float BF16ToFloat(const uint16_t& value) const {
148  float ret = 0.f;
149  reinterpret_cast<uint16_t*>(&ret)[1] = value;
150  return ret;
151  }
152 
153  MSHADOW_XINLINE float BF16ToFloat(const volatile uint16_t& value) const volatile { // NOLINT(*)
154  float ret = 0.f;
155  reinterpret_cast<uint16_t*>(&ret)[1] = value;
156  return ret;
157  }
158 
159  template<typename T>
160  MSHADOW_XINLINE void constructor(const T& value) {
161  bf16_ = FloatToBF16(float(value)); // NOLINT(*)
162  }
163 };
164 
166 MSHADOW_BF16_OPERATOR(bf16_t, +)
168 MSHADOW_BF16_OPERATOR(bf16_t, -)
170 MSHADOW_BF16_OPERATOR(bf16_t, *)
172 MSHADOW_BF16_OPERATOR(bf16_t, /)
174 MSHADOW_BF16_OPERATOR(bool, >)
176 MSHADOW_BF16_OPERATOR(bool, <)
178 MSHADOW_BF16_OPERATOR(bool, >=)
180 MSHADOW_BF16_OPERATOR(bool, <=)
181 
182 #define MSHADOW_BF16_MIN mshadow::bfloat::bf16_t::Binary(0xFF7F);
183 #define MSHADOW_BF16_MAX mshadow::bfloat::bf16_t::Binary(0x7F7F);
184 } // namespace bfloat
185 } // namespace mshadow
186 #endif // MSHADOW_BFLOAT_H_
BinaryMapExp< op::minus, TA, ScalarExp< MSHADOW_SCALAR_ >, MSHADOW_SCALAR_,(ta|type::kMapper)> operator-(const Exp< TA, MSHADOW_SCALAR_, ta > &lhs, const ScalarExp< MSHADOW_SCALAR_ > &rhs)
operator overload
Definition: expr_scalar-inl.h:102
BinaryMapExp< op::plus, TA, ScalarExp< MSHADOW_SCALAR_ >, MSHADOW_SCALAR_,(ta|type::kMapper)> operator+(const Exp< TA, MSHADOW_SCALAR_, ta > &lhs, const ScalarExp< MSHADOW_SCALAR_ > &rhs)
operator overload
Definition: expr_scalar-inl.h:94
#define MSHADOW_BF16_ASSIGNOP(AOP, OP)
Definition: bfloat.h:57
#define MSHADOW_XINLINE
Definition: base.h:223
#define MSHADOW_BF16_CONVERSIONOP(T)
Definition: bfloat.h:67
#define MSHADOW_BF16_OPERATOR(RTYPE, OP)
Definition: bfloat.h:44
class MSHADOW_ALIGNED(2) bf16_t
Definition: bfloat.h:75
overloaded + operator between half_t and bf16_t
Definition: base.h:327