1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/python/framework/python_api_dispatcher.h"
17
18#include <set>
19
20#include "absl/strings/str_join.h"
21#include "tensorflow/core/platform/logging.h"
22#include "tensorflow/core/platform/macros.h"
23#include "tensorflow/python/lib/core/py_util.h"
24#include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
25#include "tensorflow/python/util/util.h"
26
27namespace tensorflow {
28namespace py_dispatch {
29
30namespace {
31
32std::vector<Safe_PyObjectPtr>& GetRegisteredDispatchableTypes() {
33 static std::vector<Safe_PyObjectPtr>* registered_dispatchable_types =
34 new std::vector<Safe_PyObjectPtr>();
35 if (registered_dispatchable_types->empty()) {
36 static PyObject* composite_tensor =
37 swig::GetRegisteredPyObject("CompositeTensor");
38 Py_INCREF(composite_tensor);
39 registered_dispatchable_types->push_back(
40 Safe_PyObjectPtr(composite_tensor));
41 }
42 return *registered_dispatchable_types;
43}
44
45// Returns true if `py_class` is a registered dispatchable type.
46bool IsRegisteredDispatchableType(PyObject* py_class) {
47 DCheckPyGilState();
48 for (const auto& registered_type : GetRegisteredDispatchableTypes()) {
49 int result = PyObject_IsSubclass(py_class, registered_type.get());
50 if (result > 0) return true;
51 if (result < 0) PyErr_Clear();
52 }
53 return false;
54}
55
56// Raises an exception indicating that multiple dispatch targets matched.
57Safe_PyObjectPtr RaiseDispatchConflictError(const std::string& api_name,
58 PyObject* selected,
59 PyObject* target) {
60 Safe_PyObjectPtr s1(PyObject_Str(selected));
61 Safe_PyObjectPtr s2(PyObject_Str(target));
62 PyErr_SetString(PyExc_ValueError,
63 absl::StrCat("Multiple dispatch targets that were "
64 "registered with tf.dispatch_for (",
65 s1 ? PyUnicode_AsUTF8(s1.get()) : "?", " and ",
66 s2 ? PyUnicode_AsUTF8(s2.get()) : "?",
67 ") match the arguments to ", api_name)
68 .c_str());
69 return nullptr;
70}
71
72} // namespace
73
74bool RegisterDispatchableType(PyObject* py_class) {
75 DCheckPyGilState();
76 if (!PyType_Check(py_class)) {
77 PyErr_SetString(
78 PyExc_ValueError,
79 absl::StrCat("Expected a type object; got object with type ",
80 py_class->ob_type->tp_name)
81 .c_str());
82 return false;
83 }
84 if (IsRegisteredDispatchableType(py_class)) {
85 Safe_PyObjectPtr s(PyObject_Str(py_class));
86 PyErr_SetString(PyExc_ValueError,
87 absl::StrCat("Type ", s ? PyUnicode_AsUTF8(s.get()) : "?",
88 " (or one of its bases clases) has "
89 "already been registered")
90 .c_str());
91 return false;
92 }
93 Py_INCREF(py_class);
94 GetRegisteredDispatchableTypes().push_back(Safe_PyObjectPtr(py_class));
95 return true;
96}
97
98PythonAPIDispatcher::PythonAPIDispatcher(const std::string& api_name,
99 absl::Span<const char*> arg_names,
100 absl::Span<PyObject*> defaults)
101 : api_name_(api_name),
102 canonicalizer_(arg_names, defaults),
103 canonicalized_args_storage_(canonicalizer_.GetArgSize()),
104 canonicalized_args_span_(canonicalized_args_storage_) {}
105
106void PythonAPIDispatcher::Register(PySignatureChecker signature_checker,
107 PyObject* dispatch_target) {
108 DCheckPyGilState();
109 Py_INCREF(dispatch_target);
110 targets_.emplace_back(std::move(signature_checker),
111 Safe_PyObjectPtr(dispatch_target));
112}
113
114Safe_PyObjectPtr PythonAPIDispatcher::Dispatch(PyObject* args,
115 PyObject* kwargs) {
116 DCheckPyGilState();
117 if (kwargs == Py_None) {
118 kwargs = nullptr;
119 }
120 // Canonicalize args (so we don't need to deal with kwargs).
121 if (!canonicalizer_.Canonicalize(args, kwargs, canonicalized_args_span_)) {
122 return nullptr;
123 }
124
125 PyObject* selected = nullptr;
126 for (auto& target : targets_) {
127 if (target.first.CheckCanonicalizedArgs(canonicalized_args_span_)) {
128 if (selected && selected != target.second.get()) {
129 return RaiseDispatchConflictError(api_name_, selected,
130 target.second.get());
131 }
132 selected = target.second.get();
133 }
134 }
135 if (selected) {
136 return Safe_PyObjectPtr(PyObject_Call(selected, args, kwargs));
137 } else {
138 Py_INCREF(Py_NotImplemented);
139 return Safe_PyObjectPtr(Py_NotImplemented);
140 }
141}
142
143// TODO(b/194903203) Raise an error if `func` is not registered.
144void PythonAPIDispatcher::Unregister(PyObject* func) {
145 DCheckPyGilState();
146 using DispatchTargetPair = std::pair<PySignatureChecker, Safe_PyObjectPtr>;
147 targets_.erase(std::remove_if(targets_.begin(), targets_.end(),
148 [func](const DispatchTargetPair& t) {
149 return t.second.get() == func;
150 }),
151 targets_.end());
152}
153
154std::string PythonAPIDispatcher::DebugString() const {
155 DCheckPyGilState();
156 std::string out = absl::StrCat("<Dispatch(", api_name_, "): ");
157
158 const char* sep = "";
159 for (const auto& target : targets_) {
160 Safe_PyObjectPtr target_str(PyObject_Str(target.second.get()));
161 absl::StrAppend(&out, sep, target.first.DebugString(), " -> ",
162 target_str ? PyUnicode_AsUTF8(target_str.get()) : "?");
163 sep = ", ";
164 }
165 return out;
166}
167
168PySignatureChecker::PySignatureChecker(
169 std::vector<ParamChecker> parameter_checkers)
170 : positional_parameter_checkers_(std::move(parameter_checkers)) {
171 // Check less expensive parameters first.
172 std::sort(positional_parameter_checkers_.begin(),
173 positional_parameter_checkers_.end(),
174 [](ParamChecker a, ParamChecker b) {
175 return a.second->cost() < b.second->cost();
176 });
177}
178
179bool PySignatureChecker::CheckCanonicalizedArgs(
180 absl::Span<PyObject*> canon_args) const {
181 bool matched_dispatchable_type = false;
182 for (auto& c : positional_parameter_checkers_) {
183 int index = c.first;
184 auto& param_checker = c.second;
185 if (index >= canon_args.size()) {
186 return false;
187 }
188 switch (param_checker->Check(canon_args[index])) {
189 case PyTypeChecker::MatchType::NO_MATCH:
190 return false;
191 case PyTypeChecker::MatchType::MATCH_DISPATCHABLE:
192 matched_dispatchable_type = true;
193 break;
194 case PyTypeChecker::MatchType::MATCH:
195 break;
196 }
197 }
198 return matched_dispatchable_type;
199}
200
201std::string PySignatureChecker::DebugString() const {
202 return absl::StrJoin(positional_parameter_checkers_, ", ",
203 [](std::string* out, ParamChecker p) {
204 absl::StrAppend(out, "args[", p.first,
205 "]:", p.second->DebugString());
206 });
207}
208
209PyInstanceChecker::PyInstanceChecker(const std::vector<PyObject*>& py_classes) {
210 DCheckPyGilState();
211 py_classes_.reserve(py_classes.size());
212 for (PyObject* py_class : py_classes) {
213 py_classes_.emplace_back(py_class);
214 Py_INCREF(py_class);
215 }
216}
217
218PyInstanceChecker::~PyInstanceChecker() {
219 DCheckPyGilState();
220 for (const auto& pair : py_class_cache_) {
221 Py_DECREF(pair.first);
222 }
223}
224
225PyTypeChecker::MatchType PyInstanceChecker::Check(PyObject* value) {
226 DCheckPyGilState();
227 auto* type = Py_TYPE(value);
228 auto it = py_class_cache_.find(type);
229 if (it != py_class_cache_.end()) {
230 return it->second;
231 }
232
233 MatchType result = MatchType::NO_MATCH;
234 for (const auto& py_class : py_classes_) {
235 int is_instance = PyObject_IsInstance(value, py_class.get());
236 if (is_instance == 1) {
237 if (IsRegisteredDispatchableType(py_class.get())) {
238 result = MatchType::MATCH_DISPATCHABLE;
239 break;
240 } else {
241 result = MatchType::MATCH;
242 }
243 } else if (is_instance < 0) {
244 PyErr_Clear();
245 return MatchType::NO_MATCH;
246 }
247 }
248
249 if (py_class_cache_.size() < kMaxItemsInCache) {
250 Py_INCREF(type);
251 auto insert_result = py_class_cache_.insert({type, result});
252 if (!insert_result.second) {
253 Py_DECREF(type); // Result was added by a different thread.
254 }
255 }
256 return result;
257}
258
259int PyInstanceChecker::cost() const { return py_classes_.size(); }
260
261std::string PyInstanceChecker::DebugString() const {
262 DCheckPyGilState();
263 std::vector<const char*> type_names;
264 for (const auto& py_class : py_classes_) {
265 type_names.push_back(
266 reinterpret_cast<PyTypeObject*>(py_class.get())->tp_name);
267 }
268 return absl::StrJoin(
269 py_classes_, ", ", [](std::string* out, const Safe_PyObjectPtr& v) {
270 out->append(reinterpret_cast<PyTypeObject*>(v.get())->tp_name);
271 });
272}
273
274PyTypeChecker::MatchType PyListChecker::Check(PyObject* value) {
275 DCheckPyGilState();
276 if (!(PyList_Check(value) || PyTuple_Check(value))) {
277 return MatchType::NO_MATCH;
278 }
279
280 Safe_PyObjectPtr seq(PySequence_Fast(value, ""));
281 if (!seq) {
282 PyErr_Clear();
283 return MatchType::NO_MATCH; // value is not a sequence.
284 }
285
286 MatchType result = MatchType::MATCH;
287 for (int i = 0; i < PySequence_Fast_GET_SIZE(seq.get()); ++i) {
288 switch (element_type_->Check(PySequence_Fast_GET_ITEM(seq.get(), i))) {
289 case MatchType::NO_MATCH:
290 return MatchType::NO_MATCH;
291 case MatchType::MATCH_DISPATCHABLE:
292 result = MatchType::MATCH_DISPATCHABLE;
293 break;
294 case MatchType::MATCH:
295 break;
296 }
297 }
298 return result;
299}
300
301int PyListChecker::cost() const { return 10 * element_type_->cost(); }
302
303std::string PyListChecker::DebugString() const {
304 return absl::StrCat("List[", element_type_->DebugString(), "]");
305}
306
307PyTypeChecker::MatchType PyUnionChecker::Check(PyObject* value) {
308 MatchType result = MatchType::NO_MATCH;
309 for (auto& type_option : options_) {
310 switch (type_option->Check(value)) {
311 case MatchType::MATCH:
312 result = MatchType::MATCH;
313 break;
314 case MatchType::MATCH_DISPATCHABLE:
315 return MatchType::MATCH_DISPATCHABLE;
316 case MatchType::NO_MATCH:
317 break;
318 }
319 }
320 return result;
321}
322
323int PyUnionChecker::cost() const {
324 int cost = 1;
325 for (auto& type_option : options_) {
326 cost += type_option->cost();
327 }
328 return cost;
329}
330
331std::string PyUnionChecker::DebugString() const {
332 return absl::StrCat("Union[",
333 absl::StrJoin(options_, ", ",
334 [](std::string* out, PyTypeChecker_ptr v) {
335 out->append(v->DebugString());
336 }),
337 "]");
338}
339
340} // namespace py_dispatch
341} // namespace tensorflow
342