Go to the documentation of this file.
30 #ifndef MSHADOW_TENSOR_H_
31 #define MSHADOW_TENSOR_H_
53 template <
typename xpu>
79 template<
int dimension>
93 this->shape_[i] = s[i];
110 #pragma GCC diagnostic push
111 #pragma GCC diagnostic ignored "-Warray-bounds"
113 #pragma GCC diagnostic pop
122 if (s.
shape_[i] != this->shape_[i])
return false;
131 return !(*
this == s);
152 ymax *= this->shape_[i];
159 index_t size = this->shape_[0];
162 size *= this->shape_[i];
174 for (
int i = dimstart; i < dimend; ++i) {
175 num *= this->shape_[i];
187 for (
int i = 0; i <
kSubdim; ++i) {
188 s.
shape_[i] = this->shape_[i + 1];
198 template<
int dimstart,
int dimend>
200 Shape<dimend - dimstart> s;
202 for (
int i = dimstart; i < dimend; ++i) {
203 s[i - dimstart] = this->shape_[i];
243 s[0] = s0; s[1] = s1; s[2] = s2;
257 s[0] = s0; s[1] = s1; s[2] = s2; s[3] = s3;
272 s[0] = s0; s[1] = s1; s[2] = s2; s[3] = s3; s[4] = s4;
285 switch (src_layout) {
295 LOG(FATAL) <<
"Invalid layout for 3d shape " << src_layout;
297 switch (dst_layout) {
308 LOG(FATAL) <<
"Invalid layout for 3d shape " << src_layout;
322 switch (src_layout) {
333 LOG(FATAL) <<
"Invalid layout for 4d shape " << src_layout;
337 switch (dst_layout) {
347 LOG(FATAL) <<
"Invalid layout for 4d shape " << src_layout;
362 switch (src_layout) {
374 LOG(FATAL) <<
"Invalid layout for 5d shape " << src_layout;
377 switch (dst_layout) {
388 LOG(FATAL) <<
"Invalid layout for 5d shape " << src_layout;
400 template <
typename dim_t>
402 auto apply = [](
const std::vector<dim_t>& v,
const std::vector<dim_t>& op) {
403 CHECK_EQ(v.size(), op.size()) <<
"Layout ndims does not match";
404 std::vector<dim_t> ret(v.size());
405 for (
size_t i = 0; i < v.size(); i++) {
410 std::vector<dim_t> axes;
412 switch (src_layout) {
414 LOG(FATAL) <<
"Unknown source layout";
417 axes = std::vector<dim_t>({0, 1, 2, 3});
420 axes = std::vector<dim_t>({0, 2, 3, 1});
423 axes = std::vector<dim_t>({3, 1, 2, 0});
426 axes = std::vector<dim_t>({0, 1, 2});
429 axes = std::vector<dim_t>({0, 2, 1});
432 axes = std::vector<dim_t>({2, 1, 0});
435 axes = std::vector<dim_t>({0, 1, 2, 3, 4});
438 axes = std::vector<dim_t>({0, 2, 3, 4, 1});
441 axes = std::vector<dim_t>({4, 1, 2, 3, 0});
444 LOG(FATAL) <<
"Invalid source layout " << src_layout;
447 switch (dst_layout) {
449 LOG(FATAL) <<
"Unknown destination layout";
452 axes = apply(axes, {0, 1, 2, 3});
455 axes = apply(axes, {0, 3, 1, 2});
458 axes = apply(axes, {3, 1, 2, 0});
461 axes = apply(axes, {0, 1, 2});
464 axes = apply(axes, {0, 2, 1});
467 axes = apply(axes, {2, 1, 0});
470 axes = apply(axes, {0, 1, 2, 3, 4});
473 axes = apply(axes, {0, 4, 1, 2, 3});
476 axes = apply(axes, {4, 1, 2, 3, 0});
479 LOG(FATAL) <<
"Invalid destination layout " << src_layout;
487 template<
typename Device>
513 template<
typename Container,
typename Device,
int dimension,
typename DType>
523 template<
typename Device,
int dimension,
526 Device, dimension, DType> {
577 this->stream_ = stream;
583 template<
int startdim>
587 for (
int i = startdim; i <
kSubdim; ++i) {
588 memsz *= this->shape_[i];
597 return this->shape_[dimension - 1] ==
stride_;
603 return this->MemSize<0>();
659 template<
typename E,
int etype>
672 template<
typename Device,
typename DType>
674 public TRValue<Tensor<Device, 1, DType>, Device, 1, DType> {
692 this->stream_ = stream;
729 template<
typename E,
int etype>
748 template<
typename Device>
756 template<
typename Device>
763 template<
typename Device>
773 template<
typename Device>
774 inline Stream<Device> *
NewStream(
bool create_blas_handle,
775 bool create_dnn_handle,
781 template<
typename Device>
783 return NewStream<Device>(
true,
false, dev_id);
789 template<
typename Device>
802 template<
int dim,
typename DType>
803 inline void AllocSpace(Tensor<cpu, dim, DType> *obj,
816 template<
int dim,
typename DType>
817 inline void AllocSpace(Tensor<gpu, dim, DType> *obj,
825 template<
int dim,
typename DType>
826 inline void FreeSpace(Tensor<cpu, dim, DType> *obj);
833 template<
int dim,
typename DType>
834 inline void FreeSpace(Tensor<gpu, dim, DType> *obj);
847 template<
typename Device,
typename DType,
int dim>
848 inline Tensor<Device, dim, DType>
NewTensor(
const Shape<dim> &shape,
851 Stream<Device> *stream = NULL);
860 template<
int dim,
typename DType>
861 inline void Copy(Tensor<cpu, dim, DType> dst,
862 const Tensor<cpu, dim, DType> &src,
863 Stream<cpu> *stream = NULL);
872 template<
int dim,
typename DType>
873 inline void Copy(Tensor<cpu, dim, DType> dst,
874 const Tensor<gpu, dim, DType> &src,
875 Stream<gpu> *stream = NULL);
884 template<
int dim,
typename DType>
885 inline void Copy(Tensor<gpu, dim, DType> dst,
886 const Tensor<cpu, dim, DType> &src,
887 Stream<gpu> *stream = NULL);
896 template<
int dim,
typename DType>
897 inline void Copy(Tensor<gpu, dim, DType> dst,
898 const Tensor<gpu, dim, DType> &src,
899 Stream<gpu> *stream = NULL);
905 template<
typename DType>
906 inline void Softmax(Tensor<cpu, 2, DType> dst,
const Tensor<cpu, 2, DType> &energy);
912 template<
typename DType>
913 inline void Softmax(Tensor<gpu, 2, DType> dst,
const Tensor<gpu, 2, DType> &energy);
921 template<
typename DType>
923 const Tensor<cpu, 2, DType> &src,
924 const Tensor<cpu, 1, DType> &label);
931 template<
typename DType>
932 inline void SoftmaxGrad(
const Tensor<gpu, 2, DType> &dst,
933 const Tensor<gpu, 2, DType> &src,
934 const Tensor<gpu, 1, DType> &label);
943 template<
bool clip = true,
typename IndexType,
typename DType>
945 const Tensor<cpu, 1, IndexType>& index,
946 const Tensor<cpu, 2, DType> &src);
955 template<
bool clip = true,
typename IndexType,
typename DType,
typename AType>
957 Tensor<cpu, 2, AType> temp,
958 const Tensor<cpu, 1, IndexType>& index,
959 const Tensor<cpu, 2, DType> &src);
968 template<
bool clip = true,
typename IndexType,
typename DType>
970 const Tensor<gpu, 1, IndexType>& index,
971 const Tensor<gpu, 2, DType> &src);
981 template<
bool clip = true,
typename IndexType,
typename DType,
typename AType>
983 Tensor<gpu, 2, AType> temp,
984 const Tensor<gpu, 1, IndexType>& index,
985 const Tensor<gpu, 2, DType> &src);
994 template<
typename IndexType,
typename DType>
996 const Tensor<cpu, 1, IndexType>& sorted,
997 const Tensor<cpu, 1, IndexType>& index,
998 const Tensor<cpu, 2, DType> &src);
1008 template<
typename IndexType,
typename DType>
1010 const Tensor<gpu, 1, IndexType>& sorted,
1011 const Tensor<gpu, 1, IndexType>& index,
1012 const Tensor<gpu, 2, DType> &src);
1021 template<
typename IndexType,
typename DType>
1022 inline void IndexFill(Tensor<cpu, 2, DType> dst,
1023 const Tensor<cpu, 1, IndexType>& index,
1024 const Tensor<cpu, 2, DType> &src);
1033 template<
typename IndexType,
typename DType>
1034 inline void IndexFill(Tensor<gpu, 2, DType> dst,
1035 const Tensor<gpu, 1, IndexType>& index,
1036 const Tensor<gpu, 2, DType> &src);
1043 template<
typename KDType,
typename VDType>
1044 inline void SortByKey(Tensor<cpu, 1, KDType> keys, Tensor<cpu, 1, VDType> values,
1045 bool is_ascend =
true);
1052 template<
typename KDType,
typename VDType>
1053 inline void SortByKey(Tensor<gpu, 1, KDType> keys, Tensor<gpu, 1, VDType> values,
1054 bool is_ascend =
true);
1063 template<
typename Device,
typename VDType,
typename SDType>
1064 inline void VectorizedSort(Tensor<Device, 1, VDType> values, Tensor<Device, 1, SDType> segments);
1080 template<
typename Saver,
typename R,
int dim,
1081 typename DType,
typename E,
int etype>
1082 inline void MapExp(TRValue<R, cpu, dim, DType> *dst,
1083 const expr::Exp<E, DType, etype> &exp);
1096 template<
typename Saver,
typename R,
int dim,
1097 typename DType,
typename E,
int etype>
1098 inline void MapExp(TRValue<R, gpu, dim, DType> *dst,
1099 const expr::Exp<E, DType, etype> &exp);
1113 template<
typename Saver,
typename Reducer,
1114 typename R,
typename DType,
typename E,
int etype>
1116 const expr::Exp<E, DType, etype> &exp,
1131 template<
typename Saver,
typename Reducer,
typename R,
1132 typename DType,
typename E,
int etype>
1134 const expr::Exp<E, DType, etype> &exp,
1150 template<
typename Saver,
typename Reducer,
int dimkeep,
1151 typename R,
typename DType,
typename E,
int etype>
1153 const expr::Exp<E, DType, etype> &exp,
1169 template<
typename Saver,
typename Reducer,
int dimkeep,
1170 typename R,
typename DType,
typename E,
int etype>
1172 const expr::Exp<E, DType, etype> &exp,
1180 template<
typename Device,
typename DType>
1181 inline void VectorDot(Tensor<Device, 1, DType> dst,
1182 const Tensor<Device, 1, DType> &lhs,
1183 const Tensor<Device, 1, DType> &rhs);
1193 template<
bool transpose_left,
bool transpose_right,
typename Device,
typename DType>
1194 inline void BatchGEMM(Tensor<Device, 3, DType> dst,
1195 const Tensor<Device, 3, DType> &lhs,
1196 const Tensor<Device, 3, DType> &rhs,
1199 Tensor<Device, 1, DType*> workspace);
1211 #ifdef MSHADOW_SCALAR_
1212 #error "MSHADOW_SCALAR_ must not be defined"
1215 #define MSHADOW_SCALAR_ float
1217 #undef MSHADOW_SCALAR_
1218 #define MSHADOW_SCALAR_ double
1220 #undef MSHADOW_SCALAR_
1221 #define MSHADOW_SCALAR_ int32_t
1223 #undef MSHADOW_SCALAR_
1224 #define MSHADOW_SCALAR_ int64_t
1226 #undef MSHADOW_SCALAR_
1227 #define MSHADOW_SCALAR_ mshadow::half::half_t
1229 #undef MSHADOW_SCALAR_
1230 #endif // MSHADOW_TENSOR_H_
definitions of abstract expressions and expressions template
@ kNCHW
Definition: base.h:501
static const bool kDevCPU
whether this device is CPU or not
Definition: tensor.h:48
implementation of GPU code
MSHADOW_XINLINE Tensor< Device, 1, DType > Slice(index_t begin, index_t end) const
Definition: tensor.h:700
MSHADOW_XINLINE Shape< 4 > Shape4(index_t s0, index_t s1, index_t s2, index_t s3)
construct a four dimension shape, stride will equal s0
Definition: tensor.h:254
void set_stream(Stream< Device > *stream)
Definition: tensor.h:691
void set_stream(Stream< Device > *stream)
set the stream to do computation of current tensor
Definition: tensor.h:576
MSHADOW_XINLINE index_t Size(void) const
Definition: tensor.h:158
MSHADOW_XINLINE Tensor< Device, 1, DType > FlatTo1D(void) const
Definition: tensor.h:694
void SortByKey(Tensor< cpu, 1, KDType > keys, Tensor< cpu, 1, VDType > values, bool is_ascend=true)
CPU/GPU: Sort key-value pairs stored in separate places. (Stable sort is performed!...
Definition: tensor_cpu-inl.h:596
MSHADOW_XINLINE index_t MSize(void) const
Definition: tensor.h:602
computaion stream structure, used for asynchronous computations
Definition: tensor.h:488
#define MSHADOW_DEFAULT_DTYPE
default data type for tensor string in code release, change it to default_real_t during development,...
Definition: base.h:240
MSHADOW_XINLINE index_t MSize(void) const
Definition: tensor.h:708
const MSHADOW_XINLINE index_t & operator[](int idx) const
get corresponding index
Definition: tensor.h:109
static const int kSubdim
dimension of subtype
Definition: tensor.h:534
static const int kDimension
dimension of current shape
Definition: tensor.h:82
@ kNWC
Definition: base.h:506
@ kNCDHW
Definition: base.h:509
static const int kSubdim
dimension of current shape minus one
Definition: tensor.h:84
Tensor< Device, 1, DType > & operator=(const expr::Exp< E, DType, etype > &exp)
Definition: tensor.h:731
Tensor RValue, this is the super type of all kinds of possible tensors.
Definition: tensor.h:514
definitions of I/O functions for mshadow tensor
void Copy(Tensor< cpu, dim, DType > dst, const Tensor< cpu, dim, DType > &src, Stream< cpu > *stream=NULL)
copy data from one tensor to another, with same shape
Definition: tensor_cpu-inl.h:145
Shape< 1 > shape_
Definition: tensor.h:677
MSHADOW_XINLINE Tensor< Device, 2, DType > FlatTo2D(void) const
Definition: tensor.h:697
void FreeSpace(Tensor< cpu, dim, DType > *obj)
CPU/GPU: free the space of tensor, will set obj.dptr to NULL.
Definition: tensor_cpu-inl.h:140
void BatchGEMM(Tensor< Device, 3, DType > dst, const Tensor< Device, 3, DType > &lhs, const Tensor< Device, 3, DType > &rhs, DType alpha, DType beta, Tensor< Device, 1, DType * > workspace)
CPU/GPU: dst = alpha * op(lhs) op(rhs) + beta * dst.
Definition: tensor_cpu-inl.h:648
void IndexFill(Tensor< cpu, 2, DType > dst, const Tensor< cpu, 1, IndexType > &index, const Tensor< cpu, 2, DType > &src)
CPU/GPU: Fill the values of the destination matrix to specific rows in the source matrix....
Definition: tensor_cpu-inl.h:585
lapack_index_t IndexT
Definition: tensor.h:55
#define MSHADOW_XINLINE
Definition: base.h:228
MSHADOW_XINLINE Tensor(const Shape< dimension > &shape)
constructor from shape
Definition: tensor.h:558
void SoftmaxGrad(Tensor< cpu, 2, DType > dst, const Tensor< cpu, 2, DType > &src, const Tensor< cpu, 1, DType > &label)
CPU/GPU: softmax gradient.
Definition: tensor_cpu-inl.h:311
MSHADOW_XINLINE Tensor(DType *dptr, const Shape< dimension > &shape, Stream< Device > *stream)
constructor from data pointer and shape, without stride
Definition: tensor.h:564
void DeleteStream(Stream< Device > *stream)
delete the computing stream
Stream< Device > * stream_
Definition: tensor.h:679
MSHADOW_XINLINE Shape< 1 > FlatTo1D(void) const
Definition: tensor.h:137
void MapReduceKeepLowest(TRValue< R, cpu, 1, DType > *dst, const expr::Exp< E, DType, etype > &exp, DType scale=1)
CPU/GPU: map a expression, do reduction to 1D Tensor in lowest dimension (dimension 0)
Definition: tensor_cpu-inl.h:228
MSHADOW_XINLINE Tensor< Device, 1, DType > FlatTo1D(void) const
flatten the tensor to 1 dimension
Definition: tensor.h:617
general tensor
Definition: tensor.h:525
MSHADOW_XINLINE Tensor(DType *dptr, Shape< 1 > shape, index_t stride, Stream< Device > *stream)
Definition: tensor.h:688
void VectorizedSort(Tensor< Device, 1, VDType > values, Tensor< Device, 1, SDType > segments)
CPU/GPU: Sort the keys within each segment. (Stable sort is performed!) Segments is defined as an asc...
Definition: tensor_cpu-inl.h:627
MSHADOW_XINLINE Tensor< Device, dimension, DType > Slice(index_t begin, index_t end) const
slice the tensor in highest dimension [begin,end)
Definition: tensor.h:643
static const int kDevMask
device flag number, identifies this device
Definition: tensor.h:43
void ShutdownTensorEngine(void)
Shutdown tensor engine on current device this function should be called after all GPU tensor operatio...
Stream< Device > * NewStream(bool create_blas_handle, bool create_dnn_handle, int dev_id=-1)
create a new stream from system
static const bool kDevCPU
whether this device is CPU or not
Definition: tensor.h:41
MSHADOW_XINLINE Shape< kSubdim > SubShape(void) const
Definition: tensor.h:183
std::vector< dim_t > getTranspAxes(const LayoutFlag src_layout, const LayoutFlag dst_layout)
returns axes of transpose operation that needs to be performed between src layout and dst
Definition: tensor.h:401
std::ostream & operator<<(std::ostream &os, const Shape< ndim > &shape)
allow string printing of the shape
Definition: tensor_cpu-inl.h:59
MSHADOW_XINLINE Shape< dimend - dimstart > Slice(void) const
slice the shape from start to end
Definition: tensor.h:199
@ kNHWC
Definition: base.h:502
Shape< 3 > ConvertLayout(const Shape< 3 > &src, int src_layout, int dst_layout)
Convert shape in src_layout to shape in dst_layout.
Definition: tensor.h:283
MSHADOW_XINLINE index_t size(index_t i) const
Definition: tensor.h:711
device name GPU
Definition: tensor.h:46
@ kCWN
Definition: base.h:507
implementation of GPU host code
device name CPU
Definition: tensor.h:39
@ kCHWN
Definition: base.h:503
Random inline functions for tensor.
void Softmax(Tensor< cpu, 2, DType > dst, const Tensor< cpu, 2, DType > &energy)
CPU/GPU: normalize softmax: dst[i][j] = exp(energy[i][j]) /(sum_j exp(energy[i][j]))
Definition: tensor_cpu-inl.h:488
MSHADOW_XINLINE index_t MemSize(void) const
Definition: tensor.h:584
LayoutFlag
Definition: base.h:498
Tensor< Device, dimension, DType > & operator=(const expr::Exp< E, DType, etype > &exp)
functions to fit expression template
Definition: tensor.h:661
Tensor< Device, 1, DType > & operator=(const Tensor< Device, 1, DType > &exp)
implement the assignment of same type
Definition: tensor.h:722
MSHADOW_XINLINE Tensor(DType *dptr, Shape< 1 > shape, Stream< Device > *stream)
Definition: tensor.h:686
Tensor< Device, dimension, DType > & operator=(const DType &exp)
functions to fit expression template
Definition: tensor.h:665
MSHADOW_XINLINE Tensor(DType *dptr, const Shape< dimension > &shape)
constructor from data pointer and shape, without stride
Definition: tensor.h:561
index_t stride_
Definition: tensor.h:678
MSHADOW_XINLINE Shape< 5 > Shape5(index_t s0, index_t s1, index_t s2, index_t s3, index_t s4)
construct a five dimension shape, stride will equal s0
Definition: tensor.h:269
index_t shape_[kDimension]
storing the dimension information
Definition: tensor.h:86
Stream< Device > * stream_
stream where the computation lies stream is a device dependency concept where each computation
Definition: tensor.h:551
void SetDevice(int devid)
set the device of current thread to work on
Tensor< Device, 1, DType > & operator=(const DType &exp)
Definition: tensor.h:734
implementation of CPU host code
MSHADOW_XINLINE Shape< 2 > FlatTo2D(void) const
Definition: tensor.h:146
some extension of expressions, used to support something beyond elementwise op
@ kCDHWN
Definition: base.h:511
MSHADOW_XINLINE bool CheckContiguous(void) const
Definition: tensor.h:596
Shape< dimension > shape_
shape of the tensor
Definition: tensor.h:541
MSHADOW_XINLINE bool operator!=(const Shape< kDimension > &s) const
Definition: tensor.h:130
int32_t index_t
type that will be used for index
Definition: base.h:328
#define MSHADOW_ALLOC_PAD
whether do padding during allocation
Definition: base.h:73
MSHADOW_XINLINE Shape(void)
default constructor, do nothing
Definition: tensor.h:88
DType * dptr_
Definition: tensor.h:676
base class of all rvalues
Definition: expression.h:148
void AddTakeGradLargeBatch(Tensor< cpu, 2, DType > dst, const Tensor< cpu, 1, IndexType > &sorted, const Tensor< cpu, 1, IndexType > &index, const Tensor< cpu, 2, DType > &src)
CPU/GPU: Gradient accumulate of embedding matrix with safe accumulation. dst[index[i]] += src[i].
Definition: tensor_cpu-inl.h:575
int lapack_index_t
Definition: base.h:344
void Wait(void)
wait for all the computations associated with this stream to complete
Definition: tensor.h:495
@ kNDHWC
Definition: base.h:510
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:79
PaddingExp< SrcExp, DType, ExpInfo< SrcExp >::kDim > pad(const Exp< SrcExp, DType, etype > &src, index_t pad)
padding expression, pad a image with zeros on boundaries, padding affects shape[0],...
Definition: pad.h:71
MSHADOW_XINLINE Tensor(const Shape< 1 > &shape)
Definition: tensor.h:682
void CreateBlasHandle()
create a blas handle
Definition: tensor.h:504
MSHADOW_XINLINE DType & operator[](index_t idx)
Definition: tensor.h:714
void AllocSpace(Tensor< cpu, dim, DType > *obj, bool pad=MSHADOW_ALLOC_PAD)
CPU/CPU: allocate space for CTensor, according to the shape in the obj this function is responsible t...
Definition: tensor_cpu-inl.h:116
MSHADOW_XINLINE Tensor(DType *dptr, Shape< 1 > shape)
Definition: tensor.h:684
MSHADOW_XINLINE bool CheckContiguous(void) const
Definition: tensor.h:705
MSHADOW_XINLINE Tensor< Device, kSubdim, DType > operator[](index_t idx) const
get a element of dimension - 1
Definition: tensor.h:632
void MapReduceKeepHighDim(TRValue< R, cpu, 1, DType > *dst, const expr::Exp< E, DType, etype > &exp, DType scale=1)
CPU/GPU: map a expression, do reduction to 1D Tensor in third dimension (dimension 2)
Definition: tensor_cpu-inl.h:255
overloaded + operator between half_t and bf16_t
Definition: base.h:319
MSHADOW_XINLINE Tensor< Device, 2, DType > FlatTo2D(void) const
flatten the tensor to 2 dimension, collapse the higher dimensions together
Definition: tensor.h:624
MSHADOW_XINLINE Shape< 2 > Shape2(index_t s0, index_t s1)
construct a two dimension shape, stride will equal s0
Definition: tensor.h:230
static const bool kDevCPU
whether current type lies in cpu
Definition: tensor.h:532
int IndexT
Definition: tensor.h:60
@ kUNKNOWN
Definition: base.h:499
MSHADOW_XINLINE Shape(const Shape< kDimension > &s)
constuctor
Definition: tensor.h:90
shape of a tensor
Definition: tensor.h:64
DType * dptr_
pointer to the data
Definition: tensor.h:539
tensor container that does memory allocation and resize like STL
@ kNCW
Definition: base.h:505
MSHADOW_XINLINE Shape< 1 > Shape1(index_t s0)
construct a one dimension shape, stride will equal s0
Definition: tensor.h:220
void InitTensorEngine(int device_id=0)
initialize tensor engine, used to call intialization functions of dependent libs this function should...
void VectorDot(Tensor< Device, 1, DType > dst, const Tensor< Device, 1, DType > &lhs, const Tensor< Device, 1, DType > &rhs)
CPU/GPU: 1 dimension vector dot.
Definition: tensor_cpu-inl.h:635
MSHADOW_XINLINE index_t size(int idx) const
return size of i-th dimension, start counting from highest dimension
Definition: tensor.h:610
Tensor< Device, dimension, DType > & __assign(DType s)
operator overload
Definition: expression.h:178
Tensor< Device, dim, DType > NewTensor(const Shape< dim > &shape, DType initv, bool pad=MSHADOW_ALLOC_PAD, Stream< Device > *stream=NULL)
CPU/GPU: short cut to allocate and initialize a Tensor.
Definition: tensor_cpu-inl.h:132
void AddTakeGrad(Tensor< cpu, 2, DType > dst, const Tensor< cpu, 1, IndexType > &index, const Tensor< cpu, 2, DType > &src)
CPU/GPU: Gradient accumulate of embedding matrix. dst[index[i]] += src[i] Called when the featuredim ...
Definition: tensor_cpu-inl.h:521
MSHADOW_XINLINE index_t & operator[](int idx)
get corresponding index
Definition: tensor.h:101
bool CheckIdle(void)
query whether the the stream is idle
Definition: tensor.h:500
MSHADOW_XINLINE Tensor(void)
Definition: tensor.h:681
static const int kDevMask
device flag number, identifies this device
Definition: tensor.h:50
MSHADOW_XINLINE index_t ProdShape(int dimstart, int dimend) const
Definition: tensor.h:171
MSHADOW_XINLINE Tensor(DType *dptr, const Shape< dimension > &shape, index_t stride, Stream< Device > *stream)
constructor from data pointer and shape
Definition: tensor.h:568
definitions of how expressions should be evaluated
definitions of operators in expression with respect to scalar this file will be included several time...
MSHADOW_XINLINE Tensor(void)
default constructor
Definition: tensor.h:556
definitions of base types, operators, macros functions
MSHADOW_XINLINE bool operator==(const Shape< kDimension > &s) const
Definition: tensor.h:119
MSHADOW_XINLINE Shape< 3 > Shape3(index_t s0, index_t s1, index_t s2)
construct a three dimension shape, stride will equal s0
Definition: tensor.h:241
const MSHADOW_XINLINE DType & operator[](index_t idx) const
Definition: tensor.h:717
Tensor< Device, dimension, DType > & operator=(const Tensor< Device, dimension, DType > &exp)
implement the assignment of same type
Definition: tensor.h:651
index_t stride_
storing the stride information in x dimension this is used to deal with pitch allocation in gpu or ss...
Definition: tensor.h:546
void MapExp(TRValue< R, cpu, dim, DType > *dst, const expr::Exp< E, DType, etype > &exp)
CPU/GPU: map a expression to a tensor, this function calls MapPlan.
Definition: tensor_cpu-inl.h:212