1/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/python/util/nest.h"
16
17#include <utility>
18
19#include "tensorflow/core/lib/strings/strcat.h"
20#include "tensorflow/core/platform/logging.h"
21#include "tensorflow/core/platform/stringpiece.h"
22#include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
23#include "tensorflow/python/util/util.h"
24
25namespace tensorflow {
26
27namespace {
28
29// Gets a string representation of the input object.
30//
31// Args:
32// o: a python object.
33// length: If set to negative, the whole string is returned. Otherwise, the
34// string gets clipped to 'length' in size.
35//
36// Returns:
37// A string representation.
38std::string PyObject_ToString(PyObject* o, int length = -1) {
39 auto str_o = make_safe(PyObject_Str(o));
40 std::string str = PyUnicode_AsUTF8(str_o.get());
41 if (length < 0 || str.size() <= length) {
42 return str;
43 }
44 tensorflow::StringPiece str_piece(str);
45 return tensorflow::strings::StrCat(str_piece.substr(length), "...");
46}
47
48// Gets a list of keys from a dict or mapping type object.
49//
50// Args:
51// o: a dictionary or mapping type object.
52//
53// Returns:
54// A new reference to a list.
55//
56// Raises:
57// TypeError: if `o` is not a dict or mapping type object.
58PyObject* GetKeysFromDictOrMapping(PyObject* o) {
59 if (PyDict_Check(o)) {
60 return PyDict_Keys(o);
61 } else if (PyMapping_Check(o)) {
62 return PyMapping_Keys(o);
63 } else {
64 auto* o_type = Py_TYPE(o);
65 PyErr_SetString(
66 PyExc_TypeError,
67 tensorflow::strings::StrCat(
68 "Expecting a type compatible with dict or mapping, got '",
69 o_type->tp_name, "'")
70 .c_str());
71 return nullptr;
72 }
73}
74
75} // namespace
76
77PyObject* FlattenDictItems(PyObject* dict) {
78 if (!PyDict_Check(dict) && !swig::IsMapping(dict)) {
79 PyErr_SetString(PyExc_TypeError,
80 tensorflow::strings::StrCat(
81 "FlattenDictItems: 'dict' must be a dictionary or ",
82 "collection.Mapping type object, instead of '",
83 Py_TYPE(dict)->tp_name, "'.")
84 .c_str());
85 return nullptr;
86 }
87 PyObject* flat_dictionary = PyDict_New();
88 auto keys = make_safe(GetKeysFromDictOrMapping(dict));
89 for (size_t i = 0; i < PyList_Size(keys.get()); ++i) {
90 auto* key = PyList_GetItem(keys.get(), i);
91 // We use a general approach in case 'dict' is a PyMapping type,
92 // but not a PyDict type.
93 auto* value = PyObject_GetItem(dict, key);
94 if (swig::IsNested(key)) {
95 // The dict might contain list - list pairs.
96 auto flat_keys = make_safe(swig::Flatten(key, false));
97 auto flat_values = make_safe(swig::Flatten(value, false));
98 size_t flat_keys_sz = PyList_Size(flat_keys.get());
99 size_t flat_values_sz = PyList_Size(flat_values.get());
100 if (flat_keys_sz != flat_values_sz) {
101 PyErr_SetString(
102 PyExc_ValueError,
103 tensorflow::strings::StrCat(
104 "Could not flatten dictionary. Key had ", flat_keys_sz,
105 " elements, but value had ", flat_values_sz,
106 " elements. Key: ", PyObject_ToString(flat_keys.get()),
107 ", value: ", PyObject_ToString(flat_values.get()), ".")
108 .c_str());
109 Py_DecRef(flat_dictionary);
110 return nullptr;
111 }
112 for (size_t i = 0; i < flat_keys_sz; ++i) {
113 auto* flat_key = PyList_GetItem(flat_keys.get(), i);
114 auto* flat_value = PyList_GetItem(flat_values.get(), i);
115 if (PyDict_GetItem(flat_dictionary, flat_key) != nullptr) {
116 PyErr_SetString(
117 PyExc_ValueError,
118 tensorflow::strings::StrCat(
119 "Cannot flatten dict because this key is not unique: ",
120 PyObject_ToString(flat_key))
121 .c_str());
122 Py_DecRef(flat_dictionary);
123 return nullptr;
124 }
125 PyDict_SetItem(flat_dictionary, flat_key, flat_value);
126 }
127 } else {
128 if (PyDict_GetItem(flat_dictionary, key) != nullptr) {
129 PyErr_SetString(
130 PyExc_ValueError,
131 tensorflow::strings::StrCat(
132 "Cannot flatten dict because this key is not unique: ",
133 PyObject_ToString(key))
134 .c_str());
135 Py_DecRef(flat_dictionary);
136 return nullptr;
137 }
138 PyDict_SetItem(flat_dictionary, key, value);
139 }
140 // Manually decrease because PyObject_GetItem() returns a new reference.
141 Py_DECREF(value);
142 }
143 return flat_dictionary;
144}
145
146} // namespace tensorflow
147