10 #include <common/autograd.h> 12 #include <gflags/gflags.h> 14 DECLARE_string(bos_model_type);
15 DECLARE_bool(bos_bo_input);
16 DECLARE_bool(bos_mapid_input);
17 DECLARE_bool(bos_time_input);
18 DECLARE_bool(bos_res_input);
19 DECLARE_bool(bos_tech_input);
20 DECLARE_bool(bos_ptech_input);
21 DECLARE_bool(bos_units_input);
22 DECLARE_bool(bos_fabs_input);
23 DECLARE_int32(bos_hid_dim);
24 DECLARE_int32(bos_num_layers);
25 DECLARE_string(bos_targets);
32 std::map<std::string, std::string>
modelFlags();
36 void reset()
override;
37 ag::Variant forward(ag::Variant input)
override;
42 TORCH_ARG(
int, bo_embsize) = 8;
43 TORCH_ARG(
int, mapid_embsize) = 8;
44 TORCH_ARG(
int, n_builds) = -1;
45 TORCH_ARG(
int, race_embsize) = 8;
46 TORCH_ARG(
int, resources_embsize) = 8;
47 TORCH_ARG(
int, tech_embsize) = 8;
48 TORCH_ARG(
int, ptech_embsize) = 8;
49 TORCH_ARG(
int, time_embsize) = 8;
51 void reset()
override;
67 ag::Variant forward(ag::Variant input)
override;
70 ag::Container embedM_;
71 ag::Container embedR_;
72 ag::Container embedRS_;
73 ag::Container embedT_;
74 ag::Container embedPT_;
75 ag::Container embedTM_;
76 ag::Container embedBO_;
81 TORCH_ARG(
int, bo_embsize) = 8;
82 TORCH_ARG(
int, hid_dim) = 256;
83 TORCH_ARG(
int, mapid_embsize) = 8;
84 TORCH_ARG(
int, n_builds) = -1;
85 TORCH_ARG(
int, n_unit_types) = 118 * 2;
86 TORCH_ARG(
int, ptech_embsize) = 8;
87 TORCH_ARG(
int, race_embsize) = 8;
88 TORCH_ARG(
int, resources_embsize) = 8;
89 TORCH_ARG(std::set<std::string>, target_builds) = {};
90 TORCH_ARG(
int, tech_embsize) = 8;
91 TORCH_ARG(
int, time_embsize) = 8;
92 TORCH_ARG(
bool, use_fabs) =
false;
93 TORCH_ARG(
bool, zero_units) =
false;
95 void reset()
override;
115 ag::Variant forward(ag::Variant input)
override;
118 ag::Container trunk_;
119 ag::Container linear_;
120 ag::Container vHeads_;
121 torch::Tensor masks_;
126 TORCH_ARG(
int, bo_embsize) = 8;
127 TORCH_ARG(
int, hid_dim) = 256;
128 TORCH_ARG(
int, mapid_embsize) = 8;
129 TORCH_ARG(
int, n_builds) = -1;
130 TORCH_ARG(
int, n_layers) = 3;
131 TORCH_ARG(
int, n_unit_types) = 118 * 2;
132 TORCH_ARG(
int, race_embsize) = 8;
133 TORCH_ARG(
int, resources_embsize) = 8;
134 TORCH_ARG(
int, tech_embsize) = 8;
135 TORCH_ARG(
int, ptech_embsize) = 8;
136 TORCH_ARG(
int, time_embsize) = 8;
137 TORCH_ARG(
bool, use_fabs) =
false;
138 TORCH_ARG(
bool, zero_units) =
false;
139 TORCH_ARG(std::set<std::string>, target_builds) = {};
141 void reset()
override;
161 ag::Variant forward(ag::Variant input)
override;
164 ag::Container trunk_;
166 ag::Container vHeads_;
167 torch::Tensor masks_;
172 TORCH_ARG(
int, bo_embsize) = 8;
173 TORCH_ARG(
int, hid_dim) = 256;
174 TORCH_ARG(
int, mapid_embsize) = 8;
175 TORCH_ARG(
int, n_builds) = -1;
176 TORCH_ARG(
int, n_layers) = 1;
177 TORCH_ARG(
int, n_unit_types) = 118 * 2;
178 TORCH_ARG(
int, race_embsize) = 8;
179 TORCH_ARG(
int, resources_embsize) = 8;
180 TORCH_ARG(
int, tech_embsize) = 8;
181 TORCH_ARG(
int, ptech_embsize) = 8;
182 TORCH_ARG(
int, time_embsize) = 8;
183 TORCH_ARG(
bool, use_fabs) =
false;
184 TORCH_ARG(
bool, zero_units) =
false;
185 TORCH_ARG(std::set<std::string>, target_builds) = {};
187 void reset()
override;
211 ag::Variant forward(ag::Variant input)
override;
214 ag::Container trunk_;
216 ag::Container vHeads_;
217 torch::Tensor masks_;
222 TORCH_ARG(
int, bo_embsize) = 8;
223 TORCH_ARG(std::function<decltype(torch::relu)>, cnn_nonlinearity) =
225 TORCH_ARG(
bool, deep_conv) =
false;
226 TORCH_ARG(
int, hid_dim) = 256;
227 TORCH_ARG(
int, kernel_size) = 5;
228 TORCH_ARG(
bool, map_features) =
false;
229 TORCH_ARG(
int, mapid_embsize) = 8;
230 TORCH_ARG(
int, n_builds) = -1;
231 TORCH_ARG(
int, n_layers) = 1;
232 TORCH_ARG(
int, n_unit_types) = 118 * 2;
233 TORCH_ARG(
int, ptech_embsize) = 8;
234 TORCH_ARG(
int, race_embsize) = 8;
235 TORCH_ARG(
int, resources_embsize) = 8;
236 TORCH_ARG(
int, spatial_embsize) = 128;
237 TORCH_ARG(std::set<std::string>, target_builds) = {};
238 TORCH_ARG(
int, tech_embsize) = 8;
239 TORCH_ARG(
int, time_embsize) = 8;
240 TORCH_ARG(
bool, use_fabs) =
false;
242 void reset()
override;
271 ag::Variant forward(ag::Variant input)
override;
274 ag::Container trunk_;
275 ag::Container mapConv_;
276 ag::Container convnet_;
277 ag::Container cembed_;
279 ag::Container vHeads_;
280 torch::Tensor masks_;
AUTOGRAD_CONTAINER_CLASS(IdleModel)
Definition: models.h:34
std::map< std::string, std::string > modelFlags()
Definition: models.cpp:159
Main namespace for bot-related code.
Definition: areainfo.cpp:17
ag::Container modelMakeFromCli(double dropout)
Construct a BOS module according to command-line flags.
Definition: models.cpp:84