10 #include <autogradpp/autograd.h> 28 TORCH_ARG(
int, nLayers) = 1;
29 TORCH_ARG(
bool, zeroLastLayer);
31 TORCH_ARG(std::function<decltype(torch::relu)>, nonlinearity) = torch::relu;
34 void reset()
override;
35 ag::Variant forward(ag::Variant x)
override;
41 GatedConv(ag::Conv2d conv) : convBase_(conv) {
42 convBase_.output_channels(convBase_.output_channels_ * 2);
44 void reset()
override {
45 conv_ = add(convBase_.make(),
"conv_");
47 ag::Variant forward(ag::Variant inp)
override {
48 auto chunked = conv_->forward(inp)[0].chunk(2, 1);
49 return {chunked.at(0) * chunked.at(1).sigmoid_()};
72 TORCH_ARG(
int, nInFeats);
74 TORCH_ARG(
int, nOutFeats);
77 TORCH_ARG(std::function<decltype(torch::relu)>, nonlinearity) = torch::relu;
79 TORCH_ARG(
bool, deconv) =
false;
81 TORCH_ARG(
int, kernelSize) = 3;
83 TORCH_ARG(
int, stride) = 1;
85 TORCH_ARG(
int, dilation) = 1;
87 TORCH_ARG(
bool, residual) =
true;
89 TORCH_ARG(
bool, batchNorm) =
true;
92 TORCH_ARG(
bool, bottleNeck) =
false;
94 TORCH_ARG(
int, nLayers) = 2;
96 TORCH_ARG(
bool, bias) =
false;
98 TORCH_ARG(
bool, gated) =
false;
102 void reset()
override;
104 ag::Variant forward(ag::Variant x)
override;
106 ag::Container seq_, resample_;
109 void addLayer(ag::Sequential & trunk, ag::Container layer,
int nOut,
int id);
136 TORCH_ARG(at::IntList, inShape);
138 TORCH_ARG(
int, intermSize);
140 TORCH_ARG(
int, nOutFeats);
143 TORCH_ARG(std::function<decltype(torch::relu)>, nonlinearity) = torch::relu;
153 TORCH_ARG(
int, kernelSize) = 3;
155 TORCH_ARG(
int, stride) = 1;
157 TORCH_ARG(
bool, residual) =
true;
159 TORCH_ARG(
bool, batchNorm) =
true;
162 TORCH_ARG(
bool, bottleNeck) =
false;
165 TORCH_ARG(
int, numBlocks) = 2;
167 TORCH_ARG(
int, nInnerLayers) = 2;
169 TORCH_ARG(
bool, bias) =
false;
171 TORCH_ARG(
bool, gated) =
false;
173 void reset()
override;
175 ag::Variant forward(ag::Variant x)
override;
177 std::vector<ag::Container> encodingLayers_, decodingLayers_;
178 std::vector<ag::Container> trunkResampling_, skipResampling_;
181 static void add_padding_if_needed(
182 ag::Sequential & module,
184 std::pair<int, int> inShape,
185 std::pair<int, int> targetShape);
188 ag::Sequential & module,
190 std::pair<int, int> inShape,
191 std::pair<int, int> targetShape)
const;
204 ag::Container vLinear_, kLinear_, qLinear_, oLinear_;
205 TORCH_ARG(
int, queryDim) = 0;
206 TORCH_ARG(
int, valueDim) = 0;
207 TORCH_ARG(
int, hidDim) = 0;
208 TORCH_ARG(
int, nHeads) = 0;
209 TORCH_ARG(
int, outDim) = 0;
210 virtual void reset()
override;
211 virtual ag::Variant forward(ag::Variant x)
override;
Always concatenate input.
ConcatType
Definition: models.h:112
UpsamplingType
Definition: models.h:117
Bilinear upsampling (fixed)
DilationScheme
Definition: models.h:127
DecodeType
Definition: models.h:122
General utilities.
Definition: assert.cpp:7
PadType
Definition: models.h:53
AUTOGRAD_CONTAINER_CLASS(MLP)
Simple MLP of nLayers layers, with hidden size all being the same.
Definition: models.h:23
The dilation increases linearly at each layer.