1 | #include "taichi/ir/ir.h" |
2 | #include "taichi/ir/snode.h" |
3 | #include "taichi/ir/mesh.h" |
4 | #include "taichi/ir/visitors.h" |
5 | #include "taichi/ir/analysis.h" |
6 | #include "taichi/ir/statements.h" |
7 | |
8 | namespace taichi::lang { |
9 | |
10 | namespace irpass::analysis { |
11 | |
12 | class GatherMeshforRelationTypes : public BasicStmtVisitor { |
13 | public: |
14 | using BasicStmtVisitor::visit; |
15 | |
16 | GatherMeshforRelationTypes() { |
17 | allow_undefined_visitor = true; |
18 | invoke_default_visitor = true; |
19 | } |
20 | |
21 | static void run(IRNode *root) { |
22 | GatherMeshforRelationTypes analyser; |
23 | root->accept(&analyser); |
24 | } |
25 | |
26 | void visit(MeshForStmt *stmt) override { |
27 | TI_ASSERT(mesh_for == nullptr); |
28 | TI_ASSERT(stmt->major_to_types.size() == 0); |
29 | TI_ASSERT(stmt->minor_relation_types.size() == 0); |
30 | mesh_for = stmt; |
31 | stmt->body->accept(this); |
32 | |
33 | // Check metadata available |
34 | std::set<mesh::MeshElementType> all_elements; |
35 | all_elements.insert(mesh_for->major_from_type); |
36 | for (auto _type : mesh_for->major_to_types) { |
37 | all_elements.insert(_type); |
38 | } |
39 | for (auto _type : all_elements) { |
40 | TI_ERROR_IF(mesh_for->mesh->num_elements.find(_type) == |
41 | mesh_for->mesh->num_elements.end(), |
42 | "Cannot load mesh element {}'s metadata" , |
43 | mesh::element_type_name(_type)); |
44 | } |
45 | |
46 | std::set<mesh::MeshRelationType> all_relations; |
47 | for (auto _type : mesh_for->major_to_types) { |
48 | all_relations.insert( |
49 | mesh::relation_by_orders(int(mesh_for->major_from_type), int(_type))); |
50 | } |
51 | for (auto _type : mesh_for->minor_relation_types) { |
52 | all_relations.insert(_type); |
53 | } |
54 | |
55 | bool missing = false; |
56 | std::string full_name; |
57 | std::string short_name; |
58 | for (auto _type : all_relations) { |
59 | if (mesh_for->mesh->relations.find(_type) == |
60 | mesh_for->mesh->relations.end()) { |
61 | if (missing) { |
62 | full_name += ", " ; |
63 | short_name += ", " ; |
64 | } |
65 | full_name += mesh::relation_type_name(_type); |
66 | short_name += '\''; |
67 | short_name += char(mesh::element_type_name(mesh::MeshElementType( |
68 | mesh::from_end_element_order(_type)))[0] + |
69 | 'A' - 'a'); |
70 | short_name += char(mesh::element_type_name(mesh::MeshElementType( |
71 | mesh::to_end_element_order(_type)))[0] + |
72 | 'A' - 'a'); |
73 | short_name += '\''; |
74 | missing = true; |
75 | } |
76 | } |
77 | |
78 | if (missing) { |
79 | TI_ERROR( |
80 | "Relation {} detected in mesh-for loop but not initialized." |
81 | " Please add them with syntax: Patcher.load_mesh(..., " |
82 | "relations=[..., {}])" , |
83 | full_name, short_name); |
84 | } |
85 | |
86 | mesh_for = nullptr; |
87 | } |
88 | |
89 | void visit(MeshRelationAccessStmt *stmt) override { |
90 | if (auto from_stmt = |
91 | stmt->mesh_idx->cast<LoopIndexStmt>()) { // major relation |
92 | TI_ASSERT(from_stmt->mesh_index_type() == mesh_for->major_from_type); |
93 | mesh_for->major_to_types.insert(stmt->to_type); |
94 | } else if (auto from_stmt = |
95 | stmt->mesh_idx |
96 | ->cast<MeshRelationAccessStmt>()) { // minor relation |
97 | TI_ASSERT(!from_stmt->is_size()); |
98 | auto from_order = mesh::element_order(from_stmt->to_type); |
99 | auto to_order = mesh::element_order(stmt->to_type); |
100 | TI_ASSERT_INFO(from_order > to_order, |
101 | "Cannot access an indeterminate relation (E.g, Vert-Vert) " |
102 | "in a nested neighbor access" ); |
103 | mesh_for->minor_relation_types.insert( |
104 | mesh::relation_by_orders(from_order, to_order)); |
105 | } else { |
106 | TI_NOT_IMPLEMENTED; |
107 | } |
108 | } |
109 | |
110 | MeshForStmt *mesh_for{nullptr}; |
111 | }; |
112 | |
113 | void gather_meshfor_relation_types(IRNode *node) { |
114 | GatherMeshforRelationTypes::run(node); |
115 | } |
116 | |
117 | } // namespace irpass::analysis |
118 | |
119 | } // namespace taichi::lang |
120 | |