25 #ifndef MSHADOW_DOT_ENGINE_INL_H_
26 #define MSHADOW_DOT_ENGINE_INL_H_
33 #include "./cuda/tensor_gpu-inl.cuh"
34 #endif // #ifdef __CUDACC__
45 template<
typename Device,
typename DType>
46 inline void GetBatchedView(DType **dst, DType *src,
int num,
int stride,
47 Stream<Device> *stream);
48 template<
typename DType>
51 for (
int i = 0; i < num; i++) {
52 dst[i] = src + i * stride;
57 template<
typename DType>
58 inline void GetBatchedView(DType **dst, DType *src,
int num,
int stride,
59 Stream<gpu> *stream) {
62 #endif // #ifdef __CUDACC__
68 template<
typename SV,
typename Device,
int ddim,
int ldim,
69 int rdim,
bool ltrans,
bool rtrans,
typename DType>
77 template<
typename Device,
typename DType = default_real_t>
79 inline static bool GetT(
bool t) {
80 return t ? true :
false;
85 bool transa,
bool transb,
86 int m,
int n,
int k, DType alpha,
87 const DType *A,
int lda,
const DType *B,
int ldb,
88 DType beta, DType *C,
int ldc) {
89 LOG(FATAL) <<
"Not implmented!";
92 bool transa,
bool transb,
93 int m,
int n,
int k, DType alpha,
94 const DType *A,
int lda,
const DType *B,
int ldb,
95 DType beta, DType *C,
int ldc,
int batch_count,
97 LOG(FATAL) <<
"Not implmented!";
100 bool trans,
int m,
int n,
101 DType alpha,
const DType *A,
int lda,
102 const DType *X,
int incX,
103 DType beta, DType *Y,
int incY) {
104 LOG(FATAL) <<
"Not implmented!";
107 bool trans,
int m,
int n,
108 DType alpha,
const DType *A,
int lda,
109 const DType *X,
int incX,
110 DType beta, DType *Y,
int incY,
int batch_count) {
111 LOG(FATAL) <<
"Not implmented!";
114 int m,
int n, DType alpha,
115 const DType *X,
int incX,
116 const DType *Y,
int incY, DType *A,
int lda) {
117 LOG(FATAL) <<
"Not implmented!";
120 int m,
int n, DType alpha,
121 const DType *X,
int incX,
122 const DType *Y,
int incY, DType *A,
int lda,
int batch_count) {
123 LOG(FATAL) <<
"Not implmented!";
127 const DType* X,
int incX,
128 const DType* Y,
int incY,
130 LOG(FATAL) <<
"Not implmented!";
134 #if MSHADOW_STAND_ALONE
136 struct BLASEngine<
cpu, float> {
137 inline static bool GetT(
bool t) {
138 return t ? true :
false;
140 inline static void SetStream(Stream<cpu> *stream) {
142 inline static void gemm(Stream<cpu> *stream,
143 bool transa,
bool transb,
144 int m,
int n,
int k,
float alpha,
145 const float *A,
int lda,
const float *B,
int ldb,
146 float beta,
float *C,
int ldc) {
147 if (alpha == 1.0f && beta == 0.0f) {
148 bool transpose_left = transb;
149 bool transpose_right = transa;
150 Tensor<cpu, 2, float> lhs((
float*)B,
Shape2(transpose_left ? k : n, transpose_left ? n : k));
151 Tensor<cpu, 2, float> rhs((
float*)A,
Shape2(transpose_right ? m : k, transpose_right ? k : m));
152 Tensor<cpu, 2, float> dst(C,
Shape2(m, n));
153 if (!transpose_left && !transpose_right) {
155 }
else if (!transpose_left && transpose_right) {
157 }
else if (transpose_left && !transpose_right) {
160 LOG(FATAL) <<
"Not implmented!";
163 LOG(FATAL) <<
"Not implmented!";
167 bool transa,
bool transb,
168 int m,
int n,
int k,
float alpha,
169 const float *A,
int lda,
const float *B,
int ldb,
170 float beta,
float *C,
int ldc,
int batch_count,
172 for (
int i = 0; i < batch_count; ++i) {
173 gemm(stream, transa, transb, m, n, k, alpha,
174 A + i * m * k, lda, B + i * k * n, ldb,
175 beta, C + i * m * n, ldc);
178 inline static void gemv(Stream<cpu> *stream,
179 bool trans,
int m,
int n,
180 float alpha,
const float *A,
int lda,
181 const float *X,
int incX,
182 float beta,
float *Y,
int incY) {
183 LOG(FATAL) <<
"Not implmented!";
186 bool trans,
int m,
int n,
187 float alpha,
const float *A,
int lda,
188 const float *X,
int incX,
189 float beta,
float *Y,
int incY,
int batch_count) {
190 LOG(FATAL) <<
"Not implmented!";
192 inline static void ger(Stream<cpu> *stream,
193 int m,
int n,
float alpha,
194 const float *X,
int incX,
195 const float *Y,
int incY,
float *A,
int lda) {
196 LOG(FATAL) <<
"Not implmented!";
198 inline static void batched_ger(Stream<cpu> *stream,
199 int m,
int n,
float alpha,
200 const float *X,
int incX,
201 const float *Y,
int incY,
float *A,
int lda,
int batch_count) {
202 LOG(FATAL) <<
"Not implmented!";
204 inline static void dot(Stream<cpu> *stream,
206 const float* X,
int incX,
207 const float* Y,
int incY,
209 LOG(FATAL) <<
"Not implmented!";
214 struct BLASEngine<
cpu, double> {
215 inline static bool GetT(
bool t) {
216 return t ? true :
false;
218 inline static void SetStream(Stream<cpu> *stream) {
220 inline static void gemm(Stream<cpu> *stream,
221 bool transa,
bool transb,
222 int m,
int n,
int k,
double alpha,
223 const double *A,
int lda,
const double *B,
int ldb,
224 double beta,
double *C,
int ldc) {
225 if (alpha == 1.0f && beta == 0.0f) {
226 bool transpose_left = transb;
227 bool transpose_right = transa;
228 Tensor<cpu, 2, double> lhs((
double*)B,
Shape2(transpose_left ? k : n, transpose_left ? n : k));
229 Tensor<cpu, 2, double> rhs((
double*)A,
Shape2(transpose_right ? m : k, transpose_right ? k : m));
230 Tensor<cpu, 2, double> dst(C,
Shape2(m, n));
231 if (!transpose_left && !transpose_right) {
233 }
else if (!transpose_left && transpose_right) {
235 }
else if (transpose_left && !transpose_right) {
238 LOG(FATAL) <<
"Not implmented!";
241 LOG(FATAL) <<
"Not implmented!";
245 bool transa,
bool transb,
246 int m,
int n,
int k,
double alpha,
247 const double *A,
int lda,
const double *B,
int ldb,
248 double beta,
double *C,
int ldc,
int batch_count,
249 double **workspace) {
250 for (
int i = 0; i < batch_count; ++i) {
251 gemm(stream, transa, transb, m, n, k, alpha,
252 A + i * m * k, lda, B + i * k * n, ldb,
253 beta, C + i * m * n, ldc);
256 inline static void gemv(Stream<cpu> *stream,
257 bool trans,
int m,
int n,
258 double alpha,
const double *A,
int lda,
259 const double *X,
int incX,
260 double beta,
double *Y,
int incY) {
261 LOG(FATAL) <<
"Not implmented!";
264 bool trans,
int m,
int n,
265 double alpha,
const double *A,
int lda,
266 const double *X,
int incX,
267 double beta,
double *Y,
int incY,
int batch_count) {
268 LOG(FATAL) <<
"Not implmented!";
270 inline static void ger(Stream<cpu> *stream,
271 int m,
int n,
double alpha,
272 const double *X,
int incX,
273 const double *Y,
int incY,
double *A,
int lda) {
274 LOG(FATAL) <<
"Not implmented!";
276 inline static void batched_ger(Stream<cpu> *stream,
277 int m,
int n,
double alpha,
278 const double *X,
int incX,
279 const double *Y,
int incY,
double *A,
int lda,
int batch_count) {
280 LOG(FATAL) <<
"Not implmented!";
282 inline static void dot(Stream<cpu> *stream,
284 const double* X,
int incX,
285 const double* Y,
int incY,
287 LOG(FATAL) <<
"Not implmented!";
291 #elif (MSHADOW_USE_MKL || MSHADOW_USE_CBLAS) // NOLINT(*)
294 inline static CBLAS_TRANSPOSE
GetT(
bool t) {
295 return t ? CblasTrans : CblasNoTrans;
300 bool transa,
bool transb,
303 float beta,
float *C,
index_t ldc) {
304 cblas_sgemm(CblasColMajor,
GetT(transa),
GetT(transb),
305 m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
308 bool transa,
bool transb,
313 #if (MSHADOW_USE_MKL && INTEL_MKL_VERSION >= 20160000)
315 const int GROUP_SIZE = 1;
316 MKL_INT p_m[GROUP_SIZE] = {
static_cast<MKL_INT
>(m)};
317 MKL_INT p_n[GROUP_SIZE] = {
static_cast<MKL_INT
>(n)};
318 MKL_INT p_k[GROUP_SIZE] = {
static_cast<MKL_INT
>(k)};
319 MKL_INT p_lda[GROUP_SIZE] = {
static_cast<MKL_INT
>(lda)};
320 MKL_INT p_ldb[GROUP_SIZE] = {
static_cast<MKL_INT
>(ldb)};
321 MKL_INT p_ldc[GROUP_SIZE] = {
static_cast<MKL_INT
>(ldc)};
323 float p_alpha[GROUP_SIZE] = {alpha};
324 float p_beta[GROUP_SIZE] = {beta};
326 CBLAS_TRANSPOSE cblas_a_trans =
GetT(transa);
327 CBLAS_TRANSPOSE cblas_b_trans =
GetT(transb);
329 MKL_INT p_group_sizeb[GROUP_SIZE] = {
static_cast<MKL_INT
>(batch_count)};
330 CBLAS_TRANSPOSE p_transa[GROUP_SIZE] = {cblas_a_trans};
331 CBLAS_TRANSPOSE p_transb[GROUP_SIZE] = {cblas_b_trans};
333 std::vector<const float*> pp_A(batch_count,
nullptr);
334 std::vector<const float*> pp_B(batch_count,
nullptr);
335 std::vector<float*> pp_C(batch_count,
nullptr);
341 for (
int i = 0; i < batch_count; i++) {
342 pp_A[i] = A + i * m_k;
343 pp_B[i] = B + i * k_n;
344 pp_C[i] = C + i * m_n;
347 cblas_sgemm_batch(CblasColMajor, p_transa, p_transb,
348 p_m, p_n, p_k, p_alpha, pp_A.data(), p_lda, pp_B.data(),
349 p_ldb, p_beta, pp_C.data(), p_ldc, GROUP_SIZE, p_group_sizeb);
351 for (
int i = 0; i < batch_count; ++i) {
352 gemm(stream, transa, transb, m, n, k, alpha,
353 A + i * m * k, lda, B + i * k * n, ldb,
354 beta, C + i * m * n, ldc);
359 bool trans,
int m,
int n,
360 float alpha,
const float *A,
int lda,
361 const float *X,
int incX,
362 float beta,
float *Y,
int incY) {
363 cblas_sgemv(CblasColMajor,
GetT(trans), m, n, alpha,
364 A, lda, X, incX, beta, Y, incY);
367 bool trans,
int m,
int n,
368 float alpha,
const float *A,
int lda,
369 const float *X,
int incX,
370 float beta,
float *Y,
int incY,
int batch_count) {
371 for (
int i = 0; i < batch_count; ++i) {
372 gemv(stream, trans, m, n, alpha, A + i * m * n, lda,
373 X + i * (trans ? m : n) * incX, incX,
374 beta, Y + i * (trans ? n : m) * incY, incY);
378 int m,
int n,
float alpha,
379 const float *X,
int incX,
380 const float *Y,
int incY,
float *A,
int lda) {
381 cblas_sger(CblasColMajor, m, n, alpha, X, incX, Y, incY, A, lda);
384 int m,
int n,
float alpha,
385 const float *X,
int incX,
386 const float *Y,
int incY,
float *A,
int lda,
int batch_count) {
387 for (
int i = 0; i < batch_count; ++i) {
388 ger(stream, m, n, alpha, X + i * m * incX, incX, Y + i * n * incY, incY,
389 A + i * lda * n, lda);
394 const float* X,
int incX,
395 const float* Y,
int incY,
397 *ret = cblas_sdot(n, X, incX, Y, incY);
403 inline static CBLAS_TRANSPOSE
GetT(
bool t) {
404 return t ? CblasTrans : CblasNoTrans;
409 bool transa,
bool transb,
412 double beta,
double *C,
index_t ldc) {
413 cblas_dgemm(CblasColMajor,
GetT(transa),
GetT(transb),
414 m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
417 bool transa,
bool transb,
421 double **workspace) {
422 #if (MSHADOW_USE_MKL && INTEL_MKL_VERSION >= 20160000)
424 const int GROUP_SIZE = 1;
425 MKL_INT p_m[GROUP_SIZE] = {
static_cast<MKL_INT
>(m)};
426 MKL_INT p_n[GROUP_SIZE] = {
static_cast<MKL_INT
>(n)};
427 MKL_INT p_k[GROUP_SIZE] = {
static_cast<MKL_INT
>(k)};
428 MKL_INT p_lda[GROUP_SIZE] = {
static_cast<MKL_INT
>(lda)};
429 MKL_INT p_ldb[GROUP_SIZE] = {
static_cast<MKL_INT
>(ldb)};
430 MKL_INT p_ldc[GROUP_SIZE] = {
static_cast<MKL_INT
>(ldc)};
432 double p_alpha[GROUP_SIZE] = {alpha};
433 double p_beta[GROUP_SIZE] = {beta};
435 CBLAS_TRANSPOSE cblas_a_trans =
GetT(transa);
436 CBLAS_TRANSPOSE cblas_b_trans =
GetT(transb);
438 MKL_INT p_group_sizeb[GROUP_SIZE] = {
static_cast<MKL_INT
>(batch_count)};
439 CBLAS_TRANSPOSE p_transa[GROUP_SIZE] = {cblas_a_trans};
440 CBLAS_TRANSPOSE p_transb[GROUP_SIZE] = {cblas_b_trans};
442 std::vector<const double*> pp_A(batch_count,
nullptr);
443 std::vector<const double*> pp_B(batch_count,
nullptr);
444 std::vector<double*> pp_C(batch_count,
nullptr);
450 for (
int i = 0; i < batch_count; i++) {
451 pp_A[i] = A + i * m_k;
452 pp_B[i] = B + i * k_n;
453 pp_C[i] = C + i * m_n;
456 cblas_dgemm_batch(CblasColMajor, p_transa, p_transb,
457 p_m, p_n, p_k, p_alpha, pp_A.data(), p_lda, pp_B.data(),
458 p_ldb, p_beta, pp_C.data(), p_ldc, GROUP_SIZE, p_group_sizeb);
460 for (
int i = 0; i < batch_count; ++i) {
461 gemm(stream, transa, transb, m, n, k, alpha,
462 A + i * m * k, lda, B + i * k * n, ldb,
463 beta, C + i * m * n, ldc);
468 bool trans,
int m,
int n,
double alpha,
469 const double *A,
int lda,
470 const double *X,
int incX,
471 double beta,
double *Y,
int incY) {
472 cblas_dgemv(CblasColMajor,
GetT(trans), m, n, alpha,
473 A, lda, X, incX, beta, Y, incY);
476 bool trans,
int m,
int n,
477 double alpha,
const double *A,
int lda,
478 const double *X,
int incX,
479 double beta,
double *Y,
int incY,
int batch_count) {
480 for (
int i = 0; i < batch_count; ++i) {
481 gemv(stream, trans, m, n, alpha, A + i * m * n, lda,
482 X + i * (trans ? m : n) * incX, incX,
483 beta, Y + i * (trans ? n : m) * incY, incY);
487 int m,
int n,
double alpha,
488 const double *X,
int incX,
489 const double *Y,
int incY,
double *A,
int lda) {
490 cblas_dger(CblasColMajor, m, n, alpha, X, incX, Y, incY, A, lda);
493 int m,
int n,
double alpha,
494 const double *X,
int incX,
495 const double *Y,
int incY,
double *A,
int lda,
int batch_count) {
496 for (
int i = 0; i < batch_count; ++i) {
497 ger(stream, m, n, alpha, X + i * m * incX, incX, Y + i * n * incY, incY,
498 A + i * lda * n, lda);
503 const double* X,
int incX,
504 const double* Y,
int incY,
506 *ret = cblas_ddot(n, X, incX, Y, incY);
509 #endif // MSHADOW_USE_CBLAS || MSHADOW_USE_MKL || MSHADOW_STAND_ALONE
515 inline static cublasOperation_t
GetT(
bool t) {
516 return t ? CUBLAS_OP_T : CUBLAS_OP_N;
521 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas set stream fail";
524 bool transa,
bool transb,
525 int m,
int n,
int k, half::half_t alpha,
526 const half::half_t *A,
int lda,
527 const half::half_t *B,
int ldb, half::half_t beta,
528 half::half_t *C,
int ldc) {
529 #if defined(CUDA_VERSION) && CUDA_VERSION >= 7050
531 float alpha_f = float(alpha);
532 float beta_f = float(beta);
533 #if CUDA_VERSION >= 8000
535 GetT(transa),
GetT(transb), m, n, k, &alpha_f,
536 A, CUDA_R_16F, lda, B, CUDA_R_16F,
537 ldb, &beta_f, C, CUDA_R_16F, ldc);
538 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas SgemmEx fail";
541 GetT(transa),
GetT(transb), m, n, k, &alpha_f,
542 A, CUBLAS_DATA_HALF, lda, B, CUBLAS_DATA_HALF,
543 ldb, &beta_f, C, CUBLAS_DATA_HALF, ldc);
544 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas SgemmEx fail";
545 #endif // CUDA_VERSION >= 8000
547 LOG(FATAL) <<
"Require CUDA version >= 7.5!";
548 #endif // defined(CUDA_VERSION) && CUDA_VERSION >= 7050
551 bool transa,
bool transb,
552 int m,
int n,
int k, half::half_t alpha,
553 const half::half_t *A,
int lda,
const half::half_t *B,
int ldb,
554 half::half_t beta, half::half_t *C,
int ldc,
int batch_count,
555 half::half_t **workspace) {
556 #if defined(__CUDACC__) && CUDA_VERSION >= 9000
557 int major = stream->
prop.major;
558 int minor = stream->
prop.minor;
560 if ((major > 5) || (major == 5 && minor >= 3)) {
561 const __half* A_h =
reinterpret_cast<const __half*
>(A);
562 const __half* B_h =
reinterpret_cast<const __half*
>(B);
563 __half* alpha_h =
reinterpret_cast<__half*
>(&alpha);
564 __half* beta_h =
reinterpret_cast<__half*
>(&beta);
565 __half* C_h =
reinterpret_cast<__half*
>(C);
567 GetT(transa),
GetT(transb), m, n, k, alpha_h,
570 beta_h, C_h, ldc, m * n,
572 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas: HgemmStridedBatched fail";
576 for (
int i = 0; i < batch_count; ++i) {
577 gemm(stream, transa, transb, m, n, k, alpha,
578 A + i * m * k, lda, B + i * k * n, ldb,
579 beta, C + i * m * n, ldc);
583 bool trans,
int m,
int n, half::half_t alpha,
584 const half::half_t *A,
int lda,
585 const half::half_t *X,
int incX, half::half_t beta,
586 half::half_t *Y,
int incY) {
587 LOG(FATAL) <<
"Not implmented!";
590 bool trans,
int m,
int n,
591 half::half_t alpha,
const half::half_t *A,
int lda,
592 const half::half_t *X,
int incX,
593 half::half_t beta, half::half_t *Y,
int incY,
int batch_count) {
594 LOG(FATAL) <<
"Not implmented!";
597 int m,
int n, half::half_t alpha,
598 const half::half_t *X,
int incX,
599 const half::half_t *Y,
int incY, half::half_t *A,
int lda) {
600 LOG(FATAL) <<
"Not implmented!";
603 int m,
int n, half::half_t alpha,
604 const half::half_t *X,
int incX,
const half::half_t *Y,
int incY,
605 half::half_t *A,
int lda,
int batch_count) {
606 LOG(FATAL) <<
"Not implmented!";
610 const half::half_t* X,
int incX,
611 const half::half_t* Y,
int incY,
613 LOG(FATAL) <<
"Not implmented!";
619 inline static cublasOperation_t
GetT(
bool t) {
620 return t ? CUBLAS_OP_T : CUBLAS_OP_N;
625 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas: set stream fail";
628 bool transa,
bool transb,
629 int m,
int n,
int k,
float alpha,
630 const float *A,
int lda,
631 const float *B,
int ldb,
float beta,
634 GetT(transa),
GetT(transb), m, n, k, &alpha,
635 A, lda, B, ldb, &beta, C, ldc);
636 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas: Sgemm fail";
639 bool transa,
bool transb,
640 int m,
int n,
int k,
float alpha,
641 const float *A,
int lda,
const float *B,
int ldb,
642 float beta,
float *C,
int ldc,
int batch_count,
644 #if defined(__CUDACC__) && CUDA_VERSION >= 4010 && CUDA_VERSION < 8000
646 bool alloc_workspace =
false;
647 if (workspace == NULL) {
650 cudaMalloc(
reinterpret_cast<void**
>(&workspace), 3 * batch_count *
sizeof(
float*));
651 alloc_workspace =
true;
653 GetBatchedView(workspace,
const_cast<float*
>(A), batch_count, m * k, stream);
655 const_cast<float*
>(B), batch_count, k * n, stream);
656 GetBatchedView(workspace + 2 * batch_count, C, batch_count, m * n, stream);
658 GetT(transa),
GetT(transb), m, n, k, &alpha,
659 (
const float**)workspace, lda,
660 (
const float**)(workspace + batch_count), ldb,
661 &beta, workspace + 2 * batch_count, ldc, batch_count);
662 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas: SgemmBatched fail";
663 if (alloc_workspace) {
666 #elif defined(__CUDACC__) && CUDA_VERSION >= 8000
668 GetT(transa),
GetT(transb), m, n, k, &alpha,
671 &beta, C, ldc, m * n,
673 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas: SgemmStridedBatched fail";
675 for (
int i = 0; i < batch_count; ++i) {
676 gemm(stream, transa, transb, m, n, k, alpha,
677 A + i * m * k, lda, B + i * k * n, ldb,
678 beta, C + i * m * n, ldc);
680 #endif // defined(__CUDACC__) && CUDA_VERSION >= 4010
683 bool trans,
int m,
int n,
float alpha,
684 const float *A,
int lda,
685 const float *X,
int incX,
float beta,
686 float *Y,
int incY) {
688 GetT(trans), m, n, &alpha, A, lda, X, incX, &beta, Y, incY);
689 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas: Sgemv fail";
692 bool trans,
int m,
int n,
693 float alpha,
const float *A,
int lda,
694 const float *X,
int incX,
695 float beta,
float *Y,
int incY,
int batch_count) {
696 for (
int i = 0; i < batch_count; ++i) {
697 gemv(stream, trans, m, n, alpha, A + i * m * n, lda,
698 X + i * (trans ? m : n) * incX, incX,
699 beta, Y + i * (trans ? n : m) * incY, incY);
703 int m,
int n,
float alpha,
704 const float *X,
int incX,
705 const float *Y,
int incY,
float *A,
int lda) {
707 m, n, &alpha, X, incX, Y, incY, A, lda);
708 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas: Sger fail";
711 int m,
int n,
float alpha,
712 const float *X,
int incX,
713 const float *Y,
int incY,
float *A,
int lda,
int batch_count) {
714 for (
int i = 0; i < batch_count; ++i) {
715 ger(stream, m, n, alpha, X + i * m * incX, incX, Y + i * n * incY, incY,
716 A + i * lda * n, lda);
721 const float* X,
int incX,
722 const float* Y,
int incY,
725 CUBLAS_POINTER_MODE_DEVICE);
727 n, X, incX, Y, incY, ret);
728 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas: Dot fail";
730 CUBLAS_POINTER_MODE_HOST);
736 inline static cublasOperation_t
GetT(
bool t) {
737 return t ? CUBLAS_OP_T : CUBLAS_OP_N;
742 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas: set stream fail";
745 bool transa,
bool transb,
746 int m,
int n,
int k,
double alpha,
747 const double *A,
int lda,
748 const double *B,
int ldb,
749 double beta,
double *C,
int ldc) {
751 GetT(transa),
GetT(transb), m, n, k, &alpha,
752 A, lda, B, ldb, &beta, C, ldc);
753 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas: Dgemm fail";
756 bool transa,
bool transb,
757 int m,
int n,
int k,
double alpha,
758 const double *A,
int lda,
const double *B,
int ldb,
759 double beta,
double *C,
int ldc,
int batch_count,
760 double **workspace) {
761 #if defined(__CUDACC__) && CUDA_VERSION >= 4010 && CUDA_VERSION < 8000
763 bool alloc_workspace =
false;
764 if (workspace == NULL) {
767 cudaMalloc(
reinterpret_cast<void**
>(&workspace), 3 * batch_count *
sizeof(
double*));
768 alloc_workspace =
true;
770 GetBatchedView(workspace,
const_cast<double*
>(A), batch_count, m * k, stream);
772 const_cast<double*
>(B), batch_count, k * n, stream);
773 GetBatchedView(workspace + 2 * batch_count, C, batch_count, m * n, stream);
775 GetT(transa),
GetT(transb), m, n, k, &alpha,
776 (
const double**)workspace, lda,
777 (
const double**)(workspace + batch_count), ldb,
778 &beta, workspace + 2 * batch_count, ldc, batch_count);
779 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas: DgemmBatched fail";
780 if (alloc_workspace) {
783 #elif defined(__CUDACC__) && CUDA_VERSION >= 8000
785 GetT(transa),
GetT(transb), m, n, k, &alpha,
788 &beta, C, ldc, m * n,
790 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas: DgemmStridedBatched fail";
792 for (
int i = 0; i < batch_count; ++i) {
793 gemm(stream, transa, transb, m, n, k, alpha,
794 A + i * m * k, lda, B + i * k * n, ldb,
795 beta, C + i * m * n, ldc);
797 #endif // defined(__CUDACC__) && CUDA_VERSION >= 4010
800 bool trans,
int m,
int n,
double alpha,
801 const double *A,
int lda,
802 const double *X,
int incX,
803 double beta,
double *Y,
int incY) {
805 GetT(trans), m, n, &alpha, A, lda, X, incX, &beta, Y, incY);
806 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas: Dgemv fail";
809 bool trans,
int m,
int n,
810 double alpha,
const double *A,
int lda,
811 const double *X,
int incX,
812 double beta,
double *Y,
int incY,
int batch_count) {
813 for (
int i = 0; i < batch_count; ++i) {
814 gemv(stream, trans, m, n, alpha, A + i * m * n, lda,
815 X + i * (trans ? m : n) * incX, incX,
816 beta, Y + i * (trans ? n : m) * incY, incY);
820 int m,
int n,
double alpha,
821 const double *X,
int incX,
822 const double *Y,
int incY,
double *A,
int lda) {
824 m, n, &alpha, X, incX, Y, incY, A, lda);
825 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas: Dger fail";
828 int m,
int n,
double alpha,
829 const double *X,
int incX,
830 const double *Y,
int incY,
double *A,
int lda,
int batch_count) {
831 for (
int i = 0; i < batch_count; ++i) {
832 ger(stream, m, n, alpha, X + i * m * incX, incX, Y + i * n * incY, incY,
833 A + i * lda * n, lda);
838 const double* X,
int incX,
839 const double* Y,
int incY,
842 CUBLAS_POINTER_MODE_DEVICE);
844 n, X, incX, Y, incY, ret);
845 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas: Dot fail";
847 CUBLAS_POINTER_MODE_HOST);
850 #endif // MSHADOW_USE_CUDA
856 template<
typename SV,
typename xpu,
857 bool transpose_left,
bool transpose_right,
typename DType>
858 struct DotEngine<SV, xpu, 2, 2, 2, transpose_left, transpose_right, DType> {
864 #if MSHADOW_STAND_ALONE
866 if (!transpose_left && !transpose_right) {
868 }
else if (!transpose_left && transpose_right) {
870 }
else if (transpose_left && !transpose_right) {
880 CHECK(dst.
size(0) == sleft[0] && dst.
size(1) == sright[1] && sleft[1] == sright[0])
881 <<
"dot-gemm: matrix shape mismatch";
885 transpose_right , transpose_left,
886 transpose_right ? rhs.
size(0) : rhs.
size(1),
887 transpose_left ? lhs.
size(1) : lhs.
size(0),
888 transpose_right ? rhs.
size(1) : rhs.
size(0),
889 DType(scale * SV::AlphaBLAS()),
892 DType(SV::BetaBLAS()),
896 template<
typename SV,
typename xpu,
bool transpose_right,
typename DType>
897 struct DotEngine<SV, xpu, 1, 1, 2, false, transpose_right, DType> {
907 CHECK(dst.
size(0) == sright[1] && lhs.
size(0) == sright[0])
908 <<
"dot-gemv: matrix shape mismatch"
909 <<
"dst: " << dst.
shape_ <<
"\n"
910 <<
"lhs: " << lhs.
shape_ <<
"\n"
911 <<
"rhs: " << sright <<
"\n";
915 rhs.
size(1), rhs.
size(0), scale * SV::AlphaBLAS(),
917 lhs.
dptr_, 1, SV::BetaBLAS(),
921 template<
typename SV,
typename xpu,
typename DType>
922 struct DotEngine<SV, xpu, 2, 1, 1, true, false, DType> {
932 <<
"dot-ger: matrix shape mismatch"
933 <<
"dst: " << dst.
shape_ <<
"\n"
934 <<
"lhs: " << lhs.
shape_ <<
"\n"
936 if (SV::BetaBLAS() == 0.0f) {
948 #endif // MSHADOW_DOT_ENGINE_INL_H_