1 | #pragma once |
2 | |
3 | #include <torch/csrc/python_headers.h> |
4 | #include <torch/csrc/utils/object_ptr.h> |
5 | #include <torch/csrc/utils/pybind.h> |
6 | #include <stdexcept> |
7 | #include <string> |
8 | |
9 | // Utilities for handling Python strings. Note that PyString, when defined, is |
10 | // the same as PyBytes. |
11 | |
12 | // Returns true if obj is a bytes/str or unicode object |
13 | // As of Python 3.6, this does not require the GIL |
14 | inline bool THPUtils_checkString(PyObject* obj) { |
15 | return PyBytes_Check(obj) || PyUnicode_Check(obj); |
16 | } |
17 | |
18 | // Unpacks PyBytes (PyString) or PyUnicode as std::string |
19 | // PyBytes are unpacked as-is. PyUnicode is unpacked as UTF-8. |
20 | // NOTE: this method requires the GIL |
21 | inline std::string THPUtils_unpackString(PyObject* obj) { |
22 | if (PyBytes_Check(obj)) { |
23 | size_t size = PyBytes_GET_SIZE(obj); |
24 | return std::string(PyBytes_AS_STRING(obj), size); |
25 | } |
26 | if (PyUnicode_Check(obj)) { |
27 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
28 | Py_ssize_t size; |
29 | const char* data = PyUnicode_AsUTF8AndSize(obj, &size); |
30 | if (!data) { |
31 | throw std::runtime_error("error unpacking string as utf-8" ); |
32 | } |
33 | return std::string(data, (size_t)size); |
34 | } |
35 | throw std::runtime_error("unpackString: expected bytes or unicode object" ); |
36 | } |
37 | |
38 | // Unpacks PyBytes (PyString) or PyUnicode as c10::string_view |
39 | // PyBytes are unpacked as-is. PyUnicode is unpacked as UTF-8. |
40 | // NOTE: If `obj` is destroyed, then the non-owning c10::string_view will |
41 | // become invalid. If the string needs to be accessed at any point after |
42 | // `obj` is destroyed, then the c10::string_view should be copied into |
43 | // a std::string, or another owning object, and kept alive. For an example, |
44 | // look at how IValue and autograd nodes handle c10::string_view arguments. |
45 | // NOTE: this method requires the GIL |
46 | inline c10::string_view THPUtils_unpackStringView(PyObject* obj) { |
47 | if (PyBytes_Check(obj)) { |
48 | size_t size = PyBytes_GET_SIZE(obj); |
49 | return c10::string_view(PyBytes_AS_STRING(obj), size); |
50 | } |
51 | if (PyUnicode_Check(obj)) { |
52 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
53 | Py_ssize_t size; |
54 | const char* data = PyUnicode_AsUTF8AndSize(obj, &size); |
55 | if (!data) { |
56 | throw std::runtime_error("error unpacking string as utf-8" ); |
57 | } |
58 | return c10::string_view(data, (size_t)size); |
59 | } |
60 | throw std::runtime_error("unpackString: expected bytes or unicode object" ); |
61 | } |
62 | |
63 | inline PyObject* THPUtils_packString(const char* str) { |
64 | return PyUnicode_FromString(str); |
65 | } |
66 | |
67 | inline PyObject* THPUtils_packString(const std::string& str) { |
68 | return PyUnicode_FromStringAndSize(str.c_str(), str.size()); |
69 | } |
70 | |
71 | inline PyObject* THPUtils_internString(const std::string& str) { |
72 | return PyUnicode_InternFromString(str.c_str()); |
73 | } |
74 | |
75 | // Precondition: THPUtils_checkString(obj) must be true |
76 | inline bool THPUtils_isInterned(PyObject* obj) { |
77 | return PyUnicode_CHECK_INTERNED(obj); |
78 | } |
79 | |
80 | // Precondition: THPUtils_checkString(obj) must be true |
81 | inline void THPUtils_internStringInPlace(PyObject** obj) { |
82 | PyUnicode_InternInPlace(obj); |
83 | } |
84 | |
85 | /* |
86 | * Reference: |
87 | * https://github.com/numpy/numpy/blob/f4c497c768e0646df740b647782df463825bfd27/numpy/core/src/common/get_attr_string.h#L42 |
88 | * |
89 | * Stripped down version of PyObject_GetAttrString, |
90 | * avoids lookups for None, tuple, and List objects, |
91 | * and doesn't create a PyErr since this code ignores it. |
92 | * |
93 | * This can be much faster then PyObject_GetAttrString where |
94 | * exceptions are not used by caller. |
95 | * |
96 | * 'obj' is the object to search for attribute. |
97 | * |
98 | * 'name' is the attribute to search for. |
99 | * |
100 | * Returns a py::object wrapping the return value. If the attribute lookup |
101 | * failed the value will be NULL. |
102 | * |
103 | */ |
104 | |
105 | // NOLINTNEXTLINE(clang-diagnostic-unused-function) |
106 | static py::object PyObject_FastGetAttrString(PyObject* obj, const char* name) { |
107 | PyTypeObject* tp = Py_TYPE(obj); |
108 | PyObject* res = (PyObject*)nullptr; |
109 | |
110 | /* Attribute referenced by (char *)name */ |
111 | if (tp->tp_getattr != nullptr) { |
112 | // This is OK per https://bugs.python.org/issue39620 |
113 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) |
114 | res = (*tp->tp_getattr)(obj, const_cast<char*>(name)); |
115 | if (res == nullptr) { |
116 | PyErr_Clear(); |
117 | } |
118 | } |
119 | /* Attribute referenced by (PyObject *)name */ |
120 | else if (tp->tp_getattro != nullptr) { |
121 | auto w = py::reinterpret_steal<py::object>(THPUtils_internString(name)); |
122 | if (w.ptr() == nullptr) { |
123 | return py::object(); |
124 | } |
125 | res = (*tp->tp_getattro)(obj, w.ptr()); |
126 | if (res == nullptr) { |
127 | PyErr_Clear(); |
128 | } |
129 | } |
130 | return py::reinterpret_steal<py::object>(res); |
131 | } |
132 | |