1 | #include <c10/util/irange.h> |
2 | #include <torch/csrc/jit/passes/value_refinement_utils.h> |
3 | |
4 | namespace torch { |
5 | namespace jit { |
6 | |
7 | // [value refinement algorithm] |
8 | |
9 | // When a comparison like `cond = len(x) == 4` or `cond = len(x) != 4` is made, |
10 | // `cond` value carries information (refinements) about the len of `x`. |
11 | // When `cond` is used as the conditional of an if statement, the information |
12 | // it carries for its true value can be inserted into the true block |
13 | // and the same for its false value. |
14 | // For something like `y = len(x) if len(x) == 1 else 1`, in the true branch |
15 | // we can replace len(x) with 1 because the true refinements from `len(x) == 1` |
16 | // will be present in the true block. |
17 | // Additionally, we can optimize something like: |
18 | // if len(x) != 4: |
19 | // raise Exception(...) |
20 | // return len(x) |
21 | // Because the true block always throws, whatever refinements exist in the false |
22 | // block become present in the owning block of the if node. We can also merge |
23 | // refinements carried by two different booleans across an if node join by |
24 | // taking the intersections of their refinements. |
25 | // if cond: |
26 | // z = len(x) == 4 and len(y) == 5 |
27 | // else: |
28 | // z = len(x) == 4 |
29 | // Here, z's true value will refine the len(x) to 4, but not len(y). |
30 | // If the code was written as: |
31 | // if cond: |
32 | // z = len(x) == 4 and len(y) == 5 |
33 | // else: |
34 | // z = False |
35 | // |
36 | // Then z's true value would refine x and y, because if z is true it had to have |
37 | // come from the true block. Code that is written with `and` or `or` will |
38 | // desugar to something similar. Additionally, any True refinements that were |
39 | // present on `cond` can also be associated with the if node True output value. |
40 | |
41 | // The intersection of the refinements is the Value* which are in both |
42 | // refinements and are refined to the same length |
43 | // in an example like: |
44 | // if cond: |
45 | // x = len(a) == 4 and len(b) == 5 |
46 | // else: |
47 | // x = len(a) == 4 |
48 | // For the x output of the node we take the intersection between |
49 | // the refinements stored on each block output, which will result |
50 | // in only the refinement of len(a) == 4 |
51 | ListRefinement intersectRefinements( |
52 | const ListRefinement& ref1, |
53 | const ListRefinement& ref2) { |
54 | ListRefinement out; |
55 | for (const auto& pair : ref1) { |
56 | auto val2 = ref2.find(pair.first); |
57 | if (val2 != ref2.end() && val2->second == pair.second) { |
58 | out[pair.first] = pair.second; |
59 | } |
60 | } |
61 | return out; |
62 | } |
63 | |
64 | // To union, just take all refinements from both inputs. We do not need to worry |
65 | // about len refinements disagreeing because a path like `if len(x) == 4 and |
66 | // len(x) == 5` will never be taken |
67 | // in an example like: |
68 | // if len(a) == 5: |
69 | // x = len(b) == 4 |
70 | // else: |
71 | // x = False |
72 | // For the output x Value, if is true then the refinements present in the true |
73 | // block must also be true, so we take the union of `len(a) == 5` and len(b) == |
74 | // 4` and assign them to true refinements of the output x value. This is a very |
75 | // common pattern in desugaring of `and` or `or` boolean expressions |
76 | ListRefinement unionRefinements( |
77 | const ListRefinement& ref1, |
78 | const ListRefinement& ref2) { |
79 | ListRefinement out = ref1; |
80 | out.insert(ref2.begin(), ref2.end()); |
81 | return out; |
82 | } |
83 | |
84 | void joinIfRefinements( |
85 | Node* if_node, |
86 | std::unordered_set<Block*>& throwing_blocks, |
87 | ListRefinement& curr_block_refinements, |
88 | ListRefinement& true_block_refinements, |
89 | ListRefinement& false_block_refinements, |
90 | std::unordered_map<Value*, BooleanRefinementMapping>& |
91 | boolean_value_refinements) { |
92 | IfView if_n(if_node); |
93 | Block* b = if_node->owningBlock(); |
94 | |
95 | bool true_block_throws = throwing_blocks.count(if_n.thenBlock()); |
96 | bool false_block_throws = throwing_blocks.count(if_n.elseBlock()); |
97 | |
98 | // if one block throws, the refinements for the other block |
99 | // become present in the current block, and all bool outputs |
100 | // of the if node take their refinements from non throwing block |
101 | // output |
102 | |
103 | if (true_block_throws || false_block_throws) { |
104 | if (true_block_throws && false_block_throws) { |
105 | throwing_blocks.insert(b); |
106 | return; |
107 | } |
108 | if (true_block_throws) { |
109 | curr_block_refinements.insert( |
110 | false_block_refinements.begin(), false_block_refinements.end()); |
111 | } else { |
112 | curr_block_refinements.insert( |
113 | true_block_refinements.begin(), true_block_refinements.end()); |
114 | } |
115 | Block* non_throwing_block = |
116 | true_block_throws ? if_node->blocks().at(1) : if_node->blocks().at(0); |
117 | for (const auto i : c10::irange(if_n.outputs().size())) { |
118 | if (boolean_value_refinements.count( |
119 | non_throwing_block->outputs().at(i))) { |
120 | boolean_value_refinements[if_node->outputs().at(i)] = |
121 | boolean_value_refinements[non_throwing_block->outputs().at(i)]; |
122 | } |
123 | } |
124 | return; |
125 | } |
126 | |
127 | for (const auto i : c10::irange(if_n.outputs().size())) { |
128 | if (!(if_n.outputs().at(i)->type() == BoolType::get())) { |
129 | return; |
130 | } |
131 | Value* true_v = if_n.thenOutputs().at(i); |
132 | Value* false_v = if_n.elseOutputs().at(i); |
133 | |
134 | if (!boolean_value_refinements.count(true_v) && |
135 | !boolean_value_refinements.count(false_v) && |
136 | !constant_as<bool>(true_v) && !constant_as<bool>(false_v)) { |
137 | return; |
138 | } |
139 | |
140 | // if either block has a constant bool output, e.g. `true` on the |
141 | // true block, then for the `false` value we can take the false |
142 | // refinements present on the false block and from the other block |
143 | // output value bc if the output is false it had to have come from the |
144 | // false block. if len(a) == 5: |
145 | // x = len(b) == 4 |
146 | // else: |
147 | // x = False |
148 | // if x is true, then we know both len(a) == 5 and len(b) == 4 |
149 | // |
150 | // if neither block has a constant bool value, we just take the |
151 | // intersection of the refinements from boolean outputs. |
152 | // if cond: |
153 | // x = len(a) == 4 and len(b) == 5 |
154 | // else: |
155 | // x = len(a) == 4 |
156 | // here, we know if x is true, then len(a) == 4, but not len(b) |
157 | // == 5, because that refinement is not present in the true block. |
158 | // TODO: could also take intersection of refinements present in |
159 | // both blocks, but it's not a real use case. |
160 | |
161 | // boolean_value_refinements[value] is safe to access because |
162 | // BooleanRefinementMapping has a default constructor |
163 | |
164 | BooleanRefinementMapping out; |
165 | if (auto maybe_bool = constant_as<bool>(true_v)) { |
166 | if (*maybe_bool) { |
167 | out = BooleanRefinementMapping::FalseRefinements(unionRefinements( |
168 | boolean_value_refinements[false_v].false_refine(), |
169 | false_block_refinements)); |
170 | } else { |
171 | out = BooleanRefinementMapping::TrueRefinements(unionRefinements( |
172 | boolean_value_refinements[false_v].true_refine(), |
173 | false_block_refinements)); |
174 | } |
175 | } else if (auto maybe_bool = constant_as<bool>(false_v)) { |
176 | if (*maybe_bool) { |
177 | out = BooleanRefinementMapping::FalseRefinements(unionRefinements( |
178 | boolean_value_refinements[true_v].false_refine(), |
179 | true_block_refinements)); |
180 | } else { |
181 | out = BooleanRefinementMapping::TrueRefinements(unionRefinements( |
182 | boolean_value_refinements[true_v].true_refine(), |
183 | true_block_refinements)); |
184 | } |
185 | } else if ( |
186 | boolean_value_refinements.count(true_v) && |
187 | boolean_value_refinements.count(false_v)) { |
188 | out = boolean_value_refinements[true_v].intersectBooleanRefinementMapping( |
189 | boolean_value_refinements[false_v]); |
190 | } |
191 | boolean_value_refinements[if_n.outputs().at(i)] = out; |
192 | } |
193 | } |
194 | |
195 | bool handleCommonRefinentOperators( |
196 | Node* n, |
197 | std::unordered_set<Block*>& throwing_blocks, |
198 | std::unordered_map<Value*, BooleanRefinementMapping>& info) { |
199 | if (n->kind() == prim::RaiseException) { |
200 | throwing_blocks.insert(n->owningBlock()); |
201 | return true; |
202 | } |
203 | if (n->kind() == aten::__not__ && |
204 | n->inputs().at(0)->type()->cast<BoolType>()) { |
205 | // __not__(inp) -> reverse refinements |
206 | if (info.count(n->input())) { |
207 | auto& input_ref = info[n->input()]; |
208 | info[n->output()] = BooleanRefinementMapping( |
209 | input_ref.false_refine(), input_ref.true_refine()); |
210 | } |
211 | return true; |
212 | } |
213 | if (n->matches("aten::eq(bool a, bool b) -> bool" ) || |
214 | (n->matches("aten::ne(bool a, bool b) -> bool" ))) { |
215 | for (size_t const_index : {0, 1}) { |
216 | if (n->input(const_index)->node()->kind() != prim::Constant) { |
217 | continue; |
218 | } |
219 | auto const_input = constant_as<bool>(n->input(const_index)).value(); |
220 | auto non_const_input = n->input(1 - const_index); |
221 | if (!info.count(non_const_input)) { |
222 | continue; |
223 | } |
224 | // value == False / value != True -> equivalent to __not__ value |
225 | // value == True / value != False -> equivalent to value |
226 | auto& input_ref = info[non_const_input]; |
227 | if ((!const_input && n->kind() == aten::eq) || |
228 | (const_input && n->kind() == aten::ne)) { |
229 | info[n->output()] = BooleanRefinementMapping( |
230 | input_ref.false_refine(), input_ref.true_refine()); |
231 | } else { |
232 | info[n->output()] = BooleanRefinementMapping( |
233 | input_ref.true_refine(), input_ref.false_refine()); |
234 | } |
235 | } |
236 | return true; |
237 | } |
238 | return false; |
239 | } |
240 | |
241 | } // namespace jit |
242 | } // namespace torch |
243 | |