TorchCraftAI
A bot for machine learning research on StarCraft: Brood War
centralcpid2ktrainer.h
1 /*
2  * Copyright (c) 2017-present, Facebook, Inc.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #pragma once
9 
10 #include "blobpubsub.h"
11 #include "centraltrainer.h"
12 #include "episodeserver.h"
13 
14 #include <shared_mutex>
15 
16 namespace cpid {
17 class Cpid2kWorker;
18 namespace distributed {
19 class Context;
20 }
21 
22 /**
23  * A trainer that sends episodes to one or more central instances.
24  *
25  * This is like CentralTrainer but uses Redis for figuring out servers and
26  * clients.
27  */
29  public:
31  ag::Container model,
32  ag::Optimizer optim,
33  std::unique_ptr<BaseSampler> sampler,
34  std::unique_ptr<AsyncBatcher> batcher = nullptr,
35  std::string serverRole = "train");
36 
37  virtual bool update() override;
38 
39  void updateDone();
40 
41  distributed::Context& context();
42  distributed::Context& serverContext();
43 
44  int numUpdates() {
45  return numUpdates_.load();
46  }
47 
48  protected:
49  void bcastWeights();
50  void recvWeights(void const* data, size_t len, int64_t numUpdates);
51 
52  std::string serverRole_;
53  std::unique_ptr<Cpid2kWorker> worker_;
54  std::mutex makeClientMutex_;
55  std::vector<std::string> endpoints_;
56 
57  std::shared_ptr<zmq::context_t> zmqContext_;
58  // Models are pushed from server to clients
59  std::shared_ptr<BlobPublisher> modelPub_;
60  std::shared_ptr<BlobSubscriber> modelSub_;
61 
62  std::atomic<int64_t> numUpdates_{-1};
63 };
64 
65 } // namespace cpid
std::vector< std::string > endpoints_
Definition: centralcpid2ktrainer.h:55
std::mutex makeClientMutex_
Definition: centralcpid2ktrainer.h:54
Definition: distributed.h:108
A trainer that sends episodes to one or more central instances.
Definition: centraltrainer.h:37
std::unique_ptr< Cpid2kWorker > worker_
Definition: centralcpid2ktrainer.h:53
std::shared_ptr< BlobPublisher > modelPub_
Definition: centralcpid2ktrainer.h:59
int numUpdates()
Definition: centralcpid2ktrainer.h:44
std::string serverRole_
Definition: centralcpid2ktrainer.h:52
A trainer that sends episodes to one or more central instances.
Definition: centralcpid2ktrainer.h:28
std::shared_ptr< zmq::context_t > zmqContext_
Definition: centralcpid2ktrainer.h:57
The TorchCraftAI training library.
Definition: batcher.cpp:15
std::shared_ptr< BlobSubscriber > modelSub_
Definition: centralcpid2ktrainer.h:60