1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | // Python bindings for tensorflow/python/framework/python_api_dispatcher.h. |
16 | |
17 | #include "pybind11/pybind11.h" |
18 | #include "pybind11/pytypes.h" |
19 | #include "pybind11/stl.h" |
20 | #include "tensorflow/python/framework/python_api_dispatcher.h" |
21 | #include "tensorflow/python/lib/core/safe_pyobject_ptr.h" |
22 | |
23 | namespace py = pybind11; |
24 | |
25 | using tensorflow::py_dispatch::PyInstanceChecker; |
26 | using tensorflow::py_dispatch::PyListChecker; |
27 | using tensorflow::py_dispatch::PySignatureChecker; |
28 | using tensorflow::py_dispatch::PythonAPIDispatcher; |
29 | using tensorflow::py_dispatch::PyTypeChecker; |
30 | using tensorflow::py_dispatch::PyUnionChecker; |
31 | |
32 | namespace { |
33 | |
34 | py::object Dispatch(PythonAPIDispatcher* self, py::handle args, |
35 | py::handle kwargs) { |
36 | auto result = self->Dispatch(args.ptr(), kwargs.ptr()); |
37 | if (result == nullptr) { |
38 | throw py::error_already_set(); |
39 | } else { |
40 | return py::reinterpret_steal<py::object>(result.release()); |
41 | } |
42 | } |
43 | |
44 | PythonAPIDispatcher MakePythonAPIDispatcher( |
45 | const std::string& api_name, const std::vector<std::string>& arg_names, |
46 | py::handle defaults) { |
47 | std::vector<const char*> name_strs; |
48 | name_strs.reserve(arg_names.size()); |
49 | for (const auto& name : arg_names) { |
50 | name_strs.push_back(name.c_str()); |
51 | } |
52 | absl::Span<const char*> arg_names_span(name_strs); |
53 | if (defaults.ptr() == Py_None) { |
54 | return PythonAPIDispatcher(api_name, arg_names_span, {}); |
55 | } else { |
56 | tensorflow::Safe_PyObjectPtr fast_defaults( |
57 | PySequence_Fast(defaults.ptr(), "defaults is not a sequence" )); |
58 | if (!fast_defaults) { |
59 | throw py::error_already_set(); |
60 | } |
61 | return PythonAPIDispatcher( |
62 | api_name, arg_names_span, |
63 | absl::MakeSpan(PySequence_Fast_ITEMS(fast_defaults.get()), |
64 | PySequence_Fast_GET_SIZE(fast_defaults.get()))); |
65 | } |
66 | } |
67 | |
68 | } // namespace |
69 | |
70 | PYBIND11_MODULE(_pywrap_python_api_dispatcher, m) { |
71 | py::enum_<PyTypeChecker::MatchType>(m, "MatchType" ) |
72 | .value("NO_MATCH" , PyTypeChecker::MatchType::NO_MATCH) |
73 | .value("MATCH" , PyTypeChecker::MatchType::MATCH) |
74 | .value("MATCH_DISPATCHABLE" , PyTypeChecker::MatchType::MATCH_DISPATCHABLE) |
75 | .export_values(); |
76 | |
77 | py::class_<PyTypeChecker, std::shared_ptr<PyTypeChecker>>(m, "PyTypeChecker" ) |
78 | .def("Check" , [](PyTypeChecker* self, |
79 | py::handle value) { return self->Check(value.ptr()); }) |
80 | .def("cost" , &PyTypeChecker::cost) |
81 | .def("cache_size" , |
82 | [](PyTypeChecker* self) { |
83 | return static_cast<PyInstanceChecker*>(self)->cache_size(); |
84 | }) |
85 | .def("__repr__" , [](PyTypeChecker* self) { |
86 | return absl::StrCat("<PyTypeChecker " , self->DebugString(), ">" ); |
87 | }); |
88 | |
89 | py::class_<PySignatureChecker>(m, "PySignatureChecker" ) |
90 | .def(py::init< |
91 | std::vector<std::pair<int, std::shared_ptr<PyTypeChecker>>>>()) |
92 | .def("CheckCanonicalizedArgs" , |
93 | [](PySignatureChecker* self, py::tuple args) { |
94 | tensorflow::Safe_PyObjectPtr seq(PySequence_Fast(args.ptr(), "" )); |
95 | PyObject** items = PySequence_Fast_ITEMS(seq.get()); |
96 | int n = PySequence_Fast_GET_SIZE(seq.get()); |
97 | return self->CheckCanonicalizedArgs(absl::MakeSpan(items, n)); |
98 | }) |
99 | .def("__repr__" , [](PySignatureChecker* self) { |
100 | return absl::StrCat("<PySignatureChecker " , self->DebugString(), ">" ); |
101 | }); |
102 | |
103 | py::class_<PythonAPIDispatcher>(m, "PythonAPIDispatcher" ) |
104 | .def(py::init(&MakePythonAPIDispatcher)) |
105 | .def("Register" , |
106 | [](PythonAPIDispatcher* self, PySignatureChecker signature_checker, |
107 | py::handle func) { |
108 | return self->Register(signature_checker, func.ptr()); |
109 | }) |
110 | .def("Dispatch" , &Dispatch) |
111 | .def("Unregister" , |
112 | [](PythonAPIDispatcher* self, py::handle func) { |
113 | return self->Unregister(func.ptr()); |
114 | }) |
115 | .def("__repr__" , &PythonAPIDispatcher::DebugString); |
116 | |
117 | m.def("MakeInstanceChecker" , [](py::args py_classes) { |
118 | std::vector<PyObject*> py_classes_vector; |
119 | py_classes_vector.reserve(py_classes.size()); |
120 | for (auto& cls : py_classes) { |
121 | if (!PyType_Check(cls.ptr())) { |
122 | throw py::type_error("`*py_classes` must be a tuple of types." ); |
123 | } |
124 | py_classes_vector.push_back(cls.ptr()); |
125 | } |
126 | return std::shared_ptr<PyTypeChecker>( |
127 | std::make_shared<PyInstanceChecker>(py_classes_vector)); |
128 | }); |
129 | m.def("MakeListChecker" , [](std::shared_ptr<PyTypeChecker> elt_type) { |
130 | return std::shared_ptr<PyTypeChecker>( |
131 | std::make_shared<PyListChecker>(elt_type)); |
132 | }); |
133 | m.def("MakeUnionChecker" , |
134 | [](const std::vector<std::shared_ptr<PyTypeChecker>>& options) { |
135 | return std::shared_ptr<PyTypeChecker>( |
136 | std::make_shared<PyUnionChecker>(options)); |
137 | }); |
138 | m.def("register_dispatchable_type" , [](py::handle py_class) { |
139 | if (!tensorflow::py_dispatch::RegisterDispatchableType(py_class.ptr())) { |
140 | throw py::error_already_set(); |
141 | } else { |
142 | return py_class; |
143 | } |
144 | }); |
145 | } |
146 | |