mxnet
expression.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
25 #ifndef MSHADOW_EXPRESSION_H_
26 #define MSHADOW_EXPRESSION_H_
27 #include "./base.h"
28 
29 namespace mshadow {
36 namespace expr {
38 namespace type {
39 // type expression type are defined as bitmask
40 // subtype relationshop kRValue < kMapper < kPull < kComplex
45 const int kRValue = 0;
50 const int kMapper = 1;
56 const int kChainer = 3;
58 const int kComplex = 7;
59 } // namespace type
67 template<typename Saver, typename RValue, typename DType>
68 struct ExpEngine;
70 // template<typename EType>
71 // inline static void Eval(RValue *dst, const EType &exp);
78 template<typename SubType, typename DType, int exp_type>
79 struct Exp {
80  public:
82  inline const SubType& self(void) const {
83  return *static_cast<const SubType*>(this);
84  }
86  inline SubType* ptrself(void) {
87  return static_cast<SubType*>(this);
88  }
89 };
94 template<typename DType>
95 struct ScalarExp: public Exp<ScalarExp<DType>, DType, type::kMapper> {
97  DType scalar_;
99  ScalarExp(DType scalar) : scalar_(scalar) {} // NOLINT(*)
100 };
102 template<typename DType>
103 inline ScalarExp<DType> scalar(DType s) {
104  return ScalarExp<DType>(s);
105 }
113 template<typename DstDType, typename SrcDType, typename EType, int etype>
114 struct TypecastExp:
115  public Exp<TypecastExp<DstDType, SrcDType, EType, etype>,
116  DstDType, etype> {
118  const EType &exp;
120  explicit TypecastExp(const EType &e) : exp(e) {}
121 };
123 template<typename DstDType, typename SrcDType,
124  typename EType, int etype>
125 inline TypecastExp<DstDType, SrcDType, EType, (etype|type::kMapper)>
128 }
130 template<typename EType, typename DType>
131 struct TransposeExp: public Exp<TransposeExp<EType, DType>,
132  DType, type::kChainer> {
134  const EType &exp;
136  explicit TransposeExp(const EType &e) : exp(e) {}
138  inline const EType &T(void) const {
139  return exp;
140  }
141 };
147 template<typename Container, typename DType>
148 class RValueExp: public Exp<Container, DType, type::kRValue> {
149  public:
154  inline const TransposeExp<Container, DType> T(void) const {
155  return TransposeExp<Container, DType>(this->self());
156  }
158  inline Container &operator+=(DType s) {
159  ExpEngine<sv::plusto, Container, DType>::Eval(this->ptrself(), scalar<DType>(s));
160  return *(this->ptrself());
161  }
163  inline Container &operator-=(DType s) {
164  ExpEngine<sv::minusto, Container, DType>::Eval(this->ptrself(), scalar<DType>(s));
165  return *(this->ptrself());
166  }
168  inline Container &operator*=(DType s) {
169  ExpEngine<sv::multo, Container, DType>::Eval(this->ptrself(), scalar<DType>(s));
170  return *(this->ptrself());
171  }
173  inline Container &operator/=(DType s) {
174  ExpEngine<sv::divto, Container, DType>::Eval(this->ptrself(), scalar<DType>(s));
175  return *(this->ptrself());
176  }
178  inline Container &__assign(DType s) {
179  ExpEngine<sv::saveto, Container, DType>::Eval(this->ptrself(), scalar<DType>(s));
180  return *(this->ptrself());
181  }
183  template<typename E, int etype>
184  inline Container &__assign(const Exp<E, DType, etype> &exp) {
186  return *(this->ptrself());
187  }
189  inline Container &__assign(const Exp<Container, DType, type::kRValue> &exp);
191  template<typename E, int etype>
192  inline Container &operator+=(const Exp<E, DType, etype> &exp) {
194  return *(this->ptrself());
195  }
197  template<typename E, int etype>
198  inline Container &operator-=(const Exp<E, DType, etype> &exp) {
200  return *(this->ptrself());
201  }
203  template<typename E, int etype>
204  inline Container &operator*=(const Exp<E, DType, etype> &exp) {
206  return *(this->ptrself());
207  }
209  template<typename E, int etype>
210  inline Container &operator/=(const Exp<E, DType, etype> &exp) {
212  return *(this->ptrself());
213  }
214 };
223 template<typename TA, typename TB, bool ltrans, bool rtrans, typename DType>
224 struct DotExp: public Exp<DotExp<TA, TB, ltrans, rtrans, DType>,
225  DType, type::kComplex> {
227  const TA &lhs_;
229  const TB &rhs_;
231  DType scale_;
233  explicit DotExp(const TA &lhs, const TB &rhs, DType scale)
234  : lhs_(lhs), rhs_(rhs), scale_(scale) {}
235 };
236 // definition of dot expression
238 template<typename TA, typename TB, typename DType>
239 inline DotExp<TA, TB, false, false, DType>
241  return DotExp<TA, TB, false, false, DType>(lhs.self(), rhs.self(), DType(1.0f));
242 }
244 template<typename TA, typename TB, typename DType>
245 inline DotExp<TA, TB, true, false, DType>
247  return DotExp<TA, TB, true, false, DType>(lhs.exp, rhs.self(), DType(1.0f));
248 }
250 template<typename TA, typename TB, typename DType>
251 inline DotExp<TA, TB, false, true, DType>
253  return DotExp<TA, TB, false, true, DType>(lhs.self(), rhs.exp, DType(1.0f));
254 }
256 template<typename TA, typename TB, typename DType>
257 inline DotExp<TA, TB, true, true, DType>
259  return DotExp<TA, TB, true, true, DType>(lhs.exp, rhs.exp, DType(1.0f));
260 }
262 template<bool transpose_left, bool transpose_right, typename TA, typename TB, typename DType>
263 inline DotExp<TA, TB, transpose_left, transpose_right, DType>
266  lhs.self(), rhs.self(), DType(1.0f));
267 }
268 //---------------
269 // TernaryMapExp
270 // --------------
278 template<typename OP, typename TA, typename TB, typename TC, typename DType, int etype>
279 struct TernaryMapExp: public Exp<TernaryMapExp<OP, TA, TB, TC, DType, etype>,
280  DType, etype> {
282  const TA &item1_;
284  const TB &item2_;
286  const TC &item3_;
288  explicit TernaryMapExp(const TA &item1, const TB &item2, const TC &item3)
289  :item1_(item1), item2_(item2), item3_(item3) {}
290 };
291 
293 template<typename OP, typename TA, typename TB, typename TC, typename DType, int ta, int tb, int tc>
294 inline TernaryMapExp<OP, TA, TB, TC, DType, (ta|tb|tc|type::kMapper)>
295 MakeExp(const Exp<TA, DType, ta> &item1, const Exp<TB, DType, tb> &item2,
296  const Exp<TC, DType, tc> &item3) {
297  return TernaryMapExp<OP, TA, TB, TC, DType,
298  (ta|tb|tc|type::kMapper)>(item1.self(), item2.self(), item3.self());
299 }
316 // Ternary
317 template<typename OP, typename TA, typename TB, typename TC, typename DType, int ta, int tb, int tc>
318 inline TernaryMapExp<OP, TA, TB, TC, DType, (ta|tb|tc|type::kMapper)>
319 F(const Exp<TA, DType, ta> &item1, const Exp<TB, DType, tb> &item2,
320  const Exp<TC, DType, tc> &item3) {
321  return MakeExp<OP>(item1, item2, item3);
322 }
323 //---------------
324 // BinaryMapExp
325 // --------------
333 template<typename OP, typename TA, typename TB, typename DType, int etype>
334 struct BinaryMapExp: public Exp<BinaryMapExp<OP, TA, TB, DType, etype>,
335  DType, etype> {
337  const TA &lhs_;
339  const TB &rhs_;
341  explicit BinaryMapExp(const TA &lhs, const TB &rhs)
342  :lhs_(lhs), rhs_(rhs) {}
343 };
344 
346 template<typename OP, typename TA, typename TB, typename DType, int ta, int tb>
347 inline BinaryMapExp<OP, TA, TB, DType, (ta|tb|type::kMapper)>
349  return BinaryMapExp<OP, TA, TB, DType,
350  (ta|tb|type::kMapper)>(lhs.self(), rhs.self());
351 }
364 template<typename OP, typename TA, typename TB, typename DType, int ta, int tb>
365 inline BinaryMapExp<OP, TA, TB, DType, (ta|tb|type::kMapper)>
366 F(const Exp<TA, DType, ta> &lhs, const Exp<TB, DType, tb> &rhs) {
367  return MakeExp<OP>(lhs, rhs);
368 }
369 // operator rules
371 template<typename TA, typename TB, typename DType, int ta, int tb>
372 inline BinaryMapExp<op::plus, TA, TB, DType, (ta|tb|type::kMapper)>
373 operator+(const Exp<TA, DType, ta> &lhs, const Exp<TB, DType, tb> &rhs) {
374  return MakeExp<op::plus>(lhs, rhs);
375 }
377 template<typename TA, typename TB, typename DType, int ta, int tb>
378 inline BinaryMapExp<op::minus, TA, TB, DType, (ta|tb|type::kMapper)>
379 operator-(const Exp<TA, DType, ta> &lhs, const Exp<TB, DType, tb> &rhs) {
380  return MakeExp<op::minus>(lhs, rhs);
381 }
383 template<typename TA, typename TB, typename DType, int ta, int tb>
384 inline BinaryMapExp<op::mul, TA, TB, DType, (ta|tb|type::kMapper)>
385 operator*(const Exp<TA, DType, ta> &lhs, const Exp<TB, DType, tb> &rhs) {
386  return MakeExp<op::mul>(lhs, rhs);
387 }
389 template<typename TA, typename TB, typename DType, int ta, int tb>
390 inline BinaryMapExp<op::div, TA, TB, DType, (ta|tb|type::kMapper)>
391 operator/(const Exp<TA, DType, ta> &lhs, const Exp<TB, DType, tb> &rhs) {
392  return MakeExp<op::div>(lhs, rhs);
393 }
394 //---------------
395 // UnaryMapExp
396 // --------------
403 template<typename OP, typename TA, typename DType, int etype>
404 struct UnaryMapExp: public Exp<UnaryMapExp<OP, TA, DType, etype>,
405  DType, etype> {
407  const TA &src_;
409  explicit UnaryMapExp(const TA &src) : src_(src) {}
410 };
411 
413 template<typename OP, typename TA, typename DType, int ta>
414 inline UnaryMapExp<OP, TA, DType, (ta|type::kMapper)>
417 }
427 template<typename OP, typename TA, typename DType, int ta>
428 inline UnaryMapExp<OP, TA, DType, (ta|type::kMapper)>
429 F(const Exp<TA, DType, ta> &src) {
430  return MakeExp<OP>(src);
431 }
432 } // namespace expr
433 } // namespace mshadow
434 #endif // MSHADOW_EXPRESSION_H_
mshadow::expr::Exp::self
const SubType & self(void) const
Definition: expression.h:82
mshadow::expr::scalar
ScalarExp< DType > scalar(DType s)
create an scalar expression
Definition: expression.h:103
mshadow::expr::TypecastExp::exp
const EType & exp
expression to be typecasted
Definition: expression.h:118
mshadow::expr::MakeExp
TernaryMapExp< OP, TA, TB, TC, DType,(ta|tb|tc|type::kMapper)> MakeExp(const Exp< TA, DType, ta > &item1, const Exp< TB, DType, tb > &item2, const Exp< TC, DType, tc > &item3)
make expression
Definition: expression.h:295
mshadow::expr::RValueExp::operator/=
Container & operator/=(const Exp< E, DType, etype > &exp)
implementation of operator/=
Definition: expression.h:210
mshadow::expr::RValueExp::T
const TransposeExp< Container, DType > T(void) const
transpose of a matrix
Definition: expression.h:154
mshadow::expr::TypecastExp::TypecastExp
TypecastExp(const EType &e)
constructor
Definition: expression.h:120
mshadow::expr::Exp::ptrself
SubType * ptrself(void)
Definition: expression.h:86
mshadow::expr::TypecastExp
typecast expression, cast the type of elements
Definition: expression.h:114
mshadow::expr::batch_dot
DotExp< TA, TB, transpose_left, transpose_right, DType > batch_dot(const RValueExp< TA, DType > &lhs, const RValueExp< TB, DType > &rhs)
batch_dot operator def
Definition: expression.h:264
mshadow::expr::RValueExp::operator-=
Container & operator-=(DType s)
operator overload
Definition: expression.h:163
mshadow::op::minus
minus operator
Definition: base.h:641
mshadow::expr::DotExp::DotExp
DotExp(const TA &lhs, const TB &rhs, DType scale)
constructor
Definition: expression.h:233
mshadow::expr::DotExp::lhs_
const TA & lhs_
left operand
Definition: expression.h:227
mshadow::op::div
divide operator
Definition: base.h:649
mshadow::expr::TernaryMapExp::item3_
const TC & item3_
third operand
Definition: expression.h:286
mshadow::expr::ExpEngine
the engine that dispatches simple operations
Definition: expr_engine-inl.h:460
mshadow::expr::TernaryMapExp
ternary map expression
Definition: expression.h:279
mshadow::expr::TernaryMapExp::item1_
const TA & item1_
first operand
Definition: expression.h:282
mshadow::expr::RValueExp::operator*=
Container & operator*=(const Exp< E, DType, etype > &exp)
implementation of operator*=
Definition: expression.h:204
mshadow::expr::BinaryMapExp::BinaryMapExp
BinaryMapExp(const TA &lhs, const TB &rhs)
constructor
Definition: expression.h:341
mshadow::expr::F
BinaryMapExp< OP, TA, ScalarExp< MSHADOW_SCALAR_ >, MSHADOW_SCALAR_,(ta|type::kMapper)> F(const Exp< TA, MSHADOW_SCALAR_, ta > &lhs, const ScalarExp< MSHADOW_SCALAR_ > &rhs)
operator overload for const
Definition: expr_scalar-inl.h:71
mshadow::expr::tcast
TypecastExp< DstDType, SrcDType, EType,(etype|type::kMapper)> tcast(const Exp< EType, SrcDType, etype > &exp)
create an scalar expression
Definition: expression.h:126
mshadow::expr::type::kMapper
const int kMapper
expression contains element-wise tensor operations, map a expression to same shape
Definition: expression.h:50
mshadow::expr::BinaryMapExp
binary map expression lhs [op] rhs
Definition: expression.h:334
mshadow::expr::RValueExp::operator-=
Container & operator-=(const Exp< E, DType, etype > &exp)
implementation of operator-=
Definition: expression.h:198
mshadow::expr::TransposeExp::T
const EType & T(void) const
transpose expression
Definition: expression.h:138
mshadow::expr::DotExp::scale_
DType scale_
scale over result
Definition: expression.h:231
mshadow::op::plus
plus operator
Definition: base.h:633
mshadow::expr::DotExp::rhs_
const TB & rhs_
right operand
Definition: expression.h:229
mshadow::expr::RValueExp::operator/=
Container & operator/=(DType s)
operator overload
Definition: expression.h:173
mshadow::expr::type::kComplex
const int kComplex
othercase: e.g dot product
Definition: expression.h:58
mshadow::expr::UnaryMapExp::src_
const TA & src_
source expression
Definition: expression.h:407
mshadow::expr::TernaryMapExp::item2_
const TB & item2_
second operand
Definition: expression.h:284
mshadow::expr::RValueExp::operator+=
Container & operator+=(DType s)
operator overload
Definition: expression.h:158
mshadow::expr::ScalarExp::scalar_
DType scalar_
scalar value
Definition: expression.h:97
mshadow::expr::ScalarExp::ScalarExp
ScalarExp(DType scalar)
implicit constructor, MUST NOT BE explicit
Definition: expression.h:99
mshadow::expr::UnaryMapExp::UnaryMapExp
UnaryMapExp(const TA &src)
constructor
Definition: expression.h:409
mshadow::expr::BinaryMapExp::rhs_
const TB & rhs_
right operand
Definition: expression.h:339
mshadow::expr::RValueExp
base class of all rvalues
Definition: expression.h:148
mshadow::expr::Exp
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:79
mshadow::expr::RValueExp::__assign
Container & __assign(const Exp< E, DType, etype > &exp)
we can not define container = container
Definition: expression.h:184
mshadow::expr::TransposeExp::TransposeExp
TransposeExp(const EType &e)
constructor
Definition: expression.h:136
mshadow::expr::TransposeExp
represent a transpose expression of a container
Definition: expression.h:131
mshadow::expr::TransposeExp::exp
const EType & exp
expression to be transposed
Definition: expression.h:134
mshadow
overloaded + operator between half_t and bf16_t
Definition: base.h:319
mshadow::expr::BinaryMapExp::lhs_
const TA & lhs_
left operand
Definition: expression.h:337
mshadow::expr::ExpEngine::Eval
static void Eval(RV *dst, const Exp< E, DType, type::kMapper > &exp)
Definition: expr_engine-inl.h:462
mshadow::expr::RValueExp::__assign
Container & __assign(DType s)
operator overload
Definition: expression.h:178
mshadow::expr::TernaryMapExp::TernaryMapExp
TernaryMapExp(const TA &item1, const TB &item2, const TC &item3)
constructor
Definition: expression.h:288
mshadow::op::mul
mul operator
Definition: base.h:625
mshadow::expr::dot
DotExp< TA, TB, false, false, DType > dot(const RValueExp< TA, DType > &lhs, const RValueExp< TB, DType > &rhs)
dot operator def
Definition: expression.h:240
mshadow::expr::RValueExp::operator*=
Container & operator*=(DType s)
operator overload
Definition: expression.h:168
mshadow::expr::UnaryMapExp
unary map expression op(src)
Definition: expression.h:404
base.h
definitions of base types, operators, macros functions
mshadow::expr::ScalarExp
scalar expression
Definition: expression.h:95
mshadow::expr::type::kRValue
const int kRValue
this expression directly correspnds to a data class, can be used to assign data
Definition: expression.h:45
mshadow::expr::type::kChainer
const int kChainer
expression that can be chained with other expressiones Usually it have function Eval(i,...
Definition: expression.h:56
mshadow::expr::RValueExp::operator+=
Container & operator+=(const Exp< E, DType, etype > &exp)
implementation of operator+=
Definition: expression.h:192
mshadow::expr::DotExp
matrix multiplication expression dot(lhs[.T], rhs[.T])
Definition: expression.h:224