26 #ifndef MSHADOW_EXPR_ENGINE_INL_H_ 27 #define MSHADOW_EXPR_ENGINE_INL_H_ 30 #include "./logging.h" 43 template<
typename SubType,
typename SrcExp,
int dim,
typename DType>
45 :
public Exp<MakeTensorExp<SubType, SrcExp, dim, DType>,
46 DType, type::kChainer> {
51 return *
static_cast<const SubType*
>(
this);
58 template<
typename ExpType,
typename DType>
68 template <
typename Device,
int dim,
typename DType>
72 : dptr_(t.dptr_), stride_(t.stride_) {}
75 return dptr_[y * stride_ + x];
79 return dptr_[y * stride_ + x];
87 template <
typename Device,
typename DType>
102 template<
typename DType>
114 template<
typename DstDType,
typename SrcDType,
115 typename EType,
int etype>
120 return DstDType(src_.Eval(y, x));
128 template<
typename OP,
typename TA,
typename TB,
typename TC,
int etype,
typename DType>
133 : item1_(item1), item2_(item2), item3_(item3) {}
135 return OP::Map(item1_.Eval(y, x), item2_.Eval(y, x), item3_.Eval(y, x));
144 template<
typename OP,
typename TA,
typename TB,
int etype,
typename DType>
148 : lhs_(lhs), rhs_(rhs) {}
150 return OP::Map(lhs_.Eval(y, x), rhs_.Eval(y, x));
158 template<
typename OP,
typename TA,
int etype,
typename DType>
163 return OP::Map(src_.Eval(y, x));
170 template<
typename SubType,
typename SrcExp,
int dim,
typename DType>
175 return src_.Eval(y, x);
182 template<
typename EType,
typename DType>
187 return src_.Eval(x, y);
196 template<
typename OP,
typename TA,
typename TB,
typename DType,
int etype>
200 template<
typename OP,
typename TA,
typename TB,
typename TC,
typename DType,
int etype>
204 template<
typename DType>
209 template<
typename DstDType,
typename SrcDType,
typename EType,
int etype>
212 return Plan<TypecastExp<DstDType, SrcDType, EType, etype>, DstDType>(
MakePlan(e.
exp));
215 template<
typename T,
typename DType>
220 template<
typename T,
typename DType>
223 return Plan<TransposeExp<T, DType>, DType>(
MakePlan(e.
exp));
226 template<
typename T,
typename SrcExp,
int dim,
typename DType>
232 template<
typename OP,
typename TA,
typename DType,
int etype>
235 return Plan<UnaryMapExp<OP, TA, DType, etype>, DType>(
MakePlan(e.
src_));
238 template<
typename OP,
typename TA,
typename TB,
typename DType,
int etype>
239 inline Plan<BinaryMapExp<OP, TA, TB, DType, etype>, DType>
241 return Plan<BinaryMapExp<OP, TA, TB, DType, etype>,
246 template<
typename OP,
typename TA,
typename TB,
typename TC,
typename DType,
int etype>
247 inline Plan<TernaryMapExp<OP, TA, TB, TC, DType, etype>, DType>
249 return Plan<TernaryMapExp<OP, TA, TB, TC, DType, etype>,
264 static const int kDim = -1;
265 static const int kDevMask = 0;
267 template<
typename DType>
269 static const int kDim = 0;
270 static const int kDevMask = 0xffff;
272 template<
typename E,
typename DType>
277 template<
typename DstDType,
typename SrcDType,
typename EType,
int etype>
282 template<
typename Device,
int dim,
typename DType>
284 static const int kDim = dim;
285 static const int kDevMask = Device::kDevMask;
287 template<
typename T,
typename SrcExp,
int dim,
typename DType>
290 static const int kDim = kDimSrc >= 0 ? dim : -1;
293 template<
typename OP,
typename TA,
typename DType,
int etype>
298 template<
typename OP,
typename TA,
typename TB,
typename DType,
int etype>
302 static const int kDim = (kDimLhs >= 0 && kDimRhs >= 0) ?\
305 ((kDimRhs == 0 || kDimLhs == kDimRhs) ? kDimLhs : -1)) : -1;
308 template<
typename OP,
typename TA,
typename TB,
typename TC,
typename DType,
int etype>
313 static const int kDim = kDimItem1;
318 template<
typename Device,
int dim,
typename DType,
typename E>
325 static const bool kMapPass = (kExpDim == 0 || kExpDim == dim) && kDevPass;
327 static const bool kRedPass = (kExpDim > dim) && kDevPass;
345 template<
typename Device,
typename E>
349 template<
int dim,
typename Device,
typename DType>
364 template<
int dim,
typename E>
368 template<
int dim,
typename DType>
373 for (
int i = 0; i < dim; ++i) {
379 template<
int dim,
typename DstDType,
typename SrcDType,
typename EType,
int etype>
386 template<
int dim,
typename E,
typename DType>
391 std::swap(s[0], s[1]);
395 template<
int dim,
typename Device,
typename DType>
401 template<
int dim,
typename SrcExp,
typename T,
typename DType>
408 template<
int dim,
typename OP,
typename TA,
typename DType,
int etype>
416 template<
int dim,
typename OP,
typename TA,
typename TB,
417 typename DType,
int etype>
423 if (shape1[0] == 0)
return shape2;
424 if (shape2[0] == 0)
return shape1;
425 CHECK_EQ(shape1, shape2) <<
"BinaryMapExp: Shapes of operands are not the same, " <<
426 "Shape1=" << shape1 <<
", Shape2=" << shape2;
431 template<
int dim,
typename OP,
typename TA,
typename TB,
typename TC,
432 typename DType,
int etype>
439 bool same = (shape1 == shape2) && (shape2 == shape3);
440 CHECK(same) <<
"TernaryMapExp: Shapes of operands are not the same, " <<
441 "Shape1=" << shape1 <<
", Shape2=" << shape2 <<
", Shape3=" << shape3;
455 template<
typename SV,
typename RV,
typename E,
typename DType>
457 inline static void Eval(RV *dst,
const E &exp);
460 template<
typename SV,
typename RV,
typename DType>
463 inline static void Eval(RV *dst,
465 MapExp<SV>(dst, exp);
468 inline static void Eval(RV *dst,
470 MapExp<SV>(dst, exp);
473 inline static void Eval(RV *dst,
475 MapExp<SV>(dst, exp);
478 inline static void Eval(RV *dst,
483 template<
typename SV,
typename Device,
int dim,
int ldim,
484 int rdim,
bool ltrans,
bool rtrans,
typename DType>
486 Tensor<Device, dim, DType>,
488 Tensor<Device, rdim, DType>,
489 ltrans, rtrans, DType>,
494 ltrans, rtrans, DType> &exp) {
496 ltrans, rtrans, DType>::Eval(dst, exp.lhs_, exp.rhs_, exp.scale_);
501 #endif // MSHADOW_EXPR_ENGINE_INL_H_ Plan(const Plan< TA, DType > &src)
Definition: expr_engine-inl.h:161
static Shape< dim > Check(const UnaryMapExp< OP, TA, DType, etype > &t)
Definition: expr_engine-inl.h:410
static void Eval(RV *dst, const Exp< E, DType, type::kRValue > &exp)
Definition: expr_engine-inl.h:473
static Shape< dim > Check(const MakeTensorExp< T, SrcExp, dim, DType > &t)
Definition: expr_engine-inl.h:404
static Shape< dim > Check(const Tensor< Device, dim, DType > &t)
Definition: expr_engine-inl.h:397
ScalarExp< DType > scalar(DType s)
create an scalar expression
Definition: expression.h:104
static Shape< dim > Check(const BinaryMapExp< OP, TA, TB, DType, etype > &t)
Definition: expr_engine-inl.h:420
Definition: expr_engine-inl.h:59
Plan(const Tensor< Device, 1, DType > &t)
Definition: expr_engine-inl.h:90
used to help static type check
Definition: expr_engine-inl.h:331
template to do type check
Definition: expr_engine-inl.h:319
const TB & rhs_
right operand
Definition: expression.h:340
Plan(const Tensor< Device, dim, DType > &t)
Definition: expr_engine-inl.h:71
Shape< dimension > shape_
shape of the tensor
Definition: tensor.h:437
ternary map expression
Definition: expression.h:280
static void Error_All_Tensor_in_Exp_Must_Have_Same_Type(void)
Definition: expr_engine-inl.h:337
binary map expression lhs [op] rhs
Definition: expression.h:335
Plan(DType scalar)
Definition: expr_engine-inl.h:105
mshadow::expr::ExpComplexEngine< SV, Tensor< Device, dim, DType >, DotExp< Tensor< Device, ldim, DType >, Tensor< Device, rdim, DType >, ltrans, rtrans, DType >, DType >::Eval static void Eval(Tensor< Device, dim, DType > *dst, const DotExp< Tensor< Device, ldim, DType >, Tensor< Device, rdim, DType >, ltrans, rtrans, DType > &exp)
Definition: expr_engine-inl.h:491
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: expr_engine-inl.h:174
static void Eval(RV *dst, const E &exp)
base class of all rvalues
Definition: expression.h:149
Definition: dot_engine-inl.h:71
DType scalar_
scalar value
Definition: expression.h:98
MSHADOW_XINLINE DType & REval(index_t y, index_t x)
Definition: expr_engine-inl.h:74
const EType & exp
expression to be transposed
Definition: expression.h:135
static void Error_TypeCheck_Not_Pass_For_Reduce_Exp(void)
Definition: expr_engine-inl.h:338
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: expr_engine-inl.h:134
MSHADOW_XINLINE const DType & Eval(index_t y, index_t x) const
Definition: expr_engine-inl.h:78
static Shape< dim > Check(const E &t)
header file of tensor data structure and functions This lib requires explicit memory allocation and d...
static void Eval(RV *dst, const Exp< E, DType, type::kChainer > &exp)
Definition: expr_engine-inl.h:468
#define MSHADOW_XINLINE
Definition: base.h:223
const TB & item2_
second operand
Definition: expression.h:285
static type inference template, used to get the dimension of each expression, if ExpInfo<E>::kDim == ...
Definition: expr_engine-inl.h:263
Definition: expr_engine-inl.h:346
definitions of abstract expressions and expressions template
static void Eval(RV *dst, const Exp< E, DType, type::kComplex > &exp)
Definition: expr_engine-inl.h:478
static Shape< dim > Check(const TernaryMapExp< OP, TA, TB, TC, DType, etype > &t)
Definition: expr_engine-inl.h:435
static void Error_Expression_Does_Not_Meet_Dimension_Req(void)
Definition: expr_engine-inl.h:339
MSHADOW_XINLINE const DType & Eval(index_t y, index_t x) const
Definition: expr_engine-inl.h:94
int32_t index_t
type that will be used for index
Definition: base.h:336
Plan(const Plan< EType, DType > &src)
Definition: expr_engine-inl.h:185
static Shape< dim > Check(const TypecastExp< DstDType, SrcDType, EType, etype > &exp)
Definition: expr_engine-inl.h:382
Plan(const Plan< TA, DType > &item1, const Plan< TB, DType > &item2, const Plan< TC, DType > &item3)
Definition: expr_engine-inl.h:131
const TA & item1_
first operand
Definition: expression.h:283
typecast expression, cast the type of elements
Definition: expression.h:115
static Shape< dim > Check(const ScalarExp< DType > &exp)
Definition: expr_engine-inl.h:370
runtime shape checking template get the shape of an expression, report error if shape mismatch ...
Definition: expr_engine-inl.h:365
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: expr_engine-inl.h:162
represent a transpose expression of a container
Definition: expression.h:132
some engine that evaluate complex expression
Definition: expr_engine-inl.h:456
const TA & src_
source expression
Definition: expression.h:408
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: expr_engine-inl.h:186
unary map expression op(src)
Definition: expression.h:405
matrix multiplication expression dot(lhs[.T], rhs[.T])
Definition: expression.h:225
Plan(const Plan< EType, SrcDType > &src)
Definition: expr_engine-inl.h:118
scalar expression
Definition: expression.h:96
Plan(const Plan< SubType, DType > &src)
Definition: expr_engine-inl.h:173
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:80
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: expr_engine-inl.h:106
const EType & exp
expression to be typecasted
Definition: expression.h:119
const Container & self(void) const
Definition: expression.h:83
Plan< BinaryMapExp< OP, TA, TB, DType, etype >, DType > MakePlan(const BinaryMapExp< OP, TA, TB, DType, etype > &e)
Definition: expr_engine-inl.h:240
const TC & item3_
third operand
Definition: expression.h:287
MSHADOW_XINLINE DstDType Eval(index_t y, index_t x) const
Definition: expr_engine-inl.h:119
const SubType & real_self(void) const
true self of subtype
Definition: expr_engine-inl.h:50
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: expr_engine-inl.h:149
a general class that allows extension that makes tensors of some shape
Definition: expr_engine-inl.h:44
const TA & lhs_
left operand
Definition: expression.h:338
overloaded + operator between half_t and bf16_t
Definition: base.h:327
static Shape< dim > Check(const TransposeExp< E, DType > &e)
Definition: expr_engine-inl.h:388
the engine that dispatches simple operations
Definition: expr_engine-inl.h:461
Shape< dim > shape_
the shape of this expression
Definition: expr_engine-inl.h:48
general tensor
Definition: tensor.h:421
static void Eval(RV *dst, const Exp< E, DType, type::kMapper > &exp)
Definition: expr_engine-inl.h:463
Stream< Device > * stream_
stream where the computation lies stream is a device dependency concept where each computation ...
Definition: tensor.h:447
definitions of how Matrix Multiplications can be evaluated
Plan(const Plan< TA, DType > &lhs, const Plan< TB, DType > &rhs)
Definition: expr_engine-inl.h:147
MSHADOW_XINLINE DType & REval(index_t y, index_t x)
Definition: expr_engine-inl.h:91
static Stream< Device > * Get(const Tensor< Device, dim, DType > &t)
Definition: expr_engine-inl.h:351
computaion stream structure, used for asynchronous computations
Definition: tensor.h:384