Go to the documentation of this file.
25 #ifndef MSHADOW_EXTENSION_SPATIAL_UNPOOL_H_
26 #define MSHADOW_EXTENSION_SPATIAL_UNPOOL_H_
28 #include "../extension.h"
38 template<
typename Reducer,
typename SrcExp,
typename DType,
int srcdim>
40 public MakeTensorExp<UnPoolingExp<Reducer, SrcExp, DType, srcdim>,
41 SrcExp, srcdim, DType> {
62 const SrcExp &data_pooled,
63 const SrcExp &grad_pooled,
71 CHECK_EQ(pshape, ShapeCheckSrcDimSrcExp::Check(data_pooled))
72 <<
"UnPoolingExp: pooled shape mismatch";
74 for (
int k = 0; k < srcdim - 2; ++k) {
75 CHECK_EQ(pshape[k], sshape[k]) <<
"UnPoolingExp: pool and src shape mismatch";
98 template<
typename Reducer,
typename SrcExp,
typename DType,
int etype>
99 inline UnPoolingExp<Reducer, SrcExp, DType, ExpInfo<SrcExp>::kDim>
105 (data_src.
self(), data_pooled.
self(), grad_pooled.
self(),
106 ksize_y, ksize_x, kstride_y, kstride_x);
111 template<
typename Reducer,
typename SrcExp,
typename DType,
int srcdim>
116 grad_pooled_(
MakePlan(e.grad_pooled_)), sshape_y_(e.shape_[srcdim - 2]),
117 pshape_y_(e.pshape_y_), pshape_x_(e.pshape_x_),
118 ksize_y_(e.ksize_y_), ksize_x_(e.ksize_x_),
119 kstride_y_(e.kstride_y_), kstride_x_(e.kstride_x_) {}
123 const index_t y = i % sshape_y_;
124 const index_t c = i / sshape_y_;
125 const DType vsrc = data_src_.Eval(i, j);
127 y < ksize_y_ ? 0 : (y - ksize_y_ + kstride_y_) / kstride_y_;
129 x < ksize_x_ ? 0 : (x - ksize_x_ + kstride_x_) / kstride_x_;
130 const index_t py_max = min((y + kstride_y_) / kstride_y_, pshape_y_);
131 const index_t px_max = min((x + kstride_x_) / kstride_x_, pshape_x_);
133 DType val =
static_cast<DType
>(0);
134 for (
index_t py = py_min; py < py_max; ++py) {
135 for (
index_t px = px_min; px < px_max; ++px) {
136 val += Reducer::PartialGrad(vsrc,
137 data_pooled_.Eval(c * pshape_y_ + py, px)) *
138 grad_pooled_.Eval(c * pshape_y_ + py, px);
147 const index_t sshape_y_, pshape_y_, pshape_x_;
148 const index_t ksize_y_, ksize_x_;
149 const index_t kstride_y_, kstride_x_;
153 #endif // MSHADOW_EXTENSION_SPATIAL_UNPOOL_H_
const SrcExp & data_pooled_
result of pooled data, corresponds to result of pooling
Definition: spatial_unpool.h:45
const SubType & self(void) const
Definition: expression.h:82
UnPoolingExp(const SrcExp &data_src, const SrcExp &data_pooled, const SrcExp &grad_pooled, index_t ksize_y, index_t ksize_x, index_t kstride_y, index_t kstride_x)
constructor
Definition: spatial_unpool.h:61
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const
Definition: spatial_unpool.h:120
index_t ksize_y_
kernel size in height
Definition: spatial_unpool.h:53
#define MSHADOW_XINLINE
Definition: base.h:228
runtime shape checking template get the shape of an expression, report error if shape mismatch
Definition: expr_engine-inl.h:364
index_t pshape_y_
shape of pooled expression
Definition: spatial_unpool.h:49
unpooling expr reverse operation of pooling, used to pass gradient back
Definition: spatial_unpool.h:39
index_t kstride_y_
kernel stride in y directory
Definition: spatial_unpool.h:57
static Shape< dim > Check(const E &t)
Plan(const UnPoolingExp< Reducer, SrcExp, DType, srcdim > &e)
Definition: spatial_unpool.h:114
Plan< BinaryMapExp< OP, TA, TB, DType, etype >, DType > MakePlan(const BinaryMapExp< OP, TA, TB, DType, etype > &e)
Definition: expr_engine-inl.h:239
Shape< dim > shape_
the shape of this expression
Definition: expr_engine-inl.h:47
index_t kstride_x_
kernel stride in x directory
Definition: spatial_unpool.h:59
index_t pshape_x_
shape of pooled expression
Definition: spatial_unpool.h:51
UnPoolingExp< Reducer, SrcExp, DType, ExpInfo< SrcExp >::kDim > unpool(const Exp< SrcExp, DType, etype > &data_src, const Exp< SrcExp, DType, etype > &data_pooled, const Exp< SrcExp, DType, etype > &grad_pooled, index_t ksize_y, index_t ksize_x, index_t kstride_y, index_t kstride_x)
unpooling gradient for 4D, backprop gradient value back, revserse operation of pooling,...
Definition: spatial_unpool.h:100
int32_t index_t
type that will be used for index
Definition: base.h:328
Definition: expr_engine-inl.h:58
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:79
const SrcExp & grad_pooled_
gradient data of pooled part, to be propgate down
Definition: spatial_unpool.h:47
overloaded + operator between half_t and bf16_t
Definition: base.h:319
a general class that allows extension that makes tensors of some shape
Definition: expr_engine-inl.h:43
Definition: optional.h:251
const SrcExp & data_src_
source input, corresponds to src in pooling
Definition: spatial_unpool.h:43
index_t ksize_x_
kernel size in width
Definition: spatial_unpool.h:55