TorchCraftAI
A bot for machine learning research on StarCraft: Brood War
Classes | Public Member Functions | Protected Member Functions | Protected Attributes | List of all members
cpid::Evaluator Class Reference

#include <evaluator.h>

Inherits cpid::Trainer.

Inherited by cpid::Evaluator::make_shared_enabler.

Classes

struct  make_shared_enabler
 

Public Member Functions

EpisodeHandle startEpisode () override
 Returns true if succeeded to register an episode, and false otherwise. More...
 
void forceStopEpisode (EpisodeHandle const &) override
 
bool update () override
 
virtual ag::Variant forward (ag::Variant inp, EpisodeHandle const &) override
 
void reset () override
 Releases all the worker threads so that they can be joined. More...
 
std::shared_ptr< ReplayBufferFramemakeFrame (ag::Variant, ag::Variant, float reward) override
 
- Public Member Functions inherited from cpid::Trainer
 Trainer (ag::Container model, ag::Optimizer optim, std::unique_ptr< BaseSampler >, std::unique_ptr< AsyncBatcher > batcher=nullptr)
 
ag::Variant forwardUnbatched (ag::Variant in, ag::Container model=nullptr)
 Convenience function when one need to forward a single input. More...
 
virtual ~Trainer ()=default
 
void setTrain (bool=true)
 
bool isTrain () const
 
ag::Variant sample (ag::Variant in)
 Sample using the class' sampler. More...
 
ag::Container model () const
 
ag::Optimizer optim () const
 
ReplayBufferreplayBuffer ()
 
virtual std::shared_ptr< EvaluatormakeEvaluator (size_t, std::unique_ptr< BaseSampler > sampler)
 
void setDone (bool=true)
 
bool isDone () const
 
virtual void step (EpisodeHandle const &, std::shared_ptr< ReplayBufferFrame > v, bool isDone=false)
 
bool isActive (EpisodeHandle const &)
 
template<class Archive >
void save (Archive &ar) const
 
template<class Archive >
void load (Archive &ar)
 
template<typename T >
bool is () const
 
TrainersetMetricsContext (std::shared_ptr< MetricsContext > context)
 
std::shared_ptr< MetricsContextmetricsContext () const
 
 TORCH_ARG (float, noiseStd)
 
 TORCH_ARG (bool, continuousActions)
 
void setBatcher (std::unique_ptr< AsyncBatcher > batcher)
 

Protected Member Functions

virtual void stepEpisode (GameUID const &, EpisodeKey const &, ReplayBuffer::Episode &) override
 
 Evaluator (ag::Container model, std::unique_ptr< BaseSampler > sampler, size_t batchSize, ForwardFunction func)
 
- Protected Member Functions inherited from cpid::Trainer
virtual void stepFrame (GameUID const &, EpisodeKey const &, ReplayBuffer::Episode &)
 
virtual void stepGame (GameUID const &game)
 
template<typename T >
std::vector< T const * > cast (ReplayBuffer::Episode const &e)
 

Protected Attributes

size_t batchSize_
 
size_t gamesStarted_ = 0
 
std::condition_variable batchBarrier_
 
std::mutex updateMutex_
 
std::shared_timed_mutex insertionMutex_
 
std::deque< std::pair< GameUID, EpisodeKey > > newGames_
 
ForwardFunction forwardFunction_
 
- Protected Attributes inherited from cpid::Trainer
ag::Container model_
 
ag::Optimizer optim_
 
std::shared_ptr< MetricsContextmetricsContext_
 
ReplayBuffer replayer_
 
bool train_ = true
 
std::atomic< bool > done_ {false}
 
std::mutex modelWriteMutex_
 
std::shared_timed_mutex activeMapMutex_
 
std::unique_ptr< BaseSamplersampler_
 
std::unique_ptr< AsyncBatcherbatcher_
 
std::shared_ptr< HandleGuardepGuard_
 
ReplayBuffer::UIDKeyStore actives_
 

Additional Inherited Members

- Protected Types inherited from cpid::Trainer
using ForwardFunction = std::function< ag::Variant(ag::Variant, EpisodeHandle const &)>
 
- Static Protected Member Functions inherited from cpid::Trainer
static std::shared_ptr< EvaluatorevaluatorFactory (ag::Container model, std::unique_ptr< BaseSampler > s, size_t n, ForwardFunction func)
 
- Static Protected Attributes inherited from cpid::Trainer
static constexpr float kFwdMetricsSubsampling = 0.1
 We subsample kFwdMetricsSubsampling of the forward() events when measuring their duration. More...
 

Constructor & Destructor Documentation

cpid::Evaluator::Evaluator ( ag::Container  model,
std::unique_ptr< BaseSampler sampler,
size_t  batchSize,
ForwardFunction  func 
)
protected

Member Function Documentation

void cpid::Evaluator::forceStopEpisode ( EpisodeHandle const &  handle)
overridevirtual

Reimplemented from cpid::Trainer.

ag::Variant cpid::Evaluator::forward ( ag::Variant  inp,
EpisodeHandle const &  handle 
)
overridevirtual

Reimplemented from cpid::Trainer.

std::shared_ptr< ReplayBufferFrame > cpid::Evaluator::makeFrame ( ag::Variant  ,
ag::Variant  ,
float  reward 
)
overridevirtual

Implements cpid::Trainer.

void cpid::Evaluator::reset ( )
overridevirtual

Releases all the worker threads so that they can be joined.

For the off-policy trainers, labels all games as inactive. For the on-policy trainers, additionally un-blocks all threads that could be waiting at the batch barrier.

Reimplemented from cpid::Trainer.

Trainer::EpisodeHandle cpid::Evaluator::startEpisode ( )
overridevirtual

Returns true if succeeded to register an episode, and false otherwise.

After receiving false, a worker thread should check stopping conditins and re-try.

Reimplemented from cpid::Trainer.

void cpid::Evaluator::stepEpisode ( GameUID const &  gameUID,
EpisodeKey const &  key,
ReplayBuffer::Episode  
)
overrideprotectedvirtual

Reimplemented from cpid::Trainer.

bool cpid::Evaluator::update ( )
overridevirtual

Implements cpid::Trainer.

Member Data Documentation

std::condition_variable cpid::Evaluator::batchBarrier_
protected
size_t cpid::Evaluator::batchSize_
protected
ForwardFunction cpid::Evaluator::forwardFunction_
protected
size_t cpid::Evaluator::gamesStarted_ = 0
protected
std::shared_timed_mutex cpid::Evaluator::insertionMutex_
protected
std::deque<std::pair<GameUID, EpisodeKey> > cpid::Evaluator::newGames_
protected
std::mutex cpid::Evaluator::updateMutex_
protected

The documentation for this class was generated from the following files: