11 #include "serialization.h" 12 #include "zstdstream.h" 14 #include <glog/logging.h> 51 static size_t constexpr kMaxBatchesInQueue = 4;
55 std::vector<std::string> paths,
58 std::string prefix = std::string(),
60 : paths_(
std::move(paths)),
61 prefix_(
std::move(prefix)),
62 batchSize_(batchSize),
63 numThreads_(numThreads),
67 maxQueueSize_ = kMaxBatchesInQueue * batchSize_;
70 for (
size_t i = 0; i < numThreads_; i++) {
71 threads_.emplace_back(&DataReaderIterator::read,
this);
79 std::vector<T>
next();
86 std::vector<std::string> paths_;
92 std::map<size_t, std::future<T>> dataQueue_;
95 std::condition_variable prodCV_;
96 std::condition_variable consumerCV_;
97 std::vector<std::thread> threads_;
98 std::unordered_set<std::thread::id> threadsDone_;
108 template <
typename T,
typename F>
111 using Result =
typename std::result_of<F(std::vector<T>
const&)>::type;
112 static size_t constexpr kMaxResultsInQueue = 4;
118 : it_(
std::move(it)), fn_(function), init_(init) {
119 thread_ = std::thread(&DataReaderTransform::run,
this);
132 std::unique_ptr<DataReaderIterator<T>> it_;
135 std::queue<Result> queue_;
137 std::condition_variable prodCV_;
138 std::condition_variable consumerCV_;
143 template <
typename T,
typename F>
148 return std::make_unique<DataReaderTransform<T, F>>(
149 std::move(it),
function, init);
181 template <
typename T,
typename F = DataReader_NoTransform>
186 std::vector<std::string> paths,
189 std::string pathPrefix = std::string(),
191 : paths_(
std::move(paths)),
192 pathPrefix_(pathPrefix),
193 batchSize_(batchSize),
194 numThreads_(numThreads),
199 std::vector<std::string> paths,
203 std::string pathPrefix = std::string(),
205 : paths_(
std::move(paths)),
206 pathPrefix_(pathPrefix),
207 batchSize_(batchSize),
208 numThreads_(numThreads),
215 paths_.begin(), paths_.end(), Rand::makeRandEngine<std::mt19937>());
220 template <
typename FF = F>
222 typename std::enable_if_t<
223 std::is_same<FF, DataReader_NoTransform>::value,
225 return std::make_unique<DataReaderIterator<T>>(
226 paths_, numThreads_, batchSize_, pathPrefix_, init_);
231 template <
typename FF = F>
232 std::unique_ptr<DataReaderTransform<T, F&>>
iterator(
233 typename std::enable_if_t<
234 !std::is_same<FF, DataReader_NoTransform>::value,
238 paths_, numThreads_, batchSize_, pathPrefix_, init_),
252 template <
typename T>
254 std::vector<std::string> paths,
257 std::string pathPrefix = std::string(),
259 return DataReader<T>(paths, numThreads, batchSize, pathPrefix, init);
263 template <
typename T,
typename F>
265 std::vector<std::string> paths,
269 std::string pathPrefix = std::string(),
272 paths, numThreads, batchSize, transform, pathPrefix, init);
277 template <
typename T>
282 std::lock_guard<std::mutex> lock(mutex_);
283 maxQueueSize_ = std::max(maxQueueSize_, paths_.size() - threadPos_);
285 threadPos_ = paths_.size();
287 prodCV_.notify_all();
289 for (
auto& thread : threads_) {
294 template <
typename T>
296 std::unique_lock<std::mutex> lock(mutex_);
297 return (!dataQueue_.empty() || threadPos_ < paths_.size());
300 template <
typename T>
302 std::unique_lock<std::mutex> lock(mutex_);
303 if (dataQueue_.empty() && threadPos_ >= paths_.size()) {
304 throw std::runtime_error(
"Data iterator is already at end");
307 std::vector<T> batch;
308 while (pos_ < paths_.size() && batch.size() < batchSize_) {
309 auto curPos = pos_++;
312 auto it = dataQueue_.end();
314 consumerCV_.wait(lock, [&] {
315 it = dataQueue_.find(curPos);
316 if (it == dataQueue_.end()) {
320 if ((++numAttempts % 5) == 0) {
321 maxQueueSize_ *= 1.5;
330 batch.push_back(it->second.get());
331 }
catch (std::exception
const& e) {
332 LOG(WARNING) <<
"Cannot query result for datum " << pos_ <<
", skipping (" 335 dataQueue_.erase(it);
339 maxQueueSize_ = kMaxBatchesInQueue * batchSize_;
341 prodCV_.notify_all();
345 template <
typename T>
351 std::promise<T> dataPromise;
355 std::unique_lock<std::mutex> lock(mutex_);
357 if (threadPos_ >= paths_.size()) {
361 if (dataQueue_.size() < maxQueueSize_) {
362 curPos = threadPos_++;
363 dataQueue_[curPos] = dataPromise.get_future();
370 consumerCV_.notify_one();
379 std::string filePath;
380 auto const& curPath = paths_[curPos];
381 if (!prefix_.empty() && !curPath.empty() && curPath[0] !=
'/') {
382 filePath = prefix_ +
"/" + curPath;
389 VLOG(4) <<
"Reading data from " << filePath;
391 cereal::BinaryInputArchive archive(is);
394 dataPromise.set_value(std::move(d));
395 }
catch (std::exception
const& e) {
396 VLOG(0) <<
"Invalid data file " << filePath <<
", skipping (" << e.what()
399 dataPromise.set_exception(std::current_exception());
400 }
catch (std::exception
const& e2) {
401 LOG(ERROR) <<
"Cannot propagate exception: " << e2.what();
407 consumerCV_.notify_all();
410 std::unique_lock<std::mutex> lock(mutex_);
411 threadsDone_.insert(std::this_thread::get_id());
414 template <
typename T,
typename F>
417 std::lock_guard<std::mutex> lock(mutex_);
420 prodCV_.notify_all();
424 template <
typename T,
typename F>
426 std::lock_guard<std::mutex> lock(mutex_);
427 return !(queue_.empty() && done_);
430 template <
typename T,
typename F>
432 std::unique_lock<std::mutex> lock(mutex_);
433 if (queue_.empty() && done_) {
434 throw std::runtime_error(
"Data iterator is already at end");
437 consumerCV_.wait(lock, [&] {
return !queue_.empty(); });
438 auto result = std::move(queue_.front());
440 prodCV_.notify_all();
444 template <
typename T,
typename F>
448 std::unique_lock<std::mutex> lock(mutex_);
452 while (it_->hasNext()) {
454 auto result = fn_(it_->next());
458 lock, [&] {
return done_ || queue_.size() < kMaxResultsInQueue; });
463 queue_.push(std::move(result));
464 consumerCV_.notify_one();
Input file stream for Zstd-compressed data.
Definition: zstdstream.h:180
void shuffle()
Shuffle the list of paths.
Definition: datareader.h:213
std::unique_ptr< DataReaderIterator< T > > iterator(typename std::enable_if_t< std::is_same< FF, DataReader_NoTransform >::value, bool >=true)
Create an iterator that provides multi-threaded data access.
Definition: datareader.h:221
auto const DataReader_NoopF
Definition: datareader.h:25
size_t numThreads_
Definition: datareader.h:247
DataReaderIterator(std::vector< std::string > paths, size_t numThreads, size_t batchSize, std::string prefix=std::string(), DataReaderThreadInitF init=DataReader_NoopF)
Definition: datareader.h:54
F fn_
Definition: datareader.h:248
A multi-threaded reader for cerealized data.
Definition: datareader.h:182
std::vector< std::string > paths_
Definition: datareader.h:244
std::unique_ptr< DataReaderTransform< T, F > > makeDataReaderTransform(std::unique_ptr< DataReaderIterator< T >> &&it, F &&function, DataReaderThreadInitF init=DataReader_NoopF)
Definition: datareader.h:144
std::unique_ptr< DataReaderTransform< T, F & > > iterator(typename std::enable_if_t< !std::is_same< FF, DataReader_NoTransform >::value, bool >=true)
Create an iterator that provides multi-threaded data access.
Definition: datareader.h:232
General utilities.
Definition: assert.cpp:7
size_t batchSize_
Definition: datareader.h:246
std::function< void()> DataReaderThreadInitF
Definition: datareader.h:24
~DataReaderIterator()
Definition: datareader.h:278
std::string pathPrefix_
Definition: datareader.h:245
DataReader(std::vector< std::string > paths, size_t numThreads, size_t batchSize, std::string pathPrefix=std::string(), DataReaderThreadInitF init=DataReader_NoopF)
Please use makeDataReader() instead.
Definition: datareader.h:185
auto makeDataReader(std::vector< std::string > paths, size_t numThreads, size_t batchSize, std::string pathPrefix=std::string(), DataReaderThreadInitF init=DataReader_NoopF)
Definition: datareader.h:253
std::vector< T > next()
Definition: datareader.h:301
DataReaderThreadInitF init_
Definition: datareader.h:249
DataReader(std::vector< std::string > paths, size_t numThreads, size_t batchSize, F transform, std::string pathPrefix=std::string(), DataReaderThreadInitF init=DataReader_NoopF)
Please use makeDataReader instead.
Definition: datareader.h:198
bool hasNext()
Definition: datareader.h:295
A multi-threaded iterator that performs decerealization of objects and returns data in batches...
Definition: datareader.h:50