1 | #include <c10/macros/Macros.h> |
2 | #include <c10/util/C++17.h> |
3 | #include <c10/util/Synchronized.h> |
4 | #include <array> |
5 | #include <atomic> |
6 | #include <functional> |
7 | #include <mutex> |
8 | #include <shared_mutex> |
9 | #include <thread> |
10 | |
11 | namespace c10 { |
12 | |
13 | namespace detail { |
14 | |
15 | struct IncrementRAII final { |
16 | public: |
17 | explicit IncrementRAII(std::atomic<int32_t>* counter) : _counter(counter) { |
18 | _counter->fetch_add(1); |
19 | } |
20 | |
21 | ~IncrementRAII() { |
22 | _counter->fetch_sub(1); |
23 | } |
24 | |
25 | private: |
26 | std::atomic<int32_t>* _counter; |
27 | |
28 | C10_DISABLE_COPY_AND_ASSIGN(IncrementRAII); |
29 | }; |
30 | |
31 | } // namespace detail |
32 | |
33 | // LeftRight wait-free readers synchronization primitive |
34 | // https://hal.archives-ouvertes.fr/hal-01207881/document |
35 | // |
36 | // LeftRight is quite easy to use (it can make an arbitrary |
37 | // data structure permit wait-free reads), but it has some |
38 | // particular performance characteristics you should be aware |
39 | // of if you're deciding to use it: |
40 | // |
41 | // - Reads still incur an atomic write (this is how LeftRight |
42 | // keeps track of how long it needs to keep around the old |
43 | // data structure) |
44 | // |
45 | // - Writes get executed twice, to keep both the left and right |
46 | // versions up to date. So if your write is expensive or |
47 | // nondeterministic, this is also an inappropriate structure |
48 | // |
49 | // LeftRight is used fairly rarely in PyTorch's codebase. If you |
50 | // are still not sure if you need it or not, consult your local |
51 | // C++ expert. |
52 | // |
53 | template <class T> |
54 | class LeftRight final { |
55 | public: |
56 | template <class... Args> |
57 | explicit LeftRight(const Args&... args) |
58 | : _counters{{{0}, {0}}}, |
59 | _foregroundCounterIndex(0), |
60 | _foregroundDataIndex(0), |
61 | _data{{T{args...}, T{args...}}}, |
62 | _writeMutex() {} |
63 | |
64 | // Copying and moving would not be threadsafe. |
65 | // Needs more thought and careful design to make that work. |
66 | LeftRight(const LeftRight&) = delete; |
67 | LeftRight(LeftRight&&) noexcept = delete; |
68 | LeftRight& operator=(const LeftRight&) = delete; |
69 | LeftRight& operator=(LeftRight&&) noexcept = delete; |
70 | |
71 | ~LeftRight() { |
72 | // wait until any potentially running writers are finished |
73 | { std::unique_lock<std::mutex> lock(_writeMutex); } |
74 | |
75 | // wait until any potentially running readers are finished |
76 | while (_counters[0].load() != 0 || _counters[1].load() != 0) { |
77 | std::this_thread::yield(); |
78 | } |
79 | } |
80 | |
81 | template <typename F> |
82 | auto read(F&& readFunc) const -> typename c10::invoke_result_t<F, const T&> { |
83 | detail::IncrementRAII _increment_counter( |
84 | &_counters[_foregroundCounterIndex.load()]); |
85 | |
86 | return readFunc(_data[_foregroundDataIndex.load()]); |
87 | } |
88 | |
89 | // Throwing an exception in writeFunc is ok but causes the state to be either |
90 | // the old or the new state, depending on if the first or the second call to |
91 | // writeFunc threw. |
92 | template <typename F> |
93 | auto write(F&& writeFunc) -> typename c10::invoke_result_t<F, T&> { |
94 | std::unique_lock<std::mutex> lock(_writeMutex); |
95 | |
96 | return _write(writeFunc); |
97 | } |
98 | |
99 | private: |
100 | template <class F> |
101 | auto _write(const F& writeFunc) -> typename c10::invoke_result_t<F, T&> { |
102 | /* |
103 | * Assume, A is in background and B in foreground. In simplified terms, we |
104 | * want to do the following: |
105 | * 1. Write to A (old background) |
106 | * 2. Switch A/B |
107 | * 3. Write to B (new background) |
108 | * |
109 | * More detailed algorithm (explanations on why this is important are below |
110 | * in code): |
111 | * 1. Write to A |
112 | * 2. Switch A/B data pointers |
113 | * 3. Wait until A counter is zero |
114 | * 4. Switch A/B counters |
115 | * 5. Wait until B counter is zero |
116 | * 6. Write to B |
117 | */ |
118 | |
119 | auto localDataIndex = _foregroundDataIndex.load(); |
120 | |
121 | // 1. Write to A |
122 | _callWriteFuncOnBackgroundInstance(writeFunc, localDataIndex); |
123 | |
124 | // 2. Switch A/B data pointers |
125 | localDataIndex = localDataIndex ^ 1; |
126 | _foregroundDataIndex = localDataIndex; |
127 | |
128 | /* |
129 | * 3. Wait until A counter is zero |
130 | * |
131 | * In the previous write run, A was foreground and B was background. |
132 | * There was a time after switching _foregroundDataIndex (B to foreground) |
133 | * and before switching _foregroundCounterIndex, in which new readers could |
134 | * have read B but incremented A's counter. |
135 | * |
136 | * In this current run, we just switched _foregroundDataIndex (A back to |
137 | * foreground), but before writing to the new background B, we have to make |
138 | * sure A's counter was zero briefly, so all these old readers are gone. |
139 | */ |
140 | auto localCounterIndex = _foregroundCounterIndex.load(); |
141 | _waitForBackgroundCounterToBeZero(localCounterIndex); |
142 | |
143 | /* |
144 | * 4. Switch A/B counters |
145 | * |
146 | * Now that we know all readers on B are really gone, we can switch the |
147 | * counters and have new readers increment A's counter again, which is the |
148 | * correct counter since they're reading A. |
149 | */ |
150 | localCounterIndex = localCounterIndex ^ 1; |
151 | _foregroundCounterIndex = localCounterIndex; |
152 | |
153 | /* |
154 | * 5. Wait until B counter is zero |
155 | * |
156 | * This waits for all the readers on B that came in while both data and |
157 | * counter for B was in foreground, i.e. normal readers that happened |
158 | * outside of that brief gap between switching data and counter. |
159 | */ |
160 | _waitForBackgroundCounterToBeZero(localCounterIndex); |
161 | |
162 | // 6. Write to B |
163 | return _callWriteFuncOnBackgroundInstance(writeFunc, localDataIndex); |
164 | } |
165 | |
166 | template <class F> |
167 | auto _callWriteFuncOnBackgroundInstance( |
168 | const F& writeFunc, |
169 | uint8_t localDataIndex) -> typename c10::invoke_result_t<F, T&> { |
170 | try { |
171 | return writeFunc(_data[localDataIndex ^ 1]); |
172 | } catch (...) { |
173 | // recover invariant by copying from the foreground instance |
174 | _data[localDataIndex ^ 1] = _data[localDataIndex]; |
175 | // rethrow |
176 | throw; |
177 | } |
178 | } |
179 | |
180 | void _waitForBackgroundCounterToBeZero(uint8_t counterIndex) { |
181 | while (_counters[counterIndex ^ 1].load() != 0) { |
182 | std::this_thread::yield(); |
183 | } |
184 | } |
185 | |
186 | mutable std::array<std::atomic<int32_t>, 2> _counters; |
187 | std::atomic<uint8_t> _foregroundCounterIndex; |
188 | std::atomic<uint8_t> _foregroundDataIndex; |
189 | std::array<T, 2> _data; |
190 | std::mutex _writeMutex; |
191 | }; |
192 | |
193 | // RWSafeLeftRightWrapper is API compatible with LeftRight and uses a |
194 | // read-write lock to protect T (data). |
195 | template <class T> |
196 | class RWSafeLeftRightWrapper final { |
197 | public: |
198 | template <class... Args> |
199 | explicit RWSafeLeftRightWrapper(const Args&... args) : data_{args...} {} |
200 | |
201 | // RWSafeLeftRightWrapper is not copyable or moveable since LeftRight |
202 | // is not copyable or moveable. |
203 | RWSafeLeftRightWrapper(const RWSafeLeftRightWrapper&) = delete; |
204 | RWSafeLeftRightWrapper(RWSafeLeftRightWrapper&&) noexcept = delete; |
205 | RWSafeLeftRightWrapper& operator=(const RWSafeLeftRightWrapper&) = delete; |
206 | RWSafeLeftRightWrapper& operator=(RWSafeLeftRightWrapper&&) noexcept = delete; |
207 | |
208 | template <typename F> |
209 | auto read(F&& readFunc) const -> typename c10::invoke_result_t<F, const T&> { |
210 | return data_.withLock( |
211 | [&readFunc](T const& data) { return readFunc(data); }); |
212 | } |
213 | |
214 | template <typename F> |
215 | auto write(F&& writeFunc) -> typename c10::invoke_result_t<F, T&> { |
216 | return data_.withLock([&writeFunc](T& data) { return writeFunc(data); }); |
217 | } |
218 | |
219 | private: |
220 | c10::Synchronized<T> data_; |
221 | }; |
222 | |
223 | } // namespace c10 |
224 | |