26 #ifndef MSHADOW_EXPRESSION_H_ 27 #define MSHADOW_EXPRESSION_H_ 68 template<
typename Saver,
typename RValue,
typename DType>
79 template<
typename SubType,
typename DType,
int exp_type>
83 inline const SubType&
self(void)
const {
84 return *
static_cast<const SubType*
>(
this);
88 return static_cast<SubType*
>(
this);
95 template<
typename DType>
96 struct ScalarExp:
public Exp<ScalarExp<DType>, DType, type::kMapper> {
103 template<
typename DType>
114 template<
typename DstDType,
typename SrcDType,
typename EType,
int etype>
116 public Exp<TypecastExp<DstDType, SrcDType, EType, etype>,
124 template<
typename DstDType,
typename SrcDType,
125 typename EType,
int etype>
131 template<
typename EType,
typename DType>
133 DType, type::kChainer> {
139 inline const EType &
T(
void)
const {
148 template<
typename Container,
typename DType>
161 return *(this->ptrself());
166 return *(this->ptrself());
171 return *(this->ptrself());
176 return *(this->ptrself());
181 return *(this->ptrself());
184 template<
typename E,
int etype>
187 return *(this->ptrself());
192 template<
typename E,
int etype>
195 return *(this->ptrself());
198 template<
typename E,
int etype>
201 return *(this->ptrself());
204 template<
typename E,
int etype>
207 return *(this->ptrself());
210 template<
typename E,
int etype>
213 return *(this->ptrself());
224 template<
typename TA,
typename TB,
bool ltrans,
bool rtrans,
typename DType>
225 struct DotExp:
public Exp<DotExp<TA, TB, ltrans, rtrans, DType>,
226 DType, type::kComplex> {
234 explicit DotExp(
const TA &lhs,
const TB &rhs, DType scale)
235 : lhs_(lhs), rhs_(rhs), scale_(scale) {}
239 template<
typename TA,
typename TB,
typename DType>
245 template<
typename TA,
typename TB,
typename DType>
251 template<
typename TA,
typename TB,
typename DType>
257 template<
typename TA,
typename TB,
typename DType>
263 template<
bool transpose_left,
bool transpose_right,
typename TA,
typename TB,
typename DType>
267 lhs.
self(), rhs.
self(), DType(1.0f));
279 template<
typename OP,
typename TA,
typename TB,
typename TC,
typename DType,
int etype>
290 :item1_(item1), item2_(item2), item3_(item3) {}
294 template<
typename OP,
typename TA,
typename TB,
typename TC,
typename DType,
int ta,
int tb,
int tc>
318 template<
typename OP,
typename TA,
typename TB,
typename TC,
typename DType,
int ta,
int tb,
int tc>
322 return MakeExp<OP>(item1, item2, item3);
334 template<
typename OP,
typename TA,
typename TB,
typename DType,
int etype>
343 :lhs_(lhs), rhs_(rhs) {}
347 template<
typename OP,
typename TA,
typename TB,
typename DType,
int ta,
int tb>
365 template<
typename OP,
typename TA,
typename TB,
typename DType,
int ta,
int tb>
368 return MakeExp<OP>(lhs, rhs);
372 template<
typename TA,
typename TB,
typename DType,
int ta,
int tb>
375 return MakeExp<op::plus>(lhs, rhs);
378 template<
typename TA,
typename TB,
typename DType,
int ta,
int tb>
381 return MakeExp<op::minus>(lhs, rhs);
384 template<
typename TA,
typename TB,
typename DType,
int ta,
int tb>
387 return MakeExp<op::mul>(lhs, rhs);
390 template<
typename TA,
typename TB,
typename DType,
int ta,
int tb>
393 return MakeExp<op::div>(lhs, rhs);
404 template<
typename OP,
typename TA,
typename DType,
int etype>
414 template<
typename OP,
typename TA,
typename DType,
int ta>
428 template<
typename OP,
typename TA,
typename DType,
int ta>
431 return MakeExp<OP>(src);
435 #endif // MSHADOW_EXPRESSION_H_ const int kRValue
this expression directly correspnds to a data class, can be used to assign data
Definition: expression.h:46
const int kMapper
expression contains element-wise tensor operations, map a expression to same shape ...
Definition: expression.h:51
DotExp< TA, TB, transpose_left, transpose_right, DType > batch_dot(const RValueExp< TA, DType > &lhs, const RValueExp< TB, DType > &rhs)
batch_dot operator def
Definition: expression.h:265
ScalarExp< DType > scalar(DType s)
create an scalar expression
Definition: expression.h:104
BinaryMapExp< op::minus, TA, ScalarExp< MSHADOW_SCALAR_ >, MSHADOW_SCALAR_,(ta|type::kMapper)> operator-(const Exp< TA, MSHADOW_SCALAR_, ta > &lhs, const ScalarExp< MSHADOW_SCALAR_ > &rhs)
operator overload
Definition: expr_scalar-inl.h:102
BinaryMapExp< op::plus, TA, ScalarExp< MSHADOW_SCALAR_ >, MSHADOW_SCALAR_,(ta|type::kMapper)> operator+(const Exp< TA, MSHADOW_SCALAR_, ta > &lhs, const ScalarExp< MSHADOW_SCALAR_ > &rhs)
operator overload
Definition: expr_scalar-inl.h:94
UnaryMapExp(const TA &src)
constructor
Definition: expression.h:410
const TB & rhs_
right operand
Definition: expression.h:340
SubType * ptrself(void)
Definition: expression.h:87
Container & operator/=(const Exp< E, DType, etype > &exp)
implementation of operator/=
Definition: expression.h:211
const int kComplex
othercase: e.g dot product
Definition: expression.h:59
ternary map expression
Definition: expression.h:280
binary map expression lhs [op] rhs
Definition: expression.h:335
DotExp< TA, TB, false, false, DType > dot(const RValueExp< TA, DType > &lhs, const RValueExp< TB, DType > &rhs)
dot operator def
Definition: expression.h:241
TernaryMapExp(const TA &item1, const TB &item2, const TC &item3)
constructor
Definition: expression.h:289
TypecastExp(const EType &e)
constructor
Definition: expression.h:121
BinaryMapExp(const TA &lhs, const TB &rhs)
constructor
Definition: expression.h:342
TypecastExp< DstDType, SrcDType, EType,(etype|type::kMapper)> tcast(const Exp< EType, SrcDType, etype > &exp)
create an scalar expression
Definition: expression.h:127
Container & operator*=(const Exp< E, DType, etype > &exp)
implementation of operator*=
Definition: expression.h:205
base class of all rvalues
Definition: expression.h:149
DType scalar_
scalar value
Definition: expression.h:98
const TB & rhs_
right operand
Definition: expression.h:230
const EType & exp
expression to be transposed
Definition: expression.h:135
const int kChainer
expression that can be chained with other expressiones Usually it have function Eval(i,j) defined, which pulls the result (i, j) from input expression and output the result at certain position.
Definition: expression.h:57
const TB & item2_
second operand
Definition: expression.h:285
ScalarExp(DType scalar)
implicit constructor, MUST NOT BE explicit
Definition: expression.h:100
Container & operator-=(DType s)
operator overload
Definition: expression.h:164
Container & operator-=(const Exp< E, DType, etype > &exp)
implementation of operator-=
Definition: expression.h:199
BinaryMapExp< OP, TA, ScalarExp< MSHADOW_SCALAR_ >, MSHADOW_SCALAR_,(ta|type::kMapper)> F(const Exp< TA, MSHADOW_SCALAR_, ta > &lhs, const ScalarExp< MSHADOW_SCALAR_ > &rhs)
operator overload for const
Definition: expr_scalar-inl.h:72
DType scale_
scale over result
Definition: expression.h:232
DotExp(const TA &lhs, const TB &rhs, DType scale)
constructor
Definition: expression.h:234
Container & operator/=(DType s)
operator overload
Definition: expression.h:174
const EType & T(void) const
transpose expression
Definition: expression.h:139
const TA & item1_
first operand
Definition: expression.h:283
typecast expression, cast the type of elements
Definition: expression.h:115
represent a transpose expression of a container
Definition: expression.h:132
Container & operator+=(DType s)
operator overload
Definition: expression.h:159
const TA & src_
source expression
Definition: expression.h:408
unary map expression op(src)
Definition: expression.h:405
matrix multiplication expression dot(lhs[.T], rhs[.T])
Definition: expression.h:225
scalar expression
Definition: expression.h:96
TernaryMapExp< OP, TA, TB, TC, DType,(ta|tb|tc|type::kMapper)> MakeExp(const Exp< TA, DType, ta > &item1, const Exp< TB, DType, tb > &item2, const Exp< TC, DType, tc > &item3)
make expression
Definition: expression.h:296
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:80
const EType & exp
expression to be typecasted
Definition: expression.h:119
const SubType & self(void) const
Definition: expression.h:83
const TC & item3_
third operand
Definition: expression.h:287
Container & operator*=(DType s)
operator overload
Definition: expression.h:169
BinaryMapExp< op::div, TA, ScalarExp< MSHADOW_SCALAR_ >, MSHADOW_SCALAR_,(ta|type::kMapper)> operator/(const Exp< TA, MSHADOW_SCALAR_, ta > &lhs, const ScalarExp< MSHADOW_SCALAR_ > &rhs)
operator overload
Definition: expr_scalar-inl.h:118
const TA & lhs_
left operand
Definition: expression.h:228
const TA & lhs_
left operand
Definition: expression.h:338
overloaded + operator between half_t and bf16_t
Definition: base.h:327
DotExp< TA, TB, ltrans, rtrans, MSHADOW_SCALAR_ > operator*(const DotExp< TA, TB, ltrans, rtrans, MSHADOW_SCALAR_ > &lhs, MSHADOW_SCALAR_ rhs)
dot operator def
Definition: expr_scalar-inl.h:41
the engine that dispatches simple operations
Definition: expr_engine-inl.h:461
Container & __assign(DType s)
operator overload
Definition: expression.h:179
Container & operator+=(const Exp< E, DType, etype > &exp)
implementation of operator+=
Definition: expression.h:193
const TransposeExp< Container, DType > T(void) const
transpose of a matrix
Definition: expression.h:155
static void Eval(RV *dst, const Exp< E, DType, type::kMapper > &exp)
Definition: expr_engine-inl.h:463
TransposeExp(const EType &e)
constructor
Definition: expression.h:137
Container & __assign(const Exp< E, DType, etype > &exp)
we can not define container = container
Definition: expression.h:185