TorchCraftAI
A bot for machine learning research on StarCraft: Brood War
runner.h
1 /*
2  * Copyright (c) 2017-present, Facebook, Inc.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #pragma once
9 
10 #include "models/bos/sample.h"
11 
12 #ifdef HAVE_CPID
13 #include <cpid/trainer.h>
14 #endif // HAVE_CPID
15 
16 #include <unordered_map>
17 
18 #include <autogradpp/autograd.h>
19 
20 namespace cherrypi {
21 namespace bos {
22 
23 /**
24  * Helper class for running BOS models.
25  * Once instantiated, the runner is valid for the current game only.
26  */
27 struct ModelRunner {
28 #ifdef HAVE_CPID
30 #else // HAVE_CPID
31  using EpisodeHandle = std::string;
32 #endif // HAVE_CPID
33 
34  std::shared_ptr<StaticData> staticData = nullptr;
35 #ifdef HAVE_CPID
36  std::shared_ptr<cpid::Trainer> trainer;
37 #endif // HAVE_CPID
38  ag::Container model;
39  std::unordered_map<int64_t, std::string> indexToBo;
40  std::string modelType;
41  torch::Tensor boMask;
42 
43 #ifdef HAVE_CPID
44  ModelRunner(std::shared_ptr<cpid::Trainer> trainer);
45 #endif // HAVE_CPID
46  ModelRunner(ag::Container model);
47  virtual ~ModelRunner() = default;
48 
49  Sample takeSample(State* state) const;
50  virtual ag::Variant makeInput(Sample const& sample) const;
51  ag::Variant forward(
52  Sample const& sample,
53  EpisodeHandle const& handle = EpisodeHandle());
54  ag::Variant forward(
55  ag::Variant input,
56  Sample const& sample,
57  EpisodeHandle const& = EpisodeHandle());
58 
59  void blacklistBuildOrder(std::string buildOrder);
60 
61  protected:
62  virtual ag::Variant modelForward(ag::Variant input);
63 #ifdef HAVE_CPID
64  virtual ag::Variant trainerForward(ag::Variant input, EpisodeHandle const&);
65 #endif // HAVE_CPID
66 };
67 
68 #ifdef HAVE_CPID
69 std::unique_ptr<ModelRunner> makeModelRunner(
70  std::shared_ptr<cpid::Trainer> trainer,
71  std::string modelType);
72 #endif // HAVE_CPID
73 std::unique_ptr<ModelRunner> makeModelRunner(
74  ag::Container model,
75  std::string modelType);
76 
77 } // namespace bos
78 } // namespace cherrypi
Game state.
Definition: state.h:42
Trainer::EpisodeHandle EpisodeHandle
Definition: trainer.h:303
ag::Container model
Definition: runner.h:38
std::string EpisodeHandle
Definition: runner.h:31
ModelRunner(ag::Container model)
Definition: runner.cpp:151
Sample takeSample(State *state) const
Definition: runner.cpp:156
virtual ag::Variant makeInput(Sample const &sample) const
Definition: runner.cpp:162
std::unordered_map< int64_t, std::string > indexToBo
Definition: runner.h:39
std::shared_ptr< StaticData > staticData
Definition: runner.h:34
virtual ~ModelRunner()=default
std::unique_ptr< ModelRunner > makeModelRunner(ag::Container model, std::string modelType)
Definition: runner.cpp:250
ag::Variant forward(Sample const &sample, EpisodeHandle const &handle=EpisodeHandle())
Definition: runner.cpp:166
void blacklistBuildOrder(std::string buildOrder)
Definition: runner.cpp:222
virtual ag::Variant modelForward(ag::Variant input)
Definition: runner.cpp:212
Helper class for running BOS models.
Definition: runner.h:27
Definition: sample.h:148
std::string modelType
Definition: runner.h:40
Main namespace for bot-related code.
Definition: areainfo.cpp:17
torch::Tensor boMask
Definition: runner.h:41