mxnet
half2.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
27 #ifndef MSHADOW_HALF2_H_
28 #define MSHADOW_HALF2_H_
29 
30 #if (defined(__CUDACC__) && __CUDA_ARCH__ >= 530 && MSHADOW_USE_CUDA && CUDA_VERSION >= 7050)
31  #define MSHADOW_CUDA_HALF2 1
32  #include <cuda_fp16.h>
33 #else
34  #define MSHADOW_CUDA_HALF2 0
35 #endif
36 
37 #include<math.h>
38 
40 namespace mshadow {
41 /* \brief name space for host/device portable half-precision floats */
42 namespace half {
43 
44 #define MSHADOW_HALF2_ASSIGNOP(AOP, OP) \
45  template<typename T> \
46  MSHADOW_XINLINE half2_t operator AOP (const T& a) { \
47  return *this = half2_t(*this OP a); /* NOLINT(*)*/ \
48  } \
49 
50 class MSHADOW_ALIGNED(4) half2_t {
51  public:
52 #if MSHADOW_CUDA_HALF2
53  half2 half2_;
54 #else
55  half_t half_t2[2];
56 #endif
57 
58  MSHADOW_XINLINE half2_t() {}
59 
60 #if MSHADOW_CUDA_HALF2
61  MSHADOW_XINLINE explicit half2_t(half2 a) : half2_(a) {}
62 #else
63  MSHADOW_XINLINE explicit half2_t(half_t a, half_t b) {
64  half_t2[0] = a;
65  half_t2[1] = b;
66  }
67 #endif
68 
69  MSHADOW_XINLINE explicit half2_t(int a) {
70 #if MSHADOW_CUDA_HALF2
71  half2_ = __half2half2(__int2half_rz(a));
72 #else
73  half_t2[0] = (half_t)a;
74  half_t2[1] = (half_t)a;
75 #endif
76  }
77 
78  MSHADOW_XINLINE half2_t operator+() {
79  return *this;
80  }
81 
82  MSHADOW_XINLINE half2_t operator-() {
83 #if MSHADOW_CUDA_HALF2
84  return half2_t(__hneg2(half2_));
85 #else
86  return half2_t(-half_t2[0], -half_t2[1]);
87 #endif
88  }
89 
90  MSHADOW_XINLINE half2_t operator=(const half2_t& a) {
91 #if MSHADOW_CUDA_HALF2
92  half2_ = a.half2_;
93 #else
94  half_t2[0] = a.half_t2[0];
95  half_t2[1] = a.half_t2[1];
96 #endif
97  return a;
98  }
99 
104 };
105 
107 MSHADOW_XINLINE half2_t operator+(half2_t a, half2_t b) {
108 #if MSHADOW_CUDA_HALF2
109  return half2_t(__floats2half2_rn(__low2float(a.half2_) + __low2float(b.half2_),
110  __high2float(a.half2_) + __high2float(b.half2_)));
111 #else
112  return half2_t(a.half_t2[0] + b.half_t2[0], a.half_t2[1] + b.half_t2[1]);
113 #endif
114 }
116 MSHADOW_XINLINE half2_t operator-(half2_t a, half2_t b) {
117 #if MSHADOW_CUDA_HALF2
118  return half2_t(__floats2half2_rn(__low2float(a.half2_) - __low2float(b.half2_),
119  __high2float(a.half2_) - __high2float(b.half2_)));
120 #else
121  return half2_t(a.half_t2[0] - b.half_t2[0], a.half_t2[1] - b.half_t2[1]);
122 #endif
123 }
125 MSHADOW_XINLINE half2_t operator*(half2_t a, half2_t b) {
126 #if MSHADOW_CUDA_HALF2
127  return half2_t(__floats2half2_rn(__low2float(a.half2_) * __low2float(b.half2_),
128  __high2float(a.half2_) * __high2float(b.half2_)));
129 #else
130  return half2_t(a.half_t2[0] * b.half_t2[0], a.half_t2[1] * b.half_t2[1]);
131 #endif
132 }
134 MSHADOW_XINLINE half2_t operator/(half2_t a, half2_t b) {
135 #if MSHADOW_CUDA_HALF2
136  return half2_t(__floats2half2_rn(__low2float(a.half2_) / __low2float(b.half2_),
137  __high2float(a.half2_) / __high2float(b.half2_)));
138 #else
139  return half2_t(a.half_t2[0] / b.half_t2[0], a.half_t2[1] / b.half_t2[1]);
140 #endif
141 }
143 MSHADOW_XINLINE half2_t operator%(half2_t a, half2_t b) {
144 #if MSHADOW_CUDA_HALF2
145  return half2_t(__floats2half2_rn(::fmod(__low2float(a.half2_), __low2float(b.half2_)),
146  ::fmod(__high2float(a.half2_), __high2float(b.half2_))));
147 #else
148  return half2_t(::fmod(a.half_t2[0], b.half_t2[0]), ::fmod(a.half_t2[1], b.half_t2[1]));
149 #endif
150 }
152 MSHADOW_XINLINE bool operator==(half2_t a, half2_t b) {
153 #if MSHADOW_CUDA_HALF2
154  return __hbeq2(a.half2_, b.half2_);
155 #else
156  return (a.half_t2[0] == b.half_t2[0] && a.half_t2[1] == b.half_t2[1]);
157 #endif
158 }
159 
160 } // namespace half
161 } // namespace mshadow
162 #endif // MSHADOW_HALF2_H_
class MSHADOW_ALIGNED(2) half_t
Definition: half.h:113
#define MSHADOW_HALF2_ASSIGNOP(AOP, OP)
Definition: half2.h:44
MSHADOW_XINLINE half2_t operator+(half2_t a, half2_t b)
overloaded + operator for half2_t
Definition: half2.h:107
#define MSHADOW_XINLINE
Definition: base.h:223
MSHADOW_XINLINE half2_t operator-(half2_t a, half2_t b)
overloaded - operator for half2_t
Definition: half2.h:116
MSHADOW_XINLINE bool operator==(half2_t a, half2_t b)
overloaded == operator for half2_t
Definition: half2.h:152
MSHADOW_XINLINE half2_t operator*(half2_t a, half2_t b)
overloaded * operator for half2_t
Definition: half2.h:125
MSHADOW_XINLINE half2_t operator/(half2_t a, half2_t b)
overloaded / operator for half2_t
Definition: half2.h:134
overloaded + operator between half_t and bf16_t
Definition: base.h:327
MSHADOW_XINLINE half2_t operator%(half2_t a, half2_t b)
overloaded % operator for half2_t
Definition: half2.h:143