1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * SSA related checks and pass.
22 *
23 * SSA requires each varaible to be only defined once.
24 * \file verify_ssa.cc
25 */
26#include <tvm/runtime/registry.h>
27#include <tvm/tir/analysis.h>
28#include <tvm/tir/expr.h>
29#include <tvm/tir/stmt_functor.h>
30
31#include <unordered_map>
32#include <unordered_set>
33#include <vector>
34
35namespace tvm {
36namespace tir {
37
38class SSAVerifier final : public StmtExprVisitor {
39 public:
40 bool is_ssa_{true};
41
42 void VisitExpr(const PrimExpr& n) final {
43 if (!is_ssa_) return;
44 StmtExprVisitor::VisitExpr(n);
45 }
46 void VisitStmt(const Stmt& n) final {
47 if (!is_ssa_) return;
48 StmtExprVisitor::VisitStmt(n);
49 }
50 void VisitExpr_(const LetNode* op) final {
51 // Weaker SSA condition
52 // A single var can be binded in multiple lets
53 // but they have to bind to the same value.
54 // This is used to enable cases when we reuse a single let
55 // expression to cosntruct a nested expr.
56 // (let x = 1 in x + 1) * (let x = 1 in x + 1)
57 auto it = def_map_.find(op->var);
58 if (it != def_map_.end()) {
59 if (!deep_equal_(it->second, op->value)) {
60 is_ssa_ = false;
61 return;
62 }
63 } else {
64 MarkDef(op->var, op->value);
65 }
66 StmtExprVisitor::VisitExpr_(op);
67 }
68
69 void VisitStmt_(const LetStmtNode* op) final {
70 MarkDef(op->var, op->value);
71 StmtExprVisitor::VisitStmt_(op);
72 }
73 void VisitStmt_(const ForNode* op) final {
74 MarkDef(op->loop_var, op->loop_var);
75 StmtExprVisitor::VisitStmt_(op);
76 }
77 void VisitStmt_(const AllocateNode* op) final {
78 MarkDef(op->buffer_var, op->buffer_var);
79 StmtExprVisitor::VisitStmt_(op);
80 }
81
82 void VisitExpr_(const VarNode* node) final {
83 auto var = GetRef<Var>(node);
84 if (match_scope_) {
85 MarkDef(var, var, true);
86 }
87 }
88
89 void Run(const PrimFunc& func) {
90 for (auto param : func->params) {
91 MarkDef(param, param);
92 }
93
94 for (auto kv : func->buffer_map) {
95 this->DefineBuffer(kv.second);
96 }
97 this->VisitStmt(func->body);
98 }
99
100 void DefineBuffer(const Buffer& buffer) {
101 match_scope_ = true;
102 this->VisitExpr(buffer->data);
103 for (size_t i = 0; i < buffer->shape.size(); ++i) {
104 this->VisitExpr(buffer->shape[i]);
105 }
106
107 if (buffer->strides.defined()) {
108 for (size_t i = 0; i < buffer->strides.size(); ++i) {
109 this->VisitExpr(buffer->strides[i]);
110 }
111 }
112 this->VisitExpr(buffer->elem_offset);
113
114 match_scope_ = false;
115 }
116
117 private:
118 void MarkDef(const Var& var, PrimExpr value, bool allow_dup = false) {
119 if (def_map_.count(var) != 0) {
120 if (!allow_dup) {
121 is_ssa_ = false;
122 return;
123 }
124 } else {
125 def_map_[var] = value;
126 }
127 }
128 // whether we are in match scope, where a var can occur multiple times.
129 bool match_scope_{false};
130 // deep equal
131 ExprDeepEqual deep_equal_;
132 // def map, for let, maps to the bind value, for others maps to self.
133 std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> def_map_;
134};
135
136bool VerifySSA(const PrimFunc& func) {
137 SSAVerifier visitor;
138 visitor.Run(func);
139 return visitor.is_ssa_;
140}
141
142TVM_REGISTER_GLOBAL("tir.analysis.verify_ssa").set_body_typed(VerifySSA);
143
144namespace transform {
145
146Pass VerifySSA() {
147 auto pass_func = [=](IRModule mod, PassContext ctx) {
148 for (auto kv : mod->functions) {
149 if (auto* n = kv.second.as<PrimFuncNode>()) {
150 auto func = GetRef<PrimFunc>(n);
151 ICHECK(VerifySSA(func)) << "RuntimeError: IR is not in SSA form" << func;
152 }
153 }
154 return mod;
155 };
156 return tvm::transform::CreateModulePass(pass_func, 0, "tir.VerifySSA", {});
157}
158
159TVM_REGISTER_GLOBAL("tir.transform.VerifySSA").set_body_typed(VerifySSA);
160
161} // namespace transform
162
163} // namespace tir
164} // namespace tvm
165