mxnet
take.h
Go to the documentation of this file.
1 
7 #ifndef MSHADOW_EXTENSION_TAKE_H_
8 #define MSHADOW_EXTENSION_TAKE_H_
9 
10 #include "../extension.h"
11 
12 namespace mshadow {
13 namespace expr {
14 
20 template<typename IndexExp, typename SrcExp, typename DType>
21 struct TakeExp: public Exp<TakeExp<IndexExp, SrcExp, DType>,
22  DType, type::kChainer> {
24  const IndexExp &index_;
26  const SrcExp &src_;
28  TakeExp(const IndexExp &index, const SrcExp &src)
29  : index_(index), src_(src) {}
30 }; // struct TakeExp
31 
32 
33 
34 template<typename IndexExp,
35  typename SrcExp,
36  typename DType,
37  int e1, int e2>
40  const Exp<SrcExp, DType, e2> &src) {
41  return TakeExp<IndexExp, SrcExp, DType>(index.self(), src.self());
42 }
43 
44 
45 //----------------------
46 // Execution plan
47 //----------------------
48 
49 template<typename IndexExp, typename SrcExp, typename DType>
50 struct Plan<TakeExp<IndexExp, SrcExp, DType>, DType> {
51  public:
53  : index_(MakePlan(e.index_)), src_(MakePlan(e.src_)) {
54  }
55 
56  // TODO(xx): discuss W shape: in * out or out * in
57  // Now I use in * out
58  MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
59  index_t idx = static_cast<index_t>(index_.Eval(0, y));
60  return static_cast<DType>(src_.Eval(idx, x));
61  }
62 
63  private:
66 }; // struct Plan
67 
68 template<typename IndexExp, typename SrcExp, typename DType>
71  return Plan<TakeExp<IndexExp, SrcExp, DType>, DType>(exp);
72 }
73 
74 template<int dim, typename IndexExp, typename SrcExp, typename DType>
75 struct ShapeCheck<dim, TakeExp<IndexExp, SrcExp, DType> > {
76  inline static Shape<dim>
78  CHECK(dim == 2)
79  << "TakeExp only support 2D output";
82  Shape<dim> ret;
83  ret[0] = dshape[0];
84  ret[1] = wshape[1];
85  return ret;
86  }
87 };
88 
89 
90 template<typename IndexExp, typename SrcExp, typename DType>
91 struct ExpInfo<TakeExp<IndexExp, SrcExp, DType> > {
92  static const int kDim = 2;
93  static const int kDevMask = ExpInfo<IndexExp>::kDevMask;
94 };
95 
96 } // namespace expr
97 } // namespace mshadow
98 
99 #endif // MSHADOW_EXTENSION_TAKE_H_
static Shape< dim > Check(const TakeExp< IndexExp, SrcExp, DType > &t)
Definition: take.h:77
Definition: expr_engine-inl.h:40
TakeExp< IndexExp, SrcExp, DType > take(const Exp< IndexExp, DType, e1 > &index, const Exp< SrcExp, DType, e2 > &src)
Definition: take.h:39
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: take.h:58
static Shape< dim > Check(const E &t)
TakeExp(const IndexExp &index, const SrcExp &src)
Definition: take.h:28
#define MSHADOW_XINLINE
Definition: base.h:204
static type inference template, used to get the dimension of each expression, if ExpInfo<E>::kDim == ...
Definition: expr_engine-inl.h:244
int32_t index_t
type that will be used for index
Definition: base.h:291
Take a column from a matrix.
Definition: take.h:21
Plan(const TakeExp< IndexExp, SrcExp, DType > &e)
Definition: take.h:52
runtime shape checking template get the shape of an expression, report error if shape mismatch ...
Definition: expr_engine-inl.h:346
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:61
const SubType & self(void) const
Definition: expression.h:64
Plan< BinaryMapExp< OP, TA, TB, DType, etype >, DType > MakePlan(const BinaryMapExp< OP, TA, TB, DType, etype > &e)
Definition: expr_engine-inl.h:221
namespace for mshadow
Definition: base.h:282
const IndexExp & index_
index oprand
Definition: take.h:24
const SrcExp & src_
embediing oprand
Definition: take.h:26