26 #ifndef MSHADOW_EXTENSION_SPATIAL_POOL_H_ 27 #define MSHADOW_EXTENSION_SPATIAL_POOL_H_ 29 #include "../extension.h" 39 template<
typename Reducer,
typename SrcExp,
typename DType,
int srcdim>
41 public MakeTensorExp<PoolingExp<Reducer, SrcExp, DType, srcdim>,
42 SrcExp, srcdim, DType> {
60 : src_(src), ksize_y_(ksize_y), ksize_x_(ksize_x),
61 kstride_y_(kstride_y), kstride_x_(kstride_x) {
63 CHECK(sshape[srcdim - 1] >= ksize_x && sshape[srcdim - 2] >= ksize_y)
64 <<
"PoolingExp: kernel must be smaller than image";
65 this->src_height_ = sshape[srcdim - 2];
66 this->src_width_ = sshape[srcdim - 1];
68 this->
shape_[srcdim - 2] = (src_height_ - ksize_y) / kstride_y + 1;
69 this->
shape_[srcdim - 1] = (src_width_ - ksize_x) / kstride_x + 1;
74 : src_(src), ksize_y_(ksize_y), ksize_x_(ksize_x),
75 kstride_y_(kstride_y), kstride_x_(kstride_x) {
77 CHECK(sshape[srcdim - 1] >= ksize_x && sshape[srcdim - 2] >= ksize_y)
78 <<
"PoolingExp: kernel must be smaller than image";
79 this->src_height_ = sshape[srcdim - 2];
80 this->src_width_ = sshape[srcdim - 1];
82 this->
shape_[srcdim - 2] = pshape[0];
83 this->
shape_[srcdim - 1] = pshape[1];
99 template<
typename Reducer,
typename SrcExp,
typename DType,
int etype>
104 ::Error_Expression_Does_Not_Meet_Dimension_Req();
106 (src.
self(), ksize_y, ksize_x, kstride_y, kstride_x);
122 template<
typename Reducer,
typename SrcExp,
123 typename DType,
int etype>
128 ::Error_Expression_Does_Not_Meet_Dimension_Req();
130 (src.
self(), pshape, ksize_y, ksize_x, kstride_y, kstride_x);
135 template<
typename Reducer,
typename SrcExp,
typename DType,
int srcdim>
143 new_height_(e.
shape_[srcdim - 2]) {}
146 const index_t py = i % new_height_;
152 const index_t c = i / new_height_;
154 DType res; Reducer::SetInitValue(res);
155 for (
index_t y = y_start; y < y_end; ++y) {
156 for (
index_t x = x_start; x < x_end; ++x) {
171 #endif // MSHADOW_EXTENSION_SPATIAL_POOL_H_ index_t ksize_x_
kernel size in width
Definition: spatial_pool.h:48
Definition: expr_engine-inl.h:59
index_t src_height_
source height shape[1]
Definition: spatial_pool.h:54
used to help static type check
Definition: expr_engine-inl.h:331
index_t ksize_y_
kernel size in height
Definition: spatial_pool.h:46
Definition: optional.h:241
PoolingExp< Reducer, SrcExp, DType, ExpInfo< SrcExp >::kDim > pool(const Exp< SrcExp, DType, etype > &src, index_t ksize_y, index_t ksize_x, index_t kstride_y, index_t kstride_x)
pooling subregion results together
Definition: spatial_pool.h:101
PoolingExp(const SrcExp &src, index_t ksize_y, index_t ksize_x, index_t kstride_y, index_t kstride_x)
constructor
Definition: spatial_pool.h:58
PoolingExp(const SrcExp &src, Shape< 2 > pshape, index_t ksize_y, index_t ksize_x, index_t kstride_y, index_t kstride_x)
constructor, specify shape
Definition: spatial_pool.h:72
static Shape< dim > Check(const E &t)
#define MSHADOW_XINLINE
Definition: base.h:223
int32_t index_t
type that will be used for index
Definition: base.h:336
index_t kstride_y_
kernel stride in y directory
Definition: spatial_pool.h:50
index_t kstride_x_
kernel stride in x directory
Definition: spatial_pool.h:52
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:80
index_t src_width_
source width shape[0]
Definition: spatial_pool.h:56
const SubType & self(void) const
Definition: expression.h:83
Plan< BinaryMapExp< OP, TA, TB, DType, etype >, DType > MakePlan(const BinaryMapExp< OP, TA, TB, DType, etype > &e)
Definition: expr_engine-inl.h:240
a general class that allows extension that makes tensors of some shape
Definition: expr_engine-inl.h:44
overloaded + operator between half_t and bf16_t
Definition: base.h:327
const SrcExp & src_
source operand
Definition: spatial_pool.h:44
Shape< dim > shape_
the shape of this expression
Definition: expr_engine-inl.h:48
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const
Definition: spatial_pool.h:144
Plan(const PoolingExp< Reducer, SrcExp, DType, srcdim > &e)
Definition: spatial_pool.h:138
pooling expression, do reduction over local patches of a image
Definition: spatial_pool.h:40