1 | #include <algorithm> |
2 | #include <array> |
3 | |
4 | #include "taichi/inc/constants.h" |
5 | #include "taichi/ir/analysis.h" |
6 | #include "taichi/ir/snode.h" |
7 | #include "taichi/ir/statements.h" |
8 | #include "taichi/transforms/scalar_pointer_lowerer.h" |
9 | #include "taichi/transforms/utils.h" |
10 | |
11 | namespace taichi::lang { |
12 | |
13 | ScalarPointerLowerer::ScalarPointerLowerer(SNode *leaf_snode, |
14 | const std::vector<Stmt *> &indices, |
15 | const SNodeOpType snode_op, |
16 | const bool is_bit_vectorized, |
17 | VecStatement *lowered) |
18 | : indices_(indices), |
19 | snode_op_(snode_op), |
20 | is_bit_vectorized_(is_bit_vectorized), |
21 | lowered_(lowered) { |
22 | for (auto *s = leaf_snode; s != nullptr; s = s->parent) { |
23 | snodes_.push_back(s); |
24 | } |
25 | // From root to leaf |
26 | std::reverse(snodes_.begin(), snodes_.end()); |
27 | |
28 | const int path_inc = (int)(snode_op_ != SNodeOpType::undefined); |
29 | path_length_ = (int)snodes_.size() - 1 + path_inc; |
30 | } |
31 | |
32 | void ScalarPointerLowerer::run() { |
33 | std::array<int, taichi_max_num_indices> total_shape; |
34 | total_shape.fill(1); |
35 | for (const auto *s : snodes_) { |
36 | for (int j = 0; j < taichi_max_num_indices; j++) { |
37 | total_shape[j] *= s->extractors[j].shape; |
38 | } |
39 | } |
40 | std::array<bool, taichi_max_num_indices> ; |
41 | is_first_extraction.fill(true); |
42 | |
43 | if (path_length_ == 0) |
44 | return; |
45 | |
46 | auto *leaf_snode = snodes_[path_length_ - 1]; |
47 | Stmt *last = lowered_->push_back<GetRootStmt>(snodes_[0]); |
48 | for (int i = 0; i < path_length_; i++) { |
49 | auto *snode = snodes_[i]; |
50 | // TODO: Explain this condition |
51 | if (is_bit_vectorized_ && (snode->type == SNodeType::quant_array) && |
52 | (i == path_length_ - 1) && (snodes_[i - 1]->type == SNodeType::dense)) { |
53 | continue; |
54 | } |
55 | std::vector<Stmt *> lowered_indices; |
56 | std::vector<int> strides; |
57 | // extract lowered indices |
58 | for (int k_ = 0; k_ < (int)indices_.size(); k_++) { |
59 | int k = leaf_snode->physical_index_position[k_]; |
60 | if (!snode->extractors[k].active) |
61 | continue; |
62 | Stmt *; |
63 | const int prev = total_shape[k]; |
64 | total_shape[k] /= snode->extractors[k].shape; |
65 | const int next = total_shape[k]; |
66 | // Upon first extraction on axis k, "indices_[k_]" is the user |
67 | // coordinate on axis k and "prev" is the total shape of axis k. |
68 | // Unless it is an invalid out-of-bound access, we can assume |
69 | // "indices_[k_] < prev" so we don't need a mod here. |
70 | if (is_first_extraction[k]) { |
71 | extracted = indices_[k_]; |
72 | } else { |
73 | extracted = generate_mod(lowered_, indices_[k_], prev); |
74 | } |
75 | extracted = generate_div(lowered_, extracted, next); |
76 | is_first_extraction[k] = false; |
77 | lowered_indices.push_back(extracted); |
78 | strides.push_back(snode->extractors[k].shape); |
79 | } |
80 | // linearize |
81 | auto *linearized = |
82 | lowered_->push_back<LinearizeStmt>(lowered_indices, strides); |
83 | |
84 | last = handle_snode_at_level(i, linearized, last); |
85 | } |
86 | } |
87 | |
88 | } // namespace taichi::lang |
89 | |