1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#define EIGEN_USE_THREADS
17
18#include <deque>
19#include <utility>
20
21#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22#include "tensorflow/core/framework/resource_mgr.h"
23#include "tensorflow/core/framework/shared_ptr_variant.h"
24#include "tensorflow/core/framework/variant.h"
25#include "tensorflow/core/framework/variant_encode_decode.h"
26#include "tensorflow/core/kernels/ops_util.h"
27#include "tensorflow/core/lib/core/errors.h"
28#include "tensorflow/core/lib/core/threadpool.h"
29#include "tensorflow/core/platform/macros.h"
30#include "tensorflow/core/platform/mutex.h"
31#include "tensorflow/core/platform/types.h"
32
33namespace tensorflow {
34
35namespace {
36
37class Mutex : public ResourceBase {
38 public:
39 explicit Mutex(OpKernelContext* c, const string& name)
40 : locked_(false),
41 thread_pool_(new thread::ThreadPool(
42 c->env(), ThreadOptions(),
43 strings::StrCat("mutex_lock_thread_", SanitizeThreadSuffix(name)),
44 1 /* num_threads */, false /* low_latency_hint */)),
45 name_(name) {
46 VLOG(2) << "Creating mutex with name " << name << ": " << this;
47 }
48
49 string DebugString() const override {
50 return strings::StrCat("Mutex ", name_);
51 }
52
53 class LockReleaser {
54 public:
55 explicit LockReleaser(Mutex* mutex) : mutex_(mutex) {}
56
57 LockReleaser(const LockReleaser&) = delete;
58 LockReleaser& operator=(const LockReleaser&) = delete;
59
60 virtual ~LockReleaser() {
61 VLOG(3) << "Destroying LockReleaser " << this << " for mutex: " << mutex_;
62 if (mutex_) {
63 mutex_lock lock(mutex_->mu_);
64 mutex_->locked_ = false;
65 mutex_->cv_.notify_all();
66 VLOG(3) << "Destroying LockReleaser " << this
67 << ": sent notifications.";
68 }
69 }
70
71 private:
72 Mutex* mutex_;
73 };
74
75 typedef SharedPtrVariant<LockReleaser> SharedLockReleaser;
76
77 void AcquireAsync(
78 OpKernelContext* c,
79 std::function<void(const Status& s, SharedLockReleaser lock)> fn) {
80 CancellationManager* cm = c->cancellation_manager();
81 CancellationToken token{};
82 bool* cancelled = nullptr;
83 if (cm) {
84 cancelled = new bool(false); // TF_GUARDED_BY(mu_);
85 token = cm->get_cancellation_token();
86 const bool already_cancelled =
87 !cm->RegisterCallback(token, [this, cancelled]() {
88 mutex_lock lock(mu_);
89 *cancelled = true;
90 cv_.notify_all();
91 });
92 if (already_cancelled) {
93 delete cancelled;
94 fn(errors::Cancelled("Lock acquisition cancelled."),
95 SharedLockReleaser{nullptr});
96 return;
97 }
98 }
99 thread_pool_->Schedule(std::bind(
100 [this, cm, cancelled,
101 token](std::function<void(const Status& s, SharedLockReleaser&& lock)>
102 fn_) {
103 bool local_locked;
104 {
105 mutex_lock lock(mu_);
106 while (locked_ && !(cancelled && *cancelled)) {
107 cv_.wait(lock);
108 }
109 local_locked = locked_ = !(cancelled && *cancelled);
110 }
111 if (cm) {
112 cm->DeregisterCallback(token);
113 delete cancelled;
114 }
115 if (local_locked) { // Not cancelled.
116 fn_(OkStatus(),
117 SharedLockReleaser{std::make_shared<LockReleaser>(this)});
118 } else {
119 fn_(errors::Cancelled("Lock acquisition cancelled."),
120 SharedLockReleaser{nullptr});
121 }
122 },
123 std::move(fn)));
124 }
125
126 private:
127 mutex mu_;
128 condition_variable cv_ TF_GUARDED_BY(mu_);
129 bool locked_ TF_GUARDED_BY(mu_);
130 std::unique_ptr<thread::ThreadPool> thread_pool_;
131 string name_;
132};
133
134} // namespace
135
136class MutexLockOp : public AsyncOpKernel {
137 public:
138 explicit MutexLockOp(OpKernelConstruction* c) : AsyncOpKernel(c) {}
139
140 public:
141 void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
142 Mutex* mutex = nullptr;
143 OP_REQUIRES_OK_ASYNC(
144 c,
145 LookupOrCreateResource<Mutex>(c, HandleFromInput(c, 0), &mutex,
146 [c](Mutex** ptr) {
147 *ptr = new Mutex(
148 c, HandleFromInput(c, 0).name());
149 return OkStatus();
150 }),
151 done);
152
153 Tensor* variant;
154 OP_REQUIRES_OK_ASYNC(c, c->allocate_output(0, TensorShape({}), &variant),
155 done);
156
157 mutex->AcquireAsync(
158 c, std::bind(
159 [c, variant, mutex](DoneCallback done_,
160 // End of bound arguments.
161 const Status& s,
162 Mutex::SharedLockReleaser&& lock) {
163 VLOG(2) << "Finished locking mutex " << mutex
164 << " with lock: " << lock.shared_ptr.get()
165 << " status: " << s.ToString();
166 if (s.ok()) {
167 variant->scalar<Variant>()() = std::move(lock);
168 } else {
169 c->SetStatus(s);
170 }
171 mutex->Unref();
172 done_();
173 },
174 std::move(done), std::placeholders::_1, std::placeholders::_2));
175 }
176};
177
178class ConsumeMutexLockOp : public OpKernel {
179 public:
180 explicit ConsumeMutexLockOp(OpKernelConstruction* context)
181 : OpKernel(context) {}
182
183 void Compute(OpKernelContext* c) override {
184 VLOG(2) << "Executing ConsumeMutexLockOp";
185 const Tensor& lock_t = c->input(0);
186 OP_REQUIRES(
187 c, lock_t.dims() == 0,
188 errors::InvalidArgument("Expected input to be a scalar, saw shape: ",
189 lock_t.shape().DebugString()));
190 OP_REQUIRES(
191 c, lock_t.dtype() == DT_VARIANT,
192 errors::InvalidArgument("Expected input to be a variant, saw type: ",
193 DataTypeString(lock_t.dtype())));
194 const auto* lock =
195 lock_t.scalar<Variant>()().get<Mutex::SharedLockReleaser>();
196 OP_REQUIRES(c, lock,
197 errors::InvalidArgument(
198 "Expected input to contain a SharedLockReleaser "
199 "object, but saw variant: '",
200 lock_t.scalar<Variant>()().DebugString(), "'"));
201 const int use_count = lock->shared_ptr.use_count();
202 OP_REQUIRES(
203 c, use_count == 1,
204 errors::InvalidArgument("Expected use count of lock to be 1, but saw: ",
205 use_count));
206 }
207
208 bool IsExpensive() override { return false; }
209};
210
211REGISTER_KERNEL_BUILDER(Name("MutexLock").Device(DEVICE_CPU), MutexLockOp);
212
213REGISTER_KERNEL_BUILDER(Name("MutexLock")
214 .Device(DEVICE_DEFAULT)
215 .HostMemory("mutex_lock")
216 .HostMemory("mutex"),
217 MutexLockOp);
218
219REGISTER_KERNEL_BUILDER(
220 Name("MutexV2").Device(DEVICE_CPU).HostMemory("resource"),
221 ResourceHandleOp<Mutex>);
222
223REGISTER_KERNEL_BUILDER(Name("MutexV2").Device(DEVICE_DEFAULT),
224 ResourceHandleOp<Mutex>);
225
226REGISTER_KERNEL_BUILDER(Name("ConsumeMutexLock").Device(DEVICE_CPU),
227 ConsumeMutexLockOp);
228
229REGISTER_KERNEL_BUILDER(
230 Name("ConsumeMutexLock").Device(DEVICE_DEFAULT).HostMemory("mutex_lock"),
231 ConsumeMutexLockOp);
232
233} // namespace tensorflow
234