Go to the documentation of this file.
20 #ifndef MXNET_COMMON_CUDA_RTC_UTIL_INL_H_
21 #define MXNET_COMMON_CUDA_RTC_UTIL_INL_H_
33 using float32 = float;
34 using float64 = double;
36 using uint8 = unsigned char;
39 using int64 = long long;
41 using uint16 = unsigned short;
42 using uint32 = unsigned int;
43 using uint64 = unsigned long long;
45 static_assert(sizeof(float32) == 4, "Size of float32 is expected to be 4B");
46 static_assert(sizeof(float64) == 8, "Size of float64 is expected to be 8B");
47 static_assert(sizeof(float16) == 2, "Size of float16 is expected to be 2B");
48 static_assert(sizeof(uint8) == 1, "Size of uint8 is expected to be 1B");
49 static_assert(sizeof(int8) == 1, "Size of int8 is expected to be 1B");
50 static_assert(sizeof(int32) == 4, "Size of int32 is expected to be 4B");
51 static_assert(sizeof(int64) == 8, "Size of int64 is expected to be 8B");
52 static_assert(sizeof(int16) == 2, "Size of int16 is expected to be 2B");
53 static_assert(sizeof(uint16) == 2, "Size of uint16 is expected to be 2B");
54 static_assert(sizeof(uint32) == 4, "Size of uint32 is expected to be 4B");
55 static_assert(sizeof(uint64) == 8, "Size of uint64 is expected to be 8B");
58 #if MSHADOW_INT64_TENSOR_SIZE == 1
59 "typedef int64 index_t;\n"
61 "typedef int32 index_t;\n"
64 // bool and int8 need to be accumulated in index_t
65 // but bool needs to be treated in the special way
66 // for ops like bitwise_not
70 __device__ inline bool_t(const index_t& v) : value(v) {}
71 __device__ inline bool_t(const volatile index_t& v) : value(v) {}
72 __device__ inline bool_t() : value(0) {}
74 __device__ inline operator index_t() const volatile { return value; }
75 __device__ inline bool_t& operator= (const index_t& v) {
79 __device__ inline volatile bool_t& operator= (const index_t& v) volatile {
83 __device__ inline bool_t& operator= (const volatile index_t& v) {
89 struct AccType<bool> {
92 __device__ static inline type from(const bool& val) {
96 __device__ static inline bool to(type val) {
102 struct AccType<int8> {
103 using type = index_t;
105 __device__ static inline type from(const int8& val) {
109 __device__ static inline int8 to(type val) {
115 struct AccType<uint8> {
116 using type = index_t;
118 __device__ static inline type from(const uint8& val) {
122 __device__ static inline uint8 to(type val) {
127 namespace type_util {
130 static constexpr bool value = false;
134 static constexpr bool value = true;
138 template <typename T> struct is_integral : false_type {};
139 template <> struct is_integral<uint8> : true_type {};
140 template <> struct is_integral<uint16> : true_type {};
141 template <> struct is_integral<uint32> : true_type {};
142 template <> struct is_integral<uint64> : true_type {};
143 template <> struct is_integral<int8> : true_type {};
144 template <> struct is_integral<int16> : true_type {};
145 template <> struct is_integral<int32> : true_type {};
146 template <> struct is_integral<int64> : true_type {};
147 template <> struct is_integral<bool> : true_type {};
148 template <> struct is_integral<bool_t> : true_type {};
151 template <typename T> struct is_unsigned : false_type {};
152 template <> struct is_unsigned<uint8> : true_type {};
153 template <> struct is_unsigned<uint16> : true_type {};
154 template <> struct is_unsigned<uint32> : true_type {};
155 template <> struct is_unsigned<uint64> : true_type {};
156 template <> struct is_unsigned<bool> : true_type {};
157 template <> struct is_unsigned<bool_t> : true_type {};
160 template <typename T, typename U>
161 struct is_same : false_type {};
162 template <typename T> struct is_same<T, T> : true_type {};
165 template <typename... T> struct has_double : false_type {};
167 template <typename A, typename... B>
168 struct has_double<A, B...> {
169 static constexpr bool value = is_same<A, double>::value ||
170 has_double<B...>::value;
173 // has_double_or_integral
174 template <typename... T> struct has_double_or_integral : false_type {};
176 template <typename A, typename... B>
177 struct has_double_or_integral<A, B...> {
178 static constexpr bool value = is_same<A, double>::value ||
179 is_integral<A>::value ||
180 has_double_or_integral<B...>::value;
187 struct enable_if<true> {
191 template <typename T, typename U, class Enable = void>
192 struct mixed_type_helper;
194 template <typename T>
195 struct mixed_type_helper<T, float64, typename enable_if<!is_same<float64, T>::value>::type> {
196 using type = float64;
199 template <typename T>
200 struct mixed_type_helper<float64, T> {
201 using type = float64;
204 template <typename T>
205 struct mixed_type_helper<T, float32, typename enable_if<!is_same<float64, T>::value &&
206 !is_same<float32, T>::value>::type> {
207 using type = float32;
210 template <typename T>
211 struct mixed_type_helper<float32, T, typename enable_if<!is_same<float64, T>::value>::type> {
212 using type = float32;
215 template <typename T>
216 struct mixed_type_helper<T, float16, typename enable_if<is_same<float16, T>::value ||
217 is_integral<T>::value>::type> {
218 using type = float16;
221 template <typename T>
222 struct mixed_type_helper<float16, T, typename enable_if<is_integral<T>::value>::type> {
223 using type = float16;
226 template <typename T, typename U>
227 struct mixed_type_helper<T, U, typename enable_if<is_integral<T>::value &&
228 is_integral<U>::value &&
229 is_unsigned<T>::value &&
230 is_unsigned<U>::value &&
231 !is_same<U, bool_t>::value &&
232 sizeof(T) < sizeof(U)>::type> {
236 template <typename T, typename U>
237 struct mixed_type_helper<T, U, typename enable_if<is_integral<T>::value &&
238 is_integral<U>::value &&
239 !is_unsigned<T>::value &&
240 !is_unsigned<U>::value &&
241 !is_same<U, bool_t>::value &&
242 sizeof(T) < sizeof(U)>::type> {
246 template <typename T, typename U>
247 struct mixed_type_helper<T, U, typename enable_if<is_integral<T>::value &&
248 is_integral<U>::value &&
249 is_unsigned<T>::value &&
250 !is_unsigned<U>::value &&
251 !is_same<U, bool_t>::value &&
252 sizeof(T) < sizeof(U)>::type> {
256 template <typename T, typename U>
257 struct mixed_type_helper<U, T, typename enable_if<is_integral<T>::value &&
258 is_integral<U>::value &&
259 is_unsigned<T>::value &&
260 is_unsigned<U>::value &&
261 !is_same<U, bool_t>::value &&
262 sizeof(T) < sizeof(U)>::type> {
266 template <typename T, typename U>
267 struct mixed_type_helper<U, T, typename enable_if<is_integral<T>::value &&
268 is_integral<U>::value &&
269 !is_unsigned<T>::value &&
270 !is_unsigned<U>::value &&
271 !is_same<U, bool_t>::value &&
272 sizeof(T) < sizeof(U)>::type> {
276 template <typename T, typename U>
277 struct mixed_type_helper<U, T, typename enable_if<is_integral<T>::value &&
278 is_integral<U>::value &&
279 is_unsigned<T>::value &&
280 !is_unsigned<U>::value &&
281 !is_same<U, bool_t>::value &&
282 sizeof(T) < sizeof(U)>::type> {
286 template <typename T, typename U>
287 struct mixed_type_helper<T, U, typename enable_if<is_integral<T>::value &&
288 is_integral<U>::value &&
289 !is_same<U, bool_t>::value &&
290 is_same<T, U>::value>::type> {
295 struct mixed_type_helper<int8, uint8> {
300 struct mixed_type_helper<uint8, int8> {
305 struct mixed_type_helper<int8, uint16> {
310 struct mixed_type_helper<uint16, int8> {
315 struct mixed_type_helper<int8, uint32> {
320 struct mixed_type_helper<uint32, int8> {
325 struct mixed_type_helper<int16, uint16> {
330 struct mixed_type_helper<uint16, int16> {
335 struct mixed_type_helper<int16, uint32> {
340 struct mixed_type_helper<uint32, int16> {
345 struct mixed_type_helper<int32, uint32> {
350 struct mixed_type_helper<uint32, int32> {
355 struct mixed_type_helper<uint64, index_t> {
356 using type = index_t;
360 struct mixed_type_helper<index_t, uint64> {
361 using type = index_t;
364 template <typename T>
365 struct mixed_type_helper<T, bool_t, typename enable_if<is_integral<T>::value &&
366 sizeof(T) < sizeof(bool_t)>::type> {
367 using type = index_t;
370 template <typename T>
371 struct mixed_type_helper<bool_t, T, typename enable_if<is_integral<T>::value &&
372 sizeof(T) < sizeof(bool_t)>::type> {
373 using type = index_t;
376 template <typename T>
377 struct mixed_type_helper<T, bool_t, typename enable_if<is_integral<T>::value &&
378 sizeof(T) == sizeof(bool_t)>::type> {
382 template <typename T>
383 struct mixed_type_helper<bool_t, T, typename enable_if<is_integral<T>::value &&
384 !is_same<T, bool_t>::value &&
385 sizeof(T) == sizeof(bool_t)>::type> {
389 template <typename... Ts>
390 struct multi_mixed_type_helper;
393 struct multi_mixed_type_helper<> {
397 template <typename T>
398 struct multi_mixed_type_helper<T> {
402 template <typename T, typename U, typename... Ts>
403 struct multi_mixed_type_helper<T, U, Ts...> {
404 using type = typename mixed_type_helper<T,
405 typename multi_mixed_type_helper<U,
409 template <typename... Ts>
410 using mixed_type = typename multi_mixed_type_helper<Ts...>::type;
412 } // namespace type_util
416 enum class OpReqType {
423 constexpr int kRTCMaxThreadsPerBlock = 512;
424 constexpr int warp_size = 32;
428 constexpr int MAX_DIM = 5;
431 __device__ inline void unravel_dot(const index_t idx, const index_t (&shape)[MAX_DIM],
432 const index_t (&stridej)[MAX_DIM], const index_t (&stridek)[MAX_DIM], index_t* j, index_t* k) {
436 for (index_t i = ndim-1, idx_t = idx; i >=0; --i) {
437 const auto tmp = idx_t / shape[i];
438 const auto coord = idx_t - tmp*shape[i];
439 *j += coord*stridej[i];
440 *k += coord*stridek[i];
446 __device__ inline index_t unravel_dot(const index_t idx, const index_t (&shape)[MAX_DIM],
447 const index_t (&stride)[MAX_DIM]) {
450 for (index_t i = ndim-1, j = idx; i >=0; --i) {
451 auto tmp = j / shape[i];
452 ret += (j - tmp*shape[i])*stride[i];
459 __device__ inline index_t unravel_ravel(const index_t idx, const index_t (&shape1)[MAX_DIM],
460 const index_t (&shape2)[MAX_DIM]) {
462 index_t total_shape = 1;
464 for (index_t i = ndim-1, j = idx; i >=0; --i) {
466 total_shape *= shape2[i + 1];
468 auto tmp = j / shape1[i];
469 const index_t coord = j - tmp*shape1[i];
470 ret += total_shape * (shape2[i] > coord) * coord;
476 template<int ndim, int ndim2>
477 __device__ inline index_t ravel(const index_t (&coord)[ndim], const index_t (&shape)[ndim2]) {
480 for (int i = 0; i < ndim; ++i) {
481 ret = ret * shape[i] + (shape[i] > coord[i]) * coord[i];
486 template<int ndim, int ndim2>
487 __device__ inline void unravel(const index_t idx,
488 const index_t (&shape)[ndim2],
489 index_t (&coord)[ndim]) {
491 for (index_t i = ndim-1, j = idx; i >=0; --i) {
492 auto tmp = j / shape[i];
493 coord[i] = j - tmp*shape[i];
498 template <typename DType>
499 __device__ inline bool isinf(volatile const DType &val) {
504 __device__ inline bool isinf(volatile const float &val) {
509 __device__ inline bool isinf(volatile const double &val) {
514 __device__ inline bool isinf(volatile const long double &val) {
519 __device__ inline bool isinf(volatile const float16 &val) {
520 return ::isinf(__half2float(const_cast<const float16&>(val)));
523 template <typename DType>
524 __device__ inline bool isnan(volatile const DType &val) {
529 __device__ inline bool isnan(volatile const float &val) {
534 __device__ inline bool isnan(volatile const double &val) {
539 __device__ inline bool isnan(volatile const long double &val) {
544 __device__ inline bool isnan(volatile const float16 &val) {
545 return ::isnan(__half2float(const_cast<const float16&>(val)));
548 template <int NVALUES = warp_size, typename OP, typename T>
549 __device__ inline T warp_reduce(T value, OP redfun) {
551 for (int i = warp_size / 2; i >= 1; i /= 2) {
552 if (NVALUES > i) value = redfun(value, __shfl_down_sync(0xffffffff, value, i));
557 template <typename OP, typename T>
558 __device__ inline T grouped_warp_reduce(T value, OP redfun, const int group_size) {
559 for (int i = 1; i < group_size; i *= 2) {
560 value = redfun(value, __shfl_down_sync(0xffffffff, value, i));
565 template <typename OP, typename T>
566 __device__ inline T grouped_warp_allreduce(T value, OP redfun, const int group_size) {
567 value = grouped_warp_reduce(value, redfun, group_size);
568 return __shfl_sync(0xffffffff, value, 0, group_size);
571 template <typename OP, typename T>
572 __device__ inline T strided_grouped_warp_reduce(T value, OP redfun, const int group_size) {
573 for (int i = warp_size / 2; i >= group_size; i /= 2) {
574 value = redfun(value, __shfl_down_sync(0xffffffff, value, i));
579 template <typename OP, typename T>
580 __device__ inline T strided_grouped_warp_allreduce(T value, OP redfun, const int group_size) {
581 value = strided_grouped_warp_reduce(value, redfun, group_size);
582 for (int i = group_size; i < warp_size; i *= 2) {
583 T tmp = __shfl_up_sync(0xffffffff, value, i);
584 if (threadIdx.x % warp_size >= i) {
595 constexpr double DBL_MAX = 1.7976931348623157081e+308;
596 constexpr float FLT_MAX = 3.4028234663852885981e+38;
597 #define inf ((float)1e50)
598 #define nan (inf - inf)
602 template<typename DType>
603 __device__ inline DType MinValue(void);
606 __device__ inline float MinValue<float>(void) {
611 __device__ inline double MinValue<double>(void) {
616 __device__ inline uint8 MinValue<uint8>(void) {
621 __device__ inline uint16 MinValue<uint16>(void) {
626 __device__ inline uint32 MinValue<uint32>(void) {
631 __device__ inline uint64 MinValue<uint64>(void) {
636 __device__ inline int8 MinValue<int8>(void) {
641 __device__ inline int16 MinValue<int16>(void) {
646 __device__ inline int32 MinValue<int32>(void) {
651 __device__ inline int64 MinValue<int64>(void) {
652 return -9223372036854775808LL;
656 __device__ inline bool MinValue<bool>(void) {
661 __device__ inline bool_t MinValue<bool_t>(void) {
662 return MinValue<index_t>();
669 template<typename DType>
670 __device__ inline DType NegInfValue(void) {
671 return MinValue<DType>();
675 __device__ inline float NegInfValue<float>(void) {
680 __device__ inline double NegInfValue<double>(void) {
688 template<typename DType>
689 __device__ inline DType MaxValue(void);
692 __device__ inline float MaxValue<float>(void) {
697 __device__ inline double MaxValue<double>(void) {
702 __device__ inline uint8 MaxValue<uint8>(void) {
707 __device__ inline uint16 MaxValue<uint16>(void) {
712 __device__ inline uint32 MaxValue<uint32>(void) {
717 __device__ inline uint64 MaxValue<uint64>(void) {
718 return 18446744073709551615LL;
722 __device__ inline int8 MaxValue<int8>(void) {
727 __device__ inline int16 MaxValue<int16>(void) {
732 __device__ inline int32 MaxValue<int32>(void) {
737 __device__ inline int64 MaxValue<int64>(void) {
738 return 9223372036854775807LL;
742 __device__ inline bool MaxValue<bool>(void) {
747 __device__ inline bool_t MaxValue<bool_t>(void) {
748 return MaxValue<index_t>();
754 template<typename DType>
755 __device__ inline DType PosInfValue(void) {
756 return MaxValue<DType>();
760 __device__ inline float PosInfValue<float>(void) {
765 __device__ inline double PosInfValue<double>(void) {
769 } // namespace limits
776 #endif // MXNET_USE_CUDA
778 #endif // MXNET_COMMON_CUDA_RTC_UTIL_INL_H_
namespace of mxnet
Definition: api_registry.h:33
const char type_support_string[]
Definition: util-inl.h:32
const char limits[]
Definition: util-inl.h:594
configuration of MXNet as well as basic data structure.
const char util_string[]
Definition: util-inl.h:415