TorchCraftAI
A bot for machine learning research on StarCraft: Brood War
|
A trainer that sends episodes to one or more central instances. More...
#include <centraltrainer.h>
Inherits cpid::Trainer.
Inherited by cpid::CentralCpid2kTrainer.
Classes | |
class | BufferPool |
Public Member Functions | |
CentralTrainer (bool isServer, ag::Container model, ag::Optimizer optim, std::unique_ptr< BaseSampler > sampler, std::unique_ptr< AsyncBatcher > batcher=nullptr) | |
virtual | ~CentralTrainer () |
bool | isServer () const |
virtual void | stepFrame (GameUID const &, EpisodeKey const &, ReplayBuffer::Episode &) override |
virtual void | stepEpisode (GameUID const &, EpisodeKey const &, ReplayBuffer::Episode &) override |
ag::Variant | forward (ag::Variant inp, EpisodeHandle const &) override |
virtual bool | update () override |
virtual std::shared_ptr< ReplayBufferFrame > | makeFrame (ag::Variant trainerOutput, ag::Variant state, float reward) override |
virtual EpisodeHandle | startEpisode () override |
Returns true if succeeded to register an episode, and false otherwise. More... | |
virtual void | forceStopEpisode (EpisodeHandle const &) override |
std::shared_lock< std::shared_timed_mutex > | modelReadLock () |
std::unique_lock< std::shared_timed_mutex > | modelWriteLock () |
Public Member Functions inherited from cpid::Trainer | |
Trainer (ag::Container model, ag::Optimizer optim, std::unique_ptr< BaseSampler >, std::unique_ptr< AsyncBatcher > batcher=nullptr) | |
ag::Variant | forwardUnbatched (ag::Variant in, ag::Container model=nullptr) |
Convenience function when one need to forward a single input. More... | |
virtual | ~Trainer ()=default |
void | setTrain (bool=true) |
bool | isTrain () const |
ag::Variant | sample (ag::Variant in) |
Sample using the class' sampler. More... | |
ag::Container | model () const |
ag::Optimizer | optim () const |
ReplayBuffer & | replayBuffer () |
virtual std::shared_ptr< Evaluator > | makeEvaluator (size_t, std::unique_ptr< BaseSampler > sampler) |
void | setDone (bool=true) |
bool | isDone () const |
virtual void | step (EpisodeHandle const &, std::shared_ptr< ReplayBufferFrame > v, bool isDone=false) |
bool | isActive (EpisodeHandle const &) |
virtual void | reset () |
Releases all the worker threads so that they can be joined. More... | |
template<class Archive > | |
void | save (Archive &ar) const |
template<class Archive > | |
void | load (Archive &ar) |
template<typename T > | |
bool | is () const |
Trainer & | setMetricsContext (std::shared_ptr< MetricsContext > context) |
std::shared_ptr< MetricsContext > | metricsContext () const |
TORCH_ARG (float, noiseStd) | |
TORCH_ARG (bool, continuousActions) | |
void | setBatcher (std::unique_ptr< AsyncBatcher > batcher) |
Protected Member Functions | |
virtual void | receivedFrames (GameUID const &, std::string const &)=0 |
Callback for new episodes. More... | |
virtual bool | episodeClientEnqueue (EpisodeData const &) |
Callback for locally generated episode data that can be sent out. More... | |
virtual uint32_t | getMaxBatchLength () const |
Allows implementing trainers to send partial episodes TODO: set up synchronization so the partial episodes end up on the same node. More... | |
virtual uint32_t | getSendInterval () const |
virtual bool | serveContinuously () const |
Allows implementing trainers to decide whether to serve end-of-frame in the middle of the episode or not. More... | |
CentralTrainer (ag::Container model, ag::Optimizer optim, std::unique_ptr< BaseSampler > sampler, std::unique_ptr< AsyncBatcher > batcher=nullptr) | |
A protected constructor that doesn't set up a server. More... | |
void | dequeueEpisodes () |
Protected Member Functions inherited from cpid::Trainer | |
virtual void | stepGame (GameUID const &game) |
template<typename T > | |
std::vector< T const * > | cast (ReplayBuffer::Episode const &e) |
Protected Attributes | |
std::shared_ptr< EpisodeServer > | server_ |
std::shared_ptr< EpisodeClient > | client_ |
std::thread | dequeueEpisodes_ |
std::mutex | newGamesMutex_ |
std::queue< EpisodeTuple > | newBatches_ |
std::shared_timed_mutex | modelMutex_ |
std::atomic< bool > | stop_ {false} |
std::unique_ptr< BufferPool > | bufferPool_ |
Protected Attributes inherited from cpid::Trainer | |
ag::Container | model_ |
ag::Optimizer | optim_ |
std::shared_ptr< MetricsContext > | metricsContext_ |
ReplayBuffer | replayer_ |
bool | train_ = true |
std::atomic< bool > | done_ {false} |
std::mutex | modelWriteMutex_ |
std::shared_timed_mutex | activeMapMutex_ |
std::unique_ptr< BaseSampler > | sampler_ |
std::unique_ptr< AsyncBatcher > | batcher_ |
std::shared_ptr< HandleGuard > | epGuard_ |
ReplayBuffer::UIDKeyStore | actives_ |
Additional Inherited Members | |
Protected Types inherited from cpid::Trainer | |
using | ForwardFunction = std::function< ag::Variant(ag::Variant, EpisodeHandle const &)> |
Static Protected Member Functions inherited from cpid::Trainer | |
static std::shared_ptr< Evaluator > | evaluatorFactory (ag::Container model, std::unique_ptr< BaseSampler > s, size_t n, ForwardFunction func) |
Static Protected Attributes inherited from cpid::Trainer | |
static constexpr float | kFwdMetricsSubsampling = 0.1 |
We subsample kFwdMetricsSubsampling of the forward() events when measuring their duration. More... | |
A trainer that sends episodes to one or more central instances.
In this trainer, several "server" instances will collect episode data from "client" instances. Users are required to subclass this and override receivedFrames()
, which will be called on server instances whenever a new sequence of frames arrives. The trainer can be used like any other trainer, but ideally there should be no calls to sleep() between update()
calls to ensure fast processing of collected episode data.
Implementation details: The trainer spawns dedicated threads for servers and clients. The data that goes over the network (serialized episodes) will be compressed using Zstandard, so there's no need to use compression to your custom replay buffer frame structure. Episode (de)serialization (including (de)compression) will be performed in the respective thread calling stepEpisode()
(client) and update()
(server).
TODO: Extend this so that it can be used in RL settings.
cpid::CentralTrainer::CentralTrainer | ( | bool | isServer, |
ag::Container | model, | ||
ag::Optimizer | optim, | ||
std::unique_ptr< BaseSampler > | sampler, | ||
std::unique_ptr< AsyncBatcher > | batcher = nullptr |
||
) |
|
virtual |
|
protected |
A protected constructor that doesn't set up a server.
|
protected |
|
protectedvirtual |
Callback for locally generated episode data that can be sent out.
If this returns false
, the data will be put into the local replay buffer.
|
overridevirtual |
Reimplemented from cpid::Trainer.
|
overridevirtual |
Reimplemented from cpid::Trainer.
|
protectedvirtual |
Allows implementing trainers to send partial episodes TODO: set up synchronization so the partial episodes end up on the same node.
|
protectedvirtual |
|
inline |
|
overridevirtual |
Implements cpid::Trainer.
std::shared_lock< std::shared_timed_mutex > cpid::CentralTrainer::modelReadLock | ( | ) |
std::unique_lock< std::shared_timed_mutex > cpid::CentralTrainer::modelWriteLock | ( | ) |
|
protectedpure virtual |
Callback for new episodes.
This will be called from update() for locally and remotely generated episodes. The second argument is a unique ID per sequence of frames, and should be removed when we get episodes to send to the same node. This depends on the deprecation of EpisodeKey
|
protectedvirtual |
Allows implementing trainers to decide whether to serve end-of-frame in the middle of the episode or not.
|
overridevirtual |
Returns true if succeeded to register an episode, and false otherwise.
After receiving false, a worker thread should check stopping conditins and re-try.
Reimplemented from cpid::Trainer.
|
overridevirtual |
Reimplemented from cpid::Trainer.
|
overridevirtual |
Reimplemented from cpid::Trainer.
|
overridevirtual |
Implements cpid::Trainer.
Reimplemented in cpid::CentralCpid2kTrainer.
|
protected |
|
protected |
|
protected |
|
protected |
|
protected |
|
protected |
|
protected |
|
protected |