1 | #pragma once |
2 | |
3 | #include "taichi/ir/pass.h" |
4 | #include "taichi/ir/statements.h" |
5 | #include "taichi/analysis/mesh_bls_analyzer.h" |
6 | |
7 | #include <set> |
8 | |
9 | namespace taichi::lang { |
10 | |
11 | class MakeMeshBlockLocal : public Pass { |
12 | public: |
13 | static const PassID id; |
14 | |
15 | struct Args { |
16 | std::string kernel_name; |
17 | }; |
18 | |
19 | MakeMeshBlockLocal(OffloadedStmt *offload, const CompileConfig &config); |
20 | |
21 | static void run(OffloadedStmt *offload, |
22 | const CompileConfig &config, |
23 | const std::string &kernel_name); |
24 | |
25 | private: |
26 | void simplify_nested_conversion(); |
27 | void gather_candidate_mapping(); |
28 | void replace_conv_statements(); |
29 | void replace_global_ptrs(SNode *snode); |
30 | |
31 | void fetch_attr_to_bls(Block *body, Stmt *idx_val, Stmt *mapping_val); |
32 | void push_attr_to_global(Block *body, Stmt *idx_val, Stmt *mapping_val); |
33 | |
34 | Stmt *create_xlogue( |
35 | Stmt *start_val, |
36 | Stmt *end_val, |
37 | std::function<void(Block * /*block*/, Stmt * /*idx_val*/)> body); |
38 | Stmt *create_cache_mapping( |
39 | Stmt *start_val, |
40 | Stmt *end_val, |
41 | std::function<Stmt *(Block * /*block*/, Stmt * /*idx_val*/)> global_val); |
42 | |
43 | void fetch_mapping( |
44 | std::function< |
45 | Stmt *(Stmt * /*start_val*/, |
46 | Stmt * /*end_val*/, |
47 | std::function<Stmt *(Block * /*block*/, |
48 | Stmt * /*idx_val*/)>)/*global_val*/> |
49 | mapping_callback_handler, |
50 | std::function<void(Block * /*body*/, |
51 | Stmt * /*idx_val*/, |
52 | Stmt * /*mapping_val*/)> attr_callback_handler); |
53 | |
54 | const CompileConfig &config_; |
55 | OffloadedStmt *offload_{nullptr}; |
56 | std::set<std::pair<mesh::MeshElementType, mesh::ConvType>> mappings_{}; |
57 | MeshBLSCaches::Rec rec_; |
58 | |
59 | Block *block_; |
60 | |
61 | std::size_t bls_offset_in_bytes_{0}; |
62 | std::size_t mapping_bls_offset_in_bytes_{0}; |
63 | std::unordered_map<SNode *, std::size_t> attr_bls_offset_in_bytes_{}; |
64 | |
65 | mesh::MeshElementType element_type_; |
66 | mesh::ConvType conv_type_; |
67 | SNode *mapping_snode_{nullptr}; |
68 | DataType mapping_data_type_; |
69 | int mapping_dtype_size_{0}; |
70 | }; |
71 | |
72 | } // namespace taichi::lang |
73 | |