9 #include "prioritymutex.h" 11 #include <glog/logging.h> 14 #include <shared_mutex> 17 #include <autogradpp/autograd.h> 18 #include <torch/torch.h> 37 bool stripOutput =
true,
38 double stripValue = -1.);
55 void setModel(ag::Container newModel);
59 std::unique_lock<std::shared_mutex>
lockModel();
66 virtual std::vector<ag::Variant>
unBatch(
const ag::Variant& out);
67 virtual std::vector<ag::Variant>
68 unBatch(
const ag::Variant& out,
bool stripOutput,
double stripValue);
76 virtual ag::Variant
makeBatch(
const std::vector<ag::Variant>& queries);
78 const std::vector<ag::Variant>& queries,
109 std::vector<std::shared_ptr<std::promise<ag::Variant>>>
replies_;
127 static constexpr
const char* kBatchInfoKey =
"batch_info";
132 std::vector<ag::Variant>
133 unBatch(
const ag::Variant& out,
bool stripOutput,
double stripValue)
override;
135 ag::Variant
makeBatch(
const std::vector<ag::Variant>& queries,
double)
139 allowPadding_ = allowPadding;
142 torch::Tensor makeBatchTensors(
143 std::vector<torch::Tensor>
const& tensors,
147 bool allowPadding_ =
false;
150 static std::vector<torch::Tensor> unBatchTensor(
151 const torch::Tensor& out,
152 std::vector<int64_t>
const& batchSizes);
154 static std::vector<int64_t> findBatchInfo(
155 ag::Variant
const& batchInfoVar,
156 std::string
const& variableName);
158 static std::vector<torch::Tensor> forEachSubbatch(
159 ag::Variant
const& input,
160 std::string
const& inputName,
161 torch::Tensor batchedInput,
162 std::function<torch::Tensor(torch::Tensor)> do_fn = [](torch::Tensor t) {
virtual ag::Variant batchedForward(ag::Variant state)
This function queues up the state for a forward.
Definition: batcher.cpp:63
virtual std::vector< ag::Variant > unBatch(const ag::Variant &out)
Given an output of the model, retrieve the replies for all the element of the batch.
Definition: batcher.cpp:163
void startBatching(int batchSize)
Definition: batcher.cpp:41
std::mutex batchReadyMutex_
Definition: batcher.h:101
std::shared_mutex modelMutex_
Definition: batcher.h:105
This class implements a mutex that offers some control over the priority of the waiting threads...
Definition: prioritymutex.h:49
virtual bool shouldConsume()
This function should return true when the batch is ready to be consumed.
Definition: batcher.cpp:91
std::thread consumeThread_
Definition: batcher.h:111
void consumeThread()
Definition: batcher.cpp:95
virtual ag::Variant makeBatch(const std::vector< ag::Variant > &queries)
Given a vector of queries, create the batch that is going to be passed to the model.
Definition: batcher.cpp:179
std::condition_variable_any batchReadyCV_
Definition: batcher.h:100
ag::Container model_
Definition: batcher.h:92
int padValue_
Definition: batcher.h:96
std::vector< std::shared_ptr< std::promise< ag::Variant > > > replies_
Definition: batcher.h:109
virtual ~AsyncBatcher()
Definition: batcher.cpp:37
bool stripOutput_
Definition: batcher.h:97
priority_mutex accessMutex_
Definition: batcher.h:104
void allowPadding(bool allowPadding)
Definition: batcher.h:138
std::atomic< bool > shouldStop_
Definition: batcher.h:112
void stopBatching()
Definition: batcher.cpp:54
The TorchCraftAI training library.
Definition: batcher.cpp:15
double stripValue_
Definition: batcher.h:98
AsyncBatcher(ag::Container model, int batchsize, int padValue=-1, bool stripOutput=true, double stripValue=-1.)
Construct a batcher.
Definition: batcher.cpp:17
bool consumeThreadStarted_
Definition: batcher.h:93
std::vector< ag::Variant > queries_
Definition: batcher.h:108
int batchSize_
Definition: batcher.h:95
void setModel(ag::Container newModel)
Changes the model to be used for forwarding.
Definition: batcher.cpp:183
std::atomic_size_t querySize_
Definition: batcher.h:107
std::unique_lock< std::shared_mutex > lockModel()
Get a lock on the model.
Definition: batcher.cpp:188
A batcher that can operate on (already) batched data Should be used when features have a variable bat...
Definition: batcher.h:125