TorchCraftAI
A bot for machine learning research on StarCraft: Brood War
Public Types | Public Member Functions | Static Public Member Functions | Public Attributes | Protected Member Functions | Protected Attributes | List of all members
cpid::Checkpointer Class Reference

#include <checkpointer.h>

Public Types

enum  MetricsSummaryFormat { FORMAT_DEFAULT, FORMAT_TORCHBOARD }
 Choose a format for stdout metrics. More...
 
using Hook = std::function< void(int)>
 

Public Member Functions

 Checkpointer (std::shared_ptr< Trainer > trainer)
 
void updateDone (int updateCount)
 This is the entry point to be called by trainers. More...
 
void checkpointTrainer (std::string const &suffix="final")
 Creates a checkpoint on the disk. More...
 
std::string getModelPath () const
 Returns the path where the latest model would be saved. More...
 
 TORCH_ARG (int, epochLength)
 Epoch length (in number of updates) More...
 
 TORCH_ARG (std::shared_ptr< visdom::Visdom >, visdom)
 Visdom server. Give nullptr to disable plotting. More...
 
 TORCH_ARG (std::vector< std::string >, visdomKeys)
 List of metrics keys to plot. More...
 
 TORCH_ARG (bool, visdomOnEpoch)
 If true, the visdom visualization will happen at the end on the epoch, and will print the mean of the parameters during that epoch. More...
 
 TORCH_ARG (int, visdomPlotFreq)
 If visdomOnEpoch = false, this is the frequency at which visdom plots are updated. More...
 
CheckpointercheckpointPath (std::string const &path)
 
std::string const & checkpointPath () const
 
 TORCH_ARG (std::string, compareMetric)
 Metrics used to assess preformance of a model Disables performance based checkpoints if empty. More...
 
 TORCH_ARG (bool, printMetricsSummary)
 If true, print the mean of the metrics at each epoch. More...
 
 TORCH_ARG (bool, aggregateMetrics)
 If true, the metrics are aggregated over all workers. More...
 
 TORCH_ARG (bool, flushMetrics)
 If true, we clear the metrics at the end of the epoch. More...
 
 TORCH_ARG (bool, dumpMetrics)
 If true, we dump the json of the metrics at each epoch. More...
 
 TORCH_ARG (MetricsSummaryFormat, metricsSummaryFormat)
 
 TORCH_ARG (bool, reduceMax)
 If true, we reduce accross nodes using the max operator instead. More...
 
 TORCH_ARG (Hook, epochHook)
 
 TORCH_ARG (Hook, updateHook)
 

Static Public Member Functions

static void checkpointTrainer (std::shared_ptr< Trainer > trainer, std::string const &filename="trainer_final.bin")
 

Public Attributes

std::string checkpointPath_
 Where to save everything. More...
 

Protected Member Functions

void onUpdate (int updateCount)
 
void onEpoch (int updateCount)
 
void plotVisdom (const std::vector< float > &values, int count)
 
void printSummary (std::unordered_map< std::string, float > means, std::unordered_map< std::string, float > mins, std::unordered_map< std::string, float > maxs)
 
void reduceMetrics (std::vector< float > &values)
 

Protected Attributes

std::shared_ptr< Trainertrainer_
 
std::vector< std::string > visdomLines_
 
hires_clock::time_point lastEpochStamp_
 
int lastEpochUpdateNum_ = 0
 

Member Typedef Documentation

using cpid::Checkpointer::Hook = std::function<void(int)>

Member Enumeration Documentation

Choose a format for stdout metrics.

Enumerator
FORMAT_DEFAULT 
FORMAT_TORCHBOARD 

Constructor & Destructor Documentation

cpid::Checkpointer::Checkpointer ( std::shared_ptr< Trainer trainer)

Member Function Documentation

Checkpointer & cpid::Checkpointer::checkpointPath ( std::string const &  path)
std::string const & cpid::Checkpointer::checkpointPath ( ) const
void cpid::Checkpointer::checkpointTrainer ( std::string const &  suffix = "final")

Creates a checkpoint on the disk.

void cpid::Checkpointer::checkpointTrainer ( std::shared_ptr< Trainer trainer,
std::string const &  filename = "trainer_final.bin" 
)
static
std::string cpid::Checkpointer::getModelPath ( ) const

Returns the path where the latest model would be saved.

(It's not guaranteed that one have been saved yet)

void cpid::Checkpointer::onEpoch ( int  updateCount)
protected
void cpid::Checkpointer::onUpdate ( int  updateCount)
protected
void cpid::Checkpointer::plotVisdom ( const std::vector< float > &  values,
int  count 
)
protected
void cpid::Checkpointer::printSummary ( std::unordered_map< std::string, float >  means,
std::unordered_map< std::string, float >  mins,
std::unordered_map< std::string, float >  maxs 
)
protected
void cpid::Checkpointer::reduceMetrics ( std::vector< float > &  values)
protected
cpid::Checkpointer::TORCH_ARG ( int  ,
epochLength   
)

Epoch length (in number of updates)

cpid::Checkpointer::TORCH_ARG ( std::shared_ptr< visdom::Visdom ,
visdom   
)

Visdom server. Give nullptr to disable plotting.

cpid::Checkpointer::TORCH_ARG ( std::vector< std::string >  ,
visdomKeys   
)

List of metrics keys to plot.

cpid::Checkpointer::TORCH_ARG ( bool  ,
visdomOnEpoch   
)

If true, the visdom visualization will happen at the end on the epoch, and will print the mean of the parameters during that epoch.

Otherwise, it will plot the last value of the parameters, at the defined frequency

cpid::Checkpointer::TORCH_ARG ( int  ,
visdomPlotFreq   
)

If visdomOnEpoch = false, this is the frequency at which visdom plots are updated.

cpid::Checkpointer::TORCH_ARG ( std::string  ,
compareMetric   
)

Metrics used to assess preformance of a model Disables performance based checkpoints if empty.

cpid::Checkpointer::TORCH_ARG ( bool  ,
printMetricsSummary   
)

If true, print the mean of the metrics at each epoch.

cpid::Checkpointer::TORCH_ARG ( bool  ,
aggregateMetrics   
)

If true, the metrics are aggregated over all workers.

cpid::Checkpointer::TORCH_ARG ( bool  ,
flushMetrics   
)

If true, we clear the metrics at the end of the epoch.

cpid::Checkpointer::TORCH_ARG ( bool  ,
dumpMetrics   
)

If true, we dump the json of the metrics at each epoch.

cpid::Checkpointer::TORCH_ARG ( MetricsSummaryFormat  ,
metricsSummaryFormat   
)
cpid::Checkpointer::TORCH_ARG ( bool  ,
reduceMax   
)

If true, we reduce accross nodes using the max operator instead.

cpid::Checkpointer::TORCH_ARG ( Hook  ,
epochHook   
)
cpid::Checkpointer::TORCH_ARG ( Hook  ,
updateHook   
)
void cpid::Checkpointer::updateDone ( int  updateCount)

This is the entry point to be called by trainers.

Member Data Documentation

std::string cpid::Checkpointer::checkpointPath_

Where to save everything.

hires_clock::time_point cpid::Checkpointer::lastEpochStamp_
protected
int cpid::Checkpointer::lastEpochUpdateNum_ = 0
protected
std::shared_ptr<Trainer> cpid::Checkpointer::trainer_
protected
std::vector<std::string> cpid::Checkpointer::visdomLines_
protected

The documentation for this class was generated from the following files: