1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include <tvm/tir/analysis.h>
20#include <tvm/tir/stmt_functor.h>
21
22namespace tvm {
23namespace tir {
24
25const PrimFuncNode* FindEntryFunc(const IRModule& mod, GlobalVar* result_g_var) {
26 GlobalVar result = NullValue<GlobalVar>();
27 // Priority 1: PrimFunc marked as `tir::attr::kIsEntryFunc`
28 int num_prim_func = 0;
29 const tir::PrimFuncNode* main_func = nullptr;
30 const tir::PrimFuncNode* last_func = nullptr;
31 for (const auto& kv : mod->functions) {
32 GlobalVar gv = kv.first;
33 BaseFunc base_func = kv.second;
34 if (const auto* func = base_func.as<tir::PrimFuncNode>()) {
35 last_func = func;
36 if (func->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
37 if (result_g_var != nullptr) {
38 *result_g_var = gv;
39 }
40 return func;
41 }
42 if (gv->name_hint == "main") {
43 main_func = func;
44 result = gv;
45 }
46 ++num_prim_func;
47 }
48 }
49 // Priority 2: PrimFunc whose name is `main`
50 if (main_func != nullptr) {
51 if (result_g_var != nullptr) {
52 *result_g_var = result;
53 }
54 return main_func;
55 }
56 // Priority 3: The only PrimFunc in the IRModule
57 if (num_prim_func == 1) {
58 if (result_g_var != nullptr) {
59 *result_g_var = result;
60 }
61 return last_func;
62 }
63 return nullptr;
64}
65
66Stmt GetEnclosingLoop(const BlockNode* block, Stmt func_body) {
67 struct GetRootSeqStmt : public StmtVisitor {
68 void VisitStmt_(const SeqStmtNode* seq) override { result = seq; }
69 const SeqStmtNode* result;
70 };
71
72 struct BlockFinder : public StmtVisitor {
73 explicit BlockFinder(const BlockNode* tgt) : target(tgt) {}
74
75 void VisitStmt_(const BlockNode* block) override {
76 if (block == target) {
77 found = true;
78 }
79 }
80
81 const BlockNode* target;
82 bool found = false;
83 };
84
85 GetRootSeqStmt seq_finder;
86 seq_finder(func_body);
87
88 ICHECK(seq_finder.result);
89
90 for (auto stmt : seq_finder.result->seq) {
91 if (stmt->IsInstance<ForNode>()) {
92 BlockFinder finder(block);
93 finder(stmt);
94 if (finder.found) {
95 return stmt;
96 }
97 }
98 }
99
100 LOG(FATAL) << "Enclosing loop not found for a block " << GetRef<Block>(block);
101}
102
103const BlockNode* FindAnchorBlock(const IRModule& mod) {
104 struct ReductionBlockCollector : public StmtVisitor {
105 void VisitStmt_(const BlockNode* block) override {
106 if (block->init) {
107 blocks.push_back(block);
108 }
109 StmtVisitor::VisitStmt(block->body);
110 }
111 std::vector<const BlockNode*> blocks;
112 };
113
114 if (auto prim_func = FindEntryFunc(mod, nullptr)) {
115 ReductionBlockCollector collector;
116 collector(prim_func->body);
117
118 const auto& candidates = collector.blocks;
119
120 if (candidates.empty()) {
121 return nullptr;
122 } else if (candidates.size() == 1) {
123 return candidates[0];
124 }
125
126 double best_flops = -1;
127 int best_idx = 0;
128 for (size_t i = 0; i < candidates.size(); ++i) {
129 auto loop = GetEnclosingLoop(candidates[i], prim_func->body);
130 auto flops = EstimateTIRFlops(loop);
131 if (flops > best_flops) {
132 best_flops = flops;
133 best_idx = i;
134 }
135 }
136 return candidates[best_idx];
137 }
138 return nullptr;
139}
140
141TVM_REGISTER_GLOBAL("tir.analysis.find_anchor_block").set_body_typed([](const IRModule& mod) {
142 auto ret = FindAnchorBlock(mod);
143 if (ret) {
144 return Optional<Block>(GetRef<Block>(ret));
145 }
146 return Optional<Block>(NullOpt);
147});
148
149} // namespace tir
150} // namespace tvm
151