mxnet.visualization¶
Visualization module
Functions
|
Creates a visualization (Graphviz digraph object) of the given computation graph. |
|
Convert symbol for detail information. |
-
mxnet.visualization.
plot_network
(symbol, title='plot', save_format='pdf', shape=None, dtype=None, node_attrs={}, hide_weights=True)[source]¶ Creates a visualization (Graphviz digraph object) of the given computation graph. Graphviz must be installed for this function to work.
- Parameters
title (str, optional) – Title of the generated visualization.
symbol (Symbol) – A symbol from the computation graph. The generated digraph will visualize the part of the computation graph required to compute symbol.
shape (dict, optional) – Specifies the shape of the input tensors. If specified, the visualization will include the shape of the tensors between the nodes. shape is a dictionary mapping input symbol names (str) to the corresponding tensor shape (tuple).
dtype (dict, optional) – Specifies the type of the input tensors. If specified, the visualization will include the type of the tensors between the nodes. dtype is a dictionary mapping input symbol names (str) to the corresponding tensor type (e.g. numpy.float32).
node_attrs (dict, optional) –
Specifies the attributes for nodes in the generated visualization. node_attrs is a dictionary of Graphviz attribute names and values. For example:
node_attrs={"shape":"oval","fixedsize":"false"}
will use oval shape for nodes and allow variable sized nodes in the visualization.
hide_weights (bool, optional) – If True (default), then inputs with names of form _weight (corresponding to weight tensors) or _bias (corresponding to bias vectors) will be hidden for a cleaner visualization.
- Returns
dot – A Graphviz digraph object visualizing the computation graph to compute symbol.
- Return type
Digraph
Example
>>> net = mx.sym.Variable('data') >>> net = mx.sym.FullyConnected(data=net, name='fc1', num_hidden=128) >>> net = mx.sym.Activation(data=net, name='relu1', act_type="relu") >>> net = mx.sym.FullyConnected(data=net, name='fc2', num_hidden=10) >>> net = mx.sym.SoftmaxOutput(data=net, name='out') >>> digraph = mx.viz.plot_network(net, shape={'data':(100,200)}, ... node_attrs={"fixedsize":"false"}) >>> digraph.view()
Notes
If
mxnet
is imported, the visualization module can be used in its short-form. For example, if weimport mxnet
as follows:import mxnet
this method in visualization module can be used in its short-form as:
mxnet.viz.plot_network(...)
-
mxnet.visualization.
print_summary
(symbol, shape=None, line_length=120, positions=[0.44, 0.64, 0.74, 1.0])[source]¶ Convert symbol for detail information.
- Parameters
symbol (Symbol) – Symbol to be visualized.
shape (dict) – A dict of shapes, str->shape (tuple), given input shapes.
line_length (int) – Rotal length of printed lines
positions (list) – Relative or absolute positions of log elements in each line.
- Returns
- Return type
None
Notes
If
mxnet
is imported, the visualization module can be used in its short-form. For example, if weimport mxnet
as follows:import mxnet
this method in visualization module can be used in its short-form as:
mxnet.viz.print_summary(...)