TorchCraftAI
A bot for machine learning research on StarCraft: Brood War
baseplayer.h
1 /*
2  * Copyright (c) 2017-present, Facebook, Inc.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #pragma once
9 
10 #include <condition_variable>
11 #include <mutex>
12 #include <thread>
13 #include <vector>
14 
15 #include "cherrypi.h"
16 #include "module.h"
17 #include "state.h"
18 
19 namespace cherrypi {
20 
21 /**
22  * The main bot object.
23  *
24  * This class is used to play StarCraft Broodwar (TM) via the TorchCraft bridge.
25  * The behavior and actions of the player are determined by a user-supplied list
26  * of bot modules.
27  */
28 class BasePlayer {
29  using ClientCommands = std::vector<tc::Client::Command>;
30 
31  public:
32  BasePlayer(std::shared_ptr<tc::Client> client);
33  virtual ~BasePlayer();
34  BasePlayer(const BasePlayer&) = delete;
35  BasePlayer& operator=(const BasePlayer&) = delete;
36 
37  State* state() {
38  return state_;
39  }
40 
41  std::shared_ptr<Module> getTopModule() const;
42  void addModule(std::shared_ptr<Module> module);
43  void addModules(std::vector<std::shared_ptr<Module>> const& modules);
44 
45  template <typename T>
46  std::shared_ptr<T> findModule() {
47  for (auto& module : modules_) {
48  auto m = std::dynamic_pointer_cast<T>(module);
49  if (m != nullptr) {
50  return m;
51  }
52  }
53  return nullptr;
54  };
55 
56  /// Add some commands to the queue, they will be executed on next step()
57  void queueCmds(const std::vector<tc::Client::Command>& cmds);
58 
59  /// Log a warning if step() exceeds a maximum duration.
60  /// Defaults to false.
61  void setWarnIfSlow(bool warn);
62 
63  /// Delay step() to make the game run in approx. factor*fastest speed.
64  void setRealtimeFactor(float factor);
65 
66  /// Set whether to perform consistency checks during the game.
67  void setCheckConsistency(bool check);
68 
69  /// Set whether to gather timing statistics during the game.
70  void setCollectTimers(bool collect);
71 
72  /// Set whether to log failed commands (via VLOG(0)).
73  void setLogFailedCommands(bool log);
74 
75  /// Set whether to post drawing commands (if any are posted).
76  /// Defaults to true.
77  void setDraw(bool draw);
78 
79  virtual void stepModule(std::shared_ptr<Module> module);
80  void stepModules();
81  void step();
82  size_t steps() const {
83  return steps_;
84  }
85 
86  virtual void init(){};
87 
88  void leave();
89 
90  void dumpTraceAlongReplay(std::string const& replayFile);
91 
92  protected:
93  using commandStartEndFrame =
94  std::pair<tc::BW::UnitCommandType, std::pair<FrameNum, FrameNum>>;
95  virtual void preStep();
96  virtual void postStep();
97  void logFailedCommands();
98 
99  std::shared_ptr<tc::Client> client_;
100  int frameskip_ = 1;
101  int combineFrames_ = 3;
102  bool warnIfSlow_ = false;
103  bool nonBlocking_ = false;
104  bool checkConsistency_ = false;
105  bool collectTimers_ = false;
106  bool logFailedCommands_ = false;
108  int framesDropped_ = 0;
109  float realtimeFactor_ = -1.0f;
110  std::vector<std::shared_ptr<Module>> modules_;
112  std::shared_ptr<Module> top_;
113  std::unordered_map<std::shared_ptr<Module>, Duration> moduleTimeSpent_;
114  std::unordered_map<std::shared_ptr<Module>, Duration> moduleTimeSpentAgg_;
117  size_t steps_ = 0;
118  bool initialized_ = false;
119  bool firstStepDone_ = false;
120  hires_clock::time_point lastStep_;
121  bool draw_ = true;
122 
123  std::vector<tc::Client::Command> pendingCmds_;
124 
125  static const decltype(std::chrono::milliseconds(50)) kMaxStepDuration;
126  static const decltype(std::chrono::seconds(9)) kMaxInitialStepDuration;
127  static const decltype(std::chrono::milliseconds(42)) kStepDurationAtFastest;
128 
129  ClientCommands doStep();
130 };
131 
132 } // namespace cherrypi
virtual void preStep()
Definition: baseplayer.cpp:261
Game state.
Definition: state.h:42
void logFailedCommands()
Definition: baseplayer.cpp:316
void setRealtimeFactor(float factor)
Delay step() to make the game run in approx. factor*fastest speed.
Definition: baseplayer.cpp:92
virtual void init()
Definition: baseplayer.h:86
std::vector< tc::Client::Command > pendingCmds_
Definition: baseplayer.h:123
std::pair< tc::BW::UnitCommandType, std::pair< FrameNum, FrameNum >> commandStartEndFrame
Definition: baseplayer.h:94
std::unordered_map< std::shared_ptr< Module >, Duration > moduleTimeSpentAgg_
Definition: baseplayer.h:114
void addModule(std::shared_ptr< Module > module)
Definition: baseplayer.cpp:55
static decltype(std::chrono::milliseconds(42)) const kStepDurationAtFastest
Definition: baseplayer.h:127
int framesDropped_
Definition: baseplayer.h:108
bool initialized_
Definition: baseplayer.h:118
Duration stateUpdateTimeSpent_
Definition: baseplayer.h:115
STL namespace.
BasePlayer & operator=(const BasePlayer &)=delete
bool checkConsistency_
Definition: baseplayer.h:104
std::chrono::nanoseconds Duration
Definition: cherrypi.h:36
virtual void stepModule(std::shared_ptr< Module > module)
Definition: baseplayer.cpp:122
int frameskip_
Definition: baseplayer.h:100
std::shared_ptr< Module > top_
Definition: baseplayer.h:112
void dumpTraceAlongReplay(std::string const &replayFile)
Definition: baseplayer.cpp:357
void leave()
Definition: baseplayer.cpp:256
size_t steps() const
Definition: baseplayer.h:82
static decltype(std::chrono::milliseconds(50)) const kMaxStepDuration
Definition: baseplayer.h:125
bool firstStepDone_
Definition: baseplayer.h:119
void setWarnIfSlow(bool warn)
Log a warning if step() exceeds a maximum duration.
Definition: baseplayer.cpp:88
int lastFrameStepped_
Definition: baseplayer.h:107
The main bot object.
Definition: baseplayer.h:28
Duration stateUpdateTimeSpentAgg_
Definition: baseplayer.h:116
bool draw_
Definition: baseplayer.h:121
bool logFailedCommands_
Definition: baseplayer.h:106
virtual ~BasePlayer()
Definition: baseplayer.cpp:47
ClientCommands doStep()
Do the actual per-step work.
Definition: baseplayer.cpp:292
void setDraw(bool draw)
Set whether to post drawing commands (if any are posted).
Definition: baseplayer.cpp:110
State * state()
Definition: baseplayer.h:37
State * state_
Definition: baseplayer.h:111
float realtimeFactor_
Definition: baseplayer.h:109
void setLogFailedCommands(bool log)
Set whether to log failed commands (via VLOG(0)).
Definition: baseplayer.cpp:106
std::shared_ptr< Module > getTopModule() const
Definition: baseplayer.cpp:51
hires_clock::time_point lastStep_
Definition: baseplayer.h:120
void step()
Definition: baseplayer.cpp:135
void setCheckConsistency(bool check)
Set whether to perform consistency checks during the game.
Definition: baseplayer.cpp:96
BasePlayer(std::shared_ptr< tc::Client > client)
Definition: baseplayer.cpp:32
int combineFrames_
Definition: baseplayer.h:101
std::shared_ptr< T > findModule()
Definition: baseplayer.h:46
void stepModules()
Definition: baseplayer.cpp:114
void queueCmds(const std::vector< tc::Client::Command > &cmds)
Add some commands to the queue, they will be executed on next step()
Definition: baseplayer.cpp:353
std::shared_ptr< tc::Client > client_
Definition: baseplayer.h:99
std::vector< std::shared_ptr< Module > > modules_
Definition: baseplayer.h:110
Main namespace for bot-related code.
Definition: areainfo.cpp:17
bool collectTimers_
Definition: baseplayer.h:105
std::unordered_map< std::shared_ptr< Module >, Duration > moduleTimeSpent_
Definition: baseplayer.h:113
static decltype(std::chrono::seconds(9)) const kMaxInitialStepDuration
Definition: baseplayer.h:126
void addModules(std::vector< std::shared_ptr< Module >> const &modules)
Definition: baseplayer.cpp:81
size_t steps_
Definition: baseplayer.h:117
bool nonBlocking_
Definition: baseplayer.h:103
bool warnIfSlow_
Definition: baseplayer.h:102
void setCollectTimers(bool collect)
Set whether to gather timing statistics during the game.
Definition: baseplayer.cpp:100
virtual void postStep()
Definition: baseplayer.cpp:274