mxnet
reduce_with_axis.h
Go to the documentation of this file.
1 
7 #ifndef MSHADOW_EXTENSION_REDUCE_WITH_AXIS_H_
8 #define MSHADOW_EXTENSION_REDUCE_WITH_AXIS_H_
9 
10 #include "../extension.h"
11 
12 namespace mshadow {
13 namespace expr {
14 
20 template<typename Reducer, typename SrcExp, typename DType, int dimsrc, bool mask, int dimdst>
22  public MakeTensorExp<ReduceWithAxisExp<Reducer, SrcExp, DType, dimsrc, mask, dimdst>,
23  SrcExp, dimdst, DType> {
25  const SrcExp &src_;
35  explicit ReduceWithAxisExp(const SrcExp &src, int axis)
36  : src_(src) {
37  bool keepdim = (dimsrc == dimdst);
38  CHECK(dimsrc > axis) << "reduce axis out of bound";
40  for (int i = 0; i < axis; ++i) {
41  this->shape_[i] = src_shape[i];
42  }
43  this->size_ = src_shape[axis];
44  this->trailing_ = 1;
45  if (!keepdim) {
46  for (int i = axis + 1; i < dimsrc; ++i) {
47  this->trailing_ *= src_shape[i];
48  this->shape_[i - 1] = src_shape[i];
49  }
50  } else {
51  this->shape_[axis] = 1;
52  for (index_t i = axis + 1; i < dimsrc; ++i) {
53  this->trailing_ *= src_shape[i];
54  this->shape_[i] = src_shape[i];
55  }
56  }
57 
58  this->last_ = src_shape[dimsrc - 1];
59  this->last_dst_dim_ = this->shape_[dimdst - 1];
60  }
61 }; // struct ReduceWithAxisExp
62 
71 template<typename Reducer, bool mask, typename SrcExp, typename DType, int etype>
75  return ReduceWithAxisExp<Reducer, SrcExp, DType, ExpInfo<SrcExp>::kDim, mask,
76  ExpInfo<SrcExp>::kDim- 1>(src.self(), axis);
77 }
78 
87 template<typename Reducer, bool mask, typename SrcExp, typename DType, int etype>
88 inline ReduceWithAxisExp<Reducer, SrcExp, DType, ExpInfo<SrcExp>::kDim, mask,
90  reduce_keepdim(const Exp<SrcExp, DType, etype> &src, int axis) {
91  return ReduceWithAxisExp<Reducer, SrcExp, DType, ExpInfo<SrcExp>::kDim, mask,
92  ExpInfo<SrcExp>::kDim>(src.self(), axis);
93 }
94 
95 //----------------------
96 // Execution plan
97 //----------------------
98 template<typename Reducer, typename SrcExp, typename DType, int dimsrc, bool mask, int dimdst>
99 struct Plan<ReduceWithAxisExp<Reducer, SrcExp, DType, dimsrc, mask, dimdst>, DType> {
100  public:
103  size_(e.size_), last_(e.last_) {}
104  MSHADOW_XINLINE DType Eval(index_t i, index_t j) const {
105  index_t x = (i*last_dst_dim_ + j)/trailing_;
106  index_t y = (i*last_dst_dim_ + j)%trailing_;
107 
108  if (mask) {
109  index_t idx = 0;
110  DType res; Reducer::SetInitValue(res);
111  for (index_t k = 0; k < size_; ++k) {
112  index_t z = (x*size_+k)*trailing_+y;
113  DType tmp = res;
114  Reducer::Reduce(res, src_.Eval(z/last_, z%last_));
115  if (tmp != res && !isnan_typed::IsNan(tmp)) {
116  idx = k;
117  }
118  }
119  return static_cast<DType>(static_cast<int>(idx));
120  } else {
121  DType res; Reducer::SetInitValue(res);
122  for (index_t k = 0; k < size_; ++k) {
123  index_t z = (x*size_+k)*trailing_+y;
124  Reducer::Reduce(res, src_.Eval(z/last_, z%last_));
125  }
126  return res;
127  }
128  }
129 
130  private:
133 };
134 } // namespace expr
135 } // namespace mshadow
136 #endif // MSHADOW_EXTENSION_REDUCE_WITH_AXIS_H_
index_t last_
size of last src dimension
Definition: reduce_with_axis.h:33
Definition: expr_engine-inl.h:40
static Shape< dim > Check(const E &t)
#define MSHADOW_XINLINE
Definition: base.h:204
static type inference template, used to get the dimension of each expression, if ExpInfo<E>::kDim == ...
Definition: expr_engine-inl.h:244
MSHADOW_XINLINE bool IsNan(volatile DType val)
Definition: base.h:620
MaskExp< IndexExp, SrcExp, DType > mask(const Exp< IndexExp, DType, e1 > &index, const Exp< SrcExp, DType, e2 > &src)
Definition: mask.h:39
const SrcExp & src_
source oprand
Definition: reduce_with_axis.h:25
int32_t index_t
type that will be used for index
Definition: base.h:291
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const
Definition: reduce_with_axis.h:104
index_t size_
size of axis dimension
Definition: reduce_with_axis.h:31
ReduceWithAxisExp< Reducer, SrcExp, DType, ExpInfo< SrcExp >::kDim, mask, ExpInfo< SrcExp >::kDim-1 > reduce_with_axis(const Exp< SrcExp, DType, etype > &src, int axis)
reduce out the dimension of src labeled by axis.
Definition: reduce_with_axis.h:74
ReduceWithAxisExp(const SrcExp &src, int axis)
Definition: reduce_with_axis.h:35
reduce out the dimension of src labeled by axis.
Definition: reduce_with_axis.h:21
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:61
const SubType & self(void) const
Definition: expression.h:64
Plan< BinaryMapExp< OP, TA, TB, DType, etype >, DType > MakePlan(const BinaryMapExp< OP, TA, TB, DType, etype > &e)
Definition: expr_engine-inl.h:221
ReduceWithAxisExp< Reducer, SrcExp, DType, ExpInfo< SrcExp >::kDim, mask, ExpInfo< SrcExp >::kDim > reduce_keepdim(const Exp< SrcExp, DType, etype > &src, int axis)
reduce out the dimension of src labeled by axis, keepdim turned on.
Definition: reduce_with_axis.h:90
a general class that allows extension that makes tensors of some shape
Definition: expr_engine-inl.h:25
namespace for mshadow
Definition: base.h:282
Plan(const ReduceWithAxisExp< Reducer, SrcExp, DType, dimsrc, mask, dimdst > &e)
Definition: reduce_with_axis.h:101
index_t trailing_
size of trailing dimensions
Definition: reduce_with_axis.h:29
index_t last_dst_dim_
size of last destination dimension
Definition: reduce_with_axis.h:27