1#include "taichi/ir/statements.h"
2
3namespace taichi::lang {
4
5Stmt *generate_mod(VecStatement *stmts, Stmt *x, int y) {
6 if (bit::is_power_of_two(y)) {
7 auto const_stmt = stmts->push_back<ConstStmt>(TypedConstant(y - 1));
8 return stmts->push_back<BinaryOpStmt>(BinaryOpType::bit_and, x, const_stmt);
9 }
10 auto const_stmt = stmts->push_back<ConstStmt>(TypedConstant(y));
11 return stmts->push_back<BinaryOpStmt>(BinaryOpType::mod, x, const_stmt);
12}
13
14Stmt *generate_div(VecStatement *stmts, Stmt *x, int y) {
15 if (bit::is_power_of_two(y)) {
16 auto const_stmt = stmts->push_back<ConstStmt>(
17 TypedConstant(PrimitiveType::i32, bit::log2int(y)));
18 return stmts->push_back<BinaryOpStmt>(BinaryOpType::bit_shr, x, const_stmt);
19 }
20 auto const_stmt = stmts->push_back<ConstStmt>(TypedConstant(y));
21 return stmts->push_back<BinaryOpStmt>(BinaryOpType::div, x, const_stmt);
22}
23
24} // namespace taichi::lang
25