mxnet
engine.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef MXNET_ENGINE_H_
25 #define MXNET_ENGINE_H_
26 
27 #include <dmlc/base.h>
28 #if DMLC_USE_CXX11
29 #include <algorithm>
30 #include <memory>
31 #include <functional>
32 #endif
33 #include <vector>
34 #include "./base.h"
35 
36 namespace mxnet {
37 
38 // forward declare engine
39 class Engine;
40 
42 namespace engine {
44 struct Var;
46 struct Opr;
48 typedef Var* VarHandle;
50 typedef Opr* OprHandle;
56  public:
57  // use implicit copy and assign
59  inline void operator()() const {
60  (*callback_)(engine_, param_);
61  }
62 
63  private:
65  friend class ::mxnet::Engine;
67  void (*callback_)(Engine *, void *);
69  Engine* engine_;
71  void* param_;
72 };
73 } // namespace engine
74 
75 #if DMLC_USE_CXX11
76 
77 enum class FnProperty {
79  kNormal,
83  kCopyToGPU,
87  kAsync,
90 }; // enum class FnProperty
91 
96  public:
100  typedef std::function<void(RunContext)> SyncFn;
102  typedef std::function<void(RunContext, CallbackOnComplete)> AsyncFn;
114  virtual void NotifyShutdown() = 0;
121  virtual VarHandle NewVariable() = 0;
133  virtual OprHandle NewOperator(AsyncFn fn,
134  std::vector<VarHandle> const& const_vars,
135  std::vector<VarHandle> const& mutable_vars,
137  const char* opr_name = nullptr) = 0;
145  virtual void DeleteOperator(OprHandle op) = 0;
153  virtual void Push(OprHandle op, Context exec_ctx, int priority = 0, bool profiling = false) = 0;
167  virtual void PushAsync(AsyncFn exec_fun, Context exec_ctx,
168  std::vector<VarHandle> const& const_vars,
169  std::vector<VarHandle> const& mutable_vars,
171  int priority = 0,
172  const char* opr_name = nullptr) = 0;
184  virtual void DeleteVariable(SyncFn delete_fn,
185  Context exec_ctx,
186  VarHandle var) = 0;
192  virtual void WaitForVar(VarHandle var) = 0;
196  virtual void WaitForAll() = 0;
198  virtual ~Engine() noexcept(false) {}
202  static Engine* Get();
211  static std::shared_ptr<Engine> _GetSharedRef();
224  inline void PushSync(SyncFn exec_fn, Context exec_ctx,
225  std::vector<VarHandle> const& const_vars,
226  std::vector<VarHandle> const& mutable_vars,
228  int priority = 0,
229  const char* opr_name = nullptr) {
230  this->PushAsync([exec_fn](RunContext ctx, CallbackOnComplete on_complete) {
231  exec_fn(ctx);
232  on_complete();
233  }, exec_ctx, const_vars, mutable_vars, prop, priority, opr_name);
234  }
235 
241  inline CallbackOnComplete CreateCallback(
242  void (*callback)(Engine *, void *), void *param) {
243  CallbackOnComplete ret;
244  ret.callback_ = callback;
245  ret.engine_ = this;
246  ret.param_ = param;
247  return ret;
248  }
249  // For each var vector, sort it and remove the duplicated vars.
250  // Also remove vars from read_vars if it also appears in write_vars
251  inline void DeduplicateVarHandle(std::vector<engine::VarHandle> *read_vars,
252  std::vector<engine::VarHandle> *write_vars) {
253  std::sort(write_vars->begin(), write_vars->end());
254  write_vars->resize(std::unique(write_vars->begin(), write_vars->end()) -
255  write_vars->begin());
256  std::sort(read_vars->begin(), read_vars->end());
257  read_vars->resize(std::unique(read_vars->begin(), read_vars->end()) -
258  read_vars->begin());
259  auto wit = write_vars->begin();
260  auto rtop = read_vars->begin();
261  for (auto rit = read_vars->begin(); rit != read_vars->end(); ++rit) {
262  while (wit != write_vars->end() && *wit < *rit) ++wit;
263  if (wit == write_vars->end() || *wit != *rit) {
264  *rtop = *rit;
265  ++rtop;
266  }
267  }
268  read_vars->resize(rtop - read_vars->begin());
269  }
270 
274  virtual int num_omp_threads_per_worker() const = 0;
275 
279  virtual void set_num_omp_threads_per_worker(int num_omp_threads_per_worker) = 0;
280 }; // class Engine
281 #endif // DMLC_USE_CXX11
282 } // namespace mxnet
283 #endif // MXNET_ENGINE_H_
void DeduplicateVarHandle(std::vector< engine::VarHandle > *read_vars, std::vector< engine::VarHandle > *write_vars)
Definition: engine.h:251
FnProperty
Function property, used to hint what action is pushed to engine.
Definition: engine.h:77
std::function< void(RunContext)> SyncFn
Synchronous operation to pass to engine.
Definition: engine.h:100
std::function< void(RunContext, CallbackOnComplete)> AsyncFn
Asynchronous operation to pass to engine.
Definition: engine.h:102
Symbol sort(const std::string &symbol_name, Symbol data, dmlc::optional< int > axis=dmlc::optional< int >(-1), bool is_ascend=1)
Definition: op.h:2487
namespace of mxnet
Definition: base.h:126
Asynchronous function call.
CallbackOnComplete CreateCallback(void(*callback)(Engine *, void *), void *param)
factory function to create OnComplete callback.
Definition: engine.h:241
void operator()() const
involve the callback
Definition: engine.h:59
execution time context. The information needed in runtime for actual execution.
Definition: base.h:238
Delete variable call.
Normal operation.
virtual ~Engine() noexcept(false)
virtual destructor
Definition: engine.h:198
Copy operation from GPU to other devices.
engine::OprHandle OprHandle
Operator pointer.
Definition: engine.h:106
engine::VarHandle VarHandle
Variable pointer.
Definition: engine.h:104
Var * VarHandle
Variable pointer type, usually hold by user used to specify dependencies.
Definition: engine.h:46
Prioritized sync operation on CPU.
engine::CallbackOnComplete CallbackOnComplete
callback on complete
Definition: engine.h:98
void PushSync(SyncFn exec_fn, Context exec_ctx, std::vector< VarHandle > const &const_vars, std::vector< VarHandle > const &mutable_vars, FnProperty prop=FnProperty::kNormal, int priority=0, const char *opr_name=nullptr)
Push an synchronous operation to the engine.
Definition: engine.h:224
Dependency engine that schedules operations.
Definition: engine.h:95
OnComplete Callback to the engine, called by AsyncFn when action completes.
Definition: engine.h:55
Context information about the execution environment.
Definition: base.h:141
#define MXNET_API
define compatible keywords in g++ Used to support g++-4.6 and g++4.7
Definition: base.h:91
Copy operation from CPU to other devices.
Opr * OprHandle
Operator pointer type, usually hold by user.
Definition: engine.h:50