1// Intermediate representation system
2
3#pragma once
4
5#include <atomic>
6#include <unordered_set>
7#include <unordered_map>
8#include <variant>
9#include <tuple>
10
11#include "taichi/common/core.h"
12#include "taichi/common/exceptions.h"
13#include "taichi/common/one_or_more.h"
14#include "taichi/ir/snode.h"
15#include "taichi/ir/mesh.h"
16#include "taichi/ir/type_factory.h"
17#include "taichi/util/short_name.h"
18
19#ifdef TI_WITH_LLVM
20#include "llvm/ADT/SmallVector.h"
21#include "llvm/ADT/MapVector.h"
22#endif
23
24namespace taichi::lang {
25
26class IRNode;
27class Block;
28class Stmt;
29using pStmt = std::unique_ptr<Stmt>;
30
31class SNode;
32
33class Kernel;
34struct CompileConfig;
35
36enum class SNodeAccessFlag : int { block_local, read_only, mesh_local };
37std::string snode_access_flag_name(SNodeAccessFlag type);
38
39class MemoryAccessOptions {
40 public:
41 void add_flag(SNode *snode, SNodeAccessFlag flag) {
42 options_[snode].insert(flag);
43 }
44
45 bool has_flag(SNode *snode, SNodeAccessFlag flag) const {
46 if (auto it = options_.find(snode); it != options_.end())
47 return it->second.count(flag) != 0;
48 else
49 return false;
50 }
51
52 std::vector<SNode *> get_snodes_with_flag(SNodeAccessFlag flag) const {
53 std::vector<SNode *> snodes;
54 for (const auto &opt : options_) {
55 if (has_flag(opt.first, flag)) {
56 snodes.push_back(opt.first);
57 }
58 }
59 return snodes;
60 }
61
62 void clear() {
63 options_.clear();
64 }
65
66 std::unordered_map<SNode *, std::unordered_set<SNodeAccessFlag>> get_all()
67 const {
68 return options_;
69 }
70
71 private:
72 std::unordered_map<SNode *, std::unordered_set<SNodeAccessFlag>> options_;
73};
74
75#define PER_STATEMENT(x) class x;
76#include "taichi/inc/statements.inc.h"
77#undef PER_STATEMENT
78
79class Identifier {
80 public:
81 std::string name_;
82 int id{0};
83
84 // Identifier() = default;
85
86 // Multiple identifiers can share the same name but must have different id's
87 explicit Identifier(int id, const std::string &name = "")
88 : name_(name), id(id) {
89 }
90
91 std::string raw_name() const;
92
93 std::string name() const {
94 return "@" + raw_name();
95 }
96
97 bool operator<(const Identifier &o) const {
98 return id < o.id;
99 }
100
101 bool operator==(const Identifier &o) const {
102 return id == o.id;
103 }
104};
105
106#ifdef TI_WITH_LLVM
107using stmt_vector = llvm::SmallVector<pStmt, 8>;
108using stmt_ref_vector = llvm::SmallVector<Stmt *, 2>;
109#else
110using stmt_vector = std::vector<pStmt>;
111using stmt_ref_vector = std::vector<Stmt *>;
112#endif
113
114class VecStatement {
115 public:
116 stmt_vector stmts;
117
118 VecStatement() {
119 }
120
121 // NOLINTNEXTLINE(google-explicit-constructor)
122 VecStatement(pStmt &&stmt) {
123 push_back(std::move(stmt));
124 }
125
126 VecStatement(VecStatement &&o) {
127 stmts = std::move(o.stmts);
128 }
129
130 // NOLINTNEXTLINE(google-explicit-constructor)
131 VecStatement(stmt_vector &&other_stmts) {
132 stmts = std::move(other_stmts);
133 }
134
135 Stmt *push_back(pStmt &&stmt);
136
137 template <typename T, typename... Args>
138 T *push_back(Args &&...args) {
139 auto up = std::make_unique<T>(std::forward<Args>(args)...);
140 auto ptr = up.get();
141 stmts.push_back(std::move(up));
142 return ptr;
143 }
144
145 pStmt &back() {
146 return stmts.back();
147 }
148
149 std::size_t size() const {
150 return stmts.size();
151 }
152
153 pStmt &operator[](int i) {
154 return stmts[i];
155 }
156};
157
158class IRVisitor {
159 public:
160 bool allow_undefined_visitor;
161 bool invoke_default_visitor;
162
163 IRVisitor() {
164 allow_undefined_visitor = false;
165 invoke_default_visitor = false;
166 }
167
168 virtual ~IRVisitor() = default;
169
170 // default visitor
171 virtual void visit(Stmt *stmt) {
172 if (!allow_undefined_visitor) {
173 TI_ERROR(
174 "missing visitor function. Is the statement class registered via "
175 "DEFINE_VISIT?");
176 }
177 }
178
179#define DEFINE_VISIT(T) \
180 virtual void visit(T *stmt) { \
181 if (allow_undefined_visitor) { \
182 if (invoke_default_visitor) \
183 visit((Stmt *)stmt); \
184 } else \
185 TI_NOT_IMPLEMENTED; \
186 }
187
188 DEFINE_VISIT(Block);
189#define PER_STATEMENT(x) DEFINE_VISIT(x)
190#include "taichi/inc/statements.inc.h"
191
192#undef PER_STATEMENT
193#undef DEFINE_VISIT
194};
195
196struct CompileConfig;
197class Kernel;
198
199using stmt_refs = one_or_more<Stmt *>;
200
201namespace ir_traits {
202
203// FIXME: Use C++ 20 concepts to replace `dynamic_cast<T>() != nullptr`
204
205class Store {
206 public:
207 virtual ~Store() = default;
208
209 // Get the list of sinks/destinations of the store operation
210 virtual stmt_refs get_store_destination() const = 0;
211
212 // If store_stmt provides one data source, return the data.
213 virtual Stmt *get_store_data() const = 0;
214};
215
216class Load {
217 public:
218 virtual ~Load() = default;
219
220 // If load_stmt loads some variables or a stack, return the pointers of them.
221 virtual stmt_refs get_load_pointers() const = 0;
222};
223
224} // namespace ir_traits
225
226class IRNode {
227 public:
228 virtual void accept(IRVisitor *visitor) {
229 TI_NOT_IMPLEMENTED
230 }
231
232 // * For a Stmt, this returns its enclosing Block
233 // * For a Block, this returns its enclosing Stmt
234 virtual IRNode *get_parent() const = 0;
235
236 IRNode *get_ir_root();
237
238 virtual ~IRNode() = default;
239
240 template <typename T>
241 bool is() const {
242 return dynamic_cast<const T *>(this) != nullptr;
243 }
244
245 template <typename T>
246 T *as() {
247 TI_ASSERT(is<T>());
248 return dynamic_cast<T *>(this);
249 }
250
251 template <typename T>
252 const T *as() const {
253 TI_ASSERT(is<T>());
254 return dynamic_cast<const T *>(this);
255 }
256
257 template <typename T>
258 T *cast() {
259 return dynamic_cast<T *>(this);
260 }
261
262 template <typename T>
263 const T *cast() const {
264 return dynamic_cast<const T *>(this);
265 }
266
267 std::unique_ptr<IRNode> clone();
268};
269
270#define TI_DEFINE_ACCEPT \
271 void accept(IRVisitor *visitor) override { visitor->visit(this); }
272
273#define TI_DEFINE_CLONE \
274 std::unique_ptr<Stmt> clone() const override { \
275 auto new_stmt = \
276 std::make_unique<std::decay<decltype(*this)>::type>(*this); \
277 new_stmt->mark_fields_registered(); \
278 new_stmt->io(new_stmt->field_manager); \
279 return new_stmt; \
280 }
281
282#define TI_DEFINE_ACCEPT_AND_CLONE \
283 TI_DEFINE_ACCEPT \
284 TI_DEFINE_CLONE
285
286class StmtField {
287 public:
288 StmtField() = default;
289
290 virtual bool equal(const StmtField *other) const = 0;
291
292 virtual ~StmtField() = default;
293};
294
295template <typename T>
296class StmtFieldNumeric final : public StmtField {
297 private:
298 std::variant<T *, T> value_;
299
300 public:
301 explicit StmtFieldNumeric(T *value) : value_(value) {
302 }
303
304 explicit StmtFieldNumeric(T value) : value_(value) {
305 }
306
307 bool equal(const StmtField *other_generic) const override {
308 if (auto other = dynamic_cast<const StmtFieldNumeric *>(other_generic)) {
309 if (std::holds_alternative<T *>(other->value_) &&
310 std::holds_alternative<T *>(value_)) {
311 return *(std::get<T *>(other->value_)) == *(std::get<T *>(value_));
312 } else if (std::holds_alternative<T *>(other->value_) ||
313 std::holds_alternative<T *>(value_)) {
314 TI_ERROR(
315 "Inconsistent StmtField value types: a pointer value is compared "
316 "to a non-pointer value.");
317 return false;
318 } else {
319 return std::get<T>(other->value_) == std::get<T>(value_);
320 }
321 } else {
322 // Different types
323 return false;
324 }
325 }
326};
327
328class StmtFieldSNode final : public StmtField {
329 private:
330 SNode *const &snode_;
331
332 public:
333 explicit StmtFieldSNode(SNode *const &snode) : snode_(snode) {
334 }
335
336 static int get_snode_id(SNode *snode);
337
338 bool equal(const StmtField *other_generic) const override;
339};
340
341class StmtFieldMemoryAccessOptions final : public StmtField {
342 private:
343 MemoryAccessOptions const &opt_;
344
345 public:
346 explicit StmtFieldMemoryAccessOptions(MemoryAccessOptions const &opt)
347 : opt_(opt) {
348 }
349
350 bool equal(const StmtField *other_generic) const override;
351};
352
353class StmtFieldManager {
354 private:
355 Stmt *stmt_;
356
357 public:
358 std::vector<std::unique_ptr<StmtField>> fields;
359
360 explicit StmtFieldManager(Stmt *stmt) : stmt_(stmt) {
361 }
362
363 template <typename T>
364 void operator()(const char *key, T &&value);
365
366 template <typename T, typename... Args>
367 void operator()(const char *key_, T &&t, Args &&...rest) {
368 std::string key(key_);
369 size_t pos = key.find(',');
370 std::string first_name = key.substr(0, pos);
371 std::string rest_names =
372 key.substr(pos + 2, int(key.size()) - (int)pos - 2);
373 this->operator()(first_name.c_str(), std::forward<T>(t));
374 this->operator()(rest_names.c_str(), std::forward<Args>(rest)...);
375 }
376
377 bool equal(StmtFieldManager &other) const;
378};
379
380#define TI_STMT_DEF_FIELDS(...) \
381 template <typename S> \
382 void io(S &serializer) const { \
383 TI_IO(__VA_ARGS__); \
384 }
385#define TI_STMT_REG_FIELDS \
386 mark_fields_registered(); \
387 io(field_manager)
388
389class Stmt : public IRNode {
390 protected:
391 std::vector<Stmt **> operands;
392
393 public:
394 StmtFieldManager field_manager;
395 static std::atomic<int> instance_id_counter;
396 int instance_id;
397 int id;
398 Block *parent;
399 bool erased;
400 bool fields_registered;
401 std::string tb;
402 DataType ret_type;
403
404 Stmt();
405 Stmt(const Stmt &stmt);
406
407 virtual bool is_container_statement() const {
408 return false;
409 }
410
411 DataType &element_type() {
412 return ret_type;
413 }
414
415 std::string ret_data_type_name() const {
416 return ret_type->to_string();
417 }
418
419 std::string type_hint() const;
420
421 std::string name() const {
422 return fmt::format("${}", id);
423 }
424
425 std::string short_name() const {
426 return make_short_name_by_id(id);
427 }
428
429 std::string raw_name() const {
430 return fmt::format("tmp{}", id);
431 }
432
433 TI_FORCE_INLINE int num_operands() const {
434 return (int)operands.size();
435 }
436
437 TI_FORCE_INLINE Stmt *operand(int i) const {
438 // TI_ASSERT(0 <= i && i < (int)operands.size());
439 return *operands[i];
440 }
441
442 std::vector<Stmt *> get_operands() const;
443
444 void set_operand(int i, Stmt *stmt);
445 void register_operand(Stmt *&stmt);
446 int locate_operand(Stmt **stmt);
447 void mark_fields_registered();
448
449 bool has_operand(Stmt *stmt) const;
450
451 void replace_usages_with(Stmt *new_stmt);
452 void replace_with(VecStatement &&new_statements, bool replace_usages = true);
453 virtual void replace_operand_with(Stmt *old_stmt, Stmt *new_stmt);
454
455 IRNode *get_parent() const override;
456
457 // returns the inserted stmt
458 Stmt *insert_before_me(std::unique_ptr<Stmt> &&new_stmt);
459
460 // returns the inserted stmt
461 Stmt *insert_after_me(std::unique_ptr<Stmt> &&new_stmt);
462
463 virtual bool has_global_side_effect() const {
464 return true;
465 }
466
467 virtual bool dead_instruction_eliminable() const {
468 return !has_global_side_effect();
469 }
470
471 virtual bool common_statement_eliminable() const {
472 return !has_global_side_effect();
473 }
474
475 template <typename T, typename... Args>
476 static std::unique_ptr<T> make_typed(Args &&...args) {
477 return std::make_unique<T>(std::forward<Args>(args)...);
478 }
479
480 template <typename T, typename... Args>
481 static pStmt make(Args &&...args) {
482 return make_typed<T>(std::forward<Args>(args)...);
483 }
484
485 void set_tb(const std::string &tb) {
486 this->tb = tb;
487 }
488
489 std::string type();
490
491 virtual std::unique_ptr<Stmt> clone() const {
492 TI_NOT_IMPLEMENTED
493 }
494
495 ~Stmt() override = default;
496
497 static void reset_counter() {
498 instance_id_counter = 0;
499 }
500};
501
502class Block : public IRNode {
503 public:
504 Stmt *parent_stmt{nullptr};
505 stmt_vector statements;
506 stmt_vector trash_bin;
507 std::vector<SNode *> stop_gradients;
508
509 // Only used in frontend. Stores LoopIndexStmt or BinaryOpStmt for loop
510 // variables, and AllocaStmt for other variables.
511 std::map<Identifier, Stmt *> local_var_to_stmt;
512
513 Block() {
514 parent_stmt = nullptr;
515 }
516
517 Block *parent_block() const;
518
519 bool has_container_statements();
520 int locate(Stmt *stmt);
521 stmt_vector::iterator locate(int location);
522 stmt_vector::iterator find(Stmt *stmt);
523 void erase(int location);
524 void erase(Stmt *stmt);
525 void erase_range(stmt_vector::iterator begin, stmt_vector::iterator end);
526 void erase(std::unordered_set<Stmt *> stmts);
527 std::unique_ptr<Stmt> extract(int location);
528 std::unique_ptr<Stmt> extract(Stmt *stmt);
529
530 // Returns stmt.get()
531 Stmt *insert(std::unique_ptr<Stmt> &&stmt, int location = -1);
532 Stmt *insert_at(std::unique_ptr<Stmt> &&stmt, stmt_vector::iterator location);
533
534 // Returns stmt.back().get() or nullptr if stmt is empty
535 Stmt *insert(VecStatement &&stmt, int location = -1);
536 Stmt *insert_at(VecStatement &&stmt, stmt_vector::iterator location);
537
538 void replace_statements_in_range(int start, int end, VecStatement &&stmts);
539 void set_statements(VecStatement &&stmts);
540 void replace_with(Stmt *old_statement,
541 std::unique_ptr<Stmt> &&new_statement,
542 bool replace_usages = true);
543 void insert_before(Stmt *old_statement, VecStatement &&new_statements);
544 void insert_after(Stmt *old_statement, VecStatement &&new_statements);
545 void replace_with(Stmt *old_statement,
546 VecStatement &&new_statements,
547 bool replace_usages = true);
548 Stmt *lookup_var(const Identifier &ident) const;
549 IRNode *get_parent() const override;
550
551 Stmt *back() const {
552 return statements.back().get();
553 }
554
555 template <typename T, typename... Args>
556 Stmt *push_back(Args &&...args) {
557 auto stmt = std::make_unique<T>(std::forward<Args>(args)...);
558 stmt->parent = this;
559 statements.emplace_back(std::move(stmt));
560 return back();
561 }
562
563 std::size_t size() const {
564 return statements.size();
565 }
566
567 pStmt &operator[](int i) {
568 return statements[i];
569 }
570
571 std::unique_ptr<Block> clone() const;
572
573 TI_DEFINE_ACCEPT
574};
575
576class DelayedIRModifier {
577 private:
578 std::vector<std::pair<Stmt *, VecStatement>> to_insert_before_;
579 std::vector<std::pair<Stmt *, VecStatement>> to_insert_after_;
580 std::vector<std::tuple<Stmt *, VecStatement, bool>> to_replace_with_;
581 std::vector<Stmt *> to_erase_;
582 std::vector<std::pair<Stmt *, Block *>> to_extract_to_block_front_;
583 std::vector<std::pair<IRNode *, CompileConfig>> to_type_check_;
584 bool modified_{false};
585
586 public:
587 ~DelayedIRModifier();
588 void erase(Stmt *stmt);
589 void insert_before(Stmt *old_statement, std::unique_ptr<Stmt> new_statement);
590 void insert_before(Stmt *old_statement, VecStatement &&new_statements);
591 void insert_after(Stmt *old_statement, std::unique_ptr<Stmt> new_statement);
592 void insert_after(Stmt *old_statement, VecStatement &&new_statements);
593 void replace_with(Stmt *stmt,
594 VecStatement &&new_statements,
595 bool replace_usages = true);
596 void extract_to_block_front(Stmt *stmt, Block *blk);
597 void type_check(IRNode *node, CompileConfig cfg);
598 bool modify_ir();
599
600 // Force the next call of modify_ir() to return true.
601 void mark_as_modified();
602};
603
604// ImmediateIRModifier aims at replacing Stmt::replace_usages_with, which visits
605// the whole tree for a single replacement. ImmediateIRModifier is currently
606// associated with a pass, visits the whole tree once at the beginning of that
607// pass, and performs a single replacement with amortized constant time.
608class ImmediateIRModifier {
609 private:
610 std::unordered_map<Stmt *, std::vector<std::pair<Stmt *, int>>> stmt_usages_;
611
612 public:
613 explicit ImmediateIRModifier(IRNode *root);
614 void replace_usages_with(Stmt *old_stmt, Stmt *new_stmt);
615};
616
617template <typename T>
618inline void StmtFieldManager::operator()(const char *key, T &&value) {
619 using decay_T = typename std::decay<T>::type;
620 if constexpr (is_specialization<decay_T, std::vector>::value) {
621 stmt_->field_manager.fields.emplace_back(
622 std::make_unique<StmtFieldNumeric<std::size_t>>(value.size()));
623 for (int i = 0; i < (int)value.size(); i++) {
624 (*this)("__element", value[i]);
625 }
626 } else if constexpr (std::is_same<decay_T,
627 std::variant<Stmt *, std::string>>::value) {
628 if (std::holds_alternative<std::string>(value)) {
629 stmt_->field_manager.fields.emplace_back(
630 std::make_unique<StmtFieldNumeric<std::string>>(
631 std::get<std::string>(value)));
632 } else {
633 (*this)("__element", std::get<Stmt *>(value));
634 }
635 } else if constexpr (std::is_same<decay_T, Stmt *>::value) {
636 stmt_->register_operand(const_cast<Stmt *&>(value));
637 } else if constexpr (std::is_same<decay_T, SNode *>::value) {
638 stmt_->field_manager.fields.emplace_back(
639 std::make_unique<StmtFieldSNode>(value));
640 } else if constexpr (std::is_same<decay_T, MemoryAccessOptions>::value) {
641 stmt_->field_manager.fields.emplace_back(
642 std::make_unique<StmtFieldMemoryAccessOptions>(value));
643 } else {
644 stmt_->field_manager.fields.emplace_back(
645 std::make_unique<StmtFieldNumeric<std::remove_reference_t<T>>>(&value));
646 }
647}
648
649} // namespace taichi::lang
650