mxnet
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
imperative.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 #ifndef MXNET_IMPERATIVE_H_
21 #define MXNET_IMPERATIVE_H_
22 
23 #include <mxnet/op_attr_types.h>
24 #include <mxnet/graph_attr_types.h>
25 #include <mxnet/c_api.h>
26 #include <nnvm/symbolic.h>
27 #include <nnvm/op.h>
28 #include <nnvm/graph.h>
29 #include <vector>
30 #include <atomic>
31 #include <utility>
32 #include <string>
33 #include <unordered_map>
34 
35 #include "./ndarray.h"
36 
37 namespace mxnet {
39 struct CachedOpParam : public dmlc::Parameter<CachedOpParam> {
40  uint32_t inline_limit;
44  DMLC_DECLARE_FIELD(inline_limit)
45  .set_default(2)
46  .describe("Maximum number of operators that can be inlined.");
47  DMLC_DECLARE_FIELD(forward_bulk_size)
48  .set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15))
49  .describe("Segment size of bulk execution during forward pass.");
50  DMLC_DECLARE_FIELD(backward_bulk_size)
51  .set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15))
52  .describe("Segment size of bulk execution during backward pass.");
53  }
54 };
56 class Imperative {
57  public:
59  class AGInfo {
60  public:
64  std::vector<NDArray> outputs;
65  std::vector<NDArray> out_grads;
67 
68  AGInfo() :
69  grad_req(kNullOp), fresh_out_grad(false) {}
70 
71  static void Clear(const nnvm::NodePtr& node) {
72  if (node == nullptr || node->info.empty()) return;
73  AGInfo& info = Get(node);
74  if (info.grad_req != kNullOp) return;
75  node->info.clear();
76  }
77 
78  static AGInfo& Get(const nnvm::NodePtr& node) {
79  return dmlc::get<AGInfo>(node->info);
80  }
81 
82  static AGInfo& Create(const nnvm::NodePtr& node) {
83  node->info.construct<AGInfo>();
84  return Get(node);
85  }
86 
87  static bool IsNone(const NDArray& arr) {
88  return arr.entry_.node == nullptr || arr.entry_.node->info.empty();
89  }
90 
91  static bool IsVariable(const nnvm::NodePtr& node) {
92  AGInfo& info = Get(node);
93  return info.grad_req != kNullOp && info.outputs.size() == 1
94  && info.out_grads.size() == 1;
95  }
96  };
97  class CachedOp {
98  public:
99  CachedOp(const nnvm::Symbol& sym,
100  const std::vector<std::pair<std::string, std::string> >& kwargs);
101  uint32_t num_inputs() {
102  return fwd_graph_.indexed_graph().input_nodes().size();
103  }
104  uint32_t num_outputs() {
105  return fwd_graph_.outputs.size();
106  }
107  uint32_t num_backward_inputs() {
108  return bwd_ograd_dep_.size() + bwd_in_dep_.size() + bwd_out_dep_.size();
109  }
110  std::vector<bool>& save_inputs() {
111  return save_inputs_;
112  }
113  std::vector<bool>& save_outputs() {
114  return save_outputs_;
115  }
116  const std::unordered_set<uint32_t>& mutable_input_nodes() {
117  return fwd_graph_.indexed_graph().mutable_input_nodes();
118  }
119  nnvm::Graph GetForwardGraph(const bool recording,
120  const std::vector<NDArray*>& inputs);
121  nnvm::Graph GetBackwardGraph(const OpStatePtr& state,
122  const std::vector<OpReqType>& reqs,
123  const std::vector<NDArray*>& inputs);
124  std::vector<nnvm::NodeEntry> Gradient(const nnvm::NodePtr& node,
125  const std::vector<nnvm::NodeEntry>& ograds);
126  void Forward(const std::shared_ptr<CachedOp>& op_ptr,
127  const std::vector<NDArray*>& inputs,
128  const std::vector<NDArray*>& outputs);
129  void Backward(const bool retain_graph,
130  const OpStatePtr& state,
131  const std::vector<NDArray*>& inputs,
132  const std::vector<OpReqType>& reqs,
133  const std::vector<NDArray*>& outputs);
134 
135  private:
136  struct CachedOpState {
137  std::vector<NDArray> buff;
138  std::vector<OpStatePtr> states;
139  };
140  std::mutex mutex_;
141  CachedOpParam param_;
142  nnvm::Graph fwd_graph_;
143  nnvm::Graph grad_graph_;
144  nnvm::Graph full_graph_;
145  bool inlining_;
146  std::vector<nnvm::NodeEntry> ograd_entries_;
147  std::vector<bool> curr_grad_req_;
148  std::vector<uint32_t> bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_;
149  std::vector<uint32_t> bwd_input_eid_;
150  std::vector<bool> save_inputs_, save_outputs_;
151  };
153  bool is_training() const {
154  return is_train_;
155  }
157  bool set_is_training(bool is_train) {
158  bool old = is_train_;
159  is_train_ = is_train;
160  return old;
161  }
163  bool is_recording() const {
164  return is_recording_;
165  }
168  bool old = is_recording_;
169  is_recording_ = is_recording;
170  return old;
171  }
173  void RecordOp(nnvm::NodeAttrs&& attrs,
174  const std::vector<NDArray*>& inputs,
175  const std::vector<NDArray*>& outputs,
176  const OpStatePtr& state = OpStatePtr(),
177  std::vector<bool>* p_save_inputs = nullptr,
178  std::vector<bool>* p_save_outputs = nullptr);
180  OpStatePtr Invoke(const Context& default_ctx,
181  const nnvm::NodeAttrs& attrs,
182  const std::vector<NDArray*>& inputs,
183  const std::vector<NDArray*>& outputs);
185  OpStatePtr InvokeOp(const Context& ctx,
186  const nnvm::NodeAttrs& attrs,
187  const std::vector<NDArray*>& inputs,
188  const std::vector<NDArray*>& outputs,
189  const std::vector<OpReqType>& req,
190  const DispatchMode dispatch_mode,
191  OpStatePtr state = OpStatePtr());
193  void MarkVariables(const std::vector<NDArray*>& variables,
194  const std::vector<mx_uint>& grad_reqs,
195  const std::vector<NDArray*>& gradients);
197  std::vector<NDArray*> Backward(const std::vector<NDArray*>& outputs,
198  const std::vector<NDArray*>& ograds,
199  const std::vector<NDArray*>& variables,
200  bool is_train, bool retain_graph,
201  bool create_graph);
203  static Imperative* Get();
204 
205  private:
206  friend class NDArray;
208  Imperative() {
209  if (dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_TRAIN", 1)) {
210  backward_bulk_size_ = dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15);
211  }
212  }
214  void GetBackwardDependency(
215  const nnvm::NodePtr& node,
216  uint32_t num_inputs, uint32_t num_outputs,
217  std::vector<bool> *p_save_inputs,
218  std::vector<bool> *p_save_outputs);
219  void RunGraph(
220  const bool retain_graph,
221  const nnvm::IndexedGraph& idx,
222  const std::vector<NDArray*> arrays,
223  size_t node_start, size_t node_end,
224  std::vector<OpReqType>&& array_reqs,
225  std::vector<uint32_t>&& ref_count,
226  std::vector<OpStatePtr> *p_states,
227  const DispatchModeVector& dispatch_modes);
229 #if DMLC_CXX11_THREAD_LOCAL
230  static thread_local bool is_train_;
231  static thread_local bool is_recording_;
232 #else
233  static MX_THREAD_LOCAL bool is_train_;
234  static MX_THREAD_LOCAL bool is_recording_;
235 #endif
236 
237  std::atomic<uint64_t> node_count_{0};
239  std::atomic<uint64_t> variable_count_{0};
241  int backward_bulk_size_{0};
242 };
243 
244 using CachedOpPtr = std::shared_ptr<Imperative::CachedOp>;
245 
246 } // namespace mxnet
247 #endif // MXNET_IMPERATIVE_H_
std::vector< nnvm::NodeEntry > Gradient(const nnvm::NodePtr &node, const std::vector< nnvm::NodeEntry > &ograds)
C API of mxnet.
bool is_recording() const
whether operator recording is on.
Definition: imperative.h:163
static bool IsNone(const NDArray &arr)
Definition: imperative.h:87
uint32_t num_outputs()
Definition: imperative.h:104
uint32_t backward_bulk_size
Definition: imperative.h:42
static AGInfo & Create(const nnvm::NodePtr &node)
Definition: imperative.h:82
void Forward(const std::shared_ptr< CachedOp > &op_ptr, const std::vector< NDArray * > &inputs, const std::vector< NDArray * > &outputs)
std::vector< DispatchMode > DispatchModeVector
The result holder of dispatch mode of each Node in the graph.
Definition: graph_attr_types.h:60
nnvm::Graph GetBackwardGraph(const OpStatePtr &state, const std::vector< OpReqType > &reqs, const std::vector< NDArray * > &inputs)
bool is_training() const
whether operator recording is on.
Definition: imperative.h:153
no operation, do not write anything
Definition: op_attr_types.h:47
bool set_is_training(bool is_train)
turn on or turn off operator recording for autograd.
Definition: imperative.h:157
uint32_t inline_limit
Definition: imperative.h:40
CachedOp Parameters.
Definition: imperative.h:39
std::vector< NDArray * > Backward(const std::vector< NDArray * > &outputs, const std::vector< NDArray * > &ograds, const std::vector< NDArray * > &variables, bool is_train, bool retain_graph, bool create_graph)
compute the gradient of outputs w.r.t variables.
static void Clear(const nnvm::NodePtr &node)
Definition: imperative.h:71
Additional operator attributes beside the ones provided by NNVM.
OpStatePtr Invoke(const Context &default_ctx, const nnvm::NodeAttrs &attrs, const std::vector< NDArray * > &inputs, const std::vector< NDArray * > &outputs)
AGInfo()
Definition: imperative.h:68
DispatchMode
the dispatch mode of the operator
Definition: op_attr_types.h:107
CachedOp(const nnvm::Symbol &sym, const std::vector< std::pair< std::string, std::string > > &kwargs)
std::vector< NDArray > outputs
Definition: imperative.h:64
NDArray interface that handles array arithematics.
Definition: imperative.h:59
uint32_t num_backward_inputs()
Definition: imperative.h:107
bool set_is_recording(bool is_recording)
turn on or turn off operator recording for autograd.
Definition: imperative.h:167
Definition: imperative.h:97
void Backward(const bool retain_graph, const OpStatePtr &state, const std::vector< NDArray * > &inputs, const std::vector< OpReqType > &reqs, const std::vector< NDArray * > &outputs)
bool fresh_out_grad
Definition: imperative.h:66
OpStatePtr state
Definition: imperative.h:63
std::vector< NDArray > out_grads
Definition: imperative.h:65
Data structures that can appear in graph attributes.
void RecordOp(nnvm::NodeAttrs &&attrs, const std::vector< NDArray * > &inputs, const std::vector< NDArray * > &outputs, const OpStatePtr &state=OpStatePtr(), std::vector< bool > *p_save_inputs=nullptr, std::vector< bool > *p_save_outputs=nullptr)
to record operator, return corresponding node.
std::shared_ptr< Imperative::CachedOp > CachedOpPtr
Definition: imperative.h:244
OpReqType grad_req
Definition: imperative.h:62
uint32_t forward_bulk_size
Definition: imperative.h:41
DMLC_DECLARE_PARAMETER(CachedOpParam)
Definition: imperative.h:43
Context ctx
Definition: imperative.h:61
OpReqType
operation request type to Forward and Backward
Definition: op_attr_types.h:45
std::vector< bool > & save_inputs()
Definition: imperative.h:110
OpStatePtr InvokeOp(const Context &ctx, const nnvm::NodeAttrs &attrs, const std::vector< NDArray * > &inputs, const std::vector< NDArray * > &outputs, const std::vector< OpReqType > &req, const DispatchMode dispatch_mode, OpStatePtr state=OpStatePtr())
static bool IsVariable(const nnvm::NodePtr &node)
Definition: imperative.h:91
static Imperative * Get()
runtime functions for NDArray
Definition: imperative.h:56
static AGInfo & Get(const nnvm::NodePtr &node)
Definition: imperative.h:78
const std::unordered_set< uint32_t > & mutable_input_nodes()
Definition: imperative.h:116
void MarkVariables(const std::vector< NDArray * > &variables, const std::vector< mx_uint > &grad_reqs, const std::vector< NDArray * > &gradients)
mark variables for computing gradients.
uint32_t num_inputs()
Definition: imperative.h:101
Context information about the execution environment.
Definition: base.h:142
nnvm::Graph GetForwardGraph(const bool recording, const std::vector< NDArray * > &inputs)
ndarray interface
Definition: ndarray.h:79
std::vector< bool > & save_outputs()
Definition: imperative.h:113
Operator state. This is a pointer type, its content is mutable even if OpStatePtr is const...
Definition: op_attr_types.h:123