1 | #pragma once |
2 | |
3 | #include "taichi/ir/visitors.h" |
4 | #include "taichi/ir/statements.h" |
5 | #include "taichi/ir/scratch_pad.h" |
6 | |
7 | namespace taichi::lang { |
8 | |
9 | // Figure out accessed SNodes, and their ranges in this for stmt |
10 | class BLSAnalyzer : public BasicStmtVisitor { |
11 | using BasicStmtVisitor::visit; |
12 | |
13 | public: |
14 | // The lowest and highest index in each dimension. |
15 | struct IndexRange { |
16 | int low{0}; |
17 | int high{0}; |
18 | }; |
19 | using BlockIndices = std::vector<IndexRange>; |
20 | |
21 | BLSAnalyzer(OffloadedStmt *for_stmt, ScratchPads *pads); |
22 | |
23 | void visit(GlobalPtrStmt *stmt) override { |
24 | } |
25 | |
26 | // Do not eliminate global data access |
27 | void visit(GlobalLoadStmt *stmt) override; |
28 | |
29 | void visit(GlobalStoreStmt *stmt) override; |
30 | |
31 | void visit(AtomicOpStmt *stmt) override; |
32 | |
33 | void visit(Stmt *stmt) override; |
34 | |
35 | /** |
36 | * Run the block local analysis |
37 | * @return: true if the block range could be successfully inferred |
38 | */ |
39 | bool run(); |
40 | |
41 | private: |
42 | // Generate the index bounds in a SNode (block). E.g., a dense(ti.ij, (2, 4)) |
43 | // SNode has index bounds [[0, 1], [0, 3]]. |
44 | static void generate_block_indices(SNode *snode, BlockIndices *indices); |
45 | |
46 | void record_access(Stmt *stmt, AccessFlag flag); |
47 | |
48 | OffloadedStmt *for_stmt_{nullptr}; |
49 | ScratchPads *pads_{nullptr}; |
50 | std::unordered_map<SNode *, BlockIndices> block_indices_; |
51 | // true means analysis is OK for now |
52 | // it could be failed by any of the following reasons |
53 | // compiler could not infer the scratch pad range at compile time |
54 | // ... |
55 | bool analysis_ok_{true}; |
56 | }; |
57 | |
58 | } // namespace taichi::lang |
59 | |