TorchCraftAI
A bot for machine learning research on StarCraft: Brood War
features.h
1 /*
2  * Copyright (c) 2017-present, Facebook, Inc.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #pragma once
9 
10 #include "cherrypi.h"
11 #include "unitsinfo.h"
12 
13 #include <autogradpp/autograd.h>
14 #include <cereal/cereal.hpp>
15 #include <glog/logging.h>
16 #include <mapbox/variant.hpp>
17 
18 #include <deque>
19 #include <functional>
20 #include <vector>
21 
22 namespace cherrypi {
23 
24 class State;
25 
26 namespace features {
27 void initialize();
28 }
29 
30 /**
31  * Defines a family of "plain" features.
32  *
33  * These features can directly be extracted from the bot State into a spatial
34  * FeatureData instance. Use featurizePlain() to extract these features.
35  */
36 enum class PlainFeatureType {
37  Invalid = -1,
38  /// Ground height: 0 (low), 1 (high) or 2 (very high); -1 outside of map.
39  GroundHeight = 1,
40  /// Whether units can walk here or not: 0 or 1; -1 outside of map
42  /// Whether buildings can be placed here or not: 0 or 1; -1 outside of map
44  /// Whether this position is under the fog of war: 0 or 1; -1 outside of map
45  FogOfWar,
46  /// Whether there is creep here: 0 or 1; -1 outside of map
47  Creep,
48  /// Whether the enemy starts from this position: 0 or 1, -1 outside of map
50  /// Whether the corresponding buildtile is reserved
52  /// Whether this walktile contains a doodad that alters the ground height and
53  /// thus affects visibility and attack miss rates.
54  TallDoodad,
55  /// One-hot ground height: channel for height 0, 2, 4 and on the map (4
56  /// total)
58  /// Whether this position is a starting location
60  /// Grid of X/Y coordinates from (0,0) top left to (N,M) bottom right. One
61  /// channel for Y, one channel for X. -1 outside of map. N is map_width/512, M
62  /// map_height/512 (all in walktiles).
63  XYGrid,
64  /// 1 if there is a resource tile at this location, 0 otherwise
65  Resources,
66  /// This map tile has a structure on it, so it's not passable.
67  /// Since this works at the walktile level and structures are on pixels,
68  /// it will mark a walktile as impassable as long as the walktile is at
69  /// all partially impassable.
71 
72  /// User-defined single-channel feature
73  UserFeature1 = 1001,
74  /// User-defined two channel feature
75  UserFeature2 = 1002,
76 };
77 
78 /**
79  * Defines custom features.
80  *
81  * These features are extracted using various custom feature extractors. They're
82  * defined explicitly so that they can be referred to easily in feature
83  * descriptors.
84  *
85  * Use this enum as a central "registry" for your feature type.
86  */
87 enum class CustomFeatureType {
88  UnitPresence = 10001,
89  UnitType,
90  UnitFlags,
91  UnitHP,
92  UnitShield,
94  UnitAirCD,
95  UnitStat,
97  UnitTypeMDefogger, // featurizes morphing units with target type
98 
99  Other = 1 >> 30,
100 };
101 
102 using AnyFeatureType =
103  mapbox::util::variant<PlainFeatureType, CustomFeatureType>;
104 
105 /**
106  * Decribes a specific feature within FeatureData.
107  */
110  std::string name;
112 
113  private:
114  static int constexpr kPlainType = 0;
115  static int constexpr kCustomType = 1;
116 
117  public:
119  FeatureDescriptor(PlainFeatureType type, std::string name, int numChannels)
120  : type(type), name(std::move(name)), numChannels(numChannels) {}
121  FeatureDescriptor(CustomFeatureType type, std::string name, int numChannels)
122  : type(type), name(std::move(name)), numChannels(numChannels) {}
124  : type(other.type), name(other.name), numChannels(other.numChannels) {}
126  : type(std::move(other.type)),
127  name(std::move(other.name)),
128  numChannels(other.numChannels) {}
129  ~FeatureDescriptor() = default;
131  type = other.type;
132  name = other.name;
133  numChannels = other.numChannels;
134  return *this;
135  }
137  type = std::move(other.type);
138  name = std::move(other.name);
139  numChannels = other.numChannels;
140  return *this;
141  }
142 
143  // Mostly for tests...
144  bool operator==(FeatureDescriptor const& other) const;
145 
146  template <typename Archive>
147  void save(Archive& ar) const {
148  int kind = -1;
149  int value = -1;
150  type.match(
151  [&](PlainFeatureType t) {
152  kind = kPlainType;
153  value = static_cast<int>(t);
154  },
155  [&](CustomFeatureType t) {
156  kind = kCustomType;
157  value = static_cast<int>(t);
158  });
159  ar(CEREAL_NVP(kind),
160  CEREAL_NVP(value),
161  CEREAL_NVP(name),
162  CEREAL_NVP(numChannels));
163  }
164 
165  template <typename Archive>
166  void load(Archive& ar) {
167  int kind = -1;
168  int value = -1;
169  ar(CEREAL_NVP(kind),
170  CEREAL_NVP(value),
171  CEREAL_NVP(name),
172  CEREAL_NVP(numChannels));
173  switch (kind) {
174  case kPlainType:
175  type = static_cast<PlainFeatureType>(value);
176  break;
177  case kCustomType:
178  type = static_cast<CustomFeatureType>(value);
179  break;
180  default:
181  throw std::runtime_error(
182  "Unknown feature kind: " + std::to_string(kind));
183  };
184  }
185 };
186 
187 /**
188  * Represents a collection of spatial feature data.
189  */
190 struct FeatureData {
191  /// Format is [c][y][x]
192  torch::Tensor tensor;
193  std::vector<FeatureDescriptor> desc;
194  /// Decimation factor wrt walktile resolution
195  int scale;
196  /// [0][0] of tensor corresponds to this point (walktiles)
198 
199  /// Number of channels in tensor
200  int numChannels() const;
201  /// Bounding box in walktiles
202  Rect boundingBox() const;
203  /// Bounding box in current scale
204  Rect boundingBoxAtScale() const;
205 
206  template <typename Archive>
207  void serialize(Archive& ar) {
208  ar(CEREAL_NVP(tensor),
209  CEREAL_NVP(desc),
210  CEREAL_NVP(scale),
211  CEREAL_NVP(offset));
212  }
213 };
214 
215 /**
216  * Various methods for spatial subsampling.
217  */
218 enum class SubsampleMethod {
219  Sum,
220  Average,
221  Max,
222 };
223 
224 /**
225  * Extracts plain features from the current state.
226  * boundingBox defaults to all available data, but can also be larger to have
227  * constant-size features irrespective of actual map size, for example.
228  */
230  State* state,
231  std::vector<PlainFeatureType> types,
232  Rect boundingBox = Rect());
233 
234 /**
235  * Combines multiple features along channels.
236  * Ensures they have the same scale and performs zero-padding according to
237  * feature offsets.
238  */
239 FeatureData combineFeatures(std::vector<FeatureData> const& feats);
240 
241 /**
242  * Selects a subset of features.
243  * Assumes that types is a subset of the ones the feat. If not, you'll get
244  * an exception. Reorders types from feat to be as in types.
245  */
247  FeatureData const& feat,
248  std::vector<AnyFeatureType> types);
249 
250 /**
251  * Applies a spatial subsampling method to a feature.
252  * The scale of the resulting feature will be original scale times the given
253  * factor.
254  */
256  FeatureData const& feat,
257  SubsampleMethod method,
258  int factor,
259  int stride = -1);
260 
261 /**
262  * Maps walktile positions to feature positions for a given bounding box.
263  *
264  * This is mostly useful for actual featurizer implementations. Use `(x, y)` to
265  * map a position. For invalid positions (outside of the intersection of
266  * bounding box and map rectangle), `(-1, -1)` is returned.
267  */
269  FeaturePositionMapper(Rect const& boundingBox, Rect const& mapRect) {
270  Rect ir = boundingBox.intersected(mapRect);
271  irx1 = ir.left();
272  iry1 = ir.top();
273  irx2 = ir.right() - 1;
274  iry2 = ir.bottom() - 1;
275  offx = mapRect.x - boundingBox.x;
276  offy = mapRect.y - boundingBox.y;
277  }
278  int irx1;
279  int irx2;
280  int iry1;
281  int iry2;
282  int offx;
283  int offy;
284 
285  Position operator()(Position const& pos) const {
286  if (pos.x < irx1 || pos.y < iry1 || pos.x > irx2 || pos.y > iry2) {
287  return kInvalidPosition;
288  }
289  return Position(pos.x + offx, pos.y + offy);
290  }
291 };
292 
293 } // namespace cherrypi
Game state.
Definition: state.h:42
void initialize()
Definition: features.cpp:67
T bottom() const
Definition: basetypes.h:221
FeatureData subsampleFeature(FeatureData const &feat, SubsampleMethod method, int factor, int stride)
Applies a spatial subsampling method to a feature.
Definition: features.cpp:259
int offx
Definition: features.h:282
FeatureDescriptor(PlainFeatureType type, std::string name, int numChannels)
Definition: features.h:119
FeatureData featurizePlain(State *state, std::vector< PlainFeatureType > types, Rect boundingBox)
Extracts plain features from the current state.
Definition: features.cpp:131
FeatureDescriptor()
Definition: features.h:118
T right() const
Definition: basetypes.h:215
FeatureDescriptor(CustomFeatureType type, std::string name, int numChannels)
Definition: features.h:121
mapbox::util::variant< PlainFeatureType, CustomFeatureType > AnyFeatureType
Definition: features.h:103
SubsampleMethod
Various methods for spatial subsampling.
Definition: features.h:218
Whether units can walk here or not: 0 or 1; -1 outside of map.
One-hot ground height: channel for height 0, 2, 4 and on the map (4 total)
Whether the corresponding buildtile is reserved.
Grid of X/Y coordinates from (0,0) top left to (N,M) bottom right.
Rect2T< int > Rect
Definition: basetypes.h:309
int scale
Decimation factor wrt walktile resolution.
Definition: features.h:195
FeatureDescriptor(FeatureDescriptor const &other)
Definition: features.h:123
STL namespace.
Whether the enemy starts from this position: 0 or 1, -1 outside of map.
T y
Definition: basetypes.h:44
std::vector< FeatureDescriptor > desc
Definition: features.h:193
FeatureData combineFeatures(std::vector< FeatureData > const &feats)
Combines multiple features along channels.
Definition: features.cpp:177
int irx2
Definition: features.h:279
int irx1
Definition: features.h:278
void load(Archive &ar)
Definition: features.h:166
torch::Tensor tensor
Format is [c][y][x].
Definition: features.h:192
User-defined two channel feature.
Whether buildings can be placed here or not: 0 or 1; -1 outside of map.
Rect2T< T > intersected(Rect2T< T > const &r) const
Definition: basetypes.h:266
Whether there is creep here: 0 or 1; -1 outside of map.
T left() const
Definition: basetypes.h:212
Ground height: 0 (low), 1 (high) or 2 (very high); -1 outside of map.
T y
Definition: basetypes.h:185
CustomFeatureType
Defines custom features.
Definition: features.h:87
int numChannels
Definition: features.h:111
AnyFeatureType type
Definition: features.h:109
Represents a collection of spatial feature data.
Definition: features.h:190
void serialize(Archive &ar)
Definition: features.h:207
Whether this walktile contains a doodad that alters the ground height and thus affects visibility and...
PlainFeatureType
Defines a family of "plain" features.
Definition: features.h:36
T top() const
Definition: basetypes.h:218
int offy
Definition: features.h:283
Decribes a specific feature within FeatureData.
Definition: features.h:108
int iry1
Definition: features.h:280
T x
Definition: basetypes.h:184
constexpr Position kInvalidPosition
Definition: basetypes.h:179
Position offset
[0][0] of tensor corresponds to this point (walktiles)
Definition: features.h:197
Calculates which tiles should be revealed by a unit&#39;s vision.
Definition: fogofwar.h:23
FeatureData selectFeatures(FeatureData const &feat, std::vector< AnyFeatureType > types)
Selects a subset of features.
Definition: features.cpp:221
std::string name
Definition: features.h:110
Position operator()(Position const &pos) const
Definition: features.h:285
void save(Archive &ar) const
Definition: features.h:147
int iry2
Definition: features.h:281
This map tile has a structure on it, so it&#39;s not passable.
FeaturePositionMapper(Rect const &boundingBox, Rect const &mapRect)
Definition: features.h:269
FeatureDescriptor & operator=(FeatureDescriptor const &other)
Definition: features.h:130
FeatureDescriptor(FeatureDescriptor &&other)
Definition: features.h:125
Maps walktile positions to feature positions for a given bounding box.
Definition: features.h:268
T x
Definition: basetypes.h:43
Main namespace for bot-related code.
Definition: areainfo.cpp:17
Whether this position is a starting location.
User-defined single-channel feature.
Vec2T< int > Position
Definition: basetypes.h:178
FeatureDescriptor & operator=(FeatureDescriptor &&other)
Definition: features.h:136