1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/python/util/util.h"
16
17#include <functional>
18#include <memory>
19#include <unordered_map>
20#include <vector>
21
22#include "absl/memory/memory.h"
23#include "tensorflow/core/lib/gtl/map_util.h"
24#include "tensorflow/core/lib/strings/strcat.h"
25#include "tensorflow/core/platform/logging.h"
26#include "tensorflow/core/platform/mutex.h"
27#include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
28
29namespace tensorflow {
30namespace swig {
31
32namespace {
33string PyObjectToString(PyObject* o);
34} // namespace
35
36std::unordered_map<string, PyObject*>* RegisteredPyObjectMap() {
37 static auto* m = new std::unordered_map<string, PyObject*>();
38 return m;
39}
40
41PyObject* GetRegisteredPyObject(const string& name) {
42 const auto* m = RegisteredPyObjectMap();
43 auto it = m->find(name);
44 if (it == m->end()) {
45 PyErr_SetString(PyExc_TypeError,
46 tensorflow::strings::StrCat("No object with name ", name,
47 " has been registered.")
48 .c_str());
49 return nullptr;
50 }
51 return it->second;
52}
53
54PyObject* RegisterType(PyObject* type_name, PyObject* type) {
55 if (!PyType_Check(type)) {
56 PyErr_SetString(PyExc_TypeError,
57 tensorflow::strings::StrCat("Expecting a type, got ",
58 Py_TYPE(type)->tp_name)
59 .c_str());
60 return nullptr;
61 }
62 return RegisterPyObject(type_name, type);
63}
64
65PyObject* RegisterPyObject(PyObject* name, PyObject* value) {
66 string key;
67 if (PyBytes_Check(name)) {
68 key = PyBytes_AsString(name);
69#if PY_MAJOR_VERSION >= 3
70 } else if (PyUnicode_Check(name)) {
71 key = PyUnicode_AsUTF8(name);
72#endif
73 } else {
74 PyErr_SetString(PyExc_TypeError, tensorflow::strings::StrCat(
75 "Expected name to be a str, got",
76 PyObjectToString(name))
77 .c_str());
78 return nullptr;
79 }
80
81 auto* m = RegisteredPyObjectMap();
82 if (m->find(key) != m->end()) {
83 PyErr_SetString(PyExc_TypeError, tensorflow::strings::StrCat(
84 "Value already registered for ", key)
85 .c_str());
86 return nullptr;
87 }
88
89 Py_INCREF(value);
90 m->emplace(key, value);
91
92 Py_RETURN_NONE;
93}
94
95namespace {
96const int kMaxItemsInCache = 1024;
97
98bool IsString(PyObject* o) {
99 return PyBytes_Check(o) ||
100#if PY_MAJOR_VERSION < 3
101 PyString_Check(o) ||
102#endif
103 PyUnicode_Check(o);
104}
105
106// Equivalent to Python's 'o.__class__.__name__'
107// Note that '__class__' attribute is set only in new-style classes.
108// A lot of tensorflow code uses __class__ without checks, so it seems like
109// we only support new-style classes.
110StringPiece GetClassName(PyObject* o) {
111 // __class__ is equivalent to type() for new style classes.
112 // type() is equivalent to PyObject_Type()
113 // (https://docs.python.org/3.5/c-api/object.html#c.PyObject_Type)
114 // PyObject_Type() is equivalent to o->ob_type except for Py_INCREF, which
115 // we don't need here.
116 PyTypeObject* type = o->ob_type;
117
118 // __name__ is the value of `tp_name` after the last '.'
119 // (https://docs.python.org/2/c-api/typeobj.html#c.PyTypeObject.tp_name)
120 StringPiece name(type->tp_name);
121 size_t pos = name.rfind('.');
122 if (pos != StringPiece::npos) {
123 name.remove_prefix(pos + 1);
124 }
125 return name;
126}
127
128string PyObjectToString(PyObject* o) {
129 if (o == nullptr) {
130 return "<null object>";
131 }
132 PyObject* str = PyObject_Str(o);
133 if (str) {
134#if PY_MAJOR_VERSION < 3
135 string s(PyString_AS_STRING(str));
136#else
137 string s(PyUnicode_AsUTF8(str));
138#endif
139 Py_DECREF(str);
140 return tensorflow::strings::StrCat("type=", GetClassName(o), " str=", s);
141 } else {
142 return "<failed to execute str() on object>";
143 }
144}
145
146class CachedTypeCheck {
147 public:
148 explicit CachedTypeCheck(std::function<int(PyObject*)> ternary_predicate)
149 : ternary_predicate_(std::move(ternary_predicate)) {}
150
151 ~CachedTypeCheck() {
152 mutex_lock l(type_to_sequence_map_mu_);
153 for (const auto& pair : type_to_sequence_map_) {
154 Py_DECREF(pair.first);
155 }
156 }
157
158 // Caches successful executions of the one-argument (PyObject*) callable
159 // "ternary_predicate" based on the type of "o". -1 from the callable
160 // indicates an unsuccessful check (not cached), 0 indicates that "o"'s type
161 // does not match the predicate, and 1 indicates that it does. Used to avoid
162 // calling back into Python for expensive isinstance checks.
163 int CachedLookup(PyObject* o) {
164 // Try not to return to Python - see if the type has already been seen
165 // before.
166
167 auto* type = Py_TYPE(o);
168
169 {
170 tf_shared_lock l(type_to_sequence_map_mu_);
171 auto it = type_to_sequence_map_.find(type);
172 if (it != type_to_sequence_map_.end()) {
173 return it->second;
174 }
175 }
176
177 int check_result = ternary_predicate_(o);
178
179 if (check_result == -1) {
180 return -1; // Type check error, not cached.
181 }
182
183 // NOTE: This is never decref'd as long as the object lives, which is likely
184 // forever, but we don't want the type to get deleted as long as it is in
185 // the map. This should not be too much of a leak, as there should only be a
186 // relatively small number of types in the map, and an even smaller number
187 // that are eligible for decref. As a precaution, we limit the size of the
188 // map to 1024.
189 {
190 mutex_lock l(type_to_sequence_map_mu_);
191 if (type_to_sequence_map_.size() < kMaxItemsInCache) {
192 Py_INCREF(type);
193 auto insert_result = type_to_sequence_map_.insert({type, check_result});
194 if (!insert_result.second) {
195 // The type was added to the cache by a concurrent thread after we
196 // looked it up above.
197 Py_DECREF(type);
198 }
199 }
200 }
201
202 return check_result;
203 }
204
205 private:
206 std::function<int(PyObject*)> ternary_predicate_;
207 mutex type_to_sequence_map_mu_;
208 std::unordered_map<PyTypeObject*, bool> type_to_sequence_map_
209 TF_GUARDED_BY(type_to_sequence_map_mu_);
210};
211
212// Returns 1 if 'obj' is an instance of 'type_name'
213// Returns 0 otherwise.
214// Returns -1 if an error occurred (e.g., if 'type_name' is not registered.)
215int IsInstanceOfRegisteredType(PyObject* obj, const char* type_name) {
216 PyObject* type_obj = GetRegisteredPyObject(type_name);
217 if (TF_PREDICT_FALSE(type_obj == nullptr)) {
218 PyErr_SetString(PyExc_RuntimeError,
219 tensorflow::strings::StrCat(
220 type_name,
221 " type has not been set. "
222 "Please register the type with the identifier \"",
223 type_name, "\" using RegisterType.")
224 .c_str());
225 return -1;
226 }
227 return PyObject_IsInstance(obj, type_obj);
228}
229
230// Returns 1 if `o` is considered a mapping for the purposes of Flatten().
231// Returns 0 otherwise.
232// Returns -1 if an error occurred.
233int IsMappingHelper(PyObject* o) {
234 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
235 return IsInstanceOfRegisteredType(to_check, "Mapping");
236 });
237 if (PyDict_Check(o)) return true;
238 return check_cache->CachedLookup(o);
239}
240
241// Returns 1 if `o` is considered a mutable mapping for the purposes of
242// Flatten(). Returns 0 otherwise. Returns -1 if an error occurred.
243int IsMutableMappingHelper(PyObject* o) {
244 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
245 return IsInstanceOfRegisteredType(to_check, "MutableMapping");
246 });
247 if (PyDict_Check(o)) return true;
248 return check_cache->CachedLookup(o);
249}
250
251// Returns 1 if `o` is considered a mapping view for the purposes of Flatten().
252// Returns 0 otherwise.
253// Returns -1 if an error occurred.
254int IsMappingViewHelper(PyObject* o) {
255 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
256 return IsInstanceOfRegisteredType(to_check, "MappingView");
257 });
258 return check_cache->CachedLookup(o);
259}
260
261// Returns 1 if `o` is considered an object proxy
262// Returns 0 otherwise.
263// Returns -1 if an error occurred.
264int IsObjectProxy(PyObject* o) {
265 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
266 return IsInstanceOfRegisteredType(to_check, "ObjectProxy");
267 });
268 return check_cache->CachedLookup(o);
269}
270
271// Returns 1 if `o` is an instance of attrs-decorated class.
272// Returns 0 otherwise.
273int IsAttrsHelper(PyObject* o) {
274 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
275 Safe_PyObjectPtr cls(PyObject_GetAttrString(to_check, "__class__"));
276 if (cls) {
277 return PyObject_HasAttrString(cls.get(), "__attrs_attrs__");
278 }
279
280 // PyObject_GetAttrString returns null on error
281 PyErr_Clear();
282 return 0;
283 });
284 return check_cache->CachedLookup(o);
285}
286
287// Returns 1 if `o` is an object of type IndexedSlices.
288// Returns 0 otherwise.
289// Returns -1 if an error occurred.
290int IsIndexedSlicesHelper(PyObject* o) {
291 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
292 return IsInstanceOfRegisteredType(to_check, "IndexedSlices");
293 });
294 return check_cache->CachedLookup(o);
295}
296
297// Returns 1 if `o` is a Tensor.
298// Returns 0 otherwise.
299// Returns -1 if an error occurred.
300int IsTensorHelper(PyObject* o) {
301 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
302 return IsInstanceOfRegisteredType(to_check, "Tensor");
303 });
304 return check_cache->CachedLookup(o);
305}
306
307// Returns 1 if `o` is a TensorSpec.
308// Returns 0 otherwise.
309// Returns -1 if an error occurred.
310int IsTensorSpecHelper(PyObject* o) {
311 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
312 return IsInstanceOfRegisteredType(to_check, "TensorSpec");
313 });
314 return check_cache->CachedLookup(o);
315}
316
317// Returns 1 if `o` is an EagerTensor.
318// Returns 0 otherwise.
319// Returns -1 if an error occurred.
320int IsEagerTensorHelper(PyObject* o) {
321 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
322 return IsInstanceOfRegisteredType(to_check, "EagerTensor");
323 });
324 return check_cache->CachedLookup(o);
325}
326
327// Returns 1 if `o` is a ResourceVariable.
328// Returns 0 otherwise.
329// Returns -1 if an error occurred.
330int IsResourceVariableHelper(PyObject* o) {
331 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
332 return IsInstanceOfRegisteredType(to_check, "ResourceVariable");
333 });
334 return check_cache->CachedLookup(o);
335}
336
337// Returns 1 if `o` is a OwnedIterator.
338// Returns 0 otherwise.
339// Returns -1 if an error occurred.
340int IsOwnedIteratorHelper(PyObject* o) {
341 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
342 return IsInstanceOfRegisteredType(to_check, "OwnedIterator");
343 });
344 return check_cache->CachedLookup(o);
345}
346
347// Returns 1 if `o` is a ResourceVariable.
348// Returns 0 otherwise.
349// Returns -1 if an error occurred.
350int IsVariableHelper(PyObject* o) {
351 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
352 return IsInstanceOfRegisteredType(to_check, "Variable");
353 });
354 return check_cache->CachedLookup(o);
355}
356
357// Returns 1 if `o` is considered a sequence for the purposes of Flatten().
358// Returns 0 otherwise.
359// Returns -1 if an error occurred.
360int IsNestedHelper(PyObject* o) {
361 // We treat dicts and other mappings as special cases of sequences.
362 if (IsMappingHelper(o)) return true;
363 if (IsMappingViewHelper(o)) return true;
364 if (IsAttrsHelper(o)) return true;
365
366 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
367 int is_instance = IsInstanceOfRegisteredType(to_check, "Sequence");
368
369 // Don't cache a failed is_instance check.
370 if (is_instance == -1) return -1;
371
372 return static_cast<int>(is_instance != 0 && !IsString(to_check));
373 });
374 return check_cache->CachedLookup(o);
375}
376
377// Returns 1 if `o`'s class has a `__tf_dispatch__` attribute.
378// Returns 0 otherwise.
379int IsDispatchableHelper(PyObject* o) {
380 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
381 return PyObject_HasAttrString(
382 reinterpret_cast<PyObject*>(to_check->ob_type), "__tf_dispatch__");
383 });
384 return check_cache->CachedLookup(o);
385}
386
387// ValueIterator interface
388class ValueIterator {
389 public:
390 virtual ~ValueIterator() {}
391 virtual Safe_PyObjectPtr next() = 0;
392
393 bool valid() const { return is_valid_; }
394
395 protected:
396 void invalidate() { is_valid_ = false; }
397
398 private:
399 bool is_valid_ = true;
400};
401
402using ValueIteratorPtr = std::unique_ptr<ValueIterator>;
403
404// Iterate through dictionaries in a deterministic order by sorting the
405// keys. Notice this means that we ignore the original order of
406// `OrderedDict` instances. This is intentional, to avoid potential
407// bugs caused by mixing ordered and plain dicts (e.g., flattening
408// a dict but using a corresponding `OrderedDict` to pack it back).
409class DictValueIterator : public ValueIterator {
410 public:
411 explicit DictValueIterator(PyObject* dict)
412 : dict_(dict), keys_(PyDict_Keys(dict)) {
413 if (PyList_Sort(keys_.get()) == -1) {
414 invalidate();
415 } else {
416 iter_.reset(PyObject_GetIter(keys_.get()));
417 }
418 }
419
420 Safe_PyObjectPtr next() override {
421 Safe_PyObjectPtr result;
422 Safe_PyObjectPtr key(PyIter_Next(iter_.get()));
423 if (key) {
424 // PyDict_GetItem returns a borrowed reference.
425 PyObject* elem = PyDict_GetItem(dict_, key.get());
426 if (elem) {
427 Py_INCREF(elem);
428 result.reset(elem);
429 } else {
430 PyErr_SetString(PyExc_RuntimeError,
431 "Dictionary was modified during iteration over it");
432 }
433 }
434 return result;
435 }
436
437 private:
438 PyObject* dict_;
439 Safe_PyObjectPtr keys_;
440 Safe_PyObjectPtr iter_;
441};
442
443// Iterate over mapping objects by sorting the keys first
444class MappingValueIterator : public ValueIterator {
445 public:
446 explicit MappingValueIterator(PyObject* mapping)
447 : mapping_(mapping), keys_(MappingKeys(mapping)) {
448 if (!keys_ || PyList_Sort(keys_.get()) == -1) {
449 invalidate();
450 } else {
451 iter_.reset(PyObject_GetIter(keys_.get()));
452 }
453 }
454
455 Safe_PyObjectPtr next() override {
456 Safe_PyObjectPtr result;
457 Safe_PyObjectPtr key(PyIter_Next(iter_.get()));
458 if (key) {
459 // Unlike PyDict_GetItem, PyObject_GetItem returns a new reference.
460 PyObject* elem = PyObject_GetItem(mapping_, key.get());
461 if (elem) {
462 result.reset(elem);
463 } else {
464 PyErr_SetString(PyExc_RuntimeError,
465 "Mapping was modified during iteration over it");
466 }
467 }
468 return result;
469 }
470
471 private:
472 PyObject* mapping_;
473 Safe_PyObjectPtr keys_;
474 Safe_PyObjectPtr iter_;
475};
476
477// Iterate over a sequence, by index.
478class SequenceValueIterator : public ValueIterator {
479 public:
480 explicit SequenceValueIterator(PyObject* iterable)
481 : seq_(PySequence_Fast(iterable, "")),
482 size_(seq_.get() ? PySequence_Fast_GET_SIZE(seq_.get()) : 0),
483 index_(0) {}
484
485 Safe_PyObjectPtr next() override {
486 Safe_PyObjectPtr result;
487 if (index_ < size_) {
488 // PySequence_Fast_GET_ITEM returns a borrowed reference.
489 PyObject* elem = PySequence_Fast_GET_ITEM(seq_.get(), index_);
490 ++index_;
491 if (elem) {
492 Py_INCREF(elem);
493 result.reset(elem);
494 }
495 }
496
497 return result;
498 }
499
500 private:
501 Safe_PyObjectPtr seq_;
502 const Py_ssize_t size_;
503 Py_ssize_t index_;
504};
505
506// Iterator that just returns a single python object.
507class SingleValueIterator : public ValueIterator {
508 public:
509 explicit SingleValueIterator(PyObject* x) : x_(x) { Py_INCREF(x); }
510
511 Safe_PyObjectPtr next() override { return std::move(x_); }
512
513 private:
514 Safe_PyObjectPtr x_;
515};
516
517// Returns nullptr (to raise an exception) when next() is called. Caller
518// should have already called PyErr_SetString.
519class ErrorValueIterator : public ValueIterator {
520 public:
521 ErrorValueIterator() {}
522 Safe_PyObjectPtr next() override { return nullptr; }
523};
524
525class AttrsValueIterator : public ValueIterator {
526 public:
527 explicit AttrsValueIterator(PyObject* nested) : nested_(nested) {
528 Py_INCREF(nested);
529 cls_.reset(PyObject_GetAttrString(nested_.get(), "__class__"));
530 if (cls_) {
531 attrs_.reset(PyObject_GetAttrString(cls_.get(), "__attrs_attrs__"));
532 if (attrs_) {
533 iter_.reset(PyObject_GetIter(attrs_.get()));
534 }
535 }
536 if (!iter_ || PyErr_Occurred()) invalidate();
537 }
538
539 Safe_PyObjectPtr next() override {
540 Safe_PyObjectPtr result;
541 Safe_PyObjectPtr item(PyIter_Next(iter_.get()));
542 if (item) {
543 Safe_PyObjectPtr name(PyObject_GetAttrString(item.get(), "name"));
544 result.reset(PyObject_GetAttr(nested_.get(), name.get()));
545 }
546
547 return result;
548 }
549
550 private:
551 Safe_PyObjectPtr nested_;
552 Safe_PyObjectPtr cls_;
553 Safe_PyObjectPtr attrs_;
554 Safe_PyObjectPtr iter_;
555};
556
557bool IsSparseTensorValueType(PyObject* o) {
558 PyObject* sparse_tensor_value_type =
559 GetRegisteredPyObject("SparseTensorValue");
560 if (TF_PREDICT_FALSE(sparse_tensor_value_type == nullptr)) {
561 return false;
562 }
563
564 return PyObject_TypeCheck(
565 o, reinterpret_cast<PyTypeObject*>(sparse_tensor_value_type)) == 1;
566}
567
568// Returns 1 if `o` is an instance of CompositeTensor.
569// Returns 0 otherwise.
570// Returns -1 if an error occurred.
571bool IsCompositeTensorHelper(PyObject* o) {
572 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
573 // TODO(b/246438937): Remove the ResourceVariable test.
574 return IsInstanceOfRegisteredType(to_check, "CompositeTensor") &&
575 !IsResourceVariable(to_check);
576 });
577 return check_cache->CachedLookup(o);
578}
579
580// Returns 1 if `o` is an instance of TypeSpec, but is not TensorSpec or
581// VariableSpec.
582// Returns 0 otherwise.
583// Returns -1 if an error occurred.
584bool IsTypeSpecHelper(PyObject* o) {
585 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
586 int is_type_spec = IsInstanceOfRegisteredType(to_check, "TypeSpec");
587 // TODO(b/246438937): Remove the VariableSpec special case.
588 int is_dense_spec = (IsInstanceOfRegisteredType(to_check, "TensorSpec") ||
589 IsInstanceOfRegisteredType(to_check, "VariableSpec"));
590 if ((is_type_spec == -1) || (is_dense_spec == -1)) return -1;
591 return static_cast<int>(is_type_spec && !is_dense_spec);
592 });
593 return check_cache->CachedLookup(o);
594}
595
596// Returns 1 if `o` is a (non-string) sequence or CompositeTensor or
597// (non-TensorSpec and non-VariableSpec) TypeSpec.
598// Returns 0 otherwise.
599// Returns -1 if an error occurred.
600int IsNestedOrCompositeHelper(PyObject* o) {
601 int is_nested = IsNestedHelper(o);
602 int is_composite = IsCompositeTensorHelper(o);
603 int is_type_spec = IsTypeSpecHelper(o);
604 if ((is_nested == -1) || (is_composite == -1) || (is_type_spec == -1)) {
605 return -1;
606 }
607 return is_nested || is_composite || is_type_spec;
608}
609
610int IsNestedForDataHelper(PyObject* o) {
611 return IsNestedHelper(o) == 1 && !PyList_Check(o) &&
612 !IsSparseTensorValueType(o);
613}
614
615ValueIteratorPtr GetValueIterator(PyObject* nested) {
616 if (PyDict_Check(nested)) {
617 return absl::make_unique<DictValueIterator>(nested);
618 } else if (IsMappingHelper(nested)) {
619 return absl::make_unique<MappingValueIterator>(nested);
620 } else if (IsAttrsHelper(nested)) {
621 return absl::make_unique<AttrsValueIterator>(nested);
622 } else {
623 return absl::make_unique<SequenceValueIterator>(nested);
624 }
625}
626
627// Similar to above, just specialized for the functions in the data package.
628ValueIteratorPtr GetValueIteratorForData(PyObject* nested) {
629 if (PyDict_Check(nested)) {
630 return absl::make_unique<DictValueIterator>(nested);
631 } else if (IsMappingHelper(nested)) {
632 return absl::make_unique<MappingValueIterator>(nested);
633 } else if (IsAttrsHelper(nested)) {
634 return absl::make_unique<AttrsValueIterator>(nested);
635 } else if (IsSparseTensorValueType(nested)) {
636 return absl::make_unique<SingleValueIterator>(nested);
637 } else {
638 return absl::make_unique<SequenceValueIterator>(nested);
639 }
640}
641
642// Similar to GetValueIterator above, but expands CompositeTensor and TypeSpec.
643ValueIteratorPtr GetValueIteratorForComposite(PyObject* nested) {
644 if (IsCompositeTensor(nested)) {
645 Safe_PyObjectPtr spec(PyObject_GetAttrString(nested, "_type_spec"));
646 if (PyErr_Occurred() || !spec) {
647 return absl::make_unique<ErrorValueIterator>();
648 }
649
650 static char to_components[] = "_to_components";
651 static char argspec[] = "(O)";
652 Safe_PyObjectPtr components(
653 PyObject_CallMethod(spec.get(), to_components, argspec, nested));
654 if (PyErr_Occurred() || components == nullptr) {
655 return absl::make_unique<ErrorValueIterator>();
656 }
657 return absl::make_unique<SingleValueIterator>(components.get());
658 }
659
660 if (IsTypeSpec(nested)) {
661 Safe_PyObjectPtr specs(PyObject_GetAttrString(nested, "_component_specs"));
662 if (PyErr_Occurred() || specs == nullptr) {
663 return absl::make_unique<ErrorValueIterator>();
664 }
665 return absl::make_unique<SingleValueIterator>(specs.get());
666 }
667
668 return GetValueIterator(nested);
669}
670
671bool FlattenHelper(
672 PyObject* nested, PyObject* list,
673 const std::function<int(PyObject*)>& is_nested_helper,
674 const std::function<ValueIteratorPtr(PyObject*)>& value_iterator_getter) {
675 // if nested is not a sequence, append itself and exit
676 int is_nested = is_nested_helper(nested);
677 if (is_nested == -1) return false;
678 if (!is_nested) {
679 return PyList_Append(list, nested) != -1;
680 }
681
682 ValueIteratorPtr iter = value_iterator_getter(nested);
683 if (!iter->valid()) return false;
684
685 for (Safe_PyObjectPtr item = iter->next(); item; item = iter->next()) {
686 if (Py_EnterRecursiveCall(" in flatten")) {
687 return false;
688 }
689 const bool success = FlattenHelper(item.get(), list, is_nested_helper,
690 value_iterator_getter);
691 Py_LeaveRecursiveCall();
692 if (!success) {
693 return false;
694 }
695 }
696 return true;
697}
698
699// Sets error using keys of 'dict1' and 'dict2'.
700// 'dict1' and 'dict2' are assumed to be Python dictionaries.
701void SetDifferentKeysError(PyObject* dict1, PyObject* dict2, string* error_msg,
702 bool* is_type_error) {
703 Safe_PyObjectPtr k1(MappingKeys(dict1));
704 if (PyErr_Occurred() || k1.get() == nullptr) {
705 *error_msg =
706 ("The two dictionaries don't have the same set of keys. Failed to "
707 "fetch keys.");
708 return;
709 }
710 Safe_PyObjectPtr k2(MappingKeys(dict2));
711 if (PyErr_Occurred() || k2.get() == nullptr) {
712 *error_msg =
713 ("The two dictionaries don't have the same set of keys. Failed to "
714 "fetch keys.");
715 return;
716 }
717 *is_type_error = false;
718 *error_msg = tensorflow::strings::StrCat(
719 "The two dictionaries don't have the same set of keys. "
720 "First structure has keys ",
721 PyObjectToString(k1.get()), ", while second structure has keys ",
722 PyObjectToString(k2.get()));
723}
724
725// Returns true iff there were no "internal" errors. In other words,
726// errors that has nothing to do with structure checking.
727// If an "internal" error occurred, the appropriate Python error will be
728// set and the caller can propage it directly to the user.
729//
730// Both `error_msg` and `is_type_error` must be non-null. `error_msg` must
731// be empty.
732// Leaves `error_msg` empty if structures matched. Else, fills `error_msg`
733// with appropriate error and sets `is_type_error` to true iff
734// the error to be raised should be TypeError.
735bool AssertSameStructureHelper(
736 PyObject* o1, PyObject* o2, bool check_types, string* error_msg,
737 bool* is_type_error, const std::function<int(PyObject*)>& is_nested_helper,
738 const std::function<ValueIteratorPtr(PyObject*)>& value_iterator_getter,
739 bool check_composite_tensor_type_spec) {
740 DCHECK(error_msg);
741 DCHECK(is_type_error);
742 const bool is_nested1 = is_nested_helper(o1);
743 const bool is_nested2 = is_nested_helper(o2);
744 if (PyErr_Occurred()) return false;
745 if (is_nested1 != is_nested2) {
746 string seq_str = is_nested1 ? PyObjectToString(o1) : PyObjectToString(o2);
747 string non_seq_str =
748 is_nested1 ? PyObjectToString(o2) : PyObjectToString(o1);
749 *is_type_error = false;
750 *error_msg = tensorflow::strings::StrCat(
751 "Substructure \"", seq_str, "\" is a sequence, while substructure \"",
752 non_seq_str, "\" is not");
753 return true;
754 }
755
756 // Got to objects that are considered non-sequences. Note that in tf.data
757 // use case lists and sparse_tensors are not considered sequences. So finished
758 // checking, structures are the same.
759 if (!is_nested1) return true;
760
761 if (check_types) {
762 // Treat wrapped tuples as tuples.
763 tensorflow::Safe_PyObjectPtr o1_wrapped;
764 if (IsObjectProxy(o1)) {
765 o1_wrapped.reset(PyObject_GetAttrString(o1, "__wrapped__"));
766 o1 = o1_wrapped.get();
767 }
768 tensorflow::Safe_PyObjectPtr o2_wrapped;
769 if (IsObjectProxy(o2)) {
770 o2_wrapped.reset(PyObject_GetAttrString(o2, "__wrapped__"));
771 o2 = o2_wrapped.get();
772 }
773
774 const PyTypeObject* type1 = o1->ob_type;
775 const PyTypeObject* type2 = o2->ob_type;
776
777 // We treat two different namedtuples with identical name and fields
778 // as having the same type.
779 const PyObject* o1_tuple = IsNamedtuple(o1, false);
780 if (o1_tuple == nullptr) return false;
781 const PyObject* o2_tuple = IsNamedtuple(o2, false);
782 if (o2_tuple == nullptr) {
783 Py_DECREF(o1_tuple);
784 return false;
785 }
786 bool both_tuples = o1_tuple == Py_True && o2_tuple == Py_True;
787 Py_DECREF(o1_tuple);
788 Py_DECREF(o2_tuple);
789
790 if (both_tuples) {
791 const PyObject* same_tuples = SameNamedtuples(o1, o2);
792 if (same_tuples == nullptr) return false;
793 bool not_same_tuples = same_tuples != Py_True;
794 Py_DECREF(same_tuples);
795 if (not_same_tuples) {
796 *is_type_error = true;
797 *error_msg = tensorflow::strings::StrCat(
798 "The two namedtuples don't have the same sequence type. "
799 "First structure ",
800 PyObjectToString(o1), " has type ", type1->tp_name,
801 ", while second structure ", PyObjectToString(o2), " has type ",
802 type2->tp_name);
803 return true;
804 }
805 } else if (type1 != type2
806 /* If both sequences are list types, don't complain. This allows
807 one to be a list subclass (e.g. _ListWrapper used for
808 automatic dependency tracking.) */
809 && !(PyList_Check(o1) && PyList_Check(o2))
810 /* Two mapping types will also compare equal, making _DictWrapper
811 and dict compare equal. */
812 && !(IsMappingHelper(o1) && IsMappingHelper(o2))
813 /* For CompositeTensor & TypeSpec, we check below. */
814 && !(check_composite_tensor_type_spec &&
815 (IsCompositeTensor(o1) || IsTypeSpec(o1)) &&
816 (IsCompositeTensor(o2) || IsTypeSpec(o2)))) {
817 *is_type_error = true;
818 *error_msg = tensorflow::strings::StrCat(
819 "The two namedtuples don't have the same sequence type. "
820 "First structure ",
821 PyObjectToString(o1), " has type ", type1->tp_name,
822 ", while second structure ", PyObjectToString(o2), " has type ",
823 type2->tp_name);
824 return true;
825 }
826
827 if (PyDict_Check(o1) && PyDict_Check(o2)) {
828 if (PyDict_Size(o1) != PyDict_Size(o2)) {
829 SetDifferentKeysError(o1, o2, error_msg, is_type_error);
830 return true;
831 }
832
833 PyObject* key;
834 Py_ssize_t pos = 0;
835 while (PyDict_Next(o1, &pos, &key, nullptr)) {
836 if (PyDict_GetItem(o2, key) == nullptr) {
837 SetDifferentKeysError(o1, o2, error_msg, is_type_error);
838 return true;
839 }
840 }
841 } else if (IsMappingHelper(o1)) {
842 // Fallback for custom mapping types. Instead of using PyDict methods
843 // which stay in C, we call iter(o1).
844 if (PyMapping_Size(o1) != PyMapping_Size(o2)) {
845 SetDifferentKeysError(o1, o2, error_msg, is_type_error);
846 return true;
847 }
848
849 Safe_PyObjectPtr iter(PyObject_GetIter(o1));
850 PyObject* key;
851 while ((key = PyIter_Next(iter.get())) != nullptr) {
852 if (!PyMapping_HasKey(o2, key)) {
853 SetDifferentKeysError(o1, o2, error_msg, is_type_error);
854 Py_DECREF(key);
855 return true;
856 }
857 Py_DECREF(key);
858 }
859 }
860 }
861
862 if (check_composite_tensor_type_spec &&
863 (IsCompositeTensor(o1) || IsCompositeTensor(o2))) {
864 Safe_PyObjectPtr owned_type_spec_1;
865 PyObject* type_spec_1 = o1;
866 if (IsCompositeTensor(o1)) {
867 owned_type_spec_1.reset(PyObject_GetAttrString(o1, "_type_spec"));
868 type_spec_1 = owned_type_spec_1.get();
869 }
870
871 Safe_PyObjectPtr owned_type_spec_2;
872 PyObject* type_spec_2 = o2;
873 if (IsCompositeTensor(o2)) {
874 owned_type_spec_2.reset(PyObject_GetAttrString(o2, "_type_spec"));
875 type_spec_2 = owned_type_spec_2.get();
876 }
877
878 // Two composite tensors are considered to have the same structure if
879 // they share a type spec that is a supertype of both of them. We do *not*
880 // use is_subtype_of, since that would prevent us from e.g. using a
881 // cond statement where the two sides have different shapes.
882
883 // TODO(b/206014848): We have to explicitly remove the names.
884 Safe_PyObjectPtr owned_nameless_type_spec_1(
885 PyObject_CallMethod(type_spec_1, "_without_tensor_names", nullptr));
886 Safe_PyObjectPtr owned_nameless_type_spec_2(
887 PyObject_CallMethod(type_spec_2, "_without_tensor_names", nullptr));
888 // TODO(b/222123181): Reconsider most_specific_common_supertype usage.
889 static char compatible_type[] = "most_specific_common_supertype";
890 static char argspec[] = "([O])";
891 Safe_PyObjectPtr struct_compatible(
892 PyObject_CallMethod(owned_nameless_type_spec_1.get(), compatible_type,
893 argspec, owned_nameless_type_spec_2.get()));
894 if (PyErr_Occurred()) {
895 return false;
896 }
897 if (struct_compatible.get() == Py_None) {
898 *is_type_error = false;
899 *error_msg = tensorflow::strings::StrCat(
900 "Incompatible CompositeTensor TypeSpecs: ",
901 PyObjectToString(type_spec_1), " vs. ",
902 PyObjectToString(type_spec_2));
903 return true;
904 }
905 }
906
907 ValueIteratorPtr iter1 = value_iterator_getter(o1);
908 ValueIteratorPtr iter2 = value_iterator_getter(o2);
909
910 if (!iter1->valid() || !iter2->valid()) return false;
911
912 while (true) {
913 Safe_PyObjectPtr v1 = iter1->next();
914 Safe_PyObjectPtr v2 = iter2->next();
915 if (v1 && v2) {
916 if (Py_EnterRecursiveCall(" in assert_same_structure")) {
917 return false;
918 }
919 bool no_internal_errors = AssertSameStructureHelper(
920 v1.get(), v2.get(), check_types, error_msg, is_type_error,
921 is_nested_helper, value_iterator_getter,
922 check_composite_tensor_type_spec);
923 Py_LeaveRecursiveCall();
924 if (!no_internal_errors) return false;
925 if (!error_msg->empty()) return true;
926 } else if (!v1 && !v2) {
927 // Done with all recursive calls. Structure matched.
928 return true;
929 } else {
930 *is_type_error = false;
931 *error_msg = tensorflow::strings::StrCat(
932 "The two structures don't have the same number of elements. ",
933 "First structure: ", PyObjectToString(o1),
934 ". Second structure: ", PyObjectToString(o2));
935 return true;
936 }
937 }
938}
939
940} // namespace
941
942bool IsNested(PyObject* o) { return IsNestedHelper(o) == 1; }
943bool IsMapping(PyObject* o) { return IsMappingHelper(o) == 1; }
944bool IsMutableMapping(PyObject* o) { return IsMutableMappingHelper(o) == 1; }
945bool IsMappingView(PyObject* o) { return IsMappingViewHelper(o) == 1; }
946bool IsAttrs(PyObject* o) { return IsAttrsHelper(o) == 1; }
947bool IsTensor(PyObject* o) { return IsTensorHelper(o) == 1; }
948bool IsTensorSpec(PyObject* o) { return IsTensorSpecHelper(o) == 1; }
949bool IsEagerTensorSlow(PyObject* o) { return IsEagerTensorHelper(o) == 1; }
950bool IsResourceVariable(PyObject* o) {
951 return IsResourceVariableHelper(o) == 1;
952}
953bool IsOwnedIterator(PyObject* o) { return IsOwnedIteratorHelper(o) == 1; }
954bool IsVariable(PyObject* o) { return IsVariableHelper(o) == 1; }
955bool IsIndexedSlices(PyObject* o) { return IsIndexedSlicesHelper(o) == 1; }
956bool IsDispatchable(PyObject* o) { return IsDispatchableHelper(o) == 1; }
957
958bool IsTuple(PyObject* o) {
959 tensorflow::Safe_PyObjectPtr wrapped;
960 if (IsObjectProxy(o)) {
961 wrapped.reset(PyObject_GetAttrString(o, "__wrapped__"));
962 o = wrapped.get();
963 }
964 return PyTuple_Check(o);
965}
966
967// Work around a writable-strings warning with Python 2's PyMapping_Keys macro,
968// and while we're at it give them consistent behavior by making sure the
969// returned value is a list.
970//
971// As with PyMapping_Keys, returns a new reference.
972//
973// On failure, returns nullptr.
974PyObject* MappingKeys(PyObject* o) {
975#if PY_MAJOR_VERSION >= 3
976 return PyMapping_Keys(o);
977#else
978 static char key_method_name[] = "keys";
979 Safe_PyObjectPtr raw_result(PyObject_CallMethod(o, key_method_name, nullptr));
980 if (PyErr_Occurred() || raw_result.get() == nullptr) {
981 return nullptr;
982 }
983 return PySequence_Fast(
984 raw_result.get(),
985 "The '.keys()' method of a custom mapping returned a non-sequence.");
986#endif
987}
988
989PyObject* Flatten(PyObject* nested, bool expand_composites) {
990 PyObject* list = PyList_New(0);
991 const std::function<int(PyObject*)>& is_nested_helper =
992 expand_composites ? IsNestedOrCompositeHelper : IsNestedHelper;
993 const std::function<ValueIteratorPtr(PyObject*)>& get_value_iterator =
994 expand_composites ? GetValueIteratorForComposite : GetValueIterator;
995 if (FlattenHelper(nested, list, is_nested_helper, get_value_iterator)) {
996 return list;
997 } else {
998 Py_DECREF(list);
999 return nullptr;
1000 }
1001}
1002
1003bool IsNestedOrComposite(PyObject* o) {
1004 return IsNestedOrCompositeHelper(o) == 1;
1005}
1006
1007bool IsCompositeTensor(PyObject* o) { return IsCompositeTensorHelper(o) == 1; }
1008
1009bool IsTypeSpec(PyObject* o) { return IsTypeSpecHelper(o) == 1; }
1010
1011bool IsNestedForData(PyObject* o) { return IsNestedForDataHelper(o) == 1; }
1012
1013PyObject* FlattenForData(PyObject* nested) {
1014 PyObject* list = PyList_New(0);
1015 if (FlattenHelper(nested, list, IsNestedForDataHelper,
1016 GetValueIteratorForData)) {
1017 return list;
1018 } else {
1019 Py_DECREF(list);
1020 return nullptr;
1021 }
1022}
1023
1024PyObject* IsNamedtuple(PyObject* o, bool strict) {
1025 // Some low-level CPython calls do not work with wrapt.ObjectProxy, so they
1026 // require some unwrapping if we want to treat them like the objects they're
1027 // wrapping.
1028 tensorflow::Safe_PyObjectPtr o_wrapped;
1029 if (IsObjectProxy(o)) {
1030 o_wrapped.reset(PyObject_GetAttrString(o, "__wrapped__"));
1031 o = o_wrapped.get();
1032 }
1033
1034 // Must be subclass of tuple
1035 if (!PyTuple_Check(o)) {
1036 Py_RETURN_FALSE;
1037 }
1038
1039 // If strict, o.__class__.__base__ must be tuple
1040 if (strict) {
1041 PyObject* klass = PyObject_GetAttrString(o, "__class__");
1042 if (klass == nullptr) return nullptr;
1043 PyObject* base = PyObject_GetAttrString(klass, "__base__");
1044 Py_DECREF(klass);
1045 if (base == nullptr) return nullptr;
1046
1047 const PyTypeObject* base_type = reinterpret_cast<PyTypeObject*>(base);
1048 // built-in object types are singletons
1049 bool tuple_base = base_type == &PyTuple_Type;
1050 Py_DECREF(base);
1051 if (!tuple_base) {
1052 Py_RETURN_FALSE;
1053 }
1054 }
1055
1056 // o must have attribute '_fields' and every element in
1057 // '_fields' must be a string.
1058 int has_fields = PyObject_HasAttrString(o, "_fields");
1059 if (!has_fields) {
1060 Py_RETURN_FALSE;
1061 }
1062
1063 Safe_PyObjectPtr fields = make_safe(PyObject_GetAttrString(o, "_fields"));
1064 int is_instance = IsInstanceOfRegisteredType(fields.get(), "Sequence");
1065 if (is_instance == 0) {
1066 Py_RETURN_FALSE;
1067 } else if (is_instance == -1) {
1068 return nullptr;
1069 }
1070
1071 Safe_PyObjectPtr seq = make_safe(PySequence_Fast(fields.get(), ""));
1072 const Py_ssize_t s = PySequence_Fast_GET_SIZE(seq.get());
1073 for (Py_ssize_t i = 0; i < s; ++i) {
1074 // PySequence_Fast_GET_ITEM returns borrowed ref
1075 PyObject* elem = PySequence_Fast_GET_ITEM(seq.get(), i);
1076 if (!IsString(elem)) {
1077 Py_RETURN_FALSE;
1078 }
1079 }
1080
1081 Py_RETURN_TRUE;
1082}
1083
1084PyObject* SameNamedtuples(PyObject* o1, PyObject* o2) {
1085 Safe_PyObjectPtr f1 = make_safe(PyObject_GetAttrString(o1, "_fields"));
1086 Safe_PyObjectPtr f2 = make_safe(PyObject_GetAttrString(o2, "_fields"));
1087 if (f1 == nullptr || f2 == nullptr) {
1088 PyErr_SetString(
1089 PyExc_RuntimeError,
1090 "Expected namedtuple-like objects (that have _fields attr)");
1091 return nullptr;
1092 }
1093
1094 if (PyObject_RichCompareBool(f1.get(), f2.get(), Py_NE)) {
1095 Py_RETURN_FALSE;
1096 }
1097
1098 if (GetClassName(o1).compare(GetClassName(o2)) == 0) {
1099 Py_RETURN_TRUE;
1100 } else {
1101 Py_RETURN_FALSE;
1102 }
1103}
1104
1105PyObject* AssertSameStructure(PyObject* o1, PyObject* o2, bool check_types,
1106 bool expand_composites) {
1107 const std::function<int(PyObject*)>& is_nested_helper =
1108 expand_composites ? IsNestedOrCompositeHelper : IsNestedHelper;
1109 const std::function<ValueIteratorPtr(PyObject*)>& get_value_iterator =
1110 expand_composites ? GetValueIteratorForComposite : GetValueIterator;
1111 const bool check_composite_tensor_type_spec = expand_composites;
1112 string error_msg;
1113 bool is_type_error = false;
1114 AssertSameStructureHelper(o1, o2, check_types, &error_msg, &is_type_error,
1115 is_nested_helper, get_value_iterator,
1116 check_composite_tensor_type_spec);
1117 if (PyErr_Occurred()) {
1118 // Don't hide Python exceptions while checking (e.g. errors fetching keys
1119 // from custom mappings).
1120 return nullptr;
1121 }
1122 if (!error_msg.empty()) {
1123 PyErr_SetString(
1124 is_type_error ? PyExc_TypeError : PyExc_ValueError,
1125 tensorflow::strings::StrCat(
1126 "The two structures don't have the same nested structure.\n\n",
1127 "First structure: ", PyObjectToString(o1), "\n\nSecond structure: ",
1128 PyObjectToString(o2), "\n\nMore specifically: ", error_msg)
1129 .c_str());
1130 return nullptr;
1131 }
1132 Py_RETURN_NONE;
1133}
1134
1135PyObject* AssertSameStructureForData(PyObject* o1, PyObject* o2,
1136 bool check_types) {
1137 string error_msg;
1138 bool is_type_error = false;
1139 AssertSameStructureHelper(o1, o2, check_types, &error_msg, &is_type_error,
1140 IsNestedForDataHelper, GetValueIterator, false);
1141 if (PyErr_Occurred()) {
1142 // Don't hide Python exceptions while checking (e.g. errors fetching keys
1143 // from custom mappings).
1144 return nullptr;
1145 }
1146 if (!error_msg.empty()) {
1147 PyErr_SetString(
1148 is_type_error ? PyExc_TypeError : PyExc_ValueError,
1149 tensorflow::strings::StrCat(
1150 "The two structures don't have the same nested structure.\n\n",
1151 "First structure: ", PyObjectToString(o1), "\n\nSecond structure: ",
1152 PyObjectToString(o2), "\n\nMore specifically: ", error_msg)
1153 .c_str());
1154 return nullptr;
1155 }
1156 Py_RETURN_NONE;
1157}
1158
1159} // namespace swig
1160} // namespace tensorflow
1161