1#include <ATen/core/jit_type.h>
2#include <torch/csrc/jit/ir/alias_analysis.h>
3#include <torch/csrc/jit/ir/ir_views.h>
4#include <torch/csrc/jit/jit_log.h>
5#include <torch/csrc/jit/passes/dead_code_elimination.h>
6#include <torch/csrc/jit/passes/peephole.h>
7#include <torch/csrc/jit/passes/peephole_list_idioms.h>
8#include <torch/csrc/jit/passes/value_refinement_utils.h>
9#include <torch/csrc/jit/runtime/graph_executor.h>
10#include <torch/csrc/jit/runtime/slice_indices_adjust.h>
11#include <torch/csrc/utils/memory.h>
12#include <limits>
13#include <utility>
14
15namespace torch {
16namespace jit {
17
18c10::optional<size_t> normalizeIndex(int64_t index, size_t len) {
19 if (index < 0) {
20 index = index + len;
21 }
22 if (index >= 0 && index < static_cast<int64_t>(len)) {
23 return index;
24 } else {
25 return c10::nullopt;
26 }
27}
28
29// see [value refinement algorithm]
30
31struct ListLenRefiner {
32 ListLenRefiner(
33 std::shared_ptr<Graph> graph,
34 std::unordered_set<Value*>& mutated_lists)
35 : graph_(std::move(graph)), mutated_lists_(mutated_lists) {}
36
37 bool run() {
38 std::unordered_set<Value*> li_with_len_use;
39 collectListsToRefine(graph_->block(), li_with_len_use);
40 if (lists_to_refine_.empty()) {
41 return false;
42 }
43 ListRefinement refinements;
44 RefineListLens(graph_->block(), std::move(refinements));
45 return changed_;
46 }
47
48 // we only need to analyze lists that have multiple uses of len(), and we can
49 // only analyze lists that are not mutated
50 void collectListsToRefine(
51 Block* b,
52 std::unordered_set<Value*>& li_with_len_use) {
53 for (Node* n : b->nodes()) {
54 for (Block* block : n->blocks()) {
55 collectListsToRefine(block, li_with_len_use);
56 }
57
58 if (n->kind() != aten::len) {
59 continue;
60 }
61
62 auto first_input = n->input(0);
63 if (first_input->type()->castRaw<ListType>() &&
64 !mutated_lists_.count(first_input)) {
65 if (!li_with_len_use.count(first_input)) {
66 li_with_len_use.insert(first_input);
67 } else {
68 lists_to_refine_.insert(first_input);
69 }
70 }
71 }
72 }
73
74 ListRefinement RefineListLens(Block* b, ListRefinement block_refinements) {
75 active_refinements_.push_back(&block_refinements);
76 for (Node* n : b->nodes()) {
77 if (n->matches("aten::eq(int a, int b) -> bool") ||
78 n->matches("aten::ne(int a, int b) -> bool")) {
79 // check for one input constant and the other coming from len(li)
80 for (size_t const_index : {0, 1}) {
81 auto ival = constant_as<int64_t>(n->input(const_index));
82 if (!ival) {
83 continue;
84 }
85 auto li_len = n->input(1 - const_index);
86 if (!li_len->node()->matches("aten::len.t(t[] a) -> int") ||
87 !lists_to_refine_.count(li_len->node()->input())) {
88 continue;
89 }
90 ListRefinement refine;
91 refine[li_len->node()->input()] = *ival;
92 boolean_value_refinements_[n->output()] = n->kind() == aten::eq
93 ? BooleanRefinementMapping::TrueRefinements(std::move(refine))
94 : BooleanRefinementMapping::FalseRefinements(std::move(refine));
95 }
96 } else if (n->kind() == aten::len) {
97 if (auto maybe_len = tryFindRefinement(n->input(0))) {
98 changed_ = true;
99 WithInsertPoint guard(n);
100 n->output()->replaceAllUsesWith(
101 graph_->insertConstant(static_cast<int64_t>(*maybe_len)));
102 }
103 } else if (n->kind() == prim::If) {
104 IfView if_n(n);
105 bool has_cond_ref = boolean_value_refinements_.count(if_n.cond()) != 0;
106 ListRefinement empty;
107 auto true_block_refinements = RefineListLens(
108 if_n.thenBlock(),
109 has_cond_ref ? boolean_value_refinements_[if_n.cond()].true_refine()
110 : empty);
111 auto false_block_refinements = RefineListLens(
112 if_n.elseBlock(),
113 has_cond_ref
114 ? boolean_value_refinements_[if_n.cond()].false_refine()
115 : empty);
116
117 joinIfRefinements(
118 n,
119 throwing_blocks_,
120 block_refinements,
121 true_block_refinements,
122 false_block_refinements,
123 boolean_value_refinements_);
124 } else {
125 handleCommonRefinentOperators(
126 n, throwing_blocks_, boolean_value_refinements_);
127 }
128 }
129 active_refinements_.pop_back();
130 return block_refinements;
131 };
132
133 c10::optional<int64_t> tryFindRefinement(Value* v) {
134 for (const auto& ref : active_refinements_) {
135 auto maybe_refinement = ref->find(v);
136 if (maybe_refinement != ref->end()) {
137 return maybe_refinement->second;
138 }
139 }
140 return c10::nullopt;
141 }
142
143 std::shared_ptr<Graph> graph_;
144 std::unordered_set<Value*> mutated_lists_;
145 // candidate lists for optimizations
146 std::unordered_set<Value*> lists_to_refine_;
147 // A stack of active refinements, one for each block
148 std::vector<ListRefinement*> active_refinements_;
149 // A map from Boolean Value * -> associated refinements
150 std::unordered_map<Value*, BooleanRefinementMapping>
151 boolean_value_refinements_;
152 std::unordered_set<Block*> throwing_blocks_;
153 bool changed_ = false;
154};
155
156// This pass only does optimizations on lists which aren't mutated,
157// so we first use the Alias Db to collect the set of list values
158// which we shouldn't optimize.
159struct PeepholeOptimizeListIdiomsImpl {
160 PeepholeOptimizeListIdiomsImpl(
161 std::shared_ptr<Graph> graph,
162 bool refine_list_len)
163 : graph_(std::move(graph)),
164 aliasDb_(torch::make_unique<AliasDb>(graph_)),
165 refine_list_len_(refine_list_len) {}
166
167 bool run() {
168 collectMutatedLists(graph_->block());
169 bool changed = runBlock(graph_->block());
170 if (refine_list_len_) {
171 changed |= ListLenRefiner(graph_, mutated_lists_).run();
172 }
173 return changed;
174 }
175
176 private:
177 void checkForMutatedList(Value* v) {
178 if (v->type()->castRaw<ListType>() && aliasDb_->hasWriters(v)) {
179 mutated_lists_.insert(v);
180 }
181 }
182
183 void collectMutatedLists(Block* b) {
184 for (Value* v : b->inputs()) {
185 checkForMutatedList(v);
186 }
187 for (Node* n : b->nodes()) {
188 for (Value* v : n->outputs()) {
189 checkForMutatedList(v);
190 }
191 for (Block* block : n->blocks()) {
192 collectMutatedLists(block);
193 }
194 }
195 }
196
197 bool optimizeSlice(Node* slice_node, Node* list_construct_node) {
198 auto start_val = toIValue(slice_node->input(1));
199 auto end_val = toIValue(slice_node->input(2));
200 auto step_val = toIValue(slice_node->input(3));
201
202 // All args must be constant to apply this optimization.
203 if (start_val == c10::nullopt || end_val == c10::nullopt ||
204 step_val == c10::nullopt) {
205 return false;
206 }
207
208 int64_t start = start_val->isInt() ? start_val->to<int64_t>()
209 : std::numeric_limits<int64_t>::max();
210 int64_t end = end_val->isInt() ? end_val->to<int64_t>()
211 : std::numeric_limits<int64_t>::max();
212 int64_t step = step_val->isInt() ? step_val->to<int64_t>() : 1;
213
214 size_t list_size = list_construct_node->inputs().size();
215 size_t num_values = slice_indices_adjust(list_size, &start, &end, step);
216
217 WithInsertPoint guard(slice_node);
218 auto slice_list_construct =
219 graph_->insertNode(graph_->create(prim::ListConstruct));
220 slice_list_construct->output()->setType(slice_node->output()->type());
221 for (size_t i = start, j = 0; j < num_values; ++j) {
222 slice_list_construct->addInput(list_construct_node->input(i));
223 i += step;
224 }
225
226 slice_node->output()->replaceAllUsesWith(slice_list_construct->output());
227 if (mutated_lists_.count(slice_node->output())) {
228 mutated_lists_.insert(slice_list_construct->output());
229 }
230
231 return true;
232 }
233
234 bool runBlock(Block* block) {
235 bool changed = false;
236 for (Node* node : block->nodes()) {
237 for (Block* b : node->blocks()) {
238 changed |= runBlock(b);
239 }
240
241 // only optimizing list ops
242 if (node->inputs().empty() ||
243 !node->input(0)->type()->castRaw<ListType>()) {
244 continue;
245 }
246
247 auto first_input = node->input(0);
248
249 // only optimizing ops with unmutated lists
250 if (mutated_lists_.count(first_input)) {
251 continue;
252 }
253
254 auto list_creation_node = first_input->node();
255 if (list_creation_node->kind() != prim::ListConstruct) {
256 continue;
257 }
258
259 if (node->kind() == aten::len) {
260 WithInsertPoint guard(node);
261 node->output()->replaceAllUsesWith(graph_->insertConstant(
262 static_cast<int64_t>(first_input->node()->inputs().size())));
263 changed = true;
264 } else if (node->kind() == aten::__getitem__) {
265 if (auto index = toIValue(node->input(1))) {
266 size_t list_size = list_creation_node->inputs().size();
267 if (auto norm_index = normalizeIndex(index->toInt(), list_size)) {
268 node->output()->replaceAllUsesWith(
269 list_creation_node->input(*norm_index));
270 changed = true;
271 }
272 }
273 } else if (node->kind() == prim::ListUnpack) {
274 // if sizes are unequal it's a runtime error
275 if (list_creation_node->inputs().size() != node->outputs().size()) {
276 continue;
277 }
278 for (size_t i = 0; i < node->outputs().size(); ++i) {
279 node->output(i)->replaceAllUsesWith(list_creation_node->input(i));
280 changed = true;
281 }
282 } else if (node->kind() == aten::add) {
283 if (node->inputs().size() != 2) {
284 continue;
285 }
286 auto second_input = node->input(1);
287 // already checked first, need to check second
288 if (mutated_lists_.count(second_input)) {
289 continue;
290 }
291 if (second_input->node()->kind() != prim::ListConstruct) {
292 continue;
293 }
294 WithInsertPoint guard(node);
295 auto list_construct =
296 graph_->insertNode(graph_->create(prim::ListConstruct));
297 list_construct->output()->setType(node->output()->type());
298 for (Value* v : first_input->node()->inputs()) {
299 list_construct->addInput(v);
300 }
301 for (Value* v : second_input->node()->inputs()) {
302 list_construct->addInput(v);
303 }
304 node->output()->replaceAllUsesWith(list_construct->output());
305 if (mutated_lists_.count(node->output())) {
306 mutated_lists_.insert(list_construct->output());
307 }
308 changed = true;
309 } else if (node->kind() == aten::slice) {
310 changed |= optimizeSlice(node, first_input->node());
311 }
312 }
313 return changed;
314 }
315
316 std::unordered_set<Value*> mutated_lists_;
317 std::shared_ptr<Graph> graph_;
318 std::unique_ptr<AliasDb> aliasDb_;
319 bool refine_list_len_;
320};
321
322bool PeepholeOptimizeListIdioms(
323 const std::shared_ptr<Graph>& graph,
324 bool refine_list_len) {
325 PeepholeOptimizeListIdiomsImpl opt(graph, refine_list_len);
326 return opt.run();
327}
328
329} // namespace jit
330} // namespace torch
331