Go to the documentation of this file.
25 #ifndef MXNET_EXECUTOR_H_
26 #define MXNET_EXECUTOR_H_
40 #if DMLC_USE_CXX11 == 0
41 #error "CXX11 was required for symbolic module"
60 virtual void Forward(
bool is_train) = 0;
69 virtual void PartialForward(
bool is_train,
int step,
int* step_left) = 0;
79 virtual void Backward(
const std::vector<NDArray>& head_grads,
bool is_train =
true) = 0;
84 virtual void Print(std::ostream& os)
const {}
89 virtual const std::vector<NDArray>&
outputs()
const = 0;
94 virtual const std::unordered_map<std::string, NDArray>&
in_arg_map()
const = 0;
99 virtual const std::unordered_map<std::string, NDArray>&
arg_grad_map()
const = 0;
104 virtual const std::unordered_map<std::string, NDArray>&
aux_state_map()
const = 0;
121 const bool partial_shaping,
122 const bool allow_up_sizing,
124 const std::map<std::string, Context>& ctx_map,
125 const std::unordered_map<std::string, mxnet::TShape>& provided_arg_shapes,
126 std::vector<NDArray>* in_args,
127 std::vector<NDArray>* arg_grads,
128 std::vector<NDArray>* aux_states) = 0;
148 const std::map<std::string, Context>& group2ctx,
149 const std::vector<NDArray>& in_args,
150 const std::vector<NDArray>& arg_grad_store,
151 const std::vector<OpReqType>& grad_req_type,
152 const std::vector<NDArray>& aux_states,
158 const std::map<std::string, Context>& group2ctx,
159 const std::vector<Context>& in_arg_ctxes,
160 const std::vector<Context>& arg_grad_ctxes,
161 const std::vector<Context>& aux_state_ctxes,
162 const std::unordered_map<std::string, mxnet::TShape>& arg_shape_map,
163 const std::unordered_map<std::string, int>& arg_dtype_map,
164 const std::unordered_map<std::string, int>& arg_stype_map,
165 const std::vector<OpReqType>& grad_req_types,
166 const std::unordered_set<std::string>& param_names,
167 std::vector<NDArray>* in_args,
168 std::vector<NDArray>* arg_grads,
169 std::vector<NDArray>* aux_states,
170 std::unordered_map<std::string, NDArray>* shared_data_arrays =
nullptr,
183 #endif // MXNET_EXECUTOR_H_
namespace of mxnet
Definition: api_registry.h:33
Executor of a computation graph. Executor can be created by Binding a symbol.
Definition: executor.h:52
std::function< void(const char *, void *)> MonitorCallback
the prototype of user-defined monitor callback
Definition: executor.h:176
virtual void Print(std::ostream &os) const
print the execution plan info to output stream.
Definition: executor.h:84
virtual void Backward(const std::vector< NDArray > &head_grads, bool is_train=true)=0
Perform a Backward operation of the Operator. This must be called after Forward. After this operation...
static Executor * SimpleBind(nnvm::Symbol symbol, const Context &default_ctx, const std::map< std::string, Context > &group2ctx, const std::vector< Context > &in_arg_ctxes, const std::vector< Context > &arg_grad_ctxes, const std::vector< Context > &aux_state_ctxes, const std::unordered_map< std::string, mxnet::TShape > &arg_shape_map, const std::unordered_map< std::string, int > &arg_dtype_map, const std::unordered_map< std::string, int > &arg_stype_map, const std::vector< OpReqType > &grad_req_types, const std::unordered_set< std::string > ¶m_names, std::vector< NDArray > *in_args, std::vector< NDArray > *arg_grads, std::vector< NDArray > *aux_states, std::unordered_map< std::string, NDArray > *shared_data_arrays=nullptr, Executor *shared_exec=nullptr)
virtual ~Executor()
destructor
Definition: executor.h:55
defines configuration macros
virtual const std::vector< NDArray > & outputs() const =0
get array of outputs in the executor.
virtual const std::unordered_map< std::string, NDArray > & aux_state_map() const =0
get aux state map, key is arg name, value is aux state's NDArray.
virtual const std::unordered_map< std::string, NDArray > & arg_grad_map() const =0
get input argument graident map, key is arg name, value is gradient's NDArray.
virtual void SetMonitorCallback(const MonitorCallback &callback, bool monitor_all=false)
Install a callback to notify the completion of operation.
Definition: executor.h:180
virtual const std::unordered_map< std::string, NDArray > & in_arg_map() const =0
get input argument map, key is arg name, value is arg's NDArray.
Symbol is help class used to represent the operator node in Graph.
Definition: symbolic.h:50
Context information about the execution environment.
Definition: base.h:90
virtual void PartialForward(bool is_train, int step, int *step_left)=0
Perform a Partial Forward operation of Operator. Only issue operation specified by step....
virtual void Forward(bool is_train)=0
Perform a Forward operation of Operator After this operation, user can get the result by using functi...
virtual Executor * Reshape(const bool partial_shaping, const bool allow_up_sizing, const Context &default_ctx, const std::map< std::string, Context > &ctx_map, const std::unordered_map< std::string, mxnet::TShape > &provided_arg_shapes, std::vector< NDArray > *in_args, std::vector< NDArray > *arg_grads, std::vector< NDArray > *aux_states)=0
Return a new executor with the same symbol and shared memory, but different input/output shapes.
Operator interface of mxnet.
NDArray interface that handles array arithematics.
static Executor * Bind(nnvm::Symbol symbol, const Context &default_ctx, const std::map< std::string, Context > &group2ctx, const std::vector< NDArray > &in_args, const std::vector< NDArray > &arg_grad_store, const std::vector< OpReqType > &grad_req_type, const std::vector< NDArray > &aux_states, Executor *shared_exec=nullptr)
Create an operator by bind symbol with context and arguments. If user do not want to compute the grad...