mxnet
base.h
Go to the documentation of this file.
1 
8 #ifndef MSHADOW_BASE_H_
9 #define MSHADOW_BASE_H_
10 #ifdef _MSC_VER
11 #ifndef _CRT_SECURE_NO_WARNINGS
12 #define _CRT_SECURE_NO_WARNINGS
13 #endif
14 #ifndef _CRT_SECURE_NO_DEPRECATE
15 #define _CRT_SECURE_NO_DEPRECATE
16 #endif
17 #ifndef NOMINMAX
18 #define NOMINMAX
19 #endif
20 #endif
21 #include <cmath>
22 #include <cstdio>
23 #include <cfloat>
24 #include <climits>
25 #include <algorithm>
26 #include <functional>
27 #include <sstream>
28 #include <string>
29 
30 #ifdef _MSC_VER
31 typedef signed char int8_t;
33 typedef __int16 int16_t;
34 typedef __int32 int32_t;
35 typedef __int64 int64_t;
36 typedef unsigned char uint8_t;
37 typedef unsigned __int16 uint16_t;
38 typedef unsigned __int32 uint32_t;
39 typedef unsigned __int64 uint64_t;
41 #else
42 #include <inttypes.h>
43 #endif
44 // macro defintiions
49 #ifndef MSHADOW_STAND_ALONE
50 #define MSHADOW_STAND_ALONE 0
51 #endif
52 
53 #ifndef MSHADOW_ALLOC_PAD
54 #define MSHADOW_ALLOC_PAD true
55 #endif
56 
64 #ifndef MSHADOW_MIN_PAD_RATIO
65  #define MSHADOW_MIN_PAD_RATIO 2
66 #endif
67 
68 #if MSHADOW_STAND_ALONE
69  #define MSHADOW_USE_CBLAS 0
70  #define MSHADOW_USE_MKL 0
71  #define MSHADOW_USE_CUDA 0
72 #endif
73 
78 #ifndef MSHADOW_FORCE_STREAM
79 #define MSHADOW_FORCE_STREAM 1
80 #endif
81 
83 #ifndef MSHADOW_USE_CBLAS
84  #define MSHADOW_USE_CBLAS 0
85 #endif
86 
87 #ifndef MSHADOW_USE_MKL
88  #define MSHADOW_USE_MKL 1
89 #endif
90 
95 #ifndef MSHADOW_USE_CUDA
96  #define MSHADOW_USE_CUDA 1
97 #endif
98 
102 #ifndef MSHADOW_USE_CUDNN
103  #define MSHADOW_USE_CUDNN 0
104 #endif
105 
109 #ifndef MSHADOW_USE_CUSOLVER
110  #define MSHADOW_USE_CUSOLVER MSHADOW_USE_CUDA
111 #endif
112 
117 #ifndef MSHADOW_OLD_CUDA
118 #define MSHADOW_OLD_CUDA 0
119 #endif
120 
124 #ifndef MSHADOW_IN_CXX11
125  #if (defined(__GXX_EXPERIMENTAL_CXX0X__) ||\
126  __cplusplus >= 201103L || defined(_MSC_VER))
127  #define MSHADOW_IN_CXX11 1
128  #else
129  #define MSHADOW_IN_CXX11 0
130  #endif
131 #endif
132 
134 #ifndef MSHADOW_USE_SSE
135  #define MSHADOW_USE_SSE 1
136 #endif
137 
139 #ifndef MSHADOW_USE_F16C
140  #if defined(_MSC_VER) || defined(__CUDACC__)
141  #define MSHADOW_USE_F16C 0
142  #elif defined(__clang__) && \
143  ((__clang_major__ < 8) || ((__clang_major__ == 8) && (__clang_minor__ < 1)))
144  #define MSHADOW_USE_F16C 0
145  #else
146  #define MSHADOW_USE_F16C 1
147  #endif
148 #endif
149 
151 #ifndef MSHADOW_USE_NVML
152  #define MSHADOW_USE_NVML 0
153 #endif
154 // SSE is conflict with cudacc
155 #ifdef __CUDACC__
156  #undef MSHADOW_USE_SSE
157  #define MSHADOW_USE_SSE 0
158 #endif
159 
160 #if MSHADOW_USE_CBLAS
161 extern "C" {
162  #include <cblas.h>
163 }
164 #elif MSHADOW_USE_MKL
165  #include <mkl_blas.h>
166  #include <mkl_cblas.h>
167  #include <mkl_vsl.h>
168  #include <mkl_vsl_functions.h>
169  #include <mkl_version.h>
170 #endif
171 
172 #if MSHADOW_USE_CUDA
173  #include <cuda.h>
174  #include <cublas_v2.h>
175  #include <curand.h>
176 #endif
177 
178 #if MSHADOW_USE_CUDNN == 1
179  #include <cudnn.h>
180 #endif
181 
182 #if MSHADOW_USE_CUSOLVER == 1
183  #include <cusolverDn.h>
184 #endif
185 
186 #if MSHADOW_USE_NVML
187  #include <nvml.h>
188 #endif
189 
190 // --------------------------------
191 // MSHADOW_XINLINE is used for inlining template code for both CUDA and CPU code
192 #ifdef MSHADOW_XINLINE
193  #error "MSHADOW_XINLINE must not be defined"
194 #endif
195 #ifdef _MSC_VER
196 #define MSHADOW_FORCE_INLINE __forceinline
197 #pragma warning(disable : 4068)
198 #else
199 #define MSHADOW_FORCE_INLINE inline __attribute__((always_inline))
200 #endif
201 #ifdef __CUDACC__
202  #define MSHADOW_XINLINE MSHADOW_FORCE_INLINE __device__ __host__
203 #else
204  #define MSHADOW_XINLINE MSHADOW_FORCE_INLINE
205 #endif
206 
207 #define MSHADOW_CINLINE MSHADOW_FORCE_INLINE
208 
209 #if defined(__GXX_EXPERIMENTAL_CXX0X) ||\
210  defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L
211  #define MSHADOW_CONSTEXPR constexpr
212 #else
213  #define MSHADOW_CONSTEXPR const
214 #endif
215 
222 #ifndef MSHADOW_DEFAULT_DTYPE
223 #define MSHADOW_DEFAULT_DTYPE = ::mshadow::default_real_t
224 #endif
225 
229 #ifndef MSHADOW_USE_GLOG
230 #define MSHADOW_USE_GLOG DMLC_USE_GLOG
231 #endif // MSHADOW_USE_GLOG
232 
233 #if DMLC_USE_CXX11
234 #define MSHADOW_THROW_EXCEPTION noexcept(false)
235 #define MSHADOW_NO_EXCEPTION noexcept(true)
236 #else
237 #define MSHADOW_THROW_EXCEPTION
238 #define MSHADOW_NO_EXCEPTION
239 #endif
240 
241 #if defined(_MSC_VER)
242 #define MSHADOW_ALIGNED(x) __declspec(align(x))
243 #else
244 #define MSHADOW_ALIGNED(x) __attribute__ ((aligned(x)))
245 #endif
246 
252 #define MSHADOW_CUDA_CALL(func) \
253  { \
254  cudaError_t e = (func); \
255  if (e == cudaErrorCudartUnloading) { \
256  throw dmlc::Error(cudaGetErrorString(e)); \
257  } \
258  CHECK(e == cudaSuccess) \
259  << "CUDA: " << cudaGetErrorString(e); \
260  }
261 
266 #define MSHADOW_CATCH_ERROR(func) \
267  { \
268  try { \
269  (func); \
270  } catch (const dmlc::Error &e) { \
271  std::string what = e.what(); \
272  if (what.find("driver shutting down") == std::string::npos) { \
273  LOG(ERROR) << "Ignore CUDA Error " << what; \
274  } \
275  } \
276  }
277 
278 #include "./half.h"
279 #include "./half2.h"
280 #include "./logging.h"
282 namespace mshadow {
284 const unsigned kRandBufferSize = 1000000;
286 const float kPi = 3.1415926f;
288 #if MSHADOW_INT64_TENSOR_SIZE == 1
289  typedef int64_t index_t;
290 #else
291  typedef int32_t index_t;
292 #endif
293 
294 #ifdef _WIN32
295 
296  typedef int64_t openmp_index_t;
297 #else
298 
299  typedef index_t openmp_index_t;
300 #endif
301 
303 typedef float default_real_t;
304 
306 enum TypeFlag {
307  kFloat32 = 0,
308  kFloat64 = 1,
309  kFloat16 = 2,
310  kUint8 = 3,
311  kInt32 = 4,
312  kInt8 = 5,
313  kInt64 = 6,
314  kBool = 7,
315 };
316 
317 template<typename DType>
318 struct DataType;
319 
320 template<>
321 struct DataType<float> {
322  static const int kFlag = kFloat32;
323  static const int kLanes = 1;
324 #if MSHADOW_USE_CUDA
325 #if (CUDA_VERSION >= 8000)
326  static const cudaDataType_t kCudaFlag = CUDA_R_32F;
327 #endif
328 #if MSHADOW_USE_CUDNN
329  static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_FLOAT;
330  typedef float ScaleType;
331 #endif
332 #endif
333 };
334 template<>
335 struct DataType<double> {
336  static const int kFlag = kFloat64;
337  static const int kLanes = 1;
338 #if MSHADOW_USE_CUDA
339 #if (CUDA_VERSION >= 8000)
340  static const cudaDataType_t kCudaFlag = CUDA_R_64F;
341 #endif
342 #if MSHADOW_USE_CUDNN
343  static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_DOUBLE;
344  typedef double ScaleType;
345 #endif
346 #endif
347 };
348 template<>
349 struct DataType<half::half_t> {
350  static const int kFlag = kFloat16;
351  static const int kLanes = 1;
352 #if MSHADOW_USE_CUDA
353 #if (CUDA_VERSION >= 8000)
354  static const cudaDataType_t kCudaFlag = CUDA_R_16F;
355 #endif
356 #if MSHADOW_USE_CUDNN
357  static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_HALF;
358  typedef float ScaleType;
359 #endif
360 #endif
361 };
362 template<>
363 struct DataType<half::half2_t> {
364  static const int kFlag = kFloat16;
365  static const int kLanes = 2;
366 };
367 template<>
368 struct DataType<uint8_t> {
369  static const int kFlag = kUint8;
370  static const int kLanes = 1;
371 #if MSHADOW_USE_CUDA
372 #if (CUDA_VERSION >= 8000)
373  static const cudaDataType_t kCudaFlag = CUDA_R_8U;
374 #endif
375 #if (MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 6)
376  // no uint8 in cudnn for now
377  static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_INT8;
378  typedef uint8_t ScaleType;
379 #endif
380 #endif
381 };
382 template<>
383 struct DataType<int8_t> {
384  static const int kFlag = kInt8;
385  static const int kLanes = 1;
386 #if MSHADOW_USE_CUDA
387 #if (CUDA_VERSION >= 8000)
388  static const cudaDataType_t kCudaFlag = CUDA_R_8I;
389 #endif
390 #if (MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 6)
391  static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_INT8;
392  typedef int8_t ScaleType;
393 #endif
394 #endif
395 };
396 template<>
397 struct DataType<int32_t> {
398  static const int kFlag = kInt32;
399  static const int kLanes = 1;
400 #if MSHADOW_USE_CUDA
401 #if (CUDA_VERSION >= 8000)
402  static const cudaDataType_t kCudaFlag = CUDA_R_32I;
403 #endif
404 #if (MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 6)
405  static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_INT32;
406  typedef int32_t ScaleType;
407 #endif
408 #endif
409 };
410 template<>
411 struct DataType<int64_t> {
412  static const int kFlag = kInt64;
413  static const int kLanes = 1;
414 };
415 template<>
416 struct DataType<bool> {
417  static const int kFlag = kBool;
418  static const int kLanes = 1;
419 };
420 
423 
426  kNCHW = 0,
429 
430  kNCW = 1 << 3,
433 
434  kNCDHW = 1 << 5,
437 };
438 
439 template<int layout>
440 struct LayoutType;
441 
442 template<>
443 struct LayoutType<kNCHW> {
444  static const index_t kNdim = 4;
445 #if (MSHADOW_USE_CUDA && MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 4)
446  static const cudnnTensorFormat_t kCudnnFlag = CUDNN_TENSOR_NCHW;
447 #else
448  static const int kCudnnFlag = -1;
449 #endif
450 };
451 
452 template<>
453 struct LayoutType<kNHWC> {
454  static const index_t kNdim = 4;
455 #if (MSHADOW_USE_CUDA && MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 4)
456  static const cudnnTensorFormat_t kCudnnFlag = CUDNN_TENSOR_NHWC;
457 #else
458  static const int kCudnnFlag = -1;
459 #endif
460 };
461 
463 const int default_layout = kNCHW;
464 
465 template<>
467  static const index_t kNdim = 5;
468 #if (MSHADOW_USE_CUDA && MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 4)
469  static const cudnnTensorFormat_t kCudnnFlag = CUDNN_TENSOR_NCHW;
470 #else
471  static const int kCudnnFlag = -1;
472 #endif
473 };
474 
475 template<>
477  static const index_t kNdim = 5;
478 #if (MSHADOW_USE_CUDA && MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 4)
479  static const cudnnTensorFormat_t kCudnnFlag = CUDNN_TENSOR_NHWC;
480 #else
481  static const int kCudnnFlag = -1;
482 #endif
483 };
484 
487 
489 namespace op {
490 // binary operator
492 struct mul{
494  template<typename DType>
495  MSHADOW_XINLINE static DType Map(DType a, DType b) {
496  return a * b;
497  }
498 };
500 struct plus {
502  template<typename DType>
503  MSHADOW_XINLINE static DType Map(DType a, DType b) {
504  return a + b;
505  }
506 };
508 struct minus {
510  template<typename DType>
511  MSHADOW_XINLINE static DType Map(DType a, DType b) {
512  return a - b;
513  }
514 };
516 struct div {
518  template<typename DType>
519  MSHADOW_XINLINE static DType Map(DType a, DType b) {
520  return a / b;
521  }
522 };
524 struct right {
526  template<typename DType>
527  MSHADOW_XINLINE static DType Map(DType a, DType b) {
528  return b;
529  }
530 };
531 // unary operator/ function: example
532 // these operators can be defined by user,
533 // in the same style as binary and unary operator
534 // to use, simply write F<op::identity>( src )
536 struct identity{
538  template<typename DType>
539  MSHADOW_XINLINE static DType Map(DType a) {
540  return a;
541  }
542 };
543 } // namespace op
545 namespace sv {
547 struct saveto {
549  template<typename DType>
550  MSHADOW_XINLINE static void Save(DType &a, DType b) { // NOLINT(*)
551  a = b;
552  }
554  inline static default_real_t AlphaBLAS(void) { return 1.0f; }
556  inline static default_real_t BetaBLAS(void) { return 0.0f; }
558  typedef op::right OPType;
559 };
561 struct plusto {
563  template<typename DType>
564  MSHADOW_XINLINE static void Save(DType &a, DType b) { // NOLINT(*)
565  a += b;
566  }
568  inline static default_real_t AlphaBLAS(void) { return 1.0f; }
570  inline static default_real_t BetaBLAS(void) { return 1.0f; }
572  typedef op::plus OPType;
573 };
575 struct minusto {
577  template<typename DType>
578  MSHADOW_XINLINE static void Save(DType &a, DType b) { // NOLINT(*)
579  a -= b;
580  }
582  inline static default_real_t AlphaBLAS(void) { return -1.0f; }
584  inline static default_real_t BetaBLAS(void) { return 1.0f; }
586  typedef op::minus OPType;
587 };
589 struct multo {
591  template<typename DType>
592  MSHADOW_XINLINE static void Save(DType &a, DType b) { // NOLINT(*)
593  a *= b;
594  }
596  typedef op::mul OPType;
597 };
599 struct divto {
601  template<typename DType>
602  MSHADOW_XINLINE static void Save(DType& a, DType b) { // NOLINT(*)
603  a /= b;
604  }
606  typedef op::div OPType;
607 };
608 } // namespace sv
609 
610 #ifndef __CUDA_ARCH__
611 using std::isnan;
612 using std::isinf;
613 #endif
614 
618 namespace isnan_typed {
619  template<typename DType>
620  MSHADOW_XINLINE bool IsNan(volatile DType val) {
621  return false;
622  }
623  template<>
624  MSHADOW_XINLINE bool IsNan(volatile float val) {
625  return isnan(val);
626  }
627  template<>
628  MSHADOW_XINLINE bool IsNan(volatile double val) {
629  return isnan(val);
630  }
631  template<>
632  MSHADOW_XINLINE bool IsNan(volatile long double val) {
633  return isnan(val);
634  }
635  template<>
636  MSHADOW_XINLINE bool IsNan(volatile mshadow::half::half_t val) {
637  return (val.half_ & (~MSHADOW_HALF_SIGN_BIT)) > MSHADOW_HALF_EXPONENT_BITS;
638  }
639 } // namespace isnan_typed
640 
644 namespace isinf_typed {
645  template<typename DType>
646  MSHADOW_XINLINE bool IsInf(volatile DType val) {
647  return false;
648  }
649  template<>
650  MSHADOW_XINLINE bool IsInf(volatile float val) {
651  return isinf(val);
652  }
653  template<>
654  MSHADOW_XINLINE bool IsInf(volatile double val) {
655  return isinf(val);
656  }
657  template<>
658  MSHADOW_XINLINE bool IsInf(volatile long double val) {
659  return isinf(val);
660  }
661  template<>
662  MSHADOW_XINLINE bool IsInf(volatile mshadow::half::half_t val) {
663  return (val.half_ & (~MSHADOW_HALF_SIGN_BIT)) == MSHADOW_HALF_EXPONENT_BITS;
664  }
665 } // namespace isinf_typed
666 
668 namespace red {
669 namespace limits {
674 template<typename DType>
675 MSHADOW_XINLINE DType MinValue(void);
677 template<>
679  return -FLT_MAX;
680 }
682 template<>
684  return -DBL_MAX;
685 }
687 template<>
688 MSHADOW_XINLINE half::half_t MinValue<half::half_t>(void) {
689  return MSHADOW_HALF_MIN;
690 }
692 template<>
694  return 0;
695 }
697 template<>
699  return SCHAR_MIN;
700 }
702 template<>
704  return INT_MIN;
705 }
707 template<>
709  return LLONG_MIN;
710 }
712 template<>
714  return false;
715 }
716 
721 template<typename DType>
723  return MinValue<DType>();
724 }
726 template<>
728  return -HUGE_VALF;
729 }
731 template<>
733  return -HUGE_VAL;
734 }
736 template<>
737 MSHADOW_XINLINE half::half_t NegInfValue<half::half_t>(void) {
738  return half::half_t::Binary(
740 }
741 
746 template<typename DType>
747 MSHADOW_XINLINE DType MaxValue(void);
749 template<>
751  return FLT_MAX;
752 }
754 template<>
756  return DBL_MAX;
757 }
759 template<>
760 MSHADOW_XINLINE half::half_t MaxValue<half::half_t>(void) {
761  return MSHADOW_HALF_MAX;
762 }
764 template<>
766  return UCHAR_MAX;
767 }
769 template<>
771  return SCHAR_MAX;
772 }
774 template<>
776  return INT_MAX;
777 }
779 template<>
781  return LLONG_MAX;
782 }
784 template<>
786  return true;
787 }
788 
793 template<typename DType>
795  return MaxValue<DType>();
796 }
798 template<>
800  return HUGE_VALF;
801 }
803 template<>
805  return HUGE_VAL;
806 }
808 template<>
809 MSHADOW_XINLINE half::half_t PosInfValue<half::half_t>(void) {
810  return half::half_t::Binary(MSHADOW_HALF_EXPONENT_BITS);
811 }
812 
813 } // namespace limits
814 
816 struct sum {
818  template<typename DType>
819  MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src) { // NOLINT(*)
820  dst += src;
821  }
823  template<typename DType>
824  MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src, volatile DType& residual) { // NOLINT(*)
825  DType y = src - residual;
826  DType t = dst + y;
827  if (isinf_typed::IsInf(t)) {
828  residual = 0;
829  } else {
830  residual = (t - dst) - y;
831  }
832  dst = t;
833  }
835  template<typename DType>
836  MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& src_val) { // NOLINT(*)
837  Reduce(dst_val, src_val);
838  }
840  template<typename DType>
841  MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& dst_residual, volatile DType& src_val, volatile DType& src_residual) { // NOLINT(*)
842  DType t1 = dst_val + src_val;
843  if (isinf_typed::IsInf(t1)) {
844  dst_val = t1;
845  dst_residual = 0;
846  } else {
847  DType e = t1 - dst_val;
848  DType t2 = ((src_val - e) + (dst_val - (t1 - e))) + dst_residual + src_residual;
849  dst_val = t1 + t2;
850  dst_residual = t2 - (dst_val - t1);
851  }
852  }
854  template<typename DType>
855  MSHADOW_XINLINE static void Finalize(volatile DType& dst) {} // NOLINT(*)
857  template<typename DType>
858  MSHADOW_XINLINE static void Finalize(volatile DType& dst, volatile DType& residual) {} // NOLINT(*)
863  template<typename DType>
864  MSHADOW_XINLINE static DType PartialGrad(DType redres, DType redsrc) {
865  return 1;
866  }
870  template<typename DType>
871  MSHADOW_XINLINE static void SetInitValue(DType &initv) { // NOLINT(*)
872  initv = 0;
873  }
877  template<typename DType>
878  MSHADOW_XINLINE static void SetInitValue(DType &initv, DType &residual) { // NOLINT(*)
879  SetInitValue(initv);
880  residual = 0;
881  }
882 };
884 struct maximum {
886  template<typename DType>
887  MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src) { // NOLINT(*)
888  if (!isnan_typed::IsNan(dst)) {
889  if (!(dst >= src)) dst = src;
890  }
891  }
893  template<typename DType>
894  MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src, volatile DType &none) { // NOLINT(*)
895  Reduce(dst, src);
896  }
898  template<typename DType>
899  MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& src_val) { // NOLINT(*)
900  Reduce(dst_val, src_val);
901  }
903  template<typename DType>
904  MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& dst_residual, volatile DType& src_val, volatile DType& src_residual) { // NOLINT(*)
905  Reduce(dst_val, src_val);
906  }
908  template<typename DType>
909  MSHADOW_XINLINE static void Finalize(volatile DType& dst) {} // NOLINT(*)
911  template<typename DType>
912  MSHADOW_XINLINE static void Finalize(volatile DType& dst, volatile DType& residual) {} // NOLINT(*)
917  template<typename DType>
918  MSHADOW_XINLINE static DType PartialGrad(DType redres, DType redsrc) {
919  return redres == redsrc ? 1: 0;
920  }
924  template<typename DType>
925  MSHADOW_XINLINE static void SetInitValue(DType &initv) { // NOLINT(*)
926  initv = limits::NegInfValue<DType>();
927  }
931  template<typename DType>
932  MSHADOW_XINLINE static void SetInitValue(DType &initv, DType &none) { // NOLINT(*)
933  SetInitValue(initv);
934  }
935 };
937 struct minimum {
939  template<typename DType>
940  MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src) { // NOLINT(*)
941  if (!isnan_typed::IsNan(dst)) {
942  if (!(dst <= src)) dst = src;
943  }
944  }
946  template<typename DType>
947  MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src, volatile DType &none) { // NOLINT(*)
948  Reduce(dst, src);
949  }
951  template<typename DType>
952  MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& src_val) { // NOLINT(*)
953  Reduce(dst_val, src_val);
954  }
956  template<typename DType>
957  MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& dst_residual, volatile DType& src_val, volatile DType& src_residual) { // NOLINT(*)
958  Reduce(dst_val, src_val);
959  }
961  template<typename DType>
962  MSHADOW_XINLINE static void Finalize(volatile DType& dst) {} // NOLINT(*)
964  template<typename DType>
965  MSHADOW_XINLINE static void Finalize(volatile DType& dst, volatile DType& residual) {} // NOLINT(*)
970  template<typename DType>
971  MSHADOW_XINLINE static DType PartialGrad(DType redres, DType redsrc) {
972  return redres == redsrc ? 1: 0;
973  }
977  template<typename DType>
978  MSHADOW_XINLINE static void SetInitValue(DType &initv) { // NOLINT(*)
979  initv = limits::PosInfValue<DType>();
980  }
984  template<typename DType>
985  MSHADOW_XINLINE static void SetInitValue(DType &initv, DType &none) { // NOLINT(*)
986  SetInitValue(initv);
987  }
988 };
989 } // namespace red
990 
991 #define MSHADOW_TYPE_SWITCH(type, DType, ...) \
992  switch (type) { \
993  case mshadow::kFloat32: \
994  { \
995  typedef float DType; \
996  {__VA_ARGS__} \
997  } \
998  break; \
999  case mshadow::kFloat64: \
1000  { \
1001  typedef double DType; \
1002  {__VA_ARGS__} \
1003  } \
1004  break; \
1005  case mshadow::kFloat16: \
1006  { \
1007  typedef mshadow::half::half_t DType; \
1008  {__VA_ARGS__} \
1009  } \
1010  break; \
1011  case mshadow::kUint8: \
1012  { \
1013  typedef uint8_t DType; \
1014  {__VA_ARGS__} \
1015  } \
1016  break; \
1017  case mshadow::kInt8: \
1018  { \
1019  typedef int8_t DType; \
1020  {__VA_ARGS__} \
1021  } \
1022  break; \
1023  case mshadow::kInt32: \
1024  { \
1025  typedef int32_t DType; \
1026  {__VA_ARGS__} \
1027  } \
1028  break; \
1029  case mshadow::kInt64: \
1030  { \
1031  typedef int64_t DType; \
1032  {__VA_ARGS__} \
1033  } \
1034  break; \
1035  default: \
1036  LOG(FATAL) << "Unknown type enum " << type; \
1037  }
1038 
1039 #define MSHADOW_TYPE_SWITCH_WITH_HALF2(type, DType, ...) \
1040  switch (type) { \
1041  case mshadow::kFloat32: \
1042  { \
1043  typedef float DType; \
1044  {__VA_ARGS__} \
1045  } \
1046  break; \
1047  case mshadow::kFloat64: \
1048  { \
1049  typedef double DType; \
1050  {__VA_ARGS__} \
1051  } \
1052  break; \
1053  case mshadow::kFloat16: \
1054  { \
1055  typedef mshadow::half::half2_t DType; \
1056  {__VA_ARGS__} \
1057  } \
1058  break; \
1059  case mshadow::kUint8: \
1060  { \
1061  typedef uint8_t DType; \
1062  {__VA_ARGS__} \
1063  } \
1064  break; \
1065  case mshadow::kInt32: \
1066  { \
1067  typedef int32_t DType; \
1068  {__VA_ARGS__} \
1069  } \
1070  break; \
1071  case mshadow::kInt64: \
1072  { \
1073  typedef int64_t DType; \
1074  {__VA_ARGS__} \
1075  } \
1076  break; \
1077  default: \
1078  LOG(FATAL) << "Unknown type enum " << type; \
1079  }
1080 
1081 #define MSHADOW_SGL_DBL_TYPE_SWITCH(type, DType, ...) \
1082  switch (type) { \
1083  case mshadow::kFloat32: \
1084  { \
1085  typedef float DType; \
1086  {__VA_ARGS__} \
1087  } \
1088  break; \
1089  case mshadow::kFloat64: \
1090  { \
1091  typedef double DType; \
1092  {__VA_ARGS__} \
1093  } \
1094  break; \
1095  default: \
1096  LOG(FATAL) << "This operation only supports " \
1097  "32-bit and 64-bit floating point"; \
1098  }
1099 
1100 #define MSHADOW_REAL_TYPE_SWITCH(type, DType, ...) \
1101  switch (type) { \
1102  case mshadow::kFloat32: \
1103  { \
1104  typedef float DType; \
1105  {__VA_ARGS__} \
1106  } \
1107  break; \
1108  case mshadow::kFloat64: \
1109  { \
1110  typedef double DType; \
1111  {__VA_ARGS__} \
1112  } \
1113  break; \
1114  case mshadow::kFloat16: \
1115  { \
1116  typedef mshadow::half::half_t DType; \
1117  {__VA_ARGS__} \
1118  } \
1119  break; \
1120  case mshadow::kUint8: \
1121  LOG(FATAL) << "This operation only support " \
1122  "floating point types not uint8"; \
1123  break; \
1124  case mshadow::kInt8: \
1125  LOG(FATAL) << "This operation only support " \
1126  "floating point types not int8"; \
1127  break; \
1128  case mshadow::kInt32: \
1129  LOG(FATAL) << "This operation only support " \
1130  "floating point types, not int32";\
1131  break; \
1132  case mshadow::kInt64: \
1133  LOG(FATAL) << "This operation only support " \
1134  "floating point types, not int64";\
1135  break; \
1136  default: \
1137  LOG(FATAL) << "Unknown type enum " << type; \
1138  }
1139 
1140 #define MSHADOW_REAL_TYPE_SWITCH_EX(type$, DType$, DLargeType$, ...) \
1141  switch (type$) { \
1142  case mshadow::kFloat32: \
1143  { \
1144  typedef float DType$; \
1145  typedef float DLargeType$; \
1146  {__VA_ARGS__} \
1147  } \
1148  break; \
1149  case mshadow::kFloat64: \
1150  { \
1151  typedef double DType$; \
1152  typedef double DLargeType$; \
1153  {__VA_ARGS__} \
1154  } \
1155  break; \
1156  case mshadow::kFloat16: \
1157  { \
1158  typedef mshadow::half::half_t DType$; \
1159  typedef float DLargeType$; \
1160  {__VA_ARGS__} \
1161  } \
1162  break; \
1163  case mshadow::kUint8: \
1164  LOG(FATAL) << "This operation only support " \
1165  "floating point types not uint8"; \
1166  break; \
1167  case mshadow::kInt8: \
1168  LOG(FATAL) << "This operation only support " \
1169  "floating point types not int8"; \
1170  break; \
1171  case mshadow::kInt32: \
1172  LOG(FATAL) << "This operation only support " \
1173  "floating point types, not int32";\
1174  break; \
1175  case mshadow::kInt64: \
1176  LOG(FATAL) << "This operation only support " \
1177  "floating point types, not int64";\
1178  break; \
1179  default: \
1180  LOG(FATAL) << "Unknown type enum " << type$; \
1181  }
1182 
1183 #define MSHADOW_LAYOUT_SWITCH(layout, Layout, ...) \
1184  switch (layout) { \
1185  case mshadow::kNCHW: \
1186  { \
1187  const int Layout = kNCHW; \
1188  {__VA_ARGS__} \
1189  } \
1190  break; \
1191  case mshadow::kNHWC: \
1192  { \
1193  const int Layout = kNHWC; \
1194  {__VA_ARGS__} \
1195  } \
1196  break; \
1197  case mshadow::kNCDHW: \
1198  { \
1199  const int Layout = kNCDHW; \
1200  {__VA_ARGS__} \
1201  } \
1202  break; \
1203  case mshadow::kNDHWC: \
1204  { \
1205  const int Layout = kNDHWC; \
1206  {__VA_ARGS__} \
1207  } \
1208  break; \
1209  default: \
1210  LOG(FATAL) << "Unknown layout enum " << layout; \
1211  }
1212 
1217 #define MSHADOW_IDX_TYPE_SWITCH(type, DType, ...) \
1218  switch (type) { \
1219  case mshadow::kInt64: \
1220  { \
1221  typedef int64_t DType; \
1222  {__VA_ARGS__} \
1223  } \
1224  break; \
1225  default: \
1226  LOG(FATAL) << "Unknown type enum " << type; \
1227  }
1228 
1229 #define MSHADOW_TYPE_SWITCH_WITH_BOOL(type, DType, ...) \
1230  switch (type) { \
1231  case mshadow::kFloat32: \
1232  { \
1233  typedef float DType; \
1234  {__VA_ARGS__} \
1235  } \
1236  break; \
1237  case mshadow::kFloat64: \
1238  { \
1239  typedef double DType; \
1240  {__VA_ARGS__} \
1241  } \
1242  break; \
1243  case mshadow::kFloat16: \
1244  { \
1245  typedef mshadow::half::half_t DType; \
1246  {__VA_ARGS__} \
1247  } \
1248  break; \
1249  case mshadow::kUint8: \
1250  { \
1251  typedef uint8_t DType; \
1252  {__VA_ARGS__} \
1253  } \
1254  break; \
1255  case mshadow::kInt8: \
1256  { \
1257  typedef int8_t DType; \
1258  {__VA_ARGS__} \
1259  } \
1260  break; \
1261  case mshadow::kInt32: \
1262  { \
1263  typedef int32_t DType; \
1264  {__VA_ARGS__} \
1265  } \
1266  break; \
1267  case mshadow::kInt64: \
1268  { \
1269  typedef int64_t DType; \
1270  {__VA_ARGS__} \
1271  } \
1272  break; \
1273  case mshadow::kBool: \
1274  { \
1275  typedef bool DType; \
1276  {__VA_ARGS__} \
1277  } \
1278  break; \
1279  default: \
1280  LOG(FATAL) << "Unknown type enum " << type; \
1281  }
1282 
1284 inline size_t mshadow_sizeof(int type) {
1285  int size = 0;
1286  MSHADOW_TYPE_SWITCH_WITH_BOOL(type, DType, size = sizeof(DType););
1287  return size;
1288 }
1289 
1290 } // namespace mshadow
1291 #endif // MSHADOW_BASE_H_
static MSHADOW_XINLINE void Finalize(volatile DType &dst, volatile DType &residual)
finalize reduction
Definition: base.h:965
#define MSHADOW_TYPE_SWITCH_WITH_BOOL(type, DType,...)
Definition: base.h:1229
Definition: base.h:307
static MSHADOW_XINLINE DType PartialGrad(DType redres, DType redsrc)
calculate gradient of redres with respect to redsrc, redres: reduced result, redsrc: one of reduction...
Definition: base.h:971
const int default_type_flag
type enum value for default real type
Definition: base.h:422
static MSHADOW_XINLINE void Finalize(volatile DType &dst, volatile DType &residual)
finalize reduction
Definition: base.h:912
MSHADOW_XINLINE DType MaxValue(void)
maximum value of certain types
Definition: base.h:434
MSHADOW_XINLINE int8_t MaxValue< int8_t >(void)
maximum value of int8_t
Definition: base.h:770
static default_real_t BetaBLAS(void)
helper constant to use BLAS, beta
Definition: base.h:584
MSHADOW_XINLINE float NegInfValue< float >(void)
negative infinity value of float
Definition: base.h:727
static MSHADOW_XINLINE void Save(DType &a, DType b)
save b to a using save method
Definition: base.h:564
MSHADOW_XINLINE int MinValue< int32_t >(void)
minimum value of int32_t
Definition: base.h:703
static MSHADOW_XINLINE void Save(DType &a, DType b)
save b to a using save method
Definition: base.h:578
MSHADOW_XINLINE float MinValue< float >(void)
minimum value of float
Definition: base.h:678
definition of vector float16, half2 type.
static MSHADOW_XINLINE DType PartialGrad(DType redres, DType redsrc)
calculate gradient of redres with respect to redsrc, redres: reduced result, redsrc: one of reduction...
Definition: base.h:864
static MSHADOW_XINLINE void SetInitValue(DType &initv, DType &residual)
set the initial value during reduction
Definition: base.h:878
save to saver: =
Definition: base.h:547
MSHADOW_XINLINE uint8_t MinValue< uint8_t >(void)
minimum value of uint8_t
Definition: base.h:693
MSHADOW_XINLINE bool IsInf(volatile DType val)
Definition: base.h:646
definition of half (float16) type.
static MSHADOW_XINLINE DType PartialGrad(DType redres, DType redsrc)
calculate gradient of redres with respect to redsrc, redres: reduced result, redsrc: one of reduction...
Definition: base.h:918
static MSHADOW_XINLINE void Reduce(volatile DType &dst, volatile DType src, volatile DType &none)
do reduction into dst
Definition: base.h:894
#define MSHADOW_HALF_EXPONENT_BITS
Definition: half.h:353
divide operator
Definition: base.h:516
MSHADOW_XINLINE DType NegInfValue(void)
negative infinity of certain types
Definition: base.h:722
static MSHADOW_XINLINE void Save(DType &a, DType b)
save b to a using save method
Definition: base.h:592
MSHADOW_XINLINE float PosInfValue< float >(void)
positive infinity value of float
Definition: base.h:799
static MSHADOW_XINLINE void SetInitValue(DType &initv)
set the initial value during reduction
Definition: base.h:871
op::right OPType
corresponding binary operator type
Definition: base.h:558
op::minus OPType
corresponding binary operator type
Definition: base.h:586
static MSHADOW_XINLINE DType Map(DType a, DType b)
map a, b to result using defined operation
Definition: base.h:519
MSHADOW_XINLINE bool IsNan(volatile mshadow::half::half_t val)
Definition: base.h:636
static MSHADOW_XINLINE void SetInitValue(DType &initv, DType &none)
set the initial value during reduction
Definition: base.h:932
static MSHADOW_XINLINE void Reduce(volatile DType &dst, volatile DType src)
do reduction into dst
Definition: base.h:940
static MSHADOW_XINLINE void Reduce(volatile DType &dst, volatile DType src)
do reduction into dst
Definition: base.h:887
Definition: base.h:426
identity function that maps a real number to it self
Definition: base.h:536
op::mul OPType
corresponding binary operator type
Definition: base.h:596
static MSHADOW_XINLINE void Save(DType &a, DType b)
save b to a using save method
Definition: base.h:550
static default_real_t AlphaBLAS(void)
helper constant to use BLAS, alpha
Definition: base.h:582
static MSHADOW_XINLINE void Finalize(volatile DType &dst, volatile DType &residual)
finalize reduction
Definition: base.h:858
Definition: base.h:440
MSHADOW_XINLINE int8_t MinValue< int8_t >(void)
minimum value of int8_t
Definition: base.h:698
Definition: base.h:427
static MSHADOW_XINLINE void Finalize(volatile DType &dst)
finalize reduction
Definition: base.h:909
MSHADOW_XINLINE bool IsInf(volatile mshadow::half::half_t val)
Definition: base.h:662
MSHADOW_XINLINE bool MinValue< bool >(void)
minimum value of bool
Definition: base.h:713
#define MSHADOW_XINLINE
Definition: base.h:204
static MSHADOW_XINLINE void Reduce(volatile DType &dst, volatile DType src, volatile DType &residual)
do stable reduction into dst
Definition: base.h:824
static MSHADOW_XINLINE void Save(DType &a, DType b)
save b to a using save method
Definition: base.h:602
const unsigned kRandBufferSize
buffer size for each random number generator
Definition: base.h:284
MSHADOW_XINLINE double MaxValue< double >(void)
maximum value of double
Definition: base.h:755
MSHADOW_XINLINE bool IsNan(volatile DType val)
Definition: base.h:620
const int default_layout
default layout for 4d tensor
Definition: base.h:463
static MSHADOW_XINLINE void SetInitValue(DType &initv, DType &none)
set the initial value during reduction
Definition: base.h:985
LayoutFlag
Definition: base.h:425
get rhs
Definition: base.h:524
minus to saver: -=
Definition: base.h:575
MSHADOW_XINLINE int MaxValue< int32_t >(void)
maximum value of int32_t
Definition: base.h:775
int32_t index_t
type that will be used for index
Definition: base.h:291
multiply to saver: *=
Definition: base.h:589
static MSHADOW_XINLINE DType Map(DType a, DType b)
map a, b to result using defined operation
Definition: base.h:503
const float kPi
pi
Definition: base.h:286
Definition: base.h:312
op::plus OPType
corresponding binary operator type
Definition: base.h:572
MSHADOW_XINLINE float MaxValue< float >(void)
maximum value of float
Definition: base.h:750
float default_real_t
float point type that will be used in default by mshadow
Definition: base.h:303
const int default_layout_5d
default layout for 5d tensor
Definition: base.h:486
static MSHADOW_XINLINE void Reduce(volatile DType &dst, volatile DType src, volatile DType &none)
do reduction into dst
Definition: base.h:947
MSHADOW_XINLINE int64_t MinValue< int64_t >(void)
minimum value of int64_t
Definition: base.h:708
minimum reducer
Definition: base.h:937
MSHADOW_XINLINE double MinValue< double >(void)
minimum value of double
Definition: base.h:683
#define MSHADOW_HALF_MIN
overloaded + operator for half_t
Definition: half.h:350
Definition: base.h:314
Definition: base.h:308
MSHADOW_XINLINE DType MinValue(void)
minimum value of certain types
divide to saver: /=
Definition: base.h:599
static MSHADOW_XINLINE DType Map(DType a, DType b)
map a, b to result using defined operation
Definition: base.h:495
static MSHADOW_XINLINE void Finalize(volatile DType &dst)
finalize reduction
Definition: base.h:962
MSHADOW_XINLINE DType PosInfValue(void)
positive infinity of certain types
Definition: base.h:794
size_t mshadow_sizeof(int type)
get data type size from type enum
Definition: base.h:1284
Definition: base.h:311
static MSHADOW_XINLINE void Merge(volatile DType &dst_val, volatile DType &src_val)
combine the results of two reducers
Definition: base.h:899
MSHADOW_XINLINE double NegInfValue< double >(void)
negative infinity value of double
Definition: base.h:732
static MSHADOW_XINLINE DType Map(DType a)
map a to result using defined operation
Definition: base.h:539
Definition: base.h:435
maximum reducer
Definition: base.h:884
static MSHADOW_XINLINE void Merge(volatile DType &dst_val, volatile DType &dst_residual, volatile DType &src_val, volatile DType &src_residual)
combine the results of two reducers
Definition: base.h:841
TypeFlag
data type flag
Definition: base.h:306
MSHADOW_XINLINE int64_t MaxValue< int64_t >(void)
maximum value of int64_t
Definition: base.h:780
Definition: base.h:430
Definition: base.h:436
static MSHADOW_XINLINE DType Map(DType a, DType b)
map a, b to result using defined operation
Definition: base.h:511
plus operator
Definition: base.h:500
static default_real_t AlphaBLAS(void)
helper constant to use BLAS, alpha
Definition: base.h:554
Definition: base.h:310
save to saver: +=
Definition: base.h:561
Definition: base.h:432
static MSHADOW_XINLINE void Merge(volatile DType &dst_val, volatile DType &dst_residual, volatile DType &src_val, volatile DType &src_residual)
combine the results of two reducers
Definition: base.h:957
sum reducer
Definition: base.h:816
Definition: base.h:431
static MSHADOW_XINLINE void Merge(volatile DType &dst_val, volatile DType &src_val)
combine the results of two reducers
Definition: base.h:836
static default_real_t BetaBLAS(void)
helper constant to use BLAS, beta
Definition: base.h:570
static MSHADOW_XINLINE void SetInitValue(DType &initv)
set the initial value during reduction
Definition: base.h:925
mul operator
Definition: base.h:492
#define MSHADOW_HALF_SIGN_BIT
Definition: half.h:352
#define MSHADOW_HALF_MAX
Definition: half.h:351
static default_real_t BetaBLAS(void)
helper constant to use BLAS, beta
Definition: base.h:556
Definition: base.h:318
namespace for mshadow
Definition: base.h:282
static MSHADOW_XINLINE void SetInitValue(DType &initv)
set the initial value during reduction
Definition: base.h:978
Definition: base.h:309
MSHADOW_XINLINE bool MaxValue< bool >(void)
maximum value of bool
Definition: base.h:785
Definition: base.h:313
static MSHADOW_XINLINE void Reduce(volatile DType &dst, volatile DType src)
do reduction into dst
Definition: base.h:819
MSHADOW_XINLINE uint8_t MaxValue< uint8_t >(void)
maximum value of uint8_t
Definition: base.h:765
Definition: base.h:428
static MSHADOW_XINLINE DType Map(DType a, DType b)
map a, b to result using defined operation
Definition: base.h:527
static MSHADOW_XINLINE void Merge(volatile DType &dst_val, volatile DType &src_val)
combine the results of two reducers
Definition: base.h:952
static MSHADOW_XINLINE void Merge(volatile DType &dst_val, volatile DType &dst_residual, volatile DType &src_val, volatile DType &src_residual)
combine the results of two reducers
Definition: base.h:904
static MSHADOW_XINLINE void Finalize(volatile DType &dst)
finalize reduction
Definition: base.h:855
static default_real_t AlphaBLAS(void)
helper constant to use BLAS, alpha
Definition: base.h:568
index_t openmp_index_t
openmp index for linux
Definition: base.h:299
minus operator
Definition: base.h:508
MSHADOW_XINLINE double PosInfValue< double >(void)
positive infinity value of double
Definition: base.h:804
op::div OPType
corresponding binary operator type
Definition: base.h:606