mxnet
dot_engine-inl.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
25 #ifndef MSHADOW_DOT_ENGINE_INL_H_
26 #define MSHADOW_DOT_ENGINE_INL_H_
27 
28 #include <vector>
29 #include "./base.h"
31 
32 #ifdef __CUDACC__
33 #include "./cuda/tensor_gpu-inl.cuh"
34 #endif // #ifdef __CUDACC__
35 
36 namespace mshadow {
45 template<typename Device, typename DType>
46 inline void GetBatchedView(DType **dst, DType *src, int num, int stride,
47  Stream<Device> *stream);
48 template<typename DType>
49 inline void GetBatchedView(DType **dst, DType *src, int num, int stride,
50  Stream<cpu> *stream) {
51  for (int i = 0; i < num; i++) {
52  dst[i] = src + i * stride;
53  }
54 }
55 #ifdef __CUDACC__
56 namespace cuda {};
57 template<typename DType>
58 inline void GetBatchedView(DType **dst, DType *src, int num, int stride,
59  Stream<gpu> *stream) {
60  cuda::GetBatchedView(dst, src, num, stride, stream);
61 }
62 #endif // #ifdef __CUDACC__
63 
64 namespace expr {
65 //---------------------------------------------------------------------
66 // Matrix Multiplications, depends on BLAS Engine
67 //---------------------------------------------------------------------
68 template<typename SV, typename Device, int ddim, int ldim,
69  int rdim, bool ltrans, bool rtrans, typename DType>
70 struct DotEngine {
71  inline static void Eval(Tensor<Device, ddim, DType> *p_dst,
72  const Tensor<Device, ldim, DType> &lhs,
73  const Tensor<Device, rdim, DType> &rhs,
74  DType scale);
75 };
76 // handles the dot, use CblasColMajor
77 template<typename Device, typename DType = default_real_t>
78 struct BLASEngine {
79  inline static bool GetT(bool t) {
80  return t ? true : false;
81  }
82  inline static void SetStream(Stream<Device> *stream) {
83  }
84  inline static void gemm(Stream<Device> *stream,
85  bool transa, bool transb,
86  int m, int n, int k, DType alpha,
87  const DType *A, int lda, const DType *B, int ldb,
88  DType beta, DType *C, int ldc) {
89  LOG(FATAL) << "Not implmented!";
90  }
91  inline static void batched_gemm(Stream<Device> *stream,
92  bool transa, bool transb,
93  int m, int n, int k, DType alpha,
94  const DType *A, int lda, const DType *B, int ldb,
95  DType beta, DType *C, int ldc, int batch_count,
96  DType **workspace) {
97  LOG(FATAL) << "Not implmented!";
98  }
99  inline static void gemv(Stream<Device> *stream,
100  bool trans, int m, int n,
101  DType alpha, const DType *A, int lda,
102  const DType *X, int incX,
103  DType beta, DType *Y, int incY) {
104  LOG(FATAL) << "Not implmented!";
105  }
106  inline static void batched_gemv(Stream<Device> *stream,
107  bool trans, int m, int n,
108  DType alpha, const DType *A, int lda,
109  const DType *X, int incX,
110  DType beta, DType *Y, int incY, int batch_count) {
111  LOG(FATAL) << "Not implmented!";
112  }
113  inline static void ger(Stream<Device> *stream,
114  int m, int n, DType alpha,
115  const DType *X, int incX,
116  const DType *Y, int incY, DType *A, int lda) {
117  LOG(FATAL) << "Not implmented!";
118  }
119  inline static void batched_ger(Stream<Device> *stream,
120  int m, int n, DType alpha,
121  const DType *X, int incX,
122  const DType *Y, int incY, DType *A, int lda, int batch_count) {
123  LOG(FATAL) << "Not implmented!";
124  }
125  inline static void dot(Stream<Device> *stream,
126  int n,
127  const DType* X, int incX,
128  const DType* Y, int incY,
129  DType* ret) {
130  LOG(FATAL) << "Not implmented!";
131  }
132 };
133 
134 #if MSHADOW_STAND_ALONE
135 template<>
136 struct BLASEngine<cpu, float> {
137  inline static bool GetT(bool t) {
138  return t ? true : false;
139  }
140  inline static void SetStream(Stream<cpu> *stream) {
141  }
142  inline static void gemm(Stream<cpu> *stream,
143  bool transa, bool transb,
144  int m, int n, int k, float alpha,
145  const float *A, int lda, const float *B, int ldb,
146  float beta, float *C, int ldc) {
147  if (alpha == 1.0f && beta == 0.0f) {
148  bool transpose_left = transb;
149  bool transpose_right = transa;
150  Tensor<cpu, 2, float> lhs((float*)B, Shape2(transpose_left ? k : n, transpose_left ? n : k)); // NOLINT(*)
151  Tensor<cpu, 2, float> rhs((float*)A, Shape2(transpose_right ? m : k, transpose_right ? k : m)); // NOLINT(*)
152  Tensor<cpu, 2, float> dst(C, Shape2(m, n));
153  if (!transpose_left && !transpose_right) {
154  dst = expr::implicit_dot(lhs, rhs); return;
155  } else if (!transpose_left && transpose_right) {
156  dst = expr::implicit_dot(lhs, rhs.T()); return;
157  } else if (transpose_left && !transpose_right) {
158  dst = expr::implicit_dot(lhs.T(), rhs); return;
159  } else {
160  LOG(FATAL) << "Not implmented!";
161  }
162  } else {
163  LOG(FATAL) << "Not implmented!";
164  }
165  }
166  inline static void batched_gemm(Stream<cpu> *stream,
167  bool transa, bool transb,
168  int m, int n, int k, float alpha,
169  const float *A, int lda, const float *B, int ldb,
170  float beta, float *C, int ldc, int batch_count,
171  float **workspace) {
172  for (int i = 0; i < batch_count; ++i) {
173  gemm(stream, transa, transb, m, n, k, alpha,
174  A + i * m * k, lda, B + i * k * n, ldb,
175  beta, C + i * m * n, ldc);
176  }
177  }
178  inline static void gemv(Stream<cpu> *stream,
179  bool trans, int m, int n,
180  float alpha, const float *A, int lda,
181  const float *X, int incX,
182  float beta, float *Y, int incY) {
183  LOG(FATAL) << "Not implmented!";
184  }
185  inline static void batched_gemv(Stream<cpu> *stream,
186  bool trans, int m, int n,
187  float alpha, const float *A, int lda,
188  const float *X, int incX,
189  float beta, float *Y, int incY, int batch_count) {
190  LOG(FATAL) << "Not implmented!";
191  }
192  inline static void ger(Stream<cpu> *stream,
193  int m, int n, float alpha,
194  const float *X, int incX,
195  const float *Y, int incY, float *A, int lda) {
196  LOG(FATAL) << "Not implmented!";
197  }
198  inline static void batched_ger(Stream<cpu> *stream,
199  int m, int n, float alpha,
200  const float *X, int incX,
201  const float *Y, int incY, float *A, int lda, int batch_count) {
202  LOG(FATAL) << "Not implmented!";
203  }
204  inline static void dot(Stream<cpu> *stream,
205  int n,
206  const float* X, int incX,
207  const float* Y, int incY,
208  float* ret) {
209  LOG(FATAL) << "Not implmented!";
210  }
211 };
212 
213 template<>
214 struct BLASEngine<cpu, double> {
215  inline static bool GetT(bool t) {
216  return t ? true : false;
217  }
218  inline static void SetStream(Stream<cpu> *stream) {
219  }
220  inline static void gemm(Stream<cpu> *stream,
221  bool transa, bool transb,
222  int m, int n, int k, double alpha,
223  const double *A, int lda, const double *B, int ldb,
224  double beta, double *C, int ldc) {
225  if (alpha == 1.0f && beta == 0.0f) {
226  bool transpose_left = transb;
227  bool transpose_right = transa;
228  Tensor<cpu, 2, double> lhs((double*)B, Shape2(transpose_left ? k : n, transpose_left ? n : k)); // NOLINT(*)
229  Tensor<cpu, 2, double> rhs((double*)A, Shape2(transpose_right ? m : k, transpose_right ? k : m)); // NOLINT(*)
230  Tensor<cpu, 2, double> dst(C, Shape2(m, n));
231  if (!transpose_left && !transpose_right) {
232  dst = expr::implicit_dot(lhs, rhs); return;
233  } else if (!transpose_left && transpose_right) {
234  dst = expr::implicit_dot(lhs, rhs.T()); return;
235  } else if (transpose_left && !transpose_right) {
236  dst = expr::implicit_dot(lhs.T(), rhs); return;
237  } else {
238  LOG(FATAL) << "Not implmented!";
239  }
240  } else {
241  LOG(FATAL) << "Not implmented!";
242  }
243  }
244  inline static void batched_gemm(Stream<cpu> *stream,
245  bool transa, bool transb,
246  int m, int n, int k, double alpha,
247  const double *A, int lda, const double *B, int ldb,
248  double beta, double *C, int ldc, int batch_count,
249  double **workspace) {
250  for (int i = 0; i < batch_count; ++i) {
251  gemm(stream, transa, transb, m, n, k, alpha,
252  A + i * m * k, lda, B + i * k * n, ldb,
253  beta, C + i * m * n, ldc);
254  }
255  }
256  inline static void gemv(Stream<cpu> *stream,
257  bool trans, int m, int n,
258  double alpha, const double *A, int lda,
259  const double *X, int incX,
260  double beta, double *Y, int incY) {
261  LOG(FATAL) << "Not implmented!";
262  }
263  inline static void batched_gemv(Stream<cpu> *stream,
264  bool trans, int m, int n,
265  double alpha, const double *A, int lda,
266  const double *X, int incX,
267  double beta, double *Y, int incY, int batch_count) {
268  LOG(FATAL) << "Not implmented!";
269  }
270  inline static void ger(Stream<cpu> *stream,
271  int m, int n, double alpha,
272  const double *X, int incX,
273  const double *Y, int incY, double *A, int lda) {
274  LOG(FATAL) << "Not implmented!";
275  }
276  inline static void batched_ger(Stream<cpu> *stream,
277  int m, int n, double alpha,
278  const double *X, int incX,
279  const double *Y, int incY, double *A, int lda, int batch_count) {
280  LOG(FATAL) << "Not implmented!";
281  }
282  inline static void dot(Stream<cpu> *stream,
283  int n,
284  const double* X, int incX,
285  const double* Y, int incY,
286  double* ret) {
287  LOG(FATAL) << "Not implmented!";
288  }
289 };
290 
291 #elif (MSHADOW_USE_MKL || MSHADOW_USE_CBLAS) // NOLINT(*)
292 template<>
293 struct BLASEngine<cpu, float> {
294  inline static CBLAS_TRANSPOSE GetT(bool t) {
295  return t ? CblasTrans : CblasNoTrans;
296  }
297  inline static void SetStream(Stream<cpu> *stream) {
298  }
299  inline static void gemm(Stream<cpu> *stream,
300  bool transa, bool transb,
301  index_t m, index_t n, index_t k, float alpha,
302  const float *A, index_t lda, const float *B, index_t ldb,
303  float beta, float *C, index_t ldc) {
304  cblas_sgemm(CblasColMajor, GetT(transa), GetT(transb),
305  m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
306  }
307  inline static void batched_gemm(Stream<cpu> *stream,
308  bool transa, bool transb,
309  index_t m, index_t n, index_t k, float alpha,
310  const float *A, index_t lda, const float *B, index_t ldb,
311  float beta, float *C, index_t ldc, index_t batch_count,
312  float **workspace) {
313 #if (MSHADOW_USE_MKL && INTEL_MKL_VERSION >= 20160000)
314  // since same m/n/k is used for all single gemms, so we put all gemms into one group
315  const int GROUP_SIZE = 1;
316  MKL_INT p_m[GROUP_SIZE] = {static_cast<MKL_INT>(m)};
317  MKL_INT p_n[GROUP_SIZE] = {static_cast<MKL_INT>(n)};
318  MKL_INT p_k[GROUP_SIZE] = {static_cast<MKL_INT>(k)};
319  MKL_INT p_lda[GROUP_SIZE] = {static_cast<MKL_INT>(lda)};
320  MKL_INT p_ldb[GROUP_SIZE] = {static_cast<MKL_INT>(ldb)};
321  MKL_INT p_ldc[GROUP_SIZE] = {static_cast<MKL_INT>(ldc)};
322 
323  float p_alpha[GROUP_SIZE] = {alpha};
324  float p_beta[GROUP_SIZE] = {beta};
325 
326  CBLAS_TRANSPOSE cblas_a_trans = GetT(transa);
327  CBLAS_TRANSPOSE cblas_b_trans = GetT(transb);
328 
329  MKL_INT p_group_sizeb[GROUP_SIZE] = {static_cast<MKL_INT>(batch_count)};
330  CBLAS_TRANSPOSE p_transa[GROUP_SIZE] = {cblas_a_trans};
331  CBLAS_TRANSPOSE p_transb[GROUP_SIZE] = {cblas_b_trans};
332 
333  std::vector<const float*> pp_A(batch_count, nullptr);
334  std::vector<const float*> pp_B(batch_count, nullptr);
335  std::vector<float*> pp_C(batch_count, nullptr);
336 
337  auto m_k = m * k;
338  auto k_n = k * n;
339  auto m_n = m * n;
340 
341  for (int i = 0; i < batch_count; i++) {
342  pp_A[i] = A + i * m_k;
343  pp_B[i] = B + i * k_n;
344  pp_C[i] = C + i * m_n;
345  }
346 
347  cblas_sgemm_batch(CblasColMajor, p_transa, p_transb,
348  p_m, p_n, p_k, p_alpha, pp_A.data(), p_lda, pp_B.data(),
349  p_ldb, p_beta, pp_C.data(), p_ldc, GROUP_SIZE, p_group_sizeb);
350 #else
351  for (int i = 0; i < batch_count; ++i) {
352  gemm(stream, transa, transb, m, n, k, alpha,
353  A + i * m * k, lda, B + i * k * n, ldb,
354  beta, C + i * m * n, ldc);
355  }
356 #endif
357  }
358  inline static void gemv(Stream<cpu> *stream,
359  bool trans, int m, int n,
360  float alpha, const float *A, int lda,
361  const float *X, int incX,
362  float beta, float *Y, int incY) {
363  cblas_sgemv(CblasColMajor, GetT(trans), m, n, alpha,
364  A, lda, X, incX, beta, Y, incY);
365  }
366  inline static void batched_gemv(Stream<cpu> *stream,
367  bool trans, int m, int n,
368  float alpha, const float *A, int lda,
369  const float *X, int incX,
370  float beta, float *Y, int incY, int batch_count) {
371  for (int i = 0; i < batch_count; ++i) {
372  gemv(stream, trans, m, n, alpha, A + i * m * n, lda,
373  X + i * (trans ? m : n) * incX, incX,
374  beta, Y + i * (trans ? n : m) * incY, incY);
375  }
376  }
377  inline static void ger(Stream<cpu> *stream,
378  int m, int n, float alpha,
379  const float *X, int incX,
380  const float *Y, int incY, float *A, int lda) {
381  cblas_sger(CblasColMajor, m, n, alpha, X, incX, Y, incY, A, lda);
382  }
383  inline static void batched_ger(Stream<cpu> *stream,
384  int m, int n, float alpha,
385  const float *X, int incX,
386  const float *Y, int incY, float *A, int lda, int batch_count) {
387  for (int i = 0; i < batch_count; ++i) {
388  ger(stream, m, n, alpha, X + i * m * incX, incX, Y + i * n * incY, incY,
389  A + i * lda * n, lda);
390  }
391  }
392  inline static void dot(Stream<cpu> *stream,
393  int n,
394  const float* X, int incX,
395  const float* Y, int incY,
396  float* ret) {
397  *ret = cblas_sdot(n, X, incX, Y, incY);
398  }
399 };
400 
401 template<>
402 struct BLASEngine<cpu, double> {
403  inline static CBLAS_TRANSPOSE GetT(bool t) {
404  return t ? CblasTrans : CblasNoTrans;
405  }
406  inline static void SetStream(Stream<cpu> *stream) {
407  }
408  inline static void gemm(Stream<cpu> *stream,
409  bool transa, bool transb,
410  index_t m, index_t n, index_t k, double alpha,
411  const double *A, index_t lda, const double *B, index_t ldb,
412  double beta, double *C, index_t ldc) {
413  cblas_dgemm(CblasColMajor, GetT(transa), GetT(transb),
414  m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
415  }
416  inline static void batched_gemm(Stream<cpu> *stream,
417  bool transa, bool transb,
418  index_t m, index_t n, index_t k, double alpha,
419  const double *A, index_t lda, const double *B, index_t ldb,
420  double beta, double *C, index_t ldc, index_t batch_count,
421  double **workspace) {
422 #if (MSHADOW_USE_MKL && INTEL_MKL_VERSION >= 20160000)
423  // since same m/n/k is used for all single gemms, so we put all gemms into one group
424  const int GROUP_SIZE = 1;
425  MKL_INT p_m[GROUP_SIZE] = {static_cast<MKL_INT>(m)};
426  MKL_INT p_n[GROUP_SIZE] = {static_cast<MKL_INT>(n)};
427  MKL_INT p_k[GROUP_SIZE] = {static_cast<MKL_INT>(k)};
428  MKL_INT p_lda[GROUP_SIZE] = {static_cast<MKL_INT>(lda)};
429  MKL_INT p_ldb[GROUP_SIZE] = {static_cast<MKL_INT>(ldb)};
430  MKL_INT p_ldc[GROUP_SIZE] = {static_cast<MKL_INT>(ldc)};
431 
432  double p_alpha[GROUP_SIZE] = {alpha};
433  double p_beta[GROUP_SIZE] = {beta};
434 
435  CBLAS_TRANSPOSE cblas_a_trans = GetT(transa);
436  CBLAS_TRANSPOSE cblas_b_trans = GetT(transb);
437 
438  MKL_INT p_group_sizeb[GROUP_SIZE] = {static_cast<MKL_INT>(batch_count)};
439  CBLAS_TRANSPOSE p_transa[GROUP_SIZE] = {cblas_a_trans};
440  CBLAS_TRANSPOSE p_transb[GROUP_SIZE] = {cblas_b_trans};
441 
442  std::vector<const double*> pp_A(batch_count, nullptr);
443  std::vector<const double*> pp_B(batch_count, nullptr);
444  std::vector<double*> pp_C(batch_count, nullptr);
445 
446  auto m_k = m * k;
447  auto k_n = k * n;
448  auto m_n = m * n;
449 
450  for (int i = 0; i < batch_count; i++) {
451  pp_A[i] = A + i * m_k;
452  pp_B[i] = B + i * k_n;
453  pp_C[i] = C + i * m_n;
454  }
455 
456  cblas_dgemm_batch(CblasColMajor, p_transa, p_transb,
457  p_m, p_n, p_k, p_alpha, pp_A.data(), p_lda, pp_B.data(),
458  p_ldb, p_beta, pp_C.data(), p_ldc, GROUP_SIZE, p_group_sizeb);
459 #else
460  for (int i = 0; i < batch_count; ++i) {
461  gemm(stream, transa, transb, m, n, k, alpha,
462  A + i * m * k, lda, B + i * k * n, ldb,
463  beta, C + i * m * n, ldc);
464  }
465 #endif
466  }
467  inline static void gemv(Stream<cpu> *stream,
468  bool trans, int m, int n, double alpha,
469  const double *A, int lda,
470  const double *X, int incX,
471  double beta, double *Y, int incY) {
472  cblas_dgemv(CblasColMajor, GetT(trans), m, n, alpha,
473  A, lda, X, incX, beta, Y, incY);
474  }
475  inline static void batched_gemv(Stream<cpu> *stream,
476  bool trans, int m, int n,
477  double alpha, const double *A, int lda,
478  const double *X, int incX,
479  double beta, double *Y, int incY, int batch_count) {
480  for (int i = 0; i < batch_count; ++i) {
481  gemv(stream, trans, m, n, alpha, A + i * m * n, lda,
482  X + i * (trans ? m : n) * incX, incX,
483  beta, Y + i * (trans ? n : m) * incY, incY);
484  }
485  }
486  inline static void ger(Stream<cpu> *stream,
487  int m, int n, double alpha,
488  const double *X, int incX,
489  const double *Y, int incY, double *A, int lda) {
490  cblas_dger(CblasColMajor, m, n, alpha, X, incX, Y, incY, A, lda);
491  }
492  inline static void batched_ger(Stream<cpu> *stream,
493  int m, int n, double alpha,
494  const double *X, int incX,
495  const double *Y, int incY, double *A, int lda, int batch_count) {
496  for (int i = 0; i < batch_count; ++i) {
497  ger(stream, m, n, alpha, X + i * m * incX, incX, Y + i * n * incY, incY,
498  A + i * lda * n, lda);
499  }
500  }
501  inline static void dot(Stream<cpu> *stream,
502  int n,
503  const double* X, int incX,
504  const double* Y, int incY,
505  double* ret) {
506  *ret = cblas_ddot(n, X, incX, Y, incY);
507  }
508 };
509 #endif // MSHADOW_USE_CBLAS || MSHADOW_USE_MKL || MSHADOW_STAND_ALONE
510 // CuBLAS redirect code
511 #if MSHADOW_USE_CUDA
512 // All CuBLAS goes to here, use legacy API: not threadsafe
513 template<>
514 struct BLASEngine<gpu, half::half_t> {
515  inline static cublasOperation_t GetT(bool t) {
516  return t ? CUBLAS_OP_T : CUBLAS_OP_N;
517  }
518  inline static void SetStream(Stream<gpu> *stream) {
519  cublasStatus_t err = cublasSetStream(Stream<gpu>::GetBlasHandle(stream),
520  Stream<gpu>::GetStream(stream));
521  CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas set stream fail";
522  }
523  inline static void gemm(Stream<gpu> *stream,
524  bool transa, bool transb,
525  int m, int n, int k, half::half_t alpha,
526  const half::half_t *A, int lda,
527  const half::half_t *B, int ldb, half::half_t beta,
528  half::half_t *C, int ldc) {
529 #if defined(CUDA_VERSION) && CUDA_VERSION >= 7050
530  // Always use pseudo-fp16: fp32 compute with fp16 I/O.
531  float alpha_f = float(alpha); // NOLINT(*)
532  float beta_f = float(beta); // NOLINT(*)
533  #if CUDA_VERSION >= 8000
534  cublasStatus_t err = cublasSgemmEx(Stream<gpu>::GetBlasHandle(stream),
535  GetT(transa), GetT(transb), m, n, k, &alpha_f,
536  A, CUDA_R_16F, lda, B, CUDA_R_16F,
537  ldb, &beta_f, C, CUDA_R_16F, ldc);
538  CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas SgemmEx fail";
539  #else
540  cublasStatus_t err = cublasSgemmEx(Stream<gpu>::GetBlasHandle(stream),
541  GetT(transa), GetT(transb), m, n, k, &alpha_f,
542  A, CUBLAS_DATA_HALF, lda, B, CUBLAS_DATA_HALF,
543  ldb, &beta_f, C, CUBLAS_DATA_HALF, ldc);
544  CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas SgemmEx fail";
545  #endif // CUDA_VERSION >= 8000
546 #else
547  LOG(FATAL) << "Require CUDA version >= 7.5!";
548 #endif // defined(CUDA_VERSION) && CUDA_VERSION >= 7050
549  }
550  inline static void batched_gemm(Stream<gpu> *stream,
551  bool transa, bool transb,
552  int m, int n, int k, half::half_t alpha,
553  const half::half_t *A, int lda, const half::half_t *B, int ldb,
554  half::half_t beta, half::half_t *C, int ldc, int batch_count,
555  half::half_t **workspace) {
556 #if defined(__CUDACC__) && CUDA_VERSION >= 9000
557  int major = stream->prop.major;
558  int minor = stream->prop.minor;
559  // fp16 is not supported before ARCH 53
560  if ((major > 5) || (major == 5 && minor >= 3)) {
561  const __half* A_h = reinterpret_cast<const __half*>(A);
562  const __half* B_h = reinterpret_cast<const __half*>(B);
563  __half* alpha_h = reinterpret_cast<__half*>(&alpha);
564  __half* beta_h = reinterpret_cast<__half*>(&beta);
565  __half* C_h = reinterpret_cast<__half*>(C);
566  cublasStatus_t err = cublasHgemmStridedBatched(Stream<gpu>::GetBlasHandle(stream),
567  GetT(transa), GetT(transb), m, n, k, alpha_h,
568  A_h, lda, m * k,
569  B_h, ldb, k * n,
570  beta_h, C_h, ldc, m * n,
571  batch_count);
572  CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: HgemmStridedBatched fail";
573  return;
574  }
575 #endif
576  for (int i = 0; i < batch_count; ++i) {
577  gemm(stream, transa, transb, m, n, k, alpha,
578  A + i * m * k, lda, B + i * k * n, ldb,
579  beta, C + i * m * n, ldc);
580  }
581  }
582  inline static void gemv(Stream<gpu> *stream,
583  bool trans, int m, int n, half::half_t alpha,
584  const half::half_t *A, int lda,
585  const half::half_t *X, int incX, half::half_t beta,
586  half::half_t *Y, int incY) {
587  LOG(FATAL) << "Not implmented!";
588  }
589  inline static void batched_gemv(Stream<gpu> *stream,
590  bool trans, int m, int n,
591  half::half_t alpha, const half::half_t *A, int lda,
592  const half::half_t *X, int incX,
593  half::half_t beta, half::half_t *Y, int incY, int batch_count) {
594  LOG(FATAL) << "Not implmented!";
595  }
596  inline static void ger(Stream<gpu> *stream,
597  int m, int n, half::half_t alpha,
598  const half::half_t *X, int incX,
599  const half::half_t *Y, int incY, half::half_t *A, int lda) {
600  LOG(FATAL) << "Not implmented!";
601  }
602  inline static void batched_ger(Stream<gpu> *stream,
603  int m, int n, half::half_t alpha,
604  const half::half_t *X, int incX, const half::half_t *Y, int incY,
605  half::half_t *A, int lda, int batch_count) {
606  LOG(FATAL) << "Not implmented!";
607  }
608  inline static void dot(Stream<gpu> *stream,
609  int n,
610  const half::half_t* X, int incX,
611  const half::half_t* Y, int incY,
612  half::half_t *ret) {
613  LOG(FATAL) << "Not implmented!";
614  }
615 };
616 
617 template<>
618 struct BLASEngine<gpu, float> {
619  inline static cublasOperation_t GetT(bool t) {
620  return t ? CUBLAS_OP_T : CUBLAS_OP_N;
621  }
622  inline static void SetStream(Stream<gpu> *stream) {
623  cublasStatus_t err = cublasSetStream(Stream<gpu>::GetBlasHandle(stream),
624  Stream<gpu>::GetStream(stream));
625  CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: set stream fail";
626  }
627  inline static void gemm(Stream<gpu> *stream,
628  bool transa, bool transb,
629  int m, int n, int k, float alpha,
630  const float *A, int lda,
631  const float *B, int ldb, float beta,
632  float *C, int ldc) {
633  cublasStatus_t err = cublasSgemm(Stream<gpu>::GetBlasHandle(stream),
634  GetT(transa), GetT(transb), m, n, k, &alpha,
635  A, lda, B, ldb, &beta, C, ldc);
636  CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Sgemm fail";
637  }
638  inline static void batched_gemm(Stream<gpu> *stream,
639  bool transa, bool transb,
640  int m, int n, int k, float alpha,
641  const float *A, int lda, const float *B, int ldb,
642  float beta, float *C, int ldc, int batch_count,
643  float **workspace) {
644 #if defined(__CUDACC__) && CUDA_VERSION >= 4010 && CUDA_VERSION < 8000
645  // Cast DType* to DType** using workspace as a buffer
646  bool alloc_workspace = false;
647  if (workspace == NULL) {
648  // Allocate the workspace if it's NULL.
649  // TODO(sxjscience) Try to move the allocation inside Tensor, which is thread-safe.
650  cudaMalloc(reinterpret_cast<void**>(&workspace), 3 * batch_count * sizeof(float*));
651  alloc_workspace = true;
652  }
653  GetBatchedView(workspace, const_cast<float*>(A), batch_count, m * k, stream);
654  GetBatchedView(workspace + batch_count,
655  const_cast<float*>(B), batch_count, k * n, stream);
656  GetBatchedView(workspace + 2 * batch_count, C, batch_count, m * n, stream);
657  cublasStatus_t err = cublasSgemmBatched(Stream<gpu>::GetBlasHandle(stream),
658  GetT(transa), GetT(transb), m, n, k, &alpha,
659  (const float**)workspace, lda,
660  (const float**)(workspace + batch_count), ldb,
661  &beta, workspace + 2 * batch_count, ldc, batch_count);
662  CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: SgemmBatched fail";
663  if (alloc_workspace) {
664  cudaFree(workspace);
665  }
666 #elif defined(__CUDACC__) && CUDA_VERSION >= 8000
667  cublasStatus_t err = cublasSgemmStridedBatched(Stream<gpu>::GetBlasHandle(stream),
668  GetT(transa), GetT(transb), m, n, k, &alpha,
669  A, lda, m * k,
670  B, ldb, k * n,
671  &beta, C, ldc, m * n,
672  batch_count);
673  CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: SgemmStridedBatched fail";
674 #else
675  for (int i = 0; i < batch_count; ++i) {
676  gemm(stream, transa, transb, m, n, k, alpha,
677  A + i * m * k, lda, B + i * k * n, ldb,
678  beta, C + i * m * n, ldc);
679  }
680 #endif // defined(__CUDACC__) && CUDA_VERSION >= 4010
681  }
682  inline static void gemv(Stream<gpu> *stream,
683  bool trans, int m, int n, float alpha,
684  const float *A, int lda,
685  const float *X, int incX, float beta,
686  float *Y, int incY) {
687  cublasStatus_t err = cublasSgemv(Stream<gpu>::GetBlasHandle(stream),
688  GetT(trans), m, n, &alpha, A, lda, X, incX, &beta, Y, incY);
689  CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Sgemv fail";
690  }
691  inline static void batched_gemv(Stream<gpu> *stream,
692  bool trans, int m, int n,
693  float alpha, const float *A, int lda,
694  const float *X, int incX,
695  float beta, float *Y, int incY, int batch_count) {
696  for (int i = 0; i < batch_count; ++i) {
697  gemv(stream, trans, m, n, alpha, A + i * m * n, lda,
698  X + i * (trans ? m : n) * incX, incX,
699  beta, Y + i * (trans ? n : m) * incY, incY);
700  }
701  }
702  inline static void ger(Stream<gpu> *stream,
703  int m, int n, float alpha,
704  const float *X, int incX,
705  const float *Y, int incY, float *A, int lda) {
706  cublasStatus_t err = cublasSger(Stream<gpu>::GetBlasHandle(stream),
707  m, n, &alpha, X, incX, Y, incY, A, lda);
708  CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Sger fail";
709  }
710  inline static void batched_ger(Stream<gpu> *stream,
711  int m, int n, float alpha,
712  const float *X, int incX,
713  const float *Y, int incY, float *A, int lda, int batch_count) {
714  for (int i = 0; i < batch_count; ++i) {
715  ger(stream, m, n, alpha, X + i * m * incX, incX, Y + i * n * incY, incY,
716  A + i * lda * n, lda);
717  }
718  }
719  inline static void dot(Stream<gpu> *stream,
720  int n,
721  const float* X, int incX,
722  const float* Y, int incY,
723  float *ret) {
724  cublasSetPointerMode(Stream<gpu>::GetBlasHandle(stream),
725  CUBLAS_POINTER_MODE_DEVICE);
726  cublasStatus_t err = cublasSdot(Stream<gpu>::GetBlasHandle(stream),
727  n, X, incX, Y, incY, ret);
728  CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Dot fail";
729  cublasSetPointerMode(Stream<gpu>::GetBlasHandle(stream),
730  CUBLAS_POINTER_MODE_HOST);
731  }
732 };
733 
734 template<>
735 struct BLASEngine<gpu, double> {
736  inline static cublasOperation_t GetT(bool t) {
737  return t ? CUBLAS_OP_T : CUBLAS_OP_N;
738  }
739  inline static void SetStream(Stream<gpu> *stream) {
740  cublasStatus_t err = cublasSetStream(Stream<gpu>::GetBlasHandle(stream),
741  Stream<gpu>::GetStream(stream));
742  CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: set stream fail";
743  }
744  inline static void gemm(Stream<gpu> *stream,
745  bool transa, bool transb,
746  int m, int n, int k, double alpha,
747  const double *A, int lda,
748  const double *B, int ldb,
749  double beta, double *C, int ldc) {
750  cublasStatus_t err = cublasDgemm(Stream<gpu>::GetBlasHandle(stream),
751  GetT(transa), GetT(transb), m, n, k, &alpha,
752  A, lda, B, ldb, &beta, C, ldc);
753  CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Dgemm fail";
754  }
755  inline static void batched_gemm(Stream<gpu> *stream,
756  bool transa, bool transb,
757  int m, int n, int k, double alpha,
758  const double *A, int lda, const double *B, int ldb,
759  double beta, double *C, int ldc, int batch_count,
760  double **workspace) {
761 #if defined(__CUDACC__) && CUDA_VERSION >= 4010 && CUDA_VERSION < 8000
762  // Cast DType* to DType** using workspace as a buffer
763  bool alloc_workspace = false;
764  if (workspace == NULL) {
765  // Allocate the workspace if it's NULL.
766  // TODO(sxjscience) Try to move the allocation inside Tensor, which is thread-safe.
767  cudaMalloc(reinterpret_cast<void**>(&workspace), 3 * batch_count * sizeof(double*));
768  alloc_workspace = true;
769  }
770  GetBatchedView(workspace, const_cast<double*>(A), batch_count, m * k, stream);
771  GetBatchedView(workspace + batch_count,
772  const_cast<double*>(B), batch_count, k * n, stream);
773  GetBatchedView(workspace + 2 * batch_count, C, batch_count, m * n, stream);
774  cublasStatus_t err = cublasDgemmBatched(Stream<gpu>::GetBlasHandle(stream),
775  GetT(transa), GetT(transb), m, n, k, &alpha,
776  (const double**)workspace, lda,
777  (const double**)(workspace + batch_count), ldb,
778  &beta, workspace + 2 * batch_count, ldc, batch_count);
779  CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: DgemmBatched fail";
780  if (alloc_workspace) {
781  cudaFree(workspace);
782  }
783 #elif defined(__CUDACC__) && CUDA_VERSION >= 8000
784  cublasStatus_t err = cublasDgemmStridedBatched(Stream<gpu>::GetBlasHandle(stream),
785  GetT(transa), GetT(transb), m, n, k, &alpha,
786  A, lda, m * k,
787  B, ldb, k * n,
788  &beta, C, ldc, m * n,
789  batch_count);
790  CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: DgemmStridedBatched fail";
791 #else
792  for (int i = 0; i < batch_count; ++i) {
793  gemm(stream, transa, transb, m, n, k, alpha,
794  A + i * m * k, lda, B + i * k * n, ldb,
795  beta, C + i * m * n, ldc);
796  }
797 #endif // defined(__CUDACC__) && CUDA_VERSION >= 4010
798  }
799  inline static void gemv(Stream<gpu> *stream,
800  bool trans, int m, int n, double alpha,
801  const double *A, int lda,
802  const double *X, int incX,
803  double beta, double *Y, int incY) {
804  cublasStatus_t err = cublasDgemv(Stream<gpu>::GetBlasHandle(stream),
805  GetT(trans), m, n, &alpha, A, lda, X, incX, &beta, Y, incY);
806  CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Dgemv fail";
807  }
808  inline static void batched_gemv(Stream<gpu> *stream,
809  bool trans, int m, int n,
810  double alpha, const double *A, int lda,
811  const double *X, int incX,
812  double beta, double *Y, int incY, int batch_count) {
813  for (int i = 0; i < batch_count; ++i) {
814  gemv(stream, trans, m, n, alpha, A + i * m * n, lda,
815  X + i * (trans ? m : n) * incX, incX,
816  beta, Y + i * (trans ? n : m) * incY, incY);
817  }
818  }
819  inline static void ger(Stream<gpu> *stream,
820  int m, int n, double alpha,
821  const double *X, int incX,
822  const double *Y, int incY, double *A, int lda) {
823  cublasStatus_t err = cublasDger(Stream<gpu>::GetBlasHandle(stream),
824  m, n, &alpha, X, incX, Y, incY, A, lda);
825  CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Dger fail";
826  }
827  inline static void batched_ger(Stream<gpu> *stream,
828  int m, int n, double alpha,
829  const double *X, int incX,
830  const double *Y, int incY, double *A, int lda, int batch_count) {
831  for (int i = 0; i < batch_count; ++i) {
832  ger(stream, m, n, alpha, X + i * m * incX, incX, Y + i * n * incY, incY,
833  A + i * lda * n, lda);
834  }
835  }
836  inline static void dot(Stream<gpu> *stream,
837  int n,
838  const double* X, int incX,
839  const double* Y, int incY,
840  double *ret) {
841  cublasSetPointerMode(Stream<gpu>::GetBlasHandle(stream),
842  CUBLAS_POINTER_MODE_DEVICE);
843  cublasStatus_t err = cublasDdot(Stream<gpu>::GetBlasHandle(stream),
844  n, X, incX, Y, incY, ret);
845  CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Dot fail";
846  cublasSetPointerMode(Stream<gpu>::GetBlasHandle(stream),
847  CUBLAS_POINTER_MODE_HOST);
848  }
849 };
850 #endif // MSHADOW_USE_CUDA
851 // helper function to decide which shape we are in
852 inline Shape<2> GetShape(const Shape<2> &shape, bool transpose) {
853  return transpose ? Shape2(shape[1], shape[0]) : shape;
854 }
855 // dst = dot(lhs[.T], rhs[.T])
856 template<typename SV, typename xpu,
857  bool transpose_left, bool transpose_right, typename DType>
858 struct DotEngine<SV, xpu, 2, 2, 2, transpose_left, transpose_right, DType> {
859  inline static void Eval(Tensor<xpu, 2, DType> *p_dst,
860  const Tensor<xpu, 2, DType> &lhs,
861  const Tensor<xpu, 2, DType> &rhs,
862  DType scale) {
863  Tensor<xpu, 2, DType> &dst = *p_dst;
864 #if MSHADOW_STAND_ALONE
865  if (xpu::kDevMask == cpu::kDevMask && scale == 1.0f) {
866  if (!transpose_left && !transpose_right) {
867  dst = expr::implicit_dot(lhs, rhs); return;
868  } else if (!transpose_left && transpose_right) {
869  dst = expr::implicit_dot(lhs, rhs.T()); return;
870  } else if (transpose_left && !transpose_right) {
871  dst = expr::implicit_dot(lhs.T(), rhs); return;
872  }
873  }
874 #endif
875  // set kernel stream
876  // if there is no stream, crush
878  Shape<2> sleft = GetShape(lhs.shape_, transpose_left);
879  Shape<2> sright = GetShape(rhs.shape_, transpose_right);
880  CHECK(dst.size(0) == sleft[0] && dst.size(1) == sright[1] && sleft[1] == sright[0])
881  << "dot-gemm: matrix shape mismatch";
882  // use column major argument to compatible with most BLAS
884  (dst.stream_,
885  transpose_right , transpose_left,
886  transpose_right ? rhs.size(0) : rhs.size(1),
887  transpose_left ? lhs.size(1) : lhs.size(0),
888  transpose_right ? rhs.size(1) : rhs.size(0),
889  DType(scale * SV::AlphaBLAS()),
890  rhs.dptr_, rhs.stride_,
891  lhs.dptr_, lhs.stride_,
892  DType(SV::BetaBLAS()),
893  dst.dptr_, dst.stride_);
894  }
895 };
896 template<typename SV, typename xpu, bool transpose_right, typename DType>
897 struct DotEngine<SV, xpu, 1, 1, 2, false, transpose_right, DType> {
898  inline static void Eval(Tensor<xpu, 1, DType> *p_dst,
899  const Tensor<xpu, 1, DType> &lhs,
900  const Tensor<xpu, 2, DType> &rhs,
901  DType scale) {
902  Tensor<xpu, 1, DType> &dst = *p_dst;
903  // set kernel stream
904  // if there is no stream, crush
906  Shape<2> sright = GetShape(rhs.shape_, transpose_right);
907  CHECK(dst.size(0) == sright[1] && lhs.size(0) == sright[0])
908  << "dot-gemv: matrix shape mismatch"
909  << "dst: " << dst.shape_ << "\n"
910  << "lhs: " << lhs.shape_ << "\n"
911  << "rhs: " << sright << "\n";
913  (dst.stream_,
914  transpose_right,
915  rhs.size(1), rhs.size(0), scale * SV::AlphaBLAS(),
916  rhs.dptr_, rhs.stride_,
917  lhs.dptr_, 1, SV::BetaBLAS(),
918  dst.dptr_, 1);
919  }
920 };
921 template<typename SV, typename xpu, typename DType>
922 struct DotEngine<SV, xpu, 2, 1, 1, true, false, DType> {
923  inline static void Eval(Tensor<xpu, 2, DType> *p_dst,
924  const Tensor<xpu, 1, DType> &lhs,
925  const Tensor<xpu, 1, DType> &rhs,
926  DType scale) {
927  Tensor<xpu, 2, DType> &dst = *p_dst;
928  // set kernel stream
929  // if there is no stream, crush
931  CHECK(dst.size(0) == lhs.size(0) && dst.size(1) == rhs.size(0))
932  << "dot-ger: matrix shape mismatch"
933  << "dst: " << dst.shape_ << "\n"
934  << "lhs: " << lhs.shape_ << "\n"
935  << "rhs: " << rhs.shape_;
936  if (SV::BetaBLAS() == 0.0f) {
938  (dst.stream_, rhs.size(0), lhs.size(0), scale * SV::AlphaBLAS(),
939  rhs.dptr_, 1, lhs.dptr_, 1, dst.dptr_, dst.stride_);
940  } else {
941  DotEngine<SV, xpu, 2, 2, 2, true, false,
942  DType>::Eval(p_dst, lhs.FlatTo2D(), rhs.FlatTo2D(), scale);
943  }
944  }
945 };
946 } // namespace expr
947 } // namespace mshadow
948 #endif // MSHADOW_DOT_ENGINE_INL_H_
mshadow::expr::BLASEngine< gpu, half::half_t >::SetStream
static void SetStream(Stream< gpu > *stream)
Definition: dot_engine-inl.h:518
mshadow::GetBatchedView
void GetBatchedView(DType **dst, DType *src, int num, int stride, Stream< cpu > *stream)
Definition: dot_engine-inl.h:49
mshadow::expr::BLASEngine< gpu, float >::gemm
static void gemm(Stream< gpu > *stream, bool transa, bool transb, int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, float beta, float *C, int ldc)
Definition: dot_engine-inl.h:627
mshadow::expr::GetShape
Shape< 2 > GetShape(const Shape< 2 > &shape, bool transpose)
Definition: dot_engine-inl.h:852
mshadow::Stream
computaion stream structure, used for asynchronous computations
Definition: tensor.h:488
mshadow::expr::BLASEngine< cpu, double >::batched_gemv
static void batched_gemv(Stream< cpu > *stream, bool trans, int m, int n, double alpha, const double *A, int lda, const double *X, int incX, double beta, double *Y, int incY, int batch_count)
Definition: dot_engine-inl.h:475
mshadow::expr::BLASEngine< gpu, float >::gemv
static void gemv(Stream< gpu > *stream, bool trans, int m, int n, float alpha, const float *A, int lda, const float *X, int incX, float beta, float *Y, int incY)
Definition: dot_engine-inl.h:682
mshadow::expr::BLASEngine< gpu, float >::GetT
static cublasOperation_t GetT(bool t)
Definition: dot_engine-inl.h:619
mshadow::expr::BLASEngine::SetStream
static void SetStream(Stream< Device > *stream)
Definition: dot_engine-inl.h:82
mshadow::expr::BLASEngine< gpu, half::half_t >::gemv
static void gemv(Stream< gpu > *stream, bool trans, int m, int n, half::half_t alpha, const half::half_t *A, int lda, const half::half_t *X, int incX, half::half_t beta, half::half_t *Y, int incY)
Definition: dot_engine-inl.h:582
mshadow::expr::BLASEngine::dot
static void dot(Stream< Device > *stream, int n, const DType *X, int incX, const DType *Y, int incY, DType *ret)
Definition: dot_engine-inl.h:125
mshadow::expr::BLASEngine< cpu, double >::batched_gemm
static void batched_gemm(Stream< cpu > *stream, bool transa, bool transb, index_t m, index_t n, index_t k, double alpha, const double *A, index_t lda, const double *B, index_t ldb, double beta, double *C, index_t ldc, index_t batch_count, double **workspace)
Definition: dot_engine-inl.h:416
mshadow::expr::RValueExp< Tensor< Device, dimension, DType >, DType >::T
const TransposeExp< Tensor< Device, dimension, DType >, DType > T(void) const
transpose of a matrix
Definition: expression.h:154
mshadow::expr::BLASEngine< cpu, float >::ger
static void ger(Stream< cpu > *stream, int m, int n, float alpha, const float *X, int incX, const float *Y, int incY, float *A, int lda)
Definition: dot_engine-inl.h:377
mshadow::expr::BLASEngine< gpu, half::half_t >::batched_ger
static void batched_ger(Stream< gpu > *stream, int m, int n, half::half_t alpha, const half::half_t *X, int incX, const half::half_t *Y, int incY, half::half_t *A, int lda, int batch_count)
Definition: dot_engine-inl.h:602
mshadow::expr::BLASEngine::GetT
static bool GetT(bool t)
Definition: dot_engine-inl.h:79
mshadow::expr::BLASEngine< gpu, double >::dot
static void dot(Stream< gpu > *stream, int n, const double *X, int incX, const double *Y, int incY, double *ret)
Definition: dot_engine-inl.h:836
mshadow::expr::BLASEngine< gpu, half::half_t >::gemm
static void gemm(Stream< gpu > *stream, bool transa, bool transb, int m, int n, int k, half::half_t alpha, const half::half_t *A, int lda, const half::half_t *B, int ldb, half::half_t beta, half::half_t *C, int ldc)
Definition: dot_engine-inl.h:523
mshadow::expr::BLASEngine< gpu, float >::SetStream
static void SetStream(Stream< gpu > *stream)
Definition: dot_engine-inl.h:622
mshadow::Tensor
general tensor
Definition: tensor.h:525
mshadow::expr::BLASEngine< gpu, double >::SetStream
static void SetStream(Stream< gpu > *stream)
Definition: dot_engine-inl.h:739
mshadow::cpu::kDevMask
static const int kDevMask
device flag number, identifies this device
Definition: tensor.h:43
mshadow::expr::BLASEngine< cpu, double >::gemm
static void gemm(Stream< cpu > *stream, bool transa, bool transb, index_t m, index_t n, index_t k, double alpha, const double *A, index_t lda, const double *B, index_t ldb, double beta, double *C, index_t ldc)
Definition: dot_engine-inl.h:408
mshadow::expr::BLASEngine< cpu, double >::SetStream
static void SetStream(Stream< cpu > *stream)
Definition: dot_engine-inl.h:406
mshadow::expr::BLASEngine< cpu, float >::GetT
static CBLAS_TRANSPOSE GetT(bool t)
Definition: dot_engine-inl.h:294
mshadow::expr::BLASEngine::ger
static void ger(Stream< Device > *stream, int m, int n, DType alpha, const DType *X, int incX, const DType *Y, int incY, DType *A, int lda)
Definition: dot_engine-inl.h:113
mshadow::expr::BLASEngine< gpu, double >::batched_gemv
static void batched_gemv(Stream< gpu > *stream, bool trans, int m, int n, double alpha, const double *A, int lda, const double *X, int incX, double beta, double *Y, int incY, int batch_count)
Definition: dot_engine-inl.h:808
mshadow::expr::BLASEngine< gpu, double >::gemv
static void gemv(Stream< gpu > *stream, bool trans, int m, int n, double alpha, const double *A, int lda, const double *X, int incX, double beta, double *Y, int incY)
Definition: dot_engine-inl.h:799
mshadow::expr::BLASEngine::batched_gemm
static void batched_gemm(Stream< Device > *stream, bool transa, bool transb, int m, int n, int k, DType alpha, const DType *A, int lda, const DType *B, int ldb, DType beta, DType *C, int ldc, int batch_count, DType **workspace)
Definition: dot_engine-inl.h:91
mshadow::expr::BLASEngine< gpu, half::half_t >::ger
static void ger(Stream< gpu > *stream, int m, int n, half::half_t alpha, const half::half_t *X, int incX, const half::half_t *Y, int incY, half::half_t *A, int lda)
Definition: dot_engine-inl.h:596
mshadow::gpu
device name GPU
Definition: tensor.h:46
mshadow::cpu
device name CPU
Definition: tensor.h:39
mshadow::expr::BLASEngine< cpu, float >::SetStream
static void SetStream(Stream< cpu > *stream)
Definition: dot_engine-inl.h:297
mshadow::expr::BLASEngine< cpu, double >::gemv
static void gemv(Stream< cpu > *stream, bool trans, int m, int n, double alpha, const double *A, int lda, const double *X, int incX, double beta, double *Y, int incY)
Definition: dot_engine-inl.h:467
mshadow::expr::BLASEngine< gpu, float >::dot
static void dot(Stream< gpu > *stream, int n, const float *X, int incX, const float *Y, int incY, float *ret)
Definition: dot_engine-inl.h:719
mshadow::expr::BLASEngine< gpu, double >::gemm
static void gemm(Stream< gpu > *stream, bool transa, bool transb, int m, int n, int k, double alpha, const double *A, int lda, const double *B, int ldb, double beta, double *C, int ldc)
Definition: dot_engine-inl.h:744
mshadow::expr::BLASEngine< gpu, double >::batched_ger
static void batched_ger(Stream< gpu > *stream, int m, int n, double alpha, const double *X, int incX, const double *Y, int incY, double *A, int lda, int batch_count)
Definition: dot_engine-inl.h:827
mshadow::Tensor::stream_
Stream< Device > * stream_
stream where the computation lies stream is a device dependency concept where each computation
Definition: tensor.h:551
mshadow::Stream< gpu >
Definition: stream_gpu-inl.h:37
mshadow::expr::BLASEngine< gpu, half::half_t >::batched_gemm
static void batched_gemm(Stream< gpu > *stream, bool transa, bool transb, int m, int n, int k, half::half_t alpha, const half::half_t *A, int lda, const half::half_t *B, int ldb, half::half_t beta, half::half_t *C, int ldc, int batch_count, half::half_t **workspace)
Definition: dot_engine-inl.h:550
mshadow::expr::BLASEngine
Definition: dot_engine-inl.h:78
mshadow::expr::DotEngine< SV, xpu, 2, 1, 1, true, false, DType >::Eval
static void Eval(Tensor< xpu, 2, DType > *p_dst, const Tensor< xpu, 1, DType > &lhs, const Tensor< xpu, 1, DType > &rhs, DType scale)
Definition: dot_engine-inl.h:923
mshadow::expr::BLASEngine< cpu, double >::dot
static void dot(Stream< cpu > *stream, int n, const double *X, int incX, const double *Y, int incY, double *ret)
Definition: dot_engine-inl.h:501
mshadow::expr::BLASEngine::batched_gemv
static void batched_gemv(Stream< Device > *stream, bool trans, int m, int n, DType alpha, const DType *A, int lda, const DType *X, int incX, DType beta, DType *Y, int incY, int batch_count)
Definition: dot_engine-inl.h:106
mshadow::Tensor::shape_
Shape< dimension > shape_
shape of the tensor
Definition: tensor.h:541
mshadow::expr::BLASEngine::batched_ger
static void batched_ger(Stream< Device > *stream, int m, int n, DType alpha, const DType *X, int incX, const DType *Y, int incY, DType *A, int lda, int batch_count)
Definition: dot_engine-inl.h:119
mshadow::expr::BLASEngine< gpu, half::half_t >::dot
static void dot(Stream< gpu > *stream, int n, const half::half_t *X, int incX, const half::half_t *Y, int incY, half::half_t *ret)
Definition: dot_engine-inl.h:608
mshadow::index_t
int32_t index_t
type that will be used for index
Definition: base.h:328
mshadow::expr::BLASEngine< cpu, float >::batched_gemm
static void batched_gemm(Stream< cpu > *stream, bool transa, bool transb, index_t m, index_t n, index_t k, float alpha, const float *A, index_t lda, const float *B, index_t ldb, float beta, float *C, index_t ldc, index_t batch_count, float **workspace)
Definition: dot_engine-inl.h:307
mshadow::Stream< gpu >::prop
cudaDeviceProp prop
cudaDeviceProp
Definition: stream_gpu-inl.h:69
mshadow::expr::BLASEngine< gpu, double >::ger
static void ger(Stream< gpu > *stream, int m, int n, double alpha, const double *X, int incX, const double *Y, int incY, double *A, int lda)
Definition: dot_engine-inl.h:819
mshadow::expr::BLASEngine< gpu, float >::batched_gemv
static void batched_gemv(Stream< gpu > *stream, bool trans, int m, int n, float alpha, const float *A, int lda, const float *X, int incX, float beta, float *Y, int incY, int batch_count)
Definition: dot_engine-inl.h:691
mshadow::expr::BLASEngine< cpu, double >::GetT
static CBLAS_TRANSPOSE GetT(bool t)
Definition: dot_engine-inl.h:403
mshadow::expr::BLASEngine< gpu, half::half_t >::GetT
static cublasOperation_t GetT(bool t)
Definition: dot_engine-inl.h:515
mshadow::expr::BLASEngine::gemv
static void gemv(Stream< Device > *stream, bool trans, int m, int n, DType alpha, const DType *A, int lda, const DType *X, int incX, DType beta, DType *Y, int incY)
Definition: dot_engine-inl.h:99
mshadow::expr::DotEngine< SV, xpu, 1, 1, 2, false, transpose_right, DType >::Eval
static void Eval(Tensor< xpu, 1, DType > *p_dst, const Tensor< xpu, 1, DType > &lhs, const Tensor< xpu, 2, DType > &rhs, DType scale)
Definition: dot_engine-inl.h:898
mshadow::expr::BLASEngine< cpu, float >::gemm
static void gemm(Stream< cpu > *stream, bool transa, bool transb, index_t m, index_t n, index_t k, float alpha, const float *A, index_t lda, const float *B, index_t ldb, float beta, float *C, index_t ldc)
Definition: dot_engine-inl.h:299
mshadow::expr::implicit_dot
ImplicitGEMMExp< LhsExp, RhsExp, DType > implicit_dot(const Exp< LhsExp, DType, e1 > &lhs, const Exp< RhsExp, DType, e2 > &rhs)
Definition: implicit_gemm.h:64
mshadow::expr::BLASEngine< cpu, float >::dot
static void dot(Stream< cpu > *stream, int n, const float *X, int incX, const float *Y, int incY, float *ret)
Definition: dot_engine-inl.h:392
mshadow::expr::BLASEngine< gpu, float >::batched_ger
static void batched_ger(Stream< gpu > *stream, int m, int n, float alpha, const float *X, int incX, const float *Y, int incY, float *A, int lda, int batch_count)
Definition: dot_engine-inl.h:710
mshadow::expr::BLASEngine< cpu, float >::batched_gemv
static void batched_gemv(Stream< cpu > *stream, bool trans, int m, int n, float alpha, const float *A, int lda, const float *X, int incX, float beta, float *Y, int incY, int batch_count)
Definition: dot_engine-inl.h:366
mshadow
overloaded + operator between half_t and bf16_t
Definition: base.h:319
mshadow::expr::BLASEngine< gpu, float >::ger
static void ger(Stream< gpu > *stream, int m, int n, float alpha, const float *X, int incX, const float *Y, int incY, float *A, int lda)
Definition: dot_engine-inl.h:702
mshadow::expr::transpose
TransposeExExp< SrcExp, DType, ExpInfo< SrcExp >::kDim > transpose(const Exp< SrcExp, DType, etype > &src, Shape< ExpInfo< SrcExp >::kDim > axes)
a expression that reshapes a tensor to another shape
Definition: transpose.h:76
mxnet::cpu
mshadow::cpu cpu
mxnet cpu
Definition: base.h:77
mshadow::Tensor::FlatTo2D
MSHADOW_XINLINE Tensor< Device, 2, DType > FlatTo2D(void) const
flatten the tensor to 2 dimension, collapse the higher dimensions together
Definition: tensor.h:624
mshadow::Shape2
MSHADOW_XINLINE Shape< 2 > Shape2(index_t s0, index_t s1)
construct a two dimension shape, stride will equal s0
Definition: tensor.h:230
implicit_gemm.h
support for implicit GEMM operation
mshadow::expr::DotEngine
Definition: dot_engine-inl.h:70
mshadow::expr::DotEngine::Eval
static void Eval(Tensor< Device, ddim, DType > *p_dst, const Tensor< Device, ldim, DType > &lhs, const Tensor< Device, rdim, DType > &rhs, DType scale)
mshadow::expr::BLASEngine< gpu, double >::batched_gemm
static void batched_gemm(Stream< gpu > *stream, bool transa, bool transb, int m, int n, int k, double alpha, const double *A, int lda, const double *B, int ldb, double beta, double *C, int ldc, int batch_count, double **workspace)
Definition: dot_engine-inl.h:755
mshadow::Shape< 2 >
mshadow::Tensor::dptr_
DType * dptr_
pointer to the data
Definition: tensor.h:539
mshadow::expr::BLASEngine< cpu, double >::ger
static void ger(Stream< cpu > *stream, int m, int n, double alpha, const double *X, int incX, const double *Y, int incY, double *A, int lda)
Definition: dot_engine-inl.h:486
mshadow::GetBatchedView
void GetBatchedView(DType **dst, DType *src, int num, int stride, Stream< Device > *stream)
CPU/GPU: Get a batched view of the src array. dst[i] = src + i * stride.
mshadow::expr::DotEngine< SV, xpu, 2, 2, 2, transpose_left, transpose_right, DType >::Eval
static void Eval(Tensor< xpu, 2, DType > *p_dst, const Tensor< xpu, 2, DType > &lhs, const Tensor< xpu, 2, DType > &rhs, DType scale)
Definition: dot_engine-inl.h:859
mshadow::Tensor::size
MSHADOW_XINLINE index_t size(int idx) const
return size of i-th dimension, start counting from highest dimension
Definition: tensor.h:610
mshadow::expr::BLASEngine< cpu, float >::gemv
static void gemv(Stream< cpu > *stream, bool trans, int m, int n, float alpha, const float *A, int lda, const float *X, int incX, float beta, float *Y, int incY)
Definition: dot_engine-inl.h:358
mshadow::expr::BLASEngine< cpu, double >::batched_ger
static void batched_ger(Stream< cpu > *stream, int m, int n, double alpha, const double *X, int incX, const double *Y, int incY, double *A, int lda, int batch_count)
Definition: dot_engine-inl.h:492
mshadow::expr::BLASEngine< gpu, double >::GetT
static cublasOperation_t GetT(bool t)
Definition: dot_engine-inl.h:736
mshadow::expr::BLASEngine< gpu, float >::batched_gemm
static void batched_gemm(Stream< gpu > *stream, bool transa, bool transb, int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, float beta, float *C, int ldc, int batch_count, float **workspace)
Definition: dot_engine-inl.h:638
mshadow::expr::BLASEngine::gemm
static void gemm(Stream< Device > *stream, bool transa, bool transb, int m, int n, int k, DType alpha, const DType *A, int lda, const DType *B, int ldb, DType beta, DType *C, int ldc)
Definition: dot_engine-inl.h:84
mshadow::expr::BLASEngine< cpu, float >::batched_ger
static void batched_ger(Stream< cpu > *stream, int m, int n, float alpha, const float *X, int incX, const float *Y, int incY, float *A, int lda, int batch_count)
Definition: dot_engine-inl.h:383
base.h
definitions of base types, operators, macros functions
mshadow::expr::BLASEngine< gpu, half::half_t >::batched_gemv
static void batched_gemv(Stream< gpu > *stream, bool trans, int m, int n, half::half_t alpha, const half::half_t *A, int lda, const half::half_t *X, int incX, half::half_t beta, half::half_t *Y, int incY, int batch_count)
Definition: dot_engine-inl.h:589
mshadow::Tensor::stride_
index_t stride_
storing the stride information in x dimension this is used to deal with pitch allocation in gpu or ss...
Definition: tensor.h:546