mxnet
slice_ex.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef MSHADOW_EXTENSION_SLICE_EX_H_
25 #define MSHADOW_EXTENSION_SLICE_EX_H_
26 
27 #include "../extension.h"
28 
29 namespace mshadow {
30 namespace expr {
38 template<typename SrcExp, typename Device,
39  typename DType, int srcdim>
40 struct SliceExExp : public TRValue<SliceExExp<SrcExp,
41  Device, DType,
42  srcdim>,
43  Device, srcdim, DType> {
44  const SrcExp &src_;
49  SliceExExp(const SrcExp &src, Shape<srcdim> begin, Shape<srcdim> end)
50  : src_(src), begin_(begin), end_(end) {
52  for (int i = 0; i < srcdim; ++i) {
53  shape_[i] = end_[i] - begin_[i];
54  }
55  }
56  template<typename E, int etype>
57  inline void
59  this->__assign(exp);
60  }
61  inline void
62  operator=(const DType &exp) {
63  this->__assign(exp);
64  }
65 }; // struct SliceEx
66 
78 template<typename SrcExp, typename Device,
79  typename DType, int srcdim>
80 inline SliceExExp<SrcExp, Device, DType, srcdim>
83  ::Error_Expression_Does_Not_Meet_Dimension_Req();
84  return SliceExExp<SrcExp, Device, DType, srcdim>(src.self(), begin, end);
85 }
86 //------------------------
87 // engine plugin
88 //------------------------
89 // runtime shapecheck
90 template<typename SrcExp, typename Device,
91  typename DType, int srcdim>
92 struct ShapeCheck<srcdim, SliceExExp<SrcExp, Device, DType, srcdim> >{
93  inline static Shape<srcdim> Check(const SliceExExp<SrcExp,
94  Device, DType, srcdim> &t) {
95  return t.shape_;
96  }
97 };
98 
99 template<typename SrcExp, typename Device,
100  typename DType, int srcdim>
101 struct StreamInfo<Device, SliceExExp<SrcExp, Device, DType, srcdim> >{
102  inline static Stream<Device> *
105  }
106 };
107 // static typecheck
108 template<typename SrcExp, typename Device,
109  typename DType, int srcdim>
110 struct ExpInfo<SliceExExp<SrcExp, Device, DType, srcdim> >{
111  static const int kDim = ExpInfo<SrcExp>::kDim;
113 };
114 //----------------------
115 // Execution plan
116 //---------------------
117 template<typename SrcExp, typename Device,
118  typename DType, int srcdim>
119 struct Plan<SliceExExp<SrcExp, Device, DType, srcdim>, DType> {
120  public:
122  : src_(MakePlan(e.src_)), begin_(e.begin_),
123  src_shape_(e.src_shape_), shape_(e.shape_) {}
124  MSHADOW_XINLINE DType Eval(index_t i, index_t j) const {
125  index_t idx = 0;
126  index_t stride = 1;
127  #pragma unroll
128  for (int k = srcdim-2; k >= 0; --k) {
129  idx += stride * (i%shape_[k] + begin_[k]);
130  i /= shape_[k];
131  stride *= src_shape_[k];
132  }
133  return src_.Eval(idx, j + begin_[srcdim-1]);
134  }
136  index_t idx = 0;
137  index_t stride = 1;
138  #pragma unroll
139  for (int k = srcdim-2; k >= 0; --k) {
140  idx += stride * (i%shape_[k] + begin_[k]);
141  i /= shape_[k];
142  stride *= src_shape_[k];
143  }
144  return src_.REval(idx, j + begin_[srcdim-1]);
145  }
146 
147  private:
148  Plan<SrcExp, DType> src_;
149  const Shape<srcdim> begin_, src_shape_, shape_;
150 }; // struct Plan
151 } // namespace expr
152 } // namespace mshadow
153 #endif // MSHADOW_EXTENSION_SLICE_EX_H_
mshadow::expr::ExpInfo::kDevMask
static const int kDevMask
Definition: expr_engine-inl.h:264
mshadow::expr::Exp< Container, DType, type::kRValue >::self
const Container & self(void) const
Definition: expression.h:82
mshadow::Stream
computaion stream structure, used for asynchronous computations
Definition: tensor.h:488
mshadow::expr::slice
SliceExp< SrcExp, Device, DType, srcdim, srcdim - sdim > slice(const TRValue< SrcExp, Device, srcdim, DType > &src, index_t begin, index_t end)
Slice a Tensor.
Definition: slice.h:83
mshadow::expr::Plan< SliceExExp< SrcExp, Device, DType, srcdim >, DType >::Plan
Plan(const SliceExExp< SrcExp, Device, DType, srcdim > &e)
Definition: slice_ex.h:121
mshadow::TRValue
Tensor RValue, this is the super type of all kinds of possible tensors.
Definition: tensor.h:514
mshadow::expr::TypeCheckPass
used to help static type check
Definition: expr_engine-inl.h:330
mshadow::expr::SliceExExp::src_shape_
Shape< srcdim > src_shape_
Definition: slice_ex.h:45
mshadow::expr::SliceExExp::operator=
void operator=(const DType &exp)
Definition: slice_ex.h:62
mshadow::expr::SliceExExp::SliceExExp
SliceExExp(const SrcExp &src, Shape< srcdim > begin, Shape< srcdim > end)
Definition: slice_ex.h:49
mshadow::expr::Plan< SliceExExp< SrcExp, Device, DType, srcdim >, DType >::REval
MSHADOW_XINLINE DType & REval(index_t i, index_t j)
Definition: slice_ex.h:135
mshadow::expr::StreamInfo::Get
static Stream< Device > * Get(const E &t)
MSHADOW_XINLINE
#define MSHADOW_XINLINE
Definition: base.h:228
mshadow::expr::ShapeCheck
runtime shape checking template get the shape of an expression, report error if shape mismatch
Definition: expr_engine-inl.h:364
mshadow::expr::StreamInfo
Definition: expr_engine-inl.h:345
mshadow::expr::ShapeCheck::Check
static Shape< dim > Check(const E &t)
mshadow::expr::ExpInfo
static type inference template, used to get the dimension of each expression, if ExpInfo<E>::kDim == ...
Definition: expr_engine-inl.h:262
mshadow::expr::MakePlan
Plan< BinaryMapExp< OP, TA, TB, DType, etype >, DType > MakePlan(const BinaryMapExp< OP, TA, TB, DType, etype > &e)
Definition: expr_engine-inl.h:239
mshadow::expr::SliceExExp::shape_
Shape< srcdim > shape_
Definition: slice_ex.h:46
mshadow::Shape::shape_
index_t shape_[kDimension]
storing the dimension information
Definition: tensor.h:86
mshadow::expr::StreamInfo< Device, SliceExExp< SrcExp, Device, DType, srcdim > >::Get
static Stream< Device > * Get(const SliceExExp< SrcExp, Device, DType, srcdim > &t)
Definition: slice_ex.h:103
mshadow::expr::SliceExExp::end_
const Shape< srcdim > end_
Definition: slice_ex.h:48
mshadow::expr::ExpInfo::kDim
static const int kDim
Definition: expr_engine-inl.h:263
mshadow::index_t
int32_t index_t
type that will be used for index
Definition: base.h:328
mshadow::expr::Plan
Definition: expr_engine-inl.h:58
mshadow::expr::Exp
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:79
mshadow::expr::SliceExExp::src_
const SrcExp & src_
Definition: slice_ex.h:44
mshadow
overloaded + operator between half_t and bf16_t
Definition: base.h:319
mshadow::expr::ShapeCheck< srcdim, SliceExExp< SrcExp, Device, DType, srcdim > >::Check
static Shape< srcdim > Check(const SliceExExp< SrcExp, Device, DType, srcdim > &t)
Definition: slice_ex.h:93
mshadow::Shape< srcdim >
mshadow::expr::SliceExExp
slice expression, slice a tensor's channel
Definition: slice_ex.h:40
mshadow::expr::RValueExp< SliceExExp< SrcExp, Device, DType, srcdim >, DType >::__assign
SliceExExp< SrcExp, Device, DType, srcdim > & __assign(DType s)
operator overload
Definition: expression.h:178
mshadow::expr::SliceExExp::operator=
void operator=(const expr::Exp< E, DType, etype > &exp)
Definition: slice_ex.h:58
mshadow::expr::Plan< SliceExExp< SrcExp, Device, DType, srcdim >, DType >::Eval
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const
Definition: slice_ex.h:124
mshadow::expr::SliceExExp::begin_
const Shape< srcdim > begin_
Definition: slice_ex.h:47