mxnet
imperative.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 #ifndef MXNET_IMPERATIVE_H_
21 #define MXNET_IMPERATIVE_H_
22 
23 #include <mxnet/op_attr_types.h>
24 #include <mxnet/graph_attr_types.h>
25 #include <mxnet/c_api.h>
26 #include <nnvm/symbolic.h>
27 #include <nnvm/op.h>
28 #include <nnvm/graph.h>
29 #include <vector>
30 #include <atomic>
31 #include <utility>
32 #include <string>
33 #include <unordered_map>
34 
35 #include "./ndarray.h"
36 
37 namespace mxnet {
38 
39 constexpr char OPT_CONSTRAINT_ATTR[] = "__opt_constraint__";
40 enum class OptConstraint : unsigned int {
41  None = 0,
42  DisableAMP = 1 << 0
43  // DisableQuantization = 1 << 1
44 };
45 using OptConstraint_int_t = std::underlying_type_t<OptConstraint>;
46 
59 
61 class Imperative {
62  public:
64  class AGInfo {
65  public:
69  std::vector<NDArray> outputs;
70  std::vector<NDArray> out_grads; // used to hold gradient arrays the user is
71  // interested in (marked variables)
73 
75 
76  static void Clear(const nnvm::ObjectPtr& node) {
77  if (node == nullptr || node->info.empty())
78  return;
79  AGInfo& info = Get(node);
80  if (info.grad_req != kNullOp)
81  return;
82  node->info.clear();
83  }
84 
85  static AGInfo& Get(const nnvm::ObjectPtr& node) {
86  return dmlc::get<AGInfo>(node->info);
87  }
88 
89  static AGInfo& Create(const nnvm::ObjectPtr& node) {
90  node->info.construct<AGInfo>();
91  return Get(node);
92  }
93 
94  static bool IsNone(const NDArray& arr) {
95  return arr.autograd_entry_.node == nullptr || arr.autograd_entry_.node->info.empty();
96  }
97 
98  static bool IsVariable(const nnvm::ObjectPtr& node) {
99  AGInfo& info = Get(node);
100  return info.grad_req != kNullOp && info.outputs.size() == 1 && info.out_grads.size() == 1;
101  }
102  };
103 
105  class DCInfo {
106  public:
107  explicit DCInfo(const std::vector<NDArray*>& inputs, const std::vector<NDArray*>& outputs);
108 
110  static void Compute(const NDArray& arr);
111 
112  static DCInfo& Get(const nnvm::ObjectPtr& node) {
113  return dmlc::get<DCInfo>(node->info);
114  }
115 
116  static bool IsNone(const NDArray& arr) {
117  return arr.deferredcompute_entry_.node == nullptr ||
118  arr.deferredcompute_entry_.node->info.empty();
119  }
120 
121  static bool IsComputed(const NDArray& arr) {
122  return IsNone(arr) || dmlc::get<DCInfo>(arr.deferredcompute_entry_.node->info).is_computed_;
123  }
124 
125  static DCInfo& Create(const nnvm::ObjectPtr& node,
126  const std::vector<NDArray*>& inputs,
127  const std::vector<NDArray*>& outputs);
128 
129  static void Clear(const nnvm::ObjectPtr& node) {
130  if (node == nullptr || node->info.empty())
131  return;
132  node->info.clear();
133  }
134 
135  private:
136  friend class Imperative;
137 
146  std::vector<NDArray> inputs_;
147 
158  std::vector<const NDArray*> input_handles_;
159 
168  std::vector<NDArray> outputs_;
169 
171  bool is_computed_ = false;
172  };
173 
175  bool is_training() const {
176  return is_train_;
177  }
179  bool set_is_training(bool is_train) {
180  bool old = is_train_;
181  is_train_ = is_train;
182  return old;
183  }
185  bool is_recording() const {
186  return is_recording_;
187  }
190  bool old = is_recording_;
191  is_recording_ = is_recording;
192  return old;
193  }
195  bool is_deferred_compute() const {
196  return is_deferred_compute_;
197  }
200  bool old = is_deferred_compute_;
201  is_deferred_compute_ = is_deferred_compute;
202  return old;
203  }
207  int is_np_shape() const {
208  if (is_np_shape_global_) {
209  return NumpyShape::GlobalOn;
210  }
211  return is_np_shape_thread_local_ ? NumpyShape::ThreadLocalOn : NumpyShape::Off;
212  }
215  NumpyShape flag = static_cast<NumpyShape>(is_np_shape);
216  bool old = this->is_np_shape();
217  switch (flag) {
218  case GlobalOn:
219  is_np_shape_global_ = true;
220  is_np_shape_thread_local_ = true;
221  break;
222  case ThreadLocalOn:
223  is_np_shape_thread_local_ = true;
224  break;
225  case Off:
226  is_np_shape_global_ = false;
227  is_np_shape_thread_local_ = false;
228  break;
229  }
230  return old;
231  }
234  bool is_np_default_dtype() const {
235  if (is_np_default_dtype_global_) {
236  return true;
237  }
238  return false;
239  }
242  bool old = this->is_np_default_dtype();
243  if (is_np_default_dtype) {
244  is_np_default_dtype_global_ = true;
245  } else {
246  is_np_default_dtype_global_ = false;
247  }
248  return old;
249  }
252  return opt_constraints_;
253  }
256  OptConstraint old = opt_constraints_;
257  opt_constraints_ = constraints;
258  return old;
259  }
261  void RecordOp(nnvm::NodeAttrs&& attrs,
262  const std::vector<NDArray*>& inputs,
263  const std::vector<NDArray*>& outputs,
264  const OpStatePtr& state = OpStatePtr(),
265  std::vector<bool>* p_save_inputs = nullptr,
266  std::vector<bool>* p_save_outputs = nullptr);
269  const std::vector<NDArray*>& inputs,
270  const std::vector<NDArray*>& outputs);
272  nnvm::Symbol GetDeferredComputeSymbol(const std::vector<NDArray*>& outputs);
274  void SetDeferredComputeVariable(NDArrayHandle* arrays, SymbolHandle* variables, const int num);
276  void DeferredComputeClear(NDArrayHandle* arrays, const int num);
278  OpStatePtr Invoke(const Context& default_ctx,
279  const nnvm::NodeAttrs& attrs,
280  const std::vector<NDArray*>& inputs,
281  const std::vector<NDArray*>& outputs);
283  OpStatePtr InvokeOp(const Context& ctx,
284  const nnvm::NodeAttrs& attrs,
285  const std::vector<NDArray*>& inputs,
286  const std::vector<NDArray*>& outputs,
287  const std::vector<OpReqType>& req,
288  const DispatchMode dispatch_mode,
289  OpStatePtr state = OpStatePtr());
291  void MarkVariables(const std::vector<NDArray*>& variables,
292  const std::vector<uint32_t>& grad_reqs,
293  const std::vector<NDArray*>& gradients);
295  void DropGrads(const std::vector<NDArray*>& variables);
297  std::vector<NDArray*> Backward(const std::vector<NDArray*>& outputs,
298  const std::vector<NDArray*>& ograds,
299  const std::vector<NDArray*>& variables,
300  bool is_train,
301  bool retain_graph,
302  bool create_graph);
304  std::vector<nnvm::ObjectPtr> ListNonleafVariables(const nnvm::Symbol& sym) const;
306  static Imperative* Get();
308  static bool PreferBulkExecInference() {
309  return dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_INFERENCE", true);
310  }
312  static bool PreferBulkExecTrain() {
313  return dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_TRAIN", true);
314  }
316  static int BulkExecMaxNodeTrainFwd() {
317  return dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_FWD",
318  dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15));
319  }
321  static int BulkExecMaxNodeTrainBwd() {
322  return dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD",
323  dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15));
324  }
325 
326  private:
327  friend class NDArray;
329  Imperative() {
330  if (PreferBulkExecTrain())
331  backward_bulk_size_ = BulkExecMaxNodeTrainBwd();
332  }
334  void GetBackwardDependency(const nnvm::ObjectPtr& node,
335  uint32_t num_inputs,
336  uint32_t num_outputs,
337  std::vector<bool>* p_save_inputs,
338  std::vector<bool>* p_save_outputs);
340 #if DMLC_CXX11_THREAD_LOCAL
341  static thread_local bool is_train_;
342  static thread_local bool is_recording_;
343  static thread_local bool is_deferred_compute_;
344  static thread_local OptConstraint opt_constraints_;
345  // TOOD(junwu): Added numpy compatibility switch for backward compatibility.
346  // Delete it in the next major release.
347  static thread_local bool is_np_shape_thread_local_;
348 #else
349  static MX_THREAD_LOCAL bool is_train_;
350  static MX_THREAD_LOCAL bool is_recording_;
351  static MX_THREAD_LOCAL bool is_deferred_compute_;
352  static MX_THREAD_LOCAL OptConstraint opt_constraints_;
353  // TOOD(junwu): Added numpy compatibility switch for backward compatibility.
354  // Delete it in the next major release.
355  static MX_THREAD_LOCAL bool is_np_shape_thread_local_;
356 #endif
357  bool is_np_shape_global_{false};
358  bool is_np_default_dtype_global_{false};
360  std::atomic<uint64_t> node_count_{0};
362  std::atomic<uint64_t> variable_count_{0};
364  int backward_bulk_size_{0};
365 };
366 
367 } // namespace mxnet
368 #endif // MXNET_IMPERATIVE_H_
mxnet
namespace of mxnet
Definition: api_registry.h:33
mxnet::ThreadLocalOn
@ ThreadLocalOn
Definition: imperative.h:57
mxnet::Imperative::BulkExecMaxNodeTrainFwd
static int BulkExecMaxNodeTrainFwd()
The max number of op nodes in a bulk during forward pass of training.
Definition: imperative.h:316
mxnet::OpStatePtr
Operator state. This is a pointer type, its content is mutable even if OpStatePtr is const.
Definition: op_attr_types.h:148
mxnet::Imperative::DropGrads
void DropGrads(const std::vector< NDArray * > &variables)
unmark nonleaf variables to free the memory.
mxnet::OptConstraint::None
@ None
mxnet::Imperative::set_opt_constraints
OptConstraint set_opt_constraints(OptConstraint constraints)
set optimization constraints.
Definition: imperative.h:255
mxnet::Imperative::AGInfo::Get
static AGInfo & Get(const nnvm::ObjectPtr &node)
Definition: imperative.h:85
mxnet::Imperative::set_is_np_shape
bool set_is_np_shape(int is_np_shape)
specify numpy compatibility off, thread local on or global on.
Definition: imperative.h:214
mxnet::Imperative::Invoke
OpStatePtr Invoke(const Context &default_ctx, const nnvm::NodeAttrs &attrs, const std::vector< NDArray * > &inputs, const std::vector< NDArray * > &outputs)
mxnet::Imperative::MarkVariables
void MarkVariables(const std::vector< NDArray * > &variables, const std::vector< uint32_t > &grad_reqs, const std::vector< NDArray * > &gradients)
mark variables for computing gradients.
mxnet::OpReqType
OpReqType
operation request type to Forward and Backward
Definition: op_attr_types.h:45
mxnet::Imperative::ListNonleafVariables
std::vector< nnvm::ObjectPtr > ListNonleafVariables(const nnvm::Symbol &sym) const
Return the marked nonleaf nodes.
mxnet::Imperative::is_training
bool is_training() const
whether operator recording is on.
Definition: imperative.h:175
mxnet::Imperative::set_is_recording
bool set_is_recording(bool is_recording)
turn on or turn off operator recording for autograd.
Definition: imperative.h:189
mxnet::DispatchMode
DispatchMode
the dispatch mode of the operator
Definition: op_attr_types.h:122
mxnet::Imperative::DeferredComputeClear
void DeferredComputeClear(NDArrayHandle *arrays, const int num)
clear info node associated with array
op.h
Operator information structor.
mxnet::Imperative::Backward
std::vector< NDArray * > Backward(const std::vector< NDArray * > &outputs, const std::vector< NDArray * > &ograds, const std::vector< NDArray * > &variables, bool is_train, bool retain_graph, bool create_graph)
compute the gradient of outputs w.r.t variables.
mxnet::kNullOp
@ kNullOp
no operation, do not write anything
Definition: op_attr_types.h:47
mxnet::Imperative::AGInfo::Create
static AGInfo & Create(const nnvm::ObjectPtr &node)
Definition: imperative.h:89
mxnet::Imperative::DCInfo::IsComputed
static bool IsComputed(const NDArray &arr)
Definition: imperative.h:121
mxnet::Imperative::is_np_shape
int is_np_shape() const
return current numpy compatibility status, GlobalOn(2), ThreadLocalOn(1), Off(0).
Definition: imperative.h:207
mxnet::Imperative::DCInfo::Create
static DCInfo & Create(const nnvm::ObjectPtr &node, const std::vector< NDArray * > &inputs, const std::vector< NDArray * > &outputs)
mxnet::Imperative::PreferBulkExecInference
static bool PreferBulkExecInference()
Should op execution bulking be employed during inference.
Definition: imperative.h:308
mxnet::Imperative::AGInfo::Clear
static void Clear(const nnvm::ObjectPtr &node)
Definition: imperative.h:76
mxnet::Imperative::GetDeferredComputeSymbol
nnvm::Symbol GetDeferredComputeSymbol(const std::vector< NDArray * > &outputs)
obtain symbol representation of deferred compute session.
mxnet::Imperative::set_is_training
bool set_is_training(bool is_train)
turn on or turn off operator recording for autograd.
Definition: imperative.h:179
nnvm::Symbol
Symbol is help class used to represent the operator node in Graph.
Definition: symbolic.h:50
mxnet::OptConstraint
OptConstraint
Definition: imperative.h:40
mxnet::NumpyDefaultDtype
NumpyShape NumpyDefaultDtype
Definition: imperative.h:58
mxnet::Imperative::is_np_default_dtype
bool is_np_default_dtype() const
return current numpy default dtype compatibility status.
Definition: imperative.h:234
mxnet::Imperative::DCInfo::IsNone
static bool IsNone(const NDArray &arr)
Definition: imperative.h:116
mxnet::Off
@ Off
Definition: imperative.h:57
nnvm::NodeAttrs
The attributes of the current operation node. Usually are additional parameters like axis,...
Definition: node.h:107
mxnet::Imperative::RecordDeferredCompute
void RecordDeferredCompute(nnvm::NodeAttrs &&attrs, const std::vector< NDArray * > &inputs, const std::vector< NDArray * > &outputs)
to record operator, return corresponding node.
mxnet::Imperative::is_recording
bool is_recording() const
whether operator recording is on.
Definition: imperative.h:185
nnvm::ObjectPtr
std::shared_ptr< Node > ObjectPtr
we always used ObjectPtr for a reference pointer to the node, so this alias can be changed in case.
Definition: node.h:49
mxnet::Imperative::InvokeOp
OpStatePtr InvokeOp(const Context &ctx, const nnvm::NodeAttrs &attrs, const std::vector< NDArray * > &inputs, const std::vector< NDArray * > &outputs, const std::vector< OpReqType > &req, const DispatchMode dispatch_mode, OpStatePtr state=OpStatePtr())
mxnet::Imperative::PreferBulkExecTrain
static bool PreferBulkExecTrain()
Should op execution bulking be employed during training.
Definition: imperative.h:312
symbolic.h
Symbolic graph construction API.
mxnet::NDArray
ndarray interface
Definition: ndarray.h:82
mxnet::Imperative::DCInfo::Compute
static void Compute(const NDArray &arr)
Compute the outputs of the associated operator.
c_api.h
C API of mxnet.
mxnet::Imperative::AGInfo::AGInfo
AGInfo()
Definition: imperative.h:74
mxnet::Imperative::Get
static Imperative * Get()
mxnet::Imperative::SetDeferredComputeVariable
void SetDeferredComputeVariable(NDArrayHandle *arrays, SymbolHandle *variables, const int num)
associate arrays with variables for deferred compute
mxnet::Imperative::get_opt_constraints
OptConstraint get_opt_constraints() const
return current optimization constraints.
Definition: imperative.h:251
mxnet::GlobalOn
@ GlobalOn
Definition: imperative.h:57
mxnet::Context
Context information about the execution environment.
Definition: base.h:90
mxnet::Imperative::set_is_deferred_compute
bool set_is_deferred_compute(bool is_deferred_compute)
turn on or turn off operator recording for autograd.
Definition: imperative.h:199
mxnet::OptConstraint_int_t
std::underlying_type_t< OptConstraint > OptConstraint_int_t
Definition: imperative.h:45
mxnet::Imperative
runtime functions for NDArray
Definition: imperative.h:61
mxnet::Imperative::is_deferred_compute
bool is_deferred_compute() const
whether deferred compute mode is on.
Definition: imperative.h:195
mxnet::Imperative::DCInfo
DCInfo datastructure to enable deferred computation.
Definition: imperative.h:105
mxnet::Imperative::AGInfo
Definition: imperative.h:64
op_attr_types.h
Additional operator attributes beside the ones provided by NNVM.
graph_attr_types.h
Data structures that can appear in graph attributes.
mxnet::NDArrayHandle
Definition: ndarray_handle.h:40
graph.h
Configuation of nnvm as well as basic data structure.
mxnet::Imperative::AGInfo::outputs
std::vector< NDArray > outputs
Definition: imperative.h:69
SymbolHandle
void * SymbolHandle
handle to a symbol that can be bind as operator
Definition: c_api.h:82
mxnet::Imperative::DCInfo::Get
static DCInfo & Get(const nnvm::ObjectPtr &node)
Definition: imperative.h:112
mxnet::OPT_CONSTRAINT_ATTR
constexpr char OPT_CONSTRAINT_ATTR[]
Definition: imperative.h:39
mxnet::OptConstraint::DisableAMP
@ DisableAMP
mxnet::Imperative::RecordOp
void RecordOp(nnvm::NodeAttrs &&attrs, const std::vector< NDArray * > &inputs, const std::vector< NDArray * > &outputs, const OpStatePtr &state=OpStatePtr(), std::vector< bool > *p_save_inputs=nullptr, std::vector< bool > *p_save_outputs=nullptr)
to record operator, return corresponding node.
mxnet::Imperative::AGInfo::IsVariable
static bool IsVariable(const nnvm::ObjectPtr &node)
Definition: imperative.h:98
mxnet::Imperative::AGInfo::ctx
Context ctx
Definition: imperative.h:66
mxnet::Imperative::AGInfo::fresh_out_grad
bool fresh_out_grad
Definition: imperative.h:72
mxnet::Imperative::DCInfo::DCInfo
DCInfo(const std::vector< NDArray * > &inputs, const std::vector< NDArray * > &outputs)
ndarray.h
NDArray interface that handles array arithematics.
mxnet::Imperative::AGInfo::grad_req
OpReqType grad_req
Definition: imperative.h:67
nnvm::NodeEntry::node
ObjectPtr node
the source node of this data
Definition: node.h:65
mxnet::Imperative::set_is_np_default_dtype
bool set_is_np_default_dtype(bool is_np_default_dtype)
specify numpy default dtype off or global on.
Definition: imperative.h:241
mxnet::Imperative::AGInfo::out_grads
std::vector< NDArray > out_grads
Definition: imperative.h:70
mxnet::Imperative::BulkExecMaxNodeTrainBwd
static int BulkExecMaxNodeTrainBwd()
The max number of op nodes in a bulk during backward pass of training.
Definition: imperative.h:321
mxnet::Imperative::AGInfo::IsNone
static bool IsNone(const NDArray &arr)
Definition: imperative.h:94
mxnet::Imperative::AGInfo::state
OpStatePtr state
Definition: imperative.h:68
mxnet::NumpyShape
NumpyShape
there are three numpy shape flags based on priority. GlobalOn turn on numpy shape flag globally,...
Definition: imperative.h:57
mxnet::Imperative::DCInfo::Clear
static void Clear(const nnvm::ObjectPtr &node)
Definition: imperative.h:129