10 #include <autogradpp/autograd.h> 11 #include <gflags/gflags.h> 17 #define VLOG_MASTER(lvl) \ 18 VLOG_IF(lvl, cpid::distributed::globalContext()->rank == 0) 19 #define VLOG_ALL(lvl) \ 20 VLOG(lvl) << "w" << cpid::distributed::globalContext()->rank << ": " 23 namespace distributed {
25 using namespace ::
c10d;
38 Work(std::function<
void()> onFinish);
41 ~Work() noexcept(
false);
80 const std::exception_ptr
exception()
const;
84 Work(std::vector<std::shared_ptr<ProcessGroup::Work>>&&);
85 void add(std::shared_ptr<ProcessGroup::Work>);
88 std::vector<std::shared_ptr<ProcessGroup::Work>> works_;
89 std::function<void()> onFinish_ =
nullptr;
97 std::is_same<T, uint8_t>::value ||
98 std::is_same<T, char>::value ||
99 std::is_same<T, int8_t>::value ||
100 std::is_same<T, int16_t>::value ||
101 std::is_same<T, int32_t>::value ||
102 std::is_same<T, int64_t>::value ||
103 std::is_same<T, float>::value || std::is_same<T, double>::value>;
111 std::shared_ptr<Store> store,
114 std::chrono::milliseconds timeout = std::chrono::seconds(1000));
119 template <
typename T, IsTorchDType<T>* =
nullptr>
121 template <
typename T, IsTorchDType<T>* =
nullptr>
126 template <
typename T, IsTorchDType<T>* =
nullptr>
128 template <
typename T, IsTorchDType<T>* =
nullptr>
133 template <
typename T, IsTorchDType<T>* =
nullptr>
135 template <
typename T, IsTorchDType<T>* =
nullptr>
142 std::shared_ptr<ProcessGroup> glooPG_;
143 std::shared_ptr<ProcessGroup> ncclPG_;
144 std::shared_ptr<ProcessGroup> devicePG(torch::Tensor x);
148 template <
typename T, IsTorchDType<T>* =
nullptr>
150 template <
typename T, IsTorchDType<T>* =
nullptr>
155 template <
typename T, IsTorchDType<T>* =
nullptr>
157 template <
typename T, IsTorchDType<T>* =
nullptr>
162 template <
typename T, IsTorchDType<T>* =
nullptr>
164 template <
typename T, IsTorchDType<T>* =
nullptr>
std::shared_ptr< Context > globalContext()
Definition: distributed.cpp:274
bool isCompleted()
Definition: distributed.cpp:122
Definition: distributed.h:108
void wait()
Definition: distributed.cpp:134
void init()
Definition: distributed.cpp:166
int size
Definition: distributed.h:117
void setGPUToLocalRank()
Sets CUDA device to the local (if available) or MPI rank, both modulo the number of available devices...
Definition: distributed.cpp:266
Work allgather(T *out, T *in, int64_t s)
Definition: distributed.cpp:468
Work broadcast(T *ptr, int64_t s, int root)
Definition: distributed.cpp:450
Definition: distributed.h:36
Work allreduceGradients(ag::Container const &x, ReduceOp op)
Definition: distributed.cpp:445
friend class Context
Definition: distributed.h:90
int rank
Definition: distributed.h:116
~Work() noexcept(false)
Definition: distributed.cpp:59
const std::exception_ptr exception() const
Definition: distributed.cpp:138
void synchronize()
Definition: distributed.cpp:130
The TorchCraftAI training library.
Definition: batcher.cpp:15
typename std::enable_if_t< std::is_same< T, uint8_t >::value||std::is_same< T, char >::value||std::is_same< T, int8_t >::value||std::is_same< T, int16_t >::value||std::is_same< T, int32_t >::value||std::is_same< T, int64_t >::value||std::is_same< T, float >::value||std::is_same< T, double >::value > IsTorchDType
Definition: distributed.h:103
Work allreduce(T *ptr, int64_t s, ReduceOp op)
Definition: distributed.cpp:434
Work barrier()
Definition: distributed.cpp:479
bool isSuccess()
Definition: distributed.cpp:126