TorchCraftAI
A bot for machine learning research on StarCraft: Brood War
|
A batcher that can operate on (already) batched data Should be used when features have a variable batch dimension, for instance the number of units controlled. More...
#include <batcher.h>
Inherits cpid::AsyncBatcher.
Public Member Functions | |
SubBatchAsyncBatcher (int batchSize, ag::Container model=nullptr) | |
~SubBatchAsyncBatcher () | |
std::vector< ag::Variant > | unBatch (const ag::Variant &out, bool stripOutput, double stripValue) override |
ag::Variant | makeBatch (const std::vector< ag::Variant > &queries, double) override |
void | allowPadding (bool allowPadding) |
torch::Tensor | makeBatchTensors (std::vector< torch::Tensor > const &tensors, double padValue) |
Public Member Functions inherited from cpid::AsyncBatcher | |
AsyncBatcher (ag::Container model, int batchsize, int padValue=-1, bool stripOutput=true, double stripValue=-1.) | |
Construct a batcher. More... | |
virtual ag::Variant | batchedForward (ag::Variant state) |
This function queues up the state for a forward. More... | |
virtual | ~AsyncBatcher () |
void | setModel (ag::Container newModel) |
Changes the model to be used for forwarding. More... | |
std::unique_lock< std::shared_mutex > | lockModel () |
Get a lock on the model. More... | |
virtual std::vector< ag::Variant > | unBatch (const ag::Variant &out) |
Given an output of the model, retrieve the replies for all the element of the batch. More... | |
virtual ag::Variant | makeBatch (const std::vector< ag::Variant > &queries) |
Given a vector of queries, create the batch that is going to be passed to the model. More... | |
virtual bool | shouldConsume () |
This function should return true when the batch is ready to be consumed. More... | |
Static Public Member Functions | |
static std::vector< torch::Tensor > | unBatchTensor (const torch::Tensor &out, std::vector< int64_t > const &batchSizes) |
static std::vector< int64_t > | findBatchInfo (ag::Variant const &batchInfoVar, std::string const &variableName) |
static std::vector< torch::Tensor > | forEachSubbatch (ag::Variant const &input, std::string const &inputName, torch::Tensor batchedInput, std::function< torch::Tensor(torch::Tensor)> do_fn=[](torch::Tensor t){return t;}) |
Static Public Attributes | |
static constexpr const char * | kBatchInfoKey = "batch_info" |
Protected Attributes | |
bool | allowPadding_ = false |
Protected Attributes inherited from cpid::AsyncBatcher | |
ag::Container | model_ |
bool | consumeThreadStarted_ = false |
int | batchSize_ |
int | padValue_ |
bool | stripOutput_ |
double | stripValue_ |
std::condition_variable_any | batchReadyCV_ |
std::mutex | batchReadyMutex_ |
priority_mutex | accessMutex_ |
std::shared_mutex | modelMutex_ |
std::atomic_size_t | querySize_ |
std::vector< ag::Variant > | queries_ |
std::vector< std::shared_ptr< std::promise< ag::Variant > > > | replies_ |
std::thread | consumeThread_ |
std::atomic< bool > | shouldStop_ {false} |
Additional Inherited Members | |
Protected Member Functions inherited from cpid::AsyncBatcher | |
void | startBatching (int batchSize) |
void | stopBatching () |
void | consumeThread () |
A batcher that can operate on (already) batched data Should be used when features have a variable batch dimension, for instance the number of units controlled.
More specifically, tensors with sizes [b1, ft], [b2, ft] ..., are batched into a Tensor of size [b1 + b2 + ..., ft].
On the contrary to AsyncBatcher, SubBatchAsyncBatcher expects input tensors shape to differ on the first dimension only, and will not pad input tensors, unless explicitely autorized with 'allowPadding'.
cpid::SubBatchAsyncBatcher::SubBatchAsyncBatcher | ( | int | batchSize, |
ag::Container | model = nullptr |
||
) |
cpid::SubBatchAsyncBatcher::~SubBatchAsyncBatcher | ( | ) |
|
inline |
|
static |
|
static |
|
overridevirtual |
Reimplemented from cpid::AsyncBatcher.
torch::Tensor cpid::SubBatchAsyncBatcher::makeBatchTensors | ( | std::vector< torch::Tensor > const & | tensors, |
double | padValue | ||
) |
|
overridevirtual |
Reimplemented from cpid::AsyncBatcher.
|
static |
|
protected |
|
static |