mxnet
alm.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
26 #ifndef MXNET_COMMON_ALM_H_
27 #define MXNET_COMMON_ALM_H_
28 
29 #include <mxnet/base.h>
30 #include <nnvm/graph.h>
31 #include <nnvm/node.h>
32 #include <functional>
33 #include <string>
34 #include <unordered_map>
35 #include <vector>
36 
37 namespace mxnet {
38 namespace alm {
39 
43 struct ALMParams {
44  bool optimize = false;
45 
46  static ALMParams& get() {
47  static ALMParams alm;
48  return alm;
49  }
50 };
51 
56 
60 using Transpose = std::vector<size_t>;
61 
62 bool IsIdentity(const Transpose& t);
63 Transpose Reverse(const Transpose& axes);
64 
68 Transpose Compose(const Transpose& lhs, const Transpose& rhs);
69 
71 std::string ApplyTranspose(const std::string& layout, const Transpose& axes);
72 
74 
86 using FChangeLayout = std::function<bool(nnvm::NodeAttrs*,
87  mshadow::LayoutFlag target_layout,
88  std::vector<Transpose>* in_axes,
89  std::vector<Transpose>* out_axes)>;
90 
95 Transpose FactorCommonTranspose(std::vector<Transpose>* axes);
96 
97 } // namespace alm
98 } // namespace mxnet
99 
100 #endif // MXNET_COMMON_ALM_H_
mxnet
namespace of mxnet
Definition: api_registry.h:33
mxnet::alm::FactorCommonTranspose
Transpose FactorCommonTranspose(std::vector< Transpose > *axes)
Factors out and returns a common transpose, or default-constructed Transpose if all axes (in/out para...
mxnet::alm::ApplyTranspose
mshadow::LayoutFlag ApplyTranspose(mshadow::LayoutFlag layout, const Transpose &axes)
mxnet::alm::FChangeLayout
std::function< bool(nnvm::NodeAttrs *, mshadow::LayoutFlag target_layout, std::vector< Transpose > *in_axes, std::vector< Transpose > *out_axes)> FChangeLayout
May change operator's layout. Used in LayoutOptimization.
Definition: alm.h:89
nnvm::Graph
Symbolic computation graph. This is the intermediate representation for optimization pass.
Definition: graph.h:47
mxnet::alm::OptimizeLayout
nnvm::Graph OptimizeLayout(nnvm::Graph &&g)
mshadow::LayoutFlag
LayoutFlag
Definition: base.h:498
mxnet::alm::FromTShape
Transpose FromTShape(const mxnet::TShape &s)
mxnet::alm::ALMParams
A singleton flag, set and read by MXSetOptimizeLayout and MXGetOptimizeLayout.
Definition: alm.h:43
nnvm::NodeAttrs
The attributes of the current operation node. Usually are additional parameters like axis,...
Definition: node.h:107
mxnet::alm::Reverse
Transpose Reverse(const Transpose &axes)
mxnet::alm::Transpose
std::vector< size_t > Transpose
Transpose, represented by permutation of axes.
Definition: alm.h:60
mxnet::alm::Compose
Transpose Compose(const Transpose &lhs, const Transpose &rhs)
graph.h
Configuation of nnvm as well as basic data structure.
mxnet::TShape
A Shape class that is used to represent shape of each tensor.
Definition: tuple.h:440
mxnet::alm::ALMParams::optimize
bool optimize
Definition: alm.h:44
node.h
Graph node data structure.
mxnet::alm::IsIdentity
bool IsIdentity(const Transpose &t)
base.h
configuration of MXNet as well as basic data structure.
mxnet::alm::ALMParams::get
static ALMParams & get()
Definition: alm.h:46