1 | #include <ATen/core/jit_type.h> |
2 | #include <torch/csrc/jit/ir/ir.h> |
3 | #include <torch/csrc/jit/jit_log.h> |
4 | #include <torch/csrc/jit/passes/integer_value_refinement.h> |
5 | #include <torch/csrc/jit/passes/value_refinement_utils.h> |
6 | #include <torch/csrc/utils/memory.h> |
7 | |
8 | #include <utility> |
9 | |
10 | namespace torch { |
11 | namespace jit { |
12 | |
13 | using IntegerRefinement = std::unordered_map<Value*, int64_t>; |
14 | |
15 | // see [value refinement algorithm] for full explanation. |
16 | // When a comparison like `cond = x == 4` or `cond = x != 4` is made, |
17 | // `cond` value carries information (refinements) about the value of `x`. |
18 | // in an example like: |
19 | // if x == 1: |
20 | // ... |
21 | // we can substitute all uses of x dominated by the true block |
22 | // with 1. |
23 | |
24 | struct IntegerValueRefiner { |
25 | IntegerValueRefiner(std::shared_ptr<Graph> graph) |
26 | : graph_(std::move(graph)) {} |
27 | |
28 | bool run() { |
29 | if (!blockHasIntComparisons(graph_->block())) { |
30 | return false; |
31 | } |
32 | IntegerRefinement refinements; |
33 | RefineIntegerValues(graph_->block(), std::move(refinements)); |
34 | return changed_; |
35 | } |
36 | |
37 | bool blockHasIntComparisons(Block* b) { |
38 | for (Node* n : b->nodes()) { |
39 | if (n->matches("aten::eq(int a, int b) -> bool" ) || |
40 | n->matches("aten::ne(int a, int b) -> bool" )) { |
41 | for (size_t const_index : {0, 1}) { |
42 | auto non_const_index = 1 - const_index; |
43 | if (n->inputs().at(const_index)->node()->kind() == prim::Constant && |
44 | n->inputs().at(non_const_index)->uses().size() > 1) { |
45 | return true; |
46 | } |
47 | } |
48 | } |
49 | for (Block* block : n->blocks()) { |
50 | if (blockHasIntComparisons(block)) { |
51 | return true; |
52 | } |
53 | } |
54 | } |
55 | return false; |
56 | } |
57 | |
58 | void removeIfNodeOutputsWithRefinements( |
59 | Node* if_node, |
60 | IntegerRefinement& true_block_refinements, |
61 | IntegerRefinement& false_block_refinements) { |
62 | // we are looking for cases where we can replace both block outputs with the |
63 | // same value, which opens up further optimization opportunities. The pass |
64 | // will already handle if both outputs are refined to the same constant. |
65 | // Here, we look for cases where one block output has been refined in the |
66 | // other block to be equal to the same constant value as the other other |
67 | // block output: |
68 | // graph(%y.1 : int): |
69 | // %one_constant : int = prim::Constant[value=1]() |
70 | // %3 : bool = aten::eq(%y.1, %one_constant) |
71 | // %15 : int = prim::If(%3) |
72 | // block0(): |
73 | // -> (%one_constant) |
74 | // block1(): |
75 | // -> (%y.1) |
76 | // return (%15) |
77 | // %15 can always be safely replaced with %y.1 |
78 | // this is an important case for symbolic shape analysis |
79 | for (size_t block_index : {0, 1}) { |
80 | Block* if_block = if_node->blocks().at(block_index); |
81 | Block* other_if_block = if_node->blocks().at(1 - block_index); |
82 | for (size_t i = 0; i < if_node->outputs().size(); ++i) { |
83 | Value* block_output = if_block->outputs().at(i); |
84 | if (!block_output->type()->cast<IntType>()) { |
85 | continue; |
86 | } |
87 | // Value must be in scope for both blocks |
88 | // in example above, %y.1 cannot be defined in block1 |
89 | if (!if_node->isDominatedBy(block_output->node())) { |
90 | continue; |
91 | } |
92 | // one constant value one not - we are looking for the pattern |
93 | // where y.1 is refined to the existing block output %one_constant |
94 | auto other_output = other_if_block->outputs().at(i); |
95 | auto other_const_value = other_output->type()->cast<IntType>() |
96 | ? constant_as<int64_t>(other_output) |
97 | : c10::nullopt; |
98 | if (!other_const_value || |
99 | block_output->node()->kind() == prim::Constant) { |
100 | continue; |
101 | } |
102 | // here, we are looking in refinements in the other block of our |
103 | // current output. in the example, we are looking for refinements of |
104 | // %y.1 in `block0`, and we are checking that %y.1 is refined |
105 | // to the constant value of %one_constant |
106 | const auto& other_block_refinements = |
107 | block_index == 0 ? false_block_refinements : true_block_refinements; |
108 | if (!other_block_refinements.count(block_output)) { |
109 | continue; |
110 | } |
111 | if (other_block_refinements.at(block_output) == *other_const_value) { |
112 | if_node->outputs().at(i)->replaceAllUsesWith(block_output); |
113 | changed_ = true; |
114 | } |
115 | } |
116 | } |
117 | } |
118 | |
119 | // iteratively look through the block `b` for refinements or Value uses that |
120 | // can be refined, `block_refinements` are the refinements present starting at |
121 | // this block (and for all blocks dominated by this block). |
122 | IntegerRefinement RefineIntegerValues( |
123 | Block* b, |
124 | IntegerRefinement block_refinements) { |
125 | active_refinements_.push_back(&block_refinements); |
126 | for (Node* n : b->nodes()) { |
127 | if (n->matches("aten::eq(int a, int b) -> bool" ) || |
128 | n->matches("aten::ne(int a, int b) -> bool" )) { |
129 | for (size_t const_index : {0, 1}) { |
130 | if (auto ival = constant_as<int64_t>(n->inputs().at(const_index))) { |
131 | IntegerRefinement refine; |
132 | refine[n->inputs().at(1 - const_index)] = *ival; |
133 | info_[n->output()] = n->kind() == aten::eq |
134 | ? BooleanRefinementMapping::TrueRefinements(std::move(refine)) |
135 | : BooleanRefinementMapping::FalseRefinements(std::move(refine)); |
136 | } |
137 | } |
138 | } |
139 | for (size_t input = 0; input < n->inputs().size(); ++input) { |
140 | Value* input_v = n->inputs().at(input); |
141 | if (!input_v->type()->cast<IntType>()) { |
142 | continue; |
143 | } |
144 | |
145 | if (auto refine = tryFindRefinement(input_v)) { |
146 | WithInsertPoint guard(n); |
147 | auto refine_constant = |
148 | graph_->insertConstant(static_cast<int64_t>(*refine)); |
149 | n->replaceInputWith(input_v, refine_constant); |
150 | changed_ = true; |
151 | } |
152 | } |
153 | |
154 | if (n->kind() == prim::If) { |
155 | IfView if_n(n); |
156 | bool has_cond_ref = info_.count(if_n.cond()) != 0; |
157 | IntegerRefinement empty; |
158 | auto true_block_refinements = RefineIntegerValues( |
159 | if_n.thenBlock(), |
160 | has_cond_ref ? info_[if_n.cond()].true_refine() : empty); |
161 | auto false_block_refinements = RefineIntegerValues( |
162 | if_n.elseBlock(), |
163 | has_cond_ref ? info_[if_n.cond()].false_refine() : empty); |
164 | |
165 | removeIfNodeOutputsWithRefinements( |
166 | n, true_block_refinements, false_block_refinements); |
167 | |
168 | joinIfRefinements( |
169 | n, |
170 | throwing_blocks_, |
171 | block_refinements, |
172 | true_block_refinements, |
173 | false_block_refinements, |
174 | info_); |
175 | } else { |
176 | handleCommonRefinentOperators(n, throwing_blocks_, info_); |
177 | } |
178 | } |
179 | |
180 | // iterating over all nodes in the block will not iterate over |
181 | // block outputs, so we need to add handling of them. |
182 | // %3 : int = prim::Constant[value=3]() |
183 | // %4 : bool = aten::eq(%y.1, %3) |
184 | // %a : int = prim::If(%4) |
185 | // block0(): |
186 | // -> (%y.1) |
187 | // Here, we can replace y.1 with 3 |
188 | |
189 | for (size_t i = 0; i < b->outputs().size(); ++i) { |
190 | Value* output_v = b->outputs().at(i); |
191 | if (!output_v->type()->cast<IntType>()) { |
192 | continue; |
193 | } |
194 | |
195 | if (auto refine = tryFindRefinement(output_v)) { |
196 | WithInsertPoint guard(b); |
197 | auto refine_constant = |
198 | graph_->insertConstant(static_cast<int64_t>(*refine)); |
199 | b->replaceOutput(i, refine_constant); |
200 | changed_ = true; |
201 | } |
202 | } |
203 | |
204 | active_refinements_.pop_back(); |
205 | return block_refinements; |
206 | }; |
207 | |
208 | c10::optional<int64_t> tryFindRefinement(Value* v) { |
209 | for (const auto& ref : active_refinements_) { |
210 | auto maybe_refinement = ref->find(v); |
211 | if (maybe_refinement != ref->end()) { |
212 | return maybe_refinement->second; |
213 | } |
214 | } |
215 | return c10::nullopt; |
216 | } |
217 | |
218 | std::shared_ptr<Graph> graph_; |
219 | // A stack of active refinements, one for each block |
220 | std::vector<IntegerRefinement*> active_refinements_; |
221 | // A map from Boolean Value * -> associated refinements |
222 | std::unordered_map<Value*, BooleanRefinementMapping> info_; |
223 | std::unordered_set<Block*> throwing_blocks_; |
224 | bool changed_ = false; |
225 | }; |
226 | |
227 | bool RefineIntegerValues(const std::shared_ptr<Graph>& graph) { |
228 | return IntegerValueRefiner(graph).run(); |
229 | } |
230 | |
231 | } // namespace jit |
232 | } // namespace torch |
233 | |