26 #ifndef MSHADOW_EXTENSION_BROADCAST_H_ 27 #define MSHADOW_EXTENSION_BROADCAST_H_ 28 #include "../extension.h" 40 template<
typename SrcExp,
typename DType,
int dimdst,
int dimdst_m_cast>
42 public MakeTensorExp<Broadcast1DExp<SrcExp, DType, dimdst, dimdst_m_cast>,
43 SrcExp, dimdst, DType> {
61 template<
typename SrcExp,
typename DType,
int dimdst>
63 public MakeTensorExp<BroadcastScalarExp<SrcExp, DType, dimdst>,
64 SrcExp, dimdst, DType> {
85 template<
int dimcast,
typename SrcExp,
typename DType,
86 int etype,
int dimdst>
90 ::Error_Expression_Does_Not_Meet_Dimension_Req();
92 CHECK_EQ(ShapeCheckDim1SrcExp::Check(src.
self())[0], shape[dimcast])
93 <<
"broadcast, shape mismatch";
95 dimdst - dimcast>(src.
self(), shape);
108 template<
typename SrcExp,
typename DType,
int etype,
int dimdst>
112 ::Error_Expression_Does_Not_Meet_Dimension_Req();
114 CHECK_EQ(ShapeCheckDim1SrcExp::Check(src.
self())[0], 1U)
115 <<
"broadcast_scalar, source need to be scalar expression";
126 template<
typename SrcExp,
typename DType,
int etype>
135 template<
typename SrcExp,
typename DType,
int dimdst,
int dimdst_m_cast>
138 static const int dimcast = dimdst - dimdst_m_cast;
141 ystride_(e.
shape_.ProdShape(dimcast + 1, dimdst - 1)),
142 length_(e.
shape_[dimcast]) {
147 return src_.Eval(0, (y / ystride_) % length_);
152 const index_t ystride_, length_;
156 template<
typename SrcExp,
typename DType,
int dimdst>
162 return src_.Eval(0, x);
170 template<
typename SrcExp,
typename DType,
int dimdst>
176 return src_.Eval(0, 0);
184 #endif // MSHADOW_EXTENSION_BROADCAST_H_ Broadcast1DExp< SrcExp, DType, 2, 1 > repmat(const expr::Exp< SrcExp, DType, etype > &src, index_t nrow)
a expression that replicate a 1 dimension tensor for nrow times
Definition: broadcast.h:128
broadcast Tensor1D into a higher dimension Tensor input: Tensor<Device,1>: ishape[0] output: Tensor<D...
Definition: broadcast.h:41
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: broadcast.h:161
Broadcast1DExp(const SrcExp &src, Shape< dimdst > shape)
constructor
Definition: broadcast.h:47
Definition: expr_engine-inl.h:59
used to help static type check
Definition: expr_engine-inl.h:331
broadcast scalar into a higher dimension Tensor input: Tensor<Device,1>: ishape = {1} output: Tensor<...
Definition: broadcast.h:62
shape of a tensor
Definition: tensor.h:54
BroadcastScalarExp< SrcExp, DType, dimdst > broadcast_scalar(const expr::Exp< SrcExp, DType, etype > &src, Shape< dimdst > shape)
a expression that replicate a scalar tensor to target dimension.
Definition: broadcast.h:110
Plan(const Broadcast1DExp< SrcExp, DType, dimdst, 1 > &e)
Definition: broadcast.h:159
#define MSHADOW_XINLINE
Definition: base.h:223
int32_t index_t
type that will be used for index
Definition: base.h:336
const SrcExp & src_
source operand
Definition: broadcast.h:45
Plan(const BroadcastScalarExp< SrcExp, DType, dimdst > &e)
Definition: broadcast.h:173
const SrcExp & src_
source operand
Definition: broadcast.h:66
runtime shape checking template get the shape of an expression, report error if shape mismatch ...
Definition: expr_engine-inl.h:365
Plan(const Broadcast1DExp< SrcExp, DType, dimdst, dimdst_m_cast > &e)
Definition: broadcast.h:139
MSHADOW_XINLINE Shape< 2 > Shape2(index_t s0, index_t s1)
construct a two dimension shape, stride will equal s0
Definition: tensor.h:217
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: broadcast.h:175
BroadcastScalarExp(const SrcExp &src, Shape< dimdst > shape)
constructor
Definition: broadcast.h:68
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:80
const SubType & self(void) const
Definition: expression.h:83
Plan< BinaryMapExp< OP, TA, TB, DType, etype >, DType > MakePlan(const BinaryMapExp< OP, TA, TB, DType, etype > &e)
Definition: expr_engine-inl.h:240
a general class that allows extension that makes tensors of some shape
Definition: expr_engine-inl.h:44
overloaded + operator between half_t and bf16_t
Definition: base.h:327
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: broadcast.h:146
Shape< dim > shape_
the shape of this expression
Definition: expr_engine-inl.h:48
Broadcast1DExp< SrcExp, DType, dimdst, dimdst-dimcast > broadcast(const expr::Exp< SrcExp, DType, etype > &src, Shape< dimdst > shape)
a expression that replicate a 1 dimension tensor in dimension dimcast
Definition: broadcast.h:88