1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // Must be at top (before any system includes and Python.h). |
17 | // clang-format off |
18 | #include "pybind11/chrono.h" |
19 | #include "pybind11/complex.h" |
20 | #include "pybind11/functional.h" |
21 | #include "pybind11/pybind11.h" |
22 | #include "pybind11/stl.h" |
23 | // clang-format on |
24 | |
25 | #include "Python.h" |
26 | #include "absl/types/optional.h" |
27 | #include "third_party/eigen3/Eigen/Core" |
28 | #include "tensorflow/c/c_api.h" |
29 | #include "tensorflow/c/c_api_experimental.h" |
30 | #include "tensorflow/c/c_api_internal.h" |
31 | #include "tensorflow/c/python_api.h" |
32 | #include "tensorflow/c/tf_datatype.h" |
33 | #include "tensorflow/core/distributed_runtime/server_lib.h" |
34 | #include "tensorflow/core/framework/full_type.pb.h" |
35 | #include "tensorflow/core/public/version.h" |
36 | #include "tensorflow/core/util/version_info.h" |
37 | #include "tensorflow/python/client/tf_session_helper.h" |
38 | #include "tensorflow/python/lib/core/numpy.h" |
39 | #include "tensorflow/python/lib/core/pybind11_lib.h" |
40 | #include "tensorflow/python/lib/core/pybind11_status.h" |
41 | #include "tensorflow/python/lib/core/safe_ptr.h" |
42 | |
43 | namespace pybind11 { |
44 | namespace detail { |
45 | // Convert between absl::optional and python. |
46 | // |
47 | // pybind11 supports std::optional, and absl::optional is meant to be a |
48 | // drop-in replacement for std::optional, so we can just use the built in |
49 | // implementation. |
50 | #ifndef ABSL_USES_STD_OPTIONAL |
51 | template <typename T> |
52 | struct type_caster<absl::optional<T>> |
53 | : public optional_caster<absl::optional<T>> {}; |
54 | template <> |
55 | struct type_caster<absl::nullopt_t> : public void_caster<absl::nullopt_t> {}; |
56 | #endif |
57 | |
58 | } // namespace detail |
59 | } // namespace pybind11 |
60 | |
61 | // TODO(amitpatankar): Consolidate Buffer methods into a separate header file. |
62 | TF_Buffer* ProtoStringToTFBuffer(PyObject* input) { |
63 | // Convert a Python string object to TF_Buffer. |
64 | char* c_string; |
65 | Py_ssize_t py_size; |
66 | // PyBytes_AsStringAndSize() does not copy but simply interprets the input |
67 | if (PyBytes_AsStringAndSize(input, &c_string, &py_size) == -1) { |
68 | // Python has raised an error (likely TypeError or UnicodeEncodeError). |
69 | throw py::error_already_set(); |
70 | } |
71 | return TF_NewBufferFromString(static_cast<void*>(c_string), |
72 | static_cast<size_t>(py_size)); |
73 | } |
74 | |
75 | // Copied from tf_session.i |
76 | // We have to do convoluted logic of passing in a vector of py::bytes. If we |
77 | // pass in strings they are freed prior to the necessary function calls. |
78 | tensorflow::NameVector ConvertPyListToNameVector( |
79 | const std::vector<py::bytes>& py_vector) { |
80 | tensorflow::NameVector temp; |
81 | for (size_t i = 0; i < py_vector.size(); ++i) { |
82 | const char* string_elem = PyBytes_AsString(py_vector.at(i).ptr()); |
83 | temp.push_back(string_elem); |
84 | } |
85 | return temp; |
86 | } |
87 | |
88 | namespace py = pybind11; |
89 | |
90 | PYBIND11_MAKE_OPAQUE(TF_Graph); |
91 | PYBIND11_MAKE_OPAQUE(TF_Session); |
92 | PYBIND11_MAKE_OPAQUE(TF_Operation); |
93 | PYBIND11_MAKE_OPAQUE(TF_Buffer); |
94 | PYBIND11_MAKE_OPAQUE(TF_ImportGraphDefOptions); |
95 | PYBIND11_MAKE_OPAQUE(TF_ImportGraphDefResults); |
96 | PYBIND11_MAKE_OPAQUE(TF_DeprecatedSession); |
97 | PYBIND11_MAKE_OPAQUE(TF_OperationDescription); |
98 | PYBIND11_MAKE_OPAQUE(TF_Library); |
99 | PYBIND11_MAKE_OPAQUE(TF_SessionOptions); |
100 | PYBIND11_MAKE_OPAQUE(TF_ApiDefMap); |
101 | PYBIND11_MAKE_OPAQUE(TF_Server); |
102 | PYBIND11_MAKE_OPAQUE(TF_DeviceList); |
103 | PYBIND11_MAKE_OPAQUE(TF_Status); |
104 | |
105 | PYBIND11_MODULE(_pywrap_tf_session, m) { |
106 | // Numpy initialization code for array checks. |
107 | tensorflow::ImportNumpy(); |
108 | |
109 | py::class_<TF_Graph> TF_Graph_class(m, "TF_Graph" ); |
110 | py::class_<TF_Operation> TF_Operation_class(m, "TF_Operation" ); |
111 | |
112 | py::class_<TF_Output>(m, "TF_Output" ) |
113 | .def(py::init<>()) |
114 | .def_readwrite("oper" , &TF_Output::oper) |
115 | .def_readwrite("index" , &TF_Output::index); |
116 | |
117 | py::class_<TF_Input>(m, "TF_Input" ) |
118 | .def(py::init<>()) |
119 | .def_readwrite("oper" , &TF_Input::oper) |
120 | .def_readwrite("index" , &TF_Input::index); |
121 | |
122 | py::class_<TF_ImportGraphDefOptions> TF_ImportGraphDefOptions_class( |
123 | m, "TF_ImportGraphDefOptions" ); |
124 | py::class_<TF_ImportGraphDefResults> TF_ImportGraphDefResults_class( |
125 | m, "TF_ImportGraphDefResults" ); |
126 | py::class_<TF_DeprecatedSession> TF_DeprecatedSession_class( |
127 | m, "TF_DeprecatedSession" ); |
128 | py::class_<TF_Session> TF_Session_class(m, "TF_Session" ); |
129 | py::class_<TF_OperationDescription> TF_OperationDescription_class( |
130 | m, "TF_OperationDescription" ); |
131 | py::class_<TF_Library> TF_Library_class(m, "TF_Library" ); |
132 | py::class_<TF_SessionOptions> TF_SessionOptions_class(m, "TF_SessionOptions" ); |
133 | py::class_<TF_Buffer> TF_Buffer_class(m, "TF_Buffer" ); |
134 | py::class_<TF_ApiDefMap> TF_ApiDefMap_class(m, "TF_ApiDefMap" ); |
135 | py::class_<TF_Server> TF_Server_class(m, "TF_Server" ); |
136 | py::class_<TF_Status> TF_Status_class(m, "TF_Status" ); |
137 | |
138 | // We only release the Python GIL for certain methods that are |
139 | // not explicitly marked. We disable this behavior for some functions |
140 | // because they uses Python method(s) that expect the GIL to be held |
141 | // (at least PyArray_Return, maybe others). |
142 | |
143 | // Do not release GIL. |
144 | m.def("TF_OperationGetControlInputs_wrapper" , |
145 | tensorflow::TF_OperationGetControlInputs_wrapper); |
146 | // Do not release GIL. |
147 | m.def("TF_OperationGetControlOutputs_wrapper" , |
148 | tensorflow::TF_OperationGetControlOutputs_wrapper); |
149 | m.def("TF_OperationOutputConsumers_wrapper" , |
150 | tensorflow::TF_OperationOutputConsumers_wrapper); |
151 | // Do not release GIL. |
152 | m.def("GetOperationInputs" , tensorflow::GetOperationInputs); |
153 | // Do not release GIL. |
154 | m.def("TF_ImportGraphDefOptionsSetValidateColocationConstraints" , |
155 | TF_ImportGraphDefOptionsSetValidateColocationConstraints); |
156 | // Do not release GIL. |
157 | m.def("TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper" , |
158 | tensorflow::TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper); |
159 | m.def("TF_SessionMakeCallable" , |
160 | [](TF_Session* session, const TF_Buffer* callable_options) { |
161 | int64_t out_handle; |
162 | tensorflow::Safe_TF_StatusPtr status = |
163 | tensorflow::make_safe(TF_NewStatus()); |
164 | |
165 | // Release GIL. |
166 | py::gil_scoped_release release; |
167 | tensorflow::TF_SessionMakeCallable(session, callable_options, |
168 | &out_handle, status.get()); |
169 | |
170 | // Acquire GIL for returning int conversion. |
171 | pybind11::gil_scoped_acquire acquire; |
172 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
173 | return out_handle; |
174 | }); |
175 | m.def("_TF_SetTarget" , TF_SetTarget); |
176 | m.def("_TF_SetConfig" , [](TF_SessionOptions* options, py::bytes proto) { |
177 | tensorflow::Safe_TF_StatusPtr status = |
178 | tensorflow::make_safe(TF_NewStatus()); |
179 | tensorflow::Safe_TF_BufferPtr buf = |
180 | tensorflow::make_safe(ProtoStringToTFBuffer(proto.ptr())); |
181 | TF_SetConfig(options, buf.get()->data, buf.get()->length, status.get()); |
182 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
183 | }); |
184 | m.def("_TF_NewSessionOptions" , TF_NewSessionOptions, |
185 | py::return_value_policy::reference, |
186 | py::call_guard<py::gil_scoped_release>()); |
187 | m.def("TF_DeleteSessionOptions" , TF_DeleteSessionOptions, |
188 | py::call_guard<py::gil_scoped_release>()); |
189 | |
190 | m.def("EqualGraphDefWrapper" , tensorflow::EqualGraphDefWrapper, |
191 | py::call_guard<py::gil_scoped_release>()); |
192 | m.def("EqualAttrValueWrapper" , tensorflow::EqualAttrValueWrapper, |
193 | py::call_guard<py::gil_scoped_release>()); |
194 | |
195 | m.def( |
196 | "TF_GraphToFunction_wrapper" , |
197 | [](const TF_Graph* fn_body, const char* fn_name, |
198 | bool append_hash_to_fn_name, |
199 | absl::optional<std::vector<TF_Operation*>> opers_opt, |
200 | const std::vector<TF_Output>& inputs, |
201 | const std::vector<TF_Output>& outputs, |
202 | const std::vector<py::bytes> output_names, |
203 | const std::vector<TF_Operation*> control_outputs, |
204 | const std::vector<py::bytes> control_output_names, py::none opts, |
205 | const char* description) { |
206 | tensorflow::Safe_TF_StatusPtr status = |
207 | tensorflow::make_safe(TF_NewStatus()); |
208 | |
209 | // TODO(b/147674626): Use pybind11 list_caster instead. |
210 | tensorflow::NameVector output_names_name_vector = |
211 | ConvertPyListToNameVector(output_names); |
212 | |
213 | // TODO(b/147674626): Use pybind11 list_caster instead. |
214 | tensorflow::NameVector control_output_names_name_vector = |
215 | ConvertPyListToNameVector(control_output_names); |
216 | |
217 | // Release GIL. |
218 | py::gil_scoped_release release; |
219 | auto output = tensorflow::TF_GraphToFunction_wrapper( |
220 | fn_body, fn_name, append_hash_to_fn_name, |
221 | opers_opt.has_value() ? &opers_opt.value() : nullptr, inputs, |
222 | outputs, output_names_name_vector, &control_outputs, |
223 | control_output_names_name_vector, |
224 | /*opts=*/nullptr, description, status.get()); |
225 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); |
226 | return output; |
227 | }, |
228 | py::return_value_policy::reference); |
229 | |
230 | m.def("TF_GraphGetTensorShapeHelper" , [](TF_Graph* graph, TF_Output output) { |
231 | tensorflow::Safe_TF_StatusPtr status = |
232 | tensorflow::make_safe(TF_NewStatus()); |
233 | bool unknown_shape; |
234 | |
235 | auto result = tensorflow::TF_GraphGetTensorShapeHelper( |
236 | graph, output, status.get(), &unknown_shape); |
237 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
238 | |
239 | // Create a python list from InlinedVector |
240 | py::list py_list; |
241 | for (size_t i = 0; i < result.size(); ++i) { |
242 | py_list.append(py::cast(result[i])); |
243 | } |
244 | |
245 | // Return a tuple. |
246 | py::tuple result_tuple = py::make_tuple(py_list, py::cast(unknown_shape)); |
247 | return result_tuple; |
248 | }); |
249 | |
250 | m.def("TF_GraphSetTensorShape_wrapper" , |
251 | [](TF_Graph* graph, TF_Output output, const std::vector<int64_t>& dims, |
252 | bool unknown_shape) { |
253 | tensorflow::Safe_TF_StatusPtr status = |
254 | tensorflow::make_safe(TF_NewStatus()); |
255 | |
256 | // Release GIL. |
257 | py::gil_scoped_release release; |
258 | tensorflow::TF_GraphSetTensorShape_wrapper( |
259 | graph, output, dims, unknown_shape, status.get()); |
260 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); |
261 | }); |
262 | |
263 | m.def("TF_GraphGetTensorShape_wrapper" , |
264 | [](TF_Graph* graph, TF_Output output, const std::vector<int64_t>& dims, |
265 | bool unknown_shape) { |
266 | tensorflow::Safe_TF_StatusPtr status = |
267 | tensorflow::make_safe(TF_NewStatus()); |
268 | // Release GIL. |
269 | py::gil_scoped_release release; |
270 | tensorflow::TF_GraphSetTensorShape_wrapper( |
271 | graph, output, dims, unknown_shape, status.get()); |
272 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); |
273 | }); |
274 | |
275 | m.def("TF_GraphSetOutputHandleShapesAndTypes_wrapper" , |
276 | [](TF_Graph* graph, TF_Output output, |
277 | const std::vector<absl::optional<std::vector<int64_t>>>& shapes, |
278 | const std::vector<int>& ranks, py::handle& types) { |
279 | tensorflow::Safe_TF_StatusPtr status = |
280 | tensorflow::make_safe(TF_NewStatus()); |
281 | |
282 | // Cast types |
283 | std::vector<TF_DataType> types_local; |
284 | PyObject* seq = |
285 | PySequence_Fast(types.ptr(), "$symname: expected list" ); |
286 | if (seq == nullptr) { |
287 | PyErr_SetString(PyExc_RuntimeError, |
288 | "$symname: PySequence_Fast returned NULL." ); |
289 | throw py::error_already_set(); |
290 | } |
291 | |
292 | int size = PySequence_Fast_GET_SIZE(seq); |
293 | if (size == 0) { |
294 | PyErr_SetString(PyExc_ValueError, |
295 | "$symname: shapes list must be non-empty" ); |
296 | throw py::error_already_set(); |
297 | } |
298 | |
299 | for (int i = 0; i < size; ++i) { |
300 | PyObject* item = PySequence_Fast_GET_ITEM(seq, i); |
301 | types_local.push_back((TF_DataType)PyLong_AsLong(item)); |
302 | } |
303 | |
304 | // Convert shapes nested vector |
305 | std::vector<std::vector<int64_t>> shapes_local; |
306 | for (size_t i = 0; i < shapes.size(); ++i) { |
307 | std::vector<int64_t> dims; |
308 | std::vector<int64_t> item = |
309 | shapes[i].has_value() ? shapes[i].value() : dims; |
310 | shapes_local.push_back(item); |
311 | } |
312 | |
313 | Py_DECREF(seq); |
314 | |
315 | tensorflow::TF_GraphSetOutputHandleShapesAndTypes_wrapper( |
316 | graph, output, shapes_local, ranks, types_local, status.get()); |
317 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
318 | }); |
319 | |
320 | // Do not release GIL. |
321 | m.def("TF_CreatePlaceholders" , |
322 | [](TF_Graph* graph, py::handle& dtypes, const char* prefix) { |
323 | tensorflow::Safe_TF_StatusPtr status = |
324 | tensorflow::make_safe(TF_NewStatus()); |
325 | auto output = tensorflow::TF_CreatePlaceholders(graph, dtypes.ptr(), |
326 | prefix, status.get()); |
327 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
328 | return output; |
329 | }); |
330 | |
331 | m.def( |
332 | "TF_NewSession" , |
333 | [](TF_Graph* graph, const TF_SessionOptions* opts) { |
334 | tensorflow::Safe_TF_StatusPtr status = |
335 | tensorflow::make_safe(TF_NewStatus()); |
336 | // Release GIL. |
337 | py::gil_scoped_release release; |
338 | auto output = TF_NewSession(graph, opts, status.get()); |
339 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); |
340 | return output; |
341 | }, |
342 | py::return_value_policy::reference); |
343 | |
344 | m.def( |
345 | "TF_NewSessionRef" , |
346 | [](TF_Graph* graph, const TF_SessionOptions* opts) { |
347 | tensorflow::Safe_TF_StatusPtr status = |
348 | tensorflow::make_safe(TF_NewStatus()); |
349 | // Release GIL. |
350 | py::gil_scoped_release release; |
351 | auto output = tensorflow::TF_NewSessionRef(graph, opts, status.get()); |
352 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); |
353 | return output; |
354 | }, |
355 | py::return_value_policy::reference); |
356 | |
357 | m.def("TF_CloseSession" , [](TF_Session* session) { |
358 | tensorflow::Safe_TF_StatusPtr status = |
359 | tensorflow::make_safe(TF_NewStatus()); |
360 | |
361 | // Release GIL. |
362 | py::gil_scoped_release release; |
363 | TF_CloseSession(session, status.get()); |
364 | |
365 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); |
366 | }); |
367 | |
368 | m.def("TF_DeleteSession" , [](TF_Session* session) { |
369 | tensorflow::Safe_TF_StatusPtr status = |
370 | tensorflow::make_safe(TF_NewStatus()); |
371 | // Release GIL. |
372 | py::gil_scoped_release release; |
373 | TF_DeleteSession(session, status.get()); |
374 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); |
375 | }); |
376 | |
377 | m.def("SetRequireShapeInferenceFns" , tensorflow::SetRequireShapeInferenceFns); |
378 | |
379 | // Do not release GIL. |
380 | m.def("TF_TryEvaluateConstant_wrapper" , |
381 | [](TF_Graph* graph, const TF_Output output) { |
382 | tensorflow::Safe_TF_StatusPtr status = |
383 | tensorflow::make_safe(TF_NewStatus()); |
384 | auto result = tensorflow::TF_TryEvaluateConstant_wrapper( |
385 | graph, output, status.get()); |
386 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
387 | return tensorflow::PyoOrThrow(result); |
388 | }); |
389 | |
390 | m.def("ExtendSession" , [](TF_Session* session) { |
391 | tensorflow::Safe_TF_StatusPtr status = |
392 | tensorflow::make_safe(TF_NewStatus()); |
393 | // Release GIL for threading. |
394 | pybind11::gil_scoped_release release; |
395 | tensorflow::ExtendSession(session, status.get()); |
396 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); |
397 | }); |
398 | |
399 | m.def("GetHandleShapeAndType" , [](TF_Graph* graph, TF_Output output) { |
400 | std::string output_string = |
401 | tensorflow::GetHandleShapeAndType(graph, output); |
402 | // Override default py3 behavior of attempting to encode into Unicode as |
403 | // the dependent functions expect bytes. |
404 | return py::bytes(output_string); |
405 | }); |
406 | |
407 | m.def("SetHandleShapeAndType" , |
408 | [](TF_Graph* graph, TF_Output output, py::bytes proto) { |
409 | tensorflow::Safe_TF_StatusPtr status = |
410 | tensorflow::make_safe(TF_NewStatus()); |
411 | tensorflow::Safe_TF_BufferPtr buf = |
412 | tensorflow::make_safe(ProtoStringToTFBuffer(proto.ptr())); |
413 | tensorflow::SetHandleShapeAndType(graph, output, buf.get()->data, |
414 | buf.get()->length, status.get()); |
415 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
416 | }); |
417 | |
418 | // Do not release GIL. |
419 | m.def("TF_SessionRun_wrapper" , [](TF_Session* session, TF_Buffer* run_options, |
420 | const py::handle& input_dict, |
421 | const std::vector<TF_Output>& outputs, |
422 | const std::vector<TF_Operation*>& targets, |
423 | TF_Buffer* run_metadata) { |
424 | // Convert inputs dictionary |
425 | std::vector<TF_Output> inputs; |
426 | std::vector<PyObject*> input_ndarrays; |
427 | if (!PyDict_Check(input_dict.ptr())) { |
428 | PyErr_SetString( |
429 | PyExc_TypeError, |
430 | "Expected a dictionary as an argument to TF_SessionRun_wrapper." ); |
431 | throw py::error_already_set(); |
432 | } |
433 | PyObject* key; |
434 | PyObject* value; |
435 | Py_ssize_t pos = 0; |
436 | while (PyDict_Next(input_dict.ptr(), &pos, &key, &value)) { |
437 | TF_Output item = py::cast<TF_Output>(key); |
438 | inputs.push_back(item); |
439 | |
440 | // TODO(amitpatankar): Fix this PyArray check. (b/147855599) |
441 | |
442 | // if (!PyArray_Check(value)) { |
443 | // PyErr_SetString( |
444 | // PyExc_TypeError, |
445 | // "$symname: Expected all values in input dict to be ndarray."); |
446 | // throw py::error_already_set(); |
447 | // } |
448 | input_ndarrays.push_back(value); |
449 | } |
450 | |
451 | tensorflow::Safe_TF_StatusPtr status = |
452 | tensorflow::make_safe(TF_NewStatus()); |
453 | std::vector<PyObject*> py_outputs; |
454 | tensorflow::TF_SessionRun_wrapper(session, run_options, inputs, |
455 | input_ndarrays, outputs, targets, |
456 | run_metadata, status.get(), &py_outputs); |
457 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
458 | |
459 | // Create a Python list using the C API rather than py::list. b/147855599 |
460 | PyObject* result = PyList_New(py_outputs.size()); |
461 | if (result == nullptr) { |
462 | PyErr_SetString(PyExc_MemoryError, "Failed to create a list." ); |
463 | throw py::error_already_set(); |
464 | } |
465 | for (size_t i = 0; i < py_outputs.size(); ++i) { |
466 | PyList_SET_ITEM(result, i, py_outputs.at(i)); |
467 | } |
468 | |
469 | return tensorflow::PyoOrThrow(result); |
470 | }); |
471 | |
472 | // Do not release GIL. |
473 | m.def("TF_SessionPRun_wrapper" , [](TF_Session* session, const char* handle, |
474 | const py::handle& input_dict, |
475 | const std::vector<TF_Output>& outputs) { |
476 | // Convert inputs dictionary |
477 | std::vector<TF_Output> inputs; |
478 | std::vector<PyObject*> input_ndarrays; |
479 | if (!PyDict_Check(input_dict.ptr())) { |
480 | PyErr_SetString( |
481 | PyExc_TypeError, |
482 | "Expected a dictionary as an argument to TF_SessionPRun_wrapper." ); |
483 | throw py::error_already_set(); |
484 | } |
485 | PyObject* key; |
486 | PyObject* value; |
487 | Py_ssize_t pos = 0; |
488 | while (PyDict_Next(input_dict.ptr(), &pos, &key, &value)) { |
489 | TF_Output item = py::cast<TF_Output>(key); |
490 | inputs.push_back(item); |
491 | |
492 | // TODO(amitpatankar): Fix this PyArray check. (b/147855599) |
493 | |
494 | // if (!PyArray_Check(value)) { |
495 | // PyErr_SetString( |
496 | // PyExc_TypeError, |
497 | // "$symname: Expected all values in input dict to be ndarray."); |
498 | // throw py::error_already_set(); |
499 | // } |
500 | input_ndarrays.push_back(value); |
501 | } |
502 | |
503 | tensorflow::Safe_TF_StatusPtr status = |
504 | tensorflow::make_safe(TF_NewStatus()); |
505 | std::vector<PyObject*> py_outputs; |
506 | tensorflow::TF_SessionPRun_wrapper(session, handle, inputs, input_ndarrays, |
507 | outputs, status.get(), &py_outputs); |
508 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
509 | |
510 | PyObject* result = PyList_New(py_outputs.size()); |
511 | if (result == nullptr) { |
512 | PyErr_SetString(PyExc_MemoryError, "Failed to create a list." ); |
513 | throw py::error_already_set(); |
514 | } |
515 | for (size_t i = 0; i < py_outputs.size(); ++i) { |
516 | PyList_SET_ITEM(result, i, py_outputs.at(i)); |
517 | } |
518 | |
519 | return tensorflow::PyoOrThrow(result); |
520 | }); |
521 | |
522 | // Do not release GIL. |
523 | m.def("TF_SessionPRunSetup_wrapper" , |
524 | [](TF_Session* session, const std::vector<TF_Output>& inputs, |
525 | const std::vector<TF_Output>& outputs, |
526 | const std::vector<TF_Operation*>& targets) { |
527 | tensorflow::Safe_TF_StatusPtr status = |
528 | tensorflow::make_safe(TF_NewStatus()); |
529 | const char* out_handle; |
530 | tensorflow::TF_SessionPRunSetup_wrapper( |
531 | session, inputs, outputs, targets, &out_handle, status.get()); |
532 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
533 | return out_handle; |
534 | }); |
535 | |
536 | // Do not release GIL. |
537 | m.def("TF_SessionRunCallable" , [](TF_Session* session, int64_t handle, |
538 | py::object feed_values, |
539 | TF_Buffer* run_metadata) { |
540 | tensorflow::PyObjectVector out_values; |
541 | tensorflow::Safe_TF_StatusPtr status = |
542 | tensorflow::make_safe(TF_NewStatus()); |
543 | tensorflow::TF_SessionRunCallable(session, handle, feed_values.ptr(), |
544 | &out_values, run_metadata, status.get()); |
545 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
546 | |
547 | // Return out_values |
548 | py::list py_list; |
549 | for (size_t i = 0; i < out_values.size(); ++i) { |
550 | py::object obj = tensorflow::Pyo(out_values.at(i)); |
551 | py_list.append(obj); |
552 | } |
553 | return py_list; |
554 | }); |
555 | |
556 | m.def("TF_SessionReleaseCallable" , [](TF_Session* session, int64_t handle) { |
557 | tensorflow::Safe_TF_StatusPtr status = |
558 | tensorflow::make_safe(TF_NewStatus()); |
559 | // Release GIL. |
560 | py::gil_scoped_release release; |
561 | tensorflow::TF_SessionReleaseCallable(session, handle, status.get()); |
562 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); |
563 | }); |
564 | |
565 | m.def("TF_NewGraph" , TF_NewGraph, py::return_value_policy::reference, |
566 | py::call_guard<py::gil_scoped_release>()); |
567 | m.def("TF_DeleteGraph" , TF_DeleteGraph, |
568 | py::call_guard<py::gil_scoped_release>()); |
569 | |
570 | m.def("TF_GraphGetOpDef" , |
571 | [](TF_Graph* graph, const char* op_name, TF_Buffer* output_op_def) { |
572 | tensorflow::Safe_TF_StatusPtr status = |
573 | tensorflow::make_safe(TF_NewStatus()); |
574 | // Release GIL. |
575 | py::gil_scoped_release release; |
576 | TF_GraphGetOpDef(graph, op_name, output_op_def, status.get()); |
577 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); |
578 | }); |
579 | |
580 | m.def( |
581 | "TF_NewOperation" , |
582 | [](TF_Graph* graph, const char* op_type, const char* oper_name) { |
583 | tensorflow::Safe_TF_StatusPtr status = |
584 | tensorflow::make_safe(TF_NewStatus()); |
585 | // Release GIL. |
586 | py::gil_scoped_release release; |
587 | TF_OperationDescription* output = |
588 | TF_NewOperation(graph, op_type, oper_name); |
589 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); |
590 | return output; |
591 | }, |
592 | py::return_value_policy::reference); |
593 | |
594 | m.def( |
595 | "TF_FinishOperation" , |
596 | [](TF_OperationDescription* desc) { |
597 | tensorflow::Safe_TF_StatusPtr status = |
598 | tensorflow::make_safe(TF_NewStatus()); |
599 | // Release GIL. |
600 | py::gil_scoped_release release; |
601 | TF_Operation* output = TF_FinishOperation(desc, status.get()); |
602 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); |
603 | return output; |
604 | }, |
605 | py::return_value_policy::reference); |
606 | |
607 | m.def("TF_OperationGetAttrInt" , |
608 | [](TF_Operation* oper, const char* attr_name) { |
609 | tensorflow::Safe_TF_StatusPtr status = |
610 | tensorflow::make_safe(TF_NewStatus()); |
611 | int64_t value; |
612 | // Release GIL. |
613 | py::gil_scoped_release release; |
614 | TF_OperationGetAttrInt(oper, attr_name, &value, status.get()); |
615 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); |
616 | // Convert TF_OperationGetAttrInt int64_t* out-argument to Python |
617 | // bool. |
618 | // Acquire GIL for returning output returning. |
619 | pybind11::gil_scoped_acquire acquire; |
620 | return tensorflow::Pyo(PyLong_FromLongLong(value)); |
621 | }); |
622 | |
623 | m.def("TF_SetAttrValueProto" , [](TF_OperationDescription* desc, |
624 | const char* attr_name, py::bytes proto) { |
625 | tensorflow::Safe_TF_StatusPtr status = |
626 | tensorflow::make_safe(TF_NewStatus()); |
627 | tensorflow::Safe_TF_BufferPtr buf = |
628 | tensorflow::make_safe(ProtoStringToTFBuffer(proto.ptr())); |
629 | TF_SetAttrValueProto(desc, attr_name, buf.get()->data, buf.get()->length, |
630 | status.get()); |
631 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
632 | }); |
633 | |
634 | m.def("TF_OperationNumOutputs" , TF_OperationNumOutputs, |
635 | py::call_guard<py::gil_scoped_release>()); |
636 | |
637 | // Convert types to ints |
638 | m.def("TF_OperationInputType" , TF_OperationInputType, |
639 | py::call_guard<py::gil_scoped_release>()); |
640 | m.def("TF_OperationOutputType" , TF_OperationOutputType, |
641 | py::call_guard<py::gil_scoped_release>()); |
642 | |
643 | m.def("TF_OperationName" , TF_OperationName, |
644 | py::call_guard<py::gil_scoped_release>()); |
645 | m.def("TF_OperationOpType" , TF_OperationOpType, |
646 | py::call_guard<py::gil_scoped_release>()); |
647 | m.def("TF_OperationDevice" , TF_OperationDevice, |
648 | py::call_guard<py::gil_scoped_release>()); |
649 | |
650 | m.def("TF_AddInput" , TF_AddInput); |
651 | |
652 | m.def("TF_OperationToNodeDef" , |
653 | [](TF_Operation* oper, TF_Buffer* output_node_def) { |
654 | tensorflow::Safe_TF_StatusPtr status = |
655 | tensorflow::make_safe(TF_NewStatus()); |
656 | TF_OperationToNodeDef(oper, output_node_def, status.get()); |
657 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
658 | }); |
659 | |
660 | m.def("TF_OperationGetAttrValueProto" , |
661 | [](TF_Operation* oper, const char* attr_name, |
662 | TF_Buffer* output_attr_value) { |
663 | tensorflow::Safe_TF_StatusPtr status = |
664 | tensorflow::make_safe(TF_NewStatus()); |
665 | TF_OperationGetAttrValueProto(oper, attr_name, output_attr_value, |
666 | status.get()); |
667 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
668 | }); |
669 | |
670 | m.def("TF_OperationGetStackTrace" , [](TF_Operation* oper) -> py::object { |
671 | const std::shared_ptr<tensorflow::AbstractStackTrace> trace = |
672 | oper->node.GetStackTrace(); |
673 | if (!trace) { |
674 | return py::none(); |
675 | } |
676 | return py::cast(*trace, py::return_value_policy::reference); |
677 | }); |
678 | |
679 | m.def("SetRequestedDevice" , tensorflow::SetRequestedDevice); |
680 | |
681 | // TF_Buffer util methods |
682 | // TODO(amitpatankar): Consolidate Buffer methods into a separate header file. |
683 | m.def("TF_NewBuffer" , TF_NewBuffer, py::return_value_policy::reference); |
684 | m.def("TF_GetBuffer" , [](TF_Buffer* buf) { |
685 | TF_Buffer buffer = TF_GetBuffer(buf); |
686 | return tensorflow::PyoOrThrow(PyBytes_FromStringAndSize( |
687 | reinterpret_cast<const char*>(buffer.data), buffer.length)); |
688 | }); |
689 | m.def("TF_DeleteBuffer" , &TF_DeleteBuffer); |
690 | m.def( |
691 | "TF_NewBufferFromString" , |
692 | [](py::bytes buffer_as_string) { |
693 | tensorflow::Safe_TF_BufferPtr buf = tensorflow::make_safe( |
694 | ProtoStringToTFBuffer(buffer_as_string.ptr())); |
695 | return TF_NewBufferFromString(buf.get()->data, buf.get()->length); |
696 | }, |
697 | py::return_value_policy::reference); |
698 | |
699 | m.def("SetAttr" , [](TF_Graph* graph, TF_Operation* op, const char* attr_name, |
700 | TF_Buffer* attr_value_proto) { |
701 | tensorflow::Safe_TF_StatusPtr status = |
702 | tensorflow::make_safe(TF_NewStatus()); |
703 | // Release GIL. |
704 | py::gil_scoped_release release; |
705 | tensorflow::SetAttr(graph, op, attr_name, attr_value_proto, status.get()); |
706 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); |
707 | }); |
708 | |
709 | m.def("ClearAttr" , |
710 | [](TF_Graph* graph, TF_Operation* op, const char* attr_name) { |
711 | tensorflow::Safe_TF_StatusPtr status = |
712 | tensorflow::make_safe(TF_NewStatus()); |
713 | // Release GIL. |
714 | py::gil_scoped_release release; |
715 | tensorflow::ClearAttr(graph, op, attr_name, status.get()); |
716 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); |
717 | }); |
718 | |
719 | // Note: users should prefer using tf.cast or equivalent, and only when |
720 | // it's infeasible to set the type via OpDef's type constructor and inference |
721 | // function. |
722 | m.def("SetFullType" , [](TF_Graph* graph, TF_Operation* op, |
723 | const std::string& serialized_full_type) { |
724 | tensorflow::FullTypeDef proto; |
725 | proto.ParseFromString(serialized_full_type); |
726 | tensorflow::SetFullType(graph, op, proto); |
727 | }); |
728 | |
729 | m.def( |
730 | "TF_LoadLibrary" , |
731 | [](const char* library_filename) { |
732 | tensorflow::Safe_TF_StatusPtr status = |
733 | tensorflow::make_safe(TF_NewStatus()); |
734 | auto output = TF_LoadLibrary(library_filename, status.get()); |
735 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
736 | return output; |
737 | }, |
738 | py::return_value_policy::reference); |
739 | |
740 | m.def( |
741 | "TF_LoadPluggableDeviceLibrary" , |
742 | [](const char* library_filename) { |
743 | tensorflow::Safe_TF_StatusPtr status = |
744 | tensorflow::make_safe(TF_NewStatus()); |
745 | auto output = |
746 | TF_LoadPluggableDeviceLibrary(library_filename, status.get()); |
747 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
748 | return output; |
749 | }, |
750 | py::return_value_policy::reference); |
751 | |
752 | m.def("TF_GetOpList" , [](TF_Library* lib_handle) { |
753 | TF_Buffer output_buffer = TF_GetOpList(lib_handle); |
754 | return tensorflow::PyoOrThrow(PyBytes_FromStringAndSize( |
755 | reinterpret_cast<const char*>(output_buffer.data), |
756 | output_buffer.length)); |
757 | }); |
758 | |
759 | m.def("TF_DeleteLibraryHandle" , TF_DeleteLibraryHandle, |
760 | py::call_guard<py::gil_scoped_release>()); |
761 | |
762 | m.def("TF_PluggableDeviceLibraryHandle" , |
763 | TF_DeletePluggableDeviceLibraryHandle, |
764 | py::call_guard<py::gil_scoped_release>()); |
765 | |
766 | m.def("TF_AddControlInput" , TF_AddControlInput); |
767 | m.def( |
768 | "TF_AddInputList" , [](TF_OperationDescription* desc, py::handle& inputs) { |
769 | std::vector<TF_Output> vec; |
770 | size_t size = PyList_Size(inputs.ptr()); |
771 | for (size_t i = 0; i < size; ++i) { |
772 | TF_Output item = py::cast<TF_Output>(PyList_GetItem(inputs.ptr(), i)); |
773 | vec.push_back(item); |
774 | } |
775 | TF_AddInputList(desc, vec.data(), vec.size()); |
776 | }); |
777 | |
778 | m.def("UpdateEdge" , [](TF_Graph* graph, TF_Output new_src, TF_Input dst) { |
779 | tensorflow::Safe_TF_StatusPtr status = |
780 | tensorflow::make_safe(TF_NewStatus()); |
781 | // Release GIL. |
782 | py::gil_scoped_release release; |
783 | tensorflow::UpdateEdge(graph, new_src, dst, status.get()); |
784 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); |
785 | }); |
786 | |
787 | m.def("RemoveAllControlInputs" , tensorflow::RemoveAllControlInputs, |
788 | py::call_guard<py::gil_scoped_release>()); |
789 | m.def("AddControlInput" , tensorflow::AddControlInput, |
790 | py::call_guard<py::gil_scoped_release>()); |
791 | |
792 | m.def("TF_NewImportGraphDefOptions" , TF_NewImportGraphDefOptions, |
793 | py::return_value_policy::reference, |
794 | py::call_guard<py::gil_scoped_release>()); |
795 | m.def("TF_ImportGraphDefOptionsSetPrefix" , TF_ImportGraphDefOptionsSetPrefix, |
796 | py::call_guard<py::gil_scoped_release>()); |
797 | m.def("TF_ImportGraphDefOptionsSetUniquifyNames" , |
798 | TF_ImportGraphDefOptionsSetUniquifyNames, |
799 | py::call_guard<py::gil_scoped_release>()); |
800 | m.def("TF_ImportGraphDefOptionsRemapControlDependency" , |
801 | TF_ImportGraphDefOptionsRemapControlDependency, |
802 | py::call_guard<py::gil_scoped_release>()); |
803 | m.def("TF_ImportGraphDefOptionsAddInputMapping" , |
804 | TF_ImportGraphDefOptionsAddInputMapping, |
805 | py::call_guard<py::gil_scoped_release>()); |
806 | m.def("TF_ImportGraphDefOptionsAddReturnOperation" , |
807 | TF_ImportGraphDefOptionsAddReturnOperation, |
808 | py::call_guard<py::gil_scoped_release>()); |
809 | m.def("TF_ImportGraphDefOptionsAddReturnOutput" , |
810 | TF_ImportGraphDefOptionsAddReturnOutput, |
811 | py::call_guard<py::gil_scoped_release>()); |
812 | |
813 | m.def( |
814 | "TF_GraphImportGraphDefWithResults" , |
815 | [](TF_Graph* graph, const TF_Buffer* graph_def, |
816 | const TF_ImportGraphDefOptions* options) { |
817 | tensorflow::Safe_TF_StatusPtr status = |
818 | tensorflow::make_safe(TF_NewStatus()); |
819 | auto output = TF_GraphImportGraphDefWithResults(graph, graph_def, |
820 | options, status.get()); |
821 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
822 | return output; |
823 | }, |
824 | py::return_value_policy::reference); |
825 | |
826 | m.def( |
827 | "TF_GraphNextOperation" , |
828 | [](TF_Graph* graph, size_t pos) { |
829 | tensorflow::Safe_TF_StatusPtr status = |
830 | tensorflow::make_safe(TF_NewStatus()); |
831 | auto output = TF_GraphNextOperation(graph, &pos); |
832 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
833 | |
834 | // Returns a (TF_Operation*, int pos) tuple. |
835 | py::tuple result_tuple = py::make_tuple( |
836 | py::cast(output), tensorflow::Pyo(PyLong_FromSize_t(pos))); |
837 | return result_tuple; |
838 | }, |
839 | py::return_value_policy::reference); |
840 | |
841 | // Python needs to own deletion of outputs |
842 | m.def("TF_ImportGraphDefResultsReturnOutputs" , |
843 | [](TF_ImportGraphDefResults* results) { |
844 | int num_outputs; |
845 | TF_Output* outputs; |
846 | TF_ImportGraphDefResultsReturnOutputs(results, &num_outputs, |
847 | &outputs); |
848 | py::list py_list; |
849 | for (int i = 0; i < num_outputs; ++i) { |
850 | TF_Output tf_output = TF_Output(outputs[i]); |
851 | py_list.append(tf_output); |
852 | } |
853 | return py_list; |
854 | }); |
855 | |
856 | m.def( |
857 | "TF_ImportGraphDefResultsReturnOperations" , |
858 | [](TF_ImportGraphDefResults* results) { |
859 | int num_opers; |
860 | TF_Operation** opers; |
861 | TF_ImportGraphDefResultsReturnOperations(results, &num_opers, &opers); |
862 | py::list py_list; |
863 | for (int i = 0; i < num_opers; ++i) { |
864 | py_list.append(opers[i]); |
865 | } |
866 | return py_list; |
867 | }, |
868 | py::return_value_policy::reference); |
869 | |
870 | m.def("TF_GraphToGraphDef" , [](TF_Graph* graph, TF_Buffer* output_graph_def) { |
871 | tensorflow::Safe_TF_StatusPtr status = |
872 | tensorflow::make_safe(TF_NewStatus()); |
873 | // Release GIL. |
874 | py::gil_scoped_release release; |
875 | TF_GraphToGraphDef(graph, output_graph_def, status.get()); |
876 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); |
877 | }); |
878 | |
879 | m.def("TF_OperationNumInputs" , TF_OperationNumInputs, |
880 | py::call_guard<py::gil_scoped_release>()); |
881 | |
882 | m.def("TF_GraphVersions" , [](TF_Graph* graph, TF_Buffer* output_graph_def) { |
883 | tensorflow::Safe_TF_StatusPtr status = |
884 | tensorflow::make_safe(TF_NewStatus()); |
885 | // Release GIL. |
886 | py::gil_scoped_release release; |
887 | TF_GraphVersions(graph, output_graph_def, status.get()); |
888 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); |
889 | }); |
890 | |
891 | m.def("TF_DeleteFunction" , TF_DeleteFunction, |
892 | py::call_guard<py::gil_scoped_release>()); |
893 | m.def("TF_DeleteImportGraphDefResults" , TF_DeleteImportGraphDefResults, |
894 | py::call_guard<py::gil_scoped_release>()); |
895 | m.def("TF_DeleteImportGraphDefOptions" , TF_DeleteImportGraphDefOptions, |
896 | py::call_guard<py::gil_scoped_release>()); |
897 | |
898 | m.def("TF_FunctionSetAttrValueProto" , |
899 | [](TF_Function* func, const char* attr_name, py::bytes proto) { |
900 | tensorflow::Safe_TF_StatusPtr status = |
901 | tensorflow::make_safe(TF_NewStatus()); |
902 | tensorflow::Safe_TF_BufferPtr buf = |
903 | tensorflow::make_safe(ProtoStringToTFBuffer(proto.ptr())); |
904 | // Release GIL. |
905 | py::gil_scoped_release release; |
906 | TF_FunctionSetAttrValueProto(func, attr_name, buf.get()->data, |
907 | buf.get()->length, status.get()); |
908 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); |
909 | }); |
910 | |
911 | m.def("TF_FunctionToFunctionDef" , |
912 | [](TF_Function* graph, TF_Buffer* output_func_def) { |
913 | tensorflow::Safe_TF_StatusPtr status = |
914 | tensorflow::make_safe(TF_NewStatus()); |
915 | // Release GIL. |
916 | py::gil_scoped_release release; |
917 | TF_FunctionToFunctionDef(graph, output_func_def, status.get()); |
918 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); |
919 | }); |
920 | |
921 | m.def("TF_GraphCopyFunction" , |
922 | [](TF_Graph* graph, const TF_Function* func, const TF_Function* grad) { |
923 | tensorflow::Safe_TF_StatusPtr status = |
924 | tensorflow::make_safe(TF_NewStatus()); |
925 | // Release GIL. |
926 | py::gil_scoped_release release; |
927 | TF_GraphCopyFunction(graph, func, grad, status.get()); |
928 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); |
929 | }); |
930 | |
931 | m.def( |
932 | "TF_FunctionImportFunctionDef" , |
933 | [](py::bytes proto) { |
934 | tensorflow::Safe_TF_StatusPtr status = |
935 | tensorflow::make_safe(TF_NewStatus()); |
936 | tensorflow::Safe_TF_BufferPtr buf = |
937 | tensorflow::make_safe(ProtoStringToTFBuffer(proto.ptr())); |
938 | |
939 | // Release GIL. |
940 | py::gil_scoped_release release; |
941 | auto output = TF_FunctionImportFunctionDef( |
942 | buf.get()->data, buf.get()->length, status.get()); |
943 | |
944 | // Acquire GIL for returning output returning. |
945 | pybind11::gil_scoped_acquire acquire; |
946 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
947 | return output; |
948 | }, |
949 | py::return_value_policy::reference); |
950 | |
951 | m.def("EqualAttrValueWrapper" , tensorflow::EqualAttrValueWrapper, |
952 | py::call_guard<py::gil_scoped_release>()); |
953 | |
954 | m.def( |
955 | "TF_GetAllRegisteredKernels" , |
956 | []() { |
957 | tensorflow::Safe_TF_StatusPtr status = |
958 | tensorflow::make_safe(TF_NewStatus()); |
959 | // Release GIL. |
960 | py::gil_scoped_release release; |
961 | auto output = TF_GetAllRegisteredKernels(status.get()); |
962 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); |
963 | return output; |
964 | }, |
965 | py::return_value_policy::reference); |
966 | |
967 | m.def( |
968 | "TF_GetRegisteredKernelsForOp" , |
969 | [](const char* name) { |
970 | tensorflow::Safe_TF_StatusPtr status = |
971 | tensorflow::make_safe(TF_NewStatus()); |
972 | // Release GIL. |
973 | py::gil_scoped_release release; |
974 | auto output = TF_GetRegisteredKernelsForOp(name, status.get()); |
975 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); |
976 | return output; |
977 | }, |
978 | py::return_value_policy::reference); |
979 | |
980 | m.def("TF_GetAllOpList" , TF_GetAllOpList, py::return_value_policy::reference, |
981 | py::call_guard<py::gil_scoped_release>()); |
982 | |
983 | m.def( |
984 | "TF_NewApiDefMap" , |
985 | [](TF_Buffer* op_list_buffer) { |
986 | tensorflow::Safe_TF_StatusPtr status = |
987 | tensorflow::make_safe(TF_NewStatus()); |
988 | // Release GIL. |
989 | py::gil_scoped_release release; |
990 | auto output = TF_NewApiDefMap(op_list_buffer, status.get()); |
991 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); |
992 | return output; |
993 | }, |
994 | py::return_value_policy::reference); |
995 | |
996 | m.def("TF_DeleteApiDefMap" , TF_DeleteApiDefMap, |
997 | py::call_guard<py::gil_scoped_release>()); |
998 | |
999 | m.def( |
1000 | "TF_ApiDefMapGet" , |
1001 | [](TF_ApiDefMap* api_def_map, const char* name, size_t name_len) { |
1002 | tensorflow::Safe_TF_StatusPtr status = |
1003 | tensorflow::make_safe(TF_NewStatus()); |
1004 | // Release GIL. |
1005 | py::gil_scoped_release release; |
1006 | auto output = |
1007 | TF_ApiDefMapGet(api_def_map, name, name_len, status.get()); |
1008 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); |
1009 | return output; |
1010 | }, |
1011 | py::return_value_policy::reference); |
1012 | |
1013 | m.def("TF_ApiDefMapPut" , |
1014 | [](TF_ApiDefMap* api_def_map, const char* name, size_t name_len) { |
1015 | tensorflow::Safe_TF_StatusPtr status = |
1016 | tensorflow::make_safe(TF_NewStatus()); |
1017 | // Release GIL. |
1018 | py::gil_scoped_release release; |
1019 | TF_ApiDefMapPut(api_def_map, name, name_len, status.get()); |
1020 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); |
1021 | }); |
1022 | |
1023 | m.def("TF_OperationGetAttrType" , |
1024 | [](TF_Operation* oper, const char* attr_name) { |
1025 | tensorflow::Safe_TF_StatusPtr status = |
1026 | tensorflow::make_safe(TF_NewStatus()); |
1027 | TF_DataType value; |
1028 | // Release GIL. |
1029 | py::gil_scoped_release release; |
1030 | TF_OperationGetAttrType(oper, attr_name, &value, status.get()); |
1031 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); |
1032 | return value; |
1033 | }); |
1034 | |
1035 | m.def( |
1036 | "TF_NewServer" , |
1037 | [](py::bytes proto) { |
1038 | tensorflow::Safe_TF_StatusPtr status = |
1039 | tensorflow::make_safe(TF_NewStatus()); |
1040 | tensorflow::Safe_TF_BufferPtr buf = |
1041 | tensorflow::make_safe(ProtoStringToTFBuffer(proto.ptr())); |
1042 | TF_Server* output = |
1043 | TF_NewServer(buf.get()->data, buf.get()->length, status.get()); |
1044 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
1045 | return output; |
1046 | }, |
1047 | py::return_value_policy::reference); |
1048 | |
1049 | m.def("TF_ServerStart" , [](TF_Server* server) { |
1050 | tensorflow::Safe_TF_StatusPtr status = |
1051 | tensorflow::make_safe(TF_NewStatus()); |
1052 | // Release GIL. |
1053 | py::gil_scoped_release release; |
1054 | TF_ServerStart(server, status.get()); |
1055 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); |
1056 | }); |
1057 | |
1058 | m.def("TF_ServerStop" , [](TF_Server* server) { |
1059 | tensorflow::Safe_TF_StatusPtr status = |
1060 | tensorflow::make_safe(TF_NewStatus()); |
1061 | // Release GIL for threading. |
1062 | py::gil_scoped_release release; |
1063 | TF_ServerStop(server, status.get()); |
1064 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); |
1065 | }); |
1066 | |
1067 | m.def("TF_ServerJoin" , [](TF_Server* server) { |
1068 | tensorflow::Safe_TF_StatusPtr status = |
1069 | tensorflow::make_safe(TF_NewStatus()); |
1070 | // Release GIL for threading. |
1071 | py::gil_scoped_release release; |
1072 | TF_ServerJoin(server, status.get()); |
1073 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); |
1074 | }); |
1075 | |
1076 | m.def( |
1077 | "TF_ServerTarget" , |
1078 | [](TF_Server* server) { return TF_ServerTarget(server); }, |
1079 | py::call_guard<py::gil_scoped_release>()); |
1080 | |
1081 | m.def( |
1082 | "TF_SessionListDevices" , |
1083 | [](TF_Session* session) { |
1084 | tensorflow::Safe_TF_StatusPtr status = |
1085 | tensorflow::make_safe(TF_NewStatus()); |
1086 | TF_DeviceList* output = TF_SessionListDevices(session, status.get()); |
1087 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
1088 | return output; |
1089 | }, |
1090 | py::return_value_policy::reference); |
1091 | |
1092 | m.def("TF_DeviceListCount" , |
1093 | [](const TF_DeviceList* list) { return TF_DeviceListCount(list); }); |
1094 | |
1095 | m.def("TF_DeviceListName" , [](const TF_DeviceList* list, int index) { |
1096 | tensorflow::Safe_TF_StatusPtr status = |
1097 | tensorflow::make_safe(TF_NewStatus()); |
1098 | const char* output = TF_DeviceListName(list, index, status.get()); |
1099 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
1100 | return output; |
1101 | }); |
1102 | |
1103 | m.def("TF_DeviceListType" , [](const TF_DeviceList* list, int index) { |
1104 | tensorflow::Safe_TF_StatusPtr status = |
1105 | tensorflow::make_safe(TF_NewStatus()); |
1106 | const char* output = TF_DeviceListType(list, index, status.get()); |
1107 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
1108 | return output; |
1109 | }); |
1110 | |
1111 | m.def("TF_DeviceListMemoryBytes" , [](const TF_DeviceList* list, int index) { |
1112 | tensorflow::Safe_TF_StatusPtr status = |
1113 | tensorflow::make_safe(TF_NewStatus()); |
1114 | int64_t output = TF_DeviceListMemoryBytes(list, index, status.get()); |
1115 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
1116 | return output; |
1117 | }); |
1118 | |
1119 | m.def("TF_DeviceListIncarnation" , [](const TF_DeviceList* list, int index) { |
1120 | tensorflow::Safe_TF_StatusPtr status = |
1121 | tensorflow::make_safe(TF_NewStatus()); |
1122 | int64_t output = TF_DeviceListIncarnation(list, index, status.get()); |
1123 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
1124 | return output; |
1125 | }); |
1126 | |
1127 | m.def("TF_SetDevice" , TF_SetDevice); |
1128 | |
1129 | m.def("TF_DeleteDeviceList" , TF_DeleteDeviceList); |
1130 | |
1131 | m.def("TF_OperationGetAttrBool" , |
1132 | [](TF_Operation* oper, const char* attr_name) { |
1133 | tensorflow::Safe_TF_StatusPtr status = |
1134 | tensorflow::make_safe(TF_NewStatus()); |
1135 | unsigned char value; |
1136 | // Release GIL for threading. |
1137 | { |
1138 | py::gil_scoped_release release; |
1139 | TF_OperationGetAttrBool(oper, attr_name, &value, status.get()); |
1140 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); |
1141 | } |
1142 | return tensorflow::Pyo(PyBool_FromLong(value)); |
1143 | }); |
1144 | |
1145 | m.def("TF_NewStatus" , TF_NewStatus, py::return_value_policy::reference); |
1146 | m.def("TF_DeleteStatus" , TF_DeleteStatus); |
1147 | |
1148 | m.def("TF_DeleteDeviceList" , TF_DeleteDeviceList); |
1149 | |
1150 | m.def("AddWhileInputHack" , |
1151 | [](TF_Graph* graph, TF_Output new_src, TF_Operation* dst) { |
1152 | tensorflow::Safe_TF_StatusPtr status = |
1153 | tensorflow::make_safe(TF_NewStatus()); |
1154 | // Release GIL for threading. |
1155 | py::gil_scoped_release release; |
1156 | tensorflow::AddWhileInputHack(graph, new_src, dst, status.get()); |
1157 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); |
1158 | }); |
1159 | |
1160 | m.def("TF_Reset_wrapper" , [](const TF_SessionOptions* opt, |
1161 | const std::vector<py::bytes> containers) { |
1162 | tensorflow::Safe_TF_StatusPtr status = |
1163 | tensorflow::make_safe(TF_NewStatus()); |
1164 | // Release GIL for threading. |
1165 | py::gil_scoped_release release; |
1166 | tensorflow::NameVector containers_name_vector = |
1167 | ConvertPyListToNameVector(containers); |
1168 | tensorflow::TF_Reset_wrapper(opt, containers_name_vector, status.get()); |
1169 | tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get()); |
1170 | }); |
1171 | m.def("TF_GetCode" , TF_GetCode); |
1172 | |
1173 | m.def("TF_SetXlaAutoJitMode" , TF_SetXlaAutoJitMode); |
1174 | m.def("TF_GetXlaAutoJitEnabled" , TF_GetXlaAutoJitEnabled); |
1175 | m.def("TF_SetXlaEnableLazyCompilation" , TF_SetXlaEnableLazyCompilation); |
1176 | m.def("TF_SetTfXlaCpuGlobalJit" , TF_SetTfXlaCpuGlobalJit); |
1177 | m.def("TF_SetXlaMinClusterSize" , TF_SetXlaMinClusterSize); |
1178 | m.def("TF_GetXlaConstantFoldingDisabled" , TF_GetXlaConstantFoldingDisabled); |
1179 | m.def("TF_SetXlaConstantFoldingDisabled" , TF_SetXlaConstantFoldingDisabled); |
1180 | |
1181 | // // Static constants are not working on Windows. b/145559202 |
1182 | // // Creating getters instead. |
1183 | |
1184 | m.def("get_version" , []() { return TF_VERSION_STRING; }); |
1185 | m.def("get_git_version" , []() { return TF_GIT_VERSION; }); |
1186 | m.def("get_compiler_version" , []() { return TF_COMPILER_VERSION; }); |
1187 | m.def("get_cxx11_abi_flag" , []() { return TF_CXX11_ABI_FLAG; }); |
1188 | m.def("get_cxx_version" , []() { return TF_CXX_VERSION; }); |
1189 | m.def("get_eigen_max_align_bytes" , []() { return EIGEN_MAX_ALIGN_BYTES; }); |
1190 | m.def("get_monolithic_build" , []() { return TF_MONOLITHIC_BUILD; }); |
1191 | m.def("get_graph_def_version" , []() { return TF_GRAPH_DEF_VERSION; }); |
1192 | m.def("get_graph_def_version_min_consumer" , |
1193 | []() { return TF_GRAPH_DEF_VERSION_MIN_CONSUMER; }); |
1194 | m.def("get_graph_def_version_min_producer" , |
1195 | []() { return TF_GRAPH_DEF_VERSION_MIN_PRODUCER; }); |
1196 | m.def("get_tensor_handle_key" , []() { |
1197 | // TODO(amitpatankar): Look into a more elegant solution. |
1198 | // Since this is a shared object we will hard code the value from |
1199 | // third_party/tensorflow/core/common_runtime/session_state.cc because |
1200 | // the Windows import will not load the libraries necessarily |
1201 | // in order. b/145559202 |
1202 | return "TensorHandle" ; |
1203 | }); |
1204 | |
1205 | m.def("TF_RegisterFilesystemPlugin" , [](const char* plugin_filename) { |
1206 | tensorflow::Safe_TF_StatusPtr status = |
1207 | tensorflow::make_safe(TF_NewStatus()); |
1208 | TF_RegisterFilesystemPlugin(plugin_filename, status.get()); |
1209 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
1210 | }); |
1211 | |
1212 | py::enum_<TF_DataType>(m, "TF_DataType" ) |
1213 | .value("TF_FLOAT" , TF_FLOAT) |
1214 | .value("TF_DOUBLE" , TF_DOUBLE) |
1215 | .value("TF_INT32" , TF_INT32) |
1216 | .value("TF_UINT8" , TF_UINT8) |
1217 | .value("TF_INT16" , TF_INT16) |
1218 | .value("TF_INT8" , TF_INT8) |
1219 | .value("TF_STRING" , TF_STRING) |
1220 | .value("TF_COMPLEX64" , TF_COMPLEX64) |
1221 | .value("TF_COMPLEX" , TF_COMPLEX) |
1222 | .value("TF_INT64" , TF_INT64) |
1223 | .value("TF_BOOL" , TF_BOOL) |
1224 | .value("TF_QINT8" , TF_QINT8) |
1225 | .value("TF_QUINT8" , TF_QUINT8) |
1226 | .value("TF_QINT32" , TF_QINT32) |
1227 | .value("TF_BFLOAT16" , TF_BFLOAT16) |
1228 | .value("TF_QINT16" , TF_QINT16) |
1229 | .value("TF_QUINT16" , TF_QUINT16) |
1230 | .value("TF_UINT16" , TF_UINT16) |
1231 | .value("TF_COMPLEX128" , TF_COMPLEX128) |
1232 | .value("TF_HALF" , TF_HALF) |
1233 | .value("TF_RESOURCE" , TF_RESOURCE) |
1234 | .value("TF_VARIANT" , TF_VARIANT) |
1235 | .value("TF_UINT32" , TF_UINT32) |
1236 | .value("TF_UINT64" , TF_UINT64) |
1237 | .export_values(); |
1238 | |
1239 | py::enum_<TF_Code>(m, "TF_Code" ) |
1240 | .value("TF_OK" , TF_OK) |
1241 | .value("TF_CANCELLED" , TF_CANCELLED) |
1242 | .value("TF_UNKNOWN" , TF_UNKNOWN) |
1243 | .value("TF_INVALID_ARGUMENT" , TF_INVALID_ARGUMENT) |
1244 | .value("TF_DEADLINE_EXCEEDED" , TF_DEADLINE_EXCEEDED) |
1245 | .value("TF_PERMISSION_DENIED" , TF_PERMISSION_DENIED) |
1246 | .value("TF_UNAUTHENTICATED" , TF_UNAUTHENTICATED) |
1247 | .value("TF_RESOURCE_EXHAUSTED" , TF_RESOURCE_EXHAUSTED) |
1248 | .value("TF_FAILED_PRECONDITION" , TF_FAILED_PRECONDITION) |
1249 | .value("TF_ABORTED" , TF_ABORTED) |
1250 | .value("TF_OUT_OF_RANGE" , TF_OUT_OF_RANGE) |
1251 | .value("TF_UNIMPLEMENTED" , TF_UNIMPLEMENTED) |
1252 | .value("TF_INTERNAL" , TF_INTERNAL) |
1253 | .value("TF_DATA_LOSS" , TF_DATA_LOSS) |
1254 | .export_values(); |
1255 | }; |
1256 | |