10 #include "models/bos/sample.h" 13 #include <cpid/trainer.h> 16 #include <unordered_map> 18 #include <autogradpp/autograd.h> 36 std::shared_ptr<cpid::Trainer> trainer;
44 ModelRunner(std::shared_ptr<cpid::Trainer> trainer);
64 virtual ag::Variant trainerForward(ag::Variant input,
EpisodeHandle const&);
70 std::shared_ptr<cpid::Trainer> trainer,
71 std::string modelType);
75 std::string modelType);
Game state.
Definition: state.h:42
Trainer::EpisodeHandle EpisodeHandle
Definition: trainer.h:303
ag::Container model
Definition: runner.h:38
std::string EpisodeHandle
Definition: runner.h:31
ModelRunner(ag::Container model)
Definition: runner.cpp:151
Sample takeSample(State *state) const
Definition: runner.cpp:156
virtual ag::Variant makeInput(Sample const &sample) const
Definition: runner.cpp:162
std::unordered_map< int64_t, std::string > indexToBo
Definition: runner.h:39
std::shared_ptr< StaticData > staticData
Definition: runner.h:34
virtual ~ModelRunner()=default
std::unique_ptr< ModelRunner > makeModelRunner(ag::Container model, std::string modelType)
Definition: runner.cpp:250
ag::Variant forward(Sample const &sample, EpisodeHandle const &handle=EpisodeHandle())
Definition: runner.cpp:166
void blacklistBuildOrder(std::string buildOrder)
Definition: runner.cpp:222
virtual ag::Variant modelForward(ag::Variant input)
Definition: runner.cpp:212
Helper class for running BOS models.
Definition: runner.h:27
std::string modelType
Definition: runner.h:40
Main namespace for bot-related code.
Definition: areainfo.cpp:17
torch::Tensor boMask
Definition: runner.h:41