mxnet
stream_gpu-inl.h
Go to the documentation of this file.
1 
7 #ifndef MSHADOW_STREAM_GPU_INL_H_
8 #define MSHADOW_STREAM_GPU_INL_H_
9 #include <memory>
10 #include "./base.h"
11 #include "./tensor.h"
12 #include "./logging.h"
13 
14 namespace mshadow {
15 #if MSHADOW_USE_CUDA == 1
16 // Stream alocation
17 // actual implementation of GPU stream in CUDA
18 template<>
19 struct Stream<gpu> {
21  enum HandleState {
22  NoHandle = 0,
23  OwnHandle = 1,
24  };
26  cudaStream_t stream_;
28  cublasHandle_t blas_handle_;
30  #if MSHADOW_USE_CUSOLVER == 1
31  cusolverDnHandle_t solver_handle_;
32  #endif
33 
34  #if MSHADOW_USE_CUDNN == 1
35  cudnnHandle_t dnn_handle_;
36  #endif
37 
44  cudaDeviceProp prop;
46  int dev_id;
47 
48  Stream(void)
49  : stream_(0)
50  , blas_handle_(0)
51 #if MSHADOW_USE_CUDNN == 1
52  , dnn_handle_(0)
53 #endif
54  , blas_handle_ownership_(NoHandle)
55  , solver_handle_ownership_(NoHandle)
56  , dnn_handle_ownership_(NoHandle) {}
61  inline void Wait(void) {
62  MSHADOW_CUDA_CALL(cudaStreamSynchronize(stream_));
63  }
68  inline bool CheckIdle(void) {
69  cudaError_t err = cudaStreamQuery(stream_);
70  if (err == cudaSuccess) return true;
71  if (err == cudaErrorNotReady) return false;
72  LOG(FATAL) << cudaGetErrorString(err);
73  return false;
74  }
79  inline static cudaStream_t GetStream(Stream<gpu> *stream) {
80  if (stream == NULL) {
81 #if MSHADOW_FORCE_STREAM
82  LOG(FATAL) << "Default GPU stream was used when MSHADOW_FORCE_STREAM was on";
83 #endif
84  return 0;
85  } else {
86  return stream->stream_;
87  }
88  }
93  inline static cublasHandle_t GetBlasHandle(Stream<gpu> *stream) {
94  if (stream == NULL) {
95  return 0;
96  } else {
97  CHECK_NE(stream->blas_handle_ownership_, NoHandle)
98  << "No handle exist in source stream";
99  return stream->blas_handle_;
100  }
101  }
103  inline void DestroyBlasHandle() {
104  if (blas_handle_ownership_ == OwnHandle) {
105  cublasStatus_t err = cublasDestroy(blas_handle_);
106  blas_handle_ownership_ = NoHandle;
107  CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Destory cublas handle failed";
108  }
109  }
111  inline void CreateBlasHandle() {
112  this->DestroyBlasHandle();
113  cublasStatus_t err = cublasCreate(&blas_handle_);
114  blas_handle_ownership_ = OwnHandle;
115  CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Create cublas handle failed";
116  err = cublasSetStream(blas_handle_, stream_);
117  CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Setting cublas stream failed";
118  }
119 #if MSHADOW_USE_CUSOLVER == 1
120  inline static cusolverDnHandle_t GetSolverHandle(Stream<gpu> *stream) {
121  if (stream == NULL) {
122  return 0;
123  } else {
124  CHECK_NE(stream->solver_handle_ownership_, NoHandle) << "No handle exist in source stream";
125  return stream->solver_handle_;
126  }
127  }
128 #endif
129  inline void DestroySolverHandle() {
130 #if MSHADOW_USE_CUSOLVER == 1
131  if (solver_handle_ownership_ == OwnHandle) {
132  cusolverStatus_t err = cusolverDnDestroy(solver_handle_);
133  CHECK_EQ(err, CUSOLVER_STATUS_SUCCESS) << "Destory cusolver handle failed";
134  }
135 #endif
136  }
137  inline void CreateSolverHandle() {
138 #if MSHADOW_USE_CUSOLVER == 1
139  this->DestroySolverHandle();
140  cusolverStatus_t err = cusolverDnCreate(&solver_handle_);
141  CHECK_EQ(err, CUSOLVER_STATUS_SUCCESS) << "Create cusolver handle failed";
142  err = cusolverDnSetStream(solver_handle_, stream_);
143  CHECK_EQ(err, CUSOLVER_STATUS_SUCCESS) << "Setting cusolver stream failed";
144  this->solver_handle_ownership_ = OwnHandle;
145 #endif
146  }
147 // #if MSHADOW_USE_CUDNN && defined(__CUDACC__)
148 #if MSHADOW_USE_CUDNN == 1
149  inline static cudnnHandle_t GetDnnHandle(Stream<gpu> *stream) {
150  if (stream == NULL) {
151  return 0;
152  } else {
153  CHECK_NE(stream->dnn_handle_ownership_, NoHandle) << "No handle exist in source stream";
154  return stream->dnn_handle_;
155  }
156  }
157 #endif
158  inline void DestroyDnnHandle() {
159 // #if MSHADOW_USE_CUDNN && defined(__CUDACC__)
160 #if MSHADOW_USE_CUDNN == 1
161  if (dnn_handle_ownership_ == OwnHandle) {
162  cudnnStatus_t err = cudnnDestroy(dnn_handle_);
163  this->dnn_handle_ownership_ = NoHandle;
164  CHECK_EQ(err, CUDNN_STATUS_SUCCESS) << cudnnGetErrorString(err);
165  }
166 #endif
167  }
168  inline void CreateDnnHandle() {
169 // #if MSHADOW_USE_CUDNN == 1 && defined(__CUDACC__)
170 #if MSHADOW_USE_CUDNN == 1
171  this->DestroyDnnHandle();
172  cudnnStatus_t err = cudnnCreate(&dnn_handle_);
173  CHECK_EQ(err, CUDNN_STATUS_SUCCESS) << cudnnGetErrorString(err);
174  // At this point, we have the resource which may need to be freed
175  this->dnn_handle_ownership_ = OwnHandle;
176  err = cudnnSetStream(dnn_handle_, stream_);
177  CHECK_EQ(err, CUDNN_STATUS_SUCCESS) << cudnnGetErrorString(err);
178 #endif
179  }
180 };
181 template<>
182 inline void DeleteStream<gpu>(Stream<gpu> *stream) {
183  if (stream) {
184  MSHADOW_CUDA_CALL(cudaStreamDestroy(stream->stream_));
185  stream->DestroyBlasHandle();
186  stream->DestroySolverHandle();
187  stream->DestroyDnnHandle();
188  delete stream;
189  }
190 }
191 template<>
192 inline Stream<gpu> *NewStream<gpu>(bool create_blas_handle,
193  bool create_dnn_handle,
194  int dev_id) {
195  // RAII on Cuda exception
196  struct StreamDeleter { void operator()(Stream<gpu> *ptr) const { DeleteStream<gpu>(ptr); } };
197  std::unique_ptr<Stream<gpu>, StreamDeleter> st(new Stream<gpu>());
198  MSHADOW_CUDA_CALL(cudaStreamCreate(&st->stream_));
199  if (create_blas_handle) {
200  st->CreateBlasHandle();
201  st->CreateSolverHandle();
202  }
203  if (create_dnn_handle) {
204  st->CreateDnnHandle();
205  }
206  st->dev_id = dev_id;
207  if (dev_id != -1) {
208  MSHADOW_CUDA_CALL(cudaGetDeviceProperties(&st->prop, dev_id));
209  }
210  return st.release();
211 }
212 #endif
213 } // namespace mshadow
214 #endif // MSHADOW_STREAM_GPU_INL_H_
static cudaStream_t GetStream(Stream< gpu > *stream)
returns actual cudaStream_t given an input GPU stream pointer
Definition: stream_gpu-inl.h:79
HandleState dnn_handle_ownership_
cudnn handle ownership
Definition: stream_gpu-inl.h:42
Definition: stream_gpu-inl.h:19
bool CheckIdle(void)
query whether the the stream is idle
Definition: stream_gpu-inl.h:68
static cusolverDnHandle_t GetSolverHandle(Stream< gpu > *stream)
Definition: stream_gpu-inl.h:120
HandleState
handle state
Definition: stream_gpu-inl.h:21
Stream(void)
Definition: stream_gpu-inl.h:48
Stream< gpu > * NewStream< gpu >(bool create_blas_handle, bool create_dnn_handle, int dev_id)
Definition: stream_gpu-inl.h:192
void DestroySolverHandle()
Definition: stream_gpu-inl.h:129
#define MSHADOW_CUDA_CALL(func)
Protected cuda call in mshadow.
Definition: base.h:252
cudaDeviceProp prop
cudaDeviceProp
Definition: stream_gpu-inl.h:44
device name GPU
Definition: tensor.h:28
HandleState blas_handle_ownership_
cudnn handle
Definition: stream_gpu-inl.h:38
int dev_id
dev id
Definition: stream_gpu-inl.h:46
HandleState solver_handle_ownership_
cusolver handle ownership
Definition: stream_gpu-inl.h:40
void CreateBlasHandle()
Destory original blas handle and create a new one.
Definition: stream_gpu-inl.h:111
cudaStream_t stream_
cudaStream
Definition: stream_gpu-inl.h:26
cublasHandle_t blas_handle_
cublas handle
Definition: stream_gpu-inl.h:28
void DestroyDnnHandle()
Definition: stream_gpu-inl.h:158
void Wait(void)
wait for all the computation associated with this stream to complete
Definition: stream_gpu-inl.h:61
static cublasHandle_t GetBlasHandle(Stream< gpu > *stream)
return actual cublasHandle
Definition: stream_gpu-inl.h:93
namespace for mshadow
Definition: base.h:282
void CreateDnnHandle()
Definition: stream_gpu-inl.h:168
void DestroyBlasHandle()
Destory cublas handle if own it.
Definition: stream_gpu-inl.h:103
cusolverDnHandle_t solver_handle_
cusolver handle
Definition: stream_gpu-inl.h:31
#define MSHADOW_USE_CUDNN
use CUDNN support, must ensure that the cudnn include path is correct
Definition: base.h:103
void CreateSolverHandle()
Definition: stream_gpu-inl.h:137
void DeleteStream< gpu >(Stream< gpu > *stream)
Definition: stream_gpu-inl.h:182
computaion stream structure, used for asynchronous computations
Definition: tensor.h:365