1 | #include <c10/util/Exception.h> |
2 | #include <pybind11/pybind11.h> |
3 | #include <torch/csrc/Exceptions.h> |
4 | #include <torch/csrc/autograd/python_anomaly_mode.h> |
5 | #include <torch/csrc/autograd/python_cpp_function.h> |
6 | #include <torch/csrc/python_headers.h> |
7 | #include <torch/csrc/utils/auto_gil.h> |
8 | #include <torch/csrc/utils/object_ptr.h> |
9 | #include <torch/csrc/utils/pybind.h> |
10 | #include <torch/csrc/utils/python_strings.h> |
11 | |
12 | #include <iostream> |
13 | |
14 | namespace torch { |
15 | namespace autograd { |
16 | |
17 | void PyAnomalyMetadata::store_stack() { |
18 | pybind11::gil_scoped_acquire gil; |
19 | THPObjectPtr mod(PyImport_ImportModule("torch.fx.traceback" )); |
20 | if (!mod) { |
21 | throw python_error(); |
22 | } |
23 | |
24 | THPObjectPtr list(PyObject_CallMethod(mod.get(), "format_stack" , "" )); |
25 | if (!list) { |
26 | throw python_error(); |
27 | } |
28 | |
29 | if (PyDict_SetItemString(dict(), ANOMALY_TRACE_KEY, list.get())) { |
30 | throw python_error(); |
31 | } |
32 | } |
33 | |
34 | void PyAnomalyMetadata::print_stack(const std::string& current_node_name) { |
35 | pybind11::gil_scoped_acquire gil; |
36 | if (!PyDict_Check(dict())) { |
37 | throw std::runtime_error("Anomaly metadata is not a python dictionary." ); |
38 | } |
39 | PyObject* trace_stack = PyDict_GetItemString(dict(), ANOMALY_TRACE_KEY); |
40 | _print_stack(trace_stack, current_node_name, false); |
41 | PyObject* pyparent(PyDict_GetItemString(dict(), ANOMALY_PARENT_KEY)); |
42 | |
43 | // if there is no "parent_" in metadata, then it means this metadata's node |
44 | // is the root and stop printing the traceback |
45 | while (pyparent) { |
46 | THPObjectPtr parent_metadata(PyObject_GetAttrString(pyparent, "metadata" )); |
47 | if (!parent_metadata) { |
48 | throw python_error(); |
49 | } |
50 | THPObjectPtr parent_name_pyobj(PyObject_CallMethod(pyparent, "name" , "" )); |
51 | if (!parent_name_pyobj) { |
52 | throw python_error(); |
53 | } |
54 | const char* parent_name_char = PyUnicode_AsUTF8(parent_name_pyobj.get()); |
55 | if (!parent_name_char) { |
56 | throw python_error(); |
57 | } |
58 | const std::string parent_name(parent_name_char); |
59 | PyObject* parent_stack = |
60 | PyDict_GetItemString(parent_metadata.get(), ANOMALY_TRACE_KEY); |
61 | _print_stack(parent_stack, parent_name, true); |
62 | // get the parent of this node, if this node is a root, pyparent is simply |
63 | // null |
64 | pyparent = PyDict_GetItemString(parent_metadata.get(), ANOMALY_PARENT_KEY); |
65 | } |
66 | } |
67 | |
68 | void PyAnomalyMetadata::assign_parent( |
69 | const std::shared_ptr<Node>& parent_node) { |
70 | // assign the python object of parent_node in metadata["parent_"] |
71 | // if parent_node is nullptr, then do nothing (it can mean that "parent_" key |
72 | // is not in metadata) |
73 | |
74 | pybind11::gil_scoped_acquire gil; |
75 | if (!parent_node) |
76 | return; |
77 | |
78 | THPObjectPtr parent_node_(functionToPyObject(parent_node)); |
79 | if (!parent_node_) { |
80 | throw python_error(); |
81 | } |
82 | if (PyDict_SetItemString(dict(), ANOMALY_PARENT_KEY, parent_node_.get())) { |
83 | throw python_error(); |
84 | } |
85 | } |
86 | |
87 | void _print_stack( |
88 | PyObject* stack, |
89 | const std::string& current_node_name, |
90 | bool is_parent) { |
91 | if (!stack) { |
92 | TORCH_WARN( |
93 | "Error detected in " , |
94 | current_node_name, |
95 | ". " , |
96 | "No forward pass information available. Enable detect anomaly " |
97 | "during forward pass for more information." ); |
98 | return; |
99 | } |
100 | |
101 | THPObjectPtr empty_string(PyUnicode_FromString("" )); |
102 | if (!empty_string) { |
103 | throw python_error(); |
104 | } |
105 | |
106 | // stack is a list of Python strings ending with newlines. Use join to convert |
107 | // to a single string. |
108 | THPObjectPtr msg(PyUnicode_Join(empty_string, stack)); |
109 | if (!msg) { |
110 | throw python_error(); |
111 | } |
112 | |
113 | if (!is_parent) { |
114 | TORCH_WARN( |
115 | "Error detected in " , |
116 | current_node_name, |
117 | ". " , |
118 | "Traceback of forward call that caused the error:\n" , |
119 | THPUtils_unpackString(msg.get())); |
120 | } else { |
121 | TORCH_WARN( |
122 | "\n\n" , |
123 | "Previous calculation was induced by " , |
124 | current_node_name, |
125 | ". " |
126 | "Traceback of forward call that induced the previous calculation:\n" , |
127 | THPUtils_unpackString(msg.get())); |
128 | } |
129 | } |
130 | |
131 | } // namespace autograd |
132 | } // namespace torch |
133 | |