25 #ifndef MSHADOW_EXTENSION_UNPACK_PATCH2COL_H_
26 #define MSHADOW_EXTENSION_UNPACK_PATCH2COL_H_
27 #include "../extension.h"
38 template<
typename SrcExp,
typename DType,
int srcdim>
40 public MakeTensorExp<UnpackPatchToColXExp<SrcExp, DType, srcdim>,
72 CHECK(imshape[srcdim - 1] >= psize_x && imshape[srcdim - 2] >= psize_y)
73 <<
"UnpackPatchToCol:image shape smaller than patch size";
74 this->i_channel_ = imshape[srcdim - 3];
75 this->i_height_ = imshape[srcdim - 2];
76 this->i_width_ = imshape[srcdim - 1];
80 (pdilate_y * (psize_y - 1) + 1)) / pstride_y + 1;
82 (pdilate_x * (psize_x - 1) + 1)) / pstride_x + 1;
83 this->
shape_[1] = o_height * o_width * num;
107 template<
typename SrcExp,
typename DType,
int etype>
108 inline UnpackPatchToColXExp<SrcExp, DType, ExpInfo<SrcExp>::kDim>
112 ::Error_Expression_Does_Not_Meet_Dimension_Req();
114 (img.
self(), psize_y, psize_x, pstride, pstride, pdilate, pdilate);
120 template<
typename SrcExp,
typename DType,
int etype>
121 inline UnpackPatchToColXExp<SrcExp, DType, ExpInfo<SrcExp>::kDim>
126 ::Error_Expression_Does_Not_Meet_Dimension_Req();
128 (img.
self(), psize_y, psize_x, pstride_y_, pstride_x_, pdilate_y_, pdilate_x_);
133 template<
typename SrcExp,
typename DType,
int srcdim>
138 psize_y_(e.psize_y_), psize_x_(e.psize_x_),
139 pstride_y_(e.pstride_y_), pstride_x_(e.pstride_x_),
140 i_channel_(e.i_channel_), pdilate_y_(e.pdilate_y_), pdilate_x_(e.pdilate_x_),
141 i_height_(e.i_height_), i_width_(e.i_width_),
142 o_height_((i_height_ - (pdilate_y_ * (psize_y_ - 1) + 1)) / pstride_y_ + 1),
143 o_width_((i_width_ - (pdilate_x_ * (psize_x_ - 1) + 1)) / pstride_x_ + 1) {}
145 const index_t x_offset = i % psize_x_ * pdilate_x_;
146 const index_t idivp = i / psize_x_;
147 const index_t y_offset = idivp % psize_y_ * pdilate_y_;
148 const index_t c = idivp / psize_y_;
149 const index_t x = (j % o_width_) * pstride_x_ + x_offset;
150 const index_t jdivw = j / o_width_;
151 const index_t y = (jdivw % o_height_) * pstride_y_ + y_offset;
152 const index_t n = jdivw / o_height_;
154 if (x < i_width_ && y < i_height_) {
155 return src_.Eval((n * i_channel_ + c) * i_height_ + y, x);
163 const index_t psize_y_, psize_x_, pstride_y_, pstride_x_, i_channel_;
164 const index_t pdilate_y_, pdilate_x_;
165 const index_t i_height_, i_width_, o_height_, o_width_;
169 #endif // MSHADOW_EXTENSION_UNPACK_PATCH2COL_H_