1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/python/framework/op_def_util.h"
16
17#include <map>
18
19#include "absl/strings/str_cat.h"
20#include "tensorflow/core/framework/attr_value.pb.h"
21#include "tensorflow/core/framework/tensor_shape.pb.h"
22#include "tensorflow/core/framework/types.pb.h"
23#include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
24#include "tensorflow/python/util/util.h"
25
26using ::tensorflow::swig::GetRegisteredPyObject;
27
28#if PY_MAJOR_VERSION < 3
29// Python 2.x:
30#define PY_STRING_CHECK(x) (PyString_Check(x) || PyUnicode_Check(x))
31#define PY_STRING_FROMSTRING(x) (PyString_FromString(x))
32#define PY_INT_CHECK(x) (PyInt_Check(x))
33#define PY_INT_TYPE PyInt_Type
34#define PY_INT_FROM_LONG(x) (PyInt_FromLong(x))
35#else
36// Python 3.x:
37#define PY_STRING_CHECK(x) (PyBytes_Check(x) || PyUnicode_Check(x))
38#define PY_STRING_FROMSTRING(x) (PyUnicode_FromString(x))
39#define PY_INT_CHECK(x) (PyLong_Check(x))
40#define PY_INT_TYPE PyLong_Type
41#define PY_INT_FROM_LONG(x) (PyLong_FromLong(x))
42#endif
43
44namespace tensorflow {
45
46namespace {
47
48const std::map<std::string, AttributeType>* AttributeTypeNameMap() {
49 static auto* type_map = new std::map<std::string, AttributeType>(
50 {{"any", AttributeType::ANY},
51 {"float", AttributeType::FLOAT},
52 {"int", AttributeType::INT},
53 {"string", AttributeType::STRING},
54 {"bool", AttributeType::BOOL},
55 {"shape", AttributeType::SHAPE},
56 {"type", AttributeType::DTYPE},
57 {"tensor", AttributeType::TENSOR},
58 {"list(any)", AttributeType::LIST_ANY},
59 {"list(float)", AttributeType::LIST_FLOAT},
60 {"list(int)", AttributeType::LIST_INT},
61 {"list(string)", AttributeType::LIST_STRING},
62 {"list(bool)", AttributeType::LIST_BOOL},
63 {"list(type)", AttributeType::LIST_DTYPE},
64 {"list(shape)", AttributeType::LIST_SHAPE},
65 {"list(tensor)", AttributeType::LIST_TENSOR}});
66 return type_map;
67}
68
69// Note: we define functors for converting value types (rather than simple
70// functions) so we can define a generic ConvertListAttr method. These
71// functors all return a new reference on success, or nullptr on failure.
72// They do not (necessarily) call PyErr_SetString.
73
74struct ConvertAnyFunctor {
75 Safe_PyObjectPtr operator()(PyObject* value) {
76 Py_INCREF(value);
77 return Safe_PyObjectPtr(value);
78 }
79};
80
81struct ConvertFloatFunctor {
82 Safe_PyObjectPtr operator()(PyObject* value) {
83 Safe_PyObjectPtr result;
84 if (PyFloat_Check(value)) {
85 Py_INCREF(value);
86 result.reset(value);
87 } else if (!PY_STRING_CHECK(value)) {
88 result.reset(PyObject_CallFunctionObjArgs(
89 reinterpret_cast<PyObject*>(&PyFloat_Type), value, nullptr));
90 }
91 return result;
92 }
93};
94
95struct ConvertIntFunctor {
96 Safe_PyObjectPtr operator()(PyObject* value) {
97 Safe_PyObjectPtr result;
98 if (PY_INT_CHECK(value)) {
99 Py_INCREF(value);
100 result.reset(value);
101 } else if (!PY_STRING_CHECK(value)) {
102 result.reset(PyObject_CallFunctionObjArgs(
103 reinterpret_cast<PyObject*>(&PY_INT_TYPE), value, nullptr));
104 }
105 return result;
106 }
107};
108
109struct ConvertStringFunctor {
110 Safe_PyObjectPtr operator()(PyObject* value) {
111 Safe_PyObjectPtr result;
112 if (PY_STRING_CHECK(value)) {
113 Py_INCREF(value);
114 result.reset(value);
115 }
116 return result;
117 }
118};
119
120// TODO(edloper): Should we allow ints (or any other values) to be converted
121// to booleans? Currently, TensorFlow does not do this conversion for attribute
122// values in _MakeBool or make_bool.
123struct ConvertBoolFunctor {
124 Safe_PyObjectPtr operator()(PyObject* value) {
125 Safe_PyObjectPtr result;
126 if (PyBool_Check(value)) {
127 Py_INCREF(value);
128 result.reset(value);
129 }
130 return result;
131 }
132};
133
134struct ConvertDTypeFunctor {
135 Safe_PyObjectPtr operator()(PyObject* value) {
136 Safe_PyObjectPtr result;
137 // The following symbols are registered in op_def_library.py
138 static PyObject* dtype = GetRegisteredPyObject("tf.dtypes.DType");
139 static PyObject* as_dtype = GetRegisteredPyObject("tf.dtypes.as_dtype");
140 if (reinterpret_cast<PyObject*>(value->ob_type) == dtype) {
141 Py_INCREF(value);
142 result.reset(value);
143 } else {
144 result.reset(PyObject_CallFunctionObjArgs(as_dtype, value, nullptr));
145 }
146 return result;
147 }
148};
149
150struct ConvertTensorShapeFunctor {
151 Safe_PyObjectPtr operator()(PyObject* value) {
152 Safe_PyObjectPtr result;
153 // The following symbols are registered in op_def_library.py
154 static PyObject* shape = GetRegisteredPyObject("tf.TensorShape");
155 static PyObject* as_shape = GetRegisteredPyObject("tf.as_shape");
156 if (reinterpret_cast<PyObject*>(value->ob_type) == shape) {
157 Py_INCREF(value);
158 result.reset(value);
159 } else {
160 result.reset(PyObject_CallFunctionObjArgs(as_shape, value, nullptr));
161 }
162 return result;
163 }
164};
165
166struct ConvertTensorProtoFunctor {
167 Safe_PyObjectPtr operator()(PyObject* value) {
168 Safe_PyObjectPtr result;
169 // The following symbols are registered in op_def_library.py
170 static PyObject* tensor_proto = GetRegisteredPyObject("tf.TensorProto");
171 static PyObject* text_format_parse =
172 GetRegisteredPyObject("text_format.Parse");
173 if (reinterpret_cast<PyObject*>(value->ob_type) == tensor_proto) {
174 Py_INCREF(value);
175 result.reset(value);
176 } else if (PY_STRING_CHECK(value)) {
177 result.reset(PyObject_CallObject(tensor_proto, nullptr));
178 if (result) {
179 if (!PyObject_CallFunctionObjArgs(text_format_parse, value,
180 result.get(), nullptr)) {
181 return nullptr;
182 }
183 }
184 }
185 return result;
186 }
187};
188
189// Converts `value` to a list of elements with the same type, using
190// `convert_functor` to convert each element.
191template <typename T>
192Safe_PyObjectPtr ConvertListAttr(PyObject* value, T convert_functor) {
193 // Copy the list.
194 Safe_PyObjectPtr result(PySequence_List(value));
195 if (!result) return nullptr;
196
197 // Check the type of each item in the list.
198 Py_ssize_t len = PySequence_Fast_GET_SIZE(result.get());
199 PyObject** items = PySequence_Fast_ITEMS(result.get());
200 for (Py_ssize_t i = 0; i < len; ++i) {
201 if (!PyFloat_Check(value)) {
202 Safe_PyObjectPtr item = convert_functor(items[i]);
203 if (!item) return nullptr;
204 PySequence_SetItem(result.get(), i, item.get());
205 }
206 }
207 return result;
208}
209
210// Returns the given `value` value, converted to the indicated type.
211// Returns nullptr if `value` is not convertible.
212Safe_PyObjectPtr ConvertAttrOrNull(PyObject* value, AttributeType attr_type) {
213 switch (attr_type) {
214 case AttributeType::ANY:
215 return ConvertAnyFunctor()(value);
216 case AttributeType::FLOAT:
217 return ConvertFloatFunctor()(value);
218 case AttributeType::INT:
219 return ConvertIntFunctor()(value);
220 case AttributeType::STRING:
221 return ConvertStringFunctor()(value);
222 case AttributeType::BOOL:
223 return ConvertBoolFunctor()(value);
224 case AttributeType::DTYPE:
225 return ConvertDTypeFunctor()(value);
226 case AttributeType::SHAPE:
227 return ConvertTensorShapeFunctor()(value);
228 case AttributeType::TENSOR:
229 return ConvertTensorProtoFunctor()(value);
230 case AttributeType::LIST_ANY:
231 return ConvertListAttr(value, ConvertAnyFunctor());
232 case AttributeType::LIST_FLOAT:
233 return ConvertListAttr(value, ConvertFloatFunctor());
234 case AttributeType::LIST_INT:
235 return ConvertListAttr(value, ConvertIntFunctor());
236 case AttributeType::LIST_STRING:
237 return ConvertListAttr(value, ConvertStringFunctor());
238 case AttributeType::LIST_BOOL:
239 return ConvertListAttr(value, ConvertBoolFunctor());
240 case AttributeType::LIST_DTYPE:
241 return ConvertListAttr(value, ConvertDTypeFunctor());
242 case AttributeType::LIST_SHAPE:
243 return ConvertListAttr(value, ConvertTensorShapeFunctor());
244 case AttributeType::LIST_TENSOR:
245 return ConvertListAttr(value, ConvertTensorProtoFunctor());
246 default:
247 return nullptr;
248 }
249}
250
251// Returns a new reference to Py_True or Py_False depending on b.
252PyObject* PyBool_FromBool(bool b) {
253 PyObject* result = b ? Py_True : Py_False;
254 Py_INCREF(result);
255 return result;
256}
257
258Safe_PyObjectPtr AttrValueListToPyObject(AttrValue::ListValue list) {
259 if (list.s_size()) {
260 Safe_PyObjectPtr result(PyList_New(list.s_size()));
261 for (int i = 0; i < list.s_size(); ++i) {
262 PyList_SET_ITEM(result.get(), i, PY_STRING_FROMSTRING(list.s(i).c_str()));
263 }
264 return result;
265 } else if (list.i_size()) {
266 Safe_PyObjectPtr result(PyList_New(list.i_size()));
267 for (int i = 0; i < list.i_size(); ++i) {
268 PyList_SET_ITEM(result.get(), i, PY_INT_FROM_LONG(list.i(i)));
269 }
270 return result;
271 } else if (list.f_size()) {
272 Safe_PyObjectPtr result(PyList_New(list.f_size()));
273 for (int i = 0; i < list.f_size(); ++i) {
274 PyList_SET_ITEM(result.get(), i, PyFloat_FromDouble(list.f(i)));
275 }
276 return result;
277 } else if (list.b_size()) {
278 Safe_PyObjectPtr result(PyList_New(list.b_size()));
279 for (int i = 0; i < list.b_size(); ++i) {
280 PyList_SET_ITEM(result.get(), i, PyBool_FromBool(list.b(i)));
281 }
282 return result;
283 } else if (list.type_size()) {
284 Safe_PyObjectPtr result(PyList_New(list.type_size()));
285 for (int i = 0; i < list.type_size(); ++i) {
286 Safe_PyObjectPtr item(DataTypeToPyObject(list.type(i)));
287 Py_INCREF(item.get());
288 PyList_SET_ITEM(result.get(), i, item.get());
289 }
290 return result;
291 } else if (list.shape_size()) {
292 Safe_PyObjectPtr result(PyList_New(list.shape_size()));
293 for (int i = 0; i < list.shape_size(); ++i) {
294 Safe_PyObjectPtr item(TensorShapeProtoToPyObject(list.shape(i)));
295 Py_INCREF(item.get());
296 PyList_SET_ITEM(result.get(), i, item.get());
297 }
298 return result;
299 } else if (list.tensor_size() || list.func_size()) {
300 // TODO(edloper): Add support for tensorflow::AttrValue::kTensor.
301 PyErr_SetString(PyExc_TypeError, "Unsupported AttrValue type");
302 return nullptr;
303 } else {
304 // Empty list
305 return Safe_PyObjectPtr(PyList_New(0));
306 }
307}
308
309} // namespace
310
311AttributeType AttributeTypeFromName(const std::string& type_name) {
312 const auto* type_map = AttributeTypeNameMap();
313 auto it = type_map->find(type_name);
314 return it != type_map->end() ? it->second : AttributeType::UNKNOWN;
315}
316
317std::string AttributeTypeToName(AttributeType attr_type) {
318 for (const auto& pair : *AttributeTypeNameMap()) {
319 if (pair.second == attr_type) {
320 return pair.first;
321 }
322 }
323 return "<unknown>";
324}
325
326Safe_PyObjectPtr ConvertPyObjectToAttributeType(PyObject* value,
327 AttributeType type) {
328 Safe_PyObjectPtr result = ConvertAttrOrNull(value, type);
329 if (!result) {
330 auto err = absl::StrCat("Failed to convert value of type '",
331 value->ob_type->tp_name, "' to type '",
332 AttributeTypeToName(type), "'.");
333 PyErr_SetString(PyExc_TypeError, err.c_str());
334 }
335
336 return result;
337}
338
339Safe_PyObjectPtr AttrValueToPyObject(const AttrValue& attr_value) {
340 switch (attr_value.value_case()) {
341 case tensorflow::AttrValue::kS:
342 return Safe_PyObjectPtr(PY_STRING_FROMSTRING(attr_value.s().c_str()));
343 case tensorflow::AttrValue::kI:
344 return Safe_PyObjectPtr(PY_INT_FROM_LONG(attr_value.i()));
345 case tensorflow::AttrValue::kF:
346 return Safe_PyObjectPtr(PyFloat_FromDouble(attr_value.f()));
347 case tensorflow::AttrValue::kB:
348 return Safe_PyObjectPtr(PyBool_FromBool(attr_value.b()));
349 case tensorflow::AttrValue::kType:
350 return DataTypeToPyObject(attr_value.type());
351 case tensorflow::AttrValue::kShape:
352 return TensorShapeProtoToPyObject(attr_value.shape());
353 case tensorflow::AttrValue::kList:
354 return AttrValueListToPyObject(attr_value.list());
355 default:
356 // TODO(edloper): Add support for tensorflow::AttrValue::kTensor.
357 PyErr_SetString(PyExc_ValueError, "Unsupported AttrValue type");
358 return nullptr;
359 }
360}
361
362Safe_PyObjectPtr DataTypeToPyObject(const DataType& data_type) {
363 Safe_PyObjectPtr enum_value(PY_INT_FROM_LONG(data_type));
364 return ConvertDTypeFunctor()(enum_value.get());
365}
366
367Safe_PyObjectPtr TensorShapeProtoToPyObject(
368 const TensorShapeProto& tensor_shape) {
369 if (tensor_shape.unknown_rank()) {
370 return ConvertTensorShapeFunctor()(Py_None);
371 } else {
372 Safe_PyObjectPtr dims(PyTuple_New(tensor_shape.dim_size()));
373 for (int i = 0; i < tensor_shape.dim_size(); ++i) {
374 PyTuple_SET_ITEM(dims.get(), i,
375 PY_INT_FROM_LONG(tensor_shape.dim(i).size()));
376 }
377 return ConvertTensorShapeFunctor()(dims.get());
378 }
379}
380
381} // namespace tensorflow
382