1#pragma once
2
3#include <torch/csrc/distributed/c10d/Store.hpp>
4#include <memory>
5
6namespace c10d {
7
8class TORCH_API PrefixStore : public Store {
9 public:
10 explicit PrefixStore(
11 std::string prefix,
12 c10::intrusive_ptr<Store> store);
13
14 ~PrefixStore() override = default;
15
16 using Store::set;
17 void set(const std::string& key, const std::vector<uint8_t>& value) override;
18
19 using Store::compareSet;
20 std::vector<uint8_t> compareSet(
21 const std::string& key,
22 const std::vector<uint8_t>& expectedValue,
23 const std::vector<uint8_t>& desiredValue) override;
24
25 std::vector<uint8_t> get(const std::string& key) override;
26
27 int64_t add(const std::string& key, int64_t value) override;
28
29 bool deleteKey(const std::string& key) override;
30
31 int64_t getNumKeys() override;
32
33 bool check(const std::vector<std::string>& keys) override;
34
35 void wait(const std::vector<std::string>& keys) override;
36
37 void wait(
38 const std::vector<std::string>& keys,
39 const std::chrono::milliseconds& timeout) override;
40
41 const std::chrono::milliseconds& getTimeout() const noexcept override;
42
43 void setTimeout(const std::chrono::milliseconds& timeout) override;
44
45 void watchKey(const std::string& key, WatchKeyCallback callback) override;
46
47 c10::intrusive_ptr<Store> getUnderlyingStore();
48
49 protected:
50 std::string prefix_;
51 c10::intrusive_ptr<Store> store_;
52
53 std::string joinKey(const std::string& key);
54 std::vector<std::string> joinKeys(const std::vector<std::string>& keys);
55};
56
57} // namespace c10d
58