1#include <ATen/core/jit_type.h>
2#include <torch/csrc/jit/ir/ir.h>
3#include <torch/csrc/jit/jit_log.h>
4#include <torch/csrc/jit/passes/integer_value_refinement.h>
5#include <torch/csrc/jit/passes/value_refinement_utils.h>
6#include <torch/csrc/utils/memory.h>
7
8#include <utility>
9
10namespace torch {
11namespace jit {
12
13using IntegerRefinement = std::unordered_map<Value*, int64_t>;
14
15// see [value refinement algorithm] for full explanation.
16// When a comparison like `cond = x == 4` or `cond = x != 4` is made,
17// `cond` value carries information (refinements) about the value of `x`.
18// in an example like:
19// if x == 1:
20// ...
21// we can substitute all uses of x dominated by the true block
22// with 1.
23
24struct IntegerValueRefiner {
25 IntegerValueRefiner(std::shared_ptr<Graph> graph)
26 : graph_(std::move(graph)) {}
27
28 bool run() {
29 if (!blockHasIntComparisons(graph_->block())) {
30 return false;
31 }
32 IntegerRefinement refinements;
33 RefineIntegerValues(graph_->block(), std::move(refinements));
34 return changed_;
35 }
36
37 bool blockHasIntComparisons(Block* b) {
38 for (Node* n : b->nodes()) {
39 if (n->matches("aten::eq(int a, int b) -> bool") ||
40 n->matches("aten::ne(int a, int b) -> bool")) {
41 for (size_t const_index : {0, 1}) {
42 auto non_const_index = 1 - const_index;
43 if (n->inputs().at(const_index)->node()->kind() == prim::Constant &&
44 n->inputs().at(non_const_index)->uses().size() > 1) {
45 return true;
46 }
47 }
48 }
49 for (Block* block : n->blocks()) {
50 if (blockHasIntComparisons(block)) {
51 return true;
52 }
53 }
54 }
55 return false;
56 }
57
58 void removeIfNodeOutputsWithRefinements(
59 Node* if_node,
60 IntegerRefinement& true_block_refinements,
61 IntegerRefinement& false_block_refinements) {
62 // we are looking for cases where we can replace both block outputs with the
63 // same value, which opens up further optimization opportunities. The pass
64 // will already handle if both outputs are refined to the same constant.
65 // Here, we look for cases where one block output has been refined in the
66 // other block to be equal to the same constant value as the other other
67 // block output:
68 // graph(%y.1 : int):
69 // %one_constant : int = prim::Constant[value=1]()
70 // %3 : bool = aten::eq(%y.1, %one_constant)
71 // %15 : int = prim::If(%3)
72 // block0():
73 // -> (%one_constant)
74 // block1():
75 // -> (%y.1)
76 // return (%15)
77 // %15 can always be safely replaced with %y.1
78 // this is an important case for symbolic shape analysis
79 for (size_t block_index : {0, 1}) {
80 Block* if_block = if_node->blocks().at(block_index);
81 Block* other_if_block = if_node->blocks().at(1 - block_index);
82 for (size_t i = 0; i < if_node->outputs().size(); ++i) {
83 Value* block_output = if_block->outputs().at(i);
84 if (!block_output->type()->cast<IntType>()) {
85 continue;
86 }
87 // Value must be in scope for both blocks
88 // in example above, %y.1 cannot be defined in block1
89 if (!if_node->isDominatedBy(block_output->node())) {
90 continue;
91 }
92 // one constant value one not - we are looking for the pattern
93 // where y.1 is refined to the existing block output %one_constant
94 auto other_output = other_if_block->outputs().at(i);
95 auto other_const_value = other_output->type()->cast<IntType>()
96 ? constant_as<int64_t>(other_output)
97 : c10::nullopt;
98 if (!other_const_value ||
99 block_output->node()->kind() == prim::Constant) {
100 continue;
101 }
102 // here, we are looking in refinements in the other block of our
103 // current output. in the example, we are looking for refinements of
104 // %y.1 in `block0`, and we are checking that %y.1 is refined
105 // to the constant value of %one_constant
106 const auto& other_block_refinements =
107 block_index == 0 ? false_block_refinements : true_block_refinements;
108 if (!other_block_refinements.count(block_output)) {
109 continue;
110 }
111 if (other_block_refinements.at(block_output) == *other_const_value) {
112 if_node->outputs().at(i)->replaceAllUsesWith(block_output);
113 changed_ = true;
114 }
115 }
116 }
117 }
118
119 // iteratively look through the block `b` for refinements or Value uses that
120 // can be refined, `block_refinements` are the refinements present starting at
121 // this block (and for all blocks dominated by this block).
122 IntegerRefinement RefineIntegerValues(
123 Block* b,
124 IntegerRefinement block_refinements) {
125 active_refinements_.push_back(&block_refinements);
126 for (Node* n : b->nodes()) {
127 if (n->matches("aten::eq(int a, int b) -> bool") ||
128 n->matches("aten::ne(int a, int b) -> bool")) {
129 for (size_t const_index : {0, 1}) {
130 if (auto ival = constant_as<int64_t>(n->inputs().at(const_index))) {
131 IntegerRefinement refine;
132 refine[n->inputs().at(1 - const_index)] = *ival;
133 info_[n->output()] = n->kind() == aten::eq
134 ? BooleanRefinementMapping::TrueRefinements(std::move(refine))
135 : BooleanRefinementMapping::FalseRefinements(std::move(refine));
136 }
137 }
138 }
139 for (size_t input = 0; input < n->inputs().size(); ++input) {
140 Value* input_v = n->inputs().at(input);
141 if (!input_v->type()->cast<IntType>()) {
142 continue;
143 }
144
145 if (auto refine = tryFindRefinement(input_v)) {
146 WithInsertPoint guard(n);
147 auto refine_constant =
148 graph_->insertConstant(static_cast<int64_t>(*refine));
149 n->replaceInputWith(input_v, refine_constant);
150 changed_ = true;
151 }
152 }
153
154 if (n->kind() == prim::If) {
155 IfView if_n(n);
156 bool has_cond_ref = info_.count(if_n.cond()) != 0;
157 IntegerRefinement empty;
158 auto true_block_refinements = RefineIntegerValues(
159 if_n.thenBlock(),
160 has_cond_ref ? info_[if_n.cond()].true_refine() : empty);
161 auto false_block_refinements = RefineIntegerValues(
162 if_n.elseBlock(),
163 has_cond_ref ? info_[if_n.cond()].false_refine() : empty);
164
165 removeIfNodeOutputsWithRefinements(
166 n, true_block_refinements, false_block_refinements);
167
168 joinIfRefinements(
169 n,
170 throwing_blocks_,
171 block_refinements,
172 true_block_refinements,
173 false_block_refinements,
174 info_);
175 } else {
176 handleCommonRefinentOperators(n, throwing_blocks_, info_);
177 }
178 }
179
180 // iterating over all nodes in the block will not iterate over
181 // block outputs, so we need to add handling of them.
182 // %3 : int = prim::Constant[value=3]()
183 // %4 : bool = aten::eq(%y.1, %3)
184 // %a : int = prim::If(%4)
185 // block0():
186 // -> (%y.1)
187 // Here, we can replace y.1 with 3
188
189 for (size_t i = 0; i < b->outputs().size(); ++i) {
190 Value* output_v = b->outputs().at(i);
191 if (!output_v->type()->cast<IntType>()) {
192 continue;
193 }
194
195 if (auto refine = tryFindRefinement(output_v)) {
196 WithInsertPoint guard(b);
197 auto refine_constant =
198 graph_->insertConstant(static_cast<int64_t>(*refine));
199 b->replaceOutput(i, refine_constant);
200 changed_ = true;
201 }
202 }
203
204 active_refinements_.pop_back();
205 return block_refinements;
206 };
207
208 c10::optional<int64_t> tryFindRefinement(Value* v) {
209 for (const auto& ref : active_refinements_) {
210 auto maybe_refinement = ref->find(v);
211 if (maybe_refinement != ref->end()) {
212 return maybe_refinement->second;
213 }
214 }
215 return c10::nullopt;
216 }
217
218 std::shared_ptr<Graph> graph_;
219 // A stack of active refinements, one for each block
220 std::vector<IntegerRefinement*> active_refinements_;
221 // A map from Boolean Value * -> associated refinements
222 std::unordered_map<Value*, BooleanRefinementMapping> info_;
223 std::unordered_set<Block*> throwing_blocks_;
224 bool changed_ = false;
225};
226
227bool RefineIntegerValues(const std::shared_ptr<Graph>& graph) {
228 return IntegerValueRefiner(graph).run();
229}
230
231} // namespace jit
232} // namespace torch
233