mxnet
reduce_with_axis.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
25 #ifndef MSHADOW_EXTENSION_REDUCE_WITH_AXIS_H_
26 #define MSHADOW_EXTENSION_REDUCE_WITH_AXIS_H_
27 
28 #include "../extension.h"
29 
30 namespace mshadow {
31 namespace expr {
32 
38 template<typename Reducer, typename SrcExp, typename DType, int dimsrc, bool mask, int dimdst>
40  public MakeTensorExp<ReduceWithAxisExp<Reducer, SrcExp, DType, dimsrc, mask, dimdst>,
41  SrcExp, dimdst, DType> {
43  const SrcExp &src_;
53  explicit ReduceWithAxisExp(const SrcExp &src, int axis)
54  : src_(src) {
55  bool keepdim = (dimsrc == dimdst);
56  CHECK(dimsrc > axis) << "reduce axis out of bound";
58  for (int i = 0; i < axis; ++i) {
59  this->shape_[i] = src_shape[i];
60  }
61  this->size_ = src_shape[axis];
62  this->trailing_ = 1;
63  if (!keepdim) {
64  for (int i = axis + 1; i < dimsrc; ++i) {
65  this->trailing_ *= src_shape[i];
66  this->shape_[i - 1] = src_shape[i];
67  }
68  } else {
69  this->shape_[axis] = 1;
70  for (index_t i = axis + 1; i < dimsrc; ++i) {
71  this->trailing_ *= src_shape[i];
72  this->shape_[i] = src_shape[i];
73  }
74  }
75 
76  this->last_ = src_shape[dimsrc - 1];
77  this->last_dst_dim_ = this->shape_[dimdst - 1];
78  }
79 }; // struct ReduceWithAxisExp
80 
89 template<typename Reducer, bool mask, typename SrcExp, typename DType, int etype>
90 inline ReduceWithAxisExp<Reducer, SrcExp, DType, ExpInfo<SrcExp>::kDim, mask,
94  ExpInfo<SrcExp>::kDim- 1>(src.self(), axis);
95 }
96 
105 template<typename Reducer, bool mask, typename SrcExp, typename DType, int etype>
106 inline ReduceWithAxisExp<Reducer, SrcExp, DType, ExpInfo<SrcExp>::kDim, mask,
110  ExpInfo<SrcExp>::kDim>(src.self(), axis);
111 }
112 
113 //----------------------
114 // Execution plan
115 //----------------------
116 template<typename Reducer, typename SrcExp, typename DType, int dimsrc, bool mask, int dimdst>
117 struct Plan<ReduceWithAxisExp<Reducer, SrcExp, DType, dimsrc, mask, dimdst>, DType> {
118  public:
120  : src_(MakePlan(e.src_)), last_dst_dim_(e.last_dst_dim_), trailing_(e.trailing_),
121  size_(e.size_), last_(e.last_) {}
122  MSHADOW_XINLINE DType Eval(index_t i, index_t j) const {
123  index_t x = (i*last_dst_dim_ + j)/trailing_;
124  index_t y = (i*last_dst_dim_ + j)%trailing_;
125 
126  if (mask) {
127  index_t idx = 0;
128  DType res; Reducer::SetInitValue(res);
129  for (index_t k = 0; k < size_; ++k) {
130  index_t z = (x*size_+k)*trailing_+y;
131  DType tmp = res;
132  Reducer::Reduce(res, src_.Eval(z/last_, z%last_));
133  if (tmp != res && !isnan_typed::IsNan(tmp)) {
134  idx = k;
135  }
136  }
137  return static_cast<DType>(static_cast<int>(idx));
138  } else {
139  DType res; Reducer::SetInitValue(res);
140  for (index_t k = 0; k < size_; ++k) {
141  index_t z = (x*size_+k)*trailing_+y;
142  Reducer::Reduce(res, src_.Eval(z/last_, z%last_));
143  }
144  return res;
145  }
146  }
147 
148  private:
149  Plan<SrcExp, DType> src_;
150  const index_t last_dst_dim_, trailing_, size_, last_;
151 };
152 } // namespace expr
153 } // namespace mshadow
154 #endif // MSHADOW_EXTENSION_REDUCE_WITH_AXIS_H_
mshadow::expr::ReduceWithAxisExp::src_
const SrcExp & src_
source oprand
Definition: reduce_with_axis.h:43
mshadow::expr::Exp::self
const SubType & self(void) const
Definition: expression.h:82
mshadow::expr::reduce_keepdim
ReduceWithAxisExp< Reducer, SrcExp, DType, ExpInfo< SrcExp >::kDim, mask, ExpInfo< SrcExp >::kDim > reduce_keepdim(const Exp< SrcExp, DType, etype > &src, int axis)
reduce out the dimension of src labeled by axis, keepdim turned on.
Definition: reduce_with_axis.h:108
mshadow::expr::reduce_with_axis
ReduceWithAxisExp< Reducer, SrcExp, DType, ExpInfo< SrcExp >::kDim, mask, ExpInfo< SrcExp >::kDim - 1 > reduce_with_axis(const Exp< SrcExp, DType, etype > &src, int axis)
reduce out the dimension of src labeled by axis.
Definition: reduce_with_axis.h:92
MSHADOW_XINLINE
#define MSHADOW_XINLINE
Definition: base.h:228
mshadow::isnan_typed::IsNan
MSHADOW_XINLINE bool IsNan(volatile DType val)
Definition: base.h:753
mshadow::expr::ReduceWithAxisExp::last_
index_t last_
size of last src dimension
Definition: reduce_with_axis.h:51
mshadow::expr::ShapeCheck::Check
static Shape< dim > Check(const E &t)
mshadow::expr::ExpInfo
static type inference template, used to get the dimension of each expression, if ExpInfo<E>::kDim == ...
Definition: expr_engine-inl.h:262
mshadow::expr::MakePlan
Plan< BinaryMapExp< OP, TA, TB, DType, etype >, DType > MakePlan(const BinaryMapExp< OP, TA, TB, DType, etype > &e)
Definition: expr_engine-inl.h:239
mshadow::expr::ReduceWithAxisExp
reduce out the dimension of src labeled by axis.
Definition: reduce_with_axis.h:39
mshadow::expr::mask
MaskExp< IndexExp, SrcExp, DType > mask(const Exp< IndexExp, DType, e1 > &index, const Exp< SrcExp, DType, e2 > &src)
Definition: mask.h:57
mshadow::expr::MakeTensorExp< ReduceWithAxisExp< Reducer, SrcExp, DType, dimsrc, mask, dimdst >, SrcExp, dimdst, DType >::shape_
Shape< dim > shape_
the shape of this expression
Definition: expr_engine-inl.h:47
mshadow::expr::ReduceWithAxisExp::last_dst_dim_
index_t last_dst_dim_
size of last destination dimension
Definition: reduce_with_axis.h:45
mshadow::expr::Plan< ReduceWithAxisExp< Reducer, SrcExp, DType, dimsrc, mask, dimdst >, DType >::Eval
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const
Definition: reduce_with_axis.h:122
mshadow::expr::ReduceWithAxisExp::ReduceWithAxisExp
ReduceWithAxisExp(const SrcExp &src, int axis)
Definition: reduce_with_axis.h:53
mshadow::expr::ReduceWithAxisExp::size_
index_t size_
size of axis dimension
Definition: reduce_with_axis.h:49
mshadow::expr::ExpInfo::kDim
static const int kDim
Definition: expr_engine-inl.h:263
mshadow::index_t
int32_t index_t
type that will be used for index
Definition: base.h:328
mshadow::expr::Plan
Definition: expr_engine-inl.h:58
mshadow::expr::Exp
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:79
mshadow
overloaded + operator between half_t and bf16_t
Definition: base.h:319
mshadow::expr::MakeTensorExp
a general class that allows extension that makes tensors of some shape
Definition: expr_engine-inl.h:43
mshadow::Shape< dimsrc >
mshadow::expr::Plan< ReduceWithAxisExp< Reducer, SrcExp, DType, dimsrc, mask, dimdst >, DType >::Plan
Plan(const ReduceWithAxisExp< Reducer, SrcExp, DType, dimsrc, mask, dimdst > &e)
Definition: reduce_with_axis.h:119
mshadow::expr::ReduceWithAxisExp::trailing_
index_t trailing_
size of trailing dimensions
Definition: reduce_with_axis.h:47