1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/c/eager/dlpack.h"
17
18#include <string>
19
20#include "include/dlpack/dlpack.h" // from @dlpack
21#include "tensorflow/c/eager/c_api.h"
22#include "tensorflow/c/eager/c_api_experimental.h"
23#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
24#include "tensorflow/c/tf_status_internal.h"
25#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
26#include "tensorflow/core/framework/tensor.h"
27#include "tensorflow/core/framework/tensor_reference.h"
28#include "tensorflow/core/platform/logging.h"
29
30namespace tensorflow {
31
32namespace {
33
34// Managing context for the DLManagedTensor, will manage the lifetime of
35// DLManagedTensor. When calling DLManagedTensor::deleter, it will notify the
36// original framework of destruction, and this context will be deleted also.
37struct TfDlManagedTensorCtx {
38 TensorReference reference;
39 std::vector<int64_t> shape;
40 std::vector<int64_t> strides;
41 DLManagedTensor tensor;
42
43 explicit TfDlManagedTensorCtx(const TensorReference& ref) : reference(ref) {}
44};
45
46// Gets tensor from eager tensor handle.
47const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) {
48 if (h == nullptr) {
49 status->status = tensorflow::errors::InvalidArgument("Invalid handle");
50 return nullptr;
51 }
52 tensorflow::TensorHandle* handle =
53 tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h));
54 if (handle->Type() != TensorHandle::LOCAL) {
55 status->status = tensorflow::errors::InvalidArgument(
56 "DLPack doesn't support ", handle->TypeString(), " tensor");
57 return nullptr;
58 }
59 const tensorflow::Tensor* tensor;
60 status->status = handle->Tensor(&tensor);
61 if (!status->status.ok()) {
62 return nullptr;
63 }
64 return tensor;
65}
66
67// Deleter for DLManagedTensor
68void DLManagedTensorDeleter(DLManagedTensor* arg) {
69 TfDlManagedTensorCtx* owner =
70 static_cast<TfDlManagedTensorCtx*>(arg->manager_ctx);
71 owner->reference.Unref();
72 delete owner;
73}
74
75// Converts TF_DATAType to DLPack data type.
76DLDataType GetDlDataType(TF_DataType data_type, TF_Status* status) {
77 DLDataType dtype;
78 dtype.lanes = 1;
79 dtype.bits = TF_DataTypeSize(data_type) * 8;
80 switch (data_type) {
81 case TF_DataType::TF_HALF:
82 case TF_DataType::TF_FLOAT:
83 case TF_DataType::TF_DOUBLE:
84 dtype.code = DLDataTypeCode::kDLFloat;
85 break;
86 case TF_DataType::TF_INT8:
87 case TF_DataType::TF_INT16:
88 case TF_DataType::TF_INT32:
89 case TF_DataType::TF_INT64:
90 dtype.code = DLDataTypeCode::kDLInt;
91 break;
92 case TF_DataType::TF_BOOL:
93 case TF_DataType::TF_UINT8:
94 case TF_DataType::TF_UINT16:
95 case TF_DataType::TF_UINT32:
96 case TF_DataType::TF_UINT64:
97 dtype.code = DLDataTypeCode::kDLUInt;
98 break;
99 case TF_DataType::TF_BFLOAT16:
100 dtype.code = DLDataTypeCode::kDLBfloat;
101 break;
102 case TF_DataType::TF_COMPLEX64:
103 case TF_DataType::TF_COMPLEX128:
104 dtype.code = DLDataTypeCode::kDLComplex;
105 break;
106 default:
107 status->status = tensorflow::errors::InvalidArgument(
108 DataType_Name(static_cast<DataType>(data_type)),
109 " is not supported by dlpack");
110 break;
111 }
112 return dtype;
113}
114
115// Gets DLPack's DLDevice from eager tensor handle.
116DLDevice GetDlContext(TFE_TensorHandle* h, TF_Status* status) {
117 DLDevice ctx;
118 const char* device_name =
119 tensorflow::unwrap(h)->BackingDeviceName(&status->status);
120 DeviceNameUtils::ParsedName parsed_name;
121 tensorflow::DeviceNameUtils::ParseFullName(device_name, &parsed_name);
122 std::string device_type = parsed_name.type;
123 int device_id = 0;
124 if (parsed_name.has_id) {
125 device_id = parsed_name.id;
126 }
127
128 ctx.device_id = device_id;
129 if (device_type == "CPU") {
130 ctx.device_type = DLDeviceType::kDLCPU;
131 } else if (device_type == "GPU") {
132 ctx.device_type = DLDeviceType::kDLCUDA;
133 } else {
134 status->status = tensorflow::errors::InvalidArgument(
135 "Unsupported Device Type for dlpack");
136 }
137
138 return ctx;
139}
140
141// Converts DLDevice to TF device name.
142absl::optional<std::string> DeviceNameFromDlContext(const DLDevice& ctx,
143 TF_Status* status) {
144 switch (ctx.device_type) {
145 case DLDeviceType::kDLCPU:
146 return "CPU:0";
147 case DLDeviceType::kDLCUDA:
148 return absl::StrCat("GPU:", ctx.device_id);
149 default:
150 return absl::nullopt;
151 }
152}
153
154// Converts DLPack data type to TF_DATATYPE.
155Status TfDataTypeFormDlDataType(const DLDataType& dtype,
156 TF_DataType* tf_dtype) {
157 switch (dtype.code) {
158 case DLDataTypeCode::kDLUInt:
159 switch (dtype.bits) {
160 case 8:
161 *tf_dtype = TF_DataType::TF_UINT8;
162 return OkStatus();
163 case 16:
164 *tf_dtype = TF_DataType::TF_UINT16;
165 return OkStatus();
166 case 32:
167 *tf_dtype = TF_DataType::TF_UINT32;
168 return OkStatus();
169 case 64:
170 *tf_dtype = TF_DataType::TF_UINT64;
171 return OkStatus();
172 default:
173 return tensorflow::errors::InvalidArgument("Unsupported UInt bits: ",
174 dtype.bits);
175 }
176 return OkStatus();
177 case DLDataTypeCode::kDLInt:
178 switch (dtype.bits) {
179 case 8:
180 *tf_dtype = TF_DataType::TF_INT8;
181 return OkStatus();
182 case 16:
183 *tf_dtype = TF_DataType::TF_INT16;
184 return OkStatus();
185 case 32:
186 *tf_dtype = TF_DataType::TF_INT32;
187 return OkStatus();
188 case 64:
189 *tf_dtype = TF_DataType::TF_INT64;
190 return OkStatus();
191 default:
192 return tensorflow::errors::InvalidArgument("Unsupported Int bits: ",
193 dtype.bits);
194 }
195 return OkStatus();
196 case DLDataTypeCode::kDLFloat:
197 switch (dtype.bits) {
198 case 16:
199 *tf_dtype = TF_DataType::TF_HALF;
200 return OkStatus();
201 case 32:
202 *tf_dtype = TF_DataType::TF_FLOAT;
203 return OkStatus();
204 case 64:
205 *tf_dtype = TF_DataType::TF_DOUBLE;
206 return OkStatus();
207 default:
208 return tensorflow::errors::InvalidArgument("Unsupported Float bits: ",
209 dtype.bits);
210 }
211 break;
212 case DLDataTypeCode::kDLBfloat:
213 switch (dtype.bits) {
214 case 16:
215 *tf_dtype = TF_DataType::TF_BFLOAT16;
216 return OkStatus();
217 default:
218 return tensorflow::errors::InvalidArgument(
219 "Unsupported BFloat bits: ", dtype.bits);
220 }
221 break;
222 case DLDataTypeCode::kDLComplex:
223 switch (dtype.bits) {
224 case 64:
225 *tf_dtype = TF_DataType::TF_COMPLEX64;
226 return OkStatus();
227 case 128:
228 *tf_dtype = TF_DataType::TF_COMPLEX128;
229 return OkStatus();
230 default:
231 return tensorflow::errors::InvalidArgument(
232 "Unsupported Complex bits: ", dtype.bits);
233 }
234 break;
235 default:
236 return tensorflow::errors::InvalidArgument("Unsupported Type Codes: ",
237 dtype.code);
238 }
239}
240
241// Wraps the deleter function of DLManagedTensor to match the function signature
242// TFE_NewTensorHandleFromDeviceMemory.
243void DeallocatorWrapperFunc(void* data, size_t len, void* dlmt_vptr) {
244 TFE_CallDLManagedTensorDeleter(dlmt_vptr);
245}
246
247// Checks whether the stride array matches the layout of compact, row-majored
248// data.
249bool IsValidStrideCompactRowMajorData(int64_t* shape_arr, int64_t* stride_arr,
250 int ndim) {
251 if (ndim >= 1 && stride_arr[ndim - 1] != 1) {
252 return false;
253 }
254 for (int i = ndim - 2; i >= 0; --i) {
255 if (stride_arr[i] != shape_arr[i + 1] * stride_arr[i + 1]) {
256 return false;
257 }
258 }
259 return true;
260}
261} // namespace
262
263void TFE_CallDLManagedTensorDeleter(void* dlm_ptr) {
264 DLManagedTensor* dlMTensor = static_cast<DLManagedTensor*>(dlm_ptr);
265 if (dlMTensor->deleter != nullptr) {
266 dlMTensor->deleter(dlMTensor);
267 }
268}
269
270void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) {
271 auto tf_dlm_context = GetDlContext(h, status);
272 if (!status->status.ok()) {
273 return nullptr;
274 }
275
276 auto* tf_dlm_data = TFE_TensorHandleDevicePointer(h, status);
277 if (!status->status.ok()) {
278 return nullptr;
279 }
280
281 const Tensor* tensor = GetTensorFromHandle(h, status);
282 TF_DataType data_type = static_cast<TF_DataType>(tensor->dtype());
283
284 auto tf_dlm_type = GetDlDataType(data_type, status);
285 if (!status->status.ok()) {
286 return nullptr;
287 }
288
289 TensorReference tensor_ref(*tensor); // This will call buf_->Ref()
290 auto* tf_dlm_tensor_ctx = new TfDlManagedTensorCtx(tensor_ref);
291 tf_dlm_tensor_ctx->reference = tensor_ref;
292
293 DLManagedTensor* dlm_tensor = &tf_dlm_tensor_ctx->tensor;
294 dlm_tensor->manager_ctx = tf_dlm_tensor_ctx;
295 dlm_tensor->deleter = &DLManagedTensorDeleter;
296 dlm_tensor->dl_tensor.device = tf_dlm_context;
297 int ndim = tensor->dims();
298 dlm_tensor->dl_tensor.ndim = ndim;
299 dlm_tensor->dl_tensor.data = tf_dlm_data;
300 dlm_tensor->dl_tensor.dtype = tf_dlm_type;
301
302 std::vector<int64_t>* shape_arr = &tf_dlm_tensor_ctx->shape;
303 std::vector<int64_t>* stride_arr = &tf_dlm_tensor_ctx->strides;
304 shape_arr->resize(ndim);
305 stride_arr->resize(ndim, 1);
306 for (int i = 0; i < ndim; i++) {
307 (*shape_arr)[i] = tensor->dim_size(i);
308 }
309 for (int i = ndim - 2; i >= 0; --i) {
310 (*stride_arr)[i] = (*shape_arr)[i + 1] * (*stride_arr)[i + 1];
311 }
312
313 dlm_tensor->dl_tensor.shape = shape_arr->data();
314 // There are two ways to represent compact row-major data
315 // 1) nullptr indicates tensor is compact and row-majored.
316 // 2) fill in the strides array as the real case for compact row-major data.
317 // Here we choose option 2, since some frameworks didn't handle the strides
318 // argument properly.
319 dlm_tensor->dl_tensor.strides = stride_arr->data();
320
321 dlm_tensor->dl_tensor.byte_offset =
322 0; // TF doesn't handle the strides and byte_offsets here
323 return static_cast<void*>(dlm_tensor);
324}
325
326TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status,
327 TFE_Context* ctx) {
328 DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(dlm);
329 DLTensor* dl_tensor = &dlmt->dl_tensor;
330 absl::optional<std::string> device_name =
331 DeviceNameFromDlContext(dl_tensor->device, status);
332 if (!device_name.has_value()) {
333 status->status =
334 tensorflow::errors::InvalidArgument("Unsupported Device Type");
335 return nullptr;
336 }
337 TF_DataType dtype;
338 Status s = TfDataTypeFormDlDataType(dl_tensor->dtype, &dtype);
339 if (!s.ok()) {
340 status->status = std::move(s);
341 return nullptr;
342 }
343 int num_dims = dl_tensor->ndim;
344 const int64_t* dims = dl_tensor->shape;
345 void* data = dl_tensor->data;
346
347 size_t total_bytes = dl_tensor->dtype.bits / 8;
348 for (int i = 0; i < num_dims; i++) {
349 total_bytes *= dims[i];
350 }
351
352 if (dl_tensor->strides != nullptr &&
353 !IsValidStrideCompactRowMajorData(dl_tensor->shape, dl_tensor->strides,
354 num_dims)) {
355 status->status = tensorflow::errors::InvalidArgument(
356 "Invalid strides array from DLPack");
357 return nullptr;
358 }
359
360 TFE_TensorHandle* handle = TFE_NewTensorHandleFromDeviceMemory(
361 ctx, device_name.value().c_str(), dtype, dims, num_dims, data,
362 total_bytes, &DeallocatorWrapperFunc, dlmt, status);
363
364 return handle;
365}
366
367} // namespace tensorflow
368