1 | // This pass analyzes compile-time known offsets for two values. |
2 | |
3 | #include "taichi/ir/ir.h" |
4 | #include "taichi/ir/analysis.h" |
5 | #include "taichi/ir/statements.h" |
6 | #include "taichi/ir/visitors.h" |
7 | |
8 | namespace taichi::lang { |
9 | |
10 | DiffRange operator+(const DiffRange &a, const DiffRange &b) { |
11 | return DiffRange(a.related() && b.related(), a.coeff + b.coeff, a.low + b.low, |
12 | a.high + b.high - 1); |
13 | } |
14 | |
15 | DiffRange operator-(const DiffRange &a, const DiffRange &b) { |
16 | return DiffRange(a.related() && b.related(), a.coeff - b.coeff, |
17 | a.low - b.high + 1, a.high - b.low); |
18 | } |
19 | |
20 | DiffRange operator*(const DiffRange &a, const DiffRange &b) { |
21 | return DiffRange( |
22 | a.related() && b.related() && a.coeff * b.coeff == 0, |
23 | fmax(a.low * b.coeff, a.coeff * b.low), |
24 | fmin(a.low * b.low, |
25 | fmin(a.low * (b.high - 1), |
26 | fmin(b.low * (a.high - 1), (a.high - 1) * (b.high - 1)))), |
27 | fmax(a.low * b.low, |
28 | fmax(a.low * (b.high - 1), |
29 | fmax(b.low * (a.high - 1), (a.high - 1) * (b.high - 1)))) + |
30 | 1); |
31 | } |
32 | |
33 | DiffRange operator<<(const DiffRange &a, const DiffRange &b) { |
34 | return DiffRange( |
35 | a.related() && b.related() && b.coeff == 0 && b.high - b.low == 1, |
36 | a.coeff << b.low, a.low << b.low, ((a.high - 1) << b.low) + 1); |
37 | } |
38 | |
39 | namespace { |
40 | |
41 | class ValueDiffLoopIndex : public IRVisitor { |
42 | public: |
43 | // first: related, second: offset |
44 | using ret_type = DiffRange; |
45 | Stmt *input_stmt, *loop; |
46 | int loop_index; |
47 | std::map<int, ret_type> results; |
48 | |
49 | ValueDiffLoopIndex(Stmt *stmt, Stmt *loop, int loop_index) |
50 | : input_stmt(stmt), loop(loop), loop_index(loop_index) { |
51 | allow_undefined_visitor = true; |
52 | invoke_default_visitor = true; |
53 | } |
54 | |
55 | void visit(Stmt *stmt) override { |
56 | results[stmt->instance_id] = DiffRange(); |
57 | } |
58 | |
59 | void visit(GlobalLoadStmt *stmt) override { |
60 | results[stmt->instance_id] = DiffRange(); |
61 | } |
62 | |
63 | void visit(LoopIndexStmt *stmt) override { |
64 | results[stmt->instance_id] = DiffRange(); |
65 | if (stmt->loop == loop && stmt->index == loop_index) { |
66 | results[stmt->instance_id] = |
67 | DiffRange(/*related=*/true, /*coeff=*/1, /*low=*/0); |
68 | } else if (auto range_for = stmt->loop->cast<RangeForStmt>()) { |
69 | if (range_for->begin->is<ConstStmt>() && |
70 | range_for->end->is<ConstStmt>()) { |
71 | auto begin_val = range_for->begin->as<ConstStmt>()->val.val_int(); |
72 | auto end_val = range_for->end->as<ConstStmt>()->val.val_int(); |
73 | // We have begin_val <= end_val even when range_for->reversed is true: |
74 | // in that case, the loop is iterated from end_val - 1 to begin_val. |
75 | results[stmt->instance_id] = DiffRange( |
76 | /*related=*/true, /*coeff=*/0, /*low=*/begin_val, /*high=*/end_val); |
77 | } |
78 | } |
79 | } |
80 | |
81 | void visit(ConstStmt *stmt) override { |
82 | if (stmt->val.dt->is_primitive(PrimitiveTypeID::i32)) { |
83 | results[stmt->instance_id] = DiffRange(true, 0, stmt->val.val_i32); |
84 | } else { |
85 | results[stmt->instance_id] = DiffRange(); |
86 | } |
87 | } |
88 | |
89 | void visit(RangeAssumptionStmt *stmt) override { |
90 | stmt->base->accept(this); |
91 | results[stmt->instance_id] = results[stmt->base->instance_id] + |
92 | DiffRange(true, 0, stmt->low, stmt->high); |
93 | } |
94 | |
95 | void visit(BinaryOpStmt *stmt) override { |
96 | if (stmt->op_type == BinaryOpType::add || |
97 | stmt->op_type == BinaryOpType::sub || |
98 | stmt->op_type == BinaryOpType::mul || |
99 | stmt->op_type == BinaryOpType::bit_shl) { |
100 | stmt->lhs->accept(this); |
101 | stmt->rhs->accept(this); |
102 | auto ret1 = results[stmt->lhs->instance_id]; |
103 | auto ret2 = results[stmt->rhs->instance_id]; |
104 | if (ret1.related() && ret2.related()) { |
105 | if (stmt->op_type == BinaryOpType::add) { |
106 | results[stmt->instance_id] = ret1 + ret2; |
107 | } else if (stmt->op_type == BinaryOpType::sub) { |
108 | results[stmt->instance_id] = ret1 - ret2; |
109 | } else if (stmt->op_type == BinaryOpType::mul) { |
110 | results[stmt->instance_id] = ret1 * ret2; |
111 | } else { |
112 | results[stmt->instance_id] = ret1 << ret2; |
113 | } |
114 | return; |
115 | } |
116 | } |
117 | results[stmt->instance_id] = {false, 0}; |
118 | } |
119 | |
120 | ret_type run() { |
121 | input_stmt->accept(this); |
122 | return results[input_stmt->instance_id]; |
123 | } |
124 | }; |
125 | |
126 | class FindDirectValueBaseAndOffset : public IRVisitor { |
127 | public: |
128 | // In the return value, <first> is true if this class finds that the input |
129 | // statement has value equal to <second> + <third> (base + offset), or |
130 | // <first> is false if this class can't find the decomposition. |
131 | using ret_type = std::tuple<bool, Stmt *, int>; |
132 | ret_type result; |
133 | FindDirectValueBaseAndOffset() : result(false, nullptr, 0) { |
134 | allow_undefined_visitor = true; |
135 | invoke_default_visitor = true; |
136 | } |
137 | |
138 | void visit(Stmt *stmt) override { |
139 | result = std::make_tuple(false, nullptr, 0); |
140 | } |
141 | |
142 | void visit(ConstStmt *stmt) override { |
143 | if (stmt->val.dt->is_primitive(PrimitiveTypeID::i32)) { |
144 | result = std::make_tuple(true, nullptr, stmt->val.val_i32); |
145 | } |
146 | } |
147 | |
148 | void visit(BinaryOpStmt *stmt) override { |
149 | if (stmt->rhs->is<ConstStmt>()) |
150 | stmt->rhs->accept(this); |
151 | if (!std::get<0>(result) || std::get<1>(result) != nullptr || |
152 | (stmt->op_type != BinaryOpType::add && |
153 | stmt->op_type != BinaryOpType::sub)) { |
154 | result = std::make_tuple(false, nullptr, 0); |
155 | return; |
156 | } |
157 | if (stmt->op_type == BinaryOpType::sub) |
158 | std::get<2>(result) = -std::get<2>(result); |
159 | std::get<1>(result) = stmt->lhs; |
160 | } |
161 | |
162 | static ret_type run(Stmt *val) { |
163 | FindDirectValueBaseAndOffset instance; |
164 | val->accept(&instance); |
165 | return instance.result; |
166 | } |
167 | }; |
168 | |
169 | } // namespace |
170 | |
171 | namespace irpass { |
172 | namespace analysis { |
173 | |
174 | DiffRange value_diff_loop_index(Stmt *stmt, Stmt *loop, int index_id) { |
175 | TI_ASSERT(loop->is<StructForStmt>() || loop->is<OffloadedStmt>()); |
176 | if (loop->is<OffloadedStmt>()) { |
177 | TI_ASSERT(loop->as<OffloadedStmt>()->task_type == |
178 | OffloadedStmt::TaskType::struct_for); |
179 | } |
180 | if (auto loop_index = stmt->cast<LoopIndexStmt>(); loop_index) { |
181 | if (loop_index->loop == loop && loop_index->index == index_id) { |
182 | return DiffRange(true, 1, 0); |
183 | } |
184 | } |
185 | auto diff = ValueDiffLoopIndex(stmt, loop, index_id); |
186 | return diff.run(); |
187 | } |
188 | |
189 | DiffPtrResult value_diff_ptr_index(Stmt *val1, Stmt *val2) { |
190 | if (val1 == val2) { |
191 | return DiffPtrResult::make_certain(0); |
192 | } |
193 | auto v1 = FindDirectValueBaseAndOffset::run(val1); |
194 | auto v2 = FindDirectValueBaseAndOffset::run(val2); |
195 | if (!std::get<0>(v1) || !std::get<0>(v2) || |
196 | std::get<1>(v1) != std::get<1>(v2)) { |
197 | return DiffPtrResult::make_uncertain(); |
198 | } |
199 | return DiffPtrResult::make_certain(std::get<2>(v1) - std::get<2>(v2)); |
200 | } |
201 | |
202 | } // namespace analysis |
203 | } // namespace irpass |
204 | } // namespace taichi::lang |
205 | |