1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <string>
17#include <vector>
18
19#include "pybind11/pybind11.h"
20#include "pybind11/stl.h"
21#include "tensorflow/c/eager/c_api.h"
22#include "tensorflow/dtensor/cc/dtensor_device.h"
23#include "tensorflow/python/eager/pywrap_tensor.h"
24#include "tensorflow/python/eager/pywrap_tfe.h"
25#include "tensorflow/python/lib/core/pybind11_lib.h"
26#include "tensorflow/python/lib/core/pybind11_status.h"
27#include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
28#include "tensorflow/python/util/util.h"
29
30namespace py = ::pybind11;
31using tensorflow::dtensor::AddMesh;
32using tensorflow::dtensor::AllocateDTensorDevice;
33using tensorflow::dtensor::ClearTPUCoreIDs;
34using tensorflow::dtensor::ExperimentalClearDefaultLayout;
35using tensorflow::dtensor::ExperimentalClearDefaultMesh;
36using tensorflow::dtensor::ExperimentalSetDefaultLayout;
37using tensorflow::dtensor::ExperimentalSetDefaultMesh;
38using tensorflow::dtensor::FetchLayout;
39using tensorflow::dtensor::GetFunctionCacheHitAndMissCount;
40using tensorflow::dtensor::IsSparseDTensor;
41using tensorflow::dtensor::Pack;
42using tensorflow::dtensor::SetSameShapePolicy;
43using tensorflow::dtensor::SetTPUCoreIDs;
44using tensorflow::dtensor::SparsePack;
45using tensorflow::dtensor::TPUCoreIDsToLocations;
46using tensorflow::dtensor::TPUCoreLocationsToIDs;
47using tensorflow::dtensor::Unpack;
48
49void PyXDecref(PyObject* obj) { Py_XDECREF(obj); }
50
51void CallDelete_Device(PyObject* capsule) {
52 delete reinterpret_cast<TFE_CustomDevice*>(
53 PyCapsule_GetPointer(capsule, "TFE_CustomDevice"));
54}
55
56void CallDelete_DeviceInfo(PyObject* capsule) {
57 void (*destructor)(void*) =
58 reinterpret_cast<void (*)(void*)>(PyCapsule_GetContext(capsule));
59 destructor(PyCapsule_GetPointer(capsule, "TFE_CustomDevice_DeviceInfo"));
60}
61
62// Supports 2 cases:
63// i) input is an EagerTensor.
64// ii) input is an arbitrary python list/tuple.
65void ConvertToTensor(TFE_Context* ctx, PyObject* input,
66 tensorflow::Safe_PyObjectPtr* output_handle,
67 TF_Status* status) {
68 if (EagerTensor_CheckExact(input)) {
69 // Input is already a EagerTensor so increment the reference, since the
70 // caller will use it through output_handle.
71 Py_INCREF(input);
72 output_handle->reset(input);
73 return;
74 }
75 TFE_TensorHandle* handle =
76 tensorflow::ConvertToEagerTensor(ctx, input, tensorflow::DT_INVALID);
77 if (handle == nullptr) {
78 TF_SetStatus(status, TF_INTERNAL, "Failure converting to eager tensor.");
79 return;
80 }
81 output_handle->reset(EagerTensorFromHandle(handle));
82}
83
84PYBIND11_MODULE(_pywrap_dtensor_device, m) {
85 m.def("Allocate", [](const std::string& name) {
86 TFE_CustomDevice* device = new TFE_CustomDevice;
87 std::unique_ptr<PyObject, decltype(&PyXDecref)> device_capsule(
88 PyCapsule_New(device, "TFE_CustomDevice", &CallDelete_Device),
89 PyXDecref);
90 void* device_info;
91 AllocateDTensorDevice(name, device, &device_info);
92 std::unique_ptr<PyObject, decltype(&PyXDecref)> device_info_capsule(
93 PyCapsule_New(device_info, "TFE_CustomDevice_DeviceInfo",
94 &CallDelete_DeviceInfo),
95 PyXDecref);
96 // The PyCapsule destructor needs a pointer to the destructor for
97 // DeviceInfo.
98 PyCapsule_SetContext(device_info_capsule.get(),
99 reinterpret_cast<void*>(device->delete_device));
100 if (PyErr_Occurred()) throw py::error_already_set();
101 return pybind11::reinterpret_steal<pybind11::object>(
102 PyTuple_Pack(2, device_capsule.get(), device_info_capsule.get()));
103 });
104 m.def("AddMesh", [](const py::capsule& device_info,
105 const std::string& serialized_mesh, bool is_async,
106 bool is_host_mesh) {
107 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
108 TF_NewStatus(), TF_DeleteStatus);
109 AddMesh(
110 serialized_mesh,
111 PyCapsule_GetPointer(device_info.ptr(), "TFE_CustomDevice_DeviceInfo"),
112 is_async, is_host_mesh, status.get());
113 if (TF_GetCode(status.get()) != TF_OK) {
114 PyErr_SetString(PyExc_ValueError, TF_Message(status.get()));
115 throw py::error_already_set();
116 }
117 });
118 m.def(
119 "ExperimentalSetDefaultLayout",
120 [](const py::capsule& device_info, const std::string& serialized_layout) {
121 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
122 TF_NewStatus(), TF_DeleteStatus);
123 ExperimentalSetDefaultLayout(
124 serialized_layout,
125 PyCapsule_GetPointer(device_info.ptr(),
126 "TFE_CustomDevice_DeviceInfo"),
127 status.get());
128 if (TF_GetCode(status.get()) != TF_OK) {
129 PyErr_SetString(PyExc_ValueError, TF_Message(status.get()));
130 throw py::error_already_set();
131 }
132 });
133 m.def("ExperimentalClearDefaultLayout", [](const py::capsule& device_info) {
134 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
135 TF_NewStatus(), TF_DeleteStatus);
136 ExperimentalClearDefaultLayout(
137 PyCapsule_GetPointer(device_info.ptr(), "TFE_CustomDevice_DeviceInfo"),
138 status.get());
139 if (TF_GetCode(status.get()) != TF_OK) {
140 PyErr_SetString(PyExc_ValueError, TF_Message(status.get()));
141 throw py::error_already_set();
142 }
143 });
144 m.def("ExperimentalSetDefaultMesh", [](const py::capsule& device_info,
145 const std::string& serialized_mesh) {
146 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
147 TF_NewStatus(), TF_DeleteStatus);
148 ExperimentalSetDefaultMesh(
149 serialized_mesh,
150 PyCapsule_GetPointer(device_info.ptr(), "TFE_CustomDevice_DeviceInfo"),
151 status.get());
152 if (TF_GetCode(status.get()) != TF_OK) {
153 PyErr_SetString(PyExc_ValueError, TF_Message(status.get()));
154 throw py::error_already_set();
155 }
156 });
157 m.def("ExperimentalClearDefaultMesh", [](const py::capsule& device_info) {
158 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
159 TF_NewStatus(), TF_DeleteStatus);
160 ExperimentalClearDefaultMesh(
161 PyCapsule_GetPointer(device_info.ptr(), "TFE_CustomDevice_DeviceInfo"),
162 status.get());
163 if (TF_GetCode(status.get()) != TF_OK) {
164 PyErr_SetString(PyExc_ValueError, TF_Message(status.get()));
165 throw py::error_already_set();
166 }
167 });
168 m.def("SetSameShapePolicy", [](const py::capsule& device_info, bool enabled) {
169 SetSameShapePolicy(
170 PyCapsule_GetPointer(device_info.ptr(), "TFE_CustomDevice_DeviceInfo"),
171 enabled);
172 });
173 m.def("SetTPUCoreIDs", [](const py::capsule& device_info,
174 const std::string& mesh_name,
175 const std::vector<int>& tpu_core_ids) {
176 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
177 TF_NewStatus(), TF_DeleteStatus);
178 SetTPUCoreIDs(
179 mesh_name, tpu_core_ids,
180 PyCapsule_GetPointer(device_info.ptr(), "TFE_CustomDevice_DeviceInfo"),
181 status.get());
182 if (TF_GetCode(status.get()) != TF_OK) {
183 PyErr_SetString(PyExc_ValueError, TF_Message(status.get()));
184 throw py::error_already_set();
185 }
186 });
187 m.def("ClearTPUCoreIDs", [](const py::capsule& device_info) {
188 ClearTPUCoreIDs(
189 PyCapsule_GetPointer(device_info.ptr(), "TFE_CustomDevice_DeviceInfo"));
190 });
191 m.def("TPUCoreIDsToLocations", [](const py::handle& context,
192 const py::capsule& device_info,
193 const std::vector<int>& tpu_core_ids) {
194 return TPUCoreIDsToLocations(
195 static_cast<TFE_Context*>(PyCapsule_GetPointer(context.ptr(), nullptr)),
196 tpu_core_ids,
197 PyCapsule_GetPointer(device_info.ptr(), "TFE_CustomDevice_DeviceInfo"));
198 });
199 m.def("TPUCoreLocationsToIDs",
200 [](const py::handle& context, const py::capsule& device_info,
201 const std::vector<std::vector<int>>& tpu_core_locations) {
202 return TPUCoreLocationsToIDs(
203 static_cast<TFE_Context*>(
204 PyCapsule_GetPointer(context.ptr(), nullptr)),
205 tpu_core_locations,
206 PyCapsule_GetPointer(device_info.ptr(),
207 "TFE_CustomDevice_DeviceInfo"));
208 });
209 m.def("Pack", [](const py::handle& context, const py::handle& input_tensors,
210 const std::string& string_layout,
211 const py::capsule& device_info, const bool is_sparse) {
212 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
213 TF_NewStatus(), TF_DeleteStatus);
214 TFE_Context* ctx =
215 static_cast<TFE_Context*>(PyCapsule_GetPointer(context.ptr(), nullptr));
216 // Convert each python object to safe py eagertensors.
217 std::vector<tensorflow::Safe_PyObjectPtr> py_eager_tensor_handles;
218 Py_ssize_t len = PyList_Size(input_tensors.ptr());
219 py_eager_tensor_handles.resize(len);
220
221 for (Py_ssize_t i = 0; i < len; ++i) {
222 PyObject* elem = PyList_GetItem(input_tensors.ptr(), i);
223 ConvertToTensor(ctx, elem, &py_eager_tensor_handles[i], status.get());
224
225 if (tensorflow::MaybeRaiseExceptionFromTFStatus(status.get(), nullptr))
226 return tensorflow::PyoOrThrow(nullptr);
227 }
228 std::vector<TFE_TensorHandle*> input_vector;
229 input_vector.resize(len);
230 for (int i = 0; i < len; ++i)
231 input_vector[i] = EagerTensor_Handle(py_eager_tensor_handles[i].get());
232 TFE_TensorHandle* packed_tensor;
233 if (is_sparse) {
234 auto size = input_vector.size() / 3;
235 packed_tensor = SparsePack(
236 ctx,
237 /*num_inputs=*/input_vector.size() / 3,
238 /*indices=*/
239 std::vector<TFE_TensorHandle*>(input_vector.begin(),
240 input_vector.begin() + size)
241 .data(),
242 /*values=*/
243 std::vector<TFE_TensorHandle*>(input_vector.begin() + size,
244 input_vector.begin() + 2 * size)
245 .data(),
246 /*shapes=*/
247 std::vector<TFE_TensorHandle*>(input_vector.begin() + 2 * size,
248 input_vector.end())
249 .data(),
250 string_layout, device_info, status.get());
251 } else {
252 packed_tensor = Pack(ctx, input_vector.size(), input_vector.data(),
253 string_layout, device_info, status.get());
254 }
255 if (tensorflow::MaybeRaiseExceptionFromTFStatus(status.get(), nullptr))
256 return tensorflow::PyoOrThrow(nullptr);
257 // Convert c++ packed tensor handle into a python eager tensor object.
258 tensorflow::Safe_PyObjectPtr flat_result(PyList_New(1));
259 PyList_SET_ITEM(flat_result.get(), 0, EagerTensorFromHandle(packed_tensor));
260 auto* result = PyList_GET_ITEM(flat_result.get(), 0);
261 Py_INCREF(result);
262 return tensorflow::PyoOrThrow(result);
263 });
264 m.def("Unpack", [](const py::handle& context,
265 const py::handle& dtensor_handle,
266 const py::capsule& device_info) {
267 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
268 TF_NewStatus(), TF_DeleteStatus);
269
270 TFE_TensorHandle* input_handle = EagerTensor_Handle(dtensor_handle.ptr());
271 std::vector<TFE_TensorHandle*> unpacked_handles = Unpack(
272 static_cast<TFE_Context*>(PyCapsule_GetPointer(context.ptr(), nullptr)),
273 input_handle, device_info, status.get());
274
275 if (tensorflow::MaybeRaiseExceptionFromTFStatus(status.get(), nullptr))
276 return tensorflow::PyoOrThrow(nullptr);
277 // Convert all TFE_TensorHandles to py EagerTensor and
278 // return a python list of them.
279 int num_outputs = unpacked_handles.size();
280 PyObject* result(PyList_New(num_outputs));
281 for (int i = 0; i < num_outputs; ++i) {
282 PyList_SET_ITEM(result, i, EagerTensorFromHandle(unpacked_handles[i]));
283 }
284 return tensorflow::PyoOrThrow(result);
285 });
286 m.def(
287 "FetchLayout",
288 [](const py::handle& context, const py::handle& dtensor_handle,
289 const py::capsule& device_info) -> py::object {
290 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
291 TF_NewStatus(), TF_DeleteStatus);
292
293 std::string layout_string =
294 FetchLayout(static_cast<TFE_Context*>(
295 PyCapsule_GetPointer(context.ptr(), nullptr)),
296 EagerTensor_Handle(dtensor_handle.ptr()), device_info,
297 status.get());
298 if (tensorflow::MaybeRaiseExceptionFromTFStatus(status.get(), nullptr))
299 return tensorflow::PyoOrThrow(nullptr);
300 return tensorflow::PyoOrThrow(
301 PyUnicode_FromString(layout_string.c_str()));
302 });
303 m.def("IsSparseDTensor", [](const py::handle& context,
304 const py::handle& dtensor_handle,
305 const py::capsule& device_info) {
306 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
307 TF_NewStatus(), TF_DeleteStatus);
308
309 TFE_TensorHandle* input_handle = EagerTensor_Handle(dtensor_handle.ptr());
310 bool is_sparse = IsSparseDTensor(
311 static_cast<TFE_Context*>(PyCapsule_GetPointer(context.ptr(), nullptr)),
312 input_handle, device_info, status.get());
313
314 if (TF_GetCode(status.get()) != TF_OK) {
315 PyErr_SetString(PyExc_ValueError, TF_Message(status.get()));
316 throw py::error_already_set();
317 }
318 return is_sparse;
319 });
320 m.def("GetFunctionCacheHitAndMissCount", [](const py::handle& context,
321 const py::capsule& device_info) {
322 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
323 TF_NewStatus(), TF_DeleteStatus);
324 return GetFunctionCacheHitAndMissCount(
325 static_cast<TFE_Context*>(PyCapsule_GetPointer(context.ptr(), nullptr)),
326 device_info, status.get());
327 });
328}
329