TorchCraftAI
A bot for machine learning research on StarCraft: Brood War
Public Member Functions | Protected Member Functions | List of all members
cpid::BatchedPGTrainer Class Reference

Off policy policy gradient with a critic. More...

#include <policygradienttrainer.h>

Inherits cpid::Trainer.

Public Member Functions

ag::Variant forward (ag::Variant inp, EpisodeHandle const &) override
 
bool update () override
 
void doOnlineUpdatesInstead ()
 
int episodes ()
 
 BatchedPGTrainer (ag::Container model, ag::Optimizer optim, std::unique_ptr< BaseSampler > sampler, double gamma=0.99, int batchSize=10, std::size_t maxBatchSize=50, std::unique_ptr< AsyncBatcher > batcher=nullptr)
 
virtual std::shared_ptr< ReplayBufferFramemakeFrame (ag::Variant trainerOutput, ag::Variant state, float reward) override
 Contract: the trainer output should be a map with keys: "action" for the taken action "V" for the state value, and "action" for the action probability. More...
 
std::shared_ptr< EvaluatormakeEvaluator (size_t, std::unique_ptr< BaseSampler > sampler=std::make_unique< DiscreteMaxSampler >()) override
 
- Public Member Functions inherited from cpid::Trainer
 Trainer (ag::Container model, ag::Optimizer optim, std::unique_ptr< BaseSampler >, std::unique_ptr< AsyncBatcher > batcher=nullptr)
 
ag::Variant forwardUnbatched (ag::Variant in, ag::Container model=nullptr)
 Convenience function when one need to forward a single input. More...
 
virtual ~Trainer ()=default
 
void setTrain (bool=true)
 
bool isTrain () const
 
ag::Variant sample (ag::Variant in)
 Sample using the class' sampler. More...
 
ag::Container model () const
 
ag::Optimizer optim () const
 
ReplayBufferreplayBuffer ()
 
void setDone (bool=true)
 
bool isDone () const
 
virtual void step (EpisodeHandle const &, std::shared_ptr< ReplayBufferFrame > v, bool isDone=false)
 
virtual EpisodeHandle startEpisode ()
 Returns true if succeeded to register an episode, and false otherwise. More...
 
virtual void forceStopEpisode (EpisodeHandle const &)
 
bool isActive (EpisodeHandle const &)
 
virtual void reset ()
 Releases all the worker threads so that they can be joined. More...
 
template<class Archive >
void save (Archive &ar) const
 
template<class Archive >
void load (Archive &ar)
 
template<typename T >
bool is () const
 
TrainersetMetricsContext (std::shared_ptr< MetricsContext > context)
 
std::shared_ptr< MetricsContextmetricsContext () const
 
 TORCH_ARG (float, noiseStd)
 
 TORCH_ARG (bool, continuousActions)
 
void setBatcher (std::unique_ptr< AsyncBatcher > batcher)
 

Protected Member Functions

void stepEpisode (GameUID const &, EpisodeKey const &, ReplayBuffer::Episode &) override
 
- Protected Member Functions inherited from cpid::Trainer
virtual void stepFrame (GameUID const &, EpisodeKey const &, ReplayBuffer::Episode &)
 
virtual void stepGame (GameUID const &game)
 
template<typename T >
std::vector< T const * > cast (ReplayBuffer::Episode const &e)
 

Additional Inherited Members

- Protected Types inherited from cpid::Trainer
using ForwardFunction = std::function< ag::Variant(ag::Variant, EpisodeHandle const &)>
 
- Static Protected Member Functions inherited from cpid::Trainer
static std::shared_ptr< EvaluatorevaluatorFactory (ag::Container model, std::unique_ptr< BaseSampler > s, size_t n, ForwardFunction func)
 
- Protected Attributes inherited from cpid::Trainer
ag::Container model_
 
ag::Optimizer optim_
 
std::shared_ptr< MetricsContextmetricsContext_
 
ReplayBuffer replayer_
 
bool train_ = true
 
std::atomic< bool > done_ {false}
 
std::mutex modelWriteMutex_
 
std::shared_timed_mutex activeMapMutex_
 
std::unique_ptr< BaseSamplersampler_
 
std::unique_ptr< AsyncBatcherbatcher_
 
std::shared_ptr< HandleGuardepGuard_
 
ReplayBuffer::UIDKeyStore actives_
 
- Static Protected Attributes inherited from cpid::Trainer
static constexpr float kFwdMetricsSubsampling = 0.1
 We subsample kFwdMetricsSubsampling of the forward() events when measuring their duration. More...
 

Detailed Description

Off policy policy gradient with a critic.

This Trainer implements two modes:

Replayer format: state, action, p(action), reward

Model output: Probability vector over actions: 1-dim Vector Critic's value estimate: Double

Constructor & Destructor Documentation

cpid::BatchedPGTrainer::BatchedPGTrainer ( ag::Container  model,
ag::Optimizer  optim,
std::unique_ptr< BaseSampler sampler,
double  gamma = 0.99,
int  batchSize = 10,
std::size_t  maxBatchSize = 50,
std::unique_ptr< AsyncBatcher batcher = nullptr 
)

Member Function Documentation

void cpid::BatchedPGTrainer::doOnlineUpdatesInstead ( )
int cpid::BatchedPGTrainer::episodes ( )
inline
ag::Variant cpid::BatchedPGTrainer::forward ( ag::Variant  inp,
EpisodeHandle const &  handle 
)
overridevirtual

Reimplemented from cpid::Trainer.

std::shared_ptr< Evaluator > cpid::BatchedPGTrainer::makeEvaluator ( size_t  n,
std::unique_ptr< BaseSampler sampler = std::make_unique<DiscreteMaxSampler>() 
)
overridevirtual

Reimplemented from cpid::Trainer.

std::shared_ptr< ReplayBufferFrame > cpid::BatchedPGTrainer::makeFrame ( ag::Variant  trainerOutput,
ag::Variant  state,
float  reward 
)
overridevirtual

Contract: the trainer output should be a map with keys: "action" for the taken action "V" for the state value, and "action" for the action probability.

Implements cpid::Trainer.

void cpid::BatchedPGTrainer::stepEpisode ( GameUID const &  id,
EpisodeKey const &  k,
ReplayBuffer::Episode  
)
overrideprotectedvirtual

Reimplemented from cpid::Trainer.

bool cpid::BatchedPGTrainer::update ( )
overridevirtual

Implements cpid::Trainer.


The documentation for this class was generated from the following files: