1 | #pragma once |
2 | |
3 | #include <torch/csrc/Export.h> |
4 | #include <memory> |
5 | #include <string> |
6 | |
7 | namespace torch { |
8 | namespace autograd { |
9 | |
10 | // forward declaration of Node from function.h |
11 | struct Node; |
12 | |
13 | struct TORCH_API AnomalyMode { |
14 | static bool is_enabled() { |
15 | return _enabled; |
16 | } |
17 | static bool should_check_nan() { |
18 | return _check_nan; |
19 | } |
20 | static void set_enabled(bool enabled, bool check_nan = true) { |
21 | _enabled = enabled; |
22 | _check_nan = check_nan; |
23 | } |
24 | |
25 | private: |
26 | static bool _enabled; |
27 | static bool _check_nan; |
28 | }; |
29 | |
30 | /// A RAII guard that enables Anomaly Detection Mode. |
31 | /// |
32 | /// Anomaly detection mode is useful for debugging problems happening |
33 | /// in the backward, such as unexpectedly modified tensors or NaNs |
34 | /// occuring in the backward. |
35 | /// |
36 | /// The enabling of anomaly mode is global - as soon as there is one |
37 | /// such guard, it is enabled for all computation and threads. It also |
38 | /// comes with a significant performance penalty. |
39 | /// |
40 | /// Example: |
41 | /// @code |
42 | /// auto x = torch::tensor({1.}, torch::requires_grad()); |
43 | /// { |
44 | /// torch::autograd::DetectAnomalyGuard detect_anomaly; |
45 | /// auto x = torch::tensor({5.0}, torch::requires_grad()); |
46 | /// auto y = x * x; |
47 | /// auto z = y * y; |
48 | /// y += 1; |
49 | /// z.backward(); |
50 | /// } |
51 | /// @endcode |
52 | class TORCH_API DetectAnomalyGuard { |
53 | public: |
54 | DetectAnomalyGuard(bool check_nan = true); |
55 | ~DetectAnomalyGuard(); |
56 | |
57 | private: |
58 | bool prev_check_nan_; |
59 | }; |
60 | |
61 | struct TORCH_API AnomalyMetadata { |
62 | virtual ~AnomalyMetadata(); |
63 | virtual void store_stack(); |
64 | virtual void print_stack(const std::string& current_node_name); |
65 | virtual void assign_parent(const std::shared_ptr<Node>& parent_node); |
66 | |
67 | private: |
68 | std::string traceback_; |
69 | std::shared_ptr<Node> parent_; |
70 | }; |
71 | |
72 | } // namespace autograd |
73 | } // namespace torch |
74 | |