1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/arith/conjunctive_normal_form.cc |
22 | */ |
23 | |
24 | #include "conjunctive_normal_form.h" |
25 | |
26 | #include <tvm/arith/analyzer.h> |
27 | #include <tvm/tir/expr.h> |
28 | |
29 | #include <optional> |
30 | #include <unordered_map> |
31 | #include <unordered_set> |
32 | #include <utility> |
33 | #include <vector> |
34 | |
35 | #include "pattern_match.h" |
36 | #include "rewrite_simplify.h" |
37 | |
38 | namespace tvm { |
39 | namespace arith { |
40 | |
41 | namespace { |
42 | /* \brief A utility for simplifying expressions using conjunctive/disjuctive normal forms */ |
43 | class AndOfOrs { |
44 | public: |
45 | /*! \brief Construct the simplifier |
46 | * |
47 | * Convert a PrimExpr to the internal representation. |
48 | * |
49 | * \param expr The PrimExpr to be simplified. |
50 | */ |
51 | explicit AndOfOrs(const PrimExpr& expr); |
52 | |
53 | /*! \brief Convert internal representation to PrimExpr */ |
54 | PrimExpr AsPrimExpr() const; |
55 | |
56 | /*! \brief Simplify the internal representation */ |
57 | void Simplify(Analyzer* analyzer); |
58 | |
59 | private: |
60 | /*! \brief Internal utility, simplify within each group of expressions |
61 | * |
62 | * For each pair of values within a chunk, attempt to simplify them into |
63 | * a single expression. |
64 | * |
65 | * For example, |
66 | * before = (a == 5) && ((b < 10) || (b > 10)) |
67 | * after = (a == 5) && ((b != 10) || false) |
68 | */ |
69 | void SimplifyWithinChunks(Analyzer* analyzer); |
70 | |
71 | /*! \brief Internal utility, simplify across groups of expressions |
72 | * |
73 | * For each pair of chunks, if the two chunks differ by only a single |
74 | * term, attempt to simplify those differing terms. |
75 | * |
76 | * For example, |
77 | * before = ((a == 5) || (b <= 10)) && ((a == 5) || (b >= 10)) |
78 | * after = ((a == 5) || (b == 10)) && ((a == 5) || true) |
79 | */ |
80 | void SimplifyAcrossChunks(Analyzer* analyzer); |
81 | |
82 | /*! \brief Remove instances of true/false from internal representation |
83 | * |
84 | * To avoid invalidating iterators, `SimplifyWithinChunks` and |
85 | * `SimplifyAcrossChunks` may replace keys, but may not remove keys |
86 | * from the internal representation. For example, `(a < 5) && (a < |
87 | * 10)` would be simplified to `(a < 5) && true`. The |
88 | * `RemoveTrueFalse` function removes these leftover instances of |
89 | * true/false. |
90 | */ |
91 | void RemoveTrueFalse(); |
92 | |
93 | /*! \brief Internal utility function used to convert to internal form */ |
94 | static void VisitAndExpressions(const PrimExpr& expr, |
95 | std::function<void(const PrimExpr&)> callback); |
96 | /*! \brief Internal utility function used to convert to internal form */ |
97 | static void VisitOrExpressions(const PrimExpr& expr, |
98 | std::function<void(const PrimExpr&)> callback); |
99 | |
100 | /* \brief Type-safe wrapper class that represents an PrimExpr |
101 | * |
102 | * Because integer indices are used frequently through this class, |
103 | * maintaining a separation between integer indices used to access |
104 | * specific elements of the internal representation, and unique |
105 | * identifiers used to represent expressions PrimExpr, is useful. |
106 | */ |
107 | enum class Key : size_t {}; |
108 | |
109 | /*! \brief Convert a PrimExpr to a Key */ |
110 | Key GetKey(const PrimExpr& expr); |
111 | |
112 | /*! \brief Convert a Key to a PrimExpr */ |
113 | PrimExpr GetExpr(Key key) const; |
114 | |
115 | /*! \brief Attempt to simplify (a && b) |
116 | * |
117 | * If successful, will overwrite the parameters `a` and `b` with the |
118 | * simplified form. |
119 | */ |
120 | void TrySimplifyOr(Key* a, Key* b, Analyzer* analyzer); |
121 | |
122 | /*! \brief Attempt to simplify (a || b) |
123 | * |
124 | * If successful, will overwrite the parameters `a` and `b` with the |
125 | * simplified form. |
126 | */ |
127 | void TrySimplifyAnd(Key* a, Key* b, Analyzer* analyzer); |
128 | |
129 | /*! \brief The internal representation |
130 | * |
131 | * `chunks[i][j]` is the j-th expression in the i-th OR-group. |
132 | */ |
133 | std::vector<std::vector<Key>> chunks_; |
134 | |
135 | /*! \brief Mapping from internal Key to PrimExpr */ |
136 | std::unordered_map<Key, PrimExpr, StructuralHash, StructuralEqual> key_to_expr_; |
137 | |
138 | /*! \brief Mapping from PrimExpr to internal Key */ |
139 | std::unordered_map<PrimExpr, Key, StructuralHash, StructuralEqual> expr_to_key_; |
140 | |
141 | /*! \brief Cached key representing tir::Bool(true) */ |
142 | Key key_true_; |
143 | |
144 | /*! \brief Cached key representing tir::Bool(false) */ |
145 | Key key_false_; |
146 | }; |
147 | |
148 | AndOfOrs::AndOfOrs(const PrimExpr& expr) |
149 | : key_true_(GetKey(Bool(true))), key_false_(GetKey(Bool(false))) { |
150 | VisitAndExpressions(expr, [&](const PrimExpr& outer_expr) { |
151 | std::vector<Key> or_components; |
152 | VisitOrExpressions(outer_expr, [&](const PrimExpr& inner_expr) { |
153 | Key key = GetKey(inner_expr); |
154 | bool is_duplicate = std::any_of(or_components.begin(), or_components.end(), |
155 | [&](Key prev) { return prev == key; }); |
156 | if (!is_duplicate) { |
157 | or_components.push_back(key); |
158 | } |
159 | }); |
160 | |
161 | bool is_permutation = |
162 | std::any_of(chunks_.begin(), chunks_.end(), [&](const std::vector<Key>& prev_components) { |
163 | return or_components.size() == prev_components.size() && |
164 | std::is_permutation(prev_components.begin(), prev_components.end(), |
165 | or_components.begin()); |
166 | }); |
167 | if (!is_permutation) { |
168 | chunks_.push_back(std::move(or_components)); |
169 | } |
170 | }); |
171 | } |
172 | |
173 | void AndOfOrs::VisitAndExpressions(const PrimExpr& expr, |
174 | std::function<void(const PrimExpr&)> callback) { |
175 | PVar<PrimExpr> x, y, z; |
176 | if ((x && y).Match(expr)) { |
177 | // These are separate AND conditions, recurse into them in case |
178 | // they contain AND internally. |
179 | VisitAndExpressions(x.Eval(), callback); |
180 | VisitAndExpressions(y.Eval(), callback); |
181 | } else if ((x || y).Match(expr)) { |
182 | // This may be the bottom-most breakdown, but either x or y may |
183 | // themselves contain AND. (e.g. (A && B) || (C && D) should be |
184 | // split into (A || C), (A || D), (B || C), and (B || D).) |
185 | // Recurse into each, then reconstruct an OR condition. |
186 | VisitAndExpressions(x.Eval(), [&](const PrimExpr& x_part) { |
187 | VisitAndExpressions(y.Eval(), [&](const PrimExpr& y_part) { callback(x_part || y_part); }); |
188 | }); |
189 | } else { |
190 | // This is bottom-most breakdown. |
191 | callback(expr); |
192 | } |
193 | } |
194 | |
195 | void AndOfOrs::VisitOrExpressions(const PrimExpr& expr, |
196 | std::function<void(const PrimExpr&)> callback) { |
197 | PVar<PrimExpr> x, y, z; |
198 | if ((x || y).Match(expr)) { |
199 | // These are separate OR conditions, recurse into them in case |
200 | // they contain OR internally. |
201 | VisitOrExpressions(x.Eval(), callback); |
202 | VisitOrExpressions(y.Eval(), callback); |
203 | } else if ((x && y).Match(expr)) { |
204 | // This may be the bottom-most breakdown, but either x or y may |
205 | // themselves contain OR. (e.g. (A || B) && (C || D) should be |
206 | // split into (A && C), (A && D), (B && C), and (B && D).) |
207 | // Recurse into each, then reconstruct an AND condition. |
208 | VisitOrExpressions(x.Eval(), [&](const PrimExpr& x_part) { |
209 | VisitOrExpressions(y.Eval(), [&](const PrimExpr& y_part) { callback(x_part && y_part); }); |
210 | }); |
211 | } else { |
212 | // This is bottom-most breakdown. |
213 | callback(expr); |
214 | } |
215 | } |
216 | |
217 | AndOfOrs::Key AndOfOrs::GetKey(const PrimExpr& expr) { |
218 | auto it = expr_to_key_.find(expr); |
219 | if (it != expr_to_key_.end()) { |
220 | return it->second; |
221 | } |
222 | |
223 | Key key{expr_to_key_.size()}; |
224 | expr_to_key_[expr] = key; |
225 | key_to_expr_[key] = expr; |
226 | return key; |
227 | } |
228 | |
229 | PrimExpr AndOfOrs::GetExpr(AndOfOrs::Key key) const { |
230 | auto it = key_to_expr_.find(key); |
231 | ICHECK(it != key_to_expr_.end()); |
232 | return it->second; |
233 | } |
234 | |
235 | PrimExpr AndOfOrs::AsPrimExpr() const { |
236 | PrimExpr expr = Bool(true); |
237 | for (const auto& chunk : chunks_) { |
238 | PrimExpr chunk_expr = Bool(false); |
239 | for (Key j : chunk) { |
240 | chunk_expr = chunk_expr || GetExpr(j); |
241 | } |
242 | expr = expr && chunk_expr; |
243 | } |
244 | return expr; |
245 | } |
246 | |
247 | void AndOfOrs::TrySimplifyOr(Key* a_ptr, Key* b_ptr, Analyzer* analyzer) { |
248 | Key& a = *a_ptr; |
249 | Key& b = *b_ptr; |
250 | PrimExpr joint = GetExpr(a) || GetExpr(b); |
251 | PrimExpr simplified = analyzer->rewrite_simplify(joint); |
252 | if (!ExprDeepEqual()(simplified, joint)) { |
253 | if (auto* simplified_or = simplified.as<OrNode>()) { |
254 | a = GetKey(simplified_or->a); |
255 | b = GetKey(simplified_or->b); |
256 | } else { |
257 | a = key_false_; |
258 | b = GetKey(simplified); |
259 | } |
260 | } |
261 | } |
262 | |
263 | void AndOfOrs::TrySimplifyAnd(Key* a_ptr, Key* b_ptr, Analyzer* analyzer) { |
264 | Key& a = *a_ptr; |
265 | Key& b = *b_ptr; |
266 | PrimExpr joint = GetExpr(a) && GetExpr(b); |
267 | PrimExpr simplified = analyzer->rewrite_simplify(joint); |
268 | if (!ExprDeepEqual()(simplified, joint)) { |
269 | if (auto* simplified_and = simplified.as<AndNode>()) { |
270 | a = GetKey(simplified_and->a); |
271 | b = GetKey(simplified_and->b); |
272 | } else { |
273 | a = key_true_; |
274 | b = GetKey(simplified); |
275 | } |
276 | } |
277 | } |
278 | |
279 | void AndOfOrs::Simplify(Analyzer* analyzer) { |
280 | SimplifyWithinChunks(analyzer); |
281 | RemoveTrueFalse(); |
282 | SimplifyAcrossChunks(analyzer); |
283 | RemoveTrueFalse(); |
284 | } |
285 | |
286 | void AndOfOrs::SimplifyWithinChunks(Analyzer* analyzer) { |
287 | for (auto& chunk : chunks_) { |
288 | for (size_t expr_i = 0; expr_i < chunk.size(); expr_i++) { |
289 | for (size_t expr_j = expr_i + 1; expr_j < chunk.size(); expr_j++) { |
290 | Key& key_i = chunk[expr_i]; |
291 | Key& key_j = chunk[expr_j]; |
292 | |
293 | TrySimplifyOr(&key_i, &key_j, analyzer); |
294 | } |
295 | } |
296 | } |
297 | } |
298 | |
299 | void AndOfOrs::SimplifyAcrossChunks(Analyzer* analyzer) { |
300 | for (size_t i_and = 0; i_and < chunks_.size(); i_and++) { |
301 | for (size_t j_and = i_and + 1; j_and < chunks_.size(); j_and++) { |
302 | auto& i_chunk = chunks_[i_and]; |
303 | auto& j_chunk = chunks_[j_and]; |
304 | |
305 | if (i_chunk.size() == 1 && j_chunk.size() == 1) { |
306 | auto& key_i = i_chunk[0]; |
307 | auto& key_j = j_chunk[0]; |
308 | TrySimplifyAnd(&key_i, &key_j, analyzer); |
309 | continue; |
310 | } |
311 | std::unordered_set<Key> j_set(j_chunk.begin(), j_chunk.end()); |
312 | |
313 | std::optional<size_t> i_distinct_index; |
314 | for (size_t i = 0; i < i_chunk.size(); i++) { |
315 | if (!j_set.count(i_chunk[i])) { |
316 | i_distinct_index = i; |
317 | break; |
318 | } |
319 | } |
320 | |
321 | if (!i_distinct_index.has_value()) { |
322 | // I = (i_0 || i_1 || ... || i_N) |
323 | // J = (i_0 || i_1 || ... || i_N || j_0 || ... || j_N) |
324 | // I && J == I == I && true |
325 | |
326 | j_chunk = {key_true_}; |
327 | continue; |
328 | } |
329 | |
330 | std::unordered_set<Key> i_set(i_chunk.begin(), i_chunk.end()); |
331 | |
332 | std::optional<size_t> j_distinct_index; |
333 | for (size_t j = 0; j < j_chunk.size(); j++) { |
334 | if (!i_set.count(j_chunk[j])) { |
335 | j_distinct_index = j; |
336 | break; |
337 | } |
338 | } |
339 | |
340 | if (!j_distinct_index.has_value()) { |
341 | // I = (i_0 || ... || i_N || j_0 || ... || j_N) |
342 | // J = (j_0 || ... || j_N) |
343 | // I && J == J == true && J |
344 | |
345 | i_chunk = {key_true_}; |
346 | continue; |
347 | } |
348 | |
349 | if (i_chunk.size() == j_chunk.size()) { |
350 | size_t num_shared_exprs = 0; |
351 | for (const auto& j_key : j_chunk) { |
352 | if (i_set.count(j_key)) { |
353 | ++num_shared_exprs; |
354 | } |
355 | } |
356 | |
357 | if (num_shared_exprs + 1 == i_chunk.size()) { |
358 | // All but one of the expressions are shared. If the AND |
359 | // of the distinct expressions can be simplified, we can |
360 | // replace. |
361 | // |
362 | // (A or B) and (A or C) => A or (B and C) |
363 | auto& key_i = i_chunk[i_distinct_index.value()]; |
364 | auto& key_j = j_chunk[j_distinct_index.value()]; |
365 | |
366 | // When attempting to simplify (B and C), the analyzer may |
367 | // assume that A is false. |
368 | PrimExpr known = [&]() { |
369 | PrimExpr known = Bool(true); |
370 | for (const auto& key : i_chunk) { |
371 | if (&key != &key_i) { |
372 | known = known && analyzer->Simplify(!GetExpr(key)); |
373 | } |
374 | } |
375 | return known; |
376 | }(); |
377 | |
378 | With<ConstraintContext> context(analyzer, known); |
379 | TrySimplifyAnd(&key_i, &key_j, analyzer); |
380 | } |
381 | } |
382 | } |
383 | } |
384 | } |
385 | |
386 | void AndOfOrs::RemoveTrueFalse() { |
387 | for (auto& chunk : chunks_) { |
388 | // Any occurrence of True inside an OR makes the entire expression True. |
389 | if (std::any_of(chunk.begin(), chunk.end(), [&](Key key) { return key == key_true_; })) { |
390 | chunk = {key_true_}; |
391 | } else { |
392 | // Any occurrence of False inside an OR can be removed |
393 | chunk.erase( |
394 | std::remove_if(chunk.begin(), chunk.end(), [&](Key key) { return key == key_false_; }), |
395 | chunk.end()); |
396 | } |
397 | } |
398 | |
399 | // Any occurence of False inside an AND makes the entire expression False. |
400 | if (std::any_of(chunks_.begin(), chunks_.end(), |
401 | [&](const std::vector<Key>& chunk) { return chunk.size() == 0; })) { |
402 | chunks_ = {{}}; |
403 | } else { |
404 | // Any occurrence of True inside an AND can be removed. |
405 | chunks_.erase(std::remove_if(chunks_.begin(), chunks_.end(), |
406 | [&](const std::vector<Key>& chunk) { |
407 | return chunk.size() == 1 && chunk[0] == key_true_; |
408 | }), |
409 | chunks_.end()); |
410 | } |
411 | } |
412 | |
413 | // Helper utility for temporarily disabling the |
414 | // kConvertBooleanToAndOfOrs flag on an analyzer, to prevent infinite |
415 | // recursion. |
416 | class DisableAndOfOrRecursion { |
417 | public: |
418 | explicit DisableAndOfOrRecursion(Analyzer* analyzer) |
419 | : analyzer_(analyzer), cached_flags_(analyzer->rewrite_simplify.GetEnabledExtensions()) { |
420 | auto new_flags = static_cast<RewriteSimplifier::Extension>( |
421 | cached_flags_ & (~RewriteSimplifier::kConvertBooleanToAndOfOrs)); |
422 | analyzer->rewrite_simplify.SetEnabledExtensions(new_flags); |
423 | } |
424 | ~DisableAndOfOrRecursion() { analyzer_->rewrite_simplify.SetEnabledExtensions(cached_flags_); } |
425 | |
426 | DisableAndOfOrRecursion(const DisableAndOfOrRecursion&) = delete; |
427 | DisableAndOfOrRecursion& operator=(const DisableAndOfOrRecursion&) = delete; |
428 | |
429 | private: |
430 | Analyzer* analyzer_; |
431 | RewriteSimplifier::Extension cached_flags_; |
432 | }; |
433 | |
434 | } // namespace |
435 | |
436 | PrimExpr SimplifyAsAndOfOrs(const PrimExpr& expr, Analyzer* analyzer) { |
437 | DisableAndOfOrRecursion context(analyzer); |
438 | AndOfOrs repr(analyzer->Simplify(expr)); |
439 | repr.Simplify(analyzer); |
440 | return repr.AsPrimExpr(); |
441 | } |
442 | |
443 | } // namespace arith |
444 | } // namespace tvm |
445 | |