1#pragma once
2
3#include <vector>
4
5#include "taichi/ir/stmt_op_types.h"
6
7namespace taichi::lang {
8
9class LinearizeStmt;
10class SNode;
11class Stmt;
12class StructForStmt;
13class VecStatement;
14
15/**
16 * Lowers an SNode at a given indices to a series of concrete ops.
17 */
18class ScalarPointerLowerer {
19 public:
20 /**
21 * Constructor
22 *
23 * @param leaf_snode: SNode of the accessed field
24 * @param indices: Indices to access the field
25 * @param snode_op: SNode operation
26 * @param is_bit_vectorized: Is @param leaf_snode bit vectorized
27 * @param lowered: Collects the output ops
28 */
29 explicit ScalarPointerLowerer(SNode *leaf_snode,
30 const std::vector<Stmt *> &indices,
31 const SNodeOpType snode_op,
32 const bool is_bit_vectorized,
33 VecStatement *lowered);
34
35 virtual ~ScalarPointerLowerer() = default;
36 /**
37 * Runs the lowering process.
38 *
39 * This can only be called once.
40 */
41 void run();
42
43 protected:
44 /**
45 * Handles the SNode at a given @param level.
46 *
47 * @param level: Level of the SNode in the access path
48 * @param linearized: Linearized indices statement for this level
49 * @param last: SNode access op (e.g. GetCh) of the last iteration
50 */
51 virtual Stmt *handle_snode_at_level(int level,
52 LinearizeStmt *linearized,
53 Stmt *last) {
54 return last;
55 }
56
57 std::vector<SNode *> snodes() const {
58 return snodes_;
59 }
60
61 int path_length() const {
62 return path_length_;
63 }
64
65 const std::vector<Stmt *> indices_;
66 const SNodeOpType snode_op_;
67 const bool is_bit_vectorized_;
68 VecStatement *const lowered_;
69
70 private:
71 std::vector<SNode *> snodes_;
72 int path_length_{0};
73};
74
75} // namespace taichi::lang
76