1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/arith/analyzer.cc |
22 | */ |
23 | #include <tvm/arith/analyzer.h> |
24 | #include <tvm/runtime/registry.h> |
25 | #include <tvm/tir/expr.h> |
26 | #include <tvm/tir/op.h> |
27 | |
28 | namespace tvm { |
29 | namespace arith { |
30 | |
31 | Analyzer::Analyzer() |
32 | : const_int_bound(this), |
33 | modular_set(this), |
34 | rewrite_simplify(this), |
35 | canonical_simplify(this), |
36 | int_set(this) {} |
37 | |
38 | void Analyzer::Bind(const Var& var, const PrimExpr& expr, bool allow_override) { |
39 | PrimExpr new_expr = expr; |
40 | new_expr = this->canonical_simplify(new_expr); |
41 | new_expr = this->rewrite_simplify(new_expr); |
42 | |
43 | this->const_int_bound.Update(var, this->const_int_bound(new_expr), allow_override); |
44 | this->modular_set.Update(var, this->modular_set(new_expr), allow_override); |
45 | this->rewrite_simplify.Update(var, new_expr, allow_override); |
46 | this->canonical_simplify.Update(var, new_expr, allow_override); |
47 | this->int_set.Update(var, this->int_set(new_expr), allow_override); |
48 | this->transitive_comparisons.Bind(var, expr, allow_override); |
49 | } |
50 | |
51 | void Analyzer::Bind(const Var& var, const Range& range, bool allow_override) { |
52 | ICHECK(range.defined()); |
53 | if (tir::is_one(range->extent)) { |
54 | this->Bind(var, range->min, allow_override); |
55 | } else { |
56 | this->const_int_bound.Bind(var, range, allow_override); |
57 | this->int_set.Bind(var, range, allow_override); |
58 | this->transitive_comparisons.Bind(var, range, allow_override); |
59 | } |
60 | // skip modular_set |
61 | // skip rewrite simplify |
62 | } |
63 | |
64 | void Analyzer::Bind(const Map<Var, Range>& variables, bool allow_override) { |
65 | for (const auto& iter : variables) { |
66 | this->Bind(iter.first, iter.second, allow_override); |
67 | } |
68 | } |
69 | |
70 | void ConstraintContext::EnterWithScope() { |
71 | ICHECK(recovery_functions_.size() == 0); |
72 | // entering the scope. |
73 | recovery_functions_.push_back(analyzer_->const_int_bound.EnterConstraint(constraint_)); |
74 | recovery_functions_.push_back(analyzer_->modular_set.EnterConstraint(constraint_)); |
75 | recovery_functions_.push_back(analyzer_->rewrite_simplify.EnterConstraint(constraint_)); |
76 | recovery_functions_.push_back(analyzer_->int_set.EnterConstraint(constraint_)); |
77 | recovery_functions_.push_back(analyzer_->transitive_comparisons.EnterConstraint(constraint_)); |
78 | } |
79 | |
80 | void ConstraintContext::ExitWithScope() { |
81 | while (recovery_functions_.size()) { |
82 | auto& func = recovery_functions_.back(); |
83 | if (func) { |
84 | func(); |
85 | } |
86 | recovery_functions_.pop_back(); |
87 | } |
88 | } |
89 | |
90 | bool Analyzer::CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound) { |
91 | if (const auto* ptr = expr.as<tir::IntImmNode>()) { |
92 | return ptr->value >= lower_bound; |
93 | } |
94 | auto bd = this->const_int_bound(this->rewrite_simplify(expr)); |
95 | if (bd->min_value >= lower_bound) return true; |
96 | return false; |
97 | } |
98 | |
99 | bool Analyzer::CanProveLess(const PrimExpr& expr, int64_t upper_bound) { |
100 | if (const auto* ptr = expr.as<tir::IntImmNode>()) { |
101 | return ptr->value < upper_bound; |
102 | } |
103 | auto bd = this->const_int_bound(this->rewrite_simplify(expr)); |
104 | if (bd->max_value < upper_bound) return true; |
105 | return false; |
106 | } |
107 | |
108 | bool Analyzer::CanProveEqual(const PrimExpr& lhs, const PrimExpr& rhs) { |
109 | const auto* clhs = lhs.as<IntImmNode>(); |
110 | const auto* crhs = rhs.as<IntImmNode>(); |
111 | if (clhs && crhs) return clhs->value == crhs->value; |
112 | if (lhs->dtype.is_handle() || rhs->dtype.is_handle()) { |
113 | return lhs.same_as(rhs); |
114 | } |
115 | return CanProve(lhs - rhs == 0); |
116 | } |
117 | |
118 | bool Analyzer::CanProve(const PrimExpr& expr) { |
119 | // Avoid potentially expensive simplification unless required. |
120 | if (const auto* ptr = expr.as<IntImmNode>()) { |
121 | return ptr->value != 0; |
122 | } |
123 | |
124 | PrimExpr simplified = Simplify(expr); |
125 | const int64_t* as_int = tir::as_const_int(simplified); |
126 | return as_int && *as_int; |
127 | } |
128 | |
129 | PrimExpr Analyzer::Simplify(const PrimExpr& expr, int steps) { |
130 | PrimExpr res = expr; |
131 | |
132 | // Always starts with a canonical simplification, as some structural property |
133 | // of an expression might be destroyed by rewrite simplification. |
134 | res = this->canonical_simplify(res); |
135 | |
136 | for (int i = 0; i < steps; ++i) { |
137 | if (tir::is_const_int(res)) { |
138 | return res; |
139 | } |
140 | if (i % 2 == 0) { |
141 | res = this->rewrite_simplify(res); |
142 | } else { |
143 | res = this->canonical_simplify(res); |
144 | } |
145 | } |
146 | |
147 | return res; |
148 | } |
149 | |
150 | TVM_REGISTER_GLOBAL("arith.CreateAnalyzer" ).set_body([](TVMArgs args, TVMRetValue* ret) { |
151 | using runtime::PackedFunc; |
152 | using runtime::TypedPackedFunc; |
153 | auto self = std::make_shared<Analyzer>(); |
154 | auto f = [self](std::string name) -> PackedFunc { |
155 | if (name == "const_int_bound" ) { |
156 | return PackedFunc( |
157 | [self](TVMArgs args, TVMRetValue* ret) { *ret = self->const_int_bound(args[0]); }); |
158 | } else if (name == "modular_set" ) { |
159 | return PackedFunc( |
160 | [self](TVMArgs args, TVMRetValue* ret) { *ret = self->modular_set(args[0]); }); |
161 | } else if (name == "const_int_bound_update" ) { |
162 | return PackedFunc([self](TVMArgs args, TVMRetValue* ret) { |
163 | self->const_int_bound.Update(args[0], args[1], args[2]); |
164 | }); |
165 | } else if (name == "Simplify" ) { |
166 | return PackedFunc([self](TVMArgs args, TVMRetValue* ret) { |
167 | if (args.size() == 1) { |
168 | *ret = self->Simplify(args[0]); |
169 | } else if (args.size() == 2) { |
170 | *ret = self->Simplify(args[0], args[1]); |
171 | } else { |
172 | LOG(FATAL) << "Invalid size of argument (" << args.size() << ")" ; |
173 | } |
174 | }); |
175 | } else if (name == "rewrite_simplify" ) { |
176 | return PackedFunc( |
177 | [self](TVMArgs args, TVMRetValue* ret) { *ret = self->rewrite_simplify(args[0]); }); |
178 | } else if (name == "canonical_simplify" ) { |
179 | return PackedFunc( |
180 | [self](TVMArgs args, TVMRetValue* ret) { *ret = self->canonical_simplify(args[0]); }); |
181 | } else if (name == "int_set" ) { |
182 | return PackedFunc( |
183 | [self](TVMArgs args, TVMRetValue* ret) { *ret = self->int_set(args[0], args[1]); }); |
184 | } else if (name == "bind" ) { |
185 | return PackedFunc([self](TVMArgs args, TVMRetValue* ret) { |
186 | if (args[1].IsObjectRef<Range>()) { |
187 | self->Bind(args[0], args[1].operator Range()); |
188 | } else { |
189 | self->Bind(args[0], args[1].operator PrimExpr()); |
190 | } |
191 | }); |
192 | } else if (name == "enter_constraint_context" ) { |
193 | return PackedFunc([self](TVMArgs args, TVMRetValue* ret) { |
194 | // can't use make_shared due to noexcept(false) decl in destructor, |
195 | // see https://stackoverflow.com/a/43907314 |
196 | auto ctx = std::shared_ptr<With<ConstraintContext>>( |
197 | new With<ConstraintContext>(self.get(), args[0])); |
198 | auto fexit = [ctx](TVMArgs, TVMRetValue*) mutable { ctx.reset(); }; |
199 | *ret = PackedFunc(fexit); |
200 | }); |
201 | } else if (name == "can_prove_equal" ) { |
202 | return PackedFunc( |
203 | [self](TVMArgs args, TVMRetValue* ret) { *ret = self->CanProveEqual(args[0], args[1]); }); |
204 | } |
205 | return PackedFunc(); |
206 | }; |
207 | *ret = TypedPackedFunc<PackedFunc(std::string)>(f); |
208 | }); |
209 | |
210 | } // namespace arith |
211 | } // namespace tvm |
212 | |