TorchCraftAI
A bot for machine learning research on StarCraft: Brood War
|
This sampler expects as input an unordered_map<string, Variant>, containing an entry QKey, which is a tensor of size [b, n]. More...
#include <sampler.h>
Inherits cpid::BaseSampler.
Public Member Functions | |
DiscreteMaxSampler (const std::string &policyKey=kPiKey, const std::string &actionKey=kActionKey) | |
ag::Variant | sample (ag::Variant in) override |
Public Member Functions inherited from cpid::BaseSampler | |
BaseSampler () | |
virtual | ~BaseSampler ()=default |
virtual ag::Variant | computeProba (const ag::Variant &in, const ag::Variant &action) |
Protected Attributes | |
std::string | policyKey_ |
std::string | actionKey_ |
This sampler expects as input an unordered_map<string, Variant>, containing an entry QKey, which is a tensor of size [b, n].
It outputs the same map, with a new key kActionKey, a tensor of size [b] where each entry is in [0,n-1], and correspond to the action with the highest score.
cpid::DiscreteMaxSampler::DiscreteMaxSampler | ( | const std::string & | policyKey = kPiKey , |
const std::string & | actionKey = kActionKey |
||
) |
|
overridevirtual |
Reimplemented from cpid::BaseSampler.
|
protected |
|
protected |