TorchCraftAI
A bot for machine learning research on StarCraft: Brood War
movefilters.h
1 /*
2  * Copyright (c) 2017-present, Facebook, Inc.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #pragma once
9 
10 #include "cherrypi.h"
11 #include "utils.h"
12 
13 namespace cherrypi {
14 namespace movefilters {
15 
17 
18 /**
19  * Assigns positions a score and validity, based on some criteria.
20  * Used for threat-aware pathfinding.
21  */
23  public:
24  virtual ~PositionFilter() = default;
25  virtual bool isValid(Unit*, Position const&) = 0;
26  virtual float score(Unit*, Position const&) = 0;
27  virtual bool blocking() = 0;
28 };
29 
30 typedef std::shared_ptr<PositionFilter> PPositionFilter;
31 typedef std::vector<PPositionFilter> PositionFilters;
32 
33 /**
34  * Input:
35  * - getter: for the list of objects to compare to
36  * - valid: f(agent, position, obj) => bool for whether this position is valid
37  * given the obj
38  * - score: f(agent, position, obj) => float to score this position
39  *
40  * The filter will decide if a position is valid by combining the valid func
41  * and score function according to the PositionFilterPolicy.
42  */
43 template <typename T, typename Container>
45  public:
47  std::function<Container const && (Unit*)> getter,
48  std::function<bool(Unit*, Position const&, T)> valid,
49  std::function<float(Unit*, Position const&, T)> scoreFunc,
51  bool blocking = false)
52  : getter_(std::move(getter)),
53  valid_(std::move(valid)),
54  score_(std::move(scoreFunc)),
55  policy_(policy),
56  blocking_(blocking) {}
57 
58  bool isValid(Unit* agent, Position const& pos) override {
59  VLOG(4) << "PositionFilter: in filter valid";
60  auto objects = getter_(agent);
61  VLOG(4) << "PositionFilter: in valid non empty set? size is: "
62  << objects.size();
63  if (policy_ == PositionFilterPolicy::ACCEPT_IF_ALL) {
64  for (auto obj : objects) {
65  if (!valid_(agent, pos, obj)) {
66  return false;
67  }
68  }
69  return true;
70  } else if (policy_ == PositionFilterPolicy::ACCEPT_IF_ANY) {
71  for (auto obj : objects) {
72  if (valid_(agent, pos, obj)) {
73  return true;
74  }
75  }
76  return false; // false if empty: must be valid for some
77  } else {
78  LOG(DFATAL) << "MoveHelpers: incorrect position filter policy";
79  }
80  return false;
81  }
82 
83  float score(Unit* agent, Position const& pos) override {
84  if (policy_ == PositionFilterPolicy::ACCEPT_IF_ANY) {
85  auto best_score = kfInfty;
86  for (auto obj : getter_(agent)) {
87  auto s = score_(agent, pos, obj);
88  if (s < best_score) {
89  best_score = s;
90  }
91  }
92  return best_score;
93  } else if (policy_ == PositionFilterPolicy::ACCEPT_IF_ALL) {
94  auto best_score = -kfInfty;
95  for (auto obj : getter_(agent)) {
96  auto s = score_(agent, pos, obj);
97  if (s > best_score) {
98  best_score = s;
99  }
100  }
101  return best_score;
102  } else {
103  LOG(DFATAL) << "MoveHelpers: incorrect position filter policy";
104  return 0;
105  }
106  }
107 
108  bool blocking() override {
109  return blocking_;
110  }
111 
112  protected:
113  std::function<Container const && (Unit*)> getter_;
114  std::function<bool(Unit*, Position const&, T)> valid_;
115  std::function<float(Unit*, Position const&, T)> score_;
117  bool blocking_;
118 };
119 
120 /**
121  * This filter uses the score of the base filter, only of all subfilters
122  * return that the position is valid.
123  */
125  public:
127  PPositionFilter base,
128  PositionFilters l,
129  bool blocking = false)
130  : base_(base), allFilters_(l), blocking_(blocking) {}
131 
132  bool isValid(Unit* agent, Position const& pos) override {
133  if (!base_->isValid(agent, pos)) {
134  return false;
135  }
136  for (auto& filt : allFilters_) {
137  if (!filt->isValid(agent, pos)) {
138  return false;
139  }
140  }
141  return true;
142  }
143 
144  float score(Unit* agent, Position const& pos) override {
145  return base_->score(agent, pos);
146  }
147 
148  bool blocking() override {
149  return blocking_;
150  }
151 
152  protected:
153  PPositionFilter base_;
154  PositionFilters allFilters_;
155  bool blocking_;
156 };
157 
158 /**
159  * class to combine filters
160  */
162  public:
164  PositionFilters l,
166  bool blocking = false)
167  : allFilters_(l), policy_(policy), blocking_(blocking) {}
168 
169  bool isValid(Unit* agent, Position const& pos) override {
170  if (policy_ == PositionFilterPolicy::ACCEPT_IF_ALL) {
171  for (auto& filt : allFilters_) {
172  if (!filt->isValid(agent, pos)) {
173  return false;
174  }
175  }
176  return true;
177  } else if (policy_ == PositionFilterPolicy::ACCEPT_IF_ANY) {
178  for (auto& filt : allFilters_) {
179  if (filt->isValid(agent, pos)) {
180  return true;
181  }
182  }
183  return false;
184  }
185  LOG(ERROR) << "invalid moveFilter policy";
186  return false;
187  }
188 
189  float score(Unit* agent, Position const& pos) override {
190  if (policy_ == PositionFilterPolicy::ACCEPT_IF_ALL) {
191  float max_score = -kfInfty;
192  for (auto& filt : allFilters_) {
193  auto score = filt->score(agent, pos);
194  if (score > max_score) {
195  max_score = score;
196  }
197  }
198  return max_score;
199  } else if (policy_ == PositionFilterPolicy::ACCEPT_IF_ANY) {
200  float min_score = kfInfty;
201  for (auto& filt : allFilters_) {
202  auto score = filt->score(agent, pos);
203  if (score < min_score) {
204  min_score = score;
205  }
206  }
207  return min_score;
208  }
209  LOG(ERROR) << "invalid moveFilter policy";
210  return -kfInfty;
211  }
212 
213  bool blocking() override {
214  return blocking_;
215  }
216 
217  protected:
218  PositionFilters allFilters_;
220  bool blocking_;
221 };
222 
224  public:
225  ConstantGetter(std::vector<Position> values) : storage_(values) {}
226  std::vector<Position> const& operator()(Unit* agent) {
227  return storage_;
228  }
229 
230  private:
231  std::vector<Position> storage_;
232 };
233 
234 // Makes a functional position filter of an arbitrary container reference
235 template <typename T, typename UnaryFunctionReturnsContainerTType>
236 PPositionFilter makePositionFilter(
237  UnaryFunctionReturnsContainerTType getter,
238  std::function<bool(Unit*, Position const&, T)> valid,
239  std::function<float(Unit*, Position const&, T)> scoreFunc,
241  bool blocking = false) {
242  auto filter = std::make_shared<
244  std::move(getter),
245  std::move(valid),
246  std::move(scoreFunc),
247  policy,
248  blocking);
249  return std::static_pointer_cast<PositionFilter>(filter);
250 }
251 
252 PPositionFilter makePositionFilter(
253  PPositionFilter base,
254  PositionFilters l,
255  bool blocking = false);
256 
257 PPositionFilter makePositionFilter(
258  PositionFilters l,
260  bool blocking = false);
261 
262 bool insideSpecificUnit(Position const& pos, Unit* bldg, int margin = 0);
263 bool insideSpecificUnit(Unit* unit, Position const& pos, Unit* bldg);
264 bool unitTouch(Unit* unit, Unit* v, int dirX = 0, int dirY = 0);
265 bool insideAnyUnit(Unit* unit, Position const& pos, std::vector<Unit*> units);
266 bool positionAvoids(Unit* agent, Position const& pos, Unit* nmy);
267 bool dangerousAttack(Unit* unit, Unit* tgt);
268 
269 inline std::vector<Unit*>& threateningEnemiesGetter(Unit* u) {
270  return u->threateningEnemies;
271 }
272 inline std::vector<Unit*>& beingAttackedByEnemiesGetter(Unit* u) {
273  return u->beingAttackedByEnemies;
274 }
275 inline std::vector<Unit*>& unitsInSightRangeGetter(Unit* u) {
276  return u->unitsInSightRange;
277 }
278 inline std::vector<Unit*>& obstaclesInSightRangeGetter(Unit* u) {
279  return u->obstaclesInSightRange;
280 }
281 inline std::vector<Unit*>& enemyUnitsInSightRangeGetter(Unit* u) {
282  return u->enemyUnitsInSightRange;
283 }
284 inline std::vector<Unit*>& allyUnitsInSightRangeGetter(Unit* u) {
285  return u->allyUnitsInSightRange;
286 }
287 inline auto negDistanceScore(Unit* unit, Position const& pos, Unit* nmy) {
288  return -pos.distanceTo(nmy);
289 }
290 inline auto distanceScore(Unit* unit, Position const& pos, Unit* nmy) {
291  return pos.distanceTo(nmy);
292 }
293 inline auto zeroScore(Unit* unit, Position const& pos, Unit* nmy) {
294  return 0;
295 }
296 
297 PPositionFilter fleeAttackers();
298 PPositionFilter fleeThreatening();
299 PPositionFilter avoidAttackers();
300 PPositionFilter avoidThreatening();
301 PPositionFilter avoidEnemyUnitsInRange(float range);
302 PPositionFilter getCloserTo(std::vector<Position> coordinates);
303 PPositionFilter getCloserTo(Unit* bldg);
304 PPositionFilter getCloserTo(Position const& pos);
305 bool walkable(State* state, Position const& pos);
306 
307 int constexpr kMoveLength = 16;
308 int constexpr kNumberPossibleMoves = 64;
309 int constexpr kMinMoveLength = 8;
310 int constexpr kMoveLOSStepSize = 4; // discretization to check line of sight
311 int constexpr kMinDistToTargetPos =
312  8; // too close to be considered as direction
313 int constexpr kTimeUpdateMove = 7;
314 
315 bool moveIsPossible(
316  State* state,
317  Position const& pos,
318  std::vector<Unit*> const& obstacles,
319  bool outOFBoundsInvalid);
320 
321 std::vector<std::vector<std::pair<float, Position>>> getValidMovePositions(
322  State* state,
323  Unit* unit,
324  PositionFilters const& filters,
325  int moveLength = kMoveLength,
326  int nbPossibleMoves = kNumberPossibleMoves,
327  int stepSize = kMoveLOSStepSize,
328  bool outOfBoundsInvalid = true);
329 
330 // Get the best position to move under the position filters. The best position
331 // is defined by the minimum score of the first filter in filters with minimum
332 // score. filters can be though of as a list of policies, where we follow them
333 // in order until we find a policy that gives a valid position.
334 template <typename T>
335 Position safeDirectionTo(State* state, Unit* unit, T tgt);
336 Position safeMoveTo(State* state, Unit* unit, Position const& pos);
337 Position pathMoveTo(State* state, Unit* unit, Position const& pos);
338 
340  State* state,
341  Unit* unit,
342  PositionFilters const& filters,
343  int moveLength = kMoveLength,
344  int nbPossibleMoves = kNumberPossibleMoves,
345  int stepSize = kMoveLOSStepSize,
346  bool outOfBoundsInvalid = true);
347 
349  State* state,
350  Unit* unit,
351  PPositionFilter const& filter,
352  int moveLength = kMoveLength,
353  int nbPossibleMoves = kNumberPossibleMoves,
354  int stepSize = kMoveLOSStepSize,
355  bool outOfBoundsInvalid = true);
356 
357 Position smartMove(State* state, Unit* unit, Position tgt);
358 
359 } // namespace movefilters
360 } // namespace cherrypi
PPositionFilter fleeThreatening()
Definition: movefilters.cpp:172
Game state.
Definition: state.h:42
Position safeDirectionTo(State *state, Unit *unit, T tgt)
Definition: movefilters.cpp:280
bool dangerousAttack(Unit *unit, Unit *tgt)
Definition: movefilters.cpp:132
PPositionFilter avoidAttackers()
Definition: movefilters.cpp:184
std::vector< Unit * > & beingAttackedByEnemiesGetter(Unit *u)
Definition: movefilters.h:272
auto negDistanceScore(Unit *unit, Position const &pos, Unit *nmy)
Definition: movefilters.h:287
bool unitTouch(Unit *unit, Unit *v, int dirX, int dirY)
Definition: movefilters.cpp:39
Position safeMoveTo(State *state, Unit *unit, Position const &pos)
Definition: movefilters.cpp:307
PPositionFilter makePositionFilter(PPositionFilter base, PositionFilters l, bool blocking)
Definition: movefilters.cpp:15
This filter uses the score of the base filter, only of all subfilters return that the position is val...
Definition: movefilters.h:124
std::vector< Unit * > & unitsInSightRangeGetter(Unit *u)
Definition: movefilters.h:275
std::vector< Position > const & operator()(Unit *agent)
Definition: movefilters.h:226
Position pathMoveTo(State *state, Unit *unit, Position const &pos)
Definition: movefilters.cpp:324
float score(Unit *agent, Position const &pos) override
Definition: movefilters.h:189
bool isValid(Unit *agent, Position const &pos) override
Definition: movefilters.h:169
int constexpr kMoveLength
Definition: movefilters.h:307
PositionFilters allFilters_
Definition: movefilters.h:154
std::vector< Unit * > obstaclesInSightRange
Definition: unitsinfo.h:100
FuncPositionFilter(std::function< Container const &&(Unit *)> getter, std::function< bool(Unit *, Position const &, T)> valid, std::function< float(Unit *, Position const &, T)> scoreFunc, PositionFilterPolicy policy=PositionFilterPolicy::ACCEPT_IF_ALL, bool blocking=false)
Definition: movefilters.h:46
std::vector< Unit * > allyUnitsInSightRange
Definition: unitsinfo.h:103
bool blocking_
Definition: movefilters.h:117
UnionPositionFilter(PositionFilters l, PositionFilterPolicy policy=PositionFilterPolicy::ACCEPT_IF_ALL, bool blocking=false)
Definition: movefilters.h:163
double distanceTo(Vec2T const &other) const
Definition: basetypes.h:122
STL namespace.
int constexpr kMinDistToTargetPos
Definition: movefilters.h:311
std::vector< PPositionFilter > PositionFilters
Definition: movefilters.h:31
int constexpr kTimeUpdateMove
Definition: movefilters.h:313
int constexpr kNumberPossibleMoves
Definition: movefilters.h:308
std::function< Container const &&(Unit *)> getter_
Definition: movefilters.h:113
ConstantGetter(std::vector< Position > values)
Definition: movefilters.h:225
std::vector< Unit * > enemyUnitsInSightRange
Definition: unitsinfo.h:102
int constexpr kMoveLOSStepSize
Definition: movefilters.h:310
PositionFilters allFilters_
Definition: movefilters.h:218
bool blocking() override
Definition: movefilters.h:213
PPositionFilter avoidThreatening()
Definition: movefilters.cpp:189
bool isValid(Unit *agent, Position const &pos) override
Definition: movefilters.h:132
bool isValid(Unit *agent, Position const &pos) override
Definition: movefilters.h:58
PositionFilterPolicy
Definition: movefilters.h:16
PPositionFilter getCloserTo(std::vector< Position > coordinates)
Definition: movefilters.cpp:204
bool blocking() override
Definition: movefilters.h:108
bool positionAvoids(Unit *unit, Position const &pos, Unit *nmy)
Definition: movefilters.cpp:74
bool moveIsPossible(State *state, Position const &pos, std::vector< Unit * > const &obstacles, bool outOfBoundsInvalid)
Definition: movefilters.cpp:262
std::shared_ptr< PositionFilter > PPositionFilter
Definition: movefilters.h:30
Represents a unit in the game.
Definition: unitsinfo.h:35
bool blocking_
Definition: movefilters.h:220
bool walkable(State *state, Position const &pos)
Definition: movefilters.cpp:257
std::vector< Unit * > & threateningEnemiesGetter(Unit *u)
Definition: movefilters.h:269
bool insideAnyUnit(Unit *unit, Position const &pos, std::vector< Unit * > units)
Definition: movefilters.cpp:68
Assigns positions a score and validity, based on some criteria.
Definition: movefilters.h:22
std::vector< Unit * > threateningEnemies
Definition: unitsinfo.h:93
auto distanceScore(Unit *unit, Position const &pos, Unit *nmy)
Definition: movefilters.h:290
std::function< float(Unit *, Position const &, T)> score_
Definition: movefilters.h:115
auto zeroScore(Unit *unit, Position const &pos, Unit *nmy)
Definition: movefilters.h:293
std::vector< std::vector< std::pair< float, Position > > > getValidMovePositions(State *state, Unit *unit, PositionFilters const &filters, int moveLength=kMoveLength, int nbPossibleMoves=kNumberPossibleMoves, int stepSize=kMoveLOSStepSize, bool outOfBoundsInvalid=true)
PositionFilterPolicy policy_
Definition: movefilters.h:116
PPositionFilter avoidEnemyUnitsInRange(float range)
Definition: movefilters.cpp:194
float constexpr kfInfty
Definition: basetypes.h:29
float score(Unit *agent, Position const &pos) override
Definition: movefilters.h:83
Definition: movefilters.h:223
bool insideSpecificUnit(Position const &pos, Unit *bldg, int margin)
Definition: movefilters.cpp:28
bool blocking_
Definition: movefilters.h:155
class to combine filters
Definition: movefilters.h:161
MultiPositionFilter(PPositionFilter base, PositionFilters l, bool blocking=false)
Definition: movefilters.h:126
Position smartMove(State *state, Unit *unit, PositionFilters const &filters, int moveLength, int nbPossibleMoves, int stepSize, bool outOfBoundsInvalid)
Definition: movefilters.cpp:355
PPositionFilter fleeAttackers()
Definition: movefilters.cpp:160
PPositionFilter base_
Definition: movefilters.h:153
Main namespace for bot-related code.
Definition: areainfo.cpp:17
float score(Unit *agent, Position const &pos) override
Definition: movefilters.h:144
bool blocking() override
Definition: movefilters.h:148
std::vector< Unit * > unitsInSightRange
Definition: unitsinfo.h:98
int constexpr kMinMoveLength
Definition: movefilters.h:309
std::vector< Unit * > beingAttackedByEnemies
Definition: unitsinfo.h:95
std::vector< Unit * > & enemyUnitsInSightRangeGetter(Unit *u)
Definition: movefilters.h:281
PositionFilterPolicy policy_
Definition: movefilters.h:219
std::vector< Unit * > & obstaclesInSightRangeGetter(Unit *u)
Definition: movefilters.h:278
std::vector< Unit * > & allyUnitsInSightRangeGetter(Unit *u)
Definition: movefilters.h:284
std::function< bool(Unit *, Position const &, T)> valid_
Definition: movefilters.h:114
Input:
Definition: movefilters.h:44