mxnet
dot_engine-inl.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
26 #ifndef MSHADOW_DOT_ENGINE_INL_H_
27 #define MSHADOW_DOT_ENGINE_INL_H_
28 
29 #include <vector>
30 #include "./base.h"
32 
33 #ifdef __CUDACC__
34 #include "./cuda/tensor_gpu-inl.cuh"
35 #endif // #ifdef __CUDACC__
36 
37 namespace mshadow {
46 template<typename Device, typename DType>
47 inline void GetBatchedView(DType **dst, DType *src, int num, int stride,
48  Stream<Device> *stream);
49 template<typename DType>
50 inline void GetBatchedView(DType **dst, DType *src, int num, int stride,
51  Stream<cpu> *stream) {
52  for (int i = 0; i < num; i++) {
53  dst[i] = src + i * stride;
54  }
55 }
56 #ifdef __CUDACC__
57 namespace cuda {};
58 template<typename DType>
59 inline void GetBatchedView(DType **dst, DType *src, int num, int stride,
60  Stream<gpu> *stream) {
61  cuda::GetBatchedView(dst, src, num, stride, stream);
62 }
63 #endif // #ifdef __CUDACC__
64 
65 namespace expr {
66 //---------------------------------------------------------------------
67 // Matrix Multiplications, depends on BLAS Engine
68 //---------------------------------------------------------------------
69 template<typename SV, typename Device, int ddim, int ldim,
70  int rdim, bool ltrans, bool rtrans, typename DType>
71 struct DotEngine {
72  inline static void Eval(Tensor<Device, ddim, DType> *p_dst,
73  const Tensor<Device, ldim, DType> &lhs,
74  const Tensor<Device, rdim, DType> &rhs,
75  DType scale);
76 };
77 // handles the dot, use CblasColMajor
78 template<typename Device, typename DType = default_real_t>
79 struct BLASEngine {
80  inline static bool GetT(bool t) {
81  return t ? true : false;
82  }
83  inline static void SetStream(Stream<Device> *stream) {
84  }
85  inline static void gemm(Stream<Device> *stream,
86  bool transa, bool transb,
87  int m, int n, int k, DType alpha,
88  const DType *A, int lda, const DType *B, int ldb,
89  DType beta, DType *C, int ldc) {
90  LOG(FATAL) << "Not implmented!";
91  }
92  inline static void batched_gemm(Stream<Device> *stream,
93  bool transa, bool transb,
94  int m, int n, int k, DType alpha,
95  const DType *A, int lda, const DType *B, int ldb,
96  DType beta, DType *C, int ldc, int batch_count,
97  DType **workspace) {
98  LOG(FATAL) << "Not implmented!";
99  }
100  inline static void gemv(Stream<Device> *stream,
101  bool trans, int m, int n,
102  DType alpha, const DType *A, int lda,
103  const DType *X, int incX,
104  DType beta, DType *Y, int incY) {
105  LOG(FATAL) << "Not implmented!";
106  }
107  inline static void batched_gemv(Stream<Device> *stream,
108  bool trans, int m, int n,
109  DType alpha, const DType *A, int lda,
110  const DType *X, int incX,
111  DType beta, DType *Y, int incY, int batch_count) {
112  LOG(FATAL) << "Not implmented!";
113  }
114  inline static void ger(Stream<Device> *stream,
115  int m, int n, DType alpha,
116  const DType *X, int incX,
117  const DType *Y, int incY, DType *A, int lda) {
118  LOG(FATAL) << "Not implmented!";
119  }
120  inline static void batched_ger(Stream<Device> *stream,
121  int m, int n, DType alpha,
122  const DType *X, int incX,
123  const DType *Y, int incY, DType *A, int lda, int batch_count) {
124  LOG(FATAL) << "Not implmented!";
125  }
126  inline static void dot(Stream<Device> *stream,
127  int n,
128  const DType* X, int incX,
129  const DType* Y, int incY,
130  DType* ret) {
131  LOG(FATAL) << "Not implmented!";
132  }
133 };
134 
135 #if MSHADOW_STAND_ALONE
136 template<>
137 struct BLASEngine<cpu, float> {
138  inline static bool GetT(bool t) {
139  return t ? true : false;
140  }
141  inline static void SetStream(Stream<cpu> *stream) {
142  }
143  inline static void gemm(Stream<cpu> *stream,
144  bool transa, bool transb,
145  int m, int n, int k, float alpha,
146  const float *A, int lda, const float *B, int ldb,
147  float beta, float *C, int ldc) {
148  if (alpha == 1.0f && beta == 0.0f) {
149  bool transpose_left = transb;
150  bool transpose_right = transa;
151  Tensor<cpu, 2, float> lhs((float*)B, Shape2(transpose_left ? k : n, transpose_left ? n : k)); // NOLINT(*)
152  Tensor<cpu, 2, float> rhs((float*)A, Shape2(transpose_right ? m : k, transpose_right ? k : m)); // NOLINT(*)
153  Tensor<cpu, 2, float> dst(C, Shape2(m, n));
154  if (!transpose_left && !transpose_right) {
155  dst = expr::implicit_dot(lhs, rhs); return;
156  } else if (!transpose_left && transpose_right) {
157  dst = expr::implicit_dot(lhs, rhs.T()); return;
158  } else if (transpose_left && !transpose_right) {
159  dst = expr::implicit_dot(lhs.T(), rhs); return;
160  } else {
161  LOG(FATAL) << "Not implmented!";
162  }
163  } else {
164  LOG(FATAL) << "Not implmented!";
165  }
166  }
167  inline static void batched_gemm(Stream<cpu> *stream,
168  bool transa, bool transb,
169  int m, int n, int k, float alpha,
170  const float *A, int lda, const float *B, int ldb,
171  float beta, float *C, int ldc, int batch_count,
172  float **workspace) {
173  for (int i = 0; i < batch_count; ++i) {
174  gemm(stream, transa, transb, m, n, k, alpha,
175  A + i * m * k, lda, B + i * k * n, ldb,
176  beta, C + i * m * n, ldc);
177  }
178  }
179  inline static void gemv(Stream<cpu> *stream,
180  bool trans, int m, int n,
181  float alpha, const float *A, int lda,
182  const float *X, int incX,
183  float beta, float *Y, int incY) {
184  LOG(FATAL) << "Not implmented!";
185  }
186  inline static void batched_gemv(Stream<cpu> *stream,
187  bool trans, int m, int n,
188  float alpha, const float *A, int lda,
189  const float *X, int incX,
190  float beta, float *Y, int incY, int batch_count) {
191  LOG(FATAL) << "Not implmented!";
192  }
193  inline static void ger(Stream<cpu> *stream,
194  int m, int n, float alpha,
195  const float *X, int incX,
196  const float *Y, int incY, float *A, int lda) {
197  LOG(FATAL) << "Not implmented!";
198  }
199  inline static void batched_ger(Stream<cpu> *stream,
200  int m, int n, float alpha,
201  const float *X, int incX,
202  const float *Y, int incY, float *A, int lda, int batch_count) {
203  LOG(FATAL) << "Not implmented!";
204  }
205  inline static void dot(Stream<cpu> *stream,
206  int n,
207  const float* X, int incX,
208  const float* Y, int incY,
209  float* ret) {
210  LOG(FATAL) << "Not implmented!";
211  }
212 };
213 
214 template<>
215 struct BLASEngine<cpu, double> {
216  inline static bool GetT(bool t) {
217  return t ? true : false;
218  }
219  inline static void SetStream(Stream<cpu> *stream) {
220  }
221  inline static void gemm(Stream<cpu> *stream,
222  bool transa, bool transb,
223  int m, int n, int k, double alpha,
224  const double *A, int lda, const double *B, int ldb,
225  double beta, double *C, int ldc) {
226  if (alpha == 1.0f && beta == 0.0f) {
227  bool transpose_left = transb;
228  bool transpose_right = transa;
229  Tensor<cpu, 2, double> lhs((double*)B, Shape2(transpose_left ? k : n, transpose_left ? n : k)); // NOLINT(*)
230  Tensor<cpu, 2, double> rhs((double*)A, Shape2(transpose_right ? m : k, transpose_right ? k : m)); // NOLINT(*)
231  Tensor<cpu, 2, double> dst(C, Shape2(m, n));
232  if (!transpose_left && !transpose_right) {
233  dst = expr::implicit_dot(lhs, rhs); return;
234  } else if (!transpose_left && transpose_right) {
235  dst = expr::implicit_dot(lhs, rhs.T()); return;
236  } else if (transpose_left && !transpose_right) {
237  dst = expr::implicit_dot(lhs.T(), rhs); return;
238  } else {
239  LOG(FATAL) << "Not implmented!";
240  }
241  } else {
242  LOG(FATAL) << "Not implmented!";
243  }
244  }
245  inline static void batched_gemm(Stream<cpu> *stream,
246  bool transa, bool transb,
247  int m, int n, int k, double alpha,
248  const double *A, int lda, const double *B, int ldb,
249  double beta, double *C, int ldc, int batch_count,
250  double **workspace) {
251  for (int i = 0; i < batch_count; ++i) {
252  gemm(stream, transa, transb, m, n, k, alpha,
253  A + i * m * k, lda, B + i * k * n, ldb,
254  beta, C + i * m * n, ldc);
255  }
256  }
257  inline static void gemv(Stream<cpu> *stream,
258  bool trans, int m, int n,
259  double alpha, const double *A, int lda,
260  const double *X, int incX,
261  double beta, double *Y, int incY) {
262  LOG(FATAL) << "Not implmented!";
263  }
264  inline static void batched_gemv(Stream<cpu> *stream,
265  bool trans, int m, int n,
266  double alpha, const double *A, int lda,
267  const double *X, int incX,
268  double beta, double *Y, int incY, int batch_count) {
269  LOG(FATAL) << "Not implmented!";
270  }
271  inline static void ger(Stream<cpu> *stream,
272  int m, int n, double alpha,
273  const double *X, int incX,
274  const double *Y, int incY, double *A, int lda) {
275  LOG(FATAL) << "Not implmented!";
276  }
277  inline static void batched_ger(Stream<cpu> *stream,
278  int m, int n, double alpha,
279  const double *X, int incX,
280  const double *Y, int incY, double *A, int lda, int batch_count) {
281  LOG(FATAL) << "Not implmented!";
282  }
283  inline static void dot(Stream<cpu> *stream,
284  int n,
285  const double* X, int incX,
286  const double* Y, int incY,
287  double* ret) {
288  LOG(FATAL) << "Not implmented!";
289  }
290 };
291 
292 #elif (MSHADOW_USE_MKL || MSHADOW_USE_CBLAS) // NOLINT(*)
293 template<>
294 struct BLASEngine<cpu, float> {
295  inline static CBLAS_TRANSPOSE GetT(bool t) {
296  return t ? CblasTrans : CblasNoTrans;
297  }
298  inline static void SetStream(Stream<cpu> *stream) {
299  }
300  inline static void gemm(Stream<cpu> *stream,
301  bool transa, bool transb,
302  int m, int n, int k, float alpha,
303  const float *A, int lda, const float *B, int ldb,
304  float beta, float *C, int ldc) {
305  cblas_sgemm(CblasColMajor, GetT(transa), GetT(transb),
306  m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
307  }
308  inline static void batched_gemm(Stream<cpu> *stream,
309  bool transa, bool transb,
310  int m, int n, int k, float alpha,
311  const float *A, int lda, const float *B, int ldb,
312  float beta, float *C, int ldc, int batch_count,
313  float **workspace) {
314 #if (MSHADOW_USE_MKL && INTEL_MKL_VERSION >= 20160000)
315  // since same m/n/k is used for all single gemms, so we put all gemms into one group
316  const int GROUP_SIZE = 1;
317  MKL_INT p_m[GROUP_SIZE] = {m};
318  MKL_INT p_n[GROUP_SIZE] = {n};
319  MKL_INT p_k[GROUP_SIZE] = {k};
320  MKL_INT p_lda[GROUP_SIZE] = {lda};
321  MKL_INT p_ldb[GROUP_SIZE] = {ldb};
322  MKL_INT p_ldc[GROUP_SIZE] = {ldc};
323 
324  float p_alpha[GROUP_SIZE] = {alpha};
325  float p_beta[GROUP_SIZE] = {beta};
326 
327  CBLAS_TRANSPOSE cblas_a_trans = GetT(transa);
328  CBLAS_TRANSPOSE cblas_b_trans = GetT(transb);
329 
330  MKL_INT p_group_sizeb[GROUP_SIZE] = {batch_count};
331  CBLAS_TRANSPOSE p_transa[GROUP_SIZE] = {cblas_a_trans};
332  CBLAS_TRANSPOSE p_transb[GROUP_SIZE] = {cblas_b_trans};
333 
334  std::vector<const float*> pp_A(batch_count, nullptr);
335  std::vector<const float*> pp_B(batch_count, nullptr);
336  std::vector<float*> pp_C(batch_count, nullptr);
337 
338  auto m_k = m * k;
339  auto k_n = k * n;
340  auto m_n = m * n;
341 
342  for (int i = 0; i < batch_count; i++) {
343  pp_A[i] = A + i * m_k;
344  pp_B[i] = B + i * k_n;
345  pp_C[i] = C + i * m_n;
346  }
347 
348  cblas_sgemm_batch(CblasColMajor, p_transa, p_transb,
349  p_m, p_n, p_k, p_alpha, pp_A.data(), p_lda, pp_B.data(),
350  p_ldb, p_beta, pp_C.data(), p_ldc, GROUP_SIZE, p_group_sizeb);
351 #else
352  for (int i = 0; i < batch_count; ++i) {
353  gemm(stream, transa, transb, m, n, k, alpha,
354  A + i * m * k, lda, B + i * k * n, ldb,
355  beta, C + i * m * n, ldc);
356  }
357 #endif
358  }
359  inline static void gemv(Stream<cpu> *stream,
360  bool trans, int m, int n,
361  float alpha, const float *A, int lda,
362  const float *X, int incX,
363  float beta, float *Y, int incY) {
364  cblas_sgemv(CblasColMajor, GetT(trans), m, n, alpha,
365  A, lda, X, incX, beta, Y, incY);
366  }
367  inline static void batched_gemv(Stream<cpu> *stream,
368  bool trans, int m, int n,
369  float alpha, const float *A, int lda,
370  const float *X, int incX,
371  float beta, float *Y, int incY, int batch_count) {
372  for (int i = 0; i < batch_count; ++i) {
373  gemv(stream, trans, m, n, alpha, A + i * m * n, lda,
374  X + i * (trans ? m : n) * incX, incX,
375  beta, Y + i * (trans ? n : m) * incY, incY);
376  }
377  }
378  inline static void ger(Stream<cpu> *stream,
379  int m, int n, float alpha,
380  const float *X, int incX,
381  const float *Y, int incY, float *A, int lda) {
382  cblas_sger(CblasColMajor, m, n, alpha, X, incX, Y, incY, A, lda);
383  }
384  inline static void batched_ger(Stream<cpu> *stream,
385  int m, int n, float alpha,
386  const float *X, int incX,
387  const float *Y, int incY, float *A, int lda, int batch_count) {
388  for (int i = 0; i < batch_count; ++i) {
389  ger(stream, m, n, alpha, X + i * m * incX, incX, Y + i * n * incY, incY,
390  A + i * lda * n, lda);
391  }
392  }
393  inline static void dot(Stream<cpu> *stream,
394  int n,
395  const float* X, int incX,
396  const float* Y, int incY,
397  float* ret) {
398  *ret = cblas_sdot(n, X, incX, Y, incY);
399  }
400 };
401 
402 template<>
403 struct BLASEngine<cpu, double> {
404  inline static CBLAS_TRANSPOSE GetT(bool t) {
405  return t ? CblasTrans : CblasNoTrans;
406  }
407  inline static void SetStream(Stream<cpu> *stream) {
408  }
409  inline static void gemm(Stream<cpu> *stream,
410  bool transa, bool transb,
411  int m, int n, int k, double alpha,
412  const double *A, int lda, const double *B, int ldb,
413  double beta, double *C, int ldc) {
414  cblas_dgemm(CblasColMajor, GetT(transa), GetT(transb),
415  m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
416  }
417  inline static void batched_gemm(Stream<cpu> *stream,
418  bool transa, bool transb,
419  int m, int n, int k, double alpha,
420  const double *A, int lda, const double *B, int ldb,
421  double beta, double *C, int ldc, int batch_count,
422  double **workspace) {
423 #if (MSHADOW_USE_MKL && INTEL_MKL_VERSION >= 20160000)
424  // since same m/n/k is used for all single gemms, so we put all gemms into one group
425  const int GROUP_SIZE = 1;
426  MKL_INT p_m[GROUP_SIZE] = {m};
427  MKL_INT p_n[GROUP_SIZE] = {n};
428  MKL_INT p_k[GROUP_SIZE] = {k};
429  MKL_INT p_lda[GROUP_SIZE] = {lda};
430  MKL_INT p_ldb[GROUP_SIZE] = {ldb};
431  MKL_INT p_ldc[GROUP_SIZE] = {ldc};
432 
433  double p_alpha[GROUP_SIZE] = {alpha};
434  double p_beta[GROUP_SIZE] = {beta};
435 
436  CBLAS_TRANSPOSE cblas_a_trans = GetT(transa);
437  CBLAS_TRANSPOSE cblas_b_trans = GetT(transb);
438 
439  MKL_INT p_group_sizeb[GROUP_SIZE] = {batch_count};
440  CBLAS_TRANSPOSE p_transa[GROUP_SIZE] = {cblas_a_trans};
441  CBLAS_TRANSPOSE p_transb[GROUP_SIZE] = {cblas_b_trans};
442 
443  std::vector<const double*> pp_A(batch_count, nullptr);
444  std::vector<const double*> pp_B(batch_count, nullptr);
445  std::vector<double*> pp_C(batch_count, nullptr);
446 
447  auto m_k = m * k;
448  auto k_n = k * n;
449  auto m_n = m * n;
450 
451  for (int i = 0; i < batch_count; i++) {
452  pp_A[i] = A + i * m_k;
453  pp_B[i] = B + i * k_n;
454  pp_C[i] = C + i * m_n;
455  }
456 
457  cblas_dgemm_batch(CblasColMajor, p_transa, p_transb,
458  p_m, p_n, p_k, p_alpha, pp_A.data(), p_lda, pp_B.data(),
459  p_ldb, p_beta, pp_C.data(), p_ldc, GROUP_SIZE, p_group_sizeb);
460 #else
461  for (int i = 0; i < batch_count; ++i) {
462  gemm(stream, transa, transb, m, n, k, alpha,
463  A + i * m * k, lda, B + i * k * n, ldb,
464  beta, C + i * m * n, ldc);
465  }
466 #endif
467  }
468  inline static void gemv(Stream<cpu> *stream,
469  bool trans, int m, int n, double alpha,
470  const double *A, int lda,
471  const double *X, int incX,
472  double beta, double *Y, int incY) {
473  cblas_dgemv(CblasColMajor, GetT(trans), m, n, alpha,
474  A, lda, X, incX, beta, Y, incY);
475  }
476  inline static void batched_gemv(Stream<cpu> *stream,
477  bool trans, int m, int n,
478  double alpha, const double *A, int lda,
479  const double *X, int incX,
480  double beta, double *Y, int incY, int batch_count) {
481  for (int i = 0; i < batch_count; ++i) {
482  gemv(stream, trans, m, n, alpha, A + i * m * n, lda,
483  X + i * (trans ? m : n) * incX, incX,
484  beta, Y + i * (trans ? n : m) * incY, incY);
485  }
486  }
487  inline static void ger(Stream<cpu> *stream,
488  int m, int n, double alpha,
489  const double *X, int incX,
490  const double *Y, int incY, double *A, int lda) {
491  cblas_dger(CblasColMajor, m, n, alpha, X, incX, Y, incY, A, lda);
492  }
493  inline static void batched_ger(Stream<cpu> *stream,
494  int m, int n, double alpha,
495  const double *X, int incX,
496  const double *Y, int incY, double *A, int lda, int batch_count) {
497  for (int i = 0; i < batch_count; ++i) {
498  ger(stream, m, n, alpha, X + i * m * incX, incX, Y + i * n * incY, incY,
499  A + i * lda * n, lda);
500  }
501  }
502  inline static void dot(Stream<cpu> *stream,
503  int n,
504  const double* X, int incX,
505  const double* Y, int incY,
506  double* ret) {
507  *ret = cblas_ddot(n, X, incX, Y, incY);
508  }
509 };
510 #endif // MSHADOW_USE_CBLAS || MSHADOW_USE_MKL || MSHADOW_STAND_ALONE
511 // CuBLAS redirect code
512 #if MSHADOW_USE_CUDA
513 // All CuBLAS goes to here, use legacy API: not threadsafe
514 template<>
515 struct BLASEngine<gpu, half::half_t> {
516  inline static cublasOperation_t GetT(bool t) {
517  return t ? CUBLAS_OP_T : CUBLAS_OP_N;
518  }
519  inline static void SetStream(Stream<gpu> *stream) {
520  cublasStatus_t err = cublasSetStream(Stream<gpu>::GetBlasHandle(stream),
521  Stream<gpu>::GetStream(stream));
522  CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas set stream fail";
523  }
524  inline static void gemm(Stream<gpu> *stream,
525  bool transa, bool transb,
526  int m, int n, int k, half::half_t alpha,
527  const half::half_t *A, int lda,
528  const half::half_t *B, int ldb, half::half_t beta,
529  half::half_t *C, int ldc) {
530 #if defined(CUDA_VERSION) && CUDA_VERSION >= 7050
531  // Always use pseudo-fp16: fp32 compute with fp16 I/O.
532  float alpha_f = float(alpha); // NOLINT(*)
533  float beta_f = float(beta); // NOLINT(*)
534  #if CUDA_VERSION >= 8000
535  cublasStatus_t err = cublasSgemmEx(Stream<gpu>::GetBlasHandle(stream),
536  GetT(transa), GetT(transb), m, n, k, &alpha_f,
537  A, CUDA_R_16F, lda, B, CUDA_R_16F,
538  ldb, &beta_f, C, CUDA_R_16F, ldc);
539  CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas SgemmEx fail";
540  #else
541  cublasStatus_t err = cublasSgemmEx(Stream<gpu>::GetBlasHandle(stream),
542  GetT(transa), GetT(transb), m, n, k, &alpha_f,
543  A, CUBLAS_DATA_HALF, lda, B, CUBLAS_DATA_HALF,
544  ldb, &beta_f, C, CUBLAS_DATA_HALF, ldc);
545  CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas SgemmEx fail";
546  #endif // CUDA_VERSION >= 8000
547 #else
548  LOG(FATAL) << "Require CUDA version >= 7.5!";
549 #endif // defined(CUDA_VERSION) && CUDA_VERSION >= 7050
550  }
551  inline static void batched_gemm(Stream<gpu> *stream,
552  bool transa, bool transb,
553  int m, int n, int k, half::half_t alpha,
554  const half::half_t *A, int lda, const half::half_t *B, int ldb,
555  half::half_t beta, half::half_t *C, int ldc, int batch_count,
556  half::half_t **workspace) {
557 #if defined(__CUDACC__) && CUDA_VERSION >= 9000
558  int major = stream->prop.major;
559  int minor = stream->prop.minor;
560  // fp16 is not supported before ARCH 53
561  if ((major > 5) || (major == 5 && minor >= 3)) {
562  const __half* A_h = reinterpret_cast<const __half*>(A);
563  const __half* B_h = reinterpret_cast<const __half*>(B);
564  __half* alpha_h = reinterpret_cast<__half*>(&alpha);
565  __half* beta_h = reinterpret_cast<__half*>(&beta);
566  __half* C_h = reinterpret_cast<__half*>(C);
567  cublasStatus_t err = cublasHgemmStridedBatched(Stream<gpu>::GetBlasHandle(stream),
568  GetT(transa), GetT(transb), m, n, k, alpha_h,
569  A_h, lda, m * k,
570  B_h, ldb, k * n,
571  beta_h, C_h, ldc, m * n,
572  batch_count);
573  CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: HgemmStridedBatched fail";
574  return;
575  }
576 #endif
577  for (int i = 0; i < batch_count; ++i) {
578  gemm(stream, transa, transb, m, n, k, alpha,
579  A + i * m * k, lda, B + i * k * n, ldb,
580  beta, C + i * m * n, ldc);
581  }
582  }
583  inline static void gemv(Stream<gpu> *stream,
584  bool trans, int m, int n, half::half_t alpha,
585  const half::half_t *A, int lda,
586  const half::half_t *X, int incX, half::half_t beta,
587  half::half_t *Y, int incY) {
588  LOG(FATAL) << "Not implmented!";
589  }
590  inline static void batched_gemv(Stream<gpu> *stream,
591  bool trans, int m, int n,
592  half::half_t alpha, const half::half_t *A, int lda,
593  const half::half_t *X, int incX,
594  half::half_t beta, half::half_t *Y, int incY, int batch_count) {
595  LOG(FATAL) << "Not implmented!";
596  }
597  inline static void ger(Stream<gpu> *stream,
598  int m, int n, half::half_t alpha,
599  const half::half_t *X, int incX,
600  const half::half_t *Y, int incY, half::half_t *A, int lda) {
601  LOG(FATAL) << "Not implmented!";
602  }
603  inline static void batched_ger(Stream<gpu> *stream,
604  int m, int n, half::half_t alpha,
605  const half::half_t *X, int incX, const half::half_t *Y, int incY,
606  half::half_t *A, int lda, int batch_count) {
607  LOG(FATAL) << "Not implmented!";
608  }
609  inline static void dot(Stream<gpu> *stream,
610  int n,
611  const half::half_t* X, int incX,
612  const half::half_t* Y, int incY,
613  half::half_t *ret) {
614  LOG(FATAL) << "Not implmented!";
615  }
616 };
617 
618 template<>
619 struct BLASEngine<gpu, float> {
620  inline static cublasOperation_t GetT(bool t) {
621  return t ? CUBLAS_OP_T : CUBLAS_OP_N;
622  }
623  inline static void SetStream(Stream<gpu> *stream) {
624  cublasStatus_t err = cublasSetStream(Stream<gpu>::GetBlasHandle(stream),
625  Stream<gpu>::GetStream(stream));
626  CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: set stream fail";
627  }
628  inline static void gemm(Stream<gpu> *stream,
629  bool transa, bool transb,
630  int m, int n, int k, float alpha,
631  const float *A, int lda,
632  const float *B, int ldb, float beta,
633  float *C, int ldc) {
634  cublasStatus_t err = cublasSgemm(Stream<gpu>::GetBlasHandle(stream),
635  GetT(transa), GetT(transb), m, n, k, &alpha,
636  A, lda, B, ldb, &beta, C, ldc);
637  CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Sgemm fail";
638  }
639  inline static void batched_gemm(Stream<gpu> *stream,
640  bool transa, bool transb,
641  int m, int n, int k, float alpha,
642  const float *A, int lda, const float *B, int ldb,
643  float beta, float *C, int ldc, int batch_count,
644  float **workspace) {
645 #if defined(__CUDACC__) && CUDA_VERSION >= 4010 && CUDA_VERSION < 8000
646  // Cast DType* to DType** using workspace as a buffer
647  bool alloc_workspace = false;
648  if (workspace == NULL) {
649  // Allocate the workspace if it's NULL.
650  // TODO(sxjscience) Try to move the allocation inside Tensor, which is thread-safe.
651  cudaMalloc(reinterpret_cast<void**>(&workspace), 3 * batch_count * sizeof(float*));
652  alloc_workspace = true;
653  }
654  GetBatchedView(workspace, const_cast<float*>(A), batch_count, m * k, stream);
655  GetBatchedView(workspace + batch_count,
656  const_cast<float*>(B), batch_count, k * n, stream);
657  GetBatchedView(workspace + 2 * batch_count, C, batch_count, m * n, stream);
658  cublasStatus_t err = cublasSgemmBatched(Stream<gpu>::GetBlasHandle(stream),
659  GetT(transa), GetT(transb), m, n, k, &alpha,
660  (const float**)workspace, lda,
661  (const float**)(workspace + batch_count), ldb,
662  &beta, workspace + 2 * batch_count, ldc, batch_count);
663  CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: SgemmBatched fail";
664  if (alloc_workspace) {
665  cudaFree(workspace);
666  }
667 #elif defined(__CUDACC__) && CUDA_VERSION >= 8000
668  cublasStatus_t err = cublasSgemmStridedBatched(Stream<gpu>::GetBlasHandle(stream),
669  GetT(transa), GetT(transb), m, n, k, &alpha,
670  A, lda, m * k,
671  B, ldb, k * n,
672  &beta, C, ldc, m * n,
673  batch_count);
674  CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: SgemmStridedBatched fail";
675 #else
676  for (int i = 0; i < batch_count; ++i) {
677  gemm(stream, transa, transb, m, n, k, alpha,
678  A + i * m * k, lda, B + i * k * n, ldb,
679  beta, C + i * m * n, ldc);
680  }
681 #endif // defined(__CUDACC__) && CUDA_VERSION >= 4010
682  }
683  inline static void gemv(Stream<gpu> *stream,
684  bool trans, int m, int n, float alpha,
685  const float *A, int lda,
686  const float *X, int incX, float beta,
687  float *Y, int incY) {
688  cublasStatus_t err = cublasSgemv(Stream<gpu>::GetBlasHandle(stream),
689  GetT(trans), m, n, &alpha, A, lda, X, incX, &beta, Y, incY);
690  CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Sgemv fail";
691  }
692  inline static void batched_gemv(Stream<gpu> *stream,
693  bool trans, int m, int n,
694  float alpha, const float *A, int lda,
695  const float *X, int incX,
696  float beta, float *Y, int incY, int batch_count) {
697  for (int i = 0; i < batch_count; ++i) {
698  gemv(stream, trans, m, n, alpha, A + i * m * n, lda,
699  X + i * (trans ? m : n) * incX, incX,
700  beta, Y + i * (trans ? n : m) * incY, incY);
701  }
702  }
703  inline static void ger(Stream<gpu> *stream,
704  int m, int n, float alpha,
705  const float *X, int incX,
706  const float *Y, int incY, float *A, int lda) {
707  cublasStatus_t err = cublasSger(Stream<gpu>::GetBlasHandle(stream),
708  m, n, &alpha, X, incX, Y, incY, A, lda);
709  CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Sger fail";
710  }
711  inline static void batched_ger(Stream<gpu> *stream,
712  int m, int n, float alpha,
713  const float *X, int incX,
714  const float *Y, int incY, float *A, int lda, int batch_count) {
715  for (int i = 0; i < batch_count; ++i) {
716  ger(stream, m, n, alpha, X + i * m * incX, incX, Y + i * n * incY, incY,
717  A + i * lda * n, lda);
718  }
719  }
720  inline static void dot(Stream<gpu> *stream,
721  int n,
722  const float* X, int incX,
723  const float* Y, int incY,
724  float *ret) {
725  cublasSetPointerMode(Stream<gpu>::GetBlasHandle(stream),
726  CUBLAS_POINTER_MODE_DEVICE);
727  cublasStatus_t err = cublasSdot(Stream<gpu>::GetBlasHandle(stream),
728  n, X, incX, Y, incY, ret);
729  CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Dot fail";
730  cublasSetPointerMode(Stream<gpu>::GetBlasHandle(stream),
731  CUBLAS_POINTER_MODE_HOST);
732  }
733 };
734 
735 template<>
736 struct BLASEngine<gpu, double> {
737  inline static cublasOperation_t GetT(bool t) {
738  return t ? CUBLAS_OP_T : CUBLAS_OP_N;
739  }
740  inline static void SetStream(Stream<gpu> *stream) {
741  cublasStatus_t err = cublasSetStream(Stream<gpu>::GetBlasHandle(stream),
742  Stream<gpu>::GetStream(stream));
743  CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: set stream fail";
744  }
745  inline static void gemm(Stream<gpu> *stream,
746  bool transa, bool transb,
747  int m, int n, int k, double alpha,
748  const double *A, int lda,
749  const double *B, int ldb,
750  double beta, double *C, int ldc) {
751  cublasStatus_t err = cublasDgemm(Stream<gpu>::GetBlasHandle(stream),
752  GetT(transa), GetT(transb), m, n, k, &alpha,
753  A, lda, B, ldb, &beta, C, ldc);
754  CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Dgemm fail";
755  }
756  inline static void batched_gemm(Stream<gpu> *stream,
757  bool transa, bool transb,
758  int m, int n, int k, double alpha,
759  const double *A, int lda, const double *B, int ldb,
760  double beta, double *C, int ldc, int batch_count,
761  double **workspace) {
762 #if defined(__CUDACC__) && CUDA_VERSION >= 4010 && CUDA_VERSION < 8000
763  // Cast DType* to DType** using workspace as a buffer
764  bool alloc_workspace = false;
765  if (workspace == NULL) {
766  // Allocate the workspace if it's NULL.
767  // TODO(sxjscience) Try to move the allocation inside Tensor, which is thread-safe.
768  cudaMalloc(reinterpret_cast<void**>(&workspace), 3 * batch_count * sizeof(double*));
769  alloc_workspace = true;
770  }
771  GetBatchedView(workspace, const_cast<double*>(A), batch_count, m * k, stream);
772  GetBatchedView(workspace + batch_count,
773  const_cast<double*>(B), batch_count, k * n, stream);
774  GetBatchedView(workspace + 2 * batch_count, C, batch_count, m * n, stream);
775  cublasStatus_t err = cublasDgemmBatched(Stream<gpu>::GetBlasHandle(stream),
776  GetT(transa), GetT(transb), m, n, k, &alpha,
777  (const double**)workspace, lda,
778  (const double**)(workspace + batch_count), ldb,
779  &beta, workspace + 2 * batch_count, ldc, batch_count);
780  CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: DgemmBatched fail";
781  if (alloc_workspace) {
782  cudaFree(workspace);
783  }
784 #elif defined(__CUDACC__) && CUDA_VERSION >= 8000
785  cublasStatus_t err = cublasDgemmStridedBatched(Stream<gpu>::GetBlasHandle(stream),
786  GetT(transa), GetT(transb), m, n, k, &alpha,
787  A, lda, m * k,
788  B, ldb, k * n,
789  &beta, C, ldc, m * n,
790  batch_count);
791  CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: DgemmStridedBatched fail";
792 #else
793  for (int i = 0; i < batch_count; ++i) {
794  gemm(stream, transa, transb, m, n, k, alpha,
795  A + i * m * k, lda, B + i * k * n, ldb,
796  beta, C + i * m * n, ldc);
797  }
798 #endif // defined(__CUDACC__) && CUDA_VERSION >= 4010
799  }
800  inline static void gemv(Stream<gpu> *stream,
801  bool trans, int m, int n, double alpha,
802  const double *A, int lda,
803  const double *X, int incX,
804  double beta, double *Y, int incY) {
805  cublasStatus_t err = cublasDgemv(Stream<gpu>::GetBlasHandle(stream),
806  GetT(trans), m, n, &alpha, A, lda, X, incX, &beta, Y, incY);
807  CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Dgemv fail";
808  }
809  inline static void batched_gemv(Stream<gpu> *stream,
810  bool trans, int m, int n,
811  double alpha, const double *A, int lda,
812  const double *X, int incX,
813  double beta, double *Y, int incY, int batch_count) {
814  for (int i = 0; i < batch_count; ++i) {
815  gemv(stream, trans, m, n, alpha, A + i * m * n, lda,
816  X + i * (trans ? m : n) * incX, incX,
817  beta, Y + i * (trans ? n : m) * incY, incY);
818  }
819  }
820  inline static void ger(Stream<gpu> *stream,
821  int m, int n, double alpha,
822  const double *X, int incX,
823  const double *Y, int incY, double *A, int lda) {
824  cublasStatus_t err = cublasDger(Stream<gpu>::GetBlasHandle(stream),
825  m, n, &alpha, X, incX, Y, incY, A, lda);
826  CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Dger fail";
827  }
828  inline static void batched_ger(Stream<gpu> *stream,
829  int m, int n, double alpha,
830  const double *X, int incX,
831  const double *Y, int incY, double *A, int lda, int batch_count) {
832  for (int i = 0; i < batch_count; ++i) {
833  ger(stream, m, n, alpha, X + i * m * incX, incX, Y + i * n * incY, incY,
834  A + i * lda * n, lda);
835  }
836  }
837  inline static void dot(Stream<gpu> *stream,
838  int n,
839  const double* X, int incX,
840  const double* Y, int incY,
841  double *ret) {
842  cublasSetPointerMode(Stream<gpu>::GetBlasHandle(stream),
843  CUBLAS_POINTER_MODE_DEVICE);
844  cublasStatus_t err = cublasDdot(Stream<gpu>::GetBlasHandle(stream),
845  n, X, incX, Y, incY, ret);
846  CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Dot fail";
847  cublasSetPointerMode(Stream<gpu>::GetBlasHandle(stream),
848  CUBLAS_POINTER_MODE_HOST);
849  }
850 };
851 #endif // MSHADOW_USE_CUDA
852 // helper function to decide which shape we are in
853 inline Shape<2> GetShape(const Shape<2> &shape, bool transpose) {
854  return transpose ? Shape2(shape[1], shape[0]) : shape;
855 }
856 // dst = dot(lhs[.T], rhs[.T])
857 template<typename SV, typename xpu,
858  bool transpose_left, bool transpose_right, typename DType>
859 struct DotEngine<SV, xpu, 2, 2, 2, transpose_left, transpose_right, DType> {
860  inline static void Eval(Tensor<xpu, 2, DType> *p_dst,
861  const Tensor<xpu, 2, DType> &lhs,
862  const Tensor<xpu, 2, DType> &rhs,
863  DType scale) {
864  Tensor<xpu, 2, DType> &dst = *p_dst;
865 #if MSHADOW_STAND_ALONE
866  if (xpu::kDevMask == cpu::kDevMask && scale == 1.0f) {
867  if (!transpose_left && !transpose_right) {
868  dst = expr::implicit_dot(lhs, rhs); return;
869  } else if (!transpose_left && transpose_right) {
870  dst = expr::implicit_dot(lhs, rhs.T()); return;
871  } else if (transpose_left && !transpose_right) {
872  dst = expr::implicit_dot(lhs.T(), rhs); return;
873  }
874  }
875 #endif
876  // set kernel stream
877  // if there is no stream, crush
879  Shape<2> sleft = GetShape(lhs.shape_, transpose_left);
880  Shape<2> sright = GetShape(rhs.shape_, transpose_right);
881  CHECK(dst.size(0) == sleft[0] && dst.size(1) == sright[1] && sleft[1] == sright[0])
882  << "dot-gemm: matrix shape mismatch";
883  // use column major argument to compatible with most BLAS
885  (dst.stream_,
886  transpose_right , transpose_left,
887  transpose_right ? rhs.size(0) : rhs.size(1),
888  transpose_left ? lhs.size(1) : lhs.size(0),
889  transpose_right ? rhs.size(1) : rhs.size(0),
890  DType(scale * SV::AlphaBLAS()),
891  rhs.dptr_, rhs.stride_,
892  lhs.dptr_, lhs.stride_,
893  DType(SV::BetaBLAS()),
894  dst.dptr_, dst.stride_);
895  }
896 };
897 template<typename SV, typename xpu, bool transpose_right, typename DType>
898 struct DotEngine<SV, xpu, 1, 1, 2, false, transpose_right, DType> {
899  inline static void Eval(Tensor<xpu, 1, DType> *p_dst,
900  const Tensor<xpu, 1, DType> &lhs,
901  const Tensor<xpu, 2, DType> &rhs,
902  DType scale) {
903  Tensor<xpu, 1, DType> &dst = *p_dst;
904  // set kernel stream
905  // if there is no stream, crush
907  Shape<2> sright = GetShape(rhs.shape_, transpose_right);
908  CHECK(dst.size(0) == sright[1] && lhs.size(0) == sright[0])
909  << "dot-gemv: matrix shape mismatch"
910  << "dst: " << dst.shape_ << "\n"
911  << "lhs: " << lhs.shape_ << "\n"
912  << "rhs: " << sright << "\n";
914  (dst.stream_,
915  transpose_right,
916  rhs.size(1), rhs.size(0), scale * SV::AlphaBLAS(),
917  rhs.dptr_, rhs.stride_,
918  lhs.dptr_, 1, SV::BetaBLAS(),
919  dst.dptr_, 1);
920  }
921 };
922 template<typename SV, typename xpu, typename DType>
923 struct DotEngine<SV, xpu, 2, 1, 1, true, false, DType> {
924  inline static void Eval(Tensor<xpu, 2, DType> *p_dst,
925  const Tensor<xpu, 1, DType> &lhs,
926  const Tensor<xpu, 1, DType> &rhs,
927  DType scale) {
928  Tensor<xpu, 2, DType> &dst = *p_dst;
929  // set kernel stream
930  // if there is no stream, crush
932  CHECK(dst.size(0) == lhs.size(0) && dst.size(1) == rhs.size(0))
933  << "dot-ger: matrix shape mismatch"
934  << "dst: " << dst.shape_ << "\n"
935  << "lhs: " << lhs.shape_ << "\n"
936  << "rhs: " << rhs.shape_;
937  if (SV::BetaBLAS() == 0.0f) {
939  (dst.stream_, rhs.size(0), lhs.size(0), scale * SV::AlphaBLAS(),
940  rhs.dptr_, 1, lhs.dptr_, 1, dst.dptr_, dst.stride_);
941  } else {
942  DotEngine<SV, xpu, 2, 2, 2, true, false,
943  DType>::Eval(p_dst, lhs.FlatTo2D(), rhs.FlatTo2D(), scale);
944  }
945  }
946 };
947 } // namespace expr
948 } // namespace mshadow
949 #endif // MSHADOW_DOT_ENGINE_INL_H_
static void ger(Stream< gpu > *stream, int m, int n, half::half_t alpha, const half::half_t *X, int incX, const half::half_t *Y, int incY, half::half_t *A, int lda)
Definition: dot_engine-inl.h:597
static void batched_gemv(Stream< Device > *stream, bool trans, int m, int n, DType alpha, const DType *A, int lda, const DType *X, int incX, DType beta, DType *Y, int incY, int batch_count)
Definition: dot_engine-inl.h:107
static void batched_gemm(Stream< Device > *stream, bool transa, bool transb, int m, int n, int k, DType alpha, const DType *A, int lda, const DType *B, int ldb, DType beta, DType *C, int ldc, int batch_count, DType **workspace)
Definition: dot_engine-inl.h:92
static void SetStream(Stream< gpu > *stream)
Definition: dot_engine-inl.h:740
static void SetStream(Stream< cpu > *stream)
Definition: dot_engine-inl.h:407
ImplicitGEMMExp< LhsExp, RhsExp, DType > implicit_dot(const Exp< LhsExp, DType, e1 > &lhs, const Exp< RhsExp, DType, e2 > &rhs)
Definition: implicit_gemm.h:65
static void SetStream(Stream< cpu > *stream)
Definition: dot_engine-inl.h:298
static void SetStream(Stream< gpu > *stream)
Definition: dot_engine-inl.h:623
static void batched_gemv(Stream< gpu > *stream, bool trans, int m, int n, float alpha, const float *A, int lda, const float *X, int incX, float beta, float *Y, int incY, int batch_count)
Definition: dot_engine-inl.h:692
DType * dptr_
pointer to the data
Definition: tensor.h:435
static void gemv(Stream< gpu > *stream, bool trans, int m, int n, half::half_t alpha, const half::half_t *A, int lda, const half::half_t *X, int incX, half::half_t beta, half::half_t *Y, int incY)
Definition: dot_engine-inl.h:583
Shape< 2 > GetShape(const Shape< 2 > &shape, bool transpose)
Definition: dot_engine-inl.h:853
static void gemm(Stream< cpu > *stream, bool transa, bool transb, int m, int n, int k, double alpha, const double *A, int lda, const double *B, int ldb, double beta, double *C, int ldc)
Definition: dot_engine-inl.h:409
static void gemv(Stream< cpu > *stream, bool trans, int m, int n, float alpha, const float *A, int lda, const float *X, int incX, float beta, float *Y, int incY)
Definition: dot_engine-inl.h:359
static void gemv(Stream< gpu > *stream, bool trans, int m, int n, double alpha, const double *A, int lda, const double *X, int incX, double beta, double *Y, int incY)
Definition: dot_engine-inl.h:800
static void batched_ger(Stream< cpu > *stream, int m, int n, double alpha, const double *X, int incX, const double *Y, int incY, double *A, int lda, int batch_count)
Definition: dot_engine-inl.h:493
static void batched_gemm(Stream< gpu > *stream, bool transa, bool transb, int m, int n, int k, double alpha, const double *A, int lda, const double *B, int ldb, double beta, double *C, int ldc, int batch_count, double **workspace)
Definition: dot_engine-inl.h:756
static void gemv(Stream< cpu > *stream, bool trans, int m, int n, double alpha, const double *A, int lda, const double *X, int incX, double beta, double *Y, int incY)
Definition: dot_engine-inl.h:468
Definition: stream_gpu-inl.h:38
static void batched_gemv(Stream< gpu > *stream, bool trans, int m, int n, double alpha, const double *A, int lda, const double *X, int incX, double beta, double *Y, int incY, int batch_count)
Definition: dot_engine-inl.h:809
static void batched_ger(Stream< cpu > *stream, int m, int n, float alpha, const float *X, int incX, const float *Y, int incY, float *A, int lda, int batch_count)
Definition: dot_engine-inl.h:384
support for implicit GEMM operation
static void batched_ger(Stream< gpu > *stream, int m, int n, half::half_t alpha, const half::half_t *X, int incX, const half::half_t *Y, int incY, half::half_t *A, int lda, int batch_count)
Definition: dot_engine-inl.h:603
Shape< dimension > shape_
shape of the tensor
Definition: tensor.h:437
static void gemm(Stream< cpu > *stream, bool transa, bool transb, int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, float beta, float *C, int ldc)
Definition: dot_engine-inl.h:300
static cublasOperation_t GetT(bool t)
Definition: dot_engine-inl.h:620
static void batched_gemv(Stream< gpu > *stream, bool trans, int m, int n, half::half_t alpha, const half::half_t *A, int lda, const half::half_t *X, int incX, half::half_t beta, half::half_t *Y, int incY, int batch_count)
Definition: dot_engine-inl.h:590
static void ger(Stream< gpu > *stream, int m, int n, float alpha, const float *X, int incX, const float *Y, int incY, float *A, int lda)
Definition: dot_engine-inl.h:703
static void batched_ger(Stream< gpu > *stream, int m, int n, double alpha, const double *X, int incX, const double *Y, int incY, double *A, int lda, int batch_count)
Definition: dot_engine-inl.h:828
DotExp< TA, TB, false, false, DType > dot(const RValueExp< TA, DType > &lhs, const RValueExp< TB, DType > &rhs)
dot operator def
Definition: expression.h:241
static void batched_gemv(Stream< cpu > *stream, bool trans, int m, int n, double alpha, const double *A, int lda, const double *X, int incX, double beta, double *Y, int incY, int batch_count)
Definition: dot_engine-inl.h:476
static void batched_gemm(Stream< gpu > *stream, bool transa, bool transb, int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, float beta, float *C, int ldc, int batch_count, float **workspace)
Definition: dot_engine-inl.h:639
Definition: dot_engine-inl.h:71
static void batched_ger(Stream< gpu > *stream, int m, int n, float alpha, const float *X, int incX, const float *Y, int incY, float *A, int lda, int batch_count)
Definition: dot_engine-inl.h:711
cudaDeviceProp prop
cudaDeviceProp
Definition: stream_gpu-inl.h:63
device name GPU
Definition: tensor.h:47
MSHADOW_XINLINE Tensor< Device, 2, DType > FlatTo2D(void) const
flatten the tensor to 2 dimension, collapse the higher dimensions together
Definition: tensor.h:520
static void batched_ger(Stream< Device > *stream, int m, int n, DType alpha, const DType *X, int incX, const DType *Y, int incY, DType *A, int lda, int batch_count)
Definition: dot_engine-inl.h:120
static cublasOperation_t GetT(bool t)
Definition: dot_engine-inl.h:737
static void Eval(Tensor< xpu, 2, DType > *p_dst, const Tensor< xpu, 2, DType > &lhs, const Tensor< xpu, 2, DType > &rhs, DType scale)
Definition: dot_engine-inl.h:860
static void gemm(Stream< gpu > *stream, bool transa, bool transb, int m, int n, int k, half::half_t alpha, const half::half_t *A, int lda, const half::half_t *B, int ldb, half::half_t beta, half::half_t *C, int ldc)
Definition: dot_engine-inl.h:524
static void batched_gemm(Stream< gpu > *stream, bool transa, bool transb, int m, int n, int k, half::half_t alpha, const half::half_t *A, int lda, const half::half_t *B, int ldb, half::half_t beta, half::half_t *C, int ldc, int batch_count, half::half_t **workspace)
Definition: dot_engine-inl.h:551
static void batched_gemm(Stream< cpu > *stream, bool transa, bool transb, int m, int n, int k, double alpha, const double *A, int lda, const double *B, int ldb, double beta, double *C, int ldc, int batch_count, double **workspace)
Definition: dot_engine-inl.h:417
static void gemm(Stream< gpu > *stream, bool transa, bool transb, int m, int n, int k, double alpha, const double *A, int lda, const double *B, int ldb, double beta, double *C, int ldc)
Definition: dot_engine-inl.h:745
static CBLAS_TRANSPOSE GetT(bool t)
Definition: dot_engine-inl.h:404
static void dot(Stream< Device > *stream, int n, const DType *X, int incX, const DType *Y, int incY, DType *ret)
Definition: dot_engine-inl.h:126
static void batched_gemm(Stream< cpu > *stream, bool transa, bool transb, int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, float beta, float *C, int ldc, int batch_count, float **workspace)
Definition: dot_engine-inl.h:308
static cublasOperation_t GetT(bool t)
Definition: dot_engine-inl.h:516
MSHADOW_XINLINE Shape< 2 > Shape2(index_t s0, index_t s1)
construct a two dimension shape, stride will equal s0
Definition: tensor.h:217
static const int kDevMask
device flag number, identifies this device
Definition: tensor.h:44
static void ger(Stream< cpu > *stream, int m, int n, double alpha, const double *X, int incX, const double *Y, int incY, double *A, int lda)
Definition: dot_engine-inl.h:487
static void ger(Stream< cpu > *stream, int m, int n, float alpha, const float *X, int incX, const float *Y, int incY, float *A, int lda)
Definition: dot_engine-inl.h:378
void GetBatchedView(DType **dst, DType *src, int num, int stride, Stream< Device > *stream)
CPU/GPU: Get a batched view of the src array. dst[i] = src + i * stride.
static void dot(Stream< gpu > *stream, int n, const float *X, int incX, const float *Y, int incY, float *ret)
Definition: dot_engine-inl.h:720
static void Eval(Tensor< xpu, 2, DType > *p_dst, const Tensor< xpu, 1, DType > &lhs, const Tensor< xpu, 1, DType > &rhs, DType scale)
Definition: dot_engine-inl.h:924
static void dot(Stream< cpu > *stream, int n, const float *X, int incX, const float *Y, int incY, float *ret)
Definition: dot_engine-inl.h:393
static void ger(Stream< Device > *stream, int m, int n, DType alpha, const DType *X, int incX, const DType *Y, int incY, DType *A, int lda)
Definition: dot_engine-inl.h:114
static void dot(Stream< gpu > *stream, int n, const double *X, int incX, const double *Y, int incY, double *ret)
Definition: dot_engine-inl.h:837
static void batched_gemv(Stream< cpu > *stream, bool trans, int m, int n, float alpha, const float *A, int lda, const float *X, int incX, float beta, float *Y, int incY, int batch_count)
Definition: dot_engine-inl.h:367
static void Eval(Tensor< xpu, 1, DType > *p_dst, const Tensor< xpu, 1, DType > &lhs, const Tensor< xpu, 2, DType > &rhs, DType scale)
Definition: dot_engine-inl.h:899
static void dot(Stream< cpu > *stream, int n, const double *X, int incX, const double *Y, int incY, double *ret)
Definition: dot_engine-inl.h:502
static void ger(Stream< gpu > *stream, int m, int n, double alpha, const double *X, int incX, const double *Y, int incY, double *A, int lda)
Definition: dot_engine-inl.h:820
overloaded + operator between half_t and bf16_t
Definition: base.h:327
MSHADOW_XINLINE index_t size(int idx) const
return size of i-th dimension, start counting from highest dimension
Definition: tensor.h:506
static void gemv(Stream< gpu > *stream, bool trans, int m, int n, float alpha, const float *A, int lda, const float *X, int incX, float beta, float *Y, int incY)
Definition: dot_engine-inl.h:683
static void gemm(Stream< Device > *stream, bool transa, bool transb, int m, int n, int k, DType alpha, const DType *A, int lda, const DType *B, int ldb, DType beta, DType *C, int ldc)
Definition: dot_engine-inl.h:85
TransposeExExp< SrcExp, DType, ExpInfo< SrcExp >::kDim > transpose(const Exp< SrcExp, DType, etype > &src, Shape< ExpInfo< SrcExp >::kDim > axes)
a expression that reshapes a tensor to another shape
Definition: transpose.h:77
index_t stride_
storing the stride information in x dimension this is used to deal with pitch allocation in gpu or ss...
Definition: tensor.h:442
static void SetStream(Stream< gpu > *stream)
Definition: dot_engine-inl.h:519
Definition: dot_engine-inl.h:79
static void dot(Stream< gpu > *stream, int n, const half::half_t *X, int incX, const half::half_t *Y, int incY, half::half_t *ret)
Definition: dot_engine-inl.h:609
general tensor
Definition: tensor.h:421
static void gemv(Stream< Device > *stream, bool trans, int m, int n, DType alpha, const DType *A, int lda, const DType *X, int incX, DType beta, DType *Y, int incY)
Definition: dot_engine-inl.h:100
static void gemm(Stream< gpu > *stream, bool transa, bool transb, int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, float beta, float *C, int ldc)
Definition: dot_engine-inl.h:628
static void SetStream(Stream< Device > *stream)
Definition: dot_engine-inl.h:83
static bool GetT(bool t)
Definition: dot_engine-inl.h:80
const TransposeExp< Tensor< Device, dimension, DType >, DType > T(void) const
transpose of a matrix
Definition: expression.h:155
Stream< Device > * stream_
stream where the computation lies stream is a device dependency concept where each computation ...
Definition: tensor.h:447
void GetBatchedView(DType **dst, DType *src, int num, int stride, Stream< cpu > *stream)
Definition: dot_engine-inl.h:50
computaion stream structure, used for asynchronous computations
Definition: tensor.h:384
static CBLAS_TRANSPOSE GetT(bool t)
Definition: dot_engine-inl.h:295