mxnet
util-inl.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 #ifndef MXNET_COMMON_CUDA_RTC_UTIL_INL_H_
21 #define MXNET_COMMON_CUDA_RTC_UTIL_INL_H_
22 
23 #include <mxnet/base.h>
24 
25 #if MXNET_USE_CUDA
26 
27 namespace mxnet {
28 namespace common {
29 namespace cuda {
30 namespace rtc {
31 
32 const char type_support_string[] = R"code(
33 using float32 = float;
34 using float64 = double;
35 using float16 = half;
36 using uint8 = unsigned char;
37 using int8 = char;
38 using int32 = int;
39 using int64 = long long;
40 using int16 = short;
41 using uint16 = unsigned short;
42 using uint32 = unsigned int;
43 using uint64 = unsigned long long;
44 
45 static_assert(sizeof(float32) == 4, "Size of float32 is expected to be 4B");
46 static_assert(sizeof(float64) == 8, "Size of float64 is expected to be 8B");
47 static_assert(sizeof(float16) == 2, "Size of float16 is expected to be 2B");
48 static_assert(sizeof(uint8) == 1, "Size of uint8 is expected to be 1B");
49 static_assert(sizeof(int8) == 1, "Size of int8 is expected to be 1B");
50 static_assert(sizeof(int32) == 4, "Size of int32 is expected to be 4B");
51 static_assert(sizeof(int64) == 8, "Size of int64 is expected to be 8B");
52 static_assert(sizeof(int16) == 2, "Size of int16 is expected to be 2B");
53 static_assert(sizeof(uint16) == 2, "Size of uint16 is expected to be 2B");
54 static_assert(sizeof(uint32) == 4, "Size of uint32 is expected to be 4B");
55 static_assert(sizeof(uint64) == 8, "Size of uint64 is expected to be 8B");
56 
57 )code"
58 #if MSHADOW_INT64_TENSOR_SIZE == 1
59  "typedef int64 index_t;\n"
60 #else
61  "typedef int32 index_t;\n"
62 #endif
63  R"code(
64 // bool and int8 need to be accumulated in index_t
65 // but bool needs to be treated in the special way
66 // for ops like bitwise_not
67 struct bool_t {
68  index_t value;
69 
70  __device__ inline bool_t(const index_t& v) : value(v) {}
71  __device__ inline bool_t(const volatile index_t& v) : value(v) {}
72  __device__ inline bool_t() : value(0) {}
73 
74  __device__ inline operator index_t() const volatile { return value; }
75  __device__ inline bool_t& operator= (const index_t& v) {
76  value = v;
77  return *this;
78  }
79  __device__ inline volatile bool_t& operator= (const index_t& v) volatile {
80  value = v;
81  return *this;
82  }
83  __device__ inline bool_t& operator= (const volatile index_t& v) {
84  value = v;
85  return *this;
86  }
87 };
88 template<>
89 struct AccType<bool> {
90  using type = bool_t;
91 
92  __device__ static inline type from(const bool& val) {
93  return val;
94  }
95 
96  __device__ static inline bool to(type val) {
97  return val;
98  }
99 };
100 
101 template<>
102 struct AccType<int8> {
103  using type = index_t;
104 
105  __device__ static inline type from(const int8& val) {
106  return val;
107  }
108 
109  __device__ static inline int8 to(type val) {
110  return val;
111  }
112 };
113 
114 template<>
115 struct AccType<uint8> {
116  using type = index_t;
117 
118  __device__ static inline type from(const uint8& val) {
119  return val;
120  }
121 
122  __device__ static inline uint8 to(type val) {
123  return val;
124  }
125 };
126 
127 namespace type_util {
128 
129 struct false_type {
130  static constexpr bool value = false;
131 };
132 
133 struct true_type {
134  static constexpr bool value = true;
135 };
136 
137 // is_integral
138 template <typename T> struct is_integral : false_type {};
139 template <> struct is_integral<uint8> : true_type {};
140 template <> struct is_integral<uint16> : true_type {};
141 template <> struct is_integral<uint32> : true_type {};
142 template <> struct is_integral<uint64> : true_type {};
143 template <> struct is_integral<int8> : true_type {};
144 template <> struct is_integral<int16> : true_type {};
145 template <> struct is_integral<int32> : true_type {};
146 template <> struct is_integral<int64> : true_type {};
147 template <> struct is_integral<bool> : true_type {};
148 template <> struct is_integral<bool_t> : true_type {};
149 
150 // is_unsigned
151 template <typename T> struct is_unsigned : false_type {};
152 template <> struct is_unsigned<uint8> : true_type {};
153 template <> struct is_unsigned<uint16> : true_type {};
154 template <> struct is_unsigned<uint32> : true_type {};
155 template <> struct is_unsigned<uint64> : true_type {};
156 template <> struct is_unsigned<bool> : true_type {};
157 template <> struct is_unsigned<bool_t> : true_type {};
158 
159 // is_same
160 template <typename T, typename U>
161 struct is_same : false_type {};
162 template <typename T> struct is_same<T, T> : true_type {};
163 
164 // has_double
165 template <typename... T> struct has_double : false_type {};
166 
167 template <typename A, typename... B>
168 struct has_double<A, B...> {
169  static constexpr bool value = is_same<A, double>::value ||
170  has_double<B...>::value;
171 };
172 
173 // has_double_or_integral
174 template <typename... T> struct has_double_or_integral : false_type {};
175 
176 template <typename A, typename... B>
177 struct has_double_or_integral<A, B...> {
178  static constexpr bool value = is_same<A, double>::value ||
179  is_integral<A>::value ||
180  has_double_or_integral<B...>::value;
181 };
182 
183 template <bool b>
184 struct enable_if {};
185 
186 template <>
187 struct enable_if<true> {
188  using type = void;
189 };
190 
191 template <typename T, typename U, class Enable = void>
192 struct mixed_type_helper;
193 
194 template <typename T>
195 struct mixed_type_helper<T, float64, typename enable_if<!is_same<float64, T>::value>::type> {
196  using type = float64;
197 };
198 
199 template <typename T>
200 struct mixed_type_helper<float64, T> {
201  using type = float64;
202 };
203 
204 template <typename T>
205 struct mixed_type_helper<T, float32, typename enable_if<!is_same<float64, T>::value &&
206  !is_same<float32, T>::value>::type> {
207  using type = float32;
208 };
209 
210 template <typename T>
211 struct mixed_type_helper<float32, T, typename enable_if<!is_same<float64, T>::value>::type> {
212  using type = float32;
213 };
214 
215 template <typename T>
216 struct mixed_type_helper<T, float16, typename enable_if<is_same<float16, T>::value ||
217  is_integral<T>::value>::type> {
218  using type = float16;
219 };
220 
221 template <typename T>
222 struct mixed_type_helper<float16, T, typename enable_if<is_integral<T>::value>::type> {
223  using type = float16;
224 };
225 
226 template <typename T, typename U>
227 struct mixed_type_helper<T, U, typename enable_if<is_integral<T>::value &&
228  is_integral<U>::value &&
229  is_unsigned<T>::value &&
230  is_unsigned<U>::value &&
231  !is_same<U, bool_t>::value &&
232  sizeof(T) < sizeof(U)>::type> {
233  using type = U;
234 };
235 
236 template <typename T, typename U>
237 struct mixed_type_helper<T, U, typename enable_if<is_integral<T>::value &&
238  is_integral<U>::value &&
239  !is_unsigned<T>::value &&
240  !is_unsigned<U>::value &&
241  !is_same<U, bool_t>::value &&
242  sizeof(T) < sizeof(U)>::type> {
243  using type = U;
244 };
245 
246 template <typename T, typename U>
247 struct mixed_type_helper<T, U, typename enable_if<is_integral<T>::value &&
248  is_integral<U>::value &&
249  is_unsigned<T>::value &&
250  !is_unsigned<U>::value &&
251  !is_same<U, bool_t>::value &&
252  sizeof(T) < sizeof(U)>::type> {
253  using type = U;
254 };
255 
256 template <typename T, typename U>
257 struct mixed_type_helper<U, T, typename enable_if<is_integral<T>::value &&
258  is_integral<U>::value &&
259  is_unsigned<T>::value &&
260  is_unsigned<U>::value &&
261  !is_same<U, bool_t>::value &&
262  sizeof(T) < sizeof(U)>::type> {
263  using type = U;
264 };
265 
266 template <typename T, typename U>
267 struct mixed_type_helper<U, T, typename enable_if<is_integral<T>::value &&
268  is_integral<U>::value &&
269  !is_unsigned<T>::value &&
270  !is_unsigned<U>::value &&
271  !is_same<U, bool_t>::value &&
272  sizeof(T) < sizeof(U)>::type> {
273  using type = U;
274 };
275 
276 template <typename T, typename U>
277 struct mixed_type_helper<U, T, typename enable_if<is_integral<T>::value &&
278  is_integral<U>::value &&
279  is_unsigned<T>::value &&
280  !is_unsigned<U>::value &&
281  !is_same<U, bool_t>::value &&
282  sizeof(T) < sizeof(U)>::type> {
283  using type = U;
284 };
285 
286 template <typename T, typename U>
287 struct mixed_type_helper<T, U, typename enable_if<is_integral<T>::value &&
288  is_integral<U>::value &&
289  !is_same<U, bool_t>::value &&
290  is_same<T, U>::value>::type> {
291  using type = U;
292 };
293 
294 template<>
295 struct mixed_type_helper<int8, uint8> {
296  using type = int16;
297 };
298 
299 template<>
300 struct mixed_type_helper<uint8, int8> {
301  using type = int16;
302 };
303 
304 template<>
305 struct mixed_type_helper<int8, uint16> {
306  using type = int32;
307 };
308 
309 template<>
310 struct mixed_type_helper<uint16, int8> {
311  using type = int32;
312 };
313 
314 template<>
315 struct mixed_type_helper<int8, uint32> {
316  using type = int64;
317 };
318 
319 template<>
320 struct mixed_type_helper<uint32, int8> {
321  using type = int64;
322 };
323 
324 template<>
325 struct mixed_type_helper<int16, uint16> {
326  using type = int32;
327 };
328 
329 template<>
330 struct mixed_type_helper<uint16, int16> {
331  using type = int32;
332 };
333 
334 template<>
335 struct mixed_type_helper<int16, uint32> {
336  using type = int64;
337 };
338 
339 template<>
340 struct mixed_type_helper<uint32, int16> {
341  using type = int64;
342 };
343 
344 template<>
345 struct mixed_type_helper<int32, uint32> {
346  using type = int64;
347 };
348 
349 template<>
350 struct mixed_type_helper<uint32, int32> {
351  using type = int64;
352 };
353 
354 template<>
355 struct mixed_type_helper<uint64, index_t> {
356  using type = index_t;
357 };
358 
359 template<>
360 struct mixed_type_helper<index_t, uint64> {
361  using type = index_t;
362 };
363 
364 template <typename T>
365 struct mixed_type_helper<T, bool_t, typename enable_if<is_integral<T>::value &&
366  sizeof(T) < sizeof(bool_t)>::type> {
367  using type = index_t;
368 };
369 
370 template <typename T>
371 struct mixed_type_helper<bool_t, T, typename enable_if<is_integral<T>::value &&
372  sizeof(T) < sizeof(bool_t)>::type> {
373  using type = index_t;
374 };
375 
376 template <typename T>
377 struct mixed_type_helper<T, bool_t, typename enable_if<is_integral<T>::value &&
378  sizeof(T) == sizeof(bool_t)>::type> {
379  using type = T;
380 };
381 
382 template <typename T>
383 struct mixed_type_helper<bool_t, T, typename enable_if<is_integral<T>::value &&
384  !is_same<T, bool_t>::value &&
385  sizeof(T) == sizeof(bool_t)>::type> {
386  using type = T;
387 };
388 
389 template <typename... Ts>
390 struct multi_mixed_type_helper;
391 
392 template <>
393 struct multi_mixed_type_helper<> {
394  using type = void;
395 };
396 
397 template <typename T>
398 struct multi_mixed_type_helper<T> {
399  using type = T;
400 };
401 
402 template <typename T, typename U, typename... Ts>
403 struct multi_mixed_type_helper<T, U, Ts...> {
404  using type = typename mixed_type_helper<T,
405  typename multi_mixed_type_helper<U,
406  Ts...>::type>::type;
407 };
408 
409 template <typename... Ts>
410 using mixed_type = typename multi_mixed_type_helper<Ts...>::type;
411 
412 } // namespace type_util
413 )code";
414 
415 const char util_string[] = R"code(
416 enum class OpReqType {
417  kNullOp,
418  kWriteTo,
419  kWriteInplace,
420  kAddTo
421 };
422 
423 constexpr int kRTCMaxThreadsPerBlock = 512;
424 constexpr int warp_size = 32;
425 
426 namespace util {
427 
428 constexpr int MAX_DIM = 5;
429 
430 template <int ndim>
431 __device__ inline void unravel_dot(const index_t idx, const index_t (&shape)[MAX_DIM],
432  const index_t (&stridej)[MAX_DIM], const index_t (&stridek)[MAX_DIM], index_t* j, index_t* k) {
433  *j = 0;
434  *k = 0;
435  #pragma unroll
436  for (index_t i = ndim-1, idx_t = idx; i >=0; --i) {
437  const auto tmp = idx_t / shape[i];
438  const auto coord = idx_t - tmp*shape[i];
439  *j += coord*stridej[i];
440  *k += coord*stridek[i];
441  idx_t = tmp;
442  }
443 }
444 
445 template<int ndim>
446 __device__ inline index_t unravel_dot(const index_t idx, const index_t (&shape)[MAX_DIM],
447  const index_t (&stride)[MAX_DIM]) {
448  index_t ret = 0;
449  #pragma unroll
450  for (index_t i = ndim-1, j = idx; i >=0; --i) {
451  auto tmp = j / shape[i];
452  ret += (j - tmp*shape[i])*stride[i];
453  j = tmp;
454  }
455  return ret;
456 }
457 
458 template<int ndim>
459 __device__ inline index_t unravel_ravel(const index_t idx, const index_t (&shape1)[MAX_DIM],
460  const index_t (&shape2)[MAX_DIM]) {
461  index_t ret = 0;
462  index_t total_shape = 1;
463 #pragma unroll
464  for (index_t i = ndim-1, j = idx; i >=0; --i) {
465  if (i != ndim - 1) {
466  total_shape *= shape2[i + 1];
467  }
468  auto tmp = j / shape1[i];
469  const index_t coord = j - tmp*shape1[i];
470  ret += total_shape * (shape2[i] > coord) * coord;
471  j = tmp;
472  }
473  return ret;
474 }
475 
476 template<int ndim, int ndim2>
477 __device__ inline index_t ravel(const index_t (&coord)[ndim], const index_t (&shape)[ndim2]) {
478  index_t ret = 0;
479 #pragma unroll
480  for (int i = 0; i < ndim; ++i) {
481  ret = ret * shape[i] + (shape[i] > coord[i]) * coord[i];
482  }
483  return ret;
484 }
485 
486 template<int ndim, int ndim2>
487 __device__ inline void unravel(const index_t idx,
488  const index_t (&shape)[ndim2],
489  index_t (&coord)[ndim]) {
490 #pragma unroll
491  for (index_t i = ndim-1, j = idx; i >=0; --i) {
492  auto tmp = j / shape[i];
493  coord[i] = j - tmp*shape[i];
494  j = tmp;
495  }
496 }
497 
498 template <typename DType>
499 __device__ inline bool isinf(volatile const DType &val) {
500  return false;
501 }
502 
503 template <>
504 __device__ inline bool isinf(volatile const float &val) {
505  return ::isinf(val);
506 }
507 
508 template <>
509 __device__ inline bool isinf(volatile const double &val) {
510  return ::isinf(val);
511 }
512 
513 template <>
514 __device__ inline bool isinf(volatile const long double &val) {
515  return ::isinf(val);
516 }
517 
518 template <>
519 __device__ inline bool isinf(volatile const float16 &val) {
520  return ::isinf(__half2float(const_cast<const float16&>(val)));
521 }
522 
523 template <typename DType>
524 __device__ inline bool isnan(volatile const DType &val) {
525  return false;
526 }
527 
528 template <>
529 __device__ inline bool isnan(volatile const float &val) {
530  return ::isnan(val);
531 }
532 
533 template <>
534 __device__ inline bool isnan(volatile const double &val) {
535  return ::isnan(val);
536 }
537 
538 template <>
539 __device__ inline bool isnan(volatile const long double &val) {
540  return ::isnan(val);
541 }
542 
543 template <>
544 __device__ inline bool isnan(volatile const float16 &val) {
545  return ::isnan(__half2float(const_cast<const float16&>(val)));
546 }
547 
548 template <int NVALUES = warp_size, typename OP, typename T>
549 __device__ inline T warp_reduce(T value, OP redfun) {
550 #pragma unroll
551  for (int i = warp_size / 2; i >= 1; i /= 2) {
552  if (NVALUES > i) value = redfun(value, __shfl_down_sync(0xffffffff, value, i));
553  }
554  return value;
555 }
556 
557 template <typename OP, typename T>
558 __device__ inline T grouped_warp_reduce(T value, OP redfun, const int group_size) {
559  for (int i = 1; i < group_size; i *= 2) {
560  value = redfun(value, __shfl_down_sync(0xffffffff, value, i));
561  }
562  return value;
563 }
564 
565 template <typename OP, typename T>
566 __device__ inline T grouped_warp_allreduce(T value, OP redfun, const int group_size) {
567  value = grouped_warp_reduce(value, redfun, group_size);
568  return __shfl_sync(0xffffffff, value, 0, group_size);
569 }
570 
571 template <typename OP, typename T>
572 __device__ inline T strided_grouped_warp_reduce(T value, OP redfun, const int group_size) {
573  for (int i = warp_size / 2; i >= group_size; i /= 2) {
574  value = redfun(value, __shfl_down_sync(0xffffffff, value, i));
575  }
576  return value;
577 }
578 
579 template <typename OP, typename T>
580 __device__ inline T strided_grouped_warp_allreduce(T value, OP redfun, const int group_size) {
581  value = strided_grouped_warp_reduce(value, redfun, group_size);
582  for (int i = group_size; i < warp_size; i *= 2) {
583  T tmp = __shfl_up_sync(0xffffffff, value, i);
584  if (threadIdx.x % warp_size >= i) {
585  value = tmp;
586  }
587  }
588  return value;
589 }
590 
591 } // namespace util
592 )code";
593 
594 const char limits[] = R"code(
595 constexpr double DBL_MAX = 1.7976931348623157081e+308;
596 constexpr float FLT_MAX = 3.4028234663852885981e+38;
597 #define inf ((float)1e50)
598 #define nan (inf - inf)
599 
600 namespace limits {
601 
602 template<typename DType>
603 __device__ inline DType MinValue(void);
604 
605 template<>
606 __device__ inline float MinValue<float>(void) {
607  return -FLT_MAX;
608 }
610 template<>
611 __device__ inline double MinValue<double>(void) {
612  return -DBL_MAX;
613 }
615 template<>
616 __device__ inline uint8 MinValue<uint8>(void) {
617  return 0;
618 }
620 template<>
621 __device__ inline uint16 MinValue<uint16>(void) {
622  return 0;
623 }
625 template<>
626 __device__ inline uint32 MinValue<uint32>(void) {
627  return 0;
628 }
630 template<>
631 __device__ inline uint64 MinValue<uint64>(void) {
632  return 0;
633 }
635 template<>
636 __device__ inline int8 MinValue<int8>(void) {
637  return -128;
638 }
640 template<>
641 __device__ inline int16 MinValue<int16>(void) {
642  return -32768;
643 }
645 template<>
646 __device__ inline int32 MinValue<int32>(void) {
647  return -2147483648;
648 }
650 template<>
651 __device__ inline int64 MinValue<int64>(void) {
652  return -9223372036854775808LL;
653 }
655 template<>
656 __device__ inline bool MinValue<bool>(void) {
657  return false;
658 }
660 template<>
661 __device__ inline bool_t MinValue<bool_t>(void) {
662  return MinValue<index_t>();
663 }
664 
669 template<typename DType>
670 __device__ inline DType NegInfValue(void) {
671  return MinValue<DType>();
672 }
674 template<>
675 __device__ inline float NegInfValue<float>(void) {
676  return -inf;
677 }
679 template<>
680 __device__ inline double NegInfValue<double>(void) {
681  return -inf;
682 }
683 
688 template<typename DType>
689 __device__ inline DType MaxValue(void);
691 template<>
692 __device__ inline float MaxValue<float>(void) {
693  return FLT_MAX;
694 }
696 template<>
697 __device__ inline double MaxValue<double>(void) {
698  return DBL_MAX;
699 }
701 template<>
702 __device__ inline uint8 MaxValue<uint8>(void) {
703  return 255;
704 }
706 template<>
707 __device__ inline uint16 MaxValue<uint16>(void) {
708  return 65535;
709 }
711 template<>
712 __device__ inline uint32 MaxValue<uint32>(void) {
713  return 4294967295;
714 }
716 template<>
717 __device__ inline uint64 MaxValue<uint64>(void) {
718  return 18446744073709551615LL;
719 }
721 template<>
722 __device__ inline int8 MaxValue<int8>(void) {
723  return 127;
724 }
726 template<>
727 __device__ inline int16 MaxValue<int16>(void) {
728  return 32767;
729 }
731 template<>
732 __device__ inline int32 MaxValue<int32>(void) {
733  return 2147483647;
734 }
736 template<>
737 __device__ inline int64 MaxValue<int64>(void) {
738  return 9223372036854775807LL;
739 }
741 template<>
742 __device__ inline bool MaxValue<bool>(void) {
743  return true;
744 }
746 template<>
747 __device__ inline bool_t MaxValue<bool_t>(void) {
748  return MaxValue<index_t>();
749 }
754 template<typename DType>
755 __device__ inline DType PosInfValue(void) {
756  return MaxValue<DType>();
757 }
759 template<>
760 __device__ inline float PosInfValue<float>(void) {
761  return inf;
762 }
764 template<>
765 __device__ inline double PosInfValue<double>(void) {
766  return inf;
767 }
768 
769 } // namespace limits
770 )code";
771 } // namespace rtc
772 } // namespace cuda
773 } // namespace common
774 } // namespace mxnet
775 
776 #endif // MXNET_USE_CUDA
777 
778 #endif // MXNET_COMMON_CUDA_RTC_UTIL_INL_H_
mxnet
namespace of mxnet
Definition: api_registry.h:33
mxnet::common::cuda::rtc::type_support_string
const char type_support_string[]
Definition: util-inl.h:32
mxnet::common::cuda::rtc::limits
const char limits[]
Definition: util-inl.h:594
base.h
configuration of MXNet as well as basic data structure.
mxnet::common::cuda::rtc::util_string
const char util_string[]
Definition: util-inl.h:415