mxnet
|
Functions | |
Graph | LoadJSON (const std::string &json_str) |
Load a graph from JSON string, redirects to "LoadJSON" pass. More... | |
std::string | SaveJSON (Graph graph) |
Save a graph to json, redirects to "SaveJSON" pass. More... | |
std::string | PrintGraphIR (Graph graph) |
Print graph ir. More... | |
Graph | OrderMutation (Graph src) |
Add control flow dependencies between nodes. More... | |
Graph | InferShape (Graph graph, ShapeVector shape_inputs, std::string shape_attr_key="") |
Infer shapes in the graph given the information. More... | |
Graph | InferType (Graph graph, DTypeVector dtype_inputs, std::string dtype_attr_key="") |
Infer types in the graph given the information. More... | |
Graph | PlaceDevice (Graph graph, std::string device_group_attr_key, DeviceAssignMap device_assign_map, std::string device_copy_op) |
Place the devices for each operator in the graph. More... | |
Graph | Gradient (Graph graph, std::vector< NodeEntry > ys, std::vector< NodeEntry > xs, std::vector< NodeEntry > ys_out_grad, std::function< NodeEntry(std::vector< NodeEntry > &&inputs)> aggregate_fun=nullptr, std::function< int(const Node &node)> mirror_fun=nullptr, std::function< NodeEntry(const NodeEntry &src, const NodeEntry &like)> attr_hint_fun=nullptr, std::vector< const Op * > zero_ops=std::vector< const Op * >(), std::string copy_op_str=std::string()) |
Get the gradient graph whose outputs are gradients of xs wrt to ys. More... | |
|
inline |
Get the gradient graph whose outputs are gradients of xs wrt to ys.
graph | The input graph. |
ys | The entries we want to take gradient from. |
xs | The input to take gradient with respect to. |
ys_out_grad | The symbol for additional gradient to be propagate back to y. |
aggregate_fun | Aggregation function applied to aggregate the inputs. |
mirror_fun | Optional mirror function to do mirror optimization and save memory. |
attr_hint_fun | Optional, hint function to output a node that like src, but its attr is same as like. |
zero_ops | Optional, list of operators that outputs a single zero array. The first one must be zeros_like. |
copy_op_str | Optional, name of the copy operation required to handle duplicates on the edge of the graph |
|
inline |
Infer shapes in the graph given the information.
graph | The input graph. |
shape_inputs | The shapes of input symbols to the graph. |
shape_attr_key | The key to the node attribute that can indicate shape. This is the place where manual hint for shapes could be injected. |
|
inline |
Infer types in the graph given the information.
graph | The input graph. |
dtype_inputs | The types of input symbols to the graph. |
dtype_attr_key | The key to the node attribute that can indicate types. This is the place where manual hint for types could be injected. |
|
inline |
Load a graph from JSON string, redirects to "LoadJSON" pass.
json_str | The json string. |
Add control flow dependencies between nodes.
This function will enforce the correct order between write (mutable operators) and read (immutable operators) to sovle write-after-read and read-after-write problems.
src | The input graph. |
|
inline |
Place the devices for each operator in the graph.
Current device placement is quite simple. Each operator is assigned to a "group" (stored in device_group_attr_key
attribute). Each group is assigned to a device (stored in device_assign_map
attribute). Operators will be placed to the device assigned to its group. Copy operators will be injected if cross device reference happens.
graph | The input graph. |
device_group_attr_key | The attribute name for hints of device group. |
device_assign_map | The assignment map of device. |
device_copy_op | The name of copy op to be inserted when cross device copy happened. |
|
inline |
Print graph ir.
graph | The graph to be printed |
|
inline |
Save a graph to json, redirects to "SaveJSON" pass.
graph | The graph to be saved as json format. |