mxnet
imperative.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 #ifndef MXNET_IMPERATIVE_H_
21 #define MXNET_IMPERATIVE_H_
22 
23 #include <mxnet/op_attr_types.h>
24 #include <mxnet/graph_attr_types.h>
25 #include <mxnet/c_api.h>
26 #include <nnvm/symbolic.h>
27 #include <nnvm/op.h>
28 #include <nnvm/graph.h>
29 #include <vector>
30 #include <atomic>
31 #include <utility>
32 #include <string>
33 #include <unordered_map>
34 
35 #include "./ndarray.h"
36 
37 namespace mxnet {
50 class Imperative {
51  public:
53  class AGInfo {
54  public:
58  std::vector<NDArray> outputs;
59  std::vector<NDArray> out_grads;
61 
62  AGInfo() :
63  grad_req(kNullOp), fresh_out_grad(false) {}
64 
65  static void Clear(const nnvm::ObjectPtr& node) {
66  if (node == nullptr || node->info.empty()) return;
67  AGInfo& info = Get(node);
68  if (info.grad_req != kNullOp) return;
69  node->info.clear();
70  }
71 
72  static AGInfo& Get(const nnvm::ObjectPtr& node) {
73  return dmlc::get<AGInfo>(node->info);
74  }
75 
76  static AGInfo& Create(const nnvm::ObjectPtr& node) {
77  node->info.construct<AGInfo>();
78  return Get(node);
79  }
80 
81  static bool IsNone(const NDArray& arr) {
82  return arr.entry_.node == nullptr || arr.entry_.node->info.empty();
83  }
84 
85  static bool IsVariable(const nnvm::ObjectPtr& node) {
86  AGInfo& info = Get(node);
87  return info.grad_req != kNullOp && info.outputs.size() == 1
88  && info.out_grads.size() == 1;
89  }
90  };
92  bool is_training() const {
93  return is_train_;
94  }
96  bool set_is_training(bool is_train) {
97  bool old = is_train_;
98  is_train_ = is_train;
99  return old;
100  }
102  bool is_recording() const {
103  return is_recording_;
104  }
107  bool old = is_recording_;
108  is_recording_ = is_recording;
109  return old;
110  }
114  int is_np_shape() const {
115  if (is_np_shape_global_) {
116  return 2;
117  }
118  return is_np_shape_thread_local_ ? 1 : 0;
119  }
120 
123  bool is_np_default_dtype() const {
124  if (is_np_default_dtype_global_) {
125  return true;
126  }
127  return false;
128  }
129 
132  NumpyShape flag = static_cast<NumpyShape>(is_np_shape);
133  bool old = this->is_np_shape();
134  switch (flag) {
135  case GlobalOn:
136  is_np_shape_global_ = true;
137  is_np_shape_thread_local_ = true;
138  break;
139  case ThreadLocalOn:
140  is_np_shape_thread_local_ = true;
141  break;
142  case Off:
143  is_np_shape_global_ = false;
144  is_np_shape_thread_local_ = false;
145  break;
146  }
147  return old;
148  }
150  void RecordOp(nnvm::NodeAttrs&& attrs,
151  const std::vector<NDArray*>& inputs,
152  const std::vector<NDArray*>& outputs,
153  const OpStatePtr& state = OpStatePtr(),
154  std::vector<bool>* p_save_inputs = nullptr,
155  std::vector<bool>* p_save_outputs = nullptr);
157  OpStatePtr Invoke(const Context& default_ctx,
158  const nnvm::NodeAttrs& attrs,
159  const std::vector<NDArray*>& inputs,
160  const std::vector<NDArray*>& outputs);
163  const nnvm::NodeAttrs& attrs,
164  const std::vector<NDArray*>& inputs,
165  const std::vector<NDArray*>& outputs,
166  const std::vector<OpReqType>& req,
167  const DispatchMode dispatch_mode,
170  void MarkVariables(const std::vector<NDArray*>& variables,
171  const std::vector<uint32_t>& grad_reqs,
172  const std::vector<NDArray*>& gradients);
174  std::vector<NDArray*> Backward(const std::vector<NDArray*>& outputs,
175  const std::vector<NDArray*>& ograds,
176  const std::vector<NDArray*>& variables,
177  bool is_train, bool retain_graph,
178  bool create_graph);
180  static Imperative* Get();
182  static bool PreferBulkExecInference() {
183  return dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_INFERENCE", true);
184  }
186  static bool PreferBulkExecTrain() {
187  return dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_TRAIN", true);
188  }
190  static int BulkExecMaxNodeTrainFwd() {
191  return dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_FWD",
192  dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15));
193  }
195  static int BulkExecMaxNodeTrainBwd() {
196  return dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD",
197  dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15));
198  }
199 
200  private:
201  friend class NDArray;
203  Imperative() {
204  if (PreferBulkExecTrain())
205  backward_bulk_size_ = BulkExecMaxNodeTrainBwd();
206  }
208  void GetBackwardDependency(
209  const nnvm::ObjectPtr& node,
210  uint32_t num_inputs, uint32_t num_outputs,
211  std::vector<bool> *p_save_inputs,
212  std::vector<bool> *p_save_outputs);
214 #if DMLC_CXX11_THREAD_LOCAL
215  static thread_local bool is_train_;
216  static thread_local bool is_recording_;
217  // TOOD(junwu): Added numpy compatibility switch for backward compatibility.
218  // Delete it in the next major release.
219  static thread_local bool is_np_shape_thread_local_;
220 #else
221  static MX_THREAD_LOCAL bool is_train_;
222  static MX_THREAD_LOCAL bool is_recording_;
223  // TOOD(junwu): Added numpy compatibility switch for backward compatibility.
224  // Delete it in the next major release.
225  static MX_THREAD_LOCAL bool is_np_shape_thread_local_;
226 #endif
227  bool is_np_shape_global_{false};
228  bool is_np_default_dtype_global_{false};
230  std::atomic<uint64_t> node_count_{0};
232  std::atomic<uint64_t> variable_count_{0};
234  int backward_bulk_size_{0};
235 };
236 
237 } // namespace mxnet
238 #endif // MXNET_IMPERATIVE_H_
Definition: imperative.h:48
bool is_recording() const
whether operator recording is on.
Definition: imperative.h:102
static bool IsNone(const NDArray &arr)
Definition: imperative.h:81
static int BulkExecMaxNodeTrainFwd()
The max number of op nodes in a bulk during forward pass of training.
Definition: imperative.h:190
Definition: imperative.h:48
bool is_training() const
whether operator recording is on.
Definition: imperative.h:92
no operation, do not write anything
Definition: op_attr_types.h:48
bool set_is_training(bool is_train)
turn on or turn off operator recording for autograd.
Definition: imperative.h:96
The attributes of the current operation node. Usually are additional parameters like axis...
Definition: node.h:119
namespace of mxnet
Definition: api_registry.h:33
std::vector< NDArray * > Backward(const std::vector< NDArray * > &outputs, const std::vector< NDArray * > &ograds, const std::vector< NDArray * > &variables, bool is_train, bool retain_graph, bool create_graph)
compute the gradient of outputs w.r.t variables.
OpStatePtr Invoke(const Context &default_ctx, const nnvm::NodeAttrs &attrs, const std::vector< NDArray * > &inputs, const std::vector< NDArray * > &outputs)
AGInfo()
Definition: imperative.h:62
ObjectPtr node
the source node of this data
Definition: node.h:75
DispatchMode
the dispatch mode of the operator
Definition: op_attr_types.h:123
std::vector< NDArray > outputs
Definition: imperative.h:58
int is_np_shape() const
return current numpy compatibility status, GlobalOn(2), ThreadLocalOn(1), Off(0). ...
Definition: imperative.h:114
bool is_np_default_dtype() const
return current numpy default dtype compatibility status.
Definition: imperative.h:123
Definition: imperative.h:53
bool set_is_recording(bool is_recording)
turn on or turn off operator recording for autograd.
Definition: imperative.h:106
bool fresh_out_grad
Definition: imperative.h:60
OpStatePtr state
Definition: imperative.h:57
std::vector< NDArray > out_grads
Definition: imperative.h:59
void RecordOp(nnvm::NodeAttrs &&attrs, const std::vector< NDArray * > &inputs, const std::vector< NDArray * > &outputs, const OpStatePtr &state=OpStatePtr(), std::vector< bool > *p_save_inputs=nullptr, std::vector< bool > *p_save_outputs=nullptr)
to record operator, return corresponding node.
OpReqType grad_req
Definition: imperative.h:56
Context ctx
Definition: imperative.h:55
Configuation of nnvm as well as basic data structure.
OpReqType
operation request type to Forward and Backward
Definition: op_attr_types.h:46
OpStatePtr InvokeOp(const Context &ctx, const nnvm::NodeAttrs &attrs, const std::vector< NDArray * > &inputs, const std::vector< NDArray * > &outputs, const std::vector< OpReqType > &req, const DispatchMode dispatch_mode, OpStatePtr state=OpStatePtr())
runtime functions for NDArray
Definition: imperative.h:50
Definition: imperative.h:48
static AGInfo & Create(const nnvm::ObjectPtr &node)
Definition: imperative.h:76
static bool IsVariable(const nnvm::ObjectPtr &node)
Definition: imperative.h:85
Operator information structor.
static bool PreferBulkExecTrain()
Should op execution bulking be employed during training.
Definition: imperative.h:186
Symbolic graph construction API.
static int BulkExecMaxNodeTrainBwd()
The max number of op nodes in a bulk during backward pass of training.
Definition: imperative.h:195
bool set_is_np_shape(int is_np_shape)
specify numpy compatibility off, thread local on or global on.
Definition: imperative.h:131
static bool PreferBulkExecInference()
Should op execution bulking be employed during inference.
Definition: imperative.h:182
void MarkVariables(const std::vector< NDArray * > &variables, const std::vector< uint32_t > &grad_reqs, const std::vector< NDArray * > &gradients)
mark variables for computing gradients.
Context information about the execution environment.
Definition: base.h:102
static void Clear(const nnvm::ObjectPtr &node)
Definition: imperative.h:65
static AGInfo & Get(const nnvm::ObjectPtr &node)
Definition: imperative.h:72
ndarray interface
Definition: ndarray.h:82
NumpyShape
there are three numpy shape flags based on priority. GlobalOn turn on numpy shape flag globally...
Definition: imperative.h:48
std::shared_ptr< Node > ObjectPtr
we always used ObjectPtr for a reference pointer to the node, so this alias can be changed in case...
Definition: node.h:48
Operator state. This is a pointer type, its content is mutable even if OpStatePtr is const...
Definition: op_attr_types.h:149