1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tir/analysis/verify_well_formed.cc |
22 | * \brief Check if schedulable tir is well-formed. |
23 | */ |
24 | |
25 | #include <tvm/runtime/registry.h> |
26 | #include <tvm/tir/stmt.h> |
27 | #include <tvm/tir/stmt_functor.h> |
28 | |
29 | #include "../ir/functor_common.h" |
30 | |
31 | namespace tvm { |
32 | namespace tir { |
33 | |
34 | /*! \brief Verify all Expr inside the block does not contain: |
35 | * 1. loop vars outside the current block. |
36 | * 2. block vars of parent blocks. |
37 | */ |
38 | class BlockVarAccessVerifier : public StmtExprVisitor { |
39 | public: |
40 | static bool Verify(const PrimFunc& func, bool assert_mode) { |
41 | BlockVarAccessVerifier verifier(assert_mode); |
42 | verifier(func->body); |
43 | return !verifier.has_error_; |
44 | } |
45 | |
46 | private: |
47 | explicit BlockVarAccessVerifier(bool assert_mode) : assert_mode_(assert_mode) {} |
48 | |
49 | void VisitStmt(const Stmt& stmt) final { |
50 | if (!has_error_) { |
51 | StmtExprVisitor::VisitStmt(stmt); |
52 | } |
53 | } |
54 | |
55 | void VisitExpr(const PrimExpr& expr) final { |
56 | if (!has_error_) { |
57 | StmtExprVisitor::VisitExpr(expr); |
58 | } |
59 | } |
60 | |
61 | void VisitExpr_(const VarNode* op) final { |
62 | auto it = loop_vars_.find(op); |
63 | if (it != loop_vars_.end() && it->second < cur_block_level_) { |
64 | has_error_ = true; |
65 | if (assert_mode_) { |
66 | report_error(op); |
67 | } |
68 | } |
69 | } |
70 | |
71 | void VisitStmt_(const ForNode* op) final { |
72 | ICHECK(loop_vars_.find(op->loop_var.get()) == loop_vars_.end()); |
73 | loop_vars_[op->loop_var.get()] = cur_block_level_; |
74 | StmtExprVisitor::VisitStmt_(op); |
75 | loop_vars_.erase(op->loop_var.get()); |
76 | } |
77 | |
78 | void VisitStmt_(const BlockNode* op) final { |
79 | // Do not check boundary if it's a opaque block. |
80 | cur_block_level_ += !op->iter_vars.empty(); |
81 | |
82 | // Step 0. Skip block iter var's domain |
83 | |
84 | // Step 1. Visit read/write regions |
85 | auto fvisit_buffer_region = [this](const BufferRegion& s) { |
86 | for (const auto& range : s->region) { |
87 | this->VisitExpr(range->min); |
88 | this->VisitExpr(range->extent); |
89 | } |
90 | }; |
91 | VisitArray(op->reads, fvisit_buffer_region); |
92 | VisitArray(op->writes, fvisit_buffer_region); |
93 | |
94 | // Step 2. Visit match buffers |
95 | VisitArray(op->match_buffers, |
96 | [fvisit_buffer_region](const MatchBufferRegion& match_buffer_region) { |
97 | fvisit_buffer_region(match_buffer_region->source); |
98 | }); |
99 | |
100 | // Step 3. Visit init and body |
101 | if (op->init.defined()) { |
102 | this->VisitStmt(op->init.value()); |
103 | } |
104 | this->VisitStmt(op->body); |
105 | |
106 | cur_block_level_ -= !op->iter_vars.empty(); |
107 | } |
108 | |
109 | private: |
110 | void report_error(const VarNode* var) { |
111 | // TODO(siyuan): use the error message from the parser. |
112 | LOG(FATAL) << "Well-formedness check failed: outside defined var " << var->name_hint |
113 | << " is used inside the current block." ; |
114 | } |
115 | |
116 | /*! \brief The map from outside loop vars to its corresponding block level. */ |
117 | std::unordered_map<const VarNode*, size_t> loop_vars_; |
118 | /*! \brief Whether it's in assert mode. */ |
119 | bool assert_mode_; |
120 | /*! \brief Current nested block stack level. */ |
121 | size_t cur_block_level_{0}; |
122 | /*! \brief Whether there is error. */ |
123 | bool has_error_{false}; |
124 | }; |
125 | |
126 | bool VerifyWellFormed(const PrimFunc& func, bool assert_mode) { |
127 | if (!BlockVarAccessVerifier::Verify(func, assert_mode)) { |
128 | return false; |
129 | } |
130 | // TODO(Siyuan): add more checks here. |
131 | return true; |
132 | } |
133 | |
134 | TVM_REGISTER_GLOBAL("tir.analysis.VerifyWellFormed" ).set_body_typed(VerifyWellFormed); |
135 | |
136 | } // namespace tir |
137 | } // namespace tvm |
138 | |