mxnet
transpose.h
Go to the documentation of this file.
1 
7 #ifndef MSHADOW_EXTENSION_TRANSPOSE_H_
8 #define MSHADOW_EXTENSION_TRANSPOSE_H_
9 #include <algorithm>
10 #include "../extension.h"
11 namespace mshadow {
12 namespace expr {
24 template<typename SrcExp, typename DType, int dimsrc>
26  public MakeTensorExp<TransposeExExp<SrcExp, DType, dimsrc>,
27  SrcExp, dimsrc, DType> {
29  const SrcExp &src_;
31  Shape<dimsrc> dst_in_src_stride_; // Holds the corresponding stride of the dst axes in src
34  explicit TransposeExExp(const SrcExp &src, Shape<dimsrc> axes) : src_(src), axes_(axes) {
36  src_stride_ = src_shape[dimsrc - 1];
37  Shape<dimsrc> src_stride;
38  src_stride[dimsrc-1] = 1;
39  for (int i = dimsrc-2; i >= 0; --i) src_stride[i] = src_shape[i+1]*src_stride[i+1];
40  for (int i = 0; i < dimsrc; ++i) {
41  dst_in_src_stride_[i] = src_stride[axes[i]];
42  this->shape_[i] = src_shape[axes[i]];
43  }
44  }
45 };
56 template<typename SrcExp, typename DType, int etype>
60 }
61 
62 template<typename SrcExp, typename DType, int dimsrc>
63 struct Plan<TransposeExExp<SrcExp, DType, dimsrc>, DType> {
64  public:
66  : src_(MakePlan(e.src_)),
69  dst_shape_(e.shape_) {}
70  MSHADOW_XINLINE DType Eval(index_t i, index_t j) const {
71  index_t idx = j * dst_in_src_stride_[dimsrc - 1];
72  #pragma unroll
73  for (int k = dimsrc-2; k >= 0; --k) {
74  idx += (i % dst_shape_[k]) * dst_in_src_stride_[k];
75  i /= dst_shape_[k];
76  }
77  return src_.Eval(idx/src_stride_, idx%src_stride_);
78  }
79 
80  private:
82  const index_t src_stride_;
83  const Shape<dimsrc> dst_in_src_stride_, dst_shape_;
84 };
85 
96 template<typename SrcExp, typename DType, int dimsrc, int etype>
98  public Exp<TransposeIndicesExp<SrcExp, DType, dimsrc, etype>, DType, etype> {
100  const SrcExp &src_indices_; // Expression of the source indices
101  Shape<dimsrc> src_shape_; // Holds the corresponding stride of the source axes in dst
102  const Shape<dimsrc> axes_; // The transpose axes
103  Shape<dimsrc> src_in_dst_stride_; // Holds the corresponding stride of the source axes in dst
105  explicit TransposeIndicesExp(const SrcExp &src_indices,
106  Shape<dimsrc> src_shape,
107  Shape<dimsrc> axes) : src_indices_(src_indices),
108  src_shape_(src_shape), axes_(axes) {
109  Shape<dimsrc> dst_shape_;
110  Shape<dimsrc> dst_stride_;
111  bool axes_checking_flag[dimsrc] = { 0 };
112  for (int i = 0; i < dimsrc; ++i) {
113  CHECK_LT(static_cast<int>(axes[i]), dimsrc)
114  << "Invalid axes input! All elements of axes must be between 0 and " << dimsrc
115  << ", find axes=" << axes;
116  dst_shape_[i] = src_shape[axes[i]];
117  axes_checking_flag[axes[i]] = true;
118  }
119  // check if the input axes is valid
120  for (int i = 0; i < dimsrc; ++i) {
121  CHECK_EQ(axes_checking_flag[i], true)
122  << "Invalid axes input! All elements of axes must be between 0 and " << dimsrc
123  << ", find axes=" << axes;
124  }
125  dst_stride_[dimsrc - 1] = 1;
126  for (int i = dimsrc - 2; i >= 0; --i) dst_stride_[i] = dst_shape_[i+1] * dst_stride_[i+1];
127  for (int i = 0; i < dimsrc; ++i) {
128  src_in_dst_stride_[axes[i]] = dst_stride_[i];
129  }
130  }
131 };
132 
143 template<typename SrcExp, typename DType, int dimsrc, int etype>
146  Shape<dimsrc> src_shape,
147  Shape<dimsrc> axes) {
148  return TransposeIndicesExp<SrcExp, DType, dimsrc, etype>(src_indices.self(), src_shape, axes);
149 }
150 
151 template<typename SrcExp, typename DType, int dimsrc, int etype>
152 struct Plan<TransposeIndicesExp<SrcExp, DType, dimsrc, etype>, DType> {
153  public:
155  : src_indices_(MakePlan(e.src_indices_)),
156  src_in_dst_stride_(e.src_in_dst_stride_),
157  src_shape_(e.src_shape_) {}
158  MSHADOW_XINLINE DType Eval(index_t i, index_t j) const {
159  index_t src_idx = static_cast<index_t>(src_indices_.Eval(i, j));
160  index_t dst_idx = 0;
161  #pragma unroll
162  for (int k = dimsrc - 1; k >= 0; --k) {
163  dst_idx += (src_idx % src_shape_[k]) * src_in_dst_stride_[k];
164  src_idx /= src_shape_[k];
165  }
166  return static_cast<DType>(dst_idx);
167  }
168 
169  private:
170  Plan<SrcExp, DType> src_indices_;
171  const Shape<dimsrc> src_in_dst_stride_, src_shape_;
172 };
173 
174 //----------------------
175 // Execution plan
176 //----------------------
178 template<typename SrcExp, typename DType, int dimsrc, int etype>
181  return Plan<TransposeIndicesExp<SrcExp, DType, dimsrc, etype>, DType>(e);
182 }
183 
184 template<int dim, typename SrcExp, typename DType, int dimsrc, int etype>
185 struct ShapeCheck<dim, TransposeIndicesExp<SrcExp, DType, dimsrc, etype> > {
186  inline static Shape<dim>
189  return s;
190  }
191 };
192 
193 template<typename SrcExp, typename DType, int dimsrc, int etype>
194 struct ExpInfo<TransposeIndicesExp<SrcExp, DType, dimsrc, etype> > {
195  static const int kDim = ExpInfo<SrcExp>::kDim;
196  static const int kDevMask = ExpInfo<SrcExp>::kDevMask;
197 };
198 } // namespace expr
199 } // namespace mshadow
200 #endif // MSHADOW_EXTENSION_TRANSPOSE_H_
TransposeExExp(const SrcExp &src, Shape< dimsrc > axes)
constructor
Definition: transpose.h:34
Definition: expr_engine-inl.h:40
static Shape< dim > Check(const TransposeIndicesExp< SrcExp, DType, dimsrc, etype > &t)
Definition: transpose.h:187
Plan(const TransposeExExp< SrcExp, DType, dimsrc > &e)
Definition: transpose.h:65
const Shape< dimsrc > axes_
Definition: transpose.h:30
transpose axes of a tensor input: Tensor<Device,dim>: ishape output: Tensor<Device,dimdst> oshape[a1],oshape[a2] = ishape[a2],oshape[a1]
Definition: transpose.h:25
index_t src_stride_
Definition: transpose.h:32
Shape< dimsrc > dst_in_src_stride_
Definition: transpose.h:31
static Shape< dim > Check(const E &t)
#define MSHADOW_XINLINE
Definition: base.h:204
static type inference template, used to get the dimension of each expression, if ExpInfo<E>::kDim == ...
Definition: expr_engine-inl.h:244
const SrcExp & src_
source expression
Definition: transpose.h:29
int32_t index_t
type that will be used for index
Definition: base.h:291
Shape< dimsrc > src_in_dst_stride_
Definition: transpose.h:103
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const
Definition: transpose.h:70
runtime shape checking template get the shape of an expression, report error if shape mismatch ...
Definition: expr_engine-inl.h:346
Plan(const TransposeIndicesExp< SrcExp, DType, dimsrc, etype > &e)
Definition: transpose.h:154
TransposeIndicesExp< SrcExp, DType, dimsrc, etype > transpose_indices(const Exp< SrcExp, DType, etype > &src_indices, Shape< dimsrc > src_shape, Shape< dimsrc > axes)
a expression that reshapes a tensor to another shape
Definition: transpose.h:145
transform contiguous indices of the source tensor to indices of the transposed tensor. input: Tensor<Device, k>: ishape output: Tensor<Device, k>: oshape = ishape
Definition: transpose.h:97
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:61
const SubType & self(void) const
Definition: expression.h:64
Plan< BinaryMapExp< OP, TA, TB, DType, etype >, DType > MakePlan(const BinaryMapExp< OP, TA, TB, DType, etype > &e)
Definition: expr_engine-inl.h:221
TransposeIndicesExp(const SrcExp &src_indices, Shape< dimsrc > src_shape, Shape< dimsrc > axes)
constructor
Definition: transpose.h:105
a general class that allows extension that makes tensors of some shape
Definition: expr_engine-inl.h:25
const Shape< dimsrc > axes_
Definition: transpose.h:102
namespace for mshadow
Definition: base.h:282
Shape< dimsrc > src_shape_
Definition: transpose.h:101
TransposeExExp< SrcExp, DType, ExpInfo< SrcExp >::kDim > transpose(const Exp< SrcExp, DType, etype > &src, Shape< ExpInfo< SrcExp >::kDim > axes)
a expression that reshapes a tensor to another shape
Definition: transpose.h:58
const SrcExp & src_indices_
source expression
Definition: transpose.h:100
Shape< dim > shape_
the shape of this expression
Definition: expr_engine-inl.h:29
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const
Definition: transpose.h:158