mxnet.executor_manager¶
Executor manager.
Classes
|
A group of executors living on different devices, for data parallelization. |
|
Helper class to manage multiple executors for data parallelism. |
-
class
mxnet.executor_manager.
DataParallelExecutorGroup
(sym, arg_names, param_names, ctx, slices, train_data, shared_group=None)[source]¶ Bases:
object
A group of executors living on different devices, for data parallelization.
- Parameters
sym (Symbol) – The network configuration.
arg_names (list of str) – Equals sym.list_arguments()
param_names (list of str) – List of names of all trainable parameters.
ctx (list of Context) – List of devices for training (data parallelization).
slices (list of int) – Describes how the data parallelization splits data into different devices.
train_data (DataIter (or DataBatch)) – The dataset for training. It could be any object with provide_data and provide_label properties. Loading of actual data is not necessarily needed at this stage.
shared_grop (DataParallelExecutorGroup) – An existing executor group, if to share parameters with it.
Methods
backward
()Perform a backward pass on each executor.
forward
([is_train])Perform a forward pass on each executor.
load_data_batch
(data_batch)Load data and labels into arrays.
update_metric
(metric, labels[, pre_sliced])Update evaluation metric with label and current outputs.
-
class
mxnet.executor_manager.
DataParallelExecutorManager
(symbol, ctx, train_data, arg_names, param_names, aux_names, work_load_list=None, logger=None, sym_gen=None)[source]¶ Bases:
object
Helper class to manage multiple executors for data parallelism.
- Parameters
symbol (Symbol) – Output symbol.
ctx (list of Context) – Devices to run on.
param_names (list of str) – Name of all trainable parameters of the network.
arg_names (list of str) – Name of all arguments of the network.
aux_names (list of str) – Name of all auxiliary states of the network.
train_data (DataIter) – Training data iterator.
work_load_list (list of float or int, optional) – The list of work load for different devices, in the same order as ctx.
logger (logging logger) – When not specified, default logger will be used.
sym_gen (A function that generate new Symbols depending on different) – input shapes. Used only for bucketing.
Attributes
Shared aux states.
Shared gradient arrays.
Shared parameter arrays.
Methods
backward
()Run backward on the current executor.
copy_to
(arg_params, aux_params)Copy data from each executor to
`arg_params
andaux_params
.forward
([is_train])Run forward on the current executor.
install_monitor
(monitor)Install monitor on all executors.
load_data_batch
(data_batch)Load data and labels into arrays.
set_params
(arg_params, aux_params)Set parameter and aux values.
update_metric
(metric, labels[, pre_sliced])Update metric with the current executor.
-
property
aux_arrays
¶ Shared aux states.
-
copy_to
(arg_params, aux_params)[source]¶ Copy data from each executor to
`arg_params
andaux_params
.- Parameters
arg_params (list of NDArray) – Target parameter arrays.
aux_params (list of NDArray) – Target aux arrays.
Notes
This function will inplace update the NDArrays in arg_params and aux_params.
-
property
grad_arrays
¶ Shared gradient arrays.
-
property
param_arrays
¶ Shared parameter arrays.