mxnet
tensor_cpu-inl.h
Go to the documentation of this file.
1 
7 #ifndef MSHADOW_TENSOR_CPU_INL_H_
8 #define MSHADOW_TENSOR_CPU_INL_H_
9 #include <cstring>
10 #include <functional>
11 #include <utility>
12 #include <vector>
13 #include "./base.h"
14 #include "./tensor.h"
15 #include "./packet-inl.h"
16 #include "./dot_engine-inl.h"
17 
18 namespace mshadow {
19 template<>
20 inline void InitTensorEngine<cpu>(int dev_id) {
21 }
22 template<>
23 inline void ShutdownTensorEngine<cpu>(void) {
24 }
25 
26 template<>
27 inline void SetDevice<cpu>(int devid) {
28 }
29 template<>
30 inline Stream<cpu> *NewStream<cpu>(bool create_blas_handle,
31  bool create_dnn_handle,
32  int dev_id) {
33  return new Stream<cpu>();
34 }
35 template<>
36 inline void DeleteStream<cpu>(Stream<cpu> *stream) {
37  delete stream;
38 }
39 
40 template<int ndim>
41 inline std::ostream &operator<<(std::ostream &os, const Shape<ndim> &shape) { // NOLINT(*)
42  os << '(';
43  for (int i = 0; i < ndim; ++i) {
44  if (i != 0) os << ',';
45  os << shape[i];
46  }
47  // python style tuple
48  if (ndim == 1) os << ',';
49  os << ')';
50  return os;
51 }
52 
53 template<typename xpu>
54 inline void *AllocHost_(size_t size);
55 template<typename xpu>
56 inline void FreeHost_(void * dptr);
57 
58 #ifdef __CUDACC__
59 template<>
60 inline void *AllocHost_<gpu>(size_t size) {
61  void *dptr;
62  MSHADOW_CUDA_CALL(cudaMallocHost(&dptr, size, cudaHostAllocPortable));
63  return dptr;
64 }
65 template<>
66 inline void FreeHost_<gpu>(void *dptr) {
67  MSHADOW_CUDA_CALL(cudaFreeHost(dptr));
68 }
69 #endif
70 
71 template<>
72 inline void *AllocHost_<cpu>(size_t size) {
73  size_t pitch;
74  return packet::AlignedMallocPitch(&pitch, size, 1);
75 }
76 template<>
77 inline void FreeHost_<cpu>(void *dptr) {
78  packet::AlignedFree(dptr);
79 }
80 
81 template<typename xpu, int dim, typename DType>
83  obj->stride_ = obj->size(dim - 1);
84  CHECK_EQ(obj->CheckContiguous(), true) << "AllocHost";
85  void *dptr = AllocHost_<xpu>(obj->MSize() * sizeof(DType));
86  obj->dptr_ = reinterpret_cast<DType*>(dptr);
87 }
88 template<typename xpu, int dim, typename DType>
89 inline void FreeHost(Tensor<cpu, dim, DType> *obj) {
90  if (obj->dptr_ == NULL) {
91  LOG(FATAL) << "FreeHost:: double free";
92  }
93  FreeHost_<xpu>(obj->dptr_);
94  obj->dptr_ = NULL;
95 }
96 
97 template<int dim, typename DType>
98 inline void AllocSpace(Tensor<cpu, dim, DType> *obj, bool pad) {
99  size_t pitch;
100  void *dptr;
101  if (pad) {
103  (&pitch, obj->size(dim - 1) * sizeof(DType), obj->shape_.FlatTo2D()[0]);
104  obj->stride_ = static_cast<index_t>(pitch / sizeof(DType));
105  } else {
106  obj->stride_ = obj->size(dim - 1);
108  (&pitch, obj->shape_.Size() * sizeof(DType), 1);
109  }
110  obj->dptr_ = reinterpret_cast<DType*>(dptr);
111 }
112 template<typename Device, typename DType, int dim>
114 NewTensor(const Shape<dim> &shape, DType initv, bool pad, Stream<Device> *stream_) {
115  Tensor<Device, dim, DType> obj(shape);
116  obj.stream_ = stream_;
117  AllocSpace(&obj, pad);
118  MapExp<sv::saveto>(&obj, expr::ScalarExp<DType>(initv));
119  return obj;
120 }
121 template<int dim, typename DType>
124  obj->dptr_ = NULL;
125 }
126 template<int dim, typename DType>
127 inline void Copy(Tensor<cpu, dim, DType> _dst,
128  const Tensor<cpu, dim, DType> &_src,
129  Stream<cpu> *stream) {
130  CHECK_EQ(_dst.shape_, _src.shape_)
131  << "Copy:shape mismatch:" << _dst.shape_ << " vs " << _src.shape_;
132  if (_dst.CheckContiguous() && _src.CheckContiguous()) {
133  memcpy(_dst.dptr_, _src.dptr_, sizeof(DType) * _dst.shape_.Size());
134  } else {
135  Tensor<cpu, 2, DType> dst = _dst.FlatTo2D();
136  Tensor<cpu, 2, DType> src = _src.FlatTo2D();
137  for (index_t y = 0; y < dst.size(0); ++y) {
138  memcpy(dst[y].dptr_, src[y].dptr_, sizeof(DType) * dst.size(1));
139  }
140  }
141 }
142 
143 template<typename Saver, typename R, int dim,
144  typename DType, typename E>
146  const expr::Plan<E, DType> &plan) {
147  Shape<2> shape = expr::ShapeCheck<dim, R>::Check(dst->self()).FlatTo2D();
148  expr::Plan<R, DType> dplan = expr::MakePlan(dst->self());
149 #ifndef __CUDACC__
150  #pragma omp parallel for
151 #endif
152  // temp remove openmp, as default setting throttles CPU
153  for (openmp_index_t y = 0; y < shape[0]; ++y) {
154  for (index_t x = 0; x < shape[1]; ++x) {
155  // trust your compiler! -_- they will optimize it
156  Saver::template Save<DType>(dplan.REval(y, x), plan.Eval(y, x));
157  }
158  }
159 }
160 // code to handle SSE optimization
161 template<bool pass_check, typename Saver,
162  typename R, int dim,
163  typename DType, typename E, int etype>
165  inline static void Map(TRValue<R, cpu, dim, DType> *dst,
166  const expr::Exp<E, DType, etype> &exp) {
167  MapPlan<Saver>(dst, MakePlan(exp.self()));
168  }
169 };
170 
171 template<typename SV, int dim, typename DType, typename E, int etype>
172 struct MapExpCPUEngine<true, SV, Tensor<cpu, dim, DType>,
173  dim, DType, E, etype> {
174  inline static void Map(Tensor<cpu, dim, DType> *dst,
175  const expr::Exp<E, DType, etype> &exp) {
178  expr::MapPacketPlan<SV>(dst->self(),
179  expr::MakePacketPlan<MSHADOW_DEFAULT_PACKET>(exp.self()));
180  } else {
181  MapPlan<SV>(dst, MakePlan(exp.self()));
182  }
183  }
184 };
185 
186 
187 template<typename Saver, typename R, int dim,
188  typename DType, typename E, int etype>
190  const expr::Exp<E, DType, etype> &exp) {
192  ::Error_All_Tensor_in_Exp_Must_Have_Same_Type();
195  CHECK(eshape[0] == 0 || eshape == dshape)
196  << "Assignment: Shape of Tensors are not consistent with target, "
197  << "eshape: " << eshape << " dshape:" << dshape;
199  Saver, R, dim, DType, E, etype>
200  ::Map(dst->ptrself(), exp);
201 }
202 
203 template<typename Saver, typename Reducer,
204  typename R, typename DType, typename E, int etype>
206  const expr::Exp<E, DType, etype> &exp,
207  DType scale) {
209  ::Error_TypeCheck_Not_Pass_For_Reduce_Exp();
211  ::Check(exp.self()).FlatTo2D();
213  CHECK_EQ(eshape[1], dshape[0]) << "MapReduceKeepLowest::reduction dimension do not match";
214  CHECK_NE(eshape[0], 0U) << "can not reduce over empty tensor";
215  // execution
216  expr::Plan<R, DType> dplan = MakePlan(dst->self());
217  expr::Plan<E, DType> splan = MakePlan(exp.self());
218 #ifndef __CUDACC__
219  #pragma omp parallel for
220 #endif
221  for (openmp_index_t x = 0; x < eshape[1]; ++x) {
222  DType res = splan.Eval(0, x);
223  for (index_t y = 1; y < eshape[0]; ++y) {
224  Reducer::Reduce(res, splan.Eval(y, x));
225  }
226  Saver::template Save<DType>(dplan.REval(0, x), res * scale);
227  }
228 }
229 
230 template<typename Saver, typename Reducer, int dimkeep,
231  typename R, typename DType, typename E, int etype>
233  const expr::Exp<E, DType, etype> &exp,
234  DType scale) {
236  ::Error_TypeCheck_Not_Pass_For_Reduce_Exp();
237  typedef Shape<expr::ExpInfo<E>::kDim> EShape;
238  EShape eshape = expr::ShapeCheck<expr::ExpInfo<E>::kDim, E>
239  ::Check(exp.self());
241  CHECK_EQ(eshape[dimkeep], dshape[0])
242  << "MapReduceKeepHighDim::reduction dimension do not match";
243  // use equvalent form
244  Shape<4> pshape = Shape4(eshape.ProdShape(0, dimkeep),
245  eshape[dimkeep],
246  eshape.ProdShape(dimkeep + 1, EShape::kSubdim),
247  eshape[EShape::kSubdim]);
248  // execution
249  expr::Plan<R, DType> dplan = MakePlan(dst->self());
250  expr::Plan<E, DType> splan = MakePlan(exp.self());
251 #ifndef __CUDACC__
252  #pragma omp parallel for
253 #endif
254  for (openmp_index_t c = 0; c < pshape[1]; ++c) {
255  DType res; Reducer::SetInitValue(res);
256  for (index_t n = 0; n < pshape[0]; ++n) {
257  DType tres; Reducer::SetInitValue(tres);
258  for (index_t y = 0; y < pshape[2]; ++y) {
259  for (index_t x = 0; x < pshape[3]; ++x) {
260  Reducer::Reduce(tres,
261  splan.Eval((n * pshape[1] + c) * pshape[2] + y, x));
262  }
263  }
264  Reducer::Reduce(res, tres);
265  }
266  Saver::template Save<DType>(dplan.REval(0, c), DType(res * scale));
267  }
268 }
269 
270 template<typename DType>
272  const Tensor<cpu, 1, DType> &energy) {
273  DType mmax = energy[0];
274  for (index_t x = 1; x < dst.size(0); ++x) {
275  if (mmax < energy[x]) mmax = energy[x];
276  }
277  DType sum = DType(0.0f);
278  for (index_t x = 0; x < dst.size(0); ++x) {
279  dst[x] = std::exp(energy[x] - mmax);
280  sum += dst[x];
281  }
282  for (index_t x = 0; x < dst.size(0); ++x) {
283  dst[x] /= sum;
284  }
285 }
286 
287 template<typename DType>
289  const Tensor<cpu, 2, DType> &src,
290  const Tensor<cpu, 1, DType> &label) {
291 #pragma omp parallel for
292  for (openmp_index_t y = 0; y < dst.size(0); ++y) {
293  const index_t k = static_cast<int>(label[y]);
294  for (index_t x = 0; x < dst.size(1); ++x) {
295  if (x == k) {
296  dst[y][k] = src[y][k] - 1.0f;
297  } else {
298  dst[y][x] = src[y][x];
299  }
300  }
301  }
302 }
303 
304 template<typename DType>
306  const Tensor<cpu, 2, DType> &src,
307  const Tensor<cpu, 1, DType> &label,
308  const float alpha) {
309  const float smooth_grad = (alpha / (dst.size(1) - 1));
310 #pragma omp parallel for
311  for (openmp_index_t y = 0; y < dst.size(0); ++y) {
312  const index_t k = static_cast<int>(label[y]);
313  for (index_t x = 0; x < dst.size(1); ++x) {
314  if (x == k) {
315  dst[y][k] = src[y][k] - 1.0f + alpha;
316  } else {
317  dst[y][x] = src[y][x] - smooth_grad;
318  }
319  }
320  }
321 }
322 
323 
324 template<typename DType>
326  const Tensor<cpu, 2, DType> &src,
327  const Tensor<cpu, 1, DType> &label,
328  const DType &ignore_label) {
329 #pragma omp parallel for
330  for (openmp_index_t y = 0; y < dst.size(0); ++y) {
331  const int k = static_cast<int>(label[y]);
332  for (int x = 0; x < static_cast<int>(dst.size(1)); ++x) {
333  if (static_cast<int>(ignore_label) == k) {
334  dst[y][x] = 0.0f;
335  } else {
336  if (x == k) {
337  dst[y][k] = src[y][k] - 1.0f;
338  } else {
339  dst[y][x] = src[y][x];
340  }
341  }
342  }
343  }
344 }
345 
346 template<typename DType>
348  const Tensor<cpu, 2, DType> &src,
349  const Tensor<cpu, 1, DType> &label,
350  const DType &ignore_label,
351  const float alpha) {
352  const float smooth_grad = (alpha / (dst.size(1) - 1));
353 #pragma omp parallel for
354  for (openmp_index_t y = 0; y < dst.size(0); ++y) {
355  const int k = static_cast<int>(label[y]);
356  for (int x = 0; x < static_cast<int>(dst.size(1)); ++x) {
357  if (static_cast<int>(ignore_label) == k) {
358  dst[y][x] = 0.0f;
359  } else {
360  if (x == k) {
361  dst[y][k] = src[y][k] - 1.0f + alpha;
362  } else {
363  dst[y][x] = src[y][x] - smooth_grad;
364  }
365  }
366  }
367  }
368 }
369 
370 template<typename DType>
372  const Tensor<cpu, 3, DType> &src,
373  const Tensor<cpu, 2, DType> &label) {
374 #pragma omp parallel for
375  for (openmp_index_t n = 0; n < dst.size(2); ++n) {
376  for (index_t y = 0; y < dst.size(0); ++y) {
377  const int k = static_cast<int>(label[y][n]);
378  for (int x = 0; x < static_cast<int>(dst.size(1)); ++x) {
379  if (x == k) {
380  dst[y][k][n] = src[y][k][n] - 1.0f;
381  } else {
382  dst[y][x][n] = src[y][x][n];
383  }
384  }
385  }
386  }
387 }
388 
389 template<typename DType>
391  const Tensor<cpu, 3, DType> &src,
392  const Tensor<cpu, 2, DType> &label,
393  const float alpha) {
394  const float smooth_grad = (alpha / (dst.size(1) - 1));
395 #pragma omp parallel for
396  for (openmp_index_t n = 0; n < dst.size(2); ++n) {
397  for (index_t y = 0; y < dst.size(0); ++y) {
398  const int k = static_cast<int>(label[y][n]);
399  for (int x = 0; x < static_cast<int>(dst.size(1)); ++x) {
400  if (x == k) {
401  dst[y][k][n] = src[y][k][n] - 1.0f + alpha;
402  } else {
403  dst[y][x][n] = src[y][x][n] - smooth_grad;
404  }
405  }
406  }
407  }
408 }
409 
410 template<typename DType>
412  const Tensor<cpu, 3, DType> &src,
413  const Tensor<cpu, 2, DType> &label,
414  const DType &ignore_label) {
415 #pragma omp parallel for
416  for (openmp_index_t n = 0; n < dst.size(2); ++n) {
417  for (index_t y = 0; y < dst.size(0); ++y) {
418  const int k = static_cast<int>(label[y][n]);
419  if (k == static_cast<int>(ignore_label)) {
420  for (int x = 0; x < static_cast<int>(dst.size(1)); ++x) {
421  dst[y][x][n] = DType(0.0f);
422  }
423  } else {
424  for (int x = 0; x < static_cast<int>(dst.size(1)); ++x) {
425  if (x == k) {
426  dst[y][k][n] = src[y][k][n] - 1.0f;
427  } else {
428  dst[y][x][n] = src[y][x][n];
429  }
430  }
431  }
432  }
433  }
434 }
435 
436 template<typename DType>
438  const Tensor<cpu, 3, DType> &src,
439  const Tensor<cpu, 2, DType> &label,
440  const DType &ignore_label,
441  const float alpha) {
442  const float smooth_grad = (alpha / (dst.size(1) - 1));
443 #pragma omp parallel for
444  for (openmp_index_t n = 0; n < dst.size(2); ++n) {
445  for (index_t y = 0; y < dst.size(0); ++y) {
446  const int k = static_cast<int>(label[y][n]);
447  if (k == static_cast<int>(ignore_label)) {
448  for (int x = 0; x < static_cast<int>(dst.size(1)); ++x) {
449  dst[y][x][n] = DType(0.0f);
450  }
451  } else {
452  for (int x = 0; x < static_cast<int>(dst.size(1)); ++x) {
453  if (x == k) {
454  dst[y][k][n] = src[y][k][n] - 1.0f + alpha;
455  } else {
456  dst[y][x][n] = src[y][x][n] - smooth_grad;
457  }
458  }
459  }
460  }
461  }
462 }
463 
464 template<typename DType>
466  const Tensor<cpu, 2, DType> &energy) {
467  CHECK_EQ(dst.shape_, energy.shape_) << "Softmax: shape mismatch";
468 #pragma omp parallel for
469  for (openmp_index_t y = 0; y < dst.size(0); ++y) {
470  Softmax(dst[y], energy[y]);
471  }
472 }
473 
474 template<typename DType>
476  const Tensor<cpu, 3, DType> &energy) {
477  CHECK_EQ(dst.shape_, energy.shape_) << "Softmax: shape mismatch";
478 #pragma omp parallel for
479  for (openmp_index_t y = 0; y < dst.size(0); ++y) {
480  for (index_t n = 0; n < dst.size(2); ++n) {
481  DType mmax = energy[y][0][n];
482  for (index_t x = 1; x < dst.size(1); ++x) {
483  if (mmax < energy[y][x][n]) mmax = energy[y][x][n];
484  }
485  DType sum = DType(0.0f);
486  for (index_t x = 0; x < dst.size(1); ++x) {
487  dst[y][x][n] = std::exp(energy[y][x][n] - mmax);
488  sum += dst[y][x][n];
489  }
490  for (index_t x = 0; x < dst.size(1); ++x) {
491  dst[y][x][n] /= sum;
492  }
493  }
494  }
495 }
496 
497 template<bool clip, typename IndexType, typename DType>
499  const Tensor<cpu, 1, IndexType>& index,
500  const Tensor<cpu, 2, DType> &src) {
501  const int K = dst.shape_[0];
502  for (index_t y = 0; y < index.size(0); ++y) {
503  int j = index[y];
504  if (clip) {
505  if (j <= 0) j = 0;
506  else if (j >= K) j = K - 1;
507  } else {
508  j %= K;
509  if (j < 0) j += K;
510  }
511  dst[j] += src[y];
512  }
513 }
514 
515 template<typename IndexType, typename DType>
517  const Tensor<cpu, 1, IndexType>& sorted,
518  const Tensor<cpu, 1, IndexType>& index,
519  const Tensor<cpu, 2, DType> &src) {
520  for (index_t y = 0; y < sorted.size(0); ++y) {
521  dst[sorted[y]] += src[index[y]];
522  }
523 }
524 
525 template<typename IndexType, typename DType>
527  const Tensor<cpu, 1, IndexType>& index,
528  const Tensor<cpu, 2, DType> &src) {
529  for (index_t y = 0; y < index.size(0); ++y) {
530  for (index_t j = 0; j < src.size(1); j++) {
531  dst[index[y]][j] = src[y][j];
532  }
533  }
534 }
535 
536 template<typename KDType, typename VDType>
538  bool is_ascend) {
539  CHECK_EQ(keys.CheckContiguous(), true);
540  CHECK_EQ(values.CheckContiguous(), true);
541  CHECK_EQ(keys.size(0), values.size(0))
542  << "The sizes of key/value are not equal! keys_size: " << keys.size(0)
543  << "values_size: " << values.size(0);
544  std::vector<size_t> idx(keys.size(0));
545  std::vector<KDType> keys_vec(keys.size(0));
546  std::vector<VDType> values_vec(values.size(0));
547  for (int i = 0; i < keys.size(0); i++) {
548  idx[i] = i;
549  keys_vec[i] = keys[i];
550  values_vec[i] = values[i];
551  }
552  if (is_ascend) {
553  std::stable_sort(idx.begin(), idx.end(),
554  [&keys_vec](size_t i1, size_t i2)
555  {return keys_vec[i1] < keys_vec[i2]; });
556  } else {
557  std::stable_sort(idx.begin(), idx.end(),
558  [&keys_vec](size_t i1, size_t i2)
559  {return keys_vec[i1] > keys_vec[i2]; });
560  }
561  for (index_t i = 0; i < values.size(0); i++) {
562  keys[i] = keys_vec[idx[i]];
563  values[i] = values_vec[idx[i]];
564  }
565 }
566 
567 template<typename Device, typename VDType, typename SDType>
569  // We can sort each segments using two stable sorts
570  SortByKey(values, segments, true);
571  SortByKey(segments, values, true);
572 }
573 
574 // blas related
575 template<typename Device, typename DType>
577  const Tensor<Device, 1, DType> &lhs,
578  const Tensor<Device, 1, DType> &rhs) {
579  CHECK_EQ(lhs.size(0), rhs.size(0))
580  << "VectorDot: Shape mismatch";
581  CHECK_EQ(dst.size(0), 1U)
582  << "VectorDot: expect dst to be scalar";
585  lhs.stream_, lhs.size(0), lhs.dptr_, 1, rhs.dptr_, 1, dst.dptr_);
586 }
587 
588 template<bool transpose_left, bool transpose_right, typename Device, typename DType>
590  const Tensor<Device, 3, DType> &lhs,
591  const Tensor<Device, 3, DType> &rhs,
592  DType alpha,
593  DType beta,
594  Tensor<Device, 1, DType*> workspace) {
595  index_t batch_size = dst.shape_[0];
597  Shape<3> sleft = transpose_left ? Shape3(lhs.shape_[0], lhs.shape_[2], lhs.shape_[1])
598  : lhs.shape_;
599  Shape<3> sright = transpose_right ? Shape3(rhs.shape_[0], rhs.shape_[2], rhs.shape_[1])
600  : rhs.shape_;
601  CHECK_EQ(dst.CheckContiguous(), true);
602  CHECK_EQ(lhs.CheckContiguous(), true);
603  CHECK_EQ(rhs.CheckContiguous(), true);
604  CHECK(sleft[0] == batch_size && sright[0] == batch_size)
605  << "BatchGEMM: batchsize must be equal."
606  << "dst: " << dst.shape_ << "\n"
607  << "lhs: " << sleft << "\n"
608  << "rhs: " << sright << "\n";
609  CHECK(dst.size(1) == sleft[1] && dst.size(2) == sright[2] && sleft[2] == sright[1])
610  << "BatchGEMM: matrix shape mismatch"
611  << "dst: " << dst.shape_ << "\n"
612  << "lhs: " << sleft << "\n"
613  << "rhs: " << sright << "\n";
614  CHECK(workspace.size(0) >= 3 * batch_size)
615  << "Workspace Size must be bigger than " << 3 * batch_size;
616  CHECK_EQ(workspace.CheckContiguous(), true);
617  // use column major argument to compatible with most BLAS
619  (dst.stream_,
620  transpose_right, transpose_left,
621  transpose_right ? rhs.size(1) : rhs.size(2),
622  transpose_left ? lhs.size(2) : lhs.size(1),
623  transpose_right ? rhs.size(2) : rhs.size(1),
624  alpha,
625  rhs.dptr_, rhs.stride_,
626  lhs.dptr_, lhs.stride_,
627  beta,
628  dst.dptr_, dst.stride_, batch_size,
629  workspace.dptr_);
630 }
631 } // namespace mshadow
632 #endif // MSHADOW_TENSOR_CPU_INL_H_
void VectorDot(Tensor< Device, 1, DType > dst, const Tensor< Device, 1, DType > &lhs, const Tensor< Device, 1, DType > &rhs)
CPU/GPU: 1 dimension vector dot.
Definition: tensor_cpu-inl.h:576
static void batched_gemm(Stream< Device > *stream, bool transa, bool transb, int m, int n, int k, DType alpha, const DType *A, int lda, const DType *B, int ldb, DType beta, DType *C, int ldc, int batch_count, DType **workspace)
Definition: dot_engine-inl.h:73
void FreeSpace(Tensor< cpu, dim, DType > *obj)
CPU/GPU: free the space of tensor, will set obj.dptr to NULL.
Definition: tensor_cpu-inl.h:122
void ShutdownTensorEngine< cpu >(void)
Definition: tensor_cpu-inl.h:23
Stream< Device > * stream_
Definition: tensor.h:556
void IndexFill(Tensor< cpu, 2, DType > dst, const Tensor< cpu, 1, IndexType > &index, const Tensor< cpu, 2, DType > &src)
CPU/GPU: Fill the values of the destination matrix to specific rows in the source matrix...
Definition: tensor_cpu-inl.h:526
void SoftmaxGrad(Tensor< cpu, 2, DType > dst, const Tensor< cpu, 2, DType > &src, const Tensor< cpu, 1, DType > &label)
CPU/GPU: softmax gradient.
Definition: tensor_cpu-inl.h:288
void SmoothSoftmaxGrad(Tensor< cpu, 2, DType > dst, const Tensor< cpu, 2, DType > &src, const Tensor< cpu, 1, DType > &label, const float alpha)
Definition: tensor_cpu-inl.h:305
PaddingExp< SrcExp, DType, ExpInfo< SrcExp >::kDim > pad(const Exp< SrcExp, DType, etype > &src, index_t pad)
padding expression, pad a image with zeros on boundaries, padding affects shape[0], and shape[1]
Definition: pad.h:53
DType * dptr_
pointer to the data
Definition: tensor.h:416
void FreeHost_(void *dptr)
Tensor RValue, this is the super type of all kinds of possible tensors.
Definition: tensor.h:391
Definition: expr_engine-inl.h:40
void SetDevice< cpu >(int devid)
Definition: tensor_cpu-inl.h:27
used to help static type check
Definition: expr_engine-inl.h:312
void AlignedFree(void *ptr)
free aligned space
Definition: packet-inl.h:84
void Copy(Tensor< cpu, dim, DType > dst, const Tensor< cpu, dim, DType > &src, Stream< cpu > *stream=NULL)
copy data from one tensor to another, with same shape
Definition: tensor_cpu-inl.h:127
void MapExp(TRValue< R, cpu, dim, DType > *dst, const expr::Exp< E, DType, etype > &exp)
CPU/GPU: map a expression to a tensor, this function calls MapPlan.
Definition: tensor_cpu-inl.h:189
Container * ptrself(void)
Definition: expression.h:68
Shape< dimension > shape_
shape of the tensor
Definition: tensor.h:418
Definition: packet-inl.h:357
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
evaluate the expression at index [y][x] to be implemented by SubType, for RValue, the return type wil...
MSHADOW_XINLINE Shape< 4 > Shape4(index_t s0, index_t s1, index_t s2, index_t s3)
construct a four dimension shape, stride will equal s0
Definition: tensor.h:222
void SortByKey(Tensor< cpu, 1, KDType > keys, Tensor< cpu, 1, VDType > values, bool is_ascend=true)
CPU/GPU: Sort key-value pairs stored in separate places. (Stable sort is performed!) ...
Definition: tensor_cpu-inl.h:537
void Softmax(Tensor< cpu, 2, DType > dst, const Tensor< cpu, 2, DType > &energy)
CPU/GPU: normalize softmax: dst[i][j] = exp(energy[i][j]) /(sum_j exp(energy[i][j])) ...
Definition: tensor_cpu-inl.h:465
void VectorizedSort(Tensor< Device, 1, VDType > values, Tensor< Device, 1, SDType > segments)
CPU/GPU: Sort the keys within each segment. (Stable sort is performed!) Segments is defined as an asc...
Definition: tensor_cpu-inl.h:568
void * AlignedMallocPitch(size_t *out_pitch, size_t lspace, size_t num_line)
analog to cudaMallocPitch, allocate a aligned space with num_line * lspace cells
Definition: packet-inl.h:59
void BatchGEMM(Tensor< Device, 3, DType > dst, const Tensor< Device, 3, DType > &lhs, const Tensor< Device, 3, DType > &rhs, DType alpha, DType beta, Tensor< Device, 1, DType * > workspace)
CPU/GPU: dst = alpha * op(lhs) op(rhs) + beta * dst.
Definition: tensor_cpu-inl.h:589
#define MSHADOW_CUDA_CALL(func)
Protected cuda call in mshadow.
Definition: base.h:252
void MapReduceKeepLowest(TRValue< R, cpu, 1, DType > *dst, const expr::Exp< E, DType, etype > &exp, DType scale=1)
CPU/GPU: map a expression, do reduction to 1D Tensor in lowest dimension (dimension 0) ...
Definition: tensor_cpu-inl.h:205
static Shape< dim > Check(const E &t)
device name CPU
Definition: tensor.h:21
MSHADOW_XINLINE Tensor< Device, 2, DType > FlatTo2D(void) const
flatten the tensor to 2 dimension, collapse the higher dimensions together
Definition: tensor.h:501
void * AllocHost_(size_t size)
MSHADOW_XINLINE index_t size(index_t i) const
Definition: tensor.h:588
void FreeHost_< cpu >(void *dptr)
Definition: tensor_cpu-inl.h:77
int32_t index_t
type that will be used for index
Definition: base.h:291
void AllocSpace(Tensor< cpu, dim, DType > *obj, bool pad=MSHADOW_ALLOC_PAD)
CPU/CPU: allocate space for CTensor, according to the shape in the obj this function is responsible t...
Definition: tensor_cpu-inl.h:98
DType * dptr_
Definition: tensor.h:553
Generic packet vectorization code.
void InitTensorEngine< cpu >(int dev_id)
Definition: tensor_cpu-inl.h:20
void AddTakeGradLargeBatch(Tensor< cpu, 2, DType > dst, const Tensor< cpu, 1, IndexType > &sorted, const Tensor< cpu, 1, IndexType > &index, const Tensor< cpu, 2, DType > &src)
CPU/GPU: Gradient accumulate of embedding matrix. dst[sorted[i]] += src[index[i]] Called when the bat...
Definition: tensor_cpu-inl.h:516
static void dot(Stream< Device > *stream, int n, const DType *X, int incX, const DType *Y, int incY, DType *ret)
Definition: dot_engine-inl.h:107
void AllocHost(Tensor< cpu, dim, DType > *obj)
Definition: tensor_cpu-inl.h:82
runtime shape checking template get the shape of an expression, report error if shape mismatch ...
Definition: expr_engine-inl.h:346
Stream< cpu > * NewStream< cpu >(bool create_blas_handle, bool create_dnn_handle, int dev_id)
Definition: tensor_cpu-inl.h:30
void MapPlan(TRValue< R, cpu, dim, DType > *dst, const expr::Plan< E, DType > &plan)
Definition: tensor_cpu-inl.h:145
MSHADOW_XINLINE bool CheckContiguous(void) const
Definition: tensor.h:473
Definition: tensor_cpu-inl.h:164
scalar expression
Definition: expression.h:77
void MapReduceKeepHighDim(TRValue< R, cpu, 1, DType > *dst, const expr::Exp< E, DType, etype > &exp, DType scale=1)
CPU/GPU: map a expression, do reduction to 1D Tensor in third dimension (dimension 2) ...
Definition: tensor_cpu-inl.h:232
void * AllocHost_< cpu >(size_t size)
Definition: tensor_cpu-inl.h:72
Tensor< Device, dim, DType > NewTensor(const Shape< dim > &shape, DType initv, bool pad=MSHADOW_ALLOC_PAD, Stream< Device > *stream=NULL)
CPU/GPU: short cut to allocate and initialize a Tensor.
Definition: tensor_cpu-inl.h:114
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:61
const Container & self(void) const
Definition: expression.h:64
Plan< BinaryMapExp< OP, TA, TB, DType, etype >, DType > MakePlan(const BinaryMapExp< OP, TA, TB, DType, etype > &e)
Definition: expr_engine-inl.h:221
void AddTakeGrad(Tensor< cpu, 2, DType > dst, const Tensor< cpu, 1, IndexType > &index, const Tensor< cpu, 2, DType > &src)
CPU/GPU: Gradient accumulate of embedding matrix. dst[index[i]] += src[i] Called when the featuredim ...
Definition: tensor_cpu-inl.h:498
Definition: tensor.h:550
MSHADOW_XINLINE Shape< 3 > Shape3(index_t s0, index_t s1, index_t s2)
construct a three dimension shape, stride will equal s0
Definition: tensor.h:209
namespace for mshadow
Definition: base.h:282
void FreeHost(Tensor< cpu, dim, DType > *obj)
Definition: tensor_cpu-inl.h:89
MSHADOW_XINLINE index_t size(int idx) const
return size of i-th dimension, start counting from highest dimension
Definition: tensor.h:487
index_t stride_
storing the stride information in x dimension this is used to deal with pitch allocation in gpu or ss...
Definition: tensor.h:423
#define MSHADOW_DEFAULT_PACKET
Definition: packet-inl.h:29
general tensor
Definition: tensor.h:402
static void SetStream(Stream< Device > *stream)
Definition: dot_engine-inl.h:64
MSHADOW_XINLINE index_t MSize(void) const
Definition: tensor.h:479
void DeleteStream< cpu >(Stream< cpu > *stream)
Definition: tensor_cpu-inl.h:36
static void Map(Tensor< cpu, dim, DType > *dst, const expr::Exp< E, DType, etype > &exp)
Definition: tensor_cpu-inl.h:174
index_t openmp_index_t
openmp index for linux
Definition: base.h:299
Stream< Device > * stream_
stream where the computation lies stream is a device dependency concept where each computation ...
Definition: tensor.h:428
definitions of how Matrix Multiplications can be evaluated
static void Map(TRValue< R, cpu, dim, DType > *dst, const expr::Exp< E, DType, etype > &exp)
Definition: tensor_cpu-inl.h:165
computaion stream structure, used for asynchronous computations
Definition: tensor.h:365