mxnet
Public Types | Public Member Functions | Static Public Member Functions | List of all members
mxnet::Executor Class Referenceabstract

Executor of a computation graph. Executor can be created by Binding a symbol. More...

#include <executor.h>

Collaboration diagram for mxnet::Executor:
Collaboration graph

Public Types

typedef std::function< void(const char *, void *)> MonitorCallback
 the prototype of user-defined monitor callback More...
 

Public Member Functions

virtual ~Executor ()
 destructor More...
 
virtual void Forward (bool is_train)=0
 Perform a Forward operation of Operator After this operation, user can get the result by using function head. More...
 
virtual void PartialForward (bool is_train, int step, int *step_left)=0
 Perform a Partial Forward operation of Operator. Only issue operation specified by step. The caller must keep calling PartialForward with increasing steps, until step_left=0. More...
 
virtual void Backward (const std::vector< NDArray > &head_grads, bool is_train=true)=0
 Perform a Backward operation of the Operator. This must be called after Forward. After this operation, NDArrays specified by grad_in_args_store will be updated accordingly. User is allowed to pass in an empty Array if the head node is loss function and head gradeitn is not needed. More...
 
virtual void Print (std::ostream &os) const
 print the execution plan info to output stream. More...
 
virtual const std::vector< NDArray > & outputs () const =0
 get array of outputs in the executor. More...
 
virtual const std::unordered_map< std::string, NDArray > & in_arg_map () const =0
 get input argument map, key is arg name, value is arg's NDArray. More...
 
virtual const std::unordered_map< std::string, NDArray > & arg_grad_map () const =0
 get input argument graident map, key is arg name, value is gradient's NDArray. More...
 
virtual const std::unordered_map< std::string, NDArray > & aux_state_map () const =0
 get aux state map, key is arg name, value is aux state's NDArray. More...
 
virtual void SetMonitorCallback (const MonitorCallback &callback)
 Install a callback to notify the completion of operation. More...
 

Static Public Member Functions

static ExecutorBind (nnvm::Symbol symbol, const Context &default_ctx, const std::map< std::string, Context > &group2ctx, const std::vector< NDArray > &in_args, const std::vector< NDArray > &arg_grad_store, const std::vector< OpReqType > &grad_req_type, const std::vector< NDArray > &aux_states, Executor *shared_exec=NULL)
 Create an operator by bind symbol with context and arguments. If user do not want to compute the gradients of i-th argument, grad_req_type[i] can be kNullOp. More...
 
static ExecutorSimpleBind (nnvm::Symbol symbol, const Context &default_ctx, const std::map< std::string, Context > &group2ctx, const std::vector< Context > &in_arg_ctxes, const std::vector< Context > &arg_grad_ctxes, const std::vector< Context > &aux_state_ctxes, const std::unordered_map< std::string, TShape > &arg_shape_map, const std::unordered_map< std::string, int > &arg_dtype_map, const std::vector< OpReqType > &grad_req_types, const std::unordered_set< std::string > &param_names, std::vector< NDArray > *in_args, std::vector< NDArray > *arg_grads, std::vector< NDArray > *aux_states, std::unordered_map< std::string, NDArray > *shared_data_arrays=nullptr, Executor *shared_exec=nullptr)
 

Detailed Description

Executor of a computation graph. Executor can be created by Binding a symbol.

Member Typedef Documentation

typedef std::function<void(const char*, void*)> mxnet::Executor::MonitorCallback

the prototype of user-defined monitor callback

Constructor & Destructor Documentation

virtual mxnet::Executor::~Executor ( )
inlinevirtual

destructor

Member Function Documentation

virtual const std::unordered_map<std::string, NDArray>& mxnet::Executor::arg_grad_map ( ) const
pure virtual

get input argument graident map, key is arg name, value is gradient's NDArray.

Returns
input argument gradient map in the executor.
virtual const std::unordered_map<std::string, NDArray>& mxnet::Executor::aux_state_map ( ) const
pure virtual

get aux state map, key is arg name, value is aux state's NDArray.

Returns
aux state map in the executor.
virtual void mxnet::Executor::Backward ( const std::vector< NDArray > &  head_grads,
bool  is_train = true 
)
pure virtual

Perform a Backward operation of the Operator. This must be called after Forward. After this operation, NDArrays specified by grad_in_args_store will be updated accordingly. User is allowed to pass in an empty Array if the head node is loss function and head gradeitn is not needed.

Parameters
head_gradsthe gradient of head nodes to be backproped.
static Executor* mxnet::Executor::Bind ( nnvm::Symbol  symbol,
const Context default_ctx,
const std::map< std::string, Context > &  group2ctx,
const std::vector< NDArray > &  in_args,
const std::vector< NDArray > &  arg_grad_store,
const std::vector< OpReqType > &  grad_req_type,
const std::vector< NDArray > &  aux_states,
Executor shared_exec = NULL 
)
static

Create an operator by bind symbol with context and arguments. If user do not want to compute the gradients of i-th argument, grad_req_type[i] can be kNullOp.

Parameters
default_ctxthe default context of binding.
group2ctxContext mapping group to context.
symbolthe symbol that specifies the output of Forward pass.
in_argsthe NDArray that stores the input arguments to the symbol.
arg_grad_storeNDArray that is used to store the gradient output of the input arguments.
grad_req_typerequirment type of gradient saving. Can only be in {kNullOp, kAddTo, kWriteTo}.
aux_statesNDArray that is used as internal state in op
shared_execinput executor to share memory with.
Returns
a new executor.
virtual void mxnet::Executor::Forward ( bool  is_train)
pure virtual

Perform a Forward operation of Operator After this operation, user can get the result by using function head.

virtual const std::unordered_map<std::string, NDArray>& mxnet::Executor::in_arg_map ( ) const
pure virtual

get input argument map, key is arg name, value is arg's NDArray.

Returns
input argument map in the executor.
virtual const std::vector<NDArray>& mxnet::Executor::outputs ( ) const
pure virtual

get array of outputs in the executor.

Returns
array of outputs in the executor.
virtual void mxnet::Executor::PartialForward ( bool  is_train,
int  step,
int *  step_left 
)
pure virtual

Perform a Partial Forward operation of Operator. Only issue operation specified by step. The caller must keep calling PartialForward with increasing steps, until step_left=0.

Parameters
is_trainWhether this is training phase.
stepcurrent step, user can always start from 0
step_leftNumber of steps left to finish the forward.
virtual void mxnet::Executor::Print ( std::ostream &  os) const
inlinevirtual

print the execution plan info to output stream.

Parameters
osthe output stream we like to print to.
virtual void mxnet::Executor::SetMonitorCallback ( const MonitorCallback callback)
inlinevirtual

Install a callback to notify the completion of operation.

static Executor* mxnet::Executor::SimpleBind ( nnvm::Symbol  symbol,
const Context default_ctx,
const std::map< std::string, Context > &  group2ctx,
const std::vector< Context > &  in_arg_ctxes,
const std::vector< Context > &  arg_grad_ctxes,
const std::vector< Context > &  aux_state_ctxes,
const std::unordered_map< std::string, TShape > &  arg_shape_map,
const std::unordered_map< std::string, int > &  arg_dtype_map,
const std::vector< OpReqType > &  grad_req_types,
const std::unordered_set< std::string > &  param_names,
std::vector< NDArray > *  in_args,
std::vector< NDArray > *  arg_grads,
std::vector< NDArray > *  aux_states,
std::unordered_map< std::string, NDArray > *  shared_data_arrays = nullptr,
Executor shared_exec = nullptr 
)
static

The documentation for this class was generated from the following file: