1 | #include <torch/csrc/jit/passes/variadic_ops.h> |
2 | |
3 | #include <torch/csrc/jit/ir/alias_analysis.h> |
4 | #include <torch/csrc/jit/jit_log.h> |
5 | #include <torch/csrc/jit/passes/constant_pooling.h> |
6 | #include <torch/csrc/jit/passes/remove_mutation.h> |
7 | |
8 | namespace torch { |
9 | namespace jit { |
10 | |
11 | namespace { |
12 | |
13 | std::vector<size_t> identifyListArgIndices(const c10::FunctionSchema& schema) { |
14 | std::vector<size_t> list_indices; |
15 | const auto& args = schema.arguments(); |
16 | for (const auto i : c10::irange(args.size())) { |
17 | auto list_type = args[i].type()->castRaw<ListType>(); |
18 | if (list_type && list_type->getElementType()->castRaw<TensorType>()) { |
19 | list_indices.push_back(i); |
20 | } |
21 | } |
22 | return list_indices; |
23 | } |
24 | |
25 | bool isTensorListConstruct(Node* node) { |
26 | if (node->kind() != prim::ListConstruct) { |
27 | return false; |
28 | } |
29 | const auto type = node->output()->type()->castRaw<ListType>(); |
30 | TORCH_CHECK(type != nullptr); |
31 | const auto& elem_type = type->getElementType(); |
32 | return elem_type->castRaw<TensorType>(); |
33 | } |
34 | |
35 | class VariadicUpdater { |
36 | public: |
37 | VariadicUpdater( |
38 | std::shared_ptr<Graph> graph, |
39 | NodeKind op, |
40 | NodeKind variadic_op) |
41 | : graph_(std::move(graph)), |
42 | alias_db_(graph_), |
43 | op_(op), |
44 | variadic_op_(variadic_op) {} |
45 | |
46 | bool run() { |
47 | collectOpNodes(graph_->block()); |
48 | bool changed = false; |
49 | for (auto n : op_nodes_) { |
50 | changed |= replaceWithVariadicOp(n); |
51 | } |
52 | return changed; |
53 | } |
54 | |
55 | private: |
56 | void recordSchema(Node* op_node) { |
57 | const auto& schema = op_node->schema(); |
58 | auto it = schema_to_list_indices_.find(schema.name()); |
59 | if (it == schema_to_list_indices_.end()) { |
60 | schema_to_list_indices_.emplace( |
61 | schema.overload_name(), identifyListArgIndices(schema)); |
62 | } |
63 | } |
64 | |
65 | const std::vector<size_t>& getListIndices(Node* op_node) const { |
66 | const auto& schema = op_node->schema(); |
67 | auto it = schema_to_list_indices_.find(schema.overload_name()); |
68 | TORCH_CHECK(it != schema_to_list_indices_.end()); |
69 | return it->second; |
70 | } |
71 | |
72 | void collectOpNodes(Block* block) { |
73 | for (auto node : block->nodes()) { |
74 | if (node->kind() == op_) { |
75 | op_nodes_.push_back(node); |
76 | recordSchema(node); |
77 | } |
78 | for (Block* b : node->blocks()) { |
79 | collectOpNodes(b); |
80 | } |
81 | } |
82 | } |
83 | |
84 | bool allListInputsAreValid(Node* op_node) { |
85 | const size_t num_inputs = op_node->inputs().size(); |
86 | for (const auto list_idx : getListIndices(op_node)) { |
87 | TORCH_CHECK(list_idx < num_inputs); |
88 | const auto list = op_node->input(list_idx)->node(); |
89 | // We do not transform ops whose list input can not be moved to the |
90 | // position before op. This in turn implies that there is some mutation |
91 | // of the input list before op. |
92 | if (!isTensorListConstruct(list) || |
93 | !alias_db_.couldMoveBeforeTopologically(list, op_node)) { |
94 | return false; |
95 | } |
96 | } |
97 | return true; |
98 | } |
99 | |
100 | void insertAllInputsBetween( |
101 | std::vector<Value*>& inputs, |
102 | Node* node, |
103 | size_t start_idx, |
104 | size_t end_idx) const { |
105 | const size_t num_inputs = node->inputs().size(); |
106 | TORCH_CHECK(start_idx <= end_idx && end_idx <= num_inputs); |
107 | inputs.insert( |
108 | inputs.end(), |
109 | node->inputs().begin() + start_idx, |
110 | node->inputs().begin() + end_idx); |
111 | } |
112 | |
113 | void insertIntegerInput(std::vector<Value*>& inputs, size_t input) { |
114 | auto constant = graph_->create(prim::Constant); |
115 | constant->output()->setType(c10::IntType::get()); |
116 | constant->i_(attr::value, input); |
117 | graph_->prependNode(constant); |
118 | inputs.push_back(constant->output()); |
119 | } |
120 | |
121 | void deleteOpNodeAndLists(Node* op_node) { |
122 | // Collect the lists before we destroy op_node |
123 | std::vector<Node*> lists; |
124 | const auto& list_indices = getListIndices(op_node); |
125 | lists.reserve(list_indices.size()); |
126 | for (const size_t list_idx : list_indices) { |
127 | auto* list = op_node->input(list_idx)->node(); |
128 | lists.push_back(list); |
129 | } |
130 | |
131 | GRAPH_UPDATE("Deleting\n" , *op_node); |
132 | op_node->destroy(); |
133 | for (auto* list : lists) { |
134 | if (!list->hasUses()) { |
135 | GRAPH_UPDATE("Deleting\n" , *list); |
136 | list->destroy(); |
137 | } |
138 | } |
139 | } |
140 | |
141 | bool replaceWithVariadicOp(Node* op_node) { |
142 | if (!allListInputsAreValid(op_node)) { |
143 | return false; |
144 | } |
145 | |
146 | std::vector<Value*> inputs; |
147 | size_t cur_idx = 0; |
148 | std::vector<size_t> list_lens; |
149 | for (const size_t list_idx : getListIndices(op_node)) { |
150 | insertAllInputsBetween(inputs, op_node, cur_idx, list_idx); |
151 | const auto list = op_node->input(list_idx)->node(); |
152 | const auto list_len = list->inputs().size(); |
153 | list_lens.push_back(list_len); |
154 | insertAllInputsBetween(inputs, list, 0, list_len); |
155 | cur_idx = list_idx + 1; |
156 | } |
157 | insertAllInputsBetween(inputs, op_node, cur_idx, op_node->inputs().size()); |
158 | |
159 | // We insert these extra integers at the end of the argument list only if we |
160 | // have more than one variadic list (the information is redundant when there |
161 | // is only one list because the interpreter knows how many arguments there |
162 | // are). |
163 | if (list_lens.size() > 1) { |
164 | for (const size_t list_len : list_lens) { |
165 | insertIntegerInput(inputs, list_len); |
166 | } |
167 | } |
168 | |
169 | auto var_op_node = op_node->owningGraph()->create(variadic_op_, inputs); |
170 | var_op_node->output()->setType(op_node->output()->type()); |
171 | GRAPH_UPDATE("Adding\n" , *var_op_node); |
172 | var_op_node->insertBefore(op_node); |
173 | GRAPH_UPDATE("Replacing\n" , *op_node, "with\n" , *var_op_node); |
174 | op_node->output()->replaceAllUsesWith(var_op_node->output()); |
175 | deleteOpNodeAndLists(op_node); |
176 | return true; |
177 | } |
178 | |
179 | std::shared_ptr<Graph> graph_; |
180 | std::vector<Node*> op_nodes_; |
181 | |
182 | AliasDb alias_db_; |
183 | |
184 | NodeKind op_; |
185 | NodeKind variadic_op_; |
186 | |
187 | std::unordered_map<std::string, std::vector<size_t>> schema_to_list_indices_; |
188 | }; |
189 | |
190 | } // namespace |
191 | |
192 | bool UseVariadicOp( |
193 | const std::shared_ptr<Graph>& graph, |
194 | NodeKind op, |
195 | NodeKind variadic_op) { |
196 | const std::string pass_name = std::string("variadic " ) + op.toQualString(); |
197 | GRAPH_DUMP("Before " + pass_name, graph); |
198 | bool changed = VariadicUpdater(graph, op, variadic_op).run(); |
199 | if (changed) { |
200 | ConstantPooling(graph); |
201 | GRAPH_DUMP("After " + pass_name, graph); |
202 | } |
203 | return changed; |
204 | } |
205 | |
206 | bool RemoveListMutationAndUseVariadicOp( |
207 | const std::shared_ptr<Graph>& graph, |
208 | NodeKind op, |
209 | NodeKind variadic_op) { |
210 | bool changed_in_last_iter = true; |
211 | bool changed = false; |
212 | while (changed_in_last_iter) { |
213 | changed_in_last_iter = RemoveListMutation(graph); |
214 | changed_in_last_iter = |
215 | UseVariadicOp(graph, op, variadic_op) || changed_in_last_iter; |
216 | changed = changed || changed_in_last_iter; |
217 | } |
218 | return changed; |
219 | } |
220 | |
221 | bool UseVariadicCat(const std::shared_ptr<Graph>& graph) { |
222 | return UseVariadicOp(graph, aten::cat, prim::VarConcat); |
223 | } |
224 | bool RemoveListMutationAndUseVariadicCat(const std::shared_ptr<Graph>& graph) { |
225 | return RemoveListMutationAndUseVariadicOp(graph, aten::cat, prim::VarConcat); |
226 | } |
227 | |
228 | bool UseVariadicStack(const std::shared_ptr<Graph>& graph) { |
229 | return UseVariadicOp(graph, aten::stack, prim::VarStack); |
230 | } |
231 | bool RemoveListMutationAndUseVariadicStack( |
232 | const std::shared_ptr<Graph>& graph) { |
233 | return RemoveListMutationAndUseVariadicOp(graph, aten::stack, prim::VarStack); |
234 | } |
235 | |
236 | } // namespace jit |
237 | } // namespace torch |
238 | |