1 | #include <torch/csrc/distributed/c10d/HashStore.hpp> |
2 | |
3 | #include <unistd.h> |
4 | #include <cerrno> |
5 | #include <cstdint> |
6 | |
7 | #include <chrono> |
8 | #include <cstdio> |
9 | #include <system_error> |
10 | |
11 | #include <c10/util/Exception.h> |
12 | |
13 | namespace c10d { |
14 | |
15 | void HashStore::set(const std::string& key, const std::vector<uint8_t>& data) { |
16 | std::unique_lock<std::mutex> lock(m_); |
17 | map_[key] = data; |
18 | cv_.notify_all(); |
19 | } |
20 | |
21 | std::vector<uint8_t> HashStore::compareSet( |
22 | const std::string& key, |
23 | const std::vector<uint8_t>& expectedValue, |
24 | const std::vector<uint8_t>& desiredValue) { |
25 | std::unique_lock<std::mutex> lock(m_); |
26 | auto it = map_.find(key); |
27 | if ((it == map_.end() && expectedValue.empty()) || |
28 | (it != map_.end() && it->second == expectedValue)) { |
29 | // if the key does not exist and currentValue arg is empty or |
30 | // the key does exist and current value is what is expected, then set it |
31 | map_[key] = desiredValue; |
32 | cv_.notify_all(); |
33 | return desiredValue; |
34 | } else if (it == map_.end()) { |
35 | // if the key does not exist |
36 | return expectedValue; |
37 | } |
38 | // key exists but current value is not expected |
39 | return it->second; |
40 | } |
41 | |
42 | std::vector<uint8_t> HashStore::get(const std::string& key) { |
43 | std::unique_lock<std::mutex> lock(m_); |
44 | auto it = map_.find(key); |
45 | if (it != map_.end()) { |
46 | return it->second; |
47 | } |
48 | // Slow path: wait up to any timeout_. |
49 | auto pred = [&]() { return map_.find(key) != map_.end(); }; |
50 | if (timeout_ == kNoTimeout) { |
51 | cv_.wait(lock, pred); |
52 | } else { |
53 | if (!cv_.wait_for(lock, timeout_, pred)) { |
54 | throw std::system_error( |
55 | ETIMEDOUT, std::system_category(), "Wait timeout" ); |
56 | } |
57 | } |
58 | return map_[key]; |
59 | } |
60 | |
61 | void HashStore::wait( |
62 | const std::vector<std::string>& keys, |
63 | const std::chrono::milliseconds& timeout) { |
64 | const auto end = std::chrono::steady_clock::now() + timeout; |
65 | auto pred = [&]() { |
66 | auto done = true; |
67 | for (const auto& key : keys) { |
68 | if (map_.find(key) == map_.end()) { |
69 | done = false; |
70 | break; |
71 | } |
72 | } |
73 | return done; |
74 | }; |
75 | |
76 | std::unique_lock<std::mutex> lock(m_); |
77 | if (timeout == kNoTimeout) { |
78 | cv_.wait(lock, pred); |
79 | } else { |
80 | if (!cv_.wait_until(lock, end, pred)) { |
81 | throw std::system_error( |
82 | ETIMEDOUT, std::system_category(), "Wait timeout" ); |
83 | } |
84 | } |
85 | } |
86 | |
87 | int64_t HashStore::add(const std::string& key, int64_t i) { |
88 | std::unique_lock<std::mutex> lock(m_); |
89 | const auto& value = map_[key]; |
90 | int64_t ti = i; |
91 | if (!value.empty()) { |
92 | auto buf = reinterpret_cast<const char*>(value.data()); |
93 | auto len = value.size(); |
94 | ti += std::stoll(std::string(buf, len)); |
95 | } |
96 | |
97 | auto str = std::to_string(ti); |
98 | const uint8_t* strB = reinterpret_cast<const uint8_t*>(str.c_str()); |
99 | map_[key] = std::vector<uint8_t>(strB, strB + str.size()); |
100 | return ti; |
101 | } |
102 | |
103 | int64_t HashStore::getNumKeys() { |
104 | std::unique_lock<std::mutex> lock(m_); |
105 | return map_.size(); |
106 | } |
107 | |
108 | bool HashStore::deleteKey(const std::string& key) { |
109 | std::unique_lock<std::mutex> lock(m_); |
110 | auto numDeleted = map_.erase(key); |
111 | return (numDeleted == 1); |
112 | } |
113 | |
114 | bool HashStore::check(const std::vector<std::string>& keys) { |
115 | std::unique_lock<std::mutex> lock(m_); |
116 | for (const auto& key : keys) { |
117 | if (map_.find(key) == map_.end()) { |
118 | return false; |
119 | } |
120 | } |
121 | return true; |
122 | } |
123 | |
124 | } // namespace c10d |
125 | |