mxnet
half.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
27 #ifndef MSHADOW_HALF_H_
28 #define MSHADOW_HALF_H_
29 #include "./base.h"
30 
31 #if MSHADOW_USE_F16C
32  #include <x86intrin.h>
33 #endif // MSHADOW_USE_F16C
34 
35 // This flag dictates rounding for the float2half() routine only (used generally on Windows),
36 // not the f16c lib or cuda v7.5 (or later) behavior which is fixed at round-to-nearest-even.
37 #ifndef MSHADOW_HALF_ROUND_TO_NEAREST
38 #define MSHADOW_HALF_ROUND_TO_NEAREST 1
39 #endif
40 
41 #if (MSHADOW_USE_CUDA && CUDA_VERSION >= 7050)
42  #define MSHADOW_CUDA_HALF 1
43  #include <cuda_fp16.h>
44  #if defined(__CUDA_ARCH__)
45 
46  __host__ __device__ float __half2float_warp(const volatile __half& h) { /* NOLINT(*) */
47  __half val;
48 #if CUDA_VERSION >= 9000
49  val = const_cast<__half&>(h);
50 #else
51  val.x = h.x;
52 #endif
53  return __half2float(val);
54  }
55  #endif
56 #else
57  #define MSHADOW_CUDA_HALF 0
58 #endif
59 
61 namespace mshadow {
62 /* \brief name space for host/device portable half-precision floats */
63 namespace half {
64 #define MSHADOW_HALF_OPERATOR(RTYPE, OP) \
65  MSHADOW_XINLINE RTYPE operator OP (half_t a, half_t b) { \
66  return RTYPE(float(a) OP float(b)); /* NOLINT(*) */ \
67  } \
68  template<typename T> \
69  MSHADOW_XINLINE RTYPE operator OP (half_t a, T b) { \
70  return RTYPE(float(a) OP float(b)); /* NOLINT(*) */ \
71  } \
72  template<typename T> \
73  MSHADOW_XINLINE RTYPE operator OP (T a, half_t b) { \
74  return RTYPE(float(a) OP float(b)); /* NOLINT(*) */ \
75  }
76 
77 #define MSHADOW_HALF_ASSIGNOP(AOP, OP) \
78  template<typename T> \
79  MSHADOW_XINLINE half_t operator AOP (const T& a) { \
80  return *this = half_t(float(*this) OP float(a)); /* NOLINT(*)*/ \
81  } \
82  template<typename T> \
83  MSHADOW_XINLINE half_t operator AOP (const volatile T& a) volatile { \
84  return *this = half_t(float(*this) OP float(a)); /* NOLINT(*)*/ \
85  }
86 
87 #if (MSHADOW_CUDA_HALF && defined(__CUDA_ARCH__))
88 #define MSHADOW_HALF_CONVERSIONOP(T) \
89  MSHADOW_XINLINE operator T() const { \
90  return T(__half2float(cuhalf_)); /* NOLINT(*)*/ \
91  } \
92  MSHADOW_XINLINE operator T() const volatile { \
93  return T(__half2float_warp(cuhalf_)); /* NOLINT(*)*/ \
94  }
95 #elif(MSHADOW_USE_F16C)
96 #define MSHADOW_HALF_CONVERSIONOP(T) \
97  MSHADOW_XINLINE operator T() const { \
98  return T(_cvtsh_ss(half_)); /* NOLINT(*)*/ \
99  } \
100  MSHADOW_XINLINE operator T() const volatile { \
101  return T(_cvtsh_ss(half_)); /* NOLINT(*)*/ \
102  }
103 #else
104 #define MSHADOW_HALF_CONVERSIONOP(T) \
105  MSHADOW_XINLINE operator T() const { \
106  return T(half2float(half_)); /* NOLINT(*)*/ \
107  } \
108  MSHADOW_XINLINE operator T() const volatile { \
109  return T(half2float(half_)); /* NOLINT(*)*/ \
110  }
111 #endif // (MSHADOW_CUDA_HALF && defined(__CUDA_ARCH__))
112 
113 class MSHADOW_ALIGNED(2) half_t {
114  public:
115  union {
116  uint16_t half_;
117 #if MSHADOW_CUDA_HALF
118  __half cuhalf_;
119 #endif // MSHADOW_CUDA_HALF
120  };
121 
122  static MSHADOW_XINLINE half_t Binary(uint16_t value) {
123  half_t res;
124  res.half_ = value;
125  return res;
126  }
127 
128  MSHADOW_XINLINE half_t() {}
129 
130 #if MSHADOW_CUDA_HALF
131  MSHADOW_XINLINE explicit half_t(const __half& value) {
132  cuhalf_ = value;
133  }
134 #endif // MSHADOW_CUDA_HALF
135 
136  MSHADOW_XINLINE half_t(const float& value) { constructor(value); }
137  MSHADOW_XINLINE explicit half_t(const double& value) { constructor(value); }
138  MSHADOW_XINLINE explicit half_t(const int8_t& value) { constructor(value); }
139  MSHADOW_XINLINE explicit half_t(const uint8_t& value) { constructor(value); }
140  MSHADOW_XINLINE explicit half_t(const int32_t& value) { constructor(value); }
141  MSHADOW_XINLINE explicit half_t(const uint32_t& value) { constructor(value); }
142  MSHADOW_XINLINE explicit half_t(const int64_t& value) { constructor(value); }
143  MSHADOW_XINLINE explicit half_t(const uint64_t& value) { constructor(value); }
144 
146 
147  MSHADOW_HALF_ASSIGNOP(+=, +)
148  MSHADOW_HALF_ASSIGNOP(-=, -)
149  MSHADOW_HALF_ASSIGNOP(*=, *)
150  MSHADOW_HALF_ASSIGNOP(/=, /)
151 
152  MSHADOW_XINLINE half_t operator+() {
153  return *this;
154  }
155 
156  MSHADOW_XINLINE half_t operator-() {
157  return half_t(-float(*this)); // NOLINT(*)
158  }
159 
160  MSHADOW_XINLINE half_t operator=(const half_t& a) {
161  half_ = a.half_;
162  return a;
163  }
164 
165  template<typename T>
166  MSHADOW_XINLINE half_t operator=(const T& a) {
167  return *this = half_t(a); /* NOLINT(*)*/
168  }
169 
170  MSHADOW_XINLINE half_t operator=(const half_t& a) volatile {
171  half_ = a.half_;
172  return a;
173  }
174 
175  template<typename T>
176  MSHADOW_XINLINE half_t operator=(const T& a) volatile {
177  return *this = half_t(a); /* NOLINT(*)*/
178  }
179 
180  private:
181  union Bits {
182  float f;
183  int32_t si;
184  uint32_t ui;
185  };
186 
187  static int const fp16FractionBits = 10;
188  static int const fp32FractionBits = 23;
189  static int32_t const fp32FractionMask = ~(~0u << fp32FractionBits); // == 0x7fffff
190  static int32_t const fp32HiddenBit = 1 << fp32FractionBits; // == 0x800000
191  static int const shift = fp32FractionBits - fp16FractionBits; // == 13
192  static int const shiftSign = 16;
193  static int32_t const expAdjust = 127 - 15; // exp32-127 = exp16-15, so exp16 = exp32 - (127-15)
194 
195  static int32_t const infN = 0x7F800000; // flt32 infinity
196  static int32_t const maxN = 0x477FFFFF; // max flt32 that's a flt16 normal after >> by shift
197  static int32_t const minN = 0x38800000; // min flt16 normal as a flt32
198  static int32_t const maxZ = 0x33000000; // max fp32 number that's still rounded to zero in fp16
199  static int32_t const signN = 0x80000000; // flt32 sign bit
200 
201  static int32_t const infC = infN >> shift;
202  static int32_t const nanN = (infC + 1) << shift; // minimum flt16 nan as a flt32
203  static int32_t const maxC = maxN >> shift;
204  static int32_t const minC = minN >> shift;
205  static int32_t const signC = signN >> shiftSign; // flt16 sign bit
206 
207  static int32_t const mulN = 0x52000000; // (1 << 23) / minN
208  static int32_t const mulC = 0x33800000; // minN / (1 << (23 - shift))
209 
210  static int32_t const subC = 0x003FF; // max flt32 subnormal down shifted
211  static int32_t const norC = 0x00400; // min flt32 normal down shifted
212 
213  static int32_t const maxD = infC - maxC - 1;
214  static int32_t const minD = minC - subC - 1;
215 
216  MSHADOW_XINLINE uint16_t float2half(const float& value) const {
217  Bits v;
218  v.f = value;
219  uint32_t sign = v.si & signN; // grab sign bit
220  v.si ^= sign; // clear sign bit from v
221  sign >>= shiftSign; // logical shift sign to fp16 position
222 
223  if (v.si <= maxZ) {
224  // Handle eventual zeros here to ensure vshift will not exceed 32 below.
225  v.ui = 0;
226  } else if (v.si < minN) {
227  // Handle denorms
228  uint32_t exp32 = v.ui >> fp32FractionBits;
229  int32_t exp16 = exp32 - expAdjust;
230  // If exp16 == 0 (just into the denorm range), then significant should be shifted right 1.
231  // Smaller (so negative) exp16 values should result in greater right shifts.
232  uint32_t vshift = 1 - exp16;
233  uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);
234  v.ui = significand >> vshift;
235  // The only time it's *not* OK to add 0x1000 (i.e. half the flt16 fraction lsb) is
236  // when the lsb of the flt16 fraction == 0 (so not rounding up to even) and the additional
237  // bits to the right of the lsb are 1000... (including flt32 significand bits
238  // that may be lost during the above vshift). The first term below will always
239  // be true for vshift >=12 (since even the 'hidden bit' has been shifted to the
240  // right of the '1' bit in 0x1000). And when vshift <= 11, both terms combine to make
241  // the proper test of the flt32 significand bits, including those lost during the vshift.
242 #if MSHADOW_HALF_ROUND_TO_NEAREST == 1
243  // Rounding may increase the exponent to 1, but that's OK.
244  v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 0;
245 #endif
246  } else if (v.si <= maxN) {
247  // Handle norms
248 #if MSHADOW_HALF_ROUND_TO_NEAREST == 1
249  // Rounding may increase the exponent, possibly creating an inf, but that's OK.
250  v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0;
251 #endif
252  v.ui -= expAdjust << fp32FractionBits;
253  } else if (v.si <= infN) {
254  v.si = infN;
255  } else if (v.si < nanN) {
256  v.si = nanN;
257  }
258 
259  v.ui >>= shift;
260  return sign | (v.ui & 0x7fff);
261  }
262 
263  // Same as above routine, except for addition of volatile keyword
264  MSHADOW_XINLINE uint16_t float2half(const volatile float& value) const volatile { // NOLINT (*)
265  Bits v;
266  v.f = value;
267  uint32_t sign = v.si & signN; // grab sign bit
268  v.si ^= sign; // clear sign bit from v
269  sign >>= shiftSign; // logical shift sign to fp16 position
270 
271  if (v.si <= maxZ) {
272  // Handle eventual zeros here to ensure vshift will not exceed 32 below.
273  v.ui = 0;
274  } else if (v.si < minN) {
275  // Handle denorms
276  uint32_t exp32 = v.ui >> fp32FractionBits;
277  int32_t exp16 = exp32 - expAdjust;
278  // If exp16 == 0 (just into the denorm range), then significant should be shifted right 1.
279  // Smaller (so negative) exp16 values should result in greater right shifts.
280  uint32_t vshift = 1 - exp16;
281  uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);
282  v.ui = significand >> vshift;
283 #if MSHADOW_HALF_ROUND_TO_NEAREST == 1
284  // Rounding may increase the exponent to 1, but that's OK.
285  v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 0;
286 #endif
287  } else if (v.si <= maxN) {
288  // Handle norms
289 #if MSHADOW_HALF_ROUND_TO_NEAREST == 1
290  // Rounding may increase the exponent, possibly creating an inf, but that's OK.
291  v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0;
292 #endif
293  v.ui -= expAdjust << fp32FractionBits;
294  } else if (v.si <= infN) {
295  v.si = infN;
296  } else if (v.si < nanN) {
297  v.si = nanN;
298  }
299 
300  v.ui >>= shift;
301  return sign | (v.ui & 0x7fff);
302  }
303 
304  MSHADOW_XINLINE float half2float(const uint16_t& value) const {
305  Bits v;
306  v.ui = value;
307  int32_t sign = v.si & signC;
308  v.si ^= sign;
309  sign <<= shiftSign;
310  v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);
311  v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);
312  Bits s;
313  s.si = mulC;
314  s.f *= v.si;
315  int32_t mask = -(norC > v.si);
316  v.si <<= shift;
317  v.si ^= (s.si ^ v.si) & mask;
318  v.si |= sign;
319  return v.f;
320  }
321 
322  MSHADOW_XINLINE float half2float(const volatile uint16_t& value) const volatile { // NOLINT(*)
323  Bits v;
324  v.ui = value;
325  int32_t sign = v.si & signC;
326  v.si ^= sign;
327  sign <<= shiftSign;
328  v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);
329  v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);
330  Bits s;
331  s.si = mulC;
332  s.f *= v.si;
333  int32_t mask = -(norC > v.si);
334  v.si <<= shift;
335  v.si ^= (s.si ^ v.si) & mask;
336  v.si |= sign;
337  return v.f;
338  }
339 
340  template<typename T>
341  MSHADOW_XINLINE void constructor(const T& value) {
342 #if (MSHADOW_CUDA_HALF && defined(__CUDA_ARCH__))
343  cuhalf_ = __float2half(float(value)); // NOLINT(*)
344 #elif(MSHADOW_USE_F16C)
345  half_ = _cvtss_sh(static_cast<float>(value), 0);
346 #else /* !MSHADOW_CUDA_HALF && !MSHADOW_USE_F16C */
347  half_ = float2half(float(value)); // NOLINT(*)
348 #endif /* !MSHADOW_CUDA_HALF && !MSHADOW_USE_F16C */
349  }
350 };
351 
353 MSHADOW_HALF_OPERATOR(half_t, +)
355 MSHADOW_HALF_OPERATOR(half_t, -)
357 MSHADOW_HALF_OPERATOR(half_t, *)
359 MSHADOW_HALF_OPERATOR(half_t, /)
361 MSHADOW_HALF_OPERATOR(bool, >)
363 MSHADOW_HALF_OPERATOR(bool, <)
365 MSHADOW_HALF_OPERATOR(bool, >=)
367 MSHADOW_HALF_OPERATOR(bool, <=)
368 
369 #define MSHADOW_HALF_MIN mshadow::half::half_t::Binary(0xFBFF);
370 #define MSHADOW_HALF_MAX mshadow::half::half_t::Binary(0x7BFF);
371 #define MSHADOW_HALF_SIGN_BIT 0x8000
372 #define MSHADOW_HALF_EXPONENT_BITS 0x7c00
373 } // namespace half
374 } // namespace mshadow
375 #endif // MSHADOW_HALF_H_
class MSHADOW_ALIGNED(2) half_t
Definition: half.h:113
#define MSHADOW_HALF_CONVERSIONOP(T)
Definition: half.h:96
#define MSHADOW_HALF_ASSIGNOP(AOP, OP)
Definition: half.h:77
MSHADOW_XINLINE half2_t operator+(half2_t a, half2_t b)
overloaded + operator for half2_t
Definition: half2.h:107
#define MSHADOW_XINLINE
Definition: base.h:223
MSHADOW_XINLINE half2_t operator-(half2_t a, half2_t b)
overloaded - operator for half2_t
Definition: half2.h:116
MaskExp< IndexExp, SrcExp, DType > mask(const Exp< IndexExp, DType, e1 > &index, const Exp< SrcExp, DType, e2 > &src)
Definition: mask.h:58
#define MSHADOW_HALF_OPERATOR(RTYPE, OP)
Definition: half.h:64
overloaded + operator between half_t and bf16_t
Definition: base.h:327