mxnet
concurrency.h
Go to the documentation of this file.
1 
7 #ifndef DMLC_CONCURRENCY_H_
8 #define DMLC_CONCURRENCY_H_
9 // this code depends on c++11
10 #if DMLC_USE_CXX11
11 #include <atomic>
12 #include <deque>
13 #include <queue>
14 #include <mutex>
15 #include <vector>
16 #include <utility>
17 #include <condition_variable>
18 #include "dmlc/base.h"
19 
20 namespace dmlc {
21 
25 class Spinlock {
26  public:
27 #ifdef _MSC_VER
28  Spinlock() {
29  lock_.clear();
30  }
31 #else
32 #if defined(__clang__)
33 #pragma clang diagnostic push
34 #pragma clang diagnostic ignored "-Wbraced-scalar-init"
35 #endif // defined(__clang__)
36  Spinlock() : lock_(ATOMIC_FLAG_INIT) {
37  }
38 #if defined(__clang__)
39 #pragma clang diagnostic pop
40 #endif // defined(__clang__)
41 #endif
42  ~Spinlock() = default;
46  inline void lock() noexcept(true);
50  inline void unlock() noexcept(true);
51 
52  private:
53  std::atomic_flag lock_;
57  DISALLOW_COPY_AND_ASSIGN(Spinlock);
58 };
59 
61 enum class ConcurrentQueueType {
63  kFIFO,
65  kPriority
66 };
67 
71 template <typename T,
74  public:
76  ~ConcurrentBlockingQueue() = default;
87  template <typename E>
88  void Push(E&& e, int priority = 0);
89 
101  template <typename E>
102  void PushFront(E&& e, int priority = 0);
110  bool Pop(T* rv);
117  void SignalForKill();
122  size_t Size();
123 
124  private:
125  struct Entry {
126  T data;
127  int priority;
128  inline bool operator<(const Entry &b) const {
129  return priority < b.priority;
130  }
131  };
132 
133  std::mutex mutex_;
134  std::condition_variable cv_;
135  std::atomic<bool> exit_now_;
136  int nwait_consumer_;
137  // a priority queue
138  std::vector<Entry> priority_queue_;
139  // a FIFO queue
140  std::deque<T> fifo_queue_;
145 };
146 
147 inline void Spinlock::lock() noexcept(true) {
148  while (lock_.test_and_set(std::memory_order_acquire)) {
149  }
150 }
151 
152 inline void Spinlock::unlock() noexcept(true) {
153  lock_.clear(std::memory_order_release);
154 }
155 
156 template <typename T, ConcurrentQueueType type>
158  : exit_now_{false}, nwait_consumer_{0} {}
159 
160 template <typename T, ConcurrentQueueType type>
161 template <typename E>
162 void ConcurrentBlockingQueue<T, type>::Push(E&& e, int priority) {
163  static_assert(std::is_same<typename std::remove_cv<
164  typename std::remove_reference<E>::type>::type,
165  T>::value,
166  "Types must match.");
167  bool notify;
168  {
169  std::lock_guard<std::mutex> lock{mutex_};
170  if (type == ConcurrentQueueType::kFIFO) {
171  fifo_queue_.emplace_back(std::forward<E>(e));
172  notify = nwait_consumer_ != 0;
173  } else {
174  Entry entry;
175  entry.data = std::move(e);
176  entry.priority = priority;
177  priority_queue_.push_back(std::move(entry));
178  std::push_heap(priority_queue_.begin(), priority_queue_.end());
179  notify = nwait_consumer_ != 0;
180  }
181  }
182  if (notify) cv_.notify_one();
183 }
184 
185 template <typename T, ConcurrentQueueType type>
186 template <typename E>
188  static_assert(std::is_same<typename std::remove_cv<
189  typename std::remove_reference<E>::type>::type,
190  T>::value,
191  "Types must match.");
192  bool notify;
193  {
194  std::lock_guard<std::mutex> lock{mutex_};
195  if (type == ConcurrentQueueType::kFIFO) {
196  fifo_queue_.emplace_front(std::forward<E>(e));
197  notify = nwait_consumer_ != 0;
198  } else {
199  Entry entry;
200  entry.data = std::move(e);
201  entry.priority = priority;
202  priority_queue_.push_back(std::move(entry));
203  std::push_heap(priority_queue_.begin(), priority_queue_.end());
204  notify = nwait_consumer_ != 0;
205  }
206  }
207  if (notify) cv_.notify_one();
208 }
209 
210 template <typename T, ConcurrentQueueType type>
212  std::unique_lock<std::mutex> lock{mutex_};
213  if (type == ConcurrentQueueType::kFIFO) {
214  ++nwait_consumer_;
215  cv_.wait(lock, [this] {
216  return !fifo_queue_.empty() || exit_now_.load();
217  });
218  --nwait_consumer_;
219  if (!exit_now_.load()) {
220  *rv = std::move(fifo_queue_.front());
221  fifo_queue_.pop_front();
222  return true;
223  } else {
224  return false;
225  }
226  } else {
227  ++nwait_consumer_;
228  cv_.wait(lock, [this] {
229  return !priority_queue_.empty() || exit_now_.load();
230  });
231  --nwait_consumer_;
232  if (!exit_now_.load()) {
233  std::pop_heap(priority_queue_.begin(), priority_queue_.end());
234  *rv = std::move(priority_queue_.back().data);
235  priority_queue_.pop_back();
236  return true;
237  } else {
238  return false;
239  }
240  }
241 }
242 
243 template <typename T, ConcurrentQueueType type>
245  {
246  std::lock_guard<std::mutex> lock{mutex_};
247  exit_now_.store(true);
248  }
249  cv_.notify_all();
250 }
251 
252 template <typename T, ConcurrentQueueType type>
254  std::lock_guard<std::mutex> lock{mutex_};
255  if (type == ConcurrentQueueType::kFIFO) {
256  return fifo_queue_.size();
257  } else {
258  return priority_queue_.size();
259  }
260 }
261 } // namespace dmlc
262 #endif // DMLC_USE_CXX11
263 #endif // DMLC_CONCURRENCY_H_
dmlc::ConcurrentBlockingQueue::PushFront
void PushFront(E &&e, int priority=0)
Push element to the front of the queue. Only works for FIFO queue. For priority queue it is the same ...
Definition: concurrency.h:187
dmlc
namespace for dmlc
Definition: array_view.h:12
dmlc::ConcurrentBlockingQueue::Push
void Push(E &&e, int priority=0)
Push element to the end of the queue.
Definition: concurrency.h:162
base.h
defines configuration macros
dmlc::ConcurrentQueueType::kFIFO
@ kFIFO
FIFO queue.
dmlc::ConcurrentBlockingQueue::SignalForKill
void SignalForKill()
Signal the queue for destruction.
Definition: concurrency.h:244
dmlc::ConcurrentBlockingQueue::Pop
bool Pop(T *rv)
Pop element from the queue.
Definition: concurrency.h:211
mxnet::common::cuda::rtc::lock
std::mutex lock
DISALLOW_COPY_AND_ASSIGN
#define DISALLOW_COPY_AND_ASSIGN(T)
Disable copy constructor and assignment operator.
Definition: base.h:174
dmlc::ConcurrentBlockingQueue::Size
size_t Size()
Get the size of the queue.
Definition: concurrency.h:253
dmlc::Spinlock::Spinlock
Spinlock()
Definition: concurrency.h:36
dmlc::ConcurrentQueueType::kPriority
@ kPriority
queue with priority
dmlc::Spinlock::lock
void lock() noexcept(true)
Acquire lock.
Definition: concurrency.h:147
dmlc::ConcurrentQueueType
ConcurrentQueueType
type of concurrent queue
Definition: concurrency.h:61
mxnet::runtime::operator<
bool operator<(const String &lhs, const std::string &rhs)
Definition: container_ext.h:748
dmlc::ConcurrentBlockingQueue
Cocurrent blocking queue.
Definition: concurrency.h:73
dmlc::ConcurrentBlockingQueue::ConcurrentBlockingQueue
ConcurrentBlockingQueue()
Definition: concurrency.h:157
std
Definition: optional.h:251
dmlc::Spinlock
Simple userspace spinlock implementation.
Definition: concurrency.h:25
dmlc::Spinlock::~Spinlock
~Spinlock()=default
dmlc::Spinlock::unlock
void unlock() noexcept(true)
Release lock.
Definition: concurrency.h:152