1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tir/analysis/verify_well_formed.cc
22 * \brief Check if schedulable tir is well-formed.
23 */
24
25#include <tvm/runtime/registry.h>
26#include <tvm/tir/stmt.h>
27#include <tvm/tir/stmt_functor.h>
28
29#include "../ir/functor_common.h"
30
31namespace tvm {
32namespace tir {
33
34/*! \brief Verify all Expr inside the block does not contain:
35 * 1. loop vars outside the current block.
36 * 2. block vars of parent blocks.
37 */
38class BlockVarAccessVerifier : public StmtExprVisitor {
39 public:
40 static bool Verify(const PrimFunc& func, bool assert_mode) {
41 BlockVarAccessVerifier verifier(assert_mode);
42 verifier(func->body);
43 return !verifier.has_error_;
44 }
45
46 private:
47 explicit BlockVarAccessVerifier(bool assert_mode) : assert_mode_(assert_mode) {}
48
49 void VisitStmt(const Stmt& stmt) final {
50 if (!has_error_) {
51 StmtExprVisitor::VisitStmt(stmt);
52 }
53 }
54
55 void VisitExpr(const PrimExpr& expr) final {
56 if (!has_error_) {
57 StmtExprVisitor::VisitExpr(expr);
58 }
59 }
60
61 void VisitExpr_(const VarNode* op) final {
62 auto it = loop_vars_.find(op);
63 if (it != loop_vars_.end() && it->second < cur_block_level_) {
64 has_error_ = true;
65 if (assert_mode_) {
66 report_error(op);
67 }
68 }
69 }
70
71 void VisitStmt_(const ForNode* op) final {
72 ICHECK(loop_vars_.find(op->loop_var.get()) == loop_vars_.end());
73 loop_vars_[op->loop_var.get()] = cur_block_level_;
74 StmtExprVisitor::VisitStmt_(op);
75 loop_vars_.erase(op->loop_var.get());
76 }
77
78 void VisitStmt_(const BlockNode* op) final {
79 // Do not check boundary if it's a opaque block.
80 cur_block_level_ += !op->iter_vars.empty();
81
82 // Step 0. Skip block iter var's domain
83
84 // Step 1. Visit read/write regions
85 auto fvisit_buffer_region = [this](const BufferRegion& s) {
86 for (const auto& range : s->region) {
87 this->VisitExpr(range->min);
88 this->VisitExpr(range->extent);
89 }
90 };
91 VisitArray(op->reads, fvisit_buffer_region);
92 VisitArray(op->writes, fvisit_buffer_region);
93
94 // Step 2. Visit match buffers
95 VisitArray(op->match_buffers,
96 [fvisit_buffer_region](const MatchBufferRegion& match_buffer_region) {
97 fvisit_buffer_region(match_buffer_region->source);
98 });
99
100 // Step 3. Visit init and body
101 if (op->init.defined()) {
102 this->VisitStmt(op->init.value());
103 }
104 this->VisitStmt(op->body);
105
106 cur_block_level_ -= !op->iter_vars.empty();
107 }
108
109 private:
110 void report_error(const VarNode* var) {
111 // TODO(siyuan): use the error message from the parser.
112 LOG(FATAL) << "Well-formedness check failed: outside defined var " << var->name_hint
113 << " is used inside the current block.";
114 }
115
116 /*! \brief The map from outside loop vars to its corresponding block level. */
117 std::unordered_map<const VarNode*, size_t> loop_vars_;
118 /*! \brief Whether it's in assert mode. */
119 bool assert_mode_;
120 /*! \brief Current nested block stack level. */
121 size_t cur_block_level_{0};
122 /*! \brief Whether there is error. */
123 bool has_error_{false};
124};
125
126bool VerifyWellFormed(const PrimFunc& func, bool assert_mode) {
127 if (!BlockVarAccessVerifier::Verify(func, assert_mode)) {
128 return false;
129 }
130 // TODO(Siyuan): add more checks here.
131 return true;
132}
133
134TVM_REGISTER_GLOBAL("tir.analysis.VerifyWellFormed").set_body_typed(VerifyWellFormed);
135
136} // namespace tir
137} // namespace tvm
138