mxnet
half.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
26 #ifndef MSHADOW_HALF_H_
27 #define MSHADOW_HALF_H_
28 #include "./base.h"
29 
30 #if MSHADOW_USE_F16C
31  #include <x86intrin.h>
32 #endif // MSHADOW_USE_F16C
33 
34 // This flag dictates rounding for the float2half() routine only (used generally on Windows),
35 // not the f16c lib or cuda v7.5 (or later) behavior which is fixed at round-to-nearest-even.
36 #ifndef MSHADOW_HALF_ROUND_TO_NEAREST
37 #define MSHADOW_HALF_ROUND_TO_NEAREST 1
38 #endif
39 
40 #if (MSHADOW_USE_CUDA && CUDA_VERSION >= 7050)
41  #define MSHADOW_CUDA_HALF 1
42  #include <cuda_fp16.h>
43  #if defined(__CUDA_ARCH__)
44 
45  MSHADOW_XINLINE float __half2float_warp(const volatile __half& h) { /* NOLINT(*) */
46  __half val;
47 #if CUDA_VERSION >= 9000
48  val = const_cast<__half&>(h);
49 #else
50  val.x = h.x;
51 #endif
52  return __half2float(val);
53  }
54  #endif
55 #else
56  #define MSHADOW_CUDA_HALF 0
57 #endif
58 
60 namespace mshadow {
61 /* \brief name space for host/device portable half-precision floats */
62 namespace half {
63 #define MSHADOW_HALF_OPERATOR(RTYPE, OP) \
64  MSHADOW_XINLINE RTYPE operator OP (half_t a, half_t b) { \
65  return RTYPE(float(a) OP float(b)); /* NOLINT(*) */ \
66  } \
67  template<typename T> \
68  MSHADOW_XINLINE RTYPE operator OP (half_t a, T b) { \
69  return RTYPE(float(a) OP float(b)); /* NOLINT(*) */ \
70  } \
71  template<typename T> \
72  MSHADOW_XINLINE RTYPE operator OP (T a, half_t b) { \
73  return RTYPE(float(a) OP float(b)); /* NOLINT(*) */ \
74  }
75 
76 #define MSHADOW_HALF_ASSIGNOP(AOP, OP) \
77  template<typename T> \
78  MSHADOW_XINLINE half_t operator AOP (const T& a) { \
79  return *this = half_t(float(*this) OP float(a)); /* NOLINT(*)*/ \
80  } \
81  template<typename T> \
82  MSHADOW_XINLINE half_t operator AOP (const volatile T& a) volatile { \
83  return *this = half_t(float(*this) OP float(a)); /* NOLINT(*)*/ \
84  }
85 
86 #if (MSHADOW_CUDA_HALF && defined(__CUDA_ARCH__))
87 #define MSHADOW_HALF_CONVERSIONOP(T) \
88  MSHADOW_XINLINE operator T() const { \
89  return T(__half2float(cuhalf_)); /* NOLINT(*)*/ \
90  } \
91  MSHADOW_XINLINE operator T() const volatile { \
92  return T(__half2float_warp(cuhalf_)); /* NOLINT(*)*/ \
93  }
94 #elif(MSHADOW_USE_F16C)
95 #define MSHADOW_HALF_CONVERSIONOP(T) \
96  MSHADOW_XINLINE operator T() const { \
97  return T(_cvtsh_ss(half_)); /* NOLINT(*)*/ \
98  } \
99  MSHADOW_XINLINE operator T() const volatile { \
100  return T(_cvtsh_ss(half_)); /* NOLINT(*)*/ \
101  }
102 #else
103 #define MSHADOW_HALF_CONVERSIONOP(T) \
104  MSHADOW_XINLINE operator T() const { \
105  return T(half2float(half_)); /* NOLINT(*)*/ \
106  } \
107  MSHADOW_XINLINE operator T() const volatile { \
108  return T(half2float(half_)); /* NOLINT(*)*/ \
109  }
110 #endif // (MSHADOW_CUDA_HALF && defined(__CUDA_ARCH__))
111 
112 class MSHADOW_ALIGNED(2) half_t {
113  public:
114  union {
115  uint16_t half_;
116 #if MSHADOW_CUDA_HALF
117  __half cuhalf_;
118 #endif // MSHADOW_CUDA_HALF
119  };
120 
121  static MSHADOW_XINLINE half_t Binary(uint16_t value) {
122  half_t res;
123  res.half_ = value;
124  return res;
125  }
126 
127  MSHADOW_XINLINE half_t() {}
128 
129 #if MSHADOW_CUDA_HALF
130  MSHADOW_XINLINE explicit half_t(const __half& value) {
131  cuhalf_ = value;
132  }
133 #endif // MSHADOW_CUDA_HALF
134 
135  MSHADOW_XINLINE half_t(const float& value) { constructor(value); }
136  MSHADOW_XINLINE explicit half_t(const double& value) { constructor(value); }
137  MSHADOW_XINLINE explicit half_t(const int8_t& value) { constructor(value); }
138  MSHADOW_XINLINE explicit half_t(const uint8_t& value) { constructor(value); }
139  MSHADOW_XINLINE explicit half_t(const int32_t& value) { constructor(value); }
140  MSHADOW_XINLINE explicit half_t(const uint32_t& value) { constructor(value); }
141  MSHADOW_XINLINE explicit half_t(const int64_t& value) { constructor(value); }
142  MSHADOW_XINLINE explicit half_t(const uint64_t& value) { constructor(value); }
143 
145 
146  MSHADOW_HALF_ASSIGNOP(+=, +)
147  MSHADOW_HALF_ASSIGNOP(-=, -)
148  MSHADOW_HALF_ASSIGNOP(*=, *)
149  MSHADOW_HALF_ASSIGNOP(/=, /)
150 
151  MSHADOW_XINLINE half_t operator+() {
152  return *this;
153  }
154 
155  MSHADOW_XINLINE half_t operator-() {
156  return half_t(-float(*this)); // NOLINT(*)
157  }
158 
159  MSHADOW_XINLINE half_t operator=(const half_t& a) {
160  half_ = a.half_;
161  return a;
162  }
163 
164  template<typename T>
165  MSHADOW_XINLINE half_t operator=(const T& a) {
166  return *this = half_t(a); /* NOLINT(*)*/
167  }
168 
169  MSHADOW_XINLINE half_t operator=(const half_t& a) volatile {
170  half_ = a.half_;
171  return a;
172  }
173 
174  template<typename T>
175  MSHADOW_XINLINE half_t operator=(const T& a) volatile {
176  return *this = half_t(a); /* NOLINT(*)*/
177  }
178 
179  private:
180  union Bits {
181  float f;
182  int32_t si;
183  uint32_t ui;
184  };
185 
186  static int const fp16FractionBits = 10;
187  static int const fp32FractionBits = 23;
188  static int32_t const fp32FractionMask = ~(~0u << fp32FractionBits); // == 0x7fffff
189  static int32_t const fp32HiddenBit = 1 << fp32FractionBits; // == 0x800000
190  static int const shift = fp32FractionBits - fp16FractionBits; // == 13
191  static int const shiftSign = 16;
192  static int32_t const expAdjust = 127 - 15; // exp32-127 = exp16-15, so exp16 = exp32 - (127-15)
193 
194  static int32_t const infN = 0x7F800000; // flt32 infinity
195  static int32_t const maxN = 0x477FFFFF; // max flt32 that's a flt16 normal after >> by shift
196  static int32_t const minN = 0x38800000; // min flt16 normal as a flt32
197  static int32_t const maxZ = 0x33000000; // max fp32 number that's still rounded to zero in fp16
198  static int32_t const signN = 0x80000000; // flt32 sign bit
199 
200  static int32_t const infC = infN >> shift;
201  static int32_t const nanN = (infC + 1) << shift; // minimum flt16 nan as a flt32
202  static int32_t const maxC = maxN >> shift;
203  static int32_t const minC = minN >> shift;
204  static int32_t const signC = signN >> shiftSign; // flt16 sign bit
205 
206  static int32_t const mulN = 0x52000000; // (1 << 23) / minN
207  static int32_t const mulC = 0x33800000; // minN / (1 << (23 - shift))
208 
209  static int32_t const subC = 0x003FF; // max flt32 subnormal down shifted
210  static int32_t const norC = 0x00400; // min flt32 normal down shifted
211 
212  static int32_t const maxD = infC - maxC - 1;
213  static int32_t const minD = minC - subC - 1;
214 
215  MSHADOW_XINLINE uint16_t float2half(const float& value) const {
216  Bits v;
217  v.f = value;
218  uint32_t sign = v.si & signN; // grab sign bit
219  v.si ^= sign; // clear sign bit from v
220  sign >>= shiftSign; // logical shift sign to fp16 position
221 
222  if (v.si <= maxZ) {
223  // Handle eventual zeros here to ensure vshift will not exceed 32 below.
224  v.ui = 0;
225  } else if (v.si < minN) {
226  // Handle denorms
227  uint32_t exp32 = v.ui >> fp32FractionBits;
228  int32_t exp16 = exp32 - expAdjust;
229  // If exp16 == 0 (just into the denorm range), then significant should be shifted right 1.
230  // Smaller (so negative) exp16 values should result in greater right shifts.
231  uint32_t vshift = 1 - exp16;
232  uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);
233  v.ui = significand >> vshift;
234  // The only time it's *not* OK to add 0x1000 (i.e. half the flt16 fraction lsb) is
235  // when the lsb of the flt16 fraction == 0 (so not rounding up to even) and the additional
236  // bits to the right of the lsb are 1000... (including flt32 significand bits
237  // that may be lost during the above vshift). The first term below will always
238  // be true for vshift >=12 (since even the 'hidden bit' has been shifted to the
239  // right of the '1' bit in 0x1000). And when vshift <= 11, both terms combine to make
240  // the proper test of the flt32 significand bits, including those lost during the vshift.
241 #if MSHADOW_HALF_ROUND_TO_NEAREST == 1
242  // Rounding may increase the exponent to 1, but that's OK.
243  v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 0;
244 #endif
245  } else if (v.si <= maxN) {
246  // Handle norms
247 #if MSHADOW_HALF_ROUND_TO_NEAREST == 1
248  // Rounding may increase the exponent, possibly creating an inf, but that's OK.
249  v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0;
250 #endif
251  v.ui -= expAdjust << fp32FractionBits;
252  } else if (v.si <= infN) {
253  v.si = infN;
254  } else if (v.si < nanN) {
255  v.si = nanN;
256  }
257 
258  v.ui >>= shift;
259  return sign | (v.ui & 0x7fff);
260  }
261 
262  // Same as above routine, except for addition of volatile keyword
263  MSHADOW_XINLINE uint16_t float2half(const volatile float& value) const volatile { // NOLINT (*)
264  Bits v;
265  v.f = value;
266  uint32_t sign = v.si & signN; // grab sign bit
267  v.si ^= sign; // clear sign bit from v
268  sign >>= shiftSign; // logical shift sign to fp16 position
269 
270  if (v.si <= maxZ) {
271  // Handle eventual zeros here to ensure vshift will not exceed 32 below.
272  v.ui = 0;
273  } else if (v.si < minN) {
274  // Handle denorms
275  uint32_t exp32 = v.ui >> fp32FractionBits;
276  int32_t exp16 = exp32 - expAdjust;
277  // If exp16 == 0 (just into the denorm range), then significant should be shifted right 1.
278  // Smaller (so negative) exp16 values should result in greater right shifts.
279  uint32_t vshift = 1 - exp16;
280  uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);
281  v.ui = significand >> vshift;
282 #if MSHADOW_HALF_ROUND_TO_NEAREST == 1
283  // Rounding may increase the exponent to 1, but that's OK.
284  v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 0;
285 #endif
286  } else if (v.si <= maxN) {
287  // Handle norms
288 #if MSHADOW_HALF_ROUND_TO_NEAREST == 1
289  // Rounding may increase the exponent, possibly creating an inf, but that's OK.
290  v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0;
291 #endif
292  v.ui -= expAdjust << fp32FractionBits;
293  } else if (v.si <= infN) {
294  v.si = infN;
295  } else if (v.si < nanN) {
296  v.si = nanN;
297  }
298 
299  v.ui >>= shift;
300  return sign | (v.ui & 0x7fff);
301  }
302 
303  MSHADOW_XINLINE float half2float(const uint16_t& value) const {
304  Bits v;
305  v.ui = value;
306  int32_t sign = v.si & signC;
307  v.si ^= sign;
308  sign <<= shiftSign;
309  v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);
310  v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);
311  Bits s;
312  s.si = mulC;
313  s.f *= v.si;
314  int32_t mask = -(norC > v.si);
315  v.si <<= shift;
316  v.si ^= (s.si ^ v.si) & mask;
317  v.si |= sign;
318  return v.f;
319  }
320 
321  MSHADOW_XINLINE float half2float(const volatile uint16_t& value) const volatile { // NOLINT(*)
322  Bits v;
323  v.ui = value;
324  int32_t sign = v.si & signC;
325  v.si ^= sign;
326  sign <<= shiftSign;
327  v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);
328  v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);
329  Bits s;
330  s.si = mulC;
331  s.f *= v.si;
332  int32_t mask = -(norC > v.si);
333  v.si <<= shift;
334  v.si ^= (s.si ^ v.si) & mask;
335  v.si |= sign;
336  return v.f;
337  }
338 
339  template<typename T>
340  MSHADOW_XINLINE void constructor(const T& value) {
341 #if (MSHADOW_CUDA_HALF && defined(__CUDA_ARCH__))
342  cuhalf_ = __float2half(float(value)); // NOLINT(*)
343 #elif(MSHADOW_USE_F16C)
344  half_ = _cvtss_sh(static_cast<float>(value), 0);
345 #else /* !MSHADOW_CUDA_HALF && !MSHADOW_USE_F16C */
346  half_ = float2half(float(value)); // NOLINT(*)
347 #endif /* !MSHADOW_CUDA_HALF && !MSHADOW_USE_F16C */
348  }
349 };
350 
352 MSHADOW_HALF_OPERATOR(half_t, +)
354 MSHADOW_HALF_OPERATOR(half_t, -)
356 MSHADOW_HALF_OPERATOR(half_t, *)
358 MSHADOW_HALF_OPERATOR(half_t, /)
360 MSHADOW_HALF_OPERATOR(bool, >)
362 MSHADOW_HALF_OPERATOR(bool, <)
364 MSHADOW_HALF_OPERATOR(bool, >=)
366 MSHADOW_HALF_OPERATOR(bool, <=)
367 
368 #define MSHADOW_HALF_MIN mshadow::half::half_t::Binary(0xFBFF);
369 #define MSHADOW_HALF_MAX mshadow::half::half_t::Binary(0x7BFF);
370 #define MSHADOW_HALF_SIGN_BIT 0x8000
371 #define MSHADOW_HALF_EXPONENT_BITS 0x7c00
372 } // namespace half
373 } // namespace mshadow
374 #endif // MSHADOW_HALF_H_
MSHADOW_XINLINE
#define MSHADOW_XINLINE
Definition: base.h:228
mshadow::expr::operator+
BinaryMapExp< op::plus, TA, ScalarExp< MSHADOW_SCALAR_ >, MSHADOW_SCALAR_,(ta|type::kMapper)> operator+(const Exp< TA, MSHADOW_SCALAR_, ta > &lhs, const ScalarExp< MSHADOW_SCALAR_ > &rhs)
operator overload
Definition: expr_scalar-inl.h:93
mshadow::expr::mask
MaskExp< IndexExp, SrcExp, DType > mask(const Exp< IndexExp, DType, e1 > &index, const Exp< SrcExp, DType, e2 > &src)
Definition: mask.h:57
MSHADOW_HALF_OPERATOR
#define MSHADOW_HALF_OPERATOR(RTYPE, OP)
Definition: half.h:63
mshadow::half::MSHADOW_ALIGNED
class MSHADOW_ALIGNED(2) half_t
Definition: half.h:112
mshadow
overloaded + operator between half_t and bf16_t
Definition: base.h:319
MSHADOW_HALF_ASSIGNOP
#define MSHADOW_HALF_ASSIGNOP(AOP, OP)
Definition: half.h:76
MSHADOW_HALF_CONVERSIONOP
#define MSHADOW_HALF_CONVERSIONOP(T)
Definition: half.h:95
mshadow::expr::operator-
BinaryMapExp< op::minus, TA, ScalarExp< MSHADOW_SCALAR_ >, MSHADOW_SCALAR_,(ta|type::kMapper)> operator-(const Exp< TA, MSHADOW_SCALAR_, ta > &lhs, const ScalarExp< MSHADOW_SCALAR_ > &rhs)
operator overload
Definition: expr_scalar-inl.h:101
base.h
definitions of base types, operators, macros functions