TorchCraftAI
A bot for machine learning research on StarCraft: Brood War
zeroordertrainer.h
1 /*
2  * Copyright (c) 2017-present, Facebook, Inc.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #pragma once
9 
10 #include "metrics.h"
11 #include "trainer.h"
12 #include <autogradpp/autograd.h>
13 #include <stack>
14 #include <torch/torch.h>
15 
16 namespace cpid {
17 
18 /**
19  * See trainer for output format of these models
20  */
22  public:
23  // To stay true to the paper, the noise should be on the unit sphere, i.e.
24  // randn(size) / norm
25  virtual std::vector<torch::Tensor> generateNoise() = 0;
26 };
27 
28 /**
29  * State, Action taken, reward
30  * Taking in an additional action taken allows you to not just take the max
31  * action, but use your own inference strategy, for example, if some actions
32  * are invalid
33  **/
36  std::vector<torch::Tensor> state,
37  std::vector<long> actions,
38  double reward)
39  : state(state), actions(actions), reward(reward) {}
40  std::vector<torch::Tensor> state;
41  std::vector<long> actions;
42  double reward;
43 };
44 
45 /**
46  * This is a OnlineZORBTrainer that works with multiple actions per frame.
47  * Contract:
48  * - Expect to make N distinct actions, with M_{i=1,..,N} possible actions for
49  * each
50  * - The model inherits from OnlineZORBModel and returns a vector of noises
51  * - Input: the replaybuffer frame state
52  * - Output: The model should take in a state and generate an array of
53  * [\phi(s, A_1,i), w_1, ind_1, v_1, ..., \phi(s, A_N,i), w_N, ind_N, v_N]
54  * - \phi is a matrix of size [M_i, embed_size] for each action i
55  * - w has size [embed_size]
56  * - ind is an index to the noise vector generated.
57  * - v is a critic for variance reduction. Optional, (use torch::Tensor())
58  * Because of the particular way this trainer works, the sampler (inference
59  * procedureg) is part of the forward function.
60  * - The inference procedure the trainer provides through trainer->forward is
61  * argmax_i \phi * (w + \delta * noise[ind])
62  * - The critic is used in training, by doing (return - critic), where the
63  * critic is trained on the return, like in actor critic. This works because
64  * G = E_u [ f(x + d u) u ]
65  * G = E_u [ f(x + d u) u ] - E_u [ v u ] (u is gaussian this so this 0)
66  * G = E_u [ [ f(x + d u) - v ] u ]
67  *
68  * The trainer->forward function will return you
69  * [action_i, action_scores_i, ...]
70  *
71  * TODO ON MULTIPLE NODES, THIS IS UNTESTED AND DOESN'T WORK
72  * even though the logic is mostly there
73  */
74 class OnlineZORBTrainer : public Trainer {
75  int64_t episodes_ = 0;
76  std::mutex updateLock_;
77  std::mutex noiseLock_;
78  std::unordered_map<
79  GameUID,
80  std::unordered_map<EpisodeKey, std::vector<torch::Tensor>>>
81  noiseStash_;
82  std::vector<torch::Tensor> lastNoise_;
83 
84  std::atomic<int> nEpisodes_;
85 
86  public:
87  void stepEpisode(GameUID const&, EpisodeKey const&, ReplayBuffer::Episode&)
88  override;
89  bool update() override;
90  EpisodeHandle startEpisode() override;
91  ag::Variant forward(ag::Variant inp, EpisodeHandle const&) override;
92 
93  OnlineZORBTrainer(ag::Container model, ag::Optimizer optim);
94  // Contract: TrainerOutput is a map with a key "action" containing the taken
95  // action
96  virtual std::shared_ptr<ReplayBufferFrame> makeFrame(
97  ag::Variant trainerOutput,
98  ag::Variant state,
99  float reward) override;
100 
101  // Set this to non-0 to not use a critic
102  TORCH_ARG(float, valueLambda) = 0;
103  TORCH_ARG(float, delta) = 1e-3;
104  TORCH_ARG(int, batchSize) = 10;
105  // Use antithetic sampling for the noise
106  TORCH_ARG(bool, antithetic) = false;
107 };
108 
109 } // namespace cpid
std::string GameUID
Definition: trainer.h:31
Definition: trainer.h:158
State, Action taken, reward Taking in an additional action taken allows you to not just take the max ...
Definition: zeroordertrainer.h:34
The Trainer should be shared amongst multiple different nodes, and attached to a single Module...
Definition: trainer.h:156
std::vector< long > actions
Definition: zeroordertrainer.h:41
std::string EpisodeKey
Definition: trainer.h:32
See trainer for output format of these models.
Definition: zeroordertrainer.h:21
OnlineZORBReplayBufferFrame(std::vector< torch::Tensor > state, std::vector< long > actions, double reward)
Definition: zeroordertrainer.h:35
double reward
Definition: zeroordertrainer.h:42
The TorchCraftAI training library.
Definition: batcher.cpp:15
Stub base class for replay buffer frames.
Definition: trainer.h:69
std::vector< torch::Tensor > state
Definition: zeroordertrainer.h:40
virtual std::vector< torch::Tensor > generateNoise()=0
std::vector< std::shared_ptr< ReplayBufferFrame >> Episode
Definition: trainer.h:89
This is a OnlineZORBTrainer that works with multiple actions per frame.
Definition: zeroordertrainer.h:74