1#pragma once
2
3#include <c10/util/Optional.h>
4#include <torch/csrc/distributed/rpc/message.h>
5#include <torch/csrc/distributed/rpc/rpc_command_base.h>
6#include <torch/csrc/jit/runtime/operator.h>
7#include <torch/csrc/jit/serialization/pickler.h>
8#include <vector>
9
10namespace torch {
11namespace distributed {
12namespace rpc {
13
14using torch::jit::Operator;
15
16// A ScriptCall instance represents an invocation of a builtin operator for a
17// TorchScript function. If it is a builtin operator, it
18// contains a shared ptr to the `Operator` and a list of arguments.
19// If it is a TorchScript function, it contains a non empty qualifiedName string
20// to the TorchScript function schema name and a list of arguments.
21class TORCH_API ScriptCall : public RpcCommandBase {
22 public:
23 // Constructor for builitin operator call.
24 ScriptCall(std::shared_ptr<Operator> op, std::vector<at::IValue>&& stack);
25 // Constructor for TorchScript function call.
26 ScriptCall(
27 const c10::QualifiedName& qualifiedName,
28 std::vector<at::IValue>&& stack,
29 const bool isAsyncExecution = false);
30
31 bool hasOp() const;
32 std::shared_ptr<Operator> op() const;
33 bool hasQualifiedName() const;
34 const c10::QualifiedName& qualifiedName() const;
35 // return the argument stack of this builtin operator
36 const std::vector<at::IValue>& stack() const;
37 std::vector<at::IValue>& stackRef();
38 inline bool isAsyncExecution() const {
39 return isAsyncExecution_;
40 }
41
42 c10::intrusive_ptr<Message> toMessageImpl() && override;
43 static std::unique_ptr<ScriptCall> fromMessage(const Message& message);
44
45 ~ScriptCall() override = default;
46
47 protected:
48 virtual void toIValues(std::vector<at::IValue>& ivalues) const;
49 static std::unique_ptr<ScriptCall> fromIValues(
50 std::vector<at::IValue>& ivalues);
51
52 private:
53 // Given an operator symbol and a string schema, return the matched operator.
54 static std::shared_ptr<Operator> matchOperator(const std::string& str_schema);
55
56 static const std::string BUILTIN_OP_NAMESPACE_;
57 static const std::string ATEN_PREFIX_;
58
59 // This field has value if this ScriptCall represents invocation of a builtin
60 // operator.
61 c10::optional<std::shared_ptr<Operator>> op_;
62 // This field has non empty string if this ScriptCall represents invocation of
63 // an annotated torchscript function defined by users.
64 c10::optional<const c10::QualifiedName> qualifiedName_;
65 std::vector<at::IValue> stack_;
66 const bool isAsyncExecution_;
67};
68
69} // namespace rpc
70} // namespace distributed
71} // namespace torch
72