mxnet
base.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
26 #ifndef MSHADOW_BASE_H_
27 #define MSHADOW_BASE_H_
28 #ifdef _MSC_VER
29 #ifndef _CRT_SECURE_NO_WARNINGS
30 #define _CRT_SECURE_NO_WARNINGS
31 #endif
32 #ifndef _CRT_SECURE_NO_DEPRECATE
33 #define _CRT_SECURE_NO_DEPRECATE
34 #endif
35 #ifndef NOMINMAX
36 #define NOMINMAX
37 #endif
38 #endif
39 #include <algorithm>
40 #include <cfloat>
41 #include <climits>
42 #include <cmath>
43 #include <cstdio>
44 #include <functional>
45 #include <limits>
46 #include <sstream>
47 #include <string>
48 
49 #ifdef _MSC_VER
50 typedef signed char int8_t;
52 typedef __int16 int16_t;
53 typedef __int32 int32_t;
54 typedef __int64 int64_t;
55 typedef unsigned char uint8_t;
56 typedef unsigned __int16 uint16_t;
57 typedef unsigned __int32 uint32_t;
58 typedef unsigned __int64 uint64_t;
60 #else
61 #include <inttypes.h>
62 #endif
63 // macro defintiions
68 #ifndef MSHADOW_STAND_ALONE
69 #define MSHADOW_STAND_ALONE 0
70 #endif
71 
72 #ifndef MSHADOW_ALLOC_PAD
73 #define MSHADOW_ALLOC_PAD true
74 #endif
75 
83 #ifndef MSHADOW_MIN_PAD_RATIO
84  #define MSHADOW_MIN_PAD_RATIO 2
85 #endif
86 
87 #if MSHADOW_STAND_ALONE
88  #define MSHADOW_USE_CBLAS 0
89  #define MSHADOW_USE_MKL 0
90  #define MSHADOW_USE_CUDA 0
91 #endif
92 
97 #ifndef MSHADOW_FORCE_STREAM
98 #define MSHADOW_FORCE_STREAM 1
99 #endif
100 
102 #ifndef MSHADOW_USE_CBLAS
103  #define MSHADOW_USE_CBLAS 0
104 #endif
105 
106 #ifndef MSHADOW_USE_MKL
107  #define MSHADOW_USE_MKL 1
108 #endif
109 
114 #ifndef MSHADOW_USE_CUDA
115  #define MSHADOW_USE_CUDA 1
116 #endif
117 
121 #ifndef MSHADOW_USE_CUDNN
122  #define MSHADOW_USE_CUDNN 0
123 #endif
124 
128 #ifndef MSHADOW_USE_CUTENSOR
129  #define MSHADOW_USE_CUTENSOR 0
130 #endif
131 
135 #ifndef MSHADOW_USE_CUSOLVER
136  #define MSHADOW_USE_CUSOLVER MSHADOW_USE_CUDA
137 #endif
138 
143 #ifndef MSHADOW_OLD_CUDA
144 #define MSHADOW_OLD_CUDA 0
145 #endif
146 
148 #ifndef MSHADOW_USE_SSE
149  #define MSHADOW_USE_SSE 1
150 #endif
151 
153 #ifndef MSHADOW_USE_F16C
154  #if defined(_MSC_VER) || defined(__CUDACC__)
155  #define MSHADOW_USE_F16C 0
156  #elif defined(__clang__) && \
157  ((__clang_major__ < 8) || ((__clang_major__ == 8) && (__clang_minor__ < 1)))
158  #define MSHADOW_USE_F16C 0
159  #else
160  #define MSHADOW_USE_F16C 1
161  #endif
162 #endif
163 
165 #ifndef MSHADOW_USE_NVML
166  #define MSHADOW_USE_NVML 0
167 #endif
168 // SSE is conflict with cudacc
169 #ifdef __CUDACC__
170  #undef MSHADOW_USE_SSE
171  #define MSHADOW_USE_SSE 0
172 #endif
173 
174 #if MSHADOW_USE_CBLAS
175 extern "C" {
176  #include <cblas.h>
177 }
178 #elif MSHADOW_USE_MKL
179  #if MSHADOW_INT64_TENSOR_SIZE == 1
180  // Define MKL_INT here to use exactly the same 64bits integer type definitions.
181  // If MKL_INT will not be defined here, the mkl header defines it as long long int.
182  #define MKL_INT int64_t
183  #define MKL_UINT uint64_t
184  #endif
185  #include <mkl_blas.h>
186  #include <mkl_cblas.h>
187  #include <mkl_vsl.h>
188  #include <mkl_vsl_functions.h>
189  #include <mkl_version.h>
190 #endif
191 
192 #if MSHADOW_USE_CUDA
193  #include <cuda.h>
194  #include <cublas_v2.h>
195  #include <curand.h>
196 #endif
197 
198 #if MSHADOW_USE_CUDNN == 1
199  #include <cudnn.h>
200 #endif
201 
202 #if MSHADOW_USE_CUTENSOR == 1
203  #include <cutensor.h>
204 #endif
205 
206 #if MSHADOW_USE_CUSOLVER == 1
207  #include <cusolverDn.h>
208 #endif
209 
210 #if MSHADOW_USE_NVML
211  #include <nvml.h>
212 #endif
213 
214 // --------------------------------
215 // MSHADOW_XINLINE is used for inlining template code for both CUDA and CPU code
216 #ifdef MSHADOW_XINLINE
217  #error "MSHADOW_XINLINE must not be defined"
218 #endif
219 #ifdef _MSC_VER
220 #define MSHADOW_FORCE_INLINE __forceinline
221 #pragma warning(disable : 4068)
222 #else
223 #define MSHADOW_FORCE_INLINE inline __attribute__((always_inline))
224 #endif
225 #ifdef __CUDACC__
226  #define MSHADOW_XINLINE MSHADOW_FORCE_INLINE __device__ __host__
227 #else
228  #define MSHADOW_XINLINE MSHADOW_FORCE_INLINE
229 #endif
230 
231 #define MSHADOW_CINLINE MSHADOW_FORCE_INLINE
232 
239 #ifndef MSHADOW_DEFAULT_DTYPE
240 #define MSHADOW_DEFAULT_DTYPE = ::mshadow::default_real_t
241 #endif
242 
246 #ifndef MSHADOW_USE_GLOG
247 #define MSHADOW_USE_GLOG DMLC_USE_GLOG
248 #endif // MSHADOW_USE_GLOG
249 
250 #define MSHADOW_THROW_EXCEPTION noexcept(false)
251 #define MSHADOW_NO_EXCEPTION noexcept(true)
252 
253 #if defined(_MSC_VER)
254 #define MSHADOW_ALIGNED(x) __declspec(align(x))
255 #else
256 #define MSHADOW_ALIGNED(x) __attribute__ ((aligned(x)))
257 #endif
258 
264 #define MSHADOW_CUDA_CALL(func) \
265  { \
266  cudaError_t e = (func); \
267  if (e == cudaErrorCudartUnloading) { \
268  throw dmlc::Error(cudaGetErrorString(e)); \
269  } \
270  CHECK_EQ(e, cudaSuccess) \
271  << "CUDA: " << cudaGetErrorString(e); \
272  }
273 
278 #define MSHADOW_CATCH_ERROR(func) \
279  { \
280  try { \
281  (func); \
282  } catch (const dmlc::Error &e) { \
283  std::string what = e.what(); \
284  if (what.find("driver shutting down") == std::string::npos) { \
285  LOG(ERROR) << "Ignore CUDA Error " << what; \
286  } \
287  } \
288  }
289 
290 #include "./half.h"
291 #include "./bfloat.h"
292 #define MSHADOW_HALF_BF_OPERATOR(RTYPE, OP) \
293  MSHADOW_XINLINE RTYPE operator OP(mshadow::half::half_t a, mshadow::bfloat::bf16_t b) { \
294  return float(a) OP float(b); /* NOLINT(*) */ \
295  } \
296  MSHADOW_XINLINE RTYPE operator OP(mshadow::bfloat::bf16_t a, mshadow::half::half_t b) { \
297  return float(a) OP float(b); /* NOLINT(*) */ \
298  }
299 
301 MSHADOW_HALF_BF_OPERATOR(float, +)
303 MSHADOW_HALF_BF_OPERATOR(float, -)
305 MSHADOW_HALF_BF_OPERATOR(float, *)
307 MSHADOW_HALF_BF_OPERATOR(float, /)
313 MSHADOW_HALF_BF_OPERATOR(bool, >=)
315 MSHADOW_HALF_BF_OPERATOR(bool, <=)
316 
317 #include "dmlc/logging.h"
318 
319 namespace mshadow {
321 const unsigned kRandBufferSize = 1000000;
323 const float kPi = 3.1415926f;
325 #if MSHADOW_INT64_TENSOR_SIZE == 1
326  typedef int64_t index_t;
327 #else
328  typedef int32_t index_t;
329 #endif
330 
331 #ifdef _WIN32
332 
333  typedef int64_t openmp_index_t;
334 #else
335 
337 #endif
338 
339 
340 #if (MSHADOW_USE_MKL && MXNET_USE_LAPACK) || MXNET_USE_ILP64_LAPACKE
341  // lapack_index_t could be replaced by index_t and removed when all blas library support large tensor
342  typedef index_t lapack_index_t;
343 #else
344  typedef int lapack_index_t;
345 #endif
346 
348 typedef float default_real_t;
349 
351 enum TypeFlag {
352  kFloat32 = 0,
353  kFloat64 = 1,
354  kFloat16 = 2,
355  kUint8 = 3,
356  kInt32 = 4,
357  kInt8 = 5,
358  kInt64 = 6,
359  kBool = 7,
360  kInt16 = 8,
361  kUint16 = 9,
362  kUint32 = 10,
363  kUint64 = 11,
365 };
366 
367 template<typename DType>
368 struct DataType;
369 
370 template<>
371 struct DataType<float> {
372  static const int kFlag = kFloat32;
373  static const int kLanes = 1;
374 #if MSHADOW_USE_CUDA
375 #if (CUDA_VERSION >= 8000)
376  static const cudaDataType_t kCudaFlag = CUDA_R_32F;
377 #endif
378 #if MSHADOW_USE_CUDNN
379  static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_FLOAT;
380  typedef float ScaleType;
381 #endif
382 #endif
383 };
384 template<>
385 struct DataType<double> {
386  static const int kFlag = kFloat64;
387  static const int kLanes = 1;
388 #if MSHADOW_USE_CUDA
389 #if (CUDA_VERSION >= 8000)
390  static const cudaDataType_t kCudaFlag = CUDA_R_64F;
391 #endif
392 #if MSHADOW_USE_CUDNN
393  static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_DOUBLE;
394  typedef double ScaleType;
395 #endif
396 #endif
397 };
398 template<>
399 struct DataType<half::half_t> {
400  static const int kFlag = kFloat16;
401  static const int kLanes = 1;
402 #if MSHADOW_USE_CUDA
403 #if (CUDA_VERSION >= 8000)
404  static const cudaDataType_t kCudaFlag = CUDA_R_16F;
405 #endif
406 #if MSHADOW_USE_CUDNN
407  static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_HALF;
408  typedef float ScaleType;
409 #endif
410 #endif
411 };
412 template <>
413 struct DataType<bfloat::bf16_t> {
414  static const int kFlag = kBfloat16;
415  static const int kLanes = 1;
416 };
417 template<>
418 struct DataType<uint8_t> {
419  static const int kFlag = kUint8;
420  static const int kLanes = 1;
421 #if MSHADOW_USE_CUDA
422 #if (CUDA_VERSION >= 8000)
423  static const cudaDataType_t kCudaFlag = CUDA_R_8U;
424 #endif
425 #if (MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 6)
426  // no uint8 in cudnn for now
427  static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_INT8;
428  typedef uint8_t ScaleType;
429 #endif
430 #endif
431 };
432 template<>
433 struct DataType<int8_t> {
434  static const int kFlag = kInt8;
435  static const int kLanes = 1;
436 #if MSHADOW_USE_CUDA
437 #if (CUDA_VERSION >= 8000)
438  static const cudaDataType_t kCudaFlag = CUDA_R_8I;
439 #endif
440 #if (MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 6)
441  static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_INT8;
442  typedef int8_t ScaleType;
443 #endif
444 #endif
445 };
446 template<>
447 struct DataType<int32_t> {
448  static const int kFlag = kInt32;
449  static const int kLanes = 1;
450 #if MSHADOW_USE_CUDA
451 #if (CUDA_VERSION >= 8000)
452  static const cudaDataType_t kCudaFlag = CUDA_R_32I;
453 #endif
454 #if (MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 6)
455  static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_INT32;
456  typedef int32_t ScaleType;
457 #endif
458 #endif
459 };
460 template<>
461 struct DataType<int64_t> {
462  static const int kFlag = kInt64;
463  static const int kLanes = 1;
464 };
465 template<>
466 struct DataType<bool> {
467  static const int kFlag = kBool;
468  static const int kLanes = 1;
469 };
470 template<>
471 struct DataType<int16_t> {
472  static const int kFlag = kInt16;
473  static const int kLanes = 1;
474 };
475 template<>
476 struct DataType<uint16_t> {
477  static const int kFlag = kUint16;
478  static const int kLanes = 1;
479 };
480 template<>
481 struct DataType<uint32_t> {
482  static const int kFlag = kUint32;
483  static const int kLanes = 1;
484 };
485 template<>
486 struct DataType<uint64_t> {
487  static const int kFlag = kUint64;
488  static const int kLanes = 1;
489 };
490 
493 
496 
499  kUNKNOWN = -1,
500 
501  kNCHW = 0,
504 
505  kNCW = 1 << 3,
508 
509  kNCDHW = 1 << 5,
512 };
513 
514 inline LayoutFlag layoutFlag(std::string layoutstr) {
515  switch (layoutstr.length()) {
516  case 4:
517  if (layoutstr == "NHWC")
518  return kNHWC;
519  if (layoutstr == "NCHW")
520  return kNCHW;
521  if (layoutstr == "CHWN")
522  return kCHWN;
523  return kUNKNOWN;
524  case 3:
525  if (layoutstr == "NWC")
526  return kNWC;
527  if (layoutstr == "NCW")
528  return kNCW;
529  if (layoutstr == "CWN")
530  return kCWN;
531  return kUNKNOWN;
532  case 5:
533  if (layoutstr == "NDHWC")
534  return kNDHWC;
535  if (layoutstr == "NCDHW")
536  return kNCDHW;
537  if (layoutstr == "CDHWN")
538  return kCDHWN;
539  return kUNKNOWN;
540  default:
541  return kUNKNOWN;
542  }
543 }
544 
545 inline std::string toString(LayoutFlag layout) {
546  switch (layout) {
547  case kUNKNOWN:
548  return "";
549  case kNCHW:
550  return "NCHW";
551  case kNHWC:
552  return "NHWC";
553  case kCHWN:
554  return "CHWN";
555  case kNCW:
556  return "NCW";
557  case kNWC:
558  return "NWC";
559  case kCWN:
560  return "CWN";
561  case kNCDHW:
562  return "NCDHW";
563  case kNDHWC:
564  return "NDHWC";
565  case kCDHWN:
566  return "CDHWN";
567  default:
568  return "";
569  }
570 }
571 
572 template<int layout>
573 struct LayoutType;
574 
575 template<>
576 struct LayoutType<kNCHW> {
577  static const index_t kNdim = 4;
578 #if (MSHADOW_USE_CUDA && MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 4)
579  static const cudnnTensorFormat_t kCudnnFlag = CUDNN_TENSOR_NCHW;
580 #else
581  static const int kCudnnFlag = -1;
582 #endif
583 };
584 
585 template<>
586 struct LayoutType<kNHWC> {
587  static const index_t kNdim = 4;
588 #if (MSHADOW_USE_CUDA && MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 4)
589  static const cudnnTensorFormat_t kCudnnFlag = CUDNN_TENSOR_NHWC;
590 #else
591  static const int kCudnnFlag = -1;
592 #endif
593 };
594 
596 const int default_layout = kNCHW;
597 
598 template<>
600  static const index_t kNdim = 5;
601 #if (MSHADOW_USE_CUDA && MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 4)
602  static const cudnnTensorFormat_t kCudnnFlag = CUDNN_TENSOR_NCHW;
603 #else
604  static const int kCudnnFlag = -1;
605 #endif
606 };
607 
608 template<>
610  static const index_t kNdim = 5;
611 #if (MSHADOW_USE_CUDA && MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 4)
612  static const cudnnTensorFormat_t kCudnnFlag = CUDNN_TENSOR_NHWC;
613 #else
614  static const int kCudnnFlag = -1;
615 #endif
616 };
617 
620 
622 namespace op {
623 // binary operator
625 struct mul{
627  template<typename DType>
628  MSHADOW_XINLINE static DType Map(DType a, DType b) {
629  return a * b;
630  }
631 };
633 struct plus {
635  template<typename DType>
636  MSHADOW_XINLINE static DType Map(DType a, DType b) {
637  return a + b;
638  }
639 };
641 struct minus {
643  template<typename DType>
644  MSHADOW_XINLINE static DType Map(DType a, DType b) {
645  return a - b;
646  }
647 };
649 struct div {
651  template<typename DType>
652  MSHADOW_XINLINE static DType Map(DType a, DType b) {
653  return a / b;
654  }
655 };
657 struct right {
659  template<typename DType>
660  MSHADOW_XINLINE static DType Map(DType a, DType b) {
661  return b;
662  }
663 };
664 // unary operator/ function: example
665 // these operators can be defined by user,
666 // in the same style as binary and unary operator
667 // to use, simply write F<op::identity>( src )
669 struct identity{
671  template<typename DType>
672  MSHADOW_XINLINE static DType Map(DType a) {
673  return a;
674  }
675 };
676 } // namespace op
678 namespace sv {
680 struct saveto {
682  template<typename DType>
683  MSHADOW_XINLINE static void Save(DType &a, DType b) { // NOLINT(*)
684  a = b;
685  }
687  inline static default_real_t AlphaBLAS(void) { return 1.0f; }
689  inline static default_real_t BetaBLAS(void) { return 0.0f; }
691  typedef op::right OPType;
692 };
694 struct plusto {
696  template<typename DType>
697  MSHADOW_XINLINE static void Save(DType &a, DType b) { // NOLINT(*)
698  a += b;
699  }
701  inline static default_real_t AlphaBLAS(void) { return 1.0f; }
703  inline static default_real_t BetaBLAS(void) { return 1.0f; }
705  typedef op::plus OPType;
706 };
708 struct minusto {
710  template<typename DType>
711  MSHADOW_XINLINE static void Save(DType &a, DType b) { // NOLINT(*)
712  a -= b;
713  }
715  inline static default_real_t AlphaBLAS(void) { return -1.0f; }
717  inline static default_real_t BetaBLAS(void) { return 1.0f; }
719  typedef op::minus OPType;
720 };
722 struct multo {
724  template<typename DType>
725  MSHADOW_XINLINE static void Save(DType &a, DType b) { // NOLINT(*)
726  a *= b;
727  }
729  typedef op::mul OPType;
730 };
732 struct divto {
734  template<typename DType>
735  MSHADOW_XINLINE static void Save(DType& a, DType b) { // NOLINT(*)
736  a /= b;
737  }
739  typedef op::div OPType;
740 };
741 } // namespace sv
742 
743 #ifndef __CUDA_ARCH__
744 using std::isnan;
745 using std::isinf;
746 #endif
747 
751 namespace isnan_typed {
752  template<typename DType>
753  MSHADOW_XINLINE bool IsNan(volatile DType val) {
754  return false;
755  }
756  template<>
757  MSHADOW_XINLINE bool IsNan(volatile float val) {
758  return isnan(val);
759  }
760  template<>
761  MSHADOW_XINLINE bool IsNan(volatile double val) {
762  return isnan(val);
763  }
764  template<>
765  MSHADOW_XINLINE bool IsNan(volatile long double val) {
766  return isnan(val);
767  }
768  template<>
769  MSHADOW_XINLINE bool IsNan(volatile mshadow::half::half_t val) {
770  return (val.half_ & (~MSHADOW_HALF_SIGN_BIT)) > MSHADOW_HALF_EXPONENT_BITS;
771  }
772  template <>
773  MSHADOW_XINLINE bool IsNan(volatile mshadow::bfloat::bf16_t val) {
774  return (val.bf16_ & (~MSHADOW_BF16_SIGN_BIT)) > MSHADOW_BF16_EXPONENT_BITS;
775  }
776 } // namespace isnan_typed
777 
781 namespace isinf_typed {
782  template<typename DType>
783  MSHADOW_XINLINE bool IsInf(volatile DType val) {
784  return false;
785  }
786  template<>
787  MSHADOW_XINLINE bool IsInf(volatile float val) {
788  return isinf(val);
789  }
790  template<>
791  MSHADOW_XINLINE bool IsInf(volatile double val) {
792  return isinf(val);
793  }
794  template<>
795  MSHADOW_XINLINE bool IsInf(volatile long double val) {
796  return isinf(val);
797  }
798  template<>
799  MSHADOW_XINLINE bool IsInf(volatile mshadow::half::half_t val) {
800  return (val.half_ & (~MSHADOW_HALF_SIGN_BIT)) == MSHADOW_HALF_EXPONENT_BITS;
801  }
802  template <>
803  MSHADOW_XINLINE bool IsInf(volatile mshadow::bfloat::bf16_t val) {
804  return (val.bf16_ & (~MSHADOW_BF16_SIGN_BIT)) == MSHADOW_BF16_EXPONENT_BITS;
805  }
806 } // namespace isinf_typed
807 
809 namespace red {
810 namespace limits {
815 template<typename DType>
816 MSHADOW_XINLINE DType MinValue(void);
818 template<>
820  return -FLT_MAX;
821 }
823 template<>
825  return -DBL_MAX;
826 }
828 template<>
829 MSHADOW_XINLINE half::half_t MinValue<half::half_t>(void) {
830  return MSHADOW_HALF_MIN;
831 }
833 template<>
834 MSHADOW_XINLINE bfloat::bf16_t MinValue<bfloat::bf16_t>(void) {
835  return MSHADOW_BF16_MIN;
836 }
838 template<>
840  return 0;
841 }
843 template<>
845  return SCHAR_MIN;
846 }
848 template<>
850  return INT_MIN;
851 }
853 template<>
855  return LLONG_MIN;
856 }
858 template<>
860  return false;
861 }
863 template<>
865  return 0;
866 }
867 
872 template<typename DType>
874  return MinValue<DType>();
875 }
877 template<>
879  return -HUGE_VALF;
880 }
882 template<>
884  return -HUGE_VAL;
885 }
887 template<>
888 MSHADOW_XINLINE half::half_t NegInfValue<half::half_t>(void) {
889  return half::half_t::Binary(
891 }
893 template <>
894 MSHADOW_XINLINE bfloat::bf16_t NegInfValue<bfloat::bf16_t>(void) {
895  return bfloat::bf16_t::Binary(MSHADOW_BF16_SIGN_BIT | MSHADOW_BF16_EXPONENT_BITS);
896 }
897 
902 template<typename DType>
903 MSHADOW_XINLINE DType MaxValue(void);
905 template<>
907  return FLT_MAX;
908 }
910 template<>
912  return DBL_MAX;
913 }
915 template<>
916 MSHADOW_XINLINE half::half_t MaxValue<half::half_t>(void) {
917  return MSHADOW_HALF_MAX;
918 }
920 template<>
921 MSHADOW_XINLINE bfloat::bf16_t MaxValue<bfloat::bf16_t>(void) {
922  return MSHADOW_BF16_MAX;
923 }
925 template<>
927  return UCHAR_MAX;
928 }
930 template<>
932  return SCHAR_MAX;
933 }
935 template<>
937  return INT_MAX;
938 }
940 template<>
942  return LLONG_MAX;
943 }
945 template<>
947  return true;
948 }
950 template<>
952  return std::numeric_limits<uint32_t>::max();
953 }
954 
959 template<typename DType>
961  return MaxValue<DType>();
962 }
964 template<>
966  return HUGE_VALF;
967 }
969 template<>
971  return HUGE_VAL;
972 }
974 template<>
975 MSHADOW_XINLINE half::half_t PosInfValue<half::half_t>(void) {
976  return half::half_t::Binary(MSHADOW_HALF_EXPONENT_BITS);
977 }
979 template <>
980 MSHADOW_XINLINE bfloat::bf16_t PosInfValue<bfloat::bf16_t>(void) {
981  return bfloat::bf16_t::Binary(MSHADOW_BF16_EXPONENT_BITS);
982 }
983 
984 } // namespace limits
985 
987 struct sum {
989  template<typename DType>
990  MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src) { // NOLINT(*)
991  dst += src;
992  }
994  template<typename DType>
995  MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src, volatile DType& residual) { // NOLINT(*)
996  DType y = src - residual;
997  DType t = dst + y;
998  if (isinf_typed::IsInf(t)) {
999  residual = 0;
1000  } else {
1001  residual = (t - dst) - y;
1002  }
1003  dst = t;
1004  }
1006  template<typename DType>
1007  MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& src_val) { // NOLINT(*)
1008  Reduce(dst_val, src_val);
1009  }
1011  template<typename DType>
1012  MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& dst_residual, volatile DType& src_val, volatile DType& src_residual) { // NOLINT(*)
1013  DType t1 = dst_val + src_val;
1014  if (isinf_typed::IsInf(t1)) {
1015  dst_val = t1;
1016  dst_residual = 0;
1017  } else {
1018  DType e = t1 - dst_val;
1019  DType t2 = ((src_val - e) + (dst_val - (t1 - e))) + dst_residual + src_residual;
1020  dst_val = t1 + t2;
1021  dst_residual = t2 - (dst_val - t1);
1022  }
1023  }
1025  template<typename DType>
1026  MSHADOW_XINLINE static void Finalize(volatile DType& dst) {} // NOLINT(*)
1028  template<typename DType>
1029  MSHADOW_XINLINE static void Finalize(volatile DType& dst, volatile DType& residual) {} // NOLINT(*)
1034  template<typename DType>
1035  MSHADOW_XINLINE static DType PartialGrad(DType redres, DType redsrc) {
1036  return 1;
1037  }
1041  template<typename DType>
1042  MSHADOW_XINLINE static void SetInitValue(DType &initv) { // NOLINT(*)
1043  initv = 0;
1044  }
1048  template<typename DType>
1049  MSHADOW_XINLINE static void SetInitValue(DType &initv, DType &residual) { // NOLINT(*)
1050  SetInitValue(initv);
1051  residual = 0;
1052  }
1053 };
1055 struct maximum {
1057  template<typename DType>
1058  MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src) { // NOLINT(*)
1059  if (!isnan_typed::IsNan(dst)) {
1060  if (!(dst >= src)) dst = src;
1061  }
1062  }
1064  template<typename DType>
1065  MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src, volatile DType &none) { // NOLINT(*)
1066  Reduce(dst, src);
1067  }
1069  template<typename DType>
1070  MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& src_val) { // NOLINT(*)
1071  Reduce(dst_val, src_val);
1072  }
1074  template<typename DType>
1075  MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& dst_residual, volatile DType& src_val, volatile DType& src_residual) { // NOLINT(*)
1076  Reduce(dst_val, src_val);
1077  }
1079  template<typename DType>
1080  MSHADOW_XINLINE static void Finalize(volatile DType& dst) {} // NOLINT(*)
1082  template<typename DType>
1083  MSHADOW_XINLINE static void Finalize(volatile DType& dst, volatile DType& residual) {} // NOLINT(*)
1088  template<typename DType>
1089  MSHADOW_XINLINE static DType PartialGrad(DType redres, DType redsrc) {
1090  return redres == redsrc ? 1: 0;
1091  }
1095  template<typename DType>
1096  MSHADOW_XINLINE static void SetInitValue(DType &initv) { // NOLINT(*)
1097  initv = limits::NegInfValue<DType>();
1098  }
1102  template<typename DType>
1103  MSHADOW_XINLINE static void SetInitValue(DType &initv, DType &none) { // NOLINT(*)
1104  SetInitValue(initv);
1105  }
1106 };
1108 struct minimum {
1110  template<typename DType>
1111  MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src) { // NOLINT(*)
1112  if (!isnan_typed::IsNan(dst)) {
1113  if (!(dst <= src)) dst = src;
1114  }
1115  }
1117  template<typename DType>
1118  MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src, volatile DType &none) { // NOLINT(*)
1119  Reduce(dst, src);
1120  }
1122  template<typename DType>
1123  MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& src_val) { // NOLINT(*)
1124  Reduce(dst_val, src_val);
1125  }
1127  template<typename DType>
1128  MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& dst_residual, volatile DType& src_val, volatile DType& src_residual) { // NOLINT(*)
1129  Reduce(dst_val, src_val);
1130  }
1132  template<typename DType>
1133  MSHADOW_XINLINE static void Finalize(volatile DType& dst) {} // NOLINT(*)
1135  template<typename DType>
1136  MSHADOW_XINLINE static void Finalize(volatile DType& dst, volatile DType& residual) {} // NOLINT(*)
1141  template<typename DType>
1142  MSHADOW_XINLINE static DType PartialGrad(DType redres, DType redsrc) {
1143  return redres == redsrc ? 1: 0;
1144  }
1148  template<typename DType>
1149  MSHADOW_XINLINE static void SetInitValue(DType &initv) { // NOLINT(*)
1150  initv = limits::PosInfValue<DType>();
1151  }
1155  template<typename DType>
1156  MSHADOW_XINLINE static void SetInitValue(DType &initv, DType &none) { // NOLINT(*)
1157  SetInitValue(initv);
1158  }
1159 };
1160 } // namespace red
1161 
1162 #ifndef __NVCC__
1163 #define MSHADOW_TYPE_SWITCH(type, DType, ...) \
1164  switch (type) { \
1165  case mshadow::kFloat32: \
1166  { \
1167  typedef float DType; \
1168  {__VA_ARGS__} \
1169  } \
1170  break; \
1171  case mshadow::kFloat64: \
1172  { \
1173  typedef double DType; \
1174  {__VA_ARGS__} \
1175  } \
1176  break; \
1177  case mshadow::kFloat16: \
1178  { \
1179  typedef mshadow::half::half_t DType; \
1180  {__VA_ARGS__} \
1181  } \
1182  break; \
1183  case mshadow::kBfloat16: \
1184  { \
1185  typedef mshadow::bfloat::bf16_t DType; \
1186  {__VA_ARGS__} \
1187  } \
1188  break; \
1189  case mshadow::kUint8: \
1190  { \
1191  typedef uint8_t DType; \
1192  {__VA_ARGS__} \
1193  } \
1194  break; \
1195  case mshadow::kInt8: \
1196  { \
1197  typedef int8_t DType; \
1198  {__VA_ARGS__} \
1199  } \
1200  break; \
1201  case mshadow::kInt32: \
1202  { \
1203  typedef int32_t DType; \
1204  {__VA_ARGS__} \
1205  } \
1206  break; \
1207  case mshadow::kInt64: \
1208  { \
1209  typedef int64_t DType; \
1210  {__VA_ARGS__} \
1211  } \
1212  break; \
1213  case mshadow::kBool: \
1214  LOG(FATAL) << "This operation does not " \
1215  "support bool type"; \
1216  break; \
1217  case mshadow::kInt16: \
1218  LOG(FATAL) << "This operation does not " \
1219  "support int16 type"; \
1220  break; \
1221  case mshadow::kUint16: \
1222  LOG(FATAL) << "This operation does not " \
1223  "support uint16 type"; \
1224  break; \
1225  case mshadow::kUint32: \
1226  LOG(FATAL) << "This operation does not " \
1227  "support uint32 type"; \
1228  break; \
1229  case mshadow::kUint64: \
1230  LOG(FATAL) << "This operation does not " \
1231  "support uint64 type"; \
1232  break; \
1233  default: \
1234  LOG(FATAL) << "Unknown type enum " << type; \
1235  }
1236 #else
1237 #define MSHADOW_TYPE_SWITCH(type, DType, ...) \
1238  switch (type) { \
1239  case mshadow::kFloat32: \
1240  { \
1241  typedef float DType; \
1242  {__VA_ARGS__} \
1243  } \
1244  break; \
1245  case mshadow::kFloat64: \
1246  { \
1247  typedef double DType; \
1248  {__VA_ARGS__} \
1249  } \
1250  break; \
1251  case mshadow::kFloat16: \
1252  { \
1253  typedef mshadow::half::half_t DType; \
1254  {__VA_ARGS__} \
1255  } \
1256  break; \
1257  case mshadow::kUint8: \
1258  { \
1259  typedef uint8_t DType; \
1260  {__VA_ARGS__} \
1261  } \
1262  break; \
1263  case mshadow::kInt8: \
1264  { \
1265  typedef int8_t DType; \
1266  {__VA_ARGS__} \
1267  } \
1268  break; \
1269  case mshadow::kInt32: \
1270  { \
1271  typedef int32_t DType; \
1272  {__VA_ARGS__} \
1273  } \
1274  break; \
1275  case mshadow::kInt64: \
1276  { \
1277  typedef int64_t DType; \
1278  {__VA_ARGS__} \
1279  } \
1280  break; \
1281  case mshadow::kBool: \
1282  LOG(FATAL) << "This operation does not " \
1283  "support bool type"; \
1284  break; \
1285  case mshadow::kInt16: \
1286  LOG(FATAL) << "This operation does not " \
1287  "support int16 type"; \
1288  break; \
1289  case mshadow::kUint16: \
1290  LOG(FATAL) << "This operation does not " \
1291  "support uint16 type"; \
1292  break; \
1293  case mshadow::kUint32: \
1294  LOG(FATAL) << "This operation does not " \
1295  "support uint32 type"; \
1296  break; \
1297  case mshadow::kUint64: \
1298  LOG(FATAL) << "This operation does not " \
1299  "support uint64 type"; \
1300  break; \
1301  default: \
1302  LOG(FATAL) << "Unknown type enum " << type; \
1303  }
1304 #endif
1305 
1306 #define MSHADOW_SGL_DBL_TYPE_SWITCH(type, DType, ...) \
1307  switch (type) { \
1308  case mshadow::kFloat32: \
1309  { \
1310  typedef float DType; \
1311  {__VA_ARGS__} \
1312  } \
1313  break; \
1314  case mshadow::kFloat64: \
1315  { \
1316  typedef double DType; \
1317  {__VA_ARGS__} \
1318  } \
1319  break; \
1320  default: \
1321  LOG(FATAL) << "This operation only supports " \
1322  "32-bit and 64-bit floating point"; \
1323  }
1324 
1325 #define MSHADOW_REAL_TYPE_SWITCH(type, DType, ...) \
1326  switch (type) { \
1327  case mshadow::kFloat32: \
1328  { \
1329  typedef float DType; \
1330  {__VA_ARGS__} \
1331  } \
1332  break; \
1333  case mshadow::kFloat64: \
1334  { \
1335  typedef double DType; \
1336  {__VA_ARGS__} \
1337  } \
1338  break; \
1339  case mshadow::kFloat16: \
1340  { \
1341  typedef mshadow::half::half_t DType; \
1342  {__VA_ARGS__} \
1343  } \
1344  break; \
1345  case mshadow::kUint8: \
1346  LOG(FATAL) << "This operation only support " \
1347  "floating point types not uint8"; \
1348  break; \
1349  case mshadow::kInt8: \
1350  LOG(FATAL) << "This operation only support " \
1351  "floating point types not int8"; \
1352  break; \
1353  case mshadow::kInt32: \
1354  LOG(FATAL) << "This operation only support " \
1355  "floating point types, not int32";\
1356  break; \
1357  case mshadow::kInt64: \
1358  LOG(FATAL) << "This operation only support " \
1359  "floating point types, not int64";\
1360  break; \
1361  case mshadow::kBool: \
1362  LOG(FATAL) << "This operation only support " \
1363  "floating point types, not bool"; \
1364  break; \
1365  case mshadow::kInt16: \
1366  LOG(FATAL) << "This operation only support " \
1367  "floating point types, not int16";\
1368  break; \
1369  case mshadow::kUint16: \
1370  LOG(FATAL) << "This operation only support " \
1371  "floating point types not uint16";\
1372  break; \
1373  case mshadow::kUint32: \
1374  LOG(FATAL) << "This operation only support " \
1375  "floating point types not uint32";\
1376  break; \
1377  case mshadow::kUint64: \
1378  LOG(FATAL) << "This operation only support " \
1379  "floating point types not uint64";\
1380  break; \
1381  default: \
1382  LOG(FATAL) << "Unknown type enum " << type; \
1383  }
1384 
1385 #ifndef __NVCC__
1386 #define MSHADOW_REAL_TYPE_SWITCH_EX(type$, DType$, DLargeType$, ...) \
1387  switch (type$) { \
1388  case mshadow::kFloat32: \
1389  { \
1390  typedef float DType$; \
1391  typedef float DLargeType$; \
1392  {__VA_ARGS__} \
1393  } \
1394  break; \
1395  case mshadow::kFloat64: \
1396  { \
1397  typedef double DType$; \
1398  typedef double DLargeType$; \
1399  {__VA_ARGS__} \
1400  } \
1401  break; \
1402  case mshadow::kFloat16: \
1403  { \
1404  typedef mshadow::half::half_t DType$; \
1405  typedef float DLargeType$; \
1406  {__VA_ARGS__} \
1407  } \
1408  break; \
1409  case mshadow::kBfloat16: \
1410  { \
1411  typedef mshadow::bfloat::bf16_t DType$; \
1412  typedef float DLargeType$; \
1413  {__VA_ARGS__} \
1414  } \
1415  break; \
1416  case mshadow::kUint8: \
1417  LOG(FATAL) << "This operation only support " \
1418  "floating point types not uint8"; \
1419  break; \
1420  case mshadow::kInt8: \
1421  LOG(FATAL) << "This operation only support " \
1422  "floating point types not int8"; \
1423  break; \
1424  case mshadow::kInt32: \
1425  LOG(FATAL) << "This operation only support " \
1426  "floating point types, not int32";\
1427  break; \
1428  case mshadow::kInt64: \
1429  LOG(FATAL) << "This operation only support " \
1430  "floating point types, not int64";\
1431  break; \
1432  case mshadow::kBool: \
1433  LOG(FATAL) << "This operation only support " \
1434  "floating point types, not bool"; \
1435  break; \
1436  case mshadow::kInt16: \
1437  LOG(FATAL) << "This operation only support " \
1438  "floating point types, not int16";\
1439  break; \
1440  case mshadow::kUint16: \
1441  LOG(FATAL) << "This operation only support " \
1442  "floating point types not uint16";\
1443  break; \
1444  case mshadow::kUint32: \
1445  LOG(FATAL) << "This operation only support " \
1446  "floating point types not uint32";\
1447  break; \
1448  case mshadow::kUint64: \
1449  LOG(FATAL) << "This operation only support " \
1450  "floating point types not uint64";\
1451  break; \
1452  default: \
1453  LOG(FATAL) << "Unknown type enum " << type$; \
1454  }
1455 #else
1456 #define MSHADOW_REAL_TYPE_SWITCH_EX(type$, DType$, DLargeType$, ...) \
1457  switch (type$) { \
1458  case mshadow::kFloat32: \
1459  { \
1460  typedef float DType$; \
1461  typedef float DLargeType$; \
1462  {__VA_ARGS__} \
1463  } \
1464  break; \
1465  case mshadow::kFloat64: \
1466  { \
1467  typedef double DType$; \
1468  typedef double DLargeType$; \
1469  {__VA_ARGS__} \
1470  } \
1471  break; \
1472  case mshadow::kFloat16: \
1473  { \
1474  typedef mshadow::half::half_t DType$; \
1475  typedef float DLargeType$; \
1476  {__VA_ARGS__} \
1477  } \
1478  break; \
1479  case mshadow::kUint8: \
1480  LOG(FATAL) << "This operation only support " \
1481  "floating point types not uint8"; \
1482  break; \
1483  case mshadow::kInt8: \
1484  LOG(FATAL) << "This operation only support " \
1485  "floating point types not int8"; \
1486  break; \
1487  case mshadow::kInt32: \
1488  LOG(FATAL) << "This operation only support " \
1489  "floating point types, not int32";\
1490  break; \
1491  case mshadow::kInt64: \
1492  LOG(FATAL) << "This operation only support " \
1493  "floating point types, not int64";\
1494  break; \
1495  case mshadow::kBool: \
1496  LOG(FATAL) << "This operation only support " \
1497  "floating point types, not bool"; \
1498  break; \
1499  case mshadow::kInt16: \
1500  LOG(FATAL) << "This operation only support " \
1501  "floating point types, not int16";\
1502  break; \
1503  case mshadow::kUint16: \
1504  LOG(FATAL) << "This operation only support " \
1505  "floating point types not uint16";\
1506  break; \
1507  case mshadow::kUint32: \
1508  LOG(FATAL) << "This operation only support " \
1509  "floating point types not uint32";\
1510  break; \
1511  case mshadow::kUint64: \
1512  LOG(FATAL) << "This operation only support " \
1513  "floating point types not uint64";\
1514  break; \
1515  default: \
1516  LOG(FATAL) << "Unknown type enum " << type$; \
1517  }
1518 #endif
1519 #define MSHADOW_LAYOUT_SWITCH(layout, Layout, ...) \
1520  switch (layout) { \
1521  case mshadow::kNCHW: \
1522  { \
1523  const int Layout = kNCHW; \
1524  {__VA_ARGS__} \
1525  } \
1526  break; \
1527  case mshadow::kNHWC: \
1528  { \
1529  const int Layout = kNHWC; \
1530  {__VA_ARGS__} \
1531  } \
1532  break; \
1533  case mshadow::kNCDHW: \
1534  { \
1535  const int Layout = kNCDHW; \
1536  {__VA_ARGS__} \
1537  } \
1538  break; \
1539  case mshadow::kNDHWC: \
1540  { \
1541  const int Layout = kNDHWC; \
1542  {__VA_ARGS__} \
1543  } \
1544  break; \
1545  default: \
1546  LOG(FATAL) << "Unknown layout enum " << layout; \
1547  }
1548 
1553 #define MSHADOW_IDX_TYPE_SWITCH(type, DType, ...) \
1554  switch (type) { \
1555  case mshadow::kInt64: \
1556  { \
1557  typedef int64_t DType; \
1558  {__VA_ARGS__} \
1559  } \
1560  break; \
1561  default: \
1562  LOG(FATAL) << "Unknown type enum " << type; \
1563  }
1564 
1565 #define MSHADOW_TYPE_SWITCH_WITH_BOOL(type, DType, ...) \
1566  switch (type) { \
1567  case mshadow::kFloat32: \
1568  { \
1569  typedef float DType; \
1570  {__VA_ARGS__} \
1571  } \
1572  break; \
1573  case mshadow::kFloat64: \
1574  { \
1575  typedef double DType; \
1576  {__VA_ARGS__} \
1577  } \
1578  break; \
1579  case mshadow::kFloat16: \
1580  { \
1581  typedef mshadow::half::half_t DType; \
1582  {__VA_ARGS__} \
1583  } \
1584  break; \
1585  case mshadow::kBfloat16: \
1586  { \
1587  typedef mshadow::bfloat::bf16_t DType; \
1588  {__VA_ARGS__} \
1589  } \
1590  break; \
1591  case mshadow::kUint8: \
1592  { \
1593  typedef uint8_t DType; \
1594  {__VA_ARGS__} \
1595  } \
1596  break; \
1597  case mshadow::kInt8: \
1598  { \
1599  typedef int8_t DType; \
1600  {__VA_ARGS__} \
1601  } \
1602  break; \
1603  case mshadow::kInt32: \
1604  { \
1605  typedef int32_t DType; \
1606  {__VA_ARGS__} \
1607  } \
1608  break; \
1609  case mshadow::kInt64: \
1610  { \
1611  typedef int64_t DType; \
1612  {__VA_ARGS__} \
1613  } \
1614  break; \
1615  case mshadow::kBool: \
1616  { \
1617  typedef bool DType; \
1618  {__VA_ARGS__} \
1619  } \
1620  break; \
1621  case mshadow::kInt16: \
1622  LOG(FATAL) << "This operation does not " \
1623  "support int16 type"; \
1624  break; \
1625  case mshadow::kUint16: \
1626  LOG(FATAL) << "This operation does not " \
1627  "support uint16 type"; \
1628  break; \
1629  case mshadow::kUint32: \
1630  LOG(FATAL) << "This operation does not " \
1631  "support uint32 type"; \
1632  break; \
1633  case mshadow::kUint64: \
1634  LOG(FATAL) << "This operation does not " \
1635  "support uint64 type"; \
1636  break; \
1637  default: \
1638  LOG(FATAL) << "Unknown type enum " << type; \
1639  }
1640 
1641 #define MSHADOW_TYPE_SWITCH_EXT(type, DType, ...) \
1642  switch (type) { \
1643  case mshadow::kFloat32: \
1644  { \
1645  typedef float DType; \
1646  {__VA_ARGS__} \
1647  } \
1648  break; \
1649  case mshadow::kFloat64: \
1650  { \
1651  typedef double DType; \
1652  {__VA_ARGS__} \
1653  } \
1654  break; \
1655  case mshadow::kFloat16: \
1656  { \
1657  typedef mshadow::half::half_t DType; \
1658  {__VA_ARGS__} \
1659  } \
1660  break; \
1661  case mshadow::kBfloat16: \
1662  { \
1663  typedef mshadow::bfloat::bf16_t DType; \
1664  {__VA_ARGS__} \
1665  } \
1666  break; \
1667  case mshadow::kUint8: \
1668  { \
1669  typedef uint8_t DType; \
1670  {__VA_ARGS__} \
1671  } \
1672  break; \
1673  case mshadow::kInt8: \
1674  { \
1675  typedef int8_t DType; \
1676  {__VA_ARGS__} \
1677  } \
1678  break; \
1679  case mshadow::kInt32: \
1680  { \
1681  typedef int32_t DType; \
1682  {__VA_ARGS__} \
1683  } \
1684  break; \
1685  case mshadow::kInt64: \
1686  { \
1687  typedef int64_t DType; \
1688  {__VA_ARGS__} \
1689  } \
1690  break; \
1691  case mshadow::kInt16: \
1692  { \
1693  typedef int16_t DType; \
1694  {__VA_ARGS__} \
1695  } \
1696  break; \
1697  case mshadow::kUint16: \
1698  { \
1699  typedef uint16_t DType; \
1700  {__VA_ARGS__} \
1701  } \
1702  break; \
1703  case mshadow::kUint32: \
1704  { \
1705  typedef uint32_t DType; \
1706  {__VA_ARGS__} \
1707  } \
1708  break; \
1709  case mshadow::kUint64: \
1710  { \
1711  typedef uint64_t DType; \
1712  {__VA_ARGS__} \
1713  } \
1714  break; \
1715  default: \
1716  LOG(FATAL) << "Unknown type enum " << type; \
1717  }
1718 
1719 #define MSHADOW_TYPE_SWITCH_EXT_WITH_BOOL(type, DType, ...) \
1720  switch (type) { \
1721  case mshadow::kFloat32: \
1722  { \
1723  typedef float DType; \
1724  {__VA_ARGS__} \
1725  } \
1726  break; \
1727  case mshadow::kFloat64: \
1728  { \
1729  typedef double DType; \
1730  {__VA_ARGS__} \
1731  } \
1732  break; \
1733  case mshadow::kFloat16: \
1734  { \
1735  typedef mshadow::half::half_t DType; \
1736  {__VA_ARGS__} \
1737  } \
1738  break; \
1739  case mshadow::kBfloat16: \
1740  { \
1741  typedef mshadow::bfloat::bf16_t DType; \
1742  {__VA_ARGS__} \
1743  } \
1744  break; \
1745  case mshadow::kUint8: \
1746  { \
1747  typedef uint8_t DType; \
1748  {__VA_ARGS__} \
1749  } \
1750  break; \
1751  case mshadow::kInt8: \
1752  { \
1753  typedef int8_t DType; \
1754  {__VA_ARGS__} \
1755  } \
1756  break; \
1757  case mshadow::kInt32: \
1758  { \
1759  typedef int32_t DType; \
1760  {__VA_ARGS__} \
1761  } \
1762  break; \
1763  case mshadow::kInt64: \
1764  { \
1765  typedef int64_t DType; \
1766  {__VA_ARGS__} \
1767  } \
1768  break; \
1769  case mshadow::kBool: \
1770  { \
1771  typedef bool DType; \
1772  {__VA_ARGS__} \
1773  } \
1774  break; \
1775  case mshadow::kInt16: \
1776  { \
1777  typedef int16_t DType; \
1778  {__VA_ARGS__} \
1779  } \
1780  break; \
1781  case mshadow::kUint16: \
1782  { \
1783  typedef uint16_t DType; \
1784  {__VA_ARGS__} \
1785  } \
1786  break; \
1787  case mshadow::kUint32: \
1788  { \
1789  typedef uint32_t DType; \
1790  {__VA_ARGS__} \
1791  } \
1792  break; \
1793  case mshadow::kUint64: \
1794  { \
1795  typedef uint64_t DType; \
1796  {__VA_ARGS__} \
1797  } \
1798  break; \
1799  default: \
1800  LOG(FATAL) << "Unknown type enum " << type; \
1801  }
1802 
1804 inline size_t mshadow_sizeof(int type) {
1805  int size = 0;
1806  MSHADOW_TYPE_SWITCH_EXT_WITH_BOOL(type, DType, size = sizeof(DType););
1807  return size;
1808 }
1809 
1810 /*/ \brief get string with the type name from type enum */
1811 inline std::string dtype_string(const int dtype) {
1812  switch (dtype) {
1813  case mshadow::kFloat32:
1814  return "float";
1815  case mshadow::kFloat64:
1816  return "double";
1817  case mshadow::kFloat16:
1818  return "half";
1819  case mshadow::kUint8:
1820  return "unsigned char";
1821  case mshadow::kInt8:
1822  return "char";
1823  case mshadow::kInt32:
1824  return "int";
1825  case mshadow::kInt64:
1826  return "long long";
1827  case mshadow::kBool:
1828  return "bool";
1829  case mshadow::kInt16:
1830  return "short";
1831  case mshadow::kUint16:
1832  return "unsigned short";
1833  case mshadow::kUint32:
1834  return "unsigned int";
1835  case mshadow::kUint64:
1836  return "unsigned long long";
1837  default:
1838  LOG(FATAL) << "Unknown type enum " << dtype;
1839  }
1840  return "unknown";
1841 }
1842 
1843 } // namespace mshadow
1844 #endif // MSHADOW_BASE_H_
mshadow::sv::saveto::OPType
op::right OPType
corresponding binary operator type
Definition: base.h:691
mshadow::sv::minusto::AlphaBLAS
static default_real_t AlphaBLAS(void)
helper constant to use BLAS, alpha
Definition: base.h:715
mshadow::kNCHW
@ kNCHW
Definition: base.h:501
mshadow::red::maximum
maximum reducer
Definition: base.h:1055
mshadow::red::minimum::Finalize
static MSHADOW_XINLINE void Finalize(volatile DType &dst, volatile DType &residual)
finalize reduction
Definition: base.h:1136
mshadow::red::minimum::Merge
static MSHADOW_XINLINE void Merge(volatile DType &dst_val, volatile DType &dst_residual, volatile DType &src_val, volatile DType &src_residual)
combine the results of two reducers
Definition: base.h:1128
mshadow::red::sum::Finalize
static MSHADOW_XINLINE void Finalize(volatile DType &dst, volatile DType &residual)
finalize reduction
Definition: base.h:1029
mshadow::red::limits::NegInfValue
MSHADOW_XINLINE DType NegInfValue(void)
negative infinity of certain types
Definition: base.h:873
mshadow::red::limits::MaxValue< uint32_t >
MSHADOW_XINLINE uint32_t MaxValue< uint32_t >(void)
maximum value of uint32_t
Definition: base.h:951
mshadow::red::limits::MaxValue< bool >
MSHADOW_XINLINE bool MaxValue< bool >(void)
maximum value of bool
Definition: base.h:946
mshadow::openmp_index_t
index_t openmp_index_t
openmp index for linux
Definition: base.h:336
mshadow::red::limits::MinValue< uint8_t >
MSHADOW_XINLINE uint8_t MinValue< uint8_t >(void)
minimum value of uint8_t
Definition: base.h:839
mshadow::op::mul::Map
static MSHADOW_XINLINE DType Map(DType a, DType b)
map a, b to result using defined operation
Definition: base.h:628
mshadow::default_type_flag
const int default_type_flag
type enum value for default real type
Definition: base.h:492
mshadow::red::maximum::SetInitValue
static MSHADOW_XINLINE void SetInitValue(DType &initv, DType &none)
set the initial value during reduction
Definition: base.h:1103
mshadow::red::limits::PosInfValue< float >
MSHADOW_XINLINE float PosInfValue< float >(void)
positive infinity value of float
Definition: base.h:965
mshadow::red::limits::MaxValue< uint8_t >
MSHADOW_XINLINE uint8_t MaxValue< uint8_t >(void)
maximum value of uint8_t
Definition: base.h:926
mshadow::kUint16
@ kUint16
Definition: base.h:361
mshadow::kUint64
@ kUint64
Definition: base.h:363
mshadow::kNWC
@ kNWC
Definition: base.h:506
mshadow::kNCDHW
@ kNCDHW
Definition: base.h:509
mshadow::red::maximum::SetInitValue
static MSHADOW_XINLINE void SetInitValue(DType &initv)
set the initial value during reduction
Definition: base.h:1096
mshadow::sv::minusto::Save
static MSHADOW_XINLINE void Save(DType &a, DType b)
save b to a using save method
Definition: base.h:711
mshadow::red::limits::MaxValue< int8_t >
MSHADOW_XINLINE int8_t MaxValue< int8_t >(void)
maximum value of int8_t
Definition: base.h:931
mshadow::sv::divto::OPType
op::div OPType
corresponding binary operator type
Definition: base.h:739
mshadow::red::limits::MinValue< float >
MSHADOW_XINLINE float MinValue< float >(void)
minimum value of float
Definition: base.h:819
mshadow::red::maximum::Merge
static MSHADOW_XINLINE void Merge(volatile DType &dst_val, volatile DType &dst_residual, volatile DType &src_val, volatile DType &src_residual)
combine the results of two reducers
Definition: base.h:1075
mshadow::LayoutType
Definition: base.h:573
mshadow::kInt8
@ kInt8
Definition: base.h:357
mshadow::kUint32
@ kUint32
Definition: base.h:362
mshadow::sv::multo::OPType
op::mul OPType
corresponding binary operator type
Definition: base.h:729
mshadow::op::minus
minus operator
Definition: base.h:641
mshadow::red::maximum::Finalize
static MSHADOW_XINLINE void Finalize(volatile DType &dst, volatile DType &residual)
finalize reduction
Definition: base.h:1083
MSHADOW_BF16_MAX
#define MSHADOW_BF16_MAX
Definition: bfloat.h:182
mshadow::toString
std::string toString(LayoutFlag layout)
Definition: base.h:545
MSHADOW_XINLINE
#define MSHADOW_XINLINE
Definition: base.h:228
mshadow::red::sum::Merge
static MSHADOW_XINLINE void Merge(volatile DType &dst_val, volatile DType &dst_residual, volatile DType &src_val, volatile DType &src_residual)
combine the results of two reducers
Definition: base.h:1012
mshadow::sv::saveto::AlphaBLAS
static default_real_t AlphaBLAS(void)
helper constant to use BLAS, alpha
Definition: base.h:687
mshadow::op::identity::Map
static MSHADOW_XINLINE DType Map(DType a)
map a to result using defined operation
Definition: base.h:672
mshadow::isnan_typed::IsNan
MSHADOW_XINLINE bool IsNan(volatile DType val)
Definition: base.h:753
mshadow::sv::plusto::BetaBLAS
static default_real_t BetaBLAS(void)
helper constant to use BLAS, beta
Definition: base.h:703
mshadow::op::div
divide operator
Definition: base.h:649
mshadow::sv::saveto::Save
static MSHADOW_XINLINE void Save(DType &a, DType b)
save b to a using save method
Definition: base.h:683
MSHADOW_BF16_SIGN_BIT
#define MSHADOW_BF16_SIGN_BIT
Definition: bfloat.h:183
mshadow::red::sum
sum reducer
Definition: base.h:987
mshadow::sv::minusto::BetaBLAS
static default_real_t BetaBLAS(void)
helper constant to use BLAS, beta
Definition: base.h:717
mshadow::red::maximum::Finalize
static MSHADOW_XINLINE void Finalize(volatile DType &dst)
finalize reduction
Definition: base.h:1080
MSHADOW_BF16_EXPONENT_BITS
#define MSHADOW_BF16_EXPONENT_BITS
Definition: bfloat.h:184
mshadow::sv::divto::Save
static MSHADOW_XINLINE void Save(DType &a, DType b)
save b to a using save method
Definition: base.h:735
mshadow::default_layout
const int default_layout
default layout for 4d tensor
Definition: base.h:596
mshadow::sv::multo::Save
static MSHADOW_XINLINE void Save(DType &a, DType b)
save b to a using save method
Definition: base.h:725
mshadow::red::minimum::Reduce
static MSHADOW_XINLINE void Reduce(volatile DType &dst, volatile DType src, volatile DType &none)
do reduction into dst
Definition: base.h:1118
mxnet::common::cuda::rtc::limits
const char limits[]
Definition: util-inl.h:594
mshadow::red::limits::MaxValue< int32_t >
MSHADOW_XINLINE int MaxValue< int32_t >(void)
maximum value of int32_t
Definition: base.h:936
mshadow::kBool
@ kBool
Definition: base.h:359
mshadow::red::limits::MinValue< int8_t >
MSHADOW_XINLINE int8_t MinValue< int8_t >(void)
minimum value of int8_t
Definition: base.h:844
mshadow::sv::saveto::BetaBLAS
static default_real_t BetaBLAS(void)
helper constant to use BLAS, beta
Definition: base.h:689
mshadow::kNHWC
@ kNHWC
Definition: base.h:502
mshadow::default_layout_5d
const int default_layout_5d
default layout for 5d tensor
Definition: base.h:619
mshadow::red::maximum::Reduce
static MSHADOW_XINLINE void Reduce(volatile DType &dst, volatile DType src)
do reduction into dst
Definition: base.h:1058
mshadow::kCWN
@ kCWN
Definition: base.h:507
mshadow::kCHWN
@ kCHWN
Definition: base.h:503
mshadow::kFloat64
@ kFloat64
Definition: base.h:353
mshadow::op::plus
plus operator
Definition: base.h:633
mshadow::sv::plusto::AlphaBLAS
static default_real_t AlphaBLAS(void)
helper constant to use BLAS, alpha
Definition: base.h:701
mshadow::red::sum::Merge
static MSHADOW_XINLINE void Merge(volatile DType &dst_val, volatile DType &src_val)
combine the results of two reducers
Definition: base.h:1007
mshadow::red::limits::MinValue< double >
MSHADOW_XINLINE double MinValue< double >(void)
minimum value of double
Definition: base.h:824
mshadow::sv::plusto
save to saver: +=
Definition: base.h:694
mshadow::LayoutFlag
LayoutFlag
Definition: base.h:498
mshadow::op::plus::Map
static MSHADOW_XINLINE DType Map(DType a, DType b)
map a, b to result using defined operation
Definition: base.h:636
mshadow::red::sum::Reduce
static MSHADOW_XINLINE void Reduce(volatile DType &dst, volatile DType src)
do reduction into dst
Definition: base.h:990
mshadow::kInt16
@ kInt16
Definition: base.h:360
mshadow::red::minimum::PartialGrad
static MSHADOW_XINLINE DType PartialGrad(DType redres, DType redsrc)
calculate gradient of redres with respect to redsrc, redres: reduced result, redsrc: one of reduction...
Definition: base.h:1142
mshadow::red::limits::PosInfValue
MSHADOW_XINLINE DType PosInfValue(void)
positive infinity of certain types
Definition: base.h:960
MSHADOW_HALF_MAX
#define MSHADOW_HALF_MAX
Definition: half.h:369
mshadow::red::minimum::Merge
static MSHADOW_XINLINE void Merge(volatile DType &dst_val, volatile DType &src_val)
combine the results of two reducers
Definition: base.h:1123
mshadow::sv::saveto
save to saver: =
Definition: base.h:680
mshadow::red::sum::Finalize
static MSHADOW_XINLINE void Finalize(volatile DType &dst)
finalize reduction
Definition: base.h:1026
mshadow::kCDHWN
@ kCDHWN
Definition: base.h:511
mshadow::red::limits::NegInfValue< double >
MSHADOW_XINLINE double NegInfValue< double >(void)
negative infinity value of double
Definition: base.h:883
mshadow::op::right
get rhs
Definition: base.h:657
mshadow::red::limits::MinValue
MSHADOW_XINLINE DType MinValue(void)
minimum value of certain types
mshadow::kRandBufferSize
const unsigned kRandBufferSize
buffer size for each random number generator
Definition: base.h:321
mshadow::isinf_typed::IsInf
MSHADOW_XINLINE bool IsInf(volatile DType val)
Definition: base.h:783
mshadow::op::minus::Map
static MSHADOW_XINLINE DType Map(DType a, DType b)
map a, b to result using defined operation
Definition: base.h:644
mshadow::red::maximum::Merge
static MSHADOW_XINLINE void Merge(volatile DType &dst_val, volatile DType &src_val)
combine the results of two reducers
Definition: base.h:1070
mshadow::sv::multo
multiply to saver: *=
Definition: base.h:722
mshadow::kInt64
@ kInt64
Definition: base.h:358
mshadow::red::sum::PartialGrad
static MSHADOW_XINLINE DType PartialGrad(DType redres, DType redsrc)
calculate gradient of redres with respect to redsrc, redres: reduced result, redsrc: one of reduction...
Definition: base.h:1035
mshadow::index_t
int32_t index_t
type that will be used for index
Definition: base.h:328
MSHADOW_HALF_SIGN_BIT
#define MSHADOW_HALF_SIGN_BIT
Definition: half.h:370
mshadow::DataType
Definition: base.h:368
mshadow::red::minimum::Reduce
static MSHADOW_XINLINE void Reduce(volatile DType &dst, volatile DType src)
do reduction into dst
Definition: base.h:1111
mshadow::kPi
const float kPi
pi
Definition: base.h:323
mshadow::kInt32
@ kInt32
Definition: base.h:356
mshadow::red::limits::MinValue< bool >
MSHADOW_XINLINE bool MinValue< bool >(void)
minimum value of bool
Definition: base.h:859
mshadow::TypeFlag
TypeFlag
data type flag
Definition: base.h:351
mshadow::lapack_index_t
int lapack_index_t
Definition: base.h:344
mshadow::red::limits::MinValue< int64_t >
MSHADOW_XINLINE int64_t MinValue< int64_t >(void)
minimum value of int64_t
Definition: base.h:854
mshadow::red::limits::MaxValue< double >
MSHADOW_XINLINE double MaxValue< double >(void)
maximum value of double
Definition: base.h:911
mshadow::red::sum::SetInitValue
static MSHADOW_XINLINE void SetInitValue(DType &initv)
set the initial value during reduction
Definition: base.h:1042
mshadow::mshadow_sizeof
size_t mshadow_sizeof(int type)
get data type size from type enum
Definition: base.h:1804
mshadow::kNDHWC
@ kNDHWC
Definition: base.h:510
mshadow::sv::plusto::OPType
op::plus OPType
corresponding binary operator type
Definition: base.h:705
mshadow::red::maximum::Reduce
static MSHADOW_XINLINE void Reduce(volatile DType &dst, volatile DType src, volatile DType &none)
do reduction into dst
Definition: base.h:1065
MSHADOW_BF16_MIN
#define MSHADOW_BF16_MIN
overloaded + operator for bf16_t
Definition: bfloat.h:181
mshadow::dtype_string
std::string dtype_string(const int dtype)
Definition: base.h:1811
mshadow::red::sum::SetInitValue
static MSHADOW_XINLINE void SetInitValue(DType &initv, DType &residual)
set the initial value during reduction
Definition: base.h:1049
mshadow
overloaded + operator between half_t and bf16_t
Definition: base.h:319
mshadow::red::limits::MaxValue< float >
MSHADOW_XINLINE float MaxValue< float >(void)
maximum value of float
Definition: base.h:906
mshadow::default_real_t
float default_real_t
float point type that will be used in default by mshadow
Definition: base.h:348
mshadow::op::identity
identity function that maps a real number to it self
Definition: base.h:669
mshadow::sv::minusto::OPType
op::minus OPType
corresponding binary operator type
Definition: base.h:719
half.h
definition of half (float16) type.
mshadow::sv::minusto
minus to saver: -=
Definition: base.h:708
mshadow::kUNKNOWN
@ kUNKNOWN
Definition: base.h:499
mshadow::red::minimum::SetInitValue
static MSHADOW_XINLINE void SetInitValue(DType &initv)
set the initial value during reduction
Definition: base.h:1149
mshadow::red::minimum::Finalize
static MSHADOW_XINLINE void Finalize(volatile DType &dst)
finalize reduction
Definition: base.h:1133
MSHADOW_TYPE_SWITCH_EXT_WITH_BOOL
#define MSHADOW_TYPE_SWITCH_EXT_WITH_BOOL(type, DType,...)
Definition: base.h:1719
mshadow::red::limits::MaxValue< int64_t >
MSHADOW_XINLINE int64_t MaxValue< int64_t >(void)
maximum value of int64_t
Definition: base.h:941
MSHADOW_HALF_MIN
#define MSHADOW_HALF_MIN
overloaded + operator for half_t
Definition: half.h:368
mshadow::red::minimum::SetInitValue
static MSHADOW_XINLINE void SetInitValue(DType &initv, DType &none)
set the initial value during reduction
Definition: base.h:1156
mshadow::sv::plusto::Save
static MSHADOW_XINLINE void Save(DType &a, DType b)
save b to a using save method
Definition: base.h:697
MSHADOW_HALF_BF_OPERATOR
#define MSHADOW_HALF_BF_OPERATOR(RTYPE, OP)
Definition: base.h:292
mshadow::kUint8
@ kUint8
Definition: base.h:355
mshadow::kBfloat16
@ kBfloat16
Definition: base.h:364
mshadow::op::div::Map
static MSHADOW_XINLINE DType Map(DType a, DType b)
map a, b to result using defined operation
Definition: base.h:652
mshadow::kNCW
@ kNCW
Definition: base.h:505
mshadow::red::sum::Reduce
static MSHADOW_XINLINE void Reduce(volatile DType &dst, volatile DType src, volatile DType &residual)
do stable reduction into dst
Definition: base.h:995
mshadow::red::limits::PosInfValue< double >
MSHADOW_XINLINE double PosInfValue< double >(void)
positive infinity value of double
Definition: base.h:970
mshadow::op::mul
mul operator
Definition: base.h:625
mshadow::layoutFlag
LayoutFlag layoutFlag(std::string layoutstr)
Definition: base.h:514
mshadow::index_type_flag
const int index_type_flag
TypeFlag value for type of indexes.
Definition: base.h:495
mshadow::red::limits::MinValue< unsigned int >
MSHADOW_XINLINE unsigned int MinValue< unsigned int >(void)
minimum value of unsigned int
Definition: base.h:864
mshadow::red::limits::MaxValue
MSHADOW_XINLINE DType MaxValue(void)
maximum value of certain types
mshadow::sv::divto
divide to saver: /=
Definition: base.h:732
mshadow::kFloat16
@ kFloat16
Definition: base.h:354
mshadow::red::limits::NegInfValue< float >
MSHADOW_XINLINE float NegInfValue< float >(void)
negative infinity value of float
Definition: base.h:878
mshadow::red::minimum
minimum reducer
Definition: base.h:1108
mshadow::red::limits::MinValue< int32_t >
MSHADOW_XINLINE int MinValue< int32_t >(void)
minimum value of int32_t
Definition: base.h:849
mshadow::red::maximum::PartialGrad
static MSHADOW_XINLINE DType PartialGrad(DType redres, DType redsrc)
calculate gradient of redres with respect to redsrc, redres: reduced result, redsrc: one of reduction...
Definition: base.h:1089
mshadow::kFloat32
@ kFloat32
Definition: base.h:352
mshadow::op::right::Map
static MSHADOW_XINLINE DType Map(DType a, DType b)
map a, b to result using defined operation
Definition: base.h:660
bfloat.h
definition of bfloat type.
MSHADOW_HALF_EXPONENT_BITS
#define MSHADOW_HALF_EXPONENT_BITS
Definition: half.h:371