mxnet
reshape.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
26 #ifndef MSHADOW_EXTENSION_RESHAPE_H_
27 #define MSHADOW_EXTENSION_RESHAPE_H_
28 #include "../extension.h"
29 namespace mshadow {
30 namespace expr {
39 template<typename SrcExp, typename DType, int dimdst, int dimsrc>
40 struct ReshapeExp:
41  public MakeTensorExp<ReshapeExp<SrcExp, DType, dimdst, dimsrc>,
42  SrcExp, dimdst, DType> {
44  const SrcExp &src_;
48  ReshapeExp(const SrcExp &src, Shape<dimdst> shape)
49  : src_(src) {
51  CHECK_EQ(ishape.Size(), shape.Size()) << "reshape size must match";
52  ishapex_ = ishape[dimsrc - 1];
53  this->shape_ = shape;
54  }
55 };
65 template<typename SrcExp, typename DType, int etype, int dimdst>
69  (src.self(), oshape);
70 }
71 //----------------------
72 // Execution plan
73 //----------------------
74 template<typename SrcExp, typename DType, int dimdst, int dimsrc>
75 struct Plan<ReshapeExp<SrcExp, DType, dimdst, dimsrc>, DType> {
76  public:
78  : src_(MakePlan(e.src_)),
79  oshapex_(e.shape_[dimdst - 1]), ishapex_(e.ishapex_) {}
80  MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
81  const index_t idx = y * oshapex_ + x;
82  return src_.Eval(idx / ishapex_, idx % ishapex_);
83  }
84 
85  private:
87  const index_t oshapex_, ishapex_;
88 };
89 // special work plan for 1 dimensional data
90 template<typename SrcExp, typename DType, int dimdst>
91 struct Plan<ReshapeExp<SrcExp, DType, dimdst, 1>, DType> {
92  public:
94  : src_(MakePlan(e.src_)), oshapex_(e.shape_[dimdst - 1]) {
95  }
96  MSHADOW_XINLINE DType Eval(index_t y, index_t x) const {
97  return src_.Eval(0, y * oshapex_ + x);
98  }
99 
100  private:
102  const index_t oshapex_;
103 };
104 } // namespace expr
105 } // namespace mshadow
106 #endif // MSHADOW_EXTENSION_RESHAPE_H_
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: reshape.h:80
MSHADOW_XINLINE index_t Size(void) const
Definition: tensor.h:145
Definition: expr_engine-inl.h:59
shape of a tensor
Definition: tensor.h:54
Plan(const ReshapeExp< SrcExp, DType, dimdst, dimsrc > &e)
Definition: reshape.h:77
Plan(const ReshapeExp< SrcExp, DType, dimdst, 1 > &e)
Definition: reshape.h:93
static Shape< dim > Check(const E &t)
const SrcExp & src_
source expression
Definition: reshape.h:44
#define MSHADOW_XINLINE
Definition: base.h:223
index_t ishapex_
smallest dimension of input
Definition: reshape.h:46
int32_t index_t
type that will be used for index
Definition: base.h:336
ReshapeExp< SrcExp, DType, dimdst, ExpInfo< SrcExp >::kDim > reshape(const Exp< SrcExp, DType, etype > &src, Shape< dimdst > oshape)
a expression that reshapes a tensor to another shape
Definition: reshape.h:67
reshape the content to another shape input: Tensor<Device,dimsrc>: ishape output: Tensor<Device...
Definition: reshape.h:40
defines how expression exp can be evaluated and stored into dst
Definition: expression.h:80
const SubType & self(void) const
Definition: expression.h:83
Plan< BinaryMapExp< OP, TA, TB, DType, etype >, DType > MakePlan(const BinaryMapExp< OP, TA, TB, DType, etype > &e)
Definition: expr_engine-inl.h:240
ReshapeExp(const SrcExp &src, Shape< dimdst > shape)
constructor
Definition: reshape.h:48
a general class that allows extension that makes tensors of some shape
Definition: expr_engine-inl.h:44
overloaded + operator between half_t and bf16_t
Definition: base.h:327
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const
Definition: reshape.h:96
Shape< dim > shape_
the shape of this expression
Definition: expr_engine-inl.h:48