TorchCraftAI
A bot for machine learning research on StarCraft: Brood War
batcher.h
1 /*
2  * Copyright (c) 2017-present, Facebook, Inc.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #pragma once
9 #include "prioritymutex.h"
10 #include <future>
11 #include <glog/logging.h>
12 #include <memory>
13 #include <mutex>
14 #include <shared_mutex>
15 #include <thread>
16 
17 #include <autogradpp/autograd.h>
18 #include <torch/torch.h>
19 
20 namespace cpid {
21 
22 class AsyncBatcher {
23  public:
24  /** Construct a batcher
25  * @param model is the model used for forwarding
26  * @param batchSize is the maximal size of a batch. A forward will occur when
27  * that many inputs have been collected. If the consumer waits for more than
28  * 200ms for a batch, it will try to forward an incomplete batch
29  * @param padValue This is the value used to pad the inputs to the same size
30  * @param stipOutput: when true, any negative value in the output tensors will
31  * be masked out
32  */
34  ag::Container model,
35  int batchsize,
36  int padValue = -1,
37  bool stripOutput = true,
38  double stripValue = -1.);
39 
40  /** This function queues up the state for a forward. This function
41  * is blocking until the batch is full, then it executes a forward and then
42  * returns corresponding to this state.
43  * After a given timeout, the forward will be done anyways, even if the batch
44  * is not full.
45  * WARNING: this function only executes forward WITHOUT gradient
46  * (and ignores any torch::grad_guard)
47  */
48  virtual ag::Variant batchedForward(ag::Variant state);
49 
50  virtual ~AsyncBatcher();
51 
52  /** Changes the model to be used for forwarding. This operation has
53  * high priority, but if a forward is about to be executed with the old model,
54  * it may be executed before the effective model switch */
55  void setModel(ag::Container newModel);
56 
57  /** Get a lock on the model. That allows updating the model ensuring that no
58  * forward is being executed */
59  std::unique_lock<std::shared_mutex> lockModel();
60 
61  /** Given an output of the model, retrieve the replies for all the
62  * element of the batch.
63  * It will mask out any negative element of the reply tensor (that allows to
64  * batch replies even though they don't have the same size)
65  */
66  virtual std::vector<ag::Variant> unBatch(const ag::Variant& out);
67  virtual std::vector<ag::Variant>
68  unBatch(const ag::Variant& out, bool stripOutput, double stripValue);
69 
70  /** Given a vector of queries, create the batch that is going to be
71  * passed to the model.
72  * This default implementation finds the max size of each tensor accross the
73  * batch and resizes all the queries to that size, padding the extra space
74  * with -1
75  */
76  virtual ag::Variant makeBatch(const std::vector<ag::Variant>& queries);
77  virtual ag::Variant makeBatch(
78  const std::vector<ag::Variant>& queries,
79  double padValue);
80 
81  /**
82  * This function should return true when the batch is ready to be consumed
83  */
84  virtual bool shouldConsume();
85 
86  protected:
87  void startBatching(int batchSize);
88  void stopBatching();
89 
90  void consumeThread();
91 
92  ag::Container model_;
93  bool consumeThreadStarted_ = false;
94 
96  int padValue_;
98  double stripValue_;
99 
100  std::condition_variable_any batchReadyCV_;
101  std::mutex batchReadyMutex_;
102 
103  // Mutexes have to be acquires in this order:
105  std::shared_mutex modelMutex_;
106 
107  std::atomic_size_t querySize_; // To make TSAN happy
108  std::vector<ag::Variant> queries_;
109  std::vector<std::shared_ptr<std::promise<ag::Variant>>> replies_;
110 
111  std::thread consumeThread_;
112  std::atomic<bool> shouldStop_{false};
113 };
114 
115 /** A batcher that can operate on (already) batched data
116  * Should be used when features have a variable batch dimension, for
117  * instance the number of units controlled.
118  * More specifically, tensors with sizes [b1, ft], [b2, ft] ..., are
119  * batched into a Tensor of size [b1 + b2 + ..., ft].
120  *
121  * On the contrary to AsyncBatcher, SubBatchAsyncBatcher expects
122  * input tensors shape to differ on the first dimension only, and will
123  * not pad input tensors, unless explicitely autorized with 'allowPadding'.
124  */
126  public:
127  static constexpr const char* kBatchInfoKey = "batch_info";
128 
129  SubBatchAsyncBatcher(int batchSize, ag::Container model = nullptr);
131 
132  std::vector<ag::Variant>
133  unBatch(const ag::Variant& out, bool stripOutput, double stripValue) override;
134 
135  ag::Variant makeBatch(const std::vector<ag::Variant>& queries, double)
136  override;
137 
138  void allowPadding(bool allowPadding) {
139  allowPadding_ = allowPadding;
140  }
141 
142  torch::Tensor makeBatchTensors(
143  std::vector<torch::Tensor> const& tensors,
144  double padValue);
145 
146  protected:
147  bool allowPadding_ = false;
148 
149  public:
150  static std::vector<torch::Tensor> unBatchTensor(
151  const torch::Tensor& out,
152  std::vector<int64_t> const& batchSizes);
153 
154  static std::vector<int64_t> findBatchInfo(
155  ag::Variant const& batchInfoVar,
156  std::string const& variableName);
157 
158  static std::vector<torch::Tensor> forEachSubbatch(
159  ag::Variant const& input,
160  std::string const& inputName,
161  torch::Tensor batchedInput,
162  std::function<torch::Tensor(torch::Tensor)> do_fn = [](torch::Tensor t) {
163  return t;
164  });
165 };
166 
167 } // namespace cpid
virtual ag::Variant batchedForward(ag::Variant state)
This function queues up the state for a forward.
Definition: batcher.cpp:63
virtual std::vector< ag::Variant > unBatch(const ag::Variant &out)
Given an output of the model, retrieve the replies for all the element of the batch.
Definition: batcher.cpp:163
Definition: batcher.h:22
void startBatching(int batchSize)
Definition: batcher.cpp:41
std::mutex batchReadyMutex_
Definition: batcher.h:101
std::shared_mutex modelMutex_
Definition: batcher.h:105
This class implements a mutex that offers some control over the priority of the waiting threads...
Definition: prioritymutex.h:49
virtual bool shouldConsume()
This function should return true when the batch is ready to be consumed.
Definition: batcher.cpp:91
std::thread consumeThread_
Definition: batcher.h:111
void consumeThread()
Definition: batcher.cpp:95
virtual ag::Variant makeBatch(const std::vector< ag::Variant > &queries)
Given a vector of queries, create the batch that is going to be passed to the model.
Definition: batcher.cpp:179
std::condition_variable_any batchReadyCV_
Definition: batcher.h:100
ag::Container model_
Definition: batcher.h:92
int padValue_
Definition: batcher.h:96
std::vector< std::shared_ptr< std::promise< ag::Variant > > > replies_
Definition: batcher.h:109
virtual ~AsyncBatcher()
Definition: batcher.cpp:37
bool stripOutput_
Definition: batcher.h:97
priority_mutex accessMutex_
Definition: batcher.h:104
void allowPadding(bool allowPadding)
Definition: batcher.h:138
std::atomic< bool > shouldStop_
Definition: batcher.h:112
void stopBatching()
Definition: batcher.cpp:54
The TorchCraftAI training library.
Definition: batcher.cpp:15
double stripValue_
Definition: batcher.h:98
AsyncBatcher(ag::Container model, int batchsize, int padValue=-1, bool stripOutput=true, double stripValue=-1.)
Construct a batcher.
Definition: batcher.cpp:17
bool consumeThreadStarted_
Definition: batcher.h:93
std::vector< ag::Variant > queries_
Definition: batcher.h:108
int batchSize_
Definition: batcher.h:95
void setModel(ag::Container newModel)
Changes the model to be used for forwarding.
Definition: batcher.cpp:183
std::atomic_size_t querySize_
Definition: batcher.h:107
std::unique_lock< std::shared_mutex > lockModel()
Get a lock on the model.
Definition: batcher.cpp:188
A batcher that can operate on (already) batched data Should be used when features have a variable bat...
Definition: batcher.h:125