mxnet
pack_col2patch.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
26 #ifndef MSHADOW_EXTENSION_PACK_COL2PATCH_H_
27 #define MSHADOW_EXTENSION_PACK_COL2PATCH_H_
28 #include <algorithm>
29 #include "../extension.h"
30 namespace mshadow {
31 namespace expr {
40 template<typename SrcExp, typename DType, int dstdim>
42  public MakeTensorExp<PackColToPatchXExp<SrcExp, DType, dstdim>,
43  SrcExp, dstdim, DType> {
45  const SrcExp &src_;
57  PackColToPatchXExp(const SrcExp &src, Shape<dstdim> imshape,
58  index_t psize_y, index_t psize_x,
59  index_t pstride_y, index_t pstride_x,
60  index_t pdilate_y, index_t pdilate_x)
61  :src_(src), psize_y_(psize_y), psize_x_(psize_x),
62  pstride_y_(pstride_y), pstride_x_(pstride_x),
63  pdilate_y_(pdilate_y), pdilate_x_(pdilate_x){
64  this->shape_ = imshape;
65  const index_t o_height = (imshape[dstdim - 2] -
66  (pdilate_y * (psize_y - 1)+ 1))/pstride_y + 1;
67  const index_t o_width = (imshape[dstdim - 1] -
68  (pdilate_x * (psize_x - 1) + 1)) / pstride_x + 1;
70  CHECK_EQ(sshape[1], o_height * o_width * imshape.ProdShape(0, dstdim - 3))
71  << "PackColToPatchExp: src.size(1) mismatch";
72  CHECK_EQ(sshape[0], psize_y * psize_x * imshape[dstdim - 3])
73  << "PackColToPatchExp: src.size(0) mismatch";
74  }
75 };
89 template<typename SrcExp, typename DType, int dstdim, int etype>
92  Shape<dstdim> imshape, index_t psize_y,
93  index_t psize_x, index_t pstride, index_t pdilate) {
95  ::Error_Expression_Does_Not_Meet_Dimension_Req();
96  CHECK(imshape[dstdim - 1] >= psize_x && imshape[dstdim - 2] >= psize_y)
97  << "PackColToPatch:image shape smaller than patch size";
98  return PackColToPatchXExp<SrcExp, DType, dstdim>(src.self(), imshape,
99  psize_y, psize_x, pstride, pstride,
100  pdilate, pdilate);
101 }
105 template<typename SrcExp, typename DType, int dstdim, int etype>
108  Shape<dstdim> imshape, index_t psize_y,
109  index_t psize_x, index_t pstride_y, index_t pstride_x,
110  index_t pdilate_y, index_t pdilate_x) {
112  ::Error_Expression_Does_Not_Meet_Dimension_Req();
113  CHECK(imshape[dstdim - 1] >= psize_x && imshape[dstdim - 2] >= psize_y)
114  << "PackColToPatch:image shape smaller than patch size";
115  return PackColToPatchXExp<SrcExp, DType, dstdim>(src.self(), imshape,
116  psize_y, psize_x, pstride_y, pstride_x,
117  pdilate_y, pdilate_x);
118 }
119 
120 //----------------------
121 // Execution plan
122 //----------------------
123 template<typename SrcExp, typename DType, int dstdim>
124 struct Plan<PackColToPatchXExp<SrcExp, DType, dstdim>, DType> {
125  public:
129  i_channel_(e.shape_[dstdim - 3]), pdilate_y_(e.pdilate_y_), pdilate_x_(e.pdilate_x_),
130  i_height_(e.shape_[dstdim - 2]),
131  o_height_((e.shape_[dstdim - 2] - (pdilate_y_ * (psize_y_ - 1) + 1)) /
132  pstride_y_ + 1),
133  o_width_((e.shape_[dstdim - 1] - (pdilate_x_ * (psize_x_ - 1) + 1)) /
134  pstride_x_ + 1) {
135  // note: i/o convention are same as unpack
136  }
137  MSHADOW_XINLINE DType Eval(index_t i, index_t j) const {
138  using namespace std;
139  const index_t y = i % i_height_;
140  const index_t idivh = i / i_height_;
141  const index_t c = idivh % i_channel_;
142  const index_t n = idivh / i_channel_;
143  const index_t x = j;
144 
145  const index_t psize_y_dilate = (pdilate_y_ * (psize_y_ - 1) + 1);
146  const index_t psize_x_dilate = (pdilate_x_ * (psize_x_ - 1) + 1);
147 
148  const index_t py_min =
149  y < psize_y_dilate ? y % pdilate_y_ : (y-psize_y_dilate + pstride_y_) / pstride_y_;
150  const index_t px_min =
151  x < psize_x_dilate ? x % pdilate_x_ : (x-psize_x_dilate + pstride_x_) / pstride_x_;
152  const index_t py_max = min((y + pstride_y_) / pstride_y_, o_height_);
153  const index_t px_max = min((x + pstride_x_) / pstride_x_, o_width_);
154  DType res = static_cast<DType>(0);
155  for (index_t py = py_min; py < py_max; py += pdilate_y_) {
156  for (index_t px = px_min; px < px_max; px += pdilate_x_) {
157  res += src_.Eval(((c * psize_y_ + (y - py*pstride_y_) / pdilate_y_) * psize_x_ +
158  (x - px * pstride_x_) / pdilate_x_),
159  (n * o_height_ + py) * o_width_ + px);
160  }
161  }
162  return res;
163  }
164 
165  private:
167  const index_t psize_y_, psize_x_, pstride_y_, pstride_x_, i_channel_;
169  const index_t i_height_, o_height_, o_width_;
170 };
171 } // namespace expr
172 } // namespace mshadow
173 #endif // MSHADOW_EXTENSION_PACK_COL2PATCH_H_
PackColToPatchXExp< SrcExp, DType, dstdim > pack_col2patch(const expr::Exp< SrcExp, DType, etype > &src, Shape< dstdim > imshape, index_t psize_y, index_t psize_x, index_t pstride, index_t pdilate)
reverse operation of pack_col2patch, can be used to implement deconvolution
Definition: pack_col2patch.h:91
index_t pdilate_x_
Definition: pack_col2patch.h:55
Definition: expr_engine-inl.h:59
used to help static type check
Definition: expr_engine-inl.h:331
shape of a tensor
Definition: tensor.h:54
const SrcExp & src_
source operand
Definition: pack_col2patch.h:45
Definition: optional.h:241
index_t pstride_x_
Definition: pack_col2patch.h:52
index_t psize_y_
patch height
Definition: pack_col2patch.h:47
reverse operation of UnpackPatchToCol, used to backprop gradient back this is a version supporting mu...
Definition: pack_col2patch.h:41
index_t pstride_y_
patch stride
Definition: pack_col2patch.h:51
index_t psize_x_
patch height
Definition: pack_col2patch.h:49
static Shape< dim > Check(const E &t)
#define MSHADOW_XINLINE
Definition: base.h:223
int32_t index_t
type that will be used for index
Definition: base.h:336
PackColToPatchXExp(const SrcExp &src, Shape< dstdim > imshape, index_t psize_y, index_t psize_x, index_t pstride_y, index_t pstride_x, index_t pdilate_y, index_t pdilate_x)
constructor
Definition: pack_col2patch.h:57
MSHADOW_XINLINE index_t ProdShape(int dimstart, int dimend) const
Definition: tensor.h:158
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const
Definition: pack_col2patch.h:137
index_t pdilate_y_
patch dilate
Definition: pack_col2patch.h:54
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:80
const SubType & self(void) const
Definition: expression.h:83
Plan< BinaryMapExp< OP, TA, TB, DType, etype >, DType > MakePlan(const BinaryMapExp< OP, TA, TB, DType, etype > &e)
Definition: expr_engine-inl.h:240
a general class that allows extension that makes tensors of some shape
Definition: expr_engine-inl.h:44
overloaded + operator between half_t and bf16_t
Definition: base.h:327
Shape< dim > shape_
the shape of this expression
Definition: expr_engine-inl.h:48
Plan(const PackColToPatchXExp< SrcExp, DType, dstdim > &e)
Definition: pack_col2patch.h:126