Go to the documentation of this file.
20 #ifndef MXNET_IMPERATIVE_H_
21 #define MXNET_IMPERATIVE_H_
33 #include <unordered_map>
77 if (node ==
nullptr || node->info.empty())
86 return dmlc::get<AGInfo>(node->info);
90 node->info.construct<
AGInfo>();
95 return arr.autograd_entry_.
node ==
nullptr || arr.autograd_entry_.
node->info.empty();
107 explicit DCInfo(
const std::vector<NDArray*>& inputs,
const std::vector<NDArray*>& outputs);
113 return dmlc::get<DCInfo>(node->info);
117 return arr.deferredcompute_entry_.
node ==
nullptr ||
118 arr.deferredcompute_entry_.
node->info.empty();
122 return IsNone(arr) || dmlc::get<DCInfo>(arr.deferredcompute_entry_.
node->info).is_computed_;
126 const std::vector<NDArray*>& inputs,
127 const std::vector<NDArray*>& outputs);
130 if (node ==
nullptr || node->info.empty())
146 std::vector<NDArray> inputs_;
158 std::vector<const NDArray*> input_handles_;
168 std::vector<NDArray> outputs_;
171 bool is_computed_ =
false;
180 bool old = is_train_;
181 is_train_ = is_train;
186 return is_recording_;
190 bool old = is_recording_;
196 return is_deferred_compute_;
200 bool old = is_deferred_compute_;
208 if (is_np_shape_global_) {
219 is_np_shape_global_ =
true;
220 is_np_shape_thread_local_ =
true;
223 is_np_shape_thread_local_ =
true;
226 is_np_shape_global_ =
false;
227 is_np_shape_thread_local_ =
false;
235 if (is_np_default_dtype_global_) {
243 if (is_np_default_dtype) {
244 is_np_default_dtype_global_ =
true;
246 is_np_default_dtype_global_ =
false;
252 return opt_constraints_;
257 opt_constraints_ = constraints;
262 const std::vector<NDArray*>& inputs,
263 const std::vector<NDArray*>& outputs,
265 std::vector<bool>* p_save_inputs =
nullptr,
266 std::vector<bool>* p_save_outputs =
nullptr);
269 const std::vector<NDArray*>& inputs,
270 const std::vector<NDArray*>& outputs);
280 const std::vector<NDArray*>& inputs,
281 const std::vector<NDArray*>& outputs);
285 const std::vector<NDArray*>& inputs,
286 const std::vector<NDArray*>& outputs,
287 const std::vector<OpReqType>& req,
292 const std::vector<uint32_t>& grad_reqs,
293 const std::vector<NDArray*>& gradients);
295 void DropGrads(
const std::vector<NDArray*>& variables);
297 std::vector<NDArray*>
Backward(
const std::vector<NDArray*>& outputs,
298 const std::vector<NDArray*>& ograds,
299 const std::vector<NDArray*>& variables,
309 return dmlc::GetEnv(
"MXNET_EXEC_BULK_EXEC_INFERENCE",
true);
313 return dmlc::GetEnv(
"MXNET_EXEC_BULK_EXEC_TRAIN",
true);
317 return dmlc::GetEnv(
"MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_FWD",
318 dmlc::GetEnv(
"MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15));
322 return dmlc::GetEnv(
"MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD",
323 dmlc::GetEnv(
"MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15));
336 uint32_t num_outputs,
337 std::vector<bool>* p_save_inputs,
338 std::vector<bool>* p_save_outputs);
340 #if DMLC_CXX11_THREAD_LOCAL
341 static thread_local
bool is_train_;
342 static thread_local
bool is_recording_;
343 static thread_local
bool is_deferred_compute_;
347 static thread_local
bool is_np_shape_thread_local_;
349 static MX_THREAD_LOCAL
bool is_train_;
350 static MX_THREAD_LOCAL
bool is_recording_;
351 static MX_THREAD_LOCAL
bool is_deferred_compute_;
355 static MX_THREAD_LOCAL
bool is_np_shape_thread_local_;
357 bool is_np_shape_global_{
false};
358 bool is_np_default_dtype_global_{
false};
360 std::atomic<uint64_t> node_count_{0};
362 std::atomic<uint64_t> variable_count_{0};
364 int backward_bulk_size_{0};
368 #endif // MXNET_IMPERATIVE_H_
namespace of mxnet
Definition: api_registry.h:33
@ ThreadLocalOn
Definition: imperative.h:57
static int BulkExecMaxNodeTrainFwd()
The max number of op nodes in a bulk during forward pass of training.
Definition: imperative.h:316
Operator state. This is a pointer type, its content is mutable even if OpStatePtr is const.
Definition: op_attr_types.h:148
void DropGrads(const std::vector< NDArray * > &variables)
unmark nonleaf variables to free the memory.
OptConstraint set_opt_constraints(OptConstraint constraints)
set optimization constraints.
Definition: imperative.h:255
static AGInfo & Get(const nnvm::ObjectPtr &node)
Definition: imperative.h:85
bool set_is_np_shape(int is_np_shape)
specify numpy compatibility off, thread local on or global on.
Definition: imperative.h:214
OpStatePtr Invoke(const Context &default_ctx, const nnvm::NodeAttrs &attrs, const std::vector< NDArray * > &inputs, const std::vector< NDArray * > &outputs)
void MarkVariables(const std::vector< NDArray * > &variables, const std::vector< uint32_t > &grad_reqs, const std::vector< NDArray * > &gradients)
mark variables for computing gradients.
OpReqType
operation request type to Forward and Backward
Definition: op_attr_types.h:45
std::vector< nnvm::ObjectPtr > ListNonleafVariables(const nnvm::Symbol &sym) const
Return the marked nonleaf nodes.
bool is_training() const
whether operator recording is on.
Definition: imperative.h:175
bool set_is_recording(bool is_recording)
turn on or turn off operator recording for autograd.
Definition: imperative.h:189
DispatchMode
the dispatch mode of the operator
Definition: op_attr_types.h:122
void DeferredComputeClear(NDArrayHandle *arrays, const int num)
clear info node associated with array
Operator information structor.
std::vector< NDArray * > Backward(const std::vector< NDArray * > &outputs, const std::vector< NDArray * > &ograds, const std::vector< NDArray * > &variables, bool is_train, bool retain_graph, bool create_graph)
compute the gradient of outputs w.r.t variables.
@ kNullOp
no operation, do not write anything
Definition: op_attr_types.h:47
static AGInfo & Create(const nnvm::ObjectPtr &node)
Definition: imperative.h:89
static bool IsComputed(const NDArray &arr)
Definition: imperative.h:121
int is_np_shape() const
return current numpy compatibility status, GlobalOn(2), ThreadLocalOn(1), Off(0).
Definition: imperative.h:207
static DCInfo & Create(const nnvm::ObjectPtr &node, const std::vector< NDArray * > &inputs, const std::vector< NDArray * > &outputs)
static bool PreferBulkExecInference()
Should op execution bulking be employed during inference.
Definition: imperative.h:308
static void Clear(const nnvm::ObjectPtr &node)
Definition: imperative.h:76
nnvm::Symbol GetDeferredComputeSymbol(const std::vector< NDArray * > &outputs)
obtain symbol representation of deferred compute session.
bool set_is_training(bool is_train)
turn on or turn off operator recording for autograd.
Definition: imperative.h:179
Symbol is help class used to represent the operator node in Graph.
Definition: symbolic.h:50
OptConstraint
Definition: imperative.h:40
NumpyShape NumpyDefaultDtype
Definition: imperative.h:58
bool is_np_default_dtype() const
return current numpy default dtype compatibility status.
Definition: imperative.h:234
static bool IsNone(const NDArray &arr)
Definition: imperative.h:116
@ Off
Definition: imperative.h:57
The attributes of the current operation node. Usually are additional parameters like axis,...
Definition: node.h:107
void RecordDeferredCompute(nnvm::NodeAttrs &&attrs, const std::vector< NDArray * > &inputs, const std::vector< NDArray * > &outputs)
to record operator, return corresponding node.
bool is_recording() const
whether operator recording is on.
Definition: imperative.h:185
std::shared_ptr< Node > ObjectPtr
we always used ObjectPtr for a reference pointer to the node, so this alias can be changed in case.
Definition: node.h:49
OpStatePtr InvokeOp(const Context &ctx, const nnvm::NodeAttrs &attrs, const std::vector< NDArray * > &inputs, const std::vector< NDArray * > &outputs, const std::vector< OpReqType > &req, const DispatchMode dispatch_mode, OpStatePtr state=OpStatePtr())
static bool PreferBulkExecTrain()
Should op execution bulking be employed during training.
Definition: imperative.h:312
Symbolic graph construction API.
ndarray interface
Definition: ndarray.h:82
static void Compute(const NDArray &arr)
Compute the outputs of the associated operator.
AGInfo()
Definition: imperative.h:74
static Imperative * Get()
void SetDeferredComputeVariable(NDArrayHandle *arrays, SymbolHandle *variables, const int num)
associate arrays with variables for deferred compute
OptConstraint get_opt_constraints() const
return current optimization constraints.
Definition: imperative.h:251
@ GlobalOn
Definition: imperative.h:57
Context information about the execution environment.
Definition: base.h:90
bool set_is_deferred_compute(bool is_deferred_compute)
turn on or turn off operator recording for autograd.
Definition: imperative.h:199
std::underlying_type_t< OptConstraint > OptConstraint_int_t
Definition: imperative.h:45
runtime functions for NDArray
Definition: imperative.h:61
bool is_deferred_compute() const
whether deferred compute mode is on.
Definition: imperative.h:195
DCInfo datastructure to enable deferred computation.
Definition: imperative.h:105
Definition: imperative.h:64
Additional operator attributes beside the ones provided by NNVM.
Data structures that can appear in graph attributes.
Definition: ndarray_handle.h:40
Configuation of nnvm as well as basic data structure.
std::vector< NDArray > outputs
Definition: imperative.h:69
void * SymbolHandle
handle to a symbol that can be bind as operator
Definition: c_api.h:82
static DCInfo & Get(const nnvm::ObjectPtr &node)
Definition: imperative.h:112
constexpr char OPT_CONSTRAINT_ATTR[]
Definition: imperative.h:39
void RecordOp(nnvm::NodeAttrs &&attrs, const std::vector< NDArray * > &inputs, const std::vector< NDArray * > &outputs, const OpStatePtr &state=OpStatePtr(), std::vector< bool > *p_save_inputs=nullptr, std::vector< bool > *p_save_outputs=nullptr)
to record operator, return corresponding node.
static bool IsVariable(const nnvm::ObjectPtr &node)
Definition: imperative.h:98
Context ctx
Definition: imperative.h:66
bool fresh_out_grad
Definition: imperative.h:72
DCInfo(const std::vector< NDArray * > &inputs, const std::vector< NDArray * > &outputs)
NDArray interface that handles array arithematics.
OpReqType grad_req
Definition: imperative.h:67
ObjectPtr node
the source node of this data
Definition: node.h:65
bool set_is_np_default_dtype(bool is_np_default_dtype)
specify numpy default dtype off or global on.
Definition: imperative.h:241
std::vector< NDArray > out_grads
Definition: imperative.h:70
static int BulkExecMaxNodeTrainBwd()
The max number of op nodes in a bulk during backward pass of training.
Definition: imperative.h:321
static bool IsNone(const NDArray &arr)
Definition: imperative.h:94
OpStatePtr state
Definition: imperative.h:68
NumpyShape
there are three numpy shape flags based on priority. GlobalOn turn on numpy shape flag globally,...
Definition: imperative.h:57
static void Clear(const nnvm::ObjectPtr &node)
Definition: imperative.h:129