mxnet
unpack_patch2col.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
25 #ifndef MSHADOW_EXTENSION_UNPACK_PATCH2COL_H_
26 #define MSHADOW_EXTENSION_UNPACK_PATCH2COL_H_
27 #include "../extension.h"
28 namespace mshadow {
29 namespace expr {
38 template<typename SrcExp, typename DType, int srcdim>
40  public MakeTensorExp<UnpackPatchToColXExp<SrcExp, DType, srcdim>,
41  SrcExp, 2, DType>{
43  const SrcExp &img_;
61  UnpackPatchToColXExp(const SrcExp &img,
62  index_t psize_y,
63  index_t psize_x,
64  index_t pstride_y,
65  index_t pstride_x,
66  index_t pdilate_y,
67  index_t pdilate_x)
68  : img_(img), psize_y_(psize_y), psize_x_(psize_x),
69  pstride_y_(pstride_y), pstride_x_(pstride_x),
70  pdilate_y_(pdilate_y), pdilate_x_(pdilate_x){
72  CHECK(imshape[srcdim - 1] >= psize_x && imshape[srcdim - 2] >= psize_y)
73  << "UnpackPatchToCol:image shape smaller than patch size";
74  this->i_channel_ = imshape[srcdim - 3];
75  this->i_height_ = imshape[srcdim - 2];
76  this->i_width_ = imshape[srcdim - 1];
77  // calculate number of batches
78  const index_t num = imshape.ProdShape(0, srcdim - 3);
79  const index_t o_height = (i_height_ -
80  (pdilate_y * (psize_y - 1) + 1)) / pstride_y + 1;
81  const index_t o_width = (i_width_ -
82  (pdilate_x * (psize_x - 1) + 1)) / pstride_x + 1;
83  this->shape_[1] = o_height * o_width * num;
84  this->shape_[0] = psize_y * psize_x * i_channel_;
85  }
86 };
87 
107 template<typename SrcExp, typename DType, int etype>
108 inline UnpackPatchToColXExp<SrcExp, DType, ExpInfo<SrcExp>::kDim>
110  index_t psize_y, index_t psize_x, index_t pstride, index_t pdilate) {
112  ::Error_Expression_Does_Not_Meet_Dimension_Req();
114  (img.self(), psize_y, psize_x, pstride, pstride, pdilate, pdilate);
115 }
116 
120 template<typename SrcExp, typename DType, int etype>
121 inline UnpackPatchToColXExp<SrcExp, DType, ExpInfo<SrcExp>::kDim>
123  index_t psize_y, index_t psize_x, index_t pstride_y_, index_t pstride_x_,
124  index_t pdilate_y_, index_t pdilate_x_) {
126  ::Error_Expression_Does_Not_Meet_Dimension_Req();
128  (img.self(), psize_y, psize_x, pstride_y_, pstride_x_, pdilate_y_, pdilate_x_);
129 }
130 //----------------------
131 // Execution plan
132 //----------------------
133 template<typename SrcExp, typename DType, int srcdim>
134 struct Plan<UnpackPatchToColXExp<SrcExp, DType, srcdim>, DType> {
135  public:
137  :src_(MakePlan(e.img_)),
138  psize_y_(e.psize_y_), psize_x_(e.psize_x_),
139  pstride_y_(e.pstride_y_), pstride_x_(e.pstride_x_),
140  i_channel_(e.i_channel_), pdilate_y_(e.pdilate_y_), pdilate_x_(e.pdilate_x_),
141  i_height_(e.i_height_), i_width_(e.i_width_),
142  o_height_((i_height_ - (pdilate_y_ * (psize_y_ - 1) + 1)) / pstride_y_ + 1),
143  o_width_((i_width_ - (pdilate_x_ * (psize_x_ - 1) + 1)) / pstride_x_ + 1) {}
144  MSHADOW_XINLINE DType Eval(index_t i, index_t j) const {
145  const index_t x_offset = i % psize_x_ * pdilate_x_;
146  const index_t idivp = i / psize_x_;
147  const index_t y_offset = idivp % psize_y_ * pdilate_y_;
148  const index_t c = idivp / psize_y_;
149  const index_t x = (j % o_width_) * pstride_x_ + x_offset;
150  const index_t jdivw = j / o_width_;
151  const index_t y = (jdivw % o_height_) * pstride_y_ + y_offset;
152  const index_t n = jdivw / o_height_;
153 
154  if (x < i_width_ && y < i_height_) {
155  return src_.Eval((n * i_channel_ + c) * i_height_ + y, x);
156  } else {
157  return DType(0.0f);
158  }
159  }
160 
161  private:
162  Plan<SrcExp, DType> src_;
163  const index_t psize_y_, psize_x_, pstride_y_, pstride_x_, i_channel_;
164  const index_t pdilate_y_, pdilate_x_;
165  const index_t i_height_, i_width_, o_height_, o_width_;
166 };
167 } // namespace expr
168 } // namespace mshadow
169 #endif // MSHADOW_EXTENSION_UNPACK_PATCH2COL_H_
mshadow::expr::Plan< UnpackPatchToColXExp< SrcExp, DType, srcdim >, DType >::Eval
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const
Definition: unpack_patch2col.h:144
mshadow::expr::UnpackPatchToColXExp::psize_x_
index_t psize_x_
patch width
Definition: unpack_patch2col.h:47
mshadow::expr::UnpackPatchToColXExp::pstride_y_
index_t pstride_y_
patch stride
Definition: unpack_patch2col.h:49
mshadow::expr::Exp::self
const SubType & self(void) const
Definition: expression.h:82
mshadow::expr::UnpackPatchToColXExp
unpack local (overlap) patches of image to column of mat, can be used to implement convolution,...
Definition: unpack_patch2col.h:39
mshadow::expr::TypeCheckPass
used to help static type check
Definition: expr_engine-inl.h:330
mshadow::expr::UnpackPatchToColXExp::psize_y_
index_t psize_y_
patch height
Definition: unpack_patch2col.h:45
mshadow::expr::UnpackPatchToColXExp::i_height_
index_t i_height_
height of img
Definition: unpack_patch2col.h:57
MSHADOW_XINLINE
#define MSHADOW_XINLINE
Definition: base.h:228
mshadow::expr::UnpackPatchToColXExp::pstride_x_
index_t pstride_x_
Definition: unpack_patch2col.h:50
mshadow::expr::ShapeCheck::Check
static Shape< dim > Check(const E &t)
mshadow::expr::Plan< UnpackPatchToColXExp< SrcExp, DType, srcdim >, DType >::Plan
Plan(const UnpackPatchToColXExp< SrcExp, DType, srcdim > &e)
Definition: unpack_patch2col.h:136
mshadow::expr::UnpackPatchToColXExp::pdilate_x_
index_t pdilate_x_
Definition: unpack_patch2col.h:53
mshadow::expr::UnpackPatchToColXExp::UnpackPatchToColXExp
UnpackPatchToColXExp(const SrcExp &img, index_t psize_y, index_t psize_x, index_t pstride_y, index_t pstride_x, index_t pdilate_y, index_t pdilate_x)
constructor
Definition: unpack_patch2col.h:61
mshadow::expr::MakePlan
Plan< BinaryMapExp< OP, TA, TB, DType, etype >, DType > MakePlan(const BinaryMapExp< OP, TA, TB, DType, etype > &e)
Definition: expr_engine-inl.h:239
mshadow::expr::MakeTensorExp< UnpackPatchToColXExp< SrcExp, DType, srcdim >, SrcExp, 2, DType >::shape_
Shape< dim > shape_
the shape of this expression
Definition: expr_engine-inl.h:47
mshadow::index_t
int32_t index_t
type that will be used for index
Definition: base.h:328
mshadow::expr::Plan
Definition: expr_engine-inl.h:58
mshadow::expr::UnpackPatchToColXExp::pdilate_y_
index_t pdilate_y_
patch dilate
Definition: unpack_patch2col.h:52
mshadow::expr::Exp
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:79
mshadow
overloaded + operator between half_t and bf16_t
Definition: base.h:319
mshadow::expr::MakeTensorExp
a general class that allows extension that makes tensors of some shape
Definition: expr_engine-inl.h:43
mshadow::Shape< srcdim >
mshadow::expr::UnpackPatchToColXExp::i_width_
index_t i_width_
width of img
Definition: unpack_patch2col.h:59
mshadow::Shape::ProdShape
MSHADOW_XINLINE index_t ProdShape(int dimstart, int dimend) const
Definition: tensor.h:171
mshadow::expr::UnpackPatchToColXExp::img_
const SrcExp & img_
source operand
Definition: unpack_patch2col.h:43
mshadow::expr::unpack_patch2col
UnpackPatchToColXExp< SrcExp, DType, ExpInfo< SrcExp >::kDim > unpack_patch2col(const Exp< SrcExp, DType, etype > &img, index_t psize_y, index_t psize_x, index_t pstride, index_t pdilate)
unpack local (overlap) patches of image to column of mat, can be used to implement convolution after ...
Definition: unpack_patch2col.h:109
mshadow::expr::UnpackPatchToColXExp::i_channel_
index_t i_channel_
number of input channel
Definition: unpack_patch2col.h:55