9 #ifndef DMLC_THREADEDITER_H_ 10 #define DMLC_THREADEDITER_H_ 14 #if DMLC_ENABLE_STD_THREAD 15 #include <condition_variable> 24 #include "./logging.h" 39 : thread_(
std::move(thread)) {
40 if (!thread_.joinable()) {
41 throw std::logic_error(
"No thread");
77 template<
typename DType>
106 virtual bool Next(DType **inout_dptr) = 0;
113 : producer_(nullptr),
114 producer_thread_(nullptr),
115 max_capacity_(max_capacity),
129 inline void Destroy(
void);
135 max_capacity_ = max_capacity;
142 inline void Init(std::shared_ptr<Producer> producer);
151 inline void Init(std::function<
bool(DType **)> next,
152 std::function<
void()> beforefirst = NotImplemented);
162 inline bool Next(DType **out_dptr);
169 inline void Recycle(DType **inout_dptr);
174 inline void ThrowExceptionIfSet(
void);
179 inline void ClearException(
void);
188 if (out_data_ != NULL) {
189 this->Recycle(&out_data_);
191 if (Next(&out_data_)) {
202 virtual const DType &
Value(
void)
const {
203 CHECK(out_data_ != NULL) <<
"Calling Value at beginning or end?";
208 ThrowExceptionIfSet();
209 std::unique_lock<std::mutex> lock(mutex_);
210 if (out_data_ != NULL) {
211 free_cells_.push(out_data_);
214 if (producer_sig_.load(std::memory_order_acquire) == kDestroy)
return;
216 producer_sig_.store(kBeforeFirst, std::memory_order_release);
217 CHECK(!producer_sig_processed_.load(std::memory_order_acquire));
218 if (nwait_producer_ != 0) {
219 producer_cond_.notify_one();
221 CHECK(!producer_sig_processed_.load(std::memory_order_acquire));
223 consumer_cond_.wait(lock, [
this]() {
224 return producer_sig_processed_.load(std::memory_order_acquire);
226 producer_sig_processed_.store(
false, std::memory_order_release);
227 bool notify = nwait_producer_ != 0 && !produce_end_;
230 if (notify) producer_cond_.notify_one();
231 ThrowExceptionIfSet();
236 inline static void NotImplemented(
void) {
237 LOG(FATAL) <<
"BeforeFirst is not supported";
247 std::shared_ptr<Producer> producer_;
250 std::atomic<Signal> producer_sig_;
252 std::atomic<bool> producer_sig_processed_;
254 std::unique_ptr<ScopedThread> producer_thread_;
256 std::atomic<bool> produce_end_;
258 size_t max_capacity_;
262 std::mutex mutex_exception_;
264 unsigned nwait_consumer_;
266 unsigned nwait_producer_;
268 std::condition_variable producer_cond_;
270 std::condition_variable consumer_cond_;
274 std::queue<DType*> queue_;
276 std::queue<DType*> free_cells_;
278 std::exception_ptr iter_exception_{
nullptr};
283 if (producer_thread_) {
286 std::lock_guard<std::mutex> lock(mutex_);
288 producer_sig_.store(kDestroy, std::memory_order_release);
289 if (nwait_producer_ != 0) {
290 producer_cond_.notify_one();
293 producer_thread_.reset(
nullptr);
297 while (free_cells_.size() != 0) {
298 delete free_cells_.front();
301 while (queue_.size() != 0) {
302 delete queue_.front();
305 if (producer_ != NULL) {
308 if (out_data_ != NULL) {
314 template<
typename DType>
316 Init(std::shared_ptr<Producer> producer) {
317 CHECK(producer_ == NULL) <<
"can only call Init once";
318 auto next = [producer](DType **dptr) {
319 return producer->Next(dptr);
321 auto beforefirst = [producer]() {
322 producer->BeforeFirst();
324 this->Init(next, beforefirst);
327 template <
typename DType>
329 std::function<
void()> beforefirst) {
330 producer_sig_.store(kProduce, std::memory_order_release);
331 producer_sig_processed_.store(
false, std::memory_order_release);
332 produce_end_.store(
false, std::memory_order_release);
336 auto producer_fun = [
this, next, beforefirst]() {
342 std::unique_lock<std::mutex> lock(mutex_);
343 ++this->nwait_producer_;
344 producer_cond_.wait(lock, [
this]() {
345 if (producer_sig_.load(std::memory_order_acquire) == kProduce) {
346 bool ret = !produce_end_.load(std::memory_order_acquire)
347 && (queue_.size() < max_capacity_ ||
348 free_cells_.size() != 0);
354 --this->nwait_producer_;
355 if (producer_sig_.load(std::memory_order_acquire) == kProduce) {
356 if (free_cells_.size() != 0) {
357 cell = free_cells_.front();
360 }
else if (producer_sig_.load(std::memory_order_acquire) == kBeforeFirst) {
364 while (queue_.size() != 0) {
365 free_cells_.push(queue_.front());
369 produce_end_.store(
false, std::memory_order_release);
370 producer_sig_processed_.store(
true, std::memory_order_release);
371 producer_sig_.store(kProduce, std::memory_order_release);
374 consumer_cond_.notify_all();
378 DCHECK(producer_sig_.load(std::memory_order_acquire) == kDestroy);
379 producer_sig_processed_.store(
true, std::memory_order_release);
380 produce_end_.store(
true, std::memory_order_release);
382 consumer_cond_.notify_all();
387 produce_end_.store(!next(&cell), std::memory_order_release);
388 DCHECK(cell != NULL || produce_end_.load(std::memory_order_acquire));
392 std::lock_guard<std::mutex> lock(mutex_);
393 if (!produce_end_.load(std::memory_order_acquire)) {
397 free_cells_.push(cell);
400 notify = nwait_consumer_ != 0;
403 consumer_cond_.notify_all();
404 }
catch (std::exception &e) {
406 DCHECK(producer_sig_.load(std::memory_order_acquire) != kDestroy);
408 std::lock_guard<std::mutex> lock(mutex_exception_);
409 if (!iter_exception_) {
410 iter_exception_ = std::current_exception();
413 bool next_notify =
false;
415 std::unique_lock<std::mutex> lock(mutex_);
416 if (producer_sig_.load(std::memory_order_acquire) == kBeforeFirst) {
417 while (queue_.size() != 0) {
418 free_cells_.push(queue_.front());
421 produce_end_.store(
true, std::memory_order_release);
422 producer_sig_processed_.store(
true, std::memory_order_release);
424 consumer_cond_.notify_all();
425 }
else if (producer_sig_.load(std::memory_order_acquire) == kProduce) {
426 produce_end_.store(
true, std::memory_order_release);
427 next_notify = nwait_consumer_ != 0;
430 consumer_cond_.notify_all();
437 producer_thread_.reset(
new ScopedThread{std::thread(producer_fun)});
440 template <
typename DType>
442 if (producer_sig_.load(std::memory_order_acquire) == kDestroy)
444 ThrowExceptionIfSet();
445 std::unique_lock<std::mutex> lock(mutex_);
446 CHECK(producer_sig_.load(std::memory_order_acquire) == kProduce)
447 <<
"Make sure you call BeforeFirst not inconcurrent with Next!";
449 consumer_cond_.wait(lock,
450 [
this]() {
return queue_.size() != 0
451 || produce_end_.load(std::memory_order_acquire); });
453 if (queue_.size() != 0) {
454 *out_dptr = queue_.front();
456 bool notify = nwait_producer_ != 0
457 && !produce_end_.load(std::memory_order_acquire);
460 producer_cond_.notify_one();
462 ThrowExceptionIfSet();
465 CHECK(produce_end_.load(std::memory_order_acquire));
468 ThrowExceptionIfSet();
473 template <
typename DType>
476 ThrowExceptionIfSet();
478 std::lock_guard<std::mutex> lock(mutex_);
479 free_cells_.push(*inout_dptr);
481 notify = nwait_producer_ != 0 && !produce_end_.load(std::memory_order_acquire);
484 producer_cond_.notify_one();
485 ThrowExceptionIfSet();
489 std::exception_ptr tmp_exception{
nullptr};
491 std::lock_guard<std::mutex> lock(mutex_exception_);
492 if (iter_exception_) {
493 tmp_exception = iter_exception_;
498 std::rethrow_exception(tmp_exception);
499 }
catch (std::exception& exc) {
500 LOG(FATAL) << exc.what();
506 std::lock_guard<std::mutex> lock(mutex_exception_);
507 iter_exception_ =
nullptr;
511 #endif // DMLC_USE_CXX11 512 #endif // DMLC_THREADEDITER_H_ void Init(std::shared_ptr< Producer > producer)
initialize the producer and start the thread can only be called once
Definition: threadediter.h:316
void ThrowExceptionIfSet(void)
Rethrows exception which is set by the producer.
Definition: threadediter.h:488
virtual ~ScopedThread()
Definition: threadediter.h:45
data iterator interface this is not a C++ style iterator, but nice for data pulling:) This interface ...
Definition: data.h:56
virtual void BeforeFirst(void)
reset the producer to beginning
Definition: threadediter.h:90
Wrapper class to manage std::thread; uses RAII pattern to automatically join std::thread upon destruc...
Definition: threadediter.h:32
defines common input data structure, and interface for handling the input data
Definition: optional.h:241
ThreadedIter(size_t max_capacity=8)
constructor
Definition: threadediter.h:112
ScopedThread(std::thread thread)
constructor
Definition: threadediter.h:38
virtual void BeforeFirst(void)
set the iterator before first location
Definition: threadediter.h:207
namespace for dmlc
Definition: array_view.h:12
void Destroy(void)
destroy all the related resources this is equivalent to destructor, can be used to destroy the thread...
Definition: threadediter.h:282
a iterator that was backed by a thread to pull data eagerly from a single producer into a bounded buf...
Definition: threadediter.h:78
virtual ~ThreadedIter(void)
destructor
Definition: threadediter.h:120
void ClearException(void)
clears exception_ptr, called from Init
Definition: threadediter.h:505
producer class interface that threaditer used as source to preduce the content
Definition: threadediter.h:85
void Recycle(DType **inout_dptr)
recycle the data cell, this function is threadsafe the threaditer can reuse the data cell for future ...
Definition: threadediter.h:474
void set_max_capacity(size_t max_capacity)
set maximum capacity of the queue
Definition: threadediter.h:134
virtual bool Next(void)
adapt the iterator interface's Next NOTE: the call to this function is not threadsafe use the other N...
Definition: threadediter.h:187
ScopedThread & operator=(ScopedThread const &)=delete
virtual const DType & Value(void) const
adapt the iterator interface's Value NOTE: the call to this function is not threadsafe use the other ...
Definition: threadediter.h:202