mxnet
spatial_unpool.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
25 #ifndef MSHADOW_EXTENSION_SPATIAL_UNPOOL_H_
26 #define MSHADOW_EXTENSION_SPATIAL_UNPOOL_H_
27 #include <algorithm>
28 #include "../extension.h"
29 namespace mshadow {
30 namespace expr {
38 template<typename Reducer, typename SrcExp, typename DType, int srcdim>
39 struct UnPoolingExp:
40  public MakeTensorExp<UnPoolingExp<Reducer, SrcExp, DType, srcdim>,
41  SrcExp, srcdim, DType> {
43  const SrcExp &data_src_;
45  const SrcExp &data_pooled_;
47  const SrcExp &grad_pooled_;
61  UnPoolingExp(const SrcExp &data_src,
62  const SrcExp &data_pooled,
63  const SrcExp &grad_pooled,
64  index_t ksize_y, index_t ksize_x, index_t kstride_y, index_t kstride_x)
65  : data_src_(data_src), data_pooled_(data_pooled),
66  grad_pooled_(grad_pooled),
67  ksize_y_(ksize_y), ksize_x_(ksize_x),
68  kstride_y_(kstride_y), kstride_x_(kstride_x) {
70  typedef ShapeCheck<srcdim, SrcExp> ShapeCheckSrcDimSrcExp;
71  CHECK_EQ(pshape, ShapeCheckSrcDimSrcExp::Check(data_pooled))
72  << "UnPoolingExp: pooled shape mismatch";
74  for (int k = 0; k < srcdim - 2; ++k) {
75  CHECK_EQ(pshape[k], sshape[k]) << "UnPoolingExp: pool and src shape mismatch";
76  }
77  pshape_x_ = pshape[srcdim - 1];
78  pshape_y_ = pshape[srcdim - 2];
79  this->shape_ = sshape;
80  }
81 };
98 template<typename Reducer, typename SrcExp, typename DType, int etype>
99 inline UnPoolingExp<Reducer, SrcExp, DType, ExpInfo<SrcExp>::kDim>
101  const Exp<SrcExp, DType, etype> &data_pooled,
102  const Exp<SrcExp, DType, etype> &grad_pooled,
103  index_t ksize_y, index_t ksize_x, index_t kstride_y, index_t kstride_x) {
105  (data_src.self(), data_pooled.self(), grad_pooled.self(),
106  ksize_y, ksize_x, kstride_y, kstride_x);
107 }
108 //----------------------
109 // Execution plan
110 //----------------------
111 template<typename Reducer, typename SrcExp, typename DType, int srcdim>
112 struct Plan<UnPoolingExp<Reducer, SrcExp, DType, srcdim>, DType> {
113  public:
115  : data_src_(MakePlan(e.data_src_)), data_pooled_(MakePlan(e.data_pooled_)),
116  grad_pooled_(MakePlan(e.grad_pooled_)), sshape_y_(e.shape_[srcdim - 2]),
117  pshape_y_(e.pshape_y_), pshape_x_(e.pshape_x_),
118  ksize_y_(e.ksize_y_), ksize_x_(e.ksize_x_),
119  kstride_y_(e.kstride_y_), kstride_x_(e.kstride_x_) {}
120  MSHADOW_XINLINE DType Eval(index_t i, index_t j) const {
121  using namespace std;
122  const index_t x = j;
123  const index_t y = i % sshape_y_;
124  const index_t c = i / sshape_y_;
125  const DType vsrc = data_src_.Eval(i, j);
126  const index_t py_min =
127  y < ksize_y_ ? 0 : (y - ksize_y_ + kstride_y_) / kstride_y_;
128  const index_t px_min =
129  x < ksize_x_ ? 0 : (x - ksize_x_ + kstride_x_) / kstride_x_;
130  const index_t py_max = min((y + kstride_y_) / kstride_y_, pshape_y_);
131  const index_t px_max = min((x + kstride_x_) / kstride_x_, pshape_x_);
132 
133  DType val = static_cast<DType>(0);
134  for (index_t py = py_min; py < py_max; ++py) {
135  for (index_t px = px_min; px < px_max; ++px) {
136  val += Reducer::PartialGrad(vsrc,
137  data_pooled_.Eval(c * pshape_y_ + py, px)) *
138  grad_pooled_.Eval(c * pshape_y_ + py, px);
139  }
140  }
141 
142  return val;
143  }
144 
145  private:
146  Plan<SrcExp, DType> data_src_, data_pooled_, grad_pooled_;
147  const index_t sshape_y_, pshape_y_, pshape_x_;
148  const index_t ksize_y_, ksize_x_;
149  const index_t kstride_y_, kstride_x_;
150 };
151 } // namespace expr
152 } // namespace mshadow
153 #endif // MSHADOW_EXTENSION_SPATIAL_UNPOOL_H_
mshadow::expr::UnPoolingExp::data_pooled_
const SrcExp & data_pooled_
result of pooled data, corresponds to result of pooling
Definition: spatial_unpool.h:45
mshadow::expr::Exp::self
const SubType & self(void) const
Definition: expression.h:82
mshadow::expr::UnPoolingExp::UnPoolingExp
UnPoolingExp(const SrcExp &data_src, const SrcExp &data_pooled, const SrcExp &grad_pooled, index_t ksize_y, index_t ksize_x, index_t kstride_y, index_t kstride_x)
constructor
Definition: spatial_unpool.h:61
mshadow::expr::Plan< UnPoolingExp< Reducer, SrcExp, DType, srcdim >, DType >::Eval
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const
Definition: spatial_unpool.h:120
mshadow::expr::UnPoolingExp::ksize_y_
index_t ksize_y_
kernel size in height
Definition: spatial_unpool.h:53
MSHADOW_XINLINE
#define MSHADOW_XINLINE
Definition: base.h:228
mshadow::expr::ShapeCheck
runtime shape checking template get the shape of an expression, report error if shape mismatch
Definition: expr_engine-inl.h:364
mshadow::expr::UnPoolingExp::pshape_y_
index_t pshape_y_
shape of pooled expression
Definition: spatial_unpool.h:49
mshadow::expr::UnPoolingExp
unpooling expr reverse operation of pooling, used to pass gradient back
Definition: spatial_unpool.h:39
mshadow::expr::UnPoolingExp::kstride_y_
index_t kstride_y_
kernel stride in y directory
Definition: spatial_unpool.h:57
mshadow::expr::ShapeCheck::Check
static Shape< dim > Check(const E &t)
mshadow::expr::Plan< UnPoolingExp< Reducer, SrcExp, DType, srcdim >, DType >::Plan
Plan(const UnPoolingExp< Reducer, SrcExp, DType, srcdim > &e)
Definition: spatial_unpool.h:114
mshadow::expr::MakePlan
Plan< BinaryMapExp< OP, TA, TB, DType, etype >, DType > MakePlan(const BinaryMapExp< OP, TA, TB, DType, etype > &e)
Definition: expr_engine-inl.h:239
mshadow::expr::MakeTensorExp< UnPoolingExp< Reducer, SrcExp, DType, srcdim >, SrcExp, srcdim, DType >::shape_
Shape< dim > shape_
the shape of this expression
Definition: expr_engine-inl.h:47
mshadow::expr::UnPoolingExp::kstride_x_
index_t kstride_x_
kernel stride in x directory
Definition: spatial_unpool.h:59
mshadow::expr::UnPoolingExp::pshape_x_
index_t pshape_x_
shape of pooled expression
Definition: spatial_unpool.h:51
mshadow::expr::unpool
UnPoolingExp< Reducer, SrcExp, DType, ExpInfo< SrcExp >::kDim > unpool(const Exp< SrcExp, DType, etype > &data_src, const Exp< SrcExp, DType, etype > &data_pooled, const Exp< SrcExp, DType, etype > &grad_pooled, index_t ksize_y, index_t ksize_x, index_t kstride_y, index_t kstride_x)
unpooling gradient for 4D, backprop gradient value back, revserse operation of pooling,...
Definition: spatial_unpool.h:100
mshadow::index_t
int32_t index_t
type that will be used for index
Definition: base.h:328
mshadow::expr::Plan
Definition: expr_engine-inl.h:58
mshadow::expr::Exp
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:79
mshadow::expr::UnPoolingExp::grad_pooled_
const SrcExp & grad_pooled_
gradient data of pooled part, to be propgate down
Definition: spatial_unpool.h:47
mshadow
overloaded + operator between half_t and bf16_t
Definition: base.h:319
mshadow::expr::MakeTensorExp
a general class that allows extension that makes tensors of some shape
Definition: expr_engine-inl.h:43
std
Definition: optional.h:251
mshadow::Shape< srcdim >
mshadow::expr::UnPoolingExp::data_src_
const SrcExp & data_src_
source input, corresponds to src in pooling
Definition: spatial_unpool.h:43
mshadow::expr::UnPoolingExp::ksize_x_
index_t ksize_x_
kernel size in width
Definition: spatial_unpool.h:55