mxnet
thread_group.h
Go to the documentation of this file.
1 
6 #ifndef DMLC_THREAD_GROUP_H_
7 #define DMLC_THREAD_GROUP_H_
8 
9 #include <dmlc/concurrentqueue.h>
11 #include <dmlc/logging.h>
12 #include <string>
13 #include <mutex>
14 #include <utility>
15 #include <memory>
16 #include <set>
17 #include <thread>
18 #include <unordered_set>
19 #include <unordered_map>
20 #if defined(DMLC_USE_CXX14) || __cplusplus > 201103L /* C++14 */
21 #include <shared_mutex>
22 #endif
23 #include <condition_variable>
24 #ifdef __linux__
25 #include <unistd.h>
26 #include <sys/syscall.h>
27 #endif
28 
29 namespace dmlc {
30 
34 class ManualEvent {
35  public:
36  ManualEvent() : signaled_(false) {}
37 
42  void wait() {
43  std::unique_lock<std::mutex> lock(mutex_);
44  if (!signaled_) {
45  condition_variable_.wait(lock);
46  }
47  }
48 
52  void signal() {
53  signaled_ = true;
54  std::unique_lock<std::mutex> lk(mutex_);
55  condition_variable_.notify_all();
56  }
57 
61  void reset() {
62  std::unique_lock<std::mutex> lk(mutex_);
63  signaled_ = false;
64  }
65 
66  private:
68  std::mutex mutex_;
70  std::condition_variable condition_variable_;
72  std::atomic<bool> signaled_;
73 };
74 
75 #if defined(DMLC_USE_CXX14) || __cplusplus > 201103L /* C++14 */
76 
77 using SharedMutex = std::shared_timed_mutex;
79 using WriteLock = std::unique_lock<SharedMutex>;
81 using ReadLock = std::shared_lock<SharedMutex>;
82 #else
83 
84 using SharedMutex = std::recursive_mutex;
86 using WriteLock = std::unique_lock<SharedMutex>;
88 using ReadLock = std::unique_lock<SharedMutex>;
89 #endif
90 
95 class ThreadGroup {
96  public:
101  class Thread {
102  public:
104  using SharedPtr = std::shared_ptr<Thread>;
105 
112  Thread(std::string threadName, ThreadGroup *owner, std::thread *thrd = nullptr)
113  : name_(std::move(threadName))
114  , thread_(thrd)
115  , ready_event_(std::make_shared<ManualEvent>())
116  , start_event_(std::make_shared<ManualEvent>())
117  , owner_(owner)
118  , shutdown_requested_(false)
119  , auto_remove_(false) {
120  CHECK_NOTNULL(owner);
121  }
122 
126  virtual ~Thread() {
127  const bool self_delete = is_current_thread();
128  if (!self_delete) {
129  request_shutdown();
130  internal_join(true);
131  }
132  WriteLock guard(thread_mutex_);
133  if (thread_.load()) {
134  std::thread *thrd = thread_.load();
135  thread_ = nullptr;
136  if (self_delete) {
137  thrd->detach();
138  }
139  delete thrd;
140  }
141  }
142 
149  const char *name() const {
150  return name_.c_str();
151  }
152 
167  template<typename StartFunction, typename ...Args>
168  static bool launch(std::shared_ptr<Thread> pThis,
169  bool autoRemove,
170  StartFunction start_function,
171  Args ...args);
172 
177  bool is_current_thread() const {
178  ReadLock guard(thread_mutex_);
179  return thread_.load() ? (thread_.load()->get_id() == std::this_thread::get_id()) : false;
180  }
181 
187  virtual void request_shutdown() {
188  shutdown_requested_ = true;
189  }
190 
197  virtual bool is_shutdown_requested() const {
198  return shutdown_requested_.load();
199  }
200 
207  bool is_auto_remove() const {
208  return auto_remove_;
209  }
210 
216  void make_joinable() {
217  auto_remove_ = false;
218  }
219 
224  bool joinable() const {
225  if (thread_.load()) {
226  CHECK_EQ(auto_remove_, false);
227  // be checked by searching the group or exit event.
228  return thread_.load()->joinable();
229  }
230  return false;
231  }
232 
237  void join() {
238  internal_join(false);
239  }
240 
245  std::thread::id get_id() const {
246  return thread_.load()->get_id();
247  }
248 
249  private:
254  void internal_join(bool auto_remove_ok) {
255  ReadLock guard(thread_mutex_);
256  // should be careful calling (or any function externally) this when in
257  // auto-remove mode
258  if (thread_.load() && thread_.load()->get_id() != std::thread::id()) {
259  std::thread::id someId;
260  if (!auto_remove_ok) {
261  CHECK_EQ(auto_remove_, false);
262  }
263  CHECK_NOTNULL(thread_.load());
264  if (thread_.load()->joinable()) {
265  thread_.load()->join();
266  } else {
267  LOG(WARNING) << "Thread " << name_ << " ( "
268  << thread_.load()->get_id() << " ) not joinable";
269  }
270  }
271  }
272 
282  template <typename StartFunction, typename ...Args>
283  static int entry_and_exit_f(std::shared_ptr<Thread> pThis,
284  StartFunction start_function,
285  Args... args);
287  std::string name_;
289  mutable SharedMutex thread_mutex_;
291  std::atomic<std::thread *> thread_;
293  std::shared_ptr<ManualEvent> ready_event_;
295  std::shared_ptr<ManualEvent> start_event_;
297  ThreadGroup *owner_;
299  std::atomic<bool> shutdown_requested_;
304  std::atomic<bool> auto_remove_;
305  };
306 
310  inline ThreadGroup()
311  : evEmpty_(std::make_shared<ManualEvent>()) {
312  evEmpty_->signal(); // Starts out empty
313  }
314 
319  virtual ~ThreadGroup() {
320  request_shutdown_all();
321  join_all();
322  }
323 
330  inline bool is_this_thread_in() const {
331  std::thread::id id = std::this_thread::get_id();
332  ReadLock guard(m_);
333  for (auto it = threads_.begin(), end = threads_.end(); it != end; ++it) {
334  std::shared_ptr<Thread> thrd = *it;
335  if (thrd->get_id() == id)
336  return true;
337  }
338  return false;
339  }
340 
346  inline bool is_thread_in(std::shared_ptr<Thread> thrd) const {
347  if (thrd) {
348  std::thread::id id = thrd->get_id();
349  ReadLock guard(m_);
350  for (auto it = threads_.begin(), end = threads_.end(); it != end; ++it) {
351  std::shared_ptr<Thread> thrd = *it;
352  if (thrd->get_id() == id)
353  return true;
354  }
355  return false;
356  } else {
357  return false;
358  }
359  }
360 
366  inline bool add_thread(std::shared_ptr<Thread> thrd) {
367  if (thrd) {
368  WriteLock guard(m_);
369  auto iter = name_to_thread_.find(thrd->name());
370  if (iter == name_to_thread_.end()) {
371  name_to_thread_.emplace(std::make_pair(thrd->name(), thrd));
372  CHECK_EQ(threads_.insert(thrd).second, true);
373  evEmpty_->reset();
374  return true;
375  }
376  }
377  return false;
378  }
379 
385  inline bool remove_thread(std::shared_ptr<Thread> thrd) {
386  if (thrd) {
387  WriteLock guard(m_);
388  auto iter = threads_.find(thrd);
389  if (iter != threads_.end()) {
390  name_to_thread_.erase(thrd->name());
391  threads_.erase(iter);
392  if (threads_.empty()) {
393  evEmpty_->signal();
394  }
395  return true;
396  }
397  }
398  return false;
399  }
400 
406  inline void join_all() {
407  CHECK_EQ(!is_this_thread_in(), true);
408  do {
409  std::unique_lock<std::mutex> lk(join_all_mtx_);
410  std::unordered_set<std::shared_ptr<Thread>> working_set;
411  {
412  ReadLock guard(m_);
413  for (auto iter = threads_.begin(), e_iter = threads_.end(); iter != e_iter; ++iter) {
414  if (!(*iter)->is_auto_remove()) {
415  working_set.emplace(*iter);
416  }
417  }
418  }
419  // Where possible, prefer to do a proper join rather than simply waiting for empty
420  // (easier to troubleshoot)
421  while (!working_set.empty()) {
422  std::shared_ptr<Thread> thrd;
423  thrd = *working_set.begin();
424  if (thrd->joinable()) {
425  thrd->join();
426  }
427  remove_thread(thrd);
428  working_set.erase(working_set.begin());
429  thrd.reset();
430  }
431  // Wait for auto-remove threads (if any) to complete
432  } while (0);
433  evEmpty_->wait();
434  CHECK_EQ(threads_.size(), 0);
435  }
436 
441  inline void request_shutdown_all(const bool make_all_joinable = true) {
442  std::unique_lock<std::mutex> lk(join_all_mtx_);
443  ReadLock guard(m_);
444  for (auto &thread : threads_) {
445  if (make_all_joinable) {
446  thread->make_joinable();
447  }
448  thread->request_shutdown();
449  }
450  }
451 
456  inline size_t size() const {
457  ReadLock guard(m_);
458  return threads_.size();
459  }
460 
465  inline bool empty() const {
466  ReadLock guard(m_);
467  return threads_.size() == 0;
468  }
469 
485  template<typename StartFunction, typename ThreadType = Thread, typename ...Args>
486  inline bool create(const std::string &threadName,
487  bool auto_remove,
488  StartFunction start_function,
489  Args... args) {
490  typename ThreadType::SharedPtr newThread(new ThreadType(threadName, this));
491  return Thread::launch(newThread, auto_remove, start_function, args...);
492  }
493 
499  inline std::shared_ptr<Thread> thread_by_name(const std::string& name) {
500  ReadLock guard(m_);
501  auto iter = name_to_thread_.find(name);
502  if (iter != name_to_thread_.end()) {
503  return iter->second;
504  }
505  return nullptr;
506  }
507 
508  private:
510  mutable SharedMutex m_;
512  mutable std::mutex join_all_mtx_;
514  std::unordered_set<std::shared_ptr<Thread>> threads_;
516  std::shared_ptr<ManualEvent> evEmpty_;
518  std::unordered_map<std::string, std::shared_ptr<Thread>> name_to_thread_;
519 };
520 
527 template<typename ObjectType, ObjectType quit_item>
530 
531  public:
538  BlockingQueueThread(const std::string& name,
539  dmlc::ThreadGroup *owner,
540  std::thread *thrd = nullptr)
541  : ThreadGroup::Thread(std::move(name), owner, thrd)
542  , shutdown_in_progress_(false) {
543  }
544 
545 
549  ~BlockingQueueThread() override {
550  // Call to parent first because we don't want to wait for the queue to empty
552  request_shutdown();
553  }
554 
562  void request_shutdown() override {
563  shutdown_in_progress_ = true;
564  while (queue_->size_approx() > 0 && !ThreadGroup::Thread::is_shutdown_requested()) {
565  std::this_thread::sleep_for(std::chrono::milliseconds(1));
566  }
568  queue_->enqueue(quit_item);
569  }
570 
575  void enqueue(const ObjectType& item) {
576  if (!shutdown_in_progress_) {
577  queue_->enqueue(item);
578  }
579  }
580 
585  size_t size_approx() const { return queue_->size_approx(); }
586 
597  template<typename SecondaryFunction>
598  static bool launch_run(std::shared_ptr<BQT> pThis,
599  SecondaryFunction secondary_function) {
600  return ThreadGroup::Thread::launch(pThis, true, [](std::shared_ptr<BQT> pThis,
601  SecondaryFunction secondary_function) {
602  return pThis->run(secondary_function);
603  },
604  pThis, secondary_function);
605  }
606 
613  template<typename OnItemFunction>
614  inline int run(OnItemFunction on_item_function) {
615  int rc = 0;
616  do {
617  ObjectType item;
618  queue_->wait_dequeue(item);
619  if (item == quit_item) {
620  break;
621  }
622  rc = on_item_function(item);
623  if (rc) {
624  break;
625  }
626  } while (true);
627  return rc;
628  }
629 
630  private:
632  std::shared_ptr<dmlc::moodycamel::BlockingConcurrentQueue<ObjectType>> queue_ =
633  std::make_shared<dmlc::moodycamel::BlockingConcurrentQueue<ObjectType>>();
635  std::atomic<bool> shutdown_in_progress_;
636 };
637 
642 template<typename Duration>
645 
646  public:
652  TimerThread(const std::string& name, ThreadGroup *owner)
653  : Thread(name, owner) {
654  }
655 
659  ~TimerThread() override {
660  request_shutdown();
661  }
662 
673  template<typename SecondaryFunction>
674  static bool launch_run(std::shared_ptr<TimerThread<Duration>> pThis,
675  SecondaryFunction secondary_function) {
676  return ThreadGroup::Thread::launch(pThis, true, [](std::shared_ptr<TimerThread<Duration>> pThis,
677  SecondaryFunction secondary_function) {
678  return pThis->run(secondary_function);
679  },
680  pThis, secondary_function);
681  }
682 
692  template<typename Function>
693  static void start(std::shared_ptr<TimerThread> timer_thread,
694  Duration duration,
695  Function function) {
696  timer_thread->duration_ = duration;
697  launch_run(timer_thread, function);
698  }
699 
706  template<typename OnTimerFunction>
707  inline int run(OnTimerFunction on_timer_function) {
708  int rc = 0;
709  while (!is_shutdown_requested()) {
710  std::this_thread::sleep_for(duration_);
711  if (!is_shutdown_requested()) {
712  rc = on_timer_function();
713  }
714  }
715  return rc;
716  }
717 
718  private:
719  Duration duration_;
720 };
721 
722 /*
723  * Inline functions - see declarations for usage
724  */
725 template <typename StartFunction, typename ...Args>
726 inline int ThreadGroup::Thread::entry_and_exit_f(std::shared_ptr<Thread> pThis,
727  StartFunction start_function,
728  Args... args) {
729  int rc;
730  if (pThis) {
731  // Signal launcher that we're up and running
732  pThis->ready_event_->signal();
733  // Wait for launcher to be ready for us to start
734  pThis->start_event_->wait();
735  // Reset start_event_ for possible reuse
736  pThis->start_event_->reset(); // Reset in case it needs to be reused
737  // If we haven't been requested to shut down prematurely, then run the desired function
738  if (!pThis->is_shutdown_requested()) {
739  rc = start_function(args...);
740  } else {
741  rc = -1;
742  }
743  // If we're set up as auto-remove, then remove this thread from the thread group
744  if (pThis->is_auto_remove()) {
745  pThis->owner_->remove_thread(pThis);
746  }
747  // Release this thread shared pinter. May or may not be the last reference.
748  pThis.reset();
749  } else {
750  LOG(ERROR) << "Null pThis thread pointer";
751  rc = EINVAL;
752  }
753  return rc;
754 }
755 
756 template<typename StartFunction, typename ...Args>
757 inline bool ThreadGroup::Thread::launch(std::shared_ptr<Thread> pThis,
758  bool autoRemove,
759  StartFunction start_function,
760  Args ...args) {
761  WriteLock guard(pThis->thread_mutex_);
762  CHECK_EQ(!pThis->thread_.load(), true);
763  CHECK_NOTNULL(pThis->owner_);
764  // Set auto remove
765  pThis->auto_remove_ = autoRemove;
766  // Create the actual stl thread object
767  pThis->thread_ = new std::thread(Thread::template entry_and_exit_f<
768  StartFunction, Args...>,
769  pThis,
770  start_function,
771  args...);
772  // Attempt to add the thread to the thread group (after started, since in case
773  // something goes wrong, there's not a zombie thread in the thread group)
774  if (!pThis->owner_->add_thread(pThis)) {
775  pThis->request_shutdown();
776  LOG(ERROR) << "Duplicate thread name within the same thread group is not allowed";
777  }
778  // Wait for the thread to spin up
779  pThis->ready_event_->wait();
780  // Signal the thgread to continue (it will check its shutdown status)
781  pThis->start_event_->signal();
782  // Return if successful
783  return pThis->thread_.load() != nullptr;
784 }
785 
796 template<typename Duration, typename TimerFunction>
797 inline bool CreateTimer(const std::string& timer_name,
798  const Duration& duration,
799  ThreadGroup *owner,
800  TimerFunction timer_function) {
801  std::shared_ptr<dmlc::TimerThread<Duration>> timer_thread =
802  std::make_shared<dmlc::TimerThread<Duration>>(timer_name, owner);
803  dmlc::TimerThread<Duration>::start(timer_thread, duration, timer_function);
804  return timer_thread != nullptr;
805 }
806 } // namespace dmlc
807 
808 #endif // DMLC_THREAD_GROUP_H_
void enqueue(const ObjectType &item)
Enqueue and item.
Definition: thread_group.h:575
static bool launch_run(std::shared_ptr< BQT > pThis, SecondaryFunction secondary_function)
Launch to the &#39;run&#39; function which will, in turn, call the class&#39; &#39;run&#39; function, passing it the give...
Definition: thread_group.h:598
std::shared_ptr< Thread > thread_by_name(const std::string &name)
Lookup Thread object by name.
Definition: thread_group.h:499
static bool launch(std::shared_ptr< Thread > pThis, bool autoRemove, StartFunction start_function, Args...args)
Launch the given Thread object.
Definition: thread_group.h:757
BlockingQueueThread(const std::string &name, dmlc::ThreadGroup *owner, std::thread *thrd=nullptr)
Constructor.
Definition: thread_group.h:538
std::thread::id get_id() const
Get this thread&#39;s id.
Definition: thread_group.h:245
void join()
Thread join.
Definition: thread_group.h:237
void request_shutdown_all(const bool make_all_joinable=true)
Call request_shutdown() on all threads in this ThreadGroup.
Definition: thread_group.h:441
bool empty() const
Check if the ThreadGroup is empty.
Definition: thread_group.h:465
virtual bool is_shutdown_requested() const
Check whether shutdown has been requested (request_shutdown() was called)
Definition: thread_group.h:197
Thread lifecycle management group.
Definition: thread_group.h:95
ThreadGroup()
Constructor.
Definition: thread_group.h:310
Definition: optional.h:241
std::recursive_mutex SharedMutex
Standard mutex for C++ < 14.
Definition: thread_group.h:84
bool is_auto_remove() const
Check whether the thread is set to auto-remove itself from the ThreadGroup owner when exiting...
Definition: thread_group.h:207
Thread(std::string threadName, ThreadGroup *owner, std::thread *thrd=nullptr)
Constructor.
Definition: thread_group.h:112
void signal()
Set this object&#39;s state to signaled (wait() will release or pass through)
Definition: thread_group.h:52
void join_all()
Join all threads in this ThreadGroup.
Definition: thread_group.h:406
bool add_thread(std::shared_ptr< Thread > thrd)
Add a Thread object to this thread group.
Definition: thread_group.h:366
Lifecycle-managed thread (used by ThreadGroup)
Definition: thread_group.h:101
size_t size() const
Return the number of threads in this thread group.
Definition: thread_group.h:456
Blocking queue thread class.
Definition: thread_group.h:528
namespace for dmlc
Definition: array_view.h:12
bool create(const std::string &threadName, bool auto_remove, StartFunction start_function, Args...args)
Create and launch a new Thread object which will be owned by this ThreadGroup.
Definition: thread_group.h:486
void wait()
Wait for the object to become signaled. If the object is already in the signaled state and reset() ha...
Definition: thread_group.h:42
bool CreateTimer(const std::string &timer_name, const Duration &duration, ThreadGroup *owner, TimerFunction timer_function)
Utility function to easily create a timer.
Definition: thread_group.h:797
const char * name() const
Name of the thread.
Definition: thread_group.h:149
void make_joinable()
Make the thread joinable (by removing the auto_remove flag)
Definition: thread_group.h:216
bool remove_thread(std::shared_ptr< Thread > thrd)
Remove a Thread object from this thread group.
Definition: thread_group.h:385
Simple manual-reset event gate which remains open after signalled.
Definition: thread_group.h:34
virtual ~ThreadGroup()
Destructor, perform cleanup. All child threads will be exited when this destructor completes...
Definition: thread_group.h:319
size_t size_approx() const
Get the approximate size of the queue.
Definition: thread_group.h:585
std::unique_lock< SharedMutex > WriteLock
Standard unique lock for C++ < 14.
Definition: thread_group.h:86
std::shared_ptr< Thread > SharedPtr
Shared pointer type for readability.
Definition: thread_group.h:104
Managed timer thread.
Definition: thread_group.h:643
void request_shutdown() override
Signal the thread that a shutdown is desired.
Definition: thread_group.h:562
static bool launch_run(std::shared_ptr< TimerThread< Duration >> pThis, SecondaryFunction secondary_function)
Launch to the &#39;run&#39; function which will, in turn, call the class&#39; &#39;run&#39; function, passing it the give...
Definition: thread_group.h:674
std::unique_lock< SharedMutex > ReadLock
Standard unique lock for C++ < 14.
Definition: thread_group.h:88
virtual ~Thread()
Destructor with cleanup.
Definition: thread_group.h:126
bool is_thread_in(std::shared_ptr< Thread > thrd) const
Check if the current thread is a member of this ThreadGroup.
Definition: thread_group.h:346
int run(OnTimerFunction on_timer_function)
Internal timer execution function.
Definition: thread_group.h:707
TimerThread(const std::string &name, ThreadGroup *owner)
Constructor.
Definition: thread_group.h:652
bool is_current_thread() const
Check if this class represents the currently running thread (self)
Definition: thread_group.h:177
int run(OnItemFunction on_item_function)
Thread&#39;s main queue processing function.
Definition: thread_group.h:614
bool joinable() const
Check whether the thread is joinable.
Definition: thread_group.h:224
virtual void request_shutdown()
Signal to this thread that a thread shutdown/exit is requested.
Definition: thread_group.h:187
ManualEvent()
Definition: thread_group.h:36
bool is_this_thread_in() const
Check if the current thread a member if this ThreadGroup.
Definition: thread_group.h:330
void reset()
Manually reset this object&#39;s state to unsignaled (wait() will block)
Definition: thread_group.h:61
~TimerThread() override
Destructor.
Definition: thread_group.h:659
~BlockingQueueThread() override
Destructor.
Definition: thread_group.h:549
static void start(std::shared_ptr< TimerThread > timer_thread, Duration duration, Function function)
Start a given timer thread.
Definition: thread_group.h:693