1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");;
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <memory>
17
18#include "Python.h"
19#include "absl/strings/match.h"
20#include "absl/strings/str_format.h"
21#include "absl/strings/str_split.h"
22#include "pybind11/chrono.h"
23#include "pybind11/complex.h"
24#include "pybind11/functional.h"
25#include "pybind11/pybind11.h"
26#include "pybind11/pytypes.h"
27#include "pybind11/stl.h"
28#include "tensorflow/c/c_api.h"
29#include "tensorflow/c/c_api_experimental.h"
30#include "tensorflow/c/eager/c_api.h"
31#include "tensorflow/c/eager/c_api_experimental.h"
32#include "tensorflow/c/eager/c_api_internal.h"
33#include "tensorflow/c/eager/dlpack.h"
34#include "tensorflow/c/eager/tfe_cancellation_manager_internal.h"
35#include "tensorflow/c/eager/tfe_context_internal.h"
36#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
37#include "tensorflow/c/tf_status.h"
38#include "tensorflow/c/tf_status_helper.h"
39#include "tensorflow/compiler/jit/flags.h"
40#include "tensorflow/compiler/jit/get_compiler_ir.h"
41#include "tensorflow/core/common_runtime/eager/context.h"
42#include "tensorflow/python/eager/pywrap_tensor_conversion.h"
43#include "tensorflow/python/eager/pywrap_tfe.h"
44#include "tensorflow/python/lib/core/py_exception_registry.h"
45#include "tensorflow/python/lib/core/pybind11_lib.h"
46#include "tensorflow/python/lib/core/pybind11_status.h"
47#include "tensorflow/python/lib/core/safe_ptr.h"
48#include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
49#include "tensorflow/python/util/util.h"
50
51namespace py = pybind11;
52
53PYBIND11_MAKE_OPAQUE(TFE_Executor);
54PYBIND11_MAKE_OPAQUE(TFE_ContextOptions);
55PYBIND11_MAKE_OPAQUE(tensorflow::CancellationManager);
56
57PYBIND11_MAKE_OPAQUE(TFE_MonitoringCounter0);
58PYBIND11_MAKE_OPAQUE(TFE_MonitoringCounter1);
59PYBIND11_MAKE_OPAQUE(TFE_MonitoringCounter2);
60PYBIND11_MAKE_OPAQUE(TFE_MonitoringStringGauge0);
61PYBIND11_MAKE_OPAQUE(TFE_MonitoringStringGauge1);
62PYBIND11_MAKE_OPAQUE(TFE_MonitoringStringGauge2);
63PYBIND11_MAKE_OPAQUE(TFE_MonitoringStringGauge3);
64PYBIND11_MAKE_OPAQUE(TFE_MonitoringStringGauge4);
65PYBIND11_MAKE_OPAQUE(TFE_MonitoringIntGauge0);
66PYBIND11_MAKE_OPAQUE(TFE_MonitoringIntGauge1);
67PYBIND11_MAKE_OPAQUE(TFE_MonitoringIntGauge2);
68PYBIND11_MAKE_OPAQUE(TFE_MonitoringBoolGauge0);
69PYBIND11_MAKE_OPAQUE(TFE_MonitoringBoolGauge1);
70PYBIND11_MAKE_OPAQUE(TFE_MonitoringBoolGauge2);
71PYBIND11_MAKE_OPAQUE(TFE_MonitoringSampler0);
72PYBIND11_MAKE_OPAQUE(TFE_MonitoringSampler1);
73PYBIND11_MAKE_OPAQUE(TFE_MonitoringSampler2);
74PYBIND11_MAKE_OPAQUE(TFE_MonitoringCounterCell);
75PYBIND11_MAKE_OPAQUE(TFE_MonitoringIntGaugeCell);
76PYBIND11_MAKE_OPAQUE(TFE_MonitoringStringGaugeCell);
77PYBIND11_MAKE_OPAQUE(TFE_MonitoringBoolGaugeCell);
78PYBIND11_MAKE_OPAQUE(TFE_MonitoringSamplerCell);
79
80PYBIND11_MAKE_OPAQUE(TF_DeviceList);
81PYBIND11_MAKE_OPAQUE(TF_Function);
82PYBIND11_MAKE_OPAQUE(TF_Buffer);
83
84// Eager helper functions migrated from pywrap_tfe.i.
85
86namespace tensorflow {
87
88// We cannot use Context as an opaque type. SWIG also had
89// difficult directly passing the pointer around. These
90// typemaps are migrated over from pywrap_tfe.i. I tried
91// using a custom type caster, but we get segfaults periodically.
92
93// TODO(amitpatankar): Move input and output logic of Context into a
94// pybind11 custom type caster.
95
96TFE_Context* InputTFE_Context(const py::handle& ctx) {
97 return static_cast<TFE_Context*>(PyCapsule_GetPointer(ctx.ptr(), nullptr));
98}
99
100PyObject* OutputTFE_Context(TFE_Context* context) {
101 return PyCapsule_New(context, nullptr, TFE_DeleteContextCapsule);
102}
103
104TF_Buffer* ProtoStringToTFBuffer(PyObject* input) {
105 // Convert a Python string object to TF_Buffer.
106 char* c_string;
107 Py_ssize_t py_size;
108 // PyBytes_AsStringAndSize() does not copy but simply interprets the input
109 if (PyBytes_AsStringAndSize(input, &c_string, &py_size) == -1) {
110 // Python has raised an error (likely TypeError or UnicodeEncodeError).
111 throw py::error_already_set();
112 }
113 return TF_NewBufferFromString(static_cast<void*>(c_string),
114 static_cast<size_t>(py_size));
115}
116
117// These functions are typemaps from the Python side. I did not use
118// a custom type caster since the logic is slightly harder to follow. This
119// converter is also only used once in `TFE_Py_ExecuteCancelable_wrapper`.
120TFE_InputTensorHandles InputTFE_InputTensorHandles(
121 const py::handle& input_tensors) {
122 TFE_InputTensorHandles input_tensor_handles;
123 if (input_tensors.ptr() != Py_None) {
124 if (!PyList_Check(input_tensors.ptr())) {
125 tensorflow::ThrowTypeError("must provide a list of Tensors as inputs");
126 }
127 Py_ssize_t len = PyList_Size(input_tensors.ptr());
128 input_tensor_handles.resize(len);
129 for (Py_ssize_t i = 0; i < len; ++i) {
130 PyObject* elem = PyList_GetItem(input_tensors.ptr(), i);
131 if (!elem) {
132 tensorflow::ThrowTypeError("Input Tensor does not exist.");
133 }
134 if (EagerTensor_CheckExact(elem)) {
135 (input_tensor_handles)[i] = EagerTensor_Handle(elem);
136 } else if (tensorflow::swig::IsEagerTensorSlow(elem)) {
137 // Use equivalent of object.__getattribute__ to get the underlying
138 // tf wrapped EagerTensor (if there is one).
139 tensorflow::Safe_PyObjectPtr tf_should_use_attr(
140#if PY_MAJOR_VERSION < 3
141 PyString_InternFromString("_tf_should_use_wrapped_value")
142#else
143 PyUnicode_InternFromString("_tf_should_use_wrapped_value")
144#endif
145 );
146 tensorflow::Safe_PyObjectPtr value_attr(
147 PyObject_GenericGetAttr(elem, tf_should_use_attr.get()));
148 if (value_attr) {
149 // This is an EagerTensor wrapped inside a TFShouldUse wrapped object.
150 (input_tensor_handles)[i] = EagerTensor_Handle(value_attr.get());
151 } else {
152 // This is a subclass of EagerTensor that we don't support.
153 PyErr_Clear();
154 tensorflow::ThrowTypeError(
155 tensorflow::strings::StrCat(
156 "Saw an object that is an instance of a strict subclass of "
157 "EagerTensor, which is not supported. Item ",
158 i, " is type: ", elem->ob_type->tp_name)
159 .c_str());
160 }
161 } else if (tensorflow::swig::IsTensor(elem)) {
162 // If it isnt an EagerTensor, but is still a Tensor, it must be a graph
163 // tensor.
164 tensorflow::Safe_PyObjectPtr py_tensor_repr(PyObject_Repr(elem));
165 std::string tensor_repr =
166 py_tensor_repr ? TFE_GetPythonString(py_tensor_repr.get())
167 : "<unknown>";
168 tensorflow::Safe_PyObjectPtr py_op(PyObject_GetAttrString(elem, "op"));
169 tensorflow::Safe_PyObjectPtr py_defined_graph(
170 PyObject_GetAttrString(py_op.get(), "graph"));
171 tensorflow::Safe_PyObjectPtr py_defined_graph_str(
172 PyObject_Str(py_defined_graph.get()));
173 std::string defined_graph_str =
174 py_defined_graph_str
175 ? TFE_GetPythonString(py_defined_graph_str.get())
176 : "<unknown>";
177 tensorflow::Safe_PyObjectPtr c_op(
178 PyObject_GetAttrString(py_op.get(), "_c_op"));
179 auto& node = py::cast<TF_Operation*>(c_op.get())->node;
180 auto node_name_str = node.name();
181 std::string frame_str, traceback_str;
182 if (auto stack_trace = node.GetStackTrace()) {
183 auto frame = stack_trace->LastUserFrame();
184 frame_str =
185 absl::StrFormat("File \"%s\", line %d, in %s", frame.file_name,
186 frame.line_number, frame.function_name);
187 auto stack_trace_list =
188 absl::StrSplit(stack_trace->ToString({true}), '\n');
189 traceback_str = absl::StrJoin(
190 stack_trace_list, "", [&](std::string* out, const auto line) {
191 absl::StrAppend(out, " ", line, "\n");
192 });
193 } else {
194 frame_str = "<unknown>";
195 traceback_str = "<unknown>\n";
196 }
197 // Keep in sync with func_graph.py.
198 // TODO(b/200991648): Unify those two paths.
199 tensorflow::ThrowTypeError(
200 tensorflow::strings::StrCat(
201 tensor_repr,
202 " is out of scope and cannot be used here. "
203 "Use return values, explicit Python locals or TensorFlow "
204 "collections to access it.\n"
205 "Please see https://www.tensorflow.org/guide/"
206 "function#all_outputs_of_a_tffunction_must_be_return_values "
207 "for more information.\n\n",
208 tensor_repr, " was defined here:\n", traceback_str,
209 "\nThe tensor ", tensor_repr,
210 " cannot be accessed from here, because it was "
211 "defined in ",
212 defined_graph_str, ", which is out of scope.")
213 .c_str());
214 } else {
215 tensorflow::ThrowTypeError(
216 tensorflow::strings::StrCat(
217 "provided list of inputs contains objects other "
218 "than 'EagerTensor'. Item ",
219 i, " is type: ", elem->ob_type->tp_name)
220 .c_str());
221 }
222 }
223 }
224 return input_tensor_handles;
225}
226
227// These functions are typemaps from the Python side. I did not use
228// a custom type caster since the logic is slightly harder to follow. This
229// converter is also only used once in `TFE_Py_ExecuteCancelable_wrapper`.
230// This function actually takes a number rather than an output Tensor holder.
231TFE_OutputTensorHandles InputTFE_OutputTensorHandles(
232 const py::handle& num_outputs) {
233 TFE_OutputTensorHandles output_tensor_handles;
234#if PY_MAJOR_VERSION < 3
235 if (!PyInt_Check(num_outputs.ptr())) {
236#else
237 if (!PyLong_Check(num_outputs.ptr())) {
238#endif
239 PyErr_SetString(PyExc_TypeError,
240 "expected an integer value (size of the number of "
241 "outputs of the operation)");
242 throw py::error_already_set();
243 }
244#if PY_MAJOR_VERSION < 3
245 long sz = PyInt_AsLong(num_outputs.ptr()); // NOLINT
246#else
247 long sz = PyLong_AsLong(num_outputs.ptr()); // NOLINT
248#endif
249 // PyLong_AsLong might throw an error if an overflow occurs.
250 if (PyErr_Occurred()) {
251 PyErr_SetString(PyExc_ValueError, tensorflow::strings::StrCat(
252 "Number of outputs is too big: ", sz)
253 .c_str());
254 throw py::error_already_set();
255 }
256 // We can't handle more than int32 sizes for number of outputs.
257 if (static_cast<long>(static_cast<int32_t>(sz)) != sz) { // NOLINT
258 PyErr_SetString(PyExc_ValueError, tensorflow::strings::StrCat(
259 "Number of outputs is too big: ", sz)
260 .c_str());
261 throw py::error_already_set();
262 }
263 if (sz > 0) {
264#if PY_MAJOR_VERSION < 3
265 output_tensor_handles.resize(PyInt_AsLong(num_outputs.ptr()), nullptr);
266#else
267 output_tensor_handles.resize(PyLong_AsLong(num_outputs.ptr()), nullptr);
268#endif
269 }
270 return output_tensor_handles;
271}
272
273tensorflow::Device* GetMatchedDevice(py::handle& ctx, const char* device_name) {
274 auto* context = reinterpret_cast<tensorflow::ImmediateExecutionContext*>(
275 tensorflow::InputTFE_Context(ctx));
276
277 tensorflow::DeviceNameUtils::ParsedName input_device_name;
278 if (!tensorflow::DeviceNameUtils::ParseFullOrLocalName(device_name,
279 &input_device_name)) {
280 tensorflow::ThrowValueError(
281 absl::StrFormat("Failed parsing device name: '%s'. Note a valid device "
282 "string should at least contain a device type and a "
283 "device index, like \"GPU:0\".",
284 device_name)
285 .c_str());
286 }
287
288 std::vector<tensorflow::Device*> devices = context->ListLocalTfDevices();
289
290 tensorflow::Device* matched_device = nullptr;
291 for (int device_idx = 0; device_idx < devices.size(); device_idx++) {
292 tensorflow::Device* device = devices[device_idx];
293
294 if (tensorflow::DeviceNameUtils::AreCompatibleDevNames(
295 input_device_name, device->parsed_name())) {
296 if (matched_device != nullptr) {
297 tensorflow::ThrowValueError(
298 absl::StrFormat("Multiple devices match the provided string "
299 "'%s': '%s' and '%s'.",
300 device_name, matched_device->name(), device->name())
301 .c_str());
302 }
303 matched_device = device;
304 }
305 }
306
307 if (matched_device == nullptr) {
308 tensorflow::ThrowValueError(
309 absl::StrFormat("No matching devices found for '%s'", device_name)
310 .c_str());
311 }
312
313 return matched_device;
314}
315
316// Packs multiple `EagerTensor`s of the same dtype and shape into one
317// `EagerTensor`.
318py::object TFE_Py_PackEagerTensors_wrapper(const py::handle& context,
319 const py::handle& tensors) {
320 TFE_Context* ctx = tensorflow::InputTFE_Context(context);
321 TFE_InputTensorHandles handles = InputTFE_InputTensorHandles(tensors);
322 tensorflow::Safe_TF_StatusPtr status = tensorflow::make_safe(TF_NewStatus());
323 int size = handles.size();
324 TFE_TensorHandle* packed_handle =
325 TFE_CreatePackedTensorHandle(ctx, handles.data(), &size, status.get());
326 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
327 PyObject* packed_tensor =
328 EagerTensorFromHandle(packed_handle, /*is_packed=*/true);
329 return tensorflow::PyoOrThrow(packed_tensor);
330}
331
332// This function was created from fusing the typemap logic in platform/base.i.
333py::object TFE_Py_ExecuteCancelable_wrapper(
334 const py::handle& context, const char* device_name, const char* op_name,
335 const py::handle& inputs, const py::handle& attrs,
336 tensorflow::CancellationManager* cancellation_manager,
337 const py::handle& num_outputs) {
338 TFE_Context* ctx = tensorflow::InputTFE_Context(context);
339 TFE_InputTensorHandles input_tensor_handles =
340 InputTFE_InputTensorHandles(inputs);
341 TFE_OutputTensorHandles output_tensor_handles =
342 InputTFE_OutputTensorHandles(num_outputs);
343 tensorflow::Safe_TF_StatusPtr status = tensorflow::make_safe(TF_NewStatus());
344 TFE_Py_ExecuteCancelable(ctx, device_name, op_name, &input_tensor_handles,
345 attrs.ptr(), tensorflow::wrap(cancellation_manager),
346 &output_tensor_handles, status.get());
347
348 int output_len = output_tensor_handles.size();
349 PyObject* output_list = PyList_New(output_len);
350 for (int i = 0; i < output_len; ++i) {
351 PyObject* output;
352 output = EagerTensorFromHandle(output_tensor_handles.at(i));
353 PyList_SetItem(output_list, i, output);
354 }
355 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
356 return tensorflow::PyoOrThrow(output_list);
357}
358
359static py::object TF_ListPhysicalDevices() {
360 std::vector<string> devices;
361 tensorflow::Status s =
362 tensorflow::DeviceFactory::ListAllPhysicalDevices(&devices);
363 MaybeRaiseRegisteredFromStatus(s);
364 PyObject* result = PyList_New(devices.size());
365 int i = 0;
366 for (auto& dev : devices) {
367 PyObject* dev_obj = PyBytes_FromStringAndSize(dev.data(), dev.size());
368 PyList_SetItem(result, i, dev_obj);
369 ++i;
370 }
371 return tensorflow::PyoOrThrow(result);
372}
373
374static py::object TF_ListPluggablePhysicalDevices() {
375 std::vector<string> devices;
376 tensorflow::Status s =
377 tensorflow::DeviceFactory::ListPluggablePhysicalDevices(&devices);
378 MaybeRaiseRegisteredFromStatus(s);
379 Safe_PyObjectPtr result(PyList_New(devices.size()));
380 int i = 0;
381 for (auto& dev : devices) {
382 PyObject* dev_obj = PyBytes_FromStringAndSize(dev.data(), dev.size());
383 PyList_SetItem(result.get(), i, dev_obj);
384 ++i;
385 }
386 return tensorflow::PyoOrThrow(result.release());
387}
388
389static std::unordered_map<string, string> TF_GetDeviceDetails(int index) {
390 tensorflow::Safe_TF_StatusPtr status = tensorflow::make_safe(TF_NewStatus());
391 std::unordered_map<string, string> device_details;
392 tensorflow::Status s =
393 tensorflow::DeviceFactory::GetAnyDeviceDetails(index, &device_details);
394 tensorflow::Set_TF_Status_from_Status(status.get(), s);
395 MaybeRaiseRegisteredFromTFStatus(status.get());
396 return device_details;
397}
398
399static py::object TFE_ClearScalarCache() {
400 tensorflow::TFE_TensorHandleCache::Get()->Clear();
401 return py::none();
402}
403
404// Returns compiler IR for a given function.
405static py::bytes TFE_GetCompilerIr(py::handle& ctx,
406 const char* concrete_function_name,
407 const char* stage, const char* device_name,
408 py::handle& inputs) {
409 EagerContext* context = ContextFromInterface(
410 reinterpret_cast<ImmediateExecutionContext*>(InputTFE_Context(ctx)));
411
412 std::string s_stage(stage);
413 IrExportStage selected_stage = [&] {
414 if (s_stage == "hlo") {
415 return IrExportStage::HLO;
416 } else if (s_stage == "hlo_no_metadata") {
417 return IrExportStage::HLO_NO_METADATA;
418 } else if (s_stage == "hlo_serialized") {
419 return IrExportStage::HLO_SERIALIZED;
420 } else if (s_stage == "optimized_hlo") {
421 return IrExportStage::OPTIMIZED_HLO;
422 } else if (s_stage == "optimized_hlo_serialized") {
423 return IrExportStage::OPTIMIZED_HLO_SERIALIZED;
424 } else if (s_stage == "optimized_hlo_proto_serialized") {
425 return IrExportStage::OPTIMIZED_HLO_PROTO_SERIALIZED;
426 } else if (s_stage == "optimized_hlo_dot") {
427 return IrExportStage::OPTIMIZED_HLO_DOT;
428 } else {
429 ThrowValueError(
430 absl::StrFormat("Invalid stage selected: '%s'. Valid values are: "
431 "'hlo', 'hlo_serialized', 'optimized_hlo', "
432 "'optimized_hlo_serialized', 'optimized_hlo_dot'",
433 s_stage)
434 .c_str());
435 }
436 }();
437
438 TFE_InputTensorHandles handles = InputTFE_InputTensorHandles(inputs);
439
440 std::vector<const TensorHandle*> input_handles;
441 for (TFE_TensorHandle* tensor_handle : handles) {
442 AbstractTensorHandle* abstract_tensor_handle = unwrap(tensor_handle);
443 input_handles.push_back(TensorHandleFromInterface(abstract_tensor_handle));
444 }
445
446 DeviceNameUtils::ParsedName input_device_name;
447 if (!DeviceNameUtils::ParseFullOrLocalName(device_name, &input_device_name)) {
448 ThrowValueError(
449 absl::StrFormat("Failed parsing device name: '%s'", device_name)
450 .c_str());
451 }
452
453 std::vector<Device*> devices = context->local_device_mgr()->ListDevices();
454 auto selected_device = absl::c_find_if(devices, [&](const Device* d) {
455 return DeviceNameUtils::AreCompatibleDevNames(input_device_name,
456 d->parsed_name());
457 });
458 if (selected_device == devices.end()) {
459 ThrowValueError(
460 absl::StrFormat("No matching device found for '%s'", device_name)
461 .c_str());
462 }
463
464 StatusOr<std::string> hlo_str =
465 GetCompilerIr(selected_stage, context->pflr(), concrete_function_name,
466 *selected_device, context, input_handles);
467
468 if (!hlo_str.ok()) {
469 ThrowValueError(absl::StrFormat("Failed getting HLO text: '%s'",
470 hlo_str.status().error_message())
471 .c_str());
472 }
473 return py::bytes(*hlo_str);
474}
475
476} // namespace tensorflow
477
478namespace {
479
480// Wrapper around the EagerContextThreadLocalData struct (defined in
481// pywrap_tfe.h), so it can be accessed from Python.
482//
483// For PyObject* fields, the get_*() methods return a new reference; and the
484// set_*() methods create a new reference (i.e., they do not steal a reference).
485class EagerContextThreadLocalDataWrapper {
486 public:
487 explicit EagerContextThreadLocalDataWrapper(py::handle py_eager_context,
488 py::handle is_eager,
489 py::handle device_spec)
490 : py_eager_context_(py_eager_context.ptr()) {
491 tensorflow::MakeEagerContextThreadLocalData(
492 py_eager_context.ptr(), is_eager.ptr(), device_spec.ptr());
493 }
494
495 ~EagerContextThreadLocalDataWrapper() {
496 tensorflow::DestroyEagerContextThreadLocalData(py_eager_context_);
497 }
498
499 bool get_is_eager() const { return GetData()->is_eager; }
500 void set_is_eager(bool v) { GetData()->is_eager = v; }
501
502 bool get_invoking_op_callbacks() const {
503 return GetData()->invoking_op_callbacks;
504 }
505 void set_invoking_op_callbacks(bool v) {
506 GetData()->invoking_op_callbacks = v;
507 }
508
509 py::object get_device_name() const {
510 return GetPyObject(&GetData()->device_name);
511 }
512 void set_device_name(py::handle v) {
513 SetPyObject(v, &GetData()->device_name);
514 }
515
516 py::object get_scope_name() const {
517 return GetPyObject(&GetData()->scope_name);
518 }
519 void set_scope_name(py::handle v) { SetPyObject(v, &GetData()->scope_name); }
520
521 py::object get_device_spec() const {
522 return GetPyObject(&GetData()->device_spec);
523 }
524 void set_device_spec(py::handle v) {
525 SetPyObject(v, &GetData()->device_spec);
526 }
527
528 py::object get_function_call_options() const {
529 return GetPyObject(&GetData()->function_call_options);
530 }
531 void set_function_call_options(py::handle v) {
532 SetPyObject(v, &GetData()->function_call_options);
533 }
534
535 py::handle get_executor() const { return GetPyObject(&GetData()->executor); }
536 void set_executor(py::handle v) { SetPyObject(v, &GetData()->executor); }
537
538 py::object get_op_callbacks() const {
539 return GetPyObject(&GetData()->op_callbacks);
540 }
541 void set_op_callbacks(py::handle v) {
542 SetPyObject(v, &GetData()->op_callbacks);
543 }
544
545 private:
546 tensorflow::EagerContextThreadLocalData* GetData() const {
547 auto* result =
548 tensorflow::GetEagerContextThreadLocalData(py_eager_context_);
549 if (!result) {
550 throw py::error_already_set();
551 }
552 return result;
553 }
554
555 py::object GetPyObject(tensorflow::Safe_PyObjectPtr* obj) const {
556 return pybind11::reinterpret_borrow<py::object>(obj->get());
557 }
558
559 void SetPyObject(py::handle value, tensorflow::Safe_PyObjectPtr* ptr) {
560 Py_INCREF(value.ptr());
561 ptr->reset(value.ptr());
562 }
563
564 PyObject* py_eager_context_; // not owned (borrowed reference).
565};
566
567} // namespace
568
569// py::return_value_policy::reference is defined as specified by the
570// pybind11 documents listed here.
571// https://pybind11.readthedocs.io/en/stable/advanced/functions.html#return-value-policies
572// This means that C++ maintains ownership of the object. We
573// are only assigning this to functions that return opaque types.
574
575PYBIND11_MODULE(_pywrap_tfe, m) {
576 py::class_<TFE_Executor> TFE_Executor_class(m, "TFE_Executor");
577 py::class_<TFE_ContextOptions> TFE_ContextOptions_class(m,
578 "TFE_ContextOptions");
579 py::class_<TFE_MonitoringCounter0> TFE_MonitoringCounter0_class(
580 m, "TFE_MonitoringCounter0");
581 py::class_<TFE_MonitoringCounter1> TFE_MonitoringCounter1_class(
582 m, "TFE_MonitoringCounter1");
583 py::class_<TFE_MonitoringCounter2> TFE_MonitoringCounter2_class(
584 m, "TFE_MonitoringCounter2");
585 py::class_<TFE_MonitoringStringGauge0> TFE_MonitoringStringGauge0_class(
586 m, "TFE_MonitoringStringGauge0");
587 py::class_<TFE_MonitoringStringGauge1> TFE_MonitoringStringGauge1_class(
588 m, "TFE_MonitoringStringGauge1");
589 py::class_<TFE_MonitoringStringGauge2> TFE_MonitoringStringGauge2_class(
590 m, "TFE_MonitoringStringGauge2");
591 py::class_<TFE_MonitoringStringGauge3> TFE_MonitoringStringGauge3_class(
592 m, "TFE_MonitoringStringGauge3");
593 py::class_<TFE_MonitoringStringGauge4> TFE_MonitoringStringGauge4_class(
594 m, "TFE_MonitoringStringGauge4");
595 py::class_<TFE_MonitoringIntGauge0> TFE_MonitoringIntGauge0_class(
596 m, "TFE_MonitoringIntGauge0");
597 py::class_<TFE_MonitoringIntGauge1> TFE_MonitoringIntGauge1_class(
598 m, "TFE_MonitoringIntGauge1");
599 py::class_<TFE_MonitoringIntGauge2> TFE_MonitoringIntGauge2_class(
600 m, "TFE_MonitoringIntGauge2");
601 py::class_<TFE_MonitoringBoolGauge0> TFE_MonitoringBoolGauge0_class(
602 m, "TFE_MonitoringBoolGauge0");
603 py::class_<TFE_MonitoringBoolGauge1> TFE_MonitoringBoolGauge1_class(
604 m, "TFE_MonitoringBoolGauge1");
605 py::class_<TFE_MonitoringBoolGauge2> TFE_MonitoringBoolGauge2_class(
606 m, "TFE_MonitoringBoolGauge2");
607 py::class_<TFE_MonitoringCounterCell> TFE_MonitoringCounterCell_class(
608 m, "TFE_MonitoringCounterCell");
609 py::class_<TFE_MonitoringIntGaugeCell> TFE_MonitoringIntGaugeCell_class(
610 m, "TFE_MonitoringIntGaugeCell");
611 py::class_<TFE_MonitoringStringGaugeCell> TFE_MonitoringStringGaugeCell_class(
612 m, "TFE_MonitoringStringGaugeCell");
613 py::class_<TFE_MonitoringBoolGaugeCell> TFE_MonitoringBoolGaugeCell_class(
614 m, "TFE_MonitoringBoolGaugeCell");
615 py::class_<TFE_MonitoringSamplerCell> TFE_MonitoringSamplerCell_class(
616 m, "TFE_MonitoringSamplerCell");
617 py::class_<TFE_MonitoringBuckets> TFE_MonitoringBuckets_class(
618 m, "TFE_MonitoringBuckets");
619 py::class_<TFE_MonitoringSampler0> TFE_MonitoringSampler0_class(
620 m, "TFE_MonitoringSampler0");
621 py::class_<TFE_MonitoringSampler1> TFE_MonitoringSampler1_class(
622 m, "TFE_MonitoringSampler1");
623 py::class_<TFE_MonitoringSampler2> TFE_MonitoringSampler2_class(
624 m, "TFE_MonitoringSampler2");
625 py::class_<tensorflow::CancellationManager> TFE_CancellationManager_class(
626 m, "TFE_CancellationManager");
627
628 py::class_<TF_DeviceList> TF_DeviceList_class(m, "TF_DeviceList");
629 py::class_<TF_Function> TF_Function_class(m, "TF_Function");
630
631 m.def("TFE_Py_RegisterExceptionClass", [](const py::handle& e) {
632 return tensorflow::PyoOrThrow(TFE_Py_RegisterExceptionClass(e.ptr()));
633 });
634 m.def("TFE_Py_RegisterFallbackExceptionClass", [](const py::handle& e) {
635 return tensorflow::PyoOrThrow(
636 TFE_Py_RegisterFallbackExceptionClass(e.ptr()));
637 });
638
639 m.def("TFE_GetMemoryInfo", [](py::handle& ctx, const char* device_name) {
640 tensorflow::Device* matched_device =
641 tensorflow::GetMatchedDevice(ctx, device_name);
642
643 tensorflow::AllocatorAttributes attrs;
644 tensorflow::Allocator* allocator = matched_device->GetAllocator(attrs);
645
646 if (absl::optional<tensorflow::AllocatorStats> stats =
647 allocator->GetStats()) {
648 return std::map<std::string, int64_t>{{"current", stats->bytes_in_use},
649 {"peak", stats->peak_bytes_in_use}};
650 }
651
652 tensorflow::ThrowValueError(
653 absl::StrFormat("Allocator stats not available for device '%s'",
654 device_name)
655 .c_str());
656 });
657
658 m.def("TFE_ResetMemoryStats", [](py::handle& ctx, const char* device_name) {
659 tensorflow::Device* matched_device =
660 tensorflow::GetMatchedDevice(ctx, device_name);
661
662 tensorflow::AllocatorAttributes attrs;
663 tensorflow::Allocator* allocator = matched_device->GetAllocator(attrs);
664
665 if (!allocator->ClearStats()) {
666 tensorflow::ThrowValueError(
667 absl::StrFormat("Cannot reset memory stats for device '%s'",
668 device_name)
669 .c_str());
670 }
671 });
672
673 // XLA Eager Logic
674 m.def("TF_SetXlaEnableLazyCompilation", &TF_SetXlaEnableLazyCompilation);
675 m.def("TF_SetTfXlaCpuGlobalJit", &TF_SetTfXlaCpuGlobalJit);
676 m.def("TF_SetXlaAutoJitMode", &TF_SetXlaAutoJitMode);
677 m.def("TF_SetXlaConstantFoldingDisabled", &TF_SetXlaConstantFoldingDisabled);
678 m.def("TF_GetXlaConstantFoldingDisabled", &TF_GetXlaConstantFoldingDisabled);
679 m.def("TF_SetXlaMinClusterSize", &TF_SetXlaMinClusterSize);
680 m.def("TF_GetCompilerIr", &tensorflow::TFE_GetCompilerIr);
681
682 // MLIR Logic
683 m.def("TF_IsMlirBridgeEnabled", [] {
684 // Since python protobuf enums are integers, cast to an integer before
685 // returning the enum to python.
686 return static_cast<int32_t>(
687 tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge);
688 });
689 m.def("TF_EnableMlirBridge", [](bool enabled) {
690 tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge =
691 enabled
692 ? tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED
693 : tensorflow::ConfigProto::Experimental::
694 MLIR_BRIDGE_ROLLOUT_DISABLED;
695 });
696 m.def("TF_EnableXlaDevices", [] {
697 tensorflow::GetXlaDeviceFlags()->tf_xla_enable_xla_devices = true;
698 });
699 m.def("TF_ResetJitCompilerFlags",
700 [] { tensorflow::ResetJitCompilerFlags(); });
701
702 // TFE_Context Logic
703 m.def(
704 "TFE_NewContext",
705 [](const TFE_ContextOptions* opts) {
706 tensorflow::Safe_TF_StatusPtr status =
707 tensorflow::make_safe(TF_NewStatus());
708 TFE_Context* context = TFE_NewContext(opts, status.get());
709 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
710 return tensorflow::PyoOrThrow(tensorflow::OutputTFE_Context(context));
711 },
712 py::return_value_policy::reference);
713 m.def("TFE_DeleteContext", [](py::handle& o) {
714 TFE_DeleteContext(tensorflow::InputTFE_Context(o));
715 });
716 m.def(
717 "TFE_ContextListDevices",
718 [](py::handle& o) {
719 tensorflow::Safe_TF_StatusPtr status =
720 tensorflow::make_safe(TF_NewStatus());
721 auto output = TFE_ContextListDevices(tensorflow::InputTFE_Context(o),
722 status.get());
723 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
724 return output;
725 },
726 py::return_value_policy::reference);
727 m.def(
728 "TFE_SetLogicalCpuDevices",
729 [](py::handle& ctx, int num_cpus, const char* prefix) {
730 tensorflow::Safe_TF_StatusPtr status =
731 tensorflow::make_safe(TF_NewStatus());
732 TFE_SetLogicalCpuDevices(tensorflow::InputTFE_Context(ctx), num_cpus,
733 prefix, status.get());
734 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
735 },
736 py::return_value_policy::reference);
737 m.def("TFE_HostAddressSpace", [](py::handle& o, TF_Buffer& buf) {
738 TFE_HostAddressSpace(tensorflow::InputTFE_Context(o), &buf);
739 });
740 m.def("TFE_ContextAddFunction", [](py::handle& ctx, TF_Function* func) {
741 tensorflow::Safe_TF_StatusPtr status =
742 tensorflow::make_safe(TF_NewStatus());
743 TFE_ContextAddFunction(tensorflow::InputTFE_Context(ctx), func,
744 status.get());
745 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
746 });
747 m.def("TFE_ContextAddFunctionDef",
748 [](py::handle& ctx, const char* serialized_function_def, size_t size) {
749 tensorflow::Safe_TF_StatusPtr status =
750 tensorflow::make_safe(TF_NewStatus());
751 TFE_ContextAddFunctionDef(tensorflow::InputTFE_Context(ctx),
752 serialized_function_def, size,
753 status.get());
754 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
755 });
756 m.def("TFE_ContextGetFunctionDef",
757 [](py::handle& ctx, const char* function_name, TF_Buffer& buf) {
758 tensorflow::Safe_TF_StatusPtr status =
759 tensorflow::make_safe(TF_NewStatus());
760 TFE_ContextGetFunctionDef(tensorflow::InputTFE_Context(ctx),
761 function_name, &buf, status.get());
762 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
763 });
764 m.def("TFE_ContextRemoveFunction", [](py::handle& ctx, const char* name) {
765 tensorflow::Safe_TF_StatusPtr status =
766 tensorflow::make_safe(TF_NewStatus());
767 TFE_ContextRemoveFunction(tensorflow::InputTFE_Context(ctx), name,
768 status.get());
769 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
770 });
771 m.def("TFE_ContextHasFunction", [](py::handle& ctx, const char* name) {
772 tensorflow::Safe_TF_StatusPtr status =
773 tensorflow::make_safe(TF_NewStatus());
774 auto output =
775 TFE_ContextHasFunction(tensorflow::InputTFE_Context(ctx), name);
776 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
777 return output;
778 });
779 m.def("TFE_ContextListFunctionNames", [](py::handle& ctx) {
780 return tensorflow::unwrap(tensorflow::InputTFE_Context(ctx))
781 ->ListFunctionNames();
782 });
783 m.def("TFE_ContextEnableRunMetadata", [](py::handle& ctx) {
784 TFE_ContextEnableRunMetadata(tensorflow::InputTFE_Context(ctx));
785 });
786 m.def("TFE_ContextDisableRunMetadata", [](py::handle& ctx) {
787 TFE_ContextEnableRunMetadata(tensorflow::InputTFE_Context(ctx));
788 });
789 m.def("TFE_ContextEnableGraphCollection", [](py::handle& ctx) {
790 TFE_ContextEnableGraphCollection(tensorflow::InputTFE_Context(ctx));
791 });
792 m.def("TFE_ContextDisableGraphCollection", [](py::handle& ctx) {
793 TFE_ContextDisableGraphCollection(tensorflow::InputTFE_Context(ctx));
794 });
795 m.def("TFE_ContextExportRunMetadata", [](py::handle& ctx, TF_Buffer& buf) {
796 tensorflow::Safe_TF_StatusPtr status =
797 tensorflow::make_safe(TF_NewStatus());
798 TFE_ContextExportRunMetadata(tensorflow::InputTFE_Context(ctx), &buf,
799 status.get());
800 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
801 });
802 m.def("TFE_ContextClearCaches", [](py::handle& o) {
803 TFE_ContextClearCaches(tensorflow::InputTFE_Context(o));
804 });
805 m.def("TFE_GetContextId", [](py::handle& ctx) {
806 return TFE_GetContextId(tensorflow::InputTFE_Context(ctx));
807 });
808 m.def("TFE_ContextGetDevicePlacementPolicy", [](py::handle& ctx) {
809 return TFE_ContextGetDevicePlacementPolicy(
810 tensorflow::InputTFE_Context(ctx));
811 });
812 m.def("TFE_ContextSetThreadLocalDevicePlacementPolicy",
813 [](py::handle& ctx, TFE_ContextDevicePlacementPolicy policy) {
814 TFE_ContextSetThreadLocalDevicePlacementPolicy(
815 tensorflow::InputTFE_Context(ctx), policy);
816 });
817 m.def("TFE_ContextSetServerDef", [](py::handle& ctx, int keep_alive_secs,
818 py::bytes proto) {
819 tensorflow::Safe_TF_StatusPtr status =
820 tensorflow::make_safe(TF_NewStatus());
821 tensorflow::Safe_TF_BufferPtr buf =
822 tensorflow::make_safe(tensorflow::ProtoStringToTFBuffer(proto.ptr()));
823 TFE_ContextSetServerDef(tensorflow::InputTFE_Context(ctx), keep_alive_secs,
824 buf.get()->data, buf.get()->length, status.get());
825 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
826 });
827 m.def("TFE_ContextUpdateServerDef", [](py::handle& ctx, int keep_alive_secs,
828 py::bytes proto) {
829 tensorflow::Safe_TF_StatusPtr status =
830 tensorflow::make_safe(TF_NewStatus());
831 tensorflow::Safe_TF_BufferPtr buf =
832 tensorflow::make_safe(tensorflow::ProtoStringToTFBuffer(proto.ptr()));
833 Py_BEGIN_ALLOW_THREADS;
834 TFE_ContextUpdateServerDef(tensorflow::InputTFE_Context(ctx),
835 keep_alive_secs, buf.get()->data,
836 buf.get()->length, status.get());
837 Py_END_ALLOW_THREADS;
838 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
839 });
840 m.def("TFE_ContextCheckAlive", [](py::handle& ctx, const char* worker_name) {
841 tensorflow::Safe_TF_StatusPtr status =
842 tensorflow::make_safe(TF_NewStatus());
843 bool output = TFE_ContextCheckAlive(tensorflow::InputTFE_Context(ctx),
844 worker_name, status.get());
845 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
846 return output;
847 });
848 m.def("TFE_ContextSyncExecutors", [](py::handle& ctx) {
849 tensorflow::Safe_TF_StatusPtr status =
850 tensorflow::make_safe(TF_NewStatus());
851 // NOTE: release Python GIL for pending PyFunc ops to be executed properly.
852 Py_BEGIN_ALLOW_THREADS;
853 TFE_ContextAsyncWait(tensorflow::InputTFE_Context(ctx), status.get());
854 Py_END_ALLOW_THREADS;
855 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
856 });
857 m.def("TFE_ContextClearExecutors", [](py::handle& ctx) {
858 tensorflow::Safe_TF_StatusPtr status =
859 tensorflow::make_safe(TF_NewStatus());
860 // NOTE: release Python GIL for pending PyFunc ops to be executed properly.
861 Py_BEGIN_ALLOW_THREADS;
862 TFE_ContextAsyncWait(tensorflow::InputTFE_Context(ctx), status.get());
863 Py_END_ALLOW_THREADS;
864 // NOTE: different from TFE_ContextSyncExecutors that raises potential
865 // errors, deliberately ignore executor statuses in cleanup.
866 });
867 m.def(
868 "TFE_InsertConfigKeyValue",
869 [](py::handle& ctx, const char* config_key, const char* config_value) {
870 tensorflow::Safe_TF_StatusPtr status =
871 tensorflow::make_safe(TF_NewStatus());
872 Py_BEGIN_ALLOW_THREADS;
873 TFE_InsertConfigKeyValue(tensorflow::InputTFE_Context(ctx), config_key,
874 config_value, status.get());
875 Py_END_ALLOW_THREADS;
876 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
877 },
878 py::return_value_policy::reference);
879 m.def(
880 "TFE_GetConfigKeyValue",
881 [](py::handle& ctx, const char* config_key, TF_Buffer& config_value) {
882 tensorflow::Safe_TF_StatusPtr status =
883 tensorflow::make_safe(TF_NewStatus());
884 Py_BEGIN_ALLOW_THREADS;
885 TFE_GetConfigKeyValue(tensorflow::InputTFE_Context(ctx), config_key,
886 &config_value, status.get());
887 Py_END_ALLOW_THREADS;
888 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
889 },
890 py::return_value_policy::reference);
891 m.def(
892 "TFE_DeleteConfigKeyValue",
893 [](py::handle& ctx, const char* config_key) {
894 tensorflow::Safe_TF_StatusPtr status =
895 tensorflow::make_safe(TF_NewStatus());
896 Py_BEGIN_ALLOW_THREADS;
897 TFE_DeleteConfigKeyValue(tensorflow::InputTFE_Context(ctx), config_key,
898 status.get());
899 Py_END_ALLOW_THREADS;
900 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
901 },
902 py::return_value_policy::reference);
903 m.def(
904 "TFE_ReportErrorToCluster",
905 [](py::handle& ctx, int error_code, const char* error_message) {
906 tensorflow::Safe_TF_StatusPtr status =
907 tensorflow::make_safe(TF_NewStatus());
908 TFE_ReportErrorToCluster(tensorflow::InputTFE_Context(ctx), error_code,
909 error_message, status.get());
910 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
911 },
912 py::return_value_policy::reference);
913 m.def("TFE_ContextSetSoftDevicePlacement", [](py::handle& ctx, bool enable) {
914 tensorflow::Safe_TF_StatusPtr status =
915 tensorflow::make_safe(TF_NewStatus());
916 TFE_ContextSetSoftDevicePlacement(tensorflow::InputTFE_Context(ctx), enable,
917 status.get());
918 });
919 m.def("TFE_ContextSetLogDevicePlacement", [](py::handle& ctx, bool enable) {
920 tensorflow::Safe_TF_StatusPtr status =
921 tensorflow::make_safe(TF_NewStatus());
922 TFE_ContextSetSoftDevicePlacement(tensorflow::InputTFE_Context(ctx), enable,
923 status.get());
924 });
925 m.def("TFE_ContextSetRunEagerOpAsFunction", [](py::handle& ctx, bool enable) {
926 tensorflow::Safe_TF_StatusPtr status =
927 tensorflow::make_safe(TF_NewStatus());
928 TFE_ContextSetRunEagerOpAsFunction(tensorflow::InputTFE_Context(ctx),
929 enable, status.get());
930 });
931 m.def("TFE_ContextSetJitCompileRewrite", [](py::handle& ctx, bool enable) {
932 tensorflow::Safe_TF_StatusPtr status =
933 tensorflow::make_safe(TF_NewStatus());
934 TFE_ContextSetJitCompileRewrite(tensorflow::InputTFE_Context(ctx), enable,
935 status.get());
936 });
937
938 // TFE_Executor logic
939 m.def(
940 "TFE_NewExecutor",
941 [](const bool is_async, const bool enable_streaming_enqueue) {
942 TFE_Executor* exc = TFE_NewExecutor(is_async, enable_streaming_enqueue);
943 return exc;
944 },
945 py::return_value_policy::reference);
946 m.def("TFE_DeleteExecutor", &TFE_DeleteExecutor);
947 m.def("TFE_ExecutorIsAsync", &TFE_ExecutorIsAsync);
948 m.def("TFE_ExecutorWaitForAllPendingNodes", [](TFE_Executor& exc) {
949 tensorflow::Safe_TF_StatusPtr status =
950 tensorflow::make_safe(TF_NewStatus());
951 // NOTE: release Python GIL for pending PyFunc ops to be executed properly.
952 Py_BEGIN_ALLOW_THREADS;
953 TFE_ExecutorWaitForAllPendingNodes(&exc, status.get());
954 Py_END_ALLOW_THREADS;
955 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
956 });
957 m.def("TFE_ExecutorClearError", &TFE_ExecutorClearError);
958 m.def("TFE_ContextSetExecutorForThread", [](py::handle& ctx,
959 TFE_Executor& exc) {
960 TFE_ContextSetExecutorForThread(tensorflow::InputTFE_Context(ctx), &exc);
961 });
962 m.def(
963 "TFE_ContextGetExecutorForThread",
964 [](py::handle& o) {
965 return TFE_ContextGetExecutorForThread(tensorflow::InputTFE_Context(o));
966 },
967 py::return_value_policy::reference);
968
969 m.def("TFE_OpNameGetAttrType",
970 [](py::handle& ctx, const char* op_or_function_name,
971 const char* attr_name) {
972 int temp = 0;
973 unsigned char* is_list = reinterpret_cast<unsigned char*>(&temp);
974 tensorflow::Safe_TF_StatusPtr status =
975 tensorflow::make_safe(TF_NewStatus());
976 auto output = TFE_OpNameGetAttrType(tensorflow::InputTFE_Context(ctx),
977 op_or_function_name, attr_name,
978 is_list, status.get());
979 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
980#if PY_MAJOR_VERSION < 3
981 PyObject* output_pyo = PyInt_FromLong(output);
982#else
983 PyObject* output_pyo = PyLong_FromLong(output);
984#endif
985 if (*is_list == 1) {
986 PyObject* list = PyList_New(1);
987 PyList_SetItem(list, 0, output_pyo);
988 return tensorflow::PyoOrThrow(list);
989 }
990 return tensorflow::PyoOrThrow(output_pyo);
991 });
992 m.def("TFE_Py_InitEagerTensor", [](const py::handle& o) {
993 return tensorflow::PyoOrThrow(TFE_Py_InitEagerTensor(o.ptr()));
994 });
995 m.def("TFE_Py_PackEagerTensors",
996 [](const py::handle& context, const py::handle& handles) {
997 return tensorflow::TFE_Py_PackEagerTensors_wrapper(context, handles);
998 });
999 m.def("TFE_Py_SetEagerTensorProfiler", &TFE_Py_SetEagerTensorProfiler);
1000 m.def("TFE_Py_RegisterJVPFunction", [](const py::handle& o) {
1001 return tensorflow::PyoOrThrow(TFE_Py_RegisterJVPFunction(o.ptr()));
1002 });
1003 m.def("TFE_Py_RegisterGradientFunction", [](const py::handle& o) {
1004 return tensorflow::PyoOrThrow(TFE_Py_RegisterGradientFunction(o.ptr()));
1005 });
1006 m.def("TFE_Py_Execute",
1007 [](const py::handle& context, const char* device_name,
1008 const char* op_name, const py::handle& inputs,
1009 const py::handle& attrs, const py::handle& num_outputs) {
1010 return tensorflow::TFE_Py_ExecuteCancelable_wrapper(
1011 context, device_name, op_name, inputs, attrs.ptr(), nullptr,
1012 num_outputs);
1013 });
1014 m.def(
1015 "TFE_Py_ExecuteCancelable",
1016 [](const py::handle& context, const char* device_name,
1017 const char* op_name, const py::handle& inputs, const py::handle& attrs,
1018 tensorflow::CancellationManager& cancellation_manager,
1019 const py::handle& num_outputs) {
1020 return tensorflow::TFE_Py_ExecuteCancelable_wrapper(
1021 context, device_name, op_name, inputs, attrs.ptr(),
1022 &cancellation_manager, num_outputs);
1023 });
1024 m.def("TFE_Py_FastPathExecute", [](const py::args args) {
1025 // TFE_Py_FastPathExecute requires error checking prior to returning.
1026 return tensorflow::PyoOrThrow(TFE_Py_FastPathExecute_C(args.ptr()));
1027 });
1028 m.def("TFE_Py_RecordGradient",
1029 [](const py::handle& op_name, const py::handle& inputs,
1030 const py::handle& attrs, const py::handle& results,
1031 const py::handle& forward_pass_name_scope) {
1032 return tensorflow::PyoOrThrow(TFE_Py_RecordGradient(
1033 op_name.ptr(), inputs.ptr(), attrs.ptr(), results.ptr(),
1034 forward_pass_name_scope.ptr()));
1035 });
1036 m.def("TFE_Py_UID", []() { return tensorflow::PyoOrThrow(TFE_Py_UID()); });
1037
1038 // TFE_Py_Tape Logic
1039 m.def("TFE_Py_TapeSetNew", [](const py::handle& persistent,
1040 const py::handle& watch_accessed_variables) {
1041 return tensorflow::PyoOrThrow(
1042 TFE_Py_TapeSetNew(persistent.ptr(), watch_accessed_variables.ptr()));
1043 });
1044 m.def("TFE_Py_TapeSetAdd",
1045 [](const py::handle& tape) { TFE_Py_TapeSetAdd(tape.ptr()); });
1046 m.def("TFE_Py_TapeSetRemove",
1047 [](const py::handle& tape) { TFE_Py_TapeSetRemove(tape.ptr()); });
1048 m.def("TFE_Py_TapeSetStopOnThread", &TFE_Py_TapeSetStopOnThread);
1049 m.def("TFE_Py_TapeSetRestartOnThread", &TFE_Py_TapeSetRestartOnThread);
1050 m.def("TFE_Py_TapeSetIsStopped",
1051 []() { return tensorflow::PyoOrThrow(TFE_Py_TapeSetIsStopped()); });
1052 m.def("TFE_Py_TapeSetIsEmpty",
1053 []() { return tensorflow::PyoOrThrow(TFE_Py_TapeSetIsEmpty()); });
1054 m.def("TFE_Py_TapeSetShouldRecordBackprop", [](const py::handle& tensors) {
1055 return tensorflow::PyoOrThrow(
1056 TFE_Py_TapeSetShouldRecordBackprop(tensors.ptr()));
1057 });
1058 m.def("TFE_Py_TapeSetPossibleGradientTypes", [](const py::handle& tensors) {
1059 return tensorflow::PyoOrThrow(
1060 TFE_Py_TapeSetPossibleGradientTypes(tensors.ptr()));
1061 });
1062 m.def("TFE_Py_TapeSetDeleteTrace", &TFE_Py_TapeSetDeleteTrace);
1063 m.def("TFE_Py_TapeSetRecordOperation",
1064 [](const py::handle& op_type, const py::handle& output_tensors,
1065 const py::handle& input_tensors, const py::handle& backward_function,
1066 const py::handle& forward_function) {
1067 return tensorflow::PyoOrThrow(TFE_Py_TapeSetRecordOperation(
1068 op_type.ptr(), output_tensors.ptr(), input_tensors.ptr(),
1069 backward_function.ptr(), forward_function.ptr()));
1070 });
1071 m.def(
1072 "TFE_Py_TapeSetRecordOperationBackprop",
1073 [](const py::handle& op_type, const py::handle& output_tensors,
1074 const py::handle& input_tensors, const py::handle& backward_function) {
1075 return tensorflow::PyoOrThrow(TFE_Py_TapeSetRecordOperationBackprop(
1076 op_type.ptr(), output_tensors.ptr(), input_tensors.ptr(),
1077 backward_function.ptr()));
1078 });
1079 m.def(
1080 "TFE_Py_TapeSetRecordOperationForwardprop",
1081 [](const py::handle& op_type, const py::handle& output_tensors,
1082 const py::handle& input_tensors, const py::handle& backward_function,
1083 const py::handle& forwardprop_output_indices) {
1084 return tensorflow::PyoOrThrow(TFE_Py_TapeSetRecordOperationForwardprop(
1085 op_type.ptr(), output_tensors.ptr(), input_tensors.ptr(),
1086 backward_function.ptr(), forwardprop_output_indices.ptr()));
1087 });
1088 m.def("TFE_Py_TapeGradient",
1089 [](const py::handle& tape, const py::handle& target,
1090 const py::handle& sources, const py::handle& output_gradients,
1091 const py::handle& sources_raw,
1092 const py::handle& unconnected_gradients) {
1093 tensorflow::Safe_TF_StatusPtr status =
1094 tensorflow::make_safe(TF_NewStatus());
1095 PyObject* output = TFE_Py_TapeGradient(
1096 tape.ptr(), target.ptr(), sources.ptr(), output_gradients.ptr(),
1097 sources_raw.ptr(), unconnected_gradients.ptr(), status.get());
1098 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1099 return tensorflow::PyoOrThrow(output);
1100 });
1101
1102 m.def("TFE_Py_TapeVariableAccessed", [](const py::handle& variable) {
1103 TFE_Py_TapeVariableAccessed(variable.ptr());
1104 });
1105 m.def("TFE_Py_TapeWatch",
1106 [](const py::handle& tape, const py::handle& tensor) {
1107 TFE_Py_TapeWatch(tape.ptr(), tensor.ptr());
1108 });
1109 m.def("TFE_Py_TapeWatchVariable",
1110 [](const py::handle& tape, const py::handle& variable) {
1111 TFE_Py_TapeWatchVariable(tape.ptr(), variable.ptr());
1112 });
1113 m.def("TFE_Py_TapeWatchedVariables", [](const py::handle& tape) {
1114 return tensorflow::PyoOrThrow(TFE_Py_TapeWatchedVariables(tape.ptr()));
1115 });
1116
1117 // TFE_Py_VariableWatcher logic.
1118 m.def("TFE_Py_VariableWatcherNew",
1119 []() { return tensorflow::PyoOrThrow(TFE_Py_VariableWatcherNew()); });
1120 m.def("TFE_Py_VariableWatcherRemove", [](const py::handle& variable_watcher) {
1121 TFE_Py_VariableWatcherRemove(variable_watcher.ptr());
1122 });
1123 m.def("TFE_Py_VariableWatcherVariableAccessed",
1124 [](const py::handle& variable) {
1125 TFE_Py_VariableWatcherVariableAccessed(variable.ptr());
1126 });
1127 m.def("TFE_Py_VariableWatcherWatchedVariables",
1128 [](const py::handle& variable_watcher) {
1129 return tensorflow::PyoOrThrow(
1130 TFE_Py_VariableWatcherWatchedVariables(variable_watcher.ptr()));
1131 });
1132
1133 // TFE_Py_ForwardAccumulator logic.
1134 m.def("TFE_Py_ForwardAccumulatorNew", [](bool use_batch) {
1135 return tensorflow::PyoOrThrow(TFE_Py_ForwardAccumulatorNew(use_batch));
1136 });
1137
1138 m.def("TFE_Py_ForwardAccumulatorSetAdd", [](const py::handle& accumulator) {
1139 return tensorflow::PyoOrThrow(
1140 TFE_Py_ForwardAccumulatorSetAdd(accumulator.ptr()));
1141 });
1142 m.def("TFE_Py_ForwardAccumulatorSetRemove",
1143 [](const py::handle& accumulator) {
1144 TFE_Py_ForwardAccumulatorSetRemove(accumulator.ptr());
1145 });
1146
1147 m.def("TFE_Py_ForwardAccumulatorWatch",
1148 [](const py::handle& accumulator, const py::handle& tensor,
1149 const py::handle& tangent) {
1150 TFE_Py_ForwardAccumulatorWatch(accumulator.ptr(), tensor.ptr(),
1151 tangent.ptr());
1152 });
1153 m.def("TFE_Py_ForwardAccumulatorJVP",
1154 [](const py::handle& accumulator, const py::handle& tensor) {
1155 return tensorflow::PyoOrThrow(
1156 TFE_Py_ForwardAccumulatorJVP(accumulator.ptr(), tensor.ptr()));
1157 });
1158 m.def("TFE_Py_ForwardAccumulatorPushState", []() {
1159 return tensorflow::PyoOrThrow(TFE_Py_ForwardAccumulatorPushState());
1160 });
1161 m.def("TFE_Py_ForwardAccumulatorPopState", []() {
1162 return tensorflow::PyoOrThrow(TFE_Py_ForwardAccumulatorPopState());
1163 });
1164 m.def("TFE_Py_PackJVPs", [](const py::handle& tensors) {
1165 return tensorflow::PyoOrThrow(TFE_Py_PackJVPs(tensors.ptr()));
1166 });
1167
1168 // TFE_ContextOptions Logic
1169 m.def("TFE_NewContextOptions", &TFE_NewContextOptions,
1170 py::return_value_policy::reference);
1171 m.def("TFE_ContextOptionsSetConfig", [](TFE_ContextOptions* options,
1172 py::bytes proto) {
1173 tensorflow::Safe_TF_StatusPtr status =
1174 tensorflow::make_safe(TF_NewStatus());
1175 tensorflow::Safe_TF_BufferPtr buf =
1176 tensorflow::make_safe(tensorflow::ProtoStringToTFBuffer(proto.ptr()));
1177 TFE_ContextOptionsSetConfig(options, buf.get()->data, buf.get()->length,
1178 status.get());
1179 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1180 });
1181 m.def("TFE_ContextOptionsSetDevicePlacementPolicy",
1182 &TFE_ContextOptionsSetDevicePlacementPolicy);
1183 m.def("TFE_ContextOptionsSetTfrt", &TFE_ContextOptionsSetTfrt);
1184 m.def("TFE_ContextOptionsSetTfrtDistributedRuntime",
1185 &TFE_ContextOptionsSetTfrtDistributedRuntime);
1186 // Experimental feature, intentionally not exposed as a C API yet.
1187 m.def("TFE_ContextOptionsSetRunEagerOpAsFunction",
1188 [](TFE_ContextOptions* options, bool run_eager_op_as_function) {
1189 options->run_eager_op_as_function = run_eager_op_as_function;
1190 });
1191 m.def("TFE_ContextOptionsSetJitCompileRewrite",
1192 [](TFE_ContextOptions* options, bool jit_compile_rewrite) {
1193 options->jit_compile_rewrite = jit_compile_rewrite;
1194 });
1195 m.def("TFE_ContextOptionsSetAsync", &TFE_ContextOptionsSetAsync);
1196 m.def("TFE_DeleteContextOptions", &TFE_DeleteContextOptions,
1197 py::return_value_policy::reference);
1198
1199 // TFE_Py_TensorShape Logic
1200 m.def("TFE_Py_TensorShapeSlice",
1201 [](const py::handle& tensors, int slice_dim) {
1202 return tensorflow::PyoOrThrow(
1203 TFE_Py_TensorShapeSlice(tensors.ptr(), slice_dim));
1204 });
1205 m.def("TFE_Py_TensorShapeOnDevice", [](const py::handle& tensors,
1206 int slice_dim) {
1207 return tensorflow::PyoOrThrow(TFE_Py_TensorShapeOnDevice(tensors.ptr()));
1208 });
1209 m.def("TFE_Py_EnableInteractivePythonLogging",
1210 &TFE_Py_EnableInteractivePythonLogging);
1211
1212 // Additional Context Logic
1213 m.def("TFE_Py_SetEagerContext", [](const py::handle& o) {
1214 return tensorflow::PyoOrThrow(TFE_Py_SetEagerContext(o.ptr()));
1215 });
1216 m.def("TFE_Py_SetCEagerContext", [](const py::handle& ctx) {
1217 // TODO(mdan): This cast might need rewriting to ImmediateExecutionContext.
1218 tensorflow::SetCEagerContext(reinterpret_cast<tensorflow::EagerContext*>(
1219 tensorflow::InputTFE_Context(ctx)));
1220 });
1221 m.def("TFE_Py_RegisterVSpace", [](const py::handle& o) {
1222 return tensorflow::PyoOrThrow(TFE_Py_RegisterVSpace(o.ptr()));
1223 });
1224 m.def("TFE_EnableCollectiveOps", [](const py::handle& ctx, py::bytes proto) {
1225 tensorflow::Safe_TF_StatusPtr status =
1226 tensorflow::make_safe(TF_NewStatus());
1227 tensorflow::Safe_TF_BufferPtr buf =
1228 tensorflow::make_safe(tensorflow::ProtoStringToTFBuffer(proto.ptr()));
1229 TFE_EnableCollectiveOps(tensorflow::InputTFE_Context(ctx), buf.get()->data,
1230 buf.get()->length, status.get());
1231 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1232 });
1233 m.def("TFE_AbortCollectiveOps", [](const py::handle& ctx, int code,
1234 const char* message) {
1235 tensorflow::Safe_TF_StatusPtr status =
1236 tensorflow::make_safe(TF_NewStatus());
1237 TF_SetStatus(status.get(), static_cast<TF_Code>(code), message);
1238 TFE_AbortCollectiveOps(tensorflow::InputTFE_Context(ctx), status.get());
1239 });
1240 m.def("TFE_CollectiveOpsCheckPeerHealth",
1241 [](const py::handle& ctx, const char* task, int64_t timeout_in_ms) {
1242 tensorflow::Safe_TF_StatusPtr status =
1243 tensorflow::make_safe(TF_NewStatus());
1244 TFE_CollectiveOpsCheckPeerHealth(tensorflow::InputTFE_Context(ctx),
1245 task, timeout_in_ms, status.get());
1246 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1247 });
1248 m.def("TF_ListPhysicalDevices", &tensorflow::TF_ListPhysicalDevices);
1249 m.def("TF_ListPluggablePhysicalDevices",
1250 &tensorflow::TF_ListPluggablePhysicalDevices);
1251 m.def("TF_GetDeviceDetails", &tensorflow::TF_GetDeviceDetails);
1252 m.def("TF_DeleteDeviceList", &TF_DeleteDeviceList,
1253 py::return_value_policy::reference);
1254 m.def("TF_DeviceListCount", &TF_DeviceListCount);
1255 m.def("TF_DeviceListName", [](const TF_DeviceList* list, int index) {
1256 tensorflow::Safe_TF_StatusPtr status =
1257 tensorflow::make_safe(TF_NewStatus());
1258 auto output = TF_DeviceListName(list, index, status.get());
1259 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1260 return output;
1261 });
1262 m.def("TF_DeviceListType", [](const TF_DeviceList* list, int index) {
1263 tensorflow::Safe_TF_StatusPtr status =
1264 tensorflow::make_safe(TF_NewStatus());
1265 auto output = TF_DeviceListType(list, index, status.get());
1266 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1267 return output;
1268 });
1269
1270 m.def("TF_PickUnusedPortOrDie", &TF_PickUnusedPortOrDie);
1271
1272 // TFE_MonitoringCounter Logic
1273 m.def("TFE_MonitoringCounterCellIncrementBy",
1274 &TFE_MonitoringCounterCellIncrementBy);
1275 m.def("TFE_MonitoringCounterCellValue", &TFE_MonitoringCounterCellValue);
1276 m.def(
1277 "TFE_MonitoringNewCounter0",
1278 [](const char* name, const char* description) {
1279 tensorflow::Safe_TF_StatusPtr status =
1280 tensorflow::make_safe(TF_NewStatus());
1281 auto output =
1282 TFE_MonitoringNewCounter0(name, status.get(), description);
1283 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1284 return output;
1285 },
1286 py::return_value_policy::reference);
1287 m.def("TFE_MonitoringDeleteCounter0", &TFE_MonitoringDeleteCounter0,
1288 py::return_value_policy::reference);
1289 m.def("TFE_MonitoringGetCellCounter0", &TFE_MonitoringGetCellCounter0,
1290 py::return_value_policy::reference);
1291 m.def(
1292 "TFE_MonitoringNewCounter1",
1293 [](const char* name, const char* description, const char* label1) {
1294 tensorflow::Safe_TF_StatusPtr status =
1295 tensorflow::make_safe(TF_NewStatus());
1296 auto output =
1297 TFE_MonitoringNewCounter1(name, status.get(), description, label1);
1298 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1299 return output;
1300 },
1301 py::return_value_policy::reference);
1302 m.def("TFE_MonitoringDeleteCounter1", &TFE_MonitoringDeleteCounter1,
1303 py::return_value_policy::reference);
1304 m.def("TFE_MonitoringGetCellCounter1", &TFE_MonitoringGetCellCounter1,
1305 py::return_value_policy::reference);
1306 m.def(
1307 "TFE_MonitoringNewCounter2",
1308 [](const char* name, const char* description, const char* label1,
1309 const char* label2) {
1310 tensorflow::Safe_TF_StatusPtr status =
1311 tensorflow::make_safe(TF_NewStatus());
1312 auto output = TFE_MonitoringNewCounter2(name, status.get(), description,
1313 label1, label2);
1314 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1315 return output;
1316 },
1317 py::return_value_policy::reference);
1318 m.def("TFE_MonitoringDeleteCounter2", &TFE_MonitoringDeleteCounter2,
1319 py::return_value_policy::reference);
1320 m.def("TFE_MonitoringGetCellCounter2", &TFE_MonitoringGetCellCounter2,
1321 py::return_value_policy::reference);
1322
1323 // TFE_MonitoringIntGauge Logic
1324 m.def("TFE_MonitoringIntGaugeCellSet", &TFE_MonitoringIntGaugeCellSet);
1325 m.def("TFE_MonitoringIntGaugeCellValue", &TFE_MonitoringIntGaugeCellValue);
1326 m.def(
1327 "TFE_MonitoringNewIntGauge0",
1328 [](const char* name, const char* description) {
1329 tensorflow::Safe_TF_StatusPtr status =
1330 tensorflow::make_safe(TF_NewStatus());
1331 auto output =
1332 TFE_MonitoringNewIntGauge0(name, status.get(), description);
1333 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1334 return output;
1335 },
1336 py::return_value_policy::reference);
1337 m.def("TFE_MonitoringDeleteIntGauge0", &TFE_MonitoringDeleteIntGauge0,
1338 py::return_value_policy::reference);
1339 m.def("TFE_MonitoringGetCellIntGauge0", &TFE_MonitoringGetCellIntGauge0,
1340 py::return_value_policy::reference);
1341 m.def(
1342 "TFE_MonitoringNewIntGauge1",
1343 [](const char* name, const char* description, const char* label1) {
1344 tensorflow::Safe_TF_StatusPtr status =
1345 tensorflow::make_safe(TF_NewStatus());
1346 auto output =
1347 TFE_MonitoringNewIntGauge1(name, status.get(), description, label1);
1348 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1349 return output;
1350 },
1351 py::return_value_policy::reference);
1352 m.def("TFE_MonitoringDeleteIntGauge1", &TFE_MonitoringDeleteIntGauge1,
1353 py::return_value_policy::reference);
1354 m.def("TFE_MonitoringGetCellIntGauge1", &TFE_MonitoringGetCellIntGauge1,
1355 py::return_value_policy::reference);
1356 m.def(
1357 "TFE_MonitoringNewIntGauge2",
1358 [](const char* name, const char* description, const char* label1,
1359 const char* label2) {
1360 tensorflow::Safe_TF_StatusPtr status =
1361 tensorflow::make_safe(TF_NewStatus());
1362 auto output = TFE_MonitoringNewIntGauge2(name, status.get(),
1363 description, label1, label2);
1364 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1365 return output;
1366 },
1367 py::return_value_policy::reference);
1368 m.def("TFE_MonitoringDeleteIntGauge2", &TFE_MonitoringDeleteIntGauge2,
1369 py::return_value_policy::reference);
1370 m.def("TFE_MonitoringGetCellIntGauge2", &TFE_MonitoringGetCellIntGauge2,
1371 py::return_value_policy::reference);
1372 m.def("TFE_MonitoringStringGaugeCellSet", &TFE_MonitoringStringGaugeCellSet);
1373 m.def("TFE_MonitoringStringGaugeCellValue",
1374 &TFE_MonitoringStringGaugeCellValue);
1375 m.def(
1376 "TFE_MonitoringNewStringGauge0",
1377 [](const char* name, const char* description) {
1378 tensorflow::Safe_TF_StatusPtr status =
1379 tensorflow::make_safe(TF_NewStatus());
1380 auto output =
1381 TFE_MonitoringNewStringGauge0(name, status.get(), description);
1382 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1383 return output;
1384 },
1385 py::return_value_policy::reference);
1386
1387 // TFE_MonitoringStringGauge Logic
1388 m.def("TFE_MonitoringDeleteStringGauge0", &TFE_MonitoringDeleteStringGauge0);
1389 m.def("TFE_MonitoringGetCellStringGauge0", &TFE_MonitoringGetCellStringGauge0,
1390 py::return_value_policy::reference);
1391 m.def(
1392 "TFE_MonitoringNewStringGauge1",
1393 [](const char* name, const char* description, const char* label1) {
1394 tensorflow::Safe_TF_StatusPtr status =
1395 tensorflow::make_safe(TF_NewStatus());
1396 auto output = TFE_MonitoringNewStringGauge1(name, status.get(),
1397 description, label1);
1398 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1399 return output;
1400 },
1401 py::return_value_policy::reference);
1402 m.def("TFE_MonitoringDeleteStringGauge1", &TFE_MonitoringDeleteStringGauge1);
1403 m.def("TFE_MonitoringGetCellStringGauge1", &TFE_MonitoringGetCellStringGauge1,
1404 py::return_value_policy::reference);
1405 m.def(
1406 "TFE_MonitoringNewStringGauge2",
1407 [](const char* name, const char* description, const char* label1,
1408 const char* label2) {
1409 tensorflow::Safe_TF_StatusPtr status =
1410 tensorflow::make_safe(TF_NewStatus());
1411 auto output = TFE_MonitoringNewStringGauge2(
1412 name, status.get(), description, label1, label2);
1413 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1414 return output;
1415 },
1416 py::return_value_policy::reference);
1417 m.def("TFE_MonitoringDeleteStringGauge2", &TFE_MonitoringDeleteStringGauge2);
1418 m.def("TFE_MonitoringGetCellStringGauge2", &TFE_MonitoringGetCellStringGauge2,
1419 py::return_value_policy::reference);
1420
1421 m.def(
1422 "TFE_MonitoringNewStringGauge3",
1423 [](const char* name, const char* description, const char* label1,
1424 const char* label2, const char* label3) {
1425 tensorflow::Safe_TF_StatusPtr status =
1426 tensorflow::make_safe(TF_NewStatus());
1427 auto output = TFE_MonitoringNewStringGauge3(
1428 name, status.get(), description, label1, label2, label3);
1429 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1430 return output;
1431 },
1432 py::return_value_policy::reference);
1433 m.def("TFE_MonitoringDeleteStringGauge3", &TFE_MonitoringDeleteStringGauge3);
1434 m.def("TFE_MonitoringGetCellStringGauge3", &TFE_MonitoringGetCellStringGauge3,
1435 py::return_value_policy::reference);
1436
1437 m.def(
1438 "TFE_MonitoringNewStringGauge4",
1439 [](const char* name, const char* description, const char* label1,
1440 const char* label2, const char* label3, const char* label4) {
1441 tensorflow::Safe_TF_StatusPtr status =
1442 tensorflow::make_safe(TF_NewStatus());
1443 auto output = TFE_MonitoringNewStringGauge4(
1444 name, status.get(), description, label1, label2, label3, label4);
1445 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1446 return output;
1447 },
1448 py::return_value_policy::reference);
1449 m.def("TFE_MonitoringDeleteStringGauge4", &TFE_MonitoringDeleteStringGauge4);
1450 m.def("TFE_MonitoringGetCellStringGauge4", &TFE_MonitoringGetCellStringGauge4,
1451 py::return_value_policy::reference);
1452
1453 // TFE_MonitoringBoolGauge Logic
1454 m.def("TFE_MonitoringBoolGaugeCellSet", &TFE_MonitoringBoolGaugeCellSet);
1455 m.def("TFE_MonitoringBoolGaugeCellValue", &TFE_MonitoringBoolGaugeCellValue);
1456 m.def(
1457 "TFE_MonitoringNewBoolGauge0",
1458 [](const char* name, const char* description) {
1459 tensorflow::Safe_TF_StatusPtr status =
1460 tensorflow::make_safe(TF_NewStatus());
1461 auto output =
1462 TFE_MonitoringNewBoolGauge0(name, status.get(), description);
1463 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1464 return output;
1465 },
1466 py::return_value_policy::reference);
1467 m.def("TFE_MonitoringDeleteBoolGauge0", &TFE_MonitoringDeleteBoolGauge0,
1468 py::return_value_policy::reference);
1469 m.def("TFE_MonitoringGetCellBoolGauge0", &TFE_MonitoringGetCellBoolGauge0,
1470 py::return_value_policy::reference);
1471 m.def(
1472 "TFE_MonitoringNewBoolGauge1",
1473 [](const char* name, const char* description, const char* label1) {
1474 tensorflow::Safe_TF_StatusPtr status =
1475 tensorflow::make_safe(TF_NewStatus());
1476 auto output = TFE_MonitoringNewBoolGauge1(name, status.get(),
1477 description, label1);
1478 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1479 return output;
1480 },
1481 py::return_value_policy::reference);
1482 m.def("TFE_MonitoringDeleteBoolGauge1", &TFE_MonitoringDeleteBoolGauge1,
1483 py::return_value_policy::reference);
1484 m.def("TFE_MonitoringGetCellBoolGauge1", &TFE_MonitoringGetCellBoolGauge1,
1485 py::return_value_policy::reference);
1486 m.def(
1487 "TFE_MonitoringNewBoolGauge2",
1488 [](const char* name, const char* description, const char* label1,
1489 const char* label2) {
1490 tensorflow::Safe_TF_StatusPtr status =
1491 tensorflow::make_safe(TF_NewStatus());
1492 auto output = TFE_MonitoringNewBoolGauge2(name, status.get(),
1493 description, label1, label2);
1494 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1495 return output;
1496 },
1497 py::return_value_policy::reference);
1498 m.def("TFE_MonitoringDeleteBoolGauge2", &TFE_MonitoringDeleteBoolGauge2,
1499 py::return_value_policy::reference);
1500 m.def("TFE_MonitoringGetCellBoolGauge2", &TFE_MonitoringGetCellBoolGauge2,
1501 py::return_value_policy::reference);
1502
1503 // TFE_MonitoringSampler Logic
1504 m.def("TFE_MonitoringSamplerCellAdd", &TFE_MonitoringSamplerCellAdd);
1505 m.def("TFE_MonitoringSamplerCellValue", &TFE_MonitoringSamplerCellValue);
1506 m.def("TFE_MonitoringNewExponentialBuckets",
1507 &TFE_MonitoringNewExponentialBuckets,
1508 py::return_value_policy::reference);
1509 m.def("TFE_MonitoringDeleteBuckets", &TFE_MonitoringDeleteBuckets,
1510 py::return_value_policy::reference);
1511 m.def(
1512 "TFE_MonitoringNewSampler0",
1513 [](const char* name, TFE_MonitoringBuckets* buckets,
1514 const char* description) {
1515 tensorflow::Safe_TF_StatusPtr status =
1516 tensorflow::make_safe(TF_NewStatus());
1517 auto output =
1518 TFE_MonitoringNewSampler0(name, buckets, status.get(), description);
1519 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1520 return output;
1521 },
1522 py::return_value_policy::reference);
1523 m.def("TFE_MonitoringDeleteSampler0", &TFE_MonitoringDeleteSampler0,
1524 py::return_value_policy::reference);
1525 m.def("TFE_MonitoringGetCellSampler0", &TFE_MonitoringGetCellSampler0,
1526 py::return_value_policy::reference);
1527 m.def(
1528 "TFE_MonitoringNewSampler1",
1529 [](const char* name, TFE_MonitoringBuckets* buckets,
1530 const char* description, const char* label1) {
1531 tensorflow::Safe_TF_StatusPtr status =
1532 tensorflow::make_safe(TF_NewStatus());
1533 auto output = TFE_MonitoringNewSampler1(name, buckets, status.get(),
1534 description, label1);
1535 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1536 return output;
1537 },
1538 py::return_value_policy::reference);
1539 m.def("TFE_MonitoringDeleteSampler1", &TFE_MonitoringDeleteSampler1,
1540 py::return_value_policy::reference);
1541 m.def("TFE_MonitoringGetCellSampler1", &TFE_MonitoringGetCellSampler1,
1542 py::return_value_policy::reference);
1543 m.def(
1544 "TFE_MonitoringNewSampler2",
1545 [](const char* name, TFE_MonitoringBuckets* buckets,
1546 const char* description, const char* label1, const char* label2) {
1547 tensorflow::Safe_TF_StatusPtr status =
1548 tensorflow::make_safe(TF_NewStatus());
1549 auto output = TFE_MonitoringNewSampler2(name, buckets, status.get(),
1550 description, label1, label2);
1551 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1552 return output;
1553 },
1554 py::return_value_policy::reference);
1555 m.def("TFE_MonitoringDeleteSampler2", &TFE_MonitoringDeleteSampler2,
1556 py::return_value_policy::reference);
1557 m.def("TFE_MonitoringGetCellSampler2", &TFE_MonitoringGetCellSampler2,
1558 py::return_value_policy::reference);
1559
1560 // TFE_CancellationManager Logic
1561 m.def("TFE_NewCancellationManager",
1562 []() { return new tensorflow::CancellationManager(); });
1563 m.def("TFE_CancellationManagerIsCancelled",
1564 &tensorflow::CancellationManager::IsCancelled);
1565 m.def("TFE_CancellationManagerStartCancel",
1566 &tensorflow::CancellationManager::StartCancel);
1567
1568 m.def("TFE_ClearScalarCache", &tensorflow::TFE_ClearScalarCache);
1569
1570 // Util buffer helper functions
1571 m.def("TF_NewBufferFromString", &TF_NewBufferFromString,
1572 py::return_value_policy::reference);
1573
1574 // DLPack functions
1575 m.def("TFE_ToDlpackCapsule", [](py::handle& o) {
1576 PyObject* eager_tensor_pyobject_ptr = o.ptr();
1577 tensorflow::Safe_TF_StatusPtr status =
1578 tensorflow::make_safe(TF_NewStatus());
1579
1580 if (!EagerTensor_CheckExact(eager_tensor_pyobject_ptr)) {
1581 status->status = tensorflow::errors::InvalidArgument(
1582 "The argument to `to_dlpack` must be a TF tensor, not Python object");
1583 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1584 }
1585
1586 TFE_TensorHandle* thandle = EagerTensor_Handle(eager_tensor_pyobject_ptr);
1587 void* dlm_ptr = tensorflow::TFE_HandleToDLPack(thandle, status.get());
1588 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1589
1590 py::capsule capsule(
1591 dlm_ptr, tensorflow::kDlTensorCapsuleName, [](PyObject* capsule) {
1592 if (PyCapsule_IsValid(capsule, tensorflow::kDlTensorCapsuleName)) {
1593 void* dlm_rptr =
1594 PyCapsule_GetPointer(capsule, tensorflow::kDlTensorCapsuleName);
1595 if (dlm_rptr) {
1596 tensorflow::TFE_CallDLManagedTensorDeleter(dlm_rptr);
1597 PyCapsule_SetDestructor(capsule, nullptr);
1598 }
1599 }
1600 });
1601 return capsule;
1602 });
1603
1604 m.def("TFE_FromDlpackCapsule", [](const py::capsule& pycapsule,
1605 const py::handle& context) {
1606 tensorflow::Safe_TF_StatusPtr status =
1607 tensorflow::make_safe(TF_NewStatus());
1608 if (absl::string_view(pycapsule.name()) !=
1609 tensorflow::kDlTensorCapsuleName) {
1610 status->status = tensorflow::errors::InvalidArgument(
1611 "DLPack tensor must be a capsule with name \"dltensor\", got \"%s\". "
1612 "Note that a DLPack tensor may be consumed at most once.",
1613 absl::string_view(pycapsule.name()));
1614 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1615 }
1616
1617 TFE_TensorHandle* thandle = tensorflow::TFE_HandleFromDLPack(
1618 pycapsule, status.get(), tensorflow::InputTFE_Context(context));
1619
1620 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1621
1622 PyCapsule_SetName(pycapsule.ptr(), "used_dltensor");
1623 PyCapsule_SetDestructor(pycapsule.ptr(), nullptr);
1624
1625 PyObject* pyhandle = EagerTensorFromHandle(thandle);
1626 return tensorflow::PyoOrThrow(pyhandle);
1627 });
1628
1629 m.def("TFE_Py_IsCustomDevice",
1630 [](const py::handle& context, const char* device_name) {
1631 return TFE_IsCustomDevice(tensorflow::InputTFE_Context(context),
1632 device_name);
1633 });
1634
1635 m.def("TFE_Py_RegisterCustomDevice", [](const py::handle& context,
1636 const py::capsule& device,
1637 const char* device_name,
1638 const py::capsule& device_info) {
1639 tensorflow::Safe_TF_StatusPtr status =
1640 tensorflow::make_safe(TF_NewStatus());
1641 if (absl::string_view(device.name()) != "TFE_CustomDevice") {
1642 status->status = tensorflow::errors::InvalidArgument(
1643 "Expected a capsule named 'TFE_CustomDevice' for the `device` "
1644 "argument, got ",
1645 absl::string_view(device.name()));
1646 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1647 }
1648 if (absl::string_view(device_info.name()) !=
1649 "TFE_CustomDevice_DeviceInfo") {
1650 status->status = tensorflow::errors::InvalidArgument(
1651 "Expected a capsule named 'TFE_CustomDevice_DeviceInfo' for "
1652 "the `device_info` argument, got ",
1653 absl::string_view(device_info.name()));
1654 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1655 }
1656 // TFE_RegisterCustomDevice takes ownership
1657 PyCapsule_SetDestructor(device_info.ptr(), nullptr);
1658 TFE_RegisterCustomDevice(
1659 tensorflow::InputTFE_Context(context),
1660 *reinterpret_cast<TFE_CustomDevice*>(
1661 PyCapsule_GetPointer(device.ptr(), "TFE_CustomDevice")),
1662 device_name,
1663 PyCapsule_GetPointer(device_info.ptr(), "TFE_CustomDevice_DeviceInfo"),
1664 status.get());
1665 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1666 });
1667
1668 py::class_<EagerContextThreadLocalDataWrapper>(m,
1669 "EagerContextThreadLocalData")
1670 .def(py::init<py::handle, py::handle, py::handle>(),
1671 py::arg("py_eager_context"), py::arg("is_eager"),
1672 py::arg("device_spec"))
1673 .def_property("is_eager",
1674 &EagerContextThreadLocalDataWrapper::get_is_eager,
1675 &EagerContextThreadLocalDataWrapper::set_is_eager)
1676 .def_property(
1677 "invoking_op_callbacks",
1678 &EagerContextThreadLocalDataWrapper::get_invoking_op_callbacks,
1679 &EagerContextThreadLocalDataWrapper::set_invoking_op_callbacks)
1680 .def_property("device_name",
1681 &EagerContextThreadLocalDataWrapper::get_device_name,
1682 &EagerContextThreadLocalDataWrapper::set_device_name)
1683 .def_property("scope_name",
1684 &EagerContextThreadLocalDataWrapper::get_scope_name,
1685 &EagerContextThreadLocalDataWrapper::set_scope_name)
1686 .def_property("device_spec",
1687 &EagerContextThreadLocalDataWrapper::get_device_spec,
1688 &EagerContextThreadLocalDataWrapper::set_device_spec)
1689 .def_property(
1690 "function_call_options",
1691 &EagerContextThreadLocalDataWrapper::get_function_call_options,
1692 &EagerContextThreadLocalDataWrapper::set_function_call_options)
1693 .def_property("executor",
1694 &EagerContextThreadLocalDataWrapper::get_executor,
1695 &EagerContextThreadLocalDataWrapper::set_executor)
1696 .def_property("op_callbacks",
1697 &EagerContextThreadLocalDataWrapper::get_op_callbacks,
1698 &EagerContextThreadLocalDataWrapper::set_op_callbacks);
1699
1700 // C API Enum
1701
1702 py::enum_<TFE_ContextDevicePlacementPolicy>(
1703 m, "TFE_ContextDevicePlacementPolicy")
1704 .value("TFE_DEVICE_PLACEMENT_EXPLICIT", TFE_DEVICE_PLACEMENT_EXPLICIT)
1705 .value("TFE_DEVICE_PLACEMENT_WARN", TFE_DEVICE_PLACEMENT_WARN)
1706 .value("TFE_DEVICE_PLACEMENT_SILENT", TFE_DEVICE_PLACEMENT_SILENT)
1707 .value("TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32",
1708 TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32)
1709 .export_values();
1710
1711 py::enum_<TF_AttrType>(m, "TF_AttrType")
1712 .value("TF_ATTR_STRING", TF_ATTR_STRING)
1713 .value("TF_ATTR_INT", TF_ATTR_INT)
1714 .value("TF_ATTR_FLOAT", TF_ATTR_FLOAT)
1715 .value("TF_ATTR_BOOL", TF_ATTR_BOOL)
1716 .value("TF_ATTR_TYPE", TF_ATTR_TYPE)
1717 .value("TF_ATTR_SHAPE", TF_ATTR_SHAPE)
1718 .value("TF_ATTR_TENSOR", TF_ATTR_TENSOR)
1719 .value("TF_ATTR_PLACEHOLDER", TF_ATTR_PLACEHOLDER)
1720 .value("TF_ATTR_FUNC", TF_ATTR_FUNC)
1721 .export_values();
1722};
1723