mxnet
flip.h
Go to the documentation of this file.
1 
7 #ifndef MSHADOW_EXTENSION_FLIP_H_
8 #define MSHADOW_EXTENSION_FLIP_H_
9 
10 #include "../extension.h"
11 
12 namespace mshadow {
13 namespace expr {
21 template<typename SrcExp, typename Device,
22  typename DType, int srcdim>
23 struct FlipExp : public TRValue<FlipExp<SrcExp,
24  Device, DType,
25  srcdim>,
26  Device, srcdim, DType> {
27  const SrcExp &src_;
32  FlipExp(const SrcExp &src, int dim)
33  : src_(src) {
34  shape_ = ShapeCheck<srcdim, SrcExp>::Check(src_);
35  stride_ = shape_[dim];
36  stride_j_ = shape_[srcdim-1];
37  trailing_ = 1;
38  for (int i = dim + 1; i < srcdim; ++i) {
39  trailing_ *= shape_[i];
40  }
41  }
42  template<typename E, int etype>
43  inline void
45  this->__assign(exp);
46  }
47  inline void
48  operator=(const DType &exp) {
49  this->__assign(exp);
50  }
51 }; // struct Flip
52 
64 template<typename SrcExp, typename Device,
65  typename DType, int srcdim>
69 }
70 //------------------------
71 // engine plugin
72 //------------------------
73 // runtime shapecheck
74 template<typename SrcExp, typename Device,
75  typename DType, int srcdim>
76 struct ShapeCheck<srcdim, FlipExp<SrcExp, Device, DType, srcdim> >{
77  inline static Shape<srcdim> Check(const FlipExp<SrcExp,
78  Device, DType, srcdim> &t) {
79  return t.shape_;
80  }
81 };
82 template<typename SrcExp, typename Device,
83  typename DType, int srcdim>
84 struct StreamInfo<Device, FlipExp<SrcExp, Device, DType, srcdim> >{
85  inline static Stream<Device> *
88  }
89 };
90 // static typecheck
91 template<typename SrcExp, typename Device,
92  typename DType, int srcdim>
93 struct ExpInfo<FlipExp<SrcExp, Device, DType, srcdim> >{
94  static const int kDim = ExpInfo<SrcExp>::kDim;
95  static const int kDevMask = ExpInfo<SrcExp>::kDevMask;
96 };
97 //----------------------
98 // Execution plan
99 //---------------------
100 template<typename SrcExp, typename Device,
101  typename DType, int srcdim>
102 struct Plan<FlipExp<SrcExp, Device, DType, srcdim>, DType> {
103  public:
107  MSHADOW_XINLINE DType Eval(index_t i, index_t j) const {
108  index_t idx = i*stride_j_+j;
109  const index_t low = idx%trailing_;
110  index_t high = idx/trailing_;
111  const index_t x = high%stride_;
112  high /= stride_;
113  idx = (high*stride_+stride_-1-x)*trailing_+low;
114  return src_.Eval(idx/stride_j_, idx%stride_j_);
115  }
116  MSHADOW_XINLINE DType &REval(index_t i, index_t j) const {
117  index_t idx = i*stride_j_+j;
118  const index_t low = idx%trailing_;
119  index_t high = idx/trailing_;
120  const index_t x = high%stride_;
121  high /= stride_;
122  idx = (high*stride_+stride_-1-x)*trailing_+low;
123  return src_.REval(idx/stride_j_, idx%stride_j_);
124  }
125 
126  private:
129 }; // struct Plan
130 } // namespace expr
131 } // namespace mshadow
132 #endif // MSHADOW_EXTENSION_FLIP_H_
index_t stride_
Definition: flip.h:29
MSHADOW_XINLINE DType & REval(index_t i, index_t j) const
Definition: flip.h:116
Tensor RValue, this is the super type of all kinds of possible tensors.
Definition: tensor.h:391
index_t trailing_
Definition: flip.h:28
Definition: expr_engine-inl.h:40
static Stream< Device > * Get(const FlipExp< SrcExp, Device, DType, srcdim > &t)
Definition: flip.h:86
FlipExp< SrcExp, Device, DType, srcdim > flip(const TRValue< SrcExp, Device, srcdim, DType > &src, int dim)
Flip a Tensor.
Definition: flip.h:67
static Shape< dim > Check(const E &t)
FlipExp(const SrcExp &src, int dim)
Definition: flip.h:32
#define MSHADOW_XINLINE
Definition: base.h:204
static type inference template, used to get the dimension of each expression, if ExpInfo<E>::kDim == ...
Definition: expr_engine-inl.h:244
Definition: expr_engine-inl.h:327
int32_t index_t
type that will be used for index
Definition: base.h:291
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const
Definition: flip.h:107
const SrcExp & src_
Definition: flip.h:27
index_t shape_[kDimension]
storing the dimension information
Definition: tensor.h:57
runtime shape checking template get the shape of an expression, report error if shape mismatch ...
Definition: expr_engine-inl.h:346
void operator=(const expr::Exp< E, DType, etype > &exp)
Definition: flip.h:44
static Stream< Device > * Get(const E &t)
Plan(const FlipExp< SrcExp, Device, DType, srcdim > &e)
Definition: flip.h:104
void operator=(const DType &exp)
Definition: flip.h:48
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:61
const Container & self(void) const
Definition: expression.h:64
slice expression, slice a tensor&#39;s channel
Definition: flip.h:23
Plan< BinaryMapExp< OP, TA, TB, DType, etype >, DType > MakePlan(const BinaryMapExp< OP, TA, TB, DType, etype > &e)
Definition: expr_engine-inl.h:221
namespace for mshadow
Definition: base.h:282
FlipExp< SrcExp, Device, DType, srcdim > & __assign(DType s)
operator overload
Definition: expression.h:160
Shape< srcdim > shape_
Definition: flip.h:31
static Shape< srcdim > Check(const FlipExp< SrcExp, Device, DType, srcdim > &t)
Definition: flip.h:77
index_t stride_j_
Definition: flip.h:30
computaion stream structure, used for asynchronous computations
Definition: tensor.h:365