1 | #pragma once |
2 | |
3 | #include <torch/csrc/distributed/rpc/message.h> |
4 | #include <torch/csrc/distributed/rpc/rpc_command_base.h> |
5 | #include <torch/csrc/distributed/rpc/types.h> |
6 | #include <torch/csrc/jit/runtime/operator.h> |
7 | #include <torch/csrc/jit/serialization/pickler.h> |
8 | #include <vector> |
9 | |
10 | namespace torch { |
11 | namespace distributed { |
12 | namespace rpc { |
13 | |
14 | // Temporary solution of RRef operations. |
15 | // TODO: Remove all these messages and use rpc + registered functions instead. |
16 | class TORCH_API RRefMessageBase : public RpcCommandBase { |
17 | public: |
18 | RRefMessageBase(const RRefId& rrefId, MessageType type) |
19 | : rrefId_(rrefId), type_(type) {} |
20 | |
21 | ~RRefMessageBase() override = default; |
22 | |
23 | const RRefId& rrefId(); |
24 | |
25 | protected: |
26 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
27 | const RRefId rrefId_; |
28 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
29 | const MessageType type_; |
30 | }; |
31 | |
32 | class TORCH_API ForkMessageBase : public RRefMessageBase { |
33 | public: |
34 | ForkMessageBase(const RRefId& rrefId, const ForkId& forkId, MessageType type) |
35 | : RRefMessageBase(rrefId, type), forkId_(forkId) {} |
36 | |
37 | const ForkId& forkId(); |
38 | |
39 | c10::intrusive_ptr<Message> toMessageImpl() && override; |
40 | static std::pair<RRefId, ForkId> fromMessage( |
41 | const Message& message, |
42 | MessageType type); |
43 | |
44 | protected: |
45 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
46 | const ForkId forkId_; |
47 | }; |
48 | |
49 | // UserRRef uses this message to fetch the remote RRef value from the owner. |
50 | class TORCH_API ScriptRRefFetchCall final : public RRefMessageBase { |
51 | public: |
52 | ScriptRRefFetchCall(worker_id_t fromWorkerId, const RRefId& rrefId) |
53 | : RRefMessageBase(rrefId, MessageType::SCRIPT_RREF_FETCH_CALL), |
54 | fromWorkerId_(fromWorkerId) {} |
55 | |
56 | inline worker_id_t fromWorkerId() const { |
57 | return fromWorkerId_; |
58 | } |
59 | |
60 | c10::intrusive_ptr<Message> toMessageImpl() && override; |
61 | static std::unique_ptr<ScriptRRefFetchCall> fromMessage( |
62 | const Message& message); |
63 | |
64 | private: |
65 | const worker_id_t fromWorkerId_; |
66 | }; |
67 | |
68 | class TORCH_API PythonRRefFetchCall final : public RRefMessageBase { |
69 | public: |
70 | PythonRRefFetchCall(worker_id_t fromWorkerId, const RRefId& rrefId) |
71 | : RRefMessageBase(rrefId, MessageType::PYTHON_RREF_FETCH_CALL), |
72 | fromWorkerId_(fromWorkerId) {} |
73 | |
74 | c10::intrusive_ptr<Message> toMessageImpl() && override; |
75 | static std::unique_ptr<PythonRRefFetchCall> fromMessage( |
76 | const Message& message); |
77 | |
78 | private: |
79 | const worker_id_t fromWorkerId_; |
80 | }; |
81 | |
82 | // OwnerRRef uses this message to send the RRef value to a remote UserRRef |
83 | class TORCH_API RRefFetchRet : public RpcCommandBase { |
84 | public: |
85 | RRefFetchRet(std::vector<at::IValue> values, MessageType type) |
86 | : values_(std::move(values)), type_(type) {} |
87 | |
88 | const std::vector<at::IValue>& values(); |
89 | c10::intrusive_ptr<Message> toMessageImpl() && override; |
90 | |
91 | private: |
92 | std::vector<at::IValue> values_; |
93 | const MessageType type_; |
94 | }; |
95 | |
96 | class TORCH_API ScriptRRefFetchRet final : public RRefFetchRet { |
97 | public: |
98 | explicit ScriptRRefFetchRet(std::vector<at::IValue> values) |
99 | : RRefFetchRet(std::move(values), MessageType::SCRIPT_RREF_FETCH_RET) {} |
100 | |
101 | static std::unique_ptr<ScriptRRefFetchRet> fromMessage( |
102 | const Message& message); |
103 | }; |
104 | |
105 | class TORCH_API PythonRRefFetchRet final : public RRefFetchRet { |
106 | public: |
107 | explicit PythonRRefFetchRet(std::vector<at::IValue> values) |
108 | : RRefFetchRet(std::move(values), MessageType::PYTHON_RREF_FETCH_RET) {} |
109 | |
110 | static std::unique_ptr<PythonRRefFetchRet> fromMessage( |
111 | const Message& message); |
112 | }; |
113 | |
114 | // UserRRef (regardless it's the creator or not) uses this message to notiify |
115 | // OwnerRRef on delete. |
116 | class TORCH_API RRefUserDelete final : public ForkMessageBase { |
117 | public: |
118 | RRefUserDelete(const RRefId& rrefId, const ForkId& forkId) |
119 | : ForkMessageBase(rrefId, forkId, MessageType::RREF_USER_DELETE) {} |
120 | |
121 | static std::unique_ptr<RRefUserDelete> fromMessage(const Message& message); |
122 | }; |
123 | |
124 | class TORCH_API RemoteRet final : public ForkMessageBase { |
125 | public: |
126 | RemoteRet(const RRefId& rrefId, const ForkId& forkId) |
127 | : ForkMessageBase(rrefId, forkId, MessageType::REMOTE_RET) {} |
128 | |
129 | static std::unique_ptr<RemoteRet> fromMessage(const Message& message); |
130 | }; |
131 | |
132 | // A child RRef uses this message to notify its parent that the child has been |
133 | // confirmed by the owner. |
134 | class TORCH_API RRefChildAccept final : public RpcCommandBase { |
135 | public: |
136 | explicit RRefChildAccept(const ForkId& forkId) : forkId_(forkId) {} |
137 | |
138 | const ForkId& forkId() const; |
139 | |
140 | c10::intrusive_ptr<Message> toMessageImpl() && override; |
141 | static std::unique_ptr<RRefChildAccept> fromMessage(const Message& message); |
142 | |
143 | private: |
144 | const ForkId forkId_; |
145 | }; |
146 | |
147 | // A child RRef uses this message to send a fork request to the owner. |
148 | class TORCH_API RRefForkRequest final : public ForkMessageBase { |
149 | public: |
150 | RRefForkRequest(const RRefId& rrefId, const ForkId& forkId) |
151 | : ForkMessageBase(rrefId, forkId, MessageType::RREF_FORK_REQUEST) {} |
152 | |
153 | static std::unique_ptr<RRefForkRequest> fromMessage(const Message& message); |
154 | }; |
155 | |
156 | class TORCH_API RRefAck final : public RpcCommandBase { |
157 | public: |
158 | RRefAck() = default; |
159 | |
160 | c10::intrusive_ptr<Message> toMessageImpl() && override; |
161 | static std::unique_ptr<RRefAck> fromMessage(const Message& message); |
162 | }; |
163 | |
164 | } // namespace rpc |
165 | } // namespace distributed |
166 | } // namespace torch |
167 | |