1 | #include "taichi/analysis/arithmetic_interpretor.h" |
2 | |
3 | #include <algorithm> |
4 | #include <type_traits> |
5 | #include <vector> |
6 | |
7 | #include "taichi/ir/type_utils.h" |
8 | #include "taichi/ir/visitors.h" |
9 | |
10 | namespace taichi::lang { |
11 | namespace { |
12 | |
13 | using CodeRegion = ArithmeticInterpretor::CodeRegion; |
14 | using EvalContext = ArithmeticInterpretor::EvalContext; |
15 | |
16 | std::vector<Stmt *> get_raw_statements(const Block *block) { |
17 | const auto &stmts = block->statements; |
18 | std::vector<Stmt *> res(stmts.size()); |
19 | std::transform(stmts.begin(), stmts.end(), res.begin(), |
20 | [](const std::unique_ptr<Stmt> &s) { return s.get(); }); |
21 | return res; |
22 | } |
23 | |
24 | class EvalVisitor : public IRVisitor { |
25 | public: |
26 | explicit EvalVisitor() { |
27 | allow_undefined_visitor = true; |
28 | invoke_default_visitor = true; |
29 | } |
30 | |
31 | std::optional<TypedConstant> run(const CodeRegion ®ion, |
32 | const EvalContext &init_ctx) { |
33 | context_ = init_ctx; |
34 | failed_ = false; |
35 | |
36 | auto stmts = get_raw_statements(region.block); |
37 | if (stmts.empty()) { |
38 | return std::nullopt; |
39 | } |
40 | auto *begin_stmt = (region.begin == nullptr) ? stmts.front() : region.begin; |
41 | auto *end_stmt = (region.end == nullptr) ? stmts.back() : region.end; |
42 | |
43 | auto cur_iter = std::find(stmts.begin(), stmts.end(), begin_stmt); |
44 | auto end_iter = std::find(stmts.begin(), stmts.end(), end_stmt); |
45 | if ((cur_iter == stmts.end()) || (end_iter == stmts.end())) { |
46 | return std::nullopt; |
47 | } |
48 | Stmt *cur_stmt = nullptr; |
49 | while (cur_iter != end_iter) { |
50 | cur_stmt = *cur_iter; |
51 | cur_stmt->accept(this); |
52 | if (failed_) { |
53 | return std::nullopt; |
54 | } |
55 | ++cur_iter; |
56 | } |
57 | return context_.maybe_get(cur_stmt); |
58 | } |
59 | |
60 | void visit(ConstStmt *stmt) override { |
61 | context_.insert(stmt, stmt->val); |
62 | } |
63 | |
64 | void visit(BinaryOpStmt *stmt) override { |
65 | auto lhs_opt = context_.maybe_get(stmt->lhs); |
66 | auto rhs_opt = context_.maybe_get(stmt->rhs); |
67 | if (!lhs_opt || !rhs_opt) { |
68 | failed_ = true; |
69 | return; |
70 | } |
71 | auto lhs = lhs_opt.value(); |
72 | auto rhs = rhs_opt.value(); |
73 | if (lhs.dt != rhs.dt) { |
74 | failed_ = true; |
75 | return; |
76 | } |
77 | |
78 | const auto op = stmt->op_type; |
79 | const auto dt = lhs.dt; |
80 | // TODO: Consider using macros to avoid duplication |
81 | if (is_real(dt)) { |
82 | // Put floating point numbers first because is_signed/unsigned asserts |
83 | // that the data type being integral. |
84 | auto res_opt = eval_bin_op(lhs.val_float(), rhs.val_float(), op); |
85 | insert_or_failed(stmt, dt, res_opt); |
86 | } else if (is_signed(dt)) { |
87 | auto res_opt = eval_bin_op(lhs.val_int(), rhs.val_int(), op); |
88 | insert_or_failed(stmt, dt, res_opt); |
89 | } else if (is_unsigned(dt)) { |
90 | auto res_opt = eval_bin_op(lhs.val_uint(), rhs.val_uint(), op); |
91 | insert_or_failed(stmt, dt, res_opt); |
92 | } else { |
93 | TI_NOT_IMPLEMENTED; |
94 | failed_ = true; |
95 | } |
96 | } |
97 | |
98 | void visit(LinearizeStmt *stmt) override { |
99 | int64_t val = 0; |
100 | for (int i = 0; i < (int)stmt->inputs.size(); ++i) { |
101 | auto idx_opt = context_.maybe_get(stmt->inputs[i]); |
102 | if (!idx_opt) { |
103 | failed_ = true; |
104 | return; |
105 | } |
106 | val = (val * stmt->strides[i]) + idx_opt.value().val_int(); |
107 | } |
108 | insert_to_ctx(stmt, stmt->ret_type, val); |
109 | } |
110 | |
111 | void visit(Stmt *stmt) override { |
112 | if (context_.should_ignore(stmt)) { |
113 | return; |
114 | } |
115 | failed_ = (context_.maybe_get(stmt) == std::nullopt); |
116 | } |
117 | |
118 | private: |
119 | template <typename T> |
120 | static std::optional<T> eval_bin_op(T lhs, T rhs, BinaryOpType op) { |
121 | if (op == BinaryOpType::add) { |
122 | return lhs + rhs; |
123 | } |
124 | if (op == BinaryOpType::sub) { |
125 | return lhs - rhs; |
126 | } |
127 | if (op == BinaryOpType::mul) { |
128 | return lhs * rhs; |
129 | } |
130 | if (op == BinaryOpType::div) { |
131 | return lhs / rhs; |
132 | } |
133 | if constexpr (std::is_integral_v<T>) { |
134 | if (op == BinaryOpType::mod) { |
135 | return lhs % rhs; |
136 | } |
137 | if (op == BinaryOpType::bit_and) { |
138 | return lhs & rhs; |
139 | } |
140 | if (op == BinaryOpType::bit_shr) { |
141 | return static_cast<std::make_unsigned_t<T>>(lhs) >> rhs; |
142 | } |
143 | } |
144 | return std::nullopt; |
145 | } |
146 | |
147 | template <typename T> |
148 | void insert_or_failed(const Stmt *stmt, |
149 | DataType dt, |
150 | std::optional<T> val_opt) { |
151 | if (!val_opt) { |
152 | failed_ = true; |
153 | return; |
154 | } |
155 | context_.insert(stmt, TypedConstant(dt, val_opt.value())); |
156 | } |
157 | |
158 | template <typename T> |
159 | void insert_to_ctx(const Stmt *stmt, DataType dt, const T &val) { |
160 | context_.insert(stmt, TypedConstant(dt, val)); |
161 | } |
162 | |
163 | EvalContext context_; |
164 | bool failed_{false}; |
165 | }; |
166 | |
167 | } // namespace |
168 | |
169 | std::optional<TypedConstant> ArithmeticInterpretor::evaluate( |
170 | const CodeRegion ®ion, |
171 | const EvalContext &init_ctx) const { |
172 | EvalVisitor ev; |
173 | return ev.run(region, init_ctx); |
174 | } |
175 | |
176 | } // namespace taichi::lang |
177 | |