26 #ifndef MSHADOW_EXTENSION_BROADCAST_WITH_AXIS_H_ 27 #define MSHADOW_EXTENSION_BROADCAST_WITH_AXIS_H_ 30 #include "../extension.h" 42 template<
typename SrcExp,
typename DType,
int dimsrc,
int dimdst>
44 public MakeTensorExp<BroadcastWithAxisExp<SrcExp, DType, dimsrc, dimdst>,
45 SrcExp, dimdst, DType> {
58 : src_(src), size_(size) {
59 bool keepdim = (dimsrc == dimdst);
64 CHECK(dimsrc > axis && axis >= -1) <<
"broadcast axis (no keepdim) out of bound, " <<
65 "axis must be between -1 and" << dimsrc - 1 <<
", given=" << axis <<
".";
66 for (
int i = 0; i <= axis; ++i) {
67 this->
shape_[i] = src_shape[i];
70 for (
int i = axis + 1; i < dimsrc; ++i) {
71 this->trailing_ *= src_shape[i];
72 this->
shape_[i + 1] = src_shape[i];
75 CHECK(dimdst > axis && axis >= 0) <<
"broadcast axis (keepdim) out of bound, " <<
76 "axis must be between 0 and" << dimdst - 1 <<
", given=" << axis <<
".";
77 CHECK_EQ(src_shape[axis], 1U) <<
"Size of the dimension of the broadcasting axis must be 1" <<
78 " when keepdim is on, src_shape[" << axis <<
"]=" << src_shape[axis] <<
".";
79 for (
int i = 0; i <= axis - 1; ++i) {
80 this->
shape_[i] = src_shape[i];
83 for (
int i = axis + 1; i < dimdst; ++i) {
84 this->trailing_ *= src_shape[i];
85 this->
shape_[i] = src_shape[i];
89 this->last_ = src_shape[dimsrc - 1];
90 this->dst_last_ = this->
shape_[dimdst - 1];
100 template<
typename SrcExp,
typename DType,
int etype>
104 return BroadcastWithAxisExp<SrcExp, DType, ExpInfo<SrcExp>::kDim,
114 template<
typename SrcExp,
typename DType,
int etype>
115 inline BroadcastWithAxisExp<SrcExp, DType, ExpInfo<SrcExp>::kDim,
118 return BroadcastWithAxisExp<SrcExp, DType, ExpInfo<SrcExp>::kDim,
130 template<
typename SrcExp,
typename DType,
int dimsrc>
132 public MakeTensorExp<BroadcastWithMultiAxesExp<SrcExp, DType, dimsrc>,
133 SrcExp, dimsrc, DType> {
147 template<
typename TShape>
151 CHECK(axes.ndim() == sizes.ndim()) <<
"ndim of axes and sizes must be equal.";
152 this->axesnum_ = axes.ndim();
153 CHECK(this->axesnum_ <= dimsrc) <<
"Number of broadcasting axes must be smaller than" 154 "the source ndim, number of axes=" << this->axesnum_ <<
" dimsrc=" << dimsrc;
155 for (
index_t i = 0; i < this->axesnum_; i++) {
156 CHECK(dimsrc > axes[i]) <<
"broadcast axis (keepdim) out of bound, " <<
157 "all axes must be between 0 and" << dimsrc - 1 <<
", given axes[" << i <<
"] = " << axes[i]
159 CHECK_EQ(src_shape[axes[i]], 1U) <<
"Size of the dimension of the broadcasting axis must be 1" 160 <<
", src_shape[" << axes[i] <<
"]=" << src_shape[axes[i]] <<
".";
161 if (i < this->axesnum_ - 1) {
162 CHECK(axes[i] < axes[i + 1]) <<
"The given axes must be in increasing order.";
165 for (
index_t i = 0; i < dimsrc; i++) {
166 this->
shape_[i] = src_shape[i];
168 this->trailings_[i] = 1;
170 for (
index_t i = 0; i < this->axesnum_; i++) {
171 this->
shape_[axes[i]] = sizes[i];
172 this->sizes_[i] = sizes[i];
174 for (
index_t i = 0; i < this->axesnum_; i++) {
175 this->trailings_[i] = 1;
176 for (
index_t j = axes[i] + 1; j < dimsrc; ++j) {
177 this->trailings_[i] *= this->
shape_[j];
180 this->last_ = src_shape[dimsrc - 1];
181 this->dst_last_ = this->
shape_[dimsrc - 1];
195 template<
typename SrcExp,
typename DType,
int etype,
typename TShape>
198 const TShape &axes,
const TShape &sizes) {
212 template<
typename SrcExp,
typename DType,
int etype,
typename TShape>
216 CHECK_EQ(target_shape.ndim(), dimsrc);
217 std::vector<index_t> axes_vec, sizes_vec;
219 for (
size_t i = 0; i < dimsrc; ++i) {
220 if (src_shape[i] != target_shape[i]) {
221 CHECK_EQ(src_shape[i], 1U) <<
"broadcasting axis must have size 1, received shape=" 222 << src_shape <<
" target_shape=" << target_shape;
223 axes_vec.push_back(i);
224 sizes_vec.push_back(target_shape[i]);
227 TShape axes = TShape(axes_vec.begin(), axes_vec.end());
228 TShape sizes = TShape(sizes_vec.begin(), sizes_vec.end());
235 template<
typename SrcExp,
typename DType,
int dimsrc,
int dimdst>
253 template<
typename SrcExp,
typename DType,
int dimsrc>
258 trailings_(e.trailings_), sizes_(e.sizes_) {}
261 for (
index_t p = 0; p < dimsrc; ++p) {
265 indx = (indx / trailings_[p] / sizes_[p]) * trailings_[p] + (indx % trailings_[p]);
277 #endif // MSHADOW_EXTENSION_BROADCAST_WITH_AXIS_H_ Plan(const BroadcastWithAxisExp< SrcExp, DType, dimsrc, dimdst > &e)
Definition: broadcast_with_axis.h:238
Definition: expr_engine-inl.h:59
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const
Definition: broadcast_with_axis.h:241
index_t size_
new dimension of the broadcasting axis
Definition: broadcast_with_axis.h:53
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const
Definition: broadcast_with_axis.h:259
BroadcastWithMultiAxesExp< SrcExp, DType, ExpInfo< SrcExp >::kDim > broadcast_to(const Exp< SrcExp, DType, etype > &src, const TShape &target_shape)
Broadcasting the tensor to the target shape, dimension of different sizes must be 1 in the original t...
Definition: broadcast_with_axis.h:214
BroadcastWithAxisExp< SrcExp, DType, ExpInfo< SrcExp >::kDim, ExpInfo< SrcExp >::kDim+1 > broadcast_with_axis(const Exp< SrcExp, DType, etype > &src, const int axis, const index_t size)
Broadcasting the tensor after given axis.
Definition: broadcast_with_axis.h:103
BroadcastWithMultiAxesExp(const SrcExp &src, const TShape &axes, const TShape &sizes)
Definition: broadcast_with_axis.h:148
index_t axesnum_
number of broadcasting axes
Definition: broadcast_with_axis.h:139
Shape< dimsrc > sizes_
new dimension of the broadcasting axes
Definition: broadcast_with_axis.h:143
static Shape< dim > Check(const E &t)
const SrcExp & src_
data oprand
Definition: broadcast_with_axis.h:47
index_t last_
size of the last dimension of src
Definition: broadcast_with_axis.h:55
#define MSHADOW_XINLINE
Definition: base.h:223
static type inference template, used to get the dimension of each expression, if ExpInfo<E>::kDim == ...
Definition: expr_engine-inl.h:263
const SrcExp & src_
data oprand
Definition: broadcast_with_axis.h:135
int32_t index_t
type that will be used for index
Definition: base.h:336
BroadcastWithAxisExp< SrcExp, DType, ExpInfo< SrcExp >::kDim, ExpInfo< SrcExp >::kDim > broadcast_keepdim(const Exp< SrcExp, DType, etype > &src, const int axis, const index_t size)
Broadcasting the tensor in the given axis (keepdim turned on)
Definition: broadcast_with_axis.h:117
Broadcasting the tensor in the given axis. If keepdim is off, insert the broadcasting dim after axis...
Definition: broadcast_with_axis.h:43
Broadcasting the tensor in multiple axes. The dimension of the source tensor in the given axes must b...
Definition: broadcast_with_axis.h:131
Plan(const BroadcastWithMultiAxesExp< SrcExp, DType, dimsrc > &e)
Definition: broadcast_with_axis.h:256
BroadcastWithAxisExp(const SrcExp &src, const int axis, const index_t size)
Definition: broadcast_with_axis.h:57
index_t dst_last_
size of the last dimension of dst
Definition: broadcast_with_axis.h:49
index_t dst_last_
size of the last dimension of dst
Definition: broadcast_with_axis.h:137
index_t trailing_
product of the dimensions after the broadcasting axis
Definition: broadcast_with_axis.h:51
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:80
const SubType & self(void) const
Definition: expression.h:83
Plan< BinaryMapExp< OP, TA, TB, DType, etype >, DType > MakePlan(const BinaryMapExp< OP, TA, TB, DType, etype > &e)
Definition: expr_engine-inl.h:240
a general class that allows extension that makes tensors of some shape
Definition: expr_engine-inl.h:44
overloaded + operator between half_t and bf16_t
Definition: base.h:327
Shape< dim > shape_
the shape of this expression
Definition: expr_engine-inl.h:48
BroadcastWithMultiAxesExp< SrcExp, DType, ExpInfo< SrcExp >::kDim > broadcast_multi_axes(const Exp< SrcExp, DType, etype > &src, const TShape &axes, const TShape &sizes)
Broadcasting the tensor in the given axis (keepdim turned on)
Definition: broadcast_with_axis.h:197
Shape< dimsrc > trailings_
product of the dimensions after the broadcasting axses
Definition: broadcast_with_axis.h:141
index_t last_
size of the last dimension of src
Definition: broadcast_with_axis.h:145