26 #ifndef MSHADOW_HALF_H_
27 #define MSHADOW_HALF_H_
31 #include <x86intrin.h>
32 #endif // MSHADOW_USE_F16C
36 #ifndef MSHADOW_HALF_ROUND_TO_NEAREST
37 #define MSHADOW_HALF_ROUND_TO_NEAREST 1
40 #if (MSHADOW_USE_CUDA && CUDA_VERSION >= 7050)
41 #define MSHADOW_CUDA_HALF 1
42 #include <cuda_fp16.h>
43 #if defined(__CUDA_ARCH__)
47 #if CUDA_VERSION >= 9000
48 val =
const_cast<__half&
>(h);
52 return __half2float(val);
56 #define MSHADOW_CUDA_HALF 0
63 #define MSHADOW_HALF_OPERATOR(RTYPE, OP) \
64 MSHADOW_XINLINE RTYPE operator OP (half_t a, half_t b) { \
65 return RTYPE(float(a) OP float(b)); \
67 template<typename T> \
68 MSHADOW_XINLINE RTYPE operator OP (half_t a, T b) { \
69 return RTYPE(float(a) OP float(b)); \
71 template<typename T> \
72 MSHADOW_XINLINE RTYPE operator OP (T a, half_t b) { \
73 return RTYPE(float(a) OP float(b)); \
76 #define MSHADOW_HALF_ASSIGNOP(AOP, OP) \
77 template<typename T> \
78 MSHADOW_XINLINE half_t operator AOP (const T& a) { \
79 return *this = half_t(float(*this) OP float(a)); \
81 template<typename T> \
82 MSHADOW_XINLINE half_t operator AOP (const volatile T& a) volatile { \
83 return *this = half_t(float(*this) OP float(a)); \
86 #if (MSHADOW_CUDA_HALF && defined(__CUDA_ARCH__))
87 #define MSHADOW_HALF_CONVERSIONOP(T) \
88 MSHADOW_XINLINE operator T() const { \
89 return T(__half2float(cuhalf_)); \
91 MSHADOW_XINLINE operator T() const volatile { \
92 return T(__half2float_warp(cuhalf_)); \
94 #elif(MSHADOW_USE_F16C)
95 #define MSHADOW_HALF_CONVERSIONOP(T) \
96 MSHADOW_XINLINE operator T() const { \
97 return T(_cvtsh_ss(half_)); \
99 MSHADOW_XINLINE operator T() const volatile { \
100 return T(_cvtsh_ss(half_)); \
103 #define MSHADOW_HALF_CONVERSIONOP(T) \
104 MSHADOW_XINLINE operator T() const { \
105 return T(half2float(half_)); \
107 MSHADOW_XINLINE operator T() const volatile { \
108 return T(half2float(half_)); \
110 #endif // (MSHADOW_CUDA_HALF && defined(__CUDA_ARCH__))
116 #if MSHADOW_CUDA_HALF
118 #endif // MSHADOW_CUDA_HALF
129 #if MSHADOW_CUDA_HALF
133 #endif // MSHADOW_CUDA_HALF
136 MSHADOW_XINLINE explicit half_t(
const double& value) { constructor(value); }
137 MSHADOW_XINLINE explicit half_t(
const int8_t& value) { constructor(value); }
138 MSHADOW_XINLINE explicit half_t(
const uint8_t& value) { constructor(value); }
139 MSHADOW_XINLINE explicit half_t(
const int32_t& value) { constructor(value); }
140 MSHADOW_XINLINE explicit half_t(
const uint32_t& value) { constructor(value); }
141 MSHADOW_XINLINE explicit half_t(
const int64_t& value) { constructor(value); }
142 MSHADOW_XINLINE explicit half_t(
const uint64_t& value) { constructor(value); }
156 return half_t(-
float(*
this));
166 return *
this = half_t(a);
176 return *
this = half_t(a);
186 static int const fp16FractionBits = 10;
187 static int const fp32FractionBits = 23;
188 static int32_t
const fp32FractionMask = ~(~0u << fp32FractionBits);
189 static int32_t
const fp32HiddenBit = 1 << fp32FractionBits;
190 static int const shift = fp32FractionBits - fp16FractionBits;
191 static int const shiftSign = 16;
192 static int32_t
const expAdjust = 127 - 15;
194 static int32_t
const infN = 0x7F800000;
195 static int32_t
const maxN = 0x477FFFFF;
196 static int32_t
const minN = 0x38800000;
197 static int32_t
const maxZ = 0x33000000;
198 static int32_t
const signN = 0x80000000;
200 static int32_t
const infC = infN >> shift;
201 static int32_t
const nanN = (infC + 1) << shift;
202 static int32_t
const maxC = maxN >> shift;
203 static int32_t
const minC = minN >> shift;
204 static int32_t
const signC = signN >> shiftSign;
206 static int32_t
const mulN = 0x52000000;
207 static int32_t
const mulC = 0x33800000;
209 static int32_t
const subC = 0x003FF;
210 static int32_t
const norC = 0x00400;
212 static int32_t
const maxD = infC - maxC - 1;
213 static int32_t
const minD = minC - subC - 1;
218 uint32_t sign = v.si & signN;
225 }
else if (v.si < minN) {
227 uint32_t exp32 = v.ui >> fp32FractionBits;
228 int32_t exp16 = exp32 - expAdjust;
231 uint32_t vshift = 1 - exp16;
232 uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);
233 v.ui = significand >> vshift;
241 #if MSHADOW_HALF_ROUND_TO_NEAREST == 1
243 v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 0;
245 }
else if (v.si <= maxN) {
247 #if MSHADOW_HALF_ROUND_TO_NEAREST == 1
249 v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0;
251 v.ui -= expAdjust << fp32FractionBits;
252 }
else if (v.si <= infN) {
254 }
else if (v.si < nanN) {
259 return sign | (v.ui & 0x7fff);
263 MSHADOW_XINLINE uint16_t float2half(
const volatile float& value)
const volatile {
266 uint32_t sign = v.si & signN;
273 }
else if (v.si < minN) {
275 uint32_t exp32 = v.ui >> fp32FractionBits;
276 int32_t exp16 = exp32 - expAdjust;
279 uint32_t vshift = 1 - exp16;
280 uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);
281 v.ui = significand >> vshift;
282 #if MSHADOW_HALF_ROUND_TO_NEAREST == 1
284 v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 0;
286 }
else if (v.si <= maxN) {
288 #if MSHADOW_HALF_ROUND_TO_NEAREST == 1
290 v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0;
292 v.ui -= expAdjust << fp32FractionBits;
293 }
else if (v.si <= infN) {
295 }
else if (v.si < nanN) {
300 return sign | (v.ui & 0x7fff);
306 int32_t sign = v.si & signC;
309 v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);
310 v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);
314 int32_t
mask = -(norC > v.si);
316 v.si ^= (s.si ^ v.si) &
mask;
321 MSHADOW_XINLINE float half2float(
const volatile uint16_t& value)
const volatile {
324 int32_t sign = v.si & signC;
327 v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);
328 v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);
332 int32_t
mask = -(norC > v.si);
334 v.si ^= (s.si ^ v.si) &
mask;
341 #if (MSHADOW_CUDA_HALF && defined(__CUDA_ARCH__))
342 cuhalf_ = __float2half(
float(value));
343 #elif(MSHADOW_USE_F16C)
344 half_ = _cvtss_sh(
static_cast<float>(value), 0);
346 half_ = float2half(
float(value));
368 #define MSHADOW_HALF_MIN mshadow::half::half_t::Binary(0xFBFF);
369 #define MSHADOW_HALF_MAX mshadow::half::half_t::Binary(0x7BFF);
370 #define MSHADOW_HALF_SIGN_BIT 0x8000
371 #define MSHADOW_HALF_EXPONENT_BITS 0x7c00
374 #endif // MSHADOW_HALF_H_