Go to the documentation of this file.
25 #ifndef MSHADOW_EXTENSION_CHANNEL_UNPOOL_H_
26 #define MSHADOW_EXTENSION_CHANNEL_UNPOOL_H_
28 #include "../extension.h"
39 template<
typename Reducer,
typename SrcExp,
typename DType,
int srcdim>
41 public MakeTensorExp<ChannelUnpoolingExp<Reducer, SrcExp, DType, srcdim>,
42 SrcExp, srcdim, DType> {
59 const SrcExp &data_pooled,
60 const SrcExp &grad_pooled,
67 CHECK_EQ(pshape, ShapeCheckSrcDimSrcExp::Check(data_pooled))
68 <<
"ChannelUnPoolingExp: data and grad shape mismatch";
70 for (
int k = 0; k < srcdim; ++k) {
74 CHECK_EQ(pshape[k], sshape[k])
75 <<
"ChannelUnPoolingExp: pooled tensor and src tensor shape mismatch"
96 template<
typename Reducer,
typename SrcExp,
typename DType,
int etype>
97 inline ChannelUnpoolingExp<Reducer, SrcExp, DType, ExpInfo<SrcExp>::kDim>
103 ::Error_Expression_Does_Not_Meet_Dimension_Req();
105 (data_src.
self(), data_pooled.
self(), grad_pooled.
self(), nsize, stride,
pad);
108 template<
typename Reducer,
typename SrcExp,
typename DType,
int etype>
109 inline ChannelUnpoolingExp<Reducer, SrcExp, DType, ExpInfo<SrcExp>::kDim>
113 return ch_unpool(data_src, data_pooled, grad_pooled, nsize, 1, nsize / 2);
120 template<
typename Reducer,
typename SrcExp,
typename DType,
int srcdim>
124 : data_src_(e.data_src_), data_pooled_(e.data_pooled_),
125 grad_pooled_(e.grad_pooled_), channel_(e.shape_[srcdim - 3]),
126 height_(e.shape_[srcdim - 2]), pchannel_(e.pchannel_),
127 hnsize_(e.nsize_), stride_(e.kstride_), pad_(e.pad_) {}
130 const DType vsrc = data_src_.Eval(i, j);
133 const index_t c = i % channel_;
134 const index_t n = i / channel_;
136 const index_t cstart = c < hnsize_ - pad_ ? 0
137 : (c - (hnsize_ - pad_) + stride_) / stride_;
138 const index_t cend = min((c + pad_ + stride_) / stride_, channel_);
139 DType val =
static_cast<DType
>(0);
140 for (
index_t cc = cstart; cc < cend; ++cc) {
141 val += Reducer::PartialGrad(vsrc,
142 data_pooled_.Eval((n * pchannel_ + cc) * height_ + y, x)) *
143 grad_pooled_.Eval((n * pchannel_ + cc) * height_ + y, x);
150 const index_t channel_, height_, pchannel_, hnsize_, stride_, pad_;
154 #endif // MSHADOW_EXTENSION_CHANNEL_UNPOOL_H_
const SubType & self(void) const
Definition: expression.h:82
used to help static type check
Definition: expr_engine-inl.h:330
const SrcExp & grad_pooled_
gradient data of pooled part, to be propgate down
Definition: channel_unpool.h:48
index_t pad_
pad
Definition: channel_unpool.h:56
#define MSHADOW_XINLINE
Definition: base.h:228
runtime shape checking template get the shape of an expression, report error if shape mismatch
Definition: expr_engine-inl.h:364
channel pooling expression, do reduction over (local nearby) channels, used to implement local respon...
Definition: channel_unpool.h:40
Plan(const ChannelUnpoolingExp< Reducer, SrcExp, DType, srcdim > &e)
Definition: channel_unpool.h:123
static Shape< dim > Check(const E &t)
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const
Definition: channel_unpool.h:128
Shape< dim > shape_
the shape of this expression
Definition: expr_engine-inl.h:47
const SrcExp & data_pooled_
result of pooled data, corresponds to result of pooling
Definition: channel_unpool.h:46
index_t pchannel_
channel of pooled expression
Definition: channel_unpool.h:50
int32_t index_t
type that will be used for index
Definition: base.h:328
Definition: expr_engine-inl.h:58
const SrcExp & data_src_
source input, corresponds to src in pooling
Definition: channel_unpool.h:44
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:79
PaddingExp< SrcExp, DType, ExpInfo< SrcExp >::kDim > pad(const Exp< SrcExp, DType, etype > &src, index_t pad)
padding expression, pad a image with zeros on boundaries, padding affects shape[0],...
Definition: pad.h:71
index_t nsize_
kernel size in height
Definition: channel_unpool.h:52
overloaded + operator between half_t and bf16_t
Definition: base.h:319
a general class that allows extension that makes tensors of some shape
Definition: expr_engine-inl.h:43
index_t kstride_
kernel size in width
Definition: channel_unpool.h:54
Definition: optional.h:251
ChannelUnpoolingExp< Reducer, SrcExp, DType, ExpInfo< SrcExp >::kDim > ch_unpool(const Exp< SrcExp, DType, etype > &data_src, const Exp< SrcExp, DType, etype > &data_pooled, const Exp< SrcExp, DType, etype > &grad_pooled, index_t nsize, index_t stride, index_t pad)
channel unpooling, do unroll over (local nearby) channels
Definition: channel_unpool.h:98
ChannelUnpoolingExp(const SrcExp &data_src, const SrcExp &data_pooled, const SrcExp &grad_pooled, index_t nsize, index_t kstride, index_t pad)
constructor
Definition: channel_unpool.h:58