1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_KERNELS_TRAINING_OP_HELPERS_H_
17#define TENSORFLOW_CORE_KERNELS_TRAINING_OP_HELPERS_H_
18
19#include "tensorflow/core/framework/op_kernel.h"
20#include "tensorflow/core/framework/tensor.h"
21#include "tensorflow/core/framework/variant_op_registry.h"
22#include "tensorflow/core/kernels/dense_update_functor.h"
23#include "tensorflow/core/kernels/variable_ops.h"
24#include "tensorflow/core/lib/core/refcount.h"
25
26namespace tensorflow {
27
28// Must be called before performing a sparse operation on a variable. Ensures
29// that no concurrent dense operations can happen while holding the variable's
30// lock.
31template <typename Device, typename T>
32Status EnsureSparseVariableAccess(OpKernelContext* ctx, Var* var) {
33 if (var->copy_on_read_mode.load()) {
34 return OkStatus();
35 }
36 mutex_lock ml(*var->mu());
37 // Once copy-on-read mode is True the refcount is guaranteed to be 1. This can
38 // also happen if there are no concurrent reads of the variable and
39 // copy-on-read mode is false.
40 if (var->tensor()->RefCountIsOne()) {
41 var->copy_on_read_mode.store(true);
42 return OkStatus();
43 }
44 Tensor tmp;
45 if (std::is_same<T, Variant>::value) {
46 AllocatorAttributes attr;
47 attr.set_on_host(true);
48 TF_RETURN_IF_ERROR(ctx->allocate_temp(var->tensor()->dtype(),
49 var->tensor()->shape(), &tmp, attr));
50
51 const auto elements_in = var->tensor()->flat<Variant>();
52 auto elements_out = tmp.flat<Variant>();
53 for (int64_t i = 0; i < elements_in.size(); ++i) {
54 elements_out(i) = elements_in(i);
55 }
56 } else {
57 AllocatorAttributes attr;
58 attr.set_gpu_compatible(true);
59 attr.set_nic_compatible(true);
60 TF_RETURN_IF_ERROR(ctx->allocate_temp(var->tensor()->dtype(),
61 var->tensor()->shape(), &tmp, attr));
62 functor::DenseUpdate<Device, T, ASSIGN> copy_functor;
63 copy_functor(ctx->eigen_device<Device>(), tmp.flat<T>(),
64 const_cast<const Tensor*>(var->tensor())->flat<T>());
65 }
66 *var->tensor() = tmp;
67 var->copy_on_read_mode.store(true);
68 return OkStatus();
69}
70
71// Utility structure that releases a sequence of borrowed mutexes when it is
72// deleted.
73struct VariableInputLockHolder {
74 public:
75 VariableInputLockHolder(
76 std::vector<Var*> vars, std::unique_ptr<std::vector<mutex_lock>> locks,
77 std::unique_ptr<std::vector<tf_shared_lock>> shared_locks)
78 : vars_(std::move(vars)),
79 locks_(std::move(locks)),
80 shared_locks_(std::move(shared_locks)) {}
81
82 VariableInputLockHolder(VariableInputLockHolder&& other)
83 : vars_(std::move(other.vars_)),
84 locks_(std::move(other.locks_)),
85 shared_locks_(std::move(other.shared_locks_)) {}
86
87 ~VariableInputLockHolder() {
88 // Release the locks before unreffing the Vars, because each lock
89 // is potentially borrowed from a Var in vars_.
90 locks_.reset();
91 for (Var* var : vars_) {
92 var->Unref();
93 }
94 }
95
96 private:
97 std::vector<Var*> vars_;
98 // NOTE: Use a `std::unique_ptr` instead of moving in a vector directly,
99 // because a `std::vector<mutex_lock>` is not movable on all platforms.
100 std::unique_ptr<std::vector<mutex_lock>> locks_;
101 std::unique_ptr<std::vector<tf_shared_lock>> shared_locks_;
102};
103
104// Returns a borrowed pointer to the mutex for the variable `input` in `ctx`.
105//
106// If `input` corresponds to a `DT_RESOURCE`-type variable input,
107// `*maybe_resource` will be updated to contain the underlying resource, and the
108// caller will be responsible for calling `Unref()` on that resource.
109template <typename Device, typename T>
110mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input, bool sparse,
111 Var** maybe_resource) {
112 *maybe_resource = nullptr;
113 if (ctx->input_dtype(input) == DT_RESOURCE) {
114 if (LookupResource(ctx, HandleFromInput(ctx, input), maybe_resource).ok()) {
115 if (sparse) {
116 EnsureSparseVariableAccess<Device, T>(ctx, *maybe_resource)
117 .IgnoreError();
118 }
119 return (*maybe_resource)->mu();
120 } else {
121 ctx->CtxFailureWithWarning(
122 errors::Internal("Invalid variable reference."));
123 return nullptr;
124 }
125 }
126 return ctx->input_ref_mutex(input);
127}
128
129// MaybeLockVariableInputMutexesInOrder is a helper function to acquire mutexes
130// in address order to mitigate deadlock. Returns a structure that, when
131// deleted, will release the acquired mutexes. Safe to pass duplicates - will
132// only lock each distinct mutex once. If sparse is true will ensure the
133// variable gets switched to copy-on-read mode before trying to acquire the
134// locks. If do_lock is false, returns immediately for reference variables. For
135// resource variables in copy-on-read-mode it will grab a shared lock if do_lock
136// is false, exclusive lock otherwise. Note that this silently doesn't lock
137// mutexes for invalid variable references; in all usages this is followed by
138// GetInputTensor which will signal a failure.
139template <typename Device, typename T>
140VariableInputLockHolder MaybeLockVariableInputMutexesInOrder(
141 OpKernelContext* ctx, bool do_lock, bool sparse,
142 const std::vector<int>& input_ids) {
143 bool any_resource = false;
144 for (auto i : input_ids) {
145 if (ctx->input_dtype(i) == DT_RESOURCE) {
146 any_resource = true;
147 break;
148 }
149 }
150 if (!do_lock && !any_resource) {
151 return VariableInputLockHolder({}, {}, {});
152 }
153 std::vector<Var*> vars;
154 std::vector<mutex*> mutexes;
155 std::vector<int> acquire_order;
156 for (auto input : input_ids) {
157 Var* var;
158 mutex* mutex =
159 GetTrainingVariableMutex<Device, T>(ctx, input, sparse, &var);
160 if (var) vars.push_back(var);
161 // Only lock each mutex once if duplicates exist (n^2 but n is 2 or 3).
162 if (std::find(mutexes.begin(), mutexes.end(), mutex) == mutexes.end()) {
163 acquire_order.push_back(mutexes.size());
164 mutexes.push_back(mutex);
165 }
166 }
167 std::sort(acquire_order.begin(), acquire_order.end(),
168 [&mutexes](int a, int b) { return mutexes[a] < mutexes[b]; });
169
170 auto locks = std::make_unique<std::vector<mutex_lock>>();
171 auto shared_locks = std::make_unique<std::vector<tf_shared_lock>>();
172 locks->reserve(acquire_order.size());
173
174 for (auto acquire : acquire_order) {
175 mutex* mu = mutexes[acquire];
176 if (mu != nullptr) {
177 if (!sparse || do_lock) {
178 locks->emplace_back(*mu);
179 } else {
180 shared_locks->emplace_back(*mu);
181 }
182 }
183 }
184 return VariableInputLockHolder(std::move(vars), std::move(locks),
185 std::move(shared_locks));
186}
187
188void MaybeForwardRefInputToRefOutput(OpKernelContext* ctx, int input,
189 int output);
190
191// This is for use with ResourceVariables to ensure *tensor has a
192// reference count of 1 before you update it.
193// REQUIRES: If you pass in variable->tensor(), *variable->mu() must be held.
194template <typename Device, typename T>
195Status PrepareToUpdateVariable(OpKernelContext* ctx, Tensor* tensor,
196 bool copy_on_read_mode) {
197 if (copy_on_read_mode || !tensor->RefCountIsOne()) {
198 // Tensor's buffer is in use by some read, so we need to copy before
199 // updating.
200 Tensor tmp;
201 if (std::is_same<T, Variant>::value) {
202 AllocatorAttributes attr;
203 attr.set_on_host(true);
204 TF_RETURN_IF_ERROR(
205 ctx->allocate_temp(tensor->dtype(), tensor->shape(), &tmp, attr));
206
207 const auto elements_in = tensor->flat<Variant>();
208 auto elements_out = tmp.flat<Variant>();
209 for (int64_t i = 0; i < elements_in.size(); ++i) {
210 elements_out(i) = elements_in(i);
211 }
212 } else {
213 AllocatorAttributes attr;
214 attr.set_gpu_compatible(true);
215 attr.set_nic_compatible(true);
216 TF_RETURN_IF_ERROR(
217 ctx->allocate_temp(tensor->dtype(), tensor->shape(), &tmp, attr));
218 functor::DenseUpdate<Device, T, ASSIGN> copy_functor;
219 copy_functor(ctx->eigen_device<Device>(), tmp.flat<T>(),
220 const_cast<const Tensor*>(tensor)->flat<T>());
221 }
222 *tensor = tmp;
223 }
224 return OkStatus();
225}
226
227// This gives you `*out`, a tensor you can update, corresponding to a variable
228// passed as input index `input`. This handles the differences between
229// reference and resource variables. For reference variables we can just grab
230// the tensor, grabbing the lock if lock_held is False.
231//
232// For resource variables we, if sparse is true, ensure it's in copy-on-read
233// mode, and then, regardless of the value of sparse, ensure its refcount is 1
234// (by potentially copying its contents). In this case lock_held is ignored.
235template <typename Device, typename T>
236Status GetInputTensorFromVariable(OpKernelContext* ctx, int input,
237 bool lock_held, bool sparse, Tensor* out) {
238 if (ctx->input_dtype(input) == DT_RESOURCE) {
239 core::RefCountPtr<Var> var;
240 TF_RETURN_IF_ERROR(LookupResource(ctx, HandleFromInput(ctx, input), &var));
241 if (sparse) {
242 TF_RETURN_IF_ERROR(EnsureSparseVariableAccess<Device, T>(ctx, var.get()));
243 *out = *var->tensor();
244 return OkStatus();
245 }
246 TF_RETURN_IF_ERROR(PrepareToUpdateVariable<Device, T>(
247 ctx, var->tensor(), var->copy_on_read_mode.load()));
248 *out = *var->tensor();
249 return OkStatus();
250 }
251 *out = ctx->mutable_input(input, lock_held);
252 return OkStatus();
253}
254
255} // end namespace tensorflow
256
257#endif // TENSORFLOW_CORE_KERNELS_TRAINING_OP_HELPERS_H_
258