TorchCraftAI
A bot for machine learning research on StarCraft: Brood War
Public Types | Public Member Functions | Public Attributes | Protected Member Functions | List of all members
cherrypi::bos::ModelRunner Struct Reference

Helper class for running BOS models. More...

#include <runner.h>

Public Types

using EpisodeHandle = std::string
 

Public Member Functions

 ModelRunner (ag::Container model)
 
virtual ~ModelRunner ()=default
 
Sample takeSample (State *state) const
 
virtual ag::Variant makeInput (Sample const &sample) const
 
ag::Variant forward (Sample const &sample, EpisodeHandle const &handle=EpisodeHandle())
 
ag::Variant forward (ag::Variant input, Sample const &sample, EpisodeHandle const &=EpisodeHandle())
 
void blacklistBuildOrder (std::string buildOrder)
 

Public Attributes

std::shared_ptr< StaticDatastaticData = nullptr
 
ag::Container model
 
std::unordered_map< int64_t, std::string > indexToBo
 
std::string modelType
 
torch::Tensor boMask
 

Protected Member Functions

virtual ag::Variant modelForward (ag::Variant input)
 

Detailed Description

Helper class for running BOS models.

Once instantiated, the runner is valid for the current game only.

Member Typedef Documentation

Constructor & Destructor Documentation

cherrypi::bos::ModelRunner::ModelRunner ( ag::Container  model)
virtual cherrypi::bos::ModelRunner::~ModelRunner ( )
virtualdefault

Member Function Documentation

void cherrypi::bos::ModelRunner::blacklistBuildOrder ( std::string  buildOrder)
ag::Variant cherrypi::bos::ModelRunner::forward ( Sample const &  sample,
EpisodeHandle const &  handle = EpisodeHandle() 
)
ag::Variant cherrypi::bos::ModelRunner::forward ( ag::Variant  input,
Sample const &  sample,
EpisodeHandle const &  handle = EpisodeHandle() 
)
ag::Variant cherrypi::bos::ModelRunner::makeInput ( Sample const &  sample) const
virtual
ag::Variant cherrypi::bos::ModelRunner::modelForward ( ag::Variant  input)
protectedvirtual
Sample cherrypi::bos::ModelRunner::takeSample ( State state) const

Member Data Documentation

torch::Tensor cherrypi::bos::ModelRunner::boMask
std::unordered_map<int64_t, std::string> cherrypi::bos::ModelRunner::indexToBo
ag::Container cherrypi::bos::ModelRunner::model
std::string cherrypi::bos::ModelRunner::modelType
std::shared_ptr<StaticData> cherrypi::bos::ModelRunner::staticData = nullptr

The documentation for this struct was generated from the following files: