1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/arith/conjunctive_normal_form.cc
22 */
23
24#include "conjunctive_normal_form.h"
25
26#include <tvm/arith/analyzer.h>
27#include <tvm/tir/expr.h>
28
29#include <optional>
30#include <unordered_map>
31#include <unordered_set>
32#include <utility>
33#include <vector>
34
35#include "pattern_match.h"
36#include "rewrite_simplify.h"
37
38namespace tvm {
39namespace arith {
40
41namespace {
42/* \brief A utility for simplifying expressions using conjunctive/disjuctive normal forms */
43class AndOfOrs {
44 public:
45 /*! \brief Construct the simplifier
46 *
47 * Convert a PrimExpr to the internal representation.
48 *
49 * \param expr The PrimExpr to be simplified.
50 */
51 explicit AndOfOrs(const PrimExpr& expr);
52
53 /*! \brief Convert internal representation to PrimExpr */
54 PrimExpr AsPrimExpr() const;
55
56 /*! \brief Simplify the internal representation */
57 void Simplify(Analyzer* analyzer);
58
59 private:
60 /*! \brief Internal utility, simplify within each group of expressions
61 *
62 * For each pair of values within a chunk, attempt to simplify them into
63 * a single expression.
64 *
65 * For example,
66 * before = (a == 5) && ((b < 10) || (b > 10))
67 * after = (a == 5) && ((b != 10) || false)
68 */
69 void SimplifyWithinChunks(Analyzer* analyzer);
70
71 /*! \brief Internal utility, simplify across groups of expressions
72 *
73 * For each pair of chunks, if the two chunks differ by only a single
74 * term, attempt to simplify those differing terms.
75 *
76 * For example,
77 * before = ((a == 5) || (b <= 10)) && ((a == 5) || (b >= 10))
78 * after = ((a == 5) || (b == 10)) && ((a == 5) || true)
79 */
80 void SimplifyAcrossChunks(Analyzer* analyzer);
81
82 /*! \brief Remove instances of true/false from internal representation
83 *
84 * To avoid invalidating iterators, `SimplifyWithinChunks` and
85 * `SimplifyAcrossChunks` may replace keys, but may not remove keys
86 * from the internal representation. For example, `(a < 5) && (a <
87 * 10)` would be simplified to `(a < 5) && true`. The
88 * `RemoveTrueFalse` function removes these leftover instances of
89 * true/false.
90 */
91 void RemoveTrueFalse();
92
93 /*! \brief Internal utility function used to convert to internal form */
94 static void VisitAndExpressions(const PrimExpr& expr,
95 std::function<void(const PrimExpr&)> callback);
96 /*! \brief Internal utility function used to convert to internal form */
97 static void VisitOrExpressions(const PrimExpr& expr,
98 std::function<void(const PrimExpr&)> callback);
99
100 /* \brief Type-safe wrapper class that represents an PrimExpr
101 *
102 * Because integer indices are used frequently through this class,
103 * maintaining a separation between integer indices used to access
104 * specific elements of the internal representation, and unique
105 * identifiers used to represent expressions PrimExpr, is useful.
106 */
107 enum class Key : size_t {};
108
109 /*! \brief Convert a PrimExpr to a Key */
110 Key GetKey(const PrimExpr& expr);
111
112 /*! \brief Convert a Key to a PrimExpr */
113 PrimExpr GetExpr(Key key) const;
114
115 /*! \brief Attempt to simplify (a && b)
116 *
117 * If successful, will overwrite the parameters `a` and `b` with the
118 * simplified form.
119 */
120 void TrySimplifyOr(Key* a, Key* b, Analyzer* analyzer);
121
122 /*! \brief Attempt to simplify (a || b)
123 *
124 * If successful, will overwrite the parameters `a` and `b` with the
125 * simplified form.
126 */
127 void TrySimplifyAnd(Key* a, Key* b, Analyzer* analyzer);
128
129 /*! \brief The internal representation
130 *
131 * `chunks[i][j]` is the j-th expression in the i-th OR-group.
132 */
133 std::vector<std::vector<Key>> chunks_;
134
135 /*! \brief Mapping from internal Key to PrimExpr */
136 std::unordered_map<Key, PrimExpr, StructuralHash, StructuralEqual> key_to_expr_;
137
138 /*! \brief Mapping from PrimExpr to internal Key */
139 std::unordered_map<PrimExpr, Key, StructuralHash, StructuralEqual> expr_to_key_;
140
141 /*! \brief Cached key representing tir::Bool(true) */
142 Key key_true_;
143
144 /*! \brief Cached key representing tir::Bool(false) */
145 Key key_false_;
146};
147
148AndOfOrs::AndOfOrs(const PrimExpr& expr)
149 : key_true_(GetKey(Bool(true))), key_false_(GetKey(Bool(false))) {
150 VisitAndExpressions(expr, [&](const PrimExpr& outer_expr) {
151 std::vector<Key> or_components;
152 VisitOrExpressions(outer_expr, [&](const PrimExpr& inner_expr) {
153 Key key = GetKey(inner_expr);
154 bool is_duplicate = std::any_of(or_components.begin(), or_components.end(),
155 [&](Key prev) { return prev == key; });
156 if (!is_duplicate) {
157 or_components.push_back(key);
158 }
159 });
160
161 bool is_permutation =
162 std::any_of(chunks_.begin(), chunks_.end(), [&](const std::vector<Key>& prev_components) {
163 return or_components.size() == prev_components.size() &&
164 std::is_permutation(prev_components.begin(), prev_components.end(),
165 or_components.begin());
166 });
167 if (!is_permutation) {
168 chunks_.push_back(std::move(or_components));
169 }
170 });
171}
172
173void AndOfOrs::VisitAndExpressions(const PrimExpr& expr,
174 std::function<void(const PrimExpr&)> callback) {
175 PVar<PrimExpr> x, y, z;
176 if ((x && y).Match(expr)) {
177 // These are separate AND conditions, recurse into them in case
178 // they contain AND internally.
179 VisitAndExpressions(x.Eval(), callback);
180 VisitAndExpressions(y.Eval(), callback);
181 } else if ((x || y).Match(expr)) {
182 // This may be the bottom-most breakdown, but either x or y may
183 // themselves contain AND. (e.g. (A && B) || (C && D) should be
184 // split into (A || C), (A || D), (B || C), and (B || D).)
185 // Recurse into each, then reconstruct an OR condition.
186 VisitAndExpressions(x.Eval(), [&](const PrimExpr& x_part) {
187 VisitAndExpressions(y.Eval(), [&](const PrimExpr& y_part) { callback(x_part || y_part); });
188 });
189 } else {
190 // This is bottom-most breakdown.
191 callback(expr);
192 }
193}
194
195void AndOfOrs::VisitOrExpressions(const PrimExpr& expr,
196 std::function<void(const PrimExpr&)> callback) {
197 PVar<PrimExpr> x, y, z;
198 if ((x || y).Match(expr)) {
199 // These are separate OR conditions, recurse into them in case
200 // they contain OR internally.
201 VisitOrExpressions(x.Eval(), callback);
202 VisitOrExpressions(y.Eval(), callback);
203 } else if ((x && y).Match(expr)) {
204 // This may be the bottom-most breakdown, but either x or y may
205 // themselves contain OR. (e.g. (A || B) && (C || D) should be
206 // split into (A && C), (A && D), (B && C), and (B && D).)
207 // Recurse into each, then reconstruct an AND condition.
208 VisitOrExpressions(x.Eval(), [&](const PrimExpr& x_part) {
209 VisitOrExpressions(y.Eval(), [&](const PrimExpr& y_part) { callback(x_part && y_part); });
210 });
211 } else {
212 // This is bottom-most breakdown.
213 callback(expr);
214 }
215}
216
217AndOfOrs::Key AndOfOrs::GetKey(const PrimExpr& expr) {
218 auto it = expr_to_key_.find(expr);
219 if (it != expr_to_key_.end()) {
220 return it->second;
221 }
222
223 Key key{expr_to_key_.size()};
224 expr_to_key_[expr] = key;
225 key_to_expr_[key] = expr;
226 return key;
227}
228
229PrimExpr AndOfOrs::GetExpr(AndOfOrs::Key key) const {
230 auto it = key_to_expr_.find(key);
231 ICHECK(it != key_to_expr_.end());
232 return it->second;
233}
234
235PrimExpr AndOfOrs::AsPrimExpr() const {
236 PrimExpr expr = Bool(true);
237 for (const auto& chunk : chunks_) {
238 PrimExpr chunk_expr = Bool(false);
239 for (Key j : chunk) {
240 chunk_expr = chunk_expr || GetExpr(j);
241 }
242 expr = expr && chunk_expr;
243 }
244 return expr;
245}
246
247void AndOfOrs::TrySimplifyOr(Key* a_ptr, Key* b_ptr, Analyzer* analyzer) {
248 Key& a = *a_ptr;
249 Key& b = *b_ptr;
250 PrimExpr joint = GetExpr(a) || GetExpr(b);
251 PrimExpr simplified = analyzer->rewrite_simplify(joint);
252 if (!ExprDeepEqual()(simplified, joint)) {
253 if (auto* simplified_or = simplified.as<OrNode>()) {
254 a = GetKey(simplified_or->a);
255 b = GetKey(simplified_or->b);
256 } else {
257 a = key_false_;
258 b = GetKey(simplified);
259 }
260 }
261}
262
263void AndOfOrs::TrySimplifyAnd(Key* a_ptr, Key* b_ptr, Analyzer* analyzer) {
264 Key& a = *a_ptr;
265 Key& b = *b_ptr;
266 PrimExpr joint = GetExpr(a) && GetExpr(b);
267 PrimExpr simplified = analyzer->rewrite_simplify(joint);
268 if (!ExprDeepEqual()(simplified, joint)) {
269 if (auto* simplified_and = simplified.as<AndNode>()) {
270 a = GetKey(simplified_and->a);
271 b = GetKey(simplified_and->b);
272 } else {
273 a = key_true_;
274 b = GetKey(simplified);
275 }
276 }
277}
278
279void AndOfOrs::Simplify(Analyzer* analyzer) {
280 SimplifyWithinChunks(analyzer);
281 RemoveTrueFalse();
282 SimplifyAcrossChunks(analyzer);
283 RemoveTrueFalse();
284}
285
286void AndOfOrs::SimplifyWithinChunks(Analyzer* analyzer) {
287 for (auto& chunk : chunks_) {
288 for (size_t expr_i = 0; expr_i < chunk.size(); expr_i++) {
289 for (size_t expr_j = expr_i + 1; expr_j < chunk.size(); expr_j++) {
290 Key& key_i = chunk[expr_i];
291 Key& key_j = chunk[expr_j];
292
293 TrySimplifyOr(&key_i, &key_j, analyzer);
294 }
295 }
296 }
297}
298
299void AndOfOrs::SimplifyAcrossChunks(Analyzer* analyzer) {
300 for (size_t i_and = 0; i_and < chunks_.size(); i_and++) {
301 for (size_t j_and = i_and + 1; j_and < chunks_.size(); j_and++) {
302 auto& i_chunk = chunks_[i_and];
303 auto& j_chunk = chunks_[j_and];
304
305 if (i_chunk.size() == 1 && j_chunk.size() == 1) {
306 auto& key_i = i_chunk[0];
307 auto& key_j = j_chunk[0];
308 TrySimplifyAnd(&key_i, &key_j, analyzer);
309 continue;
310 }
311 std::unordered_set<Key> j_set(j_chunk.begin(), j_chunk.end());
312
313 std::optional<size_t> i_distinct_index;
314 for (size_t i = 0; i < i_chunk.size(); i++) {
315 if (!j_set.count(i_chunk[i])) {
316 i_distinct_index = i;
317 break;
318 }
319 }
320
321 if (!i_distinct_index.has_value()) {
322 // I = (i_0 || i_1 || ... || i_N)
323 // J = (i_0 || i_1 || ... || i_N || j_0 || ... || j_N)
324 // I && J == I == I && true
325
326 j_chunk = {key_true_};
327 continue;
328 }
329
330 std::unordered_set<Key> i_set(i_chunk.begin(), i_chunk.end());
331
332 std::optional<size_t> j_distinct_index;
333 for (size_t j = 0; j < j_chunk.size(); j++) {
334 if (!i_set.count(j_chunk[j])) {
335 j_distinct_index = j;
336 break;
337 }
338 }
339
340 if (!j_distinct_index.has_value()) {
341 // I = (i_0 || ... || i_N || j_0 || ... || j_N)
342 // J = (j_0 || ... || j_N)
343 // I && J == J == true && J
344
345 i_chunk = {key_true_};
346 continue;
347 }
348
349 if (i_chunk.size() == j_chunk.size()) {
350 size_t num_shared_exprs = 0;
351 for (const auto& j_key : j_chunk) {
352 if (i_set.count(j_key)) {
353 ++num_shared_exprs;
354 }
355 }
356
357 if (num_shared_exprs + 1 == i_chunk.size()) {
358 // All but one of the expressions are shared. If the AND
359 // of the distinct expressions can be simplified, we can
360 // replace.
361 //
362 // (A or B) and (A or C) => A or (B and C)
363 auto& key_i = i_chunk[i_distinct_index.value()];
364 auto& key_j = j_chunk[j_distinct_index.value()];
365
366 // When attempting to simplify (B and C), the analyzer may
367 // assume that A is false.
368 PrimExpr known = [&]() {
369 PrimExpr known = Bool(true);
370 for (const auto& key : i_chunk) {
371 if (&key != &key_i) {
372 known = known && analyzer->Simplify(!GetExpr(key));
373 }
374 }
375 return known;
376 }();
377
378 With<ConstraintContext> context(analyzer, known);
379 TrySimplifyAnd(&key_i, &key_j, analyzer);
380 }
381 }
382 }
383 }
384}
385
386void AndOfOrs::RemoveTrueFalse() {
387 for (auto& chunk : chunks_) {
388 // Any occurrence of True inside an OR makes the entire expression True.
389 if (std::any_of(chunk.begin(), chunk.end(), [&](Key key) { return key == key_true_; })) {
390 chunk = {key_true_};
391 } else {
392 // Any occurrence of False inside an OR can be removed
393 chunk.erase(
394 std::remove_if(chunk.begin(), chunk.end(), [&](Key key) { return key == key_false_; }),
395 chunk.end());
396 }
397 }
398
399 // Any occurence of False inside an AND makes the entire expression False.
400 if (std::any_of(chunks_.begin(), chunks_.end(),
401 [&](const std::vector<Key>& chunk) { return chunk.size() == 0; })) {
402 chunks_ = {{}};
403 } else {
404 // Any occurrence of True inside an AND can be removed.
405 chunks_.erase(std::remove_if(chunks_.begin(), chunks_.end(),
406 [&](const std::vector<Key>& chunk) {
407 return chunk.size() == 1 && chunk[0] == key_true_;
408 }),
409 chunks_.end());
410 }
411}
412
413// Helper utility for temporarily disabling the
414// kConvertBooleanToAndOfOrs flag on an analyzer, to prevent infinite
415// recursion.
416class DisableAndOfOrRecursion {
417 public:
418 explicit DisableAndOfOrRecursion(Analyzer* analyzer)
419 : analyzer_(analyzer), cached_flags_(analyzer->rewrite_simplify.GetEnabledExtensions()) {
420 auto new_flags = static_cast<RewriteSimplifier::Extension>(
421 cached_flags_ & (~RewriteSimplifier::kConvertBooleanToAndOfOrs));
422 analyzer->rewrite_simplify.SetEnabledExtensions(new_flags);
423 }
424 ~DisableAndOfOrRecursion() { analyzer_->rewrite_simplify.SetEnabledExtensions(cached_flags_); }
425
426 DisableAndOfOrRecursion(const DisableAndOfOrRecursion&) = delete;
427 DisableAndOfOrRecursion& operator=(const DisableAndOfOrRecursion&) = delete;
428
429 private:
430 Analyzer* analyzer_;
431 RewriteSimplifier::Extension cached_flags_;
432};
433
434} // namespace
435
436PrimExpr SimplifyAsAndOfOrs(const PrimExpr& expr, Analyzer* analyzer) {
437 DisableAndOfOrRecursion context(analyzer);
438 AndOfOrs repr(analyzer->Simplify(expr));
439 repr.Simplify(analyzer);
440 return repr.AsPrimExpr();
441}
442
443} // namespace arith
444} // namespace tvm
445