1 | #include "taichi/ir/ir.h" |
2 | #include "taichi/ir/ir_builder.h" |
3 | #include "taichi/ir/statements.h" |
4 | #include "taichi/ir/transforms.h" |
5 | #include "taichi/ir/visitors.h" |
6 | #include "taichi/program/program.h" |
7 | |
8 | namespace taichi::lang { |
9 | |
10 | // Demote Operations into pieces for backends to deal easier |
11 | class DemoteOperations : public BasicStmtVisitor { |
12 | public: |
13 | using BasicStmtVisitor::visit; |
14 | DelayedIRModifier modifier; |
15 | |
16 | DemoteOperations() { |
17 | } |
18 | |
19 | std::unique_ptr<Stmt> demote_ifloordiv(BinaryOpStmt *stmt, |
20 | Stmt *lhs, |
21 | Stmt *rhs) { |
22 | auto ret = Stmt::make<BinaryOpStmt>(BinaryOpType::div, lhs, rhs); |
23 | auto zero = Stmt::make<ConstStmt>(TypedConstant(0)); |
24 | auto lhs_ltz = |
25 | Stmt::make<BinaryOpStmt>(BinaryOpType::cmp_lt, lhs, zero.get()); |
26 | auto rhs_ltz = |
27 | Stmt::make<BinaryOpStmt>(BinaryOpType::cmp_lt, rhs, zero.get()); |
28 | auto rhs_mul_ret = |
29 | Stmt::make<BinaryOpStmt>(BinaryOpType::mul, rhs, ret.get()); |
30 | auto cond1 = Stmt::make<BinaryOpStmt>(BinaryOpType::cmp_ne, lhs_ltz.get(), |
31 | rhs_ltz.get()); |
32 | auto cond2 = |
33 | Stmt::make<BinaryOpStmt>(BinaryOpType::cmp_ne, lhs, zero.get()); |
34 | auto cond3 = |
35 | Stmt::make<BinaryOpStmt>(BinaryOpType::cmp_ne, rhs_mul_ret.get(), lhs); |
36 | auto cond12 = Stmt::make<BinaryOpStmt>(BinaryOpType::bit_and, cond1.get(), |
37 | cond2.get()); |
38 | auto cond = Stmt::make<BinaryOpStmt>(BinaryOpType::bit_and, cond12.get(), |
39 | cond3.get()); |
40 | auto real_ret = |
41 | Stmt::make<BinaryOpStmt>(BinaryOpType::add, ret.get(), cond.get()); |
42 | modifier.insert_before(stmt, std::move(ret)); |
43 | modifier.insert_before(stmt, std::move(zero)); |
44 | modifier.insert_before(stmt, std::move(lhs_ltz)); |
45 | modifier.insert_before(stmt, std::move(rhs_ltz)); |
46 | modifier.insert_before(stmt, std::move(rhs_mul_ret)); |
47 | modifier.insert_before(stmt, std::move(cond1)); |
48 | modifier.insert_before(stmt, std::move(cond2)); |
49 | modifier.insert_before(stmt, std::move(cond3)); |
50 | modifier.insert_before(stmt, std::move(cond12)); |
51 | modifier.insert_before(stmt, std::move(cond)); |
52 | return real_ret; |
53 | } |
54 | |
55 | std::unique_ptr<Stmt> demote_ffloor(BinaryOpStmt *stmt, |
56 | Stmt *lhs, |
57 | Stmt *rhs) { |
58 | auto div = Stmt::make<BinaryOpStmt>(BinaryOpType::div, lhs, rhs); |
59 | auto floor = Stmt::make<UnaryOpStmt>(UnaryOpType::floor, div.get()); |
60 | modifier.insert_before(stmt, std::move(div)); |
61 | return floor; |
62 | } |
63 | |
64 | void visit(BinaryOpStmt *stmt) override { |
65 | auto lhs = stmt->lhs; |
66 | auto rhs = stmt->rhs; |
67 | if (stmt->op_type == BinaryOpType::floordiv) { |
68 | if (is_integral(rhs->element_type()) && |
69 | is_integral(lhs->element_type())) { |
70 | // @ti.func |
71 | // def ifloordiv(a, b): |
72 | // r = ti.raw_div(a, b) |
73 | // if (a < 0) != (b < 0) and a and b * r != a: |
74 | // r = r - 1 |
75 | // return r |
76 | // |
77 | // simply `a * b < 0` may leads to overflow (#969) |
78 | // |
79 | // Formal Anti-Regression Verification (FARV): |
80 | // |
81 | // old = a * b < 0 |
82 | // new = (a < 0) != (b < 0) && a |
83 | // |
84 | // a b old new |
85 | // - - f = f (f&t) |
86 | // - + t = t (t&t) |
87 | // 0 - f = f (t&f) |
88 | // 0 + f = f (f&f) |
89 | // + - t = t (t&t) |
90 | // + + f = f (f&t) |
91 | // |
92 | // the situation of `b = 0` is ignored since we get FPE anyway. |
93 | auto real_ret = demote_ifloordiv(stmt, lhs, rhs); |
94 | real_ret->ret_type = stmt->ret_type; |
95 | stmt->replace_usages_with(real_ret.get()); |
96 | modifier.insert_before(stmt, std::move(real_ret)); |
97 | modifier.erase(stmt); |
98 | |
99 | } else if (is_real(rhs->element_type()) || is_real(lhs->element_type())) { |
100 | // @ti.func |
101 | // def ffloordiv(a, b): |
102 | // r = ti.raw_div(a, b) |
103 | // return ti.floor(r) |
104 | auto floor = demote_ffloor(stmt, lhs, rhs); |
105 | floor->ret_type = stmt->ret_type; |
106 | stmt->replace_usages_with(floor.get()); |
107 | modifier.insert_before(stmt, std::move(floor)); |
108 | modifier.erase(stmt); |
109 | } else if (lhs->ret_type->is<TensorType>() && |
110 | rhs->ret_type->is<TensorType>()) { |
111 | bool use_integral = is_integral(lhs->ret_type.get_element_type()) && |
112 | is_integral(rhs->ret_type.get_element_type()); |
113 | std::vector<Stmt *> ret_stmts; |
114 | auto lhs_tensor_ty = lhs->ret_type->cast<TensorType>(); |
115 | auto rhs_tensor_ty = rhs->ret_type->cast<TensorType>(); |
116 | auto lhs_alloca = Stmt::make<AllocaStmt>(lhs_tensor_ty); |
117 | auto rhs_alloca = Stmt::make<AllocaStmt>(rhs_tensor_ty); |
118 | auto lhs_store = |
119 | Stmt::make<LocalStoreStmt>(lhs_alloca.get(), stmt->lhs); |
120 | auto rhs_store = |
121 | Stmt::make<LocalStoreStmt>(rhs_alloca.get(), stmt->rhs); |
122 | auto lhs_ptr = lhs_alloca.get(); |
123 | auto rhs_ptr = rhs_alloca.get(); |
124 | modifier.insert_before(stmt, std::move(lhs_alloca)); |
125 | modifier.insert_before(stmt, std::move(rhs_alloca)); |
126 | modifier.insert_before(stmt, std::move(lhs_store)); |
127 | modifier.insert_before(stmt, std::move(rhs_store)); |
128 | for (int i = 0; i < lhs_tensor_ty->get_num_elements(); i++) { |
129 | auto idx = Stmt::make<ConstStmt>(TypedConstant(i)); |
130 | auto lhs_i = Stmt::make<MatrixPtrStmt>(lhs_ptr, idx.get()); |
131 | auto rhs_i = Stmt::make<MatrixPtrStmt>(rhs_ptr, idx.get()); |
132 | auto lhs_load = Stmt::make<LocalLoadStmt>(lhs_i.get()); |
133 | auto rhs_load = Stmt::make<LocalLoadStmt>(rhs_i.get()); |
134 | auto cur_lhs = lhs_load.get(); |
135 | auto cur_rhs = rhs_load.get(); |
136 | modifier.insert_before(stmt, std::move(idx)); |
137 | modifier.insert_before(stmt, std::move(lhs_i)); |
138 | modifier.insert_before(stmt, std::move(rhs_i)); |
139 | modifier.insert_before(stmt, std::move(lhs_load)); |
140 | modifier.insert_before(stmt, std::move(rhs_load)); |
141 | auto ret_i = use_integral ? demote_ifloordiv(stmt, cur_lhs, cur_rhs) |
142 | : demote_ffloor(stmt, cur_lhs, cur_rhs); |
143 | ret_stmts.push_back(ret_i.get()); |
144 | modifier.insert_before(stmt, std::move(ret_i)); |
145 | } |
146 | auto new_matrix = Stmt::make<MatrixInitStmt>(ret_stmts); |
147 | new_matrix->ret_type = stmt->ret_type; |
148 | stmt->replace_usages_with(new_matrix.get()); |
149 | modifier.insert_before(stmt, std::move(new_matrix)); |
150 | modifier.erase(stmt); |
151 | } |
152 | } else if (stmt->op_type == BinaryOpType::bit_shr && |
153 | is_integral(lhs->element_type()) && |
154 | is_integral(rhs->element_type()) && |
155 | is_signed(lhs->element_type())) { |
156 | // @ti.func |
157 | // def bit_shr(a, b): |
158 | // unsigned_a = ti.cast(a, ti.uXX) |
159 | // shifted = ti.bit_sar(unsigned_a, b) |
160 | // ret = ti.cast(shifted, ti.iXX) |
161 | // return ret |
162 | auto unsigned_cast = Stmt::make<UnaryOpStmt>(UnaryOpType::cast_bits, lhs); |
163 | unsigned_cast->as<UnaryOpStmt>()->cast_type = |
164 | to_unsigned(lhs->element_type()); |
165 | auto shift = Stmt::make<BinaryOpStmt>(BinaryOpType::bit_sar, |
166 | unsigned_cast.get(), rhs); |
167 | auto signed_cast = |
168 | Stmt::make<UnaryOpStmt>(UnaryOpType::cast_bits, shift.get()); |
169 | signed_cast->as<UnaryOpStmt>()->cast_type = lhs->element_type(); |
170 | signed_cast->ret_type = stmt->ret_type; |
171 | stmt->replace_usages_with(signed_cast.get()); |
172 | modifier.insert_before(stmt, std::move(unsigned_cast)); |
173 | modifier.insert_before(stmt, std::move(shift)); |
174 | modifier.insert_before(stmt, std::move(signed_cast)); |
175 | modifier.erase(stmt); |
176 | } else if (stmt->op_type == BinaryOpType::pow && |
177 | is_integral(rhs->element_type())) { |
178 | // @ti.func |
179 | // def pow(lhs, rhs): |
180 | // a = lhs |
181 | // b = abs(rhs) |
182 | // result = 1 |
183 | // while b > 0: |
184 | // if b & 1: |
185 | // result *= a |
186 | // a *= a |
187 | // b >>= 1 |
188 | // if rhs < 0: # for real lhs |
189 | // result = 1 / result # for real lhs |
190 | // return result |
191 | IRBuilder builder; |
192 | auto one_lhs = builder.get_constant(lhs->element_type(), 1); |
193 | auto one_rhs = builder.get_constant(rhs->element_type(), 1); |
194 | auto zero_rhs = builder.get_constant(rhs->element_type(), 0); |
195 | auto a = builder.create_local_var(lhs->element_type()); |
196 | builder.create_local_store(a, lhs); |
197 | auto b = builder.create_local_var(rhs->element_type()); |
198 | builder.create_local_store(b, builder.create_abs(rhs)); |
199 | auto result = builder.create_local_var(lhs->element_type()); |
200 | builder.create_local_store(result, one_lhs); |
201 | auto loop = builder.create_while_true(); |
202 | { |
203 | auto loop_guard = builder.get_loop_guard(loop); |
204 | auto current_a = builder.create_local_load(a); |
205 | auto current_b = builder.create_local_load(b); |
206 | auto if_stmt = |
207 | builder.create_if(builder.create_cmp_le(current_b, zero_rhs)); |
208 | { |
209 | auto _ = builder.get_if_guard(if_stmt, true); |
210 | builder.create_break(); |
211 | } |
212 | auto bit_and = builder.create_and(current_b, one_rhs); |
213 | if_stmt = builder.create_if(builder.create_cmp_ne(bit_and, zero_rhs)); |
214 | { |
215 | auto _ = builder.get_if_guard(if_stmt, true); |
216 | auto current_result = builder.create_local_load(result); |
217 | auto new_result = builder.create_mul(current_result, current_a); |
218 | builder.create_local_store(result, new_result); |
219 | } |
220 | auto new_a = builder.create_mul(current_a, current_a); |
221 | builder.create_local_store(a, new_a); |
222 | auto new_b = builder.create_sar(current_b, one_rhs); |
223 | builder.create_local_store(b, new_b); |
224 | } |
225 | if (is_real(lhs->element_type())) { |
226 | auto if_stmt = builder.create_if(builder.create_cmp_le(rhs, zero_rhs)); |
227 | { |
228 | auto _ = builder.get_if_guard(if_stmt, true); |
229 | auto current_result = builder.create_local_load(result); |
230 | auto new_result = builder.create_div(one_lhs, current_result); |
231 | builder.create_local_store(result, new_result); |
232 | } |
233 | } |
234 | auto final_result = builder.create_local_load(result); |
235 | stmt->replace_usages_with(final_result); |
236 | modifier.insert_before( |
237 | stmt, VecStatement(std::move(builder.extract_ir()->statements))); |
238 | modifier.erase(stmt); |
239 | } |
240 | } |
241 | |
242 | static bool run(IRNode *node, const CompileConfig &config) { |
243 | DemoteOperations demoter; |
244 | bool modified = false; |
245 | while (true) { |
246 | node->accept(&demoter); |
247 | if (demoter.modifier.modify_ir()) |
248 | modified = true; |
249 | else |
250 | break; |
251 | irpass::type_check(node, config); |
252 | } |
253 | if (modified) { |
254 | irpass::type_check(node, config); |
255 | } |
256 | return modified; |
257 | } |
258 | }; |
259 | |
260 | namespace irpass { |
261 | |
262 | bool demote_operations(IRNode *root, const CompileConfig &config) { |
263 | TI_AUTO_PROF; |
264 | bool modified = DemoteOperations::run(root, config); |
265 | return modified; |
266 | } |
267 | |
268 | } // namespace irpass |
269 | |
270 | } // namespace taichi::lang |
271 | |