1 | #pragma once |
2 | |
3 | #include "taichi/ir/expr.h" |
4 | #include "taichi/ir/expression.h" |
5 | #include "taichi/ir/frontend_ir.h" |
6 | #include "taichi/program/program.h" |
7 | #include "taichi/analysis/offline_cache_util.h" |
8 | |
9 | namespace taichi::lang { |
10 | |
11 | class ExpressionPrinter : public ExpressionVisitor { |
12 | public: |
13 | explicit ExpressionPrinter(std::ostream *os = nullptr) : os_(os) { |
14 | } |
15 | |
16 | void set_ostream(std::ostream *os) { |
17 | os_ = os; |
18 | } |
19 | |
20 | std::ostream *get_ostream() { |
21 | return os_; |
22 | } |
23 | |
24 | private: |
25 | std::ostream *os_{nullptr}; |
26 | }; |
27 | |
28 | class ExpressionHumanFriendlyPrinter : public ExpressionPrinter { |
29 | public: |
30 | explicit ExpressionHumanFriendlyPrinter(std::ostream *os = nullptr) |
31 | : ExpressionPrinter(os) { |
32 | } |
33 | |
34 | void visit(ExprGroup &expr_group) override { |
35 | emit_vector(expr_group.exprs); |
36 | } |
37 | |
38 | void visit(ArgLoadExpression *expr) override { |
39 | emit( |
40 | fmt::format("arg[{}] (dt={})" , expr->arg_id, data_type_name(expr->dt))); |
41 | } |
42 | |
43 | void visit(TexturePtrExpression *expr) override { |
44 | emit(fmt::format("(Texture *)(arg[{}])" , expr->arg_id)); |
45 | } |
46 | |
47 | void visit(TextureOpExpression *expr) override { |
48 | emit(fmt::format("texture_{}(" , texture_op_type_name(expr->op))); |
49 | visit(expr->args); |
50 | emit(")" ); |
51 | } |
52 | |
53 | void visit(RandExpression *expr) override { |
54 | emit(fmt::format("rand<{}>()" , data_type_name(expr->dt))); |
55 | } |
56 | |
57 | void visit(UnaryOpExpression *expr) override { |
58 | emit('('); |
59 | if (expr->is_cast()) { |
60 | emit(expr->type == UnaryOpType::cast_value ? "" : "reinterpret_" ); |
61 | emit(unary_op_type_name(expr->type)); |
62 | emit('<', data_type_name(expr->cast_type), "> " ); |
63 | } else { |
64 | emit(unary_op_type_name(expr->type), ' '); |
65 | } |
66 | expr->operand->accept(this); |
67 | emit(')'); |
68 | } |
69 | |
70 | void visit(BinaryOpExpression *expr) override { |
71 | emit('('); |
72 | expr->lhs->accept(this); |
73 | emit(' ', binary_op_type_symbol(expr->type), ' '); |
74 | expr->rhs->accept(this); |
75 | emit(')'); |
76 | } |
77 | |
78 | void visit(TernaryOpExpression *expr) override { |
79 | emit(ternary_type_name(expr->type), '('); |
80 | expr->op1->accept(this); |
81 | emit(' '); |
82 | expr->op2->accept(this); |
83 | emit(' '); |
84 | expr->op3->accept(this); |
85 | emit(')'); |
86 | } |
87 | |
88 | void visit(InternalFuncCallExpression *expr) override { |
89 | emit("internal call " , expr->func_name, '('); |
90 | if (expr->with_runtime_context) { |
91 | emit("runtime, " ); |
92 | } |
93 | emit_vector(expr->args); |
94 | emit(')'); |
95 | } |
96 | |
97 | void visit(ExternalTensorExpression *expr) override { |
98 | emit(fmt::format("{}d_ext_arr (element_dim={}, dt={}, grad={})" , expr->dim, |
99 | expr->element_dim, expr->dt->to_string(), expr->is_grad)); |
100 | } |
101 | |
102 | void visit(FieldExpression *expr) override { |
103 | emit("#" , expr->ident.name()); |
104 | if (expr->snode) { |
105 | emit( |
106 | fmt::format(" (snode={})" , expr->snode->get_node_type_name_hinted())); |
107 | } else { |
108 | emit(fmt::format(" (dt={})" , expr->dt->to_string())); |
109 | } |
110 | } |
111 | |
112 | void visit(MatrixFieldExpression *expr) override { |
113 | emit('['); |
114 | emit_vector(expr->fields); |
115 | emit("] (" ); |
116 | emit_vector(expr->element_shape); |
117 | if (expr->dynamic_index_stride) { |
118 | emit(", dynamic_index_stride = " , expr->dynamic_index_stride); |
119 | } |
120 | emit(')'); |
121 | } |
122 | |
123 | void visit(MatrixExpression *expr) override { |
124 | emit('['); |
125 | emit_vector(expr->elements); |
126 | emit(']'); |
127 | emit(fmt::format(" (dt={})" , expr->dt->to_string())); |
128 | } |
129 | |
130 | void visit(IndexExpression *expr) override { |
131 | expr->var->accept(this); |
132 | emit('['); |
133 | if (expr->ret_shape.empty()) { |
134 | emit_vector(expr->indices_group[0].exprs); |
135 | } else { |
136 | for (auto &indices : expr->indices_group) { |
137 | emit('('); |
138 | emit_vector(indices.exprs); |
139 | emit("), " ); |
140 | } |
141 | emit("shape=(" ); |
142 | emit_vector(expr->ret_shape); |
143 | emit(')'); |
144 | } |
145 | emit(']'); |
146 | } |
147 | |
148 | void visit(RangeAssumptionExpression *expr) override { |
149 | emit("assume_in_range({" ); |
150 | expr->base->accept(this); |
151 | emit(fmt::format("{:+d}" , expr->low), " <= (" ); |
152 | expr->input->accept(this); |
153 | emit(") < " ); |
154 | expr->base->accept(this); |
155 | emit(fmt::format("{:+d})" , expr->high)); |
156 | } |
157 | |
158 | void visit(LoopUniqueExpression *expr) override { |
159 | emit("loop_unique(" ); |
160 | expr->input->accept(this); |
161 | if (!expr->covers.empty()) { |
162 | emit(", covers=[" ); |
163 | emit_vector(expr->covers); |
164 | emit(']'); |
165 | } |
166 | emit(')'); |
167 | } |
168 | |
169 | void visit(IdExpression *expr) override { |
170 | emit(expr->id.name()); |
171 | } |
172 | |
173 | void visit(AtomicOpExpression *expr) override { |
174 | const auto op_type = (std::size_t)expr->op_type; |
175 | constexpr const char *names_table[] = { |
176 | "atomic_add" , "atomic_sub" , "atomic_min" , "atomic_max" , |
177 | "atomic_bit_and" , "atomic_bit_or" , "atomic_bit_xor" , |
178 | }; |
179 | if (op_type > std::size(names_table)) { |
180 | // min/max not supported in the LLVM backend yet. |
181 | TI_NOT_IMPLEMENTED; |
182 | } |
183 | emit(names_table[op_type], '('); |
184 | expr->dest->accept(this); |
185 | emit(", " ); |
186 | expr->val->accept(this); |
187 | emit(")" ); |
188 | } |
189 | |
190 | void visit(SNodeOpExpression *expr) override { |
191 | emit(snode_op_type_name(expr->op_type)); |
192 | emit('(', expr->snode->get_node_type_name_hinted(), ", [" ); |
193 | emit_vector(expr->indices.exprs); |
194 | emit("]" ); |
195 | if (!expr->values.empty()) { |
196 | emit(' '); |
197 | emit_vector(expr->values); |
198 | } |
199 | emit(')'); |
200 | } |
201 | |
202 | void visit(ConstExpression *expr) override { |
203 | emit(expr->val.stringify()); |
204 | } |
205 | |
206 | void visit(ExternalTensorShapeAlongAxisExpression *expr) override { |
207 | emit("external_tensor_shape_along_axis(" ); |
208 | expr->ptr->accept(this); |
209 | emit(", " , expr->axis, ')'); |
210 | } |
211 | |
212 | void visit(MeshPatchIndexExpression *expr) override { |
213 | emit("mesh_patch_idx()" ); |
214 | } |
215 | |
216 | void visit(MeshRelationAccessExpression *expr) override { |
217 | if (expr->neighbor_idx) { |
218 | emit("mesh_relation_access(" ); |
219 | expr->mesh_idx->accept(this); |
220 | emit(", " , mesh::element_type_name(expr->to_type), '['); |
221 | expr->neighbor_idx->accept(this); |
222 | emit("])" ); |
223 | } else { |
224 | emit("mesh_relation_size(" ); |
225 | expr->mesh_idx->accept(this); |
226 | emit(", " , mesh::element_type_name(expr->to_type), ')'); |
227 | } |
228 | } |
229 | |
230 | void visit(MeshIndexConversionExpression *expr) override { |
231 | emit("mesh_index_conversion(" , mesh::conv_type_name(expr->conv_type), ", " , |
232 | mesh::element_type_name(expr->idx_type), ", " ); |
233 | expr->idx->accept(this); |
234 | emit(")" ); |
235 | } |
236 | |
237 | void visit(ReferenceExpression *expr) override { |
238 | emit("ref(" ); |
239 | expr->var->accept(this); |
240 | emit(")" ); |
241 | } |
242 | |
243 | void visit(GetElementExpression *expr) override { |
244 | emit("get_element(" ); |
245 | expr->src->accept(this); |
246 | emit(", " ); |
247 | emit_vector(expr->index); |
248 | emit(")" ); |
249 | } |
250 | |
251 | static std::string expr_to_string(Expr &expr) { |
252 | std::ostringstream oss; |
253 | ExpressionHumanFriendlyPrinter printer(&oss); |
254 | expr->accept(&printer); |
255 | return oss.str(); |
256 | } |
257 | |
258 | protected: |
259 | template <typename... Args> |
260 | void emit(Args &&...args) { |
261 | TI_ASSERT(this->get_ostream()); |
262 | (*this->get_ostream() << ... << std::forward<Args>(args)); |
263 | } |
264 | |
265 | template <typename T> |
266 | void emit_vector(std::vector<T> &v) { |
267 | if (!v.empty()) { |
268 | emit_element(v[0]); |
269 | const auto size = v.size(); |
270 | for (std::size_t i = 1; i < size; ++i) { |
271 | emit(", " ); |
272 | emit_element(v[i]); |
273 | } |
274 | } |
275 | } |
276 | |
277 | template <typename D> |
278 | void emit_element(D &&e) { |
279 | using T = |
280 | typename std::remove_cv<typename std::remove_reference<D>::type>::type; |
281 | if constexpr (std::is_same_v<T, Expr>) { |
282 | e->accept(this); |
283 | } else if constexpr (std::is_same_v<T, SNode *>) { |
284 | emit(e->get_node_type_name_hinted()); |
285 | } else { |
286 | emit(std::forward<D>(e)); |
287 | } |
288 | } |
289 | }; |
290 | |
291 | } // namespace taichi::lang |
292 | |