mxnet
op_attr_types.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
26 #ifndef MXNET_OP_ATTR_TYPES_H_
27 #define MXNET_OP_ATTR_TYPES_H_
28 
29 #include <mshadow/tensor.h>
30 #include <nnvm/op_attr_types.h>
31 
32 #include <vector>
33 #include <functional>
34 
35 #include "./base.h"
36 #include "./ndarray.h"
37 #include "./engine.h"
38 #include "./resource.h"
39 
40 namespace mxnet {
41 
42 using nnvm::NodeAttrs;
43 
45 enum OpReqType {
58 };
59 
66 struct OpContext {
68  bool need_grad;
70  bool is_train;
76  std::vector<Resource> requested;
82  template<typename xpu>
83  inline mshadow::Stream<xpu>* get_stream() const {
84  return run_ctx.get_stream<xpu>();
85  }
86 #if MXNET_USE_CUDA
87 
92  return run_ctx.get_gpu_aux_stream();
93  }
94 #endif
95 };
96 
98 enum class ExecType {
100  kSync,
105  kAsync,
119 };
120 
122 enum class DispatchMode {
123  kUndefined = -1,
124  // dispatch on FCompute or FStatefulCompute
125  kFCompute,
126  // dispatch on FComputeEx or FStatefulComputeEx, if available
127  kFComputeEx,
128  // dispatch on FCompute or FStatefulCompute, and performs storage fallback
130  // special dispatch mode for variables
131  kVariable,
132 };
133 
135 enum class QuantizeType {
136  // This operator doesn't support quantization
137  kNone = 0,
138  // This operator can get huge benefit from quantization, thus must be quantized
139  kMust,
140  // This operator support quantization, but will be decided depending on the connection
141  kSupport,
142 };
143 
148 class OpStatePtr {
149  public:
150  /* \brief Create a OpStatePtr with state of type T.
151  * \param args Arguments passed to T's constructor.
152  */
153  template<typename T, typename... Args>
154  static OpStatePtr Create(Args&&... args) {
155  OpStatePtr ret;
156  auto state = new T(std::forward<Args>(args)...);
157  auto var = Engine::Get()->NewVariable();
158  ret.ptr_.reset(
159  new OpState(var, state),
160  [](OpState* p) {
161  Engine::Get()->DeleteVariable([](RunContext s) {}, Context::CPU(), p->var);
162  delete reinterpret_cast<T*>(p->state);
163  delete p;
164  });
165 
166  return ret;
167  }
168  /* \brief Get engine variable associated with this state */
170  return ptr_->var;
171  }
172  /* \brief Get state of type T */
173  template<typename T>
174  T& get_state() const {
175  return *reinterpret_cast<T*>(ptr_->state);
176  }
177  /* \brief clear state */
178  void reset() {
179  ptr_.reset();
180  }
181  /* \brief checks whether the managed object is managed only by the current
182  OpStatePtr instance */
183  bool unique() const {
184  return ptr_.unique();
185  }
186  /* \brief Whether state is empty */
187  explicit operator bool() const {
188  return ptr_ ? true : false;
189  }
190 
191  private:
192  /* \brief state structure */
193  struct OpState {
194  engine::VarHandle var;
195  void* state;
196 
197  OpState(engine::VarHandle var_, void* state_) : var(var_), state(state_) {}
198  OpState(const OpState& other) = delete;
199  OpState& operator=(const OpState& other) = delete;
200  };
201  /* \brief shared pointer to state */
202  std::shared_ptr<OpState> ptr_;
203 };
204 
217 using FCreateOpState = std::function<OpStatePtr (const NodeAttrs& attrs,
218  Context ctx,
219  const mxnet::ShapeVector& in_shape,
220  const std::vector<int>& in_type)>;
221 
231 
235 using FExecType = std::function<ExecType (const NodeAttrs& attrs)>;
243 using FStatefulCompute = std::function<void (const OpStatePtr& state,
244  const OpContext& ctx,
245  const std::vector<TBlob>& inputs,
246  const std::vector<OpReqType>& req,
247  const std::vector<TBlob>& outputs)>;
255 using FStatefulComputeEx = std::function<void (const OpStatePtr& state,
256  const OpContext& ctx,
257  const std::vector<NDArray>& inputs,
258  const std::vector<OpReqType>& req,
259  const std::vector<NDArray>& outputs)>;
266 using FResourceRequest = std::function<
267  std::vector<ResourceRequest> (const NodeAttrs& n)>;
276 using FResourceRequestEx = std::function<
277  std::vector<ResourceRequest> (const NodeAttrs& n,
278  const int dev_mask,
279  const DispatchMode dispatch_mode)>;
285 using FNDArrayFunction = std::function<void (const nnvm::NodeAttrs& attrs,
286  const std::vector<NDArray>& inputs,
287  std::vector<NDArray>* outputs)>;
293 using FCompute = std::function<void (const nnvm::NodeAttrs& attrs,
294  const OpContext& ctx,
295  const std::vector<TBlob>& inputs,
296  const std::vector<OpReqType>& req,
297  const std::vector<TBlob>& outputs)>;
303 using FComputeEx = std::function<void (const nnvm::NodeAttrs& attrs,
304  const OpContext& ctx,
305  const std::vector<NDArray>& inputs,
306  const std::vector<OpReqType>& req,
307  const std::vector<NDArray>& outputs)>;
308 
315 using FInferStorageType = std::function<bool (const NodeAttrs& attrs,
316  const int dev_mask,
317  DispatchMode* dispatch_mode,
318  std::vector<int>* in_attrs,
319  std::vector<int>* out_attrs)>;
320 
325 using FQuantizable = std::function<QuantizeType (const NodeAttrs& attrs)>;
326 
331 using FQuantizedOp = std::function<nnvm::NodePtr (const NodeAttrs& attrs)>;
332 
339 using FNeedRequantize = std::function<bool (const NodeAttrs& attrs)>;
340 
346 using FAvoidQuantizeInput = std::function<bool (const NodeAttrs& attrs,
347  size_t index)>;
348 
354 using FNeedCalibrateInput = std::function<std::vector<int> (const NodeAttrs& attrs)>;
355 
361 using FNeedCalibrateOutput = std::function<std::vector<int> (const NodeAttrs& attrs)>;
362 
363 } // namespace mxnet
364 
365 #endif // MXNET_OP_ATTR_TYPES_H_
std::function< void(const OpStatePtr &state, const OpContext &ctx, const std::vector< TBlob > &inputs, const std::vector< OpReqType > &req, const std::vector< TBlob > &outputs)> FStatefulCompute
Resiger a compute function for stateful operator. OpStatePtr is a pointer type, it&#39;s content is mutab...
Definition: op_attr_types.h:247
void reset()
Definition: op_attr_types.h:178
Forward/Backward are synchronous calls.
Engine that schedules all the operations according to dependency.
bool THasDeterministicOutput
Whether the operator always produces the same output given the same input. This enables certain optim...
Definition: op_attr_types.h:230
no operation, do not write anything
Definition: op_attr_types.h:47
The attributes of the current operation node. Usually are additional parameters like axis...
Definition: node.h:120
write gradient to provided space
Definition: op_attr_types.h:49
namespace of mxnet
Definition: base.h:89
std::function< std::vector< ResourceRequest >(const NodeAttrs &n)> FResourceRequest
The resource request from the operator. An operator could register ResourceRequestEx, or ResourceRequest, or neither.
Definition: op_attr_types.h:267
std::function< void(const OpStatePtr &state, const OpContext &ctx, const std::vector< NDArray > &inputs, const std::vector< OpReqType > &req, const std::vector< NDArray > &outputs)> FStatefulComputeEx
Resiger a compute function for stateful operator using NDArray interface. OpStatePtr is a pointer typ...
Definition: op_attr_types.h:259
mshadow::Stream< xpu > * get_stream() const
get mshadow stream from Context
Definition: base.h:371
std::function< bool(const NodeAttrs &attrs)> FNeedRequantize
Register a function to determine if the output of a quantized operator needs to be requantized...
Definition: op_attr_types.h:339
Asynchronous function call.
bool is_train
whether it is training phase
Definition: op_attr_types.h:70
engine::VarHandle get_var() const
Definition: op_attr_types.h:169
Cross device copy operation, this is a special operator that indicates it will copy across devices...
execution time context. The information needed in runtime for actual execution.
Definition: base.h:350
DispatchMode
the dispatch mode of the operator
Definition: op_attr_types.h:122
base class of engine variables.
Definition: engine.h:44
Provides automatic coordination of an auxilary stream with a primary one. This object, upon construction, prepares an aux stream for use by syncing it with enqueued primary-stream work. Object destruction will sync again so future primary-stream work will wait on enqueued aux-stream work. If MXNET_GPU_WORKER_NSTREAMS == 1, then this defaults simply: the primary stream will equal the aux stream and the syncs will be executed as nops. See ./src/operator/cudnn/cudnn_convolution-inl.h for a usage example.
Definition: base.h:315
T & get_state() const
Definition: op_attr_types.h:174
engine::CallbackOnComplete async_on_complete
the callback when operation completes, used by asynchronize ops
Definition: op_attr_types.h:74
All the possible information needed by Operator.Forward and Backward This is the superset of RunConte...
Definition: op_attr_types.h:66
std::vector< mxnet::TShape > ShapeVector
The result holder of shape of each NodeEntry in the graph.
Definition: tuple.h:793
QuantizeType
the quantization type of the operator
Definition: op_attr_types.h:135
std::function< ExecType(const NodeAttrs &attrs)> FExecType
Execution mode of this operator.
Definition: op_attr_types.h:235
virtual VarHandle NewVariable()=0
Allocate a new variable, the variable can then be used to schedule the operation concurrently via dep...
static OpStatePtr Create(Args &&...args)
Definition: op_attr_types.h:154
std::function< nnvm::NodePtr(const NodeAttrs &attrs)> FQuantizedOp
Register a quantized node creation function based on the attrs of the node.
Definition: op_attr_types.h:331
bool need_grad
whether there is a backward phase to compute gradients.
Definition: op_attr_types.h:68
SyncedGPUAuxStream get_gpu_aux_stream() const
get auxilary gpu stream auto-syncing object from Context
Definition: op_attr_types.h:91
std::function< void(const nnvm::NodeAttrs &attrs, const std::vector< NDArray > &inputs, std::vector< NDArray > *outputs)> FNDArrayFunction
Register an operator called as a NDArray function.
Definition: op_attr_types.h:287
Global resource allocation handling.
virtual void DeleteVariable(SyncFn delete_fn, Context exec_ctx, VarHandle var)=0
Schedule the deletion of a variable.
std::function< OpStatePtr(const NodeAttrs &attrs, Context ctx, const mxnet::ShapeVector &in_shape, const std::vector< int > &in_type)> FCreateOpState
Create a Layer style, forward/backward operator. This is easy to write code that contains state...
Definition: op_attr_types.h:220
perform an inplace write, This option only happen when Target shares memory with one of input argumen...
Definition: op_attr_types.h:55
A subgraph execution should happen in the main thread, instead of in the execution engine...
bool unique() const
Definition: op_attr_types.h:183
OpReqType
operation request type to Forward and Backward
Definition: op_attr_types.h:45
std::function< void(const nnvm::NodeAttrs &attrs, const OpContext &ctx, const std::vector< NDArray > &inputs, const std::vector< OpReqType > &req, const std::vector< NDArray > &outputs)> FComputeEx
Register an NDArray compute function for simple stateless forward only operator.
Definition: op_attr_types.h:307
std::vector< Resource > requested
Resources requested by the operator.
Definition: op_attr_types.h:76
RunContext run_ctx
RunContext related resources.
Definition: op_attr_types.h:72
static Context CPU(int32_t dev_id=0)
std::function< void(const nnvm::NodeAttrs &attrs, const OpContext &ctx, const std::vector< TBlob > &inputs, const std::vector< OpReqType > &req, const std::vector< TBlob > &outputs)> FCompute
Register a compute function for simple stateless forward only operator.
Definition: op_attr_types.h:297
OnComplete Callback to the engine, called by AsyncFn when action completes.
Definition: engine.h:73
SyncedGPUAuxStream get_gpu_aux_stream() const
get an RAII object that transparently handles the syncing of the auxiliary stream.
Definition: base.h:379
static Engine * Get()
std::function< bool(const NodeAttrs &attrs, const int dev_mask, DispatchMode *dispatch_mode, std::vector< int > *in_attrs, std::vector< int > *out_attrs)> FInferStorageType
Register a storage and dispatch mode inference function based on storage types of the inputs and outp...
Definition: op_attr_types.h:319
std::function< std::vector< ResourceRequest >(const NodeAttrs &n, const int dev_mask, const DispatchMode dispatch_mode)> FResourceRequestEx
The resource request from the operator. An operator could register ResourceRequestEx, or ResourceRequest, or neither. If an operator registers both ResourceRequestEx and ResourceRequest, ResourceRequest is ignored.
Definition: op_attr_types.h:279
std::function< std::vector< int >(const NodeAttrs &attrs)> FNeedCalibrateInput
Register a function to determine if the input of a quantized operator needs to be calibrated...
Definition: op_attr_types.h:354
add to the provided space
Definition: op_attr_types.h:57
std::function< bool(const NodeAttrs &attrs, size_t index)> FAvoidQuantizeInput
Register a function to determine if the input of a quantized operator needs to be quantized...
Definition: op_attr_types.h:347
std::function< std::vector< int >(const NodeAttrs &attrs)> FNeedCalibrateOutput
Register a function to determine if the output of a quantized operator needs to be calibrated...
Definition: op_attr_types.h:361
ExecType
the execution type of the operator
Definition: op_attr_types.h:98
Context information about the execution environment.
Definition: base.h:102
mshadow::Stream< xpu > * get_stream() const
get mshadow stream from Context
Definition: op_attr_types.h:83
std::function< QuantizeType(const NodeAttrs &attrs)> FQuantizable
Register a quantized node creation function based on the attrs of the node.
Definition: op_attr_types.h:325
Data structures that can appear in operator attributes.
Operator state. This is a pointer type, its content is mutable even if OpStatePtr is const...
Definition: op_attr_types.h:148
computaion stream structure, used for asynchronous computations
Definition: tensor.h:365