|
TorchCraftAI
A bot for machine learning research on StarCraft: Brood War
|
A batcher that can operate on (already) batched data Should be used when features have a variable batch dimension, for instance the number of units controlled. More...
#include <batcher.h>
Inherits cpid::AsyncBatcher.
Public Member Functions | |
| SubBatchAsyncBatcher (int batchSize, ag::Container model=nullptr) | |
| ~SubBatchAsyncBatcher () | |
| std::vector< ag::Variant > | unBatch (const ag::Variant &out, bool stripOutput, double stripValue) override |
| ag::Variant | makeBatch (const std::vector< ag::Variant > &queries, double) override |
| void | allowPadding (bool allowPadding) |
| torch::Tensor | makeBatchTensors (std::vector< torch::Tensor > const &tensors, double padValue) |
Public Member Functions inherited from cpid::AsyncBatcher | |
| AsyncBatcher (ag::Container model, int batchsize, int padValue=-1, bool stripOutput=true, double stripValue=-1.) | |
| Construct a batcher. More... | |
| virtual ag::Variant | batchedForward (ag::Variant state) |
| This function queues up the state for a forward. More... | |
| virtual | ~AsyncBatcher () |
| void | setModel (ag::Container newModel) |
| Changes the model to be used for forwarding. More... | |
| std::unique_lock< std::shared_mutex > | lockModel () |
| Get a lock on the model. More... | |
| virtual std::vector< ag::Variant > | unBatch (const ag::Variant &out) |
| Given an output of the model, retrieve the replies for all the element of the batch. More... | |
| virtual ag::Variant | makeBatch (const std::vector< ag::Variant > &queries) |
| Given a vector of queries, create the batch that is going to be passed to the model. More... | |
| virtual bool | shouldConsume () |
| This function should return true when the batch is ready to be consumed. More... | |
Static Public Member Functions | |
| static std::vector< torch::Tensor > | unBatchTensor (const torch::Tensor &out, std::vector< int64_t > const &batchSizes) |
| static std::vector< int64_t > | findBatchInfo (ag::Variant const &batchInfoVar, std::string const &variableName) |
| static std::vector< torch::Tensor > | forEachSubbatch (ag::Variant const &input, std::string const &inputName, torch::Tensor batchedInput, std::function< torch::Tensor(torch::Tensor)> do_fn=[](torch::Tensor t){return t;}) |
Static Public Attributes | |
| static constexpr const char * | kBatchInfoKey = "batch_info" |
Protected Attributes | |
| bool | allowPadding_ = false |
Protected Attributes inherited from cpid::AsyncBatcher | |
| ag::Container | model_ |
| bool | consumeThreadStarted_ = false |
| int | batchSize_ |
| int | padValue_ |
| bool | stripOutput_ |
| double | stripValue_ |
| std::condition_variable_any | batchReadyCV_ |
| std::mutex | batchReadyMutex_ |
| priority_mutex | accessMutex_ |
| std::shared_mutex | modelMutex_ |
| std::atomic_size_t | querySize_ |
| std::vector< ag::Variant > | queries_ |
| std::vector< std::shared_ptr< std::promise< ag::Variant > > > | replies_ |
| std::thread | consumeThread_ |
| std::atomic< bool > | shouldStop_ {false} |
Additional Inherited Members | |
Protected Member Functions inherited from cpid::AsyncBatcher | |
| void | startBatching (int batchSize) |
| void | stopBatching () |
| void | consumeThread () |
A batcher that can operate on (already) batched data Should be used when features have a variable batch dimension, for instance the number of units controlled.
More specifically, tensors with sizes [b1, ft], [b2, ft] ..., are batched into a Tensor of size [b1 + b2 + ..., ft].
On the contrary to AsyncBatcher, SubBatchAsyncBatcher expects input tensors shape to differ on the first dimension only, and will not pad input tensors, unless explicitely autorized with 'allowPadding'.
| cpid::SubBatchAsyncBatcher::SubBatchAsyncBatcher | ( | int | batchSize, |
| ag::Container | model = nullptr |
||
| ) |
| cpid::SubBatchAsyncBatcher::~SubBatchAsyncBatcher | ( | ) |
|
inline |
|
static |
|
static |
|
overridevirtual |
Reimplemented from cpid::AsyncBatcher.
| torch::Tensor cpid::SubBatchAsyncBatcher::makeBatchTensors | ( | std::vector< torch::Tensor > const & | tensors, |
| double | padValue | ||
| ) |
|
overridevirtual |
Reimplemented from cpid::AsyncBatcher.
|
static |
|
protected |
|
static |
1.8.11