1#include <algorithm>
2#include <array>
3
4#include "taichi/inc/constants.h"
5#include "taichi/ir/analysis.h"
6#include "taichi/ir/snode.h"
7#include "taichi/ir/statements.h"
8#include "taichi/transforms/scalar_pointer_lowerer.h"
9#include "taichi/transforms/utils.h"
10
11namespace taichi::lang {
12
13ScalarPointerLowerer::ScalarPointerLowerer(SNode *leaf_snode,
14 const std::vector<Stmt *> &indices,
15 const SNodeOpType snode_op,
16 const bool is_bit_vectorized,
17 VecStatement *lowered)
18 : indices_(indices),
19 snode_op_(snode_op),
20 is_bit_vectorized_(is_bit_vectorized),
21 lowered_(lowered) {
22 for (auto *s = leaf_snode; s != nullptr; s = s->parent) {
23 snodes_.push_back(s);
24 }
25 // From root to leaf
26 std::reverse(snodes_.begin(), snodes_.end());
27
28 const int path_inc = (int)(snode_op_ != SNodeOpType::undefined);
29 path_length_ = (int)snodes_.size() - 1 + path_inc;
30}
31
32void ScalarPointerLowerer::run() {
33 std::array<int, taichi_max_num_indices> total_shape;
34 total_shape.fill(1);
35 for (const auto *s : snodes_) {
36 for (int j = 0; j < taichi_max_num_indices; j++) {
37 total_shape[j] *= s->extractors[j].shape;
38 }
39 }
40 std::array<bool, taichi_max_num_indices> is_first_extraction;
41 is_first_extraction.fill(true);
42
43 if (path_length_ == 0)
44 return;
45
46 auto *leaf_snode = snodes_[path_length_ - 1];
47 Stmt *last = lowered_->push_back<GetRootStmt>(snodes_[0]);
48 for (int i = 0; i < path_length_; i++) {
49 auto *snode = snodes_[i];
50 // TODO: Explain this condition
51 if (is_bit_vectorized_ && (snode->type == SNodeType::quant_array) &&
52 (i == path_length_ - 1) && (snodes_[i - 1]->type == SNodeType::dense)) {
53 continue;
54 }
55 std::vector<Stmt *> lowered_indices;
56 std::vector<int> strides;
57 // extract lowered indices
58 for (int k_ = 0; k_ < (int)indices_.size(); k_++) {
59 int k = leaf_snode->physical_index_position[k_];
60 if (!snode->extractors[k].active)
61 continue;
62 Stmt *extracted;
63 const int prev = total_shape[k];
64 total_shape[k] /= snode->extractors[k].shape;
65 const int next = total_shape[k];
66 // Upon first extraction on axis k, "indices_[k_]" is the user
67 // coordinate on axis k and "prev" is the total shape of axis k.
68 // Unless it is an invalid out-of-bound access, we can assume
69 // "indices_[k_] < prev" so we don't need a mod here.
70 if (is_first_extraction[k]) {
71 extracted = indices_[k_];
72 } else {
73 extracted = generate_mod(lowered_, indices_[k_], prev);
74 }
75 extracted = generate_div(lowered_, extracted, next);
76 is_first_extraction[k] = false;
77 lowered_indices.push_back(extracted);
78 strides.push_back(snode->extractors[k].shape);
79 }
80 // linearize
81 auto *linearized =
82 lowered_->push_back<LinearizeStmt>(lowered_indices, strides);
83
84 last = handle_snode_at_level(i, linearized, last);
85 }
86}
87
88} // namespace taichi::lang
89