Go to the documentation of this file.
25 #ifndef MSHADOW_EXTENSION_BROADCAST_WITH_AXIS_H_
26 #define MSHADOW_EXTENSION_BROADCAST_WITH_AXIS_H_
29 #include "../extension.h"
41 template<
typename SrcExp,
typename DType,
int dimsrc,
int dimdst>
43 public MakeTensorExp<BroadcastWithAxisExp<SrcExp, DType, dimsrc, dimdst>,
44 SrcExp, dimdst, DType> {
58 bool keepdim = (dimsrc == dimdst);
63 CHECK(dimsrc > axis && axis >= -1) <<
"broadcast axis (no keepdim) out of bound, " <<
64 "axis must be between -1 and" << dimsrc - 1 <<
", given=" << axis <<
".";
65 for (
int i = 0; i <= axis; ++i) {
66 this->
shape_[i] = src_shape[i];
69 for (
int i = axis + 1; i < dimsrc; ++i) {
70 this->trailing_ *= src_shape[i];
71 this->
shape_[i + 1] = src_shape[i];
74 CHECK(dimdst > axis && axis >= 0) <<
"broadcast axis (keepdim) out of bound, " <<
75 "axis must be between 0 and" << dimdst - 1 <<
", given=" << axis <<
".";
76 CHECK_EQ(src_shape[axis], 1U) <<
"Size of the dimension of the broadcasting axis must be 1" <<
77 " when keepdim is on, src_shape[" << axis <<
"]=" << src_shape[axis] <<
".";
78 for (
int i = 0; i <= axis - 1; ++i) {
79 this->
shape_[i] = src_shape[i];
82 for (
int i = axis + 1; i < dimdst; ++i) {
83 this->trailing_ *= src_shape[i];
84 this->
shape_[i] = src_shape[i];
88 this->last_ = src_shape[dimsrc - 1];
89 this->dst_last_ = this->
shape_[dimdst - 1];
99 template<
typename SrcExp,
typename DType,
int etype>
100 inline BroadcastWithAxisExp<SrcExp, DType, ExpInfo<SrcExp>::kDim,
113 template<
typename SrcExp,
typename DType,
int etype>
114 inline BroadcastWithAxisExp<SrcExp, DType, ExpInfo<SrcExp>::kDim,
129 template<
typename SrcExp,
typename DType,
int dimsrc>
131 public MakeTensorExp<BroadcastWithMultiAxesExp<SrcExp, DType, dimsrc>,
132 SrcExp, dimsrc, DType> {
146 template<
typename TShape>
150 CHECK(axes.ndim() == sizes.ndim()) <<
"ndim of axes and sizes must be equal.";
151 this->axesnum_ = axes.ndim();
152 CHECK(this->axesnum_ <= dimsrc) <<
"Number of broadcasting axes must be smaller than"
153 "the source ndim, number of axes=" << this->axesnum_ <<
" dimsrc=" << dimsrc;
155 CHECK(dimsrc > axes[i]) <<
"broadcast axis (keepdim) out of bound, " <<
156 "all axes must be between 0 and" << dimsrc - 1 <<
", given axes[" << i <<
"] = " << axes[i]
158 CHECK_EQ(src_shape[axes[i]], 1U) <<
"Size of the dimension of the broadcasting axis must be 1"
159 <<
", src_shape[" << axes[i] <<
"]=" << src_shape[axes[i]] <<
".";
161 CHECK(axes[i] < axes[i + 1]) <<
"The given axes must be in increasing order.";
164 for (
index_t i = 0; i < dimsrc; i++) {
165 this->
shape_[i] = src_shape[i];
167 this->trailings_[i] = 1;
170 this->
shape_[axes[i]] = sizes[i];
171 this->sizes_[i] = sizes[i];
174 this->trailings_[i] = 1;
175 for (
index_t j = axes[i] + 1; j < dimsrc; ++j) {
176 this->trailings_[i] *= this->
shape_[j];
179 this->last_ = src_shape[dimsrc - 1];
180 this->dst_last_ = this->
shape_[dimsrc - 1];
194 template<
typename SrcExp,
typename DType,
int etype,
typename TShape>
195 inline BroadcastWithMultiAxesExp<SrcExp, DType, ExpInfo<SrcExp>::kDim>
197 const TShape &axes,
const TShape &sizes) {
211 template<
typename SrcExp,
typename DType,
int etype,
typename TShape>
212 inline BroadcastWithMultiAxesExp<SrcExp, DType, ExpInfo<SrcExp>::kDim>
215 CHECK_EQ(target_shape.ndim(), dimsrc);
216 std::vector<index_t> axes_vec, sizes_vec;
218 for (
size_t i = 0; i < dimsrc; ++i) {
219 if (src_shape[i] != target_shape[i]) {
220 CHECK_EQ(src_shape[i], 1U) <<
"broadcasting axis must have size 1, received shape="
221 << src_shape <<
" target_shape=" << target_shape;
222 axes_vec.push_back(i);
223 sizes_vec.push_back(target_shape[i]);
226 TShape axes = TShape(axes_vec.begin(), axes_vec.end());
227 TShape sizes = TShape(sizes_vec.begin(), sizes_vec.end());
234 template<
typename SrcExp,
typename DType,
int dimsrc,
int dimdst>
238 : src_(
MakePlan(e.src_)), dst_last_(e.dst_last_),
239 trailing_(e.trailing_), size_(e.size_), last_(e.last_) {}
241 index_t x = (i * dst_last_ + j) / trailing_ / size_;
242 index_t y = (i * dst_last_ + j) % trailing_;
244 return src_.Eval(z / last_, z % last_);
249 const index_t dst_last_, trailing_, size_, last_;
252 template<
typename SrcExp,
typename DType,
int dimsrc>
256 : src_(
MakePlan(e.src_)), dst_last_(e.dst_last_), last_(e.last_), axesnum_(e.axesnum_),
257 trailings_(e.trailings_), sizes_(e.sizes_) {}
259 index_t indx = i * dst_last_ + j;
260 for (
index_t p = 0; p < dimsrc; ++p) {
264 indx = (indx / trailings_[p] / sizes_[p]) * trailings_[p] + (indx % trailings_[p]);
266 return src_.Eval(indx / last_, indx % last_);
271 const index_t dst_last_, last_, axesnum_;
276 #endif // MSHADOW_EXTENSION_BROADCAST_WITH_AXIS_H_
index_t dst_last_
size of the last dimension of dst
Definition: broadcast_with_axis.h:136
Plan(const BroadcastWithAxisExp< SrcExp, DType, dimsrc, dimdst > &e)
Definition: broadcast_with_axis.h:237
Broadcasting the tensor in the given axis. If keepdim is off, insert the broadcasting dim after axis....
Definition: broadcast_with_axis.h:42
const SubType & self(void) const
Definition: expression.h:82
index_t size_
new dimension of the broadcasting axis
Definition: broadcast_with_axis.h:52
BroadcastWithMultiAxesExp< SrcExp, DType, ExpInfo< SrcExp >::kDim > broadcast_to(const Exp< SrcExp, DType, etype > &src, const TShape &target_shape)
Broadcasting the tensor to the target shape, dimension of different sizes must be 1 in the original t...
Definition: broadcast_with_axis.h:213
#define MSHADOW_XINLINE
Definition: base.h:228
BroadcastWithAxisExp(const SrcExp &src, const int axis, const index_t size)
Definition: broadcast_with_axis.h:56
static Shape< dim > Check(const E &t)
BroadcastWithMultiAxesExp(const SrcExp &src, const TShape &axes, const TShape &sizes)
Definition: broadcast_with_axis.h:147
index_t last_
size of the last dimension of src
Definition: broadcast_with_axis.h:54
static type inference template, used to get the dimension of each expression, if ExpInfo<E>::kDim == ...
Definition: expr_engine-inl.h:262
Plan< BinaryMapExp< OP, TA, TB, DType, etype >, DType > MakePlan(const BinaryMapExp< OP, TA, TB, DType, etype > &e)
Definition: expr_engine-inl.h:239
Shape< dim > shape_
the shape of this expression
Definition: expr_engine-inl.h:47
const SrcExp & src_
data oprand
Definition: broadcast_with_axis.h:134
index_t last_
size of the last dimension of src
Definition: broadcast_with_axis.h:144
Broadcasting the tensor in multiple axes. The dimension of the source tensor in the given axes must b...
Definition: broadcast_with_axis.h:130
static const int kDim
Definition: expr_engine-inl.h:263
int32_t index_t
type that will be used for index
Definition: base.h:328
const SrcExp & src_
data oprand
Definition: broadcast_with_axis.h:46
Definition: expr_engine-inl.h:58
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:79
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const
Definition: broadcast_with_axis.h:240
index_t trailing_
product of the dimensions after the broadcasting axis
Definition: broadcast_with_axis.h:50
Shape< dimsrc > trailings_
product of the dimensions after the broadcasting axses
Definition: broadcast_with_axis.h:140
overloaded + operator between half_t and bf16_t
Definition: base.h:319
a general class that allows extension that makes tensors of some shape
Definition: expr_engine-inl.h:43
BroadcastWithAxisExp< SrcExp, DType, ExpInfo< SrcExp >::kDim, ExpInfo< SrcExp >::kDim+1 > broadcast_with_axis(const Exp< SrcExp, DType, etype > &src, const int axis, const index_t size)
Broadcasting the tensor after given axis.
Definition: broadcast_with_axis.h:102
Shape< dimsrc > sizes_
new dimension of the broadcasting axes
Definition: broadcast_with_axis.h:142
BroadcastWithAxisExp< SrcExp, DType, ExpInfo< SrcExp >::kDim, ExpInfo< SrcExp >::kDim > broadcast_keepdim(const Exp< SrcExp, DType, etype > &src, const int axis, const index_t size)
Broadcasting the tensor in the given axis (keepdim turned on)
Definition: broadcast_with_axis.h:116
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const
Definition: broadcast_with_axis.h:258
index_t dst_last_
size of the last dimension of dst
Definition: broadcast_with_axis.h:48
index_t axesnum_
number of broadcasting axes
Definition: broadcast_with_axis.h:138
Plan(const BroadcastWithMultiAxesExp< SrcExp, DType, dimsrc > &e)
Definition: broadcast_with_axis.h:255
BroadcastWithMultiAxesExp< SrcExp, DType, ExpInfo< SrcExp >::kDim > broadcast_multi_axes(const Exp< SrcExp, DType, etype > &src, const TShape &axes, const TShape &sizes)
Broadcasting the tensor in the given axis (keepdim turned on)
Definition: broadcast_with_axis.h:196