1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  *
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
29 #include <mshadow/tensor.h>
30 #include <nnvm/op_attr_types.h>
32 #include <vector>
33 #include <functional>
34 #include <string>
36 #include "./base.h"
37 #include "./ndarray.h"
38 #include "./engine.h"
39 #include "./resource.h"
41 namespace mxnet {
43 using nnvm::NodeAttrs;
46 enum OpReqType {
59 };
67 struct OpContext {
69  bool need_grad;
71  bool is_train;
77  std::vector<Resource> requested;
83  template<typename xpu>
84  inline mshadow::Stream<xpu>* get_stream() const {
85  return run_ctx.get_stream<xpu>();
86  }
93  return run_ctx.get_gpu_aux_stream();
94  }
95 #endif
96 };
99 enum class ExecType {
101  kSync,
106  kAsync,
120 };
123 enum class DispatchMode {
124  kUndefined = -1,
125  // dispatch on FCompute or FStatefulCompute
126  kFCompute,
127  // dispatch on FComputeEx or FStatefulComputeEx, if available
128  kFComputeEx,
129  // dispatch on FCompute or FStatefulCompute, and performs storage fallback
131  // special dispatch mode for variables
132  kVariable,
133 };
136 enum class QuantizeType {
137  // This operator doesn't support quantization
138  kNone = 0,
139  // This operator can get huge benefit from quantization, thus must be quantized
140  kMust,
141  // This operator support quantization, but will be decided depending on the connection
142  kSupport,
143 };
149 class OpStatePtr {
150  public:
151  /* \brief Create a OpStatePtr with state of type T.
152  * \param args Arguments passed to T's constructor.
153  */
154  template<typename T, typename... Args>
155  static OpStatePtr Create(Args&&... args) {
156  OpStatePtr ret;
157  auto state = new T(std::forward<Args>(args)...);
158  auto var = Engine::Get()->NewVariable();
159  ret.ptr_.reset(
160  new OpState(var, state),
161  [](OpState* p) {
162  Engine::Get()->DeleteVariable([](RunContext s) {}, Context::CPU(), p->var);
163  delete reinterpret_cast<T*>(p->state);
164  delete p;
165  });
167  return ret;
168  }
169  /* \brief Get engine variable associated with this state */
171  return ptr_->var;
172  }
173  /* \brief Get state of type T */
174  template<typename T>
175  T& get_state() const {
176  return *reinterpret_cast<T*>(ptr_->state);
177  }
178  /* \brief clear state */
179  void reset() {
180  ptr_.reset();
181  }
182  /* \brief checks whether the managed object is managed only by the current
183  OpStatePtr instance */
184  bool unique() const {
185  return ptr_.unique();
186  }
187  /* \brief Whether state is empty */
188  explicit operator bool() const {
189  return ptr_ ? true : false;
190  }
192  private:
193  /* \brief state structure */
194  struct OpState {
195  engine::VarHandle var;
196  void* state;
198  OpState(engine::VarHandle var_, void* state_) : var(var_), state(state_) {}
199  OpState(const OpState& other) = delete;
200  OpState& operator=(const OpState& other) = delete;
201  };
202  /* \brief shared pointer to state */
203  std::shared_ptr<OpState> ptr_;
204 };
218 using FCreateOpState = std::function<OpStatePtr (const NodeAttrs& attrs,
219  Context ctx,
220  const mxnet::ShapeVector& in_shape,
221  const std::vector<int>& in_type)>;
236 using FExecType = std::function<ExecType (const NodeAttrs& attrs)>;
244 using FStatefulCompute = std::function<void (const OpStatePtr& state,
245  const OpContext& ctx,
246  const std::vector<TBlob>& inputs,
247  const std::vector<OpReqType>& req,
248  const std::vector<TBlob>& outputs)>;
256 using FStatefulComputeEx = std::function<void (const OpStatePtr& state,
257  const OpContext& ctx,
258  const std::vector<NDArray>& inputs,
259  const std::vector<OpReqType>& req,
260  const std::vector<NDArray>& outputs)>;
267 using FResourceRequest = std::function<
268  std::vector<ResourceRequest> (const NodeAttrs& n)>;
277 using FResourceRequestEx = std::function<
278  std::vector<ResourceRequest> (const NodeAttrs& n,
279  const int dev_mask,
280  const DispatchMode dispatch_mode)>;
286 using FNDArrayFunction = std::function<void (const nnvm::NodeAttrs& attrs,
287  const std::vector<NDArray>& inputs,
288  std::vector<NDArray>* outputs)>;
294 using FCompute = std::function<void (const nnvm::NodeAttrs& attrs,
295  const OpContext& ctx,
296  const std::vector<TBlob>& inputs,
297  const std::vector<OpReqType>& req,
298  const std::vector<TBlob>& outputs)>;
304 using FComputeEx = std::function<void (const nnvm::NodeAttrs& attrs,
305  const OpContext& ctx,
306  const std::vector<NDArray>& inputs,
307  const std::vector<OpReqType>& req,
308  const std::vector<NDArray>& outputs)>;
316 using FInferStorageType = std::function<bool (const NodeAttrs& attrs,
317  const int dev_mask,
318  DispatchMode* dispatch_mode,
319  std::vector<int>* in_attrs,
320  std::vector<int>* out_attrs)>;
326 using FQuantizable = std::function<QuantizeType (const NodeAttrs& attrs)>;
332 using FQuantizedOp = std::function<nnvm::ObjectPtr (const NodeAttrs& attrs)>;
340 using FNeedRequantize = std::function<bool (const NodeAttrs& attrs)>;
347 using FAvoidQuantizeInput = std::function<bool (const NodeAttrs& attrs,
348  const size_t index,
349  const std::string quantize_granularity)>;
356 using FNeedCalibrateInput = std::function<std::vector<int> (const NodeAttrs& attrs)>;
363 using FNeedCalibrateOutput = std::function<std::vector<int> (const NodeAttrs& attrs)>;
365 } // namespace mxnet
367 #endif // MXNET_OP_ATTR_TYPES_H_
