mxnet
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
op_attr_types.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
26 #ifndef MXNET_OP_ATTR_TYPES_H_
27 #define MXNET_OP_ATTR_TYPES_H_
28 
29 #include <mshadow/tensor.h>
30 #include <nnvm/op_attr_types.h>
31 
32 #include <vector>
33 #include <functional>
34 
35 #include "./base.h"
36 #include "./ndarray.h"
37 #include "./engine.h"
38 #include "./resource.h"
39 
40 namespace mxnet {
41 
42 using nnvm::NodeAttrs;
43 
45 enum OpReqType {
58 };
59 
66 struct OpContext {
68  int is_train;
74  std::vector<Resource> requested;
80  template<typename xpu>
81  inline mshadow::Stream<xpu>* get_stream() const {
82  return run_ctx.get_stream<xpu>();
83  }
84 };
85 
87 enum class ExecType {
89  kSync,
94  kAsync,
96  kLocal,
104 };
105 
107 enum class DispatchMode {
108  kUndefined = -1,
109  // dispatch on FCompute or FStatefulCompute
110  kFCompute,
111  // dispatch on FComputeEx or FStatefulComputeEx, if available
112  kFComputeEx,
113  // dispatch on FCompute or FStatefulCompute, and performs storage fallback
115  // special dispatch mode for variables
116  kVariable,
117 };
118 
123 class OpStatePtr {
124  public:
125  /* \brief Create a OpStatePtr with state of type T.
126  * \param args Arguments passed to T's constructor.
127  */
128  template<typename T, typename... Args>
129  static OpStatePtr Create(Args&&... args) {
130  OpStatePtr ret;
131  ret.ptr_ = std::make_shared<OpState>();
132  ret.ptr_->var_ = Engine::Get()->NewVariable();
133  ret.ptr_->state_.construct<T>(std::forward<Args>(args)...);
134 
135  return ret;
136  }
137  /* \brief Get engine variable associated with this state */
139  return ptr_->var_;
140  }
141  /* \brief Get state of type T */
142  template<typename T>
143  T& get_state() const {
144  return dmlc::get<T>(ptr_->state_);
145  }
146  /* \brief clear state */
147  void reset() {
148  ptr_.reset();
149  }
150  /* \brief Whether state is empty */
151  explicit operator bool() const {
152  return ptr_ ? true : false;
153  }
154 
155  private:
156  /* \brief state structure */
157  struct OpState {
158  OpState() {}
159  OpState(const OpState& other) = delete;
160  OpState& operator=(const OpState& other) = delete;
161 
162  ~OpState() {
163  Engine::Get()->DeleteVariable([](RunContext s) {}, Context::CPU(), var_);
164  }
165 
166  engine::VarHandle var_;
167  dmlc::any state_;
168  };
169  /* \brief shared pointer to state */
170  std::shared_ptr<OpState> ptr_;
171 };
172 
185 using FCreateOpState = std::function<OpStatePtr (const NodeAttrs& attrs,
186  Context ctx,
187  const std::vector<TShape>& in_shape,
188  const std::vector<int>& in_type)>;
192 using FExecType = std::function<ExecType (const NodeAttrs& attrs)>;
200 using FStatefulCompute = std::function<void (const OpStatePtr& state,
201  const OpContext& ctx,
202  const std::vector<TBlob>& inputs,
203  const std::vector<OpReqType>& req,
204  const std::vector<TBlob>& outputs)>;
212 using FStatefulComputeEx = std::function<void (const OpStatePtr& state,
213  const OpContext& ctx,
214  const std::vector<NDArray>& inputs,
215  const std::vector<OpReqType>& req,
216  const std::vector<NDArray>& outputs)>;
222 using FResourceRequest = std::function<
223  std::vector<ResourceRequest> (const NodeAttrs& n)>;
229 using FNDArrayFunction = std::function<void (const nnvm::NodeAttrs& attrs,
230  const std::vector<NDArray>& inputs,
231  std::vector<NDArray>* outputs)>;
237 using FCompute = std::function<void (const nnvm::NodeAttrs& attrs,
238  const OpContext& ctx,
239  const std::vector<TBlob>& inputs,
240  const std::vector<OpReqType>& req,
241  const std::vector<TBlob>& outputs)>;
248 using FComputeEx = std::function<void (const nnvm::NodeAttrs& attrs,
249  const OpContext& ctx,
250  const std::vector<NDArray>& inputs,
251  const std::vector<OpReqType>& req,
252  const std::vector<NDArray>& outputs)>;
253 
260 using FInferStorageType = std::function<bool (const NodeAttrs& attrs,
261  const int dev_mask,
262  DispatchMode* dispatch_mode,
263  std::vector<int>* in_attrs,
264  std::vector<int>* out_attrs)>;
265 
266 } // namespace mxnet
267 
268 #endif // MXNET_OP_ATTR_TYPES_H_
std::function< void(const OpStatePtr &state, const OpContext &ctx, const std::vector< TBlob > &inputs, const std::vector< OpReqType > &req, const std::vector< TBlob > &outputs)> FStatefulCompute
Resiger a compute function for stateful operator. OpStatePtr is a pointer type, it's content is mutab...
Definition: op_attr_types.h:204
void reset()
Definition: op_attr_types.h:147
Forward/Backward are synchronize calls.
Engine that schedules all the operations according to dependency.
no operation, do not write anything
Definition: op_attr_types.h:47
write gradient to provided space
Definition: op_attr_types.h:49
std::function< std::vector< ResourceRequest >(const NodeAttrs &n)> FResourceRequest
The resource request from the operator.
Definition: op_attr_types.h:223
std::function< void(const OpStatePtr &state, const OpContext &ctx, const std::vector< NDArray > &inputs, const std::vector< OpReqType > &req, const std::vector< NDArray > &outputs)> FStatefulComputeEx
Resiger a compute function for stateful operator using NDArray interface. OpStatePtr is a pointer typ...
Definition: op_attr_types.h:216
mshadow::Stream< xpu > * get_stream() const
get mshadow stream from Context
Definition: base.h:266
Asynchronous function call.
engine::VarHandle get_var() const
Definition: op_attr_types.h:138
Cross device copy operation, this is a special operator That indicates copy across devices...
execution time context. The information needed in runtime for actual execution.
Definition: base.h:253
DispatchMode
the dispatch mode of the operator
Definition: op_attr_types.h:107
Run this operator on the scheduling thread without pushing to engine.
T & get_state() const
Definition: op_attr_types.h:143
engine::CallbackOnComplete async_on_complete
the callback when operation completes, used by asynchronize ops
Definition: op_attr_types.h:72
std::function< OpStatePtr(const NodeAttrs &attrs, Context ctx, const std::vector< TShape > &in_shape, const std::vector< int > &in_type)> FCreateOpState
Create a Layer style, forward/backward operator. This is easy to write code that contains state...
Definition: op_attr_types.h:188
All the possible information needed by Operator.Forward and Backward This is the superset of RunConte...
Definition: op_attr_types.h:66
NDArray interface that handles array arithematics.
int is_train
whether it is training phase
Definition: op_attr_types.h:68
std::function< ExecType(const NodeAttrs &attrs)> FExecType
Execution mode of this operator.
Definition: op_attr_types.h:192
virtual VarHandle NewVariable()=0
Allocate a new variable, the variable can then be used to schedule the operation concurrently via dep...
static OpStatePtr Create(Args &&...args)
Definition: op_attr_types.h:129
std::function< void(const nnvm::NodeAttrs &attrs, const std::vector< NDArray > &inputs, std::vector< NDArray > *outputs)> FNDArrayFunction
Register an operator called as a NDArray function.
Definition: op_attr_types.h:231
Global resource allocation handling.
virtual void DeleteVariable(SyncFn delete_fn, Context exec_ctx, VarHandle var)=0
Schedule the deletion of a variable.
Var * VarHandle
Variable pointer type, usually hold by user used to specify dependencies.
Definition: engine.h:47
perform an inplace write, Target shares memory with one of input arguments. This option only happen w...
Definition: op_attr_types.h:55
OpReqType
operation request type to Forward and Backward
Definition: op_attr_types.h:45
std::function< void(const nnvm::NodeAttrs &attrs, const OpContext &ctx, const std::vector< NDArray > &inputs, const std::vector< OpReqType > &req, const std::vector< NDArray > &outputs)> FComputeEx
Resiger an NDArray compute function for simple stateless forward only operator.
Definition: op_attr_types.h:252
std::vector< Resource > requested
Resources requested by the operator.
Definition: op_attr_types.h:74
RunContext run_ctx
RunContext related resources.
Definition: op_attr_types.h:70
static Context CPU(int32_t dev_id=0)
std::function< void(const nnvm::NodeAttrs &attrs, const OpContext &ctx, const std::vector< TBlob > &inputs, const std::vector< OpReqType > &req, const std::vector< TBlob > &outputs)> FCompute
Resiger a compute function for simple stateless forward only operator.
Definition: op_attr_types.h:241
OnComplete Callback to the engine, called by AsyncFn when action completes.
Definition: engine.h:56
configuation of mxnet as well as basic data structure.
static Engine * Get()
std::function< bool(const NodeAttrs &attrs, const int dev_mask, DispatchMode *dispatch_mode, std::vector< int > *in_attrs, std::vector< int > *out_attrs)> FInferStorageType
Resiger a storage and dispatch mode inference function based on storage types of the inputs and outpu...
Definition: op_attr_types.h:264
add to the provided space
Definition: op_attr_types.h:57
ExecType
the execution type of the operator
Definition: op_attr_types.h:87
mshadow::Stream< xpu > * get_stream() const
get mshadow stream from Context
Definition: op_attr_types.h:81
Operator state. This is a pointer type, its content is mutable even if OpStatePtr is const...
Definition: op_attr_types.h:123