mxnet
base.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
27 #ifndef MSHADOW_BASE_H_
28 #define MSHADOW_BASE_H_
29 #ifdef _MSC_VER
30 #ifndef _CRT_SECURE_NO_WARNINGS
31 #define _CRT_SECURE_NO_WARNINGS
32 #endif
33 #ifndef _CRT_SECURE_NO_DEPRECATE
34 #define _CRT_SECURE_NO_DEPRECATE
35 #endif
36 #ifndef NOMINMAX
37 #define NOMINMAX
38 #endif
39 #endif
40 #include <cmath>
41 #include <cstdio>
42 #include <cfloat>
43 #include <climits>
44 #include <algorithm>
45 #include <functional>
46 #include <sstream>
47 #include <string>
48 
49 #ifdef _MSC_VER
50 typedef signed char int8_t;
52 typedef __int16 int16_t;
53 typedef __int32 int32_t;
54 typedef __int64 int64_t;
55 typedef unsigned char uint8_t;
56 typedef unsigned __int16 uint16_t;
57 typedef unsigned __int32 uint32_t;
58 typedef unsigned __int64 uint64_t;
60 #else
61 #include <inttypes.h>
62 #endif
63 // macro defintiions
68 #ifndef MSHADOW_STAND_ALONE
69 #define MSHADOW_STAND_ALONE 0
70 #endif
71 
72 #ifndef MSHADOW_ALLOC_PAD
73 #define MSHADOW_ALLOC_PAD true
74 #endif
75 
83 #ifndef MSHADOW_MIN_PAD_RATIO
84  #define MSHADOW_MIN_PAD_RATIO 2
85 #endif
86 
87 #if MSHADOW_STAND_ALONE
88  #define MSHADOW_USE_CBLAS 0
89  #define MSHADOW_USE_MKL 0
90  #define MSHADOW_USE_CUDA 0
91 #endif
92 
97 #ifndef MSHADOW_FORCE_STREAM
98 #define MSHADOW_FORCE_STREAM 1
99 #endif
100 
102 #ifndef MSHADOW_USE_CBLAS
103  #define MSHADOW_USE_CBLAS 0
104 #endif
105 
106 #ifndef MSHADOW_USE_MKL
107  #define MSHADOW_USE_MKL 1
108 #endif
109 
114 #ifndef MSHADOW_USE_CUDA
115  #define MSHADOW_USE_CUDA 1
116 #endif
117 
121 #ifndef MSHADOW_USE_CUDNN
122  #define MSHADOW_USE_CUDNN 0
123 #endif
124 
128 #ifndef MSHADOW_USE_CUSOLVER
129  #define MSHADOW_USE_CUSOLVER MSHADOW_USE_CUDA
130 #endif
131 
136 #ifndef MSHADOW_OLD_CUDA
137 #define MSHADOW_OLD_CUDA 0
138 #endif
139 
143 #ifndef MSHADOW_IN_CXX11
144  #if (defined(__GXX_EXPERIMENTAL_CXX0X__) ||\
145  __cplusplus >= 201103L || defined(_MSC_VER))
146  #define MSHADOW_IN_CXX11 1
147  #else
148  #define MSHADOW_IN_CXX11 0
149  #endif
150 #endif
151 
153 #ifndef MSHADOW_USE_SSE
154  #define MSHADOW_USE_SSE 1
155 #endif
156 
158 #ifndef MSHADOW_USE_F16C
159  #if defined(_MSC_VER) || defined(__CUDACC__)
160  #define MSHADOW_USE_F16C 0
161  #elif defined(__clang__) && \
162  ((__clang_major__ < 8) || ((__clang_major__ == 8) && (__clang_minor__ < 1)))
163  #define MSHADOW_USE_F16C 0
164  #else
165  #define MSHADOW_USE_F16C 1
166  #endif
167 #endif
168 
170 #ifndef MSHADOW_USE_NVML
171  #define MSHADOW_USE_NVML 0
172 #endif
173 // SSE is conflict with cudacc
174 #ifdef __CUDACC__
175  #undef MSHADOW_USE_SSE
176  #define MSHADOW_USE_SSE 0
177 #endif
178 
179 #if MSHADOW_USE_CBLAS
180 extern "C" {
181  #include <cblas.h>
182 }
183 #elif MSHADOW_USE_MKL
184  #include <mkl_blas.h>
185  #include <mkl_cblas.h>
186  #include <mkl_vsl.h>
187  #include <mkl_vsl_functions.h>
188  #include <mkl_version.h>
189 #endif
190 
191 #if MSHADOW_USE_CUDA
192  #include <cuda.h>
193  #include <cublas_v2.h>
194  #include <curand.h>
195 #endif
196 
197 #if MSHADOW_USE_CUDNN == 1
198  #include <cudnn.h>
199 #endif
200 
201 #if MSHADOW_USE_CUSOLVER == 1
202  #include <cusolverDn.h>
203 #endif
204 
205 #if MSHADOW_USE_NVML
206  #include <nvml.h>
207 #endif
208 
209 // --------------------------------
210 // MSHADOW_XINLINE is used for inlining template code for both CUDA and CPU code
211 #ifdef MSHADOW_XINLINE
212  #error "MSHADOW_XINLINE must not be defined"
213 #endif
214 #ifdef _MSC_VER
215 #define MSHADOW_FORCE_INLINE __forceinline
216 #pragma warning(disable : 4068)
217 #else
218 #define MSHADOW_FORCE_INLINE inline __attribute__((always_inline))
219 #endif
220 #ifdef __CUDACC__
221  #define MSHADOW_XINLINE MSHADOW_FORCE_INLINE __device__ __host__
222 #else
223  #define MSHADOW_XINLINE MSHADOW_FORCE_INLINE
224 #endif
225 
226 #define MSHADOW_CINLINE MSHADOW_FORCE_INLINE
227 
228 #if defined(__GXX_EXPERIMENTAL_CXX0X) ||\
229  defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L
230  #define MSHADOW_CONSTEXPR constexpr
231 #else
232  #define MSHADOW_CONSTEXPR const
233 #endif
234 
241 #ifndef MSHADOW_DEFAULT_DTYPE
242 #define MSHADOW_DEFAULT_DTYPE = ::mshadow::default_real_t
243 #endif
244 
248 #ifndef MSHADOW_USE_GLOG
249 #define MSHADOW_USE_GLOG DMLC_USE_GLOG
250 #endif // MSHADOW_USE_GLOG
251 
252 #if DMLC_USE_CXX11
253 #define MSHADOW_THROW_EXCEPTION noexcept(false)
254 #define MSHADOW_NO_EXCEPTION noexcept(true)
255 #else
256 #define MSHADOW_THROW_EXCEPTION
257 #define MSHADOW_NO_EXCEPTION
258 #endif
259 
260 #if defined(_MSC_VER)
261 #define MSHADOW_ALIGNED(x) __declspec(align(x))
262 #else
263 #define MSHADOW_ALIGNED(x) __attribute__ ((aligned(x)))
264 #endif
265 
271 #define MSHADOW_CUDA_CALL(func) \
272  { \
273  cudaError_t e = (func); \
274  if (e == cudaErrorCudartUnloading) { \
275  throw dmlc::Error(cudaGetErrorString(e)); \
276  } \
277  CHECK(e == cudaSuccess) \
278  << "CUDA: " << cudaGetErrorString(e); \
279  }
280 
285 #define MSHADOW_CATCH_ERROR(func) \
286  { \
287  try { \
288  (func); \
289  } catch (const dmlc::Error &e) { \
290  std::string what = e.what(); \
291  if (what.find("driver shutting down") == std::string::npos) { \
292  LOG(ERROR) << "Ignore CUDA Error " << what; \
293  } \
294  } \
295  }
296 
297 #include "./half.h"
298 #include "./half2.h"
299 #include "./bfloat.h"
300 #define MSHADOW_HALF_BF_OPERATOR(RTYPE, OP) \
301  MSHADOW_XINLINE RTYPE operator OP(mshadow::half::half_t a, mshadow::bfloat::bf16_t b) { \
302  return float(a) OP float(b); /* NOLINT(*) */ \
303  } \
304  MSHADOW_XINLINE RTYPE operator OP(mshadow::bfloat::bf16_t a, mshadow::half::half_t b) { \
305  return float(a) OP float(b); /* NOLINT(*) */ \
306  }
307 
309 MSHADOW_HALF_BF_OPERATOR(float, +)
311 MSHADOW_HALF_BF_OPERATOR(float, -)
313 MSHADOW_HALF_BF_OPERATOR(float, *)
315 MSHADOW_HALF_BF_OPERATOR(float, /)
321 MSHADOW_HALF_BF_OPERATOR(bool, >=)
323 MSHADOW_HALF_BF_OPERATOR(bool, <=)
324 
325 #include "./logging.h"
326 
327 namespace mshadow {
329 const unsigned kRandBufferSize = 1000000;
331 const float kPi = 3.1415926f;
333 #if MSHADOW_INT64_TENSOR_SIZE == 1
334  typedef int64_t index_t;
335 #else
336  typedef int32_t index_t;
337 #endif
338 
339 #ifdef _WIN32
340 
341  typedef int64_t openmp_index_t;
342 #else
343 
344  typedef index_t openmp_index_t;
345 #endif
346 
348 typedef float default_real_t;
349 
351 enum TypeFlag {
352  kFloat32 = 0,
353  kFloat64 = 1,
354  kFloat16 = 2,
355  kUint8 = 3,
356  kInt32 = 4,
357  kInt8 = 5,
358  kInt64 = 6,
359  kBool = 7,
360  kInt16 = 8,
361  kUint16 = 9,
362  kUint32 = 10,
363  kUint64 = 11,
365 };
366 
367 template<typename DType>
368 struct DataType;
369 
370 template<>
371 struct DataType<float> {
372  static const int kFlag = kFloat32;
373  static const int kLanes = 1;
374 #if MSHADOW_USE_CUDA
375 #if (CUDA_VERSION >= 8000)
376  static const cudaDataType_t kCudaFlag = CUDA_R_32F;
377 #endif
378 #if MSHADOW_USE_CUDNN
379  static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_FLOAT;
380  typedef float ScaleType;
381 #endif
382 #endif
383 };
384 template<>
385 struct DataType<double> {
386  static const int kFlag = kFloat64;
387  static const int kLanes = 1;
388 #if MSHADOW_USE_CUDA
389 #if (CUDA_VERSION >= 8000)
390  static const cudaDataType_t kCudaFlag = CUDA_R_64F;
391 #endif
392 #if MSHADOW_USE_CUDNN
393  static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_DOUBLE;
394  typedef double ScaleType;
395 #endif
396 #endif
397 };
398 template<>
399 struct DataType<half::half_t> {
400  static const int kFlag = kFloat16;
401  static const int kLanes = 1;
402 #if MSHADOW_USE_CUDA
403 #if (CUDA_VERSION >= 8000)
404  static const cudaDataType_t kCudaFlag = CUDA_R_16F;
405 #endif
406 #if MSHADOW_USE_CUDNN
407  static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_HALF;
408  typedef float ScaleType;
409 #endif
410 #endif
411 };
412 template<>
413 struct DataType<half::half2_t> {
414  static const int kFlag = kFloat16;
415  static const int kLanes = 2;
416 };
417 template<>
418 struct DataType<bfloat::bf16_t> {
419  static const int kFlag = kBfloat16;
420  static const int kLanes = 1;
421 };
422 template<>
423 struct DataType<uint8_t> {
424  static const int kFlag = kUint8;
425  static const int kLanes = 1;
426 #if MSHADOW_USE_CUDA
427 #if (CUDA_VERSION >= 8000)
428  static const cudaDataType_t kCudaFlag = CUDA_R_8U;
429 #endif
430 #if (MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 6)
431  // no uint8 in cudnn for now
432  static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_INT8;
433  typedef uint8_t ScaleType;
434 #endif
435 #endif
436 };
437 template<>
438 struct DataType<int8_t> {
439  static const int kFlag = kInt8;
440  static const int kLanes = 1;
441 #if MSHADOW_USE_CUDA
442 #if (CUDA_VERSION >= 8000)
443  static const cudaDataType_t kCudaFlag = CUDA_R_8I;
444 #endif
445 #if (MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 6)
446  static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_INT8;
447  typedef int8_t ScaleType;
448 #endif
449 #endif
450 };
451 template<>
452 struct DataType<int32_t> {
453  static const int kFlag = kInt32;
454  static const int kLanes = 1;
455 #if MSHADOW_USE_CUDA
456 #if (CUDA_VERSION >= 8000)
457  static const cudaDataType_t kCudaFlag = CUDA_R_32I;
458 #endif
459 #if (MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 6)
460  static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_INT32;
461  typedef int32_t ScaleType;
462 #endif
463 #endif
464 };
465 template<>
466 struct DataType<int64_t> {
467  static const int kFlag = kInt64;
468  static const int kLanes = 1;
469 };
470 template<>
471 struct DataType<bool> {
472  static const int kFlag = kBool;
473  static const int kLanes = 1;
474 };
475 
478 
481  kNCHW = 0,
484 
485  kNCW = 1 << 3,
488 
489  kNCDHW = 1 << 5,
492 };
493 
494 template<int layout>
495 struct LayoutType;
496 
497 template<>
498 struct LayoutType<kNCHW> {
499  static const index_t kNdim = 4;
500 #if (MSHADOW_USE_CUDA && MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 4)
501  static const cudnnTensorFormat_t kCudnnFlag = CUDNN_TENSOR_NCHW;
502 #else
503  static const int kCudnnFlag = -1;
504 #endif
505 };
506 
507 template<>
508 struct LayoutType<kNHWC> {
509  static const index_t kNdim = 4;
510 #if (MSHADOW_USE_CUDA && MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 4)
511  static const cudnnTensorFormat_t kCudnnFlag = CUDNN_TENSOR_NHWC;
512 #else
513  static const int kCudnnFlag = -1;
514 #endif
515 };
516 
518 const int default_layout = kNCHW;
519 
520 template<>
522  static const index_t kNdim = 5;
523 #if (MSHADOW_USE_CUDA && MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 4)
524  static const cudnnTensorFormat_t kCudnnFlag = CUDNN_TENSOR_NCHW;
525 #else
526  static const int kCudnnFlag = -1;
527 #endif
528 };
529 
530 template<>
532  static const index_t kNdim = 5;
533 #if (MSHADOW_USE_CUDA && MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 4)
534  static const cudnnTensorFormat_t kCudnnFlag = CUDNN_TENSOR_NHWC;
535 #else
536  static const int kCudnnFlag = -1;
537 #endif
538 };
539 
542 
544 namespace op {
545 // binary operator
547 struct mul{
549  template<typename DType>
550  MSHADOW_XINLINE static DType Map(DType a, DType b) {
551  return a * b;
552  }
553 };
555 struct plus {
557  template<typename DType>
558  MSHADOW_XINLINE static DType Map(DType a, DType b) {
559  return a + b;
560  }
561 };
563 struct minus {
565  template<typename DType>
566  MSHADOW_XINLINE static DType Map(DType a, DType b) {
567  return a - b;
568  }
569 };
571 struct div {
573  template<typename DType>
574  MSHADOW_XINLINE static DType Map(DType a, DType b) {
575  return a / b;
576  }
577 };
579 struct right {
581  template<typename DType>
582  MSHADOW_XINLINE static DType Map(DType a, DType b) {
583  return b;
584  }
585 };
586 // unary operator/ function: example
587 // these operators can be defined by user,
588 // in the same style as binary and unary operator
589 // to use, simply write F<op::identity>( src )
591 struct identity{
593  template<typename DType>
594  MSHADOW_XINLINE static DType Map(DType a) {
595  return a;
596  }
597 };
598 } // namespace op
600 namespace sv {
602 struct saveto {
604  template<typename DType>
605  MSHADOW_XINLINE static void Save(DType &a, DType b) { // NOLINT(*)
606  a = b;
607  }
609  inline static default_real_t AlphaBLAS(void) { return 1.0f; }
611  inline static default_real_t BetaBLAS(void) { return 0.0f; }
613  typedef op::right OPType;
614 };
616 struct plusto {
618  template<typename DType>
619  MSHADOW_XINLINE static void Save(DType &a, DType b) { // NOLINT(*)
620  a += b;
621  }
623  inline static default_real_t AlphaBLAS(void) { return 1.0f; }
625  inline static default_real_t BetaBLAS(void) { return 1.0f; }
627  typedef op::plus OPType;
628 };
630 struct minusto {
632  template<typename DType>
633  MSHADOW_XINLINE static void Save(DType &a, DType b) { // NOLINT(*)
634  a -= b;
635  }
637  inline static default_real_t AlphaBLAS(void) { return -1.0f; }
639  inline static default_real_t BetaBLAS(void) { return 1.0f; }
641  typedef op::minus OPType;
642 };
644 struct multo {
646  template<typename DType>
647  MSHADOW_XINLINE static void Save(DType &a, DType b) { // NOLINT(*)
648  a *= b;
649  }
651  typedef op::mul OPType;
652 };
654 struct divto {
656  template<typename DType>
657  MSHADOW_XINLINE static void Save(DType& a, DType b) { // NOLINT(*)
658  a /= b;
659  }
661  typedef op::div OPType;
662 };
663 } // namespace sv
664 
665 #ifndef __CUDA_ARCH__
666 using std::isnan;
667 using std::isinf;
668 #endif
669 
673 namespace isnan_typed {
674  template<typename DType>
675  MSHADOW_XINLINE bool IsNan(volatile DType val) {
676  return false;
677  }
678  template<>
679  MSHADOW_XINLINE bool IsNan(volatile float val) {
680  return isnan(val);
681  }
682  template<>
683  MSHADOW_XINLINE bool IsNan(volatile double val) {
684  return isnan(val);
685  }
686  template<>
687  MSHADOW_XINLINE bool IsNan(volatile long double val) {
688  return isnan(val);
689  }
690  template<>
691  MSHADOW_XINLINE bool IsNan(volatile mshadow::half::half_t val) {
692  return (val.half_ & (~MSHADOW_HALF_SIGN_BIT)) > MSHADOW_HALF_EXPONENT_BITS;
693  }
694 } // namespace isnan_typed
695 
699 namespace isinf_typed {
700  template<typename DType>
701  MSHADOW_XINLINE bool IsInf(volatile DType val) {
702  return false;
703  }
704  template<>
705  MSHADOW_XINLINE bool IsInf(volatile float val) {
706  return isinf(val);
707  }
708  template<>
709  MSHADOW_XINLINE bool IsInf(volatile double val) {
710  return isinf(val);
711  }
712  template<>
713  MSHADOW_XINLINE bool IsInf(volatile long double val) {
714  return isinf(val);
715  }
716  template<>
717  MSHADOW_XINLINE bool IsInf(volatile mshadow::half::half_t val) {
718  return (val.half_ & (~MSHADOW_HALF_SIGN_BIT)) == MSHADOW_HALF_EXPONENT_BITS;
719  }
720 } // namespace isinf_typed
721 
723 namespace red {
724 namespace limits {
729 template<typename DType>
730 MSHADOW_XINLINE DType MinValue(void);
732 template<>
734  return -FLT_MAX;
735 }
737 template<>
739  return -DBL_MAX;
740 }
742 template<>
743 MSHADOW_XINLINE half::half_t MinValue<half::half_t>(void) {
744  return MSHADOW_HALF_MIN;
745 }
747 template<>
748 MSHADOW_XINLINE bfloat::bf16_t MinValue<bfloat::bf16_t>(void) {
749  return MSHADOW_BF16_MIN;
750 }
752 template<>
754  return 0;
755 }
757 template<>
759  return SCHAR_MIN;
760 }
762 template<>
764  return INT_MIN;
765 }
767 template<>
769  return LLONG_MIN;
770 }
772 template<>
774  return false;
775 }
777 template<>
779  return 0;
780 }
781 
786 template<typename DType>
788  return MinValue<DType>();
789 }
791 template<>
793  return -HUGE_VALF;
794 }
796 template<>
798  return -HUGE_VAL;
799 }
801 template<>
802 MSHADOW_XINLINE half::half_t NegInfValue<half::half_t>(void) {
803  return half::half_t::Binary(
805 }
806 
811 template<typename DType>
812 MSHADOW_XINLINE DType MaxValue(void);
814 template<>
816  return FLT_MAX;
817 }
819 template<>
821  return DBL_MAX;
822 }
824 template<>
825 MSHADOW_XINLINE half::half_t MaxValue<half::half_t>(void) {
826  return MSHADOW_HALF_MAX;
827 }
829 template<>
830 MSHADOW_XINLINE bfloat::bf16_t MaxValue<bfloat::bf16_t>(void) {
831  return MSHADOW_BF16_MAX;
832 }
834 template<>
836  return UCHAR_MAX;
837 }
839 template<>
841  return SCHAR_MAX;
842 }
844 template<>
846  return INT_MAX;
847 }
849 template<>
851  return LLONG_MAX;
852 }
854 template<>
856  return true;
857 }
859 template<>
861  return -1;
862 }
863 
868 template<typename DType>
870  return MaxValue<DType>();
871 }
873 template<>
875  return HUGE_VALF;
876 }
878 template<>
880  return HUGE_VAL;
881 }
883 template<>
884 MSHADOW_XINLINE half::half_t PosInfValue<half::half_t>(void) {
885  return half::half_t::Binary(MSHADOW_HALF_EXPONENT_BITS);
886 }
887 
888 } // namespace limits
889 
891 struct sum {
893  template<typename DType>
894  MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src) { // NOLINT(*)
895  dst += src;
896  }
898  template<typename DType>
899  MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src, volatile DType& residual) { // NOLINT(*)
900  DType y = src - residual;
901  DType t = dst + y;
902  if (isinf_typed::IsInf(t)) {
903  residual = 0;
904  } else {
905  residual = (t - dst) - y;
906  }
907  dst = t;
908  }
910  template<typename DType>
911  MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& src_val) { // NOLINT(*)
912  Reduce(dst_val, src_val);
913  }
915  template<typename DType>
916  MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& dst_residual, volatile DType& src_val, volatile DType& src_residual) { // NOLINT(*)
917  DType t1 = dst_val + src_val;
918  if (isinf_typed::IsInf(t1)) {
919  dst_val = t1;
920  dst_residual = 0;
921  } else {
922  DType e = t1 - dst_val;
923  DType t2 = ((src_val - e) + (dst_val - (t1 - e))) + dst_residual + src_residual;
924  dst_val = t1 + t2;
925  dst_residual = t2 - (dst_val - t1);
926  }
927  }
929  template<typename DType>
930  MSHADOW_XINLINE static void Finalize(volatile DType& dst) {} // NOLINT(*)
932  template<typename DType>
933  MSHADOW_XINLINE static void Finalize(volatile DType& dst, volatile DType& residual) {} // NOLINT(*)
938  template<typename DType>
939  MSHADOW_XINLINE static DType PartialGrad(DType redres, DType redsrc) {
940  return 1;
941  }
945  template<typename DType>
946  MSHADOW_XINLINE static void SetInitValue(DType &initv) { // NOLINT(*)
947  initv = 0;
948  }
952  template<typename DType>
953  MSHADOW_XINLINE static void SetInitValue(DType &initv, DType &residual) { // NOLINT(*)
954  SetInitValue(initv);
955  residual = 0;
956  }
957 };
959 struct maximum {
961  template<typename DType>
962  MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src) { // NOLINT(*)
963  if (!isnan_typed::IsNan(dst)) {
964  if (!(dst >= src)) dst = src;
965  }
966  }
968  template<typename DType>
969  MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src, volatile DType &none) { // NOLINT(*)
970  Reduce(dst, src);
971  }
973  template<typename DType>
974  MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& src_val) { // NOLINT(*)
975  Reduce(dst_val, src_val);
976  }
978  template<typename DType>
979  MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& dst_residual, volatile DType& src_val, volatile DType& src_residual) { // NOLINT(*)
980  Reduce(dst_val, src_val);
981  }
983  template<typename DType>
984  MSHADOW_XINLINE static void Finalize(volatile DType& dst) {} // NOLINT(*)
986  template<typename DType>
987  MSHADOW_XINLINE static void Finalize(volatile DType& dst, volatile DType& residual) {} // NOLINT(*)
992  template<typename DType>
993  MSHADOW_XINLINE static DType PartialGrad(DType redres, DType redsrc) {
994  return redres == redsrc ? 1: 0;
995  }
999  template<typename DType>
1000  MSHADOW_XINLINE static void SetInitValue(DType &initv) { // NOLINT(*)
1001  initv = limits::NegInfValue<DType>();
1002  }
1006  template<typename DType>
1007  MSHADOW_XINLINE static void SetInitValue(DType &initv, DType &none) { // NOLINT(*)
1008  SetInitValue(initv);
1009  }
1010 };
1012 struct minimum {
1014  template<typename DType>
1015  MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src) { // NOLINT(*)
1016  if (!isnan_typed::IsNan(dst)) {
1017  if (!(dst <= src)) dst = src;
1018  }
1019  }
1021  template<typename DType>
1022  MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src, volatile DType &none) { // NOLINT(*)
1023  Reduce(dst, src);
1024  }
1026  template<typename DType>
1027  MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& src_val) { // NOLINT(*)
1028  Reduce(dst_val, src_val);
1029  }
1031  template<typename DType>
1032  MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& dst_residual, volatile DType& src_val, volatile DType& src_residual) { // NOLINT(*)
1033  Reduce(dst_val, src_val);
1034  }
1036  template<typename DType>
1037  MSHADOW_XINLINE static void Finalize(volatile DType& dst) {} // NOLINT(*)
1039  template<typename DType>
1040  MSHADOW_XINLINE static void Finalize(volatile DType& dst, volatile DType& residual) {} // NOLINT(*)
1045  template<typename DType>
1046  MSHADOW_XINLINE static DType PartialGrad(DType redres, DType redsrc) {
1047  return redres == redsrc ? 1: 0;
1048  }
1052  template<typename DType>
1053  MSHADOW_XINLINE static void SetInitValue(DType &initv) { // NOLINT(*)
1054  initv = limits::PosInfValue<DType>();
1055  }
1059  template<typename DType>
1060  MSHADOW_XINLINE static void SetInitValue(DType &initv, DType &none) { // NOLINT(*)
1061  SetInitValue(initv);
1062  }
1063 };
1064 } // namespace red
1065 
1066 #ifndef __NVCC__
1067 #define MSHADOW_TYPE_SWITCH(type, DType, ...) \
1068  switch (type) { \
1069  case mshadow::kFloat32: \
1070  { \
1071  typedef float DType; \
1072  {__VA_ARGS__} \
1073  } \
1074  break; \
1075  case mshadow::kFloat64: \
1076  { \
1077  typedef double DType; \
1078  {__VA_ARGS__} \
1079  } \
1080  break; \
1081  case mshadow::kFloat16: \
1082  { \
1083  typedef mshadow::half::half_t DType; \
1084  {__VA_ARGS__} \
1085  } \
1086  break; \
1087  case mshadow::kBfloat16: \
1088  { \
1089  typedef mshadow::bfloat::bf16_t DType; \
1090  {__VA_ARGS__} \
1091  } \
1092  break; \
1093  case mshadow::kUint8: \
1094  { \
1095  typedef uint8_t DType; \
1096  {__VA_ARGS__} \
1097  } \
1098  break; \
1099  case mshadow::kInt8: \
1100  { \
1101  typedef int8_t DType; \
1102  {__VA_ARGS__} \
1103  } \
1104  break; \
1105  case mshadow::kInt32: \
1106  { \
1107  typedef int32_t DType; \
1108  {__VA_ARGS__} \
1109  } \
1110  break; \
1111  case mshadow::kInt64: \
1112  { \
1113  typedef int64_t DType; \
1114  {__VA_ARGS__} \
1115  } \
1116  break; \
1117  default: \
1118  LOG(FATAL) << "Unknown type enum " << type; \
1119  }
1120 #else
1121 #define MSHADOW_TYPE_SWITCH(type, DType, ...) \
1122  switch (type) { \
1123  case mshadow::kFloat32: \
1124  { \
1125  typedef float DType; \
1126  {__VA_ARGS__} \
1127  } \
1128  break; \
1129  case mshadow::kFloat64: \
1130  { \
1131  typedef double DType; \
1132  {__VA_ARGS__} \
1133  } \
1134  break; \
1135  case mshadow::kFloat16: \
1136  { \
1137  typedef mshadow::half::half_t DType; \
1138  {__VA_ARGS__} \
1139  } \
1140  break; \
1141  case mshadow::kUint8: \
1142  { \
1143  typedef uint8_t DType; \
1144  {__VA_ARGS__} \
1145  } \
1146  break; \
1147  case mshadow::kInt8: \
1148  { \
1149  typedef int8_t DType; \
1150  {__VA_ARGS__} \
1151  } \
1152  break; \
1153  case mshadow::kInt32: \
1154  { \
1155  typedef int32_t DType; \
1156  {__VA_ARGS__} \
1157  } \
1158  break; \
1159  case mshadow::kInt64: \
1160  { \
1161  typedef int64_t DType; \
1162  {__VA_ARGS__} \
1163  } \
1164  break; \
1165  default: \
1166  LOG(FATAL) << "Unknown type enum " << type; \
1167  }
1168 #endif
1169 
1170 #define MSHADOW_TYPE_SWITCH_WITH_HALF2(type, DType, ...) \
1171  switch (type) { \
1172  case mshadow::kFloat32: \
1173  { \
1174  typedef float DType; \
1175  {__VA_ARGS__} \
1176  } \
1177  break; \
1178  case mshadow::kFloat64: \
1179  { \
1180  typedef double DType; \
1181  {__VA_ARGS__} \
1182  } \
1183  break; \
1184  case mshadow::kFloat16: \
1185  { \
1186  typedef mshadow::half::half2_t DType; \
1187  {__VA_ARGS__} \
1188  } \
1189  break; \
1190  case mshadow::kUint8: \
1191  { \
1192  typedef uint8_t DType; \
1193  {__VA_ARGS__} \
1194  } \
1195  break; \
1196  case mshadow::kInt32: \
1197  { \
1198  typedef int32_t DType; \
1199  {__VA_ARGS__} \
1200  } \
1201  break; \
1202  case mshadow::kInt64: \
1203  { \
1204  typedef int64_t DType; \
1205  {__VA_ARGS__} \
1206  } \
1207  break; \
1208  default: \
1209  LOG(FATAL) << "Unknown type enum " << type; \
1210  }
1211 
1212 #define MSHADOW_SGL_DBL_TYPE_SWITCH(type, DType, ...) \
1213  switch (type) { \
1214  case mshadow::kFloat32: \
1215  { \
1216  typedef float DType; \
1217  {__VA_ARGS__} \
1218  } \
1219  break; \
1220  case mshadow::kFloat64: \
1221  { \
1222  typedef double DType; \
1223  {__VA_ARGS__} \
1224  } \
1225  break; \
1226  default: \
1227  LOG(FATAL) << "This operation only supports " \
1228  "32-bit and 64-bit floating point"; \
1229  }
1230 
1231 #define MSHADOW_REAL_TYPE_SWITCH(type, DType, ...) \
1232  switch (type) { \
1233  case mshadow::kFloat32: \
1234  { \
1235  typedef float DType; \
1236  {__VA_ARGS__} \
1237  } \
1238  break; \
1239  case mshadow::kFloat64: \
1240  { \
1241  typedef double DType; \
1242  {__VA_ARGS__} \
1243  } \
1244  break; \
1245  case mshadow::kFloat16: \
1246  { \
1247  typedef mshadow::half::half_t DType; \
1248  {__VA_ARGS__} \
1249  } \
1250  break; \
1251  case mshadow::kUint8: \
1252  LOG(FATAL) << "This operation only support " \
1253  "floating point types not uint8"; \
1254  break; \
1255  case mshadow::kInt8: \
1256  LOG(FATAL) << "This operation only support " \
1257  "floating point types not int8"; \
1258  break; \
1259  case mshadow::kInt32: \
1260  LOG(FATAL) << "This operation only support " \
1261  "floating point types, not int32";\
1262  break; \
1263  case mshadow::kInt64: \
1264  LOG(FATAL) << "This operation only support " \
1265  "floating point types, not int64";\
1266  break; \
1267  default: \
1268  LOG(FATAL) << "Unknown type enum " << type; \
1269  }
1270 
1271 #ifndef __NVCC__
1272 #define MSHADOW_REAL_TYPE_SWITCH_EX(type$, DType$, DLargeType$, ...) \
1273  switch (type$) { \
1274  case mshadow::kFloat32: \
1275  { \
1276  typedef float DType$; \
1277  typedef float DLargeType$; \
1278  {__VA_ARGS__} \
1279  } \
1280  break; \
1281  case mshadow::kFloat64: \
1282  { \
1283  typedef double DType$; \
1284  typedef double DLargeType$; \
1285  {__VA_ARGS__} \
1286  } \
1287  break; \
1288  case mshadow::kFloat16: \
1289  { \
1290  typedef mshadow::half::half_t DType$; \
1291  typedef float DLargeType$; \
1292  {__VA_ARGS__} \
1293  } \
1294  break; \
1295  case mshadow::kBfloat16: \
1296  { \
1297  typedef mshadow::bfloat::bf16_t DType$; \
1298  typedef float DLargeType$; \
1299  {__VA_ARGS__} \
1300  } \
1301  break; \
1302  case mshadow::kUint8: \
1303  LOG(FATAL) << "This operation only support " \
1304  "floating point types not uint8"; \
1305  break; \
1306  case mshadow::kInt8: \
1307  LOG(FATAL) << "This operation only support " \
1308  "floating point types not int8"; \
1309  break; \
1310  case mshadow::kInt32: \
1311  LOG(FATAL) << "This operation only support " \
1312  "floating point types, not int32";\
1313  break; \
1314  case mshadow::kInt64: \
1315  LOG(FATAL) << "This operation only support " \
1316  "floating point types, not int64";\
1317  break; \
1318  default: \
1319  LOG(FATAL) << "Unknown type enum " << type$; \
1320  }
1321 #else
1322 #define MSHADOW_REAL_TYPE_SWITCH_EX(type$, DType$, DLargeType$, ...) \
1323  switch (type$) { \
1324  case mshadow::kFloat32: \
1325  { \
1326  typedef float DType$; \
1327  typedef float DLargeType$; \
1328  {__VA_ARGS__} \
1329  } \
1330  break; \
1331  case mshadow::kFloat64: \
1332  { \
1333  typedef double DType$; \
1334  typedef double DLargeType$; \
1335  {__VA_ARGS__} \
1336  } \
1337  break; \
1338  case mshadow::kFloat16: \
1339  { \
1340  typedef mshadow::half::half_t DType$; \
1341  typedef float DLargeType$; \
1342  {__VA_ARGS__} \
1343  } \
1344  break; \
1345  case mshadow::kUint8: \
1346  LOG(FATAL) << "This operation only support " \
1347  "floating point types not uint8"; \
1348  break; \
1349  case mshadow::kInt8: \
1350  LOG(FATAL) << "This operation only support " \
1351  "floating point types not int8"; \
1352  break; \
1353  case mshadow::kInt32: \
1354  LOG(FATAL) << "This operation only support " \
1355  "floating point types, not int32";\
1356  break; \
1357  case mshadow::kInt64: \
1358  LOG(FATAL) << "This operation only support " \
1359  "floating point types, not int64";\
1360  break; \
1361  default: \
1362  LOG(FATAL) << "Unknown type enum " << type$; \
1363  }
1364 #endif
1365 #define MSHADOW_LAYOUT_SWITCH(layout, Layout, ...) \
1366  switch (layout) { \
1367  case mshadow::kNCHW: \
1368  { \
1369  const int Layout = kNCHW; \
1370  {__VA_ARGS__} \
1371  } \
1372  break; \
1373  case mshadow::kNHWC: \
1374  { \
1375  const int Layout = kNHWC; \
1376  {__VA_ARGS__} \
1377  } \
1378  break; \
1379  case mshadow::kNCDHW: \
1380  { \
1381  const int Layout = kNCDHW; \
1382  {__VA_ARGS__} \
1383  } \
1384  break; \
1385  case mshadow::kNDHWC: \
1386  { \
1387  const int Layout = kNDHWC; \
1388  {__VA_ARGS__} \
1389  } \
1390  break; \
1391  default: \
1392  LOG(FATAL) << "Unknown layout enum " << layout; \
1393  }
1394 
1399 #define MSHADOW_IDX_TYPE_SWITCH(type, DType, ...) \
1400  switch (type) { \
1401  case mshadow::kInt64: \
1402  { \
1403  typedef int64_t DType; \
1404  {__VA_ARGS__} \
1405  } \
1406  break; \
1407  default: \
1408  LOG(FATAL) << "Unknown type enum " << type; \
1409  }
1410 
1411 #define MSHADOW_TYPE_SWITCH_WITH_BOOL(type, DType, ...) \
1412  switch (type) { \
1413  case mshadow::kFloat32: \
1414  { \
1415  typedef float DType; \
1416  {__VA_ARGS__} \
1417  } \
1418  break; \
1419  case mshadow::kFloat64: \
1420  { \
1421  typedef double DType; \
1422  {__VA_ARGS__} \
1423  } \
1424  break; \
1425  case mshadow::kFloat16: \
1426  { \
1427  typedef mshadow::half::half_t DType; \
1428  {__VA_ARGS__} \
1429  } \
1430  break; \
1431  case mshadow::kBfloat16: \
1432  { \
1433  typedef mshadow::bfloat::bf16_t DType; \
1434  {__VA_ARGS__} \
1435  } \
1436  break; \
1437  case mshadow::kUint8: \
1438  { \
1439  typedef uint8_t DType; \
1440  {__VA_ARGS__} \
1441  } \
1442  break; \
1443  case mshadow::kInt8: \
1444  { \
1445  typedef int8_t DType; \
1446  {__VA_ARGS__} \
1447  } \
1448  break; \
1449  case mshadow::kInt32: \
1450  { \
1451  typedef int32_t DType; \
1452  {__VA_ARGS__} \
1453  } \
1454  break; \
1455  case mshadow::kInt64: \
1456  { \
1457  typedef int64_t DType; \
1458  {__VA_ARGS__} \
1459  } \
1460  break; \
1461  case mshadow::kBool: \
1462  { \
1463  typedef bool DType; \
1464  {__VA_ARGS__} \
1465  } \
1466  break; \
1467  default: \
1468  LOG(FATAL) << "Unknown type enum " << type; \
1469  }
1470 
1472 inline size_t mshadow_sizeof(int type) {
1473  int size = 0;
1474  MSHADOW_TYPE_SWITCH_WITH_BOOL(type, DType, size = sizeof(DType););
1475  return size;
1476 }
1477 
1478 /*/ \brief get string with the type name from type enum */
1479 inline std::string dtype_string(const int dtype) {
1480  switch (dtype) {
1481  case mshadow::kFloat32:
1482  return "float";
1483  case mshadow::kFloat64:
1484  return "double";
1485  case mshadow::kFloat16:
1486  return "half";
1487  case mshadow::kUint8:
1488  return "unsigned char";
1489  case mshadow::kInt8:
1490  return "char";
1491  case mshadow::kInt32:
1492  return "int";
1493  case mshadow::kInt64:
1494  return "long long";
1495  case mshadow::kBool:
1496  return "bool";
1497  default:
1498  LOG(FATAL) << "Unknown type enum " << dtype;
1499  }
1500  return "unknown";
1501 }
1502 
1503 } // namespace mshadow
1504 #endif // MSHADOW_BASE_H_
static MSHADOW_XINLINE void Finalize(volatile DType &dst, volatile DType &residual)
finalize reduction
Definition: base.h:1040
#define MSHADOW_TYPE_SWITCH_WITH_BOOL(type, DType,...)
Definition: base.h:1411
Definition: base.h:352
static MSHADOW_XINLINE DType PartialGrad(DType redres, DType redsrc)
calculate gradient of redres with respect to redsrc, redres: reduced result, redsrc: one of reduction...
Definition: base.h:1046
const int default_type_flag
type enum value for default real type
Definition: base.h:477
static MSHADOW_XINLINE void Finalize(volatile DType &dst, volatile DType &residual)
finalize reduction
Definition: base.h:987
MSHADOW_XINLINE DType MaxValue(void)
maximum value of certain types
Definition: base.h:489
MSHADOW_XINLINE int8_t MaxValue< int8_t >(void)
maximum value of int8_t
Definition: base.h:840
static default_real_t BetaBLAS(void)
helper constant to use BLAS, beta
Definition: base.h:639
MSHADOW_XINLINE float NegInfValue< float >(void)
negative infinity value of float
Definition: base.h:792
static MSHADOW_XINLINE void Save(DType &a, DType b)
save b to a using save method
Definition: base.h:619
MSHADOW_XINLINE int MinValue< int32_t >(void)
minimum value of int32_t
Definition: base.h:763
static MSHADOW_XINLINE void Save(DType &a, DType b)
save b to a using save method
Definition: base.h:633
MSHADOW_XINLINE float MinValue< float >(void)
minimum value of float
Definition: base.h:733
definition of vector float16, half2 type.
static MSHADOW_XINLINE DType PartialGrad(DType redres, DType redsrc)
calculate gradient of redres with respect to redsrc, redres: reduced result, redsrc: one of reduction...
Definition: base.h:939
static MSHADOW_XINLINE void SetInitValue(DType &initv, DType &residual)
set the initial value during reduction
Definition: base.h:953
MSHADOW_XINLINE uint32_t MaxValue< uint32_t >(void)
maximum value of uint32_t
Definition: base.h:860
save to saver: =
Definition: base.h:602
MSHADOW_XINLINE uint8_t MinValue< uint8_t >(void)
minimum value of uint8_t
Definition: base.h:753
MSHADOW_XINLINE bool IsInf(volatile DType val)
Definition: base.h:701
definition of bfloat type.
definition of half (float16) type.
static MSHADOW_XINLINE DType PartialGrad(DType redres, DType redsrc)
calculate gradient of redres with respect to redsrc, redres: reduced result, redsrc: one of reduction...
Definition: base.h:993
static MSHADOW_XINLINE void Reduce(volatile DType &dst, volatile DType src, volatile DType &none)
do reduction into dst
Definition: base.h:969
#define MSHADOW_HALF_EXPONENT_BITS
Definition: half.h:372
divide operator
Definition: base.h:571
MSHADOW_XINLINE DType NegInfValue(void)
negative infinity of certain types
Definition: base.h:787
static MSHADOW_XINLINE void Save(DType &a, DType b)
save b to a using save method
Definition: base.h:647
MSHADOW_XINLINE float PosInfValue< float >(void)
positive infinity value of float
Definition: base.h:874
static MSHADOW_XINLINE void SetInitValue(DType &initv)
set the initial value during reduction
Definition: base.h:946
op::right OPType
corresponding binary operator type
Definition: base.h:613
op::minus OPType
corresponding binary operator type
Definition: base.h:641
static MSHADOW_XINLINE DType Map(DType a, DType b)
map a, b to result using defined operation
Definition: base.h:574
MSHADOW_XINLINE bool IsNan(volatile mshadow::half::half_t val)
Definition: base.h:691
static MSHADOW_XINLINE void SetInitValue(DType &initv, DType &none)
set the initial value during reduction
Definition: base.h:1007
static MSHADOW_XINLINE void Reduce(volatile DType &dst, volatile DType src)
do reduction into dst
Definition: base.h:1015
static MSHADOW_XINLINE void Reduce(volatile DType &dst, volatile DType src)
do reduction into dst
Definition: base.h:962
Definition: base.h:481
Definition: base.h:360
identity function that maps a real number to it self
Definition: base.h:591
op::mul OPType
corresponding binary operator type
Definition: base.h:651
static MSHADOW_XINLINE void Save(DType &a, DType b)
save b to a using save method
Definition: base.h:605
static default_real_t AlphaBLAS(void)
helper constant to use BLAS, alpha
Definition: base.h:637
static MSHADOW_XINLINE void Finalize(volatile DType &dst, volatile DType &residual)
finalize reduction
Definition: base.h:933
Definition: base.h:495
MSHADOW_XINLINE int8_t MinValue< int8_t >(void)
minimum value of int8_t
Definition: base.h:758
Definition: base.h:482
static MSHADOW_XINLINE void Finalize(volatile DType &dst)
finalize reduction
Definition: base.h:984
MSHADOW_XINLINE bool IsInf(volatile mshadow::half::half_t val)
Definition: base.h:717
MSHADOW_XINLINE bool MinValue< bool >(void)
minimum value of bool
Definition: base.h:773
#define MSHADOW_XINLINE
Definition: base.h:223
static MSHADOW_XINLINE void Reduce(volatile DType &dst, volatile DType src, volatile DType &residual)
do stable reduction into dst
Definition: base.h:899
static MSHADOW_XINLINE void Save(DType &a, DType b)
save b to a using save method
Definition: base.h:657
const unsigned kRandBufferSize
buffer size for each random number generator
Definition: base.h:329
MSHADOW_XINLINE double MaxValue< double >(void)
maximum value of double
Definition: base.h:820
MSHADOW_XINLINE bool IsNan(volatile DType val)
Definition: base.h:675
const int default_layout
default layout for 4d tensor
Definition: base.h:518
static MSHADOW_XINLINE void SetInitValue(DType &initv, DType &none)
set the initial value during reduction
Definition: base.h:1060
LayoutFlag
Definition: base.h:480
get rhs
Definition: base.h:579
#define MSHADOW_BF16_MAX
Definition: bfloat.h:183
std::string dtype_string(const int dtype)
Definition: base.h:1479
minus to saver: -=
Definition: base.h:630
MSHADOW_XINLINE int MaxValue< int32_t >(void)
maximum value of int32_t
Definition: base.h:845
int32_t index_t
type that will be used for index
Definition: base.h:336
multiply to saver: *=
Definition: base.h:644
static MSHADOW_XINLINE DType Map(DType a, DType b)
map a, b to result using defined operation
Definition: base.h:558
const float kPi
pi
Definition: base.h:331
Definition: base.h:357
op::plus OPType
corresponding binary operator type
Definition: base.h:627
MSHADOW_XINLINE float MaxValue< float >(void)
maximum value of float
Definition: base.h:815
float default_real_t
float point type that will be used in default by mshadow
Definition: base.h:348
const int default_layout_5d
default layout for 5d tensor
Definition: base.h:541
static MSHADOW_XINLINE void Reduce(volatile DType &dst, volatile DType src, volatile DType &none)
do reduction into dst
Definition: base.h:1022
MSHADOW_XINLINE int64_t MinValue< int64_t >(void)
minimum value of int64_t
Definition: base.h:768
minimum reducer
Definition: base.h:1012
MSHADOW_XINLINE double MinValue< double >(void)
minimum value of double
Definition: base.h:738
#define MSHADOW_HALF_MIN
overloaded + operator for half_t
Definition: half.h:369
Definition: base.h:359
Definition: base.h:353
MSHADOW_XINLINE DType MinValue(void)
minimum value of certain types
divide to saver: /=
Definition: base.h:654
static MSHADOW_XINLINE DType Map(DType a, DType b)
map a, b to result using defined operation
Definition: base.h:550
#define MSHADOW_HALF_BF_OPERATOR(RTYPE, OP)
Definition: base.h:300
static MSHADOW_XINLINE void Finalize(volatile DType &dst)
finalize reduction
Definition: base.h:1037
Definition: base.h:363
MSHADOW_XINLINE DType PosInfValue(void)
positive infinity of certain types
Definition: base.h:869
size_t mshadow_sizeof(int type)
get data type size from type enum
Definition: base.h:1472
Definition: base.h:356
static MSHADOW_XINLINE void Merge(volatile DType &dst_val, volatile DType &src_val)
combine the results of two reducers
Definition: base.h:974
MSHADOW_XINLINE double NegInfValue< double >(void)
negative infinity value of double
Definition: base.h:797
static MSHADOW_XINLINE DType Map(DType a)
map a to result using defined operation
Definition: base.h:594
Definition: base.h:490
Definition: base.h:362
maximum reducer
Definition: base.h:959
static MSHADOW_XINLINE void Merge(volatile DType &dst_val, volatile DType &dst_residual, volatile DType &src_val, volatile DType &src_residual)
combine the results of two reducers
Definition: base.h:916
TypeFlag
data type flag
Definition: base.h:351
MSHADOW_XINLINE int64_t MaxValue< int64_t >(void)
maximum value of int64_t
Definition: base.h:850
Definition: base.h:485
Definition: base.h:491
static MSHADOW_XINLINE DType Map(DType a, DType b)
map a, b to result using defined operation
Definition: base.h:566
plus operator
Definition: base.h:555
static default_real_t AlphaBLAS(void)
helper constant to use BLAS, alpha
Definition: base.h:609
Definition: base.h:355
save to saver: +=
Definition: base.h:616
Definition: base.h:487
static MSHADOW_XINLINE void Merge(volatile DType &dst_val, volatile DType &dst_residual, volatile DType &src_val, volatile DType &src_residual)
combine the results of two reducers
Definition: base.h:1032
sum reducer
Definition: base.h:891
Definition: base.h:486
static MSHADOW_XINLINE void Merge(volatile DType &dst_val, volatile DType &src_val)
combine the results of two reducers
Definition: base.h:911
static default_real_t BetaBLAS(void)
helper constant to use BLAS, beta
Definition: base.h:625
static MSHADOW_XINLINE void SetInitValue(DType &initv)
set the initial value during reduction
Definition: base.h:1000
MSHADOW_XINLINE unsigned int MinValue< unsigned int >(void)
minimum value of unsigned int
Definition: base.h:778
mul operator
Definition: base.h:547
#define MSHADOW_HALF_SIGN_BIT
Definition: half.h:371
#define MSHADOW_HALF_MAX
Definition: half.h:370
static default_real_t BetaBLAS(void)
helper constant to use BLAS, beta
Definition: base.h:611
Definition: base.h:368
overloaded + operator between half_t and bf16_t
Definition: base.h:327
Definition: base.h:361
static MSHADOW_XINLINE void SetInitValue(DType &initv)
set the initial value during reduction
Definition: base.h:1053
Definition: base.h:354
Definition: base.h:364
MSHADOW_XINLINE bool MaxValue< bool >(void)
maximum value of bool
Definition: base.h:855
Definition: base.h:358
static MSHADOW_XINLINE void Reduce(volatile DType &dst, volatile DType src)
do reduction into dst
Definition: base.h:894
MSHADOW_XINLINE uint8_t MaxValue< uint8_t >(void)
maximum value of uint8_t
Definition: base.h:835
Definition: base.h:483
static MSHADOW_XINLINE DType Map(DType a, DType b)
map a, b to result using defined operation
Definition: base.h:582
static MSHADOW_XINLINE void Merge(volatile DType &dst_val, volatile DType &src_val)
combine the results of two reducers
Definition: base.h:1027
static MSHADOW_XINLINE void Merge(volatile DType &dst_val, volatile DType &dst_residual, volatile DType &src_val, volatile DType &src_residual)
combine the results of two reducers
Definition: base.h:979
static MSHADOW_XINLINE void Finalize(volatile DType &dst)
finalize reduction
Definition: base.h:930
static default_real_t AlphaBLAS(void)
helper constant to use BLAS, alpha
Definition: base.h:623
index_t openmp_index_t
openmp index for linux
Definition: base.h:344
minus operator
Definition: base.h:563
MSHADOW_XINLINE double PosInfValue< double >(void)
positive infinity value of double
Definition: base.h:879
op::div OPType
corresponding binary operator type
Definition: base.h:661
#define MSHADOW_BF16_MIN
overloaded + operator for bf16_t
Definition: bfloat.h:182