mxnet
broadcast.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
25 #ifndef MSHADOW_EXTENSION_BROADCAST_H_
26 #define MSHADOW_EXTENSION_BROADCAST_H_
27 #include "../extension.h"
28 namespace mshadow {
29 namespace expr {
39 template<typename SrcExp, typename DType, int dimdst, int dimdst_m_cast>
41  public MakeTensorExp<Broadcast1DExp<SrcExp, DType, dimdst, dimdst_m_cast>,
42  SrcExp, dimdst, DType> {
44  const SrcExp &src_;
46  Broadcast1DExp(const SrcExp &src, Shape<dimdst> shape)
47  : src_(src) {
48  this->shape_ = shape;
49  }
50 };
51 
60 template<typename SrcExp, typename DType, int dimdst>
62  public MakeTensorExp<BroadcastScalarExp<SrcExp, DType, dimdst>,
63  SrcExp, dimdst, DType> {
65  const SrcExp &src_;
67  BroadcastScalarExp(const SrcExp &src, Shape<dimdst> shape)
68  : src_(src) {
69  this->shape_ = shape;
70  }
71 };
72 
84 template<int dimcast, typename SrcExp, typename DType,
85  int etype, int dimdst>
86 inline Broadcast1DExp<SrcExp, DType, dimdst, dimdst - dimcast>
89  ::Error_Expression_Does_Not_Meet_Dimension_Req();
90  typedef ShapeCheck<1, SrcExp> ShapeCheckDim1SrcExp;
91  CHECK_EQ(ShapeCheckDim1SrcExp::Check(src.self())[0], shape[dimcast])
92  << "broadcast, shape mismatch";
93  return Broadcast1DExp<SrcExp, DType, dimdst,
94  dimdst - dimcast>(src.self(), shape);
95 }
96 
107 template<typename SrcExp, typename DType, int etype, int dimdst>
108 inline BroadcastScalarExp<SrcExp, DType, dimdst>
111  ::Error_Expression_Does_Not_Meet_Dimension_Req();
112  typedef ShapeCheck<1, SrcExp> ShapeCheckDim1SrcExp;
113  CHECK_EQ(ShapeCheckDim1SrcExp::Check(src.self())[0], 1U)
114  << "broadcast_scalar, source need to be scalar expression";
115  return BroadcastScalarExp<SrcExp, DType, dimdst>(src.self(), shape);
116 }
117 // short cut functions
125 template<typename SrcExp, typename DType, int etype>
126 inline Broadcast1DExp<SrcExp, DType, 2, 1>
128  return broadcast<1>
129  (src, Shape2(nrow, ShapeCheck<1, SrcExp>::Check(src.self())[0]));
130 }
131 //----------------------
132 // Execution plan
133 //----------------------
134 template<typename SrcExp, typename DType, int dimdst, int dimdst_m_cast>
135 struct Plan<Broadcast1DExp<SrcExp, DType, dimdst, dimdst_m_cast>, DType> {
136  public:
137  static const int dimcast = dimdst - dimdst_m_cast;
139  : src_(MakePlan(e.src_)),
140  ystride_(e.shape_.ProdShape(dimcast + 1, dimdst - 1)),
141  length_(e.shape_[dimcast]) {
144  }
145  MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
146  return src_.Eval(0, (y / ystride_) % length_);
147  }
148 
149  private:
151  const index_t ystride_, length_;
152 };
153 
155 template<typename SrcExp, typename DType, int dimdst>
156 struct Plan<Broadcast1DExp<SrcExp, DType, dimdst, 1>, DType>{
157  public:
159  : src_(MakePlan(e.src_)) {}
160  MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
161  return src_.Eval(0, x);
162  }
163 
164  private:
166 };
167 
169 template<typename SrcExp, typename DType, int dimdst>
170 struct Plan<BroadcastScalarExp<SrcExp, DType, dimdst>, DType>{
171  public:
173  : src_(MakePlan(e.src_)) {}
174  MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
175  return src_.Eval(0, 0);
176  }
177 
178  private:
180 };
181 } // namespace expr
182 } // namespace mshadow
183 #endif // MSHADOW_EXTENSION_BROADCAST_H_
mshadow::expr::BroadcastScalarExp::BroadcastScalarExp
BroadcastScalarExp(const SrcExp &src, Shape< dimdst > shape)
constructor
Definition: broadcast.h:67
mshadow::expr::BroadcastScalarExp::src_
const SrcExp & src_
source operand
Definition: broadcast.h:65
mshadow::expr::Exp::self
const SubType & self(void) const
Definition: expression.h:82
mshadow::expr::broadcast
Broadcast1DExp< SrcExp, DType, dimdst, dimdst - dimcast > broadcast(const expr::Exp< SrcExp, DType, etype > &src, Shape< dimdst > shape)
a expression that replicate a 1 dimension tensor in dimension dimcast
Definition: broadcast.h:87
mshadow::expr::TypeCheckPass
used to help static type check
Definition: expr_engine-inl.h:330
mshadow::expr::Broadcast1DExp::src_
const SrcExp & src_
source operand
Definition: broadcast.h:44
MSHADOW_XINLINE
#define MSHADOW_XINLINE
Definition: base.h:228
mshadow::expr::ShapeCheck
runtime shape checking template get the shape of an expression, report error if shape mismatch
Definition: expr_engine-inl.h:364
mshadow::expr::Broadcast1DExp::Broadcast1DExp
Broadcast1DExp(const SrcExp &src, Shape< dimdst > shape)
constructor
Definition: broadcast.h:46
mshadow::expr::Plan< Broadcast1DExp< SrcExp, DType, dimdst, 1 >, DType >::Plan
Plan(const Broadcast1DExp< SrcExp, DType, dimdst, 1 > &e)
Definition: broadcast.h:158
mshadow::expr::MakePlan
Plan< BinaryMapExp< OP, TA, TB, DType, etype >, DType > MakePlan(const BinaryMapExp< OP, TA, TB, DType, etype > &e)
Definition: expr_engine-inl.h:239
mshadow::expr::MakeTensorExp< Broadcast1DExp< SrcExp, DType, dimdst, dimdst_m_cast >, SrcExp, dimdst, DType >::shape_
Shape< dim > shape_
the shape of this expression
Definition: expr_engine-inl.h:47
mshadow::expr::broadcast_scalar
BroadcastScalarExp< SrcExp, DType, dimdst > broadcast_scalar(const expr::Exp< SrcExp, DType, etype > &src, Shape< dimdst > shape)
a expression that replicate a scalar tensor to target dimension.
Definition: broadcast.h:109
mshadow::expr::BroadcastScalarExp
broadcast scalar into a higher dimension Tensor input: Tensor<Device,1>: ishape = {1} output: Tensor<...
Definition: broadcast.h:61
mshadow::index_t
int32_t index_t
type that will be used for index
Definition: base.h:328
mshadow::expr::repmat
Broadcast1DExp< SrcExp, DType, 2, 1 > repmat(const expr::Exp< SrcExp, DType, etype > &src, index_t nrow)
a expression that replicate a 1 dimension tensor for nrow times
Definition: broadcast.h:127
mshadow::expr::Plan
Definition: expr_engine-inl.h:58
mshadow::expr::Plan< Broadcast1DExp< SrcExp, DType, dimdst, 1 >, DType >::Eval
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: broadcast.h:160
mshadow::expr::Exp
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:79
mshadow
overloaded + operator between half_t and bf16_t
Definition: base.h:319
mshadow::expr::Plan< BroadcastScalarExp< SrcExp, DType, dimdst >, DType >::Eval
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: broadcast.h:174
mshadow::expr::MakeTensorExp
a general class that allows extension that makes tensors of some shape
Definition: expr_engine-inl.h:43
mshadow::Shape2
MSHADOW_XINLINE Shape< 2 > Shape2(index_t s0, index_t s1)
construct a two dimension shape, stride will equal s0
Definition: tensor.h:230
mshadow::expr::Plan< Broadcast1DExp< SrcExp, DType, dimdst, dimdst_m_cast >, DType >::Eval
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: broadcast.h:145
mshadow::Shape
shape of a tensor
Definition: tensor.h:64
mshadow::expr::Plan< Broadcast1DExp< SrcExp, DType, dimdst, dimdst_m_cast >, DType >::Plan
Plan(const Broadcast1DExp< SrcExp, DType, dimdst, dimdst_m_cast > &e)
Definition: broadcast.h:138
mshadow::expr::Plan< BroadcastScalarExp< SrcExp, DType, dimdst >, DType >::Plan
Plan(const BroadcastScalarExp< SrcExp, DType, dimdst > &e)
Definition: broadcast.h:172
mshadow::expr::Broadcast1DExp
broadcast Tensor1D into a higher dimension Tensor input: Tensor<Device,1>: ishape[0] output: Tensor<D...
Definition: broadcast.h:40