1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/op_kernel.h" |
17 | #include "tensorflow/core/framework/tensor.h" |
18 | #include "tensorflow/core/lib/core/status.h" |
19 | |
20 | namespace tensorflow { |
21 | void AssignRefVariable( |
22 | OpKernelContext* context, int input_ref_index, int output_ref_index, |
23 | int value_index, bool use_locking, bool validate_shape, |
24 | bool relax_constraints, |
25 | std::function<void(OpKernelContext*, Tensor*, const Tensor&)> copy) { |
26 | const Tensor& rhs = context->input(value_index); |
27 | |
28 | // We always return the input ref. |
29 | context->forward_ref_input_to_ref_output(input_ref_index, output_ref_index); |
30 | |
31 | // Prevent copying uninitialized data, to solve harder to debug undefined |
32 | // behaviors that cannot be traced back to the original tensor. |
33 | OP_REQUIRES( |
34 | context, rhs.IsInitialized(), |
35 | errors::Internal("Right hand side of AssignOp is not initialized" )); |
36 | |
37 | // We can't always know how this value will be used downstream, so make |
38 | // conservative assumptions in specifying constraints on the memory |
39 | // allocation attributes, unless the Grappler graph analysis determined that |
40 | // it was safe not to. |
41 | AllocatorAttributes attr; |
42 | if (!relax_constraints) { |
43 | attr.set_gpu_compatible(true); |
44 | attr.set_nic_compatible(true); |
45 | } |
46 | |
47 | { |
48 | mutex_lock l(*context->input_ref_mutex(input_ref_index)); |
49 | const Tensor& old_lhs = |
50 | context->mutable_input(input_ref_index, /* lock_held */ true); |
51 | const bool same_shape = old_lhs.shape().IsSameSize(rhs.shape()); |
52 | if (validate_shape) { |
53 | OP_REQUIRES(context, same_shape, |
54 | errors::InvalidArgument( |
55 | "Assign requires shapes of both tensors to match. " |
56 | "lhs shape= " , |
57 | old_lhs.shape().DebugString(), |
58 | " rhs shape= " , rhs.shape().DebugString())); |
59 | } |
60 | |
61 | // In the code below we try to minimize the amount of memory allocation |
62 | // and copying by trying the following two shortcuts: |
63 | // 1. If the lhs is initialized and has the same number of elements as |
64 | // the rhs we can avoid a memory allocation. |
65 | // 2. If we can reuse the rhs buffer we avoid both a memory allocation |
66 | // and copying. |
67 | |
68 | // 1. Try to copy into an existing buffer. |
69 | if (old_lhs.IsInitialized() && |
70 | old_lhs.shape().num_elements() == rhs.shape().num_elements()) { |
71 | // The existing lhs tensor has already been initialized and the right |
72 | // hand side can fit in the underlying buffer. |
73 | Tensor reshaped_old_lhs; |
74 | if (same_shape) { |
75 | reshaped_old_lhs = old_lhs; |
76 | } else { |
77 | OP_REQUIRES(context, reshaped_old_lhs.CopyFrom(old_lhs, rhs.shape()), |
78 | errors::Internal( |
79 | "Unable to copy the value tensor to the ref input" )); |
80 | context->replace_ref_input(input_ref_index, reshaped_old_lhs, |
81 | /* lock_held */ true); |
82 | } |
83 | if (use_locking) { |
84 | copy(context, &reshaped_old_lhs, rhs); |
85 | return; |
86 | } |
87 | } else { |
88 | // 2. Try to reuse the rhs. |
89 | std::unique_ptr<Tensor> input_alias = context->forward_input( |
90 | value_index, OpKernelContext::Params::kNoReservation /*output_index*/, |
91 | rhs.dtype(), rhs.shape(), DEVICE_MEMORY, attr); |
92 | if (input_alias != nullptr) { |
93 | // Update the ref to point to the new buffer. |
94 | context->replace_ref_input(input_ref_index, *input_alias, |
95 | /* lock_held */ true); |
96 | return; |
97 | } |
98 | |
99 | // Otherwise, create a new tensor whose shape matches the |
100 | // right hand side, hand off to lhs and copy the rhs into it. |
101 | Tensor copy_tensor; |
102 | OP_REQUIRES_OK(context, |
103 | context->allocate_temp(old_lhs.dtype(), rhs.shape(), |
104 | ©_tensor, attr)); |
105 | // We track memory of variables in variable ops instead of in this |
106 | // assign op. |
107 | context->clear_recorded_memory(); |
108 | context->replace_ref_input(input_ref_index, copy_tensor, |
109 | /* lock_held */ true); |
110 | if (use_locking) { |
111 | copy(context, ©_tensor, rhs); |
112 | return; |
113 | } |
114 | } |
115 | } |
116 | |
117 | // The tensor has already been initialized and the right hand side |
118 | // matches the left hand side's shape. We have been told to do the |
119 | // copy outside the lock. |
120 | Tensor old_unlocked_lhs = |
121 | context->mutable_input(input_ref_index, /* lock_held */ false); |
122 | copy(context, &old_unlocked_lhs, rhs); |
123 | } |
124 | } // end namespace tensorflow |
125 | |