1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file remove_store_undef.cc |
22 | * \brief Remove stores of tir::builtin::undef |
23 | */ |
24 | #include <tvm/runtime/registry.h> |
25 | #include <tvm/tir/analysis.h> |
26 | #include <tvm/tir/builtin.h> |
27 | #include <tvm/tir/op.h> |
28 | #include <tvm/tir/stmt.h> |
29 | #include <tvm/tir/stmt_functor.h> |
30 | #include <tvm/tir/transform.h> |
31 | |
32 | namespace tvm { |
33 | namespace tir { |
34 | |
35 | class StoreUndefLocator : public StmtExprVisitor { |
36 | public: |
37 | static std::unordered_set<const BufferStoreNode*> Locate(Stmt stmt) { |
38 | StoreUndefLocator locator; |
39 | locator(std::move(stmt)); |
40 | return locator.undef_stores_; |
41 | } |
42 | |
43 | private: |
44 | StoreUndefLocator() = default; |
45 | |
46 | void VisitStmt_(const BufferStoreNode* op) final { |
47 | bool stash_undef = false; |
48 | std::swap(has_undef_, stash_undef); |
49 | StmtExprVisitor::VisitExpr(op->value); |
50 | std::swap(has_undef_, stash_undef); |
51 | if (stash_undef) { |
52 | ICHECK(SideEffect(op->value) <= CallEffectKind::kReadState) |
53 | << "Error: T.undef() used in BufferStore expressions " |
54 | << "must not have other side effects" ; |
55 | undef_stores_.insert(op); |
56 | } |
57 | } |
58 | |
59 | void VisitExpr_(const BufferLoadNode* op) final { |
60 | // This function left deliberately empty. builtin::undef() |
61 | // shouldn't occur in the indices of BufferLoad. Avoiding |
62 | // visiting the indices catches the builtin::undef in |
63 | // ValidateAllUndefRemoved. |
64 | } |
65 | |
66 | void VisitStmt_(const LetStmtNode* op) final { |
67 | bool stash_undef = false; |
68 | std::swap(has_undef_, stash_undef); |
69 | StmtExprVisitor::VisitExpr(op->value); |
70 | std::swap(has_undef_, stash_undef); |
71 | if (stash_undef) { |
72 | ICHECK(SideEffect(op->value) <= CallEffectKind::kReadState) |
73 | << "Error: T.undef() used in Let expressions " |
74 | << "must not have other side effects" ; |
75 | var_bindings_with_undef_.insert(op->var.get()); |
76 | } |
77 | |
78 | StmtExprVisitor::VisitStmt(op->body); |
79 | } |
80 | |
81 | void VisitExpr_(const VarNode* op) final { |
82 | if (var_bindings_with_undef_.count(op)) { |
83 | has_undef_ = true; |
84 | } |
85 | } |
86 | |
87 | void VisitExpr_(const CallNode* op) final { |
88 | if (op->op.same_as(builtin::undef())) { |
89 | has_undef_ = true; |
90 | } |
91 | StmtExprVisitor::VisitExpr_(op); |
92 | } |
93 | |
94 | bool has_undef_{false}; |
95 | |
96 | std::unordered_set<const VarNode*> var_bindings_with_undef_; |
97 | std::unordered_set<const BufferStoreNode*> undef_stores_; |
98 | }; |
99 | |
100 | // Remove any BufferStores whose value depends on T.undef |
101 | class StoreUndefRemover : public StmtExprMutator { |
102 | public: |
103 | static Stmt Apply(Stmt stmt) { |
104 | auto to_remove = StoreUndefLocator::Locate(stmt); |
105 | StoreUndefRemover mutator(to_remove); |
106 | return mutator(std::move(stmt)); |
107 | } |
108 | |
109 | private: |
110 | using Parent = StmtExprMutator; |
111 | |
112 | explicit StoreUndefRemover(const std::unordered_set<const BufferStoreNode*>& to_remove) |
113 | : to_remove_(to_remove) {} |
114 | |
115 | Stmt VisitStmt_(const BufferStoreNode* op) final { |
116 | if (to_remove_.count(op)) { |
117 | return Evaluate(0); |
118 | } else { |
119 | return Parent::VisitStmt_(op); |
120 | } |
121 | } |
122 | |
123 | const std::unordered_set<const BufferStoreNode*>& to_remove_; |
124 | }; |
125 | |
126 | // Remove any BufferStores whose value depends on T.undef |
127 | class ContainsUndefChecker : public StmtExprVisitor { |
128 | public: |
129 | static bool Check(const Stmt& stmt) { |
130 | ContainsUndefChecker checker; |
131 | checker(stmt); |
132 | return checker.contains_undef; |
133 | } |
134 | |
135 | private: |
136 | void VisitExpr_(const CallNode* op) final { |
137 | if (op->op.same_as(builtin::undef())) { |
138 | contains_undef = true; |
139 | } |
140 | StmtExprVisitor::VisitExpr_(op); |
141 | } |
142 | |
143 | bool contains_undef{false}; |
144 | }; |
145 | |
146 | namespace transform { |
147 | Pass RemoveStoreUndefInternal() { |
148 | auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { |
149 | auto* n = f.CopyOnWrite(); |
150 | n->body = StoreUndefRemover::Apply(std::move(n->body)); |
151 | return f; |
152 | }; |
153 | return CreatePrimFuncPass(pass_func, 0, "tir.RemoveStoreUndefInternal" , {}); |
154 | } |
155 | |
156 | Pass ValidateAllUndefRemoved() { |
157 | auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { |
158 | bool contains_undef = ContainsUndefChecker::Check(f->body); |
159 | ICHECK(!contains_undef) << "Expected removal of BufferStore containing builtin::undef() " |
160 | << "to remove all instances of builtin::undef(). " |
161 | << "Instead, result was" |
162 | << "\n" |
163 | << f; |
164 | return f; |
165 | }; |
166 | return CreatePrimFuncPass(pass_func, 0, "tir.ValidateAllUndefRemoved" , {}); |
167 | } |
168 | |
169 | Pass RemoveStoreUndef() { |
170 | return Sequential({RemoveStoreUndefInternal(), RemoveNoOp(), ValidateAllUndefRemoved()}, |
171 | "tir.RemoveStoreUndef" ); |
172 | } |
173 | |
174 | TVM_REGISTER_GLOBAL("tir.transform.RemoveStoreUndef" ).set_body_typed(RemoveStoreUndef); |
175 | |
176 | } // namespace transform |
177 | |
178 | } // namespace tir |
179 | } // namespace tvm |
180 | |