TorchCraftAI
A bot for machine learning research on StarCraft: Brood War
Classes | Public Member Functions | Protected Types | Protected Member Functions | Static Protected Member Functions | Protected Attributes | Static Protected Attributes | List of all members
cpid::Trainer Class Referenceabstract

The Trainer should be shared amongst multiple different nodes, and attached to a single Module. More...

#include <trainer.h>

Inherited by cpid::BatchedPGTrainer, cpid::CentralTrainer, cpid::ESTrainer, cpid::Evaluator, and cpid::OnlineZORBTrainer.

Classes

struct  EpisodeHandle
 

Public Member Functions

 Trainer (ag::Container model, ag::Optimizer optim, std::unique_ptr< BaseSampler >, std::unique_ptr< AsyncBatcher > batcher=nullptr)
 
virtual ag::Variant forward (ag::Variant inp, EpisodeHandle const &)
 
ag::Variant forwardUnbatched (ag::Variant in, ag::Container model=nullptr)
 Convenience function when one need to forward a single input. More...
 
virtual bool update ()=0
 
virtual ~Trainer ()=default
 
void setTrain (bool=true)
 
bool isTrain () const
 
ag::Variant sample (ag::Variant in)
 Sample using the class' sampler. More...
 
ag::Container model () const
 
ag::Optimizer optim () const
 
ReplayBufferreplayBuffer ()
 
virtual std::shared_ptr< EvaluatormakeEvaluator (size_t, std::unique_ptr< BaseSampler > sampler)
 
void setDone (bool=true)
 
bool isDone () const
 
virtual void step (EpisodeHandle const &, std::shared_ptr< ReplayBufferFrame > v, bool isDone=false)
 
virtual std::shared_ptr< ReplayBufferFramemakeFrame (ag::Variant trainerOutput, ag::Variant state, float reward)=0
 
virtual EpisodeHandle startEpisode ()
 Returns true if succeeded to register an episode, and false otherwise. More...
 
virtual void forceStopEpisode (EpisodeHandle const &)
 
bool isActive (EpisodeHandle const &)
 
virtual void reset ()
 Releases all the worker threads so that they can be joined. More...
 
template<class Archive >
void save (Archive &ar) const
 
template<class Archive >
void load (Archive &ar)
 
template<typename T >
bool is () const
 
TrainersetMetricsContext (std::shared_ptr< MetricsContext > context)
 
std::shared_ptr< MetricsContextmetricsContext () const
 
 TORCH_ARG (float, noiseStd)
 
 TORCH_ARG (bool, continuousActions)
 
void setBatcher (std::unique_ptr< AsyncBatcher > batcher)
 

Protected Types

using ForwardFunction = std::function< ag::Variant(ag::Variant, EpisodeHandle const &)>
 

Protected Member Functions

virtual void stepFrame (GameUID const &, EpisodeKey const &, ReplayBuffer::Episode &)
 
virtual void stepEpisode (GameUID const &, EpisodeKey const &, ReplayBuffer::Episode &)
 
virtual void stepGame (GameUID const &game)
 
template<typename T >
std::vector< T const * > cast (ReplayBuffer::Episode const &e)
 

Static Protected Member Functions

static std::shared_ptr< EvaluatorevaluatorFactory (ag::Container model, std::unique_ptr< BaseSampler > s, size_t n, ForwardFunction func)
 

Protected Attributes

ag::Container model_
 
ag::Optimizer optim_
 
std::shared_ptr< MetricsContextmetricsContext_
 
ReplayBuffer replayer_
 
bool train_ = true
 
std::atomic< bool > done_ {false}
 
std::mutex modelWriteMutex_
 
std::shared_timed_mutex activeMapMutex_
 
std::unique_ptr< BaseSamplersampler_
 
std::unique_ptr< AsyncBatcherbatcher_
 
std::shared_ptr< HandleGuardepGuard_
 
ReplayBuffer::UIDKeyStore actives_
 

Static Protected Attributes

static constexpr float kFwdMetricsSubsampling = 0.1
 We subsample kFwdMetricsSubsampling of the forward() events when measuring their duration. More...
 

Detailed Description

The Trainer should be shared amongst multiple different nodes, and attached to a single Module.

It consists of an:

Member Typedef Documentation

using cpid::Trainer::ForwardFunction = std::function<ag::Variant(ag::Variant, EpisodeHandle const&)>
protected

Constructor & Destructor Documentation

cpid::Trainer::Trainer ( ag::Container  model,
ag::Optimizer  optim,
std::unique_ptr< BaseSampler sampler,
std::unique_ptr< AsyncBatcher batcher = nullptr 
)
virtual cpid::Trainer::~Trainer ( )
virtualdefault

Member Function Documentation

