TorchCraftAI
A bot for machine learning research on StarCraft: Brood War
reqrepserver.h
1 /*
2  * Copyright (c) 2017-present, Facebook, Inc.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #pragma once
9 
10 #include <atomic>
11 #include <chrono>
12 #include <future>
13 #include <memory>
14 #include <queue>
15 #include <shared_mutex>
16 #include <string>
17 #include <thread>
18 #include <vector>
19 
20 namespace zmq {
21 class context_t;
22 class socket_t;
23 } // namespace zmq
24 
25 namespace cpid {
26 
27 /**
28  * A request-reply server backed by ZeroMQ.
29  *
30  * This server will listen for messages in a dedicated thread and call the
31  * supplied callback function for every incoming request. Note that if
32  * `numThreads` is greater than one, the callback function maybe be called
33  * concurrently from multiple threads.
34  * The callback function will be supplied with a reply function; this function
35  * *must* be called before returning. Failure to do so will return in a fatal
36  * error (i.e. program abort).
37  */
38 class ReqRepServer final {
39  public:
40  using ReplyFn = std::function<void(void const* buf, size_t len)>;
41  using CallbackFn =
42  std::function<void(void const* buf, size_t len, ReplyFn reply)>;
43 
44  /// Constructor.
45  /// This instance will handle up to numThreads replies concurrently.
46  /// If endpoint is an empty string, bind to local IP with automatic port
47  /// selection.
48  /// The callback will be called from dedicated threads.
50  CallbackFn callback,
51  size_t numThreads = 1,
52  std::string endpoint = std::string());
53  ~ReqRepServer();
54 
55  std::string endpoint() const;
56 
57  private:
58  void listen(std::string endpoint, std::promise<std::string>&& endpointP);
59  void runWorker(std::string const& endpoint);
60 
61  CallbackFn callback_;
62  size_t numThreads_;
63  std::shared_ptr<zmq::context_t> context_;
64  std::mutex contextM_;
65  mutable std::string endpoint_;
66  mutable std::future<std::string> endpointF_;
67  mutable std::mutex endpointM_;
68  std::thread thread_;
69 };
70 
71 /**
72  * A request-reply client backed by ZeroMQ.
73  *
74  * This class provides a futures-based interface to the request-reply pattern.
75  * You call request() and get a future fo your (future) reply. Note that
76  * requests() will always happily accept the request and move into in internal
77  * queue. This queue is *unbounded* -- if this is a concern you should add some
78  * manual blocking logic; see ZeroMQBufferedConsumer() for an example.
79  *
80  * The client can be connected to multiple ReqRepServers and will send out
81  * requests in a round-robin fashion. The number of concurrent replies that can
82  * be sent is controlled with the `maxConcurrentRequests` parameter. There are
83  * some basic robustness guarantees regarding slow or crashing servers: if a
84  * server does not send a reply in time, retries will be attempted. The number
85  * of retries can be limited; in this case, the future will be fulfilled with an
86  * exception. The server list can be updated loss of messages.
87  */
88 class ReqRepClient final {
89  public:
90  using Clock = std::chrono::steady_clock;
91  using TimePoint = std::chrono::time_point<Clock>;
92  using Blob = std::vector<char>;
93 
95  size_t maxConcurrentRequests,
96  std::vector<std::string> endpoints,
97  std::shared_ptr<zmq::context_t> context = nullptr);
98  ~ReqRepClient();
99 
100  std::future<std::vector<char>> request(std::vector<char> msg);
101  /// Returns true if the endpoints changed
102  bool updateEndpoints(std::vector<std::string> endpoints);
103 
104  void setReplyTimeout(std::chrono::milliseconds timeout) {
105  setReplyTimeoutMs(timeout.count());
106  }
107  void setReplyTimeoutMs(size_t timeoutMs);
108  void setMaxRetries(size_t count);
109 
110  private:
111  void run();
112 
113  struct QueueItem {
114  Blob msg;
115  std::promise<Blob> promise;
116  size_t retries = 0;
117  QueueItem(Blob msg) : msg(std::move(msg)) {}
118  QueueItem() = default;
119  QueueItem(QueueItem&&) = default;
120  QueueItem& operator=(QueueItem&&) = default;
121  QueueItem(QueueItem const&) = delete;
122  QueueItem& operator=(QueueItem const&) = delete;
123  };
124 
125  std::shared_ptr<zmq::context_t> context_;
126 
127  std::shared_mutex epM_;
128  std::vector<std::string> endpoints_;
129  bool endpointsChanged_ = false;
130 
131  std::mutex queueM_;
132  std::queue<QueueItem> queue_;
133  size_t const maxConcurrentRequests_;
134  std::atomic<size_t> replyTimeoutMs_{10 * 1000};
135  std::atomic<size_t> maxRetries_{std::numeric_limits<size_t>::max()};
136  std::thread thread_;
137  std::atomic<bool> stop_{false};
138  std::string signalEndpoint_;
139  std::unique_ptr<zmq::socket_t> signalSocket_;
140 };
141 
142 } // namespace cpid
A request-reply client backed by ZeroMQ.
Definition: reqrepserver.h:88
std::function< void(void const *buf, size_t len)> ReplyFn
Definition: reqrepserver.h:40
A request-reply server backed by ZeroMQ.
Definition: reqrepserver.h:38
std::function< void(void const *buf, size_t len, ReplyFn reply)> CallbackFn
Definition: reqrepserver.h:42
std::chrono::time_point< Clock > TimePoint
Definition: reqrepserver.h:91
The TorchCraftAI training library.
Definition: batcher.cpp:15
Definition: episodeserver.h:16
std::chrono::steady_clock Clock
Definition: reqrepserver.h:90
void setReplyTimeout(std::chrono::milliseconds timeout)
Definition: reqrepserver.h:104
std::vector< char > Blob
Definition: reqrepserver.h:92