mxnet
expression.h
Go to the documentation of this file.
1 
7 #ifndef MSHADOW_EXPRESSION_H_
8 #define MSHADOW_EXPRESSION_H_
9 #include "./base.h"
10 
11 namespace mshadow {
18 namespace expr {
20 namespace type {
21 // type expression type are defined as bitmask
22 // subtype relationshop kRValue < kMapper < kPull < kComplex
27 const int kRValue = 0;
32 const int kMapper = 1;
38 const int kChainer = 3;
40 const int kComplex = 7;
41 } // namespace type
49 template<typename Saver, typename RValue, typename DType>
50 struct ExpEngine;
52 // template<typename EType>
53 // inline static void Eval(RValue *dst, const EType &exp);
60 template<typename SubType, typename DType, int exp_type>
61 struct Exp {
62  public:
64  inline const SubType& self(void) const {
65  return *static_cast<const SubType*>(this);
66  }
68  inline SubType* ptrself(void) {
69  return static_cast<SubType*>(this);
70  }
71 };
76 template<typename DType>
77 struct ScalarExp: public Exp<ScalarExp<DType>, DType, type::kMapper> {
79  DType scalar_;
81  ScalarExp(DType scalar) : scalar_(scalar) {} // NOLINT(*)
82 };
84 template<typename DType>
85 inline ScalarExp<DType> scalar(DType s) {
86  return ScalarExp<DType>(s);
87 }
95 template<typename DstDType, typename SrcDType, typename EType, int etype>
96 struct TypecastExp:
97  public Exp<TypecastExp<DstDType, SrcDType, EType, etype>,
98  DstDType, etype> {
100  const EType &exp;
102  explicit TypecastExp(const EType &e) : exp(e) {}
103 };
105 template<typename DstDType, typename SrcDType,
106  typename EType, int etype>
110 }
112 template<typename EType, typename DType>
113 struct TransposeExp: public Exp<TransposeExp<EType, DType>,
114  DType, type::kChainer> {
116  const EType &exp;
118  explicit TransposeExp(const EType &e) : exp(e) {}
120  inline const EType &T(void) const {
121  return exp;
122  }
123 };
129 template<typename Container, typename DType>
130 class RValueExp: public Exp<Container, DType, type::kRValue> {
131  public:
136  inline const TransposeExp<Container, DType> T(void) const {
137  return TransposeExp<Container, DType>(this->self());
138  }
140  inline Container &operator+=(DType s) {
141  ExpEngine<sv::plusto, Container, DType>::Eval(this->ptrself(), scalar<DType>(s));
142  return *(this->ptrself());
143  }
145  inline Container &operator-=(DType s) {
146  ExpEngine<sv::minusto, Container, DType>::Eval(this->ptrself(), scalar<DType>(s));
147  return *(this->ptrself());
148  }
150  inline Container &operator*=(DType s) {
151  ExpEngine<sv::multo, Container, DType>::Eval(this->ptrself(), scalar<DType>(s));
152  return *(this->ptrself());
153  }
155  inline Container &operator/=(DType s) {
156  ExpEngine<sv::divto, Container, DType>::Eval(this->ptrself(), scalar<DType>(s));
157  return *(this->ptrself());
158  }
160  inline Container &__assign(DType s) {
161  ExpEngine<sv::saveto, Container, DType>::Eval(this->ptrself(), scalar<DType>(s));
162  return *(this->ptrself());
163  }
165  template<typename E, int etype>
166  inline Container &__assign(const Exp<E, DType, etype> &exp) {
167  ExpEngine<sv::saveto, Container, DType>::Eval(this->ptrself(), exp.self());
168  return *(this->ptrself());
169  }
171  inline Container &__assign(const Exp<Container, DType, type::kRValue> &exp);
173  template<typename E, int etype>
174  inline Container &operator+=(const Exp<E, DType, etype> &exp) {
175  ExpEngine<sv::plusto, Container, DType>::Eval(this->ptrself(), exp.self());
176  return *(this->ptrself());
177  }
179  template<typename E, int etype>
180  inline Container &operator-=(const Exp<E, DType, etype> &exp) {
182  return *(this->ptrself());
183  }
185  template<typename E, int etype>
186  inline Container &operator*=(const Exp<E, DType, etype> &exp) {
187  ExpEngine<sv::multo, Container, DType>::Eval(this->ptrself(), exp.self());
188  return *(this->ptrself());
189  }
191  template<typename E, int etype>
192  inline Container &operator/=(const Exp<E, DType, etype> &exp) {
193  ExpEngine<sv::divto, Container, DType>::Eval(this->ptrself(), exp.self());
194  return *(this->ptrself());
195  }
196 };
205 template<typename TA, typename TB, bool ltrans, bool rtrans, typename DType>
206 struct DotExp: public Exp<DotExp<TA, TB, ltrans, rtrans, DType>,
207  DType, type::kComplex> {
209  const TA &lhs_;
211  const TB &rhs_;
213  DType scale_;
215  explicit DotExp(const TA &lhs, const TB &rhs, DType scale)
216  : lhs_(lhs), rhs_(rhs), scale_(scale) {}
217 };
218 // definition of dot expression
220 template<typename TA, typename TB, typename DType>
223  return DotExp<TA, TB, false, false, DType>(lhs.self(), rhs.self(), DType(1.0f));
224 }
226 template<typename TA, typename TB, typename DType>
229  return DotExp<TA, TB, true, false, DType>(lhs.exp, rhs.self(), DType(1.0f));
230 }
232 template<typename TA, typename TB, typename DType>
235  return DotExp<TA, TB, false, true, DType>(lhs.self(), rhs.exp, DType(1.0f));
236 }
238 template<typename TA, typename TB, typename DType>
241  return DotExp<TA, TB, true, true, DType>(lhs.exp, rhs.exp, DType(1.0f));
242 }
244 template<bool transpose_left, bool transpose_right, typename TA, typename TB, typename DType>
248  lhs.self(), rhs.self(), DType(1.0f));
249 }
250 //---------------
251 // TernaryMapExp
252 // --------------
260 template<typename OP, typename TA, typename TB, typename TC, typename DType, int etype>
261 struct TernaryMapExp: public Exp<TernaryMapExp<OP, TA, TB, TC, DType, etype>,
262  DType, etype> {
264  const TA &item1_;
266  const TB &item2_;
268  const TC &item3_;
270  explicit TernaryMapExp(const TA &item1, const TB &item2, const TC &item3)
271  :item1_(item1), item2_(item2), item3_(item3) {}
272 };
273 
275 template<typename OP, typename TA, typename TB, typename TC, typename DType, int ta, int tb, int tc>
277 MakeExp(const Exp<TA, DType, ta> &item1, const Exp<TB, DType, tb> &item2,
278  const Exp<TC, DType, tc> &item3) {
279  return TernaryMapExp<OP, TA, TB, TC, DType,
280  (ta|tb|tc|type::kMapper)>(item1.self(), item2.self(), item3.self());
281 }
298 // Ternary
299 template<typename OP, typename TA, typename TB, typename TC, typename DType, int ta, int tb, int tc>
301 F(const Exp<TA, DType, ta> &item1, const Exp<TB, DType, tb> &item2,
302  const Exp<TC, DType, tc> &item3) {
303  return MakeExp<OP>(item1, item2, item3);
304 }
305 //---------------
306 // BinaryMapExp
307 // --------------
315 template<typename OP, typename TA, typename TB, typename DType, int etype>
316 struct BinaryMapExp: public Exp<BinaryMapExp<OP, TA, TB, DType, etype>,
317  DType, etype> {
319  const TA &lhs_;
321  const TB &rhs_;
323  explicit BinaryMapExp(const TA &lhs, const TB &rhs)
324  :lhs_(lhs), rhs_(rhs) {}
325 };
326 
328 template<typename OP, typename TA, typename TB, typename DType, int ta, int tb>
331  return BinaryMapExp<OP, TA, TB, DType,
332  (ta|tb|type::kMapper)>(lhs.self(), rhs.self());
333 }
346 template<typename OP, typename TA, typename TB, typename DType, int ta, int tb>
348 F(const Exp<TA, DType, ta> &lhs, const Exp<TB, DType, tb> &rhs) {
349  return MakeExp<OP>(lhs, rhs);
350 }
351 // operator rules
353 template<typename TA, typename TB, typename DType, int ta, int tb>
356  return MakeExp<op::plus>(lhs, rhs);
357 }
359 template<typename TA, typename TB, typename DType, int ta, int tb>
362  return MakeExp<op::minus>(lhs, rhs);
363 }
365 template<typename TA, typename TB, typename DType, int ta, int tb>
368  return MakeExp<op::mul>(lhs, rhs);
369 }
371 template<typename TA, typename TB, typename DType, int ta, int tb>
374  return MakeExp<op::div>(lhs, rhs);
375 }
376 //---------------
377 // UnaryMapExp
378 // --------------
385 template<typename OP, typename TA, typename DType, int etype>
386 struct UnaryMapExp: public Exp<UnaryMapExp<OP, TA, DType, etype>,
387  DType, etype> {
389  const TA &src_;
391  explicit UnaryMapExp(const TA &src) : src_(src) {}
392 };
393 
395 template<typename OP, typename TA, typename DType, int ta>
399 }
409 template<typename OP, typename TA, typename DType, int ta>
411 F(const Exp<TA, DType, ta> &src) {
412  return MakeExp<OP>(src);
413 }
414 } // namespace expr
415 } // namespace mshadow
416 #endif // MSHADOW_EXPRESSION_H_
const int kRValue
this expression directly correspnds to a data class, can be used to assign data
Definition: expression.h:27
const int kMapper
expression contains element-wise tensor operations, map a expression to same shape ...
Definition: expression.h:32
DotExp< TA, TB, transpose_left, transpose_right, DType > batch_dot(const RValueExp< TA, DType > &lhs, const RValueExp< TB, DType > &rhs)
batch_dot operator def
Definition: expression.h:246
ScalarExp< DType > scalar(DType s)
create an scalar expression
Definition: expression.h:85
BinaryMapExp< op::minus, TA, ScalarExp< MSHADOW_SCALAR_ >, MSHADOW_SCALAR_,(ta|type::kMapper)> operator-(const Exp< TA, MSHADOW_SCALAR_, ta > &lhs, const ScalarExp< MSHADOW_SCALAR_ > &rhs)
operator overload
Definition: expr_scalar-inl.h:83
BinaryMapExp< op::plus, TA, ScalarExp< MSHADOW_SCALAR_ >, MSHADOW_SCALAR_,(ta|type::kMapper)> operator+(const Exp< TA, MSHADOW_SCALAR_, ta > &lhs, const ScalarExp< MSHADOW_SCALAR_ > &rhs)
operator overload
Definition: expr_scalar-inl.h:75
UnaryMapExp(const TA &src)
constructor
Definition: expression.h:391
const TB & rhs_
right operand
Definition: expression.h:321
SubType * ptrself(void)
Definition: expression.h:68
Container & operator/=(const Exp< E, DType, etype > &exp)
implementation of operator/=
Definition: expression.h:192
const int kComplex
othercase: e.g dot product
Definition: expression.h:40
ternary map expression
Definition: expression.h:261
binary map expression lhs [op] rhs
Definition: expression.h:316
DotExp< TA, TB, false, false, DType > dot(const RValueExp< TA, DType > &lhs, const RValueExp< TB, DType > &rhs)
dot operator def
Definition: expression.h:222
TernaryMapExp(const TA &item1, const TB &item2, const TC &item3)
constructor
Definition: expression.h:270
TypecastExp(const EType &e)
constructor
Definition: expression.h:102
BinaryMapExp(const TA &lhs, const TB &rhs)
constructor
Definition: expression.h:323
TypecastExp< DstDType, SrcDType, EType,(etype|type::kMapper)> tcast(const Exp< EType, SrcDType, etype > &exp)
create an scalar expression
Definition: expression.h:108
Container & operator*=(const Exp< E, DType, etype > &exp)
implementation of operator*=
Definition: expression.h:186
base class of all rvalues
Definition: expression.h:130
DType scalar_
scalar value
Definition: expression.h:79
const TB & rhs_
right operand
Definition: expression.h:211
const EType & exp
expression to be transposed
Definition: expression.h:116
const int kChainer
expression that can be chained with other expressiones Usually it have function Eval(i,j) defined, which pulls the result (i, j) from input expression and output the result at certain position.
Definition: expression.h:38
const TB & item2_
second operand
Definition: expression.h:266
ScalarExp(DType scalar)
implicit constructor, MUST NOT BE explicit
Definition: expression.h:81
Container & operator-=(DType s)
operator overload
Definition: expression.h:145
Container & operator-=(const Exp< E, DType, etype > &exp)
implementation of operator-=
Definition: expression.h:180
BinaryMapExp< OP, TA, ScalarExp< MSHADOW_SCALAR_ >, MSHADOW_SCALAR_,(ta|type::kMapper)> F(const Exp< TA, MSHADOW_SCALAR_, ta > &lhs, const ScalarExp< MSHADOW_SCALAR_ > &rhs)
operator overload for const
Definition: expr_scalar-inl.h:53
DType scale_
scale over result
Definition: expression.h:213
DotExp(const TA &lhs, const TB &rhs, DType scale)
constructor
Definition: expression.h:215
Container & operator/=(DType s)
operator overload
Definition: expression.h:155
const EType & T(void) const
transpose expression
Definition: expression.h:120
const TA & item1_
first operand
Definition: expression.h:264
typecast expression, cast the type of elements
Definition: expression.h:96
represent a transpose expression of a container
Definition: expression.h:113
Container & operator+=(DType s)
operator overload
Definition: expression.h:140
const TA & src_
source expression
Definition: expression.h:389
unary map expression op(src)
Definition: expression.h:386
matrix multiplication expression dot(lhs[.T], rhs[.T])
Definition: expression.h:206
scalar expression
Definition: expression.h:77
TernaryMapExp< OP, TA, TB, TC, DType,(ta|tb|tc|type::kMapper)> MakeExp(const Exp< TA, DType, ta > &item1, const Exp< TB, DType, tb > &item2, const Exp< TC, DType, tc > &item3)
make expression
Definition: expression.h:277
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:61
const EType & exp
expression to be typecasted
Definition: expression.h:100
const SubType & self(void) const
Definition: expression.h:64
const TC & item3_
third operand
Definition: expression.h:268
Container & operator*=(DType s)
operator overload
Definition: expression.h:150
BinaryMapExp< op::div, TA, ScalarExp< MSHADOW_SCALAR_ >, MSHADOW_SCALAR_,(ta|type::kMapper)> operator/(const Exp< TA, MSHADOW_SCALAR_, ta > &lhs, const ScalarExp< MSHADOW_SCALAR_ > &rhs)
operator overload
Definition: expr_scalar-inl.h:99
const TA & lhs_
left operand
Definition: expression.h:209
const TA & lhs_
left operand
Definition: expression.h:319
namespace for mshadow
Definition: base.h:282
DotExp< TA, TB, ltrans, rtrans, MSHADOW_SCALAR_ > operator*(const DotExp< TA, TB, ltrans, rtrans, MSHADOW_SCALAR_ > &lhs, MSHADOW_SCALAR_ rhs)
dot operator def
Definition: expr_scalar-inl.h:22
the engine that dispatches simple operations
Definition: expr_engine-inl.h:442
Container & __assign(DType s)
operator overload
Definition: expression.h:160
Container & operator+=(const Exp< E, DType, etype > &exp)
implementation of operator+=
Definition: expression.h:174
const TransposeExp< Container, DType > T(void) const
transpose of a matrix
Definition: expression.h:136
static void Eval(RV *dst, const Exp< E, DType, type::kMapper > &exp)
Definition: expr_engine-inl.h:444
TransposeExp(const EType &e)
constructor
Definition: expression.h:118
Container & __assign(const Exp< E, DType, etype > &exp)
we can not define container = container
Definition: expression.h:166