1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_TSL_PLATFORM_REFCOUNT_H_
17#define TENSORFLOW_TSL_PLATFORM_REFCOUNT_H_
18
19#include <atomic>
20#include <map>
21#include <memory>
22
23#include "tensorflow/tsl/platform/logging.h"
24#include "tensorflow/tsl/platform/mutex.h"
25#include "tensorflow/tsl/platform/thread_annotations.h"
26
27namespace tsl {
28namespace core {
29
30class RefCounted {
31 public:
32 // Initial reference count is one.
33 RefCounted();
34
35 // Increments reference count by one.
36 void Ref() const;
37
38 // Decrements reference count by one. If the count remains
39 // positive, returns false. When the count reaches zero, returns
40 // true and deletes this, in which case the caller must not access
41 // the object afterward.
42 bool Unref() const;
43
44 // Gets the current reference count.
45 int_fast32_t RefCount() const;
46
47 // Return whether the reference count is one.
48 // If the reference count is used in the conventional way, a
49 // reference count of 1 implies that the current thread owns the
50 // reference and no other thread shares it.
51 // This call performs the test for a reference count of one, and
52 // performs the memory barrier needed for the owning thread
53 // to act on the object, knowing that it has exclusive access to the
54 // object.
55 bool RefCountIsOne() const;
56
57 protected:
58 // Make destructor protected so that RefCounted objects cannot
59 // be instantiated directly. Only subclasses can be instantiated.
60 virtual ~RefCounted();
61
62 // Increments reference count by one if the object is not being destructed.
63 // This function is used by WeakRefCounted for securely acquiring a
64 // strong reference. It is only safe to call this as part of the weak
65 // reference implementation.
66 bool TryRef() const;
67
68 private:
69 mutable std::atomic_int_fast32_t ref_;
70
71 RefCounted(const RefCounted&) = delete;
72 void operator=(const RefCounted&) = delete;
73};
74
75// A deleter class to form a std::unique_ptr that unrefs objects.
76struct RefCountDeleter {
77 void operator()(const RefCounted* o) const { o->Unref(); }
78};
79
80// A unique_ptr that unrefs the owned object on destruction.
81template <typename T>
82using RefCountPtr = std::unique_ptr<T, RefCountDeleter>;
83
84// Helper class to unref an object when out-of-scope.
85class ScopedUnref {
86 public:
87 explicit ScopedUnref(const RefCounted* o) : obj_(o) {}
88 ~ScopedUnref() {
89 if (obj_) obj_->Unref();
90 }
91
92 private:
93 const RefCounted* obj_;
94
95 ScopedUnref(const ScopedUnref&) = delete;
96 void operator=(const ScopedUnref&) = delete;
97};
98
99// Forward declaration for friend class of WeakRefCounted.
100template <typename T>
101class WeakPtr;
102
103// A WeakNotifyFn is called when the weakly referred object is being destroyed.
104// The object may already be destructed when the call occurs. A WeakNotifyFn
105// can be passed into WeakPtr at construction.
106using WeakNotifyFn = std::function<void()>;
107
108// A base class for RefCounted objects that allow weak references by WeakPtr.
109// WeakRefCounted and every WeakPtr to it, each holds a strong reference to a
110// WeakRefData.
111//
112// If the WeakRefCounted is valid, WeakPtr::GetNewRef() returns a new strong
113// reference to the WeakRefCounted.
114// If the WeakRefCounted is being destructed, `WeakRefCounted::ref_ == 0`;
115// if the WeakRefcounted is already destructed,`WeakRefData::ptr == nullptr`.
116// In either case, WeakPtr::GetNewRef() returns a nullptr.
117class WeakRefCounted : public RefCounted {
118 public:
119 int WeakRefCount() const {
120 // Each weak ref owns one ref to data_, and *this owns the last one.
121 return data_->RefCount() - 1;
122 }
123
124 protected:
125 ~WeakRefCounted() override { data_->Notify(); }
126
127 private:
128 struct WeakRefData : public RefCounted {
129 explicit WeakRefData(WeakRefCounted* ptr) : ptr(ptr), next_notifier_id(1) {}
130
131 mutable mutex mu;
132 WeakRefCounted* ptr TF_GUARDED_BY(mu);
133 std::map<int, WeakNotifyFn> notifiers;
134 int next_notifier_id;
135
136 // Notifies WeakPtr instansces that this object is being destructed.
137 void Notify() {
138 mutex_lock ml(mu);
139
140 while (!notifiers.empty()) {
141 auto iter = notifiers.begin();
142 WeakNotifyFn notify_fn = std::move(iter->second);
143 notifiers.erase(iter);
144
145 mu.unlock();
146 notify_fn();
147 mu.lock();
148 }
149 ptr = nullptr;
150 }
151
152 WeakRefCounted* GetNewRef() {
153 mutex_lock ml(mu);
154 if (ptr != nullptr && ptr->TryRef()) {
155 return ptr;
156 }
157 return nullptr;
158 }
159
160 // Inserts notify_fn and returns a non-zero id.
161 // Returns 0 if insertion fails due to the object is being destroyed.
162 // 0 is also used by WeakPtr to represent "no notify_fn".
163 int AddNotifier(WeakNotifyFn notify_fn) {
164 mutex_lock ml(mu);
165 if (ptr == nullptr) {
166 return 0;
167 }
168 int notifier_id = next_notifier_id++;
169 notifiers.emplace(notifier_id, std::move(notify_fn));
170 return notifier_id;
171 }
172
173 void RemoveNotifier(int notifier_id) {
174 mutex_lock ml(mu);
175 notifiers.erase(notifier_id);
176 }
177 };
178
179 RefCountPtr<WeakRefData> data_{new WeakRefData(this)};
180
181 template <typename T>
182 friend class WeakPtr;
183 // MSVC14 workaround: access permission of a nested class member is not
184 // treated as an ordinary member in MSVC14.
185 friend struct WeakRefData;
186};
187
188// A weak reference to a WeakRefCounted object. Refer to WeakRefCounted.
189template <typename T>
190class WeakPtr {
191 public:
192 // Creates a weak reference.
193 // When the object is being destroyed, notify_fn is called.
194 explicit WeakPtr(WeakRefCounted* ptr, WeakNotifyFn notify_fn = nullptr)
195 : data_(nullptr), notifier_id_(0) {
196 if (ptr != nullptr) {
197 ptr->data_->Ref();
198 data_.reset(ptr->data_.get());
199 if (notify_fn) {
200 notifier_id_ = data_->AddNotifier(notify_fn);
201 }
202 }
203 }
204
205 ~WeakPtr() {
206 if (data_ != nullptr && notifier_id_ != 0) {
207 data_->RemoveNotifier(notifier_id_);
208 }
209 }
210
211 // NOTE(feyu): change data_ to a IntrusivePtr to make WeakPtr copyable.
212 WeakPtr(const WeakPtr& other) = delete;
213 WeakPtr& operator=(const WeakPtr& other) = delete;
214
215 WeakPtr(WeakPtr&& other) {
216 data_ = std::move(other.data_);
217 notifier_id_ = other.notifier_id_;
218 other.notifier_id_ = 0;
219 }
220
221 WeakPtr& operator=(WeakPtr&& other) {
222 if (this != &other) {
223 if (data_ != nullptr && notifier_id_ != 0) {
224 data_->RemoveNotifier(notifier_id_);
225 }
226 data_ = std::move(other.data_);
227 notifier_id_ = other.notifier_id_;
228 other.notifier_id_ = 0;
229 }
230 return *this;
231 }
232
233 // Returns a new strong reference to the referred object, or nullptr if the
234 // object is in an invalid state (being destructed or already destructed).
235 RefCountPtr<T> GetNewRef() const {
236 RefCountPtr<T> ref;
237 if (data_ != nullptr) {
238 WeakRefCounted* ptr = data_->GetNewRef();
239 ref.reset(static_cast<T*>(ptr));
240 }
241 return std::move(ref);
242 }
243
244 private:
245 RefCountPtr<WeakRefCounted::WeakRefData> data_;
246 int notifier_id_;
247};
248
249// Inlined routines, since these are performance critical
250inline RefCounted::RefCounted() : ref_(1) {}
251
252inline RefCounted::~RefCounted() {
253 // A destructing object has ref_ == 0.
254 // It is a bug if the object is resurrected (ref_ > 0) before delete is
255 // called by Unref().
256 DCHECK_EQ(ref_.load(), 0);
257}
258
259inline void RefCounted::Ref() const {
260 // Ref() uses relaxed order because it is never called with old_ref == 0.
261 // When old_ref >= 1, no actions depend on the new value of ref.
262 int_fast32_t old_ref = ref_.fetch_add(1, std::memory_order_relaxed);
263 DCHECK_GT(old_ref, 0);
264}
265
266inline bool RefCounted::TryRef() const {
267 // This is not on a hot path.
268 // Be conservative and use seq_cst to prevent racing with Unref() when
269 // old_ref == 0, as done in LLVM libstdc++.
270 int_fast32_t old_ref = ref_.load();
271 while (old_ref != 0) {
272 if (ref_.compare_exchange_weak(old_ref, old_ref + 1)) {
273 return true;
274 }
275 }
276 // Already destructing, cannot increase ref.
277 return false;
278}
279
280inline bool RefCounted::Unref() const {
281 DCHECK_GT(ref_.load(), 0);
282 // acq_rel is used to prevent reordering introduces object access after
283 // destruction.
284
285 // Using release alone is a bug on systems where acq_rel differs from release.
286 // (e.g. arm), according to Herb Sutter's 2012 talk on "Atomic<> Weapons".
287 if (ref_.fetch_sub(1, std::memory_order_acq_rel) == 1) {
288 delete this;
289 return true;
290 }
291 return false;
292}
293
294inline int_fast32_t RefCounted::RefCount() const {
295 return ref_.load(std::memory_order_acquire);
296}
297
298inline bool RefCounted::RefCountIsOne() const {
299 return (ref_.load(std::memory_order_acquire) == 1);
300}
301
302} // namespace core
303} // namespace tsl
304
305#endif // TENSORFLOW_TSL_PLATFORM_REFCOUNT_H_
306