1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/python/util/function_parameter_canonicalizer.h" |
17 | |
18 | #include "absl/container/flat_hash_set.h" |
19 | #include "tensorflow/core/platform/logging.h" |
20 | #include "tensorflow/core/platform/macros.h" |
21 | #include "tensorflow/python/lib/core/py_util.h" |
22 | #include "tensorflow/python/lib/core/safe_pyobject_ptr.h" |
23 | |
24 | namespace { |
25 | inline const char* PyUnicodeAsUtf8Compat(PyObject* obj) { |
26 | #if PY_MAJOR_VERSION < 3 |
27 | return PyString_AS_STRING(obj); |
28 | #else |
29 | return PyUnicode_AsUTF8(obj); |
30 | #endif |
31 | } |
32 | |
33 | inline PyObject* PyUnicodeInternFromStringCompat(const char* str) { |
34 | #if PY_MAJOR_VERSION < 3 |
35 | return PyString_InternFromString(str); |
36 | #else |
37 | return PyUnicode_InternFromString(str); |
38 | #endif |
39 | } |
40 | |
41 | inline void PyUnicodeInternInPlaceCompat(PyObject** obj) { |
42 | #if PY_MAJOR_VERSION < 3 |
43 | PyString_InternInPlace(obj); |
44 | #else |
45 | PyUnicode_InternInPlace(obj); |
46 | #endif |
47 | } |
48 | |
49 | } // namespace |
50 | |
51 | namespace tensorflow { |
52 | |
53 | FunctionParameterCanonicalizer::FunctionParameterCanonicalizer( |
54 | absl::Span<const char*> arg_names, absl::Span<PyObject*> defaults) |
55 | : positional_args_size_(arg_names.size() - defaults.size()) { |
56 | DCheckPyGilState(); |
57 | DCHECK_GE(positional_args_size_, 0); |
58 | |
59 | interned_arg_names_.reserve(arg_names.size()); |
60 | for (const char* obj : arg_names) |
61 | interned_arg_names_.emplace_back(PyUnicodeInternFromStringCompat(obj)); |
62 | |
63 | DCHECK(AreInternedArgNamesUnique()); |
64 | |
65 | for (PyObject* obj : defaults) Py_INCREF(obj); |
66 | defaults_ = std::vector<Safe_PyObjectPtr>(defaults.begin(), defaults.end()); |
67 | } |
68 | |
69 | bool FunctionParameterCanonicalizer::Canonicalize( |
70 | PyObject* args, PyObject* kwargs, absl::Span<PyObject*> result) { |
71 | // TODO(kkb): Closely follow `Python/ceval.c`'s logic and error handling. |
72 | |
73 | DCheckPyGilState(); |
74 | DCHECK(PyTuple_CheckExact(args)); |
75 | DCHECK(kwargs == nullptr || PyDict_CheckExact(kwargs)); |
76 | DCHECK_EQ(result.size(), interned_arg_names_.size()); |
77 | |
78 | const int args_size = Py_SIZE(args); |
79 | int remaining_positional_args_count = positional_args_size_ - args_size; |
80 | |
81 | // Check if the number of input arguments are too many. |
82 | if (TF_PREDICT_FALSE(args_size > interned_arg_names_.size())) { |
83 | PyErr_SetString( |
84 | PyExc_TypeError, |
85 | absl::StrCat("Too many arguments were given. Expected " , |
86 | interned_arg_names_.size(), " but got " , args_size, "." ) |
87 | .c_str()); |
88 | return false; |
89 | } |
90 | |
91 | // Fill positional arguments. |
92 | for (int i = 0; i < args_size; ++i) result[i] = PyTuple_GET_ITEM(args, i); |
93 | |
94 | // Fill default arguments. |
95 | for (int i = std::max(positional_args_size_, args_size); |
96 | i < interned_arg_names_.size(); ++i) |
97 | result[i] = defaults_[i - positional_args_size_].get(); |
98 | |
99 | // Fill keyword arguments. |
100 | if (kwargs != nullptr) { |
101 | PyObject *key, *value; |
102 | Py_ssize_t pos = 0; |
103 | while (PyDict_Next(kwargs, &pos, &key, &value)) { |
104 | std::size_t index = InternedArgNameLinearSearch(key); |
105 | |
106 | // Check if key object(argument name) was found in the pre-built intern |
107 | // string table. |
108 | if (TF_PREDICT_FALSE(index == interned_arg_names_.size())) { |
109 | // `key` might not be an interend string, so get the interned string |
110 | // and try again. Note: we need to call INCREF before we use |
111 | // InternInPlace, to prevent the key in the dictionary from being |
112 | // prematurely deleted in the case where InternInPlace switches `key` |
113 | // to point at a new object. We call DECREF(key) once we're done |
114 | // (which might decref the original key *or* the interned version). |
115 | Py_INCREF(key); |
116 | PyUnicodeInternInPlaceCompat(&key); |
117 | index = InternedArgNameLinearSearch(key); |
118 | Py_DECREF(key); |
119 | |
120 | // Stil not found, then return an error. |
121 | if (TF_PREDICT_FALSE(index == interned_arg_names_.size())) { |
122 | PyErr_Format(PyExc_TypeError, |
123 | "Got an unexpected keyword argument '%s'" , |
124 | PyUnicodeAsUtf8Compat(key)); |
125 | return false; |
126 | } |
127 | } |
128 | |
129 | // Check if the keyword argument overlaps with positional arguments. |
130 | if (TF_PREDICT_FALSE(index < args_size)) { |
131 | PyErr_Format(PyExc_TypeError, "Got multiple values for argument '%s'" , |
132 | PyUnicodeAsUtf8Compat(key)); |
133 | return false; |
134 | } |
135 | |
136 | if (TF_PREDICT_FALSE(index < positional_args_size_)) |
137 | --remaining_positional_args_count; |
138 | |
139 | result[index] = value; |
140 | } |
141 | } |
142 | |
143 | // Check if all the arguments are filled. |
144 | // Example failure, not enough number of arguments passed: `matmul(x)` |
145 | if (TF_PREDICT_FALSE(remaining_positional_args_count > 0)) { |
146 | // TODO(kkb): Report what arguments are missing. |
147 | PyErr_SetString(PyExc_TypeError, "Missing required positional argument" ); |
148 | return false; |
149 | } |
150 | |
151 | return true; |
152 | } |
153 | |
154 | ABSL_MUST_USE_RESULT |
155 | ABSL_ATTRIBUTE_HOT |
156 | inline std::size_t FunctionParameterCanonicalizer::InternedArgNameLinearSearch( |
157 | PyObject* name) { |
158 | std::size_t result = interned_arg_names_.size(); |
159 | |
160 | for (std::size_t i = 0; i < interned_arg_names_.size(); ++i) |
161 | if (TF_PREDICT_FALSE(name == interned_arg_names_[i].get())) return i; |
162 | |
163 | return result; |
164 | } |
165 | |
166 | bool FunctionParameterCanonicalizer::AreInternedArgNamesUnique() { |
167 | absl::flat_hash_set<PyObject*> interned_arg_names_set; |
168 | for (const Safe_PyObjectPtr& obj : interned_arg_names_) |
169 | interned_arg_names_set.emplace(obj.get()); |
170 | |
171 | return interned_arg_names_set.size() == interned_arg_names_.size(); |
172 | } |
173 | } // namespace tensorflow |
174 | |