mxnet
tensor.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
30 #ifndef MSHADOW_TENSOR_H_
31 #define MSHADOW_TENSOR_H_
32 #include <string>
33 #include <iostream>
34 #include "./base.h"
35 #include "./expression.h"
36 
37 namespace mshadow {
39 struct cpu {
41  static const bool kDevCPU = true;
43  static const int kDevMask = 1 << 0;
44 };
46 struct gpu {
48  static const bool kDevCPU = false;
50  static const int kDevMask = 1 << 1;
51 };
52 
53 template <typename xpu>
54 struct LapackIndex {
56 };
57 
58 template <>
59 struct LapackIndex <gpu> {
60  using IndexT = int;
61 };
62 
63 template<int ndim>
64 struct Shape;
65 
72 template<int ndim>
73 inline std::ostream &operator<<(std::ostream &os, const Shape<ndim> &shape); // NOLINT(*)
74 
79 template<int dimension>
80 struct Shape {
82  static const int kDimension = dimension;
84  static const int kSubdim = dimension - 1;
91  #pragma unroll
92  for (int i = 0; i < kDimension; ++i) {
93  this->shape_[i] = s[i];
94  }
95  }
102  return shape_[idx];
103  }
109  MSHADOW_XINLINE const index_t &operator[](int idx) const {
110 #pragma GCC diagnostic push
111 #pragma GCC diagnostic ignored "-Warray-bounds"
112  return shape_[idx];
113 #pragma GCC diagnostic pop
114  }
120  #pragma unroll
121  for (int i = 0; i < kDimension; ++i) {
122  if (s.shape_[i] != this->shape_[i]) return false;
123  }
124  return true;
125  }
131  return !(*this == s);
132  }
138  Shape<1> s;
139  s[0] = this->Size();
140  return s;
141  }
147  Shape<2> s;
148  s.shape_[1] = this->shape_[kDimension - 1];
149  index_t ymax = 1;
150  #pragma unroll
151  for (int i = 0; i < kDimension - 1; ++i) {
152  ymax *= this->shape_[i];
153  }
154  s.shape_[0] = ymax;
155  return s;
156  }
159  index_t size = this->shape_[0];
160  #pragma unroll
161  for (int i = 1; i < kDimension; ++i) {
162  size *= this->shape_[i];
163  }
164  return size;
165  }
171  MSHADOW_XINLINE index_t ProdShape(int dimstart, int dimend) const {
172  index_t num = 1;
173  #pragma unroll
174  for (int i = dimstart; i < dimend; ++i) {
175  num *= this->shape_[i];
176  }
177  return num;
178  }
184  Shape<kSubdim> s;
185  // for cuda
186  #pragma unroll
187  for (int i = 0; i < kSubdim; ++i) {
188  s.shape_[i] = this->shape_[i + 1];
189  }
190  return s;
191  }
198  template<int dimstart, int dimend>
199  MSHADOW_XINLINE Shape<dimend - dimstart> Slice(void) const {
200  Shape<dimend - dimstart> s;
201  #pragma unroll
202  for (int i = dimstart; i < dimend; ++i) {
203  s[i - dimstart] = this->shape_[i];
204  }
205  return s;
206  }
208  template<int dim>
209  friend std::ostream &operator<<(std::ostream &os, const Shape<dim> &shape); // NOLINT(*)
211 }; // Shape
212 //------------------------------------------------
213 // useful construction functions to generate shape
214 //-------------------------------------------------
221  Shape<1> s; s[0] = s0;
222  return s;
223 }
231  Shape<2> s; s[0] = s0; s[1] = s1;
232  return s;
233 }
242  Shape<3> s;
243  s[0] = s0; s[1] = s1; s[2] = s2;
244  return s;
245 }
255  index_t s2, index_t s3) {
256  Shape<4> s;
257  s[0] = s0; s[1] = s1; s[2] = s2; s[3] = s3;
258  return s;
259 }
270  index_t s3, index_t s4) {
271  Shape<5> s;
272  s[0] = s0; s[1] = s1; s[2] = s2; s[3] = s3; s[4] = s4;
273  return s;
274 }
275 
283 inline Shape<3> ConvertLayout(const Shape<3>& src, int src_layout, int dst_layout) {
284  Shape<3> dst;
285  switch (src_layout) {
286  case kNCW:
287  dst = src;
288  break;
289  case kNWC:
290  dst[0] = src[0];
291  dst[1] = src[2];
292  dst[2] = src[1];
293  break;
294  default:
295  LOG(FATAL) << "Invalid layout for 3d shape " << src_layout;
296  }
297  switch (dst_layout) {
298  case kNCW:
299  return dst;
300  case kNWC:
301  {
302  index_t tmp = dst[1];
303  dst[1] = dst[2];
304  dst[2] = tmp;
305  }
306  break;
307  default:
308  LOG(FATAL) << "Invalid layout for 3d shape " << src_layout;
309  }
310  return dst;
311 }
312 
320 inline Shape<4> ConvertLayout(const Shape<4>& src, int src_layout, int dst_layout) {
321  Shape<4> dst;
322  switch (src_layout) {
323  case kNCHW:
324  dst = src;
325  break;
326  case kNHWC:
327  dst[0] = src[0];
328  dst[2] = src[1];
329  dst[3] = src[2];
330  dst[1] = src[3];
331  break;
332  default:
333  LOG(FATAL) << "Invalid layout for 4d shape " << src_layout;
334  dst = src; // fixes compiler warning
335  }
336  Shape<4> dst2;
337  switch (dst_layout) {
338  case kNCHW:
339  return dst;
340  case kNHWC:
341  dst2[0] = dst[0];
342  dst2[1] = dst[2];
343  dst2[2] = dst[3];
344  dst2[3] = dst[1];
345  break;
346  default:
347  LOG(FATAL) << "Invalid layout for 4d shape " << src_layout;
348  dst2 = src; // fixes compiler warning
349  }
350  return dst2;
351 }
352 
360 inline Shape<5> ConvertLayout(const Shape<5>& src, int src_layout, int dst_layout) {
361  Shape<5> dst;
362  switch (src_layout) {
363  case kNCDHW:
364  dst = src;
365  break;
366  case kNDHWC:
367  dst[0] = src[0];
368  dst[2] = src[1];
369  dst[3] = src[2];
370  dst[4] = src[3];
371  dst[1] = src[4];
372  break;
373  default:
374  LOG(FATAL) << "Invalid layout for 5d shape " << src_layout;
375  }
376  Shape<5> dst2;
377  switch (dst_layout) {
378  case kNCDHW:
379  return dst;
380  case kNDHWC:
381  dst2[0] = dst[0];
382  dst2[1] = dst[2];
383  dst2[2] = dst[3];
384  dst2[3] = dst[4];
385  dst2[4] = dst[1];
386  break;
387  default:
388  LOG(FATAL) << "Invalid layout for 5d shape " << src_layout;
389  }
390  return dst2;
391 }
392 
400 template <typename dim_t>
401 inline std::vector<dim_t> getTranspAxes(const LayoutFlag src_layout, const LayoutFlag dst_layout) {
402  auto apply = [](const std::vector<dim_t>& v, const std::vector<dim_t>& op) {
403  CHECK_EQ(v.size(), op.size()) << "Layout ndims does not match";
404  std::vector<dim_t> ret(v.size());
405  for (size_t i = 0; i < v.size(); i++) {
406  ret[i] = v[op[i]];
407  }
408  return ret;
409  };
410  std::vector<dim_t> axes;
411  // transpose from `case` to ND?H?WC
412  switch (src_layout) {
413  case kUNKNOWN:
414  LOG(FATAL) << "Unknown source layout";
415  break;
416  case kNHWC:
417  axes = std::vector<dim_t>({0, 1, 2, 3});
418  break;
419  case kNCHW:
420  axes = std::vector<dim_t>({0, 2, 3, 1});
421  break;
422  case kCHWN:
423  axes = std::vector<dim_t>({3, 1, 2, 0});
424  break;
425  case kNWC:
426  axes = std::vector<dim_t>({0, 1, 2});
427  break;
428  case kNCW:
429  axes = std::vector<dim_t>({0, 2, 1});
430  break;
431  case kCWN:
432  axes = std::vector<dim_t>({2, 1, 0});
433  break;
434  case kNDHWC:
435  axes = std::vector<dim_t>({0, 1, 2, 3, 4});
436  break;
437  case kNCDHW:
438  axes = std::vector<dim_t>({0, 2, 3, 4, 1});
439  break;
440  case kCDHWN:
441  axes = std::vector<dim_t>({4, 1, 2, 3, 0});
442  break;
443  default:
444  LOG(FATAL) << "Invalid source layout " << src_layout;
445  }
446  // transpose from ND?H?WC to `case`
447  switch (dst_layout) {
448  case kUNKNOWN:
449  LOG(FATAL) << "Unknown destination layout";
450  break;
451  case kNHWC:
452  axes = apply(axes, {0, 1, 2, 3});
453  break;
454  case kNCHW:
455  axes = apply(axes, {0, 3, 1, 2});
456  break;
457  case kCHWN:
458  axes = apply(axes, {3, 1, 2, 0});
459  break;
460  case kNWC:
461  axes = apply(axes, {0, 1, 2});
462  break;
463  case kNCW:
464  axes = apply(axes, {0, 2, 1});
465  break;
466  case kCWN:
467  axes = apply(axes, {2, 1, 0});
468  break;
469  case kNDHWC:
470  axes = apply(axes, {0, 1, 2, 3, 4});
471  break;
472  case kNCDHW:
473  axes = apply(axes, {0, 4, 1, 2, 3});
474  break;
475  case kCDHWN:
476  axes = apply(axes, {4, 1, 2, 3, 0});
477  break;
478  default:
479  LOG(FATAL) << "Invalid destination layout " << src_layout;
480  }
481  return axes;
482 }
483 
487 template<typename Device>
488 struct Stream {
489  // this is only a dummy implementation for CPU
490  // for GPU, the actual implementation will be specialized in tensor_gpu-inl.h
495  inline void Wait(void) {}
500  inline bool CheckIdle(void) {
501  return true;
502  }
504  inline void CreateBlasHandle() {}
505 };
513 template<typename Container, typename Device, int dimension, typename DType>
514 struct TRValue: public expr::RValueExp<Container, DType> {
515 };
516 // more compact template
523 template<typename Device, int dimension,
524  typename DType MSHADOW_DEFAULT_DTYPE>
525 struct Tensor: public TRValue<Tensor<Device, dimension, DType>,
526  Device, dimension, DType> {
527  public:
528  //--------------------------------
529  // struct memembers
530  //--------------------------------
532  static const bool kDevCPU = Device::kDevCPU;
534  static const int kSubdim = dimension - 1;
535  //--------------------------------
536  // struct memembers
537  //--------------------------------
539  DType *dptr_ = nullptr;
552  //--------------------------------
553  // functions
554  //--------------------------------
556  MSHADOW_XINLINE Tensor(void) : stream_(NULL) {}
559  : shape_(shape), stream_(NULL) {}
561  MSHADOW_XINLINE Tensor(DType *dptr, const Shape<dimension> &shape)
562  : dptr_(dptr), shape_(shape), stride_(shape[kSubdim]), stream_(NULL) {}
564  MSHADOW_XINLINE Tensor(DType *dptr, const Shape<dimension> &shape,
565  Stream<Device> *stream)
566  : dptr_(dptr), shape_(shape), stride_(shape[kSubdim]), stream_(stream) {}
568  MSHADOW_XINLINE Tensor(DType *dptr,
569  const Shape<dimension> &shape,
570  index_t stride, Stream<Device> *stream)
571  : dptr_(dptr), shape_(shape), stride_(stride), stream_(stream) {}
576  inline void set_stream(Stream<Device> *stream) {
577  this->stream_ = stream;
578  }
583  template<int startdim>
585  index_t memsz = this->stride_;
586  #pragma unroll
587  for (int i = startdim; i < kSubdim; ++i) {
588  memsz *= this->shape_[i];
589  }
590  return memsz;
591  }
596  MSHADOW_XINLINE bool CheckContiguous(void) const {
597  return this->shape_[dimension - 1] == stride_;
598  }
603  return this->MemSize<0>();
604  }
610  MSHADOW_XINLINE index_t size(int idx) const {
611  return shape_[idx];
612  }
618  return Tensor<Device, 1, DType>(dptr_, shape_.FlatTo1D(), stride_, stream_);
619  }
625  return Tensor<Device, 2, DType>(dptr_, shape_.FlatTo2D(), stride_, stream_);
626  }
633  return Tensor<Device, kSubdim, DType>(dptr_ + this->MemSize<1>() * idx,
634  shape_.SubShape(), stride_, stream_);
635  }
643  Slice(index_t begin, index_t end) const {
644  Shape<dimension> s = this->shape_;
645  s[0] = end - begin;
646  return Tensor<Device, dimension, DType>(dptr_ + this->MemSize<1>() * begin,
647  s, stride_, stream_);
648  }
652  dptr_ = exp.dptr_;
653  shape_ = exp.shape_;
654  stride_ = exp.stride_;
655  stream_ = exp.stream_;
656  return *this;
657  }
659  template<typename E, int etype>
662  return this->__assign(exp);
663  }
665  inline Tensor<Device, dimension, DType> &operator=(const DType &exp) {
666  return this->__assign(exp);
667  }
668 };
669 /*
670  * respecialized class Tensor1D, thei is due to different implementation in operator[]
671  */
672 template<typename Device, typename DType>
673 struct Tensor<Device, 1, DType>:
674  public TRValue<Tensor<Device, 1, DType>, Device, 1, DType> {
675  public:
676  DType *dptr_;
680  // constructor
681  MSHADOW_XINLINE Tensor(void) : stream_(NULL) {}
683  : shape_(shape), stream_(NULL) {}
684  MSHADOW_XINLINE Tensor(DType *dptr, Shape<1> shape)
685  : dptr_(dptr), shape_(shape), stride_(shape[0]), stream_(NULL) {}
686  MSHADOW_XINLINE Tensor(DType *dptr, Shape<1> shape, Stream<Device> *stream)
687  : dptr_(dptr), shape_(shape), stride_(shape[0]), stream_(stream) {}
688  MSHADOW_XINLINE Tensor(DType *dptr, Shape<1> shape,
689  index_t stride, Stream<Device> *stream)
690  : dptr_(dptr), shape_(shape), stride_(stride), stream_(stream) {}
691  inline void set_stream(Stream<Device> *stream) {
692  this->stream_ = stream;
693  }
695  return *this;
696  }
698  return Tensor<Device, 2, DType>(dptr_, shape_.FlatTo2D(), stride_, stream_);
699  }
701  Shape<1> s;
702  s[0] = end - begin;
703  return Tensor<Device, 1, DType>(dptr_ + begin, s, s[0], stream_);
704  }
705  MSHADOW_XINLINE bool CheckContiguous(void) const {
706  return true;
707  }
709  return shape_[0];
710  }
712  return shape_[0];
713  }
715  return dptr_[idx];
716  }
717  MSHADOW_XINLINE const DType &operator[](index_t idx) const {
718  return dptr_[idx];
719  }
721  inline Tensor<Device, 1, DType> &
723  dptr_ = exp.dptr_;
724  shape_ = exp.shape_;
725  stride_ = exp.stride_;
726  stream_ = exp.stream_;
727  return *this;
728  }
729  template<typename E, int etype>
730  inline Tensor<Device, 1, DType> &
732  return this->__assign(exp);
733  }
734  inline Tensor<Device, 1, DType> &operator=(const DType &exp) {
735  return this->__assign(exp);
736  }
737 };
738 //------------------------
739 // Function Declarations
740 //-----------------------
748 template<typename Device>
749 inline void InitTensorEngine(int device_id = 0);
756 template<typename Device>
757 inline void ShutdownTensorEngine(void);
763 template<typename Device>
764 inline void SetDevice(int devid);
773 template<typename Device>
774 inline Stream<Device> *NewStream(bool create_blas_handle,
775  bool create_dnn_handle,
776  int dev_id = -1);
781 template<typename Device>
782 inline Stream<Device> *NewStream(int dev_id) {
783  return NewStream<Device>(true, false, dev_id);
784 }
789 template<typename Device>
790 inline void DeleteStream(Stream<Device> *stream);
802 template<int dim, typename DType>
803 inline void AllocSpace(Tensor<cpu, dim, DType> *obj,
804  bool pad = MSHADOW_ALLOC_PAD);
816 template<int dim, typename DType>
817 inline void AllocSpace(Tensor<gpu, dim, DType> *obj,
818  bool pad = MSHADOW_ALLOC_PAD);
825 template<int dim, typename DType>
826 inline void FreeSpace(Tensor<cpu, dim, DType> *obj);
833 template<int dim, typename DType>
834 inline void FreeSpace(Tensor<gpu, dim, DType> *obj);
847 template<typename Device, typename DType, int dim>
848 inline Tensor<Device, dim, DType> NewTensor(const Shape<dim> &shape,
849  DType initv,
850  bool pad = MSHADOW_ALLOC_PAD,
851  Stream<Device> *stream = NULL);
860 template<int dim, typename DType>
861 inline void Copy(Tensor<cpu, dim, DType> dst,
862  const Tensor<cpu, dim, DType> &src,
863  Stream<cpu> *stream = NULL);
872 template<int dim, typename DType>
873 inline void Copy(Tensor<cpu, dim, DType> dst,
874  const Tensor<gpu, dim, DType> &src,
875  Stream<gpu> *stream = NULL);
884 template<int dim, typename DType>
885 inline void Copy(Tensor<gpu, dim, DType> dst,
886  const Tensor<cpu, dim, DType> &src,
887  Stream<gpu> *stream = NULL);
896 template<int dim, typename DType>
897 inline void Copy(Tensor<gpu, dim, DType> dst,
898  const Tensor<gpu, dim, DType> &src,
899  Stream<gpu> *stream = NULL);
905 template<typename DType>
906 inline void Softmax(Tensor<cpu, 2, DType> dst, const Tensor<cpu, 2, DType> &energy);
912 template<typename DType>
913 inline void Softmax(Tensor<gpu, 2, DType> dst, const Tensor<gpu, 2, DType> &energy);
914 
921 template<typename DType>
922 inline void SoftmaxGrad(Tensor<cpu, 2, DType> dst,
923  const Tensor<cpu, 2, DType> &src,
924  const Tensor<cpu, 1, DType> &label);
931 template<typename DType>
932 inline void SoftmaxGrad(const Tensor<gpu, 2, DType> &dst,
933  const Tensor<gpu, 2, DType> &src,
934  const Tensor<gpu, 1, DType> &label);
943 template<bool clip = true, typename IndexType, typename DType>
944 inline void AddTakeGrad(Tensor<cpu, 2, DType> dst,
945  const Tensor<cpu, 1, IndexType>& index,
946  const Tensor<cpu, 2, DType> &src);
955 template<bool clip = true, typename IndexType, typename DType, typename AType>
956 inline void AddTakeGrad(Tensor<cpu, 2, DType> dst,
957  Tensor<cpu, 2, AType> temp,
958  const Tensor<cpu, 1, IndexType>& index,
959  const Tensor<cpu, 2, DType> &src);
968 template<bool clip = true, typename IndexType, typename DType>
969 inline void AddTakeGrad(Tensor<gpu, 2, DType> dst,
970  const Tensor<gpu, 1, IndexType>& index,
971  const Tensor<gpu, 2, DType> &src);
981 template<bool clip = true, typename IndexType, typename DType, typename AType>
982 inline void AddTakeGrad(Tensor<gpu, 2, DType> dst,
983  Tensor<gpu, 2, AType> temp,
984  const Tensor<gpu, 1, IndexType>& index,
985  const Tensor<gpu, 2, DType> &src);
994 template<typename IndexType, typename DType>
995 inline void AddTakeGradLargeBatch(Tensor<cpu, 2, DType> dst,
996  const Tensor<cpu, 1, IndexType>& sorted,
997  const Tensor<cpu, 1, IndexType>& index,
998  const Tensor<cpu, 2, DType> &src);
1008 template<typename IndexType, typename DType>
1009 inline void AddTakeGradLargeBatch(Tensor<gpu, 2, DType> dst,
1010  const Tensor<gpu, 1, IndexType>& sorted,
1011  const Tensor<gpu, 1, IndexType>& index,
1012  const Tensor<gpu, 2, DType> &src);
1021 template<typename IndexType, typename DType>
1022 inline void IndexFill(Tensor<cpu, 2, DType> dst,
1023  const Tensor<cpu, 1, IndexType>& index,
1024  const Tensor<cpu, 2, DType> &src);
1033 template<typename IndexType, typename DType>
1034 inline void IndexFill(Tensor<gpu, 2, DType> dst,
1035  const Tensor<gpu, 1, IndexType>& index,
1036  const Tensor<gpu, 2, DType> &src);
1043 template<typename KDType, typename VDType>
1044 inline void SortByKey(Tensor<cpu, 1, KDType> keys, Tensor<cpu, 1, VDType> values,
1045  bool is_ascend = true);
1052 template<typename KDType, typename VDType>
1053 inline void SortByKey(Tensor<gpu, 1, KDType> keys, Tensor<gpu, 1, VDType> values,
1054  bool is_ascend = true);
1063 template<typename Device, typename VDType, typename SDType>
1064 inline void VectorizedSort(Tensor<Device, 1, VDType> values, Tensor<Device, 1, SDType> segments);
1065 
1066 // function declarations to support expression, no need to understand them
1067 // these functions do not need to be directly used
1080 template<typename Saver, typename R, int dim,
1081  typename DType, typename E, int etype>
1082 inline void MapExp(TRValue<R, cpu, dim, DType> *dst,
1083  const expr::Exp<E, DType, etype> &exp);
1096 template<typename Saver, typename R, int dim,
1097  typename DType, typename E, int etype>
1098 inline void MapExp(TRValue<R, gpu, dim, DType> *dst,
1099  const expr::Exp<E, DType, etype> &exp);
1113 template<typename Saver, typename Reducer,
1114  typename R, typename DType, typename E, int etype>
1115 inline void MapReduceKeepLowest(TRValue<R, cpu, 1, DType> *dst,
1116  const expr::Exp<E, DType, etype> &exp,
1117  DType scale = 1);
1131 template<typename Saver, typename Reducer, typename R,
1132  typename DType, typename E, int etype>
1133 inline void MapReduceKeepLowest(TRValue<R, gpu, 1, DType> *dst,
1134  const expr::Exp<E, DType, etype> &exp,
1135  DType scale = 1);
1150 template<typename Saver, typename Reducer, int dimkeep,
1151  typename R, typename DType, typename E, int etype>
1152 inline void MapReduceKeepHighDim(TRValue<R, cpu, 1, DType> *dst,
1153  const expr::Exp<E, DType, etype> &exp,
1154  DType scale = 1);
1169 template<typename Saver, typename Reducer, int dimkeep,
1170  typename R, typename DType, typename E, int etype>
1171 inline void MapReduceKeepHighDim(TRValue<R, gpu, 1, DType> *dst,
1172  const expr::Exp<E, DType, etype> &exp,
1173  DType scale = 1);
1180 template<typename Device, typename DType>
1181 inline void VectorDot(Tensor<Device, 1, DType> dst,
1182  const Tensor<Device, 1, DType> &lhs,
1183  const Tensor<Device, 1, DType> &rhs);
1193 template<bool transpose_left, bool transpose_right, typename Device, typename DType>
1194 inline void BatchGEMM(Tensor<Device, 3, DType> dst,
1195  const Tensor<Device, 3, DType> &lhs,
1196  const Tensor<Device, 3, DType> &rhs,
1197  DType alpha,
1198  DType beta,
1199  Tensor<Device, 1, DType*> workspace);
1200 } // namespace mshadow
1201 // include headers
1202 #include "./stream_gpu-inl.h"
1203 #include "./extension.h"
1204 #include "./expr_engine-inl.h"
1205 #include "./tensor_cpu-inl.h"
1206 #include "./tensor_gpu-inl.h"
1207 #include "./io.h"
1208 #include "./tensor_container.h"
1209 #include "./random.h"
1210 // add definition of scalar related operators
1211 #ifdef MSHADOW_SCALAR_
1212  #error "MSHADOW_SCALAR_ must not be defined"
1213 #endif
1214 // enumerate all the scalar data type we aim to be good at
1215 #define MSHADOW_SCALAR_ float
1216 #include "./expr_scalar-inl.h"
1217 #undef MSHADOW_SCALAR_
1218 #define MSHADOW_SCALAR_ double
1219 #include "./expr_scalar-inl.h"
1220 #undef MSHADOW_SCALAR_
1221 #define MSHADOW_SCALAR_ int32_t
1222 #include "./expr_scalar-inl.h"
1223 #undef MSHADOW_SCALAR_
1224 #define MSHADOW_SCALAR_ int64_t
1225 #include "./expr_scalar-inl.h"
1226 #undef MSHADOW_SCALAR_
1227 #define MSHADOW_SCALAR_ mshadow::half::half_t
1228 #include "./expr_scalar-inl.h"
1229 #undef MSHADOW_SCALAR_
1230 #endif // MSHADOW_TENSOR_H_
expression.h
definitions of abstract expressions and expressions template
mshadow::kNCHW
@ kNCHW
Definition: base.h:501
mshadow::gpu::kDevCPU
static const bool kDevCPU
whether this device is CPU or not
Definition: tensor.h:48
stream_gpu-inl.h
implementation of GPU code
mshadow::Tensor< Device, 1, DType >::Slice
MSHADOW_XINLINE Tensor< Device, 1, DType > Slice(index_t begin, index_t end) const
Definition: tensor.h:700
mshadow::Shape4
MSHADOW_XINLINE Shape< 4 > Shape4(index_t s0, index_t s1, index_t s2, index_t s3)
construct a four dimension shape, stride will equal s0
Definition: tensor.h:254
mshadow::Tensor< Device, 1, DType >::set_stream
void set_stream(Stream< Device > *stream)
Definition: tensor.h:691
mshadow::Tensor::set_stream
void set_stream(Stream< Device > *stream)
set the stream to do computation of current tensor
Definition: tensor.h:576
mshadow::Shape::Size
MSHADOW_XINLINE index_t Size(void) const
Definition: tensor.h:158
mshadow::Tensor< Device, 1, DType >::FlatTo1D
MSHADOW_XINLINE Tensor< Device, 1, DType > FlatTo1D(void) const
Definition: tensor.h:694
mshadow::SortByKey
void SortByKey(Tensor< cpu, 1, KDType > keys, Tensor< cpu, 1, VDType > values, bool is_ascend=true)
CPU/GPU: Sort key-value pairs stored in separate places. (Stable sort is performed!...
Definition: tensor_cpu-inl.h:596
mshadow::Tensor::MSize
MSHADOW_XINLINE index_t MSize(void) const
Definition: tensor.h:602
mshadow::Stream
computaion stream structure, used for asynchronous computations
Definition: tensor.h:488
MSHADOW_DEFAULT_DTYPE
#define MSHADOW_DEFAULT_DTYPE
default data type for tensor string in code release, change it to default_real_t during development,...
Definition: base.h:240
mshadow::Tensor< Device, 1, DType >::MSize
MSHADOW_XINLINE index_t MSize(void) const
Definition: tensor.h:708
mshadow::Shape::operator[]
const MSHADOW_XINLINE index_t & operator[](int idx) const
get corresponding index
Definition: tensor.h:109
mshadow::Tensor::kSubdim
static const int kSubdim
dimension of subtype
Definition: tensor.h:534
mshadow::Shape::kDimension
static const int kDimension
dimension of current shape
Definition: tensor.h:82
mshadow::kNWC
@ kNWC
Definition: base.h:506
mshadow::kNCDHW
@ kNCDHW
Definition: base.h:509
mshadow::Shape::kSubdim
static const int kSubdim
dimension of current shape minus one
Definition: tensor.h:84
mshadow::Tensor< Device, 1, DType >::operator=
Tensor< Device, 1, DType > & operator=(const expr::Exp< E, DType, etype > &exp)
Definition: tensor.h:731
mshadow::TRValue
Tensor RValue, this is the super type of all kinds of possible tensors.
Definition: tensor.h:514
io.h
definitions of I/O functions for mshadow tensor
mshadow::Copy
void Copy(Tensor< cpu, dim, DType > dst, const Tensor< cpu, dim, DType > &src, Stream< cpu > *stream=NULL)
copy data from one tensor to another, with same shape
Definition: tensor_cpu-inl.h:145
mshadow::LapackIndex
Definition: tensor.h:54
mshadow::Tensor< Device, 1, DType >::shape_
Shape< 1 > shape_
Definition: tensor.h:677
mshadow::Tensor< Device, 1, DType >::FlatTo2D
MSHADOW_XINLINE Tensor< Device, 2, DType > FlatTo2D(void) const
Definition: tensor.h:697
mshadow::FreeSpace
void FreeSpace(Tensor< cpu, dim, DType > *obj)
CPU/GPU: free the space of tensor, will set obj.dptr to NULL.
Definition: tensor_cpu-inl.h:140
mshadow::BatchGEMM
void BatchGEMM(Tensor< Device, 3, DType > dst, const Tensor< Device, 3, DType > &lhs, const Tensor< Device, 3, DType > &rhs, DType alpha, DType beta, Tensor< Device, 1, DType * > workspace)
CPU/GPU: dst = alpha * op(lhs) op(rhs) + beta * dst.
Definition: tensor_cpu-inl.h:648
mshadow::IndexFill
void IndexFill(Tensor< cpu, 2, DType > dst, const Tensor< cpu, 1, IndexType > &index, const Tensor< cpu, 2, DType > &src)
CPU/GPU: Fill the values of the destination matrix to specific rows in the source matrix....
Definition: tensor_cpu-inl.h:585
mshadow::LapackIndex::IndexT
lapack_index_t IndexT
Definition: tensor.h:55
MSHADOW_XINLINE
#define MSHADOW_XINLINE
Definition: base.h:228
mshadow::Tensor::Tensor
MSHADOW_XINLINE Tensor(const Shape< dimension > &shape)
constructor from shape
Definition: tensor.h:558
mshadow::SoftmaxGrad
void SoftmaxGrad(Tensor< cpu, 2, DType > dst, const Tensor< cpu, 2, DType > &src, const Tensor< cpu, 1, DType > &label)
CPU/GPU: softmax gradient.
Definition: tensor_cpu-inl.h:311
mshadow::Tensor::Tensor
MSHADOW_XINLINE Tensor(DType *dptr, const Shape< dimension > &shape, Stream< Device > *stream)
constructor from data pointer and shape, without stride
Definition: tensor.h:564
mshadow::DeleteStream
void DeleteStream(Stream< Device > *stream)
delete the computing stream
mshadow::Tensor< Device, 1, DType >::stream_
Stream< Device > * stream_
Definition: tensor.h:679
mshadow::Shape::FlatTo1D
MSHADOW_XINLINE Shape< 1 > FlatTo1D(void) const
Definition: tensor.h:137
mshadow::MapReduceKeepLowest
void MapReduceKeepLowest(TRValue< R, cpu, 1, DType > *dst, const expr::Exp< E, DType, etype > &exp, DType scale=1)
CPU/GPU: map a expression, do reduction to 1D Tensor in lowest dimension (dimension 0)
Definition: tensor_cpu-inl.h:228
mshadow::Tensor::FlatTo1D
MSHADOW_XINLINE Tensor< Device, 1, DType > FlatTo1D(void) const
flatten the tensor to 1 dimension
Definition: tensor.h:617
mshadow::Tensor
general tensor
Definition: tensor.h:525
mshadow::Tensor< Device, 1, DType >::Tensor
MSHADOW_XINLINE Tensor(DType *dptr, Shape< 1 > shape, index_t stride, Stream< Device > *stream)
Definition: tensor.h:688
mshadow::VectorizedSort
void VectorizedSort(Tensor< Device, 1, VDType > values, Tensor< Device, 1, SDType > segments)
CPU/GPU: Sort the keys within each segment. (Stable sort is performed!) Segments is defined as an asc...
Definition: tensor_cpu-inl.h:627
mshadow::Tensor::Slice
MSHADOW_XINLINE Tensor< Device, dimension, DType > Slice(index_t begin, index_t end) const
slice the tensor in highest dimension [begin,end)
Definition: tensor.h:643
mshadow::cpu::kDevMask
static const int kDevMask
device flag number, identifies this device
Definition: tensor.h:43
mshadow::ShutdownTensorEngine
void ShutdownTensorEngine(void)
Shutdown tensor engine on current device this function should be called after all GPU tensor operatio...
mshadow::NewStream
Stream< Device > * NewStream(bool create_blas_handle, bool create_dnn_handle, int dev_id=-1)
create a new stream from system
mshadow::cpu::kDevCPU
static const bool kDevCPU
whether this device is CPU or not
Definition: tensor.h:41
mshadow::Shape::SubShape
MSHADOW_XINLINE Shape< kSubdim > SubShape(void) const
Definition: tensor.h:183
mshadow::getTranspAxes
std::vector< dim_t > getTranspAxes(const LayoutFlag src_layout, const LayoutFlag dst_layout)
returns axes of transpose operation that needs to be performed between src layout and dst
Definition: tensor.h:401
mshadow::operator<<
std::ostream & operator<<(std::ostream &os, const Shape< ndim > &shape)
allow string printing of the shape
Definition: tensor_cpu-inl.h:59
mshadow::Shape::Slice
MSHADOW_XINLINE Shape< dimend - dimstart > Slice(void) const
slice the shape from start to end
Definition: tensor.h:199
mshadow::kNHWC
@ kNHWC
Definition: base.h:502
mshadow::ConvertLayout
Shape< 3 > ConvertLayout(const Shape< 3 > &src, int src_layout, int dst_layout)
Convert shape in src_layout to shape in dst_layout.
Definition: tensor.h:283
mshadow::Tensor< Device, 1, DType >::size
MSHADOW_XINLINE index_t size(index_t i) const
Definition: tensor.h:711
mshadow::gpu
device name GPU
Definition: tensor.h:46
mshadow::kCWN
@ kCWN
Definition: base.h:507
tensor_gpu-inl.h
implementation of GPU host code
mshadow::cpu
device name CPU
Definition: tensor.h:39
mshadow::kCHWN
@ kCHWN
Definition: base.h:503
random.h
Random inline functions for tensor.
mshadow::Softmax
void Softmax(Tensor< cpu, 2, DType > dst, const Tensor< cpu, 2, DType > &energy)
CPU/GPU: normalize softmax: dst[i][j] = exp(energy[i][j]) /(sum_j exp(energy[i][j]))
Definition: tensor_cpu-inl.h:488
mshadow::Tensor::MemSize
MSHADOW_XINLINE index_t MemSize(void) const
Definition: tensor.h:584
mshadow::LayoutFlag
LayoutFlag
Definition: base.h:498
mshadow::Tensor::operator=
Tensor< Device, dimension, DType > & operator=(const expr::Exp< E, DType, etype > &exp)
functions to fit expression template
Definition: tensor.h:661
mshadow::Tensor< Device, 1, DType >::operator=
Tensor< Device, 1, DType > & operator=(const Tensor< Device, 1, DType > &exp)
implement the assignment of same type
Definition: tensor.h:722
mshadow::Tensor< Device, 1, DType >::Tensor
MSHADOW_XINLINE Tensor(DType *dptr, Shape< 1 > shape, Stream< Device > *stream)
Definition: tensor.h:686
mshadow::Tensor::operator=
Tensor< Device, dimension, DType > & operator=(const DType &exp)
functions to fit expression template
Definition: tensor.h:665
mshadow::Tensor::Tensor
MSHADOW_XINLINE Tensor(DType *dptr, const Shape< dimension > &shape)
constructor from data pointer and shape, without stride
Definition: tensor.h:561
mshadow::Tensor< Device, 1, DType >::stride_
index_t stride_
Definition: tensor.h:678
mshadow::Shape5
MSHADOW_XINLINE Shape< 5 > Shape5(index_t s0, index_t s1, index_t s2, index_t s3, index_t s4)
construct a five dimension shape, stride will equal s0
Definition: tensor.h:269
mshadow::Shape::shape_
index_t shape_[kDimension]
storing the dimension information
Definition: tensor.h:86
mshadow::Tensor::stream_
Stream< Device > * stream_
stream where the computation lies stream is a device dependency concept where each computation
Definition: tensor.h:551
mshadow::SetDevice
void SetDevice(int devid)
set the device of current thread to work on
mshadow::Tensor< Device, 1, DType >::operator=
Tensor< Device, 1, DType > & operator=(const DType &exp)
Definition: tensor.h:734
tensor_cpu-inl.h
implementation of CPU host code
mshadow::Shape::FlatTo2D
MSHADOW_XINLINE Shape< 2 > FlatTo2D(void) const
Definition: tensor.h:146
extension.h
some extension of expressions, used to support something beyond elementwise op
mshadow::kCDHWN
@ kCDHWN
Definition: base.h:511
mshadow::Tensor::CheckContiguous
MSHADOW_XINLINE bool CheckContiguous(void) const
Definition: tensor.h:596
mshadow::Tensor::shape_
Shape< dimension > shape_
shape of the tensor
Definition: tensor.h:541
mshadow::Shape::operator!=
MSHADOW_XINLINE bool operator!=(const Shape< kDimension > &s) const
Definition: tensor.h:130
mshadow::index_t
int32_t index_t
type that will be used for index
Definition: base.h:328
MSHADOW_ALLOC_PAD
#define MSHADOW_ALLOC_PAD
whether do padding during allocation
Definition: base.h:73
mshadow::Shape::Shape
MSHADOW_XINLINE Shape(void)
default constructor, do nothing
Definition: tensor.h:88
mshadow::Tensor< Device, 1, DType >::dptr_
DType * dptr_
Definition: tensor.h:676
mshadow::expr::RValueExp
base class of all rvalues
Definition: expression.h:148
mshadow::AddTakeGradLargeBatch
void AddTakeGradLargeBatch(Tensor< cpu, 2, DType > dst, const Tensor< cpu, 1, IndexType > &sorted, const Tensor< cpu, 1, IndexType > &index, const Tensor< cpu, 2, DType > &src)
CPU/GPU: Gradient accumulate of embedding matrix with safe accumulation. dst[index[i]] += src[i].
Definition: tensor_cpu-inl.h:575
mshadow::lapack_index_t
int lapack_index_t
Definition: base.h:344
mshadow::Stream::Wait
void Wait(void)
wait for all the computations associated with this stream to complete
Definition: tensor.h:495
mshadow::kNDHWC
@ kNDHWC
Definition: base.h:510
mshadow::expr::Exp
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:79
mshadow::expr::pad
PaddingExp< SrcExp, DType, ExpInfo< SrcExp >::kDim > pad(const Exp< SrcExp, DType, etype > &src, index_t pad)
padding expression, pad a image with zeros on boundaries, padding affects shape[0],...
Definition: pad.h:71
mshadow::Tensor< Device, 1, DType >::Tensor
MSHADOW_XINLINE Tensor(const Shape< 1 > &shape)
Definition: tensor.h:682
mshadow::Stream::CreateBlasHandle
void CreateBlasHandle()
create a blas handle
Definition: tensor.h:504
mshadow::Tensor< Device, 1, DType >::operator[]
MSHADOW_XINLINE DType & operator[](index_t idx)
Definition: tensor.h:714
mshadow::AllocSpace
void AllocSpace(Tensor< cpu, dim, DType > *obj, bool pad=MSHADOW_ALLOC_PAD)
CPU/CPU: allocate space for CTensor, according to the shape in the obj this function is responsible t...
Definition: tensor_cpu-inl.h:116
mshadow::Tensor< Device, 1, DType >::Tensor
MSHADOW_XINLINE Tensor(DType *dptr, Shape< 1 > shape)
Definition: tensor.h:684
mshadow::Tensor< Device, 1, DType >::CheckContiguous
MSHADOW_XINLINE bool CheckContiguous(void) const
Definition: tensor.h:705
mshadow::Tensor::operator[]
MSHADOW_XINLINE Tensor< Device, kSubdim, DType > operator[](index_t idx) const
get a element of dimension - 1
Definition: tensor.h:632
mshadow::MapReduceKeepHighDim
void MapReduceKeepHighDim(TRValue< R, cpu, 1, DType > *dst, const expr::Exp< E, DType, etype > &exp, DType scale=1)
CPU/GPU: map a expression, do reduction to 1D Tensor in third dimension (dimension 2)
Definition: tensor_cpu-inl.h:255
mshadow::Tensor< Device, 1, DType >
Definition: tensor.h:673
mshadow
overloaded + operator between half_t and bf16_t
Definition: base.h:319
mshadow::Tensor::FlatTo2D
MSHADOW_XINLINE Tensor< Device, 2, DType > FlatTo2D(void) const
flatten the tensor to 2 dimension, collapse the higher dimensions together
Definition: tensor.h:624
mshadow::Shape2
MSHADOW_XINLINE Shape< 2 > Shape2(index_t s0, index_t s1)
construct a two dimension shape, stride will equal s0
Definition: tensor.h:230
mshadow::Tensor::kDevCPU
static const bool kDevCPU
whether current type lies in cpu
Definition: tensor.h:532
mshadow::LapackIndex< gpu >::IndexT
int IndexT
Definition: tensor.h:60
mshadow::kUNKNOWN
@ kUNKNOWN
Definition: base.h:499
mshadow::Shape::Shape
MSHADOW_XINLINE Shape(const Shape< kDimension > &s)
constuctor
Definition: tensor.h:90
mshadow::Shape
shape of a tensor
Definition: tensor.h:64
mshadow::Tensor::dptr_
DType * dptr_
pointer to the data
Definition: tensor.h:539
tensor_container.h
tensor container that does memory allocation and resize like STL
mshadow::kNCW
@ kNCW
Definition: base.h:505
mshadow::Shape1
MSHADOW_XINLINE Shape< 1 > Shape1(index_t s0)
construct a one dimension shape, stride will equal s0
Definition: tensor.h:220
mshadow::InitTensorEngine
void InitTensorEngine(int device_id=0)
initialize tensor engine, used to call intialization functions of dependent libs this function should...
mshadow::VectorDot
void VectorDot(Tensor< Device, 1, DType > dst, const Tensor< Device, 1, DType > &lhs, const Tensor< Device, 1, DType > &rhs)
CPU/GPU: 1 dimension vector dot.
Definition: tensor_cpu-inl.h:635
mshadow::Tensor::size
MSHADOW_XINLINE index_t size(int idx) const
return size of i-th dimension, start counting from highest dimension
Definition: tensor.h:610
mshadow::expr::RValueExp< Tensor< Device, dimension, DType >, DType >::__assign
Tensor< Device, dimension, DType > & __assign(DType s)
operator overload
Definition: expression.h:178
mshadow::NewTensor
Tensor< Device, dim, DType > NewTensor(const Shape< dim > &shape, DType initv, bool pad=MSHADOW_ALLOC_PAD, Stream< Device > *stream=NULL)
CPU/GPU: short cut to allocate and initialize a Tensor.
Definition: tensor_cpu-inl.h:132
mshadow::AddTakeGrad
void AddTakeGrad(Tensor< cpu, 2, DType > dst, const Tensor< cpu, 1, IndexType > &index, const Tensor< cpu, 2, DType > &src)
CPU/GPU: Gradient accumulate of embedding matrix. dst[index[i]] += src[i] Called when the featuredim ...
Definition: tensor_cpu-inl.h:521
mshadow::Shape::operator[]
MSHADOW_XINLINE index_t & operator[](int idx)
get corresponding index
Definition: tensor.h:101
mshadow::Stream::CheckIdle
bool CheckIdle(void)
query whether the the stream is idle
Definition: tensor.h:500
mshadow::Tensor< Device, 1, DType >::Tensor
MSHADOW_XINLINE Tensor(void)
Definition: tensor.h:681
mshadow::gpu::kDevMask
static const int kDevMask
device flag number, identifies this device
Definition: tensor.h:50
mshadow::Shape::ProdShape
MSHADOW_XINLINE index_t ProdShape(int dimstart, int dimend) const
Definition: tensor.h:171
mshadow::Tensor::Tensor
MSHADOW_XINLINE Tensor(DType *dptr, const Shape< dimension > &shape, index_t stride, Stream< Device > *stream)
constructor from data pointer and shape
Definition: tensor.h:568
expr_engine-inl.h
definitions of how expressions should be evaluated
expr_scalar-inl.h
definitions of operators in expression with respect to scalar this file will be included several time...
mshadow::Tensor::Tensor
MSHADOW_XINLINE Tensor(void)
default constructor
Definition: tensor.h:556
base.h
definitions of base types, operators, macros functions
mshadow::Shape::operator==
MSHADOW_XINLINE bool operator==(const Shape< kDimension > &s) const
Definition: tensor.h:119
mshadow::Shape3
MSHADOW_XINLINE Shape< 3 > Shape3(index_t s0, index_t s1, index_t s2)
construct a three dimension shape, stride will equal s0
Definition: tensor.h:241
mshadow::Tensor< Device, 1, DType >::operator[]
const MSHADOW_XINLINE DType & operator[](index_t idx) const
Definition: tensor.h:717
mshadow::Tensor::operator=
Tensor< Device, dimension, DType > & operator=(const Tensor< Device, dimension, DType > &exp)
implement the assignment of same type
Definition: tensor.h:651
mshadow::Tensor::stride_
index_t stride_
storing the stride information in x dimension this is used to deal with pitch allocation in gpu or ss...
Definition: tensor.h:546
mshadow::MapExp
void MapExp(TRValue< R, cpu, dim, DType > *dst, const expr::Exp< E, DType, etype > &exp)
CPU/GPU: map a expression to a tensor, this function calls MapPlan.
Definition: tensor_cpu-inl.h:212