mxnet
take_grad.h
Go to the documentation of this file.
1 
7 #ifndef MSHADOW_EXTENSION_TAKE_GRAD_H_
8 #define MSHADOW_EXTENSION_TAKE_GRAD_H_
9 
10 #include "../extension.h"
11 
12 namespace mshadow {
13 namespace expr {
14 
21 template<typename IndexExp, typename SrcExp, typename DType>
22 struct TakeGradExp : public Exp<TakeGradExp<IndexExp, SrcExp, DType>,
23  DType, type::kChainer> {
25  const IndexExp &index_;
27  const SrcExp &src_;
31  TakeGradExp(const IndexExp &index, const SrcExp &src, const index_t input_dim)
32  : index_(index), src_(src), input_dim_(input_dim) {}
33 }; // struct TakeGradExp
34 
35 
36 template<typename IndexExp,
37  typename SrcExp,
38  typename DType,
39  int e1, int e2>
42  const Exp<SrcExp, DType, e2> &src,
43  const index_t input_dim) {
45  src.self(),
46  input_dim);
47 }
48 
49 //----------------------
50 // Execution plan
51 //----------------------
52 
53 template<typename IndexExp, typename SrcExp, typename DType>
54 struct Plan<TakeGradExp<IndexExp, SrcExp, DType>, DType> {
55  public:
57  : index_(MakePlan(e.index_)),
58  src_(MakePlan(e.src_)),
59  batch_size_(ShapeCheck<1, IndexExp>::Check(e.index_)[0]) {
60  }
61 
62  // now return shape: in * out
63  MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
64  DType ret = 0.f;
65  for (index_t i = 0; i < batch_size_; ++i) {
66  index_t idx = static_cast<index_t>(index_.Eval(0, i));
67  if (idx == y) {
68  ret += static_cast<DType>(src_.Eval(i, x));
69  }
70  }
71  return ret;
72  }
73 
74  private:
77  const index_t batch_size_;
78 }; // struct Plan
79 
80 
81 template<typename IndexExp, typename SrcExp, typename DType>
85 }
86 
87 template<int dim, typename IndexExp, typename SrcExp, typename DType>
88 struct ShapeCheck<dim, TakeGradExp<IndexExp, SrcExp, DType> > {
89  inline static Shape<dim>
91  CHECK(dim == 2)
92  << "TakeGradExp only support 2D output";
93  // Shape<1> dshape = ShapeCheck<1, IndexExp>::Check(t.index_);
95  Shape<dim> ret;
96  ret[0] = t.input_dim_;
97  ret[1] = gshape[1];
98  return ret;
99  }
100 }; // struct ShapeCheck
101 
102 template<typename IndexExp, typename SrcExp, typename DType>
103 struct ExpInfo<TakeGradExp<IndexExp, SrcExp, DType> > {
104  static const int kDim = 2;
105  static const int kDevMask = ExpInfo<IndexExp>::kDevMask;
106 };
107 
108 } // namespace expr
109 } // namespace mshadow
110 
111 #endif // MSHADOW_EXTENSION_TAKE_GRAD_H_
const IndexExp & index_
index oprand
Definition: take_grad.h:25
Definition: expr_engine-inl.h:40
const SrcExp & src_
out gradient oprand
Definition: take_grad.h:27
static Shape< dim > Check(const E &t)
#define MSHADOW_XINLINE
Definition: base.h:204
static type inference template, used to get the dimension of each expression, if ExpInfo<E>::kDim == ...
Definition: expr_engine-inl.h:244
static Shape< dim > Check(const TakeGradExp< IndexExp, SrcExp, DType > &t)
Definition: take_grad.h:90
int32_t index_t
type that will be used for index
Definition: base.h:291
TakeGradExp< IndexExp, SrcExp, DType > take_grad(const Exp< IndexExp, DType, e1 > &index, const Exp< SrcExp, DType, e2 > &src, const index_t input_dim)
Definition: take_grad.h:41
runtime shape checking template get the shape of an expression, report error if shape mismatch ...
Definition: expr_engine-inl.h:346
Calculate embedding gradient.
Definition: take_grad.h:22
const index_t input_dim_
batch size
Definition: take_grad.h:29
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:61
const SubType & self(void) const
Definition: expression.h:64
Plan< BinaryMapExp< OP, TA, TB, DType, etype >, DType > MakePlan(const BinaryMapExp< OP, TA, TB, DType, etype > &e)
Definition: expr_engine-inl.h:221
TakeGradExp(const IndexExp &index, const SrcExp &src, const index_t input_dim)
constructor
Definition: take_grad.h:31
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: take_grad.h:63
namespace for mshadow
Definition: base.h:282
Plan(const TakeGradExp< IndexExp, SrcExp, DType > &e)
Definition: take_grad.h:56