20 #ifndef MXNET_IMPERATIVE_H_ 21 #define MXNET_IMPERATIVE_H_ 33 #include <unordered_map> 63 grad_req(
kNullOp), fresh_out_grad(false) {}
66 if (node ==
nullptr || node->info.empty())
return;
73 return dmlc::get<AGInfo>(node->info);
77 node->info.construct<
AGInfo>();
82 return arr.entry_.
node ==
nullptr || arr.entry_.
node->info.empty();
103 return is_recording_;
107 bool old = is_recording_;
115 if (is_np_shape_global_) {
118 return is_np_shape_thread_local_ ? 1 : 0;
124 if (is_np_default_dtype_global_) {
136 is_np_shape_global_ =
true;
137 is_np_shape_thread_local_ =
true;
140 is_np_shape_thread_local_ =
true;
143 is_np_shape_global_ =
false;
144 is_np_shape_thread_local_ =
false;
151 const std::vector<NDArray*>& inputs,
152 const std::vector<NDArray*>&
outputs,
154 std::vector<bool>* p_save_inputs =
nullptr,
155 std::vector<bool>* p_save_outputs =
nullptr);
159 const std::vector<NDArray*>& inputs,
160 const std::vector<NDArray*>& outputs);
164 const std::vector<NDArray*>& inputs,
165 const std::vector<NDArray*>& outputs,
166 const std::vector<OpReqType>& req,
171 const std::vector<uint32_t>& grad_reqs,
172 const std::vector<NDArray*>& gradients);
174 std::vector<NDArray*>
Backward(
const std::vector<NDArray*>& outputs,
175 const std::vector<NDArray*>& ograds,
176 const std::vector<NDArray*>& variables,
177 bool is_train,
bool retain_graph,
183 return dmlc::GetEnv(
"MXNET_EXEC_BULK_EXEC_INFERENCE",
true);
187 return dmlc::GetEnv(
"MXNET_EXEC_BULK_EXEC_TRAIN",
true);
191 return dmlc::GetEnv(
"MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_FWD",
192 dmlc::GetEnv(
"MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15));
196 return dmlc::GetEnv(
"MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD",
197 dmlc::GetEnv(
"MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15));
208 void GetBackwardDependency(
210 uint32_t num_inputs, uint32_t num_outputs,
211 std::vector<bool> *p_save_inputs,
212 std::vector<bool> *p_save_outputs);
214 #if DMLC_CXX11_THREAD_LOCAL 215 static thread_local
bool is_train_;
216 static thread_local
bool is_recording_;
219 static thread_local
bool is_np_shape_thread_local_;
221 static MX_THREAD_LOCAL
bool is_train_;
222 static MX_THREAD_LOCAL
bool is_recording_;
225 static MX_THREAD_LOCAL
bool is_np_shape_thread_local_;
227 bool is_np_shape_global_{
false};
228 bool is_np_default_dtype_global_{
false};
230 std::atomic<uint64_t> node_count_{0};
232 std::atomic<uint64_t> variable_count_{0};
234 int backward_bulk_size_{0};
238 #endif // MXNET_IMPERATIVE_H_ Definition: imperative.h:48
bool is_recording() const
whether operator recording is on.
Definition: imperative.h:102
static bool IsNone(const NDArray &arr)
Definition: imperative.h:81
static int BulkExecMaxNodeTrainFwd()
The max number of op nodes in a bulk during forward pass of training.
Definition: imperative.h:190
Definition: imperative.h:48
bool is_training() const
whether operator recording is on.
Definition: imperative.h:92
no operation, do not write anything
Definition: op_attr_types.h:48
bool set_is_training(bool is_train)
turn on or turn off operator recording for autograd.
Definition: imperative.h:96
The attributes of the current operation node. Usually are additional parameters like axis...
Definition: node.h:119
namespace of mxnet
Definition: api_registry.h:33
std::vector< NDArray * > Backward(const std::vector< NDArray * > &outputs, const std::vector< NDArray * > &ograds, const std::vector< NDArray * > &variables, bool is_train, bool retain_graph, bool create_graph)
compute the gradient of outputs w.r.t variables.
OpStatePtr Invoke(const Context &default_ctx, const nnvm::NodeAttrs &attrs, const std::vector< NDArray * > &inputs, const std::vector< NDArray * > &outputs)
AGInfo()
Definition: imperative.h:62
ObjectPtr node
the source node of this data
Definition: node.h:75
DispatchMode
the dispatch mode of the operator
Definition: op_attr_types.h:123
std::vector< NDArray > outputs
Definition: imperative.h:58
int is_np_shape() const
return current numpy compatibility status, GlobalOn(2), ThreadLocalOn(1), Off(0). ...
Definition: imperative.h:114
bool is_np_default_dtype() const
return current numpy default dtype compatibility status.
Definition: imperative.h:123
Definition: imperative.h:53
bool set_is_recording(bool is_recording)
turn on or turn off operator recording for autograd.
Definition: imperative.h:106
bool fresh_out_grad
Definition: imperative.h:60
OpStatePtr state
Definition: imperative.h:57
std::vector< NDArray > out_grads
Definition: imperative.h:59
void RecordOp(nnvm::NodeAttrs &&attrs, const std::vector< NDArray * > &inputs, const std::vector< NDArray * > &outputs, const OpStatePtr &state=OpStatePtr(), std::vector< bool > *p_save_inputs=nullptr, std::vector< bool > *p_save_outputs=nullptr)
to record operator, return corresponding node.
OpReqType grad_req
Definition: imperative.h:56
Context ctx
Definition: imperative.h:55
Configuation of nnvm as well as basic data structure.
OpReqType
operation request type to Forward and Backward
Definition: op_attr_types.h:46
OpStatePtr InvokeOp(const Context &ctx, const nnvm::NodeAttrs &attrs, const std::vector< NDArray * > &inputs, const std::vector< NDArray * > &outputs, const std::vector< OpReqType > &req, const DispatchMode dispatch_mode, OpStatePtr state=OpStatePtr())
runtime functions for NDArray
Definition: imperative.h:50
Definition: imperative.h:48
static AGInfo & Create(const nnvm::ObjectPtr &node)
Definition: imperative.h:76
static bool IsVariable(const nnvm::ObjectPtr &node)
Definition: imperative.h:85
Operator information structor.
static bool PreferBulkExecTrain()
Should op execution bulking be employed during training.
Definition: imperative.h:186
Symbolic graph construction API.
static int BulkExecMaxNodeTrainBwd()
The max number of op nodes in a bulk during backward pass of training.
Definition: imperative.h:195
bool set_is_np_shape(int is_np_shape)
specify numpy compatibility off, thread local on or global on.
Definition: imperative.h:131
static bool PreferBulkExecInference()
Should op execution bulking be employed during inference.
Definition: imperative.h:182
void MarkVariables(const std::vector< NDArray * > &variables, const std::vector< uint32_t > &grad_reqs, const std::vector< NDArray * > &gradients)
mark variables for computing gradients.
Context information about the execution environment.
Definition: base.h:102
static void Clear(const nnvm::ObjectPtr &node)
Definition: imperative.h:65
static AGInfo & Get(const nnvm::ObjectPtr &node)
Definition: imperative.h:72
ndarray interface
Definition: ndarray.h:82
NumpyShape
there are three numpy shape flags based on priority. GlobalOn turn on numpy shape flag globally...
Definition: imperative.h:48
std::shared_ptr< Node > ObjectPtr
we always used ObjectPtr for a reference pointer to the node, so this alias can be changed in case...
Definition: node.h:48
Operator state. This is a pointer type, its content is mutable even if OpStatePtr is const...
Definition: op_attr_types.h:149