11 #include "policygradienttrainer.h" 13 #include <shared_mutex> 15 #include "distributed.h" 39 std::unordered_map<std::pair<int, int64_t>, ag::Container,
pairhash>
43 std::pair<GameUID, EpisodeKey>,
44 std::pair<int, int64_t>,
63 std::deque<std::pair<GameUID, EpisodeKey>>
newGames_;
93 std::unique_ptr<BaseSampler> sampler,
108 std::unique_ptr<BaseSampler> sampler =
109 std::make_unique<DiscreteMaxSampler>())
override;
111 torch::Tensor
const& rewards,
113 void reset()
override;
114 virtual std::shared_ptr<ReplayBufferFrame>
makeFrame(
115 ag::Variant trainerOutput,
117 float reward)
override;
torch::Tensor rewardTransform(torch::Tensor const &rewards, RewardTransform transform)
Definition: estrainer.cpp:458
std::string GameUID
Definition: trainer.h:31
ag::Container model() const
Definition: trainer.cpp:231
size_t gamesStarted_
Definition: estrainer.h:69
std::shared_timed_mutex currentModelMutex_
Definition: estrainer.h:55
std::vector< int64_t > seedQueue_
Definition: estrainer.h:66
Definition: trainer.h:158
virtual std::shared_ptr< ReplayBufferFrame > makeFrame(ag::Variant trainerOutput, ag::Variant state, float reward) override
Definition: estrainer.cpp:505
Definition: estrainer.h:30
std::deque< std::pair< int, ag::Container > > modelsHistory_
Definition: estrainer.h:54
void reset() override
Releases all the worker threads so that they can be joined.
Definition: estrainer.cpp:493
std::mutex seedQueueMutex_
Definition: estrainer.h:65
bool onPolicy_
Definition: estrainer.h:60
std::shared_timed_mutex insertionMutex_
Definition: estrainer.h:62
std::condition_variable batchBarrier_
Definition: estrainer.h:70
std::vector< float > rewards_
Definition: estrainer.h:77
std::vector< int > allGenerations_
Definition: estrainer.h:74
virtual void stepEpisode(GameUID const &, EpisodeKey const &, ReplayBuffer::Episode &) override
Definition: estrainer.cpp:81
RewardTransform
Definition: estrainer.h:22
Definition: estrainer.h:23
The Trainer should be shared amongst multiple different nodes, and attached to a single Module...
Definition: trainer.h:156
void populateSeedQueue()
Definition: estrainer.cpp:485
bool antithetic_
Definition: estrainer.h:57
std::deque< std::pair< GameUID, EpisodeKey > > newGames_
Definition: estrainer.h:63
std::vector< float > allRewards_
Definition: estrainer.h:73
EpisodeHandle startEpisode() override
Returns true if succeeded to register an episode, and false otherwise.
Definition: estrainer.cpp:318
virtual ag::Variant forward(ag::Variant inp, EpisodeHandle const &) override
Definition: estrainer.cpp:437
Definition: estrainer.h:27
bool update() override
Definition: estrainer.cpp:100
void forceStopEpisode(EpisodeHandle const &) override
Definition: estrainer.cpp:306
std::vector< int64_t > seeds_
Definition: estrainer.h:79
std::shared_timed_mutex modelStorageMutex_
Definition: estrainer.h:47
size_t gatherSize_
Definition: estrainer.h:72
std::string EpisodeKey
Definition: trainer.h:32
ag::Container getGameModel(GameUID const &gameIUID, EpisodeKey const &key)
Definition: estrainer.cpp:371
std::vector< int64_t > allSeeds_
Definition: estrainer.h:75
float std_
Definition: estrainer.h:34
size_t historyLength_
Definition: estrainer.h:50
std::shared_ptr< Evaluator > makeEvaluator(size_t n, std::unique_ptr< BaseSampler > sampler=std::make_unique< DiscreteMaxSampler >()) override
Definition: estrainer.cpp:445
ag::Container generateModel(int generation, int64_t seed)
Re-creates model based on its seed and the generation it was produced from.
Definition: estrainer.cpp:394
RewardTransform transform_
Definition: estrainer.h:58
ag::Optimizer optim() const
Definition: trainer.cpp:235
The TorchCraftAI training library.
Definition: batcher.cpp:15
TORCH_ARG(bool, waitUpdate)
size_t batchSize_
Definition: estrainer.h:35
std::unordered_map< std::pair< GameUID, EpisodeKey >, std::pair< int, int64_t >, pairhash > gameToGenerationSeed_
Definition: estrainer.h:46
ESTrainer(ag::Container model, ag::Optimizer optim, std::unique_ptr< BaseSampler > sampler, float std, size_t batchSize, size_t historyLength, bool antithetic, RewardTransform transform, bool onPolicy)
Definition: estrainer.cpp:53
std::vector< std::shared_ptr< ReplayBufferFrame >> Episode
Definition: trainer.h:89
std::vector< int > generations_
Definition: estrainer.h:78
std::unordered_map< std::pair< int, int64_t >, ag::Container, pairhash > modelCache_
Definition: estrainer.h:40
std::mutex updateMutex_
Definition: estrainer.h:68
Definition: estrainer.h:20