1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file remove_store_undef.cc
22 * \brief Remove stores of tir::builtin::undef
23 */
24#include <tvm/runtime/registry.h>
25#include <tvm/tir/analysis.h>
26#include <tvm/tir/builtin.h>
27#include <tvm/tir/op.h>
28#include <tvm/tir/stmt.h>
29#include <tvm/tir/stmt_functor.h>
30#include <tvm/tir/transform.h>
31
32namespace tvm {
33namespace tir {
34
35class StoreUndefLocator : public StmtExprVisitor {
36 public:
37 static std::unordered_set<const BufferStoreNode*> Locate(Stmt stmt) {
38 StoreUndefLocator locator;
39 locator(std::move(stmt));
40 return locator.undef_stores_;
41 }
42
43 private:
44 StoreUndefLocator() = default;
45
46 void VisitStmt_(const BufferStoreNode* op) final {
47 bool stash_undef = false;
48 std::swap(has_undef_, stash_undef);
49 StmtExprVisitor::VisitExpr(op->value);
50 std::swap(has_undef_, stash_undef);
51 if (stash_undef) {
52 ICHECK(SideEffect(op->value) <= CallEffectKind::kReadState)
53 << "Error: T.undef() used in BufferStore expressions "
54 << "must not have other side effects";
55 undef_stores_.insert(op);
56 }
57 }
58
59 void VisitExpr_(const BufferLoadNode* op) final {
60 // This function left deliberately empty. builtin::undef()
61 // shouldn't occur in the indices of BufferLoad. Avoiding
62 // visiting the indices catches the builtin::undef in
63 // ValidateAllUndefRemoved.
64 }
65
66 void VisitStmt_(const LetStmtNode* op) final {
67 bool stash_undef = false;
68 std::swap(has_undef_, stash_undef);
69 StmtExprVisitor::VisitExpr(op->value);
70 std::swap(has_undef_, stash_undef);
71 if (stash_undef) {
72 ICHECK(SideEffect(op->value) <= CallEffectKind::kReadState)
73 << "Error: T.undef() used in Let expressions "
74 << "must not have other side effects";
75 var_bindings_with_undef_.insert(op->var.get());
76 }
77
78 StmtExprVisitor::VisitStmt(op->body);
79 }
80
81 void VisitExpr_(const VarNode* op) final {
82 if (var_bindings_with_undef_.count(op)) {
83 has_undef_ = true;
84 }
85 }
86
87 void VisitExpr_(const CallNode* op) final {
88 if (op->op.same_as(builtin::undef())) {
89 has_undef_ = true;
90 }
91 StmtExprVisitor::VisitExpr_(op);
92 }
93
94 bool has_undef_{false};
95
96 std::unordered_set<const VarNode*> var_bindings_with_undef_;
97 std::unordered_set<const BufferStoreNode*> undef_stores_;
98};
99
100// Remove any BufferStores whose value depends on T.undef
101class StoreUndefRemover : public StmtExprMutator {
102 public:
103 static Stmt Apply(Stmt stmt) {
104 auto to_remove = StoreUndefLocator::Locate(stmt);
105 StoreUndefRemover mutator(to_remove);
106 return mutator(std::move(stmt));
107 }
108
109 private:
110 using Parent = StmtExprMutator;
111
112 explicit StoreUndefRemover(const std::unordered_set<const BufferStoreNode*>& to_remove)
113 : to_remove_(to_remove) {}
114
115 Stmt VisitStmt_(const BufferStoreNode* op) final {
116 if (to_remove_.count(op)) {
117 return Evaluate(0);
118 } else {
119 return Parent::VisitStmt_(op);
120 }
121 }
122
123 const std::unordered_set<const BufferStoreNode*>& to_remove_;
124};
125
126// Remove any BufferStores whose value depends on T.undef
127class ContainsUndefChecker : public StmtExprVisitor {
128 public:
129 static bool Check(const Stmt& stmt) {
130 ContainsUndefChecker checker;
131 checker(stmt);
132 return checker.contains_undef;
133 }
134
135 private:
136 void VisitExpr_(const CallNode* op) final {
137 if (op->op.same_as(builtin::undef())) {
138 contains_undef = true;
139 }
140 StmtExprVisitor::VisitExpr_(op);
141 }
142
143 bool contains_undef{false};
144};
145
146namespace transform {
147Pass RemoveStoreUndefInternal() {
148 auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
149 auto* n = f.CopyOnWrite();
150 n->body = StoreUndefRemover::Apply(std::move(n->body));
151 return f;
152 };
153 return CreatePrimFuncPass(pass_func, 0, "tir.RemoveStoreUndefInternal", {});
154}
155
156Pass ValidateAllUndefRemoved() {
157 auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
158 bool contains_undef = ContainsUndefChecker::Check(f->body);
159 ICHECK(!contains_undef) << "Expected removal of BufferStore containing builtin::undef() "
160 << "to remove all instances of builtin::undef(). "
161 << "Instead, result was"
162 << "\n"
163 << f;
164 return f;
165 };
166 return CreatePrimFuncPass(pass_func, 0, "tir.ValidateAllUndefRemoved", {});
167}
168
169Pass RemoveStoreUndef() {
170 return Sequential({RemoveStoreUndefInternal(), RemoveNoOp(), ValidateAllUndefRemoved()},
171 "tir.RemoveStoreUndef");
172}
173
174TVM_REGISTER_GLOBAL("tir.transform.RemoveStoreUndef").set_body_typed(RemoveStoreUndef);
175
176} // namespace transform
177
178} // namespace tir
179} // namespace tvm
180