1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file simplify.cc
22 * \brief Statement simplifier based on analyzer
23 */
24#include <tvm/arith/analyzer.h>
25#include <tvm/runtime/registry.h>
26#include <tvm/tir/analysis.h>
27#include <tvm/tir/builtin.h>
28#include <tvm/tir/expr.h>
29#include <tvm/tir/op.h>
30#include <tvm/tir/transform.h>
31
32#include <optional>
33
34#include "../../arith/ir_mutator_with_analyzer.h"
35#include "../../tir/analysis/control_flow_graph.h"
36
37namespace tvm {
38namespace arith {
39
40using namespace tir;
41
42struct SimplifyConfigNode : public tvm::AttrsNode<SimplifyConfigNode> {
43 bool transitively_prove_inequalities;
44 bool propagate_knowns_to_prove_conditional;
45 bool propagate_knowns_to_simplify_expressions;
46 bool convert_boolean_to_and_of_ors;
47 bool apply_constraints_to_boolean_branches;
48
49 TVM_DECLARE_ATTRS(SimplifyConfigNode, "tir.transform.SimplifyConfig") {
50 TVM_ATTR_FIELD(transitively_prove_inequalities)
51 .describe(
52 "If true, simplify conditionals with transitive combinations of scoped constraints")
53 .set_default(false);
54
55 TVM_ATTR_FIELD(propagate_knowns_to_prove_conditional)
56 .describe(
57 "If true, known buffer values are propagated and used to statically prove conditionals")
58 .set_default(false);
59
60 TVM_ATTR_FIELD(propagate_knowns_to_simplify_expressions)
61 .describe(
62 "If true, known buffer values are propagated and used to replace BufferLoad wherever "
63 "possible")
64 .set_default(false);
65
66 TVM_ATTR_FIELD(convert_boolean_to_and_of_ors)
67 .describe("If true, simplify conditionals into an AND of ORs")
68 .set_default(false);
69
70 TVM_ATTR_FIELD(apply_constraints_to_boolean_branches)
71 .describe(
72 "If true, simplify each branch of AND/OR "
73 "under a constraints provided by the other branch")
74 .set_default(false);
75 }
76
77 RewriteSimplifier::Extension GetEnabledExtensions() const {
78 RewriteSimplifier::Extension flags = RewriteSimplifier::kNone;
79 if (transitively_prove_inequalities) {
80 flags =
81 RewriteSimplifier::Extension(flags | RewriteSimplifier::kTransitivelyProveInequalities);
82 }
83 if (convert_boolean_to_and_of_ors) {
84 flags = RewriteSimplifier::Extension(flags | RewriteSimplifier::kConvertBooleanToAndOfOrs);
85 }
86 if (apply_constraints_to_boolean_branches) {
87 flags = RewriteSimplifier::Extension(flags |
88 RewriteSimplifier::kApplyConstraintsToBooleanBranches);
89 }
90 return flags;
91 }
92};
93
94class SimplifyConfig : public Attrs {
95 public:
96 TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SimplifyConfig, Attrs, SimplifyConfigNode);
97};
98
99TVM_REGISTER_NODE_TYPE(SimplifyConfigNode);
100TVM_REGISTER_PASS_CONFIG_OPTION("tir.Simplify", SimplifyConfig);
101
102class StmtSimplifier : public IRMutatorWithAnalyzer {
103 public:
104 static Stmt Apply(Stmt stmt, Analyzer* analyzer, Optional<SimplifyConfig> config_opt = NullOpt) {
105 auto config = config_opt.value_or(AttrsWithDefaultValues<arith::SimplifyConfig>());
106 analyzer->rewrite_simplify.SetEnabledExtensions(config->GetEnabledExtensions());
107
108 std::optional<ControlFlowGraph> touch_pattern = std::nullopt;
109 if (config->propagate_knowns_to_prove_conditional ||
110 config->propagate_knowns_to_simplify_expressions) {
111 touch_pattern = ControlFlowGraph(stmt);
112 }
113 StmtSimplifier simplifier(analyzer, config, std::move(touch_pattern));
114 return simplifier(std::move(stmt));
115 }
116
117 private:
118 explicit StmtSimplifier(Analyzer* analyzer, SimplifyConfig config,
119 std::optional<ControlFlowGraph> touch_pattern)
120 : IRMutatorWithAnalyzer(analyzer), config_(config), touch_pattern_(touch_pattern) {}
121
122 using Parent = IRMutatorWithAnalyzer;
123 using Parent::VisitStmt;
124 using Parent::VisitStmt_;
125
126 PrimExpr VisitExpr(const PrimExpr& expr) final {
127 if (config_->propagate_knowns_to_simplify_expressions) {
128 return touch_pattern_->SimplifyInContext(expr, current_stmt_.value(), analyzer_);
129 } else {
130 return analyzer_->Simplify(expr);
131 }
132 }
133
134 Stmt Simplify(Stmt stmt) { return operator()(std::move(stmt)); }
135
136 Stmt VisitStmt(const Stmt& stmt) override {
137 Optional<Stmt> cache = this->current_stmt_;
138 this->current_stmt_ = stmt;
139 Stmt output = Parent::VisitStmt(stmt);
140 this->current_stmt_ = std::move(cache);
141 return output;
142 }
143
144 Stmt VisitStmt_(const ForNode* op) final {
145 analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
146 With<ConstraintContext> ctx1(analyzer_, op->loop_var >= op->min);
147 With<ConstraintContext> ctx2(analyzer_, op->loop_var < op->min + op->extent);
148 return Parent::VisitStmt_(op);
149 }
150
151 bool CanInlineLetStmt(const LetStmtNode* op) {
152 if (is_const_number(op->value)) return true;
153 if (op->value.as<VarNode>()) return true;
154 // Won't face the deep expression explosion problem as in Let expression.
155 // attempt to inline as much as possible if the value integer type(can be index).
156 if (!op->value.dtype().is_int()) return false;
157 return SideEffect(op->value) <= CallEffectKind::kPure;
158 }
159
160 Stmt VisitStmt_(const LetStmtNode* op) override {
161 PrimExpr value = this->VisitExpr(op->value);
162 if (CanInlineLetStmt(op)) {
163 // it is fine to discard the let binding
164 // because the call to simplify will always inline the var.
165 analyzer_->Bind(op->var, value);
166 return this->VisitStmt(op->body);
167 } else if (SideEffect(op->value) <= CallEffectKind::kPure) {
168 // Even if we aren't replacing all occurrences, they may be
169 // necessary for proving conditional statements.
170 non_inlined_bindings_.Set(op->var, value);
171 }
172 Stmt body = this->VisitStmt(op->body);
173 if (value.same_as(op->value) && body.same_as(op->body)) {
174 return GetRef<Stmt>(op);
175 } else {
176 auto n = this->CopyOnWrite(op);
177 n->value = std::move(value);
178 n->body = std::move(body);
179 return Stmt(n);
180 }
181 }
182
183 Stmt VisitStmt_(const IfThenElseNode* op) override {
184 if (Optional<Bool> cond = ProveCondition(op->condition)) {
185 if (cond.value()->value) {
186 return this->VisitStmt(op->then_case);
187 } else if (op->else_case) {
188 return this->VisitStmt(op->else_case.value());
189 } else {
190 return Evaluate(0);
191 }
192 } else {
193 return Parent::VisitStmt_(op);
194 }
195 }
196
197 PrimExpr VisitExpr_(const CallNode* op) override {
198 if (op->op.same_as(builtin::if_then_else())) {
199 if (Optional<Bool> cond = ProveCondition(op->args[0])) {
200 if (cond.value()->value) {
201 return this->VisitExpr(op->args[1]);
202 } else {
203 return this->VisitExpr(op->args[2]);
204 }
205 }
206 }
207 return Parent::VisitExpr_(op);
208 }
209
210 Stmt VisitStmt_(const StoreNode* op) final {
211 LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead.";
212 }
213
214 // eliminate useless stores
215 Stmt VisitStmt_(const BufferStoreNode* op) final {
216 BufferStore store = Downcast<BufferStore>(Parent::VisitStmt_(op));
217 if (const BufferLoadNode* load = store->value.as<BufferLoadNode>()) {
218 if (load->buffer->data.same_as(store->buffer->data) &&
219 ArrayDeepEqual(load->indices, store->indices) &&
220 tir::ExprDeepEqual()(load->buffer->elem_offset, store->buffer->elem_offset) &&
221 ArrayDeepEqual(load->buffer->shape, store->buffer->shape) &&
222 ArrayDeepEqual(load->buffer->strides, store->buffer->strides)) {
223 return Evaluate(0);
224 }
225 }
226 return std::move(store);
227 }
228
229 private:
230 bool ArrayDeepEqual(const Array<PrimExpr>& lhs, const Array<PrimExpr>& rhs) {
231 if (lhs.size() != rhs.size()) {
232 return false;
233 }
234 for (size_t i = 0; i < lhs.size(); i++) {
235 if (!tir::ExprDeepEqual()(lhs[i], rhs[i])) {
236 return false;
237 }
238 }
239 return true;
240 }
241
242 /* \brief Internal utility for checking conditionals
243 *
244 * Uses more aggressive optimization, such as performing additional
245 * inlining and tracking known buffer values.
246 */
247 Optional<Bool> ProveCondition(PrimExpr condition) const {
248 condition = Substitute(condition, non_inlined_bindings_);
249 if (config_->propagate_knowns_to_prove_conditional) {
250 ICHECK(touch_pattern_.has_value());
251 condition = touch_pattern_->SimplifyInContext(condition, current_stmt_.value(), analyzer_);
252 } else {
253 condition = analyzer_->Simplify(condition);
254 }
255 if (const int64_t* as_int = as_const_int(condition)) {
256 return Bool(*as_int);
257 } else {
258 return NullOpt;
259 }
260 }
261
262 SimplifyConfig config_;
263 std::optional<ControlFlowGraph> touch_pattern_;
264
265 Map<Var, PrimExpr> non_inlined_bindings_;
266 Optional<Stmt> current_stmt_{NullOpt};
267};
268
269} // namespace arith
270
271namespace tir {
272
273Stmt Simplify(Stmt stmt, arith::Analyzer* analyzer) {
274 return arith::StmtSimplifier::Apply(stmt, analyzer);
275}
276
277namespace transform {
278
279Pass Simplify() {
280 auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
281 arith::Analyzer analyzer;
282 auto cfg = ctx->GetConfig<arith::SimplifyConfig>("tir.Simplify");
283
284 auto* n = f.CopyOnWrite();
285 n->body = arith::StmtSimplifier::Apply(std::move(n->body), &analyzer, cfg);
286 return f;
287 };
288 return CreatePrimFuncPass(pass_func, 0, "tir.Simplify", {});
289}
290
291TVM_REGISTER_GLOBAL("tir.transform.Simplify").set_body_typed(Simplify);
292
293} // namespace transform
294} // namespace tir
295} // namespace tvm
296