1#include <c10/util/irange.h>
2#include "StoreTestCommon.hpp"
3
4#include <unistd.h>
5
6#include <iostream>
7#include <thread>
8
9#include <torch/csrc/distributed/c10d/HashStore.hpp>
10#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
11
12constexpr int64_t kShortStoreTimeoutMillis = 100;
13
14void testGetSet(std::string prefix = "") {
15 // Basic set/get
16 {
17 auto hashStore = c10::make_intrusive<c10d::HashStore>();
18 c10d::PrefixStore store(prefix, hashStore);
19 c10d::test::set(store, "key0", "value0");
20 c10d::test::set(store, "key1", "value1");
21 c10d::test::set(store, "key2", "value2");
22 c10d::test::check(store, "key0", "value0");
23 c10d::test::check(store, "key1", "value1");
24 c10d::test::check(store, "key2", "value2");
25
26 // Check compareSet, does not check return value
27 c10d::test::compareSet(store, "key0", "wrongExpectedValue", "newValue");
28 c10d::test::check(store, "key0", "value0");
29 c10d::test::compareSet(store, "key0", "value0", "newValue");
30 c10d::test::check(store, "key0", "newValue");
31
32 auto numKeys = store.getNumKeys();
33 EXPECT_EQ(numKeys, 3);
34 auto delSuccess = store.deleteKey("key0");
35 EXPECT_TRUE(delSuccess);
36 numKeys = store.getNumKeys();
37 EXPECT_EQ(numKeys, 2);
38 auto delFailure = store.deleteKey("badKeyName");
39 EXPECT_FALSE(delFailure);
40 auto timeout = std::chrono::milliseconds(kShortStoreTimeoutMillis);
41 store.setTimeout(timeout);
42 EXPECT_THROW(store.get("key0"), std::runtime_error);
43 }
44
45 // get() waits up to timeout_.
46 {
47 auto hashStore = c10::make_intrusive<c10d::HashStore>();
48 c10d::PrefixStore store(prefix, hashStore);
49 std::thread th([&]() { c10d::test::set(store, "key0", "value0"); });
50 c10d::test::check(store, "key0", "value0");
51 th.join();
52 }
53}
54
55void stressTestStore(std::string prefix = "") {
56 // Hammer on HashStore::add
57 const auto numThreads = 4;
58 const auto numIterations = 100;
59
60 std::vector<std::thread> threads;
61 c10d::test::Semaphore sem1, sem2;
62 auto hashStore = c10::make_intrusive<c10d::HashStore>();
63 c10d::PrefixStore store(prefix, hashStore);
64
65 for (C10_UNUSED const auto i : c10::irange(numThreads)) {
66 threads.emplace_back(std::thread([&] {
67 sem1.post();
68 sem2.wait();
69 for (C10_UNUSED const auto j : c10::irange(numIterations)) {
70 store.add("counter", 1);
71 }
72 }));
73 }
74
75 sem1.wait(numThreads);
76 sem2.post(numThreads);
77
78 for (auto& thread : threads) {
79 thread.join();
80 }
81 std::string expected = std::to_string(numThreads * numIterations);
82 c10d::test::check(store, "counter", expected);
83}
84
85TEST(HashStoreTest, testGetAndSet) {
86 testGetSet();
87}
88
89TEST(HashStoreTest, testGetAndSetWithPrefix) {
90 testGetSet("testPrefix");
91}
92
93TEST(HashStoreTest, testStressStore) {
94 stressTestStore();
95}
96
97TEST(HashStoreTest, testStressStoreWithPrefix) {
98 stressTestStore("testPrefix");
99}
100