TorchCraftAI
A bot for machine learning research on StarCraft: Brood War
centraltrainer.h
1 /*
2  * Copyright (c) 2017-present, Facebook, Inc.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #pragma once
9 
10 #include "episodeserver.h"
11 
12 #include <shared_mutex>
13 
14 namespace cpid {
15 
16 /**
17  * A trainer that sends episodes to one or more central instances.
18  *
19  * In this trainer, several "server" instances will collect episode data from
20  * "client" instances. Users are required to subclass this and override
21  * `receivedFrames()`, which will be called on server instances whenever a new
22  * sequence of frames arrives. The trainer can be used like any other trainer,
23  * but ideally there should be no calls to sleep() between `update()` calls to
24  * ensure fast processing of collected episode data.
25  *
26  * Implementation details:
27  * The trainer spawns dedicated threads for servers and clients.
28  * The data that goes over the network (serialized episodes) will be compressed
29  * using Zstandard, so there's no need to use compression to your custom replay
30  * buffer frame structure.
31  * Episode (de)serialization (including (de)compression) will be performed in
32  * the respective thread calling `stepEpisode()` (client) and `update()`
33  * (server).
34  *
35  * TODO: Extend this so that it can be used in RL settings.
36  */
37 class CentralTrainer : public Trainer {
38  public:
40  bool isServer,
41  ag::Container model,
42  ag::Optimizer optim,
43  std::unique_ptr<BaseSampler> sampler,
44  std::unique_ptr<AsyncBatcher> batcher = nullptr);
45  virtual ~CentralTrainer();
46 
47  bool isServer() const {
48  return server_ != nullptr;
49  }
50 
51  virtual void
52  stepFrame(GameUID const&, EpisodeKey const&, ReplayBuffer::Episode&) override;
53  virtual void stepEpisode(
54  GameUID const&,
55  EpisodeKey const&,
56  ReplayBuffer::Episode&) override;
57  ag::Variant forward(ag::Variant inp, EpisodeHandle const&) override;
58  virtual bool update() override;
59  virtual std::shared_ptr<ReplayBufferFrame> makeFrame(
60  ag::Variant trainerOutput,
61  ag::Variant state,
62  float reward) override;
63 
64  virtual EpisodeHandle startEpisode() override;
65  virtual void forceStopEpisode(EpisodeHandle const&) override;
66 
67  std::shared_lock<std::shared_timed_mutex> modelReadLock();
68  std::unique_lock<std::shared_timed_mutex> modelWriteLock();
69 
70  protected:
71  /// Callback for new episodes.
72  /// This will be called from update() for locally and remotely generated
73  /// episodes. The second argument is a unique ID per sequence of frames,
74  /// and should be removed when we get episodes to send to the same
75  /// node. This _depends_ on the deprecation of EpisodeKey
76  virtual void receivedFrames(GameUID const&, std::string const&) = 0;
77 
78  /// Callback for locally generated episode data that can be sent out.
79  /// If this returns `false`, the data will be put into the local replay
80  /// buffer.
81  virtual bool episodeClientEnqueue(EpisodeData const&);
82 
83  /// Allows implementing trainers to send partial episodes
84  /// TODO: set up synchronization so the partial episodes end up
85  /// on the same node.
86  virtual uint32_t getMaxBatchLength() const;
87  virtual uint32_t getSendInterval() const;
88 
89  /// Allows implementing trainers to decide whether to serve end-of-frame
90  /// in the middle of the episode or not.
91  virtual bool serveContinuously() const;
92 
93  /// A protected constructor that doesn't set up a server
95  ag::Container model,
96  ag::Optimizer optim,
97  std::unique_ptr<BaseSampler> sampler,
98  std::unique_ptr<AsyncBatcher> batcher = nullptr);
99 
100  // Each instance either has a server or client
101  std::shared_ptr<EpisodeServer> server_;
102  std::shared_ptr<EpisodeClient> client_;
103  std::thread dequeueEpisodes_;
104  void dequeueEpisodes();
105 
106  std::mutex newGamesMutex_;
107  std::queue<EpisodeTuple> newBatches_;
108 
109  std::shared_timed_mutex modelMutex_;
110  std::atomic<bool> stop_{false};
111 
112  // Some state variables for when we serve continuously
113  class BufferPool;
114  std::unique_ptr<BufferPool> bufferPool_;
115 };
116 
117 } // namespace cpid
std::atomic< bool > stop_
Definition: centraltrainer.h:110
std::string GameUID
Definition: trainer.h:31
ag::Container model() const
Definition: trainer.cpp:231
Definition: episodeserver.h:23
void dequeueEpisodes()
Definition: centraltrainer.cpp:237
virtual void stepFrame(GameUID const &, EpisodeKey const &, ReplayBuffer::Episode &) override
Definition: centraltrainer.cpp:118
std::shared_lock< std::shared_timed_mutex > modelReadLock()
Definition: centraltrainer.cpp:229
std::shared_timed_mutex modelMutex_
Definition: centraltrainer.h:109
virtual bool serveContinuously() const
Allows implementing trainers to decide whether to serve end-of-frame in the middle of the episode or ...
Definition: centraltrainer.cpp:280
Definition: trainer.h:158
virtual uint32_t getMaxBatchLength() const
Allows implementing trainers to send partial episodes TODO: set up synchronization so the partial epi...
Definition: centraltrainer.cpp:272
virtual ~CentralTrainer()
Definition: centraltrainer.cpp:110
A trainer that sends episodes to one or more central instances.
Definition: centraltrainer.h:37
CentralTrainer(bool isServer, ag::Container model, ag::Optimizer optim, std::unique_ptr< BaseSampler > sampler, std::unique_ptr< AsyncBatcher > batcher=nullptr)
Definition: centraltrainer.cpp:67
std::shared_ptr< EpisodeServer > server_
Definition: centraltrainer.h:101
std::unique_ptr< BufferPool > bufferPool_
Definition: centraltrainer.h:113
ag::Variant forward(ag::Variant inp, EpisodeHandle const &) override
Definition: centraltrainer.cpp:191
std::shared_ptr< EpisodeClient > client_
Definition: centraltrainer.h:102
std::unique_lock< std::shared_timed_mutex > modelWriteLock()
Definition: centraltrainer.cpp:233
The Trainer should be shared amongst multiple different nodes, and attached to a single Module...
Definition: trainer.h:156
virtual std::shared_ptr< ReplayBufferFrame > makeFrame(ag::Variant trainerOutput, ag::Variant state, float reward) override
Definition: centraltrainer.cpp:222
std::mutex newGamesMutex_
Definition: centraltrainer.h:106
std::queue< EpisodeTuple > newBatches_
Definition: centraltrainer.h:107
virtual bool episodeClientEnqueue(EpisodeData const &)
Callback for locally generated episode data that can be sent out.
Definition: centraltrainer.cpp:264
std::string EpisodeKey
Definition: trainer.h:32
virtual void stepEpisode(GameUID const &, EpisodeKey const &, ReplayBuffer::Episode &) override
Definition: centraltrainer.cpp:182
virtual uint32_t getSendInterval() const
Definition: centraltrainer.cpp:276
virtual EpisodeHandle startEpisode() override
Returns true if succeeded to register an episode, and false otherwise.
Definition: centraltrainer.cpp:167
ag::Optimizer optim() const
Definition: trainer.cpp:235
The TorchCraftAI training library.
Definition: batcher.cpp:15
std::thread dequeueEpisodes_
Definition: centraltrainer.h:103
virtual bool update() override
Definition: centraltrainer.cpp:198
std::vector< std::shared_ptr< ReplayBufferFrame >> Episode
Definition: trainer.h:89
Definition: centraltrainer.cpp:25
virtual void receivedFrames(GameUID const &, std::string const &)=0
Callback for new episodes.
virtual void forceStopEpisode(EpisodeHandle const &) override
Definition: centraltrainer.cpp:175
bool isServer() const
Definition: centraltrainer.h:47