mxnet
broadcast_with_axis.h
Go to the documentation of this file.
1 
7 #ifndef MSHADOW_EXTENSION_BROADCAST_WITH_AXIS_H_
8 #define MSHADOW_EXTENSION_BROADCAST_WITH_AXIS_H_
9 
10 #include <vector>
11 #include "../extension.h"
12 
13 namespace mshadow {
14 namespace expr {
15 
23 template<typename SrcExp, typename DType, int dimsrc, int dimdst>
25  public MakeTensorExp<BroadcastWithAxisExp<SrcExp, DType, dimsrc, dimdst>,
26  SrcExp, dimdst, DType> {
28  const SrcExp &src_;
38  BroadcastWithAxisExp(const SrcExp &src, const int axis, const index_t size)
39  : src_(src), size_(size) {
40  bool keepdim = (dimsrc == dimdst);
42  this->trailing_ = 1;
43 
44  if (!keepdim) {
45  CHECK(dimsrc > axis && axis >= -1) << "broadcast axis (no keepdim) out of bound, " <<
46  "axis must be between -1 and" << dimsrc - 1 << ", given=" << axis << ".";
47  for (int i = 0; i <= axis; ++i) {
48  this->shape_[i] = src_shape[i];
49  }
50  this->shape_[axis + 1] = size_;
51  for (int i = axis + 1; i < dimsrc; ++i) {
52  this->trailing_ *= src_shape[i];
53  this->shape_[i + 1] = src_shape[i];
54  }
55  } else {
56  CHECK(dimdst > axis && axis >= 0) << "broadcast axis (keepdim) out of bound, " <<
57  "axis must be between 0 and" << dimdst - 1 << ", given=" << axis << ".";
58  CHECK_EQ(src_shape[axis], 1U) << "Size of the dimension of the broadcasting axis must be 1" <<
59  " when keepdim is on, src_shape[" << axis << "]=" << src_shape[axis] << ".";
60  for (int i = 0; i <= axis - 1; ++i) {
61  this->shape_[i] = src_shape[i];
62  }
63  this->shape_[axis] = size_;
64  for (int i = axis + 1; i < dimdst; ++i) {
65  this->trailing_ *= src_shape[i];
66  this->shape_[i] = src_shape[i];
67  }
68  }
69 
70  this->last_ = src_shape[dimsrc - 1];
71  this->dst_last_ = this->shape_[dimdst - 1];
72  }
73 }; // struct BroadcastWithAxisExp
74 
81 template<typename SrcExp, typename DType, int etype>
84 broadcast_with_axis(const Exp<SrcExp, DType, etype> &src, const int axis, const index_t size) {
85  return BroadcastWithAxisExp<SrcExp, DType, ExpInfo<SrcExp>::kDim,
86  ExpInfo<SrcExp>::kDim + 1>(src.self(), axis, size);
87 }
88 
95 template<typename SrcExp, typename DType, int etype>
96 inline BroadcastWithAxisExp<SrcExp, DType, ExpInfo<SrcExp>::kDim,
98  broadcast_keepdim(const Exp<SrcExp, DType, etype> &src, const int axis, const index_t size) {
99  return BroadcastWithAxisExp<SrcExp, DType, ExpInfo<SrcExp>::kDim,
100  ExpInfo<SrcExp>::kDim>(src.self(), axis, size);
101 }
102 
111 template<typename SrcExp, typename DType, int dimsrc>
113  public MakeTensorExp<BroadcastWithMultiAxesExp<SrcExp, DType, dimsrc>,
114  SrcExp, dimsrc, DType> {
116  const SrcExp &src_;
128  template<typename TShape>
129  BroadcastWithMultiAxesExp(const SrcExp &src, const TShape& axes, const TShape& sizes)
130  : src_(src) {
132  CHECK(axes.ndim() == sizes.ndim()) << "ndim of axes and sizes must be equal.";
133  this->axesnum_ = axes.ndim();
134  CHECK(this->axesnum_ <= dimsrc) << "Number of broadcasting axes must be smaller than"
135  "the source ndim, number of axes=" << this->axesnum_ << " dimsrc=" << dimsrc;
136  for (index_t i = 0; i < this->axesnum_; i++) {
137  CHECK(dimsrc > axes[i]) << "broadcast axis (keepdim) out of bound, " <<
138  "all axes must be between 0 and" << dimsrc - 1 << ", given axes[" << i << "] = " << axes[i]
139  << ".";
140  CHECK_EQ(src_shape[axes[i]], 1U) << "Size of the dimension of the broadcasting axis must be 1"
141  << ", src_shape[" << axes[i] << "]=" << src_shape[axes[i]] << ".";
142  if (i < this->axesnum_ - 1) {
143  CHECK(axes[i] < axes[i + 1]) << "The given axes must be in increasing order.";
144  }
145  }
146  for (index_t i = 0; i < dimsrc; i++) {
147  this->shape_[i] = src_shape[i];
148  this->sizes_[i] = 1;
149  this->trailings_[i] = 1;
150  }
151  for (index_t i = 0; i < this->axesnum_; i++) {
152  this->shape_[axes[i]] = sizes[i];
153  this->sizes_[i] = sizes[i];
154  }
155  for (index_t i = 0; i < this->axesnum_; i++) {
156  this->trailings_[i] = 1;
157  for (index_t j = axes[i] + 1; j < dimsrc; ++j) {
158  this->trailings_[i] *= this->shape_[j];
159  }
160  }
161  this->last_ = src_shape[dimsrc - 1];
162  this->dst_last_ = this->shape_[dimsrc - 1];
163  }
164 }; // struct BroadcastWithMultiAxesExp
165 
176 template<typename SrcExp, typename DType, int etype, typename TShape>
179 const TShape &axes, const TShape &sizes) {
181 }
182 
193 template<typename SrcExp, typename DType, int etype, typename TShape>
195 broadcast_to(const Exp<SrcExp, DType, etype> &src, const TShape &target_shape) {
196  static const size_t dimsrc = ExpInfo<SrcExp>::kDim;
197  CHECK_EQ(target_shape.ndim(), dimsrc);
198  std::vector<index_t> axes_vec, sizes_vec;
200  for (size_t i = 0; i < dimsrc; ++i) {
201  if (src_shape[i] != target_shape[i]) {
202  CHECK_EQ(src_shape[i], 1U) << "broadcasting axis must have size 1, received shape="
203  << src_shape << " target_shape=" << target_shape;
204  axes_vec.push_back(i);
205  sizes_vec.push_back(target_shape[i]);
206  }
207  }
208  TShape axes = TShape(axes_vec.begin(), axes_vec.end());
209  TShape sizes = TShape(sizes_vec.begin(), sizes_vec.end());
211 }
212 
213 //----------------------
214 // Execution plan
215 //----------------------
216 template<typename SrcExp, typename DType, int dimsrc, int dimdst>
217 struct Plan<BroadcastWithAxisExp<SrcExp, DType, dimsrc, dimdst>, DType> {
218  public:
221  trailing_(e.trailing_), size_(e.size_), last_(e.last_) {}
222  MSHADOW_XINLINE DType Eval(index_t i, index_t j) const {
223  index_t x = (i * dst_last_ + j) / trailing_ / size_;
224  index_t y = (i * dst_last_ + j) % trailing_;
225  index_t z = x * trailing_ + y;
226  return src_.Eval(z / last_, z % last_);
227  }
228 
229  private:
232 };
233 
234 template<typename SrcExp, typename DType, int dimsrc>
235 struct Plan<BroadcastWithMultiAxesExp<SrcExp, DType, dimsrc>, DType> {
236  public:
238  : src_(MakePlan(e.src_)), dst_last_(e.dst_last_), last_(e.last_), axesnum_(e.axesnum_),
239  trailings_(e.trailings_), sizes_(e.sizes_) {}
240  MSHADOW_XINLINE DType Eval(index_t i, index_t j) const {
241  index_t indx = i * dst_last_ + j;
242  for (index_t p = 0; p < dimsrc; ++p) {
243  if (p >= axesnum_) {
244  break;
245  }
246  indx = (indx / trailings_[p] / sizes_[p]) * trailings_[p] + (indx % trailings_[p]);
247  }
248  return src_.Eval(indx / last_, indx % last_);
249  }
250 
251  private:
253  const index_t dst_last_, last_, axesnum_;
254  const Shape<dimsrc> trailings_, sizes_;
255 };
256 } // namespace expr
257 } // namespace mshadow
258 #endif // MSHADOW_EXTENSION_BROADCAST_WITH_AXIS_H_
Plan(const BroadcastWithAxisExp< SrcExp, DType, dimsrc, dimdst > &e)
Definition: broadcast_with_axis.h:219
Definition: expr_engine-inl.h:40
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const
Definition: broadcast_with_axis.h:222
index_t size_
new dimension of the broadcasting axis
Definition: broadcast_with_axis.h:34
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const
Definition: broadcast_with_axis.h:240
BroadcastWithMultiAxesExp< SrcExp, DType, ExpInfo< SrcExp >::kDim > broadcast_to(const Exp< SrcExp, DType, etype > &src, const TShape &target_shape)
Broadcasting the tensor to the target shape, dimension of different sizes must be 1 in the original t...
Definition: broadcast_with_axis.h:195
BroadcastWithAxisExp< SrcExp, DType, ExpInfo< SrcExp >::kDim, ExpInfo< SrcExp >::kDim+1 > broadcast_with_axis(const Exp< SrcExp, DType, etype > &src, const int axis, const index_t size)
Broadcasting the tensor after given axis.
Definition: broadcast_with_axis.h:84
BroadcastWithMultiAxesExp(const SrcExp &src, const TShape &axes, const TShape &sizes)
Definition: broadcast_with_axis.h:129
index_t axesnum_
number of broadcasting axes
Definition: broadcast_with_axis.h:120
Shape< dimsrc > sizes_
new dimension of the broadcasting axes
Definition: broadcast_with_axis.h:124
static Shape< dim > Check(const E &t)
const SrcExp & src_
data oprand
Definition: broadcast_with_axis.h:28
index_t last_
size of the last dimension of src
Definition: broadcast_with_axis.h:36
#define MSHADOW_XINLINE
Definition: base.h:204
static type inference template, used to get the dimension of each expression, if ExpInfo<E>::kDim == ...
Definition: expr_engine-inl.h:244
const SrcExp & src_
data oprand
Definition: broadcast_with_axis.h:116
int32_t index_t
type that will be used for index
Definition: base.h:291
BroadcastWithAxisExp< SrcExp, DType, ExpInfo< SrcExp >::kDim, ExpInfo< SrcExp >::kDim > broadcast_keepdim(const Exp< SrcExp, DType, etype > &src, const int axis, const index_t size)
Broadcasting the tensor in the given axis (keepdim turned on)
Definition: broadcast_with_axis.h:98
Broadcasting the tensor in the given axis. If keepdim is off, insert the broadcasting dim after axis...
Definition: broadcast_with_axis.h:24
Broadcasting the tensor in multiple axes. The dimension of the source tensor in the given axes must b...
Definition: broadcast_with_axis.h:112
Plan(const BroadcastWithMultiAxesExp< SrcExp, DType, dimsrc > &e)
Definition: broadcast_with_axis.h:237
BroadcastWithAxisExp(const SrcExp &src, const int axis, const index_t size)
Definition: broadcast_with_axis.h:38
index_t dst_last_
size of the last dimension of dst
Definition: broadcast_with_axis.h:30
index_t dst_last_
size of the last dimension of dst
Definition: broadcast_with_axis.h:118
index_t trailing_
product of the dimensions after the broadcasting axis
Definition: broadcast_with_axis.h:32
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:61
const SubType & self(void) const
Definition: expression.h:64
Plan< BinaryMapExp< OP, TA, TB, DType, etype >, DType > MakePlan(const BinaryMapExp< OP, TA, TB, DType, etype > &e)
Definition: expr_engine-inl.h:221
a general class that allows extension that makes tensors of some shape
Definition: expr_engine-inl.h:25
namespace for mshadow
Definition: base.h:282
Shape< dim > shape_
the shape of this expression
Definition: expr_engine-inl.h:29
BroadcastWithMultiAxesExp< SrcExp, DType, ExpInfo< SrcExp >::kDim > broadcast_multi_axes(const Exp< SrcExp, DType, etype > &src, const TShape &axes, const TShape &sizes)
Broadcasting the tensor in the given axis (keepdim turned on)
Definition: broadcast_with_axis.h:178
Shape< dimsrc > trailings_
product of the dimensions after the broadcasting axses
Definition: broadcast_with_axis.h:122
index_t last_
size of the last dimension of src
Definition: broadcast_with_axis.h:126