1#pragma once
2
3#include <exception>
4#include <memory>
5#include <mutex>
6#include <queue>
7#include <string>
8#include <system_error>
9
10#include <ATen/detail/FunctionTraits.h>
11#include <c10/util/C++17.h>
12#include <c10/util/Exception.h>
13#include <c10/util/StringUtil.h>
14#include <pybind11/pybind11.h>
15#include <torch/csrc/Export.h>
16#include <torch/csrc/jit/runtime/jit_exception.h>
17#include <torch/csrc/utils/auto_gil.h>
18#include <torch/csrc/utils/cpp_stacktraces.h>
19#include <torch/csrc/utils/pybind.h>
20
21#if defined(USE_DISTRIBUTED) && defined(USE_C10D)
22#include <torch/csrc/distributed/c10d/exception.h>
23#endif
24
25static inline void PyErr_SetString(PyObject* type, const std::string& message) {
26 PyErr_SetString(type, message.c_str());
27}
28/// NOTE [ Conversion Cpp Python Warning ]
29/// The warning handler cannot set python warnings immediately
30/// as it requires acquiring the GIL (potential deadlock)
31/// and would need to cleanly exit if the warning raised a
32/// python error. To solve this, we buffer the warnings and
33/// process them when we go back to python.
34/// This requires the two try/catch blocks below to handle the
35/// following cases:
36/// - If there is no Error raised in the inner try/catch, the
37/// buffered warnings are processed as python warnings.
38/// - If they don't raise an error, the function process with the
39/// original return code.
40/// - If any of them raise an error, the error is set (PyErr_*) and
41/// the destructor will raise a cpp exception python_error() that
42/// will be caught by the outer try/catch that will be able to change
43/// the return value of the function to reflect the error.
44/// - If an Error was raised in the inner try/catch, the inner try/catch
45/// must set the python error. The buffered warnings are then
46/// processed as cpp warnings as we cannot predict before hand
47/// whether a python warning will raise an error or not and we
48/// cannot handle two errors at the same time.
49/// This advanced handler will only be used in the current thread.
50/// If any other thread is used, warnings will be processed as
51/// cpp warnings.
52#define HANDLE_TH_ERRORS \
53 try { \
54 torch::PyWarningHandler __enforce_warning_buffer; \
55 try {
56#define _CATCH_GENERIC_ERROR(ErrorType, PythonErrorType, retstmnt) \
57 catch (const c10::ErrorType& e) { \
58 auto msg = torch::get_cpp_stacktraces_enabled() \
59 ? e.what() \
60 : e.what_without_backtrace(); \
61 PyErr_SetString(PythonErrorType, torch::processErrorMsg(msg)); \
62 retstmnt; \
63 }
64
65// Only catch torch-specific exceptions
66#define CATCH_CORE_ERRORS(retstmnt) \
67 catch (python_error & e) { \
68 e.restore(); \
69 retstmnt; \
70 } \
71 catch (py::error_already_set & e) { \
72 e.restore(); \
73 retstmnt; \
74 } \
75 _CATCH_GENERIC_ERROR(IndexError, PyExc_IndexError, retstmnt) \
76 _CATCH_GENERIC_ERROR(ValueError, PyExc_ValueError, retstmnt) \
77 _CATCH_GENERIC_ERROR(TypeError, PyExc_TypeError, retstmnt) \
78 _CATCH_GENERIC_ERROR( \
79 NotImplementedError, PyExc_NotImplementedError, retstmnt) \
80 _CATCH_GENERIC_ERROR(LinAlgError, THPException_LinAlgError, retstmnt) \
81 _CATCH_GENERIC_ERROR( \
82 OutOfMemoryError, THPException_OutOfMemoryError, retstmnt) \
83 _CATCH_GENERIC_ERROR( \
84 DistBackendError, THPException_DistBackendError, retstmnt) \
85 _CATCH_GENERIC_ERROR(Error, PyExc_RuntimeError, retstmnt) \
86 catch (torch::PyTorchError & e) { \
87 auto msg = torch::processErrorMsg(e.what()); \
88 PyErr_SetString(e.python_type(), msg); \
89 retstmnt; \
90 }
91
92#if defined(USE_DISTRIBUTED) && defined(USE_C10D)
93#define CATCH_C10D_ERRORS(retstmnt) \
94 catch (const c10d::TimeoutError& e) { \
95 auto msg = torch::processErrorMsg(e.what()); \
96 PyErr_SetString(PyExc_TimeoutError, msg); \
97 retstmnt; \
98 } \
99 catch (const c10d::C10dError& e) { \
100 auto msg = torch::processErrorMsg(e.what()); \
101 PyErr_SetString(PyExc_RuntimeError, msg); \
102 retstmnt; \
103 }
104#else
105#define CATCH_C10D_ERRORS(retstmnt)
106#endif
107
108#define CATCH_TH_ERRORS(retstmnt) \
109 CATCH_CORE_ERRORS(retstmnt) \
110 CATCH_C10D_ERRORS(retstmnt)
111
112#define CATCH_ALL_ERRORS(retstmnt) \
113 CATCH_TH_ERRORS(retstmnt) \
114 catch (const std::exception& e) { \
115 auto msg = torch::processErrorMsg(e.what()); \
116 PyErr_SetString(PyExc_RuntimeError, msg); \
117 retstmnt; \
118 }
119
120#define END_HANDLE_TH_ERRORS_PYBIND \
121 } \
122 catch (...) { \
123 __enforce_warning_buffer.set_in_exception(); \
124 throw; \
125 } \
126 } \
127 catch (py::error_already_set & e) { \
128 throw; \
129 } \
130 catch (py::builtin_exception & e) { \
131 throw; \
132 } \
133 catch (torch::jit::JITException & e) { \
134 throw; \
135 } \
136 catch (const std::exception& e) { \
137 torch::translate_exception_to_python(std::current_exception()); \
138 throw py::error_already_set(); \
139 }
140
141#define END_HANDLE_TH_ERRORS_RET(retval) \
142 } \
143 catch (...) { \
144 __enforce_warning_buffer.set_in_exception(); \
145 throw; \
146 } \
147 } \
148 catch (const std::exception& e) { \
149 torch::translate_exception_to_python(std::current_exception()); \
150 return retval; \
151 }
152
153#define END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS_RET(nullptr)
154
155extern PyObject *THPException_FatalError, *THPException_LinAlgError,
156 *THPException_OutOfMemoryError, *THPException_DistBackendError;
157
158// Throwing this exception means that the python error flags have been already
159// set and control should be immediately returned to the interpreter.
160struct python_error : public std::exception {
161 python_error() : type(nullptr), value(nullptr), traceback(nullptr) {}
162
163 python_error(const python_error& other)
164 : type(other.type),
165 value(other.value),
166 traceback(other.traceback),
167 message(other.message) {
168 pybind11::gil_scoped_acquire gil;
169 Py_XINCREF(type);
170 Py_XINCREF(value);
171 Py_XINCREF(traceback);
172 }
173
174 python_error(python_error&& other) noexcept
175 : type(other.type),
176 value(other.value),
177 traceback(other.traceback),
178 message(std::move(other.message)) {
179 other.type = nullptr;
180 other.value = nullptr;
181 other.traceback = nullptr;
182 }
183
184 ~python_error() override {
185 if (type || value || traceback) {
186 pybind11::gil_scoped_acquire gil;
187 Py_XDECREF(type);
188 Py_XDECREF(value);
189 Py_XDECREF(traceback);
190 }
191 }
192
193 const char* what() const noexcept override {
194 return message.c_str();
195 }
196
197 void build_message() {
198 // Ensure we have the GIL.
199 pybind11::gil_scoped_acquire gil;
200
201 // No errors should be set when we enter the function since PyErr_Fetch
202 // clears the error indicator.
203 TORCH_INTERNAL_ASSERT(!PyErr_Occurred());
204
205 // Default message.
206 message = "python_error";
207
208 // Try to retrieve the error message from the value.
209 if (value != nullptr) {
210 // Reference count should not be zero.
211 TORCH_INTERNAL_ASSERT(Py_REFCNT(value) > 0);
212
213 PyObject* pyStr = PyObject_Str(value);
214 if (pyStr != nullptr) {
215 PyObject* encodedString =
216 PyUnicode_AsEncodedString(pyStr, "utf-8", "strict");
217 if (encodedString != nullptr) {
218 char* bytes = PyBytes_AS_STRING(encodedString);
219 if (bytes != nullptr) {
220 // Set the message.
221 message = std::string(bytes);
222 }
223 Py_XDECREF(encodedString);
224 }
225 Py_XDECREF(pyStr);
226 }
227 }
228
229 // Clear any errors since we don't want to propagate errors for functions
230 // that are trying to build a string for the error message.
231 PyErr_Clear();
232 }
233
234 /** Saves the exception so that it can be re-thrown on a different thread */
235 inline void persist() {
236 if (type)
237 return; // Don't overwrite exceptions
238 // PyErr_Fetch overwrites the pointers
239 pybind11::gil_scoped_acquire gil;
240 Py_XDECREF(type);
241 Py_XDECREF(value);
242 Py_XDECREF(traceback);
243 PyErr_Fetch(&type, &value, &traceback);
244 build_message();
245 }
246
247 /** Sets the current Python error from this exception */
248 inline void restore() {
249 if (!type)
250 return;
251 // PyErr_Restore steals references
252 pybind11::gil_scoped_acquire gil;
253 Py_XINCREF(type);
254 Py_XINCREF(value);
255 Py_XINCREF(traceback);
256 PyErr_Restore(type, value, traceback);
257 }
258
259 PyObject* type;
260 PyObject* value;
261 PyObject* traceback;
262
263 // Message to return to the user when 'what()' is invoked.
264 std::string message;
265};
266
267bool THPException_init(PyObject* module);
268
269namespace torch {
270
271// Set python current exception from a C++ exception
272TORCH_PYTHON_API void translate_exception_to_python(const std::exception_ptr&);
273
274TORCH_PYTHON_API std::string processErrorMsg(std::string str);
275
276// Abstract base class for exceptions which translate to specific Python types
277struct PyTorchError : public std::exception {
278 PyTorchError() = default;
279 PyTorchError(std::string msg_) : msg(std::move(msg_)) {}
280 virtual PyObject* python_type() = 0;
281 const char* what() const noexcept override {
282 return msg.c_str();
283 }
284 std::string msg;
285};
286
287// Declare a printf-like function on gcc & clang
288// The compiler can then warn on invalid format specifiers
289#ifdef __GNUC__
290#define TORCH_FORMAT_FUNC(FORMAT_INDEX, VA_ARGS_INDEX) \
291 __attribute__((format(printf, FORMAT_INDEX, VA_ARGS_INDEX)))
292#else
293#define TORCH_FORMAT_FUNC(FORMAT_INDEX, VA_ARGS_INDEX)
294#endif
295
296// Translates to Python IndexError
297struct IndexError : public PyTorchError {
298 using PyTorchError::PyTorchError;
299 IndexError(const char* format, ...) TORCH_FORMAT_FUNC(2, 3);
300 PyObject* python_type() override {
301 return PyExc_IndexError;
302 }
303};
304
305// Translates to Python TypeError
306struct TypeError : public PyTorchError {
307 using PyTorchError::PyTorchError;
308 TORCH_API TypeError(const char* format, ...) TORCH_FORMAT_FUNC(2, 3);
309 PyObject* python_type() override {
310 return PyExc_TypeError;
311 }
312};
313
314// Translates to Python ValueError
315struct ValueError : public PyTorchError {
316 using PyTorchError::PyTorchError;
317 ValueError(const char* format, ...) TORCH_FORMAT_FUNC(2, 3);
318 PyObject* python_type() override {
319 return PyExc_ValueError;
320 }
321};
322
323// Translates to Python NotImplementedError
324struct NotImplementedError : public PyTorchError {
325 NotImplementedError() = default;
326 PyObject* python_type() override {
327 return PyExc_NotImplementedError;
328 }
329};
330
331// Translates to Python AttributeError
332struct AttributeError : public PyTorchError {
333 AttributeError(const char* format, ...) TORCH_FORMAT_FUNC(2, 3);
334 PyObject* python_type() override {
335 return PyExc_AttributeError;
336 }
337};
338
339// Translates to Python LinAlgError
340struct LinAlgError : public PyTorchError {
341 LinAlgError(const char* format, ...) TORCH_FORMAT_FUNC(2, 3);
342 PyObject* python_type() override {
343 return THPException_LinAlgError;
344 }
345};
346
347// ATen warning handler for Python
348struct PyWarningHandler {
349 // Move actual handler into a separate class with a noexcept
350 // destructor. Otherwise, we need to force all WarningHandler
351 // subclasses to have a noexcept(false) destructor.
352 struct InternalHandler : at::WarningHandler {
353 ~InternalHandler() override = default;
354 void process(const c10::Warning& warning) override;
355
356 std::vector<c10::Warning> warning_buffer_;
357 };
358
359 public:
360 /// See NOTE [ Conversion Cpp Python Warning ] for noexcept justification
361 TORCH_API PyWarningHandler() noexcept(true);
362 // NOLINTNEXTLINE(bugprone-exception-escape)
363 TORCH_API ~PyWarningHandler() noexcept(false);
364
365 /** Call if an exception has been thrown
366
367 * Necessary to determine if it is safe to throw from the desctructor since
368 * std::uncaught_exception is buggy on some platforms and generally
369 * unreliable across dynamic library calls.
370 */
371 void set_in_exception() {
372 in_exception_ = true;
373 }
374
375 private:
376 InternalHandler internal_handler_;
377 at::WarningHandler* prev_handler_;
378 bool in_exception_;
379};
380
381namespace detail {
382template <typename Func, size_t i>
383using Arg = typename invoke_traits<Func>::template arg<i>::type;
384
385template <typename Func, size_t... Is>
386auto wrap_pybind_function_impl_(
387 Func&& f,
388 std::index_sequence<Is...>,
389 bool release_gil) {
390 using result_type = typename invoke_traits<Func>::result_type;
391 namespace py = pybind11;
392
393 // f=f is needed to handle function references on older compilers
394 return [f = std::forward<Func>(f),
395 release_gil](Arg<Func, Is>... args) -> result_type {
396 HANDLE_TH_ERRORS
397 if (release_gil) {
398 py::gil_scoped_release no_gil;
399 return c10::guts::invoke(f, std::forward<Arg<Func, Is>>(args)...);
400 } else {
401 return c10::guts::invoke(f, std::forward<Arg<Func, Is>>(args)...);
402 }
403 END_HANDLE_TH_ERRORS_PYBIND
404 };
405}
406} // namespace detail
407
408// Wrap a function with TH error and warning handling.
409// Returns a function object suitable for registering with pybind11.
410template <typename Func>
411auto wrap_pybind_function(Func&& f) {
412 using traits = invoke_traits<Func>;
413 return torch::detail::wrap_pybind_function_impl_(
414 std::forward<Func>(f), std::make_index_sequence<traits::arity>{}, false);
415}
416
417// Wrap a function with TH error, warning handling and releases the GIL.
418// Returns a function object suitable for registering with pybind11.
419template <typename Func>
420auto wrap_pybind_function_no_gil(Func&& f) {
421 using traits = invoke_traits<Func>;
422 return torch::detail::wrap_pybind_function_impl_(
423 std::forward<Func>(f), std::make_index_sequence<traits::arity>{}, true);
424}
425
426} // namespace torch
427