27 #ifndef MSHADOW_HALF_H_ 28 #define MSHADOW_HALF_H_ 32 #include <x86intrin.h> 33 #endif // MSHADOW_USE_F16C 37 #ifndef MSHADOW_HALF_ROUND_TO_NEAREST 38 #define MSHADOW_HALF_ROUND_TO_NEAREST 1 41 #if (MSHADOW_USE_CUDA && CUDA_VERSION >= 7050) 42 #define MSHADOW_CUDA_HALF 1 43 #include <cuda_fp16.h> 44 #if defined(__CUDA_ARCH__) 46 __host__ __device__
float __half2float_warp(
const volatile __half& h) {
48 #if CUDA_VERSION >= 9000 49 val =
const_cast<__half&
>(h);
53 return __half2float(val);
57 #define MSHADOW_CUDA_HALF 0 64 #define MSHADOW_HALF_OPERATOR(RTYPE, OP) \ 65 MSHADOW_XINLINE RTYPE operator OP (half_t a, half_t b) { \ 66 return RTYPE(float(a) OP float(b)); \ 68 template<typename T> \ 69 MSHADOW_XINLINE RTYPE operator OP (half_t a, T b) { \ 70 return RTYPE(float(a) OP float(b)); \ 72 template<typename T> \ 73 MSHADOW_XINLINE RTYPE operator OP (T a, half_t b) { \ 74 return RTYPE(float(a) OP float(b)); \ 77 #define MSHADOW_HALF_ASSIGNOP(AOP, OP) \ 78 template<typename T> \ 79 MSHADOW_XINLINE half_t operator AOP (const T& a) { \ 80 return *this = half_t(float(*this) OP float(a)); \ 82 template<typename T> \ 83 MSHADOW_XINLINE half_t operator AOP (const volatile T& a) volatile { \ 84 return *this = half_t(float(*this) OP float(a)); \ 87 #if (MSHADOW_CUDA_HALF && defined(__CUDA_ARCH__)) 88 #define MSHADOW_HALF_CONVERSIONOP(T) \ 89 MSHADOW_XINLINE operator T() const { \ 90 return T(__half2float(cuhalf_)); \ 92 MSHADOW_XINLINE operator T() const volatile { \ 93 return T(__half2float_warp(cuhalf_)); \ 95 #elif(MSHADOW_USE_F16C) 96 #define MSHADOW_HALF_CONVERSIONOP(T) \ 97 MSHADOW_XINLINE operator T() const { \ 98 return T(_cvtsh_ss(half_)); \ 100 MSHADOW_XINLINE operator T() const volatile { \ 101 return T(_cvtsh_ss(half_)); \ 104 #define MSHADOW_HALF_CONVERSIONOP(T) \ 105 MSHADOW_XINLINE operator T() const { \ 106 return T(half2float(half_)); \ 108 MSHADOW_XINLINE operator T() const volatile { \ 109 return T(half2float(half_)); \ 111 #endif // (MSHADOW_CUDA_HALF && defined(__CUDA_ARCH__)) 117 #if MSHADOW_CUDA_HALF 119 #endif // MSHADOW_CUDA_HALF 130 #if MSHADOW_CUDA_HALF 134 #endif // MSHADOW_CUDA_HALF 137 MSHADOW_XINLINE explicit half_t(
const double& value) { constructor(value); }
138 MSHADOW_XINLINE explicit half_t(
const int8_t& value) { constructor(value); }
139 MSHADOW_XINLINE explicit half_t(
const uint8_t& value) { constructor(value); }
140 MSHADOW_XINLINE explicit half_t(
const int32_t& value) { constructor(value); }
141 MSHADOW_XINLINE explicit half_t(
const uint32_t& value) { constructor(value); }
142 MSHADOW_XINLINE explicit half_t(
const int64_t& value) { constructor(value); }
143 MSHADOW_XINLINE explicit half_t(
const uint64_t& value) { constructor(value); }
157 return half_t(-
float(*
this));
167 return *
this = half_t(a);
177 return *
this = half_t(a);
187 static int const fp16FractionBits = 10;
188 static int const fp32FractionBits = 23;
189 static int32_t
const fp32FractionMask = ~(~0u << fp32FractionBits);
190 static int32_t
const fp32HiddenBit = 1 << fp32FractionBits;
191 static int const shift = fp32FractionBits - fp16FractionBits;
192 static int const shiftSign = 16;
193 static int32_t
const expAdjust = 127 - 15;
195 static int32_t
const infN = 0x7F800000;
196 static int32_t
const maxN = 0x477FFFFF;
197 static int32_t
const minN = 0x38800000;
198 static int32_t
const maxZ = 0x33000000;
199 static int32_t
const signN = 0x80000000;
201 static int32_t
const infC = infN >> shift;
202 static int32_t
const nanN = (infC + 1) << shift;
203 static int32_t
const maxC = maxN >> shift;
204 static int32_t
const minC = minN >> shift;
205 static int32_t
const signC = signN >> shiftSign;
207 static int32_t
const mulN = 0x52000000;
208 static int32_t
const mulC = 0x33800000;
210 static int32_t
const subC = 0x003FF;
211 static int32_t
const norC = 0x00400;
213 static int32_t
const maxD = infC - maxC - 1;
214 static int32_t
const minD = minC - subC - 1;
219 uint32_t sign = v.si & signN;
226 }
else if (v.si < minN) {
228 uint32_t exp32 = v.ui >> fp32FractionBits;
229 int32_t exp16 = exp32 - expAdjust;
232 uint32_t vshift = 1 - exp16;
233 uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);
234 v.ui = significand >> vshift;
242 #if MSHADOW_HALF_ROUND_TO_NEAREST == 1 244 v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 0;
246 }
else if (v.si <= maxN) {
248 #if MSHADOW_HALF_ROUND_TO_NEAREST == 1 250 v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0;
252 v.ui -= expAdjust << fp32FractionBits;
253 }
else if (v.si <= infN) {
255 }
else if (v.si < nanN) {
260 return sign | (v.ui & 0x7fff);
264 MSHADOW_XINLINE uint16_t float2half(
const volatile float& value)
const volatile {
267 uint32_t sign = v.si & signN;
274 }
else if (v.si < minN) {
276 uint32_t exp32 = v.ui >> fp32FractionBits;
277 int32_t exp16 = exp32 - expAdjust;
280 uint32_t vshift = 1 - exp16;
281 uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);
282 v.ui = significand >> vshift;
283 #if MSHADOW_HALF_ROUND_TO_NEAREST == 1 285 v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 0;
287 }
else if (v.si <= maxN) {
289 #if MSHADOW_HALF_ROUND_TO_NEAREST == 1 291 v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0;
293 v.ui -= expAdjust << fp32FractionBits;
294 }
else if (v.si <= infN) {
296 }
else if (v.si < nanN) {
301 return sign | (v.ui & 0x7fff);
307 int32_t sign = v.si & signC;
310 v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);
311 v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);
315 int32_t
mask = -(norC > v.si);
317 v.si ^= (s.si ^ v.si) & mask;
322 MSHADOW_XINLINE float half2float(
const volatile uint16_t& value)
const volatile {
325 int32_t sign = v.si & signC;
328 v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);
329 v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);
333 int32_t
mask = -(norC > v.si);
335 v.si ^= (s.si ^ v.si) & mask;
342 #if (MSHADOW_CUDA_HALF && defined(__CUDA_ARCH__)) 343 cuhalf_ = __float2half(
float(value));
344 #elif(MSHADOW_USE_F16C) 345 half_ = _cvtss_sh(static_cast<float>(value), 0);
347 half_ = float2half(
float(value));
369 #define MSHADOW_HALF_MIN mshadow::half::half_t::Binary(0xFBFF); 370 #define MSHADOW_HALF_MAX mshadow::half::half_t::Binary(0x7BFF); 371 #define MSHADOW_HALF_SIGN_BIT 0x8000 372 #define MSHADOW_HALF_EXPONENT_BITS 0x7c00 375 #endif // MSHADOW_HALF_H_ class MSHADOW_ALIGNED(2) half_t
Definition: half.h:113
#define MSHADOW_HALF_CONVERSIONOP(T)
Definition: half.h:96
#define MSHADOW_HALF_ASSIGNOP(AOP, OP)
Definition: half.h:77
MSHADOW_XINLINE half2_t operator+(half2_t a, half2_t b)
overloaded + operator for half2_t
Definition: half2.h:107
#define MSHADOW_XINLINE
Definition: base.h:223
MSHADOW_XINLINE half2_t operator-(half2_t a, half2_t b)
overloaded - operator for half2_t
Definition: half2.h:116
MaskExp< IndexExp, SrcExp, DType > mask(const Exp< IndexExp, DType, e1 > &index, const Exp< SrcExp, DType, e2 > &src)
Definition: mask.h:58
#define MSHADOW_HALF_OPERATOR(RTYPE, OP)
Definition: half.h:64
overloaded + operator between half_t and bf16_t
Definition: base.h:327