1#include "offline_cache_util.h"
2
3#include "taichi/ir/expr.h"
4#include "taichi/ir/frontend_ir.h"
5#include "taichi/ir/ir.h"
6#include "taichi/ir/mesh.h"
7#include "taichi/ir/type.h"
8#include "taichi/program/function.h"
9#include "taichi/program/program.h"
10
11namespace taichi::lang {
12
13namespace {
14
15enum class ExprOpCode : std::uint8_t {
16 NIL,
17#define PER_EXPRESSION(x) x,
18#include "taichi/inc/expressions.inc.h"
19#undef PER_EXPRESSION
20};
21
22enum class StmtOpCode : std::uint8_t {
23 NIL,
24 EnterBlock,
25 ExitBlock,
26 StopGrad,
27#define PER_STATEMENT(x) x,
28#include "taichi/inc/frontend_statements.inc.h"
29#undef PER_STATEMENT
30};
31
32enum class ForLoopType : std::uint8_t {
33 StructForOnSNode,
34 StructForOnExternalTensor,
35 MeshFor,
36 RangeFor
37};
38
39enum class ExternalFuncType : std::uint8_t {
40 SO,
41 ASM,
42 BC,
43};
44
45enum class MeshRelationAccessType {
46 Access, // mesh_relation_access
47 Size, // mesh_relation_size
48};
49
50class ASTSerializer : public IRVisitor, public ExpressionVisitor {
51 private:
52 using ExpressionVisitor::visit;
53 using IRVisitor::visit;
54
55 public:
56 explicit ASTSerializer(std::ostream *os) : ExpressionVisitor(true), os_(os) {
57 // TODO(PGZXB): Set allow_undefined_visitor as false. (blocked by
58 // constant-folding)
59 this->allow_undefined_visitor = true;
60 }
61
62 void set_ostream(std::ostream *os) {
63 this->os_ = os;
64 }
65
66 std::ostream *get_ostream() {
67 return this->os_;
68 }
69
70 void visit(Expression *expr) override {
71 this->ExpressionVisitor::visit(expr);
72 }
73
74 void visit(Stmt *stmt) override {
75 this->IRVisitor::visit(stmt);
76 }
77
78 void visit(ExprGroup &expr_group) override {
79 emit(expr_group.exprs);
80 }
81
82 void visit(ArgLoadExpression *expr) override {
83 emit(ExprOpCode::ArgLoadExpression);
84 emit(expr->dt);
85 emit(expr->arg_id);
86 emit(expr->is_ptr);
87 }
88
89 void visit(TexturePtrExpression *expr) override {
90 emit(ExprOpCode::TexturePtrExpression);
91 emit(expr->arg_id);
92 emit(expr->num_dims);
93 emit(expr->is_storage);
94 emit(expr->num_channels);
95 emit(expr->channel_format);
96 emit(expr->lod);
97 }
98
99 void visit(TextureOpExpression *expr) override {
100 emit(ExprOpCode::TextureOpExpression);
101 emit(expr->op);
102 emit(expr->texture_ptr);
103 emit(expr->args.exprs);
104 }
105
106 void visit(RandExpression *expr) override {
107 emit(ExprOpCode::RandExpression);
108 emit(expr->dt);
109 }
110
111 void visit(UnaryOpExpression *expr) override {
112 emit(ExprOpCode::UnaryOpExpression);
113 emit(expr->type);
114 if (expr->is_cast()) {
115 emit(expr->cast_type);
116 }
117 emit(expr->operand);
118 }
119
120 void visit(BinaryOpExpression *expr) override {
121 emit(ExprOpCode::BinaryOpExpression);
122 emit(expr->type);
123 emit(expr->lhs);
124 emit(expr->rhs);
125 }
126
127 void visit(TernaryOpExpression *expr) override {
128 emit(ExprOpCode::TernaryOpExpression);
129 emit(expr->type);
130 emit(expr->op1);
131 emit(expr->op2);
132 emit(expr->op3);
133 }
134
135 void visit(InternalFuncCallExpression *expr) override {
136 emit(ExprOpCode::InternalFuncCallExpression);
137 emit(expr->func_name);
138 emit(expr->args);
139 emit(expr->with_runtime_context);
140 }
141
142 void visit(ExternalTensorExpression *expr) override {
143 emit(ExprOpCode::ExternalTensorExpression);
144 emit(expr->dt);
145 emit(expr->dim);
146 emit(expr->arg_id);
147 emit(expr->element_dim);
148 }
149
150 void visit(FieldExpression *expr) override {
151 emit(ExprOpCode::FieldExpression);
152 emit(expr->ident);
153 emit(expr->dt);
154 emit(expr->snode);
155 emit(expr->has_ambient);
156 emit(expr->ambient_value);
157 emit(expr->snode_grad_type);
158 emit(expr->adjoint);
159 emit(expr->dual);
160 emit(expr->adjoint_checkbit);
161 }
162
163 void visit(MatrixFieldExpression *expr) override {
164 emit(ExprOpCode::MatrixFieldExpression);
165 emit(expr->fields);
166 emit(expr->element_shape);
167 emit(expr->dynamic_index_stride);
168 }
169
170 void visit(IndexExpression *expr) override {
171 emit(ExprOpCode::IndexExpression);
172 emit(expr->var);
173 for (auto &indices : expr->indices_group) {
174 emit(indices.exprs);
175 }
176 emit(expr->ret_shape);
177 }
178
179 void visit(MatrixExpression *expr) override {
180 emit(ExprOpCode::MatrixExpression);
181 emit(expr->dt);
182 for (auto elt : expr->elements) {
183 emit(elt);
184 }
185 }
186
187 void visit(RangeAssumptionExpression *expr) override {
188 emit(ExprOpCode::RangeAssumptionExpression);
189 emit(expr->input);
190 emit(expr->base);
191 emit(expr->low);
192 emit(expr->high);
193 }
194
195 void visit(LoopUniqueExpression *expr) override {
196 emit(ExprOpCode::LoopUniqueExpression);
197 emit(expr->input);
198 emit(expr->covers);
199 }
200
201 void visit(IdExpression *expr) override {
202 emit(ExprOpCode::IdExpression);
203 emit(expr->id);
204 }
205
206 void visit(AtomicOpExpression *expr) override {
207 emit(ExprOpCode::AtomicOpExpression);
208 emit(expr->op_type);
209 emit(expr->dest);
210 emit(expr->val);
211 }
212
213 void visit(SNodeOpExpression *expr) override {
214 emit(ExprOpCode::SNodeOpExpression);
215 emit(expr->op_type);
216 emit(expr->snode);
217 emit(expr->indices.exprs);
218 emit(expr->values);
219 }
220
221 void visit(ConstExpression *expr) override {
222 emit(ExprOpCode::ConstExpression);
223 emit(expr->val);
224 }
225
226 void visit(ExternalTensorShapeAlongAxisExpression *expr) override {
227 emit(ExprOpCode::ExternalTensorShapeAlongAxisExpression);
228 emit(expr->ptr);
229 emit(expr->axis);
230 }
231
232 void visit(FrontendFuncCallStmt *expr) override {
233 emit(StmtOpCode::FrontendFuncCallStmt);
234 emit(expr->func);
235 emit(expr->args.exprs);
236 }
237
238 void visit(MeshPatchIndexExpression *expr) override {
239 emit(ExprOpCode::MeshPatchIndexExpression);
240 }
241
242 void visit(MeshRelationAccessExpression *expr) override {
243 emit(ExprOpCode::MeshRelationAccessExpression);
244 if (expr->neighbor_idx) {
245 emit(MeshRelationAccessType::Access);
246 emit(expr->neighbor_idx);
247 } else {
248 emit(MeshRelationAccessType::Size);
249 }
250 emit(expr->mesh);
251 emit(expr->to_type);
252 emit(expr->mesh_idx);
253 }
254
255 void visit(MeshIndexConversionExpression *expr) override {
256 emit(ExprOpCode::MeshIndexConversionExpression);
257 emit(expr->mesh);
258 emit(expr->idx_type);
259 emit(expr->idx);
260 emit(expr->conv_type);
261 }
262
263 void visit(ReferenceExpression *expr) override {
264 emit(ExprOpCode::ReferenceExpression);
265 emit(expr->var);
266 }
267
268 void visit(GetElementExpression *expr) override {
269 emit(ExprOpCode::GetElementExpression);
270 emit(expr->src);
271 emit(expr->index);
272 }
273
274 void visit(Block *block) override {
275 emit(StmtOpCode::EnterBlock);
276 emit(static_cast<std::size_t>(block->statements.size()));
277 for (auto &stmt : block->statements) {
278 emit(stmt.get());
279 }
280 emit(StmtOpCode::StopGrad);
281 emit(block->stop_gradients);
282 emit(StmtOpCode::ExitBlock);
283 }
284
285 void visit(FrontendExprStmt *stmt) override {
286 emit(StmtOpCode::FrontendExprStmt);
287 emit(stmt->val);
288 }
289
290 void visit(FrontendBreakStmt *stmt) override {
291 emit(StmtOpCode::FrontendBreakStmt);
292 }
293
294 void visit(FrontendContinueStmt *stmt) override {
295 emit(StmtOpCode::FrontendContinueStmt);
296 }
297
298 void visit(FrontendAssignStmt *stmt) override {
299 emit(StmtOpCode::FrontendAssignStmt);
300 emit(stmt->lhs);
301 emit(stmt->rhs);
302 }
303
304 void visit(FrontendAllocaStmt *stmt) override {
305 emit(StmtOpCode::FrontendAllocaStmt);
306 emit(stmt->ident);
307 }
308
309 void visit(FrontendAssertStmt *stmt) override {
310 emit(StmtOpCode::FrontendAssertStmt);
311 emit(stmt->cond);
312 emit(stmt->text);
313 emit(stmt->args);
314 }
315
316 void visit(FrontendSNodeOpStmt *stmt) override {
317 emit(StmtOpCode::FrontendSNodeOpStmt);
318 emit(stmt->op_type);
319 emit(stmt->snode);
320 emit(stmt->indices.exprs);
321 emit(stmt->val);
322 }
323
324 void visit(FrontendIfStmt *stmt) override {
325 emit(StmtOpCode::FrontendIfStmt);
326 emit(stmt->condition);
327 std::uint8_t branch_count = 0;
328 if (stmt->true_statements) {
329 ++branch_count;
330 }
331 if (stmt->false_statements) {
332 ++branch_count;
333 }
334 emit(branch_count);
335 if (stmt->true_statements) {
336 emit(stmt->true_statements.get());
337 }
338 if (stmt->false_statements) {
339 emit(stmt->false_statements.get());
340 }
341 }
342
343 void visit(FrontendPrintStmt *stmt) override {
344 emit(StmtOpCode::FrontendPrintStmt);
345 emit(static_cast<std::size_t>(stmt->contents.size()));
346 for (const auto &c : stmt->contents) {
347 emit(static_cast<std::uint8_t>(c.index()));
348 if (std::holds_alternative<Expr>(c)) {
349 emit(std::get<Expr>(c));
350 } else {
351 emit(std::get<std::string>(c));
352 }
353 }
354 }
355
356 void visit(FrontendFuncDefStmt *stmt) override {
357 emit(StmtOpCode::FrontendFuncDefStmt);
358 emit(stmt->body.get());
359 }
360
361 void visit(FrontendWhileStmt *stmt) override {
362 emit(StmtOpCode::FrontendWhileStmt);
363 emit(stmt->cond);
364 emit(stmt->body.get());
365 }
366
367 void visit(FrontendForStmt *stmt) override {
368 emit(StmtOpCode::FrontendForStmt);
369 if (stmt->snode) {
370 emit(ForLoopType::StructForOnSNode);
371 emit(stmt->snode);
372 } else if (stmt->external_tensor) {
373 emit(ForLoopType::StructForOnExternalTensor);
374 emit(stmt->external_tensor);
375 } else if (stmt->mesh) {
376 emit(ForLoopType::MeshFor);
377 emit(stmt->element_type);
378 emit(stmt->mesh);
379 } else {
380 emit(ForLoopType::RangeFor);
381 emit(stmt->begin);
382 emit(stmt->end);
383 }
384 emit(stmt->loop_var_ids);
385 emit(stmt->is_bit_vectorized);
386 emit(stmt->num_cpu_threads);
387 emit(stmt->strictly_serialized);
388 emit(stmt->mem_access_opt);
389 emit(stmt->block_dim);
390 emit(stmt->body.get());
391 }
392
393 void visit(FrontendReturnStmt *stmt) override {
394 emit(StmtOpCode::FrontendReturnStmt);
395 emit(stmt->values.exprs);
396 }
397
398 void visit(FrontendExternalFuncStmt *stmt) override {
399 // Note: The result of serializing FrontendExternalFuncStmt is not parsable
400 // now
401 emit(StmtOpCode::FrontendExternalFuncStmt);
402 if (stmt->so_func != nullptr) {
403 emit(ExternalFuncType::SO);
404 } else if (!stmt->asm_source.empty()) {
405 emit(ExternalFuncType::ASM);
406 emit(stmt->asm_source);
407 } else {
408 emit(ExternalFuncType::BC);
409 emit(stmt->bc_filename);
410 emit(stmt->bc_funcname);
411 }
412 emit(stmt->args);
413 emit(stmt->outputs);
414 }
415
416 static void run(IRNode *ast, std::ostream *os) {
417 ASTSerializer serializer(os);
418 ast->accept(&serializer);
419 serializer.emit_dependencies();
420 }
421
422 private:
423 void emit_dependencies() {
424 // Serialize dependent real-functions
425 emit(real_funcs_.size());
426 for (auto &[func, id] : real_funcs_) {
427 if (auto &ast_str = func->try_get_ast_serialization_data();
428 ast_str.has_value()) {
429 emit_bytes(ast_str->c_str(), ast_str->size());
430 }
431 }
432
433 // Serialize snode_trees(Temporary: using offline-cache-key of SNode)
434 // Note: The result of serializing snode_tree_roots_ is not parsable now
435 emit(static_cast<std::size_t>(snode_tree_roots_.size()));
436 for (const auto *snode : snode_tree_roots_) {
437 auto key = get_hashed_offline_cache_key_of_snode(snode);
438 emit_bytes(key.c_str(), key.size());
439 }
440
441 // Dump string-pool
442 emit(static_cast<std::size_t>(string_pool_.size()));
443 emit_bytes(string_pool_.data(), string_pool_.size());
444 }
445
446 template <typename T>
447 void emit_pod(const T &val) {
448 static_assert(std::is_pod<T>::value);
449 TI_ASSERT(os_);
450 os_->write((const char *)&val, sizeof(T));
451 }
452
453 void emit_bytes(const char *bytes, std::size_t len) {
454 TI_ASSERT(os_);
455 if (!bytes)
456 return;
457 os_->write(bytes, len);
458 }
459
460 template <typename T>
461 void emit(const std::vector<T> &v) {
462 emit(static_cast<std::size_t>(v.size()));
463 for (const auto &e : v) {
464 emit(e);
465 }
466 }
467
468 template <typename K, typename V>
469 void emit(const std::unordered_map<K, V> &map) {
470 emit(static_cast<std::size_t>(map.size()));
471 for (const auto &[k, v] : map) {
472 emit(k);
473 emit(v);
474 }
475 }
476
477 template <typename T1, typename T2>
478 void emit(const std::pair<T1, T2> &pair) {
479 emit(pair.first);
480 emit(pair.second);
481 }
482
483 template <typename K, typename V>
484 void emit(const std::map<K, V> &map) {
485 emit(static_cast<std::size_t>(map.size()));
486 for (const auto &[k, v] : map) {
487 emit(k);
488 emit(v);
489 }
490 }
491
492 void emit(const std::string &str) {
493 std::size_t size = str.size();
494 std::size_t offset = string_pool_.size();
495 string_pool_.insert(string_pool_.end(), str.begin(), str.end());
496 emit(size);
497 emit(offset);
498 }
499
500 void emit(Function *func) {
501 TI_ASSERT(func);
502 auto iter = real_funcs_.find(func);
503 if (iter != real_funcs_.end()) {
504 emit(iter->second);
505 } else {
506 auto [iter, ok] = real_funcs_.insert({func, real_funcs_.size()});
507 TI_ASSERT(ok);
508 emit(iter->second);
509 }
510 }
511
512 void emit(const TypedConstant &val) {
513 emit(val.dt);
514 if (!val.dt->is_primitive(PrimitiveTypeID::unknown)) {
515 emit(val.stringify());
516 }
517 }
518
519 void emit(const SNode *snode) {
520 if (snode) {
521 emit(static_cast<std::size_t>(snode->get_snode_tree_id()));
522 emit(static_cast<std::size_t>(snode->id));
523 const auto *root = snode->get_root();
524 snode_tree_roots_.insert(root);
525 } else {
526 emit(std::numeric_limits<std::size_t>::max());
527 emit(std::numeric_limits<std::size_t>::max());
528 }
529 }
530
531 void emit(const mesh::MeshLocalRelation &r) {
532 emit(r.fixed);
533 emit(r.value);
534 emit(r.patch_offset);
535 emit(r.offset);
536 }
537
538 void emit(mesh::Mesh *mesh) {
539 TI_ASSERT(mesh);
540 emit(mesh->num_patches);
541 emit(mesh->num_elements);
542 emit(mesh->patch_max_element_num);
543 emit(mesh->owned_offset);
544 emit(mesh->total_offset);
545 emit(mesh->index_mapping);
546 emit(mesh->relations);
547 }
548
549 void emit(const Identifier &ident) {
550 emit(ident.id);
551 }
552
553 void emit(const DataType &type) {
554 if (auto *p = type->cast<PrimitiveType>()) {
555 emit(p->type);
556 } else {
557 auto type_str = type->to_string();
558 emit(type_str);
559 }
560 }
561
562 void emit(IRNode *ir) {
563 TI_ASSERT(ir);
564 ir->accept(this);
565 }
566
567 void emit(const Expr &expr) {
568 if (expr) {
569 emit(expr.const_value);
570 emit(expr.atomic);
571 auto *e = expr.expr.get();
572 emit(e->get_flattened_stmt());
573 emit(e->attributes);
574 emit(e->ret_type);
575 expr.expr->accept(this);
576 } else {
577 emit(ExprOpCode::NIL);
578 }
579 }
580
581 void emit(Stmt *stmt) {
582 if (stmt) {
583 emit(stmt->get_operands());
584 emit(stmt->erased);
585 emit(stmt->fields_registered);
586 emit(stmt->ret_type);
587 stmt->accept(this);
588 } else {
589 emit(StmtOpCode::NIL);
590 }
591 }
592
593 void emit(std::size_t size) {
594 emit_pod(size);
595 }
596
597 void emit(std::uint8_t u8) {
598 emit_pod(u8);
599 }
600
601 void emit(int i) {
602 emit_pod(i);
603 }
604
605 void emit(bool v) {
606 emit_pod(v);
607 }
608
609 void emit(const MemoryAccessOptions &mem_access_options) {
610 auto all_options = mem_access_options.get_all();
611 emit(static_cast<std::size_t>(all_options.size()));
612 for (const auto &[snode, options] : all_options) {
613 emit(snode);
614 emit(static_cast<std::size_t>(options.size()));
615 for (auto e : options) {
616 emit(e);
617 }
618 }
619 }
620
621#define DEFINE_EMIT_ENUM(EnumType) \
622 void emit(EnumType type) { emit_pod(type); }
623
624 DEFINE_EMIT_ENUM(ExprOpCode);
625 DEFINE_EMIT_ENUM(StmtOpCode);
626 DEFINE_EMIT_ENUM(PrimitiveTypeID);
627 DEFINE_EMIT_ENUM(UnaryOpType);
628 DEFINE_EMIT_ENUM(BinaryOpType);
629 DEFINE_EMIT_ENUM(TernaryOpType);
630 DEFINE_EMIT_ENUM(AtomicOpType);
631 DEFINE_EMIT_ENUM(SNodeOpType);
632 DEFINE_EMIT_ENUM(ForLoopType);
633 DEFINE_EMIT_ENUM(SNodeAccessFlag);
634 DEFINE_EMIT_ENUM(MeshRelationAccessType);
635 DEFINE_EMIT_ENUM(ExternalFuncType);
636 DEFINE_EMIT_ENUM(TextureOpType);
637 DEFINE_EMIT_ENUM(mesh::MeshElementType);
638 DEFINE_EMIT_ENUM(mesh::MeshRelationType);
639 DEFINE_EMIT_ENUM(mesh::ConvType);
640 DEFINE_EMIT_ENUM(SNodeGradType);
641
642#undef DEFINE_EMIT_ENUM
643
644 std::ostream *os_{nullptr};
645 std::unordered_set<const SNode *> snode_tree_roots_;
646 std::unordered_map<Function *, std::size_t> real_funcs_;
647 std::vector<char> string_pool_;
648};
649
650} // namespace
651
652void gen_offline_cache_key(IRNode *ast, std::ostream *os) {
653 ASTSerializer::run(ast, os);
654}
655
656} // namespace taichi::lang
657