TorchCraftAI
A bot for machine learning research on StarCraft: Brood War
rand.h
1 /*
2  * Copyright (c) 2017-present, Facebook, Inc.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #pragma once
9 
10 #include <memory>
11 #include <mutex>
12 #include <random>
13 
14 #include <ATen/Generator.h>
15 
16 namespace common {
17 
18 // This provides some thead-safe random primitives
19 
20 class Rand {
21  public:
22  /// Set the seed for random generators: this one, rand(3) and ATen
23  static void setSeed(int64_t seed);
24 
25  /// Set a static seed for the local thread.
26  static void setLocalSeed(int64_t seed);
27 
28  /// Sample random value
29  static uint64_t rand();
30 
31  /// Default random seed used by init()
32  static int64_t defaultRandomSeed();
33 
34  /// Random number engine based on previously set seed
35  template <typename T>
36  static T makeRandEngine() {
37  std::seed_seq seed{Rand::rand(), Rand::rand()};
38  return T(seed);
39  }
40 
41  /// Sample from a given distribution
42  template <typename T>
43  static auto sample(T&& distrib) -> decltype(distrib.min()) {
44  if (hasLocalSeed_) {
45  return distrib(localRandEngine_);
46  }
47  std::lock_guard<std::mutex> guard(randEngineMutex_);
48  return distrib(randEngine_);
49  }
50 
51  /**
52  * This allows to use a custom seed in torch.
53  * For example: at::normal(mean, dev, Rand::gen());
54  * Similarly to rand(), this will use a thread_local generator if a local seed
55  * is set
56  */
57  static at::Generator* gen();
58 
59  protected:
60  static std::mt19937 randEngine_;
61  static std::mutex randEngineMutex_;
62 
63  static thread_local bool hasLocalSeed_;
64  static thread_local std::mt19937 localRandEngine_;
65 
66  static std::unique_ptr<at::Generator> torchEngine_;
67  static thread_local std::unique_ptr<at::Generator> localTorchEngine_;
68 };
69 
70 // Generates a random alphanumerica string, not bound to the Rand class
71 // because we do not want these IDs to be reproduced across multiple processes
72 std::string randId(size_t length);
73 
74 // This method was originally written by Christopher Smith at
75 // https://stackoverflow.com/questions/6942273/get-random-element-from-container
76 // and is used under CC BY-SA: https://creativecommons.org/licenses/by-sa/2.0/
77 template <typename Iter, typename RandomGenerator>
78 Iter select_randomly(Iter start, Iter end, RandomGenerator& g) {
79  std::uniform_int_distribution<> dis(0, std::distance(start, end) - 1);
80  std::advance(start, dis(g));
81  return start;
82 }
83 } // namespace common
std::string randId(size_t len)
Definition: rand.cpp:76
Definition: rand.h:20
static thread_local std::unique_ptr< at::Generator > localTorchEngine_
Definition: rand.h:67
static std::mutex randEngineMutex_
Definition: rand.h:61
static std::mt19937 randEngine_
Definition: rand.h:60
static thread_local std::mt19937 localRandEngine_
Definition: rand.h:64
static T makeRandEngine()
Random number engine based on previously set seed.
Definition: rand.h:36
static at::Generator * gen()
This allows to use a custom seed in torch.
Definition: rand.cpp:69
static int64_t defaultRandomSeed()
Default random seed used by init()
Definition: rand.cpp:27
static std::unique_ptr< at::Generator > torchEngine_
Definition: rand.h:66
General utilities.
Definition: assert.cpp:7
static thread_local bool hasLocalSeed_
Definition: rand.h:63
static auto sample(T &&distrib) -> decltype(distrib.min())
Sample from a given distribution.
Definition: rand.h:43
static void setSeed(int64_t seed)
Set the seed for random generators: this one, rand(3) and ATen.
Definition: rand.cpp:31
static uint64_t rand()
Sample random value.
Definition: rand.cpp:61
static void setLocalSeed(int64_t seed)
Set a static seed for the local thread.
Definition: rand.cpp:48
Iter select_randomly(Iter start, Iter end, RandomGenerator &g)
Definition: rand.h:78