10 #include <autogradpp/autograd.h> 11 #include <glog/logging.h> 12 #include <torch/torch.h> 14 #define ASSERT_SIZE(T, ...) ::common::assertSize(#T, T, __VA_ARGS__) 39 void checkTensor(torch::Tensor x,
bool logOnError =
true);
41 using VarList = torch::autograd::variable_list;
42 using HookFunction = std::function<VarList(const VarList&, const VarList&)>;
59 const std::string& name,
60 const torch::Tensor& tensor,
std::function< VarList(const VarList &, const VarList &)> HookFunction
Definition: debug.h:42
long weights
Definition: debug.h:68
float norm2
Definition: debug.h:72
std::ostream & operator<<(std::ostream &out, const WeightSummary &summary)
Definition: debug.cpp:177
std::pair< int64_t, int64_t > torchMemoryUsage(int device)
Show the current memory usage, the first element is the amount allocated, or currently used by tensor...
Definition: debug.cpp:181
WeightSummary(torch::nn::Module &)
Definition: debug.cpp:150
std::string variantInfo(ag::Variant x)
Returns a string describing the content of a variant.
Definition: debug.cpp:91
void assertSize(const std::string &name, const torch::Tensor &tensor, at::IntList sizes)
Verifies that a tensor's dimension sizes match expectations.
Definition: debug.cpp:124
torch::autograd::variable_list VarList
Definition: debug.h:41
long zeroes
Definition: debug.h:69
std::string tensorInfo(torch::Tensor x)
Returns a string containing the tensor type and sizes.
Definition: debug.cpp:17
std::string toString() const
Definition: debug.cpp:166
void checkTensor(torch::Tensor x, bool logOnError)
Throws if the given float tensor has a NaN or +/- infinity.
Definition: debug.cpp:106
float norm1
Definition: debug.h:71
General utilities.
Definition: assert.cpp:7
std::string tensorStats(torch::Tensor x)
Returns a string containing the tensor info, the max/min/mean and sum.
Definition: debug.cpp:95
torch::Tensor const & addHook(torch::Tensor const &tensor, HookFunction &&f)
Adds a hook to the backwards of the variable.
Definition: debug.cpp:116
Collects metrics about a container's weights.
Definition: debug.h:66
long nans
Definition: debug.h:70