1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License");; |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <memory> |
17 | |
18 | #include "Python.h" |
19 | #include "absl/strings/match.h" |
20 | #include "absl/strings/str_format.h" |
21 | #include "absl/strings/str_split.h" |
22 | #include "pybind11/chrono.h" |
23 | #include "pybind11/complex.h" |
24 | #include "pybind11/functional.h" |
25 | #include "pybind11/pybind11.h" |
26 | #include "pybind11/pytypes.h" |
27 | #include "pybind11/stl.h" |
28 | #include "tensorflow/c/c_api.h" |
29 | #include "tensorflow/c/c_api_experimental.h" |
30 | #include "tensorflow/c/eager/c_api.h" |
31 | #include "tensorflow/c/eager/c_api_experimental.h" |
32 | #include "tensorflow/c/eager/c_api_internal.h" |
33 | #include "tensorflow/c/eager/dlpack.h" |
34 | #include "tensorflow/c/eager/tfe_cancellation_manager_internal.h" |
35 | #include "tensorflow/c/eager/tfe_context_internal.h" |
36 | #include "tensorflow/c/eager/tfe_tensorhandle_internal.h" |
37 | #include "tensorflow/c/tf_status.h" |
38 | #include "tensorflow/c/tf_status_helper.h" |
39 | #include "tensorflow/compiler/jit/flags.h" |
40 | #include "tensorflow/compiler/jit/get_compiler_ir.h" |
41 | #include "tensorflow/core/common_runtime/eager/context.h" |
42 | #include "tensorflow/python/eager/pywrap_tensor_conversion.h" |
43 | #include "tensorflow/python/eager/pywrap_tfe.h" |
44 | #include "tensorflow/python/lib/core/py_exception_registry.h" |
45 | #include "tensorflow/python/lib/core/pybind11_lib.h" |
46 | #include "tensorflow/python/lib/core/pybind11_status.h" |
47 | #include "tensorflow/python/lib/core/safe_ptr.h" |
48 | #include "tensorflow/python/lib/core/safe_pyobject_ptr.h" |
49 | #include "tensorflow/python/util/util.h" |
50 | |
51 | namespace py = pybind11; |
52 | |
53 | PYBIND11_MAKE_OPAQUE(TFE_Executor); |
54 | PYBIND11_MAKE_OPAQUE(TFE_ContextOptions); |
55 | PYBIND11_MAKE_OPAQUE(tensorflow::CancellationManager); |
56 | |
57 | PYBIND11_MAKE_OPAQUE(TFE_MonitoringCounter0); |
58 | PYBIND11_MAKE_OPAQUE(TFE_MonitoringCounter1); |
59 | PYBIND11_MAKE_OPAQUE(TFE_MonitoringCounter2); |
60 | PYBIND11_MAKE_OPAQUE(TFE_MonitoringStringGauge0); |
61 | PYBIND11_MAKE_OPAQUE(TFE_MonitoringStringGauge1); |
62 | PYBIND11_MAKE_OPAQUE(TFE_MonitoringStringGauge2); |
63 | PYBIND11_MAKE_OPAQUE(TFE_MonitoringStringGauge3); |
64 | PYBIND11_MAKE_OPAQUE(TFE_MonitoringStringGauge4); |
65 | PYBIND11_MAKE_OPAQUE(TFE_MonitoringIntGauge0); |
66 | PYBIND11_MAKE_OPAQUE(TFE_MonitoringIntGauge1); |
67 | PYBIND11_MAKE_OPAQUE(TFE_MonitoringIntGauge2); |
68 | PYBIND11_MAKE_OPAQUE(TFE_MonitoringBoolGauge0); |
69 | PYBIND11_MAKE_OPAQUE(TFE_MonitoringBoolGauge1); |
70 | PYBIND11_MAKE_OPAQUE(TFE_MonitoringBoolGauge2); |
71 | PYBIND11_MAKE_OPAQUE(TFE_MonitoringSampler0); |
72 | PYBIND11_MAKE_OPAQUE(TFE_MonitoringSampler1); |
73 | PYBIND11_MAKE_OPAQUE(TFE_MonitoringSampler2); |
74 | PYBIND11_MAKE_OPAQUE(TFE_MonitoringCounterCell); |
75 | PYBIND11_MAKE_OPAQUE(TFE_MonitoringIntGaugeCell); |
76 | PYBIND11_MAKE_OPAQUE(TFE_MonitoringStringGaugeCell); |
77 | PYBIND11_MAKE_OPAQUE(TFE_MonitoringBoolGaugeCell); |
78 | PYBIND11_MAKE_OPAQUE(TFE_MonitoringSamplerCell); |
79 | |
80 | PYBIND11_MAKE_OPAQUE(TF_DeviceList); |
81 | PYBIND11_MAKE_OPAQUE(TF_Function); |
82 | PYBIND11_MAKE_OPAQUE(TF_Buffer); |
83 | |
84 | // Eager helper functions migrated from pywrap_tfe.i. |
85 | |
86 | namespace tensorflow { |
87 | |
88 | // We cannot use Context as an opaque type. SWIG also had |
89 | // difficult directly passing the pointer around. These |
90 | // typemaps are migrated over from pywrap_tfe.i. I tried |
91 | // using a custom type caster, but we get segfaults periodically. |
92 | |
93 | // TODO(amitpatankar): Move input and output logic of Context into a |
94 | // pybind11 custom type caster. |
95 | |
96 | TFE_Context* InputTFE_Context(const py::handle& ctx) { |
97 | return static_cast<TFE_Context*>(PyCapsule_GetPointer(ctx.ptr(), nullptr)); |
98 | } |
99 | |
100 | PyObject* OutputTFE_Context(TFE_Context* context) { |
101 | return PyCapsule_New(context, nullptr, TFE_DeleteContextCapsule); |
102 | } |
103 | |
104 | TF_Buffer* ProtoStringToTFBuffer(PyObject* input) { |
105 | // Convert a Python string object to TF_Buffer. |
106 | char* c_string; |
107 | Py_ssize_t py_size; |
108 | // PyBytes_AsStringAndSize() does not copy but simply interprets the input |
109 | if (PyBytes_AsStringAndSize(input, &c_string, &py_size) == -1) { |
110 | // Python has raised an error (likely TypeError or UnicodeEncodeError). |
111 | throw py::error_already_set(); |
112 | } |
113 | return TF_NewBufferFromString(static_cast<void*>(c_string), |
114 | static_cast<size_t>(py_size)); |
115 | } |
116 | |
117 | // These functions are typemaps from the Python side. I did not use |
118 | // a custom type caster since the logic is slightly harder to follow. This |
119 | // converter is also only used once in `TFE_Py_ExecuteCancelable_wrapper`. |
120 | TFE_InputTensorHandles InputTFE_InputTensorHandles( |
121 | const py::handle& input_tensors) { |
122 | TFE_InputTensorHandles input_tensor_handles; |
123 | if (input_tensors.ptr() != Py_None) { |
124 | if (!PyList_Check(input_tensors.ptr())) { |
125 | tensorflow::ThrowTypeError("must provide a list of Tensors as inputs" ); |
126 | } |
127 | Py_ssize_t len = PyList_Size(input_tensors.ptr()); |
128 | input_tensor_handles.resize(len); |
129 | for (Py_ssize_t i = 0; i < len; ++i) { |
130 | PyObject* elem = PyList_GetItem(input_tensors.ptr(), i); |
131 | if (!elem) { |
132 | tensorflow::ThrowTypeError("Input Tensor does not exist." ); |
133 | } |
134 | if (EagerTensor_CheckExact(elem)) { |
135 | (input_tensor_handles)[i] = EagerTensor_Handle(elem); |
136 | } else if (tensorflow::swig::IsEagerTensorSlow(elem)) { |
137 | // Use equivalent of object.__getattribute__ to get the underlying |
138 | // tf wrapped EagerTensor (if there is one). |
139 | tensorflow::Safe_PyObjectPtr tf_should_use_attr( |
140 | #if PY_MAJOR_VERSION < 3 |
141 | PyString_InternFromString("_tf_should_use_wrapped_value" ) |
142 | #else |
143 | PyUnicode_InternFromString("_tf_should_use_wrapped_value" ) |
144 | #endif |
145 | ); |
146 | tensorflow::Safe_PyObjectPtr value_attr( |
147 | PyObject_GenericGetAttr(elem, tf_should_use_attr.get())); |
148 | if (value_attr) { |
149 | // This is an EagerTensor wrapped inside a TFShouldUse wrapped object. |
150 | (input_tensor_handles)[i] = EagerTensor_Handle(value_attr.get()); |
151 | } else { |
152 | // This is a subclass of EagerTensor that we don't support. |
153 | PyErr_Clear(); |
154 | tensorflow::ThrowTypeError( |
155 | tensorflow::strings::StrCat( |
156 | "Saw an object that is an instance of a strict subclass of " |
157 | "EagerTensor, which is not supported. Item " , |
158 | i, " is type: " , elem->ob_type->tp_name) |
159 | .c_str()); |
160 | } |
161 | } else if (tensorflow::swig::IsTensor(elem)) { |
162 | // If it isnt an EagerTensor, but is still a Tensor, it must be a graph |
163 | // tensor. |
164 | tensorflow::Safe_PyObjectPtr py_tensor_repr(PyObject_Repr(elem)); |
165 | std::string tensor_repr = |
166 | py_tensor_repr ? TFE_GetPythonString(py_tensor_repr.get()) |
167 | : "<unknown>" ; |
168 | tensorflow::Safe_PyObjectPtr py_op(PyObject_GetAttrString(elem, "op" )); |
169 | tensorflow::Safe_PyObjectPtr py_defined_graph( |
170 | PyObject_GetAttrString(py_op.get(), "graph" )); |
171 | tensorflow::Safe_PyObjectPtr py_defined_graph_str( |
172 | PyObject_Str(py_defined_graph.get())); |
173 | std::string defined_graph_str = |
174 | py_defined_graph_str |
175 | ? TFE_GetPythonString(py_defined_graph_str.get()) |
176 | : "<unknown>" ; |
177 | tensorflow::Safe_PyObjectPtr c_op( |
178 | PyObject_GetAttrString(py_op.get(), "_c_op" )); |
179 | auto& node = py::cast<TF_Operation*>(c_op.get())->node; |
180 | auto node_name_str = node.name(); |
181 | std::string frame_str, traceback_str; |
182 | if (auto stack_trace = node.GetStackTrace()) { |
183 | auto frame = stack_trace->LastUserFrame(); |
184 | frame_str = |
185 | absl::StrFormat("File \"%s\", line %d, in %s" , frame.file_name, |
186 | frame.line_number, frame.function_name); |
187 | auto stack_trace_list = |
188 | absl::StrSplit(stack_trace->ToString({true}), '\n'); |
189 | traceback_str = absl::StrJoin( |
190 | stack_trace_list, "" , [&](std::string* out, const auto line) { |
191 | absl::StrAppend(out, " " , line, "\n" ); |
192 | }); |
193 | } else { |
194 | frame_str = "<unknown>" ; |
195 | traceback_str = "<unknown>\n" ; |
196 | } |
197 | // Keep in sync with func_graph.py. |
198 | // TODO(b/200991648): Unify those two paths. |
199 | tensorflow::ThrowTypeError( |
200 | tensorflow::strings::StrCat( |
201 | tensor_repr, |
202 | " is out of scope and cannot be used here. " |
203 | "Use return values, explicit Python locals or TensorFlow " |
204 | "collections to access it.\n" |
205 | "Please see https://www.tensorflow.org/guide/" |
206 | "function#all_outputs_of_a_tffunction_must_be_return_values " |
207 | "for more information.\n\n" , |
208 | tensor_repr, " was defined here:\n" , traceback_str, |
209 | "\nThe tensor " , tensor_repr, |
210 | " cannot be accessed from here, because it was " |
211 | "defined in " , |
212 | defined_graph_str, ", which is out of scope." ) |
213 | .c_str()); |
214 | } else { |
215 | tensorflow::ThrowTypeError( |
216 | tensorflow::strings::StrCat( |
217 | "provided list of inputs contains objects other " |
218 | "than 'EagerTensor'. Item " , |
219 | i, " is type: " , elem->ob_type->tp_name) |
220 | .c_str()); |
221 | } |
222 | } |
223 | } |
224 | return input_tensor_handles; |
225 | } |
226 | |
227 | // These functions are typemaps from the Python side. I did not use |
228 | // a custom type caster since the logic is slightly harder to follow. This |
229 | // converter is also only used once in `TFE_Py_ExecuteCancelable_wrapper`. |
230 | // This function actually takes a number rather than an output Tensor holder. |
231 | TFE_OutputTensorHandles InputTFE_OutputTensorHandles( |
232 | const py::handle& num_outputs) { |
233 | TFE_OutputTensorHandles output_tensor_handles; |
234 | #if PY_MAJOR_VERSION < 3 |
235 | if (!PyInt_Check(num_outputs.ptr())) { |
236 | #else |
237 | if (!PyLong_Check(num_outputs.ptr())) { |
238 | #endif |
239 | PyErr_SetString(PyExc_TypeError, |
240 | "expected an integer value (size of the number of " |
241 | "outputs of the operation)" ); |
242 | throw py::error_already_set(); |
243 | } |
244 | #if PY_MAJOR_VERSION < 3 |
245 | long sz = PyInt_AsLong(num_outputs.ptr()); // NOLINT |
246 | #else |
247 | long sz = PyLong_AsLong(num_outputs.ptr()); // NOLINT |
248 | #endif |
249 | // PyLong_AsLong might throw an error if an overflow occurs. |
250 | if (PyErr_Occurred()) { |
251 | PyErr_SetString(PyExc_ValueError, tensorflow::strings::StrCat( |
252 | "Number of outputs is too big: " , sz) |
253 | .c_str()); |
254 | throw py::error_already_set(); |
255 | } |
256 | // We can't handle more than int32 sizes for number of outputs. |
257 | if (static_cast<long>(static_cast<int32_t>(sz)) != sz) { // NOLINT |
258 | PyErr_SetString(PyExc_ValueError, tensorflow::strings::StrCat( |
259 | "Number of outputs is too big: " , sz) |
260 | .c_str()); |
261 | throw py::error_already_set(); |
262 | } |
263 | if (sz > 0) { |
264 | #if PY_MAJOR_VERSION < 3 |
265 | output_tensor_handles.resize(PyInt_AsLong(num_outputs.ptr()), nullptr); |
266 | #else |
267 | output_tensor_handles.resize(PyLong_AsLong(num_outputs.ptr()), nullptr); |
268 | #endif |
269 | } |
270 | return output_tensor_handles; |
271 | } |
272 | |
273 | tensorflow::Device* GetMatchedDevice(py::handle& ctx, const char* device_name) { |
274 | auto* context = reinterpret_cast<tensorflow::ImmediateExecutionContext*>( |
275 | tensorflow::InputTFE_Context(ctx)); |
276 | |
277 | tensorflow::DeviceNameUtils::ParsedName input_device_name; |
278 | if (!tensorflow::DeviceNameUtils::ParseFullOrLocalName(device_name, |
279 | &input_device_name)) { |
280 | tensorflow::ThrowValueError( |
281 | absl::StrFormat("Failed parsing device name: '%s'. Note a valid device " |
282 | "string should at least contain a device type and a " |
283 | "device index, like \"GPU:0\"." , |
284 | device_name) |
285 | .c_str()); |
286 | } |
287 | |
288 | std::vector<tensorflow::Device*> devices = context->ListLocalTfDevices(); |
289 | |
290 | tensorflow::Device* matched_device = nullptr; |
291 | for (int device_idx = 0; device_idx < devices.size(); device_idx++) { |
292 | tensorflow::Device* device = devices[device_idx]; |
293 | |
294 | if (tensorflow::DeviceNameUtils::AreCompatibleDevNames( |
295 | input_device_name, device->parsed_name())) { |
296 | if (matched_device != nullptr) { |
297 | tensorflow::ThrowValueError( |
298 | absl::StrFormat("Multiple devices match the provided string " |
299 | "'%s': '%s' and '%s'." , |
300 | device_name, matched_device->name(), device->name()) |
301 | .c_str()); |
302 | } |
303 | matched_device = device; |
304 | } |
305 | } |
306 | |
307 | if (matched_device == nullptr) { |
308 | tensorflow::ThrowValueError( |
309 | absl::StrFormat("No matching devices found for '%s'" , device_name) |
310 | .c_str()); |
311 | } |
312 | |
313 | return matched_device; |
314 | } |
315 | |
316 | // Packs multiple `EagerTensor`s of the same dtype and shape into one |
317 | // `EagerTensor`. |
318 | py::object TFE_Py_PackEagerTensors_wrapper(const py::handle& context, |
319 | const py::handle& tensors) { |
320 | TFE_Context* ctx = tensorflow::InputTFE_Context(context); |
321 | TFE_InputTensorHandles handles = InputTFE_InputTensorHandles(tensors); |
322 | tensorflow::Safe_TF_StatusPtr status = tensorflow::make_safe(TF_NewStatus()); |
323 | int size = handles.size(); |
324 | TFE_TensorHandle* packed_handle = |
325 | TFE_CreatePackedTensorHandle(ctx, handles.data(), &size, status.get()); |
326 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
327 | PyObject* packed_tensor = |
328 | EagerTensorFromHandle(packed_handle, /*is_packed=*/true); |
329 | return tensorflow::PyoOrThrow(packed_tensor); |
330 | } |
331 | |
332 | // This function was created from fusing the typemap logic in platform/base.i. |
333 | py::object TFE_Py_ExecuteCancelable_wrapper( |
334 | const py::handle& context, const char* device_name, const char* op_name, |
335 | const py::handle& inputs, const py::handle& attrs, |
336 | tensorflow::CancellationManager* cancellation_manager, |
337 | const py::handle& num_outputs) { |
338 | TFE_Context* ctx = tensorflow::InputTFE_Context(context); |
339 | TFE_InputTensorHandles input_tensor_handles = |
340 | InputTFE_InputTensorHandles(inputs); |
341 | TFE_OutputTensorHandles output_tensor_handles = |
342 | InputTFE_OutputTensorHandles(num_outputs); |
343 | tensorflow::Safe_TF_StatusPtr status = tensorflow::make_safe(TF_NewStatus()); |
344 | TFE_Py_ExecuteCancelable(ctx, device_name, op_name, &input_tensor_handles, |
345 | attrs.ptr(), tensorflow::wrap(cancellation_manager), |
346 | &output_tensor_handles, status.get()); |
347 | |
348 | int output_len = output_tensor_handles.size(); |
349 | PyObject* output_list = PyList_New(output_len); |
350 | for (int i = 0; i < output_len; ++i) { |
351 | PyObject* output; |
352 | output = EagerTensorFromHandle(output_tensor_handles.at(i)); |
353 | PyList_SetItem(output_list, i, output); |
354 | } |
355 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
356 | return tensorflow::PyoOrThrow(output_list); |
357 | } |
358 | |
359 | static py::object TF_ListPhysicalDevices() { |
360 | std::vector<string> devices; |
361 | tensorflow::Status s = |
362 | tensorflow::DeviceFactory::ListAllPhysicalDevices(&devices); |
363 | MaybeRaiseRegisteredFromStatus(s); |
364 | PyObject* result = PyList_New(devices.size()); |
365 | int i = 0; |
366 | for (auto& dev : devices) { |
367 | PyObject* dev_obj = PyBytes_FromStringAndSize(dev.data(), dev.size()); |
368 | PyList_SetItem(result, i, dev_obj); |
369 | ++i; |
370 | } |
371 | return tensorflow::PyoOrThrow(result); |
372 | } |
373 | |
374 | static py::object TF_ListPluggablePhysicalDevices() { |
375 | std::vector<string> devices; |
376 | tensorflow::Status s = |
377 | tensorflow::DeviceFactory::ListPluggablePhysicalDevices(&devices); |
378 | MaybeRaiseRegisteredFromStatus(s); |
379 | Safe_PyObjectPtr result(PyList_New(devices.size())); |
380 | int i = 0; |
381 | for (auto& dev : devices) { |
382 | PyObject* dev_obj = PyBytes_FromStringAndSize(dev.data(), dev.size()); |
383 | PyList_SetItem(result.get(), i, dev_obj); |
384 | ++i; |
385 | } |
386 | return tensorflow::PyoOrThrow(result.release()); |
387 | } |
388 | |
389 | static std::unordered_map<string, string> TF_GetDeviceDetails(int index) { |
390 | tensorflow::Safe_TF_StatusPtr status = tensorflow::make_safe(TF_NewStatus()); |
391 | std::unordered_map<string, string> device_details; |
392 | tensorflow::Status s = |
393 | tensorflow::DeviceFactory::GetAnyDeviceDetails(index, &device_details); |
394 | tensorflow::Set_TF_Status_from_Status(status.get(), s); |
395 | MaybeRaiseRegisteredFromTFStatus(status.get()); |
396 | return device_details; |
397 | } |
398 | |
399 | static py::object TFE_ClearScalarCache() { |
400 | tensorflow::TFE_TensorHandleCache::Get()->Clear(); |
401 | return py::none(); |
402 | } |
403 | |
404 | // Returns compiler IR for a given function. |
405 | static py::bytes TFE_GetCompilerIr(py::handle& ctx, |
406 | const char* concrete_function_name, |
407 | const char* stage, const char* device_name, |
408 | py::handle& inputs) { |
409 | EagerContext* context = ContextFromInterface( |
410 | reinterpret_cast<ImmediateExecutionContext*>(InputTFE_Context(ctx))); |
411 | |
412 | std::string s_stage(stage); |
413 | IrExportStage selected_stage = [&] { |
414 | if (s_stage == "hlo" ) { |
415 | return IrExportStage::HLO; |
416 | } else if (s_stage == "hlo_no_metadata" ) { |
417 | return IrExportStage::HLO_NO_METADATA; |
418 | } else if (s_stage == "hlo_serialized" ) { |
419 | return IrExportStage::HLO_SERIALIZED; |
420 | } else if (s_stage == "optimized_hlo" ) { |
421 | return IrExportStage::OPTIMIZED_HLO; |
422 | } else if (s_stage == "optimized_hlo_serialized" ) { |
423 | return IrExportStage::OPTIMIZED_HLO_SERIALIZED; |
424 | } else if (s_stage == "optimized_hlo_proto_serialized" ) { |
425 | return IrExportStage::OPTIMIZED_HLO_PROTO_SERIALIZED; |
426 | } else if (s_stage == "optimized_hlo_dot" ) { |
427 | return IrExportStage::OPTIMIZED_HLO_DOT; |
428 | } else { |
429 | ThrowValueError( |
430 | absl::StrFormat("Invalid stage selected: '%s'. Valid values are: " |
431 | "'hlo', 'hlo_serialized', 'optimized_hlo', " |
432 | "'optimized_hlo_serialized', 'optimized_hlo_dot'" , |
433 | s_stage) |
434 | .c_str()); |
435 | } |
436 | }(); |
437 | |
438 | TFE_InputTensorHandles handles = InputTFE_InputTensorHandles(inputs); |
439 | |
440 | std::vector<const TensorHandle*> input_handles; |
441 | for (TFE_TensorHandle* tensor_handle : handles) { |
442 | AbstractTensorHandle* abstract_tensor_handle = unwrap(tensor_handle); |
443 | input_handles.push_back(TensorHandleFromInterface(abstract_tensor_handle)); |
444 | } |
445 | |
446 | DeviceNameUtils::ParsedName input_device_name; |
447 | if (!DeviceNameUtils::ParseFullOrLocalName(device_name, &input_device_name)) { |
448 | ThrowValueError( |
449 | absl::StrFormat("Failed parsing device name: '%s'" , device_name) |
450 | .c_str()); |
451 | } |
452 | |
453 | std::vector<Device*> devices = context->local_device_mgr()->ListDevices(); |
454 | auto selected_device = absl::c_find_if(devices, [&](const Device* d) { |
455 | return DeviceNameUtils::AreCompatibleDevNames(input_device_name, |
456 | d->parsed_name()); |
457 | }); |
458 | if (selected_device == devices.end()) { |
459 | ThrowValueError( |
460 | absl::StrFormat("No matching device found for '%s'" , device_name) |
461 | .c_str()); |
462 | } |
463 | |
464 | StatusOr<std::string> hlo_str = |
465 | GetCompilerIr(selected_stage, context->pflr(), concrete_function_name, |
466 | *selected_device, context, input_handles); |
467 | |
468 | if (!hlo_str.ok()) { |
469 | ThrowValueError(absl::StrFormat("Failed getting HLO text: '%s'" , |
470 | hlo_str.status().error_message()) |
471 | .c_str()); |
472 | } |
473 | return py::bytes(*hlo_str); |
474 | } |
475 | |
476 | } // namespace tensorflow |
477 | |
478 | namespace { |
479 | |
480 | // Wrapper around the EagerContextThreadLocalData struct (defined in |
481 | // pywrap_tfe.h), so it can be accessed from Python. |
482 | // |
483 | // For PyObject* fields, the get_*() methods return a new reference; and the |
484 | // set_*() methods create a new reference (i.e., they do not steal a reference). |
485 | class EagerContextThreadLocalDataWrapper { |
486 | public: |
487 | explicit EagerContextThreadLocalDataWrapper(py::handle py_eager_context, |
488 | py::handle is_eager, |
489 | py::handle device_spec) |
490 | : py_eager_context_(py_eager_context.ptr()) { |
491 | tensorflow::MakeEagerContextThreadLocalData( |
492 | py_eager_context.ptr(), is_eager.ptr(), device_spec.ptr()); |
493 | } |
494 | |
495 | ~EagerContextThreadLocalDataWrapper() { |
496 | tensorflow::DestroyEagerContextThreadLocalData(py_eager_context_); |
497 | } |
498 | |
499 | bool get_is_eager() const { return GetData()->is_eager; } |
500 | void set_is_eager(bool v) { GetData()->is_eager = v; } |
501 | |
502 | bool get_invoking_op_callbacks() const { |
503 | return GetData()->invoking_op_callbacks; |
504 | } |
505 | void set_invoking_op_callbacks(bool v) { |
506 | GetData()->invoking_op_callbacks = v; |
507 | } |
508 | |
509 | py::object get_device_name() const { |
510 | return GetPyObject(&GetData()->device_name); |
511 | } |
512 | void set_device_name(py::handle v) { |
513 | SetPyObject(v, &GetData()->device_name); |
514 | } |
515 | |
516 | py::object get_scope_name() const { |
517 | return GetPyObject(&GetData()->scope_name); |
518 | } |
519 | void set_scope_name(py::handle v) { SetPyObject(v, &GetData()->scope_name); } |
520 | |
521 | py::object get_device_spec() const { |
522 | return GetPyObject(&GetData()->device_spec); |
523 | } |
524 | void set_device_spec(py::handle v) { |
525 | SetPyObject(v, &GetData()->device_spec); |
526 | } |
527 | |
528 | py::object get_function_call_options() const { |
529 | return GetPyObject(&GetData()->function_call_options); |
530 | } |
531 | void set_function_call_options(py::handle v) { |
532 | SetPyObject(v, &GetData()->function_call_options); |
533 | } |
534 | |
535 | py::handle get_executor() const { return GetPyObject(&GetData()->executor); } |
536 | void set_executor(py::handle v) { SetPyObject(v, &GetData()->executor); } |
537 | |
538 | py::object get_op_callbacks() const { |
539 | return GetPyObject(&GetData()->op_callbacks); |
540 | } |
541 | void set_op_callbacks(py::handle v) { |
542 | SetPyObject(v, &GetData()->op_callbacks); |
543 | } |
544 | |
545 | private: |
546 | tensorflow::EagerContextThreadLocalData* GetData() const { |
547 | auto* result = |
548 | tensorflow::GetEagerContextThreadLocalData(py_eager_context_); |
549 | if (!result) { |
550 | throw py::error_already_set(); |
551 | } |
552 | return result; |
553 | } |
554 | |
555 | py::object GetPyObject(tensorflow::Safe_PyObjectPtr* obj) const { |
556 | return pybind11::reinterpret_borrow<py::object>(obj->get()); |
557 | } |
558 | |
559 | void SetPyObject(py::handle value, tensorflow::Safe_PyObjectPtr* ptr) { |
560 | Py_INCREF(value.ptr()); |
561 | ptr->reset(value.ptr()); |
562 | } |
563 | |
564 | PyObject* py_eager_context_; // not owned (borrowed reference). |
565 | }; |
566 | |
567 | } // namespace |
568 | |
569 | // py::return_value_policy::reference is defined as specified by the |
570 | // pybind11 documents listed here. |
571 | // https://pybind11.readthedocs.io/en/stable/advanced/functions.html#return-value-policies |
572 | // This means that C++ maintains ownership of the object. We |
573 | // are only assigning this to functions that return opaque types. |
574 | |
575 | PYBIND11_MODULE(_pywrap_tfe, m) { |
576 | py::class_<TFE_Executor> TFE_Executor_class(m, "TFE_Executor" ); |
577 | py::class_<TFE_ContextOptions> TFE_ContextOptions_class(m, |
578 | "TFE_ContextOptions" ); |
579 | py::class_<TFE_MonitoringCounter0> TFE_MonitoringCounter0_class( |
580 | m, "TFE_MonitoringCounter0" ); |
581 | py::class_<TFE_MonitoringCounter1> TFE_MonitoringCounter1_class( |
582 | m, "TFE_MonitoringCounter1" ); |
583 | py::class_<TFE_MonitoringCounter2> TFE_MonitoringCounter2_class( |
584 | m, "TFE_MonitoringCounter2" ); |
585 | py::class_<TFE_MonitoringStringGauge0> TFE_MonitoringStringGauge0_class( |
586 | m, "TFE_MonitoringStringGauge0" ); |
587 | py::class_<TFE_MonitoringStringGauge1> TFE_MonitoringStringGauge1_class( |
588 | m, "TFE_MonitoringStringGauge1" ); |
589 | py::class_<TFE_MonitoringStringGauge2> TFE_MonitoringStringGauge2_class( |
590 | m, "TFE_MonitoringStringGauge2" ); |
591 | py::class_<TFE_MonitoringStringGauge3> TFE_MonitoringStringGauge3_class( |
592 | m, "TFE_MonitoringStringGauge3" ); |
593 | py::class_<TFE_MonitoringStringGauge4> TFE_MonitoringStringGauge4_class( |
594 | m, "TFE_MonitoringStringGauge4" ); |
595 | py::class_<TFE_MonitoringIntGauge0> TFE_MonitoringIntGauge0_class( |
596 | m, "TFE_MonitoringIntGauge0" ); |
597 | py::class_<TFE_MonitoringIntGauge1> TFE_MonitoringIntGauge1_class( |
598 | m, "TFE_MonitoringIntGauge1" ); |
599 | py::class_<TFE_MonitoringIntGauge2> TFE_MonitoringIntGauge2_class( |
600 | m, "TFE_MonitoringIntGauge2" ); |
601 | py::class_<TFE_MonitoringBoolGauge0> TFE_MonitoringBoolGauge0_class( |
602 | m, "TFE_MonitoringBoolGauge0" ); |
603 | py::class_<TFE_MonitoringBoolGauge1> TFE_MonitoringBoolGauge1_class( |
604 | m, "TFE_MonitoringBoolGauge1" ); |
605 | py::class_<TFE_MonitoringBoolGauge2> TFE_MonitoringBoolGauge2_class( |
606 | m, "TFE_MonitoringBoolGauge2" ); |
607 | py::class_<TFE_MonitoringCounterCell> TFE_MonitoringCounterCell_class( |
608 | m, "TFE_MonitoringCounterCell" ); |
609 | py::class_<TFE_MonitoringIntGaugeCell> TFE_MonitoringIntGaugeCell_class( |
610 | m, "TFE_MonitoringIntGaugeCell" ); |
611 | py::class_<TFE_MonitoringStringGaugeCell> TFE_MonitoringStringGaugeCell_class( |
612 | m, "TFE_MonitoringStringGaugeCell" ); |
613 | py::class_<TFE_MonitoringBoolGaugeCell> TFE_MonitoringBoolGaugeCell_class( |
614 | m, "TFE_MonitoringBoolGaugeCell" ); |
615 | py::class_<TFE_MonitoringSamplerCell> TFE_MonitoringSamplerCell_class( |
616 | m, "TFE_MonitoringSamplerCell" ); |
617 | py::class_<TFE_MonitoringBuckets> TFE_MonitoringBuckets_class( |
618 | m, "TFE_MonitoringBuckets" ); |
619 | py::class_<TFE_MonitoringSampler0> TFE_MonitoringSampler0_class( |
620 | m, "TFE_MonitoringSampler0" ); |
621 | py::class_<TFE_MonitoringSampler1> TFE_MonitoringSampler1_class( |
622 | m, "TFE_MonitoringSampler1" ); |
623 | py::class_<TFE_MonitoringSampler2> TFE_MonitoringSampler2_class( |
624 | m, "TFE_MonitoringSampler2" ); |
625 | py::class_<tensorflow::CancellationManager> TFE_CancellationManager_class( |
626 | m, "TFE_CancellationManager" ); |
627 | |
628 | py::class_<TF_DeviceList> TF_DeviceList_class(m, "TF_DeviceList" ); |
629 | py::class_<TF_Function> TF_Function_class(m, "TF_Function" ); |
630 | |
631 | m.def("TFE_Py_RegisterExceptionClass" , [](const py::handle& e) { |
632 | return tensorflow::PyoOrThrow(TFE_Py_RegisterExceptionClass(e.ptr())); |
633 | }); |
634 | m.def("TFE_Py_RegisterFallbackExceptionClass" , [](const py::handle& e) { |
635 | return tensorflow::PyoOrThrow( |
636 | TFE_Py_RegisterFallbackExceptionClass(e.ptr())); |
637 | }); |
638 | |
639 | m.def("TFE_GetMemoryInfo" , [](py::handle& ctx, const char* device_name) { |
640 | tensorflow::Device* matched_device = |
641 | tensorflow::GetMatchedDevice(ctx, device_name); |
642 | |
643 | tensorflow::AllocatorAttributes attrs; |
644 | tensorflow::Allocator* allocator = matched_device->GetAllocator(attrs); |
645 | |
646 | if (absl::optional<tensorflow::AllocatorStats> stats = |
647 | allocator->GetStats()) { |
648 | return std::map<std::string, int64_t>{{"current" , stats->bytes_in_use}, |
649 | {"peak" , stats->peak_bytes_in_use}}; |
650 | } |
651 | |
652 | tensorflow::ThrowValueError( |
653 | absl::StrFormat("Allocator stats not available for device '%s'" , |
654 | device_name) |
655 | .c_str()); |
656 | }); |
657 | |
658 | m.def("TFE_ResetMemoryStats" , [](py::handle& ctx, const char* device_name) { |
659 | tensorflow::Device* matched_device = |
660 | tensorflow::GetMatchedDevice(ctx, device_name); |
661 | |
662 | tensorflow::AllocatorAttributes attrs; |
663 | tensorflow::Allocator* allocator = matched_device->GetAllocator(attrs); |
664 | |
665 | if (!allocator->ClearStats()) { |
666 | tensorflow::ThrowValueError( |
667 | absl::StrFormat("Cannot reset memory stats for device '%s'" , |
668 | device_name) |
669 | .c_str()); |
670 | } |
671 | }); |
672 | |
673 | // XLA Eager Logic |
674 | m.def("TF_SetXlaEnableLazyCompilation" , &TF_SetXlaEnableLazyCompilation); |
675 | m.def("TF_SetTfXlaCpuGlobalJit" , &TF_SetTfXlaCpuGlobalJit); |
676 | m.def("TF_SetXlaAutoJitMode" , &TF_SetXlaAutoJitMode); |
677 | m.def("TF_SetXlaConstantFoldingDisabled" , &TF_SetXlaConstantFoldingDisabled); |
678 | m.def("TF_GetXlaConstantFoldingDisabled" , &TF_GetXlaConstantFoldingDisabled); |
679 | m.def("TF_SetXlaMinClusterSize" , &TF_SetXlaMinClusterSize); |
680 | m.def("TF_GetCompilerIr" , &tensorflow::TFE_GetCompilerIr); |
681 | |
682 | // MLIR Logic |
683 | m.def("TF_IsMlirBridgeEnabled" , [] { |
684 | // Since python protobuf enums are integers, cast to an integer before |
685 | // returning the enum to python. |
686 | return static_cast<int32_t>( |
687 | tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge); |
688 | }); |
689 | m.def("TF_EnableMlirBridge" , [](bool enabled) { |
690 | tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge = |
691 | enabled |
692 | ? tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED |
693 | : tensorflow::ConfigProto::Experimental:: |
694 | MLIR_BRIDGE_ROLLOUT_DISABLED; |
695 | }); |
696 | m.def("TF_EnableXlaDevices" , [] { |
697 | tensorflow::GetXlaDeviceFlags()->tf_xla_enable_xla_devices = true; |
698 | }); |
699 | m.def("TF_ResetJitCompilerFlags" , |
700 | [] { tensorflow::ResetJitCompilerFlags(); }); |
701 | |
702 | // TFE_Context Logic |
703 | m.def( |
704 | "TFE_NewContext" , |
705 | [](const TFE_ContextOptions* opts) { |
706 | tensorflow::Safe_TF_StatusPtr status = |
707 | tensorflow::make_safe(TF_NewStatus()); |
708 | TFE_Context* context = TFE_NewContext(opts, status.get()); |
709 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
710 | return tensorflow::PyoOrThrow(tensorflow::OutputTFE_Context(context)); |
711 | }, |
712 | py::return_value_policy::reference); |
713 | m.def("TFE_DeleteContext" , [](py::handle& o) { |
714 | TFE_DeleteContext(tensorflow::InputTFE_Context(o)); |
715 | }); |
716 | m.def( |
717 | "TFE_ContextListDevices" , |
718 | [](py::handle& o) { |
719 | tensorflow::Safe_TF_StatusPtr status = |
720 | tensorflow::make_safe(TF_NewStatus()); |
721 | auto output = TFE_ContextListDevices(tensorflow::InputTFE_Context(o), |
722 | status.get()); |
723 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
724 | return output; |
725 | }, |
726 | py::return_value_policy::reference); |
727 | m.def( |
728 | "TFE_SetLogicalCpuDevices" , |
729 | [](py::handle& ctx, int num_cpus, const char* prefix) { |
730 | tensorflow::Safe_TF_StatusPtr status = |
731 | tensorflow::make_safe(TF_NewStatus()); |
732 | TFE_SetLogicalCpuDevices(tensorflow::InputTFE_Context(ctx), num_cpus, |
733 | prefix, status.get()); |
734 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
735 | }, |
736 | py::return_value_policy::reference); |
737 | m.def("TFE_HostAddressSpace" , [](py::handle& o, TF_Buffer& buf) { |
738 | TFE_HostAddressSpace(tensorflow::InputTFE_Context(o), &buf); |
739 | }); |
740 | m.def("TFE_ContextAddFunction" , [](py::handle& ctx, TF_Function* func) { |
741 | tensorflow::Safe_TF_StatusPtr status = |
742 | tensorflow::make_safe(TF_NewStatus()); |
743 | TFE_ContextAddFunction(tensorflow::InputTFE_Context(ctx), func, |
744 | status.get()); |
745 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
746 | }); |
747 | m.def("TFE_ContextAddFunctionDef" , |
748 | [](py::handle& ctx, const char* serialized_function_def, size_t size) { |
749 | tensorflow::Safe_TF_StatusPtr status = |
750 | tensorflow::make_safe(TF_NewStatus()); |
751 | TFE_ContextAddFunctionDef(tensorflow::InputTFE_Context(ctx), |
752 | serialized_function_def, size, |
753 | status.get()); |
754 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
755 | }); |
756 | m.def("TFE_ContextGetFunctionDef" , |
757 | [](py::handle& ctx, const char* function_name, TF_Buffer& buf) { |
758 | tensorflow::Safe_TF_StatusPtr status = |
759 | tensorflow::make_safe(TF_NewStatus()); |
760 | TFE_ContextGetFunctionDef(tensorflow::InputTFE_Context(ctx), |
761 | function_name, &buf, status.get()); |
762 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
763 | }); |
764 | m.def("TFE_ContextRemoveFunction" , [](py::handle& ctx, const char* name) { |
765 | tensorflow::Safe_TF_StatusPtr status = |
766 | tensorflow::make_safe(TF_NewStatus()); |
767 | TFE_ContextRemoveFunction(tensorflow::InputTFE_Context(ctx), name, |
768 | status.get()); |
769 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
770 | }); |
771 | m.def("TFE_ContextHasFunction" , [](py::handle& ctx, const char* name) { |
772 | tensorflow::Safe_TF_StatusPtr status = |
773 | tensorflow::make_safe(TF_NewStatus()); |
774 | auto output = |
775 | TFE_ContextHasFunction(tensorflow::InputTFE_Context(ctx), name); |
776 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
777 | return output; |
778 | }); |
779 | m.def("TFE_ContextListFunctionNames" , [](py::handle& ctx) { |
780 | return tensorflow::unwrap(tensorflow::InputTFE_Context(ctx)) |
781 | ->ListFunctionNames(); |
782 | }); |
783 | m.def("TFE_ContextEnableRunMetadata" , [](py::handle& ctx) { |
784 | TFE_ContextEnableRunMetadata(tensorflow::InputTFE_Context(ctx)); |
785 | }); |
786 | m.def("TFE_ContextDisableRunMetadata" , [](py::handle& ctx) { |
787 | TFE_ContextEnableRunMetadata(tensorflow::InputTFE_Context(ctx)); |
788 | }); |
789 | m.def("TFE_ContextEnableGraphCollection" , [](py::handle& ctx) { |
790 | TFE_ContextEnableGraphCollection(tensorflow::InputTFE_Context(ctx)); |
791 | }); |
792 | m.def("TFE_ContextDisableGraphCollection" , [](py::handle& ctx) { |
793 | TFE_ContextDisableGraphCollection(tensorflow::InputTFE_Context(ctx)); |
794 | }); |
795 | m.def("TFE_ContextExportRunMetadata" , [](py::handle& ctx, TF_Buffer& buf) { |
796 | tensorflow::Safe_TF_StatusPtr status = |
797 | tensorflow::make_safe(TF_NewStatus()); |
798 | TFE_ContextExportRunMetadata(tensorflow::InputTFE_Context(ctx), &buf, |
799 | status.get()); |
800 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
801 | }); |
802 | m.def("TFE_ContextClearCaches" , [](py::handle& o) { |
803 | TFE_ContextClearCaches(tensorflow::InputTFE_Context(o)); |
804 | }); |
805 | m.def("TFE_GetContextId" , [](py::handle& ctx) { |
806 | return TFE_GetContextId(tensorflow::InputTFE_Context(ctx)); |
807 | }); |
808 | m.def("TFE_ContextGetDevicePlacementPolicy" , [](py::handle& ctx) { |
809 | return TFE_ContextGetDevicePlacementPolicy( |
810 | tensorflow::InputTFE_Context(ctx)); |
811 | }); |
812 | m.def("TFE_ContextSetThreadLocalDevicePlacementPolicy" , |
813 | [](py::handle& ctx, TFE_ContextDevicePlacementPolicy policy) { |
814 | TFE_ContextSetThreadLocalDevicePlacementPolicy( |
815 | tensorflow::InputTFE_Context(ctx), policy); |
816 | }); |
817 | m.def("TFE_ContextSetServerDef" , [](py::handle& ctx, int keep_alive_secs, |
818 | py::bytes proto) { |
819 | tensorflow::Safe_TF_StatusPtr status = |
820 | tensorflow::make_safe(TF_NewStatus()); |
821 | tensorflow::Safe_TF_BufferPtr buf = |
822 | tensorflow::make_safe(tensorflow::ProtoStringToTFBuffer(proto.ptr())); |
823 | TFE_ContextSetServerDef(tensorflow::InputTFE_Context(ctx), keep_alive_secs, |
824 | buf.get()->data, buf.get()->length, status.get()); |
825 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
826 | }); |
827 | m.def("TFE_ContextUpdateServerDef" , [](py::handle& ctx, int keep_alive_secs, |
828 | py::bytes proto) { |
829 | tensorflow::Safe_TF_StatusPtr status = |
830 | tensorflow::make_safe(TF_NewStatus()); |
831 | tensorflow::Safe_TF_BufferPtr buf = |
832 | tensorflow::make_safe(tensorflow::ProtoStringToTFBuffer(proto.ptr())); |
833 | Py_BEGIN_ALLOW_THREADS; |
834 | TFE_ContextUpdateServerDef(tensorflow::InputTFE_Context(ctx), |
835 | keep_alive_secs, buf.get()->data, |
836 | buf.get()->length, status.get()); |
837 | Py_END_ALLOW_THREADS; |
838 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
839 | }); |
840 | m.def("TFE_ContextCheckAlive" , [](py::handle& ctx, const char* worker_name) { |
841 | tensorflow::Safe_TF_StatusPtr status = |
842 | tensorflow::make_safe(TF_NewStatus()); |
843 | bool output = TFE_ContextCheckAlive(tensorflow::InputTFE_Context(ctx), |
844 | worker_name, status.get()); |
845 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
846 | return output; |
847 | }); |
848 | m.def("TFE_ContextSyncExecutors" , [](py::handle& ctx) { |
849 | tensorflow::Safe_TF_StatusPtr status = |
850 | tensorflow::make_safe(TF_NewStatus()); |
851 | // NOTE: release Python GIL for pending PyFunc ops to be executed properly. |
852 | Py_BEGIN_ALLOW_THREADS; |
853 | TFE_ContextAsyncWait(tensorflow::InputTFE_Context(ctx), status.get()); |
854 | Py_END_ALLOW_THREADS; |
855 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
856 | }); |
857 | m.def("TFE_ContextClearExecutors" , [](py::handle& ctx) { |
858 | tensorflow::Safe_TF_StatusPtr status = |
859 | tensorflow::make_safe(TF_NewStatus()); |
860 | // NOTE: release Python GIL for pending PyFunc ops to be executed properly. |
861 | Py_BEGIN_ALLOW_THREADS; |
862 | TFE_ContextAsyncWait(tensorflow::InputTFE_Context(ctx), status.get()); |
863 | Py_END_ALLOW_THREADS; |
864 | // NOTE: different from TFE_ContextSyncExecutors that raises potential |
865 | // errors, deliberately ignore executor statuses in cleanup. |
866 | }); |
867 | m.def( |
868 | "TFE_InsertConfigKeyValue" , |
869 | [](py::handle& ctx, const char* config_key, const char* config_value) { |
870 | tensorflow::Safe_TF_StatusPtr status = |
871 | tensorflow::make_safe(TF_NewStatus()); |
872 | Py_BEGIN_ALLOW_THREADS; |
873 | TFE_InsertConfigKeyValue(tensorflow::InputTFE_Context(ctx), config_key, |
874 | config_value, status.get()); |
875 | Py_END_ALLOW_THREADS; |
876 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
877 | }, |
878 | py::return_value_policy::reference); |
879 | m.def( |
880 | "TFE_GetConfigKeyValue" , |
881 | [](py::handle& ctx, const char* config_key, TF_Buffer& config_value) { |
882 | tensorflow::Safe_TF_StatusPtr status = |
883 | tensorflow::make_safe(TF_NewStatus()); |
884 | Py_BEGIN_ALLOW_THREADS; |
885 | TFE_GetConfigKeyValue(tensorflow::InputTFE_Context(ctx), config_key, |
886 | &config_value, status.get()); |
887 | Py_END_ALLOW_THREADS; |
888 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
889 | }, |
890 | py::return_value_policy::reference); |
891 | m.def( |
892 | "TFE_DeleteConfigKeyValue" , |
893 | [](py::handle& ctx, const char* config_key) { |
894 | tensorflow::Safe_TF_StatusPtr status = |
895 | tensorflow::make_safe(TF_NewStatus()); |
896 | Py_BEGIN_ALLOW_THREADS; |
897 | TFE_DeleteConfigKeyValue(tensorflow::InputTFE_Context(ctx), config_key, |
898 | status.get()); |
899 | Py_END_ALLOW_THREADS; |
900 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
901 | }, |
902 | py::return_value_policy::reference); |
903 | m.def( |
904 | "TFE_ReportErrorToCluster" , |
905 | [](py::handle& ctx, int error_code, const char* error_message) { |
906 | tensorflow::Safe_TF_StatusPtr status = |
907 | tensorflow::make_safe(TF_NewStatus()); |
908 | TFE_ReportErrorToCluster(tensorflow::InputTFE_Context(ctx), error_code, |
909 | error_message, status.get()); |
910 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
911 | }, |
912 | py::return_value_policy::reference); |
913 | m.def("TFE_ContextSetSoftDevicePlacement" , [](py::handle& ctx, bool enable) { |
914 | tensorflow::Safe_TF_StatusPtr status = |
915 | tensorflow::make_safe(TF_NewStatus()); |
916 | TFE_ContextSetSoftDevicePlacement(tensorflow::InputTFE_Context(ctx), enable, |
917 | status.get()); |
918 | }); |
919 | m.def("TFE_ContextSetLogDevicePlacement" , [](py::handle& ctx, bool enable) { |
920 | tensorflow::Safe_TF_StatusPtr status = |
921 | tensorflow::make_safe(TF_NewStatus()); |
922 | TFE_ContextSetSoftDevicePlacement(tensorflow::InputTFE_Context(ctx), enable, |
923 | status.get()); |
924 | }); |
925 | m.def("TFE_ContextSetRunEagerOpAsFunction" , [](py::handle& ctx, bool enable) { |
926 | tensorflow::Safe_TF_StatusPtr status = |
927 | tensorflow::make_safe(TF_NewStatus()); |
928 | TFE_ContextSetRunEagerOpAsFunction(tensorflow::InputTFE_Context(ctx), |
929 | enable, status.get()); |
930 | }); |
931 | m.def("TFE_ContextSetJitCompileRewrite" , [](py::handle& ctx, bool enable) { |
932 | tensorflow::Safe_TF_StatusPtr status = |
933 | tensorflow::make_safe(TF_NewStatus()); |
934 | TFE_ContextSetJitCompileRewrite(tensorflow::InputTFE_Context(ctx), enable, |
935 | status.get()); |
936 | }); |
937 | |
938 | // TFE_Executor logic |
939 | m.def( |
940 | "TFE_NewExecutor" , |
941 | [](const bool is_async, const bool enable_streaming_enqueue) { |
942 | TFE_Executor* exc = TFE_NewExecutor(is_async, enable_streaming_enqueue); |
943 | return exc; |
944 | }, |
945 | py::return_value_policy::reference); |
946 | m.def("TFE_DeleteExecutor" , &TFE_DeleteExecutor); |
947 | m.def("TFE_ExecutorIsAsync" , &TFE_ExecutorIsAsync); |
948 | m.def("TFE_ExecutorWaitForAllPendingNodes" , [](TFE_Executor& exc) { |
949 | tensorflow::Safe_TF_StatusPtr status = |
950 | tensorflow::make_safe(TF_NewStatus()); |
951 | // NOTE: release Python GIL for pending PyFunc ops to be executed properly. |
952 | Py_BEGIN_ALLOW_THREADS; |
953 | TFE_ExecutorWaitForAllPendingNodes(&exc, status.get()); |
954 | Py_END_ALLOW_THREADS; |
955 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
956 | }); |
957 | m.def("TFE_ExecutorClearError" , &TFE_ExecutorClearError); |
958 | m.def("TFE_ContextSetExecutorForThread" , [](py::handle& ctx, |
959 | TFE_Executor& exc) { |
960 | TFE_ContextSetExecutorForThread(tensorflow::InputTFE_Context(ctx), &exc); |
961 | }); |
962 | m.def( |
963 | "TFE_ContextGetExecutorForThread" , |
964 | [](py::handle& o) { |
965 | return TFE_ContextGetExecutorForThread(tensorflow::InputTFE_Context(o)); |
966 | }, |
967 | py::return_value_policy::reference); |
968 | |
969 | m.def("TFE_OpNameGetAttrType" , |
970 | [](py::handle& ctx, const char* op_or_function_name, |
971 | const char* attr_name) { |
972 | int temp = 0; |
973 | unsigned char* is_list = reinterpret_cast<unsigned char*>(&temp); |
974 | tensorflow::Safe_TF_StatusPtr status = |
975 | tensorflow::make_safe(TF_NewStatus()); |
976 | auto output = TFE_OpNameGetAttrType(tensorflow::InputTFE_Context(ctx), |
977 | op_or_function_name, attr_name, |
978 | is_list, status.get()); |
979 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
980 | #if PY_MAJOR_VERSION < 3 |
981 | PyObject* output_pyo = PyInt_FromLong(output); |
982 | #else |
983 | PyObject* output_pyo = PyLong_FromLong(output); |
984 | #endif |
985 | if (*is_list == 1) { |
986 | PyObject* list = PyList_New(1); |
987 | PyList_SetItem(list, 0, output_pyo); |
988 | return tensorflow::PyoOrThrow(list); |
989 | } |
990 | return tensorflow::PyoOrThrow(output_pyo); |
991 | }); |
992 | m.def("TFE_Py_InitEagerTensor" , [](const py::handle& o) { |
993 | return tensorflow::PyoOrThrow(TFE_Py_InitEagerTensor(o.ptr())); |
994 | }); |
995 | m.def("TFE_Py_PackEagerTensors" , |
996 | [](const py::handle& context, const py::handle& handles) { |
997 | return tensorflow::TFE_Py_PackEagerTensors_wrapper(context, handles); |
998 | }); |
999 | m.def("TFE_Py_SetEagerTensorProfiler" , &TFE_Py_SetEagerTensorProfiler); |
1000 | m.def("TFE_Py_RegisterJVPFunction" , [](const py::handle& o) { |
1001 | return tensorflow::PyoOrThrow(TFE_Py_RegisterJVPFunction(o.ptr())); |
1002 | }); |
1003 | m.def("TFE_Py_RegisterGradientFunction" , [](const py::handle& o) { |
1004 | return tensorflow::PyoOrThrow(TFE_Py_RegisterGradientFunction(o.ptr())); |
1005 | }); |
1006 | m.def("TFE_Py_Execute" , |
1007 | [](const py::handle& context, const char* device_name, |
1008 | const char* op_name, const py::handle& inputs, |
1009 | const py::handle& attrs, const py::handle& num_outputs) { |
1010 | return tensorflow::TFE_Py_ExecuteCancelable_wrapper( |
1011 | context, device_name, op_name, inputs, attrs.ptr(), nullptr, |
1012 | num_outputs); |
1013 | }); |
1014 | m.def( |
1015 | "TFE_Py_ExecuteCancelable" , |
1016 | [](const py::handle& context, const char* device_name, |
1017 | const char* op_name, const py::handle& inputs, const py::handle& attrs, |
1018 | tensorflow::CancellationManager& cancellation_manager, |
1019 | const py::handle& num_outputs) { |
1020 | return tensorflow::TFE_Py_ExecuteCancelable_wrapper( |
1021 | context, device_name, op_name, inputs, attrs.ptr(), |
1022 | &cancellation_manager, num_outputs); |
1023 | }); |
1024 | m.def("TFE_Py_FastPathExecute" , [](const py::args args) { |
1025 | // TFE_Py_FastPathExecute requires error checking prior to returning. |
1026 | return tensorflow::PyoOrThrow(TFE_Py_FastPathExecute_C(args.ptr())); |
1027 | }); |
1028 | m.def("TFE_Py_RecordGradient" , |
1029 | [](const py::handle& op_name, const py::handle& inputs, |
1030 | const py::handle& attrs, const py::handle& results, |
1031 | const py::handle& forward_pass_name_scope) { |
1032 | return tensorflow::PyoOrThrow(TFE_Py_RecordGradient( |
1033 | op_name.ptr(), inputs.ptr(), attrs.ptr(), results.ptr(), |
1034 | forward_pass_name_scope.ptr())); |
1035 | }); |
1036 | m.def("TFE_Py_UID" , []() { return tensorflow::PyoOrThrow(TFE_Py_UID()); }); |
1037 | |
1038 | // TFE_Py_Tape Logic |
1039 | m.def("TFE_Py_TapeSetNew" , [](const py::handle& persistent, |
1040 | const py::handle& watch_accessed_variables) { |
1041 | return tensorflow::PyoOrThrow( |
1042 | TFE_Py_TapeSetNew(persistent.ptr(), watch_accessed_variables.ptr())); |
1043 | }); |
1044 | m.def("TFE_Py_TapeSetAdd" , |
1045 | [](const py::handle& tape) { TFE_Py_TapeSetAdd(tape.ptr()); }); |
1046 | m.def("TFE_Py_TapeSetRemove" , |
1047 | [](const py::handle& tape) { TFE_Py_TapeSetRemove(tape.ptr()); }); |
1048 | m.def("TFE_Py_TapeSetStopOnThread" , &TFE_Py_TapeSetStopOnThread); |
1049 | m.def("TFE_Py_TapeSetRestartOnThread" , &TFE_Py_TapeSetRestartOnThread); |
1050 | m.def("TFE_Py_TapeSetIsStopped" , |
1051 | []() { return tensorflow::PyoOrThrow(TFE_Py_TapeSetIsStopped()); }); |
1052 | m.def("TFE_Py_TapeSetIsEmpty" , |
1053 | []() { return tensorflow::PyoOrThrow(TFE_Py_TapeSetIsEmpty()); }); |
1054 | m.def("TFE_Py_TapeSetShouldRecordBackprop" , [](const py::handle& tensors) { |
1055 | return tensorflow::PyoOrThrow( |
1056 | TFE_Py_TapeSetShouldRecordBackprop(tensors.ptr())); |
1057 | }); |
1058 | m.def("TFE_Py_TapeSetPossibleGradientTypes" , [](const py::handle& tensors) { |
1059 | return tensorflow::PyoOrThrow( |
1060 | TFE_Py_TapeSetPossibleGradientTypes(tensors.ptr())); |
1061 | }); |
1062 | m.def("TFE_Py_TapeSetDeleteTrace" , &TFE_Py_TapeSetDeleteTrace); |
1063 | m.def("TFE_Py_TapeSetRecordOperation" , |
1064 | [](const py::handle& op_type, const py::handle& output_tensors, |
1065 | const py::handle& input_tensors, const py::handle& backward_function, |
1066 | const py::handle& forward_function) { |
1067 | return tensorflow::PyoOrThrow(TFE_Py_TapeSetRecordOperation( |
1068 | op_type.ptr(), output_tensors.ptr(), input_tensors.ptr(), |
1069 | backward_function.ptr(), forward_function.ptr())); |
1070 | }); |
1071 | m.def( |
1072 | "TFE_Py_TapeSetRecordOperationBackprop" , |
1073 | [](const py::handle& op_type, const py::handle& output_tensors, |
1074 | const py::handle& input_tensors, const py::handle& backward_function) { |
1075 | return tensorflow::PyoOrThrow(TFE_Py_TapeSetRecordOperationBackprop( |
1076 | op_type.ptr(), output_tensors.ptr(), input_tensors.ptr(), |
1077 | backward_function.ptr())); |
1078 | }); |
1079 | m.def( |
1080 | "TFE_Py_TapeSetRecordOperationForwardprop" , |
1081 | [](const py::handle& op_type, const py::handle& output_tensors, |
1082 | const py::handle& input_tensors, const py::handle& backward_function, |
1083 | const py::handle& forwardprop_output_indices) { |
1084 | return tensorflow::PyoOrThrow(TFE_Py_TapeSetRecordOperationForwardprop( |
1085 | op_type.ptr(), output_tensors.ptr(), input_tensors.ptr(), |
1086 | backward_function.ptr(), forwardprop_output_indices.ptr())); |
1087 | }); |
1088 | m.def("TFE_Py_TapeGradient" , |
1089 | [](const py::handle& tape, const py::handle& target, |
1090 | const py::handle& sources, const py::handle& output_gradients, |
1091 | const py::handle& sources_raw, |
1092 | const py::handle& unconnected_gradients) { |
1093 | tensorflow::Safe_TF_StatusPtr status = |
1094 | tensorflow::make_safe(TF_NewStatus()); |
1095 | PyObject* output = TFE_Py_TapeGradient( |
1096 | tape.ptr(), target.ptr(), sources.ptr(), output_gradients.ptr(), |
1097 | sources_raw.ptr(), unconnected_gradients.ptr(), status.get()); |
1098 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
1099 | return tensorflow::PyoOrThrow(output); |
1100 | }); |
1101 | |
1102 | m.def("TFE_Py_TapeVariableAccessed" , [](const py::handle& variable) { |
1103 | TFE_Py_TapeVariableAccessed(variable.ptr()); |
1104 | }); |
1105 | m.def("TFE_Py_TapeWatch" , |
1106 | [](const py::handle& tape, const py::handle& tensor) { |
1107 | TFE_Py_TapeWatch(tape.ptr(), tensor.ptr()); |
1108 | }); |
1109 | m.def("TFE_Py_TapeWatchVariable" , |
1110 | [](const py::handle& tape, const py::handle& variable) { |
1111 | TFE_Py_TapeWatchVariable(tape.ptr(), variable.ptr()); |
1112 | }); |
1113 | m.def("TFE_Py_TapeWatchedVariables" , [](const py::handle& tape) { |
1114 | return tensorflow::PyoOrThrow(TFE_Py_TapeWatchedVariables(tape.ptr())); |
1115 | }); |
1116 | |
1117 | // TFE_Py_VariableWatcher logic. |
1118 | m.def("TFE_Py_VariableWatcherNew" , |
1119 | []() { return tensorflow::PyoOrThrow(TFE_Py_VariableWatcherNew()); }); |
1120 | m.def("TFE_Py_VariableWatcherRemove" , [](const py::handle& variable_watcher) { |
1121 | TFE_Py_VariableWatcherRemove(variable_watcher.ptr()); |
1122 | }); |
1123 | m.def("TFE_Py_VariableWatcherVariableAccessed" , |
1124 | [](const py::handle& variable) { |
1125 | TFE_Py_VariableWatcherVariableAccessed(variable.ptr()); |
1126 | }); |
1127 | m.def("TFE_Py_VariableWatcherWatchedVariables" , |
1128 | [](const py::handle& variable_watcher) { |
1129 | return tensorflow::PyoOrThrow( |
1130 | TFE_Py_VariableWatcherWatchedVariables(variable_watcher.ptr())); |
1131 | }); |
1132 | |
1133 | // TFE_Py_ForwardAccumulator logic. |
1134 | m.def("TFE_Py_ForwardAccumulatorNew" , [](bool use_batch) { |
1135 | return tensorflow::PyoOrThrow(TFE_Py_ForwardAccumulatorNew(use_batch)); |
1136 | }); |
1137 | |
1138 | m.def("TFE_Py_ForwardAccumulatorSetAdd" , [](const py::handle& accumulator) { |
1139 | return tensorflow::PyoOrThrow( |
1140 | TFE_Py_ForwardAccumulatorSetAdd(accumulator.ptr())); |
1141 | }); |
1142 | m.def("TFE_Py_ForwardAccumulatorSetRemove" , |
1143 | [](const py::handle& accumulator) { |
1144 | TFE_Py_ForwardAccumulatorSetRemove(accumulator.ptr()); |
1145 | }); |
1146 | |
1147 | m.def("TFE_Py_ForwardAccumulatorWatch" , |
1148 | [](const py::handle& accumulator, const py::handle& tensor, |
1149 | const py::handle& tangent) { |
1150 | TFE_Py_ForwardAccumulatorWatch(accumulator.ptr(), tensor.ptr(), |
1151 | tangent.ptr()); |
1152 | }); |
1153 | m.def("TFE_Py_ForwardAccumulatorJVP" , |
1154 | [](const py::handle& accumulator, const py::handle& tensor) { |
1155 | return tensorflow::PyoOrThrow( |
1156 | TFE_Py_ForwardAccumulatorJVP(accumulator.ptr(), tensor.ptr())); |
1157 | }); |
1158 | m.def("TFE_Py_ForwardAccumulatorPushState" , []() { |
1159 | return tensorflow::PyoOrThrow(TFE_Py_ForwardAccumulatorPushState()); |
1160 | }); |
1161 | m.def("TFE_Py_ForwardAccumulatorPopState" , []() { |
1162 | return tensorflow::PyoOrThrow(TFE_Py_ForwardAccumulatorPopState()); |
1163 | }); |
1164 | m.def("TFE_Py_PackJVPs" , [](const py::handle& tensors) { |
1165 | return tensorflow::PyoOrThrow(TFE_Py_PackJVPs(tensors.ptr())); |
1166 | }); |
1167 | |
1168 | // TFE_ContextOptions Logic |
1169 | m.def("TFE_NewContextOptions" , &TFE_NewContextOptions, |
1170 | py::return_value_policy::reference); |
1171 | m.def("TFE_ContextOptionsSetConfig" , [](TFE_ContextOptions* options, |
1172 | py::bytes proto) { |
1173 | tensorflow::Safe_TF_StatusPtr status = |
1174 | tensorflow::make_safe(TF_NewStatus()); |
1175 | tensorflow::Safe_TF_BufferPtr buf = |
1176 | tensorflow::make_safe(tensorflow::ProtoStringToTFBuffer(proto.ptr())); |
1177 | TFE_ContextOptionsSetConfig(options, buf.get()->data, buf.get()->length, |
1178 | status.get()); |
1179 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
1180 | }); |
1181 | m.def("TFE_ContextOptionsSetDevicePlacementPolicy" , |
1182 | &TFE_ContextOptionsSetDevicePlacementPolicy); |
1183 | m.def("TFE_ContextOptionsSetTfrt" , &TFE_ContextOptionsSetTfrt); |
1184 | m.def("TFE_ContextOptionsSetTfrtDistributedRuntime" , |
1185 | &TFE_ContextOptionsSetTfrtDistributedRuntime); |
1186 | // Experimental feature, intentionally not exposed as a C API yet. |
1187 | m.def("TFE_ContextOptionsSetRunEagerOpAsFunction" , |
1188 | [](TFE_ContextOptions* options, bool run_eager_op_as_function) { |
1189 | options->run_eager_op_as_function = run_eager_op_as_function; |
1190 | }); |
1191 | m.def("TFE_ContextOptionsSetJitCompileRewrite" , |
1192 | [](TFE_ContextOptions* options, bool jit_compile_rewrite) { |
1193 | options->jit_compile_rewrite = jit_compile_rewrite; |
1194 | }); |
1195 | m.def("TFE_ContextOptionsSetAsync" , &TFE_ContextOptionsSetAsync); |
1196 | m.def("TFE_DeleteContextOptions" , &TFE_DeleteContextOptions, |
1197 | py::return_value_policy::reference); |
1198 | |
1199 | // TFE_Py_TensorShape Logic |
1200 | m.def("TFE_Py_TensorShapeSlice" , |
1201 | [](const py::handle& tensors, int slice_dim) { |
1202 | return tensorflow::PyoOrThrow( |
1203 | TFE_Py_TensorShapeSlice(tensors.ptr(), slice_dim)); |
1204 | }); |
1205 | m.def("TFE_Py_TensorShapeOnDevice" , [](const py::handle& tensors, |
1206 | int slice_dim) { |
1207 | return tensorflow::PyoOrThrow(TFE_Py_TensorShapeOnDevice(tensors.ptr())); |
1208 | }); |
1209 | m.def("TFE_Py_EnableInteractivePythonLogging" , |
1210 | &TFE_Py_EnableInteractivePythonLogging); |
1211 | |
1212 | // Additional Context Logic |
1213 | m.def("TFE_Py_SetEagerContext" , [](const py::handle& o) { |
1214 | return tensorflow::PyoOrThrow(TFE_Py_SetEagerContext(o.ptr())); |
1215 | }); |
1216 | m.def("TFE_Py_SetCEagerContext" , [](const py::handle& ctx) { |
1217 | // TODO(mdan): This cast might need rewriting to ImmediateExecutionContext. |
1218 | tensorflow::SetCEagerContext(reinterpret_cast<tensorflow::EagerContext*>( |
1219 | tensorflow::InputTFE_Context(ctx))); |
1220 | }); |
1221 | m.def("TFE_Py_RegisterVSpace" , [](const py::handle& o) { |
1222 | return tensorflow::PyoOrThrow(TFE_Py_RegisterVSpace(o.ptr())); |
1223 | }); |
1224 | m.def("TFE_EnableCollectiveOps" , [](const py::handle& ctx, py::bytes proto) { |
1225 | tensorflow::Safe_TF_StatusPtr status = |
1226 | tensorflow::make_safe(TF_NewStatus()); |
1227 | tensorflow::Safe_TF_BufferPtr buf = |
1228 | tensorflow::make_safe(tensorflow::ProtoStringToTFBuffer(proto.ptr())); |
1229 | TFE_EnableCollectiveOps(tensorflow::InputTFE_Context(ctx), buf.get()->data, |
1230 | buf.get()->length, status.get()); |
1231 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
1232 | }); |
1233 | m.def("TFE_AbortCollectiveOps" , [](const py::handle& ctx, int code, |
1234 | const char* message) { |
1235 | tensorflow::Safe_TF_StatusPtr status = |
1236 | tensorflow::make_safe(TF_NewStatus()); |
1237 | TF_SetStatus(status.get(), static_cast<TF_Code>(code), message); |
1238 | TFE_AbortCollectiveOps(tensorflow::InputTFE_Context(ctx), status.get()); |
1239 | }); |
1240 | m.def("TFE_CollectiveOpsCheckPeerHealth" , |
1241 | [](const py::handle& ctx, const char* task, int64_t timeout_in_ms) { |
1242 | tensorflow::Safe_TF_StatusPtr status = |
1243 | tensorflow::make_safe(TF_NewStatus()); |
1244 | TFE_CollectiveOpsCheckPeerHealth(tensorflow::InputTFE_Context(ctx), |
1245 | task, timeout_in_ms, status.get()); |
1246 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
1247 | }); |
1248 | m.def("TF_ListPhysicalDevices" , &tensorflow::TF_ListPhysicalDevices); |
1249 | m.def("TF_ListPluggablePhysicalDevices" , |
1250 | &tensorflow::TF_ListPluggablePhysicalDevices); |
1251 | m.def("TF_GetDeviceDetails" , &tensorflow::TF_GetDeviceDetails); |
1252 | m.def("TF_DeleteDeviceList" , &TF_DeleteDeviceList, |
1253 | py::return_value_policy::reference); |
1254 | m.def("TF_DeviceListCount" , &TF_DeviceListCount); |
1255 | m.def("TF_DeviceListName" , [](const TF_DeviceList* list, int index) { |
1256 | tensorflow::Safe_TF_StatusPtr status = |
1257 | tensorflow::make_safe(TF_NewStatus()); |
1258 | auto output = TF_DeviceListName(list, index, status.get()); |
1259 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
1260 | return output; |
1261 | }); |
1262 | m.def("TF_DeviceListType" , [](const TF_DeviceList* list, int index) { |
1263 | tensorflow::Safe_TF_StatusPtr status = |
1264 | tensorflow::make_safe(TF_NewStatus()); |
1265 | auto output = TF_DeviceListType(list, index, status.get()); |
1266 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
1267 | return output; |
1268 | }); |
1269 | |
1270 | m.def("TF_PickUnusedPortOrDie" , &TF_PickUnusedPortOrDie); |
1271 | |
1272 | // TFE_MonitoringCounter Logic |
1273 | m.def("TFE_MonitoringCounterCellIncrementBy" , |
1274 | &TFE_MonitoringCounterCellIncrementBy); |
1275 | m.def("TFE_MonitoringCounterCellValue" , &TFE_MonitoringCounterCellValue); |
1276 | m.def( |
1277 | "TFE_MonitoringNewCounter0" , |
1278 | [](const char* name, const char* description) { |
1279 | tensorflow::Safe_TF_StatusPtr status = |
1280 | tensorflow::make_safe(TF_NewStatus()); |
1281 | auto output = |
1282 | TFE_MonitoringNewCounter0(name, status.get(), description); |
1283 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
1284 | return output; |
1285 | }, |
1286 | py::return_value_policy::reference); |
1287 | m.def("TFE_MonitoringDeleteCounter0" , &TFE_MonitoringDeleteCounter0, |
1288 | py::return_value_policy::reference); |
1289 | m.def("TFE_MonitoringGetCellCounter0" , &TFE_MonitoringGetCellCounter0, |
1290 | py::return_value_policy::reference); |
1291 | m.def( |
1292 | "TFE_MonitoringNewCounter1" , |
1293 | [](const char* name, const char* description, const char* label1) { |
1294 | tensorflow::Safe_TF_StatusPtr status = |
1295 | tensorflow::make_safe(TF_NewStatus()); |
1296 | auto output = |
1297 | TFE_MonitoringNewCounter1(name, status.get(), description, label1); |
1298 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
1299 | return output; |
1300 | }, |
1301 | py::return_value_policy::reference); |
1302 | m.def("TFE_MonitoringDeleteCounter1" , &TFE_MonitoringDeleteCounter1, |
1303 | py::return_value_policy::reference); |
1304 | m.def("TFE_MonitoringGetCellCounter1" , &TFE_MonitoringGetCellCounter1, |
1305 | py::return_value_policy::reference); |
1306 | m.def( |
1307 | "TFE_MonitoringNewCounter2" , |
1308 | [](const char* name, const char* description, const char* label1, |
1309 | const char* label2) { |
1310 | tensorflow::Safe_TF_StatusPtr status = |
1311 | tensorflow::make_safe(TF_NewStatus()); |
1312 | auto output = TFE_MonitoringNewCounter2(name, status.get(), description, |
1313 | label1, label2); |
1314 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
1315 | return output; |
1316 | }, |
1317 | py::return_value_policy::reference); |
1318 | m.def("TFE_MonitoringDeleteCounter2" , &TFE_MonitoringDeleteCounter2, |
1319 | py::return_value_policy::reference); |
1320 | m.def("TFE_MonitoringGetCellCounter2" , &TFE_MonitoringGetCellCounter2, |
1321 | py::return_value_policy::reference); |
1322 | |
1323 | // TFE_MonitoringIntGauge Logic |
1324 | m.def("TFE_MonitoringIntGaugeCellSet" , &TFE_MonitoringIntGaugeCellSet); |
1325 | m.def("TFE_MonitoringIntGaugeCellValue" , &TFE_MonitoringIntGaugeCellValue); |
1326 | m.def( |
1327 | "TFE_MonitoringNewIntGauge0" , |
1328 | [](const char* name, const char* description) { |
1329 | tensorflow::Safe_TF_StatusPtr status = |
1330 | tensorflow::make_safe(TF_NewStatus()); |
1331 | auto output = |
1332 | TFE_MonitoringNewIntGauge0(name, status.get(), description); |
1333 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
1334 | return output; |
1335 | }, |
1336 | py::return_value_policy::reference); |
1337 | m.def("TFE_MonitoringDeleteIntGauge0" , &TFE_MonitoringDeleteIntGauge0, |
1338 | py::return_value_policy::reference); |
1339 | m.def("TFE_MonitoringGetCellIntGauge0" , &TFE_MonitoringGetCellIntGauge0, |
1340 | py::return_value_policy::reference); |
1341 | m.def( |
1342 | "TFE_MonitoringNewIntGauge1" , |
1343 | [](const char* name, const char* description, const char* label1) { |
1344 | tensorflow::Safe_TF_StatusPtr status = |
1345 | tensorflow::make_safe(TF_NewStatus()); |
1346 | auto output = |
1347 | TFE_MonitoringNewIntGauge1(name, status.get(), description, label1); |
1348 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
1349 | return output; |
1350 | }, |
1351 | py::return_value_policy::reference); |
1352 | m.def("TFE_MonitoringDeleteIntGauge1" , &TFE_MonitoringDeleteIntGauge1, |
1353 | py::return_value_policy::reference); |
1354 | m.def("TFE_MonitoringGetCellIntGauge1" , &TFE_MonitoringGetCellIntGauge1, |
1355 | py::return_value_policy::reference); |
1356 | m.def( |
1357 | "TFE_MonitoringNewIntGauge2" , |
1358 | [](const char* name, const char* description, const char* label1, |
1359 | const char* label2) { |
1360 | tensorflow::Safe_TF_StatusPtr status = |
1361 | tensorflow::make_safe(TF_NewStatus()); |
1362 | auto output = TFE_MonitoringNewIntGauge2(name, status.get(), |
1363 | description, label1, label2); |
1364 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
1365 | return output; |
1366 | }, |
1367 | py::return_value_policy::reference); |
1368 | m.def("TFE_MonitoringDeleteIntGauge2" , &TFE_MonitoringDeleteIntGauge2, |
1369 | py::return_value_policy::reference); |
1370 | m.def("TFE_MonitoringGetCellIntGauge2" , &TFE_MonitoringGetCellIntGauge2, |
1371 | py::return_value_policy::reference); |
1372 | m.def("TFE_MonitoringStringGaugeCellSet" , &TFE_MonitoringStringGaugeCellSet); |
1373 | m.def("TFE_MonitoringStringGaugeCellValue" , |
1374 | &TFE_MonitoringStringGaugeCellValue); |
1375 | m.def( |
1376 | "TFE_MonitoringNewStringGauge0" , |
1377 | [](const char* name, const char* description) { |
1378 | tensorflow::Safe_TF_StatusPtr status = |
1379 | tensorflow::make_safe(TF_NewStatus()); |
1380 | auto output = |
1381 | TFE_MonitoringNewStringGauge0(name, status.get(), description); |
1382 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
1383 | return output; |
1384 | }, |
1385 | py::return_value_policy::reference); |
1386 | |
1387 | // TFE_MonitoringStringGauge Logic |
1388 | m.def("TFE_MonitoringDeleteStringGauge0" , &TFE_MonitoringDeleteStringGauge0); |
1389 | m.def("TFE_MonitoringGetCellStringGauge0" , &TFE_MonitoringGetCellStringGauge0, |
1390 | py::return_value_policy::reference); |
1391 | m.def( |
1392 | "TFE_MonitoringNewStringGauge1" , |
1393 | [](const char* name, const char* description, const char* label1) { |
1394 | tensorflow::Safe_TF_StatusPtr status = |
1395 | tensorflow::make_safe(TF_NewStatus()); |
1396 | auto output = TFE_MonitoringNewStringGauge1(name, status.get(), |
1397 | description, label1); |
1398 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
1399 | return output; |
1400 | }, |
1401 | py::return_value_policy::reference); |
1402 | m.def("TFE_MonitoringDeleteStringGauge1" , &TFE_MonitoringDeleteStringGauge1); |
1403 | m.def("TFE_MonitoringGetCellStringGauge1" , &TFE_MonitoringGetCellStringGauge1, |
1404 | py::return_value_policy::reference); |
1405 | m.def( |
1406 | "TFE_MonitoringNewStringGauge2" , |
1407 | [](const char* name, const char* description, const char* label1, |
1408 | const char* label2) { |
1409 | tensorflow::Safe_TF_StatusPtr status = |
1410 | tensorflow::make_safe(TF_NewStatus()); |
1411 | auto output = TFE_MonitoringNewStringGauge2( |
1412 | name, status.get(), description, label1, label2); |
1413 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
1414 | return output; |
1415 | }, |
1416 | py::return_value_policy::reference); |
1417 | m.def("TFE_MonitoringDeleteStringGauge2" , &TFE_MonitoringDeleteStringGauge2); |
1418 | m.def("TFE_MonitoringGetCellStringGauge2" , &TFE_MonitoringGetCellStringGauge2, |
1419 | py::return_value_policy::reference); |
1420 | |
1421 | m.def( |
1422 | "TFE_MonitoringNewStringGauge3" , |
1423 | [](const char* name, const char* description, const char* label1, |
1424 | const char* label2, const char* label3) { |
1425 | tensorflow::Safe_TF_StatusPtr status = |
1426 | tensorflow::make_safe(TF_NewStatus()); |
1427 | auto output = TFE_MonitoringNewStringGauge3( |
1428 | name, status.get(), description, label1, label2, label3); |
1429 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
1430 | return output; |
1431 | }, |
1432 | py::return_value_policy::reference); |
1433 | m.def("TFE_MonitoringDeleteStringGauge3" , &TFE_MonitoringDeleteStringGauge3); |
1434 | m.def("TFE_MonitoringGetCellStringGauge3" , &TFE_MonitoringGetCellStringGauge3, |
1435 | py::return_value_policy::reference); |
1436 | |
1437 | m.def( |
1438 | "TFE_MonitoringNewStringGauge4" , |
1439 | [](const char* name, const char* description, const char* label1, |
1440 | const char* label2, const char* label3, const char* label4) { |
1441 | tensorflow::Safe_TF_StatusPtr status = |
1442 | tensorflow::make_safe(TF_NewStatus()); |
1443 | auto output = TFE_MonitoringNewStringGauge4( |
1444 | name, status.get(), description, label1, label2, label3, label4); |
1445 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
1446 | return output; |
1447 | }, |
1448 | py::return_value_policy::reference); |
1449 | m.def("TFE_MonitoringDeleteStringGauge4" , &TFE_MonitoringDeleteStringGauge4); |
1450 | m.def("TFE_MonitoringGetCellStringGauge4" , &TFE_MonitoringGetCellStringGauge4, |
1451 | py::return_value_policy::reference); |
1452 | |
1453 | // TFE_MonitoringBoolGauge Logic |
1454 | m.def("TFE_MonitoringBoolGaugeCellSet" , &TFE_MonitoringBoolGaugeCellSet); |
1455 | m.def("TFE_MonitoringBoolGaugeCellValue" , &TFE_MonitoringBoolGaugeCellValue); |
1456 | m.def( |
1457 | "TFE_MonitoringNewBoolGauge0" , |
1458 | [](const char* name, const char* description) { |
1459 | tensorflow::Safe_TF_StatusPtr status = |
1460 | tensorflow::make_safe(TF_NewStatus()); |
1461 | auto output = |
1462 | TFE_MonitoringNewBoolGauge0(name, status.get(), description); |
1463 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
1464 | return output; |
1465 | }, |
1466 | py::return_value_policy::reference); |
1467 | m.def("TFE_MonitoringDeleteBoolGauge0" , &TFE_MonitoringDeleteBoolGauge0, |
1468 | py::return_value_policy::reference); |
1469 | m.def("TFE_MonitoringGetCellBoolGauge0" , &TFE_MonitoringGetCellBoolGauge0, |
1470 | py::return_value_policy::reference); |
1471 | m.def( |
1472 | "TFE_MonitoringNewBoolGauge1" , |
1473 | [](const char* name, const char* description, const char* label1) { |
1474 | tensorflow::Safe_TF_StatusPtr status = |
1475 | tensorflow::make_safe(TF_NewStatus()); |
1476 | auto output = TFE_MonitoringNewBoolGauge1(name, status.get(), |
1477 | description, label1); |
1478 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
1479 | return output; |
1480 | }, |
1481 | py::return_value_policy::reference); |
1482 | m.def("TFE_MonitoringDeleteBoolGauge1" , &TFE_MonitoringDeleteBoolGauge1, |
1483 | py::return_value_policy::reference); |
1484 | m.def("TFE_MonitoringGetCellBoolGauge1" , &TFE_MonitoringGetCellBoolGauge1, |
1485 | py::return_value_policy::reference); |
1486 | m.def( |
1487 | "TFE_MonitoringNewBoolGauge2" , |
1488 | [](const char* name, const char* description, const char* label1, |
1489 | const char* label2) { |
1490 | tensorflow::Safe_TF_StatusPtr status = |
1491 | tensorflow::make_safe(TF_NewStatus()); |
1492 | auto output = TFE_MonitoringNewBoolGauge2(name, status.get(), |
1493 | description, label1, label2); |
1494 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
1495 | return output; |
1496 | }, |
1497 | py::return_value_policy::reference); |
1498 | m.def("TFE_MonitoringDeleteBoolGauge2" , &TFE_MonitoringDeleteBoolGauge2, |
1499 | py::return_value_policy::reference); |
1500 | m.def("TFE_MonitoringGetCellBoolGauge2" , &TFE_MonitoringGetCellBoolGauge2, |
1501 | py::return_value_policy::reference); |
1502 | |
1503 | // TFE_MonitoringSampler Logic |
1504 | m.def("TFE_MonitoringSamplerCellAdd" , &TFE_MonitoringSamplerCellAdd); |
1505 | m.def("TFE_MonitoringSamplerCellValue" , &TFE_MonitoringSamplerCellValue); |
1506 | m.def("TFE_MonitoringNewExponentialBuckets" , |
1507 | &TFE_MonitoringNewExponentialBuckets, |
1508 | py::return_value_policy::reference); |
1509 | m.def("TFE_MonitoringDeleteBuckets" , &TFE_MonitoringDeleteBuckets, |
1510 | py::return_value_policy::reference); |
1511 | m.def( |
1512 | "TFE_MonitoringNewSampler0" , |
1513 | [](const char* name, TFE_MonitoringBuckets* buckets, |
1514 | const char* description) { |
1515 | tensorflow::Safe_TF_StatusPtr status = |
1516 | tensorflow::make_safe(TF_NewStatus()); |
1517 | auto output = |
1518 | TFE_MonitoringNewSampler0(name, buckets, status.get(), description); |
1519 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
1520 | return output; |
1521 | }, |
1522 | py::return_value_policy::reference); |
1523 | m.def("TFE_MonitoringDeleteSampler0" , &TFE_MonitoringDeleteSampler0, |
1524 | py::return_value_policy::reference); |
1525 | m.def("TFE_MonitoringGetCellSampler0" , &TFE_MonitoringGetCellSampler0, |
1526 | py::return_value_policy::reference); |
1527 | m.def( |
1528 | "TFE_MonitoringNewSampler1" , |
1529 | [](const char* name, TFE_MonitoringBuckets* buckets, |
1530 | const char* description, const char* label1) { |
1531 | tensorflow::Safe_TF_StatusPtr status = |
1532 | tensorflow::make_safe(TF_NewStatus()); |
1533 | auto output = TFE_MonitoringNewSampler1(name, buckets, status.get(), |
1534 | description, label1); |
1535 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
1536 | return output; |
1537 | }, |
1538 | py::return_value_policy::reference); |
1539 | m.def("TFE_MonitoringDeleteSampler1" , &TFE_MonitoringDeleteSampler1, |
1540 | py::return_value_policy::reference); |
1541 | m.def("TFE_MonitoringGetCellSampler1" , &TFE_MonitoringGetCellSampler1, |
1542 | py::return_value_policy::reference); |
1543 | m.def( |
1544 | "TFE_MonitoringNewSampler2" , |
1545 | [](const char* name, TFE_MonitoringBuckets* buckets, |
1546 | const char* description, const char* label1, const char* label2) { |
1547 | tensorflow::Safe_TF_StatusPtr status = |
1548 | tensorflow::make_safe(TF_NewStatus()); |
1549 | auto output = TFE_MonitoringNewSampler2(name, buckets, status.get(), |
1550 | description, label1, label2); |
1551 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
1552 | return output; |
1553 | }, |
1554 | py::return_value_policy::reference); |
1555 | m.def("TFE_MonitoringDeleteSampler2" , &TFE_MonitoringDeleteSampler2, |
1556 | py::return_value_policy::reference); |
1557 | m.def("TFE_MonitoringGetCellSampler2" , &TFE_MonitoringGetCellSampler2, |
1558 | py::return_value_policy::reference); |
1559 | |
1560 | // TFE_CancellationManager Logic |
1561 | m.def("TFE_NewCancellationManager" , |
1562 | []() { return new tensorflow::CancellationManager(); }); |
1563 | m.def("TFE_CancellationManagerIsCancelled" , |
1564 | &tensorflow::CancellationManager::IsCancelled); |
1565 | m.def("TFE_CancellationManagerStartCancel" , |
1566 | &tensorflow::CancellationManager::StartCancel); |
1567 | |
1568 | m.def("TFE_ClearScalarCache" , &tensorflow::TFE_ClearScalarCache); |
1569 | |
1570 | // Util buffer helper functions |
1571 | m.def("TF_NewBufferFromString" , &TF_NewBufferFromString, |
1572 | py::return_value_policy::reference); |
1573 | |
1574 | // DLPack functions |
1575 | m.def("TFE_ToDlpackCapsule" , [](py::handle& o) { |
1576 | PyObject* eager_tensor_pyobject_ptr = o.ptr(); |
1577 | tensorflow::Safe_TF_StatusPtr status = |
1578 | tensorflow::make_safe(TF_NewStatus()); |
1579 | |
1580 | if (!EagerTensor_CheckExact(eager_tensor_pyobject_ptr)) { |
1581 | status->status = tensorflow::errors::InvalidArgument( |
1582 | "The argument to `to_dlpack` must be a TF tensor, not Python object" ); |
1583 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
1584 | } |
1585 | |
1586 | TFE_TensorHandle* thandle = EagerTensor_Handle(eager_tensor_pyobject_ptr); |
1587 | void* dlm_ptr = tensorflow::TFE_HandleToDLPack(thandle, status.get()); |
1588 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
1589 | |
1590 | py::capsule capsule( |
1591 | dlm_ptr, tensorflow::kDlTensorCapsuleName, [](PyObject* capsule) { |
1592 | if (PyCapsule_IsValid(capsule, tensorflow::kDlTensorCapsuleName)) { |
1593 | void* dlm_rptr = |
1594 | PyCapsule_GetPointer(capsule, tensorflow::kDlTensorCapsuleName); |
1595 | if (dlm_rptr) { |
1596 | tensorflow::TFE_CallDLManagedTensorDeleter(dlm_rptr); |
1597 | PyCapsule_SetDestructor(capsule, nullptr); |
1598 | } |
1599 | } |
1600 | }); |
1601 | return capsule; |
1602 | }); |
1603 | |
1604 | m.def("TFE_FromDlpackCapsule" , [](const py::capsule& pycapsule, |
1605 | const py::handle& context) { |
1606 | tensorflow::Safe_TF_StatusPtr status = |
1607 | tensorflow::make_safe(TF_NewStatus()); |
1608 | if (absl::string_view(pycapsule.name()) != |
1609 | tensorflow::kDlTensorCapsuleName) { |
1610 | status->status = tensorflow::errors::InvalidArgument( |
1611 | "DLPack tensor must be a capsule with name \"dltensor\", got \"%s\". " |
1612 | "Note that a DLPack tensor may be consumed at most once." , |
1613 | absl::string_view(pycapsule.name())); |
1614 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
1615 | } |
1616 | |
1617 | TFE_TensorHandle* thandle = tensorflow::TFE_HandleFromDLPack( |
1618 | pycapsule, status.get(), tensorflow::InputTFE_Context(context)); |
1619 | |
1620 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
1621 | |
1622 | PyCapsule_SetName(pycapsule.ptr(), "used_dltensor" ); |
1623 | PyCapsule_SetDestructor(pycapsule.ptr(), nullptr); |
1624 | |
1625 | PyObject* pyhandle = EagerTensorFromHandle(thandle); |
1626 | return tensorflow::PyoOrThrow(pyhandle); |
1627 | }); |
1628 | |
1629 | m.def("TFE_Py_IsCustomDevice" , |
1630 | [](const py::handle& context, const char* device_name) { |
1631 | return TFE_IsCustomDevice(tensorflow::InputTFE_Context(context), |
1632 | device_name); |
1633 | }); |
1634 | |
1635 | m.def("TFE_Py_RegisterCustomDevice" , [](const py::handle& context, |
1636 | const py::capsule& device, |
1637 | const char* device_name, |
1638 | const py::capsule& device_info) { |
1639 | tensorflow::Safe_TF_StatusPtr status = |
1640 | tensorflow::make_safe(TF_NewStatus()); |
1641 | if (absl::string_view(device.name()) != "TFE_CustomDevice" ) { |
1642 | status->status = tensorflow::errors::InvalidArgument( |
1643 | "Expected a capsule named 'TFE_CustomDevice' for the `device` " |
1644 | "argument, got " , |
1645 | absl::string_view(device.name())); |
1646 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
1647 | } |
1648 | if (absl::string_view(device_info.name()) != |
1649 | "TFE_CustomDevice_DeviceInfo" ) { |
1650 | status->status = tensorflow::errors::InvalidArgument( |
1651 | "Expected a capsule named 'TFE_CustomDevice_DeviceInfo' for " |
1652 | "the `device_info` argument, got " , |
1653 | absl::string_view(device_info.name())); |
1654 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
1655 | } |
1656 | // TFE_RegisterCustomDevice takes ownership |
1657 | PyCapsule_SetDestructor(device_info.ptr(), nullptr); |
1658 | TFE_RegisterCustomDevice( |
1659 | tensorflow::InputTFE_Context(context), |
1660 | *reinterpret_cast<TFE_CustomDevice*>( |
1661 | PyCapsule_GetPointer(device.ptr(), "TFE_CustomDevice" )), |
1662 | device_name, |
1663 | PyCapsule_GetPointer(device_info.ptr(), "TFE_CustomDevice_DeviceInfo" ), |
1664 | status.get()); |
1665 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
1666 | }); |
1667 | |
1668 | py::class_<EagerContextThreadLocalDataWrapper>(m, |
1669 | "EagerContextThreadLocalData" ) |
1670 | .def(py::init<py::handle, py::handle, py::handle>(), |
1671 | py::arg("py_eager_context" ), py::arg("is_eager" ), |
1672 | py::arg("device_spec" )) |
1673 | .def_property("is_eager" , |
1674 | &EagerContextThreadLocalDataWrapper::get_is_eager, |
1675 | &EagerContextThreadLocalDataWrapper::set_is_eager) |
1676 | .def_property( |
1677 | "invoking_op_callbacks" , |
1678 | &EagerContextThreadLocalDataWrapper::get_invoking_op_callbacks, |
1679 | &EagerContextThreadLocalDataWrapper::set_invoking_op_callbacks) |
1680 | .def_property("device_name" , |
1681 | &EagerContextThreadLocalDataWrapper::get_device_name, |
1682 | &EagerContextThreadLocalDataWrapper::set_device_name) |
1683 | .def_property("scope_name" , |
1684 | &EagerContextThreadLocalDataWrapper::get_scope_name, |
1685 | &EagerContextThreadLocalDataWrapper::set_scope_name) |
1686 | .def_property("device_spec" , |
1687 | &EagerContextThreadLocalDataWrapper::get_device_spec, |
1688 | &EagerContextThreadLocalDataWrapper::set_device_spec) |
1689 | .def_property( |
1690 | "function_call_options" , |
1691 | &EagerContextThreadLocalDataWrapper::get_function_call_options, |
1692 | &EagerContextThreadLocalDataWrapper::set_function_call_options) |
1693 | .def_property("executor" , |
1694 | &EagerContextThreadLocalDataWrapper::get_executor, |
1695 | &EagerContextThreadLocalDataWrapper::set_executor) |
1696 | .def_property("op_callbacks" , |
1697 | &EagerContextThreadLocalDataWrapper::get_op_callbacks, |
1698 | &EagerContextThreadLocalDataWrapper::set_op_callbacks); |
1699 | |
1700 | // C API Enum |
1701 | |
1702 | py::enum_<TFE_ContextDevicePlacementPolicy>( |
1703 | m, "TFE_ContextDevicePlacementPolicy" ) |
1704 | .value("TFE_DEVICE_PLACEMENT_EXPLICIT" , TFE_DEVICE_PLACEMENT_EXPLICIT) |
1705 | .value("TFE_DEVICE_PLACEMENT_WARN" , TFE_DEVICE_PLACEMENT_WARN) |
1706 | .value("TFE_DEVICE_PLACEMENT_SILENT" , TFE_DEVICE_PLACEMENT_SILENT) |
1707 | .value("TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32" , |
1708 | TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32) |
1709 | .export_values(); |
1710 | |
1711 | py::enum_<TF_AttrType>(m, "TF_AttrType" ) |
1712 | .value("TF_ATTR_STRING" , TF_ATTR_STRING) |
1713 | .value("TF_ATTR_INT" , TF_ATTR_INT) |
1714 | .value("TF_ATTR_FLOAT" , TF_ATTR_FLOAT) |
1715 | .value("TF_ATTR_BOOL" , TF_ATTR_BOOL) |
1716 | .value("TF_ATTR_TYPE" , TF_ATTR_TYPE) |
1717 | .value("TF_ATTR_SHAPE" , TF_ATTR_SHAPE) |
1718 | .value("TF_ATTR_TENSOR" , TF_ATTR_TENSOR) |
1719 | .value("TF_ATTR_PLACEHOLDER" , TF_ATTR_PLACEHOLDER) |
1720 | .value("TF_ATTR_FUNC" , TF_ATTR_FUNC) |
1721 | .export_values(); |
1722 | }; |
1723 | |