TorchCraftAI
A bot for machine learning research on StarCraft: Brood War
checkpointer.h
1 /*
2  * Copyright (c) 2017-present, Facebook, Inc.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #pragma once
9 
10 #include <autogradpp/autograd.h>
11 
12 namespace visdom {
13 class Visdom;
14 }
15 namespace cpid {
16 
17 class Trainer;
18 class Checkpointer {
19  using hires_clock = std::chrono::steady_clock;
20 
21  public:
22  Checkpointer(std::shared_ptr<Trainer> trainer);
23 
24  /// This is the entry point to be called by trainers
25  void updateDone(int updateCount);
26 
27  /// Creates a checkpoint on the disk
28  void checkpointTrainer(std::string const& suffix = "final");
29  static void checkpointTrainer(
30  std::shared_ptr<Trainer> trainer,
31  std::string const& filename = "trainer_final.bin");
32 
33  /// Returns the path where the latest model would be saved. (It's not
34  /// guaranteed that one have been saved yet)
35  std::string getModelPath() const;
36 
37  /// Epoch length (in number of updates)
38  TORCH_ARG(int, epochLength) = 500;
39 
40  /// Visdom server. Give nullptr to disable plotting
41  TORCH_ARG(std::shared_ptr<visdom::Visdom>, visdom);
42 
43  /// List of metrics keys to plot
44  TORCH_ARG(std::vector<std::string>, visdomKeys);
45 
46  /**
47  * If true, the visdom visualization will happen at the end on the epoch, and
48  * will print the mean of the parameters during that epoch. Otherwise, it will
49  * plot the last value of the parameters, at the defined frequency
50  */
51  TORCH_ARG(bool, visdomOnEpoch) = true;
52 
53  /**
54  * If visdomOnEpoch = false, this is the frequency at which visdom plots are
55  * updated
56  */
57  TORCH_ARG(int, visdomPlotFreq) = -1;
58 
59  /// Where to save everything
60  std::string checkpointPath_;
61  Checkpointer& checkpointPath(std::string const& path);
62  std::string const& checkpointPath() const;
63 
64  /// Metrics used to assess preformance of a model
65  /// Disables performance based checkpoints if empty
66  TORCH_ARG(std::string, compareMetric) = "";
67 
68  /// If true, print the mean of the metrics at each epoch
69  TORCH_ARG(bool, printMetricsSummary) = true;
70 
71  /// If true, the metrics are aggregated over all workers
72  TORCH_ARG(bool, aggregateMetrics) = true;
73 
74  /// If true, we clear the metrics at the end of the epoch
75  TORCH_ARG(bool, flushMetrics) = false;
76 
77  /// If true, we dump the json of the metrics at each epoch
78  TORCH_ARG(bool, dumpMetrics) = false;
79 
80  /// Choose a format for stdout metrics
84  };
85  TORCH_ARG(MetricsSummaryFormat, metricsSummaryFormat) = FORMAT_DEFAULT;
86 
87  /// If true, we reduce accross nodes using the max operator instead
88  TORCH_ARG(bool, reduceMax) = true;
89 
90  using Hook = std::function<void(int)>;
91 
92  // Function to call at the end of every epoch
93  TORCH_ARG(Hook, epochHook) = [](int) {};
94 
95  // Function to call at the end of every update
96  TORCH_ARG(Hook, updateHook) = [](int) {};
97 
98  protected:
99  void onUpdate(int updateCount);
100  void onEpoch(int updateCount);
101  void plotVisdom(const std::vector<float>& values, int count);
102  void printSummary(
103  std::unordered_map<std::string, float> means,
104  std::unordered_map<std::string, float> mins,
105  std::unordered_map<std::string, float> maxs);
106  void reduceMetrics(std::vector<float>& values);
107  std::shared_ptr<Trainer> trainer_;
108 
109  std::vector<std::string> visdomLines_;
110  hires_clock::time_point lastEpochStamp_;
111  int lastEpochUpdateNum_ = 0;
112 };
113 } // namespace cpid
Definition: checkpointer.h:12
Definition: visdom.h:80
MetricsSummaryFormat
Choose a format for stdout metrics.
Definition: checkpointer.h:81
std::shared_ptr< Trainer > trainer_
Definition: checkpointer.h:107
Definition: checkpointer.h:82
Definition: checkpointer.h:83
std::function< void(int)> Hook
Definition: checkpointer.h:90
std::vector< std::string > visdomLines_
Definition: checkpointer.h:109
The TorchCraftAI training library.
Definition: batcher.cpp:15
hires_clock::time_point lastEpochStamp_
Definition: checkpointer.h:110
std::string checkpointPath_
Where to save everything.
Definition: checkpointer.h:60
Definition: checkpointer.h:18