1 | // Copyright (c) 2011 The LevelDB Authors. All rights reserved. |
2 | // Use of this source code is governed by a BSD-style license that can be |
3 | // found in the LICENSE file. See the AUTHORS file for names of contributors. |
4 | |
5 | #include "db/skiplist.h" |
6 | |
7 | #include <atomic> |
8 | #include <set> |
9 | |
10 | #include "gtest/gtest.h" |
11 | #include "leveldb/env.h" |
12 | #include "port/port.h" |
13 | #include "port/thread_annotations.h" |
14 | #include "util/arena.h" |
15 | #include "util/hash.h" |
16 | #include "util/random.h" |
17 | #include "util/testutil.h" |
18 | |
19 | namespace leveldb { |
20 | |
21 | typedef uint64_t Key; |
22 | |
23 | struct Comparator { |
24 | int operator()(const Key& a, const Key& b) const { |
25 | if (a < b) { |
26 | return -1; |
27 | } else if (a > b) { |
28 | return +1; |
29 | } else { |
30 | return 0; |
31 | } |
32 | } |
33 | }; |
34 | |
35 | TEST(SkipTest, Empty) { |
36 | Arena arena; |
37 | Comparator cmp; |
38 | SkipList<Key, Comparator> list(cmp, &arena); |
39 | ASSERT_TRUE(!list.Contains(10)); |
40 | |
41 | SkipList<Key, Comparator>::Iterator iter(&list); |
42 | ASSERT_TRUE(!iter.Valid()); |
43 | iter.SeekToFirst(); |
44 | ASSERT_TRUE(!iter.Valid()); |
45 | iter.Seek(100); |
46 | ASSERT_TRUE(!iter.Valid()); |
47 | iter.SeekToLast(); |
48 | ASSERT_TRUE(!iter.Valid()); |
49 | } |
50 | |
51 | TEST(SkipTest, InsertAndLookup) { |
52 | const int N = 2000; |
53 | const int R = 5000; |
54 | Random rnd(1000); |
55 | std::set<Key> keys; |
56 | Arena arena; |
57 | Comparator cmp; |
58 | SkipList<Key, Comparator> list(cmp, &arena); |
59 | for (int i = 0; i < N; i++) { |
60 | Key key = rnd.Next() % R; |
61 | if (keys.insert(key).second) { |
62 | list.Insert(key); |
63 | } |
64 | } |
65 | |
66 | for (int i = 0; i < R; i++) { |
67 | if (list.Contains(i)) { |
68 | ASSERT_EQ(keys.count(i), 1); |
69 | } else { |
70 | ASSERT_EQ(keys.count(i), 0); |
71 | } |
72 | } |
73 | |
74 | // Simple iterator tests |
75 | { |
76 | SkipList<Key, Comparator>::Iterator iter(&list); |
77 | ASSERT_TRUE(!iter.Valid()); |
78 | |
79 | iter.Seek(0); |
80 | ASSERT_TRUE(iter.Valid()); |
81 | ASSERT_EQ(*(keys.begin()), iter.key()); |
82 | |
83 | iter.SeekToFirst(); |
84 | ASSERT_TRUE(iter.Valid()); |
85 | ASSERT_EQ(*(keys.begin()), iter.key()); |
86 | |
87 | iter.SeekToLast(); |
88 | ASSERT_TRUE(iter.Valid()); |
89 | ASSERT_EQ(*(keys.rbegin()), iter.key()); |
90 | } |
91 | |
92 | // Forward iteration test |
93 | for (int i = 0; i < R; i++) { |
94 | SkipList<Key, Comparator>::Iterator iter(&list); |
95 | iter.Seek(i); |
96 | |
97 | // Compare against model iterator |
98 | std::set<Key>::iterator model_iter = keys.lower_bound(i); |
99 | for (int j = 0; j < 3; j++) { |
100 | if (model_iter == keys.end()) { |
101 | ASSERT_TRUE(!iter.Valid()); |
102 | break; |
103 | } else { |
104 | ASSERT_TRUE(iter.Valid()); |
105 | ASSERT_EQ(*model_iter, iter.key()); |
106 | ++model_iter; |
107 | iter.Next(); |
108 | } |
109 | } |
110 | } |
111 | |
112 | // Backward iteration test |
113 | { |
114 | SkipList<Key, Comparator>::Iterator iter(&list); |
115 | iter.SeekToLast(); |
116 | |
117 | // Compare against model iterator |
118 | for (std::set<Key>::reverse_iterator model_iter = keys.rbegin(); |
119 | model_iter != keys.rend(); ++model_iter) { |
120 | ASSERT_TRUE(iter.Valid()); |
121 | ASSERT_EQ(*model_iter, iter.key()); |
122 | iter.Prev(); |
123 | } |
124 | ASSERT_TRUE(!iter.Valid()); |
125 | } |
126 | } |
127 | |
128 | // We want to make sure that with a single writer and multiple |
129 | // concurrent readers (with no synchronization other than when a |
130 | // reader's iterator is created), the reader always observes all the |
131 | // data that was present in the skip list when the iterator was |
132 | // constructed. Because insertions are happening concurrently, we may |
133 | // also observe new values that were inserted since the iterator was |
134 | // constructed, but we should never miss any values that were present |
135 | // at iterator construction time. |
136 | // |
137 | // We generate multi-part keys: |
138 | // <key,gen,hash> |
139 | // where: |
140 | // key is in range [0..K-1] |
141 | // gen is a generation number for key |
142 | // hash is hash(key,gen) |
143 | // |
144 | // The insertion code picks a random key, sets gen to be 1 + the last |
145 | // generation number inserted for that key, and sets hash to Hash(key,gen). |
146 | // |
147 | // At the beginning of a read, we snapshot the last inserted |
148 | // generation number for each key. We then iterate, including random |
149 | // calls to Next() and Seek(). For every key we encounter, we |
150 | // check that it is either expected given the initial snapshot or has |
151 | // been concurrently added since the iterator started. |
152 | class ConcurrentTest { |
153 | private: |
154 | static constexpr uint32_t K = 4; |
155 | |
156 | static uint64_t key(Key key) { return (key >> 40); } |
157 | static uint64_t gen(Key key) { return (key >> 8) & 0xffffffffu; } |
158 | static uint64_t hash(Key key) { return key & 0xff; } |
159 | |
160 | static uint64_t HashNumbers(uint64_t k, uint64_t g) { |
161 | uint64_t data[2] = {k, g}; |
162 | return Hash(reinterpret_cast<char*>(data), sizeof(data), 0); |
163 | } |
164 | |
165 | static Key MakeKey(uint64_t k, uint64_t g) { |
166 | static_assert(sizeof(Key) == sizeof(uint64_t), "" ); |
167 | assert(k <= K); // We sometimes pass K to seek to the end of the skiplist |
168 | assert(g <= 0xffffffffu); |
169 | return ((k << 40) | (g << 8) | (HashNumbers(k, g) & 0xff)); |
170 | } |
171 | |
172 | static bool IsValidKey(Key k) { |
173 | return hash(k) == (HashNumbers(key(k), gen(k)) & 0xff); |
174 | } |
175 | |
176 | static Key RandomTarget(Random* rnd) { |
177 | switch (rnd->Next() % 10) { |
178 | case 0: |
179 | // Seek to beginning |
180 | return MakeKey(0, 0); |
181 | case 1: |
182 | // Seek to end |
183 | return MakeKey(K, 0); |
184 | default: |
185 | // Seek to middle |
186 | return MakeKey(rnd->Next() % K, 0); |
187 | } |
188 | } |
189 | |
190 | // Per-key generation |
191 | struct State { |
192 | std::atomic<int> generation[K]; |
193 | void Set(int k, int v) { |
194 | generation[k].store(v, std::memory_order_release); |
195 | } |
196 | int Get(int k) { return generation[k].load(std::memory_order_acquire); } |
197 | |
198 | State() { |
199 | for (int k = 0; k < K; k++) { |
200 | Set(k, 0); |
201 | } |
202 | } |
203 | }; |
204 | |
205 | // Current state of the test |
206 | State current_; |
207 | |
208 | Arena arena_; |
209 | |
210 | // SkipList is not protected by mu_. We just use a single writer |
211 | // thread to modify it. |
212 | SkipList<Key, Comparator> list_; |
213 | |
214 | public: |
215 | ConcurrentTest() : list_(Comparator(), &arena_) {} |
216 | |
217 | // REQUIRES: External synchronization |
218 | void WriteStep(Random* rnd) { |
219 | const uint32_t k = rnd->Next() % K; |
220 | const intptr_t g = current_.Get(k) + 1; |
221 | const Key key = MakeKey(k, g); |
222 | list_.Insert(key); |
223 | current_.Set(k, g); |
224 | } |
225 | |
226 | void ReadStep(Random* rnd) { |
227 | // Remember the initial committed state of the skiplist. |
228 | State initial_state; |
229 | for (int k = 0; k < K; k++) { |
230 | initial_state.Set(k, current_.Get(k)); |
231 | } |
232 | |
233 | Key pos = RandomTarget(rnd); |
234 | SkipList<Key, Comparator>::Iterator iter(&list_); |
235 | iter.Seek(pos); |
236 | while (true) { |
237 | Key current; |
238 | if (!iter.Valid()) { |
239 | current = MakeKey(K, 0); |
240 | } else { |
241 | current = iter.key(); |
242 | ASSERT_TRUE(IsValidKey(current)) << current; |
243 | } |
244 | ASSERT_LE(pos, current) << "should not go backwards" ; |
245 | |
246 | // Verify that everything in [pos,current) was not present in |
247 | // initial_state. |
248 | while (pos < current) { |
249 | ASSERT_LT(key(pos), K) << pos; |
250 | |
251 | // Note that generation 0 is never inserted, so it is ok if |
252 | // <*,0,*> is missing. |
253 | ASSERT_TRUE((gen(pos) == 0) || |
254 | (gen(pos) > static_cast<Key>(initial_state.Get(key(pos))))) |
255 | << "key: " << key(pos) << "; gen: " << gen(pos) |
256 | << "; initgen: " << initial_state.Get(key(pos)); |
257 | |
258 | // Advance to next key in the valid key space |
259 | if (key(pos) < key(current)) { |
260 | pos = MakeKey(key(pos) + 1, 0); |
261 | } else { |
262 | pos = MakeKey(key(pos), gen(pos) + 1); |
263 | } |
264 | } |
265 | |
266 | if (!iter.Valid()) { |
267 | break; |
268 | } |
269 | |
270 | if (rnd->Next() % 2) { |
271 | iter.Next(); |
272 | pos = MakeKey(key(pos), gen(pos) + 1); |
273 | } else { |
274 | Key new_target = RandomTarget(rnd); |
275 | if (new_target > pos) { |
276 | pos = new_target; |
277 | iter.Seek(new_target); |
278 | } |
279 | } |
280 | } |
281 | } |
282 | }; |
283 | |
284 | // Needed when building in C++11 mode. |
285 | constexpr uint32_t ConcurrentTest::K; |
286 | |
287 | // Simple test that does single-threaded testing of the ConcurrentTest |
288 | // scaffolding. |
289 | TEST(SkipTest, ConcurrentWithoutThreads) { |
290 | ConcurrentTest test; |
291 | Random rnd(test::RandomSeed()); |
292 | for (int i = 0; i < 10000; i++) { |
293 | test.ReadStep(&rnd); |
294 | test.WriteStep(&rnd); |
295 | } |
296 | } |
297 | |
298 | class TestState { |
299 | public: |
300 | ConcurrentTest t_; |
301 | int seed_; |
302 | std::atomic<bool> quit_flag_; |
303 | |
304 | enum ReaderState { STARTING, RUNNING, DONE }; |
305 | |
306 | explicit TestState(int s) |
307 | : seed_(s), quit_flag_(false), state_(STARTING), state_cv_(&mu_) {} |
308 | |
309 | void Wait(ReaderState s) LOCKS_EXCLUDED(mu_) { |
310 | mu_.Lock(); |
311 | while (state_ != s) { |
312 | state_cv_.Wait(); |
313 | } |
314 | mu_.Unlock(); |
315 | } |
316 | |
317 | void Change(ReaderState s) LOCKS_EXCLUDED(mu_) { |
318 | mu_.Lock(); |
319 | state_ = s; |
320 | state_cv_.Signal(); |
321 | mu_.Unlock(); |
322 | } |
323 | |
324 | private: |
325 | port::Mutex mu_; |
326 | ReaderState state_ GUARDED_BY(mu_); |
327 | port::CondVar state_cv_ GUARDED_BY(mu_); |
328 | }; |
329 | |
330 | static void ConcurrentReader(void* arg) { |
331 | TestState* state = reinterpret_cast<TestState*>(arg); |
332 | Random rnd(state->seed_); |
333 | int64_t reads = 0; |
334 | state->Change(TestState::RUNNING); |
335 | while (!state->quit_flag_.load(std::memory_order_acquire)) { |
336 | state->t_.ReadStep(&rnd); |
337 | ++reads; |
338 | } |
339 | state->Change(TestState::DONE); |
340 | } |
341 | |
342 | static void RunConcurrent(int run) { |
343 | const int seed = test::RandomSeed() + (run * 100); |
344 | Random rnd(seed); |
345 | const int N = 1000; |
346 | const int kSize = 1000; |
347 | for (int i = 0; i < N; i++) { |
348 | if ((i % 100) == 0) { |
349 | std::fprintf(stderr, "Run %d of %d\n" , i, N); |
350 | } |
351 | TestState state(seed + 1); |
352 | Env::Default()->Schedule(ConcurrentReader, &state); |
353 | state.Wait(TestState::RUNNING); |
354 | for (int i = 0; i < kSize; i++) { |
355 | state.t_.WriteStep(&rnd); |
356 | } |
357 | state.quit_flag_.store(true, std::memory_order_release); |
358 | state.Wait(TestState::DONE); |
359 | } |
360 | } |
361 | |
362 | TEST(SkipTest, Concurrent1) { RunConcurrent(1); } |
363 | TEST(SkipTest, Concurrent2) { RunConcurrent(2); } |
364 | TEST(SkipTest, Concurrent3) { RunConcurrent(3); } |
365 | TEST(SkipTest, Concurrent4) { RunConcurrent(4); } |
366 | TEST(SkipTest, Concurrent5) { RunConcurrent(5); } |
367 | |
368 | } // namespace leveldb |
369 | |