1 | /* |
---|---|
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include <tvm/tir/analysis.h> |
20 | #include <tvm/tir/stmt_functor.h> |
21 | |
22 | namespace tvm { |
23 | namespace tir { |
24 | |
25 | const PrimFuncNode* FindEntryFunc(const IRModule& mod, GlobalVar* result_g_var) { |
26 | GlobalVar result = NullValue<GlobalVar>(); |
27 | // Priority 1: PrimFunc marked as `tir::attr::kIsEntryFunc` |
28 | int num_prim_func = 0; |
29 | const tir::PrimFuncNode* main_func = nullptr; |
30 | const tir::PrimFuncNode* last_func = nullptr; |
31 | for (const auto& kv : mod->functions) { |
32 | GlobalVar gv = kv.first; |
33 | BaseFunc base_func = kv.second; |
34 | if (const auto* func = base_func.as<tir::PrimFuncNode>()) { |
35 | last_func = func; |
36 | if (func->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { |
37 | if (result_g_var != nullptr) { |
38 | *result_g_var = gv; |
39 | } |
40 | return func; |
41 | } |
42 | if (gv->name_hint == "main") { |
43 | main_func = func; |
44 | result = gv; |
45 | } |
46 | ++num_prim_func; |
47 | } |
48 | } |
49 | // Priority 2: PrimFunc whose name is `main` |
50 | if (main_func != nullptr) { |
51 | if (result_g_var != nullptr) { |
52 | *result_g_var = result; |
53 | } |
54 | return main_func; |
55 | } |
56 | // Priority 3: The only PrimFunc in the IRModule |
57 | if (num_prim_func == 1) { |
58 | if (result_g_var != nullptr) { |
59 | *result_g_var = result; |
60 | } |
61 | return last_func; |
62 | } |
63 | return nullptr; |
64 | } |
65 | |
66 | Stmt GetEnclosingLoop(const BlockNode* block, Stmt func_body) { |
67 | struct GetRootSeqStmt : public StmtVisitor { |
68 | void VisitStmt_(const SeqStmtNode* seq) override { result = seq; } |
69 | const SeqStmtNode* result; |
70 | }; |
71 | |
72 | struct BlockFinder : public StmtVisitor { |
73 | explicit BlockFinder(const BlockNode* tgt) : target(tgt) {} |
74 | |
75 | void VisitStmt_(const BlockNode* block) override { |
76 | if (block == target) { |
77 | found = true; |
78 | } |
79 | } |
80 | |
81 | const BlockNode* target; |
82 | bool found = false; |
83 | }; |
84 | |
85 | GetRootSeqStmt seq_finder; |
86 | seq_finder(func_body); |
87 | |
88 | ICHECK(seq_finder.result); |
89 | |
90 | for (auto stmt : seq_finder.result->seq) { |
91 | if (stmt->IsInstance<ForNode>()) { |
92 | BlockFinder finder(block); |
93 | finder(stmt); |
94 | if (finder.found) { |
95 | return stmt; |
96 | } |
97 | } |
98 | } |
99 | |
100 | LOG(FATAL) << "Enclosing loop not found for a block "<< GetRef<Block>(block); |
101 | } |
102 | |
103 | const BlockNode* FindAnchorBlock(const IRModule& mod) { |
104 | struct ReductionBlockCollector : public StmtVisitor { |
105 | void VisitStmt_(const BlockNode* block) override { |
106 | if (block->init) { |
107 | blocks.push_back(block); |
108 | } |
109 | StmtVisitor::VisitStmt(block->body); |
110 | } |
111 | std::vector<const BlockNode*> blocks; |
112 | }; |
113 | |
114 | if (auto prim_func = FindEntryFunc(mod, nullptr)) { |
115 | ReductionBlockCollector collector; |
116 | collector(prim_func->body); |
117 | |
118 | const auto& candidates = collector.blocks; |
119 | |
120 | if (candidates.empty()) { |
121 | return nullptr; |
122 | } else if (candidates.size() == 1) { |
123 | return candidates[0]; |
124 | } |
125 | |
126 | double best_flops = -1; |
127 | int best_idx = 0; |
128 | for (size_t i = 0; i < candidates.size(); ++i) { |
129 | auto loop = GetEnclosingLoop(candidates[i], prim_func->body); |
130 | auto flops = EstimateTIRFlops(loop); |
131 | if (flops > best_flops) { |
132 | best_flops = flops; |
133 | best_idx = i; |
134 | } |
135 | } |
136 | return candidates[best_idx]; |
137 | } |
138 | return nullptr; |
139 | } |
140 | |
141 | TVM_REGISTER_GLOBAL("tir.analysis.find_anchor_block").set_body_typed([](const IRModule& mod) { |
142 | auto ret = FindAnchorBlock(mod); |
143 | if (ret) { |
144 | return Optional<Block>(GetRef<Block>(ret)); |
145 | } |
146 | return Optional<Block>(NullOpt); |
147 | }); |
148 | |
149 | } // namespace tir |
150 | } // namespace tvm |
151 |