TorchCraftAI
A bot for machine learning research on StarCraft: Brood War
basetypes.h
1 /*
2  * Copyright (c) 2017-present, Facebook, Inc.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #pragma once
9 
10 #include <cereal/cereal.hpp>
11 #include <fmt/format.h>
12 
13 #include <cmath>
14 #include <functional>
15 #include <iostream>
16 #include <limits>
17 #include <utility>
18 
19 namespace cherrypi {
20 
21 typedef int PlayerId;
22 typedef int FrameNum;
23 typedef int UpcId;
24 UpcId constexpr kRootUpcId = 0;
25 UpcId constexpr kInvalidUpcId = -1;
26 UpcId constexpr kFilteredUpcId = -2;
27 
28 double constexpr kDegPerRad = 180 / M_PI;
29 float constexpr kfInfty = std::numeric_limits<float>::infinity();
30 float constexpr kfLowest = std::numeric_limits<float>::lowest();
31 float constexpr kfMax = std::numeric_limits<float>::max();
32 float constexpr kfEpsilon = std::numeric_limits<float>::epsilon();
33 double constexpr kdInfty = std::numeric_limits<double>::infinity();
34 double constexpr kdLowest = std::numeric_limits<double>::lowest();
35 double constexpr kdMax = std::numeric_limits<double>::max();
36 double constexpr kdEpsilon = std::numeric_limits<double>::epsilon();
37 int constexpr kForever = 24 * 60 * 60 * 24 * 7;
38 constexpr int kLarvaFrames = 342;
39 
40 template <typename T>
41 class Vec2T {
42  public:
43  T x;
44  T y;
45 
46  constexpr Vec2T() : x(0), y(0) {}
47  constexpr Vec2T(T x, T y) : x(x), y(y) {}
48  template <typename U>
49  explicit Vec2T(U const& other) : x(other.x), y(other.y) {}
50  template <typename U>
51  explicit Vec2T(U* other) : x(other->x), y(other->y) {}
52  template <typename U, typename V>
53  explicit Vec2T(std::pair<U, V> const& other)
54  : x(other.first), y(other.second) {}
55 
56  Vec2T& operator=(Vec2T const& other) {
57  x = other.x;
58  y = other.y;
59  return *this;
60  }
61 
62  bool operator==(Vec2T const& other) const {
63  return x == other.x && y == other.y;
64  }
65  bool operator!=(Vec2T const& other) const {
66  return x != other.x || y != other.y;
67  }
68  // For use in std::map as a key, since it's ordered and keeps a compare
69  bool operator<(Vec2T<T> const& other) const {
70  return x < other.x || (x == other.x && y < other.y);
71  }
72 
73  Vec2T operator+(T scalar) const {
74  return Vec2T(x + scalar, y + scalar);
75  }
76  Vec2T operator-(T scalar) const {
77  return Vec2T(x - scalar, y - scalar);
78  }
79  Vec2T operator*(T scalar) const {
80  return Vec2T(x * scalar, y * scalar);
81  }
82  Vec2T operator/(T scalar) const {
83  return Vec2T(x / scalar, y / scalar);
84  }
85  Vec2T operator+(Vec2T const& other) const {
86  return Vec2T(x + other.x, y + other.y);
87  }
88  Vec2T operator-(Vec2T const& other) const {
89  return Vec2T(x - other.x, y - other.y);
90  }
91  Vec2T& operator+=(T scalar) {
92  x += scalar;
93  y += scalar;
94  return *this;
95  }
96  Vec2T& operator-=(T scalar) {
97  x -= scalar;
98  y -= scalar;
99  return *this;
100  }
101  Vec2T& operator*=(T scalar) {
102  x *= scalar;
103  y *= scalar;
104  return *this;
105  }
106  Vec2T& operator/=(T scalar) {
107  x /= scalar;
108  y /= scalar;
109  return *this;
110  }
111  Vec2T& operator+=(Vec2T const& other) {
112  x += other.x;
113  y += other.y;
114  return *this;
115  }
116  Vec2T& operator-=(Vec2T const& other) {
117  x -= other.x;
118  y -= other.y;
119  return *this;
120  }
121 
122  double distanceTo(Vec2T const& other) const {
123  return Vec2T(other.x - x, other.y - y).length();
124  }
125  template <typename U>
126  double distanceTo(U* other) const {
127  return Vec2T(other->x - x, other->y - y).length();
128  }
129  double length() const {
130  return std::sqrt(x * x + y * y);
131  }
133  // Warning -- using this on Vec2<int> is a bad idea
134  // because *= casts the multiplier as an int
135  double len = length();
136  *this *= len == 0 ? 1. : 1. / len;
137  return *this;
138  }
139  Vec2T& rotateDegrees(double degrees) {
140  double radians = degrees / kDegPerRad;
141  double sine = std::sin(radians), cosine = std::cos(radians);
142  double xNew = x * cosine - y * sine;
143  double yNew = x * sine + y * cosine;
144  x = xNew;
145  y = yNew;
146  return *this;
147  }
148  Vec2T project(Vec2T towards, T distance) {
149  if (distance == 0)
150  return Vec2T(this);
151  auto separation = distanceTo(towards);
152  if (separation == 0)
153  return Vec2T(this);
154  return (towards - *this) * (distance / separation) + (*this);
155  }
156 
157  static double cos(Vec2T const& a, Vec2T const& b) {
158  double denominator = a.length() * b.length();
159  return denominator == 0 ? 0 : Vec2T::dot(a, b) / denominator;
160  }
161  static T dot(Vec2T const& a, Vec2T const& b) {
162  return a.x * b.x + a.y * b.y;
163  }
164  T dot(Vec2T const& other) {
165  return dot(*this, other);
166  }
167  static T cross(Vec2T const& a, Vec2T const& b) {
168  return (a.x * b.y) - (a.y * b.x);
169  }
170 
171  template <class Archive>
172  void serialize(Archive& ar) {
173  ar(CEREAL_NVP(x), CEREAL_NVP(y));
174  }
175 };
176 
179 constexpr Position kInvalidPosition{-1, -1};
180 
181 template <typename T>
182 class Rect2T {
183  public:
184  T x;
185  T y;
186  T w;
187  T h;
188 
189  Rect2T() : x(0), y(0), w(0), h(0) {}
190  Rect2T(T x, T y, T width, T height) : x(x), y(y), w(width), h(height) {}
191  Rect2T(Vec2T<T> const& topLeft, Vec2T<T> const& bottomRight)
192  : x(topLeft.x),
193  y(topLeft.y),
194  w(bottomRight.x - topLeft.x),
195  h(bottomRight.y - topLeft.y) {}
196  Rect2T(Vec2T<T> const& topLeft, T width, T height)
197  : x(topLeft.x), y(topLeft.y), w(width), h(height) {}
198  template <typename T2>
199  Rect2T(Rect2T<T2> const& r) : x(r.x), y(r.y), w(r.w), h(r.h) {}
200  template <typename T2>
202  x = r.x;
203  y = r.y;
204  w = r.w;
205  h = r.h;
206  }
207 
208  bool operator==(Rect2T<T> const& r) const {
209  return x == r.x && y == r.y && w == r.w && h == r.w;
210  }
211 
212  T left() const {
213  return x;
214  }
215  T right() const {
216  return x + w;
217  }
218  T top() const {
219  return y;
220  }
221  T bottom() const {
222  return y + h;
223  }
224  T width() const {
225  return w;
226  }
227  T height() const {
228  return h;
229  }
230 
231  Vec2T<T> center() const {
232  return Vec2T<T>(x + w / 2, y + h / 2);
233  }
234  static Rect2T<T> centeredWithSize(Vec2T<T> const& center, T width, T height) {
235  Rect2T<T> t;
236  t.x = center.x - width / 2;
237  t.y = center.y - height / 2;
238  t.w = width;
239  t.h = height;
240  return t;
241  }
242 
243  bool null() const {
244  return w == T(0) && h == T(0);
245  }
246  bool empty() const {
247  return w <= T(0) && h <= T(0);
248  }
249 
251  if (empty()) {
252  return r;
253  }
254  if (r.empty()) {
255  return *this;
256  }
257 
258  Rect2T<T> t;
259  t.x = std::min(x, r.x);
260  t.y = std::min(y, r.y);
261  t.w = std::max(x + w, r.x + r.w) - t.x;
262  t.h = std::max(y + h, r.y + r.h) - t.y;
263  return t;
264  }
265 
266  Rect2T<T> intersected(Rect2T<T> const& r) const {
267  if (empty() || r.empty()) {
268  return Rect2T<T>();
269  }
270 
271  T left1 = x;
272  T right1 = x + w;
273  T left2 = r.x;
274  T right2 = r.x + r.w;
275 
276  if (left1 >= right2 || left2 >= right1) {
277  // No intersection
278  return Rect2T<T>();
279  }
280 
281  T top1 = y;
282  T bottom1 = y + h;
283  T top2 = r.y;
284  T bottom2 = r.y + r.h;
285 
286  if (top1 >= bottom2 || top2 >= bottom1) {
287  // No intersection
288  return Rect2T<T>();
289  }
290 
291  Rect2T<T> t;
292  t.x = std::max(left1, left2);
293  t.y = std::max(top1, top2);
294  t.w = std::min(right1, right2) - t.x;
295  t.h = std::min(bottom1, bottom2) - t.y;
296  return t;
297  }
298 
299  bool contains(const Vec2T<T>& pt) const {
300  return pt.x >= left() && pt.x < right() && pt.y >= top() && pt.y < bottom();
301  }
302 
303  template <class Archive>
304  void serialize(Archive& ar) {
305  ar(CEREAL_NVP(x), CEREAL_NVP(y), CEREAL_NVP(w), CEREAL_NVP(h));
306  }
307 };
308 
310 
311 /**
312  * Abstract "meta" commands for UPCTuples.
313  */
314 enum Command : uint64_t {
315  // clang-format off
316  None = 0,
317  Create = 1 << 0,
318  Move = 1 << 1,
319  Delete = 1 << 2, // Kill things (ie. attack)
320  Gather = 1 << 3,
321  Scout = 1 << 4,
322  Cancel = 1 << 5,
323  Harass = 1 << 6,
324  Flee = 1 << 7,
326  ReturnCargo = 1 << 9,
327  Cast = 1 << 10,
328  MAX = 1 << 11,
329  // clang-format on
330 };
331 
332 /// Does not count the "None" command
333 int constexpr numUpcCommands() {
334  static_assert(Command::MAX > 0, "Invalid MAX command value");
335 #if defined(__GNUC__) || defined(__clang__)
336  return __builtin_ctzl(Command::MAX);
337 #else
338  std::underlying_type<Command>::type n = Command::MAX;
339  int count = 0;
340  while (n % 2 == 0) {
341  count++;
342  n >>= 1;
343  }
344  return count;
345 #endif
346 }
347 } // namespace cherrypi
348 
349 namespace fmt {
350 template <typename T>
351 struct formatter<cherrypi::Vec2T<T>> {
352  template <typename ParseContext>
353  constexpr auto parse(ParseContext& ctx) {
354  return ctx.begin();
355  }
356 
357  template <typename FormatContext>
358  auto format(const cherrypi::Vec2T<T>& p, FormatContext& ctx) {
359  return format_to(ctx.begin(), "({},{})", p.x, p.y);
360  }
361 };
362 } // namespace fmt
363 
364 namespace std {
365 template <typename T>
366 inline ostream& operator<<(ostream& strm, cherrypi::Vec2T<T> const& p) {
367  return strm << "(" << p.x << "," << p.y << ")";
368 }
369 
370 template <typename T>
371 inline ostream& operator<<(ostream& strm, cherrypi::Rect2T<T> const& r) {
372  return strm << "(" << r.x << "," << r.y << " " << r.w << "x" << r.h << ")";
373 }
374 
375 template <typename T>
376 struct hash<cherrypi::Vec2T<T>> {
377  size_t operator()(cherrypi::Vec2T<T> const& pos) const {
378  return hashT(pos.x * pos.y) ^ hashT(pos.y);
379  }
380 
381  private:
382  function<size_t(T)> hashT = hash<T>();
383 };
384 
385 template <>
386 struct hash<cherrypi::Command> {
387  using utype = std::underlying_type<cherrypi::Command>::type;
388  size_t operator()(cherrypi::Command const& cmd) const {
389  return h(static_cast<utype>(cmd));
390  }
391 
392  private:
393  hash<utype> h;
394 };
395 } // namespace std
double distanceTo(U *other) const
Definition: basetypes.h:126
Rect2T(Vec2T< T > const &topLeft, T width, T height)
Definition: basetypes.h:196
int FrameNum
Definition: basetypes.h:22
Definition: basetypes.h:325
T bottom() const
Definition: basetypes.h:221
constexpr Vec2T()
Definition: basetypes.h:46
Vec2T< T > center() const
Definition: basetypes.h:231
bool contains(const Vec2T< T > &pt) const
Definition: basetypes.h:299
T dot(Vec2T const &other)
Definition: basetypes.h:164
Command
Abstract "meta" commands for UPCTuples.
Definition: basetypes.h:314
double constexpr kdInfty
Definition: basetypes.h:33
Vec2T operator+(Vec2T const &other) const
Definition: basetypes.h:85
Vec2T & operator=(Vec2T const &other)
Definition: basetypes.h:56
T width() const
Definition: basetypes.h:224
Vec2T & operator/=(T scalar)
Definition: basetypes.h:106
T right() const
Definition: basetypes.h:215
Vec2T & rotateDegrees(double degrees)
Definition: basetypes.h:139
bool operator==(Vec2T const &other) const
Definition: basetypes.h:62
Rect2T< T > united(Rect2T< T > const &r)
Definition: basetypes.h:250
T w
Definition: basetypes.h:186
Rect2T< int > Rect
Definition: basetypes.h:309
Vec2T & operator*=(T scalar)
Definition: basetypes.h:101
Definition: basetypes.h:318
double distanceTo(Vec2T const &other) const
Definition: basetypes.h:122
static T cross(Vec2T const &a, Vec2T const &b)
Definition: basetypes.h:167
STL namespace.
Rect2T< T > & operator=(Rect2T< T2 > const &r)
Definition: basetypes.h:201
T y
Definition: basetypes.h:44
Definition: basetypes.h:322
Vec2T operator-(Vec2T const &other) const
Definition: basetypes.h:88
double constexpr kDegPerRad
Definition: basetypes.h:28
Definition: basetypes.h:320
Definition: basetypes.h:328
Vec2T & normalize()
Definition: basetypes.h:132
double length() const
Definition: basetypes.h:129
Rect2T(Rect2T< T2 > const &r)
Definition: basetypes.h:199
Rect2T(Vec2T< T > const &topLeft, Vec2T< T > const &bottomRight)
Definition: basetypes.h:191
Definition: basetypes.h:321
static Rect2T< T > centeredWithSize(Vec2T< T > const &center, T width, T height)
Definition: basetypes.h:234
UpcId constexpr kRootUpcId
Definition: basetypes.h:24
bool operator==(Rect2T< T > const &r) const
Definition: basetypes.h:208
double constexpr kdEpsilon
Definition: basetypes.h:36
Rect2T< T > intersected(Rect2T< T > const &r) const
Definition: basetypes.h:266
double constexpr kdMax
Definition: basetypes.h:35
Definition: basetypes.h:327
Vec2T operator-(T scalar) const
Definition: basetypes.h:76
T left() const
Definition: basetypes.h:212
float constexpr kfLowest
Definition: basetypes.h:30
Rect2T()
Definition: basetypes.h:189
T y
Definition: basetypes.h:185
bool null() const
Definition: basetypes.h:243
T top() const
Definition: basetypes.h:218
Definition: basetypes.h:317
void serialize(Archive &ar)
Definition: basetypes.h:304
Definition: basetypes.h:324
std::underlying_type< cherrypi::Command >::type utype
Definition: basetypes.h:387
double constexpr kdLowest
Definition: basetypes.h:34
Vec2T & operator-=(Vec2T const &other)
Definition: basetypes.h:116
Vec2T project(Vec2T towards, T distance)
Definition: basetypes.h:148
constexpr int kLarvaFrames
Definition: basetypes.h:38
T x
Definition: basetypes.h:184
Vec2T & operator+=(Vec2T const &other)
Definition: basetypes.h:111
Definition: basetypes.h:319
constexpr Vec2T(T x, T y)
Definition: basetypes.h:47
constexpr Position kInvalidPosition
Definition: basetypes.h:179
Vec2T & operator-=(T scalar)
Definition: basetypes.h:96
float constexpr kfMax
Definition: basetypes.h:31
float constexpr kfInfty
Definition: basetypes.h:29
Definition: basetypes.h:316
Vec2T(std::pair< U, V > const &other)
Definition: basetypes.h:53
Definition: basetypes.h:182
Vec2T(U const &other)
Definition: basetypes.h:49
int constexpr kForever
Definition: basetypes.h:37
bool empty() const
Definition: basetypes.h:246
Definition: basetypes.h:326
Vec2T operator+(T scalar) const
Definition: basetypes.h:73
int PlayerId
Definition: basetypes.h:21
size_t operator()(cherrypi::Command const &cmd) const
Definition: basetypes.h:388
T h
Definition: basetypes.h:187
Definition: basetypes.h:323
T x
Definition: basetypes.h:43
Main namespace for bot-related code.
Definition: areainfo.cpp:17
Definition: basetypes.h:349
constexpr auto parse(ParseContext &ctx)
Definition: basetypes.h:353
float constexpr kfEpsilon
Definition: basetypes.h:32
Vec2T< float > Vec2
Definition: basetypes.h:177
Definition: basetypes.h:41
static double cos(Vec2T const &a, Vec2T const &b)
Definition: basetypes.h:157
void serialize(Archive &ar)
Definition: basetypes.h:172
int UpcId
Definition: basetypes.h:23
Rect2T(T x, T y, T width, T height)
Definition: basetypes.h:190
bool operator!=(Vec2T const &other) const
Definition: basetypes.h:65
Vec2T(U *other)
Definition: basetypes.h:51
Vec2T & operator+=(T scalar)
Definition: basetypes.h:91
T height() const
Definition: basetypes.h:227
Vec2T operator/(T scalar) const
Definition: basetypes.h:82
Vec2T operator*(T scalar) const
Definition: basetypes.h:79
int constexpr numUpcCommands()
Does not count the "None" command.
Definition: basetypes.h:333
auto format(const cherrypi::Vec2T< T > &p, FormatContext &ctx)
Definition: basetypes.h:358
size_t operator()(cherrypi::Vec2T< T > const &pos) const
Definition: basetypes.h:377
UpcId constexpr kFilteredUpcId
Definition: basetypes.h:26
static T dot(Vec2T const &a, Vec2T const &b)
Definition: basetypes.h:161
Vec2T< int > Position
Definition: basetypes.h:178
UpcId constexpr kInvalidUpcId
Definition: basetypes.h:25