1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * SSA related checks and pass. |
22 | * |
23 | * SSA requires each varaible to be only defined once. |
24 | * \file verify_ssa.cc |
25 | */ |
26 | #include <tvm/runtime/registry.h> |
27 | #include <tvm/tir/analysis.h> |
28 | #include <tvm/tir/expr.h> |
29 | #include <tvm/tir/stmt_functor.h> |
30 | |
31 | #include <unordered_map> |
32 | #include <unordered_set> |
33 | #include <vector> |
34 | |
35 | namespace tvm { |
36 | namespace tir { |
37 | |
38 | class SSAVerifier final : public StmtExprVisitor { |
39 | public: |
40 | bool is_ssa_{true}; |
41 | |
42 | void VisitExpr(const PrimExpr& n) final { |
43 | if (!is_ssa_) return; |
44 | StmtExprVisitor::VisitExpr(n); |
45 | } |
46 | void VisitStmt(const Stmt& n) final { |
47 | if (!is_ssa_) return; |
48 | StmtExprVisitor::VisitStmt(n); |
49 | } |
50 | void VisitExpr_(const LetNode* op) final { |
51 | // Weaker SSA condition |
52 | // A single var can be binded in multiple lets |
53 | // but they have to bind to the same value. |
54 | // This is used to enable cases when we reuse a single let |
55 | // expression to cosntruct a nested expr. |
56 | // (let x = 1 in x + 1) * (let x = 1 in x + 1) |
57 | auto it = def_map_.find(op->var); |
58 | if (it != def_map_.end()) { |
59 | if (!deep_equal_(it->second, op->value)) { |
60 | is_ssa_ = false; |
61 | return; |
62 | } |
63 | } else { |
64 | MarkDef(op->var, op->value); |
65 | } |
66 | StmtExprVisitor::VisitExpr_(op); |
67 | } |
68 | |
69 | void VisitStmt_(const LetStmtNode* op) final { |
70 | MarkDef(op->var, op->value); |
71 | StmtExprVisitor::VisitStmt_(op); |
72 | } |
73 | void VisitStmt_(const ForNode* op) final { |
74 | MarkDef(op->loop_var, op->loop_var); |
75 | StmtExprVisitor::VisitStmt_(op); |
76 | } |
77 | void VisitStmt_(const AllocateNode* op) final { |
78 | MarkDef(op->buffer_var, op->buffer_var); |
79 | StmtExprVisitor::VisitStmt_(op); |
80 | } |
81 | |
82 | void VisitExpr_(const VarNode* node) final { |
83 | auto var = GetRef<Var>(node); |
84 | if (match_scope_) { |
85 | MarkDef(var, var, true); |
86 | } |
87 | } |
88 | |
89 | void Run(const PrimFunc& func) { |
90 | for (auto param : func->params) { |
91 | MarkDef(param, param); |
92 | } |
93 | |
94 | for (auto kv : func->buffer_map) { |
95 | this->DefineBuffer(kv.second); |
96 | } |
97 | this->VisitStmt(func->body); |
98 | } |
99 | |
100 | void DefineBuffer(const Buffer& buffer) { |
101 | match_scope_ = true; |
102 | this->VisitExpr(buffer->data); |
103 | for (size_t i = 0; i < buffer->shape.size(); ++i) { |
104 | this->VisitExpr(buffer->shape[i]); |
105 | } |
106 | |
107 | if (buffer->strides.defined()) { |
108 | for (size_t i = 0; i < buffer->strides.size(); ++i) { |
109 | this->VisitExpr(buffer->strides[i]); |
110 | } |
111 | } |
112 | this->VisitExpr(buffer->elem_offset); |
113 | |
114 | match_scope_ = false; |
115 | } |
116 | |
117 | private: |
118 | void MarkDef(const Var& var, PrimExpr value, bool allow_dup = false) { |
119 | if (def_map_.count(var) != 0) { |
120 | if (!allow_dup) { |
121 | is_ssa_ = false; |
122 | return; |
123 | } |
124 | } else { |
125 | def_map_[var] = value; |
126 | } |
127 | } |
128 | // whether we are in match scope, where a var can occur multiple times. |
129 | bool match_scope_{false}; |
130 | // deep equal |
131 | ExprDeepEqual deep_equal_; |
132 | // def map, for let, maps to the bind value, for others maps to self. |
133 | std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> def_map_; |
134 | }; |
135 | |
136 | bool VerifySSA(const PrimFunc& func) { |
137 | SSAVerifier visitor; |
138 | visitor.Run(func); |
139 | return visitor.is_ssa_; |
140 | } |
141 | |
142 | TVM_REGISTER_GLOBAL("tir.analysis.verify_ssa" ).set_body_typed(VerifySSA); |
143 | |
144 | namespace transform { |
145 | |
146 | Pass VerifySSA() { |
147 | auto pass_func = [=](IRModule mod, PassContext ctx) { |
148 | for (auto kv : mod->functions) { |
149 | if (auto* n = kv.second.as<PrimFuncNode>()) { |
150 | auto func = GetRef<PrimFunc>(n); |
151 | ICHECK(VerifySSA(func)) << "RuntimeError: IR is not in SSA form" << func; |
152 | } |
153 | } |
154 | return mod; |
155 | }; |
156 | return tvm::transform::CreateModulePass(pass_func, 0, "tir.VerifySSA" , {}); |
157 | } |
158 | |
159 | TVM_REGISTER_GLOBAL("tir.transform.VerifySSA" ).set_body_typed(VerifySSA); |
160 | |
161 | } // namespace transform |
162 | |
163 | } // namespace tir |
164 | } // namespace tvm |
165 | |