11 #include "common/rand.h" 14 #include <shared_mutex> 16 #include <glog/logging.h> 18 #include <autogradpp/autograd.h> 37 inline const std::string
kQKey =
"Q";
38 inline const std::string
kPiKey =
"Pi";
46 template <
typename T,
typename U>
48 return std::hash<T>()(x.first) ^ std::hash<U>()(x.second);
72 template <
class Archive>
89 using Episode = std::vector<std::shared_ptr<ReplayBufferFrame>>;
91 std::unordered_map<GameUID, std::unordered_map<EpisodeKey, Episode>>;
93 std::unordered_map<GameUID, std::unordered_set<EpisodeKey>>;
94 using SampleOutput = std::pair<EpisodeTuple, std::reference_wrapper<Episode>>;
99 std::shared_ptr<ReplayBufferFrame> value,
100 bool isDone =
false);
102 std::size_t size()
const;
103 std::size_t size(
GameUID const&)
const;
104 std::size_t sizeDone()
const;
105 std::size_t sizeDone(
GameUID const&)
const;
108 std::vector<SampleOutput> getAllEpisodes();
112 template <
typename RandomGenerator>
113 std::vector<SampleOutput> sample(RandomGenerator& g, uint32_t num = 1);
114 std::vector<SampleOutput> sample(uint32_t num = 1);
117 bool has(GameUID
const&, EpisodeKey
const& = kDefaultEpisodeKey);
118 bool isDone(GameUID
const&, EpisodeKey
const& = kDefaultEpisodeKey);
124 template <
typename RandomGenerator>
163 explicit operator bool()
const;
187 std::weak_ptr<HandleGuard> guard_;
192 std::unique_ptr<BaseSampler>,
193 std::unique_ptr<AsyncBatcher> batcher =
nullptr);
194 virtual ag::Variant forward(ag::Variant inp,
EpisodeHandle const&);
200 ag::Variant forwardUnbatched(ag::Variant in, ag::Container model =
nullptr);
205 virtual bool update() = 0;
208 void setTrain(
bool =
true);
213 ag::Variant sample(ag::Variant in);
214 ag::Container model()
const;
215 ag::Optimizer optim()
const;
217 virtual std::shared_ptr<Evaluator> makeEvaluator(
219 std::unique_ptr<BaseSampler> sampler);
224 void setDone(
bool =
true);
231 std::shared_ptr<ReplayBufferFrame> v,
232 bool isDone =
false);
233 virtual std::shared_ptr<ReplayBufferFrame>
234 makeFrame(ag::Variant trainerOutput, ag::Variant state,
float reward) = 0;
248 virtual void reset();
250 template <
class Archive>
251 void save(Archive& ar)
const;
252 template <
class Archive>
253 void load(Archive& ar);
254 template <
typename T>
257 Trainer& setMetricsContext(std::shared_ptr<MetricsContext> context);
258 std::shared_ptr<MetricsContext> metricsContext()
const;
260 TORCH_ARG(
float, noiseStd) = 1e-2;
261 TORCH_ARG(
bool, continuousActions) =
false;
263 void setBatcher(std::unique_ptr<AsyncBatcher> batcher);
273 ag::Container model_;
278 std::atomic<bool> done_{
false};
286 template <
typename T>
290 std::function<ag::Variant(ag::Variant, EpisodeHandle const&)>;
292 static std::shared_ptr<Evaluator> evaluatorFactory(
294 std::unique_ptr<BaseSampler> s,
301 static constexpr
float kFwdMetricsSubsampling = 0.1;
307 template <
typename RandomGenerator>
311 std::vector<SampleOutput> samples;
312 for (uint32_t i = 0; i < num; i++) {
313 samples.push_back(sample_(g));
318 template <
typename RandomGenerator>
320 std::shared_lock<std::shared_timed_mutex> lock(replayerRWMutex_);
321 if (dones_.size() == 0) {
322 throw std::runtime_error(
"No finished episodes yet...");
325 if (game.second.size() == 0) {
326 LOG(FATAL) <<
"no episodes in game";
330 return std::make_pair(
331 EpisodeTuple{game.first, ep}, std::ref(storage_[game.first][ep]));
334 template <
typename T>
337 std::vector<T const*> ret;
338 ret.reserve(e.size());
339 for (
auto& elem : e) {
340 ret.push_back(static_cast<T const*>(elem.get()));
345 template <
class Archive>
347 ar(CEREAL_NVP(*model_));
348 ar(CEREAL_NVP(optim_));
351 template <
class Archive>
353 ar(CEREAL_NVP(*model_));
354 ar(CEREAL_NVP(optim_));
355 optim_->add_parameters(model_->parameters());
358 template <
typename T>
360 return dynamic_cast<const T*
>(
this) !=
nullptr;
364 std::shared_ptr<MetricsContext> context) {
365 metricsContext_ = context;
370 return metricsContext_;
ag::Optimizer optim_
Definition: trainer.h:274
std::string GameUID
Definition: trainer.h:31
bool is() const
Definition: trainer.h:359
std::size_t operator()(const std::pair< T, U > &x) const
Definition: trainer.h:47
Trainer & setMetricsContext(std::shared_ptr< MetricsContext > context)
Definition: trainer.h:363
const std::string kQKey
Definition: trainer.h:37
Trainer::EpisodeHandle EpisodeHandle
Definition: trainer.h:303
Definition: trainer.h:134
std::shared_ptr< MetricsContext > metricsContext_
Definition: trainer.h:275
Definition: trainer.h:158
std::pair< EpisodeTuple, std::reference_wrapper< Episode >> SampleOutput
Definition: trainer.h:94
Definition: evaluator.h:19
Store storage_
Definition: trainer.h:121
float reward
Definition: trainer.h:78
std::shared_ptr< HandleGuard > epGuard_
Definition: trainer.h:285
const std::string kActionKey
Definition: trainer.h:41
std::unique_ptr< AsyncBatcher > batcher_
Definition: trainer.h:283
virtual void stepEpisode(GameUID const &, EpisodeKey const &, ReplayBuffer::Episode &)
Definition: trainer.h:269
SampleOutput sample_(RandomGenerator &g)
Definition: trainer.h:319
std::unordered_map< GameUID, std::unordered_set< EpisodeKey >> UIDKeyStore
Definition: trainer.h:93
const constexpr auto kDefaultEpisodeKey
Definition: trainer.h:33
const std::string kPiKey
Definition: trainer.h:38
std::shared_ptr< MetricsContext > metricsContext() const
Definition: trainer.h:369
EpisodeKey episodeKey
Definition: trainer.h:54
const std::string kValueKey
Definition: trainer.h:36
std::vector< T const * > cast(ReplayBuffer::Episode const &e)
Definition: trainer.h:335
The Trainer should be shared amongst multiple different nodes, and attached to a single Module...
Definition: trainer.h:156
void serialize(Archive &ar)
Definition: trainer.h:73
const std::string kPActionKey
Definition: trainer.h:42
void load(Archive &ar)
Definition: trainer.h:352
ReplayBuffer::UIDKeyStore actives_
Definition: trainer.h:298
const std::string kActionQKey
Definition: trainer.h:40
RewardBufferFrame(float reward)
Definition: trainer.h:77
A sampler takes the output of the model, and outputs an action accordingly.
Definition: sampler.h:19
GameUID genGameUID()
Definition: trainer.cpp:41
std::string EpisodeKey
Definition: trainer.h:32
UIDKeyStore dones_
Definition: trainer.h:122
std::shared_timed_mutex replayerRWMutex_
Definition: trainer.h:128
std::function< ag::Variant(ag::Variant, EpisodeHandle const &)> ForwardFunction
Definition: trainer.h:290
std::unique_ptr< BaseSampler > sampler_
Definition: trainer.h:282
std::vector< SampleOutput > sample(RandomGenerator &g, uint32_t num=1)
Definition: trainer.h:308
bool isTrain() const
Definition: trainer.h:209
The TorchCraftAI training library.
Definition: batcher.cpp:15
Stub base class for replay buffer frames.
Definition: trainer.h:69
std::ostream & operator<<(std::ostream &os, EpisodeHandle const &handle)
Definition: trainer.cpp:379
std::unordered_map< GameUID, std::unordered_map< EpisodeKey, Episode >> Store
Definition: trainer.h:91
std::mutex modelWriteMutex_
Definition: trainer.h:279
GameUID gameID
Definition: trainer.h:53
virtual void stepFrame(GameUID const &, EpisodeKey const &, ReplayBuffer::Episode &)
Definition: trainer.h:267
std::vector< std::shared_ptr< ReplayBufferFrame >> Episode
Definition: trainer.h:89
virtual void stepGame(GameUID const &game)
Definition: trainer.h:271
ReplayBuffer replayer_
Definition: trainer.h:276
bool isDone() const
Definition: trainer.h:225
Stores an unordered_map[GameUID] = unordered_map[int, Episode] All the public functions here should b...
Definition: trainer.h:87
void save(Archive &ar) const
Definition: trainer.h:346
const std::string kSigmaKey
Definition: trainer.h:39
Iter select_randomly(Iter start, Iter end, RandomGenerator &g)
Definition: rand.h:78
std::shared_timed_mutex activeMapMutex_
Definition: trainer.h:280