1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  *
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
29 #include <mshadow/tensor.h>
30 #include <nnvm/op_attr_types.h>
32 #include <vector>
33 #include <functional>
35 #include "./base.h"
36 #include "./ndarray.h"
37 #include "./engine.h"
38 #include "./resource.h"
40 namespace mxnet {
42 using nnvm::NodeAttrs;
45 enum OpReqType {
58 };
66 struct OpContext {
68  bool need_grad;
70  bool is_train;
76  std::vector<Resource> requested;
82  template<typename xpu>
83  inline mshadow::Stream<xpu>* get_stream() const {
84  return run_ctx.get_stream<xpu>();
85  }
86 };
89 enum class ExecType {
91  kSync,
96  kAsync,
110 };
113 enum class DispatchMode {
114  kUndefined = -1,
115  // dispatch on FCompute or FStatefulCompute
116  kFCompute,
117  // dispatch on FComputeEx or FStatefulComputeEx, if available
118  kFComputeEx,
119  // dispatch on FCompute or FStatefulCompute, and performs storage fallback
121  // special dispatch mode for variables
122  kVariable,
123 };
129 class OpStatePtr {
130  public:
131  /* \brief Create a OpStatePtr with state of type T.
132  * \param args Arguments passed to T's constructor.
133  */
134  template<typename T, typename... Args>
135  static OpStatePtr Create(Args&&... args) {
136  OpStatePtr ret;
137  auto state = new T(std::forward<Args>(args)...);
138  auto var = Engine::Get()->NewVariable();
139  ret.ptr_.reset(
140  new OpState(var, state),
141  [](OpState* p) {
142  Engine::Get()->DeleteVariable([](RunContext s) {}, Context::CPU(), p->var);
143  delete reinterpret_cast<T*>(p->state);
144  delete p;
145  });
147  return ret;
148  }
149  /* \brief Get engine variable associated with this state */
151  return ptr_->var;
152  }
153  /* \brief Get state of type T */
154  template<typename T>
155  T& get_state() const {
156  return *reinterpret_cast<T*>(ptr_->state);
157  }
158  /* \brief clear state */
159  void reset() {
160  ptr_.reset();
161  }
162  /* \brief checks whether the managed object is managed only by the current
163  OpStatePtr instance */
164  bool unique() const {
165  return ptr_.unique();
166  }
167  /* \brief Whether state is empty */
168  explicit operator bool() const {
169  return ptr_ ? true : false;
170  }
172  private:
173  /* \brief state structure */
174  struct OpState {
175  engine::VarHandle var;
176  void* state;
178  OpState(engine::VarHandle var_, void* state_) : var(var_), state(state_) {}
179  OpState(const OpState& other) = delete;
180  OpState& operator=(const OpState& other) = delete;
181  };
182  /* \brief shared pointer to state */
183  std::shared_ptr<OpState> ptr_;
184 };
198 using FCreateOpState = std::function<OpStatePtr (const NodeAttrs& attrs,
199  Context ctx,
200  const std::vector<TShape>& in_shape,
201  const std::vector<int>& in_type)>;
205 using FExecType = std::function<ExecType (const NodeAttrs& attrs)>;
213 using FStatefulCompute = std::function<void (const OpStatePtr& state,
214  const OpContext& ctx,
215  const std::vector<TBlob>& inputs,
216  const std::vector<OpReqType>& req,
217  const std::vector<TBlob>& outputs)>;
225 using FStatefulComputeEx = std::function<void (const OpStatePtr& state,
226  const OpContext& ctx,
227  const std::vector<NDArray>& inputs,
228  const std::vector<OpReqType>& req,
229  const std::vector<NDArray>& outputs)>;
236 using FResourceRequest = std::function<
237  std::vector<ResourceRequest> (const NodeAttrs& n)>;
244 using FResourceRequestEx = std::function<
245  std::vector<ResourceRequest> (const NodeAttrs& n,
246  const int dev_mask,
247  const DispatchMode dispatch_mode)>;
253 using FNDArrayFunction = std::function<void (const nnvm::NodeAttrs& attrs,
254  const std::vector<NDArray>& inputs,
255  std::vector<NDArray>* outputs)>;
261 using FCompute = std::function<void (const nnvm::NodeAttrs& attrs,
262  const OpContext& ctx,
263  const std::vector<TBlob>& inputs,
264  const std::vector<OpReqType>& req,
265  const std::vector<TBlob>& outputs)>;
271 using FComputeEx = std::function<void (const nnvm::NodeAttrs& attrs,
272  const OpContext& ctx,
273  const std::vector<NDArray>& inputs,
274  const std::vector<OpReqType>& req,
275  const std::vector<NDArray>& outputs)>;
283 using FInferStorageType = std::function<bool (const NodeAttrs& attrs,
284  const int dev_mask,
285  DispatchMode* dispatch_mode,
286  std::vector<int>* in_attrs,
287  std::vector<int>* out_attrs)>;
293 using FQuantizedOp = std::function<nnvm::NodePtr (const NodeAttrs& attrs)>;
301 using FNeedRequantize = std::function<bool (const NodeAttrs& attrs)>;
308 using FAvoidQuantizeInput = std::function<bool (const NodeAttrs& attrs,
309  size_t index)>;
311 } // namespace mxnet
313 #endif // MXNET_OP_ATTR_TYPES_H_
