TorchCraftAI
A bot for machine learning research on StarCraft: Brood War
cpid2kworker.h
1 /*
2  * Copyright (c) 2017-present, Facebook, Inc.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #pragma once
9 
10 #include "distributed.h"
11 #include <nlohmann/json.hpp>
12 
13 #include <mutex>
14 #include <thread>
15 
16 namespace cpid {
17 
18 class RedisClient;
19 
23 
24  bool roleIs(std::string_view role);
25 
26  template <typename Archive>
27  void serialize(Archive& ar) {
28  ar(CEREAL_NVP(id));
29  ar(CEREAL_NVP(host));
30  ar(CEREAL_NVP(services));
31  }
32 
33  /// Worker ID
34  std::string id;
35  /// IP address of the machine this process is running on.
36  std::string host;
37  /// Services offered by this worker (name to port number)
38  std::map<std::string, int> services;
39 };
40 
41 /**
42  * Periodically sends out heartbeats to a Redis instance.
43  *
44  * The supplied Cpid2kWorkerInfo will be sent as the heartbeat value to the
45  * database. In addition, during construction this class will ensure that
46  * startup can be performed according to the scheduler (see the supplied Redis
47  * schema). If not, the constructor will throw.
48  */
50  public:
52  Cpid2kWorkerInfo info,
53  std::string prefix,
54  std::string_view host,
55  int port,
56  int64_t intevalMs = 10 * 1000);
58 
59  /// Returns true if the worker is considered dead by the scheduler.
60  /// In this case, the worker should abort its execution.
61  bool consideredDead() const;
62 
63  int64_t intervalMs() const {
64  return intervalMs_;
65  }
66 
67  private:
68  std::string bootKey() const;
69  std::string deadKey() const;
70  std::string heartBeatKey() const;
71  std::string heartBeatData() const;
72  void boot();
73  void run();
74 
75  Cpid2kWorkerInfo info_;
76  std::string prefix_;
77  int64_t intervalMs_;
78  std::unique_ptr<RedisClient> redis_;
79  std::thread th_;
80  std::atomic<bool> stop_{false};
81  std::atomic<bool> consideredDead_{false};
82 };
83 
84 /**
85  * Encapsulates information about the participating peers in a cpid2k job.
86  */
88  public:
89  using Clock = std::chrono::steady_clock;
90 
91  Cpid2kGlobalState(std::string prefix, int64_t updateIntervalMs = 5 * 1000);
92 
93  void update(RedisClient& client);
94 
95  std::string_view prefix() const {
96  return prefix_;
97  }
98  bool isDone();
99  std::vector<Cpid2kWorkerInfo> peers(std::string_view role);
100  std::vector<std::string> serviceEndpoints(std::string const& serviceName);
101 
102  private:
103  void tryUpdate(RedisClient& client);
104 
105  std::string prefix_;
106  std::mutex mutex_;
107  int64_t peerv_ = -1; // Version number for peer information
108  Clock::time_point lastPeersCheck_;
109  std::chrono::milliseconds pcInterval_;
110  std::vector<Cpid2kWorkerInfo> peers_;
111  bool isDone_ = false;
112 };
113 
114 /**
115  * Helper class for job coordination via a central Redis instance.
116  *
117  * In a nutshell, the Cpid2kWorker class does the following:
118  * - Communicate local job status to the scheduler via Cpid2kHeartbeater (ctor
119  * will throw if job does not have start permission).
120  * - Provide basic information about global job status (`peers()`, `isDone()`,
121  * etc.) and local status as seen by the scheduler (`consideredDead()`).
122  * - Convenience functions for common operations (`dcontext()`,
123  * `waitForOne/All()`, etc.)
124  *
125  * For manual operations on the Redis database, use `threadLocalClient()` to
126  * obtain a RedisClient instance for the current thread. Note that these will be
127  * re-used.
128  *
129  * All public functions are thread-safe, i.e. it's alright to call them from
130  * several trainer or game threads.
131  */
133  public:
134  using Clock = std::chrono::steady_clock;
135  static std::string const kAnyRole;
136  static std::chrono::milliseconds const kNoTimeout;
137  static std::chrono::milliseconds const kDefaultTimeout;
138 
139  Cpid2kWorker(
140  Cpid2kWorkerInfo info,
141  std::string prefix,
142  std::string host,
143  int port = 6379,
144  int64_t hbIntervalMs = 10 * 1000);
145  ~Cpid2kWorker();
146 
147  static std::unique_ptr<Cpid2kWorker> fromEnvVars(Cpid2kWorkerInfo const&);
148  static std::unique_ptr<Cpid2kWorker> fromEnvVars();
149 
150  Cpid2kWorkerInfo const& info() const;
151  std::string_view prefix() const;
152  bool consideredDead() const;
153  bool isDone();
154  std::string redisKey(std::string_view key) const;
155  std::shared_ptr<RedisClient> threadLocalClient();
157  return hb_;
158  }
159 
160  distributed::Context& dcontext(
161  std::string const& role = kAnyRole,
162  std::chrono::milliseconds timeout = kDefaultTimeout);
163  void discardDContext(std::string const& role = kAnyRole);
164 
165  std::vector<Cpid2kWorkerInfo> peers(std::string_view role = kAnyRole);
166  std::vector<std::string> serviceEndpoints(std::string const& serviceName);
167  bool waitForOne(
168  std::string_view role,
169  std::chrono::milliseconds timeout = kNoTimeout);
170  bool waitForAll(
171  std::string_view role,
172  std::chrono::milliseconds timeout = kNoTimeout);
173  void appendMetrics(std::string_view metricsName, nlohmann::json const& json);
174 
175  private:
176  std::shared_ptr<RedisClient> redisClient(std::thread::id id);
177  int numWorkersWithRoleInSpec(std::string_view role);
178 
179  std::mutex mutex_; // General mutex to make functions thread-safe
180  Cpid2kWorkerInfo info_;
181  std::string prefix_;
182  std::string host_;
183  int port_;
184  Cpid2kHeartBeater hb_;
185  Cpid2kGlobalState gs_;
186  std::chrono::milliseconds pcInterval_;
187  std::unordered_map<std::string, std::unique_ptr<distributed::Context>>
188  dcontexts_;
189  std::unordered_map<std::string, std::vector<std::string>> dcontextIds_;
190  std::unordered_map<std::thread::id, std::shared_ptr<RedisClient>>
191  threadClients_;
192 };
193 
194 /**
195  * Helper class to aggregate metrics locally, and send them reguarly as events
196  * in the redis database as key 'prefix:metricEvents'
197  */
199  public:
206  };
207  struct EventMetric {
208  template <typename T>
210  std::string n,
211  T v,
212  AggregationType a = AggregateMean,
213  typename std::enable_if_t<std::is_arithmetic<T>::value>* = 0)
214  : name(std::move(n)), value(float(v)), aggregation(a) {}
215  virtual ~EventMetric() = default;
216  std::string name;
217  float value;
219  };
220  struct Aggregator {
221  virtual ~Aggregator() = default;
222  virtual void add(float value) = 0;
223  virtual nlohmann::json value() const = 0;
224  std::string_view type;
225  };
226 
228  std::shared_ptr<Cpid2kWorker> worker,
229  std::chrono::milliseconds sendInterval = std::chrono::seconds(30));
230  ~Cpid2kMetrics();
231  void push(std::vector<EventMetric> const& metrics);
232 
233  protected:
234  void run();
235  using Clock = std::chrono::steady_clock;
236 
237  std::shared_ptr<Cpid2kWorker> worker_;
238  std::chrono::milliseconds sendInterval_;
239 
240  std::thread thr_;
241  std::atomic<bool> stop_;
242 
243  std::mutex aggregatorsMutex_;
244  std::unordered_map<std::string, std::unique_ptr<Aggregator>> aggregators_;
245 };
246 
247 } // namespace cpid
std::chrono::steady_clock Clock
Definition: cpid2kworker.h:89
Definition: cpid2kworker.h:203
std::chrono::steady_clock Clock
Definition: cpid2kworker.h:134
std::thread thr_
Definition: cpid2kworker.h:240
Cpid2kHeartBeater & heartBeater()
Definition: cpid2kworker.h:156
bool roleIs(std::string_view role)
Definition: cpid2kworker.cpp:62
std::shared_ptr< Cpid2kWorker > worker_
Definition: cpid2kworker.h:237
static std::string const kAnyRole
Definition: cpid2kworker.h:135
std::chrono::milliseconds sendInterval_
Definition: cpid2kworker.h:238
static std::chrono::milliseconds const kDefaultTimeout
Definition: cpid2kworker.h:137
Definition: distributed.h:108
STL namespace.
Definition: cpid2kworker.h:205
std::string host
IP address of the machine this process is running on.
Definition: cpid2kworker.h:36
static std::chrono::milliseconds const kNoTimeout
Definition: cpid2kworker.h:136
AggregationType
Definition: cpid2kworker.h:200
std::string_view prefix() const
Definition: cpid2kworker.h:95
Definition: cpid2kworker.h:204
static Cpid2kWorkerInfo withLocalIp()
Definition: cpid2kworker.cpp:50
std::string name
Definition: cpid2kworker.h:216
Helper class to aggregate metrics locally, and send them reguarly as events in the redis database as ...
Definition: cpid2kworker.h:198
std::atomic< bool > stop_
Definition: cpid2kworker.h:241
Definition: cpid2kworker.h:20
Periodically sends out heartbeats to a Redis instance.
Definition: cpid2kworker.h:49
std::chrono::steady_clock Clock
Definition: cpid2kworker.h:235
Definition: cpid2kworker.h:207
std::string id
Worker ID.
Definition: cpid2kworker.h:34
Simple, synchronous C++ wrapper for the Hiredis Redis client.
Definition: redisclient.h:30
The TorchCraftAI training library.
Definition: batcher.cpp:15
Definition: cpid2kworker.h:220
Definition: cpid2kworker.h:201
void serialize(Archive &ar)
Definition: cpid2kworker.h:27
std::map< std::string, int > services
Services offered by this worker (name to port number)
Definition: cpid2kworker.h:38
std::mutex aggregatorsMutex_
Definition: cpid2kworker.h:243
Encapsulates information about the participating peers in a cpid2k job.
Definition: cpid2kworker.h:87
static Cpid2kWorkerInfo withLocalIpFromEnvVars()
Definition: cpid2kworker.cpp:56
EventMetric(std::string n, T v, AggregationType a=AggregateMean, typename std::enable_if_t< std::is_arithmetic< T >::value > *=0)
Definition: cpid2kworker.h:209
AggregationType aggregation
Definition: cpid2kworker.h:218
float value
Definition: cpid2kworker.h:217
int64_t intervalMs() const
Definition: cpid2kworker.h:63
Helper class for job coordination via a central Redis instance.
Definition: cpid2kworker.h:132
std::string_view type
Definition: cpid2kworker.h:224
Definition: cpid2kworker.h:202
std::unordered_map< std::string, std::unique_ptr< Aggregator > > aggregators_
Definition: cpid2kworker.h:244