26 #ifndef MSHADOW_DOT_ENGINE_INL_H_ 27 #define MSHADOW_DOT_ENGINE_INL_H_ 34 #include "./cuda/tensor_gpu-inl.cuh" 35 #endif // #ifdef __CUDACC__ 46 template<
typename Device,
typename DType>
47 inline void GetBatchedView(DType **dst, DType *src,
int num,
int stride,
48 Stream<Device> *stream);
49 template<
typename DType>
52 for (
int i = 0; i < num; i++) {
53 dst[i] = src + i * stride;
58 template<
typename DType>
59 inline void GetBatchedView(DType **dst, DType *src,
int num,
int stride,
63 #endif // #ifdef __CUDACC__ 69 template<
typename SV,
typename Device,
int ddim,
int ldim,
70 int rdim,
bool ltrans,
bool rtrans,
typename DType>
78 template<
typename Device,
typename DType = default_real_t>
80 inline static bool GetT(
bool t) {
81 return t ?
true :
false;
86 bool transa,
bool transb,
87 int m,
int n,
int k, DType alpha,
88 const DType *A,
int lda,
const DType *B,
int ldb,
89 DType beta, DType *C,
int ldc) {
90 LOG(FATAL) <<
"Not implmented!";
93 bool transa,
bool transb,
94 int m,
int n,
int k, DType alpha,
95 const DType *A,
int lda,
const DType *B,
int ldb,
96 DType beta, DType *C,
int ldc,
int batch_count,
98 LOG(FATAL) <<
"Not implmented!";
101 bool trans,
int m,
int n,
102 DType alpha,
const DType *A,
int lda,
103 const DType *X,
int incX,
104 DType beta, DType *Y,
int incY) {
105 LOG(FATAL) <<
"Not implmented!";
108 bool trans,
int m,
int n,
109 DType alpha,
const DType *A,
int lda,
110 const DType *X,
int incX,
111 DType beta, DType *Y,
int incY,
int batch_count) {
112 LOG(FATAL) <<
"Not implmented!";
115 int m,
int n, DType alpha,
116 const DType *X,
int incX,
117 const DType *Y,
int incY, DType *A,
int lda) {
118 LOG(FATAL) <<
"Not implmented!";
121 int m,
int n, DType alpha,
122 const DType *X,
int incX,
123 const DType *Y,
int incY, DType *A,
int lda,
int batch_count) {
124 LOG(FATAL) <<
"Not implmented!";
128 const DType* X,
int incX,
129 const DType* Y,
int incY,
131 LOG(FATAL) <<
"Not implmented!";
135 #if MSHADOW_STAND_ALONE 138 inline static bool GetT(
bool t) {
139 return t ?
true :
false;
141 inline static void SetStream(
Stream<cpu> *stream) {
144 bool transa,
bool transb,
145 int m,
int n,
int k,
float alpha,
146 const float *A,
int lda,
const float *B,
int ldb,
147 float beta,
float *C,
int ldc) {
148 if (alpha == 1.0f && beta == 0.0f) {
149 bool transpose_left = transb;
150 bool transpose_right = transa;
154 if (!transpose_left && !transpose_right) {
156 }
else if (!transpose_left && transpose_right) {
158 }
else if (transpose_left && !transpose_right) {
161 LOG(FATAL) <<
"Not implmented!";
164 LOG(FATAL) <<
"Not implmented!";
167 inline static void batched_gemm(
Stream<cpu> *stream,
168 bool transa,
bool transb,
169 int m,
int n,
int k,
float alpha,
170 const float *A,
int lda,
const float *B,
int ldb,
171 float beta,
float *C,
int ldc,
int batch_count,
173 for (
int i = 0; i < batch_count; ++i) {
174 gemm(stream, transa, transb, m, n, k, alpha,
175 A + i * m * k, lda, B + i * k * n, ldb,
176 beta, C + i * m * n, ldc);
180 bool trans,
int m,
int n,
181 float alpha,
const float *A,
int lda,
182 const float *X,
int incX,
183 float beta,
float *Y,
int incY) {
184 LOG(FATAL) <<
"Not implmented!";
186 inline static void batched_gemv(
Stream<cpu> *stream,
187 bool trans,
int m,
int n,
188 float alpha,
const float *A,
int lda,
189 const float *X,
int incX,
190 float beta,
float *Y,
int incY,
int batch_count) {
191 LOG(FATAL) <<
"Not implmented!";
194 int m,
int n,
float alpha,
195 const float *X,
int incX,
196 const float *Y,
int incY,
float *A,
int lda) {
197 LOG(FATAL) <<
"Not implmented!";
199 inline static void batched_ger(
Stream<cpu> *stream,
200 int m,
int n,
float alpha,
201 const float *X,
int incX,
202 const float *Y,
int incY,
float *A,
int lda,
int batch_count) {
203 LOG(FATAL) <<
"Not implmented!";
207 const float* X,
int incX,
208 const float* Y,
int incY,
210 LOG(FATAL) <<
"Not implmented!";
216 inline static bool GetT(
bool t) {
217 return t ?
true :
false;
219 inline static void SetStream(
Stream<cpu> *stream) {
222 bool transa,
bool transb,
223 int m,
int n,
int k,
double alpha,
224 const double *A,
int lda,
const double *B,
int ldb,
225 double beta,
double *C,
int ldc) {
226 if (alpha == 1.0f && beta == 0.0f) {
227 bool transpose_left = transb;
228 bool transpose_right = transa;
232 if (!transpose_left && !transpose_right) {
234 }
else if (!transpose_left && transpose_right) {
236 }
else if (transpose_left && !transpose_right) {
239 LOG(FATAL) <<
"Not implmented!";
242 LOG(FATAL) <<
"Not implmented!";
245 inline static void batched_gemm(
Stream<cpu> *stream,
246 bool transa,
bool transb,
247 int m,
int n,
int k,
double alpha,
248 const double *A,
int lda,
const double *B,
int ldb,
249 double beta,
double *C,
int ldc,
int batch_count,
250 double **workspace) {
251 for (
int i = 0; i < batch_count; ++i) {
252 gemm(stream, transa, transb, m, n, k, alpha,
253 A + i * m * k, lda, B + i * k * n, ldb,
254 beta, C + i * m * n, ldc);
258 bool trans,
int m,
int n,
259 double alpha,
const double *A,
int lda,
260 const double *X,
int incX,
261 double beta,
double *Y,
int incY) {
262 LOG(FATAL) <<
"Not implmented!";
264 inline static void batched_gemv(
Stream<cpu> *stream,
265 bool trans,
int m,
int n,
266 double alpha,
const double *A,
int lda,
267 const double *X,
int incX,
268 double beta,
double *Y,
int incY,
int batch_count) {
269 LOG(FATAL) <<
"Not implmented!";
272 int m,
int n,
double alpha,
273 const double *X,
int incX,
274 const double *Y,
int incY,
double *A,
int lda) {
275 LOG(FATAL) <<
"Not implmented!";
277 inline static void batched_ger(
Stream<cpu> *stream,
278 int m,
int n,
double alpha,
279 const double *X,
int incX,
280 const double *Y,
int incY,
double *A,
int lda,
int batch_count) {
281 LOG(FATAL) <<
"Not implmented!";
285 const double* X,
int incX,
286 const double* Y,
int incY,
288 LOG(FATAL) <<
"Not implmented!";
292 #elif (MSHADOW_USE_MKL || MSHADOW_USE_CBLAS) // NOLINT(*) 295 inline static CBLAS_TRANSPOSE
GetT(
bool t) {
296 return t ? CblasTrans : CblasNoTrans;
301 bool transa,
bool transb,
302 int m,
int n,
int k,
float alpha,
303 const float *A,
int lda,
const float *B,
int ldb,
304 float beta,
float *C,
int ldc) {
305 cblas_sgemm(CblasColMajor, GetT(transa), GetT(transb),
306 m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
309 bool transa,
bool transb,
310 int m,
int n,
int k,
float alpha,
311 const float *A,
int lda,
const float *B,
int ldb,
312 float beta,
float *C,
int ldc,
int batch_count,
314 #if (MSHADOW_USE_MKL && INTEL_MKL_VERSION >= 20160000) 316 const int GROUP_SIZE = 1;
317 MKL_INT p_m[GROUP_SIZE] = {m};
318 MKL_INT p_n[GROUP_SIZE] = {n};
319 MKL_INT p_k[GROUP_SIZE] = {k};
320 MKL_INT p_lda[GROUP_SIZE] = {lda};
321 MKL_INT p_ldb[GROUP_SIZE] = {ldb};
322 MKL_INT p_ldc[GROUP_SIZE] = {ldc};
324 float p_alpha[GROUP_SIZE] = {alpha};
325 float p_beta[GROUP_SIZE] = {beta};
327 CBLAS_TRANSPOSE cblas_a_trans = GetT(transa);
328 CBLAS_TRANSPOSE cblas_b_trans = GetT(transb);
330 MKL_INT p_group_sizeb[GROUP_SIZE] = {batch_count};
331 CBLAS_TRANSPOSE p_transa[GROUP_SIZE] = {cblas_a_trans};
332 CBLAS_TRANSPOSE p_transb[GROUP_SIZE] = {cblas_b_trans};
334 std::vector<const float*> pp_A(batch_count,
nullptr);
335 std::vector<const float*> pp_B(batch_count,
nullptr);
336 std::vector<float*> pp_C(batch_count,
nullptr);
342 for (
int i = 0; i < batch_count; i++) {
343 pp_A[i] = A + i * m_k;
344 pp_B[i] = B + i * k_n;
345 pp_C[i] = C + i * m_n;
348 cblas_sgemm_batch(CblasColMajor, p_transa, p_transb,
349 p_m, p_n, p_k, p_alpha, pp_A.data(), p_lda, pp_B.data(),
350 p_ldb, p_beta, pp_C.data(), p_ldc, GROUP_SIZE, p_group_sizeb);
352 for (
int i = 0; i < batch_count; ++i) {
353 gemm(stream, transa, transb, m, n, k, alpha,
354 A + i * m * k, lda, B + i * k * n, ldb,
355 beta, C + i * m * n, ldc);
360 bool trans,
int m,
int n,
361 float alpha,
const float *A,
int lda,
362 const float *X,
int incX,
363 float beta,
float *Y,
int incY) {
364 cblas_sgemv(CblasColMajor, GetT(trans), m, n, alpha,
365 A, lda, X, incX, beta, Y, incY);
368 bool trans,
int m,
int n,
369 float alpha,
const float *A,
int lda,
370 const float *X,
int incX,
371 float beta,
float *Y,
int incY,
int batch_count) {
372 for (
int i = 0; i < batch_count; ++i) {
373 gemv(stream, trans, m, n, alpha, A + i * m * n, lda,
374 X + i * (trans ? m : n) * incX, incX,
375 beta, Y + i * (trans ? n : m) * incY, incY);
379 int m,
int n,
float alpha,
380 const float *X,
int incX,
381 const float *Y,
int incY,
float *A,
int lda) {
382 cblas_sger(CblasColMajor, m, n, alpha, X, incX, Y, incY, A, lda);
385 int m,
int n,
float alpha,
386 const float *X,
int incX,
387 const float *Y,
int incY,
float *A,
int lda,
int batch_count) {
388 for (
int i = 0; i < batch_count; ++i) {
389 ger(stream, m, n, alpha, X + i * m * incX, incX, Y + i * n * incY, incY,
390 A + i * lda * n, lda);
395 const float* X,
int incX,
396 const float* Y,
int incY,
398 *ret = cblas_sdot(n, X, incX, Y, incY);
404 inline static CBLAS_TRANSPOSE
GetT(
bool t) {
405 return t ? CblasTrans : CblasNoTrans;
410 bool transa,
bool transb,
411 int m,
int n,
int k,
double alpha,
412 const double *A,
int lda,
const double *B,
int ldb,
413 double beta,
double *C,
int ldc) {
414 cblas_dgemm(CblasColMajor, GetT(transa), GetT(transb),
415 m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
418 bool transa,
bool transb,
419 int m,
int n,
int k,
double alpha,
420 const double *A,
int lda,
const double *B,
int ldb,
421 double beta,
double *C,
int ldc,
int batch_count,
422 double **workspace) {
423 #if (MSHADOW_USE_MKL && INTEL_MKL_VERSION >= 20160000) 425 const int GROUP_SIZE = 1;
426 MKL_INT p_m[GROUP_SIZE] = {m};
427 MKL_INT p_n[GROUP_SIZE] = {n};
428 MKL_INT p_k[GROUP_SIZE] = {k};
429 MKL_INT p_lda[GROUP_SIZE] = {lda};
430 MKL_INT p_ldb[GROUP_SIZE] = {ldb};
431 MKL_INT p_ldc[GROUP_SIZE] = {ldc};
433 double p_alpha[GROUP_SIZE] = {alpha};
434 double p_beta[GROUP_SIZE] = {beta};
436 CBLAS_TRANSPOSE cblas_a_trans = GetT(transa);
437 CBLAS_TRANSPOSE cblas_b_trans = GetT(transb);
439 MKL_INT p_group_sizeb[GROUP_SIZE] = {batch_count};
440 CBLAS_TRANSPOSE p_transa[GROUP_SIZE] = {cblas_a_trans};
441 CBLAS_TRANSPOSE p_transb[GROUP_SIZE] = {cblas_b_trans};
443 std::vector<const double*> pp_A(batch_count,
nullptr);
444 std::vector<const double*> pp_B(batch_count,
nullptr);
445 std::vector<double*> pp_C(batch_count,
nullptr);
451 for (
int i = 0; i < batch_count; i++) {
452 pp_A[i] = A + i * m_k;
453 pp_B[i] = B + i * k_n;
454 pp_C[i] = C + i * m_n;
457 cblas_dgemm_batch(CblasColMajor, p_transa, p_transb,
458 p_m, p_n, p_k, p_alpha, pp_A.data(), p_lda, pp_B.data(),
459 p_ldb, p_beta, pp_C.data(), p_ldc, GROUP_SIZE, p_group_sizeb);
461 for (
int i = 0; i < batch_count; ++i) {
462 gemm(stream, transa, transb, m, n, k, alpha,
463 A + i * m * k, lda, B + i * k * n, ldb,
464 beta, C + i * m * n, ldc);
469 bool trans,
int m,
int n,
double alpha,
470 const double *A,
int lda,
471 const double *X,
int incX,
472 double beta,
double *Y,
int incY) {
473 cblas_dgemv(CblasColMajor, GetT(trans), m, n, alpha,
474 A, lda, X, incX, beta, Y, incY);
477 bool trans,
int m,
int n,
478 double alpha,
const double *A,
int lda,
479 const double *X,
int incX,
480 double beta,
double *Y,
int incY,
int batch_count) {
481 for (
int i = 0; i < batch_count; ++i) {
482 gemv(stream, trans, m, n, alpha, A + i * m * n, lda,
483 X + i * (trans ? m : n) * incX, incX,
484 beta, Y + i * (trans ? n : m) * incY, incY);
488 int m,
int n,
double alpha,
489 const double *X,
int incX,
490 const double *Y,
int incY,
double *A,
int lda) {
491 cblas_dger(CblasColMajor, m, n, alpha, X, incX, Y, incY, A, lda);
494 int m,
int n,
double alpha,
495 const double *X,
int incX,
496 const double *Y,
int incY,
double *A,
int lda,
int batch_count) {
497 for (
int i = 0; i < batch_count; ++i) {
498 ger(stream, m, n, alpha, X + i * m * incX, incX, Y + i * n * incY, incY,
499 A + i * lda * n, lda);
504 const double* X,
int incX,
505 const double* Y,
int incY,
507 *ret = cblas_ddot(n, X, incX, Y, incY);
510 #endif // MSHADOW_USE_CBLAS || MSHADOW_USE_MKL || MSHADOW_STAND_ALONE 516 inline static cublasOperation_t
GetT(
bool t) {
517 return t ? CUBLAS_OP_T : CUBLAS_OP_N;
522 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas set stream fail";
525 bool transa,
bool transb,
526 int m,
int n,
int k, half::half_t alpha,
527 const half::half_t *A,
int lda,
528 const half::half_t *B,
int ldb, half::half_t beta,
529 half::half_t *C,
int ldc) {
530 #if defined(CUDA_VERSION) && CUDA_VERSION >= 7050 532 float alpha_f = float(alpha);
533 float beta_f = float(beta);
534 #if CUDA_VERSION >= 8000 536 GetT(transa), GetT(transb), m, n, k, &alpha_f,
537 A, CUDA_R_16F, lda, B, CUDA_R_16F,
538 ldb, &beta_f, C, CUDA_R_16F, ldc);
539 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas SgemmEx fail";
542 GetT(transa), GetT(transb), m, n, k, &alpha_f,
543 A, CUBLAS_DATA_HALF, lda, B, CUBLAS_DATA_HALF,
544 ldb, &beta_f, C, CUBLAS_DATA_HALF, ldc);
545 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas SgemmEx fail";
546 #endif // CUDA_VERSION >= 8000 548 LOG(FATAL) <<
"Require CUDA version >= 7.5!";
549 #endif // defined(CUDA_VERSION) && CUDA_VERSION >= 7050 552 bool transa,
bool transb,
553 int m,
int n,
int k, half::half_t alpha,
554 const half::half_t *A,
int lda,
const half::half_t *B,
int ldb,
555 half::half_t beta, half::half_t *C,
int ldc,
int batch_count,
556 half::half_t **workspace) {
557 #if defined(__CUDACC__) && CUDA_VERSION >= 9000 558 int major = stream->
prop.major;
559 int minor = stream->
prop.minor;
561 if ((major > 5) || (major == 5 && minor >= 3)) {
562 const __half* A_h =
reinterpret_cast<const __half*
>(A);
563 const __half* B_h =
reinterpret_cast<const __half*
>(B);
564 __half* alpha_h =
reinterpret_cast<__half*
>(&alpha);
565 __half* beta_h =
reinterpret_cast<__half*
>(&beta);
566 __half* C_h =
reinterpret_cast<__half*
>(C);
568 GetT(transa), GetT(transb), m, n, k, alpha_h,
571 beta_h, C_h, ldc, m * n,
573 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas: HgemmStridedBatched fail";
577 for (
int i = 0; i < batch_count; ++i) {
578 gemm(stream, transa, transb, m, n, k, alpha,
579 A + i * m * k, lda, B + i * k * n, ldb,
580 beta, C + i * m * n, ldc);
584 bool trans,
int m,
int n, half::half_t alpha,
585 const half::half_t *A,
int lda,
586 const half::half_t *X,
int incX, half::half_t beta,
587 half::half_t *Y,
int incY) {
588 LOG(FATAL) <<
"Not implmented!";
591 bool trans,
int m,
int n,
592 half::half_t alpha,
const half::half_t *A,
int lda,
593 const half::half_t *X,
int incX,
594 half::half_t beta, half::half_t *Y,
int incY,
int batch_count) {
595 LOG(FATAL) <<
"Not implmented!";
598 int m,
int n, half::half_t alpha,
599 const half::half_t *X,
int incX,
600 const half::half_t *Y,
int incY, half::half_t *A,
int lda) {
601 LOG(FATAL) <<
"Not implmented!";
604 int m,
int n, half::half_t alpha,
605 const half::half_t *X,
int incX,
const half::half_t *Y,
int incY,
606 half::half_t *A,
int lda,
int batch_count) {
607 LOG(FATAL) <<
"Not implmented!";
611 const half::half_t* X,
int incX,
612 const half::half_t* Y,
int incY,
614 LOG(FATAL) <<
"Not implmented!";
620 inline static cublasOperation_t
GetT(
bool t) {
621 return t ? CUBLAS_OP_T : CUBLAS_OP_N;
626 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas: set stream fail";
629 bool transa,
bool transb,
630 int m,
int n,
int k,
float alpha,
631 const float *A,
int lda,
632 const float *B,
int ldb,
float beta,
635 GetT(transa), GetT(transb), m, n, k, &alpha,
636 A, lda, B, ldb, &beta, C, ldc);
637 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas: Sgemm fail";
640 bool transa,
bool transb,
641 int m,
int n,
int k,
float alpha,
642 const float *A,
int lda,
const float *B,
int ldb,
643 float beta,
float *C,
int ldc,
int batch_count,
645 #if defined(__CUDACC__) && CUDA_VERSION >= 4010 && CUDA_VERSION < 8000 647 bool alloc_workspace =
false;
648 if (workspace == NULL) {
651 cudaMalloc(reinterpret_cast<void**>(&workspace), 3 * batch_count *
sizeof(
float*));
652 alloc_workspace =
true;
654 GetBatchedView(workspace, const_cast<float*>(A), batch_count, m * k, stream);
656 const_cast<float*>(B), batch_count, k * n, stream);
657 GetBatchedView(workspace + 2 * batch_count, C, batch_count, m * n, stream);
659 GetT(transa), GetT(transb), m, n, k, &alpha,
660 (
const float**)workspace, lda,
661 (
const float**)(workspace + batch_count), ldb,
662 &beta, workspace + 2 * batch_count, ldc, batch_count);
663 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas: SgemmBatched fail";
664 if (alloc_workspace) {
667 #elif defined(__CUDACC__) && CUDA_VERSION >= 8000 669 GetT(transa), GetT(transb), m, n, k, &alpha,
672 &beta, C, ldc, m * n,
674 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas: SgemmStridedBatched fail";
676 for (
int i = 0; i < batch_count; ++i) {
677 gemm(stream, transa, transb, m, n, k, alpha,
678 A + i * m * k, lda, B + i * k * n, ldb,
679 beta, C + i * m * n, ldc);
681 #endif // defined(__CUDACC__) && CUDA_VERSION >= 4010 684 bool trans,
int m,
int n,
float alpha,
685 const float *A,
int lda,
686 const float *X,
int incX,
float beta,
687 float *Y,
int incY) {
689 GetT(trans), m, n, &alpha, A, lda, X, incX, &beta, Y, incY);
690 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas: Sgemv fail";
693 bool trans,
int m,
int n,
694 float alpha,
const float *A,
int lda,
695 const float *X,
int incX,
696 float beta,
float *Y,
int incY,
int batch_count) {
697 for (
int i = 0; i < batch_count; ++i) {
698 gemv(stream, trans, m, n, alpha, A + i * m * n, lda,
699 X + i * (trans ? m : n) * incX, incX,
700 beta, Y + i * (trans ? n : m) * incY, incY);
704 int m,
int n,
float alpha,
705 const float *X,
int incX,
706 const float *Y,
int incY,
float *A,
int lda) {
708 m, n, &alpha, X, incX, Y, incY, A, lda);
709 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas: Sger fail";
712 int m,
int n,
float alpha,
713 const float *X,
int incX,
714 const float *Y,
int incY,
float *A,
int lda,
int batch_count) {
715 for (
int i = 0; i < batch_count; ++i) {
716 ger(stream, m, n, alpha, X + i * m * incX, incX, Y + i * n * incY, incY,
717 A + i * lda * n, lda);
722 const float* X,
int incX,
723 const float* Y,
int incY,
726 CUBLAS_POINTER_MODE_DEVICE);
728 n, X, incX, Y, incY, ret);
729 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas: Dot fail";
731 CUBLAS_POINTER_MODE_HOST);
737 inline static cublasOperation_t
GetT(
bool t) {
738 return t ? CUBLAS_OP_T : CUBLAS_OP_N;
743 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas: set stream fail";
746 bool transa,
bool transb,
747 int m,
int n,
int k,
double alpha,
748 const double *A,
int lda,
749 const double *B,
int ldb,
750 double beta,
double *C,
int ldc) {
752 GetT(transa), GetT(transb), m, n, k, &alpha,
753 A, lda, B, ldb, &beta, C, ldc);
754 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas: Dgemm fail";
757 bool transa,
bool transb,
758 int m,
int n,
int k,
double alpha,
759 const double *A,
int lda,
const double *B,
int ldb,
760 double beta,
double *C,
int ldc,
int batch_count,
761 double **workspace) {
762 #if defined(__CUDACC__) && CUDA_VERSION >= 4010 && CUDA_VERSION < 8000 764 bool alloc_workspace =
false;
765 if (workspace == NULL) {
768 cudaMalloc(reinterpret_cast<void**>(&workspace), 3 * batch_count *
sizeof(
double*));
769 alloc_workspace =
true;
771 GetBatchedView(workspace, const_cast<double*>(A), batch_count, m * k, stream);
773 const_cast<double*>(B), batch_count, k * n, stream);
774 GetBatchedView(workspace + 2 * batch_count, C, batch_count, m * n, stream);
776 GetT(transa), GetT(transb), m, n, k, &alpha,
777 (
const double**)workspace, lda,
778 (
const double**)(workspace + batch_count), ldb,
779 &beta, workspace + 2 * batch_count, ldc, batch_count);
780 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas: DgemmBatched fail";
781 if (alloc_workspace) {
784 #elif defined(__CUDACC__) && CUDA_VERSION >= 8000 786 GetT(transa), GetT(transb), m, n, k, &alpha,
789 &beta, C, ldc, m * n,
791 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas: DgemmStridedBatched fail";
793 for (
int i = 0; i < batch_count; ++i) {
794 gemm(stream, transa, transb, m, n, k, alpha,
795 A + i * m * k, lda, B + i * k * n, ldb,
796 beta, C + i * m * n, ldc);
798 #endif // defined(__CUDACC__) && CUDA_VERSION >= 4010 801 bool trans,
int m,
int n,
double alpha,
802 const double *A,
int lda,
803 const double *X,
int incX,
804 double beta,
double *Y,
int incY) {
806 GetT(trans), m, n, &alpha, A, lda, X, incX, &beta, Y, incY);
807 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas: Dgemv fail";
810 bool trans,
int m,
int n,
811 double alpha,
const double *A,
int lda,
812 const double *X,
int incX,
813 double beta,
double *Y,
int incY,
int batch_count) {
814 for (
int i = 0; i < batch_count; ++i) {
815 gemv(stream, trans, m, n, alpha, A + i * m * n, lda,
816 X + i * (trans ? m : n) * incX, incX,
817 beta, Y + i * (trans ? n : m) * incY, incY);
821 int m,
int n,
double alpha,
822 const double *X,
int incX,
823 const double *Y,
int incY,
double *A,
int lda) {
825 m, n, &alpha, X, incX, Y, incY, A, lda);
826 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas: Dger fail";
829 int m,
int n,
double alpha,
830 const double *X,
int incX,
831 const double *Y,
int incY,
double *A,
int lda,
int batch_count) {
832 for (
int i = 0; i < batch_count; ++i) {
833 ger(stream, m, n, alpha, X + i * m * incX, incX, Y + i * n * incY, incY,
834 A + i * lda * n, lda);
839 const double* X,
int incX,
840 const double* Y,
int incY,
843 CUBLAS_POINTER_MODE_DEVICE);
845 n, X, incX, Y, incY, ret);
846 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas: Dot fail";
848 CUBLAS_POINTER_MODE_HOST);
851 #endif // MSHADOW_USE_CUDA 854 return transpose ?
Shape2(shape[1], shape[0]) : shape;
857 template<
typename SV,
typename xpu,
858 bool transpose_left,
bool transpose_right,
typename DType>
859 struct DotEngine<SV, xpu, 2, 2, 2, transpose_left, transpose_right, DType> {
865 #if MSHADOW_STAND_ALONE 867 if (!transpose_left && !transpose_right) {
869 }
else if (!transpose_left && transpose_right) {
871 }
else if (transpose_left && !transpose_right) {
881 CHECK(dst.
size(0) == sleft[0] && dst.
size(1) == sright[1] && sleft[1] == sright[0])
882 <<
"dot-gemm: matrix shape mismatch";
886 transpose_right , transpose_left,
887 transpose_right ? rhs.
size(0) : rhs.
size(1),
888 transpose_left ? lhs.
size(1) : lhs.
size(0),
889 transpose_right ? rhs.
size(1) : rhs.
size(0),
890 DType(scale * SV::AlphaBLAS()),
893 DType(SV::BetaBLAS()),
897 template<
typename SV,
typename xpu,
bool transpose_right,
typename DType>
898 struct DotEngine<SV, xpu, 1, 1, 2, false, transpose_right, DType> {
908 CHECK(dst.
size(0) == sright[1] && lhs.
size(0) == sright[0])
909 <<
"dot-gemv: matrix shape mismatch" 910 <<
"dst: " << dst.
shape_ <<
"\n" 911 <<
"lhs: " << lhs.
shape_ <<
"\n" 912 <<
"rhs: " << sright <<
"\n";
916 rhs.
size(1), rhs.
size(0), scale * SV::AlphaBLAS(),
918 lhs.
dptr_, 1, SV::BetaBLAS(),
922 template<
typename SV,
typename xpu,
typename DType>
923 struct DotEngine<SV, xpu, 2, 1, 1, true, false, DType> {
933 <<
"dot-ger: matrix shape mismatch" 934 <<
"dst: " << dst.
shape_ <<
"\n" 935 <<
"lhs: " << lhs.
shape_ <<
"\n" 937 if (SV::BetaBLAS() == 0.0f) {
949 #endif // MSHADOW_DOT_ENGINE_INL_H_ static void ger(Stream< gpu > *stream, int m, int n, half::half_t alpha, const half::half_t *X, int incX, const half::half_t *Y, int incY, half::half_t *A, int lda)
Definition: dot_engine-inl.h:597
static void batched_gemv(Stream< Device > *stream, bool trans, int m, int n, DType alpha, const DType *A, int lda, const DType *X, int incX, DType beta, DType *Y, int incY, int batch_count)
Definition: dot_engine-inl.h:107
static void batched_gemm(Stream< Device > *stream, bool transa, bool transb, int m, int n, int k, DType alpha, const DType *A, int lda, const DType *B, int ldb, DType beta, DType *C, int ldc, int batch_count, DType **workspace)
Definition: dot_engine-inl.h:92
static void SetStream(Stream< gpu > *stream)
Definition: dot_engine-inl.h:740
static void SetStream(Stream< cpu > *stream)
Definition: dot_engine-inl.h:407
ImplicitGEMMExp< LhsExp, RhsExp, DType > implicit_dot(const Exp< LhsExp, DType, e1 > &lhs, const Exp< RhsExp, DType, e2 > &rhs)
Definition: implicit_gemm.h:65
static void SetStream(Stream< cpu > *stream)
Definition: dot_engine-inl.h:298
static void SetStream(Stream< gpu > *stream)
Definition: dot_engine-inl.h:623
static void batched_gemv(Stream< gpu > *stream, bool trans, int m, int n, float alpha, const float *A, int lda, const float *X, int incX, float beta, float *Y, int incY, int batch_count)
Definition: dot_engine-inl.h:692
DType * dptr_
pointer to the data
Definition: tensor.h:435
static void gemv(Stream< gpu > *stream, bool trans, int m, int n, half::half_t alpha, const half::half_t *A, int lda, const half::half_t *X, int incX, half::half_t beta, half::half_t *Y, int incY)
Definition: dot_engine-inl.h:583
Shape< 2 > GetShape(const Shape< 2 > &shape, bool transpose)
Definition: dot_engine-inl.h:853
static void gemm(Stream< cpu > *stream, bool transa, bool transb, int m, int n, int k, double alpha, const double *A, int lda, const double *B, int ldb, double beta, double *C, int ldc)
Definition: dot_engine-inl.h:409
static void gemv(Stream< cpu > *stream, bool trans, int m, int n, float alpha, const float *A, int lda, const float *X, int incX, float beta, float *Y, int incY)
Definition: dot_engine-inl.h:359
static void gemv(Stream< gpu > *stream, bool trans, int m, int n, double alpha, const double *A, int lda, const double *X, int incX, double beta, double *Y, int incY)
Definition: dot_engine-inl.h:800
static void batched_ger(Stream< cpu > *stream, int m, int n, double alpha, const double *X, int incX, const double *Y, int incY, double *A, int lda, int batch_count)
Definition: dot_engine-inl.h:493
static void batched_gemm(Stream< gpu > *stream, bool transa, bool transb, int m, int n, int k, double alpha, const double *A, int lda, const double *B, int ldb, double beta, double *C, int ldc, int batch_count, double **workspace)
Definition: dot_engine-inl.h:756
static void gemv(Stream< cpu > *stream, bool trans, int m, int n, double alpha, const double *A, int lda, const double *X, int incX, double beta, double *Y, int incY)
Definition: dot_engine-inl.h:468
Definition: stream_gpu-inl.h:38
static void batched_gemv(Stream< gpu > *stream, bool trans, int m, int n, double alpha, const double *A, int lda, const double *X, int incX, double beta, double *Y, int incY, int batch_count)
Definition: dot_engine-inl.h:809
static void batched_ger(Stream< cpu > *stream, int m, int n, float alpha, const float *X, int incX, const float *Y, int incY, float *A, int lda, int batch_count)
Definition: dot_engine-inl.h:384
support for implicit GEMM operation
static void batched_ger(Stream< gpu > *stream, int m, int n, half::half_t alpha, const half::half_t *X, int incX, const half::half_t *Y, int incY, half::half_t *A, int lda, int batch_count)
Definition: dot_engine-inl.h:603
Shape< dimension > shape_
shape of the tensor
Definition: tensor.h:437
static void gemm(Stream< cpu > *stream, bool transa, bool transb, int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, float beta, float *C, int ldc)
Definition: dot_engine-inl.h:300
static cublasOperation_t GetT(bool t)
Definition: dot_engine-inl.h:620
static void batched_gemv(Stream< gpu > *stream, bool trans, int m, int n, half::half_t alpha, const half::half_t *A, int lda, const half::half_t *X, int incX, half::half_t beta, half::half_t *Y, int incY, int batch_count)
Definition: dot_engine-inl.h:590
static void ger(Stream< gpu > *stream, int m, int n, float alpha, const float *X, int incX, const float *Y, int incY, float *A, int lda)
Definition: dot_engine-inl.h:703
static void batched_ger(Stream< gpu > *stream, int m, int n, double alpha, const double *X, int incX, const double *Y, int incY, double *A, int lda, int batch_count)
Definition: dot_engine-inl.h:828
DotExp< TA, TB, false, false, DType > dot(const RValueExp< TA, DType > &lhs, const RValueExp< TB, DType > &rhs)
dot operator def
Definition: expression.h:241
static void batched_gemv(Stream< cpu > *stream, bool trans, int m, int n, double alpha, const double *A, int lda, const double *X, int incX, double beta, double *Y, int incY, int batch_count)
Definition: dot_engine-inl.h:476
static void batched_gemm(Stream< gpu > *stream, bool transa, bool transb, int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, float beta, float *C, int ldc, int batch_count, float **workspace)
Definition: dot_engine-inl.h:639
Definition: dot_engine-inl.h:71
static void batched_ger(Stream< gpu > *stream, int m, int n, float alpha, const float *X, int incX, const float *Y, int incY, float *A, int lda, int batch_count)
Definition: dot_engine-inl.h:711
cudaDeviceProp prop
cudaDeviceProp
Definition: stream_gpu-inl.h:63
device name GPU
Definition: tensor.h:47
MSHADOW_XINLINE Tensor< Device, 2, DType > FlatTo2D(void) const
flatten the tensor to 2 dimension, collapse the higher dimensions together
Definition: tensor.h:520
static void batched_ger(Stream< Device > *stream, int m, int n, DType alpha, const DType *X, int incX, const DType *Y, int incY, DType *A, int lda, int batch_count)
Definition: dot_engine-inl.h:120
static cublasOperation_t GetT(bool t)
Definition: dot_engine-inl.h:737
static void Eval(Tensor< xpu, 2, DType > *p_dst, const Tensor< xpu, 2, DType > &lhs, const Tensor< xpu, 2, DType > &rhs, DType scale)
Definition: dot_engine-inl.h:860
static void gemm(Stream< gpu > *stream, bool transa, bool transb, int m, int n, int k, half::half_t alpha, const half::half_t *A, int lda, const half::half_t *B, int ldb, half::half_t beta, half::half_t *C, int ldc)
Definition: dot_engine-inl.h:524
static void batched_gemm(Stream< gpu > *stream, bool transa, bool transb, int m, int n, int k, half::half_t alpha, const half::half_t *A, int lda, const half::half_t *B, int ldb, half::half_t beta, half::half_t *C, int ldc, int batch_count, half::half_t **workspace)
Definition: dot_engine-inl.h:551
static void batched_gemm(Stream< cpu > *stream, bool transa, bool transb, int m, int n, int k, double alpha, const double *A, int lda, const double *B, int ldb, double beta, double *C, int ldc, int batch_count, double **workspace)
Definition: dot_engine-inl.h:417
static void gemm(Stream< gpu > *stream, bool transa, bool transb, int m, int n, int k, double alpha, const double *A, int lda, const double *B, int ldb, double beta, double *C, int ldc)
Definition: dot_engine-inl.h:745
static CBLAS_TRANSPOSE GetT(bool t)
Definition: dot_engine-inl.h:404
static void dot(Stream< Device > *stream, int n, const DType *X, int incX, const DType *Y, int incY, DType *ret)
Definition: dot_engine-inl.h:126
static void batched_gemm(Stream< cpu > *stream, bool transa, bool transb, int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, float beta, float *C, int ldc, int batch_count, float **workspace)
Definition: dot_engine-inl.h:308
static cublasOperation_t GetT(bool t)
Definition: dot_engine-inl.h:516
MSHADOW_XINLINE Shape< 2 > Shape2(index_t s0, index_t s1)
construct a two dimension shape, stride will equal s0
Definition: tensor.h:217
static const int kDevMask
device flag number, identifies this device
Definition: tensor.h:44
static void ger(Stream< cpu > *stream, int m, int n, double alpha, const double *X, int incX, const double *Y, int incY, double *A, int lda)
Definition: dot_engine-inl.h:487
static void ger(Stream< cpu > *stream, int m, int n, float alpha, const float *X, int incX, const float *Y, int incY, float *A, int lda)
Definition: dot_engine-inl.h:378
void GetBatchedView(DType **dst, DType *src, int num, int stride, Stream< Device > *stream)
CPU/GPU: Get a batched view of the src array. dst[i] = src + i * stride.
static void dot(Stream< gpu > *stream, int n, const float *X, int incX, const float *Y, int incY, float *ret)
Definition: dot_engine-inl.h:720
static void Eval(Tensor< xpu, 2, DType > *p_dst, const Tensor< xpu, 1, DType > &lhs, const Tensor< xpu, 1, DType > &rhs, DType scale)
Definition: dot_engine-inl.h:924
static void dot(Stream< cpu > *stream, int n, const float *X, int incX, const float *Y, int incY, float *ret)
Definition: dot_engine-inl.h:393
static void ger(Stream< Device > *stream, int m, int n, DType alpha, const DType *X, int incX, const DType *Y, int incY, DType *A, int lda)
Definition: dot_engine-inl.h:114
static void dot(Stream< gpu > *stream, int n, const double *X, int incX, const double *Y, int incY, double *ret)
Definition: dot_engine-inl.h:837
static void batched_gemv(Stream< cpu > *stream, bool trans, int m, int n, float alpha, const float *A, int lda, const float *X, int incX, float beta, float *Y, int incY, int batch_count)
Definition: dot_engine-inl.h:367
static void Eval(Tensor< xpu, 1, DType > *p_dst, const Tensor< xpu, 1, DType > &lhs, const Tensor< xpu, 2, DType > &rhs, DType scale)
Definition: dot_engine-inl.h:899
static void dot(Stream< cpu > *stream, int n, const double *X, int incX, const double *Y, int incY, double *ret)
Definition: dot_engine-inl.h:502
static void ger(Stream< gpu > *stream, int m, int n, double alpha, const double *X, int incX, const double *Y, int incY, double *A, int lda)
Definition: dot_engine-inl.h:820
overloaded + operator between half_t and bf16_t
Definition: base.h:327
MSHADOW_XINLINE index_t size(int idx) const
return size of i-th dimension, start counting from highest dimension
Definition: tensor.h:506
static void gemv(Stream< gpu > *stream, bool trans, int m, int n, float alpha, const float *A, int lda, const float *X, int incX, float beta, float *Y, int incY)
Definition: dot_engine-inl.h:683
static void gemm(Stream< Device > *stream, bool transa, bool transb, int m, int n, int k, DType alpha, const DType *A, int lda, const DType *B, int ldb, DType beta, DType *C, int ldc)
Definition: dot_engine-inl.h:85
TransposeExExp< SrcExp, DType, ExpInfo< SrcExp >::kDim > transpose(const Exp< SrcExp, DType, etype > &src, Shape< ExpInfo< SrcExp >::kDim > axes)
a expression that reshapes a tensor to another shape
Definition: transpose.h:77
index_t stride_
storing the stride information in x dimension this is used to deal with pitch allocation in gpu or ss...
Definition: tensor.h:442
static void SetStream(Stream< gpu > *stream)
Definition: dot_engine-inl.h:519
Definition: dot_engine-inl.h:79
static void dot(Stream< gpu > *stream, int n, const half::half_t *X, int incX, const half::half_t *Y, int incY, half::half_t *ret)
Definition: dot_engine-inl.h:609
general tensor
Definition: tensor.h:421
static void gemv(Stream< Device > *stream, bool trans, int m, int n, DType alpha, const DType *A, int lda, const DType *X, int incX, DType beta, DType *Y, int incY)
Definition: dot_engine-inl.h:100
static void gemm(Stream< gpu > *stream, bool transa, bool transb, int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, float beta, float *C, int ldc)
Definition: dot_engine-inl.h:628
static void SetStream(Stream< Device > *stream)
Definition: dot_engine-inl.h:83
static bool GetT(bool t)
Definition: dot_engine-inl.h:80
const TransposeExp< Tensor< Device, dimension, DType >, DType > T(void) const
transpose of a matrix
Definition: expression.h:155
Stream< Device > * stream_
stream where the computation lies stream is a device dependency concept where each computation ...
Definition: tensor.h:447
void GetBatchedView(DType **dst, DType *src, int num, int stride, Stream< cpu > *stream)
Definition: dot_engine-inl.h:50
computaion stream structure, used for asynchronous computations
Definition: tensor.h:384
static CBLAS_TRANSPOSE GetT(bool t)
Definition: dot_engine-inl.h:295