mxnet
channel_unpool.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
25 #ifndef MSHADOW_EXTENSION_CHANNEL_UNPOOL_H_
26 #define MSHADOW_EXTENSION_CHANNEL_UNPOOL_H_
27 #include <algorithm>
28 #include "../extension.h"
29 namespace mshadow {
30 namespace expr {
39 template<typename Reducer, typename SrcExp, typename DType, int srcdim>
41  public MakeTensorExp<ChannelUnpoolingExp<Reducer, SrcExp, DType, srcdim>,
42  SrcExp, srcdim, DType> {
44  const SrcExp &data_src_;
46  const SrcExp &data_pooled_;
48  const SrcExp &grad_pooled_;
58  ChannelUnpoolingExp(const SrcExp &data_src,
59  const SrcExp &data_pooled,
60  const SrcExp &grad_pooled,
61  index_t nsize, index_t kstride, index_t pad)
62  : data_src_(data_src), data_pooled_(data_pooled),
63  grad_pooled_(grad_pooled),
64  nsize_(nsize), kstride_(kstride), pad_(pad) {
66  typedef ShapeCheck<srcdim, SrcExp> ShapeCheckSrcDimSrcExp;
67  CHECK_EQ(pshape, ShapeCheckSrcDimSrcExp::Check(data_pooled))
68  << "ChannelUnPoolingExp: data and grad shape mismatch";
70  for (int k = 0; k < srcdim; ++k) {
71  if (k == 1) {
72  continue;
73  }
74  CHECK_EQ(pshape[k], sshape[k])
75  << "ChannelUnPoolingExp: pooled tensor and src tensor shape mismatch"
76  << pshape[k]
77  << " vs "
78  << sshape[k];
79  }
80  pchannel_ = pshape[1];
81  this->shape_ = sshape;
82  }
83 };
96 template<typename Reducer, typename SrcExp, typename DType, int etype>
97 inline ChannelUnpoolingExp<Reducer, SrcExp, DType, ExpInfo<SrcExp>::kDim>
99  const Exp<SrcExp, DType, etype> &data_pooled,
100  const Exp<SrcExp, DType, etype> &grad_pooled,
101  index_t nsize, index_t stride, index_t pad) {
103  ::Error_Expression_Does_Not_Meet_Dimension_Req();
105  (data_src.self(), data_pooled.self(), grad_pooled.self(), nsize, stride, pad);
106 }
107 
108 template<typename Reducer, typename SrcExp, typename DType, int etype>
109 inline ChannelUnpoolingExp<Reducer, SrcExp, DType, ExpInfo<SrcExp>::kDim>
111  const Exp<SrcExp, DType, etype> &data_pooled,
112  const Exp<SrcExp, DType, etype> &grad_pooled, index_t nsize) {
113  return ch_unpool(data_src, data_pooled, grad_pooled, nsize, 1, nsize / 2);
114 }
115 
116 
117 //----------------------
118 // Execution plan
119 //----------------------
120 template<typename Reducer, typename SrcExp, typename DType, int srcdim>
121 struct Plan<ChannelUnpoolingExp<Reducer, SrcExp, DType, srcdim>, DType> {
122  public:
124  : data_src_(e.data_src_), data_pooled_(e.data_pooled_),
125  grad_pooled_(e.grad_pooled_), channel_(e.shape_[srcdim - 3]),
126  height_(e.shape_[srcdim - 2]), pchannel_(e.pchannel_),
127  hnsize_(e.nsize_), stride_(e.kstride_), pad_(e.pad_) {}
128  MSHADOW_XINLINE DType Eval(index_t i, index_t j) const {
129  using namespace std;
130  const DType vsrc = data_src_.Eval(i, j);
131  const index_t y = i % height_;
132  i /= height_;
133  const index_t c = i % channel_;
134  const index_t n = i / channel_;
135  const index_t x = j;
136  const index_t cstart = c < hnsize_ - pad_ ? 0
137  : (c - (hnsize_ - pad_) + stride_) / stride_;
138  const index_t cend = min((c + pad_ + stride_) / stride_, channel_);
139  DType val = static_cast<DType>(0);
140  for (index_t cc = cstart; cc < cend; ++cc) {
141  val += Reducer::PartialGrad(vsrc,
142  data_pooled_.Eval((n * pchannel_ + cc) * height_ + y, x)) *
143  grad_pooled_.Eval((n * pchannel_ + cc) * height_ + y, x);
144  }
145  return val;
146  }
147 
148  private:
149  Plan<SrcExp, DType> data_src_, data_pooled_, grad_pooled_;
150  const index_t channel_, height_, pchannel_, hnsize_, stride_, pad_;
151 };
152 } // namespace expr
153 } // namespace mshadow
154 #endif // MSHADOW_EXTENSION_CHANNEL_UNPOOL_H_
155 
mshadow::expr::Exp::self
const SubType & self(void) const
Definition: expression.h:82
mshadow::expr::TypeCheckPass
used to help static type check
Definition: expr_engine-inl.h:330
mshadow::expr::ChannelUnpoolingExp::grad_pooled_
const SrcExp & grad_pooled_
gradient data of pooled part, to be propgate down
Definition: channel_unpool.h:48
mshadow::expr::ChannelUnpoolingExp::pad_
index_t pad_
pad
Definition: channel_unpool.h:56
MSHADOW_XINLINE
#define MSHADOW_XINLINE
Definition: base.h:228
mshadow::expr::ShapeCheck
runtime shape checking template get the shape of an expression, report error if shape mismatch
Definition: expr_engine-inl.h:364
mshadow::expr::ChannelUnpoolingExp
channel pooling expression, do reduction over (local nearby) channels, used to implement local respon...
Definition: channel_unpool.h:40
mshadow::expr::Plan< ChannelUnpoolingExp< Reducer, SrcExp, DType, srcdim >, DType >::Plan
Plan(const ChannelUnpoolingExp< Reducer, SrcExp, DType, srcdim > &e)
Definition: channel_unpool.h:123
mshadow::expr::ShapeCheck::Check
static Shape< dim > Check(const E &t)
mshadow::expr::Plan< ChannelUnpoolingExp< Reducer, SrcExp, DType, srcdim >, DType >::Eval
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const
Definition: channel_unpool.h:128
mshadow::expr::MakeTensorExp< ChannelUnpoolingExp< Reducer, SrcExp, DType, srcdim >, SrcExp, srcdim, DType >::shape_
Shape< dim > shape_
the shape of this expression
Definition: expr_engine-inl.h:47
mshadow::expr::ChannelUnpoolingExp::data_pooled_
const SrcExp & data_pooled_
result of pooled data, corresponds to result of pooling
Definition: channel_unpool.h:46
mshadow::expr::ChannelUnpoolingExp::pchannel_
index_t pchannel_
channel of pooled expression
Definition: channel_unpool.h:50
mshadow::index_t
int32_t index_t
type that will be used for index
Definition: base.h:328
mshadow::expr::Plan
Definition: expr_engine-inl.h:58
mshadow::expr::ChannelUnpoolingExp::data_src_
const SrcExp & data_src_
source input, corresponds to src in pooling
Definition: channel_unpool.h:44
mshadow::expr::Exp
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:79
mshadow::expr::pad
PaddingExp< SrcExp, DType, ExpInfo< SrcExp >::kDim > pad(const Exp< SrcExp, DType, etype > &src, index_t pad)
padding expression, pad a image with zeros on boundaries, padding affects shape[0],...
Definition: pad.h:71
mshadow::expr::ChannelUnpoolingExp::nsize_
index_t nsize_
kernel size in height
Definition: channel_unpool.h:52
mshadow
overloaded + operator between half_t and bf16_t
Definition: base.h:319
mshadow::expr::MakeTensorExp
a general class that allows extension that makes tensors of some shape
Definition: expr_engine-inl.h:43
mshadow::expr::ChannelUnpoolingExp::kstride_
index_t kstride_
kernel size in width
Definition: channel_unpool.h:54
std
Definition: optional.h:251
mshadow::Shape< srcdim >
mshadow::expr::ch_unpool
ChannelUnpoolingExp< Reducer, SrcExp, DType, ExpInfo< SrcExp >::kDim > ch_unpool(const Exp< SrcExp, DType, etype > &data_src, const Exp< SrcExp, DType, etype > &data_pooled, const Exp< SrcExp, DType, etype > &grad_pooled, index_t nsize, index_t stride, index_t pad)
channel unpooling, do unroll over (local nearby) channels
Definition: channel_unpool.h:98
mshadow::expr::ChannelUnpoolingExp::ChannelUnpoolingExp
ChannelUnpoolingExp(const SrcExp &data_src, const SrcExp &data_pooled, const SrcExp &grad_pooled, index_t nsize, index_t kstride, index_t pad)
constructor
Definition: channel_unpool.h:58