1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/c/eager/dlpack.h" |
17 | |
18 | #include <string> |
19 | |
20 | #include "include/dlpack/dlpack.h" // from @dlpack |
21 | #include "tensorflow/c/eager/c_api.h" |
22 | #include "tensorflow/c/eager/c_api_experimental.h" |
23 | #include "tensorflow/c/eager/tfe_tensorhandle_internal.h" |
24 | #include "tensorflow/c/tf_status_internal.h" |
25 | #include "tensorflow/core/common_runtime/eager/tensor_handle.h" |
26 | #include "tensorflow/core/framework/tensor.h" |
27 | #include "tensorflow/core/framework/tensor_reference.h" |
28 | #include "tensorflow/core/platform/logging.h" |
29 | |
30 | namespace tensorflow { |
31 | |
32 | namespace { |
33 | |
34 | // Managing context for the DLManagedTensor, will manage the lifetime of |
35 | // DLManagedTensor. When calling DLManagedTensor::deleter, it will notify the |
36 | // original framework of destruction, and this context will be deleted also. |
37 | struct TfDlManagedTensorCtx { |
38 | TensorReference reference; |
39 | std::vector<int64_t> shape; |
40 | std::vector<int64_t> strides; |
41 | DLManagedTensor tensor; |
42 | |
43 | explicit TfDlManagedTensorCtx(const TensorReference& ref) : reference(ref) {} |
44 | }; |
45 | |
46 | // Gets tensor from eager tensor handle. |
47 | const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) { |
48 | if (h == nullptr) { |
49 | status->status = tensorflow::errors::InvalidArgument("Invalid handle" ); |
50 | return nullptr; |
51 | } |
52 | tensorflow::TensorHandle* handle = |
53 | tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h)); |
54 | if (handle->Type() != TensorHandle::LOCAL) { |
55 | status->status = tensorflow::errors::InvalidArgument( |
56 | "DLPack doesn't support " , handle->TypeString(), " tensor" ); |
57 | return nullptr; |
58 | } |
59 | const tensorflow::Tensor* tensor; |
60 | status->status = handle->Tensor(&tensor); |
61 | if (!status->status.ok()) { |
62 | return nullptr; |
63 | } |
64 | return tensor; |
65 | } |
66 | |
67 | // Deleter for DLManagedTensor |
68 | void DLManagedTensorDeleter(DLManagedTensor* arg) { |
69 | TfDlManagedTensorCtx* owner = |
70 | static_cast<TfDlManagedTensorCtx*>(arg->manager_ctx); |
71 | owner->reference.Unref(); |
72 | delete owner; |
73 | } |
74 | |
75 | // Converts TF_DATAType to DLPack data type. |
76 | DLDataType GetDlDataType(TF_DataType data_type, TF_Status* status) { |
77 | DLDataType dtype; |
78 | dtype.lanes = 1; |
79 | dtype.bits = TF_DataTypeSize(data_type) * 8; |
80 | switch (data_type) { |
81 | case TF_DataType::TF_HALF: |
82 | case TF_DataType::TF_FLOAT: |
83 | case TF_DataType::TF_DOUBLE: |
84 | dtype.code = DLDataTypeCode::kDLFloat; |
85 | break; |
86 | case TF_DataType::TF_INT8: |
87 | case TF_DataType::TF_INT16: |
88 | case TF_DataType::TF_INT32: |
89 | case TF_DataType::TF_INT64: |
90 | dtype.code = DLDataTypeCode::kDLInt; |
91 | break; |
92 | case TF_DataType::TF_BOOL: |
93 | case TF_DataType::TF_UINT8: |
94 | case TF_DataType::TF_UINT16: |
95 | case TF_DataType::TF_UINT32: |
96 | case TF_DataType::TF_UINT64: |
97 | dtype.code = DLDataTypeCode::kDLUInt; |
98 | break; |
99 | case TF_DataType::TF_BFLOAT16: |
100 | dtype.code = DLDataTypeCode::kDLBfloat; |
101 | break; |
102 | case TF_DataType::TF_COMPLEX64: |
103 | case TF_DataType::TF_COMPLEX128: |
104 | dtype.code = DLDataTypeCode::kDLComplex; |
105 | break; |
106 | default: |
107 | status->status = tensorflow::errors::InvalidArgument( |
108 | DataType_Name(static_cast<DataType>(data_type)), |
109 | " is not supported by dlpack" ); |
110 | break; |
111 | } |
112 | return dtype; |
113 | } |
114 | |
115 | // Gets DLPack's DLDevice from eager tensor handle. |
116 | DLDevice GetDlContext(TFE_TensorHandle* h, TF_Status* status) { |
117 | DLDevice ctx; |
118 | const char* device_name = |
119 | tensorflow::unwrap(h)->BackingDeviceName(&status->status); |
120 | DeviceNameUtils::ParsedName parsed_name; |
121 | tensorflow::DeviceNameUtils::ParseFullName(device_name, &parsed_name); |
122 | std::string device_type = parsed_name.type; |
123 | int device_id = 0; |
124 | if (parsed_name.has_id) { |
125 | device_id = parsed_name.id; |
126 | } |
127 | |
128 | ctx.device_id = device_id; |
129 | if (device_type == "CPU" ) { |
130 | ctx.device_type = DLDeviceType::kDLCPU; |
131 | } else if (device_type == "GPU" ) { |
132 | ctx.device_type = DLDeviceType::kDLCUDA; |
133 | } else { |
134 | status->status = tensorflow::errors::InvalidArgument( |
135 | "Unsupported Device Type for dlpack" ); |
136 | } |
137 | |
138 | return ctx; |
139 | } |
140 | |
141 | // Converts DLDevice to TF device name. |
142 | absl::optional<std::string> DeviceNameFromDlContext(const DLDevice& ctx, |
143 | TF_Status* status) { |
144 | switch (ctx.device_type) { |
145 | case DLDeviceType::kDLCPU: |
146 | return "CPU:0" ; |
147 | case DLDeviceType::kDLCUDA: |
148 | return absl::StrCat("GPU:" , ctx.device_id); |
149 | default: |
150 | return absl::nullopt; |
151 | } |
152 | } |
153 | |
154 | // Converts DLPack data type to TF_DATATYPE. |
155 | Status TfDataTypeFormDlDataType(const DLDataType& dtype, |
156 | TF_DataType* tf_dtype) { |
157 | switch (dtype.code) { |
158 | case DLDataTypeCode::kDLUInt: |
159 | switch (dtype.bits) { |
160 | case 8: |
161 | *tf_dtype = TF_DataType::TF_UINT8; |
162 | return OkStatus(); |
163 | case 16: |
164 | *tf_dtype = TF_DataType::TF_UINT16; |
165 | return OkStatus(); |
166 | case 32: |
167 | *tf_dtype = TF_DataType::TF_UINT32; |
168 | return OkStatus(); |
169 | case 64: |
170 | *tf_dtype = TF_DataType::TF_UINT64; |
171 | return OkStatus(); |
172 | default: |
173 | return tensorflow::errors::InvalidArgument("Unsupported UInt bits: " , |
174 | dtype.bits); |
175 | } |
176 | return OkStatus(); |
177 | case DLDataTypeCode::kDLInt: |
178 | switch (dtype.bits) { |
179 | case 8: |
180 | *tf_dtype = TF_DataType::TF_INT8; |
181 | return OkStatus(); |
182 | case 16: |
183 | *tf_dtype = TF_DataType::TF_INT16; |
184 | return OkStatus(); |
185 | case 32: |
186 | *tf_dtype = TF_DataType::TF_INT32; |
187 | return OkStatus(); |
188 | case 64: |
189 | *tf_dtype = TF_DataType::TF_INT64; |
190 | return OkStatus(); |
191 | default: |
192 | return tensorflow::errors::InvalidArgument("Unsupported Int bits: " , |
193 | dtype.bits); |
194 | } |
195 | return OkStatus(); |
196 | case DLDataTypeCode::kDLFloat: |
197 | switch (dtype.bits) { |
198 | case 16: |
199 | *tf_dtype = TF_DataType::TF_HALF; |
200 | return OkStatus(); |
201 | case 32: |
202 | *tf_dtype = TF_DataType::TF_FLOAT; |
203 | return OkStatus(); |
204 | case 64: |
205 | *tf_dtype = TF_DataType::TF_DOUBLE; |
206 | return OkStatus(); |
207 | default: |
208 | return tensorflow::errors::InvalidArgument("Unsupported Float bits: " , |
209 | dtype.bits); |
210 | } |
211 | break; |
212 | case DLDataTypeCode::kDLBfloat: |
213 | switch (dtype.bits) { |
214 | case 16: |
215 | *tf_dtype = TF_DataType::TF_BFLOAT16; |
216 | return OkStatus(); |
217 | default: |
218 | return tensorflow::errors::InvalidArgument( |
219 | "Unsupported BFloat bits: " , dtype.bits); |
220 | } |
221 | break; |
222 | case DLDataTypeCode::kDLComplex: |
223 | switch (dtype.bits) { |
224 | case 64: |
225 | *tf_dtype = TF_DataType::TF_COMPLEX64; |
226 | return OkStatus(); |
227 | case 128: |
228 | *tf_dtype = TF_DataType::TF_COMPLEX128; |
229 | return OkStatus(); |
230 | default: |
231 | return tensorflow::errors::InvalidArgument( |
232 | "Unsupported Complex bits: " , dtype.bits); |
233 | } |
234 | break; |
235 | default: |
236 | return tensorflow::errors::InvalidArgument("Unsupported Type Codes: " , |
237 | dtype.code); |
238 | } |
239 | } |
240 | |
241 | // Wraps the deleter function of DLManagedTensor to match the function signature |
242 | // TFE_NewTensorHandleFromDeviceMemory. |
243 | void DeallocatorWrapperFunc(void* data, size_t len, void* dlmt_vptr) { |
244 | TFE_CallDLManagedTensorDeleter(dlmt_vptr); |
245 | } |
246 | |
247 | // Checks whether the stride array matches the layout of compact, row-majored |
248 | // data. |
249 | bool IsValidStrideCompactRowMajorData(int64_t* shape_arr, int64_t* stride_arr, |
250 | int ndim) { |
251 | if (ndim >= 1 && stride_arr[ndim - 1] != 1) { |
252 | return false; |
253 | } |
254 | for (int i = ndim - 2; i >= 0; --i) { |
255 | if (stride_arr[i] != shape_arr[i + 1] * stride_arr[i + 1]) { |
256 | return false; |
257 | } |
258 | } |
259 | return true; |
260 | } |
261 | } // namespace |
262 | |
263 | void TFE_CallDLManagedTensorDeleter(void* dlm_ptr) { |
264 | DLManagedTensor* dlMTensor = static_cast<DLManagedTensor*>(dlm_ptr); |
265 | if (dlMTensor->deleter != nullptr) { |
266 | dlMTensor->deleter(dlMTensor); |
267 | } |
268 | } |
269 | |
270 | void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) { |
271 | auto tf_dlm_context = GetDlContext(h, status); |
272 | if (!status->status.ok()) { |
273 | return nullptr; |
274 | } |
275 | |
276 | auto* tf_dlm_data = TFE_TensorHandleDevicePointer(h, status); |
277 | if (!status->status.ok()) { |
278 | return nullptr; |
279 | } |
280 | |
281 | const Tensor* tensor = GetTensorFromHandle(h, status); |
282 | TF_DataType data_type = static_cast<TF_DataType>(tensor->dtype()); |
283 | |
284 | auto tf_dlm_type = GetDlDataType(data_type, status); |
285 | if (!status->status.ok()) { |
286 | return nullptr; |
287 | } |
288 | |
289 | TensorReference tensor_ref(*tensor); // This will call buf_->Ref() |
290 | auto* tf_dlm_tensor_ctx = new TfDlManagedTensorCtx(tensor_ref); |
291 | tf_dlm_tensor_ctx->reference = tensor_ref; |
292 | |
293 | DLManagedTensor* dlm_tensor = &tf_dlm_tensor_ctx->tensor; |
294 | dlm_tensor->manager_ctx = tf_dlm_tensor_ctx; |
295 | dlm_tensor->deleter = &DLManagedTensorDeleter; |
296 | dlm_tensor->dl_tensor.device = tf_dlm_context; |
297 | int ndim = tensor->dims(); |
298 | dlm_tensor->dl_tensor.ndim = ndim; |
299 | dlm_tensor->dl_tensor.data = tf_dlm_data; |
300 | dlm_tensor->dl_tensor.dtype = tf_dlm_type; |
301 | |
302 | std::vector<int64_t>* shape_arr = &tf_dlm_tensor_ctx->shape; |
303 | std::vector<int64_t>* stride_arr = &tf_dlm_tensor_ctx->strides; |
304 | shape_arr->resize(ndim); |
305 | stride_arr->resize(ndim, 1); |
306 | for (int i = 0; i < ndim; i++) { |
307 | (*shape_arr)[i] = tensor->dim_size(i); |
308 | } |
309 | for (int i = ndim - 2; i >= 0; --i) { |
310 | (*stride_arr)[i] = (*shape_arr)[i + 1] * (*stride_arr)[i + 1]; |
311 | } |
312 | |
313 | dlm_tensor->dl_tensor.shape = shape_arr->data(); |
314 | // There are two ways to represent compact row-major data |
315 | // 1) nullptr indicates tensor is compact and row-majored. |
316 | // 2) fill in the strides array as the real case for compact row-major data. |
317 | // Here we choose option 2, since some frameworks didn't handle the strides |
318 | // argument properly. |
319 | dlm_tensor->dl_tensor.strides = stride_arr->data(); |
320 | |
321 | dlm_tensor->dl_tensor.byte_offset = |
322 | 0; // TF doesn't handle the strides and byte_offsets here |
323 | return static_cast<void*>(dlm_tensor); |
324 | } |
325 | |
326 | TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status, |
327 | TFE_Context* ctx) { |
328 | DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(dlm); |
329 | DLTensor* dl_tensor = &dlmt->dl_tensor; |
330 | absl::optional<std::string> device_name = |
331 | DeviceNameFromDlContext(dl_tensor->device, status); |
332 | if (!device_name.has_value()) { |
333 | status->status = |
334 | tensorflow::errors::InvalidArgument("Unsupported Device Type" ); |
335 | return nullptr; |
336 | } |
337 | TF_DataType dtype; |
338 | Status s = TfDataTypeFormDlDataType(dl_tensor->dtype, &dtype); |
339 | if (!s.ok()) { |
340 | status->status = std::move(s); |
341 | return nullptr; |
342 | } |
343 | int num_dims = dl_tensor->ndim; |
344 | const int64_t* dims = dl_tensor->shape; |
345 | void* data = dl_tensor->data; |
346 | |
347 | size_t total_bytes = dl_tensor->dtype.bits / 8; |
348 | for (int i = 0; i < num_dims; i++) { |
349 | total_bytes *= dims[i]; |
350 | } |
351 | |
352 | if (dl_tensor->strides != nullptr && |
353 | !IsValidStrideCompactRowMajorData(dl_tensor->shape, dl_tensor->strides, |
354 | num_dims)) { |
355 | status->status = tensorflow::errors::InvalidArgument( |
356 | "Invalid strides array from DLPack" ); |
357 | return nullptr; |
358 | } |
359 | |
360 | TFE_TensorHandle* handle = TFE_NewTensorHandleFromDeviceMemory( |
361 | ctx, device_name.value().c_str(), dtype, dims, num_dims, data, |
362 | total_bytes, &DeallocatorWrapperFunc, dlmt, status); |
363 | |
364 | return handle; |
365 | } |
366 | |
367 | } // namespace tensorflow |
368 | |