|
TorchCraftAI
A bot for machine learning research on StarCraft: Brood War
|
A trainer that sends episodes to one or more central instances. More...
#include <centralcpid2ktrainer.h>
Inherits cpid::CentralTrainer.
Public Member Functions | |
| CentralCpid2kTrainer (ag::Container model, ag::Optimizer optim, std::unique_ptr< BaseSampler > sampler, std::unique_ptr< AsyncBatcher > batcher=nullptr, std::string serverRole="train") | |
| virtual bool | update () override |
| void | updateDone () |
| distributed::Context & | context () |
| distributed::Context & | serverContext () |
| int | numUpdates () |
Public Member Functions inherited from cpid::CentralTrainer | |
| CentralTrainer (bool isServer, ag::Container model, ag::Optimizer optim, std::unique_ptr< BaseSampler > sampler, std::unique_ptr< AsyncBatcher > batcher=nullptr) | |
| virtual | ~CentralTrainer () |
| bool | isServer () const |
| virtual void | stepFrame (GameUID const &, EpisodeKey const &, ReplayBuffer::Episode &) override |
| virtual void | stepEpisode (GameUID const &, EpisodeKey const &, ReplayBuffer::Episode &) override |
| ag::Variant | forward (ag::Variant inp, EpisodeHandle const &) override |
| virtual std::shared_ptr< ReplayBufferFrame > | makeFrame (ag::Variant trainerOutput, ag::Variant state, float reward) override |
| virtual EpisodeHandle | startEpisode () override |
| Returns true if succeeded to register an episode, and false otherwise. More... | |
| virtual void | forceStopEpisode (EpisodeHandle const &) override |
| std::shared_lock< std::shared_timed_mutex > | modelReadLock () |
| std::unique_lock< std::shared_timed_mutex > | modelWriteLock () |
Public Member Functions inherited from cpid::Trainer | |
| Trainer (ag::Container model, ag::Optimizer optim, std::unique_ptr< BaseSampler >, std::unique_ptr< AsyncBatcher > batcher=nullptr) | |
| ag::Variant | forwardUnbatched (ag::Variant in, ag::Container model=nullptr) |
| Convenience function when one need to forward a single input. More... | |
| virtual | ~Trainer ()=default |
| void | setTrain (bool=true) |
| bool | isTrain () const |
| ag::Variant | sample (ag::Variant in) |
| Sample using the class' sampler. More... | |
| ag::Container | model () const |
| ag::Optimizer | optim () const |
| ReplayBuffer & | replayBuffer () |
| virtual std::shared_ptr< Evaluator > | makeEvaluator (size_t, std::unique_ptr< BaseSampler > sampler) |
| void | setDone (bool=true) |
| bool | isDone () const |
| virtual void | step (EpisodeHandle const &, std::shared_ptr< ReplayBufferFrame > v, bool isDone=false) |
| bool | isActive (EpisodeHandle const &) |
| virtual void | reset () |
| Releases all the worker threads so that they can be joined. More... | |
| template<class Archive > | |
| void | save (Archive &ar) const |
| template<class Archive > | |
| void | load (Archive &ar) |
| template<typename T > | |
| bool | is () const |
| Trainer & | setMetricsContext (std::shared_ptr< MetricsContext > context) |
| std::shared_ptr< MetricsContext > | metricsContext () const |
| TORCH_ARG (float, noiseStd) | |
| TORCH_ARG (bool, continuousActions) | |
| void | setBatcher (std::unique_ptr< AsyncBatcher > batcher) |
Protected Member Functions | |
| void | bcastWeights () |
| void | recvWeights (void const *data, size_t len, int64_t numUpdates) |
Protected Member Functions inherited from cpid::CentralTrainer | |
| virtual void | receivedFrames (GameUID const &, std::string const &)=0 |
| Callback for new episodes. More... | |
| virtual bool | episodeClientEnqueue (EpisodeData const &) |
| Callback for locally generated episode data that can be sent out. More... | |
| virtual uint32_t | getMaxBatchLength () const |
| Allows implementing trainers to send partial episodes TODO: set up synchronization so the partial episodes end up on the same node. More... | |
| virtual uint32_t | getSendInterval () const |
| virtual bool | serveContinuously () const |
| Allows implementing trainers to decide whether to serve end-of-frame in the middle of the episode or not. More... | |
| CentralTrainer (ag::Container model, ag::Optimizer optim, std::unique_ptr< BaseSampler > sampler, std::unique_ptr< AsyncBatcher > batcher=nullptr) | |
| A protected constructor that doesn't set up a server. More... | |
| void | dequeueEpisodes () |
Protected Member Functions inherited from cpid::Trainer | |
| virtual void | stepGame (GameUID const &game) |
| template<typename T > | |
| std::vector< T const * > | cast (ReplayBuffer::Episode const &e) |
Protected Attributes | |
| std::string | serverRole_ |
| std::unique_ptr< Cpid2kWorker > | worker_ |
| std::mutex | makeClientMutex_ |
| std::vector< std::string > | endpoints_ |
| std::shared_ptr< zmq::context_t > | zmqContext_ |
| std::shared_ptr< BlobPublisher > | modelPub_ |
| std::shared_ptr< BlobSubscriber > | modelSub_ |
| std::atomic< int64_t > | numUpdates_ {-1} |
Protected Attributes inherited from cpid::CentralTrainer | |
| std::shared_ptr< EpisodeServer > | server_ |
| std::shared_ptr< EpisodeClient > | client_ |
| std::thread | dequeueEpisodes_ |
| std::mutex | newGamesMutex_ |
| std::queue< EpisodeTuple > | newBatches_ |
| std::shared_timed_mutex | modelMutex_ |
| std::atomic< bool > | stop_ {false} |
| std::unique_ptr< BufferPool > | bufferPool_ |
Protected Attributes inherited from cpid::Trainer | |
| ag::Container | model_ |
| ag::Optimizer | optim_ |
| std::shared_ptr< MetricsContext > | metricsContext_ |
| ReplayBuffer | replayer_ |
| bool | train_ = true |
| std::atomic< bool > | done_ {false} |
| std::mutex | modelWriteMutex_ |
| std::shared_timed_mutex | activeMapMutex_ |
| std::unique_ptr< BaseSampler > | sampler_ |
| std::unique_ptr< AsyncBatcher > | batcher_ |
| std::shared_ptr< HandleGuard > | epGuard_ |
| ReplayBuffer::UIDKeyStore | actives_ |
Additional Inherited Members | |
Protected Types inherited from cpid::Trainer | |
| using | ForwardFunction = std::function< ag::Variant(ag::Variant, EpisodeHandle const &)> |
Static Protected Member Functions inherited from cpid::Trainer | |
| static std::shared_ptr< Evaluator > | evaluatorFactory (ag::Container model, std::unique_ptr< BaseSampler > s, size_t n, ForwardFunction func) |
Static Protected Attributes inherited from cpid::Trainer | |
| static constexpr float | kFwdMetricsSubsampling = 0.1 |
| We subsample kFwdMetricsSubsampling of the forward() events when measuring their duration. More... | |
A trainer that sends episodes to one or more central instances.
This is like CentralTrainer but uses Redis for figuring out servers and clients.
| cpid::CentralCpid2kTrainer::CentralCpid2kTrainer | ( | ag::Container | model, |
| ag::Optimizer | optim, | ||
| std::unique_ptr< BaseSampler > | sampler, | ||
| std::unique_ptr< AsyncBatcher > | batcher = nullptr, |
||
| std::string | serverRole = "train" |
||
| ) |
|
protected |
| dist::Context & cpid::CentralCpid2kTrainer::context | ( | ) |
|
inline |
|
protected |
| dist::Context & cpid::CentralCpid2kTrainer::serverContext | ( | ) |
|
overridevirtual |
Reimplemented from cpid::CentralTrainer.
| void cpid::CentralCpid2kTrainer::updateDone | ( | ) |
|
protected |
|
protected |
|
protected |
|
protected |
|
protected |
|
protected |
|
protected |
|
protected |
1.8.11