1 | #pragma once |
2 | |
3 | #include <vector> |
4 | |
5 | #include "taichi/ir/stmt_op_types.h" |
6 | |
7 | namespace taichi::lang { |
8 | |
9 | class LinearizeStmt; |
10 | class SNode; |
11 | class Stmt; |
12 | class StructForStmt; |
13 | class VecStatement; |
14 | |
15 | /** |
16 | * Lowers an SNode at a given indices to a series of concrete ops. |
17 | */ |
18 | class ScalarPointerLowerer { |
19 | public: |
20 | /** |
21 | * Constructor |
22 | * |
23 | * @param leaf_snode: SNode of the accessed field |
24 | * @param indices: Indices to access the field |
25 | * @param snode_op: SNode operation |
26 | * @param is_bit_vectorized: Is @param leaf_snode bit vectorized |
27 | * @param lowered: Collects the output ops |
28 | */ |
29 | explicit ScalarPointerLowerer(SNode *leaf_snode, |
30 | const std::vector<Stmt *> &indices, |
31 | const SNodeOpType snode_op, |
32 | const bool is_bit_vectorized, |
33 | VecStatement *lowered); |
34 | |
35 | virtual ~ScalarPointerLowerer() = default; |
36 | /** |
37 | * Runs the lowering process. |
38 | * |
39 | * This can only be called once. |
40 | */ |
41 | void run(); |
42 | |
43 | protected: |
44 | /** |
45 | * Handles the SNode at a given @param level. |
46 | * |
47 | * @param level: Level of the SNode in the access path |
48 | * @param linearized: Linearized indices statement for this level |
49 | * @param last: SNode access op (e.g. GetCh) of the last iteration |
50 | */ |
51 | virtual Stmt *handle_snode_at_level(int level, |
52 | LinearizeStmt *linearized, |
53 | Stmt *last) { |
54 | return last; |
55 | } |
56 | |
57 | std::vector<SNode *> snodes() const { |
58 | return snodes_; |
59 | } |
60 | |
61 | int path_length() const { |
62 | return path_length_; |
63 | } |
64 | |
65 | const std::vector<Stmt *> indices_; |
66 | const SNodeOpType snode_op_; |
67 | const bool is_bit_vectorized_; |
68 | VecStatement *const lowered_; |
69 | |
70 | private: |
71 | std::vector<SNode *> snodes_; |
72 | int path_length_{0}; |
73 | }; |
74 | |
75 | } // namespace taichi::lang |
76 | |