TorchCraftAI
A bot for machine learning research on StarCraft: Brood War
datareader.h
1 /*
2  * Copyright (c) 2017-present, Facebook, Inc.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #pragma once
9 
10 #include "rand.h"
11 #include "serialization.h"
12 #include "zstdstream.h"
13 
14 #include <glog/logging.h>
15 
16 #include <chrono>
17 #include <future>
18 #include <map>
19 #include <queue>
20 #include <thread>
21 
22 namespace common {
23 
24 using DataReaderThreadInitF = std::function<void()>;
25 auto const DataReader_NoopF = [] {};
26 
27 /**
28  * A multi-threaded iterator that performs decerealization of objects and
29  * returns data in batches.
30  *
31  * Here, batches means instances of std::vector<T>.
32  *
33  * The iterator will read data from all files specified by the list of paths
34  * passed to the constructor *once*. The order of files is retained and will be
35  * reflected in the resulting batches.
36  *
37  * Call `next()` to retrieve the next batch of data. It will block until enough
38  * data is available to form a batch with the requested size. If the iterator
39  * cannot advance anymore (`hasNext()` returns false), the call with throw an
40  * exception. The last batch returned by this function may have less that
41  * `batchSize` elements or be empty, depending on the actual data.
42  *
43  * For data files that cannot be decerealized (e.g. because the file cannot be
44  * accessed, or because decerialization failed), the iterator will print a
45  * message via glog but otherwise resume operation as usual.
46  *
47  * You will probably want to use this via DataReader<T>.
48  */
49 template <typename T>
51  static size_t constexpr kMaxBatchesInQueue = 4;
52 
53  public:
55  std::vector<std::string> paths,
56  size_t numThreads,
57  size_t batchSize,
58  std::string prefix = std::string(),
60  : paths_(std::move(paths)),
61  prefix_(std::move(prefix)),
62  batchSize_(batchSize),
63  numThreads_(numThreads),
64  init_(init) {
65  pos_ = 0;
66  threadPos_ = 0;
67  maxQueueSize_ = kMaxBatchesInQueue * batchSize_;
68 
69  // Start reader threads
70  for (size_t i = 0; i < numThreads_; i++) {
71  threads_.emplace_back(&DataReaderIterator::read, this);
72  }
73  }
74 
76 
77  bool hasNext();
78 
79  std::vector<T> next();
80 
81  private:
82  /// Read data from files (to be run in a thread)
83  void read();
84 
85  private:
86  std::vector<std::string> paths_;
87  std::string prefix_;
88  size_t batchSize_;
89  size_t numThreads_;
90  size_t pos_;
91  size_t threadPos_;
92  std::map<size_t, std::future<T>> dataQueue_; // key is offset in paths_
93  size_t maxQueueSize_;
94  std::mutex mutex_;
95  std::condition_variable prodCV_;
96  std::condition_variable consumerCV_;
97  std::vector<std::thread> threads_;
98  std::unordered_set<std::thread::id> threadsDone_;
100 }; // namespace common
101 
102 /**
103  * Wrapper for DataReaderIterator that applies an additional transform to the
104  * resulting batches.
105  *
106  * The transform function will be run in a dedicated thread.
107  */
108 template <typename T, typename F>
110  public:
111  using Result = typename std::result_of<F(std::vector<T> const&)>::type;
112  static size_t constexpr kMaxResultsInQueue = 4;
113 
115  std::unique_ptr<DataReaderIterator<T>>&& it,
116  F function,
118  : it_(std::move(it)), fn_(function), init_(init) {
119  thread_ = std::thread(&DataReaderTransform::run, this);
120  }
121 
123 
124  bool hasNext();
125 
126  Result next();
127 
128  private:
129  void run();
130 
131  private:
132  std::unique_ptr<DataReaderIterator<T>> it_;
133  F fn_;
134  DataReaderThreadInitF init_;
135  std::queue<Result> queue_;
136  std::mutex mutex_;
137  std::condition_variable prodCV_;
138  std::condition_variable consumerCV_;
139  std::thread thread_;
140  bool done_ = false;
141 };
142 
143 template <typename T, typename F>
144 std::unique_ptr<DataReaderTransform<T, F>> makeDataReaderTransform(
145  std::unique_ptr<DataReaderIterator<T>>&& it,
146  F&& function,
148  return std::make_unique<DataReaderTransform<T, F>>(
149  std::move(it), function, init);
150 }
151 
153 
154 /**
155  * A multi-threaded reader for cerealized data.
156  *
157  * This class merely holds a list of paths pointing to files that contain
158  * cerealized versions of T. zstd decompression will transparently work. The
159  * actual multi-threaded data reading will happen in an iterator object that
160  * can be obtained by calling `iterator()`.
161  *
162  * Optionally, a `pathPrefix` can be passed to the constructor which will be
163  * prepended to every element in `paths` before accessing the respective file.
164  *
165  * If a transform function is provided, the iterator will run batches through
166  * the function before returning them (in a dedicated thread).
167  *
168  * Usage example with 4 threads and batch size 32:
169  ```
170 auto reader = makeDataReader<MyDatumType>(fileList, 4, 32);
171 while (training) {
172  reader.shuffle();
173  auto it = reader.iterator();
174  while (it->hasNext()) {
175  auto batch = it->next();
176  // Do work
177  }
178 }
179  ```
180  */
181 template <typename T, typename F = DataReader_NoTransform>
182 class DataReader {
183  public:
184  /// Please use `makeDataReader()` instead
186  std::vector<std::string> paths,
187  size_t numThreads,
188  size_t batchSize,
189  std::string pathPrefix = std::string(),
191  : paths_(std::move(paths)),
192  pathPrefix_(pathPrefix),
193  batchSize_(batchSize),
194  numThreads_(numThreads),
195  init_(init) {}
196 
197  /// Please use `makeDataReader` instead
199  std::vector<std::string> paths,
200  size_t numThreads,
201  size_t batchSize,
202  F transform,
203  std::string pathPrefix = std::string(),
205  : paths_(std::move(paths)),
206  pathPrefix_(pathPrefix),
207  batchSize_(batchSize),
208  numThreads_(numThreads),
209  fn_(transform),
210  init_(init) {}
211 
212  /// Shuffle the list of paths.
213  void shuffle() {
214  std::shuffle(
215  paths_.begin(), paths_.end(), Rand::makeRandEngine<std::mt19937>());
216  }
217 
218  /// Create an iterator that provides multi-threaded data access.
219  /// This function will be available for data readers without transforms.
220  template <typename FF = F>
221  std::unique_ptr<DataReaderIterator<T>> iterator(
222  typename std::enable_if_t<
223  std::is_same<FF, DataReader_NoTransform>::value,
224  bool> = true) {
225  return std::make_unique<DataReaderIterator<T>>(
226  paths_, numThreads_, batchSize_, pathPrefix_, init_);
227  }
228 
229  /// Create an iterator that provides multi-threaded data access.
230  /// This function will be available for data readers with a transform.
231  template <typename FF = F>
232  std::unique_ptr<DataReaderTransform<T, F&>> iterator(
233  typename std::enable_if_t<
234  !std::is_same<FF, DataReader_NoTransform>::value,
235  bool> = true) {
237  std::make_unique<DataReaderIterator<T>>(
238  paths_, numThreads_, batchSize_, pathPrefix_, init_),
239  fn_,
240  init_);
241  }
242 
243  protected:
244  std::vector<std::string> paths_;
245  std::string pathPrefix_;
246  size_t batchSize_;
247  size_t numThreads_;
248  F fn_;
250 };
251 
252 template <typename T>
254  std::vector<std::string> paths,
255  size_t numThreads,
256  size_t batchSize,
257  std::string pathPrefix = std::string(),
259  return DataReader<T>(paths, numThreads, batchSize, pathPrefix, init);
260 }
261 
262 // Desperately waiting for C++17 so we get automatic class template deduction
263 template <typename T, typename F>
265  std::vector<std::string> paths,
266  size_t numThreads,
267  size_t batchSize,
268  F transform,
269  std::string pathPrefix = std::string(),
271  return DataReader<T, F>(
272  paths, numThreads, batchSize, transform, pathPrefix, init);
273 }
274 
275 /**************** IMPLEMENTATIONS ********************/
276 
277 template <typename T>
279  {
280  // Move iterator to end and clear results queue so that threads can
281  // finish their current operation.
282  std::lock_guard<std::mutex> lock(mutex_);
283  maxQueueSize_ = std::max(maxQueueSize_, paths_.size() - threadPos_);
284  dataQueue_.clear();
285  threadPos_ = paths_.size();
286  }
287  prodCV_.notify_all();
288 
289  for (auto& thread : threads_) {
290  thread.join();
291  }
292 }
293 
294 template <typename T>
296  std::unique_lock<std::mutex> lock(mutex_);
297  return (!dataQueue_.empty() || threadPos_ < paths_.size());
298 }
299 
300 template <typename T>
301 std::vector<T> DataReaderIterator<T>::next() {
302  std::unique_lock<std::mutex> lock(mutex_);
303  if (dataQueue_.empty() && threadPos_ >= paths_.size()) {
304  throw std::runtime_error("Data iterator is already at end");
305  }
306 
307  std::vector<T> batch;
308  while (pos_ < paths_.size() && batch.size() < batchSize_) {
309  auto curPos = pos_++;
310 
311  int numAttempts = 0;
312  auto it = dataQueue_.end();
313 
314  consumerCV_.wait(lock, [&] {
315  it = dataQueue_.find(curPos);
316  if (it == dataQueue_.end()) {
317  // The requested datum is not in the queue yet -- wait a while.
318  // If we waited for a while already, increase queue size and hope
319  // that we'll eventually get our requested datum's future.
320  if ((++numAttempts % 5) == 0) {
321  maxQueueSize_ *= 1.5;
322  }
323  return false;
324  }
325 
326  return true;
327  });
328 
329  try {
330  batch.push_back(it->second.get());
331  } catch (std::exception const& e) {
332  LOG(WARNING) << "Cannot query result for datum " << pos_ << ", skipping ("
333  << e.what() << ")";
334  }
335  dataQueue_.erase(it);
336  }
337 
338  // Back to normal
339  maxQueueSize_ = kMaxBatchesInQueue * batchSize_;
340 
341  prodCV_.notify_all();
342  return batch;
343 }
344 
345 template <typename T>
347  init_();
348 
349  while (true) {
350  size_t curPos;
351  std::promise<T> dataPromise;
352  bool done = false;
353 
354  { // Critical section interacting with producer and consumer queue
355  std::unique_lock<std::mutex> lock(mutex_);
356  while (true) {
357  if (threadPos_ >= paths_.size()) {
358  done = true;
359  break;
360  }
361  if (dataQueue_.size() < maxQueueSize_) {
362  curPos = threadPos_++;
363  dataQueue_[curPos] = dataPromise.get_future();
364  break;
365  }
366 
367  // No space in queue, return to waiting. In case the specific datum in
368  // question is blocked here, we notify the consumer (next()) to
369  // provide them a chance for increasing the queue size
370  consumerCV_.notify_one();
371 
372  prodCV_.wait(lock);
373  }
374  }
375  if (done) {
376  break;
377  }
378 
379  std::string filePath;
380  auto const& curPath = paths_[curPos];
381  if (!prefix_.empty() && !curPath.empty() && curPath[0] != '/') {
382  filePath = prefix_ + "/" + curPath;
383  } else {
384  filePath = curPath;
385  }
386 
387  // Read data, fulfill promise from above
388  try {
389  VLOG(4) << "Reading data from " << filePath;
390  zstd::ifstream is(filePath);
391  cereal::BinaryInputArchive archive(is);
392  T d;
393  archive(d);
394  dataPromise.set_value(std::move(d));
395  } catch (std::exception const& e) {
396  VLOG(0) << "Invalid data file " << filePath << ", skipping (" << e.what()
397  << ")";
398  try {
399  dataPromise.set_exception(std::current_exception());
400  } catch (std::exception const& e2) {
401  LOG(ERROR) << "Cannot propagate exception: " << e2.what();
402  }
403  }
404 
405  // Use notify_all() for the (unlikely) case where multiple threads are
406  // waiting in next().
407  consumerCV_.notify_all();
408  }
409 
410  std::unique_lock<std::mutex> lock(mutex_);
411  threadsDone_.insert(std::this_thread::get_id());
412 }
413 
414 template <typename T, typename F>
416  {
417  std::lock_guard<std::mutex> lock(mutex_);
418  done_ = true;
419  }
420  prodCV_.notify_all();
421  thread_.join();
422 }
423 
424 template <typename T, typename F>
426  std::lock_guard<std::mutex> lock(mutex_);
427  return !(queue_.empty() && done_);
428 }
429 
430 template <typename T, typename F>
432  std::unique_lock<std::mutex> lock(mutex_);
433  if (queue_.empty() && done_) {
434  throw std::runtime_error("Data iterator is already at end");
435  }
436 
437  consumerCV_.wait(lock, [&] { return !queue_.empty(); });
438  auto result = std::move(queue_.front());
439  queue_.pop();
440  prodCV_.notify_all();
441  return result;
442 }
443 
444 template <typename T, typename F>
446  init_();
447 
448  std::unique_lock<std::mutex> lock(mutex_);
449  // We want this check to be locked so that when we are at the end of the
450  // iterator, we do not leave the critical section before setting done_ to
451  // true.
452  while (it_->hasNext()) {
453  lock.unlock();
454  auto result = fn_(it_->next());
455  lock.lock();
456 
457  prodCV_.wait(
458  lock, [&] { return done_ || queue_.size() < kMaxResultsInQueue; });
459  if (done_) {
460  break;
461  }
462 
463  queue_.push(std::move(result));
464  consumerCV_.notify_one();
465  }
466  done_ = true;
467 }
468 
469 } // namespace common
Input file stream for Zstd-compressed data.
Definition: zstdstream.h:180
void shuffle()
Shuffle the list of paths.
Definition: datareader.h:213
std::unique_ptr< DataReaderIterator< T > > iterator(typename std::enable_if_t< std::is_same< FF, DataReader_NoTransform >::value, bool >=true)
Create an iterator that provides multi-threaded data access.
Definition: datareader.h:221
auto const DataReader_NoopF
Definition: datareader.h:25
STL namespace.
~DataReaderTransform()
Definition: datareader.h:415
size_t numThreads_
Definition: datareader.h:247
DataReaderIterator(std::vector< std::string > paths, size_t numThreads, size_t batchSize, std::string prefix=std::string(), DataReaderThreadInitF init=DataReader_NoopF)
Definition: datareader.h:54
F fn_
Definition: datareader.h:248
DataReaderTransform(std::unique_ptr< DataReaderIterator< T >> &&it, F function, DataReaderThreadInitF init)
Definition: datareader.h:114
A multi-threaded reader for cerealized data.
Definition: datareader.h:182
std::vector< std::string > paths_
Definition: datareader.h:244
Result next()
Definition: datareader.h:431
typename std::result_of< F(std::vector< T > const &)>::type Result
Definition: datareader.h:111
Definition: datareader.h:152
std::unique_ptr< DataReaderTransform< T, F > > makeDataReaderTransform(std::unique_ptr< DataReaderIterator< T >> &&it, F &&function, DataReaderThreadInitF init=DataReader_NoopF)
Definition: datareader.h:144
std::unique_ptr< DataReaderTransform< T, F & > > iterator(typename std::enable_if_t< !std::is_same< FF, DataReader_NoTransform >::value, bool >=true)
Create an iterator that provides multi-threaded data access.
Definition: datareader.h:232
General utilities.
Definition: assert.cpp:7
size_t batchSize_
Definition: datareader.h:246
std::function< void()> DataReaderThreadInitF
Definition: datareader.h:24
~DataReaderIterator()
Definition: datareader.h:278
std::string pathPrefix_
Definition: datareader.h:245
Wrapper for DataReaderIterator that applies an additional transform to the resulting batches...
Definition: datareader.h:109
DataReader(std::vector< std::string > paths, size_t numThreads, size_t batchSize, std::string pathPrefix=std::string(), DataReaderThreadInitF init=DataReader_NoopF)
Please use makeDataReader() instead.
Definition: datareader.h:185
auto makeDataReader(std::vector< std::string > paths, size_t numThreads, size_t batchSize, std::string pathPrefix=std::string(), DataReaderThreadInitF init=DataReader_NoopF)
Definition: datareader.h:253
std::vector< T > next()
Definition: datareader.h:301
DataReaderThreadInitF init_
Definition: datareader.h:249
DataReader(std::vector< std::string > paths, size_t numThreads, size_t batchSize, F transform, std::string pathPrefix=std::string(), DataReaderThreadInitF init=DataReader_NoopF)
Please use makeDataReader instead.
Definition: datareader.h:198
bool hasNext()
Definition: datareader.h:295
A multi-threaded iterator that performs decerealization of objects and returns data in batches...
Definition: datareader.h:50
bool hasNext()
Definition: datareader.h:425