mxnet
broadcast.h
Go to the documentation of this file.
1 
7 #ifndef MSHADOW_EXTENSION_BROADCAST_H_
8 #define MSHADOW_EXTENSION_BROADCAST_H_
9 #include "../extension.h"
10 namespace mshadow {
11 namespace expr {
21 template<typename SrcExp, typename DType, int dimdst, int dimdst_m_cast>
23  public MakeTensorExp<Broadcast1DExp<SrcExp, DType, dimdst, dimdst_m_cast>,
24  SrcExp, dimdst, DType> {
26  const SrcExp &src_;
28  Broadcast1DExp(const SrcExp &src, Shape<dimdst> shape)
29  : src_(src) {
30  this->shape_ = shape;
31  }
32 };
33 
42 template<typename SrcExp, typename DType, int dimdst>
44  public MakeTensorExp<BroadcastScalarExp<SrcExp, DType, dimdst>,
45  SrcExp, dimdst, DType> {
47  const SrcExp &src_;
49  BroadcastScalarExp(const SrcExp &src, Shape<dimdst> shape)
50  : src_(src) {
51  this->shape_ = shape;
52  }
53 };
54 
66 template<int dimcast, typename SrcExp, typename DType,
67  int etype, int dimdst>
68 inline Broadcast1DExp<SrcExp, DType, dimdst, dimdst - dimcast>
71  ::Error_Expression_Does_Not_Meet_Dimension_Req();
72  typedef ShapeCheck<1, SrcExp> ShapeCheckDim1SrcExp;
73  CHECK_EQ(ShapeCheckDim1SrcExp::Check(src.self())[0], shape[dimcast])
74  << "broadcast, shape mismatch";
75  return Broadcast1DExp<SrcExp, DType, dimdst,
76  dimdst - dimcast>(src.self(), shape);
77 }
78 
89 template<typename SrcExp, typename DType, int etype, int dimdst>
93  ::Error_Expression_Does_Not_Meet_Dimension_Req();
94  typedef ShapeCheck<1, SrcExp> ShapeCheckDim1SrcExp;
95  CHECK_EQ(ShapeCheckDim1SrcExp::Check(src.self())[0], 1U)
96  << "broadcast_scalar, source need to be scalar expression";
98 }
99 // short cut functions
107 template<typename SrcExp, typename DType, int etype>
110  return broadcast<1>
111  (src, Shape2(nrow, ShapeCheck<1, SrcExp>::Check(src.self())[0]));
112 }
113 //----------------------
114 // Execution plan
115 //----------------------
116 template<typename SrcExp, typename DType, int dimdst, int dimdst_m_cast>
117 struct Plan<Broadcast1DExp<SrcExp, DType, dimdst, dimdst_m_cast>, DType> {
118  public:
119  static const int dimcast = dimdst - dimdst_m_cast;
121  : src_(MakePlan(e.src_)),
122  ystride_(e.shape_.ProdShape(dimcast + 1, dimdst - 1)),
123  length_(e.shape_[dimcast]) {
126  }
127  MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
128  return src_.Eval(0, (y / ystride_) % length_);
129  }
130 
131  private:
133  const index_t ystride_, length_;
134 };
135 
137 template<typename SrcExp, typename DType, int dimdst>
138 struct Plan<Broadcast1DExp<SrcExp, DType, dimdst, 1>, DType>{
139  public:
141  : src_(MakePlan(e.src_)) {}
142  MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
143  return src_.Eval(0, x);
144  }
145 
146  private:
148 };
149 
151 template<typename SrcExp, typename DType, int dimdst>
152 struct Plan<BroadcastScalarExp<SrcExp, DType, dimdst>, DType>{
153  public:
155  : src_(MakePlan(e.src_)) {}
156  MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
157  return src_.Eval(0, 0);
158  }
159 
160  private:
162 };
163 } // namespace expr
164 } // namespace mshadow
165 #endif // MSHADOW_EXTENSION_BROADCAST_H_
Broadcast1DExp< SrcExp, DType, 2, 1 > repmat(const expr::Exp< SrcExp, DType, etype > &src, index_t nrow)
a expression that replicate a 1 dimension tensor for nrow times
Definition: broadcast.h:109
broadcast Tensor1D into a higher dimension Tensor input: Tensor<Device,1>: ishape[0] output: Tensor<D...
Definition: broadcast.h:22
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: broadcast.h:142
Broadcast1DExp(const SrcExp &src, Shape< dimdst > shape)
constructor
Definition: broadcast.h:28
Definition: expr_engine-inl.h:40
used to help static type check
Definition: expr_engine-inl.h:312
broadcast scalar into a higher dimension Tensor input: Tensor<Device,1>: ishape = {1} output: Tensor<...
Definition: broadcast.h:43
shape of a tensor
Definition: tensor.h:35
BroadcastScalarExp< SrcExp, DType, dimdst > broadcast_scalar(const expr::Exp< SrcExp, DType, etype > &src, Shape< dimdst > shape)
a expression that replicate a scalar tensor to target dimension.
Definition: broadcast.h:91
Plan(const Broadcast1DExp< SrcExp, DType, dimdst, 1 > &e)
Definition: broadcast.h:140
#define MSHADOW_XINLINE
Definition: base.h:204
int32_t index_t
type that will be used for index
Definition: base.h:291
const SrcExp & src_
source operand
Definition: broadcast.h:26
Plan(const BroadcastScalarExp< SrcExp, DType, dimdst > &e)
Definition: broadcast.h:154
const SrcExp & src_
source operand
Definition: broadcast.h:47
runtime shape checking template get the shape of an expression, report error if shape mismatch ...
Definition: expr_engine-inl.h:346
Plan(const Broadcast1DExp< SrcExp, DType, dimdst, dimdst_m_cast > &e)
Definition: broadcast.h:120
MSHADOW_XINLINE Shape< 2 > Shape2(index_t s0, index_t s1)
construct a two dimension shape, stride will equal s0
Definition: tensor.h:198
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: broadcast.h:156
BroadcastScalarExp(const SrcExp &src, Shape< dimdst > shape)
constructor
Definition: broadcast.h:49
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:61
const SubType & self(void) const
Definition: expression.h:64
Plan< BinaryMapExp< OP, TA, TB, DType, etype >, DType > MakePlan(const BinaryMapExp< OP, TA, TB, DType, etype > &e)
Definition: expr_engine-inl.h:221
a general class that allows extension that makes tensors of some shape
Definition: expr_engine-inl.h:25
namespace for mshadow
Definition: base.h:282
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: broadcast.h:127
Shape< dim > shape_
the shape of this expression
Definition: expr_engine-inl.h:29
Broadcast1DExp< SrcExp, DType, dimdst, dimdst-dimcast > broadcast(const expr::Exp< SrcExp, DType, etype > &src, Shape< dimdst > shape)
a expression that replicate a 1 dimension tensor in dimension dimcast
Definition: broadcast.h:69