10 #include <autogradpp/autograd.h>    28   TORCH_ARG(
int, nLayers) = 1;
    29   TORCH_ARG(
bool, zeroLastLayer);
    31   TORCH_ARG(std::function<decltype(torch::relu)>, nonlinearity) = torch::relu;
    34   void reset() 
override;
    35   ag::Variant forward(ag::Variant x) 
override;
    41   GatedConv(ag::Conv2d conv) : convBase_(conv) {
    42     convBase_.output_channels(convBase_.output_channels_ * 2);
    44   void reset()
 override {
    45     conv_ = add(convBase_.make(), 
"conv_");
    47   ag::Variant forward(ag::Variant inp)
 override {
    48     auto chunked = conv_->forward(inp)[0].chunk(2, 1);
    49     return {chunked.at(0) * chunked.at(1).sigmoid_()};
    72   TORCH_ARG(
int, nInFeats);
    74   TORCH_ARG(
int, nOutFeats);
    77   TORCH_ARG(std::function<decltype(torch::relu)>, nonlinearity) = torch::relu;
    79   TORCH_ARG(
bool, deconv) = 
false;
    81   TORCH_ARG(
int, kernelSize) = 3;
    83   TORCH_ARG(
int, stride) = 1;
    85   TORCH_ARG(
int, dilation) = 1;
    87   TORCH_ARG(
bool, residual) = 
true;
    89   TORCH_ARG(
bool, batchNorm) = 
true;
    92   TORCH_ARG(
bool, bottleNeck) = 
false;
    94   TORCH_ARG(
int, nLayers) = 2;
    96   TORCH_ARG(
bool, bias) = 
false;
    98   TORCH_ARG(
bool, gated) = 
false;
   102   void reset() 
override;
   104   ag::Variant forward(ag::Variant x) 
override;
   106   ag::Container seq_, resample_;
   109   void addLayer(ag::Sequential & trunk, ag::Container layer, 
int nOut, 
int id);
   136   TORCH_ARG(at::IntList, inShape);
   138   TORCH_ARG(
int, intermSize);
   140   TORCH_ARG(
int, nOutFeats);
   143   TORCH_ARG(std::function<decltype(torch::relu)>, nonlinearity) = torch::relu;
   153   TORCH_ARG(
int, kernelSize) = 3;
   155   TORCH_ARG(
int, stride) = 1;
   157   TORCH_ARG(
bool, residual) = 
true;
   159   TORCH_ARG(
bool, batchNorm) = 
true;
   162   TORCH_ARG(
bool, bottleNeck) = 
false;
   165   TORCH_ARG(
int, numBlocks) = 2;
   167   TORCH_ARG(
int, nInnerLayers) = 2;
   169   TORCH_ARG(
bool, bias) = 
false;
   171   TORCH_ARG(
bool, gated) = 
false;
   173   void reset() 
override;
   175   ag::Variant forward(ag::Variant x) 
override;
   177   std::vector<ag::Container> encodingLayers_, decodingLayers_;
   178   std::vector<ag::Container> trunkResampling_, skipResampling_;
   181   static void add_padding_if_needed(
   182       ag::Sequential & module,
   184       std::pair<int, int> inShape,
   185       std::pair<int, int> targetShape);
   188       ag::Sequential & module,
   190       std::pair<int, int> inShape,
   191       std::pair<int, int> targetShape) 
const;
   204   ag::Container vLinear_, kLinear_, qLinear_, oLinear_;
   205   TORCH_ARG(
int, queryDim) = 0;
   206   TORCH_ARG(
int, valueDim) = 0;
   207   TORCH_ARG(
int, hidDim) = 0;
   208   TORCH_ARG(
int, nHeads) = 0;
   209   TORCH_ARG(
int, outDim) = 0;
   210   virtual void reset() 
override;
   211   virtual ag::Variant forward(ag::Variant x) 
override;
 
Always concatenate input. 
ConcatType
Definition: models.h:112
UpsamplingType
Definition: models.h:117
Bilinear upsampling (fixed) 
DilationScheme
Definition: models.h:127
DecodeType
Definition: models.h:122
General utilities. 
Definition: assert.cpp:7
PadType
Definition: models.h:53
AUTOGRAD_CONTAINER_CLASS(MLP)
Simple MLP of nLayers layers, with hidden size all being the same. 
Definition: models.h:23
The dilation increases linearly at each layer.