mxnet
tensor_cpu-inl.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
26 #ifndef MSHADOW_TENSOR_CPU_INL_H_
27 #define MSHADOW_TENSOR_CPU_INL_H_
28 #include <cstring>
29 #include <functional>
30 #include <utility>
31 #include <vector>
32 #include "./base.h"
33 #include "./tensor.h"
34 #include "./packet-inl.h"
35 #include "./dot_engine-inl.h"
36 
37 namespace mshadow {
38 template<>
39 inline void InitTensorEngine<cpu>(int dev_id) {
40 }
41 template<>
42 inline void ShutdownTensorEngine<cpu>(void) {
43 }
44 
45 template<>
46 inline void SetDevice<cpu>(int devid) {
47 }
48 template<>
49 inline Stream<cpu> *NewStream<cpu>(bool create_blas_handle,
50  bool create_dnn_handle,
51  int dev_id) {
52  return new Stream<cpu>();
53 }
54 template<>
55 inline void DeleteStream<cpu>(Stream<cpu> *stream) {
56  delete stream;
57 }
58 
59 template<int ndim>
60 inline std::ostream &operator<<(std::ostream &os, const Shape<ndim> &shape) { // NOLINT(*)
61  os << '(';
62  for (int i = 0; i < ndim; ++i) {
63  if (i != 0) os << ',';
64  os << shape[i];
65  }
66  // python style tuple
67  if (ndim == 1) os << ',';
68  os << ')';
69  return os;
70 }
71 
72 template<typename xpu>
73 inline void *AllocHost_(size_t size);
74 template<typename xpu>
75 inline void FreeHost_(void * dptr);
76 
77 #ifdef __CUDACC__
78 template<>
79 inline void *AllocHost_<gpu>(size_t size) {
80  void *dptr;
81  MSHADOW_CUDA_CALL(cudaMallocHost(&dptr, size, cudaHostAllocPortable));
82  return dptr;
83 }
84 template<>
85 inline void FreeHost_<gpu>(void *dptr) {
86  MSHADOW_CUDA_CALL(cudaFreeHost(dptr));
87 }
88 #endif
89 
90 template<>
91 inline void *AllocHost_<cpu>(size_t size) {
92  size_t pitch;
93  return packet::AlignedMallocPitch(&pitch, size, 1);
94 }
95 template<>
96 inline void FreeHost_<cpu>(void *dptr) {
97  packet::AlignedFree(dptr);
98 }
99 
100 template<typename xpu, int dim, typename DType>
102  obj->stride_ = obj->size(dim - 1);
103  CHECK_EQ(obj->CheckContiguous(), true) << "AllocHost";
104  void *dptr = AllocHost_<xpu>(obj->MSize() * sizeof(DType));
105  obj->dptr_ = reinterpret_cast<DType*>(dptr);
106 }
107 template<typename xpu, int dim, typename DType>
109  if (obj->dptr_ == NULL) {
110  LOG(FATAL) << "FreeHost:: double free";
111  }
112  FreeHost_<xpu>(obj->dptr_);
113  obj->dptr_ = NULL;
114 }
115 
116 template<int dim, typename DType>
117 inline void AllocSpace(Tensor<cpu, dim, DType> *obj, bool pad) {
118  size_t pitch;
119  void *dptr;
120  if (pad) {
122  (&pitch, obj->size(dim - 1) * sizeof(DType), obj->shape_.FlatTo2D()[0]);
123  obj->stride_ = static_cast<index_t>(pitch / sizeof(DType));
124  } else {
125  obj->stride_ = obj->size(dim - 1);
127  (&pitch, obj->shape_.Size() * sizeof(DType), 1);
128  }
129  obj->dptr_ = reinterpret_cast<DType*>(dptr);
130 }
131 template<typename Device, typename DType, int dim>
133 NewTensor(const Shape<dim> &shape, DType initv, bool pad, Stream<Device> *stream_) {
134  Tensor<Device, dim, DType> obj(shape);
135  obj.stream_ = stream_;
136  AllocSpace(&obj, pad);
137  MapExp<sv::saveto>(&obj, expr::ScalarExp<DType>(initv));
138  return obj;
139 }
140 template<int dim, typename DType>
143  obj->dptr_ = NULL;
144 }
145 template<int dim, typename DType>
146 inline void Copy(Tensor<cpu, dim, DType> _dst,
147  const Tensor<cpu, dim, DType> &_src,
148  Stream<cpu> *stream) {
149  CHECK_EQ(_dst.shape_, _src.shape_)
150  << "Copy:shape mismatch:" << _dst.shape_ << " vs " << _src.shape_;
151  if (_dst.CheckContiguous() && _src.CheckContiguous()) {
152  memcpy(_dst.dptr_, _src.dptr_, sizeof(DType) * _dst.shape_.Size());
153  } else {
154  Tensor<cpu, 2, DType> dst = _dst.FlatTo2D();
155  Tensor<cpu, 2, DType> src = _src.FlatTo2D();
156  for (index_t y = 0; y < dst.size(0); ++y) {
157  memcpy(dst[y].dptr_, src[y].dptr_, sizeof(DType) * dst.size(1));
158  }
159  }
160 }
161 
162 template<typename Saver, typename R, int dim,
163  typename DType, typename E>
165  const expr::Plan<E, DType> &plan) {
166  Shape<2> shape = expr::ShapeCheck<dim, R>::Check(dst->self()).FlatTo2D();
167  expr::Plan<R, DType> dplan = expr::MakePlan(dst->self());
168 #ifndef __CUDACC__
169  #pragma omp parallel for
170 #endif
171  // temp remove openmp, as default setting throttles CPU
172  for (openmp_index_t y = 0; y < shape[0]; ++y) {
173  for (index_t x = 0; x < shape[1]; ++x) {
174  // trust your compiler! -_- they will optimize it
175  Saver::template Save<DType>(dplan.REval(y, x), plan.Eval(y, x));
176  }
177  }
178 }
179 // code to handle SSE optimization
180 template<bool pass_check, typename Saver,
181  typename R, int dim,
182  typename DType, typename E, int etype>
184  inline static void Map(TRValue<R, cpu, dim, DType> *dst,
185  const expr::Exp<E, DType, etype> &exp) {
186  MapPlan<Saver>(dst, MakePlan(exp.self()));
187  }
188 };
189 
190 template<typename SV, int dim, typename DType, typename E, int etype>
191 struct MapExpCPUEngine<true, SV, Tensor<cpu, dim, DType>,
192  dim, DType, E, etype> {
193  inline static void Map(Tensor<cpu, dim, DType> *dst,
194  const expr::Exp<E, DType, etype> &exp) {
197  expr::MapPacketPlan<SV>(dst->self(),
198  expr::MakePacketPlan<MSHADOW_DEFAULT_PACKET>(exp.self()));
199  } else {
200  MapPlan<SV>(dst, MakePlan(exp.self()));
201  }
202  }
203 };
204 
205 
206 template<typename Saver, typename R, int dim,
207  typename DType, typename E, int etype>
209  const expr::Exp<E, DType, etype> &exp) {
211  ::Error_All_Tensor_in_Exp_Must_Have_Same_Type();
214  CHECK(eshape[0] == 0 || eshape == dshape)
215  << "Assignment: Shape of Tensors are not consistent with target, "
216  << "eshape: " << eshape << " dshape:" << dshape;
218  Saver, R, dim, DType, E, etype>
219  ::Map(dst->ptrself(), exp);
220 }
221 
222 template<typename Saver, typename Reducer,
223  typename R, typename DType, typename E, int etype>
225  const expr::Exp<E, DType, etype> &exp,
226  DType scale) {
228  ::Error_TypeCheck_Not_Pass_For_Reduce_Exp();
230  ::Check(exp.self()).FlatTo2D();
232  CHECK_EQ(eshape[1], dshape[0]) << "MapReduceKeepLowest::reduction dimension do not match";
233  CHECK_NE(eshape[0], 0U) << "can not reduce over empty tensor";
234  // execution
235  expr::Plan<R, DType> dplan = MakePlan(dst->self());
236  expr::Plan<E, DType> splan = MakePlan(exp.self());
237 #ifndef __CUDACC__
238  #pragma omp parallel for
239 #endif
240  for (openmp_index_t x = 0; x < eshape[1]; ++x) {
241  DType res = splan.Eval(0, x);
242  for (index_t y = 1; y < eshape[0]; ++y) {
243  Reducer::Reduce(res, splan.Eval(y, x));
244  }
245  Saver::template Save<DType>(dplan.REval(0, x), res * scale);
246  }
247 }
248 
249 template<typename Saver, typename Reducer, int dimkeep,
250  typename R, typename DType, typename E, int etype>
252  const expr::Exp<E, DType, etype> &exp,
253  DType scale) {
255  ::Error_TypeCheck_Not_Pass_For_Reduce_Exp();
256  typedef Shape<expr::ExpInfo<E>::kDim> EShape;
257  EShape eshape = expr::ShapeCheck<expr::ExpInfo<E>::kDim, E>
258  ::Check(exp.self());
260  CHECK_EQ(eshape[dimkeep], dshape[0])
261  << "MapReduceKeepHighDim::reduction dimension do not match";
262  // use equvalent form
263  Shape<4> pshape = Shape4(eshape.ProdShape(0, dimkeep),
264  eshape[dimkeep],
265  eshape.ProdShape(dimkeep + 1, EShape::kSubdim),
266  eshape[EShape::kSubdim]);
267  // execution
268  expr::Plan<R, DType> dplan = MakePlan(dst->self());
269  expr::Plan<E, DType> splan = MakePlan(exp.self());
270 #ifndef __CUDACC__
271  #pragma omp parallel for
272 #endif
273  for (openmp_index_t c = 0; c < pshape[1]; ++c) {
274  DType res; Reducer::SetInitValue(res);
275  for (index_t n = 0; n < pshape[0]; ++n) {
276  DType tres; Reducer::SetInitValue(tres);
277  for (index_t y = 0; y < pshape[2]; ++y) {
278  for (index_t x = 0; x < pshape[3]; ++x) {
279  Reducer::Reduce(tres,
280  splan.Eval((n * pshape[1] + c) * pshape[2] + y, x));
281  }
282  }
283  Reducer::Reduce(res, tres);
284  }
285  Saver::template Save<DType>(dplan.REval(0, c), DType(res * scale));
286  }
287 }
288 
289 template<typename DType>
291  const Tensor<cpu, 1, DType> &energy) {
292  DType mmax = energy[0];
293  for (index_t x = 1; x < dst.size(0); ++x) {
294  if (mmax < energy[x]) mmax = energy[x];
295  }
296  DType sum = DType(0.0f);
297  for (index_t x = 0; x < dst.size(0); ++x) {
298  dst[x] = std::exp(energy[x] - mmax);
299  sum += dst[x];
300  }
301  for (index_t x = 0; x < dst.size(0); ++x) {
302  dst[x] /= sum;
303  }
304 }
305 
306 template<typename DType>
308  const Tensor<cpu, 2, DType> &src,
309  const Tensor<cpu, 1, DType> &label) {
310 #pragma omp parallel for
311  for (openmp_index_t y = 0; y < dst.size(0); ++y) {
312  const index_t k = static_cast<int>(label[y]);
313  for (index_t x = 0; x < dst.size(1); ++x) {
314  if (x == k) {
315  dst[y][k] = src[y][k] - 1.0f;
316  } else {
317  dst[y][x] = src[y][x];
318  }
319  }
320  }
321 }
322 
323 template<typename DType>
325  const Tensor<cpu, 2, DType> &src,
326  const Tensor<cpu, 1, DType> &label,
327  const float alpha) {
328  const float smooth_grad = (alpha / (dst.size(1) - 1));
329 #pragma omp parallel for
330  for (openmp_index_t y = 0; y < dst.size(0); ++y) {
331  const index_t k = static_cast<int>(label[y]);
332  for (index_t x = 0; x < dst.size(1); ++x) {
333  if (x == k) {
334  dst[y][k] = src[y][k] - 1.0f + alpha;
335  } else {
336  dst[y][x] = src[y][x] - smooth_grad;
337  }
338  }
339  }
340 }
341 
342 
343 template<typename DType>
345  const Tensor<cpu, 2, DType> &src,
346  const Tensor<cpu, 1, DType> &label,
347  const DType &ignore_label) {
348 #pragma omp parallel for
349  for (openmp_index_t y = 0; y < dst.size(0); ++y) {
350  const int k = static_cast<int>(label[y]);
351  for (int x = 0; x < static_cast<int>(dst.size(1)); ++x) {
352  if (static_cast<int>(ignore_label) == k) {
353  dst[y][x] = 0.0f;
354  } else {
355  if (x == k) {
356  dst[y][k] = src[y][k] - 1.0f;
357  } else {
358  dst[y][x] = src[y][x];
359  }
360  }
361  }
362  }
363 }
364 
365 template<typename DType>
367  const Tensor<cpu, 2, DType> &src,
368  const Tensor<cpu, 1, DType> &label,
369  const DType &ignore_label,
370  const float alpha) {
371  const float smooth_grad = (alpha / (dst.size(1) - 1));
372 #pragma omp parallel for
373  for (openmp_index_t y = 0; y < dst.size(0); ++y) {
374  const int k = static_cast<int>(label[y]);
375  for (int x = 0; x < static_cast<int>(dst.size(1)); ++x) {
376  if (static_cast<int>(ignore_label) == k) {
377  dst[y][x] = 0.0f;
378  } else {
379  if (x == k) {
380  dst[y][k] = src[y][k] - 1.0f + alpha;
381  } else {
382  dst[y][x] = src[y][x] - smooth_grad;
383  }
384  }
385  }
386  }
387 }
388 
389 template<typename DType>
391  const Tensor<cpu, 3, DType> &src,
392  const Tensor<cpu, 2, DType> &label) {
393 #pragma omp parallel for
394  for (openmp_index_t n = 0; n < dst.size(2); ++n) {
395  for (index_t y = 0; y < dst.size(0); ++y) {
396  const int k = static_cast<int>(label[y][n]);
397  for (int x = 0; x < static_cast<int>(dst.size(1)); ++x) {
398  if (x == k) {
399  dst[y][k][n] = src[y][k][n] - 1.0f;
400  } else {
401  dst[y][x][n] = src[y][x][n];
402  }
403  }
404  }
405  }
406 }
407 
408 template<typename DType>
410  const Tensor<cpu, 3, DType> &src,
411  const Tensor<cpu, 2, DType> &label,
412  const float alpha) {
413  const float smooth_grad = (alpha / (dst.size(1) - 1));
414 #pragma omp parallel for
415  for (openmp_index_t n = 0; n < dst.size(2); ++n) {
416  for (index_t y = 0; y < dst.size(0); ++y) {
417  const int k = static_cast<int>(label[y][n]);
418  for (int x = 0; x < static_cast<int>(dst.size(1)); ++x) {
419  if (x == k) {
420  dst[y][k][n] = src[y][k][n] - 1.0f + alpha;
421  } else {
422  dst[y][x][n] = src[y][x][n] - smooth_grad;
423  }
424  }
425  }
426  }
427 }
428 
429 template<typename DType>
431  const Tensor<cpu, 3, DType> &src,
432  const Tensor<cpu, 2, DType> &label,
433  const DType &ignore_label) {
434 #pragma omp parallel for
435  for (openmp_index_t n = 0; n < dst.size(2); ++n) {
436  for (index_t y = 0; y < dst.size(0); ++y) {
437  const int k = static_cast<int>(label[y][n]);
438  if (k == static_cast<int>(ignore_label)) {
439  for (int x = 0; x < static_cast<int>(dst.size(1)); ++x) {
440  dst[y][x][n] = DType(0.0f);
441  }
442  } else {
443  for (int x = 0; x < static_cast<int>(dst.size(1)); ++x) {
444  if (x == k) {
445  dst[y][k][n] = src[y][k][n] - 1.0f;
446  } else {
447  dst[y][x][n] = src[y][x][n];
448  }
449  }
450  }
451  }
452  }
453 }
454 
455 template<typename DType>
457  const Tensor<cpu, 3, DType> &src,
458  const Tensor<cpu, 2, DType> &label,
459  const DType &ignore_label,
460  const float alpha) {
461  const float smooth_grad = (alpha / (dst.size(1) - 1));
462 #pragma omp parallel for
463  for (openmp_index_t n = 0; n < dst.size(2); ++n) {
464  for (index_t y = 0; y < dst.size(0); ++y) {
465  const int k = static_cast<int>(label[y][n]);
466  if (k == static_cast<int>(ignore_label)) {
467  for (int x = 0; x < static_cast<int>(dst.size(1)); ++x) {
468  dst[y][x][n] = DType(0.0f);
469  }
470  } else {
471  for (int x = 0; x < static_cast<int>(dst.size(1)); ++x) {
472  if (x == k) {
473  dst[y][k][n] = src[y][k][n] - 1.0f + alpha;
474  } else {
475  dst[y][x][n] = src[y][x][n] - smooth_grad;
476  }
477  }
478  }
479  }
480  }
481 }
482 
483 template<typename DType>
485  const Tensor<cpu, 2, DType> &energy) {
486  CHECK_EQ(dst.shape_, energy.shape_) << "Softmax: shape mismatch";
487 #pragma omp parallel for
488  for (openmp_index_t y = 0; y < dst.size(0); ++y) {
489  Softmax(dst[y], energy[y]);
490  }
491 }
492 
493 template<typename DType>
495  const Tensor<cpu, 3, DType> &energy) {
496  CHECK_EQ(dst.shape_, energy.shape_) << "Softmax: shape mismatch";
497 #pragma omp parallel for
498  for (openmp_index_t y = 0; y < dst.size(0); ++y) {
499  for (index_t n = 0; n < dst.size(2); ++n) {
500  DType mmax = energy[y][0][n];
501  for (index_t x = 1; x < dst.size(1); ++x) {
502  if (mmax < energy[y][x][n]) mmax = energy[y][x][n];
503  }
504  DType sum = DType(0.0f);
505  for (index_t x = 0; x < dst.size(1); ++x) {
506  dst[y][x][n] = std::exp(energy[y][x][n] - mmax);
507  sum += dst[y][x][n];
508  }
509  for (index_t x = 0; x < dst.size(1); ++x) {
510  dst[y][x][n] /= sum;
511  }
512  }
513  }
514 }
515 
516 template<bool clip, typename IndexType, typename DType>
518  const Tensor<cpu, 1, IndexType>& index,
519  const Tensor<cpu, 2, DType> &src) {
520  const index_t K = dst.shape_[0];
521  const index_t C = dst.shape_[1];
522  for (index_t y = 0; y < index.size(0); ++y) {
523  index_t j = index[y];
524  if (clip) {
525  if (j <= 0) j = 0;
526  else if (j >= K) j = K - 1;
527  } else {
528  j %= K;
529  if (j < 0) j += K;
530  }
531  for (index_t i = 0; i < C; ++i) {
532  dst[j][i] += src[y][i];
533  }
534  }
535 }
536 
537 template<typename IndexType, typename DType>
539  const Tensor<cpu, 1, IndexType>& sorted,
540  const Tensor<cpu, 1, IndexType>& index,
541  const Tensor<cpu, 2, DType> &src) {
542  for (index_t y = 0; y < sorted.size(0); ++y) {
543  dst[sorted[y]] += src[index[y]];
544  }
545 }
546 
547 template<typename IndexType, typename DType>
549  const Tensor<cpu, 1, IndexType>& index,
550  const Tensor<cpu, 2, DType> &src) {
551  for (index_t y = 0; y < index.size(0); ++y) {
552  for (index_t j = 0; j < src.size(1); j++) {
553  dst[index[y]][j] = src[y][j];
554  }
555  }
556 }
557 
558 template<typename KDType, typename VDType>
560  bool is_ascend) {
561  CHECK_EQ(keys.CheckContiguous(), true);
562  CHECK_EQ(values.CheckContiguous(), true);
563  CHECK_EQ(keys.size(0), values.size(0))
564  << "The sizes of key/value are not equal! keys_size: " << keys.size(0)
565  << "values_size: " << values.size(0);
566  std::vector<size_t> idx(keys.size(0));
567  std::vector<KDType> keys_vec(keys.size(0));
568  std::vector<VDType> values_vec(values.size(0));
569  for (int i = 0; i < keys.size(0); i++) {
570  idx[i] = i;
571  keys_vec[i] = keys[i];
572  values_vec[i] = values[i];
573  }
574  if (is_ascend) {
575  std::stable_sort(idx.begin(), idx.end(),
576  [&keys_vec](size_t i1, size_t i2)
577  {return keys_vec[i1] < keys_vec[i2]; });
578  } else {
579  std::stable_sort(idx.begin(), idx.end(),
580  [&keys_vec](size_t i1, size_t i2)
581  {return keys_vec[i1] > keys_vec[i2]; });
582  }
583  for (index_t i = 0; i < values.size(0); i++) {
584  keys[i] = keys_vec[idx[i]];
585  values[i] = values_vec[idx[i]];
586  }
587 }
588 
589 template<typename Device, typename VDType, typename SDType>
591  // We can sort each segments using two stable sorts
592  SortByKey(values, segments, true);
593  SortByKey(segments, values, true);
594 }
595 
596 // blas related
597 template<typename Device, typename DType>
599  const Tensor<Device, 1, DType> &lhs,
600  const Tensor<Device, 1, DType> &rhs) {
601  CHECK_EQ(lhs.size(0), rhs.size(0))
602  << "VectorDot: Shape mismatch";
603  CHECK_EQ(dst.size(0), 1U)
604  << "VectorDot: expect dst to be scalar";
607  lhs.stream_, lhs.size(0), lhs.dptr_, 1, rhs.dptr_, 1, dst.dptr_);
608 }
609 
610 template<bool transpose_left, bool transpose_right, typename Device, typename DType>
612  const Tensor<Device, 3, DType> &lhs,
613  const Tensor<Device, 3, DType> &rhs,
614  DType alpha,
615  DType beta,
616  Tensor<Device, 1, DType*> workspace) {
617  index_t batch_size = dst.shape_[0];
619  Shape<3> sleft = transpose_left ? Shape3(lhs.shape_[0], lhs.shape_[2], lhs.shape_[1])
620  : lhs.shape_;
621  Shape<3> sright = transpose_right ? Shape3(rhs.shape_[0], rhs.shape_[2], rhs.shape_[1])
622  : rhs.shape_;
623  CHECK_EQ(dst.CheckContiguous(), true);
624  CHECK_EQ(lhs.CheckContiguous(), true);
625  CHECK_EQ(rhs.CheckContiguous(), true);
626  CHECK(sleft[0] == batch_size && sright[0] == batch_size)
627  << "BatchGEMM: batchsize must be equal."
628  << "dst: " << dst.shape_ << "\n"
629  << "lhs: " << sleft << "\n"
630  << "rhs: " << sright << "\n";
631  CHECK(dst.size(1) == sleft[1] && dst.size(2) == sright[2] && sleft[2] == sright[1])
632  << "BatchGEMM: matrix shape mismatch"
633  << "dst: " << dst.shape_ << "\n"
634  << "lhs: " << sleft << "\n"
635  << "rhs: " << sright << "\n";
636  CHECK(workspace.size(0) >= 3 * batch_size)
637  << "Workspace Size must be bigger than " << 3 * batch_size;
638  CHECK_EQ(workspace.CheckContiguous(), true);
639  // use column major argument to compatible with most BLAS
641  (dst.stream_,
642  transpose_right, transpose_left,
643  transpose_right ? rhs.size(1) : rhs.size(2),
644  transpose_left ? lhs.size(2) : lhs.size(1),
645  transpose_right ? rhs.size(2) : rhs.size(1),
646  alpha,
647  rhs.dptr_, rhs.stride_,
648  lhs.dptr_, lhs.stride_,
649  beta,
650  dst.dptr_, dst.stride_, batch_size,
651  workspace.dptr_);
652 }
653 } // namespace mshadow
654 #endif // MSHADOW_TENSOR_CPU_INL_H_
void VectorDot(Tensor< Device, 1, DType > dst, const Tensor< Device, 1, DType > &lhs, const Tensor< Device, 1, DType > &rhs)
CPU/GPU: 1 dimension vector dot.
Definition: tensor_cpu-inl.h:598
static void batched_gemm(Stream< Device > *stream, bool transa, bool transb, int m, int n, int k, DType alpha, const DType *A, int lda, const DType *B, int ldb, DType beta, DType *C, int ldc, int batch_count, DType **workspace)
Definition: dot_engine-inl.h:92
void FreeSpace(Tensor< cpu, dim, DType > *obj)
CPU/GPU: free the space of tensor, will set obj.dptr to NULL.
Definition: tensor_cpu-inl.h:141
void ShutdownTensorEngine< cpu >(void)
Definition: tensor_cpu-inl.h:42
Stream< Device > * stream_
Definition: tensor.h:575
void IndexFill(Tensor< cpu, 2, DType > dst, const Tensor< cpu, 1, IndexType > &index, const Tensor< cpu, 2, DType > &src)
CPU/GPU: Fill the values of the destination matrix to specific rows in the source matrix...
Definition: tensor_cpu-inl.h:548
void SoftmaxGrad(Tensor< cpu, 2, DType > dst, const Tensor< cpu, 2, DType > &src, const Tensor< cpu, 1, DType > &label)
CPU/GPU: softmax gradient.
Definition: tensor_cpu-inl.h:307
void SmoothSoftmaxGrad(Tensor< cpu, 2, DType > dst, const Tensor< cpu, 2, DType > &src, const Tensor< cpu, 1, DType > &label, const float alpha)
Definition: tensor_cpu-inl.h:324
PaddingExp< SrcExp, DType, ExpInfo< SrcExp >::kDim > pad(const Exp< SrcExp, DType, etype > &src, index_t pad)
padding expression, pad a image with zeros on boundaries, padding affects shape[0], and shape[1]
Definition: pad.h:72
DType * dptr_
pointer to the data
Definition: tensor.h:435
void FreeHost_(void *dptr)
Tensor RValue, this is the super type of all kinds of possible tensors.
Definition: tensor.h:410
Definition: expr_engine-inl.h:59
void SetDevice< cpu >(int devid)
Definition: tensor_cpu-inl.h:46
used to help static type check
Definition: expr_engine-inl.h:331
void AlignedFree(void *ptr)
free aligned space
Definition: packet-inl.h:103
void Copy(Tensor< cpu, dim, DType > dst, const Tensor< cpu, dim, DType > &src, Stream< cpu > *stream=NULL)
copy data from one tensor to another, with same shape
Definition: tensor_cpu-inl.h:146
void MapExp(TRValue< R, cpu, dim, DType > *dst, const expr::Exp< E, DType, etype > &exp)
CPU/GPU: map a expression to a tensor, this function calls MapPlan.
Definition: tensor_cpu-inl.h:208
Container * ptrself(void)
Definition: expression.h:87
Shape< dimension > shape_
shape of the tensor
Definition: tensor.h:437
Definition: packet-inl.h:376
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
evaluate the expression at index [y][x] to be implemented by SubType, for RValue, the return type wil...
MSHADOW_XINLINE Shape< 4 > Shape4(index_t s0, index_t s1, index_t s2, index_t s3)
construct a four dimension shape, stride will equal s0
Definition: tensor.h:241
void SortByKey(Tensor< cpu, 1, KDType > keys, Tensor< cpu, 1, VDType > values, bool is_ascend=true)
CPU/GPU: Sort key-value pairs stored in separate places. (Stable sort is performed!) ...
Definition: tensor_cpu-inl.h:559
void Softmax(Tensor< cpu, 2, DType > dst, const Tensor< cpu, 2, DType > &energy)
CPU/GPU: normalize softmax: dst[i][j] = exp(energy[i][j]) /(sum_j exp(energy[i][j])) ...
Definition: tensor_cpu-inl.h:484
void VectorizedSort(Tensor< Device, 1, VDType > values, Tensor< Device, 1, SDType > segments)
CPU/GPU: Sort the keys within each segment. (Stable sort is performed!) Segments is defined as an asc...
Definition: tensor_cpu-inl.h:590
void * AlignedMallocPitch(size_t *out_pitch, size_t lspace, size_t num_line)
analog to cudaMallocPitch, allocate a aligned space with num_line * lspace cells
Definition: packet-inl.h:78
void BatchGEMM(Tensor< Device, 3, DType > dst, const Tensor< Device, 3, DType > &lhs, const Tensor< Device, 3, DType > &rhs, DType alpha, DType beta, Tensor< Device, 1, DType * > workspace)
CPU/GPU: dst = alpha * op(lhs) op(rhs) + beta * dst.
Definition: tensor_cpu-inl.h:611
#define MSHADOW_CUDA_CALL(func)
Protected cuda call in mshadow.
Definition: base.h:271
void MapReduceKeepLowest(TRValue< R, cpu, 1, DType > *dst, const expr::Exp< E, DType, etype > &exp, DType scale=1)
CPU/GPU: map a expression, do reduction to 1D Tensor in lowest dimension (dimension 0) ...
Definition: tensor_cpu-inl.h:224
static Shape< dim > Check(const E &t)
header file of tensor data structure and functions This lib requires explicit memory allocation and d...
device name CPU
Definition: tensor.h:40
MSHADOW_XINLINE Tensor< Device, 2, DType > FlatTo2D(void) const
flatten the tensor to 2 dimension, collapse the higher dimensions together
Definition: tensor.h:520
void * AllocHost_(size_t size)
MSHADOW_XINLINE index_t size(index_t i) const
Definition: tensor.h:607
void FreeHost_< cpu >(void *dptr)
Definition: tensor_cpu-inl.h:96
int32_t index_t
type that will be used for index
Definition: base.h:336
void AllocSpace(Tensor< cpu, dim, DType > *obj, bool pad=MSHADOW_ALLOC_PAD)
CPU/CPU: allocate space for CTensor, according to the shape in the obj this function is responsible t...
Definition: tensor_cpu-inl.h:117
DType * dptr_
Definition: tensor.h:572
Generic packet vectorization code.
void InitTensorEngine< cpu >(int dev_id)
Definition: tensor_cpu-inl.h:39
void AddTakeGradLargeBatch(Tensor< cpu, 2, DType > dst, const Tensor< cpu, 1, IndexType > &sorted, const Tensor< cpu, 1, IndexType > &index, const Tensor< cpu, 2, DType > &src)
CPU/GPU: Gradient accumulate of embedding matrix. dst[sorted[i]] += src[index[i]] Called when the bat...
Definition: tensor_cpu-inl.h:538
static void dot(Stream< Device > *stream, int n, const DType *X, int incX, const DType *Y, int incY, DType *ret)
Definition: dot_engine-inl.h:126
void AllocHost(Tensor< cpu, dim, DType > *obj)
Definition: tensor_cpu-inl.h:101
runtime shape checking template get the shape of an expression, report error if shape mismatch ...
Definition: expr_engine-inl.h:365
Stream< cpu > * NewStream< cpu >(bool create_blas_handle, bool create_dnn_handle, int dev_id)
Definition: tensor_cpu-inl.h:49
void MapPlan(TRValue< R, cpu, dim, DType > *dst, const expr::Plan< E, DType > &plan)
Definition: tensor_cpu-inl.h:164
MSHADOW_XINLINE bool CheckContiguous(void) const
Definition: tensor.h:492
Definition: tensor_cpu-inl.h:183
scalar expression
Definition: expression.h:96
void MapReduceKeepHighDim(TRValue< R, cpu, 1, DType > *dst, const expr::Exp< E, DType, etype > &exp, DType scale=1)
CPU/GPU: map a expression, do reduction to 1D Tensor in third dimension (dimension 2) ...
Definition: tensor_cpu-inl.h:251
void * AllocHost_< cpu >(size_t size)
Definition: tensor_cpu-inl.h:91
Tensor< Device, dim, DType > NewTensor(const Shape< dim > &shape, DType initv, bool pad=MSHADOW_ALLOC_PAD, Stream< Device > *stream=NULL)
CPU/GPU: short cut to allocate and initialize a Tensor.
Definition: tensor_cpu-inl.h:133
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:80
const Container & self(void) const
Definition: expression.h:83
Plan< BinaryMapExp< OP, TA, TB, DType, etype >, DType > MakePlan(const BinaryMapExp< OP, TA, TB, DType, etype > &e)
Definition: expr_engine-inl.h:240
void AddTakeGrad(Tensor< cpu, 2, DType > dst, const Tensor< cpu, 1, IndexType > &index, const Tensor< cpu, 2, DType > &src)
CPU/GPU: Gradient accumulate of embedding matrix. dst[index[i]] += src[i] Called when the featuredim ...
Definition: tensor_cpu-inl.h:517
Definition: tensor.h:569
MSHADOW_XINLINE Shape< 3 > Shape3(index_t s0, index_t s1, index_t s2)
construct a three dimension shape, stride will equal s0
Definition: tensor.h:228
overloaded + operator between half_t and bf16_t
Definition: base.h:327
void FreeHost(Tensor< cpu, dim, DType > *obj)
Definition: tensor_cpu-inl.h:108
MSHADOW_XINLINE index_t size(int idx) const
return size of i-th dimension, start counting from highest dimension
Definition: tensor.h:506
index_t stride_
storing the stride information in x dimension this is used to deal with pitch allocation in gpu or ss...
Definition: tensor.h:442
#define MSHADOW_DEFAULT_PACKET
Definition: packet-inl.h:48
general tensor
Definition: tensor.h:421
static void SetStream(Stream< Device > *stream)
Definition: dot_engine-inl.h:83
MSHADOW_XINLINE index_t MSize(void) const
Definition: tensor.h:498
void DeleteStream< cpu >(Stream< cpu > *stream)
Definition: tensor_cpu-inl.h:55
static void Map(Tensor< cpu, dim, DType > *dst, const expr::Exp< E, DType, etype > &exp)
Definition: tensor_cpu-inl.h:193
index_t openmp_index_t
openmp index for linux
Definition: base.h:344
Stream< Device > * stream_
stream where the computation lies stream is a device dependency concept where each computation ...
Definition: tensor.h:447
definitions of how Matrix Multiplications can be evaluated
static void Map(TRValue< R, cpu, dim, DType > *dst, const expr::Exp< E, DType, etype > &exp)
Definition: tensor_cpu-inl.h:184
computaion stream structure, used for asynchronous computations
Definition: tensor.h:384