TorchCraftAI
A bot for machine learning research on StarCraft: Brood War
upcstorage.h
1 /*
2  * Copyright (c) 2017-present, Facebook, Inc.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #pragma once
9 
10 #include "upc.h"
11 
12 #include <deque>
13 
14 namespace cherrypi {
15 
16 class Module;
17 
18 /**
19  * Base class for data attached to the posting of an UPC.
20  *
21  * When backpropagating through the bot, this data will be provided for
22  * computing gradients for the respective posted UPC. A common use case
23  * would be to store the output of a featurizer here.
24  */
25 struct UpcPostData {
26  // Base class is empty
27  virtual ~UpcPostData() = default;
28 };
29 
30 /**
31  * Stores information about UPCs that have been posted to the board.
32  */
33 struct UpcPost {
34  public:
35  /// Game frame at time of post
36  FrameNum frame = -1;
37  /// Identifier of posted UPC
39  /// Identifier of source UPC
40  UpcId sourceId = kInvalidUpcId;
41  /// The module performing the transaction
42  Module* module = nullptr;
43  /// The actual UPC data
44  std::shared_ptr<UPCTuple> upc = nullptr;
45  /// Data attached to this transaction
46  std::shared_ptr<UpcPostData> data = nullptr;
47 
48  UpcPost() {}
50  FrameNum frame,
51  UpcId upcId,
52  UpcId sourceId,
53  Module* module,
54  std::shared_ptr<UPCTuple> upc = nullptr,
55  std::shared_ptr<UpcPostData> data = nullptr)
56  : frame(frame),
57  upcId(upcId),
58  sourceId(sourceId),
59  module(module),
60  upc(std::move(upc)),
61  data(std::move(data)) {}
62 };
63 
64 /**
65  * Stores a graph of UPC communication, including any transactional data.
66  *
67  * The storage will retain any UPCTuple and accompanying UpcPostData objects
68  * added via addUpc(). It is possible to disable permanent storage of tuples and
69  * data objects via setPersistent(), which is useful for evaluation settings in
70  * which memory is scarce.
71  */
72 class UpcStorage {
73  public:
74  UpcStorage();
75  ~UpcStorage();
76 
77  /// Controls whether UPCTuple and UpcPostData objects should be stored.
78  void setPersistent(bool persistent);
79 
80  /// Adds a UPC tuple with accompanying transaction data.
81  /// The returned ID can be used to refer to this UPC in the future.
82  UpcId addUpc(
83  FrameNum frame,
84  UpcId sourceId,
85  Module* source,
86  std::shared_ptr<UPCTuple> upc,
87  std::shared_ptr<UpcPostData> data = nullptr);
88 
89  /// Retrieve the source UPC ID for the given UPC ID.
90  UpcId sourceId(UpcId id) const;
91 
92  /// Recursively retrieve source UPC IDs up to a given module.
93  /// If upTo is nullptr, retrieve all source UPC IDs up to and including the
94  /// root UPC.
95  std::vector<UpcId> sourceIds(UpcId id, Module* upTo = nullptr) const;
96 
97  /// Retrieve the UPC Tuple for a given ID.
98  /// If the storage is not persistent, this function will return nullptr.
99  std::shared_ptr<UPCTuple> upc(UpcId id) const;
100 
101  /// Retrieve the full post data for a given ID.
102  UpcPost const* post(UpcId id) const;
103 
104  /// Retrieve all posts from a given module.
105  /// Optionally, this can be restricted to a given frame number. Note that
106  /// these pointers might be invalidated in subsequent calls to addUpc().
107  std::vector<UpcPost const*> upcPostsFrom(Module* module, FrameNum frame = -1)
108  const;
109 
110  std::vector<UpcPost> const& getAllUpcs() const {
111  return posts_;
112  }
113 
114  private:
115  /// The UPC IDs we provide are indices to this container.
116  /// However, the indices actually start at 1.
117  std::vector<UpcPost> posts_;
118 
119  bool persistent_ = true;
120 };
121 
122 } // namespace cherrypi
int FrameNum
Definition: basetypes.h:22
Stores a graph of UPC communication, including any transactional data.
Definition: upcstorage.h:72
STL namespace.
UpcPost(FrameNum frame, UpcId upcId, UpcId sourceId, Module *module, std::shared_ptr< UPCTuple > upc=nullptr, std::shared_ptr< UpcPostData > data=nullptr)
Definition: upcstorage.h:49
Base class for data attached to the posting of an UPC.
Definition: upcstorage.h:25
std::vector< UpcPost > const & getAllUpcs() const
Definition: upcstorage.h:110
virtual ~UpcPostData()=default
Main namespace for bot-related code.
Definition: areainfo.cpp:17
int UpcId
Definition: basetypes.h:23
UpcPost()
Definition: upcstorage.h:48
Interface for bot modules.
Definition: module.h:30
Stores information about UPCs that have been posted to the board.
Definition: upcstorage.h:33
UpcId constexpr kInvalidUpcId
Definition: basetypes.h:25