1/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15// clang-format off
16// These headers must be at the top, before including Python.h header
17// Otherwise, we get C2039 on MSVC due to 'copysign'
18#include "pybind11/complex.h"
19#include "pybind11/pybind11.h"
20// clang-format on
21
22#include "Python.h"
23#include "absl/container/flat_hash_map.h"
24#include "tensorflow/core/platform/logging.h"
25
26namespace py = pybind11;
27constexpr int PY_MODULE_TYPE_TP_BASIC_SIZE = 56;
28
29struct FastModuleObject {
30 // A dummy array that ensures enough size is reserved for FastModuleObject,
31 // because it's inherited from PyModuleObject.
32 const std::array<char, PY_MODULE_TYPE_TP_BASIC_SIZE> opaque_base_fields;
33 // A cache that helps reduce attribute lookup overhead.
34 absl::flat_hash_map<PyObject *, PyObject *> attr_map;
35 // pointer to the external getattribute function
36 PyObject *cb_getattribute = nullptr;
37 // pointer to the external getattr function
38 PyObject *cb_getattr = nullptr;
39 // static PyTypeObject type;
40
41 FastModuleObject() = delete;
42 ~FastModuleObject() = delete;
43 static FastModuleObject *UncheckedCast(PyObject *obj);
44};
45
46static int FastModule_init(FastModuleObject *self, PyObject *args,
47 PyObject *kwds) {
48 DCHECK_EQ(PY_MODULE_TYPE_TP_BASIC_SIZE, PyModule_Type.tp_basicsize);
49 if (PyModule_Type.tp_init(reinterpret_cast<PyObject *>(self), args, kwds) < 0)
50 return -1;
51 new (&(self->attr_map)) absl::flat_hash_map<PyObject *, PyObject *>();
52 return 0;
53}
54
55// Parses the input as a callable and checks the result.
56static PyObject *ParseFunc(PyObject *args) {
57 PyObject *func;
58 if (!PyArg_ParseTuple(args, "O:set_callback", &func)) return nullptr;
59 if (!PyCallable_Check(func)) {
60 PyErr_SetString(PyExc_TypeError, "input args must be callable");
61 return nullptr;
62 }
63 Py_INCREF(func); // Add a reference to new callback
64 return func;
65}
66
67// Sets the pointer 'cb_getattribute' in the FastModuleObject object
68// corresponding to 'self'.
69static PyObject *SetGetattributeCallback(PyObject *self, PyObject *args) {
70 PyObject *func = ParseFunc(args);
71 // Dispose of previous callback
72 Py_XDECREF(FastModuleObject::UncheckedCast(self)->cb_getattribute);
73 // Remember new callback
74 FastModuleObject::UncheckedCast(self)->cb_getattribute = func;
75 Py_RETURN_NONE;
76}
77
78// Sets the pointer 'cb_getattr' in the FastModuleObject object
79// corresponding to 'self'.
80static PyObject *SetGetattrCallback(PyObject *self, PyObject *args) {
81 PyObject *func = ParseFunc(args);
82 // Dispose of previous callback
83 Py_XDECREF(FastModuleObject::UncheckedCast(self)->cb_getattr);
84 // Remember new callback
85 FastModuleObject::UncheckedCast(self)->cb_getattr = func;
86 Py_RETURN_NONE;
87}
88
89// Inserts or updates a key-value pair in the cache 'attr_map'
90// of the FastModuleObject object corresponding to 'self'.
91static PyObject *FastDictInsert(FastModuleObject *self, PyObject *args) {
92 PyObject *name, *value;
93 if (!PyArg_ParseTuple(args, "OO", &name, &value)) {
94 PyErr_SetString(PyExc_TypeError, "_fastdict_insert: incorrect inputs");
95 return nullptr;
96 }
97 auto &attr_map = self->attr_map;
98 if (attr_map.find(name) != attr_map.end()) {
99 Py_DECREF(name);
100 Py_DECREF(value);
101 }
102 attr_map.insert_or_assign(name, value);
103 // Increment the reference count
104 Py_INCREF(name);
105 Py_INCREF(value);
106 // Properly handle returning Py_None
107 Py_RETURN_NONE;
108}
109
110// Gets a value from a key in the cache 'attr_map'
111// of the FastModuleObject object corresponding to 'self'.
112static PyObject *FastDictGet(FastModuleObject *self, PyObject *args) {
113 PyObject *name;
114 if (!PyArg_ParseTuple(args, "O", &name)) {
115 PyErr_SetString(PyExc_TypeError, "_fastdict_get: incorrect inputs");
116 return nullptr;
117 }
118 auto &attr_map = self->attr_map;
119 auto result = attr_map.find(name);
120 if (result != attr_map.end()) {
121 PyObject *value = result->second;
122 Py_INCREF(value);
123 return value;
124 }
125 // Copied from CPython's moduleobject.c
126 PyErr_Format(PyExc_KeyError, "module has no attribute '%U'", name);
127 return nullptr;
128}
129
130// Returns true if a key exists in the cache 'attr_map'
131// of the FastModuleObject object corresponding to 'self',
132// otherwise returns false.
133static PyObject *FastDictContains(FastModuleObject *self, PyObject *args) {
134 PyObject *name;
135 if (!PyArg_ParseTuple(args, "O", &name)) {
136 PyErr_SetString(PyExc_TypeError, "_fastdict_key_in: incorrect inputs");
137 return nullptr;
138 }
139 const auto &attr_map = self->attr_map;
140 const auto result = attr_map.contains(name);
141 if (result) {
142 // Properly handle returning Py_True
143 Py_RETURN_TRUE;
144 }
145 // Properly handle returning Py_False
146 Py_RETURN_FALSE;
147}
148
149// Calls a function 'func' with inputs 'self' and 'args'.
150static PyObject *CallFunc(FastModuleObject *self, PyObject *args,
151 PyObject *func) {
152 if (func == nullptr) {
153 PyErr_SetString(PyExc_NameError,
154 "Attempting to call a callback that was not defined");
155 return nullptr;
156 }
157 PyObject *name;
158 if (!PyArg_ParseTuple(args, "O", &name)) {
159 PyErr_SetString(PyExc_TypeError, "CallFunc: incorrect inputs");
160 return nullptr;
161 }
162 PyObject *arglist = Py_BuildValue("(OO)", self, name);
163 auto result = PyObject_CallObject(func, arglist);
164 Py_DECREF(arglist);
165 return result;
166}
167
168static PyMethodDef FastModule_methods[] = {
169 {"_fastdict_insert", reinterpret_cast<PyCFunction>(FastDictInsert),
170 METH_VARARGS, "Registers a method to the fast lookup table."},
171 {"_fastdict_get", reinterpret_cast<PyCFunction>(FastDictGet), METH_VARARGS,
172 "Gets a method from the fast lookup table."},
173 {"_fastdict_key_in", reinterpret_cast<PyCFunction>(FastDictContains),
174 METH_VARARGS, "Checks if a method exists in the fast lookup table."},
175 {"set_getattribute_callback", SetGetattributeCallback, METH_VARARGS,
176 "Defines the callback function to replace __getattribute__"},
177 {"set_getattr_callback", SetGetattrCallback, METH_VARARGS,
178 "Defines the callback function to replace __getattr__"},
179 {nullptr, nullptr, 0, nullptr},
180};
181
182// Attempts to get the attribute based on 'name' as the key in cache 'attr_map'
183// of the FastModuleObject object corresponding to 'module'.
184// If the lookup fails in the cache, either uses
185// a user-defined callback 'cb_getattribute'
186// or the default 'tp_getattro' function to look for the attribute.
187static PyObject *FastTpGetattro(PyObject *module, PyObject *name) {
188 FastModuleObject *fast_module = FastModuleObject::UncheckedCast(module);
189 auto &attr_map = fast_module->attr_map;
190 auto it = attr_map.find(name);
191 // If the attribute lookup is successful in the cache, directly return it.
192 if (it != attr_map.end()) {
193 PyObject *value = it->second;
194 Py_INCREF(value);
195 return value;
196 }
197 PyObject *arglist = Py_BuildValue("(O)", name);
198 PyObject *result;
199 // Prefer the customized callback function over the default function.
200 if (fast_module->cb_getattribute != nullptr) {
201 result = CallFunc(fast_module, arglist, fast_module->cb_getattribute);
202 } else {
203 result = PyModule_Type.tp_getattro(module, name);
204 }
205 // Return result if it's found
206 if (result != nullptr) {
207 return result;
208 }
209 // If the default lookup fails and an AttributeError is raised,
210 // clear the error status before using the __getattr__ callback function.
211 auto is_error = PyErr_Occurred();
212 if (is_error && PyErr_ExceptionMatches(PyExc_AttributeError) &&
213 fast_module->cb_getattr != nullptr) {
214 PyErr_Clear();
215 return CallFunc(fast_module, arglist, fast_module->cb_getattr);
216 }
217 // If all options were used up
218 return result;
219}
220
221// Customized destructor for FastModuleType.tp_dealloc
222// In addition to default behavior it also clears up the contents in attr_map.
223static void FastModuleObjectDealloc(PyObject *module) {
224 auto &attr_map = FastModuleObject::UncheckedCast(module)->attr_map;
225 for (auto &it : attr_map) {
226 Py_DECREF(it.first);
227 Py_DECREF(it.second);
228 }
229 attr_map.~flat_hash_map<PyObject *, PyObject *>();
230 Py_TYPE(module)->tp_free(module);
231}
232
233static PyTypeObject FastModuleType = []() {
234 PyTypeObject obj = {PyVarObject_HEAD_INIT(&PyType_Type, 0)};
235 obj.tp_name = "fast_module_type.FastModuleType";
236 obj.tp_basicsize = sizeof(FastModuleObject);
237 obj.tp_itemsize = 0;
238 obj.tp_dealloc = FastModuleObjectDealloc;
239 obj.tp_getattro = FastTpGetattro;
240 obj.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
241 obj.tp_doc = "FastModuleType objects";
242 obj.tp_methods = FastModule_methods;
243 obj.tp_init = reinterpret_cast<initproc>(FastModule_init);
244 return obj;
245}();
246
247// Returns true if the type of 'obj' or any of its parent class
248// is equal to 'target'. Otherwise returns false.
249bool IsAnyBaseSameType(const PyObject *obj, const PyTypeObject *target) {
250 auto *tp = Py_TYPE(obj);
251 while (true) {
252 if (tp == target) return true;
253 // If the default type is found, there is no need to search further
254 if (tp == &PyBaseObject_Type) break;
255 tp = tp->tp_base;
256 }
257 return false;
258}
259
260// Casts 'obj' to 'FastModuleObject *'.
261// Conducts a check only in non-optimized builds.
262FastModuleObject *FastModuleObject::UncheckedCast(PyObject *obj) {
263 DCHECK(IsAnyBaseSameType(obj, &FastModuleType));
264 return reinterpret_cast<FastModuleObject *>(obj);
265}
266
267PYBIND11_MODULE(fast_module_type, m) {
268 FastModuleType.tp_base = &PyModule_Type;
269 FastModuleType.tp_setattro = [](PyObject *module, PyObject *name,
270 PyObject *value) -> int {
271 auto &attr_map = FastModuleObject::UncheckedCast(module)->attr_map;
272 if (attr_map.find(name) != attr_map.end()) {
273 Py_DECREF(name);
274 Py_DECREF(value);
275 }
276 attr_map.insert_or_assign(name, value);
277 // Increment the reference count
278 Py_INCREF(name);
279 Py_INCREF(value);
280 PyObject_GenericSetAttr(module, name, value);
281 return 0;
282 };
283
284 m.doc() = R"pbdoc(
285 fast_module_type
286 -----
287 )pbdoc";
288 // Use getter function to hold attributes rather than pybind11's m.attr due to
289 // b/145559202.
290 m.def(
291 "get_fast_module_type_class",
292 []() {
293 return py::cast<py::object>(
294 reinterpret_cast<PyObject *>(&FastModuleType));
295 },
296 py::return_value_policy::reference);
297}
298