TorchCraftAI
A bot for machine learning research on StarCraft: Brood War
|
Off policy policy gradient with a critic. More...
#include <policygradienttrainer.h>
Inherits cpid::Trainer.
Public Member Functions | |
ag::Variant | forward (ag::Variant inp, EpisodeHandle const &) override |
bool | update () override |
void | doOnlineUpdatesInstead () |
int | episodes () |
BatchedPGTrainer (ag::Container model, ag::Optimizer optim, std::unique_ptr< BaseSampler > sampler, double gamma=0.99, int batchSize=10, std::size_t maxBatchSize=50, std::unique_ptr< AsyncBatcher > batcher=nullptr) | |
virtual std::shared_ptr< ReplayBufferFrame > | makeFrame (ag::Variant trainerOutput, ag::Variant state, float reward) override |
Contract: the trainer output should be a map with keys: "action" for the taken action "V" for the state value, and "action" for the action probability. More... | |
std::shared_ptr< Evaluator > | makeEvaluator (size_t, std::unique_ptr< BaseSampler > sampler=std::make_unique< DiscreteMaxSampler >()) override |
![]() | |
Trainer (ag::Container model, ag::Optimizer optim, std::unique_ptr< BaseSampler >, std::unique_ptr< AsyncBatcher > batcher=nullptr) | |
ag::Variant | forwardUnbatched (ag::Variant in, ag::Container model=nullptr) |
Convenience function when one need to forward a single input. More... | |
virtual | ~Trainer ()=default |
void | setTrain (bool=true) |
bool | isTrain () const |
ag::Variant | sample (ag::Variant in) |
Sample using the class' sampler. More... | |
ag::Container | model () const |
ag::Optimizer | optim () const |
ReplayBuffer & | replayBuffer () |
void | setDone (bool=true) |
bool | isDone () const |
virtual void | step (EpisodeHandle const &, std::shared_ptr< ReplayBufferFrame > v, bool isDone=false) |
virtual EpisodeHandle | startEpisode () |
Returns true if succeeded to register an episode, and false otherwise. More... | |
virtual void | forceStopEpisode (EpisodeHandle const &) |
bool | isActive (EpisodeHandle const &) |
virtual void | reset () |
Releases all the worker threads so that they can be joined. More... | |
template<class Archive > | |
void | save (Archive &ar) const |
template<class Archive > | |
void | load (Archive &ar) |
template<typename T > | |
bool | is () const |
Trainer & | setMetricsContext (std::shared_ptr< MetricsContext > context) |
std::shared_ptr< MetricsContext > | metricsContext () const |
TORCH_ARG (float, noiseStd) | |
TORCH_ARG (bool, continuousActions) | |
void | setBatcher (std::unique_ptr< AsyncBatcher > batcher) |
Protected Member Functions | |
void | stepEpisode (GameUID const &, EpisodeKey const &, ReplayBuffer::Episode &) override |
![]() | |
virtual void | stepFrame (GameUID const &, EpisodeKey const &, ReplayBuffer::Episode &) |
virtual void | stepGame (GameUID const &game) |
template<typename T > | |
std::vector< T const * > | cast (ReplayBuffer::Episode const &e) |
Additional Inherited Members | |
![]() | |
using | ForwardFunction = std::function< ag::Variant(ag::Variant, EpisodeHandle const &)> |
![]() | |
static std::shared_ptr< Evaluator > | evaluatorFactory (ag::Container model, std::unique_ptr< BaseSampler > s, size_t n, ForwardFunction func) |
![]() | |
ag::Container | model_ |
ag::Optimizer | optim_ |
std::shared_ptr< MetricsContext > | metricsContext_ |
ReplayBuffer | replayer_ |
bool | train_ = true |
std::atomic< bool > | done_ {false} |
std::mutex | modelWriteMutex_ |
std::shared_timed_mutex | activeMapMutex_ |
std::unique_ptr< BaseSampler > | sampler_ |
std::unique_ptr< AsyncBatcher > | batcher_ |
std::shared_ptr< HandleGuard > | epGuard_ |
ReplayBuffer::UIDKeyStore | actives_ |
![]() | |
static constexpr float | kFwdMetricsSubsampling = 0.1 |
We subsample kFwdMetricsSubsampling of the forward() events when measuring their duration. More... | |
Off policy policy gradient with a critic.
This Trainer implements two modes:
Replayer format: state, action, p(action), reward
Model output: Probability vector over actions: 1-dim Vector Critic's value estimate: Double
cpid::BatchedPGTrainer::BatchedPGTrainer | ( | ag::Container | model, |
ag::Optimizer | optim, | ||
std::unique_ptr< BaseSampler > | sampler, | ||
double | gamma = 0.99 , |
||
int | batchSize = 10 , |
||
std::size_t | maxBatchSize = 50 , |
||
std::unique_ptr< AsyncBatcher > | batcher = nullptr |
||
) |
void cpid::BatchedPGTrainer::doOnlineUpdatesInstead | ( | ) |
|
inline |
|
overridevirtual |
Reimplemented from cpid::Trainer.
|
overridevirtual |
Reimplemented from cpid::Trainer.
|
overridevirtual |
Contract: the trainer output should be a map with keys: "action" for the taken action "V" for the state value, and "action" for the action probability.
Implements cpid::Trainer.
|
overrideprotectedvirtual |
Reimplemented from cpid::Trainer.
|
overridevirtual |
Implements cpid::Trainer.