The Trainer should be shared amongst multiple different nodes, and attached to a single Module.
More...
#include <trainer.h>
Inherited by cpid::BatchedPGTrainer, cpid::CentralTrainer, cpid::ESTrainer, cpid::Evaluator, and cpid::OnlineZORBTrainer.
|
| | Trainer (ag::Container model, ag::Optimizer optim, std::unique_ptr< BaseSampler >, std::unique_ptr< AsyncBatcher > batcher=nullptr) |
| |
| virtual ag::Variant | forward (ag::Variant inp, EpisodeHandle const &) |
| |
| ag::Variant | forwardUnbatched (ag::Variant in, ag::Container model=nullptr) |
| | Convenience function when one need to forward a single input. More...
|
| |
| virtual bool | update ()=0 |
| |
| virtual | ~Trainer ()=default |
| |
| void | setTrain (bool=true) |
| |
| bool | isTrain () const |
| |
| ag::Variant | sample (ag::Variant in) |
| | Sample using the class' sampler. More...
|
| |
| ag::Container | model () const |
| |
| ag::Optimizer | optim () const |
| |
| ReplayBuffer & | replayBuffer () |
| |
| virtual std::shared_ptr< Evaluator > | makeEvaluator (size_t, std::unique_ptr< BaseSampler > sampler) |
| |
| void | setDone (bool=true) |
| |
| bool | isDone () const |
| |
| virtual void | step (EpisodeHandle const &, std::shared_ptr< ReplayBufferFrame > v, bool isDone=false) |
| |
| virtual std::shared_ptr< ReplayBufferFrame > | makeFrame (ag::Variant trainerOutput, ag::Variant state, float reward)=0 |
| |
| virtual EpisodeHandle | startEpisode () |
| | Returns true if succeeded to register an episode, and false otherwise. More...
|
| |
| virtual void | forceStopEpisode (EpisodeHandle const &) |
| |
| bool | isActive (EpisodeHandle const &) |
| |
| virtual void | reset () |
| | Releases all the worker threads so that they can be joined. More...
|
| |
| template<class Archive > |
| void | save (Archive &ar) const |
| |
| template<class Archive > |
| void | load (Archive &ar) |
| |
| template<typename T > |
| bool | is () const |
| |
| Trainer & | setMetricsContext (std::shared_ptr< MetricsContext > context) |
| |
| std::shared_ptr< MetricsContext > | metricsContext () const |
| |
| | TORCH_ARG (float, noiseStd) |
| |
| | TORCH_ARG (bool, continuousActions) |
| |
| void | setBatcher (std::unique_ptr< AsyncBatcher > batcher) |
| |
The Trainer should be shared amongst multiple different nodes, and attached to a single Module.
It consists of an:
- Algorithm, built into the trainer and subclassed via implementing the stepFrame and stepGame functions. The trainer itself might also start a seperate thread or otherwise have functionality for syncing weights.
- Model, defined externally subject to algorithm specifications I don't know of a good way of enforcing model output, except an incorrect output spec will cause the algorithm to fail. Thus, please comment new algorithms with its input specification.
- An optimizer, perhaps defined externally. Many of the multinode algorithms will probably force its own optimizer
- A sampler, responsible for transforming the model output into an action
- A replay buffer, as implemented below and forced by subclassing the algorithm. Each algorithm will expect its own replay buffer.
| cpid::Trainer::Trainer |
( |
ag::Container |
model, |
|
|
ag::Optimizer |
optim, |
|
|
std::unique_ptr< BaseSampler > |
sampler, |
|
|
std::unique_ptr< AsyncBatcher > |
batcher = nullptr |
|
) |
| |
| virtual cpid::Trainer::~Trainer |
( |
| ) |
|
|
virtualdefault |
| void cpid::Trainer::forceStopEpisode |
( |
EpisodeHandle const & |
handle | ) |
|
|
virtual |
| ag::Variant cpid::Trainer::forward |
( |
ag::Variant |
inp, |
|
|
EpisodeHandle const & |
handle |
|
) |
| |
|
virtual |
| ag::Variant cpid::Trainer::forwardUnbatched |
( |
ag::Variant |
in, |
|
|
ag::Container |
model = nullptr |
|
) |
| |
Convenience function when one need to forward a single input.
This will make it look batched and forward it, so that the model has no problem handling it. If
- Parameters
-
template<typename T >
| bool cpid::Trainer::is |
( |
| ) |
const |
|
inline |
| bool cpid::Trainer::isDone |
( |
| ) |
const |
|
inline |
| bool cpid::Trainer::isTrain |
( |
| ) |
const |
|
inline |
template<class Archive >
| void cpid::Trainer::load |
( |
Archive & |
ar | ) |
|
|
inline |
| std::shared_ptr< Evaluator > cpid::Trainer::makeEvaluator |
( |
size_t |
, |
|
|
std::unique_ptr< BaseSampler > |
sampler |
|
) |
| |
|
virtual |
| virtual std::shared_ptr<ReplayBufferFrame> cpid::Trainer::makeFrame |
( |
ag::Variant |
trainerOutput, |
|
|
ag::Variant |
state, |
|
|
float |
reward |
|
) |
| |
|
pure virtual |
| std::shared_ptr< MetricsContext > cpid::Trainer::metricsContext |
( |
| ) |
const |
|
inline |
| ag::Container cpid::Trainer::model |
( |
| ) |
const |
| ag::Optimizer cpid::Trainer::optim |
( |
| ) |
const |
| void cpid::Trainer::reset |
( |
| ) |
|
|
virtual |
Releases all the worker threads so that they can be joined.
For the off-policy trainers, labels all games as inactive. For the on-policy trainers, additionally un-blocks all threads that could be waiting at the batch barrier.
Reimplemented in cpid::ESTrainer, and cpid::Evaluator.
| ag::Variant cpid::Trainer::sample |
( |
ag::Variant |
in | ) |
|
Sample using the class' sampler.
template<class Archive >
| void cpid::Trainer::save |
( |
Archive & |
ar | ) |
const |
|
inline |
| void cpid::Trainer::setBatcher |
( |
std::unique_ptr< AsyncBatcher > |
batcher | ) |
|
| void cpid::Trainer::setDone |
( |
bool |
done = true | ) |
|
| void cpid::Trainer::setTrain |
( |
bool |
train = true | ) |
|
| virtual void cpid::Trainer::stepGame |
( |
GameUID const & |
game | ) |
|
|
inlineprotectedvirtual |
| cpid::Trainer::TORCH_ARG |
( |
float |
, |
|
|
noiseStd |
|
|
) |
| |
| cpid::Trainer::TORCH_ARG |
( |
bool |
, |
|
|
continuousActions |
|
|
) |
| |
| virtual bool cpid::Trainer::update |
( |
| ) |
|
|
pure virtual |
| std::shared_timed_mutex cpid::Trainer::activeMapMutex_ |
|
protected |
| std::atomic<bool> cpid::Trainer::done_ {false} |
|
protected |
| constexpr float cpid::Trainer::kFwdMetricsSubsampling = 0.1 |
|
staticprotected |
We subsample kFwdMetricsSubsampling of the forward() events when measuring their duration.
| ag::Container cpid::Trainer::model_ |
|
protected |
| std::mutex cpid::Trainer::modelWriteMutex_ |
|
protected |
| ag::Optimizer cpid::Trainer::optim_ |
|
protected |
| bool cpid::Trainer::train_ = true |
|
protected |
The documentation for this class was generated from the following files: