1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/python/framework/op_def_util.h" |
16 | |
17 | #include <map> |
18 | |
19 | #include "absl/strings/str_cat.h" |
20 | #include "tensorflow/core/framework/attr_value.pb.h" |
21 | #include "tensorflow/core/framework/tensor_shape.pb.h" |
22 | #include "tensorflow/core/framework/types.pb.h" |
23 | #include "tensorflow/python/lib/core/safe_pyobject_ptr.h" |
24 | #include "tensorflow/python/util/util.h" |
25 | |
26 | using ::tensorflow::swig::GetRegisteredPyObject; |
27 | |
28 | #if PY_MAJOR_VERSION < 3 |
29 | // Python 2.x: |
30 | #define PY_STRING_CHECK(x) (PyString_Check(x) || PyUnicode_Check(x)) |
31 | #define PY_STRING_FROMSTRING(x) (PyString_FromString(x)) |
32 | #define PY_INT_CHECK(x) (PyInt_Check(x)) |
33 | #define PY_INT_TYPE PyInt_Type |
34 | #define PY_INT_FROM_LONG(x) (PyInt_FromLong(x)) |
35 | #else |
36 | // Python 3.x: |
37 | #define PY_STRING_CHECK(x) (PyBytes_Check(x) || PyUnicode_Check(x)) |
38 | #define PY_STRING_FROMSTRING(x) (PyUnicode_FromString(x)) |
39 | #define PY_INT_CHECK(x) (PyLong_Check(x)) |
40 | #define PY_INT_TYPE PyLong_Type |
41 | #define PY_INT_FROM_LONG(x) (PyLong_FromLong(x)) |
42 | #endif |
43 | |
44 | namespace tensorflow { |
45 | |
46 | namespace { |
47 | |
48 | const std::map<std::string, AttributeType>* AttributeTypeNameMap() { |
49 | static auto* type_map = new std::map<std::string, AttributeType>( |
50 | {{"any" , AttributeType::ANY}, |
51 | {"float" , AttributeType::FLOAT}, |
52 | {"int" , AttributeType::INT}, |
53 | {"string" , AttributeType::STRING}, |
54 | {"bool" , AttributeType::BOOL}, |
55 | {"shape" , AttributeType::SHAPE}, |
56 | {"type" , AttributeType::DTYPE}, |
57 | {"tensor" , AttributeType::TENSOR}, |
58 | {"list(any)" , AttributeType::LIST_ANY}, |
59 | {"list(float)" , AttributeType::LIST_FLOAT}, |
60 | {"list(int)" , AttributeType::LIST_INT}, |
61 | {"list(string)" , AttributeType::LIST_STRING}, |
62 | {"list(bool)" , AttributeType::LIST_BOOL}, |
63 | {"list(type)" , AttributeType::LIST_DTYPE}, |
64 | {"list(shape)" , AttributeType::LIST_SHAPE}, |
65 | {"list(tensor)" , AttributeType::LIST_TENSOR}}); |
66 | return type_map; |
67 | } |
68 | |
69 | // Note: we define functors for converting value types (rather than simple |
70 | // functions) so we can define a generic ConvertListAttr method. These |
71 | // functors all return a new reference on success, or nullptr on failure. |
72 | // They do not (necessarily) call PyErr_SetString. |
73 | |
74 | struct ConvertAnyFunctor { |
75 | Safe_PyObjectPtr operator()(PyObject* value) { |
76 | Py_INCREF(value); |
77 | return Safe_PyObjectPtr(value); |
78 | } |
79 | }; |
80 | |
81 | struct ConvertFloatFunctor { |
82 | Safe_PyObjectPtr operator()(PyObject* value) { |
83 | Safe_PyObjectPtr result; |
84 | if (PyFloat_Check(value)) { |
85 | Py_INCREF(value); |
86 | result.reset(value); |
87 | } else if (!PY_STRING_CHECK(value)) { |
88 | result.reset(PyObject_CallFunctionObjArgs( |
89 | reinterpret_cast<PyObject*>(&PyFloat_Type), value, nullptr)); |
90 | } |
91 | return result; |
92 | } |
93 | }; |
94 | |
95 | struct ConvertIntFunctor { |
96 | Safe_PyObjectPtr operator()(PyObject* value) { |
97 | Safe_PyObjectPtr result; |
98 | if (PY_INT_CHECK(value)) { |
99 | Py_INCREF(value); |
100 | result.reset(value); |
101 | } else if (!PY_STRING_CHECK(value)) { |
102 | result.reset(PyObject_CallFunctionObjArgs( |
103 | reinterpret_cast<PyObject*>(&PY_INT_TYPE), value, nullptr)); |
104 | } |
105 | return result; |
106 | } |
107 | }; |
108 | |
109 | struct ConvertStringFunctor { |
110 | Safe_PyObjectPtr operator()(PyObject* value) { |
111 | Safe_PyObjectPtr result; |
112 | if (PY_STRING_CHECK(value)) { |
113 | Py_INCREF(value); |
114 | result.reset(value); |
115 | } |
116 | return result; |
117 | } |
118 | }; |
119 | |
120 | // TODO(edloper): Should we allow ints (or any other values) to be converted |
121 | // to booleans? Currently, TensorFlow does not do this conversion for attribute |
122 | // values in _MakeBool or make_bool. |
123 | struct ConvertBoolFunctor { |
124 | Safe_PyObjectPtr operator()(PyObject* value) { |
125 | Safe_PyObjectPtr result; |
126 | if (PyBool_Check(value)) { |
127 | Py_INCREF(value); |
128 | result.reset(value); |
129 | } |
130 | return result; |
131 | } |
132 | }; |
133 | |
134 | struct ConvertDTypeFunctor { |
135 | Safe_PyObjectPtr operator()(PyObject* value) { |
136 | Safe_PyObjectPtr result; |
137 | // The following symbols are registered in op_def_library.py |
138 | static PyObject* dtype = GetRegisteredPyObject("tf.dtypes.DType" ); |
139 | static PyObject* as_dtype = GetRegisteredPyObject("tf.dtypes.as_dtype" ); |
140 | if (reinterpret_cast<PyObject*>(value->ob_type) == dtype) { |
141 | Py_INCREF(value); |
142 | result.reset(value); |
143 | } else { |
144 | result.reset(PyObject_CallFunctionObjArgs(as_dtype, value, nullptr)); |
145 | } |
146 | return result; |
147 | } |
148 | }; |
149 | |
150 | struct ConvertTensorShapeFunctor { |
151 | Safe_PyObjectPtr operator()(PyObject* value) { |
152 | Safe_PyObjectPtr result; |
153 | // The following symbols are registered in op_def_library.py |
154 | static PyObject* shape = GetRegisteredPyObject("tf.TensorShape" ); |
155 | static PyObject* as_shape = GetRegisteredPyObject("tf.as_shape" ); |
156 | if (reinterpret_cast<PyObject*>(value->ob_type) == shape) { |
157 | Py_INCREF(value); |
158 | result.reset(value); |
159 | } else { |
160 | result.reset(PyObject_CallFunctionObjArgs(as_shape, value, nullptr)); |
161 | } |
162 | return result; |
163 | } |
164 | }; |
165 | |
166 | struct ConvertTensorProtoFunctor { |
167 | Safe_PyObjectPtr operator()(PyObject* value) { |
168 | Safe_PyObjectPtr result; |
169 | // The following symbols are registered in op_def_library.py |
170 | static PyObject* tensor_proto = GetRegisteredPyObject("tf.TensorProto" ); |
171 | static PyObject* text_format_parse = |
172 | GetRegisteredPyObject("text_format.Parse" ); |
173 | if (reinterpret_cast<PyObject*>(value->ob_type) == tensor_proto) { |
174 | Py_INCREF(value); |
175 | result.reset(value); |
176 | } else if (PY_STRING_CHECK(value)) { |
177 | result.reset(PyObject_CallObject(tensor_proto, nullptr)); |
178 | if (result) { |
179 | if (!PyObject_CallFunctionObjArgs(text_format_parse, value, |
180 | result.get(), nullptr)) { |
181 | return nullptr; |
182 | } |
183 | } |
184 | } |
185 | return result; |
186 | } |
187 | }; |
188 | |
189 | // Converts `value` to a list of elements with the same type, using |
190 | // `convert_functor` to convert each element. |
191 | template <typename T> |
192 | Safe_PyObjectPtr ConvertListAttr(PyObject* value, T convert_functor) { |
193 | // Copy the list. |
194 | Safe_PyObjectPtr result(PySequence_List(value)); |
195 | if (!result) return nullptr; |
196 | |
197 | // Check the type of each item in the list. |
198 | Py_ssize_t len = PySequence_Fast_GET_SIZE(result.get()); |
199 | PyObject** items = PySequence_Fast_ITEMS(result.get()); |
200 | for (Py_ssize_t i = 0; i < len; ++i) { |
201 | if (!PyFloat_Check(value)) { |
202 | Safe_PyObjectPtr item = convert_functor(items[i]); |
203 | if (!item) return nullptr; |
204 | PySequence_SetItem(result.get(), i, item.get()); |
205 | } |
206 | } |
207 | return result; |
208 | } |
209 | |
210 | // Returns the given `value` value, converted to the indicated type. |
211 | // Returns nullptr if `value` is not convertible. |
212 | Safe_PyObjectPtr ConvertAttrOrNull(PyObject* value, AttributeType attr_type) { |
213 | switch (attr_type) { |
214 | case AttributeType::ANY: |
215 | return ConvertAnyFunctor()(value); |
216 | case AttributeType::FLOAT: |
217 | return ConvertFloatFunctor()(value); |
218 | case AttributeType::INT: |
219 | return ConvertIntFunctor()(value); |
220 | case AttributeType::STRING: |
221 | return ConvertStringFunctor()(value); |
222 | case AttributeType::BOOL: |
223 | return ConvertBoolFunctor()(value); |
224 | case AttributeType::DTYPE: |
225 | return ConvertDTypeFunctor()(value); |
226 | case AttributeType::SHAPE: |
227 | return ConvertTensorShapeFunctor()(value); |
228 | case AttributeType::TENSOR: |
229 | return ConvertTensorProtoFunctor()(value); |
230 | case AttributeType::LIST_ANY: |
231 | return ConvertListAttr(value, ConvertAnyFunctor()); |
232 | case AttributeType::LIST_FLOAT: |
233 | return ConvertListAttr(value, ConvertFloatFunctor()); |
234 | case AttributeType::LIST_INT: |
235 | return ConvertListAttr(value, ConvertIntFunctor()); |
236 | case AttributeType::LIST_STRING: |
237 | return ConvertListAttr(value, ConvertStringFunctor()); |
238 | case AttributeType::LIST_BOOL: |
239 | return ConvertListAttr(value, ConvertBoolFunctor()); |
240 | case AttributeType::LIST_DTYPE: |
241 | return ConvertListAttr(value, ConvertDTypeFunctor()); |
242 | case AttributeType::LIST_SHAPE: |
243 | return ConvertListAttr(value, ConvertTensorShapeFunctor()); |
244 | case AttributeType::LIST_TENSOR: |
245 | return ConvertListAttr(value, ConvertTensorProtoFunctor()); |
246 | default: |
247 | return nullptr; |
248 | } |
249 | } |
250 | |
251 | // Returns a new reference to Py_True or Py_False depending on b. |
252 | PyObject* PyBool_FromBool(bool b) { |
253 | PyObject* result = b ? Py_True : Py_False; |
254 | Py_INCREF(result); |
255 | return result; |
256 | } |
257 | |
258 | Safe_PyObjectPtr AttrValueListToPyObject(AttrValue::ListValue list) { |
259 | if (list.s_size()) { |
260 | Safe_PyObjectPtr result(PyList_New(list.s_size())); |
261 | for (int i = 0; i < list.s_size(); ++i) { |
262 | PyList_SET_ITEM(result.get(), i, PY_STRING_FROMSTRING(list.s(i).c_str())); |
263 | } |
264 | return result; |
265 | } else if (list.i_size()) { |
266 | Safe_PyObjectPtr result(PyList_New(list.i_size())); |
267 | for (int i = 0; i < list.i_size(); ++i) { |
268 | PyList_SET_ITEM(result.get(), i, PY_INT_FROM_LONG(list.i(i))); |
269 | } |
270 | return result; |
271 | } else if (list.f_size()) { |
272 | Safe_PyObjectPtr result(PyList_New(list.f_size())); |
273 | for (int i = 0; i < list.f_size(); ++i) { |
274 | PyList_SET_ITEM(result.get(), i, PyFloat_FromDouble(list.f(i))); |
275 | } |
276 | return result; |
277 | } else if (list.b_size()) { |
278 | Safe_PyObjectPtr result(PyList_New(list.b_size())); |
279 | for (int i = 0; i < list.b_size(); ++i) { |
280 | PyList_SET_ITEM(result.get(), i, PyBool_FromBool(list.b(i))); |
281 | } |
282 | return result; |
283 | } else if (list.type_size()) { |
284 | Safe_PyObjectPtr result(PyList_New(list.type_size())); |
285 | for (int i = 0; i < list.type_size(); ++i) { |
286 | Safe_PyObjectPtr item(DataTypeToPyObject(list.type(i))); |
287 | Py_INCREF(item.get()); |
288 | PyList_SET_ITEM(result.get(), i, item.get()); |
289 | } |
290 | return result; |
291 | } else if (list.shape_size()) { |
292 | Safe_PyObjectPtr result(PyList_New(list.shape_size())); |
293 | for (int i = 0; i < list.shape_size(); ++i) { |
294 | Safe_PyObjectPtr item(TensorShapeProtoToPyObject(list.shape(i))); |
295 | Py_INCREF(item.get()); |
296 | PyList_SET_ITEM(result.get(), i, item.get()); |
297 | } |
298 | return result; |
299 | } else if (list.tensor_size() || list.func_size()) { |
300 | // TODO(edloper): Add support for tensorflow::AttrValue::kTensor. |
301 | PyErr_SetString(PyExc_TypeError, "Unsupported AttrValue type" ); |
302 | return nullptr; |
303 | } else { |
304 | // Empty list |
305 | return Safe_PyObjectPtr(PyList_New(0)); |
306 | } |
307 | } |
308 | |
309 | } // namespace |
310 | |
311 | AttributeType AttributeTypeFromName(const std::string& type_name) { |
312 | const auto* type_map = AttributeTypeNameMap(); |
313 | auto it = type_map->find(type_name); |
314 | return it != type_map->end() ? it->second : AttributeType::UNKNOWN; |
315 | } |
316 | |
317 | std::string AttributeTypeToName(AttributeType attr_type) { |
318 | for (const auto& pair : *AttributeTypeNameMap()) { |
319 | if (pair.second == attr_type) { |
320 | return pair.first; |
321 | } |
322 | } |
323 | return "<unknown>" ; |
324 | } |
325 | |
326 | Safe_PyObjectPtr ConvertPyObjectToAttributeType(PyObject* value, |
327 | AttributeType type) { |
328 | Safe_PyObjectPtr result = ConvertAttrOrNull(value, type); |
329 | if (!result) { |
330 | auto err = absl::StrCat("Failed to convert value of type '" , |
331 | value->ob_type->tp_name, "' to type '" , |
332 | AttributeTypeToName(type), "'." ); |
333 | PyErr_SetString(PyExc_TypeError, err.c_str()); |
334 | } |
335 | |
336 | return result; |
337 | } |
338 | |
339 | Safe_PyObjectPtr AttrValueToPyObject(const AttrValue& attr_value) { |
340 | switch (attr_value.value_case()) { |
341 | case tensorflow::AttrValue::kS: |
342 | return Safe_PyObjectPtr(PY_STRING_FROMSTRING(attr_value.s().c_str())); |
343 | case tensorflow::AttrValue::kI: |
344 | return Safe_PyObjectPtr(PY_INT_FROM_LONG(attr_value.i())); |
345 | case tensorflow::AttrValue::kF: |
346 | return Safe_PyObjectPtr(PyFloat_FromDouble(attr_value.f())); |
347 | case tensorflow::AttrValue::kB: |
348 | return Safe_PyObjectPtr(PyBool_FromBool(attr_value.b())); |
349 | case tensorflow::AttrValue::kType: |
350 | return DataTypeToPyObject(attr_value.type()); |
351 | case tensorflow::AttrValue::kShape: |
352 | return TensorShapeProtoToPyObject(attr_value.shape()); |
353 | case tensorflow::AttrValue::kList: |
354 | return AttrValueListToPyObject(attr_value.list()); |
355 | default: |
356 | // TODO(edloper): Add support for tensorflow::AttrValue::kTensor. |
357 | PyErr_SetString(PyExc_ValueError, "Unsupported AttrValue type" ); |
358 | return nullptr; |
359 | } |
360 | } |
361 | |
362 | Safe_PyObjectPtr DataTypeToPyObject(const DataType& data_type) { |
363 | Safe_PyObjectPtr enum_value(PY_INT_FROM_LONG(data_type)); |
364 | return ConvertDTypeFunctor()(enum_value.get()); |
365 | } |
366 | |
367 | Safe_PyObjectPtr TensorShapeProtoToPyObject( |
368 | const TensorShapeProto& tensor_shape) { |
369 | if (tensor_shape.unknown_rank()) { |
370 | return ConvertTensorShapeFunctor()(Py_None); |
371 | } else { |
372 | Safe_PyObjectPtr dims(PyTuple_New(tensor_shape.dim_size())); |
373 | for (int i = 0; i < tensor_shape.dim_size(); ++i) { |
374 | PyTuple_SET_ITEM(dims.get(), i, |
375 | PY_INT_FROM_LONG(tensor_shape.dim(i).size())); |
376 | } |
377 | return ConvertTensorShapeFunctor()(dims.get()); |
378 | } |
379 | } |
380 | |
381 | } // namespace tensorflow |
382 | |