1#include <c10/util/Exception.h>
2#include <pybind11/pybind11.h>
3#include <torch/csrc/Exceptions.h>
4#include <torch/csrc/autograd/python_anomaly_mode.h>
5#include <torch/csrc/autograd/python_cpp_function.h>
6#include <torch/csrc/python_headers.h>
7#include <torch/csrc/utils/auto_gil.h>
8#include <torch/csrc/utils/object_ptr.h>
9#include <torch/csrc/utils/pybind.h>
10#include <torch/csrc/utils/python_strings.h>
11
12#include <iostream>
13
14namespace torch {
15namespace autograd {
16
17void PyAnomalyMetadata::store_stack() {
18 pybind11::gil_scoped_acquire gil;
19 THPObjectPtr mod(PyImport_ImportModule("torch.fx.traceback"));
20 if (!mod) {
21 throw python_error();
22 }
23
24 THPObjectPtr list(PyObject_CallMethod(mod.get(), "format_stack", ""));
25 if (!list) {
26 throw python_error();
27 }
28
29 if (PyDict_SetItemString(dict(), ANOMALY_TRACE_KEY, list.get())) {
30 throw python_error();
31 }
32}
33
34void PyAnomalyMetadata::print_stack(const std::string& current_node_name) {
35 pybind11::gil_scoped_acquire gil;
36 if (!PyDict_Check(dict())) {
37 throw std::runtime_error("Anomaly metadata is not a python dictionary.");
38 }
39 PyObject* trace_stack = PyDict_GetItemString(dict(), ANOMALY_TRACE_KEY);
40 _print_stack(trace_stack, current_node_name, false);
41 PyObject* pyparent(PyDict_GetItemString(dict(), ANOMALY_PARENT_KEY));
42
43 // if there is no "parent_" in metadata, then it means this metadata's node
44 // is the root and stop printing the traceback
45 while (pyparent) {
46 THPObjectPtr parent_metadata(PyObject_GetAttrString(pyparent, "metadata"));
47 if (!parent_metadata) {
48 throw python_error();
49 }
50 THPObjectPtr parent_name_pyobj(PyObject_CallMethod(pyparent, "name", ""));
51 if (!parent_name_pyobj) {
52 throw python_error();
53 }
54 const char* parent_name_char = PyUnicode_AsUTF8(parent_name_pyobj.get());
55 if (!parent_name_char) {
56 throw python_error();
57 }
58 const std::string parent_name(parent_name_char);
59 PyObject* parent_stack =
60 PyDict_GetItemString(parent_metadata.get(), ANOMALY_TRACE_KEY);
61 _print_stack(parent_stack, parent_name, true);
62 // get the parent of this node, if this node is a root, pyparent is simply
63 // null
64 pyparent = PyDict_GetItemString(parent_metadata.get(), ANOMALY_PARENT_KEY);
65 }
66}
67
68void PyAnomalyMetadata::assign_parent(
69 const std::shared_ptr<Node>& parent_node) {
70 // assign the python object of parent_node in metadata["parent_"]
71 // if parent_node is nullptr, then do nothing (it can mean that "parent_" key
72 // is not in metadata)
73
74 pybind11::gil_scoped_acquire gil;
75 if (!parent_node)
76 return;
77
78 THPObjectPtr parent_node_(functionToPyObject(parent_node));
79 if (!parent_node_) {
80 throw python_error();
81 }
82 if (PyDict_SetItemString(dict(), ANOMALY_PARENT_KEY, parent_node_.get())) {
83 throw python_error();
84 }
85}
86
87void _print_stack(
88 PyObject* stack,
89 const std::string& current_node_name,
90 bool is_parent) {
91 if (!stack) {
92 TORCH_WARN(
93 "Error detected in ",
94 current_node_name,
95 ". ",
96 "No forward pass information available. Enable detect anomaly "
97 "during forward pass for more information.");
98 return;
99 }
100
101 THPObjectPtr empty_string(PyUnicode_FromString(""));
102 if (!empty_string) {
103 throw python_error();
104 }
105
106 // stack is a list of Python strings ending with newlines. Use join to convert
107 // to a single string.
108 THPObjectPtr msg(PyUnicode_Join(empty_string, stack));
109 if (!msg) {
110 throw python_error();
111 }
112
113 if (!is_parent) {
114 TORCH_WARN(
115 "Error detected in ",
116 current_node_name,
117 ". ",
118 "Traceback of forward call that caused the error:\n",
119 THPUtils_unpackString(msg.get()));
120 } else {
121 TORCH_WARN(
122 "\n\n",
123 "Previous calculation was induced by ",
124 current_node_name,
125 ". "
126 "Traceback of forward call that induced the previous calculation:\n",
127 THPUtils_unpackString(msg.get()));
128 }
129}
130
131} // namespace autograd
132} // namespace torch
133