1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "../utils.h"
20
21namespace tvm {
22namespace tir {
23
24class SRefTreeVerifier : public StmtVisitor {
25 public:
26 static void Verify(const ScheduleStateNode* self) { SRefTreeVerifier(self).Verify(); }
27
28 private:
29 /*! \brief Constructor */
30 explicit SRefTreeVerifier(const ScheduleStateNode* self) : self_(self) {}
31
32 void Verify() {
33 VisitPrimFuncs(self_->mod, [this](const PrimFuncNode* func) { this->VisitStmt(func->body); });
34 ICHECK_EQ(n_sref_visited_, static_cast<int>(self_->stmt2ref.size()));
35 for (const auto& kv : self_->block_info) {
36 const StmtSRef& sref = kv.first;
37 ICHECK(sref->stmt != nullptr)
38 << "InternalError: An expired sref is found in the block_scope mapping";
39 auto it = self_->stmt2ref.find(sref->stmt);
40 ICHECK(it != self_->stmt2ref.end())
41 << "InternalError: The sref points to a statement that does not exist in stmt2ref";
42 const StmtSRef& sref2 = it->second;
43 ICHECK(sref.same_as(sref2))
44 << "InternalError: The sref points to a statement whose corresponding sref in stmt2ref "
45 "is not the same object as itself";
46 }
47 ICHECK_EQ(n_block_sref_visited_, static_cast<int>(self_->block_info.size()));
48 }
49
50 void VisitStmt_(const BlockNode* block) final {
51 if (init_block_depth_) {
52 ICHECK(!self_->stmt2ref.count(block)) << "InternalError: A block inside init block has its "
53 "corresponding sref, which is not allowed";
54 StmtVisitor::VisitStmt_(block);
55 return;
56 }
57 ICHECK(self_->stmt2ref.count(block))
58 << "InternalError: A BlockNode should appear in sref map, but it didn't\n"
59 << GetRef<Stmt>(block);
60 ++n_sref_visited_;
61 ++n_block_sref_visited_;
62 const StmtSRef& sref = self_->stmt2ref.at(block);
63 ICHECK(self_->block_info.count(sref))
64 << "InternalError: Cannot find scope information of the BlockNode:\n"
65 << GetRef<Stmt>(block);
66 ICHECK(sref->parent == ancestors_.back())
67 << "InternalError: Parent information mismatch for BlockNode:\n"
68 << GetRef<Stmt>(block) << "\nIts parent is supposed to be:\n"
69 << GetRef<Stmt>(ancestors_.back()->stmt) << "\nHowever, its parent is incorrect and is:\n"
70 << (sref->parent ? Optional<Stmt>(GetRef<Stmt>(sref->parent->stmt))
71 : Optional<Stmt>(NullOpt));
72 ancestors_.push_back(sref.operator->());
73 if (block->init.defined()) {
74 ++init_block_depth_;
75 VisitStmt(block->init.value());
76 --init_block_depth_;
77 }
78 VisitStmt(block->body);
79 ancestors_.pop_back();
80 }
81
82 void VisitStmt_(const ForNode* loop) final {
83 if (init_block_depth_) {
84 ICHECK(!self_->stmt2ref.count(loop)) << "InternalError: A loop inside init block has its "
85 "corresponding sref, which is not allowed";
86 StmtVisitor::VisitStmt_(loop);
87 return;
88 }
89 ICHECK(self_->stmt2ref.count(loop))
90 << "InternalError: A ForNode should appear in sref map, but it didn't\n"
91 << GetRef<Stmt>(loop);
92 ++n_sref_visited_;
93 const StmtSRef& sref = self_->stmt2ref.at(loop);
94 Optional<Stmt> stmt = NullOpt;
95 ICHECK(sref->parent == ancestors_.back())
96 << "InternalError: Parent information mismatch for ForNode:\n"
97 << GetRef<Stmt>(loop) << "\nIts parent is supposed to be:\n"
98 << GetRef<Stmt>(ancestors_.back()->stmt) << "\nHowever, its parent is incorrect and is:\n"
99 << (sref->parent ? Optional<Stmt>(GetRef<Stmt>(sref->parent->stmt))
100 : Optional<Stmt>(NullOpt));
101 ancestors_.push_back(sref.operator->());
102 StmtVisitor::VisitStmt_(loop);
103 ancestors_.pop_back();
104 }
105
106 void VisitStmt_(const SeqStmtNode* seq_stmt) final {
107 // Verify seq_index
108 if (init_block_depth_) {
109 StmtVisitor::VisitStmt_(seq_stmt);
110 return;
111 }
112 int n = static_cast<int>(seq_stmt->seq.size());
113 for (int i = 0; i < n; ++i) {
114 const Stmt& child = seq_stmt->seq[i];
115 StmtSRef sref{nullptr};
116 if (const auto* realize = child.as<BlockRealizeNode>()) {
117 const auto* block = realize->block.get();
118 ICHECK(self_->stmt2ref.count(block));
119 sref = self_->stmt2ref.at(block);
120 } else if (child->IsInstance<ForNode>()) {
121 ICHECK(self_->stmt2ref.count(child.get()));
122 sref = self_->stmt2ref.at(child.get());
123 } else {
124 continue;
125 }
126 ICHECK_EQ(sref->seq_index, i) << "InternalError: A StmtSRef has incorrect seq_index";
127 }
128 StmtVisitor::VisitStmt_(seq_stmt);
129 }
130
131 /*! \brief The schedule it belongs to */
132 const ScheduleStateNode* self_;
133 /*! \brief Parent information during the visit */
134 std::vector<const StmtSRefNode*> ancestors_ = {nullptr};
135 /*! \brief If the visitor is currently in the init block */
136 int init_block_depth_ = 0;
137 /*! \brief Number of srefs that are visited */
138 int n_sref_visited_ = 0;
139 /*! \brief Number of block srefs that are visited */
140 int n_block_sref_visited_ = 0;
141};
142
143void VerifySRefTree(const ScheduleState& self) { SRefTreeVerifier::Verify(self.get()); }
144
145void VerifyCachedFlags(const ScheduleState& self) {
146 std::vector<StmtSRef> block_info_not_found;
147 std::vector<std::tuple<StmtSRef, bool, bool>> block_info_wrong_affine_binding;
148 std::vector<std::tuple<StmtSRef, bool, bool>> block_info_wrong_region_cover;
149 std::vector<std::tuple<StmtSRef, bool, bool>> block_info_wrong_stage_pipeline;
150
151 ScheduleState new_state(self->mod);
152 for (const auto& kv : new_state->stmt2ref) {
153 const StmtNode* stmt = kv.first;
154 const StmtSRef& new_sref = kv.second;
155 if (stmt->IsInstance<ForNode>() || !self->stmt2ref.count(stmt)) {
156 continue;
157 }
158 const BlockInfo& new_block_info = new_state->block_info.at(new_sref);
159 const StmtSRef& old_sref = self->stmt2ref.at(stmt);
160 if (!self->block_info.count(old_sref)) {
161 block_info_not_found.push_back(new_sref);
162 continue;
163 }
164 const BlockInfo& old_block_info = self->block_info.at(old_sref);
165 if (new_block_info.affine_binding != old_block_info.affine_binding) {
166 block_info_wrong_affine_binding.emplace_back(new_sref, //
167 new_block_info.affine_binding,
168 old_block_info.affine_binding);
169 }
170 if (new_block_info.region_cover != old_block_info.region_cover) {
171 block_info_wrong_region_cover.emplace_back(new_sref, //
172 new_block_info.region_cover,
173 old_block_info.region_cover);
174 }
175 if (new_block_info.scope->stage_pipeline != old_block_info.scope->stage_pipeline) {
176 block_info_wrong_stage_pipeline.emplace_back(new_sref, //
177 new_block_info.scope->stage_pipeline,
178 old_block_info.scope->stage_pipeline);
179 }
180 }
181
182 bool has_not_found = !block_info_not_found.empty();
183 bool has_wrong_affine_binding = !block_info_wrong_affine_binding.empty();
184 bool has_wrong_region_cover = !block_info_wrong_region_cover.empty();
185 bool has_wrong_stage_pipeline = !block_info_wrong_stage_pipeline.empty();
186 if (!(has_not_found || has_wrong_affine_binding || has_wrong_region_cover ||
187 has_wrong_stage_pipeline)) {
188 return;
189 }
190 std::ostringstream os;
191 if (has_not_found) {
192 os << "- BlockInfo not found:";
193 for (const StmtSRef& block_sref : block_info_not_found) {
194 const auto* block = block_sref->StmtAs<BlockNode>();
195 ICHECK(block);
196 os << " " << block->name_hint;
197 }
198 os << std::endl;
199 }
200 if (has_wrong_affine_binding) {
201 os << "- Wrong affine_binding: ";
202 for (const std::tuple<StmtSRef, bool, bool>& record : block_info_wrong_affine_binding) {
203 const StmtSRef& block_sref = std::get<0>(record);
204 bool expected = std::get<1>(record);
205 bool actual = std::get<2>(record);
206 const auto* block = block_sref->StmtAs<BlockNode>();
207 ICHECK(block);
208 os << " (" << block->name_hint << ", expected=" << expected << ", actual=" << actual << ")";
209 }
210 os << std::endl;
211 }
212 if (has_wrong_region_cover) {
213 os << "- Wrong region_cover: ";
214 for (const std::tuple<StmtSRef, bool, bool>& record : block_info_wrong_region_cover) {
215 const StmtSRef& block_sref = std::get<0>(record);
216 bool expected = std::get<1>(record);
217 bool actual = std::get<2>(record);
218 const auto* block = block_sref->StmtAs<BlockNode>();
219 ICHECK(block);
220 os << " (" << block->name_hint << ", expected=" << expected << ", actual=" << actual << ")";
221 }
222 os << std::endl;
223 }
224 if (has_wrong_stage_pipeline) {
225 os << "- Wrong stage_pipeline: ";
226 for (const std::tuple<StmtSRef, bool, bool>& record : block_info_wrong_stage_pipeline) {
227 const StmtSRef& block_sref = std::get<0>(record);
228 bool expected = std::get<1>(record);
229 bool actual = std::get<2>(record);
230 const auto* block = block_sref->StmtAs<BlockNode>();
231 ICHECK(block);
232 os << " (" << block->name_hint << ", expected=" << expected << ", actual=" << actual << ")";
233 }
234 os << std::endl;
235 }
236 LOG(FATAL) << "Schedule verification failed. The IR is:\n"
237 << self->mod << "\nThe errors are:\n"
238 << os.str();
239 throw;
240}
241
242} // namespace tir
243} // namespace tvm
244