1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/python/util/util.h" |
16 | |
17 | #include <functional> |
18 | #include <memory> |
19 | #include <unordered_map> |
20 | #include <vector> |
21 | |
22 | #include "absl/memory/memory.h" |
23 | #include "tensorflow/core/lib/gtl/map_util.h" |
24 | #include "tensorflow/core/lib/strings/strcat.h" |
25 | #include "tensorflow/core/platform/logging.h" |
26 | #include "tensorflow/core/platform/mutex.h" |
27 | #include "tensorflow/python/lib/core/safe_pyobject_ptr.h" |
28 | |
29 | namespace tensorflow { |
30 | namespace swig { |
31 | |
32 | namespace { |
33 | string PyObjectToString(PyObject* o); |
34 | } // namespace |
35 | |
36 | std::unordered_map<string, PyObject*>* RegisteredPyObjectMap() { |
37 | static auto* m = new std::unordered_map<string, PyObject*>(); |
38 | return m; |
39 | } |
40 | |
41 | PyObject* GetRegisteredPyObject(const string& name) { |
42 | const auto* m = RegisteredPyObjectMap(); |
43 | auto it = m->find(name); |
44 | if (it == m->end()) { |
45 | PyErr_SetString(PyExc_TypeError, |
46 | tensorflow::strings::StrCat("No object with name " , name, |
47 | " has been registered." ) |
48 | .c_str()); |
49 | return nullptr; |
50 | } |
51 | return it->second; |
52 | } |
53 | |
54 | PyObject* RegisterType(PyObject* type_name, PyObject* type) { |
55 | if (!PyType_Check(type)) { |
56 | PyErr_SetString(PyExc_TypeError, |
57 | tensorflow::strings::StrCat("Expecting a type, got " , |
58 | Py_TYPE(type)->tp_name) |
59 | .c_str()); |
60 | return nullptr; |
61 | } |
62 | return RegisterPyObject(type_name, type); |
63 | } |
64 | |
65 | PyObject* RegisterPyObject(PyObject* name, PyObject* value) { |
66 | string key; |
67 | if (PyBytes_Check(name)) { |
68 | key = PyBytes_AsString(name); |
69 | #if PY_MAJOR_VERSION >= 3 |
70 | } else if (PyUnicode_Check(name)) { |
71 | key = PyUnicode_AsUTF8(name); |
72 | #endif |
73 | } else { |
74 | PyErr_SetString(PyExc_TypeError, tensorflow::strings::StrCat( |
75 | "Expected name to be a str, got" , |
76 | PyObjectToString(name)) |
77 | .c_str()); |
78 | return nullptr; |
79 | } |
80 | |
81 | auto* m = RegisteredPyObjectMap(); |
82 | if (m->find(key) != m->end()) { |
83 | PyErr_SetString(PyExc_TypeError, tensorflow::strings::StrCat( |
84 | "Value already registered for " , key) |
85 | .c_str()); |
86 | return nullptr; |
87 | } |
88 | |
89 | Py_INCREF(value); |
90 | m->emplace(key, value); |
91 | |
92 | Py_RETURN_NONE; |
93 | } |
94 | |
95 | namespace { |
96 | const int kMaxItemsInCache = 1024; |
97 | |
98 | bool IsString(PyObject* o) { |
99 | return PyBytes_Check(o) || |
100 | #if PY_MAJOR_VERSION < 3 |
101 | PyString_Check(o) || |
102 | #endif |
103 | PyUnicode_Check(o); |
104 | } |
105 | |
106 | // Equivalent to Python's 'o.__class__.__name__' |
107 | // Note that '__class__' attribute is set only in new-style classes. |
108 | // A lot of tensorflow code uses __class__ without checks, so it seems like |
109 | // we only support new-style classes. |
110 | StringPiece GetClassName(PyObject* o) { |
111 | // __class__ is equivalent to type() for new style classes. |
112 | // type() is equivalent to PyObject_Type() |
113 | // (https://docs.python.org/3.5/c-api/object.html#c.PyObject_Type) |
114 | // PyObject_Type() is equivalent to o->ob_type except for Py_INCREF, which |
115 | // we don't need here. |
116 | PyTypeObject* type = o->ob_type; |
117 | |
118 | // __name__ is the value of `tp_name` after the last '.' |
119 | // (https://docs.python.org/2/c-api/typeobj.html#c.PyTypeObject.tp_name) |
120 | StringPiece name(type->tp_name); |
121 | size_t pos = name.rfind('.'); |
122 | if (pos != StringPiece::npos) { |
123 | name.remove_prefix(pos + 1); |
124 | } |
125 | return name; |
126 | } |
127 | |
128 | string PyObjectToString(PyObject* o) { |
129 | if (o == nullptr) { |
130 | return "<null object>" ; |
131 | } |
132 | PyObject* str = PyObject_Str(o); |
133 | if (str) { |
134 | #if PY_MAJOR_VERSION < 3 |
135 | string s(PyString_AS_STRING(str)); |
136 | #else |
137 | string s(PyUnicode_AsUTF8(str)); |
138 | #endif |
139 | Py_DECREF(str); |
140 | return tensorflow::strings::StrCat("type=" , GetClassName(o), " str=" , s); |
141 | } else { |
142 | return "<failed to execute str() on object>" ; |
143 | } |
144 | } |
145 | |
146 | class CachedTypeCheck { |
147 | public: |
148 | explicit CachedTypeCheck(std::function<int(PyObject*)> ternary_predicate) |
149 | : ternary_predicate_(std::move(ternary_predicate)) {} |
150 | |
151 | ~CachedTypeCheck() { |
152 | mutex_lock l(type_to_sequence_map_mu_); |
153 | for (const auto& pair : type_to_sequence_map_) { |
154 | Py_DECREF(pair.first); |
155 | } |
156 | } |
157 | |
158 | // Caches successful executions of the one-argument (PyObject*) callable |
159 | // "ternary_predicate" based on the type of "o". -1 from the callable |
160 | // indicates an unsuccessful check (not cached), 0 indicates that "o"'s type |
161 | // does not match the predicate, and 1 indicates that it does. Used to avoid |
162 | // calling back into Python for expensive isinstance checks. |
163 | int CachedLookup(PyObject* o) { |
164 | // Try not to return to Python - see if the type has already been seen |
165 | // before. |
166 | |
167 | auto* type = Py_TYPE(o); |
168 | |
169 | { |
170 | tf_shared_lock l(type_to_sequence_map_mu_); |
171 | auto it = type_to_sequence_map_.find(type); |
172 | if (it != type_to_sequence_map_.end()) { |
173 | return it->second; |
174 | } |
175 | } |
176 | |
177 | int check_result = ternary_predicate_(o); |
178 | |
179 | if (check_result == -1) { |
180 | return -1; // Type check error, not cached. |
181 | } |
182 | |
183 | // NOTE: This is never decref'd as long as the object lives, which is likely |
184 | // forever, but we don't want the type to get deleted as long as it is in |
185 | // the map. This should not be too much of a leak, as there should only be a |
186 | // relatively small number of types in the map, and an even smaller number |
187 | // that are eligible for decref. As a precaution, we limit the size of the |
188 | // map to 1024. |
189 | { |
190 | mutex_lock l(type_to_sequence_map_mu_); |
191 | if (type_to_sequence_map_.size() < kMaxItemsInCache) { |
192 | Py_INCREF(type); |
193 | auto insert_result = type_to_sequence_map_.insert({type, check_result}); |
194 | if (!insert_result.second) { |
195 | // The type was added to the cache by a concurrent thread after we |
196 | // looked it up above. |
197 | Py_DECREF(type); |
198 | } |
199 | } |
200 | } |
201 | |
202 | return check_result; |
203 | } |
204 | |
205 | private: |
206 | std::function<int(PyObject*)> ternary_predicate_; |
207 | mutex type_to_sequence_map_mu_; |
208 | std::unordered_map<PyTypeObject*, bool> type_to_sequence_map_ |
209 | TF_GUARDED_BY(type_to_sequence_map_mu_); |
210 | }; |
211 | |
212 | // Returns 1 if 'obj' is an instance of 'type_name' |
213 | // Returns 0 otherwise. |
214 | // Returns -1 if an error occurred (e.g., if 'type_name' is not registered.) |
215 | int IsInstanceOfRegisteredType(PyObject* obj, const char* type_name) { |
216 | PyObject* type_obj = GetRegisteredPyObject(type_name); |
217 | if (TF_PREDICT_FALSE(type_obj == nullptr)) { |
218 | PyErr_SetString(PyExc_RuntimeError, |
219 | tensorflow::strings::StrCat( |
220 | type_name, |
221 | " type has not been set. " |
222 | "Please register the type with the identifier \"" , |
223 | type_name, "\" using RegisterType." ) |
224 | .c_str()); |
225 | return -1; |
226 | } |
227 | return PyObject_IsInstance(obj, type_obj); |
228 | } |
229 | |
230 | // Returns 1 if `o` is considered a mapping for the purposes of Flatten(). |
231 | // Returns 0 otherwise. |
232 | // Returns -1 if an error occurred. |
233 | int IsMappingHelper(PyObject* o) { |
234 | static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) { |
235 | return IsInstanceOfRegisteredType(to_check, "Mapping" ); |
236 | }); |
237 | if (PyDict_Check(o)) return true; |
238 | return check_cache->CachedLookup(o); |
239 | } |
240 | |
241 | // Returns 1 if `o` is considered a mutable mapping for the purposes of |
242 | // Flatten(). Returns 0 otherwise. Returns -1 if an error occurred. |
243 | int IsMutableMappingHelper(PyObject* o) { |
244 | static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) { |
245 | return IsInstanceOfRegisteredType(to_check, "MutableMapping" ); |
246 | }); |
247 | if (PyDict_Check(o)) return true; |
248 | return check_cache->CachedLookup(o); |
249 | } |
250 | |
251 | // Returns 1 if `o` is considered a mapping view for the purposes of Flatten(). |
252 | // Returns 0 otherwise. |
253 | // Returns -1 if an error occurred. |
254 | int IsMappingViewHelper(PyObject* o) { |
255 | static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) { |
256 | return IsInstanceOfRegisteredType(to_check, "MappingView" ); |
257 | }); |
258 | return check_cache->CachedLookup(o); |
259 | } |
260 | |
261 | // Returns 1 if `o` is considered an object proxy |
262 | // Returns 0 otherwise. |
263 | // Returns -1 if an error occurred. |
264 | int IsObjectProxy(PyObject* o) { |
265 | static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) { |
266 | return IsInstanceOfRegisteredType(to_check, "ObjectProxy" ); |
267 | }); |
268 | return check_cache->CachedLookup(o); |
269 | } |
270 | |
271 | // Returns 1 if `o` is an instance of attrs-decorated class. |
272 | // Returns 0 otherwise. |
273 | int IsAttrsHelper(PyObject* o) { |
274 | static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) { |
275 | Safe_PyObjectPtr cls(PyObject_GetAttrString(to_check, "__class__" )); |
276 | if (cls) { |
277 | return PyObject_HasAttrString(cls.get(), "__attrs_attrs__" ); |
278 | } |
279 | |
280 | // PyObject_GetAttrString returns null on error |
281 | PyErr_Clear(); |
282 | return 0; |
283 | }); |
284 | return check_cache->CachedLookup(o); |
285 | } |
286 | |
287 | // Returns 1 if `o` is an object of type IndexedSlices. |
288 | // Returns 0 otherwise. |
289 | // Returns -1 if an error occurred. |
290 | int IsIndexedSlicesHelper(PyObject* o) { |
291 | static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) { |
292 | return IsInstanceOfRegisteredType(to_check, "IndexedSlices" ); |
293 | }); |
294 | return check_cache->CachedLookup(o); |
295 | } |
296 | |
297 | // Returns 1 if `o` is a Tensor. |
298 | // Returns 0 otherwise. |
299 | // Returns -1 if an error occurred. |
300 | int IsTensorHelper(PyObject* o) { |
301 | static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) { |
302 | return IsInstanceOfRegisteredType(to_check, "Tensor" ); |
303 | }); |
304 | return check_cache->CachedLookup(o); |
305 | } |
306 | |
307 | // Returns 1 if `o` is a TensorSpec. |
308 | // Returns 0 otherwise. |
309 | // Returns -1 if an error occurred. |
310 | int IsTensorSpecHelper(PyObject* o) { |
311 | static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) { |
312 | return IsInstanceOfRegisteredType(to_check, "TensorSpec" ); |
313 | }); |
314 | return check_cache->CachedLookup(o); |
315 | } |
316 | |
317 | // Returns 1 if `o` is an EagerTensor. |
318 | // Returns 0 otherwise. |
319 | // Returns -1 if an error occurred. |
320 | int IsEagerTensorHelper(PyObject* o) { |
321 | static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) { |
322 | return IsInstanceOfRegisteredType(to_check, "EagerTensor" ); |
323 | }); |
324 | return check_cache->CachedLookup(o); |
325 | } |
326 | |
327 | // Returns 1 if `o` is a ResourceVariable. |
328 | // Returns 0 otherwise. |
329 | // Returns -1 if an error occurred. |
330 | int IsResourceVariableHelper(PyObject* o) { |
331 | static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) { |
332 | return IsInstanceOfRegisteredType(to_check, "ResourceVariable" ); |
333 | }); |
334 | return check_cache->CachedLookup(o); |
335 | } |
336 | |
337 | // Returns 1 if `o` is a OwnedIterator. |
338 | // Returns 0 otherwise. |
339 | // Returns -1 if an error occurred. |
340 | int IsOwnedIteratorHelper(PyObject* o) { |
341 | static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) { |
342 | return IsInstanceOfRegisteredType(to_check, "OwnedIterator" ); |
343 | }); |
344 | return check_cache->CachedLookup(o); |
345 | } |
346 | |
347 | // Returns 1 if `o` is a ResourceVariable. |
348 | // Returns 0 otherwise. |
349 | // Returns -1 if an error occurred. |
350 | int IsVariableHelper(PyObject* o) { |
351 | static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) { |
352 | return IsInstanceOfRegisteredType(to_check, "Variable" ); |
353 | }); |
354 | return check_cache->CachedLookup(o); |
355 | } |
356 | |
357 | // Returns 1 if `o` is considered a sequence for the purposes of Flatten(). |
358 | // Returns 0 otherwise. |
359 | // Returns -1 if an error occurred. |
360 | int IsNestedHelper(PyObject* o) { |
361 | // We treat dicts and other mappings as special cases of sequences. |
362 | if (IsMappingHelper(o)) return true; |
363 | if (IsMappingViewHelper(o)) return true; |
364 | if (IsAttrsHelper(o)) return true; |
365 | |
366 | static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) { |
367 | int is_instance = IsInstanceOfRegisteredType(to_check, "Sequence" ); |
368 | |
369 | // Don't cache a failed is_instance check. |
370 | if (is_instance == -1) return -1; |
371 | |
372 | return static_cast<int>(is_instance != 0 && !IsString(to_check)); |
373 | }); |
374 | return check_cache->CachedLookup(o); |
375 | } |
376 | |
377 | // Returns 1 if `o`'s class has a `__tf_dispatch__` attribute. |
378 | // Returns 0 otherwise. |
379 | int IsDispatchableHelper(PyObject* o) { |
380 | static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) { |
381 | return PyObject_HasAttrString( |
382 | reinterpret_cast<PyObject*>(to_check->ob_type), "__tf_dispatch__" ); |
383 | }); |
384 | return check_cache->CachedLookup(o); |
385 | } |
386 | |
387 | // ValueIterator interface |
388 | class ValueIterator { |
389 | public: |
390 | virtual ~ValueIterator() {} |
391 | virtual Safe_PyObjectPtr next() = 0; |
392 | |
393 | bool valid() const { return is_valid_; } |
394 | |
395 | protected: |
396 | void invalidate() { is_valid_ = false; } |
397 | |
398 | private: |
399 | bool is_valid_ = true; |
400 | }; |
401 | |
402 | using ValueIteratorPtr = std::unique_ptr<ValueIterator>; |
403 | |
404 | // Iterate through dictionaries in a deterministic order by sorting the |
405 | // keys. Notice this means that we ignore the original order of |
406 | // `OrderedDict` instances. This is intentional, to avoid potential |
407 | // bugs caused by mixing ordered and plain dicts (e.g., flattening |
408 | // a dict but using a corresponding `OrderedDict` to pack it back). |
409 | class DictValueIterator : public ValueIterator { |
410 | public: |
411 | explicit DictValueIterator(PyObject* dict) |
412 | : dict_(dict), keys_(PyDict_Keys(dict)) { |
413 | if (PyList_Sort(keys_.get()) == -1) { |
414 | invalidate(); |
415 | } else { |
416 | iter_.reset(PyObject_GetIter(keys_.get())); |
417 | } |
418 | } |
419 | |
420 | Safe_PyObjectPtr next() override { |
421 | Safe_PyObjectPtr result; |
422 | Safe_PyObjectPtr key(PyIter_Next(iter_.get())); |
423 | if (key) { |
424 | // PyDict_GetItem returns a borrowed reference. |
425 | PyObject* elem = PyDict_GetItem(dict_, key.get()); |
426 | if (elem) { |
427 | Py_INCREF(elem); |
428 | result.reset(elem); |
429 | } else { |
430 | PyErr_SetString(PyExc_RuntimeError, |
431 | "Dictionary was modified during iteration over it" ); |
432 | } |
433 | } |
434 | return result; |
435 | } |
436 | |
437 | private: |
438 | PyObject* dict_; |
439 | Safe_PyObjectPtr keys_; |
440 | Safe_PyObjectPtr iter_; |
441 | }; |
442 | |
443 | // Iterate over mapping objects by sorting the keys first |
444 | class MappingValueIterator : public ValueIterator { |
445 | public: |
446 | explicit MappingValueIterator(PyObject* mapping) |
447 | : mapping_(mapping), keys_(MappingKeys(mapping)) { |
448 | if (!keys_ || PyList_Sort(keys_.get()) == -1) { |
449 | invalidate(); |
450 | } else { |
451 | iter_.reset(PyObject_GetIter(keys_.get())); |
452 | } |
453 | } |
454 | |
455 | Safe_PyObjectPtr next() override { |
456 | Safe_PyObjectPtr result; |
457 | Safe_PyObjectPtr key(PyIter_Next(iter_.get())); |
458 | if (key) { |
459 | // Unlike PyDict_GetItem, PyObject_GetItem returns a new reference. |
460 | PyObject* elem = PyObject_GetItem(mapping_, key.get()); |
461 | if (elem) { |
462 | result.reset(elem); |
463 | } else { |
464 | PyErr_SetString(PyExc_RuntimeError, |
465 | "Mapping was modified during iteration over it" ); |
466 | } |
467 | } |
468 | return result; |
469 | } |
470 | |
471 | private: |
472 | PyObject* mapping_; |
473 | Safe_PyObjectPtr keys_; |
474 | Safe_PyObjectPtr iter_; |
475 | }; |
476 | |
477 | // Iterate over a sequence, by index. |
478 | class SequenceValueIterator : public ValueIterator { |
479 | public: |
480 | explicit SequenceValueIterator(PyObject* iterable) |
481 | : seq_(PySequence_Fast(iterable, "" )), |
482 | size_(seq_.get() ? PySequence_Fast_GET_SIZE(seq_.get()) : 0), |
483 | index_(0) {} |
484 | |
485 | Safe_PyObjectPtr next() override { |
486 | Safe_PyObjectPtr result; |
487 | if (index_ < size_) { |
488 | // PySequence_Fast_GET_ITEM returns a borrowed reference. |
489 | PyObject* elem = PySequence_Fast_GET_ITEM(seq_.get(), index_); |
490 | ++index_; |
491 | if (elem) { |
492 | Py_INCREF(elem); |
493 | result.reset(elem); |
494 | } |
495 | } |
496 | |
497 | return result; |
498 | } |
499 | |
500 | private: |
501 | Safe_PyObjectPtr seq_; |
502 | const Py_ssize_t size_; |
503 | Py_ssize_t index_; |
504 | }; |
505 | |
506 | // Iterator that just returns a single python object. |
507 | class SingleValueIterator : public ValueIterator { |
508 | public: |
509 | explicit SingleValueIterator(PyObject* x) : x_(x) { Py_INCREF(x); } |
510 | |
511 | Safe_PyObjectPtr next() override { return std::move(x_); } |
512 | |
513 | private: |
514 | Safe_PyObjectPtr x_; |
515 | }; |
516 | |
517 | // Returns nullptr (to raise an exception) when next() is called. Caller |
518 | // should have already called PyErr_SetString. |
519 | class ErrorValueIterator : public ValueIterator { |
520 | public: |
521 | ErrorValueIterator() {} |
522 | Safe_PyObjectPtr next() override { return nullptr; } |
523 | }; |
524 | |
525 | class AttrsValueIterator : public ValueIterator { |
526 | public: |
527 | explicit AttrsValueIterator(PyObject* nested) : nested_(nested) { |
528 | Py_INCREF(nested); |
529 | cls_.reset(PyObject_GetAttrString(nested_.get(), "__class__" )); |
530 | if (cls_) { |
531 | attrs_.reset(PyObject_GetAttrString(cls_.get(), "__attrs_attrs__" )); |
532 | if (attrs_) { |
533 | iter_.reset(PyObject_GetIter(attrs_.get())); |
534 | } |
535 | } |
536 | if (!iter_ || PyErr_Occurred()) invalidate(); |
537 | } |
538 | |
539 | Safe_PyObjectPtr next() override { |
540 | Safe_PyObjectPtr result; |
541 | Safe_PyObjectPtr item(PyIter_Next(iter_.get())); |
542 | if (item) { |
543 | Safe_PyObjectPtr name(PyObject_GetAttrString(item.get(), "name" )); |
544 | result.reset(PyObject_GetAttr(nested_.get(), name.get())); |
545 | } |
546 | |
547 | return result; |
548 | } |
549 | |
550 | private: |
551 | Safe_PyObjectPtr nested_; |
552 | Safe_PyObjectPtr cls_; |
553 | Safe_PyObjectPtr attrs_; |
554 | Safe_PyObjectPtr iter_; |
555 | }; |
556 | |
557 | bool IsSparseTensorValueType(PyObject* o) { |
558 | PyObject* sparse_tensor_value_type = |
559 | GetRegisteredPyObject("SparseTensorValue" ); |
560 | if (TF_PREDICT_FALSE(sparse_tensor_value_type == nullptr)) { |
561 | return false; |
562 | } |
563 | |
564 | return PyObject_TypeCheck( |
565 | o, reinterpret_cast<PyTypeObject*>(sparse_tensor_value_type)) == 1; |
566 | } |
567 | |
568 | // Returns 1 if `o` is an instance of CompositeTensor. |
569 | // Returns 0 otherwise. |
570 | // Returns -1 if an error occurred. |
571 | bool IsCompositeTensorHelper(PyObject* o) { |
572 | static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) { |
573 | // TODO(b/246438937): Remove the ResourceVariable test. |
574 | return IsInstanceOfRegisteredType(to_check, "CompositeTensor" ) && |
575 | !IsResourceVariable(to_check); |
576 | }); |
577 | return check_cache->CachedLookup(o); |
578 | } |
579 | |
580 | // Returns 1 if `o` is an instance of TypeSpec, but is not TensorSpec or |
581 | // VariableSpec. |
582 | // Returns 0 otherwise. |
583 | // Returns -1 if an error occurred. |
584 | bool IsTypeSpecHelper(PyObject* o) { |
585 | static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) { |
586 | int is_type_spec = IsInstanceOfRegisteredType(to_check, "TypeSpec" ); |
587 | // TODO(b/246438937): Remove the VariableSpec special case. |
588 | int is_dense_spec = (IsInstanceOfRegisteredType(to_check, "TensorSpec" ) || |
589 | IsInstanceOfRegisteredType(to_check, "VariableSpec" )); |
590 | if ((is_type_spec == -1) || (is_dense_spec == -1)) return -1; |
591 | return static_cast<int>(is_type_spec && !is_dense_spec); |
592 | }); |
593 | return check_cache->CachedLookup(o); |
594 | } |
595 | |
596 | // Returns 1 if `o` is a (non-string) sequence or CompositeTensor or |
597 | // (non-TensorSpec and non-VariableSpec) TypeSpec. |
598 | // Returns 0 otherwise. |
599 | // Returns -1 if an error occurred. |
600 | int IsNestedOrCompositeHelper(PyObject* o) { |
601 | int is_nested = IsNestedHelper(o); |
602 | int is_composite = IsCompositeTensorHelper(o); |
603 | int is_type_spec = IsTypeSpecHelper(o); |
604 | if ((is_nested == -1) || (is_composite == -1) || (is_type_spec == -1)) { |
605 | return -1; |
606 | } |
607 | return is_nested || is_composite || is_type_spec; |
608 | } |
609 | |
610 | int IsNestedForDataHelper(PyObject* o) { |
611 | return IsNestedHelper(o) == 1 && !PyList_Check(o) && |
612 | !IsSparseTensorValueType(o); |
613 | } |
614 | |
615 | ValueIteratorPtr GetValueIterator(PyObject* nested) { |
616 | if (PyDict_Check(nested)) { |
617 | return absl::make_unique<DictValueIterator>(nested); |
618 | } else if (IsMappingHelper(nested)) { |
619 | return absl::make_unique<MappingValueIterator>(nested); |
620 | } else if (IsAttrsHelper(nested)) { |
621 | return absl::make_unique<AttrsValueIterator>(nested); |
622 | } else { |
623 | return absl::make_unique<SequenceValueIterator>(nested); |
624 | } |
625 | } |
626 | |
627 | // Similar to above, just specialized for the functions in the data package. |
628 | ValueIteratorPtr GetValueIteratorForData(PyObject* nested) { |
629 | if (PyDict_Check(nested)) { |
630 | return absl::make_unique<DictValueIterator>(nested); |
631 | } else if (IsMappingHelper(nested)) { |
632 | return absl::make_unique<MappingValueIterator>(nested); |
633 | } else if (IsAttrsHelper(nested)) { |
634 | return absl::make_unique<AttrsValueIterator>(nested); |
635 | } else if (IsSparseTensorValueType(nested)) { |
636 | return absl::make_unique<SingleValueIterator>(nested); |
637 | } else { |
638 | return absl::make_unique<SequenceValueIterator>(nested); |
639 | } |
640 | } |
641 | |
642 | // Similar to GetValueIterator above, but expands CompositeTensor and TypeSpec. |
643 | ValueIteratorPtr GetValueIteratorForComposite(PyObject* nested) { |
644 | if (IsCompositeTensor(nested)) { |
645 | Safe_PyObjectPtr spec(PyObject_GetAttrString(nested, "_type_spec" )); |
646 | if (PyErr_Occurred() || !spec) { |
647 | return absl::make_unique<ErrorValueIterator>(); |
648 | } |
649 | |
650 | static char to_components[] = "_to_components" ; |
651 | static char argspec[] = "(O)" ; |
652 | Safe_PyObjectPtr components( |
653 | PyObject_CallMethod(spec.get(), to_components, argspec, nested)); |
654 | if (PyErr_Occurred() || components == nullptr) { |
655 | return absl::make_unique<ErrorValueIterator>(); |
656 | } |
657 | return absl::make_unique<SingleValueIterator>(components.get()); |
658 | } |
659 | |
660 | if (IsTypeSpec(nested)) { |
661 | Safe_PyObjectPtr specs(PyObject_GetAttrString(nested, "_component_specs" )); |
662 | if (PyErr_Occurred() || specs == nullptr) { |
663 | return absl::make_unique<ErrorValueIterator>(); |
664 | } |
665 | return absl::make_unique<SingleValueIterator>(specs.get()); |
666 | } |
667 | |
668 | return GetValueIterator(nested); |
669 | } |
670 | |
671 | bool FlattenHelper( |
672 | PyObject* nested, PyObject* list, |
673 | const std::function<int(PyObject*)>& is_nested_helper, |
674 | const std::function<ValueIteratorPtr(PyObject*)>& value_iterator_getter) { |
675 | // if nested is not a sequence, append itself and exit |
676 | int is_nested = is_nested_helper(nested); |
677 | if (is_nested == -1) return false; |
678 | if (!is_nested) { |
679 | return PyList_Append(list, nested) != -1; |
680 | } |
681 | |
682 | ValueIteratorPtr iter = value_iterator_getter(nested); |
683 | if (!iter->valid()) return false; |
684 | |
685 | for (Safe_PyObjectPtr item = iter->next(); item; item = iter->next()) { |
686 | if (Py_EnterRecursiveCall(" in flatten" )) { |
687 | return false; |
688 | } |
689 | const bool success = FlattenHelper(item.get(), list, is_nested_helper, |
690 | value_iterator_getter); |
691 | Py_LeaveRecursiveCall(); |
692 | if (!success) { |
693 | return false; |
694 | } |
695 | } |
696 | return true; |
697 | } |
698 | |
699 | // Sets error using keys of 'dict1' and 'dict2'. |
700 | // 'dict1' and 'dict2' are assumed to be Python dictionaries. |
701 | void SetDifferentKeysError(PyObject* dict1, PyObject* dict2, string* error_msg, |
702 | bool* is_type_error) { |
703 | Safe_PyObjectPtr k1(MappingKeys(dict1)); |
704 | if (PyErr_Occurred() || k1.get() == nullptr) { |
705 | *error_msg = |
706 | ("The two dictionaries don't have the same set of keys. Failed to " |
707 | "fetch keys." ); |
708 | return; |
709 | } |
710 | Safe_PyObjectPtr k2(MappingKeys(dict2)); |
711 | if (PyErr_Occurred() || k2.get() == nullptr) { |
712 | *error_msg = |
713 | ("The two dictionaries don't have the same set of keys. Failed to " |
714 | "fetch keys." ); |
715 | return; |
716 | } |
717 | *is_type_error = false; |
718 | *error_msg = tensorflow::strings::StrCat( |
719 | "The two dictionaries don't have the same set of keys. " |
720 | "First structure has keys " , |
721 | PyObjectToString(k1.get()), ", while second structure has keys " , |
722 | PyObjectToString(k2.get())); |
723 | } |
724 | |
725 | // Returns true iff there were no "internal" errors. In other words, |
726 | // errors that has nothing to do with structure checking. |
727 | // If an "internal" error occurred, the appropriate Python error will be |
728 | // set and the caller can propage it directly to the user. |
729 | // |
730 | // Both `error_msg` and `is_type_error` must be non-null. `error_msg` must |
731 | // be empty. |
732 | // Leaves `error_msg` empty if structures matched. Else, fills `error_msg` |
733 | // with appropriate error and sets `is_type_error` to true iff |
734 | // the error to be raised should be TypeError. |
735 | bool AssertSameStructureHelper( |
736 | PyObject* o1, PyObject* o2, bool check_types, string* error_msg, |
737 | bool* is_type_error, const std::function<int(PyObject*)>& is_nested_helper, |
738 | const std::function<ValueIteratorPtr(PyObject*)>& value_iterator_getter, |
739 | bool check_composite_tensor_type_spec) { |
740 | DCHECK(error_msg); |
741 | DCHECK(is_type_error); |
742 | const bool is_nested1 = is_nested_helper(o1); |
743 | const bool is_nested2 = is_nested_helper(o2); |
744 | if (PyErr_Occurred()) return false; |
745 | if (is_nested1 != is_nested2) { |
746 | string seq_str = is_nested1 ? PyObjectToString(o1) : PyObjectToString(o2); |
747 | string non_seq_str = |
748 | is_nested1 ? PyObjectToString(o2) : PyObjectToString(o1); |
749 | *is_type_error = false; |
750 | *error_msg = tensorflow::strings::StrCat( |
751 | "Substructure \"" , seq_str, "\" is a sequence, while substructure \"" , |
752 | non_seq_str, "\" is not" ); |
753 | return true; |
754 | } |
755 | |
756 | // Got to objects that are considered non-sequences. Note that in tf.data |
757 | // use case lists and sparse_tensors are not considered sequences. So finished |
758 | // checking, structures are the same. |
759 | if (!is_nested1) return true; |
760 | |
761 | if (check_types) { |
762 | // Treat wrapped tuples as tuples. |
763 | tensorflow::Safe_PyObjectPtr o1_wrapped; |
764 | if (IsObjectProxy(o1)) { |
765 | o1_wrapped.reset(PyObject_GetAttrString(o1, "__wrapped__" )); |
766 | o1 = o1_wrapped.get(); |
767 | } |
768 | tensorflow::Safe_PyObjectPtr o2_wrapped; |
769 | if (IsObjectProxy(o2)) { |
770 | o2_wrapped.reset(PyObject_GetAttrString(o2, "__wrapped__" )); |
771 | o2 = o2_wrapped.get(); |
772 | } |
773 | |
774 | const PyTypeObject* type1 = o1->ob_type; |
775 | const PyTypeObject* type2 = o2->ob_type; |
776 | |
777 | // We treat two different namedtuples with identical name and fields |
778 | // as having the same type. |
779 | const PyObject* o1_tuple = IsNamedtuple(o1, false); |
780 | if (o1_tuple == nullptr) return false; |
781 | const PyObject* o2_tuple = IsNamedtuple(o2, false); |
782 | if (o2_tuple == nullptr) { |
783 | Py_DECREF(o1_tuple); |
784 | return false; |
785 | } |
786 | bool both_tuples = o1_tuple == Py_True && o2_tuple == Py_True; |
787 | Py_DECREF(o1_tuple); |
788 | Py_DECREF(o2_tuple); |
789 | |
790 | if (both_tuples) { |
791 | const PyObject* same_tuples = SameNamedtuples(o1, o2); |
792 | if (same_tuples == nullptr) return false; |
793 | bool not_same_tuples = same_tuples != Py_True; |
794 | Py_DECREF(same_tuples); |
795 | if (not_same_tuples) { |
796 | *is_type_error = true; |
797 | *error_msg = tensorflow::strings::StrCat( |
798 | "The two namedtuples don't have the same sequence type. " |
799 | "First structure " , |
800 | PyObjectToString(o1), " has type " , type1->tp_name, |
801 | ", while second structure " , PyObjectToString(o2), " has type " , |
802 | type2->tp_name); |
803 | return true; |
804 | } |
805 | } else if (type1 != type2 |
806 | /* If both sequences are list types, don't complain. This allows |
807 | one to be a list subclass (e.g. _ListWrapper used for |
808 | automatic dependency tracking.) */ |
809 | && !(PyList_Check(o1) && PyList_Check(o2)) |
810 | /* Two mapping types will also compare equal, making _DictWrapper |
811 | and dict compare equal. */ |
812 | && !(IsMappingHelper(o1) && IsMappingHelper(o2)) |
813 | /* For CompositeTensor & TypeSpec, we check below. */ |
814 | && !(check_composite_tensor_type_spec && |
815 | (IsCompositeTensor(o1) || IsTypeSpec(o1)) && |
816 | (IsCompositeTensor(o2) || IsTypeSpec(o2)))) { |
817 | *is_type_error = true; |
818 | *error_msg = tensorflow::strings::StrCat( |
819 | "The two namedtuples don't have the same sequence type. " |
820 | "First structure " , |
821 | PyObjectToString(o1), " has type " , type1->tp_name, |
822 | ", while second structure " , PyObjectToString(o2), " has type " , |
823 | type2->tp_name); |
824 | return true; |
825 | } |
826 | |
827 | if (PyDict_Check(o1) && PyDict_Check(o2)) { |
828 | if (PyDict_Size(o1) != PyDict_Size(o2)) { |
829 | SetDifferentKeysError(o1, o2, error_msg, is_type_error); |
830 | return true; |
831 | } |
832 | |
833 | PyObject* key; |
834 | Py_ssize_t pos = 0; |
835 | while (PyDict_Next(o1, &pos, &key, nullptr)) { |
836 | if (PyDict_GetItem(o2, key) == nullptr) { |
837 | SetDifferentKeysError(o1, o2, error_msg, is_type_error); |
838 | return true; |
839 | } |
840 | } |
841 | } else if (IsMappingHelper(o1)) { |
842 | // Fallback for custom mapping types. Instead of using PyDict methods |
843 | // which stay in C, we call iter(o1). |
844 | if (PyMapping_Size(o1) != PyMapping_Size(o2)) { |
845 | SetDifferentKeysError(o1, o2, error_msg, is_type_error); |
846 | return true; |
847 | } |
848 | |
849 | Safe_PyObjectPtr iter(PyObject_GetIter(o1)); |
850 | PyObject* key; |
851 | while ((key = PyIter_Next(iter.get())) != nullptr) { |
852 | if (!PyMapping_HasKey(o2, key)) { |
853 | SetDifferentKeysError(o1, o2, error_msg, is_type_error); |
854 | Py_DECREF(key); |
855 | return true; |
856 | } |
857 | Py_DECREF(key); |
858 | } |
859 | } |
860 | } |
861 | |
862 | if (check_composite_tensor_type_spec && |
863 | (IsCompositeTensor(o1) || IsCompositeTensor(o2))) { |
864 | Safe_PyObjectPtr owned_type_spec_1; |
865 | PyObject* type_spec_1 = o1; |
866 | if (IsCompositeTensor(o1)) { |
867 | owned_type_spec_1.reset(PyObject_GetAttrString(o1, "_type_spec" )); |
868 | type_spec_1 = owned_type_spec_1.get(); |
869 | } |
870 | |
871 | Safe_PyObjectPtr owned_type_spec_2; |
872 | PyObject* type_spec_2 = o2; |
873 | if (IsCompositeTensor(o2)) { |
874 | owned_type_spec_2.reset(PyObject_GetAttrString(o2, "_type_spec" )); |
875 | type_spec_2 = owned_type_spec_2.get(); |
876 | } |
877 | |
878 | // Two composite tensors are considered to have the same structure if |
879 | // they share a type spec that is a supertype of both of them. We do *not* |
880 | // use is_subtype_of, since that would prevent us from e.g. using a |
881 | // cond statement where the two sides have different shapes. |
882 | |
883 | // TODO(b/206014848): We have to explicitly remove the names. |
884 | Safe_PyObjectPtr owned_nameless_type_spec_1( |
885 | PyObject_CallMethod(type_spec_1, "_without_tensor_names" , nullptr)); |
886 | Safe_PyObjectPtr owned_nameless_type_spec_2( |
887 | PyObject_CallMethod(type_spec_2, "_without_tensor_names" , nullptr)); |
888 | // TODO(b/222123181): Reconsider most_specific_common_supertype usage. |
889 | static char compatible_type[] = "most_specific_common_supertype" ; |
890 | static char argspec[] = "([O])" ; |
891 | Safe_PyObjectPtr struct_compatible( |
892 | PyObject_CallMethod(owned_nameless_type_spec_1.get(), compatible_type, |
893 | argspec, owned_nameless_type_spec_2.get())); |
894 | if (PyErr_Occurred()) { |
895 | return false; |
896 | } |
897 | if (struct_compatible.get() == Py_None) { |
898 | *is_type_error = false; |
899 | *error_msg = tensorflow::strings::StrCat( |
900 | "Incompatible CompositeTensor TypeSpecs: " , |
901 | PyObjectToString(type_spec_1), " vs. " , |
902 | PyObjectToString(type_spec_2)); |
903 | return true; |
904 | } |
905 | } |
906 | |
907 | ValueIteratorPtr iter1 = value_iterator_getter(o1); |
908 | ValueIteratorPtr iter2 = value_iterator_getter(o2); |
909 | |
910 | if (!iter1->valid() || !iter2->valid()) return false; |
911 | |
912 | while (true) { |
913 | Safe_PyObjectPtr v1 = iter1->next(); |
914 | Safe_PyObjectPtr v2 = iter2->next(); |
915 | if (v1 && v2) { |
916 | if (Py_EnterRecursiveCall(" in assert_same_structure" )) { |
917 | return false; |
918 | } |
919 | bool no_internal_errors = AssertSameStructureHelper( |
920 | v1.get(), v2.get(), check_types, error_msg, is_type_error, |
921 | is_nested_helper, value_iterator_getter, |
922 | check_composite_tensor_type_spec); |
923 | Py_LeaveRecursiveCall(); |
924 | if (!no_internal_errors) return false; |
925 | if (!error_msg->empty()) return true; |
926 | } else if (!v1 && !v2) { |
927 | // Done with all recursive calls. Structure matched. |
928 | return true; |
929 | } else { |
930 | *is_type_error = false; |
931 | *error_msg = tensorflow::strings::StrCat( |
932 | "The two structures don't have the same number of elements. " , |
933 | "First structure: " , PyObjectToString(o1), |
934 | ". Second structure: " , PyObjectToString(o2)); |
935 | return true; |
936 | } |
937 | } |
938 | } |
939 | |
940 | } // namespace |
941 | |
942 | bool IsNested(PyObject* o) { return IsNestedHelper(o) == 1; } |
943 | bool IsMapping(PyObject* o) { return IsMappingHelper(o) == 1; } |
944 | bool IsMutableMapping(PyObject* o) { return IsMutableMappingHelper(o) == 1; } |
945 | bool IsMappingView(PyObject* o) { return IsMappingViewHelper(o) == 1; } |
946 | bool IsAttrs(PyObject* o) { return IsAttrsHelper(o) == 1; } |
947 | bool IsTensor(PyObject* o) { return IsTensorHelper(o) == 1; } |
948 | bool IsTensorSpec(PyObject* o) { return IsTensorSpecHelper(o) == 1; } |
949 | bool IsEagerTensorSlow(PyObject* o) { return IsEagerTensorHelper(o) == 1; } |
950 | bool IsResourceVariable(PyObject* o) { |
951 | return IsResourceVariableHelper(o) == 1; |
952 | } |
953 | bool IsOwnedIterator(PyObject* o) { return IsOwnedIteratorHelper(o) == 1; } |
954 | bool IsVariable(PyObject* o) { return IsVariableHelper(o) == 1; } |
955 | bool IsIndexedSlices(PyObject* o) { return IsIndexedSlicesHelper(o) == 1; } |
956 | bool IsDispatchable(PyObject* o) { return IsDispatchableHelper(o) == 1; } |
957 | |
958 | bool IsTuple(PyObject* o) { |
959 | tensorflow::Safe_PyObjectPtr wrapped; |
960 | if (IsObjectProxy(o)) { |
961 | wrapped.reset(PyObject_GetAttrString(o, "__wrapped__" )); |
962 | o = wrapped.get(); |
963 | } |
964 | return PyTuple_Check(o); |
965 | } |
966 | |
967 | // Work around a writable-strings warning with Python 2's PyMapping_Keys macro, |
968 | // and while we're at it give them consistent behavior by making sure the |
969 | // returned value is a list. |
970 | // |
971 | // As with PyMapping_Keys, returns a new reference. |
972 | // |
973 | // On failure, returns nullptr. |
974 | PyObject* MappingKeys(PyObject* o) { |
975 | #if PY_MAJOR_VERSION >= 3 |
976 | return PyMapping_Keys(o); |
977 | #else |
978 | static char key_method_name[] = "keys" ; |
979 | Safe_PyObjectPtr raw_result(PyObject_CallMethod(o, key_method_name, nullptr)); |
980 | if (PyErr_Occurred() || raw_result.get() == nullptr) { |
981 | return nullptr; |
982 | } |
983 | return PySequence_Fast( |
984 | raw_result.get(), |
985 | "The '.keys()' method of a custom mapping returned a non-sequence." ); |
986 | #endif |
987 | } |
988 | |
989 | PyObject* Flatten(PyObject* nested, bool expand_composites) { |
990 | PyObject* list = PyList_New(0); |
991 | const std::function<int(PyObject*)>& is_nested_helper = |
992 | expand_composites ? IsNestedOrCompositeHelper : IsNestedHelper; |
993 | const std::function<ValueIteratorPtr(PyObject*)>& get_value_iterator = |
994 | expand_composites ? GetValueIteratorForComposite : GetValueIterator; |
995 | if (FlattenHelper(nested, list, is_nested_helper, get_value_iterator)) { |
996 | return list; |
997 | } else { |
998 | Py_DECREF(list); |
999 | return nullptr; |
1000 | } |
1001 | } |
1002 | |
1003 | bool IsNestedOrComposite(PyObject* o) { |
1004 | return IsNestedOrCompositeHelper(o) == 1; |
1005 | } |
1006 | |
1007 | bool IsCompositeTensor(PyObject* o) { return IsCompositeTensorHelper(o) == 1; } |
1008 | |
1009 | bool IsTypeSpec(PyObject* o) { return IsTypeSpecHelper(o) == 1; } |
1010 | |
1011 | bool IsNestedForData(PyObject* o) { return IsNestedForDataHelper(o) == 1; } |
1012 | |
1013 | PyObject* FlattenForData(PyObject* nested) { |
1014 | PyObject* list = PyList_New(0); |
1015 | if (FlattenHelper(nested, list, IsNestedForDataHelper, |
1016 | GetValueIteratorForData)) { |
1017 | return list; |
1018 | } else { |
1019 | Py_DECREF(list); |
1020 | return nullptr; |
1021 | } |
1022 | } |
1023 | |
1024 | PyObject* IsNamedtuple(PyObject* o, bool strict) { |
1025 | // Some low-level CPython calls do not work with wrapt.ObjectProxy, so they |
1026 | // require some unwrapping if we want to treat them like the objects they're |
1027 | // wrapping. |
1028 | tensorflow::Safe_PyObjectPtr o_wrapped; |
1029 | if (IsObjectProxy(o)) { |
1030 | o_wrapped.reset(PyObject_GetAttrString(o, "__wrapped__" )); |
1031 | o = o_wrapped.get(); |
1032 | } |
1033 | |
1034 | // Must be subclass of tuple |
1035 | if (!PyTuple_Check(o)) { |
1036 | Py_RETURN_FALSE; |
1037 | } |
1038 | |
1039 | // If strict, o.__class__.__base__ must be tuple |
1040 | if (strict) { |
1041 | PyObject* klass = PyObject_GetAttrString(o, "__class__" ); |
1042 | if (klass == nullptr) return nullptr; |
1043 | PyObject* base = PyObject_GetAttrString(klass, "__base__" ); |
1044 | Py_DECREF(klass); |
1045 | if (base == nullptr) return nullptr; |
1046 | |
1047 | const PyTypeObject* base_type = reinterpret_cast<PyTypeObject*>(base); |
1048 | // built-in object types are singletons |
1049 | bool tuple_base = base_type == &PyTuple_Type; |
1050 | Py_DECREF(base); |
1051 | if (!tuple_base) { |
1052 | Py_RETURN_FALSE; |
1053 | } |
1054 | } |
1055 | |
1056 | // o must have attribute '_fields' and every element in |
1057 | // '_fields' must be a string. |
1058 | int has_fields = PyObject_HasAttrString(o, "_fields" ); |
1059 | if (!has_fields) { |
1060 | Py_RETURN_FALSE; |
1061 | } |
1062 | |
1063 | Safe_PyObjectPtr fields = make_safe(PyObject_GetAttrString(o, "_fields" )); |
1064 | int is_instance = IsInstanceOfRegisteredType(fields.get(), "Sequence" ); |
1065 | if (is_instance == 0) { |
1066 | Py_RETURN_FALSE; |
1067 | } else if (is_instance == -1) { |
1068 | return nullptr; |
1069 | } |
1070 | |
1071 | Safe_PyObjectPtr seq = make_safe(PySequence_Fast(fields.get(), "" )); |
1072 | const Py_ssize_t s = PySequence_Fast_GET_SIZE(seq.get()); |
1073 | for (Py_ssize_t i = 0; i < s; ++i) { |
1074 | // PySequence_Fast_GET_ITEM returns borrowed ref |
1075 | PyObject* elem = PySequence_Fast_GET_ITEM(seq.get(), i); |
1076 | if (!IsString(elem)) { |
1077 | Py_RETURN_FALSE; |
1078 | } |
1079 | } |
1080 | |
1081 | Py_RETURN_TRUE; |
1082 | } |
1083 | |
1084 | PyObject* SameNamedtuples(PyObject* o1, PyObject* o2) { |
1085 | Safe_PyObjectPtr f1 = make_safe(PyObject_GetAttrString(o1, "_fields" )); |
1086 | Safe_PyObjectPtr f2 = make_safe(PyObject_GetAttrString(o2, "_fields" )); |
1087 | if (f1 == nullptr || f2 == nullptr) { |
1088 | PyErr_SetString( |
1089 | PyExc_RuntimeError, |
1090 | "Expected namedtuple-like objects (that have _fields attr)" ); |
1091 | return nullptr; |
1092 | } |
1093 | |
1094 | if (PyObject_RichCompareBool(f1.get(), f2.get(), Py_NE)) { |
1095 | Py_RETURN_FALSE; |
1096 | } |
1097 | |
1098 | if (GetClassName(o1).compare(GetClassName(o2)) == 0) { |
1099 | Py_RETURN_TRUE; |
1100 | } else { |
1101 | Py_RETURN_FALSE; |
1102 | } |
1103 | } |
1104 | |
1105 | PyObject* AssertSameStructure(PyObject* o1, PyObject* o2, bool check_types, |
1106 | bool expand_composites) { |
1107 | const std::function<int(PyObject*)>& is_nested_helper = |
1108 | expand_composites ? IsNestedOrCompositeHelper : IsNestedHelper; |
1109 | const std::function<ValueIteratorPtr(PyObject*)>& get_value_iterator = |
1110 | expand_composites ? GetValueIteratorForComposite : GetValueIterator; |
1111 | const bool check_composite_tensor_type_spec = expand_composites; |
1112 | string error_msg; |
1113 | bool is_type_error = false; |
1114 | AssertSameStructureHelper(o1, o2, check_types, &error_msg, &is_type_error, |
1115 | is_nested_helper, get_value_iterator, |
1116 | check_composite_tensor_type_spec); |
1117 | if (PyErr_Occurred()) { |
1118 | // Don't hide Python exceptions while checking (e.g. errors fetching keys |
1119 | // from custom mappings). |
1120 | return nullptr; |
1121 | } |
1122 | if (!error_msg.empty()) { |
1123 | PyErr_SetString( |
1124 | is_type_error ? PyExc_TypeError : PyExc_ValueError, |
1125 | tensorflow::strings::StrCat( |
1126 | "The two structures don't have the same nested structure.\n\n" , |
1127 | "First structure: " , PyObjectToString(o1), "\n\nSecond structure: " , |
1128 | PyObjectToString(o2), "\n\nMore specifically: " , error_msg) |
1129 | .c_str()); |
1130 | return nullptr; |
1131 | } |
1132 | Py_RETURN_NONE; |
1133 | } |
1134 | |
1135 | PyObject* AssertSameStructureForData(PyObject* o1, PyObject* o2, |
1136 | bool check_types) { |
1137 | string error_msg; |
1138 | bool is_type_error = false; |
1139 | AssertSameStructureHelper(o1, o2, check_types, &error_msg, &is_type_error, |
1140 | IsNestedForDataHelper, GetValueIterator, false); |
1141 | if (PyErr_Occurred()) { |
1142 | // Don't hide Python exceptions while checking (e.g. errors fetching keys |
1143 | // from custom mappings). |
1144 | return nullptr; |
1145 | } |
1146 | if (!error_msg.empty()) { |
1147 | PyErr_SetString( |
1148 | is_type_error ? PyExc_TypeError : PyExc_ValueError, |
1149 | tensorflow::strings::StrCat( |
1150 | "The two structures don't have the same nested structure.\n\n" , |
1151 | "First structure: " , PyObjectToString(o1), "\n\nSecond structure: " , |
1152 | PyObjectToString(o2), "\n\nMore specifically: " , error_msg) |
1153 | .c_str()); |
1154 | return nullptr; |
1155 | } |
1156 | Py_RETURN_NONE; |
1157 | } |
1158 | |
1159 | } // namespace swig |
1160 | } // namespace tensorflow |
1161 | |