TorchCraftAI
A bot for machine learning research on StarCraft: Brood War
Public Member Functions | Protected Member Functions | Protected Attributes | List of all members
cpid::CentralCpid2kTrainer Class Reference

A trainer that sends episodes to one or more central instances. More...

#include <centralcpid2ktrainer.h>

Inherits cpid::CentralTrainer.

Public Member Functions

 CentralCpid2kTrainer (ag::Container model, ag::Optimizer optim, std::unique_ptr< BaseSampler > sampler, std::unique_ptr< AsyncBatcher > batcher=nullptr, std::string serverRole="train")
 
virtual bool update () override
 
void updateDone ()
 
distributed::Contextcontext ()
 
distributed::ContextserverContext ()
 
int numUpdates ()
 
- Public Member Functions inherited from cpid::CentralTrainer
 CentralTrainer (bool isServer, ag::Container model, ag::Optimizer optim, std::unique_ptr< BaseSampler > sampler, std::unique_ptr< AsyncBatcher > batcher=nullptr)
 
virtual ~CentralTrainer ()
 
bool isServer () const
 
virtual void stepFrame (GameUID const &, EpisodeKey const &, ReplayBuffer::Episode &) override
 
virtual void stepEpisode (GameUID const &, EpisodeKey const &, ReplayBuffer::Episode &) override
 
ag::Variant forward (ag::Variant inp, EpisodeHandle const &) override
 
virtual std::shared_ptr< ReplayBufferFramemakeFrame (ag::Variant trainerOutput, ag::Variant state, float reward) override
 
virtual EpisodeHandle startEpisode () override
 Returns true if succeeded to register an episode, and false otherwise. More...
 
virtual void forceStopEpisode (EpisodeHandle const &) override
 
std::shared_lock< std::shared_timed_mutex > modelReadLock ()
 
std::unique_lock< std::shared_timed_mutex > modelWriteLock ()
 
- Public Member Functions inherited from cpid::Trainer
 Trainer (ag::Container model, ag::Optimizer optim, std::unique_ptr< BaseSampler >, std::unique_ptr< AsyncBatcher > batcher=nullptr)
 
ag::Variant forwardUnbatched (ag::Variant in, ag::Container model=nullptr)
 Convenience function when one need to forward a single input. More...
 
virtual ~Trainer ()=default
 
void setTrain (bool=true)
 
bool isTrain () const
 
ag::Variant sample (ag::Variant in)
 Sample using the class' sampler. More...
 
ag::Container model () const
 
ag::Optimizer optim () const
 
ReplayBufferreplayBuffer ()
 
virtual std::shared_ptr< EvaluatormakeEvaluator (size_t, std::unique_ptr< BaseSampler > sampler)
 
void setDone (bool=true)
 
bool isDone () const
 
virtual void step (EpisodeHandle const &, std::shared_ptr< ReplayBufferFrame > v, bool isDone=false)
 
bool isActive (EpisodeHandle const &)
 
virtual void reset ()
 Releases all the worker threads so that they can be joined. More...
 
template<class Archive >
void save (Archive &ar) const
 
template<class Archive >
void load (Archive &ar)
 
template<typename T >
bool is () const
 
TrainersetMetricsContext (std::shared_ptr< MetricsContext > context)
 
std::shared_ptr< MetricsContextmetricsContext () const
 
 TORCH_ARG (float, noiseStd)
 
 TORCH_ARG (bool, continuousActions)
 
void setBatcher (std::unique_ptr< AsyncBatcher > batcher)
 

Protected Member Functions

void bcastWeights ()
 
void recvWeights (void const *data, size_t len, int64_t numUpdates)
 
- Protected Member Functions inherited from cpid::CentralTrainer
virtual void receivedFrames (GameUID const &, std::string const &)=0
 Callback for new episodes. More...
 
virtual bool episodeClientEnqueue (EpisodeData const &)
 Callback for locally generated episode data that can be sent out. More...
 
virtual uint32_t getMaxBatchLength () const
 Allows implementing trainers to send partial episodes TODO: set up synchronization so the partial episodes end up on the same node. More...
 
virtual uint32_t getSendInterval () const
 
virtual bool serveContinuously () const
 Allows implementing trainers to decide whether to serve end-of-frame in the middle of the episode or not. More...
 
 CentralTrainer (ag::Container model, ag::Optimizer optim, std::unique_ptr< BaseSampler > sampler, std::unique_ptr< AsyncBatcher > batcher=nullptr)
 A protected constructor that doesn't set up a server. More...
 