template<typename T >
std::vector< T const * > cpid::Trainer::cast ( ReplayBuffer::Episode const &  e)
inlineprotected
std::shared_ptr< Evaluator > cpid::Trainer::evaluatorFactory ( ag::Container  model,
std::unique_ptr< BaseSampler s,
size_t  n,
ForwardFunction  func 
)
staticprotected
void cpid::Trainer::forceStopEpisode ( EpisodeHandle const &  handle)
virtual
ag::Variant cpid::Trainer::forward ( ag::Variant  inp,
EpisodeHandle const &  handle 
)
virtual
ag::Variant cpid::Trainer::forwardUnbatched ( ag::Variant  in,
ag::Container  model = nullptr 
)

Convenience function when one need to forward a single input.

This will make it look batched and forward it, so that the model has no problem handling it. If

Parameters
template<typename T >
bool cpid::Trainer::is ( ) const
inline
bool cpid::Trainer::isActive ( EpisodeHandle const &  handle)
bool cpid::Trainer::isDone ( ) const
inline
bool cpid::Trainer::isTrain ( ) const
inline
template<class Archive >
void cpid::Trainer::load ( Archive &  ar)
inline
std::shared_ptr< Evaluator > cpid::Trainer::makeEvaluator ( size_t  ,
std::unique_ptr< BaseSampler sampler 
)
virtual

Reimplemented in cpid::BatchedPGTrainer, and cpid::ESTrainer.

virtual std::shared_ptr<ReplayBufferFrame> cpid::Trainer::makeFrame ( ag::Variant  trainerOutput,
ag::Variant  state,
float  reward 
)
pure virtual
std::shared_ptr< MetricsContext > cpid::Trainer::metricsContext ( ) const
inline
ag::Container cpid::Trainer::model ( ) const
ag::Optimizer cpid::Trainer::optim ( ) const
ReplayBuffer & cpid::Trainer::replayBuffer ( )
void cpid::Trainer::reset ( )
virtual

Releases all the worker threads so that they can be joined.

For the off-policy trainers, labels all games as inactive. For the on-policy trainers, additionally un-blocks all threads that could be waiting at the batch barrier.

Reimplemented in cpid::ESTrainer, and cpid::Evaluator.

ag::Variant cpid::Trainer::sample ( ag::Variant  in)

Sample using the class' sampler.

template<class Archive >
void cpid::Trainer::save ( Archive &  ar) const
inline
void cpid::Trainer::setBatcher ( std::unique_ptr< AsyncBatcher batcher)
void cpid::Trainer::setDone ( bool  done = true)
Trainer & cpid::Trainer::setMetricsContext ( std::shared_ptr< MetricsContext context)
inline
void cpid::Trainer::setTrain ( bool  train = true)
Trainer::EpisodeHandle cpid::Trainer::startEpisode ( )
virtual

Returns true if succeeded to register an episode, and false otherwise.

After receiving false, a worker thread should check stopping conditins and re-try.

Reimplemented in cpid::ESTrainer, cpid::OnlineZORBTrainer, cpid::CentralTrainer, and cpid::Evaluator.

void cpid::Trainer::step ( EpisodeHandle const &  handle,
std::shared_ptr< ReplayBufferFrame v,
bool  isDone = false 
)
virtual
virtual void cpid::Trainer::stepEpisode ( GameUID const &  ,
EpisodeKey const &  ,
ReplayBuffer::Episode  
)
inlineprotectedvirtual
virtual void cpid::Trainer::stepFrame ( GameUID const &  ,
EpisodeKey const &  ,
ReplayBuffer::Episode  
)
inlineprotectedvirtual

Reimplemented in cpid::CentralTrainer.

virtual void cpid::Trainer::stepGame ( GameUID const &  game)
inlineprotectedvirtual
cpid::Trainer::TORCH_ARG ( float  ,
noiseStd   
)
cpid::Trainer::TORCH_ARG ( bool  ,
continuousActions   
)
virtual bool cpid::Trainer::update ( )
pure virtual

Member Data Documentation

std::shared_timed_mutex cpid::Trainer::activeMapMutex_
protected
ReplayBuffer::UIDKeyStore cpid::Trainer::actives_
protected
std::unique_ptr<AsyncBatcher> cpid::Trainer::batcher_
protected
std::atomic<bool> cpid::Trainer::done_ {false}
protected
std::shared_ptr<HandleGuard> cpid::Trainer::epGuard_
protected
constexpr float cpid::Trainer::kFwdMetricsSubsampling = 0.1
staticprotected

We subsample kFwdMetricsSubsampling of the forward() events when measuring their duration.

std::shared_ptr<MetricsContext> cpid::Trainer::metricsContext_
protected
ag::Container cpid::Trainer::model_
protected
std::mutex cpid::Trainer::modelWriteMutex_
protected
ag::Optimizer cpid::Trainer::optim_
protected
ReplayBuffer cpid::Trainer::replayer_
protected
std::unique_ptr<BaseSampler> cpid::Trainer::sampler_
protected
bool cpid::Trainer::train_ = true
protected

The documentation for this class was generated from the following files: