TorchCraftAI
A bot for machine learning research on StarCraft: Brood War
blobpubsub.h
1 /*
2  * Copyright (c) 2017-present, Facebook, Inc.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #pragma once
9 
10 #include <common/flags.h>
11 
12 #include <zmq.hpp>
13 
14 #include <future>
15 #include <mutex>
16 #include <string>
17 #include <vector>
18 
19 namespace cpid {
20 
21 /**
22  * Publisher for ZeroMQ PUB-SUB pattern.
23  *
24  * This server will publish binary blobs at `endpoint()`. The last published
25  * blob will be cached and re-published if new subscribers are joining.
26  *
27  * Published data consists of both a tag and binary data. The tag can be used to
28  * disambiguiate blobs on the subscriber side but does not affect transport.
29  */
30 class BlobPublisher final {
31  public:
32  enum DataFlags {
33  None = 0,
34  HasData = 1 << 0,
35  NewData = 1 << 1,
36  };
37 
39  std::string endpoint = std::string(),
40  std::shared_ptr<zmq::context_t> context = nullptr);
42 
43  std::string endpoint() const;
44 
45  void publish(void const* data, size_t len, int64_t tag);
46  void publish(std::vector<char>&& data, int64_t tag);
47 
48  private:
49  void run(std::string endpoint, std::promise<std::string>&& endpointP);
50 
51  std::shared_ptr<zmq::context_t> context_;
52  mutable std::string endpoint_;
53  mutable std::future<std::string> endpointF_;
54  mutable std::mutex endpointM_;
55  std::thread thread_;
56  std::atomic<bool> stop_{false};
57  int64_t tag_;
58  std::vector<char> data_;
59  DataFlags dflags_ = DataFlags::None;
60  std::mutex dataM_;
61  std::condition_variable dataCV_;
62 };
63 
65 
66 /**
67  * Subscriber for ZeroMQ PUB-SUB pattern.
68  *
69  * This client will subscribe to *one* of the BlobPublisher endpoints specified
70  * and listen for incoming mesages. For each received blob, a user-defined
71  * callback will be called (in the context of the dedicated listening thread).
72  *
73  * Note that due to last-value-caching, the callback might be called multiple
74  * times for the same data and tag as they might be broadcasted multiple times.
75  *
76  * Changing the endpoints via `updateEndpoints()` will trigger endpoint
77  * re-selection, which in turn might trigger re-subscription to a new publisher
78  * endpoint and which in turn will trigger re-broadcasts.
79  */
80 class BlobSubscriber final {
81  public:
82  using CallbackFn =
83  std::function<void(void const* data, size_t len, int64_t tag)>;
84 
85  public:
87  CallbackFn callback,
88  std::vector<std::string> endpoints,
89  std::shared_ptr<zmq::context_t> context = nullptr);
90  ~BlobSubscriber();
91 
92  void updateEndpoints(std::vector<std::string> endpoints);
93 
94  private:
95  void listen();
96 
97  CallbackFn callback_;
98  std::shared_ptr<zmq::context_t> context_;
99  std::vector<std::string> endpoints_;
100  std::mutex endpointsM_;
101  std::thread thread_;
102  std::atomic<bool> stop_{false};
103  std::atomic<bool> endpointsChanged_{false};
104 };
105 
106 } // namespace cpid
std::string endpoint() const
Definition: blobpubsub.cpp:52
Definition: blobpubsub.h:35
std::function< void(void const *data, size_t len, int64_t tag)> CallbackFn
Definition: blobpubsub.h:83
~BlobPublisher()
Definition: blobpubsub.cpp:46
DEFINE_FLAG_OPERATORS(BlobPublisher::DataFlags)
Publisher for ZeroMQ PUB-SUB pattern.
Definition: blobpubsub.h:30
The TorchCraftAI training library.
Definition: batcher.cpp:15
Definition: blobpubsub.h:34
BlobPublisher(std::string endpoint=std::string(), std::shared_ptr< zmq::context_t > context=nullptr)
Definition: blobpubsub.cpp:35
Definition: blobpubsub.h:33
void publish(void const *data, size_t len, int64_t tag)
Definition: blobpubsub.cpp:61
DataFlags
Definition: blobpubsub.h:32
Subscriber for ZeroMQ PUB-SUB pattern.
Definition: blobpubsub.h:80