12 #include <shared_mutex> 14 #include "distributed.h" 28 std::deque<std::pair<GameUID, EpisodeKey>>
newGames_;
39 std::unique_ptr<BaseSampler> sampler,
48 void reset()
override;
49 std::shared_ptr<ReplayBufferFrame>
makeFrame(
52 float reward)
override;
66 std::unique_ptr<BaseSampler> s,
std::string GameUID
Definition: trainer.h:31
ag::Container model() const
Definition: trainer.cpp:231
std::mutex updateMutex_
Definition: evaluator.h:26
Definition: trainer.h:158
ForwardFunction forwardFunction_
Definition: evaluator.h:35
Definition: evaluator.h:19
EpisodeHandle startEpisode() override
Returns true if succeeded to register an episode, and false otherwise.
Definition: evaluator.cpp:82
std::shared_ptr< ReplayBufferFrame > makeFrame(ag::Variant, ag::Variant, float reward) override
Definition: evaluator.cpp:123
The Trainer should be shared amongst multiple different nodes, and attached to a single Module...
Definition: trainer.h:156
size_t gamesStarted_
Definition: evaluator.h:23
void reset() override
Releases all the worker threads so that they can be joined.
Definition: evaluator.cpp:116
void forceStopEpisode(EpisodeHandle const &) override
Definition: evaluator.cpp:102
Definition: evaluator.h:62
std::condition_variable batchBarrier_
Definition: evaluator.h:24
std::shared_timed_mutex insertionMutex_
Definition: evaluator.h:27
virtual void stepEpisode(GameUID const &, EpisodeKey const &, ReplayBuffer::Episode &) override
Definition: evaluator.cpp:26
std::string EpisodeKey
Definition: trainer.h:32
std::function< ag::Variant(ag::Variant, EpisodeHandle const &)> ForwardFunction
Definition: trainer.h:290
The TorchCraftAI training library.
Definition: batcher.cpp:15
Evaluator(ag::Container model, std::unique_ptr< BaseSampler > sampler, size_t batchSize, ForwardFunction func)
Definition: evaluator.cpp:15
size_t batchSize_
Definition: evaluator.h:21
make_shared_enabler(ag::Container model, std::unique_ptr< BaseSampler > s, size_t n, ForwardFunction f)
Definition: evaluator.h:64
std::vector< std::shared_ptr< ReplayBufferFrame >> Episode
Definition: trainer.h:89
virtual ag::Variant forward(ag::Variant inp, EpisodeHandle const &) override
Definition: evaluator.cpp:110
bool update() override
Definition: evaluator.cpp:34
std::deque< std::pair< GameUID, EpisodeKey > > newGames_
Definition: evaluator.h:28