1#include <torch/csrc/jit/passes/peephole.h>
2
3#include <ATen/core/jit_type.h>
4#include <c10/util/irange.h>
5#include <torch/csrc/jit/ir/alias_analysis.h>
6#include <torch/csrc/jit/ir/ir_views.h>
7#include <torch/csrc/jit/jit_log.h>
8#include <torch/csrc/jit/passes/concat_opt.h>
9#include <torch/csrc/jit/passes/dead_code_elimination.h>
10#include <torch/csrc/jit/passes/peephole_alias_sensitive.h>
11#include <torch/csrc/jit/passes/peephole_dict_idioms.h>
12#include <torch/csrc/jit/passes/peephole_list_idioms.h>
13#include <torch/csrc/jit/passes/peephole_non_tensor.h>
14#include <torch/csrc/jit/runtime/graph_executor.h>
15#include <torch/csrc/utils/memory.h>
16
17namespace torch {
18namespace jit {
19
20// Conservatively compare two optionals. If both are undefined, assume
21// they aren't equal
22template <typename T>
23static bool mustBeEqual(const c10::optional<T>& a, const c10::optional<T>& b) {
24 return a == b && a.has_value();
25}
26
27struct PeepholeOptimizeImpl {
28 PeepholeOptimizeImpl(
29 // NOLINTNEXTLINE(modernize-pass-by-value)
30 const std::shared_ptr<Graph>& graph,
31 bool disable_shape_peepholes)
32 : graph_(graph), shape_peepholes_(!disable_shape_peepholes) {}
33
34 bool run() {
35 bool changed = optimizeBlock(graph_->block());
36 changed |= PeepholeOptimizeListIdioms(graph_);
37 changed |= PeepholeOptimizeDictIdioms(graph_);
38 changed |= PeepholeOptimizeAliasSensitive(graph_, shape_peepholes_);
39 changed |= PeepholeOptimizeNonTensor(graph_);
40 changed |= CombineConcats(graph_);
41 return changed;
42 }
43
44 // The intent for this optimization pass is to catch all of the small, easy to
45 // catch peephole optimizations you might be interested in doing.
46 //
47 // TODO: Decide what kind of fixed point strategy we will have
48 bool optimizeBlock(Block* block) {
49 bool changed = false;
50 for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) {
51 auto* node = *it;
52
53 for (Block* sub_block : node->blocks()) {
54 changed |= optimizeBlock(sub_block);
55 }
56
57 // XXX: remember that if you want to simplify an expression by combining
58 // multiple nodes into a different one, then you need to check that they
59 // all belong to the given block
60 // TODO: this doesn't work with Scalar-Tensor ops! We should
61 // canonicalize those
62 if (node->matches(
63 "aten::_grad_sum_to_size(Tensor(a) self, int[]? size) -> Tensor(a)")) {
64 // Eliminate no-op _grad_sum_to_size.
65 // TODO: this doesn't work with Scalar-Tensor ops! We should
66 // canonicalize those
67 if (node->input(1)->mustBeNone()) {
68 GRAPH_UPDATE(
69 getHeader(node),
70 " (x._grad_sum_to_size(x, None) == x) is replaced with ",
71 node->input(0)->debugName());
72 node->output()->replaceAllUsesWith(node->input(0));
73 changed = true;
74 } else {
75 auto uses = node->output()->uses();
76 for (Use u : uses) {
77 if (u.user->matches(
78 "aten::_grad_sum_to_size(Tensor(a) self, int[]? size) -> Tensor(a)") &&
79 u.user->input(1)->type()->isSubtypeOf(*ListType::ofInts())) {
80 GRAPH_UPDATE(
81 getHeader(node),
82 " (x._grad_sum_to_size(y)._grad_sum_to_size(z) == x._grad_sum_to_size(z)) is replaced with ",
83 node->inputs().at(0)->debugName());
84 u.user->replaceInput(0, node->inputs().at(0));
85 changed = true;
86 }
87 }
88 }
89 } else if (
90 node->matches(
91 "aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor",
92 /*const_inputs=*/attr::size)) {
93 // x.expand(x.size()) == x
94 auto input_type =
95 node->namedInput(attr::self)->type()->cast<TensorType>();
96 if (input_type && shape_peepholes_) {
97 auto expanded_sizes = node->get<c10::List<int64_t>>(attr::size);
98 auto input_type_sizes = input_type->sizes().concrete_sizes();
99 if (expanded_sizes.has_value() && input_type_sizes &&
100 expanded_sizes->vec() == *input_type_sizes) {
101 GRAPH_UPDATE(
102 getHeader(node),
103 " (x.expand(x.size()) == x) is replaced with ",
104 node->namedInput(attr::self)->debugName());
105 node->output()->replaceAllUsesWith(node->namedInput(attr::self));
106 changed = true;
107 }
108 }
109 } else if (node->matches("aten::t(Tensor self) -> Tensor")) {
110 // x.t().t() == x
111 Node* input_node = node->input()->node();
112 if (input_node->matches("aten::t(Tensor self) -> Tensor")) {
113 GRAPH_UPDATE(
114 getHeader(node),
115 " (x.t().t() == x) is replaced with ",
116 input_node->input()->debugName());
117 node->output()->replaceAllUsesWith(input_node->input());
118 changed = true;
119 }
120 } else if (
121 node->matches("aten::type_as(Tensor self, Tensor other) -> Tensor") &&
122 shape_peepholes_) {
123 // x.type_as(y) == x iff x.type() == y.type()
124 auto self_type = node->input(0)->type()->expect<TensorType>();
125 auto other_type = node->input(1)->type()->expect<TensorType>();
126 if (mustBeEqual(self_type->scalarType(), other_type->scalarType()) &&
127 mustBeEqual(self_type->device(), other_type->device())) {
128 GRAPH_UPDATE(
129 getHeader(node),
130 " (x.type_as(y) == x) is replaced with ",
131 node->input(0)->debugName());
132 node->output()->replaceAllUsesWith(node->input(0));
133 changed = true;
134 }
135 } else if (
136 node->kind() == aten::Float || node->kind() == aten::Int ||
137 node->kind() == aten::FloatImplicit ||
138 node->kind() == aten::IntImplicit ||
139 node->kind() == aten::ScalarImplicit) {
140 Node* input_node = node->input()->node();
141 if (input_node->kind() == prim::NumToTensor) {
142 GRAPH_UPDATE(
143 getHeader(node),
144 " (x.NumToTensor() == x) is replaced with ",
145 node->input()->debugName());
146 node->output()->replaceAllUsesWith(input_node->input());
147 changed = true;
148 }
149 } else if (
150 node->matches("aten::size(Tensor self) -> int[]") &&
151 shape_peepholes_) {
152 if (auto ptt = node->input()->type()->cast<TensorType>()) {
153 if (auto sizes = ptt->sizes().concrete_sizes()) {
154 GRAPH_UPDATE(
155 getHeader(node),
156 " (x.size()) is replaced with ",
157 node->input()->debugName());
158 WithInsertPoint guard(node);
159 IValue ival(sizes);
160 auto const_sizes_val = node->owningGraph()->insertConstant(ival);
161 node->output()->replaceAllUsesWith(const_sizes_val);
162 changed = true;
163 }
164 }
165 } else if (
166 node->matches("aten::len.t(t[] a) -> int") &&
167 node->input()->node()->matches("aten::size(Tensor self) -> int[]") &&
168 shape_peepholes_) {
169 auto ptt = node->input()->node()->input()->type()->expect<TensorType>();
170 // only handle one use case for now to avoid modifying mutated lists
171 // TODO: canonicalize as aten::dim ?
172 if (ptt->sizes().size() && node->input()->uses().size() == 1) {
173 WithInsertPoint guard(node);
174 auto output = node->owningGraph()->insertConstant(
175 static_cast<int64_t>(*ptt->sizes().size()));
176 GRAPH_UPDATE(
177 "Replacing ",
178 getHeader(node),
179 " with a \"dim\" constant ",
180 output->debugName());
181 node->output()->replaceAllUsesWith(output);
182 changed = true;
183 }
184 } else if (
185 node->matches("aten::size(Tensor self, int dim) -> int") &&
186 shape_peepholes_) {
187 if (auto ptt = node->inputs().at(0)->type()->cast<TensorType>()) {
188 if (auto maybe_ndim = ptt->sizes().size()) {
189 auto ndim = *maybe_ndim;
190 auto maybe_index = toIValue(node->inputs().at(1));
191 if (!maybe_index) {
192 continue;
193 }
194 int64_t index = maybe_index->toInt();
195 int64_t norm_index = index < 0 ? ndim + index : index;
196 if (norm_index >= 0 && norm_index < static_cast<int64_t>(ndim) &&
197 ptt->sizes()[norm_index]) {
198 WithInsertPoint guard(node);
199 IValue ival(*ptt->sizes()[norm_index]);
200 auto const_sizes_val = node->owningGraph()->insertConstant(ival);
201 node->output()->replaceAllUsesWith(const_sizes_val);
202 GRAPH_UPDATE(
203 getHeader(node),
204 " (x.size(dim)) is replaced with constant ",
205 const_sizes_val->debugName());
206 changed = true;
207 }
208 }
209 }
210 } else if (
211 node->matches("aten::is_floating_point(Tensor self) -> bool") &&
212 shape_peepholes_) {
213 auto ptt = node->inputs().at(0)->type()->cast<TensorType>();
214 if (auto maybe_dtype = ptt->scalarType()) {
215 c10::ScalarType dtype = *maybe_dtype;
216 WithInsertPoint guard(node);
217 IValue ival(at::isFloatingType(dtype));
218 auto new_constant = node->owningGraph()->insertConstant(ival);
219 node->output()->replaceAllUsesWith(new_constant);
220 GRAPH_UPDATE(
221 getHeader(node),
222 " (x.is_floating_point()) is replaced with ",
223 new_constant->debugName());
224 changed = true;
225 }
226 } else if (
227 node->matches("aten::is_complex(Tensor self) -> bool") &&
228 shape_peepholes_) {
229 auto ptt = node->inputs().at(0)->type()->cast<TensorType>();
230 if (auto maybe_dtype = ptt->scalarType()) {
231 c10::ScalarType dtype = *maybe_dtype;
232 WithInsertPoint guard(node);
233 IValue ival(at::isComplexType(dtype));
234 auto new_constant = node->owningGraph()->insertConstant(ival);
235 node->output()->replaceAllUsesWith(new_constant);
236 GRAPH_UPDATE(
237 getHeader(node),
238 " (x.is_complex()) is replaced with ",
239 new_constant->debugName());
240 changed = true;
241 }
242 } else if (
243 node->matches("prim::dtype(Tensor a) -> int") && shape_peepholes_) {
244 auto ptt = node->input()->type()->expect<TensorType>();
245 if (ptt->scalarType()) {
246 WithInsertPoint guard(node);
247 auto output = node->owningGraph()->insertConstant(
248 static_cast<int64_t>(*ptt->scalarType()));
249 GRAPH_UPDATE(
250 "Replacing ",
251 getHeader(node),
252 " with a type constant ",
253 output->debugName());
254 node->output()->replaceAllUsesWith(output);
255 changed = true;
256 }
257 } else if (
258 node->matches("prim::device(Tensor a) -> Device") &&
259 shape_peepholes_) {
260 auto ptt = node->input()->type()->expect<TensorType>();
261 if (ptt->device()) {
262 WithInsertPoint guard(node);
263 auto output = node->owningGraph()->insertConstant(*ptt->device());
264 GRAPH_UPDATE(
265 "Replacing ",
266 getHeader(node),
267 " with a device constant ",
268 output->debugName());
269 node->output()->replaceAllUsesWith(output);
270 changed = true;
271 }
272 } else if (
273 node->matches("aten::dim(Tensor self) -> int") && shape_peepholes_) {
274 auto ptt = node->input()->type()->expect<TensorType>();
275 if (auto dim = ptt->sizes().size()) {
276 WithInsertPoint guard(node);
277 auto output =
278 node->owningGraph()->insertConstant(static_cast<int64_t>(*dim));
279 GRAPH_UPDATE(
280 "Replacing ",
281 getHeader(node),
282 " with a \"dim\" constant ",
283 output->debugName());
284 node->output()->replaceAllUsesWith(output);
285 changed = true;
286 }
287 } else if (
288 node->matches("prim::is_cuda(Tensor a) -> bool") &&
289 shape_peepholes_) {
290 auto ptt = node->input()->type()->expect<TensorType>();
291 if (ptt->device()) {
292 WithInsertPoint guard(node);
293 auto output =
294 node->owningGraph()->insertConstant((*ptt->device()).is_cuda());
295 GRAPH_UPDATE(
296 "Replacing ",
297 getHeader(node),
298 " with a is_cuda constant ",
299 output->debugName());
300 node->output()->replaceAllUsesWith(output);
301 changed = true;
302 }
303 }
304 }
305 return changed;
306 }
307
308 private:
309 std::shared_ptr<Graph> graph_;
310 bool shape_peepholes_;
311};
312
313bool FuseAddMM(Block* block) {
314 bool changed = false;
315 for (Node* node : block->nodes()) {
316 // XXX: remember that if you want to simplify an expression by combining
317 // multiple nodes into a different one, then you need to check that they
318 // all belong to the given block
319 if (node->matches(
320 "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
321 /*const_inputs=*/attr::alpha)) {
322 // z + x.mm(y) == z.addmm(x, y) == x.mm(y) + z
323 if (node->get<at::Scalar>(attr::alpha).value().toDouble() == 1.) {
324 // Look for mm from both sides of the add
325 for (const auto mm_side : c10::irange(2)) {
326 // Add will accept tensors of mismatched scalar types, as long as
327 // one of them is a scalar, but addmm will throw in that case, so we
328 // can only perform this fusion if we're sure that it is correct,
329 // and for that we need the add_mat_type. An alternative would be to
330 // insert a type_as conditional on the tensor shape being a scalar,
331 // but that might add overhead, and make analysis harder.
332 auto add_mat_type =
333 node->input(1 - mm_side)->type()->expect<TensorType>();
334 // if we don't have the rank, we can't tell if the bias is a scalar
335 if (!add_mat_type->sizes().size()) {
336 continue;
337 }
338
339 if (node->input(mm_side)->node()->matches(
340 "aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
341 WithInsertPoint guard(node);
342
343 auto* graph = node->owningGraph();
344 auto* mm_node = node->input(mm_side)->node();
345 auto* add_mat = node->input(1 - mm_side);
346 auto* mat1 = mm_node->input(0);
347 auto* mat2 = mm_node->input(1);
348
349 // Attempts to find a matrix with a defined scalar type to type as
350 auto* type_as_mat = mat1;
351 if (!type_as_mat->type()->expectRef<TensorType>().scalarType()) {
352 type_as_mat = mat2;
353 }
354 auto mat_scalar_type =
355 type_as_mat->type()->expectRef<TensorType>().scalarType();
356
357 // we can't use type_as if we don't know the target type (mm), the
358 // bias needs to be coerced to
359 if (!mat_scalar_type) {
360 continue;
361 }
362
363 // We insert the type_as if we're sure that the added element is a
364 // scalar, and we either don't know the type of the scalar, or
365 // know that it's mismatched.
366 if (add_mat_type->sizes().size() &&
367 *add_mat_type->sizes().size() == 0 &&
368 !mustBeEqual(add_mat_type->scalarType(), mat_scalar_type)) {
369 auto* type_as_node =
370 graph->insertNode(graph->create(aten::type_as, 1));
371 type_as_node->addInput(add_mat);
372 type_as_node->addInput(type_as_mat);
373 add_mat = type_as_node->output();
374 if (add_mat_type->isComplete()) {
375 auto new_type =
376 add_mat_type->withScalarType(mat_scalar_type)->contiguous();
377 add_mat->setType(new_type);
378 }
379 }
380
381 auto* cOne = graph->insertConstant(1);
382 auto* addmm_node = graph->insertNode(graph->create(aten::addmm, 1));
383 addmm_node->addInput(add_mat);
384 addmm_node->addInput(mat1);
385 addmm_node->addInput(mat2);
386 addmm_node->addInput(cOne);
387 addmm_node->addInput(cOne);
388 auto* addmm_value = addmm_node->output();
389
390 // Copy shape information from output node
391 addmm_value->copyMetadata(node->output());
392 GRAPH_UPDATE(
393 "Fusing ",
394 mm_node->input(0)->debugName(),
395 ", ",
396 mm_node->input(1)->debugName(),
397 " and ",
398 node->input(1 - mm_side)->debugName(),
399 " into ",
400 addmm_value->debugName());
401 node->output()->replaceAllUsesWith(addmm_value);
402 changed = true;
403 continue;
404 }
405 }
406 }
407 }
408 for (Block* b : node->blocks()) {
409 changed |= FuseAddMM(b);
410 }
411 }
412 return changed;
413}
414
415// FuseAddMM is a separate pass from peephole optimize because it is currently
416// used for exporting to ONNX.
417// Today, fusing add + MM has no benefit within PyTorch running ATen
418// ops. However, we rely on seeing the fused version of AddMM for ONNX export,
419// since otherwise after ONNX translation we would see redundant Gemm ops with
420// sub-optimal inputs.
421// It won't be helpful for ATen until we're able to represent
422// torch.addmm(a, b, c, out=a).
423// That's because addmm dispatches internally to gemm, which computes:
424// C = beta * C + alpha * A @ B
425// but aten::addmm(a, b, c, 1, 1) is really:
426// D = beta * C + alpha * A @ B
427// and because it works out of place on C, we're only trading off an
428// explicit add for a copy inside the addmm function. Note that it
429// doesn't even result in fewer reads, because mm won't even load C
430// (because beta == 0 for it).
431bool FuseAddMM(const std::shared_ptr<Graph>& graph) {
432 bool changed = FuseAddMM(graph->block());
433 GRAPH_DUMP("After FuseAddMM: ", graph);
434 return changed;
435}
436
437bool PeepholeOptimize(
438 const std::shared_ptr<Graph>& graph,
439 bool addmm_fusion_enabled) {
440 PeepholeOptimizeImpl peephole(graph, addmm_fusion_enabled);
441 bool changed = peephole.run();
442 GRAPH_DUMP("After PeepholeOptimize: ", graph);
443 // Eliminate dead code created by any peephole passes we've just done
444 if (changed) {
445 EliminateDeadCode(graph->block());
446 }
447 return changed;
448}
449
450} // namespace jit
451} // namespace torch
452