TorchCraftAI
A bot for machine learning research on StarCraft: Brood War
sample.h
1 /*
2  * Copyright (c) 2017-present, Facebook, Inc.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #pragma once
9 
10 #include "cherrypi.h"
11 #include "features/features.h"
12 #include "features/unitsfeatures.h"
13 #include "modules/autobuild.h"
14 
15 #ifdef HAVE_CPID
16 #include <cpid/trainer.h>
17 #endif // HAVE_CPID
18 
19 #include <autogradpp/autograd.h>
20 
21 #include <regex>
22 
23 #include <cereal/archives/binary.hpp>
24 #include <cereal/types/vector.hpp>
25 
26 namespace cereal {
27 template <class Archive>
28 void serialize(Archive& archive, torchcraft::Resources& resource) {
29  archive(
30  resource.ore,
31  resource.gas,
32  resource.used_psi,
33  resource.total_psi,
34  resource.upgrades,
35  resource.upgrades_level,
36  resource.techs);
37 }
38 } // namespace cereal
39 
40 namespace cherrypi {
41 namespace bos {
42 
43 std::map<std::string, int64_t> const& buildOrderMap();
44 std::vector<std::string> const& targetBuilds();
45 std::string allowedTargetsAsFlag();
46 std::string allowedOpeningsAsFlag();
47 
48 inline char getOpponentRace(std::string const& opponent) {
49  std::regex race_regex("_([PZT]{1})_");
50  std::smatch race_match;
51  if (!std::regex_search(opponent, race_match, race_regex)) {
52  throw std::runtime_error("Opponent string is not correct!");
53  }
54  return race_match[0].str()[1];
55 }
56 
57 inline std::string addRacePrefix(std::string buildOrder, char prefix) {
58  return std::string(1, prefix) + "-" + buildOrder;
59 }
60 
61 inline std::string addRacePrefix(std::string buildOrder, int race) {
62  return addRacePrefix(
63  std::move(buildOrder),
64  tc::BW::Race::_from_integral(race)._to_string()[0]);
65 }
66 
67 inline std::string stripRacePrefix(std::string prefixedBo) {
68  return prefixedBo.substr(2);
69 }
70 
71 inline int64_t buildOrderId(std::string const& bo) {
72  auto const& boMap = buildOrderMap();
73  auto it = boMap.find(bo);
74  if (it == boMap.end()) {
75  throw std::runtime_error("Unknown build order: " + bo);
76  }
77  return it->second;
78 }
79 
80 torch::Tensor getBuildOrderMaskByRace(char race);
81 torch::Tensor getBuildOrderMaskByRace(int race);
82 
83 /// A list of possible features that can be extracted from a Sample
84 enum class BosFeature {
85  Undef,
86  /// Map features from StaticData
87  Map,
88  /// Map "ID" based on sum of map features
89  MapId,
90  /// 2-dimensional: our and their race
91  Race,
92  /// Defogger-style pooled unit types
93  Units,
94  /// Bag-of-words unit type counts
96  /// Bag-of-words unit type counts in future autobuild states (ours only)
98  /// Ore/Gas/UsedPsi/TotalPsi: log(x / 5 + 1)
100  /// 142-dim tech/upgrade vector: one bit for each upgrade/level/tech
102  /// 142-dim vector of pending upgrades/techs
104  /// Numerical frame value
105  TimeAsFrame,
106  /// Id of active build order
107  ActiveBo,
108  /// Id of next build order
109  NextBo,
110 };
111 
112 // Features that don't change throughout the game
113 struct StaticData {
114  static int constexpr kMapSize = 512; // walk tiles
115  static int constexpr kNumMapChannels = 4;
116 
117  /// Various map features
119  /// Probability of having taken a random switch (per sample).
120  /// Only used during supervised data generation.
121  float switchProba = 0.0f;
122  /// Race for our player (0) and the opponent (1)
123  int race[2];
124  /// Player name of opponent
125  std::string opponentName;
126  /// Did we win this game?
127  bool won = false;
128  /// Game Id (optional)
129  std::string gameId;
130 
131  StaticData() = default;
132  StaticData(State* state);
133 
134  template <class Archive>
135  void serialize(Archive& ar, uint32_t const version) {
136  ar(CEREAL_NVP(map),
137  CEREAL_NVP(switchProba),
138  CEREAL_NVP(race[0]),
139  CEREAL_NVP(race[1]),
140  CEREAL_NVP(opponentName),
141  CEREAL_NVP(won));
142  if (version > 0) {
143  ar(CEREAL_NVP(gameId));
144  }
145  }
146 };
147 
148 struct Sample {
149  std::shared_ptr<StaticData> staticData;
150  /// Defogger style unit types in spatial representation
152  /// Frame number of this sample
154  /// Our resources
156  /// Current build order
157  std::string buildOrder;
158  /// Build order until next sample
159  std::string nextBuildOrder;
160  /// Whether we've switched the build order after taking this sample
161  bool switched = false;
162  /// Upgrades that are currently being researched
163  uint64_t pendingUpgrades = 0;
164  /// Levels for upgrades that are currently researched
165  uint64_t pendingUpgradesLevel = 0;
166  /// Techs that are currently being researched
167  uint64_t pendingTechs = 0;
168  /// Future autobuild states for given frame offsets
169  std::map<int, autobuild::BuildState> nextAbboStates;
170 
171  Sample() = default;
172  Sample(
173  State* state,
174  int res,
175  int sride,
176  std::shared_ptr<StaticData> sd = nullptr);
177  virtual ~Sample() = default;
178 
179  template <class Archive>
180  void serialize(Archive& ar, uint32_t const version) {
181  if (version < 2) {
182  throw std::runtime_error("Unsupported version");
183  }
184  // Note that Cereal will serialized shared_ptr instances only once per
185  // archive
186  ar(CEREAL_NVP(staticData),
187  CEREAL_NVP(units),
188  CEREAL_NVP(frame),
189  CEREAL_NVP(resources),
190  CEREAL_NVP(buildOrder),
191  CEREAL_NVP(nextBuildOrder),
192  CEREAL_NVP(switched),
193  CEREAL_NVP(pendingUpgrades),
194  CEREAL_NVP(pendingUpgradesLevel),
195  CEREAL_NVP(pendingTechs),
196  CEREAL_NVP(nextAbboStates));
197  if (version == 2) {
198  // Unit features saved with /= 10 instead of gscore
199  renormV2Features();
200  }
201  }
202 
203  torch::Tensor featurize(
204  BosFeature feature,
205  torch::Tensor buffer = torch::Tensor()) const;
206  ag::tensor_list featurize(std::vector<BosFeature> features) const;
207  void renormV2Features();
208 
209  static std::map<int, autobuild::BuildState> simulateAbbo(
210  State* state,
211  std::string const& buildOrder,
212  std::vector<int> const& frameOffsets);
213 };
214 
215 #ifdef HAVE_CPID
216 struct ReplayBufferFrame : cpid::ReplayBufferFrame {
217  Sample sample;
218 
219  ReplayBufferFrame() = default;
220  ReplayBufferFrame(Sample sample) : sample(std::move(sample)) {}
221  virtual ~ReplayBufferFrame() = default;
222 
223  template <class Archive>
224  void serialize(Archive& ar, uint32_t const version) {
225  ar(cereal::base_class<cpid::ReplayBufferFrame>(this), sample);
226  }
227 };
228 
229 struct EpisodeData {
230  cpid::GameUID gameId;
231  cpid::EpisodeKey episodeKey;
232  std::vector<std::shared_ptr<cpid::ReplayBufferFrame>> frames;
233 
234  template <class Archive>
235  void serialize(Archive& ar) {
236  ar(gameId);
237  ar(episodeKey);
238  uint32_t size = frames.size();
239  ar(size);
240  frames.resize(size);
241  for (auto i = 0U; i < size; i++) {
242  ar(frames[i]);
243  }
244  }
245 };
246 #endif // HAVE_CPID
247 
248 } // namespace bos
249 
250 // Backwards compatibility with previous naming
251 // using BosStaticData = bos::StaticData;
252 // using BosSample = bos::Sample;
255 #ifdef HAVE_CPID
256 typedef bos::ReplayBufferFrame BosReplayBufferFrame;
257 typedef bos::EpisodeData BosEpisodeData;
258 #endif // HAVE_CPID
259 
260 } // namespace cherrypi
261 
262 CEREAL_CLASS_VERSION(cherrypi::BosStaticData, 2);
263 CEREAL_CLASS_VERSION(cherrypi::BosSample, 3);
264 #ifdef HAVE_CPID
265 CEREAL_REGISTER_TYPE(cherrypi::BosReplayBufferFrame);
266 CEREAL_CLASS_VERSION(cherrypi::BosReplayBufferFrame, 0);
267 #endif // HAVE_CPID
Map "ID" based on sum of map features.
Game state.
Definition: state.h:42
int FrameNum
Definition: basetypes.h:22
std::string GameUID
Definition: trainer.h:31
uint64_t upgrades_level
Definition: frame.h:172
int32_t used_psi
Definition: frame.h:169
std::string stripRacePrefix(std::string prefixedBo)
Definition: sample.h:67
std::string gameId
Game Id (optional)
Definition: sample.h:129
FeatureData units
Defogger style unit types in spatial representation.
Definition: sample.h:151
int32_t gas
Definition: frame.h:168
BosFeature
A list of possible features that can be extracted from a Sample.
Definition: sample.h:84
uint64_t techs
Definition: frame.h:173
std::string nextBuildOrder
Build order until next sample.
Definition: sample.h:159
void serialize(Archive &archive, torchcraft::Resources &resource)
Definition: sample.h:28
int64_t buildOrderId(std::string const &bo)
Definition: sample.h:71
std::string buildOrder
Current build order.
Definition: sample.h:157
char getOpponentRace(std::string const &opponent)
Definition: sample.h:48
Definition: frame.h:166
torch::Tensor getBuildOrderMaskByRace(char race)
Definition: sample.cpp:175
Ore/Gas/UsedPsi/TotalPsi: log(x / 5 + 1)
int32_t ore
Definition: frame.h:167
FrameNum frame
Frame number of this sample.
Definition: sample.h:153
void serialize(Archive &ar, uint32_t const version)
Definition: sample.h:135
std::shared_ptr< StaticData > staticData
Definition: sample.h:149
std::string opponentName
Player name of opponent.
Definition: sample.h:125
Map features from StaticData.
Definition: sample.h:26
Represents a collection of spatial feature data.
Definition: features.h:190
std::string addRacePrefix(std::string buildOrder, char prefix)
Definition: sample.h:57
2-dimensional: our and their race
std::string EpisodeKey
Definition: trainer.h:32
tc::Resources resources
Our resources.
Definition: sample.h:155
FeatureData map
Various map features.
Definition: sample.h:118
Id of active build order.
std::map< int, autobuild::BuildState > nextAbboStates
Future autobuild states for given frame offsets.
Definition: sample.h:169
uint64_t upgrades
Definition: frame.h:171
Stub base class for replay buffer frames.
Definition: trainer.h:69
Bag-of-words unit type counts.
int32_t total_psi
Definition: frame.h:170
Bag-of-words unit type counts in future autobuild states (ours only)
void serialize(Archive &ar, uint32_t const version)
Definition: sample.h:180
Id of next build order.
Definition: sample.h:148
Main namespace for bot-related code.
Definition: areainfo.cpp:17
Defogger-style pooled unit types.
Definition: sample.h:113
142-dim tech/upgrade vector: one bit for each upgrade/level/tech
142-dim vector of pending upgrades/techs