TorchCraftAI
A bot for machine learning research on StarCraft: Brood War
zmqbufferedproducer.h
1 /*
2  * Copyright (c) 2017-present, Facebook, Inc.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #pragma once
9 
10 #include "distributed.h"
11 #include "reqrepserver.h"
12 
13 #include <common/parallel.h>
14 #include <common/serialization.h>
15 #include <common/zstdstream.h>
16 
17 namespace cpid {
18 namespace detail {
19 extern std::string kConfirm;
20 extern std::string kDeny;
21 } // namespace detail
22 
23 /**
24  * A buffered producer that obtains data via ZeroMQ.
25  *
26  * The intended use-case is for this class to be used together with
27  * ZeroMQBufferedConsumer to implement distributed producer-consumer setups.
28  * Suppose you have an existing setup that looks like this, with sections of
29  * your code producing items of type T and other sections consuming them:
30  *
31  * [Producer] -> [Consumer]
32  *
33  * Then, assuming that items can be serialized with cereal, the
34  * ZeroMQBufferedConsumer/Producer classes enable the following design:
35  *
36  * [Producer] -> [ZeroMQBufferedConsumer]
37  * |
38  * TCP
39  * |
40  * [ZeroMQBufferedProducer] -> [Consumer]
41  *
42  * As in common::BufferedProducer you specify a number of threads in the
43  * constructor which will be used to deserialize data. Calling get() returns
44  * data. Destructing the object will stop all threads.
45  *
46  * Make sure that you're calling get() fast enough; if you expect delays for
47  * consumption set maxQueueSize accordingly. If the queue runs full the server
48  * will not accept new data from the network.
49  */
50 template <typename T>
52  public:
54  uint8_t nthreads,
55  size_t maxQueueSize,
56  std::string endpoint = std::string());
58 
59  std::optional<T> get();
60  std::string endpoint() const {
61  return rrs_->endpoint();
62  }
63  void stop();
64 
65  protected:
66  void handleRequest(void const* buf, size_t len, ReqRepServer::ReplyFn reply);
67 
68  private:
69  std::optional<T> produce();
70 
71  std::mutex mutex_;
72  std::condition_variable cv_;
73  std::queue<std::vector<char>> queue_;
74  size_t const maxInQueue_;
75  std::atomic<bool> stop_{false};
76  std::unique_ptr<common::BufferedProducer<T>> bprod_;
77  std::unique_ptr<ReqRepServer> rrs_;
78 };
79 
80 template <typename T>
82  uint8_t nthreads,
83  size_t maxQueueSize,
84  std::string endpoint)
85  : maxInQueue_(maxQueueSize) {
86  bprod_ = std::make_unique<common::BufferedProducer<T>>(
87  nthreads, maxQueueSize, [this] { return produce(); });
88  rrs_ = std::make_unique<ReqRepServer>(
89  [this](void const* buf, size_t len, ReqRepServer::ReplyFn reply) {
90  handleRequest(buf, len, reply);
91  },
92  1,
93  std::move(endpoint));
94 }
95 
96 template <typename T>
98  stop();
99 }
100 
101 template <typename T>
102 std::optional<T> ZeroMQBufferedProducer<T>::get() {
103  return bprod_->get();
104 }
105 
106 template <typename T>
108  stop_.store(true);
109  cv_.notify_all();
110 }
111 
112 template <typename T>
114  void const* buf,
115  size_t len,
116  ReqRepServer::ReplyFn reply) {
117  VLOG(2) << "ZeroMQBufferedProducer: received " << len << " bytes";
118  {
119  std::lock_guard<std::mutex> lock(mutex_);
120  if (queue_.size() >= maxInQueue_) {
121  VLOG(0) << "ZeroMQBufferedProducer: queue is full, cannot accept message";
122  reply(detail::kDeny.c_str(), detail::kDeny.size());
123  return;
124  } else if (queue_.size() > 0) {
125  VLOG(1) << "ZeroMQBufferedProducer: queue size " << queue_.size();
126  }
127  // Place in queue
128  queue_.emplace(
129  static_cast<char const*>(buf), static_cast<char const*>(buf) + len);
130  }
131 
132  // Notify client that we received the message
133  reply(detail::kConfirm.c_str(), detail::kConfirm.size());
134  cv_.notify_one();
135 }
136 
137 template <typename T>
138 std::optional<T> ZeroMQBufferedProducer<T>::produce() {
139  std::unique_lock<std::mutex> lock(mutex_);
140  cv_.wait(lock, [&] { return stop_ || !queue_.empty(); });
141  if (stop_) {
142  return {};
143  }
144 
145  auto data = std::move(queue_.front());
146  queue_.pop();
147  lock.unlock();
148 
149  common::IMembuf buf(data);
150  common::zstd::istream is(&buf);
151  cereal::BinaryInputArchive ar(is);
152  T item;
153  ar(item);
154  return item;
155 }
156 
157 } // namespace cpid
std::string kDeny
Definition: zmqbufferedproducer.cpp:13
std::function< void(void const *buf, size_t len)> ReplyFn
Definition: reqrepserver.h:40
std::optional< T > get()
Definition: zmqbufferedproducer.h:102
~ZeroMQBufferedProducer()
Definition: zmqbufferedproducer.h:97
std::string kConfirm
Definition: zmqbufferedproducer.cpp:12
A buffered producer that obtains data via ZeroMQ.
Definition: zmqbufferedproducer.h:51
The TorchCraftAI training library.
Definition: batcher.cpp:15
std::string endpoint() const
Definition: zmqbufferedproducer.h:60
void stop()
Definition: zmqbufferedproducer.h:107
A stream buffer for reading from a vector of bytes.
Definition: serialization.h:33
ZeroMQBufferedProducer(uint8_t nthreads, size_t maxQueueSize, std::string endpoint=std::string())
Definition: zmqbufferedproducer.h:81
Definition: zstdstream.h:130
void handleRequest(void const *buf, size_t len, ReqRepServer::ReplyFn reply)
Definition: zmqbufferedproducer.h:113