Go to the documentation of this file.
25 #ifndef MSHADOW_EXTENSION_REDUCE_WITH_AXIS_H_
26 #define MSHADOW_EXTENSION_REDUCE_WITH_AXIS_H_
28 #include "../extension.h"
38 template<
typename Reducer,
typename SrcExp,
typename DType,
int dimsrc,
bool mask,
int dimdst>
40 public MakeTensorExp<ReduceWithAxisExp<Reducer, SrcExp, DType, dimsrc, mask, dimdst>,
41 SrcExp, dimdst, DType> {
55 bool keepdim = (dimsrc == dimdst);
56 CHECK(dimsrc > axis) <<
"reduce axis out of bound";
58 for (
int i = 0; i < axis; ++i) {
59 this->
shape_[i] = src_shape[i];
61 this->size_ = src_shape[axis];
64 for (
int i = axis + 1; i < dimsrc; ++i) {
65 this->trailing_ *= src_shape[i];
66 this->
shape_[i - 1] = src_shape[i];
70 for (
index_t i = axis + 1; i < dimsrc; ++i) {
71 this->trailing_ *= src_shape[i];
72 this->
shape_[i] = src_shape[i];
76 this->last_ = src_shape[dimsrc - 1];
77 this->last_dst_dim_ = this->
shape_[dimdst - 1];
89 template<
typename Reducer,
bool mask,
typename SrcExp,
typename DType,
int etype>
90 inline ReduceWithAxisExp<Reducer, SrcExp, DType, ExpInfo<SrcExp>::kDim,
mask,
105 template<
typename Reducer,
bool mask,
typename SrcExp,
typename DType,
int etype>
106 inline ReduceWithAxisExp<Reducer, SrcExp, DType, ExpInfo<SrcExp>::kDim,
mask,
116 template<
typename Reducer,
typename SrcExp,
typename DType,
int dimsrc,
bool mask,
int dimdst>
120 : src_(
MakePlan(e.src_)), last_dst_dim_(e.last_dst_dim_), trailing_(e.trailing_),
121 size_(e.size_), last_(e.last_) {}
123 index_t x = (i*last_dst_dim_ + j)/trailing_;
124 index_t y = (i*last_dst_dim_ + j)%trailing_;
128 DType res; Reducer::SetInitValue(res);
129 for (
index_t k = 0; k < size_; ++k) {
130 index_t z = (x*size_+k)*trailing_+y;
132 Reducer::Reduce(res, src_.Eval(z/last_, z%last_));
137 return static_cast<DType
>(
static_cast<int>(idx));
139 DType res; Reducer::SetInitValue(res);
140 for (
index_t k = 0; k < size_; ++k) {
141 index_t z = (x*size_+k)*trailing_+y;
142 Reducer::Reduce(res, src_.Eval(z/last_, z%last_));
150 const index_t last_dst_dim_, trailing_, size_, last_;
154 #endif // MSHADOW_EXTENSION_REDUCE_WITH_AXIS_H_
const SrcExp & src_
source oprand
Definition: reduce_with_axis.h:43
const SubType & self(void) const
Definition: expression.h:82
ReduceWithAxisExp< Reducer, SrcExp, DType, ExpInfo< SrcExp >::kDim, mask, ExpInfo< SrcExp >::kDim > reduce_keepdim(const Exp< SrcExp, DType, etype > &src, int axis)
reduce out the dimension of src labeled by axis, keepdim turned on.
Definition: reduce_with_axis.h:108
ReduceWithAxisExp< Reducer, SrcExp, DType, ExpInfo< SrcExp >::kDim, mask, ExpInfo< SrcExp >::kDim - 1 > reduce_with_axis(const Exp< SrcExp, DType, etype > &src, int axis)
reduce out the dimension of src labeled by axis.
Definition: reduce_with_axis.h:92
#define MSHADOW_XINLINE
Definition: base.h:228
MSHADOW_XINLINE bool IsNan(volatile DType val)
Definition: base.h:753
index_t last_
size of last src dimension
Definition: reduce_with_axis.h:51
static Shape< dim > Check(const E &t)
static type inference template, used to get the dimension of each expression, if ExpInfo<E>::kDim == ...
Definition: expr_engine-inl.h:262
Plan< BinaryMapExp< OP, TA, TB, DType, etype >, DType > MakePlan(const BinaryMapExp< OP, TA, TB, DType, etype > &e)
Definition: expr_engine-inl.h:239
reduce out the dimension of src labeled by axis.
Definition: reduce_with_axis.h:39
MaskExp< IndexExp, SrcExp, DType > mask(const Exp< IndexExp, DType, e1 > &index, const Exp< SrcExp, DType, e2 > &src)
Definition: mask.h:57
Shape< dim > shape_
the shape of this expression
Definition: expr_engine-inl.h:47
index_t last_dst_dim_
size of last destination dimension
Definition: reduce_with_axis.h:45
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const
Definition: reduce_with_axis.h:122
ReduceWithAxisExp(const SrcExp &src, int axis)
Definition: reduce_with_axis.h:53
index_t size_
size of axis dimension
Definition: reduce_with_axis.h:49
static const int kDim
Definition: expr_engine-inl.h:263
int32_t index_t
type that will be used for index
Definition: base.h:328
Definition: expr_engine-inl.h:58
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:79
overloaded + operator between half_t and bf16_t
Definition: base.h:319
a general class that allows extension that makes tensors of some shape
Definition: expr_engine-inl.h:43
Plan(const ReduceWithAxisExp< Reducer, SrcExp, DType, dimsrc, mask, dimdst > &e)
Definition: reduce_with_axis.h:119
index_t trailing_
size of trailing dimensions
Definition: reduce_with_axis.h:47