1#pragma once
2
3#include <torch/csrc/distributed/rpc/script_call.h>
4#include <torch/csrc/distributed/rpc/types.h>
5#include <torch/csrc/jit/runtime/operator.h>
6#include <torch/csrc/jit/serialization/pickler.h>
7#include <vector>
8
9namespace torch {
10namespace distributed {
11namespace rpc {
12
13using torch::jit::Operator;
14
15// A ScriptRemoteCall instance represents an invocation of `dist.remote` on a
16// builtin operator. Currently, it does not support using RRef as arguments yet.
17// Besides the operator and a vector of arguments, ScriptRemoteCall also
18// caontains the RRefId and the ForkId of the return value RRef.
19class TORCH_API ScriptRemoteCall final : public ScriptCall {
20 public:
21 // Constructor for builitin operator call.
22 ScriptRemoteCall(
23 std::shared_ptr<Operator> op,
24 std::vector<at::IValue>&& stack,
25 const RRefId& retRRefId,
26 const ForkId& retForkId);
27
28 // Constructor for TorchScript function call.
29 ScriptRemoteCall(
30 const c10::QualifiedName& qualifiedName,
31 std::vector<at::IValue>&& stack,
32 const RRefId& retRRefId,
33 const ForkId& retForkId,
34 const bool isAsyncExecution);
35
36 inline const RRefId& retRRefId() const {
37 return retRRefId_;
38 }
39
40 inline const ForkId& retForkId() const {
41 return retForkId_;
42 }
43
44 static std::unique_ptr<ScriptRemoteCall> fromIValues(
45 std::vector<at::IValue>& ivalues);
46
47 c10::intrusive_ptr<Message> toMessageImpl() && override;
48 static std::unique_ptr<ScriptRemoteCall> fromMessage(const Message& message);
49
50 private:
51 const RRefId retRRefId_;
52 const ForkId retForkId_;
53};
54
55} // namespace rpc
56} // namespace distributed
57} // namespace torch
58