1 | /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | // clang-format off |
16 | // These headers must be at the top, before including Python.h header |
17 | // Otherwise, we get C2039 on MSVC due to 'copysign' |
18 | #include "pybind11/complex.h" |
19 | #include "pybind11/pybind11.h" |
20 | // clang-format on |
21 | |
22 | #include "Python.h" |
23 | #include "absl/container/flat_hash_map.h" |
24 | #include "tensorflow/core/platform/logging.h" |
25 | |
26 | namespace py = pybind11; |
27 | constexpr int PY_MODULE_TYPE_TP_BASIC_SIZE = 56; |
28 | |
29 | struct FastModuleObject { |
30 | // A dummy array that ensures enough size is reserved for FastModuleObject, |
31 | // because it's inherited from PyModuleObject. |
32 | const std::array<char, PY_MODULE_TYPE_TP_BASIC_SIZE> opaque_base_fields; |
33 | // A cache that helps reduce attribute lookup overhead. |
34 | absl::flat_hash_map<PyObject *, PyObject *> attr_map; |
35 | // pointer to the external getattribute function |
36 | PyObject *cb_getattribute = nullptr; |
37 | // pointer to the external getattr function |
38 | PyObject *cb_getattr = nullptr; |
39 | // static PyTypeObject type; |
40 | |
41 | FastModuleObject() = delete; |
42 | ~FastModuleObject() = delete; |
43 | static FastModuleObject *UncheckedCast(PyObject *obj); |
44 | }; |
45 | |
46 | static int FastModule_init(FastModuleObject *self, PyObject *args, |
47 | PyObject *kwds) { |
48 | DCHECK_EQ(PY_MODULE_TYPE_TP_BASIC_SIZE, PyModule_Type.tp_basicsize); |
49 | if (PyModule_Type.tp_init(reinterpret_cast<PyObject *>(self), args, kwds) < 0) |
50 | return -1; |
51 | new (&(self->attr_map)) absl::flat_hash_map<PyObject *, PyObject *>(); |
52 | return 0; |
53 | } |
54 | |
55 | // Parses the input as a callable and checks the result. |
56 | static PyObject *ParseFunc(PyObject *args) { |
57 | PyObject *func; |
58 | if (!PyArg_ParseTuple(args, "O:set_callback" , &func)) return nullptr; |
59 | if (!PyCallable_Check(func)) { |
60 | PyErr_SetString(PyExc_TypeError, "input args must be callable" ); |
61 | return nullptr; |
62 | } |
63 | Py_INCREF(func); // Add a reference to new callback |
64 | return func; |
65 | } |
66 | |
67 | // Sets the pointer 'cb_getattribute' in the FastModuleObject object |
68 | // corresponding to 'self'. |
69 | static PyObject *SetGetattributeCallback(PyObject *self, PyObject *args) { |
70 | PyObject *func = ParseFunc(args); |
71 | // Dispose of previous callback |
72 | Py_XDECREF(FastModuleObject::UncheckedCast(self)->cb_getattribute); |
73 | // Remember new callback |
74 | FastModuleObject::UncheckedCast(self)->cb_getattribute = func; |
75 | Py_RETURN_NONE; |
76 | } |
77 | |
78 | // Sets the pointer 'cb_getattr' in the FastModuleObject object |
79 | // corresponding to 'self'. |
80 | static PyObject *SetGetattrCallback(PyObject *self, PyObject *args) { |
81 | PyObject *func = ParseFunc(args); |
82 | // Dispose of previous callback |
83 | Py_XDECREF(FastModuleObject::UncheckedCast(self)->cb_getattr); |
84 | // Remember new callback |
85 | FastModuleObject::UncheckedCast(self)->cb_getattr = func; |
86 | Py_RETURN_NONE; |
87 | } |
88 | |
89 | // Inserts or updates a key-value pair in the cache 'attr_map' |
90 | // of the FastModuleObject object corresponding to 'self'. |
91 | static PyObject *FastDictInsert(FastModuleObject *self, PyObject *args) { |
92 | PyObject *name, *value; |
93 | if (!PyArg_ParseTuple(args, "OO" , &name, &value)) { |
94 | PyErr_SetString(PyExc_TypeError, "_fastdict_insert: incorrect inputs" ); |
95 | return nullptr; |
96 | } |
97 | auto &attr_map = self->attr_map; |
98 | if (attr_map.find(name) != attr_map.end()) { |
99 | Py_DECREF(name); |
100 | Py_DECREF(value); |
101 | } |
102 | attr_map.insert_or_assign(name, value); |
103 | // Increment the reference count |
104 | Py_INCREF(name); |
105 | Py_INCREF(value); |
106 | // Properly handle returning Py_None |
107 | Py_RETURN_NONE; |
108 | } |
109 | |
110 | // Gets a value from a key in the cache 'attr_map' |
111 | // of the FastModuleObject object corresponding to 'self'. |
112 | static PyObject *FastDictGet(FastModuleObject *self, PyObject *args) { |
113 | PyObject *name; |
114 | if (!PyArg_ParseTuple(args, "O" , &name)) { |
115 | PyErr_SetString(PyExc_TypeError, "_fastdict_get: incorrect inputs" ); |
116 | return nullptr; |
117 | } |
118 | auto &attr_map = self->attr_map; |
119 | auto result = attr_map.find(name); |
120 | if (result != attr_map.end()) { |
121 | PyObject *value = result->second; |
122 | Py_INCREF(value); |
123 | return value; |
124 | } |
125 | // Copied from CPython's moduleobject.c |
126 | PyErr_Format(PyExc_KeyError, "module has no attribute '%U'" , name); |
127 | return nullptr; |
128 | } |
129 | |
130 | // Returns true if a key exists in the cache 'attr_map' |
131 | // of the FastModuleObject object corresponding to 'self', |
132 | // otherwise returns false. |
133 | static PyObject *FastDictContains(FastModuleObject *self, PyObject *args) { |
134 | PyObject *name; |
135 | if (!PyArg_ParseTuple(args, "O" , &name)) { |
136 | PyErr_SetString(PyExc_TypeError, "_fastdict_key_in: incorrect inputs" ); |
137 | return nullptr; |
138 | } |
139 | const auto &attr_map = self->attr_map; |
140 | const auto result = attr_map.contains(name); |
141 | if (result) { |
142 | // Properly handle returning Py_True |
143 | Py_RETURN_TRUE; |
144 | } |
145 | // Properly handle returning Py_False |
146 | Py_RETURN_FALSE; |
147 | } |
148 | |
149 | // Calls a function 'func' with inputs 'self' and 'args'. |
150 | static PyObject *CallFunc(FastModuleObject *self, PyObject *args, |
151 | PyObject *func) { |
152 | if (func == nullptr) { |
153 | PyErr_SetString(PyExc_NameError, |
154 | "Attempting to call a callback that was not defined" ); |
155 | return nullptr; |
156 | } |
157 | PyObject *name; |
158 | if (!PyArg_ParseTuple(args, "O" , &name)) { |
159 | PyErr_SetString(PyExc_TypeError, "CallFunc: incorrect inputs" ); |
160 | return nullptr; |
161 | } |
162 | PyObject *arglist = Py_BuildValue("(OO)" , self, name); |
163 | auto result = PyObject_CallObject(func, arglist); |
164 | Py_DECREF(arglist); |
165 | return result; |
166 | } |
167 | |
168 | static PyMethodDef FastModule_methods[] = { |
169 | {"_fastdict_insert" , reinterpret_cast<PyCFunction>(FastDictInsert), |
170 | METH_VARARGS, "Registers a method to the fast lookup table." }, |
171 | {"_fastdict_get" , reinterpret_cast<PyCFunction>(FastDictGet), METH_VARARGS, |
172 | "Gets a method from the fast lookup table." }, |
173 | {"_fastdict_key_in" , reinterpret_cast<PyCFunction>(FastDictContains), |
174 | METH_VARARGS, "Checks if a method exists in the fast lookup table." }, |
175 | {"set_getattribute_callback" , SetGetattributeCallback, METH_VARARGS, |
176 | "Defines the callback function to replace __getattribute__" }, |
177 | {"set_getattr_callback" , SetGetattrCallback, METH_VARARGS, |
178 | "Defines the callback function to replace __getattr__" }, |
179 | {nullptr, nullptr, 0, nullptr}, |
180 | }; |
181 | |
182 | // Attempts to get the attribute based on 'name' as the key in cache 'attr_map' |
183 | // of the FastModuleObject object corresponding to 'module'. |
184 | // If the lookup fails in the cache, either uses |
185 | // a user-defined callback 'cb_getattribute' |
186 | // or the default 'tp_getattro' function to look for the attribute. |
187 | static PyObject *FastTpGetattro(PyObject *module, PyObject *name) { |
188 | FastModuleObject *fast_module = FastModuleObject::UncheckedCast(module); |
189 | auto &attr_map = fast_module->attr_map; |
190 | auto it = attr_map.find(name); |
191 | // If the attribute lookup is successful in the cache, directly return it. |
192 | if (it != attr_map.end()) { |
193 | PyObject *value = it->second; |
194 | Py_INCREF(value); |
195 | return value; |
196 | } |
197 | PyObject *arglist = Py_BuildValue("(O)" , name); |
198 | PyObject *result; |
199 | // Prefer the customized callback function over the default function. |
200 | if (fast_module->cb_getattribute != nullptr) { |
201 | result = CallFunc(fast_module, arglist, fast_module->cb_getattribute); |
202 | } else { |
203 | result = PyModule_Type.tp_getattro(module, name); |
204 | } |
205 | // Return result if it's found |
206 | if (result != nullptr) { |
207 | return result; |
208 | } |
209 | // If the default lookup fails and an AttributeError is raised, |
210 | // clear the error status before using the __getattr__ callback function. |
211 | auto is_error = PyErr_Occurred(); |
212 | if (is_error && PyErr_ExceptionMatches(PyExc_AttributeError) && |
213 | fast_module->cb_getattr != nullptr) { |
214 | PyErr_Clear(); |
215 | return CallFunc(fast_module, arglist, fast_module->cb_getattr); |
216 | } |
217 | // If all options were used up |
218 | return result; |
219 | } |
220 | |
221 | // Customized destructor for FastModuleType.tp_dealloc |
222 | // In addition to default behavior it also clears up the contents in attr_map. |
223 | static void FastModuleObjectDealloc(PyObject *module) { |
224 | auto &attr_map = FastModuleObject::UncheckedCast(module)->attr_map; |
225 | for (auto &it : attr_map) { |
226 | Py_DECREF(it.first); |
227 | Py_DECREF(it.second); |
228 | } |
229 | attr_map.~flat_hash_map<PyObject *, PyObject *>(); |
230 | Py_TYPE(module)->tp_free(module); |
231 | } |
232 | |
233 | static PyTypeObject FastModuleType = []() { |
234 | PyTypeObject obj = {PyVarObject_HEAD_INIT(&PyType_Type, 0)}; |
235 | obj.tp_name = "fast_module_type.FastModuleType" ; |
236 | obj.tp_basicsize = sizeof(FastModuleObject); |
237 | obj.tp_itemsize = 0; |
238 | obj.tp_dealloc = FastModuleObjectDealloc; |
239 | obj.tp_getattro = FastTpGetattro; |
240 | obj.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; |
241 | obj.tp_doc = "FastModuleType objects" ; |
242 | obj.tp_methods = FastModule_methods; |
243 | obj.tp_init = reinterpret_cast<initproc>(FastModule_init); |
244 | return obj; |
245 | }(); |
246 | |
247 | // Returns true if the type of 'obj' or any of its parent class |
248 | // is equal to 'target'. Otherwise returns false. |
249 | bool IsAnyBaseSameType(const PyObject *obj, const PyTypeObject *target) { |
250 | auto *tp = Py_TYPE(obj); |
251 | while (true) { |
252 | if (tp == target) return true; |
253 | // If the default type is found, there is no need to search further |
254 | if (tp == &PyBaseObject_Type) break; |
255 | tp = tp->tp_base; |
256 | } |
257 | return false; |
258 | } |
259 | |
260 | // Casts 'obj' to 'FastModuleObject *'. |
261 | // Conducts a check only in non-optimized builds. |
262 | FastModuleObject *FastModuleObject::UncheckedCast(PyObject *obj) { |
263 | DCHECK(IsAnyBaseSameType(obj, &FastModuleType)); |
264 | return reinterpret_cast<FastModuleObject *>(obj); |
265 | } |
266 | |
267 | PYBIND11_MODULE(fast_module_type, m) { |
268 | FastModuleType.tp_base = &PyModule_Type; |
269 | FastModuleType.tp_setattro = [](PyObject *module, PyObject *name, |
270 | PyObject *value) -> int { |
271 | auto &attr_map = FastModuleObject::UncheckedCast(module)->attr_map; |
272 | if (attr_map.find(name) != attr_map.end()) { |
273 | Py_DECREF(name); |
274 | Py_DECREF(value); |
275 | } |
276 | attr_map.insert_or_assign(name, value); |
277 | // Increment the reference count |
278 | Py_INCREF(name); |
279 | Py_INCREF(value); |
280 | PyObject_GenericSetAttr(module, name, value); |
281 | return 0; |
282 | }; |
283 | |
284 | m.doc() = R"pbdoc( |
285 | fast_module_type |
286 | ----- |
287 | )pbdoc" ; |
288 | // Use getter function to hold attributes rather than pybind11's m.attr due to |
289 | // b/145559202. |
290 | m.def( |
291 | "get_fast_module_type_class" , |
292 | []() { |
293 | return py::cast<py::object>( |
294 | reinterpret_cast<PyObject *>(&FastModuleType)); |
295 | }, |
296 | py::return_value_policy::reference); |
297 | } |
298 | |