TorchCraftAI
A bot for machine learning research on StarCraft: Brood War
|
#include <estrainer.h>
Inherits cpid::Trainer.
Public Types | |
enum | RewardTransform { kNone = 0, kRankTransform, kStdNormalize } |
Public Member Functions | |
ESTrainer (ag::Container model, ag::Optimizer optim, std::unique_ptr< BaseSampler > sampler, float std, size_t batchSize, size_t historyLength, bool antithetic, RewardTransform transform, bool onPolicy) | |
ag::Container | getGameModel (GameUID const &gameIUID, EpisodeKey const &key) |
void | forceStopEpisode (EpisodeHandle const &) override |
EpisodeHandle | startEpisode () override |
Returns true if succeeded to register an episode, and false otherwise. More... | |
bool | update () override |
virtual ag::Variant | forward (ag::Variant inp, EpisodeHandle const &) override |
std::shared_ptr< Evaluator > | makeEvaluator (size_t n, std::unique_ptr< BaseSampler > sampler=std::make_unique< DiscreteMaxSampler >()) override |
torch::Tensor | rewardTransform (torch::Tensor const &rewards, RewardTransform transform) |
void | reset () override |
Releases all the worker threads so that they can be joined. More... | |
virtual std::shared_ptr< ReplayBufferFrame > | makeFrame (ag::Variant trainerOutput, ag::Variant state, float reward) override |
TORCH_ARG (bool, waitUpdate) | |
Public Member Functions inherited from cpid::Trainer | |
Trainer (ag::Container model, ag::Optimizer optim, std::unique_ptr< BaseSampler >, std::unique_ptr< AsyncBatcher > batcher=nullptr) | |
ag::Variant | forwardUnbatched (ag::Variant in, ag::Container model=nullptr) |
Convenience function when one need to forward a single input. More... | |
virtual | ~Trainer ()=default |
void | setTrain (bool=true) |
bool | isTrain () const |
ag::Variant | sample (ag::Variant in) |
Sample using the class' sampler. More... | |
ag::Container | model () const |
ag::Optimizer | optim () const |
ReplayBuffer & | replayBuffer () |
void | setDone (bool=true) |
bool | isDone () const |
virtual void | step (EpisodeHandle const &, std::shared_ptr< ReplayBufferFrame > v, bool isDone=false) |
bool | isActive (EpisodeHandle const &) |
template<class Archive > | |
void | save (Archive &ar) const |
template<class Archive > | |
void | load (Archive &ar) |
template<typename T > | |
bool | is () const |
Trainer & | setMetricsContext (std::shared_ptr< MetricsContext > context) |
std::shared_ptr< MetricsContext > | metricsContext () const |
TORCH_ARG (float, noiseStd) | |
TORCH_ARG (bool, continuousActions) | |
void | setBatcher (std::unique_ptr< AsyncBatcher > batcher) |
Protected Member Functions | |
virtual void | stepEpisode (GameUID const &, EpisodeKey const &, ReplayBuffer::Episode &) override |
ag::Container | generateModel (int generation, int64_t seed) |
Re-creates model based on its seed and the generation it was produced from. More... | |
void | populateSeedQueue () |
Protected Member Functions inherited from cpid::Trainer | |
virtual void | stepFrame (GameUID const &, EpisodeKey const &, ReplayBuffer::Episode &) |
virtual void | stepGame (GameUID const &game) |
template<typename T > | |
std::vector< T const * > | cast (ReplayBuffer::Episode const &e) |
Protected Attributes | |
float | std_ |
size_t | batchSize_ |
std::unordered_map< std::pair< int, int64_t >, ag::Container, pairhash > | modelCache_ |
std::unordered_map< std::pair< GameUID, EpisodeKey >, std::pair< int, int64_t >, pairhash > | gameToGenerationSeed_ |
std::shared_timed_mutex | modelStorageMutex_ |
size_t | historyLength_ |
std::deque< std::pair< int, ag::Container > > | modelsHistory_ |
std::shared_timed_mutex | currentModelMutex_ |
bool | antithetic_ |
RewardTransform | transform_ |
bool | onPolicy_ |
std::shared_timed_mutex | insertionMutex_ |
std::deque< std::pair< GameUID, EpisodeKey > > | newGames_ |
std::mutex | seedQueueMutex_ |
std::vector< int64_t > | seedQueue_ |
std::mutex | updateMutex_ |
size_t | gamesStarted_ = 0 |
std::condition_variable | batchBarrier_ |
size_t | gatherSize_ |
std::vector< float > | allRewards_ |
std::vector< int > | allGenerations_ |
std::vector< int64_t > | allSeeds_ |
std::vector< float > | rewards_ |
std::vector< int > | generations_ |
std::vector< int64_t > | seeds_ |
Protected Attributes inherited from cpid::Trainer | |
ag::Container | model_ |
ag::Optimizer | optim_ |
std::shared_ptr< MetricsContext > | metricsContext_ |
ReplayBuffer | replayer_ |
bool | train_ = true |
std::atomic< bool > | done_ {false} |
std::mutex | modelWriteMutex_ |
std::shared_timed_mutex | activeMapMutex_ |
std::unique_ptr< BaseSampler > | sampler_ |
std::unique_ptr< AsyncBatcher > | batcher_ |
std::shared_ptr< HandleGuard > | epGuard_ |
ReplayBuffer::UIDKeyStore | actives_ |
Additional Inherited Members | |
Protected Types inherited from cpid::Trainer | |
using | ForwardFunction = std::function< ag::Variant(ag::Variant, EpisodeHandle const &)> |
Static Protected Member Functions inherited from cpid::Trainer | |
static std::shared_ptr< Evaluator > | evaluatorFactory (ag::Container model, std::unique_ptr< BaseSampler > s, size_t n, ForwardFunction func) |
Static Protected Attributes inherited from cpid::Trainer | |
static constexpr float | kFwdMetricsSubsampling = 0.1 |
We subsample kFwdMetricsSubsampling of the forward() events when measuring their duration. More... | |
cpid::ESTrainer::ESTrainer | ( | ag::Container | model, |
ag::Optimizer | optim, | ||
std::unique_ptr< BaseSampler > | sampler, | ||
float | std, | ||
size_t | batchSize, | ||
size_t | historyLength, | ||
bool | antithetic, | ||
RewardTransform | transform, | ||
bool | onPolicy | ||
) |
|
overridevirtual |
Reimplemented from cpid::Trainer.
|
overridevirtual |
Reimplemented from cpid::Trainer.
|
protected |
Re-creates model based on its seed and the generation it was produced from.
The absolute value of the seed is used for seeding, and its sign indicates if we add or subtract the noise, which is used to implement antithetic variates.
ag::Container cpid::ESTrainer::getGameModel | ( | GameUID const & | gameIUID, |
EpisodeKey const & | key | ||
) |
|
overridevirtual |
Reimplemented from cpid::Trainer.
|
overridevirtual |
Implements cpid::Trainer.
|
protected |
|
overridevirtual |
Releases all the worker threads so that they can be joined.
For the off-policy trainers, labels all games as inactive. For the on-policy trainers, additionally un-blocks all threads that could be waiting at the batch barrier.
Reimplemented from cpid::Trainer.
torch::Tensor cpid::ESTrainer::rewardTransform | ( | torch::Tensor const & | rewards, |
ESTrainer::RewardTransform | transform | ||
) |
|
overridevirtual |
Returns true if succeeded to register an episode, and false otherwise.
After receiving false, a worker thread should check stopping conditins and re-try.
Reimplemented from cpid::Trainer.
|
overrideprotectedvirtual |
Reimplemented from cpid::Trainer.
cpid::ESTrainer::TORCH_ARG | ( | bool | , |
waitUpdate | |||
) |
|
overridevirtual |
Implements cpid::Trainer.
|
protected |
|
protected |
|
protected |
|
protected |
|
protected |
|
protected |
|
protected |
|
protected |
|
protected |
|
protected |
|
protected |
|
protected |
|
protected |
|
protected |
|
protected |
|
protected |
|
protected |
|
protected |
|
protected |
|
protected |
|
protected |
|
protected |
|
protected |
|
protected |
|
protected |