TorchCraftAI
A bot for machine learning research on StarCraft: Brood War
Public Member Functions | Protected Attributes | List of all members
cpid::ContinuousGaussianSampler Class Reference

This sampler expects as input an unordered_map<string, Variant>, containing an entry policyKey, which is a tensor of size [b, n]. More...

#include <sampler.h>

Inherits cpid::BaseSampler.

Public Member Functions

 ContinuousGaussianSampler (const std::string &policyKey=kPiKey, const std::string &stdKey=kSigmaKey, const std::string &actionKey=kActionKey, const std::string &pActionKey=kPActionKey)
 
ag::Variant sample (ag::Variant in) override
 
ag::Variant computeProba (const ag::Variant &in, const ag::Variant &action) override
 
- Public Member Functions inherited from cpid::BaseSampler
 BaseSampler ()
 
virtual ~BaseSampler ()=default
 

Protected Attributes

std::string policyKey_
 
std::string stdKey_
 
std::string actionKey_
 
std::string pActionKey_
 

Detailed Description

This sampler expects as input an unordered_map<string, Variant>, containing an entry policyKey, which is a tensor of size [b, n].

It outputs the same map, with a new key kActionKey, a tensor of size [b] where each entry action[i] is sampled from a normal distribution centered in policy[i]. It also expects the stdKey to be set, it will be used as the standard deviation of the normal. It can be either a float/double, in which case the deviation will be the same for the batch, or it can be the same shape as the policy, for a finer control. It also adds a key pActionKey which corresponds to the probability of the sampled action.

Constructor & Destructor Documentation

cpid::ContinuousGaussianSampler::ContinuousGaussianSampler ( const std::string &  policyKey = kPiKey,
const std::string &  stdKey = kSigmaKey,
const std::string &  actionKey = kActionKey,
const std::string &  pActionKey = kPActionKey 
)

Member Function Documentation

ag::Variant cpid::ContinuousGaussianSampler::computeProba ( const ag::Variant &  in,
const ag::Variant &  action 
)
overridevirtual

Reimplemented from cpid::BaseSampler.

ag::Variant cpid::ContinuousGaussianSampler::sample ( ag::Variant  in)
overridevirtual

Reimplemented from cpid::BaseSampler.

Member Data Documentation

std::string cpid::ContinuousGaussianSampler::actionKey_
protected
std::string cpid::ContinuousGaussianSampler::pActionKey_
protected
std::string cpid::ContinuousGaussianSampler::policyKey_
protected
std::string cpid::ContinuousGaussianSampler::stdKey_
protected

The documentation for this class was generated from the following files: