7 #ifndef MSHADOW_DOT_ENGINE_INL_H_ 8 #define MSHADOW_DOT_ENGINE_INL_H_ 15 #include "./cuda/tensor_gpu-inl.cuh" 16 #endif // #ifdef __CUDACC__ 27 template<
typename Device,
typename DType>
28 inline void GetBatchedView(DType **dst, DType *src,
int num,
int stride,
29 Stream<Device> *stream);
30 template<
typename DType>
33 for (
int i = 0; i < num; i++) {
34 dst[i] = src + i * stride;
39 template<
typename DType>
40 inline void GetBatchedView(DType **dst, DType *src,
int num,
int stride,
44 #endif // #ifdef __CUDACC__ 50 template<
typename SV,
typename Device,
int ddim,
int ldim,
51 int rdim,
bool ltrans,
bool rtrans,
typename DType>
59 template<
typename Device,
typename DType = default_real_t>
61 inline static bool GetT(
bool t) {
62 return t ?
true :
false;
67 bool transa,
bool transb,
68 int m,
int n,
int k, DType alpha,
69 const DType *A,
int lda,
const DType *B,
int ldb,
70 DType beta, DType *C,
int ldc) {
71 LOG(FATAL) <<
"Not implmented!";
74 bool transa,
bool transb,
75 int m,
int n,
int k, DType alpha,
76 const DType *A,
int lda,
const DType *B,
int ldb,
77 DType beta, DType *C,
int ldc,
int batch_count,
79 LOG(FATAL) <<
"Not implmented!";
82 bool trans,
int m,
int n,
83 DType alpha,
const DType *A,
int lda,
84 const DType *X,
int incX,
85 DType beta, DType *Y,
int incY) {
86 LOG(FATAL) <<
"Not implmented!";
89 bool trans,
int m,
int n,
90 DType alpha,
const DType *A,
int lda,
91 const DType *X,
int incX,
92 DType beta, DType *Y,
int incY,
int batch_count) {
93 LOG(FATAL) <<
"Not implmented!";
96 int m,
int n, DType alpha,
97 const DType *X,
int incX,
98 const DType *Y,
int incY, DType *A,
int lda) {
99 LOG(FATAL) <<
"Not implmented!";
102 int m,
int n, DType alpha,
103 const DType *X,
int incX,
104 const DType *Y,
int incY, DType *A,
int lda,
int batch_count) {
105 LOG(FATAL) <<
"Not implmented!";
109 const DType* X,
int incX,
110 const DType* Y,
int incY,
112 LOG(FATAL) <<
"Not implmented!";
116 #if MSHADOW_STAND_ALONE 119 inline static bool GetT(
bool t) {
120 return t ?
true :
false;
122 inline static void SetStream(
Stream<cpu> *stream) {
125 bool transa,
bool transb,
126 int m,
int n,
int k,
float alpha,
127 const float *A,
int lda,
const float *B,
int ldb,
128 float beta,
float *C,
int ldc) {
129 if (alpha == 1.0f && beta == 0.0f) {
130 bool transpose_left = transb;
131 bool transpose_right = transa;
135 if (!transpose_left && !transpose_right) {
137 }
else if (!transpose_left && transpose_right) {
139 }
else if (transpose_left && !transpose_right) {
142 LOG(FATAL) <<
"Not implmented!";
145 LOG(FATAL) <<
"Not implmented!";
148 inline static void batched_gemm(
Stream<cpu> *stream,
149 bool transa,
bool transb,
150 int m,
int n,
int k,
float alpha,
151 const float *A,
int lda,
const float *B,
int ldb,
152 float beta,
float *C,
int ldc,
int batch_count,
154 for (
int i = 0; i < batch_count; ++i) {
155 gemm(stream, transa, transb, m, n, k, alpha,
156 A + i * m * k, lda, B + i * k * n, ldb,
157 beta, C + i * m * n, ldc);
161 bool trans,
int m,
int n,
162 float alpha,
const float *A,
int lda,
163 const float *X,
int incX,
164 float beta,
float *Y,
int incY) {
165 LOG(FATAL) <<
"Not implmented!";
167 inline static void batched_gemv(
Stream<cpu> *stream,
168 bool trans,
int m,
int n,
169 float alpha,
const float *A,
int lda,
170 const float *X,
int incX,
171 float beta,
float *Y,
int incY,
int batch_count) {
172 LOG(FATAL) <<
"Not implmented!";
175 int m,
int n,
float alpha,
176 const float *X,
int incX,
177 const float *Y,
int incY,
float *A,
int lda) {
178 LOG(FATAL) <<
"Not implmented!";
180 inline static void batched_ger(
Stream<cpu> *stream,
181 int m,
int n,
float alpha,
182 const float *X,
int incX,
183 const float *Y,
int incY,
float *A,
int lda,
int batch_count) {
184 LOG(FATAL) <<
"Not implmented!";
188 const float* X,
int incX,
189 const float* Y,
int incY,
191 LOG(FATAL) <<
"Not implmented!";
197 inline static bool GetT(
bool t) {
198 return t ?
true :
false;
200 inline static void SetStream(
Stream<cpu> *stream) {
203 bool transa,
bool transb,
204 int m,
int n,
int k,
double alpha,
205 const double *A,
int lda,
const double *B,
int ldb,
206 double beta,
double *C,
int ldc) {
207 if (alpha == 1.0f && beta == 0.0f) {
208 bool transpose_left = transb;
209 bool transpose_right = transa;
213 if (!transpose_left && !transpose_right) {
215 }
else if (!transpose_left && transpose_right) {
217 }
else if (transpose_left && !transpose_right) {
220 LOG(FATAL) <<
"Not implmented!";
223 LOG(FATAL) <<
"Not implmented!";
226 inline static void batched_gemm(
Stream<cpu> *stream,
227 bool transa,
bool transb,
228 int m,
int n,
int k,
double alpha,
229 const double *A,
int lda,
const double *B,
int ldb,
230 double beta,
double *C,
int ldc,
int batch_count,
231 double **workspace) {
232 for (
int i = 0; i < batch_count; ++i) {
233 gemm(stream, transa, transb, m, n, k, alpha,
234 A + i * m * k, lda, B + i * k * n, ldb,
235 beta, C + i * m * n, ldc);
239 bool trans,
int m,
int n,
240 double alpha,
const double *A,
int lda,
241 const double *X,
int incX,
242 double beta,
double *Y,
int incY) {
243 LOG(FATAL) <<
"Not implmented!";
245 inline static void batched_gemv(
Stream<cpu> *stream,
246 bool trans,
int m,
int n,
247 double alpha,
const double *A,
int lda,
248 const double *X,
int incX,
249 double beta,
double *Y,
int incY,
int batch_count) {
250 LOG(FATAL) <<
"Not implmented!";
253 int m,
int n,
double alpha,
254 const double *X,
int incX,
255 const double *Y,
int incY,
double *A,
int lda) {
256 LOG(FATAL) <<
"Not implmented!";
258 inline static void batched_ger(
Stream<cpu> *stream,
259 int m,
int n,
double alpha,
260 const double *X,
int incX,
261 const double *Y,
int incY,
double *A,
int lda,
int batch_count) {
262 LOG(FATAL) <<
"Not implmented!";
266 const double* X,
int incX,
267 const double* Y,
int incY,
269 LOG(FATAL) <<
"Not implmented!";
273 #elif (MSHADOW_USE_MKL || MSHADOW_USE_CBLAS) // NOLINT(*) 276 inline static CBLAS_TRANSPOSE
GetT(
bool t) {
277 return t ? CblasTrans : CblasNoTrans;
282 bool transa,
bool transb,
283 int m,
int n,
int k,
float alpha,
284 const float *A,
int lda,
const float *B,
int ldb,
285 float beta,
float *C,
int ldc) {
286 cblas_sgemm(CblasColMajor, GetT(transa), GetT(transb),
287 m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
290 bool transa,
bool transb,
291 int m,
int n,
int k,
float alpha,
292 const float *A,
int lda,
const float *B,
int ldb,
293 float beta,
float *C,
int ldc,
int batch_count,
295 #if (MSHADOW_USE_MKL && INTEL_MKL_VERSION >= 20160000) 297 const int GROUP_SIZE = 1;
298 MKL_INT p_m[GROUP_SIZE] = {m};
299 MKL_INT p_n[GROUP_SIZE] = {n};
300 MKL_INT p_k[GROUP_SIZE] = {k};
301 MKL_INT p_lda[GROUP_SIZE] = {lda};
302 MKL_INT p_ldb[GROUP_SIZE] = {ldb};
303 MKL_INT p_ldc[GROUP_SIZE] = {ldc};
305 float p_alpha[GROUP_SIZE] = {alpha};
306 float p_beta[GROUP_SIZE] = {beta};
308 CBLAS_TRANSPOSE cblas_a_trans = GetT(transa);
309 CBLAS_TRANSPOSE cblas_b_trans = GetT(transb);
311 MKL_INT p_group_sizeb[GROUP_SIZE] = {batch_count};
312 CBLAS_TRANSPOSE p_transa[GROUP_SIZE] = {cblas_a_trans};
313 CBLAS_TRANSPOSE p_transb[GROUP_SIZE] = {cblas_b_trans};
315 std::vector<const float*> pp_A;
316 std::vector<const float*> pp_B;
317 std::vector<float*> pp_C;
318 pp_A.reserve(batch_count);
319 pp_B.reserve(batch_count);
320 pp_C.reserve(batch_count);
326 for (
int i = 0; i < batch_count; i++) {
327 pp_A[i] = A + i * m_k;
328 pp_B[i] = B + i * k_n;
329 pp_C[i] = C + i * m_n;
332 cblas_sgemm_batch(CblasColMajor, p_transa, p_transb,
333 p_m, p_n, p_k, p_alpha, pp_A.data(), p_lda, pp_B.data(),
334 p_ldb, p_beta, pp_C.data(), p_ldc, GROUP_SIZE, p_group_sizeb);
336 for (
int i = 0; i < batch_count; ++i) {
337 gemm(stream, transa, transb, m, n, k, alpha,
338 A + i * m * k, lda, B + i * k * n, ldb,
339 beta, C + i * m * n, ldc);
344 bool trans,
int m,
int n,
345 float alpha,
const float *A,
int lda,
346 const float *X,
int incX,
347 float beta,
float *Y,
int incY) {
348 cblas_sgemv(CblasColMajor, GetT(trans), m, n, alpha,
349 A, lda, X, incX, beta, Y, incY);
352 bool trans,
int m,
int n,
353 float alpha,
const float *A,
int lda,
354 const float *X,
int incX,
355 float beta,
float *Y,
int incY,
int batch_count) {
356 for (
int i = 0; i < batch_count; ++i) {
357 gemv(stream, trans, m, n, alpha, A + i * m * n, lda,
358 X + i * (trans ? m : n) * incX, incX,
359 beta, Y + i * (trans ? n : m) * incY, incY);
363 int m,
int n,
float alpha,
364 const float *X,
int incX,
365 const float *Y,
int incY,
float *A,
int lda) {
366 cblas_sger(CblasColMajor, m, n, alpha, X, incX, Y, incY, A, lda);
369 int m,
int n,
float alpha,
370 const float *X,
int incX,
371 const float *Y,
int incY,
float *A,
int lda,
int batch_count) {
372 for (
int i = 0; i < batch_count; ++i) {
373 ger(stream, m, n, alpha, X + i * m * incX, incX, Y + i * n * incY, incY,
374 A + i * lda * n, lda);
379 const float* X,
int incX,
380 const float* Y,
int incY,
382 *ret = cblas_sdot(n, X, incX, Y, incY);
388 inline static CBLAS_TRANSPOSE
GetT(
bool t) {
389 return t ? CblasTrans : CblasNoTrans;
394 bool transa,
bool transb,
395 int m,
int n,
int k,
double alpha,
396 const double *A,
int lda,
const double *B,
int ldb,
397 double beta,
double *C,
int ldc) {
398 cblas_dgemm(CblasColMajor, GetT(transa), GetT(transb),
399 m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
402 bool transa,
bool transb,
403 int m,
int n,
int k,
double alpha,
404 const double *A,
int lda,
const double *B,
int ldb,
405 double beta,
double *C,
int ldc,
int batch_count,
406 double **workspace) {
407 #if (MSHADOW_USE_MKL && INTEL_MKL_VERSION >= 20160000) 409 const int GROUP_SIZE = 1;
410 MKL_INT p_m[GROUP_SIZE] = {m};
411 MKL_INT p_n[GROUP_SIZE] = {n};
412 MKL_INT p_k[GROUP_SIZE] = {k};
413 MKL_INT p_lda[GROUP_SIZE] = {lda};
414 MKL_INT p_ldb[GROUP_SIZE] = {ldb};
415 MKL_INT p_ldc[GROUP_SIZE] = {ldc};
417 double p_alpha[GROUP_SIZE] = {alpha};
418 double p_beta[GROUP_SIZE] = {beta};
420 CBLAS_TRANSPOSE cblas_a_trans = GetT(transa);
421 CBLAS_TRANSPOSE cblas_b_trans = GetT(transb);
423 MKL_INT p_group_sizeb[GROUP_SIZE] = {batch_count};
424 CBLAS_TRANSPOSE p_transa[GROUP_SIZE] = {cblas_a_trans};
425 CBLAS_TRANSPOSE p_transb[GROUP_SIZE] = {cblas_b_trans};
427 std::vector<const double*> pp_A;
428 std::vector<const double*> pp_B;
429 std::vector<double*> pp_C;
430 pp_A.reserve(batch_count);
431 pp_B.reserve(batch_count);
432 pp_C.reserve(batch_count);
438 for (
int i = 0; i < batch_count; i++) {
439 pp_A[i] = A + i * m_k;
440 pp_B[i] = B + i * k_n;
441 pp_C[i] = C + i * m_n;
444 cblas_dgemm_batch(CblasColMajor, p_transa, p_transb,
445 p_m, p_n, p_k, p_alpha, pp_A.data(), p_lda, pp_B.data(),
446 p_ldb, p_beta, pp_C.data(), p_ldc, GROUP_SIZE, p_group_sizeb);
448 for (
int i = 0; i < batch_count; ++i) {
449 gemm(stream, transa, transb, m, n, k, alpha,
450 A + i * m * k, lda, B + i * k * n, ldb,
451 beta, C + i * m * n, ldc);
456 bool trans,
int m,
int n,
double alpha,
457 const double *A,
int lda,
458 const double *X,
int incX,
459 double beta,
double *Y,
int incY) {
460 cblas_dgemv(CblasColMajor, GetT(trans), m, n, alpha,
461 A, lda, X, incX, beta, Y, incY);
464 bool trans,
int m,
int n,
465 double alpha,
const double *A,
int lda,
466 const double *X,
int incX,
467 double beta,
double *Y,
int incY,
int batch_count) {
468 for (
int i = 0; i < batch_count; ++i) {
469 gemv(stream, trans, m, n, alpha, A + i * m * n, lda,
470 X + i * (trans ? m : n) * incX, incX,
471 beta, Y + i * (trans ? n : m) * incY, incY);
475 int m,
int n,
double alpha,
476 const double *X,
int incX,
477 const double *Y,
int incY,
double *A,
int lda) {
478 cblas_dger(CblasColMajor, m, n, alpha, X, incX, Y, incY, A, lda);
481 int m,
int n,
double alpha,
482 const double *X,
int incX,
483 const double *Y,
int incY,
double *A,
int lda,
int batch_count) {
484 for (
int i = 0; i < batch_count; ++i) {
485 ger(stream, m, n, alpha, X + i * m * incX, incX, Y + i * n * incY, incY,
486 A + i * lda * n, lda);
491 const double* X,
int incX,
492 const double* Y,
int incY,
494 *ret = cblas_ddot(n, X, incX, Y, incY);
497 #endif // MSHADOW_USE_CBLAS || MSHADOW_USE_MKL || MSHADOW_STAND_ALONE 503 inline static cublasOperation_t
GetT(
bool t) {
504 return t ? CUBLAS_OP_T : CUBLAS_OP_N;
509 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas set stream fail";
512 bool transa,
bool transb,
513 int m,
int n,
int k, half::half_t alpha,
514 const half::half_t *A,
int lda,
515 const half::half_t *B,
int ldb, half::half_t beta,
516 half::half_t *C,
int ldc) {
517 #if defined(CUDA_VERSION) && CUDA_VERSION >= 7050 519 float alpha_f = float(alpha);
520 float beta_f = float(beta);
521 #if CUDA_VERSION >= 8000 523 GetT(transa), GetT(transb), m, n, k, &alpha_f,
524 A, CUDA_R_16F, lda, B, CUDA_R_16F,
525 ldb, &beta_f, C, CUDA_R_16F, ldc);
526 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas SgemmEx fail";
529 GetT(transa), GetT(transb), m, n, k, &alpha_f,
530 A, CUBLAS_DATA_HALF, lda, B, CUBLAS_DATA_HALF,
531 ldb, &beta_f, C, CUBLAS_DATA_HALF, ldc);
532 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas SgemmEx fail";
533 #endif // CUDA_VERSION >= 8000 535 LOG(FATAL) <<
"Require CUDA version >= 7.5!";
536 #endif // defined(CUDA_VERSION) && CUDA_VERSION >= 7050 539 bool transa,
bool transb,
540 int m,
int n,
int k, half::half_t alpha,
541 const half::half_t *A,
int lda,
const half::half_t *B,
int ldb,
542 half::half_t beta, half::half_t *C,
int ldc,
int batch_count,
543 half::half_t **workspace) {
544 #if defined(__CUDACC__) && CUDA_VERSION >= 9000 545 int major = stream->
prop.major;
546 int minor = stream->
prop.minor;
548 if ((major > 5) || (major == 5 && minor >= 3)) {
549 const __half* A_h =
reinterpret_cast<const __half*
>(A);
550 const __half* B_h =
reinterpret_cast<const __half*
>(B);
551 __half* alpha_h =
reinterpret_cast<__half*
>(&alpha);
552 __half* beta_h =
reinterpret_cast<__half*
>(&beta);
553 __half* C_h =
reinterpret_cast<__half*
>(C);
555 GetT(transa), GetT(transb), m, n, k, alpha_h,
558 beta_h, C_h, ldc, m * n,
560 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas: HgemmStridedBatched fail";
564 for (
int i = 0; i < batch_count; ++i) {
565 gemm(stream, transa, transb, m, n, k, alpha,
566 A + i * m * k, lda, B + i * k * n, ldb,
567 beta, C + i * m * n, ldc);
571 bool trans,
int m,
int n, half::half_t alpha,
572 const half::half_t *A,
int lda,
573 const half::half_t *X,
int incX, half::half_t beta,
574 half::half_t *Y,
int incY) {
575 LOG(FATAL) <<
"Not implmented!";
578 bool trans,
int m,
int n,
579 half::half_t alpha,
const half::half_t *A,
int lda,
580 const half::half_t *X,
int incX,
581 half::half_t beta, half::half_t *Y,
int incY,
int batch_count) {
582 LOG(FATAL) <<
"Not implmented!";
585 int m,
int n, half::half_t alpha,
586 const half::half_t *X,
int incX,
587 const half::half_t *Y,
int incY, half::half_t *A,
int lda) {
588 LOG(FATAL) <<
"Not implmented!";
591 int m,
int n, half::half_t alpha,
592 const half::half_t *X,
int incX,
const half::half_t *Y,
int incY,
593 half::half_t *A,
int lda,
int batch_count) {
594 LOG(FATAL) <<
"Not implmented!";
598 const half::half_t* X,
int incX,
599 const half::half_t* Y,
int incY,
601 LOG(FATAL) <<
"Not implmented!";
607 inline static cublasOperation_t
GetT(
bool t) {
608 return t ? CUBLAS_OP_T : CUBLAS_OP_N;
613 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas: set stream fail";
616 bool transa,
bool transb,
617 int m,
int n,
int k,
float alpha,
618 const float *A,
int lda,
619 const float *B,
int ldb,
float beta,
622 GetT(transa), GetT(transb), m, n, k, &alpha,
623 A, lda, B, ldb, &beta, C, ldc);
624 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas: Sgemm fail";
627 bool transa,
bool transb,
628 int m,
int n,
int k,
float alpha,
629 const float *A,
int lda,
const float *B,
int ldb,
630 float beta,
float *C,
int ldc,
int batch_count,
632 #if defined(__CUDACC__) && CUDA_VERSION >= 4010 && CUDA_VERSION < 8000 634 bool alloc_workspace =
false;
635 if (workspace == NULL) {
638 cudaMalloc(reinterpret_cast<void**>(&workspace), 3 * batch_count *
sizeof(
float*));
639 alloc_workspace =
true;
641 GetBatchedView(workspace, const_cast<float*>(A), batch_count, m * k, stream);
643 const_cast<float*>(B), batch_count, k * n, stream);
644 GetBatchedView(workspace + 2 * batch_count, C, batch_count, m * n, stream);
646 GetT(transa), GetT(transb), m, n, k, &alpha,
647 (
const float**)workspace, lda,
648 (
const float**)(workspace + batch_count), ldb,
649 &beta, workspace + 2 * batch_count, ldc, batch_count);
650 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas: SgemmBatched fail";
651 if (alloc_workspace) {
654 #elif defined(__CUDACC__) && CUDA_VERSION >= 8000 656 GetT(transa), GetT(transb), m, n, k, &alpha,
659 &beta, C, ldc, m * n,
661 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas: SgemmStridedBatched fail";
663 for (
int i = 0; i < batch_count; ++i) {
664 gemm(stream, transa, transb, m, n, k, alpha,
665 A + i * m * k, lda, B + i * k * n, ldb,
666 beta, C + i * m * n, ldc);
668 #endif // defined(__CUDACC__) && CUDA_VERSION >= 4010 671 bool trans,
int m,
int n,
float alpha,
672 const float *A,
int lda,
673 const float *X,
int incX,
float beta,
674 float *Y,
int incY) {
676 GetT(trans), m, n, &alpha, A, lda, X, incX, &beta, Y, incY);
677 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas: Sgemv fail";
680 bool trans,
int m,
int n,
681 float alpha,
const float *A,
int lda,
682 const float *X,
int incX,
683 float beta,
float *Y,
int incY,
int batch_count) {
684 for (
int i = 0; i < batch_count; ++i) {
685 gemv(stream, trans, m, n, alpha, A + i * m * n, lda,
686 X + i * (trans ? m : n) * incX, incX,
687 beta, Y + i * (trans ? n : m) * incY, incY);
691 int m,
int n,
float alpha,
692 const float *X,
int incX,
693 const float *Y,
int incY,
float *A,
int lda) {
695 m, n, &alpha, X, incX, Y, incY, A, lda);
696 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas: Sger fail";
699 int m,
int n,
float alpha,
700 const float *X,
int incX,
701 const float *Y,
int incY,
float *A,
int lda,
int batch_count) {
702 for (
int i = 0; i < batch_count; ++i) {
703 ger(stream, m, n, alpha, X + i * m * incX, incX, Y + i * n * incY, incY,
704 A + i * lda * n, lda);
709 const float* X,
int incX,
710 const float* Y,
int incY,
713 CUBLAS_POINTER_MODE_DEVICE);
715 n, X, incX, Y, incY, ret);
716 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas: Dot fail";
718 CUBLAS_POINTER_MODE_HOST);
724 inline static cublasOperation_t
GetT(
bool t) {
725 return t ? CUBLAS_OP_T : CUBLAS_OP_N;
730 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas: set stream fail";
733 bool transa,
bool transb,
734 int m,
int n,
int k,
double alpha,
735 const double *A,
int lda,
736 const double *B,
int ldb,
737 double beta,
double *C,
int ldc) {
739 GetT(transa), GetT(transb), m, n, k, &alpha,
740 A, lda, B, ldb, &beta, C, ldc);
741 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas: Dgemm fail";
744 bool transa,
bool transb,
745 int m,
int n,
int k,
double alpha,
746 const double *A,
int lda,
const double *B,
int ldb,
747 double beta,
double *C,
int ldc,
int batch_count,
748 double **workspace) {
749 #if defined(__CUDACC__) && CUDA_VERSION >= 4010 && CUDA_VERSION < 8000 751 bool alloc_workspace =
false;
752 if (workspace == NULL) {
755 cudaMalloc(reinterpret_cast<void**>(&workspace), 3 * batch_count *
sizeof(
double*));
756 alloc_workspace =
true;
758 GetBatchedView(workspace, const_cast<double*>(A), batch_count, m * k, stream);
760 const_cast<double*>(B), batch_count, k * n, stream);
761 GetBatchedView(workspace + 2 * batch_count, C, batch_count, m * n, stream);
763 GetT(transa), GetT(transb), m, n, k, &alpha,
764 (
const double**)workspace, lda,
765 (
const double**)(workspace + batch_count), ldb,
766 &beta, workspace + 2 * batch_count, ldc, batch_count);
767 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas: DgemmBatched fail";
768 if (alloc_workspace) {
771 #elif defined(__CUDACC__) && CUDA_VERSION >= 8000 773 GetT(transa), GetT(transb), m, n, k, &alpha,
776 &beta, C, ldc, m * n,
778 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas: DgemmStridedBatched fail";
780 for (
int i = 0; i < batch_count; ++i) {
781 gemm(stream, transa, transb, m, n, k, alpha,
782 A + i * m * k, lda, B + i * k * n, ldb,
783 beta, C + i * m * n, ldc);
785 #endif // defined(__CUDACC__) && CUDA_VERSION >= 4010 788 bool trans,
int m,
int n,
double alpha,
789 const double *A,
int lda,
790 const double *X,
int incX,
791 double beta,
double *Y,
int incY) {
793 GetT(trans), m, n, &alpha, A, lda, X, incX, &beta, Y, incY);
794 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas: Dgemv fail";
797 bool trans,
int m,
int n,
798 double alpha,
const double *A,
int lda,
799 const double *X,
int incX,
800 double beta,
double *Y,
int incY,
int batch_count) {
801 for (
int i = 0; i < batch_count; ++i) {
802 gemv(stream, trans, m, n, alpha, A + i * m * n, lda,
803 X + i * (trans ? m : n) * incX, incX,
804 beta, Y + i * (trans ? n : m) * incY, incY);
808 int m,
int n,
double alpha,
809 const double *X,
int incX,
810 const double *Y,
int incY,
double *A,
int lda) {
812 m, n, &alpha, X, incX, Y, incY, A, lda);
813 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas: Dger fail";
816 int m,
int n,
double alpha,
817 const double *X,
int incX,
818 const double *Y,
int incY,
double *A,
int lda,
int batch_count) {
819 for (
int i = 0; i < batch_count; ++i) {
820 ger(stream, m, n, alpha, X + i * m * incX, incX, Y + i * n * incY, incY,
821 A + i * lda * n, lda);
826 const double* X,
int incX,
827 const double* Y,
int incY,
830 CUBLAS_POINTER_MODE_DEVICE);
832 n, X, incX, Y, incY, ret);
833 CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) <<
"Cublas: Dot fail";
835 CUBLAS_POINTER_MODE_HOST);
838 #endif // MSHADOW_USE_CUDA 841 return transpose ?
Shape2(shape[1], shape[0]) : shape;
844 template<
typename SV,
typename xpu,
845 bool transpose_left,
bool transpose_right,
typename DType>
846 struct DotEngine<SV, xpu, 2, 2, 2, transpose_left, transpose_right, DType> {
852 #if MSHADOW_STAND_ALONE 854 if (!transpose_left && !transpose_right) {
856 }
else if (!transpose_left && transpose_right) {
858 }
else if (transpose_left && !transpose_right) {
868 CHECK(dst.
size(0) == sleft[0] && dst.
size(1) == sright[1] && sleft[1] == sright[0])
869 <<
"dot-gemm: matrix shape mismatch";
873 transpose_right , transpose_left,
874 transpose_right ? rhs.
size(0) : rhs.
size(1),
875 transpose_left ? lhs.
size(1) : lhs.
size(0),
876 transpose_right ? rhs.
size(1) : rhs.
size(0),
877 DType(scale * SV::AlphaBLAS()),
880 DType(SV::BetaBLAS()),
884 template<
typename SV,
typename xpu,
bool transpose_right,
typename DType>
885 struct DotEngine<SV, xpu, 1, 1, 2, false, transpose_right, DType> {
895 CHECK(dst.
size(0) == sright[1] && lhs.
size(0) == sright[0])
896 <<
"dot-gemv: matrix shape mismatch" 897 <<
"dst: " << dst.
shape_ <<
"\n" 898 <<
"lhs: " << lhs.
shape_ <<
"\n" 899 <<
"rhs: " << sright <<
"\n";
903 rhs.
size(1), rhs.
size(0), scale * SV::AlphaBLAS(),
905 lhs.
dptr_, 1, SV::BetaBLAS(),
909 template<
typename SV,
typename xpu,
typename DType>
910 struct DotEngine<SV, xpu, 2, 1, 1, true, false, DType> {
920 <<
"dot-ger: matrix shape mismatch" 921 <<
"dst: " << dst.
shape_ <<
"\n" 922 <<
"lhs: " << lhs.
shape_ <<
"\n" 924 if (SV::BetaBLAS() == 0.0f) {
936 #endif // MSHADOW_DOT_ENGINE_INL_H_ static void ger(Stream< gpu > *stream, int m, int n, half::half_t alpha, const half::half_t *X, int incX, const half::half_t *Y, int incY, half::half_t *A, int lda)
Definition: dot_engine-inl.h:584
static void batched_gemv(Stream< Device > *stream, bool trans, int m, int n, DType alpha, const DType *A, int lda, const DType *X, int incX, DType beta, DType *Y, int incY, int batch_count)
Definition: dot_engine-inl.h:88
static void batched_gemm(Stream< Device > *stream, bool transa, bool transb, int m, int n, int k, DType alpha, const DType *A, int lda, const DType *B, int ldb, DType beta, DType *C, int ldc, int batch_count, DType **workspace)
Definition: dot_engine-inl.h:73
static void SetStream(Stream< gpu > *stream)
Definition: dot_engine-inl.h:727
static void SetStream(Stream< cpu > *stream)
Definition: dot_engine-inl.h:391
ImplicitGEMMExp< LhsExp, RhsExp, DType > implicit_dot(const Exp< LhsExp, DType, e1 > &lhs, const Exp< RhsExp, DType, e2 > &rhs)
Definition: implicit_gemm.h:46
static void SetStream(Stream< cpu > *stream)
Definition: dot_engine-inl.h:279
static void SetStream(Stream< gpu > *stream)
Definition: dot_engine-inl.h:610
static void batched_gemv(Stream< gpu > *stream, bool trans, int m, int n, float alpha, const float *A, int lda, const float *X, int incX, float beta, float *Y, int incY, int batch_count)
Definition: dot_engine-inl.h:679
DType * dptr_
pointer to the data
Definition: tensor.h:416
static void gemv(Stream< gpu > *stream, bool trans, int m, int n, half::half_t alpha, const half::half_t *A, int lda, const half::half_t *X, int incX, half::half_t beta, half::half_t *Y, int incY)
Definition: dot_engine-inl.h:570
Shape< 2 > GetShape(const Shape< 2 > &shape, bool transpose)
Definition: dot_engine-inl.h:840
static void gemm(Stream< cpu > *stream, bool transa, bool transb, int m, int n, int k, double alpha, const double *A, int lda, const double *B, int ldb, double beta, double *C, int ldc)
Definition: dot_engine-inl.h:393
static void gemv(Stream< cpu > *stream, bool trans, int m, int n, float alpha, const float *A, int lda, const float *X, int incX, float beta, float *Y, int incY)
Definition: dot_engine-inl.h:343
static void gemv(Stream< gpu > *stream, bool trans, int m, int n, double alpha, const double *A, int lda, const double *X, int incX, double beta, double *Y, int incY)
Definition: dot_engine-inl.h:787
static void batched_ger(Stream< cpu > *stream, int m, int n, double alpha, const double *X, int incX, const double *Y, int incY, double *A, int lda, int batch_count)
Definition: dot_engine-inl.h:480
static void batched_gemm(Stream< gpu > *stream, bool transa, bool transb, int m, int n, int k, double alpha, const double *A, int lda, const double *B, int ldb, double beta, double *C, int ldc, int batch_count, double **workspace)
Definition: dot_engine-inl.h:743
static void gemv(Stream< cpu > *stream, bool trans, int m, int n, double alpha, const double *A, int lda, const double *X, int incX, double beta, double *Y, int incY)
Definition: dot_engine-inl.h:455
Definition: stream_gpu-inl.h:19
static void batched_gemv(Stream< gpu > *stream, bool trans, int m, int n, double alpha, const double *A, int lda, const double *X, int incX, double beta, double *Y, int incY, int batch_count)
Definition: dot_engine-inl.h:796
static void batched_ger(Stream< cpu > *stream, int m, int n, float alpha, const float *X, int incX, const float *Y, int incY, float *A, int lda, int batch_count)
Definition: dot_engine-inl.h:368
support for implicit GEMM operation
static void batched_ger(Stream< gpu > *stream, int m, int n, half::half_t alpha, const half::half_t *X, int incX, const half::half_t *Y, int incY, half::half_t *A, int lda, int batch_count)
Definition: dot_engine-inl.h:590
Shape< dimension > shape_
shape of the tensor
Definition: tensor.h:418
static void gemm(Stream< cpu > *stream, bool transa, bool transb, int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, float beta, float *C, int ldc)
Definition: dot_engine-inl.h:281
static cublasOperation_t GetT(bool t)
Definition: dot_engine-inl.h:607
static void batched_gemv(Stream< gpu > *stream, bool trans, int m, int n, half::half_t alpha, const half::half_t *A, int lda, const half::half_t *X, int incX, half::half_t beta, half::half_t *Y, int incY, int batch_count)
Definition: dot_engine-inl.h:577
static void ger(Stream< gpu > *stream, int m, int n, float alpha, const float *X, int incX, const float *Y, int incY, float *A, int lda)
Definition: dot_engine-inl.h:690
static void batched_ger(Stream< gpu > *stream, int m, int n, double alpha, const double *X, int incX, const double *Y, int incY, double *A, int lda, int batch_count)
Definition: dot_engine-inl.h:815
DotExp< TA, TB, false, false, DType > dot(const RValueExp< TA, DType > &lhs, const RValueExp< TB, DType > &rhs)
dot operator def
Definition: expression.h:222
static void batched_gemv(Stream< cpu > *stream, bool trans, int m, int n, double alpha, const double *A, int lda, const double *X, int incX, double beta, double *Y, int incY, int batch_count)
Definition: dot_engine-inl.h:463
static void batched_gemm(Stream< gpu > *stream, bool transa, bool transb, int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, float beta, float *C, int ldc, int batch_count, float **workspace)
Definition: dot_engine-inl.h:626
Definition: dot_engine-inl.h:52
static void batched_ger(Stream< gpu > *stream, int m, int n, float alpha, const float *X, int incX, const float *Y, int incY, float *A, int lda, int batch_count)
Definition: dot_engine-inl.h:698
cudaDeviceProp prop
cudaDeviceProp
Definition: stream_gpu-inl.h:44
device name GPU
Definition: tensor.h:28
MSHADOW_XINLINE Tensor< Device, 2, DType > FlatTo2D(void) const
flatten the tensor to 2 dimension, collapse the higher dimensions together
Definition: tensor.h:501
static void batched_ger(Stream< Device > *stream, int m, int n, DType alpha, const DType *X, int incX, const DType *Y, int incY, DType *A, int lda, int batch_count)
Definition: dot_engine-inl.h:101
static cublasOperation_t GetT(bool t)
Definition: dot_engine-inl.h:724
static void Eval(Tensor< xpu, 2, DType > *p_dst, const Tensor< xpu, 2, DType > &lhs, const Tensor< xpu, 2, DType > &rhs, DType scale)
Definition: dot_engine-inl.h:847
static void gemm(Stream< gpu > *stream, bool transa, bool transb, int m, int n, int k, half::half_t alpha, const half::half_t *A, int lda, const half::half_t *B, int ldb, half::half_t beta, half::half_t *C, int ldc)
Definition: dot_engine-inl.h:511
static void batched_gemm(Stream< gpu > *stream, bool transa, bool transb, int m, int n, int k, half::half_t alpha, const half::half_t *A, int lda, const half::half_t *B, int ldb, half::half_t beta, half::half_t *C, int ldc, int batch_count, half::half_t **workspace)
Definition: dot_engine-inl.h:538
static void batched_gemm(Stream< cpu > *stream, bool transa, bool transb, int m, int n, int k, double alpha, const double *A, int lda, const double *B, int ldb, double beta, double *C, int ldc, int batch_count, double **workspace)
Definition: dot_engine-inl.h:401
static void gemm(Stream< gpu > *stream, bool transa, bool transb, int m, int n, int k, double alpha, const double *A, int lda, const double *B, int ldb, double beta, double *C, int ldc)
Definition: dot_engine-inl.h:732
static CBLAS_TRANSPOSE GetT(bool t)
Definition: dot_engine-inl.h:388
static void dot(Stream< Device > *stream, int n, const DType *X, int incX, const DType *Y, int incY, DType *ret)
Definition: dot_engine-inl.h:107
static void batched_gemm(Stream< cpu > *stream, bool transa, bool transb, int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, float beta, float *C, int ldc, int batch_count, float **workspace)
Definition: dot_engine-inl.h:289
static cublasOperation_t GetT(bool t)
Definition: dot_engine-inl.h:503
MSHADOW_XINLINE Shape< 2 > Shape2(index_t s0, index_t s1)
construct a two dimension shape, stride will equal s0
Definition: tensor.h:198
static const int kDevMask
device flag number, identifies this device
Definition: tensor.h:25
static void ger(Stream< cpu > *stream, int m, int n, double alpha, const double *X, int incX, const double *Y, int incY, double *A, int lda)
Definition: dot_engine-inl.h:474
static void ger(Stream< cpu > *stream, int m, int n, float alpha, const float *X, int incX, const float *Y, int incY, float *A, int lda)
Definition: dot_engine-inl.h:362
void GetBatchedView(DType **dst, DType *src, int num, int stride, Stream< Device > *stream)
CPU/GPU: Get a batched view of the src array. dst[i] = src + i * stride.
static void dot(Stream< gpu > *stream, int n, const float *X, int incX, const float *Y, int incY, float *ret)
Definition: dot_engine-inl.h:707
static void Eval(Tensor< xpu, 2, DType > *p_dst, const Tensor< xpu, 1, DType > &lhs, const Tensor< xpu, 1, DType > &rhs, DType scale)
Definition: dot_engine-inl.h:911
static void dot(Stream< cpu > *stream, int n, const float *X, int incX, const float *Y, int incY, float *ret)
Definition: dot_engine-inl.h:377
static void ger(Stream< Device > *stream, int m, int n, DType alpha, const DType *X, int incX, const DType *Y, int incY, DType *A, int lda)
Definition: dot_engine-inl.h:95
static void dot(Stream< gpu > *stream, int n, const double *X, int incX, const double *Y, int incY, double *ret)
Definition: dot_engine-inl.h:824
static void batched_gemv(Stream< cpu > *stream, bool trans, int m, int n, float alpha, const float *A, int lda, const float *X, int incX, float beta, float *Y, int incY, int batch_count)
Definition: dot_engine-inl.h:351
static void Eval(Tensor< xpu, 1, DType > *p_dst, const Tensor< xpu, 1, DType > &lhs, const Tensor< xpu, 2, DType > &rhs, DType scale)
Definition: dot_engine-inl.h:886
static void dot(Stream< cpu > *stream, int n, const double *X, int incX, const double *Y, int incY, double *ret)
Definition: dot_engine-inl.h:489
static void ger(Stream< gpu > *stream, int m, int n, double alpha, const double *X, int incX, const double *Y, int incY, double *A, int lda)
Definition: dot_engine-inl.h:807
namespace for mshadow
Definition: base.h:282
MSHADOW_XINLINE index_t size(int idx) const
return size of i-th dimension, start counting from highest dimension
Definition: tensor.h:487
static void gemv(Stream< gpu > *stream, bool trans, int m, int n, float alpha, const float *A, int lda, const float *X, int incX, float beta, float *Y, int incY)
Definition: dot_engine-inl.h:670
static void gemm(Stream< Device > *stream, bool transa, bool transb, int m, int n, int k, DType alpha, const DType *A, int lda, const DType *B, int ldb, DType beta, DType *C, int ldc)
Definition: dot_engine-inl.h:66
TransposeExExp< SrcExp, DType, ExpInfo< SrcExp >::kDim > transpose(const Exp< SrcExp, DType, etype > &src, Shape< ExpInfo< SrcExp >::kDim > axes)
a expression that reshapes a tensor to another shape
Definition: transpose.h:58
index_t stride_
storing the stride information in x dimension this is used to deal with pitch allocation in gpu or ss...
Definition: tensor.h:423
static void SetStream(Stream< gpu > *stream)
Definition: dot_engine-inl.h:506
Definition: dot_engine-inl.h:60
static void dot(Stream< gpu > *stream, int n, const half::half_t *X, int incX, const half::half_t *Y, int incY, half::half_t *ret)
Definition: dot_engine-inl.h:596
general tensor
Definition: tensor.h:402
static void gemv(Stream< Device > *stream, bool trans, int m, int n, DType alpha, const DType *A, int lda, const DType *X, int incX, DType beta, DType *Y, int incY)
Definition: dot_engine-inl.h:81
static void gemm(Stream< gpu > *stream, bool transa, bool transb, int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, float beta, float *C, int ldc)
Definition: dot_engine-inl.h:615
static void SetStream(Stream< Device > *stream)
Definition: dot_engine-inl.h:64
static bool GetT(bool t)
Definition: dot_engine-inl.h:61
const TransposeExp< Tensor< Device, dimension, DType >, DType > T(void) const
transpose of a matrix
Definition: expression.h:136
Stream< Device > * stream_
stream where the computation lies stream is a device dependency concept where each computation ...
Definition: tensor.h:428
void GetBatchedView(DType **dst, DType *src, int num, int stride, Stream< cpu > *stream)
Definition: dot_engine-inl.h:31
computaion stream structure, used for asynchronous computations
Definition: tensor.h:365
static CBLAS_TRANSPOSE GetT(bool t)
Definition: dot_engine-inl.h:276