13 #include <autogradpp/autograd.h> 25 : state(state), action(action), pAction(pAction), reward(reward) {}
61 std::size_t maxBatchSize_;
63 bool onlineUpdates_ =
false;
65 std::shared_timed_mutex updateMutex_;
67 std::deque<std::pair<GameUID, EpisodeKey>> newGames_;
71 std::queue<std::pair<GameUID, EpisodeKey>> seenGames_;
72 std::mutex newGamesMutex_;
73 bool enoughEpisodes_ =
false;
83 ag::Variant forward(ag::Variant inp,
EpisodeHandle const&)
override;
84 bool update()
override;
85 void doOnlineUpdatesInstead();
94 std::unique_ptr<BaseSampler> sampler,
97 std::size_t maxBatchSize = 50,
98 std::unique_ptr<AsyncBatcher> batcher =
nullptr);
105 virtual std::shared_ptr<ReplayBufferFrame> makeFrame(
106 ag::Variant trainerOutput,
109 std::shared_ptr<Evaluator> makeEvaluator(
111 std::unique_ptr<BaseSampler> sampler =
112 std::make_unique<DiscreteMaxSampler>())
override;
std::string GameUID
Definition: trainer.h:31
Definition: trainer.h:158
ag::Variant state
Definition: policygradienttrainer.h:27
torch::Tensor action
Definition: policygradienttrainer.h:28
The Trainer should be shared amongst multiple different nodes, and attached to a single Module...
Definition: trainer.h:156
double reward
Reward observed since taking previous action.
Definition: policygradienttrainer.h:33
float pAction
Probability of action according to the policy that was used to obtain this frame. ...
Definition: policygradienttrainer.h:31
int episodes()
Definition: policygradienttrainer.h:87
std::string EpisodeKey
Definition: trainer.h:32
The TorchCraftAI training library.
Definition: batcher.cpp:15
Stub base class for replay buffer frames.
Definition: trainer.h:69
BatchedPGReplayBufferFrame(ag::Variant state, torch::Tensor action, float pAction, double reward)
Definition: policygradienttrainer.h:20
Definition: policygradienttrainer.h:19
std::vector< std::shared_ptr< ReplayBufferFrame >> Episode
Definition: trainer.h:89
Off policy policy gradient with a critic.
Definition: policygradienttrainer.h:59