TorchCraftAI
A bot for machine learning research on StarCraft: Brood War
parallel.h
1 /*
2  * Copyright (c) 2017-present, Facebook, Inc.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #pragma once
9 
10 #include <common/assert.h>
11 #include <condition_variable>
12 #include <future>
13 #include <iostream>
14 #include <mutex>
15 #include <optional>
16 #include <queue>
17 #include <thread>
18 #include <type_traits>
19 
20 namespace common {
21 
22 /**
23  * A simple producer/consumer class.
24  *
25  * This class is dead-simple, but sometimes useful. You specify the element type
26  * for the queue in the type, and then instantiate it with functor which will
27  * run in a separate thread. The main function of the class is enqueue(), which,
28  * well, adds stuff to the queue. You also specify a maximum queue size on
29  * construction; if that size is reached, enqueue() will block.
30  *
31  * As a special case, you can use this class with 0 threads. This means that the
32  * supplied functor will be called directly in the thread calling enqueue().
33  * Items will be buffered implicitly by enqueue() blocking until consumption.
34  *
35  * If you want to wait for the consumers to finish, call wait(). If you want to
36  * stop the consumer threads, destruct the object.
37  *
38  * The implementation assumes that objects of type T are in a valid state (i.e.
39  * can be destructed) after moving. If that's not the case for your type, go fix
40  * your type.
41  */
42 template <typename T>
44  using Function = std::function<void(T)>;
45 
46  public:
47  using type = T;
48 
49  BufferedConsumer(uint8_t nthreads, size_t maxQueueSize, Function&& fn);
50 
51  /// Stops the consumers, discarding any items in the queue
53 
54  /// Blocks until the queue is empty or the consumers are stopped
55  void wait();
56 
57  /// Adds another item to the work queue, possibly blocking
58  /// If the number of threads is zero, execute directly in the calling thread's
59  /// context (and thus block).
60  void enqueue(T arg);
61 
62  /// Same as 'enqueue', except that if the queue is full, the oldest element
63  /// will be removed before inserting
64  /// Only works for nthreads > 0
65  void enqueueOrReplaceOldest(T arg);
66 
67  void run();
68 
69  protected:
70  size_t const maxQueueSize_;
71  bool stop_ = false;
72  int64_t consuming_ = 0;
73  Function fn_;
74  std::vector<std::thread> threads_;
75  std::queue<T> queue_;
76  std::mutex mutex_;
77  std::condition_variable itemReady_;
78  std::condition_variable itemDone_;
79 };
80 
81 /**
82  * A simple producer class.
83  *
84  * You specify a function that will generate data for you somehow, ending when
85  * it returns an optional without a value, and this producer will multithread it
86  * for you automatically. The function should be thread-safe, and data is not
87  * guaranteed to arrive in the same order it was generated in, unless you do it
88  * yourself. If you want to stop the consumer threads, destruct the object. If
89  * you try destructing the object while get() is still being called, it will
90  * result in a runtime error.
91  */
92 template <typename T>
94  using Function = std::function<std::optional<T>()>;
95 
96  public:
97  using type = T;
98 
99  // Use uint8 because we don't expect more than 256 threads
100  BufferedProducer(uint8_t nthreads, size_t maxQueueSize, Function&& fn);
101 
102  /// Stops the producers, discarding any items in the queue
103  ~BufferedProducer();
104 
105  std::optional<T> get();
106 
107  void run(Function fn);
108 
109  private:
110  size_t const maxQueueSize_;
111  bool stop_ = false;
112  int working_ = 0;
113  uint8_t nThreads_;
114  std::atomic_int running_;
115  std::vector<std::thread> threads_;
116  std::queue<std::future<std::optional<T>>> queue_;
117  std::mutex mutex_;
118  std::condition_variable queueCV_;
119 };
120 
121 /************************ IMPLEMENTATION ***********************/
122 
123 template <typename T>
125  uint8_t nthreads,
126  size_t maxQueueSize,
127  Function&& fn)
128  : maxQueueSize_(maxQueueSize), fn_(fn) {
129  if (maxQueueSize_ == 0 && nthreads > 0) {
130  throw std::runtime_error(
131  "Cannot construct BufferedConsumer with > 0 threads but zero-sized "
132  "queue");
133  }
134  for (auto i = nthreads; i > 0; i--) {
135  threads_.emplace_back(&BufferedConsumer::run, this);
136  }
137 }
138 
139 /// Stops the consumers, discarding any items in the queue
140 template <typename T>
142  {
143  std::lock_guard<std::mutex> lock(mutex_);
144  stop_ = true;
145  }
146  itemReady_.notify_all();
147  itemDone_.notify_all();
148  for (auto& th : threads_) {
149  th.join();
150  }
151 }
152 
153 /// Blocks until the queue is empty or the consumers are stopped
154 template <typename T>
156  std::unique_lock<std::mutex> lock(mutex_);
157  itemDone_.wait(
158  lock, [&] { return stop_ || (queue_.empty() && consuming_ == 0); });
159 }
160 
161 template <typename T>
163  if (!threads_.empty()) {
164  {
165  std::unique_lock<std::mutex> lock(mutex_);
166  itemDone_.wait(
167  lock, [&] { return stop_ || queue_.size() < maxQueueSize_; });
168  if (stop_) {
169  throw std::runtime_error("BufferedConsumer not active");
170  }
171  queue_.push(std::move(arg));
172  }
173  itemReady_.notify_one();
174  } else {
175  {
176  std::unique_lock<std::mutex> lock(mutex_);
177  if (stop_) {
178  throw std::runtime_error("BufferedConsumer not active");
179  }
180  consuming_++;
181  fn_(std::move(arg));
182  consuming_--;
183  }
184  itemDone_.notify_all();
185  }
186 }
187 
188 template <typename T>
190  ASSERT(
191  !threads_.empty(),
192  "Please use BufferedConsumer::enqueue when not using threads");
193  {
194  std::unique_lock<std::mutex> lock(mutex_);
195  if (stop_) {
196  throw std::runtime_error("BufferedConsumer not active");
197  }
198  if (queue_.size() >= maxQueueSize_) {
199  queue_.pop();
200  }
201  queue_.push(std::move(arg));
202  }
203  itemReady_.notify_one();
204 }
205 
206 template <typename T>
208  std::unique_lock<std::mutex> lock(mutex_);
209  while (true) {
210  itemReady_.wait(lock, [&] { return stop_ || !queue_.empty(); });
211  if (stop_) {
212  break;
213  }
214  if (queue_.empty()) {
215  continue;
216  }
217 
218  T item = std::move(queue_.front());
219  queue_.pop();
220 
221  consuming_++;
222  lock.unlock();
223  fn_(std::move(item));
224  lock.lock();
225  consuming_--;
226 
227  // Only remove the item from the queue once it has been consumed
228  // Ideally we'd do the notification without holding the lock, but doing so
229  // we save one lock/unlock cycle. Let's trust implementations to recognize
230  // this scenario (cf.
231  // https://en.cppreference.com/w/cpp/thread/condition_variable/notify_one)
232  itemDone_.notify_all();
233  }
234 }
235 
236 template <typename T>
238  uint8_t nThreads,
239  size_t maxQueueSize,
240  Function&& fn)
241  : maxQueueSize_(maxQueueSize), nThreads_(nThreads) {
242  if (nThreads_ == 0) {
243  throw std::runtime_error("Cannot use a buffered producer with no threads");
244  }
245  if (maxQueueSize == 0) {
246  throw std::runtime_error(
247  "Cannot consturct a BufferedProducer with 0 queue size");
248  }
249  for (auto i = nThreads; i > 0; i--) {
250  threads_.emplace_back(&BufferedProducer::run, this, fn);
251  }
252  running_ = nThreads_;
253 }
254 
255 /// Stops the producers, discarding any items in the queue
256 template <typename T>
258  {
259  std::lock_guard<std::mutex> lock(mutex_);
260  stop_ = true;
261  queueCV_.notify_all();
262  }
263  for (auto& th : threads_) {
264  th.join();
265  }
266 }
267 
268 template <typename T>
269 std::optional<T> BufferedProducer<T>::get() {
270  std::unique_lock<std::mutex> lock(mutex_);
271  queueCV_.wait(
272  lock, [&] { return stop_ || !queue_.empty() || running_ == 0; });
273  if (stop_) {
274  throw std::runtime_error("BufferedProducer not active");
275  }
276  if (running_ == 0 && queue_.empty()) {
277  return std::optional<T>();
278  }
279  auto ret = queue_.front().get();
280  queue_.pop();
281  queueCV_.notify_all();
282  return ret;
283 }
284 
285 template <typename T>
286 void BufferedProducer<T>::run(Function fn) {
287  while (true) {
288  std::unique_lock<std::mutex> lock(mutex_);
289  queueCV_.wait(lock, [&] {
290  return stop_ || queue_.size() + working_ < maxQueueSize_;
291  });
292  if (stop_) {
293  break;
294  }
295 
296  std::promise<std::optional<T>> dataPromise;
297  working_++;
298 
299  lock.unlock();
300  auto result = fn();
301  bool done = (!result.has_value());
302  lock.lock();
303 
304  working_--;
305  if (done) {
306  running_--;
307  queueCV_.notify_all();
308  break;
309  }
310  dataPromise.set_value(std::move(result));
311  queue_.push(dataPromise.get_future());
312  queueCV_.notify_all();
313  }
314 }
315 } // namespace common
void run()
Definition: parallel.h:207
~BufferedConsumer()
Stops the consumers, discarding any items in the queue.
Definition: parallel.h:141
std::optional< T > get()
Definition: parallel.h:269
Function fn_
Definition: parallel.h:73
void wait()
Blocks until the queue is empty or the consumers are stopped.
Definition: parallel.h:155
int64_t consuming_
Definition: parallel.h:72
std::condition_variable itemReady_
Definition: parallel.h:77
std::queue< T > queue_
Definition: parallel.h:75
std::condition_variable itemDone_
Definition: parallel.h:78
void enqueue(T arg)
Adds another item to the work queue, possibly blocking If the number of threads is zero...
Definition: parallel.h:162
std::mutex mutex_
Definition: parallel.h:76
size_t const maxQueueSize_
Definition: parallel.h:70
void run(Function fn)
Definition: parallel.h:286
General utilities.
Definition: assert.cpp:7
A simple producer class.
Definition: parallel.h:93
T type
Definition: parallel.h:97
BufferedProducer(uint8_t nthreads, size_t maxQueueSize, Function &&fn)
Definition: parallel.h:237
bool stop_
Definition: parallel.h:71
std::vector< std::thread > threads_
Definition: parallel.h:74
A simple producer/consumer class.
Definition: parallel.h:43
BufferedConsumer(uint8_t nthreads, size_t maxQueueSize, Function &&fn)
Definition: parallel.h:124
Request type
Definition: parallel.h:47
~BufferedProducer()
Stops the producers, discarding any items in the queue.
Definition: parallel.h:257
void enqueueOrReplaceOldest(T arg)
Same as &#39;enqueue&#39;, except that if the queue is full, the oldest element will be removed before insert...
Definition: parallel.h:189