1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // Disallow Numpy 1.7 deprecated symbols. |
17 | #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION |
18 | |
19 | #include "numpy/arrayobject.h" |
20 | #include "numpy/ufuncobject.h" |
21 | #include "pybind11/chrono.h" |
22 | #include "pybind11/complex.h" |
23 | #include "pybind11/functional.h" |
24 | #include "pybind11/pybind11.h" |
25 | #include "pybind11/stl.h" |
26 | #include "tensorflow/c/checkpoint_reader.h" |
27 | #include "tensorflow/c/tf_status.h" |
28 | #include "tensorflow/core/lib/core/errors.h" |
29 | #include "tensorflow/core/lib/core/status.h" |
30 | #include "tensorflow/python/lib/core/ndarray_tensor.h" |
31 | #include "tensorflow/python/lib/core/py_exception_registry.h" |
32 | #include "tensorflow/python/lib/core/pybind11_lib.h" |
33 | #include "tensorflow/python/lib/core/pybind11_status.h" |
34 | #include "tensorflow/python/lib/core/safe_ptr.h" |
35 | |
36 | namespace py = pybind11; |
37 | |
38 | // TODO(amitpatankar): Move the custom type casters to separate common header |
39 | // only libraries. |
40 | |
41 | namespace pybind11 { |
42 | namespace detail { |
43 | |
44 | /* This is a custom type caster for the TensorShape object. For more |
45 | * documentation please refer to this link: |
46 | * https://pybind11.readthedocs.io/en/stable/advanced/cast/custom.html#custom-type-casters |
47 | * The PyCheckpointReader methods sometimes return the `TensorShape` object |
48 | * and the `DataType` object as outputs. This custom type caster helps Python |
49 | * handle it's conversion from C++ to Python. Since we do not accept these |
50 | * classes as arguments from Python, it is not necessary to define the `load` |
51 | * function to cast the object from Python to a C++ object. |
52 | */ |
53 | |
54 | template <> |
55 | struct type_caster<tensorflow::TensorShape> { |
56 | public: |
57 | PYBIND11_TYPE_CASTER(tensorflow::TensorShape, _("tensorflow::TensorShape" )); |
58 | |
59 | static handle cast(const tensorflow::TensorShape& src, |
60 | return_value_policy unused_policy, handle unused_handle) { |
61 | // TODO(amitpatankar): Simplify handling TensorShape as output later. |
62 | size_t dims = src.dims(); |
63 | tensorflow::Safe_PyObjectPtr value(PyList_New(dims)); |
64 | for (size_t i = 0; i < dims; ++i) { |
65 | #if PY_MAJOR_VERSION >= 3 |
66 | tensorflow::Safe_PyObjectPtr dim_value( |
67 | tensorflow::make_safe(PyLong_FromLong(src.dim_size(i)))); |
68 | #else |
69 | tensorflow::Safe_PyObjectPtr dim_value( |
70 | tensorflow::make_safe(PyInt_FromLong(src.dim_size(i)))); |
71 | #endif |
72 | PyList_SET_ITEM(value.get(), i, dim_value.release()); |
73 | } |
74 | |
75 | return value.release(); |
76 | } |
77 | }; |
78 | |
79 | template <> |
80 | struct type_caster<tensorflow::DataType> { |
81 | public: |
82 | PYBIND11_TYPE_CASTER(tensorflow::DataType, _("tensorflow::DataType" )); |
83 | |
84 | static handle cast(const tensorflow::DataType& src, |
85 | return_value_policy unused_policy, handle unused_handle) { |
86 | #if PY_MAJOR_VERSION >= 3 |
87 | tensorflow::Safe_PyObjectPtr value( |
88 | tensorflow::make_safe(PyLong_FromLong(src))); |
89 | #else |
90 | tensorflow::Safe_PyObjectPtr value( |
91 | tensorflow::make_safe(PyInt_FromLong(src))); |
92 | #endif |
93 | return value.release(); |
94 | } |
95 | }; |
96 | |
97 | } // namespace detail |
98 | } // namespace pybind11 |
99 | |
100 | namespace tensorflow { |
101 | |
102 | static py::object CheckpointReader_GetTensor( |
103 | tensorflow::checkpoint::CheckpointReader* reader, const string& name) { |
104 | Safe_TF_StatusPtr status = make_safe(TF_NewStatus()); |
105 | PyObject* py_obj = Py_None; |
106 | std::unique_ptr<tensorflow::Tensor> tensor; |
107 | reader->GetTensor(name, &tensor, status.get()); |
108 | |
109 | // Error handling if unable to get Tensor. |
110 | tensorflow::MaybeRaiseFromTFStatus(status.get()); |
111 | |
112 | tensorflow::MaybeRaiseFromStatus( |
113 | tensorflow::TensorToNdarray(*tensor, &py_obj)); |
114 | |
115 | return tensorflow::PyoOrThrow( |
116 | PyArray_Return(reinterpret_cast<PyArrayObject*>(py_obj))); |
117 | } |
118 | |
119 | } // namespace tensorflow |
120 | |
121 | PYBIND11_MODULE(_pywrap_checkpoint_reader, m) { |
122 | // Initialization code to use numpy types in the type casters. |
123 | import_array1(); |
124 | py::class_<tensorflow::checkpoint::CheckpointReader> checkpoint_reader_class( |
125 | m, "CheckpointReader" ); |
126 | checkpoint_reader_class |
127 | .def(py::init([](const std::string& filename) { |
128 | tensorflow::Safe_TF_StatusPtr status = |
129 | tensorflow::make_safe(TF_NewStatus()); |
130 | // pybind11 support smart pointers and will own freeing the memory when |
131 | // complete. |
132 | // https://pybind11.readthedocs.io/en/master/advanced/smart_ptrs.html#std-unique-ptr |
133 | auto checkpoint = |
134 | std::make_unique<tensorflow::checkpoint::CheckpointReader>( |
135 | filename, status.get()); |
136 | tensorflow::MaybeRaiseFromTFStatus(status.get()); |
137 | return checkpoint; |
138 | })) |
139 | .def("debug_string" , |
140 | [](tensorflow::checkpoint::CheckpointReader& self) { |
141 | return py::bytes(self.DebugString()); |
142 | }) |
143 | .def("get_variable_to_shape_map" , |
144 | &tensorflow::checkpoint::CheckpointReader::GetVariableToShapeMap) |
145 | .def("_GetVariableToDataTypeMap" , |
146 | &tensorflow::checkpoint::CheckpointReader::GetVariableToDataTypeMap) |
147 | .def("_HasTensor" , &tensorflow::checkpoint::CheckpointReader::HasTensor) |
148 | .def_static("CheckpointReader_GetTensor" , |
149 | &tensorflow::CheckpointReader_GetTensor); |
150 | }; |
151 | |