TorchCraftAI
A bot for machine learning research on StarCraft: Brood War
algorithms.h
1 /*
2  * Copyright (c) 2017-present, Facebook, Inc.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #pragma once
9 
10 #include <algorithm>
11 #include <math.h>
12 #include <torchcraft/client.h>
13 #include <type_traits>
14 #include <vector>
15 
16 #include <autogradpp/autograd.h>
17 
18 #include "buildtype.h"
19 #include "cherrypi.h"
20 #include "filter.h"
21 #include "gamemechanics.h"
22 #include "state.h"
23 #include "unitsinfo.h"
24 
25 namespace cherrypi {
26 namespace utils {
27 
28 struct NoValue {};
29 static constexpr NoValue noValue{};
30 template <typename A, typename B>
31 constexpr bool isEqualButNotNoValue(A&& a, B&& b) {
32  return a == b;
33 }
34 template <typename A>
35 constexpr bool isEqualButNotNoValue(A&& a, NoValue) {
36  return false;
37 }
38 template <typename B>
39 constexpr bool isEqualButNotNoValue(NoValue, B&& b) {
40  return false;
41 }
42 
43 /// This function iterates from begin to end, passing each value to the
44 /// provided score function.
45 /// It returns the corresponding iterator for the value whose score function
46 /// returned the lowest value (using operator <).
47 /// If invalidScore is provided, then any score which compares equal to it will
48 /// never be returned.
49 /// If bestPossibleScore is provided, then any score which compares equal to it
50 /// (and also is the best score thus far) will cause an immediate return,
51 /// without iterating to the end.
52 /// If the range is empty or no value can be returned (due to invalidScore),
53 /// then the end iterator is returned.
54 template <
55  typename Iterator,
56  typename Score,
57  typename InvalidScore = NoValue,
58  typename BestPossibleScore = NoValue>
60  Iterator begin,
61  Iterator end,
62  Score&& score,
63  InvalidScore&& invalidScore = InvalidScore(),
64  BestPossibleScore&& bestPossibleScore = BestPossibleScore()) {
65  if (begin == end) {
66  return end;
67  }
68  auto i = begin;
69  auto best = i;
70  auto bestScore = score(*i);
71  ++i;
72  if (isEqualButNotNoValue(bestScore, invalidScore)) {
73  best = end;
74  for (; i != end; ++i) {
75  auto s = score(*i);
76  if (isEqualButNotNoValue(s, invalidScore)) {
77  continue;
78  }
79  best = i;
80  bestScore = s;
81  if (isEqualButNotNoValue(s, bestPossibleScore)) {
82  return best;
83  }
84  break;
85  }
86  }
87  for (; i != end; ++i) {
88  auto s = score(*i);
89  if (isEqualButNotNoValue(s, invalidScore)) {
90  continue;
91  }
92  if (s < bestScore) {
93  best = i;
94  bestScore = s;
95  if (isEqualButNotNoValue(s, bestPossibleScore)) {
96  break;
97  }
98  }
99  }
100  return best;
101 }
102 
103 /// This function is equivalent to getBestScore, but it can be passed a range
104 /// or container instead of two iterators.
105 /// The return value is still an iterator.
106 template <
107  typename Range,
108  typename Score,
109  typename InvalidScore = NoValue,
110  typename BestPossibleScore = NoValue>
112  Range&& range,
113  Score&& score,
114  InvalidScore&& invalidScore = InvalidScore(),
115  BestPossibleScore&& bestPossibleScore = BestPossibleScore()) {
116  return getBestScore(
117  range.begin(),
118  range.end(),
119  std::forward<Score>(score),
120  std::forward<InvalidScore>(invalidScore),
121  std::forward<BestPossibleScore>(bestPossibleScore));
122 }
123 
124 /// This function is the same as getBestScore, but it returns a copy of the
125 /// value retrieved by dereferencing the returned iterator (using auto type
126 /// semantics; it's a copy, not a reference).
127 /// If the end iterator would be returned, a value initialized object is
128 /// returned as if by T{}.
129 template <
130  typename Range,
131  typename Score,
132  typename InvalidScore = NoValue,
133  typename BestPossibleScore = NoValue>
135  Range&& range,
136  Score&& score,
137  InvalidScore&& invalidScore = InvalidScore(),
138  BestPossibleScore&& bestPossibleScore = BestPossibleScore()) {
139  auto i = getBestScore(
140  range.begin(),
141  range.end(),
142  std::forward<Score>(score),
143  std::forward<InvalidScore>(invalidScore),
144  std::forward<BestPossibleScore>(bestPossibleScore));
145  if (i == range.end()) {
146  return typename std::remove_reference<decltype(*i)>::type{};
147  }
148  return *i;
149 }
150 
151 /// This function is the same as getBestScore, but it returns a pointer
152 /// to the value of the dereferenced result iterator, or nullptr if the
153 /// end iterator would be returned.
154 template <
155  typename Range,
156  typename Score,
157  typename InvalidScore = NoValue,
158  typename BestPossibleScore = NoValue>
160  Range&& range,
161  Score&& score,
162  InvalidScore&& invalidScore = InvalidScore(),
163  BestPossibleScore&& bestPossibleScore = BestPossibleScore()) {
164  auto i = getBestScore(
165  range.begin(),
166  range.end(),
167  std::forward<Score>(score),
168  std::forward<InvalidScore>(invalidScore),
169  std::forward<BestPossibleScore>(bestPossibleScore));
170  if (i == range.end()) {
171  return (decltype(&*i)) nullptr;
172  }
173  return &*i;
174 }
175 
176 inline std::string buildTypeString(BuildType const* buildType) {
177  return (buildType ? buildType->name : "null");
178 }
179 
180 template <typename Units>
181 inline Position centerOfUnits(Units&& units) {
182  Position p;
183  if (units.size() != 0) {
184  for (Unit const* unit : units) {
185  p += Position(unit);
186  }
187  p /= units.size();
188  } else {
189  VLOG(2) << "Center of no units is (0, 0)";
190  return Position(0, 0);
191  }
192  return p;
193 }
194 
195 template <typename InputIterator>
196 inline Position centerOfUnits(InputIterator start, InputIterator end) {
197  Position p(0, 0);
198  auto size = 0U;
199  if (start == end) {
200  VLOG(2) << "Center of no units is (0, 0)";
201  return Position(0, 0);
202  }
203  for (; start != end; start++, size++) {
204  p += Position(*start);
205  }
206  p /= size;
207  return p;
208 }
209 
210 inline bool isWithinRadius(Unit* unit, int32_t x, int32_t y, float radius) {
211  return distance(unit->x, unit->y, x, y) <= radius;
212 }
213 
214 template <typename Units>
215 inline std::vector<Unit*>
216 filterUnitsByDistance(Units&& units, int32_t x, int32_t y, float radius) {
217  return filterUnits(
218  units, [=](Unit* u) { return isWithinRadius(u, x, y, radius); });
219 }
220 
221 // Determine the closest unit to a given position
222 template <typename It>
223 inline It getClosest(int x, int y, It first, It last) {
224  It closest = last;
225  float mind = FLT_MAX;
226  while (first != last) {
227  float d = float(x - first->x) * (x - first->x) +
228  float(y - first->y) * (y - first->y);
229  if (d < mind) {
230  closest = first;
231  mind = d;
232  }
233  ++first;
234  }
235  return closest;
236 }
237 
238 template <typename Units>
239 std::unordered_set<Unit*> findNearbyEnemyUnits(State* state, Units&& units) {
240  auto& enemyUnits = state->unitsInfo().enemyUnits();
241  std::unordered_set<Unit*> nearby;
242  for (auto unit : units) {
243  // from UAlbertaBot
244  auto wRange = 75;
245  for (auto enemy :
246  filterUnitsByDistance(enemyUnits, unit->x, unit->y, wRange)) {
247  // XXX What if it's gone??
248  if (!enemy->gone) {
249  nearby.insert(enemy);
250  }
251  }
252  }
253  return nearby;
254 }
255 
256 // Returns argmax (x,y) and value in walktiles
257 inline std::tuple<int, int, float> argmax(torch::Tensor const& pos, int scale) {
258  if (!pos.defined() || pos.dim() != 2) {
259  throw std::runtime_error("Two-dimensional tensor expected");
260  }
261  // ATen needs a const accessor...
262  auto acc = const_cast<torch::Tensor&>(pos).accessor<float, 2>();
263  int xmax = 0;
264  int ymax = 0;
265  float max = kfLowest;
266  for (int y = 0; y < acc.size(0); y++) {
267  for (int x = 0; x < acc.size(1); x++) {
268  auto el = acc[y][x];
269  if (el > max) {
270  max = el;
271  xmax = x;
272  ymax = y;
273  }
274  }
275  }
276 
277  return std::make_tuple(xmax * scale, ymax * scale, max);
278 }
279 
280 template <typename T>
282  std::vector<T>& in,
283  const std::vector<T>& add) {
284  for (size_t i = 0; i < in.size(); ++i) {
285  in[i] += add[i];
286  }
287 }
288 
289 template <typename T>
291  std::vector<T>& in,
292  const std::vector<T>& mul1,
293  const std::vector<T>& mul2) {
294  for (size_t i = 0; i < in.size(); ++i) {
295  in[i] += mul1[i] * mul2[i];
296  }
297 }
298 template <typename T>
300  std::vector<T>& in,
301  const std::vector<T>& mul1,
302  T mul2) {
303  for (size_t i = 0; i < in.size(); ++i) {
304  in[i] += mul1[i] * mul2;
305  }
306 }
307 
308 template <typename T>
309 inline void inplace_flat_vector_div(std::vector<T>& in, T div) {
310  for (size_t i = 0; i < in.size(); ++i) {
311  in[i] /= div;
312  }
313 }
314 
315 template <typename T>
316 inline T l2_norm_vector(const std::vector<T>& v) {
317  T s2 = 0;
318  for (auto& e : v)
319  s2 += pow(e, 2);
320  return sqrt(s2);
321 }
322 
323 template <typename T>
324 inline size_t argmax(const std::vector<T>& v) {
325  return std::distance(v.begin(), std::max_element(v.begin(), v.end()));
326 }
327 
328 template <typename TCollection, typename TKey>
329 bool contains(TCollection& collection, TKey& key) {
330  return collection.find(key) != collection.end();
331 }
332 
333 namespace detail {
334 // Add overload for specific container below if needed
335 template <typename K, typename V, typename U>
336 void cmergeInto(std::map<K, V>& dest, U&& src) {
337  dest.insert(src.begin(), src.end());
338 }
339 template <typename K, typename V, typename U>
340 void cmergeInto(std::unordered_map<K, V>& dest, U&& src) {
341  dest.insert(src.begin(), src.end());
342 }
343 template <typename T, typename U>
344 void cmergeInto(T& dest, U&& src) {
345  dest.insert(dest.end(), src.begin(), src.end());
346 }
347 } // namespace detail
348 
349 /// Merges two or more STL containers.
350 /// For associative containers, values will not be overwritten during merge,
351 /// i.e. for duplicate keys, the value of the *first* argument containing that
352 /// key will be used.
353 template <typename T, typename... Args>
354 inline typename std::remove_reference<T>::type cmerge(T&& c1, Args&&... cs) {
355  typename std::remove_reference<T>::type m;
356  detail::cmergeInto(m, c1);
357  (void)std::initializer_list<int>{(detail::cmergeInto(m, cs), 0)...};
358  return m;
359 }
360 
361 } // namespace utils
362 } // namespace cherrypi
Game state.
Definition: state.h:42
void inplace_flat_vector_div(std::vector< T > &in, T div)
Definition: algorithms.h:309
std::unordered_set< Unit * > findNearbyEnemyUnits(State *state, Units &&units)
Definition: algorithms.h:239
std::remove_reference< T >::type cmerge(T &&c1, Args &&...cs)
Merges two or more STL containers.
Definition: algorithms.h:354
bool isWithinRadius(Unit *unit, int32_t x, int32_t y, float radius)
Definition: algorithms.h:210
std::tuple< int, int, float > argmax(torch::Tensor const &pos, int scale)
Definition: algorithms.h:257
It getClosest(int x, int y, It first, It last)
Definition: algorithms.h:223
void cmergeInto(std::map< K, V > &dest, U &&src)
Definition: algorithms.h:336
Represents and holds information about buildable types (units, upgrades, techs).
Definition: buildtype.h:36
constexpr bool isEqualButNotNoValue(A &&a, B &&b)
Definition: algorithms.h:31
T l2_norm_vector(const std::vector< T > &v)
Definition: algorithms.h:316
UnitsInfo & unitsInfo()
Definition: state.h:116
bool contains(TCollection &collection, TKey &key)
Definition: algorithms.h:329
Definition: algorithms.h:28
std::string name
Definition: buildtype.h:44
auto filterUnits(Units &&units, UnaryPredicate pred)
Definition: filter.h:15
float constexpr kfLowest
Definition: basetypes.h:30
std::vector< Unit * > filterUnitsByDistance(Units &&units, int32_t x, int32_t y, float radius)
Definition: algorithms.h:216
auto getBestScore(Iterator begin, Iterator end, Score &&score, InvalidScore &&invalidScore=InvalidScore(), BestPossibleScore &&bestPossibleScore=BestPossibleScore())
This function iterates from begin to end, passing each value to the provided score function...
Definition: algorithms.h:59
Represents a unit in the game.
Definition: unitsinfo.h:35
float distance(int x1, int y1, int x2, int y2)
Walktile distance.
Definition: gamemechanics.h:49
void inplace_flat_vector_addcmul(std::vector< T > &in, const std::vector< T > &mul1, const std::vector< T > &mul2)
Definition: algorithms.h:290
std::string buildTypeString(BuildType const *buildType)
Definition: algorithms.h:176
const Units & enemyUnits()
All enemy units that are not dead (includes gone units).
Definition: unitsinfo.h:350
auto getBestScoreCopy(Range &&range, Score &&score, InvalidScore &&invalidScore=InvalidScore(), BestPossibleScore &&bestPossibleScore=BestPossibleScore())
This function is the same as getBestScore, but it returns a copy of the value retrieved by dereferenc...
Definition: algorithms.h:134
auto getBestScorePointer(Range &&range, Score &&score, InvalidScore &&invalidScore=InvalidScore(), BestPossibleScore &&bestPossibleScore=BestPossibleScore())
This function is the same as getBestScore, but it returns a pointer to the value of the dereferenced ...
Definition: algorithms.h:159
Main namespace for bot-related code.
Definition: areainfo.cpp:17
Position centerOfUnits(Units &&units)
Definition: algorithms.h:181
int x
Definition: unitsinfo.h:37
void inplace_flat_vector_add(std::vector< T > &in, const std::vector< T > &add)
Definition: algorithms.h:281
int y
Definition: unitsinfo.h:38
Vec2T< int > Position
Definition: basetypes.h:178