TorchCraftAI
A bot for machine learning research on StarCraft: Brood War
Classes | Public Member Functions | Protected Member Functions | Protected Attributes | List of all members
cpid::CentralTrainer Class Referenceabstract

A trainer that sends episodes to one or more central instances. More...

#include <centraltrainer.h>

Inherits cpid::Trainer.

Inherited by cpid::CentralCpid2kTrainer.

Classes

class  BufferPool
 

Public Member Functions

 CentralTrainer (bool isServer, ag::Container model, ag::Optimizer optim, std::unique_ptr< BaseSampler > sampler, std::unique_ptr< AsyncBatcher > batcher=nullptr)
 
virtual ~CentralTrainer ()
 
bool isServer () const
 
virtual void stepFrame (GameUID const &, EpisodeKey const &, ReplayBuffer::Episode &) override
 
virtual void stepEpisode (GameUID const &, EpisodeKey const &, ReplayBuffer::Episode &) override
 
ag::Variant forward (ag::Variant inp, EpisodeHandle const &) override
 
virtual bool update () override
 
virtual std::shared_ptr< ReplayBufferFramemakeFrame (ag::Variant trainerOutput, ag::Variant state, float reward) override
 
virtual EpisodeHandle startEpisode () override
 Returns true if succeeded to register an episode, and false otherwise. More...
 
virtual void forceStopEpisode (EpisodeHandle const &) override
 
std::shared_lock< std::shared_timed_mutex > modelReadLock ()
 
std::unique_lock< std::shared_timed_mutex > modelWriteLock ()
 
- Public Member Functions inherited from cpid::Trainer
 Trainer (ag::Container model, ag::Optimizer optim, std::unique_ptr< BaseSampler >, std::unique_ptr< AsyncBatcher > batcher=nullptr)
 
ag::Variant forwardUnbatched (ag::Variant in, ag::Container model=nullptr)
 Convenience function when one need to forward a single input. More...
 
virtual ~Trainer ()=default
 
void setTrain (bool=true)
 
bool isTrain () const
 
ag::Variant sample (ag::Variant in)
 Sample using the class' sampler. More...
 
ag::Container model () const
 
ag::Optimizer optim () const
 
ReplayBufferreplayBuffer ()
 
virtual std::shared_ptr< EvaluatormakeEvaluator (size_t, std::unique_ptr< BaseSampler > sampler)
 
void setDone (bool=true)
 
bool isDone () const
 
virtual void step (EpisodeHandle const &, std::shared_ptr< ReplayBufferFrame > v, bool isDone=false)
 
bool isActive (EpisodeHandle const &)
 
virtual void reset ()
 Releases all the worker threads so that they can be joined. More...
 
template<class Archive >
void save (Archive &ar) const
 
template<class Archive >
void load (Archive &ar)
 
template<typename T >
bool is () const
 
TrainersetMetricsContext (std::shared_ptr< MetricsContext > context)
 
std::shared_ptr< MetricsContextmetricsContext () const
 
 TORCH_ARG (float, noiseStd)
 
 TORCH_ARG (bool, continuousActions)
 
void setBatcher (std::unique_ptr< AsyncBatcher > batcher)
 

Protected Member Functions

virtual void receivedFrames (GameUID const &, std::string const &)=0
 Callback for new episodes. More...
 
virtual bool episodeClientEnqueue (EpisodeData const &)
 Callback for locally generated episode data that can be sent out. More...
 
virtual uint32_t getMaxBatchLength () const
 Allows implementing trainers to send partial episodes TODO: set up synchronization so the partial episodes end up on the same node. More...
 
virtual uint32_t getSendInterval () const
 
virtual bool serveContinuously () const
 Allows implementing trainers to decide whether to serve end-of-frame in the middle of the episode or not. More...
 
 CentralTrainer (ag::Container model, ag::Optimizer optim, std::unique_ptr< BaseSampler > sampler, std::unique_ptr< AsyncBatcher > batcher=nullptr)
 A protected constructor that doesn't set up a server. More...
 
void dequeueEpisodes ()
 
- Protected Member Functions inherited from cpid::Trainer
virtual void stepGame (GameUID const &game)
 
template<typename T >
std::vector< T const * > cast (ReplayBuffer::Episode const &e)
 

Protected Attributes

std::shared_ptr< EpisodeServerserver_
 
std::shared_ptr< EpisodeClientclient_
 
std::thread dequeueEpisodes_
 
std::mutex newGamesMutex_
 
std::queue< EpisodeTuplenewBatches_
 
std::shared_timed_mutex modelMutex_
 
std::atomic< bool > stop_ {false}
 
std::unique_ptr< BufferPoolbufferPool_
 
- Protected Attributes inherited from cpid::Trainer
ag::Container model_
 
ag::Optimizer optim_
 
std::shared_ptr< MetricsContextmetricsContext_
 
ReplayBuffer replayer_
 
bool train_ = true
 
std::atomic< bool > done_ {false}
 
std::mutex modelWriteMutex_
 
std::shared_timed_mutex activeMapMutex_
 
std::unique_ptr< BaseSamplersampler_
 
std::unique_ptr< AsyncBatcherbatcher_
 
std::shared_ptr< HandleGuardepGuard_
 
ReplayBuffer::UIDKeyStore actives_
 

Additional Inherited Members

- Protected Types inherited from cpid::Trainer
using ForwardFunction = std::function< ag::Variant(ag::Variant, EpisodeHandle const &)>
 
