27 #ifndef MSHADOW_BASE_H_ 28 #define MSHADOW_BASE_H_ 30 #ifndef _CRT_SECURE_NO_WARNINGS 31 #define _CRT_SECURE_NO_WARNINGS 33 #ifndef _CRT_SECURE_NO_DEPRECATE 34 #define _CRT_SECURE_NO_DEPRECATE 50 typedef signed char int8_t;
52 typedef __int16 int16_t;
53 typedef __int32 int32_t;
54 typedef __int64 int64_t;
55 typedef unsigned char uint8_t;
56 typedef unsigned __int16 uint16_t;
57 typedef unsigned __int32 uint32_t;
58 typedef unsigned __int64 uint64_t;
68 #ifndef MSHADOW_STAND_ALONE 69 #define MSHADOW_STAND_ALONE 0 72 #ifndef MSHADOW_ALLOC_PAD 73 #define MSHADOW_ALLOC_PAD true 83 #ifndef MSHADOW_MIN_PAD_RATIO 84 #define MSHADOW_MIN_PAD_RATIO 2 87 #if MSHADOW_STAND_ALONE 88 #define MSHADOW_USE_CBLAS 0 89 #define MSHADOW_USE_MKL 0 90 #define MSHADOW_USE_CUDA 0 97 #ifndef MSHADOW_FORCE_STREAM 98 #define MSHADOW_FORCE_STREAM 1 102 #ifndef MSHADOW_USE_CBLAS 103 #define MSHADOW_USE_CBLAS 0 106 #ifndef MSHADOW_USE_MKL 107 #define MSHADOW_USE_MKL 1 114 #ifndef MSHADOW_USE_CUDA 115 #define MSHADOW_USE_CUDA 1 121 #ifndef MSHADOW_USE_CUDNN 122 #define MSHADOW_USE_CUDNN 0 128 #ifndef MSHADOW_USE_CUSOLVER 129 #define MSHADOW_USE_CUSOLVER MSHADOW_USE_CUDA 136 #ifndef MSHADOW_OLD_CUDA 137 #define MSHADOW_OLD_CUDA 0 143 #ifndef MSHADOW_IN_CXX11 144 #if (defined(__GXX_EXPERIMENTAL_CXX0X__) ||\ 145 __cplusplus >= 201103L || defined(_MSC_VER)) 146 #define MSHADOW_IN_CXX11 1 148 #define MSHADOW_IN_CXX11 0 153 #ifndef MSHADOW_USE_SSE 154 #define MSHADOW_USE_SSE 1 158 #ifndef MSHADOW_USE_F16C 159 #if defined(_MSC_VER) || defined(__CUDACC__) 160 #define MSHADOW_USE_F16C 0 161 #elif defined(__clang__) && \ 162 ((__clang_major__ < 8) || ((__clang_major__ == 8) && (__clang_minor__ < 1))) 163 #define MSHADOW_USE_F16C 0 165 #define MSHADOW_USE_F16C 1 170 #ifndef MSHADOW_USE_NVML 171 #define MSHADOW_USE_NVML 0 175 #undef MSHADOW_USE_SSE 176 #define MSHADOW_USE_SSE 0 179 #if MSHADOW_USE_CBLAS 183 #elif MSHADOW_USE_MKL 184 #include <mkl_blas.h> 185 #include <mkl_cblas.h> 187 #include <mkl_vsl_functions.h> 188 #include <mkl_version.h> 193 #include <cublas_v2.h> 197 #if MSHADOW_USE_CUDNN == 1 201 #if MSHADOW_USE_CUSOLVER == 1 202 #include <cusolverDn.h> 211 #ifdef MSHADOW_XINLINE 212 #error "MSHADOW_XINLINE must not be defined" 215 #define MSHADOW_FORCE_INLINE __forceinline 216 #pragma warning(disable : 4068) 218 #define MSHADOW_FORCE_INLINE inline __attribute__((always_inline)) 221 #define MSHADOW_XINLINE MSHADOW_FORCE_INLINE __device__ __host__ 223 #define MSHADOW_XINLINE MSHADOW_FORCE_INLINE 226 #define MSHADOW_CINLINE MSHADOW_FORCE_INLINE 228 #if defined(__GXX_EXPERIMENTAL_CXX0X) ||\ 229 defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L 230 #define MSHADOW_CONSTEXPR constexpr 232 #define MSHADOW_CONSTEXPR const 241 #ifndef MSHADOW_DEFAULT_DTYPE 242 #define MSHADOW_DEFAULT_DTYPE = ::mshadow::default_real_t 248 #ifndef MSHADOW_USE_GLOG 249 #define MSHADOW_USE_GLOG DMLC_USE_GLOG 250 #endif // MSHADOW_USE_GLOG 253 #define MSHADOW_THROW_EXCEPTION noexcept(false) 254 #define MSHADOW_NO_EXCEPTION noexcept(true) 256 #define MSHADOW_THROW_EXCEPTION 257 #define MSHADOW_NO_EXCEPTION 260 #if defined(_MSC_VER) 261 #define MSHADOW_ALIGNED(x) __declspec(align(x)) 263 #define MSHADOW_ALIGNED(x) __attribute__ ((aligned(x))) 271 #define MSHADOW_CUDA_CALL(func) \ 273 cudaError_t e = (func); \ 274 if (e == cudaErrorCudartUnloading) { \ 275 throw dmlc::Error(cudaGetErrorString(e)); \ 277 CHECK(e == cudaSuccess) \ 278 << "CUDA: " << cudaGetErrorString(e); \ 285 #define MSHADOW_CATCH_ERROR(func) \ 289 } catch (const dmlc::Error &e) { \ 290 std::string what = e.what(); \ 291 if (what.find("driver shutting down") == std::string::npos) { \ 292 LOG(ERROR) << "Ignore CUDA Error " << what; \ 300 #define MSHADOW_HALF_BF_OPERATOR(RTYPE, OP) \ 301 MSHADOW_XINLINE RTYPE operator OP(mshadow::half::half_t a, mshadow::bfloat::bf16_t b) { \ 302 return float(a) OP float(b); \ 304 MSHADOW_XINLINE RTYPE operator OP(mshadow::bfloat::bf16_t a, mshadow::half::half_t b) { \ 305 return float(a) OP float(b); \ 325 #include "./logging.h" 331 const float kPi = 3.1415926f;
333 #if MSHADOW_INT64_TENSOR_SIZE == 1 367 template<
typename DType>
373 static const int kLanes = 1;
375 #if (CUDA_VERSION >= 8000) 376 static const cudaDataType_t kCudaFlag = CUDA_R_32F;
378 #if MSHADOW_USE_CUDNN 379 static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_FLOAT;
380 typedef float ScaleType;
387 static const int kLanes = 1;
389 #if (CUDA_VERSION >= 8000) 390 static const cudaDataType_t kCudaFlag = CUDA_R_64F;
392 #if MSHADOW_USE_CUDNN 393 static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_DOUBLE;
394 typedef double ScaleType;
401 static const int kLanes = 1;
403 #if (CUDA_VERSION >= 8000) 404 static const cudaDataType_t kCudaFlag = CUDA_R_16F;
406 #if MSHADOW_USE_CUDNN 407 static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_HALF;
408 typedef float ScaleType;
415 static const int kLanes = 2;
420 static const int kLanes = 1;
425 static const int kLanes = 1;
427 #if (CUDA_VERSION >= 8000) 428 static const cudaDataType_t kCudaFlag = CUDA_R_8U;
430 #if (MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 6) 432 static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_INT8;
433 typedef uint8_t ScaleType;
440 static const int kLanes = 1;
442 #if (CUDA_VERSION >= 8000) 443 static const cudaDataType_t kCudaFlag = CUDA_R_8I;
445 #if (MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 6) 446 static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_INT8;
447 typedef int8_t ScaleType;
454 static const int kLanes = 1;
456 #if (CUDA_VERSION >= 8000) 457 static const cudaDataType_t kCudaFlag = CUDA_R_32I;
459 #if (MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 6) 460 static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_INT32;
461 typedef int32_t ScaleType;
468 static const int kLanes = 1;
473 static const int kLanes = 1;
499 static const index_t kNdim = 4;
500 #if (MSHADOW_USE_CUDA && MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 4) 501 static const cudnnTensorFormat_t kCudnnFlag = CUDNN_TENSOR_NCHW;
503 static const int kCudnnFlag = -1;
509 static const index_t kNdim = 4;
510 #if (MSHADOW_USE_CUDA && MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 4) 511 static const cudnnTensorFormat_t kCudnnFlag = CUDNN_TENSOR_NHWC;
513 static const int kCudnnFlag = -1;
522 static const index_t kNdim = 5;
523 #if (MSHADOW_USE_CUDA && MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 4) 524 static const cudnnTensorFormat_t kCudnnFlag = CUDNN_TENSOR_NCHW;
526 static const int kCudnnFlag = -1;
532 static const index_t kNdim = 5;
533 #if (MSHADOW_USE_CUDA && MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 4) 534 static const cudnnTensorFormat_t kCudnnFlag = CUDNN_TENSOR_NHWC;
536 static const int kCudnnFlag = -1;
549 template<
typename DType>
557 template<
typename DType>
565 template<
typename DType>
573 template<
typename DType>
581 template<
typename DType>
593 template<
typename DType>
604 template<
typename DType>
609 inline static default_real_t
AlphaBLAS(
void) {
return 1.0f; }
611 inline static default_real_t
BetaBLAS(
void) {
return 0.0f; }
618 template<
typename DType>
623 inline static default_real_t
AlphaBLAS(
void) {
return 1.0f; }
625 inline static default_real_t
BetaBLAS(
void) {
return 1.0f; }
632 template<
typename DType>
637 inline static default_real_t
AlphaBLAS(
void) {
return -1.0f; }
639 inline static default_real_t
BetaBLAS(
void) {
return 1.0f; }
646 template<
typename DType>
656 template<
typename DType>
665 #ifndef __CUDA_ARCH__ 673 namespace isnan_typed {
674 template<
typename DType>
699 namespace isinf_typed {
700 template<
typename DType>
729 template<
typename DType>
786 template<
typename DType>
788 return MinValue<DType>();
803 return half::half_t::Binary(
811 template<
typename DType>
868 template<
typename DType>
870 return MaxValue<DType>();
893 template<
typename DType>
898 template<
typename DType>
900 DType y = src - residual;
905 residual = (t - dst) - y;
910 template<
typename DType>
912 Reduce(dst_val, src_val);
915 template<
typename DType>
916 MSHADOW_XINLINE static void Merge(
volatile DType& dst_val,
volatile DType& dst_residual,
volatile DType& src_val,
volatile DType& src_residual) {
917 DType t1 = dst_val + src_val;
922 DType e = t1 - dst_val;
923 DType t2 = ((src_val - e) + (dst_val - (t1 - e))) + dst_residual + src_residual;
925 dst_residual = t2 - (dst_val - t1);
929 template<
typename DType>
932 template<
typename DType>
938 template<
typename DType>
945 template<
typename DType>
952 template<
typename DType>
961 template<
typename DType>
964 if (!(dst >= src)) dst = src;
968 template<
typename DType>
973 template<
typename DType>
975 Reduce(dst_val, src_val);
978 template<
typename DType>
979 MSHADOW_XINLINE static void Merge(
volatile DType& dst_val,
volatile DType& dst_residual,
volatile DType& src_val,
volatile DType& src_residual) {
980 Reduce(dst_val, src_val);
983 template<
typename DType>
986 template<
typename DType>
992 template<
typename DType>
994 return redres == redsrc ? 1: 0;
999 template<
typename DType>
1001 initv = limits::NegInfValue<DType>();
1006 template<
typename DType>
1008 SetInitValue(initv);
1014 template<
typename DType>
1017 if (!(dst <= src)) dst = src;
1021 template<
typename DType>
1026 template<
typename DType>
1028 Reduce(dst_val, src_val);
1031 template<
typename DType>
1032 MSHADOW_XINLINE static void Merge(
volatile DType& dst_val,
volatile DType& dst_residual,
volatile DType& src_val,
volatile DType& src_residual) {
1033 Reduce(dst_val, src_val);
1036 template<
typename DType>
1039 template<
typename DType>
1045 template<
typename DType>
1047 return redres == redsrc ? 1: 0;
1052 template<
typename DType>
1054 initv = limits::PosInfValue<DType>();
1059 template<
typename DType>
1061 SetInitValue(initv);
1067 #define MSHADOW_TYPE_SWITCH(type, DType, ...) \ 1069 case mshadow::kFloat32: \ 1071 typedef float DType; \ 1075 case mshadow::kFloat64: \ 1077 typedef double DType; \ 1081 case mshadow::kFloat16: \ 1083 typedef mshadow::half::half_t DType; \ 1087 case mshadow::kBfloat16: \ 1089 typedef mshadow::bfloat::bf16_t DType; \ 1093 case mshadow::kUint8: \ 1095 typedef uint8_t DType; \ 1099 case mshadow::kInt8: \ 1101 typedef int8_t DType; \ 1105 case mshadow::kInt32: \ 1107 typedef int32_t DType; \ 1111 case mshadow::kInt64: \ 1113 typedef int64_t DType; \ 1118 LOG(FATAL) << "Unknown type enum " << type; \ 1121 #define MSHADOW_TYPE_SWITCH(type, DType, ...) \ 1123 case mshadow::kFloat32: \ 1125 typedef float DType; \ 1129 case mshadow::kFloat64: \ 1131 typedef double DType; \ 1135 case mshadow::kFloat16: \ 1137 typedef mshadow::half::half_t DType; \ 1141 case mshadow::kUint8: \ 1143 typedef uint8_t DType; \ 1147 case mshadow::kInt8: \ 1149 typedef int8_t DType; \ 1153 case mshadow::kInt32: \ 1155 typedef int32_t DType; \ 1159 case mshadow::kInt64: \ 1161 typedef int64_t DType; \ 1166 LOG(FATAL) << "Unknown type enum " << type; \ 1170 #define MSHADOW_TYPE_SWITCH_WITH_HALF2(type, DType, ...) \ 1172 case mshadow::kFloat32: \ 1174 typedef float DType; \ 1178 case mshadow::kFloat64: \ 1180 typedef double DType; \ 1184 case mshadow::kFloat16: \ 1186 typedef mshadow::half::half2_t DType; \ 1190 case mshadow::kUint8: \ 1192 typedef uint8_t DType; \ 1196 case mshadow::kInt32: \ 1198 typedef int32_t DType; \ 1202 case mshadow::kInt64: \ 1204 typedef int64_t DType; \ 1209 LOG(FATAL) << "Unknown type enum " << type; \ 1212 #define MSHADOW_SGL_DBL_TYPE_SWITCH(type, DType, ...) \ 1214 case mshadow::kFloat32: \ 1216 typedef float DType; \ 1220 case mshadow::kFloat64: \ 1222 typedef double DType; \ 1227 LOG(FATAL) << "This operation only supports " \ 1228 "32-bit and 64-bit floating point"; \ 1231 #define MSHADOW_REAL_TYPE_SWITCH(type, DType, ...) \ 1233 case mshadow::kFloat32: \ 1235 typedef float DType; \ 1239 case mshadow::kFloat64: \ 1241 typedef double DType; \ 1245 case mshadow::kFloat16: \ 1247 typedef mshadow::half::half_t DType; \ 1251 case mshadow::kUint8: \ 1252 LOG(FATAL) << "This operation only support " \ 1253 "floating point types not uint8"; \ 1255 case mshadow::kInt8: \ 1256 LOG(FATAL) << "This operation only support " \ 1257 "floating point types not int8"; \ 1259 case mshadow::kInt32: \ 1260 LOG(FATAL) << "This operation only support " \ 1261 "floating point types, not int32";\ 1263 case mshadow::kInt64: \ 1264 LOG(FATAL) << "This operation only support " \ 1265 "floating point types, not int64";\ 1268 LOG(FATAL) << "Unknown type enum " << type; \ 1272 #define MSHADOW_REAL_TYPE_SWITCH_EX(type$, DType$, DLargeType$, ...) \ 1274 case mshadow::kFloat32: \ 1276 typedef float DType$; \ 1277 typedef float DLargeType$; \ 1281 case mshadow::kFloat64: \ 1283 typedef double DType$; \ 1284 typedef double DLargeType$; \ 1288 case mshadow::kFloat16: \ 1290 typedef mshadow::half::half_t DType$; \ 1291 typedef float DLargeType$; \ 1295 case mshadow::kBfloat16: \ 1297 typedef mshadow::bfloat::bf16_t DType$; \ 1298 typedef float DLargeType$; \ 1302 case mshadow::kUint8: \ 1303 LOG(FATAL) << "This operation only support " \ 1304 "floating point types not uint8"; \ 1306 case mshadow::kInt8: \ 1307 LOG(FATAL) << "This operation only support " \ 1308 "floating point types not int8"; \ 1310 case mshadow::kInt32: \ 1311 LOG(FATAL) << "This operation only support " \ 1312 "floating point types, not int32";\ 1314 case mshadow::kInt64: \ 1315 LOG(FATAL) << "This operation only support " \ 1316 "floating point types, not int64";\ 1319 LOG(FATAL) << "Unknown type enum " << type$; \ 1322 #define MSHADOW_REAL_TYPE_SWITCH_EX(type$, DType$, DLargeType$, ...) \ 1324 case mshadow::kFloat32: \ 1326 typedef float DType$; \ 1327 typedef float DLargeType$; \ 1331 case mshadow::kFloat64: \ 1333 typedef double DType$; \ 1334 typedef double DLargeType$; \ 1338 case mshadow::kFloat16: \ 1340 typedef mshadow::half::half_t DType$; \ 1341 typedef float DLargeType$; \ 1345 case mshadow::kUint8: \ 1346 LOG(FATAL) << "This operation only support " \ 1347 "floating point types not uint8"; \ 1349 case mshadow::kInt8: \ 1350 LOG(FATAL) << "This operation only support " \ 1351 "floating point types not int8"; \ 1353 case mshadow::kInt32: \ 1354 LOG(FATAL) << "This operation only support " \ 1355 "floating point types, not int32";\ 1357 case mshadow::kInt64: \ 1358 LOG(FATAL) << "This operation only support " \ 1359 "floating point types, not int64";\ 1362 LOG(FATAL) << "Unknown type enum " << type$; \ 1365 #define MSHADOW_LAYOUT_SWITCH(layout, Layout, ...) \ 1367 case mshadow::kNCHW: \ 1369 const int Layout = kNCHW; \ 1373 case mshadow::kNHWC: \ 1375 const int Layout = kNHWC; \ 1379 case mshadow::kNCDHW: \ 1381 const int Layout = kNCDHW; \ 1385 case mshadow::kNDHWC: \ 1387 const int Layout = kNDHWC; \ 1392 LOG(FATAL) << "Unknown layout enum " << layout; \ 1399 #define MSHADOW_IDX_TYPE_SWITCH(type, DType, ...) \ 1401 case mshadow::kInt64: \ 1403 typedef int64_t DType; \ 1408 LOG(FATAL) << "Unknown type enum " << type; \ 1411 #define MSHADOW_TYPE_SWITCH_WITH_BOOL(type, DType, ...) \ 1413 case mshadow::kFloat32: \ 1415 typedef float DType; \ 1419 case mshadow::kFloat64: \ 1421 typedef double DType; \ 1425 case mshadow::kFloat16: \ 1427 typedef mshadow::half::half_t DType; \ 1431 case mshadow::kBfloat16: \ 1433 typedef mshadow::bfloat::bf16_t DType; \ 1437 case mshadow::kUint8: \ 1439 typedef uint8_t DType; \ 1443 case mshadow::kInt8: \ 1445 typedef int8_t DType; \ 1449 case mshadow::kInt32: \ 1451 typedef int32_t DType; \ 1455 case mshadow::kInt64: \ 1457 typedef int64_t DType; \ 1461 case mshadow::kBool: \ 1463 typedef bool DType; \ 1468 LOG(FATAL) << "Unknown type enum " << type; \ 1488 return "unsigned char";
1498 LOG(FATAL) <<
"Unknown type enum " << dtype;
1504 #endif // MSHADOW_BASE_H_ static MSHADOW_XINLINE void Finalize(volatile DType &dst, volatile DType &residual)
finalize reduction
Definition: base.h:1040
#define MSHADOW_TYPE_SWITCH_WITH_BOOL(type, DType,...)
Definition: base.h:1411
static MSHADOW_XINLINE DType PartialGrad(DType redres, DType redsrc)
calculate gradient of redres with respect to redsrc, redres: reduced result, redsrc: one of reduction...
Definition: base.h:1046
const int default_type_flag
type enum value for default real type
Definition: base.h:477
static MSHADOW_XINLINE void Finalize(volatile DType &dst, volatile DType &residual)
finalize reduction
Definition: base.h:987
MSHADOW_XINLINE DType MaxValue(void)
maximum value of certain types
MSHADOW_XINLINE int8_t MaxValue< int8_t >(void)
maximum value of int8_t
Definition: base.h:840
static default_real_t BetaBLAS(void)
helper constant to use BLAS, beta
Definition: base.h:639
MSHADOW_XINLINE float NegInfValue< float >(void)
negative infinity value of float
Definition: base.h:792
static MSHADOW_XINLINE void Save(DType &a, DType b)
save b to a using save method
Definition: base.h:619
MSHADOW_XINLINE int MinValue< int32_t >(void)
minimum value of int32_t
Definition: base.h:763
static MSHADOW_XINLINE void Save(DType &a, DType b)
save b to a using save method
Definition: base.h:633
MSHADOW_XINLINE float MinValue< float >(void)
minimum value of float
Definition: base.h:733
definition of vector float16, half2 type.
static MSHADOW_XINLINE DType PartialGrad(DType redres, DType redsrc)
calculate gradient of redres with respect to redsrc, redres: reduced result, redsrc: one of reduction...
Definition: base.h:939
static MSHADOW_XINLINE void SetInitValue(DType &initv, DType &residual)
set the initial value during reduction
Definition: base.h:953
MSHADOW_XINLINE uint32_t MaxValue< uint32_t >(void)
maximum value of uint32_t
Definition: base.h:860
save to saver: =
Definition: base.h:602
MSHADOW_XINLINE uint8_t MinValue< uint8_t >(void)
minimum value of uint8_t
Definition: base.h:753
MSHADOW_XINLINE bool IsInf(volatile DType val)
Definition: base.h:701
definition of bfloat type.
definition of half (float16) type.
static MSHADOW_XINLINE DType PartialGrad(DType redres, DType redsrc)
calculate gradient of redres with respect to redsrc, redres: reduced result, redsrc: one of reduction...
Definition: base.h:993
static MSHADOW_XINLINE void Reduce(volatile DType &dst, volatile DType src, volatile DType &none)
do reduction into dst
Definition: base.h:969
#define MSHADOW_HALF_EXPONENT_BITS
Definition: half.h:372
divide operator
Definition: base.h:571
MSHADOW_XINLINE DType NegInfValue(void)
negative infinity of certain types
Definition: base.h:787
static MSHADOW_XINLINE void Save(DType &a, DType b)
save b to a using save method
Definition: base.h:647
MSHADOW_XINLINE float PosInfValue< float >(void)
positive infinity value of float
Definition: base.h:874
static MSHADOW_XINLINE void SetInitValue(DType &initv)
set the initial value during reduction
Definition: base.h:946
op::right OPType
corresponding binary operator type
Definition: base.h:613
op::minus OPType
corresponding binary operator type
Definition: base.h:641
static MSHADOW_XINLINE DType Map(DType a, DType b)
map a, b to result using defined operation
Definition: base.h:574
MSHADOW_XINLINE bool IsNan(volatile mshadow::half::half_t val)
Definition: base.h:691
static MSHADOW_XINLINE void SetInitValue(DType &initv, DType &none)
set the initial value during reduction
Definition: base.h:1007
static MSHADOW_XINLINE void Reduce(volatile DType &dst, volatile DType src)
do reduction into dst
Definition: base.h:1015
static MSHADOW_XINLINE void Reduce(volatile DType &dst, volatile DType src)
do reduction into dst
Definition: base.h:962
identity function that maps a real number to it self
Definition: base.h:591
op::mul OPType
corresponding binary operator type
Definition: base.h:651
static MSHADOW_XINLINE void Save(DType &a, DType b)
save b to a using save method
Definition: base.h:605
static default_real_t AlphaBLAS(void)
helper constant to use BLAS, alpha
Definition: base.h:637
static MSHADOW_XINLINE void Finalize(volatile DType &dst, volatile DType &residual)
finalize reduction
Definition: base.h:933
MSHADOW_XINLINE int8_t MinValue< int8_t >(void)
minimum value of int8_t
Definition: base.h:758
static MSHADOW_XINLINE void Finalize(volatile DType &dst)
finalize reduction
Definition: base.h:984
MSHADOW_XINLINE bool IsInf(volatile mshadow::half::half_t val)
Definition: base.h:717
MSHADOW_XINLINE bool MinValue< bool >(void)
minimum value of bool
Definition: base.h:773
#define MSHADOW_XINLINE
Definition: base.h:223
static MSHADOW_XINLINE void Reduce(volatile DType &dst, volatile DType src, volatile DType &residual)
do stable reduction into dst
Definition: base.h:899
static MSHADOW_XINLINE void Save(DType &a, DType b)
save b to a using save method
Definition: base.h:657
const unsigned kRandBufferSize
buffer size for each random number generator
Definition: base.h:329
MSHADOW_XINLINE double MaxValue< double >(void)
maximum value of double
Definition: base.h:820
MSHADOW_XINLINE bool IsNan(volatile DType val)
Definition: base.h:675
const int default_layout
default layout for 4d tensor
Definition: base.h:518
static MSHADOW_XINLINE void SetInitValue(DType &initv, DType &none)
set the initial value during reduction
Definition: base.h:1060
LayoutFlag
Definition: base.h:480
get rhs
Definition: base.h:579
#define MSHADOW_BF16_MAX
Definition: bfloat.h:183
std::string dtype_string(const int dtype)
Definition: base.h:1479
minus to saver: -=
Definition: base.h:630
MSHADOW_XINLINE int MaxValue< int32_t >(void)
maximum value of int32_t
Definition: base.h:845
int32_t index_t
type that will be used for index
Definition: base.h:336
multiply to saver: *=
Definition: base.h:644
static MSHADOW_XINLINE DType Map(DType a, DType b)
map a, b to result using defined operation
Definition: base.h:558
const float kPi
pi
Definition: base.h:331
op::plus OPType
corresponding binary operator type
Definition: base.h:627
MSHADOW_XINLINE float MaxValue< float >(void)
maximum value of float
Definition: base.h:815
float default_real_t
float point type that will be used in default by mshadow
Definition: base.h:348
const int default_layout_5d
default layout for 5d tensor
Definition: base.h:541
static MSHADOW_XINLINE void Reduce(volatile DType &dst, volatile DType src, volatile DType &none)
do reduction into dst
Definition: base.h:1022
MSHADOW_XINLINE int64_t MinValue< int64_t >(void)
minimum value of int64_t
Definition: base.h:768
minimum reducer
Definition: base.h:1012
MSHADOW_XINLINE double MinValue< double >(void)
minimum value of double
Definition: base.h:738
#define MSHADOW_HALF_MIN
overloaded + operator for half_t
Definition: half.h:369
MSHADOW_XINLINE DType MinValue(void)
minimum value of certain types
divide to saver: /=
Definition: base.h:654
static MSHADOW_XINLINE DType Map(DType a, DType b)
map a, b to result using defined operation
Definition: base.h:550
#define MSHADOW_HALF_BF_OPERATOR(RTYPE, OP)
Definition: base.h:300
static MSHADOW_XINLINE void Finalize(volatile DType &dst)
finalize reduction
Definition: base.h:1037
MSHADOW_XINLINE DType PosInfValue(void)
positive infinity of certain types
Definition: base.h:869
size_t mshadow_sizeof(int type)
get data type size from type enum
Definition: base.h:1472
static MSHADOW_XINLINE void Merge(volatile DType &dst_val, volatile DType &src_val)
combine the results of two reducers
Definition: base.h:974
MSHADOW_XINLINE double NegInfValue< double >(void)
negative infinity value of double
Definition: base.h:797
static MSHADOW_XINLINE DType Map(DType a)
map a to result using defined operation
Definition: base.h:594
maximum reducer
Definition: base.h:959
static MSHADOW_XINLINE void Merge(volatile DType &dst_val, volatile DType &dst_residual, volatile DType &src_val, volatile DType &src_residual)
combine the results of two reducers
Definition: base.h:916
TypeFlag
data type flag
Definition: base.h:351
MSHADOW_XINLINE int64_t MaxValue< int64_t >(void)
maximum value of int64_t
Definition: base.h:850
static MSHADOW_XINLINE DType Map(DType a, DType b)
map a, b to result using defined operation
Definition: base.h:566
plus operator
Definition: base.h:555
static default_real_t AlphaBLAS(void)
helper constant to use BLAS, alpha
Definition: base.h:609
save to saver: +=
Definition: base.h:616
static MSHADOW_XINLINE void Merge(volatile DType &dst_val, volatile DType &dst_residual, volatile DType &src_val, volatile DType &src_residual)
combine the results of two reducers
Definition: base.h:1032
sum reducer
Definition: base.h:891
static MSHADOW_XINLINE void Merge(volatile DType &dst_val, volatile DType &src_val)
combine the results of two reducers
Definition: base.h:911
static default_real_t BetaBLAS(void)
helper constant to use BLAS, beta
Definition: base.h:625
static MSHADOW_XINLINE void SetInitValue(DType &initv)
set the initial value during reduction
Definition: base.h:1000
MSHADOW_XINLINE unsigned int MinValue< unsigned int >(void)
minimum value of unsigned int
Definition: base.h:778
mul operator
Definition: base.h:547
#define MSHADOW_HALF_SIGN_BIT
Definition: half.h:371
#define MSHADOW_HALF_MAX
Definition: half.h:370
static default_real_t BetaBLAS(void)
helper constant to use BLAS, beta
Definition: base.h:611
overloaded + operator between half_t and bf16_t
Definition: base.h:327
static MSHADOW_XINLINE void SetInitValue(DType &initv)
set the initial value during reduction
Definition: base.h:1053
MSHADOW_XINLINE bool MaxValue< bool >(void)
maximum value of bool
Definition: base.h:855
static MSHADOW_XINLINE void Reduce(volatile DType &dst, volatile DType src)
do reduction into dst
Definition: base.h:894
MSHADOW_XINLINE uint8_t MaxValue< uint8_t >(void)
maximum value of uint8_t
Definition: base.h:835
static MSHADOW_XINLINE DType Map(DType a, DType b)
map a, b to result using defined operation
Definition: base.h:582
static MSHADOW_XINLINE void Merge(volatile DType &dst_val, volatile DType &src_val)
combine the results of two reducers
Definition: base.h:1027
static MSHADOW_XINLINE void Merge(volatile DType &dst_val, volatile DType &dst_residual, volatile DType &src_val, volatile DType &src_residual)
combine the results of two reducers
Definition: base.h:979
static MSHADOW_XINLINE void Finalize(volatile DType &dst)
finalize reduction
Definition: base.h:930
static default_real_t AlphaBLAS(void)
helper constant to use BLAS, alpha
Definition: base.h:623
index_t openmp_index_t
openmp index for linux
Definition: base.h:344
minus operator
Definition: base.h:563
MSHADOW_XINLINE double PosInfValue< double >(void)
positive infinity value of double
Definition: base.h:879
op::div OPType
corresponding binary operator type
Definition: base.h:661
#define MSHADOW_BF16_MIN
overloaded + operator for bf16_t
Definition: bfloat.h:182