TorchCraftAI
A bot for machine learning research on StarCraft: Brood War
|
Helper class for running BOS models. More...
#include <runner.h>
Public Types | |
using | EpisodeHandle = std::string |
Public Member Functions | |
ModelRunner (ag::Container model) | |
virtual | ~ModelRunner ()=default |
Sample | takeSample (State *state) const |
virtual ag::Variant | makeInput (Sample const &sample) const |
ag::Variant | forward (Sample const &sample, EpisodeHandle const &handle=EpisodeHandle()) |
ag::Variant | forward (ag::Variant input, Sample const &sample, EpisodeHandle const &=EpisodeHandle()) |
void | blacklistBuildOrder (std::string buildOrder) |
Public Attributes | |
std::shared_ptr< StaticData > | staticData = nullptr |
ag::Container | model |
std::unordered_map< int64_t, std::string > | indexToBo |
std::string | modelType |
torch::Tensor | boMask |
Protected Member Functions | |
virtual ag::Variant | modelForward (ag::Variant input) |
Helper class for running BOS models.
Once instantiated, the runner is valid for the current game only.
using cherrypi::bos::ModelRunner::EpisodeHandle = std::string |
cherrypi::bos::ModelRunner::ModelRunner | ( | ag::Container | model | ) |
|
virtualdefault |
void cherrypi::bos::ModelRunner::blacklistBuildOrder | ( | std::string | buildOrder | ) |
ag::Variant cherrypi::bos::ModelRunner::forward | ( | Sample const & | sample, |
EpisodeHandle const & | handle = EpisodeHandle() |
||
) |
ag::Variant cherrypi::bos::ModelRunner::forward | ( | ag::Variant | input, |
Sample const & | sample, | ||
EpisodeHandle const & | handle = EpisodeHandle() |
||
) |
|
virtual |
|
protectedvirtual |
torch::Tensor cherrypi::bos::ModelRunner::boMask |
std::unordered_map<int64_t, std::string> cherrypi::bos::ModelRunner::indexToBo |
ag::Container cherrypi::bos::ModelRunner::model |
std::string cherrypi::bos::ModelRunner::modelType |
std::shared_ptr<StaticData> cherrypi::bos::ModelRunner::staticData = nullptr |