1#include <cmath>
2#include <deque>
3#include <set>
4#include <thread>
5
6#include "taichi/ir/ir.h"
7#include "taichi/ir/snode.h"
8#include "taichi/ir/statements.h"
9#include "taichi/ir/transforms.h"
10#include "taichi/ir/visitors.h"
11#include "taichi/transforms/constant_fold.h"
12#include "taichi/program/program.h"
13
14namespace taichi::lang {
15
16class ConstantFold : public BasicStmtVisitor {
17 public:
18 using BasicStmtVisitor::visit;
19 DelayedIRModifier modifier;
20 Program *program;
21 CompileConfig compile_config;
22
23 explicit ConstantFold(Program *program, const CompileConfig &compile_config)
24 : program(program), compile_config(compile_config) {
25 this->compile_config.advanced_optimization = false;
26 this->compile_config.constant_folding = false;
27 this->compile_config.external_optimization_level = 0;
28 }
29
30 Kernel *get_jit_evaluator_kernel(JITEvaluatorId const &id) {
31 auto &cache = program->jit_evaluator_cache;
32 // Discussion:
33 // https://github.com/taichi-dev/taichi/pull/954#discussion_r423442606
34 std::lock_guard<std::mutex> _(program->jit_evaluator_cache_mut);
35 auto it = cache.find(id);
36 if (it != cache.end()) // cached?
37 return it->second.get();
38
39 auto kernel_name = fmt::format("jit_evaluator_{}", cache.size());
40 auto func = [&id](Kernel *kernel) {
41 auto lhstmt =
42 Stmt::make<ArgLoadStmt>(/*arg_id=*/0, id.lhs, /*is_ptr=*/false);
43 auto rhstmt =
44 Stmt::make<ArgLoadStmt>(/*arg_id=*/1, id.rhs, /*is_ptr=*/false);
45 pStmt oper;
46 if (id.is_binary) {
47 oper = Stmt::make<BinaryOpStmt>(id.binary_op(), lhstmt.get(),
48 rhstmt.get());
49 oper->set_tb(id.tb);
50 } else {
51 oper = Stmt::make<UnaryOpStmt>(id.unary_op(), lhstmt.get());
52 if (unary_op_is_cast(id.unary_op())) {
53 oper->cast<UnaryOpStmt>()->cast_type = id.rhs;
54 }
55 }
56 auto &ast_builder = kernel->context->builder();
57 auto ret = Stmt::make<ReturnStmt>(oper.get());
58 ast_builder.insert(std::move(lhstmt));
59 if (id.is_binary) {
60 ast_builder.insert(std::move(rhstmt));
61 }
62 ast_builder.insert(std::move(oper));
63 ast_builder.insert(std::move(ret));
64 };
65
66 auto ker = std::make_unique<Kernel>(*program, func, kernel_name);
67 ker->insert_ret(id.ret);
68 ker->insert_scalar_param(id.lhs);
69 if (id.is_binary)
70 ker->insert_scalar_param(id.rhs);
71 ker->is_evaluator = true;
72 ker->finalize_rets();
73
74 auto *ker_ptr = ker.get();
75 TI_TRACE("Saving JIT evaluator cache entry id={}",
76 std::hash<JITEvaluatorId>{}(id));
77 cache[id] = std::move(ker);
78
79 return ker_ptr;
80 }
81
82 static bool is_good_type(DataType dt) {
83 // ConstStmt of `bad` types like `i8` is not supported by LLVM.
84 // Discussion:
85 // https://github.com/taichi-dev/taichi/pull/839#issuecomment-625902727
86 if (dt->is_primitive(PrimitiveTypeID::i32) ||
87 dt->is_primitive(PrimitiveTypeID::i64) ||
88 dt->is_primitive(PrimitiveTypeID::u32) ||
89 dt->is_primitive(PrimitiveTypeID::u64) ||
90 dt->is_primitive(PrimitiveTypeID::f32) ||
91 dt->is_primitive(PrimitiveTypeID::f64))
92 return true;
93 else
94 return false;
95 }
96
97 bool jit_evaluate_binary_op(TypedConstant &ret,
98 BinaryOpStmt *stmt,
99 const TypedConstant &lhs,
100 const TypedConstant &rhs) {
101 if (!is_good_type(ret.dt))
102 return false;
103 JITEvaluatorId id{std::this_thread::get_id(),
104 (int)stmt->op_type,
105 ret.dt,
106 lhs.dt,
107 rhs.dt,
108 compile_config.debug ? stmt->tb : "",
109 true};
110 auto *ker = get_jit_evaluator_kernel(id);
111 auto launch_ctx = ker->make_launch_context();
112 launch_ctx.set_arg_raw(0, lhs.val_u64);
113 launch_ctx.set_arg_raw(1, rhs.val_u64);
114 {
115 std::lock_guard<std::mutex> _(program->jit_evaluator_cache_mut);
116 (*ker)(compile_config, launch_ctx);
117 ret.val_i64 = program->fetch_result<int64_t>(0);
118 }
119 return true;
120 }
121
122 bool jit_evaluate_unary_op(TypedConstant &ret,
123 UnaryOpStmt *stmt,
124 const TypedConstant &operand) {
125 if (!is_good_type(ret.dt))
126 return false;
127 JITEvaluatorId id{std::this_thread::get_id(),
128 (int)stmt->op_type,
129 ret.dt,
130 operand.dt,
131 stmt->cast_type,
132 "",
133 false};
134 auto *ker = get_jit_evaluator_kernel(id);
135 auto launch_ctx = ker->make_launch_context();
136 launch_ctx.set_arg_raw(0, operand.val_u64);
137 {
138 std::lock_guard<std::mutex> _(program->jit_evaluator_cache_mut);
139 (*ker)(compile_config, launch_ctx);
140 ret.val_i64 = program->fetch_result<int64_t>(0);
141 }
142 return true;
143 }
144
145 void visit(BinaryOpStmt *stmt) override {
146 auto lhs = stmt->lhs->cast<ConstStmt>();
147 auto rhs = stmt->rhs->cast<ConstStmt>();
148 if (!lhs || !rhs)
149 return;
150 auto dst_type = stmt->ret_type;
151 TypedConstant new_constant(dst_type);
152
153 if (stmt->op_type == BinaryOpType::pow) {
154 if (is_integral(rhs->ret_type)) {
155 auto rhs_val = rhs->val.val_int();
156 if (rhs_val < 0 && is_integral(stmt->ret_type)) {
157 TI_ERROR("Negative exponent in pow(int, int) is not allowed.");
158 }
159 }
160 }
161
162 if (jit_evaluate_binary_op(new_constant, stmt, lhs->val, rhs->val)) {
163 auto evaluated = Stmt::make<ConstStmt>(TypedConstant(new_constant));
164 stmt->replace_usages_with(evaluated.get());
165 modifier.insert_before(stmt, std::move(evaluated));
166 modifier.erase(stmt);
167 }
168 }
169
170 void visit(UnaryOpStmt *stmt) override {
171 if (stmt->is_cast() && stmt->cast_type == stmt->operand->ret_type) {
172 stmt->replace_usages_with(stmt->operand);
173 modifier.erase(stmt);
174 return;
175 }
176 auto operand = stmt->operand->cast<ConstStmt>();
177 if (!operand)
178 return;
179 if (stmt->is_cast()) {
180 bool cast_available = true;
181 TypedConstant new_constant(stmt->ret_type);
182 auto operand = stmt->operand->cast<ConstStmt>();
183 if (stmt->op_type == UnaryOpType::cast_bits) {
184 new_constant.value_bits = operand->val.value_bits;
185 } else {
186 if (stmt->cast_type == PrimitiveType::f32) {
187 new_constant.val_f32 = float32(operand->val.val_cast_to_float64());
188 } else if (stmt->cast_type == PrimitiveType::f64) {
189 new_constant.val_f64 = operand->val.val_cast_to_float64();
190 } else {
191 cast_available = false;
192 }
193 }
194 if (cast_available) {
195 auto evaluated = Stmt::make<ConstStmt>(TypedConstant(new_constant));
196 stmt->replace_usages_with(evaluated.get());
197 modifier.insert_before(stmt, std::move(evaluated));
198 modifier.erase(stmt);
199 return;
200 }
201 }
202 auto dst_type = stmt->ret_type;
203 TypedConstant new_constant(dst_type);
204 if (jit_evaluate_unary_op(new_constant, stmt, operand->val)) {
205 auto evaluated = Stmt::make<ConstStmt>(TypedConstant(new_constant));
206 stmt->replace_usages_with(evaluated.get());
207 modifier.insert_before(stmt, std::move(evaluated));
208 modifier.erase(stmt);
209 }
210 }
211
212 static bool run(IRNode *node,
213 Program *program,
214 const CompileConfig &compile_config) {
215 ConstantFold folder(program, compile_config);
216 bool modified = false;
217
218 while (true) {
219 node->accept(&folder);
220 if (folder.modifier.modify_ir()) {
221 modified = true;
222 } else {
223 break;
224 }
225 }
226
227 return modified;
228 }
229};
230
231const PassID ConstantFoldPass::id = "ConstantFoldPass";
232
233namespace irpass {
234
235bool constant_fold(IRNode *root,
236 const CompileConfig &compile_config,
237 const ConstantFoldPass::Args &args) {
238 TI_AUTO_PROF;
239 if (!compile_config.advanced_optimization)
240 return false;
241 return ConstantFold::run(root, args.program, compile_config);
242}
243
244} // namespace irpass
245
246} // namespace taichi::lang
247