mxnet
slice_ex.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
25 #ifndef MSHADOW_EXTENSION_SLICE_EX_H_
26 #define MSHADOW_EXTENSION_SLICE_EX_H_
27 
28 #include "../extension.h"
29 
30 namespace mshadow {
31 namespace expr {
39 template<typename SrcExp, typename Device,
40  typename DType, int srcdim>
41 struct SliceExExp : public TRValue<SliceExExp<SrcExp,
42  Device, DType,
43  srcdim>,
44  Device, srcdim, DType> {
45  const SrcExp &src_;
50  SliceExExp(const SrcExp &src, Shape<srcdim> begin, Shape<srcdim> end)
51  : src_(src), begin_(begin), end_(end) {
52  src_shape_ = ShapeCheck<srcdim, SrcExp>::Check(src_);
53  for (int i = 0; i < srcdim; ++i) {
54  shape_[i] = end_[i] - begin_[i];
55  }
56  }
57  template<typename E, int etype>
58  inline void
60  this->__assign(exp);
61  }
62  inline void
63  operator=(const DType &exp) {
64  this->__assign(exp);
65  }
66 }; // struct SliceEx
67 
79 template<typename SrcExp, typename Device,
80  typename DType, int srcdim>
84  ::Error_Expression_Does_Not_Meet_Dimension_Req();
85  return SliceExExp<SrcExp, Device, DType, srcdim>(src.self(), begin, end);
86 }
87 //------------------------
88 // engine plugin
89 //------------------------
90 // runtime shapecheck
91 template<typename SrcExp, typename Device,
92  typename DType, int srcdim>
93 struct ShapeCheck<srcdim, SliceExExp<SrcExp, Device, DType, srcdim> >{
94  inline static Shape<srcdim> Check(const SliceExExp<SrcExp,
95  Device, DType, srcdim> &t) {
96  return t.shape_;
97  }
98 };
99 
100 template<typename SrcExp, typename Device,
101  typename DType, int srcdim>
102 struct StreamInfo<Device, SliceExExp<SrcExp, Device, DType, srcdim> >{
103  inline static Stream<Device> *
106  }
107 };
108 // static typecheck
109 template<typename SrcExp, typename Device,
110  typename DType, int srcdim>
111 struct ExpInfo<SliceExExp<SrcExp, Device, DType, srcdim> >{
112  static const int kDim = ExpInfo<SrcExp>::kDim;
113  static const int kDevMask = ExpInfo<SrcExp>::kDevMask;
114 };
115 //----------------------
116 // Execution plan
117 //---------------------
118 template<typename SrcExp, typename Device,
119  typename DType, int srcdim>
120 struct Plan<SliceExExp<SrcExp, Device, DType, srcdim>, DType> {
121  public:
123  : src_(MakePlan(e.src_)), begin_(e.begin_),
125  MSHADOW_XINLINE DType Eval(index_t i, index_t j) const {
126  index_t idx = 0;
127  index_t stride = 1;
128  #pragma unroll
129  for (int k = srcdim-2; k >= 0; --k) {
130  idx += stride * (i%shape_[k] + begin_[k]);
131  i /= shape_[k];
132  stride *= src_shape_[k];
133  }
134  return src_.Eval(idx, j + begin_[srcdim-1]);
135  }
137  index_t idx = 0;
138  index_t stride = 1;
139  #pragma unroll
140  for (int k = srcdim-2; k >= 0; --k) {
141  idx += stride * (i%shape_[k] + begin_[k]);
142  i /= shape_[k];
143  stride *= src_shape_[k];
144  }
145  return src_.REval(idx, j + begin_[srcdim-1]);
146  }
147 
148  private:
151 }; // struct Plan
152 } // namespace expr
153 } // namespace mshadow
154 #endif // MSHADOW_EXTENSION_SLICE_EX_H_
slice expression, slice a tensor&#39;s channel
Definition: slice_ex.h:41
Tensor RValue, this is the super type of all kinds of possible tensors.
Definition: tensor.h:410
const SrcExp & src_
Definition: slice_ex.h:45
Definition: expr_engine-inl.h:59
used to help static type check
Definition: expr_engine-inl.h:331
const Shape< srcdim > end_
Definition: slice_ex.h:49
const Shape< srcdim > begin_
Definition: slice_ex.h:48
Plan(const SliceExExp< SrcExp, Device, DType, srcdim > &e)
Definition: slice_ex.h:122
MSHADOW_XINLINE DType Eval(index_t i, index_t j) const
Definition: slice_ex.h:125
void operator=(const DType &exp)
Definition: slice_ex.h:63
static Shape< dim > Check(const E &t)
#define MSHADOW_XINLINE
Definition: base.h:223
SliceExp< SrcExp, Device, DType, srcdim, srcdim-sdim > slice(const TRValue< SrcExp, Device, srcdim, DType > &src, index_t begin, index_t end)
Slice a Tensor.
Definition: slice.h:84
static type inference template, used to get the dimension of each expression, if ExpInfo<E>::kDim == ...
Definition: expr_engine-inl.h:263
Definition: expr_engine-inl.h:346
int32_t index_t
type that will be used for index
Definition: base.h:336
static Stream< Device > * Get(const SliceExExp< SrcExp, Device, DType, srcdim > &t)
Definition: slice_ex.h:104
MSHADOW_XINLINE DType & REval(index_t i, index_t j)
Definition: slice_ex.h:136
Shape< srcdim > src_shape_
Definition: slice_ex.h:46
index_t shape_[kDimension]
storing the dimension information
Definition: tensor.h:76
runtime shape checking template get the shape of an expression, report error if shape mismatch ...
Definition: expr_engine-inl.h:365
static Stream< Device > * Get(const E &t)
static Shape< srcdim > Check(const SliceExExp< SrcExp, Device, DType, srcdim > &t)
Definition: slice_ex.h:94
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:80
const Container & self(void) const
Definition: expression.h:83
Plan< BinaryMapExp< OP, TA, TB, DType, etype >, DType > MakePlan(const BinaryMapExp< OP, TA, TB, DType, etype > &e)
Definition: expr_engine-inl.h:240
SliceExExp(const SrcExp &src, Shape< srcdim > begin, Shape< srcdim > end)
Definition: slice_ex.h:50
overloaded + operator between half_t and bf16_t
Definition: base.h:327
SliceExExp< SrcExp, Device, DType, srcdim > & __assign(DType s)
operator overload
Definition: expression.h:179
void operator=(const expr::Exp< E, DType, etype > &exp)
Definition: slice_ex.h:59
Shape< srcdim > shape_
Definition: slice_ex.h:47
computaion stream structure, used for asynchronous computations
Definition: tensor.h:384