1/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/c/kernels_experimental.h"
17
18#include <algorithm>
19#include <string>
20#include <utility>
21
22#include "tensorflow/c/tf_status_helper.h"
23#include "tensorflow/c/tf_status_internal.h"
24#include "tensorflow/c/tf_tensor_internal.h"
25#include "tensorflow/core/framework/ref_var.h"
26#include "tensorflow/core/framework/resource_mgr.h"
27#include "tensorflow/core/framework/resource_var.h"
28#include "tensorflow/core/framework/variant.h"
29
30#ifndef IS_MOBILE_PLATFORM
31#include "tensorflow/core/kernels/data/optional_ops_util.h"
32#include "tensorflow/core/kernels/tensor_list.h"
33#include "tensorflow/core/kernels/tensor_list_util.h"
34#include "tensorflow/core/kernels/variant_ops_util.h"
35#include "tensorflow/core/platform/abi.h"
36#endif // IS_MOBILE_PLATFORM
37
38#include "tensorflow/core/platform/errors.h"
39#include "tensorflow/core/platform/mutex.h"
40#include "tensorflow/core/platform/refcount.h"
41
42using tensorflow::AllocatorAttributes;
43using tensorflow::mutex_lock;
44using tensorflow::Status;
45using tensorflow::Tensor;
46using tensorflow::TF_TensorFromTensor;
47using tensorflow::Var;
48using tensorflow::Variant;
49using tensorflow::errors::InvalidArgument;
50
51struct TF_VariableInputLockHolder {
52 TF_VariableInputLockHolder(
53 std::vector<tensorflow::Var*> vars,
54 std::unique_ptr<std::vector<tensorflow::mutex_lock>> locks,
55 std::unique_ptr<std::vector<tensorflow::tf_shared_lock>> shared_locks)
56 : vars(std::move(vars)),
57 locks(std::move(locks)),
58 shared_locks(std::move(shared_locks)) {}
59
60 std::vector<tensorflow::Var*> vars;
61 std::unique_ptr<std::vector<tensorflow::mutex_lock>> locks;
62 std::unique_ptr<std::vector<tensorflow::tf_shared_lock>> shared_locks;
63};
64
65tensorflow::Status EnsureSparseVariableAccess(
66 TF_OpKernelContext* ctx, bool variantType,
67 void (*copyFunc)(TF_OpKernelContext* ctx, TF_Tensor* source,
68 TF_Tensor* dest),
69 tensorflow::Var* var) {
70 auto* context = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
71 if (var->copy_on_read_mode.load()) {
72 return ::tensorflow::OkStatus();
73 }
74 mutex_lock ml(*var->mu());
75 // Once copy-on-read mode is True the refcount is guaranteed to be 1. This can
76 // also happen if there are no concurrent reads of the variable and
77 // copy-on-read mode is false.
78 if (var->tensor()->RefCountIsOne()) {
79 var->copy_on_read_mode.store(true);
80 return ::tensorflow::OkStatus();
81 }
82 Tensor tmp;
83 if (variantType) {
84 AllocatorAttributes attr;
85 attr.set_on_host(true);
86 TF_RETURN_IF_ERROR(context->allocate_temp(
87 var->tensor()->dtype(), var->tensor()->shape(), &tmp, attr));
88
89 const auto elements_in = var->tensor()->flat<Variant>();
90 auto elements_out = tmp.flat<Variant>();
91 for (int64_t i = 0; i < elements_in.size(); ++i) {
92 elements_out(i) = elements_in(i);
93 }
94 } else {
95 AllocatorAttributes attr;
96 attr.set_gpu_compatible(true);
97 attr.set_nic_compatible(true);
98 TF_RETURN_IF_ERROR(context->allocate_temp(
99 var->tensor()->dtype(), var->tensor()->shape(), &tmp, attr));
100 tensorflow::Status s;
101 TF_Tensor* tf_tmp = TF_TensorFromTensor(tmp, &s);
102 TF_Tensor* tf_tensor = TF_TensorFromTensor(*var->tensor(), &s);
103 copyFunc(ctx, tf_tensor, tf_tmp);
104 }
105 *var->tensor() = tmp;
106 var->copy_on_read_mode.store(true);
107 return ::tensorflow::OkStatus();
108}
109
110tensorflow::Status PrepareToUpdateVariable(
111 TF_OpKernelContext* ctx, tensorflow::Tensor* tensor, bool copy_on_read_mode,
112 bool variantType,
113 void (*copyFunc)(TF_OpKernelContext* ctx, TF_Tensor* source,
114 TF_Tensor* dest)) {
115 auto* context = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
116 if (copy_on_read_mode || !tensor->RefCountIsOne()) {
117 // Tensor's buffer is in use by some read, so we need to copy before
118 // updating.
119 Tensor tmp;
120 if (variantType) {
121 AllocatorAttributes attr;
122 attr.set_on_host(true);
123 TF_RETURN_IF_ERROR(
124 context->allocate_temp(tensor->dtype(), tensor->shape(), &tmp, attr));
125
126 const auto elements_in = tensor->flat<Variant>();
127 auto elements_out = tmp.flat<Variant>();
128 for (int64_t i = 0; i < elements_in.size(); ++i) {
129 elements_out(i) = elements_in(i);
130 }
131 } else {
132 AllocatorAttributes attr;
133 attr.set_gpu_compatible(true);
134 attr.set_nic_compatible(true);
135 TF_RETURN_IF_ERROR(
136 context->allocate_temp(tensor->dtype(), tensor->shape(), &tmp, attr));
137 tensorflow::Status s;
138 TF_Tensor* tf_tmp = TF_TensorFromTensor(tmp, &s);
139 TF_Tensor* tf_tensor = TF_TensorFromTensor(*tensor, &s);
140 copyFunc(ctx, tf_tensor, tf_tmp);
141 }
142 *tensor = tmp;
143 }
144 return ::tensorflow::OkStatus();
145}
146
147tensorflow::mutex* GetTrainingVariableMutex(
148 TF_OpKernelContext* ctx, int32_t input, bool sparse,
149 void (*copyFunc)(TF_OpKernelContext* ctx, TF_Tensor* source,
150 TF_Tensor* dest),
151 tensorflow::Var** maybe_resource) {
152 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
153 *maybe_resource = nullptr;
154 if (cc_ctx->input_dtype(input) == tensorflow::DT_RESOURCE) {
155 if (LookupResource(cc_ctx, HandleFromInput(cc_ctx, input), maybe_resource)
156 .ok()) {
157 if (sparse) {
158 TF_CHECK_OK(
159 EnsureSparseVariableAccess(ctx, false, copyFunc, *maybe_resource));
160 }
161 return (*maybe_resource)->mu();
162 } else {
163 cc_ctx->CtxFailureWithWarning(
164 tensorflow::errors::Internal("Invalid variable reference."));
165 return nullptr;
166 }
167 }
168 return cc_ctx->input_ref_mutex(input);
169}
170
171void TF_AssignVariable(TF_OpKernelContext* ctx, int input_index,
172 int value_index, bool validate_shape,
173 void (*copyFunc)(TF_OpKernelContext* ctx,
174 TF_Tensor* source, TF_Tensor* dest),
175 TF_Status* status) {
176 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
177 tensorflow::core::RefCountPtr<tensorflow::Var> variable;
178 const tensorflow::Tensor& value = cc_ctx->input(value_index);
179 OP_REQUIRES_OK(cc_ctx, tensorflow::LookupOrCreateResource<tensorflow::Var>(
180 cc_ctx, HandleFromInput(cc_ctx, input_index),
181 &variable, [&value](tensorflow::Var** ptr) {
182 *ptr = new tensorflow::Var(value.dtype());
183 *(*ptr)->tensor() = value;
184 (*ptr)->is_initialized = true;
185 return ::tensorflow::OkStatus();
186 }));
187 tensorflow::mutex_lock ml(*variable->mu());
188
189 if (validate_shape) {
190 OP_REQUIRES(cc_ctx,
191 (!variable->is_initialized ||
192 variable->tensor()->shape().IsSameSize(value.shape())),
193 InvalidArgument(
194 "Trying to assign to variable with tensor with wrong shape."
195 " Expected ",
196 variable->tensor()->shape().DebugString(), " got ",
197 value.shape().DebugString()));
198 }
199
200 if (variable->copy_on_read_mode.load()) {
201 tensorflow::Tensor tmp;
202 tensorflow::AllocatorAttributes attr;
203 attr.set_gpu_compatible(true);
204 attr.set_nic_compatible(true);
205 OP_REQUIRES_OK(cc_ctx, cc_ctx->allocate_temp(value.dtype(), value.shape(),
206 &tmp, attr));
207 tensorflow::Status s;
208 TF_Tensor* tf_tmp = TF_TensorFromTensor(tmp, &s);
209 TF_Tensor* tf_value = TF_TensorFromTensor(value, &s);
210 copyFunc(ctx, tf_value, tf_tmp);
211 *variable->tensor() = tmp;
212 } else {
213 *variable->tensor() = value;
214 }
215 variable->is_initialized = true;
216 TF_SetStatus(status, TF_OK, "");
217}
218
219void TF_AssignRefVariable(TF_OpKernelContext* ctx, int input_ref_index,
220 int output_ref_index, int value_index,
221 bool use_locking, bool validate_shape,
222 void (*copyFunc)(TF_OpKernelContext* ctx,
223 TF_Tensor* source, TF_Tensor* dest),
224 TF_Status* status) {
225 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
226
227 auto copy = [copyFunc, ctx](::tensorflow::OpKernelContext* cc_ctx,
228 ::tensorflow::Tensor* lhs,
229 const ::tensorflow::Tensor& rhs) {
230 ::tensorflow::Status s;
231 TF_Tensor* tf_lhs = TF_TensorFromTensor(*lhs, &s);
232 OP_REQUIRES_OK(cc_ctx, s);
233
234 TF_Tensor* tf_rhs = TF_TensorFromTensor(rhs, &s);
235
236 if (!s.ok()) {
237 TF_DeleteTensor(tf_lhs);
238 OP_REQUIRES_OK(cc_ctx, s);
239 }
240
241 copyFunc(ctx, tf_rhs, tf_lhs);
242 };
243
244 ::tensorflow::AssignRefVariable(cc_ctx, input_ref_index, output_ref_index,
245 value_index, use_locking, validate_shape,
246 false, copy);
247 TF_SetStatus(status, TF_OK, "");
248}
249
250void TF_AssignUpdateVariable(TF_OpKernelContext* ctx, int input_index,
251 int value_index, int Op, int isVariantType,
252 void (*copyFunc)(TF_OpKernelContext* ctx,
253 TF_Tensor* source,
254 TF_Tensor* dest),
255 void (*updateFunc)(TF_OpKernelContext* ctx,
256 TF_Tensor* tensor,
257 TF_Tensor* value, int Op),
258 TF_Status* tf_status) {
259 auto* context = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
260 tensorflow::core::RefCountPtr<Var> variable;
261 Status status =
262 LookupResource(context, HandleFromInput(context, input_index), &variable);
263 if (!status.ok()) {
264 printf("Failed with error: %s\n", status.error_message().c_str());
265 abort();
266 }
267 const Tensor& value = context->input(value_index);
268 mutex_lock ml(*variable->mu());
269 Tensor* var_tensor = variable->tensor();
270 OP_REQUIRES(
271 context, var_tensor->shape().IsSameSize(value.shape()),
272 InvalidArgument("Cannot update variable with shape ",
273 var_tensor->shape().DebugString(),
274 " using a Tensor with shape ",
275 value.shape().DebugString(), ", shapes must be equal."));
276 OP_REQUIRES_OK(context,
277 PrepareToUpdateVariable(ctx, var_tensor,
278 variable->copy_on_read_mode.load(),
279 isVariantType, copyFunc));
280 tensorflow::Status s;
281 TF_Tensor* tf_var_tensor = TF_TensorFromTensor(*var_tensor, &s);
282 TF_Tensor* tf_value = TF_TensorFromTensor(value, &s);
283 updateFunc(ctx, tf_var_tensor, tf_value, Op);
284 TF_SetStatus(tf_status, TF_OK, "");
285}
286
287void TF_MaybeLockVariableInputMutexesInOrder(
288 TF_OpKernelContext* ctx, bool do_lock, bool sparse, const int* const inputs,
289 size_t len,
290 void (*copyFunc)(TF_OpKernelContext* ctx, TF_Tensor* source,
291 TF_Tensor* dest),
292 TF_VariableInputLockHolder** lockHolder, TF_Status* status) {
293 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
294 bool any_resource = false;
295 std::vector<int> input_ids(inputs, inputs + len);
296 for (auto i : input_ids) {
297 if (cc_ctx->input_dtype(i) == tensorflow::DT_RESOURCE) {
298 any_resource = true;
299 break;
300 }
301 }
302 if (!do_lock && !any_resource) {
303 *lockHolder = new TF_VariableInputLockHolder({}, {}, {});
304 TF_SetStatus(status, TF_OK, "");
305 return;
306 }
307 std::vector<tensorflow::Var*> vars;
308 std::vector<tensorflow::mutex*> mutexes;
309 std::vector<int32_t> acquire_order;
310 for (auto input : input_ids) {
311 tensorflow::Var* var;
312 tensorflow::mutex* mutex =
313 GetTrainingVariableMutex(ctx, input, sparse, copyFunc, &var);
314 if (var) vars.push_back(var);
315 // Only lock each mutex once if duplicates exist (n^2 but n is 2 or 3).
316 if (std::find(mutexes.begin(), mutexes.end(), mutex) == mutexes.end()) {
317 acquire_order.push_back(mutexes.size());
318 mutexes.push_back(mutex);
319 }
320 }
321 std::sort(acquire_order.begin(), acquire_order.end(),
322 [&mutexes](int a, int b) { return mutexes[a] < mutexes[b]; });
323
324 auto locks = absl::make_unique<std::vector<tensorflow::mutex_lock>>();
325 auto shared_locks =
326 absl::make_unique<std::vector<tensorflow::tf_shared_lock>>();
327 locks->reserve(acquire_order.size());
328
329 for (auto input : acquire_order) {
330 tensorflow::Var* var;
331 tensorflow::mutex* mu =
332 GetTrainingVariableMutex(ctx, input, sparse, copyFunc, &var);
333 tensorflow::core::ScopedUnref scoped_unref(var);
334 if (mu != nullptr) {
335 if (do_lock) {
336 locks->emplace_back(*mu);
337 } else {
338 shared_locks->emplace_back(*mu);
339 }
340 }
341 }
342 *lockHolder = new TF_VariableInputLockHolder(
343 std::move(vars), std::move(locks), std::move(shared_locks));
344 TF_SetStatus(status, TF_OK, "");
345}
346
347void TF_GetInputTensorFromVariable(TF_OpKernelContext* ctx, int input,
348 bool lock_held, bool isVariantType,
349 bool sparse,
350 void (*copyFunc)(TF_OpKernelContext* ctx,
351 TF_Tensor* source,
352 TF_Tensor* dest),
353 TF_Tensor** out, TF_Status* status) {
354 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
355 tensorflow::Status s;
356 if (cc_ctx->input_dtype(input) == tensorflow::DT_RESOURCE) {
357 tensorflow::core::RefCountPtr<tensorflow::Var> var;
358 OP_REQUIRES_OK(
359 cc_ctx, LookupResource(cc_ctx, HandleFromInput(cc_ctx, input), &var));
360 if (sparse) {
361 OP_REQUIRES_OK(cc_ctx, EnsureSparseVariableAccess(ctx, isVariantType,
362 copyFunc, var.get()));
363 *out = ::tensorflow::TF_TensorFromTensor(*var->tensor(), &s);
364 ::tensorflow::Set_TF_Status_from_Status(status, s);
365 return;
366 }
367 OP_REQUIRES_OK(cc_ctx, PrepareToUpdateVariable(
368 ctx, var->tensor(),
369 var->copy_on_read_mode.load(), false, copyFunc));
370 *out = ::tensorflow::TF_TensorFromTensor(*var->tensor(), &s);
371 ::tensorflow::Set_TF_Status_from_Status(status, s);
372 return;
373 }
374 *out = ::tensorflow::TF_TensorFromTensor(
375 cc_ctx->mutable_input(input, lock_held), &s);
376 ::tensorflow::Set_TF_Status_from_Status(status, s);
377}
378
379void TF_OpKernelContext_ForwardRefInputToRefOutput(TF_OpKernelContext* ctx,
380 int32_t input_index,
381 int32_t output_index) {
382 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
383 if (cc_ctx->input_dtype(input_index) != tensorflow::DT_RESOURCE) {
384 cc_ctx->forward_ref_input_to_ref_output(input_index, output_index);
385 }
386}
387
388void TF_ReleaseVariableInputLockHolder(TF_VariableInputLockHolder* lockHolder) {
389 if (lockHolder != nullptr) {
390 lockHolder->locks.reset();
391 for (tensorflow::Var* var : lockHolder->vars) {
392 var->Unref();
393 }
394 delete lockHolder;
395 }
396}
397
398void TF_GetInputByName(TF_OpKernelContext* ctx, const char* inputName,
399 TF_Tensor** tensor, TF_Status* status) {
400 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
401 const ::tensorflow::Tensor* cc_tensor = nullptr;
402 tensorflow::Status s = cc_ctx->input(inputName, &cc_tensor);
403
404 if (!s.ok()) {
405 ::tensorflow::Set_TF_Status_from_Status(status, s);
406 return;
407 }
408 TF_Tensor* result =
409 ::tensorflow::TF_TensorFromTensor(*cc_tensor, &status->status);
410 if (TF_GetCode(status) == TF_OK) {
411 *tensor = result;
412 }
413}
414
415void TF_OpKernelConstruction_GetAttrTensorShape(TF_OpKernelConstruction* ctx,
416 const char* attr_name,
417 int64_t* dims, size_t num_dims,
418 TF_Status* status) {
419 ::tensorflow::TensorShape shape;
420 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
421 ::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &shape);
422 ::tensorflow::Set_TF_Status_from_Status(status, s);
423 size_t rank = static_cast<size_t>(shape.dims());
424
425 if (!status->status.ok()) return;
426
427 if (num_dims != rank) {
428 status->status = InvalidArgument("Expected rank is ", num_dims,
429 " but actual rank is ", rank);
430 return;
431 }
432
433 for (int i = 0; i < rank; ++i) {
434 dims[i] = static_cast<int64_t>(shape.dim_size(i));
435 }
436}
437
438bool TF_IsRefInput(TF_OpKernelContext* ctx, int i, TF_Status* status) {
439 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
440 if (i < 0 || i >= cc_ctx->num_inputs()) {
441 TF_SetStatus(status, TF_OUT_OF_RANGE, "input index out of range");
442 return false;
443 }
444 TF_SetStatus(status, TF_OK, "");
445 return cc_ctx->input_is_ref(i);
446}
447
448#ifndef IS_MOBILE_PLATFORM
449template <typename T>
450static Status ValidateVariantType(const Variant& variant) {
451 if (variant.get<T>() == nullptr) {
452 const std::string type_index_name =
453 ::tensorflow::port::MaybeAbiDemangle(variant.TypeId().name());
454
455 return ::tensorflow::errors::Internal(
456 "VariantBinaryOpFn: Could not access object 'a', type_index: ",
457 type_index_name);
458 }
459
460 return ::tensorflow::OkStatus();
461}
462
463void TF_AddNVariant(TF_OpKernelContext* ctx,
464 void (*binary_add_func)(TF_OpKernelContext* ctx,
465 TF_Tensor* a, TF_Tensor* b,
466 TF_Tensor* out),
467 TF_Status* status) {
468 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
469
470 auto cc_binary_add_func = [binary_add_func](
471 ::tensorflow::OpKernelContext* cc_ctx,
472 const Tensor& cc_a, const Tensor& cc_b,
473 Tensor* cc_out) {
474 if (cc_a.dtype() == ::tensorflow::DT_INVALID) {
475 *cc_out = cc_b;
476 return ::tensorflow::OkStatus();
477 }
478 if (cc_b.dtype() == ::tensorflow::DT_INVALID) {
479 *cc_out = cc_a;
480 return ::tensorflow::OkStatus();
481 }
482
483 Status status;
484 TF_Tensor* a = TF_TensorFromTensor(cc_a, &status);
485 TF_RETURN_IF_ERROR(status);
486
487 TF_Tensor* b = TF_TensorFromTensor(cc_b, &status);
488 if (!status.ok()) {
489 TF_DeleteTensor(a);
490 return status;
491 }
492
493 ::tensorflow::AllocatorAttributes attr;
494 if (cc_a.dtype() == ::tensorflow::DT_VARIANT) {
495 attr.set_on_host(true);
496 }
497
498 status = cc_ctx->allocate_temp(cc_a.dtype(), cc_a.shape(), cc_out, attr);
499 if (!status.ok()) {
500 TF_DeleteTensor(a);
501 TF_DeleteTensor(b);
502 return status;
503 }
504
505 TF_Tensor* out = TF_TensorFromTensor(*cc_out, &status);
506 if (!status.ok()) {
507 TF_DeleteTensor(a);
508 TF_DeleteTensor(b);
509 return status;
510 }
511
512 auto* ctx = reinterpret_cast<TF_OpKernelContext*>(cc_ctx);
513 binary_add_func(ctx, a, b, out);
514 return cc_ctx->status();
515 };
516
517 auto binary_add_variant = [cc_binary_add_func](
518 ::tensorflow::OpKernelContext* cc_ctx,
519 const Variant& a, const Variant& b,
520 Variant* out) {
521 if (out == nullptr) {
522 return ::tensorflow::errors::Internal(
523 "The output variant hasn't been initialized");
524 }
525
526 if (a.TypeId() != b.TypeId()) {
527 return ::tensorflow::errors::Internal(
528 "BinaryOpVariants: Variants a and b have different "
529 "type ids. Type names: '",
530 a.TypeName(), "' vs. '", b.TypeName(), "'");
531 }
532
533 if (a.TypeId() == tensorflow::TypeIndex::Make<::tensorflow::TensorList>()) {
534 TF_RETURN_IF_ERROR(ValidateVariantType<::tensorflow::TensorList>(a));
535 *out = ::tensorflow::TensorList();
536
537 return ::tensorflow::TensorListBinaryAdd(
538 cc_ctx, *a.get<::tensorflow::TensorList>(),
539 *b.get<::tensorflow::TensorList>(),
540 out->get<::tensorflow::TensorList>(), cc_binary_add_func);
541 } else if (a.TypeId() == tensorflow::TypeIndex::Make<
542 ::tensorflow::data::OptionalVariant>()) {
543 TF_RETURN_IF_ERROR(
544 ValidateVariantType<::tensorflow::data::OptionalVariant>(a));
545 *out = ::tensorflow::data::OptionalVariant();
546
547 return ::tensorflow::data::OptionalBinaryAdd(
548 cc_ctx, *a.get<::tensorflow::data::OptionalVariant>(),
549 *b.get<::tensorflow::data::OptionalVariant>(),
550 out->get<::tensorflow::data::OptionalVariant>(), cc_binary_add_func);
551 }
552
553 const std::string type_index_name =
554 ::tensorflow::port::MaybeAbiDemangle(a.TypeId().name());
555
556 return ::tensorflow::errors::Internal(
557 "No unary variant binary_op function found for op ADD Variant "
558 "type_name: ",
559 type_index_name, " for device type: ", cc_ctx->device()->name());
560 };
561 ::tensorflow::AddNVariant(cc_ctx, binary_add_variant);
562 ::tensorflow::Set_TF_Status_from_Status(status, cc_ctx->status());
563}
564
565static Status ZerosLikeVariant(::tensorflow::OpKernelContext* cc_ctx,
566 const Variant& input, Variant* out,
567 void (*zeros_like_func)(TF_OpKernelContext* ctx,
568 TF_Tensor* input,
569 TF_Tensor* out)) {
570 auto cc_zeros_like_func = [zeros_like_func](
571 ::tensorflow::OpKernelContext* cc_ctx,
572 const Tensor& cc_input, Tensor* cc_out) {
573 AllocatorAttributes attr;
574 if (cc_input.dtype() == ::tensorflow::DT_VARIANT) {
575 attr.set_on_host(true);
576 }
577 TF_RETURN_IF_ERROR(cc_ctx->allocate_temp(cc_input.dtype(), cc_input.shape(),
578 cc_out, attr));
579
580 switch (cc_input.dtype()) {
581 case ::tensorflow::DT_INVALID: {
582 *cc_out = Tensor(::tensorflow::DT_INVALID);
583 break;
584 }
585 case ::tensorflow::DT_VARIANT: {
586 // If the wrapped tensor is also a variant, recursively call
587 // ZerosLikeVariant to unwrap it the same way
588 Variant* out_variant = cc_out->scalar<Variant>().data();
589 TF_RETURN_IF_ERROR(ZerosLikeVariant(cc_ctx,
590 cc_input.scalar<Variant>()(),
591 out_variant, zeros_like_func));
592 break;
593 }
594 default: {
595 Status status;
596 TF_Tensor* input = TF_TensorFromTensor(cc_input, &status);
597 TF_RETURN_IF_ERROR(status);
598
599 TF_Tensor* out = TF_TensorFromTensor(*cc_out, &status);
600 if (!status.ok()) {
601 TF_DeleteTensor(input);
602 return status;
603 }
604
605 auto* ctx = reinterpret_cast<TF_OpKernelContext*>(cc_ctx);
606 zeros_like_func(ctx, input, out);
607 }
608 }
609 return cc_ctx->status();
610 };
611
612 if (out == nullptr) {
613 return ::tensorflow::errors::Internal(
614 "The output variant hasn't been initialized");
615 }
616
617 if (input.TypeId() ==
618 tensorflow::TypeIndex::Make<::tensorflow::TensorList>()) {
619 TF_RETURN_IF_ERROR(ValidateVariantType<::tensorflow::TensorList>(input));
620 *out = ::tensorflow::TensorList();
621
622 return ::tensorflow::TensorListZerosLike(
623 cc_ctx, *input.get<::tensorflow::TensorList>(),
624 out->get<::tensorflow::TensorList>(), cc_zeros_like_func);
625 } else if (input.TypeId() == tensorflow::TypeIndex::Make<
626 ::tensorflow::data::OptionalVariant>()) {
627 TF_RETURN_IF_ERROR(
628 ValidateVariantType<::tensorflow::data::OptionalVariant>(input));
629 *out = ::tensorflow::data::OptionalVariant();
630
631 return ::tensorflow::data::OptionalZerosLike(
632 cc_ctx, *input.get<::tensorflow::data::OptionalVariant>(),
633 out->get<::tensorflow::data::OptionalVariant>(), cc_zeros_like_func);
634 }
635
636 const std::string type_index_name =
637 ::tensorflow::port::MaybeAbiDemangle(input.TypeId().name());
638
639 return ::tensorflow::errors::Internal(
640 "No unary variant unary_op function found for op ZEROS_LIKE Variant "
641 "type_name: ",
642 type_index_name, " for device type: ", cc_ctx->device()->name());
643}
644
645void TF_ZerosLikeVariant(TF_OpKernelContext* ctx,
646 void (*zeros_like_func)(TF_OpKernelContext* ctx,
647 TF_Tensor* input,
648 TF_Tensor* out),
649 TF_Status* status) {
650 auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
651
652 const Tensor& input = cc_ctx->input(0);
653 OP_REQUIRES(cc_ctx, input.dims() == 0,
654 InvalidArgument(
655 "ZerosLike non-scalar Tensor with dtype=DT_VARIANT is not "
656 "supported."));
657 const Variant& v = input.scalar<Variant>()();
658 // DT_VARIANT tensors must be allocated on CPU since they wrap C++
659 // objects which can not be efficiently represented in GPU memory.
660 int numa_node = cc_ctx->device()->NumaNode();
661 Tensor out(::tensorflow::cpu_allocator(numa_node), ::tensorflow::DT_VARIANT,
662 ::tensorflow::TensorShape({}));
663 Variant* out_v = &(out.scalar<Variant>()());
664 Status cc_status = ZerosLikeVariant(cc_ctx, v, out_v, zeros_like_func);
665 ::tensorflow::Set_TF_Status_from_Status(status, cc_status);
666 OP_REQUIRES_OK(cc_ctx, cc_status);
667 cc_ctx->set_output(0, out);
668}
669#endif // IS_MOBILE_PLATFORM
670