mxnet
tensor.h
Go to the documentation of this file.
1 
12 #ifndef MSHADOW_TENSOR_H_
13 #define MSHADOW_TENSOR_H_
14 #include <string>
15 #include <iostream>
16 #include "./base.h"
17 #include "./expression.h"
18 
19 namespace mshadow {
21 struct cpu {
23  static const bool kDevCPU = true;
25  static const int kDevMask = 1 << 0;
26 };
28 struct gpu {
30  static const bool kDevCPU = false;
32  static const int kDevMask = 1 << 1;
33 };
34 template<int ndim>
35 struct Shape;
36 
43 template<int ndim>
44 inline std::ostream &operator<<(std::ostream &os, const Shape<ndim> &shape); // NOLINT(*)
45 
50 template<int dimension>
51 struct Shape {
53  static const int kDimension = dimension;
55  static const int kSubdim = dimension - 1;
57  index_t shape_[kDimension];
62  #pragma unroll
63  for (int i = 0; i < kDimension; ++i) {
64  this->shape_[i] = s[i];
65  }
66  }
73  return shape_[idx];
74  }
80  MSHADOW_XINLINE const index_t &operator[](int idx) const {
81  return shape_[idx];
82  }
88  #pragma unroll
89  for (int i = 0; i < kDimension; ++i) {
90  if (s.shape_[i] != this->shape_[i]) return false;
91  }
92  return true;
93  }
99  return !(*this == s);
100  }
106  Shape<1> s;
107  s[0] = this->Size();
108  return s;
109  }
115  Shape<2> s;
116  s.shape_[1] = this->shape_[kDimension - 1];
117  index_t ymax = 1;
118  #pragma unroll
119  for (int i = 0; i < kDimension - 1; ++i) {
120  ymax *= this->shape_[i];
121  }
122  s.shape_[0] = ymax;
123  return s;
124  }
127  index_t size = this->shape_[0];
128  #pragma unroll
129  for (int i = 1; i < kDimension; ++i) {
130  size *= this->shape_[i];
131  }
132  return size;
133  }
139  MSHADOW_XINLINE index_t ProdShape(int dimstart, int dimend) const {
140  index_t num = 1;
141  #pragma unroll
142  for (int i = dimstart; i < dimend; ++i) {
143  num *= this->shape_[i];
144  }
145  return num;
146  }
152  Shape<kSubdim> s;
153  // for cuda
154  #pragma unroll
155  for (int i = 0; i < kSubdim; ++i) {
156  s.shape_[i] = this->shape_[i + 1];
157  }
158  return s;
159  }
166  template<int dimstart, int dimend>
167  MSHADOW_XINLINE Shape<dimend - dimstart> Slice(void) const {
168  Shape<dimend - dimstart> s;
169  #pragma unroll
170  for (int i = dimstart; i < dimend; ++i) {
171  s[i - dimstart] = this->shape_[i];
172  }
173  return s;
174  }
176  template<int dim>
177  friend std::ostream &operator<<(std::ostream &os, const Shape<dim> &shape); // NOLINT(*)
179 }; // Shape
180 //------------------------------------------------
181 // useful construction functions to generate shape
182 //-------------------------------------------------
189  Shape<1> s; s[0] = s0;
190  return s;
191 }
199  Shape<2> s; s[0] = s0; s[1] = s1;
200  return s;
201 }
210  Shape<3> s;
211  s[0] = s0; s[1] = s1; s[2] = s2;
212  return s;
213 }
223  index_t s2, index_t s3) {
224  Shape<4> s;
225  s[0] = s0; s[1] = s1; s[2] = s2; s[3] = s3;
226  return s;
227 }
238  index_t s3, index_t s4) {
239  Shape<5> s;
240  s[0] = s0; s[1] = s1; s[2] = s2; s[3] = s3; s[4] = s4;
241  return s;
242 }
243 
251 inline Shape<3> ConvertLayout(const Shape<3>& src, int src_layout, int dst_layout) {
252  Shape<3> dst;
253  switch (src_layout) {
254  case kNCW:
255  dst = src;
256  break;
257  case kNWC:
258  dst[0] = src[0];
259  dst[1] = src[2];
260  dst[2] = src[1];
261  break;
262  default:
263  LOG(FATAL) << "Invalid layout for 3d shape " << src_layout;
264  }
265  switch (dst_layout) {
266  case kNCW:
267  return dst;
268  case kNWC:
269  {
270  index_t tmp = dst[1];
271  dst[1] = dst[2];
272  dst[2] = tmp;
273  }
274  break;
275  default:
276  LOG(FATAL) << "Invalid layout for 3d shape " << src_layout;
277  }
278  return dst;
279 }
280 
288 inline Shape<4> ConvertLayout(const Shape<4>& src, int src_layout, int dst_layout) {
289  Shape<4> dst;
290  switch (src_layout) {
291  case kNCHW:
292  dst = src;
293  break;
294  case kNHWC:
295  dst[0] = src[0];
296  dst[2] = src[1];
297  dst[3] = src[2];
298  dst[1] = src[3];
299  break;
300  default:
301  LOG(FATAL) << "Invalid layout for 4d shape " << src_layout;
302  dst = src; // fixes compiler warning
303  }
304  Shape<4> dst2;
305  switch (dst_layout) {
306  case kNCHW:
307  return dst;
308  case kNHWC:
309  dst2[0] = dst[0];
310  dst2[1] = dst[2];
311  dst2[2] = dst[3];
312  dst2[3] = dst[1];
313  break;
314  default:
315  LOG(FATAL) << "Invalid layout for 4d shape " << src_layout;
316  dst2 = src; // fixes compiler warning
317  }
318  return dst2;
319 }
320 
328 inline Shape<5> ConvertLayout(const Shape<5>& src, int src_layout, int dst_layout) {
329  Shape<5> dst;
330  switch (src_layout) {
331  case kNCDHW:
332  dst = src;
333  break;
334  case kNDHWC:
335  dst[0] = src[0];
336  dst[2] = src[1];
337  dst[3] = src[2];
338  dst[4] = src[3];
339  dst[1] = src[4];
340  break;
341  default:
342  LOG(FATAL) << "Invalid layout for 5d shape " << src_layout;
343  }
344  Shape<5> dst2;
345  switch (dst_layout) {
346  case kNCDHW:
347  return dst;
348  case kNDHWC:
349  dst2[0] = dst[0];
350  dst2[1] = dst[2];
351  dst2[2] = dst[3];
352  dst2[3] = dst[4];
353  dst2[4] = dst[1];
354  break;
355  default:
356  LOG(FATAL) << "Invalid layout for 5d shape " << src_layout;
357  }
358  return dst2;
359 }
360 
364 template<typename Device>
365 struct Stream {
366  // this is only a dummy implementation for CPU
367  // for GPU, the actual implementation will be specialized in tensor_gpu-inl.h
372  inline void Wait(void) {}
377  inline bool CheckIdle(void) {
378  return true;
379  }
381  inline void CreateBlasHandle() {}
382 };
390 template<typename Container, typename Device, int dimension, typename DType>
391 struct TRValue: public expr::RValueExp<Container, DType> {
392 };
393 // more compact template
400 template<typename Device, int dimension,
401  typename DType MSHADOW_DEFAULT_DTYPE>
402 struct Tensor: public TRValue<Tensor<Device, dimension, DType>,
403  Device, dimension, DType> {
404  public:
405  //--------------------------------
406  // struct memembers
407  //--------------------------------
409  static const bool kDevCPU = Device::kDevCPU;
411  static const int kSubdim = dimension - 1;
412  //--------------------------------
413  // struct memembers
414  //--------------------------------
416  DType *dptr_ = nullptr;
429  //--------------------------------
430  // functions
431  //--------------------------------
433  MSHADOW_XINLINE Tensor(void) : stream_(NULL) {}
436  : shape_(shape), stream_(NULL) {}
438  MSHADOW_XINLINE Tensor(DType *dptr, const Shape<dimension> &shape)
439  : dptr_(dptr), shape_(shape), stride_(shape[kSubdim]), stream_(NULL) {}
441  MSHADOW_XINLINE Tensor(DType *dptr, const Shape<dimension> &shape,
442  Stream<Device> *stream)
443  : dptr_(dptr), shape_(shape), stride_(shape[kSubdim]), stream_(stream) {}
445  MSHADOW_XINLINE Tensor(DType *dptr,
446  const Shape<dimension> &shape,
447  index_t stride, Stream<Device> *stream)
448  : dptr_(dptr), shape_(shape), stride_(stride), stream_(stream) {}
453  inline void set_stream(Stream<Device> *stream) {
454  this->stream_ = stream;
455  }
460  template<int startdim>
462  index_t memsz = this->stride_;
463  #pragma unroll
464  for (int i = startdim; i < kSubdim; ++i) {
465  memsz *= this->shape_[i];
466  }
467  return memsz;
468  }
473  MSHADOW_XINLINE bool CheckContiguous(void) const {
474  return this->shape_[dimension - 1] == stride_;
475  }
480  return this->MemSize<0>();
481  }
487  MSHADOW_XINLINE index_t size(int idx) const {
488  return shape_[idx];
489  }
495  return Tensor<Device, 1, DType>(dptr_, shape_.FlatTo1D(), stride_, stream_);
496  }
502  return Tensor<Device, 2, DType>(dptr_, shape_.FlatTo2D(), stride_, stream_);
503  }
510  return Tensor<Device, kSubdim, DType>(dptr_ + this->MemSize<1>() * idx,
511  shape_.SubShape(), stride_, stream_);
512  }
520  Slice(index_t begin, index_t end) const {
521  Shape<dimension> s = this->shape_;
522  s[0] = end - begin;
523  return Tensor<Device, dimension, DType>(dptr_ + this->MemSize<1>() * begin,
524  s, stride_, stream_);
525  }
529  dptr_ = exp.dptr_;
530  shape_ = exp.shape_;
531  stride_ = exp.stride_;
532  stream_ = exp.stream_;
533  return *this;
534  }
536  template<typename E, int etype>
539  return this->__assign(exp);
540  }
542  inline Tensor<Device, dimension, DType> &operator=(const DType &exp) {
543  return this->__assign(exp);
544  }
545 };
546 /*
547  * respecialized class Tensor1D, thei is due to different implementation in operator[]
548  */
549 template<typename Device, typename DType>
550 struct Tensor<Device, 1, DType>:
551  public TRValue<Tensor<Device, 1, DType>, Device, 1, DType> {
552  public:
553  DType *dptr_;
557  // constructor
558  MSHADOW_XINLINE Tensor(void) : stream_(NULL) {}
560  : shape_(shape), stream_(NULL) {}
561  MSHADOW_XINLINE Tensor(DType *dptr, Shape<1> shape)
562  : dptr_(dptr), shape_(shape), stride_(shape[0]), stream_(NULL) {}
563  MSHADOW_XINLINE Tensor(DType *dptr, Shape<1> shape, Stream<Device> *stream)
564  : dptr_(dptr), shape_(shape), stride_(shape[0]), stream_(stream) {}
565  MSHADOW_XINLINE Tensor(DType *dptr, Shape<1> shape,
566  index_t stride, Stream<Device> *stream)
567  : dptr_(dptr), shape_(shape), stride_(stride), stream_(stream) {}
568  inline void set_stream(Stream<Device> *stream) {
569  this->stream_ = stream;
570  }
572  return *this;
573  }
575  return Tensor<Device, 2, DType>(dptr_, shape_.FlatTo2D(), stride_, stream_);
576  }
578  Shape<1> s;
579  s[0] = end - begin;
580  return Tensor<Device, 1, DType>(dptr_ + begin, s, s[0], stream_);
581  }
582  MSHADOW_XINLINE bool CheckContiguous(void) const {
583  return true;
584  }
586  return shape_[0];
587  }
589  return shape_[0];
590  }
592  return dptr_[idx];
593  }
594  MSHADOW_XINLINE const DType &operator[](index_t idx) const {
595  return dptr_[idx];
596  }
598  inline Tensor<Device, 1, DType> &
600  dptr_ = exp.dptr_;
601  shape_ = exp.shape_;
602  stride_ = exp.stride_;
603  stream_ = exp.stream_;
604  return *this;
605  }
606  template<typename E, int etype>
607  inline Tensor<Device, 1, DType> &
609  return this->__assign(exp);
610  }
611  inline Tensor<Device, 1, DType> &operator=(const DType &exp) {
612  return this->__assign(exp);
613  }
614 };
615 //------------------------
616 // Function Declarations
617 //-----------------------
625 template<typename Device>
626 inline void InitTensorEngine(int device_id = 0);
633 template<typename Device>
634 inline void ShutdownTensorEngine(void);
640 template<typename Device>
641 inline void SetDevice(int devid);
650 template<typename Device>
651 inline Stream<Device> *NewStream(bool create_blas_handle,
652  bool create_dnn_handle,
653  int dev_id = -1);
658 template<typename Device>
659 inline Stream<Device> *NewStream(int dev_id) {
660  return NewStream<Device>(true, false, dev_id);
661 }
666 template<typename Device>
667 inline void DeleteStream(Stream<Device> *stream);
679 template<int dim, typename DType>
680 inline void AllocSpace(Tensor<cpu, dim, DType> *obj,
681  bool pad = MSHADOW_ALLOC_PAD);
693 template<int dim, typename DType>
694 inline void AllocSpace(Tensor<gpu, dim, DType> *obj,
695  bool pad = MSHADOW_ALLOC_PAD);
702 template<int dim, typename DType>
703 inline void FreeSpace(Tensor<cpu, dim, DType> *obj);
710 template<int dim, typename DType>
711 inline void FreeSpace(Tensor<gpu, dim, DType> *obj);
724 template<typename Device, typename DType, int dim>
726  DType initv,
727  bool pad = MSHADOW_ALLOC_PAD,
728  Stream<Device> *stream = NULL);
737 template<int dim, typename DType>
738 inline void Copy(Tensor<cpu, dim, DType> dst,
739  const Tensor<cpu, dim, DType> &src,
740  Stream<cpu> *stream = NULL);
749 template<int dim, typename DType>
750 inline void Copy(Tensor<cpu, dim, DType> dst,
751  const Tensor<gpu, dim, DType> &src,
752  Stream<gpu> *stream = NULL);
761 template<int dim, typename DType>
762 inline void Copy(Tensor<gpu, dim, DType> dst,
763  const Tensor<cpu, dim, DType> &src,
764  Stream<gpu> *stream = NULL);
773 template<int dim, typename DType>
774 inline void Copy(Tensor<gpu, dim, DType> dst,
775  const Tensor<gpu, dim, DType> &src,
776  Stream<gpu> *stream = NULL);
782 template<typename DType>
783 inline void Softmax(Tensor<cpu, 2, DType> dst, const Tensor<cpu, 2, DType> &energy);
789 template<typename DType>
790 inline void Softmax(Tensor<gpu, 2, DType> dst, const Tensor<gpu, 2, DType> &energy);
791 
798 template<typename DType>
799 inline void SoftmaxGrad(Tensor<cpu, 2, DType> dst,
800  const Tensor<cpu, 2, DType> &src,
801  const Tensor<cpu, 1, DType> &label);
808 template<typename DType>
809 inline void SoftmaxGrad(const Tensor<gpu, 2, DType> &dst,
810  const Tensor<gpu, 2, DType> &src,
811  const Tensor<gpu, 1, DType> &label);
820 template<bool clip = true, typename IndexType, typename DType>
821 inline void AddTakeGrad(Tensor<cpu, 2, DType> dst,
822  const Tensor<cpu, 1, IndexType>& index,
823  const Tensor<cpu, 2, DType> &src);
832 template<bool clip = true, typename IndexType, typename DType>
833 inline void AddTakeGrad(Tensor<gpu, 2, DType> dst,
834  const Tensor<gpu, 1, IndexType>& index,
835  const Tensor<gpu, 2, DType> &src);
845 template<typename IndexType, typename DType>
847  const Tensor<cpu, 1, IndexType>& sorted,
848  const Tensor<cpu, 1, IndexType>& index,
849  const Tensor<cpu, 2, DType> &src);
859 template<typename IndexType, typename DType>
861  const Tensor<gpu, 1, IndexType>& sorted,
862  const Tensor<gpu, 1, IndexType>& index,
863  const Tensor<gpu, 2, DType> &src);
872 template<typename IndexType, typename DType>
873 inline void IndexFill(Tensor<cpu, 2, DType> dst,
874  const Tensor<cpu, 1, IndexType>& index,
875  const Tensor<cpu, 2, DType> &src);
884 template<typename IndexType, typename DType>
885 inline void IndexFill(Tensor<gpu, 2, DType> dst,
886  const Tensor<gpu, 1, IndexType>& index,
887  const Tensor<gpu, 2, DType> &src);
894 template<typename KDType, typename VDType>
896  bool is_ascend = true);
903 template<typename KDType, typename VDType>
905  bool is_ascend = true);
914 template<typename Device, typename VDType, typename SDType>
916 
917 // function declarations to support expression, no need to understand them
918 // these functions do not need to be directly used
931 template<typename Saver, typename R, int dim,
932  typename DType, typename E, int etype>
933 inline void MapExp(TRValue<R, cpu, dim, DType> *dst,
934  const expr::Exp<E, DType, etype> &exp);
947 template<typename Saver, typename R, int dim,
948  typename DType, typename E, int etype>
949 inline void MapExp(TRValue<R, gpu, dim, DType> *dst,
950  const expr::Exp<E, DType, etype> &exp);
964 template<typename Saver, typename Reducer,
965  typename R, typename DType, typename E, int etype>
967  const expr::Exp<E, DType, etype> &exp,
968  DType scale = 1);
982 template<typename Saver, typename Reducer, typename R,
983  typename DType, typename E, int etype>
985  const expr::Exp<E, DType, etype> &exp,
986  DType scale = 1);
1001 template<typename Saver, typename Reducer, int dimkeep,
1002  typename R, typename DType, typename E, int etype>
1004  const expr::Exp<E, DType, etype> &exp,
1005  DType scale = 1);
1020 template<typename Saver, typename Reducer, int dimkeep,
1021  typename R, typename DType, typename E, int etype>
1023  const expr::Exp<E, DType, etype> &exp,
1024  DType scale = 1);
1031 template<typename Device, typename DType>
1032 inline void VectorDot(Tensor<Device, 1, DType> dst,
1033  const Tensor<Device, 1, DType> &lhs,
1034  const Tensor<Device, 1, DType> &rhs);
1044 template<bool transpose_left, bool transpose_right, typename Device, typename DType>
1045 inline void BatchGEMM(Tensor<Device, 3, DType> dst,
1046  const Tensor<Device, 3, DType> &lhs,
1047  const Tensor<Device, 3, DType> &rhs,
1048  DType alpha,
1049  DType beta,
1050  Tensor<Device, 1, DType*> workspace);
1051 } // namespace mshadow
1052 // include headers
1053 #include "./stream_gpu-inl.h"
1054 #include "./extension.h"
1055 #include "./expr_engine-inl.h"
1056 #include "./tensor_cpu-inl.h"
1057 #include "./tensor_gpu-inl.h"
1058 #include "./io.h"
1059 #include "./tensor_container.h"
1060 #include "./random.h"
1061 // add definition of scalar related operators
1062 #ifdef MSHADOW_SCALAR_
1063  #error "MSHADOW_SCALAR_ must not be defined"
1064 #endif
1065 // enumerate all the scalar data type we aim to be good at
1066 #define MSHADOW_SCALAR_ float
1067 #include "./expr_scalar-inl.h"
1068 #undef MSHADOW_SCALAR_
1069 #define MSHADOW_SCALAR_ double
1070 #include "./expr_scalar-inl.h"
1071 #undef MSHADOW_SCALAR_
1072 #define MSHADOW_SCALAR_ int32_t
1073 #include "./expr_scalar-inl.h"
1074 #undef MSHADOW_SCALAR_
1075 #define MSHADOW_SCALAR_ int64_t
1076 #include "./expr_scalar-inl.h"
1077 #undef MSHADOW_SCALAR_
1078 #define MSHADOW_SCALAR_ mshadow::half::half_t
1079 #include "./expr_scalar-inl.h"
1080 #undef MSHADOW_SCALAR_
1081 #endif // MSHADOW_TENSOR_H_
void VectorDot(Tensor< Device, 1, DType > dst, const Tensor< Device, 1, DType > &lhs, const Tensor< Device, 1, DType > &rhs)
CPU/GPU: 1 dimension vector dot.
Definition: tensor_cpu-inl.h:576
void FreeSpace(Tensor< cpu, dim, DType > *obj)
CPU/GPU: free the space of tensor, will set obj.dptr to NULL.
Definition: tensor_cpu-inl.h:122
Stream< Device > * stream_
Definition: tensor.h:556
MSHADOW_XINLINE index_t & operator[](int idx)
get corresponding index
Definition: tensor.h:72
void IndexFill(Tensor< cpu, 2, DType > dst, const Tensor< cpu, 1, IndexType > &index, const Tensor< cpu, 2, DType > &src)
CPU/GPU: Fill the values of the destination matrix to specific rows in the source matrix...
Definition: tensor_cpu-inl.h:526
Definition: base.h:434
void SoftmaxGrad(Tensor< cpu, 2, DType > dst, const Tensor< cpu, 2, DType > &src, const Tensor< cpu, 1, DType > &label)
CPU/GPU: softmax gradient.
Definition: tensor_cpu-inl.h:288
PaddingExp< SrcExp, DType, ExpInfo< SrcExp >::kDim > pad(const Exp< SrcExp, DType, etype > &src, index_t pad)
padding expression, pad a image with zeros on boundaries, padding affects shape[0], and shape[1]
Definition: pad.h:53
MSHADOW_XINLINE index_t Size(void) const
Definition: tensor.h:126
DType * dptr_
pointer to the data
Definition: tensor.h:416
Tensor RValue, this is the super type of all kinds of possible tensors.
Definition: tensor.h:391
Stream< Device > * NewStream(bool create_blas_handle, bool create_dnn_handle, int dev_id=-1)
create a new stream from system
MSHADOW_XINLINE Shape< dimend-dimstart > Slice(void) const
slice the shape from start to end
Definition: tensor.h:167
MSHADOW_XINLINE Tensor< Device, 2, DType > FlatTo2D(void) const
Definition: tensor.h:574
void Copy(Tensor< cpu, dim, DType > dst, const Tensor< cpu, dim, DType > &src, Stream< cpu > *stream=NULL)
copy data from one tensor to another, with same shape
Definition: tensor_cpu-inl.h:127
void ShutdownTensorEngine(void)
Shutdown tensor engine on current device this function should be called after all GPU tensor operatio...
shape of a tensor
Definition: tensor.h:35
MSHADOW_XINLINE Tensor(DType *dptr, Shape< 1 > shape, Stream< Device > *stream)
Definition: tensor.h:563
void MapExp(TRValue< R, cpu, dim, DType > *dst, const expr::Exp< E, DType, etype > &exp)
CPU/GPU: map a expression to a tensor, this function calls MapPlan.
Definition: tensor_cpu-inl.h:189
Definition: stream_gpu-inl.h:19
MSHADOW_XINLINE Tensor(DType *dptr, const Shape< dimension > &shape)
constructor from data pointer and shape, without stride
Definition: tensor.h:438
MSHADOW_XINLINE Tensor< Device, 1, DType > FlatTo1D(void) const
flatten the tensor to 1 dimension
Definition: tensor.h:494
Shape< dimension > shape_
shape of the tensor
Definition: tensor.h:418
MSHADOW_XINLINE const DType & operator[](index_t idx) const
Definition: tensor.h:594
MSHADOW_XINLINE Shape< 4 > Shape4(index_t s0, index_t s1, index_t s2, index_t s3)
construct a four dimension shape, stride will equal s0
Definition: tensor.h:222
MSHADOW_XINLINE bool operator!=(const Shape< kDimension > &s) const
Definition: tensor.h:98
void SortByKey(Tensor< cpu, 1, KDType > keys, Tensor< cpu, 1, VDType > values, bool is_ascend=true)
CPU/GPU: Sort key-value pairs stored in separate places. (Stable sort is performed!) ...
Definition: tensor_cpu-inl.h:537
Tensor< Device, dimension, DType > & operator=(const expr::Exp< E, DType, etype > &exp)
functions to fit expression template
Definition: tensor.h:538
MSHADOW_XINLINE Shape< 2 > FlatTo2D(void) const
Definition: tensor.h:114
void Softmax(Tensor< cpu, 2, DType > dst, const Tensor< cpu, 2, DType > &energy)
CPU/GPU: normalize softmax: dst[i][j] = exp(energy[i][j]) /(sum_j exp(energy[i][j])) ...
Definition: tensor_cpu-inl.h:465
void VectorizedSort(Tensor< Device, 1, VDType > values, Tensor< Device, 1, SDType > segments)
CPU/GPU: Sort the keys within each segment. (Stable sort is performed!) Segments is defined as an asc...
Definition: tensor_cpu-inl.h:568
void set_stream(Stream< Device > *stream)
set the stream to do computation of current tensor
Definition: tensor.h:453
void BatchGEMM(Tensor< Device, 3, DType > dst, const Tensor< Device, 3, DType > &lhs, const Tensor< Device, 3, DType > &rhs, DType alpha, DType beta, Tensor< Device, 1, DType * > workspace)
CPU/GPU: dst = alpha * op(lhs) op(rhs) + beta * dst.
Definition: tensor_cpu-inl.h:589
base class of all rvalues
Definition: expression.h:130
Definition: base.h:426
static const bool kDevCPU
whether this device is CPU or not
Definition: tensor.h:23
void DeleteStream(Stream< Device > *stream)
delete the computing stream
MSHADOW_XINLINE Shape< kSubdim > SubShape(void) const
get subshape that takes off largest dimension v *
Definition: tensor.h:151
void MapReduceKeepLowest(TRValue< R, cpu, 1, DType > *dst, const expr::Exp< E, DType, etype > &exp, DType scale=1)
CPU/GPU: map a expression, do reduction to 1D Tensor in lowest dimension (dimension 0) ...
Definition: tensor_cpu-inl.h:205
MSHADOW_XINLINE Tensor(DType *dptr, Shape< 1 > shape)
Definition: tensor.h:561
#define MSHADOW_ALLOC_PAD
whether do padding during allocation
Definition: base.h:54
Definition: base.h:427
device name CPU
Definition: tensor.h:21
device name GPU
Definition: tensor.h:28
MSHADOW_XINLINE Tensor< Device, 2, DType > FlatTo2D(void) const
flatten the tensor to 2 dimension, collapse the higher dimensions together
Definition: tensor.h:501
bool CheckIdle(void)
query whether the the stream is idle
Definition: tensor.h:377
#define MSHADOW_XINLINE
Definition: base.h:204
Tensor< Device, 1, DType > & operator=(const Tensor< Device, 1, DType > &exp)
implement the assignment of same type
Definition: tensor.h:599
MSHADOW_XINLINE Tensor(void)
default constructor
Definition: tensor.h:433
MSHADOW_XINLINE index_t MSize(void) const
Definition: tensor.h:585
definitions of abstract expressions and expressions template
MSHADOW_XINLINE index_t size(index_t i) const
Definition: tensor.h:588
Tensor< Device, dimension, DType > & operator=(const Tensor< Device, dimension, DType > &exp)
implement the assignment of same type
Definition: tensor.h:528
Shape< 3 > ConvertLayout(const Shape< 3 > &src, int src_layout, int dst_layout)
Convert shape in src_layout to shape in dst_layout.
Definition: tensor.h:251
void CreateBlasHandle()
create a blas handle
Definition: tensor.h:381
int32_t index_t
type that will be used for index
Definition: base.h:291
MSHADOW_XINLINE Shape< 1 > FlatTo1D(void) const
Definition: tensor.h:105
MSHADOW_XINLINE Tensor< Device, 1, DType > FlatTo1D(void) const
Definition: tensor.h:571
void AllocSpace(Tensor< cpu, dim, DType > *obj, bool pad=MSHADOW_ALLOC_PAD)
CPU/CPU: allocate space for CTensor, according to the shape in the obj this function is responsible t...
Definition: tensor_cpu-inl.h:98
DType * dptr_
Definition: tensor.h:553
definitions of how expressions should be evaluated
MSHADOW_XINLINE const index_t & operator[](int idx) const
get corresponding index
Definition: tensor.h:80
MSHADOW_XINLINE index_t ProdShape(int dimstart, int dimend) const
Definition: tensor.h:139
definitions of operators in expression with respect to scalar this file will be included several time...
void AddTakeGradLargeBatch(Tensor< cpu, 2, DType > dst, const Tensor< cpu, 1, IndexType > &sorted, const Tensor< cpu, 1, IndexType > &index, const Tensor< cpu, 2, DType > &src)
CPU/GPU: Gradient accumulate of embedding matrix. dst[sorted[i]] += src[index[i]] Called when the bat...
Definition: tensor_cpu-inl.h:516
MSHADOW_XINLINE Shape< 5 > Shape5(index_t s0, index_t s1, index_t s2, index_t s3, index_t s4)
construct a five dimension shape, stride will equal s0
Definition: tensor.h:237
MSHADOW_XINLINE index_t MemSize(void) const
Definition: tensor.h:461
void SetDevice(int devid)
set the device of current thread to work on
MSHADOW_XINLINE Shape(const Shape< kDimension > &s)
constuctor
Definition: tensor.h:61
some extension of expressions, used to support something beyond elementwise op
MSHADOW_XINLINE Shape< 1 > Shape1(index_t s0)
construct a one dimension shape, stride will equal s0
Definition: tensor.h:188
MSHADOW_XINLINE Tensor< Device, kSubdim, DType > operator[](index_t idx) const
get a element of dimension - 1
Definition: tensor.h:509
index_t shape_[kDimension]
storing the dimension information
Definition: tensor.h:57
MSHADOW_XINLINE Tensor(const Shape< 1 > &shape)
Definition: tensor.h:559
MSHADOW_XINLINE Shape(void)
default constructor, do nothing
Definition: tensor.h:59
void InitTensorEngine(int device_id=0)
initialize tensor engine, used to call intialization functions of dependent libs this function should...
MSHADOW_XINLINE Shape< 2 > Shape2(index_t s0, index_t s1)
construct a two dimension shape, stride will equal s0
Definition: tensor.h:198
Definition: base.h:435
implementation of GPU code
MSHADOW_XINLINE bool CheckContiguous(void) const
Definition: tensor.h:582
static const int kDevMask
device flag number, identifies this device
Definition: tensor.h:25
MSHADOW_XINLINE bool CheckContiguous(void) const
Definition: tensor.h:473
Definition: base.h:430
MSHADOW_XINLINE Tensor< Device, dimension, DType > Slice(index_t begin, index_t end) const
slice the tensor in highest dimension [begin,end)
Definition: tensor.h:520
void Wait(void)
wait for all the computations associated with this stream to complete
Definition: tensor.h:372
MSHADOW_XINLINE Tensor(DType *dptr, const Shape< dimension > &shape, index_t stride, Stream< Device > *stream)
constructor from data pointer and shape
Definition: tensor.h:445
MSHADOW_XINLINE Tensor< Device, 1, DType > Slice(index_t begin, index_t end) const
Definition: tensor.h:577
void MapReduceKeepHighDim(TRValue< R, cpu, 1, DType > *dst, const expr::Exp< E, DType, etype > &exp, DType scale=1)
CPU/GPU: map a expression, do reduction to 1D Tensor in third dimension (dimension 2) ...
Definition: tensor_cpu-inl.h:232
Definition: base.h:431
MSHADOW_XINLINE Tensor(const Shape< dimension > &shape)
constructor from shape
Definition: tensor.h:435
index_t stride_
Definition: tensor.h:555
Tensor< Device, dim, DType > NewTensor(const Shape< dim > &shape, DType initv, bool pad=MSHADOW_ALLOC_PAD, Stream< Device > *stream=NULL)
CPU/GPU: short cut to allocate and initialize a Tensor.
Definition: tensor_cpu-inl.h:114
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:61
MSHADOW_XINLINE Tensor(DType *dptr, const Shape< dimension > &shape, Stream< Device > *stream)
constructor from data pointer and shape, without stride
Definition: tensor.h:441
Tensor< Device, 1, DType > & operator=(const expr::Exp< E, DType, etype > &exp)
Definition: tensor.h:608
implementation of GPU host code
tensor container that does memory allocation and resize like STL
void AddTakeGrad(Tensor< cpu, 2, DType > dst, const Tensor< cpu, 1, IndexType > &index, const Tensor< cpu, 2, DType > &src)
CPU/GPU: Gradient accumulate of embedding matrix. dst[index[i]] += src[i] Called when the featuredim ...
Definition: tensor_cpu-inl.h:498
Definition: tensor.h:550
MSHADOW_XINLINE Shape< 3 > Shape3(index_t s0, index_t s1, index_t s2)
construct a three dimension shape, stride will equal s0
Definition: tensor.h:209
namespace for mshadow
Definition: base.h:282
void set_stream(Stream< Device > *stream)
Definition: tensor.h:568
Random inline functions for tensor.
MSHADOW_XINLINE DType & operator[](index_t idx)
Definition: tensor.h:591
MSHADOW_XINLINE index_t size(int idx) const
return size of i-th dimension, start counting from highest dimension
Definition: tensor.h:487
MSHADOW_XINLINE Tensor(DType *dptr, Shape< 1 > shape, index_t stride, Stream< Device > *stream)
Definition: tensor.h:565
index_t stride_
storing the stride information in x dimension this is used to deal with pitch allocation in gpu or ss...
Definition: tensor.h:423
Tensor< Device, 1, DType > & operator=(const DType &exp)
Definition: tensor.h:611
general tensor
Definition: tensor.h:402
implementation of CPU host code
#define MSHADOW_DEFAULT_DTYPE
default data type for tensor string in code release, change it to default_real_t during development...
Definition: base.h:223
MSHADOW_XINLINE Tensor(void)
Definition: tensor.h:558
MSHADOW_XINLINE index_t MSize(void) const
Definition: tensor.h:479
Stream< Device > * stream_
stream where the computation lies stream is a device dependency concept where each computation ...
Definition: tensor.h:428
Shape< 1 > shape_
Definition: tensor.h:554
MSHADOW_XINLINE bool operator==(const Shape< kDimension > &s) const
Definition: tensor.h:87
Tensor< Device, dimension, DType > & operator=(const DType &exp)
functions to fit expression template
Definition: tensor.h:542
computaion stream structure, used for asynchronous computations
Definition: tensor.h:365