mxnet
cuda_utils.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef MXNET_COMMON_CUDA_UTILS_H_
25 #define MXNET_COMMON_CUDA_UTILS_H_
26 
27 #include <dmlc/logging.h>
28 #include <dmlc/parameter.h>
29 #include <dmlc/optional.h>
30 #include <mshadow/base.h>
31 
33 #ifdef __JETBRAINS_IDE__
34 #define __CUDACC__ 1
35 #define __host__
36 #define __device__
37 #define __global__
38 #define __forceinline__
39 #define __shared__
40 inline void __syncthreads() {}
41 inline void __threadfence_block() {}
42 template<class T> inline T __clz(const T val) { return val; }
43 struct __cuda_fake_struct { int x; int y; int z; };
44 extern __cuda_fake_struct blockDim;
45 extern __cuda_fake_struct threadIdx;
46 extern __cuda_fake_struct blockIdx;
47 #endif
48 
49 #if MXNET_USE_CUDA
50 
51 #include <cuda_runtime.h>
52 #include <cublas_v2.h>
53 #include <curand.h>
54 
55 namespace mxnet {
56 namespace common {
58 namespace cuda {
64 inline const char* CublasGetErrorString(cublasStatus_t error) {
65  switch (error) {
66  case CUBLAS_STATUS_SUCCESS:
67  return "CUBLAS_STATUS_SUCCESS";
68  case CUBLAS_STATUS_NOT_INITIALIZED:
69  return "CUBLAS_STATUS_NOT_INITIALIZED";
70  case CUBLAS_STATUS_ALLOC_FAILED:
71  return "CUBLAS_STATUS_ALLOC_FAILED";
72  case CUBLAS_STATUS_INVALID_VALUE:
73  return "CUBLAS_STATUS_INVALID_VALUE";
74  case CUBLAS_STATUS_ARCH_MISMATCH:
75  return "CUBLAS_STATUS_ARCH_MISMATCH";
76  case CUBLAS_STATUS_MAPPING_ERROR:
77  return "CUBLAS_STATUS_MAPPING_ERROR";
78  case CUBLAS_STATUS_EXECUTION_FAILED:
79  return "CUBLAS_STATUS_EXECUTION_FAILED";
80  case CUBLAS_STATUS_INTERNAL_ERROR:
81  return "CUBLAS_STATUS_INTERNAL_ERROR";
82  case CUBLAS_STATUS_NOT_SUPPORTED:
83  return "CUBLAS_STATUS_NOT_SUPPORTED";
84  default:
85  break;
86  }
87  return "Unknown cuBLAS status";
88 }
89 
95 inline const char* CusolverGetErrorString(cusolverStatus_t error) {
96  switch (error) {
97  case CUSOLVER_STATUS_SUCCESS:
98  return "CUSOLVER_STATUS_SUCCESS";
99  case CUSOLVER_STATUS_NOT_INITIALIZED:
100  return "CUSOLVER_STATUS_NOT_INITIALIZED";
101  case CUSOLVER_STATUS_ALLOC_FAILED:
102  return "CUSOLVER_STATUS_ALLOC_FAILED";
103  case CUSOLVER_STATUS_INVALID_VALUE:
104  return "CUSOLVER_STATUS_INVALID_VALUE";
105  case CUSOLVER_STATUS_ARCH_MISMATCH:
106  return "CUSOLVER_STATUS_ARCH_MISMATCH";
107  case CUSOLVER_STATUS_EXECUTION_FAILED:
108  return "CUSOLVER_STATUS_EXECUTION_FAILED";
109  case CUSOLVER_STATUS_INTERNAL_ERROR:
110  return "CUSOLVER_STATUS_INTERNAL_ERROR";
111  case CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED:
112  return "CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED";
113  default:
114  break;
115  }
116  return "Unknown cuSOLVER status";
117 }
118 
124 inline const char* CurandGetErrorString(curandStatus_t status) {
125  switch (status) {
126  case CURAND_STATUS_SUCCESS:
127  return "CURAND_STATUS_SUCCESS";
128  case CURAND_STATUS_VERSION_MISMATCH:
129  return "CURAND_STATUS_VERSION_MISMATCH";
130  case CURAND_STATUS_NOT_INITIALIZED:
131  return "CURAND_STATUS_NOT_INITIALIZED";
132  case CURAND_STATUS_ALLOCATION_FAILED:
133  return "CURAND_STATUS_ALLOCATION_FAILED";
134  case CURAND_STATUS_TYPE_ERROR:
135  return "CURAND_STATUS_TYPE_ERROR";
136  case CURAND_STATUS_OUT_OF_RANGE:
137  return "CURAND_STATUS_OUT_OF_RANGE";
138  case CURAND_STATUS_LENGTH_NOT_MULTIPLE:
139  return "CURAND_STATUS_LENGTH_NOT_MULTIPLE";
140  case CURAND_STATUS_DOUBLE_PRECISION_REQUIRED:
141  return "CURAND_STATUS_DOUBLE_PRECISION_REQUIRED";
142  case CURAND_STATUS_LAUNCH_FAILURE:
143  return "CURAND_STATUS_LAUNCH_FAILURE";
144  case CURAND_STATUS_PREEXISTING_FAILURE:
145  return "CURAND_STATUS_PREEXISTING_FAILURE";
146  case CURAND_STATUS_INITIALIZATION_FAILED:
147  return "CURAND_STATUS_INITIALIZATION_FAILED";
148  case CURAND_STATUS_ARCH_MISMATCH:
149  return "CURAND_STATUS_ARCH_MISMATCH";
150  case CURAND_STATUS_INTERNAL_ERROR:
151  return "CURAND_STATUS_INTERNAL_ERROR";
152  }
153  return "Unknown cuRAND status";
154 }
155 
156 template <typename DType>
157 inline DType __device__ CudaMax(DType a, DType b) {
158  return a > b ? a : b;
159 }
160 
161 template <typename DType>
162 inline DType __device__ CudaMin(DType a, DType b) {
163  return a < b ? a : b;
164 }
165 
166 } // namespace cuda
167 } // namespace common
168 } // namespace mxnet
169 
174 #define CHECK_CUDA_ERROR(msg) \
175  { \
176  cudaError_t e = cudaGetLastError(); \
177  CHECK_EQ(e, cudaSuccess) << (msg) << " CUDA: " << cudaGetErrorString(e); \
178  }
179 
186 #define CUDA_CALL(func) \
187  { \
188  cudaError_t e = (func); \
189  CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) \
190  << "CUDA: " << cudaGetErrorString(e); \
191  }
192 
199 #define CUBLAS_CALL(func) \
200  { \
201  cublasStatus_t e = (func); \
202  CHECK_EQ(e, CUBLAS_STATUS_SUCCESS) \
203  << "cuBLAS: " << common::cuda::CublasGetErrorString(e); \
204  }
205 
212 #define CUSOLVER_CALL(func) \
213  { \
214  cusolverStatus_t e = (func); \
215  CHECK_EQ(e, CUSOLVER_STATUS_SUCCESS) \
216  << "cuSolver: " << common::cuda::CusolverGetErrorString(e); \
217  }
218 
225 #define CURAND_CALL(func) \
226  { \
227  curandStatus_t e = (func); \
228  CHECK_EQ(e, CURAND_STATUS_SUCCESS) \
229  << "cuRAND: " << common::cuda::CurandGetErrorString(e); \
230  }
231 
232 #if !defined(_MSC_VER)
233 #define CUDA_UNROLL _Pragma("unroll")
234 #define CUDA_NOUNROLL _Pragma("nounroll")
235 #else
236 #define CUDA_UNROLL
237 #define CUDA_NOUNROLL
238 #endif
239 
245 inline int ComputeCapabilityMajor(int device_id) {
246  int major = 0;
247  CUDA_CALL(cudaDeviceGetAttribute(&major,
248  cudaDevAttrComputeCapabilityMajor, device_id));
249  return major;
250 }
251 
257 inline int ComputeCapabilityMinor(int device_id) {
258  int minor = 0;
259  CUDA_CALL(cudaDeviceGetAttribute(&minor,
260  cudaDevAttrComputeCapabilityMinor, device_id));
261  return minor;
262 }
263 
269 inline int SMArch(int device_id) {
270  auto major = ComputeCapabilityMajor(device_id);
271  auto minor = ComputeCapabilityMinor(device_id);
272  return 10 * major + minor;
273 }
274 
280 inline bool SupportsFloat16Compute(int device_id) {
281  // Kepler and most Maxwell GPUs do not support fp16 compute
282  int computeCapabilityMajor = ComputeCapabilityMajor(device_id);
283  int computeCapabilityMinor = ComputeCapabilityMinor(device_id);
284  return (computeCapabilityMajor > 5) ||
285  (computeCapabilityMajor == 5 && computeCapabilityMinor >= 3);
286 }
287 
293 inline bool SupportsTensorCore(int device_id) {
294  // Volta (sm_70) supports TensorCore algos
295  int computeCapabilityMajor = ComputeCapabilityMajor(device_id);
296  return (computeCapabilityMajor >= 7);
297 }
298 
299 // The policy if the user hasn't set the environment variable MXNET_CUDA_ALLOW_TENSOR_CORE
300 #define MXNET_CUDA_ALLOW_TENSOR_CORE_DEFAULT true
301 
306 inline bool GetEnvAllowTensorCore() {
307  // Use of optional<bool> here permits: "0", "1", "true" and "false" to all be legal.
308  bool default_value = MXNET_CUDA_ALLOW_TENSOR_CORE_DEFAULT;
309  return dmlc::GetEnv("MXNET_CUDA_ALLOW_TENSOR_CORE",
310  dmlc::optional<bool>(default_value)).value();
311 }
312 #endif // MXNET_USE_CUDA
313 
314 #if MXNET_USE_CUDNN
315 
316 #include <cudnn.h>
317 
318 #define CUDNN_CALL(func) \
319  { \
320  cudnnStatus_t e = (func); \
321  CHECK_EQ(e, CUDNN_STATUS_SUCCESS) << "cuDNN: " << cudnnGetErrorString(e); \
322  }
323 
331 inline int MaxForwardAlgos(cudnnHandle_t cudnn_handle) {
332 #if CUDNN_MAJOR >= 7
333  int max_algos = 0;
334  CUDNN_CALL(cudnnGetConvolutionForwardAlgorithmMaxCount(cudnn_handle, &max_algos));
335  return max_algos;
336 #else
337  return 10;
338 #endif
339 }
340 
348 inline int MaxBackwardFilterAlgos(cudnnHandle_t cudnn_handle) {
349 #if CUDNN_MAJOR >= 7
350  int max_algos = 0;
351  CUDNN_CALL(cudnnGetConvolutionBackwardFilterAlgorithmMaxCount(cudnn_handle, &max_algos));
352  return max_algos;
353 #else
354  return 10;
355 #endif
356 }
357 
365 inline int MaxBackwardDataAlgos(cudnnHandle_t cudnn_handle) {
366 #if CUDNN_MAJOR >= 7
367  int max_algos = 0;
368  CUDNN_CALL(cudnnGetConvolutionBackwardDataAlgorithmMaxCount(cudnn_handle, &max_algos));
369  return max_algos;
370 #else
371  return 10;
372 #endif
373 }
374 
375 #endif // MXNET_USE_CUDNN
376 
377 // Overload atomicAdd to work for floats on all architectures
378 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 600
379 // From CUDA Programming Guide
380 static inline __device__ void atomicAdd(double *address, double val) {
381  unsigned long long* address_as_ull = // NOLINT(*)
382  reinterpret_cast<unsigned long long*>(address); // NOLINT(*)
383  unsigned long long old = *address_as_ull; // NOLINT(*)
384  unsigned long long assumed; // NOLINT(*)
385 
386  do {
387  assumed = old;
388  old = atomicCAS(address_as_ull, assumed,
389  __double_as_longlong(val +
390  __longlong_as_double(assumed)));
391 
392  // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
393  } while (assumed != old);
394 }
395 #endif
396 
397 // Overload atomicAdd for half precision
398 // Taken from:
399 // https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh
400 #if defined(__CUDA_ARCH__)
401 static inline __device__ void atomicAdd(mshadow::half::half_t *address,
402  mshadow::half::half_t val) {
403  unsigned int *address_as_ui =
404  reinterpret_cast<unsigned int *>(reinterpret_cast<char *>(address) -
405  (reinterpret_cast<size_t>(address) & 2));
406  unsigned int old = *address_as_ui;
407  unsigned int assumed;
408 
409  do {
410  assumed = old;
411  mshadow::half::half_t hsum;
412  hsum.half_ =
413  reinterpret_cast<size_t>(address) & 2 ? (old >> 16) : (old & 0xffff);
414  hsum += val;
415  old = reinterpret_cast<size_t>(address) & 2
416  ? (old & 0xffff) | (hsum.half_ << 16)
417  : (old & 0xffff0000) | hsum.half_;
418  old = atomicCAS(address_as_ui, assumed, old);
419  } while (assumed != old);
420 }
421 
422 template <typename DType>
423 __device__ inline DType ldg(const DType* address) {
424 #if __CUDA_ARCH__ >= 350
425  return __ldg(address);
426 #else
427  return *address;
428 #endif
429 }
430 #endif
431 
432 #endif // MXNET_COMMON_CUDA_UTILS_H_
int ComputeCapabilityMajor(int device_id)
Determine major version number of the gpu&#39;s cuda compute architecture.
Definition: cuda_utils.h:245
namespace of mxnet
Definition: base.h:126
bool GetEnvAllowTensorCore()
Returns global policy for TensorCore algo use.
Definition: cuda_utils.h:306
int SMArch(int device_id)
Return the integer SM architecture (e.g. Volta = 70).
Definition: cuda_utils.h:269
DType __device__ CudaMin(DType a, DType b)
Definition: cuda_utils.h:162
bool SupportsFloat16Compute(int device_id)
Determine whether a cuda-capable gpu&#39;s architecture supports float16 math.
Definition: cuda_utils.h:280
DType __device__ CudaMax(DType a, DType b)
Definition: cuda_utils.h:157
bool SupportsTensorCore(int device_id)
Determine whether a cuda-capable gpu&#39;s architecture supports Tensor Core math.
Definition: cuda_utils.h:293
const char * CusolverGetErrorString(cusolverStatus_t error)
Get string representation of cuSOLVER errors.
Definition: cuda_utils.h:95
#define MXNET_CUDA_ALLOW_TENSOR_CORE_DEFAULT
Definition: cuda_utils.h:300
const char * CurandGetErrorString(curandStatus_t status)
Get string representation of cuRAND errors.
Definition: cuda_utils.h:124
int ComputeCapabilityMinor(int device_id)
Determine minor version number of the gpu&#39;s cuda compute architecture.
Definition: cuda_utils.h:257
#define CUDA_CALL(func)
Protected CUDA call.
Definition: cuda_utils.h:186
const char * CublasGetErrorString(cublasStatus_t error)
Get string representation of cuBLAS errors.
Definition: cuda_utils.h:64