|
| void | stepEpisode (GameUID const &, EpisodeKey const &, ReplayBuffer::Episode &) override |
| |
| bool | update () override |
| |
| EpisodeHandle | startEpisode () override |
| | Returns true if succeeded to register an episode, and false otherwise. More...
|
| |
| ag::Variant | forward (ag::Variant inp, EpisodeHandle const &) override |
| |
| | OnlineZORBTrainer (ag::Container model, ag::Optimizer optim) |
| |
| virtual std::shared_ptr< ReplayBufferFrame > | makeFrame (ag::Variant trainerOutput, ag::Variant state, float reward) override |
| |
| | TORCH_ARG (float, valueLambda)=0 |
| |
| | TORCH_ARG (float, delta) |
| |
| | TORCH_ARG (int, batchSize) |
| |
| | TORCH_ARG (bool, antithetic) |
| |
| | Trainer (ag::Container model, ag::Optimizer optim, std::unique_ptr< BaseSampler >, std::unique_ptr< AsyncBatcher > batcher=nullptr) |
| |
| ag::Variant | forwardUnbatched (ag::Variant in, ag::Container model=nullptr) |
| | Convenience function when one need to forward a single input. More...
|
| |
| virtual | ~Trainer ()=default |
| |
| void | setTrain (bool=true) |
| |
| bool | isTrain () const |
| |
| ag::Variant | sample (ag::Variant in) |
| | Sample using the class' sampler. More...
|
| |
| ag::Container | model () const |
| |
| ag::Optimizer | optim () const |
| |
| ReplayBuffer & | replayBuffer () |
| |
| virtual std::shared_ptr< Evaluator > | makeEvaluator (size_t, std::unique_ptr< BaseSampler > sampler) |
| |
| void | setDone (bool=true) |
| |
| bool | isDone () const |
| |
| virtual void | step (EpisodeHandle const &, std::shared_ptr< ReplayBufferFrame > v, bool isDone=false) |
| |
| virtual void | forceStopEpisode (EpisodeHandle const &) |
| |
| bool | isActive (EpisodeHandle const &) |
| |
| virtual void | reset () |
| | Releases all the worker threads so that they can be joined. More...
|
| |
| template<class Archive > |
| void | save (Archive &ar) const |
| |
| template<class Archive > |
| void | load (Archive &ar) |
| |
| template<typename T > |
| bool | is () const |
| |
| Trainer & | setMetricsContext (std::shared_ptr< MetricsContext > context) |
| |
| std::shared_ptr< MetricsContext > | metricsContext () const |
| |
| | TORCH_ARG (float, noiseStd) |
| |
| | TORCH_ARG (bool, continuousActions) |
| |
| void | setBatcher (std::unique_ptr< AsyncBatcher > batcher) |
| |
This is a OnlineZORBTrainer that works with multiple actions per frame.
Contract:
- Expect to make N distinct actions, with M_{i=1,..,N} possible actions for each
- The model inherits from OnlineZORBModel and returns a vector of noises
- Input: the replaybuffer frame state
- Output: The model should take in a state and generate an array of [(s, A_1,i), w_1, ind_1, v_1, ..., (s, A_N,i), w_N, ind_N, v_N]
- is a matrix of size [M_i, embed_size] for each action i
- w has size [embed_size]
- ind is an index to the noise vector generated.
- v is a critic for variance reduction. Optional, (use torch::Tensor()) Because of the particular way this trainer works, the sampler (inference procedureg) is part of the forward function.
- The inference procedure the trainer provides through trainer->forward is argmax_i * (w + * noise[ind])
- The critic is used in training, by doing (return - critic), where the critic is trained on the return, like in actor critic. This works because G = E_u [ f(x + d u) u ] G = E_u [ f(x + d u) u ] - E_u [ v u ] (u is gaussian this so this 0) G = E_u [ [ f(x + d u) - v ] u ]
The trainer->forward function will return you [action_i, action_scores_i, ...]
TODO ON MULTIPLE NODES, THIS IS UNTESTED AND DOESN'T WORK even though the logic is mostly there