1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/python/framework/python_api_parameter_converter.h" |
16 | |
17 | #include "absl/strings/str_cat.h" |
18 | #include "tensorflow/core/framework/op.h" |
19 | #include "tensorflow/core/lib/gtl/map_util.h" |
20 | #include "tensorflow/python/eager/pywrap_tensor.h" |
21 | #include "tensorflow/python/framework/op_def_util.h" |
22 | #include "tensorflow/python/lib/core/safe_pyobject_ptr.h" |
23 | #include "tensorflow/python/util/util.h" |
24 | |
25 | #if PY_MAJOR_VERSION < 3 |
26 | // Python 2.x: |
27 | #define PY_INT_AS_LONG(x) (PyInt_AsLong(x)) |
28 | #define PY_STRING_INTERN_FROM_STRING(x) (PyString_InternFromString(x)) |
29 | #else |
30 | // Python 3.x: |
31 | #define PY_INT_AS_LONG(x) (PyLong_AsLong(x)) |
32 | #define PY_STRING_INTERN_FROM_STRING(x) (PyUnicode_InternFromString(x)) |
33 | #endif |
34 | |
35 | // Evaluate `condition`, and if it returns false then return false. |
36 | #define RETURN_IF_FALSE(condition) \ |
37 | do { \ |
38 | if (!(condition)) return false; \ |
39 | } while (0) |
40 | |
41 | #define PyList_ITEMS(o) (((PyListObject*)(o))->ob_item) |
42 | |
43 | namespace tensorflow { |
44 | |
45 | using InferredAttributes = PythonAPIInfo::InferredAttributes; |
46 | using ParamIndex = PythonAPIInfo::ParamIndex; |
47 | using Attribute = PythonAPIInfo::Attribute; |
48 | using InputWithFixedDType = PythonAPIInfo::InputWithFixedDType; |
49 | using InputsWithTypeAttr = PythonAPIInfo::InputsWithTypeAttr; |
50 | using InputsWithTypeListAttr = PythonAPIInfo::InputsWithTypeListAttr; |
51 | |
52 | namespace { |
53 | |
54 | // Returns `dtype._type_enum`. |
55 | Safe_PyObjectPtr GetAttr_TypeEnum(PyObject* dtype) { |
56 | static PyObject* attr = PY_STRING_INTERN_FROM_STRING("_type_enum" ); |
57 | return Safe_PyObjectPtr(PyObject_GetAttr(dtype, attr)); |
58 | } |
59 | |
60 | // Returns `tensor.dtype`. |
61 | Safe_PyObjectPtr GetAttr_DType(PyObject* tensor) { |
62 | static PyObject* attr = PY_STRING_INTERN_FROM_STRING("dtype" ); |
63 | return Safe_PyObjectPtr(PyObject_GetAttr(tensor, attr)); |
64 | } |
65 | |
66 | // Raises a TypeError with a message constructed by applying StrCat to the |
67 | // specified strings. If an exception has already been set when this function |
68 | // is called, then add its message as a suffix to the message string. |
69 | template <typename... Args> |
70 | void RaiseTypeError(Args... args) { |
71 | string message = absl::StrCat(args...); |
72 | if (!PyErr_Occurred()) { |
73 | PyErr_SetString(PyExc_TypeError, message.c_str()); |
74 | } else { |
75 | PyObject* exc_type; |
76 | PyObject* exc_value; |
77 | PyObject* exc_traceback; |
78 | PyErr_Fetch(&exc_type, &exc_value, &exc_traceback); |
79 | PyErr_Format(PyExc_TypeError, "%s: %S" , message.c_str(), exc_value); |
80 | Py_XDECREF(exc_type); |
81 | Py_XDECREF(exc_value); |
82 | Py_XDECREF(exc_traceback); |
83 | } |
84 | } |
85 | |
86 | // Returns the DataType for a `tf.dtypes.DType` object (or DT_INVALID if it |
87 | // is not a valid DType object). |
88 | ABSL_MUST_USE_RESULT |
89 | DataType DataTypeFromPyDType(PyObject* dtype) { |
90 | if (!dtype) { |
91 | return DT_INVALID; |
92 | } |
93 | Safe_PyObjectPtr enum_field = GetAttr_TypeEnum(dtype); |
94 | if (!enum_field) { |
95 | return DT_INVALID; |
96 | } |
97 | DataType result = static_cast<DataType>(PY_INT_AS_LONG(enum_field.get())); |
98 | return result; |
99 | } |
100 | |
101 | // Update `dtype` with an inferred dtype from `value`. In particular, if |
102 | // `dtype == DT_INVALID` and `value` is a `Tensor`, then set `dtype` to |
103 | // `value.dtype`. (If `dtype` is not `DT_INVALID`, or `value` is not a |
104 | // tensor, then do nothing.) Returns false on exception. |
105 | ABSL_MUST_USE_RESULT |
106 | bool InferDType(PyObject* value, DataType& dtype) { |
107 | if (dtype != DT_INVALID) return true; // Already have dtype. |
108 | |
109 | if (EagerTensor_CheckExact(value)) { |
110 | dtype = PyEagerTensor_Dtype(value); |
111 | return true; |
112 | } |
113 | |
114 | if (swig::IsTensor(value)) { |
115 | Safe_PyObjectPtr py_dtype = GetAttr_DType(value); |
116 | if (!py_dtype) return false; |
117 | dtype = DataTypeFromPyDType(py_dtype.get()); // set output parameter |
118 | return true; |
119 | } |
120 | return true; |
121 | } |
122 | |
123 | // Returns true if `dtype` is in `ok_dtypes`, or `ok_dtypes` is null or empty. |
124 | ABSL_MUST_USE_RESULT |
125 | bool IsOkDType(DataType dtype, const std::vector<DataType>* ok_dtypes) { |
126 | return (ok_dtypes == nullptr || ok_dtypes->empty() || |
127 | std::find(ok_dtypes->begin(), ok_dtypes->end(), dtype) != |
128 | ok_dtypes->end()); |
129 | } |
130 | |
131 | // Formatter for DataTypes for absl::StrJoin. |
132 | struct DataTypeFormatter { |
133 | void operator()(std::string* out, DataType dtype) const { |
134 | out->append(DataType_Name(dtype)); |
135 | } |
136 | }; |
137 | |
138 | // Converts `src` to a tensor using `tensor_converter.Convert`. If `src` is |
139 | // replaced by a new value then decref the replaced value. If an error |
140 | // occurs, then re-raise it as a TypeError with a prefix indicating the API |
141 | // name and the parameter name. |
142 | // |
143 | // Args: |
144 | // src: The value that should be converted (in-place). |
145 | // dtype: The dtype to convert `src` to, or DT_INVALID for unconstraned. |
146 | // If DT_INVALID, then `dtype` will be set to the actual dtype of the |
147 | // converted value. |
148 | // tensor_converter: Class used to convert python values to tensors. |
149 | // api_info: Information about the API we're converting this value for |
150 | // (for error messages). |
151 | // param_index: Index of the parameter we're converting (for error messages). |
152 | // ok_dtypes: List of valid dtypes for conversion (optional). |
153 | // default_dtype: Default dtype -- used if converting the value to a tensor |
154 | // with unconstrained dtype returns a value not in ok_dtypes. |
155 | ABSL_MUST_USE_RESULT |
156 | bool ConvertToTensorInPlace(PyObject*& src, DataType& dtype, |
157 | const PythonTensorConverter& tensor_converter, |
158 | const PythonAPIInfo& api_info, int param_index, |
159 | const std::vector<DataType>* ok_dtypes = nullptr, |
160 | DataType default_dtype = DT_INVALID) { |
161 | bool inferred_dtype = (dtype == DT_INVALID); |
162 | Safe_PyObjectPtr converted = tensor_converter.Convert(src, dtype); |
163 | if (!converted) { |
164 | RaiseTypeError(api_info.api_name(), " argument " , |
165 | api_info.param_names()[param_index]); |
166 | return false; |
167 | } |
168 | |
169 | if (inferred_dtype && !IsOkDType(dtype, ok_dtypes)) { |
170 | // Converting `src` to a tensor gave us a disallowed dtype; try again |
171 | // with `default_dtype`. |
172 | if (default_dtype == DT_INVALID) { |
173 | RaiseTypeError(api_info.api_name(), " argument " , |
174 | api_info.param_names()[param_index], ": Expected one of {" , |
175 | absl::StrJoin(*ok_dtypes, ", " , DataTypeFormatter()), |
176 | "}, but got " , DataType_Name(dtype)); |
177 | return false; |
178 | } else { |
179 | dtype = default_dtype; |
180 | converted = tensor_converter.Convert(src, dtype); |
181 | if (!converted) { |
182 | RaiseTypeError(api_info.api_name(), " argument " , |
183 | api_info.param_names()[param_index]); |
184 | return false; |
185 | } |
186 | } |
187 | } |
188 | |
189 | Py_DECREF(src); |
190 | src = converted.release(); |
191 | return true; |
192 | } |
193 | |
194 | // Converts the specified attribute parameter to the expected type. Modifies |
195 | // `params` in-place. Returns true on success, or sets an exception and |
196 | // returns false on failure. |
197 | ABSL_MUST_USE_RESULT |
198 | bool ConvertAttribute(const Attribute& attr, const PythonAPIInfo& api_info, |
199 | absl::Span<PyObject*> params) { |
200 | if (attr.index == -1) return true; // Inferred attribute. |
201 | PyObject* src = params[attr.index]; |
202 | Safe_PyObjectPtr converted = ConvertPyObjectToAttributeType(src, attr.type); |
203 | if (!converted) { |
204 | RaiseTypeError(api_info.api_name(), " argument " , |
205 | api_info.param_names()[attr.index]); |
206 | return false; |
207 | } |
208 | if (converted.get() != src) { |
209 | Py_DECREF(src); |
210 | params[attr.index] = converted.release(); |
211 | } |
212 | return true; |
213 | } |
214 | |
215 | // Converts the specified fixed-dtype input parameter to a Tensor with the |
216 | // expected dtype. Modifies `params` in-place. Returns true on success, or |
217 | // sets an exception and returns false on failure. |
218 | ABSL_MUST_USE_RESULT |
219 | bool ConvertInputWithFixedDType(const InputWithFixedDType& input, |
220 | const PythonTensorConverter& tensor_converter, |
221 | const PythonAPIInfo& api_info, |
222 | absl::Span<PyObject*> params) { |
223 | DataType dtype = input.dtype; |
224 | PyObject*& src = params[input.index]; |
225 | if (!input.is_list) { |
226 | RETURN_IF_FALSE(ConvertToTensorInPlace(src, dtype, tensor_converter, |
227 | api_info, input.index)); |
228 | } else { |
229 | DCHECK(PyList_CheckExact(src)); |
230 | PyObject** items = PyList_ITEMS(src); |
231 | Py_ssize_t len = PyList_GET_SIZE(src); |
232 | for (Py_ssize_t i = 0; i < len; ++i) { |
233 | RETURN_IF_FALSE(ConvertToTensorInPlace(items[i], dtype, tensor_converter, |
234 | api_info, input.index)); |
235 | } |
236 | } |
237 | return true; |
238 | } |
239 | |
240 | // Infers a consistent dtype for the specified collection of homogeneous-dtype |
241 | // input parameters, and converts those parameters to Tensors (or lists of |
242 | // Tensors) with that dtype. Modifies `params` in-place, and updates |
243 | // `inferred_attrs` with the inferred dtype (if it's not null). Returns true |
244 | // on success, or sets an exception and returns false on failure. |
245 | ABSL_MUST_USE_RESULT |
246 | bool ConvertInputsWithTypeAttr(const InputsWithTypeAttr& input, |
247 | const PythonTensorConverter& tensor_converter, |
248 | const PythonAPIInfo& api_info, |
249 | absl::Span<PyObject*> params, |
250 | InferredAttributes* inferred_attrs) { |
251 | DataType dtype = DT_INVALID; |
252 | if (input.type_attr->index != -1) { |
253 | // explicit type attribute |
254 | PyObject* py_dtype = params[input.type_attr->index]; |
255 | dtype = DataTypeFromPyDType(py_dtype); |
256 | } else { |
257 | // implicit type attribute: infer the dtype. |
258 | // First, check the single-tensor inputs. |
259 | for (ParamIndex index : input.tensor_params) { |
260 | RETURN_IF_FALSE(InferDType(params[index], dtype)); |
261 | if (dtype != DT_INVALID) break; |
262 | } |
263 | // Next, check the list-of-tensor inputs. |
264 | if (dtype == DT_INVALID) { |
265 | for (ParamIndex index : input.tensor_list_params) { |
266 | PyObject* tensor_list = params[index]; |
267 | DCHECK(PyList_CheckExact(tensor_list)); |
268 | Py_ssize_t num_tensors = PyList_GET_SIZE(tensor_list); |
269 | PyObject** tensors = PyList_ITEMS(tensor_list); |
270 | for (Py_ssize_t i = 0; i < num_tensors; ++i) { |
271 | RETURN_IF_FALSE(InferDType(tensors[i], dtype)); |
272 | if (dtype != DT_INVALID) break; |
273 | } |
274 | if (dtype != DT_INVALID) break; |
275 | } |
276 | } |
277 | } |
278 | |
279 | // Convert the single-tensor inputs to tensors. |
280 | for (ParamIndex index : input.tensor_params) { |
281 | RETURN_IF_FALSE( |
282 | ConvertToTensorInPlace(params[index], dtype, tensor_converter, api_info, |
283 | index, &input.ok_dtypes, input.default_dtype)); |
284 | } |
285 | |
286 | // Convert the list-of-tensor inputs to tensors. |
287 | for (ParamIndex index : input.tensor_list_params) { |
288 | PyObject* tensor_list = params[index]; |
289 | DCHECK(PyList_CheckExact(tensor_list)); |
290 | Py_ssize_t num_tensors = PyList_GET_SIZE(tensor_list); |
291 | PyObject** items = PyList_ITEMS(tensor_list); |
292 | for (Py_ssize_t i = 0; i < num_tensors; ++i) { |
293 | RETURN_IF_FALSE(ConvertToTensorInPlace(items[i], dtype, tensor_converter, |
294 | api_info, index, &input.ok_dtypes, |
295 | input.default_dtype)); |
296 | } |
297 | } |
298 | |
299 | if (inferred_attrs) { |
300 | if (dtype == DT_INVALID) { |
301 | dtype = input.default_dtype; |
302 | } |
303 | // TODO(b/164980194) Should we raise an exception here if we didn't manage |
304 | // to infer a dtype? (I.e., if there were no single-tensor inputs and all |
305 | // list-of-tensor inputs were empty, and there's no default dtype.) |
306 | int inferred_index = input.type_attr->inferred_index; |
307 | if (inferred_index != -1) { |
308 | inferred_attrs->types[inferred_index] = dtype; |
309 | } |
310 | } |
311 | |
312 | return true; |
313 | } |
314 | |
315 | // Infers a consistent list of dtypes for the specified collection of |
316 | // heterogeneous-dtype input parameters, and converts those parameters to lists |
317 | // of Tensors with those dtypes. Modifies `params` in-place, and updates |
318 | // `inferred_attrs` with the inferred dtypes (if it's not null). Returns true |
319 | // on success, or sets an exception and returns false on failure. |
320 | ABSL_MUST_USE_RESULT |
321 | bool ConvertInputsWithTypeListAttr( |
322 | const InputsWithTypeListAttr& input, |
323 | const PythonTensorConverter& tensor_converter, |
324 | const PythonAPIInfo& api_info, absl::Span<PyObject*> params, |
325 | InferredAttributes* inferred_attrs) { |
326 | DCHECK(!input.tensor_list_params.empty()); |
327 | |
328 | // Get the number of tensors from the first input list; and check that the |
329 | // remaining lists have the same size. |
330 | DCHECK(PyList_CheckExact(params[input.tensor_list_params[0]])); |
331 | Py_ssize_t num_tensors = PyList_GET_SIZE(params[input.tensor_list_params[0]]); |
332 | for (int i = 1; i < input.tensor_list_params.size(); ++i) { |
333 | DCHECK(PyList_CheckExact(params[input.tensor_list_params[i]])); |
334 | if (num_tensors != PyList_GET_SIZE(params[input.tensor_list_params[i]])) { |
335 | RaiseTypeError(api_info.api_name(), " expected parameters " , |
336 | api_info.param_names()[0], " and " , |
337 | api_info.param_names()[i], |
338 | " to be lists of the same length." ); |
339 | return false; |
340 | } |
341 | } |
342 | |
343 | // Get the list of dtypes. |
344 | std::vector<DataType> dtypes(num_tensors, DT_INVALID); |
345 | if (input.type_list_attr->index != -1) { |
346 | // Dtypes are specified by an explicit attribute. |
347 | PyObject* py_dtypes = params[input.type_list_attr->index]; |
348 | if (PyList_GET_SIZE(py_dtypes) != num_tensors) { |
349 | RaiseTypeError(api_info.api_name(), " expected parameters " , |
350 | api_info.param_names()[0], " and " , |
351 | api_info.param_names()[input.type_list_attr->index], |
352 | "to be lists of the same length." ); |
353 | return false; |
354 | } |
355 | for (Py_ssize_t i = 0; i < PyList_GET_SIZE(py_dtypes); ++i) { |
356 | dtypes[i] = DataTypeFromPyDType(PyList_GetItem(py_dtypes, i)); |
357 | } |
358 | } else { |
359 | // Dtypes are implicit: infer them. |
360 | for (Py_ssize_t i = 0; i < num_tensors; ++i) { |
361 | for (ParamIndex index : input.tensor_list_params) { |
362 | PyObject* tensor_list = params[index]; |
363 | DCHECK(PyList_CheckExact(tensor_list)); |
364 | PyObject* item = PyList_GET_ITEM(tensor_list, i); |
365 | RETURN_IF_FALSE(InferDType(item, dtypes[i])); |
366 | if (dtypes[i] != DT_INVALID) break; |
367 | } |
368 | } |
369 | } |
370 | |
371 | // Convert tensors. |
372 | for (ParamIndex index : input.tensor_list_params) { |
373 | PyObject* tensor_list = params[index]; |
374 | PyObject** items = PyList_ITEMS(tensor_list); |
375 | for (Py_ssize_t i = 0; i < num_tensors; ++i) { |
376 | DataType default_dtype = i < input.default_dtypes.size() |
377 | ? input.default_dtypes[i] |
378 | : DT_INVALID; |
379 | RETURN_IF_FALSE(ConvertToTensorInPlace(items[i], dtypes[i], |
380 | tensor_converter, api_info, index, |
381 | &input.ok_dtypes, default_dtype)); |
382 | } |
383 | } |
384 | |
385 | if (inferred_attrs) { |
386 | int inferred_index = input.type_list_attr->inferred_index; |
387 | if (inferred_index != -1) { |
388 | inferred_attrs->type_lists[inferred_index].swap(dtypes); |
389 | } |
390 | } |
391 | |
392 | return true; |
393 | } |
394 | |
395 | // Infers length attributes for Tensor-list parameters from their values, and |
396 | // updates `inferred_length_attrs` with the inferred length. Sets an exception |
397 | // if multiple Tensor-list parameters have the same length attribute but |
398 | // different lengths. Returns true on success, or sets an exception and returns |
399 | // false on failure. |
400 | ABSL_MUST_USE_RESULT |
401 | bool InferLengthAttributes(const absl::Span<PyObject*> params, |
402 | const PythonAPIInfo& api_info, |
403 | std::vector<int64_t>& inferred_length_attrs) { |
404 | for (int i = 0; i < api_info.inputs_with_number_attrs().size(); ++i) { |
405 | const auto& inputs = api_info.inputs_with_number_attrs()[i]; |
406 | DCHECK(!inputs.tensor_list_params.empty()); |
407 | |
408 | // Use the first tensor_list parameter to infer the length attribute. |
409 | PyObject* tensors = params[inputs.tensor_list_params[0]]; |
410 | DCHECK(PyList_CheckExact(tensors)); |
411 | int inferred_length = PyList_GET_SIZE(tensors); |
412 | |
413 | // Check that any other tensor_list parameters have matching length. |
414 | for (int j = 1; j < inputs.tensor_list_params.size(); ++j) { |
415 | int num_tensors = PyList_GET_SIZE(params[inputs.tensor_list_params[j]]); |
416 | if (num_tensors != inferred_length) { |
417 | RaiseTypeError(api_info.api_name(), " expected parameters " , |
418 | api_info.param_names()[inputs.tensor_list_params[0]], |
419 | " and " , |
420 | api_info.param_names()[inputs.tensor_list_params[j]], |
421 | " to be lists with the same length." ); |
422 | } |
423 | } |
424 | |
425 | int inferred_index = inputs.number_attr->inferred_index; |
426 | if (inferred_index != -1) { |
427 | inferred_length_attrs[inferred_index] = inferred_length; |
428 | } |
429 | } |
430 | return true; |
431 | } |
432 | |
433 | } // namespace |
434 | |
435 | bool ConvertPythonAPIParameters(const PythonAPIInfo& api_info, |
436 | const PythonTensorConverter& tensor_converter, |
437 | absl::Span<PyObject*> params, |
438 | InferredAttributes* inferred_attrs) { |
439 | // Make room for inferred attributes. |
440 | if (inferred_attrs) { |
441 | inferred_attrs->types.resize(api_info.inferred_type_attrs().size()); |
442 | inferred_attrs->type_lists.resize( |
443 | api_info.inferred_type_list_attrs().size()); |
444 | inferred_attrs->lengths.resize(api_info.inferred_length_attrs().size()); |
445 | } |
446 | |
447 | for (const auto& attr : api_info.attributes()) { |
448 | RETURN_IF_FALSE(ConvertAttribute(attr, api_info, params)); |
449 | } |
450 | |
451 | for (const auto& input : api_info.inputs_with_fixed_dtype()) { |
452 | RETURN_IF_FALSE( |
453 | ConvertInputWithFixedDType(input, tensor_converter, api_info, params)); |
454 | } |
455 | |
456 | for (int i = 0; i < api_info.inputs_with_type_attrs().size(); ++i) { |
457 | RETURN_IF_FALSE(ConvertInputsWithTypeAttr( |
458 | api_info.inputs_with_type_attrs()[i], tensor_converter, api_info, |
459 | params, inferred_attrs)); |
460 | } |
461 | |
462 | for (int i = 0; i < api_info.inputs_with_type_list_attrs().size(); ++i) { |
463 | RETURN_IF_FALSE(ConvertInputsWithTypeListAttr( |
464 | api_info.inputs_with_type_list_attrs()[i], tensor_converter, api_info, |
465 | params, inferred_attrs)); |
466 | } |
467 | |
468 | if (inferred_attrs) { |
469 | RETURN_IF_FALSE( |
470 | InferLengthAttributes(params, api_info, inferred_attrs->lengths)); |
471 | } |
472 | |
473 | return true; |
474 | } |
475 | |
476 | bool CopyPythonAPITensorLists(const PythonAPIInfo& api_info, |
477 | absl::Span<PyObject*> params) { |
478 | for (const auto& input : api_info.inputs()) { |
479 | if (input.is_list) { |
480 | PyObject* src = params[input.index]; |
481 | PyObject* copy = PySequence_List(src); |
482 | if (!copy) { |
483 | RaiseTypeError(api_info.api_name(), " expected a list of Tensors for '" , |
484 | api_info.param_names()[input.index], "'; got " , |
485 | src->ob_type->tp_name, "." ); |
486 | return false; |
487 | } |
488 | Py_DECREF(params[input.index]); |
489 | params[input.index] = copy; |
490 | } |
491 | } |
492 | return true; |
493 | } |
494 | |
495 | } // namespace tensorflow |
496 | |