mxnet
pack_col2patch.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
25 #ifndef MSHADOW_EXTENSION_PACK_COL2PATCH_H_
26 #define MSHADOW_EXTENSION_PACK_COL2PATCH_H_
27 #include <algorithm>
28 #include "../extension.h"
29 namespace mshadow {
30 namespace expr {
39 template<typename SrcExp, typename DType, int dstdim>
41  public MakeTensorExp<PackColToPatchXExp<SrcExp, DType, dstdim>,
42  SrcExp, dstdim, DType> {
44  const SrcExp &src_;
56  PackColToPatchXExp(const SrcExp &src, Shape<dstdim> imshape,
57  index_t psize_y, index_t psize_x,
58  index_t pstride_y, index_t pstride_x,
59  index_t pdilate_y, index_t pdilate_x)
60  :src_(src), psize_y_(psize_y), psize_x_(psize_x),
61  pstride_y_(pstride_y), pstride_x_(pstride_x),
62  pdilate_y_(pdilate_y), pdilate_x_(pdilate_x){
63  this->shape_ = imshape;
64  const index_t o_height = (imshape[dstdim - 2] -
65  (pdilate_y * (psize_y - 1)+ 1))/pstride_y + 1;
66  const index_t o_width = (imshape[dstdim - 1] -
67  (pdilate_x * (psize_x - 1) + 1)) / pstride_x + 1;
69  CHECK_EQ(sshape[1], o_height * o_width * imshape.ProdShape(0, dstdim - 3))
70  << "PackColToPatchExp: src.size(1) mismatch";
71  CHECK_EQ(sshape[0], psize_y * psize_x * imshape[dstdim - 3])
72  << "PackColToPatchExp: src.size(0) mismatch";
73  }
74 };
88 template<typename SrcExp, typename DType, int dstdim, int etype>
89 inline PackColToPatchXExp<SrcExp, DType, dstdim>
91  Shape<dstdim> imshape, index_t psize_y,
92  index_t psize_x, index_t pstride, index_t pdilate) {
94  ::Error_Expression_Does_Not_Meet_Dimension_Req();
95  CHECK(imshape[dstdim - 1] >= psize_x && imshape[dstdim - 2] >= psize_y)
96  << "PackColToPatch:image shape smaller than patch size";
97  return PackColToPatchXExp<SrcExp, DType, dstdim>(src.self(), imshape,
98  psize_y, psize_x, pstride, pstride,
99  pdilate, pdilate);
100 }
104 template<typename SrcExp, typename DType, int dstdim, int etype>
105 inline PackColToPatchXExp<SrcExp, DType, dstdim>
107  Shape<dstdim> imshape, index_t psize_y,
108  index_t psize_x, index_t pstride_y, index_t pstride_x,
109  index_t pdilate_y, index_t pdilate_x) {
111  ::Error_Expression_Does_Not_Meet_Dimension_Req();
112  CHECK(imshape[dstdim - 1] >= psize_x && imshape[dstdim - 2] >= psize_y)
113  << "PackColToPatch:image shape smaller than patch size";
114  return PackColToPatchXExp<SrcExp, DType, dstdim>(src.self(), imshape,
115  psize_y, psize_x, pstride_y, pstride_x,
116  pdilate_y, pdilate_x);
117 }
118 
119 //----------------------
120 // Execution plan
121 //----------------------
122 template<typename SrcExp, typename DType, int dstdim>
123 struct Plan<PackColToPatchXExp<SrcExp, DType, dstdim>, DType> {
124  public:
126  :src_(MakePlan(e.src_)), psize_y_(e.psize_y_),
127  psize_x_(e.psize_x_), pstride_y_(e.pstride_y_), pstride_x_(e.pstride_x_),
128  i_channel_(e.shape_[dstdim - 3]), pdilate_y_(e.pdilate_y_), pdilate_x_(e.pdilate_x_),
129  i_height_(e.shape_[dstdim - 2]),
130  o_height_((e.shape_[dstdim - 2] - (pdilate_y_ * (psize_y_ - 1) + 1)) /
131  pstride_y_ + 1),
132  o_width_((e.shape_[dstdim - 1] - (pdilate_x_ * (psize_x_ - 1) + 1)) /
133  pstride_x_ + 1) {
134  // note: i/o convention are same as unpack
135  }
136  MSHADOW_XINLINE DType Eval(index_t i, index_t j) const {
137  using namespace std;
138  const index_t y = i % i_height_;
139  const index_t idivh = i / i_height_;
140  const index_t c = idivh % i_channel_;
141  const index_t n = idivh / i_channel_;
142  const index_t x = j;
143 
144  const index_t psize_y_dilate = (pdilate_y_ * (psize_y_ - 1) + 1);
145  const index_t psize_x_dilate = (pdilate_x_ * (psize_x_ - 1) + 1);
146 
147  const index_t py_min =
148  y < psize_y_dilate ? y % pdilate_y_ : (y-psize_y_dilate + pstride_y_) / pstride_y_;
149  const index_t px_min =
150  x < psize_x_dilate ? x % pdilate_x_ : (x-psize_x_dilate + pstride_x_) / pstride_x_;
151  const index_t py_max = min((y + pstride_y_) / pstride_y_, o_height_);
152  const index_t px_max = min((x + pstride_x_) / pstride_x_, o_width_);
153  DType res = static_cast<DType>(0);
154  for (index_t py = py_min; py < py_max; py += pdilate_y_) {
155  for (index_t px = px_min; px < px_max; px += pdilate_x_) {
156  res += src_.Eval(((c * psize_y_ + (y - py*pstride_y_) / pdilate_y_) * psize_x_ +
157  (x - px * pstride_x_) / pdilate_x_),
158  (n * o_height_ + py) * o_width_ + px);
159  }
160  }
161  return res;
162  }
163 
164  private:
165  Plan<SrcExp, DType> src_;
166  const index_t psize_y_, psize_x_, pstride_y_, pstride_x_, i_channel_;
167  const index_t pdilate_y_, pdilate_x_;
168  const index_t i_height_, o_height_, o_width_;
169 };
170 } // namespace expr
171 } // namespace mshadow
172 #endif // MSHADOW_EXTENSION_PACK_COL2PATCH_H_
mshadow::expr::PackColToPatchXExp::PackColToPatchXExp
PackColToPatchXExp(const SrcExp &src, Shape< dstdim > imshape, index_t psize_y, index_t psize_x, index_t pstride_y, index_t pstride_x, index_t pdilate_y, index_t pdilate_x)
constructor
Definition: pack_col2patch.h:56
mshadow::expr::Exp::self
const SubType & self(void) const
Definition: expression.h:82
mshadow::expr::TypeCheckPass
used to help static type check
Definition: expr_engine-inl.h:330
mshadow::expr::Plan< PackColToPatchXExp< SrcExp, DType, dstdim >, DType >::Eval
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const
Definition: pack_col2patch.h:136
MSHADOW_XINLINE
#define MSHADOW_XINLINE
Definition: base.h:228
mshadow::expr::PackColToPatchXExp::pdilate_x_
index_t pdilate_x_
Definition: pack_col2patch.h:54
mshadow::expr::PackColToPatchXExp::psize_x_
index_t psize_x_
patch height
Definition: pack_col2patch.h:48
mshadow::expr::ShapeCheck::Check
static Shape< dim > Check(const E &t)
mshadow::expr::MakePlan
Plan< BinaryMapExp< OP, TA, TB, DType, etype >, DType > MakePlan(const BinaryMapExp< OP, TA, TB, DType, etype > &e)
Definition: expr_engine-inl.h:239
mshadow::expr::MakeTensorExp< PackColToPatchXExp< SrcExp, DType, dstdim >, SrcExp, dstdim, DType >::shape_
Shape< dim > shape_
the shape of this expression
Definition: expr_engine-inl.h:47
mshadow::expr::PackColToPatchXExp::psize_y_
index_t psize_y_
patch height
Definition: pack_col2patch.h:46
mshadow::index_t
int32_t index_t
type that will be used for index
Definition: base.h:328
mshadow::expr::pack_col2patch
PackColToPatchXExp< SrcExp, DType, dstdim > pack_col2patch(const expr::Exp< SrcExp, DType, etype > &src, Shape< dstdim > imshape, index_t psize_y, index_t psize_x, index_t pstride, index_t pdilate)
reverse operation of pack_col2patch, can be used to implement deconvolution
Definition: pack_col2patch.h:90
mshadow::expr::Plan
Definition: expr_engine-inl.h:58
mshadow::expr::Plan< PackColToPatchXExp< SrcExp, DType, dstdim >, DType >::Plan
Plan(const PackColToPatchXExp< SrcExp, DType, dstdim > &e)
Definition: pack_col2patch.h:125
mshadow::expr::Exp
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:79
mshadow::expr::PackColToPatchXExp::pdilate_y_
index_t pdilate_y_
patch dilate
Definition: pack_col2patch.h:53
mshadow
overloaded + operator between half_t and bf16_t
Definition: base.h:319
mshadow::expr::MakeTensorExp
a general class that allows extension that makes tensors of some shape
Definition: expr_engine-inl.h:43
mshadow::expr::PackColToPatchXExp
reverse operation of UnpackPatchToCol, used to backprop gradient back this is a version supporting mu...
Definition: pack_col2patch.h:40
std
Definition: optional.h:251
mshadow::expr::PackColToPatchXExp::pstride_y_
index_t pstride_y_
patch stride
Definition: pack_col2patch.h:50
mshadow::Shape
shape of a tensor
Definition: tensor.h:64
mshadow::expr::PackColToPatchXExp::src_
const SrcExp & src_
source operand
Definition: pack_col2patch.h:44
mshadow::expr::PackColToPatchXExp::pstride_x_
index_t pstride_x_
Definition: pack_col2patch.h:51
mshadow::Shape::ProdShape
MSHADOW_XINLINE index_t ProdShape(int dimstart, int dimend) const
Definition: tensor.h:171