mxnet
unpack_patch2col.h
Go to the documentation of this file.
1 
7 #ifndef MSHADOW_EXTENSION_UNPACK_PATCH2COL_H_
8 #define MSHADOW_EXTENSION_UNPACK_PATCH2COL_H_
9 #include "../extension.h"
10 namespace mshadow {
11 namespace expr {
20 template<typename SrcExp, typename DType, int srcdim>
22  public MakeTensorExp<UnpackPatchToColXExp<SrcExp, DType, srcdim>,
23  SrcExp, 2, DType>{
25  const SrcExp &img_;
43  UnpackPatchToColXExp(const SrcExp &img,
44  index_t psize_y,
45  index_t psize_x,
46  index_t pstride_y,
47  index_t pstride_x,
48  index_t pdilate_y,
49  index_t pdilate_x)
50  : img_(img), psize_y_(psize_y), psize_x_(psize_x),
51  pstride_y_(pstride_y), pstride_x_(pstride_x),
52  pdilate_y_(pdilate_y), pdilate_x_(pdilate_x){
54  CHECK(imshape[srcdim - 1] >= psize_x && imshape[srcdim - 2] >= psize_y)
55  << "UnpackPatchToCol:image shape smaller than patch size";
56  this->i_channel_ = imshape[srcdim - 3];
57  this->i_height_ = imshape[srcdim - 2];
58  this->i_width_ = imshape[srcdim - 1];
59  // calculate number of batches
60  const index_t num = imshape.ProdShape(0, srcdim - 3);
61  const index_t o_height = (i_height_ -
62  (pdilate_y * (psize_y - 1) + 1)) / pstride_y + 1;
63  const index_t o_width = (i_width_ -
64  (pdilate_x * (psize_x - 1) + 1)) / pstride_x + 1;
65  this->shape_[1] = o_height * o_width * num;
66  this->shape_[0] = psize_y * psize_x * i_channel_;
67  }
68 };
69 
89 template<typename SrcExp, typename DType, int etype>
92  index_t psize_y, index_t psize_x, index_t pstride, index_t pdilate) {
94  ::Error_Expression_Does_Not_Meet_Dimension_Req();
96  (img.self(), psize_y, psize_x, pstride, pstride, pdilate, pdilate);
97 }
98 
102 template<typename SrcExp, typename DType, int etype>
105  index_t psize_y, index_t psize_x, index_t pstride_y_, index_t pstride_x_,
108  ::Error_Expression_Does_Not_Meet_Dimension_Req();
110  (img.self(), psize_y, psize_x, pstride_y_, pstride_x_, pdilate_y_, pdilate_x_);
111 }
112 //----------------------
113 // Execution plan
114 //----------------------
115 template<typename SrcExp, typename DType, int srcdim>
116 struct Plan<UnpackPatchToColXExp<SrcExp, DType, srcdim>, DType> {
117  public:
119  :src_(MakePlan(e.img_)),
124  o_height_((i_height_ - (pdilate_y_ * (psize_y_ - 1) + 1)) / pstride_y_ + 1),
125  o_width_((i_width_ - (pdilate_x_ * (psize_x_ - 1) + 1)) / pstride_x_ + 1) {}
126  MSHADOW_XINLINE DType Eval(index_t i, index_t j) const {
127  const index_t x_offset = i % psize_x_ * pdilate_x_;
128  const index_t idivp = i / psize_x_;
129  const index_t y_offset = idivp % psize_y_ * pdilate_y_;
130  const index_t c = idivp / psize_y_;
131  const index_t x = (j % o_width_) * pstride_x_ + x_offset;
132  const index_t jdivw = j / o_width_;
133  const index_t y = (jdivw % o_height_) * pstride_y_ + y_offset;
134  const index_t n = jdivw / o_height_;
135 
136  if (x < i_width_ && y < i_height_) {
137  return src_.Eval((n * i_channel_ + c) * i_height_ + y, x);
138  } else {
139  return DType(0.0f);
140  }
141  }
142 
143  private:
144  Plan<SrcExp, DType> src_;
147  const index_t i_height_, i_width_, o_height_, o_width_;
148 };
149 } // namespace expr
150 } // namespace mshadow
151 #endif // MSHADOW_EXTENSION_UNPACK_PATCH2COL_H_
Definition: expr_engine-inl.h:40
used to help static type check
Definition: expr_engine-inl.h:312
UnpackPatchToColXExp(const SrcExp &img, index_t psize_y, index_t psize_x, index_t pstride_y, index_t pstride_x, index_t pdilate_y, index_t pdilate_x)
constructor
Definition: unpack_patch2col.h:43
index_t psize_x_
patch width
Definition: unpack_patch2col.h:29
unpack local (overlap) patches of image to column of mat, can be used to implement convolution...
Definition: unpack_patch2col.h:21
index_t pstride_y_
patch stride
Definition: unpack_patch2col.h:31
UnpackPatchToColXExp< SrcExp, DType, ExpInfo< SrcExp >::kDim > unpack_patch2col(const Exp< SrcExp, DType, etype > &img, index_t psize_y, index_t psize_x, index_t pstride, index_t pdilate)
unpack local (overlap) patches of image to column of mat, can be used to implement convolution after ...
Definition: unpack_patch2col.h:91
static Shape< dim > Check(const E &t)
#define MSHADOW_XINLINE
Definition: base.h:204
int32_t index_t
type that will be used for index
Definition: base.h:291
index_t i_height_
height of img
Definition: unpack_patch2col.h:39
index_t i_width_
width of img
Definition: unpack_patch2col.h:41
MSHADOW_XINLINE index_t ProdShape(int dimstart, int dimend) const
Definition: tensor.h:139
index_t pdilate_x_
Definition: unpack_patch2col.h:35
const SrcExp & img_
source operand
Definition: unpack_patch2col.h:25
index_t i_channel_
number of input channel
Definition: unpack_patch2col.h:37
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const
Definition: unpack_patch2col.h:126
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:61
const SubType & self(void) const
Definition: expression.h:64
Plan< BinaryMapExp< OP, TA, TB, DType, etype >, DType > MakePlan(const BinaryMapExp< OP, TA, TB, DType, etype > &e)
Definition: expr_engine-inl.h:221
Plan(const UnpackPatchToColXExp< SrcExp, DType, srcdim > &e)
Definition: unpack_patch2col.h:118
index_t pdilate_y_
patch dilate
Definition: unpack_patch2col.h:34
a general class that allows extension that makes tensors of some shape
Definition: expr_engine-inl.h:25
namespace for mshadow
Definition: base.h:282
Shape< dim > shape_
the shape of this expression
Definition: expr_engine-inl.h:29
index_t pstride_x_
Definition: unpack_patch2col.h:32
index_t psize_y_
patch height
Definition: unpack_patch2col.h:27