TorchCraftAI
A bot for machine learning research on StarCraft: Brood War
replayer.h
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree. An additional grant
7  * of patent rights can be found in the PATENTS file in the same directory.
8  */
9 
10 #pragma once
11 
12 #include <cstdio>
13 #include <fstream>
14 #include <iostream>
15 #include <vector>
16 
17 #include <torchcraft/frame.h>
18 #include <torchcraft/refcount.h>
19 #include <torchcraft/state.h>
20 
21 namespace torchcraft {
22 namespace replayer {
23 
24 struct Map {
25  uint32_t height, width;
26  std::vector<uint8_t> data;
27 };
28 
29 class Replayer : public RefCounted {
30  private:
31  std::vector<Frame*> frames;
32  std::unordered_map<int32_t, int32_t> numUnits;
33  Map map;
34  // If keyframe = 0, every frame is a frame.
35  // Otherwise, every keyframe is a frame, and all others are diffs.
36  // Only affects saving/loading (replays are still large in memory)
37  uint32_t keyframe;
38 
39  public:
41  for (auto f : frames) {
42  if (f)
43  f->decref();
44  }
45  }
46 
47  Frame* getFrame(size_t i) {
48  if (i >= frames.size())
49  return nullptr;
50  return frames[i];
51  }
52  void push(Frame* f) {
53  auto new_frame = new Frame(f);
54  frames.push_back(new_frame);
55  }
56  void setKeyFrame(int32_t x) {
57  keyframe = x < 0 ? frames.size() + 1 : (uint32_t)x;
58  }
59  uint32_t getKeyFrame() {
60  return keyframe;
61  }
62  size_t size() const {
63  return frames.size();
64  }
65  int32_t mapHeight() const {
66  return map.height;
67  }
68  int32_t mapWidth() const {
69  return map.width;
70  }
71 
72  void setNumUnits() {
73  for (const auto f : frames) {
74  for (auto u : f->units) {
75  auto s = u.second.size();
76  auto i = u.first;
77  if (numUnits.count(i) == 0) {
78  numUnits[i] = s;
79  } else if (s > static_cast<size_t>(numUnits[i])) {
80  numUnits[i] = s;
81  }
82  }
83  }
84  }
85 
86  int32_t getNumUnits(const int32_t& key) const {
87  if (numUnits.find(key) == numUnits.end())
88  return -1;
89  return numUnits.at(key);
90  }
91 
92  void setMapFromState(torchcraft::State const* state);
93 
94  void setMap(
95  int32_t h,
96  int32_t w,
97  std::vector<uint8_t> const& walkability,
98  std::vector<uint8_t> const& ground_height,
99  std::vector<uint8_t> const& buildability,
100  std::vector<int> const& start_loc_x,
101  std::vector<int> const& start_loc_y);
102 
103  void setMap(
104  int32_t h,
105  int32_t w,
106  uint8_t const* const walkability,
107  uint8_t const* const ground_height,
108  uint8_t const* const buildability,
109  std::vector<int> const& start_loc_x,
110  std::vector<int> const& start_loc_y);
111 
112  void setRawMap(uint32_t h, uint32_t w, uint8_t const* d) {
113  // free existing map if needed
114  map.data.resize(h * w);
115  map.data.assign(d, d + h * w);
116  map.height = h;
117  map.width = w;
118  }
119 
120  const std::vector<uint8_t>& getRawMap() {
121  return map.data;
122  }
123 
124  std::pair<int32_t, int32_t> getMap(
125  std::vector<uint8_t>& walkability,
126  std::vector<uint8_t>& ground_height,
127  std::vector<uint8_t>& buildability,
128  std::vector<int>& start_loc_x,
129  std::vector<int>& start_loc_y) const;
130 
131  friend std::ostream& operator<<(std::ostream& out, const Replayer& o);
132  friend std::istream& operator>>(std::istream& in, Replayer& o);
133 
134  void load(const std::string& path);
135  void save(const std::string& path, bool compressed = false);
136 };
137 
138 } // namespace replayer
139 } // namespace torchcraft
void setNumUnits()
Definition: replayer.h:72
size_t size() const
Definition: replayer.h:62
Copyright (c) 2015-present, Facebook, Inc.
Definition: openbwprocess.h:17
int32_t mapWidth() const
Definition: replayer.h:68
Definition: frame.h:306
int32_t getNumUnits(const int32_t &key) const
Definition: replayer.h:86
uint32_t getKeyFrame()
Definition: replayer.h:59
std::istream & operator>>(std::istream &in, Frame &o)
void setRawMap(uint32_t h, uint32_t w, uint8_t const *d)
Definition: replayer.h:112
std::ostream & operator<<(std::ostream &out, const Frame &o)
const std::vector< uint8_t > & getRawMap()
Definition: replayer.h:120
replayer::Frame Frame
Definition: state.h:41
Copyright (c) 2015-present, Facebook, Inc.
Definition: refcount.h:23
std::vector< uint8_t > data
Definition: replayer.h:26
~Replayer()
Definition: replayer.h:40
Frame * getFrame(size_t i)
Definition: replayer.h:47
Definition: state.h:43
void push(Frame *f)
Definition: replayer.h:52
Definition: replayer.h:29
uint32_t height
Definition: replayer.h:25
int32_t mapHeight() const
Definition: replayer.h:65
uint32_t width
Definition: replayer.h:25
void setKeyFrame(int32_t x)
Definition: replayer.h:56
Definition: replayer.h:24