Go to the documentation of this file.
26 #ifndef MSHADOW_BASE_H_
27 #define MSHADOW_BASE_H_
29 #ifndef _CRT_SECURE_NO_WARNINGS
30 #define _CRT_SECURE_NO_WARNINGS
32 #ifndef _CRT_SECURE_NO_DEPRECATE
33 #define _CRT_SECURE_NO_DEPRECATE
50 typedef signed char int8_t;
52 typedef __int16 int16_t;
53 typedef __int32 int32_t;
54 typedef __int64 int64_t;
55 typedef unsigned char uint8_t;
56 typedef unsigned __int16 uint16_t;
57 typedef unsigned __int32 uint32_t;
58 typedef unsigned __int64 uint64_t;
68 #ifndef MSHADOW_STAND_ALONE
69 #define MSHADOW_STAND_ALONE 0
72 #ifndef MSHADOW_ALLOC_PAD
73 #define MSHADOW_ALLOC_PAD true
83 #ifndef MSHADOW_MIN_PAD_RATIO
84 #define MSHADOW_MIN_PAD_RATIO 2
87 #if MSHADOW_STAND_ALONE
88 #define MSHADOW_USE_CBLAS 0
89 #define MSHADOW_USE_MKL 0
90 #define MSHADOW_USE_CUDA 0
97 #ifndef MSHADOW_FORCE_STREAM
98 #define MSHADOW_FORCE_STREAM 1
102 #ifndef MSHADOW_USE_CBLAS
103 #define MSHADOW_USE_CBLAS 0
106 #ifndef MSHADOW_USE_MKL
107 #define MSHADOW_USE_MKL 1
114 #ifndef MSHADOW_USE_CUDA
115 #define MSHADOW_USE_CUDA 1
121 #ifndef MSHADOW_USE_CUDNN
122 #define MSHADOW_USE_CUDNN 0
128 #ifndef MSHADOW_USE_CUTENSOR
129 #define MSHADOW_USE_CUTENSOR 0
135 #ifndef MSHADOW_USE_CUSOLVER
136 #define MSHADOW_USE_CUSOLVER MSHADOW_USE_CUDA
143 #ifndef MSHADOW_OLD_CUDA
144 #define MSHADOW_OLD_CUDA 0
148 #ifndef MSHADOW_USE_SSE
149 #define MSHADOW_USE_SSE 1
153 #ifndef MSHADOW_USE_F16C
154 #if defined(_MSC_VER) || defined(__CUDACC__)
155 #define MSHADOW_USE_F16C 0
156 #elif defined(__clang__) && \
157 ((__clang_major__ < 8) || ((__clang_major__ == 8) && (__clang_minor__ < 1)))
158 #define MSHADOW_USE_F16C 0
160 #define MSHADOW_USE_F16C 1
165 #ifndef MSHADOW_USE_NVML
166 #define MSHADOW_USE_NVML 0
170 #undef MSHADOW_USE_SSE
171 #define MSHADOW_USE_SSE 0
174 #if MSHADOW_USE_CBLAS
178 #elif MSHADOW_USE_MKL
179 #if MSHADOW_INT64_TENSOR_SIZE == 1
182 #define MKL_INT int64_t
183 #define MKL_UINT uint64_t
185 #include <mkl_blas.h>
186 #include <mkl_cblas.h>
188 #include <mkl_vsl_functions.h>
189 #include <mkl_version.h>
194 #include <cublas_v2.h>
198 #if MSHADOW_USE_CUDNN == 1
202 #if MSHADOW_USE_CUTENSOR == 1
203 #include <cutensor.h>
206 #if MSHADOW_USE_CUSOLVER == 1
207 #include <cusolverDn.h>
216 #ifdef MSHADOW_XINLINE
217 #error "MSHADOW_XINLINE must not be defined"
220 #define MSHADOW_FORCE_INLINE __forceinline
221 #pragma warning(disable : 4068)
223 #define MSHADOW_FORCE_INLINE inline __attribute__((always_inline))
226 #define MSHADOW_XINLINE MSHADOW_FORCE_INLINE __device__ __host__
228 #define MSHADOW_XINLINE MSHADOW_FORCE_INLINE
231 #define MSHADOW_CINLINE MSHADOW_FORCE_INLINE
239 #ifndef MSHADOW_DEFAULT_DTYPE
240 #define MSHADOW_DEFAULT_DTYPE = ::mshadow::default_real_t
246 #ifndef MSHADOW_USE_GLOG
247 #define MSHADOW_USE_GLOG DMLC_USE_GLOG
248 #endif // MSHADOW_USE_GLOG
250 #define MSHADOW_THROW_EXCEPTION noexcept(false)
251 #define MSHADOW_NO_EXCEPTION noexcept(true)
253 #if defined(_MSC_VER)
254 #define MSHADOW_ALIGNED(x) __declspec(align(x))
256 #define MSHADOW_ALIGNED(x) __attribute__ ((aligned(x)))
264 #define MSHADOW_CUDA_CALL(func) \
266 cudaError_t e = (func); \
267 if (e == cudaErrorCudartUnloading) { \
268 throw dmlc::Error(cudaGetErrorString(e)); \
270 CHECK_EQ(e, cudaSuccess) \
271 << "CUDA: " << cudaGetErrorString(e); \
278 #define MSHADOW_CATCH_ERROR(func) \
282 } catch (const dmlc::Error &e) { \
283 std::string what = e.what(); \
284 if (what.find("driver shutting down") == std::string::npos) { \
285 LOG(ERROR) << "Ignore CUDA Error " << what; \
292 #define MSHADOW_HALF_BF_OPERATOR(RTYPE, OP) \
293 MSHADOW_XINLINE RTYPE operator OP(mshadow::half::half_t a, mshadow::bfloat::bf16_t b) { \
294 return float(a) OP float(b); \
296 MSHADOW_XINLINE RTYPE operator OP(mshadow::bfloat::bf16_t a, mshadow::half::half_t b) { \
297 return float(a) OP float(b); \
317 #include "dmlc/logging.h"
323 const float kPi = 3.1415926f;
325 #if MSHADOW_INT64_TENSOR_SIZE == 1
340 #if (MSHADOW_USE_MKL && MXNET_USE_LAPACK) || MXNET_USE_ILP64_LAPACKE
367 template<
typename DType>
373 static const int kLanes = 1;
375 #if (CUDA_VERSION >= 8000)
376 static const cudaDataType_t kCudaFlag = CUDA_R_32F;
378 #if MSHADOW_USE_CUDNN
379 static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_FLOAT;
380 typedef float ScaleType;
387 static const int kLanes = 1;
389 #if (CUDA_VERSION >= 8000)
390 static const cudaDataType_t kCudaFlag = CUDA_R_64F;
392 #if MSHADOW_USE_CUDNN
393 static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_DOUBLE;
394 typedef double ScaleType;
401 static const int kLanes = 1;
403 #if (CUDA_VERSION >= 8000)
404 static const cudaDataType_t kCudaFlag = CUDA_R_16F;
406 #if MSHADOW_USE_CUDNN
407 static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_HALF;
408 typedef float ScaleType;
415 static const int kLanes = 1;
420 static const int kLanes = 1;
422 #if (CUDA_VERSION >= 8000)
423 static const cudaDataType_t kCudaFlag = CUDA_R_8U;
425 #if (MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 6)
427 static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_INT8;
428 typedef uint8_t ScaleType;
435 static const int kLanes = 1;
437 #if (CUDA_VERSION >= 8000)
438 static const cudaDataType_t kCudaFlag = CUDA_R_8I;
440 #if (MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 6)
441 static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_INT8;
442 typedef int8_t ScaleType;
449 static const int kLanes = 1;
451 #if (CUDA_VERSION >= 8000)
452 static const cudaDataType_t kCudaFlag = CUDA_R_32I;
454 #if (MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 6)
455 static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_INT32;
456 typedef int32_t ScaleType;
463 static const int kLanes = 1;
468 static const int kLanes = 1;
473 static const int kLanes = 1;
478 static const int kLanes = 1;
483 static const int kLanes = 1;
488 static const int kLanes = 1;
515 switch (layoutstr.length()) {
517 if (layoutstr ==
"NHWC")
519 if (layoutstr ==
"NCHW")
521 if (layoutstr ==
"CHWN")
525 if (layoutstr ==
"NWC")
527 if (layoutstr ==
"NCW")
529 if (layoutstr ==
"CWN")
533 if (layoutstr ==
"NDHWC")
535 if (layoutstr ==
"NCDHW")
537 if (layoutstr ==
"CDHWN")
578 #if (MSHADOW_USE_CUDA && MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 4)
579 static const cudnnTensorFormat_t kCudnnFlag = CUDNN_TENSOR_NCHW;
581 static const int kCudnnFlag = -1;
588 #if (MSHADOW_USE_CUDA && MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 4)
589 static const cudnnTensorFormat_t kCudnnFlag = CUDNN_TENSOR_NHWC;
591 static const int kCudnnFlag = -1;
601 #if (MSHADOW_USE_CUDA && MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 4)
602 static const cudnnTensorFormat_t kCudnnFlag = CUDNN_TENSOR_NCHW;
604 static const int kCudnnFlag = -1;
611 #if (MSHADOW_USE_CUDA && MSHADOW_USE_CUDNN == 1 && CUDNN_MAJOR >= 4)
612 static const cudnnTensorFormat_t kCudnnFlag = CUDNN_TENSOR_NHWC;
614 static const int kCudnnFlag = -1;
627 template<
typename DType>
635 template<
typename DType>
643 template<
typename DType>
651 template<
typename DType>
659 template<
typename DType>
671 template<
typename DType>
682 template<
typename DType>
696 template<
typename DType>
710 template<
typename DType>
724 template<
typename DType>
734 template<
typename DType>
743 #ifndef __CUDA_ARCH__
751 namespace isnan_typed {
752 template<
typename DType>
781 namespace isinf_typed {
782 template<
typename DType>
815 template<
typename DType>
872 template<
typename DType>
874 return MinValue<DType>();
889 return half::half_t::Binary(
902 template<
typename DType>
952 return std::numeric_limits<uint32_t>::max();
959 template<
typename DType>
961 return MaxValue<DType>();
989 template<
typename DType>
994 template<
typename DType>
996 DType y = src - residual;
1001 residual = (t - dst) - y;
1006 template<
typename DType>
1008 Reduce(dst_val, src_val);
1011 template<
typename DType>
1012 MSHADOW_XINLINE static void Merge(
volatile DType& dst_val,
volatile DType& dst_residual,
volatile DType& src_val,
volatile DType& src_residual) {
1013 DType t1 = dst_val + src_val;
1018 DType e = t1 - dst_val;
1019 DType t2 = ((src_val - e) + (dst_val - (t1 - e))) + dst_residual + src_residual;
1021 dst_residual = t2 - (dst_val - t1);
1025 template<
typename DType>
1028 template<
typename DType>
1034 template<
typename DType>
1041 template<
typename DType>
1048 template<
typename DType>
1057 template<
typename DType>
1060 if (!(dst >= src)) dst = src;
1064 template<
typename DType>
1069 template<
typename DType>
1071 Reduce(dst_val, src_val);
1074 template<
typename DType>
1075 MSHADOW_XINLINE static void Merge(
volatile DType& dst_val,
volatile DType& dst_residual,
volatile DType& src_val,
volatile DType& src_residual) {
1076 Reduce(dst_val, src_val);
1079 template<
typename DType>
1082 template<
typename DType>
1088 template<
typename DType>
1090 return redres == redsrc ? 1: 0;
1095 template<
typename DType>
1097 initv = limits::NegInfValue<DType>();
1102 template<
typename DType>
1110 template<
typename DType>
1113 if (!(dst <= src)) dst = src;
1117 template<
typename DType>
1122 template<
typename DType>
1124 Reduce(dst_val, src_val);
1127 template<
typename DType>
1128 MSHADOW_XINLINE static void Merge(
volatile DType& dst_val,
volatile DType& dst_residual,
volatile DType& src_val,
volatile DType& src_residual) {
1129 Reduce(dst_val, src_val);
1132 template<
typename DType>
1135 template<
typename DType>
1141 template<
typename DType>
1143 return redres == redsrc ? 1: 0;
1148 template<
typename DType>
1150 initv = limits::PosInfValue<DType>();
1155 template<
typename DType>
1163 #define MSHADOW_TYPE_SWITCH(type, DType, ...) \
1165 case mshadow::kFloat32: \
1167 typedef float DType; \
1171 case mshadow::kFloat64: \
1173 typedef double DType; \
1177 case mshadow::kFloat16: \
1179 typedef mshadow::half::half_t DType; \
1183 case mshadow::kBfloat16: \
1185 typedef mshadow::bfloat::bf16_t DType; \
1189 case mshadow::kUint8: \
1191 typedef uint8_t DType; \
1195 case mshadow::kInt8: \
1197 typedef int8_t DType; \
1201 case mshadow::kInt32: \
1203 typedef int32_t DType; \
1207 case mshadow::kInt64: \
1209 typedef int64_t DType; \
1213 case mshadow::kBool: \
1214 LOG(FATAL) << "This operation does not " \
1215 "support bool type"; \
1217 case mshadow::kInt16: \
1218 LOG(FATAL) << "This operation does not " \
1219 "support int16 type"; \
1221 case mshadow::kUint16: \
1222 LOG(FATAL) << "This operation does not " \
1223 "support uint16 type"; \
1225 case mshadow::kUint32: \
1226 LOG(FATAL) << "This operation does not " \
1227 "support uint32 type"; \
1229 case mshadow::kUint64: \
1230 LOG(FATAL) << "This operation does not " \
1231 "support uint64 type"; \
1234 LOG(FATAL) << "Unknown type enum " << type; \
1237 #define MSHADOW_TYPE_SWITCH(type, DType, ...) \
1239 case mshadow::kFloat32: \
1241 typedef float DType; \
1245 case mshadow::kFloat64: \
1247 typedef double DType; \
1251 case mshadow::kFloat16: \
1253 typedef mshadow::half::half_t DType; \
1257 case mshadow::kUint8: \
1259 typedef uint8_t DType; \
1263 case mshadow::kInt8: \
1265 typedef int8_t DType; \
1269 case mshadow::kInt32: \
1271 typedef int32_t DType; \
1275 case mshadow::kInt64: \
1277 typedef int64_t DType; \
1281 case mshadow::kBool: \
1282 LOG(FATAL) << "This operation does not " \
1283 "support bool type"; \
1285 case mshadow::kInt16: \
1286 LOG(FATAL) << "This operation does not " \
1287 "support int16 type"; \
1289 case mshadow::kUint16: \
1290 LOG(FATAL) << "This operation does not " \
1291 "support uint16 type"; \
1293 case mshadow::kUint32: \
1294 LOG(FATAL) << "This operation does not " \
1295 "support uint32 type"; \
1297 case mshadow::kUint64: \
1298 LOG(FATAL) << "This operation does not " \
1299 "support uint64 type"; \
1302 LOG(FATAL) << "Unknown type enum " << type; \
1306 #define MSHADOW_SGL_DBL_TYPE_SWITCH(type, DType, ...) \
1308 case mshadow::kFloat32: \
1310 typedef float DType; \
1314 case mshadow::kFloat64: \
1316 typedef double DType; \
1321 LOG(FATAL) << "This operation only supports " \
1322 "32-bit and 64-bit floating point"; \
1325 #define MSHADOW_REAL_TYPE_SWITCH(type, DType, ...) \
1327 case mshadow::kFloat32: \
1329 typedef float DType; \
1333 case mshadow::kFloat64: \
1335 typedef double DType; \
1339 case mshadow::kFloat16: \
1341 typedef mshadow::half::half_t DType; \
1345 case mshadow::kUint8: \
1346 LOG(FATAL) << "This operation only support " \
1347 "floating point types not uint8"; \
1349 case mshadow::kInt8: \
1350 LOG(FATAL) << "This operation only support " \
1351 "floating point types not int8"; \
1353 case mshadow::kInt32: \
1354 LOG(FATAL) << "This operation only support " \
1355 "floating point types, not int32";\
1357 case mshadow::kInt64: \
1358 LOG(FATAL) << "This operation only support " \
1359 "floating point types, not int64";\
1361 case mshadow::kBool: \
1362 LOG(FATAL) << "This operation only support " \
1363 "floating point types, not bool"; \
1365 case mshadow::kInt16: \
1366 LOG(FATAL) << "This operation only support " \
1367 "floating point types, not int16";\
1369 case mshadow::kUint16: \
1370 LOG(FATAL) << "This operation only support " \
1371 "floating point types not uint16";\
1373 case mshadow::kUint32: \
1374 LOG(FATAL) << "This operation only support " \
1375 "floating point types not uint32";\
1377 case mshadow::kUint64: \
1378 LOG(FATAL) << "This operation only support " \
1379 "floating point types not uint64";\
1382 LOG(FATAL) << "Unknown type enum " << type; \
1386 #define MSHADOW_REAL_TYPE_SWITCH_EX(type$, DType$, DLargeType$, ...) \
1388 case mshadow::kFloat32: \
1390 typedef float DType$; \
1391 typedef float DLargeType$; \
1395 case mshadow::kFloat64: \
1397 typedef double DType$; \
1398 typedef double DLargeType$; \
1402 case mshadow::kFloat16: \
1404 typedef mshadow::half::half_t DType$; \
1405 typedef float DLargeType$; \
1409 case mshadow::kBfloat16: \
1411 typedef mshadow::bfloat::bf16_t DType$; \
1412 typedef float DLargeType$; \
1416 case mshadow::kUint8: \
1417 LOG(FATAL) << "This operation only support " \
1418 "floating point types not uint8"; \
1420 case mshadow::kInt8: \
1421 LOG(FATAL) << "This operation only support " \
1422 "floating point types not int8"; \
1424 case mshadow::kInt32: \
1425 LOG(FATAL) << "This operation only support " \
1426 "floating point types, not int32";\
1428 case mshadow::kInt64: \
1429 LOG(FATAL) << "This operation only support " \
1430 "floating point types, not int64";\
1432 case mshadow::kBool: \
1433 LOG(FATAL) << "This operation only support " \
1434 "floating point types, not bool"; \
1436 case mshadow::kInt16: \
1437 LOG(FATAL) << "This operation only support " \
1438 "floating point types, not int16";\
1440 case mshadow::kUint16: \
1441 LOG(FATAL) << "This operation only support " \
1442 "floating point types not uint16";\
1444 case mshadow::kUint32: \
1445 LOG(FATAL) << "This operation only support " \
1446 "floating point types not uint32";\
1448 case mshadow::kUint64: \
1449 LOG(FATAL) << "This operation only support " \
1450 "floating point types not uint64";\
1453 LOG(FATAL) << "Unknown type enum " << type$; \
1456 #define MSHADOW_REAL_TYPE_SWITCH_EX(type$, DType$, DLargeType$, ...) \
1458 case mshadow::kFloat32: \
1460 typedef float DType$; \
1461 typedef float DLargeType$; \
1465 case mshadow::kFloat64: \
1467 typedef double DType$; \
1468 typedef double DLargeType$; \
1472 case mshadow::kFloat16: \
1474 typedef mshadow::half::half_t DType$; \
1475 typedef float DLargeType$; \
1479 case mshadow::kUint8: \
1480 LOG(FATAL) << "This operation only support " \
1481 "floating point types not uint8"; \
1483 case mshadow::kInt8: \
1484 LOG(FATAL) << "This operation only support " \
1485 "floating point types not int8"; \
1487 case mshadow::kInt32: \
1488 LOG(FATAL) << "This operation only support " \
1489 "floating point types, not int32";\
1491 case mshadow::kInt64: \
1492 LOG(FATAL) << "This operation only support " \
1493 "floating point types, not int64";\
1495 case mshadow::kBool: \
1496 LOG(FATAL) << "This operation only support " \
1497 "floating point types, not bool"; \
1499 case mshadow::kInt16: \
1500 LOG(FATAL) << "This operation only support " \
1501 "floating point types, not int16";\
1503 case mshadow::kUint16: \
1504 LOG(FATAL) << "This operation only support " \
1505 "floating point types not uint16";\
1507 case mshadow::kUint32: \
1508 LOG(FATAL) << "This operation only support " \
1509 "floating point types not uint32";\
1511 case mshadow::kUint64: \
1512 LOG(FATAL) << "This operation only support " \
1513 "floating point types not uint64";\
1516 LOG(FATAL) << "Unknown type enum " << type$; \
1519 #define MSHADOW_LAYOUT_SWITCH(layout, Layout, ...) \
1521 case mshadow::kNCHW: \
1523 const int Layout = kNCHW; \
1527 case mshadow::kNHWC: \
1529 const int Layout = kNHWC; \
1533 case mshadow::kNCDHW: \
1535 const int Layout = kNCDHW; \
1539 case mshadow::kNDHWC: \
1541 const int Layout = kNDHWC; \
1546 LOG(FATAL) << "Unknown layout enum " << layout; \
1553 #define MSHADOW_IDX_TYPE_SWITCH(type, DType, ...) \
1555 case mshadow::kInt64: \
1557 typedef int64_t DType; \
1562 LOG(FATAL) << "Unknown type enum " << type; \
1565 #define MSHADOW_TYPE_SWITCH_WITH_BOOL(type, DType, ...) \
1567 case mshadow::kFloat32: \
1569 typedef float DType; \
1573 case mshadow::kFloat64: \
1575 typedef double DType; \
1579 case mshadow::kFloat16: \
1581 typedef mshadow::half::half_t DType; \
1585 case mshadow::kBfloat16: \
1587 typedef mshadow::bfloat::bf16_t DType; \
1591 case mshadow::kUint8: \
1593 typedef uint8_t DType; \
1597 case mshadow::kInt8: \
1599 typedef int8_t DType; \
1603 case mshadow::kInt32: \
1605 typedef int32_t DType; \
1609 case mshadow::kInt64: \
1611 typedef int64_t DType; \
1615 case mshadow::kBool: \
1617 typedef bool DType; \
1621 case mshadow::kInt16: \
1622 LOG(FATAL) << "This operation does not " \
1623 "support int16 type"; \
1625 case mshadow::kUint16: \
1626 LOG(FATAL) << "This operation does not " \
1627 "support uint16 type"; \
1629 case mshadow::kUint32: \
1630 LOG(FATAL) << "This operation does not " \
1631 "support uint32 type"; \
1633 case mshadow::kUint64: \
1634 LOG(FATAL) << "This operation does not " \
1635 "support uint64 type"; \
1638 LOG(FATAL) << "Unknown type enum " << type; \
1641 #define MSHADOW_TYPE_SWITCH_EXT(type, DType, ...) \
1643 case mshadow::kFloat32: \
1645 typedef float DType; \
1649 case mshadow::kFloat64: \
1651 typedef double DType; \
1655 case mshadow::kFloat16: \
1657 typedef mshadow::half::half_t DType; \
1661 case mshadow::kBfloat16: \
1663 typedef mshadow::bfloat::bf16_t DType; \
1667 case mshadow::kUint8: \
1669 typedef uint8_t DType; \
1673 case mshadow::kInt8: \
1675 typedef int8_t DType; \
1679 case mshadow::kInt32: \
1681 typedef int32_t DType; \
1685 case mshadow::kInt64: \
1687 typedef int64_t DType; \
1691 case mshadow::kInt16: \
1693 typedef int16_t DType; \
1697 case mshadow::kUint16: \
1699 typedef uint16_t DType; \
1703 case mshadow::kUint32: \
1705 typedef uint32_t DType; \
1709 case mshadow::kUint64: \
1711 typedef uint64_t DType; \
1716 LOG(FATAL) << "Unknown type enum " << type; \
1719 #define MSHADOW_TYPE_SWITCH_EXT_WITH_BOOL(type, DType, ...) \
1721 case mshadow::kFloat32: \
1723 typedef float DType; \
1727 case mshadow::kFloat64: \
1729 typedef double DType; \
1733 case mshadow::kFloat16: \
1735 typedef mshadow::half::half_t DType; \
1739 case mshadow::kBfloat16: \
1741 typedef mshadow::bfloat::bf16_t DType; \
1745 case mshadow::kUint8: \
1747 typedef uint8_t DType; \
1751 case mshadow::kInt8: \
1753 typedef int8_t DType; \
1757 case mshadow::kInt32: \
1759 typedef int32_t DType; \
1763 case mshadow::kInt64: \
1765 typedef int64_t DType; \
1769 case mshadow::kBool: \
1771 typedef bool DType; \
1775 case mshadow::kInt16: \
1777 typedef int16_t DType; \
1781 case mshadow::kUint16: \
1783 typedef uint16_t DType; \
1787 case mshadow::kUint32: \
1789 typedef uint32_t DType; \
1793 case mshadow::kUint64: \
1795 typedef uint64_t DType; \
1800 LOG(FATAL) << "Unknown type enum " << type; \
1820 return "unsigned char";
1832 return "unsigned short";
1834 return "unsigned int";
1836 return "unsigned long long";
1838 LOG(FATAL) <<
"Unknown type enum " << dtype;
1844 #endif // MSHADOW_BASE_H_
op::right OPType
corresponding binary operator type
Definition: base.h:691
static default_real_t AlphaBLAS(void)
helper constant to use BLAS, alpha
Definition: base.h:715
@ kNCHW
Definition: base.h:501
maximum reducer
Definition: base.h:1055
static MSHADOW_XINLINE void Finalize(volatile DType &dst, volatile DType &residual)
finalize reduction
Definition: base.h:1136
static MSHADOW_XINLINE void Merge(volatile DType &dst_val, volatile DType &dst_residual, volatile DType &src_val, volatile DType &src_residual)
combine the results of two reducers
Definition: base.h:1128
static MSHADOW_XINLINE void Finalize(volatile DType &dst, volatile DType &residual)
finalize reduction
Definition: base.h:1029
MSHADOW_XINLINE DType NegInfValue(void)
negative infinity of certain types
Definition: base.h:873
MSHADOW_XINLINE uint32_t MaxValue< uint32_t >(void)
maximum value of uint32_t
Definition: base.h:951
MSHADOW_XINLINE bool MaxValue< bool >(void)
maximum value of bool
Definition: base.h:946
index_t openmp_index_t
openmp index for linux
Definition: base.h:336
MSHADOW_XINLINE uint8_t MinValue< uint8_t >(void)
minimum value of uint8_t
Definition: base.h:839
static MSHADOW_XINLINE DType Map(DType a, DType b)
map a, b to result using defined operation
Definition: base.h:628
const int default_type_flag
type enum value for default real type
Definition: base.h:492
static MSHADOW_XINLINE void SetInitValue(DType &initv, DType &none)
set the initial value during reduction
Definition: base.h:1103
MSHADOW_XINLINE float PosInfValue< float >(void)
positive infinity value of float
Definition: base.h:965
MSHADOW_XINLINE uint8_t MaxValue< uint8_t >(void)
maximum value of uint8_t
Definition: base.h:926
@ kUint16
Definition: base.h:361
@ kUint64
Definition: base.h:363
@ kNWC
Definition: base.h:506
@ kNCDHW
Definition: base.h:509
static MSHADOW_XINLINE void SetInitValue(DType &initv)
set the initial value during reduction
Definition: base.h:1096
static MSHADOW_XINLINE void Save(DType &a, DType b)
save b to a using save method
Definition: base.h:711
MSHADOW_XINLINE int8_t MaxValue< int8_t >(void)
maximum value of int8_t
Definition: base.h:931
op::div OPType
corresponding binary operator type
Definition: base.h:739
MSHADOW_XINLINE float MinValue< float >(void)
minimum value of float
Definition: base.h:819
static MSHADOW_XINLINE void Merge(volatile DType &dst_val, volatile DType &dst_residual, volatile DType &src_val, volatile DType &src_residual)
combine the results of two reducers
Definition: base.h:1075
@ kInt8
Definition: base.h:357
@ kUint32
Definition: base.h:362
op::mul OPType
corresponding binary operator type
Definition: base.h:729
minus operator
Definition: base.h:641
static MSHADOW_XINLINE void Finalize(volatile DType &dst, volatile DType &residual)
finalize reduction
Definition: base.h:1083
#define MSHADOW_BF16_MAX
Definition: bfloat.h:182
std::string toString(LayoutFlag layout)
Definition: base.h:545
#define MSHADOW_XINLINE
Definition: base.h:228
static MSHADOW_XINLINE void Merge(volatile DType &dst_val, volatile DType &dst_residual, volatile DType &src_val, volatile DType &src_residual)
combine the results of two reducers
Definition: base.h:1012
static default_real_t AlphaBLAS(void)
helper constant to use BLAS, alpha
Definition: base.h:687
static MSHADOW_XINLINE DType Map(DType a)
map a to result using defined operation
Definition: base.h:672
MSHADOW_XINLINE bool IsNan(volatile DType val)
Definition: base.h:753
static default_real_t BetaBLAS(void)
helper constant to use BLAS, beta
Definition: base.h:703
divide operator
Definition: base.h:649
static MSHADOW_XINLINE void Save(DType &a, DType b)
save b to a using save method
Definition: base.h:683
#define MSHADOW_BF16_SIGN_BIT
Definition: bfloat.h:183
sum reducer
Definition: base.h:987
static default_real_t BetaBLAS(void)
helper constant to use BLAS, beta
Definition: base.h:717
static MSHADOW_XINLINE void Finalize(volatile DType &dst)
finalize reduction
Definition: base.h:1080
#define MSHADOW_BF16_EXPONENT_BITS
Definition: bfloat.h:184
static MSHADOW_XINLINE void Save(DType &a, DType b)
save b to a using save method
Definition: base.h:735
const int default_layout
default layout for 4d tensor
Definition: base.h:596
static MSHADOW_XINLINE void Save(DType &a, DType b)
save b to a using save method
Definition: base.h:725
static MSHADOW_XINLINE void Reduce(volatile DType &dst, volatile DType src, volatile DType &none)
do reduction into dst
Definition: base.h:1118
const char limits[]
Definition: util-inl.h:594
MSHADOW_XINLINE int MaxValue< int32_t >(void)
maximum value of int32_t
Definition: base.h:936
@ kBool
Definition: base.h:359
MSHADOW_XINLINE int8_t MinValue< int8_t >(void)
minimum value of int8_t
Definition: base.h:844
static default_real_t BetaBLAS(void)
helper constant to use BLAS, beta
Definition: base.h:689
@ kNHWC
Definition: base.h:502
const int default_layout_5d
default layout for 5d tensor
Definition: base.h:619
static MSHADOW_XINLINE void Reduce(volatile DType &dst, volatile DType src)
do reduction into dst
Definition: base.h:1058
@ kCWN
Definition: base.h:507
@ kCHWN
Definition: base.h:503
@ kFloat64
Definition: base.h:353
plus operator
Definition: base.h:633
static default_real_t AlphaBLAS(void)
helper constant to use BLAS, alpha
Definition: base.h:701
static MSHADOW_XINLINE void Merge(volatile DType &dst_val, volatile DType &src_val)
combine the results of two reducers
Definition: base.h:1007
MSHADOW_XINLINE double MinValue< double >(void)
minimum value of double
Definition: base.h:824
save to saver: +=
Definition: base.h:694
LayoutFlag
Definition: base.h:498
static MSHADOW_XINLINE DType Map(DType a, DType b)
map a, b to result using defined operation
Definition: base.h:636
static MSHADOW_XINLINE void Reduce(volatile DType &dst, volatile DType src)
do reduction into dst
Definition: base.h:990
@ kInt16
Definition: base.h:360
static MSHADOW_XINLINE DType PartialGrad(DType redres, DType redsrc)
calculate gradient of redres with respect to redsrc, redres: reduced result, redsrc: one of reduction...
Definition: base.h:1142
MSHADOW_XINLINE DType PosInfValue(void)
positive infinity of certain types
Definition: base.h:960
#define MSHADOW_HALF_MAX
Definition: half.h:369
static MSHADOW_XINLINE void Merge(volatile DType &dst_val, volatile DType &src_val)
combine the results of two reducers
Definition: base.h:1123
save to saver: =
Definition: base.h:680
static MSHADOW_XINLINE void Finalize(volatile DType &dst)
finalize reduction
Definition: base.h:1026
@ kCDHWN
Definition: base.h:511
MSHADOW_XINLINE double NegInfValue< double >(void)
negative infinity value of double
Definition: base.h:883
get rhs
Definition: base.h:657
MSHADOW_XINLINE DType MinValue(void)
minimum value of certain types
const unsigned kRandBufferSize
buffer size for each random number generator
Definition: base.h:321
MSHADOW_XINLINE bool IsInf(volatile DType val)
Definition: base.h:783
static MSHADOW_XINLINE DType Map(DType a, DType b)
map a, b to result using defined operation
Definition: base.h:644
static MSHADOW_XINLINE void Merge(volatile DType &dst_val, volatile DType &src_val)
combine the results of two reducers
Definition: base.h:1070
multiply to saver: *=
Definition: base.h:722
@ kInt64
Definition: base.h:358
static MSHADOW_XINLINE DType PartialGrad(DType redres, DType redsrc)
calculate gradient of redres with respect to redsrc, redres: reduced result, redsrc: one of reduction...
Definition: base.h:1035
int32_t index_t
type that will be used for index
Definition: base.h:328
#define MSHADOW_HALF_SIGN_BIT
Definition: half.h:370
static MSHADOW_XINLINE void Reduce(volatile DType &dst, volatile DType src)
do reduction into dst
Definition: base.h:1111
const float kPi
pi
Definition: base.h:323
@ kInt32
Definition: base.h:356
MSHADOW_XINLINE bool MinValue< bool >(void)
minimum value of bool
Definition: base.h:859
TypeFlag
data type flag
Definition: base.h:351
int lapack_index_t
Definition: base.h:344
MSHADOW_XINLINE int64_t MinValue< int64_t >(void)
minimum value of int64_t
Definition: base.h:854
MSHADOW_XINLINE double MaxValue< double >(void)
maximum value of double
Definition: base.h:911
static MSHADOW_XINLINE void SetInitValue(DType &initv)
set the initial value during reduction
Definition: base.h:1042
size_t mshadow_sizeof(int type)
get data type size from type enum
Definition: base.h:1804
@ kNDHWC
Definition: base.h:510
op::plus OPType
corresponding binary operator type
Definition: base.h:705
static MSHADOW_XINLINE void Reduce(volatile DType &dst, volatile DType src, volatile DType &none)
do reduction into dst
Definition: base.h:1065
#define MSHADOW_BF16_MIN
overloaded + operator for bf16_t
Definition: bfloat.h:181
std::string dtype_string(const int dtype)
Definition: base.h:1811
static MSHADOW_XINLINE void SetInitValue(DType &initv, DType &residual)
set the initial value during reduction
Definition: base.h:1049
overloaded + operator between half_t and bf16_t
Definition: base.h:319
MSHADOW_XINLINE float MaxValue< float >(void)
maximum value of float
Definition: base.h:906
float default_real_t
float point type that will be used in default by mshadow
Definition: base.h:348
identity function that maps a real number to it self
Definition: base.h:669
op::minus OPType
corresponding binary operator type
Definition: base.h:719
definition of half (float16) type.
minus to saver: -=
Definition: base.h:708
@ kUNKNOWN
Definition: base.h:499
static MSHADOW_XINLINE void SetInitValue(DType &initv)
set the initial value during reduction
Definition: base.h:1149
static MSHADOW_XINLINE void Finalize(volatile DType &dst)
finalize reduction
Definition: base.h:1133
#define MSHADOW_TYPE_SWITCH_EXT_WITH_BOOL(type, DType,...)
Definition: base.h:1719
MSHADOW_XINLINE int64_t MaxValue< int64_t >(void)
maximum value of int64_t
Definition: base.h:941
#define MSHADOW_HALF_MIN
overloaded + operator for half_t
Definition: half.h:368
static MSHADOW_XINLINE void SetInitValue(DType &initv, DType &none)
set the initial value during reduction
Definition: base.h:1156
static MSHADOW_XINLINE void Save(DType &a, DType b)
save b to a using save method
Definition: base.h:697
#define MSHADOW_HALF_BF_OPERATOR(RTYPE, OP)
Definition: base.h:292
@ kUint8
Definition: base.h:355
@ kBfloat16
Definition: base.h:364
static MSHADOW_XINLINE DType Map(DType a, DType b)
map a, b to result using defined operation
Definition: base.h:652
@ kNCW
Definition: base.h:505
static MSHADOW_XINLINE void Reduce(volatile DType &dst, volatile DType src, volatile DType &residual)
do stable reduction into dst
Definition: base.h:995
MSHADOW_XINLINE double PosInfValue< double >(void)
positive infinity value of double
Definition: base.h:970
mul operator
Definition: base.h:625
LayoutFlag layoutFlag(std::string layoutstr)
Definition: base.h:514
const int index_type_flag
TypeFlag value for type of indexes.
Definition: base.h:495
MSHADOW_XINLINE unsigned int MinValue< unsigned int >(void)
minimum value of unsigned int
Definition: base.h:864
MSHADOW_XINLINE DType MaxValue(void)
maximum value of certain types
divide to saver: /=
Definition: base.h:732
@ kFloat16
Definition: base.h:354
MSHADOW_XINLINE float NegInfValue< float >(void)
negative infinity value of float
Definition: base.h:878
minimum reducer
Definition: base.h:1108
MSHADOW_XINLINE int MinValue< int32_t >(void)
minimum value of int32_t
Definition: base.h:849
static MSHADOW_XINLINE DType PartialGrad(DType redres, DType redsrc)
calculate gradient of redres with respect to redsrc, redres: reduced result, redsrc: one of reduction...
Definition: base.h:1089
@ kFloat32
Definition: base.h:352
static MSHADOW_XINLINE DType Map(DType a, DType b)
map a, b to result using defined operation
Definition: base.h:660
definition of bfloat type.
#define MSHADOW_HALF_EXPONENT_BITS
Definition: half.h:371