TorchCraftAI
A bot for machine learning research on StarCraft: Brood War
sampler.h
1 /*
2  * Copyright (c) 2017-present, Facebook, Inc.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #pragma once
9 #include "trainer.h"
10 #include <autogradpp/autograd.h>
11 
12 namespace cpid {
13 
14 /**
15  * A sampler takes the output of the model, and outputs an action accordingly.
16  * The exact shape of the action is dependent on the rest of the training loop.
17  * For convenience, the base sampling function is the identity.
18  */
19 class BaseSampler {
20  public:
22 
23  virtual ~BaseSampler() = default;
24  virtual ag::Variant sample(ag::Variant in) {
25  return in;
26  };
27 
28  virtual ag::Variant computeProba(
29  const ag::Variant& in,
30  const ag::Variant& action) {
31  throw std::runtime_error("Proba computation not implemented...");
32  return ag::Variant(0);
33  }
34 };
35 
36 /**
37  * This sampler expects as input an unordered_map<string, Variant>, which
38  * contains an entry policyKey, which is a tensor of size [b, n]. It outputs the
39  * same map, with a new key actionKey, a tensor of size [b] where each entry is
40  * in [0,n-1], and is the result of multinomial sampling over pi. It also adds a
41  * key pActionKey which corresponds to the probability of the sampled action.
42  */
44  public:
46  const std::string& policyKey = kPiKey,
47  const std::string& actionKey = kActionKey,
48  const std::string& pActionKey = kPActionKey);
49  ag::Variant sample(ag::Variant in) override;
50  ag::Variant computeProba(const ag::Variant& in, const ag::Variant& action)
51  override;
52 
53  protected:
54  std::string policyKey_, actionKey_, pActionKey_;
55 };
56 
57 /**
58  * This sampler expects as input an unordered_map<string, Variant>, containing
59  * an entry QKey, which is a tensor of size [b, n]. It outputs the same map,
60  * with a new key kActionKey, a tensor of size [b] where each entry is in
61  * [0,n-1], and correspond to the action with the highest score.
62  */
64  public:
66  const std::string& policyKey = kPiKey,
67  const std::string& actionKey = kActionKey);
68  ag::Variant sample(ag::Variant in) override;
69 
70  protected:
71  std::string policyKey_, actionKey_;
72 };
73 
74 /**
75  * This sampler expects as input an unordered_map<string, Variant>, containing
76  * an entry policyKey, which is a tensor of size [b, n]. It outputs the same
77  * map, with a new key kActionKey, a tensor of size [b] where each entry
78  * action[i] is sampled from a normal distribution centered in policy[i]. It
79  * also expects the stdKey to be set, it will be used as the standard deviation
80  * of the normal. It can be either a float/double, in which case the deviation
81  * will be the same for the batch, or it can be the same shape as the policy,
82  * for a finer control. It also adds a key pActionKey which corresponds to the
83  * probability of the sampled action.
84  */
86  public:
88  const std::string& policyKey = kPiKey,
89  const std::string& stdKey = kSigmaKey,
90  const std::string& actionKey = kActionKey,
91  const std::string& pActionKey = kPActionKey);
92  ag::Variant sample(ag::Variant in) override;
93  ag::Variant computeProba(const ag::Variant& in, const ag::Variant& action)
94  override;
95 
96  protected:
97  std::string policyKey_, stdKey_;
98  std::string actionKey_, pActionKey_;
99 };
100 
101 /**
102  * This sampler expects as input an unordered_map<string, Variant> containing an
103  * entry policyKey, which is a tensor of size [b, n]. It outputs the same map,
104  * with a new key kActionKey, a clone of the policy.
105  */
107  public:
109  const std::string& policyKey = kPiKey,
110  const std::string& actionKey = kActionKey);
111  ag::Variant sample(ag::Variant in) override;
112 
113  protected:
114  std::string policyKey_;
115  std::string actionKey_;
116 };
117 
118 /**
119  * This sampler expects as input an unordered_map<string, Variant> containing an
120  * entry QKey, which is a tensor of size [b, n]. It outputs the same map, with a
121  * new key actionKey, which contains the best action with proba 1-eps, and a
122  * random action with proba eps.
123  */
125  public:
127  double eps = 0.07,
128  const std::string& QKey = kQKey,
129  const std::string& actionKey = kActionKey);
130 
131  ag::Variant sample(ag::Variant in) override;
132 
133  double eps_;
134  std::string QKey_, actionKey_;
135 };
136 } // namespace cpid
const std::string kQKey
Definition: trainer.h:37
This sampler expects as input an unordered_map<string, Variant>, containing an entry QKey...
Definition: sampler.h:63
std::string policyKey_
Definition: sampler.h:114
This sampler expects as input an unordered_map<string, Variant> containing an entry QKey...
Definition: sampler.h:124
virtual ag::Variant computeProba(const ag::Variant &in, const ag::Variant &action)
Definition: sampler.h:28
const std::string kActionKey
Definition: trainer.h:41
const std::string kPiKey
Definition: trainer.h:38
std::string actionKey_
Definition: sampler.h:115
std::string QKey_
Definition: sampler.h:134
std::string pActionKey_
Definition: sampler.h:98
const std::string kPActionKey
Definition: trainer.h:42
std::string policyKey_
Definition: sampler.h:71
BaseSampler()
Definition: sampler.h:21
std::string stdKey_
Definition: sampler.h:97
A sampler takes the output of the model, and outputs an action accordingly.
Definition: sampler.h:19
virtual ~BaseSampler()=default
This sampler expects as input an unordered_map<string, Variant>, which contains an entry policyKey...
Definition: sampler.h:43
The TorchCraftAI training library.
Definition: batcher.cpp:15
virtual ag::Variant sample(ag::Variant in)
Definition: sampler.h:24
double eps_
Definition: sampler.h:133
This sampler expects as input an unordered_map<string, Variant>, containing an entry policyKey...
Definition: sampler.h:85
This sampler expects as input an unordered_map<string, Variant> containing an entry policyKey...
Definition: sampler.h:106
const std::string kSigmaKey
Definition: trainer.h:39
std::string policyKey_
Definition: sampler.h:54