1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #define EIGEN_USE_THREADS |
17 | |
18 | #include <deque> |
19 | #include <utility> |
20 | |
21 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
22 | #include "tensorflow/core/framework/resource_mgr.h" |
23 | #include "tensorflow/core/framework/shared_ptr_variant.h" |
24 | #include "tensorflow/core/framework/variant.h" |
25 | #include "tensorflow/core/framework/variant_encode_decode.h" |
26 | #include "tensorflow/core/kernels/ops_util.h" |
27 | #include "tensorflow/core/lib/core/errors.h" |
28 | #include "tensorflow/core/lib/core/threadpool.h" |
29 | #include "tensorflow/core/platform/macros.h" |
30 | #include "tensorflow/core/platform/mutex.h" |
31 | #include "tensorflow/core/platform/types.h" |
32 | |
33 | namespace tensorflow { |
34 | |
35 | namespace { |
36 | |
37 | class Mutex : public ResourceBase { |
38 | public: |
39 | explicit Mutex(OpKernelContext* c, const string& name) |
40 | : locked_(false), |
41 | thread_pool_(new thread::ThreadPool( |
42 | c->env(), ThreadOptions(), |
43 | strings::StrCat("mutex_lock_thread_" , SanitizeThreadSuffix(name)), |
44 | 1 /* num_threads */, false /* low_latency_hint */)), |
45 | name_(name) { |
46 | VLOG(2) << "Creating mutex with name " << name << ": " << this; |
47 | } |
48 | |
49 | string DebugString() const override { |
50 | return strings::StrCat("Mutex " , name_); |
51 | } |
52 | |
53 | class LockReleaser { |
54 | public: |
55 | explicit LockReleaser(Mutex* mutex) : mutex_(mutex) {} |
56 | |
57 | LockReleaser(const LockReleaser&) = delete; |
58 | LockReleaser& operator=(const LockReleaser&) = delete; |
59 | |
60 | virtual ~LockReleaser() { |
61 | VLOG(3) << "Destroying LockReleaser " << this << " for mutex: " << mutex_; |
62 | if (mutex_) { |
63 | mutex_lock lock(mutex_->mu_); |
64 | mutex_->locked_ = false; |
65 | mutex_->cv_.notify_all(); |
66 | VLOG(3) << "Destroying LockReleaser " << this |
67 | << ": sent notifications." ; |
68 | } |
69 | } |
70 | |
71 | private: |
72 | Mutex* mutex_; |
73 | }; |
74 | |
75 | typedef SharedPtrVariant<LockReleaser> SharedLockReleaser; |
76 | |
77 | void AcquireAsync( |
78 | OpKernelContext* c, |
79 | std::function<void(const Status& s, SharedLockReleaser lock)> fn) { |
80 | CancellationManager* cm = c->cancellation_manager(); |
81 | CancellationToken token{}; |
82 | bool* cancelled = nullptr; |
83 | if (cm) { |
84 | cancelled = new bool(false); // TF_GUARDED_BY(mu_); |
85 | token = cm->get_cancellation_token(); |
86 | const bool already_cancelled = |
87 | !cm->RegisterCallback(token, [this, cancelled]() { |
88 | mutex_lock lock(mu_); |
89 | *cancelled = true; |
90 | cv_.notify_all(); |
91 | }); |
92 | if (already_cancelled) { |
93 | delete cancelled; |
94 | fn(errors::Cancelled("Lock acquisition cancelled." ), |
95 | SharedLockReleaser{nullptr}); |
96 | return; |
97 | } |
98 | } |
99 | thread_pool_->Schedule(std::bind( |
100 | [this, cm, cancelled, |
101 | token](std::function<void(const Status& s, SharedLockReleaser&& lock)> |
102 | fn_) { |
103 | bool local_locked; |
104 | { |
105 | mutex_lock lock(mu_); |
106 | while (locked_ && !(cancelled && *cancelled)) { |
107 | cv_.wait(lock); |
108 | } |
109 | local_locked = locked_ = !(cancelled && *cancelled); |
110 | } |
111 | if (cm) { |
112 | cm->DeregisterCallback(token); |
113 | delete cancelled; |
114 | } |
115 | if (local_locked) { // Not cancelled. |
116 | fn_(OkStatus(), |
117 | SharedLockReleaser{std::make_shared<LockReleaser>(this)}); |
118 | } else { |
119 | fn_(errors::Cancelled("Lock acquisition cancelled." ), |
120 | SharedLockReleaser{nullptr}); |
121 | } |
122 | }, |
123 | std::move(fn))); |
124 | } |
125 | |
126 | private: |
127 | mutex mu_; |
128 | condition_variable cv_ TF_GUARDED_BY(mu_); |
129 | bool locked_ TF_GUARDED_BY(mu_); |
130 | std::unique_ptr<thread::ThreadPool> thread_pool_; |
131 | string name_; |
132 | }; |
133 | |
134 | } // namespace |
135 | |
136 | class MutexLockOp : public AsyncOpKernel { |
137 | public: |
138 | explicit MutexLockOp(OpKernelConstruction* c) : AsyncOpKernel(c) {} |
139 | |
140 | public: |
141 | void ComputeAsync(OpKernelContext* c, DoneCallback done) override { |
142 | Mutex* mutex = nullptr; |
143 | OP_REQUIRES_OK_ASYNC( |
144 | c, |
145 | LookupOrCreateResource<Mutex>(c, HandleFromInput(c, 0), &mutex, |
146 | [c](Mutex** ptr) { |
147 | *ptr = new Mutex( |
148 | c, HandleFromInput(c, 0).name()); |
149 | return OkStatus(); |
150 | }), |
151 | done); |
152 | |
153 | Tensor* variant; |
154 | OP_REQUIRES_OK_ASYNC(c, c->allocate_output(0, TensorShape({}), &variant), |
155 | done); |
156 | |
157 | mutex->AcquireAsync( |
158 | c, std::bind( |
159 | [c, variant, mutex](DoneCallback done_, |
160 | // End of bound arguments. |
161 | const Status& s, |
162 | Mutex::SharedLockReleaser&& lock) { |
163 | VLOG(2) << "Finished locking mutex " << mutex |
164 | << " with lock: " << lock.shared_ptr.get() |
165 | << " status: " << s.ToString(); |
166 | if (s.ok()) { |
167 | variant->scalar<Variant>()() = std::move(lock); |
168 | } else { |
169 | c->SetStatus(s); |
170 | } |
171 | mutex->Unref(); |
172 | done_(); |
173 | }, |
174 | std::move(done), std::placeholders::_1, std::placeholders::_2)); |
175 | } |
176 | }; |
177 | |
178 | class ConsumeMutexLockOp : public OpKernel { |
179 | public: |
180 | explicit ConsumeMutexLockOp(OpKernelConstruction* context) |
181 | : OpKernel(context) {} |
182 | |
183 | void Compute(OpKernelContext* c) override { |
184 | VLOG(2) << "Executing ConsumeMutexLockOp" ; |
185 | const Tensor& lock_t = c->input(0); |
186 | OP_REQUIRES( |
187 | c, lock_t.dims() == 0, |
188 | errors::InvalidArgument("Expected input to be a scalar, saw shape: " , |
189 | lock_t.shape().DebugString())); |
190 | OP_REQUIRES( |
191 | c, lock_t.dtype() == DT_VARIANT, |
192 | errors::InvalidArgument("Expected input to be a variant, saw type: " , |
193 | DataTypeString(lock_t.dtype()))); |
194 | const auto* lock = |
195 | lock_t.scalar<Variant>()().get<Mutex::SharedLockReleaser>(); |
196 | OP_REQUIRES(c, lock, |
197 | errors::InvalidArgument( |
198 | "Expected input to contain a SharedLockReleaser " |
199 | "object, but saw variant: '" , |
200 | lock_t.scalar<Variant>()().DebugString(), "'" )); |
201 | const int use_count = lock->shared_ptr.use_count(); |
202 | OP_REQUIRES( |
203 | c, use_count == 1, |
204 | errors::InvalidArgument("Expected use count of lock to be 1, but saw: " , |
205 | use_count)); |
206 | } |
207 | |
208 | bool IsExpensive() override { return false; } |
209 | }; |
210 | |
211 | REGISTER_KERNEL_BUILDER(Name("MutexLock" ).Device(DEVICE_CPU), MutexLockOp); |
212 | |
213 | REGISTER_KERNEL_BUILDER(Name("MutexLock" ) |
214 | .Device(DEVICE_DEFAULT) |
215 | .HostMemory("mutex_lock" ) |
216 | .HostMemory("mutex" ), |
217 | MutexLockOp); |
218 | |
219 | REGISTER_KERNEL_BUILDER( |
220 | Name("MutexV2" ).Device(DEVICE_CPU).HostMemory("resource" ), |
221 | ResourceHandleOp<Mutex>); |
222 | |
223 | REGISTER_KERNEL_BUILDER(Name("MutexV2" ).Device(DEVICE_DEFAULT), |
224 | ResourceHandleOp<Mutex>); |
225 | |
226 | REGISTER_KERNEL_BUILDER(Name("ConsumeMutexLock" ).Device(DEVICE_CPU), |
227 | ConsumeMutexLockOp); |
228 | |
229 | REGISTER_KERNEL_BUILDER( |
230 | Name("ConsumeMutexLock" ).Device(DEVICE_DEFAULT).HostMemory("mutex_lock" ), |
231 | ConsumeMutexLockOp); |
232 | |
233 | } // namespace tensorflow |
234 | |