TorchCraftAI
A bot for machine learning research on StarCraft: Brood War
models.h
1 /*
2  * Copyright (c) 2017-present, Facebook, Inc.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #pragma once
9 
10 #include <autogradpp/autograd.h>
11 
12 /*
13  * Useful helpers for neural networks expressed with Torch.
14  */
15 namespace common {
16 
17 /**
18  * Simple MLP of nLayers layers, with hidden size all being the same.
19  * Optionally, we can zero the last layer, which is useful if the output is
20  * suppose to be a probability distribution since values will be uniform after
21  * softmax
22  */
24  public:
25  TORCH_ARG(int, nIn);
26  TORCH_ARG(int, nHid);
27  TORCH_ARG(int, nOut);
28  TORCH_ARG(int, nLayers) = 1;
29  TORCH_ARG(bool, zeroLastLayer);
30  // Type of relu is (at::Tensor (&)(const at::Tensor&))
31  TORCH_ARG(std::function<decltype(torch::relu)>, nonlinearity) = torch::relu;
32  ag::Container seq_;
33 
34  void reset() override;
35  ag::Variant forward(ag::Variant x) override;
36 };
37 
39  ag::Container conv_;
40  ag::Conv2d convBase_;
41  GatedConv(ag::Conv2d conv) : convBase_(conv) {
42  convBase_.output_channels(convBase_.output_channels_ * 2);
43  }
44  void reset() override {
45  conv_ = add(convBase_.make(), "conv_");
46  }
47  ag::Variant forward(ag::Variant inp) override {
48  auto chunked = conv_->forward(inp)[0].chunk(2, 1);
49  return {chunked.at(0) * chunked.at(1).sigmoid_()};
50  }
51 };
52 
53 enum class PadType {
54  Zero,
55  Reflection,
57 };
58 /**
59  * Simple convolutional block, with optional residual connection
60  * From a user perspective, the convolution parameters behave the same as if the
61  * block was a single conv layer. For example if the stride is 2, the output
62  * will be twice smaller than the input, irrespective of the number of inner
63  * layers. In practice the stride and dilation are only applied to the first
64  * layer.
65  * The block also applies padding to compensate for the kernel size and
66  * the dilation. That means that if the input is size hxw, the output will be
67  * h'xw' with h' = (h - 1)/stride + 1 and w' = (w-1)/stride + 1
68  */
70  public:
71  /// Number of feature channels in the input
72  TORCH_ARG(int, nInFeats);
73  /// Number of feature channels in the output
74  TORCH_ARG(int, nOutFeats);
75  /// Non linearity inserted between each convolution
76  // Type of relu is (at::Tensor (&)(const at::Tensor&))
77  TORCH_ARG(std::function<decltype(torch::relu)>, nonlinearity) = torch::relu;
78  /// If true, the module performs transposed convolutions instead
79  TORCH_ARG(bool, deconv) = false;
80  /// Size of the convolution kernels (we use kernelSize X kernelSize)
81  TORCH_ARG(int, kernelSize) = 3;
82  /// Stride of the convolutions
83  TORCH_ARG(int, stride) = 1;
84  /// Dilation of the convolutions
85  TORCH_ARG(int, dilation) = 1;
86  /// Add a residual convolution when true
87  TORCH_ARG(bool, residual) = true;
88  /// Add a batchNorm layers where appropriate, if true
89  TORCH_ARG(bool, batchNorm) = true;
90  /// If true, the intermediate convolutions will have 4 times less features
91  /// than the output
92  TORCH_ARG(bool, bottleNeck) = false;
93  /// Number of convolution layers
94  TORCH_ARG(int, nLayers) = 2;
95  /// Bias in the convolutions
96  TORCH_ARG(bool, bias) = false;
97  /// Whether to use gated convolutions
98  TORCH_ARG(bool, gated) = false;
99  /// How to pad
100  TORCH_ARG(PadType, padType) = PadType::Zero;
101 
102  void reset() override;
103 
104  ag::Variant forward(ag::Variant x) override;
105 
106  ag::Container seq_, resample_;
107 
108  protected:
109  void addLayer(ag::Sequential & trunk, ag::Container layer, int nOut, int id);
110 };
111 
112 enum class ConcatType {
113  None, /// No concatenation
114  Input, /// Always concatenate input
115  Mirror /// Concatenate input of mirror layer
116 };
117 enum class UpsamplingType {
118  None, /// No upsampling
119  Bilin, /// Bilinear upsampling (fixed)
120  Deconv /// Learnt upsampling (transposed convolution)
121 };
122 enum class DecodeType {
123  None, /// No decoding
124  Conv, /// Decode with convolutions
125  Deconv /// Decode with transposed convolutions
126 };
127 enum class DilationScheme {
128  None, /// No dilation
129  Linear, /// The dilation increases linearly at each layer
130  Exponential /// The dilation increases exponentially
131 };
132 AUTOGRAD_CONTAINER_CLASS(EncoderDecoder) {
133  public:
134  /// Shape of the input, given as [c,h,w], where c is the number of channels, h
135  /// is the height and w the width
136  TORCH_ARG(at::IntList, inShape);
137  /// Number of feature channels in the intermediate layers
138  TORCH_ARG(int, intermSize);
139  /// Number of feature channels in the output
140  TORCH_ARG(int, nOutFeats);
141  /// Non linearity inserted between each convolution
142  // Type of relu is (at::Tensor (&)(const at::Tensor&))
143  TORCH_ARG(std::function<decltype(torch::relu)>, nonlinearity) = torch::relu;
144  /// Strategy for concatening previous layers during decoding
145  TORCH_ARG(ConcatType, concatInput) = ConcatType::None;
146  /// Strategy for upsampling, when needed
147  TORCH_ARG(UpsamplingType, upsampling) = UpsamplingType::None;
148  /// Strategy for decoding
149  TORCH_ARG(DecodeType, decodeType) = DecodeType::None;
150  /// Strategy for dilation
151  TORCH_ARG(DilationScheme, dilationType) = DilationScheme::None;
152  /// Size of the convolution kernels (we use kernelSize X kernelSize)
153  TORCH_ARG(int, kernelSize) = 3;
154  /// Stride of the convolutions
155  TORCH_ARG(int, stride) = 1;
156  /// Add a residual convolution when true
157  TORCH_ARG(bool, residual) = true;
158  /// Add a batchNorm layers where appropriate, if true
159  TORCH_ARG(bool, batchNorm) = true;
160  /// If true, the intermediate convolutions will have 4 times less features
161  /// than the output
162  TORCH_ARG(bool, bottleNeck) = false;
163  /// Number of Convolutional blocks in the encoding (if there is decoding, it
164  /// will contain the same amount of blocks)
165  TORCH_ARG(int, numBlocks) = 2;
166  /// Number of convolution layers in each block
167  TORCH_ARG(int, nInnerLayers) = 2;
168  /// Bias in the convolutions
169  TORCH_ARG(bool, bias) = false;
170  /// Whether to use gated convolutions
171  TORCH_ARG(bool, gated) = false;
172 
173  void reset() override;
174 
175  ag::Variant forward(ag::Variant x) override;
176 
177  std::vector<ag::Container> encodingLayers_, decodingLayers_;
178  std::vector<ag::Container> trunkResampling_, skipResampling_;
179 
180  protected:
181  static void add_padding_if_needed(
182  ag::Sequential & module,
183  int size,
184  std::pair<int, int> inShape,
185  std::pair<int, int> targetShape);
186 
187  void add_resample(
188  ag::Sequential & module,
189  int size,
190  std::pair<int, int> inShape,
191  std::pair<int, int> targetShape) const;
192 };
193 
194 // Input is (Q, K, V, mask), where mask contains the valid indices
195 // Q is (bsz, numQueries, queryDim)
196 // K is (bsz, numKeys, queryDim)
197 // V is (bsz, numKeys, valueDim)
198 // mask is (bsz, numQueries, numKeys)
199 // output is (bsz, numQueries, outDim)
200 //
201 // Check the paper for details:
202 // https://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf
204  ag::Container vLinear_, kLinear_, qLinear_, oLinear_;
205  TORCH_ARG(int, queryDim) = 0;
206  TORCH_ARG(int, valueDim) = 0;
207  TORCH_ARG(int, hidDim) = 0;
208  TORCH_ARG(int, nHeads) = 0;
209  TORCH_ARG(int, outDim) = 0;
210  virtual void reset() override;
211  virtual ag::Variant forward(ag::Variant x) override;
212 };
213 
214 } // namespace common
Always concatenate input.
ConcatType
Definition: models.h:112
UpsamplingType
Definition: models.h:117
Bilinear upsampling (fixed)
DilationScheme
Definition: models.h:127
DecodeType
Definition: models.h:122
General utilities.
Definition: assert.cpp:7
PadType
Definition: models.h:53
No concatenation.
AUTOGRAD_CONTAINER_CLASS(MLP)
Simple MLP of nLayers layers, with hidden size all being the same.
Definition: models.h:23
The dilation increases linearly at each layer.