1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include "../utils.h" |
20 | |
21 | namespace tvm { |
22 | namespace tir { |
23 | |
24 | class SRefTreeVerifier : public StmtVisitor { |
25 | public: |
26 | static void Verify(const ScheduleStateNode* self) { SRefTreeVerifier(self).Verify(); } |
27 | |
28 | private: |
29 | /*! \brief Constructor */ |
30 | explicit SRefTreeVerifier(const ScheduleStateNode* self) : self_(self) {} |
31 | |
32 | void Verify() { |
33 | VisitPrimFuncs(self_->mod, [this](const PrimFuncNode* func) { this->VisitStmt(func->body); }); |
34 | ICHECK_EQ(n_sref_visited_, static_cast<int>(self_->stmt2ref.size())); |
35 | for (const auto& kv : self_->block_info) { |
36 | const StmtSRef& sref = kv.first; |
37 | ICHECK(sref->stmt != nullptr) |
38 | << "InternalError: An expired sref is found in the block_scope mapping" ; |
39 | auto it = self_->stmt2ref.find(sref->stmt); |
40 | ICHECK(it != self_->stmt2ref.end()) |
41 | << "InternalError: The sref points to a statement that does not exist in stmt2ref" ; |
42 | const StmtSRef& sref2 = it->second; |
43 | ICHECK(sref.same_as(sref2)) |
44 | << "InternalError: The sref points to a statement whose corresponding sref in stmt2ref " |
45 | "is not the same object as itself" ; |
46 | } |
47 | ICHECK_EQ(n_block_sref_visited_, static_cast<int>(self_->block_info.size())); |
48 | } |
49 | |
50 | void VisitStmt_(const BlockNode* block) final { |
51 | if (init_block_depth_) { |
52 | ICHECK(!self_->stmt2ref.count(block)) << "InternalError: A block inside init block has its " |
53 | "corresponding sref, which is not allowed" ; |
54 | StmtVisitor::VisitStmt_(block); |
55 | return; |
56 | } |
57 | ICHECK(self_->stmt2ref.count(block)) |
58 | << "InternalError: A BlockNode should appear in sref map, but it didn't\n" |
59 | << GetRef<Stmt>(block); |
60 | ++n_sref_visited_; |
61 | ++n_block_sref_visited_; |
62 | const StmtSRef& sref = self_->stmt2ref.at(block); |
63 | ICHECK(self_->block_info.count(sref)) |
64 | << "InternalError: Cannot find scope information of the BlockNode:\n" |
65 | << GetRef<Stmt>(block); |
66 | ICHECK(sref->parent == ancestors_.back()) |
67 | << "InternalError: Parent information mismatch for BlockNode:\n" |
68 | << GetRef<Stmt>(block) << "\nIts parent is supposed to be:\n" |
69 | << GetRef<Stmt>(ancestors_.back()->stmt) << "\nHowever, its parent is incorrect and is:\n" |
70 | << (sref->parent ? Optional<Stmt>(GetRef<Stmt>(sref->parent->stmt)) |
71 | : Optional<Stmt>(NullOpt)); |
72 | ancestors_.push_back(sref.operator->()); |
73 | if (block->init.defined()) { |
74 | ++init_block_depth_; |
75 | VisitStmt(block->init.value()); |
76 | --init_block_depth_; |
77 | } |
78 | VisitStmt(block->body); |
79 | ancestors_.pop_back(); |
80 | } |
81 | |
82 | void VisitStmt_(const ForNode* loop) final { |
83 | if (init_block_depth_) { |
84 | ICHECK(!self_->stmt2ref.count(loop)) << "InternalError: A loop inside init block has its " |
85 | "corresponding sref, which is not allowed" ; |
86 | StmtVisitor::VisitStmt_(loop); |
87 | return; |
88 | } |
89 | ICHECK(self_->stmt2ref.count(loop)) |
90 | << "InternalError: A ForNode should appear in sref map, but it didn't\n" |
91 | << GetRef<Stmt>(loop); |
92 | ++n_sref_visited_; |
93 | const StmtSRef& sref = self_->stmt2ref.at(loop); |
94 | Optional<Stmt> stmt = NullOpt; |
95 | ICHECK(sref->parent == ancestors_.back()) |
96 | << "InternalError: Parent information mismatch for ForNode:\n" |
97 | << GetRef<Stmt>(loop) << "\nIts parent is supposed to be:\n" |
98 | << GetRef<Stmt>(ancestors_.back()->stmt) << "\nHowever, its parent is incorrect and is:\n" |
99 | << (sref->parent ? Optional<Stmt>(GetRef<Stmt>(sref->parent->stmt)) |
100 | : Optional<Stmt>(NullOpt)); |
101 | ancestors_.push_back(sref.operator->()); |
102 | StmtVisitor::VisitStmt_(loop); |
103 | ancestors_.pop_back(); |
104 | } |
105 | |
106 | void VisitStmt_(const SeqStmtNode* seq_stmt) final { |
107 | // Verify seq_index |
108 | if (init_block_depth_) { |
109 | StmtVisitor::VisitStmt_(seq_stmt); |
110 | return; |
111 | } |
112 | int n = static_cast<int>(seq_stmt->seq.size()); |
113 | for (int i = 0; i < n; ++i) { |
114 | const Stmt& child = seq_stmt->seq[i]; |
115 | StmtSRef sref{nullptr}; |
116 | if (const auto* realize = child.as<BlockRealizeNode>()) { |
117 | const auto* block = realize->block.get(); |
118 | ICHECK(self_->stmt2ref.count(block)); |
119 | sref = self_->stmt2ref.at(block); |
120 | } else if (child->IsInstance<ForNode>()) { |
121 | ICHECK(self_->stmt2ref.count(child.get())); |
122 | sref = self_->stmt2ref.at(child.get()); |
123 | } else { |
124 | continue; |
125 | } |
126 | ICHECK_EQ(sref->seq_index, i) << "InternalError: A StmtSRef has incorrect seq_index" ; |
127 | } |
128 | StmtVisitor::VisitStmt_(seq_stmt); |
129 | } |
130 | |
131 | /*! \brief The schedule it belongs to */ |
132 | const ScheduleStateNode* self_; |
133 | /*! \brief Parent information during the visit */ |
134 | std::vector<const StmtSRefNode*> ancestors_ = {nullptr}; |
135 | /*! \brief If the visitor is currently in the init block */ |
136 | int init_block_depth_ = 0; |
137 | /*! \brief Number of srefs that are visited */ |
138 | int n_sref_visited_ = 0; |
139 | /*! \brief Number of block srefs that are visited */ |
140 | int n_block_sref_visited_ = 0; |
141 | }; |
142 | |
143 | void VerifySRefTree(const ScheduleState& self) { SRefTreeVerifier::Verify(self.get()); } |
144 | |
145 | void VerifyCachedFlags(const ScheduleState& self) { |
146 | std::vector<StmtSRef> block_info_not_found; |
147 | std::vector<std::tuple<StmtSRef, bool, bool>> block_info_wrong_affine_binding; |
148 | std::vector<std::tuple<StmtSRef, bool, bool>> block_info_wrong_region_cover; |
149 | std::vector<std::tuple<StmtSRef, bool, bool>> block_info_wrong_stage_pipeline; |
150 | |
151 | ScheduleState new_state(self->mod); |
152 | for (const auto& kv : new_state->stmt2ref) { |
153 | const StmtNode* stmt = kv.first; |
154 | const StmtSRef& new_sref = kv.second; |
155 | if (stmt->IsInstance<ForNode>() || !self->stmt2ref.count(stmt)) { |
156 | continue; |
157 | } |
158 | const BlockInfo& new_block_info = new_state->block_info.at(new_sref); |
159 | const StmtSRef& old_sref = self->stmt2ref.at(stmt); |
160 | if (!self->block_info.count(old_sref)) { |
161 | block_info_not_found.push_back(new_sref); |
162 | continue; |
163 | } |
164 | const BlockInfo& old_block_info = self->block_info.at(old_sref); |
165 | if (new_block_info.affine_binding != old_block_info.affine_binding) { |
166 | block_info_wrong_affine_binding.emplace_back(new_sref, // |
167 | new_block_info.affine_binding, |
168 | old_block_info.affine_binding); |
169 | } |
170 | if (new_block_info.region_cover != old_block_info.region_cover) { |
171 | block_info_wrong_region_cover.emplace_back(new_sref, // |
172 | new_block_info.region_cover, |
173 | old_block_info.region_cover); |
174 | } |
175 | if (new_block_info.scope->stage_pipeline != old_block_info.scope->stage_pipeline) { |
176 | block_info_wrong_stage_pipeline.emplace_back(new_sref, // |
177 | new_block_info.scope->stage_pipeline, |
178 | old_block_info.scope->stage_pipeline); |
179 | } |
180 | } |
181 | |
182 | bool has_not_found = !block_info_not_found.empty(); |
183 | bool has_wrong_affine_binding = !block_info_wrong_affine_binding.empty(); |
184 | bool has_wrong_region_cover = !block_info_wrong_region_cover.empty(); |
185 | bool has_wrong_stage_pipeline = !block_info_wrong_stage_pipeline.empty(); |
186 | if (!(has_not_found || has_wrong_affine_binding || has_wrong_region_cover || |
187 | has_wrong_stage_pipeline)) { |
188 | return; |
189 | } |
190 | std::ostringstream os; |
191 | if (has_not_found) { |
192 | os << "- BlockInfo not found:" ; |
193 | for (const StmtSRef& block_sref : block_info_not_found) { |
194 | const auto* block = block_sref->StmtAs<BlockNode>(); |
195 | ICHECK(block); |
196 | os << " " << block->name_hint; |
197 | } |
198 | os << std::endl; |
199 | } |
200 | if (has_wrong_affine_binding) { |
201 | os << "- Wrong affine_binding: " ; |
202 | for (const std::tuple<StmtSRef, bool, bool>& record : block_info_wrong_affine_binding) { |
203 | const StmtSRef& block_sref = std::get<0>(record); |
204 | bool expected = std::get<1>(record); |
205 | bool actual = std::get<2>(record); |
206 | const auto* block = block_sref->StmtAs<BlockNode>(); |
207 | ICHECK(block); |
208 | os << " (" << block->name_hint << ", expected=" << expected << ", actual=" << actual << ")" ; |
209 | } |
210 | os << std::endl; |
211 | } |
212 | if (has_wrong_region_cover) { |
213 | os << "- Wrong region_cover: " ; |
214 | for (const std::tuple<StmtSRef, bool, bool>& record : block_info_wrong_region_cover) { |
215 | const StmtSRef& block_sref = std::get<0>(record); |
216 | bool expected = std::get<1>(record); |
217 | bool actual = std::get<2>(record); |
218 | const auto* block = block_sref->StmtAs<BlockNode>(); |
219 | ICHECK(block); |
220 | os << " (" << block->name_hint << ", expected=" << expected << ", actual=" << actual << ")" ; |
221 | } |
222 | os << std::endl; |
223 | } |
224 | if (has_wrong_stage_pipeline) { |
225 | os << "- Wrong stage_pipeline: " ; |
226 | for (const std::tuple<StmtSRef, bool, bool>& record : block_info_wrong_stage_pipeline) { |
227 | const StmtSRef& block_sref = std::get<0>(record); |
228 | bool expected = std::get<1>(record); |
229 | bool actual = std::get<2>(record); |
230 | const auto* block = block_sref->StmtAs<BlockNode>(); |
231 | ICHECK(block); |
232 | os << " (" << block->name_hint << ", expected=" << expected << ", actual=" << actual << ")" ; |
233 | } |
234 | os << std::endl; |
235 | } |
236 | LOG(FATAL) << "Schedule verification failed. The IR is:\n" |
237 | << self->mod << "\nThe errors are:\n" |
238 | << os.str(); |
239 | throw; |
240 | } |
241 | |
242 | } // namespace tir |
243 | } // namespace tvm |
244 | |