TorchCraftAI
A bot for machine learning research on StarCraft: Brood War
evaluator.h
1 /*
2  * Copyright (c) 2017-present, Facebook, Inc.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #pragma once
9 
10 #include "metrics.h"
11 #include "trainer.h"
12 #include <shared_mutex>
13 
14 #include "distributed.h"
15 #include "sampler.h"
16 
17 namespace cpid {
18 
19 class Evaluator : public Trainer {
20  protected:
21  size_t batchSize_;
22 
23  size_t gamesStarted_ = 0;
24  std::condition_variable batchBarrier_;
25 
26  std::mutex updateMutex_;
27  std::shared_timed_mutex insertionMutex_;
28  std::deque<std::pair<GameUID, EpisodeKey>> newGames_;
29 
30  virtual void stepEpisode(
31  GameUID const&,
32  EpisodeKey const&,
33  ReplayBuffer::Episode&) override;
34 
36 
37  Evaluator(
38  ag::Container model,
39  std::unique_ptr<BaseSampler> sampler,
40  size_t batchSize,
41  ForwardFunction func);
42 
43  public:
44  EpisodeHandle startEpisode() override;
45  void forceStopEpisode(EpisodeHandle const&) override;
46  bool update() override;
47  virtual ag::Variant forward(ag::Variant inp, EpisodeHandle const&) override;
48  void reset() override;
49  std::shared_ptr<ReplayBufferFrame> makeFrame(
50  ag::Variant /*trainerOutput*/,
51  ag::Variant /*state*/,
52  float reward) override;
53 
54  struct make_shared_enabler;
55 };
56 
57 // IMPLEMENTATION DETAIL:
58 // This guy is just here to help Trainer make a shared pointer when the
59 // constructor to Evaluator is private. Nobody else can call that constructor,
60 // not even make_shared, but this enabler + friending Trainer helps Trainer make
61 // a shared_ptr<Evaluator>.
63  friend class Trainer;
65  ag::Container model,
66  std::unique_ptr<BaseSampler> s,
67  size_t n,
69  : Evaluator(model, std::move(s), n, f) {}
70 };
71 
72 } // namespace cpid
std::string GameUID
Definition: trainer.h:31
ag::Container model() const
Definition: trainer.cpp:231
std::mutex updateMutex_
Definition: evaluator.h:26
Definition: trainer.h:158
ForwardFunction forwardFunction_
Definition: evaluator.h:35
Definition: evaluator.h:19
EpisodeHandle startEpisode() override
Returns true if succeeded to register an episode, and false otherwise.
Definition: evaluator.cpp:82
STL namespace.
std::shared_ptr< ReplayBufferFrame > makeFrame(ag::Variant, ag::Variant, float reward) override
Definition: evaluator.cpp:123
The Trainer should be shared amongst multiple different nodes, and attached to a single Module...
Definition: trainer.h:156
size_t gamesStarted_
Definition: evaluator.h:23
void reset() override
Releases all the worker threads so that they can be joined.
Definition: evaluator.cpp:116
void forceStopEpisode(EpisodeHandle const &) override
Definition: evaluator.cpp:102
Definition: evaluator.h:62
std::condition_variable batchBarrier_
Definition: evaluator.h:24
std::shared_timed_mutex insertionMutex_
Definition: evaluator.h:27
virtual void stepEpisode(GameUID const &, EpisodeKey const &, ReplayBuffer::Episode &) override
Definition: evaluator.cpp:26
std::string EpisodeKey
Definition: trainer.h:32
std::function< ag::Variant(ag::Variant, EpisodeHandle const &)> ForwardFunction
Definition: trainer.h:290
The TorchCraftAI training library.
Definition: batcher.cpp:15
Evaluator(ag::Container model, std::unique_ptr< BaseSampler > sampler, size_t batchSize, ForwardFunction func)
Definition: evaluator.cpp:15
size_t batchSize_
Definition: evaluator.h:21
make_shared_enabler(ag::Container model, std::unique_ptr< BaseSampler > s, size_t n, ForwardFunction f)
Definition: evaluator.h:64
std::vector< std::shared_ptr< ReplayBufferFrame >> Episode
Definition: trainer.h:89
virtual ag::Variant forward(ag::Variant inp, EpisodeHandle const &) override
Definition: evaluator.cpp:110
bool update() override
Definition: evaluator.cpp:34
std::deque< std::pair< GameUID, EpisodeKey > > newGames_
Definition: evaluator.h:28