Go to the documentation of this file.
25 #ifndef MSHADOW_EXTENSION_BROADCAST_H_
26 #define MSHADOW_EXTENSION_BROADCAST_H_
27 #include "../extension.h"
39 template<
typename SrcExp,
typename DType,
int dimdst,
int dimdst_m_cast>
41 public MakeTensorExp<Broadcast1DExp<SrcExp, DType, dimdst, dimdst_m_cast>,
42 SrcExp, dimdst, DType> {
60 template<
typename SrcExp,
typename DType,
int dimdst>
62 public MakeTensorExp<BroadcastScalarExp<SrcExp, DType, dimdst>,
63 SrcExp, dimdst, DType> {
84 template<
int dimcast,
typename SrcExp,
typename DType,
85 int etype,
int dimdst>
86 inline Broadcast1DExp<SrcExp, DType, dimdst, dimdst - dimcast>
89 ::Error_Expression_Does_Not_Meet_Dimension_Req();
91 CHECK_EQ(ShapeCheckDim1SrcExp::Check(src.
self())[0], shape[dimcast])
92 <<
"broadcast, shape mismatch";
94 dimdst - dimcast>(src.
self(), shape);
107 template<
typename SrcExp,
typename DType,
int etype,
int dimdst>
108 inline BroadcastScalarExp<SrcExp, DType, dimdst>
111 ::Error_Expression_Does_Not_Meet_Dimension_Req();
113 CHECK_EQ(ShapeCheckDim1SrcExp::Check(src.
self())[0], 1U)
114 <<
"broadcast_scalar, source need to be scalar expression";
125 template<
typename SrcExp,
typename DType,
int etype>
126 inline Broadcast1DExp<SrcExp, DType, 2, 1>
134 template<
typename SrcExp,
typename DType,
int dimdst,
int dimdst_m_cast>
137 static const int dimcast = dimdst - dimdst_m_cast;
140 ystride_(e.shape_.ProdShape(dimcast + 1, dimdst - 1)),
141 length_(e.shape_[dimcast]) {
146 return src_.Eval(0, (y / ystride_) % length_);
151 const index_t ystride_, length_;
155 template<
typename SrcExp,
typename DType,
int dimdst>
161 return src_.Eval(0, x);
169 template<
typename SrcExp,
typename DType,
int dimdst>
175 return src_.Eval(0, 0);
183 #endif // MSHADOW_EXTENSION_BROADCAST_H_
BroadcastScalarExp(const SrcExp &src, Shape< dimdst > shape)
constructor
Definition: broadcast.h:67
const SrcExp & src_
source operand
Definition: broadcast.h:65
const SubType & self(void) const
Definition: expression.h:82
Broadcast1DExp< SrcExp, DType, dimdst, dimdst - dimcast > broadcast(const expr::Exp< SrcExp, DType, etype > &src, Shape< dimdst > shape)
a expression that replicate a 1 dimension tensor in dimension dimcast
Definition: broadcast.h:87
used to help static type check
Definition: expr_engine-inl.h:330
const SrcExp & src_
source operand
Definition: broadcast.h:44
#define MSHADOW_XINLINE
Definition: base.h:228
runtime shape checking template get the shape of an expression, report error if shape mismatch
Definition: expr_engine-inl.h:364
Broadcast1DExp(const SrcExp &src, Shape< dimdst > shape)
constructor
Definition: broadcast.h:46
Plan(const Broadcast1DExp< SrcExp, DType, dimdst, 1 > &e)
Definition: broadcast.h:158
Plan< BinaryMapExp< OP, TA, TB, DType, etype >, DType > MakePlan(const BinaryMapExp< OP, TA, TB, DType, etype > &e)
Definition: expr_engine-inl.h:239
Shape< dim > shape_
the shape of this expression
Definition: expr_engine-inl.h:47
BroadcastScalarExp< SrcExp, DType, dimdst > broadcast_scalar(const expr::Exp< SrcExp, DType, etype > &src, Shape< dimdst > shape)
a expression that replicate a scalar tensor to target dimension.
Definition: broadcast.h:109
broadcast scalar into a higher dimension Tensor input: Tensor<Device,1>: ishape = {1} output: Tensor<...
Definition: broadcast.h:61
int32_t index_t
type that will be used for index
Definition: base.h:328
Broadcast1DExp< SrcExp, DType, 2, 1 > repmat(const expr::Exp< SrcExp, DType, etype > &src, index_t nrow)
a expression that replicate a 1 dimension tensor for nrow times
Definition: broadcast.h:127
Definition: expr_engine-inl.h:58
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: broadcast.h:160
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:79
overloaded + operator between half_t and bf16_t
Definition: base.h:319
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: broadcast.h:174
a general class that allows extension that makes tensors of some shape
Definition: expr_engine-inl.h:43
MSHADOW_XINLINE Shape< 2 > Shape2(index_t s0, index_t s1)
construct a two dimension shape, stride will equal s0
Definition: tensor.h:230
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: broadcast.h:145
shape of a tensor
Definition: tensor.h:64
Plan(const Broadcast1DExp< SrcExp, DType, dimdst, dimdst_m_cast > &e)
Definition: broadcast.h:138
Plan(const BroadcastScalarExp< SrcExp, DType, dimdst > &e)
Definition: broadcast.h:172
broadcast Tensor1D into a higher dimension Tensor input: Tensor<Device,1>: ishape[0] output: Tensor<D...
Definition: broadcast.h:40