11 #include "features/features.h" 12 #include "features/unitsfeatures.h" 13 #include "modules/autobuild.h" 16 #include <cpid/trainer.h> 19 #include <autogradpp/autograd.h> 23 #include <cereal/archives/binary.hpp> 24 #include <cereal/types/vector.hpp> 27 template <
class Archive>
43 std::map<std::string, int64_t>
const& buildOrderMap();
44 std::vector<std::string>
const& targetBuilds();
45 std::string allowedTargetsAsFlag();
46 std::string allowedOpeningsAsFlag();
49 std::regex race_regex(
"_([PZT]{1})_");
50 std::smatch race_match;
51 if (!std::regex_search(opponent, race_match, race_regex)) {
52 throw std::runtime_error(
"Opponent string is not correct!");
54 return race_match[0].str()[1];
58 return std::string(1, prefix) +
"-" + buildOrder;
63 std::move(buildOrder),
64 tc::BW::Race::_from_integral(race)._to_string()[0]);
68 return prefixedBo.substr(2);
72 auto const& boMap = buildOrderMap();
73 auto it = boMap.find(bo);
74 if (it == boMap.end()) {
75 throw std::runtime_error(
"Unknown build order: " + bo);
81 torch::Tensor getBuildOrderMaskByRace(
int race);
114 static int constexpr kMapSize = 512;
115 static int constexpr kNumMapChannels = 4;
121 float switchProba = 0.0f;
134 template <
class Archive>
137 CEREAL_NVP(switchProba),
140 CEREAL_NVP(opponentName),
143 ar(CEREAL_NVP(gameId));
161 bool switched =
false;
163 uint64_t pendingUpgrades = 0;
165 uint64_t pendingUpgradesLevel = 0;
167 uint64_t pendingTechs = 0;
176 std::shared_ptr<StaticData> sd =
nullptr);
177 virtual ~
Sample() =
default;
179 template <
class Archive>
182 throw std::runtime_error(
"Unsupported version");
186 ar(CEREAL_NVP(staticData),
189 CEREAL_NVP(resources),
190 CEREAL_NVP(buildOrder),
191 CEREAL_NVP(nextBuildOrder),
192 CEREAL_NVP(switched),
193 CEREAL_NVP(pendingUpgrades),
194 CEREAL_NVP(pendingUpgradesLevel),
195 CEREAL_NVP(pendingTechs),
196 CEREAL_NVP(nextAbboStates));
203 torch::Tensor featurize(
205 torch::Tensor buffer = torch::Tensor())
const;
206 ag::tensor_list featurize(std::vector<BosFeature> features)
const;
207 void renormV2Features();
209 static std::map<int, autobuild::BuildState> simulateAbbo(
211 std::string
const& buildOrder,
212 std::vector<int>
const& frameOffsets);
219 ReplayBufferFrame() =
default;
220 ReplayBufferFrame(
Sample sample) : sample(std::move(sample)) {}
221 virtual ~ReplayBufferFrame() =
default;
223 template <
class Archive>
224 void serialize(Archive& ar, uint32_t
const version) {
225 ar(cereal::base_class<cpid::ReplayBufferFrame>(
this), sample);
232 std::vector<std::shared_ptr<cpid::ReplayBufferFrame>> frames;
234 template <
class Archive>
238 uint32_t size = frames.size();
241 for (
auto i = 0U; i < size; i++) {
256 typedef bos::ReplayBufferFrame BosReplayBufferFrame;
257 typedef bos::EpisodeData BosEpisodeData;
265 CEREAL_REGISTER_TYPE(cherrypi::BosReplayBufferFrame);
266 CEREAL_CLASS_VERSION(cherrypi::BosReplayBufferFrame, 0);
Map "ID" based on sum of map features.
Game state.
Definition: state.h:42
int FrameNum
Definition: basetypes.h:22
std::string GameUID
Definition: trainer.h:31
uint64_t upgrades_level
Definition: frame.h:172
int32_t used_psi
Definition: frame.h:169
std::string stripRacePrefix(std::string prefixedBo)
Definition: sample.h:67
std::string gameId
Game Id (optional)
Definition: sample.h:129
FeatureData units
Defogger style unit types in spatial representation.
Definition: sample.h:151
int32_t gas
Definition: frame.h:168
BosFeature
A list of possible features that can be extracted from a Sample.
Definition: sample.h:84
uint64_t techs
Definition: frame.h:173
std::string nextBuildOrder
Build order until next sample.
Definition: sample.h:159
void serialize(Archive &archive, torchcraft::Resources &resource)
Definition: sample.h:28
int64_t buildOrderId(std::string const &bo)
Definition: sample.h:71
std::string buildOrder
Current build order.
Definition: sample.h:157
char getOpponentRace(std::string const &opponent)
Definition: sample.h:48
torch::Tensor getBuildOrderMaskByRace(char race)
Definition: sample.cpp:175
Ore/Gas/UsedPsi/TotalPsi: log(x / 5 + 1)
int32_t ore
Definition: frame.h:167
FrameNum frame
Frame number of this sample.
Definition: sample.h:153
void serialize(Archive &ar, uint32_t const version)
Definition: sample.h:135
std::shared_ptr< StaticData > staticData
Definition: sample.h:149
std::string opponentName
Player name of opponent.
Definition: sample.h:125
Map features from StaticData.
Represents a collection of spatial feature data.
Definition: features.h:190
std::string addRacePrefix(std::string buildOrder, char prefix)
Definition: sample.h:57
2-dimensional: our and their race
std::string EpisodeKey
Definition: trainer.h:32
tc::Resources resources
Our resources.
Definition: sample.h:155
FeatureData map
Various map features.
Definition: sample.h:118
Id of active build order.
std::map< int, autobuild::BuildState > nextAbboStates
Future autobuild states for given frame offsets.
Definition: sample.h:169
uint64_t upgrades
Definition: frame.h:171
Stub base class for replay buffer frames.
Definition: trainer.h:69
Bag-of-words unit type counts.
int32_t total_psi
Definition: frame.h:170
Bag-of-words unit type counts in future autobuild states (ours only)
void serialize(Archive &ar, uint32_t const version)
Definition: sample.h:180
Main namespace for bot-related code.
Definition: areainfo.cpp:17
Defogger-style pooled unit types.
142-dim tech/upgrade vector: one bit for each upgrade/level/tech
142-dim vector of pending upgrades/techs