1 | #include <torch/csrc/jit/passes/peephole.h> |
2 | |
3 | #include <ATen/core/jit_type.h> |
4 | #include <c10/util/irange.h> |
5 | #include <torch/csrc/jit/ir/alias_analysis.h> |
6 | #include <torch/csrc/jit/ir/ir_views.h> |
7 | #include <torch/csrc/jit/jit_log.h> |
8 | #include <torch/csrc/jit/passes/concat_opt.h> |
9 | #include <torch/csrc/jit/passes/dead_code_elimination.h> |
10 | #include <torch/csrc/jit/passes/peephole_alias_sensitive.h> |
11 | #include <torch/csrc/jit/passes/peephole_dict_idioms.h> |
12 | #include <torch/csrc/jit/passes/peephole_list_idioms.h> |
13 | #include <torch/csrc/jit/passes/peephole_non_tensor.h> |
14 | #include <torch/csrc/jit/runtime/graph_executor.h> |
15 | #include <torch/csrc/utils/memory.h> |
16 | |
17 | namespace torch { |
18 | namespace jit { |
19 | |
20 | // Conservatively compare two optionals. If both are undefined, assume |
21 | // they aren't equal |
22 | template <typename T> |
23 | static bool mustBeEqual(const c10::optional<T>& a, const c10::optional<T>& b) { |
24 | return a == b && a.has_value(); |
25 | } |
26 | |
27 | struct PeepholeOptimizeImpl { |
28 | PeepholeOptimizeImpl( |
29 | // NOLINTNEXTLINE(modernize-pass-by-value) |
30 | const std::shared_ptr<Graph>& graph, |
31 | bool disable_shape_peepholes) |
32 | : graph_(graph), shape_peepholes_(!disable_shape_peepholes) {} |
33 | |
34 | bool run() { |
35 | bool changed = optimizeBlock(graph_->block()); |
36 | changed |= PeepholeOptimizeListIdioms(graph_); |
37 | changed |= PeepholeOptimizeDictIdioms(graph_); |
38 | changed |= PeepholeOptimizeAliasSensitive(graph_, shape_peepholes_); |
39 | changed |= PeepholeOptimizeNonTensor(graph_); |
40 | changed |= CombineConcats(graph_); |
41 | return changed; |
42 | } |
43 | |
44 | // The intent for this optimization pass is to catch all of the small, easy to |
45 | // catch peephole optimizations you might be interested in doing. |
46 | // |
47 | // TODO: Decide what kind of fixed point strategy we will have |
48 | bool optimizeBlock(Block* block) { |
49 | bool changed = false; |
50 | for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) { |
51 | auto* node = *it; |
52 | |
53 | for (Block* sub_block : node->blocks()) { |
54 | changed |= optimizeBlock(sub_block); |
55 | } |
56 | |
57 | // XXX: remember that if you want to simplify an expression by combining |
58 | // multiple nodes into a different one, then you need to check that they |
59 | // all belong to the given block |
60 | // TODO: this doesn't work with Scalar-Tensor ops! We should |
61 | // canonicalize those |
62 | if (node->matches( |
63 | "aten::_grad_sum_to_size(Tensor(a) self, int[]? size) -> Tensor(a)" )) { |
64 | // Eliminate no-op _grad_sum_to_size. |
65 | // TODO: this doesn't work with Scalar-Tensor ops! We should |
66 | // canonicalize those |
67 | if (node->input(1)->mustBeNone()) { |
68 | GRAPH_UPDATE( |
69 | getHeader(node), |
70 | " (x._grad_sum_to_size(x, None) == x) is replaced with " , |
71 | node->input(0)->debugName()); |
72 | node->output()->replaceAllUsesWith(node->input(0)); |
73 | changed = true; |
74 | } else { |
75 | auto uses = node->output()->uses(); |
76 | for (Use u : uses) { |
77 | if (u.user->matches( |
78 | "aten::_grad_sum_to_size(Tensor(a) self, int[]? size) -> Tensor(a)" ) && |
79 | u.user->input(1)->type()->isSubtypeOf(*ListType::ofInts())) { |
80 | GRAPH_UPDATE( |
81 | getHeader(node), |
82 | " (x._grad_sum_to_size(y)._grad_sum_to_size(z) == x._grad_sum_to_size(z)) is replaced with " , |
83 | node->inputs().at(0)->debugName()); |
84 | u.user->replaceInput(0, node->inputs().at(0)); |
85 | changed = true; |
86 | } |
87 | } |
88 | } |
89 | } else if ( |
90 | node->matches( |
91 | "aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor" , |
92 | /*const_inputs=*/attr::size)) { |
93 | // x.expand(x.size()) == x |
94 | auto input_type = |
95 | node->namedInput(attr::self)->type()->cast<TensorType>(); |
96 | if (input_type && shape_peepholes_) { |
97 | auto expanded_sizes = node->get<c10::List<int64_t>>(attr::size); |
98 | auto input_type_sizes = input_type->sizes().concrete_sizes(); |
99 | if (expanded_sizes.has_value() && input_type_sizes && |
100 | expanded_sizes->vec() == *input_type_sizes) { |
101 | GRAPH_UPDATE( |
102 | getHeader(node), |
103 | " (x.expand(x.size()) == x) is replaced with " , |
104 | node->namedInput(attr::self)->debugName()); |
105 | node->output()->replaceAllUsesWith(node->namedInput(attr::self)); |
106 | changed = true; |
107 | } |
108 | } |
109 | } else if (node->matches("aten::t(Tensor self) -> Tensor" )) { |
110 | // x.t().t() == x |
111 | Node* input_node = node->input()->node(); |
112 | if (input_node->matches("aten::t(Tensor self) -> Tensor" )) { |
113 | GRAPH_UPDATE( |
114 | getHeader(node), |
115 | " (x.t().t() == x) is replaced with " , |
116 | input_node->input()->debugName()); |
117 | node->output()->replaceAllUsesWith(input_node->input()); |
118 | changed = true; |
119 | } |
120 | } else if ( |
121 | node->matches("aten::type_as(Tensor self, Tensor other) -> Tensor" ) && |
122 | shape_peepholes_) { |
123 | // x.type_as(y) == x iff x.type() == y.type() |
124 | auto self_type = node->input(0)->type()->expect<TensorType>(); |
125 | auto other_type = node->input(1)->type()->expect<TensorType>(); |
126 | if (mustBeEqual(self_type->scalarType(), other_type->scalarType()) && |
127 | mustBeEqual(self_type->device(), other_type->device())) { |
128 | GRAPH_UPDATE( |
129 | getHeader(node), |
130 | " (x.type_as(y) == x) is replaced with " , |
131 | node->input(0)->debugName()); |
132 | node->output()->replaceAllUsesWith(node->input(0)); |
133 | changed = true; |
134 | } |
135 | } else if ( |
136 | node->kind() == aten::Float || node->kind() == aten::Int || |
137 | node->kind() == aten::FloatImplicit || |
138 | node->kind() == aten::IntImplicit || |
139 | node->kind() == aten::ScalarImplicit) { |
140 | Node* input_node = node->input()->node(); |
141 | if (input_node->kind() == prim::NumToTensor) { |
142 | GRAPH_UPDATE( |
143 | getHeader(node), |
144 | " (x.NumToTensor() == x) is replaced with " , |
145 | node->input()->debugName()); |
146 | node->output()->replaceAllUsesWith(input_node->input()); |
147 | changed = true; |
148 | } |
149 | } else if ( |
150 | node->matches("aten::size(Tensor self) -> int[]" ) && |
151 | shape_peepholes_) { |
152 | if (auto ptt = node->input()->type()->cast<TensorType>()) { |
153 | if (auto sizes = ptt->sizes().concrete_sizes()) { |
154 | GRAPH_UPDATE( |
155 | getHeader(node), |
156 | " (x.size()) is replaced with " , |
157 | node->input()->debugName()); |
158 | WithInsertPoint guard(node); |
159 | IValue ival(sizes); |
160 | auto const_sizes_val = node->owningGraph()->insertConstant(ival); |
161 | node->output()->replaceAllUsesWith(const_sizes_val); |
162 | changed = true; |
163 | } |
164 | } |
165 | } else if ( |
166 | node->matches("aten::len.t(t[] a) -> int" ) && |
167 | node->input()->node()->matches("aten::size(Tensor self) -> int[]" ) && |
168 | shape_peepholes_) { |
169 | auto ptt = node->input()->node()->input()->type()->expect<TensorType>(); |
170 | // only handle one use case for now to avoid modifying mutated lists |
171 | // TODO: canonicalize as aten::dim ? |
172 | if (ptt->sizes().size() && node->input()->uses().size() == 1) { |
173 | WithInsertPoint guard(node); |
174 | auto output = node->owningGraph()->insertConstant( |
175 | static_cast<int64_t>(*ptt->sizes().size())); |
176 | GRAPH_UPDATE( |
177 | "Replacing " , |
178 | getHeader(node), |
179 | " with a \"dim\" constant " , |
180 | output->debugName()); |
181 | node->output()->replaceAllUsesWith(output); |
182 | changed = true; |
183 | } |
184 | } else if ( |
185 | node->matches("aten::size(Tensor self, int dim) -> int" ) && |
186 | shape_peepholes_) { |
187 | if (auto ptt = node->inputs().at(0)->type()->cast<TensorType>()) { |
188 | if (auto maybe_ndim = ptt->sizes().size()) { |
189 | auto ndim = *maybe_ndim; |
190 | auto maybe_index = toIValue(node->inputs().at(1)); |
191 | if (!maybe_index) { |
192 | continue; |
193 | } |
194 | int64_t index = maybe_index->toInt(); |
195 | int64_t norm_index = index < 0 ? ndim + index : index; |
196 | if (norm_index >= 0 && norm_index < static_cast<int64_t>(ndim) && |
197 | ptt->sizes()[norm_index]) { |
198 | WithInsertPoint guard(node); |
199 | IValue ival(*ptt->sizes()[norm_index]); |
200 | auto const_sizes_val = node->owningGraph()->insertConstant(ival); |
201 | node->output()->replaceAllUsesWith(const_sizes_val); |
202 | GRAPH_UPDATE( |
203 | getHeader(node), |
204 | " (x.size(dim)) is replaced with constant " , |
205 | const_sizes_val->debugName()); |
206 | changed = true; |
207 | } |
208 | } |
209 | } |
210 | } else if ( |
211 | node->matches("aten::is_floating_point(Tensor self) -> bool" ) && |
212 | shape_peepholes_) { |
213 | auto ptt = node->inputs().at(0)->type()->cast<TensorType>(); |
214 | if (auto maybe_dtype = ptt->scalarType()) { |
215 | c10::ScalarType dtype = *maybe_dtype; |
216 | WithInsertPoint guard(node); |
217 | IValue ival(at::isFloatingType(dtype)); |
218 | auto new_constant = node->owningGraph()->insertConstant(ival); |
219 | node->output()->replaceAllUsesWith(new_constant); |
220 | GRAPH_UPDATE( |
221 | getHeader(node), |
222 | " (x.is_floating_point()) is replaced with " , |
223 | new_constant->debugName()); |
224 | changed = true; |
225 | } |
226 | } else if ( |
227 | node->matches("aten::is_complex(Tensor self) -> bool" ) && |
228 | shape_peepholes_) { |
229 | auto ptt = node->inputs().at(0)->type()->cast<TensorType>(); |
230 | if (auto maybe_dtype = ptt->scalarType()) { |
231 | c10::ScalarType dtype = *maybe_dtype; |
232 | WithInsertPoint guard(node); |
233 | IValue ival(at::isComplexType(dtype)); |
234 | auto new_constant = node->owningGraph()->insertConstant(ival); |
235 | node->output()->replaceAllUsesWith(new_constant); |
236 | GRAPH_UPDATE( |
237 | getHeader(node), |
238 | " (x.is_complex()) is replaced with " , |
239 | new_constant->debugName()); |
240 | changed = true; |
241 | } |
242 | } else if ( |
243 | node->matches("prim::dtype(Tensor a) -> int" ) && shape_peepholes_) { |
244 | auto ptt = node->input()->type()->expect<TensorType>(); |
245 | if (ptt->scalarType()) { |
246 | WithInsertPoint guard(node); |
247 | auto output = node->owningGraph()->insertConstant( |
248 | static_cast<int64_t>(*ptt->scalarType())); |
249 | GRAPH_UPDATE( |
250 | "Replacing " , |
251 | getHeader(node), |
252 | " with a type constant " , |
253 | output->debugName()); |
254 | node->output()->replaceAllUsesWith(output); |
255 | changed = true; |
256 | } |
257 | } else if ( |
258 | node->matches("prim::device(Tensor a) -> Device" ) && |
259 | shape_peepholes_) { |
260 | auto ptt = node->input()->type()->expect<TensorType>(); |
261 | if (ptt->device()) { |
262 | WithInsertPoint guard(node); |
263 | auto output = node->owningGraph()->insertConstant(*ptt->device()); |
264 | GRAPH_UPDATE( |
265 | "Replacing " , |
266 | getHeader(node), |
267 | " with a device constant " , |
268 | output->debugName()); |
269 | node->output()->replaceAllUsesWith(output); |
270 | changed = true; |
271 | } |
272 | } else if ( |
273 | node->matches("aten::dim(Tensor self) -> int" ) && shape_peepholes_) { |
274 | auto ptt = node->input()->type()->expect<TensorType>(); |
275 | if (auto dim = ptt->sizes().size()) { |
276 | WithInsertPoint guard(node); |
277 | auto output = |
278 | node->owningGraph()->insertConstant(static_cast<int64_t>(*dim)); |
279 | GRAPH_UPDATE( |
280 | "Replacing " , |
281 | getHeader(node), |
282 | " with a \"dim\" constant " , |
283 | output->debugName()); |
284 | node->output()->replaceAllUsesWith(output); |
285 | changed = true; |
286 | } |
287 | } else if ( |
288 | node->matches("prim::is_cuda(Tensor a) -> bool" ) && |
289 | shape_peepholes_) { |
290 | auto ptt = node->input()->type()->expect<TensorType>(); |
291 | if (ptt->device()) { |
292 | WithInsertPoint guard(node); |
293 | auto output = |
294 | node->owningGraph()->insertConstant((*ptt->device()).is_cuda()); |
295 | GRAPH_UPDATE( |
296 | "Replacing " , |
297 | getHeader(node), |
298 | " with a is_cuda constant " , |
299 | output->debugName()); |
300 | node->output()->replaceAllUsesWith(output); |
301 | changed = true; |
302 | } |
303 | } |
304 | } |
305 | return changed; |
306 | } |
307 | |
308 | private: |
309 | std::shared_ptr<Graph> graph_; |
310 | bool shape_peepholes_; |
311 | }; |
312 | |
313 | bool FuseAddMM(Block* block) { |
314 | bool changed = false; |
315 | for (Node* node : block->nodes()) { |
316 | // XXX: remember that if you want to simplify an expression by combining |
317 | // multiple nodes into a different one, then you need to check that they |
318 | // all belong to the given block |
319 | if (node->matches( |
320 | "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor" , |
321 | /*const_inputs=*/attr::alpha)) { |
322 | // z + x.mm(y) == z.addmm(x, y) == x.mm(y) + z |
323 | if (node->get<at::Scalar>(attr::alpha).value().toDouble() == 1.) { |
324 | // Look for mm from both sides of the add |
325 | for (const auto mm_side : c10::irange(2)) { |
326 | // Add will accept tensors of mismatched scalar types, as long as |
327 | // one of them is a scalar, but addmm will throw in that case, so we |
328 | // can only perform this fusion if we're sure that it is correct, |
329 | // and for that we need the add_mat_type. An alternative would be to |
330 | // insert a type_as conditional on the tensor shape being a scalar, |
331 | // but that might add overhead, and make analysis harder. |
332 | auto add_mat_type = |
333 | node->input(1 - mm_side)->type()->expect<TensorType>(); |
334 | // if we don't have the rank, we can't tell if the bias is a scalar |
335 | if (!add_mat_type->sizes().size()) { |
336 | continue; |
337 | } |
338 | |
339 | if (node->input(mm_side)->node()->matches( |
340 | "aten::mm(Tensor self, Tensor mat2) -> Tensor" )) { |
341 | WithInsertPoint guard(node); |
342 | |
343 | auto* graph = node->owningGraph(); |
344 | auto* mm_node = node->input(mm_side)->node(); |
345 | auto* add_mat = node->input(1 - mm_side); |
346 | auto* mat1 = mm_node->input(0); |
347 | auto* mat2 = mm_node->input(1); |
348 | |
349 | // Attempts to find a matrix with a defined scalar type to type as |
350 | auto* type_as_mat = mat1; |
351 | if (!type_as_mat->type()->expectRef<TensorType>().scalarType()) { |
352 | type_as_mat = mat2; |
353 | } |
354 | auto mat_scalar_type = |
355 | type_as_mat->type()->expectRef<TensorType>().scalarType(); |
356 | |
357 | // we can't use type_as if we don't know the target type (mm), the |
358 | // bias needs to be coerced to |
359 | if (!mat_scalar_type) { |
360 | continue; |
361 | } |
362 | |
363 | // We insert the type_as if we're sure that the added element is a |
364 | // scalar, and we either don't know the type of the scalar, or |
365 | // know that it's mismatched. |
366 | if (add_mat_type->sizes().size() && |
367 | *add_mat_type->sizes().size() == 0 && |
368 | !mustBeEqual(add_mat_type->scalarType(), mat_scalar_type)) { |
369 | auto* type_as_node = |
370 | graph->insertNode(graph->create(aten::type_as, 1)); |
371 | type_as_node->addInput(add_mat); |
372 | type_as_node->addInput(type_as_mat); |
373 | add_mat = type_as_node->output(); |
374 | if (add_mat_type->isComplete()) { |
375 | auto new_type = |
376 | add_mat_type->withScalarType(mat_scalar_type)->contiguous(); |
377 | add_mat->setType(new_type); |
378 | } |
379 | } |
380 | |
381 | auto* cOne = graph->insertConstant(1); |
382 | auto* addmm_node = graph->insertNode(graph->create(aten::addmm, 1)); |
383 | addmm_node->addInput(add_mat); |
384 | addmm_node->addInput(mat1); |
385 | addmm_node->addInput(mat2); |
386 | addmm_node->addInput(cOne); |
387 | addmm_node->addInput(cOne); |
388 | auto* addmm_value = addmm_node->output(); |
389 | |
390 | // Copy shape information from output node |
391 | addmm_value->copyMetadata(node->output()); |
392 | GRAPH_UPDATE( |
393 | "Fusing " , |
394 | mm_node->input(0)->debugName(), |
395 | ", " , |
396 | mm_node->input(1)->debugName(), |
397 | " and " , |
398 | node->input(1 - mm_side)->debugName(), |
399 | " into " , |
400 | addmm_value->debugName()); |
401 | node->output()->replaceAllUsesWith(addmm_value); |
402 | changed = true; |
403 | continue; |
404 | } |
405 | } |
406 | } |
407 | } |
408 | for (Block* b : node->blocks()) { |
409 | changed |= FuseAddMM(b); |
410 | } |
411 | } |
412 | return changed; |
413 | } |
414 | |
415 | // FuseAddMM is a separate pass from peephole optimize because it is currently |
416 | // used for exporting to ONNX. |
417 | // Today, fusing add + MM has no benefit within PyTorch running ATen |
418 | // ops. However, we rely on seeing the fused version of AddMM for ONNX export, |
419 | // since otherwise after ONNX translation we would see redundant Gemm ops with |
420 | // sub-optimal inputs. |
421 | // It won't be helpful for ATen until we're able to represent |
422 | // torch.addmm(a, b, c, out=a). |
423 | // That's because addmm dispatches internally to gemm, which computes: |
424 | // C = beta * C + alpha * A @ B |
425 | // but aten::addmm(a, b, c, 1, 1) is really: |
426 | // D = beta * C + alpha * A @ B |
427 | // and because it works out of place on C, we're only trading off an |
428 | // explicit add for a copy inside the addmm function. Note that it |
429 | // doesn't even result in fewer reads, because mm won't even load C |
430 | // (because beta == 0 for it). |
431 | bool FuseAddMM(const std::shared_ptr<Graph>& graph) { |
432 | bool changed = FuseAddMM(graph->block()); |
433 | GRAPH_DUMP("After FuseAddMM: " , graph); |
434 | return changed; |
435 | } |
436 | |
437 | bool PeepholeOptimize( |
438 | const std::shared_ptr<Graph>& graph, |
439 | bool addmm_fusion_enabled) { |
440 | PeepholeOptimizeImpl peephole(graph, addmm_fusion_enabled); |
441 | bool changed = peephole.run(); |
442 | GRAPH_DUMP("After PeepholeOptimize: " , graph); |
443 | // Eliminate dead code created by any peephole passes we've just done |
444 | if (changed) { |
445 | EliminateDeadCode(graph->block()); |
446 | } |
447 | return changed; |
448 | } |
449 | |
450 | } // namespace jit |
451 | } // namespace torch |
452 | |