mxnet
pass_functions.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
28 #ifndef NNVM_PASS_FUNCTIONS_H_
29 #define NNVM_PASS_FUNCTIONS_H_
30 
31 #include <memory>
32 #include <string>
33 #include <utility>
34 #include <vector>
35 
36 #include "base.h"
37 #include "graph_attr_types.h"
38 #include "pass.h"
39 
40 namespace nnvm {
41 namespace pass {
42 
48 inline Graph LoadJSON(const std::string& json_str) {
49  Graph ret;
50  ret.attrs["json"] = std::make_shared<any>(json_str);
51  return ApplyPass(ret, "LoadJSON");
52 }
53 
59 inline std::string SaveJSON(Graph graph) {
60  Graph ret = ApplyPass(std::move(graph), "SaveJSON");
61  return ret.GetAttr<std::string>("json");
62 }
63 
69 inline std::string PrintGraphIR(Graph graph) {
70  Graph ret = ApplyPass(std::move(graph), "PrintGraphIR");
71  return ret.GetAttr<std::string>("graphir");
72 }
73 
84 inline Graph OrderMutation(Graph src) { return ApplyPass(std::move(src), "OrderMutation"); }
85 
95 inline Graph InferShape(Graph graph, ShapeVector shape_inputs, std::string shape_attr_key = "") {
96  if (shape_inputs.size() != 0) {
97  graph.attrs["shape_inputs"] = std::make_shared<any>(std::move(shape_inputs));
98  }
99  if (shape_attr_key.length() != 0) {
100  graph.attrs["shape_attr_key"] = std::make_shared<any>(std::move(shape_attr_key));
101  }
102  return ApplyPass(std::move(graph), "InferShape");
103 }
104 
114 inline Graph InferType(Graph graph, DTypeVector dtype_inputs, std::string dtype_attr_key = "") {
115  if (dtype_inputs.size() != 0) {
116  graph.attrs["dtype_inputs"] = std::make_shared<any>(std::move(dtype_inputs));
117  }
118  if (dtype_attr_key.length() != 0) {
119  graph.attrs["dtype_attr_key"] = std::make_shared<any>(std::move(dtype_attr_key));
120  }
121  return ApplyPass(std::move(graph), "InferType");
122 }
123 
138 inline Graph PlaceDevice(Graph graph, std::string device_group_attr_key,
139  DeviceAssignMap device_assign_map, std::string device_copy_op) {
140  graph.attrs["device_group_attr_key"] = std::make_shared<any>(std::move(device_group_attr_key));
141  graph.attrs["device_assign_map"] = std::make_shared<any>(std::move(device_assign_map));
142  graph.attrs["device_copy_op"] = std::make_shared<any>(std::move(device_copy_op));
143  return ApplyPass(std::move(graph), "PlaceDevice");
144 }
145 
161  Graph graph, std::vector<NodeEntry> ys, std::vector<NodeEntry> xs,
162  std::vector<NodeEntry> ys_out_grad,
163  std::function<NodeEntry(std::vector<NodeEntry>&& inputs)> aggregate_fun = nullptr,
164  std::function<int(const Node& node)> mirror_fun = nullptr,
165  std::function<NodeEntry(const NodeEntry& src, const NodeEntry& like)> attr_hint_fun = nullptr,
166  std::vector<const Op*> zero_ops = std::vector<const Op*>(),
167  std::string copy_op_str = std::string()) {
168  graph.attrs["grad_ys"] = std::make_shared<any>(std::move(ys));
169 
170  graph.attrs["grad_xs"] = std::make_shared<any>(std::move(xs));
171  graph.attrs["grad_ys_out_grad"] = std::make_shared<any>(std::move(ys_out_grad));
172  if (aggregate_fun != nullptr) {
173  graph.attrs["grad_aggregate_fun"] = std::make_shared<any>(aggregate_fun);
174  }
175 
176  if (mirror_fun != nullptr) {
177  graph.attrs["grad_mirror_fun"] = std::make_shared<any>(mirror_fun);
178  }
179 
180  if (attr_hint_fun != nullptr) {
181  graph.attrs["attr_hint_fun"] = std::make_shared<any>(attr_hint_fun);
182  }
183 
184  if (zero_ops.size()) {
185  graph.attrs["zero_ops"] = std::make_shared<any>(std::move(zero_ops));
186  }
187 
188  if (copy_op_str != std::string()) {
189  graph.attrs["copy_op"] = std::make_shared<any>(std::move(copy_op_str));
190  }
191 
192  return ApplyPass(std::move(graph), "Gradient");
193 }
194 
195 } // namespace pass
196 } // namespace nnvm
197 #endif // NNVM_PASS_FUNCTIONS_H_
nnvm::pass::PlaceDevice
Graph PlaceDevice(Graph graph, std::string device_group_attr_key, DeviceAssignMap device_assign_map, std::string device_copy_op)
Place the devices for each operator in the graph.
Definition: pass_functions.h:138
nnvm::pass::LoadJSON
Graph LoadJSON(const std::string &json_str)
Load a graph from JSON string, redirects to "LoadJSON" pass.
Definition: pass_functions.h:48
nnvm::pass::PrintGraphIR
std::string PrintGraphIR(Graph graph)
Print graph ir.
Definition: pass_functions.h:69
nnvm::Node
Node represents an operation in a computation graph.
Definition: node.h:143
nnvm::Graph
Symbolic computation graph. This is the intermediate representation for optimization pass.
Definition: graph.h:47
nnvm::ShapeVector
std::vector< TShape > ShapeVector
The result holder of shape of each NodeEntry in the graph.
Definition: graph_attr_types.h:61
base.h
Configuration of nnvm as well as basic data structure.
nnvm::Graph::attrs
std::unordered_map< std::string, std::shared_ptr< any > > attrs
attributes of a graph Note that attribute is shared pointer and can be shared across graphs.
Definition: graph.h:61
nnvm::pass::InferType
Graph InferType(Graph graph, DTypeVector dtype_inputs, std::string dtype_attr_key="")
Infer types in the graph given the information.
Definition: pass_functions.h:114
nnvm::pass::InferShape
Graph InferShape(Graph graph, ShapeVector shape_inputs, std::string shape_attr_key="")
Infer shapes in the graph given the information.
Definition: pass_functions.h:95
graph_attr_types.h
Data structures that can appear in graph attributes.
nnvm::ApplyPass
Graph ApplyPass(Graph src, const std::string &pass)
Apply one pass to the graph.
Definition: pass.h:62
nnvm::pass::Gradient
Graph Gradient(Graph graph, std::vector< NodeEntry > ys, std::vector< NodeEntry > xs, std::vector< NodeEntry > ys_out_grad, std::function< NodeEntry(std::vector< NodeEntry > &&inputs)> aggregate_fun=nullptr, std::function< int(const Node &node)> mirror_fun=nullptr, std::function< NodeEntry(const NodeEntry &src, const NodeEntry &like)> attr_hint_fun=nullptr, std::vector< const Op * > zero_ops=std::vector< const Op * >(), std::string copy_op_str=std::string())
Get the gradient graph whose outputs are gradients of xs wrt to ys.
Definition: pass_functions.h:160
nnvm::pass::SaveJSON
std::string SaveJSON(Graph graph)
Save a graph to json, redirects to "SaveJSON" pass.
Definition: pass_functions.h:59
nnvm::pass::OrderMutation
Graph OrderMutation(Graph src)
Add control flow dependencies between nodes.
Definition: pass_functions.h:84
nnvm::Graph::GetAttr
const T & GetAttr(const std::string &attr_name) const
Get the immutable attribute from attrs.
Definition: graph.h:230
nnvm::NodeEntry
an entry that represents output data from a node
Definition: node.h:52
nnvm::DeviceAssignMap
std::unordered_map< std::string, int > DeviceAssignMap
The result holder of device of each operator in the graph.
Definition: graph_attr_types.h:112
nnvm::DTypeVector
std::vector< int > DTypeVector
The result holder of type of each NodeEntry in the graph.
Definition: graph_attr_types.h:76
nnvm
Definition: base.h:35
pass.h
Pass that can be applied to a graph.