1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file remove_no_op.cc |
22 | * \brief Remove no op from the stmt |
23 | */ |
24 | #include <tvm/arith/analyzer.h> |
25 | #include <tvm/runtime/registry.h> |
26 | #include <tvm/tir/analysis.h> |
27 | #include <tvm/tir/op.h> |
28 | #include <tvm/tir/stmt.h> |
29 | #include <tvm/tir/stmt_functor.h> |
30 | #include <tvm/tir/transform.h> |
31 | |
32 | #include <optional> |
33 | #include <unordered_map> |
34 | |
35 | #include "../../arith/const_fold.h" |
36 | #include "../../arith/ir_mutator_with_analyzer.h" |
37 | #include "../analysis/control_flow_graph.h" |
38 | #include "ir_utils.h" |
39 | |
40 | namespace tvm { |
41 | namespace tir { |
42 | |
43 | struct RemoveNoOpConfigNode : public tvm::AttrsNode<RemoveNoOpConfigNode> { |
44 | bool use_dataflow_analysis; |
45 | |
46 | TVM_DECLARE_ATTRS(RemoveNoOpConfigNode, "tir.transform.RemoveNoOpConfig" ) { |
47 | TVM_ATTR_FIELD(use_dataflow_analysis) |
48 | .describe( |
49 | "If true, known buffer values are propagated and used " |
50 | "to statically prove statements as no-ops." ) |
51 | .set_default(false); |
52 | } |
53 | }; |
54 | |
55 | class RemoveNoOpConfig : public Attrs { |
56 | public: |
57 | TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(RemoveNoOpConfig, Attrs, RemoveNoOpConfigNode); |
58 | }; |
59 | |
60 | TVM_REGISTER_NODE_TYPE(RemoveNoOpConfigNode); |
61 | TVM_REGISTER_PASS_CONFIG_OPTION("tir.RemoveNoOp" , RemoveNoOpConfig); |
62 | |
63 | // Mark the statement of each stage. |
64 | class NoOpRemover : public arith::IRMutatorWithAnalyzer { |
65 | public: |
66 | static Stmt Apply(Stmt stmt, arith::Analyzer* analyzer, |
67 | std::optional<ControlFlowGraph> touch_pattern, const StmtNode* context) { |
68 | NoOpRemover visitor(analyzer, touch_pattern, context); |
69 | return visitor(std::move(stmt)); |
70 | } |
71 | |
72 | private: |
73 | using Parent = IRMutatorWithAnalyzer; |
74 | using Parent::VisitStmt; |
75 | using Parent::VisitStmt_; |
76 | |
77 | NoOpRemover(arith::Analyzer* analyzer, std::optional<ControlFlowGraph> touch_pattern, |
78 | const StmtNode* context) |
79 | : Parent(analyzer), touch_pattern_(touch_pattern), context_(context) {} |
80 | |
81 | Stmt VisitStmt_(const LetStmtNode* op) final { |
82 | Stmt stmt = Parent::VisitStmt_(op); |
83 | op = stmt.as<LetStmtNode>(); |
84 | if (is_no_op(op->body)) { |
85 | return MakeEvaluate(op->value); |
86 | } |
87 | |
88 | bool body_uses_bound_variable = |
89 | !UsesVar(op->body, [&](const VarNode* var) { return var == op->var.get(); }); |
90 | if (body_uses_bound_variable && HasSideEffect(op->value)) { |
91 | return SeqStmt({MakeEvaluate(op->value), op->body}); |
92 | } else if (body_uses_bound_variable) { |
93 | return op->body; |
94 | } else { |
95 | return stmt; |
96 | } |
97 | } |
98 | Stmt VisitStmt_(const AttrStmtNode* op) final { |
99 | if (op->attr_key == "pragma_debug_skip_region" ) { |
100 | return MakeEvaluate(0); |
101 | } else if (op->attr_key == attr::async_wait_queue_scope) { |
102 | auto wait_attrs = GetAsyncWaitAttributes(op); |
103 | auto wait_cnt = wait_attrs.second; |
104 | arith::Analyzer ana; |
105 | if (ana.CanProve(wait_cnt < 0)) { |
106 | // A negative wait count can arise if it depends on a loop variable. |
107 | // For example, a wait count 1 - i can be negative after loop unrolling. |
108 | // We assume that such wait is a nop. |
109 | auto inner = op->body.as<AttrStmtNode>(); |
110 | ICHECK(inner); |
111 | return Parent::VisitStmt(inner->body); |
112 | } |
113 | } |
114 | |
115 | Stmt stmt = Parent::VisitStmt_(op); |
116 | op = stmt.as<AttrStmtNode>(); |
117 | return is_no_op(op->body) ? MakeEvaluate(op->value) : stmt; |
118 | } |
119 | Stmt VisitStmt_(const IfThenElseNode* op) final { |
120 | Stmt stmt = Parent::VisitStmt_(op); |
121 | op = stmt.as<IfThenElseNode>(); |
122 | // Sometimes the condition can be statically determined, |
123 | // in which the type of the `stmt` will not be IfThenElseNode. |
124 | if (!op) { |
125 | return stmt; |
126 | } |
127 | if (op->else_case) { |
128 | bool no_op_else = is_no_op(op->else_case.value()); |
129 | bool no_op_then = is_no_op(op->then_case); |
130 | if (no_op_else && no_op_then) { |
131 | return MakeEvaluate(op->condition); |
132 | } else if (no_op_else) { |
133 | return IfThenElse(op->condition, op->then_case); |
134 | } else if (no_op_then) { |
135 | return IfThenElse(!op->condition, op->else_case.value()); |
136 | } else { |
137 | return stmt; |
138 | } |
139 | } else { |
140 | if (is_no_op(op->then_case)) { |
141 | return MakeEvaluate(op->condition); |
142 | } else { |
143 | return stmt; |
144 | } |
145 | } |
146 | } |
147 | Stmt VisitStmt_(const ForNode* op) final { |
148 | auto extent_range = arith::EvalSet(op->extent, var_range_map_); |
149 | if (!arith::is_neg_inf(extent_range.max()) && !arith::is_pos_inf(extent_range.max()) && |
150 | analyzer_->CanProve(extent_range.max() <= 0)) { |
151 | return Evaluate(0); |
152 | } |
153 | var_range_map_[op->loop_var.get()] = arith::IntSet::FromMinExtent(op->min, op->extent); |
154 | Stmt stmt = Parent::VisitStmt_(op); |
155 | var_range_map_.erase(op->loop_var.get()); |
156 | op = stmt.as<ForNode>(); |
157 | if (is_zero(op->extent)) { |
158 | return Evaluate(0); |
159 | } |
160 | return is_no_op(op->body) ? MakeEvaluate({op->min, op->extent}) : stmt; |
161 | } |
162 | Stmt VisitStmt_(const AllocateNode* op) final { |
163 | Stmt stmt = StmtMutator::VisitStmt_(op); |
164 | op = stmt.as<AllocateNode>(); |
165 | return is_no_op(op->body) ? MakeEvaluate(op->extents) : stmt; |
166 | } |
167 | |
168 | Stmt VisitStmt_(const ProducerRealizeNode* op) final { |
169 | Stmt stmt = StmtMutator::VisitStmt_(op); |
170 | op = stmt.as<ProducerRealizeNode>(); |
171 | return is_no_op(op->body) ? op->body : stmt; |
172 | } |
173 | Stmt VisitStmt_(const EvaluateNode* op) final { |
174 | if (HasSideEffect(op->value)) { |
175 | return GetRef<Stmt>(op); |
176 | } else { |
177 | return Evaluate(0); |
178 | } |
179 | } |
180 | |
181 | Stmt VisitStmt_(const SeqStmtNode* op) final { |
182 | auto ret = Downcast<SeqStmt>(StmtMutator::VisitSeqStmt_(op, true)); |
183 | |
184 | bool need_compact = std::any_of(ret->seq.begin(), ret->seq.end(), |
185 | [](const auto& stmt) { return is_no_op(stmt); }); |
186 | |
187 | if (need_compact) { |
188 | Array<Stmt> filtered; |
189 | for (Stmt stmt : ret->seq) { |
190 | if (!is_no_op(stmt)) { |
191 | filtered.push_back(std::move(stmt)); |
192 | } |
193 | } |
194 | ret = SeqStmt(filtered); |
195 | } |
196 | |
197 | if (ret->size() == 0) { |
198 | return Evaluate(0); |
199 | } else if (ret->size() == 1) { |
200 | return ret->seq[0]; |
201 | } else { |
202 | return std::move(ret); |
203 | } |
204 | } |
205 | |
206 | Stmt VisitStmt_(const BufferStoreNode* op) final { |
207 | BufferStore store = GetRef<BufferStore>(op); |
208 | |
209 | // Helper function that returns a statement containing only the |
210 | // side effects of evaluating this BufferStore, but not the store |
211 | // itself. |
212 | auto only_side_effects = [&]() { |
213 | Array<Stmt> statements; |
214 | statements.push_back(MakeEvaluate(store->value)); |
215 | for (const auto& index : store->indices) { |
216 | statements.push_back(MakeEvaluate(index)); |
217 | } |
218 | return this->VisitStmt(SeqStmt(statements)); |
219 | }; |
220 | |
221 | if (touch_pattern_.has_value()) { |
222 | // A write that is later overwritten is a no-op. |
223 | Stmt context = context_ ? GetRef<Stmt>(context_) : store; |
224 | if (touch_pattern_->IsOverwrittenWithoutEffect(store, context)) { |
225 | touch_pattern_->RemoveStore(store); |
226 | return only_side_effects(); |
227 | } |
228 | } |
229 | |
230 | // A write whose destination is known to already contain the |
231 | // values to be written is a no-op. |
232 | // PrimExpr stores_existing_value = store->value == BufferLoad(store->buffer, store->indices); |
233 | PrimExpr stores_existing_value = store->value - BufferLoad(store->buffer, store->indices) == 0; |
234 | if (touch_pattern_.has_value()) { |
235 | Stmt context_arg = context_ ? GetRef<Stmt>(context_) : Stmt(store); |
236 | stores_existing_value = |
237 | touch_pattern_->SimplifyInContext(stores_existing_value, context_arg, analyzer_); |
238 | } else { |
239 | stores_existing_value = analyzer_->Simplify(stores_existing_value); |
240 | } |
241 | if (is_one(stores_existing_value)) { |
242 | return only_side_effects(); |
243 | } |
244 | |
245 | // If the stored value is a load from the same location, the |
246 | // statement is a no-op, regardless of contextual information. |
247 | if (const BufferLoadNode* load = store->value.as<BufferLoadNode>()) { |
248 | if (load->buffer->data.same_as(store->buffer->data) && |
249 | analyzer_->CanProveEqual(load->buffer->elem_offset, store->buffer->elem_offset) && |
250 | ArrayValueEqual(load->buffer->shape, store->buffer->shape) && |
251 | ArrayValueEqual(load->buffer->strides, store->buffer->strides) && |
252 | ArrayValueEqual(load->indices, store->indices)) { |
253 | return only_side_effects(); |
254 | } |
255 | } |
256 | |
257 | return std::move(store); |
258 | } |
259 | |
260 | private: |
261 | bool ArrayValueEqual(const Array<PrimExpr>& a, const Array<PrimExpr>& b) { |
262 | if (a.size() != b.size()) { |
263 | return false; |
264 | } |
265 | for (size_t i = 0; i < a.size(); i++) { |
266 | if (!analyzer_->CanProveEqual(a[i], b[i])) { |
267 | return false; |
268 | } |
269 | } |
270 | return true; |
271 | } |
272 | |
273 | bool HasSideEffect(const PrimExpr& value) { |
274 | return SideEffect(value) > CallEffectKind::kReadState; |
275 | } |
276 | |
277 | Stmt MakeEvaluate(PrimExpr value) { |
278 | if (SideEffect(value) > CallEffectKind::kReadState) { |
279 | return Evaluate(value); |
280 | } else { |
281 | return Evaluate(0); |
282 | } |
283 | } |
284 | Stmt MakeEvaluate(const Array<PrimExpr>& values) { |
285 | Array<Stmt> stmts; |
286 | for (PrimExpr e : values) { |
287 | if (SideEffect(e) > CallEffectKind::kReadState) { |
288 | stmts.push_back(Evaluate(e)); |
289 | } |
290 | } |
291 | |
292 | if (stmts.size() == 0) { |
293 | return Evaluate(0); |
294 | } else if (stmts.size() == 1) { |
295 | return stmts[0]; |
296 | } else { |
297 | return SeqStmt(stmts); |
298 | } |
299 | } |
300 | |
301 | std::unordered_map<const VarNode*, arith::IntSet> var_range_map_; |
302 | std::optional<ControlFlowGraph> touch_pattern_; |
303 | const StmtNode* context_; |
304 | }; |
305 | |
306 | Stmt RemoveNoOp(Stmt stmt, arith::Analyzer* analyzer, std::optional<ControlFlowGraph> touch_pattern, |
307 | const StmtNode* context) { |
308 | return NoOpRemover::Apply(std::move(stmt), analyzer, std::move(touch_pattern), context); |
309 | } |
310 | |
311 | namespace transform { |
312 | |
313 | Pass RemoveNoOp() { |
314 | auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { |
315 | std::optional<ControlFlowGraph> touch_pattern = std::nullopt; |
316 | |
317 | RemoveNoOpConfig config = ctx->GetConfig<RemoveNoOpConfig>("tir.RemoveNoOp" ) |
318 | .value_or(AttrsWithDefaultValues<RemoveNoOpConfig>()); |
319 | if (config->use_dataflow_analysis) { |
320 | touch_pattern.emplace(f->body); |
321 | } |
322 | |
323 | arith::Analyzer analyzer; |
324 | |
325 | auto* n = f.CopyOnWrite(); |
326 | n->body = NoOpRemover::Apply(std::move(n->body), &analyzer, std::move(touch_pattern), nullptr); |
327 | return f; |
328 | }; |
329 | return CreatePrimFuncPass(pass_func, 0, "tir.RemoveNoOp" , {}); |
330 | } |
331 | |
332 | TVM_REGISTER_GLOBAL("tir.transform.RemoveNoOp" ).set_body_typed(RemoveNoOp); |
333 | |
334 | } // namespace transform |
335 | |
336 | } // namespace tir |
337 | } // namespace tvm |
338 | |