7 #ifndef MSHADOW_EXTENSION_PACK_COL2PATCH_H_ 8 #define MSHADOW_EXTENSION_PACK_COL2PATCH_H_ 10 #include "../extension.h" 21 template<
typename SrcExp,
typename DType,
int dstdim>
23 public MakeTensorExp<PackColToPatchXExp<SrcExp, DType, dstdim>,
24 SrcExp, dstdim, DType> {
42 :src_(src), psize_y_(psize_y), psize_x_(psize_x),
43 pstride_y_(pstride_y), pstride_x_(pstride_x),
44 pdilate_y_(pdilate_y), pdilate_x_(pdilate_x){
46 const index_t o_height = (imshape[dstdim - 2] -
47 (pdilate_y * (psize_y - 1)+ 1))/pstride_y + 1;
48 const index_t o_width = (imshape[dstdim - 1] -
49 (pdilate_x * (psize_x - 1) + 1)) / pstride_x + 1;
51 CHECK_EQ(sshape[1], o_height * o_width * imshape.
ProdShape(0, dstdim - 3))
52 <<
"PackColToPatchExp: src.size(1) mismatch";
53 CHECK_EQ(sshape[0], psize_y * psize_x * imshape[dstdim - 3])
54 <<
"PackColToPatchExp: src.size(0) mismatch";
70 template<
typename SrcExp,
typename DType,
int dstdim,
int etype>
76 ::Error_Expression_Does_Not_Meet_Dimension_Req();
77 CHECK(imshape[dstdim - 1] >= psize_x && imshape[dstdim - 2] >= psize_y)
78 <<
"PackColToPatch:image shape smaller than patch size";
80 psize_y, psize_x, pstride, pstride,
86 template<
typename SrcExp,
typename DType,
int dstdim,
int etype>
93 ::Error_Expression_Does_Not_Meet_Dimension_Req();
94 CHECK(imshape[dstdim - 1] >= psize_x && imshape[dstdim - 2] >= psize_y)
95 <<
"PackColToPatch:image shape smaller than patch size";
97 psize_y, psize_x, pstride_y, pstride_x,
98 pdilate_y, pdilate_x);
104 template<
typename SrcExp,
typename DType,
int dstdim>
111 i_height_(e.
shape_[dstdim - 2]),
120 const index_t y = i % i_height_;
121 const index_t idivh = i / i_height_;
122 const index_t c = idivh % i_channel_;
123 const index_t n = idivh / i_channel_;
135 DType res =
static_cast<DType
>(0);
140 (n * o_height_ + py) * o_width_ + px);
150 const index_t i_height_, o_height_, o_width_;
154 #endif // MSHADOW_EXTENSION_PACK_COL2PATCH_H_ PackColToPatchXExp< SrcExp, DType, dstdim > pack_col2patch(const expr::Exp< SrcExp, DType, etype > &src, Shape< dstdim > imshape, index_t psize_y, index_t psize_x, index_t pstride, index_t pdilate)
reverse operation of pack_col2patch, can be used to implement deconvolution
Definition: pack_col2patch.h:72
index_t pdilate_x_
Definition: pack_col2patch.h:36
Definition: expr_engine-inl.h:40
used to help static type check
Definition: expr_engine-inl.h:312
shape of a tensor
Definition: tensor.h:35
const SrcExp & src_
source operand
Definition: pack_col2patch.h:26
Definition: optional.h:241
index_t pstride_x_
Definition: pack_col2patch.h:33
index_t psize_y_
patch height
Definition: pack_col2patch.h:28
reverse operation of UnpackPatchToCol, used to backprop gradient back this is a version supporting mu...
Definition: pack_col2patch.h:22
index_t pstride_y_
patch stride
Definition: pack_col2patch.h:32
index_t psize_x_
patch height
Definition: pack_col2patch.h:30
static Shape< dim > Check(const E &t)
#define MSHADOW_XINLINE
Definition: base.h:204
int32_t index_t
type that will be used for index
Definition: base.h:291
PackColToPatchXExp(const SrcExp &src, Shape< dstdim > imshape, index_t psize_y, index_t psize_x, index_t pstride_y, index_t pstride_x, index_t pdilate_y, index_t pdilate_x)
constructor
Definition: pack_col2patch.h:38
MSHADOW_XINLINE index_t ProdShape(int dimstart, int dimend) const
Definition: tensor.h:139
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const
Definition: pack_col2patch.h:118
index_t pdilate_y_
patch dilate
Definition: pack_col2patch.h:35
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:61
const SubType & self(void) const
Definition: expression.h:64
Plan< BinaryMapExp< OP, TA, TB, DType, etype >, DType > MakePlan(const BinaryMapExp< OP, TA, TB, DType, etype > &e)
Definition: expr_engine-inl.h:221
a general class that allows extension that makes tensors of some shape
Definition: expr_engine-inl.h:25
namespace for mshadow
Definition: base.h:282
Shape< dim > shape_
the shape of this expression
Definition: expr_engine-inl.h:29
Plan(const PackColToPatchXExp< SrcExp, DType, dstdim > &e)
Definition: pack_col2patch.h:107