10 #include "episodeserver.h" 12 #include <shared_mutex> 43 std::unique_ptr<BaseSampler> sampler,
44 std::unique_ptr<AsyncBatcher> batcher =
nullptr);
58 virtual bool update()
override;
59 virtual std::shared_ptr<ReplayBufferFrame>
makeFrame(
60 ag::Variant trainerOutput,
62 float reward)
override;
97 std::unique_ptr<BaseSampler> sampler,
98 std::unique_ptr<AsyncBatcher> batcher =
nullptr);
std::atomic< bool > stop_
Definition: centraltrainer.h:110
std::string GameUID
Definition: trainer.h:31
ag::Container model() const
Definition: trainer.cpp:231
Definition: episodeserver.h:23
void dequeueEpisodes()
Definition: centraltrainer.cpp:237
virtual void stepFrame(GameUID const &, EpisodeKey const &, ReplayBuffer::Episode &) override
Definition: centraltrainer.cpp:118
std::shared_lock< std::shared_timed_mutex > modelReadLock()
Definition: centraltrainer.cpp:229
std::shared_timed_mutex modelMutex_
Definition: centraltrainer.h:109
virtual bool serveContinuously() const
Allows implementing trainers to decide whether to serve end-of-frame in the middle of the episode or ...
Definition: centraltrainer.cpp:280
Definition: trainer.h:158
virtual uint32_t getMaxBatchLength() const
Allows implementing trainers to send partial episodes TODO: set up synchronization so the partial epi...
Definition: centraltrainer.cpp:272
virtual ~CentralTrainer()
Definition: centraltrainer.cpp:110
A trainer that sends episodes to one or more central instances.
Definition: centraltrainer.h:37
CentralTrainer(bool isServer, ag::Container model, ag::Optimizer optim, std::unique_ptr< BaseSampler > sampler, std::unique_ptr< AsyncBatcher > batcher=nullptr)
Definition: centraltrainer.cpp:67
std::shared_ptr< EpisodeServer > server_
Definition: centraltrainer.h:101
std::unique_ptr< BufferPool > bufferPool_
Definition: centraltrainer.h:113
ag::Variant forward(ag::Variant inp, EpisodeHandle const &) override
Definition: centraltrainer.cpp:191
std::shared_ptr< EpisodeClient > client_
Definition: centraltrainer.h:102
std::unique_lock< std::shared_timed_mutex > modelWriteLock()
Definition: centraltrainer.cpp:233
The Trainer should be shared amongst multiple different nodes, and attached to a single Module...
Definition: trainer.h:156
virtual std::shared_ptr< ReplayBufferFrame > makeFrame(ag::Variant trainerOutput, ag::Variant state, float reward) override
Definition: centraltrainer.cpp:222
std::mutex newGamesMutex_
Definition: centraltrainer.h:106
std::queue< EpisodeTuple > newBatches_
Definition: centraltrainer.h:107
virtual bool episodeClientEnqueue(EpisodeData const &)
Callback for locally generated episode data that can be sent out.
Definition: centraltrainer.cpp:264
std::string EpisodeKey
Definition: trainer.h:32
virtual void stepEpisode(GameUID const &, EpisodeKey const &, ReplayBuffer::Episode &) override
Definition: centraltrainer.cpp:182
virtual uint32_t getSendInterval() const
Definition: centraltrainer.cpp:276
virtual EpisodeHandle startEpisode() override
Returns true if succeeded to register an episode, and false otherwise.
Definition: centraltrainer.cpp:167
ag::Optimizer optim() const
Definition: trainer.cpp:235
The TorchCraftAI training library.
Definition: batcher.cpp:15
std::thread dequeueEpisodes_
Definition: centraltrainer.h:103
virtual bool update() override
Definition: centraltrainer.cpp:198
std::vector< std::shared_ptr< ReplayBufferFrame >> Episode
Definition: trainer.h:89
Definition: centraltrainer.cpp:25
virtual void receivedFrames(GameUID const &, std::string const &)=0
Callback for new episodes.
virtual void forceStopEpisode(EpisodeHandle const &) override
Definition: centraltrainer.cpp:175
bool isServer() const
Definition: centraltrainer.h:47