mxnet
half2.h
Go to the documentation of this file.
1 
8 #ifndef MSHADOW_HALF2_H_
9 #define MSHADOW_HALF2_H_
10 
11 #if (defined(__CUDACC__) && __CUDA_ARCH__ >= 530 && MSHADOW_USE_CUDA && CUDA_VERSION >= 7050)
12  #define MSHADOW_CUDA_HALF2 1
13  #include <cuda_fp16.h>
14 #else
15  #define MSHADOW_CUDA_HALF2 0
16 #endif
17 
18 #include<math.h>
19 
21 namespace mshadow {
22 /* \brief name space for host/device portable half-precision floats */
23 namespace half {
24 
25 #define MSHADOW_HALF2_ASSIGNOP(AOP, OP) \
26  template<typename T> \
27  MSHADOW_XINLINE half2_t operator AOP (const T& a) { \
28  return *this = half2_t(*this OP a); /* NOLINT(*)*/ \
29  } \
30 
31 class MSHADOW_ALIGNED(4) half2_t {
32  public:
33 #if MSHADOW_CUDA_HALF2
34  half2 half2_;
35 #else
36  half_t half_t2[2];
37 #endif
38 
39  MSHADOW_XINLINE half2_t() {}
40 
41 #if MSHADOW_CUDA_HALF2
42  MSHADOW_XINLINE explicit half2_t(half2 a) : half2_(a) {}
43 #else
44  MSHADOW_XINLINE explicit half2_t(half_t a, half_t b) {
45  half_t2[0] = a;
46  half_t2[1] = b;
47  }
48 #endif
49 
50  MSHADOW_XINLINE explicit half2_t(int a) {
51 #if MSHADOW_CUDA_HALF2
52  half2_ = __half2half2(__int2half_rz(a));
53 #else
54  half_t2[0] = (half_t)a;
55  half_t2[1] = (half_t)a;
56 #endif
57  }
58 
59  MSHADOW_XINLINE half2_t operator+() {
60  return *this;
61  }
62 
63  MSHADOW_XINLINE half2_t operator-() {
64 #if MSHADOW_CUDA_HALF2
65  return half2_t(__hneg2(half2_));
66 #else
67  return half2_t(-half_t2[0], -half_t2[1]);
68 #endif
69  }
70 
71  MSHADOW_XINLINE half2_t operator=(const half2_t& a) {
72 #if MSHADOW_CUDA_HALF2
73  half2_ = a.half2_;
74 #else
75  half_t2[0] = a.half_t2[0];
76  half_t2[1] = a.half_t2[1];
77 #endif
78  return a;
79  }
80 
85 };
86 
88 MSHADOW_XINLINE half2_t operator+(half2_t a, half2_t b) {
89 #if MSHADOW_CUDA_HALF2
90  return half2_t(__floats2half2_rn(__low2float(a.half2_) + __low2float(b.half2_),
91  __high2float(a.half2_) + __high2float(b.half2_)));
92 #else
93  return half2_t(a.half_t2[0] + b.half_t2[0], a.half_t2[1] + b.half_t2[1]);
94 #endif
95 }
97 MSHADOW_XINLINE half2_t operator-(half2_t a, half2_t b) {
98 #if MSHADOW_CUDA_HALF2
99  return half2_t(__floats2half2_rn(__low2float(a.half2_) - __low2float(b.half2_),
100  __high2float(a.half2_) - __high2float(b.half2_)));
101 #else
102  return half2_t(a.half_t2[0] - b.half_t2[0], a.half_t2[1] - b.half_t2[1]);
103 #endif
104 }
106 MSHADOW_XINLINE half2_t operator*(half2_t a, half2_t b) {
107 #if MSHADOW_CUDA_HALF2
108  return half2_t(__floats2half2_rn(__low2float(a.half2_) * __low2float(b.half2_),
109  __high2float(a.half2_) * __high2float(b.half2_)));
110 #else
111  return half2_t(a.half_t2[0] * b.half_t2[0], a.half_t2[1] * b.half_t2[1]);
112 #endif
113 }
115 MSHADOW_XINLINE half2_t operator/(half2_t a, half2_t b) {
116 #if MSHADOW_CUDA_HALF2
117  return half2_t(__floats2half2_rn(__low2float(a.half2_) / __low2float(b.half2_),
118  __high2float(a.half2_) / __high2float(b.half2_)));
119 #else
120  return half2_t(a.half_t2[0] / b.half_t2[0], a.half_t2[1] / b.half_t2[1]);
121 #endif
122 }
124 MSHADOW_XINLINE half2_t operator%(half2_t a, half2_t b) {
125 #if MSHADOW_CUDA_HALF2
126  return half2_t(__floats2half2_rn(::fmod(__low2float(a.half2_), __low2float(b.half2_)),
127  ::fmod(__high2float(a.half2_), __high2float(b.half2_))));
128 #else
129  return half2_t(::fmod(a.half_t2[0], b.half_t2[0]), ::fmod(a.half_t2[1], b.half_t2[1]));
130 #endif
131 }
133 MSHADOW_XINLINE bool operator==(half2_t a, half2_t b) {
134 #if MSHADOW_CUDA_HALF2
135  return __hbeq2(a.half2_, b.half2_);
136 #else
137  return (a.half_t2[0] == b.half_t2[0] && a.half_t2[1] == b.half_t2[1]);
138 #endif
139 }
140 
141 } // namespace half
142 } // namespace mshadow
143 #endif // MSHADOW_HALF2_H_
class MSHADOW_ALIGNED(2) half_t
Definition: half.h:94
#define MSHADOW_HALF2_ASSIGNOP(AOP, OP)
Definition: half2.h:25
MSHADOW_XINLINE half2_t operator+(half2_t a, half2_t b)
overloaded + operator for half2_t
Definition: half2.h:88
#define MSHADOW_XINLINE
Definition: base.h:204
MSHADOW_XINLINE half2_t operator-(half2_t a, half2_t b)
overloaded - operator for half2_t
Definition: half2.h:97
MSHADOW_XINLINE bool operator==(half2_t a, half2_t b)
overloaded == operator for half2_t
Definition: half2.h:133
MSHADOW_XINLINE half2_t operator*(half2_t a, half2_t b)
overloaded * operator for half2_t
Definition: half2.h:106
MSHADOW_XINLINE half2_t operator/(half2_t a, half2_t b)
overloaded / operator for half2_t
Definition: half2.h:115
namespace for mshadow
Definition: base.h:282
MSHADOW_XINLINE half2_t operator%(half2_t a, half2_t b)
overloaded % operator for half2_t
Definition: half2.h:124