1#include <torch/csrc/distributed/rpc/rpc_agent.h>
2#include <torch/csrc/distributed/rpc/script_call.h>
3#include <torch/csrc/jit/serialization/pickle.h>
4
5namespace torch {
6namespace distributed {
7namespace rpc {
8
9const std::string ScriptCall::BUILTIN_OP_NAMESPACE_("torch.ops.aten.");
10const std::string ScriptCall::ATEN_PREFIX_("aten::");
11
12ScriptCall::ScriptCall(
13 std::shared_ptr<Operator> op,
14 std::vector<at::IValue>&& stack)
15 : op_(std::move(op)), stack_(stack), isAsyncExecution_(false) {}
16
17ScriptCall::ScriptCall(
18 const c10::QualifiedName& qualifiedName,
19 std::vector<at::IValue>&& stack,
20 const bool isAsyncExecution)
21 : qualifiedName_(qualifiedName),
22 stack_(stack),
23 isAsyncExecution_(isAsyncExecution) {}
24
25bool ScriptCall::hasOp() const {
26 return op_ ? true : false;
27}
28
29std::shared_ptr<Operator> ScriptCall::op() const {
30 return *op_;
31}
32
33bool ScriptCall::hasQualifiedName() const {
34 return qualifiedName_ ? true : false;
35}
36
37const c10::QualifiedName& ScriptCall::qualifiedName() const {
38 return *qualifiedName_;
39}
40
41const std::vector<at::IValue>& ScriptCall::stack() const {
42 return stack_;
43}
44
45std::vector<at::IValue>& ScriptCall::stackRef() {
46 return stack_;
47}
48
49void ScriptCall::toIValues(std::vector<at::IValue>& ivalues) const {
50 for (auto& value : stack_) {
51 ivalues.push_back(value);
52 }
53
54 if (hasOp()) {
55 TORCH_CHECK(
56 !hasQualifiedName(),
57 "It is builtin operator call, qualifiedName_ should not be set.");
58 // TODO: replace this with a real overload_name when FunctionSchema supports
59 // that.
60 ivalues.emplace_back(toString((*op_)->schema()));
61 // insert qualified name
62 auto opName = (*op_)->schema().name();
63 TORCH_CHECK(
64 opName.find("::") == opName.rfind("::") &&
65 opName.rfind(ATEN_PREFIX_) == 0,
66 "Unexpected operator name ",
67 opName);
68 // aten::add -> torch.ops.aten.add
69 opName.replace(0, ATEN_PREFIX_.length(), BUILTIN_OP_NAMESPACE_);
70 ivalues.emplace_back(std::move(opName));
71 } else if (hasQualifiedName()) {
72 ivalues.emplace_back(isAsyncExecution());
73 TORCH_CHECK(
74 !hasOp(),
75 "It is TorchScript function call, operator should not be set.");
76 ivalues.emplace_back((*qualifiedName_).qualifiedName());
77 } else {
78 TORCH_INTERNAL_ASSERT(
79 false,
80 "Either builtin operator or TorchScript function name should be set.");
81 }
82}
83
84std::unique_ptr<ScriptCall> ScriptCall::fromIValues(
85 std::vector<at::IValue>& ivalues) {
86 // Last element in the vector is always qualifiedName for both
87 // builitin operator and TorchScript function
88 // If the qualifiedName is not a builtin operator name, then treat it
89 // as TorchScript function name
90 const std::string& qualifiedName = ivalues.back().toStringRef();
91
92 if (qualifiedName.rfind(BUILTIN_OP_NAMESPACE_) == 0) {
93 ivalues.pop_back();
94 const std::string& str_schema = ivalues.back().toStringRef();
95 auto op = matchOperator(str_schema);
96
97 ivalues.pop_back();
98 // remove str_schema from ivalues
99 return std::make_unique<ScriptCall>(op, std::move(ivalues));
100 } else {
101 ivalues.pop_back();
102 bool isAsyncExecution = ivalues.back().toBool();
103 ivalues.pop_back();
104 return std::make_unique<ScriptCall>(
105 c10::QualifiedName(qualifiedName),
106 std::move(ivalues),
107 isAsyncExecution);
108 }
109}
110
111c10::intrusive_ptr<Message> ScriptCall::toMessageImpl() && {
112 std::vector<IValue> ivalues;
113 toIValues(ivalues);
114
115 std::vector<torch::Tensor> tensor_table;
116 auto payload = jit::pickle(
117 c10::ivalue::Tuple::create(std::move(ivalues)), &tensor_table);
118
119 return c10::make_intrusive<Message>(
120 std::move(payload), std::move(tensor_table), MessageType::SCRIPT_CALL);
121}
122
123std::unique_ptr<ScriptCall> ScriptCall::fromMessage(const Message& message) {
124 auto payload = static_cast<const char*>(message.payload().data());
125 auto payload_size = message.payload().size();
126 auto value = jit::unpickle(
127 payload,
128 payload_size,
129 *RpcAgent::getCurrentRpcAgent()->getTypeResolver(),
130 message.tensors());
131
132 auto values = value.toTupleRef().elements().vec();
133 return fromIValues(values);
134}
135
136std::shared_ptr<Operator> ScriptCall::matchOperator(
137 const std::string& str_schema) {
138 // TODO: This is a temporary solution. We should pass enough information to
139 // allow deterministically matched to one operator.
140
141 // extract symbol from the schema
142 auto schema = torch::jit::parseSchema(str_schema);
143 auto symbol = at::Symbol::fromQualString(schema.name());
144
145 for (auto op : torch::jit::getAllOperatorsFor(symbol)) {
146 if (toString(op->schema()) == str_schema) {
147 return op;
148 }
149 }
150
151 TORCH_CHECK(false, "Cannot find matching operator for schema ", str_schema);
152}
153
154} // namespace rpc
155} // namespace distributed
156} // namespace torch
157