10 #include <autogradpp/autograd.h> 19 using hires_clock = std::chrono::steady_clock;
25 void updateDone(
int updateCount);
28 void checkpointTrainer(std::string
const& suffix =
"final");
29 static void checkpointTrainer(
30 std::shared_ptr<Trainer> trainer,
31 std::string
const& filename =
"trainer_final.bin");
35 std::string getModelPath()
const;
38 TORCH_ARG(
int, epochLength) = 500;
41 TORCH_ARG(std::shared_ptr<visdom::Visdom>,
visdom);
44 TORCH_ARG(std::vector<std::string>, visdomKeys);
51 TORCH_ARG(
bool, visdomOnEpoch) =
true;
57 TORCH_ARG(
int, visdomPlotFreq) = -1;
62 std::string
const& checkpointPath()
const;
66 TORCH_ARG(std::string, compareMetric) =
"";
69 TORCH_ARG(
bool, printMetricsSummary) =
true;
72 TORCH_ARG(
bool, aggregateMetrics) =
true;
75 TORCH_ARG(
bool, flushMetrics) =
false;
78 TORCH_ARG(
bool, dumpMetrics) =
false;
88 TORCH_ARG(
bool, reduceMax) =
true;
90 using Hook = std::function<void(int)>;
93 TORCH_ARG(
Hook, epochHook) = [](int) {};
96 TORCH_ARG(
Hook, updateHook) = [](int) {};
99 void onUpdate(
int updateCount);
100 void onEpoch(
int updateCount);
101 void plotVisdom(
const std::vector<float>& values,
int count);
103 std::unordered_map<std::string, float> means,
104 std::unordered_map<std::string, float> mins,
105 std::unordered_map<std::string, float> maxs);
106 void reduceMetrics(std::vector<float>& values);
111 int lastEpochUpdateNum_ = 0;
Definition: checkpointer.h:12
MetricsSummaryFormat
Choose a format for stdout metrics.
Definition: checkpointer.h:81
std::shared_ptr< Trainer > trainer_
Definition: checkpointer.h:107
Definition: checkpointer.h:82
Definition: checkpointer.h:83
std::function< void(int)> Hook
Definition: checkpointer.h:90
std::vector< std::string > visdomLines_
Definition: checkpointer.h:109
The TorchCraftAI training library.
Definition: batcher.cpp:15
hires_clock::time_point lastEpochStamp_
Definition: checkpointer.h:110
std::string checkpointPath_
Where to save everything.
Definition: checkpointer.h:60
Definition: checkpointer.h:18