1 | #pragma once |
---|---|
2 | |
3 | #include <sys/types.h> |
4 | |
5 | #include <condition_variable> |
6 | #include <mutex> |
7 | #include <unordered_map> |
8 | |
9 | #include <torch/csrc/distributed/c10d/Store.hpp> |
10 | |
11 | namespace c10d { |
12 | |
13 | class TORCH_API HashStore : public Store { |
14 | public: |
15 | ~HashStore() override = default; |
16 | |
17 | void set(const std::string& key, const std::vector<uint8_t>& data) override; |
18 | |
19 | std::vector<uint8_t> compareSet( |
20 | const std::string& key, |
21 | const std::vector<uint8_t>& expectedValue, |
22 | const std::vector<uint8_t>& desiredValue) override; |
23 | |
24 | std::vector<uint8_t> get(const std::string& key) override; |
25 | |
26 | void wait(const std::vector<std::string>& keys) override { |
27 | wait(keys, Store::kDefaultTimeout); |
28 | } |
29 | |
30 | void wait( |
31 | const std::vector<std::string>& keys, |
32 | const std::chrono::milliseconds& timeout) override; |
33 | |
34 | int64_t add(const std::string& key, int64_t value) override; |
35 | |
36 | int64_t getNumKeys() override; |
37 | |
38 | bool check(const std::vector<std::string>& keys) override; |
39 | |
40 | bool deleteKey(const std::string& key) override; |
41 | |
42 | protected: |
43 | std::unordered_map<std::string, std::vector<uint8_t>> map_; |
44 | std::mutex m_; |
45 | std::condition_variable cv_; |
46 | }; |
47 | |
48 | } // namespace c10d |
49 |