1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/python/framework/python_api_dispatcher.h" |
17 | |
18 | #include <set> |
19 | |
20 | #include "absl/strings/str_join.h" |
21 | #include "tensorflow/core/platform/logging.h" |
22 | #include "tensorflow/core/platform/macros.h" |
23 | #include "tensorflow/python/lib/core/py_util.h" |
24 | #include "tensorflow/python/lib/core/safe_pyobject_ptr.h" |
25 | #include "tensorflow/python/util/util.h" |
26 | |
27 | namespace tensorflow { |
28 | namespace py_dispatch { |
29 | |
30 | namespace { |
31 | |
32 | std::vector<Safe_PyObjectPtr>& GetRegisteredDispatchableTypes() { |
33 | static std::vector<Safe_PyObjectPtr>* registered_dispatchable_types = |
34 | new std::vector<Safe_PyObjectPtr>(); |
35 | if (registered_dispatchable_types->empty()) { |
36 | static PyObject* composite_tensor = |
37 | swig::GetRegisteredPyObject("CompositeTensor" ); |
38 | Py_INCREF(composite_tensor); |
39 | registered_dispatchable_types->push_back( |
40 | Safe_PyObjectPtr(composite_tensor)); |
41 | } |
42 | return *registered_dispatchable_types; |
43 | } |
44 | |
45 | // Returns true if `py_class` is a registered dispatchable type. |
46 | bool IsRegisteredDispatchableType(PyObject* py_class) { |
47 | DCheckPyGilState(); |
48 | for (const auto& registered_type : GetRegisteredDispatchableTypes()) { |
49 | int result = PyObject_IsSubclass(py_class, registered_type.get()); |
50 | if (result > 0) return true; |
51 | if (result < 0) PyErr_Clear(); |
52 | } |
53 | return false; |
54 | } |
55 | |
56 | // Raises an exception indicating that multiple dispatch targets matched. |
57 | Safe_PyObjectPtr RaiseDispatchConflictError(const std::string& api_name, |
58 | PyObject* selected, |
59 | PyObject* target) { |
60 | Safe_PyObjectPtr s1(PyObject_Str(selected)); |
61 | Safe_PyObjectPtr s2(PyObject_Str(target)); |
62 | PyErr_SetString(PyExc_ValueError, |
63 | absl::StrCat("Multiple dispatch targets that were " |
64 | "registered with tf.dispatch_for (" , |
65 | s1 ? PyUnicode_AsUTF8(s1.get()) : "?" , " and " , |
66 | s2 ? PyUnicode_AsUTF8(s2.get()) : "?" , |
67 | ") match the arguments to " , api_name) |
68 | .c_str()); |
69 | return nullptr; |
70 | } |
71 | |
72 | } // namespace |
73 | |
74 | bool RegisterDispatchableType(PyObject* py_class) { |
75 | DCheckPyGilState(); |
76 | if (!PyType_Check(py_class)) { |
77 | PyErr_SetString( |
78 | PyExc_ValueError, |
79 | absl::StrCat("Expected a type object; got object with type " , |
80 | py_class->ob_type->tp_name) |
81 | .c_str()); |
82 | return false; |
83 | } |
84 | if (IsRegisteredDispatchableType(py_class)) { |
85 | Safe_PyObjectPtr s(PyObject_Str(py_class)); |
86 | PyErr_SetString(PyExc_ValueError, |
87 | absl::StrCat("Type " , s ? PyUnicode_AsUTF8(s.get()) : "?" , |
88 | " (or one of its bases clases) has " |
89 | "already been registered" ) |
90 | .c_str()); |
91 | return false; |
92 | } |
93 | Py_INCREF(py_class); |
94 | GetRegisteredDispatchableTypes().push_back(Safe_PyObjectPtr(py_class)); |
95 | return true; |
96 | } |
97 | |
98 | PythonAPIDispatcher::PythonAPIDispatcher(const std::string& api_name, |
99 | absl::Span<const char*> arg_names, |
100 | absl::Span<PyObject*> defaults) |
101 | : api_name_(api_name), |
102 | canonicalizer_(arg_names, defaults), |
103 | canonicalized_args_storage_(canonicalizer_.GetArgSize()), |
104 | canonicalized_args_span_(canonicalized_args_storage_) {} |
105 | |
106 | void PythonAPIDispatcher::Register(PySignatureChecker signature_checker, |
107 | PyObject* dispatch_target) { |
108 | DCheckPyGilState(); |
109 | Py_INCREF(dispatch_target); |
110 | targets_.emplace_back(std::move(signature_checker), |
111 | Safe_PyObjectPtr(dispatch_target)); |
112 | } |
113 | |
114 | Safe_PyObjectPtr PythonAPIDispatcher::Dispatch(PyObject* args, |
115 | PyObject* kwargs) { |
116 | DCheckPyGilState(); |
117 | if (kwargs == Py_None) { |
118 | kwargs = nullptr; |
119 | } |
120 | // Canonicalize args (so we don't need to deal with kwargs). |
121 | if (!canonicalizer_.Canonicalize(args, kwargs, canonicalized_args_span_)) { |
122 | return nullptr; |
123 | } |
124 | |
125 | PyObject* selected = nullptr; |
126 | for (auto& target : targets_) { |
127 | if (target.first.CheckCanonicalizedArgs(canonicalized_args_span_)) { |
128 | if (selected && selected != target.second.get()) { |
129 | return RaiseDispatchConflictError(api_name_, selected, |
130 | target.second.get()); |
131 | } |
132 | selected = target.second.get(); |
133 | } |
134 | } |
135 | if (selected) { |
136 | return Safe_PyObjectPtr(PyObject_Call(selected, args, kwargs)); |
137 | } else { |
138 | Py_INCREF(Py_NotImplemented); |
139 | return Safe_PyObjectPtr(Py_NotImplemented); |
140 | } |
141 | } |
142 | |
143 | // TODO(b/194903203) Raise an error if `func` is not registered. |
144 | void PythonAPIDispatcher::Unregister(PyObject* func) { |
145 | DCheckPyGilState(); |
146 | using DispatchTargetPair = std::pair<PySignatureChecker, Safe_PyObjectPtr>; |
147 | targets_.erase(std::remove_if(targets_.begin(), targets_.end(), |
148 | [func](const DispatchTargetPair& t) { |
149 | return t.second.get() == func; |
150 | }), |
151 | targets_.end()); |
152 | } |
153 | |
154 | std::string PythonAPIDispatcher::DebugString() const { |
155 | DCheckPyGilState(); |
156 | std::string out = absl::StrCat("<Dispatch(" , api_name_, "): " ); |
157 | |
158 | const char* sep = "" ; |
159 | for (const auto& target : targets_) { |
160 | Safe_PyObjectPtr target_str(PyObject_Str(target.second.get())); |
161 | absl::StrAppend(&out, sep, target.first.DebugString(), " -> " , |
162 | target_str ? PyUnicode_AsUTF8(target_str.get()) : "?" ); |
163 | sep = ", " ; |
164 | } |
165 | return out; |
166 | } |
167 | |
168 | PySignatureChecker::PySignatureChecker( |
169 | std::vector<ParamChecker> parameter_checkers) |
170 | : positional_parameter_checkers_(std::move(parameter_checkers)) { |
171 | // Check less expensive parameters first. |
172 | std::sort(positional_parameter_checkers_.begin(), |
173 | positional_parameter_checkers_.end(), |
174 | [](ParamChecker a, ParamChecker b) { |
175 | return a.second->cost() < b.second->cost(); |
176 | }); |
177 | } |
178 | |
179 | bool PySignatureChecker::CheckCanonicalizedArgs( |
180 | absl::Span<PyObject*> canon_args) const { |
181 | bool matched_dispatchable_type = false; |
182 | for (auto& c : positional_parameter_checkers_) { |
183 | int index = c.first; |
184 | auto& param_checker = c.second; |
185 | if (index >= canon_args.size()) { |
186 | return false; |
187 | } |
188 | switch (param_checker->Check(canon_args[index])) { |
189 | case PyTypeChecker::MatchType::NO_MATCH: |
190 | return false; |
191 | case PyTypeChecker::MatchType::MATCH_DISPATCHABLE: |
192 | matched_dispatchable_type = true; |
193 | break; |
194 | case PyTypeChecker::MatchType::MATCH: |
195 | break; |
196 | } |
197 | } |
198 | return matched_dispatchable_type; |
199 | } |
200 | |
201 | std::string PySignatureChecker::DebugString() const { |
202 | return absl::StrJoin(positional_parameter_checkers_, ", " , |
203 | [](std::string* out, ParamChecker p) { |
204 | absl::StrAppend(out, "args[" , p.first, |
205 | "]:" , p.second->DebugString()); |
206 | }); |
207 | } |
208 | |
209 | PyInstanceChecker::PyInstanceChecker(const std::vector<PyObject*>& py_classes) { |
210 | DCheckPyGilState(); |
211 | py_classes_.reserve(py_classes.size()); |
212 | for (PyObject* py_class : py_classes) { |
213 | py_classes_.emplace_back(py_class); |
214 | Py_INCREF(py_class); |
215 | } |
216 | } |
217 | |
218 | PyInstanceChecker::~PyInstanceChecker() { |
219 | DCheckPyGilState(); |
220 | for (const auto& pair : py_class_cache_) { |
221 | Py_DECREF(pair.first); |
222 | } |
223 | } |
224 | |
225 | PyTypeChecker::MatchType PyInstanceChecker::Check(PyObject* value) { |
226 | DCheckPyGilState(); |
227 | auto* type = Py_TYPE(value); |
228 | auto it = py_class_cache_.find(type); |
229 | if (it != py_class_cache_.end()) { |
230 | return it->second; |
231 | } |
232 | |
233 | MatchType result = MatchType::NO_MATCH; |
234 | for (const auto& py_class : py_classes_) { |
235 | int is_instance = PyObject_IsInstance(value, py_class.get()); |
236 | if (is_instance == 1) { |
237 | if (IsRegisteredDispatchableType(py_class.get())) { |
238 | result = MatchType::MATCH_DISPATCHABLE; |
239 | break; |
240 | } else { |
241 | result = MatchType::MATCH; |
242 | } |
243 | } else if (is_instance < 0) { |
244 | PyErr_Clear(); |
245 | return MatchType::NO_MATCH; |
246 | } |
247 | } |
248 | |
249 | if (py_class_cache_.size() < kMaxItemsInCache) { |
250 | Py_INCREF(type); |
251 | auto insert_result = py_class_cache_.insert({type, result}); |
252 | if (!insert_result.second) { |
253 | Py_DECREF(type); // Result was added by a different thread. |
254 | } |
255 | } |
256 | return result; |
257 | } |
258 | |
259 | int PyInstanceChecker::cost() const { return py_classes_.size(); } |
260 | |
261 | std::string PyInstanceChecker::DebugString() const { |
262 | DCheckPyGilState(); |
263 | std::vector<const char*> type_names; |
264 | for (const auto& py_class : py_classes_) { |
265 | type_names.push_back( |
266 | reinterpret_cast<PyTypeObject*>(py_class.get())->tp_name); |
267 | } |
268 | return absl::StrJoin( |
269 | py_classes_, ", " , [](std::string* out, const Safe_PyObjectPtr& v) { |
270 | out->append(reinterpret_cast<PyTypeObject*>(v.get())->tp_name); |
271 | }); |
272 | } |
273 | |
274 | PyTypeChecker::MatchType PyListChecker::Check(PyObject* value) { |
275 | DCheckPyGilState(); |
276 | if (!(PyList_Check(value) || PyTuple_Check(value))) { |
277 | return MatchType::NO_MATCH; |
278 | } |
279 | |
280 | Safe_PyObjectPtr seq(PySequence_Fast(value, "" )); |
281 | if (!seq) { |
282 | PyErr_Clear(); |
283 | return MatchType::NO_MATCH; // value is not a sequence. |
284 | } |
285 | |
286 | MatchType result = MatchType::MATCH; |
287 | for (int i = 0; i < PySequence_Fast_GET_SIZE(seq.get()); ++i) { |
288 | switch (element_type_->Check(PySequence_Fast_GET_ITEM(seq.get(), i))) { |
289 | case MatchType::NO_MATCH: |
290 | return MatchType::NO_MATCH; |
291 | case MatchType::MATCH_DISPATCHABLE: |
292 | result = MatchType::MATCH_DISPATCHABLE; |
293 | break; |
294 | case MatchType::MATCH: |
295 | break; |
296 | } |
297 | } |
298 | return result; |
299 | } |
300 | |
301 | int PyListChecker::cost() const { return 10 * element_type_->cost(); } |
302 | |
303 | std::string PyListChecker::DebugString() const { |
304 | return absl::StrCat("List[" , element_type_->DebugString(), "]" ); |
305 | } |
306 | |
307 | PyTypeChecker::MatchType PyUnionChecker::Check(PyObject* value) { |
308 | MatchType result = MatchType::NO_MATCH; |
309 | for (auto& type_option : options_) { |
310 | switch (type_option->Check(value)) { |
311 | case MatchType::MATCH: |
312 | result = MatchType::MATCH; |
313 | break; |
314 | case MatchType::MATCH_DISPATCHABLE: |
315 | return MatchType::MATCH_DISPATCHABLE; |
316 | case MatchType::NO_MATCH: |
317 | break; |
318 | } |
319 | } |
320 | return result; |
321 | } |
322 | |
323 | int PyUnionChecker::cost() const { |
324 | int cost = 1; |
325 | for (auto& type_option : options_) { |
326 | cost += type_option->cost(); |
327 | } |
328 | return cost; |
329 | } |
330 | |
331 | std::string PyUnionChecker::DebugString() const { |
332 | return absl::StrCat("Union[" , |
333 | absl::StrJoin(options_, ", " , |
334 | [](std::string* out, PyTypeChecker_ptr v) { |
335 | out->append(v->DebugString()); |
336 | }), |
337 | "]" ); |
338 | } |
339 | |
340 | } // namespace py_dispatch |
341 | } // namespace tensorflow |
342 | |