TorchCraftAI
A bot for machine learning research on StarCraft: Brood War
zmqbufferedconsumer.h
1 /*
2  * Copyright (c) 2017-present, Facebook, Inc.
3  *
4  * This source code is licensed under the MIT license found in the LICENSE file
5  * in the root directory of this source tree.
6  */
7 
8 #pragma once
9 
10 #include "zmqbufferedproducer.h"
11 
12 #include <atomic>
13 
14 namespace cpid {
15 
16 /** A buffered consumer that sends data via ZeroMQ.
17  *
18  * The intended use-case is for this class to be used together with
19  * ZeroMQBufferedConsumer to implement distributed producer-consumer setups.
20  * Suppose you have an existing setup that looks like this, with sections of
21  * your code producing items of type T and other sections consuming them:
22  *
23  * [Producer] -> [Consumer]
24  *
25  * Then, assuming that items can be serialized with cereal, the
26  * ZeroMQBufferedConsumer/Producer classes enable the following design:
27  *
28  * [Producer] -> [ZeroMQBufferedConsumer]
29  * |
30  * TCP
31  * |
32  * [ZeroMQBufferedProducer] -> [Consumer]
33  *
34  * As in common::BufferedConsumer, you specify the number of threads and a queue
35  * size. In addition, you supply a list of end points that
36  * ZeroMQBufferedProducer instances have been bound to. Data will be send to
37  * end-points in a round-robin fashion. If producer endpoints don't accept new
38  * data (because their queue is full and items are not consumed fast enough),
39  * `enqueue()` will eventually block and perform retries.
40  */
41 template <typename T>
43  using Request = std::vector<char>;
44  using Reply = std::vector<char>;
45 
46  public:
48  uint8_t nthreads,
49  size_t maxQueueSize,
50  std::vector<std::string> endpoints,
51  std::shared_ptr<zmq::context_t> context = nullptr);
53 
54  void enqueue(T arg);
55  void updateEndpoints(std::vector<std::string> endpoints);
56 
57  private:
58  size_t const maxConcurrentRequests_;
59  std::list<std::pair<Request, std::future<Reply>>> pending_;
60  std::atomic<bool> stop_{false};
61  ReqRepClient client_;
62  std::unique_ptr<common::BufferedConsumer<Request>> bcsend_;
63  std::unique_ptr<common::BufferedConsumer<T>> bcser_;
64 };
65 
66 template <typename T>
68  uint8_t nthreads,
69  size_t maxQueueSize,
70  std::vector<std::string> endpoints,
71  std::shared_ptr<zmq::context_t> context)
72  : maxConcurrentRequests_(std::min(maxQueueSize, size_t(64))),
73  client_(maxConcurrentRequests_, endpoints, context) {
74  // BufferedConsumer for sending out data. With a single thread, this will
75  // simply run in the calling thread (protected by a mutex).
76  bcsend_ = std::make_unique<common::BufferedConsumer<Request>>(
77  0, 1, [this](Request ca) {
78  // Check pending requests. We'll keep on retrying with an exponential
79  // (bounded) backoff. For this implementation we're painfully reminded
80  // of the rawness of C++11's thread support library: we can't attach
81  // callbacks to futures (i.e. future::then()) or wait for multiple
82  // futures at once.
83  // This is the reason we limit the maximum queue size for ReqRepClient
84  // to 64 (tuned by manual monkeying).
85  int ntry = 0;
86  while (pending_.size() >= maxConcurrentRequests_ && !stop_.load()) {
87  if (ntry++ > 0) {
88  std::this_thread::sleep_for(std::chrono::milliseconds(
89  int(10 * std::pow(2, std::min(ntry, 5)))));
90  }
91 
92  for (auto it = pending_.begin(); it != pending_.end();) {
93  auto& [req, fut] = *it;
94  if (fut.wait_for(std::chrono::seconds(0)) !=
95  std::future_status::ready) {
96  ++it;
97  continue;
98  }
99  Reply reply;
100  try {
101  reply = fut.get();
102  } catch (std::exception const& ex) {
103  // Something failed -- need to resend
104  VLOG(1)
105  << "ZeroMQBufferedConsumer: got exception instead of reply: "
106  << ex.what();
107  auto copy = req;
108  *it = std::make_pair(
109  std::move(copy), client_.request(std::move(req)));
110  ++it;
111  continue;
112  }
113 
114  // Recepient confirmed?
115  if (detail::kConfirm.compare(
116  0, detail::kConfirm.size(), reply.data(), reply.size()) ==
117  0) {
118  it = pending_.erase(it);
119  } else {
120  VLOG(0) << "ZeroMQBufferedConsumer: got non-affirmative "
121  "reply of size "
122  << reply.size() << ", retrying";
123  auto copy = req;
124  *it = std::make_pair(
125  std::move(copy), client_.request(std::move(req)));
126  ++it;
127  }
128  }
129  }
130 
131  auto copy = ca;
132  pending_.emplace_back(std::move(copy), client_.request(std::move(ca)));
133  });
134 
135  // BufferedConsumer for data serialization
136  bcser_ = std::make_unique<common::BufferedConsumer<T>>(
137  nthreads, maxQueueSize, [this](T data) {
138  common::OMembuf buf;
139  {
140  common::zstd::ostream os(&buf);
141  cereal::BinaryOutputArchive ar(os);
142  ar(data);
143  }
144  bcsend_->enqueue(buf.takeData());
145  });
146 }
147 
148 template <typename T>
150  bcser_.reset();
151  stop_.store(true);
152  bcsend_.reset();
153 }
154 
155 template <typename T>
157  bcser_->enqueue(std::move(arg));
158 }
159 
160 template <typename T>
162  std::vector<std::string> endpoints) {
163  client_.updateEndpoints(std::move(endpoints));
164 }
165 
166 } // namespace cpid
A request-reply client backed by ZeroMQ.
Definition: reqrepserver.h:88
bool updateEndpoints(std::vector< std::string > endpoints)
Returns true if the endpoints changed.
Definition: reqrepserver.cpp:268
A stream buffer for writing to an accessible vector of bytes.
Definition: serialization.h:51
void enqueue(T arg)
Definition: zmqbufferedconsumer.h:156
STL namespace.
std::string kConfirm
Definition: zmqbufferedproducer.cpp:12
std::future< std::vector< char > > request(std::vector< char > msg)
Definition: reqrepserver.cpp:261
Output stream for Zstd-compressed data.
Definition: zstdstream.h:140
A buffered consumer that sends data via ZeroMQ.
Definition: zmqbufferedconsumer.h:42
~ZeroMQBufferedConsumer()
Definition: zmqbufferedconsumer.h:149
void updateEndpoints(std::vector< std::string > endpoints)
Definition: zmqbufferedconsumer.h:161
The TorchCraftAI training library.
Definition: batcher.cpp:15
ZeroMQBufferedConsumer(uint8_t nthreads, size_t maxQueueSize, std::vector< std::string > endpoints, std::shared_ptr< zmq::context_t > context=nullptr)
Definition: zmqbufferedconsumer.h:67
std::vector< char > takeData()
Definition: serialization.cpp:26