1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/python/framework/python_api_parameter_converter.h"
16
17#include "absl/strings/str_cat.h"
18#include "tensorflow/core/framework/op.h"
19#include "tensorflow/core/lib/gtl/map_util.h"
20#include "tensorflow/python/eager/pywrap_tensor.h"
21#include "tensorflow/python/framework/op_def_util.h"
22#include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
23#include "tensorflow/python/util/util.h"
24
25#if PY_MAJOR_VERSION < 3
26// Python 2.x:
27#define PY_INT_AS_LONG(x) (PyInt_AsLong(x))
28#define PY_STRING_INTERN_FROM_STRING(x) (PyString_InternFromString(x))
29#else
30// Python 3.x:
31#define PY_INT_AS_LONG(x) (PyLong_AsLong(x))
32#define PY_STRING_INTERN_FROM_STRING(x) (PyUnicode_InternFromString(x))
33#endif
34
35// Evaluate `condition`, and if it returns false then return false.
36#define RETURN_IF_FALSE(condition) \
37 do { \
38 if (!(condition)) return false; \
39 } while (0)
40
41#define PyList_ITEMS(o) (((PyListObject*)(o))->ob_item)
42
43namespace tensorflow {
44
45using InferredAttributes = PythonAPIInfo::InferredAttributes;
46using ParamIndex = PythonAPIInfo::ParamIndex;
47using Attribute = PythonAPIInfo::Attribute;
48using InputWithFixedDType = PythonAPIInfo::InputWithFixedDType;
49using InputsWithTypeAttr = PythonAPIInfo::InputsWithTypeAttr;
50using InputsWithTypeListAttr = PythonAPIInfo::InputsWithTypeListAttr;
51
52namespace {
53
54// Returns `dtype._type_enum`.
55Safe_PyObjectPtr GetAttr_TypeEnum(PyObject* dtype) {
56 static PyObject* attr = PY_STRING_INTERN_FROM_STRING("_type_enum");
57 return Safe_PyObjectPtr(PyObject_GetAttr(dtype, attr));
58}
59
60// Returns `tensor.dtype`.
61Safe_PyObjectPtr GetAttr_DType(PyObject* tensor) {
62 static PyObject* attr = PY_STRING_INTERN_FROM_STRING("dtype");
63 return Safe_PyObjectPtr(PyObject_GetAttr(tensor, attr));
64}
65
66// Raises a TypeError with a message constructed by applying StrCat to the
67// specified strings. If an exception has already been set when this function
68// is called, then add its message as a suffix to the message string.
69template <typename... Args>
70void RaiseTypeError(Args... args) {
71 string message = absl::StrCat(args...);
72 if (!PyErr_Occurred()) {
73 PyErr_SetString(PyExc_TypeError, message.c_str());
74 } else {
75 PyObject* exc_type;
76 PyObject* exc_value;
77 PyObject* exc_traceback;
78 PyErr_Fetch(&exc_type, &exc_value, &exc_traceback);
79 PyErr_Format(PyExc_TypeError, "%s: %S", message.c_str(), exc_value);
80 Py_XDECREF(exc_type);
81 Py_XDECREF(exc_value);
82 Py_XDECREF(exc_traceback);
83 }
84}
85
86// Returns the DataType for a `tf.dtypes.DType` object (or DT_INVALID if it
87// is not a valid DType object).
88ABSL_MUST_USE_RESULT
89DataType DataTypeFromPyDType(PyObject* dtype) {
90 if (!dtype) {
91 return DT_INVALID;
92 }
93 Safe_PyObjectPtr enum_field = GetAttr_TypeEnum(dtype);
94 if (!enum_field) {
95 return DT_INVALID;
96 }
97 DataType result = static_cast<DataType>(PY_INT_AS_LONG(enum_field.get()));
98 return result;
99}
100
101// Update `dtype` with an inferred dtype from `value`. In particular, if
102// `dtype == DT_INVALID` and `value` is a `Tensor`, then set `dtype` to
103// `value.dtype`. (If `dtype` is not `DT_INVALID`, or `value` is not a
104// tensor, then do nothing.) Returns false on exception.
105ABSL_MUST_USE_RESULT
106bool InferDType(PyObject* value, DataType& dtype) {
107 if (dtype != DT_INVALID) return true; // Already have dtype.
108
109 if (EagerTensor_CheckExact(value)) {
110 dtype = PyEagerTensor_Dtype(value);
111 return true;
112 }
113
114 if (swig::IsTensor(value)) {
115 Safe_PyObjectPtr py_dtype = GetAttr_DType(value);
116 if (!py_dtype) return false;
117 dtype = DataTypeFromPyDType(py_dtype.get()); // set output parameter
118 return true;
119 }
120 return true;
121}
122
123// Returns true if `dtype` is in `ok_dtypes`, or `ok_dtypes` is null or empty.
124ABSL_MUST_USE_RESULT
125bool IsOkDType(DataType dtype, const std::vector<DataType>* ok_dtypes) {
126 return (ok_dtypes == nullptr || ok_dtypes->empty() ||
127 std::find(ok_dtypes->begin(), ok_dtypes->end(), dtype) !=
128 ok_dtypes->end());
129}
130
131// Formatter for DataTypes for absl::StrJoin.
132struct DataTypeFormatter {
133 void operator()(std::string* out, DataType dtype) const {
134 out->append(DataType_Name(dtype));
135 }
136};
137
138// Converts `src` to a tensor using `tensor_converter.Convert`. If `src` is
139// replaced by a new value then decref the replaced value. If an error
140// occurs, then re-raise it as a TypeError with a prefix indicating the API
141// name and the parameter name.
142//
143// Args:
144// src: The value that should be converted (in-place).
145// dtype: The dtype to convert `src` to, or DT_INVALID for unconstraned.
146// If DT_INVALID, then `dtype` will be set to the actual dtype of the
147// converted value.
148// tensor_converter: Class used to convert python values to tensors.
149// api_info: Information about the API we're converting this value for
150// (for error messages).
151// param_index: Index of the parameter we're converting (for error messages).
152// ok_dtypes: List of valid dtypes for conversion (optional).
153// default_dtype: Default dtype -- used if converting the value to a tensor
154// with unconstrained dtype returns a value not in ok_dtypes.
155ABSL_MUST_USE_RESULT
156bool ConvertToTensorInPlace(PyObject*& src, DataType& dtype,
157 const PythonTensorConverter& tensor_converter,
158 const PythonAPIInfo& api_info, int param_index,
159 const std::vector<DataType>* ok_dtypes = nullptr,
160 DataType default_dtype = DT_INVALID) {
161 bool inferred_dtype = (dtype == DT_INVALID);
162 Safe_PyObjectPtr converted = tensor_converter.Convert(src, dtype);
163 if (!converted) {
164 RaiseTypeError(api_info.api_name(), " argument ",
165 api_info.param_names()[param_index]);
166 return false;
167 }
168
169 if (inferred_dtype && !IsOkDType(dtype, ok_dtypes)) {
170 // Converting `src` to a tensor gave us a disallowed dtype; try again
171 // with `default_dtype`.
172 if (default_dtype == DT_INVALID) {
173 RaiseTypeError(api_info.api_name(), " argument ",
174 api_info.param_names()[param_index], ": Expected one of {",
175 absl::StrJoin(*ok_dtypes, ", ", DataTypeFormatter()),
176 "}, but got ", DataType_Name(dtype));
177 return false;
178 } else {
179 dtype = default_dtype;
180 converted = tensor_converter.Convert(src, dtype);
181 if (!converted) {
182 RaiseTypeError(api_info.api_name(), " argument ",
183 api_info.param_names()[param_index]);
184 return false;
185 }
186 }
187 }
188
189 Py_DECREF(src);
190 src = converted.release();
191 return true;
192}
193
194// Converts the specified attribute parameter to the expected type. Modifies
195// `params` in-place. Returns true on success, or sets an exception and
196// returns false on failure.
197ABSL_MUST_USE_RESULT
198bool ConvertAttribute(const Attribute& attr, const PythonAPIInfo& api_info,
199 absl::Span<PyObject*> params) {
200 if (attr.index == -1) return true; // Inferred attribute.
201 PyObject* src = params[attr.index];
202 Safe_PyObjectPtr converted = ConvertPyObjectToAttributeType(src, attr.type);
203 if (!converted) {
204 RaiseTypeError(api_info.api_name(), " argument ",
205 api_info.param_names()[attr.index]);
206 return false;
207 }
208 if (converted.get() != src) {
209 Py_DECREF(src);
210 params[attr.index] = converted.release();
211 }
212 return true;
213}
214
215// Converts the specified fixed-dtype input parameter to a Tensor with the
216// expected dtype. Modifies `params` in-place. Returns true on success, or
217// sets an exception and returns false on failure.
218ABSL_MUST_USE_RESULT
219bool ConvertInputWithFixedDType(const InputWithFixedDType& input,
220 const PythonTensorConverter& tensor_converter,
221 const PythonAPIInfo& api_info,
222 absl::Span<PyObject*> params) {
223 DataType dtype = input.dtype;
224 PyObject*& src = params[input.index];
225 if (!input.is_list) {
226 RETURN_IF_FALSE(ConvertToTensorInPlace(src, dtype, tensor_converter,
227 api_info, input.index));
228 } else {
229 DCHECK(PyList_CheckExact(src));
230 PyObject** items = PyList_ITEMS(src);
231 Py_ssize_t len = PyList_GET_SIZE(src);
232 for (Py_ssize_t i = 0; i < len; ++i) {
233 RETURN_IF_FALSE(ConvertToTensorInPlace(items[i], dtype, tensor_converter,
234 api_info, input.index));
235 }
236 }
237 return true;
238}
239
240// Infers a consistent dtype for the specified collection of homogeneous-dtype
241// input parameters, and converts those parameters to Tensors (or lists of
242// Tensors) with that dtype. Modifies `params` in-place, and updates
243// `inferred_attrs` with the inferred dtype (if it's not null). Returns true
244// on success, or sets an exception and returns false on failure.
245ABSL_MUST_USE_RESULT
246bool ConvertInputsWithTypeAttr(const InputsWithTypeAttr& input,
247 const PythonTensorConverter& tensor_converter,
248 const PythonAPIInfo& api_info,
249 absl::Span<PyObject*> params,
250 InferredAttributes* inferred_attrs) {
251 DataType dtype = DT_INVALID;
252 if (input.type_attr->index != -1) {
253 // explicit type attribute
254 PyObject* py_dtype = params[input.type_attr->index];
255 dtype = DataTypeFromPyDType(py_dtype);
256 } else {
257 // implicit type attribute: infer the dtype.
258 // First, check the single-tensor inputs.
259 for (ParamIndex index : input.tensor_params) {
260 RETURN_IF_FALSE(InferDType(params[index], dtype));
261 if (dtype != DT_INVALID) break;
262 }
263 // Next, check the list-of-tensor inputs.
264 if (dtype == DT_INVALID) {
265 for (ParamIndex index : input.tensor_list_params) {
266 PyObject* tensor_list = params[index];
267 DCHECK(PyList_CheckExact(tensor_list));
268 Py_ssize_t num_tensors = PyList_GET_SIZE(tensor_list);
269 PyObject** tensors = PyList_ITEMS(tensor_list);
270 for (Py_ssize_t i = 0; i < num_tensors; ++i) {
271 RETURN_IF_FALSE(InferDType(tensors[i], dtype));
272 if (dtype != DT_INVALID) break;
273 }
274 if (dtype != DT_INVALID) break;
275 }
276 }
277 }
278
279 // Convert the single-tensor inputs to tensors.
280 for (ParamIndex index : input.tensor_params) {
281 RETURN_IF_FALSE(
282 ConvertToTensorInPlace(params[index], dtype, tensor_converter, api_info,
283 index, &input.ok_dtypes, input.default_dtype));
284 }
285
286 // Convert the list-of-tensor inputs to tensors.
287 for (ParamIndex index : input.tensor_list_params) {
288 PyObject* tensor_list = params[index];
289 DCHECK(PyList_CheckExact(tensor_list));
290 Py_ssize_t num_tensors = PyList_GET_SIZE(tensor_list);
291 PyObject** items = PyList_ITEMS(tensor_list);
292 for (Py_ssize_t i = 0; i < num_tensors; ++i) {
293 RETURN_IF_FALSE(ConvertToTensorInPlace(items[i], dtype, tensor_converter,
294 api_info, index, &input.ok_dtypes,
295 input.default_dtype));
296 }
297 }
298
299 if (inferred_attrs) {
300 if (dtype == DT_INVALID) {
301 dtype = input.default_dtype;
302 }
303 // TODO(b/164980194) Should we raise an exception here if we didn't manage
304 // to infer a dtype? (I.e., if there were no single-tensor inputs and all
305 // list-of-tensor inputs were empty, and there's no default dtype.)
306 int inferred_index = input.type_attr->inferred_index;
307 if (inferred_index != -1) {
308 inferred_attrs->types[inferred_index] = dtype;
309 }
310 }
311
312 return true;
313}
314
315// Infers a consistent list of dtypes for the specified collection of
316// heterogeneous-dtype input parameters, and converts those parameters to lists
317// of Tensors with those dtypes. Modifies `params` in-place, and updates
318// `inferred_attrs` with the inferred dtypes (if it's not null). Returns true
319// on success, or sets an exception and returns false on failure.
320ABSL_MUST_USE_RESULT
321bool ConvertInputsWithTypeListAttr(
322 const InputsWithTypeListAttr& input,
323 const PythonTensorConverter& tensor_converter,
324 const PythonAPIInfo& api_info, absl::Span<PyObject*> params,
325 InferredAttributes* inferred_attrs) {
326 DCHECK(!input.tensor_list_params.empty());
327
328 // Get the number of tensors from the first input list; and check that the
329 // remaining lists have the same size.
330 DCHECK(PyList_CheckExact(params[input.tensor_list_params[0]]));
331 Py_ssize_t num_tensors = PyList_GET_SIZE(params[input.tensor_list_params[0]]);
332 for (int i = 1; i < input.tensor_list_params.size(); ++i) {
333 DCHECK(PyList_CheckExact(params[input.tensor_list_params[i]]));
334 if (num_tensors != PyList_GET_SIZE(params[input.tensor_list_params[i]])) {
335 RaiseTypeError(api_info.api_name(), " expected parameters ",
336 api_info.param_names()[0], " and ",
337 api_info.param_names()[i],
338 " to be lists of the same length.");
339 return false;
340 }
341 }
342
343 // Get the list of dtypes.
344 std::vector<DataType> dtypes(num_tensors, DT_INVALID);
345 if (input.type_list_attr->index != -1) {
346 // Dtypes are specified by an explicit attribute.
347 PyObject* py_dtypes = params[input.type_list_attr->index];
348 if (PyList_GET_SIZE(py_dtypes) != num_tensors) {
349 RaiseTypeError(api_info.api_name(), " expected parameters ",
350 api_info.param_names()[0], " and ",
351 api_info.param_names()[input.type_list_attr->index],
352 "to be lists of the same length.");
353 return false;
354 }
355 for (Py_ssize_t i = 0; i < PyList_GET_SIZE(py_dtypes); ++i) {
356 dtypes[i] = DataTypeFromPyDType(PyList_GetItem(py_dtypes, i));
357 }
358 } else {
359 // Dtypes are implicit: infer them.
360 for (Py_ssize_t i = 0; i < num_tensors; ++i) {
361 for (ParamIndex index : input.tensor_list_params) {
362 PyObject* tensor_list = params[index];
363 DCHECK(PyList_CheckExact(tensor_list));
364 PyObject* item = PyList_GET_ITEM(tensor_list, i);
365 RETURN_IF_FALSE(InferDType(item, dtypes[i]));
366 if (dtypes[i] != DT_INVALID) break;
367 }
368 }
369 }
370
371 // Convert tensors.
372 for (ParamIndex index : input.tensor_list_params) {
373 PyObject* tensor_list = params[index];
374 PyObject** items = PyList_ITEMS(tensor_list);
375 for (Py_ssize_t i = 0; i < num_tensors; ++i) {
376 DataType default_dtype = i < input.default_dtypes.size()
377 ? input.default_dtypes[i]
378 : DT_INVALID;
379 RETURN_IF_FALSE(ConvertToTensorInPlace(items[i], dtypes[i],
380 tensor_converter, api_info, index,
381 &input.ok_dtypes, default_dtype));
382 }
383 }
384
385 if (inferred_attrs) {
386 int inferred_index = input.type_list_attr->inferred_index;
387 if (inferred_index != -1) {
388 inferred_attrs->type_lists[inferred_index].swap(dtypes);
389 }
390 }
391
392 return true;
393}
394
395// Infers length attributes for Tensor-list parameters from their values, and
396// updates `inferred_length_attrs` with the inferred length. Sets an exception
397// if multiple Tensor-list parameters have the same length attribute but
398// different lengths. Returns true on success, or sets an exception and returns
399// false on failure.
400ABSL_MUST_USE_RESULT
401bool InferLengthAttributes(const absl::Span<PyObject*> params,
402 const PythonAPIInfo& api_info,
403 std::vector<int64_t>& inferred_length_attrs) {
404 for (int i = 0; i < api_info.inputs_with_number_attrs().size(); ++i) {
405 const auto& inputs = api_info.inputs_with_number_attrs()[i];
406 DCHECK(!inputs.tensor_list_params.empty());
407
408 // Use the first tensor_list parameter to infer the length attribute.
409 PyObject* tensors = params[inputs.tensor_list_params[0]];
410 DCHECK(PyList_CheckExact(tensors));
411 int inferred_length = PyList_GET_SIZE(tensors);
412
413 // Check that any other tensor_list parameters have matching length.
414 for (int j = 1; j < inputs.tensor_list_params.size(); ++j) {
415 int num_tensors = PyList_GET_SIZE(params[inputs.tensor_list_params[j]]);
416 if (num_tensors != inferred_length) {
417 RaiseTypeError(api_info.api_name(), " expected parameters ",
418 api_info.param_names()[inputs.tensor_list_params[0]],
419 " and ",
420 api_info.param_names()[inputs.tensor_list_params[j]],
421 " to be lists with the same length.");
422 }
423 }
424
425 int inferred_index = inputs.number_attr->inferred_index;
426 if (inferred_index != -1) {
427 inferred_length_attrs[inferred_index] = inferred_length;
428 }
429 }
430 return true;
431}
432
433} // namespace
434
435bool ConvertPythonAPIParameters(const PythonAPIInfo& api_info,
436 const PythonTensorConverter& tensor_converter,
437 absl::Span<PyObject*> params,
438 InferredAttributes* inferred_attrs) {
439 // Make room for inferred attributes.
440 if (inferred_attrs) {
441 inferred_attrs->types.resize(api_info.inferred_type_attrs().size());
442 inferred_attrs->type_lists.resize(
443 api_info.inferred_type_list_attrs().size());
444 inferred_attrs->lengths.resize(api_info.inferred_length_attrs().size());
445 }
446
447 for (const auto& attr : api_info.attributes()) {
448 RETURN_IF_FALSE(ConvertAttribute(attr, api_info, params));
449 }
450
451 for (const auto& input : api_info.inputs_with_fixed_dtype()) {
452 RETURN_IF_FALSE(
453 ConvertInputWithFixedDType(input, tensor_converter, api_info, params));
454 }
455
456 for (int i = 0; i < api_info.inputs_with_type_attrs().size(); ++i) {
457 RETURN_IF_FALSE(ConvertInputsWithTypeAttr(
458 api_info.inputs_with_type_attrs()[i], tensor_converter, api_info,
459 params, inferred_attrs));
460 }
461
462 for (int i = 0; i < api_info.inputs_with_type_list_attrs().size(); ++i) {
463 RETURN_IF_FALSE(ConvertInputsWithTypeListAttr(
464 api_info.inputs_with_type_list_attrs()[i], tensor_converter, api_info,
465 params, inferred_attrs));
466 }
467
468 if (inferred_attrs) {
469 RETURN_IF_FALSE(
470 InferLengthAttributes(params, api_info, inferred_attrs->lengths));
471 }
472
473 return true;
474}
475
476bool CopyPythonAPITensorLists(const PythonAPIInfo& api_info,
477 absl::Span<PyObject*> params) {
478 for (const auto& input : api_info.inputs()) {
479 if (input.is_list) {
480 PyObject* src = params[input.index];
481 PyObject* copy = PySequence_List(src);
482 if (!copy) {
483 RaiseTypeError(api_info.api_name(), " expected a list of Tensors for '",
484 api_info.param_names()[input.index], "'; got ",
485 src->ob_type->tp_name, ".");
486 return false;
487 }
488 Py_DECREF(params[input.index]);
489 params[input.index] = copy;
490 }
491 }
492 return true;
493}
494
495} // namespace tensorflow
496