mxnet
Functions
nnvm::pass Namespace Reference

Functions

Graph LoadJSON (const std::string &json_str)
 Load a graph from JSON string, redirects to "LoadJSON" pass. More...
 
std::string SaveJSON (Graph graph)
 Save a graph to json, redirects to "SaveJSON" pass. More...
 
std::string PrintGraphIR (Graph graph)
 Print graph ir. More...
 
Graph OrderMutation (Graph src)
 Add control flow dependencies between nodes. More...
 
Graph InferShape (Graph graph, ShapeVector shape_inputs, std::string shape_attr_key="")
 Infer shapes in the graph given the information. More...
 
Graph InferType (Graph graph, DTypeVector dtype_inputs, std::string dtype_attr_key="")
 Infer types in the graph given the information. More...
 
Graph PlaceDevice (Graph graph, std::string device_group_attr_key, DeviceAssignMap device_assign_map, std::string device_copy_op)
 Place the devices for each operator in the graph. More...
 
Graph Gradient (Graph graph, std::vector< NodeEntry > ys, std::vector< NodeEntry > xs, std::vector< NodeEntry > ys_out_grad, std::function< NodeEntry(std::vector< NodeEntry > &&inputs)> aggregate_fun=nullptr, std::function< int(const Node &node)> mirror_fun=nullptr, std::function< NodeEntry(const NodeEntry &src, const NodeEntry &like)> attr_hint_fun=nullptr, std::vector< const Op * > zero_ops=std::vector< const Op * >(), std::string copy_op_str=std::string())
 Get the gradient graph whose outputs are gradients of xs wrt to ys. More...
 

Function Documentation

Graph nnvm::pass::Gradient ( Graph  graph,
std::vector< NodeEntry ys,
std::vector< NodeEntry xs,
std::vector< NodeEntry ys_out_grad,
std::function< NodeEntry(std::vector< NodeEntry > &&inputs)>  aggregate_fun = nullptr,
std::function< int(const Node &node)>  mirror_fun = nullptr,
std::function< NodeEntry(const NodeEntry &src, const NodeEntry &like)>  attr_hint_fun = nullptr,
std::vector< const Op * >  zero_ops = std::vector<const Op*>(),
std::string  copy_op_str = std::string() 
)
inline

Get the gradient graph whose outputs are gradients of xs wrt to ys.

Parameters
graphThe input graph.
ysThe entries we want to take gradient from.
xsThe input to take gradient with respect to.
ys_out_gradThe symbol for additional gradient to be propagate back to y.
aggregate_funAggregation function applied to aggregate the inputs.
mirror_funOptional mirror function to do mirror optimization and save memory.
attr_hint_funOptional, hint function to output a node that like src, but its attr is same as like.
zero_opsOptional, list of operators that outputs a single zero array. The first one must be zeros_like.
copy_op_strOptional, name of the copy operation required to handle duplicates on the edge of the graph
Returns
A new graph, whose outputs correspond to inputs of xs.
Graph nnvm::pass::InferShape ( Graph  graph,
ShapeVector  shape_inputs,
std::string  shape_attr_key = "" 
)
inline

Infer shapes in the graph given the information.

Parameters
graphThe input graph.
shape_inputsThe shapes of input symbols to the graph.
shape_attr_keyThe key to the node attribute that can indicate shape. This is the place where manual hint for shapes could be injected.
Returns
A graph with new attribute "shape" containing inferred shape of each NodeEntry. The index of ShapeVector is given by graph.indexed_graph().entry_id.
Graph nnvm::pass::InferType ( Graph  graph,
DTypeVector  dtype_inputs,
std::string  dtype_attr_key = "" 
)
inline

Infer types in the graph given the information.

Parameters
graphThe input graph.
dtype_inputsThe types of input symbols to the graph.
dtype_attr_keyThe key to the node attribute that can indicate types. This is the place where manual hint for types could be injected.
Returns
A graph with new attribute "dtype" containing inferred type of each NodeEntry. The index of ShapeVector is given by graph.indexed_graph().entry_id.
Graph nnvm::pass::LoadJSON ( const std::string &  json_str)
inline

Load a graph from JSON string, redirects to "LoadJSON" pass.

Parameters
json_strThe json string.
Returns
Loaded graph.
Graph nnvm::pass::OrderMutation ( Graph  src)
inline

Add control flow dependencies between nodes.

This function will enforce the correct order between write (mutable operators) and read (immutable operators) to sovle write-after-read and read-after-write problems.

Parameters
srcThe input graph.
Returns
A graph with proper control flow dependencies added.
Graph nnvm::pass::PlaceDevice ( Graph  graph,
std::string  device_group_attr_key,
DeviceAssignMap  device_assign_map,
std::string  device_copy_op 
)
inline

Place the devices for each operator in the graph.

Current device placement is quite simple. Each operator is assigned to a "group" (stored in device_group_attr_key attribute). Each group is assigned to a device (stored in device_assign_map attribute). Operators will be placed to the device assigned to its group. Copy operators will be injected if cross device reference happens.

Parameters
graphThe input graph.
device_group_attr_keyThe attribute name for hints of device group.
device_assign_mapThe assignment map of device.
device_copy_opThe name of copy op to be inserted when cross device copy happened.
Returns
A graph with new attribute "device", cotaining device information of each node.
std::string nnvm::pass::PrintGraphIR ( Graph  graph)
inline

Print graph ir.

Parameters
graphThe graph to be printed
Returns
The graph ir string.
std::string nnvm::pass::SaveJSON ( Graph  graph)
inline

Save a graph to json, redirects to "SaveJSON" pass.

Parameters
graphThe graph to be saved as json format.
Returns
The json string.