1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/python/framework/python_tensor_converter.h" |
16 | |
17 | #include "absl/strings/str_cat.h" |
18 | #include "tensorflow/python/eager/pywrap_tensor.h" |
19 | #include "tensorflow/python/eager/pywrap_tfe.h" |
20 | #include "tensorflow/python/util/util.h" |
21 | |
22 | #if PY_MAJOR_VERSION < 3 |
23 | // Python 2.x: |
24 | #define PY_INT_AS_LONG(x) (PyInt_AsLong(x)) |
25 | #define PY_STRING_INTERN_FROM_STRING(x) (PyString_InternFromString(x)) |
26 | #else |
27 | // Python 3.x: |
28 | #define PY_INT_AS_LONG(x) (PyLong_AsLong(x)) |
29 | #define PY_STRING_INTERN_FROM_STRING(x) (PyUnicode_InternFromString(x)) |
30 | #endif |
31 | |
32 | namespace tensorflow { |
33 | namespace { |
34 | |
35 | // Returns `tensor.dtype._type_enum` as a DataType enum. Assumes that `tensor` |
36 | // is a python `Tensor` object. |
37 | // |
38 | // On error: sets a python AttributeError exception and returns DT_INVALID. |
39 | DataType DataTypeForTensor(PyObject* tensor) { |
40 | static PyObject* dtype_attr = PY_STRING_INTERN_FROM_STRING("dtype" ); |
41 | static PyObject* type_enum_attr = PY_STRING_INTERN_FROM_STRING("_type_enum" ); |
42 | |
43 | Safe_PyObjectPtr py_dtype(PyObject_GetAttr(tensor, dtype_attr)); |
44 | if (!py_dtype) return DT_INVALID; |
45 | |
46 | Safe_PyObjectPtr enum_field(PyObject_GetAttr(py_dtype.get(), type_enum_attr)); |
47 | if (!enum_field) return DT_INVALID; |
48 | |
49 | DataType result = static_cast<DataType>(PY_INT_AS_LONG(enum_field.get())); |
50 | return result; |
51 | } |
52 | |
53 | // Check that actual_dtype == expected_dtype. If not, set an exception and |
54 | // return false. (If expected_dtype is DT_INVALID, then instead simply update |
55 | // its value to `actual_dtype` and return true.) |
56 | bool CheckDType(DataType actual_dtype, DataType& expected_dtype) { |
57 | if (expected_dtype == DT_INVALID) { |
58 | expected_dtype = actual_dtype; // set output parameter. |
59 | } else if (expected_dtype != actual_dtype) { |
60 | PyErr_SetString(PyExc_TypeError, |
61 | absl::StrCat("Expected " , DataType_Name(expected_dtype), |
62 | " but got " , DataType_Name(actual_dtype)) |
63 | .c_str()); |
64 | return false; |
65 | } |
66 | return true; |
67 | } |
68 | |
69 | } // namespace |
70 | |
71 | Safe_PyObjectPtr PythonTensorConverter::Convert(PyObject* src, DataType& dtype, |
72 | bool* used_fallback) const { |
73 | // First, try converting `src` to a Tensor without calling back into Python. |
74 | if (ctx_) { // Eager mode |
75 | // TODO(b/164980194): Handle resource variables as well. (See |
76 | // ConvertToTensor function in pywrap_tfe_src.cc). |
77 | if (EagerTensor_CheckExact(src)) { |
78 | // `src` is already an eager tensor; check its type, and return it as-is. |
79 | if (!CheckDType(PyEagerTensor_Dtype(src), dtype)) return nullptr; |
80 | Py_INCREF(src); |
81 | return Safe_PyObjectPtr(src); |
82 | } else { |
83 | TFE_TensorHandle* handle = |
84 | tensorflow::ConvertToEagerTensor(ctx_, src, dtype, device_name_); |
85 | if (handle) { |
86 | Safe_PyObjectPtr result(EagerTensorFromHandle(handle)); |
87 | if (!CheckDType(PyEagerTensor_Dtype(result.get()), dtype)) { |
88 | return nullptr; |
89 | } |
90 | return result; |
91 | } else { |
92 | PyErr_Clear(); |
93 | } |
94 | } |
95 | } else { // Graph mode |
96 | if (swig::IsTensor(src)) { |
97 | DataType src_dtype = DataTypeForTensor(src); |
98 | if (src_dtype == DT_INVALID) return nullptr; |
99 | if (!CheckDType(src_dtype, dtype)) return nullptr; |
100 | Py_INCREF(src); |
101 | return Safe_PyObjectPtr(src); |
102 | } |
103 | } |
104 | |
105 | // Fallback: use the Python tf.convert_to_tensor function. |
106 | // Currently this is used: |
107 | // |
108 | // * In Eager mode: for anything that's not already an Eager tensor, or |
109 | // handled by `tensorflow::ConvertToEagerTensor`. (At time of writing |
110 | // for this comment, ConvertToEagerTensor handles simple values like ints, |
111 | // nested lists of simple values, and numpy arrays.) |
112 | // * In graph mode: for anything that's not already a tensor. |
113 | // |
114 | // TODO(b/164980194) Reduce/eliminate cases where fallback is used. |
115 | if (used_fallback) *used_fallback = true; |
116 | static PyObject* convert_to_tensor = |
117 | swig::GetRegisteredPyObject("tf.convert_to_tensor" ); |
118 | if (!convert_to_tensor) return nullptr; |
119 | |
120 | Safe_PyObjectPtr args(PyTuple_New(dtype == DT_INVALID ? 1 : 2)); |
121 | Safe_PyObjectPtr kwargs(PyDict_New()); |
122 | Py_INCREF(src); |
123 | PyTuple_SetItem(args.get(), 0, src); |
124 | if (dtype != DT_INVALID) { |
125 | PyTuple_SetItem(args.get(), 1, PyLong_FromLong(dtype)); |
126 | } |
127 | PyDict_SetItemString(kwargs.get(), "ctx" , py_eager_context_); |
128 | Safe_PyObjectPtr result( |
129 | PyObject_Call(convert_to_tensor, args.get(), kwargs.get())); |
130 | if (!result) return nullptr; |
131 | dtype = DataTypeForTensor(result.get()); // set output parameter. |
132 | if (dtype == DT_INVALID) return nullptr; |
133 | return result; |
134 | } |
135 | |
136 | } // namespace tensorflow |
137 | |