mxnet
threadediter.h
Go to the documentation of this file.
1 
9 #ifndef DMLC_THREADEDITER_H_
10 #define DMLC_THREADEDITER_H_
11 // defines DMLC_USE_CXX11
12 #include "./base.h"
13 // this code depends on c++11
14 #if DMLC_ENABLE_STD_THREAD
15 #include <condition_variable>
16 #include <functional>
17 #include <mutex>
18 #include <queue>
19 #include <atomic>
20 #include <thread>
21 #include <utility>
22 #include <memory>
23 #include "./data.h"
24 #include "./logging.h"
25 
26 namespace dmlc {
27 
32 class ScopedThread {
33  public:
38  explicit ScopedThread(std::thread thread)
39  : thread_(std::move(thread)) {
40  if (!thread_.joinable()) {
41  throw std::logic_error("No thread");
42  }
43  }
44  // destructor: join upon destruction
45  virtual ~ScopedThread() {
46  thread_.join();
47  }
48  // copy assignment and construction are not allowed
49  ScopedThread(ScopedThread const&) = delete;
50  ScopedThread& operator=(ScopedThread const&) = delete;
51 
52  private:
53  std::thread thread_;
54 };
55 
77 template<typename DType>
78 class ThreadedIter : public DataIter<DType> {
79  public:
85  class Producer {
86  public:
87  // virtual destructor
88  virtual ~Producer() = default;
90  virtual void BeforeFirst(void) {
91  NotImplemented();
92  }
106  virtual bool Next(DType **inout_dptr) = 0;
107  };
112  explicit ThreadedIter(size_t max_capacity = 8)
113  : producer_(nullptr),
114  producer_thread_(nullptr),
115  max_capacity_(max_capacity),
116  nwait_consumer_(0),
117  nwait_producer_(0),
118  out_data_(NULL) {}
120  virtual ~ThreadedIter(void) {
121  this->Destroy();
122  }
129  inline void Destroy(void);
134  inline void set_max_capacity(size_t max_capacity) {
135  max_capacity_ = max_capacity;
136  }
142  inline void Init(std::shared_ptr<Producer> producer);
151  inline void Init(std::function<bool(DType **)> next,
152  std::function<void()> beforefirst = NotImplemented);
162  inline bool Next(DType **out_dptr);
169  inline void Recycle(DType **inout_dptr);
170 
174  inline void ThrowExceptionIfSet(void);
175 
179  inline void ClearException(void);
180 
187  virtual bool Next(void) {
188  if (out_data_ != NULL) {
189  this->Recycle(&out_data_);
190  }
191  if (Next(&out_data_)) {
192  return true;
193  } else {
194  return false;
195  }
196  }
202  virtual const DType &Value(void) const {
203  CHECK(out_data_ != NULL) << "Calling Value at beginning or end?";
204  return *out_data_;
205  }
207  virtual void BeforeFirst(void) {
208  ThrowExceptionIfSet();
209  std::unique_lock<std::mutex> lock(mutex_);
210  if (out_data_ != NULL) {
211  free_cells_.push(out_data_);
212  out_data_ = NULL;
213  }
214  if (producer_sig_.load(std::memory_order_acquire) == kDestroy) return;
215 
216  producer_sig_.store(kBeforeFirst, std::memory_order_release);
217  CHECK(!producer_sig_processed_.load(std::memory_order_acquire));
218  if (nwait_producer_ != 0) {
219  producer_cond_.notify_one();
220  }
221  CHECK(!producer_sig_processed_.load(std::memory_order_acquire));
222  // wait until the request has been processed
223  consumer_cond_.wait(lock, [this]() {
224  return producer_sig_processed_.load(std::memory_order_acquire);
225  });
226  producer_sig_processed_.store(false, std::memory_order_release);
227  bool notify = nwait_producer_ != 0 && !produce_end_;
228  lock.unlock();
229  // notify producer, in case they are waiting for the condition.
230  if (notify) producer_cond_.notify_one();
231  ThrowExceptionIfSet();
232  }
233 
234  private:
236  inline static void NotImplemented(void) {
237  LOG(FATAL) << "BeforeFirst is not supported";
238  }
240  enum Signal {
241  kProduce,
242  kBeforeFirst,
243  kDestroy
244  };
246  // Producer *producer_owned_;
247  std::shared_ptr<Producer> producer_;
248 
250  std::atomic<Signal> producer_sig_;
252  std::atomic<bool> producer_sig_processed_;
254  std::unique_ptr<ScopedThread> producer_thread_;
256  std::atomic<bool> produce_end_;
258  size_t max_capacity_;
260  std::mutex mutex_;
262  std::mutex mutex_exception_;
264  unsigned nwait_consumer_;
266  unsigned nwait_producer_;
268  std::condition_variable producer_cond_;
270  std::condition_variable consumer_cond_;
272  DType *out_data_;
274  std::queue<DType*> queue_;
276  std::queue<DType*> free_cells_;
278  std::exception_ptr iter_exception_{nullptr};
279 };
280 
281 // implementation of functions
282 template <typename DType> inline void ThreadedIter<DType>::Destroy(void) {
283  if (producer_thread_) {
284  {
285  // lock the mutex
286  std::lock_guard<std::mutex> lock(mutex_);
287  // send destroy signal
288  producer_sig_.store(kDestroy, std::memory_order_release);
289  if (nwait_producer_ != 0) {
290  producer_cond_.notify_one();
291  }
292  }
293  producer_thread_.reset(nullptr);
294  }
295  // end of critical region
296  // now the slave thread should exit
297  while (free_cells_.size() != 0) {
298  delete free_cells_.front();
299  free_cells_.pop();
300  }
301  while (queue_.size() != 0) {
302  delete queue_.front();
303  queue_.pop();
304  }
305  if (producer_ != NULL) {
306  producer_.reset();
307  }
308  if (out_data_ != NULL) {
309  delete out_data_;
310  out_data_ = NULL;
311  }
312 }
313 
314 template<typename DType>
315 inline void ThreadedIter<DType>::
316 Init(std::shared_ptr<Producer> producer) {
317  CHECK(producer_ == NULL) << "can only call Init once";
318  auto next = [producer](DType **dptr) {
319  return producer->Next(dptr);
320  };
321  auto beforefirst = [producer]() {
322  producer->BeforeFirst();
323  };
324  this->Init(next, beforefirst);
325 }
326 
327 template <typename DType>
328 inline void ThreadedIter<DType>::Init(std::function<bool(DType **)> next,
329  std::function<void()> beforefirst) {
330  producer_sig_.store(kProduce, std::memory_order_release);
331  producer_sig_processed_.store(false, std::memory_order_release);
332  produce_end_.store(false, std::memory_order_release);
333  ClearException();
334  // procedure running in prodcuer
335  // run producer thread
336  auto producer_fun = [this, next, beforefirst]() {
337  while (true) {
338  try {
339  DType *cell = NULL;
340  {
341  // lockscope
342  std::unique_lock<std::mutex> lock(mutex_);
343  ++this->nwait_producer_;
344  producer_cond_.wait(lock, [this]() {
345  if (producer_sig_.load(std::memory_order_acquire) == kProduce) {
346  bool ret = !produce_end_.load(std::memory_order_acquire)
347  && (queue_.size() < max_capacity_ ||
348  free_cells_.size() != 0);
349  return ret;
350  } else {
351  return true;
352  }
353  });
354  --this->nwait_producer_;
355  if (producer_sig_.load(std::memory_order_acquire) == kProduce) {
356  if (free_cells_.size() != 0) {
357  cell = free_cells_.front();
358  free_cells_.pop();
359  }
360  } else if (producer_sig_.load(std::memory_order_acquire) == kBeforeFirst) {
361  // reset the producer
362  beforefirst();
363  // cleanup the queue
364  while (queue_.size() != 0) {
365  free_cells_.push(queue_.front());
366  queue_.pop();
367  }
368  // reset the state
369  produce_end_.store(false, std::memory_order_release);
370  producer_sig_processed_.store(true, std::memory_order_release);
371  producer_sig_.store(kProduce, std::memory_order_release);
372  // notify consumer that all the process as been done.
373  lock.unlock();
374  consumer_cond_.notify_all();
375  continue;
376  } else {
377  // destroy the thread
378  DCHECK(producer_sig_.load(std::memory_order_acquire) == kDestroy);
379  producer_sig_processed_.store(true, std::memory_order_release);
380  produce_end_.store(true, std::memory_order_release);
381  lock.unlock();
382  consumer_cond_.notify_all();
383  return;
384  }
385  } // end of lock scope
386  // now without lock
387  produce_end_.store(!next(&cell), std::memory_order_release);
388  DCHECK(cell != NULL || produce_end_.load(std::memory_order_acquire));
389  bool notify;
390  {
391  // lockscope
392  std::lock_guard<std::mutex> lock(mutex_);
393  if (!produce_end_.load(std::memory_order_acquire)) {
394  queue_.push(cell);
395  } else {
396  if (cell != NULL)
397  free_cells_.push(cell);
398  }
399  // put things into queue
400  notify = nwait_consumer_ != 0;
401  }
402  if (notify)
403  consumer_cond_.notify_all();
404  } catch (std::exception &e) {
405  // Shouldn't throw exception in destructor
406  DCHECK(producer_sig_.load(std::memory_order_acquire) != kDestroy);
407  {
408  std::lock_guard<std::mutex> lock(mutex_exception_);
409  if (!iter_exception_) {
410  iter_exception_ = std::current_exception();
411  }
412  }
413  bool next_notify = false;
414  {
415  std::unique_lock<std::mutex> lock(mutex_);
416  if (producer_sig_.load(std::memory_order_acquire) == kBeforeFirst) {
417  while (queue_.size() != 0) {
418  free_cells_.push(queue_.front());
419  queue_.pop();
420  }
421  produce_end_.store(true, std::memory_order_release);
422  producer_sig_processed_.store(true, std::memory_order_release);
423  lock.unlock();
424  consumer_cond_.notify_all();
425  } else if (producer_sig_.load(std::memory_order_acquire) == kProduce) {
426  produce_end_.store(true, std::memory_order_release);
427  next_notify = nwait_consumer_ != 0;
428  lock.unlock();
429  if (next_notify)
430  consumer_cond_.notify_all();
431  }
432  }
433  return;
434  }
435  }
436  };
437  producer_thread_.reset(new ScopedThread{std::thread(producer_fun)});
438 }
439 
440 template <typename DType>
441 inline bool ThreadedIter<DType>::Next(DType **out_dptr) {
442  if (producer_sig_.load(std::memory_order_acquire) == kDestroy)
443  return false;
444  ThrowExceptionIfSet();
445  std::unique_lock<std::mutex> lock(mutex_);
446  CHECK(producer_sig_.load(std::memory_order_acquire) == kProduce)
447  << "Make sure you call BeforeFirst not inconcurrent with Next!";
448  ++nwait_consumer_;
449  consumer_cond_.wait(lock,
450  [this]() { return queue_.size() != 0
451  || produce_end_.load(std::memory_order_acquire); });
452  --nwait_consumer_;
453  if (queue_.size() != 0) {
454  *out_dptr = queue_.front();
455  queue_.pop();
456  bool notify = nwait_producer_ != 0
457  && !produce_end_.load(std::memory_order_acquire);
458  lock.unlock();
459  if (notify)
460  producer_cond_.notify_one();
461 
462  ThrowExceptionIfSet();
463  return true;
464  } else {
465  CHECK(produce_end_.load(std::memory_order_acquire));
466  lock.unlock();
467 
468  ThrowExceptionIfSet();
469  return false;
470  }
471 }
472 
473 template <typename DType>
474 inline void ThreadedIter<DType>::Recycle(DType **inout_dptr) {
475  bool notify;
476  ThrowExceptionIfSet();
477  {
478  std::lock_guard<std::mutex> lock(mutex_);
479  free_cells_.push(*inout_dptr);
480  *inout_dptr = NULL;
481  notify = nwait_producer_ != 0 && !produce_end_.load(std::memory_order_acquire);
482  }
483  if (notify)
484  producer_cond_.notify_one();
485  ThrowExceptionIfSet();
486 }
487 
488 template <typename DType> inline void ThreadedIter<DType>::ThrowExceptionIfSet(void) {
489  std::exception_ptr tmp_exception{nullptr};
490  {
491  std::lock_guard<std::mutex> lock(mutex_exception_);
492  if (iter_exception_) {
493  tmp_exception = iter_exception_;
494  }
495  }
496  if (tmp_exception) {
497  try {
498  std::rethrow_exception(tmp_exception);
499  } catch (std::exception& exc) {
500  LOG(FATAL) << exc.what();
501  }
502  }
503 }
504 
505 template <typename DType> inline void ThreadedIter<DType>::ClearException(void) {
506  std::lock_guard<std::mutex> lock(mutex_exception_);
507  iter_exception_ = nullptr;
508 }
509 
510 } // namespace dmlc
511 #endif // DMLC_USE_CXX11
512 #endif // DMLC_THREADEDITER_H_
void Init(std::shared_ptr< Producer > producer)
initialize the producer and start the thread can only be called once
Definition: threadediter.h:316
void ThrowExceptionIfSet(void)
Rethrows exception which is set by the producer.
Definition: threadediter.h:488
virtual ~ScopedThread()
Definition: threadediter.h:45
data iterator interface this is not a C++ style iterator, but nice for data pulling:) This interface ...
Definition: data.h:56
virtual void BeforeFirst(void)
reset the producer to beginning
Definition: threadediter.h:90
Wrapper class to manage std::thread; uses RAII pattern to automatically join std::thread upon destruc...
Definition: threadediter.h:32
defines common input data structure, and interface for handling the input data
Definition: optional.h:241
ThreadedIter(size_t max_capacity=8)
constructor
Definition: threadediter.h:112
ScopedThread(std::thread thread)
constructor
Definition: threadediter.h:38
virtual void BeforeFirst(void)
set the iterator before first location
Definition: threadediter.h:207
namespace for dmlc
Definition: array_view.h:12
void Destroy(void)
destroy all the related resources this is equivalent to destructor, can be used to destroy the thread...
Definition: threadediter.h:282
a iterator that was backed by a thread to pull data eagerly from a single producer into a bounded buf...
Definition: threadediter.h:78
virtual ~ThreadedIter(void)
destructor
Definition: threadediter.h:120
void ClearException(void)
clears exception_ptr, called from Init
Definition: threadediter.h:505
producer class interface that threaditer used as source to preduce the content
Definition: threadediter.h:85
void Recycle(DType **inout_dptr)
recycle the data cell, this function is threadsafe the threaditer can reuse the data cell for future ...
Definition: threadediter.h:474
void set_max_capacity(size_t max_capacity)
set maximum capacity of the queue
Definition: threadediter.h:134
virtual bool Next(void)
adapt the iterator interface&#39;s Next NOTE: the call to this function is not threadsafe use the other N...
Definition: threadediter.h:187
ScopedThread & operator=(ScopedThread const &)=delete
virtual const DType & Value(void) const
adapt the iterator interface&#39;s Value NOTE: the call to this function is not threadsafe use the other ...
Definition: threadediter.h:202