1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "pybind11/pybind11.h" |
16 | #include "tensorflow/python/framework/op_def_util.h" |
17 | |
18 | namespace py = pybind11; |
19 | |
20 | namespace { |
21 | |
22 | py::handle ConvertAttr(py::handle value, std::string attr_type) { |
23 | tensorflow::Safe_PyObjectPtr result = |
24 | ::tensorflow::ConvertPyObjectToAttributeType( |
25 | value.ptr(), ::tensorflow::AttributeTypeFromName(attr_type)); |
26 | if (!result) { |
27 | throw py::error_already_set(); |
28 | } |
29 | Py_INCREF(result.get()); |
30 | return result.release(); |
31 | } |
32 | |
33 | py::handle SerializedAttrValueToPyObject(std::string attr_value_string) { |
34 | tensorflow::AttrValue attr_value; |
35 | attr_value.ParseFromString(attr_value_string); |
36 | tensorflow::Safe_PyObjectPtr result = |
37 | ::tensorflow::AttrValueToPyObject(attr_value); |
38 | if (!result) { |
39 | throw py::error_already_set(); |
40 | } |
41 | Py_INCREF(result.get()); |
42 | return result.release(); |
43 | } |
44 | |
45 | } // namespace |
46 | |
47 | // Expose op_def_util.h functions via Python. |
48 | PYBIND11_MODULE(_op_def_util, m) { |
49 | // Note: the bindings below are added for testing purposes; but the functions |
50 | // are expected to be called from c++, not Python. |
51 | m.def("ConvertPyObjectToAttributeType" , ConvertAttr, py::arg("value" ), |
52 | py::arg("attr_type_enum" )); |
53 | m.def("SerializedAttrValueToPyObject" , SerializedAttrValueToPyObject, |
54 | py::arg("attr_value_string" )); |
55 | } |
56 | |