1 | #include "offline_cache_util.h" |
2 | |
3 | #include "taichi/ir/expr.h" |
4 | #include "taichi/ir/frontend_ir.h" |
5 | #include "taichi/ir/ir.h" |
6 | #include "taichi/ir/mesh.h" |
7 | #include "taichi/ir/type.h" |
8 | #include "taichi/program/function.h" |
9 | #include "taichi/program/program.h" |
10 | |
11 | namespace taichi::lang { |
12 | |
13 | namespace { |
14 | |
15 | enum class ExprOpCode : std::uint8_t { |
16 | NIL, |
17 | #define PER_EXPRESSION(x) x, |
18 | #include "taichi/inc/expressions.inc.h" |
19 | #undef PER_EXPRESSION |
20 | }; |
21 | |
22 | enum class StmtOpCode : std::uint8_t { |
23 | NIL, |
24 | EnterBlock, |
25 | ExitBlock, |
26 | StopGrad, |
27 | #define PER_STATEMENT(x) x, |
28 | #include "taichi/inc/frontend_statements.inc.h" |
29 | #undef PER_STATEMENT |
30 | }; |
31 | |
32 | enum class ForLoopType : std::uint8_t { |
33 | StructForOnSNode, |
34 | StructForOnExternalTensor, |
35 | MeshFor, |
36 | RangeFor |
37 | }; |
38 | |
39 | enum class ExternalFuncType : std::uint8_t { |
40 | SO, |
41 | ASM, |
42 | BC, |
43 | }; |
44 | |
45 | enum class MeshRelationAccessType { |
46 | Access, // mesh_relation_access |
47 | Size, // mesh_relation_size |
48 | }; |
49 | |
50 | class ASTSerializer : public IRVisitor, public ExpressionVisitor { |
51 | private: |
52 | using ExpressionVisitor::visit; |
53 | using IRVisitor::visit; |
54 | |
55 | public: |
56 | explicit ASTSerializer(std::ostream *os) : ExpressionVisitor(true), os_(os) { |
57 | // TODO(PGZXB): Set allow_undefined_visitor as false. (blocked by |
58 | // constant-folding) |
59 | this->allow_undefined_visitor = true; |
60 | } |
61 | |
62 | void set_ostream(std::ostream *os) { |
63 | this->os_ = os; |
64 | } |
65 | |
66 | std::ostream *get_ostream() { |
67 | return this->os_; |
68 | } |
69 | |
70 | void visit(Expression *expr) override { |
71 | this->ExpressionVisitor::visit(expr); |
72 | } |
73 | |
74 | void visit(Stmt *stmt) override { |
75 | this->IRVisitor::visit(stmt); |
76 | } |
77 | |
78 | void visit(ExprGroup &expr_group) override { |
79 | emit(expr_group.exprs); |
80 | } |
81 | |
82 | void visit(ArgLoadExpression *expr) override { |
83 | emit(ExprOpCode::ArgLoadExpression); |
84 | emit(expr->dt); |
85 | emit(expr->arg_id); |
86 | emit(expr->is_ptr); |
87 | } |
88 | |
89 | void visit(TexturePtrExpression *expr) override { |
90 | emit(ExprOpCode::TexturePtrExpression); |
91 | emit(expr->arg_id); |
92 | emit(expr->num_dims); |
93 | emit(expr->is_storage); |
94 | emit(expr->num_channels); |
95 | emit(expr->channel_format); |
96 | emit(expr->lod); |
97 | } |
98 | |
99 | void visit(TextureOpExpression *expr) override { |
100 | emit(ExprOpCode::TextureOpExpression); |
101 | emit(expr->op); |
102 | emit(expr->texture_ptr); |
103 | emit(expr->args.exprs); |
104 | } |
105 | |
106 | void visit(RandExpression *expr) override { |
107 | emit(ExprOpCode::RandExpression); |
108 | emit(expr->dt); |
109 | } |
110 | |
111 | void visit(UnaryOpExpression *expr) override { |
112 | emit(ExprOpCode::UnaryOpExpression); |
113 | emit(expr->type); |
114 | if (expr->is_cast()) { |
115 | emit(expr->cast_type); |
116 | } |
117 | emit(expr->operand); |
118 | } |
119 | |
120 | void visit(BinaryOpExpression *expr) override { |
121 | emit(ExprOpCode::BinaryOpExpression); |
122 | emit(expr->type); |
123 | emit(expr->lhs); |
124 | emit(expr->rhs); |
125 | } |
126 | |
127 | void visit(TernaryOpExpression *expr) override { |
128 | emit(ExprOpCode::TernaryOpExpression); |
129 | emit(expr->type); |
130 | emit(expr->op1); |
131 | emit(expr->op2); |
132 | emit(expr->op3); |
133 | } |
134 | |
135 | void visit(InternalFuncCallExpression *expr) override { |
136 | emit(ExprOpCode::InternalFuncCallExpression); |
137 | emit(expr->func_name); |
138 | emit(expr->args); |
139 | emit(expr->with_runtime_context); |
140 | } |
141 | |
142 | void visit(ExternalTensorExpression *expr) override { |
143 | emit(ExprOpCode::ExternalTensorExpression); |
144 | emit(expr->dt); |
145 | emit(expr->dim); |
146 | emit(expr->arg_id); |
147 | emit(expr->element_dim); |
148 | } |
149 | |
150 | void visit(FieldExpression *expr) override { |
151 | emit(ExprOpCode::FieldExpression); |
152 | emit(expr->ident); |
153 | emit(expr->dt); |
154 | emit(expr->snode); |
155 | emit(expr->has_ambient); |
156 | emit(expr->ambient_value); |
157 | emit(expr->snode_grad_type); |
158 | emit(expr->adjoint); |
159 | emit(expr->dual); |
160 | emit(expr->adjoint_checkbit); |
161 | } |
162 | |
163 | void visit(MatrixFieldExpression *expr) override { |
164 | emit(ExprOpCode::MatrixFieldExpression); |
165 | emit(expr->fields); |
166 | emit(expr->element_shape); |
167 | emit(expr->dynamic_index_stride); |
168 | } |
169 | |
170 | void visit(IndexExpression *expr) override { |
171 | emit(ExprOpCode::IndexExpression); |
172 | emit(expr->var); |
173 | for (auto &indices : expr->indices_group) { |
174 | emit(indices.exprs); |
175 | } |
176 | emit(expr->ret_shape); |
177 | } |
178 | |
179 | void visit(MatrixExpression *expr) override { |
180 | emit(ExprOpCode::MatrixExpression); |
181 | emit(expr->dt); |
182 | for (auto elt : expr->elements) { |
183 | emit(elt); |
184 | } |
185 | } |
186 | |
187 | void visit(RangeAssumptionExpression *expr) override { |
188 | emit(ExprOpCode::RangeAssumptionExpression); |
189 | emit(expr->input); |
190 | emit(expr->base); |
191 | emit(expr->low); |
192 | emit(expr->high); |
193 | } |
194 | |
195 | void visit(LoopUniqueExpression *expr) override { |
196 | emit(ExprOpCode::LoopUniqueExpression); |
197 | emit(expr->input); |
198 | emit(expr->covers); |
199 | } |
200 | |
201 | void visit(IdExpression *expr) override { |
202 | emit(ExprOpCode::IdExpression); |
203 | emit(expr->id); |
204 | } |
205 | |
206 | void visit(AtomicOpExpression *expr) override { |
207 | emit(ExprOpCode::AtomicOpExpression); |
208 | emit(expr->op_type); |
209 | emit(expr->dest); |
210 | emit(expr->val); |
211 | } |
212 | |
213 | void visit(SNodeOpExpression *expr) override { |
214 | emit(ExprOpCode::SNodeOpExpression); |
215 | emit(expr->op_type); |
216 | emit(expr->snode); |
217 | emit(expr->indices.exprs); |
218 | emit(expr->values); |
219 | } |
220 | |
221 | void visit(ConstExpression *expr) override { |
222 | emit(ExprOpCode::ConstExpression); |
223 | emit(expr->val); |
224 | } |
225 | |
226 | void visit(ExternalTensorShapeAlongAxisExpression *expr) override { |
227 | emit(ExprOpCode::ExternalTensorShapeAlongAxisExpression); |
228 | emit(expr->ptr); |
229 | emit(expr->axis); |
230 | } |
231 | |
232 | void visit(FrontendFuncCallStmt *expr) override { |
233 | emit(StmtOpCode::FrontendFuncCallStmt); |
234 | emit(expr->func); |
235 | emit(expr->args.exprs); |
236 | } |
237 | |
238 | void visit(MeshPatchIndexExpression *expr) override { |
239 | emit(ExprOpCode::MeshPatchIndexExpression); |
240 | } |
241 | |
242 | void visit(MeshRelationAccessExpression *expr) override { |
243 | emit(ExprOpCode::MeshRelationAccessExpression); |
244 | if (expr->neighbor_idx) { |
245 | emit(MeshRelationAccessType::Access); |
246 | emit(expr->neighbor_idx); |
247 | } else { |
248 | emit(MeshRelationAccessType::Size); |
249 | } |
250 | emit(expr->mesh); |
251 | emit(expr->to_type); |
252 | emit(expr->mesh_idx); |
253 | } |
254 | |
255 | void visit(MeshIndexConversionExpression *expr) override { |
256 | emit(ExprOpCode::MeshIndexConversionExpression); |
257 | emit(expr->mesh); |
258 | emit(expr->idx_type); |
259 | emit(expr->idx); |
260 | emit(expr->conv_type); |
261 | } |
262 | |
263 | void visit(ReferenceExpression *expr) override { |
264 | emit(ExprOpCode::ReferenceExpression); |
265 | emit(expr->var); |
266 | } |
267 | |
268 | void visit(GetElementExpression *expr) override { |
269 | emit(ExprOpCode::GetElementExpression); |
270 | emit(expr->src); |
271 | emit(expr->index); |
272 | } |
273 | |
274 | void visit(Block *block) override { |
275 | emit(StmtOpCode::EnterBlock); |
276 | emit(static_cast<std::size_t>(block->statements.size())); |
277 | for (auto &stmt : block->statements) { |
278 | emit(stmt.get()); |
279 | } |
280 | emit(StmtOpCode::StopGrad); |
281 | emit(block->stop_gradients); |
282 | emit(StmtOpCode::ExitBlock); |
283 | } |
284 | |
285 | void visit(FrontendExprStmt *stmt) override { |
286 | emit(StmtOpCode::FrontendExprStmt); |
287 | emit(stmt->val); |
288 | } |
289 | |
290 | void visit(FrontendBreakStmt *stmt) override { |
291 | emit(StmtOpCode::FrontendBreakStmt); |
292 | } |
293 | |
294 | void visit(FrontendContinueStmt *stmt) override { |
295 | emit(StmtOpCode::FrontendContinueStmt); |
296 | } |
297 | |
298 | void visit(FrontendAssignStmt *stmt) override { |
299 | emit(StmtOpCode::FrontendAssignStmt); |
300 | emit(stmt->lhs); |
301 | emit(stmt->rhs); |
302 | } |
303 | |
304 | void visit(FrontendAllocaStmt *stmt) override { |
305 | emit(StmtOpCode::FrontendAllocaStmt); |
306 | emit(stmt->ident); |
307 | } |
308 | |
309 | void visit(FrontendAssertStmt *stmt) override { |
310 | emit(StmtOpCode::FrontendAssertStmt); |
311 | emit(stmt->cond); |
312 | emit(stmt->text); |
313 | emit(stmt->args); |
314 | } |
315 | |
316 | void visit(FrontendSNodeOpStmt *stmt) override { |
317 | emit(StmtOpCode::FrontendSNodeOpStmt); |
318 | emit(stmt->op_type); |
319 | emit(stmt->snode); |
320 | emit(stmt->indices.exprs); |
321 | emit(stmt->val); |
322 | } |
323 | |
324 | void visit(FrontendIfStmt *stmt) override { |
325 | emit(StmtOpCode::FrontendIfStmt); |
326 | emit(stmt->condition); |
327 | std::uint8_t branch_count = 0; |
328 | if (stmt->true_statements) { |
329 | ++branch_count; |
330 | } |
331 | if (stmt->false_statements) { |
332 | ++branch_count; |
333 | } |
334 | emit(branch_count); |
335 | if (stmt->true_statements) { |
336 | emit(stmt->true_statements.get()); |
337 | } |
338 | if (stmt->false_statements) { |
339 | emit(stmt->false_statements.get()); |
340 | } |
341 | } |
342 | |
343 | void visit(FrontendPrintStmt *stmt) override { |
344 | emit(StmtOpCode::FrontendPrintStmt); |
345 | emit(static_cast<std::size_t>(stmt->contents.size())); |
346 | for (const auto &c : stmt->contents) { |
347 | emit(static_cast<std::uint8_t>(c.index())); |
348 | if (std::holds_alternative<Expr>(c)) { |
349 | emit(std::get<Expr>(c)); |
350 | } else { |
351 | emit(std::get<std::string>(c)); |
352 | } |
353 | } |
354 | } |
355 | |
356 | void visit(FrontendFuncDefStmt *stmt) override { |
357 | emit(StmtOpCode::FrontendFuncDefStmt); |
358 | emit(stmt->body.get()); |
359 | } |
360 | |
361 | void visit(FrontendWhileStmt *stmt) override { |
362 | emit(StmtOpCode::FrontendWhileStmt); |
363 | emit(stmt->cond); |
364 | emit(stmt->body.get()); |
365 | } |
366 | |
367 | void visit(FrontendForStmt *stmt) override { |
368 | emit(StmtOpCode::FrontendForStmt); |
369 | if (stmt->snode) { |
370 | emit(ForLoopType::StructForOnSNode); |
371 | emit(stmt->snode); |
372 | } else if (stmt->external_tensor) { |
373 | emit(ForLoopType::StructForOnExternalTensor); |
374 | emit(stmt->external_tensor); |
375 | } else if (stmt->mesh) { |
376 | emit(ForLoopType::MeshFor); |
377 | emit(stmt->element_type); |
378 | emit(stmt->mesh); |
379 | } else { |
380 | emit(ForLoopType::RangeFor); |
381 | emit(stmt->begin); |
382 | emit(stmt->end); |
383 | } |
384 | emit(stmt->loop_var_ids); |
385 | emit(stmt->is_bit_vectorized); |
386 | emit(stmt->num_cpu_threads); |
387 | emit(stmt->strictly_serialized); |
388 | emit(stmt->mem_access_opt); |
389 | emit(stmt->block_dim); |
390 | emit(stmt->body.get()); |
391 | } |
392 | |
393 | void visit(FrontendReturnStmt *stmt) override { |
394 | emit(StmtOpCode::FrontendReturnStmt); |
395 | emit(stmt->values.exprs); |
396 | } |
397 | |
398 | void visit(FrontendExternalFuncStmt *stmt) override { |
399 | // Note: The result of serializing FrontendExternalFuncStmt is not parsable |
400 | // now |
401 | emit(StmtOpCode::FrontendExternalFuncStmt); |
402 | if (stmt->so_func != nullptr) { |
403 | emit(ExternalFuncType::SO); |
404 | } else if (!stmt->asm_source.empty()) { |
405 | emit(ExternalFuncType::ASM); |
406 | emit(stmt->asm_source); |
407 | } else { |
408 | emit(ExternalFuncType::BC); |
409 | emit(stmt->bc_filename); |
410 | emit(stmt->bc_funcname); |
411 | } |
412 | emit(stmt->args); |
413 | emit(stmt->outputs); |
414 | } |
415 | |
416 | static void run(IRNode *ast, std::ostream *os) { |
417 | ASTSerializer serializer(os); |
418 | ast->accept(&serializer); |
419 | serializer.emit_dependencies(); |
420 | } |
421 | |
422 | private: |
423 | void emit_dependencies() { |
424 | // Serialize dependent real-functions |
425 | emit(real_funcs_.size()); |
426 | for (auto &[func, id] : real_funcs_) { |
427 | if (auto &ast_str = func->try_get_ast_serialization_data(); |
428 | ast_str.has_value()) { |
429 | emit_bytes(ast_str->c_str(), ast_str->size()); |
430 | } |
431 | } |
432 | |
433 | // Serialize snode_trees(Temporary: using offline-cache-key of SNode) |
434 | // Note: The result of serializing snode_tree_roots_ is not parsable now |
435 | emit(static_cast<std::size_t>(snode_tree_roots_.size())); |
436 | for (const auto *snode : snode_tree_roots_) { |
437 | auto key = get_hashed_offline_cache_key_of_snode(snode); |
438 | emit_bytes(key.c_str(), key.size()); |
439 | } |
440 | |
441 | // Dump string-pool |
442 | emit(static_cast<std::size_t>(string_pool_.size())); |
443 | emit_bytes(string_pool_.data(), string_pool_.size()); |
444 | } |
445 | |
446 | template <typename T> |
447 | void emit_pod(const T &val) { |
448 | static_assert(std::is_pod<T>::value); |
449 | TI_ASSERT(os_); |
450 | os_->write((const char *)&val, sizeof(T)); |
451 | } |
452 | |
453 | void emit_bytes(const char *bytes, std::size_t len) { |
454 | TI_ASSERT(os_); |
455 | if (!bytes) |
456 | return; |
457 | os_->write(bytes, len); |
458 | } |
459 | |
460 | template <typename T> |
461 | void emit(const std::vector<T> &v) { |
462 | emit(static_cast<std::size_t>(v.size())); |
463 | for (const auto &e : v) { |
464 | emit(e); |
465 | } |
466 | } |
467 | |
468 | template <typename K, typename V> |
469 | void emit(const std::unordered_map<K, V> &map) { |
470 | emit(static_cast<std::size_t>(map.size())); |
471 | for (const auto &[k, v] : map) { |
472 | emit(k); |
473 | emit(v); |
474 | } |
475 | } |
476 | |
477 | template <typename T1, typename T2> |
478 | void emit(const std::pair<T1, T2> &pair) { |
479 | emit(pair.first); |
480 | emit(pair.second); |
481 | } |
482 | |
483 | template <typename K, typename V> |
484 | void emit(const std::map<K, V> &map) { |
485 | emit(static_cast<std::size_t>(map.size())); |
486 | for (const auto &[k, v] : map) { |
487 | emit(k); |
488 | emit(v); |
489 | } |
490 | } |
491 | |
492 | void emit(const std::string &str) { |
493 | std::size_t size = str.size(); |
494 | std::size_t offset = string_pool_.size(); |
495 | string_pool_.insert(string_pool_.end(), str.begin(), str.end()); |
496 | emit(size); |
497 | emit(offset); |
498 | } |
499 | |
500 | void emit(Function *func) { |
501 | TI_ASSERT(func); |
502 | auto iter = real_funcs_.find(func); |
503 | if (iter != real_funcs_.end()) { |
504 | emit(iter->second); |
505 | } else { |
506 | auto [iter, ok] = real_funcs_.insert({func, real_funcs_.size()}); |
507 | TI_ASSERT(ok); |
508 | emit(iter->second); |
509 | } |
510 | } |
511 | |
512 | void emit(const TypedConstant &val) { |
513 | emit(val.dt); |
514 | if (!val.dt->is_primitive(PrimitiveTypeID::unknown)) { |
515 | emit(val.stringify()); |
516 | } |
517 | } |
518 | |
519 | void emit(const SNode *snode) { |
520 | if (snode) { |
521 | emit(static_cast<std::size_t>(snode->get_snode_tree_id())); |
522 | emit(static_cast<std::size_t>(snode->id)); |
523 | const auto *root = snode->get_root(); |
524 | snode_tree_roots_.insert(root); |
525 | } else { |
526 | emit(std::numeric_limits<std::size_t>::max()); |
527 | emit(std::numeric_limits<std::size_t>::max()); |
528 | } |
529 | } |
530 | |
531 | void emit(const mesh::MeshLocalRelation &r) { |
532 | emit(r.fixed); |
533 | emit(r.value); |
534 | emit(r.patch_offset); |
535 | emit(r.offset); |
536 | } |
537 | |
538 | void emit(mesh::Mesh *mesh) { |
539 | TI_ASSERT(mesh); |
540 | emit(mesh->num_patches); |
541 | emit(mesh->num_elements); |
542 | emit(mesh->patch_max_element_num); |
543 | emit(mesh->owned_offset); |
544 | emit(mesh->total_offset); |
545 | emit(mesh->index_mapping); |
546 | emit(mesh->relations); |
547 | } |
548 | |
549 | void emit(const Identifier &ident) { |
550 | emit(ident.id); |
551 | } |
552 | |
553 | void emit(const DataType &type) { |
554 | if (auto *p = type->cast<PrimitiveType>()) { |
555 | emit(p->type); |
556 | } else { |
557 | auto type_str = type->to_string(); |
558 | emit(type_str); |
559 | } |
560 | } |
561 | |
562 | void emit(IRNode *ir) { |
563 | TI_ASSERT(ir); |
564 | ir->accept(this); |
565 | } |
566 | |
567 | void emit(const Expr &expr) { |
568 | if (expr) { |
569 | emit(expr.const_value); |
570 | emit(expr.atomic); |
571 | auto *e = expr.expr.get(); |
572 | emit(e->get_flattened_stmt()); |
573 | emit(e->attributes); |
574 | emit(e->ret_type); |
575 | expr.expr->accept(this); |
576 | } else { |
577 | emit(ExprOpCode::NIL); |
578 | } |
579 | } |
580 | |
581 | void emit(Stmt *stmt) { |
582 | if (stmt) { |
583 | emit(stmt->get_operands()); |
584 | emit(stmt->erased); |
585 | emit(stmt->fields_registered); |
586 | emit(stmt->ret_type); |
587 | stmt->accept(this); |
588 | } else { |
589 | emit(StmtOpCode::NIL); |
590 | } |
591 | } |
592 | |
593 | void emit(std::size_t size) { |
594 | emit_pod(size); |
595 | } |
596 | |
597 | void emit(std::uint8_t u8) { |
598 | emit_pod(u8); |
599 | } |
600 | |
601 | void emit(int i) { |
602 | emit_pod(i); |
603 | } |
604 | |
605 | void emit(bool v) { |
606 | emit_pod(v); |
607 | } |
608 | |
609 | void emit(const MemoryAccessOptions &mem_access_options) { |
610 | auto all_options = mem_access_options.get_all(); |
611 | emit(static_cast<std::size_t>(all_options.size())); |
612 | for (const auto &[snode, options] : all_options) { |
613 | emit(snode); |
614 | emit(static_cast<std::size_t>(options.size())); |
615 | for (auto e : options) { |
616 | emit(e); |
617 | } |
618 | } |
619 | } |
620 | |
621 | #define DEFINE_EMIT_ENUM(EnumType) \ |
622 | void emit(EnumType type) { emit_pod(type); } |
623 | |
624 | DEFINE_EMIT_ENUM(ExprOpCode); |
625 | DEFINE_EMIT_ENUM(StmtOpCode); |
626 | DEFINE_EMIT_ENUM(PrimitiveTypeID); |
627 | DEFINE_EMIT_ENUM(UnaryOpType); |
628 | DEFINE_EMIT_ENUM(BinaryOpType); |
629 | DEFINE_EMIT_ENUM(TernaryOpType); |
630 | DEFINE_EMIT_ENUM(AtomicOpType); |
631 | DEFINE_EMIT_ENUM(SNodeOpType); |
632 | DEFINE_EMIT_ENUM(ForLoopType); |
633 | DEFINE_EMIT_ENUM(SNodeAccessFlag); |
634 | DEFINE_EMIT_ENUM(MeshRelationAccessType); |
635 | DEFINE_EMIT_ENUM(ExternalFuncType); |
636 | DEFINE_EMIT_ENUM(TextureOpType); |
637 | DEFINE_EMIT_ENUM(mesh::MeshElementType); |
638 | DEFINE_EMIT_ENUM(mesh::MeshRelationType); |
639 | DEFINE_EMIT_ENUM(mesh::ConvType); |
640 | DEFINE_EMIT_ENUM(SNodeGradType); |
641 | |
642 | #undef DEFINE_EMIT_ENUM |
643 | |
644 | std::ostream *os_{nullptr}; |
645 | std::unordered_set<const SNode *> snode_tree_roots_; |
646 | std::unordered_map<Function *, std::size_t> real_funcs_; |
647 | std::vector<char> string_pool_; |
648 | }; |
649 | |
650 | } // namespace |
651 | |
652 | void gen_offline_cache_key(IRNode *ast, std::ostream *os) { |
653 | ASTSerializer::run(ast, os); |
654 | } |
655 | |
656 | } // namespace taichi::lang |
657 | |