1#include <torch/csrc/jit/passes/variadic_ops.h>
2
3#include <torch/csrc/jit/ir/alias_analysis.h>
4#include <torch/csrc/jit/jit_log.h>
5#include <torch/csrc/jit/passes/constant_pooling.h>
6#include <torch/csrc/jit/passes/remove_mutation.h>
7
8namespace torch {
9namespace jit {
10
11namespace {
12
13std::vector<size_t> identifyListArgIndices(const c10::FunctionSchema& schema) {
14 std::vector<size_t> list_indices;
15 const auto& args = schema.arguments();
16 for (const auto i : c10::irange(args.size())) {
17 auto list_type = args[i].type()->castRaw<ListType>();
18 if (list_type && list_type->getElementType()->castRaw<TensorType>()) {
19 list_indices.push_back(i);
20 }
21 }
22 return list_indices;
23}
24
25bool isTensorListConstruct(Node* node) {
26 if (node->kind() != prim::ListConstruct) {
27 return false;
28 }
29 const auto type = node->output()->type()->castRaw<ListType>();
30 TORCH_CHECK(type != nullptr);
31 const auto& elem_type = type->getElementType();
32 return elem_type->castRaw<TensorType>();
33}
34
35class VariadicUpdater {
36 public:
37 VariadicUpdater(
38 std::shared_ptr<Graph> graph,
39 NodeKind op,
40 NodeKind variadic_op)
41 : graph_(std::move(graph)),
42 alias_db_(graph_),
43 op_(op),
44 variadic_op_(variadic_op) {}
45
46 bool run() {
47 collectOpNodes(graph_->block());
48 bool changed = false;
49 for (auto n : op_nodes_) {
50 changed |= replaceWithVariadicOp(n);
51 }
52 return changed;
53 }
54
55 private:
56 void recordSchema(Node* op_node) {
57 const auto& schema = op_node->schema();
58 auto it = schema_to_list_indices_.find(schema.name());
59 if (it == schema_to_list_indices_.end()) {
60 schema_to_list_indices_.emplace(
61 schema.overload_name(), identifyListArgIndices(schema));
62 }
63 }
64
65 const std::vector<size_t>& getListIndices(Node* op_node) const {
66 const auto& schema = op_node->schema();
67 auto it = schema_to_list_indices_.find(schema.overload_name());
68 TORCH_CHECK(it != schema_to_list_indices_.end());
69 return it->second;
70 }
71
72 void collectOpNodes(Block* block) {
73 for (auto node : block->nodes()) {
74 if (node->kind() == op_) {
75 op_nodes_.push_back(node);
76 recordSchema(node);
77 }
78 for (Block* b : node->blocks()) {
79 collectOpNodes(b);
80 }
81 }
82 }
83
84 bool allListInputsAreValid(Node* op_node) {
85 const size_t num_inputs = op_node->inputs().size();
86 for (const auto list_idx : getListIndices(op_node)) {
87 TORCH_CHECK(list_idx < num_inputs);
88 const auto list = op_node->input(list_idx)->node();
89 // We do not transform ops whose list input can not be moved to the
90 // position before op. This in turn implies that there is some mutation
91 // of the input list before op.
92 if (!isTensorListConstruct(list) ||
93 !alias_db_.couldMoveBeforeTopologically(list, op_node)) {
94 return false;
95 }
96 }
97 return true;
98 }
99
100 void insertAllInputsBetween(
101 std::vector<Value*>& inputs,
102 Node* node,
103 size_t start_idx,
104 size_t end_idx) const {
105 const size_t num_inputs = node->inputs().size();
106 TORCH_CHECK(start_idx <= end_idx && end_idx <= num_inputs);
107 inputs.insert(
108 inputs.end(),
109 node->inputs().begin() + start_idx,
110 node->inputs().begin() + end_idx);
111 }
112
113 void insertIntegerInput(std::vector<Value*>& inputs, size_t input) {
114 auto constant = graph_->create(prim::Constant);
115 constant->output()->setType(c10::IntType::get());
116 constant->i_(attr::value, input);
117 graph_->prependNode(constant);
118 inputs.push_back(constant->output());
119 }
120
121 void deleteOpNodeAndLists(Node* op_node) {
122 // Collect the lists before we destroy op_node
123 std::vector<Node*> lists;
124 const auto& list_indices = getListIndices(op_node);
125 lists.reserve(list_indices.size());
126 for (const size_t list_idx : list_indices) {
127 auto* list = op_node->input(list_idx)->node();
128 lists.push_back(list);
129 }
130
131 GRAPH_UPDATE("Deleting\n", *op_node);
132 op_node->destroy();
133 for (auto* list : lists) {
134 if (!list->hasUses()) {
135 GRAPH_UPDATE("Deleting\n", *list);
136 list->destroy();
137 }
138 }
139 }
140
141 bool replaceWithVariadicOp(Node* op_node) {
142 if (!allListInputsAreValid(op_node)) {
143 return false;
144 }
145
146 std::vector<Value*> inputs;
147 size_t cur_idx = 0;
148 std::vector<size_t> list_lens;
149 for (const size_t list_idx : getListIndices(op_node)) {
150 insertAllInputsBetween(inputs, op_node, cur_idx, list_idx);
151 const auto list = op_node->input(list_idx)->node();
152 const auto list_len = list->inputs().size();
153 list_lens.push_back(list_len);
154 insertAllInputsBetween(inputs, list, 0, list_len);
155 cur_idx = list_idx + 1;
156 }
157 insertAllInputsBetween(inputs, op_node, cur_idx, op_node->inputs().size());
158
159 // We insert these extra integers at the end of the argument list only if we
160 // have more than one variadic list (the information is redundant when there
161 // is only one list because the interpreter knows how many arguments there
162 // are).
163 if (list_lens.size() > 1) {
164 for (const size_t list_len : list_lens) {
165 insertIntegerInput(inputs, list_len);
166 }
167 }
168
169 auto var_op_node = op_node->owningGraph()->create(variadic_op_, inputs);
170 var_op_node->output()->setType(op_node->output()->type());
171 GRAPH_UPDATE("Adding\n", *var_op_node);
172 var_op_node->insertBefore(op_node);
173 GRAPH_UPDATE("Replacing\n", *op_node, "with\n", *var_op_node);
174 op_node->output()->replaceAllUsesWith(var_op_node->output());
175 deleteOpNodeAndLists(op_node);
176 return true;
177 }
178
179 std::shared_ptr<Graph> graph_;
180 std::vector<Node*> op_nodes_;
181
182 AliasDb alias_db_;
183
184 NodeKind op_;
185 NodeKind variadic_op_;
186
187 std::unordered_map<std::string, std::vector<size_t>> schema_to_list_indices_;
188};
189
190} // namespace
191
192bool UseVariadicOp(
193 const std::shared_ptr<Graph>& graph,
194 NodeKind op,
195 NodeKind variadic_op) {
196 const std::string pass_name = std::string("variadic ") + op.toQualString();
197 GRAPH_DUMP("Before " + pass_name, graph);
198 bool changed = VariadicUpdater(graph, op, variadic_op).run();
199 if (changed) {
200 ConstantPooling(graph);
201 GRAPH_DUMP("After " + pass_name, graph);
202 }
203 return changed;
204}
205
206bool RemoveListMutationAndUseVariadicOp(
207 const std::shared_ptr<Graph>& graph,
208 NodeKind op,
209 NodeKind variadic_op) {
210 bool changed_in_last_iter = true;
211 bool changed = false;
212 while (changed_in_last_iter) {
213 changed_in_last_iter = RemoveListMutation(graph);
214 changed_in_last_iter =
215 UseVariadicOp(graph, op, variadic_op) || changed_in_last_iter;
216 changed = changed || changed_in_last_iter;
217 }
218 return changed;
219}
220
221bool UseVariadicCat(const std::shared_ptr<Graph>& graph) {
222 return UseVariadicOp(graph, aten::cat, prim::VarConcat);
223}
224bool RemoveListMutationAndUseVariadicCat(const std::shared_ptr<Graph>& graph) {
225 return RemoveListMutationAndUseVariadicOp(graph, aten::cat, prim::VarConcat);
226}
227
228bool UseVariadicStack(const std::shared_ptr<Graph>& graph) {
229 return UseVariadicOp(graph, aten::stack, prim::VarStack);
230}
231bool RemoveListMutationAndUseVariadicStack(
232 const std::shared_ptr<Graph>& graph) {
233 return RemoveListMutationAndUseVariadicOp(graph, aten::stack, prim::VarStack);
234}
235
236} // namespace jit
237} // namespace torch
238