25 #ifndef MSHADOW_STREAM_GPU_INL_H_
26 #define MSHADOW_STREAM_GPU_INL_H_
30 #include "dmlc/logging.h"
33 #if MSHADOW_USE_CUDA == 1
48 #if MSHADOW_USE_CUSOLVER == 1
52 #if MSHADOW_USE_CUDNN == 1
53 cudnnHandle_t dnn_handle_;
56 #if MSHADOW_USE_CUTENSOR== 1
57 cutensorHandle_t cutensor_handle_;
67 void* cutensor_cachelines_ =
nullptr;
80 , blas_handle_ownership_(NoHandle)
81 , solver_handle_ownership_(NoHandle)
82 , dnn_handle_ownership_(NoHandle)
83 , cutensor_handle_ownership_(NoHandle)
84 , cutensor_cachelines_(nullptr){}
97 cudaError_t err = cudaStreamQuery(stream_);
98 if (err == cudaSuccess)
return true;
99 if (err == cudaErrorNotReady)
return false;
100 LOG(FATAL) << cudaGetErrorString(err);
108 if (stream == NULL) {
109 #if MSHADOW_FORCE_STREAM
110 LOG(FATAL) <<
"Default GPU stream was used when MSHADOW_FORCE_STREAM was on";
122 if (stream == NULL) {
126 <<
"No handle exist in source stream";
132 if (blas_handle_ownership_ == OwnHandle) {
133 cublasStatus_t err = cublasDestroy(blas_handle_);
134 blas_handle_ownership_ = NoHandle;
135 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Destory cublas handle failed";
140 this->DestroyBlasHandle();
141 cublasStatus_t err = cublasCreate(&blas_handle_);
142 blas_handle_ownership_ = OwnHandle;
143 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Create cublas handle failed";
144 err = cublasSetStream(blas_handle_, stream_);
145 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Setting cublas stream failed";
147 #if MSHADOW_USE_CUSOLVER == 1
149 if (stream == NULL) {
158 #if MSHADOW_USE_CUSOLVER == 1
159 if (solver_handle_ownership_ == OwnHandle) {
160 cusolverStatus_t err = cusolverDnDestroy(solver_handle_);
161 CHECK_EQ(err, CUSOLVER_STATUS_SUCCESS) <<
"Destory cusolver handle failed";
166 #if MSHADOW_USE_CUSOLVER == 1
167 this->DestroySolverHandle();
168 cusolverStatus_t err = cusolverDnCreate(&solver_handle_);
169 CHECK_EQ(err, CUSOLVER_STATUS_SUCCESS) <<
"Create cusolver handle failed";
170 err = cusolverDnSetStream(solver_handle_, stream_);
171 CHECK_EQ(err, CUSOLVER_STATUS_SUCCESS) <<
"Setting cusolver stream failed";
172 this->solver_handle_ownership_ = OwnHandle;
176 #if MSHADOW_USE_CUDNN == 1
177 inline static cudnnHandle_t GetDnnHandle(
Stream<gpu> *stream) {
178 if (stream == NULL) {
182 return stream->dnn_handle_;
188 #if MSHADOW_USE_CUDNN == 1
189 if (dnn_handle_ownership_ == OwnHandle) {
190 cudnnStatus_t err = cudnnDestroy(dnn_handle_);
191 this->dnn_handle_ownership_ = NoHandle;
192 CHECK_EQ(err, CUDNN_STATUS_SUCCESS) << cudnnGetErrorString(err);
198 #if MSHADOW_USE_CUDNN == 1
199 this->DestroyDnnHandle();
200 cudnnStatus_t err = cudnnCreate(&dnn_handle_);
201 CHECK_EQ(err, CUDNN_STATUS_SUCCESS) << cudnnGetErrorString(err);
203 this->dnn_handle_ownership_ = OwnHandle;
204 err = cudnnSetStream(dnn_handle_, stream_);
205 CHECK_EQ(err, CUDNN_STATUS_SUCCESS) << cudnnGetErrorString(err);
209 #if MSHADOW_USE_CUTENSOR == 1
210 if (cutensor_handle_ownership_ == OwnHandle) {
212 if (cutensor_cachelines_ !=
nullptr) {
213 cutensorStatus_t err;
214 const char* cacheFilename = getenv(
"MXNET_CUTENSOR_CACHEFILE");
215 if (cacheFilename !=
nullptr) {
216 err = cutensorHandleWriteCacheToFile(&cutensor_handle_, cacheFilename);
217 CHECK_EQ(err, CUTENSOR_STATUS_SUCCESS) << cutensorGetErrorString(err);
219 err = cutensorHandleDetachPlanCachelines(&cutensor_handle_);
220 CHECK_EQ(err, CUTENSOR_STATUS_SUCCESS) << cutensorGetErrorString(err);
221 free(cutensor_cachelines_);
222 cutensor_cachelines_ =
nullptr;
224 this->cutensor_handle_ownership_ = NoHandle;
229 #if MSHADOW_USE_CUTENSOR == 1
230 this->DestroyCuTensorHandle();
231 cutensorStatus_t err = cutensorInit(&cutensor_handle_);
232 CHECK_EQ(err, CUTENSOR_STATUS_SUCCESS) << cutensorGetErrorString(err);
233 const char* cacheFilename = getenv(
"MXNET_CUTENSOR_CACHEFILE");
234 if (cacheFilename !=
nullptr) {
235 constexpr int32_t numCachelines = 1024;
236 size_t sizeCache = numCachelines *
sizeof(cutensorPlanCacheline_t);
237 cutensor_cachelines_ = malloc(sizeCache);
238 err = cutensorHandleAttachPlanCachelines(&cutensor_handle_, (cutensorPlanCacheline_t*) cutensor_cachelines_, numCachelines);
239 CHECK_EQ(err, CUTENSOR_STATUS_SUCCESS) << cutensorGetErrorString(err);
241 uint32_t numCachelinesRead = 0;
242 cutensorStatus_t status = cutensorHandleReadCacheFromFile(&cutensor_handle_, cacheFilename, &numCachelinesRead);
243 if (status == CUTENSOR_STATUS_IO_ERROR) {
244 printf(
"File (%s) doesn't seem to exist.\n", cacheFilename);
245 }
else if (status == CUTENSOR_STATUS_INSUFFICIENT_WORKSPACE) {
246 printf(
"Cannot read cache: Please attach at least %d cachelines to the handle.\n", numCachelinesRead);
250 this->cutensor_handle_ownership_ = OwnHandle;
267 bool create_dnn_handle,
271 std::unique_ptr<Stream<gpu>, StreamDeleter> st(
new Stream<gpu>());
273 if (create_blas_handle) {
274 st->CreateBlasHandle();
275 st->CreateSolverHandle();
277 if (create_dnn_handle) {
278 st->CreateDnnHandle();
280 #if MSHADOW_USE_CUTENSOR == 1
281 st->CreateCuTensorHandle();
291 #endif // MSHADOW_STREAM_GPU_INL_H_