mxnet
cuda_utils.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
25 #ifndef MXNET_COMMON_CUDA_UTILS_H_
26 #define MXNET_COMMON_CUDA_UTILS_H_
27 
28 #include <dmlc/logging.h>
29 #include <dmlc/parameter.h>
30 #include <dmlc/optional.h>
31 #include <mshadow/base.h>
32 #include <mxnet/libinfo.h>
33 
35 #ifdef __JETBRAINS_IDE__
36 #define __CUDACC__ 1
37 #define __host__
38 #define __device__
39 #define __global__
40 #define __forceinline__
41 #define __shared__
42 inline void __syncthreads() {}
43 inline void __threadfence_block() {}
44 template<class T> inline T __clz(const T val) { return val; }
45 struct __cuda_fake_struct { int x; int y; int z; };
46 extern __cuda_fake_struct blockDim;
47 extern __cuda_fake_struct threadIdx;
48 extern __cuda_fake_struct blockIdx;
49 #endif
50 
51 #define QUOTE(x) #x
52 #define QUOTEVALUE(x) QUOTE(x)
53 
54 #if MXNET_USE_CUDA
55 
56 #include <cuda_runtime.h>
57 #include <cublas_v2.h>
58 #include <curand.h>
59 
60 #include <vector>
61 
62 #define STATIC_ASSERT_CUDA_VERSION_GE(min_version) \
63  static_assert(CUDA_VERSION >= min_version, "Compiled-against CUDA version " \
64  QUOTEVALUE(CUDA_VERSION) " is too old, please upgrade system to version " \
65  QUOTEVALUE(min_version) " or later.")
66 
71 #ifdef __CUDACC__
72 inline __device__ bool __is_supported_cuda_architecture() {
73 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 300
74 #error "Fermi and earlier GPU architectures are not supported (architecture versions less than 3.0)"
75  return false;
76 #else
77  return true;
78 #endif // __CUDA_ARCH__ < 300
79 }
80 #endif // __CUDACC__
81 
86 #define CHECK_CUDA_ERROR(msg) \
87  { \
88  cudaError_t e = cudaGetLastError(); \
89  CHECK_EQ(e, cudaSuccess) << (msg) << " CUDA: " << cudaGetErrorString(e); \
90  }
91 
98 #define CUDA_CALL(func) \
99  { \
100  cudaError_t e = (func); \
101  CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) \
102  << "CUDA: " << cudaGetErrorString(e); \
103  }
104 
111 #define CUBLAS_CALL(func) \
112  { \
113  cublasStatus_t e = (func); \
114  CHECK_EQ(e, CUBLAS_STATUS_SUCCESS) \
115  << "cuBLAS: " << mxnet::common::cuda::CublasGetErrorString(e); \
116  }
117 
124 #define CUSOLVER_CALL(func) \
125  { \
126  cusolverStatus_t e = (func); \
127  CHECK_EQ(e, CUSOLVER_STATUS_SUCCESS) \
128  << "cuSolver: " << mxnet::common::cuda::CusolverGetErrorString(e); \
129  }
130 
137 #define CURAND_CALL(func) \
138  { \
139  curandStatus_t e = (func); \
140  CHECK_EQ(e, CURAND_STATUS_SUCCESS) \
141  << "cuRAND: " << mxnet::common::cuda::CurandGetErrorString(e); \
142  }
143 
150 #define NVRTC_CALL(x) \
151  { \
152  nvrtcResult result = x; \
153  CHECK_EQ(result, NVRTC_SUCCESS) \
154  << #x " failed with error " \
155  << nvrtcGetErrorString(result); \
156  }
157 
164 #define CUDA_DRIVER_CALL(func) \
165  { \
166  CUresult e = (func); \
167  if (e != CUDA_SUCCESS) { \
168  char const * err_msg = nullptr; \
169  if (cuGetErrorString(e, &err_msg) == CUDA_ERROR_INVALID_VALUE) { \
170  LOG(FATAL) << "CUDA Driver: Unknown error " << e; \
171  } else { \
172  LOG(FATAL) << "CUDA Driver: " << err_msg; \
173  } \
174  } \
175  }
176 
177 
178 #if !defined(_MSC_VER)
179 #define CUDA_UNROLL _Pragma("unroll")
180 #define CUDA_NOUNROLL _Pragma("nounroll")
181 #else
182 #define CUDA_UNROLL
183 #define CUDA_NOUNROLL
184 #endif
185 
186 namespace mxnet {
187 namespace common {
189 namespace cuda {
193 template<typename DType>
194 struct CublasType;
195 
196 // With CUDA v8, cuBLAS adopted use of cudaDataType_t instead of its own
197 // datatype cublasDataType_t. The older cudaDataType_t values could be
198 // included below, but since this class was introduced to support the cuBLAS v8
199 // call cublasGemmEx(), burdening the class with the legacy type values
200 // was not needed.
201 
202 template<>
203 struct CublasType<float> {
204  static const int kFlag = mshadow::kFloat32;
205 #if CUDA_VERSION >= 8000
206  static const cudaDataType_t kCudaFlag = CUDA_R_32F;
207 #endif
208  typedef float ScaleType;
209  static const float one;
210  static const float zero;
211 };
212 template<>
213 struct CublasType<double> {
214  static const int kFlag = mshadow::kFloat64;
215 #if CUDA_VERSION >= 8000
216  static const cudaDataType_t kCudaFlag = CUDA_R_64F;
217 #endif
218  typedef double ScaleType;
219  static const double one;
220  static const double zero;
221 };
222 template<>
223 struct CublasType<mshadow::half::half_t> {
224  static const int kFlag = mshadow::kFloat16;
225 #if CUDA_VERSION >= 8000
226  static const cudaDataType_t kCudaFlag = CUDA_R_16F;
227 #endif
228  typedef float ScaleType;
229  static const mshadow::half::half_t one;
230  static const mshadow::half::half_t zero;
231 };
232 template<>
233 struct CublasType<uint8_t> {
234  static const int kFlag = mshadow::kUint8;
235 #if CUDA_VERSION >= 8000
236  static const cudaDataType_t kCudaFlag = CUDA_R_8I;
237 #endif
238  typedef uint8_t ScaleType;
239  static const uint8_t one = 1;
240  static const uint8_t zero = 0;
241 };
242 template<>
243 struct CublasType<int32_t> {
244  static const int kFlag = mshadow::kInt32;
245 #if CUDA_VERSION >= 8000
246  static const cudaDataType_t kCudaFlag = CUDA_R_32I;
247 #endif
248  typedef int32_t ScaleType;
249  static const int32_t one = 1;
250  static const int32_t zero = 0;
251 };
252 
258 inline const char* CublasGetErrorString(cublasStatus_t error) {
259  switch (error) {
260  case CUBLAS_STATUS_SUCCESS:
261  return "CUBLAS_STATUS_SUCCESS";
262  case CUBLAS_STATUS_NOT_INITIALIZED:
263  return "CUBLAS_STATUS_NOT_INITIALIZED";
264  case CUBLAS_STATUS_ALLOC_FAILED:
265  return "CUBLAS_STATUS_ALLOC_FAILED";
266  case CUBLAS_STATUS_INVALID_VALUE:
267  return "CUBLAS_STATUS_INVALID_VALUE";
268  case CUBLAS_STATUS_ARCH_MISMATCH:
269  return "CUBLAS_STATUS_ARCH_MISMATCH";
270  case CUBLAS_STATUS_MAPPING_ERROR:
271  return "CUBLAS_STATUS_MAPPING_ERROR";
272  case CUBLAS_STATUS_EXECUTION_FAILED:
273  return "CUBLAS_STATUS_EXECUTION_FAILED";
274  case CUBLAS_STATUS_INTERNAL_ERROR:
275  return "CUBLAS_STATUS_INTERNAL_ERROR";
276  case CUBLAS_STATUS_NOT_SUPPORTED:
277  return "CUBLAS_STATUS_NOT_SUPPORTED";
278  default:
279  break;
280  }
281  return "Unknown cuBLAS status";
282 }
283 
284 #if CUDA_VERSION >= 8000
285 
290 inline cublasOperation_t CublasTransposeOp(bool transpose) {
291  return transpose ? CUBLAS_OP_T : CUBLAS_OP_N;
292 }
293 #endif
294 
300 inline const char* CusolverGetErrorString(cusolverStatus_t error) {
301  switch (error) {
302  case CUSOLVER_STATUS_SUCCESS:
303  return "CUSOLVER_STATUS_SUCCESS";
304  case CUSOLVER_STATUS_NOT_INITIALIZED:
305  return "CUSOLVER_STATUS_NOT_INITIALIZED";
306  case CUSOLVER_STATUS_ALLOC_FAILED:
307  return "CUSOLVER_STATUS_ALLOC_FAILED";
308  case CUSOLVER_STATUS_INVALID_VALUE:
309  return "CUSOLVER_STATUS_INVALID_VALUE";
310  case CUSOLVER_STATUS_ARCH_MISMATCH:
311  return "CUSOLVER_STATUS_ARCH_MISMATCH";
312  case CUSOLVER_STATUS_EXECUTION_FAILED:
313  return "CUSOLVER_STATUS_EXECUTION_FAILED";
314  case CUSOLVER_STATUS_INTERNAL_ERROR:
315  return "CUSOLVER_STATUS_INTERNAL_ERROR";
316  case CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED:
317  return "CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED";
318  default:
319  break;
320  }
321  return "Unknown cuSOLVER status";
322 }
323 
329 inline const char* CurandGetErrorString(curandStatus_t status) {
330  switch (status) {
331  case CURAND_STATUS_SUCCESS:
332  return "CURAND_STATUS_SUCCESS";
333  case CURAND_STATUS_VERSION_MISMATCH:
334  return "CURAND_STATUS_VERSION_MISMATCH";
335  case CURAND_STATUS_NOT_INITIALIZED:
336  return "CURAND_STATUS_NOT_INITIALIZED";
337  case CURAND_STATUS_ALLOCATION_FAILED:
338  return "CURAND_STATUS_ALLOCATION_FAILED";
339  case CURAND_STATUS_TYPE_ERROR:
340  return "CURAND_STATUS_TYPE_ERROR";
341  case CURAND_STATUS_OUT_OF_RANGE:
342  return "CURAND_STATUS_OUT_OF_RANGE";
343  case CURAND_STATUS_LENGTH_NOT_MULTIPLE:
344  return "CURAND_STATUS_LENGTH_NOT_MULTIPLE";
345  case CURAND_STATUS_DOUBLE_PRECISION_REQUIRED:
346  return "CURAND_STATUS_DOUBLE_PRECISION_REQUIRED";
347  case CURAND_STATUS_LAUNCH_FAILURE:
348  return "CURAND_STATUS_LAUNCH_FAILURE";
349  case CURAND_STATUS_PREEXISTING_FAILURE:
350  return "CURAND_STATUS_PREEXISTING_FAILURE";
351  case CURAND_STATUS_INITIALIZATION_FAILED:
352  return "CURAND_STATUS_INITIALIZATION_FAILED";
353  case CURAND_STATUS_ARCH_MISMATCH:
354  return "CURAND_STATUS_ARCH_MISMATCH";
355  case CURAND_STATUS_INTERNAL_ERROR:
356  return "CURAND_STATUS_INTERNAL_ERROR";
357  }
358  return "Unknown cuRAND status";
359 }
360 
361 template <typename DType>
362 inline DType __device__ CudaMax(DType a, DType b) {
363  return a > b ? a : b;
364 }
365 
366 template <typename DType>
367 inline DType __device__ CudaMin(DType a, DType b) {
368  return a < b ? a : b;
369 }
370 
371 class DeviceStore {
372  public:
374  explicit DeviceStore(int requested_device = -1, bool restore = true) :
375  restore_device_(-1),
376  current_device_(requested_device),
377  restore_(restore) {
378  if (restore_)
379  CUDA_CALL(cudaGetDevice(&restore_device_));
380  if (requested_device != restore_device_) {
381  SetDevice(requested_device);
382  }
383  }
384 
386  if (restore_ &&
387  current_device_ != restore_device_ &&
388  current_device_ != -1 &&
389  restore_device_ != -1)
390  CUDA_CALL(cudaSetDevice(restore_device_));
391  }
392 
393  void SetDevice(int device) {
394  if (device != -1) {
395  CUDA_CALL(cudaSetDevice(device));
396  current_device_ = device;
397  }
398  }
399 
400  private:
401  int restore_device_;
402  int current_device_;
403  bool restore_;
404 };
405 
414 int get_load_type(size_t N);
415 
426 int get_rows_per_block(size_t row_size, int num_threads_per_block);
427 
428 } // namespace cuda
429 } // namespace common
430 } // namespace mxnet
431 
433 constexpr size_t kMaxNumGpus = 64;
434 
435 // The implementations below assume that accesses of 32-bit ints are inherently atomic and
436 // can be read/written by multiple threads without locks. The values held should be < 2^31.
437 
446 inline int cudaAttributeLookup(int device_id, std::vector<int32_t> *cached_values,
447  cudaDeviceAttr attr, const char *attr_name) {
448  if (device_id < 0 || device_id >= static_cast<int>(cached_values->size())) {
449  LOG(FATAL) << attr_name << "(device_id) called with invalid id: " << device_id;
450  } else if ((*cached_values)[device_id] < 0) {
451  int temp = -1;
452  CUDA_CALL(cudaDeviceGetAttribute(&temp, attr, device_id));
453  (*cached_values)[device_id] = static_cast<int32_t>(temp);
454  }
455  return (*cached_values)[device_id];
456 }
457 
463 inline int ComputeCapabilityMajor(int device_id) {
464  static std::vector<int32_t> capability_major(kMaxNumGpus, -1);
465  return cudaAttributeLookup(device_id, &capability_major,
466  cudaDevAttrComputeCapabilityMajor, "ComputeCapabilityMajor");
467 }
468 
474 inline int ComputeCapabilityMinor(int device_id) {
475  static std::vector<int32_t> capability_minor(kMaxNumGpus, -1);
476  return cudaAttributeLookup(device_id, &capability_minor,
477  cudaDevAttrComputeCapabilityMinor, "ComputeCapabilityMinor");
478 }
479 
485 inline int SMArch(int device_id) {
486  auto major = ComputeCapabilityMajor(device_id);
487  auto minor = ComputeCapabilityMinor(device_id);
488  return 10 * major + minor;
489 }
490 
496 inline int MultiprocessorCount(int device_id) {
497  static std::vector<int32_t> sm_counts(kMaxNumGpus, -1);
498  return cudaAttributeLookup(device_id, &sm_counts,
499  cudaDevAttrMultiProcessorCount, "MultiprocessorCount");
500 }
501 
507 inline int MaxSharedMemoryPerMultiprocessor(int device_id) {
508  static std::vector<int32_t> max_smem_per_mutiprocessor(kMaxNumGpus, -1);
509  return cudaAttributeLookup(device_id, &max_smem_per_mutiprocessor,
510  cudaDevAttrMaxSharedMemoryPerMultiprocessor,
511  "MaxSharedMemoryPerMultiprocessor");
512 }
513 
519 inline bool SupportsCooperativeLaunch(int device_id) {
520  static std::vector<int32_t> coop_launch(kMaxNumGpus, -1);
521  return cudaAttributeLookup(device_id, &coop_launch,
522  cudaDevAttrCooperativeLaunch, "SupportsCooperativeLaunch");
523 }
524 
531 inline bool SupportsFloat16Compute(int device_id) {
532  if (device_id < 0) {
533  return false;
534  } else {
535  // Kepler and most Maxwell GPUs do not support fp16 compute
536  int computeCapabilityMajor = ComputeCapabilityMajor(device_id);
537  return (computeCapabilityMajor > 5) ||
538  (computeCapabilityMajor == 5 && ComputeCapabilityMinor(device_id) >= 3);
539  }
540 }
541 
548 inline bool SupportsTensorCore(int device_id) {
549  // Volta (sm_70) supports TensorCore algos
550  return device_id >= 0 &&
551  ComputeCapabilityMajor(device_id) >=7;
552 }
553 
554 // The policy if the user hasn't set the environment variable MXNET_CUDA_ALLOW_TENSOR_CORE
555 #define MXNET_CUDA_ALLOW_TENSOR_CORE_DEFAULT true
556 
561 inline bool GetEnvAllowTensorCore() {
562  // Since these statics are in the '.h' file, they will exist and will be set
563  // separately in each compilation unit. Not ideal, but cleaner than creating a
564  // cuda_utils.cc solely to have a single instance and initialization.
565  static bool allow_tensor_core = false;
566  static bool is_set = false;
567  if (!is_set) {
568  // Use of optional<bool> here permits: "0", "1", "true" and "false" to all be legal.
569  bool default_value = MXNET_CUDA_ALLOW_TENSOR_CORE_DEFAULT;
570  allow_tensor_core = dmlc::GetEnv("MXNET_CUDA_ALLOW_TENSOR_CORE",
571  dmlc::optional<bool>(default_value)).value();
572  is_set = true;
573  }
574  return allow_tensor_core;
575 }
576 
577 // The policy if the user hasn't set the environment variable
578 // CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION
579 #define MXNET_CUDA_TENSOR_OP_MATH_ALLOW_CONVERSION_DEFAULT false
580 
585  // Use of optional<bool> here permits: "0", "1", "true" and "false" to all be
586  // legal.
588  return dmlc::GetEnv("MXNET_CUDA_TENSOR_OP_MATH_ALLOW_CONVERSION",
589  dmlc::optional<bool>(default_value))
590  .value();
591 }
592 
593 #if CUDA_VERSION >= 9000
594 // Sets the cuBLAS math mode that determines the 'allow TensorCore' policy. Returns previous.
595 inline cublasMath_t SetCublasMathMode(cublasHandle_t blas_handle, cublasMath_t new_math_type) {
596  auto handle_math_mode = CUBLAS_DEFAULT_MATH;
597  CUBLAS_CALL(cublasGetMathMode(blas_handle, &handle_math_mode));
598  CUBLAS_CALL(cublasSetMathMode(blas_handle, new_math_type));
599  return handle_math_mode;
600 }
601 #endif
602 
603 #endif // MXNET_USE_CUDA
604 
605 #if MXNET_USE_CUDNN
606 
607 #include <cudnn.h>
608 
609 // Creating CUDNN_VERSION_AS_STRING as follows avoids a static_assert error message that shows
610 // the formula for CUDNN_VERSION, i.e. "1000 * 7 + 100 * 6 + 0" rather than number "7600".
611 static_assert(CUDNN_PATCHLEVEL < 100 && CUDNN_MINOR < 10,
612  "CUDNN_VERSION_AS_STRING macro assumptions violated.");
613 #if CUDNN_PATCHLEVEL >= 10
614 #define CUDNN_VERSION_AS_STRING QUOTEVALUE(CUDNN_MAJOR) \
615  QUOTEVALUE(CUDNN_MINOR) \
616  QUOTEVALUE(CUDNN_PATCHLEVEL)
617 #else
618 #define CUDNN_VERSION_AS_STRING QUOTEVALUE(CUDNN_MAJOR) \
619  QUOTEVALUE(CUDNN_MINOR) \
620  "0" QUOTEVALUE(CUDNN_PATCHLEVEL)
621 #endif
622 
623 #define STATIC_ASSERT_CUDNN_VERSION_GE(min_version) \
624  static_assert(CUDNN_VERSION >= min_version, "Compiled-against cuDNN version " \
625  CUDNN_VERSION_AS_STRING " is too old, please upgrade system to version " \
626  QUOTEVALUE(min_version) " or later.")
627 
628 #define CUDNN_CALL(func) \
629  { \
630  cudnnStatus_t e = (func); \
631  CHECK_EQ(e, CUDNN_STATUS_SUCCESS) << "cuDNN: " << cudnnGetErrorString(e); \
632  }
633 
641 inline int MaxForwardAlgos(cudnnHandle_t cudnn_handle) {
642  STATIC_ASSERT_CUDNN_VERSION_GE(7000);
643  int max_algos = 0;
644  CUDNN_CALL(cudnnGetConvolutionForwardAlgorithmMaxCount(cudnn_handle, &max_algos));
645  return max_algos;
646 }
647 
655 inline int MaxBackwardFilterAlgos(cudnnHandle_t cudnn_handle) {
656  STATIC_ASSERT_CUDNN_VERSION_GE(7000);
657  int max_algos = 0;
658  CUDNN_CALL(cudnnGetConvolutionBackwardFilterAlgorithmMaxCount(cudnn_handle, &max_algos));
659  return max_algos;
660 }
661 
669 inline int MaxBackwardDataAlgos(cudnnHandle_t cudnn_handle) {
670  STATIC_ASSERT_CUDNN_VERSION_GE(7000);
671  int max_algos = 0;
672  CUDNN_CALL(cudnnGetConvolutionBackwardDataAlgorithmMaxCount(cudnn_handle, &max_algos));
673  return max_algos;
674 }
675 
676 #endif // MXNET_USE_CUDNN
677 
678 // Overload atomicAdd to work for floats on all architectures
679 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 600
680 // From CUDA Programming Guide
681 static inline __device__ void atomicAdd(double *address, double val) {
682  unsigned long long* address_as_ull = // NOLINT(*)
683  reinterpret_cast<unsigned long long*>(address); // NOLINT(*)
684  unsigned long long old = *address_as_ull; // NOLINT(*)
685  unsigned long long assumed; // NOLINT(*)
686 
687  do {
688  assumed = old;
689  old = atomicCAS(address_as_ull, assumed,
690  __double_as_longlong(val +
691  __longlong_as_double(assumed)));
692 
693  // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
694  } while (assumed != old);
695 }
696 #endif
697 
698 // Overload atomicAdd for half precision
699 // Taken from:
700 // https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh
701 #ifdef __CUDACC__
702 static inline __device__ void atomicAdd(mshadow::half::half_t *address,
703  mshadow::half::half_t val) {
704  unsigned int *address_as_ui =
705  reinterpret_cast<unsigned int *>(reinterpret_cast<char *>(address) -
706  (reinterpret_cast<size_t>(address) & 2));
707  unsigned int old = *address_as_ui;
708  unsigned int assumed;
709 
710  do {
711  assumed = old;
712  mshadow::half::half_t hsum;
713  hsum.half_ =
714  reinterpret_cast<size_t>(address) & 2 ? (old >> 16) : (old & 0xffff);
715  hsum += val;
716  old = reinterpret_cast<size_t>(address) & 2
717  ? (old & 0xffff) | (hsum.half_ << 16)
718  : (old & 0xffff0000) | hsum.half_;
719  old = atomicCAS(address_as_ui, assumed, old);
720  } while (assumed != old);
721 }
722 
723 static inline __device__ void atomicAdd(uint8_t *address, uint8_t val) {
724  unsigned int * address_as_ui = (unsigned int *) (address - ((size_t)address & 0x3));
725  unsigned int old = *address_as_ui;
726  unsigned int shift = (((size_t)address & 0x3) << 3);
727  unsigned int sum;
728  unsigned int assumed;
729 
730  do {
731  assumed = old;
732  sum = val + static_cast<uint8_t>((old >> shift) & 0xff);
733  old = (old & ~(0x000000ff << shift)) | (sum << shift);
734  old = atomicCAS(address_as_ui, assumed, old);
735  } while (assumed != old);
736 }
737 
738 static inline __device__ void atomicAdd(int8_t *address, int8_t val) {
739  unsigned int * address_as_ui = (unsigned int *) (address - ((size_t)address & 0x3));
740  unsigned int old = *address_as_ui;
741  unsigned int shift = (((size_t)address & 0x3) << 3);
742  unsigned int sum;
743  unsigned int assumed;
744 
745  do {
746  assumed = old;
747  sum = val + static_cast<int8_t>((old >> shift) & 0xff);
748  old = (old & ~(0x000000ff << shift)) | (sum << shift);
749  old = atomicCAS(address_as_ui, assumed, old);
750  } while (assumed != old);
751 }
752 
753 // Overload atomicAdd to work for signed int64 on all architectures
754 static inline __device__ void atomicAdd(int64_t *address, int64_t val) {
755  atomicAdd(reinterpret_cast<unsigned long long*>(address), static_cast<unsigned long long>(val)); // NOLINT
756 }
757 
758 template <typename DType>
759 __device__ inline DType ldg(const DType* address) {
760 #if __CUDA_ARCH__ >= 350
761  return __ldg(address);
762 #else
763  return *address;
764 #endif
765 }
766 
767 template <typename OP, typename T>
768 __device__ inline T warp_reduce(T value, OP redfun) {
769  value = redfun(value, __shfl_down_sync(0xffffffff, value, 16));
770  value = redfun(value, __shfl_down_sync(0xffffffff, value, 8));
771  value = redfun(value, __shfl_down_sync(0xffffffff, value, 4));
772  value = redfun(value, __shfl_down_sync(0xffffffff, value, 2));
773  value = redfun(value, __shfl_down_sync(0xffffffff, value, 1));
774  return value;
775 }
776 
777 template <typename OP>
778 __device__ inline mshadow::half::half_t warp_reduce(mshadow::half::half_t value, OP redfun) {
779  float v = static_cast<float>(value);
780  v = redfun(v, __shfl_down_sync(0xffffffff, v, 16));
781  v = redfun(v, __shfl_down_sync(0xffffffff, v, 8));
782  v = redfun(v, __shfl_down_sync(0xffffffff, v, 4));
783  v = redfun(v, __shfl_down_sync(0xffffffff, v, 2));
784  v = redfun(v, __shfl_down_sync(0xffffffff, v, 1));
785  return mshadow::half::half_t(v);
786 }
787 
788 #endif // __CUDACC__
789 
790 #endif // MXNET_COMMON_CUDA_UTILS_H_
Definition: base.h:307
#define CUBLAS_CALL(func)
Protected cuBLAS call.
Definition: cuda_utils.h:111
static const double zero
Definition: cuda_utils.h:220
Container to hold optional data.
int ComputeCapabilityMajor(int device_id)
Determine major version number of the gpu&#39;s cuda compute architecture.
Definition: cuda_utils.h:463
Definition: cuda_utils.h:371
c++17 compatible optional class.
Definition: optional.h:43
bool GetEnvAllowTensorCoreConversion()
Returns global policy for TensorCore implicit type casting.
Definition: cuda_utils.h:584
namespace of mxnet
Definition: base.h:89
int get_load_type(size_t N)
Get the largest datatype suitable to read requested number of bytes.
bool GetEnvAllowTensorCore()
Returns global policy for TensorCore algo use.
Definition: cuda_utils.h:561
DeviceStore(int requested_device=-1, bool restore=true)
default constructor- only optionally restores previous device
Definition: cuda_utils.h:374
int SMArch(int device_id)
Return the integer SM architecture (e.g. Volta = 70).
Definition: cuda_utils.h:485
int32_t ScaleType
Definition: cuda_utils.h:248
static const float zero
Definition: cuda_utils.h:210
static const float one
Definition: cuda_utils.h:209
DType __device__ CudaMin(DType a, DType b)
Definition: cuda_utils.h:367
static const double one
Definition: cuda_utils.h:219
int MultiprocessorCount(int device_id)
Return the number of streaming multiprocessors of GPU device_id.
Definition: cuda_utils.h:496
static const mshadow::half::half_t zero
Definition: cuda_utils.h:230
void SetDevice(int device)
Definition: cuda_utils.h:393
bool SupportsFloat16Compute(int device_id)
Determine whether a cuda-capable gpu&#39;s architecture supports float16 math. Assume not if device_id is...
Definition: cuda_utils.h:531
constexpr size_t kMaxNumGpus
Maximum number of GPUs.
Definition: cuda_utils.h:433
DType __device__ CudaMax(DType a, DType b)
Definition: cuda_utils.h:362
Definition: base.h:308
void SetDevice(int devid)
set the device of current thread to work on
#define MXNET_CUDA_TENSOR_OP_MATH_ALLOW_CONVERSION_DEFAULT
Definition: cuda_utils.h:579
int cudaAttributeLookup(int device_id, std::vector< int32_t > *cached_values, cudaDeviceAttr attr, const char *attr_name)
Return an attribute GPU device_id.
Definition: cuda_utils.h:446
Definition: base.h:311
bool SupportsTensorCore(int device_id)
Determine whether a cuda-capable gpu&#39;s architecture supports Tensor Core math. Assume not if device_i...
Definition: cuda_utils.h:548
const char * CusolverGetErrorString(cusolverStatus_t error)
Get string representation of cuSOLVER errors.
Definition: cuda_utils.h:300
#define MXNET_CUDA_ALLOW_TENSOR_CORE_DEFAULT
Definition: cuda_utils.h:555
const char * CurandGetErrorString(curandStatus_t status)
Get string representation of cuRAND errors.
Definition: cuda_utils.h:329
int MaxSharedMemoryPerMultiprocessor(int device_id)
Return the shared memory size in bytes of each of the GPU&#39;s streaming multiprocessors.
Definition: cuda_utils.h:507
Definition: base.h:310
~DeviceStore()
Definition: cuda_utils.h:385
uint8_t ScaleType
Definition: cuda_utils.h:238
Converts between C++ datatypes and enums/constants needed by cuBLAS.
Definition: cuda_utils.h:194
int ComputeCapabilityMinor(int device_id)
Determine minor version number of the gpu&#39;s cuda compute architecture.
Definition: cuda_utils.h:474
static const mshadow::half::half_t one
Definition: cuda_utils.h:229
namespace for mshadow
Definition: base.h:282
float ScaleType
Definition: cuda_utils.h:208
TransposeExExp< SrcExp, DType, ExpInfo< SrcExp >::kDim > transpose(const Exp< SrcExp, DType, etype > &src, Shape< ExpInfo< SrcExp >::kDim > axes)
a expression that reshapes a tensor to another shape
Definition: transpose.h:58
Definition: base.h:309
#define CUDA_CALL(func)
Protected CUDA call.
Definition: cuda_utils.h:98
double ScaleType
Definition: cuda_utils.h:218
const char * CublasGetErrorString(cublasStatus_t error)
Get string representation of cuBLAS errors.
Definition: cuda_utils.h:258
int get_rows_per_block(size_t row_size, int num_threads_per_block)
Determine how many rows in a 2D matrix should a block of threads handle based on the row size and the...
Provide lightweight util to do parameter setup and checking.
bool SupportsCooperativeLaunch(int device_id)
Return whether the GPU device_id supports cooperative-group kernel launching.
Definition: cuda_utils.h:519
get features of the MXNet library at runtime