Go to the documentation of this file.
25 #ifndef MXNET_OP_ATTR_TYPES_H_
26 #define MXNET_OP_ATTR_TYPES_H_
82 template <
typename xpu>
153 template <
typename T,
typename... Args>
156 auto state =
new T(std::forward<Args>(args)...);
158 ret.ptr_.reset(
new OpState(var, state), [](OpState* p) {
160 delete reinterpret_cast<T*
>(p->state);
171 template <
typename T>
173 return *
reinterpret_cast<T*
>(ptr_->state);
182 return ptr_.unique();
185 explicit operator bool()
const {
186 return ptr_ ? true :
false;
196 OpState(
const OpState& other) =
delete;
197 OpState& operator=(
const OpState& other) =
delete;
200 std::shared_ptr<OpState> ptr_;
215 using FCreateOpState = std::function<OpStatePtr(
const NodeAttrs& attrs,
218 const std::vector<int>& in_type)>;
243 const std::vector<TBlob>& inputs,
244 const std::vector<OpReqType>& req,
245 const std::vector<TBlob>& outputs)>;
255 const std::vector<NDArray>& inputs,
256 const std::vector<OpReqType>& req,
257 const std::vector<NDArray>& outputs)>;
274 std::function<std::vector<ResourceRequest>(
const NodeAttrs& n,
283 const std::vector<NDArray>& inputs,
284 std::vector<NDArray>* outputs)>;
292 const std::vector<TBlob>& inputs,
293 const std::vector<OpReqType>& req,
294 const std::vector<TBlob>& outputs)>;
302 const std::vector<NDArray>& inputs,
303 const std::vector<OpReqType>& req,
304 const std::vector<NDArray>& outputs)>;
315 std::vector<int>* in_attrs,
316 std::vector<int>* out_attrs)>;
344 bool(
const NodeAttrs& attrs,
const size_t index,
const std::string quantize_granularity)>;
388 #endif // MXNET_OP_ATTR_TYPES_H_
namespace of mxnet
Definition: api_registry.h:33
@ kWriteInplace
perform an inplace write, This option only happen when Target shares memory with one of input argumen...
Definition: op_attr_types.h:55
bool THasDeterministicOutput
Whether the operator always produces the same output given the same input. This enables certain optim...
Definition: op_attr_types.h:228
Operator state. This is a pointer type, its content is mutable even if OpStatePtr is const.
Definition: op_attr_types.h:148
computaion stream structure, used for asynchronous computations
Definition: tensor.h:488
std::function< void(const nnvm::NodeAttrs &attrs, const OpContext &ctx, const std::vector< TBlob > &inputs, const std::vector< OpReqType > &req, const std::vector< TBlob > &outputs)> FCompute
Register a compute function for simple stateless forward only operator.
Definition: op_attr_types.h:294
Data structures that can appear in operator attributes.
Provides automatic coordination of an auxilary stream with a primary one. This object,...
Definition: base.h:308
QuantizeType
the quantization type of the operator
Definition: op_attr_types.h:135
OnComplete Callback to the engine, called by AsyncFn when action completes.
Definition: engine.h:169
virtual VarHandle NewVariable()=0
Allocate a new variable, the variable can then be used to schedule the operation concurrently via dep...
std::function< void(const OpStatePtr &state, const OpContext &ctx, const std::vector< TBlob > &inputs, const std::vector< OpReqType > &req, const std::vector< TBlob > &outputs)> FStatefulCompute
Resiger a compute function for stateful operator. OpStatePtr is a pointer type, it's content is mutab...
Definition: op_attr_types.h:245
std::function< nnvm::ObjectPtr(const NodeAttrs &attrs)> FQuantizedOp
Register a quantized node creation function based on the attrs of the node.
Definition: op_attr_types.h:328
OpReqType
operation request type to Forward and Backward
Definition: op_attr_types.h:45
virtual void DeleteVariable(SyncFn delete_fn, Context exec_ctx, VarHandle var)=0
Schedule the deletion of a variable.
DispatchMode
the dispatch mode of the operator
Definition: op_attr_types.h:122
@ kNullOp
no operation, do not write anything
Definition: op_attr_types.h:47
execution time context. The information needed in runtime for actual execution.
Definition: base.h:343
std::function< std::vector< ResourceRequest >(const NodeAttrs &n, const int dev_mask, const DispatchMode dispatch_mode)> FResourceRequestEx
The resource request from the operator. An operator could register ResourceRequestEx,...
Definition: op_attr_types.h:276
std::function< std::vector< int >(const NodeAttrs &attrs)> FNeedCalibrateInput
Register a function to determine if the input of a quantized operator needs to be calibrated....
Definition: op_attr_types.h:364
All the possible information needed by Operator. This is the superset of RunContext....
Definition: op_attr_types.h:66
@ kSubgraphExec
A subgraph execution should happen in the main thread, instead of in the execution engine.
@ kAsync
Forward/Backward are asynchronous, will call OpContext.async_on_complete when operation finishes.
std::function< std::vector< int >(const NodeAttrs &attrs)> FNeedCalibrateOutput
Register a function to determine if the output of a quantized operator needs to be calibrated....
Definition: op_attr_types.h:371
header file of tensor data structure and functions This lib requires explicit memory allocation and d...
engine::CallbackOnComplete async_on_complete
the callback when operation completes, used by asynchronize ops
Definition: op_attr_types.h:74
The attributes of the current operation node. Usually are additional parameters like axis,...
Definition: node.h:107
std::shared_ptr< Node > ObjectPtr
we always used ObjectPtr for a reference pointer to the node, so this alias can be changed in case.
Definition: node.h:49
std::function< QuantizeType(const NodeAttrs &attrs)> FQuantizable
Register a quantized node creation function based on the attrs of the node.
Definition: op_attr_types.h:322
Global resource allocation handling.
SyncedGPUAuxStream get_gpu_aux_stream() const
get auxilary gpu stream auto-syncing object from Context
Definition: op_attr_types.h:91
std::function< bool(const NodeAttrs &attrs, const size_t index, const std::string quantize_granularity)> FAvoidQuantizeInput
Register a function to determine if the input of a quantized operator needs to be quantized....
Definition: op_attr_types.h:344
ExecType
the execution type of the operator
Definition: op_attr_types.h:98
std::function< void(const nnvm::NodeAttrs &attrs, const OpContext &ctx, const std::vector< NDArray > &inputs, const std::vector< OpReqType > &req, const std::vector< NDArray > &outputs)> FComputeEx
Register an NDArray compute function for simple stateless forward only operator.
Definition: op_attr_types.h:304
std::function< bool(const NodeAttrs &attrs, const int dev_mask, DispatchMode *dispatch_mode, std::vector< int > *in_attrs, std::vector< int > *out_attrs)> FInferStorageType
Register a storage and dispatch mode inference function based on storage types of the inputs and outp...
Definition: op_attr_types.h:316
bool need_grad
whether there is a backward phase to compute gradients.
Definition: op_attr_types.h:68
mshadow::Stream< xpu > * get_stream() const
get mshadow stream from Context
Definition: base.h:364
@ kSync
Forward/Backward are synchronous calls.
@ kWriteTo
write gradient to provided space
Definition: op_attr_types.h:49
T & get_state() const
Definition: op_attr_types.h:172
std::function< void(const OpStatePtr &state, const OpContext &ctx, const std::vector< NDArray > &inputs, const std::vector< OpReqType > &req, const std::vector< NDArray > &outputs)> FStatefulComputeEx
Resiger a compute function for stateful operator using NDArray interface. OpStatePtr is a pointer typ...
Definition: op_attr_types.h:257
std::function< bool(const NodeAttrs &attrs, const size_t index)> FAvoidDequantizeOutput
Register a function to determine if the output of a quantized operator needs to be dequantized....
Definition: op_attr_types.h:357
std::function< bool(const NodeAttrs &attrs)> FNeedRequantize
Register a function to determine if the output of a quantized operator needs to be requantized....
Definition: op_attr_types.h:336
static OpStatePtr Create(Args &&... args)
Definition: op_attr_types.h:154
base class of engine variables.
Definition: engine.h:111
std::function< std::vector< ResourceRequest >(const NodeAttrs &n)> FResourceRequest
The resource request from the operator. An operator could register ResourceRequestEx,...
Definition: op_attr_types.h:264
SyncedGPUAuxStream get_gpu_aux_stream() const
get an RAII object that transparently handles the syncing of the auxiliary stream.
Definition: base.h:372
static Context CPU(int32_t dev_id=0)
std::vector< mxnet::TShape > ShapeVector
The result holder of shape of each NodeEntry in the graph.
Definition: tuple.h:830
engine::VarHandle get_var() const
Definition: op_attr_types.h:167
std::function< OpStatePtr(const NodeAttrs &attrs, Context ctx, const mxnet::ShapeVector &in_shape, const std::vector< int > &in_type)> FCreateOpState
Create a Layer style, forward/backward operator. This is easy to write code that contains state....
Definition: op_attr_types.h:218
Engine that schedules all the operations according to dependency.
@ kCrossDeviceCopy
Cross device copy operation, this is a special operator that indicates it will copy across devices....
std::function< bool(const NodeAttrs &attrs, const bool is_train)> FIsCUDAGraphsCompatible
Register a function to determine if the operator implementation is compatible with CUDA graphs....
Definition: op_attr_types.h:382
RunContext run_ctx
RunContext related resources.
Definition: op_attr_types.h:72
void reset()
Definition: op_attr_types.h:176
std::vector< Resource > requested
Resources requested by the operator.
Definition: op_attr_types.h:76
NDArray interface that handles array arithematics.
@ kAddTo
add to the provided space
Definition: op_attr_types.h:57
std::function< bool(const NodeAttrs &attrs, const size_t index)> FNeedAsymQuantizeInput
Register a function to determine if the input of a quantized operator needs to be quantized asymmetri...
Definition: op_attr_types.h:350
std::function< ExecType(const NodeAttrs &attrs)> FExecType
Execution mode of this operator.
Definition: op_attr_types.h:233
bool is_train
whether it is training phase
Definition: op_attr_types.h:70
bool unique() const
Definition: op_attr_types.h:181
configuration of MXNet as well as basic data structure.
mshadow::Stream< xpu > * get_stream() const
get mshadow stream from Context
Definition: op_attr_types.h:83
std::function< void(const nnvm::NodeAttrs &attrs, const std::vector< NDArray > &inputs, std::vector< NDArray > *outputs)> FNDArrayFunction
Register an operator called as a NDArray function.
Definition: op_attr_types.h:284