mxnet
broadcast.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
26 #ifndef MSHADOW_EXTENSION_BROADCAST_H_
27 #define MSHADOW_EXTENSION_BROADCAST_H_
28 #include "../extension.h"
29 namespace mshadow {
30 namespace expr {
40 template<typename SrcExp, typename DType, int dimdst, int dimdst_m_cast>
42  public MakeTensorExp<Broadcast1DExp<SrcExp, DType, dimdst, dimdst_m_cast>,
43  SrcExp, dimdst, DType> {
45  const SrcExp &src_;
47  Broadcast1DExp(const SrcExp &src, Shape<dimdst> shape)
48  : src_(src) {
49  this->shape_ = shape;
50  }
51 };
52 
61 template<typename SrcExp, typename DType, int dimdst>
63  public MakeTensorExp<BroadcastScalarExp<SrcExp, DType, dimdst>,
64  SrcExp, dimdst, DType> {
66  const SrcExp &src_;
68  BroadcastScalarExp(const SrcExp &src, Shape<dimdst> shape)
69  : src_(src) {
70  this->shape_ = shape;
71  }
72 };
73 
85 template<int dimcast, typename SrcExp, typename DType,
86  int etype, int dimdst>
87 inline Broadcast1DExp<SrcExp, DType, dimdst, dimdst - dimcast>
90  ::Error_Expression_Does_Not_Meet_Dimension_Req();
91  typedef ShapeCheck<1, SrcExp> ShapeCheckDim1SrcExp;
92  CHECK_EQ(ShapeCheckDim1SrcExp::Check(src.self())[0], shape[dimcast])
93  << "broadcast, shape mismatch";
94  return Broadcast1DExp<SrcExp, DType, dimdst,
95  dimdst - dimcast>(src.self(), shape);
96 }
97 
108 template<typename SrcExp, typename DType, int etype, int dimdst>
112  ::Error_Expression_Does_Not_Meet_Dimension_Req();
113  typedef ShapeCheck<1, SrcExp> ShapeCheckDim1SrcExp;
114  CHECK_EQ(ShapeCheckDim1SrcExp::Check(src.self())[0], 1U)
115  << "broadcast_scalar, source need to be scalar expression";
116  return BroadcastScalarExp<SrcExp, DType, dimdst>(src.self(), shape);
117 }
118 // short cut functions
126 template<typename SrcExp, typename DType, int etype>
129  return broadcast<1>
130  (src, Shape2(nrow, ShapeCheck<1, SrcExp>::Check(src.self())[0]));
131 }
132 //----------------------
133 // Execution plan
134 //----------------------
135 template<typename SrcExp, typename DType, int dimdst, int dimdst_m_cast>
136 struct Plan<Broadcast1DExp<SrcExp, DType, dimdst, dimdst_m_cast>, DType> {
137  public:
138  static const int dimcast = dimdst - dimdst_m_cast;
140  : src_(MakePlan(e.src_)),
141  ystride_(e.shape_.ProdShape(dimcast + 1, dimdst - 1)),
142  length_(e.shape_[dimcast]) {
145  }
146  MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
147  return src_.Eval(0, (y / ystride_) % length_);
148  }
149 
150  private:
152  const index_t ystride_, length_;
153 };
154 
156 template<typename SrcExp, typename DType, int dimdst>
157 struct Plan<Broadcast1DExp<SrcExp, DType, dimdst, 1>, DType>{
158  public:
160  : src_(MakePlan(e.src_)) {}
161  MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
162  return src_.Eval(0, x);
163  }
164 
165  private:
167 };
168 
170 template<typename SrcExp, typename DType, int dimdst>
171 struct Plan<BroadcastScalarExp<SrcExp, DType, dimdst>, DType>{
172  public:
174  : src_(MakePlan(e.src_)) {}
175  MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
176  return src_.Eval(0, 0);
177  }
178 
179  private:
181 };
182 } // namespace expr
183 } // namespace mshadow
184 #endif // MSHADOW_EXTENSION_BROADCAST_H_
Broadcast1DExp< SrcExp, DType, 2, 1 > repmat(const expr::Exp< SrcExp, DType, etype > &src, index_t nrow)
a expression that replicate a 1 dimension tensor for nrow times
Definition: broadcast.h:128
broadcast Tensor1D into a higher dimension Tensor input: Tensor<Device,1>: ishape[0] output: Tensor<D...
Definition: broadcast.h:41
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: broadcast.h:161
Broadcast1DExp(const SrcExp &src, Shape< dimdst > shape)
constructor
Definition: broadcast.h:47
Definition: expr_engine-inl.h:59
used to help static type check
Definition: expr_engine-inl.h:331
broadcast scalar into a higher dimension Tensor input: Tensor<Device,1>: ishape = {1} output: Tensor<...
Definition: broadcast.h:62
shape of a tensor
Definition: tensor.h:54
BroadcastScalarExp< SrcExp, DType, dimdst > broadcast_scalar(const expr::Exp< SrcExp, DType, etype > &src, Shape< dimdst > shape)
a expression that replicate a scalar tensor to target dimension.
Definition: broadcast.h:110
Plan(const Broadcast1DExp< SrcExp, DType, dimdst, 1 > &e)
Definition: broadcast.h:159
#define MSHADOW_XINLINE
Definition: base.h:223
int32_t index_t
type that will be used for index
Definition: base.h:336
const SrcExp & src_
source operand
Definition: broadcast.h:45
Plan(const BroadcastScalarExp< SrcExp, DType, dimdst > &e)
Definition: broadcast.h:173
const SrcExp & src_
source operand
Definition: broadcast.h:66
runtime shape checking template get the shape of an expression, report error if shape mismatch ...
Definition: expr_engine-inl.h:365
Plan(const Broadcast1DExp< SrcExp, DType, dimdst, dimdst_m_cast > &e)
Definition: broadcast.h:139
MSHADOW_XINLINE Shape< 2 > Shape2(index_t s0, index_t s1)
construct a two dimension shape, stride will equal s0
Definition: tensor.h:217
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: broadcast.h:175
BroadcastScalarExp(const SrcExp &src, Shape< dimdst > shape)
constructor
Definition: broadcast.h:68
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:80
const SubType & self(void) const
Definition: expression.h:83
Plan< BinaryMapExp< OP, TA, TB, DType, etype >, DType > MakePlan(const BinaryMapExp< OP, TA, TB, DType, etype > &e)
Definition: expr_engine-inl.h:240
a general class that allows extension that makes tensors of some shape
Definition: expr_engine-inl.h:44
overloaded + operator between half_t and bf16_t
Definition: base.h:327
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: broadcast.h:146
Shape< dim > shape_
the shape of this expression
Definition: expr_engine-inl.h:48
Broadcast1DExp< SrcExp, DType, dimdst, dimdst-dimcast > broadcast(const expr::Exp< SrcExp, DType, etype > &src, Shape< dimdst > shape)
a expression that replicate a 1 dimension tensor in dimension dimcast
Definition: broadcast.h:88