1#pragma once
2
3#include "taichi/ir/pass.h"
4#include "taichi/ir/statements.h"
5#include "taichi/analysis/mesh_bls_analyzer.h"
6
7#include <set>
8
9namespace taichi::lang {
10
11class MakeMeshBlockLocal : public Pass {
12 public:
13 static const PassID id;
14
15 struct Args {
16 std::string kernel_name;
17 };
18
19 MakeMeshBlockLocal(OffloadedStmt *offload, const CompileConfig &config);
20
21 static void run(OffloadedStmt *offload,
22 const CompileConfig &config,
23 const std::string &kernel_name);
24
25 private:
26 void simplify_nested_conversion();
27 void gather_candidate_mapping();
28 void replace_conv_statements();
29 void replace_global_ptrs(SNode *snode);
30
31 void fetch_attr_to_bls(Block *body, Stmt *idx_val, Stmt *mapping_val);
32 void push_attr_to_global(Block *body, Stmt *idx_val, Stmt *mapping_val);
33
34 Stmt *create_xlogue(
35 Stmt *start_val,
36 Stmt *end_val,
37 std::function<void(Block * /*block*/, Stmt * /*idx_val*/)> body);
38 Stmt *create_cache_mapping(
39 Stmt *start_val,
40 Stmt *end_val,
41 std::function<Stmt *(Block * /*block*/, Stmt * /*idx_val*/)> global_val);
42
43 void fetch_mapping(
44 std::function<
45 Stmt *(Stmt * /*start_val*/,
46 Stmt * /*end_val*/,
47 std::function<Stmt *(Block * /*block*/,
48 Stmt * /*idx_val*/)>)/*global_val*/>
49 mapping_callback_handler,
50 std::function<void(Block * /*body*/,
51 Stmt * /*idx_val*/,
52 Stmt * /*mapping_val*/)> attr_callback_handler);
53
54 const CompileConfig &config_;
55 OffloadedStmt *offload_{nullptr};
56 std::set<std::pair<mesh::MeshElementType, mesh::ConvType>> mappings_{};
57 MeshBLSCaches::Rec rec_;
58
59 Block *block_;
60
61 std::size_t bls_offset_in_bytes_{0};
62 std::size_t mapping_bls_offset_in_bytes_{0};
63 std::unordered_map<SNode *, std::size_t> attr_bls_offset_in_bytes_{};
64
65 mesh::MeshElementType element_type_;
66 mesh::ConvType conv_type_;
67 SNode *mapping_snode_{nullptr};
68 DataType mapping_data_type_;
69 int mapping_dtype_size_{0};
70};
71
72} // namespace taichi::lang
73