1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// Disallow Numpy 1.7 deprecated symbols.
17#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
18
19#include "numpy/arrayobject.h"
20#include "numpy/ufuncobject.h"
21#include "pybind11/chrono.h"
22#include "pybind11/complex.h"
23#include "pybind11/functional.h"
24#include "pybind11/pybind11.h"
25#include "pybind11/stl.h"
26#include "tensorflow/c/checkpoint_reader.h"
27#include "tensorflow/c/tf_status.h"
28#include "tensorflow/core/lib/core/errors.h"
29#include "tensorflow/core/lib/core/status.h"
30#include "tensorflow/python/lib/core/ndarray_tensor.h"
31#include "tensorflow/python/lib/core/py_exception_registry.h"
32#include "tensorflow/python/lib/core/pybind11_lib.h"
33#include "tensorflow/python/lib/core/pybind11_status.h"
34#include "tensorflow/python/lib/core/safe_ptr.h"
35
36namespace py = pybind11;
37
38// TODO(amitpatankar): Move the custom type casters to separate common header
39// only libraries.
40
41namespace pybind11 {
42namespace detail {
43
44/* This is a custom type caster for the TensorShape object. For more
45 * documentation please refer to this link:
46 * https://pybind11.readthedocs.io/en/stable/advanced/cast/custom.html#custom-type-casters
47 * The PyCheckpointReader methods sometimes return the `TensorShape` object
48 * and the `DataType` object as outputs. This custom type caster helps Python
49 * handle it's conversion from C++ to Python. Since we do not accept these
50 * classes as arguments from Python, it is not necessary to define the `load`
51 * function to cast the object from Python to a C++ object.
52 */
53
54template <>
55struct type_caster<tensorflow::TensorShape> {
56 public:
57 PYBIND11_TYPE_CASTER(tensorflow::TensorShape, _("tensorflow::TensorShape"));
58
59 static handle cast(const tensorflow::TensorShape& src,
60 return_value_policy unused_policy, handle unused_handle) {
61 // TODO(amitpatankar): Simplify handling TensorShape as output later.
62 size_t dims = src.dims();
63 tensorflow::Safe_PyObjectPtr value(PyList_New(dims));
64 for (size_t i = 0; i < dims; ++i) {
65#if PY_MAJOR_VERSION >= 3
66 tensorflow::Safe_PyObjectPtr dim_value(
67 tensorflow::make_safe(PyLong_FromLong(src.dim_size(i))));
68#else
69 tensorflow::Safe_PyObjectPtr dim_value(
70 tensorflow::make_safe(PyInt_FromLong(src.dim_size(i))));
71#endif
72 PyList_SET_ITEM(value.get(), i, dim_value.release());
73 }
74
75 return value.release();
76 }
77};
78
79template <>
80struct type_caster<tensorflow::DataType> {
81 public:
82 PYBIND11_TYPE_CASTER(tensorflow::DataType, _("tensorflow::DataType"));
83
84 static handle cast(const tensorflow::DataType& src,
85 return_value_policy unused_policy, handle unused_handle) {
86#if PY_MAJOR_VERSION >= 3
87 tensorflow::Safe_PyObjectPtr value(
88 tensorflow::make_safe(PyLong_FromLong(src)));
89#else
90 tensorflow::Safe_PyObjectPtr value(
91 tensorflow::make_safe(PyInt_FromLong(src)));
92#endif
93 return value.release();
94 }
95};
96
97} // namespace detail
98} // namespace pybind11
99
100namespace tensorflow {
101
102static py::object CheckpointReader_GetTensor(
103 tensorflow::checkpoint::CheckpointReader* reader, const string& name) {
104 Safe_TF_StatusPtr status = make_safe(TF_NewStatus());
105 PyObject* py_obj = Py_None;
106 std::unique_ptr<tensorflow::Tensor> tensor;
107 reader->GetTensor(name, &tensor, status.get());
108
109 // Error handling if unable to get Tensor.
110 tensorflow::MaybeRaiseFromTFStatus(status.get());
111
112 tensorflow::MaybeRaiseFromStatus(
113 tensorflow::TensorToNdarray(*tensor, &py_obj));
114
115 return tensorflow::PyoOrThrow(
116 PyArray_Return(reinterpret_cast<PyArrayObject*>(py_obj)));
117}
118
119} // namespace tensorflow
120
121PYBIND11_MODULE(_pywrap_checkpoint_reader, m) {
122 // Initialization code to use numpy types in the type casters.
123 import_array1();
124 py::class_<tensorflow::checkpoint::CheckpointReader> checkpoint_reader_class(
125 m, "CheckpointReader");
126 checkpoint_reader_class
127 .def(py::init([](const std::string& filename) {
128 tensorflow::Safe_TF_StatusPtr status =
129 tensorflow::make_safe(TF_NewStatus());
130 // pybind11 support smart pointers and will own freeing the memory when
131 // complete.
132 // https://pybind11.readthedocs.io/en/master/advanced/smart_ptrs.html#std-unique-ptr
133 auto checkpoint =
134 std::make_unique<tensorflow::checkpoint::CheckpointReader>(
135 filename, status.get());
136 tensorflow::MaybeRaiseFromTFStatus(status.get());
137 return checkpoint;
138 }))
139 .def("debug_string",
140 [](tensorflow::checkpoint::CheckpointReader& self) {
141 return py::bytes(self.DebugString());
142 })
143 .def("get_variable_to_shape_map",
144 &tensorflow::checkpoint::CheckpointReader::GetVariableToShapeMap)
145 .def("_GetVariableToDataTypeMap",
146 &tensorflow::checkpoint::CheckpointReader::GetVariableToDataTypeMap)
147 .def("_HasTensor", &tensorflow::checkpoint::CheckpointReader::HasTensor)
148 .def_static("CheckpointReader_GetTensor",
149 &tensorflow::CheckpointReader_GetTensor);
150};
151