|
ag::Variant | forward (ag::Variant inp, EpisodeHandle const &) override |
|
bool | update () override |
|
void | doOnlineUpdatesInstead () |
|
int | episodes () |
|
| BatchedPGTrainer (ag::Container model, ag::Optimizer optim, std::unique_ptr< BaseSampler > sampler, double gamma=0.99, int batchSize=10, std::size_t maxBatchSize=50, std::unique_ptr< AsyncBatcher > batcher=nullptr) |
|
virtual std::shared_ptr< ReplayBufferFrame > | makeFrame (ag::Variant trainerOutput, ag::Variant state, float reward) override |
| Contract: the trainer output should be a map with keys: "action" for the taken action "V" for the state value, and "action" for the action probability. More...
|
|
std::shared_ptr< Evaluator > | makeEvaluator (size_t, std::unique_ptr< BaseSampler > sampler=std::make_unique< DiscreteMaxSampler >()) override |
|
| Trainer (ag::Container model, ag::Optimizer optim, std::unique_ptr< BaseSampler >, std::unique_ptr< AsyncBatcher > batcher=nullptr) |
|
ag::Variant | forwardUnbatched (ag::Variant in, ag::Container model=nullptr) |
| Convenience function when one need to forward a single input. More...
|
|
virtual | ~Trainer ()=default |
|
void | setTrain (bool=true) |
|
bool | isTrain () const |
|
ag::Variant | sample (ag::Variant in) |
| Sample using the class' sampler. More...
|
|
ag::Container | model () const |
|
ag::Optimizer | optim () const |
|
ReplayBuffer & | replayBuffer () |
|
void | setDone (bool=true) |
|
bool | isDone () const |
|
virtual void | step (EpisodeHandle const &, std::shared_ptr< ReplayBufferFrame > v, bool isDone=false) |
|
virtual EpisodeHandle | startEpisode () |
| Returns true if succeeded to register an episode, and false otherwise. More...
|
|
virtual void | forceStopEpisode (EpisodeHandle const &) |
|
bool | isActive (EpisodeHandle const &) |
|
virtual void | reset () |
| Releases all the worker threads so that they can be joined. More...
|
|
template<class Archive > |
void | save (Archive &ar) const |
|
template<class Archive > |
void | load (Archive &ar) |
|
template<typename T > |
bool | is () const |
|
Trainer & | setMetricsContext (std::shared_ptr< MetricsContext > context) |
|
std::shared_ptr< MetricsContext > | metricsContext () const |
|
| TORCH_ARG (float, noiseStd) |
|
| TORCH_ARG (bool, continuousActions) |
|
void | setBatcher (std::unique_ptr< AsyncBatcher > batcher) |
|
Off policy policy gradient with a critic.
This Trainer implements two modes:
- Online: It does 1 update with the given batch size per node whenever it gets an episode. Therefore, one episode will always be new and the others will be from the replay buffer. THIS MODE IS UNTESTED
- Offline: Many threads are assumed to generate episodes in the background, and it does updates in a seperate background thread. In both modes, it will first update on new episodes at least once before moving to sample from the replay buffer. If more episodes are generated than it can update, it will block until the next update. When the replaybuffer of episodes it has already updated over reaches maxBatchSize, it will remove the oldest episode it's seen.
Replayer format: state, action, p(action), reward
Model output: Probability vector over actions: 1-dim Vector Critic's value estimate: Double