TorchCraftAI
A bot for machine learning research on StarCraft: Brood War
Public Types | Public Member Functions | Protected Member Functions | Protected Attributes | List of all members
cpid::ESTrainer Class Reference

#include <estrainer.h>

Inherits cpid::Trainer.

Public Types

enum  RewardTransform { kNone = 0, kRankTransform, kStdNormalize }
 

Public Member Functions

 ESTrainer (ag::Container model, ag::Optimizer optim, std::unique_ptr< BaseSampler > sampler, float std, size_t batchSize, size_t historyLength, bool antithetic, RewardTransform transform, bool onPolicy)
 
ag::Container getGameModel (GameUID const &gameIUID, EpisodeKey const &key)
 
void forceStopEpisode (EpisodeHandle const &) override
 
EpisodeHandle startEpisode () override
 Returns true if succeeded to register an episode, and false otherwise. More...
 
bool update () override
 
virtual ag::Variant forward (ag::Variant inp, EpisodeHandle const &) override
 
std::shared_ptr< EvaluatormakeEvaluator (size_t n, std::unique_ptr< BaseSampler > sampler=std::make_unique< DiscreteMaxSampler >()) override
 
torch::Tensor rewardTransform (torch::Tensor const &rewards, RewardTransform transform)
 
void reset () override
 Releases all the worker threads so that they can be joined. More...
 
virtual std::shared_ptr< ReplayBufferFramemakeFrame (ag::Variant trainerOutput, ag::Variant state, float reward) override
 
 TORCH_ARG (bool, waitUpdate)
 
- Public Member Functions inherited from cpid::Trainer
 Trainer (ag::Container model, ag::Optimizer optim, std::unique_ptr< BaseSampler >, std::unique_ptr< AsyncBatcher > batcher=nullptr)
 
ag::Variant forwardUnbatched (ag::Variant in, ag::Container model=nullptr)
 Convenience function when one need to forward a single input. More...
 
virtual ~Trainer ()=default
 
void setTrain (bool=true)
 
bool isTrain () const
 
ag::Variant sample (ag::Variant in)
 Sample using the class' sampler. More...
 
ag::Container model () const
 
ag::Optimizer optim () const
 
ReplayBufferreplayBuffer ()
 
void setDone (bool=true)
 
bool isDone () const
 
virtual void step (EpisodeHandle const &, std::shared_ptr< ReplayBufferFrame > v, bool isDone=false)
 
bool isActive (EpisodeHandle const &)
 
template<class Archive >
void save (Archive &ar) const
 
template<class Archive >
void load (Archive &ar)
 
template<typename T >
bool is () const
 
TrainersetMetricsContext (std::shared_ptr< MetricsContext > context)
 
std::shared_ptr< MetricsContextmetricsContext () const
 
 TORCH_ARG (float, noiseStd)
 
 TORCH_ARG (bool, continuousActions)
 
void setBatcher (std::unique_ptr< AsyncBatcher > batcher)
 

Protected Member Functions

virtual void stepEpisode (GameUID const &, EpisodeKey const &, ReplayBuffer::Episode &) override
 
ag::Container generateModel (int generation, int64_t seed)
 Re-creates model based on its seed and the generation it was produced from. More...
 
void populateSeedQueue ()
 
- Protected Member Functions inherited from cpid::Trainer
virtual void stepFrame (GameUID const &, EpisodeKey const &, ReplayBuffer::Episode &)
 
virtual void stepGame (GameUID const &game)
 
template<typename T >
std::vector< T const * > cast (ReplayBuffer::Episode const &e)
 

Protected Attributes

float std_
 
size_t batchSize_
 
std::unordered_map< std::pair< int, int64_t >, ag::Container, pairhashmodelCache_
 
std::unordered_map< std::pair< GameUID, EpisodeKey >, std::pair< int, int64_t >, pairhashgameToGenerationSeed_
 
std::shared_timed_mutex modelStorageMutex_
 
size_t historyLength_
 
std::deque< std::pair< int, ag::Container > > modelsHistory_
 
std::shared_timed_mutex currentModelMutex_
 
bool antithetic_
 
RewardTransform transform_
 
bool onPolicy_
 
std::shared_timed_mutex insertionMutex_
 
std::deque< std::pair< GameUID, EpisodeKey > > newGames_
 
std::mutex seedQueueMutex_
 
std::vector< int64_t > seedQueue_
 
std::mutex updateMutex_
 
size_t gamesStarted_ = 0
 
std::condition_variable batchBarrier_
 
size_t gatherSize_
 
std::vector< float > allRewards_
 
std::vector< int > allGenerations_
 
std::vector< int64_t > allSeeds_
 
std::vector< float > rewards_
 
std::vector< int > generations_
 
std::vector< int64_t > seeds_
 
- Protected Attributes inherited from cpid::Trainer
ag::Container model_
 
ag::Optimizer optim_
 
std::shared_ptr< MetricsContextmetricsContext_
 
ReplayBuffer replayer_
 
bool train_ = true
 
std::atomic< bool > done_ {false}
 
std::mutex modelWriteMutex_
 
std::shared_timed_mutex activeMapMutex_
 
std::unique_ptr< BaseSamplersampler_
 
std::unique_ptr< AsyncBatcherbatcher_
 
std::shared_ptr< HandleGuardepGuard_
 
ReplayBuffer::UIDKeyStore actives_
 

Additional Inherited Members

- Protected Types inherited from cpid::Trainer
using ForwardFunction = std::function< ag::Variant(ag::Variant, EpisodeHandle const &)>
 
- Static Protected Member Functions inherited from cpid::Trainer
static std::shared_ptr< EvaluatorevaluatorFactory (ag::Container model, std::unique_ptr< BaseSampler > s, size_t n, ForwardFunction func)
 
- Static Protected Attributes inherited from cpid::Trainer
static constexpr float kFwdMetricsSubsampling = 0.1
 We subsample kFwdMetricsSubsampling of the forward() events when measuring their duration. More...
 

Member Enumeration Documentation

Enumerator
kNone 
kRankTransform 
kStdNormalize 

Constructor & Destructor Documentation

cpid::ESTrainer::ESTrainer ( ag::Container  model,
ag::Optimizer  optim,
std::unique_ptr< BaseSampler sampler,
float  std,
size_t  batchSize,
size_t  historyLength,
bool  antithetic,
RewardTransform  transform,
bool  onPolicy 
)

Member Function Documentation

void cpid::ESTrainer::forceStopEpisode ( EpisodeHandle const &  handle)
overridevirtual

Reimplemented from cpid::Trainer.

ag::Variant cpid::ESTrainer::forward ( ag::Variant  inp,
EpisodeHandle const &  handle 
)
overridevirtual

Reimplemented from cpid::Trainer.

ag::Container cpid::ESTrainer::generateModel ( int  generation,
int64_t  seed 
)
protected

Re-creates model based on its seed and the generation it was produced from.

The absolute value of the seed is used for seeding, and its sign indicates if we add or subtract the noise, which is used to implement antithetic variates.

ag::Container cpid::ESTrainer::getGameModel ( GameUID const &  gameIUID,
EpisodeKey const &  key 
)
std::shared_ptr< Evaluator > cpid::ESTrainer::makeEvaluator ( size_t  n,
std::unique_ptr< BaseSampler sampler = std::make_unique<DiscreteMaxSampler>() 
)
overridevirtual

Reimplemented from cpid::Trainer.

std::shared_ptr< ReplayBufferFrame > cpid::ESTrainer::makeFrame ( ag::Variant  trainerOutput,
ag::Variant  state,
float  reward 
)
overridevirtual

Implements cpid::Trainer.

void cpid::ESTrainer::populateSeedQueue ( )
protected
void cpid::ESTrainer::reset ( )
overridevirtual

Releases all the worker threads so that they can be joined.

For the off-policy trainers, labels all games as inactive. For the on-policy trainers, additionally un-blocks all threads that could be waiting at the batch barrier.

Reimplemented from cpid::Trainer.

torch::Tensor cpid::ESTrainer::rewardTransform ( torch::Tensor const &  rewards,
ESTrainer::RewardTransform  transform 
)
Trainer::EpisodeHandle cpid::ESTrainer::startEpisode ( )
overridevirtual

Returns true if succeeded to register an episode, and false otherwise.

After receiving false, a worker thread should check stopping conditins and re-try.

Reimplemented from cpid::Trainer.

void cpid::ESTrainer::stepEpisode ( GameUID const &  gameUID,
EpisodeKey const &  key,
ReplayBuffer::Episode  
)
overrideprotectedvirtual

Reimplemented from cpid::Trainer.

cpid::ESTrainer::TORCH_ARG ( bool  ,
waitUpdate   
)
bool cpid::ESTrainer::update ( )
overridevirtual

Implements cpid::Trainer.

Member Data Documentation

std::vector<int> cpid::ESTrainer::allGenerations_
protected
std::vector<float> cpid::ESTrainer::allRewards_
protected
std::vector<int64_t> cpid::ESTrainer::allSeeds_
protected
bool cpid::ESTrainer::antithetic_
protected
std::condition_variable cpid::ESTrainer::batchBarrier_
protected
size_t cpid::ESTrainer::batchSize_
protected
std::shared_timed_mutex cpid::ESTrainer::currentModelMutex_
protected
size_t cpid::ESTrainer::gamesStarted_ = 0
protected
std::unordered_map< std::pair<GameUID, EpisodeKey>, std::pair<int, int64_t>, pairhash> cpid::ESTrainer::gameToGenerationSeed_
protected
size_t cpid::ESTrainer::gatherSize_
protected
std::vector<int> cpid::ESTrainer::generations_
protected
size_t cpid::ESTrainer::historyLength_
protected
std::shared_timed_mutex cpid::ESTrainer::insertionMutex_
protected
std::unordered_map<std::pair<int, int64_t>, ag::Container, pairhash> cpid::ESTrainer::modelCache_
protected
std::deque<std::pair<int, ag::Container> > cpid::ESTrainer::modelsHistory_
protected
std::shared_timed_mutex cpid::ESTrainer::modelStorageMutex_
protected
std::deque<std::pair<GameUID, EpisodeKey> > cpid::ESTrainer::newGames_
protected
bool cpid::ESTrainer::onPolicy_
protected
std::vector<float> cpid::ESTrainer::rewards_
protected
std::vector<int64_t> cpid::ESTrainer::seedQueue_
protected
std::mutex cpid::ESTrainer::seedQueueMutex_
protected
std::vector<int64_t> cpid::ESTrainer::seeds_
protected
float cpid::ESTrainer::std_
protected
RewardTransform cpid::ESTrainer::transform_
protected
std::mutex cpid::ESTrainer::updateMutex_
protected

The documentation for this class was generated from the following files: