1#pragma once
2
3#include <torch/csrc/Export.h>
4#include <memory>
5#include <string>
6
7namespace torch {
8namespace autograd {
9
10// forward declaration of Node from function.h
11struct Node;
12
13struct TORCH_API AnomalyMode {
14 static bool is_enabled() {
15 return _enabled;
16 }
17 static bool should_check_nan() {
18 return _check_nan;
19 }
20 static void set_enabled(bool enabled, bool check_nan = true) {
21 _enabled = enabled;
22 _check_nan = check_nan;
23 }
24
25 private:
26 static bool _enabled;
27 static bool _check_nan;
28};
29
30/// A RAII guard that enables Anomaly Detection Mode.
31///
32/// Anomaly detection mode is useful for debugging problems happening
33/// in the backward, such as unexpectedly modified tensors or NaNs
34/// occuring in the backward.
35///
36/// The enabling of anomaly mode is global - as soon as there is one
37/// such guard, it is enabled for all computation and threads. It also
38/// comes with a significant performance penalty.
39///
40/// Example:
41/// @code
42/// auto x = torch::tensor({1.}, torch::requires_grad());
43/// {
44/// torch::autograd::DetectAnomalyGuard detect_anomaly;
45/// auto x = torch::tensor({5.0}, torch::requires_grad());
46/// auto y = x * x;
47/// auto z = y * y;
48/// y += 1;
49/// z.backward();
50/// }
51/// @endcode
52class TORCH_API DetectAnomalyGuard {
53 public:
54 DetectAnomalyGuard(bool check_nan = true);
55 ~DetectAnomalyGuard();
56
57 private:
58 bool prev_check_nan_;
59};
60
61struct TORCH_API AnomalyMetadata {
62 virtual ~AnomalyMetadata();
63 virtual void store_stack();
64 virtual void print_stack(const std::string& current_node_name);
65 virtual void assign_parent(const std::shared_ptr<Node>& parent_node);
66
67 private:
68 std::string traceback_;
69 std::shared_ptr<Node> parent_;
70};
71
72} // namespace autograd
73} // namespace torch
74