1#include <c10/util/irange.h>
2#include <torch/csrc/jit/passes/value_refinement_utils.h>
3
4namespace torch {
5namespace jit {
6
7// [value refinement algorithm]
8
9// When a comparison like `cond = len(x) == 4` or `cond = len(x) != 4` is made,
10// `cond` value carries information (refinements) about the len of `x`.
11// When `cond` is used as the conditional of an if statement, the information
12// it carries for its true value can be inserted into the true block
13// and the same for its false value.
14// For something like `y = len(x) if len(x) == 1 else 1`, in the true branch
15// we can replace len(x) with 1 because the true refinements from `len(x) == 1`
16// will be present in the true block.
17// Additionally, we can optimize something like:
18// if len(x) != 4:
19// raise Exception(...)
20// return len(x)
21// Because the true block always throws, whatever refinements exist in the false
22// block become present in the owning block of the if node. We can also merge
23// refinements carried by two different booleans across an if node join by
24// taking the intersections of their refinements.
25// if cond:
26// z = len(x) == 4 and len(y) == 5
27// else:
28// z = len(x) == 4
29// Here, z's true value will refine the len(x) to 4, but not len(y).
30// If the code was written as:
31// if cond:
32// z = len(x) == 4 and len(y) == 5
33// else:
34// z = False
35//
36// Then z's true value would refine x and y, because if z is true it had to have
37// come from the true block. Code that is written with `and` or `or` will
38// desugar to something similar. Additionally, any True refinements that were
39// present on `cond` can also be associated with the if node True output value.
40
41// The intersection of the refinements is the Value* which are in both
42// refinements and are refined to the same length
43// in an example like:
44// if cond:
45// x = len(a) == 4 and len(b) == 5
46// else:
47// x = len(a) == 4
48// For the x output of the node we take the intersection between
49// the refinements stored on each block output, which will result
50// in only the refinement of len(a) == 4
51ListRefinement intersectRefinements(
52 const ListRefinement& ref1,
53 const ListRefinement& ref2) {
54 ListRefinement out;
55 for (const auto& pair : ref1) {
56 auto val2 = ref2.find(pair.first);
57 if (val2 != ref2.end() && val2->second == pair.second) {
58 out[pair.first] = pair.second;
59 }
60 }
61 return out;
62}
63
64// To union, just take all refinements from both inputs. We do not need to worry
65// about len refinements disagreeing because a path like `if len(x) == 4 and
66// len(x) == 5` will never be taken
67// in an example like:
68// if len(a) == 5:
69// x = len(b) == 4
70// else:
71// x = False
72// For the output x Value, if is true then the refinements present in the true
73// block must also be true, so we take the union of `len(a) == 5` and len(b) ==
74// 4` and assign them to true refinements of the output x value. This is a very
75// common pattern in desugaring of `and` or `or` boolean expressions
76ListRefinement unionRefinements(
77 const ListRefinement& ref1,
78 const ListRefinement& ref2) {
79 ListRefinement out = ref1;
80 out.insert(ref2.begin(), ref2.end());
81 return out;
82}
83
84void joinIfRefinements(
85 Node* if_node,
86 std::unordered_set<Block*>& throwing_blocks,
87 ListRefinement& curr_block_refinements,
88 ListRefinement& true_block_refinements,
89 ListRefinement& false_block_refinements,
90 std::unordered_map<Value*, BooleanRefinementMapping>&
91 boolean_value_refinements) {
92 IfView if_n(if_node);
93 Block* b = if_node->owningBlock();
94
95 bool true_block_throws = throwing_blocks.count(if_n.thenBlock());
96 bool false_block_throws = throwing_blocks.count(if_n.elseBlock());
97
98 // if one block throws, the refinements for the other block
99 // become present in the current block, and all bool outputs
100 // of the if node take their refinements from non throwing block
101 // output
102
103 if (true_block_throws || false_block_throws) {
104 if (true_block_throws && false_block_throws) {
105 throwing_blocks.insert(b);
106 return;
107 }
108 if (true_block_throws) {
109 curr_block_refinements.insert(
110 false_block_refinements.begin(), false_block_refinements.end());
111 } else {
112 curr_block_refinements.insert(
113 true_block_refinements.begin(), true_block_refinements.end());
114 }
115 Block* non_throwing_block =
116 true_block_throws ? if_node->blocks().at(1) : if_node->blocks().at(0);
117 for (const auto i : c10::irange(if_n.outputs().size())) {
118 if (boolean_value_refinements.count(
119 non_throwing_block->outputs().at(i))) {
120 boolean_value_refinements[if_node->outputs().at(i)] =
121 boolean_value_refinements[non_throwing_block->outputs().at(i)];
122 }
123 }
124 return;
125 }
126
127 for (const auto i : c10::irange(if_n.outputs().size())) {
128 if (!(if_n.outputs().at(i)->type() == BoolType::get())) {
129 return;
130 }
131 Value* true_v = if_n.thenOutputs().at(i);
132 Value* false_v = if_n.elseOutputs().at(i);
133
134 if (!boolean_value_refinements.count(true_v) &&
135 !boolean_value_refinements.count(false_v) &&
136 !constant_as<bool>(true_v) && !constant_as<bool>(false_v)) {
137 return;
138 }
139
140 // if either block has a constant bool output, e.g. `true` on the
141 // true block, then for the `false` value we can take the false
142 // refinements present on the false block and from the other block
143 // output value bc if the output is false it had to have come from the
144 // false block. if len(a) == 5:
145 // x = len(b) == 4
146 // else:
147 // x = False
148 // if x is true, then we know both len(a) == 5 and len(b) == 4
149 //
150 // if neither block has a constant bool value, we just take the
151 // intersection of the refinements from boolean outputs.
152 // if cond:
153 // x = len(a) == 4 and len(b) == 5
154 // else:
155 // x = len(a) == 4
156 // here, we know if x is true, then len(a) == 4, but not len(b)
157 // == 5, because that refinement is not present in the true block.
158 // TODO: could also take intersection of refinements present in
159 // both blocks, but it's not a real use case.
160
161 // boolean_value_refinements[value] is safe to access because
162 // BooleanRefinementMapping has a default constructor
163
164 BooleanRefinementMapping out;
165 if (auto maybe_bool = constant_as<bool>(true_v)) {
166 if (*maybe_bool) {
167 out = BooleanRefinementMapping::FalseRefinements(unionRefinements(
168 boolean_value_refinements[false_v].false_refine(),
169 false_block_refinements));
170 } else {
171 out = BooleanRefinementMapping::TrueRefinements(unionRefinements(
172 boolean_value_refinements[false_v].true_refine(),
173 false_block_refinements));
174 }
175 } else if (auto maybe_bool = constant_as<bool>(false_v)) {
176 if (*maybe_bool) {
177 out = BooleanRefinementMapping::FalseRefinements(unionRefinements(
178 boolean_value_refinements[true_v].false_refine(),
179 true_block_refinements));
180 } else {
181 out = BooleanRefinementMapping::TrueRefinements(unionRefinements(
182 boolean_value_refinements[true_v].true_refine(),
183 true_block_refinements));
184 }
185 } else if (
186 boolean_value_refinements.count(true_v) &&
187 boolean_value_refinements.count(false_v)) {
188 out = boolean_value_refinements[true_v].intersectBooleanRefinementMapping(
189 boolean_value_refinements[false_v]);
190 }
191 boolean_value_refinements[if_n.outputs().at(i)] = out;
192 }
193}
194
195bool handleCommonRefinentOperators(
196 Node* n,
197 std::unordered_set<Block*>& throwing_blocks,
198 std::unordered_map<Value*, BooleanRefinementMapping>& info) {
199 if (n->kind() == prim::RaiseException) {
200 throwing_blocks.insert(n->owningBlock());
201 return true;
202 }
203 if (n->kind() == aten::__not__ &&
204 n->inputs().at(0)->type()->cast<BoolType>()) {
205 // __not__(inp) -> reverse refinements
206 if (info.count(n->input())) {
207 auto& input_ref = info[n->input()];
208 info[n->output()] = BooleanRefinementMapping(
209 input_ref.false_refine(), input_ref.true_refine());
210 }
211 return true;
212 }
213 if (n->matches("aten::eq(bool a, bool b) -> bool") ||
214 (n->matches("aten::ne(bool a, bool b) -> bool"))) {
215 for (size_t const_index : {0, 1}) {
216 if (n->input(const_index)->node()->kind() != prim::Constant) {
217 continue;
218 }
219 auto const_input = constant_as<bool>(n->input(const_index)).value();
220 auto non_const_input = n->input(1 - const_index);
221 if (!info.count(non_const_input)) {
222 continue;
223 }
224 // value == False / value != True -> equivalent to __not__ value
225 // value == True / value != False -> equivalent to value
226 auto& input_ref = info[non_const_input];
227 if ((!const_input && n->kind() == aten::eq) ||
228 (const_input && n->kind() == aten::ne)) {
229 info[n->output()] = BooleanRefinementMapping(
230 input_ref.false_refine(), input_ref.true_refine());
231 } else {
232 info[n->output()] = BooleanRefinementMapping(
233 input_ref.true_refine(), input_ref.false_refine());
234 }
235 }
236 return true;
237 }
238 return false;
239}
240
241} // namespace jit
242} // namespace torch
243