mxnet
half.h
Go to the documentation of this file.
1 
8 #ifndef MSHADOW_HALF_H_
9 #define MSHADOW_HALF_H_
10 #include "./base.h"
11 
12 #if MSHADOW_USE_F16C
13  #include <x86intrin.h>
14 #endif // MSHADOW_USE_F16C
15 
16 // This flag dictates rounding for the float2half() routine only (used generally on Windows),
17 // not the f16c lib or cuda v7.5 (or later) behavior which is fixed at round-to-nearest-even.
18 #ifndef MSHADOW_HALF_ROUND_TO_NEAREST
19 #define MSHADOW_HALF_ROUND_TO_NEAREST 1
20 #endif
21 
22 #if (MSHADOW_USE_CUDA && CUDA_VERSION >= 7050)
23  #define MSHADOW_CUDA_HALF 1
24  #include <cuda_fp16.h>
25  #if defined(__CUDA_ARCH__)
26 
27  __host__ __device__ float __half2float_warp(const volatile __half& h) { /* NOLINT(*) */
28  __half val;
29 #if CUDA_VERSION >= 9000
30  val = const_cast<__half&>(h);
31 #else
32  val.x = h.x;
33 #endif
34  return __half2float(val);
35  }
36  #endif
37 #else
38  #define MSHADOW_CUDA_HALF 0
39 #endif
40 
42 namespace mshadow {
43 /* \brief name space for host/device portable half-precision floats */
44 namespace half {
45 #define MSHADOW_HALF_OPERATOR(RTYPE, OP) \
46  MSHADOW_XINLINE RTYPE operator OP (half_t a, half_t b) { \
47  return RTYPE(float(a) OP float(b)); /* NOLINT(*) */ \
48  } \
49  template<typename T> \
50  MSHADOW_XINLINE RTYPE operator OP (half_t a, T b) { \
51  return RTYPE(float(a) OP float(b)); /* NOLINT(*) */ \
52  } \
53  template<typename T> \
54  MSHADOW_XINLINE RTYPE operator OP (T a, half_t b) { \
55  return RTYPE(float(a) OP float(b)); /* NOLINT(*) */ \
56  }
57 
58 #define MSHADOW_HALF_ASSIGNOP(AOP, OP) \
59  template<typename T> \
60  MSHADOW_XINLINE half_t operator AOP (const T& a) { \
61  return *this = half_t(float(*this) OP float(a)); /* NOLINT(*)*/ \
62  } \
63  template<typename T> \
64  MSHADOW_XINLINE half_t operator AOP (const volatile T& a) volatile { \
65  return *this = half_t(float(*this) OP float(a)); /* NOLINT(*)*/ \
66  }
67 
68 #if (MSHADOW_CUDA_HALF && defined(__CUDA_ARCH__))
69 #define MSHADOW_HALF_CONVERSIONOP(T) \
70  MSHADOW_XINLINE operator T() const { \
71  return T(__half2float(cuhalf_)); /* NOLINT(*)*/ \
72  } \
73  MSHADOW_XINLINE operator T() const volatile { \
74  return T(__half2float_warp(cuhalf_)); /* NOLINT(*)*/ \
75  }
76 #elif(MSHADOW_USE_F16C)
77 #define MSHADOW_HALF_CONVERSIONOP(T) \
78  MSHADOW_XINLINE operator T() const { \
79  return T(_cvtsh_ss(half_)); /* NOLINT(*)*/ \
80  } \
81  MSHADOW_XINLINE operator T() const volatile { \
82  return T(_cvtsh_ss(half_)); /* NOLINT(*)*/ \
83  }
84 #else
85 #define MSHADOW_HALF_CONVERSIONOP(T) \
86  MSHADOW_XINLINE operator T() const { \
87  return T(half2float(half_)); /* NOLINT(*)*/ \
88  } \
89  MSHADOW_XINLINE operator T() const volatile { \
90  return T(half2float(half_)); /* NOLINT(*)*/ \
91  }
92 #endif // (MSHADOW_CUDA_HALF && defined(__CUDA_ARCH__))
93 
94 class MSHADOW_ALIGNED(2) half_t {
95  public:
96  union {
97  uint16_t half_;
98 #if MSHADOW_CUDA_HALF
99  __half cuhalf_;
100 #endif // MSHADOW_CUDA_HALF
101  };
102 
103  static MSHADOW_XINLINE half_t Binary(uint16_t value) {
104  half_t res;
105  res.half_ = value;
106  return res;
107  }
108 
109  MSHADOW_XINLINE half_t() {}
110 
111 #if MSHADOW_CUDA_HALF
112  MSHADOW_XINLINE explicit half_t(const __half& value) {
113  cuhalf_ = value;
114  }
115 #endif // MSHADOW_CUDA_HALF
116 
117  MSHADOW_XINLINE half_t(const float& value) { constructor(value); }
118  MSHADOW_XINLINE explicit half_t(const double& value) { constructor(value); }
119  MSHADOW_XINLINE explicit half_t(const int8_t& value) { constructor(value); }
120  MSHADOW_XINLINE explicit half_t(const uint8_t& value) { constructor(value); }
121  MSHADOW_XINLINE explicit half_t(const int32_t& value) { constructor(value); }
122  MSHADOW_XINLINE explicit half_t(const uint32_t& value) { constructor(value); }
123  MSHADOW_XINLINE explicit half_t(const int64_t& value) { constructor(value); }
124  MSHADOW_XINLINE explicit half_t(const uint64_t& value) { constructor(value); }
125 
127 
128  MSHADOW_HALF_ASSIGNOP(+=, +)
129  MSHADOW_HALF_ASSIGNOP(-=, -)
130  MSHADOW_HALF_ASSIGNOP(*=, *)
131  MSHADOW_HALF_ASSIGNOP(/=, /)
132 
133  MSHADOW_XINLINE half_t operator+() {
134  return *this;
135  }
136 
137  MSHADOW_XINLINE half_t operator-() {
138  return half_t(-float(*this)); // NOLINT(*)
139  }
140 
141  MSHADOW_XINLINE half_t operator=(const half_t& a) {
142  half_ = a.half_;
143  return a;
144  }
145 
146  template<typename T>
147  MSHADOW_XINLINE half_t operator=(const T& a) {
148  return *this = half_t(a); /* NOLINT(*)*/
149  }
150 
151  MSHADOW_XINLINE half_t operator=(const half_t& a) volatile {
152  half_ = a.half_;
153  return a;
154  }
155 
156  template<typename T>
157  MSHADOW_XINLINE half_t operator=(const T& a) volatile {
158  return *this = half_t(a); /* NOLINT(*)*/
159  }
160 
161  private:
162  union Bits {
163  float f;
164  int32_t si;
165  uint32_t ui;
166  };
167 
168  static int const fp16FractionBits = 10;
169  static int const fp32FractionBits = 23;
170  static int32_t const fp32FractionMask = ~(~0u << fp32FractionBits); // == 0x7fffff
171  static int32_t const fp32HiddenBit = 1 << fp32FractionBits; // == 0x800000
172  static int const shift = fp32FractionBits - fp16FractionBits; // == 13
173  static int const shiftSign = 16;
174  static int32_t const expAdjust = 127 - 15; // exp32-127 = exp16-15, so exp16 = exp32 - (127-15)
175 
176  static int32_t const infN = 0x7F800000; // flt32 infinity
177  static int32_t const maxN = 0x477FFFFF; // max flt32 that's a flt16 normal after >> by shift
178  static int32_t const minN = 0x38800000; // min flt16 normal as a flt32
179  static int32_t const maxZ = 0x33000000; // max fp32 number that's still rounded to zero in fp16
180  static int32_t const signN = 0x80000000; // flt32 sign bit
181 
182  static int32_t const infC = infN >> shift;
183  static int32_t const nanN = (infC + 1) << shift; // minimum flt16 nan as a flt32
184  static int32_t const maxC = maxN >> shift;
185  static int32_t const minC = minN >> shift;
186  static int32_t const signC = signN >> shiftSign; // flt16 sign bit
187 
188  static int32_t const mulN = 0x52000000; // (1 << 23) / minN
189  static int32_t const mulC = 0x33800000; // minN / (1 << (23 - shift))
190 
191  static int32_t const subC = 0x003FF; // max flt32 subnormal down shifted
192  static int32_t const norC = 0x00400; // min flt32 normal down shifted
193 
194  static int32_t const maxD = infC - maxC - 1;
195  static int32_t const minD = minC - subC - 1;
196 
197  MSHADOW_XINLINE uint16_t float2half(const float& value) const {
198  Bits v;
199  v.f = value;
200  uint32_t sign = v.si & signN; // grab sign bit
201  v.si ^= sign; // clear sign bit from v
202  sign >>= shiftSign; // logical shift sign to fp16 position
203 
204  if (v.si <= maxZ) {
205  // Handle eventual zeros here to ensure vshift will not exceed 32 below.
206  v.ui = 0;
207  } else if (v.si < minN) {
208  // Handle denorms
209  uint32_t exp32 = v.ui >> fp32FractionBits;
210  int32_t exp16 = exp32 - expAdjust;
211  // If exp16 == 0 (just into the denorm range), then significant should be shifted right 1.
212  // Smaller (so negative) exp16 values should result in greater right shifts.
213  uint32_t vshift = 1 - exp16;
214  uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);
215  v.ui = significand >> vshift;
216  // The only time it's *not* OK to add 0x1000 (i.e. half the flt16 fraction lsb) is
217  // when the lsb of the flt16 fraction == 0 (so not rounding up to even) and the additional
218  // bits to the right of the lsb are 1000... (including flt32 significand bits
219  // that may be lost during the above vshift). The first term below will always
220  // be true for vshift >=12 (since even the 'hidden bit' has been shifted to the
221  // right of the '1' bit in 0x1000). And when vshift <= 11, both terms combine to make
222  // the proper test of the flt32 significand bits, including those lost during the vshift.
223 #if MSHADOW_HALF_ROUND_TO_NEAREST == 1
224  // Rounding may increase the exponent to 1, but that's OK.
225  v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 0;
226 #endif
227  } else if (v.si <= maxN) {
228  // Handle norms
229 #if MSHADOW_HALF_ROUND_TO_NEAREST == 1
230  // Rounding may increase the exponent, possibly creating an inf, but that's OK.
231  v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0;
232 #endif
233  v.ui -= expAdjust << fp32FractionBits;
234  } else if (v.si <= infN) {
235  v.si = infN;
236  } else if (v.si < nanN) {
237  v.si = nanN;
238  }
239 
240  v.ui >>= shift;
241  return sign | (v.ui & 0x7fff);
242  }
243 
244  // Same as above routine, except for addition of volatile keyword
245  MSHADOW_XINLINE uint16_t float2half(const volatile float& value) const volatile { // NOLINT (*)
246  Bits v;
247  v.f = value;
248  uint32_t sign = v.si & signN; // grab sign bit
249  v.si ^= sign; // clear sign bit from v
250  sign >>= shiftSign; // logical shift sign to fp16 position
251 
252  if (v.si <= maxZ) {
253  // Handle eventual zeros here to ensure vshift will not exceed 32 below.
254  v.ui = 0;
255  } else if (v.si < minN) {
256  // Handle denorms
257  uint32_t exp32 = v.ui >> fp32FractionBits;
258  int32_t exp16 = exp32 - expAdjust;
259  // If exp16 == 0 (just into the denorm range), then significant should be shifted right 1.
260  // Smaller (so negative) exp16 values should result in greater right shifts.
261  uint32_t vshift = 1 - exp16;
262  uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);
263  v.ui = significand >> vshift;
264 #if MSHADOW_HALF_ROUND_TO_NEAREST == 1
265  // Rounding may increase the exponent to 1, but that's OK.
266  v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 0;
267 #endif
268  } else if (v.si <= maxN) {
269  // Handle norms
270 #if MSHADOW_HALF_ROUND_TO_NEAREST == 1
271  // Rounding may increase the exponent, possibly creating an inf, but that's OK.
272  v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0;
273 #endif
274  v.ui -= expAdjust << fp32FractionBits;
275  } else if (v.si <= infN) {
276  v.si = infN;
277  } else if (v.si < nanN) {
278  v.si = nanN;
279  }
280 
281  v.ui >>= shift;
282  return sign | (v.ui & 0x7fff);
283  }
284 
285  MSHADOW_XINLINE float half2float(const uint16_t& value) const {
286  Bits v;
287  v.ui = value;
288  int32_t sign = v.si & signC;
289  v.si ^= sign;
290  sign <<= shiftSign;
291  v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);
292  v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);
293  Bits s;
294  s.si = mulC;
295  s.f *= v.si;
296  int32_t mask = -(norC > v.si);
297  v.si <<= shift;
298  v.si ^= (s.si ^ v.si) & mask;
299  v.si |= sign;
300  return v.f;
301  }
302 
303  MSHADOW_XINLINE float half2float(const volatile uint16_t& value) const volatile { // NOLINT(*)
304  Bits v;
305  v.ui = value;
306  int32_t sign = v.si & signC;
307  v.si ^= sign;
308  sign <<= shiftSign;
309  v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);
310  v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);
311  Bits s;
312  s.si = mulC;
313  s.f *= v.si;
314  int32_t mask = -(norC > v.si);
315  v.si <<= shift;
316  v.si ^= (s.si ^ v.si) & mask;
317  v.si |= sign;
318  return v.f;
319  }
320 
321  template<typename T>
322  MSHADOW_XINLINE void constructor(const T& value) {
323 #if (MSHADOW_CUDA_HALF && defined(__CUDA_ARCH__))
324  cuhalf_ = __float2half(float(value)); // NOLINT(*)
325 #elif(MSHADOW_USE_F16C)
326  half_ = _cvtss_sh(static_cast<float>(value), 0);
327 #else /* !MSHADOW_CUDA_HALF && !MSHADOW_USE_F16C */
328  half_ = float2half(float(value)); // NOLINT(*)
329 #endif /* !MSHADOW_CUDA_HALF && !MSHADOW_USE_F16C */
330  }
331 };
332 
334 MSHADOW_HALF_OPERATOR(half_t, +)
336 MSHADOW_HALF_OPERATOR(half_t, -)
338 MSHADOW_HALF_OPERATOR(half_t, *)
340 MSHADOW_HALF_OPERATOR(half_t, /)
342 MSHADOW_HALF_OPERATOR(bool, >)
344 MSHADOW_HALF_OPERATOR(bool, <)
346 MSHADOW_HALF_OPERATOR(bool, >=)
348 MSHADOW_HALF_OPERATOR(bool, <=)
349 
350 #define MSHADOW_HALF_MIN mshadow::half::half_t::Binary(0xFBFF);
351 #define MSHADOW_HALF_MAX mshadow::half::half_t::Binary(0x7BFF);
352 #define MSHADOW_HALF_SIGN_BIT 0x8000
353 #define MSHADOW_HALF_EXPONENT_BITS 0x7c00
354 } // namespace half
355 } // namespace mshadow
356 #endif // MSHADOW_HALF_H_
class MSHADOW_ALIGNED(2) half_t
Definition: half.h:94
#define MSHADOW_HALF_CONVERSIONOP(T)
Definition: half.h:77
#define MSHADOW_HALF_ASSIGNOP(AOP, OP)
Definition: half.h:58
MSHADOW_XINLINE half2_t operator+(half2_t a, half2_t b)
overloaded + operator for half2_t
Definition: half2.h:88
#define MSHADOW_XINLINE
Definition: base.h:204
MSHADOW_XINLINE half2_t operator-(half2_t a, half2_t b)
overloaded - operator for half2_t
Definition: half2.h:97
MaskExp< IndexExp, SrcExp, DType > mask(const Exp< IndexExp, DType, e1 > &index, const Exp< SrcExp, DType, e2 > &src)
Definition: mask.h:39
#define MSHADOW_HALF_OPERATOR(RTYPE, OP)
Definition: half.h:45
namespace for mshadow
Definition: base.h:282