1 | #pragma once |
2 | |
3 | #include <exception> |
4 | #include <memory> |
5 | #include <mutex> |
6 | #include <queue> |
7 | #include <string> |
8 | #include <system_error> |
9 | |
10 | #include <ATen/detail/FunctionTraits.h> |
11 | #include <c10/util/C++17.h> |
12 | #include <c10/util/Exception.h> |
13 | #include <c10/util/StringUtil.h> |
14 | #include <pybind11/pybind11.h> |
15 | #include <torch/csrc/Export.h> |
16 | #include <torch/csrc/jit/runtime/jit_exception.h> |
17 | #include <torch/csrc/utils/auto_gil.h> |
18 | #include <torch/csrc/utils/cpp_stacktraces.h> |
19 | #include <torch/csrc/utils/pybind.h> |
20 | |
21 | #if defined(USE_DISTRIBUTED) && defined(USE_C10D) |
22 | #include <torch/csrc/distributed/c10d/exception.h> |
23 | #endif |
24 | |
25 | static inline void PyErr_SetString(PyObject* type, const std::string& message) { |
26 | PyErr_SetString(type, message.c_str()); |
27 | } |
28 | /// NOTE [ Conversion Cpp Python Warning ] |
29 | /// The warning handler cannot set python warnings immediately |
30 | /// as it requires acquiring the GIL (potential deadlock) |
31 | /// and would need to cleanly exit if the warning raised a |
32 | /// python error. To solve this, we buffer the warnings and |
33 | /// process them when we go back to python. |
34 | /// This requires the two try/catch blocks below to handle the |
35 | /// following cases: |
36 | /// - If there is no Error raised in the inner try/catch, the |
37 | /// buffered warnings are processed as python warnings. |
38 | /// - If they don't raise an error, the function process with the |
39 | /// original return code. |
40 | /// - If any of them raise an error, the error is set (PyErr_*) and |
41 | /// the destructor will raise a cpp exception python_error() that |
42 | /// will be caught by the outer try/catch that will be able to change |
43 | /// the return value of the function to reflect the error. |
44 | /// - If an Error was raised in the inner try/catch, the inner try/catch |
45 | /// must set the python error. The buffered warnings are then |
46 | /// processed as cpp warnings as we cannot predict before hand |
47 | /// whether a python warning will raise an error or not and we |
48 | /// cannot handle two errors at the same time. |
49 | /// This advanced handler will only be used in the current thread. |
50 | /// If any other thread is used, warnings will be processed as |
51 | /// cpp warnings. |
52 | #define HANDLE_TH_ERRORS \ |
53 | try { \ |
54 | torch::PyWarningHandler __enforce_warning_buffer; \ |
55 | try { |
56 | #define _CATCH_GENERIC_ERROR(ErrorType, PythonErrorType, retstmnt) \ |
57 | catch (const c10::ErrorType& e) { \ |
58 | auto msg = torch::get_cpp_stacktraces_enabled() \ |
59 | ? e.what() \ |
60 | : e.what_without_backtrace(); \ |
61 | PyErr_SetString(PythonErrorType, torch::processErrorMsg(msg)); \ |
62 | retstmnt; \ |
63 | } |
64 | |
65 | // Only catch torch-specific exceptions |
66 | #define CATCH_CORE_ERRORS(retstmnt) \ |
67 | catch (python_error & e) { \ |
68 | e.restore(); \ |
69 | retstmnt; \ |
70 | } \ |
71 | catch (py::error_already_set & e) { \ |
72 | e.restore(); \ |
73 | retstmnt; \ |
74 | } \ |
75 | _CATCH_GENERIC_ERROR(IndexError, PyExc_IndexError, retstmnt) \ |
76 | _CATCH_GENERIC_ERROR(ValueError, PyExc_ValueError, retstmnt) \ |
77 | _CATCH_GENERIC_ERROR(TypeError, PyExc_TypeError, retstmnt) \ |
78 | _CATCH_GENERIC_ERROR( \ |
79 | NotImplementedError, PyExc_NotImplementedError, retstmnt) \ |
80 | _CATCH_GENERIC_ERROR(LinAlgError, THPException_LinAlgError, retstmnt) \ |
81 | _CATCH_GENERIC_ERROR( \ |
82 | OutOfMemoryError, THPException_OutOfMemoryError, retstmnt) \ |
83 | _CATCH_GENERIC_ERROR( \ |
84 | DistBackendError, THPException_DistBackendError, retstmnt) \ |
85 | _CATCH_GENERIC_ERROR(Error, PyExc_RuntimeError, retstmnt) \ |
86 | catch (torch::PyTorchError & e) { \ |
87 | auto msg = torch::processErrorMsg(e.what()); \ |
88 | PyErr_SetString(e.python_type(), msg); \ |
89 | retstmnt; \ |
90 | } |
91 | |
92 | #if defined(USE_DISTRIBUTED) && defined(USE_C10D) |
93 | #define CATCH_C10D_ERRORS(retstmnt) \ |
94 | catch (const c10d::TimeoutError& e) { \ |
95 | auto msg = torch::processErrorMsg(e.what()); \ |
96 | PyErr_SetString(PyExc_TimeoutError, msg); \ |
97 | retstmnt; \ |
98 | } \ |
99 | catch (const c10d::C10dError& e) { \ |
100 | auto msg = torch::processErrorMsg(e.what()); \ |
101 | PyErr_SetString(PyExc_RuntimeError, msg); \ |
102 | retstmnt; \ |
103 | } |
104 | #else |
105 | #define CATCH_C10D_ERRORS(retstmnt) |
106 | #endif |
107 | |
108 | #define CATCH_TH_ERRORS(retstmnt) \ |
109 | CATCH_CORE_ERRORS(retstmnt) \ |
110 | CATCH_C10D_ERRORS(retstmnt) |
111 | |
112 | #define CATCH_ALL_ERRORS(retstmnt) \ |
113 | CATCH_TH_ERRORS(retstmnt) \ |
114 | catch (const std::exception& e) { \ |
115 | auto msg = torch::processErrorMsg(e.what()); \ |
116 | PyErr_SetString(PyExc_RuntimeError, msg); \ |
117 | retstmnt; \ |
118 | } |
119 | |
120 | #define END_HANDLE_TH_ERRORS_PYBIND \ |
121 | } \ |
122 | catch (...) { \ |
123 | __enforce_warning_buffer.set_in_exception(); \ |
124 | throw; \ |
125 | } \ |
126 | } \ |
127 | catch (py::error_already_set & e) { \ |
128 | throw; \ |
129 | } \ |
130 | catch (py::builtin_exception & e) { \ |
131 | throw; \ |
132 | } \ |
133 | catch (torch::jit::JITException & e) { \ |
134 | throw; \ |
135 | } \ |
136 | catch (const std::exception& e) { \ |
137 | torch::translate_exception_to_python(std::current_exception()); \ |
138 | throw py::error_already_set(); \ |
139 | } |
140 | |
141 | #define END_HANDLE_TH_ERRORS_RET(retval) \ |
142 | } \ |
143 | catch (...) { \ |
144 | __enforce_warning_buffer.set_in_exception(); \ |
145 | throw; \ |
146 | } \ |
147 | } \ |
148 | catch (const std::exception& e) { \ |
149 | torch::translate_exception_to_python(std::current_exception()); \ |
150 | return retval; \ |
151 | } |
152 | |
153 | #define END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS_RET(nullptr) |
154 | |
155 | extern PyObject *THPException_FatalError, *THPException_LinAlgError, |
156 | *THPException_OutOfMemoryError, *THPException_DistBackendError; |
157 | |
158 | // Throwing this exception means that the python error flags have been already |
159 | // set and control should be immediately returned to the interpreter. |
160 | struct python_error : public std::exception { |
161 | python_error() : type(nullptr), value(nullptr), traceback(nullptr) {} |
162 | |
163 | python_error(const python_error& other) |
164 | : type(other.type), |
165 | value(other.value), |
166 | traceback(other.traceback), |
167 | message(other.message) { |
168 | pybind11::gil_scoped_acquire gil; |
169 | Py_XINCREF(type); |
170 | Py_XINCREF(value); |
171 | Py_XINCREF(traceback); |
172 | } |
173 | |
174 | python_error(python_error&& other) noexcept |
175 | : type(other.type), |
176 | value(other.value), |
177 | traceback(other.traceback), |
178 | message(std::move(other.message)) { |
179 | other.type = nullptr; |
180 | other.value = nullptr; |
181 | other.traceback = nullptr; |
182 | } |
183 | |
184 | ~python_error() override { |
185 | if (type || value || traceback) { |
186 | pybind11::gil_scoped_acquire gil; |
187 | Py_XDECREF(type); |
188 | Py_XDECREF(value); |
189 | Py_XDECREF(traceback); |
190 | } |
191 | } |
192 | |
193 | const char* what() const noexcept override { |
194 | return message.c_str(); |
195 | } |
196 | |
197 | void build_message() { |
198 | // Ensure we have the GIL. |
199 | pybind11::gil_scoped_acquire gil; |
200 | |
201 | // No errors should be set when we enter the function since PyErr_Fetch |
202 | // clears the error indicator. |
203 | TORCH_INTERNAL_ASSERT(!PyErr_Occurred()); |
204 | |
205 | // Default message. |
206 | message = "python_error" ; |
207 | |
208 | // Try to retrieve the error message from the value. |
209 | if (value != nullptr) { |
210 | // Reference count should not be zero. |
211 | TORCH_INTERNAL_ASSERT(Py_REFCNT(value) > 0); |
212 | |
213 | PyObject* pyStr = PyObject_Str(value); |
214 | if (pyStr != nullptr) { |
215 | PyObject* encodedString = |
216 | PyUnicode_AsEncodedString(pyStr, "utf-8" , "strict" ); |
217 | if (encodedString != nullptr) { |
218 | char* bytes = PyBytes_AS_STRING(encodedString); |
219 | if (bytes != nullptr) { |
220 | // Set the message. |
221 | message = std::string(bytes); |
222 | } |
223 | Py_XDECREF(encodedString); |
224 | } |
225 | Py_XDECREF(pyStr); |
226 | } |
227 | } |
228 | |
229 | // Clear any errors since we don't want to propagate errors for functions |
230 | // that are trying to build a string for the error message. |
231 | PyErr_Clear(); |
232 | } |
233 | |
234 | /** Saves the exception so that it can be re-thrown on a different thread */ |
235 | inline void persist() { |
236 | if (type) |
237 | return; // Don't overwrite exceptions |
238 | // PyErr_Fetch overwrites the pointers |
239 | pybind11::gil_scoped_acquire gil; |
240 | Py_XDECREF(type); |
241 | Py_XDECREF(value); |
242 | Py_XDECREF(traceback); |
243 | PyErr_Fetch(&type, &value, &traceback); |
244 | build_message(); |
245 | } |
246 | |
247 | /** Sets the current Python error from this exception */ |
248 | inline void restore() { |
249 | if (!type) |
250 | return; |
251 | // PyErr_Restore steals references |
252 | pybind11::gil_scoped_acquire gil; |
253 | Py_XINCREF(type); |
254 | Py_XINCREF(value); |
255 | Py_XINCREF(traceback); |
256 | PyErr_Restore(type, value, traceback); |
257 | } |
258 | |
259 | PyObject* type; |
260 | PyObject* value; |
261 | PyObject* traceback; |
262 | |
263 | // Message to return to the user when 'what()' is invoked. |
264 | std::string message; |
265 | }; |
266 | |
267 | bool THPException_init(PyObject* module); |
268 | |
269 | namespace torch { |
270 | |
271 | // Set python current exception from a C++ exception |
272 | TORCH_PYTHON_API void translate_exception_to_python(const std::exception_ptr&); |
273 | |
274 | TORCH_PYTHON_API std::string processErrorMsg(std::string str); |
275 | |
276 | // Abstract base class for exceptions which translate to specific Python types |
277 | struct PyTorchError : public std::exception { |
278 | PyTorchError() = default; |
279 | PyTorchError(std::string msg_) : msg(std::move(msg_)) {} |
280 | virtual PyObject* python_type() = 0; |
281 | const char* what() const noexcept override { |
282 | return msg.c_str(); |
283 | } |
284 | std::string msg; |
285 | }; |
286 | |
287 | // Declare a printf-like function on gcc & clang |
288 | // The compiler can then warn on invalid format specifiers |
289 | #ifdef __GNUC__ |
290 | #define TORCH_FORMAT_FUNC(FORMAT_INDEX, VA_ARGS_INDEX) \ |
291 | __attribute__((format(printf, FORMAT_INDEX, VA_ARGS_INDEX))) |
292 | #else |
293 | #define TORCH_FORMAT_FUNC(FORMAT_INDEX, VA_ARGS_INDEX) |
294 | #endif |
295 | |
296 | // Translates to Python IndexError |
297 | struct IndexError : public PyTorchError { |
298 | using PyTorchError::PyTorchError; |
299 | IndexError(const char* format, ...) TORCH_FORMAT_FUNC(2, 3); |
300 | PyObject* python_type() override { |
301 | return PyExc_IndexError; |
302 | } |
303 | }; |
304 | |
305 | // Translates to Python TypeError |
306 | struct TypeError : public PyTorchError { |
307 | using PyTorchError::PyTorchError; |
308 | TORCH_API TypeError(const char* format, ...) TORCH_FORMAT_FUNC(2, 3); |
309 | PyObject* python_type() override { |
310 | return PyExc_TypeError; |
311 | } |
312 | }; |
313 | |
314 | // Translates to Python ValueError |
315 | struct ValueError : public PyTorchError { |
316 | using PyTorchError::PyTorchError; |
317 | ValueError(const char* format, ...) TORCH_FORMAT_FUNC(2, 3); |
318 | PyObject* python_type() override { |
319 | return PyExc_ValueError; |
320 | } |
321 | }; |
322 | |
323 | // Translates to Python NotImplementedError |
324 | struct NotImplementedError : public PyTorchError { |
325 | NotImplementedError() = default; |
326 | PyObject* python_type() override { |
327 | return PyExc_NotImplementedError; |
328 | } |
329 | }; |
330 | |
331 | // Translates to Python AttributeError |
332 | struct AttributeError : public PyTorchError { |
333 | AttributeError(const char* format, ...) TORCH_FORMAT_FUNC(2, 3); |
334 | PyObject* python_type() override { |
335 | return PyExc_AttributeError; |
336 | } |
337 | }; |
338 | |
339 | // Translates to Python LinAlgError |
340 | struct LinAlgError : public PyTorchError { |
341 | LinAlgError(const char* format, ...) TORCH_FORMAT_FUNC(2, 3); |
342 | PyObject* python_type() override { |
343 | return THPException_LinAlgError; |
344 | } |
345 | }; |
346 | |
347 | // ATen warning handler for Python |
348 | struct PyWarningHandler { |
349 | // Move actual handler into a separate class with a noexcept |
350 | // destructor. Otherwise, we need to force all WarningHandler |
351 | // subclasses to have a noexcept(false) destructor. |
352 | struct InternalHandler : at::WarningHandler { |
353 | ~InternalHandler() override = default; |
354 | void process(const c10::Warning& warning) override; |
355 | |
356 | std::vector<c10::Warning> warning_buffer_; |
357 | }; |
358 | |
359 | public: |
360 | /// See NOTE [ Conversion Cpp Python Warning ] for noexcept justification |
361 | TORCH_API PyWarningHandler() noexcept(true); |
362 | // NOLINTNEXTLINE(bugprone-exception-escape) |
363 | TORCH_API ~PyWarningHandler() noexcept(false); |
364 | |
365 | /** Call if an exception has been thrown |
366 | |
367 | * Necessary to determine if it is safe to throw from the desctructor since |
368 | * std::uncaught_exception is buggy on some platforms and generally |
369 | * unreliable across dynamic library calls. |
370 | */ |
371 | void set_in_exception() { |
372 | in_exception_ = true; |
373 | } |
374 | |
375 | private: |
376 | InternalHandler internal_handler_; |
377 | at::WarningHandler* prev_handler_; |
378 | bool in_exception_; |
379 | }; |
380 | |
381 | namespace detail { |
382 | template <typename Func, size_t i> |
383 | using Arg = typename invoke_traits<Func>::template arg<i>::type; |
384 | |
385 | template <typename Func, size_t... Is> |
386 | auto wrap_pybind_function_impl_( |
387 | Func&& f, |
388 | std::index_sequence<Is...>, |
389 | bool release_gil) { |
390 | using result_type = typename invoke_traits<Func>::result_type; |
391 | namespace py = pybind11; |
392 | |
393 | // f=f is needed to handle function references on older compilers |
394 | return [f = std::forward<Func>(f), |
395 | release_gil](Arg<Func, Is>... args) -> result_type { |
396 | HANDLE_TH_ERRORS |
397 | if (release_gil) { |
398 | py::gil_scoped_release no_gil; |
399 | return c10::guts::invoke(f, std::forward<Arg<Func, Is>>(args)...); |
400 | } else { |
401 | return c10::guts::invoke(f, std::forward<Arg<Func, Is>>(args)...); |
402 | } |
403 | END_HANDLE_TH_ERRORS_PYBIND |
404 | }; |
405 | } |
406 | } // namespace detail |
407 | |
408 | // Wrap a function with TH error and warning handling. |
409 | // Returns a function object suitable for registering with pybind11. |
410 | template <typename Func> |
411 | auto wrap_pybind_function(Func&& f) { |
412 | using traits = invoke_traits<Func>; |
413 | return torch::detail::wrap_pybind_function_impl_( |
414 | std::forward<Func>(f), std::make_index_sequence<traits::arity>{}, false); |
415 | } |
416 | |
417 | // Wrap a function with TH error, warning handling and releases the GIL. |
418 | // Returns a function object suitable for registering with pybind11. |
419 | template <typename Func> |
420 | auto wrap_pybind_function_no_gil(Func&& f) { |
421 | using traits = invoke_traits<Func>; |
422 | return torch::detail::wrap_pybind_function_impl_( |
423 | std::forward<Func>(f), std::make_index_sequence<traits::arity>{}, true); |
424 | } |
425 | |
426 | } // namespace torch |
427 | |