void dequeueEpisodes ()
 
- Protected Member Functions inherited from cpid::Trainer
virtual void stepGame (GameUID const &game)
 
template<typename T >
std::vector< T const * > cast (ReplayBuffer::Episode const &e)
 

Protected Attributes

std::string serverRole_
 
std::unique_ptr< Cpid2kWorkerworker_
 
std::mutex makeClientMutex_
 
std::vector< std::string > endpoints_
 
std::shared_ptr< zmq::context_t > zmqContext_
 
std::shared_ptr< BlobPublishermodelPub_
 
std::shared_ptr< BlobSubscribermodelSub_
 
std::atomic< int64_t > numUpdates_ {-1}
 
- Protected Attributes inherited from cpid::CentralTrainer
std::shared_ptr< EpisodeServerserver_
 
std::shared_ptr< EpisodeClientclient_
 
std::thread dequeueEpisodes_
 
std::mutex newGamesMutex_
 
std::queue< EpisodeTuplenewBatches_
 
std::shared_timed_mutex modelMutex_
 
std::atomic< bool > stop_ {false}
 
std::unique_ptr< BufferPoolbufferPool_
 
- Protected Attributes inherited from cpid::Trainer
ag::Container model_
 
ag::Optimizer optim_
 
std::shared_ptr< MetricsContextmetricsContext_
 
ReplayBuffer replayer_
 
bool train_ = true
 
std::atomic< bool > done_ {false}
 
std::mutex modelWriteMutex_
 
std::shared_timed_mutex activeMapMutex_
 
std::unique_ptr< BaseSamplersampler_
 
std::unique_ptr< AsyncBatcherbatcher_
 
std::shared_ptr< HandleGuardepGuard_
 
ReplayBuffer::UIDKeyStore actives_
 

Additional Inherited Members

- Protected Types inherited from cpid::Trainer
using ForwardFunction = std::function< ag::Variant(ag::Variant, EpisodeHandle const &)>
 
- Static Protected Member Functions inherited from cpid::Trainer
static std::shared_ptr< EvaluatorevaluatorFactory (ag::Container model, std::unique_ptr< BaseSampler > s, size_t n, ForwardFunction func)
 
- Static Protected Attributes inherited from cpid::Trainer
static constexpr float kFwdMetricsSubsampling = 0.1
 We subsample kFwdMetricsSubsampling of the forward() events when measuring their duration. More...
 

Detailed Description

A trainer that sends episodes to one or more central instances.

This is like CentralTrainer but uses Redis for figuring out servers and clients.

Constructor & Destructor Documentation

cpid::CentralCpid2kTrainer::CentralCpid2kTrainer ( ag::Container  model,
ag::Optimizer  optim,
std::unique_ptr< BaseSampler sampler,
std::unique_ptr< AsyncBatcher batcher = nullptr,
std::string  serverRole = "train" 
)

Member Function Documentation

void cpid::CentralCpid2kTrainer::bcastWeights ( )
protected
dist::Context & cpid::CentralCpid2kTrainer::context ( )
int cpid::CentralCpid2kTrainer::numUpdates ( )
inline
void cpid::CentralCpid2kTrainer::recvWeights ( void const *  data,
size_t  len,
int64_t  numUpdates 
)
protected
dist::Context & cpid::CentralCpid2kTrainer::serverContext ( )
bool cpid::CentralCpid2kTrainer::update ( )
overridevirtual

Reimplemented from cpid::CentralTrainer.

void cpid::CentralCpid2kTrainer::updateDone ( )

Member Data Documentation

std::vector<std::string> cpid::CentralCpid2kTrainer::endpoints_
protected
std::mutex cpid::CentralCpid2kTrainer::makeClientMutex_
protected
std::shared_ptr<BlobPublisher> cpid::CentralCpid2kTrainer::modelPub_
protected
std::shared_ptr<BlobSubscriber> cpid::CentralCpid2kTrainer::modelSub_
protected
std::atomic<int64_t> cpid::CentralCpid2kTrainer::numUpdates_ {-1}
protected
std::string cpid::CentralCpid2kTrainer::serverRole_
protected
std::unique_ptr<Cpid2kWorker> cpid::CentralCpid2kTrainer::worker_
protected
std::shared_ptr<zmq::context_t> cpid::CentralCpid2kTrainer::zmqContext_
protected

The documentation for this class was generated from the following files: