mxnet
expression.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
26 #ifndef MSHADOW_EXPRESSION_H_
27 #define MSHADOW_EXPRESSION_H_
28 #include "./base.h"
29 
30 namespace mshadow {
37 namespace expr {
39 namespace type {
40 // type expression type are defined as bitmask
41 // subtype relationshop kRValue < kMapper < kPull < kComplex
46 const int kRValue = 0;
51 const int kMapper = 1;
57 const int kChainer = 3;
59 const int kComplex = 7;
60 } // namespace type
68 template<typename Saver, typename RValue, typename DType>
69 struct ExpEngine;
71 // template<typename EType>
72 // inline static void Eval(RValue *dst, const EType &exp);
79 template<typename SubType, typename DType, int exp_type>
80 struct Exp {
81  public:
83  inline const SubType& self(void) const {
84  return *static_cast<const SubType*>(this);
85  }
87  inline SubType* ptrself(void) {
88  return static_cast<SubType*>(this);
89  }
90 };
95 template<typename DType>
96 struct ScalarExp: public Exp<ScalarExp<DType>, DType, type::kMapper> {
98  DType scalar_;
100  ScalarExp(DType scalar) : scalar_(scalar) {} // NOLINT(*)
101 };
103 template<typename DType>
104 inline ScalarExp<DType> scalar(DType s) {
105  return ScalarExp<DType>(s);
106 }
114 template<typename DstDType, typename SrcDType, typename EType, int etype>
115 struct TypecastExp:
116  public Exp<TypecastExp<DstDType, SrcDType, EType, etype>,
117  DstDType, etype> {
119  const EType &exp;
121  explicit TypecastExp(const EType &e) : exp(e) {}
122 };
124 template<typename DstDType, typename SrcDType,
125  typename EType, int etype>
129 }
131 template<typename EType, typename DType>
132 struct TransposeExp: public Exp<TransposeExp<EType, DType>,
133  DType, type::kChainer> {
135  const EType &exp;
137  explicit TransposeExp(const EType &e) : exp(e) {}
139  inline const EType &T(void) const {
140  return exp;
141  }
142 };
148 template<typename Container, typename DType>
149 class RValueExp: public Exp<Container, DType, type::kRValue> {
150  public:
155  inline const TransposeExp<Container, DType> T(void) const {
156  return TransposeExp<Container, DType>(this->self());
157  }
159  inline Container &operator+=(DType s) {
160  ExpEngine<sv::plusto, Container, DType>::Eval(this->ptrself(), scalar<DType>(s));
161  return *(this->ptrself());
162  }
164  inline Container &operator-=(DType s) {
165  ExpEngine<sv::minusto, Container, DType>::Eval(this->ptrself(), scalar<DType>(s));
166  return *(this->ptrself());
167  }
169  inline Container &operator*=(DType s) {
170  ExpEngine<sv::multo, Container, DType>::Eval(this->ptrself(), scalar<DType>(s));
171  return *(this->ptrself());
172  }
174  inline Container &operator/=(DType s) {
175  ExpEngine<sv::divto, Container, DType>::Eval(this->ptrself(), scalar<DType>(s));
176  return *(this->ptrself());
177  }
179  inline Container &__assign(DType s) {
180  ExpEngine<sv::saveto, Container, DType>::Eval(this->ptrself(), scalar<DType>(s));
181  return *(this->ptrself());
182  }
184  template<typename E, int etype>
185  inline Container &__assign(const Exp<E, DType, etype> &exp) {
186  ExpEngine<sv::saveto, Container, DType>::Eval(this->ptrself(), exp.self());
187  return *(this->ptrself());
188  }
190  inline Container &__assign(const Exp<Container, DType, type::kRValue> &exp);
192  template<typename E, int etype>
193  inline Container &operator+=(const Exp<E, DType, etype> &exp) {
194  ExpEngine<sv::plusto, Container, DType>::Eval(this->ptrself(), exp.self());
195  return *(this->ptrself());
196  }
198  template<typename E, int etype>
199  inline Container &operator-=(const Exp<E, DType, etype> &exp) {
201  return *(this->ptrself());
202  }
204  template<typename E, int etype>
205  inline Container &operator*=(const Exp<E, DType, etype> &exp) {
206  ExpEngine<sv::multo, Container, DType>::Eval(this->ptrself(), exp.self());
207  return *(this->ptrself());
208  }
210  template<typename E, int etype>
211  inline Container &operator/=(const Exp<E, DType, etype> &exp) {
212  ExpEngine<sv::divto, Container, DType>::Eval(this->ptrself(), exp.self());
213  return *(this->ptrself());
214  }
215 };
224 template<typename TA, typename TB, bool ltrans, bool rtrans, typename DType>
225 struct DotExp: public Exp<DotExp<TA, TB, ltrans, rtrans, DType>,
226  DType, type::kComplex> {
228  const TA &lhs_;
230  const TB &rhs_;
232  DType scale_;
234  explicit DotExp(const TA &lhs, const TB &rhs, DType scale)
235  : lhs_(lhs), rhs_(rhs), scale_(scale) {}
236 };
237 // definition of dot expression
239 template<typename TA, typename TB, typename DType>
242  return DotExp<TA, TB, false, false, DType>(lhs.self(), rhs.self(), DType(1.0f));
243 }
245 template<typename TA, typename TB, typename DType>
248  return DotExp<TA, TB, true, false, DType>(lhs.exp, rhs.self(), DType(1.0f));
249 }
251 template<typename TA, typename TB, typename DType>
254  return DotExp<TA, TB, false, true, DType>(lhs.self(), rhs.exp, DType(1.0f));
255 }
257 template<typename TA, typename TB, typename DType>
260  return DotExp<TA, TB, true, true, DType>(lhs.exp, rhs.exp, DType(1.0f));
261 }
263 template<bool transpose_left, bool transpose_right, typename TA, typename TB, typename DType>
267  lhs.self(), rhs.self(), DType(1.0f));
268 }
269 //---------------
270 // TernaryMapExp
271 // --------------
279 template<typename OP, typename TA, typename TB, typename TC, typename DType, int etype>
280 struct TernaryMapExp: public Exp<TernaryMapExp<OP, TA, TB, TC, DType, etype>,
281  DType, etype> {
283  const TA &item1_;
285  const TB &item2_;
287  const TC &item3_;
289  explicit TernaryMapExp(const TA &item1, const TB &item2, const TC &item3)
290  :item1_(item1), item2_(item2), item3_(item3) {}
291 };
292 
294 template<typename OP, typename TA, typename TB, typename TC, typename DType, int ta, int tb, int tc>
296 MakeExp(const Exp<TA, DType, ta> &item1, const Exp<TB, DType, tb> &item2,
297  const Exp<TC, DType, tc> &item3) {
298  return TernaryMapExp<OP, TA, TB, TC, DType,
299  (ta|tb|tc|type::kMapper)>(item1.self(), item2.self(), item3.self());
300 }
317 // Ternary
318 template<typename OP, typename TA, typename TB, typename TC, typename DType, int ta, int tb, int tc>
320 F(const Exp<TA, DType, ta> &item1, const Exp<TB, DType, tb> &item2,
321  const Exp<TC, DType, tc> &item3) {
322  return MakeExp<OP>(item1, item2, item3);
323 }
324 //---------------
325 // BinaryMapExp
326 // --------------
334 template<typename OP, typename TA, typename TB, typename DType, int etype>
335 struct BinaryMapExp: public Exp<BinaryMapExp<OP, TA, TB, DType, etype>,
336  DType, etype> {
338  const TA &lhs_;
340  const TB &rhs_;
342  explicit BinaryMapExp(const TA &lhs, const TB &rhs)
343  :lhs_(lhs), rhs_(rhs) {}
344 };
345 
347 template<typename OP, typename TA, typename TB, typename DType, int ta, int tb>
350  return BinaryMapExp<OP, TA, TB, DType,
351  (ta|tb|type::kMapper)>(lhs.self(), rhs.self());
352 }
365 template<typename OP, typename TA, typename TB, typename DType, int ta, int tb>
367 F(const Exp<TA, DType, ta> &lhs, const Exp<TB, DType, tb> &rhs) {
368  return MakeExp<OP>(lhs, rhs);
369 }
370 // operator rules
372 template<typename TA, typename TB, typename DType, int ta, int tb>
375  return MakeExp<op::plus>(lhs, rhs);
376 }
378 template<typename TA, typename TB, typename DType, int ta, int tb>
381  return MakeExp<op::minus>(lhs, rhs);
382 }
384 template<typename TA, typename TB, typename DType, int ta, int tb>
387  return MakeExp<op::mul>(lhs, rhs);
388 }
390 template<typename TA, typename TB, typename DType, int ta, int tb>
393  return MakeExp<op::div>(lhs, rhs);
394 }
395 //---------------
396 // UnaryMapExp
397 // --------------
404 template<typename OP, typename TA, typename DType, int etype>
405 struct UnaryMapExp: public Exp<UnaryMapExp<OP, TA, DType, etype>,
406  DType, etype> {
408  const TA &src_;
410  explicit UnaryMapExp(const TA &src) : src_(src) {}
411 };
412 
414 template<typename OP, typename TA, typename DType, int ta>
418 }
428 template<typename OP, typename TA, typename DType, int ta>
430 F(const Exp<TA, DType, ta> &src) {
431  return MakeExp<OP>(src);
432 }
433 } // namespace expr
434 } // namespace mshadow
435 #endif // MSHADOW_EXPRESSION_H_
const int kRValue
this expression directly correspnds to a data class, can be used to assign data
Definition: expression.h:46
const int kMapper
expression contains element-wise tensor operations, map a expression to same shape ...
Definition: expression.h:51
DotExp< TA, TB, transpose_left, transpose_right, DType > batch_dot(const RValueExp< TA, DType > &lhs, const RValueExp< TB, DType > &rhs)
batch_dot operator def
Definition: expression.h:265
ScalarExp< DType > scalar(DType s)
create an scalar expression
Definition: expression.h:104
BinaryMapExp< op::minus, TA, ScalarExp< MSHADOW_SCALAR_ >, MSHADOW_SCALAR_,(ta|type::kMapper)> operator-(const Exp< TA, MSHADOW_SCALAR_, ta > &lhs, const ScalarExp< MSHADOW_SCALAR_ > &rhs)
operator overload
Definition: expr_scalar-inl.h:102
BinaryMapExp< op::plus, TA, ScalarExp< MSHADOW_SCALAR_ >, MSHADOW_SCALAR_,(ta|type::kMapper)> operator+(const Exp< TA, MSHADOW_SCALAR_, ta > &lhs, const ScalarExp< MSHADOW_SCALAR_ > &rhs)
operator overload
Definition: expr_scalar-inl.h:94
UnaryMapExp(const TA &src)
constructor
Definition: expression.h:410
const TB & rhs_
right operand
Definition: expression.h:340
SubType * ptrself(void)
Definition: expression.h:87
Container & operator/=(const Exp< E, DType, etype > &exp)
implementation of operator/=
Definition: expression.h:211
const int kComplex
othercase: e.g dot product
Definition: expression.h:59
ternary map expression
Definition: expression.h:280
binary map expression lhs [op] rhs
Definition: expression.h:335
DotExp< TA, TB, false, false, DType > dot(const RValueExp< TA, DType > &lhs, const RValueExp< TB, DType > &rhs)
dot operator def
Definition: expression.h:241
TernaryMapExp(const TA &item1, const TB &item2, const TC &item3)
constructor
Definition: expression.h:289
TypecastExp(const EType &e)
constructor
Definition: expression.h:121
BinaryMapExp(const TA &lhs, const TB &rhs)
constructor
Definition: expression.h:342
TypecastExp< DstDType, SrcDType, EType,(etype|type::kMapper)> tcast(const Exp< EType, SrcDType, etype > &exp)
create an scalar expression
Definition: expression.h:127
Container & operator*=(const Exp< E, DType, etype > &exp)
implementation of operator*=
Definition: expression.h:205
base class of all rvalues
Definition: expression.h:149
DType scalar_
scalar value
Definition: expression.h:98
const TB & rhs_
right operand
Definition: expression.h:230
const EType & exp
expression to be transposed
Definition: expression.h:135
const int kChainer
expression that can be chained with other expressiones Usually it have function Eval(i,j) defined, which pulls the result (i, j) from input expression and output the result at certain position.
Definition: expression.h:57
const TB & item2_
second operand
Definition: expression.h:285
ScalarExp(DType scalar)
implicit constructor, MUST NOT BE explicit
Definition: expression.h:100
Container & operator-=(DType s)
operator overload
Definition: expression.h:164
Container & operator-=(const Exp< E, DType, etype > &exp)
implementation of operator-=
Definition: expression.h:199
BinaryMapExp< OP, TA, ScalarExp< MSHADOW_SCALAR_ >, MSHADOW_SCALAR_,(ta|type::kMapper)> F(const Exp< TA, MSHADOW_SCALAR_, ta > &lhs, const ScalarExp< MSHADOW_SCALAR_ > &rhs)
operator overload for const
Definition: expr_scalar-inl.h:72
DType scale_
scale over result
Definition: expression.h:232
DotExp(const TA &lhs, const TB &rhs, DType scale)
constructor
Definition: expression.h:234
Container & operator/=(DType s)
operator overload
Definition: expression.h:174
const EType & T(void) const
transpose expression
Definition: expression.h:139
const TA & item1_
first operand
Definition: expression.h:283
typecast expression, cast the type of elements
Definition: expression.h:115
represent a transpose expression of a container
Definition: expression.h:132
Container & operator+=(DType s)
operator overload
Definition: expression.h:159
const TA & src_
source expression
Definition: expression.h:408
unary map expression op(src)
Definition: expression.h:405
matrix multiplication expression dot(lhs[.T], rhs[.T])
Definition: expression.h:225
scalar expression
Definition: expression.h:96
TernaryMapExp< OP, TA, TB, TC, DType,(ta|tb|tc|type::kMapper)> MakeExp(const Exp< TA, DType, ta > &item1, const Exp< TB, DType, tb > &item2, const Exp< TC, DType, tc > &item3)
make expression
Definition: expression.h:296
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:80
const EType & exp
expression to be typecasted
Definition: expression.h:119
const SubType & self(void) const
Definition: expression.h:83
const TC & item3_
third operand
Definition: expression.h:287
Container & operator*=(DType s)
operator overload
Definition: expression.h:169
BinaryMapExp< op::div, TA, ScalarExp< MSHADOW_SCALAR_ >, MSHADOW_SCALAR_,(ta|type::kMapper)> operator/(const Exp< TA, MSHADOW_SCALAR_, ta > &lhs, const ScalarExp< MSHADOW_SCALAR_ > &rhs)
operator overload
Definition: expr_scalar-inl.h:118
const TA & lhs_
left operand
Definition: expression.h:228
const TA & lhs_
left operand
Definition: expression.h:338
overloaded + operator between half_t and bf16_t
Definition: base.h:327
DotExp< TA, TB, ltrans, rtrans, MSHADOW_SCALAR_ > operator*(const DotExp< TA, TB, ltrans, rtrans, MSHADOW_SCALAR_ > &lhs, MSHADOW_SCALAR_ rhs)
dot operator def
Definition: expr_scalar-inl.h:41
the engine that dispatches simple operations
Definition: expr_engine-inl.h:461
Container & __assign(DType s)
operator overload
Definition: expression.h:179
Container & operator+=(const Exp< E, DType, etype > &exp)
implementation of operator+=
Definition: expression.h:193
const TransposeExp< Container, DType > T(void) const
transpose of a matrix
Definition: expression.h:155
static void Eval(RV *dst, const Exp< E, DType, type::kMapper > &exp)
Definition: expr_engine-inl.h:463
TransposeExp(const EType &e)
constructor
Definition: expression.h:137
Container & __assign(const Exp< E, DType, etype > &exp)
we can not define container = container
Definition: expression.h:185