1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/arith/analyzer.cc
22 */
23#include <tvm/arith/analyzer.h>
24#include <tvm/runtime/registry.h>
25#include <tvm/tir/expr.h>
26#include <tvm/tir/op.h>
27
28namespace tvm {
29namespace arith {
30
31Analyzer::Analyzer()
32 : const_int_bound(this),
33 modular_set(this),
34 rewrite_simplify(this),
35 canonical_simplify(this),
36 int_set(this) {}
37
38void Analyzer::Bind(const Var& var, const PrimExpr& expr, bool allow_override) {
39 PrimExpr new_expr = expr;
40 new_expr = this->canonical_simplify(new_expr);
41 new_expr = this->rewrite_simplify(new_expr);
42
43 this->const_int_bound.Update(var, this->const_int_bound(new_expr), allow_override);
44 this->modular_set.Update(var, this->modular_set(new_expr), allow_override);
45 this->rewrite_simplify.Update(var, new_expr, allow_override);
46 this->canonical_simplify.Update(var, new_expr, allow_override);
47 this->int_set.Update(var, this->int_set(new_expr), allow_override);
48 this->transitive_comparisons.Bind(var, expr, allow_override);
49}
50
51void Analyzer::Bind(const Var& var, const Range& range, bool allow_override) {
52 ICHECK(range.defined());
53 if (tir::is_one(range->extent)) {
54 this->Bind(var, range->min, allow_override);
55 } else {
56 this->const_int_bound.Bind(var, range, allow_override);
57 this->int_set.Bind(var, range, allow_override);
58 this->transitive_comparisons.Bind(var, range, allow_override);
59 }
60 // skip modular_set
61 // skip rewrite simplify
62}
63
64void Analyzer::Bind(const Map<Var, Range>& variables, bool allow_override) {
65 for (const auto& iter : variables) {
66 this->Bind(iter.first, iter.second, allow_override);
67 }
68}
69
70void ConstraintContext::EnterWithScope() {
71 ICHECK(recovery_functions_.size() == 0);
72 // entering the scope.
73 recovery_functions_.push_back(analyzer_->const_int_bound.EnterConstraint(constraint_));
74 recovery_functions_.push_back(analyzer_->modular_set.EnterConstraint(constraint_));
75 recovery_functions_.push_back(analyzer_->rewrite_simplify.EnterConstraint(constraint_));
76 recovery_functions_.push_back(analyzer_->int_set.EnterConstraint(constraint_));
77 recovery_functions_.push_back(analyzer_->transitive_comparisons.EnterConstraint(constraint_));
78}
79
80void ConstraintContext::ExitWithScope() {
81 while (recovery_functions_.size()) {
82 auto& func = recovery_functions_.back();
83 if (func) {
84 func();
85 }
86 recovery_functions_.pop_back();
87 }
88}
89
90bool Analyzer::CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound) {
91 if (const auto* ptr = expr.as<tir::IntImmNode>()) {
92 return ptr->value >= lower_bound;
93 }
94 auto bd = this->const_int_bound(this->rewrite_simplify(expr));
95 if (bd->min_value >= lower_bound) return true;
96 return false;
97}
98
99bool Analyzer::CanProveLess(const PrimExpr& expr, int64_t upper_bound) {
100 if (const auto* ptr = expr.as<tir::IntImmNode>()) {
101 return ptr->value < upper_bound;
102 }
103 auto bd = this->const_int_bound(this->rewrite_simplify(expr));
104 if (bd->max_value < upper_bound) return true;
105 return false;
106}
107
108bool Analyzer::CanProveEqual(const PrimExpr& lhs, const PrimExpr& rhs) {
109 const auto* clhs = lhs.as<IntImmNode>();
110 const auto* crhs = rhs.as<IntImmNode>();
111 if (clhs && crhs) return clhs->value == crhs->value;
112 if (lhs->dtype.is_handle() || rhs->dtype.is_handle()) {
113 return lhs.same_as(rhs);
114 }
115 return CanProve(lhs - rhs == 0);
116}
117
118bool Analyzer::CanProve(const PrimExpr& expr) {
119 // Avoid potentially expensive simplification unless required.
120 if (const auto* ptr = expr.as<IntImmNode>()) {
121 return ptr->value != 0;
122 }
123
124 PrimExpr simplified = Simplify(expr);
125 const int64_t* as_int = tir::as_const_int(simplified);
126 return as_int && *as_int;
127}
128
129PrimExpr Analyzer::Simplify(const PrimExpr& expr, int steps) {
130 PrimExpr res = expr;
131
132 // Always starts with a canonical simplification, as some structural property
133 // of an expression might be destroyed by rewrite simplification.
134 res = this->canonical_simplify(res);
135
136 for (int i = 0; i < steps; ++i) {
137 if (tir::is_const_int(res)) {
138 return res;
139 }
140 if (i % 2 == 0) {
141 res = this->rewrite_simplify(res);
142 } else {
143 res = this->canonical_simplify(res);
144 }
145 }
146
147 return res;
148}
149
150TVM_REGISTER_GLOBAL("arith.CreateAnalyzer").set_body([](TVMArgs args, TVMRetValue* ret) {
151 using runtime::PackedFunc;
152 using runtime::TypedPackedFunc;
153 auto self = std::make_shared<Analyzer>();
154 auto f = [self](std::string name) -> PackedFunc {
155 if (name == "const_int_bound") {
156 return PackedFunc(
157 [self](TVMArgs args, TVMRetValue* ret) { *ret = self->const_int_bound(args[0]); });
158 } else if (name == "modular_set") {
159 return PackedFunc(
160 [self](TVMArgs args, TVMRetValue* ret) { *ret = self->modular_set(args[0]); });
161 } else if (name == "const_int_bound_update") {
162 return PackedFunc([self](TVMArgs args, TVMRetValue* ret) {
163 self->const_int_bound.Update(args[0], args[1], args[2]);
164 });
165 } else if (name == "Simplify") {
166 return PackedFunc([self](TVMArgs args, TVMRetValue* ret) {
167 if (args.size() == 1) {
168 *ret = self->Simplify(args[0]);
169 } else if (args.size() == 2) {
170 *ret = self->Simplify(args[0], args[1]);
171 } else {
172 LOG(FATAL) << "Invalid size of argument (" << args.size() << ")";
173 }
174 });
175 } else if (name == "rewrite_simplify") {
176 return PackedFunc(
177 [self](TVMArgs args, TVMRetValue* ret) { *ret = self->rewrite_simplify(args[0]); });
178 } else if (name == "canonical_simplify") {
179 return PackedFunc(
180 [self](TVMArgs args, TVMRetValue* ret) { *ret = self->canonical_simplify(args[0]); });
181 } else if (name == "int_set") {
182 return PackedFunc(
183 [self](TVMArgs args, TVMRetValue* ret) { *ret = self->int_set(args[0], args[1]); });
184 } else if (name == "bind") {
185 return PackedFunc([self](TVMArgs args, TVMRetValue* ret) {
186 if (args[1].IsObjectRef<Range>()) {
187 self->Bind(args[0], args[1].operator Range());
188 } else {
189 self->Bind(args[0], args[1].operator PrimExpr());
190 }
191 });
192 } else if (name == "enter_constraint_context") {
193 return PackedFunc([self](TVMArgs args, TVMRetValue* ret) {
194 // can't use make_shared due to noexcept(false) decl in destructor,
195 // see https://stackoverflow.com/a/43907314
196 auto ctx = std::shared_ptr<With<ConstraintContext>>(
197 new With<ConstraintContext>(self.get(), args[0]));
198 auto fexit = [ctx](TVMArgs, TVMRetValue*) mutable { ctx.reset(); };
199 *ret = PackedFunc(fexit);
200 });
201 } else if (name == "can_prove_equal") {
202 return PackedFunc(
203 [self](TVMArgs args, TVMRetValue* ret) { *ret = self->CanProveEqual(args[0], args[1]); });
204 }
205 return PackedFunc();
206 };
207 *ret = TypedPackedFunc<PackedFunc(std::string)>(f);
208});
209
210} // namespace arith
211} // namespace tvm
212