1 | // Intermediate representation system |
2 | |
3 | #pragma once |
4 | |
5 | #include <atomic> |
6 | #include <unordered_set> |
7 | #include <unordered_map> |
8 | #include <variant> |
9 | #include <tuple> |
10 | |
11 | #include "taichi/common/core.h" |
12 | #include "taichi/common/exceptions.h" |
13 | #include "taichi/common/one_or_more.h" |
14 | #include "taichi/ir/snode.h" |
15 | #include "taichi/ir/mesh.h" |
16 | #include "taichi/ir/type_factory.h" |
17 | #include "taichi/util/short_name.h" |
18 | |
19 | #ifdef TI_WITH_LLVM |
20 | #include "llvm/ADT/SmallVector.h" |
21 | #include "llvm/ADT/MapVector.h" |
22 | #endif |
23 | |
24 | namespace taichi::lang { |
25 | |
26 | class IRNode; |
27 | class Block; |
28 | class Stmt; |
29 | using pStmt = std::unique_ptr<Stmt>; |
30 | |
31 | class SNode; |
32 | |
33 | class Kernel; |
34 | struct CompileConfig; |
35 | |
36 | enum class SNodeAccessFlag : int { block_local, read_only, mesh_local }; |
37 | std::string snode_access_flag_name(SNodeAccessFlag type); |
38 | |
39 | class MemoryAccessOptions { |
40 | public: |
41 | void add_flag(SNode *snode, SNodeAccessFlag flag) { |
42 | options_[snode].insert(flag); |
43 | } |
44 | |
45 | bool has_flag(SNode *snode, SNodeAccessFlag flag) const { |
46 | if (auto it = options_.find(snode); it != options_.end()) |
47 | return it->second.count(flag) != 0; |
48 | else |
49 | return false; |
50 | } |
51 | |
52 | std::vector<SNode *> get_snodes_with_flag(SNodeAccessFlag flag) const { |
53 | std::vector<SNode *> snodes; |
54 | for (const auto &opt : options_) { |
55 | if (has_flag(opt.first, flag)) { |
56 | snodes.push_back(opt.first); |
57 | } |
58 | } |
59 | return snodes; |
60 | } |
61 | |
62 | void clear() { |
63 | options_.clear(); |
64 | } |
65 | |
66 | std::unordered_map<SNode *, std::unordered_set<SNodeAccessFlag>> get_all() |
67 | const { |
68 | return options_; |
69 | } |
70 | |
71 | private: |
72 | std::unordered_map<SNode *, std::unordered_set<SNodeAccessFlag>> options_; |
73 | }; |
74 | |
75 | #define PER_STATEMENT(x) class x; |
76 | #include "taichi/inc/statements.inc.h" |
77 | #undef PER_STATEMENT |
78 | |
79 | class Identifier { |
80 | public: |
81 | std::string name_; |
82 | int id{0}; |
83 | |
84 | // Identifier() = default; |
85 | |
86 | // Multiple identifiers can share the same name but must have different id's |
87 | explicit Identifier(int id, const std::string &name = "" ) |
88 | : name_(name), id(id) { |
89 | } |
90 | |
91 | std::string raw_name() const; |
92 | |
93 | std::string name() const { |
94 | return "@" + raw_name(); |
95 | } |
96 | |
97 | bool operator<(const Identifier &o) const { |
98 | return id < o.id; |
99 | } |
100 | |
101 | bool operator==(const Identifier &o) const { |
102 | return id == o.id; |
103 | } |
104 | }; |
105 | |
106 | #ifdef TI_WITH_LLVM |
107 | using stmt_vector = llvm::SmallVector<pStmt, 8>; |
108 | using stmt_ref_vector = llvm::SmallVector<Stmt *, 2>; |
109 | #else |
110 | using stmt_vector = std::vector<pStmt>; |
111 | using stmt_ref_vector = std::vector<Stmt *>; |
112 | #endif |
113 | |
114 | class VecStatement { |
115 | public: |
116 | stmt_vector stmts; |
117 | |
118 | VecStatement() { |
119 | } |
120 | |
121 | // NOLINTNEXTLINE(google-explicit-constructor) |
122 | VecStatement(pStmt &&stmt) { |
123 | push_back(std::move(stmt)); |
124 | } |
125 | |
126 | VecStatement(VecStatement &&o) { |
127 | stmts = std::move(o.stmts); |
128 | } |
129 | |
130 | // NOLINTNEXTLINE(google-explicit-constructor) |
131 | VecStatement(stmt_vector &&other_stmts) { |
132 | stmts = std::move(other_stmts); |
133 | } |
134 | |
135 | Stmt *push_back(pStmt &&stmt); |
136 | |
137 | template <typename T, typename... Args> |
138 | T *push_back(Args &&...args) { |
139 | auto up = std::make_unique<T>(std::forward<Args>(args)...); |
140 | auto ptr = up.get(); |
141 | stmts.push_back(std::move(up)); |
142 | return ptr; |
143 | } |
144 | |
145 | pStmt &back() { |
146 | return stmts.back(); |
147 | } |
148 | |
149 | std::size_t size() const { |
150 | return stmts.size(); |
151 | } |
152 | |
153 | pStmt &operator[](int i) { |
154 | return stmts[i]; |
155 | } |
156 | }; |
157 | |
158 | class IRVisitor { |
159 | public: |
160 | bool allow_undefined_visitor; |
161 | bool invoke_default_visitor; |
162 | |
163 | IRVisitor() { |
164 | allow_undefined_visitor = false; |
165 | invoke_default_visitor = false; |
166 | } |
167 | |
168 | virtual ~IRVisitor() = default; |
169 | |
170 | // default visitor |
171 | virtual void visit(Stmt *stmt) { |
172 | if (!allow_undefined_visitor) { |
173 | TI_ERROR( |
174 | "missing visitor function. Is the statement class registered via " |
175 | "DEFINE_VISIT?" ); |
176 | } |
177 | } |
178 | |
179 | #define DEFINE_VISIT(T) \ |
180 | virtual void visit(T *stmt) { \ |
181 | if (allow_undefined_visitor) { \ |
182 | if (invoke_default_visitor) \ |
183 | visit((Stmt *)stmt); \ |
184 | } else \ |
185 | TI_NOT_IMPLEMENTED; \ |
186 | } |
187 | |
188 | DEFINE_VISIT(Block); |
189 | #define PER_STATEMENT(x) DEFINE_VISIT(x) |
190 | #include "taichi/inc/statements.inc.h" |
191 | |
192 | #undef PER_STATEMENT |
193 | #undef DEFINE_VISIT |
194 | }; |
195 | |
196 | struct CompileConfig; |
197 | class Kernel; |
198 | |
199 | using stmt_refs = one_or_more<Stmt *>; |
200 | |
201 | namespace ir_traits { |
202 | |
203 | // FIXME: Use C++ 20 concepts to replace `dynamic_cast<T>() != nullptr` |
204 | |
205 | class Store { |
206 | public: |
207 | virtual ~Store() = default; |
208 | |
209 | // Get the list of sinks/destinations of the store operation |
210 | virtual stmt_refs get_store_destination() const = 0; |
211 | |
212 | // If store_stmt provides one data source, return the data. |
213 | virtual Stmt *get_store_data() const = 0; |
214 | }; |
215 | |
216 | class Load { |
217 | public: |
218 | virtual ~Load() = default; |
219 | |
220 | // If load_stmt loads some variables or a stack, return the pointers of them. |
221 | virtual stmt_refs get_load_pointers() const = 0; |
222 | }; |
223 | |
224 | } // namespace ir_traits |
225 | |
226 | class IRNode { |
227 | public: |
228 | virtual void accept(IRVisitor *visitor) { |
229 | TI_NOT_IMPLEMENTED |
230 | } |
231 | |
232 | // * For a Stmt, this returns its enclosing Block |
233 | // * For a Block, this returns its enclosing Stmt |
234 | virtual IRNode *get_parent() const = 0; |
235 | |
236 | IRNode *get_ir_root(); |
237 | |
238 | virtual ~IRNode() = default; |
239 | |
240 | template <typename T> |
241 | bool is() const { |
242 | return dynamic_cast<const T *>(this) != nullptr; |
243 | } |
244 | |
245 | template <typename T> |
246 | T *as() { |
247 | TI_ASSERT(is<T>()); |
248 | return dynamic_cast<T *>(this); |
249 | } |
250 | |
251 | template <typename T> |
252 | const T *as() const { |
253 | TI_ASSERT(is<T>()); |
254 | return dynamic_cast<const T *>(this); |
255 | } |
256 | |
257 | template <typename T> |
258 | T *cast() { |
259 | return dynamic_cast<T *>(this); |
260 | } |
261 | |
262 | template <typename T> |
263 | const T *cast() const { |
264 | return dynamic_cast<const T *>(this); |
265 | } |
266 | |
267 | std::unique_ptr<IRNode> clone(); |
268 | }; |
269 | |
270 | #define TI_DEFINE_ACCEPT \ |
271 | void accept(IRVisitor *visitor) override { visitor->visit(this); } |
272 | |
273 | #define TI_DEFINE_CLONE \ |
274 | std::unique_ptr<Stmt> clone() const override { \ |
275 | auto new_stmt = \ |
276 | std::make_unique<std::decay<decltype(*this)>::type>(*this); \ |
277 | new_stmt->mark_fields_registered(); \ |
278 | new_stmt->io(new_stmt->field_manager); \ |
279 | return new_stmt; \ |
280 | } |
281 | |
282 | #define TI_DEFINE_ACCEPT_AND_CLONE \ |
283 | TI_DEFINE_ACCEPT \ |
284 | TI_DEFINE_CLONE |
285 | |
286 | class StmtField { |
287 | public: |
288 | StmtField() = default; |
289 | |
290 | virtual bool equal(const StmtField *other) const = 0; |
291 | |
292 | virtual ~StmtField() = default; |
293 | }; |
294 | |
295 | template <typename T> |
296 | class StmtFieldNumeric final : public StmtField { |
297 | private: |
298 | std::variant<T *, T> value_; |
299 | |
300 | public: |
301 | explicit StmtFieldNumeric(T *value) : value_(value) { |
302 | } |
303 | |
304 | explicit StmtFieldNumeric(T value) : value_(value) { |
305 | } |
306 | |
307 | bool equal(const StmtField *other_generic) const override { |
308 | if (auto other = dynamic_cast<const StmtFieldNumeric *>(other_generic)) { |
309 | if (std::holds_alternative<T *>(other->value_) && |
310 | std::holds_alternative<T *>(value_)) { |
311 | return *(std::get<T *>(other->value_)) == *(std::get<T *>(value_)); |
312 | } else if (std::holds_alternative<T *>(other->value_) || |
313 | std::holds_alternative<T *>(value_)) { |
314 | TI_ERROR( |
315 | "Inconsistent StmtField value types: a pointer value is compared " |
316 | "to a non-pointer value." ); |
317 | return false; |
318 | } else { |
319 | return std::get<T>(other->value_) == std::get<T>(value_); |
320 | } |
321 | } else { |
322 | // Different types |
323 | return false; |
324 | } |
325 | } |
326 | }; |
327 | |
328 | class StmtFieldSNode final : public StmtField { |
329 | private: |
330 | SNode *const &snode_; |
331 | |
332 | public: |
333 | explicit StmtFieldSNode(SNode *const &snode) : snode_(snode) { |
334 | } |
335 | |
336 | static int get_snode_id(SNode *snode); |
337 | |
338 | bool equal(const StmtField *other_generic) const override; |
339 | }; |
340 | |
341 | class StmtFieldMemoryAccessOptions final : public StmtField { |
342 | private: |
343 | MemoryAccessOptions const &opt_; |
344 | |
345 | public: |
346 | explicit StmtFieldMemoryAccessOptions(MemoryAccessOptions const &opt) |
347 | : opt_(opt) { |
348 | } |
349 | |
350 | bool equal(const StmtField *other_generic) const override; |
351 | }; |
352 | |
353 | class StmtFieldManager { |
354 | private: |
355 | Stmt *stmt_; |
356 | |
357 | public: |
358 | std::vector<std::unique_ptr<StmtField>> fields; |
359 | |
360 | explicit StmtFieldManager(Stmt *stmt) : stmt_(stmt) { |
361 | } |
362 | |
363 | template <typename T> |
364 | void operator()(const char *key, T &&value); |
365 | |
366 | template <typename T, typename... Args> |
367 | void operator()(const char *key_, T &&t, Args &&...rest) { |
368 | std::string key(key_); |
369 | size_t pos = key.find(','); |
370 | std::string first_name = key.substr(0, pos); |
371 | std::string rest_names = |
372 | key.substr(pos + 2, int(key.size()) - (int)pos - 2); |
373 | this->operator()(first_name.c_str(), std::forward<T>(t)); |
374 | this->operator()(rest_names.c_str(), std::forward<Args>(rest)...); |
375 | } |
376 | |
377 | bool equal(StmtFieldManager &other) const; |
378 | }; |
379 | |
380 | #define TI_STMT_DEF_FIELDS(...) \ |
381 | template <typename S> \ |
382 | void io(S &serializer) const { \ |
383 | TI_IO(__VA_ARGS__); \ |
384 | } |
385 | #define TI_STMT_REG_FIELDS \ |
386 | mark_fields_registered(); \ |
387 | io(field_manager) |
388 | |
389 | class Stmt : public IRNode { |
390 | protected: |
391 | std::vector<Stmt **> operands; |
392 | |
393 | public: |
394 | StmtFieldManager field_manager; |
395 | static std::atomic<int> instance_id_counter; |
396 | int instance_id; |
397 | int id; |
398 | Block *parent; |
399 | bool erased; |
400 | bool fields_registered; |
401 | std::string tb; |
402 | DataType ret_type; |
403 | |
404 | Stmt(); |
405 | Stmt(const Stmt &stmt); |
406 | |
407 | virtual bool is_container_statement() const { |
408 | return false; |
409 | } |
410 | |
411 | DataType &element_type() { |
412 | return ret_type; |
413 | } |
414 | |
415 | std::string ret_data_type_name() const { |
416 | return ret_type->to_string(); |
417 | } |
418 | |
419 | std::string type_hint() const; |
420 | |
421 | std::string name() const { |
422 | return fmt::format("${}" , id); |
423 | } |
424 | |
425 | std::string short_name() const { |
426 | return make_short_name_by_id(id); |
427 | } |
428 | |
429 | std::string raw_name() const { |
430 | return fmt::format("tmp{}" , id); |
431 | } |
432 | |
433 | TI_FORCE_INLINE int num_operands() const { |
434 | return (int)operands.size(); |
435 | } |
436 | |
437 | TI_FORCE_INLINE Stmt *operand(int i) const { |
438 | // TI_ASSERT(0 <= i && i < (int)operands.size()); |
439 | return *operands[i]; |
440 | } |
441 | |
442 | std::vector<Stmt *> get_operands() const; |
443 | |
444 | void set_operand(int i, Stmt *stmt); |
445 | void register_operand(Stmt *&stmt); |
446 | int locate_operand(Stmt **stmt); |
447 | void mark_fields_registered(); |
448 | |
449 | bool has_operand(Stmt *stmt) const; |
450 | |
451 | void replace_usages_with(Stmt *new_stmt); |
452 | void replace_with(VecStatement &&new_statements, bool replace_usages = true); |
453 | virtual void replace_operand_with(Stmt *old_stmt, Stmt *new_stmt); |
454 | |
455 | IRNode *get_parent() const override; |
456 | |
457 | // returns the inserted stmt |
458 | Stmt *insert_before_me(std::unique_ptr<Stmt> &&new_stmt); |
459 | |
460 | // returns the inserted stmt |
461 | Stmt *insert_after_me(std::unique_ptr<Stmt> &&new_stmt); |
462 | |
463 | virtual bool has_global_side_effect() const { |
464 | return true; |
465 | } |
466 | |
467 | virtual bool dead_instruction_eliminable() const { |
468 | return !has_global_side_effect(); |
469 | } |
470 | |
471 | virtual bool common_statement_eliminable() const { |
472 | return !has_global_side_effect(); |
473 | } |
474 | |
475 | template <typename T, typename... Args> |
476 | static std::unique_ptr<T> make_typed(Args &&...args) { |
477 | return std::make_unique<T>(std::forward<Args>(args)...); |
478 | } |
479 | |
480 | template <typename T, typename... Args> |
481 | static pStmt make(Args &&...args) { |
482 | return make_typed<T>(std::forward<Args>(args)...); |
483 | } |
484 | |
485 | void set_tb(const std::string &tb) { |
486 | this->tb = tb; |
487 | } |
488 | |
489 | std::string type(); |
490 | |
491 | virtual std::unique_ptr<Stmt> clone() const { |
492 | TI_NOT_IMPLEMENTED |
493 | } |
494 | |
495 | ~Stmt() override = default; |
496 | |
497 | static void reset_counter() { |
498 | instance_id_counter = 0; |
499 | } |
500 | }; |
501 | |
502 | class Block : public IRNode { |
503 | public: |
504 | Stmt *parent_stmt{nullptr}; |
505 | stmt_vector statements; |
506 | stmt_vector trash_bin; |
507 | std::vector<SNode *> stop_gradients; |
508 | |
509 | // Only used in frontend. Stores LoopIndexStmt or BinaryOpStmt for loop |
510 | // variables, and AllocaStmt for other variables. |
511 | std::map<Identifier, Stmt *> local_var_to_stmt; |
512 | |
513 | Block() { |
514 | parent_stmt = nullptr; |
515 | } |
516 | |
517 | Block *parent_block() const; |
518 | |
519 | bool has_container_statements(); |
520 | int locate(Stmt *stmt); |
521 | stmt_vector::iterator locate(int location); |
522 | stmt_vector::iterator find(Stmt *stmt); |
523 | void erase(int location); |
524 | void erase(Stmt *stmt); |
525 | void erase_range(stmt_vector::iterator begin, stmt_vector::iterator end); |
526 | void erase(std::unordered_set<Stmt *> stmts); |
527 | std::unique_ptr<Stmt> (int location); |
528 | std::unique_ptr<Stmt> (Stmt *stmt); |
529 | |
530 | // Returns stmt.get() |
531 | Stmt *insert(std::unique_ptr<Stmt> &&stmt, int location = -1); |
532 | Stmt *insert_at(std::unique_ptr<Stmt> &&stmt, stmt_vector::iterator location); |
533 | |
534 | // Returns stmt.back().get() or nullptr if stmt is empty |
535 | Stmt *insert(VecStatement &&stmt, int location = -1); |
536 | Stmt *insert_at(VecStatement &&stmt, stmt_vector::iterator location); |
537 | |
538 | void replace_statements_in_range(int start, int end, VecStatement &&stmts); |
539 | void set_statements(VecStatement &&stmts); |
540 | void replace_with(Stmt *old_statement, |
541 | std::unique_ptr<Stmt> &&new_statement, |
542 | bool replace_usages = true); |
543 | void insert_before(Stmt *old_statement, VecStatement &&new_statements); |
544 | void insert_after(Stmt *old_statement, VecStatement &&new_statements); |
545 | void replace_with(Stmt *old_statement, |
546 | VecStatement &&new_statements, |
547 | bool replace_usages = true); |
548 | Stmt *lookup_var(const Identifier &ident) const; |
549 | IRNode *get_parent() const override; |
550 | |
551 | Stmt *back() const { |
552 | return statements.back().get(); |
553 | } |
554 | |
555 | template <typename T, typename... Args> |
556 | Stmt *push_back(Args &&...args) { |
557 | auto stmt = std::make_unique<T>(std::forward<Args>(args)...); |
558 | stmt->parent = this; |
559 | statements.emplace_back(std::move(stmt)); |
560 | return back(); |
561 | } |
562 | |
563 | std::size_t size() const { |
564 | return statements.size(); |
565 | } |
566 | |
567 | pStmt &operator[](int i) { |
568 | return statements[i]; |
569 | } |
570 | |
571 | std::unique_ptr<Block> clone() const; |
572 | |
573 | TI_DEFINE_ACCEPT |
574 | }; |
575 | |
576 | class DelayedIRModifier { |
577 | private: |
578 | std::vector<std::pair<Stmt *, VecStatement>> to_insert_before_; |
579 | std::vector<std::pair<Stmt *, VecStatement>> to_insert_after_; |
580 | std::vector<std::tuple<Stmt *, VecStatement, bool>> to_replace_with_; |
581 | std::vector<Stmt *> to_erase_; |
582 | std::vector<std::pair<Stmt *, Block *>> ; |
583 | std::vector<std::pair<IRNode *, CompileConfig>> to_type_check_; |
584 | bool modified_{false}; |
585 | |
586 | public: |
587 | ~DelayedIRModifier(); |
588 | void erase(Stmt *stmt); |
589 | void insert_before(Stmt *old_statement, std::unique_ptr<Stmt> new_statement); |
590 | void insert_before(Stmt *old_statement, VecStatement &&new_statements); |
591 | void insert_after(Stmt *old_statement, std::unique_ptr<Stmt> new_statement); |
592 | void insert_after(Stmt *old_statement, VecStatement &&new_statements); |
593 | void replace_with(Stmt *stmt, |
594 | VecStatement &&new_statements, |
595 | bool replace_usages = true); |
596 | void (Stmt *stmt, Block *blk); |
597 | void type_check(IRNode *node, CompileConfig cfg); |
598 | bool modify_ir(); |
599 | |
600 | // Force the next call of modify_ir() to return true. |
601 | void mark_as_modified(); |
602 | }; |
603 | |
604 | // ImmediateIRModifier aims at replacing Stmt::replace_usages_with, which visits |
605 | // the whole tree for a single replacement. ImmediateIRModifier is currently |
606 | // associated with a pass, visits the whole tree once at the beginning of that |
607 | // pass, and performs a single replacement with amortized constant time. |
608 | class ImmediateIRModifier { |
609 | private: |
610 | std::unordered_map<Stmt *, std::vector<std::pair<Stmt *, int>>> stmt_usages_; |
611 | |
612 | public: |
613 | explicit ImmediateIRModifier(IRNode *root); |
614 | void replace_usages_with(Stmt *old_stmt, Stmt *new_stmt); |
615 | }; |
616 | |
617 | template <typename T> |
618 | inline void StmtFieldManager::operator()(const char *key, T &&value) { |
619 | using decay_T = typename std::decay<T>::type; |
620 | if constexpr (is_specialization<decay_T, std::vector>::value) { |
621 | stmt_->field_manager.fields.emplace_back( |
622 | std::make_unique<StmtFieldNumeric<std::size_t>>(value.size())); |
623 | for (int i = 0; i < (int)value.size(); i++) { |
624 | (*this)("__element" , value[i]); |
625 | } |
626 | } else if constexpr (std::is_same<decay_T, |
627 | std::variant<Stmt *, std::string>>::value) { |
628 | if (std::holds_alternative<std::string>(value)) { |
629 | stmt_->field_manager.fields.emplace_back( |
630 | std::make_unique<StmtFieldNumeric<std::string>>( |
631 | std::get<std::string>(value))); |
632 | } else { |
633 | (*this)("__element" , std::get<Stmt *>(value)); |
634 | } |
635 | } else if constexpr (std::is_same<decay_T, Stmt *>::value) { |
636 | stmt_->register_operand(const_cast<Stmt *&>(value)); |
637 | } else if constexpr (std::is_same<decay_T, SNode *>::value) { |
638 | stmt_->field_manager.fields.emplace_back( |
639 | std::make_unique<StmtFieldSNode>(value)); |
640 | } else if constexpr (std::is_same<decay_T, MemoryAccessOptions>::value) { |
641 | stmt_->field_manager.fields.emplace_back( |
642 | std::make_unique<StmtFieldMemoryAccessOptions>(value)); |
643 | } else { |
644 | stmt_->field_manager.fields.emplace_back( |
645 | std::make_unique<StmtFieldNumeric<std::remove_reference_t<T>>>(&value)); |
646 | } |
647 | } |
648 | |
649 | } // namespace taichi::lang |
650 | |