mxnet
expr_engine-inl.h
Go to the documentation of this file.
1 
7 #ifndef MSHADOW_EXPR_ENGINE_INL_H_
8 #define MSHADOW_EXPR_ENGINE_INL_H_
9 #include <utility>
10 #include <algorithm>
11 #include "./logging.h"
12 #include "./expression.h"
13 #include "./tensor.h"
14 
15 namespace mshadow {
16 namespace expr {
24 template<typename SubType, typename SrcExp, int dim, typename DType>
26  : public Exp<MakeTensorExp<SubType, SrcExp, dim, DType>,
27  DType, type::kChainer> {
31  inline const SubType& real_self(void) const{
32  return *static_cast<const SubType*>(this);
33  }
34 };
35 //----------------------------------------------------------------------
36 // This part of code gives plan that can be used to carry out execution
37 //---------------------------------------------------------------------
38 // Declarations of plans
39 template<typename ExpType, typename DType>
40 class Plan {
41  public:
46  MSHADOW_XINLINE DType Eval(index_t y, index_t x) const;
47 };
48 // tensor plan
49 template <typename Device, int dim, typename DType>
50 class Plan<Tensor<Device, dim, DType>, DType> {
51  public:
52  explicit Plan(const Tensor<Device, dim, DType> &t)
53  : dptr_(t.dptr_), stride_(t.stride_) {}
54  // for RValue, the return type should be reference
56  return dptr_[y * stride_ + x];
57  }
58  // const evaluation
59  MSHADOW_XINLINE const DType &Eval(index_t y, index_t x) const {
60  return dptr_[y * stride_ + x];
61  }
62 
63  private:
64  DType *dptr_;
65  index_t stride_;
66 };
67 // special evaluation case for 1d tensor, no stride
68 template <typename Device, typename DType>
69 class Plan<Tensor<Device, 1, DType>, DType> {
70  public:
71  explicit Plan(const Tensor<Device, 1, DType> &t) : dptr_(t.dptr_) {}
73  return dptr_[x];
74  }
75  MSHADOW_XINLINE const DType &Eval(index_t y, index_t x) const {
76  return dptr_[x];
77  }
78 
79  private:
80  DType *dptr_;
81 };
82 // scalar
83 template<typename DType>
84 class Plan<ScalarExp<DType>, DType> {
85  public:
86  explicit Plan(DType scalar) : scalar_(scalar) {}
87  MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
88  return scalar_;
89  }
90 
91  private:
92  DType scalar_;
93 };
94 // unary expression
95 template<typename DstDType, typename SrcDType,
96  typename EType, int etype>
97 class Plan<TypecastExp<DstDType, SrcDType, EType, etype>, DstDType> {
98  public:
99  explicit Plan(const Plan<EType, SrcDType> &src) : src_(src) {}
100  MSHADOW_XINLINE DstDType Eval(index_t y, index_t x) const {
101  return DstDType(src_.Eval(y, x)); // NOLINT(*)
102  }
103 
104  private:
106 };
107 
108 // ternary expression
109 template<typename OP, typename TA, typename TB, typename TC, int etype, typename DType>
110 class Plan<TernaryMapExp<OP, TA, TB, TC, DType, etype>, DType> {
111  public:
112  explicit Plan(const Plan<TA, DType> &item1, const Plan<TB, DType> &item2,
113  const Plan<TC, DType> &item3)
114  : item1_(item1), item2_(item2), item3_(item3) {}
115  MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
116  return OP::Map(item1_.Eval(y, x), item2_.Eval(y, x), item3_.Eval(y, x));
117  }
118 
119  private:
120  Plan<TA, DType> item1_;
121  Plan<TB, DType> item2_;
122  Plan<TC, DType> item3_;
123 };
124 // binary expression
125 template<typename OP, typename TA, typename TB, int etype, typename DType>
126 class Plan<BinaryMapExp<OP, TA, TB, DType, etype>, DType> {
127  public:
128  explicit Plan(const Plan<TA, DType> &lhs, const Plan<TB, DType> &rhs)
129  : lhs_(lhs), rhs_(rhs) {}
130  MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
131  return OP::Map(lhs_.Eval(y, x), rhs_.Eval(y, x));
132  }
133 
134  private:
135  Plan<TA, DType> lhs_;
136  Plan<TB, DType> rhs_;
137 };
138 // unary expression
139 template<typename OP, typename TA, int etype, typename DType>
140 class Plan<UnaryMapExp<OP, TA, DType, etype>, DType> {
141  public:
142  explicit Plan(const Plan<TA, DType> &src) : src_(src) {}
143  MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
144  return OP::Map(src_.Eval(y, x));
145  }
146 
147  private:
148  Plan<TA, DType> src_;
149 };
150 // remaps map tensor expression to subtype's plan
151 template<typename SubType, typename SrcExp, int dim, typename DType>
152 struct Plan<MakeTensorExp<SubType, SrcExp, dim, DType>, DType> {
153  public:
154  Plan(const Plan<SubType, DType> &src) : src_(src) {}
155  MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
156  return src_.Eval(y, x);
157  }
158 
159  private:
161 };
162 // tranpsoe
163 template<typename EType, typename DType>
164 class Plan<TransposeExp<EType, DType>, DType> {
165  public:
166  explicit Plan(const Plan<EType, DType> &src) : src_(src) {}
167  MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
168  return src_.Eval(x, y);
169  }
170 
171  private:
172  Plan<EType, DType> src_;
173 };
174 //----------------------------------------------------------------------
175 // Mappings from expression to plans
176 //---------------------------------------------------------------------
177 template<typename OP, typename TA, typename TB, typename DType, int etype>
180 
181 template<typename OP, typename TA, typename TB, typename TC, typename DType, int etype>
184 
185 template<typename DType>
187  return Plan<ScalarExp<DType>, DType>(e.scalar_);
188 }
189 
190 template<typename DstDType, typename SrcDType, typename EType, int etype>
193  return Plan<TypecastExp<DstDType, SrcDType, EType, etype>, DstDType>(MakePlan(e.exp));
194 }
195 
196 template<typename T, typename DType>
198  return Plan<T, DType>(e.self());
199 }
200 
201 template<typename T, typename DType>
202 inline Plan<TransposeExp<T, DType>, DType>
204  return Plan<TransposeExp<T, DType>, DType>(MakePlan(e.exp));
205 }
206 
207 template<typename T, typename SrcExp, int dim, typename DType>
208 inline Plan<T, DType>
210  return Plan<T, DType>(e.real_self());
211 }
212 
213 template<typename OP, typename TA, typename DType, int etype>
216  return Plan<UnaryMapExp<OP, TA, DType, etype>, DType>(MakePlan(e.src_));
217 }
218 
219 template<typename OP, typename TA, typename TB, typename DType, int etype>
220 inline Plan<BinaryMapExp<OP, TA, TB, DType, etype>, DType>
222  return Plan<BinaryMapExp<OP, TA, TB, DType, etype>,
223  DType>(MakePlan(e.lhs_), MakePlan(e.rhs_));
224 }
225 
226 // Ternary
227 template<typename OP, typename TA, typename TB, typename TC, typename DType, int etype>
228 inline Plan<TernaryMapExp<OP, TA, TB, TC, DType, etype>, DType>
230  return Plan<TernaryMapExp<OP, TA, TB, TC, DType, etype>,
231  DType>(MakePlan(e.item1_), MakePlan(e.item2_), MakePlan(e.item3_));
232 }
233 //----------------------------------------------------------------
234 // Static Type inference and Type Checking
235 //----------------------------------------------------------------
243 template<typename E>
244 struct ExpInfo {
245  static const int kDim = -1;
246  static const int kDevMask = 0;
247 };
248 template<typename DType>
249 struct ExpInfo< ScalarExp<DType> > {
250  static const int kDim = 0;
251  static const int kDevMask = 0xffff;
252 };
253 template<typename E, typename DType>
254 struct ExpInfo<TransposeExp<E, DType> > {
255  static const int kDim = ExpInfo<E>::kDim;
256  static const int kDevMask = ExpInfo<E>::kDevMask;
257 };
258 template<typename DstDType, typename SrcDType, typename EType, int etype>
259 struct ExpInfo<TypecastExp<DstDType, SrcDType, EType, etype> > {
260  static const int kDim = ExpInfo<EType>::kDim;
261  static const int kDevMask = ExpInfo<EType>::kDevMask;
262 };
263 template<typename Device, int dim, typename DType>
264 struct ExpInfo<Tensor<Device, dim, DType> > {
265  static const int kDim = dim;
266  static const int kDevMask = Device::kDevMask;
267 };
268 template<typename T, typename SrcExp, int dim, typename DType>
269 struct ExpInfo<MakeTensorExp<T, SrcExp, dim, DType> > {
270  static const int kDimSrc = ExpInfo<SrcExp>::kDim;
271  static const int kDim = kDimSrc >= 0 ? dim : -1;
272  static const int kDevMask = ExpInfo<SrcExp>::kDevMask;
273 };
274 template<typename OP, typename TA, typename DType, int etype>
275 struct ExpInfo<UnaryMapExp<OP, TA, DType, etype> > {
276  static const int kDim = ExpInfo<TA>::kDim;
277  static const int kDevMask = ExpInfo<TA>::kDevMask;
278 };
279 template<typename OP, typename TA, typename TB, typename DType, int etype>
280 struct ExpInfo<BinaryMapExp<OP, TA, TB, DType, etype> > {
281  static const int kDimLhs = ExpInfo<TA>::kDim;
282  static const int kDimRhs = ExpInfo<TB>::kDim;
283  static const int kDim = (kDimLhs >= 0 && kDimRhs >= 0) ?\
284  (kDimLhs == 0 ?\
285  kDimRhs :\
286  ((kDimRhs == 0 || kDimLhs == kDimRhs) ? kDimLhs : -1)) : -1;
287  static const int kDevMask = ExpInfo<TA>::kDevMask & ExpInfo<TB>::kDevMask;
288 };
289 template<typename OP, typename TA, typename TB, typename TC, typename DType, int etype>
290 struct ExpInfo<TernaryMapExp<OP, TA, TB, TC, DType, etype> > {
291  static const int kDimItem1 = ExpInfo<TA>::kDim;
292  static const int kDimItem2 = ExpInfo<TB>::kDim;
293  static const int kDimItem3 = ExpInfo<TC>::kDim;
294  static const int kDim = kDimItem1;
296 };
297 
299 template<typename Device, int dim, typename DType, typename E>
300 struct TypeCheck {
302  static const int kExpDim = ExpInfo<E>::kDim;
304  static const bool kDevPass = (ExpInfo<E>::kDevMask & Device::kDevMask) != 0;
306  static const bool kMapPass = (kExpDim == 0 || kExpDim == dim) && kDevPass;
308  static const bool kRedPass = (kExpDim > dim) && kDevPass;
309 };
311 template<bool kPass>
313 // Todo : add static assert using C++11
314 template<>
315 struct TypeCheckPass<false> {};
316 template<>
317 struct TypeCheckPass<true> {
319  inline static void Error_TypeCheck_Not_Pass_For_Reduce_Exp(void) {}
321 };
322 
323 //----------------------------------------------------------------
324 // Runtime Stream Getting
325 //----------------------------------------------------------------
326 template<typename Device, typename E>
327 struct StreamInfo {
328  inline static Stream<Device> *Get(const E &t);
329 };
330 template<int dim, typename Device, typename DType>
331 struct StreamInfo<Device, Tensor<Device, dim, DType> > {
332  inline static Stream<Device> *Get(const Tensor<Device, dim, DType> &t) {
333  return t.stream_;
334  }
335 };
336 //----------------------------------------------------------------
337 // Runtime Shape Checking
338 //----------------------------------------------------------------
345 template<int dim, typename E>
346 struct ShapeCheck {
347  inline static Shape<dim> Check(const E &t);
348 };
349 template<int dim, typename DType>
350 struct ShapeCheck<dim, ScalarExp<DType> > {
351  inline static Shape<dim> Check(const ScalarExp<DType> &exp) {
352  // use lowest dimension to mark scalar exp
353  Shape<dim> shape;
354  for (int i = 0; i < dim; ++i) {
355  shape[i] = 0;
356  }
357  return shape;
358  }
359 };
360 template<int dim, typename DstDType, typename SrcDType, typename EType, int etype>
361 struct ShapeCheck<dim, TypecastExp<DstDType, SrcDType, EType, etype> > {
362  inline static Shape<dim>
365  }
366 };
367 template<int dim, typename E, typename DType>
368 struct ShapeCheck<dim, TransposeExp<E, DType> > {
369  inline static Shape<dim> Check(const TransposeExp<E, DType> &e) {
370  // swap the lowest two dimensions
372  std::swap(s[0], s[1]);
373  return s;
374  }
375 };
376 template<int dim, typename Device, typename DType>
377 struct ShapeCheck<dim, Tensor<Device, dim, DType> > {
378  inline static Shape<dim> Check(const Tensor<Device, dim, DType> &t) {
379  return t.shape_;
380  }
381 };
382 template<int dim, typename SrcExp, typename T, typename DType>
383 struct ShapeCheck<dim, MakeTensorExp<T, SrcExp, dim, DType> > {
384  inline static Shape<dim>
386  return t.shape_;
387  }
388 };
389 template<int dim, typename OP, typename TA, typename DType, int etype>
390 struct ShapeCheck<dim, UnaryMapExp<OP, TA, DType, etype> > {
393  return s;
394  }
395 };
396 
397 template<int dim, typename OP, typename TA, typename TB,
398  typename DType, int etype>
399 struct ShapeCheck<dim, BinaryMapExp<OP, TA, TB, DType, etype> > {
400  inline static Shape<dim>
404  if (shape1[0] == 0) return shape2;
405  if (shape2[0] == 0) return shape1;
406  CHECK_EQ(shape1, shape2) << "BinaryMapExp: Shapes of operands are not the same, " <<
407  "Shape1=" << shape1 << ", Shape2=" << shape2;
408  return shape1;
409  }
410 };
411 
412 template<int dim, typename OP, typename TA, typename TB, typename TC,
413  typename DType, int etype>
414 struct ShapeCheck<dim, TernaryMapExp<OP, TA, TB, TC, DType, etype> > {
415  inline static Shape<dim>
420  bool same = (shape1 == shape2) && (shape2 == shape3);
421  CHECK(same) << "TernaryMapExp: Shapes of operands are not the same, " <<
422  "Shape1=" << shape1 << ", Shape2=" << shape2 << ", Shape3=" << shape3;
423 
424  return shape1;
425  }
426 };
427 } // namespace expr
428 
429 } // namespace mshadow
430 // include definition of dot engine
431 #include "./dot_engine-inl.h"
432 
433 namespace mshadow {
434 namespace expr {
436 template<typename SV, typename RV, typename E, typename DType>
438  inline static void Eval(RV *dst, const E &exp);
439 };
441 template<typename SV, typename RV, typename DType>
442 struct ExpEngine {
443  template<typename E>
444  inline static void Eval(RV *dst,
445  const Exp<E, DType, type::kMapper> &exp) {
446  MapExp<SV>(dst, exp);
447  }
448  template<typename E>
449  inline static void Eval(RV *dst,
450  const Exp<E, DType, type::kChainer> &exp) {
451  MapExp<SV>(dst, exp);
452  }
453  template<typename E>
454  inline static void Eval(RV *dst,
455  const Exp<E, DType, type::kRValue> &exp) {
456  MapExp<SV>(dst, exp);
457  }
458  template<typename E>
459  inline static void Eval(RV *dst,
460  const Exp<E, DType, type::kComplex> &exp) {
461  ExpComplexEngine<SV, RV, E, DType>::Eval(dst->ptrself(), exp.self());
462  }
463 };
464 template<typename SV, typename Device, int dim, int ldim,
465  int rdim, bool ltrans, bool rtrans, typename DType>
467  Tensor<Device, dim, DType>,
468  DotExp<Tensor<Device, ldim, DType>,
469  Tensor<Device, rdim, DType>,
470  ltrans, rtrans, DType>,
471  DType> {
472  inline static void Eval(Tensor<Device, dim, DType> *dst,
475  ltrans, rtrans, DType> &exp) {
476  DotEngine<SV, Device, dim, ldim, rdim,
477  ltrans, rtrans, DType>::Eval(dst, exp.lhs_, exp.rhs_, exp.scale_);
478  }
479 };
480 } // namespace expr
481 } // namespace mshadow
482 #endif // MSHADOW_EXPR_ENGINE_INL_H_
Plan(const Plan< TA, DType > &src)
Definition: expr_engine-inl.h:142
static Shape< dim > Check(const UnaryMapExp< OP, TA, DType, etype > &t)
Definition: expr_engine-inl.h:391
static void Eval(RV *dst, const Exp< E, DType, type::kRValue > &exp)
Definition: expr_engine-inl.h:454
static Shape< dim > Check(const MakeTensorExp< T, SrcExp, dim, DType > &t)
Definition: expr_engine-inl.h:385
static Shape< dim > Check(const Tensor< Device, dim, DType > &t)
Definition: expr_engine-inl.h:378
ScalarExp< DType > scalar(DType s)
create an scalar expression
Definition: expression.h:85
static Shape< dim > Check(const BinaryMapExp< OP, TA, TB, DType, etype > &t)
Definition: expr_engine-inl.h:401
Definition: expr_engine-inl.h:40
Plan(const Tensor< Device, 1, DType > &t)
Definition: expr_engine-inl.h:71
used to help static type check
Definition: expr_engine-inl.h:312
template to do type check
Definition: expr_engine-inl.h:300
const TB & rhs_
right operand
Definition: expression.h:321
Plan(const Tensor< Device, dim, DType > &t)
Definition: expr_engine-inl.h:52
Shape< dimension > shape_
shape of the tensor
Definition: tensor.h:418
ternary map expression
Definition: expression.h:261
static void Error_All_Tensor_in_Exp_Must_Have_Same_Type(void)
Definition: expr_engine-inl.h:318
binary map expression lhs [op] rhs
Definition: expression.h:316
Plan(DType scalar)
Definition: expr_engine-inl.h:86
static void Eval(Tensor< Device, dim, DType > *dst, const DotExp< Tensor< Device, ldim, DType >, Tensor< Device, rdim, DType >, ltrans, rtrans, DType > &exp)
Definition: expr_engine-inl.h:472
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: expr_engine-inl.h:155
static void Eval(RV *dst, const E &exp)
base class of all rvalues
Definition: expression.h:130
Definition: dot_engine-inl.h:52
DType scalar_
scalar value
Definition: expression.h:79
MSHADOW_XINLINE DType & REval(index_t y, index_t x)
Definition: expr_engine-inl.h:55
const EType & exp
expression to be transposed
Definition: expression.h:116
static void Error_TypeCheck_Not_Pass_For_Reduce_Exp(void)
Definition: expr_engine-inl.h:319
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: expr_engine-inl.h:115
MSHADOW_XINLINE const DType & Eval(index_t y, index_t x) const
Definition: expr_engine-inl.h:59
static Shape< dim > Check(const E &t)
static void Eval(RV *dst, const Exp< E, DType, type::kChainer > &exp)
Definition: expr_engine-inl.h:449
#define MSHADOW_XINLINE
Definition: base.h:204
const TB & item2_
second operand
Definition: expression.h:266
static type inference template, used to get the dimension of each expression, if ExpInfo<E>::kDim == ...
Definition: expr_engine-inl.h:244
Definition: expr_engine-inl.h:327
definitions of abstract expressions and expressions template
static void Eval(RV *dst, const Exp< E, DType, type::kComplex > &exp)
Definition: expr_engine-inl.h:459
static Shape< dim > Check(const TernaryMapExp< OP, TA, TB, TC, DType, etype > &t)
Definition: expr_engine-inl.h:416
static void Error_Expression_Does_Not_Meet_Dimension_Req(void)
Definition: expr_engine-inl.h:320
MSHADOW_XINLINE const DType & Eval(index_t y, index_t x) const
Definition: expr_engine-inl.h:75
int32_t index_t
type that will be used for index
Definition: base.h:291
Plan(const Plan< EType, DType > &src)
Definition: expr_engine-inl.h:166
static Shape< dim > Check(const TypecastExp< DstDType, SrcDType, EType, etype > &exp)
Definition: expr_engine-inl.h:363
Plan(const Plan< TA, DType > &item1, const Plan< TB, DType > &item2, const Plan< TC, DType > &item3)
Definition: expr_engine-inl.h:112
const TA & item1_
first operand
Definition: expression.h:264
typecast expression, cast the type of elements
Definition: expression.h:96
static Shape< dim > Check(const ScalarExp< DType > &exp)
Definition: expr_engine-inl.h:351
runtime shape checking template get the shape of an expression, report error if shape mismatch ...
Definition: expr_engine-inl.h:346
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: expr_engine-inl.h:143
represent a transpose expression of a container
Definition: expression.h:113
some engine that evaluate complex expression
Definition: expr_engine-inl.h:437
const TA & src_
source expression
Definition: expression.h:389
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: expr_engine-inl.h:167
unary map expression op(src)
Definition: expression.h:386
matrix multiplication expression dot(lhs[.T], rhs[.T])
Definition: expression.h:206
Plan(const Plan< EType, SrcDType > &src)
Definition: expr_engine-inl.h:99
scalar expression
Definition: expression.h:77
Plan(const Plan< SubType, DType > &src)
Definition: expr_engine-inl.h:154
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:61
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: expr_engine-inl.h:87
const EType & exp
expression to be typecasted
Definition: expression.h:100
const Container & self(void) const
Definition: expression.h:64
Plan< BinaryMapExp< OP, TA, TB, DType, etype >, DType > MakePlan(const BinaryMapExp< OP, TA, TB, DType, etype > &e)
Definition: expr_engine-inl.h:221
const TC & item3_
third operand
Definition: expression.h:268
MSHADOW_XINLINE DstDType Eval(index_t y, index_t x) const
Definition: expr_engine-inl.h:100
const SubType & real_self(void) const
true self of subtype
Definition: expr_engine-inl.h:31
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: expr_engine-inl.h:130
Definition: tensor.h:550
a general class that allows extension that makes tensors of some shape
Definition: expr_engine-inl.h:25
const TA & lhs_
left operand
Definition: expression.h:319
namespace for mshadow
Definition: base.h:282
static Shape< dim > Check(const TransposeExp< E, DType > &e)
Definition: expr_engine-inl.h:369
the engine that dispatches simple operations
Definition: expr_engine-inl.h:442
Shape< dim > shape_
the shape of this expression
Definition: expr_engine-inl.h:29
general tensor
Definition: tensor.h:402
static void Eval(RV *dst, const Exp< E, DType, type::kMapper > &exp)
Definition: expr_engine-inl.h:444
Stream< Device > * stream_
stream where the computation lies stream is a device dependency concept where each computation ...
Definition: tensor.h:428
definitions of how Matrix Multiplications can be evaluated
Plan(const Plan< TA, DType > &lhs, const Plan< TB, DType > &rhs)
Definition: expr_engine-inl.h:128
MSHADOW_XINLINE DType & REval(index_t y, index_t x)
Definition: expr_engine-inl.h:72
static Stream< Device > * Get(const Tensor< Device, dim, DType > &t)
Definition: expr_engine-inl.h:332
computaion stream structure, used for asynchronous computations
Definition: tensor.h:365