TorchCraftAI
A bot for machine learning research on StarCraft: Brood War
policygradienttrainer.h
1 /*
2  * Copyright (c) 2017-present, Facebook, Inc.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #pragma once
9 
10 #include "metrics.h"
11 #include "sampler.h"
12 #include "trainer.h"
13 #include <autogradpp/autograd.h>
14 
15 #include <queue>
16 
17 namespace cpid {
18 
21  ag::Variant state,
22  torch::Tensor action,
23  float pAction,
24  double reward)
25  : state(state), action(action), pAction(pAction), reward(reward) {}
26 
27  ag::Variant state;
28  torch::Tensor action;
29  /// Probability of action according to the policy that was used to obtain this
30  /// frame
31  float pAction;
32  /// Reward observed since taking previous action
33  double reward;
34 };
35 
36 /**
37  * Off policy policy gradient with a critic.
38  * This Trainer implements two modes:
39  * - Online:
40  * It does 1 update with the given batch size per node whenever it gets an
41  * episode. Therefore, one episode will always be new and the others will
42  * be from the replay buffer. THIS MODE IS UNTESTED
43  * - Offline:
44  * Many threads are assumed to generate episodes in the background, and
45  * it does updates in a seperate background thread.
46  * In both modes, it will first update on new episodes at least once before
47  * moving to sample from the replay buffer. If more episodes are generated than
48  * it can update, it will block until the next update. When the replaybuffer of
49  * episodes it has already updated over reaches maxBatchSize, it will remove
50  * the oldest episode it's seen.
51  *
52  * Replayer format:
53  * state, action, p(action), reward
54  *
55  * Model output:
56  * Probability vector over actions: 1-dim Vector
57  * Critic's value estimate: Double
58  */
59 class BatchedPGTrainer : public Trainer {
60  int batchSize_;
61  std::size_t maxBatchSize_;
62  double gamma_;
63  bool onlineUpdates_ = false;
64 
65  std::shared_timed_mutex updateMutex_;
66  // Games that were not used for updating the model yet
67  std::deque<std::pair<GameUID, EpisodeKey>> newGames_;
68  // Games that were already used for updating the model but which are still in
69  // the replay buffer. This will be kept <= maxBatchSize_; older games will be
70  // removed first.
71  std::queue<std::pair<GameUID, EpisodeKey>> seenGames_;
72  std::mutex newGamesMutex_;
73  bool enoughEpisodes_ = false;
74  int episodes_ = 0;
75 
76  void updateModel();
77 
78  protected:
79  void stepEpisode(GameUID const&, EpisodeKey const&, ReplayBuffer::Episode&)
80  override;
81 
82  public:
83  ag::Variant forward(ag::Variant inp, EpisodeHandle const&) override;
84  bool update() override;
85  void doOnlineUpdatesInstead();
86 
87  inline int episodes() {
88  return episodes_;
89  }
90 
92  ag::Container model,
93  ag::Optimizer optim,
94  std::unique_ptr<BaseSampler> sampler,
95  double gamma = 0.99,
96  int batchSize = 10,
97  std::size_t maxBatchSize = 50,
98  std::unique_ptr<AsyncBatcher> batcher = nullptr);
99 
100  /**
101  * Contract: the trainer output should be a map with keys: "action" for the
102  * taken action "V" for the state value, and "action" for the action
103  * probability
104  */
105  virtual std::shared_ptr<ReplayBufferFrame> makeFrame(
106  ag::Variant trainerOutput,
107  ag::Variant state,
108  float reward) override;
109  std::shared_ptr<Evaluator> makeEvaluator(
110  size_t,
111  std::unique_ptr<BaseSampler> sampler =
112  std::make_unique<DiscreteMaxSampler>()) override;
113 };
114 } // namespace cpid
std::string GameUID
Definition: trainer.h:31
Definition: trainer.h:158
ag::Variant state
Definition: policygradienttrainer.h:27
torch::Tensor action
Definition: policygradienttrainer.h:28
The Trainer should be shared amongst multiple different nodes, and attached to a single Module...
Definition: trainer.h:156
double reward
Reward observed since taking previous action.
Definition: policygradienttrainer.h:33
float pAction
Probability of action according to the policy that was used to obtain this frame. ...
Definition: policygradienttrainer.h:31
int episodes()
Definition: policygradienttrainer.h:87
std::string EpisodeKey
Definition: trainer.h:32
The TorchCraftAI training library.
Definition: batcher.cpp:15
Stub base class for replay buffer frames.
Definition: trainer.h:69
BatchedPGReplayBufferFrame(ag::Variant state, torch::Tensor action, float pAction, double reward)
Definition: policygradienttrainer.h:20
Definition: policygradienttrainer.h:19
std::vector< std::shared_ptr< ReplayBufferFrame >> Episode
Definition: trainer.h:89
Off policy policy gradient with a critic.
Definition: policygradienttrainer.h:59