TorchCraftAI
A bot for machine learning research on StarCraft: Brood War
buildingplacer.h
1 /*
2  * Copyright (c) 2017-present, Facebook, Inc.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #pragma once
9 
10 #include "cherrypi.h"
11 #include "features/features.h"
12 #include "features/unitsfeatures.h"
13 
14 #include <autogradpp/autograd.h>
15 
16 namespace cherrypi {
17 
18 class State;
19 struct UPCTuple;
20 
21 /**
22  * Describes a sample that can be used to learn the BuildingPlacerModel.
23  */
25  using UnitType = int; // tc::BW::UnitType is a bit too much hassle
26  static int constexpr kMapSize = 512; // walk tiles
27  static int kNumMapChannels;
28 
29  /// Game-dependent input features
30  struct StaticData {
31  StaticData(State* state);
33  };
34 
35  /// State-dependent input features
36  struct {
37  /// Various map features (plus UPC probabilities), build tile resolution
39  /// Unit type IDs that are present
41  /// Requested building type
42  UnitType type = (+tc::BW::UnitType::MAX)._to_integral();
43  /// Float tensor that contains all valid build locations wrt the input UPC
44  /// (1 = valid, 0 = invalid). This is intended to be used as a mask for the
45  /// model output.
46  torch::Tensor validLocations;
47  } features;
48 
50 
51  /// Frame number of this sample
53 
54  /// Map name (file name for replays); optional
55  std::string mapName;
56  /// Player name; optional
57  std::string playerName;
58  /// Area ID; optional, for easier baseline computations
59  int areaId = -1;
60 
61  /// Model target output: a single position (in walk tiles)
63 
64  BuildingPlacerSample() = default;
65  virtual ~BuildingPlacerSample() = default;
66 
67  /// Constructs a new sample with features extracted from the given state
69  State* state,
70  std::shared_ptr<UPCTuple> upc,
71  StaticData* staticData = nullptr);
72 
73  /// Constructs a new sample with features extracted from the given state,
74  /// along with a target action.
76  State* state,
77  Position action,
78  std::shared_ptr<UPCTuple> upc);
79 
80  /// Assemble network input
81  std::vector<torch::Tensor> networkInput() const;
82 
83  /// Maps an action (position) in walktiles to offset in flattened output or
84  /// target tensor.
85  /// "scale" will be accounted for in addtion to the scale of the extracted
86  /// features.
87  int64_t actionToOffset(Position pos, int scale = 1) const;
88 
89  /// Maps offset in flattened output or target tensor to a walktile position.
90  /// "scale" will be accounted for in addtion to the scale of the extracted
91  /// features.
92  Position offsetToAction(int64_t offset, int scale = 1) const;
93 
94  template <class Archive>
95  void serialize(Archive& ar, uint32_t const version) {
96  ar(CEREAL_NVP(features.map),
97  CEREAL_NVP(features.units),
98  CEREAL_NVP(features.type),
99  CEREAL_NVP(features.validLocations),
100  CEREAL_NVP(frame),
101  CEREAL_NVP(mapName),
102  CEREAL_NVP(playerName),
103  CEREAL_NVP(areaId),
104  CEREAL_NVP(action));
105  }
106 };
107 
108 /**
109  * A CNN model for determining building positions.
110  *
111  * This is a relatively simple feature pyramid model. The input is 128x128
112  * (build tile resolution); after two conv layers + max pooling we are at 32x32.
113  * Then, use a series of conv layers at that scale (4 by default). Afterwards,
114  * convolutions and upsampling to go back to 128x128. There are skip connections
115  * from the first two conv layers to the deconv layers.
116  *
117  * The following properties will alter the model output:
118  * - `masked`: Softmax masking to eliminiate zero-probability positions from
119  * input UPC. If this is set to false, forward() will return return an
120  * all-ones mask.
121  * - `flatten`: Output flat tensors instead of 2-dimensional ones
122  * - `logprobs`: Output log-probabilities instead of probabilities
123  *
124  * The following properties will alter the model structure:
125  * - `num_top_channels`: The number of channels in the top-level
126  * (lowest-resolution) convolutions.
127  * - `num_top_convs`: The number of convolutional layers at the top level.
128  */
129 AUTOGRAD_CONTAINER_CLASS(BuildingPlacerModel) {
130  public:
131  TORCH_ARG(bool, masked) = false;
132  TORCH_ARG(bool, flatten) = true;
133  TORCH_ARG(bool, logprobs) = false;
134  TORCH_ARG(int, num_top_channels) = 64;
135  TORCH_ARG(int, num_top_convs) = 4;
136 
137  void reset() override;
138 
139  /// Build network input from a batch of samples.
140  ag::Variant makeInputBatch(
141  std::vector<BuildingPlacerSample> const& samples, torch::Device) const;
142  ag::Variant makeInputBatch(std::vector<BuildingPlacerSample> const& samples)
143  const;
144 
145  /**
146  * Build network input and target from a batch of samples.
147  *
148  * The first element of the resulting pair is the network input, the second
149  * element are the targets for this batch.
150  */
151  std::pair<ag::Variant, ag::Variant> makeBatch(
152  std::vector<BuildingPlacerSample> const& samples, torch::Device) const;
153  std::pair<ag::Variant, ag::Variant> makeBatch(
154  std::vector<BuildingPlacerSample> const& samples) const;
155 
156  /**
157  * Network forward.
158  *
159  * Expected input, with batch dimension as first dimension:
160  * - `maps`: map features
161  * - `units_pos`: 2D coordinates for entries in `units_data``:w
162  * - `units_data`: unit type IDs
163  * - `type`: requested building type
164  * - `valid_mask`: mask set to 1 at valid build locations, 0 otherwise
165  * Use makeBatch() to generate inputs from one or more samples.
166  *
167  * Output (batched):
168  * - `output`: probability distribution over the whole map
169  * - `mask`: effective mask that was applied to the output
170  */
171  ag::Variant forward(ag::Variant input) override;
172 
173  protected:
174  ag::Container embedU;
175  ag::Container embedT;
176  ag::Container conv1;
177  ag::Container conv2;
178  ag::Container conv3;
179  std::vector<ag::Container> convS;
180  ag::Container dconv2;
181  ag::Container skip2;
182  ag::Container postskip2;
183  ag::Container dconv1;
184  ag::Container skip1;
185  ag::Container postskip1;
186  ag::Container out;
187 };
188 
189 } // namespace cherrypi
190 
191 CEREAL_CLASS_VERSION(cherrypi::BuildingPlacerSample, 1);
Game state.
Definition: state.h:42
torch::Tensor validLocations
Float tensor that contains all valid build locations wrt the input UPC (1 = valid, 0 = invalid).
Definition: buildingplacer.h:46
int FrameNum
Definition: basetypes.h:22
FrameNum frame
Frame number of this sample.
Definition: buildingplacer.h:52
int areaId
Area ID; optional, for easier baseline computations.
Definition: buildingplacer.h:59
Position offsetToAction(int64_t offset, int scale=1) const
Maps offset in flattened output or target tensor to a walktile position.
Definition: buildingplacer.cpp:188
Position action
Model target output: a single position (in walk tiles)
Definition: buildingplacer.h:62
FeatureData map
Various map features (plus UPC probabilities), build tile resolution.
Definition: buildingplacer.h:38
FeatureData smap
Definition: buildingplacer.h:32
void serialize(Archive &ar, uint32_t const version)
Definition: buildingplacer.h:95
struct cherrypi::BuildingPlacerSample::@0 features
State-dependent input features.
Game-dependent input features.
Definition: buildingplacer.h:30
virtual ~BuildingPlacerSample()=default
static int kNumMapChannels
Definition: buildingplacer.h:27
StaticData(State *state)
Definition: buildingplacer.cpp:70
std::vector< torch::Tensor > networkInput() const
Assemble network input.
Definition: buildingplacer.cpp:168
Sparse featurizer for numeric unit types.
Definition: unitsfeatures.h:116
Represents a collection of spatial feature data.
Definition: features.h:190
static int constexpr kMapSize
Definition: buildingplacer.h:26
UnitTypeFeaturizer::Data units
Unit type IDs that are present.
Definition: buildingplacer.h:40
Describes a sample that can be used to learn the BuildingPlacerModel.
Definition: buildingplacer.h:24
std::string playerName
Player name; optional.
Definition: buildingplacer.h:57
AUTOGRAD_CONTAINER_CLASS(BuildingPlacerModel)
A CNN model for determining building positions.
Definition: buildingplacer.h:129
Definition: unitsfeatures.h:33
int UnitType
Definition: buildingplacer.h:25
int64_t actionToOffset(Position pos, int scale=1) const
Maps an action (position) in walktiles to offset in flattened output or target tensor.
Definition: buildingplacer.cpp:180
std::string mapName
Map name (file name for replays); optional.
Definition: buildingplacer.h:55
Main namespace for bot-related code.
Definition: areainfo.cpp:17
UnitTypeFeaturizer unitFeaturizer
Definition: buildingplacer.h:49
UnitType type
Requested building type.
Definition: buildingplacer.h:42