mxnet
pass.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef NNVM_PASS_H_
25 #define NNVM_PASS_H_
26 
27 #include <functional>
28 #include <vector>
29 
30 #include "base.h"
31 #include "graph.h"
32 
33 namespace nnvm {
34 
46 typedef std::function<Graph(Graph src)> PassFunction;
47 
54 Graph ApplyPasses(Graph src, const std::vector<std::string>& passes);
55 
62 inline Graph ApplyPass(Graph src, const std::string& pass) { return ApplyPasses(src, {pass}); }
63 
67 struct PassFunctionReg : public dmlc::FunctionRegEntryBase<PassFunctionReg, PassFunction> {
72  bool change_graph{false};
74  std::vector<std::string> op_attr_dependency;
76  std::vector<std::string> graph_attr_dependency;
78  std::vector<std::string> graph_attr_targets;
84  PassFunctionReg& set_change_graph(bool v) { // NOLINT(*)
85  change_graph = v;
86  return *this;
87  }
94  PassFunctionReg& provide_graph_attr(const std::string& attr_name) { // NOLINT(*)
95  graph_attr_targets.push_back(attr_name);
96  return *this;
97  }
104  PassFunctionReg& depend_op_attr(const std::string& attr_name) { // NOLINT(*)
105  op_attr_dependency.push_back(attr_name);
106  return *this;
107  }
114  PassFunctionReg& depend_graph_attr(const std::string& attr_name) { // NOLINT(*)
115  graph_attr_dependency.push_back(attr_name);
116  return *this;
117  }
118 };
119 
136 #define NNVM_REGISTER_PASS(name) \
137  DMLC_REGISTRY_REGISTER(::nnvm::PassFunctionReg, PassFunctionReg, name)
138 
139 } // namespace nnvm
140 
141 #endif // NNVM_PASS_H_
dmlc::FunctionRegEntryBase
Common base class for function registry.
Definition: registry.h:151
nnvm::ApplyPasses
Graph ApplyPasses(Graph src, const std::vector< std::string > &passes)
Apply a series of pass transformations on the input graph.
nnvm::PassFunctionReg::depend_graph_attr
PassFunctionReg & depend_graph_attr(const std::string &attr_name)
Declare this pass requires the given graph attribute to be available before being applied on the grap...
Definition: pass.h:114
nnvm::Graph
Symbolic computation graph. This is the intermediate representation for optimization pass.
Definition: graph.h:47
nnvm::PassFunctionReg::graph_attr_dependency
std::vector< std::string > graph_attr_dependency
dependencies on attributes in the graph
Definition: pass.h:76
nnvm::PassFunctionReg::provide_graph_attr
PassFunctionReg & provide_graph_attr(const std::string &attr_name)
Declare that this pass will generate the given graph attribute name once it is applied on the graph.
Definition: pass.h:94
nnvm::PassFunctionReg::change_graph
bool change_graph
Whether the pass will change graph structure If this is false, the pass will only change attributes.
Definition: pass.h:72
base.h
Configuration of nnvm as well as basic data structure.
nnvm::PassFunctionReg::set_change_graph
PassFunctionReg & set_change_graph(bool v)
Set whether this pass will change graph structure.
Definition: pass.h:84
nnvm::PassFunctionReg::depend_op_attr
PassFunctionReg & depend_op_attr(const std::string &attr_name)
Declare this pass requires the given operator attribute to be available before being applied on the g...
Definition: pass.h:104
nnvm::ApplyPass
Graph ApplyPass(Graph src, const std::string &pass)
Apply one pass to the graph.
Definition: pass.h:62
nnvm::PassFunctionReg::graph_attr_targets
std::vector< std::string > graph_attr_targets
generated targets of graph attributes
Definition: pass.h:78
nnvm::PassFunctionReg
Registry entry for pass functions.
Definition: pass.h:67
nnvm::PassFunction
std::function< Graph(Graph src)> PassFunction
A PassFunction is an "Operator on Graph". It takes a source graph and return a graph that may or may ...
Definition: pass.h:46
graph.h
Configuation of nnvm as well as basic data structure.
nnvm::PassFunctionReg::op_attr_dependency
std::vector< std::string > op_attr_dependency
dependencies on operator attributes
Definition: pass.h:74
nnvm
Definition: base.h:35