1#include "taichi/ir/ir.h"
2#include "taichi/ir/transforms.h"
3#include "taichi/ir/visitors.h"
4#include "taichi/ir/frontend_ir.h"
5#include "taichi/system/profiler.h"
6
7#include <set>
8
9namespace taichi::lang {
10
11namespace irpass {
12
13// TODO: gather Expr as well?
14class GatherStmts : public BasicStmtVisitor {
15 public:
16 using BasicStmtVisitor::visit;
17
18 std::vector<Stmt *> stmts;
19
20 GatherStmts() {
21 invoke_default_visitor = true;
22 }
23
24 void visit(Stmt *stmt) override {
25 stmts.push_back(stmt);
26 }
27};
28
29void reverse_segments(IRNode *root) {
30 TI_AUTO_PROF;
31 auto block = dynamic_cast<Block *>(root);
32 std::vector<std::vector<pStmt>> statement_blocks(1);
33 bool has_for = false;
34 bool has_non_for = false;
35 for (auto &&s : block->statements) {
36 if (s->is<FrontendForStmt>()) {
37 has_for = true;
38 statement_blocks.emplace_back();
39 statement_blocks.back().push_back(std::move(s));
40 statement_blocks.emplace_back();
41 } else {
42 has_non_for = true;
43 statement_blocks.back().push_back(std::move(s));
44 }
45 }
46 block->statements.clear();
47 std::reverse(statement_blocks.begin(), statement_blocks.end());
48 /*
49 for (auto &b : statement_blocks) {
50 std::vector<Stmt *> stmts;
51 for (auto &s : b) {
52 GatherStmts gather;
53 s->accept(&gather);
54 stmts.insert(stmts.end(), gather.stmts.begin(), gather.stmts.end());
55 }
56 std::set<Stmt *> stmt_set(stmts.begin(), stmts.end());
57 bool valid = true;
58 for (auto s : stmts) {
59 for (auto op : s->get_operands()) {
60 if (stmt_set.find(op) == stmt_set.end()) {
61 valid = false;
62 }
63 }
64 }
65 }
66 */
67 if (has_for && has_non_for) {
68 TI_ERROR(
69 "Invalid program input for autodiff: "
70 "Mixed usage of for-loops and statements without looping. \n"
71 "Please split them into two kernels "
72 "and check the documentation for more details:\n"
73 "https://docs.taichi-lang.org/docs/"
74 "differentiable_programming");
75 }
76 for (auto &sblock : statement_blocks) {
77 for (auto &&s : sblock) {
78 block->statements.push_back(std::move(s));
79 }
80 }
81}
82
83} // namespace irpass
84
85} // namespace taichi::lang
86