1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// Must be at top (before any system includes and Python.h).
17// clang-format off
18#include "pybind11/chrono.h"
19#include "pybind11/complex.h"
20#include "pybind11/functional.h"
21#include "pybind11/pybind11.h"
22#include "pybind11/stl.h"
23// clang-format on
24
25#include "Python.h"
26#include "absl/types/optional.h"
27#include "third_party/eigen3/Eigen/Core"
28#include "tensorflow/c/c_api.h"
29#include "tensorflow/c/c_api_experimental.h"
30#include "tensorflow/c/c_api_internal.h"
31#include "tensorflow/c/python_api.h"
32#include "tensorflow/c/tf_datatype.h"
33#include "tensorflow/core/distributed_runtime/server_lib.h"
34#include "tensorflow/core/framework/full_type.pb.h"
35#include "tensorflow/core/public/version.h"
36#include "tensorflow/core/util/version_info.h"
37#include "tensorflow/python/client/tf_session_helper.h"
38#include "tensorflow/python/lib/core/numpy.h"
39#include "tensorflow/python/lib/core/pybind11_lib.h"
40#include "tensorflow/python/lib/core/pybind11_status.h"
41#include "tensorflow/python/lib/core/safe_ptr.h"
42
43namespace pybind11 {
44namespace detail {
45// Convert between absl::optional and python.
46//
47// pybind11 supports std::optional, and absl::optional is meant to be a
48// drop-in replacement for std::optional, so we can just use the built in
49// implementation.
50#ifndef ABSL_USES_STD_OPTIONAL
51template <typename T>
52struct type_caster<absl::optional<T>>
53 : public optional_caster<absl::optional<T>> {};
54template <>
55struct type_caster<absl::nullopt_t> : public void_caster<absl::nullopt_t> {};
56#endif
57
58} // namespace detail
59} // namespace pybind11
60
61// TODO(amitpatankar): Consolidate Buffer methods into a separate header file.
62TF_Buffer* ProtoStringToTFBuffer(PyObject* input) {
63 // Convert a Python string object to TF_Buffer.
64 char* c_string;
65 Py_ssize_t py_size;
66 // PyBytes_AsStringAndSize() does not copy but simply interprets the input
67 if (PyBytes_AsStringAndSize(input, &c_string, &py_size) == -1) {
68 // Python has raised an error (likely TypeError or UnicodeEncodeError).
69 throw py::error_already_set();
70 }
71 return TF_NewBufferFromString(static_cast<void*>(c_string),
72 static_cast<size_t>(py_size));
73}
74
75// Copied from tf_session.i
76// We have to do convoluted logic of passing in a vector of py::bytes. If we
77// pass in strings they are freed prior to the necessary function calls.
78tensorflow::NameVector ConvertPyListToNameVector(
79 const std::vector<py::bytes>& py_vector) {
80 tensorflow::NameVector temp;
81 for (size_t i = 0; i < py_vector.size(); ++i) {
82 const char* string_elem = PyBytes_AsString(py_vector.at(i).ptr());
83 temp.push_back(string_elem);
84 }
85 return temp;
86}
87
88namespace py = pybind11;
89
90PYBIND11_MAKE_OPAQUE(TF_Graph);
91PYBIND11_MAKE_OPAQUE(TF_Session);
92PYBIND11_MAKE_OPAQUE(TF_Operation);
93PYBIND11_MAKE_OPAQUE(TF_Buffer);
94PYBIND11_MAKE_OPAQUE(TF_ImportGraphDefOptions);
95PYBIND11_MAKE_OPAQUE(TF_ImportGraphDefResults);
96PYBIND11_MAKE_OPAQUE(TF_DeprecatedSession);
97PYBIND11_MAKE_OPAQUE(TF_OperationDescription);
98PYBIND11_MAKE_OPAQUE(TF_Library);
99PYBIND11_MAKE_OPAQUE(TF_SessionOptions);
100PYBIND11_MAKE_OPAQUE(TF_ApiDefMap);
101PYBIND11_MAKE_OPAQUE(TF_Server);
102PYBIND11_MAKE_OPAQUE(TF_DeviceList);
103PYBIND11_MAKE_OPAQUE(TF_Status);
104
105PYBIND11_MODULE(_pywrap_tf_session, m) {
106 // Numpy initialization code for array checks.
107 tensorflow::ImportNumpy();
108
109 py::class_<TF_Graph> TF_Graph_class(m, "TF_Graph");
110 py::class_<TF_Operation> TF_Operation_class(m, "TF_Operation");
111
112 py::class_<TF_Output>(m, "TF_Output")
113 .def(py::init<>())
114 .def_readwrite("oper", &TF_Output::oper)
115 .def_readwrite("index", &TF_Output::index);
116
117 py::class_<TF_Input>(m, "TF_Input")
118 .def(py::init<>())
119 .def_readwrite("oper", &TF_Input::oper)
120 .def_readwrite("index", &TF_Input::index);
121
122 py::class_<TF_ImportGraphDefOptions> TF_ImportGraphDefOptions_class(
123 m, "TF_ImportGraphDefOptions");
124 py::class_<TF_ImportGraphDefResults> TF_ImportGraphDefResults_class(
125 m, "TF_ImportGraphDefResults");
126 py::class_<TF_DeprecatedSession> TF_DeprecatedSession_class(
127 m, "TF_DeprecatedSession");
128 py::class_<TF_Session> TF_Session_class(m, "TF_Session");
129 py::class_<TF_OperationDescription> TF_OperationDescription_class(
130 m, "TF_OperationDescription");
131 py::class_<TF_Library> TF_Library_class(m, "TF_Library");
132 py::class_<TF_SessionOptions> TF_SessionOptions_class(m, "TF_SessionOptions");
133 py::class_<TF_Buffer> TF_Buffer_class(m, "TF_Buffer");
134 py::class_<TF_ApiDefMap> TF_ApiDefMap_class(m, "TF_ApiDefMap");
135 py::class_<TF_Server> TF_Server_class(m, "TF_Server");
136 py::class_<TF_Status> TF_Status_class(m, "TF_Status");
137
138 // We only release the Python GIL for certain methods that are
139 // not explicitly marked. We disable this behavior for some functions
140 // because they uses Python method(s) that expect the GIL to be held
141 // (at least PyArray_Return, maybe others).
142
143 // Do not release GIL.
144 m.def("TF_OperationGetControlInputs_wrapper",
145 tensorflow::TF_OperationGetControlInputs_wrapper);
146 // Do not release GIL.
147 m.def("TF_OperationGetControlOutputs_wrapper",
148 tensorflow::TF_OperationGetControlOutputs_wrapper);
149 m.def("TF_OperationOutputConsumers_wrapper",
150 tensorflow::TF_OperationOutputConsumers_wrapper);
151 // Do not release GIL.
152 m.def("GetOperationInputs", tensorflow::GetOperationInputs);
153 // Do not release GIL.
154 m.def("TF_ImportGraphDefOptionsSetValidateColocationConstraints",
155 TF_ImportGraphDefOptionsSetValidateColocationConstraints);
156 // Do not release GIL.
157 m.def("TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper",
158 tensorflow::TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper);
159 m.def("TF_SessionMakeCallable",
160 [](TF_Session* session, const TF_Buffer* callable_options) {
161 int64_t out_handle;
162 tensorflow::Safe_TF_StatusPtr status =
163 tensorflow::make_safe(TF_NewStatus());
164
165 // Release GIL.
166 py::gil_scoped_release release;
167 tensorflow::TF_SessionMakeCallable(session, callable_options,
168 &out_handle, status.get());
169
170 // Acquire GIL for returning int conversion.
171 pybind11::gil_scoped_acquire acquire;
172 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
173 return out_handle;
174 });
175 m.def("_TF_SetTarget", TF_SetTarget);
176 m.def("_TF_SetConfig", [](TF_SessionOptions* options, py::bytes proto) {
177 tensorflow::Safe_TF_StatusPtr status =
178 tensorflow::make_safe(TF_NewStatus());
179 tensorflow::Safe_TF_BufferPtr buf =
180 tensorflow::make_safe(ProtoStringToTFBuffer(proto.ptr()));
181 TF_SetConfig(options, buf.get()->data, buf.get()->length, status.get());
182 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
183 });
184 m.def("_TF_NewSessionOptions", TF_NewSessionOptions,
185 py::return_value_policy::reference,
186 py::call_guard<py::gil_scoped_release>());
187 m.def("TF_DeleteSessionOptions", TF_DeleteSessionOptions,
188 py::call_guard<py::gil_scoped_release>());
189
190 m.def("EqualGraphDefWrapper", tensorflow::EqualGraphDefWrapper,
191 py::call_guard<py::gil_scoped_release>());
192 m.def("EqualAttrValueWrapper", tensorflow::EqualAttrValueWrapper,
193 py::call_guard<py::gil_scoped_release>());
194
195 m.def(
196 "TF_GraphToFunction_wrapper",
197 [](const TF_Graph* fn_body, const char* fn_name,
198 bool append_hash_to_fn_name,
199 absl::optional<std::vector<TF_Operation*>> opers_opt,
200 const std::vector<TF_Output>& inputs,
201 const std::vector<TF_Output>& outputs,
202 const std::vector<py::bytes> output_names,
203 const std::vector<TF_Operation*> control_outputs,
204 const std::vector<py::bytes> control_output_names, py::none opts,
205 const char* description) {
206 tensorflow::Safe_TF_StatusPtr status =
207 tensorflow::make_safe(TF_NewStatus());
208
209 // TODO(b/147674626): Use pybind11 list_caster instead.
210 tensorflow::NameVector output_names_name_vector =
211 ConvertPyListToNameVector(output_names);
212
213 // TODO(b/147674626): Use pybind11 list_caster instead.
214 tensorflow::NameVector control_output_names_name_vector =
215 ConvertPyListToNameVector(control_output_names);
216
217 // Release GIL.
218 py::gil_scoped_release release;
219 auto output = tensorflow::TF_GraphToFunction_wrapper(
220 fn_body, fn_name, append_hash_to_fn_name,
221 opers_opt.has_value() ? &opers_opt.value() : nullptr, inputs,
222 outputs, output_names_name_vector, &control_outputs,
223 control_output_names_name_vector,
224 /*opts=*/nullptr, description, status.get());
225 tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get());
226 return output;
227 },
228 py::return_value_policy::reference);
229
230 m.def("TF_GraphGetTensorShapeHelper", [](TF_Graph* graph, TF_Output output) {
231 tensorflow::Safe_TF_StatusPtr status =
232 tensorflow::make_safe(TF_NewStatus());
233 bool unknown_shape;
234
235 auto result = tensorflow::TF_GraphGetTensorShapeHelper(
236 graph, output, status.get(), &unknown_shape);
237 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
238
239 // Create a python list from InlinedVector
240 py::list py_list;
241 for (size_t i = 0; i < result.size(); ++i) {
242 py_list.append(py::cast(result[i]));
243 }
244
245 // Return a tuple.
246 py::tuple result_tuple = py::make_tuple(py_list, py::cast(unknown_shape));
247 return result_tuple;
248 });
249
250 m.def("TF_GraphSetTensorShape_wrapper",
251 [](TF_Graph* graph, TF_Output output, const std::vector<int64_t>& dims,
252 bool unknown_shape) {
253 tensorflow::Safe_TF_StatusPtr status =
254 tensorflow::make_safe(TF_NewStatus());
255
256 // Release GIL.
257 py::gil_scoped_release release;
258 tensorflow::TF_GraphSetTensorShape_wrapper(
259 graph, output, dims, unknown_shape, status.get());
260 tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get());
261 });
262
263 m.def("TF_GraphGetTensorShape_wrapper",
264 [](TF_Graph* graph, TF_Output output, const std::vector<int64_t>& dims,
265 bool unknown_shape) {
266 tensorflow::Safe_TF_StatusPtr status =
267 tensorflow::make_safe(TF_NewStatus());
268 // Release GIL.
269 py::gil_scoped_release release;
270 tensorflow::TF_GraphSetTensorShape_wrapper(
271 graph, output, dims, unknown_shape, status.get());
272 tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get());
273 });
274
275 m.def("TF_GraphSetOutputHandleShapesAndTypes_wrapper",
276 [](TF_Graph* graph, TF_Output output,
277 const std::vector<absl::optional<std::vector<int64_t>>>& shapes,
278 const std::vector<int>& ranks, py::handle& types) {
279 tensorflow::Safe_TF_StatusPtr status =
280 tensorflow::make_safe(TF_NewStatus());
281
282 // Cast types
283 std::vector<TF_DataType> types_local;
284 PyObject* seq =
285 PySequence_Fast(types.ptr(), "$symname: expected list");
286 if (seq == nullptr) {
287 PyErr_SetString(PyExc_RuntimeError,
288 "$symname: PySequence_Fast returned NULL.");
289 throw py::error_already_set();
290 }
291
292 int size = PySequence_Fast_GET_SIZE(seq);
293 if (size == 0) {
294 PyErr_SetString(PyExc_ValueError,
295 "$symname: shapes list must be non-empty");
296 throw py::error_already_set();
297 }
298
299 for (int i = 0; i < size; ++i) {
300 PyObject* item = PySequence_Fast_GET_ITEM(seq, i);
301 types_local.push_back((TF_DataType)PyLong_AsLong(item));
302 }
303
304 // Convert shapes nested vector
305 std::vector<std::vector<int64_t>> shapes_local;
306 for (size_t i = 0; i < shapes.size(); ++i) {
307 std::vector<int64_t> dims;
308 std::vector<int64_t> item =
309 shapes[i].has_value() ? shapes[i].value() : dims;
310 shapes_local.push_back(item);
311 }
312
313 Py_DECREF(seq);
314
315 tensorflow::TF_GraphSetOutputHandleShapesAndTypes_wrapper(
316 graph, output, shapes_local, ranks, types_local, status.get());
317 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
318 });
319
320 // Do not release GIL.
321 m.def("TF_CreatePlaceholders",
322 [](TF_Graph* graph, py::handle& dtypes, const char* prefix) {
323 tensorflow::Safe_TF_StatusPtr status =
324 tensorflow::make_safe(TF_NewStatus());
325 auto output = tensorflow::TF_CreatePlaceholders(graph, dtypes.ptr(),
326 prefix, status.get());
327 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
328 return output;
329 });
330
331 m.def(
332 "TF_NewSession",
333 [](TF_Graph* graph, const TF_SessionOptions* opts) {
334 tensorflow::Safe_TF_StatusPtr status =
335 tensorflow::make_safe(TF_NewStatus());
336 // Release GIL.
337 py::gil_scoped_release release;
338 auto output = TF_NewSession(graph, opts, status.get());
339 tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get());
340 return output;
341 },
342 py::return_value_policy::reference);
343
344 m.def(
345 "TF_NewSessionRef",
346 [](TF_Graph* graph, const TF_SessionOptions* opts) {
347 tensorflow::Safe_TF_StatusPtr status =
348 tensorflow::make_safe(TF_NewStatus());
349 // Release GIL.
350 py::gil_scoped_release release;
351 auto output = tensorflow::TF_NewSessionRef(graph, opts, status.get());
352 tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get());
353 return output;
354 },
355 py::return_value_policy::reference);
356
357 m.def("TF_CloseSession", [](TF_Session* session) {
358 tensorflow::Safe_TF_StatusPtr status =
359 tensorflow::make_safe(TF_NewStatus());
360
361 // Release GIL.
362 py::gil_scoped_release release;
363 TF_CloseSession(session, status.get());
364
365 tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get());
366 });
367
368 m.def("TF_DeleteSession", [](TF_Session* session) {
369 tensorflow::Safe_TF_StatusPtr status =
370 tensorflow::make_safe(TF_NewStatus());
371 // Release GIL.
372 py::gil_scoped_release release;
373 TF_DeleteSession(session, status.get());
374 tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get());
375 });
376
377 m.def("SetRequireShapeInferenceFns", tensorflow::SetRequireShapeInferenceFns);
378
379 // Do not release GIL.
380 m.def("TF_TryEvaluateConstant_wrapper",
381 [](TF_Graph* graph, const TF_Output output) {
382 tensorflow::Safe_TF_StatusPtr status =
383 tensorflow::make_safe(TF_NewStatus());
384 auto result = tensorflow::TF_TryEvaluateConstant_wrapper(
385 graph, output, status.get());
386 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
387 return tensorflow::PyoOrThrow(result);
388 });
389
390 m.def("ExtendSession", [](TF_Session* session) {
391 tensorflow::Safe_TF_StatusPtr status =
392 tensorflow::make_safe(TF_NewStatus());
393 // Release GIL for threading.
394 pybind11::gil_scoped_release release;
395 tensorflow::ExtendSession(session, status.get());
396 tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get());
397 });
398
399 m.def("GetHandleShapeAndType", [](TF_Graph* graph, TF_Output output) {
400 std::string output_string =
401 tensorflow::GetHandleShapeAndType(graph, output);
402 // Override default py3 behavior of attempting to encode into Unicode as
403 // the dependent functions expect bytes.
404 return py::bytes(output_string);
405 });
406
407 m.def("SetHandleShapeAndType",
408 [](TF_Graph* graph, TF_Output output, py::bytes proto) {
409 tensorflow::Safe_TF_StatusPtr status =
410 tensorflow::make_safe(TF_NewStatus());
411 tensorflow::Safe_TF_BufferPtr buf =
412 tensorflow::make_safe(ProtoStringToTFBuffer(proto.ptr()));
413 tensorflow::SetHandleShapeAndType(graph, output, buf.get()->data,
414 buf.get()->length, status.get());
415 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
416 });
417
418 // Do not release GIL.
419 m.def("TF_SessionRun_wrapper", [](TF_Session* session, TF_Buffer* run_options,
420 const py::handle& input_dict,
421 const std::vector<TF_Output>& outputs,
422 const std::vector<TF_Operation*>& targets,
423 TF_Buffer* run_metadata) {
424 // Convert inputs dictionary
425 std::vector<TF_Output> inputs;
426 std::vector<PyObject*> input_ndarrays;
427 if (!PyDict_Check(input_dict.ptr())) {
428 PyErr_SetString(
429 PyExc_TypeError,
430 "Expected a dictionary as an argument to TF_SessionRun_wrapper.");
431 throw py::error_already_set();
432 }
433 PyObject* key;
434 PyObject* value;
435 Py_ssize_t pos = 0;
436 while (PyDict_Next(input_dict.ptr(), &pos, &key, &value)) {
437 TF_Output item = py::cast<TF_Output>(key);
438 inputs.push_back(item);
439
440 // TODO(amitpatankar): Fix this PyArray check. (b/147855599)
441
442 // if (!PyArray_Check(value)) {
443 // PyErr_SetString(
444 // PyExc_TypeError,
445 // "$symname: Expected all values in input dict to be ndarray.");
446 // throw py::error_already_set();
447 // }
448 input_ndarrays.push_back(value);
449 }
450
451 tensorflow::Safe_TF_StatusPtr status =
452 tensorflow::make_safe(TF_NewStatus());
453 std::vector<PyObject*> py_outputs;
454 tensorflow::TF_SessionRun_wrapper(session, run_options, inputs,
455 input_ndarrays, outputs, targets,
456 run_metadata, status.get(), &py_outputs);
457 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
458
459 // Create a Python list using the C API rather than py::list. b/147855599
460 PyObject* result = PyList_New(py_outputs.size());
461 if (result == nullptr) {
462 PyErr_SetString(PyExc_MemoryError, "Failed to create a list.");
463 throw py::error_already_set();
464 }
465 for (size_t i = 0; i < py_outputs.size(); ++i) {
466 PyList_SET_ITEM(result, i, py_outputs.at(i));
467 }
468
469 return tensorflow::PyoOrThrow(result);
470 });
471
472 // Do not release GIL.
473 m.def("TF_SessionPRun_wrapper", [](TF_Session* session, const char* handle,
474 const py::handle& input_dict,
475 const std::vector<TF_Output>& outputs) {
476 // Convert inputs dictionary
477 std::vector<TF_Output> inputs;
478 std::vector<PyObject*> input_ndarrays;
479 if (!PyDict_Check(input_dict.ptr())) {
480 PyErr_SetString(
481 PyExc_TypeError,
482 "Expected a dictionary as an argument to TF_SessionPRun_wrapper.");
483 throw py::error_already_set();
484 }
485 PyObject* key;
486 PyObject* value;
487 Py_ssize_t pos = 0;
488 while (PyDict_Next(input_dict.ptr(), &pos, &key, &value)) {
489 TF_Output item = py::cast<TF_Output>(key);
490 inputs.push_back(item);
491
492 // TODO(amitpatankar): Fix this PyArray check. (b/147855599)
493
494 // if (!PyArray_Check(value)) {
495 // PyErr_SetString(
496 // PyExc_TypeError,
497 // "$symname: Expected all values in input dict to be ndarray.");
498 // throw py::error_already_set();
499 // }
500 input_ndarrays.push_back(value);
501 }
502
503 tensorflow::Safe_TF_StatusPtr status =
504 tensorflow::make_safe(TF_NewStatus());
505 std::vector<PyObject*> py_outputs;
506 tensorflow::TF_SessionPRun_wrapper(session, handle, inputs, input_ndarrays,
507 outputs, status.get(), &py_outputs);
508 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
509
510 PyObject* result = PyList_New(py_outputs.size());
511 if (result == nullptr) {
512 PyErr_SetString(PyExc_MemoryError, "Failed to create a list.");
513 throw py::error_already_set();
514 }
515 for (size_t i = 0; i < py_outputs.size(); ++i) {
516 PyList_SET_ITEM(result, i, py_outputs.at(i));
517 }
518
519 return tensorflow::PyoOrThrow(result);
520 });
521
522 // Do not release GIL.
523 m.def("TF_SessionPRunSetup_wrapper",
524 [](TF_Session* session, const std::vector<TF_Output>& inputs,
525 const std::vector<TF_Output>& outputs,
526 const std::vector<TF_Operation*>& targets) {
527 tensorflow::Safe_TF_StatusPtr status =
528 tensorflow::make_safe(TF_NewStatus());
529 const char* out_handle;
530 tensorflow::TF_SessionPRunSetup_wrapper(
531 session, inputs, outputs, targets, &out_handle, status.get());
532 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
533 return out_handle;
534 });
535
536 // Do not release GIL.
537 m.def("TF_SessionRunCallable", [](TF_Session* session, int64_t handle,
538 py::object feed_values,
539 TF_Buffer* run_metadata) {
540 tensorflow::PyObjectVector out_values;
541 tensorflow::Safe_TF_StatusPtr status =
542 tensorflow::make_safe(TF_NewStatus());
543 tensorflow::TF_SessionRunCallable(session, handle, feed_values.ptr(),
544 &out_values, run_metadata, status.get());
545 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
546
547 // Return out_values
548 py::list py_list;
549 for (size_t i = 0; i < out_values.size(); ++i) {
550 py::object obj = tensorflow::Pyo(out_values.at(i));
551 py_list.append(obj);
552 }
553 return py_list;
554 });
555
556 m.def("TF_SessionReleaseCallable", [](TF_Session* session, int64_t handle) {
557 tensorflow::Safe_TF_StatusPtr status =
558 tensorflow::make_safe(TF_NewStatus());
559 // Release GIL.
560 py::gil_scoped_release release;
561 tensorflow::TF_SessionReleaseCallable(session, handle, status.get());
562 tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get());
563 });
564
565 m.def("TF_NewGraph", TF_NewGraph, py::return_value_policy::reference,
566 py::call_guard<py::gil_scoped_release>());
567 m.def("TF_DeleteGraph", TF_DeleteGraph,
568 py::call_guard<py::gil_scoped_release>());
569
570 m.def("TF_GraphGetOpDef",
571 [](TF_Graph* graph, const char* op_name, TF_Buffer* output_op_def) {
572 tensorflow::Safe_TF_StatusPtr status =
573 tensorflow::make_safe(TF_NewStatus());
574 // Release GIL.
575 py::gil_scoped_release release;
576 TF_GraphGetOpDef(graph, op_name, output_op_def, status.get());
577 tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get());
578 });
579
580 m.def(
581 "TF_NewOperation",
582 [](TF_Graph* graph, const char* op_type, const char* oper_name) {
583 tensorflow::Safe_TF_StatusPtr status =
584 tensorflow::make_safe(TF_NewStatus());
585 // Release GIL.
586 py::gil_scoped_release release;
587 TF_OperationDescription* output =
588 TF_NewOperation(graph, op_type, oper_name);
589 tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get());
590 return output;
591 },
592 py::return_value_policy::reference);
593
594 m.def(
595 "TF_FinishOperation",
596 [](TF_OperationDescription* desc) {
597 tensorflow::Safe_TF_StatusPtr status =
598 tensorflow::make_safe(TF_NewStatus());
599 // Release GIL.
600 py::gil_scoped_release release;
601 TF_Operation* output = TF_FinishOperation(desc, status.get());
602 tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get());
603 return output;
604 },
605 py::return_value_policy::reference);
606
607 m.def("TF_OperationGetAttrInt",
608 [](TF_Operation* oper, const char* attr_name) {
609 tensorflow::Safe_TF_StatusPtr status =
610 tensorflow::make_safe(TF_NewStatus());
611 int64_t value;
612 // Release GIL.
613 py::gil_scoped_release release;
614 TF_OperationGetAttrInt(oper, attr_name, &value, status.get());
615 tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get());
616 // Convert TF_OperationGetAttrInt int64_t* out-argument to Python
617 // bool.
618 // Acquire GIL for returning output returning.
619 pybind11::gil_scoped_acquire acquire;
620 return tensorflow::Pyo(PyLong_FromLongLong(value));
621 });
622
623 m.def("TF_SetAttrValueProto", [](TF_OperationDescription* desc,
624 const char* attr_name, py::bytes proto) {
625 tensorflow::Safe_TF_StatusPtr status =
626 tensorflow::make_safe(TF_NewStatus());
627 tensorflow::Safe_TF_BufferPtr buf =
628 tensorflow::make_safe(ProtoStringToTFBuffer(proto.ptr()));
629 TF_SetAttrValueProto(desc, attr_name, buf.get()->data, buf.get()->length,
630 status.get());
631 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
632 });
633
634 m.def("TF_OperationNumOutputs", TF_OperationNumOutputs,
635 py::call_guard<py::gil_scoped_release>());
636
637 // Convert types to ints
638 m.def("TF_OperationInputType", TF_OperationInputType,
639 py::call_guard<py::gil_scoped_release>());
640 m.def("TF_OperationOutputType", TF_OperationOutputType,
641 py::call_guard<py::gil_scoped_release>());
642
643 m.def("TF_OperationName", TF_OperationName,
644 py::call_guard<py::gil_scoped_release>());
645 m.def("TF_OperationOpType", TF_OperationOpType,
646 py::call_guard<py::gil_scoped_release>());
647 m.def("TF_OperationDevice", TF_OperationDevice,
648 py::call_guard<py::gil_scoped_release>());
649
650 m.def("TF_AddInput", TF_AddInput);
651
652 m.def("TF_OperationToNodeDef",
653 [](TF_Operation* oper, TF_Buffer* output_node_def) {
654 tensorflow::Safe_TF_StatusPtr status =
655 tensorflow::make_safe(TF_NewStatus());
656 TF_OperationToNodeDef(oper, output_node_def, status.get());
657 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
658 });
659
660 m.def("TF_OperationGetAttrValueProto",
661 [](TF_Operation* oper, const char* attr_name,
662 TF_Buffer* output_attr_value) {
663 tensorflow::Safe_TF_StatusPtr status =
664 tensorflow::make_safe(TF_NewStatus());
665 TF_OperationGetAttrValueProto(oper, attr_name, output_attr_value,
666 status.get());
667 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
668 });
669
670 m.def("TF_OperationGetStackTrace", [](TF_Operation* oper) -> py::object {
671 const std::shared_ptr<tensorflow::AbstractStackTrace> trace =
672 oper->node.GetStackTrace();
673 if (!trace) {
674 return py::none();
675 }
676 return py::cast(*trace, py::return_value_policy::reference);
677 });
678
679 m.def("SetRequestedDevice", tensorflow::SetRequestedDevice);
680
681 // TF_Buffer util methods
682 // TODO(amitpatankar): Consolidate Buffer methods into a separate header file.
683 m.def("TF_NewBuffer", TF_NewBuffer, py::return_value_policy::reference);
684 m.def("TF_GetBuffer", [](TF_Buffer* buf) {
685 TF_Buffer buffer = TF_GetBuffer(buf);
686 return tensorflow::PyoOrThrow(PyBytes_FromStringAndSize(
687 reinterpret_cast<const char*>(buffer.data), buffer.length));
688 });
689 m.def("TF_DeleteBuffer", &TF_DeleteBuffer);
690 m.def(
691 "TF_NewBufferFromString",
692 [](py::bytes buffer_as_string) {
693 tensorflow::Safe_TF_BufferPtr buf = tensorflow::make_safe(
694 ProtoStringToTFBuffer(buffer_as_string.ptr()));
695 return TF_NewBufferFromString(buf.get()->data, buf.get()->length);
696 },
697 py::return_value_policy::reference);
698
699 m.def("SetAttr", [](TF_Graph* graph, TF_Operation* op, const char* attr_name,
700 TF_Buffer* attr_value_proto) {
701 tensorflow::Safe_TF_StatusPtr status =
702 tensorflow::make_safe(TF_NewStatus());
703 // Release GIL.
704 py::gil_scoped_release release;
705 tensorflow::SetAttr(graph, op, attr_name, attr_value_proto, status.get());
706 tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get());
707 });
708
709 m.def("ClearAttr",
710 [](TF_Graph* graph, TF_Operation* op, const char* attr_name) {
711 tensorflow::Safe_TF_StatusPtr status =
712 tensorflow::make_safe(TF_NewStatus());
713 // Release GIL.
714 py::gil_scoped_release release;
715 tensorflow::ClearAttr(graph, op, attr_name, status.get());
716 tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get());
717 });
718
719 // Note: users should prefer using tf.cast or equivalent, and only when
720 // it's infeasible to set the type via OpDef's type constructor and inference
721 // function.
722 m.def("SetFullType", [](TF_Graph* graph, TF_Operation* op,
723 const std::string& serialized_full_type) {
724 tensorflow::FullTypeDef proto;
725 proto.ParseFromString(serialized_full_type);
726 tensorflow::SetFullType(graph, op, proto);
727 });
728
729 m.def(
730 "TF_LoadLibrary",
731 [](const char* library_filename) {
732 tensorflow::Safe_TF_StatusPtr status =
733 tensorflow::make_safe(TF_NewStatus());
734 auto output = TF_LoadLibrary(library_filename, status.get());
735 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
736 return output;
737 },
738 py::return_value_policy::reference);
739
740 m.def(
741 "TF_LoadPluggableDeviceLibrary",
742 [](const char* library_filename) {
743 tensorflow::Safe_TF_StatusPtr status =
744 tensorflow::make_safe(TF_NewStatus());
745 auto output =
746 TF_LoadPluggableDeviceLibrary(library_filename, status.get());
747 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
748 return output;
749 },
750 py::return_value_policy::reference);
751
752 m.def("TF_GetOpList", [](TF_Library* lib_handle) {
753 TF_Buffer output_buffer = TF_GetOpList(lib_handle);
754 return tensorflow::PyoOrThrow(PyBytes_FromStringAndSize(
755 reinterpret_cast<const char*>(output_buffer.data),
756 output_buffer.length));
757 });
758
759 m.def("TF_DeleteLibraryHandle", TF_DeleteLibraryHandle,
760 py::call_guard<py::gil_scoped_release>());
761
762 m.def("TF_PluggableDeviceLibraryHandle",
763 TF_DeletePluggableDeviceLibraryHandle,
764 py::call_guard<py::gil_scoped_release>());
765
766 m.def("TF_AddControlInput", TF_AddControlInput);
767 m.def(
768 "TF_AddInputList", [](TF_OperationDescription* desc, py::handle& inputs) {
769 std::vector<TF_Output> vec;
770 size_t size = PyList_Size(inputs.ptr());
771 for (size_t i = 0; i < size; ++i) {
772 TF_Output item = py::cast<TF_Output>(PyList_GetItem(inputs.ptr(), i));
773 vec.push_back(item);
774 }
775 TF_AddInputList(desc, vec.data(), vec.size());
776 });
777
778 m.def("UpdateEdge", [](TF_Graph* graph, TF_Output new_src, TF_Input dst) {
779 tensorflow::Safe_TF_StatusPtr status =
780 tensorflow::make_safe(TF_NewStatus());
781 // Release GIL.
782 py::gil_scoped_release release;
783 tensorflow::UpdateEdge(graph, new_src, dst, status.get());
784 tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get());
785 });
786
787 m.def("RemoveAllControlInputs", tensorflow::RemoveAllControlInputs,
788 py::call_guard<py::gil_scoped_release>());
789 m.def("AddControlInput", tensorflow::AddControlInput,
790 py::call_guard<py::gil_scoped_release>());
791
792 m.def("TF_NewImportGraphDefOptions", TF_NewImportGraphDefOptions,
793 py::return_value_policy::reference,
794 py::call_guard<py::gil_scoped_release>());
795 m.def("TF_ImportGraphDefOptionsSetPrefix", TF_ImportGraphDefOptionsSetPrefix,
796 py::call_guard<py::gil_scoped_release>());
797 m.def("TF_ImportGraphDefOptionsSetUniquifyNames",
798 TF_ImportGraphDefOptionsSetUniquifyNames,
799 py::call_guard<py::gil_scoped_release>());
800 m.def("TF_ImportGraphDefOptionsRemapControlDependency",
801 TF_ImportGraphDefOptionsRemapControlDependency,
802 py::call_guard<py::gil_scoped_release>());
803 m.def("TF_ImportGraphDefOptionsAddInputMapping",
804 TF_ImportGraphDefOptionsAddInputMapping,
805 py::call_guard<py::gil_scoped_release>());
806 m.def("TF_ImportGraphDefOptionsAddReturnOperation",
807 TF_ImportGraphDefOptionsAddReturnOperation,
808 py::call_guard<py::gil_scoped_release>());
809 m.def("TF_ImportGraphDefOptionsAddReturnOutput",
810 TF_ImportGraphDefOptionsAddReturnOutput,
811 py::call_guard<py::gil_scoped_release>());
812
813 m.def(
814 "TF_GraphImportGraphDefWithResults",
815 [](TF_Graph* graph, const TF_Buffer* graph_def,
816 const TF_ImportGraphDefOptions* options) {
817 tensorflow::Safe_TF_StatusPtr status =
818 tensorflow::make_safe(TF_NewStatus());
819 auto output = TF_GraphImportGraphDefWithResults(graph, graph_def,
820 options, status.get());
821 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
822 return output;
823 },
824 py::return_value_policy::reference);
825
826 m.def(
827 "TF_GraphNextOperation",
828 [](TF_Graph* graph, size_t pos) {
829 tensorflow::Safe_TF_StatusPtr status =
830 tensorflow::make_safe(TF_NewStatus());
831 auto output = TF_GraphNextOperation(graph, &pos);
832 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
833
834 // Returns a (TF_Operation*, int pos) tuple.
835 py::tuple result_tuple = py::make_tuple(
836 py::cast(output), tensorflow::Pyo(PyLong_FromSize_t(pos)));
837 return result_tuple;
838 },
839 py::return_value_policy::reference);
840
841 // Python needs to own deletion of outputs
842 m.def("TF_ImportGraphDefResultsReturnOutputs",
843 [](TF_ImportGraphDefResults* results) {
844 int num_outputs;
845 TF_Output* outputs;
846 TF_ImportGraphDefResultsReturnOutputs(results, &num_outputs,
847 &outputs);
848 py::list py_list;
849 for (int i = 0; i < num_outputs; ++i) {
850 TF_Output tf_output = TF_Output(outputs[i]);
851 py_list.append(tf_output);
852 }
853 return py_list;
854 });
855
856 m.def(
857 "TF_ImportGraphDefResultsReturnOperations",
858 [](TF_ImportGraphDefResults* results) {
859 int num_opers;
860 TF_Operation** opers;
861 TF_ImportGraphDefResultsReturnOperations(results, &num_opers, &opers);
862 py::list py_list;
863 for (int i = 0; i < num_opers; ++i) {
864 py_list.append(opers[i]);
865 }
866 return py_list;
867 },
868 py::return_value_policy::reference);
869
870 m.def("TF_GraphToGraphDef", [](TF_Graph* graph, TF_Buffer* output_graph_def) {
871 tensorflow::Safe_TF_StatusPtr status =
872 tensorflow::make_safe(TF_NewStatus());
873 // Release GIL.
874 py::gil_scoped_release release;
875 TF_GraphToGraphDef(graph, output_graph_def, status.get());
876 tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get());
877 });
878
879 m.def("TF_OperationNumInputs", TF_OperationNumInputs,
880 py::call_guard<py::gil_scoped_release>());
881
882 m.def("TF_GraphVersions", [](TF_Graph* graph, TF_Buffer* output_graph_def) {
883 tensorflow::Safe_TF_StatusPtr status =
884 tensorflow::make_safe(TF_NewStatus());
885 // Release GIL.
886 py::gil_scoped_release release;
887 TF_GraphVersions(graph, output_graph_def, status.get());
888 tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get());
889 });
890
891 m.def("TF_DeleteFunction", TF_DeleteFunction,
892 py::call_guard<py::gil_scoped_release>());
893 m.def("TF_DeleteImportGraphDefResults", TF_DeleteImportGraphDefResults,
894 py::call_guard<py::gil_scoped_release>());
895 m.def("TF_DeleteImportGraphDefOptions", TF_DeleteImportGraphDefOptions,
896 py::call_guard<py::gil_scoped_release>());
897
898 m.def("TF_FunctionSetAttrValueProto",
899 [](TF_Function* func, const char* attr_name, py::bytes proto) {
900 tensorflow::Safe_TF_StatusPtr status =
901 tensorflow::make_safe(TF_NewStatus());
902 tensorflow::Safe_TF_BufferPtr buf =
903 tensorflow::make_safe(ProtoStringToTFBuffer(proto.ptr()));
904 // Release GIL.
905 py::gil_scoped_release release;
906 TF_FunctionSetAttrValueProto(func, attr_name, buf.get()->data,
907 buf.get()->length, status.get());
908 tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get());
909 });
910
911 m.def("TF_FunctionToFunctionDef",
912 [](TF_Function* graph, TF_Buffer* output_func_def) {
913 tensorflow::Safe_TF_StatusPtr status =
914 tensorflow::make_safe(TF_NewStatus());
915 // Release GIL.
916 py::gil_scoped_release release;
917 TF_FunctionToFunctionDef(graph, output_func_def, status.get());
918 tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get());
919 });
920
921 m.def("TF_GraphCopyFunction",
922 [](TF_Graph* graph, const TF_Function* func, const TF_Function* grad) {
923 tensorflow::Safe_TF_StatusPtr status =
924 tensorflow::make_safe(TF_NewStatus());
925 // Release GIL.
926 py::gil_scoped_release release;
927 TF_GraphCopyFunction(graph, func, grad, status.get());
928 tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get());
929 });
930
931 m.def(
932 "TF_FunctionImportFunctionDef",
933 [](py::bytes proto) {
934 tensorflow::Safe_TF_StatusPtr status =
935 tensorflow::make_safe(TF_NewStatus());
936 tensorflow::Safe_TF_BufferPtr buf =
937 tensorflow::make_safe(ProtoStringToTFBuffer(proto.ptr()));
938
939 // Release GIL.
940 py::gil_scoped_release release;
941 auto output = TF_FunctionImportFunctionDef(
942 buf.get()->data, buf.get()->length, status.get());
943
944 // Acquire GIL for returning output returning.
945 pybind11::gil_scoped_acquire acquire;
946 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
947 return output;
948 },
949 py::return_value_policy::reference);
950
951 m.def("EqualAttrValueWrapper", tensorflow::EqualAttrValueWrapper,
952 py::call_guard<py::gil_scoped_release>());
953
954 m.def(
955 "TF_GetAllRegisteredKernels",
956 []() {
957 tensorflow::Safe_TF_StatusPtr status =
958 tensorflow::make_safe(TF_NewStatus());
959 // Release GIL.
960 py::gil_scoped_release release;
961 auto output = TF_GetAllRegisteredKernels(status.get());
962 tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get());
963 return output;
964 },
965 py::return_value_policy::reference);
966
967 m.def(
968 "TF_GetRegisteredKernelsForOp",
969 [](const char* name) {
970 tensorflow::Safe_TF_StatusPtr status =
971 tensorflow::make_safe(TF_NewStatus());
972 // Release GIL.
973 py::gil_scoped_release release;
974 auto output = TF_GetRegisteredKernelsForOp(name, status.get());
975 tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get());
976 return output;
977 },
978 py::return_value_policy::reference);
979
980 m.def("TF_GetAllOpList", TF_GetAllOpList, py::return_value_policy::reference,
981 py::call_guard<py::gil_scoped_release>());
982
983 m.def(
984 "TF_NewApiDefMap",
985 [](TF_Buffer* op_list_buffer) {
986 tensorflow::Safe_TF_StatusPtr status =
987 tensorflow::make_safe(TF_NewStatus());
988 // Release GIL.
989 py::gil_scoped_release release;
990 auto output = TF_NewApiDefMap(op_list_buffer, status.get());
991 tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get());
992 return output;
993 },
994 py::return_value_policy::reference);
995
996 m.def("TF_DeleteApiDefMap", TF_DeleteApiDefMap,
997 py::call_guard<py::gil_scoped_release>());
998
999 m.def(
1000 "TF_ApiDefMapGet",
1001 [](TF_ApiDefMap* api_def_map, const char* name, size_t name_len) {
1002 tensorflow::Safe_TF_StatusPtr status =
1003 tensorflow::make_safe(TF_NewStatus());
1004 // Release GIL.
1005 py::gil_scoped_release release;
1006 auto output =
1007 TF_ApiDefMapGet(api_def_map, name, name_len, status.get());
1008 tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get());
1009 return output;
1010 },
1011 py::return_value_policy::reference);
1012
1013 m.def("TF_ApiDefMapPut",
1014 [](TF_ApiDefMap* api_def_map, const char* name, size_t name_len) {
1015 tensorflow::Safe_TF_StatusPtr status =
1016 tensorflow::make_safe(TF_NewStatus());
1017 // Release GIL.
1018 py::gil_scoped_release release;
1019 TF_ApiDefMapPut(api_def_map, name, name_len, status.get());
1020 tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get());
1021 });
1022
1023 m.def("TF_OperationGetAttrType",
1024 [](TF_Operation* oper, const char* attr_name) {
1025 tensorflow::Safe_TF_StatusPtr status =
1026 tensorflow::make_safe(TF_NewStatus());
1027 TF_DataType value;
1028 // Release GIL.
1029 py::gil_scoped_release release;
1030 TF_OperationGetAttrType(oper, attr_name, &value, status.get());
1031 tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get());
1032 return value;
1033 });
1034
1035 m.def(
1036 "TF_NewServer",
1037 [](py::bytes proto) {
1038 tensorflow::Safe_TF_StatusPtr status =
1039 tensorflow::make_safe(TF_NewStatus());
1040 tensorflow::Safe_TF_BufferPtr buf =
1041 tensorflow::make_safe(ProtoStringToTFBuffer(proto.ptr()));
1042 TF_Server* output =
1043 TF_NewServer(buf.get()->data, buf.get()->length, status.get());
1044 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1045 return output;
1046 },
1047 py::return_value_policy::reference);
1048
1049 m.def("TF_ServerStart", [](TF_Server* server) {
1050 tensorflow::Safe_TF_StatusPtr status =
1051 tensorflow::make_safe(TF_NewStatus());
1052 // Release GIL.
1053 py::gil_scoped_release release;
1054 TF_ServerStart(server, status.get());
1055 tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get());
1056 });
1057
1058 m.def("TF_ServerStop", [](TF_Server* server) {
1059 tensorflow::Safe_TF_StatusPtr status =
1060 tensorflow::make_safe(TF_NewStatus());
1061 // Release GIL for threading.
1062 py::gil_scoped_release release;
1063 TF_ServerStop(server, status.get());
1064 tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get());
1065 });
1066
1067 m.def("TF_ServerJoin", [](TF_Server* server) {
1068 tensorflow::Safe_TF_StatusPtr status =
1069 tensorflow::make_safe(TF_NewStatus());
1070 // Release GIL for threading.
1071 py::gil_scoped_release release;
1072 TF_ServerJoin(server, status.get());
1073 tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get());
1074 });
1075
1076 m.def(
1077 "TF_ServerTarget",
1078 [](TF_Server* server) { return TF_ServerTarget(server); },
1079 py::call_guard<py::gil_scoped_release>());
1080
1081 m.def(
1082 "TF_SessionListDevices",
1083 [](TF_Session* session) {
1084 tensorflow::Safe_TF_StatusPtr status =
1085 tensorflow::make_safe(TF_NewStatus());
1086 TF_DeviceList* output = TF_SessionListDevices(session, status.get());
1087 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1088 return output;
1089 },
1090 py::return_value_policy::reference);
1091
1092 m.def("TF_DeviceListCount",
1093 [](const TF_DeviceList* list) { return TF_DeviceListCount(list); });
1094
1095 m.def("TF_DeviceListName", [](const TF_DeviceList* list, int index) {
1096 tensorflow::Safe_TF_StatusPtr status =
1097 tensorflow::make_safe(TF_NewStatus());
1098 const char* output = TF_DeviceListName(list, index, status.get());
1099 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1100 return output;
1101 });
1102
1103 m.def("TF_DeviceListType", [](const TF_DeviceList* list, int index) {
1104 tensorflow::Safe_TF_StatusPtr status =
1105 tensorflow::make_safe(TF_NewStatus());
1106 const char* output = TF_DeviceListType(list, index, status.get());
1107 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1108 return output;
1109 });
1110
1111 m.def("TF_DeviceListMemoryBytes", [](const TF_DeviceList* list, int index) {
1112 tensorflow::Safe_TF_StatusPtr status =
1113 tensorflow::make_safe(TF_NewStatus());
1114 int64_t output = TF_DeviceListMemoryBytes(list, index, status.get());
1115 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1116 return output;
1117 });
1118
1119 m.def("TF_DeviceListIncarnation", [](const TF_DeviceList* list, int index) {
1120 tensorflow::Safe_TF_StatusPtr status =
1121 tensorflow::make_safe(TF_NewStatus());
1122 int64_t output = TF_DeviceListIncarnation(list, index, status.get());
1123 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1124 return output;
1125 });
1126
1127 m.def("TF_SetDevice", TF_SetDevice);
1128
1129 m.def("TF_DeleteDeviceList", TF_DeleteDeviceList);
1130
1131 m.def("TF_OperationGetAttrBool",
1132 [](TF_Operation* oper, const char* attr_name) {
1133 tensorflow::Safe_TF_StatusPtr status =
1134 tensorflow::make_safe(TF_NewStatus());
1135 unsigned char value;
1136 // Release GIL for threading.
1137 {
1138 py::gil_scoped_release release;
1139 TF_OperationGetAttrBool(oper, attr_name, &value, status.get());
1140 tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get());
1141 }
1142 return tensorflow::Pyo(PyBool_FromLong(value));
1143 });
1144
1145 m.def("TF_NewStatus", TF_NewStatus, py::return_value_policy::reference);
1146 m.def("TF_DeleteStatus", TF_DeleteStatus);
1147
1148 m.def("TF_DeleteDeviceList", TF_DeleteDeviceList);
1149
1150 m.def("AddWhileInputHack",
1151 [](TF_Graph* graph, TF_Output new_src, TF_Operation* dst) {
1152 tensorflow::Safe_TF_StatusPtr status =
1153 tensorflow::make_safe(TF_NewStatus());
1154 // Release GIL for threading.
1155 py::gil_scoped_release release;
1156 tensorflow::AddWhileInputHack(graph, new_src, dst, status.get());
1157 tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get());
1158 });
1159
1160 m.def("TF_Reset_wrapper", [](const TF_SessionOptions* opt,
1161 const std::vector<py::bytes> containers) {
1162 tensorflow::Safe_TF_StatusPtr status =
1163 tensorflow::make_safe(TF_NewStatus());
1164 // Release GIL for threading.
1165 py::gil_scoped_release release;
1166 tensorflow::NameVector containers_name_vector =
1167 ConvertPyListToNameVector(containers);
1168 tensorflow::TF_Reset_wrapper(opt, containers_name_vector, status.get());
1169 tensorflow::MaybeRaiseRegisteredFromTFStatusWithGIL(status.get());
1170 });
1171 m.def("TF_GetCode", TF_GetCode);
1172
1173 m.def("TF_SetXlaAutoJitMode", TF_SetXlaAutoJitMode);
1174 m.def("TF_GetXlaAutoJitEnabled", TF_GetXlaAutoJitEnabled);
1175 m.def("TF_SetXlaEnableLazyCompilation", TF_SetXlaEnableLazyCompilation);
1176 m.def("TF_SetTfXlaCpuGlobalJit", TF_SetTfXlaCpuGlobalJit);
1177 m.def("TF_SetXlaMinClusterSize", TF_SetXlaMinClusterSize);
1178 m.def("TF_GetXlaConstantFoldingDisabled", TF_GetXlaConstantFoldingDisabled);
1179 m.def("TF_SetXlaConstantFoldingDisabled", TF_SetXlaConstantFoldingDisabled);
1180
1181 // // Static constants are not working on Windows. b/145559202
1182 // // Creating getters instead.
1183
1184 m.def("get_version", []() { return TF_VERSION_STRING; });
1185 m.def("get_git_version", []() { return TF_GIT_VERSION; });
1186 m.def("get_compiler_version", []() { return TF_COMPILER_VERSION; });
1187 m.def("get_cxx11_abi_flag", []() { return TF_CXX11_ABI_FLAG; });
1188 m.def("get_cxx_version", []() { return TF_CXX_VERSION; });
1189 m.def("get_eigen_max_align_bytes", []() { return EIGEN_MAX_ALIGN_BYTES; });
1190 m.def("get_monolithic_build", []() { return TF_MONOLITHIC_BUILD; });
1191 m.def("get_graph_def_version", []() { return TF_GRAPH_DEF_VERSION; });
1192 m.def("get_graph_def_version_min_consumer",
1193 []() { return TF_GRAPH_DEF_VERSION_MIN_CONSUMER; });
1194 m.def("get_graph_def_version_min_producer",
1195 []() { return TF_GRAPH_DEF_VERSION_MIN_PRODUCER; });
1196 m.def("get_tensor_handle_key", []() {
1197 // TODO(amitpatankar): Look into a more elegant solution.
1198 // Since this is a shared object we will hard code the value from
1199 // third_party/tensorflow/core/common_runtime/session_state.cc because
1200 // the Windows import will not load the libraries necessarily
1201 // in order. b/145559202
1202 return "TensorHandle";
1203 });
1204
1205 m.def("TF_RegisterFilesystemPlugin", [](const char* plugin_filename) {
1206 tensorflow::Safe_TF_StatusPtr status =
1207 tensorflow::make_safe(TF_NewStatus());
1208 TF_RegisterFilesystemPlugin(plugin_filename, status.get());
1209 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1210 });
1211
1212 py::enum_<TF_DataType>(m, "TF_DataType")
1213 .value("TF_FLOAT", TF_FLOAT)
1214 .value("TF_DOUBLE", TF_DOUBLE)
1215 .value("TF_INT32", TF_INT32)
1216 .value("TF_UINT8", TF_UINT8)
1217 .value("TF_INT16", TF_INT16)
1218 .value("TF_INT8", TF_INT8)
1219 .value("TF_STRING", TF_STRING)
1220 .value("TF_COMPLEX64", TF_COMPLEX64)
1221 .value("TF_COMPLEX", TF_COMPLEX)
1222 .value("TF_INT64", TF_INT64)
1223 .value("TF_BOOL", TF_BOOL)
1224 .value("TF_QINT8", TF_QINT8)
1225 .value("TF_QUINT8", TF_QUINT8)
1226 .value("TF_QINT32", TF_QINT32)
1227 .value("TF_BFLOAT16", TF_BFLOAT16)
1228 .value("TF_QINT16", TF_QINT16)
1229 .value("TF_QUINT16", TF_QUINT16)
1230 .value("TF_UINT16", TF_UINT16)
1231 .value("TF_COMPLEX128", TF_COMPLEX128)
1232 .value("TF_HALF", TF_HALF)
1233 .value("TF_RESOURCE", TF_RESOURCE)
1234 .value("TF_VARIANT", TF_VARIANT)
1235 .value("TF_UINT32", TF_UINT32)
1236 .value("TF_UINT64", TF_UINT64)
1237 .export_values();
1238
1239 py::enum_<TF_Code>(m, "TF_Code")
1240 .value("TF_OK", TF_OK)
1241 .value("TF_CANCELLED", TF_CANCELLED)
1242 .value("TF_UNKNOWN", TF_UNKNOWN)
1243 .value("TF_INVALID_ARGUMENT", TF_INVALID_ARGUMENT)
1244 .value("TF_DEADLINE_EXCEEDED", TF_DEADLINE_EXCEEDED)
1245 .value("TF_PERMISSION_DENIED", TF_PERMISSION_DENIED)
1246 .value("TF_UNAUTHENTICATED", TF_UNAUTHENTICATED)
1247 .value("TF_RESOURCE_EXHAUSTED", TF_RESOURCE_EXHAUSTED)
1248 .value("TF_FAILED_PRECONDITION", TF_FAILED_PRECONDITION)
1249 .value("TF_ABORTED", TF_ABORTED)
1250 .value("TF_OUT_OF_RANGE", TF_OUT_OF_RANGE)
1251 .value("TF_UNIMPLEMENTED", TF_UNIMPLEMENTED)
1252 .value("TF_INTERNAL", TF_INTERNAL)
1253 .value("TF_DATA_LOSS", TF_DATA_LOSS)
1254 .export_values();
1255};
1256