1#include "taichi/analysis/arithmetic_interpretor.h"
2
3#include <algorithm>
4#include <type_traits>
5#include <vector>
6
7#include "taichi/ir/type_utils.h"
8#include "taichi/ir/visitors.h"
9
10namespace taichi::lang {
11namespace {
12
13using CodeRegion = ArithmeticInterpretor::CodeRegion;
14using EvalContext = ArithmeticInterpretor::EvalContext;
15
16std::vector<Stmt *> get_raw_statements(const Block *block) {
17 const auto &stmts = block->statements;
18 std::vector<Stmt *> res(stmts.size());
19 std::transform(stmts.begin(), stmts.end(), res.begin(),
20 [](const std::unique_ptr<Stmt> &s) { return s.get(); });
21 return res;
22}
23
24class EvalVisitor : public IRVisitor {
25 public:
26 explicit EvalVisitor() {
27 allow_undefined_visitor = true;
28 invoke_default_visitor = true;
29 }
30
31 std::optional<TypedConstant> run(const CodeRegion &region,
32 const EvalContext &init_ctx) {
33 context_ = init_ctx;
34 failed_ = false;
35
36 auto stmts = get_raw_statements(region.block);
37 if (stmts.empty()) {
38 return std::nullopt;
39 }
40 auto *begin_stmt = (region.begin == nullptr) ? stmts.front() : region.begin;
41 auto *end_stmt = (region.end == nullptr) ? stmts.back() : region.end;
42
43 auto cur_iter = std::find(stmts.begin(), stmts.end(), begin_stmt);
44 auto end_iter = std::find(stmts.begin(), stmts.end(), end_stmt);
45 if ((cur_iter == stmts.end()) || (end_iter == stmts.end())) {
46 return std::nullopt;
47 }
48 Stmt *cur_stmt = nullptr;
49 while (cur_iter != end_iter) {
50 cur_stmt = *cur_iter;
51 cur_stmt->accept(this);
52 if (failed_) {
53 return std::nullopt;
54 }
55 ++cur_iter;
56 }
57 return context_.maybe_get(cur_stmt);
58 }
59
60 void visit(ConstStmt *stmt) override {
61 context_.insert(stmt, stmt->val);
62 }
63
64 void visit(BinaryOpStmt *stmt) override {
65 auto lhs_opt = context_.maybe_get(stmt->lhs);
66 auto rhs_opt = context_.maybe_get(stmt->rhs);
67 if (!lhs_opt || !rhs_opt) {
68 failed_ = true;
69 return;
70 }
71 auto lhs = lhs_opt.value();
72 auto rhs = rhs_opt.value();
73 if (lhs.dt != rhs.dt) {
74 failed_ = true;
75 return;
76 }
77
78 const auto op = stmt->op_type;
79 const auto dt = lhs.dt;
80 // TODO: Consider using macros to avoid duplication
81 if (is_real(dt)) {
82 // Put floating point numbers first because is_signed/unsigned asserts
83 // that the data type being integral.
84 auto res_opt = eval_bin_op(lhs.val_float(), rhs.val_float(), op);
85 insert_or_failed(stmt, dt, res_opt);
86 } else if (is_signed(dt)) {
87 auto res_opt = eval_bin_op(lhs.val_int(), rhs.val_int(), op);
88 insert_or_failed(stmt, dt, res_opt);
89 } else if (is_unsigned(dt)) {
90 auto res_opt = eval_bin_op(lhs.val_uint(), rhs.val_uint(), op);
91 insert_or_failed(stmt, dt, res_opt);
92 } else {
93 TI_NOT_IMPLEMENTED;
94 failed_ = true;
95 }
96 }
97
98 void visit(LinearizeStmt *stmt) override {
99 int64_t val = 0;
100 for (int i = 0; i < (int)stmt->inputs.size(); ++i) {
101 auto idx_opt = context_.maybe_get(stmt->inputs[i]);
102 if (!idx_opt) {
103 failed_ = true;
104 return;
105 }
106 val = (val * stmt->strides[i]) + idx_opt.value().val_int();
107 }
108 insert_to_ctx(stmt, stmt->ret_type, val);
109 }
110
111 void visit(Stmt *stmt) override {
112 if (context_.should_ignore(stmt)) {
113 return;
114 }
115 failed_ = (context_.maybe_get(stmt) == std::nullopt);
116 }
117
118 private:
119 template <typename T>
120 static std::optional<T> eval_bin_op(T lhs, T rhs, BinaryOpType op) {
121 if (op == BinaryOpType::add) {
122 return lhs + rhs;
123 }
124 if (op == BinaryOpType::sub) {
125 return lhs - rhs;
126 }
127 if (op == BinaryOpType::mul) {
128 return lhs * rhs;
129 }
130 if (op == BinaryOpType::div) {
131 return lhs / rhs;
132 }
133 if constexpr (std::is_integral_v<T>) {
134 if (op == BinaryOpType::mod) {
135 return lhs % rhs;
136 }
137 if (op == BinaryOpType::bit_and) {
138 return lhs & rhs;
139 }
140 if (op == BinaryOpType::bit_shr) {
141 return static_cast<std::make_unsigned_t<T>>(lhs) >> rhs;
142 }
143 }
144 return std::nullopt;
145 }
146
147 template <typename T>
148 void insert_or_failed(const Stmt *stmt,
149 DataType dt,
150 std::optional<T> val_opt) {
151 if (!val_opt) {
152 failed_ = true;
153 return;
154 }
155 context_.insert(stmt, TypedConstant(dt, val_opt.value()));
156 }
157
158 template <typename T>
159 void insert_to_ctx(const Stmt *stmt, DataType dt, const T &val) {
160 context_.insert(stmt, TypedConstant(dt, val));
161 }
162
163 EvalContext context_;
164 bool failed_{false};
165};
166
167} // namespace
168
169std::optional<TypedConstant> ArithmeticInterpretor::evaluate(
170 const CodeRegion &region,
171 const EvalContext &init_ctx) const {
172 EvalVisitor ev;
173 return ev.run(region, init_ctx);
174}
175
176} // namespace taichi::lang
177