1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/python/framework/python_tensor_converter.h"
16
17#include "absl/strings/str_cat.h"
18#include "tensorflow/python/eager/pywrap_tensor.h"
19#include "tensorflow/python/eager/pywrap_tfe.h"
20#include "tensorflow/python/util/util.h"
21
22#if PY_MAJOR_VERSION < 3
23// Python 2.x:
24#define PY_INT_AS_LONG(x) (PyInt_AsLong(x))
25#define PY_STRING_INTERN_FROM_STRING(x) (PyString_InternFromString(x))
26#else
27// Python 3.x:
28#define PY_INT_AS_LONG(x) (PyLong_AsLong(x))
29#define PY_STRING_INTERN_FROM_STRING(x) (PyUnicode_InternFromString(x))
30#endif
31
32namespace tensorflow {
33namespace {
34
35// Returns `tensor.dtype._type_enum` as a DataType enum. Assumes that `tensor`
36// is a python `Tensor` object.
37//
38// On error: sets a python AttributeError exception and returns DT_INVALID.
39DataType DataTypeForTensor(PyObject* tensor) {
40 static PyObject* dtype_attr = PY_STRING_INTERN_FROM_STRING("dtype");
41 static PyObject* type_enum_attr = PY_STRING_INTERN_FROM_STRING("_type_enum");
42
43 Safe_PyObjectPtr py_dtype(PyObject_GetAttr(tensor, dtype_attr));
44 if (!py_dtype) return DT_INVALID;
45
46 Safe_PyObjectPtr enum_field(PyObject_GetAttr(py_dtype.get(), type_enum_attr));
47 if (!enum_field) return DT_INVALID;
48
49 DataType result = static_cast<DataType>(PY_INT_AS_LONG(enum_field.get()));
50 return result;
51}
52
53// Check that actual_dtype == expected_dtype. If not, set an exception and
54// return false. (If expected_dtype is DT_INVALID, then instead simply update
55// its value to `actual_dtype` and return true.)
56bool CheckDType(DataType actual_dtype, DataType& expected_dtype) {
57 if (expected_dtype == DT_INVALID) {
58 expected_dtype = actual_dtype; // set output parameter.
59 } else if (expected_dtype != actual_dtype) {
60 PyErr_SetString(PyExc_TypeError,
61 absl::StrCat("Expected ", DataType_Name(expected_dtype),
62 " but got ", DataType_Name(actual_dtype))
63 .c_str());
64 return false;
65 }
66 return true;
67}
68
69} // namespace
70
71Safe_PyObjectPtr PythonTensorConverter::Convert(PyObject* src, DataType& dtype,
72 bool* used_fallback) const {
73 // First, try converting `src` to a Tensor without calling back into Python.
74 if (ctx_) { // Eager mode
75 // TODO(b/164980194): Handle resource variables as well. (See
76 // ConvertToTensor function in pywrap_tfe_src.cc).
77 if (EagerTensor_CheckExact(src)) {
78 // `src` is already an eager tensor; check its type, and return it as-is.
79 if (!CheckDType(PyEagerTensor_Dtype(src), dtype)) return nullptr;
80 Py_INCREF(src);
81 return Safe_PyObjectPtr(src);
82 } else {
83 TFE_TensorHandle* handle =
84 tensorflow::ConvertToEagerTensor(ctx_, src, dtype, device_name_);
85 if (handle) {
86 Safe_PyObjectPtr result(EagerTensorFromHandle(handle));
87 if (!CheckDType(PyEagerTensor_Dtype(result.get()), dtype)) {
88 return nullptr;
89 }
90 return result;
91 } else {
92 PyErr_Clear();
93 }
94 }
95 } else { // Graph mode
96 if (swig::IsTensor(src)) {
97 DataType src_dtype = DataTypeForTensor(src);
98 if (src_dtype == DT_INVALID) return nullptr;
99 if (!CheckDType(src_dtype, dtype)) return nullptr;
100 Py_INCREF(src);
101 return Safe_PyObjectPtr(src);
102 }
103 }
104
105 // Fallback: use the Python tf.convert_to_tensor function.
106 // Currently this is used:
107 //
108 // * In Eager mode: for anything that's not already an Eager tensor, or
109 // handled by `tensorflow::ConvertToEagerTensor`. (At time of writing
110 // for this comment, ConvertToEagerTensor handles simple values like ints,
111 // nested lists of simple values, and numpy arrays.)
112 // * In graph mode: for anything that's not already a tensor.
113 //
114 // TODO(b/164980194) Reduce/eliminate cases where fallback is used.
115 if (used_fallback) *used_fallback = true;
116 static PyObject* convert_to_tensor =
117 swig::GetRegisteredPyObject("tf.convert_to_tensor");
118 if (!convert_to_tensor) return nullptr;
119
120 Safe_PyObjectPtr args(PyTuple_New(dtype == DT_INVALID ? 1 : 2));
121 Safe_PyObjectPtr kwargs(PyDict_New());
122 Py_INCREF(src);
123 PyTuple_SetItem(args.get(), 0, src);
124 if (dtype != DT_INVALID) {
125 PyTuple_SetItem(args.get(), 1, PyLong_FromLong(dtype));
126 }
127 PyDict_SetItemString(kwargs.get(), "ctx", py_eager_context_);
128 Safe_PyObjectPtr result(
129 PyObject_Call(convert_to_tensor, args.get(), kwargs.get()));
130 if (!result) return nullptr;
131 dtype = DataTypeForTensor(result.get()); // set output parameter.
132 if (dtype == DT_INVALID) return nullptr;
133 return result;
134}
135
136} // namespace tensorflow
137