1#pragma once
2
3#include "taichi/ir/expr.h"
4#include "taichi/ir/expression.h"
5#include "taichi/ir/frontend_ir.h"
6#include "taichi/program/program.h"
7#include "taichi/analysis/offline_cache_util.h"
8
9namespace taichi::lang {
10
11class ExpressionPrinter : public ExpressionVisitor {
12 public:
13 explicit ExpressionPrinter(std::ostream *os = nullptr) : os_(os) {
14 }
15
16 void set_ostream(std::ostream *os) {
17 os_ = os;
18 }
19
20 std::ostream *get_ostream() {
21 return os_;
22 }
23
24 private:
25 std::ostream *os_{nullptr};
26};
27
28class ExpressionHumanFriendlyPrinter : public ExpressionPrinter {
29 public:
30 explicit ExpressionHumanFriendlyPrinter(std::ostream *os = nullptr)
31 : ExpressionPrinter(os) {
32 }
33
34 void visit(ExprGroup &expr_group) override {
35 emit_vector(expr_group.exprs);
36 }
37
38 void visit(ArgLoadExpression *expr) override {
39 emit(
40 fmt::format("arg[{}] (dt={})", expr->arg_id, data_type_name(expr->dt)));
41 }
42
43 void visit(TexturePtrExpression *expr) override {
44 emit(fmt::format("(Texture *)(arg[{}])", expr->arg_id));
45 }
46
47 void visit(TextureOpExpression *expr) override {
48 emit(fmt::format("texture_{}(", texture_op_type_name(expr->op)));
49 visit(expr->args);
50 emit(")");
51 }
52
53 void visit(RandExpression *expr) override {
54 emit(fmt::format("rand<{}>()", data_type_name(expr->dt)));
55 }
56
57 void visit(UnaryOpExpression *expr) override {
58 emit('(');
59 if (expr->is_cast()) {
60 emit(expr->type == UnaryOpType::cast_value ? "" : "reinterpret_");
61 emit(unary_op_type_name(expr->type));
62 emit('<', data_type_name(expr->cast_type), "> ");
63 } else {
64 emit(unary_op_type_name(expr->type), ' ');
65 }
66 expr->operand->accept(this);
67 emit(')');
68 }
69
70 void visit(BinaryOpExpression *expr) override {
71 emit('(');
72 expr->lhs->accept(this);
73 emit(' ', binary_op_type_symbol(expr->type), ' ');
74 expr->rhs->accept(this);
75 emit(')');
76 }
77
78 void visit(TernaryOpExpression *expr) override {
79 emit(ternary_type_name(expr->type), '(');
80 expr->op1->accept(this);
81 emit(' ');
82 expr->op2->accept(this);
83 emit(' ');
84 expr->op3->accept(this);
85 emit(')');
86 }
87
88 void visit(InternalFuncCallExpression *expr) override {
89 emit("internal call ", expr->func_name, '(');
90 if (expr->with_runtime_context) {
91 emit("runtime, ");
92 }
93 emit_vector(expr->args);
94 emit(')');
95 }
96
97 void visit(ExternalTensorExpression *expr) override {
98 emit(fmt::format("{}d_ext_arr (element_dim={}, dt={}, grad={})", expr->dim,
99 expr->element_dim, expr->dt->to_string(), expr->is_grad));
100 }
101
102 void visit(FieldExpression *expr) override {
103 emit("#", expr->ident.name());
104 if (expr->snode) {
105 emit(
106 fmt::format(" (snode={})", expr->snode->get_node_type_name_hinted()));
107 } else {
108 emit(fmt::format(" (dt={})", expr->dt->to_string()));
109 }
110 }
111
112 void visit(MatrixFieldExpression *expr) override {
113 emit('[');
114 emit_vector(expr->fields);
115 emit("] (");
116 emit_vector(expr->element_shape);
117 if (expr->dynamic_index_stride) {
118 emit(", dynamic_index_stride = ", expr->dynamic_index_stride);
119 }
120 emit(')');
121 }
122
123 void visit(MatrixExpression *expr) override {
124 emit('[');
125 emit_vector(expr->elements);
126 emit(']');
127 emit(fmt::format(" (dt={})", expr->dt->to_string()));
128 }
129
130 void visit(IndexExpression *expr) override {
131 expr->var->accept(this);
132 emit('[');
133 if (expr->ret_shape.empty()) {
134 emit_vector(expr->indices_group[0].exprs);
135 } else {
136 for (auto &indices : expr->indices_group) {
137 emit('(');
138 emit_vector(indices.exprs);
139 emit("), ");
140 }
141 emit("shape=(");
142 emit_vector(expr->ret_shape);
143 emit(')');
144 }
145 emit(']');
146 }
147
148 void visit(RangeAssumptionExpression *expr) override {
149 emit("assume_in_range({");
150 expr->base->accept(this);
151 emit(fmt::format("{:+d}", expr->low), " <= (");
152 expr->input->accept(this);
153 emit(") < ");
154 expr->base->accept(this);
155 emit(fmt::format("{:+d})", expr->high));
156 }
157
158 void visit(LoopUniqueExpression *expr) override {
159 emit("loop_unique(");
160 expr->input->accept(this);
161 if (!expr->covers.empty()) {
162 emit(", covers=[");
163 emit_vector(expr->covers);
164 emit(']');
165 }
166 emit(')');
167 }
168
169 void visit(IdExpression *expr) override {
170 emit(expr->id.name());
171 }
172
173 void visit(AtomicOpExpression *expr) override {
174 const auto op_type = (std::size_t)expr->op_type;
175 constexpr const char *names_table[] = {
176 "atomic_add", "atomic_sub", "atomic_min", "atomic_max",
177 "atomic_bit_and", "atomic_bit_or", "atomic_bit_xor",
178 };
179 if (op_type > std::size(names_table)) {
180 // min/max not supported in the LLVM backend yet.
181 TI_NOT_IMPLEMENTED;
182 }
183 emit(names_table[op_type], '(');
184 expr->dest->accept(this);
185 emit(", ");
186 expr->val->accept(this);
187 emit(")");
188 }
189
190 void visit(SNodeOpExpression *expr) override {
191 emit(snode_op_type_name(expr->op_type));
192 emit('(', expr->snode->get_node_type_name_hinted(), ", [");
193 emit_vector(expr->indices.exprs);
194 emit("]");
195 if (!expr->values.empty()) {
196 emit(' ');
197 emit_vector(expr->values);
198 }
199 emit(')');
200 }
201
202 void visit(ConstExpression *expr) override {
203 emit(expr->val.stringify());
204 }
205
206 void visit(ExternalTensorShapeAlongAxisExpression *expr) override {
207 emit("external_tensor_shape_along_axis(");
208 expr->ptr->accept(this);
209 emit(", ", expr->axis, ')');
210 }
211
212 void visit(MeshPatchIndexExpression *expr) override {
213 emit("mesh_patch_idx()");
214 }
215
216 void visit(MeshRelationAccessExpression *expr) override {
217 if (expr->neighbor_idx) {
218 emit("mesh_relation_access(");
219 expr->mesh_idx->accept(this);
220 emit(", ", mesh::element_type_name(expr->to_type), '[');
221 expr->neighbor_idx->accept(this);
222 emit("])");
223 } else {
224 emit("mesh_relation_size(");
225 expr->mesh_idx->accept(this);
226 emit(", ", mesh::element_type_name(expr->to_type), ')');
227 }
228 }
229
230 void visit(MeshIndexConversionExpression *expr) override {
231 emit("mesh_index_conversion(", mesh::conv_type_name(expr->conv_type), ", ",
232 mesh::element_type_name(expr->idx_type), ", ");
233 expr->idx->accept(this);
234 emit(")");
235 }
236
237 void visit(ReferenceExpression *expr) override {
238 emit("ref(");
239 expr->var->accept(this);
240 emit(")");
241 }
242
243 void visit(GetElementExpression *expr) override {
244 emit("get_element(");
245 expr->src->accept(this);
246 emit(", ");
247 emit_vector(expr->index);
248 emit(")");
249 }
250
251 static std::string expr_to_string(Expr &expr) {
252 std::ostringstream oss;
253 ExpressionHumanFriendlyPrinter printer(&oss);
254 expr->accept(&printer);
255 return oss.str();
256 }
257
258 protected:
259 template <typename... Args>
260 void emit(Args &&...args) {
261 TI_ASSERT(this->get_ostream());
262 (*this->get_ostream() << ... << std::forward<Args>(args));
263 }
264
265 template <typename T>
266 void emit_vector(std::vector<T> &v) {
267 if (!v.empty()) {
268 emit_element(v[0]);
269 const auto size = v.size();
270 for (std::size_t i = 1; i < size; ++i) {
271 emit(", ");
272 emit_element(v[i]);
273 }
274 }
275 }
276
277 template <typename D>
278 void emit_element(D &&e) {
279 using T =
280 typename std::remove_cv<typename std::remove_reference<D>::type>::type;
281 if constexpr (std::is_same_v<T, Expr>) {
282 e->accept(this);
283 } else if constexpr (std::is_same_v<T, SNode *>) {
284 emit(e->get_node_type_name_hinted());
285 } else {
286 emit(std::forward<D>(e));
287 }
288 }
289};
290
291} // namespace taichi::lang
292