Go to the documentation of this file.
26 #ifndef MXNET_CPP_EXECUTOR_H_
27 #define MXNET_CPP_EXECUTOR_H_
51 const std::vector<OpReqType>& grad_reqs,
53 const std::map<std::string, Context>& group_to_ctx = std::map<std::string, Context>(),
63 std::vector<NDArrayHandle> arg_handles;
65 arg_handles.push_back(array.GetHandle());
67 int prev_is_record = 0;
68 int prev_train_mode = 0;
70 if (is_train ==
true) {
73 std::vector<NDArrayHandle> output_handles;
91 for (
mx_uint i = 0; i < out_size; ++i) {
94 int cur_train_mode = prev_train_mode;
95 int cur_is_record = prev_is_record;
96 if (is_train ==
true) {
111 void Backward(
const std::vector<NDArray>& head_grads = std::vector<NDArray>()) {
116 std::vector<NDArrayHandle> out_handles;
117 for (
const auto& array :
outputs) {
118 out_handles.push_back(array.GetHandle());
120 std::vector<NDArrayHandle> head_grads_;
121 for (
auto d : head_grads) {
122 head_grads_.push_back(d.GetHandle());
124 if (head_grads_.size() > 0) {
193 std::map<std::string, NDArray> GetDict(
const std::vector<std::string>& names,
194 const std::vector<NDArray>& arrays) {
195 std::map<std::string, NDArray> ret;
196 std::set<std::string> name_set;
197 for (
const auto& s : names) {
198 CHECK(name_set.find(s) == name_set.end()) <<
"Duplicate names detected, " << s;
201 CHECK_EQ(name_set.size(), arrays.size()) <<
"names size not equal to arrays size";
202 for (
size_t i = 0; i < names.size(); ++i) {
203 ret[names[i]] = arrays[i];
210 #endif // MXNET_CPP_EXECUTOR_H_
namespace of mxnet
Definition: api_registry.h:33
MXNET_DLL int MXAutogradSetIsTraining(int is_training, int *prev)
set whether to record operator for autograd
void * CachedOpHandle
handle to cached operator
Definition: c_api.h:80
std::map< std::string, NDArray > grad_dict()
Definition: executor.h:181
void Backward(const std::vector< NDArray > &head_grads=std::vector< NDArray >())
Perform a Backward operation of the Operator. This must be called after Forward. After this operation...
Definition: executor.h:111
int device_type
Definition: executor.h:171
std::vector< NDArray > outputs
arrays store the outputs of forward
Definition: executor.h:177
MXNET_DLL int MXAutogradBackwardEx(uint32_t num_output, NDArrayHandle *output_handles, NDArrayHandle *ograd_handles, uint32_t num_variables, NDArrayHandle *var_handles, int retain_graph, int create_graph, int is_train, NDArrayHandle **grad_handles, int **grad_stypes)
compute the gradient of outputs w.r.t variabels
std::vector< std::string > ListArguments() const
List the arguments names.
~Executor()
destructor, free the handle
Definition: executor.h:164
std::map< std::string, NDArray > aux_dict()
Definition: executor.h:184
void Forward(bool is_train)
Perform a Forward operation of Operator After this operation, user can get the result by using functi...
Definition: executor.h:62
NDArray interface.
Definition: ndarray.h:122
Context interface.
Definition: ndarray.h:45
std::vector< std::string > ListAuxiliaryStates() const
std::vector< NDArray > grad_arrays
Definition: executor.h:168
Executor(const CachedOpHandle &h)
Definition: executor.h:55
MXNET_DLL int MXFreeCachedOp(CachedOpHandle handle)
free cached operator
std::vector< NDArray > combined_arrays
Definition: executor.h:170
MXNET_DLL int MXNDArrayGetGrad(NDArrayHandle handle, NDArrayHandle *out)
return gradient buffer attached to this NDArray
Definition: ndarray_handle.h:40
Executor interface.
Definition: executor.h:45
std::map< std::string, NDArray > arg_dict()
Definition: executor.h:178
MXNET_DLL int MXInvokeCachedOp(CachedOpHandle handle, int num_inputs, NDArrayHandle *inputs, int default_dev_type, int default_dev_id, int *num_outputs, NDArrayHandle **outputs, const int **out_stypes)
invoke a cached op
base definitions for mxnetcpp
std::vector< NDArray > arg_arrays
Definition: executor.h:167
int device_id
Definition: executor.h:172
MXNET_DLL int MXAutogradSetIsRecording(int is_recording, int *prev)
set whether to record operator for autograd
uint32_t mx_uint
manually define unsigned int
Definition: c_api.h:65
bool require_grad
Definition: executor.h:173
std::vector< NDArray > aux_arrays
Definition: executor.h:169
Executor(const Symbol &symbol, Context context, const std::vector< NDArray > &arg_arrays, const std::vector< NDArray > &grad_arrays, const std::vector< OpReqType > &grad_reqs, const std::vector< NDArray > &aux_arrays, const std::map< std::string, Context > &group_to_ctx=std::map< std::string, Context >(), Executor *shared_exec=nullptr)
Symbol interface.
Definition: symbol.h:73