mxnet
tensor_cpu-inl.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
25 #ifndef MSHADOW_TENSOR_CPU_INL_H_
26 #define MSHADOW_TENSOR_CPU_INL_H_
27 #include <cstring>
28 #include <functional>
29 #include <utility>
30 #include <vector>
31 #include "./base.h"
32 #include "./tensor.h"
33 #include "./packet-inl.h"
34 #include "./dot_engine-inl.h"
35 
36 namespace mshadow {
37 template<>
38 inline void InitTensorEngine<cpu>(int dev_id) {
39 }
40 template<>
41 inline void ShutdownTensorEngine<cpu>(void) {
42 }
43 
44 template<>
45 inline void SetDevice<cpu>(int devid) {
46 }
47 template<>
48 inline Stream<cpu> *NewStream<cpu>(bool create_blas_handle,
49  bool create_dnn_handle,
50  int dev_id) {
51  return new Stream<cpu>();
52 }
53 template<>
54 inline void DeleteStream<cpu>(Stream<cpu> *stream) {
55  delete stream;
56 }
57 
58 template<int ndim>
59 inline std::ostream &operator<<(std::ostream &os, const Shape<ndim> &shape) { // NOLINT(*)
60  os << '(';
61  for (int i = 0; i < ndim; ++i) {
62  if (i != 0) os << ',';
63  os << shape[i];
64  }
65  // python style tuple
66  if (ndim == 1) os << ',';
67  os << ')';
68  return os;
69 }
70 
71 template<typename xpu>
72 inline void *AllocHost_(size_t size);
73 template<typename xpu>
74 inline void FreeHost_(void * dptr);
75 
76 #ifdef __CUDACC__
77 template<>
78 inline void *AllocHost_<gpu>(size_t size) {
79  void *dptr;
80  MSHADOW_CUDA_CALL(cudaMallocHost(&dptr, size, cudaHostAllocPortable));
81  return dptr;
82 }
83 template<>
84 inline void FreeHost_<gpu>(void *dptr) {
85  MSHADOW_CUDA_CALL(cudaFreeHost(dptr));
86 }
87 #endif
88 
89 template<>
90 inline void *AllocHost_<cpu>(size_t size) {
91  size_t pitch;
92  return packet::AlignedMallocPitch(&pitch, size, 1);
93 }
94 template<>
95 inline void FreeHost_<cpu>(void *dptr) {
96  packet::AlignedFree(dptr);
97 }
98 
99 template<typename xpu, int dim, typename DType>
101  obj->stride_ = obj->size(dim - 1);
102  CHECK_EQ(obj->CheckContiguous(), true) << "AllocHost";
103  void *dptr = AllocHost_<xpu>(obj->MSize() * sizeof(DType));
104  obj->dptr_ = reinterpret_cast<DType*>(dptr);
105 }
106 template<typename xpu, int dim, typename DType>
108  if (obj->dptr_ == NULL) {
109  LOG(FATAL) << "FreeHost:: double free";
110  }
111  FreeHost_<xpu>(obj->dptr_);
112  obj->dptr_ = NULL;
113 }
114 
115 template<int dim, typename DType>
116 inline void AllocSpace(Tensor<cpu, dim, DType> *obj, bool pad) {
117  size_t pitch;
118  void *dptr;
119  if (pad) {
121  (&pitch, obj->size(dim - 1) * sizeof(DType), obj->shape_.FlatTo2D()[0]);
122  obj->stride_ = static_cast<index_t>(pitch / sizeof(DType));
123  } else {
124  obj->stride_ = obj->size(dim - 1);
126  (&pitch, obj->shape_.Size() * sizeof(DType), 1);
127  }
128  obj->dptr_ = reinterpret_cast<DType*>(dptr);
129 }
130 template<typename Device, typename DType, int dim>
131 inline Tensor<Device, dim, DType>
132 NewTensor(const Shape<dim> &shape, DType initv, bool pad, Stream<Device> *stream_) {
133  Tensor<Device, dim, DType> obj(shape);
134  obj.stream_ = stream_;
135  AllocSpace(&obj, pad);
136  MapExp<sv::saveto>(&obj, expr::ScalarExp<DType>(initv));
137  return obj;
138 }
139 template<int dim, typename DType>
142  obj->dptr_ = NULL;
143 }
144 template<int dim, typename DType>
145 inline void Copy(Tensor<cpu, dim, DType> _dst,
146  const Tensor<cpu, dim, DType> &_src,
147  Stream<cpu> *stream) {
148 #pragma GCC diagnostic push
149 #if __GNUC__ >= 8
150 #pragma GCC diagnostic ignored "-Wclass-memaccess"
151 #endif
152  CHECK_EQ(_dst.shape_, _src.shape_)
153  << "Copy:shape mismatch:" << _dst.shape_ << " vs " << _src.shape_;
154  if (_dst.CheckContiguous() && _src.CheckContiguous()) {
155  memcpy(_dst.dptr_, _src.dptr_, sizeof(DType) * _dst.shape_.Size());
156  } else {
157  Tensor<cpu, 2, DType> dst = _dst.FlatTo2D();
158  Tensor<cpu, 2, DType> src = _src.FlatTo2D();
159  for (index_t y = 0; y < dst.size(0); ++y) {
160  memcpy(dst[y].dptr_, src[y].dptr_, sizeof(DType) * dst.size(1));
161  }
162  }
163 #pragma GCC diagnostic pop
164 }
165 
166 template<typename Saver, typename R, int dim,
167  typename DType, typename E>
169  const expr::Plan<E, DType> &plan) {
171  expr::Plan<R, DType> dplan = expr::MakePlan(dst->self());
172 #ifndef __CUDACC__
173  #pragma omp parallel for
174 #endif
175  // temp remove openmp, as default setting throttles CPU
176  for (openmp_index_t y = 0; y < shape[0]; ++y) {
177  for (index_t x = 0; x < shape[1]; ++x) {
178  // trust your compiler! -_- they will optimize it
179  Saver::template Save<DType>(dplan.REval(y, x), plan.Eval(y, x));
180  }
181  }
182 }
183 // code to handle SSE optimization
184 template<bool pass_check, typename Saver,
185  typename R, int dim,
186  typename DType, typename E, int etype>
188  inline static void Map(TRValue<R, cpu, dim, DType> *dst,
189  const expr::Exp<E, DType, etype> &exp) {
190  MapPlan<Saver>(dst, MakePlan(exp.self()));
191  }
192 };
193 
194 template<typename SV, int dim, typename DType, typename E, int etype>
195 struct MapExpCPUEngine<true, SV, Tensor<cpu, dim, DType>,
196  dim, DType, E, etype> {
197  inline static void Map(Tensor<cpu, dim, DType> *dst,
198  const expr::Exp<E, DType, etype> &exp) {
201  expr::MapPacketPlan<SV>(dst->self(),
202  expr::MakePacketPlan<MSHADOW_DEFAULT_PACKET>(exp.self()));
203  } else {
204  MapPlan<SV>(dst, MakePlan(exp.self()));
205  }
206  }
207 };
208 
209 
210 template<typename Saver, typename R, int dim,
211  typename DType, typename E, int etype>
213  const expr::Exp<E, DType, etype> &exp) {
215  ::Error_All_Tensor_in_Exp_Must_Have_Same_Type();
218  CHECK(eshape[0] == 0 || eshape == dshape)
219  << "Assignment: Shape of Tensors are not consistent with target, "
220  << "eshape: " << eshape << " dshape:" << dshape;
222  Saver, R, dim, DType, E, etype>
223  ::Map(dst->ptrself(), exp);
224 }
225 
226 template<typename Saver, typename Reducer,
227  typename R, typename DType, typename E, int etype>
229  const expr::Exp<E, DType, etype> &exp,
230  DType scale) {
232  ::Error_TypeCheck_Not_Pass_For_Reduce_Exp();
234  ::Check(exp.self()).FlatTo2D();
236  CHECK_EQ(eshape[1], dshape[0]) << "MapReduceKeepLowest::reduction dimension do not match";
237  CHECK_NE(eshape[0], 0U) << "can not reduce over empty tensor";
238  // execution
239  expr::Plan<R, DType> dplan = MakePlan(dst->self());
240  expr::Plan<E, DType> splan = MakePlan(exp.self());
241 #ifndef __CUDACC__
242  #pragma omp parallel for
243 #endif
244  for (openmp_index_t x = 0; x < eshape[1]; ++x) {
245  DType res = splan.Eval(0, x);
246  for (index_t y = 1; y < eshape[0]; ++y) {
247  Reducer::Reduce(res, splan.Eval(y, x));
248  }
249  Saver::template Save<DType>(dplan.REval(0, x), res * scale);
250  }
251 }
252 
253 template<typename Saver, typename Reducer, int dimkeep,
254  typename R, typename DType, typename E, int etype>
256  const expr::Exp<E, DType, etype> &exp,
257  DType scale) {
259  ::Error_TypeCheck_Not_Pass_For_Reduce_Exp();
260  typedef Shape<expr::ExpInfo<E>::kDim> EShape;
261  EShape eshape = expr::ShapeCheck<expr::ExpInfo<E>::kDim, E>
262  ::Check(exp.self());
264  CHECK_EQ(eshape[dimkeep], dshape[0])
265  << "MapReduceKeepHighDim::reduction dimension do not match";
266  // use equvalent form
267  Shape<4> pshape = Shape4(eshape.ProdShape(0, dimkeep),
268  eshape[dimkeep],
269  eshape.ProdShape(dimkeep + 1, EShape::kSubdim),
270  eshape[EShape::kSubdim]);
271  // execution
272  expr::Plan<R, DType> dplan = MakePlan(dst->self());
273  expr::Plan<E, DType> splan = MakePlan(exp.self());
274 #ifndef __CUDACC__
275  #pragma omp parallel for
276 #endif
277  for (openmp_index_t c = 0; c < pshape[1]; ++c) {
278  DType res; Reducer::SetInitValue(res);
279  for (index_t n = 0; n < pshape[0]; ++n) {
280  DType tres; Reducer::SetInitValue(tres);
281  for (index_t y = 0; y < pshape[2]; ++y) {
282  for (index_t x = 0; x < pshape[3]; ++x) {
283  Reducer::Reduce(tres,
284  splan.Eval((n * pshape[1] + c) * pshape[2] + y, x));
285  }
286  }
287  Reducer::Reduce(res, tres);
288  }
289  Saver::template Save<DType>(dplan.REval(0, c), DType(res * scale));
290  }
291 }
292 
293 template<typename DType>
295  const Tensor<cpu, 1, DType> &energy) {
296  DType mmax = energy[0];
297  for (index_t x = 1; x < dst.size(0); ++x) {
298  if (mmax < energy[x]) mmax = energy[x];
299  }
300  DType sum = DType(0.0f);
301  for (index_t x = 0; x < dst.size(0); ++x) {
302  dst[x] = std::exp(energy[x] - mmax);
303  sum += dst[x];
304  }
305  for (index_t x = 0; x < dst.size(0); ++x) {
306  dst[x] /= sum;
307  }
308 }
309 
310 template<typename DType>
312  const Tensor<cpu, 2, DType> &src,
313  const Tensor<cpu, 1, DType> &label) {
314 #pragma omp parallel for
315  for (openmp_index_t y = 0; y < dst.size(0); ++y) {
316  const index_t k = static_cast<int>(label[y]);
317  for (index_t x = 0; x < dst.size(1); ++x) {
318  if (x == k) {
319  dst[y][k] = src[y][k] - 1.0f;
320  } else {
321  dst[y][x] = src[y][x];
322  }
323  }
324  }
325 }
326 
327 template<typename DType>
329  const Tensor<cpu, 2, DType> &src,
330  const Tensor<cpu, 1, DType> &label,
331  const float alpha) {
332  const float smooth_grad = (alpha / (dst.size(1) - 1));
333 #pragma omp parallel for
334  for (openmp_index_t y = 0; y < dst.size(0); ++y) {
335  const index_t k = static_cast<int>(label[y]);
336  for (index_t x = 0; x < dst.size(1); ++x) {
337  if (x == k) {
338  dst[y][k] = src[y][k] - 1.0f + alpha;
339  } else {
340  dst[y][x] = src[y][x] - smooth_grad;
341  }
342  }
343  }
344 }
345 
346 
347 template<typename DType>
349  const Tensor<cpu, 2, DType> &src,
350  const Tensor<cpu, 1, DType> &label,
351  const DType &ignore_label) {
352 #pragma omp parallel for
353  for (openmp_index_t y = 0; y < dst.size(0); ++y) {
354  const int k = static_cast<int>(label[y]);
355  for (int x = 0; x < static_cast<int>(dst.size(1)); ++x) {
356  if (static_cast<int>(ignore_label) == k) {
357  dst[y][x] = 0.0f;
358  } else {
359  if (x == k) {
360  dst[y][k] = src[y][k] - 1.0f;
361  } else {
362  dst[y][x] = src[y][x];
363  }
364  }
365  }
366  }
367 }
368 
369 template<typename DType>
371  const Tensor<cpu, 2, DType> &src,
372  const Tensor<cpu, 1, DType> &label,
373  const DType &ignore_label,
374  const float alpha) {
375  const float smooth_grad = (alpha / (dst.size(1) - 1));
376 #pragma omp parallel for
377  for (openmp_index_t y = 0; y < dst.size(0); ++y) {
378  const int k = static_cast<int>(label[y]);
379  for (int x = 0; x < static_cast<int>(dst.size(1)); ++x) {
380  if (static_cast<int>(ignore_label) == k) {
381  dst[y][x] = 0.0f;
382  } else {
383  if (x == k) {
384  dst[y][k] = src[y][k] - 1.0f + alpha;
385  } else {
386  dst[y][x] = src[y][x] - smooth_grad;
387  }
388  }
389  }
390  }
391 }
392 
393 template<typename DType>
395  const Tensor<cpu, 3, DType> &src,
396  const Tensor<cpu, 2, DType> &label) {
397 #pragma omp parallel for
398  for (openmp_index_t n = 0; n < dst.size(2); ++n) {
399  for (index_t y = 0; y < dst.size(0); ++y) {
400  const int k = static_cast<int>(label[y][n]);
401  for (int x = 0; x < static_cast<int>(dst.size(1)); ++x) {
402  if (x == k) {
403  dst[y][k][n] = src[y][k][n] - 1.0f;
404  } else {
405  dst[y][x][n] = src[y][x][n];
406  }
407  }
408  }
409  }
410 }
411 
412 template<typename DType>
414  const Tensor<cpu, 3, DType> &src,
415  const Tensor<cpu, 2, DType> &label,
416  const float alpha) {
417  const float smooth_grad = (alpha / (dst.size(1) - 1));
418 #pragma omp parallel for
419  for (openmp_index_t n = 0; n < dst.size(2); ++n) {
420  for (index_t y = 0; y < dst.size(0); ++y) {
421  const int k = static_cast<int>(label[y][n]);
422  for (int x = 0; x < static_cast<int>(dst.size(1)); ++x) {
423  if (x == k) {
424  dst[y][k][n] = src[y][k][n] - 1.0f + alpha;
425  } else {
426  dst[y][x][n] = src[y][x][n] - smooth_grad;
427  }
428  }
429  }
430  }
431 }
432 
433 template<typename DType>
435  const Tensor<cpu, 3, DType> &src,
436  const Tensor<cpu, 2, DType> &label,
437  const DType &ignore_label) {
438 #pragma omp parallel for
439  for (openmp_index_t n = 0; n < dst.size(2); ++n) {
440  for (index_t y = 0; y < dst.size(0); ++y) {
441  const int k = static_cast<int>(label[y][n]);
442  if (k == static_cast<int>(ignore_label)) {
443  for (int x = 0; x < static_cast<int>(dst.size(1)); ++x) {
444  dst[y][x][n] = DType(0.0f);
445  }
446  } else {
447  for (int x = 0; x < static_cast<int>(dst.size(1)); ++x) {
448  if (x == k) {
449  dst[y][k][n] = src[y][k][n] - 1.0f;
450  } else {
451  dst[y][x][n] = src[y][x][n];
452  }
453  }
454  }
455  }
456  }
457 }
458 
459 template<typename DType>
461  const Tensor<cpu, 3, DType> &src,
462  const Tensor<cpu, 2, DType> &label,
463  const DType &ignore_label,
464  const float alpha) {
465  const float smooth_grad = (alpha / (dst.size(1) - 1));
466 #pragma omp parallel for
467  for (openmp_index_t n = 0; n < dst.size(2); ++n) {
468  for (index_t y = 0; y < dst.size(0); ++y) {
469  const int k = static_cast<int>(label[y][n]);
470  if (k == static_cast<int>(ignore_label)) {
471  for (int x = 0; x < static_cast<int>(dst.size(1)); ++x) {
472  dst[y][x][n] = DType(0.0f);
473  }
474  } else {
475  for (int x = 0; x < static_cast<int>(dst.size(1)); ++x) {
476  if (x == k) {
477  dst[y][k][n] = src[y][k][n] - 1.0f + alpha;
478  } else {
479  dst[y][x][n] = src[y][x][n] - smooth_grad;
480  }
481  }
482  }
483  }
484  }
485 }
486 
487 template<typename DType>
489  const Tensor<cpu, 2, DType> &energy) {
490  CHECK_EQ(dst.shape_, energy.shape_) << "Softmax: shape mismatch";
491 #pragma omp parallel for
492  for (openmp_index_t y = 0; y < dst.size(0); ++y) {
493  Softmax(dst[y], energy[y]);
494  }
495 }
496 
497 template<typename DType>
499  const Tensor<cpu, 3, DType> &energy) {
500  CHECK_EQ(dst.shape_, energy.shape_) << "Softmax: shape mismatch";
501 #pragma omp parallel for
502  for (openmp_index_t y = 0; y < dst.size(0); ++y) {
503  for (index_t n = 0; n < dst.size(2); ++n) {
504  DType mmax = energy[y][0][n];
505  for (index_t x = 1; x < dst.size(1); ++x) {
506  if (mmax < energy[y][x][n]) mmax = energy[y][x][n];
507  }
508  DType sum = DType(0.0f);
509  for (index_t x = 0; x < dst.size(1); ++x) {
510  dst[y][x][n] = std::exp(energy[y][x][n] - mmax);
511  sum += dst[y][x][n];
512  }
513  for (index_t x = 0; x < dst.size(1); ++x) {
514  dst[y][x][n] /= sum;
515  }
516  }
517  }
518 }
519 
520 template<bool clip, typename IndexType, typename DType>
522  const Tensor<cpu, 1, IndexType>& index,
523  const Tensor<cpu, 2, DType> &src) {
524  const index_t K = dst.shape_[0];
525  const index_t C = dst.shape_[1];
526  for (index_t y = 0; y < index.size(0); ++y) {
527  index_t j = index[y];
528  if (clip) {
529  if (j <= 0) j = 0;
530  else if (j >= K) j = K - 1;
531  } else {
532  j %= K;
533  if (j < 0) j += K;
534  }
535  for (index_t i = 0; i < C; ++i) {
536  dst[j][i] += src[y][i];
537  }
538  }
539 }
540 
541 // safe accumulation
542 template<bool clip, typename IndexType, typename DType, typename AType>
545  const Tensor<cpu, 1, IndexType>& index,
546  const Tensor<cpu, 2, DType> &src) {
547  const index_t K = dst.shape_[0];
548  const index_t C = dst.shape_[1];
549  for (index_t j = 0; j < K; ++j) {
550  for (index_t i = 0; i < C; ++i) {
551  temp[j][i] = dst[j][i];
552  }
553  }
554  for (index_t y = 0; y < index.size(0); ++y) {
555  index_t j = index[y];
556  if (clip) {
557  if (j <= 0) j = 0;
558  else if (j >= K) j = K - 1;
559  } else {
560  j %= K;
561  if (j < 0) j += K;
562  }
563  for (index_t i = 0; i < C; ++i) {
564  temp[j][i] += src[y][i];
565  }
566  }
567  for (index_t j = 0; j < K; ++j) {
568  for (index_t i = 0; i < C; ++i) {
569  dst[j][i] = temp[j][i];
570  }
571  }
572 }
573 
574 template<typename IndexType, typename DType>
576  const Tensor<cpu, 1, IndexType>& sorted,
577  const Tensor<cpu, 1, IndexType>& index,
578  const Tensor<cpu, 2, DType> &src) {
579  for (index_t y = 0; y < sorted.size(0); ++y) {
580  dst[sorted[y]] += src[index[y]];
581  }
582 }
583 
584 template<typename IndexType, typename DType>
586  const Tensor<cpu, 1, IndexType>& index,
587  const Tensor<cpu, 2, DType> &src) {
588  for (index_t y = 0; y < index.size(0); ++y) {
589  for (index_t j = 0; j < src.size(1); j++) {
590  dst[index[y]][j] = src[y][j];
591  }
592  }
593 }
594 
595 template<typename KDType, typename VDType>
597  bool is_ascend) {
598  CHECK_EQ(keys.CheckContiguous(), true);
599  CHECK_EQ(values.CheckContiguous(), true);
600  CHECK_EQ(keys.size(0), values.size(0))
601  << "The sizes of key/value are not equal! keys_size: " << keys.size(0)
602  << "values_size: " << values.size(0);
603  std::vector<size_t> idx(keys.size(0));
604  std::vector<KDType> keys_vec(keys.size(0));
605  std::vector<VDType> values_vec(values.size(0));
606  for (int i = 0; i < keys.size(0); i++) {
607  idx[i] = i;
608  keys_vec[i] = keys[i];
609  values_vec[i] = values[i];
610  }
611  if (is_ascend) {
612  std::stable_sort(idx.begin(), idx.end(),
613  [&keys_vec](size_t i1, size_t i2)
614  {return keys_vec[i1] < keys_vec[i2]; });
615  } else {
616  std::stable_sort(idx.begin(), idx.end(),
617  [&keys_vec](size_t i1, size_t i2)
618  {return keys_vec[i1] > keys_vec[i2]; });
619  }
620  for (index_t i = 0; i < values.size(0); i++) {
621  keys[i] = keys_vec[idx[i]];
622  values[i] = values_vec[idx[i]];
623  }
624 }
625 
626 template<typename Device, typename VDType, typename SDType>
628  // We can sort each segments using two stable sorts
629  SortByKey(values, segments, true);
630  SortByKey(segments, values, true);
631 }
632 
633 // blas related
634 template<typename Device, typename DType>
636  const Tensor<Device, 1, DType> &lhs,
637  const Tensor<Device, 1, DType> &rhs) {
638  CHECK_EQ(lhs.size(0), rhs.size(0))
639  << "VectorDot: Shape mismatch";
640  CHECK_EQ(dst.size(0), 1U)
641  << "VectorDot: expect dst to be scalar";
644  lhs.stream_, lhs.size(0), lhs.dptr_, 1, rhs.dptr_, 1, dst.dptr_);
645 }
646 
647 template<bool transpose_left, bool transpose_right, typename Device, typename DType>
649  const Tensor<Device, 3, DType> &lhs,
650  const Tensor<Device, 3, DType> &rhs,
651  DType alpha,
652  DType beta,
653  Tensor<Device, 1, DType*> workspace) {
654  index_t batch_size = dst.shape_[0];
656  Shape<3> sleft = transpose_left ? Shape3(lhs.shape_[0], lhs.shape_[2], lhs.shape_[1])
657  : lhs.shape_;
658  Shape<3> sright = transpose_right ? Shape3(rhs.shape_[0], rhs.shape_[2], rhs.shape_[1])
659  : rhs.shape_;
660  CHECK_EQ(dst.CheckContiguous(), true);
661  CHECK_EQ(lhs.CheckContiguous(), true);
662  CHECK_EQ(rhs.CheckContiguous(), true);
663  CHECK(sleft[0] == batch_size && sright[0] == batch_size)
664  << "BatchGEMM: batchsize must be equal."
665  << "dst: " << dst.shape_ << "\n"
666  << "lhs: " << sleft << "\n"
667  << "rhs: " << sright << "\n";
668  CHECK(dst.size(1) == sleft[1] && dst.size(2) == sright[2] && sleft[2] == sright[1])
669  << "BatchGEMM: matrix shape mismatch"
670  << "dst: " << dst.shape_ << "\n"
671  << "lhs: " << sleft << "\n"
672  << "rhs: " << sright << "\n";
673  CHECK(workspace.size(0) >= 3 * batch_size)
674  << "Workspace Size must be bigger than " << 3 * batch_size;
675  CHECK_EQ(workspace.CheckContiguous(), true);
676  // use column major argument to compatible with most BLAS
678  (dst.stream_,
679  transpose_right, transpose_left,
680  transpose_right ? rhs.size(1) : rhs.size(2),
681  transpose_left ? lhs.size(2) : lhs.size(1),
682  transpose_right ? rhs.size(2) : rhs.size(1),
683  alpha,
684  rhs.dptr_, rhs.stride_,
685  lhs.dptr_, lhs.stride_,
686  beta,
687  dst.dptr_, dst.stride_, batch_size,
688  workspace.dptr_);
689 }
690 } // namespace mshadow
691 #endif // MSHADOW_TENSOR_CPU_INL_H_
mshadow::Shape4
MSHADOW_XINLINE Shape< 4 > Shape4(index_t s0, index_t s1, index_t s2, index_t s3)
construct a four dimension shape, stride will equal s0
Definition: tensor.h:254
mshadow::openmp_index_t
index_t openmp_index_t
openmp index for linux
Definition: base.h:336
mshadow::SortByKey
void SortByKey(Tensor< cpu, 1, KDType > keys, Tensor< cpu, 1, VDType > values, bool is_ascend=true)
CPU/GPU: Sort key-value pairs stored in separate places. (Stable sort is performed!...
Definition: tensor_cpu-inl.h:596
mshadow::expr::Exp< Container, DType, type::kRValue >::self
const Container & self(void) const
Definition: expression.h:82
mshadow::Tensor::MSize
MSHADOW_XINLINE index_t MSize(void) const
Definition: tensor.h:602
mshadow::Stream
computaion stream structure, used for asynchronous computations
Definition: tensor.h:488
mshadow::expr::BLASEngine::SetStream
static void SetStream(Stream< Device > *stream)
Definition: dot_engine-inl.h:82
mshadow::TRValue
Tensor RValue, this is the super type of all kinds of possible tensors.
Definition: tensor.h:514
mshadow::expr::TypeCheckPass
used to help static type check
Definition: expr_engine-inl.h:330
mshadow::expr::BLASEngine::dot
static void dot(Stream< Device > *stream, int n, const DType *X, int incX, const DType *Y, int incY, DType *ret)
Definition: dot_engine-inl.h:125
mshadow::Copy
void Copy(Tensor< cpu, dim, DType > dst, const Tensor< cpu, dim, DType > &src, Stream< cpu > *stream=NULL)
copy data from one tensor to another, with same shape
Definition: tensor_cpu-inl.h:145
mshadow::FreeHost_
void FreeHost_(void *dptr)
mshadow::FreeSpace
void FreeSpace(Tensor< cpu, dim, DType > *obj)
CPU/GPU: free the space of tensor, will set obj.dptr to NULL.
Definition: tensor_cpu-inl.h:140
mshadow::expr::Exp< Container, DType, type::kRValue >::ptrself
Container * ptrself(void)
Definition: expression.h:86
mshadow::BatchGEMM
void BatchGEMM(Tensor< Device, 3, DType > dst, const Tensor< Device, 3, DType > &lhs, const Tensor< Device, 3, DType > &rhs, DType alpha, DType beta, Tensor< Device, 1, DType * > workspace)
CPU/GPU: dst = alpha * op(lhs) op(rhs) + beta * dst.
Definition: tensor_cpu-inl.h:648
MSHADOW_CUDA_CALL
#define MSHADOW_CUDA_CALL(func)
Protected cuda call in mshadow.
Definition: base.h:264
mshadow::IndexFill
void IndexFill(Tensor< cpu, 2, DType > dst, const Tensor< cpu, 1, IndexType > &index, const Tensor< cpu, 2, DType > &src)
CPU/GPU: Fill the values of the destination matrix to specific rows in the source matrix....
Definition: tensor_cpu-inl.h:585
mshadow::SetDevice< cpu >
void SetDevice< cpu >(int devid)
Definition: tensor_cpu-inl.h:45
dot_engine-inl.h
definitions of how Matrix Multiplications can be evaluated
mshadow::SoftmaxGrad
void SoftmaxGrad(Tensor< cpu, 2, DType > dst, const Tensor< cpu, 2, DType > &src, const Tensor< cpu, 1, DType > &label)
CPU/GPU: softmax gradient.
Definition: tensor_cpu-inl.h:311
mshadow::Tensor< Device, 1, DType >::stream_
Stream< Device > * stream_
Definition: tensor.h:679
mshadow::MapReduceKeepLowest
void MapReduceKeepLowest(TRValue< R, cpu, 1, DType > *dst, const expr::Exp< E, DType, etype > &exp, DType scale=1)
CPU/GPU: map a expression, do reduction to 1D Tensor in lowest dimension (dimension 0)
Definition: tensor_cpu-inl.h:228
mshadow::Tensor
general tensor
Definition: tensor.h:525
mshadow::packet::AlignedMallocPitch
void * AlignedMallocPitch(size_t *out_pitch, size_t lspace, size_t num_line)
analog to cudaMallocPitch, allocate a aligned space with num_line * lspace cells
Definition: packet-inl.h:77
mshadow::VectorizedSort
void VectorizedSort(Tensor< Device, 1, VDType > values, Tensor< Device, 1, SDType > segments)
CPU/GPU: Sort the keys within each segment. (Stable sort is performed!) Segments is defined as an asc...
Definition: tensor_cpu-inl.h:627
mshadow::expr::ShapeCheck
runtime shape checking template get the shape of an expression, report error if shape mismatch
Definition: expr_engine-inl.h:364
mshadow::FreeHost_< cpu >
void FreeHost_< cpu >(void *dptr)
Definition: tensor_cpu-inl.h:95
mshadow::operator<<
std::ostream & operator<<(std::ostream &os, const Shape< ndim > &shape)
allow string printing of the shape
Definition: tensor_cpu-inl.h:59
mshadow::expr::BLASEngine::batched_gemm
static void batched_gemm(Stream< Device > *stream, bool transa, bool transb, int m, int n, int k, DType alpha, const DType *A, int lda, const DType *B, int ldb, DType beta, DType *C, int ldc, int batch_count, DType **workspace)
Definition: dot_engine-inl.h:91
mshadow::Tensor< Device, 1, DType >::size
MSHADOW_XINLINE index_t size(index_t i) const
Definition: tensor.h:711
mshadow::expr::ShapeCheck::Check
static Shape< dim > Check(const E &t)
mshadow::cpu
device name CPU
Definition: tensor.h:39
mshadow::Softmax
void Softmax(Tensor< cpu, 2, DType > dst, const Tensor< cpu, 2, DType > &energy)
CPU/GPU: normalize softmax: dst[i][j] = exp(energy[i][j]) /(sum_j exp(energy[i][j]))
Definition: tensor_cpu-inl.h:488
tensor.h
header file of tensor data structure and functions This lib requires explicit memory allocation and d...
mshadow::expr::MakePlan
Plan< BinaryMapExp< OP, TA, TB, DType, etype >, DType > MakePlan(const BinaryMapExp< OP, TA, TB, DType, etype > &e)
Definition: expr_engine-inl.h:239
mshadow::AllocHost
void AllocHost(Tensor< cpu, dim, DType > *obj)
Definition: tensor_cpu-inl.h:100
mshadow::Tensor::stream_
Stream< Device > * stream_
stream where the computation lies stream is a device dependency concept where each computation
Definition: tensor.h:551
mshadow::AllocHost_
void * AllocHost_(size_t size)
mshadow::Shape::FlatTo2D
MSHADOW_XINLINE Shape< 2 > FlatTo2D(void) const
Definition: tensor.h:146
mshadow::expr::PacketAlignCheck
Definition: packet-inl.h:379
packet-inl.h
Generic packet vectorization code.
mshadow::MapExpCPUEngine
Definition: tensor_cpu-inl.h:187
mshadow::Tensor::CheckContiguous
MSHADOW_XINLINE bool CheckContiguous(void) const
Definition: tensor.h:596
mshadow::Tensor::shape_
Shape< dimension > shape_
shape of the tensor
Definition: tensor.h:541
mshadow::expr::Plan::Eval
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
evaluate the expression at index [y][x] to be implemented by SubType, for RValue, the return type wil...
mshadow::index_t
int32_t index_t
type that will be used for index
Definition: base.h:328
mshadow::Tensor< Device, 1, DType >::dptr_
DType * dptr_
Definition: tensor.h:676
mshadow::expr::Plan
Definition: expr_engine-inl.h:58
mshadow::AddTakeGradLargeBatch
void AddTakeGradLargeBatch(Tensor< cpu, 2, DType > dst, const Tensor< cpu, 1, IndexType > &sorted, const Tensor< cpu, 1, IndexType > &index, const Tensor< cpu, 2, DType > &src)
CPU/GPU: Gradient accumulate of embedding matrix with safe accumulation. dst[index[i]] += src[i].
Definition: tensor_cpu-inl.h:575
mshadow::expr::Exp
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:79
mshadow::expr::pad
PaddingExp< SrcExp, DType, ExpInfo< SrcExp >::kDim > pad(const Exp< SrcExp, DType, etype > &src, index_t pad)
padding expression, pad a image with zeros on boundaries, padding affects shape[0],...
Definition: pad.h:71
mshadow::AllocSpace
void AllocSpace(Tensor< cpu, dim, DType > *obj, bool pad=MSHADOW_ALLOC_PAD)
CPU/CPU: allocate space for CTensor, according to the shape in the obj this function is responsible t...
Definition: tensor_cpu-inl.h:116
mshadow::MapReduceKeepHighDim
void MapReduceKeepHighDim(TRValue< R, cpu, 1, DType > *dst, const expr::Exp< E, DType, etype > &exp, DType scale=1)
CPU/GPU: map a expression, do reduction to 1D Tensor in third dimension (dimension 2)
Definition: tensor_cpu-inl.h:255
mshadow::InitTensorEngine< cpu >
void InitTensorEngine< cpu >(int dev_id)
Definition: tensor_cpu-inl.h:38
mshadow::AllocHost_< cpu >
void * AllocHost_< cpu >(size_t size)
Definition: tensor_cpu-inl.h:90
mshadow::Tensor< Device, 1, DType >
Definition: tensor.h:673
mshadow
overloaded + operator between half_t and bf16_t
Definition: base.h:319
mshadow::MapPlan
void MapPlan(TRValue< R, cpu, dim, DType > *dst, const expr::Plan< E, DType > &plan)
Definition: tensor_cpu-inl.h:168
mshadow::Tensor::FlatTo2D
MSHADOW_XINLINE Tensor< Device, 2, DType > FlatTo2D(void) const
flatten the tensor to 2 dimension, collapse the higher dimensions together
Definition: tensor.h:624
mshadow::Shape
shape of a tensor
Definition: tensor.h:64
mshadow::Tensor::dptr_
DType * dptr_
pointer to the data
Definition: tensor.h:539
MSHADOW_DEFAULT_PACKET
#define MSHADOW_DEFAULT_PACKET
Definition: packet-inl.h:47
mshadow::VectorDot
void VectorDot(Tensor< Device, 1, DType > dst, const Tensor< Device, 1, DType > &lhs, const Tensor< Device, 1, DType > &rhs)
CPU/GPU: 1 dimension vector dot.
Definition: tensor_cpu-inl.h:635
mshadow::Tensor::size
MSHADOW_XINLINE index_t size(int idx) const
return size of i-th dimension, start counting from highest dimension
Definition: tensor.h:610
mshadow::NewTensor
Tensor< Device, dim, DType > NewTensor(const Shape< dim > &shape, DType initv, bool pad=MSHADOW_ALLOC_PAD, Stream< Device > *stream=NULL)
CPU/GPU: short cut to allocate and initialize a Tensor.
Definition: tensor_cpu-inl.h:132
mshadow::AddTakeGrad
void AddTakeGrad(Tensor< cpu, 2, DType > dst, const Tensor< cpu, 1, IndexType > &index, const Tensor< cpu, 2, DType > &src)
CPU/GPU: Gradient accumulate of embedding matrix. dst[index[i]] += src[i] Called when the featuredim ...
Definition: tensor_cpu-inl.h:521
mshadow::NewStream< cpu >
Stream< cpu > * NewStream< cpu >(bool create_blas_handle, bool create_dnn_handle, int dev_id)
Definition: tensor_cpu-inl.h:48
mshadow::SmoothSoftmaxGrad
void SmoothSoftmaxGrad(Tensor< cpu, 2, DType > dst, const Tensor< cpu, 2, DType > &src, const Tensor< cpu, 1, DType > &label, const float alpha)
Definition: tensor_cpu-inl.h:328
base.h
definitions of base types, operators, macros functions
mshadow::expr::ScalarExp
scalar expression
Definition: expression.h:95
mshadow::MapExpCPUEngine< true, SV, Tensor< cpu, dim, DType >, dim, DType, E, etype >::Map
static void Map(Tensor< cpu, dim, DType > *dst, const expr::Exp< E, DType, etype > &exp)
Definition: tensor_cpu-inl.h:197
mshadow::MapExpCPUEngine::Map
static void Map(TRValue< R, cpu, dim, DType > *dst, const expr::Exp< E, DType, etype > &exp)
Definition: tensor_cpu-inl.h:188
mshadow::Shape3
MSHADOW_XINLINE Shape< 3 > Shape3(index_t s0, index_t s1, index_t s2)
construct a three dimension shape, stride will equal s0
Definition: tensor.h:241
mshadow::DeleteStream< cpu >
void DeleteStream< cpu >(Stream< cpu > *stream)
Definition: tensor_cpu-inl.h:54
mshadow::FreeHost
void FreeHost(Tensor< cpu, dim, DType > *obj)
Definition: tensor_cpu-inl.h:107
mshadow::ShutdownTensorEngine< cpu >
void ShutdownTensorEngine< cpu >(void)
Definition: tensor_cpu-inl.h:41
mshadow::Tensor::stride_
index_t stride_
storing the stride information in x dimension this is used to deal with pitch allocation in gpu or ss...
Definition: tensor.h:546
mshadow::packet::AlignedFree
void AlignedFree(void *ptr)
free aligned space
Definition: packet-inl.h:106
mshadow::MapExp
void MapExp(TRValue< R, cpu, dim, DType > *dst, const expr::Exp< E, DType, etype > &exp)
CPU/GPU: map a expression to a tensor, this function calls MapPlan.
Definition: tensor_cpu-inl.h:212