1#include <c10/macros/Macros.h>
2#include <c10/util/C++17.h>
3#include <c10/util/Synchronized.h>
4#include <array>
5#include <atomic>
6#include <functional>
7#include <mutex>
8#include <shared_mutex>
9#include <thread>
10
11namespace c10 {
12
13namespace detail {
14
15struct IncrementRAII final {
16 public:
17 explicit IncrementRAII(std::atomic<int32_t>* counter) : _counter(counter) {
18 _counter->fetch_add(1);
19 }
20
21 ~IncrementRAII() {
22 _counter->fetch_sub(1);
23 }
24
25 private:
26 std::atomic<int32_t>* _counter;
27
28 C10_DISABLE_COPY_AND_ASSIGN(IncrementRAII);
29};
30
31} // namespace detail
32
33// LeftRight wait-free readers synchronization primitive
34// https://hal.archives-ouvertes.fr/hal-01207881/document
35//
36// LeftRight is quite easy to use (it can make an arbitrary
37// data structure permit wait-free reads), but it has some
38// particular performance characteristics you should be aware
39// of if you're deciding to use it:
40//
41// - Reads still incur an atomic write (this is how LeftRight
42// keeps track of how long it needs to keep around the old
43// data structure)
44//
45// - Writes get executed twice, to keep both the left and right
46// versions up to date. So if your write is expensive or
47// nondeterministic, this is also an inappropriate structure
48//
49// LeftRight is used fairly rarely in PyTorch's codebase. If you
50// are still not sure if you need it or not, consult your local
51// C++ expert.
52//
53template <class T>
54class LeftRight final {
55 public:
56 template <class... Args>
57 explicit LeftRight(const Args&... args)
58 : _counters{{{0}, {0}}},
59 _foregroundCounterIndex(0),
60 _foregroundDataIndex(0),
61 _data{{T{args...}, T{args...}}},
62 _writeMutex() {}
63
64 // Copying and moving would not be threadsafe.
65 // Needs more thought and careful design to make that work.
66 LeftRight(const LeftRight&) = delete;
67 LeftRight(LeftRight&&) noexcept = delete;
68 LeftRight& operator=(const LeftRight&) = delete;
69 LeftRight& operator=(LeftRight&&) noexcept = delete;
70
71 ~LeftRight() {
72 // wait until any potentially running writers are finished
73 { std::unique_lock<std::mutex> lock(_writeMutex); }
74
75 // wait until any potentially running readers are finished
76 while (_counters[0].load() != 0 || _counters[1].load() != 0) {
77 std::this_thread::yield();
78 }
79 }
80
81 template <typename F>
82 auto read(F&& readFunc) const -> typename c10::invoke_result_t<F, const T&> {
83 detail::IncrementRAII _increment_counter(
84 &_counters[_foregroundCounterIndex.load()]);
85
86 return readFunc(_data[_foregroundDataIndex.load()]);
87 }
88
89 // Throwing an exception in writeFunc is ok but causes the state to be either
90 // the old or the new state, depending on if the first or the second call to
91 // writeFunc threw.
92 template <typename F>
93 auto write(F&& writeFunc) -> typename c10::invoke_result_t<F, T&> {
94 std::unique_lock<std::mutex> lock(_writeMutex);
95
96 return _write(writeFunc);
97 }
98
99 private:
100 template <class F>
101 auto _write(const F& writeFunc) -> typename c10::invoke_result_t<F, T&> {
102 /*
103 * Assume, A is in background and B in foreground. In simplified terms, we
104 * want to do the following:
105 * 1. Write to A (old background)
106 * 2. Switch A/B
107 * 3. Write to B (new background)
108 *
109 * More detailed algorithm (explanations on why this is important are below
110 * in code):
111 * 1. Write to A
112 * 2. Switch A/B data pointers
113 * 3. Wait until A counter is zero
114 * 4. Switch A/B counters
115 * 5. Wait until B counter is zero
116 * 6. Write to B
117 */
118
119 auto localDataIndex = _foregroundDataIndex.load();
120
121 // 1. Write to A
122 _callWriteFuncOnBackgroundInstance(writeFunc, localDataIndex);
123
124 // 2. Switch A/B data pointers
125 localDataIndex = localDataIndex ^ 1;
126 _foregroundDataIndex = localDataIndex;
127
128 /*
129 * 3. Wait until A counter is zero
130 *
131 * In the previous write run, A was foreground and B was background.
132 * There was a time after switching _foregroundDataIndex (B to foreground)
133 * and before switching _foregroundCounterIndex, in which new readers could
134 * have read B but incremented A's counter.
135 *
136 * In this current run, we just switched _foregroundDataIndex (A back to
137 * foreground), but before writing to the new background B, we have to make
138 * sure A's counter was zero briefly, so all these old readers are gone.
139 */
140 auto localCounterIndex = _foregroundCounterIndex.load();
141 _waitForBackgroundCounterToBeZero(localCounterIndex);
142
143 /*
144 * 4. Switch A/B counters
145 *
146 * Now that we know all readers on B are really gone, we can switch the
147 * counters and have new readers increment A's counter again, which is the
148 * correct counter since they're reading A.
149 */
150 localCounterIndex = localCounterIndex ^ 1;
151 _foregroundCounterIndex = localCounterIndex;
152
153 /*
154 * 5. Wait until B counter is zero
155 *
156 * This waits for all the readers on B that came in while both data and
157 * counter for B was in foreground, i.e. normal readers that happened
158 * outside of that brief gap between switching data and counter.
159 */
160 _waitForBackgroundCounterToBeZero(localCounterIndex);
161
162 // 6. Write to B
163 return _callWriteFuncOnBackgroundInstance(writeFunc, localDataIndex);
164 }
165
166 template <class F>
167 auto _callWriteFuncOnBackgroundInstance(
168 const F& writeFunc,
169 uint8_t localDataIndex) -> typename c10::invoke_result_t<F, T&> {
170 try {
171 return writeFunc(_data[localDataIndex ^ 1]);
172 } catch (...) {
173 // recover invariant by copying from the foreground instance
174 _data[localDataIndex ^ 1] = _data[localDataIndex];
175 // rethrow
176 throw;
177 }
178 }
179
180 void _waitForBackgroundCounterToBeZero(uint8_t counterIndex) {
181 while (_counters[counterIndex ^ 1].load() != 0) {
182 std::this_thread::yield();
183 }
184 }
185
186 mutable std::array<std::atomic<int32_t>, 2> _counters;
187 std::atomic<uint8_t> _foregroundCounterIndex;
188 std::atomic<uint8_t> _foregroundDataIndex;
189 std::array<T, 2> _data;
190 std::mutex _writeMutex;
191};
192
193// RWSafeLeftRightWrapper is API compatible with LeftRight and uses a
194// read-write lock to protect T (data).
195template <class T>
196class RWSafeLeftRightWrapper final {
197 public:
198 template <class... Args>
199 explicit RWSafeLeftRightWrapper(const Args&... args) : data_{args...} {}
200
201 // RWSafeLeftRightWrapper is not copyable or moveable since LeftRight
202 // is not copyable or moveable.
203 RWSafeLeftRightWrapper(const RWSafeLeftRightWrapper&) = delete;
204 RWSafeLeftRightWrapper(RWSafeLeftRightWrapper&&) noexcept = delete;
205 RWSafeLeftRightWrapper& operator=(const RWSafeLeftRightWrapper&) = delete;
206 RWSafeLeftRightWrapper& operator=(RWSafeLeftRightWrapper&&) noexcept = delete;
207
208 template <typename F>
209 auto read(F&& readFunc) const -> typename c10::invoke_result_t<F, const T&> {
210 return data_.withLock(
211 [&readFunc](T const& data) { return readFunc(data); });
212 }
213
214 template <typename F>
215 auto write(F&& writeFunc) -> typename c10::invoke_result_t<F, T&> {
216 return data_.withLock([&writeFunc](T& data) { return writeFunc(data); });
217 }
218
219 private:
220 c10::Synchronized<T> data_;
221};
222
223} // namespace c10
224