TorchCraftAI
A bot for machine learning research on StarCraft: Brood War
estrainer.h
1 /*
2  * Copyright (c) 2017-present, Facebook, Inc.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #pragma once
9 
10 #include "metrics.h"
11 #include "policygradienttrainer.h"
12 #include "sampler.h"
13 #include <shared_mutex>
14 
15 #include "distributed.h"
16 #include <stack>
17 
18 namespace cpid {
19 
20 class ESTrainer : public Trainer {
21  public:
23  kNone = 0,
24  // Transforms a vector of elements into a vector of floats
25  // uniformly distributed within [-0.5,+0.5] according to their ranks.
26  // Used in https://arxiv.org/pdf/1703.03864.pdf
28  // Divides by the std of the rewards.
29  // Defined in https://arxiv.org/pdf/1803.07055.pdf
31  };
32 
33  protected:
34  float std_;
35  size_t batchSize_;
36 
37  // (generation, seed) => model mapping used to speed up forward in the active
38  // local models
39  std::unordered_map<std::pair<int, int64_t>, ag::Container, pairhash>
41  // (GameUID, Key) => (generation, seed) for the active local models
42  std::unordered_map<
43  std::pair<GameUID, EpisodeKey>,
44  std::pair<int, int64_t>,
45  pairhash>
47  std::shared_timed_mutex modelStorageMutex_;
48 
49  // the max number of historical models to store for the off-policy mode
51  // historyLength_ pairs of (generationId, model) stored in the order of
52  // sequentially increasing generations (front() is the oldest, back() is
53  // the newest)
54  std::deque<std::pair<int, ag::Container>> modelsHistory_;
55  std::shared_timed_mutex currentModelMutex_;
56 
59 
60  bool onPolicy_;
61 
62  std::shared_timed_mutex insertionMutex_;
63  std::deque<std::pair<GameUID, EpisodeKey>> newGames_;
64 
65  std::mutex seedQueueMutex_;
66  std::vector<int64_t> seedQueue_;
67 
68  std::mutex updateMutex_;
69  size_t gamesStarted_ = 0;
70  std::condition_variable batchBarrier_;
71 
72  size_t gatherSize_;
73  std::vector<float> allRewards_;
74  std::vector<int> allGenerations_;
75  std::vector<int64_t> allSeeds_;
76 
77  std::vector<float> rewards_;
78  std::vector<int> generations_;
79  std::vector<int64_t> seeds_;
80 
81  virtual void stepEpisode(
82  GameUID const&,
83  EpisodeKey const&,
84  ReplayBuffer::Episode&) override;
85 
86  ag::Container generateModel(int generation, int64_t seed);
87  void populateSeedQueue();
88 
89  public:
90  ESTrainer(
91  ag::Container model,
92  ag::Optimizer optim,
93  std::unique_ptr<BaseSampler> sampler,
94  float std,
95  size_t batchSize,
96  size_t historyLength,
97  bool antithetic,
98  RewardTransform transform,
99  bool onPolicy);
100 
101  ag::Container getGameModel(GameUID const& gameIUID, EpisodeKey const& key);
102  void forceStopEpisode(EpisodeHandle const&) override;
103  EpisodeHandle startEpisode() override;
104  bool update() override;
105  virtual ag::Variant forward(ag::Variant inp, EpisodeHandle const&) override;
106  std::shared_ptr<Evaluator> makeEvaluator(
107  size_t n,
108  std::unique_ptr<BaseSampler> sampler =
109  std::make_unique<DiscreteMaxSampler>()) override;
110  torch::Tensor rewardTransform(
111  torch::Tensor const& rewards,
112  RewardTransform transform);
113  void reset() override;
114  virtual std::shared_ptr<ReplayBufferFrame> makeFrame(
115  ag::Variant trainerOutput,
116  ag::Variant state,
117  float reward) override;
118 
119  // If set to true, after successful update() worker threads would remain
120  // blocked until the next update() call.
121  TORCH_ARG(bool, waitUpdate) = false;
122 };
123 
124 } // namespace cpid
torch::Tensor rewardTransform(torch::Tensor const &rewards, RewardTransform transform)
Definition: estrainer.cpp:458
std::string GameUID
Definition: trainer.h:31
ag::Container model() const
Definition: trainer.cpp:231
size_t gamesStarted_
Definition: estrainer.h:69
std::shared_timed_mutex currentModelMutex_
Definition: estrainer.h:55
std::vector< int64_t > seedQueue_
Definition: estrainer.h:66
Definition: trainer.h:158
virtual std::shared_ptr< ReplayBufferFrame > makeFrame(ag::Variant trainerOutput, ag::Variant state, float reward) override
Definition: estrainer.cpp:505
Definition: estrainer.h:30
std::deque< std::pair< int, ag::Container > > modelsHistory_
Definition: estrainer.h:54
void reset() override
Releases all the worker threads so that they can be joined.
Definition: estrainer.cpp:493
std::mutex seedQueueMutex_
Definition: estrainer.h:65
STL namespace.
bool onPolicy_
Definition: estrainer.h:60
std::shared_timed_mutex insertionMutex_
Definition: estrainer.h:62
std::condition_variable batchBarrier_
Definition: estrainer.h:70
std::vector< float > rewards_
Definition: estrainer.h:77
std::vector< int > allGenerations_
Definition: estrainer.h:74
virtual void stepEpisode(GameUID const &, EpisodeKey const &, ReplayBuffer::Episode &) override
Definition: estrainer.cpp:81
RewardTransform
Definition: estrainer.h:22
Definition: estrainer.h:23
The Trainer should be shared amongst multiple different nodes, and attached to a single Module...
Definition: trainer.h:156
void populateSeedQueue()
Definition: estrainer.cpp:485
bool antithetic_
Definition: estrainer.h:57
std::deque< std::pair< GameUID, EpisodeKey > > newGames_
Definition: estrainer.h:63
std::vector< float > allRewards_
Definition: estrainer.h:73
EpisodeHandle startEpisode() override
Returns true if succeeded to register an episode, and false otherwise.
Definition: estrainer.cpp:318
virtual ag::Variant forward(ag::Variant inp, EpisodeHandle const &) override
Definition: estrainer.cpp:437
Definition: estrainer.h:27
bool update() override
Definition: estrainer.cpp:100
void forceStopEpisode(EpisodeHandle const &) override
Definition: estrainer.cpp:306
std::vector< int64_t > seeds_
Definition: estrainer.h:79
std::shared_timed_mutex modelStorageMutex_
Definition: estrainer.h:47
size_t gatherSize_
Definition: estrainer.h:72
std::string EpisodeKey
Definition: trainer.h:32
ag::Container getGameModel(GameUID const &gameIUID, EpisodeKey const &key)
Definition: estrainer.cpp:371
std::vector< int64_t > allSeeds_
Definition: estrainer.h:75
float std_
Definition: estrainer.h:34
size_t historyLength_
Definition: estrainer.h:50
std::shared_ptr< Evaluator > makeEvaluator(size_t n, std::unique_ptr< BaseSampler > sampler=std::make_unique< DiscreteMaxSampler >()) override
Definition: estrainer.cpp:445
ag::Container generateModel(int generation, int64_t seed)
Re-creates model based on its seed and the generation it was produced from.
Definition: estrainer.cpp:394
RewardTransform transform_
Definition: estrainer.h:58
ag::Optimizer optim() const
Definition: trainer.cpp:235
The TorchCraftAI training library.
Definition: batcher.cpp:15
TORCH_ARG(bool, waitUpdate)
size_t batchSize_
Definition: estrainer.h:35
std::unordered_map< std::pair< GameUID, EpisodeKey >, std::pair< int, int64_t >, pairhash > gameToGenerationSeed_
Definition: estrainer.h:46
ESTrainer(ag::Container model, ag::Optimizer optim, std::unique_ptr< BaseSampler > sampler, float std, size_t batchSize, size_t historyLength, bool antithetic, RewardTransform transform, bool onPolicy)
Definition: estrainer.cpp:53
Definition: trainer.h:44
std::vector< std::shared_ptr< ReplayBufferFrame >> Episode
Definition: trainer.h:89
std::vector< int > generations_
Definition: estrainer.h:78
std::unordered_map< std::pair< int, int64_t >, ag::Container, pairhash > modelCache_
Definition: estrainer.h:40
std::mutex updateMutex_
Definition: estrainer.h:68
Definition: estrainer.h:20