1// This pass analyzes compile-time known offsets for two values.
2
3#include "taichi/ir/ir.h"
4#include "taichi/ir/analysis.h"
5#include "taichi/ir/statements.h"
6#include "taichi/ir/visitors.h"
7
8namespace taichi::lang {
9
10DiffRange operator+(const DiffRange &a, const DiffRange &b) {
11 return DiffRange(a.related() && b.related(), a.coeff + b.coeff, a.low + b.low,
12 a.high + b.high - 1);
13}
14
15DiffRange operator-(const DiffRange &a, const DiffRange &b) {
16 return DiffRange(a.related() && b.related(), a.coeff - b.coeff,
17 a.low - b.high + 1, a.high - b.low);
18}
19
20DiffRange operator*(const DiffRange &a, const DiffRange &b) {
21 return DiffRange(
22 a.related() && b.related() && a.coeff * b.coeff == 0,
23 fmax(a.low * b.coeff, a.coeff * b.low),
24 fmin(a.low * b.low,
25 fmin(a.low * (b.high - 1),
26 fmin(b.low * (a.high - 1), (a.high - 1) * (b.high - 1)))),
27 fmax(a.low * b.low,
28 fmax(a.low * (b.high - 1),
29 fmax(b.low * (a.high - 1), (a.high - 1) * (b.high - 1)))) +
30 1);
31}
32
33DiffRange operator<<(const DiffRange &a, const DiffRange &b) {
34 return DiffRange(
35 a.related() && b.related() && b.coeff == 0 && b.high - b.low == 1,
36 a.coeff << b.low, a.low << b.low, ((a.high - 1) << b.low) + 1);
37}
38
39namespace {
40
41class ValueDiffLoopIndex : public IRVisitor {
42 public:
43 // first: related, second: offset
44 using ret_type = DiffRange;
45 Stmt *input_stmt, *loop;
46 int loop_index;
47 std::map<int, ret_type> results;
48
49 ValueDiffLoopIndex(Stmt *stmt, Stmt *loop, int loop_index)
50 : input_stmt(stmt), loop(loop), loop_index(loop_index) {
51 allow_undefined_visitor = true;
52 invoke_default_visitor = true;
53 }
54
55 void visit(Stmt *stmt) override {
56 results[stmt->instance_id] = DiffRange();
57 }
58
59 void visit(GlobalLoadStmt *stmt) override {
60 results[stmt->instance_id] = DiffRange();
61 }
62
63 void visit(LoopIndexStmt *stmt) override {
64 results[stmt->instance_id] = DiffRange();
65 if (stmt->loop == loop && stmt->index == loop_index) {
66 results[stmt->instance_id] =
67 DiffRange(/*related=*/true, /*coeff=*/1, /*low=*/0);
68 } else if (auto range_for = stmt->loop->cast<RangeForStmt>()) {
69 if (range_for->begin->is<ConstStmt>() &&
70 range_for->end->is<ConstStmt>()) {
71 auto begin_val = range_for->begin->as<ConstStmt>()->val.val_int();
72 auto end_val = range_for->end->as<ConstStmt>()->val.val_int();
73 // We have begin_val <= end_val even when range_for->reversed is true:
74 // in that case, the loop is iterated from end_val - 1 to begin_val.
75 results[stmt->instance_id] = DiffRange(
76 /*related=*/true, /*coeff=*/0, /*low=*/begin_val, /*high=*/end_val);
77 }
78 }
79 }
80
81 void visit(ConstStmt *stmt) override {
82 if (stmt->val.dt->is_primitive(PrimitiveTypeID::i32)) {
83 results[stmt->instance_id] = DiffRange(true, 0, stmt->val.val_i32);
84 } else {
85 results[stmt->instance_id] = DiffRange();
86 }
87 }
88
89 void visit(RangeAssumptionStmt *stmt) override {
90 stmt->base->accept(this);
91 results[stmt->instance_id] = results[stmt->base->instance_id] +
92 DiffRange(true, 0, stmt->low, stmt->high);
93 }
94
95 void visit(BinaryOpStmt *stmt) override {
96 if (stmt->op_type == BinaryOpType::add ||
97 stmt->op_type == BinaryOpType::sub ||
98 stmt->op_type == BinaryOpType::mul ||
99 stmt->op_type == BinaryOpType::bit_shl) {
100 stmt->lhs->accept(this);
101 stmt->rhs->accept(this);
102 auto ret1 = results[stmt->lhs->instance_id];
103 auto ret2 = results[stmt->rhs->instance_id];
104 if (ret1.related() && ret2.related()) {
105 if (stmt->op_type == BinaryOpType::add) {
106 results[stmt->instance_id] = ret1 + ret2;
107 } else if (stmt->op_type == BinaryOpType::sub) {
108 results[stmt->instance_id] = ret1 - ret2;
109 } else if (stmt->op_type == BinaryOpType::mul) {
110 results[stmt->instance_id] = ret1 * ret2;
111 } else {
112 results[stmt->instance_id] = ret1 << ret2;
113 }
114 return;
115 }
116 }
117 results[stmt->instance_id] = {false, 0};
118 }
119
120 ret_type run() {
121 input_stmt->accept(this);
122 return results[input_stmt->instance_id];
123 }
124};
125
126class FindDirectValueBaseAndOffset : public IRVisitor {
127 public:
128 // In the return value, <first> is true if this class finds that the input
129 // statement has value equal to <second> + <third> (base + offset), or
130 // <first> is false if this class can't find the decomposition.
131 using ret_type = std::tuple<bool, Stmt *, int>;
132 ret_type result;
133 FindDirectValueBaseAndOffset() : result(false, nullptr, 0) {
134 allow_undefined_visitor = true;
135 invoke_default_visitor = true;
136 }
137
138 void visit(Stmt *stmt) override {
139 result = std::make_tuple(false, nullptr, 0);
140 }
141
142 void visit(ConstStmt *stmt) override {
143 if (stmt->val.dt->is_primitive(PrimitiveTypeID::i32)) {
144 result = std::make_tuple(true, nullptr, stmt->val.val_i32);
145 }
146 }
147
148 void visit(BinaryOpStmt *stmt) override {
149 if (stmt->rhs->is<ConstStmt>())
150 stmt->rhs->accept(this);
151 if (!std::get<0>(result) || std::get<1>(result) != nullptr ||
152 (stmt->op_type != BinaryOpType::add &&
153 stmt->op_type != BinaryOpType::sub)) {
154 result = std::make_tuple(false, nullptr, 0);
155 return;
156 }
157 if (stmt->op_type == BinaryOpType::sub)
158 std::get<2>(result) = -std::get<2>(result);
159 std::get<1>(result) = stmt->lhs;
160 }
161
162 static ret_type run(Stmt *val) {
163 FindDirectValueBaseAndOffset instance;
164 val->accept(&instance);
165 return instance.result;
166 }
167};
168
169} // namespace
170
171namespace irpass {
172namespace analysis {
173
174DiffRange value_diff_loop_index(Stmt *stmt, Stmt *loop, int index_id) {
175 TI_ASSERT(loop->is<StructForStmt>() || loop->is<OffloadedStmt>());
176 if (loop->is<OffloadedStmt>()) {
177 TI_ASSERT(loop->as<OffloadedStmt>()->task_type ==
178 OffloadedStmt::TaskType::struct_for);
179 }
180 if (auto loop_index = stmt->cast<LoopIndexStmt>(); loop_index) {
181 if (loop_index->loop == loop && loop_index->index == index_id) {
182 return DiffRange(true, 1, 0);
183 }
184 }
185 auto diff = ValueDiffLoopIndex(stmt, loop, index_id);
186 return diff.run();
187}
188
189DiffPtrResult value_diff_ptr_index(Stmt *val1, Stmt *val2) {
190 if (val1 == val2) {
191 return DiffPtrResult::make_certain(0);
192 }
193 auto v1 = FindDirectValueBaseAndOffset::run(val1);
194 auto v2 = FindDirectValueBaseAndOffset::run(val2);
195 if (!std::get<0>(v1) || !std::get<0>(v2) ||
196 std::get<1>(v1) != std::get<1>(v2)) {
197 return DiffPtrResult::make_uncertain();
198 }
199 return DiffPtrResult::make_certain(std::get<2>(v1) - std::get<2>(v2));
200}
201
202} // namespace analysis
203} // namespace irpass
204} // namespace taichi::lang
205