TorchCraftAI
A bot for machine learning research on StarCraft: Brood War
forkserver.h
1 /*
2  * Copyright (c) 2017-present, Facebook, Inc.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #pragma once
9 
10 #include <cstdlib>
11 #include <mutex>
12 #include <sstream>
13 #include <tuple>
14 #include <unordered_map>
15 #include <vector>
16 
17 #include <cereal/archives/binary.hpp>
18 
19 namespace cherrypi {
20 
21 struct EnvVar {
22  std::string key;
23  std::string value;
24  bool overwrite = false;
25  template <class Archive>
26  void serialize(Archive& archive) {
27  archive(key, value, overwrite);
28  }
29 };
30 
32  public:
33  EnvironmentBuilder(bool copyEnv = true);
35 
36  void setenv(
37  std::string const& name,
38  std::string const& value,
39  bool overwrite = false);
40  char* const* const getEnv();
41 
42  private:
43  std::unordered_map<std::string, std::string> environ_;
44  char** env_ = nullptr;
45 
46  void freeEnv();
47 };
48 
49 /**
50  * File descriptors passed as arguments to ForkServer::fork *must* be wrapped
51  * in this class. The actual file descriptor number will not be the same in
52  * the forked process.
53  */
55  public:
56  FileDescriptor(int fd) : fd(fd) {}
57  operator int() {
58  return fd;
59  }
60 
61  private:
62  int fd;
63 };
64 
65 /**
66  * This class lets us fork when using MPI.
67  * You must call ForkServer::startForkServer() before mpi/gloo is initialized,
68  * and before any threads are created. The best place to call it is in main,
69  * after parsing command line arguments, and before cherrypi::init().
70  * Example usage:
71  *
72  * ForkServer::startForkServer();
73  * int pid = ForkServer::instance().fork([](std::string str) {
74  * VLOG(0) << "This is a new process with message: " << str;
75  * }, std::string("hello world"));
76  * ForkServer::instance().waitpid(pid);
77  *
78  */
79 class ForkServer {
80  public:
81  ForkServer();
82  ~ForkServer();
83 
84  static ForkServer& instance();
85  static void startForkServer();
86  static void endForkServer();
87 
88  /// Execute command with environment.
89  /// returns rfd, wfd, pid, where rfd and wfd is the read and write descriptor
90  /// to stdout of the new process.
91  std::tuple<int, int, int> execute(
92  std::vector<std::string> const& command,
93  std::vector<EnvVar> const& env);
94 
95  /// fork and call f with the specified arguments. f must be trivially
96  /// copyable, args must be cereal serializable.
97  /// You should not pass any pointers or references (either through
98  /// argument or lambda capture), except to globals, since they will
99  /// not be valid in the new process.
100  /// There are no restrictions on what code can be executed in the function,
101  /// but keep in mind that it runs in a new process with a single thread, as-if
102  /// running from the point in the program where startForkServer was called.
103  /// It is highly recommended to call waitpid at some point with the returned
104  /// pid, because linux requires it in order to reap children and avoid a
105  /// defunct process for every fork.
106  /// returns pid
107  template <typename F, typename... Args>
108  int fork(F&& f, Args&&... args) {
109  static_assert(
110  std::is_trivially_copyable<F>::value, "f must be trivially copyable!");
111  std::lock_guard<std::mutex> lock(mutex_);
112  std::stringstream oss;
113  cereal::BinaryOutputArchive ar(oss);
114  std::vector<int> (*ptrReadFds)(int sock) = &forkReadFds<F, Args...>;
115  ar.saveBinary(&ptrReadFds, sizeof(ptrReadFds));
116  void (*ptr)(cereal::BinaryInputArchive & ar, const std::vector<int>& fds) =
117  &forkEntry<F, Args...>;
118  ar.saveBinary(&ptr, sizeof(ptr));
119  ar.saveBinary(&f, sizeof(f));
120  forkSerialize(ar, std::forward<Args>(args)...);
121  return forkSendCommand(oss.str());
122  }
123 
124  // Blocks and waits until pid exits. Linux will not release process resources
125  // until either this is called or the parent exits.
126  int waitpid(int pid);
127 
128  private:
129  std::mutex mutex_;
130  int forkServerRFd_ = -1; // Socket to read data from the fork server
131  int forkServerWFd_ = -1; // Socket to send data to the fork server
132  int forkServerSock_ = -1; // UNIX domain socket to recv file descriptors
133 
134  static void sendfd(int sock, int fd);
135  static int recvfd(int sock);
136 
137  int forkSendCommand(const std::string& data);
138 
139  template <typename T>
140  void forkSerialize(T& ar) {}
141  template <typename T, typename A, typename... Args>
142  void forkSerialize(T& ar, A&& a, Args&&... args) {
143  ar(std::forward<A>(a));
144  forkSerialize(ar, std::forward<Args>(args)...);
145  }
146  template <typename T, typename... Args>
147  void forkSerialize(T& ar, FileDescriptor fd, Args&&... args) {
148  sendfd(forkServerSock_, (int)fd);
149  forkSerialize(ar, std::forward<Args>(args)...);
150  }
151 
152  template <typename T>
153  struct typeResolver {};
154 
155  template <typename T>
156  static T forkDeserialize(
157  typeResolver<T>,
158  cereal::BinaryInputArchive& ar,
159  const std::vector<int>& fds,
160  size_t& fdsIndex) {
161  T r;
162  ar(r);
163  return r;
164  }
165  static int forkDeserialize(
166  typeResolver<FileDescriptor>,
167  cereal::BinaryInputArchive& ar,
168  const std::vector<int>& fds,
169  size_t& fdsIndex) {
170  return fds.at(fdsIndex++);
171  }
172 
173  template <typename T>
174  static int
175  forkDeserializeFds(typeResolver<T>, int sock, std::vector<int>& result) {
176  return 0;
177  }
178  static int forkDeserializeFds(
179  typeResolver<FileDescriptor>,
180  int sock,
181  std::vector<int>& result) {
182  result.push_back(recvfd(sock));
183  return 0;
184  }
185 
186  template <typename F, typename... Args>
187  static std::vector<int> forkReadFds(int sock) {
188  std::vector<int> r;
189  // Force evaluation order with {}
190  auto x = {
191  forkDeserializeFds(typeResolver<std::decay_t<Args>>{}, sock, r)...};
192  (void)x;
193  return r;
194  }
195 
196  template <typename F, typename Tuple, size_t... I>
197  static void applyImpl(F&& f, Tuple tuple, std::index_sequence<I...>) {
198  std::forward<F>(f)(std::move(std::get<I>(tuple))...);
199  }
200 
201  template <typename F, typename Tuple>
202  static void apply(F&& f, Tuple tuple) {
203  applyImpl(
204  std::forward<F>(f),
205  std::move(tuple),
206  std::make_index_sequence<std::tuple_size<Tuple>::value>{});
207  }
208 
209  template <typename F, typename... Args>
210  static void forkEntry(
211  cereal::BinaryInputArchive& ar,
212  const std::vector<int>& fds) {
213  typename std::aligned_storage<sizeof(F), alignof(F)>::type buf;
214  ar.loadBinary(&buf, sizeof(buf));
215  F& f = (F&)buf;
216  size_t fdsIndex = 0;
217  apply(
218  f,
219  std::tuple<std::decay_t<Args>...>{forkDeserialize(
220  typeResolver<std::decay_t<Args>>{}, ar, fds, fdsIndex)...});
221  std::_Exit(0);
222  }
223 };
224 } // namespace cherrypi
std::string value
Definition: forkserver.h:23
void serialize(Archive &archive)
Definition: forkserver.h:26
FileDescriptor(int fd)
Definition: forkserver.h:56
Definition: forkserver.h:31
This class lets us fork when using MPI.
Definition: forkserver.h:79
Definition: forkserver.h:21
File descriptors passed as arguments to ForkServer::fork must be wrapped in this class.
Definition: forkserver.h:54
int fork(F &&f, Args &&...args)
fork and call f with the specified arguments.
Definition: forkserver.h:108
std::string key
Definition: forkserver.h:22
int recvfd(int socket)
Definition: forkserver.cpp:157
Main namespace for bot-related code.
Definition: areainfo.cpp:17
bool overwrite
Definition: forkserver.h:24
void sendfd(int socket, int fd)
Definition: forkserver.cpp:127