mxnet
utils.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef MXNET_COMMON_CUDA_UTILS_H_
25 #define MXNET_COMMON_CUDA_UTILS_H_
26 
27 #include <dmlc/logging.h>
28 #include <dmlc/parameter.h>
29 #include <dmlc/optional.h>
30 #include <mshadow/base.h>
31 #include <mxnet/libinfo.h>
32 
34 #ifdef __JETBRAINS_IDE__
35 #define __CUDACC__ 1
36 #define __host__
37 #define __device__
38 #define __global__
39 #define __forceinline__
40 #define __shared__
41 inline void __syncthreads() {}
42 inline void __threadfence_block() {}
43 template <class T>
44 inline T __clz(const T val) {
45  return val;
46 }
47 struct __cuda_fake_struct {
48  int x;
49  int y;
50  int z;
51 };
52 extern __cuda_fake_struct blockDim;
53 extern __cuda_fake_struct threadIdx;
54 extern __cuda_fake_struct blockIdx;
55 #endif
56 
57 #define QUOTE(x) #x
58 #define QUOTEVALUE(x) QUOTE(x)
59 
60 #if MXNET_USE_CUDA
61 
62 #include <cuda_runtime.h>
63 #include <cublas_v2.h>
64 #include <curand.h>
65 #if MXNET_USE_NVML
66 #include <nvml.h>
67 #endif // MXNET_USE_NVML
68 
69 #include <vector>
70 
71 #define STATIC_ASSERT_CUDA_VERSION_GE(min_version) \
72  static_assert(CUDA_VERSION >= min_version, "Compiled-against CUDA version " \
73  QUOTEVALUE(CUDA_VERSION) " is too old, please upgrade system to version " \
74  QUOTEVALUE(min_version) " or later.")
75 
80 #ifdef __CUDACC__
81 inline __device__ bool __is_supported_cuda_architecture() {
82 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 300
83 #error "Fermi and earlier GPU architectures are not supported (architecture versions less than 3.0)"
84  return false;
85 #else
86  return true;
87 #endif // __CUDA_ARCH__ < 300
88 }
89 #endif // __CUDACC__
90 
95 #define CHECK_CUDA_ERROR(msg) \
96  { \
97  cudaError_t e = cudaGetLastError(); \
98  CHECK_EQ(e, cudaSuccess) << (msg) << " CUDA: " << cudaGetErrorString(e); \
99  }
100 
107 #define CUDA_CALL(func) \
108  { \
109  cudaError_t e = (func); \
110  CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) << "CUDA: " << cudaGetErrorString(e); \
111  }
112 
119 #define CUBLAS_CALL(func) \
120  { \
121  cublasStatus_t e = (func); \
122  CHECK_EQ(e, CUBLAS_STATUS_SUCCESS) \
123  << "cuBLAS: " << mxnet::common::cuda::CublasGetErrorString(e); \
124  }
125 
132 #define CUSOLVER_CALL(func) \
133  { \
134  cusolverStatus_t e = (func); \
135  CHECK_EQ(e, CUSOLVER_STATUS_SUCCESS) \
136  << "cuSolver: " << mxnet::common::cuda::CusolverGetErrorString(e); \
137  }
138 
145 #define CURAND_CALL(func) \
146  { \
147  curandStatus_t e = (func); \
148  CHECK_EQ(e, CURAND_STATUS_SUCCESS) \
149  << "cuRAND: " << mxnet::common::cuda::CurandGetErrorString(e); \
150  }
151 
158 #define NVRTC_CALL(x) \
159  { \
160  nvrtcResult result = x; \
161  CHECK_EQ(result, NVRTC_SUCCESS) << #x " failed with error " << nvrtcGetErrorString(result); \
162  }
163 
170 #define CUDA_DRIVER_CALL(func) \
171  { \
172  CUresult e = (func); \
173  if (e != CUDA_SUCCESS) { \
174  char const* err_msg = nullptr; \
175  if (cuGetErrorString(e, &err_msg) == CUDA_ERROR_INVALID_VALUE) { \
176  LOG(FATAL) << "CUDA Driver: Unknown error " << e; \
177  } else { \
178  LOG(FATAL) << "CUDA Driver: " << e << " " << err_msg; \
179  } \
180  } \
181  }
182 
183 #if MXNET_USE_NVML
184 
190 #define NVML_CALL(func) \
191  { \
192  nvmlReturn_t result = (func); \
193  CHECK_EQ(result, NVML_SUCCESS) << #func " failed with error " << nvmlErrorString(result); \
194  }
195 #endif // MXNET_USE_NVML
196 
197 #if !defined(_MSC_VER)
198 #define CUDA_UNROLL _Pragma("unroll")
199 #define CUDA_NOUNROLL _Pragma("nounroll")
200 #else
201 #define CUDA_UNROLL
202 #define CUDA_NOUNROLL
203 #endif
204 
205 namespace mxnet {
206 namespace common {
208 namespace cuda {
212 template <typename DType>
213 struct CublasType;
214 
215 // With CUDA v8, cuBLAS adopted use of cudaDataType_t instead of its own
216 // datatype cublasDataType_t. The older cudaDataType_t values could be
217 // included below, but since this class was introduced to support the cuBLAS v8
218 // call cublasGemmEx(), burdening the class with the legacy type values
219 // was not needed.
220 
221 template <>
222 struct CublasType<float> {
223  static const int kFlag = mshadow::kFloat32;
224 #if CUDA_VERSION >= 8000
225  static const cudaDataType_t kCudaFlag = CUDA_R_32F;
226 #endif
227  typedef float ScaleType;
228  static const float one;
229  static const float zero;
230 };
231 template <>
232 struct CublasType<double> {
233  static const int kFlag = mshadow::kFloat64;
234 #if CUDA_VERSION >= 8000
235  static const cudaDataType_t kCudaFlag = CUDA_R_64F;
236 #endif
237  typedef double ScaleType;
238  static const double one;
239  static const double zero;
240 };
241 template <>
242 struct CublasType<mshadow::half::half_t> {
243  static const int kFlag = mshadow::kFloat16;
244 #if CUDA_VERSION >= 8000
245  static const cudaDataType_t kCudaFlag = CUDA_R_16F;
246 #endif
247  typedef float ScaleType;
248  static const mshadow::half::half_t one;
249  static const mshadow::half::half_t zero;
250 };
251 template <>
252 struct CublasType<uint8_t> {
253  static const int kFlag = mshadow::kUint8;
254 #if CUDA_VERSION >= 8000
255  static const cudaDataType_t kCudaFlag = CUDA_R_8I;
256 #endif
257  typedef uint8_t ScaleType;
258  static const uint8_t one = 1;
259  static const uint8_t zero = 0;
260 };
261 template <>
262 struct CublasType<int32_t> {
263  static const int kFlag = mshadow::kInt32;
264 #if CUDA_VERSION >= 8000
265  static const cudaDataType_t kCudaFlag = CUDA_R_32I;
266 #endif
267  typedef int32_t ScaleType;
268  static const int32_t one = 1;
269  static const int32_t zero = 0;
270 };
271 
277 inline const char* CublasGetErrorString(cublasStatus_t error) {
278  switch (error) {
279  case CUBLAS_STATUS_SUCCESS:
280  return "CUBLAS_STATUS_SUCCESS";
281  case CUBLAS_STATUS_NOT_INITIALIZED:
282  return "CUBLAS_STATUS_NOT_INITIALIZED";
283  case CUBLAS_STATUS_ALLOC_FAILED:
284  return "CUBLAS_STATUS_ALLOC_FAILED";
285  case CUBLAS_STATUS_INVALID_VALUE:
286  return "CUBLAS_STATUS_INVALID_VALUE";
287  case CUBLAS_STATUS_ARCH_MISMATCH:
288  return "CUBLAS_STATUS_ARCH_MISMATCH";
289  case CUBLAS_STATUS_MAPPING_ERROR:
290  return "CUBLAS_STATUS_MAPPING_ERROR";
291  case CUBLAS_STATUS_EXECUTION_FAILED:
292  return "CUBLAS_STATUS_EXECUTION_FAILED";
293  case CUBLAS_STATUS_INTERNAL_ERROR:
294  return "CUBLAS_STATUS_INTERNAL_ERROR";
295  case CUBLAS_STATUS_NOT_SUPPORTED:
296  return "CUBLAS_STATUS_NOT_SUPPORTED";
297  default:
298  break;
299  }
300  return "Unknown cuBLAS status";
301 }
302 
303 #if CUDA_VERSION >= 8000
304 
309 inline cublasOperation_t CublasTransposeOp(bool transpose) {
310  return transpose ? CUBLAS_OP_T : CUBLAS_OP_N;
311 }
312 #endif
313 
319 inline const char* CusolverGetErrorString(cusolverStatus_t error) {
320  switch (error) {
321  case CUSOLVER_STATUS_SUCCESS:
322  return "CUSOLVER_STATUS_SUCCESS";
323  case CUSOLVER_STATUS_NOT_INITIALIZED:
324  return "CUSOLVER_STATUS_NOT_INITIALIZED";
325  case CUSOLVER_STATUS_ALLOC_FAILED:
326  return "CUSOLVER_STATUS_ALLOC_FAILED";
327  case CUSOLVER_STATUS_INVALID_VALUE:
328  return "CUSOLVER_STATUS_INVALID_VALUE";
329  case CUSOLVER_STATUS_ARCH_MISMATCH:
330  return "CUSOLVER_STATUS_ARCH_MISMATCH";
331  case CUSOLVER_STATUS_EXECUTION_FAILED:
332  return "CUSOLVER_STATUS_EXECUTION_FAILED";
333  case CUSOLVER_STATUS_INTERNAL_ERROR:
334  return "CUSOLVER_STATUS_INTERNAL_ERROR";
335  case CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED:
336  return "CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED";
337  default:
338  break;
339  }
340  return "Unknown cuSOLVER status";
341 }
342 
348 inline const char* CurandGetErrorString(curandStatus_t status) {
349  switch (status) {
350  case CURAND_STATUS_SUCCESS:
351  return "CURAND_STATUS_SUCCESS";
352  case CURAND_STATUS_VERSION_MISMATCH:
353  return "CURAND_STATUS_VERSION_MISMATCH";
354  case CURAND_STATUS_NOT_INITIALIZED:
355  return "CURAND_STATUS_NOT_INITIALIZED";
356  case CURAND_STATUS_ALLOCATION_FAILED:
357  return "CURAND_STATUS_ALLOCATION_FAILED";
358  case CURAND_STATUS_TYPE_ERROR:
359  return "CURAND_STATUS_TYPE_ERROR";
360  case CURAND_STATUS_OUT_OF_RANGE:
361  return "CURAND_STATUS_OUT_OF_RANGE";
362  case CURAND_STATUS_LENGTH_NOT_MULTIPLE:
363  return "CURAND_STATUS_LENGTH_NOT_MULTIPLE";
364  case CURAND_STATUS_DOUBLE_PRECISION_REQUIRED:
365  return "CURAND_STATUS_DOUBLE_PRECISION_REQUIRED";
366  case CURAND_STATUS_LAUNCH_FAILURE:
367  return "CURAND_STATUS_LAUNCH_FAILURE";
368  case CURAND_STATUS_PREEXISTING_FAILURE:
369  return "CURAND_STATUS_PREEXISTING_FAILURE";
370  case CURAND_STATUS_INITIALIZATION_FAILED:
371  return "CURAND_STATUS_INITIALIZATION_FAILED";
372  case CURAND_STATUS_ARCH_MISMATCH:
373  return "CURAND_STATUS_ARCH_MISMATCH";
374  case CURAND_STATUS_INTERNAL_ERROR:
375  return "CURAND_STATUS_INTERNAL_ERROR";
376  }
377  return "Unknown cuRAND status";
378 }
379 
380 template <typename DType>
381 inline DType __device__ CudaMax(DType a, DType b) {
382  return a > b ? a : b;
383 }
384 
385 template <typename DType>
386 inline DType __device__ CudaMin(DType a, DType b) {
387  return a < b ? a : b;
388 }
389 
390 class DeviceStore {
391  public:
393  explicit DeviceStore(int requested_device = -1, bool restore = true)
394  : restore_device_(-1), current_device_(requested_device), restore_(restore) {
395  if (restore_)
396  CUDA_CALL(cudaGetDevice(&restore_device_));
397  if (requested_device != restore_device_) {
398  SetDevice(requested_device);
399  }
400  }
401 
403  if (restore_ && current_device_ != restore_device_ && current_device_ != -1 &&
404  restore_device_ != -1)
405  CUDA_CALL(cudaSetDevice(restore_device_));
406  }
407 
408  void SetDevice(int device) {
409  if (device != -1) {
410  CUDA_CALL(cudaSetDevice(device));
411  current_device_ = device;
412  }
413  }
414 
415  private:
416  int restore_device_;
417  int current_device_;
418  bool restore_;
419 };
420 
429 int get_load_type(size_t N);
430 
441 int get_rows_per_block(size_t row_size, int num_threads_per_block);
442 
443 } // namespace cuda
444 } // namespace common
445 } // namespace mxnet
446 
448 constexpr size_t kMaxNumGpus = 64;
449 
450 // The implementations below assume that accesses of 32-bit ints are inherently atomic and
451 // can be read/written by multiple threads without locks. The values held should be < 2^31.
452 
461 inline int cudaAttributeLookup(int device_id,
462  std::vector<int32_t>* cached_values,
463  cudaDeviceAttr attr,
464  const char* attr_name) {
465  if (device_id < 0 || device_id >= static_cast<int>(cached_values->size())) {
466  LOG(FATAL) << attr_name << "(device_id) called with invalid id: " << device_id;
467  } else if ((*cached_values)[device_id] < 0) {
468  int temp = -1;
469  CUDA_CALL(cudaDeviceGetAttribute(&temp, attr, device_id));
470  (*cached_values)[device_id] = static_cast<int32_t>(temp);
471  }
472  return (*cached_values)[device_id];
473 }
474 
480 inline int ComputeCapabilityMajor(int device_id) {
481  static std::vector<int32_t> capability_major(kMaxNumGpus, -1);
482  return cudaAttributeLookup(
483  device_id, &capability_major, cudaDevAttrComputeCapabilityMajor, "ComputeCapabilityMajor");
484 }
485 
491 inline int ComputeCapabilityMinor(int device_id) {
492  static std::vector<int32_t> capability_minor(kMaxNumGpus, -1);
493  return cudaAttributeLookup(
494  device_id, &capability_minor, cudaDevAttrComputeCapabilityMinor, "ComputeCapabilityMinor");
495 }
496 
502 inline int SMArch(int device_id) {
503  auto major = ComputeCapabilityMajor(device_id);
504  auto minor = ComputeCapabilityMinor(device_id);
505  return 10 * major + minor;
506 }
507 
513 inline int MultiprocessorCount(int device_id) {
514  static std::vector<int32_t> sm_counts(kMaxNumGpus, -1);
515  return cudaAttributeLookup(
516  device_id, &sm_counts, cudaDevAttrMultiProcessorCount, "MultiprocessorCount");
517 }
518 
524 inline int MaxSharedMemoryPerMultiprocessor(int device_id) {
525  static std::vector<int32_t> max_smem_per_mutiprocessor(kMaxNumGpus, -1);
526  return cudaAttributeLookup(device_id,
527  &max_smem_per_mutiprocessor,
528  cudaDevAttrMaxSharedMemoryPerMultiprocessor,
529  "MaxSharedMemoryPerMultiprocessor");
530 }
531 
537 inline bool SupportsCooperativeLaunch(int device_id) {
538  static std::vector<int32_t> coop_launch(kMaxNumGpus, -1);
539  return cudaAttributeLookup(
540  device_id, &coop_launch, cudaDevAttrCooperativeLaunch, "SupportsCooperativeLaunch");
541 }
542 
549 inline bool SupportsFloat16Compute(int device_id) {
550  if (device_id < 0) {
551  return false;
552  } else {
553  // Kepler and most Maxwell GPUs do not support fp16 compute
554  int computeCapabilityMajor = ComputeCapabilityMajor(device_id);
555  return (computeCapabilityMajor > 5) ||
556  (computeCapabilityMajor == 5 && ComputeCapabilityMinor(device_id) >= 3);
557  }
558 }
559 
566 inline bool SupportsTensorCore(int device_id) {
567  // Volta (sm_70) supports TensorCore algos
568  return device_id >= 0 && ComputeCapabilityMajor(device_id) >= 7;
569 }
570 
571 // The policy if the user hasn't set the environment variable MXNET_CUDA_ALLOW_TENSOR_CORE
572 #define MXNET_CUDA_ALLOW_TENSOR_CORE_DEFAULT true
573 
578 inline bool GetEnvAllowTensorCore() {
579  // Since these statics are in the '.h' file, they will exist and will be set
580  // separately in each compilation unit. Not ideal, but cleaner than creating a
581  // cuda_utils.cc solely to have a single instance and initialization.
582  static bool allow_tensor_core = false;
583  static bool is_set = false;
584  if (!is_set) {
585  // Use of optional<bool> here permits: "0", "1", "true" and "false" to all be legal.
586  bool default_value = MXNET_CUDA_ALLOW_TENSOR_CORE_DEFAULT;
587  allow_tensor_core =
588  dmlc::GetEnv("MXNET_CUDA_ALLOW_TENSOR_CORE", dmlc::optional<bool>(default_value)).value();
589  is_set = true;
590  }
591  return allow_tensor_core;
592 }
593 
594 // The policy if the user hasn't set the environment variable
595 // CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION
596 #define MXNET_CUDA_TENSOR_OP_MATH_ALLOW_CONVERSION_DEFAULT false
597 
602  // Use of optional<bool> here permits: "0", "1", "true" and "false" to all be
603  // legal.
605  return dmlc::GetEnv("MXNET_CUDA_TENSOR_OP_MATH_ALLOW_CONVERSION",
606  dmlc::optional<bool>(default_value))
607  .value();
608 }
609 
610 #if CUDA_VERSION >= 9000
611 // Sets the cuBLAS math mode that determines the 'allow TensorCore' policy. Returns previous.
612 inline cublasMath_t SetCublasMathMode(cublasHandle_t blas_handle, cublasMath_t new_math_type) {
613  auto handle_math_mode = CUBLAS_DEFAULT_MATH;
614  CUBLAS_CALL(cublasGetMathMode(blas_handle, &handle_math_mode));
615  CUBLAS_CALL(cublasSetMathMode(blas_handle, new_math_type));
616  return handle_math_mode;
617 }
618 #endif
619 
620 #endif // MXNET_USE_CUDA
621 
622 #if MXNET_USE_CUDNN
623 
624 #include <cudnn.h>
625 
626 // Creating CUDNN_VERSION_AS_STRING as follows avoids a static_assert error message that shows
627 // the formula for CUDNN_VERSION, i.e. "1000 * 7 + 100 * 6 + 0" rather than number "7600".
628 static_assert(CUDNN_PATCHLEVEL < 100 && CUDNN_MINOR < 10,
629  "CUDNN_VERSION_AS_STRING macro assumptions violated.");
630 #if CUDNN_PATCHLEVEL >= 10
631 #define CUDNN_VERSION_AS_STRING \
632  QUOTEVALUE(CUDNN_MAJOR) \
633  QUOTEVALUE(CUDNN_MINOR) \
634  QUOTEVALUE(CUDNN_PATCHLEVEL)
635 #else
636 #define CUDNN_VERSION_AS_STRING \
637  QUOTEVALUE(CUDNN_MAJOR) \
638  QUOTEVALUE(CUDNN_MINOR) \
639  "0" QUOTEVALUE(CUDNN_PATCHLEVEL)
640 #endif
641 
642 #define STATIC_ASSERT_CUDNN_VERSION_GE(min_version) \
643  static_assert( \
644  CUDNN_VERSION >= min_version, \
645  "Compiled-against cuDNN version " CUDNN_VERSION_AS_STRING \
646  " is too old, please upgrade system to version " QUOTEVALUE(min_version) " or later.")
647 
648 #define CUDNN_CALL_S(f, s) \
649  { \
650  cudnnStatus_t unclash_cxx_e = (f); \
651  if (unclash_cxx_e != CUDNN_STATUS_SUCCESS) \
652  LOG(s) << "cuDNN: " << cudnnGetErrorString(unclash_cxx_e); \
653  }
654 
655 #define CUDNN_CALL(f) CUDNN_CALL_S(f, FATAL)
656 #define CUDNN_CALL_NONFATAL(f) CUDNN_CALL_S(f, WARNING)
657 
658 #define CUTENSOR_CALL(func) \
659  { \
660  cutensorStatus_t e = (func); \
661  CHECK_EQ(e, CUTENSOR_STATUS_SUCCESS) << "cuTensor: " << cutensorGetErrorString(e); \
662  }
663 
671 inline int MaxForwardAlgos(cudnnHandle_t cudnn_handle) {
672  STATIC_ASSERT_CUDNN_VERSION_GE(7000);
673  int max_algos = 0;
674  CUDNN_CALL(cudnnGetConvolutionForwardAlgorithmMaxCount(cudnn_handle, &max_algos));
675  return max_algos;
676 }
677 
685 inline int MaxBackwardFilterAlgos(cudnnHandle_t cudnn_handle) {
686  STATIC_ASSERT_CUDNN_VERSION_GE(7000);
687  int max_algos = 0;
688  CUDNN_CALL(cudnnGetConvolutionBackwardFilterAlgorithmMaxCount(cudnn_handle, &max_algos));
689  return max_algos;
690 }
691 
699 inline int MaxBackwardDataAlgos(cudnnHandle_t cudnn_handle) {
700  STATIC_ASSERT_CUDNN_VERSION_GE(7000);
701  int max_algos = 0;
702  CUDNN_CALL(cudnnGetConvolutionBackwardDataAlgorithmMaxCount(cudnn_handle, &max_algos));
703  return max_algos;
704 }
705 
706 #endif // MXNET_USE_CUDNN
707 
708 // Overload atomicAdd to work for floats on all architectures
709 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 600
710 // From CUDA Programming Guide
711 static inline __device__ void atomicAdd(double* address, double val) {
712  unsigned long long* address_as_ull = // NOLINT(*)
713  reinterpret_cast<unsigned long long*>(address); // NOLINT(*)
714  unsigned long long old = *address_as_ull; // NOLINT(*)
715  unsigned long long assumed; // NOLINT(*)
716 
717  do {
718  assumed = old;
719  old = atomicCAS(
720  address_as_ull, assumed, __double_as_longlong(val + __longlong_as_double(assumed)));
721 
722  // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
723  } while (assumed != old);
724 }
725 #endif
726 
727 // Overload atomicAdd for half precision
728 // Taken from:
729 // https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh
730 #ifdef __CUDACC__
731 static inline __device__ void atomicAdd(mshadow::half::half_t* address, mshadow::half::half_t val) {
732  unsigned int* address_as_ui = reinterpret_cast<unsigned int*>(
733  reinterpret_cast<char*>(address) - (reinterpret_cast<size_t>(address) & 2));
734  unsigned int old = *address_as_ui;
735  unsigned int assumed;
736 
737  do {
738  assumed = old;
739  mshadow::half::half_t hsum;
740  hsum.half_ = reinterpret_cast<size_t>(address) & 2 ? (old >> 16) : (old & 0xffff);
741  hsum += val;
742  old = reinterpret_cast<size_t>(address) & 2 ? (old & 0xffff) | (hsum.half_ << 16) :
743  (old & 0xffff0000) | hsum.half_;
744  old = atomicCAS(address_as_ui, assumed, old);
745  } while (assumed != old);
746 }
747 
748 static inline __device__ void atomicAdd(uint8_t* address, uint8_t val) {
749  unsigned int* address_as_ui = (unsigned int*)(address - ((size_t)address & 0x3));
750  unsigned int old = *address_as_ui;
751  unsigned int shift = (((size_t)address & 0x3) << 3);
752  unsigned int sum;
753  unsigned int assumed;
754 
755  do {
756  assumed = old;
757  sum = val + static_cast<uint8_t>((old >> shift) & 0xff);
758  old = (old & ~(0x000000ff << shift)) | (sum << shift);
759  old = atomicCAS(address_as_ui, assumed, old);
760  } while (assumed != old);
761 }
762 
763 static inline __device__ void atomicAdd(int8_t* address, int8_t val) {
764  unsigned int* address_as_ui = (unsigned int*)(address - ((size_t)address & 0x3));
765  unsigned int old = *address_as_ui;
766  unsigned int shift = (((size_t)address & 0x3) << 3);
767  unsigned int sum;
768  unsigned int assumed;
769 
770  do {
771  assumed = old;
772  sum = val + static_cast<int8_t>((old >> shift) & 0xff);
773  old = (old & ~(0x000000ff << shift)) | (sum << shift);
774  old = atomicCAS(address_as_ui, assumed, old);
775  } while (assumed != old);
776 }
777 
778 // Overload atomicAdd to work for signed int64 on all architectures
779 static inline __device__ void atomicAdd(int64_t* address, int64_t val) {
780  atomicAdd(reinterpret_cast<unsigned long long*>(address), // NOLINT
781  static_cast<unsigned long long>(val)); // NOLINT
782 }
783 
784 template <typename DType>
785 __device__ inline DType ldg(const DType* address) {
786 #if __CUDA_ARCH__ >= 350
787  return __ldg(address);
788 #else
789  return *address;
790 #endif
791 }
792 
793 namespace mxnet {
794 namespace common {
796 namespace cuda {
797 
798 static constexpr const int warp_size = 32;
799 
806 template <int NVALUES = warp_size, typename OP, typename T>
807 __device__ inline T warp_reduce(T value, OP redfun) {
808 #pragma unroll
809  for (int i = warp_size / 2; i >= 1; i /= 2) {
810  if (NVALUES > i)
811  value = redfun(value, __shfl_down_sync(0xffffffff, value, i));
812  }
813  return value;
814 }
815 
816 template <typename OP, typename T>
817 __device__ inline T grouped_warp_allreduce(T value, OP redfun, const int group_size) {
818  for (int i = 1; i < group_size; i *= 2) {
819  value = redfun(value, __shfl_down_sync(0xffffffff, value, i));
820  }
821  return __shfl_sync(0xffffffff, value, 0, group_size);
822 }
823 
824 template <int NValues = warp_size, typename OP>
825 __device__ inline mshadow::half::half_t warp_reduce(mshadow::half::half_t value, OP redfun) {
826  float v = static_cast<float>(value);
827 #pragma unroll
828  for (int i = warp_size / 2; i >= 1; i /= 2) {
829  if (NValues > i)
830  v = redfun(v, __shfl_down_sync(0xffffffff, v, i));
831  }
832  return mshadow::half::half_t(v);
833 }
834 
847 template <int NTHREADS, bool all_reduce = true, typename OP, typename T>
848 __device__ inline T reduce(const T& value, OP redfun) {
849  static_assert(NTHREADS <= warp_size * warp_size, "Number of threads too large for reduction");
850  __shared__ T scratch[NTHREADS / warp_size];
851  const int thread_idx_in_warp = threadIdx.x % warp_size;
852  const int warp_id = threadIdx.x / warp_size;
853  const T my_val = warp_reduce<warp_size>(value, redfun);
854  if (thread_idx_in_warp == 0) {
855  scratch[warp_id] = my_val;
856  }
857  __syncthreads();
858  T ret = 0;
859  if (warp_id == 0) {
860  const T prev_val = threadIdx.x < (NTHREADS / warp_size) ? scratch[threadIdx.x] : 0;
861  const T my_val = warp_reduce<NTHREADS / warp_size>(prev_val, redfun);
862  if (all_reduce) {
863  scratch[threadIdx.x] = my_val;
864  } else {
865  ret = my_val;
866  }
867  }
868  // Necessary to synchronize in order to use this function again
869  // as the shared memory scratch space is reused between calls
870  __syncthreads();
871  if (all_reduce) {
872  ret = scratch[0];
873  __syncthreads();
874  }
875  return ret;
876 }
877 
878 } // namespace cuda
879 } // namespace common
880 } // namespace mxnet
881 
882 #endif // __CUDACC__
883 
884 #endif // MXNET_COMMON_CUDA_UTILS_H_
mxnet
namespace of mxnet
Definition: api_registry.h:33
mxnet::common::cuda::CublasType
Converts between C++ datatypes and enums/constants needed by cuBLAS.
Definition: utils.h:213
mxnet::common::cuda::CublasType< uint8_t >::ScaleType
uint8_t ScaleType
Definition: utils.h:257
MaxSharedMemoryPerMultiprocessor
int MaxSharedMemoryPerMultiprocessor(int device_id)
Return the shared memory size in bytes of each of the GPU's streaming multiprocessors.
Definition: utils.h:524
mxnet::common::cuda::CublasType< mshadow::half::half_t >::ScaleType
float ScaleType
Definition: utils.h:247
libinfo.h
get features of the MXNet library at runtime
optional.h
Container to hold optional data.
mxnet::common::cuda::CublasType< double >::zero
static const double zero
Definition: utils.h:239
GetEnvAllowTensorCore
bool GetEnvAllowTensorCore()
Returns global policy for TensorCore algo use.
Definition: utils.h:578
SupportsCooperativeLaunch
bool SupportsCooperativeLaunch(int device_id)
Return whether the GPU device_id supports cooperative-group kernel launching.
Definition: utils.h:537
parameter.h
Provide lightweight util to do parameter setup and checking.
ComputeCapabilityMajor
int ComputeCapabilityMajor(int device_id)
Determine major version number of the gpu's cuda compute architecture.
Definition: utils.h:480
mxnet::common::cuda::CublasType< mshadow::half::half_t >::zero
static const mshadow::half::half_t zero
Definition: utils.h:249
CUDA_CALL
#define CUDA_CALL(func)
Protected CUDA call.
Definition: utils.h:107
CUBLAS_CALL
#define CUBLAS_CALL(func)
Protected cuBLAS call.
Definition: utils.h:119
mxnet::common::cuda::CusolverGetErrorString
const char * CusolverGetErrorString(cusolverStatus_t error)
Get string representation of cuSOLVER errors.
Definition: utils.h:319
mshadow::kFloat64
@ kFloat64
Definition: base.h:353
SupportsFloat16Compute
bool SupportsFloat16Compute(int device_id)
Determine whether a cuda-capable gpu's architecture supports float16 math. Assume not if device_id is...
Definition: utils.h:549
GetEnvAllowTensorCoreConversion
bool GetEnvAllowTensorCoreConversion()
Returns global policy for TensorCore implicit type casting.
Definition: utils.h:601
SupportsTensorCore
bool SupportsTensorCore(int device_id)
Determine whether a cuda-capable gpu's architecture supports Tensor Core math. Assume not if device_i...
Definition: utils.h:566
SMArch
int SMArch(int device_id)
Return the integer SM architecture (e.g. Volta = 70).
Definition: utils.h:502
mxnet::common::cuda::CublasType< float >::one
static const float one
Definition: utils.h:228
mxnet::common::cuda::CublasType< mshadow::half::half_t >::one
static const mshadow::half::half_t one
Definition: utils.h:248
mxnet::common::cuda::DeviceStore
Definition: utils.h:390
kMaxNumGpus
constexpr size_t kMaxNumGpus
Maximum number of GPUs.
Definition: utils.h:448
mxnet::common::cuda::DeviceStore::~DeviceStore
~DeviceStore()
Definition: utils.h:402
ComputeCapabilityMinor
int ComputeCapabilityMinor(int device_id)
Determine minor version number of the gpu's cuda compute architecture.
Definition: utils.h:491
mxnet::common::cuda::CublasType< double >::one
static const double one
Definition: utils.h:238
mshadow::kInt32
@ kInt32
Definition: base.h:356
mxnet::common::cuda::CublasGetErrorString
const char * CublasGetErrorString(cublasStatus_t error)
Get string representation of cuBLAS errors.
Definition: utils.h:277
cudaAttributeLookup
int cudaAttributeLookup(int device_id, std::vector< int32_t > *cached_values, cudaDeviceAttr attr, const char *attr_name)
Return an attribute GPU device_id.
Definition: utils.h:461
mxnet::common::cuda::CublasType< float >::ScaleType
float ScaleType
Definition: utils.h:227
mxnet::common::cuda::get_rows_per_block
int get_rows_per_block(size_t row_size, int num_threads_per_block)
Determine how many rows in a 2D matrix should a block of threads handle based on the row size and the...
mxnet::common::cuda::CudaMax
DType __device__ CudaMax(DType a, DType b)
Definition: utils.h:381
MXNET_CUDA_ALLOW_TENSOR_CORE_DEFAULT
#define MXNET_CUDA_ALLOW_TENSOR_CORE_DEFAULT
Definition: utils.h:572
mshadow
overloaded + operator between half_t and bf16_t
Definition: base.h:319
mshadow::expr::transpose
TransposeExExp< SrcExp, DType, ExpInfo< SrcExp >::kDim > transpose(const Exp< SrcExp, DType, etype > &src, Shape< ExpInfo< SrcExp >::kDim > axes)
a expression that reshapes a tensor to another shape
Definition: transpose.h:76
MultiprocessorCount
int MultiprocessorCount(int device_id)
Return the number of streaming multiprocessors of GPU device_id.
Definition: utils.h:513
mxnet::common::cuda::DeviceStore::DeviceStore
DeviceStore(int requested_device=-1, bool restore=true)
default constructor- only optionally restores previous device
Definition: utils.h:393
mxnet::common::cuda::CublasType< double >::ScaleType
double ScaleType
Definition: utils.h:237
MXNET_CUDA_TENSOR_OP_MATH_ALLOW_CONVERSION_DEFAULT
#define MXNET_CUDA_TENSOR_OP_MATH_ALLOW_CONVERSION_DEFAULT
Definition: utils.h:596
mxnet::common::cuda::DeviceStore::SetDevice
void SetDevice(int device)
Definition: utils.h:408
mshadow::kUint8
@ kUint8
Definition: base.h:355
mxnet::common::cuda::CudaMin
DType __device__ CudaMin(DType a, DType b)
Definition: utils.h:386
mxnet::common::cuda::CurandGetErrorString
const char * CurandGetErrorString(curandStatus_t status)
Get string representation of cuRAND errors.
Definition: utils.h:348
mxnet::common::cuda::get_load_type
int get_load_type(size_t N)
Get the largest datatype suitable to read requested number of bytes.
mxnet::common::cuda::CublasType< int32_t >::ScaleType
int32_t ScaleType
Definition: utils.h:267
mshadow::kFloat16
@ kFloat16
Definition: base.h:354
base.h
definitions of base types, operators, macros functions
dmlc::optional
c++17 compatible optional class.
Definition: optional.h:43
mshadow::kFloat32
@ kFloat32
Definition: base.h:352
mxnet::common::cuda::CublasType< float >::zero
static const float zero
Definition: utils.h:229