mxnet
engine.h
Go to the documentation of this file.
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
24 #ifndef MXNET_ENGINE_H_
25 #define MXNET_ENGINE_H_
26 
27 #if DMLC_USE_CXX11
28 #include <algorithm>
29 #include <memory>
30 #include <functional>
31 #endif
32 #include <utility>
33 #include <vector>
34 #include "./base.h"
35 
36 namespace mxnet {
37 
38 // forward declare engine
39 class Engine;
40 
42 namespace engine {
43 #if MXNET_USE_CUDA
44 /* \brief The class wrapping CUDA event with timing disabled. */
45 class CUDAEvent final {
46  public:
47  explicit CUDAEvent(Context const& ctx);
48 
49  CUDAEvent(CUDAEvent&& other) : event_(other.event_), dev_id_(other.dev_id_) {
50  other.event_ = nullptr;
51  }
52 
53  CUDAEvent(const CUDAEvent& other) = delete;
54  void operator=(const CUDAEvent& other) = delete;
55 
56  ~CUDAEvent();
57 
58  inline std::weak_ptr<cudaEvent_t> GetEvent() noexcept {
59  return event_;
60  }
61 
62  private:
63  std::shared_ptr<cudaEvent_t> event_;
64  int dev_id_;
65 };
66 
67 class CUDAEventPool final {
68  public:
69  explicit CUDAEventPool(Context const& ctx) : counter_(0) {
70  for (size_t i = 0; i < kPoolSize; ++i) {
71  events_.emplace_back(ctx);
72  }
73  }
74 
75  inline std::weak_ptr<cudaEvent_t> GetEvent(size_t i) noexcept {
76  return events_.at(i).GetEvent();
77  }
78 
79  inline std::pair<std::weak_ptr<cudaEvent_t>, uint64_t> GetNextEvent() noexcept {
80  uint64_t c = counter_++;
81  return {events_.at((c) % kPoolSize).GetEvent(), c};
82  }
83 
84  inline uint64_t GetCounterValue() noexcept {
85  return counter_.load();
86  }
87 
88  private:
89  static constexpr size_t kPoolSize = 64;
90  std::vector<CUDAEvent> events_;
91  std::atomic<uint64_t> counter_;
92 };
93 
95 struct EventInfo {
96  std::weak_ptr<cudaEvent_t> event;
97  cudaStream_t stream;
98  uint64_t pool_index;
99 };
101 struct SyncObject {
102  // vector can carry multiple reader events
103  std::vector<EventInfo> reader_events;
104  // vector should carry only 1 writer event
105  std::vector<EventInfo> writer_event;
106  std::mutex mutex;
107 };
108 #endif
109 
111 struct Var {
112  virtual size_t version() {
113  return version_;
114  }
115  virtual ~Var() = default;
121  template <typename T>
122  inline T* Cast();
127  size_t version_{0};
128 #if MXNET_USE_CUDA
129 
133 #endif
134 }; // struct Var
135 
137 struct Opr;
139 typedef Var* VarHandle;
141 typedef Opr* OprHandle;
147  public:
148  // use implicit copy and assign
150  inline void operator()(const dmlc::Error* error = nullptr) const {
151  if (callback_ != nullptr)
152  (*callback_)(engine_, param_, error);
153  }
154 
155  private:
157  friend class ::mxnet::Engine;
159  void (*callback_)(Engine*, void*, const dmlc::Error*);
161  Engine* engine_;
163  void* param_;
164 };
170  public:
171  // use implicit copy and assign
173  inline void operator()(const dmlc::Error* error = nullptr) const {
174  (*callback_)(engine_, param_, error);
175  }
176 
177  private:
179  friend class ::mxnet::Engine;
181  void (*callback_)(Engine*, void*, const dmlc::Error*);
183  Engine* engine_;
185  void* param_;
186 };
187 } // namespace engine
188 
189 #if DMLC_USE_CXX11
190 
191 enum class FnProperty {
193  kNormal,
195  kCopyFromGPU,
197  kCopyToGPU,
201  kAsync,
203  kDeleteVar,
207  kNoSkip
208 }; // enum class FnProperty
209 
214  public:
220  typedef std::function<void(RunContext)> SyncFn;
222  typedef std::function<void(RunContext, CallbackOnStart, CallbackOnComplete)> AsyncFn;
234  virtual void NotifyShutdown() = 0;
238  virtual void Stop() {
239  LOG(FATAL) << "Engine cannot be stopped";
240  }
244  virtual void Start() {
245  LOG(FATAL) << "Engine cannot be restarted";
246  }
253  virtual VarHandle NewVariable() = 0;
266  virtual OprHandle NewOperator(AsyncFn fn,
267  std::vector<VarHandle> const& const_vars,
268  std::vector<VarHandle> const& mutable_vars,
270  const char* opr_name = nullptr,
271  bool wait = false) = 0;
279  virtual void DeleteOperator(OprHandle op) = 0;
287  virtual void Push(OprHandle op, Context exec_ctx, int priority = 0, bool profiling = false) = 0;
302  virtual void PushAsync(AsyncFn exec_fun,
303  Context exec_ctx,
304  std::vector<VarHandle> const& const_vars,
305  std::vector<VarHandle> const& mutable_vars,
307  int priority = 0,
308  const char* opr_name = nullptr,
309  bool wait = false) = 0;
321  virtual void DeleteVariable(SyncFn delete_fn, Context exec_ctx, VarHandle var) = 0;
327  virtual void WaitForVar(VarHandle var) = 0;
331  virtual void WaitForAll() = 0;
333  virtual void Throw(VarHandle var) = 0;
335  virtual ~Engine() noexcept(false) {}
339  static Engine* Get();
348  static const std::shared_ptr<Engine>& _GetSharedRef();
361  virtual void PushSync(SyncFn exec_fn,
362  Context exec_ctx,
363  std::vector<VarHandle> const& const_vars,
364  std::vector<VarHandle> const& mutable_vars,
366  int priority = 0,
367  const char* opr_name = nullptr) {
368  this->PushAsync(
369  [exec_fn](RunContext ctx, CallbackOnStart on_start, CallbackOnComplete on_complete) {
370  on_start();
371  exec_fn(ctx);
372  on_complete();
373  },
374  exec_ctx,
375  const_vars,
376  mutable_vars,
377  prop,
378  priority,
379  opr_name);
380  }
381 
387  inline CallbackOnStart CreateOnStart(void (*callback)(Engine*, void*, const dmlc::Error*),
388  void* param) {
389  CallbackOnStart ret;
390  ret.callback_ = callback;
391  ret.engine_ = this;
392  ret.param_ = param;
393  return ret;
394  }
395 
401  inline CallbackOnComplete CreateCallback(void (*callback)(Engine*, void*, const dmlc::Error*),
402  void* param) {
403  CallbackOnComplete ret;
404  ret.callback_ = callback;
405  ret.engine_ = this;
406  ret.param_ = param;
407  return ret;
408  }
409  // For each var vector, sort it and remove the duplicated vars.
410  // Also remove vars from read_vars if it also appears in write_vars
411  inline void DeduplicateVarHandle(std::vector<engine::VarHandle>* read_vars,
412  std::vector<engine::VarHandle>* write_vars) {
413  std::sort(write_vars->begin(), write_vars->end());
414  write_vars->resize(std::unique(write_vars->begin(), write_vars->end()) - write_vars->begin());
415  std::sort(read_vars->begin(), read_vars->end());
416  read_vars->resize(std::unique(read_vars->begin(), read_vars->end()) - read_vars->begin());
417  auto wit = write_vars->begin();
418  auto rtop = read_vars->begin();
419  for (auto rit = read_vars->begin(); rit != read_vars->end(); ++rit) {
420  while (wit != write_vars->end() && *wit < *rit)
421  ++wit;
422  if (wit == write_vars->end() || *wit != *rit) {
423  *rtop = *rit;
424  ++rtop;
425  }
426  }
427  read_vars->resize(rtop - read_vars->begin());
428  }
430  virtual int bulk_size() const {
431  return 0;
432  }
434  virtual int set_bulk_size(int) {
435  return 0;
436  }
437 }; // class Engine
438 #endif // DMLC_USE_CXX11
439 } // namespace mxnet
440 #endif // MXNET_ENGINE_H_
mxnet
namespace of mxnet
Definition: api_registry.h:33
mxnet::Engine::CallbackOnStart
engine::CallbackOnStart CallbackOnStart
on start
Definition: engine.h:216
mxnet::engine::CUDAEvent::operator=
void operator=(const CUDAEvent &other)=delete
mxnet::engine::CallbackOnComplete
OnComplete Callback to the engine, called by AsyncFn when action completes.
Definition: engine.h:169
mxnet::engine::CallbackOnStart
OnStart callback to the engine, called by AsyncFn before the action.
Definition: engine.h:146
mxnet::FnProperty::kNormal
@ kNormal
Normal operation.
mxnet::engine::Var::version_
size_t version_
version number of the var. Every time the object it is associated with is modified,...
Definition: engine.h:127
mxnet::engine::CUDAEvent::GetEvent
std::weak_ptr< cudaEvent_t > GetEvent() noexcept
Definition: engine.h:58
mxnet::Engine::CallbackOnComplete
engine::CallbackOnComplete CallbackOnComplete
callback on complete
Definition: engine.h:218
mxnet::engine::EventInfo::stream
cudaStream_t stream
Definition: engine.h:97
mxnet::engine::EventInfo
full event info for the sync object.
Definition: engine.h:95
mxnet::engine::EventInfo::pool_index
uint64_t pool_index
Definition: engine.h:98
mxnet::FnProperty::kDeleteVar
@ kDeleteVar
Delete variable call.
mxnet::engine::CUDAEventPool::CUDAEventPool
CUDAEventPool(Context const &ctx)
Definition: engine.h:69
mxnet::RunContext
execution time context. The information needed in runtime for actual execution.
Definition: base.h:343
mxnet::FnProperty::kAsync
@ kAsync
Asynchronous function call.
mxnet::engine::CUDAEventPool::GetNextEvent
std::pair< std::weak_ptr< cudaEvent_t >, uint64_t > GetNextEvent() noexcept
Definition: engine.h:79
mxnet::engine::CUDAEvent::CUDAEvent
CUDAEvent(CUDAEvent &&other)
Definition: engine.h:49
mxnet::Engine::Stop
virtual void Stop()
Stop all workers in the engine.
Definition: engine.h:238
mxnet::Engine::~Engine
virtual ~Engine() noexcept(false)
virtual destructor
Definition: engine.h:335
mxnet::FnProperty::kNoSkip
@ kNoSkip
Operation not to be skipped even with associated exception.
mxnet::Engine::SyncFn
std::function< void(RunContext)> SyncFn
Synchronous operation to pass to engine.
Definition: engine.h:220
mxnet::Engine::AsyncFn
std::function< void(RunContext, CallbackOnStart, CallbackOnComplete)> AsyncFn
Asynchronous operation to pass to engine.
Definition: engine.h:222
mxnet::engine::Var::~Var
virtual ~Var()=default
mxnet::Engine::Start
virtual void Start()
Restart all workers in the engine.
Definition: engine.h:244
mxnet::engine::OprHandle
Opr * OprHandle
Operator pointer type, usually hold by user.
Definition: engine.h:141
mxnet::engine::CUDAEvent::CUDAEvent
CUDAEvent(Context const &ctx)
mxnet::FnProperty::kGPUPrioritized
@ kGPUPrioritized
Prioritized sync operation on GPU.
mxnet::Engine::OprHandle
engine::OprHandle OprHandle
Operator pointer.
Definition: engine.h:226
mxnet::Engine
Dependency engine that schedules operations.
Definition: engine.h:213
mxnet::engine::CallbackOnComplete::operator()
void operator()(const dmlc::Error *error=nullptr) const
involve the callback
Definition: engine.h:173
mxnet::FnProperty::kCPUPrioritized
@ kCPUPrioritized
Prioritized sync operation on CPU.
mxnet::Context
Context information about the execution environment.
Definition: base.h:90
mxnet::Engine::PushSync
virtual void PushSync(SyncFn exec_fn, Context exec_ctx, std::vector< VarHandle > const &const_vars, std::vector< VarHandle > const &mutable_vars, FnProperty prop=FnProperty::kNormal, int priority=0, const char *opr_name=nullptr)
Push an synchronous operation to the engine.
Definition: engine.h:361
mxnet::engine::SyncObject::mutex
std::mutex mutex
Definition: engine.h:106
MXNET_API
#define MXNET_API
define dllexport for Visual Studio
Definition: base.h:49
mxnet::Engine::bulk_size
virtual int bulk_size() const
query current limit for bulk size
Definition: engine.h:430
mxnet::engine::CUDAEvent::~CUDAEvent
~CUDAEvent()
mxnet::engine::VarHandle
Var * VarHandle
Variable pointer type, usually hold by user used to specify dependencies.
Definition: engine.h:137
mxnet::engine::Var
base class of engine variables.
Definition: engine.h:111
mxnet::engine::CUDAEventPool::GetEvent
std::weak_ptr< cudaEvent_t > GetEvent(size_t i) noexcept
Definition: engine.h:75
mxnet::engine::EventInfo::event
std::weak_ptr< cudaEvent_t > event
Definition: engine.h:96
mxnet::engine::SyncObject::writer_event
std::vector< EventInfo > writer_event
Definition: engine.h:105
mxnet::Engine::DeduplicateVarHandle
void DeduplicateVarHandle(std::vector< engine::VarHandle > *read_vars, std::vector< engine::VarHandle > *write_vars)
Definition: engine.h:411
mxnet::Engine::CreateOnStart
CallbackOnStart CreateOnStart(void(*callback)(Engine *, void *, const dmlc::Error *), void *param)
factory function to create OnStart callback.
Definition: engine.h:387
mxnet::engine::CUDAEventPool
Definition: engine.h:67
mxnet::Engine::VarHandle
engine::VarHandle VarHandle
Variable pointer.
Definition: engine.h:224
mxnet::engine::SyncObject::reader_events
std::vector< EventInfo > reader_events
Definition: engine.h:103
mxnet::Engine::set_bulk_size
virtual int set_bulk_size(int)
set maximum limit for bulk size
Definition: engine.h:434
mxnet::engine::Var::Cast
T * Cast()
cast variable to derived type T
mxnet::Engine::CreateCallback
CallbackOnComplete CreateCallback(void(*callback)(Engine *, void *, const dmlc::Error *), void *param)
factory function to create OnComplete callback.
Definition: engine.h:401
mxnet::FnProperty::kCopyToGPU
@ kCopyToGPU
Copy operation from CPU to other devices.
mxnet::engine::Var::version
virtual size_t version()
Definition: engine.h:112
base.h
configuration of MXNet as well as basic data structure.
mxnet::engine::SyncObject
struct containing cuda events and variables needed for the dependencies.
Definition: engine.h:101
mxnet::FnProperty::kCopyFromGPU
@ kCopyFromGPU
Copy operation from GPU to other devices.
mxnet::FnProperty
FnProperty
Function property, used to hint what action is pushed to engine.
Definition: engine.h:191
mxnet::engine::Var::sync_object
SyncObject sync_object
struct containing cuda events and variables needed for the dependencies.
Definition: engine.h:132
mxnet::engine::CUDAEvent
Definition: engine.h:45
mxnet::engine::CUDAEventPool::GetCounterValue
uint64_t GetCounterValue() noexcept
Definition: engine.h:84
mxnet::engine::CallbackOnStart::operator()
void operator()(const dmlc::Error *error=nullptr) const
involve the callback
Definition: engine.h:150