TorchCraftAI
A bot for machine learning research on StarCraft: Brood War
Public Member Functions | Protected Member Functions | Protected Attributes | List of all members
cpid::AsyncBatcher Class Reference

#include <batcher.h>

Inherited by cpid::SubBatchAsyncBatcher.

Public Member Functions

 AsyncBatcher (ag::Container model, int batchsize, int padValue=-1, bool stripOutput=true, double stripValue=-1.)
 Construct a batcher. More...
 
virtual ag::Variant batchedForward (ag::Variant state)
 This function queues up the state for a forward. More...
 
virtual ~AsyncBatcher ()
 
void setModel (ag::Container newModel)
 Changes the model to be used for forwarding. More...
 
std::unique_lock< std::shared_mutex > lockModel ()
 Get a lock on the model. More...
 
virtual std::vector< ag::Variant > unBatch (const ag::Variant &out)
 Given an output of the model, retrieve the replies for all the element of the batch. More...
 
virtual std::vector< ag::Variant > unBatch (const ag::Variant &out, bool stripOutput, double stripValue)
 
virtual ag::Variant makeBatch (const std::vector< ag::Variant > &queries)
 Given a vector of queries, create the batch that is going to be passed to the model. More...
 
virtual ag::Variant makeBatch (const std::vector< ag::Variant > &queries, double padValue)
 
virtual bool shouldConsume ()
 This function should return true when the batch is ready to be consumed. More...
 

Protected Member Functions

void startBatching (int batchSize)
 
void stopBatching ()
 
void consumeThread ()
 

Protected Attributes

ag::Container model_
 
bool consumeThreadStarted_ = false
 
int batchSize_
 
int padValue_
 
bool stripOutput_
 
double stripValue_
 
std::condition_variable_any batchReadyCV_
 
std::mutex batchReadyMutex_
 
priority_mutex accessMutex_
 
std::shared_mutex modelMutex_
 
std::atomic_size_t querySize_
 
std::vector< ag::Variant > queries_
 
std::vector< std::shared_ptr< std::promise< ag::Variant > > > replies_
 
std::thread consumeThread_
 
std::atomic< bool > shouldStop_ {false}
 

Constructor & Destructor Documentation

cpid::AsyncBatcher::AsyncBatcher ( ag::Container  model,
int  batchsize,
int  padValue = -1,
bool  stripOutput = true,
double  stripValue = -1. 
)

Construct a batcher.

Parameters
modelis the model used for forwarding
batchSizeis the maximal size of a batch. A forward will occur when that many inputs have been collected. If the consumer waits for more than 200ms for a batch, it will try to forward an incomplete batch
padValueThis is the value used to pad the inputs to the same size
stipOutputwhen true, any negative value in the output tensors will be masked out
cpid::AsyncBatcher::~AsyncBatcher ( )
virtual

Member Function Documentation

ag::Variant cpid::AsyncBatcher::batchedForward ( ag::Variant  state)
virtual

This function queues up the state for a forward.

This function is blocking until the batch is full, then it executes a forward and then returns corresponding to this state. After a given timeout, the forward will be done anyways, even if the batch is not full. WARNING: this function only executes forward WITHOUT gradient (and ignores any torch::grad_guard)

void cpid::AsyncBatcher::consumeThread ( )
protected
std::unique_lock< std::shared_mutex > cpid::AsyncBatcher::lockModel ( )

Get a lock on the model.

That allows updating the model ensuring that no forward is being executed

ag::Variant cpid::AsyncBatcher::makeBatch ( const std::vector< ag::Variant > &  queries)
virtual

Given a vector of queries, create the batch that is going to be passed to the model.

This default implementation finds the max size of each tensor accross the batch and resizes all the queries to that size, padding the extra space with -1

ag::Variant cpid::AsyncBatcher::makeBatch ( const std::vector< ag::Variant > &  queries,
double  padValue 
)
virtual

Reimplemented in cpid::SubBatchAsyncBatcher.

void cpid::AsyncBatcher::setModel ( ag::Container  newModel)

Changes the model to be used for forwarding.

This operation has high priority, but if a forward is about to be executed with the old model, it may be executed before the effective model switch

bool cpid::AsyncBatcher::shouldConsume ( )
virtual

This function should return true when the batch is ready to be consumed.

void cpid::AsyncBatcher::startBatching ( int  batchSize)
protected
void cpid::AsyncBatcher::stopBatching ( )
protected
std::vector< ag::Variant > cpid::AsyncBatcher::unBatch ( const ag::Variant &  out)
virtual

Given an output of the model, retrieve the replies for all the element of the batch.

It will mask out any negative element of the reply tensor (that allows to batch replies even though they don't have the same size)

std::vector< ag::Variant > cpid::AsyncBatcher::unBatch ( const ag::Variant &  out,
bool  stripOutput,
double  stripValue 
)
virtual

Reimplemented in cpid::SubBatchAsyncBatcher.

Member Data Documentation

priority_mutex cpid::AsyncBatcher::accessMutex_
protected
std::condition_variable_any cpid::AsyncBatcher::batchReadyCV_
protected
std::mutex cpid::AsyncBatcher::batchReadyMutex_
protected
int cpid::AsyncBatcher::batchSize_
protected
std::thread cpid::AsyncBatcher::consumeThread_
protected
bool cpid::AsyncBatcher::consumeThreadStarted_ = false
protected
ag::Container cpid::AsyncBatcher::model_
protected
std::shared_mutex cpid::AsyncBatcher::modelMutex_
protected
int cpid::AsyncBatcher::padValue_
protected
std::vector<ag::Variant> cpid::AsyncBatcher::queries_
protected
std::atomic_size_t cpid::AsyncBatcher::querySize_
protected
std::vector<std::shared_ptr<std::promise<ag::Variant> > > cpid::AsyncBatcher::replies_
protected
std::atomic<bool> cpid::AsyncBatcher::shouldStop_ {false}
protected
bool cpid::AsyncBatcher::stripOutput_
protected
double cpid::AsyncBatcher::stripValue_
protected

The documentation for this class was generated from the following files: