TorchCraftAI
A bot for machine learning research on StarCraft: Brood War
distributed.h
1 /*
2  * Copyright (c) 2017-present, Facebook, Inc.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #pragma once
9 
10 #include <autogradpp/autograd.h>
11 #include <gflags/gflags.h>
12 #include <memory>
13 
14 #include "c10d.h"
15 
16 // Distributed logging utilities.
17 #define VLOG_MASTER(lvl) \
18  VLOG_IF(lvl, cpid::distributed::globalContext()->rank == 0)
19 #define VLOG_ALL(lvl) \
20  VLOG(lvl) << "w" << cpid::distributed::globalContext()->rank << ": "
21 
22 namespace cpid {
23 namespace distributed {
24 
25 using namespace ::c10d;
26 /**
27  * This is a wrapper around ProcessGroup::Work.
28  * We wait for the work to be finished on destruction, to ensure transfer
29  * completes. Additionally, this subclass provides support for waiting on
30  * multiple pieces of work to finish, in the case of syncing entire models.
31  *
32  * The comments for this class are copied from ProcessGroup::Work and may be out
33  * of date when PyTorch updates...
34  **/
35 class Context;
36 class Work {
37  public:
38  Work(std::function<void()> onFinish);
39  Work(Work const&) = delete;
40  Work(Work&&);
41  ~Work() noexcept(false);
42 
43  // Checks if request has completed. Non-blocking operation.
44  bool isCompleted();
45 
46  // Returns if the work completed successfully.
47  // If false, the exception function can be called to get details.
48  bool isSuccess();
49 
50  // Ensures that operations on the output tensors that are invoked
51  // after this function returns are correctly sequenced after the
52  // asynchronous completion of this work.
53  //
54  // For CUDA tensors, it inserts stream synchronization such that
55  // the streams of the caller wait for completion of the
56  // asynchronous operations on the destination tensors.
57  //
58  // For CPU tensors, it is currently a nop.
59  //
60  // This function should only be used if the caller polls for
61  // completion through the `isCompleted` function, it has returned
62  // true, and the `isSuccess` function also has returned true.
63  //
64  void synchronize();
65 
66  // Waits until request completes. Blocking operation.
67  // Throws if the work completed with an exception.
68  //
69  // Functionally equivalent to:
70  //
71  // while (!isCompleted()) { /* nop */ }
72  // auto success = isSuccess();
73  // if (!success) { std::rethrow_exception(exception()); }
74  // return success;
75  //
76  void wait();
77 
78  // Returns exception if wait() returned false.
79  // This will return the first exception encountered.
80  const std::exception_ptr exception() const;
81 
82  private:
83  Work() = default;
84  Work(std::vector<std::shared_ptr<ProcessGroup::Work>>&&);
85  void add(std::shared_ptr<ProcessGroup::Work>);
86  void add(Work&&);
87 
88  std::vector<std::shared_ptr<ProcessGroup::Work>> works_;
89  std::function<void()> onFinish_ = nullptr;
90  friend class Context;
91 };
92 
93 // Let's provide some type-safety. We can only send types that has a
94 // torch::Dtype
95 template <typename T>
96 using IsTorchDType = typename std::enable_if_t<
97  std::is_same<T, uint8_t>::value || // Byte
98  std::is_same<T, char>::value || // Char
99  std::is_same<T, int8_t>::value || // Char
100  std::is_same<T, int16_t>::value || // Short
101  std::is_same<T, int32_t>::value || // Int
102  std::is_same<T, int64_t>::value || // Long
103  std::is_same<T, float>::value || std::is_same<T, double>::value>;
104 
105 // The Context contains 2 instantiations of the C10D
106 // processgroup, and will automatically reroute tensors to NCCL or Gloo,
107 // depending on whether we use CPU or GPU
108 class Context {
109  public:
110  Context(
111  std::shared_ptr<Store> store,
112  int rank,
113  int size,
114  std::chrono::milliseconds timeout = std::chrono::seconds(1000));
115 
116  int rank;
117  int size;
118 
119  template <typename T, IsTorchDType<T>* = nullptr>
120  Work allreduce(T* ptr, int64_t s, ReduceOp = ReduceOp::SUM);
121  template <typename T, IsTorchDType<T>* = nullptr>
122  Work allreduce(std::vector<T>& v, ReduceOp = ReduceOp::SUM);
123  Work allreduce(torch::Tensor, ReduceOp = ReduceOp::SUM);
124  Work allreduceGradients(ag::Container const&, ReduceOp = ReduceOp::SUM);
125 
126  template <typename T, IsTorchDType<T>* = nullptr>
127  Work broadcast(T* ptr, int64_t s, int root = 0);
128  template <typename T, IsTorchDType<T>* = nullptr>
129  Work broadcast(std::vector<T>& v, int root = 0);
130  Work broadcast(torch::Tensor, int root = 0);
131  Work broadcast(ag::Container const&, int root = 0);
132 
133  template <typename T, IsTorchDType<T>* = nullptr>
134  Work allgather(T* out, T* in, int64_t s);
135  template <typename T, IsTorchDType<T>* = nullptr>
136  Work allgather(T* out, torch::Tensor in);
137  Work allgather(torch::Tensor, torch::Tensor);
138 
139  Work barrier();
140 
141  private:
142  std::shared_ptr<ProcessGroup> glooPG_;
143  std::shared_ptr<ProcessGroup> ncclPG_;
144  std::shared_ptr<ProcessGroup> devicePG(torch::Tensor x);
145 };
146 
147 // Here are some functions that will automatically use the global context.
148 template <typename T, IsTorchDType<T>* = nullptr>
149 Work allreduce(T* ptr, int64_t s, ReduceOp = ReduceOp::SUM);
150 template <typename T, IsTorchDType<T>* = nullptr>
151 Work allreduce(std::vector<T>& v, ReduceOp = ReduceOp::SUM);
152 Work allreduce(torch::Tensor, ReduceOp = ReduceOp::SUM);
153 Work allreduceGradients(ag::Container const&, ReduceOp = ReduceOp::SUM);
154 
155 template <typename T, IsTorchDType<T>* = nullptr>
156 Work broadcast(T* ptr, int64_t s, int root = 0);
157 template <typename T, IsTorchDType<T>* = nullptr>
158 Work broadcast(std::vector<T>& v, int root = 0);
159 Work broadcast(torch::Tensor, int root = 0);
160 Work broadcast(ag::Container const&, int root = 0);
161 
162 template <typename T, IsTorchDType<T>* = nullptr>
163 Work allgather(T* out, T* in, int64_t s);
164 template <typename T, IsTorchDType<T>* = nullptr>
165 Work allgather(T* out, torch::Tensor in);
166 Work allgather(torch::Tensor, torch::Tensor);
167 
168 Work barrier();
169 
170 void init();
171 /// Sets CUDA device to the local (if available) or MPI rank, both modulo the
172 /// number of available devices. Does nothing when no CUDA is avail.
173 /// init() calls setGPUToLocalRank() already, but since the result is
174 /// thread-local it's necessary to call it from any thread that is spawned
175 void setGPUToLocalRank();
176 std::shared_ptr<Context> globalContext();
177 } // namespace distributed
178 } // namespace cpid
std::shared_ptr< Context > globalContext()
Definition: distributed.cpp:274
bool isCompleted()
Definition: distributed.cpp:122
Definition: distributed.h:108
void wait()
Definition: distributed.cpp:134
void init()
Definition: distributed.cpp:166
int size
Definition: distributed.h:117
void setGPUToLocalRank()
Sets CUDA device to the local (if available) or MPI rank, both modulo the number of available devices...
Definition: distributed.cpp:266
Work allgather(T *out, T *in, int64_t s)
Definition: distributed.cpp:468
Work broadcast(T *ptr, int64_t s, int root)
Definition: distributed.cpp:450
Definition: distributed.h:36
Work allreduceGradients(ag::Container const &x, ReduceOp op)
Definition: distributed.cpp:445
friend class Context
Definition: distributed.h:90
int rank
Definition: distributed.h:116
~Work() noexcept(false)
Definition: distributed.cpp:59
const std::exception_ptr exception() const
Definition: distributed.cpp:138
Definition: c10d.h:14
void synchronize()
Definition: distributed.cpp:130
The TorchCraftAI training library.
Definition: batcher.cpp:15
typename std::enable_if_t< std::is_same< T, uint8_t >::value||std::is_same< T, char >::value||std::is_same< T, int8_t >::value||std::is_same< T, int16_t >::value||std::is_same< T, int32_t >::value||std::is_same< T, int64_t >::value||std::is_same< T, float >::value||std::is_same< T, double >::value > IsTorchDType
Definition: distributed.h:103
Work allreduce(T *ptr, int64_t s, ReduceOp op)
Definition: distributed.cpp:434
Work barrier()
Definition: distributed.cpp:479
bool isSuccess()
Definition: distributed.cpp:126