28 #ifndef NNVM_PASS_FUNCTIONS_H_ 29 #define NNVM_PASS_FUNCTIONS_H_ 49 ret.
attrs[
"json"] = std::make_shared<any>(json_str);
60 return ret.
GetAttr<std::string>(
"json");
71 return ret.
GetAttr<std::string>(
"graphir");
85 return ApplyPass(std::move(src),
"OrderMutation");
99 std::string shape_attr_key =
"") {
100 if (shape_inputs.size() != 0) {
101 graph.
attrs[
"shape_inputs"] = std::make_shared<any>(std::move(shape_inputs));
103 if (shape_attr_key.length() != 0) {
104 graph.
attrs[
"shape_attr_key"] = std::make_shared<any>(std::move(shape_attr_key));
106 return ApplyPass(std::move(graph),
"InferShape");
120 std::string dtype_attr_key =
"") {
121 if (dtype_inputs.size() != 0) {
122 graph.
attrs[
"dtype_inputs"] = std::make_shared<any>(std::move(dtype_inputs));
124 if (dtype_attr_key.length() != 0) {
125 graph.
attrs[
"dtype_attr_key"] = std::make_shared<any>(std::move(dtype_attr_key));
127 return ApplyPass(std::move(graph),
"InferType");
145 std::string device_group_attr_key,
147 std::string device_copy_op) {
148 graph.
attrs[
"device_group_attr_key"] = std::make_shared<any>(std::move(device_group_attr_key));
149 graph.
attrs[
"device_assign_map"] = std::make_shared<any>(std::move(device_assign_map));
150 graph.
attrs[
"device_copy_op"] = std::make_shared<any>(std::move(device_copy_op));
151 return ApplyPass(std::move(graph),
"PlaceDevice");
171 std::vector<NodeEntry> ys,
172 std::vector<NodeEntry> xs,
173 std::vector<NodeEntry> ys_out_grad,
174 std::function<
NodeEntry(std::vector<NodeEntry>&& inputs)> aggregate_fun =
nullptr,
175 std::function<
int(
const Node& node)> mirror_fun =
nullptr,
177 attr_hint_fun =
nullptr,
178 std::vector<const Op*> zero_ops = std::vector<const Op*>(),
179 std::string copy_op_str = std::string()) {
180 graph.
attrs[
"grad_ys"] = std::make_shared<any>(std::move(ys));
182 graph.
attrs[
"grad_xs"] = std::make_shared<any>(std::move(xs));
183 graph.
attrs[
"grad_ys_out_grad"] = std::make_shared<any>(std::move(ys_out_grad));
184 if (aggregate_fun !=
nullptr) {
185 graph.
attrs[
"grad_aggregate_fun"] = std::make_shared<any>(aggregate_fun);
188 if (mirror_fun !=
nullptr) {
189 graph.
attrs[
"grad_mirror_fun"] = std::make_shared<any>(mirror_fun);
192 if (attr_hint_fun !=
nullptr) {
193 graph.
attrs[
"attr_hint_fun"] = std::make_shared<any>(attr_hint_fun);
196 if (zero_ops.size()) {
197 graph.
attrs[
"zero_ops"] = std::make_shared<any>(std::move(zero_ops));
200 if (copy_op_str != std::string()) {
201 graph.
attrs[
"copy_op"] = std::make_shared<any>(std::move(copy_op_str));
204 return ApplyPass(std::move(graph),
"Gradient");
209 #endif // NNVM_PASS_FUNCTIONS_H_
Graph InferType(Graph graph, DTypeVector dtype_inputs, std::string dtype_attr_key="")
Infer types in the graph given the information.
Definition: pass_functions.h:118
Graph Gradient(Graph graph, std::vector< NodeEntry > ys, std::vector< NodeEntry > xs, std::vector< NodeEntry > ys_out_grad, std::function< NodeEntry(std::vector< NodeEntry > &&inputs)> aggregate_fun=nullptr, std::function< int(const Node &node)> mirror_fun=nullptr, std::function< NodeEntry(const NodeEntry &src, const NodeEntry &like)> attr_hint_fun=nullptr, std::vector< const Op * > zero_ops=std::vector< const Op * >(), std::string copy_op_str=std::string())
Get the gradient graph whose outputs are gradients of xs wrt to ys.
Definition: pass_functions.h:169
Graph OrderMutation(Graph src)
Add control flow dependencies between nodes.
Definition: pass_functions.h:84
Graph ApplyPass(Graph src, const std::string &pass)
Apply one pass to the graph.
Definition: pass.h:62
std::unordered_map< std::string, int > DeviceAssignMap
The result holder of device of each operator in the graph.
Definition: graph_attr_types.h:111
std::unordered_map< std::string, std::shared_ptr< any > > attrs
attributes of a graph Note that attribute is shared pointer and can be shared across graphs...
Definition: graph.h:60
Data structures that can appear in graph attributes.
Node represents an operation in a computation graph.
Definition: node.h:155
Graph PlaceDevice(Graph graph, std::string device_group_attr_key, DeviceAssignMap device_assign_map, std::string device_copy_op)
Place the devices for each operator in the graph.
Definition: pass_functions.h:144
Symbolic computation graph. This is the intermediate representation for optimization pass...
Definition: graph.h:46
std::string PrintGraphIR(Graph graph)
Print graph ir.
Definition: pass_functions.h:69
std::vector< TShape > ShapeVector
The result holder of shape of each NodeEntry in the graph.
Definition: graph_attr_types.h:60
an entry that represents output data from a node
Definition: node.h:51
std::vector< int > DTypeVector
The result holder of type of each NodeEntry in the graph.
Definition: graph_attr_types.h:75
Pass that can be applied to a graph.
const T & GetAttr(const std::string &attr_name) const
Get the immutable attribute from attrs.
Definition: graph.h:247
Graph LoadJSON(const std::string &json_str)
Load a graph from JSON string, redirects to "LoadJSON" pass.
Definition: pass_functions.h:47
Graph InferShape(Graph graph, ShapeVector shape_inputs, std::string shape_attr_key="")
Infer shapes in the graph given the information.
Definition: pass_functions.h:97
Configuration of nnvm as well as basic data structure.
std::string SaveJSON(Graph graph)
Save a graph to json, redirects to "SaveJSON" pass.
Definition: pass_functions.h:58