Go to the documentation of this file.
24 #ifndef MXNET_ENGINE_H_
25 #define MXNET_ENGINE_H_
50 other.event_ =
nullptr;
58 inline std::weak_ptr<cudaEvent_t>
GetEvent() noexcept {
63 std::shared_ptr<cudaEvent_t> event_;
70 for (
size_t i = 0; i < kPoolSize; ++i) {
71 events_.emplace_back(ctx);
75 inline std::weak_ptr<cudaEvent_t>
GetEvent(
size_t i) noexcept {
76 return events_.at(i).GetEvent();
79 inline std::pair<std::weak_ptr<cudaEvent_t>, uint64_t>
GetNextEvent() noexcept {
80 uint64_t c = counter_++;
81 return {events_.at((c) % kPoolSize).GetEvent(), c};
85 return counter_.load();
89 static constexpr
size_t kPoolSize = 64;
90 std::vector<CUDAEvent> events_;
91 std::atomic<uint64_t> counter_;
96 std::weak_ptr<cudaEvent_t>
event;
115 virtual ~Var() =
default;
121 template <
typename T>
150 inline void operator()(
const dmlc::Error* error =
nullptr)
const {
151 if (callback_ !=
nullptr)
152 (*callback_)(engine_, param_, error);
157 friend class ::mxnet::Engine;
159 void (*callback_)(
Engine*,
void*,
const dmlc::Error*);
173 inline void operator()(
const dmlc::Error* error =
nullptr)
const {
174 (*callback_)(engine_, param_, error);
179 friend class ::mxnet::Engine;
181 void (*callback_)(
Engine*,
void*,
const dmlc::Error*);
234 virtual void NotifyShutdown() = 0;
239 LOG(FATAL) <<
"Engine cannot be stopped";
245 LOG(FATAL) <<
"Engine cannot be restarted";
266 virtual OprHandle NewOperator(AsyncFn fn,
267 std::vector<VarHandle>
const& const_vars,
268 std::vector<VarHandle>
const& mutable_vars,
270 const char* opr_name =
nullptr,
271 bool wait =
false) = 0;
279 virtual void DeleteOperator(
OprHandle op) = 0;
287 virtual void Push(
OprHandle op,
Context exec_ctx,
int priority = 0,
bool profiling =
false) = 0;
302 virtual void PushAsync(AsyncFn exec_fun,
304 std::vector<VarHandle>
const& const_vars,
305 std::vector<VarHandle>
const& mutable_vars,
308 const char* opr_name =
nullptr,
309 bool wait =
false) = 0;
321 virtual void DeleteVariable(SyncFn delete_fn,
Context exec_ctx,
VarHandle var) = 0;
327 virtual void WaitForVar(
VarHandle var) = 0;
331 virtual void WaitForAll() = 0;
348 static const std::shared_ptr<Engine>& _GetSharedRef();
363 std::vector<VarHandle>
const& const_vars,
364 std::vector<VarHandle>
const& mutable_vars,
367 const char* opr_name =
nullptr) {
390 ret.callback_ = callback;
404 ret.callback_ = callback;
412 std::vector<engine::VarHandle>* write_vars) {
413 std::sort(write_vars->begin(), write_vars->end());
414 write_vars->resize(std::unique(write_vars->begin(), write_vars->end()) - write_vars->begin());
415 std::sort(read_vars->begin(), read_vars->end());
416 read_vars->resize(std::unique(read_vars->begin(), read_vars->end()) - read_vars->begin());
417 auto wit = write_vars->begin();
418 auto rtop = read_vars->begin();
419 for (
auto rit = read_vars->begin(); rit != read_vars->end(); ++rit) {
420 while (wit != write_vars->end() && *wit < *rit)
422 if (wit == write_vars->end() || *wit != *rit) {
427 read_vars->resize(rtop - read_vars->begin());
438 #endif // DMLC_USE_CXX11
440 #endif // MXNET_ENGINE_H_
namespace of mxnet
Definition: api_registry.h:33
engine::CallbackOnStart CallbackOnStart
on start
Definition: engine.h:216
void operator=(const CUDAEvent &other)=delete
OnComplete Callback to the engine, called by AsyncFn when action completes.
Definition: engine.h:169
OnStart callback to the engine, called by AsyncFn before the action.
Definition: engine.h:146
@ kNormal
Normal operation.
size_t version_
version number of the var. Every time the object it is associated with is modified,...
Definition: engine.h:127
std::weak_ptr< cudaEvent_t > GetEvent() noexcept
Definition: engine.h:58
engine::CallbackOnComplete CallbackOnComplete
callback on complete
Definition: engine.h:218
cudaStream_t stream
Definition: engine.h:97
full event info for the sync object.
Definition: engine.h:95
uint64_t pool_index
Definition: engine.h:98
@ kDeleteVar
Delete variable call.
CUDAEventPool(Context const &ctx)
Definition: engine.h:69
execution time context. The information needed in runtime for actual execution.
Definition: base.h:343
@ kAsync
Asynchronous function call.
std::pair< std::weak_ptr< cudaEvent_t >, uint64_t > GetNextEvent() noexcept
Definition: engine.h:79
CUDAEvent(CUDAEvent &&other)
Definition: engine.h:49
virtual void Stop()
Stop all workers in the engine.
Definition: engine.h:238
virtual ~Engine() noexcept(false)
virtual destructor
Definition: engine.h:335
@ kNoSkip
Operation not to be skipped even with associated exception.
std::function< void(RunContext)> SyncFn
Synchronous operation to pass to engine.
Definition: engine.h:220
std::function< void(RunContext, CallbackOnStart, CallbackOnComplete)> AsyncFn
Asynchronous operation to pass to engine.
Definition: engine.h:222
virtual void Start()
Restart all workers in the engine.
Definition: engine.h:244
Opr * OprHandle
Operator pointer type, usually hold by user.
Definition: engine.h:141
CUDAEvent(Context const &ctx)
@ kGPUPrioritized
Prioritized sync operation on GPU.
engine::OprHandle OprHandle
Operator pointer.
Definition: engine.h:226
Dependency engine that schedules operations.
Definition: engine.h:213
void operator()(const dmlc::Error *error=nullptr) const
involve the callback
Definition: engine.h:173
@ kCPUPrioritized
Prioritized sync operation on CPU.
Context information about the execution environment.
Definition: base.h:90
virtual void PushSync(SyncFn exec_fn, Context exec_ctx, std::vector< VarHandle > const &const_vars, std::vector< VarHandle > const &mutable_vars, FnProperty prop=FnProperty::kNormal, int priority=0, const char *opr_name=nullptr)
Push an synchronous operation to the engine.
Definition: engine.h:361
std::mutex mutex
Definition: engine.h:106
#define MXNET_API
define dllexport for Visual Studio
Definition: base.h:49
virtual int bulk_size() const
query current limit for bulk size
Definition: engine.h:430
Var * VarHandle
Variable pointer type, usually hold by user used to specify dependencies.
Definition: engine.h:137
base class of engine variables.
Definition: engine.h:111
std::weak_ptr< cudaEvent_t > GetEvent(size_t i) noexcept
Definition: engine.h:75
std::weak_ptr< cudaEvent_t > event
Definition: engine.h:96
std::vector< EventInfo > writer_event
Definition: engine.h:105
void DeduplicateVarHandle(std::vector< engine::VarHandle > *read_vars, std::vector< engine::VarHandle > *write_vars)
Definition: engine.h:411
CallbackOnStart CreateOnStart(void(*callback)(Engine *, void *, const dmlc::Error *), void *param)
factory function to create OnStart callback.
Definition: engine.h:387
engine::VarHandle VarHandle
Variable pointer.
Definition: engine.h:224
std::vector< EventInfo > reader_events
Definition: engine.h:103
virtual int set_bulk_size(int)
set maximum limit for bulk size
Definition: engine.h:434
T * Cast()
cast variable to derived type T
CallbackOnComplete CreateCallback(void(*callback)(Engine *, void *, const dmlc::Error *), void *param)
factory function to create OnComplete callback.
Definition: engine.h:401
@ kCopyToGPU
Copy operation from CPU to other devices.
virtual size_t version()
Definition: engine.h:112
configuration of MXNet as well as basic data structure.
struct containing cuda events and variables needed for the dependencies.
Definition: engine.h:101
@ kCopyFromGPU
Copy operation from GPU to other devices.
FnProperty
Function property, used to hint what action is pushed to engine.
Definition: engine.h:191
SyncObject sync_object
struct containing cuda events and variables needed for the dependencies.
Definition: engine.h:132
uint64_t GetCounterValue() noexcept
Definition: engine.h:84
void operator()(const dmlc::Error *error=nullptr) const
involve the callback
Definition: engine.h:150