- Static Protected Member Functions inherited from cpid::Trainer
static std::shared_ptr< EvaluatorevaluatorFactory (ag::Container model, std::unique_ptr< BaseSampler > s, size_t n, ForwardFunction func)
 
- Static Protected Attributes inherited from cpid::Trainer
static constexpr float kFwdMetricsSubsampling = 0.1
 We subsample kFwdMetricsSubsampling of the forward() events when measuring their duration. More...
 

Detailed Description

A trainer that sends episodes to one or more central instances.

In this trainer, several "server" instances will collect episode data from "client" instances. Users are required to subclass this and override receivedFrames(), which will be called on server instances whenever a new sequence of frames arrives. The trainer can be used like any other trainer, but ideally there should be no calls to sleep() between update() calls to ensure fast processing of collected episode data.

Implementation details: The trainer spawns dedicated threads for servers and clients. The data that goes over the network (serialized episodes) will be compressed using Zstandard, so there's no need to use compression to your custom replay buffer frame structure. Episode (de)serialization (including (de)compression) will be performed in the respective thread calling stepEpisode() (client) and update() (server).

TODO: Extend this so that it can be used in RL settings.

Constructor & Destructor Documentation

cpid::CentralTrainer::CentralTrainer ( bool  isServer,
ag::Container  model,
ag::Optimizer  optim,
std::unique_ptr< BaseSampler sampler,
std::unique_ptr< AsyncBatcher batcher = nullptr 
)
cpid::CentralTrainer::~CentralTrainer ( )
virtual
cpid::CentralTrainer::CentralTrainer ( ag::Container  model,
ag::Optimizer  optim,
std::unique_ptr< BaseSampler sampler,
std::unique_ptr< AsyncBatcher batcher = nullptr 
)
protected

A protected constructor that doesn't set up a server.

Member Function Documentation

void cpid::CentralTrainer::dequeueEpisodes ( )
protected
bool cpid::CentralTrainer::episodeClientEnqueue ( EpisodeData const &  epData)
protectedvirtual

Callback for locally generated episode data that can be sent out.

If this returns false, the data will be put into the local replay buffer.

void cpid::CentralTrainer::forceStopEpisode ( EpisodeHandle const &  handle)
overridevirtual

Reimplemented from cpid::Trainer.

ag::Variant cpid::CentralTrainer::forward ( ag::Variant  inp,
EpisodeHandle const &  handle 
)
overridevirtual

Reimplemented from cpid::Trainer.

uint32_t cpid::CentralTrainer::getMaxBatchLength ( ) const
protectedvirtual

Allows implementing trainers to send partial episodes TODO: set up synchronization so the partial episodes end up on the same node.

uint32_t cpid::CentralTrainer::getSendInterval ( ) const
protectedvirtual
bool cpid::CentralTrainer::isServer ( ) const
inline
std::shared_ptr< ReplayBufferFrame > cpid::CentralTrainer::makeFrame ( ag::Variant  trainerOutput,
ag::Variant  state,
float  reward 
)
overridevirtual

Implements cpid::Trainer.

std::shared_lock< std::shared_timed_mutex > cpid::CentralTrainer::modelReadLock ( )
std::unique_lock< std::shared_timed_mutex > cpid::CentralTrainer::modelWriteLock ( )
virtual void cpid::CentralTrainer::receivedFrames ( GameUID const &  ,
std::string const &   
)
protectedpure virtual

Callback for new episodes.

This will be called from update() for locally and remotely generated episodes. The second argument is a unique ID per sequence of frames, and should be removed when we get episodes to send to the same node. This depends on the deprecation of EpisodeKey

bool cpid::CentralTrainer::serveContinuously ( ) const
protectedvirtual

Allows implementing trainers to decide whether to serve end-of-frame in the middle of the episode or not.

EpisodeHandle cpid::CentralTrainer::startEpisode ( )
overridevirtual

Returns true if succeeded to register an episode, and false otherwise.

After receiving false, a worker thread should check stopping conditins and re-try.

Reimplemented from cpid::Trainer.

void cpid::CentralTrainer::stepEpisode ( GameUID const &  gameId,
EpisodeKey const &  key,
ReplayBuffer::Episode episode 
)
overridevirtual

Reimplemented from cpid::Trainer.

void cpid::CentralTrainer::stepFrame ( GameUID const &  gameId,
EpisodeKey const &  ,
ReplayBuffer::Episode episode 
)
overridevirtual

Reimplemented from cpid::Trainer.

bool cpid::CentralTrainer::update ( )
overridevirtual

Implements cpid::Trainer.

Reimplemented in cpid::CentralCpid2kTrainer.

Member Data Documentation

std::unique_ptr<BufferPool> cpid::CentralTrainer::bufferPool_
protected
std::shared_ptr<EpisodeClient> cpid::CentralTrainer::client_
protected
std::thread cpid::CentralTrainer::dequeueEpisodes_
protected
std::shared_timed_mutex cpid::CentralTrainer::modelMutex_
protected
std::queue<EpisodeTuple> cpid::CentralTrainer::newBatches_
protected
std::mutex cpid::CentralTrainer::newGamesMutex_
protected
std::shared_ptr<EpisodeServer> cpid::CentralTrainer::server_
protected
std::atomic<bool> cpid::CentralTrainer::stop_ {false}
protected

The documentation for this class was generated from the following files: