1 | #pragma once |
---|---|
2 | |
3 | #include <torch/csrc/distributed/rpc/script_call.h> |
4 | #include <torch/csrc/distributed/rpc/types.h> |
5 | #include <torch/csrc/jit/runtime/operator.h> |
6 | #include <torch/csrc/jit/serialization/pickler.h> |
7 | #include <vector> |
8 | |
9 | namespace torch { |
10 | namespace distributed { |
11 | namespace rpc { |
12 | |
13 | using torch::jit::Operator; |
14 | |
15 | // A ScriptRemoteCall instance represents an invocation of `dist.remote` on a |
16 | // builtin operator. Currently, it does not support using RRef as arguments yet. |
17 | // Besides the operator and a vector of arguments, ScriptRemoteCall also |
18 | // caontains the RRefId and the ForkId of the return value RRef. |
19 | class TORCH_API ScriptRemoteCall final : public ScriptCall { |
20 | public: |
21 | // Constructor for builitin operator call. |
22 | ScriptRemoteCall( |
23 | std::shared_ptr<Operator> op, |
24 | std::vector<at::IValue>&& stack, |
25 | const RRefId& retRRefId, |
26 | const ForkId& retForkId); |
27 | |
28 | // Constructor for TorchScript function call. |
29 | ScriptRemoteCall( |
30 | const c10::QualifiedName& qualifiedName, |
31 | std::vector<at::IValue>&& stack, |
32 | const RRefId& retRRefId, |
33 | const ForkId& retForkId, |
34 | const bool isAsyncExecution); |
35 | |
36 | inline const RRefId& retRRefId() const { |
37 | return retRRefId_; |
38 | } |
39 | |
40 | inline const ForkId& retForkId() const { |
41 | return retForkId_; |
42 | } |
43 | |
44 | static std::unique_ptr<ScriptRemoteCall> fromIValues( |
45 | std::vector<at::IValue>& ivalues); |
46 | |
47 | c10::intrusive_ptr<Message> toMessageImpl() && override; |
48 | static std::unique_ptr<ScriptRemoteCall> fromMessage(const Message& message); |
49 | |
50 | private: |
51 | const RRefId retRRefId_; |
52 | const ForkId retForkId_; |
53 | }; |
54 | |
55 | } // namespace rpc |
56 | } // namespace distributed |
57 | } // namespace torch |
58 |