1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file simplify.cc |
22 | * \brief Statement simplifier based on analyzer |
23 | */ |
24 | #include <tvm/arith/analyzer.h> |
25 | #include <tvm/runtime/registry.h> |
26 | #include <tvm/tir/analysis.h> |
27 | #include <tvm/tir/builtin.h> |
28 | #include <tvm/tir/expr.h> |
29 | #include <tvm/tir/op.h> |
30 | #include <tvm/tir/transform.h> |
31 | |
32 | #include <optional> |
33 | |
34 | #include "../../arith/ir_mutator_with_analyzer.h" |
35 | #include "../../tir/analysis/control_flow_graph.h" |
36 | |
37 | namespace tvm { |
38 | namespace arith { |
39 | |
40 | using namespace tir; |
41 | |
42 | struct SimplifyConfigNode : public tvm::AttrsNode<SimplifyConfigNode> { |
43 | bool transitively_prove_inequalities; |
44 | bool propagate_knowns_to_prove_conditional; |
45 | bool propagate_knowns_to_simplify_expressions; |
46 | bool convert_boolean_to_and_of_ors; |
47 | bool apply_constraints_to_boolean_branches; |
48 | |
49 | TVM_DECLARE_ATTRS(SimplifyConfigNode, "tir.transform.SimplifyConfig" ) { |
50 | TVM_ATTR_FIELD(transitively_prove_inequalities) |
51 | .describe( |
52 | "If true, simplify conditionals with transitive combinations of scoped constraints" ) |
53 | .set_default(false); |
54 | |
55 | TVM_ATTR_FIELD(propagate_knowns_to_prove_conditional) |
56 | .describe( |
57 | "If true, known buffer values are propagated and used to statically prove conditionals" ) |
58 | .set_default(false); |
59 | |
60 | TVM_ATTR_FIELD(propagate_knowns_to_simplify_expressions) |
61 | .describe( |
62 | "If true, known buffer values are propagated and used to replace BufferLoad wherever " |
63 | "possible" ) |
64 | .set_default(false); |
65 | |
66 | TVM_ATTR_FIELD(convert_boolean_to_and_of_ors) |
67 | .describe("If true, simplify conditionals into an AND of ORs" ) |
68 | .set_default(false); |
69 | |
70 | TVM_ATTR_FIELD(apply_constraints_to_boolean_branches) |
71 | .describe( |
72 | "If true, simplify each branch of AND/OR " |
73 | "under a constraints provided by the other branch" ) |
74 | .set_default(false); |
75 | } |
76 | |
77 | RewriteSimplifier::Extension GetEnabledExtensions() const { |
78 | RewriteSimplifier::Extension flags = RewriteSimplifier::kNone; |
79 | if (transitively_prove_inequalities) { |
80 | flags = |
81 | RewriteSimplifier::Extension(flags | RewriteSimplifier::kTransitivelyProveInequalities); |
82 | } |
83 | if (convert_boolean_to_and_of_ors) { |
84 | flags = RewriteSimplifier::Extension(flags | RewriteSimplifier::kConvertBooleanToAndOfOrs); |
85 | } |
86 | if (apply_constraints_to_boolean_branches) { |
87 | flags = RewriteSimplifier::Extension(flags | |
88 | RewriteSimplifier::kApplyConstraintsToBooleanBranches); |
89 | } |
90 | return flags; |
91 | } |
92 | }; |
93 | |
94 | class SimplifyConfig : public Attrs { |
95 | public: |
96 | TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SimplifyConfig, Attrs, SimplifyConfigNode); |
97 | }; |
98 | |
99 | TVM_REGISTER_NODE_TYPE(SimplifyConfigNode); |
100 | TVM_REGISTER_PASS_CONFIG_OPTION("tir.Simplify" , SimplifyConfig); |
101 | |
102 | class StmtSimplifier : public IRMutatorWithAnalyzer { |
103 | public: |
104 | static Stmt Apply(Stmt stmt, Analyzer* analyzer, Optional<SimplifyConfig> config_opt = NullOpt) { |
105 | auto config = config_opt.value_or(AttrsWithDefaultValues<arith::SimplifyConfig>()); |
106 | analyzer->rewrite_simplify.SetEnabledExtensions(config->GetEnabledExtensions()); |
107 | |
108 | std::optional<ControlFlowGraph> touch_pattern = std::nullopt; |
109 | if (config->propagate_knowns_to_prove_conditional || |
110 | config->propagate_knowns_to_simplify_expressions) { |
111 | touch_pattern = ControlFlowGraph(stmt); |
112 | } |
113 | StmtSimplifier simplifier(analyzer, config, std::move(touch_pattern)); |
114 | return simplifier(std::move(stmt)); |
115 | } |
116 | |
117 | private: |
118 | explicit StmtSimplifier(Analyzer* analyzer, SimplifyConfig config, |
119 | std::optional<ControlFlowGraph> touch_pattern) |
120 | : IRMutatorWithAnalyzer(analyzer), config_(config), touch_pattern_(touch_pattern) {} |
121 | |
122 | using Parent = IRMutatorWithAnalyzer; |
123 | using Parent::VisitStmt; |
124 | using Parent::VisitStmt_; |
125 | |
126 | PrimExpr VisitExpr(const PrimExpr& expr) final { |
127 | if (config_->propagate_knowns_to_simplify_expressions) { |
128 | return touch_pattern_->SimplifyInContext(expr, current_stmt_.value(), analyzer_); |
129 | } else { |
130 | return analyzer_->Simplify(expr); |
131 | } |
132 | } |
133 | |
134 | Stmt Simplify(Stmt stmt) { return operator()(std::move(stmt)); } |
135 | |
136 | Stmt VisitStmt(const Stmt& stmt) override { |
137 | Optional<Stmt> cache = this->current_stmt_; |
138 | this->current_stmt_ = stmt; |
139 | Stmt output = Parent::VisitStmt(stmt); |
140 | this->current_stmt_ = std::move(cache); |
141 | return output; |
142 | } |
143 | |
144 | Stmt VisitStmt_(const ForNode* op) final { |
145 | analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent)); |
146 | With<ConstraintContext> ctx1(analyzer_, op->loop_var >= op->min); |
147 | With<ConstraintContext> ctx2(analyzer_, op->loop_var < op->min + op->extent); |
148 | return Parent::VisitStmt_(op); |
149 | } |
150 | |
151 | bool CanInlineLetStmt(const LetStmtNode* op) { |
152 | if (is_const_number(op->value)) return true; |
153 | if (op->value.as<VarNode>()) return true; |
154 | // Won't face the deep expression explosion problem as in Let expression. |
155 | // attempt to inline as much as possible if the value integer type(can be index). |
156 | if (!op->value.dtype().is_int()) return false; |
157 | return SideEffect(op->value) <= CallEffectKind::kPure; |
158 | } |
159 | |
160 | Stmt VisitStmt_(const LetStmtNode* op) override { |
161 | PrimExpr value = this->VisitExpr(op->value); |
162 | if (CanInlineLetStmt(op)) { |
163 | // it is fine to discard the let binding |
164 | // because the call to simplify will always inline the var. |
165 | analyzer_->Bind(op->var, value); |
166 | return this->VisitStmt(op->body); |
167 | } else if (SideEffect(op->value) <= CallEffectKind::kPure) { |
168 | // Even if we aren't replacing all occurrences, they may be |
169 | // necessary for proving conditional statements. |
170 | non_inlined_bindings_.Set(op->var, value); |
171 | } |
172 | Stmt body = this->VisitStmt(op->body); |
173 | if (value.same_as(op->value) && body.same_as(op->body)) { |
174 | return GetRef<Stmt>(op); |
175 | } else { |
176 | auto n = this->CopyOnWrite(op); |
177 | n->value = std::move(value); |
178 | n->body = std::move(body); |
179 | return Stmt(n); |
180 | } |
181 | } |
182 | |
183 | Stmt VisitStmt_(const IfThenElseNode* op) override { |
184 | if (Optional<Bool> cond = ProveCondition(op->condition)) { |
185 | if (cond.value()->value) { |
186 | return this->VisitStmt(op->then_case); |
187 | } else if (op->else_case) { |
188 | return this->VisitStmt(op->else_case.value()); |
189 | } else { |
190 | return Evaluate(0); |
191 | } |
192 | } else { |
193 | return Parent::VisitStmt_(op); |
194 | } |
195 | } |
196 | |
197 | PrimExpr VisitExpr_(const CallNode* op) override { |
198 | if (op->op.same_as(builtin::if_then_else())) { |
199 | if (Optional<Bool> cond = ProveCondition(op->args[0])) { |
200 | if (cond.value()->value) { |
201 | return this->VisitExpr(op->args[1]); |
202 | } else { |
203 | return this->VisitExpr(op->args[2]); |
204 | } |
205 | } |
206 | } |
207 | return Parent::VisitExpr_(op); |
208 | } |
209 | |
210 | Stmt VisitStmt_(const StoreNode* op) final { |
211 | LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead." ; |
212 | } |
213 | |
214 | // eliminate useless stores |
215 | Stmt VisitStmt_(const BufferStoreNode* op) final { |
216 | BufferStore store = Downcast<BufferStore>(Parent::VisitStmt_(op)); |
217 | if (const BufferLoadNode* load = store->value.as<BufferLoadNode>()) { |
218 | if (load->buffer->data.same_as(store->buffer->data) && |
219 | ArrayDeepEqual(load->indices, store->indices) && |
220 | tir::ExprDeepEqual()(load->buffer->elem_offset, store->buffer->elem_offset) && |
221 | ArrayDeepEqual(load->buffer->shape, store->buffer->shape) && |
222 | ArrayDeepEqual(load->buffer->strides, store->buffer->strides)) { |
223 | return Evaluate(0); |
224 | } |
225 | } |
226 | return std::move(store); |
227 | } |
228 | |
229 | private: |
230 | bool ArrayDeepEqual(const Array<PrimExpr>& lhs, const Array<PrimExpr>& rhs) { |
231 | if (lhs.size() != rhs.size()) { |
232 | return false; |
233 | } |
234 | for (size_t i = 0; i < lhs.size(); i++) { |
235 | if (!tir::ExprDeepEqual()(lhs[i], rhs[i])) { |
236 | return false; |
237 | } |
238 | } |
239 | return true; |
240 | } |
241 | |
242 | /* \brief Internal utility for checking conditionals |
243 | * |
244 | * Uses more aggressive optimization, such as performing additional |
245 | * inlining and tracking known buffer values. |
246 | */ |
247 | Optional<Bool> ProveCondition(PrimExpr condition) const { |
248 | condition = Substitute(condition, non_inlined_bindings_); |
249 | if (config_->propagate_knowns_to_prove_conditional) { |
250 | ICHECK(touch_pattern_.has_value()); |
251 | condition = touch_pattern_->SimplifyInContext(condition, current_stmt_.value(), analyzer_); |
252 | } else { |
253 | condition = analyzer_->Simplify(condition); |
254 | } |
255 | if (const int64_t* as_int = as_const_int(condition)) { |
256 | return Bool(*as_int); |
257 | } else { |
258 | return NullOpt; |
259 | } |
260 | } |
261 | |
262 | SimplifyConfig config_; |
263 | std::optional<ControlFlowGraph> touch_pattern_; |
264 | |
265 | Map<Var, PrimExpr> non_inlined_bindings_; |
266 | Optional<Stmt> current_stmt_{NullOpt}; |
267 | }; |
268 | |
269 | } // namespace arith |
270 | |
271 | namespace tir { |
272 | |
273 | Stmt Simplify(Stmt stmt, arith::Analyzer* analyzer) { |
274 | return arith::StmtSimplifier::Apply(stmt, analyzer); |
275 | } |
276 | |
277 | namespace transform { |
278 | |
279 | Pass Simplify() { |
280 | auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { |
281 | arith::Analyzer analyzer; |
282 | auto cfg = ctx->GetConfig<arith::SimplifyConfig>("tir.Simplify" ); |
283 | |
284 | auto* n = f.CopyOnWrite(); |
285 | n->body = arith::StmtSimplifier::Apply(std::move(n->body), &analyzer, cfg); |
286 | return f; |
287 | }; |
288 | return CreatePrimFuncPass(pass_func, 0, "tir.Simplify" , {}); |
289 | } |
290 | |
291 | TVM_REGISTER_GLOBAL("tir.transform.Simplify" ).set_body_typed(Simplify); |
292 | |
293 | } // namespace transform |
294 | } // namespace tir |
295 | } // namespace tvm |
296 | |