mxnet
slice_ex.h
Go to the documentation of this file.
1 
6 #ifndef MSHADOW_EXTENSION_SLICE_EX_H_
7 #define MSHADOW_EXTENSION_SLICE_EX_H_
8 
9 #include "../extension.h"
10 
11 namespace mshadow {
12 namespace expr {
20 template<typename SrcExp, typename Device,
21  typename DType, int srcdim>
22 struct SliceExExp : public TRValue<SliceExExp<SrcExp,
23  Device, DType,
24  srcdim>,
25  Device, srcdim, DType> {
26  const SrcExp &src_;
31  SliceExExp(const SrcExp &src, Shape<srcdim> begin, Shape<srcdim> end)
32  : src_(src), begin_(begin), end_(end) {
33  src_shape_ = ShapeCheck<srcdim, SrcExp>::Check(src_);
34  for (int i = 0; i < srcdim; ++i) {
35  shape_[i] = end_[i] - begin_[i];
36  }
37  }
38  template<typename E, int etype>
39  inline void
41  this->__assign(exp);
42  }
43  inline void
44  operator=(const DType &exp) {
45  this->__assign(exp);
46  }
47 }; // struct SliceEx
48 
60 template<typename SrcExp, typename Device,
61  typename DType, int srcdim>
65  ::Error_Expression_Does_Not_Meet_Dimension_Req();
66  return SliceExExp<SrcExp, Device, DType, srcdim>(src.self(), begin, end);
67 }
68 //------------------------
69 // engine plugin
70 //------------------------
71 // runtime shapecheck
72 template<typename SrcExp, typename Device,
73  typename DType, int srcdim>
74 struct ShapeCheck<srcdim, SliceExExp<SrcExp, Device, DType, srcdim> >{
75  inline static Shape<srcdim> Check(const SliceExExp<SrcExp,
76  Device, DType, srcdim> &t) {
77  return t.shape_;
78  }
79 };
80 
81 template<typename SrcExp, typename Device,
82  typename DType, int srcdim>
83 struct StreamInfo<Device, SliceExExp<SrcExp, Device, DType, srcdim> >{
84  inline static Stream<Device> *
87  }
88 };
89 // static typecheck
90 template<typename SrcExp, typename Device,
91  typename DType, int srcdim>
92 struct ExpInfo<SliceExExp<SrcExp, Device, DType, srcdim> >{
93  static const int kDim = ExpInfo<SrcExp>::kDim;
94  static const int kDevMask = ExpInfo<SrcExp>::kDevMask;
95 };
96 //----------------------
97 // Execution plan
98 //---------------------
99 template<typename SrcExp, typename Device,
100  typename DType, int srcdim>
101 struct Plan<SliceExExp<SrcExp, Device, DType, srcdim>, DType> {
102  public:
104  : src_(MakePlan(e.src_)), begin_(e.begin_),
106  MSHADOW_XINLINE DType Eval(index_t i, index_t j) const {
107  index_t idx = 0;
108  index_t stride = 1;
109  #pragma unroll
110  for (int k = srcdim-2; k >= 0; --k) {
111  idx += stride * (i%shape_[k] + begin_[k]);
112  i /= shape_[k];
113  stride *= src_shape_[k];
114  }
115  return src_.Eval(idx, j + begin_[srcdim-1]);
116  }
118  index_t idx = 0;
119  index_t stride = 1;
120  #pragma unroll
121  for (int k = srcdim-2; k >= 0; --k) {
122  idx += stride * (i%shape_[k] + begin_[k]);
123  i /= shape_[k];
124  stride *= src_shape_[k];
125  }
126  return src_.REval(idx, j + begin_[srcdim-1]);
127  }
128 
129  private:
132 }; // struct Plan
133 } // namespace expr
134 } // namespace mshadow
135 #endif // MSHADOW_EXTENSION_SLICE_EX_H_
slice expression, slice a tensor&#39;s channel
Definition: slice_ex.h:22
Tensor RValue, this is the super type of all kinds of possible tensors.
Definition: tensor.h:391
const SrcExp & src_
Definition: slice_ex.h:26
Definition: expr_engine-inl.h:40
used to help static type check
Definition: expr_engine-inl.h:312
const Shape< srcdim > end_
Definition: slice_ex.h:30
const Shape< srcdim > begin_
Definition: slice_ex.h:29
Plan(const SliceExExp< SrcExp, Device, DType, srcdim > &e)
Definition: slice_ex.h:103
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const
Definition: slice_ex.h:106
void operator=(const DType &exp)
Definition: slice_ex.h:44
static Shape< dim > Check(const E &t)
#define MSHADOW_XINLINE
Definition: base.h:204
SliceExp< SrcExp, Device, DType, srcdim, srcdim-sdim > slice(const TRValue< SrcExp, Device, srcdim, DType > &src, index_t begin, index_t end)
Slice a Tensor.
Definition: slice.h:65
static type inference template, used to get the dimension of each expression, if ExpInfo<E>::kDim == ...
Definition: expr_engine-inl.h:244
Definition: expr_engine-inl.h:327
int32_t index_t
type that will be used for index
Definition: base.h:291
static Stream< Device > * Get(const SliceExExp< SrcExp, Device, DType, srcdim > &t)
Definition: slice_ex.h:85
MSHADOW_XINLINE DType & REval(index_t i, index_t j)
Definition: slice_ex.h:117
Shape< srcdim > src_shape_
Definition: slice_ex.h:27
index_t shape_[kDimension]
storing the dimension information
Definition: tensor.h:57
runtime shape checking template get the shape of an expression, report error if shape mismatch ...
Definition: expr_engine-inl.h:346
static Stream< Device > * Get(const E &t)
static Shape< srcdim > Check(const SliceExExp< SrcExp, Device, DType, srcdim > &t)
Definition: slice_ex.h:75
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:61
const Container & self(void) const
Definition: expression.h:64
Plan< BinaryMapExp< OP, TA, TB, DType, etype >, DType > MakePlan(const BinaryMapExp< OP, TA, TB, DType, etype > &e)
Definition: expr_engine-inl.h:221
SliceExExp(const SrcExp &src, Shape< srcdim > begin, Shape< srcdim > end)
Definition: slice_ex.h:31
namespace for mshadow
Definition: base.h:282
SliceExExp< SrcExp, Device, DType, srcdim > & __assign(DType s)
operator overload
Definition: expression.h:160
void operator=(const expr::Exp< E, DType, etype > &exp)
Definition: slice_ex.h:40
Shape< srcdim > shape_
Definition: slice_ex.h:28
computaion stream structure, used for asynchronous computations
Definition: tensor.h:365