The Trainer should be shared amongst multiple different nodes, and attached to a single Module.
More...
#include <trainer.h>
Inherited by cpid::BatchedPGTrainer, cpid::CentralTrainer, cpid::ESTrainer, cpid::Evaluator, and cpid::OnlineZORBTrainer.
|
| Trainer (ag::Container model, ag::Optimizer optim, std::unique_ptr< BaseSampler >, std::unique_ptr< AsyncBatcher > batcher=nullptr) |
|
virtual ag::Variant | forward (ag::Variant inp, EpisodeHandle const &) |
|
ag::Variant | forwardUnbatched (ag::Variant in, ag::Container model=nullptr) |
| Convenience function when one need to forward a single input. More...
|
|
virtual bool | update ()=0 |
|
virtual | ~Trainer ()=default |
|
void | setTrain (bool=true) |
|
bool | isTrain () const |
|
ag::Variant | sample (ag::Variant in) |
| Sample using the class' sampler. More...
|
|
ag::Container | model () const |
|
ag::Optimizer | optim () const |
|
ReplayBuffer & | replayBuffer () |
|
virtual std::shared_ptr< Evaluator > | makeEvaluator (size_t, std::unique_ptr< BaseSampler > sampler) |
|
void | setDone (bool=true) |
|
bool | isDone () const |
|
virtual void | step (EpisodeHandle const &, std::shared_ptr< ReplayBufferFrame > v, bool isDone=false) |
|
virtual std::shared_ptr< ReplayBufferFrame > | makeFrame (ag::Variant trainerOutput, ag::Variant state, float reward)=0 |
|
virtual EpisodeHandle | startEpisode () |
| Returns true if succeeded to register an episode, and false otherwise. More...
|
|
virtual void | forceStopEpisode (EpisodeHandle const &) |
|
bool | isActive (EpisodeHandle const &) |
|
virtual void | reset () |
| Releases all the worker threads so that they can be joined. More...
|
|
template<class Archive > |
void | save (Archive &ar) const |
|
template<class Archive > |
void | load (Archive &ar) |
|
template<typename T > |
bool | is () const |
|
Trainer & | setMetricsContext (std::shared_ptr< MetricsContext > context) |
|
std::shared_ptr< MetricsContext > | metricsContext () const |
|
| TORCH_ARG (float, noiseStd) |
|
| TORCH_ARG (bool, continuousActions) |
|
void | setBatcher (std::unique_ptr< AsyncBatcher > batcher) |
|
The Trainer should be shared amongst multiple different nodes, and attached to a single Module.
It consists of an:
- Algorithm, built into the trainer and subclassed via implementing the stepFrame and stepGame functions. The trainer itself might also start a seperate thread or otherwise have functionality for syncing weights.
- Model, defined externally subject to algorithm specifications I don't know of a good way of enforcing model output, except an incorrect output spec will cause the algorithm to fail. Thus, please comment new algorithms with its input specification.
- An optimizer, perhaps defined externally. Many of the multinode algorithms will probably force its own optimizer
- A sampler, responsible for transforming the model output into an action
- A replay buffer, as implemented below and forced by subclassing the algorithm. Each algorithm will expect its own replay buffer.
cpid::Trainer::Trainer |
( |
ag::Container |
model, |
|
|
ag::Optimizer |
optim, |
|
|
std::unique_ptr< BaseSampler > |
sampler, |
|
|
std::unique_ptr< AsyncBatcher > |
batcher = nullptr |
|
) |
| |
virtual cpid::Trainer::~Trainer |
( |
| ) |
|
|
virtualdefault |
void cpid::Trainer::forceStopEpisode |
( |
EpisodeHandle const & |
handle | ) |
|
|
virtual |
ag::Variant cpid::Trainer::forward |
( |
ag::Variant |
inp, |
|
|
EpisodeHandle const & |
handle |
|
) |
| |
|
virtual |
ag::Variant cpid::Trainer::forwardUnbatched |
( |
ag::Variant |
in, |
|
|
ag::Container |
model = nullptr |
|
) |
| |
Convenience function when one need to forward a single input.
This will make it look batched and forward it, so that the model has no problem handling it. If
- Parameters
-
template<typename T >
bool cpid::Trainer::is |
( |
| ) |
const |
|
inline |
bool cpid::Trainer::isDone |
( |
| ) |
const |
|
inline |
bool cpid::Trainer::isTrain |
( |
| ) |
const |
|
inline |
template<class Archive >
void cpid::Trainer::load |
( |
Archive & |
ar | ) |
|
|
inline |
std::shared_ptr< Evaluator > cpid::Trainer::makeEvaluator |
( |
size_t |
, |
|
|
std::unique_ptr< BaseSampler > |
sampler |
|
) |
| |
|
virtual |
virtual std::shared_ptr<ReplayBufferFrame> cpid::Trainer::makeFrame |
( |
ag::Variant |
trainerOutput, |
|
|
ag::Variant |
state, |
|
|
float |
reward |
|
) |
| |
|
pure virtual |
std::shared_ptr< MetricsContext > cpid::Trainer::metricsContext |
( |
| ) |
const |
|
inline |
ag::Container cpid::Trainer::model |
( |
| ) |
const |
ag::Optimizer cpid::Trainer::optim |
( |
| ) |
const |
void cpid::Trainer::reset |
( |
| ) |
|
|
virtual |
Releases all the worker threads so that they can be joined.
For the off-policy trainers, labels all games as inactive. For the on-policy trainers, additionally un-blocks all threads that could be waiting at the batch barrier.
Reimplemented in cpid::ESTrainer, and cpid::Evaluator.
ag::Variant cpid::Trainer::sample |
( |
ag::Variant |
in | ) |
|
Sample using the class' sampler.
template<class Archive >
void cpid::Trainer::save |
( |
Archive & |
ar | ) |
const |
|
inline |
void cpid::Trainer::setBatcher |
( |
std::unique_ptr< AsyncBatcher > |
batcher | ) |
|
void cpid::Trainer::setDone |
( |
bool |
done = true | ) |
|
void cpid::Trainer::setTrain |
( |
bool |
train = true | ) |
|
virtual void cpid::Trainer::stepGame |
( |
GameUID const & |
game | ) |
|
|
inlineprotectedvirtual |
cpid::Trainer::TORCH_ARG |
( |
float |
, |
|
|
noiseStd |
|
|
) |
| |
cpid::Trainer::TORCH_ARG |
( |
bool |
, |
|
|
continuousActions |
|
|
) |
| |
virtual bool cpid::Trainer::update |
( |
| ) |
|
|
pure virtual |
std::shared_timed_mutex cpid::Trainer::activeMapMutex_ |
|
protected |
std::atomic<bool> cpid::Trainer::done_ {false} |
|
protected |
constexpr float cpid::Trainer::kFwdMetricsSubsampling = 0.1 |
|
staticprotected |
We subsample kFwdMetricsSubsampling of the forward() events when measuring their duration.
ag::Container cpid::Trainer::model_ |
|
protected |
std::mutex cpid::Trainer::modelWriteMutex_ |
|
protected |
ag::Optimizer cpid::Trainer::optim_ |
|
protected |
bool cpid::Trainer::train_ = true |
|
protected |
The documentation for this class was generated from the following files: