mxnet
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
executor.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
26 #ifndef MXNET_EXECUTOR_H_
27 #define MXNET_EXECUTOR_H_
28 
29 #include <dmlc/base.h>
30 #include <vector>
31 #include <memory>
32 #include <map>
33 #include <string>
34 #include <utility>
35 #include "./base.h"
36 #include "./c_api.h"
37 #include "./ndarray.h"
38 #include "./operator.h"
39 
40 // check c++11
41 #if DMLC_USE_CXX11 == 0
42 #error "CXX11 was required for symbolic module"
43 #endif
44 
45 namespace mxnet {
47 using nnvm::Symbol;
48 
53 class Executor {
54  public:
56  virtual ~Executor() {}
61  virtual void Forward(bool is_train) = 0;
70  virtual void PartialForward(bool is_train, int step, int *step_left) = 0;
80  virtual void Backward(const std::vector<NDArray> &head_grads, bool is_train = true) = 0;
85  virtual void Print(std::ostream &os) const {} // NOLINT(*)
90  virtual const std::vector<NDArray> &outputs() const = 0;
95  virtual const std::unordered_map<std::string, NDArray>& in_arg_map() const = 0;
100  virtual const std::unordered_map<std::string, NDArray>& arg_grad_map() const = 0;
105  virtual const std::unordered_map<std::string, NDArray>& aux_state_map() const = 0;
120  static Executor *Bind(nnvm::Symbol symbol,
121  const Context& default_ctx,
122  const std::map<std::string, Context>& group2ctx,
123  const std::vector<NDArray> &in_args,
124  const std::vector<NDArray> &arg_grad_store,
125  const std::vector<OpReqType> &grad_req_type,
126  const std::vector<NDArray> &aux_states,
127  Executor* shared_exec = NULL);
128 
129  static Executor* SimpleBind(nnvm::Symbol symbol,
130  const Context& default_ctx,
131  const std::map<std::string, Context>& group2ctx,
132  const std::vector<Context>& in_arg_ctxes,
133  const std::vector<Context>& arg_grad_ctxes,
134  const std::vector<Context>& aux_state_ctxes,
135  const std::unordered_map<std::string, TShape>& arg_shape_map,
136  const std::unordered_map<std::string, int>& arg_dtype_map,
137  const std::unordered_map<std::string, int>& arg_stype_map,
138  const std::vector<OpReqType>& grad_req_types,
139  const std::unordered_set<std::string>& param_names,
140  std::vector<NDArray>* in_args,
141  std::vector<NDArray>* arg_grads,
142  std::vector<NDArray>* aux_states,
143  std::unordered_map<std::string, NDArray>*
144  shared_data_arrays = nullptr,
145  Executor* shared_exec = nullptr);
149  typedef std::function<void(const char*, void*)> MonitorCallback;
153  virtual void SetMonitorCallback(const MonitorCallback& callback) {}
154 }; // class executor
155 } // namespace mxnet
156 #endif // MXNET_EXECUTOR_H_
Executor of a computation graph. Executor can be created by Binding a symbol.
Definition: executor.h:53
C API of mxnet.
virtual ~Executor()
destructor
Definition: executor.h:56
static Executor * SimpleBind(nnvm::Symbol symbol, const Context &default_ctx, const std::map< std::string, Context > &group2ctx, const std::vector< Context > &in_arg_ctxes, const std::vector< Context > &arg_grad_ctxes, const std::vector< Context > &aux_state_ctxes, const std::unordered_map< std::string, TShape > &arg_shape_map, const std::unordered_map< std::string, int > &arg_dtype_map, const std::unordered_map< std::string, int > &arg_stype_map, const std::vector< OpReqType > &grad_req_types, const std::unordered_set< std::string > &param_names, std::vector< NDArray > *in_args, std::vector< NDArray > *arg_grads, std::vector< NDArray > *aux_states, std::unordered_map< std::string, NDArray > *shared_data_arrays=nullptr, Executor *shared_exec=nullptr)
std::function< void(const char *, void *)> MonitorCallback
the prototype of user-defined monitor callback
Definition: executor.h:149
virtual void Backward(const std::vector< NDArray > &head_grads, bool is_train=true)=0
Perform a Backward operation of the Operator. This must be called after Forward. After this operation...
static Executor * Bind(nnvm::Symbol symbol, const Context &default_ctx, const std::map< std::string, Context > &group2ctx, const std::vector< NDArray > &in_args, const std::vector< NDArray > &arg_grad_store, const std::vector< OpReqType > &grad_req_type, const std::vector< NDArray > &aux_states, Executor *shared_exec=NULL)
Create an operator by bind symbol with context and arguments. If user do not want to compute the grad...
virtual void Print(std::ostream &os) const
print the execution plan info to output stream.
Definition: executor.h:85
virtual const std::unordered_map< std::string, NDArray > & in_arg_map() const =0
get input argument map, key is arg name, value is arg's NDArray.
virtual const std::vector< NDArray > & outputs() const =0
get array of outputs in the executor.
virtual void SetMonitorCallback(const MonitorCallback &callback)
Install a callback to notify the completion of operation.
Definition: executor.h:153
NDArray interface that handles array arithematics.
virtual void Forward(bool is_train)=0
Perform a Forward operation of Operator After this operation, user can get the result by using functi...
Operator interface of mxnet.
virtual void PartialForward(bool is_train, int step, int *step_left)=0
Perform a Partial Forward operation of Operator. Only issue operation specified by step...
configuation of mxnet as well as basic data structure.
virtual const std::unordered_map< std::string, NDArray > & arg_grad_map() const =0
get input argument graident map, key is arg name, value is gradient's NDArray.
virtual const std::unordered_map< std::string, NDArray > & aux_state_map() const =0
get aux state map, key is arg name, value is aux state's NDArray.
Context information about the execution environment.
Definition: base.h:142