1 | #pragma once |
2 | |
3 | #include <chrono> |
4 | #include <cstdint> |
5 | #include <stdexcept> |
6 | #include <string> |
7 | #include <vector> |
8 | |
9 | #include <c10/macros/Macros.h> |
10 | #include <torch/custom_class.h> |
11 | |
12 | namespace c10d { |
13 | |
14 | // callback function will be given arguments (optional<string> oldValue, |
15 | // optional<string> newValue) |
16 | using WatchKeyCallback = |
17 | std::function<void(c10::optional<std::string>, c10::optional<std::string>)>; |
18 | |
19 | class TORCH_API Store : public torch::CustomClassHolder { |
20 | public: |
21 | static constexpr std::chrono::milliseconds kDefaultTimeout = |
22 | std::chrono::seconds(300); |
23 | static constexpr std::chrono::milliseconds kNoTimeout = |
24 | std::chrono::milliseconds::zero(); |
25 | |
26 | Store() : timeout_(kDefaultTimeout) {} |
27 | |
28 | explicit Store(const std::chrono::milliseconds& timeout) |
29 | : timeout_(timeout) {} |
30 | |
31 | ~Store() override; |
32 | |
33 | void set(const std::string& key, const std::string& value); |
34 | |
35 | virtual void set( |
36 | const std::string& key, |
37 | const std::vector<uint8_t>& value) = 0; |
38 | |
39 | std::string compareSet( |
40 | const std::string& key, |
41 | const std::string& currentValue, |
42 | const std::string& newValue); |
43 | |
44 | virtual std::vector<uint8_t> compareSet( |
45 | const std::string& key, |
46 | const std::vector<uint8_t>& currentValue, |
47 | const std::vector<uint8_t>& newValue) { |
48 | TORCH_INTERNAL_ASSERT(false, "Not implemented." ); |
49 | } |
50 | |
51 | std::string get_to_str(const std::string& key); |
52 | |
53 | virtual std::vector<uint8_t> get(const std::string& key) = 0; |
54 | |
55 | virtual int64_t add(const std::string& key, int64_t value) = 0; |
56 | |
57 | virtual bool deleteKey(const std::string& key) = 0; |
58 | |
59 | virtual bool check(const std::vector<std::string>& keys) = 0; |
60 | |
61 | virtual int64_t getNumKeys() = 0; |
62 | |
63 | virtual void wait(const std::vector<std::string>& keys) = 0; |
64 | |
65 | virtual void wait( |
66 | const std::vector<std::string>& keys, |
67 | const std::chrono::milliseconds& timeout) = 0; |
68 | |
69 | virtual const std::chrono::milliseconds& getTimeout() const noexcept; |
70 | |
71 | virtual void setTimeout(const std::chrono::milliseconds& timeout); |
72 | |
73 | // watchKey() takes two arguments: key and callback function. The callback |
74 | // should be run whenever the key is changed (create, update, or delete). The |
75 | // callback function takes two parameters: currentValue and newValue, which |
76 | // are optional depending on how the key is changed. These key updates should |
77 | // trigger the callback as follows: |
78 | // CREATE: callback(c10::nullopt, newValue) // null currentValue |
79 | // UPDATE: callback(currentValue, newValue) |
80 | // DELETE: callback(currentValue, c10::nullopt) // null newValue |
81 | virtual void watchKey( |
82 | const std::string& /* unused */, |
83 | WatchKeyCallback /* unused */) { |
84 | TORCH_CHECK( |
85 | false, |
86 | "watchKey only implemented for TCPStore and PrefixStore that wraps TCPStore." ); |
87 | } |
88 | |
89 | protected: |
90 | std::chrono::milliseconds timeout_; |
91 | }; |
92 | |
93 | } // namespace c10d |
94 | |