1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/python/util/function_parameter_canonicalizer.h"
17
18#include "absl/container/flat_hash_set.h"
19#include "tensorflow/core/platform/logging.h"
20#include "tensorflow/core/platform/macros.h"
21#include "tensorflow/python/lib/core/py_util.h"
22#include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
23
24namespace {
25inline const char* PyUnicodeAsUtf8Compat(PyObject* obj) {
26#if PY_MAJOR_VERSION < 3
27 return PyString_AS_STRING(obj);
28#else
29 return PyUnicode_AsUTF8(obj);
30#endif
31}
32
33inline PyObject* PyUnicodeInternFromStringCompat(const char* str) {
34#if PY_MAJOR_VERSION < 3
35 return PyString_InternFromString(str);
36#else
37 return PyUnicode_InternFromString(str);
38#endif
39}
40
41inline void PyUnicodeInternInPlaceCompat(PyObject** obj) {
42#if PY_MAJOR_VERSION < 3
43 PyString_InternInPlace(obj);
44#else
45 PyUnicode_InternInPlace(obj);
46#endif
47}
48
49} // namespace
50
51namespace tensorflow {
52
53FunctionParameterCanonicalizer::FunctionParameterCanonicalizer(
54 absl::Span<const char*> arg_names, absl::Span<PyObject*> defaults)
55 : positional_args_size_(arg_names.size() - defaults.size()) {
56 DCheckPyGilState();
57 DCHECK_GE(positional_args_size_, 0);
58
59 interned_arg_names_.reserve(arg_names.size());
60 for (const char* obj : arg_names)
61 interned_arg_names_.emplace_back(PyUnicodeInternFromStringCompat(obj));
62
63 DCHECK(AreInternedArgNamesUnique());
64
65 for (PyObject* obj : defaults) Py_INCREF(obj);
66 defaults_ = std::vector<Safe_PyObjectPtr>(defaults.begin(), defaults.end());
67}
68
69bool FunctionParameterCanonicalizer::Canonicalize(
70 PyObject* args, PyObject* kwargs, absl::Span<PyObject*> result) {
71 // TODO(kkb): Closely follow `Python/ceval.c`'s logic and error handling.
72
73 DCheckPyGilState();
74 DCHECK(PyTuple_CheckExact(args));
75 DCHECK(kwargs == nullptr || PyDict_CheckExact(kwargs));
76 DCHECK_EQ(result.size(), interned_arg_names_.size());
77
78 const int args_size = Py_SIZE(args);
79 int remaining_positional_args_count = positional_args_size_ - args_size;
80
81 // Check if the number of input arguments are too many.
82 if (TF_PREDICT_FALSE(args_size > interned_arg_names_.size())) {
83 PyErr_SetString(
84 PyExc_TypeError,
85 absl::StrCat("Too many arguments were given. Expected ",
86 interned_arg_names_.size(), " but got ", args_size, ".")
87 .c_str());
88 return false;
89 }
90
91 // Fill positional arguments.
92 for (int i = 0; i < args_size; ++i) result[i] = PyTuple_GET_ITEM(args, i);
93
94 // Fill default arguments.
95 for (int i = std::max(positional_args_size_, args_size);
96 i < interned_arg_names_.size(); ++i)
97 result[i] = defaults_[i - positional_args_size_].get();
98
99 // Fill keyword arguments.
100 if (kwargs != nullptr) {
101 PyObject *key, *value;
102 Py_ssize_t pos = 0;
103 while (PyDict_Next(kwargs, &pos, &key, &value)) {
104 std::size_t index = InternedArgNameLinearSearch(key);
105
106 // Check if key object(argument name) was found in the pre-built intern
107 // string table.
108 if (TF_PREDICT_FALSE(index == interned_arg_names_.size())) {
109 // `key` might not be an interend string, so get the interned string
110 // and try again. Note: we need to call INCREF before we use
111 // InternInPlace, to prevent the key in the dictionary from being
112 // prematurely deleted in the case where InternInPlace switches `key`
113 // to point at a new object. We call DECREF(key) once we're done
114 // (which might decref the original key *or* the interned version).
115 Py_INCREF(key);
116 PyUnicodeInternInPlaceCompat(&key);
117 index = InternedArgNameLinearSearch(key);
118 Py_DECREF(key);
119
120 // Stil not found, then return an error.
121 if (TF_PREDICT_FALSE(index == interned_arg_names_.size())) {
122 PyErr_Format(PyExc_TypeError,
123 "Got an unexpected keyword argument '%s'",
124 PyUnicodeAsUtf8Compat(key));
125 return false;
126 }
127 }
128
129 // Check if the keyword argument overlaps with positional arguments.
130 if (TF_PREDICT_FALSE(index < args_size)) {
131 PyErr_Format(PyExc_TypeError, "Got multiple values for argument '%s'",
132 PyUnicodeAsUtf8Compat(key));
133 return false;
134 }
135
136 if (TF_PREDICT_FALSE(index < positional_args_size_))
137 --remaining_positional_args_count;
138
139 result[index] = value;
140 }
141 }
142
143 // Check if all the arguments are filled.
144 // Example failure, not enough number of arguments passed: `matmul(x)`
145 if (TF_PREDICT_FALSE(remaining_positional_args_count > 0)) {
146 // TODO(kkb): Report what arguments are missing.
147 PyErr_SetString(PyExc_TypeError, "Missing required positional argument");
148 return false;
149 }
150
151 return true;
152}
153
154ABSL_MUST_USE_RESULT
155ABSL_ATTRIBUTE_HOT
156inline std::size_t FunctionParameterCanonicalizer::InternedArgNameLinearSearch(
157 PyObject* name) {
158 std::size_t result = interned_arg_names_.size();
159
160 for (std::size_t i = 0; i < interned_arg_names_.size(); ++i)
161 if (TF_PREDICT_FALSE(name == interned_arg_names_[i].get())) return i;
162
163 return result;
164}
165
166bool FunctionParameterCanonicalizer::AreInternedArgNamesUnique() {
167 absl::flat_hash_set<PyObject*> interned_arg_names_set;
168 for (const Safe_PyObjectPtr& obj : interned_arg_names_)
169 interned_arg_names_set.emplace(obj.get());
170
171 return interned_arg_names_set.size() == interned_arg_names_.size();
172}
173} // namespace tensorflow
174