TorchCraftAI
A bot for machine learning research on StarCraft: Brood War
optimizers.h
1 /*
2  * Copyright (c) 2017-present, Facebook, Inc.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #pragma once
9 
10 #include <autogradpp/autograd.h>
11 #include <gflags/gflags.h>
12 
13 DECLARE_string(optim);
14 DECLARE_double(lr);
15 DECLARE_double(weight_decay);
16 DECLARE_double(momentum);
17 DECLARE_double(optim_eps);
18 
19 // adagrad
20 DECLARE_double(adadgrad_lr_decay);
21 
22 // adam
23 DECLARE_double(adam_beta1);
24 DECLARE_double(adam_beta2);
25 DECLARE_bool(adam_amsgrad);
26 
27 // rmsprop
28 DECLARE_double(rmsprop_alpha);
29 DECLARE_bool(rmsprop_centered);
30 
31 // sgd
32 DECLARE_double(sgd_dampening);
33 DECLARE_bool(sgd_nesterov);
34 
35 // Although this header defines optimizers and flags for you, you can always
36 // set defaults for your own script, via this idiom at the beginning of main:
37 // FLAGS_lr = 1e-3;
38 // gflags::ParseCommandLineFlags(&argc, &argv, true);
39 
40 namespace cpid {
41 
42 ag::Optimizer selectOptimizer(std::shared_ptr<torch::nn::Module>);
43 
44 std::map<std::string, std::string> optimizerFlags();
45 
46 } // namespace cpid
The TorchCraftAI training library.
Definition: batcher.cpp:15
std::map< std::string, std::string > optimizerFlags()
Definition: optimizers.cpp:76
ag::Optimizer selectOptimizer(std::shared_ptr< torch::nn::Module > module)
Definition: optimizers.cpp:35