1 | #include <torch/csrc/Exceptions.h> |
2 | #include <torch/csrc/python_headers.h> |
3 | |
4 | #include <array> |
5 | #include <cstdarg> |
6 | #include <exception> |
7 | #include <utility> |
8 | |
9 | #include <fmt/format.h> |
10 | #include <torch/csrc/THP.h> |
11 | |
12 | #include <c10/util/StringUtil.h> |
13 | |
14 | PyObject *THPException_FatalError, *THPException_LinAlgError, |
15 | *THPException_OutOfMemoryError, *THPException_DistBackendError; |
16 | |
17 | #define ASSERT_TRUE(cond) \ |
18 | if (!(cond)) \ |
19 | return false |
20 | bool THPException_init(PyObject* module) { |
21 | ASSERT_TRUE( |
22 | THPException_FatalError = |
23 | PyErr_NewException("torch.FatalError" , nullptr, nullptr)); |
24 | ASSERT_TRUE( |
25 | PyModule_AddObject(module, "FatalError" , THPException_FatalError) == 0); |
26 | |
27 | // Set the doc string here since _add_docstr throws malloc errors if tp_doc is |
28 | // modified for an error class. |
29 | ASSERT_TRUE( |
30 | THPException_LinAlgError = PyErr_NewExceptionWithDoc( |
31 | "torch._C._LinAlgError" , |
32 | "Error raised by torch.linalg function when the cause of error is a numerical inconsistency in the data.\n \ |
33 | For example, you can the torch.linalg.inv function will raise torch.linalg.LinAlgError when it finds that \ |
34 | a matrix is not invertible.\n \ |
35 | \n\ |
36 | Example:\n \ |
37 | >>> # xdoctest: +REQUIRES(env:TORCH_DOCKTEST_LAPACK)\n \ |
38 | >>> matrix = torch.eye(3, 3)\n \ |
39 | >>> matrix[-1, -1] = 0\n \ |
40 | >>> matrix\n \ |
41 | tensor([[1., 0., 0.],\n \ |
42 | [0., 1., 0.],\n \ |
43 | [0., 0., 0.]])\n \ |
44 | >>> torch.linalg.inv(matrix)\n \ |
45 | Traceback (most recent call last):\n \ |
46 | File \"<stdin>\", line 1, in <module>\n \ |
47 | torch._C._LinAlgError: torch.linalg.inv: The diagonal element 3 is zero, the inversion\n \ |
48 | could not be completed because the input matrix is singular." , |
49 | PyExc_RuntimeError, |
50 | nullptr)); |
51 | ASSERT_TRUE( |
52 | PyModule_AddObject(module, "_LinAlgError" , THPException_LinAlgError) == |
53 | 0); |
54 | |
55 | ASSERT_TRUE( |
56 | THPException_OutOfMemoryError = PyErr_NewExceptionWithDoc( |
57 | "torch.cuda.OutOfMemoryError" , |
58 | "Exception raised when CUDA is out of memory" , |
59 | PyExc_RuntimeError, |
60 | nullptr)); |
61 | ASSERT_TRUE( |
62 | PyModule_AddObject( |
63 | module, "_OutOfMemoryError" , THPException_OutOfMemoryError) == 0); |
64 | |
65 | ASSERT_TRUE( |
66 | THPException_DistBackendError = PyErr_NewExceptionWithDoc( |
67 | "torch.distributed.DistBackendError" , |
68 | "Exception raised when a backend error occurs in distributed" , |
69 | PyExc_RuntimeError, |
70 | nullptr)); |
71 | ASSERT_TRUE( |
72 | PyModule_AddObject( |
73 | module, "_DistBackendError" , THPException_DistBackendError) == 0); |
74 | |
75 | return true; |
76 | } |
77 | |
78 | namespace torch { |
79 | |
80 | void processErrorMsgInplace(std::string& str) { |
81 | // Translate Aten types to their respective pytorch ones |
82 | constexpr std::array<std::pair<c10::string_view, c10::string_view>, 64> |
83 | changes{{ |
84 | {"Variable[SparseCUDAByteType]" , "torch.cuda.sparse.ByteTensor" }, |
85 | {"Variable[SparseCUDACharType]" , "torch.cuda.sparse.CharTensor" }, |
86 | {"Variable[SparseCUDADoubleType]" , "torch.cuda.sparse.DoubleTensor" }, |
87 | {"Variable[SparseCUDAFloatType]" , "torch.cuda.sparse.FloatTensor" }, |
88 | {"Variable[SparseCUDAIntType]" , "torch.cuda.sparse.IntTensor" }, |
89 | {"Variable[SparseCUDALongType]" , "torch.cuda.sparse.LongTensor" }, |
90 | {"Variable[SparseCUDAShortType]" , "torch.cuda.sparse.ShortTensor" }, |
91 | {"Variable[SparseCUDAHalfType]" , "torch.cuda.sparse.HalfTensor" }, |
92 | {"Variable[SparseCPUByteType]" , "torch.sparse.ByteTensor" }, |
93 | {"Variable[SparseCPUCharType]" , "torch.sparse.CharTensor" }, |
94 | {"Variable[SparseCPUDoubleType]" , "torch.sparse.DoubleTensor" }, |
95 | {"Variable[SparseCPUFloatType]" , "torch.sparse.FloatTensor" }, |
96 | {"Variable[SparseCPUIntType]" , "torch.sparse.IntTensor" }, |
97 | {"Variable[SparseCPULongType]" , "torch.sparse.LongTensor" }, |
98 | {"Variable[SparseCPUShortType]" , "torch.sparse.ShortTensor" }, |
99 | {"Variable[SparseCPUHalfType]" , "torch.sparse.HalfTensor" }, |
100 | {"Variable[CUDAByteType]" , "torch.cuda.ByteTensor" }, |
101 | {"Variable[CUDACharType]" , "torch.cuda.CharTensor" }, |
102 | {"Variable[CUDADoubleType]" , "torch.cuda.DoubleTensor" }, |
103 | {"Variable[CUDAFloatType]" , "torch.cuda.FloatTensor" }, |
104 | {"Variable[CUDAIntType]" , "torch.cuda.IntTensor" }, |
105 | {"Variable[CUDALongType]" , "torch.cuda.LongTensor" }, |
106 | {"Variable[CUDAShortType]" , "torch.cuda.ShortTensor" }, |
107 | {"Variable[CUDAHalfType]" , "torch.cuda.HalfTensor" }, |
108 | {"Variable[CPUByteType]" , "torch.ByteTensor" }, |
109 | {"Variable[CPUCharType]" , "torch.CharTensor" }, |
110 | {"Variable[CPUDoubleType]" , "torch.DoubleTensor" }, |
111 | {"Variable[CPUFloatType]" , "torch.FloatTensor" }, |
112 | {"Variable[CPUIntType]" , "torch.IntTensor" }, |
113 | {"Variable[CPULongType]" , "torch.LongTensor" }, |
114 | {"Variable[CPUShortType]" , "torch.ShortTensor" }, |
115 | {"Variable[CPUHalfType]" , "torch.HalfTensor" }, |
116 | {"SparseCUDAByteType" , "torch.cuda.sparse.ByteTensor" }, |
117 | {"SparseCUDACharType" , "torch.cuda.sparse.CharTensor" }, |
118 | {"SparseCUDADoubleType" , "torch.cuda.sparse.DoubleTensor" }, |
119 | {"SparseCUDAFloatType" , "torch.cuda.sparse.FloatTensor" }, |
120 | {"SparseCUDAIntType" , "torch.cuda.sparse.IntTensor" }, |
121 | {"SparseCUDALongType" , "torch.cuda.sparse.LongTensor" }, |
122 | {"SparseCUDAShortType" , "torch.cuda.sparse.ShortTensor" }, |
123 | {"SparseCUDAHalfType" , "torch.cuda.sparse.HalfTensor" }, |
124 | {"SparseCPUByteType" , "torch.sparse.ByteTensor" }, |
125 | {"SparseCPUCharType" , "torch.sparse.CharTensor" }, |
126 | {"SparseCPUDoubleType" , "torch.sparse.DoubleTensor" }, |
127 | {"SparseCPUFloatType" , "torch.sparse.FloatTensor" }, |
128 | {"SparseCPUIntType" , "torch.sparse.IntTensor" }, |
129 | {"SparseCPULongType" , "torch.sparse.LongTensor" }, |
130 | {"SparseCPUShortType" , "torch.sparse.ShortTensor" }, |
131 | {"SparseCPUHalfType" , "torch.sparse.HalfTensor" }, |
132 | {"CUDAByteType" , "torch.cuda.ByteTensor" }, |
133 | {"CUDACharType" , "torch.cuda.CharTensor" }, |
134 | {"CUDADoubleType" , "torch.cuda.DoubleTensor" }, |
135 | {"CUDAFloatType" , "torch.cuda.FloatTensor" }, |
136 | {"CUDAIntType" , "torch.cuda.IntTensor" }, |
137 | {"CUDALongType" , "torch.cuda.LongTensor" }, |
138 | {"CUDAShortType" , "torch.cuda.ShortTensor" }, |
139 | {"CUDAHalfType" , "torch.cuda.HalfTensor" }, |
140 | {"CPUByteType" , "torch.ByteTensor" }, |
141 | {"CPUCharType" , "torch.CharTensor" }, |
142 | {"CPUDoubleType" , "torch.DoubleTensor" }, |
143 | {"CPUFloatType" , "torch.FloatTensor" }, |
144 | {"CPUIntType" , "torch.IntTensor" }, |
145 | {"CPULongType" , "torch.LongTensor" }, |
146 | {"CPUShortType" , "torch.ShortTensor" }, |
147 | {"CPUHalfType" , "torch.HalfTensor" }, |
148 | }}; |
149 | |
150 | // Avoid doing any work if no types need translated |
151 | if (str.find("Type" ) == str.npos) { |
152 | return; |
153 | } |
154 | for (const auto& it : changes) { |
155 | c10::ReplaceAll(str, it.first, it.second); |
156 | } |
157 | } |
158 | |
159 | std::string processErrorMsg(std::string str) { |
160 | processErrorMsgInplace(str); |
161 | return str; |
162 | } |
163 | |
164 | static std::string formatMessage(const char* format, va_list fmt_args) { |
165 | static const size_t ERROR_BUF_SIZE = 1024; |
166 | // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) |
167 | char error_buf[ERROR_BUF_SIZE]; |
168 | vsnprintf(error_buf, ERROR_BUF_SIZE, format, fmt_args); |
169 | |
170 | // Ensure that the string is null terminated |
171 | error_buf[sizeof(error_buf) / sizeof(*error_buf) - 1] = 0; |
172 | |
173 | return std::string(error_buf); |
174 | } |
175 | |
176 | void translate_exception_to_python(const std::exception_ptr& e_ptr) { |
177 | try { |
178 | TORCH_INTERNAL_ASSERT( |
179 | e_ptr, |
180 | "translate_exception_to_python " |
181 | "called with invalid exception pointer" ); |
182 | std::rethrow_exception(e_ptr); |
183 | } |
184 | CATCH_ALL_ERRORS(return ) |
185 | } |
186 | |
187 | IndexError::IndexError(const char* format, ...) { |
188 | va_list fmt_args; |
189 | va_start(fmt_args, format); |
190 | msg = formatMessage(format, fmt_args); |
191 | va_end(fmt_args); |
192 | } |
193 | |
194 | TypeError::TypeError(const char* format, ...) { |
195 | va_list fmt_args; |
196 | va_start(fmt_args, format); |
197 | msg = formatMessage(format, fmt_args); |
198 | va_end(fmt_args); |
199 | } |
200 | |
201 | ValueError::ValueError(const char* format, ...) { |
202 | va_list fmt_args; |
203 | va_start(fmt_args, format); |
204 | msg = formatMessage(format, fmt_args); |
205 | va_end(fmt_args); |
206 | } |
207 | |
208 | AttributeError::AttributeError(const char* format, ...) { |
209 | va_list fmt_args; |
210 | va_start(fmt_args, format); |
211 | msg = formatMessage(format, fmt_args); |
212 | va_end(fmt_args); |
213 | } |
214 | |
215 | LinAlgError::LinAlgError(const char* format, ...) { |
216 | va_list fmt_args; |
217 | va_start(fmt_args, format); |
218 | msg = formatMessage(format, fmt_args); |
219 | va_end(fmt_args); |
220 | } |
221 | |
222 | void PyWarningHandler::InternalHandler::process(const c10::Warning& warning) { |
223 | warning_buffer_.push_back(warning); |
224 | } |
225 | |
226 | PyWarningHandler::PyWarningHandler() noexcept(true) |
227 | : prev_handler_(c10::WarningUtils::get_warning_handler()), |
228 | in_exception_(false) { |
229 | c10::WarningUtils::set_warning_handler(&internal_handler_); |
230 | } |
231 | |
232 | // Get the Python warning type for a warning |
233 | PyObject* map_warning_to_python_type(const c10::Warning& warning) { |
234 | struct Visitor { |
235 | PyObject* operator()(const c10::UserWarning&) const { |
236 | return PyExc_UserWarning; |
237 | } |
238 | PyObject* operator()(const c10::DeprecationWarning&) const { |
239 | return PyExc_DeprecationWarning; |
240 | } |
241 | }; |
242 | return c10::visit(Visitor(), warning.type()); |
243 | } |
244 | |
245 | /// See NOTE [ Conversion Cpp Python Warning ] for noexcept justification |
246 | /// NOLINTNEXTLINE(bugprone-exception-escape) |
247 | PyWarningHandler::~PyWarningHandler() noexcept(false) { |
248 | c10::WarningUtils::set_warning_handler(prev_handler_); |
249 | auto& warning_buffer = internal_handler_.warning_buffer_; |
250 | |
251 | if (!warning_buffer.empty()) { |
252 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
253 | PyObject *type, *value, *traceback; |
254 | pybind11::gil_scoped_acquire gil; |
255 | auto result = 0; |
256 | if (in_exception_) { |
257 | // This (combined with PyErr_Restore below) also works when no python |
258 | // error has been set yet |
259 | PyErr_Fetch(&type, &value, &traceback); |
260 | } |
261 | for (const auto& warning : warning_buffer) { |
262 | auto source_location = warning.source_location(); |
263 | auto msg = warning.msg(); |
264 | processErrorMsgInplace(msg); |
265 | if (source_location.file == nullptr) { |
266 | result = |
267 | PyErr_WarnEx(map_warning_to_python_type(warning), msg.c_str(), 1); |
268 | } else if (warning.verbatim()) { |
269 | // Sets the source location from the warning |
270 | // Note: PyErr_WarnExplicit will disregard Python's warning filter |
271 | // and always appear. This is in contrast to PyErr_WarnEx, |
272 | // which respects the warning filter. |
273 | result = PyErr_WarnExplicit( |
274 | /*category=*/map_warning_to_python_type(warning), |
275 | /*message=*/msg.c_str(), |
276 | /*filename=*/source_location.file, |
277 | /*lineno=*/source_location.line, |
278 | /*module=*/nullptr, |
279 | /*registry=*/nullptr); |
280 | } else { |
281 | // Lets Python set the source location and puts the C++ warning |
282 | // location into the message. |
283 | auto buf = fmt::format( |
284 | "{} (Triggered internally at {}:{}.)" , |
285 | msg, |
286 | source_location.file, |
287 | source_location.line); |
288 | result = |
289 | PyErr_WarnEx(map_warning_to_python_type(warning), buf.c_str(), 1); |
290 | } |
291 | if (result < 0) { |
292 | if (in_exception_) { |
293 | // PyErr_Print prints the traceback to sys.stderr and |
294 | // clears the error indicator |
295 | PyErr_Print(); |
296 | } else { |
297 | break; |
298 | } |
299 | } |
300 | } |
301 | warning_buffer.clear(); |
302 | if ((result < 0) && (!in_exception_)) { |
303 | /// A warning raised an error, we need to force the parent |
304 | /// function to return an error code. |
305 | throw python_error(); |
306 | } |
307 | if (in_exception_) { |
308 | // NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage) |
309 | PyErr_Restore(type, value, traceback); |
310 | } |
311 | } |
312 | } |
313 | |
314 | } // namespace torch |
315 | |