9 #ifndef DMLC_THREADEDITER_H_
10 #define DMLC_THREADEDITER_H_
14 #if DMLC_ENABLE_STD_THREAD
15 #include <condition_variable>
24 #include "./logging.h"
39 : thread_(
std::move(thread)) {
40 if (!thread_.joinable()) {
41 throw std::logic_error(
"No thread");
77 template<
typename DType>
106 virtual bool Next(DType **inout_dptr) = 0;
113 : producer_(nullptr),
114 producer_thread_(nullptr),
115 max_capacity_(max_capacity),
135 max_capacity_ = max_capacity;
142 inline void Init(std::shared_ptr<Producer> producer);
151 inline void Init(std::function<
bool(DType **)> next,
152 std::function<
void()> beforefirst = NotImplemented);
162 inline bool Next(DType **out_dptr);
169 inline void Recycle(DType **inout_dptr);
188 if (out_data_ != NULL) {
191 if (
Next(&out_data_)) {
202 virtual const DType &
Value(
void)
const {
203 CHECK(out_data_ != NULL) <<
"Calling Value at beginning or end?";
209 std::unique_lock<std::mutex>
lock(mutex_);
210 if (out_data_ != NULL) {
211 free_cells_.push(out_data_);
214 if (producer_sig_.load(std::memory_order_acquire) == kDestroy)
return;
216 producer_sig_.store(kBeforeFirst, std::memory_order_release);
217 CHECK(!producer_sig_processed_.load(std::memory_order_acquire));
218 if (nwait_producer_ != 0) {
219 producer_cond_.notify_one();
221 CHECK(!producer_sig_processed_.load(std::memory_order_acquire));
223 consumer_cond_.wait(
lock, [
this]() {
224 return producer_sig_processed_.load(std::memory_order_acquire);
226 producer_sig_processed_.store(
false, std::memory_order_release);
227 bool notify = nwait_producer_ != 0 && !produce_end_;
230 if (notify) producer_cond_.notify_one();
236 inline static void NotImplemented(
void) {
237 LOG(FATAL) <<
"BeforeFirst is not supported";
247 std::shared_ptr<Producer> producer_;
250 std::atomic<Signal> producer_sig_;
252 std::atomic<bool> producer_sig_processed_;
254 std::unique_ptr<ScopedThread> producer_thread_;
256 std::atomic<bool> produce_end_;
258 size_t max_capacity_;
262 std::mutex mutex_exception_;
264 unsigned nwait_consumer_;
266 unsigned nwait_producer_;
268 std::condition_variable producer_cond_;
270 std::condition_variable consumer_cond_;
274 std::queue<DType*> queue_;
276 std::queue<DType*> free_cells_;
278 std::exception_ptr iter_exception_{
nullptr};
283 if (producer_thread_) {
286 std::lock_guard<std::mutex>
lock(mutex_);
288 producer_sig_.store(kDestroy, std::memory_order_release);
289 if (nwait_producer_ != 0) {
290 producer_cond_.notify_one();
293 producer_thread_.reset(
nullptr);
297 while (free_cells_.size() != 0) {
298 delete free_cells_.front();
301 while (queue_.size() != 0) {
302 delete queue_.front();
305 if (producer_ != NULL) {
308 if (out_data_ != NULL) {
314 template<
typename DType>
316 Init(std::shared_ptr<Producer> producer) {
317 CHECK(producer_ == NULL) <<
"can only call Init once";
318 auto next = [producer](DType **dptr) {
319 return producer->Next(dptr);
321 auto beforefirst = [producer]() {
322 producer->BeforeFirst();
324 this->Init(next, beforefirst);
327 template <
typename DType>
329 std::function<
void()> beforefirst) {
330 producer_sig_.store(kProduce, std::memory_order_release);
331 producer_sig_processed_.store(
false, std::memory_order_release);
332 produce_end_.store(
false, std::memory_order_release);
336 auto producer_fun = [
this, next, beforefirst]() {
342 std::unique_lock<std::mutex>
lock(mutex_);
343 ++this->nwait_producer_;
344 producer_cond_.wait(
lock, [
this]() {
345 if (producer_sig_.load(std::memory_order_acquire) == kProduce) {
346 bool ret = !produce_end_.load(std::memory_order_acquire)
347 && (queue_.size() < max_capacity_ ||
348 free_cells_.size() != 0);
354 --this->nwait_producer_;
355 if (producer_sig_.load(std::memory_order_acquire) == kProduce) {
356 if (free_cells_.size() != 0) {
357 cell = free_cells_.front();
360 }
else if (producer_sig_.load(std::memory_order_acquire) == kBeforeFirst) {
364 while (queue_.size() != 0) {
365 free_cells_.push(queue_.front());
369 produce_end_.store(
false, std::memory_order_release);
370 producer_sig_processed_.store(
true, std::memory_order_release);
371 producer_sig_.store(kProduce, std::memory_order_release);
374 consumer_cond_.notify_all();
378 DCHECK(producer_sig_.load(std::memory_order_acquire) == kDestroy);
379 producer_sig_processed_.store(
true, std::memory_order_release);
380 produce_end_.store(
true, std::memory_order_release);
382 consumer_cond_.notify_all();
387 produce_end_.store(!next(&cell), std::memory_order_release);
388 DCHECK(cell != NULL || produce_end_.load(std::memory_order_acquire));
392 std::lock_guard<std::mutex>
lock(mutex_);
393 if (!produce_end_.load(std::memory_order_acquire)) {
397 free_cells_.push(cell);
400 notify = nwait_consumer_ != 0;
403 consumer_cond_.notify_all();
404 }
catch (std::exception &e) {
406 DCHECK(producer_sig_.load(std::memory_order_acquire) != kDestroy);
408 std::lock_guard<std::mutex>
lock(mutex_exception_);
409 if (!iter_exception_) {
410 iter_exception_ = std::current_exception();
413 bool next_notify =
false;
415 std::unique_lock<std::mutex>
lock(mutex_);
416 if (producer_sig_.load(std::memory_order_acquire) == kBeforeFirst) {
417 while (queue_.size() != 0) {
418 free_cells_.push(queue_.front());
421 produce_end_.store(
true, std::memory_order_release);
422 producer_sig_processed_.store(
true, std::memory_order_release);
424 consumer_cond_.notify_all();
425 }
else if (producer_sig_.load(std::memory_order_acquire) == kProduce) {
426 produce_end_.store(
true, std::memory_order_release);
427 next_notify = nwait_consumer_ != 0;
430 consumer_cond_.notify_all();
437 producer_thread_.reset(
new ScopedThread{std::thread(producer_fun)});
440 template <
typename DType>
442 if (producer_sig_.load(std::memory_order_acquire) == kDestroy)
444 ThrowExceptionIfSet();
445 std::unique_lock<std::mutex>
lock(mutex_);
446 CHECK(producer_sig_.load(std::memory_order_acquire) == kProduce)
447 <<
"Make sure you call BeforeFirst not inconcurrent with Next!";
449 consumer_cond_.wait(
lock,
450 [
this]() {
return queue_.size() != 0
451 || produce_end_.load(std::memory_order_acquire); });
453 if (queue_.size() != 0) {
454 *out_dptr = queue_.front();
456 bool notify = nwait_producer_ != 0
457 && !produce_end_.load(std::memory_order_acquire);
460 producer_cond_.notify_one();
462 ThrowExceptionIfSet();
465 CHECK(produce_end_.load(std::memory_order_acquire));
468 ThrowExceptionIfSet();
473 template <
typename DType>
476 ThrowExceptionIfSet();
478 std::lock_guard<std::mutex>
lock(mutex_);
479 free_cells_.push(*inout_dptr);
481 notify = nwait_producer_ != 0 && !produce_end_.load(std::memory_order_acquire);
484 producer_cond_.notify_one();
485 ThrowExceptionIfSet();
489 std::exception_ptr tmp_exception{
nullptr};
491 std::lock_guard<std::mutex>
lock(mutex_exception_);
492 if (iter_exception_) {
493 tmp_exception = iter_exception_;
498 std::rethrow_exception(tmp_exception);
499 }
catch (std::exception& exc) {
500 LOG(FATAL) << exc.what();
506 std::lock_guard<std::mutex>
lock(mutex_exception_);
507 iter_exception_ =
nullptr;
511 #endif // DMLC_USE_CXX11
512 #endif // DMLC_THREADEDITER_H_