25 #ifndef MSHADOW_EXTENSION_PACK_COL2PATCH_H_
26 #define MSHADOW_EXTENSION_PACK_COL2PATCH_H_
28 #include "../extension.h"
39 template<
typename SrcExp,
typename DType,
int dstdim>
41 public MakeTensorExp<PackColToPatchXExp<SrcExp, DType, dstdim>,
42 SrcExp, dstdim, DType> {
64 const index_t o_height = (imshape[dstdim - 2] -
65 (pdilate_y * (psize_y - 1)+ 1))/pstride_y + 1;
66 const index_t o_width = (imshape[dstdim - 1] -
67 (pdilate_x * (psize_x - 1) + 1)) / pstride_x + 1;
69 CHECK_EQ(sshape[1], o_height * o_width * imshape.
ProdShape(0, dstdim - 3))
70 <<
"PackColToPatchExp: src.size(1) mismatch";
71 CHECK_EQ(sshape[0], psize_y * psize_x * imshape[dstdim - 3])
72 <<
"PackColToPatchExp: src.size(0) mismatch";
88 template<
typename SrcExp,
typename DType,
int dstdim,
int etype>
89 inline PackColToPatchXExp<SrcExp, DType, dstdim>
94 ::Error_Expression_Does_Not_Meet_Dimension_Req();
95 CHECK(imshape[dstdim - 1] >= psize_x && imshape[dstdim - 2] >= psize_y)
96 <<
"PackColToPatch:image shape smaller than patch size";
98 psize_y, psize_x, pstride, pstride,
104 template<
typename SrcExp,
typename DType,
int dstdim,
int etype>
105 inline PackColToPatchXExp<SrcExp, DType, dstdim>
111 ::Error_Expression_Does_Not_Meet_Dimension_Req();
112 CHECK(imshape[dstdim - 1] >= psize_x && imshape[dstdim - 2] >= psize_y)
113 <<
"PackColToPatch:image shape smaller than patch size";
115 psize_y, psize_x, pstride_y, pstride_x,
116 pdilate_y, pdilate_x);
122 template<
typename SrcExp,
typename DType,
int dstdim>
126 :src_(
MakePlan(e.src_)), psize_y_(e.psize_y_),
127 psize_x_(e.psize_x_), pstride_y_(e.pstride_y_), pstride_x_(e.pstride_x_),
128 i_channel_(e.shape_[dstdim - 3]), pdilate_y_(e.pdilate_y_), pdilate_x_(e.pdilate_x_),
129 i_height_(e.shape_[dstdim - 2]),
130 o_height_((e.shape_[dstdim - 2] - (pdilate_y_ * (psize_y_ - 1) + 1)) /
132 o_width_((e.shape_[dstdim - 1] - (pdilate_x_ * (psize_x_ - 1) + 1)) /
138 const index_t y = i % i_height_;
139 const index_t idivh = i / i_height_;
140 const index_t c = idivh % i_channel_;
141 const index_t n = idivh / i_channel_;
144 const index_t psize_y_dilate = (pdilate_y_ * (psize_y_ - 1) + 1);
145 const index_t psize_x_dilate = (pdilate_x_ * (psize_x_ - 1) + 1);
148 y < psize_y_dilate ? y % pdilate_y_ : (y-psize_y_dilate + pstride_y_) / pstride_y_;
150 x < psize_x_dilate ? x % pdilate_x_ : (x-psize_x_dilate + pstride_x_) / pstride_x_;
151 const index_t py_max = min((y + pstride_y_) / pstride_y_, o_height_);
152 const index_t px_max = min((x + pstride_x_) / pstride_x_, o_width_);
153 DType res =
static_cast<DType
>(0);
154 for (
index_t py = py_min; py < py_max; py += pdilate_y_) {
155 for (
index_t px = px_min; px < px_max; px += pdilate_x_) {
156 res += src_.Eval(((c * psize_y_ + (y - py*pstride_y_) / pdilate_y_) * psize_x_ +
157 (x - px * pstride_x_) / pdilate_x_),
158 (n * o_height_ + py) * o_width_ + px);
166 const index_t psize_y_, psize_x_, pstride_y_, pstride_x_, i_channel_;
167 const index_t pdilate_y_, pdilate_x_;
168 const index_t i_height_, o_height_, o_width_;
172 #endif // MSHADOW_EXTENSION_PACK_COL2PATCH_H_