TorchCraftAI
A bot for machine learning research on StarCraft: Brood War
Public Member Functions | Static Public Member Functions | Static Public Attributes | Protected Attributes | List of all members
cpid::SubBatchAsyncBatcher Class Reference

A batcher that can operate on (already) batched data Should be used when features have a variable batch dimension, for instance the number of units controlled. More...

#include <batcher.h>

Inherits cpid::AsyncBatcher.

Public Member Functions

 SubBatchAsyncBatcher (int batchSize, ag::Container model=nullptr)
 
 ~SubBatchAsyncBatcher ()
 
std::vector< ag::Variant > unBatch (const ag::Variant &out, bool stripOutput, double stripValue) override
 
ag::Variant makeBatch (const std::vector< ag::Variant > &queries, double) override
 
void allowPadding (bool allowPadding)
 
torch::Tensor makeBatchTensors (std::vector< torch::Tensor > const &tensors, double padValue)
 
- Public Member Functions inherited from cpid::AsyncBatcher
 AsyncBatcher (ag::Container model, int batchsize, int padValue=-1, bool stripOutput=true, double stripValue=-1.)
 Construct a batcher. More...
 
virtual ag::Variant batchedForward (ag::Variant state)
 This function queues up the state for a forward. More...
 
virtual ~AsyncBatcher ()
 
void setModel (ag::Container newModel)
 Changes the model to be used for forwarding. More...
 
std::unique_lock< std::shared_mutex > lockModel ()
 Get a lock on the model. More...
 
virtual std::vector< ag::Variant > unBatch (const ag::Variant &out)
 Given an output of the model, retrieve the replies for all the element of the batch. More...
 
virtual ag::Variant makeBatch (const std::vector< ag::Variant > &queries)
 Given a vector of queries, create the batch that is going to be passed to the model. More...
 
virtual bool shouldConsume ()
 This function should return true when the batch is ready to be consumed. More...
 

Static Public Member Functions

static std::vector< torch::Tensor > unBatchTensor (const torch::Tensor &out, std::vector< int64_t > const &batchSizes)
 
static std::vector< int64_t > findBatchInfo (ag::Variant const &batchInfoVar, std::string const &variableName)
 
static std::vector< torch::Tensor > forEachSubbatch (ag::Variant const &input, std::string const &inputName, torch::Tensor batchedInput, std::function< torch::Tensor(torch::Tensor)> do_fn=[](torch::Tensor t){return t;})
 

Static Public Attributes

static constexpr const char * kBatchInfoKey = "batch_info"
 

Protected Attributes

bool allowPadding_ = false
 
- Protected Attributes inherited from cpid::AsyncBatcher
ag::Container model_
 
bool consumeThreadStarted_ = false
 
int batchSize_
 
int padValue_
 
bool stripOutput_
 
double stripValue_
 
std::condition_variable_any batchReadyCV_
 
std::mutex batchReadyMutex_
 
priority_mutex accessMutex_
 
std::shared_mutex modelMutex_
 
std::atomic_size_t querySize_
 
std::vector< ag::Variant > queries_
 
std::vector< std::shared_ptr< std::promise< ag::Variant > > > replies_
 
std::thread consumeThread_
 
std::atomic< bool > shouldStop_ {false}
 

Additional Inherited Members

- Protected Member Functions inherited from cpid::AsyncBatcher
void startBatching (int batchSize)
 
void stopBatching ()
 
void consumeThread ()
 

Detailed Description

A batcher that can operate on (already) batched data Should be used when features have a variable batch dimension, for instance the number of units controlled.

More specifically, tensors with sizes [b1, ft], [b2, ft] ..., are batched into a Tensor of size [b1 + b2 + ..., ft].

On the contrary to AsyncBatcher, SubBatchAsyncBatcher expects input tensors shape to differ on the first dimension only, and will not pad input tensors, unless explicitely autorized with 'allowPadding'.

Constructor & Destructor Documentation

cpid::SubBatchAsyncBatcher::SubBatchAsyncBatcher ( int  batchSize,
ag::Container  model = nullptr 
)
cpid::SubBatchAsyncBatcher::~SubBatchAsyncBatcher ( )

Member Function Documentation

void cpid::SubBatchAsyncBatcher::allowPadding ( bool  allowPadding)
inline
std::vector< int64_t > cpid::SubBatchAsyncBatcher::findBatchInfo ( ag::Variant const &  batchInfoVar,
std::string const &  variableName 
)
static
std::vector< torch::Tensor > cpid::SubBatchAsyncBatcher::forEachSubbatch ( ag::Variant const &  input,
std::string const &  inputName,
torch::Tensor  batchedInput,
std::function< torch::Tensor(torch::Tensor)>  do_fn = [](torch::Tensor t) { return t; } 
)
static
ag::Variant cpid::SubBatchAsyncBatcher::makeBatch ( const std::vector< ag::Variant > &  queries,
double  padValue 
)
overridevirtual

Reimplemented from cpid::AsyncBatcher.

torch::Tensor cpid::SubBatchAsyncBatcher::makeBatchTensors ( std::vector< torch::Tensor > const &  tensors,
double  padValue 
)
std::vector< ag::Variant > cpid::SubBatchAsyncBatcher::unBatch ( const ag::Variant &  out,
bool  stripOutput,
double  stripValue 
)
overridevirtual

Reimplemented from cpid::AsyncBatcher.

std::vector< torch::Tensor > cpid::SubBatchAsyncBatcher::unBatchTensor ( const torch::Tensor &  out,
std::vector< int64_t > const &  batchSizes 
)
static

Member Data Documentation

bool cpid::SubBatchAsyncBatcher::allowPadding_ = false
protected
constexpr const char* cpid::SubBatchAsyncBatcher::kBatchInfoKey = "batch_info"
static

The documentation for this class was generated from the following files: