mxnet
slice.h
Go to the documentation of this file.
1 
6 #ifndef MSHADOW_EXTENSION_SLICE_H_
7 #define MSHADOW_EXTENSION_SLICE_H_
8 
9 #include "../extension.h"
10 
11 namespace mshadow {
12 namespace expr {
20 template<typename SrcExp,
21  typename Device, typename DType,
22  int srcdim, int dimsrc_m_slice>
23 struct SliceExp : public TRValue<SliceExp<SrcExp,
24  Device, DType,
25  srcdim, dimsrc_m_slice>,
26  Device, srcdim, DType> {
27  static const int dimslice = srcdim - dimsrc_m_slice;
28  const SrcExp &src_;
32  SliceExp(const SrcExp &src, index_t begin, index_t end)
33  : src_(src), ch_begin_(begin) {
34  shape_ = ShapeCheck<srcdim, SrcExp>::Check(src_);
35  ch_old_ = shape_[dimslice];
36  CHECK(begin <= shape_[dimslice] && end <= shape_[dimslice])
37  << "The slice went out of range. ";
38  shape_[dimslice] = end - begin;
39  }
40  template<typename E, int etype>
41  inline void
43  this->__assign(exp);
44  }
45  inline void
46  operator=(const DType &exp) {
47  this->__assign(exp);
48  }
49 }; // struct Slice
50 
62 template<int sdim, typename SrcExp,
63  typename Device, typename DType, int srcdim>
64 inline SliceExp<SrcExp, Device, DType, srcdim, srcdim - sdim>
67  ::Error_Expression_Does_Not_Meet_Dimension_Req();
68  return SliceExp<SrcExp, Device, DType, srcdim, srcdim - sdim>(src.self(), begin, end);
69 }
70 //------------------------
71 // engine plugin
72 //------------------------
73 // runtime shapecheck
74 template<typename SrcExp,
75  typename Device, typename DType,
76  int srcdim, int dimsrc_m_slice>
77 struct ShapeCheck<srcdim, SliceExp<SrcExp, Device, DType, srcdim, dimsrc_m_slice> >{
78  inline static Shape<srcdim> Check(const SliceExp<SrcExp,
79  Device, DType, srcdim, dimsrc_m_slice> &t) {
80  return t.shape_;
81  }
82 };
83 template<typename SrcExp,
84  typename Device, typename DType,
85  int srcdim, int dimsrc_m_slice>
86 struct StreamInfo<Device, SliceExp<SrcExp, Device, DType, srcdim, dimsrc_m_slice> >{
87  inline static Stream<Device> *
90  }
91 };
92 // static typecheck
93 template<typename SrcExp,
94  typename Device, typename DType,
95  int srcdim, int dimsrc_m_slice>
96 struct ExpInfo<SliceExp<SrcExp, Device, DType, srcdim, dimsrc_m_slice> >{
97  static const int kDim = ExpInfo<SrcExp>::kDim;
98  static const int kDevMask = ExpInfo<SrcExp>::kDevMask;
99 };
100 //----------------------
101 // Execution plan
102 //---------------------
103 template<typename SrcExp,
104  typename Device, typename DType,
105  int srcdim, int dimsrc_m_slice>
106 struct Plan<SliceExp<SrcExp, Device, DType, srcdim, dimsrc_m_slice>, DType> {
107  public:
108  static const int dimslice = srcdim - dimsrc_m_slice;
110  : src_(MakePlan(e.src_)),
111  height_(e.shape_.ProdShape(dimslice + 1, srcdim - 1)),
112  ch_begin_(e.ch_begin_), ch_old_(e.ch_old_), ch_(e.shape_[dimslice]) {}
113  MSHADOW_XINLINE DType Eval(index_t i, index_t j) const {
114  const index_t y = i % height_;
115  i /= height_;
116  const index_t c = i % ch_ + ch_begin_;
117  const index_t b = i / ch_;
118  const index_t x = j;
119  return src_.Eval((b * ch_old_ + c) * height_ + y, x);
120  }
122  const index_t y = i % height_;
123  i /= height_;
124  const index_t c = i % ch_ + ch_begin_;
125  const index_t b = i / ch_;
126  const index_t x = j;
127  return src_.REval((b * ch_old_ + c) * height_ + y, x);
128  }
129 
130  private:
132  const index_t height_, ch_begin_, ch_old_, ch_;
133 }; // struct Plan
134 
135 template<typename SrcExp,
136  typename Device, typename DType,
137  int srcdim>
138 struct Plan<SliceExp<SrcExp, Device, DType, srcdim, 1>, DType> {
139  public:
141  : src_(MakePlan(e.src_)),
142  ch_begin_(e.ch_begin_) {}
143  MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
144  return src_.Eval(y, x + ch_begin_);
145  }
147  return src_.REval(y, x + ch_begin_);
148  }
149 
150  private:
152  const index_t ch_begin_;
153 };
154 } // namespace expr
155 } // namespace mshadow
156 #endif // MSHADOW_EXTENSION_SLICE_H_
Tensor RValue, this is the super type of all kinds of possible tensors.
Definition: tensor.h:391
Definition: expr_engine-inl.h:40
used to help static type check
Definition: expr_engine-inl.h:312
static const int dimslice
Definition: slice.h:27
void operator=(const expr::Exp< E, DType, etype > &exp)
Definition: slice.h:42
Plan(const SliceExp< SrcExp, Device, DType, srcdim, dimsrc_m_slice > &e)
Definition: slice.h:109
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const
Definition: slice.h:113
static Shape< dim > Check(const E &t)
index_t ch_old_
Definition: slice.h:30
MSHADOW_XINLINE DType & REval(index_t y, index_t x)
Definition: slice.h:146
#define MSHADOW_XINLINE
Definition: base.h:204
SliceExp< SrcExp, Device, DType, srcdim, srcdim-sdim > slice(const TRValue< SrcExp, Device, srcdim, DType > &src, index_t begin, index_t end)
Slice a Tensor.
Definition: slice.h:65
static type inference template, used to get the dimension of each expression, if ExpInfo<E>::kDim == ...
Definition: expr_engine-inl.h:244
Definition: expr_engine-inl.h:327
static Shape< srcdim > Check(const SliceExp< SrcExp, Device, DType, srcdim, dimsrc_m_slice > &t)
Definition: slice.h:78
int32_t index_t
type that will be used for index
Definition: base.h:291
void operator=(const DType &exp)
Definition: slice.h:46
const SrcExp & src_
Definition: slice.h:28
Plan(const SliceExp< SrcExp, Device, DType, srcdim, 1 > &e)
Definition: slice.h:140
index_t shape_[kDimension]
storing the dimension information
Definition: tensor.h:57
runtime shape checking template get the shape of an expression, report error if shape mismatch ...
Definition: expr_engine-inl.h:346
Shape< srcdim > shape_
Definition: slice.h:31
static Stream< Device > * Get(const E &t)
MSHADOW_XINLINE DType & REval(index_t i, index_t j)
Definition: slice.h:121
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:61
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: slice.h:143
SliceExp(const SrcExp &src, index_t begin, index_t end)
Definition: slice.h:32
const Container & self(void) const
Definition: expression.h:64
Plan< BinaryMapExp< OP, TA, TB, DType, etype >, DType > MakePlan(const BinaryMapExp< OP, TA, TB, DType, etype > &e)
Definition: expr_engine-inl.h:221
namespace for mshadow
Definition: base.h:282
index_t ch_begin_
Definition: slice.h:29
slice expression, slice a tensor&#39;s channel
Definition: slice.h:23
SliceExp< SrcExp, Device, DType, srcdim, dimsrc_m_slice > & __assign(DType s)
operator overload
Definition: expression.h:160
static Stream< Device > * Get(const SliceExp< SrcExp, Device, DType, srcdim, dimsrc_m_slice > &t)
Definition: slice.h:88
computaion stream structure, used for asynchronous computations
Definition: tensor.h:365