1 | #include <torch/csrc/distributed/rpc/rpc_agent.h> |
2 | #include <torch/csrc/distributed/rpc/script_call.h> |
3 | #include <torch/csrc/jit/serialization/pickle.h> |
4 | |
5 | namespace torch { |
6 | namespace distributed { |
7 | namespace rpc { |
8 | |
9 | const std::string ScriptCall::BUILTIN_OP_NAMESPACE_("torch.ops.aten." ); |
10 | const std::string ScriptCall::ATEN_PREFIX_("aten::" ); |
11 | |
12 | ScriptCall::ScriptCall( |
13 | std::shared_ptr<Operator> op, |
14 | std::vector<at::IValue>&& stack) |
15 | : op_(std::move(op)), stack_(stack), isAsyncExecution_(false) {} |
16 | |
17 | ScriptCall::ScriptCall( |
18 | const c10::QualifiedName& qualifiedName, |
19 | std::vector<at::IValue>&& stack, |
20 | const bool isAsyncExecution) |
21 | : qualifiedName_(qualifiedName), |
22 | stack_(stack), |
23 | isAsyncExecution_(isAsyncExecution) {} |
24 | |
25 | bool ScriptCall::hasOp() const { |
26 | return op_ ? true : false; |
27 | } |
28 | |
29 | std::shared_ptr<Operator> ScriptCall::op() const { |
30 | return *op_; |
31 | } |
32 | |
33 | bool ScriptCall::hasQualifiedName() const { |
34 | return qualifiedName_ ? true : false; |
35 | } |
36 | |
37 | const c10::QualifiedName& ScriptCall::qualifiedName() const { |
38 | return *qualifiedName_; |
39 | } |
40 | |
41 | const std::vector<at::IValue>& ScriptCall::stack() const { |
42 | return stack_; |
43 | } |
44 | |
45 | std::vector<at::IValue>& ScriptCall::stackRef() { |
46 | return stack_; |
47 | } |
48 | |
49 | void ScriptCall::toIValues(std::vector<at::IValue>& ivalues) const { |
50 | for (auto& value : stack_) { |
51 | ivalues.push_back(value); |
52 | } |
53 | |
54 | if (hasOp()) { |
55 | TORCH_CHECK( |
56 | !hasQualifiedName(), |
57 | "It is builtin operator call, qualifiedName_ should not be set." ); |
58 | // TODO: replace this with a real overload_name when FunctionSchema supports |
59 | // that. |
60 | ivalues.emplace_back(toString((*op_)->schema())); |
61 | // insert qualified name |
62 | auto opName = (*op_)->schema().name(); |
63 | TORCH_CHECK( |
64 | opName.find("::" ) == opName.rfind("::" ) && |
65 | opName.rfind(ATEN_PREFIX_) == 0, |
66 | "Unexpected operator name " , |
67 | opName); |
68 | // aten::add -> torch.ops.aten.add |
69 | opName.replace(0, ATEN_PREFIX_.length(), BUILTIN_OP_NAMESPACE_); |
70 | ivalues.emplace_back(std::move(opName)); |
71 | } else if (hasQualifiedName()) { |
72 | ivalues.emplace_back(isAsyncExecution()); |
73 | TORCH_CHECK( |
74 | !hasOp(), |
75 | "It is TorchScript function call, operator should not be set." ); |
76 | ivalues.emplace_back((*qualifiedName_).qualifiedName()); |
77 | } else { |
78 | TORCH_INTERNAL_ASSERT( |
79 | false, |
80 | "Either builtin operator or TorchScript function name should be set." ); |
81 | } |
82 | } |
83 | |
84 | std::unique_ptr<ScriptCall> ScriptCall::fromIValues( |
85 | std::vector<at::IValue>& ivalues) { |
86 | // Last element in the vector is always qualifiedName for both |
87 | // builitin operator and TorchScript function |
88 | // If the qualifiedName is not a builtin operator name, then treat it |
89 | // as TorchScript function name |
90 | const std::string& qualifiedName = ivalues.back().toStringRef(); |
91 | |
92 | if (qualifiedName.rfind(BUILTIN_OP_NAMESPACE_) == 0) { |
93 | ivalues.pop_back(); |
94 | const std::string& str_schema = ivalues.back().toStringRef(); |
95 | auto op = matchOperator(str_schema); |
96 | |
97 | ivalues.pop_back(); |
98 | // remove str_schema from ivalues |
99 | return std::make_unique<ScriptCall>(op, std::move(ivalues)); |
100 | } else { |
101 | ivalues.pop_back(); |
102 | bool isAsyncExecution = ivalues.back().toBool(); |
103 | ivalues.pop_back(); |
104 | return std::make_unique<ScriptCall>( |
105 | c10::QualifiedName(qualifiedName), |
106 | std::move(ivalues), |
107 | isAsyncExecution); |
108 | } |
109 | } |
110 | |
111 | c10::intrusive_ptr<Message> ScriptCall::toMessageImpl() && { |
112 | std::vector<IValue> ivalues; |
113 | toIValues(ivalues); |
114 | |
115 | std::vector<torch::Tensor> tensor_table; |
116 | auto payload = jit::pickle( |
117 | c10::ivalue::Tuple::create(std::move(ivalues)), &tensor_table); |
118 | |
119 | return c10::make_intrusive<Message>( |
120 | std::move(payload), std::move(tensor_table), MessageType::SCRIPT_CALL); |
121 | } |
122 | |
123 | std::unique_ptr<ScriptCall> ScriptCall::fromMessage(const Message& message) { |
124 | auto payload = static_cast<const char*>(message.payload().data()); |
125 | auto payload_size = message.payload().size(); |
126 | auto value = jit::unpickle( |
127 | payload, |
128 | payload_size, |
129 | *RpcAgent::getCurrentRpcAgent()->getTypeResolver(), |
130 | message.tensors()); |
131 | |
132 | auto values = value.toTupleRef().elements().vec(); |
133 | return fromIValues(values); |
134 | } |
135 | |
136 | std::shared_ptr<Operator> ScriptCall::matchOperator( |
137 | const std::string& str_schema) { |
138 | // TODO: This is a temporary solution. We should pass enough information to |
139 | // allow deterministically matched to one operator. |
140 | |
141 | // extract symbol from the schema |
142 | auto schema = torch::jit::parseSchema(str_schema); |
143 | auto symbol = at::Symbol::fromQualString(schema.name()); |
144 | |
145 | for (auto op : torch::jit::getAllOperatorsFor(symbol)) { |
146 | if (toString(op->schema()) == str_schema) { |
147 | return op; |
148 | } |
149 | } |
150 | |
151 | TORCH_CHECK(false, "Cannot find matching operator for schema " , str_schema); |
152 | } |
153 | |
154 | } // namespace rpc |
155 | } // namespace distributed |
156 | } // namespace torch |
157 | |