|
| | make_shared_enabler (ag::Container model, std::unique_ptr< BaseSampler > s, size_t n, ForwardFunction f) |
| |
| EpisodeHandle | startEpisode () override |
| | Returns true if succeeded to register an episode, and false otherwise. More...
|
| |
| void | forceStopEpisode (EpisodeHandle const &) override |
| |
| bool | update () override |
| |
| virtual ag::Variant | forward (ag::Variant inp, EpisodeHandle const &) override |
| |
| void | reset () override |
| | Releases all the worker threads so that they can be joined. More...
|
| |
| std::shared_ptr< ReplayBufferFrame > | makeFrame (ag::Variant, ag::Variant, float reward) override |
| |
| | Trainer (ag::Container model, ag::Optimizer optim, std::unique_ptr< BaseSampler >, std::unique_ptr< AsyncBatcher > batcher=nullptr) |
| |
| ag::Variant | forwardUnbatched (ag::Variant in, ag::Container model=nullptr) |
| | Convenience function when one need to forward a single input. More...
|
| |
| virtual | ~Trainer ()=default |
| |
| void | setTrain (bool=true) |
| |
| bool | isTrain () const |
| |
| ag::Variant | sample (ag::Variant in) |
| | Sample using the class' sampler. More...
|
| |
| ag::Container | model () const |
| |
| ag::Optimizer | optim () const |
| |
| ReplayBuffer & | replayBuffer () |
| |
| virtual std::shared_ptr< Evaluator > | makeEvaluator (size_t, std::unique_ptr< BaseSampler > sampler) |
| |
| void | setDone (bool=true) |
| |
| bool | isDone () const |
| |
| virtual void | step (EpisodeHandle const &, std::shared_ptr< ReplayBufferFrame > v, bool isDone=false) |
| |
| bool | isActive (EpisodeHandle const &) |
| |
| template<class Archive > |
| void | save (Archive &ar) const |
| |
| template<class Archive > |
| void | load (Archive &ar) |
| |
| template<typename T > |
| bool | is () const |
| |
| Trainer & | setMetricsContext (std::shared_ptr< MetricsContext > context) |
| |
| std::shared_ptr< MetricsContext > | metricsContext () const |
| |
| | TORCH_ARG (float, noiseStd) |
| |
| | TORCH_ARG (bool, continuousActions) |
| |
| void | setBatcher (std::unique_ptr< AsyncBatcher > batcher) |
| |
|
| using | ForwardFunction = std::function< ag::Variant(ag::Variant, EpisodeHandle const &)> |
| |
| virtual void | stepEpisode (GameUID const &, EpisodeKey const &, ReplayBuffer::Episode &) override |
| |
| | Evaluator (ag::Container model, std::unique_ptr< BaseSampler > sampler, size_t batchSize, ForwardFunction func) |
| |
| virtual void | stepFrame (GameUID const &, EpisodeKey const &, ReplayBuffer::Episode &) |
| |
| virtual void | stepGame (GameUID const &game) |
| |
| template<typename T > |
| std::vector< T const * > | cast (ReplayBuffer::Episode const &e) |
| |
| static std::shared_ptr< Evaluator > | evaluatorFactory (ag::Container model, std::unique_ptr< BaseSampler > s, size_t n, ForwardFunction func) |
| |
| size_t | batchSize_ |
| |
| size_t | gamesStarted_ = 0 |
| |
| std::condition_variable | batchBarrier_ |
| |
| std::mutex | updateMutex_ |
| |
| std::shared_timed_mutex | insertionMutex_ |
| |
| std::deque< std::pair< GameUID, EpisodeKey > > | newGames_ |
| |
| ForwardFunction | forwardFunction_ |
| |
| ag::Container | model_ |
| |
| ag::Optimizer | optim_ |
| |
| std::shared_ptr< MetricsContext > | metricsContext_ |
| |
| ReplayBuffer | replayer_ |
| |
| bool | train_ = true |
| |
| std::atomic< bool > | done_ {false} |
| |
| std::mutex | modelWriteMutex_ |
| |
| std::shared_timed_mutex | activeMapMutex_ |
| |
| std::unique_ptr< BaseSampler > | sampler_ |
| |
| std::unique_ptr< AsyncBatcher > | batcher_ |
| |
| std::shared_ptr< HandleGuard > | epGuard_ |
| |
| ReplayBuffer::UIDKeyStore | actives_ |
| |
| static constexpr float | kFwdMetricsSubsampling = 0.1 |
| | We subsample kFwdMetricsSubsampling of the forward() events when measuring their duration. More...
|
| |