1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_TSL_PLATFORM_REFCOUNT_H_ |
17 | #define TENSORFLOW_TSL_PLATFORM_REFCOUNT_H_ |
18 | |
19 | #include <atomic> |
20 | #include <map> |
21 | #include <memory> |
22 | |
23 | #include "tensorflow/tsl/platform/logging.h" |
24 | #include "tensorflow/tsl/platform/mutex.h" |
25 | #include "tensorflow/tsl/platform/thread_annotations.h" |
26 | |
27 | namespace tsl { |
28 | namespace core { |
29 | |
30 | class RefCounted { |
31 | public: |
32 | // Initial reference count is one. |
33 | RefCounted(); |
34 | |
35 | // Increments reference count by one. |
36 | void Ref() const; |
37 | |
38 | // Decrements reference count by one. If the count remains |
39 | // positive, returns false. When the count reaches zero, returns |
40 | // true and deletes this, in which case the caller must not access |
41 | // the object afterward. |
42 | bool Unref() const; |
43 | |
44 | // Gets the current reference count. |
45 | int_fast32_t RefCount() const; |
46 | |
47 | // Return whether the reference count is one. |
48 | // If the reference count is used in the conventional way, a |
49 | // reference count of 1 implies that the current thread owns the |
50 | // reference and no other thread shares it. |
51 | // This call performs the test for a reference count of one, and |
52 | // performs the memory barrier needed for the owning thread |
53 | // to act on the object, knowing that it has exclusive access to the |
54 | // object. |
55 | bool RefCountIsOne() const; |
56 | |
57 | protected: |
58 | // Make destructor protected so that RefCounted objects cannot |
59 | // be instantiated directly. Only subclasses can be instantiated. |
60 | virtual ~RefCounted(); |
61 | |
62 | // Increments reference count by one if the object is not being destructed. |
63 | // This function is used by WeakRefCounted for securely acquiring a |
64 | // strong reference. It is only safe to call this as part of the weak |
65 | // reference implementation. |
66 | bool TryRef() const; |
67 | |
68 | private: |
69 | mutable std::atomic_int_fast32_t ref_; |
70 | |
71 | RefCounted(const RefCounted&) = delete; |
72 | void operator=(const RefCounted&) = delete; |
73 | }; |
74 | |
75 | // A deleter class to form a std::unique_ptr that unrefs objects. |
76 | struct RefCountDeleter { |
77 | void operator()(const RefCounted* o) const { o->Unref(); } |
78 | }; |
79 | |
80 | // A unique_ptr that unrefs the owned object on destruction. |
81 | template <typename T> |
82 | using RefCountPtr = std::unique_ptr<T, RefCountDeleter>; |
83 | |
84 | // Helper class to unref an object when out-of-scope. |
85 | class ScopedUnref { |
86 | public: |
87 | explicit ScopedUnref(const RefCounted* o) : obj_(o) {} |
88 | ~ScopedUnref() { |
89 | if (obj_) obj_->Unref(); |
90 | } |
91 | |
92 | private: |
93 | const RefCounted* obj_; |
94 | |
95 | ScopedUnref(const ScopedUnref&) = delete; |
96 | void operator=(const ScopedUnref&) = delete; |
97 | }; |
98 | |
99 | // Forward declaration for friend class of WeakRefCounted. |
100 | template <typename T> |
101 | class WeakPtr; |
102 | |
103 | // A WeakNotifyFn is called when the weakly referred object is being destroyed. |
104 | // The object may already be destructed when the call occurs. A WeakNotifyFn |
105 | // can be passed into WeakPtr at construction. |
106 | using WeakNotifyFn = std::function<void()>; |
107 | |
108 | // A base class for RefCounted objects that allow weak references by WeakPtr. |
109 | // WeakRefCounted and every WeakPtr to it, each holds a strong reference to a |
110 | // WeakRefData. |
111 | // |
112 | // If the WeakRefCounted is valid, WeakPtr::GetNewRef() returns a new strong |
113 | // reference to the WeakRefCounted. |
114 | // If the WeakRefCounted is being destructed, `WeakRefCounted::ref_ == 0`; |
115 | // if the WeakRefcounted is already destructed,`WeakRefData::ptr == nullptr`. |
116 | // In either case, WeakPtr::GetNewRef() returns a nullptr. |
117 | class WeakRefCounted : public RefCounted { |
118 | public: |
119 | int WeakRefCount() const { |
120 | // Each weak ref owns one ref to data_, and *this owns the last one. |
121 | return data_->RefCount() - 1; |
122 | } |
123 | |
124 | protected: |
125 | ~WeakRefCounted() override { data_->Notify(); } |
126 | |
127 | private: |
128 | struct WeakRefData : public RefCounted { |
129 | explicit WeakRefData(WeakRefCounted* ptr) : ptr(ptr), next_notifier_id(1) {} |
130 | |
131 | mutable mutex mu; |
132 | WeakRefCounted* ptr TF_GUARDED_BY(mu); |
133 | std::map<int, WeakNotifyFn> notifiers; |
134 | int next_notifier_id; |
135 | |
136 | // Notifies WeakPtr instansces that this object is being destructed. |
137 | void Notify() { |
138 | mutex_lock ml(mu); |
139 | |
140 | while (!notifiers.empty()) { |
141 | auto iter = notifiers.begin(); |
142 | WeakNotifyFn notify_fn = std::move(iter->second); |
143 | notifiers.erase(iter); |
144 | |
145 | mu.unlock(); |
146 | notify_fn(); |
147 | mu.lock(); |
148 | } |
149 | ptr = nullptr; |
150 | } |
151 | |
152 | WeakRefCounted* GetNewRef() { |
153 | mutex_lock ml(mu); |
154 | if (ptr != nullptr && ptr->TryRef()) { |
155 | return ptr; |
156 | } |
157 | return nullptr; |
158 | } |
159 | |
160 | // Inserts notify_fn and returns a non-zero id. |
161 | // Returns 0 if insertion fails due to the object is being destroyed. |
162 | // 0 is also used by WeakPtr to represent "no notify_fn". |
163 | int AddNotifier(WeakNotifyFn notify_fn) { |
164 | mutex_lock ml(mu); |
165 | if (ptr == nullptr) { |
166 | return 0; |
167 | } |
168 | int notifier_id = next_notifier_id++; |
169 | notifiers.emplace(notifier_id, std::move(notify_fn)); |
170 | return notifier_id; |
171 | } |
172 | |
173 | void RemoveNotifier(int notifier_id) { |
174 | mutex_lock ml(mu); |
175 | notifiers.erase(notifier_id); |
176 | } |
177 | }; |
178 | |
179 | RefCountPtr<WeakRefData> data_{new WeakRefData(this)}; |
180 | |
181 | template <typename T> |
182 | friend class WeakPtr; |
183 | // MSVC14 workaround: access permission of a nested class member is not |
184 | // treated as an ordinary member in MSVC14. |
185 | friend struct WeakRefData; |
186 | }; |
187 | |
188 | // A weak reference to a WeakRefCounted object. Refer to WeakRefCounted. |
189 | template <typename T> |
190 | class WeakPtr { |
191 | public: |
192 | // Creates a weak reference. |
193 | // When the object is being destroyed, notify_fn is called. |
194 | explicit WeakPtr(WeakRefCounted* ptr, WeakNotifyFn notify_fn = nullptr) |
195 | : data_(nullptr), notifier_id_(0) { |
196 | if (ptr != nullptr) { |
197 | ptr->data_->Ref(); |
198 | data_.reset(ptr->data_.get()); |
199 | if (notify_fn) { |
200 | notifier_id_ = data_->AddNotifier(notify_fn); |
201 | } |
202 | } |
203 | } |
204 | |
205 | ~WeakPtr() { |
206 | if (data_ != nullptr && notifier_id_ != 0) { |
207 | data_->RemoveNotifier(notifier_id_); |
208 | } |
209 | } |
210 | |
211 | // NOTE(feyu): change data_ to a IntrusivePtr to make WeakPtr copyable. |
212 | WeakPtr(const WeakPtr& other) = delete; |
213 | WeakPtr& operator=(const WeakPtr& other) = delete; |
214 | |
215 | WeakPtr(WeakPtr&& other) { |
216 | data_ = std::move(other.data_); |
217 | notifier_id_ = other.notifier_id_; |
218 | other.notifier_id_ = 0; |
219 | } |
220 | |
221 | WeakPtr& operator=(WeakPtr&& other) { |
222 | if (this != &other) { |
223 | if (data_ != nullptr && notifier_id_ != 0) { |
224 | data_->RemoveNotifier(notifier_id_); |
225 | } |
226 | data_ = std::move(other.data_); |
227 | notifier_id_ = other.notifier_id_; |
228 | other.notifier_id_ = 0; |
229 | } |
230 | return *this; |
231 | } |
232 | |
233 | // Returns a new strong reference to the referred object, or nullptr if the |
234 | // object is in an invalid state (being destructed or already destructed). |
235 | RefCountPtr<T> GetNewRef() const { |
236 | RefCountPtr<T> ref; |
237 | if (data_ != nullptr) { |
238 | WeakRefCounted* ptr = data_->GetNewRef(); |
239 | ref.reset(static_cast<T*>(ptr)); |
240 | } |
241 | return std::move(ref); |
242 | } |
243 | |
244 | private: |
245 | RefCountPtr<WeakRefCounted::WeakRefData> data_; |
246 | int notifier_id_; |
247 | }; |
248 | |
249 | // Inlined routines, since these are performance critical |
250 | inline RefCounted::RefCounted() : ref_(1) {} |
251 | |
252 | inline RefCounted::~RefCounted() { |
253 | // A destructing object has ref_ == 0. |
254 | // It is a bug if the object is resurrected (ref_ > 0) before delete is |
255 | // called by Unref(). |
256 | DCHECK_EQ(ref_.load(), 0); |
257 | } |
258 | |
259 | inline void RefCounted::Ref() const { |
260 | // Ref() uses relaxed order because it is never called with old_ref == 0. |
261 | // When old_ref >= 1, no actions depend on the new value of ref. |
262 | int_fast32_t old_ref = ref_.fetch_add(1, std::memory_order_relaxed); |
263 | DCHECK_GT(old_ref, 0); |
264 | } |
265 | |
266 | inline bool RefCounted::TryRef() const { |
267 | // This is not on a hot path. |
268 | // Be conservative and use seq_cst to prevent racing with Unref() when |
269 | // old_ref == 0, as done in LLVM libstdc++. |
270 | int_fast32_t old_ref = ref_.load(); |
271 | while (old_ref != 0) { |
272 | if (ref_.compare_exchange_weak(old_ref, old_ref + 1)) { |
273 | return true; |
274 | } |
275 | } |
276 | // Already destructing, cannot increase ref. |
277 | return false; |
278 | } |
279 | |
280 | inline bool RefCounted::Unref() const { |
281 | DCHECK_GT(ref_.load(), 0); |
282 | // acq_rel is used to prevent reordering introduces object access after |
283 | // destruction. |
284 | |
285 | // Using release alone is a bug on systems where acq_rel differs from release. |
286 | // (e.g. arm), according to Herb Sutter's 2012 talk on "Atomic<> Weapons". |
287 | if (ref_.fetch_sub(1, std::memory_order_acq_rel) == 1) { |
288 | delete this; |
289 | return true; |
290 | } |
291 | return false; |
292 | } |
293 | |
294 | inline int_fast32_t RefCounted::RefCount() const { |
295 | return ref_.load(std::memory_order_acquire); |
296 | } |
297 | |
298 | inline bool RefCounted::RefCountIsOne() const { |
299 | return (ref_.load(std::memory_order_acquire) == 1); |
300 | } |
301 | |
302 | } // namespace core |
303 | } // namespace tsl |
304 | |
305 | #endif // TENSORFLOW_TSL_PLATFORM_REFCOUNT_H_ |
306 | |