TorchCraftAI
A bot for machine learning research on StarCraft: Brood War
|
#include <checkpointer.h>
Public Types | |
enum | MetricsSummaryFormat { FORMAT_DEFAULT, FORMAT_TORCHBOARD } |
Choose a format for stdout metrics. More... | |
using | Hook = std::function< void(int)> |
Public Member Functions | |
Checkpointer (std::shared_ptr< Trainer > trainer) | |
void | updateDone (int updateCount) |
This is the entry point to be called by trainers. More... | |
void | checkpointTrainer (std::string const &suffix="final") |
Creates a checkpoint on the disk. More... | |
std::string | getModelPath () const |
Returns the path where the latest model would be saved. More... | |
TORCH_ARG (int, epochLength) | |
Epoch length (in number of updates) More... | |
TORCH_ARG (std::shared_ptr< visdom::Visdom >, visdom) | |
Visdom server. Give nullptr to disable plotting. More... | |
TORCH_ARG (std::vector< std::string >, visdomKeys) | |
List of metrics keys to plot. More... | |
TORCH_ARG (bool, visdomOnEpoch) | |
If true, the visdom visualization will happen at the end on the epoch, and will print the mean of the parameters during that epoch. More... | |
TORCH_ARG (int, visdomPlotFreq) | |
If visdomOnEpoch = false, this is the frequency at which visdom plots are updated. More... | |
Checkpointer & | checkpointPath (std::string const &path) |
std::string const & | checkpointPath () const |
TORCH_ARG (std::string, compareMetric) | |
Metrics used to assess preformance of a model Disables performance based checkpoints if empty. More... | |
TORCH_ARG (bool, printMetricsSummary) | |
If true, print the mean of the metrics at each epoch. More... | |
TORCH_ARG (bool, aggregateMetrics) | |
If true, the metrics are aggregated over all workers. More... | |
TORCH_ARG (bool, flushMetrics) | |
If true, we clear the metrics at the end of the epoch. More... | |
TORCH_ARG (bool, dumpMetrics) | |
If true, we dump the json of the metrics at each epoch. More... | |
TORCH_ARG (MetricsSummaryFormat, metricsSummaryFormat) | |
TORCH_ARG (bool, reduceMax) | |
If true, we reduce accross nodes using the max operator instead. More... | |
TORCH_ARG (Hook, epochHook) | |
TORCH_ARG (Hook, updateHook) | |
Static Public Member Functions | |
static void | checkpointTrainer (std::shared_ptr< Trainer > trainer, std::string const &filename="trainer_final.bin") |
Public Attributes | |
std::string | checkpointPath_ |
Where to save everything. More... | |
Protected Member Functions | |
void | onUpdate (int updateCount) |
void | onEpoch (int updateCount) |
void | plotVisdom (const std::vector< float > &values, int count) |
void | printSummary (std::unordered_map< std::string, float > means, std::unordered_map< std::string, float > mins, std::unordered_map< std::string, float > maxs) |
void | reduceMetrics (std::vector< float > &values) |
Protected Attributes | |
std::shared_ptr< Trainer > | trainer_ |
std::vector< std::string > | visdomLines_ |
hires_clock::time_point | lastEpochStamp_ |
int | lastEpochUpdateNum_ = 0 |
using cpid::Checkpointer::Hook = std::function<void(int)> |
cpid::Checkpointer::Checkpointer | ( | std::shared_ptr< Trainer > | trainer | ) |
Checkpointer & cpid::Checkpointer::checkpointPath | ( | std::string const & | path | ) |
std::string const & cpid::Checkpointer::checkpointPath | ( | ) | const |
void cpid::Checkpointer::checkpointTrainer | ( | std::string const & | suffix = "final" | ) |
Creates a checkpoint on the disk.
|
static |
std::string cpid::Checkpointer::getModelPath | ( | ) | const |
Returns the path where the latest model would be saved.
(It's not guaranteed that one have been saved yet)
|
protected |
|
protected |
|
protected |
|
protected |
|
protected |
cpid::Checkpointer::TORCH_ARG | ( | int | , |
epochLength | |||
) |
Epoch length (in number of updates)
cpid::Checkpointer::TORCH_ARG | ( | std::shared_ptr< visdom::Visdom > | , |
visdom | |||
) |
Visdom server. Give nullptr to disable plotting.
cpid::Checkpointer::TORCH_ARG | ( | std::vector< std::string > | , |
visdomKeys | |||
) |
List of metrics keys to plot.
cpid::Checkpointer::TORCH_ARG | ( | bool | , |
visdomOnEpoch | |||
) |
If true, the visdom visualization will happen at the end on the epoch, and will print the mean of the parameters during that epoch.
Otherwise, it will plot the last value of the parameters, at the defined frequency
cpid::Checkpointer::TORCH_ARG | ( | int | , |
visdomPlotFreq | |||
) |
If visdomOnEpoch = false, this is the frequency at which visdom plots are updated.
cpid::Checkpointer::TORCH_ARG | ( | std::string | , |
compareMetric | |||
) |
Metrics used to assess preformance of a model Disables performance based checkpoints if empty.
cpid::Checkpointer::TORCH_ARG | ( | bool | , |
printMetricsSummary | |||
) |
If true, print the mean of the metrics at each epoch.
cpid::Checkpointer::TORCH_ARG | ( | bool | , |
aggregateMetrics | |||
) |
If true, the metrics are aggregated over all workers.
cpid::Checkpointer::TORCH_ARG | ( | bool | , |
flushMetrics | |||
) |
If true, we clear the metrics at the end of the epoch.
cpid::Checkpointer::TORCH_ARG | ( | bool | , |
dumpMetrics | |||
) |
If true, we dump the json of the metrics at each epoch.
cpid::Checkpointer::TORCH_ARG | ( | MetricsSummaryFormat | , |
metricsSummaryFormat | |||
) |
cpid::Checkpointer::TORCH_ARG | ( | bool | , |
reduceMax | |||
) |
If true, we reduce accross nodes using the max operator instead.
cpid::Checkpointer::TORCH_ARG | ( | Hook | , |
epochHook | |||
) |
cpid::Checkpointer::TORCH_ARG | ( | Hook | , |
updateHook | |||
) |
void cpid::Checkpointer::updateDone | ( | int | updateCount | ) |
This is the entry point to be called by trainers.
std::string cpid::Checkpointer::checkpointPath_ |
Where to save everything.
|
protected |
|
protected |
|
protected |
|
protected |