1#include "taichi/ir/ir.h"
2#include "taichi/ir/ir_builder.h"
3#include "taichi/ir/statements.h"
4#include "taichi/ir/transforms.h"
5#include "taichi/ir/visitors.h"
6#include "taichi/program/program.h"
7
8namespace taichi::lang {
9
10// Demote Operations into pieces for backends to deal easier
11class DemoteOperations : public BasicStmtVisitor {
12 public:
13 using BasicStmtVisitor::visit;
14 DelayedIRModifier modifier;
15
16 DemoteOperations() {
17 }
18
19 std::unique_ptr<Stmt> demote_ifloordiv(BinaryOpStmt *stmt,
20 Stmt *lhs,
21 Stmt *rhs) {
22 auto ret = Stmt::make<BinaryOpStmt>(BinaryOpType::div, lhs, rhs);
23 auto zero = Stmt::make<ConstStmt>(TypedConstant(0));
24 auto lhs_ltz =
25 Stmt::make<BinaryOpStmt>(BinaryOpType::cmp_lt, lhs, zero.get());
26 auto rhs_ltz =
27 Stmt::make<BinaryOpStmt>(BinaryOpType::cmp_lt, rhs, zero.get());
28 auto rhs_mul_ret =
29 Stmt::make<BinaryOpStmt>(BinaryOpType::mul, rhs, ret.get());
30 auto cond1 = Stmt::make<BinaryOpStmt>(BinaryOpType::cmp_ne, lhs_ltz.get(),
31 rhs_ltz.get());
32 auto cond2 =
33 Stmt::make<BinaryOpStmt>(BinaryOpType::cmp_ne, lhs, zero.get());
34 auto cond3 =
35 Stmt::make<BinaryOpStmt>(BinaryOpType::cmp_ne, rhs_mul_ret.get(), lhs);
36 auto cond12 = Stmt::make<BinaryOpStmt>(BinaryOpType::bit_and, cond1.get(),
37 cond2.get());
38 auto cond = Stmt::make<BinaryOpStmt>(BinaryOpType::bit_and, cond12.get(),
39 cond3.get());
40 auto real_ret =
41 Stmt::make<BinaryOpStmt>(BinaryOpType::add, ret.get(), cond.get());
42 modifier.insert_before(stmt, std::move(ret));
43 modifier.insert_before(stmt, std::move(zero));
44 modifier.insert_before(stmt, std::move(lhs_ltz));
45 modifier.insert_before(stmt, std::move(rhs_ltz));
46 modifier.insert_before(stmt, std::move(rhs_mul_ret));
47 modifier.insert_before(stmt, std::move(cond1));
48 modifier.insert_before(stmt, std::move(cond2));
49 modifier.insert_before(stmt, std::move(cond3));
50 modifier.insert_before(stmt, std::move(cond12));
51 modifier.insert_before(stmt, std::move(cond));
52 return real_ret;
53 }
54
55 std::unique_ptr<Stmt> demote_ffloor(BinaryOpStmt *stmt,
56 Stmt *lhs,
57 Stmt *rhs) {
58 auto div = Stmt::make<BinaryOpStmt>(BinaryOpType::div, lhs, rhs);
59 auto floor = Stmt::make<UnaryOpStmt>(UnaryOpType::floor, div.get());
60 modifier.insert_before(stmt, std::move(div));
61 return floor;
62 }
63
64 void visit(BinaryOpStmt *stmt) override {
65 auto lhs = stmt->lhs;
66 auto rhs = stmt->rhs;
67 if (stmt->op_type == BinaryOpType::floordiv) {
68 if (is_integral(rhs->element_type()) &&
69 is_integral(lhs->element_type())) {
70 // @ti.func
71 // def ifloordiv(a, b):
72 // r = ti.raw_div(a, b)
73 // if (a < 0) != (b < 0) and a and b * r != a:
74 // r = r - 1
75 // return r
76 //
77 // simply `a * b < 0` may leads to overflow (#969)
78 //
79 // Formal Anti-Regression Verification (FARV):
80 //
81 // old = a * b < 0
82 // new = (a < 0) != (b < 0) && a
83 //
84 // a b old new
85 // - - f = f (f&t)
86 // - + t = t (t&t)
87 // 0 - f = f (t&f)
88 // 0 + f = f (f&f)
89 // + - t = t (t&t)
90 // + + f = f (f&t)
91 //
92 // the situation of `b = 0` is ignored since we get FPE anyway.
93 auto real_ret = demote_ifloordiv(stmt, lhs, rhs);
94 real_ret->ret_type = stmt->ret_type;
95 stmt->replace_usages_with(real_ret.get());
96 modifier.insert_before(stmt, std::move(real_ret));
97 modifier.erase(stmt);
98
99 } else if (is_real(rhs->element_type()) || is_real(lhs->element_type())) {
100 // @ti.func
101 // def ffloordiv(a, b):
102 // r = ti.raw_div(a, b)
103 // return ti.floor(r)
104 auto floor = demote_ffloor(stmt, lhs, rhs);
105 floor->ret_type = stmt->ret_type;
106 stmt->replace_usages_with(floor.get());
107 modifier.insert_before(stmt, std::move(floor));
108 modifier.erase(stmt);
109 } else if (lhs->ret_type->is<TensorType>() &&
110 rhs->ret_type->is<TensorType>()) {
111 bool use_integral = is_integral(lhs->ret_type.get_element_type()) &&
112 is_integral(rhs->ret_type.get_element_type());
113 std::vector<Stmt *> ret_stmts;
114 auto lhs_tensor_ty = lhs->ret_type->cast<TensorType>();
115 auto rhs_tensor_ty = rhs->ret_type->cast<TensorType>();
116 auto lhs_alloca = Stmt::make<AllocaStmt>(lhs_tensor_ty);
117 auto rhs_alloca = Stmt::make<AllocaStmt>(rhs_tensor_ty);
118 auto lhs_store =
119 Stmt::make<LocalStoreStmt>(lhs_alloca.get(), stmt->lhs);
120 auto rhs_store =
121 Stmt::make<LocalStoreStmt>(rhs_alloca.get(), stmt->rhs);
122 auto lhs_ptr = lhs_alloca.get();
123 auto rhs_ptr = rhs_alloca.get();
124 modifier.insert_before(stmt, std::move(lhs_alloca));
125 modifier.insert_before(stmt, std::move(rhs_alloca));
126 modifier.insert_before(stmt, std::move(lhs_store));
127 modifier.insert_before(stmt, std::move(rhs_store));
128 for (int i = 0; i < lhs_tensor_ty->get_num_elements(); i++) {
129 auto idx = Stmt::make<ConstStmt>(TypedConstant(i));
130 auto lhs_i = Stmt::make<MatrixPtrStmt>(lhs_ptr, idx.get());
131 auto rhs_i = Stmt::make<MatrixPtrStmt>(rhs_ptr, idx.get());
132 auto lhs_load = Stmt::make<LocalLoadStmt>(lhs_i.get());
133 auto rhs_load = Stmt::make<LocalLoadStmt>(rhs_i.get());
134 auto cur_lhs = lhs_load.get();
135 auto cur_rhs = rhs_load.get();
136 modifier.insert_before(stmt, std::move(idx));
137 modifier.insert_before(stmt, std::move(lhs_i));
138 modifier.insert_before(stmt, std::move(rhs_i));
139 modifier.insert_before(stmt, std::move(lhs_load));
140 modifier.insert_before(stmt, std::move(rhs_load));
141 auto ret_i = use_integral ? demote_ifloordiv(stmt, cur_lhs, cur_rhs)
142 : demote_ffloor(stmt, cur_lhs, cur_rhs);
143 ret_stmts.push_back(ret_i.get());
144 modifier.insert_before(stmt, std::move(ret_i));
145 }
146 auto new_matrix = Stmt::make<MatrixInitStmt>(ret_stmts);
147 new_matrix->ret_type = stmt->ret_type;
148 stmt->replace_usages_with(new_matrix.get());
149 modifier.insert_before(stmt, std::move(new_matrix));
150 modifier.erase(stmt);
151 }
152 } else if (stmt->op_type == BinaryOpType::bit_shr &&
153 is_integral(lhs->element_type()) &&
154 is_integral(rhs->element_type()) &&
155 is_signed(lhs->element_type())) {
156 // @ti.func
157 // def bit_shr(a, b):
158 // unsigned_a = ti.cast(a, ti.uXX)
159 // shifted = ti.bit_sar(unsigned_a, b)
160 // ret = ti.cast(shifted, ti.iXX)
161 // return ret
162 auto unsigned_cast = Stmt::make<UnaryOpStmt>(UnaryOpType::cast_bits, lhs);
163 unsigned_cast->as<UnaryOpStmt>()->cast_type =
164 to_unsigned(lhs->element_type());
165 auto shift = Stmt::make<BinaryOpStmt>(BinaryOpType::bit_sar,
166 unsigned_cast.get(), rhs);
167 auto signed_cast =
168 Stmt::make<UnaryOpStmt>(UnaryOpType::cast_bits, shift.get());
169 signed_cast->as<UnaryOpStmt>()->cast_type = lhs->element_type();
170 signed_cast->ret_type = stmt->ret_type;
171 stmt->replace_usages_with(signed_cast.get());
172 modifier.insert_before(stmt, std::move(unsigned_cast));
173 modifier.insert_before(stmt, std::move(shift));
174 modifier.insert_before(stmt, std::move(signed_cast));
175 modifier.erase(stmt);
176 } else if (stmt->op_type == BinaryOpType::pow &&
177 is_integral(rhs->element_type())) {
178 // @ti.func
179 // def pow(lhs, rhs):
180 // a = lhs
181 // b = abs(rhs)
182 // result = 1
183 // while b > 0:
184 // if b & 1:
185 // result *= a
186 // a *= a
187 // b >>= 1
188 // if rhs < 0: # for real lhs
189 // result = 1 / result # for real lhs
190 // return result
191 IRBuilder builder;
192 auto one_lhs = builder.get_constant(lhs->element_type(), 1);
193 auto one_rhs = builder.get_constant(rhs->element_type(), 1);
194 auto zero_rhs = builder.get_constant(rhs->element_type(), 0);
195 auto a = builder.create_local_var(lhs->element_type());
196 builder.create_local_store(a, lhs);
197 auto b = builder.create_local_var(rhs->element_type());
198 builder.create_local_store(b, builder.create_abs(rhs));
199 auto result = builder.create_local_var(lhs->element_type());
200 builder.create_local_store(result, one_lhs);
201 auto loop = builder.create_while_true();
202 {
203 auto loop_guard = builder.get_loop_guard(loop);
204 auto current_a = builder.create_local_load(a);
205 auto current_b = builder.create_local_load(b);
206 auto if_stmt =
207 builder.create_if(builder.create_cmp_le(current_b, zero_rhs));
208 {
209 auto _ = builder.get_if_guard(if_stmt, true);
210 builder.create_break();
211 }
212 auto bit_and = builder.create_and(current_b, one_rhs);
213 if_stmt = builder.create_if(builder.create_cmp_ne(bit_and, zero_rhs));
214 {
215 auto _ = builder.get_if_guard(if_stmt, true);
216 auto current_result = builder.create_local_load(result);
217 auto new_result = builder.create_mul(current_result, current_a);
218 builder.create_local_store(result, new_result);
219 }
220 auto new_a = builder.create_mul(current_a, current_a);
221 builder.create_local_store(a, new_a);
222 auto new_b = builder.create_sar(current_b, one_rhs);
223 builder.create_local_store(b, new_b);
224 }
225 if (is_real(lhs->element_type())) {
226 auto if_stmt = builder.create_if(builder.create_cmp_le(rhs, zero_rhs));
227 {
228 auto _ = builder.get_if_guard(if_stmt, true);
229 auto current_result = builder.create_local_load(result);
230 auto new_result = builder.create_div(one_lhs, current_result);
231 builder.create_local_store(result, new_result);
232 }
233 }
234 auto final_result = builder.create_local_load(result);
235 stmt->replace_usages_with(final_result);
236 modifier.insert_before(
237 stmt, VecStatement(std::move(builder.extract_ir()->statements)));
238 modifier.erase(stmt);
239 }
240 }
241
242 static bool run(IRNode *node, const CompileConfig &config) {
243 DemoteOperations demoter;
244 bool modified = false;
245 while (true) {
246 node->accept(&demoter);
247 if (demoter.modifier.modify_ir())
248 modified = true;
249 else
250 break;
251 irpass::type_check(node, config);
252 }
253 if (modified) {
254 irpass::type_check(node, config);
255 }
256 return modified;
257 }
258};
259
260namespace irpass {
261
262bool demote_operations(IRNode *root, const CompileConfig &config) {
263 TI_AUTO_PROF;
264 bool modified = DemoteOperations::run(root, config);
265 return modified;
266}
267
268} // namespace irpass
269
270} // namespace taichi::lang
271