1#include <ATen/core/jit_type.h>
2#include <torch/csrc/jit/ir/alias_analysis.h>
3#include <torch/csrc/jit/ir/ir_views.h>
4#include <torch/csrc/jit/jit_log.h>
5#include <torch/csrc/jit/passes/dead_code_elimination.h>
6#include <torch/csrc/jit/passes/peephole.h>
7#include <torch/csrc/jit/passes/peephole_alias_sensitive.h>
8#include <torch/csrc/jit/runtime/graph_executor.h>
9#include <torch/csrc/utils/memory.h>
10#include <unordered_set>
11
12namespace torch {
13namespace jit {
14
15// This pass only does optimizations which requires Alias Analysis
16// It is seprated out from Peephole Pass so that Peephole does not have
17// maintain alias db correctness throughout the pass.
18struct PeepholeOptimizeAliasSensitiveImpl {
19 PeepholeOptimizeAliasSensitiveImpl(
20 std::shared_ptr<Graph> graph,
21 bool shape_peepholes)
22 : graph_(std::move(graph)),
23 aliasDb_(torch::make_unique<AliasDb>(graph_)),
24 shape_peepholes_(shape_peepholes) {}
25
26 bool run() {
27 return runBlock(graph_->block());
28 }
29
30 private:
31 void replaceWithIValue(Value* v, IValue val) {
32 WithInsertPoint guard(v->node());
33 v->replaceAllUsesWith(v->owningGraph()->insertConstant(val));
34 }
35
36 bool isFloatingPoint(TensorType& t) {
37 auto input_dtype = t.scalarType();
38 return (
39 shape_peepholes_ && input_dtype && at::isFloatingType(*input_dtype));
40 }
41
42 bool runBlock(Block* block) {
43 bool changed = false;
44 for (Node* node : block->nodes()) {
45 for (Block* b : node->blocks()) {
46 changed |= runBlock(b);
47 }
48
49 // dim(conv(x)) extremely common and prevents Conv->BN fusion
50 if (node->kind() == aten::conv1d || node->kind() == aten::conv2d ||
51 node->kind() == aten::conv3d) {
52 auto dim_uses = c10::filter(node->output()->uses(), [](const Use& use) {
53 return use.user->kind() == aten::dim;
54 });
55 if (dim_uses.empty()) {
56 continue;
57 }
58 auto kind = node->kind();
59 int64_t output_size =
60 kind == aten::conv1d ? 3 : (kind == aten::conv2d ? 4 : 5);
61 // This is to handle potential resize_ calls, however unlikely.
62 // If we add more checks related to resize_ in the graph,
63 // factor this out like collectResizeSet in shape_analysis.
64 if (!aliasDb_->hasWriters(node->output())) {
65 for (const Use& dim_use : dim_uses) {
66 replaceWithIValue(dim_use.user->output(), output_size);
67 }
68 changed = true;
69 } else {
70 for (const Use& dim_use : dim_uses) {
71 if (aliasDb_->moveAfterTopologicallyValid(node, dim_use.user)) {
72 replaceWithIValue(dim_use.user->output(), output_size);
73 changed = true;
74 }
75 }
76 }
77 continue;
78 } else if (
79 node->matches(
80 "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor",
81 /*const_inputs=*/{attr::alpha, attr::other}) ||
82 node->matches(
83 "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor",
84 /*const_inputs=*/{attr::alpha, attr::other})) {
85 // x + 0 == x - 0 == x
86 // if either scalar input is a float, than removing this operator could
87 // remove type promotion and affect semantics
88 if (!isFloatingPoint(node->input(0)->type()->expectRef<TensorType>())) {
89 auto inps = node->inputs();
90 if (!inps.at(1)->type()->isSubtypeOf(IntType::get()) ||
91 !inps.at(2)->type()->isSubtypeOf(IntType::get())) {
92 continue;
93 }
94 }
95
96 if (node->get<at::Scalar>(attr::alpha)->toDouble() == 1 &&
97 node->get<at::Scalar>(attr::other)->toDouble() == 0) {
98 if (tryToReplaceOutputWithInput(node->input(0), node->output())) {
99 GRAPH_UPDATE(
100 getHeader(node),
101 " (x + 0 == x - 0 == x) is replaced with ",
102 node->input(0)->debugName());
103 node->output()->replaceAllUsesWith(node->input(0));
104 changed = true;
105 }
106 }
107 } else if (
108 node->matches(
109 "aten::mul(Tensor self, Scalar other) -> Tensor",
110 /*const_inputs=*/attr::other) ||
111 node->matches(
112 "aten::div(Tensor self, Scalar other) -> Tensor",
113 /*const_inputs=*/attr::other)) {
114 // x * 1 == x / 1 == x
115 // is the node is a division or other isn't an integer, than removing
116 // this operator could remove type promotion and affect semantics
117 if (!isFloatingPoint(node->input(0)->type()->expectRef<TensorType>())) {
118 if (node->kind() == aten::div ||
119 !node->input(1)->type()->isSubtypeOf(IntType::get())) {
120 continue;
121 }
122 }
123
124 if (node->get<at::Scalar>(attr::other)->toDouble() == 1) {
125 if (tryToReplaceOutputWithInput(node->input(0), node->output())) {
126 GRAPH_UPDATE(
127 getHeader(node),
128 " (x * 1 == x / 1 == x) is replaced with ",
129 node->input(0)->debugName());
130
131 changed = true;
132 }
133 }
134 }
135 }
136 return changed;
137 }
138
139 bool tryToReplaceOutputWithInput(Value* input, Value* output) {
140 if (!aliasDb_->safeToChangeAliasingRelationship(input, output)) {
141 return false;
142 }
143 // whenever we replace an output with an input, all of the aliasing
144 // properties of the output are now present on the input.
145 // For example, if the output aliases a graph output, the input will now
146 // as well.
147 // in order to avoid re-instantiating an alias db on each change, which
148 // would be O(n^2), or inplace modifying it, which would involve
149 // invalidating all of the memory dag caches, we just keep a set of values
150 // which are "stale" (aliasing properties not up to date), and avoid doing
151 // further optimizations on values which alias them
152 if (aliasDb_->mayAlias({input, output}, stale_alias_values_)) {
153 return false;
154 }
155 output->replaceAllUsesWith(input);
156 stale_alias_values_.insert(input);
157 stale_alias_values_.insert(output);
158 return true;
159 }
160
161 ValueSet stale_alias_values_;
162 std::shared_ptr<Graph> graph_;
163 std::unique_ptr<AliasDb> aliasDb_;
164 bool shape_peepholes_;
165};
166
167bool PeepholeOptimizeAliasSensitive(
168 const std::shared_ptr<Graph>& graph,
169 bool shape_peepholes) {
170 PeepholeOptimizeAliasSensitiveImpl opt(graph, shape_peepholes);
171 return opt.run();
172}
173
174} // namespace jit
175} // namespace torch
176