1 | #pragma once |
2 | |
3 | #include <c10/util/Optional.h> |
4 | #include <torch/csrc/distributed/rpc/message.h> |
5 | #include <torch/csrc/distributed/rpc/rpc_command_base.h> |
6 | #include <torch/csrc/jit/runtime/operator.h> |
7 | #include <torch/csrc/jit/serialization/pickler.h> |
8 | #include <vector> |
9 | |
10 | namespace torch { |
11 | namespace distributed { |
12 | namespace rpc { |
13 | |
14 | using torch::jit::Operator; |
15 | |
16 | // A ScriptCall instance represents an invocation of a builtin operator for a |
17 | // TorchScript function. If it is a builtin operator, it |
18 | // contains a shared ptr to the `Operator` and a list of arguments. |
19 | // If it is a TorchScript function, it contains a non empty qualifiedName string |
20 | // to the TorchScript function schema name and a list of arguments. |
21 | class TORCH_API ScriptCall : public RpcCommandBase { |
22 | public: |
23 | // Constructor for builitin operator call. |
24 | ScriptCall(std::shared_ptr<Operator> op, std::vector<at::IValue>&& stack); |
25 | // Constructor for TorchScript function call. |
26 | ScriptCall( |
27 | const c10::QualifiedName& qualifiedName, |
28 | std::vector<at::IValue>&& stack, |
29 | const bool isAsyncExecution = false); |
30 | |
31 | bool hasOp() const; |
32 | std::shared_ptr<Operator> op() const; |
33 | bool hasQualifiedName() const; |
34 | const c10::QualifiedName& qualifiedName() const; |
35 | // return the argument stack of this builtin operator |
36 | const std::vector<at::IValue>& stack() const; |
37 | std::vector<at::IValue>& stackRef(); |
38 | inline bool isAsyncExecution() const { |
39 | return isAsyncExecution_; |
40 | } |
41 | |
42 | c10::intrusive_ptr<Message> toMessageImpl() && override; |
43 | static std::unique_ptr<ScriptCall> fromMessage(const Message& message); |
44 | |
45 | ~ScriptCall() override = default; |
46 | |
47 | protected: |
48 | virtual void toIValues(std::vector<at::IValue>& ivalues) const; |
49 | static std::unique_ptr<ScriptCall> fromIValues( |
50 | std::vector<at::IValue>& ivalues); |
51 | |
52 | private: |
53 | // Given an operator symbol and a string schema, return the matched operator. |
54 | static std::shared_ptr<Operator> matchOperator(const std::string& str_schema); |
55 | |
56 | static const std::string BUILTIN_OP_NAMESPACE_; |
57 | static const std::string ATEN_PREFIX_; |
58 | |
59 | // This field has value if this ScriptCall represents invocation of a builtin |
60 | // operator. |
61 | c10::optional<std::shared_ptr<Operator>> op_; |
62 | // This field has non empty string if this ScriptCall represents invocation of |
63 | // an annotated torchscript function defined by users. |
64 | c10::optional<const c10::QualifiedName> qualifiedName_; |
65 | std::vector<at::IValue> stack_; |
66 | const bool isAsyncExecution_; |
67 | }; |
68 | |
69 | } // namespace rpc |
70 | } // namespace distributed |
71 | } // namespace torch |
72 | |