1 | #include <ATen/core/jit_type.h> |
2 | #include <torch/csrc/jit/ir/alias_analysis.h> |
3 | #include <torch/csrc/jit/ir/ir_views.h> |
4 | #include <torch/csrc/jit/jit_log.h> |
5 | #include <torch/csrc/jit/passes/dead_code_elimination.h> |
6 | #include <torch/csrc/jit/passes/peephole.h> |
7 | #include <torch/csrc/jit/passes/peephole_list_idioms.h> |
8 | #include <torch/csrc/jit/passes/value_refinement_utils.h> |
9 | #include <torch/csrc/jit/runtime/graph_executor.h> |
10 | #include <torch/csrc/jit/runtime/slice_indices_adjust.h> |
11 | #include <torch/csrc/utils/memory.h> |
12 | #include <limits> |
13 | #include <utility> |
14 | |
15 | namespace torch { |
16 | namespace jit { |
17 | |
18 | c10::optional<size_t> normalizeIndex(int64_t index, size_t len) { |
19 | if (index < 0) { |
20 | index = index + len; |
21 | } |
22 | if (index >= 0 && index < static_cast<int64_t>(len)) { |
23 | return index; |
24 | } else { |
25 | return c10::nullopt; |
26 | } |
27 | } |
28 | |
29 | // see [value refinement algorithm] |
30 | |
31 | struct ListLenRefiner { |
32 | ListLenRefiner( |
33 | std::shared_ptr<Graph> graph, |
34 | std::unordered_set<Value*>& mutated_lists) |
35 | : graph_(std::move(graph)), mutated_lists_(mutated_lists) {} |
36 | |
37 | bool run() { |
38 | std::unordered_set<Value*> li_with_len_use; |
39 | collectListsToRefine(graph_->block(), li_with_len_use); |
40 | if (lists_to_refine_.empty()) { |
41 | return false; |
42 | } |
43 | ListRefinement refinements; |
44 | RefineListLens(graph_->block(), std::move(refinements)); |
45 | return changed_; |
46 | } |
47 | |
48 | // we only need to analyze lists that have multiple uses of len(), and we can |
49 | // only analyze lists that are not mutated |
50 | void collectListsToRefine( |
51 | Block* b, |
52 | std::unordered_set<Value*>& li_with_len_use) { |
53 | for (Node* n : b->nodes()) { |
54 | for (Block* block : n->blocks()) { |
55 | collectListsToRefine(block, li_with_len_use); |
56 | } |
57 | |
58 | if (n->kind() != aten::len) { |
59 | continue; |
60 | } |
61 | |
62 | auto first_input = n->input(0); |
63 | if (first_input->type()->castRaw<ListType>() && |
64 | !mutated_lists_.count(first_input)) { |
65 | if (!li_with_len_use.count(first_input)) { |
66 | li_with_len_use.insert(first_input); |
67 | } else { |
68 | lists_to_refine_.insert(first_input); |
69 | } |
70 | } |
71 | } |
72 | } |
73 | |
74 | ListRefinement RefineListLens(Block* b, ListRefinement block_refinements) { |
75 | active_refinements_.push_back(&block_refinements); |
76 | for (Node* n : b->nodes()) { |
77 | if (n->matches("aten::eq(int a, int b) -> bool" ) || |
78 | n->matches("aten::ne(int a, int b) -> bool" )) { |
79 | // check for one input constant and the other coming from len(li) |
80 | for (size_t const_index : {0, 1}) { |
81 | auto ival = constant_as<int64_t>(n->input(const_index)); |
82 | if (!ival) { |
83 | continue; |
84 | } |
85 | auto li_len = n->input(1 - const_index); |
86 | if (!li_len->node()->matches("aten::len.t(t[] a) -> int" ) || |
87 | !lists_to_refine_.count(li_len->node()->input())) { |
88 | continue; |
89 | } |
90 | ListRefinement refine; |
91 | refine[li_len->node()->input()] = *ival; |
92 | boolean_value_refinements_[n->output()] = n->kind() == aten::eq |
93 | ? BooleanRefinementMapping::TrueRefinements(std::move(refine)) |
94 | : BooleanRefinementMapping::FalseRefinements(std::move(refine)); |
95 | } |
96 | } else if (n->kind() == aten::len) { |
97 | if (auto maybe_len = tryFindRefinement(n->input(0))) { |
98 | changed_ = true; |
99 | WithInsertPoint guard(n); |
100 | n->output()->replaceAllUsesWith( |
101 | graph_->insertConstant(static_cast<int64_t>(*maybe_len))); |
102 | } |
103 | } else if (n->kind() == prim::If) { |
104 | IfView if_n(n); |
105 | bool has_cond_ref = boolean_value_refinements_.count(if_n.cond()) != 0; |
106 | ListRefinement empty; |
107 | auto true_block_refinements = RefineListLens( |
108 | if_n.thenBlock(), |
109 | has_cond_ref ? boolean_value_refinements_[if_n.cond()].true_refine() |
110 | : empty); |
111 | auto false_block_refinements = RefineListLens( |
112 | if_n.elseBlock(), |
113 | has_cond_ref |
114 | ? boolean_value_refinements_[if_n.cond()].false_refine() |
115 | : empty); |
116 | |
117 | joinIfRefinements( |
118 | n, |
119 | throwing_blocks_, |
120 | block_refinements, |
121 | true_block_refinements, |
122 | false_block_refinements, |
123 | boolean_value_refinements_); |
124 | } else { |
125 | handleCommonRefinentOperators( |
126 | n, throwing_blocks_, boolean_value_refinements_); |
127 | } |
128 | } |
129 | active_refinements_.pop_back(); |
130 | return block_refinements; |
131 | }; |
132 | |
133 | c10::optional<int64_t> tryFindRefinement(Value* v) { |
134 | for (const auto& ref : active_refinements_) { |
135 | auto maybe_refinement = ref->find(v); |
136 | if (maybe_refinement != ref->end()) { |
137 | return maybe_refinement->second; |
138 | } |
139 | } |
140 | return c10::nullopt; |
141 | } |
142 | |
143 | std::shared_ptr<Graph> graph_; |
144 | std::unordered_set<Value*> mutated_lists_; |
145 | // candidate lists for optimizations |
146 | std::unordered_set<Value*> lists_to_refine_; |
147 | // A stack of active refinements, one for each block |
148 | std::vector<ListRefinement*> active_refinements_; |
149 | // A map from Boolean Value * -> associated refinements |
150 | std::unordered_map<Value*, BooleanRefinementMapping> |
151 | boolean_value_refinements_; |
152 | std::unordered_set<Block*> throwing_blocks_; |
153 | bool changed_ = false; |
154 | }; |
155 | |
156 | // This pass only does optimizations on lists which aren't mutated, |
157 | // so we first use the Alias Db to collect the set of list values |
158 | // which we shouldn't optimize. |
159 | struct PeepholeOptimizeListIdiomsImpl { |
160 | PeepholeOptimizeListIdiomsImpl( |
161 | std::shared_ptr<Graph> graph, |
162 | bool refine_list_len) |
163 | : graph_(std::move(graph)), |
164 | aliasDb_(torch::make_unique<AliasDb>(graph_)), |
165 | refine_list_len_(refine_list_len) {} |
166 | |
167 | bool run() { |
168 | collectMutatedLists(graph_->block()); |
169 | bool changed = runBlock(graph_->block()); |
170 | if (refine_list_len_) { |
171 | changed |= ListLenRefiner(graph_, mutated_lists_).run(); |
172 | } |
173 | return changed; |
174 | } |
175 | |
176 | private: |
177 | void checkForMutatedList(Value* v) { |
178 | if (v->type()->castRaw<ListType>() && aliasDb_->hasWriters(v)) { |
179 | mutated_lists_.insert(v); |
180 | } |
181 | } |
182 | |
183 | void collectMutatedLists(Block* b) { |
184 | for (Value* v : b->inputs()) { |
185 | checkForMutatedList(v); |
186 | } |
187 | for (Node* n : b->nodes()) { |
188 | for (Value* v : n->outputs()) { |
189 | checkForMutatedList(v); |
190 | } |
191 | for (Block* block : n->blocks()) { |
192 | collectMutatedLists(block); |
193 | } |
194 | } |
195 | } |
196 | |
197 | bool optimizeSlice(Node* slice_node, Node* list_construct_node) { |
198 | auto start_val = toIValue(slice_node->input(1)); |
199 | auto end_val = toIValue(slice_node->input(2)); |
200 | auto step_val = toIValue(slice_node->input(3)); |
201 | |
202 | // All args must be constant to apply this optimization. |
203 | if (start_val == c10::nullopt || end_val == c10::nullopt || |
204 | step_val == c10::nullopt) { |
205 | return false; |
206 | } |
207 | |
208 | int64_t start = start_val->isInt() ? start_val->to<int64_t>() |
209 | : std::numeric_limits<int64_t>::max(); |
210 | int64_t end = end_val->isInt() ? end_val->to<int64_t>() |
211 | : std::numeric_limits<int64_t>::max(); |
212 | int64_t step = step_val->isInt() ? step_val->to<int64_t>() : 1; |
213 | |
214 | size_t list_size = list_construct_node->inputs().size(); |
215 | size_t num_values = slice_indices_adjust(list_size, &start, &end, step); |
216 | |
217 | WithInsertPoint guard(slice_node); |
218 | auto slice_list_construct = |
219 | graph_->insertNode(graph_->create(prim::ListConstruct)); |
220 | slice_list_construct->output()->setType(slice_node->output()->type()); |
221 | for (size_t i = start, j = 0; j < num_values; ++j) { |
222 | slice_list_construct->addInput(list_construct_node->input(i)); |
223 | i += step; |
224 | } |
225 | |
226 | slice_node->output()->replaceAllUsesWith(slice_list_construct->output()); |
227 | if (mutated_lists_.count(slice_node->output())) { |
228 | mutated_lists_.insert(slice_list_construct->output()); |
229 | } |
230 | |
231 | return true; |
232 | } |
233 | |
234 | bool runBlock(Block* block) { |
235 | bool changed = false; |
236 | for (Node* node : block->nodes()) { |
237 | for (Block* b : node->blocks()) { |
238 | changed |= runBlock(b); |
239 | } |
240 | |
241 | // only optimizing list ops |
242 | if (node->inputs().empty() || |
243 | !node->input(0)->type()->castRaw<ListType>()) { |
244 | continue; |
245 | } |
246 | |
247 | auto first_input = node->input(0); |
248 | |
249 | // only optimizing ops with unmutated lists |
250 | if (mutated_lists_.count(first_input)) { |
251 | continue; |
252 | } |
253 | |
254 | auto list_creation_node = first_input->node(); |
255 | if (list_creation_node->kind() != prim::ListConstruct) { |
256 | continue; |
257 | } |
258 | |
259 | if (node->kind() == aten::len) { |
260 | WithInsertPoint guard(node); |
261 | node->output()->replaceAllUsesWith(graph_->insertConstant( |
262 | static_cast<int64_t>(first_input->node()->inputs().size()))); |
263 | changed = true; |
264 | } else if (node->kind() == aten::__getitem__) { |
265 | if (auto index = toIValue(node->input(1))) { |
266 | size_t list_size = list_creation_node->inputs().size(); |
267 | if (auto norm_index = normalizeIndex(index->toInt(), list_size)) { |
268 | node->output()->replaceAllUsesWith( |
269 | list_creation_node->input(*norm_index)); |
270 | changed = true; |
271 | } |
272 | } |
273 | } else if (node->kind() == prim::ListUnpack) { |
274 | // if sizes are unequal it's a runtime error |
275 | if (list_creation_node->inputs().size() != node->outputs().size()) { |
276 | continue; |
277 | } |
278 | for (size_t i = 0; i < node->outputs().size(); ++i) { |
279 | node->output(i)->replaceAllUsesWith(list_creation_node->input(i)); |
280 | changed = true; |
281 | } |
282 | } else if (node->kind() == aten::add) { |
283 | if (node->inputs().size() != 2) { |
284 | continue; |
285 | } |
286 | auto second_input = node->input(1); |
287 | // already checked first, need to check second |
288 | if (mutated_lists_.count(second_input)) { |
289 | continue; |
290 | } |
291 | if (second_input->node()->kind() != prim::ListConstruct) { |
292 | continue; |
293 | } |
294 | WithInsertPoint guard(node); |
295 | auto list_construct = |
296 | graph_->insertNode(graph_->create(prim::ListConstruct)); |
297 | list_construct->output()->setType(node->output()->type()); |
298 | for (Value* v : first_input->node()->inputs()) { |
299 | list_construct->addInput(v); |
300 | } |
301 | for (Value* v : second_input->node()->inputs()) { |
302 | list_construct->addInput(v); |
303 | } |
304 | node->output()->replaceAllUsesWith(list_construct->output()); |
305 | if (mutated_lists_.count(node->output())) { |
306 | mutated_lists_.insert(list_construct->output()); |
307 | } |
308 | changed = true; |
309 | } else if (node->kind() == aten::slice) { |
310 | changed |= optimizeSlice(node, first_input->node()); |
311 | } |
312 | } |
313 | return changed; |
314 | } |
315 | |
316 | std::unordered_set<Value*> mutated_lists_; |
317 | std::shared_ptr<Graph> graph_; |
318 | std::unique_ptr<AliasDb> aliasDb_; |
319 | bool refine_list_len_; |
320 | }; |
321 | |
322 | bool PeepholeOptimizeListIdioms( |
323 | const std::shared_ptr<Graph>& graph, |
324 | bool refine_list_len) { |
325 | PeepholeOptimizeListIdiomsImpl opt(graph, refine_list_len); |
326 | return opt.run(); |
327 | } |
328 | |
329 | } // namespace jit |
330 | } // namespace torch |
331 | |