TorchCraftAI
A bot for machine learning research on StarCraft: Brood War
trainer.h
1 /*
2  * Copyright (c) 2017-present, Facebook, Inc.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #pragma once
9 
10 #include "batcher.h"
11 #include "common/rand.h"
12 #include "metrics.h"
13 
14 #include <shared_mutex>
15 
16 #include <glog/logging.h>
17 
18 #include <autogradpp/autograd.h>
19 
20 /**
21  * The TorchCraftAI training library.
22  */
23 namespace cpid {
24 // I'm unsure whether this is needed, some higher power needs to decide on how
25 // to
26 // hash individual "games" so they can be seen as separate. Games are separate
27 // from episodes because you can, for example, have multiple "episodes" of a
28 // Builder (starts when it tries to build something, ends when we know whether
29 // it succeeded or not) within the same game... For now we can probably just
30 // ignore it or pass in "" for all cases.
31 using GameUID = std::string;
32 using EpisodeKey = std::string;
33 const constexpr auto kDefaultEpisodeKey = "";
34 class Evaluator;
35 
36 inline const std::string kValueKey = "V";
37 inline const std::string kQKey = "Q";
38 inline const std::string kPiKey = "Pi";
39 inline const std::string kSigmaKey = "std";
40 inline const std::string kActionQKey = "actionQ";
41 inline const std::string kActionKey = "action";
42 inline const std::string kPActionKey = "pAction";
43 
44 struct pairhash {
45  public:
46  template <typename T, typename U>
47  std::size_t operator()(const std::pair<T, U>& x) const {
48  return std::hash<T>()(x.first) ^ std::hash<U>()(x.second);
49  }
50 };
51 
52 struct EpisodeTuple {
55 };
56 
57 // Creates UIDs for each rank, uint64_t unique ids (will wrap when exhausted)
59 
60 /**
61  * Stub base class for replay buffer frames.
62  *
63  * Frames that are serializable need to implement serialize() or load()/save()
64  * and register themselves in global scope with
65  * `CEREAL_REGISTER_TYPE(MyReplayBufferFrame)` and (depending on how you
66  * serialize) `CEREAL_REGISTER_POLYMORPHIC_RELATION(ReplayBufferFrame,
67  * MyReplayBufferFrame)`
68  */
70  virtual ~ReplayBufferFrame() = default;
71 
72  template <class Archive>
73  void serialize(Archive& ar) {}
74 };
75 
77  explicit RewardBufferFrame(float reward) : reward(reward) {}
78  float reward;
79 };
80 
81 /**
82  * Stores an unordered_map[GameUID] = unordered_map[int, Episode]
83  * All the public functions here should be autolocking and therefore relatively
84  * thread safe. However, things like size do no perfectly accurately represent
85  * the size in a multithreaded environment.
86  */
87 class ReplayBuffer {
88  public:
89  using Episode = std::vector<std::shared_ptr<ReplayBufferFrame>>;
90  using Store =
91  std::unordered_map<GameUID, std::unordered_map<EpisodeKey, Episode>>;
92  using UIDKeyStore =
93  std::unordered_map<GameUID, std::unordered_set<EpisodeKey>>;
94  using SampleOutput = std::pair<EpisodeTuple, std::reference_wrapper<Episode>>;
95 
96  Episode& append(
97  GameUID uid,
98  EpisodeKey key,
99  std::shared_ptr<ReplayBufferFrame> value,
100  bool isDone = false);
101 
102  std::size_t size() const;
103  std::size_t size(GameUID const&) const;
104  std::size_t sizeDone() const;
105  std::size_t sizeDone(GameUID const&) const;
106  void clear();
107  void erase(GameUID const&, EpisodeKey const& = kDefaultEpisodeKey);
108  std::vector<SampleOutput> getAllEpisodes();
109 
110  // Stupid sample for now, samples uniformly over games and then over episodes
111  // No guarantee of uniqueness :)
112  template <typename RandomGenerator>
113  std::vector<SampleOutput> sample(RandomGenerator& g, uint32_t num = 1);
114  std::vector<SampleOutput> sample(uint32_t num = 1);
115 
116  Episode& get(GameUID const&, EpisodeKey const& = kDefaultEpisodeKey);
117  bool has(GameUID const&, EpisodeKey const& = kDefaultEpisodeKey);
118  bool isDone(GameUID const&, EpisodeKey const& = kDefaultEpisodeKey);
119 
120  protected:
123 
124  template <typename RandomGenerator>
125  SampleOutput sample_(RandomGenerator& g);
126 
127  // Can be replaced with shared_mutex in C++17
128  mutable std::shared_timed_mutex replayerRWMutex_;
129 };
130 
131 class AsyncBatcher;
132 class BaseSampler;
133 
134 struct HandleGuard {};
135 
136 /**
137  * The Trainer should be shared amongst multiple different nodes, and
138  * attached to a single Module.
139  * It consists of an:
140  * - Algorithm, built into the trainer and subclassed via implementing
141  * the stepFrame and stepGame functions. The trainer itself might also
142  * start a seperate thread or otherwise have functionality for
143  * syncing weights.
144  * - Model, defined externally subject to algorithm specifications
145  * I don't know of a good way of enforcing model output, except an
146  * incorrect output spec will cause the algorithm to fail. Thus,
147  * please comment new algorithms with its input specification.
148  * - An optimizer, perhaps defined externally. Many of the multinode
149  * algorithms will probably force its own optimizer
150  * - A sampler, responsible for transforming the model output into an
151  * action
152  * - A replay buffer, as implemented below and forced by subclassing the
153  * algorithm. Each algorithm will expect its own replay buffer.
154  *
155  */
156 class Trainer {
157  public:
158  struct EpisodeHandle {
159  EpisodeHandle(Trainer* trainer, GameUID id, EpisodeKey k);
160 
161  // No trainer = no handle
162  EpisodeHandle() = default;
163  explicit operator bool() const;
164 
165  GameUID const& gameID() const;
166  EpisodeKey const& episodeKey() const;
167 
168  ~EpisodeHandle();
169 
170  // Can't copy or assign
171  EpisodeHandle(EpisodeHandle&) = delete;
172  EpisodeHandle(EpisodeHandle const&) = delete;
173  EpisodeHandle& operator=(EpisodeHandle&) = delete;
174  EpisodeHandle& operator=(EpisodeHandle const&) = delete;
175 
176  // Move is possible, and it will invalidate other episode
178  // In move assignment, if we have a existing episode, it's force stopped
179  EpisodeHandle& operator=(EpisodeHandle&&);
180 
181  friend std::ostream& operator<<(std::ostream&, EpisodeHandle const&);
182 
183  private:
184  Trainer* trainer_;
185  GameUID gameID_;
186  EpisodeKey episodeKey_;
187  std::weak_ptr<HandleGuard> guard_;
188  };
189  Trainer(
190  ag::Container model,
191  ag::Optimizer optim,
192  std::unique_ptr<BaseSampler>,
193  std::unique_ptr<AsyncBatcher> batcher = nullptr);
194  virtual ag::Variant forward(ag::Variant inp, EpisodeHandle const&);
195 
196  /// Convenience function when one need to forward a single input. This will
197  /// make it look batched and forward it, so that the model has no problem
198  /// handling it.
199  /// If \param{model} is not provided, will use the trainer's model_.
200  ag::Variant forwardUnbatched(ag::Variant in, ag::Container model = nullptr);
201 
202  // Runs the training loop once. The return value is whether the model
203  // succesfully updated. Sometimes, algorithms will be blocked while waiting
204  // for new episodes, and this update will return false;
205  virtual bool update() = 0;
206  virtual ~Trainer() = default;
207 
208  void setTrain(bool = true);
209  bool isTrain() const {
210  return train_;
211  }
212  /// Sample using the class' sampler
213  ag::Variant sample(ag::Variant in);
214  ag::Container model() const;
215  ag::Optimizer optim() const;
216  ReplayBuffer& replayBuffer();
217  virtual std::shared_ptr<Evaluator> makeEvaluator(
218  size_t /* how many to run */,
219  std::unique_ptr<BaseSampler> sampler);
220 
221  // These helper functions expose an atomic<bool> that can be convenient
222  // for coordinating training threads. The Trainer itself is not affected by
223  // this.
224  void setDone(bool = true);
225  bool isDone() const {
226  return done_.load();
227  }
228 
229  virtual void step(
230  EpisodeHandle const&,
231  std::shared_ptr<ReplayBufferFrame> v,
232  bool isDone = false);
233  virtual std::shared_ptr<ReplayBufferFrame>
234  makeFrame(ag::Variant trainerOutput, ag::Variant state, float reward) = 0;
235  /// Returns true if succeeded to register an episode, and false otherwise.
236  /// After
237  /// receiving false, a worker thread should check stopping conditins and
238  /// re-try.
239  virtual EpisodeHandle startEpisode();
240  // For when an episode is not done and you want to remove it from training
241  // because it is corrupted or for some other reason.
242  virtual void forceStopEpisode(EpisodeHandle const&);
243  bool isActive(EpisodeHandle const&);
244  /// Releases all the worker threads so that they can be joined.
245  /// For the off-policy trainers, labels all games as inactive. For the
246  /// on-policy trainers, additionally un-blocks all threads that could be
247  /// waiting at the batch barrier.
248  virtual void reset();
249 
250  template <class Archive>
251  void save(Archive& ar) const;
252  template <class Archive>
253  void load(Archive& ar);
254  template <typename T>
255  bool is() const;
256 
257  Trainer& setMetricsContext(std::shared_ptr<MetricsContext> context);
258  std::shared_ptr<MetricsContext> metricsContext() const;
259 
260  TORCH_ARG(float, noiseStd) = 1e-2;
261  TORCH_ARG(bool, continuousActions) = false;
262 
263  void setBatcher(std::unique_ptr<AsyncBatcher> batcher);
264 
265  protected:
266  virtual void
268  virtual void
270  // Currently, stepGame is not called from anywhere. TODO
271  virtual void stepGame(GameUID const& game){};
272 
273  ag::Container model_;
274  ag::Optimizer optim_;
275  std::shared_ptr<MetricsContext> metricsContext_;
277  bool train_ = true;
278  std::atomic<bool> done_{false};
279  std::mutex modelWriteMutex_;
280  std::shared_timed_mutex activeMapMutex_;
281 
282  std::unique_ptr<BaseSampler> sampler_;
283  std::unique_ptr<AsyncBatcher> batcher_;
284 
285  std::shared_ptr<HandleGuard> epGuard_;
286  template <typename T>
287  std::vector<T const*> cast(ReplayBuffer::Episode const& e);
288 
289  using ForwardFunction =
290  std::function<ag::Variant(ag::Variant, EpisodeHandle const&)>;
291  // Private for trainers to use if they want to support evaluation
292  static std::shared_ptr<Evaluator> evaluatorFactory(
293  ag::Container model,
294  std::unique_ptr<BaseSampler> s,
295  size_t n,
296  ForwardFunction func);
297 
299  /// We subsample kFwdMetricsSubsampling of the forward() events
300  /// when measuring their duration
301  static constexpr float kFwdMetricsSubsampling = 0.1;
302 };
304 
305 /********************* IMPLEMENTATIONS *************************/
306 
307 template <typename RandomGenerator>
308 inline std::vector<ReplayBuffer::SampleOutput> ReplayBuffer::sample(
309  RandomGenerator& g,
310  uint32_t num) {
311  std::vector<SampleOutput> samples;
312  for (uint32_t i = 0; i < num; i++) {
313  samples.push_back(sample_(g));
314  }
315  return samples;
316 }
317 
318 template <typename RandomGenerator>
320  std::shared_lock<std::shared_timed_mutex> lock(replayerRWMutex_);
321  if (dones_.size() == 0) {
322  throw std::runtime_error("No finished episodes yet...");
323  }
324  auto& game = *common::select_randomly(dones_.begin(), dones_.end(), g);
325  if (game.second.size() == 0) {
326  LOG(FATAL) << "no episodes in game"; // This shouldn't ever happen...
327  }
328  auto& ep =
329  *common::select_randomly(game.second.begin(), game.second.end(), g);
330  return std::make_pair(
331  EpisodeTuple{game.first, ep}, std::ref(storage_[game.first][ep]));
332 }
333 
334 template <typename T>
335 inline std::vector<T const*> Trainer::cast(ReplayBuffer::Episode const& e) {
336  // TODO might be better to return an iterator instead
337  std::vector<T const*> ret;
338  ret.reserve(e.size());
339  for (auto& elem : e) {
340  ret.push_back(static_cast<T const*>(elem.get()));
341  }
342  return ret;
343 }
344 
345 template <class Archive>
346 inline void Trainer::save(Archive& ar) const {
347  ar(CEREAL_NVP(*model_));
348  ar(CEREAL_NVP(optim_));
349 }
350 
351 template <class Archive>
352 inline void Trainer::load(Archive& ar) {
353  ar(CEREAL_NVP(*model_));
354  ar(CEREAL_NVP(optim_));
355  optim_->add_parameters(model_->parameters());
356 }
357 
358 template <typename T>
359 inline bool Trainer::is() const {
360  return dynamic_cast<const T*>(this) != nullptr;
361 }
362 
364  std::shared_ptr<MetricsContext> context) {
365  metricsContext_ = context;
366  return *this;
367 }
368 
369 inline std::shared_ptr<MetricsContext> Trainer::metricsContext() const {
370  return metricsContext_;
371 }
372 } // namespace cpid
ag::Optimizer optim_
Definition: trainer.h:274
std::string GameUID
Definition: trainer.h:31
bool is() const
Definition: trainer.h:359
std::size_t operator()(const std::pair< T, U > &x) const
Definition: trainer.h:47
Trainer & setMetricsContext(std::shared_ptr< MetricsContext > context)
Definition: trainer.h:363
const std::string kQKey
Definition: trainer.h:37
Trainer::EpisodeHandle EpisodeHandle
Definition: trainer.h:303
Definition: trainer.h:134
Definition: batcher.h:22
std::shared_ptr< MetricsContext > metricsContext_
Definition: trainer.h:275
Definition: trainer.h:158
std::pair< EpisodeTuple, std::reference_wrapper< Episode >> SampleOutput
Definition: trainer.h:94
Definition: evaluator.h:19
Store storage_
Definition: trainer.h:121
float reward
Definition: trainer.h:78
std::shared_ptr< HandleGuard > epGuard_
Definition: trainer.h:285
const std::string kActionKey
Definition: trainer.h:41
std::unique_ptr< AsyncBatcher > batcher_
Definition: trainer.h:283
virtual void stepEpisode(GameUID const &, EpisodeKey const &, ReplayBuffer::Episode &)
Definition: trainer.h:269
SampleOutput sample_(RandomGenerator &g)
Definition: trainer.h:319
std::unordered_map< GameUID, std::unordered_set< EpisodeKey >> UIDKeyStore
Definition: trainer.h:93
const constexpr auto kDefaultEpisodeKey
Definition: trainer.h:33
const std::string kPiKey
Definition: trainer.h:38
std::shared_ptr< MetricsContext > metricsContext() const
Definition: trainer.h:369
EpisodeKey episodeKey
Definition: trainer.h:54
const std::string kValueKey
Definition: trainer.h:36
std::vector< T const * > cast(ReplayBuffer::Episode const &e)
Definition: trainer.h:335
The Trainer should be shared amongst multiple different nodes, and attached to a single Module...
Definition: trainer.h:156
void serialize(Archive &ar)
Definition: trainer.h:73
const std::string kPActionKey
Definition: trainer.h:42
void load(Archive &ar)
Definition: trainer.h:352
ReplayBuffer::UIDKeyStore actives_
Definition: trainer.h:298
Definition: trainer.h:76
const std::string kActionQKey
Definition: trainer.h:40
RewardBufferFrame(float reward)
Definition: trainer.h:77
A sampler takes the output of the model, and outputs an action accordingly.
Definition: sampler.h:19
GameUID genGameUID()
Definition: trainer.cpp:41
std::string EpisodeKey
Definition: trainer.h:32
UIDKeyStore dones_
Definition: trainer.h:122
std::shared_timed_mutex replayerRWMutex_
Definition: trainer.h:128
std::function< ag::Variant(ag::Variant, EpisodeHandle const &)> ForwardFunction
Definition: trainer.h:290
std::unique_ptr< BaseSampler > sampler_
Definition: trainer.h:282
std::vector< SampleOutput > sample(RandomGenerator &g, uint32_t num=1)
Definition: trainer.h:308
bool isTrain() const
Definition: trainer.h:209
The TorchCraftAI training library.
Definition: batcher.cpp:15
Stub base class for replay buffer frames.
Definition: trainer.h:69
Definition: trainer.h:52
std::ostream & operator<<(std::ostream &os, EpisodeHandle const &handle)
Definition: trainer.cpp:379
std::unordered_map< GameUID, std::unordered_map< EpisodeKey, Episode >> Store
Definition: trainer.h:91
std::mutex modelWriteMutex_
Definition: trainer.h:279
GameUID gameID
Definition: trainer.h:53
Definition: trainer.h:44
virtual void stepFrame(GameUID const &, EpisodeKey const &, ReplayBuffer::Episode &)
Definition: trainer.h:267
std::vector< std::shared_ptr< ReplayBufferFrame >> Episode
Definition: trainer.h:89
virtual void stepGame(GameUID const &game)
Definition: trainer.h:271
ReplayBuffer replayer_
Definition: trainer.h:276
bool isDone() const
Definition: trainer.h:225
Stores an unordered_map[GameUID] = unordered_map[int, Episode] All the public functions here should b...
Definition: trainer.h:87
void save(Archive &ar) const
Definition: trainer.h:346
const std::string kSigmaKey
Definition: trainer.h:39
Iter select_randomly(Iter start, Iter end, RandomGenerator &g)
Definition: rand.h:78
std::shared_timed_mutex activeMapMutex_
Definition: trainer.h:280