1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/op_kernel.h"
17#include "tensorflow/core/framework/tensor.h"
18#include "tensorflow/core/lib/core/status.h"
19
20namespace tensorflow {
21void AssignRefVariable(
22 OpKernelContext* context, int input_ref_index, int output_ref_index,
23 int value_index, bool use_locking, bool validate_shape,
24 bool relax_constraints,
25 std::function<void(OpKernelContext*, Tensor*, const Tensor&)> copy) {
26 const Tensor& rhs = context->input(value_index);
27
28 // We always return the input ref.
29 context->forward_ref_input_to_ref_output(input_ref_index, output_ref_index);
30
31 // Prevent copying uninitialized data, to solve harder to debug undefined
32 // behaviors that cannot be traced back to the original tensor.
33 OP_REQUIRES(
34 context, rhs.IsInitialized(),
35 errors::Internal("Right hand side of AssignOp is not initialized"));
36
37 // We can't always know how this value will be used downstream, so make
38 // conservative assumptions in specifying constraints on the memory
39 // allocation attributes, unless the Grappler graph analysis determined that
40 // it was safe not to.
41 AllocatorAttributes attr;
42 if (!relax_constraints) {
43 attr.set_gpu_compatible(true);
44 attr.set_nic_compatible(true);
45 }
46
47 {
48 mutex_lock l(*context->input_ref_mutex(input_ref_index));
49 const Tensor& old_lhs =
50 context->mutable_input(input_ref_index, /* lock_held */ true);
51 const bool same_shape = old_lhs.shape().IsSameSize(rhs.shape());
52 if (validate_shape) {
53 OP_REQUIRES(context, same_shape,
54 errors::InvalidArgument(
55 "Assign requires shapes of both tensors to match. "
56 "lhs shape= ",
57 old_lhs.shape().DebugString(),
58 " rhs shape= ", rhs.shape().DebugString()));
59 }
60
61 // In the code below we try to minimize the amount of memory allocation
62 // and copying by trying the following two shortcuts:
63 // 1. If the lhs is initialized and has the same number of elements as
64 // the rhs we can avoid a memory allocation.
65 // 2. If we can reuse the rhs buffer we avoid both a memory allocation
66 // and copying.
67
68 // 1. Try to copy into an existing buffer.
69 if (old_lhs.IsInitialized() &&
70 old_lhs.shape().num_elements() == rhs.shape().num_elements()) {
71 // The existing lhs tensor has already been initialized and the right
72 // hand side can fit in the underlying buffer.
73 Tensor reshaped_old_lhs;
74 if (same_shape) {
75 reshaped_old_lhs = old_lhs;
76 } else {
77 OP_REQUIRES(context, reshaped_old_lhs.CopyFrom(old_lhs, rhs.shape()),
78 errors::Internal(
79 "Unable to copy the value tensor to the ref input"));
80 context->replace_ref_input(input_ref_index, reshaped_old_lhs,
81 /* lock_held */ true);
82 }
83 if (use_locking) {
84 copy(context, &reshaped_old_lhs, rhs);
85 return;
86 }
87 } else {
88 // 2. Try to reuse the rhs.
89 std::unique_ptr<Tensor> input_alias = context->forward_input(
90 value_index, OpKernelContext::Params::kNoReservation /*output_index*/,
91 rhs.dtype(), rhs.shape(), DEVICE_MEMORY, attr);
92 if (input_alias != nullptr) {
93 // Update the ref to point to the new buffer.
94 context->replace_ref_input(input_ref_index, *input_alias,
95 /* lock_held */ true);
96 return;
97 }
98
99 // Otherwise, create a new tensor whose shape matches the
100 // right hand side, hand off to lhs and copy the rhs into it.
101 Tensor copy_tensor;
102 OP_REQUIRES_OK(context,
103 context->allocate_temp(old_lhs.dtype(), rhs.shape(),
104 &copy_tensor, attr));
105 // We track memory of variables in variable ops instead of in this
106 // assign op.
107 context->clear_recorded_memory();
108 context->replace_ref_input(input_ref_index, copy_tensor,
109 /* lock_held */ true);
110 if (use_locking) {
111 copy(context, &copy_tensor, rhs);
112 return;
113 }
114 }
115 }
116
117 // The tensor has already been initialized and the right hand side
118 // matches the left hand side's shape. We have been told to do the
119 // copy outside the lock.
120 Tensor old_unlocked_lhs =
121 context->mutable_input(input_ref_index, /* lock_held */ false);
122 copy(context, &old_unlocked_lhs, rhs);
123}
124} // end namespace tensorflow
125