TorchCraftAI
A bot for machine learning research on StarCraft: Brood War
debug.h
1 /*
2  * Copyright (c) 2017-present, Facebook, Inc.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #pragma once
9 
10 #include <autogradpp/autograd.h>
11 #include <glog/logging.h>
12 #include <torch/torch.h>
13 
14 #define ASSERT_SIZE(T, ...) ::common::assertSize(#T, T, __VA_ARGS__)
15 
16 /*
17  * Useful helpers for neural networks expressed with Torch.
18  */
19 namespace common {
20 
21 /**
22  * Returns a string containing the tensor type and sizes
23  */
24 std::string tensorInfo(torch::Tensor x);
25 
26 /**
27  * Returns a string describing the content of a variant
28  */
29 std::string variantInfo(ag::Variant x);
30 
31 /**
32  * Returns a string containing the tensor info, the max/min/mean and sum
33  */
34 std::string tensorStats(torch::Tensor x);
35 
36 /**
37  * Throws if the given float tensor has a NaN or +/- infinity.
38  */
39 void checkTensor(torch::Tensor x, bool logOnError = true);
40 
41 using VarList = torch::autograd::variable_list;
42 using HookFunction = std::function<VarList(const VarList&, const VarList&)>;
43 /**
44  * Adds a hook to the backwards of the variable.
45  * The hook function takes gradInput and gradOutput, and should by default
46  * return gradInput, that is, the identity function looks like:
47  * [](VarList const& gradInp, Varlist const& gradOutp) { return gradInp; }
48  *
49  * https://pytorch.org/docs/stable/nn.html?highlight=hook#torch.nn.Module.register_backward_hook
50  */
51 torch::Tensor const& addHook(torch::Tensor const& tensor, HookFunction&& f);
52 
53 /**
54  * Verifies that a tensor's dimension sizes match expectations.
55  * If a dimension is negative (e.g. -1) it won't be checked.
56  * Throws a std::range_error if they don't.
57  */
58 void assertSize(
59  const std::string& name,
60  const torch::Tensor& tensor,
61  at::IntList sizes);
62 
63 /**
64  * Collects metrics about a container's weights
65  */
66 struct WeightSummary {
67  WeightSummary(torch::nn::Module&);
68  long weights = 0;
69  long zeroes = 0;
70  long nans = 0;
71  float norm1 = 0.0;
72  float norm2 = 0.0;
73  std::string toString() const;
74 };
75 std::ostream& operator<<(std::ostream& out, const WeightSummary& summary);
76 
77 /**
78  * Show the current memory usage, the first element is the amount allocated,
79  * or currently used by tensors that are alive, and the second element
80  * is the amount cached by the caching allocator.
81  * WARNING: This function will call cudaDeviceSynchronize, so it's extremely
82  * expensive, and should not be in any training runs unless it's hidden behind
83  * an if statement.
84  */
85 std::pair<int64_t, int64_t> torchMemoryUsage(int device = 0);
86 
87 } // namespace common
std::function< VarList(const VarList &, const VarList &)> HookFunction
Definition: debug.h:42
long weights
Definition: debug.h:68
float norm2
Definition: debug.h:72
std::ostream & operator<<(std::ostream &out, const WeightSummary &summary)
Definition: debug.cpp:177
std::pair< int64_t, int64_t > torchMemoryUsage(int device)
Show the current memory usage, the first element is the amount allocated, or currently used by tensor...
Definition: debug.cpp:181
WeightSummary(torch::nn::Module &)
Definition: debug.cpp:150
std::string variantInfo(ag::Variant x)
Returns a string describing the content of a variant.
Definition: debug.cpp:91
void assertSize(const std::string &name, const torch::Tensor &tensor, at::IntList sizes)
Verifies that a tensor&#39;s dimension sizes match expectations.
Definition: debug.cpp:124
torch::autograd::variable_list VarList
Definition: debug.h:41
long zeroes
Definition: debug.h:69
std::string tensorInfo(torch::Tensor x)
Returns a string containing the tensor type and sizes.
Definition: debug.cpp:17
std::string toString() const
Definition: debug.cpp:166
void checkTensor(torch::Tensor x, bool logOnError)
Throws if the given float tensor has a NaN or +/- infinity.
Definition: debug.cpp:106
float norm1
Definition: debug.h:71
General utilities.
Definition: assert.cpp:7
std::string tensorStats(torch::Tensor x)
Returns a string containing the tensor info, the max/min/mean and sum.
Definition: debug.cpp:95
torch::Tensor const & addHook(torch::Tensor const &tensor, HookFunction &&f)
Adds a hook to the backwards of the variable.
Definition: debug.cpp:116
Collects metrics about a container&#39;s weights.
Definition: debug.h:66
long nans
Definition: debug.h:70