TorchCraftAI
A bot for machine learning research on StarCraft: Brood War
unitsfeatures.h
1 /*
2  * Copyright (c) 2017-present, Facebook, Inc.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #pragma once
9 
10 #include "features/features.h"
11 #include "features/jitter.h"
12 
13 namespace cherrypi {
14 
15 /**
16  * Abstract base class for featurizing unit attributes in a sparse manner.
17  *
18  * General usage of sub-classes for actual feature extraction boils down to
19  * calling extract() with a desired subset of units to featurize. The resulting
20  * data is sparse wrt positions, i.e. it contains a tensor of positions for each
21  * unit and the accompanying data as defined by a featurizer implementation.
22  *
23  * toSpatialFeature() will transform the given data to a `FeatureData` object,
24  * i.e. it will place the feature data at the respective positions.
25  *
26  * Optionally, users can set a jittering method that will be accounted for in
27  * extract().
28  */
30  using UnitFilter = std::function<bool(Unit*)>;
31  using TensorDest = torch::TensorAccessor<float, 1>;
32 
33  struct Data {
35  // TODO Undefined position and data tensors currently represent an empty set
36  // of units. Are tensors with a 0-sized dimension possible?
37  torch::Tensor positions; /// #units X 2 (y, x)
38  torch::Tensor data; /// #units X nchannels
39 
40  template <typename Archive>
41  void serialize(Archive& ar) {
42  ar(CEREAL_NVP(boundingBox), CEREAL_NVP(positions), CEREAL_NVP(data));
43  }
44  };
45 
46  /// Optional jittering of unit positions
47  std::shared_ptr<BaseJitter> jitter = std::make_shared<NoJitter>();
48 
49  // This information will be used for converting this featurizer's data into
50  // a spatial feature. Subclasses should customize this in their constructors.
52  std::string name;
54 
55  virtual ~UnitAttributeFeaturizer() = default;
56 
57  /// Extract unit features for a given set of units
58  virtual Data extract(
59  State* state,
60  UnitsInfo::Units const& units,
61  Rect const& boundingBox = Rect());
62  /// Extract unit features for all live units
63  Data extract(State* state, Rect const& boundingBox = Rect());
64  /// Extract unit features for all live units that pass the given filter
65  Data
66  extract(State* state, UnitFilter filter, Rect const& boundingBox = Rect());
67 
68  /// Embeds the unit attribute data into a spatial feature
70  Data const& data,
71  SubsampleMethod pooling = SubsampleMethod::Sum) const;
72 
73  /// Embeds the unit attribute data into a spatial feature.
74  /// This version will re-use the tensor memory of the given feature data
75  /// instance.
76  void toSpatialFeature(
77  FeatureData* dest,
78  Data const& data,
79  SubsampleMethod pooling = SubsampleMethod::Sum) const;
80 
81  protected:
82  /// Reimplement this in actual featurizers.
83  /// This function is expected to set acc[0], ..., acc[numChannels-1]
84  virtual void extractUnit(TensorDest acc, Unit* unit) = 0;
85 };
86 
87 /**
88  * Sparse featurizer for unit presence.
89  *
90  * This will produce a binary feature with a single channel: 0 if there is no
91  * unit, 1 if there is a unit.
92  */
96  name = "UnitPresence";
97  numChannels = 1;
98  }
99 
100  protected:
101  virtual void extractUnit(TensorDest acc, Unit*) override {
102  // Simply mark this unit as being present
103  acc[0] = 1;
104  }
105 };
106 
107 /**
108  * Sparse featurizer for numeric unit types.
109  *
110  * This will produce a single-channel feature that contains a unit type ID for
111  * each unit. Unit IDs are mutually exclusive for allied (0-232), enemy
112  * (233-465) and neutral units (466-698).
113  *
114  * The resulting sparse feature is suitable for embedding via lookup tables.
115  */
117  static int constexpr kNumUnitTypes = 233 * 3;
118 
121  name = "UnitType";
122  numChannels = 1;
123  }
124 
125  FeatureData toOneHotSpatialFeature(
126  Data const& data,
127  int unitValueOffset,
128  std::unordered_map<int, int> const& channelValues) const;
129 
130  protected:
131  virtual void extractUnit(TensorDest acc, Unit* unit) override {
132  if (unit->isMine) {
133  acc[0] = unit->type->unit + 233 * 0;
134  } else if (unit->isEnemy) {
135  acc[0] = unit->type->unit + 233 * 1;
136  } else if (unit->isNeutral) {
137  acc[0] = unit->type->unit + 233 * 2;
138  }
139  }
140 };
141 
142 /**
143  * Sparse featurizer for unit types, defogger-style.
144  *
145  * This featurizer maps unit types to 118 IDs (instead of the 234 possible IDs)
146  * and assigns valid IDs to allied and enemy units only -- neutral units will be
147  * be mapped to -1.
148  *
149  * `toDefoggerFeature()` supports pooling with a given resolution and
150  * stride so that the result contains accumulated unit counts per type for each
151  * "cell". It ignores neutral units.
152  */
154  static int constexpr kNumUnitTypes = 118 * 2;
155 
157 
158  int mapType(int unitType) const {
159  return typemap_->at(unitType);
160  }
161  int unmapType(int mappedType) const {
162  return itypemap_->at(mappedType);
163  }
164 
165  FeatureData toDefoggerFeature(Data const& data, int res, int stride) const;
166 
167  protected:
168  virtual void extractUnit(TensorDest acc, Unit* unit) override {
169  if (unit->isMine) {
170  acc[0] = mapType(unit->type->unit) + 118 * 0;
171  } else if (unit->isEnemy) {
172  acc[0] = mapType(unit->type->unit) + 118 * 1;
173  } else {
174  acc[0] = -1;
175  }
176  }
177 
178  std::array<int, 234>* typemap_;
179  std::array<int, 234>* itypemap_;
180 };
181 
182 /**
183  * A variant of UnitTypeDefoggerFeaturizer that stores the target type of
184  * morphing units.
185  *
186  * Morphing Zerglings will be featurized as two units.
187  */
189  static int constexpr kNumUnitTypes = 118 * 2;
190 
192 
193  /// Extract unit features for a given set of units
194  virtual Data extract(
195  State* state,
196  UnitsInfo::Units const& units,
197  Rect const& boundingBox = Rect());
198 };
199 
200 /**
201  * Sparse featurizer for unit flags.
202  *
203  * This will produce a feature with 52 channels, where each channel corresponds
204  * to a flag of torchcraft::replayer::Unit. Each channel is binary, i.e. it's 1
205  * if the flag is set and 0 otherwise.
206  */
208  static int constexpr kNumUnitFlags = 52;
209 
212  name = "UnitFlags";
213  numChannels = kNumUnitFlags;
214  }
215 
216  protected:
217  virtual void extractUnit(TensorDest acc, Unit* unit) override {
218  for (auto flag = 0; flag < kNumUnitFlags; flag++) {
219  acc[flag] = (unit->unit.flags & (1 << flag)) ? 1 : 0;
220  }
221  }
222 };
223 
225  static constexpr int kNumChannels =
227 
230  name = "UnitStat";
231  numChannels = kNumChannels;
232  }
233 
234  protected:
235  virtual void extractUnit(TensorDest, cherrypi::Unit*) override;
236 };
237 
238 /**
239  * Sparse featurizer for unit ATTR.
240  *
241  * This will produce a single-channel feature that contains the ATTR for
242  * each unit.
243  */
244 #define GEN_SPARSE_UNIT_ATTRIBUTE_FEATURIZER(NAME, ATTR) \
245  struct Unit##NAME##Featurizer : UnitAttributeFeaturizer { \
246  Unit##NAME##Featurizer() { \
247  type = CustomFeatureType::Unit##NAME; \
248  name = "Unit" #NAME; \
249  numChannels = 1; \
250  } \
251  \
252  protected: \
253  virtual void extractUnit(TensorDest acc, Unit* unit) override { \
254  acc[0] = unit->unit.ATTR; \
255  } \
256  };
257 
260 GEN_SPARSE_UNIT_ATTRIBUTE_FEATURIZER(GroundCD, groundCD);
262 
263 #undef GEN_SPARSE_UNIT_ATTRIBUTE_FEATURIZER
264 } // namespace cherrypi
Game state.
Definition: state.h:42
int numChannels
Definition: unitsfeatures.h:53
torch::Tensor positions
Definition: unitsfeatures.h:37
UnitTypeFeaturizer()
Definition: unitsfeatures.h:119
UnitPresenceFeaturizer()
Definition: unitsfeatures.h:94
static int constexpr kNumUnitFlags
Definition: unitsfeatures.h:208
virtual void extractUnit(TensorDest acc, Unit *unit) override
Reimplement this in actual featurizers.
Definition: unitsfeatures.h:168
Rect boundingBox
Definition: unitsfeatures.h:34
UnitStatFeaturizer()
Definition: unitsfeatures.h:228
SubsampleMethod
Various methods for spatial subsampling.
Definition: features.h:218
virtual void extractUnit(TensorDest acc, Unit *unit) override
Reimplement this in actual featurizers.
Definition: unitsfeatures.h:217
Sparse featurizer for unit presence.
Definition: unitsfeatures.h:93
Rect2T< int > Rect
Definition: basetypes.h:309
Abstract base class for featurizing unit attributes in a sparse manner.
Definition: unitsfeatures.h:29
int unmapType(int mappedType) const
Definition: unitsfeatures.h:161
bool isNeutral
Definition: unitsinfo.h:55
std::function< bool(Unit *)> UnitFilter
Definition: unitsfeatures.h:30
Definition: unitsfeatures.h:224
virtual Data extract(State *state, UnitsInfo::Units const &units, Rect const &boundingBox=Rect())
Extract unit features for a given set of units.
Definition: unitsfeatures.cpp:48
virtual void extractUnit(TensorDest acc, Unit *) override
Reimplement this in actual featurizers.
Definition: unitsfeatures.h:101
Sparse featurizer for numeric unit types.
Definition: unitsfeatures.h:116
Sparse featurizer for unit flags.
Definition: unitsfeatures.h:207
CustomFeatureType type
Definition: unitsfeatures.h:51
torch::TensorAccessor< float, 1 > TensorDest
Definition: unitsfeatures.h:31
int mapType(int unitType) const
Definition: unitsfeatures.h:158
CustomFeatureType
Defines custom features.
Definition: features.h:87
Represents a collection of spatial feature data.
Definition: features.h:190
Represents a unit in the game.
Definition: unitsinfo.h:35
bool isEnemy
Definition: unitsinfo.h:54
virtual ~UnitAttributeFeaturizer()=default
GEN_SPARSE_UNIT_ATTRIBUTE_FEATURIZER(HP, health)
std::array< int, 234 > * itypemap_
Definition: unitsfeatures.h:179
Definition: unitsfeatures.h:33
tc::Unit unit
A copy of the torchcraft unit data.
Definition: unitsinfo.h:81
FeatureData toSpatialFeature(Data const &data, SubsampleMethod pooling=SubsampleMethod::Sum) const
Embeds the unit attribute data into a spatial feature.
Definition: unitsfeatures.cpp:110
bool isMine
Definition: unitsinfo.h:53
int unit
Definition: buildtype.h:37
Main namespace for bot-related code.
Definition: areainfo.cpp:17
std::array< int, 234 > * typemap_
Definition: unitsfeatures.h:178
UnitFlagsFeaturizer()
Definition: unitsfeatures.h:210
const BuildType * type
Definition: unitsinfo.h:56
void serialize(Archive &ar)
#units X nchannels
Definition: unitsfeatures.h:41
Sparse featurizer for unit types, defogger-style.
Definition: unitsfeatures.h:153
torch::Tensor data
#units X 2 (y, x)
Definition: unitsfeatures.h:38
virtual void extractUnit(TensorDest acc, Unit *unit) override
Reimplement this in actual featurizers.
Definition: unitsfeatures.h:131
std::vector< Unit * > Units
Definition: unitsinfo.h:296
uint64_t flags
Definition: frame.h:85
std::shared_ptr< BaseJitter > jitter
Optional jittering of unit positions.
Definition: unitsfeatures.h:47
A variant of UnitTypeDefoggerFeaturizer that stores the target type of morphing units.
Definition: unitsfeatures.h:188
std::string name
Definition: unitsfeatures.h:52
virtual void extractUnit(TensorDest acc, Unit *unit)=0
Reimplement this in actual featurizers.