Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  *
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
29 #include <mshadow/tensor.h>
30 #include <nnvm/op_attr_types.h>
32 #include <vector>
33 #include <functional>
34 #include <string>
36 #include "./base.h"
37 #include "./ndarray.h"
38 #include "./engine.h"
39 #include "./resource.h"
41 namespace mxnet {
43 using nnvm::NodeAttrs;
46 enum OpReqType {
59 };
67 struct OpContext {
69  bool need_grad;
71  bool is_train;
77  std::vector<Resource> requested;
83  template<typename xpu>
84  inline mshadow::Stream<xpu>* get_stream() const {
85  return run_ctx.get_stream<xpu>();
86  }
93  return run_ctx.get_gpu_aux_stream();
94  }
95 #endif
96 };
99 enum class ExecType {
101  kSync,
106  kAsync,
120 };
123 enum class DispatchMode {
124  kUndefined = -1,
125  // dispatch on FCompute or FStatefulCompute
126  kFCompute,
127  // dispatch on FComputeEx or FStatefulComputeEx, if available
128  kFComputeEx,
129  // dispatch on FCompute or FStatefulCompute, and performs storage fallback
131  // special dispatch mode for variables
132  kVariable,
133 };
136 enum class QuantizeType {
137  // This operator doesn't support quantization
138  kNone = 0,
139  // This operator can get huge benefit from quantization, thus must be quantized
140  kMust,
141  // This operator support quantization, but will be decided depending on the connection
142  kSupport,
143 };
149 class OpStatePtr {
150  public:
151  /* \brief Create a OpStatePtr with state of type T.
152  * \param args Arguments passed to T's constructor.
153  */
154  template<typename T, typename... Args>
155  static OpStatePtr Create(Args&&... args) {
156  OpStatePtr ret;
157  auto state = new T(std::forward<Args>(args)...);
158  auto var = Engine::Get()->NewVariable();
159  ret.ptr_.reset(
160  new OpState(var, state),
161  [](OpState* p) {
162  Engine::Get()->DeleteVariable([](RunContext s) {}, Context::CPU(), p->var);
163  delete reinterpret_cast<T*>(p->state);
164  delete p;
165  });
167  return ret;
168  }
169  /* \brief Get engine variable associated with this state */
171  return ptr_->var;
172  }
173  /* \brief Get state of type T */
174  template<typename T>
175  T& get_state() const {
176  return *reinterpret_cast<T*>(ptr_->state);
177  }
178  /* \brief clear state */
179  void reset() {
180  ptr_.reset();
181  }
182  /* \brief checks whether the managed object is managed only by the current
183  OpStatePtr instance */
184  bool unique() const {
185  return ptr_.unique();
186  }
187  /* \brief Whether state is empty */
188  explicit operator bool() const {
189  return ptr_ ? true : false;
190  }
192  private:
193  /* \brief state structure */
194  struct OpState {
195  engine::VarHandle var;
196  void* state;
198  OpState(engine::VarHandle var_, void* state_) : var(var_), state(state_) {}
199  OpState(const OpState& other) = delete;
200  OpState& operator=(const OpState& other) = delete;
201  };
202  /* \brief shared pointer to state */
203  std::shared_ptr<OpState> ptr_;
204 };
218 using FCreateOpState = std::function<OpStatePtr (const NodeAttrs& attrs,
219  Context ctx,
220  const mxnet::ShapeVector& in_shape,
221  const std::vector<int>& in_type)>;
236 using FExecType = std::function<ExecType (const NodeAttrs& attrs)>;
244 using FStatefulCompute = std::function<void (const OpStatePtr& state,
245  const OpContext& ctx,
246  const std::vector<TBlob>& inputs,
247  const std::vector<OpReqType>& req,
248  const std::vector<TBlob>& outputs)>;
256 using FStatefulComputeEx = std::function<void (const OpStatePtr& state,
257  const OpContext& ctx,
258  const std::vector<NDArray>& inputs,
259  const std::vector<OpReqType>& req,
260  const std::vector<NDArray>& outputs)>;
267 using FResourceRequest = std::function<
268  std::vector<ResourceRequest> (const NodeAttrs& n)>;
277 using FResourceRequestEx = std::function<
278  std::vector<ResourceRequest> (const NodeAttrs& n,
279  const int dev_mask,
280  const DispatchMode dispatch_mode)>;
286 using FNDArrayFunction = std::function<void (const nnvm::NodeAttrs& attrs,
287  const std::vector<NDArray>& inputs,
288  std::vector<NDArray>* outputs)>;
294 using FCompute = std::function<void (const nnvm::NodeAttrs& attrs,
295  const OpContext& ctx,
296  const std::vector<TBlob>& inputs,
297  const std::vector<OpReqType>& req,
298  const std::vector<TBlob>& outputs)>;
304 using FComputeEx = std::function<void (const nnvm::NodeAttrs& attrs,
305  const OpContext& ctx,
306  const std::vector<NDArray>& inputs,
307  const std::vector<OpReqType>& req,
308  const std::vector<NDArray>& outputs)>;
316 using FInferStorageType = std::function<bool (const NodeAttrs& attrs,
317  const int dev_mask,
318  DispatchMode* dispatch_mode,
319  std::vector<int>* in_attrs,
320  std::vector<int>* out_attrs)>;
326 using FQuantizable = std::function<QuantizeType (const NodeAttrs& attrs)>;
332 using FQuantizedOp = std::function<nnvm::ObjectPtr (const NodeAttrs& attrs)>;
340 using FNeedRequantize = std::function<bool (const NodeAttrs& attrs)>;
347 using FAvoidQuantizeInput = std::function<bool (const NodeAttrs& attrs,
348  const size_t index,
349  const std::string quantize_granularity)>;
356 using FNeedCalibrateInput = std::function<std::vector<int> (const NodeAttrs& attrs)>;
363 using FNeedCalibrateOutput = std::function<std::vector<int> (const NodeAttrs& attrs)>;
365 } // namespace mxnet
367 #endif // MXNET_OP_ATTR_TYPES_H_
std::function< void(const OpStatePtr &state, const OpContext &ctx, const std::vector< TBlob > &inputs, const std::vector< OpReqType > &req, const std::vector< TBlob > &outputs)> FStatefulCompute
Resiger a compute function for stateful operator. OpStatePtr is a pointer type, it&#39;s content is mutab...
Definition: op_attr_types.h:248
void reset()
Definition: op_attr_types.h:179
Forward/Backward are synchronous calls.
Engine that schedules all the operations according to dependency.
bool THasDeterministicOutput
Whether the operator always produces the same output given the same input. This enables certain optim...
Definition: op_attr_types.h:231
no operation, do not write anything
Definition: op_attr_types.h:48
The attributes of the current operation node. Usually are additional parameters like axis...
Definition: node.h:119
write gradient to provided space
Definition: op_attr_types.h:50
namespace of mxnet
Definition: api_registry.h:33
std::function< std::vector< ResourceRequest >(const NodeAttrs &n)> FResourceRequest
The resource request from the operator. An operator could register ResourceRequestEx, or ResourceRequest, or neither.
Definition: op_attr_types.h:268
std::function< void(const OpStatePtr &state, const OpContext &ctx, const std::vector< NDArray > &inputs, const std::vector< OpReqType > &req, const std::vector< NDArray > &outputs)> FStatefulComputeEx
Resiger a compute function for stateful operator using NDArray interface. OpStatePtr is a pointer typ...
Definition: op_attr_types.h:260
mshadow::Stream< xpu > * get_stream() const
get mshadow stream from Context
Definition: base.h:371
std::function< bool(const NodeAttrs &attrs)> FNeedRequantize
Register a function to determine if the output of a quantized operator needs to be requantized...
Definition: op_attr_types.h:340
Asynchronous function call.
bool is_train
whether it is training phase
Definition: op_attr_types.h:71
engine::VarHandle get_var() const
Definition: op_attr_types.h:170
Cross device copy operation, this is a special operator that indicates it will copy across devices...
execution time context. The information needed in runtime for actual execution.
Definition: base.h:350
the dispatch mode of the operator
Definition: op_attr_types.h:123
base class of engine variables.
Definition: engine.h:44
Provides automatic coordination of an auxilary stream with a primary one. This object, upon construction, prepares an aux stream for use by syncing it with enqueued primary-stream work. Object destruction will sync again so future primary-stream work will wait on enqueued aux-stream work. If MXNET_GPU_WORKER_NSTREAMS == 1, then this defaults simply: the primary stream will equal the aux stream and the syncs will be executed as nops. See ./src/operator/cudnn/cudnn_convolution-inl.h for a usage example.
Definition: base.h:315
T & get_state() const
Definition: op_attr_types.h:175
engine::CallbackOnComplete async_on_complete
the callback when operation completes, used by asynchronize ops
Definition: op_attr_types.h:75
header file of tensor data structure and functions This lib requires explicit memory allocation and d...
All the possible information needed by Operator.Forward and Backward This is the superset of RunConte...
Definition: op_attr_types.h:67
std::vector< mxnet::TShape > ShapeVector
The result holder of shape of each NodeEntry in the graph.
Definition: tuple.h:820
the quantization type of the operator
Definition: op_attr_types.h:136
std::function< ExecType(const NodeAttrs &attrs)> FExecType
Execution mode of this operator.
Definition: op_attr_types.h:236
virtual VarHandle NewVariable()=0
Allocate a new variable, the variable can then be used to schedule the operation concurrently via dep...
static OpStatePtr Create(Args &&...args)
Definition: op_attr_types.h:155
bool need_grad
whether there is a backward phase to compute gradients.
Definition: op_attr_types.h:69
SyncedGPUAuxStream get_gpu_aux_stream() const
get auxilary gpu stream auto-syncing object from Context
Definition: op_attr_types.h:92
std::function< void(const nnvm::NodeAttrs &attrs, const std::vector< NDArray > &inputs, std::vector< NDArray > *outputs)> FNDArrayFunction
Register an operator called as a NDArray function.
Definition: op_attr_types.h:288
std::function< nnvm::ObjectPtr(const NodeAttrs &attrs)> FQuantizedOp
Register a quantized node creation function based on the attrs of the node.
Definition: op_attr_types.h:332
Global resource allocation handling.
virtual void DeleteVariable(SyncFn delete_fn, Context exec_ctx, VarHandle var)=0
Schedule the deletion of a variable.
std::function< OpStatePtr(const NodeAttrs &attrs, Context ctx, const mxnet::ShapeVector &in_shape, const std::vector< int > &in_type)> FCreateOpState
Create a Layer style, forward/backward operator. This is easy to write code that contains state...
Definition: op_attr_types.h:221
perform an inplace write, This option only happen when Target shares memory with one of input argumen...
Definition: op_attr_types.h:56
A subgraph execution should happen in the main thread, instead of in the execution engine...
bool unique() const
Definition: op_attr_types.h:184
operation request type to Forward and Backward
Definition: op_attr_types.h:46
std::function< void(const nnvm::NodeAttrs &attrs, const OpContext &ctx, const std::vector< NDArray > &inputs, const std::vector< OpReqType > &req, const std::vector< NDArray > &outputs)> FComputeEx
Register an NDArray compute function for simple stateless forward only operator.
Definition: op_attr_types.h:308
std::vector< Resource > requested
Resources requested by the operator.
Definition: op_attr_types.h:77
RunContext run_ctx
RunContext related resources.
Definition: op_attr_types.h:73
static Context CPU(int32_t dev_id=0)
std::function< void(const nnvm::NodeAttrs &attrs, const OpContext &ctx, const std::vector< TBlob > &inputs, const std::vector< OpReqType > &req, const std::vector< TBlob > &outputs)> FCompute
Register a compute function for simple stateless forward only operator.
Definition: op_attr_types.h:298
OnComplete Callback to the engine, called by AsyncFn when action completes.
Definition: engine.h:73
SyncedGPUAuxStream get_gpu_aux_stream() const
get an RAII object that transparently handles the syncing of the auxiliary stream.
Definition: base.h:379
static Engine * Get()
std::function< bool(const NodeAttrs &attrs, const int dev_mask, DispatchMode *dispatch_mode, std::vector< int > *in_attrs, std::vector< int > *out_attrs)> FInferStorageType
Register a storage and dispatch mode inference function based on storage types of the inputs and outp...
Definition: op_attr_types.h:320
std::function< std::vector< ResourceRequest >(const NodeAttrs &n, const int dev_mask, const DispatchMode dispatch_mode)> FResourceRequestEx
The resource request from the operator. An operator could register ResourceRequestEx, or ResourceRequest, or neither. If an operator registers both ResourceRequestEx and ResourceRequest, ResourceRequest is ignored.
Definition: op_attr_types.h:280
std::function< std::vector< int >(const NodeAttrs &attrs)> FNeedCalibrateInput
Register a function to determine if the input of a quantized operator needs to be calibrated...
Definition: op_attr_types.h:356
add to the provided space
Definition: op_attr_types.h:58
std::function< bool(const NodeAttrs &attrs, const size_t index, const std::string quantize_granularity)> FAvoidQuantizeInput
Register a function to determine if the input of a quantized operator needs to be quantized...
Definition: op_attr_types.h:349
std::function< std::vector< int >(const NodeAttrs &attrs)> FNeedCalibrateOutput
Register a function to determine if the output of a quantized operator needs to be calibrated...
Definition: op_attr_types.h:363
the execution type of the operator
Definition: op_attr_types.h:99
Context information about the execution environment.
Definition: base.h:102
mshadow::Stream< xpu > * get_stream() const
get mshadow stream from Context
Definition: op_attr_types.h:84
std::function< QuantizeType(const NodeAttrs &attrs)> FQuantizable
Register a quantized node creation function based on the attrs of the node.
Definition: op_attr_types.h:326
Data structures that can appear in operator attributes.
Operator state. This is a pointer type, its content is mutable even if OpStatePtr is const...
Definition: op_attr_types.h:149
computaion stream structure, used for asynchronous computations
Definition: tensor.h:384