1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file remove_no_op.cc
22 * \brief Remove no op from the stmt
23 */
24#include <tvm/arith/analyzer.h>
25#include <tvm/runtime/registry.h>
26#include <tvm/tir/analysis.h>
27#include <tvm/tir/op.h>
28#include <tvm/tir/stmt.h>
29#include <tvm/tir/stmt_functor.h>
30#include <tvm/tir/transform.h>
31
32#include <optional>
33#include <unordered_map>
34
35#include "../../arith/const_fold.h"
36#include "../../arith/ir_mutator_with_analyzer.h"
37#include "../analysis/control_flow_graph.h"
38#include "ir_utils.h"
39
40namespace tvm {
41namespace tir {
42
43struct RemoveNoOpConfigNode : public tvm::AttrsNode<RemoveNoOpConfigNode> {
44 bool use_dataflow_analysis;
45
46 TVM_DECLARE_ATTRS(RemoveNoOpConfigNode, "tir.transform.RemoveNoOpConfig") {
47 TVM_ATTR_FIELD(use_dataflow_analysis)
48 .describe(
49 "If true, known buffer values are propagated and used "
50 "to statically prove statements as no-ops.")
51 .set_default(false);
52 }
53};
54
55class RemoveNoOpConfig : public Attrs {
56 public:
57 TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(RemoveNoOpConfig, Attrs, RemoveNoOpConfigNode);
58};
59
60TVM_REGISTER_NODE_TYPE(RemoveNoOpConfigNode);
61TVM_REGISTER_PASS_CONFIG_OPTION("tir.RemoveNoOp", RemoveNoOpConfig);
62
63// Mark the statement of each stage.
64class NoOpRemover : public arith::IRMutatorWithAnalyzer {
65 public:
66 static Stmt Apply(Stmt stmt, arith::Analyzer* analyzer,
67 std::optional<ControlFlowGraph> touch_pattern, const StmtNode* context) {
68 NoOpRemover visitor(analyzer, touch_pattern, context);
69 return visitor(std::move(stmt));
70 }
71
72 private:
73 using Parent = IRMutatorWithAnalyzer;
74 using Parent::VisitStmt;
75 using Parent::VisitStmt_;
76
77 NoOpRemover(arith::Analyzer* analyzer, std::optional<ControlFlowGraph> touch_pattern,
78 const StmtNode* context)
79 : Parent(analyzer), touch_pattern_(touch_pattern), context_(context) {}
80
81 Stmt VisitStmt_(const LetStmtNode* op) final {
82 Stmt stmt = Parent::VisitStmt_(op);
83 op = stmt.as<LetStmtNode>();
84 if (is_no_op(op->body)) {
85 return MakeEvaluate(op->value);
86 }
87
88 bool body_uses_bound_variable =
89 !UsesVar(op->body, [&](const VarNode* var) { return var == op->var.get(); });
90 if (body_uses_bound_variable && HasSideEffect(op->value)) {
91 return SeqStmt({MakeEvaluate(op->value), op->body});
92 } else if (body_uses_bound_variable) {
93 return op->body;
94 } else {
95 return stmt;
96 }
97 }
98 Stmt VisitStmt_(const AttrStmtNode* op) final {
99 if (op->attr_key == "pragma_debug_skip_region") {
100 return MakeEvaluate(0);
101 } else if (op->attr_key == attr::async_wait_queue_scope) {
102 auto wait_attrs = GetAsyncWaitAttributes(op);
103 auto wait_cnt = wait_attrs.second;
104 arith::Analyzer ana;
105 if (ana.CanProve(wait_cnt < 0)) {
106 // A negative wait count can arise if it depends on a loop variable.
107 // For example, a wait count 1 - i can be negative after loop unrolling.
108 // We assume that such wait is a nop.
109 auto inner = op->body.as<AttrStmtNode>();
110 ICHECK(inner);
111 return Parent::VisitStmt(inner->body);
112 }
113 }
114
115 Stmt stmt = Parent::VisitStmt_(op);
116 op = stmt.as<AttrStmtNode>();
117 return is_no_op(op->body) ? MakeEvaluate(op->value) : stmt;
118 }
119 Stmt VisitStmt_(const IfThenElseNode* op) final {
120 Stmt stmt = Parent::VisitStmt_(op);
121 op = stmt.as<IfThenElseNode>();
122 // Sometimes the condition can be statically determined,
123 // in which the type of the `stmt` will not be IfThenElseNode.
124 if (!op) {
125 return stmt;
126 }
127 if (op->else_case) {
128 bool no_op_else = is_no_op(op->else_case.value());
129 bool no_op_then = is_no_op(op->then_case);
130 if (no_op_else && no_op_then) {
131 return MakeEvaluate(op->condition);
132 } else if (no_op_else) {
133 return IfThenElse(op->condition, op->then_case);
134 } else if (no_op_then) {
135 return IfThenElse(!op->condition, op->else_case.value());
136 } else {
137 return stmt;
138 }
139 } else {
140 if (is_no_op(op->then_case)) {
141 return MakeEvaluate(op->condition);
142 } else {
143 return stmt;
144 }
145 }
146 }
147 Stmt VisitStmt_(const ForNode* op) final {
148 auto extent_range = arith::EvalSet(op->extent, var_range_map_);
149 if (!arith::is_neg_inf(extent_range.max()) && !arith::is_pos_inf(extent_range.max()) &&
150 analyzer_->CanProve(extent_range.max() <= 0)) {
151 return Evaluate(0);
152 }
153 var_range_map_[op->loop_var.get()] = arith::IntSet::FromMinExtent(op->min, op->extent);
154 Stmt stmt = Parent::VisitStmt_(op);
155 var_range_map_.erase(op->loop_var.get());
156 op = stmt.as<ForNode>();
157 if (is_zero(op->extent)) {
158 return Evaluate(0);
159 }
160 return is_no_op(op->body) ? MakeEvaluate({op->min, op->extent}) : stmt;
161 }
162 Stmt VisitStmt_(const AllocateNode* op) final {
163 Stmt stmt = StmtMutator::VisitStmt_(op);
164 op = stmt.as<AllocateNode>();
165 return is_no_op(op->body) ? MakeEvaluate(op->extents) : stmt;
166 }
167
168 Stmt VisitStmt_(const ProducerRealizeNode* op) final {
169 Stmt stmt = StmtMutator::VisitStmt_(op);
170 op = stmt.as<ProducerRealizeNode>();
171 return is_no_op(op->body) ? op->body : stmt;
172 }
173 Stmt VisitStmt_(const EvaluateNode* op) final {
174 if (HasSideEffect(op->value)) {
175 return GetRef<Stmt>(op);
176 } else {
177 return Evaluate(0);
178 }
179 }
180
181 Stmt VisitStmt_(const SeqStmtNode* op) final {
182 auto ret = Downcast<SeqStmt>(StmtMutator::VisitSeqStmt_(op, true));
183
184 bool need_compact = std::any_of(ret->seq.begin(), ret->seq.end(),
185 [](const auto& stmt) { return is_no_op(stmt); });
186
187 if (need_compact) {
188 Array<Stmt> filtered;
189 for (Stmt stmt : ret->seq) {
190 if (!is_no_op(stmt)) {
191 filtered.push_back(std::move(stmt));
192 }
193 }
194 ret = SeqStmt(filtered);
195 }
196
197 if (ret->size() == 0) {
198 return Evaluate(0);
199 } else if (ret->size() == 1) {
200 return ret->seq[0];
201 } else {
202 return std::move(ret);
203 }
204 }
205
206 Stmt VisitStmt_(const BufferStoreNode* op) final {
207 BufferStore store = GetRef<BufferStore>(op);
208
209 // Helper function that returns a statement containing only the
210 // side effects of evaluating this BufferStore, but not the store
211 // itself.
212 auto only_side_effects = [&]() {
213 Array<Stmt> statements;
214 statements.push_back(MakeEvaluate(store->value));
215 for (const auto& index : store->indices) {
216 statements.push_back(MakeEvaluate(index));
217 }
218 return this->VisitStmt(SeqStmt(statements));
219 };
220
221 if (touch_pattern_.has_value()) {
222 // A write that is later overwritten is a no-op.
223 Stmt context = context_ ? GetRef<Stmt>(context_) : store;
224 if (touch_pattern_->IsOverwrittenWithoutEffect(store, context)) {
225 touch_pattern_->RemoveStore(store);
226 return only_side_effects();
227 }
228 }
229
230 // A write whose destination is known to already contain the
231 // values to be written is a no-op.
232 // PrimExpr stores_existing_value = store->value == BufferLoad(store->buffer, store->indices);
233 PrimExpr stores_existing_value = store->value - BufferLoad(store->buffer, store->indices) == 0;
234 if (touch_pattern_.has_value()) {
235 Stmt context_arg = context_ ? GetRef<Stmt>(context_) : Stmt(store);
236 stores_existing_value =
237 touch_pattern_->SimplifyInContext(stores_existing_value, context_arg, analyzer_);
238 } else {
239 stores_existing_value = analyzer_->Simplify(stores_existing_value);
240 }
241 if (is_one(stores_existing_value)) {
242 return only_side_effects();
243 }
244
245 // If the stored value is a load from the same location, the
246 // statement is a no-op, regardless of contextual information.
247 if (const BufferLoadNode* load = store->value.as<BufferLoadNode>()) {
248 if (load->buffer->data.same_as(store->buffer->data) &&
249 analyzer_->CanProveEqual(load->buffer->elem_offset, store->buffer->elem_offset) &&
250 ArrayValueEqual(load->buffer->shape, store->buffer->shape) &&
251 ArrayValueEqual(load->buffer->strides, store->buffer->strides) &&
252 ArrayValueEqual(load->indices, store->indices)) {
253 return only_side_effects();
254 }
255 }
256
257 return std::move(store);
258 }
259
260 private:
261 bool ArrayValueEqual(const Array<PrimExpr>& a, const Array<PrimExpr>& b) {
262 if (a.size() != b.size()) {
263 return false;
264 }
265 for (size_t i = 0; i < a.size(); i++) {
266 if (!analyzer_->CanProveEqual(a[i], b[i])) {
267 return false;
268 }
269 }
270 return true;
271 }
272
273 bool HasSideEffect(const PrimExpr& value) {
274 return SideEffect(value) > CallEffectKind::kReadState;
275 }
276
277 Stmt MakeEvaluate(PrimExpr value) {
278 if (SideEffect(value) > CallEffectKind::kReadState) {
279 return Evaluate(value);
280 } else {
281 return Evaluate(0);
282 }
283 }
284 Stmt MakeEvaluate(const Array<PrimExpr>& values) {
285 Array<Stmt> stmts;
286 for (PrimExpr e : values) {
287 if (SideEffect(e) > CallEffectKind::kReadState) {
288 stmts.push_back(Evaluate(e));
289 }
290 }
291
292 if (stmts.size() == 0) {
293 return Evaluate(0);
294 } else if (stmts.size() == 1) {
295 return stmts[0];
296 } else {
297 return SeqStmt(stmts);
298 }
299 }
300
301 std::unordered_map<const VarNode*, arith::IntSet> var_range_map_;
302 std::optional<ControlFlowGraph> touch_pattern_;
303 const StmtNode* context_;
304};
305
306Stmt RemoveNoOp(Stmt stmt, arith::Analyzer* analyzer, std::optional<ControlFlowGraph> touch_pattern,
307 const StmtNode* context) {
308 return NoOpRemover::Apply(std::move(stmt), analyzer, std::move(touch_pattern), context);
309}
310
311namespace transform {
312
313Pass RemoveNoOp() {
314 auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
315 std::optional<ControlFlowGraph> touch_pattern = std::nullopt;
316
317 RemoveNoOpConfig config = ctx->GetConfig<RemoveNoOpConfig>("tir.RemoveNoOp")
318 .value_or(AttrsWithDefaultValues<RemoveNoOpConfig>());
319 if (config->use_dataflow_analysis) {
320 touch_pattern.emplace(f->body);
321 }
322
323 arith::Analyzer analyzer;
324
325 auto* n = f.CopyOnWrite();
326 n->body = NoOpRemover::Apply(std::move(n->body), &analyzer, std::move(touch_pattern), nullptr);
327 return f;
328 };
329 return CreatePrimFuncPass(pass_func, 0, "tir.RemoveNoOp", {});
330}
331
332TVM_REGISTER_GLOBAL("tir.transform.RemoveNoOp").set_body_typed(RemoveNoOp);
333
334} // namespace transform
335
336} // namespace tir
337} // namespace tvm
338