26 #ifndef MSHADOW_EXTENSION_REDUCE_WITH_AXIS_H_ 27 #define MSHADOW_EXTENSION_REDUCE_WITH_AXIS_H_ 29 #include "../extension.h" 39 template<
typename Reducer,
typename SrcExp,
typename DType,
int dimsrc,
bool mask,
int dimdst>
41 public MakeTensorExp<ReduceWithAxisExp<Reducer, SrcExp, DType, dimsrc, mask, dimdst>,
42 SrcExp, dimdst, DType> {
56 bool keepdim = (dimsrc == dimdst);
57 CHECK(dimsrc > axis) <<
"reduce axis out of bound";
59 for (
int i = 0; i < axis; ++i) {
60 this->
shape_[i] = src_shape[i];
62 this->size_ = src_shape[axis];
65 for (
int i = axis + 1; i < dimsrc; ++i) {
66 this->trailing_ *= src_shape[i];
67 this->
shape_[i - 1] = src_shape[i];
71 for (
index_t i = axis + 1; i < dimsrc; ++i) {
72 this->trailing_ *= src_shape[i];
73 this->
shape_[i] = src_shape[i];
77 this->last_ = src_shape[dimsrc - 1];
78 this->last_dst_dim_ = this->
shape_[dimdst - 1];
90 template<
typename Reducer,
bool mask,
typename SrcExp,
typename DType,
int etype>
94 return ReduceWithAxisExp<Reducer, SrcExp, DType, ExpInfo<SrcExp>::kDim,
mask,
106 template<
typename Reducer,
bool mask,
typename SrcExp,
typename DType,
int etype>
107 inline ReduceWithAxisExp<Reducer, SrcExp, DType, ExpInfo<SrcExp>::kDim,
mask,
110 return ReduceWithAxisExp<Reducer, SrcExp, DType, ExpInfo<SrcExp>::kDim,
mask,
117 template<
typename Reducer,
typename SrcExp,
typename DType,
int dimsrc,
bool mask,
int dimdst>
129 DType res; Reducer::SetInitValue(res);
138 return static_cast<DType
>(
static_cast<int>(idx));
140 DType res; Reducer::SetInitValue(res);
155 #endif // MSHADOW_EXTENSION_REDUCE_WITH_AXIS_H_ index_t last_
size of last src dimension
Definition: reduce_with_axis.h:52
Definition: expr_engine-inl.h:59
static Shape< dim > Check(const E &t)
#define MSHADOW_XINLINE
Definition: base.h:223
static type inference template, used to get the dimension of each expression, if ExpInfo<E>::kDim == ...
Definition: expr_engine-inl.h:263
MSHADOW_XINLINE bool IsNan(volatile DType val)
Definition: base.h:675
MaskExp< IndexExp, SrcExp, DType > mask(const Exp< IndexExp, DType, e1 > &index, const Exp< SrcExp, DType, e2 > &src)
Definition: mask.h:58
const SrcExp & src_
source oprand
Definition: reduce_with_axis.h:44
int32_t index_t
type that will be used for index
Definition: base.h:336
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const
Definition: reduce_with_axis.h:123
index_t size_
size of axis dimension
Definition: reduce_with_axis.h:50
ReduceWithAxisExp< Reducer, SrcExp, DType, ExpInfo< SrcExp >::kDim, mask, ExpInfo< SrcExp >::kDim-1 > reduce_with_axis(const Exp< SrcExp, DType, etype > &src, int axis)
reduce out the dimension of src labeled by axis.
Definition: reduce_with_axis.h:93
ReduceWithAxisExp(const SrcExp &src, int axis)
Definition: reduce_with_axis.h:54
reduce out the dimension of src labeled by axis.
Definition: reduce_with_axis.h:40
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:80
const SubType & self(void) const
Definition: expression.h:83
Plan< BinaryMapExp< OP, TA, TB, DType, etype >, DType > MakePlan(const BinaryMapExp< OP, TA, TB, DType, etype > &e)
Definition: expr_engine-inl.h:240
ReduceWithAxisExp< Reducer, SrcExp, DType, ExpInfo< SrcExp >::kDim, mask, ExpInfo< SrcExp >::kDim > reduce_keepdim(const Exp< SrcExp, DType, etype > &src, int axis)
reduce out the dimension of src labeled by axis, keepdim turned on.
Definition: reduce_with_axis.h:109
a general class that allows extension that makes tensors of some shape
Definition: expr_engine-inl.h:44
overloaded + operator between half_t and bf16_t
Definition: base.h:327
Shape< dim > shape_
the shape of this expression
Definition: expr_engine-inl.h:48
Plan(const ReduceWithAxisExp< Reducer, SrcExp, DType, dimsrc, mask, dimdst > &e)
Definition: reduce_with_axis.h:120
index_t trailing_
size of trailing dimensions
Definition: reduce_with_axis.h:48
index_t last_dst_dim_
size of last destination dimension
Definition: reduce_with_axis.h:46