mxnet
op_attr_types.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
25 #ifndef MXNET_OP_ATTR_TYPES_H_
26 #define MXNET_OP_ATTR_TYPES_H_
27 
28 #include <mshadow/tensor.h>
29 #include <nnvm/op_attr_types.h>
30 
31 #include <vector>
32 #include <functional>
33 #include <string>
34 
35 #include "./base.h"
36 #include "./ndarray.h"
37 #include "./engine.h"
38 #include "./resource.h"
39 
40 namespace mxnet {
41 
42 using nnvm::NodeAttrs;
43 
45 enum OpReqType {
58 };
59 
66 struct OpContext {
68  bool need_grad;
70  bool is_train;
76  std::vector<Resource> requested;
82  template <typename xpu>
83  inline mshadow::Stream<xpu>* get_stream() const {
84  return run_ctx.get_stream<xpu>();
85  }
86 #if MXNET_USE_CUDA
87 
92  return run_ctx.get_gpu_aux_stream();
93  }
94 #endif
95 };
96 
98 enum class ExecType {
100  kSync,
105  kAsync,
119 };
120 
122 enum class DispatchMode {
123  kUndefined = -1,
124  // dispatch on FCompute or FStatefulCompute
125  kFCompute,
126  // dispatch on FComputeEx or FStatefulComputeEx, if available
127  kFComputeEx,
128  // dispatch on FCompute or FStatefulCompute, and performs storage fallback
130  // special dispatch mode for variables
131  kVariable,
132 };
133 
135 enum class QuantizeType {
136  // This operator doesn't support quantization
137  kNone = 0,
138  // This operator can get huge benefit from quantization, thus must be quantized
139  kMust,
140  // This operator support quantization, but will be decided depending on the connection
141  kSupport,
142 };
143 
148 class OpStatePtr {
149  public:
150  /* \brief Create a OpStatePtr with state of type T.
151  * \param args Arguments passed to T's constructor.
152  */
153  template <typename T, typename... Args>
154  static OpStatePtr Create(Args&&... args) {
155  OpStatePtr ret;
156  auto state = new T(std::forward<Args>(args)...);
157  auto var = Engine::Get()->NewVariable();
158  ret.ptr_.reset(new OpState(var, state), [](OpState* p) {
159  Engine::Get()->DeleteVariable([](RunContext s) {}, Context::CPU(), p->var);
160  delete reinterpret_cast<T*>(p->state);
161  delete p;
162  });
163 
164  return ret;
165  }
166  /* \brief Get engine variable associated with this state */
168  return ptr_->var;
169  }
170  /* \brief Get state of type T */
171  template <typename T>
172  T& get_state() const {
173  return *reinterpret_cast<T*>(ptr_->state);
174  }
175  /* \brief clear state */
176  void reset() {
177  ptr_.reset();
178  }
179  /* \brief checks whether the managed object is managed only by the current
180  OpStatePtr instance */
181  bool unique() const {
182  return ptr_.unique();
183  }
184  /* \brief Whether state is empty */
185  explicit operator bool() const {
186  return ptr_ ? true : false;
187  }
188 
189  private:
190  /* \brief state structure */
191  struct OpState {
192  engine::VarHandle var;
193  void* state;
194 
195  OpState(engine::VarHandle var_, void* state_) : var(var_), state(state_) {}
196  OpState(const OpState& other) = delete;
197  OpState& operator=(const OpState& other) = delete;
198  };
199  /* \brief shared pointer to state */
200  std::shared_ptr<OpState> ptr_;
201 };
202 
215 using FCreateOpState = std::function<OpStatePtr(const NodeAttrs& attrs,
216  Context ctx,
217  const mxnet::ShapeVector& in_shape,
218  const std::vector<int>& in_type)>;
219 
229 
233 using FExecType = std::function<ExecType(const NodeAttrs& attrs)>;
241 using FStatefulCompute = std::function<void(const OpStatePtr& state,
242  const OpContext& ctx,
243  const std::vector<TBlob>& inputs,
244  const std::vector<OpReqType>& req,
245  const std::vector<TBlob>& outputs)>;
253 using FStatefulComputeEx = std::function<void(const OpStatePtr& state,
254  const OpContext& ctx,
255  const std::vector<NDArray>& inputs,
256  const std::vector<OpReqType>& req,
257  const std::vector<NDArray>& outputs)>;
264 using FResourceRequest = std::function<std::vector<ResourceRequest>(const NodeAttrs& n)>;
273 using FResourceRequestEx =
274  std::function<std::vector<ResourceRequest>(const NodeAttrs& n,
275  const int dev_mask,
276  const DispatchMode dispatch_mode)>;
282 using FNDArrayFunction = std::function<void(const nnvm::NodeAttrs& attrs,
283  const std::vector<NDArray>& inputs,
284  std::vector<NDArray>* outputs)>;
290 using FCompute = std::function<void(const nnvm::NodeAttrs& attrs,
291  const OpContext& ctx,
292  const std::vector<TBlob>& inputs,
293  const std::vector<OpReqType>& req,
294  const std::vector<TBlob>& outputs)>;
300 using FComputeEx = std::function<void(const nnvm::NodeAttrs& attrs,
301  const OpContext& ctx,
302  const std::vector<NDArray>& inputs,
303  const std::vector<OpReqType>& req,
304  const std::vector<NDArray>& outputs)>;
305 
312 using FInferStorageType = std::function<bool(const NodeAttrs& attrs,
313  const int dev_mask,
314  DispatchMode* dispatch_mode,
315  std::vector<int>* in_attrs,
316  std::vector<int>* out_attrs)>;
317 
322 using FQuantizable = std::function<QuantizeType(const NodeAttrs& attrs)>;
323 
328 using FQuantizedOp = std::function<nnvm::ObjectPtr(const NodeAttrs& attrs)>;
329 
336 using FNeedRequantize = std::function<bool(const NodeAttrs& attrs)>;
337 
343 using FAvoidQuantizeInput = std::function<
344  bool(const NodeAttrs& attrs, const size_t index, const std::string quantize_granularity)>;
345 
350 using FNeedAsymQuantizeInput = std::function<bool(const NodeAttrs& attrs, const size_t index)>;
351 
357 using FAvoidDequantizeOutput = std::function<bool(const NodeAttrs& attrs, const size_t index)>;
358 
364 using FNeedCalibrateInput = std::function<std::vector<int>(const NodeAttrs& attrs)>;
365 
371 using FNeedCalibrateOutput = std::function<std::vector<int>(const NodeAttrs& attrs)>;
372 
373 #if MXNET_USE_CUDA
374 
382 using FIsCUDAGraphsCompatible = std::function<bool(const NodeAttrs& attrs, const bool is_train)>;
383 
384 #endif
385 
386 } // namespace mxnet
387 
388 #endif // MXNET_OP_ATTR_TYPES_H_
mxnet
namespace of mxnet
Definition: api_registry.h:33
mxnet::kWriteInplace
@ kWriteInplace
perform an inplace write, This option only happen when Target shares memory with one of input argumen...
Definition: op_attr_types.h:55
mxnet::THasDeterministicOutput
bool THasDeterministicOutput
Whether the operator always produces the same output given the same input. This enables certain optim...
Definition: op_attr_types.h:228
mxnet::OpStatePtr
Operator state. This is a pointer type, its content is mutable even if OpStatePtr is const.
Definition: op_attr_types.h:148
mshadow::Stream
computaion stream structure, used for asynchronous computations
Definition: tensor.h:488
mxnet::DispatchMode::kVariable
@ kVariable
mxnet::FCompute
std::function< void(const nnvm::NodeAttrs &attrs, const OpContext &ctx, const std::vector< TBlob > &inputs, const std::vector< OpReqType > &req, const std::vector< TBlob > &outputs)> FCompute
Register a compute function for simple stateless forward only operator.
Definition: op_attr_types.h:294
op_attr_types.h
Data structures that can appear in operator attributes.
mxnet::SyncedGPUAuxStream
Provides automatic coordination of an auxilary stream with a primary one. This object,...
Definition: base.h:308
mxnet::QuantizeType
QuantizeType
the quantization type of the operator
Definition: op_attr_types.h:135
mxnet::engine::CallbackOnComplete
OnComplete Callback to the engine, called by AsyncFn when action completes.
Definition: engine.h:169
mxnet::Engine::NewVariable
virtual VarHandle NewVariable()=0
Allocate a new variable, the variable can then be used to schedule the operation concurrently via dep...
mxnet::FStatefulCompute
std::function< void(const OpStatePtr &state, const OpContext &ctx, const std::vector< TBlob > &inputs, const std::vector< OpReqType > &req, const std::vector< TBlob > &outputs)> FStatefulCompute
Resiger a compute function for stateful operator. OpStatePtr is a pointer type, it's content is mutab...
Definition: op_attr_types.h:245
mxnet::FQuantizedOp
std::function< nnvm::ObjectPtr(const NodeAttrs &attrs)> FQuantizedOp
Register a quantized node creation function based on the attrs of the node.
Definition: op_attr_types.h:328
mxnet::OpReqType
OpReqType
operation request type to Forward and Backward
Definition: op_attr_types.h:45
mxnet::Engine::DeleteVariable
virtual void DeleteVariable(SyncFn delete_fn, Context exec_ctx, VarHandle var)=0
Schedule the deletion of a variable.
mxnet::DispatchMode
DispatchMode
the dispatch mode of the operator
Definition: op_attr_types.h:122
mxnet::kNullOp
@ kNullOp
no operation, do not write anything
Definition: op_attr_types.h:47
mxnet::RunContext
execution time context. The information needed in runtime for actual execution.
Definition: base.h:343
mxnet::FResourceRequestEx
std::function< std::vector< ResourceRequest >(const NodeAttrs &n, const int dev_mask, const DispatchMode dispatch_mode)> FResourceRequestEx
The resource request from the operator. An operator could register ResourceRequestEx,...
Definition: op_attr_types.h:276
mxnet::FNeedCalibrateInput
std::function< std::vector< int >(const NodeAttrs &attrs)> FNeedCalibrateInput
Register a function to determine if the input of a quantized operator needs to be calibrated....
Definition: op_attr_types.h:364
mxnet::OpContext
All the possible information needed by Operator. This is the superset of RunContext....
Definition: op_attr_types.h:66
mxnet::Engine::Get
static Engine * Get()
mxnet::ExecType::kSubgraphExec
@ kSubgraphExec
A subgraph execution should happen in the main thread, instead of in the execution engine.
mxnet::ExecType::kAsync
@ kAsync
Forward/Backward are asynchronous, will call OpContext.async_on_complete when operation finishes.
mxnet::DispatchMode::kFComputeFallback
@ kFComputeFallback
mxnet::FNeedCalibrateOutput
std::function< std::vector< int >(const NodeAttrs &attrs)> FNeedCalibrateOutput
Register a function to determine if the output of a quantized operator needs to be calibrated....
Definition: op_attr_types.h:371
mxnet::QuantizeType::kSupport
@ kSupport
tensor.h
header file of tensor data structure and functions This lib requires explicit memory allocation and d...
mxnet::OpContext::async_on_complete
engine::CallbackOnComplete async_on_complete
the callback when operation completes, used by asynchronize ops
Definition: op_attr_types.h:74
nnvm::NodeAttrs
The attributes of the current operation node. Usually are additional parameters like axis,...
Definition: node.h:107
nnvm::ObjectPtr
std::shared_ptr< Node > ObjectPtr
we always used ObjectPtr for a reference pointer to the node, so this alias can be changed in case.
Definition: node.h:49
mxnet::FQuantizable
std::function< QuantizeType(const NodeAttrs &attrs)> FQuantizable
Register a quantized node creation function based on the attrs of the node.
Definition: op_attr_types.h:322
resource.h
Global resource allocation handling.
mxnet::OpContext::get_gpu_aux_stream
SyncedGPUAuxStream get_gpu_aux_stream() const
get auxilary gpu stream auto-syncing object from Context
Definition: op_attr_types.h:91
mxnet::FAvoidQuantizeInput
std::function< bool(const NodeAttrs &attrs, const size_t index, const std::string quantize_granularity)> FAvoidQuantizeInput
Register a function to determine if the input of a quantized operator needs to be quantized....
Definition: op_attr_types.h:344
mxnet::ExecType
ExecType
the execution type of the operator
Definition: op_attr_types.h:98
mxnet::FComputeEx
std::function< void(const nnvm::NodeAttrs &attrs, const OpContext &ctx, const std::vector< NDArray > &inputs, const std::vector< OpReqType > &req, const std::vector< NDArray > &outputs)> FComputeEx
Register an NDArray compute function for simple stateless forward only operator.
Definition: op_attr_types.h:304
mxnet::FInferStorageType
std::function< bool(const NodeAttrs &attrs, const int dev_mask, DispatchMode *dispatch_mode, std::vector< int > *in_attrs, std::vector< int > *out_attrs)> FInferStorageType
Register a storage and dispatch mode inference function based on storage types of the inputs and outp...
Definition: op_attr_types.h:316
mxnet::OpContext::need_grad
bool need_grad
whether there is a backward phase to compute gradients.
Definition: op_attr_types.h:68
mxnet::RunContext::get_stream
mshadow::Stream< xpu > * get_stream() const
get mshadow stream from Context
Definition: base.h:364
mxnet::ExecType::kSync
@ kSync
Forward/Backward are synchronous calls.
mxnet::QuantizeType::kMust
@ kMust
mxnet::kWriteTo
@ kWriteTo
write gradient to provided space
Definition: op_attr_types.h:49
mxnet::OpStatePtr::get_state
T & get_state() const
Definition: op_attr_types.h:172
mxnet::FStatefulComputeEx
std::function< void(const OpStatePtr &state, const OpContext &ctx, const std::vector< NDArray > &inputs, const std::vector< OpReqType > &req, const std::vector< NDArray > &outputs)> FStatefulComputeEx
Resiger a compute function for stateful operator using NDArray interface. OpStatePtr is a pointer typ...
Definition: op_attr_types.h:257
mxnet::FAvoidDequantizeOutput
std::function< bool(const NodeAttrs &attrs, const size_t index)> FAvoidDequantizeOutput
Register a function to determine if the output of a quantized operator needs to be dequantized....
Definition: op_attr_types.h:357
mxnet::FNeedRequantize
std::function< bool(const NodeAttrs &attrs)> FNeedRequantize
Register a function to determine if the output of a quantized operator needs to be requantized....
Definition: op_attr_types.h:336
mxnet::OpStatePtr::Create
static OpStatePtr Create(Args &&... args)
Definition: op_attr_types.h:154
mxnet::engine::Var
base class of engine variables.
Definition: engine.h:111
mxnet::DispatchMode::kFComputeEx
@ kFComputeEx
mxnet::FResourceRequest
std::function< std::vector< ResourceRequest >(const NodeAttrs &n)> FResourceRequest
The resource request from the operator. An operator could register ResourceRequestEx,...
Definition: op_attr_types.h:264
mxnet::RunContext::get_gpu_aux_stream
SyncedGPUAuxStream get_gpu_aux_stream() const
get an RAII object that transparently handles the syncing of the auxiliary stream.
Definition: base.h:372
mxnet::Context::CPU
static Context CPU(int32_t dev_id=0)
mxnet::ShapeVector
std::vector< mxnet::TShape > ShapeVector
The result holder of shape of each NodeEntry in the graph.
Definition: tuple.h:830
mxnet::QuantizeType::kNone
@ kNone
mxnet::OpStatePtr::get_var
engine::VarHandle get_var() const
Definition: op_attr_types.h:167
mxnet::FCreateOpState
std::function< OpStatePtr(const NodeAttrs &attrs, Context ctx, const mxnet::ShapeVector &in_shape, const std::vector< int > &in_type)> FCreateOpState
Create a Layer style, forward/backward operator. This is easy to write code that contains state....
Definition: op_attr_types.h:218
engine.h
Engine that schedules all the operations according to dependency.
mxnet::DispatchMode::kUndefined
@ kUndefined
mxnet::ExecType::kCrossDeviceCopy
@ kCrossDeviceCopy
Cross device copy operation, this is a special operator that indicates it will copy across devices....
mxnet::FIsCUDAGraphsCompatible
std::function< bool(const NodeAttrs &attrs, const bool is_train)> FIsCUDAGraphsCompatible
Register a function to determine if the operator implementation is compatible with CUDA graphs....
Definition: op_attr_types.h:382
mxnet::OpContext::run_ctx
RunContext run_ctx
RunContext related resources.
Definition: op_attr_types.h:72
mxnet::OpStatePtr::reset
void reset()
Definition: op_attr_types.h:176
mxnet::OpContext::requested
std::vector< Resource > requested
Resources requested by the operator.
Definition: op_attr_types.h:76
ndarray.h
NDArray interface that handles array arithematics.
mxnet::kAddTo
@ kAddTo
add to the provided space
Definition: op_attr_types.h:57
mxnet::FNeedAsymQuantizeInput
std::function< bool(const NodeAttrs &attrs, const size_t index)> FNeedAsymQuantizeInput
Register a function to determine if the input of a quantized operator needs to be quantized asymmetri...
Definition: op_attr_types.h:350
mxnet::FExecType
std::function< ExecType(const NodeAttrs &attrs)> FExecType
Execution mode of this operator.
Definition: op_attr_types.h:233
mxnet::OpContext::is_train
bool is_train
whether it is training phase
Definition: op_attr_types.h:70
mxnet::OpStatePtr::unique
bool unique() const
Definition: op_attr_types.h:181
base.h
configuration of MXNet as well as basic data structure.
mxnet::OpContext::get_stream
mshadow::Stream< xpu > * get_stream() const
get mshadow stream from Context
Definition: op_attr_types.h:83
mxnet::FNDArrayFunction
std::function< void(const nnvm::NodeAttrs &attrs, const std::vector< NDArray > &inputs, std::vector< NDArray > *outputs)> FNDArrayFunction
Register an operator called as a NDArray function.
Definition: op_attr_types.h:284
mxnet::DispatchMode::kFCompute
@ kFCompute