TorchCraftAI
A bot for machine learning research on StarCraft: Brood War
cherryvisdumper.h
1 /*
2  * Copyright (c) 2017-present, Facebook, Inc.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #pragma once
9 #include "blackboard.h"
10 #include "buildtype.h"
11 #include "cherrypi.h"
12 #include "module.h"
13 #include "state.h"
14 #include "unitsinfo.h"
15 
16 #include <fmt/format.h>
17 #include <nlohmann/json.hpp>
18 #include <thread>
19 #include <variant>
20 
21 namespace cherrypi {
22 
23 class CherryVisLogSink;
24 
25 #define CVIS_LOG(state) \
26  if (auto _cvisDumper = state->board()->getTraceDumper()) \
27  _cvisDumper->getGlobalLogger().getStream(state, __FILE__, __LINE__)
28 
29 #define CVIS_LOG_UNIT(state, unit) \
30  if (auto _cvisDumper = state->board()->getTraceDumper()) \
31  _cvisDumper->getUnitLogger(unit).getStream(state, __FILE__, __LINE__)
32 
33 /**
34  * Records bot internal state and dumps it to a file readable by CherryVis.
35  */
36 class CherryVisDumperModule : public Module {
37  public:
38  class Logger;
39  using Dumpable = std::variant<Unit*, Position, int, float, std::string>;
40 
41  class Logger {
42  public:
43  template <typename T, typename _U>
44  struct TypeInfo {
45  static constexpr const char* kName = nullptr;
46  };
47 #define ACCEPT_TYPE(TYPE) \
48  template <typename _U> \
49  struct TypeInfo<TYPE, _U> { \
50  static constexpr const char* kName = #TYPE; \
51  };
52  ACCEPT_TYPE(Unit*);
54  ACCEPT_TYPE(int);
55  ACCEPT_TYPE(float);
56  ACCEPT_TYPE(std::string);
58 
59  class LoggerStream {
60  public:
61  friend class Logger;
62  LoggerStream(State* state, Logger& logger, const char* filename, int line)
63  : state_(state), logger_(logger), filename_(filename), line_(line) {}
65  logger_.addMessage(
66  state_, message_.str(), std::move(attachments_), filename_, line_);
67  }
68 
69  template <typename T>
70  LoggerStream& operator<<(T const& m) {
71  message_ << m;
72  return *this;
73  }
74 
75  template <typename K, typename V>
76  LoggerStream& operator<<(std::unordered_map<K, V> m) {
77  static_assert(TypeInfo<K, void>::kName, "Map Key has invalid type");
78  static_assert(TypeInfo<V, void>::kName, "Map Value has invalid type");
79  attachments_.push_back({
80  {"map", std::move(m)},
81  {"key_type", TypeInfo<K, void>::kName},
82  {"value_type", TypeInfo<V, void>::kName},
83  });
84  return *this;
85  }
86 
87  protected:
88  std::stringstream message_;
89  std::vector<nlohmann::json> attachments_;
92  const char* filename_;
93  int line_;
94  };
95 
96  LoggerStream getStream(State* state, const char* filename, int line) {
97  return LoggerStream(state, *this, filename, line);
98  }
99 
100  void addMessage(
101  State* state,
102  std::string message,
103  std::vector<nlohmann::json> attachments = {},
104  const char* full_filename = "",
105  int line = 0,
106  google::LogSeverity severity = 0);
107 
108  nlohmann::json to_json() const {
109  return logs_;
110  }
111 
112  protected:
113  std::vector<nlohmann::json> logs_;
114  };
115 
116  virtual ~CherryVisDumperModule() = default;
117 
118  void setReplayFile(std::string const& replayFile) {
119  replayFileName_ = replayFile;
120  }
121 
122  virtual void step(State* s) override;
123  virtual void onGameStart(State* s) override;
124  virtual void onGameEnd(State* s) override;
125 
127  return trace_.logs_;
128  }
129 
131  return trace_.unitsLogs_[std::to_string(u->id)];
132  }
133 
134  void onDrawCommand(State* s, tc::Client::Command const& command);
135 
136  template <typename T, typename V, typename W>
137  void addTree(
138  State* s,
139  std::string const& name,
140  T dump_node,
141  V get_children,
142  W root);
143  void dumpTensorsSummary(
144  State* s,
145  std::unordered_map<std::string, ag::Variant> const& tensors);
146 
147  /**
148  * Dumps a heatmap that can be displayed as overlay of the terrain.
149  * \param{tensors} is expected to be a dict of 2 dimensionnal at::Tensor
150  * in a [y, x] order.
151  */
152  void dumpTerrainHeatmaps(
153  State* s,
154  std::unordered_map<std::string, ag::Variant> const& tensors,
155  std::array<int, 2> const& topLeftPixel,
156  std::array<float, 2> const& scalingToPixels);
157 
158  void dumpGameValue(State* s, std::string const& key, float value) {
159  trace_.gameValues_[std::to_string(s->currentFrame())][key] = value;
160  }
161 
162  static void writeGameSummary(State* final_state, std::string const& file);
163  static std::string parseReplayFileName(std::string name);
164 
165  class TreeNode : public std::stringstream {
166  public:
167  std::vector<std::shared_ptr<TreeNode>> children;
168 
169  void setModule(std::string name) {
170  node["module"] = name;
171  }
172  void setFrame(FrameNum f) {
173  node["frame"] = f;
174  }
175  void setId(int32_t id, std::string prefix = "") {
176  node["id"] = id;
177  node["type_prefix"] = prefix;
178  }
179  void addUnitWithProb(Unit* unit, float proba) {
180  prob_distr.push_back({{"type_prefix", "i"},
181  {"id", unit ? unit->id : -1},
182  {"proba", proba}});
183  }
184  nlohmann::json to_json() {
185  if (!str().empty()) {
186  node["description"] = str();
187  }
188  if (!prob_distr.empty()) {
189  node["distribution"] = prob_distr;
190  }
191  return node;
192  }
193 
194  protected:
195  std::unordered_map<std::string, nlohmann::json> node;
196  std::vector<nlohmann::json> prob_distr;
197  };
198 
199  protected:
200  struct UnitData {
201  int32_t lastSeenTask = -1;
202  int32_t lastSeenType = -1;
203  };
204  struct TreeData {
205  nlohmann::json metadata;
206  std::shared_ptr<TreeNode> graph;
207  std::vector<std::shared_ptr<TreeNode>> allNodes;
208  };
209  struct TraceData {
210  // We assign IDs to tasks so we can store them only once
211  std::unordered_map<std::shared_ptr<Task>, int32_t> taskToId_;
212  std::vector<nlohmann::json> tasks_;
213 
214  // Blackboard
215  std::unordered_map<std::string, std::string> boardKnownValues_;
216  nlohmann::json boardUpdates_;
217 
218  // Units
219  std::unordered_map<UnitId, UnitData> unitsInfos_;
220  std::unordered_map<std::string /* unit_id */, nlohmann::json> unitsUpdates_;
221  std::unordered_map<std::string /* frame_id */, std::vector<nlohmann::json>>
223 
224  // Logs
226  std::unordered_map<std::string /* unit_id */, Logger> unitsLogs_;
227 
228  // Draw commands
229  std::unordered_map<std::string /* frame_id */, std::vector<nlohmann::json>>
231 
232  // Graphs
233  std::unordered_map<std::string /* filename */, TreeData> trees_;
234  std::vector<nlohmann::json> treesMetadata_;
235 
236  // Tensors
237  std::unordered_map<
238  std::string /* filename */,
239  std::unordered_map<std::string /* frame */, nlohmann::json>>
241  std::unordered_map<std::string /* name */, std::string /* filename */>
243  std::vector<nlohmann::json> heatmapsMetadata_;
244  std::unordered_map<std::string /* frame_id */, std::vector<nlohmann::json>>
246 
247  // Game values
248  std::unordered_map<std::string /* frame_id */, nlohmann::json> gameValues_;
249  };
250 
251  int32_t getUnitTaskId(State* s, Unit* unit);
252  std::string getBoardValueAsString(Blackboard::Data const& value);
253  void dumpGameUpcs(State* s);
254  void writeTrees(std::string const& dumpDirectory);
255  void writeTensors(std::string const& dumpDirectory);
256  nlohmann::json getTensorSummary(std::string const& name, at::Tensor const& t);
257  nlohmann::json getTensor1d(at::Tensor const& t);
258 
259  std::string replayFileName_;
261  std::unique_ptr<CherryVisLogSink> logSink_;
262 };
263 
264 class CherryVisLogSink : public google::LogSink {
265  public:
267  : module_(module), state_(state), threadId_(std::this_thread::get_id()) {
268  google::AddLogSink(this);
269  }
270  virtual ~CherryVisLogSink() {
271  google::RemoveLogSink(this);
272  }
273 
274  virtual void send(
275  google::LogSeverity severity,
276  const char* full_filename,
277  const char* base_filename,
278  int line,
279  const struct ::tm* tm_time,
280  const char* message,
281  size_t message_len) override {
282  // TODO: In case of self-play, the opponents share the same thread_id
283  if (threadId_ == std::this_thread::get_id()) {
284  module_->getGlobalLogger().addMessage(
285  state_,
286  std::string(message, message_len),
287  {},
288  full_filename,
289  line,
290  severity);
291  }
292  }
293 
294  protected:
297  std::thread::id threadId_;
298 };
299 
300 template <typename T, typename V, typename W>
302  State* s,
303  std::string const& name,
304  T dump_node,
305  V get_children,
306  W root) {
307  std::string filename = "tree__" + std::to_string(trace_.trees_.size()) +
308  "__f" + std::to_string(s->currentFrame()) + ".json.zstd";
309  TreeData& g = trace_.trees_[filename];
310  g.graph = std::make_shared<TreeNode>();
311  uint32_t nodesCount = 0;
312  std::vector<std::pair<W /* node */, std::shared_ptr<TreeNode>>> todo;
313 
314  // Process root
315  dump_node(root, g.graph);
316  ++nodesCount;
317  g.allNodes.push_back(g.graph);
318  std::vector<W> childs = get_children(root);
319  for (auto c : childs) {
320  todo.push_back(std::make_pair(c, g.graph));
321  }
322 
323  // Process queue
324  while (!todo.empty()) {
325  auto doing = todo.back();
326  todo.pop_back();
327 
328  // Node
329  doing.second->children.emplace_back(std::make_shared<TreeNode>());
330  dump_node(doing.first, doing.second->children.back());
331  ++nodesCount;
332  g.allNodes.push_back(doing.second->children.back());
333  std::vector<W> childs = get_children(doing.first);
334  for (auto c : childs) {
335  todo.push_back(std::make_pair(c, doing.second->children.back()));
336  }
337  }
338  g.metadata = {
339  {"frame", s->currentFrame()},
340  {"name", name},
341  {"nodes", nodesCount},
342  {"filename", filename},
343  };
344  trace_.treesMetadata_.push_back(g.metadata);
345 }
346 
347 void to_json(nlohmann::json& json, CherryVisDumperModule::Logger const& logger);
348 void to_json(
349  nlohmann::json& json,
350  CherryVisDumperModule::Dumpable const& logger);
351 void to_json(nlohmann::json& json, Unit const* unit);
352 void to_json(nlohmann::json& json, Position const& p);
353 } // namespace cherrypi
Game state.
Definition: state.h:42
std::unordered_map< std::string, nlohmann::json > gameValues_
Definition: cherryvisdumper.h:248
std::vector< nlohmann::json > prob_distr
Definition: cherryvisdumper.h:196
int FrameNum
Definition: basetypes.h:22
LoggerStream getStream(State *state, const char *filename, int line)
Definition: cherryvisdumper.h:96
static constexpr const char * kName
Definition: cherryvisdumper.h:45
Definition: cherryvisdumper.h:200
std::unordered_map< std::string, std::unordered_map< std::string, nlohmann::json > > tensors_
Definition: cherryvisdumper.h:240
std::shared_ptr< TreeNode > graph
Definition: cherryvisdumper.h:206
std::unordered_map< std::string, std::vector< nlohmann::json > > unitsFirstSeen_
Definition: cherryvisdumper.h:222
virtual void step(State *s) override
Definition: cherryvisdumper.cpp:34
Definition: cherryvisdumper.h:264
Logger & getUnitLogger(Unit *u)
Definition: cherryvisdumper.h:130
void dumpGameUpcs(State *s)
Definition: cherryvisdumper.cpp:297
TraceData trace_
Definition: cherryvisdumper.h:260
FrameNum currentFrame() const
Definition: state.h:57
std::unordered_map< std::string, std::vector< nlohmann::json > > tensorsSummary_
Definition: cherryvisdumper.h:245
int32_t getUnitTaskId(State *s, Unit *unit)
Definition: cherryvisdumper.cpp:212
void addUnitWithProb(Unit *unit, float proba)
Definition: cherryvisdumper.h:179
Definition: cherryvisdumper.h:209
std::unordered_map< std::string, std::string > tensorNameToFile_
Definition: cherryvisdumper.h:242
std::unordered_map< std::string, nlohmann::json > unitsUpdates_
Definition: cherryvisdumper.h:220
void dumpTerrainHeatmaps(State *s, std::unordered_map< std::string, ag::Variant > const &tensors, std::array< int, 2 > const &topLeftPixel, std::array< float, 2 > const &scalingToPixels)
Dumps a heatmap that can be displayed as overlay of the terrain.
Definition: cherryvisdumper.cpp:416
const char * filename_
Definition: cherryvisdumper.h:92
nlohmann::json to_json()
Definition: cherryvisdumper.h:184
CherryVisDumperModule * module_
Definition: cherryvisdumper.h:295
STL namespace.
State * state_
Definition: cherryvisdumper.h:296
void addTree(State *s, std::string const &name, T dump_node, V get_children, W root)
Definition: cherryvisdumper.h:301
std::unordered_map< std::string, std::string > boardKnownValues_
Definition: cherryvisdumper.h:215
std::vector< std::shared_ptr< TreeNode > > allNodes
Definition: cherryvisdumper.h:207
std::unordered_map< std::shared_ptr< Task >, int32_t > taskToId_
Definition: cherryvisdumper.h:211
static void writeGameSummary(State *final_state, std::string const &file)
Definition: cherryvisdumper.cpp:232
std::variant< Unit *, Position, int, float, std::string > Dumpable
Definition: cherryvisdumper.h:39
void onDrawCommand(State *s, tc::Client::Command const &command)
Definition: cherryvisdumper.cpp:189
std::unordered_map< std::string, Logger > unitsLogs_
Definition: cherryvisdumper.h:226
LoggerStream(State *state, Logger &logger, const char *filename, int line)
Definition: cherryvisdumper.h:62
virtual void send(google::LogSeverity severity, const char *full_filename, const char *base_filename, int line, const struct::tm *tm_time, const char *message, size_t message_len) override
Definition: cherryvisdumper.h:274
Definition: cherryvisdumper.h:165
std::string name()
Definition: module.cpp:41
nlohmann::json boardUpdates_
Definition: cherryvisdumper.h:216
virtual void onGameStart(State *s) override
Definition: cherryvisdumper.cpp:87
std::stringstream message_
Definition: cherryvisdumper.h:88
int line_
Definition: cherryvisdumper.h:93
mapbox::util::variant< bool, int, float, double, std::string, Position, std::shared_ptr< SharedController >, std::unordered_map< int, int >> Data
A variant of types that are allowed in the Blackboard&#39;s key-value storage.
Definition: blackboard.h:99
UnitId id
Definition: unitsinfo.h:36
virtual ~CherryVisDumperModule()=default
Represents a unit in the game.
Definition: unitsinfo.h:35
void setId(int32_t id, std::string prefix="")
Definition: cherryvisdumper.h:175
std::string getBoardValueAsString(Blackboard::Data const &value)
Definition: cherryvisdumper.cpp:292
std::vector< nlohmann::json > treesMetadata_
Definition: cherryvisdumper.h:234
nlohmann::json to_json() const
Definition: cherryvisdumper.h:108
void setReplayFile(std::string const &replayFile)
Definition: cherryvisdumper.h:118
void setModule(std::string name)
Definition: cherryvisdumper.h:169
std::vector< nlohmann::json > logs_
Definition: cherryvisdumper.h:113
~LoggerStream()
Definition: cherryvisdumper.h:64
void dumpGameValue(State *s, std::string const &key, float value)
Definition: cherryvisdumper.h:158
void dumpTensorsSummary(State *s, std::unordered_map< std::string, ag::Variant > const &tensors)
Definition: cherryvisdumper.cpp:397
Records bot internal state and dumps it to a file readable by CherryVis.
Definition: cherryvisdumper.h:36
void writeTrees(std::string const &dumpDirectory)
Definition: cherryvisdumper.cpp:330
State * state_
Definition: cherryvisdumper.h:90
Logger & getGlobalLogger()
Definition: cherryvisdumper.h:126
std::unique_ptr< CherryVisLogSink > logSink_
Definition: cherryvisdumper.h:261
Logger logs_
Definition: cherryvisdumper.h:225
static std::string parseReplayFileName(std::string name)
Definition: cherryvisdumper.cpp:96
std::thread::id threadId_
Definition: cherryvisdumper.h:297
std::unordered_map< std::string, std::vector< nlohmann::json > > drawCommands_
Definition: cherryvisdumper.h:230
Logger & logger_
Definition: cherryvisdumper.h:91
std::unordered_map< UnitId, UnitData > unitsInfos_
Definition: cherryvisdumper.h:219
nlohmann::json metadata
Definition: cherryvisdumper.h:205
virtual ~CherryVisLogSink()
Definition: cherryvisdumper.h:270
std::vector< nlohmann::json > tasks_
Definition: cherryvisdumper.h:212
std::vector< std::shared_ptr< TreeNode > > children
Definition: cherryvisdumper.h:167
void setFrame(FrameNum f)
Definition: cherryvisdumper.h:172
CherryVisLogSink(CherryVisDumperModule *module, State *state)
Definition: cherryvisdumper.h:266
Main namespace for bot-related code.
Definition: areainfo.cpp:17
nlohmann::json getTensor1d(at::Tensor const &t)
Definition: cherryvisdumper.cpp:387
std::vector< nlohmann::json > heatmapsMetadata_
Definition: cherryvisdumper.h:243
nlohmann::json getTensorSummary(std::string const &name, at::Tensor const &t)
Definition: cherryvisdumper.cpp:355
std::string replayFileName_
Definition: cherryvisdumper.h:259
void addMessage(State *state, std::string message, std::vector< nlohmann::json > attachments={}, const char *full_filename="", int line=0, google::LogSeverity severity=0)
Definition: cherryvisdumper.cpp:478
std::unordered_map< std::string, nlohmann::json > node
Definition: cherryvisdumper.h:195
void writeTensors(std::string const &dumpDirectory)
Definition: cherryvisdumper.cpp:469
Definition: cherryvisdumper.h:41
std::vector< nlohmann::json > attachments_
Definition: cherryvisdumper.h:89
Interface for bot modules.
Definition: module.h:30
virtual void onGameEnd(State *s) override
Definition: cherryvisdumper.cpp:130
std::unordered_map< std::string, TreeData > trees_
Definition: cherryvisdumper.h:233
LoggerStream & operator<<(T const &m)
Definition: cherryvisdumper.h:70
Definition: cherryvisdumper.h:204