1 | #include <ATen/core/jit_type.h> |
2 | #include <torch/csrc/jit/ir/alias_analysis.h> |
3 | #include <torch/csrc/jit/ir/ir_views.h> |
4 | #include <torch/csrc/jit/jit_log.h> |
5 | #include <torch/csrc/jit/passes/dead_code_elimination.h> |
6 | #include <torch/csrc/jit/passes/peephole.h> |
7 | #include <torch/csrc/jit/passes/peephole_alias_sensitive.h> |
8 | #include <torch/csrc/jit/runtime/graph_executor.h> |
9 | #include <torch/csrc/utils/memory.h> |
10 | #include <unordered_set> |
11 | |
12 | namespace torch { |
13 | namespace jit { |
14 | |
15 | // This pass only does optimizations which requires Alias Analysis |
16 | // It is seprated out from Peephole Pass so that Peephole does not have |
17 | // maintain alias db correctness throughout the pass. |
18 | struct PeepholeOptimizeAliasSensitiveImpl { |
19 | PeepholeOptimizeAliasSensitiveImpl( |
20 | std::shared_ptr<Graph> graph, |
21 | bool shape_peepholes) |
22 | : graph_(std::move(graph)), |
23 | aliasDb_(torch::make_unique<AliasDb>(graph_)), |
24 | shape_peepholes_(shape_peepholes) {} |
25 | |
26 | bool run() { |
27 | return runBlock(graph_->block()); |
28 | } |
29 | |
30 | private: |
31 | void replaceWithIValue(Value* v, IValue val) { |
32 | WithInsertPoint guard(v->node()); |
33 | v->replaceAllUsesWith(v->owningGraph()->insertConstant(val)); |
34 | } |
35 | |
36 | bool isFloatingPoint(TensorType& t) { |
37 | auto input_dtype = t.scalarType(); |
38 | return ( |
39 | shape_peepholes_ && input_dtype && at::isFloatingType(*input_dtype)); |
40 | } |
41 | |
42 | bool runBlock(Block* block) { |
43 | bool changed = false; |
44 | for (Node* node : block->nodes()) { |
45 | for (Block* b : node->blocks()) { |
46 | changed |= runBlock(b); |
47 | } |
48 | |
49 | // dim(conv(x)) extremely common and prevents Conv->BN fusion |
50 | if (node->kind() == aten::conv1d || node->kind() == aten::conv2d || |
51 | node->kind() == aten::conv3d) { |
52 | auto dim_uses = c10::filter(node->output()->uses(), [](const Use& use) { |
53 | return use.user->kind() == aten::dim; |
54 | }); |
55 | if (dim_uses.empty()) { |
56 | continue; |
57 | } |
58 | auto kind = node->kind(); |
59 | int64_t output_size = |
60 | kind == aten::conv1d ? 3 : (kind == aten::conv2d ? 4 : 5); |
61 | // This is to handle potential resize_ calls, however unlikely. |
62 | // If we add more checks related to resize_ in the graph, |
63 | // factor this out like collectResizeSet in shape_analysis. |
64 | if (!aliasDb_->hasWriters(node->output())) { |
65 | for (const Use& dim_use : dim_uses) { |
66 | replaceWithIValue(dim_use.user->output(), output_size); |
67 | } |
68 | changed = true; |
69 | } else { |
70 | for (const Use& dim_use : dim_uses) { |
71 | if (aliasDb_->moveAfterTopologicallyValid(node, dim_use.user)) { |
72 | replaceWithIValue(dim_use.user->output(), output_size); |
73 | changed = true; |
74 | } |
75 | } |
76 | } |
77 | continue; |
78 | } else if ( |
79 | node->matches( |
80 | "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor" , |
81 | /*const_inputs=*/{attr::alpha, attr::other}) || |
82 | node->matches( |
83 | "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor" , |
84 | /*const_inputs=*/{attr::alpha, attr::other})) { |
85 | // x + 0 == x - 0 == x |
86 | // if either scalar input is a float, than removing this operator could |
87 | // remove type promotion and affect semantics |
88 | if (!isFloatingPoint(node->input(0)->type()->expectRef<TensorType>())) { |
89 | auto inps = node->inputs(); |
90 | if (!inps.at(1)->type()->isSubtypeOf(IntType::get()) || |
91 | !inps.at(2)->type()->isSubtypeOf(IntType::get())) { |
92 | continue; |
93 | } |
94 | } |
95 | |
96 | if (node->get<at::Scalar>(attr::alpha)->toDouble() == 1 && |
97 | node->get<at::Scalar>(attr::other)->toDouble() == 0) { |
98 | if (tryToReplaceOutputWithInput(node->input(0), node->output())) { |
99 | GRAPH_UPDATE( |
100 | getHeader(node), |
101 | " (x + 0 == x - 0 == x) is replaced with " , |
102 | node->input(0)->debugName()); |
103 | node->output()->replaceAllUsesWith(node->input(0)); |
104 | changed = true; |
105 | } |
106 | } |
107 | } else if ( |
108 | node->matches( |
109 | "aten::mul(Tensor self, Scalar other) -> Tensor" , |
110 | /*const_inputs=*/attr::other) || |
111 | node->matches( |
112 | "aten::div(Tensor self, Scalar other) -> Tensor" , |
113 | /*const_inputs=*/attr::other)) { |
114 | // x * 1 == x / 1 == x |
115 | // is the node is a division or other isn't an integer, than removing |
116 | // this operator could remove type promotion and affect semantics |
117 | if (!isFloatingPoint(node->input(0)->type()->expectRef<TensorType>())) { |
118 | if (node->kind() == aten::div || |
119 | !node->input(1)->type()->isSubtypeOf(IntType::get())) { |
120 | continue; |
121 | } |
122 | } |
123 | |
124 | if (node->get<at::Scalar>(attr::other)->toDouble() == 1) { |
125 | if (tryToReplaceOutputWithInput(node->input(0), node->output())) { |
126 | GRAPH_UPDATE( |
127 | getHeader(node), |
128 | " (x * 1 == x / 1 == x) is replaced with " , |
129 | node->input(0)->debugName()); |
130 | |
131 | changed = true; |
132 | } |
133 | } |
134 | } |
135 | } |
136 | return changed; |
137 | } |
138 | |
139 | bool tryToReplaceOutputWithInput(Value* input, Value* output) { |
140 | if (!aliasDb_->safeToChangeAliasingRelationship(input, output)) { |
141 | return false; |
142 | } |
143 | // whenever we replace an output with an input, all of the aliasing |
144 | // properties of the output are now present on the input. |
145 | // For example, if the output aliases a graph output, the input will now |
146 | // as well. |
147 | // in order to avoid re-instantiating an alias db on each change, which |
148 | // would be O(n^2), or inplace modifying it, which would involve |
149 | // invalidating all of the memory dag caches, we just keep a set of values |
150 | // which are "stale" (aliasing properties not up to date), and avoid doing |
151 | // further optimizations on values which alias them |
152 | if (aliasDb_->mayAlias({input, output}, stale_alias_values_)) { |
153 | return false; |
154 | } |
155 | output->replaceAllUsesWith(input); |
156 | stale_alias_values_.insert(input); |
157 | stale_alias_values_.insert(output); |
158 | return true; |
159 | } |
160 | |
161 | ValueSet stale_alias_values_; |
162 | std::shared_ptr<Graph> graph_; |
163 | std::unique_ptr<AliasDb> aliasDb_; |
164 | bool shape_peepholes_; |
165 | }; |
166 | |
167 | bool PeepholeOptimizeAliasSensitive( |
168 | const std::shared_ptr<Graph>& graph, |
169 | bool shape_peepholes) { |
170 | PeepholeOptimizeAliasSensitiveImpl opt(graph, shape_peepholes); |
171 | return opt.run(); |
172 | } |
173 | |
174 | } // namespace jit |
175 | } // namespace torch |
176 | |