mxnet
dot_engine-inl.h
Go to the documentation of this file.
1 
7 #ifndef MSHADOW_DOT_ENGINE_INL_H_
8 #define MSHADOW_DOT_ENGINE_INL_H_
9 
10 #include <vector>
11 #include "./base.h"
13 
14 #ifdef __CUDACC__
15 #include "./cuda/tensor_gpu-inl.cuh"
16 #endif // #ifdef __CUDACC__
17 
18 namespace mshadow {
27 template<typename Device, typename DType>
28 inline void GetBatchedView(DType **dst, DType *src, int num, int stride,
29  Stream<Device> *stream);
30 template<typename DType>
31 inline void GetBatchedView(DType **dst, DType *src, int num, int stride,
32  Stream<cpu> *stream) {
33  for (int i = 0; i < num; i++) {
34  dst[i] = src + i * stride;
35  }
36 }
37 #ifdef __CUDACC__
38 namespace cuda {};
39 template<typename DType>
40 inline void GetBatchedView(DType **dst, DType *src, int num, int stride,
41  Stream<gpu> *stream) {
42  cuda::GetBatchedView(dst, src, num, stride, stream);
43 }
44 #endif // #ifdef __CUDACC__
45 
46 namespace expr {
47 //---------------------------------------------------------------------
48 // Matrix Multiplications, depends on BLAS Engine
49 //---------------------------------------------------------------------
50 template<typename SV, typename Device, int ddim, int ldim,
51  int rdim, bool ltrans, bool rtrans, typename DType>
52 struct DotEngine {
53  inline static void Eval(Tensor<Device, ddim, DType> *p_dst,
54  const Tensor<Device, ldim, DType> &lhs,
55  const Tensor<Device, rdim, DType> &rhs,
56  DType scale);
57 };
58 // handles the dot, use CblasColMajor
59 template<typename Device, typename DType = default_real_t>
60 struct BLASEngine {
61  inline static bool GetT(bool t) {
62  return t ? true : false;
63  }
64  inline static void SetStream(Stream<Device> *stream) {
65  }
66  inline static void gemm(Stream<Device> *stream,
67  bool transa, bool transb,
68  int m, int n, int k, DType alpha,
69  const DType *A, int lda, const DType *B, int ldb,
70  DType beta, DType *C, int ldc) {
71  LOG(FATAL) << "Not implmented!";
72  }
73  inline static void batched_gemm(Stream<Device> *stream,
74  bool transa, bool transb,
75  int m, int n, int k, DType alpha,
76  const DType *A, int lda, const DType *B, int ldb,
77  DType beta, DType *C, int ldc, int batch_count,
78  DType **workspace) {
79  LOG(FATAL) << "Not implmented!";
80  }
81  inline static void gemv(Stream<Device> *stream,
82  bool trans, int m, int n,
83  DType alpha, const DType *A, int lda,
84  const DType *X, int incX,
85  DType beta, DType *Y, int incY) {
86  LOG(FATAL) << "Not implmented!";
87  }
88  inline static void batched_gemv(Stream<Device> *stream,
89  bool trans, int m, int n,
90  DType alpha, const DType *A, int lda,
91  const DType *X, int incX,
92  DType beta, DType *Y, int incY, int batch_count) {
93  LOG(FATAL) << "Not implmented!";
94  }
95  inline static void ger(Stream<Device> *stream,
96  int m, int n, DType alpha,
97  const DType *X, int incX,
98  const DType *Y, int incY, DType *A, int lda) {
99  LOG(FATAL) << "Not implmented!";
100  }
101  inline static void batched_ger(Stream<Device> *stream,
102  int m, int n, DType alpha,
103  const DType *X, int incX,
104  const DType *Y, int incY, DType *A, int lda, int batch_count) {
105  LOG(FATAL) << "Not implmented!";
106  }
107  inline static void dot(Stream<Device> *stream,
108  int n,
109  const DType* X, int incX,
110  const DType* Y, int incY,
111  DType* ret) {
112  LOG(FATAL) << "Not implmented!";
113  }
114 };
115 
116 #if MSHADOW_STAND_ALONE
117 template<>
118 struct BLASEngine<cpu, float> {
119  inline static bool GetT(bool t) {
120  return t ? true : false;
121  }
122  inline static void SetStream(Stream<cpu> *stream) {
123  }
124  inline static void gemm(Stream<cpu> *stream,
125  bool transa, bool transb,
126  int m, int n, int k, float alpha,
127  const float *A, int lda, const float *B, int ldb,
128  float beta, float *C, int ldc) {
129  if (alpha == 1.0f && beta == 0.0f) {
130  bool transpose_left = transb;
131  bool transpose_right = transa;
132  Tensor<cpu, 2, float> lhs((float*)B, Shape2(transpose_left ? k : n, transpose_left ? n : k)); // NOLINT(*)
133  Tensor<cpu, 2, float> rhs((float*)A, Shape2(transpose_right ? m : k, transpose_right ? k : m)); // NOLINT(*)
134  Tensor<cpu, 2, float> dst(C, Shape2(m, n));
135  if (!transpose_left && !transpose_right) {
136  dst = expr::implicit_dot(lhs, rhs); return;
137  } else if (!transpose_left && transpose_right) {
138  dst = expr::implicit_dot(lhs, rhs.T()); return;
139  } else if (transpose_left && !transpose_right) {
140  dst = expr::implicit_dot(lhs.T(), rhs); return;
141  } else {
142  LOG(FATAL) << "Not implmented!";
143  }
144  } else {
145  LOG(FATAL) << "Not implmented!";
146  }
147  }
148  inline static void batched_gemm(Stream<cpu> *stream,
149  bool transa, bool transb,
150  int m, int n, int k, float alpha,
151  const float *A, int lda, const float *B, int ldb,
152  float beta, float *C, int ldc, int batch_count,
153  float **workspace) {
154  for (int i = 0; i < batch_count; ++i) {
155  gemm(stream, transa, transb, m, n, k, alpha,
156  A + i * m * k, lda, B + i * k * n, ldb,
157  beta, C + i * m * n, ldc);
158  }
159  }
160  inline static void gemv(Stream<cpu> *stream,
161  bool trans, int m, int n,
162  float alpha, const float *A, int lda,
163  const float *X, int incX,
164  float beta, float *Y, int incY) {
165  LOG(FATAL) << "Not implmented!";
166  }
167  inline static void batched_gemv(Stream<cpu> *stream,
168  bool trans, int m, int n,
169  float alpha, const float *A, int lda,
170  const float *X, int incX,
171  float beta, float *Y, int incY, int batch_count) {
172  LOG(FATAL) << "Not implmented!";
173  }
174  inline static void ger(Stream<cpu> *stream,
175  int m, int n, float alpha,
176  const float *X, int incX,
177  const float *Y, int incY, float *A, int lda) {
178  LOG(FATAL) << "Not implmented!";
179  }
180  inline static void batched_ger(Stream<cpu> *stream,
181  int m, int n, float alpha,
182  const float *X, int incX,
183  const float *Y, int incY, float *A, int lda, int batch_count) {
184  LOG(FATAL) << "Not implmented!";
185  }
186  inline static void dot(Stream<cpu> *stream,
187  int n,
188  const float* X, int incX,
189  const float* Y, int incY,
190  float* ret) {
191  LOG(FATAL) << "Not implmented!";
192  }
193 };
194 
195 template<>
196 struct BLASEngine<cpu, double> {
197  inline static bool GetT(bool t) {
198  return t ? true : false;
199  }
200  inline static void SetStream(Stream<cpu> *stream) {
201  }
202  inline static void gemm(Stream<cpu> *stream,
203  bool transa, bool transb,
204  int m, int n, int k, double alpha,
205  const double *A, int lda, const double *B, int ldb,
206  double beta, double *C, int ldc) {
207  if (alpha == 1.0f && beta == 0.0f) {
208  bool transpose_left = transb;
209  bool transpose_right = transa;
210  Tensor<cpu, 2, double> lhs((double*)B, Shape2(transpose_left ? k : n, transpose_left ? n : k)); // NOLINT(*)
211  Tensor<cpu, 2, double> rhs((double*)A, Shape2(transpose_right ? m : k, transpose_right ? k : m)); // NOLINT(*)
212  Tensor<cpu, 2, double> dst(C, Shape2(m, n));
213  if (!transpose_left && !transpose_right) {
214  dst = expr::implicit_dot(lhs, rhs); return;
215  } else if (!transpose_left && transpose_right) {
216  dst = expr::implicit_dot(lhs, rhs.T()); return;
217  } else if (transpose_left && !transpose_right) {
218  dst = expr::implicit_dot(lhs.T(), rhs); return;
219  } else {
220  LOG(FATAL) << "Not implmented!";
221  }
222  } else {
223  LOG(FATAL) << "Not implmented!";
224  }
225  }
226  inline static void batched_gemm(Stream<cpu> *stream,
227  bool transa, bool transb,
228  int m, int n, int k, double alpha,
229  const double *A, int lda, const double *B, int ldb,
230  double beta, double *C, int ldc, int batch_count,
231  double **workspace) {
232  for (int i = 0; i < batch_count; ++i) {
233  gemm(stream, transa, transb, m, n, k, alpha,
234  A + i * m * k, lda, B + i * k * n, ldb,
235  beta, C + i * m * n, ldc);
236  }
237  }
238  inline static void gemv(Stream<cpu> *stream,
239  bool trans, int m, int n,
240  double alpha, const double *A, int lda,
241  const double *X, int incX,
242  double beta, double *Y, int incY) {
243  LOG(FATAL) << "Not implmented!";
244  }
245  inline static void batched_gemv(Stream<cpu> *stream,
246  bool trans, int m, int n,
247  double alpha, const double *A, int lda,
248  const double *X, int incX,
249  double beta, double *Y, int incY, int batch_count) {
250  LOG(FATAL) << "Not implmented!";
251  }
252  inline static void ger(Stream<cpu> *stream,
253  int m, int n, double alpha,
254  const double *X, int incX,
255  const double *Y, int incY, double *A, int lda) {
256  LOG(FATAL) << "Not implmented!";
257  }
258  inline static void batched_ger(Stream<cpu> *stream,
259  int m, int n, double alpha,
260  const double *X, int incX,
261  const double *Y, int incY, double *A, int lda, int batch_count) {
262  LOG(FATAL) << "Not implmented!";
263  }
264  inline static void dot(Stream<cpu> *stream,
265  int n,
266  const double* X, int incX,
267  const double* Y, int incY,
268  double* ret) {
269  LOG(FATAL) << "Not implmented!";
270  }
271 };
272 
273 #elif (MSHADOW_USE_MKL || MSHADOW_USE_CBLAS) // NOLINT(*)
274 template<>
275 struct BLASEngine<cpu, float> {
276  inline static CBLAS_TRANSPOSE GetT(bool t) {
277  return t ? CblasTrans : CblasNoTrans;
278  }
279  inline static void SetStream(Stream<cpu> *stream) {
280  }
281  inline static void gemm(Stream<cpu> *stream,
282  bool transa, bool transb,
283  int m, int n, int k, float alpha,
284  const float *A, int lda, const float *B, int ldb,
285  float beta, float *C, int ldc) {
286  cblas_sgemm(CblasColMajor, GetT(transa), GetT(transb),
287  m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
288  }
289  inline static void batched_gemm(Stream<cpu> *stream,
290  bool transa, bool transb,
291  int m, int n, int k, float alpha,
292  const float *A, int lda, const float *B, int ldb,
293  float beta, float *C, int ldc, int batch_count,
294  float **workspace) {
295 #if (MSHADOW_USE_MKL && INTEL_MKL_VERSION >= 20160000)
296  // since same m/n/k is used for all single gemms, so we put all gemms into one group
297  const int GROUP_SIZE = 1;
298  MKL_INT p_m[GROUP_SIZE] = {m};
299  MKL_INT p_n[GROUP_SIZE] = {n};
300  MKL_INT p_k[GROUP_SIZE] = {k};
301  MKL_INT p_lda[GROUP_SIZE] = {lda};
302  MKL_INT p_ldb[GROUP_SIZE] = {ldb};
303  MKL_INT p_ldc[GROUP_SIZE] = {ldc};
304 
305  float p_alpha[GROUP_SIZE] = {alpha};
306  float p_beta[GROUP_SIZE] = {beta};
307 
308  CBLAS_TRANSPOSE cblas_a_trans = GetT(transa);
309  CBLAS_TRANSPOSE cblas_b_trans = GetT(transb);
310 
311  MKL_INT p_group_sizeb[GROUP_SIZE] = {batch_count};
312  CBLAS_TRANSPOSE p_transa[GROUP_SIZE] = {cblas_a_trans};
313  CBLAS_TRANSPOSE p_transb[GROUP_SIZE] = {cblas_b_trans};
314 
315  std::vector<const float*> pp_A;
316  std::vector<const float*> pp_B;
317  std::vector<float*> pp_C;
318  pp_A.reserve(batch_count);
319  pp_B.reserve(batch_count);
320  pp_C.reserve(batch_count);
321 
322  auto m_k = m * k;
323  auto k_n = k * n;
324  auto m_n = m * n;
325 
326  for (int i = 0; i < batch_count; i++) {
327  pp_A[i] = A + i * m_k;
328  pp_B[i] = B + i * k_n;
329  pp_C[i] = C + i * m_n;
330  }
331 
332  cblas_sgemm_batch(CblasColMajor, p_transa, p_transb,
333  p_m, p_n, p_k, p_alpha, pp_A.data(), p_lda, pp_B.data(),
334  p_ldb, p_beta, pp_C.data(), p_ldc, GROUP_SIZE, p_group_sizeb);
335 #else
336  for (int i = 0; i < batch_count; ++i) {
337  gemm(stream, transa, transb, m, n, k, alpha,
338  A + i * m * k, lda, B + i * k * n, ldb,
339  beta, C + i * m * n, ldc);
340  }
341 #endif
342  }
343  inline static void gemv(Stream<cpu> *stream,
344  bool trans, int m, int n,
345  float alpha, const float *A, int lda,
346  const float *X, int incX,
347  float beta, float *Y, int incY) {
348  cblas_sgemv(CblasColMajor, GetT(trans), m, n, alpha,
349  A, lda, X, incX, beta, Y, incY);
350  }
351  inline static void batched_gemv(Stream<cpu> *stream,
352  bool trans, int m, int n,
353  float alpha, const float *A, int lda,
354  const float *X, int incX,
355  float beta, float *Y, int incY, int batch_count) {
356  for (int i = 0; i < batch_count; ++i) {
357  gemv(stream, trans, m, n, alpha, A + i * m * n, lda,
358  X + i * (trans ? m : n) * incX, incX,
359  beta, Y + i * (trans ? n : m) * incY, incY);
360  }
361  }
362  inline static void ger(Stream<cpu> *stream,
363  int m, int n, float alpha,
364  const float *X, int incX,
365  const float *Y, int incY, float *A, int lda) {
366  cblas_sger(CblasColMajor, m, n, alpha, X, incX, Y, incY, A, lda);
367  }
368  inline static void batched_ger(Stream<cpu> *stream,
369  int m, int n, float alpha,
370  const float *X, int incX,
371  const float *Y, int incY, float *A, int lda, int batch_count) {
372  for (int i = 0; i < batch_count; ++i) {
373  ger(stream, m, n, alpha, X + i * m * incX, incX, Y + i * n * incY, incY,
374  A + i * lda * n, lda);
375  }
376  }
377  inline static void dot(Stream<cpu> *stream,
378  int n,
379  const float* X, int incX,
380  const float* Y, int incY,
381  float* ret) {
382  *ret = cblas_sdot(n, X, incX, Y, incY);
383  }
384 };
385 
386 template<>
387 struct BLASEngine<cpu, double> {
388  inline static CBLAS_TRANSPOSE GetT(bool t) {
389  return t ? CblasTrans : CblasNoTrans;
390  }
391  inline static void SetStream(Stream<cpu> *stream) {
392  }
393  inline static void gemm(Stream<cpu> *stream,
394  bool transa, bool transb,
395  int m, int n, int k, double alpha,
396  const double *A, int lda, const double *B, int ldb,
397  double beta, double *C, int ldc) {
398  cblas_dgemm(CblasColMajor, GetT(transa), GetT(transb),
399  m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
400  }
401  inline static void batched_gemm(Stream<cpu> *stream,
402  bool transa, bool transb,
403  int m, int n, int k, double alpha,
404  const double *A, int lda, const double *B, int ldb,
405  double beta, double *C, int ldc, int batch_count,
406  double **workspace) {
407 #if (MSHADOW_USE_MKL && INTEL_MKL_VERSION >= 20160000)
408  // since same m/n/k is used for all single gemms, so we put all gemms into one group
409  const int GROUP_SIZE = 1;
410  MKL_INT p_m[GROUP_SIZE] = {m};
411  MKL_INT p_n[GROUP_SIZE] = {n};
412  MKL_INT p_k[GROUP_SIZE] = {k};
413  MKL_INT p_lda[GROUP_SIZE] = {lda};
414  MKL_INT p_ldb[GROUP_SIZE] = {ldb};
415  MKL_INT p_ldc[GROUP_SIZE] = {ldc};
416 
417  double p_alpha[GROUP_SIZE] = {alpha};
418  double p_beta[GROUP_SIZE] = {beta};
419 
420  CBLAS_TRANSPOSE cblas_a_trans = GetT(transa);
421  CBLAS_TRANSPOSE cblas_b_trans = GetT(transb);
422 
423  MKL_INT p_group_sizeb[GROUP_SIZE] = {batch_count};
424  CBLAS_TRANSPOSE p_transa[GROUP_SIZE] = {cblas_a_trans};
425  CBLAS_TRANSPOSE p_transb[GROUP_SIZE] = {cblas_b_trans};
426 
427  std::vector<const double*> pp_A;
428  std::vector<const double*> pp_B;
429  std::vector<double*> pp_C;
430  pp_A.reserve(batch_count);
431  pp_B.reserve(batch_count);
432  pp_C.reserve(batch_count);
433 
434  auto m_k = m * k;
435  auto k_n = k * n;
436  auto m_n = m * n;
437 
438  for (int i = 0; i < batch_count; i++) {
439  pp_A[i] = A + i * m_k;
440  pp_B[i] = B + i * k_n;
441  pp_C[i] = C + i * m_n;
442  }
443 
444  cblas_dgemm_batch(CblasColMajor, p_transa, p_transb,
445  p_m, p_n, p_k, p_alpha, pp_A.data(), p_lda, pp_B.data(),
446  p_ldb, p_beta, pp_C.data(), p_ldc, GROUP_SIZE, p_group_sizeb);
447 #else
448  for (int i = 0; i < batch_count; ++i) {
449  gemm(stream, transa, transb, m, n, k, alpha,
450  A + i * m * k, lda, B + i * k * n, ldb,
451  beta, C + i * m * n, ldc);
452  }
453 #endif
454  }
455  inline static void gemv(Stream<cpu> *stream,
456  bool trans, int m, int n, double alpha,
457  const double *A, int lda,
458  const double *X, int incX,
459  double beta, double *Y, int incY) {
460  cblas_dgemv(CblasColMajor, GetT(trans), m, n, alpha,
461  A, lda, X, incX, beta, Y, incY);
462  }
463  inline static void batched_gemv(Stream<cpu> *stream,
464  bool trans, int m, int n,
465  double alpha, const double *A, int lda,
466  const double *X, int incX,
467  double beta, double *Y, int incY, int batch_count) {
468  for (int i = 0; i < batch_count; ++i) {
469  gemv(stream, trans, m, n, alpha, A + i * m * n, lda,
470  X + i * (trans ? m : n) * incX, incX,
471  beta, Y + i * (trans ? n : m) * incY, incY);
472  }
473  }
474  inline static void ger(Stream<cpu> *stream,
475  int m, int n, double alpha,
476  const double *X, int incX,
477  const double *Y, int incY, double *A, int lda) {
478  cblas_dger(CblasColMajor, m, n, alpha, X, incX, Y, incY, A, lda);
479  }
480  inline static void batched_ger(Stream<cpu> *stream,
481  int m, int n, double alpha,
482  const double *X, int incX,
483  const double *Y, int incY, double *A, int lda, int batch_count) {
484  for (int i = 0; i < batch_count; ++i) {
485  ger(stream, m, n, alpha, X + i * m * incX, incX, Y + i * n * incY, incY,
486  A + i * lda * n, lda);
487  }
488  }
489  inline static void dot(Stream<cpu> *stream,
490  int n,
491  const double* X, int incX,
492  const double* Y, int incY,
493  double* ret) {
494  *ret = cblas_ddot(n, X, incX, Y, incY);
495  }
496 };
497 #endif // MSHADOW_USE_CBLAS || MSHADOW_USE_MKL || MSHADOW_STAND_ALONE
498 // CuBLAS redirect code
499 #if MSHADOW_USE_CUDA
500 // All CuBLAS goes to here, use legacy API: not threadsafe
501 template<>
502 struct BLASEngine<gpu, half::half_t> {
503  inline static cublasOperation_t GetT(bool t) {
504  return t ? CUBLAS_OP_T : CUBLAS_OP_N;
505  }
506  inline static void SetStream(Stream<gpu> *stream) {
507  cublasStatus_t err = cublasSetStream(Stream<gpu>::GetBlasHandle(stream),
508  Stream<gpu>::GetStream(stream));
509  CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas set stream fail";
510  }
511  inline static void gemm(Stream<gpu> *stream,
512  bool transa, bool transb,
513  int m, int n, int k, half::half_t alpha,
514  const half::half_t *A, int lda,
515  const half::half_t *B, int ldb, half::half_t beta,
516  half::half_t *C, int ldc) {
517 #if defined(CUDA_VERSION) && CUDA_VERSION >= 7050
518  // Always use pseudo-fp16: fp32 compute with fp16 I/O.
519  float alpha_f = float(alpha); // NOLINT(*)
520  float beta_f = float(beta); // NOLINT(*)
521  #if CUDA_VERSION >= 8000
522  cublasStatus_t err = cublasSgemmEx(Stream<gpu>::GetBlasHandle(stream),
523  GetT(transa), GetT(transb), m, n, k, &alpha_f,
524  A, CUDA_R_16F, lda, B, CUDA_R_16F,
525  ldb, &beta_f, C, CUDA_R_16F, ldc);
526  CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas SgemmEx fail";
527  #else
528  cublasStatus_t err = cublasSgemmEx(Stream<gpu>::GetBlasHandle(stream),
529  GetT(transa), GetT(transb), m, n, k, &alpha_f,
530  A, CUBLAS_DATA_HALF, lda, B, CUBLAS_DATA_HALF,
531  ldb, &beta_f, C, CUBLAS_DATA_HALF, ldc);
532  CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas SgemmEx fail";
533  #endif // CUDA_VERSION >= 8000
534 #else
535  LOG(FATAL) << "Require CUDA version >= 7.5!";
536 #endif // defined(CUDA_VERSION) && CUDA_VERSION >= 7050
537  }
538  inline static void batched_gemm(Stream<gpu> *stream,
539  bool transa, bool transb,
540  int m, int n, int k, half::half_t alpha,
541  const half::half_t *A, int lda, const half::half_t *B, int ldb,
542  half::half_t beta, half::half_t *C, int ldc, int batch_count,
543  half::half_t **workspace) {
544 #if defined(__CUDACC__) && CUDA_VERSION >= 9000
545  int major = stream->prop.major;
546  int minor = stream->prop.minor;
547  // fp16 is not supported before ARCH 53
548  if ((major > 5) || (major == 5 && minor >= 3)) {
549  const __half* A_h = reinterpret_cast<const __half*>(A);
550  const __half* B_h = reinterpret_cast<const __half*>(B);
551  __half* alpha_h = reinterpret_cast<__half*>(&alpha);
552  __half* beta_h = reinterpret_cast<__half*>(&beta);
553  __half* C_h = reinterpret_cast<__half*>(C);
554  cublasStatus_t err = cublasHgemmStridedBatched(Stream<gpu>::GetBlasHandle(stream),
555  GetT(transa), GetT(transb), m, n, k, alpha_h,
556  A_h, lda, m * k,
557  B_h, ldb, k * n,
558  beta_h, C_h, ldc, m * n,
559  batch_count);
560  CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: HgemmStridedBatched fail";
561  return;
562  }
563 #endif
564  for (int i = 0; i < batch_count; ++i) {
565  gemm(stream, transa, transb, m, n, k, alpha,
566  A + i * m * k, lda, B + i * k * n, ldb,
567  beta, C + i * m * n, ldc);
568  }
569  }
570  inline static void gemv(Stream<gpu> *stream,
571  bool trans, int m, int n, half::half_t alpha,
572  const half::half_t *A, int lda,
573  const half::half_t *X, int incX, half::half_t beta,
574  half::half_t *Y, int incY) {
575  LOG(FATAL) << "Not implmented!";
576  }
577  inline static void batched_gemv(Stream<gpu> *stream,
578  bool trans, int m, int n,
579  half::half_t alpha, const half::half_t *A, int lda,
580  const half::half_t *X, int incX,
581  half::half_t beta, half::half_t *Y, int incY, int batch_count) {
582  LOG(FATAL) << "Not implmented!";
583  }
584  inline static void ger(Stream<gpu> *stream,
585  int m, int n, half::half_t alpha,
586  const half::half_t *X, int incX,
587  const half::half_t *Y, int incY, half::half_t *A, int lda) {
588  LOG(FATAL) << "Not implmented!";
589  }
590  inline static void batched_ger(Stream<gpu> *stream,
591  int m, int n, half::half_t alpha,
592  const half::half_t *X, int incX, const half::half_t *Y, int incY,
593  half::half_t *A, int lda, int batch_count) {
594  LOG(FATAL) << "Not implmented!";
595  }
596  inline static void dot(Stream<gpu> *stream,
597  int n,
598  const half::half_t* X, int incX,
599  const half::half_t* Y, int incY,
600  half::half_t *ret) {
601  LOG(FATAL) << "Not implmented!";
602  }
603 };
604 
605 template<>
606 struct BLASEngine<gpu, float> {
607  inline static cublasOperation_t GetT(bool t) {
608  return t ? CUBLAS_OP_T : CUBLAS_OP_N;
609  }
610  inline static void SetStream(Stream<gpu> *stream) {
611  cublasStatus_t err = cublasSetStream(Stream<gpu>::GetBlasHandle(stream),
612  Stream<gpu>::GetStream(stream));
613  CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: set stream fail";
614  }
615  inline static void gemm(Stream<gpu> *stream,
616  bool transa, bool transb,
617  int m, int n, int k, float alpha,
618  const float *A, int lda,
619  const float *B, int ldb, float beta,
620  float *C, int ldc) {
621  cublasStatus_t err = cublasSgemm(Stream<gpu>::GetBlasHandle(stream),
622  GetT(transa), GetT(transb), m, n, k, &alpha,
623  A, lda, B, ldb, &beta, C, ldc);
624  CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Sgemm fail";
625  }
626  inline static void batched_gemm(Stream<gpu> *stream,
627  bool transa, bool transb,
628  int m, int n, int k, float alpha,
629  const float *A, int lda, const float *B, int ldb,
630  float beta, float *C, int ldc, int batch_count,
631  float **workspace) {
632 #if defined(__CUDACC__) && CUDA_VERSION >= 4010 && CUDA_VERSION < 8000
633  // Cast DType* to DType** using workspace as a buffer
634  bool alloc_workspace = false;
635  if (workspace == NULL) {
636  // Allocate the workspace if it's NULL.
637  // TODO(sxjscience) Try to move the allocation inside Tensor, which is thread-safe.
638  cudaMalloc(reinterpret_cast<void**>(&workspace), 3 * batch_count * sizeof(float*));
639  alloc_workspace = true;
640  }
641  GetBatchedView(workspace, const_cast<float*>(A), batch_count, m * k, stream);
642  GetBatchedView(workspace + batch_count,
643  const_cast<float*>(B), batch_count, k * n, stream);
644  GetBatchedView(workspace + 2 * batch_count, C, batch_count, m * n, stream);
645  cublasStatus_t err = cublasSgemmBatched(Stream<gpu>::GetBlasHandle(stream),
646  GetT(transa), GetT(transb), m, n, k, &alpha,
647  (const float**)workspace, lda,
648  (const float**)(workspace + batch_count), ldb,
649  &beta, workspace + 2 * batch_count, ldc, batch_count);
650  CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: SgemmBatched fail";
651  if (alloc_workspace) {
652  cudaFree(workspace);
653  }
654 #elif defined(__CUDACC__) && CUDA_VERSION >= 8000
655  cublasStatus_t err = cublasSgemmStridedBatched(Stream<gpu>::GetBlasHandle(stream),
656  GetT(transa), GetT(transb), m, n, k, &alpha,
657  A, lda, m * k,
658  B, ldb, k * n,
659  &beta, C, ldc, m * n,
660  batch_count);
661  CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: SgemmStridedBatched fail";
662 #else
663  for (int i = 0; i < batch_count; ++i) {
664  gemm(stream, transa, transb, m, n, k, alpha,
665  A + i * m * k, lda, B + i * k * n, ldb,
666  beta, C + i * m * n, ldc);
667  }
668 #endif // defined(__CUDACC__) && CUDA_VERSION >= 4010
669  }
670  inline static void gemv(Stream<gpu> *stream,
671  bool trans, int m, int n, float alpha,
672  const float *A, int lda,
673  const float *X, int incX, float beta,
674  float *Y, int incY) {
675  cublasStatus_t err = cublasSgemv(Stream<gpu>::GetBlasHandle(stream),
676  GetT(trans), m, n, &alpha, A, lda, X, incX, &beta, Y, incY);
677  CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Sgemv fail";
678  }
679  inline static void batched_gemv(Stream<gpu> *stream,
680  bool trans, int m, int n,
681  float alpha, const float *A, int lda,
682  const float *X, int incX,
683  float beta, float *Y, int incY, int batch_count) {
684  for (int i = 0; i < batch_count; ++i) {
685  gemv(stream, trans, m, n, alpha, A + i * m * n, lda,
686  X + i * (trans ? m : n) * incX, incX,
687  beta, Y + i * (trans ? n : m) * incY, incY);
688  }
689  }
690  inline static void ger(Stream<gpu> *stream,
691  int m, int n, float alpha,
692  const float *X, int incX,
693  const float *Y, int incY, float *A, int lda) {
694  cublasStatus_t err = cublasSger(Stream<gpu>::GetBlasHandle(stream),
695  m, n, &alpha, X, incX, Y, incY, A, lda);
696  CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Sger fail";
697  }
698  inline static void batched_ger(Stream<gpu> *stream,
699  int m, int n, float alpha,
700  const float *X, int incX,
701  const float *Y, int incY, float *A, int lda, int batch_count) {
702  for (int i = 0; i < batch_count; ++i) {
703  ger(stream, m, n, alpha, X + i * m * incX, incX, Y + i * n * incY, incY,
704  A + i * lda * n, lda);
705  }
706  }
707  inline static void dot(Stream<gpu> *stream,
708  int n,
709  const float* X, int incX,
710  const float* Y, int incY,
711  float *ret) {
712  cublasSetPointerMode(Stream<gpu>::GetBlasHandle(stream),
713  CUBLAS_POINTER_MODE_DEVICE);
714  cublasStatus_t err = cublasSdot(Stream<gpu>::GetBlasHandle(stream),
715  n, X, incX, Y, incY, ret);
716  CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Dot fail";
717  cublasSetPointerMode(Stream<gpu>::GetBlasHandle(stream),
718  CUBLAS_POINTER_MODE_HOST);
719  }
720 };
721 
722 template<>
723 struct BLASEngine<gpu, double> {
724  inline static cublasOperation_t GetT(bool t) {
725  return t ? CUBLAS_OP_T : CUBLAS_OP_N;
726  }
727  inline static void SetStream(Stream<gpu> *stream) {
728  cublasStatus_t err = cublasSetStream(Stream<gpu>::GetBlasHandle(stream),
729  Stream<gpu>::GetStream(stream));
730  CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: set stream fail";
731  }
732  inline static void gemm(Stream<gpu> *stream,
733  bool transa, bool transb,
734  int m, int n, int k, double alpha,
735  const double *A, int lda,
736  const double *B, int ldb,
737  double beta, double *C, int ldc) {
738  cublasStatus_t err = cublasDgemm(Stream<gpu>::GetBlasHandle(stream),
739  GetT(transa), GetT(transb), m, n, k, &alpha,
740  A, lda, B, ldb, &beta, C, ldc);
741  CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Dgemm fail";
742  }
743  inline static void batched_gemm(Stream<gpu> *stream,
744  bool transa, bool transb,
745  int m, int n, int k, double alpha,
746  const double *A, int lda, const double *B, int ldb,
747  double beta, double *C, int ldc, int batch_count,
748  double **workspace) {
749 #if defined(__CUDACC__) && CUDA_VERSION >= 4010 && CUDA_VERSION < 8000
750  // Cast DType* to DType** using workspace as a buffer
751  bool alloc_workspace = false;
752  if (workspace == NULL) {
753  // Allocate the workspace if it's NULL.
754  // TODO(sxjscience) Try to move the allocation inside Tensor, which is thread-safe.
755  cudaMalloc(reinterpret_cast<void**>(&workspace), 3 * batch_count * sizeof(double*));
756  alloc_workspace = true;
757  }
758  GetBatchedView(workspace, const_cast<double*>(A), batch_count, m * k, stream);
759  GetBatchedView(workspace + batch_count,
760  const_cast<double*>(B), batch_count, k * n, stream);
761  GetBatchedView(workspace + 2 * batch_count, C, batch_count, m * n, stream);
762  cublasStatus_t err = cublasDgemmBatched(Stream<gpu>::GetBlasHandle(stream),
763  GetT(transa), GetT(transb), m, n, k, &alpha,
764  (const double**)workspace, lda,
765  (const double**)(workspace + batch_count), ldb,
766  &beta, workspace + 2 * batch_count, ldc, batch_count);
767  CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: DgemmBatched fail";
768  if (alloc_workspace) {
769  cudaFree(workspace);
770  }
771 #elif defined(__CUDACC__) && CUDA_VERSION >= 8000
772  cublasStatus_t err = cublasDgemmStridedBatched(Stream<gpu>::GetBlasHandle(stream),
773  GetT(transa), GetT(transb), m, n, k, &alpha,
774  A, lda, m * k,
775  B, ldb, k * n,
776  &beta, C, ldc, m * n,
777  batch_count);
778  CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: DgemmStridedBatched fail";
779 #else
780  for (int i = 0; i < batch_count; ++i) {
781  gemm(stream, transa, transb, m, n, k, alpha,
782  A + i * m * k, lda, B + i * k * n, ldb,
783  beta, C + i * m * n, ldc);
784  }
785 #endif // defined(__CUDACC__) && CUDA_VERSION >= 4010
786  }
787  inline static void gemv(Stream<gpu> *stream,
788  bool trans, int m, int n, double alpha,
789  const double *A, int lda,
790  const double *X, int incX,
791  double beta, double *Y, int incY) {
792  cublasStatus_t err = cublasDgemv(Stream<gpu>::GetBlasHandle(stream),
793  GetT(trans), m, n, &alpha, A, lda, X, incX, &beta, Y, incY);
794  CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Dgemv fail";
795  }
796  inline static void batched_gemv(Stream<gpu> *stream,
797  bool trans, int m, int n,
798  double alpha, const double *A, int lda,
799  const double *X, int incX,
800  double beta, double *Y, int incY, int batch_count) {
801  for (int i = 0; i < batch_count; ++i) {
802  gemv(stream, trans, m, n, alpha, A + i * m * n, lda,
803  X + i * (trans ? m : n) * incX, incX,
804  beta, Y + i * (trans ? n : m) * incY, incY);
805  }
806  }
807  inline static void ger(Stream<gpu> *stream,
808  int m, int n, double alpha,
809  const double *X, int incX,
810  const double *Y, int incY, double *A, int lda) {
811  cublasStatus_t err = cublasDger(Stream<gpu>::GetBlasHandle(stream),
812  m, n, &alpha, X, incX, Y, incY, A, lda);
813  CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Dger fail";
814  }
815  inline static void batched_ger(Stream<gpu> *stream,
816  int m, int n, double alpha,
817  const double *X, int incX,
818  const double *Y, int incY, double *A, int lda, int batch_count) {
819  for (int i = 0; i < batch_count; ++i) {
820  ger(stream, m, n, alpha, X + i * m * incX, incX, Y + i * n * incY, incY,
821  A + i * lda * n, lda);
822  }
823  }
824  inline static void dot(Stream<gpu> *stream,
825  int n,
826  const double* X, int incX,
827  const double* Y, int incY,
828  double *ret) {
829  cublasSetPointerMode(Stream<gpu>::GetBlasHandle(stream),
830  CUBLAS_POINTER_MODE_DEVICE);
831  cublasStatus_t err = cublasDdot(Stream<gpu>::GetBlasHandle(stream),
832  n, X, incX, Y, incY, ret);
833  CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Dot fail";
834  cublasSetPointerMode(Stream<gpu>::GetBlasHandle(stream),
835  CUBLAS_POINTER_MODE_HOST);
836  }
837 };
838 #endif // MSHADOW_USE_CUDA
839 // helper function to decide which shape we are in
840 inline Shape<2> GetShape(const Shape<2> &shape, bool transpose) {
841  return transpose ? Shape2(shape[1], shape[0]) : shape;
842 }
843 // dst = dot(lhs[.T], rhs[.T])
844 template<typename SV, typename xpu,
845  bool transpose_left, bool transpose_right, typename DType>
846 struct DotEngine<SV, xpu, 2, 2, 2, transpose_left, transpose_right, DType> {
847  inline static void Eval(Tensor<xpu, 2, DType> *p_dst,
848  const Tensor<xpu, 2, DType> &lhs,
849  const Tensor<xpu, 2, DType> &rhs,
850  DType scale) {
851  Tensor<xpu, 2, DType> &dst = *p_dst;
852 #if MSHADOW_STAND_ALONE
853  if (xpu::kDevMask == cpu::kDevMask && scale == 1.0f) {
854  if (!transpose_left && !transpose_right) {
855  dst = expr::implicit_dot(lhs, rhs); return;
856  } else if (!transpose_left && transpose_right) {
857  dst = expr::implicit_dot(lhs, rhs.T()); return;
858  } else if (transpose_left && !transpose_right) {
859  dst = expr::implicit_dot(lhs.T(), rhs); return;
860  }
861  }
862 #endif
863  // set kernel stream
864  // if there is no stream, crush
866  Shape<2> sleft = GetShape(lhs.shape_, transpose_left);
867  Shape<2> sright = GetShape(rhs.shape_, transpose_right);
868  CHECK(dst.size(0) == sleft[0] && dst.size(1) == sright[1] && sleft[1] == sright[0])
869  << "dot-gemm: matrix shape mismatch";
870  // use column major argument to compatible with most BLAS
872  (dst.stream_,
873  transpose_right , transpose_left,
874  transpose_right ? rhs.size(0) : rhs.size(1),
875  transpose_left ? lhs.size(1) : lhs.size(0),
876  transpose_right ? rhs.size(1) : rhs.size(0),
877  DType(scale * SV::AlphaBLAS()),
878  rhs.dptr_, rhs.stride_,
879  lhs.dptr_, lhs.stride_,
880  DType(SV::BetaBLAS()),
881  dst.dptr_, dst.stride_);
882  }
883 };
884 template<typename SV, typename xpu, bool transpose_right, typename DType>
885 struct DotEngine<SV, xpu, 1, 1, 2, false, transpose_right, DType> {
886  inline static void Eval(Tensor<xpu, 1, DType> *p_dst,
887  const Tensor<xpu, 1, DType> &lhs,
888  const Tensor<xpu, 2, DType> &rhs,
889  DType scale) {
890  Tensor<xpu, 1, DType> &dst = *p_dst;
891  // set kernel stream
892  // if there is no stream, crush
894  Shape<2> sright = GetShape(rhs.shape_, transpose_right);
895  CHECK(dst.size(0) == sright[1] && lhs.size(0) == sright[0])
896  << "dot-gemv: matrix shape mismatch"
897  << "dst: " << dst.shape_ << "\n"
898  << "lhs: " << lhs.shape_ << "\n"
899  << "rhs: " << sright << "\n";
901  (dst.stream_,
902  transpose_right,
903  rhs.size(1), rhs.size(0), scale * SV::AlphaBLAS(),
904  rhs.dptr_, rhs.stride_,
905  lhs.dptr_, 1, SV::BetaBLAS(),
906  dst.dptr_, 1);
907  }
908 };
909 template<typename SV, typename xpu, typename DType>
910 struct DotEngine<SV, xpu, 2, 1, 1, true, false, DType> {
911  inline static void Eval(Tensor<xpu, 2, DType> *p_dst,
912  const Tensor<xpu, 1, DType> &lhs,
913  const Tensor<xpu, 1, DType> &rhs,
914  DType scale) {
915  Tensor<xpu, 2, DType> &dst = *p_dst;
916  // set kernel stream
917  // if there is no stream, crush
919  CHECK(dst.size(0) == lhs.size(0) && dst.size(1) == rhs.size(0))
920  << "dot-ger: matrix shape mismatch"
921  << "dst: " << dst.shape_ << "\n"
922  << "lhs: " << lhs.shape_ << "\n"
923  << "rhs: " << rhs.shape_;
924  if (SV::BetaBLAS() == 0.0f) {
926  (dst.stream_, rhs.size(0), lhs.size(0), scale * SV::AlphaBLAS(),
927  rhs.dptr_, 1, lhs.dptr_, 1, dst.dptr_, dst.stride_);
928  } else {
929  DotEngine<SV, xpu, 2, 2, 2, true, false,
930  DType>::Eval(p_dst, lhs.FlatTo2D(), rhs.FlatTo2D(), scale);
931  }
932  }
933 };
934 } // namespace expr
935 } // namespace mshadow
936 #endif // MSHADOW_DOT_ENGINE_INL_H_
static void ger(Stream< gpu > *stream, int m, int n, half::half_t alpha, const half::half_t *X, int incX, const half::half_t *Y, int incY, half::half_t *A, int lda)
Definition: dot_engine-inl.h:584
static void batched_gemv(Stream< Device > *stream, bool trans, int m, int n, DType alpha, const DType *A, int lda, const DType *X, int incX, DType beta, DType *Y, int incY, int batch_count)
Definition: dot_engine-inl.h:88
static void batched_gemm(Stream< Device > *stream, bool transa, bool transb, int m, int n, int k, DType alpha, const DType *A, int lda, const DType *B, int ldb, DType beta, DType *C, int ldc, int batch_count, DType **workspace)
Definition: dot_engine-inl.h:73
static void SetStream(Stream< gpu > *stream)
Definition: dot_engine-inl.h:727
static void SetStream(Stream< cpu > *stream)
Definition: dot_engine-inl.h:391
ImplicitGEMMExp< LhsExp, RhsExp, DType > implicit_dot(const Exp< LhsExp, DType, e1 > &lhs, const Exp< RhsExp, DType, e2 > &rhs)
Definition: implicit_gemm.h:46
static void SetStream(Stream< cpu > *stream)
Definition: dot_engine-inl.h:279
static void SetStream(Stream< gpu > *stream)
Definition: dot_engine-inl.h:610
static void batched_gemv(Stream< gpu > *stream, bool trans, int m, int n, float alpha, const float *A, int lda, const float *X, int incX, float beta, float *Y, int incY, int batch_count)
Definition: dot_engine-inl.h:679
DType * dptr_
pointer to the data
Definition: tensor.h:416
static void gemv(Stream< gpu > *stream, bool trans, int m, int n, half::half_t alpha, const half::half_t *A, int lda, const half::half_t *X, int incX, half::half_t beta, half::half_t *Y, int incY)
Definition: dot_engine-inl.h:570
Shape< 2 > GetShape(const Shape< 2 > &shape, bool transpose)
Definition: dot_engine-inl.h:840
static void gemm(Stream< cpu > *stream, bool transa, bool transb, int m, int n, int k, double alpha, const double *A, int lda, const double *B, int ldb, double beta, double *C, int ldc)
Definition: dot_engine-inl.h:393
static void gemv(Stream< cpu > *stream, bool trans, int m, int n, float alpha, const float *A, int lda, const float *X, int incX, float beta, float *Y, int incY)
Definition: dot_engine-inl.h:343
static void gemv(Stream< gpu > *stream, bool trans, int m, int n, double alpha, const double *A, int lda, const double *X, int incX, double beta, double *Y, int incY)
Definition: dot_engine-inl.h:787
static void batched_ger(Stream< cpu > *stream, int m, int n, double alpha, const double *X, int incX, const double *Y, int incY, double *A, int lda, int batch_count)
Definition: dot_engine-inl.h:480
static void batched_gemm(Stream< gpu > *stream, bool transa, bool transb, int m, int n, int k, double alpha, const double *A, int lda, const double *B, int ldb, double beta, double *C, int ldc, int batch_count, double **workspace)
Definition: dot_engine-inl.h:743
static void gemv(Stream< cpu > *stream, bool trans, int m, int n, double alpha, const double *A, int lda, const double *X, int incX, double beta, double *Y, int incY)
Definition: dot_engine-inl.h:455
Definition: stream_gpu-inl.h:19
static void batched_gemv(Stream< gpu > *stream, bool trans, int m, int n, double alpha, const double *A, int lda, const double *X, int incX, double beta, double *Y, int incY, int batch_count)
Definition: dot_engine-inl.h:796
static void batched_ger(Stream< cpu > *stream, int m, int n, float alpha, const float *X, int incX, const float *Y, int incY, float *A, int lda, int batch_count)
Definition: dot_engine-inl.h:368
support for implicit GEMM operation
static void batched_ger(Stream< gpu > *stream, int m, int n, half::half_t alpha, const half::half_t *X, int incX, const half::half_t *Y, int incY, half::half_t *A, int lda, int batch_count)
Definition: dot_engine-inl.h:590
Shape< dimension > shape_
shape of the tensor
Definition: tensor.h:418
static void gemm(Stream< cpu > *stream, bool transa, bool transb, int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, float beta, float *C, int ldc)
Definition: dot_engine-inl.h:281
static cublasOperation_t GetT(bool t)
Definition: dot_engine-inl.h:607
static void batched_gemv(Stream< gpu > *stream, bool trans, int m, int n, half::half_t alpha, const half::half_t *A, int lda, const half::half_t *X, int incX, half::half_t beta, half::half_t *Y, int incY, int batch_count)
Definition: dot_engine-inl.h:577
static void ger(Stream< gpu > *stream, int m, int n, float alpha, const float *X, int incX, const float *Y, int incY, float *A, int lda)
Definition: dot_engine-inl.h:690
static void batched_ger(Stream< gpu > *stream, int m, int n, double alpha, const double *X, int incX, const double *Y, int incY, double *A, int lda, int batch_count)
Definition: dot_engine-inl.h:815
DotExp< TA, TB, false, false, DType > dot(const RValueExp< TA, DType > &lhs, const RValueExp< TB, DType > &rhs)
dot operator def
Definition: expression.h:222
static void batched_gemv(Stream< cpu > *stream, bool trans, int m, int n, double alpha, const double *A, int lda, const double *X, int incX, double beta, double *Y, int incY, int batch_count)
Definition: dot_engine-inl.h:463
static void batched_gemm(Stream< gpu > *stream, bool transa, bool transb, int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, float beta, float *C, int ldc, int batch_count, float **workspace)
Definition: dot_engine-inl.h:626
Definition: dot_engine-inl.h:52
static void batched_ger(Stream< gpu > *stream, int m, int n, float alpha, const float *X, int incX, const float *Y, int incY, float *A, int lda, int batch_count)
Definition: dot_engine-inl.h:698
cudaDeviceProp prop
cudaDeviceProp
Definition: stream_gpu-inl.h:44
device name GPU
Definition: tensor.h:28
MSHADOW_XINLINE Tensor< Device, 2, DType > FlatTo2D(void) const
flatten the tensor to 2 dimension, collapse the higher dimensions together
Definition: tensor.h:501
static void batched_ger(Stream< Device > *stream, int m, int n, DType alpha, const DType *X, int incX, const DType *Y, int incY, DType *A, int lda, int batch_count)
Definition: dot_engine-inl.h:101
static cublasOperation_t GetT(bool t)
Definition: dot_engine-inl.h:724
static void Eval(Tensor< xpu, 2, DType > *p_dst, const Tensor< xpu, 2, DType > &lhs, const Tensor< xpu, 2, DType > &rhs, DType scale)
Definition: dot_engine-inl.h:847
static void gemm(Stream< gpu > *stream, bool transa, bool transb, int m, int n, int k, half::half_t alpha, const half::half_t *A, int lda, const half::half_t *B, int ldb, half::half_t beta, half::half_t *C, int ldc)
Definition: dot_engine-inl.h:511
static void batched_gemm(Stream< gpu > *stream, bool transa, bool transb, int m, int n, int k, half::half_t alpha, const half::half_t *A, int lda, const half::half_t *B, int ldb, half::half_t beta, half::half_t *C, int ldc, int batch_count, half::half_t **workspace)
Definition: dot_engine-inl.h:538
static void batched_gemm(Stream< cpu > *stream, bool transa, bool transb, int m, int n, int k, double alpha, const double *A, int lda, const double *B, int ldb, double beta, double *C, int ldc, int batch_count, double **workspace)
Definition: dot_engine-inl.h:401
static void gemm(Stream< gpu > *stream, bool transa, bool transb, int m, int n, int k, double alpha, const double *A, int lda, const double *B, int ldb, double beta, double *C, int ldc)
Definition: dot_engine-inl.h:732
static CBLAS_TRANSPOSE GetT(bool t)
Definition: dot_engine-inl.h:388
static void dot(Stream< Device > *stream, int n, const DType *X, int incX, const DType *Y, int incY, DType *ret)
Definition: dot_engine-inl.h:107
static void batched_gemm(Stream< cpu > *stream, bool transa, bool transb, int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, float beta, float *C, int ldc, int batch_count, float **workspace)
Definition: dot_engine-inl.h:289
static cublasOperation_t GetT(bool t)
Definition: dot_engine-inl.h:503
MSHADOW_XINLINE Shape< 2 > Shape2(index_t s0, index_t s1)
construct a two dimension shape, stride will equal s0
Definition: tensor.h:198
static const int kDevMask
device flag number, identifies this device
Definition: tensor.h:25
static void ger(Stream< cpu > *stream, int m, int n, double alpha, const double *X, int incX, const double *Y, int incY, double *A, int lda)
Definition: dot_engine-inl.h:474
static void ger(Stream< cpu > *stream, int m, int n, float alpha, const float *X, int incX, const float *Y, int incY, float *A, int lda)
Definition: dot_engine-inl.h:362
void GetBatchedView(DType **dst, DType *src, int num, int stride, Stream< Device > *stream)
CPU/GPU: Get a batched view of the src array. dst[i] = src + i * stride.
static void dot(Stream< gpu > *stream, int n, const float *X, int incX, const float *Y, int incY, float *ret)
Definition: dot_engine-inl.h:707
static void Eval(Tensor< xpu, 2, DType > *p_dst, const Tensor< xpu, 1, DType > &lhs, const Tensor< xpu, 1, DType > &rhs, DType scale)
Definition: dot_engine-inl.h:911
static void dot(Stream< cpu > *stream, int n, const float *X, int incX, const float *Y, int incY, float *ret)
Definition: dot_engine-inl.h:377
static void ger(Stream< Device > *stream, int m, int n, DType alpha, const DType *X, int incX, const DType *Y, int incY, DType *A, int lda)
Definition: dot_engine-inl.h:95
static void dot(Stream< gpu > *stream, int n, const double *X, int incX, const double *Y, int incY, double *ret)
Definition: dot_engine-inl.h:824
static void batched_gemv(Stream< cpu > *stream, bool trans, int m, int n, float alpha, const float *A, int lda, const float *X, int incX, float beta, float *Y, int incY, int batch_count)
Definition: dot_engine-inl.h:351
static void Eval(Tensor< xpu, 1, DType > *p_dst, const Tensor< xpu, 1, DType > &lhs, const Tensor< xpu, 2, DType > &rhs, DType scale)
Definition: dot_engine-inl.h:886
static void dot(Stream< cpu > *stream, int n, const double *X, int incX, const double *Y, int incY, double *ret)
Definition: dot_engine-inl.h:489
static void ger(Stream< gpu > *stream, int m, int n, double alpha, const double *X, int incX, const double *Y, int incY, double *A, int lda)
Definition: dot_engine-inl.h:807
namespace for mshadow
Definition: base.h:282
MSHADOW_XINLINE index_t size(int idx) const
return size of i-th dimension, start counting from highest dimension
Definition: tensor.h:487
static void gemv(Stream< gpu > *stream, bool trans, int m, int n, float alpha, const float *A, int lda, const float *X, int incX, float beta, float *Y, int incY)
Definition: dot_engine-inl.h:670
static void gemm(Stream< Device > *stream, bool transa, bool transb, int m, int n, int k, DType alpha, const DType *A, int lda, const DType *B, int ldb, DType beta, DType *C, int ldc)
Definition: dot_engine-inl.h:66
TransposeExExp< SrcExp, DType, ExpInfo< SrcExp >::kDim > transpose(const Exp< SrcExp, DType, etype > &src, Shape< ExpInfo< SrcExp >::kDim > axes)
a expression that reshapes a tensor to another shape
Definition: transpose.h:58
index_t stride_
storing the stride information in x dimension this is used to deal with pitch allocation in gpu or ss...
Definition: tensor.h:423
static void SetStream(Stream< gpu > *stream)
Definition: dot_engine-inl.h:506
Definition: dot_engine-inl.h:60
static void dot(Stream< gpu > *stream, int n, const half::half_t *X, int incX, const half::half_t *Y, int incY, half::half_t *ret)
Definition: dot_engine-inl.h:596
general tensor
Definition: tensor.h:402
static void gemv(Stream< Device > *stream, bool trans, int m, int n, DType alpha, const DType *A, int lda, const DType *X, int incX, DType beta, DType *Y, int incY)
Definition: dot_engine-inl.h:81
static void gemm(Stream< gpu > *stream, bool transa, bool transb, int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, float beta, float *C, int ldc)
Definition: dot_engine-inl.h:615
static void SetStream(Stream< Device > *stream)
Definition: dot_engine-inl.h:64
static bool GetT(bool t)
Definition: dot_engine-inl.h:61
const TransposeExp< Tensor< Device, dimension, DType >, DType > T(void) const
transpose of a matrix
Definition: expression.h:136
Stream< Device > * stream_
stream where the computation lies stream is a device dependency concept where each computation ...
Definition: tensor.h:428
void GetBatchedView(DType **dst, DType *src, int num, int stride, Stream< cpu > *stream)
Definition: dot_engine-inl.h:31
computaion stream structure, used for asynchronous computations
Definition: tensor.h:365
static CBLAS_TRANSPOSE GetT(bool t)
Definition: dot_engine-inl.h:276