1 | /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/c/kernels_experimental.h" |
17 | |
18 | #include <algorithm> |
19 | #include <string> |
20 | #include <utility> |
21 | |
22 | #include "tensorflow/c/tf_status_helper.h" |
23 | #include "tensorflow/c/tf_status_internal.h" |
24 | #include "tensorflow/c/tf_tensor_internal.h" |
25 | #include "tensorflow/core/framework/ref_var.h" |
26 | #include "tensorflow/core/framework/resource_mgr.h" |
27 | #include "tensorflow/core/framework/resource_var.h" |
28 | #include "tensorflow/core/framework/variant.h" |
29 | |
30 | #ifndef IS_MOBILE_PLATFORM |
31 | #include "tensorflow/core/kernels/data/optional_ops_util.h" |
32 | #include "tensorflow/core/kernels/tensor_list.h" |
33 | #include "tensorflow/core/kernels/tensor_list_util.h" |
34 | #include "tensorflow/core/kernels/variant_ops_util.h" |
35 | #include "tensorflow/core/platform/abi.h" |
36 | #endif // IS_MOBILE_PLATFORM |
37 | |
38 | #include "tensorflow/core/platform/errors.h" |
39 | #include "tensorflow/core/platform/mutex.h" |
40 | #include "tensorflow/core/platform/refcount.h" |
41 | |
42 | using tensorflow::AllocatorAttributes; |
43 | using tensorflow::mutex_lock; |
44 | using tensorflow::Status; |
45 | using tensorflow::Tensor; |
46 | using tensorflow::TF_TensorFromTensor; |
47 | using tensorflow::Var; |
48 | using tensorflow::Variant; |
49 | using tensorflow::errors::InvalidArgument; |
50 | |
51 | struct TF_VariableInputLockHolder { |
52 | TF_VariableInputLockHolder( |
53 | std::vector<tensorflow::Var*> vars, |
54 | std::unique_ptr<std::vector<tensorflow::mutex_lock>> locks, |
55 | std::unique_ptr<std::vector<tensorflow::tf_shared_lock>> shared_locks) |
56 | : vars(std::move(vars)), |
57 | locks(std::move(locks)), |
58 | shared_locks(std::move(shared_locks)) {} |
59 | |
60 | std::vector<tensorflow::Var*> vars; |
61 | std::unique_ptr<std::vector<tensorflow::mutex_lock>> locks; |
62 | std::unique_ptr<std::vector<tensorflow::tf_shared_lock>> shared_locks; |
63 | }; |
64 | |
65 | tensorflow::Status EnsureSparseVariableAccess( |
66 | TF_OpKernelContext* ctx, bool variantType, |
67 | void (*copyFunc)(TF_OpKernelContext* ctx, TF_Tensor* source, |
68 | TF_Tensor* dest), |
69 | tensorflow::Var* var) { |
70 | auto* context = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); |
71 | if (var->copy_on_read_mode.load()) { |
72 | return ::tensorflow::OkStatus(); |
73 | } |
74 | mutex_lock ml(*var->mu()); |
75 | // Once copy-on-read mode is True the refcount is guaranteed to be 1. This can |
76 | // also happen if there are no concurrent reads of the variable and |
77 | // copy-on-read mode is false. |
78 | if (var->tensor()->RefCountIsOne()) { |
79 | var->copy_on_read_mode.store(true); |
80 | return ::tensorflow::OkStatus(); |
81 | } |
82 | Tensor tmp; |
83 | if (variantType) { |
84 | AllocatorAttributes attr; |
85 | attr.set_on_host(true); |
86 | TF_RETURN_IF_ERROR(context->allocate_temp( |
87 | var->tensor()->dtype(), var->tensor()->shape(), &tmp, attr)); |
88 | |
89 | const auto elements_in = var->tensor()->flat<Variant>(); |
90 | auto elements_out = tmp.flat<Variant>(); |
91 | for (int64_t i = 0; i < elements_in.size(); ++i) { |
92 | elements_out(i) = elements_in(i); |
93 | } |
94 | } else { |
95 | AllocatorAttributes attr; |
96 | attr.set_gpu_compatible(true); |
97 | attr.set_nic_compatible(true); |
98 | TF_RETURN_IF_ERROR(context->allocate_temp( |
99 | var->tensor()->dtype(), var->tensor()->shape(), &tmp, attr)); |
100 | tensorflow::Status s; |
101 | TF_Tensor* tf_tmp = TF_TensorFromTensor(tmp, &s); |
102 | TF_Tensor* tf_tensor = TF_TensorFromTensor(*var->tensor(), &s); |
103 | copyFunc(ctx, tf_tensor, tf_tmp); |
104 | } |
105 | *var->tensor() = tmp; |
106 | var->copy_on_read_mode.store(true); |
107 | return ::tensorflow::OkStatus(); |
108 | } |
109 | |
110 | tensorflow::Status PrepareToUpdateVariable( |
111 | TF_OpKernelContext* ctx, tensorflow::Tensor* tensor, bool copy_on_read_mode, |
112 | bool variantType, |
113 | void (*copyFunc)(TF_OpKernelContext* ctx, TF_Tensor* source, |
114 | TF_Tensor* dest)) { |
115 | auto* context = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); |
116 | if (copy_on_read_mode || !tensor->RefCountIsOne()) { |
117 | // Tensor's buffer is in use by some read, so we need to copy before |
118 | // updating. |
119 | Tensor tmp; |
120 | if (variantType) { |
121 | AllocatorAttributes attr; |
122 | attr.set_on_host(true); |
123 | TF_RETURN_IF_ERROR( |
124 | context->allocate_temp(tensor->dtype(), tensor->shape(), &tmp, attr)); |
125 | |
126 | const auto elements_in = tensor->flat<Variant>(); |
127 | auto elements_out = tmp.flat<Variant>(); |
128 | for (int64_t i = 0; i < elements_in.size(); ++i) { |
129 | elements_out(i) = elements_in(i); |
130 | } |
131 | } else { |
132 | AllocatorAttributes attr; |
133 | attr.set_gpu_compatible(true); |
134 | attr.set_nic_compatible(true); |
135 | TF_RETURN_IF_ERROR( |
136 | context->allocate_temp(tensor->dtype(), tensor->shape(), &tmp, attr)); |
137 | tensorflow::Status s; |
138 | TF_Tensor* tf_tmp = TF_TensorFromTensor(tmp, &s); |
139 | TF_Tensor* tf_tensor = TF_TensorFromTensor(*tensor, &s); |
140 | copyFunc(ctx, tf_tensor, tf_tmp); |
141 | } |
142 | *tensor = tmp; |
143 | } |
144 | return ::tensorflow::OkStatus(); |
145 | } |
146 | |
147 | tensorflow::mutex* GetTrainingVariableMutex( |
148 | TF_OpKernelContext* ctx, int32_t input, bool sparse, |
149 | void (*copyFunc)(TF_OpKernelContext* ctx, TF_Tensor* source, |
150 | TF_Tensor* dest), |
151 | tensorflow::Var** maybe_resource) { |
152 | auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); |
153 | *maybe_resource = nullptr; |
154 | if (cc_ctx->input_dtype(input) == tensorflow::DT_RESOURCE) { |
155 | if (LookupResource(cc_ctx, HandleFromInput(cc_ctx, input), maybe_resource) |
156 | .ok()) { |
157 | if (sparse) { |
158 | TF_CHECK_OK( |
159 | EnsureSparseVariableAccess(ctx, false, copyFunc, *maybe_resource)); |
160 | } |
161 | return (*maybe_resource)->mu(); |
162 | } else { |
163 | cc_ctx->CtxFailureWithWarning( |
164 | tensorflow::errors::Internal("Invalid variable reference." )); |
165 | return nullptr; |
166 | } |
167 | } |
168 | return cc_ctx->input_ref_mutex(input); |
169 | } |
170 | |
171 | void TF_AssignVariable(TF_OpKernelContext* ctx, int input_index, |
172 | int value_index, bool validate_shape, |
173 | void (*copyFunc)(TF_OpKernelContext* ctx, |
174 | TF_Tensor* source, TF_Tensor* dest), |
175 | TF_Status* status) { |
176 | auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); |
177 | tensorflow::core::RefCountPtr<tensorflow::Var> variable; |
178 | const tensorflow::Tensor& value = cc_ctx->input(value_index); |
179 | OP_REQUIRES_OK(cc_ctx, tensorflow::LookupOrCreateResource<tensorflow::Var>( |
180 | cc_ctx, HandleFromInput(cc_ctx, input_index), |
181 | &variable, [&value](tensorflow::Var** ptr) { |
182 | *ptr = new tensorflow::Var(value.dtype()); |
183 | *(*ptr)->tensor() = value; |
184 | (*ptr)->is_initialized = true; |
185 | return ::tensorflow::OkStatus(); |
186 | })); |
187 | tensorflow::mutex_lock ml(*variable->mu()); |
188 | |
189 | if (validate_shape) { |
190 | OP_REQUIRES(cc_ctx, |
191 | (!variable->is_initialized || |
192 | variable->tensor()->shape().IsSameSize(value.shape())), |
193 | InvalidArgument( |
194 | "Trying to assign to variable with tensor with wrong shape." |
195 | " Expected " , |
196 | variable->tensor()->shape().DebugString(), " got " , |
197 | value.shape().DebugString())); |
198 | } |
199 | |
200 | if (variable->copy_on_read_mode.load()) { |
201 | tensorflow::Tensor tmp; |
202 | tensorflow::AllocatorAttributes attr; |
203 | attr.set_gpu_compatible(true); |
204 | attr.set_nic_compatible(true); |
205 | OP_REQUIRES_OK(cc_ctx, cc_ctx->allocate_temp(value.dtype(), value.shape(), |
206 | &tmp, attr)); |
207 | tensorflow::Status s; |
208 | TF_Tensor* tf_tmp = TF_TensorFromTensor(tmp, &s); |
209 | TF_Tensor* tf_value = TF_TensorFromTensor(value, &s); |
210 | copyFunc(ctx, tf_value, tf_tmp); |
211 | *variable->tensor() = tmp; |
212 | } else { |
213 | *variable->tensor() = value; |
214 | } |
215 | variable->is_initialized = true; |
216 | TF_SetStatus(status, TF_OK, "" ); |
217 | } |
218 | |
219 | void TF_AssignRefVariable(TF_OpKernelContext* ctx, int input_ref_index, |
220 | int output_ref_index, int value_index, |
221 | bool use_locking, bool validate_shape, |
222 | void (*copyFunc)(TF_OpKernelContext* ctx, |
223 | TF_Tensor* source, TF_Tensor* dest), |
224 | TF_Status* status) { |
225 | auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); |
226 | |
227 | auto copy = [copyFunc, ctx](::tensorflow::OpKernelContext* cc_ctx, |
228 | ::tensorflow::Tensor* lhs, |
229 | const ::tensorflow::Tensor& rhs) { |
230 | ::tensorflow::Status s; |
231 | TF_Tensor* tf_lhs = TF_TensorFromTensor(*lhs, &s); |
232 | OP_REQUIRES_OK(cc_ctx, s); |
233 | |
234 | TF_Tensor* tf_rhs = TF_TensorFromTensor(rhs, &s); |
235 | |
236 | if (!s.ok()) { |
237 | TF_DeleteTensor(tf_lhs); |
238 | OP_REQUIRES_OK(cc_ctx, s); |
239 | } |
240 | |
241 | copyFunc(ctx, tf_rhs, tf_lhs); |
242 | }; |
243 | |
244 | ::tensorflow::AssignRefVariable(cc_ctx, input_ref_index, output_ref_index, |
245 | value_index, use_locking, validate_shape, |
246 | false, copy); |
247 | TF_SetStatus(status, TF_OK, "" ); |
248 | } |
249 | |
250 | void TF_AssignUpdateVariable(TF_OpKernelContext* ctx, int input_index, |
251 | int value_index, int Op, int isVariantType, |
252 | void (*copyFunc)(TF_OpKernelContext* ctx, |
253 | TF_Tensor* source, |
254 | TF_Tensor* dest), |
255 | void (*updateFunc)(TF_OpKernelContext* ctx, |
256 | TF_Tensor* tensor, |
257 | TF_Tensor* value, int Op), |
258 | TF_Status* tf_status) { |
259 | auto* context = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); |
260 | tensorflow::core::RefCountPtr<Var> variable; |
261 | Status status = |
262 | LookupResource(context, HandleFromInput(context, input_index), &variable); |
263 | if (!status.ok()) { |
264 | printf("Failed with error: %s\n" , status.error_message().c_str()); |
265 | abort(); |
266 | } |
267 | const Tensor& value = context->input(value_index); |
268 | mutex_lock ml(*variable->mu()); |
269 | Tensor* var_tensor = variable->tensor(); |
270 | OP_REQUIRES( |
271 | context, var_tensor->shape().IsSameSize(value.shape()), |
272 | InvalidArgument("Cannot update variable with shape " , |
273 | var_tensor->shape().DebugString(), |
274 | " using a Tensor with shape " , |
275 | value.shape().DebugString(), ", shapes must be equal." )); |
276 | OP_REQUIRES_OK(context, |
277 | PrepareToUpdateVariable(ctx, var_tensor, |
278 | variable->copy_on_read_mode.load(), |
279 | isVariantType, copyFunc)); |
280 | tensorflow::Status s; |
281 | TF_Tensor* tf_var_tensor = TF_TensorFromTensor(*var_tensor, &s); |
282 | TF_Tensor* tf_value = TF_TensorFromTensor(value, &s); |
283 | updateFunc(ctx, tf_var_tensor, tf_value, Op); |
284 | TF_SetStatus(tf_status, TF_OK, "" ); |
285 | } |
286 | |
287 | void TF_MaybeLockVariableInputMutexesInOrder( |
288 | TF_OpKernelContext* ctx, bool do_lock, bool sparse, const int* const inputs, |
289 | size_t len, |
290 | void (*copyFunc)(TF_OpKernelContext* ctx, TF_Tensor* source, |
291 | TF_Tensor* dest), |
292 | TF_VariableInputLockHolder** lockHolder, TF_Status* status) { |
293 | auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); |
294 | bool any_resource = false; |
295 | std::vector<int> input_ids(inputs, inputs + len); |
296 | for (auto i : input_ids) { |
297 | if (cc_ctx->input_dtype(i) == tensorflow::DT_RESOURCE) { |
298 | any_resource = true; |
299 | break; |
300 | } |
301 | } |
302 | if (!do_lock && !any_resource) { |
303 | *lockHolder = new TF_VariableInputLockHolder({}, {}, {}); |
304 | TF_SetStatus(status, TF_OK, "" ); |
305 | return; |
306 | } |
307 | std::vector<tensorflow::Var*> vars; |
308 | std::vector<tensorflow::mutex*> mutexes; |
309 | std::vector<int32_t> acquire_order; |
310 | for (auto input : input_ids) { |
311 | tensorflow::Var* var; |
312 | tensorflow::mutex* mutex = |
313 | GetTrainingVariableMutex(ctx, input, sparse, copyFunc, &var); |
314 | if (var) vars.push_back(var); |
315 | // Only lock each mutex once if duplicates exist (n^2 but n is 2 or 3). |
316 | if (std::find(mutexes.begin(), mutexes.end(), mutex) == mutexes.end()) { |
317 | acquire_order.push_back(mutexes.size()); |
318 | mutexes.push_back(mutex); |
319 | } |
320 | } |
321 | std::sort(acquire_order.begin(), acquire_order.end(), |
322 | [&mutexes](int a, int b) { return mutexes[a] < mutexes[b]; }); |
323 | |
324 | auto locks = absl::make_unique<std::vector<tensorflow::mutex_lock>>(); |
325 | auto shared_locks = |
326 | absl::make_unique<std::vector<tensorflow::tf_shared_lock>>(); |
327 | locks->reserve(acquire_order.size()); |
328 | |
329 | for (auto input : acquire_order) { |
330 | tensorflow::Var* var; |
331 | tensorflow::mutex* mu = |
332 | GetTrainingVariableMutex(ctx, input, sparse, copyFunc, &var); |
333 | tensorflow::core::ScopedUnref scoped_unref(var); |
334 | if (mu != nullptr) { |
335 | if (do_lock) { |
336 | locks->emplace_back(*mu); |
337 | } else { |
338 | shared_locks->emplace_back(*mu); |
339 | } |
340 | } |
341 | } |
342 | *lockHolder = new TF_VariableInputLockHolder( |
343 | std::move(vars), std::move(locks), std::move(shared_locks)); |
344 | TF_SetStatus(status, TF_OK, "" ); |
345 | } |
346 | |
347 | void TF_GetInputTensorFromVariable(TF_OpKernelContext* ctx, int input, |
348 | bool lock_held, bool isVariantType, |
349 | bool sparse, |
350 | void (*copyFunc)(TF_OpKernelContext* ctx, |
351 | TF_Tensor* source, |
352 | TF_Tensor* dest), |
353 | TF_Tensor** out, TF_Status* status) { |
354 | auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); |
355 | tensorflow::Status s; |
356 | if (cc_ctx->input_dtype(input) == tensorflow::DT_RESOURCE) { |
357 | tensorflow::core::RefCountPtr<tensorflow::Var> var; |
358 | OP_REQUIRES_OK( |
359 | cc_ctx, LookupResource(cc_ctx, HandleFromInput(cc_ctx, input), &var)); |
360 | if (sparse) { |
361 | OP_REQUIRES_OK(cc_ctx, EnsureSparseVariableAccess(ctx, isVariantType, |
362 | copyFunc, var.get())); |
363 | *out = ::tensorflow::TF_TensorFromTensor(*var->tensor(), &s); |
364 | ::tensorflow::Set_TF_Status_from_Status(status, s); |
365 | return; |
366 | } |
367 | OP_REQUIRES_OK(cc_ctx, PrepareToUpdateVariable( |
368 | ctx, var->tensor(), |
369 | var->copy_on_read_mode.load(), false, copyFunc)); |
370 | *out = ::tensorflow::TF_TensorFromTensor(*var->tensor(), &s); |
371 | ::tensorflow::Set_TF_Status_from_Status(status, s); |
372 | return; |
373 | } |
374 | *out = ::tensorflow::TF_TensorFromTensor( |
375 | cc_ctx->mutable_input(input, lock_held), &s); |
376 | ::tensorflow::Set_TF_Status_from_Status(status, s); |
377 | } |
378 | |
379 | void TF_OpKernelContext_ForwardRefInputToRefOutput(TF_OpKernelContext* ctx, |
380 | int32_t input_index, |
381 | int32_t output_index) { |
382 | auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); |
383 | if (cc_ctx->input_dtype(input_index) != tensorflow::DT_RESOURCE) { |
384 | cc_ctx->forward_ref_input_to_ref_output(input_index, output_index); |
385 | } |
386 | } |
387 | |
388 | void TF_ReleaseVariableInputLockHolder(TF_VariableInputLockHolder* lockHolder) { |
389 | if (lockHolder != nullptr) { |
390 | lockHolder->locks.reset(); |
391 | for (tensorflow::Var* var : lockHolder->vars) { |
392 | var->Unref(); |
393 | } |
394 | delete lockHolder; |
395 | } |
396 | } |
397 | |
398 | void TF_GetInputByName(TF_OpKernelContext* ctx, const char* inputName, |
399 | TF_Tensor** tensor, TF_Status* status) { |
400 | auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); |
401 | const ::tensorflow::Tensor* cc_tensor = nullptr; |
402 | tensorflow::Status s = cc_ctx->input(inputName, &cc_tensor); |
403 | |
404 | if (!s.ok()) { |
405 | ::tensorflow::Set_TF_Status_from_Status(status, s); |
406 | return; |
407 | } |
408 | TF_Tensor* result = |
409 | ::tensorflow::TF_TensorFromTensor(*cc_tensor, &status->status); |
410 | if (TF_GetCode(status) == TF_OK) { |
411 | *tensor = result; |
412 | } |
413 | } |
414 | |
415 | void TF_OpKernelConstruction_GetAttrTensorShape(TF_OpKernelConstruction* ctx, |
416 | const char* attr_name, |
417 | int64_t* dims, size_t num_dims, |
418 | TF_Status* status) { |
419 | ::tensorflow::TensorShape shape; |
420 | auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx); |
421 | ::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &shape); |
422 | ::tensorflow::Set_TF_Status_from_Status(status, s); |
423 | size_t rank = static_cast<size_t>(shape.dims()); |
424 | |
425 | if (!status->status.ok()) return; |
426 | |
427 | if (num_dims != rank) { |
428 | status->status = InvalidArgument("Expected rank is " , num_dims, |
429 | " but actual rank is " , rank); |
430 | return; |
431 | } |
432 | |
433 | for (int i = 0; i < rank; ++i) { |
434 | dims[i] = static_cast<int64_t>(shape.dim_size(i)); |
435 | } |
436 | } |
437 | |
438 | bool TF_IsRefInput(TF_OpKernelContext* ctx, int i, TF_Status* status) { |
439 | auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); |
440 | if (i < 0 || i >= cc_ctx->num_inputs()) { |
441 | TF_SetStatus(status, TF_OUT_OF_RANGE, "input index out of range" ); |
442 | return false; |
443 | } |
444 | TF_SetStatus(status, TF_OK, "" ); |
445 | return cc_ctx->input_is_ref(i); |
446 | } |
447 | |
448 | #ifndef IS_MOBILE_PLATFORM |
449 | template <typename T> |
450 | static Status ValidateVariantType(const Variant& variant) { |
451 | if (variant.get<T>() == nullptr) { |
452 | const std::string type_index_name = |
453 | ::tensorflow::port::MaybeAbiDemangle(variant.TypeId().name()); |
454 | |
455 | return ::tensorflow::errors::Internal( |
456 | "VariantBinaryOpFn: Could not access object 'a', type_index: " , |
457 | type_index_name); |
458 | } |
459 | |
460 | return ::tensorflow::OkStatus(); |
461 | } |
462 | |
463 | void TF_AddNVariant(TF_OpKernelContext* ctx, |
464 | void (*binary_add_func)(TF_OpKernelContext* ctx, |
465 | TF_Tensor* a, TF_Tensor* b, |
466 | TF_Tensor* out), |
467 | TF_Status* status) { |
468 | auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); |
469 | |
470 | auto cc_binary_add_func = [binary_add_func]( |
471 | ::tensorflow::OpKernelContext* cc_ctx, |
472 | const Tensor& cc_a, const Tensor& cc_b, |
473 | Tensor* cc_out) { |
474 | if (cc_a.dtype() == ::tensorflow::DT_INVALID) { |
475 | *cc_out = cc_b; |
476 | return ::tensorflow::OkStatus(); |
477 | } |
478 | if (cc_b.dtype() == ::tensorflow::DT_INVALID) { |
479 | *cc_out = cc_a; |
480 | return ::tensorflow::OkStatus(); |
481 | } |
482 | |
483 | Status status; |
484 | TF_Tensor* a = TF_TensorFromTensor(cc_a, &status); |
485 | TF_RETURN_IF_ERROR(status); |
486 | |
487 | TF_Tensor* b = TF_TensorFromTensor(cc_b, &status); |
488 | if (!status.ok()) { |
489 | TF_DeleteTensor(a); |
490 | return status; |
491 | } |
492 | |
493 | ::tensorflow::AllocatorAttributes attr; |
494 | if (cc_a.dtype() == ::tensorflow::DT_VARIANT) { |
495 | attr.set_on_host(true); |
496 | } |
497 | |
498 | status = cc_ctx->allocate_temp(cc_a.dtype(), cc_a.shape(), cc_out, attr); |
499 | if (!status.ok()) { |
500 | TF_DeleteTensor(a); |
501 | TF_DeleteTensor(b); |
502 | return status; |
503 | } |
504 | |
505 | TF_Tensor* out = TF_TensorFromTensor(*cc_out, &status); |
506 | if (!status.ok()) { |
507 | TF_DeleteTensor(a); |
508 | TF_DeleteTensor(b); |
509 | return status; |
510 | } |
511 | |
512 | auto* ctx = reinterpret_cast<TF_OpKernelContext*>(cc_ctx); |
513 | binary_add_func(ctx, a, b, out); |
514 | return cc_ctx->status(); |
515 | }; |
516 | |
517 | auto binary_add_variant = [cc_binary_add_func]( |
518 | ::tensorflow::OpKernelContext* cc_ctx, |
519 | const Variant& a, const Variant& b, |
520 | Variant* out) { |
521 | if (out == nullptr) { |
522 | return ::tensorflow::errors::Internal( |
523 | "The output variant hasn't been initialized" ); |
524 | } |
525 | |
526 | if (a.TypeId() != b.TypeId()) { |
527 | return ::tensorflow::errors::Internal( |
528 | "BinaryOpVariants: Variants a and b have different " |
529 | "type ids. Type names: '" , |
530 | a.TypeName(), "' vs. '" , b.TypeName(), "'" ); |
531 | } |
532 | |
533 | if (a.TypeId() == tensorflow::TypeIndex::Make<::tensorflow::TensorList>()) { |
534 | TF_RETURN_IF_ERROR(ValidateVariantType<::tensorflow::TensorList>(a)); |
535 | *out = ::tensorflow::TensorList(); |
536 | |
537 | return ::tensorflow::TensorListBinaryAdd( |
538 | cc_ctx, *a.get<::tensorflow::TensorList>(), |
539 | *b.get<::tensorflow::TensorList>(), |
540 | out->get<::tensorflow::TensorList>(), cc_binary_add_func); |
541 | } else if (a.TypeId() == tensorflow::TypeIndex::Make< |
542 | ::tensorflow::data::OptionalVariant>()) { |
543 | TF_RETURN_IF_ERROR( |
544 | ValidateVariantType<::tensorflow::data::OptionalVariant>(a)); |
545 | *out = ::tensorflow::data::OptionalVariant(); |
546 | |
547 | return ::tensorflow::data::OptionalBinaryAdd( |
548 | cc_ctx, *a.get<::tensorflow::data::OptionalVariant>(), |
549 | *b.get<::tensorflow::data::OptionalVariant>(), |
550 | out->get<::tensorflow::data::OptionalVariant>(), cc_binary_add_func); |
551 | } |
552 | |
553 | const std::string type_index_name = |
554 | ::tensorflow::port::MaybeAbiDemangle(a.TypeId().name()); |
555 | |
556 | return ::tensorflow::errors::Internal( |
557 | "No unary variant binary_op function found for op ADD Variant " |
558 | "type_name: " , |
559 | type_index_name, " for device type: " , cc_ctx->device()->name()); |
560 | }; |
561 | ::tensorflow::AddNVariant(cc_ctx, binary_add_variant); |
562 | ::tensorflow::Set_TF_Status_from_Status(status, cc_ctx->status()); |
563 | } |
564 | |
565 | static Status ZerosLikeVariant(::tensorflow::OpKernelContext* cc_ctx, |
566 | const Variant& input, Variant* out, |
567 | void (*zeros_like_func)(TF_OpKernelContext* ctx, |
568 | TF_Tensor* input, |
569 | TF_Tensor* out)) { |
570 | auto cc_zeros_like_func = [zeros_like_func]( |
571 | ::tensorflow::OpKernelContext* cc_ctx, |
572 | const Tensor& cc_input, Tensor* cc_out) { |
573 | AllocatorAttributes attr; |
574 | if (cc_input.dtype() == ::tensorflow::DT_VARIANT) { |
575 | attr.set_on_host(true); |
576 | } |
577 | TF_RETURN_IF_ERROR(cc_ctx->allocate_temp(cc_input.dtype(), cc_input.shape(), |
578 | cc_out, attr)); |
579 | |
580 | switch (cc_input.dtype()) { |
581 | case ::tensorflow::DT_INVALID: { |
582 | *cc_out = Tensor(::tensorflow::DT_INVALID); |
583 | break; |
584 | } |
585 | case ::tensorflow::DT_VARIANT: { |
586 | // If the wrapped tensor is also a variant, recursively call |
587 | // ZerosLikeVariant to unwrap it the same way |
588 | Variant* out_variant = cc_out->scalar<Variant>().data(); |
589 | TF_RETURN_IF_ERROR(ZerosLikeVariant(cc_ctx, |
590 | cc_input.scalar<Variant>()(), |
591 | out_variant, zeros_like_func)); |
592 | break; |
593 | } |
594 | default: { |
595 | Status status; |
596 | TF_Tensor* input = TF_TensorFromTensor(cc_input, &status); |
597 | TF_RETURN_IF_ERROR(status); |
598 | |
599 | TF_Tensor* out = TF_TensorFromTensor(*cc_out, &status); |
600 | if (!status.ok()) { |
601 | TF_DeleteTensor(input); |
602 | return status; |
603 | } |
604 | |
605 | auto* ctx = reinterpret_cast<TF_OpKernelContext*>(cc_ctx); |
606 | zeros_like_func(ctx, input, out); |
607 | } |
608 | } |
609 | return cc_ctx->status(); |
610 | }; |
611 | |
612 | if (out == nullptr) { |
613 | return ::tensorflow::errors::Internal( |
614 | "The output variant hasn't been initialized" ); |
615 | } |
616 | |
617 | if (input.TypeId() == |
618 | tensorflow::TypeIndex::Make<::tensorflow::TensorList>()) { |
619 | TF_RETURN_IF_ERROR(ValidateVariantType<::tensorflow::TensorList>(input)); |
620 | *out = ::tensorflow::TensorList(); |
621 | |
622 | return ::tensorflow::TensorListZerosLike( |
623 | cc_ctx, *input.get<::tensorflow::TensorList>(), |
624 | out->get<::tensorflow::TensorList>(), cc_zeros_like_func); |
625 | } else if (input.TypeId() == tensorflow::TypeIndex::Make< |
626 | ::tensorflow::data::OptionalVariant>()) { |
627 | TF_RETURN_IF_ERROR( |
628 | ValidateVariantType<::tensorflow::data::OptionalVariant>(input)); |
629 | *out = ::tensorflow::data::OptionalVariant(); |
630 | |
631 | return ::tensorflow::data::OptionalZerosLike( |
632 | cc_ctx, *input.get<::tensorflow::data::OptionalVariant>(), |
633 | out->get<::tensorflow::data::OptionalVariant>(), cc_zeros_like_func); |
634 | } |
635 | |
636 | const std::string type_index_name = |
637 | ::tensorflow::port::MaybeAbiDemangle(input.TypeId().name()); |
638 | |
639 | return ::tensorflow::errors::Internal( |
640 | "No unary variant unary_op function found for op ZEROS_LIKE Variant " |
641 | "type_name: " , |
642 | type_index_name, " for device type: " , cc_ctx->device()->name()); |
643 | } |
644 | |
645 | void TF_ZerosLikeVariant(TF_OpKernelContext* ctx, |
646 | void (*zeros_like_func)(TF_OpKernelContext* ctx, |
647 | TF_Tensor* input, |
648 | TF_Tensor* out), |
649 | TF_Status* status) { |
650 | auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); |
651 | |
652 | const Tensor& input = cc_ctx->input(0); |
653 | OP_REQUIRES(cc_ctx, input.dims() == 0, |
654 | InvalidArgument( |
655 | "ZerosLike non-scalar Tensor with dtype=DT_VARIANT is not " |
656 | "supported." )); |
657 | const Variant& v = input.scalar<Variant>()(); |
658 | // DT_VARIANT tensors must be allocated on CPU since they wrap C++ |
659 | // objects which can not be efficiently represented in GPU memory. |
660 | int numa_node = cc_ctx->device()->NumaNode(); |
661 | Tensor out(::tensorflow::cpu_allocator(numa_node), ::tensorflow::DT_VARIANT, |
662 | ::tensorflow::TensorShape({})); |
663 | Variant* out_v = &(out.scalar<Variant>()()); |
664 | Status cc_status = ZerosLikeVariant(cc_ctx, v, out_v, zeros_like_func); |
665 | ::tensorflow::Set_TF_Status_from_Status(status, cc_status); |
666 | OP_REQUIRES_OK(cc_ctx, cc_status); |
667 | cc_ctx->set_output(0, out); |
668 | } |
669 | #endif // IS_MOBILE_PLATFORM |
670 | |