mxnet
executor.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
26 #ifndef MXNET_CPP_EXECUTOR_H_
27 #define MXNET_CPP_EXECUTOR_H_
28 
29 #include <vector>
30 #include <map>
31 #include <set>
32 #include <string>
33 #include <algorithm>
34 #include "mxnet-cpp/base.h"
35 #include "mxnet-cpp/symbol.h"
36 
37 namespace mxnet {
38 namespace cpp {
39 
40 class Optimizer;
41 
45 class Executor {
46  public:
47  Executor(const Symbol& symbol,
48  Context context,
49  const std::vector<NDArray>& arg_arrays,
50  const std::vector<NDArray>& grad_arrays,
51  const std::vector<OpReqType>& grad_reqs,
52  const std::vector<NDArray>& aux_arrays,
53  const std::map<std::string, Context>& group_to_ctx = std::map<std::string, Context>(),
54  Executor* shared_exec = nullptr);
55  explicit Executor(const CachedOpHandle& h) {
56  handle_ = h;
57  }
62  void Forward(bool is_train) {
63  std::vector<NDArrayHandle> arg_handles;
64  for (const auto& array : combined_arrays) {
65  arg_handles.push_back(array.GetHandle());
66  }
67  int prev_is_record = 0;
68  int prev_train_mode = 0;
69  CHECK_EQ(MXAutogradSetIsRecording(1, &prev_is_record), 0);
70  if (is_train == true) {
71  CHECK_EQ(MXAutogradSetIsTraining(1, &prev_train_mode), 0);
72  }
73  std::vector<NDArrayHandle> output_handles;
74  std::transform(
75  outputs.begin(), outputs.end(), std::back_inserter(output_handles), [](NDArray& a) {
76  return a.GetHandle();
77  });
78  int out_size = 0;
79  NDArrayHandle* out_array = nullptr;
80  CHECK_EQ(MXInvokeCachedOp(handle_,
81  arg_handles.size(),
82  arg_handles.data(),
84  device_id,
85  &out_size,
86  &out_array,
87  nullptr),
88  0);
89  outputs.clear();
90  outputs.reserve(out_size);
91  for (mx_uint i = 0; i < out_size; ++i) {
92  outputs.push_back(NDArray(out_array[i]));
93  }
94  int cur_train_mode = prev_train_mode;
95  int cur_is_record = prev_is_record;
96  if (is_train == true) {
97  CHECK_EQ(MXAutogradSetIsTraining(cur_train_mode, &prev_train_mode), 0);
98  }
99  CHECK_EQ(MXAutogradSetIsRecording(cur_is_record, &prev_is_record), 0);
100  }
111  void Backward(const std::vector<NDArray>& head_grads = std::vector<NDArray>()) {
112  if (require_grad == true) {
113  if (outputs.size() == 0) {
114  Forward(false);
115  }
116  std::vector<NDArrayHandle> out_handles;
117  for (const auto& array : outputs) {
118  out_handles.push_back(array.GetHandle());
119  }
120  std::vector<NDArrayHandle> head_grads_;
121  for (auto d : head_grads) {
122  head_grads_.push_back(d.GetHandle());
123  }
124  if (head_grads_.size() > 0) {
125  CHECK_EQ(MXAutogradBackwardEx(out_handles.size(),
126  out_handles.data(),
127  head_grads_.data(),
128  0,
129  nullptr,
130  0,
131  0,
132  1,
133  nullptr,
134  nullptr),
135  0);
136  } else {
137  CHECK_EQ(MXAutogradBackwardEx(out_handles.size(),
138  out_handles.data(),
139  nullptr,
140  0,
141  nullptr,
142  0,
143  0,
144  1,
145  nullptr,
146  nullptr),
147  0);
148  }
149  grad_arrays.clear();
150  grad_arrays.reserve(arg_arrays.size());
151  for (const auto& array : arg_arrays) {
152  NDArrayHandle grad;
153  CHECK_EQ(MXNDArrayGetGrad(array.GetHandle(), &grad), 0);
154  grad_arrays.push_back(NDArray(grad));
155  }
156  }
157  }
158  // TODO(zhangchen-qinyinghua)
159  // To implement reshape function
160  void Reshape();
165  MXFreeCachedOp(handle_);
166  }
167  std::vector<NDArray> arg_arrays;
168  std::vector<NDArray> grad_arrays;
169  std::vector<NDArray> aux_arrays;
170  std::vector<NDArray> combined_arrays;
177  std::vector<NDArray> outputs;
178  std::map<std::string, NDArray> arg_dict() {
179  return GetDict(symbol_.ListArguments(), arg_arrays);
180  }
181  std::map<std::string, NDArray> grad_dict() {
182  return GetDict(symbol_.ListArguments(), grad_arrays);
183  }
184  std::map<std::string, NDArray> aux_dict() {
185  return GetDict(symbol_.ListAuxiliaryStates(), aux_arrays);
186  }
187 
188  private:
189  Executor(const Executor& e);
190  Executor& operator=(const Executor& e);
191  CachedOpHandle handle_;
192  Symbol symbol_;
193  std::map<std::string, NDArray> GetDict(const std::vector<std::string>& names,
194  const std::vector<NDArray>& arrays) {
195  std::map<std::string, NDArray> ret;
196  std::set<std::string> name_set;
197  for (const auto& s : names) {
198  CHECK(name_set.find(s) == name_set.end()) << "Duplicate names detected, " << s;
199  name_set.insert(s);
200  }
201  CHECK_EQ(name_set.size(), arrays.size()) << "names size not equal to arrays size";
202  for (size_t i = 0; i < names.size(); ++i) {
203  ret[names[i]] = arrays[i];
204  }
205  return ret;
206  }
207 };
208 } // namespace cpp
209 } // namespace mxnet
210 #endif // MXNET_CPP_EXECUTOR_H_
mxnet
namespace of mxnet
Definition: api_registry.h:33
MXAutogradSetIsTraining
MXNET_DLL int MXAutogradSetIsTraining(int is_training, int *prev)
set whether to record operator for autograd
CachedOpHandle
void * CachedOpHandle
handle to cached operator
Definition: c_api.h:80
mxnet::cpp::Executor::grad_dict
std::map< std::string, NDArray > grad_dict()
Definition: executor.h:181
mxnet::cpp::Executor::Backward
void Backward(const std::vector< NDArray > &head_grads=std::vector< NDArray >())
Perform a Backward operation of the Operator. This must be called after Forward. After this operation...
Definition: executor.h:111
mxnet::cpp::Executor::device_type
int device_type
Definition: executor.h:171
mxnet::cpp::Executor::outputs
std::vector< NDArray > outputs
arrays store the outputs of forward
Definition: executor.h:177
MXAutogradBackwardEx
MXNET_DLL int MXAutogradBackwardEx(uint32_t num_output, NDArrayHandle *output_handles, NDArrayHandle *ograd_handles, uint32_t num_variables, NDArrayHandle *var_handles, int retain_graph, int create_graph, int is_train, NDArrayHandle **grad_handles, int **grad_stypes)
compute the gradient of outputs w.r.t variabels
mxnet::cpp::Symbol::ListArguments
std::vector< std::string > ListArguments() const
List the arguments names.
mxnet::cpp::Executor::~Executor
~Executor()
destructor, free the handle
Definition: executor.h:164
mxnet::cpp::Executor::aux_dict
std::map< std::string, NDArray > aux_dict()
Definition: executor.h:184
mxnet::cpp::Executor::Forward
void Forward(bool is_train)
Perform a Forward operation of Operator After this operation, user can get the result by using functi...
Definition: executor.h:62
mxnet::cpp::NDArray
NDArray interface.
Definition: ndarray.h:122
mxnet::cpp::Context
Context interface.
Definition: ndarray.h:45
mxnet::cpp::Symbol::ListAuxiliaryStates
std::vector< std::string > ListAuxiliaryStates() const
mxnet::cpp::Executor::grad_arrays
std::vector< NDArray > grad_arrays
Definition: executor.h:168
mxnet::cpp::Executor::Executor
Executor(const CachedOpHandle &h)
Definition: executor.h:55
mxnet::cpp::Executor::Reshape
void Reshape()
symbol.h
definition of symbol
MXFreeCachedOp
MXNET_DLL int MXFreeCachedOp(CachedOpHandle handle)
free cached operator
mxnet::cpp::Executor::combined_arrays
std::vector< NDArray > combined_arrays
Definition: executor.h:170
MXNDArrayGetGrad
MXNET_DLL int MXNDArrayGetGrad(NDArrayHandle handle, NDArrayHandle *out)
return gradient buffer attached to this NDArray
mxnet::NDArrayHandle
Definition: ndarray_handle.h:40
mxnet::cpp::Executor
Executor interface.
Definition: executor.h:45
mxnet::cpp::Executor::arg_dict
std::map< std::string, NDArray > arg_dict()
Definition: executor.h:178
MXInvokeCachedOp
MXNET_DLL int MXInvokeCachedOp(CachedOpHandle handle, int num_inputs, NDArrayHandle *inputs, int default_dev_type, int default_dev_id, int *num_outputs, NDArrayHandle **outputs, const int **out_stypes)
invoke a cached op
base.h
base definitions for mxnetcpp
mxnet::cpp::Executor::arg_arrays
std::vector< NDArray > arg_arrays
Definition: executor.h:167
mxnet::cpp::Executor::device_id
int device_id
Definition: executor.h:172
MXAutogradSetIsRecording
MXNET_DLL int MXAutogradSetIsRecording(int is_recording, int *prev)
set whether to record operator for autograd
mx_uint
uint32_t mx_uint
manually define unsigned int
Definition: c_api.h:65
mxnet::cpp::Executor::require_grad
bool require_grad
Definition: executor.h:173
mxnet::cpp::Executor::aux_arrays
std::vector< NDArray > aux_arrays
Definition: executor.h:169
mxnet::cpp::Executor::Executor
Executor(const Symbol &symbol, Context context, const std::vector< NDArray > &arg_arrays, const std::vector< NDArray > &grad_arrays, const std::vector< OpReqType > &grad_reqs, const std::vector< NDArray > &aux_arrays, const std::map< std::string, Context > &group_to_ctx=std::map< std::string, Context >(), Executor *shared_exec=nullptr)
mxnet::cpp::Symbol
Symbol interface.
Definition: symbol.h:73