TorchCraftAI
A bot for machine learning research on StarCraft: Brood War
Public Member Functions | Protected Attributes | List of all members
cpid::DiscreteMaxSampler Class Reference

This sampler expects as input an unordered_map<string, Variant>, containing an entry QKey, which is a tensor of size [b, n]. More...

#include <sampler.h>

Inherits cpid::BaseSampler.

Public Member Functions

 DiscreteMaxSampler (const std::string &policyKey=kPiKey, const std::string &actionKey=kActionKey)
 
ag::Variant sample (ag::Variant in) override
 
- Public Member Functions inherited from cpid::BaseSampler
 BaseSampler ()
 
virtual ~BaseSampler ()=default
 
virtual ag::Variant computeProba (const ag::Variant &in, const ag::Variant &action)
 

Protected Attributes

std::string policyKey_
 
std::string actionKey_
 

Detailed Description

This sampler expects as input an unordered_map<string, Variant>, containing an entry QKey, which is a tensor of size [b, n].

It outputs the same map, with a new key kActionKey, a tensor of size [b] where each entry is in [0,n-1], and correspond to the action with the highest score.

Constructor & Destructor Documentation

cpid::DiscreteMaxSampler::DiscreteMaxSampler ( const std::string &  policyKey = kPiKey,
const std::string &  actionKey = kActionKey 
)

Member Function Documentation

ag::Variant cpid::DiscreteMaxSampler::sample ( ag::Variant  in)
overridevirtual

Reimplemented from cpid::BaseSampler.

Member Data Documentation

std::string cpid::DiscreteMaxSampler::actionKey_
protected
std::string cpid::DiscreteMaxSampler::policyKey_
protected

The documentation for this class was generated from the following files: