7 #ifndef MSHADOW_EXTENSION_BROADCAST_WITH_AXIS_H_ 8 #define MSHADOW_EXTENSION_BROADCAST_WITH_AXIS_H_ 11 #include "../extension.h" 23 template<
typename SrcExp,
typename DType,
int dimsrc,
int dimdst>
25 public MakeTensorExp<BroadcastWithAxisExp<SrcExp, DType, dimsrc, dimdst>,
26 SrcExp, dimdst, DType> {
39 : src_(src), size_(size) {
40 bool keepdim = (dimsrc == dimdst);
45 CHECK(dimsrc > axis && axis >= -1) <<
"broadcast axis (no keepdim) out of bound, " <<
46 "axis must be between -1 and" << dimsrc - 1 <<
", given=" << axis <<
".";
47 for (
int i = 0; i <= axis; ++i) {
48 this->
shape_[i] = src_shape[i];
51 for (
int i = axis + 1; i < dimsrc; ++i) {
52 this->trailing_ *= src_shape[i];
53 this->
shape_[i + 1] = src_shape[i];
56 CHECK(dimdst > axis && axis >= 0) <<
"broadcast axis (keepdim) out of bound, " <<
57 "axis must be between 0 and" << dimdst - 1 <<
", given=" << axis <<
".";
58 CHECK_EQ(src_shape[axis], 1U) <<
"Size of the dimension of the broadcasting axis must be 1" <<
59 " when keepdim is on, src_shape[" << axis <<
"]=" << src_shape[axis] <<
".";
60 for (
int i = 0; i <= axis - 1; ++i) {
61 this->
shape_[i] = src_shape[i];
64 for (
int i = axis + 1; i < dimdst; ++i) {
65 this->trailing_ *= src_shape[i];
66 this->
shape_[i] = src_shape[i];
70 this->last_ = src_shape[dimsrc - 1];
71 this->dst_last_ = this->
shape_[dimdst - 1];
81 template<
typename SrcExp,
typename DType,
int etype>
85 return BroadcastWithAxisExp<SrcExp, DType, ExpInfo<SrcExp>::kDim,
95 template<
typename SrcExp,
typename DType,
int etype>
96 inline BroadcastWithAxisExp<SrcExp, DType, ExpInfo<SrcExp>::kDim,
99 return BroadcastWithAxisExp<SrcExp, DType, ExpInfo<SrcExp>::kDim,
111 template<
typename SrcExp,
typename DType,
int dimsrc>
113 public MakeTensorExp<BroadcastWithMultiAxesExp<SrcExp, DType, dimsrc>,
114 SrcExp, dimsrc, DType> {
128 template<
typename TShape>
132 CHECK(axes.ndim() == sizes.ndim()) <<
"ndim of axes and sizes must be equal.";
133 this->axesnum_ = axes.ndim();
134 CHECK(this->axesnum_ <= dimsrc) <<
"Number of broadcasting axes must be smaller than" 135 "the source ndim, number of axes=" << this->axesnum_ <<
" dimsrc=" << dimsrc;
136 for (
index_t i = 0; i < this->axesnum_; i++) {
137 CHECK(dimsrc > axes[i]) <<
"broadcast axis (keepdim) out of bound, " <<
138 "all axes must be between 0 and" << dimsrc - 1 <<
", given axes[" << i <<
"] = " << axes[i]
140 CHECK_EQ(src_shape[axes[i]], 1U) <<
"Size of the dimension of the broadcasting axis must be 1" 141 <<
", src_shape[" << axes[i] <<
"]=" << src_shape[axes[i]] <<
".";
142 if (i < this->axesnum_ - 1) {
143 CHECK(axes[i] < axes[i + 1]) <<
"The given axes must be in increasing order.";
146 for (
index_t i = 0; i < dimsrc; i++) {
147 this->
shape_[i] = src_shape[i];
149 this->trailings_[i] = 1;
151 for (
index_t i = 0; i < this->axesnum_; i++) {
152 this->
shape_[axes[i]] = sizes[i];
153 this->sizes_[i] = sizes[i];
155 for (
index_t i = 0; i < this->axesnum_; i++) {
156 this->trailings_[i] = 1;
157 for (
index_t j = axes[i] + 1; j < dimsrc; ++j) {
158 this->trailings_[i] *= this->
shape_[j];
161 this->last_ = src_shape[dimsrc - 1];
162 this->dst_last_ = this->
shape_[dimsrc - 1];
176 template<
typename SrcExp,
typename DType,
int etype,
typename TShape>
179 const TShape &axes,
const TShape &sizes) {
193 template<
typename SrcExp,
typename DType,
int etype,
typename TShape>
197 CHECK_EQ(target_shape.ndim(), dimsrc);
198 std::vector<index_t> axes_vec, sizes_vec;
200 for (
size_t i = 0; i < dimsrc; ++i) {
201 if (src_shape[i] != target_shape[i]) {
202 CHECK_EQ(src_shape[i], 1U) <<
"broadcasting axis must have size 1, received shape=" 203 << src_shape <<
" target_shape=" << target_shape;
204 axes_vec.push_back(i);
205 sizes_vec.push_back(target_shape[i]);
208 TShape axes = TShape(axes_vec.begin(), axes_vec.end());
209 TShape sizes = TShape(sizes_vec.begin(), sizes_vec.end());
216 template<
typename SrcExp,
typename DType,
int dimsrc,
int dimdst>
234 template<
typename SrcExp,
typename DType,
int dimsrc>
239 trailings_(e.trailings_), sizes_(e.sizes_) {}
242 for (
index_t p = 0; p < dimsrc; ++p) {
246 indx = (indx / trailings_[p] / sizes_[p]) * trailings_[p] + (indx % trailings_[p]);
258 #endif // MSHADOW_EXTENSION_BROADCAST_WITH_AXIS_H_ Plan(const BroadcastWithAxisExp< SrcExp, DType, dimsrc, dimdst > &e)
Definition: broadcast_with_axis.h:219
Definition: expr_engine-inl.h:40
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const
Definition: broadcast_with_axis.h:222
index_t size_
new dimension of the broadcasting axis
Definition: broadcast_with_axis.h:34
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const
Definition: broadcast_with_axis.h:240
BroadcastWithMultiAxesExp< SrcExp, DType, ExpInfo< SrcExp >::kDim > broadcast_to(const Exp< SrcExp, DType, etype > &src, const TShape &target_shape)
Broadcasting the tensor to the target shape, dimension of different sizes must be 1 in the original t...
Definition: broadcast_with_axis.h:195
BroadcastWithAxisExp< SrcExp, DType, ExpInfo< SrcExp >::kDim, ExpInfo< SrcExp >::kDim+1 > broadcast_with_axis(const Exp< SrcExp, DType, etype > &src, const int axis, const index_t size)
Broadcasting the tensor after given axis.
Definition: broadcast_with_axis.h:84
BroadcastWithMultiAxesExp(const SrcExp &src, const TShape &axes, const TShape &sizes)
Definition: broadcast_with_axis.h:129
index_t axesnum_
number of broadcasting axes
Definition: broadcast_with_axis.h:120
Shape< dimsrc > sizes_
new dimension of the broadcasting axes
Definition: broadcast_with_axis.h:124
static Shape< dim > Check(const E &t)
const SrcExp & src_
data oprand
Definition: broadcast_with_axis.h:28
index_t last_
size of the last dimension of src
Definition: broadcast_with_axis.h:36
#define MSHADOW_XINLINE
Definition: base.h:204
static type inference template, used to get the dimension of each expression, if ExpInfo<E>::kDim == ...
Definition: expr_engine-inl.h:244
const SrcExp & src_
data oprand
Definition: broadcast_with_axis.h:116
int32_t index_t
type that will be used for index
Definition: base.h:291
BroadcastWithAxisExp< SrcExp, DType, ExpInfo< SrcExp >::kDim, ExpInfo< SrcExp >::kDim > broadcast_keepdim(const Exp< SrcExp, DType, etype > &src, const int axis, const index_t size)
Broadcasting the tensor in the given axis (keepdim turned on)
Definition: broadcast_with_axis.h:98
Broadcasting the tensor in the given axis. If keepdim is off, insert the broadcasting dim after axis...
Definition: broadcast_with_axis.h:24
Broadcasting the tensor in multiple axes. The dimension of the source tensor in the given axes must b...
Definition: broadcast_with_axis.h:112
Plan(const BroadcastWithMultiAxesExp< SrcExp, DType, dimsrc > &e)
Definition: broadcast_with_axis.h:237
BroadcastWithAxisExp(const SrcExp &src, const int axis, const index_t size)
Definition: broadcast_with_axis.h:38
index_t dst_last_
size of the last dimension of dst
Definition: broadcast_with_axis.h:30
index_t dst_last_
size of the last dimension of dst
Definition: broadcast_with_axis.h:118
index_t trailing_
product of the dimensions after the broadcasting axis
Definition: broadcast_with_axis.h:32
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:61
const SubType & self(void) const
Definition: expression.h:64
Plan< BinaryMapExp< OP, TA, TB, DType, etype >, DType > MakePlan(const BinaryMapExp< OP, TA, TB, DType, etype > &e)
Definition: expr_engine-inl.h:221
a general class that allows extension that makes tensors of some shape
Definition: expr_engine-inl.h:25
namespace for mshadow
Definition: base.h:282
Shape< dim > shape_
the shape of this expression
Definition: expr_engine-inl.h:29
BroadcastWithMultiAxesExp< SrcExp, DType, ExpInfo< SrcExp >::kDim > broadcast_multi_axes(const Exp< SrcExp, DType, etype > &src, const TShape &axes, const TShape &sizes)
Broadcasting the tensor in the given axis (keepdim turned on)
Definition: broadcast_with_axis.h:178
Shape< dimsrc > trailings_
product of the dimensions after the broadcasting axses
Definition: broadcast_with_axis.h:122
index_t last_
size of the last dimension of src
Definition: broadcast_with_axis.h:126