1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15// Python bindings for tensorflow/python/framework/python_api_dispatcher.h.
16
17#include "pybind11/pybind11.h"
18#include "pybind11/pytypes.h"
19#include "pybind11/stl.h"
20#include "tensorflow/python/framework/python_api_dispatcher.h"
21#include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
22
23namespace py = pybind11;
24
25using tensorflow::py_dispatch::PyInstanceChecker;
26using tensorflow::py_dispatch::PyListChecker;
27using tensorflow::py_dispatch::PySignatureChecker;
28using tensorflow::py_dispatch::PythonAPIDispatcher;
29using tensorflow::py_dispatch::PyTypeChecker;
30using tensorflow::py_dispatch::PyUnionChecker;
31
32namespace {
33
34py::object Dispatch(PythonAPIDispatcher* self, py::handle args,
35 py::handle kwargs) {
36 auto result = self->Dispatch(args.ptr(), kwargs.ptr());
37 if (result == nullptr) {
38 throw py::error_already_set();
39 } else {
40 return py::reinterpret_steal<py::object>(result.release());
41 }
42}
43
44PythonAPIDispatcher MakePythonAPIDispatcher(
45 const std::string& api_name, const std::vector<std::string>& arg_names,
46 py::handle defaults) {
47 std::vector<const char*> name_strs;
48 name_strs.reserve(arg_names.size());
49 for (const auto& name : arg_names) {
50 name_strs.push_back(name.c_str());
51 }
52 absl::Span<const char*> arg_names_span(name_strs);
53 if (defaults.ptr() == Py_None) {
54 return PythonAPIDispatcher(api_name, arg_names_span, {});
55 } else {
56 tensorflow::Safe_PyObjectPtr fast_defaults(
57 PySequence_Fast(defaults.ptr(), "defaults is not a sequence"));
58 if (!fast_defaults) {
59 throw py::error_already_set();
60 }
61 return PythonAPIDispatcher(
62 api_name, arg_names_span,
63 absl::MakeSpan(PySequence_Fast_ITEMS(fast_defaults.get()),
64 PySequence_Fast_GET_SIZE(fast_defaults.get())));
65 }
66}
67
68} // namespace
69
70PYBIND11_MODULE(_pywrap_python_api_dispatcher, m) {
71 py::enum_<PyTypeChecker::MatchType>(m, "MatchType")
72 .value("NO_MATCH", PyTypeChecker::MatchType::NO_MATCH)
73 .value("MATCH", PyTypeChecker::MatchType::MATCH)
74 .value("MATCH_DISPATCHABLE", PyTypeChecker::MatchType::MATCH_DISPATCHABLE)
75 .export_values();
76
77 py::class_<PyTypeChecker, std::shared_ptr<PyTypeChecker>>(m, "PyTypeChecker")
78 .def("Check", [](PyTypeChecker* self,
79 py::handle value) { return self->Check(value.ptr()); })
80 .def("cost", &PyTypeChecker::cost)
81 .def("cache_size",
82 [](PyTypeChecker* self) {
83 return static_cast<PyInstanceChecker*>(self)->cache_size();
84 })
85 .def("__repr__", [](PyTypeChecker* self) {
86 return absl::StrCat("<PyTypeChecker ", self->DebugString(), ">");
87 });
88
89 py::class_<PySignatureChecker>(m, "PySignatureChecker")
90 .def(py::init<
91 std::vector<std::pair<int, std::shared_ptr<PyTypeChecker>>>>())
92 .def("CheckCanonicalizedArgs",
93 [](PySignatureChecker* self, py::tuple args) {
94 tensorflow::Safe_PyObjectPtr seq(PySequence_Fast(args.ptr(), ""));
95 PyObject** items = PySequence_Fast_ITEMS(seq.get());
96 int n = PySequence_Fast_GET_SIZE(seq.get());
97 return self->CheckCanonicalizedArgs(absl::MakeSpan(items, n));
98 })
99 .def("__repr__", [](PySignatureChecker* self) {
100 return absl::StrCat("<PySignatureChecker ", self->DebugString(), ">");
101 });
102
103 py::class_<PythonAPIDispatcher>(m, "PythonAPIDispatcher")
104 .def(py::init(&MakePythonAPIDispatcher))
105 .def("Register",
106 [](PythonAPIDispatcher* self, PySignatureChecker signature_checker,
107 py::handle func) {
108 return self->Register(signature_checker, func.ptr());
109 })
110 .def("Dispatch", &Dispatch)
111 .def("Unregister",
112 [](PythonAPIDispatcher* self, py::handle func) {
113 return self->Unregister(func.ptr());
114 })
115 .def("__repr__", &PythonAPIDispatcher::DebugString);
116
117 m.def("MakeInstanceChecker", [](py::args py_classes) {
118 std::vector<PyObject*> py_classes_vector;
119 py_classes_vector.reserve(py_classes.size());
120 for (auto& cls : py_classes) {
121 if (!PyType_Check(cls.ptr())) {
122 throw py::type_error("`*py_classes` must be a tuple of types.");
123 }
124 py_classes_vector.push_back(cls.ptr());
125 }
126 return std::shared_ptr<PyTypeChecker>(
127 std::make_shared<PyInstanceChecker>(py_classes_vector));
128 });
129 m.def("MakeListChecker", [](std::shared_ptr<PyTypeChecker> elt_type) {
130 return std::shared_ptr<PyTypeChecker>(
131 std::make_shared<PyListChecker>(elt_type));
132 });
133 m.def("MakeUnionChecker",
134 [](const std::vector<std::shared_ptr<PyTypeChecker>>& options) {
135 return std::shared_ptr<PyTypeChecker>(
136 std::make_shared<PyUnionChecker>(options));
137 });
138 m.def("register_dispatchable_type", [](py::handle py_class) {
139 if (!tensorflow::py_dispatch::RegisterDispatchableType(py_class.ptr())) {
140 throw py::error_already_set();
141 } else {
142 return py_class;
143 }
144 });
145}
146