TorchCraftAI
A bot for machine learning research on StarCraft: Brood War
Public Member Functions | List of all members
cpid::OnlineZORBTrainer Class Referenceabstract

This is a OnlineZORBTrainer that works with multiple actions per frame. More...

#include <zeroordertrainer.h>

Inherits cpid::Trainer.

Public Member Functions

void stepEpisode (GameUID const &, EpisodeKey const &, ReplayBuffer::Episode &) override
 
bool update () override
 
EpisodeHandle startEpisode () override
 Returns true if succeeded to register an episode, and false otherwise. More...
 
ag::Variant forward (ag::Variant inp, EpisodeHandle const &) override
 
 OnlineZORBTrainer (ag::Container model, ag::Optimizer optim)
 
virtual std::shared_ptr< ReplayBufferFramemakeFrame (ag::Variant trainerOutput, ag::Variant state, float reward) override
 
 TORCH_ARG (float, valueLambda)=0
 
 TORCH_ARG (float, delta)
 
 TORCH_ARG (int, batchSize)
 
 TORCH_ARG (bool, antithetic)
 
- Public Member Functions inherited from cpid::Trainer
 Trainer (ag::Container model, ag::Optimizer optim, std::unique_ptr< BaseSampler >, std::unique_ptr< AsyncBatcher > batcher=nullptr)
 
ag::Variant forwardUnbatched (ag::Variant in, ag::Container model=nullptr)
 Convenience function when one need to forward a single input. More...
 
virtual ~Trainer ()=default
 
void setTrain (bool=true)
 
bool isTrain () const
 
ag::Variant sample (ag::Variant in)
 Sample using the class' sampler. More...
 
ag::Container model () const
 
ag::Optimizer optim () const
 
ReplayBufferreplayBuffer ()
 
virtual std::shared_ptr< EvaluatormakeEvaluator (size_t, std::unique_ptr< BaseSampler > sampler)
 
void setDone (bool=true)
 
bool isDone () const
 
virtual void step (EpisodeHandle const &, std::shared_ptr< ReplayBufferFrame > v, bool isDone=false)
 
virtual void forceStopEpisode (EpisodeHandle const &)
 
bool isActive (EpisodeHandle const &)
 
virtual void reset ()
 Releases all the worker threads so that they can be joined. More...
 
template<class Archive >
void save (Archive &ar) const
 
template<class Archive >
void load (Archive &ar)
 
template<typename T >
bool is () const
 
TrainersetMetricsContext (std::shared_ptr< MetricsContext > context)
 
std::shared_ptr< MetricsContextmetricsContext () const
 
 TORCH_ARG (float, noiseStd)
 
 TORCH_ARG (bool, continuousActions)
 
void setBatcher (std::unique_ptr< AsyncBatcher > batcher)
 

Additional Inherited Members

- Protected Types inherited from cpid::Trainer
using ForwardFunction = std::function< ag::Variant(ag::Variant, EpisodeHandle const &)>
 
- Protected Member Functions inherited from cpid::Trainer
virtual void stepFrame (GameUID const &, EpisodeKey const &, ReplayBuffer::Episode &)
 
virtual void stepGame (GameUID const &game)
 
template<typename T >
std::vector< T const * > cast (ReplayBuffer::Episode const &e)
 
- Static Protected Member Functions inherited from cpid::Trainer
static std::shared_ptr< EvaluatorevaluatorFactory (ag::Container model, std::unique_ptr< BaseSampler > s, size_t n, ForwardFunction func)
 
- Protected Attributes inherited from cpid::Trainer
ag::Container model_
 
ag::Optimizer optim_
 
std::shared_ptr< MetricsContextmetricsContext_
 
ReplayBuffer replayer_
 
bool train_ = true
 
std::atomic< bool > done_ {false}
 
std::mutex modelWriteMutex_
 
std::shared_timed_mutex activeMapMutex_
 
std::unique_ptr< BaseSamplersampler_
 
std::unique_ptr< AsyncBatcherbatcher_
 
std::shared_ptr< HandleGuardepGuard_
 
ReplayBuffer::UIDKeyStore actives_
 
- Static Protected Attributes inherited from cpid::Trainer
static constexpr float kFwdMetricsSubsampling = 0.1
 We subsample kFwdMetricsSubsampling of the forward() events when measuring their duration. More...
 

Detailed Description

This is a OnlineZORBTrainer that works with multiple actions per frame.

Contract:

The trainer->forward function will return you [action_i, action_scores_i, ...]

TODO ON MULTIPLE NODES, THIS IS UNTESTED AND DOESN'T WORK even though the logic is mostly there

Constructor & Destructor Documentation

cpid::OnlineZORBTrainer::OnlineZORBTrainer ( ag::Container  model,
ag::Optimizer  optim 
)

Member Function Documentation

ag::Variant cpid::OnlineZORBTrainer::forward ( ag::Variant  inp,
EpisodeHandle const &  handle 
)
overridevirtual

Reimplemented from cpid::Trainer.

std::shared_ptr< ReplayBufferFrame > cpid::OnlineZORBTrainer::makeFrame ( ag::Variant  trainerOutput,
ag::Variant  state,
float  reward 
)
overridevirtual

Implements cpid::Trainer.

Trainer::EpisodeHandle cpid::OnlineZORBTrainer::startEpisode ( )
overridevirtual

Returns true if succeeded to register an episode, and false otherwise.

After receiving false, a worker thread should check stopping conditins and re-try.

Reimplemented from cpid::Trainer.

void cpid::OnlineZORBTrainer::stepEpisode ( GameUID const &  id,
EpisodeKey const &  k,
ReplayBuffer::Episode ep 
)
overridevirtual

Reimplemented from cpid::Trainer.

cpid::OnlineZORBTrainer::TORCH_ARG ( float  ,
valueLambda   
)
pure virtual
cpid::OnlineZORBTrainer::TORCH_ARG ( float  ,
delta   
)
cpid::OnlineZORBTrainer::TORCH_ARG ( int  ,
batchSize   
)
cpid::OnlineZORBTrainer::TORCH_ARG ( bool  ,
antithetic   
)
bool cpid::OnlineZORBTrainer::update ( )
overridevirtual

Implements cpid::Trainer.


The documentation for this class was generated from the following files: