1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "gpu/jit/pass/simplify.hpp" |
18 | |
19 | #include <algorithm> |
20 | #include <iostream> |
21 | #include <sstream> |
22 | #include <string> |
23 | #include <vector> |
24 | |
25 | #include "common/cpp_compat.hpp" |
26 | #include "common/math_utils.hpp" |
27 | #include "gpu/jit/ir/ir.hpp" |
28 | #include "gpu/jit/utils/trace.hpp" |
29 | #include "gpu/jit/utils/utils.hpp" |
30 | |
31 | namespace dnnl { |
32 | namespace impl { |
33 | namespace gpu { |
34 | namespace jit { |
35 | |
36 | using namespace ir_utils; |
37 | |
38 | // Generic pattern expression, used as a wild card during pattern matching. Can |
39 | // match any expression. |
40 | class pexpr_t : public expr_impl_t { |
41 | public: |
42 | IR_DECL_EXPR_TYPE_ID(pexpr_t) |
43 | |
44 | static expr_t make(int id) { return expr_t(new pexpr_t(id)); } |
45 | |
46 | bool is_equal(const object_impl_t &obj) const override { |
47 | if (!obj.is<self_type>()) return false; |
48 | auto &other = obj.as<self_type>(); |
49 | |
50 | return id == other.id; |
51 | } |
52 | |
53 | size_t get_hash() const override { return ir_utils::get_hash(id); } |
54 | |
55 | std::string str() const override { |
56 | std::ostringstream oss; |
57 | oss << "pexpr_t(" << id << ")" ; |
58 | return oss.str(); |
59 | } |
60 | |
61 | static expr_t x() { |
62 | static thread_local expr_t x = pexpr_t::make(0); |
63 | return x; |
64 | } |
65 | static expr_t y() { |
66 | static thread_local expr_t y = pexpr_t::make(1); |
67 | return y; |
68 | } |
69 | static expr_t z() { |
70 | static thread_local expr_t z = pexpr_t::make(2); |
71 | return z; |
72 | } |
73 | |
74 | IR_DECLARE_TRAVERSERS() |
75 | |
76 | int id; |
77 | |
78 | private: |
79 | pexpr_t(int id) : expr_impl_t(_type_info(), type_t::undef()), id(id) {} |
80 | }; |
81 | |
82 | // Pattern expression for int_imm_t, used as a wild card during pattern |
83 | // matching. Can match any int_imm_t with the given value. |
84 | class pint_imm_t : public expr_impl_t { |
85 | public: |
86 | IR_DECL_EXPR_TYPE_ID(pint_imm_t) |
87 | |
88 | // Matches an integer constant with the given value. |
89 | static expr_t make(int64_t value) { |
90 | return expr_t(new pint_imm_t(-1, value)); |
91 | } |
92 | |
93 | static expr_t _0() { |
94 | static thread_local expr_t ret = pint_imm_t::make(0); |
95 | return ret; |
96 | } |
97 | static expr_t _1() { |
98 | static thread_local expr_t ret = pint_imm_t::make(1); |
99 | return ret; |
100 | } |
101 | |
102 | // Matches any integer constant. |
103 | static expr_t make_any(int64_t id) { return expr_t(new pint_imm_t(id, 0)); } |
104 | |
105 | bool matches(const int_imm_t &imm) const { |
106 | if (id == -1) return value == imm.value; |
107 | return true; |
108 | } |
109 | |
110 | bool is_equal(const object_impl_t &obj) const override { |
111 | if (!obj.is<self_type>()) return false; |
112 | auto &other = obj.as<self_type>(); |
113 | |
114 | return value == other.value; |
115 | } |
116 | |
117 | size_t get_hash() const override { return ir_utils::get_hash(value); } |
118 | |
119 | std::string str() const override { |
120 | std::ostringstream oss; |
121 | oss << "pint_imm_t(" << value << ")" ; |
122 | return oss.str(); |
123 | } |
124 | |
125 | int id; |
126 | int64_t value; |
127 | |
128 | private: |
129 | pint_imm_t(int id, int64_t value) |
130 | : expr_impl_t(_type_info(), type_t::undef()), id(id), value(value) {} |
131 | }; |
132 | |
133 | // Stores already matched pairs of <pattern expression, matched expression>. |
134 | class match_context_t { |
135 | public: |
136 | bool contains(const expr_t &ptrn) const { |
137 | return expr_matched_.count(ptrn) != 0; |
138 | } |
139 | |
140 | void set(const expr_t &ptrn, const expr_t &e) { |
141 | ir_assert(ptrn.is<pexpr_t>()); |
142 | auto ret = expr_matched_.insert({ptrn, e}); |
143 | ir_assert(ret.second); |
144 | MAYBE_UNUSED(ret); |
145 | } |
146 | |
147 | const expr_t &operator[](const expr_t &ptrn) const { |
148 | return expr_matched_.at(ptrn); |
149 | } |
150 | |
151 | template <typename T> |
152 | const T &at(const expr_t &ptrn) const { |
153 | return expr_matched_.at(ptrn).as<T>(); |
154 | } |
155 | |
156 | expr_t sub(const expr_t &expr) const; |
157 | |
158 | private: |
159 | object_eq_map_t<expr_t, expr_t> expr_matched_; |
160 | }; |
161 | |
162 | class pexpr_substitute_t : public ir_mutator_t { |
163 | public: |
164 | using ir_mutator_t::_mutate; |
165 | |
166 | pexpr_substitute_t(const match_context_t *ctx) : ctx_(ctx) {} |
167 | |
168 | object_t _mutate(const pexpr_t &obj) override { |
169 | return (*ctx_)[expr_t(obj)]; |
170 | } |
171 | |
172 | private: |
173 | const match_context_t *ctx_; |
174 | }; |
175 | |
176 | // Replaces occurrences of pattern expressions in `expr` according to the |
177 | // context. |
178 | expr_t match_context_t::sub(const expr_t &expr) const { |
179 | pexpr_substitute_t s(this); |
180 | return s.mutate(expr); |
181 | } |
182 | |
183 | // Returns true if the expression matches the pattern, false otherwise. Upon |
184 | // successful match the context contains information matched pattern |
185 | // expressions. |
186 | bool match(const expr_t &ptrn, const expr_t &expr, match_context_t &ctx); |
187 | |
188 | bool match_binary( |
189 | const expr_t &ptrn, const expr_t &expr, match_context_t &ctx) { |
190 | bool ptrn_is_binary = is_binary_op(ptrn); |
191 | bool expr_is_binary = is_binary_op(expr); |
192 | |
193 | if (!ptrn_is_binary || !expr_is_binary) return false; |
194 | |
195 | auto &ptrn_op = ptrn.as<binary_op_t>(); |
196 | auto &expr_op = expr.as<binary_op_t>(); |
197 | if (ptrn_op.op_kind != expr_op.op_kind) return false; |
198 | |
199 | match_context_t ctx_copy = ctx; |
200 | if (match(ptrn_op.a, expr_op.a, ctx_copy) |
201 | && match(ptrn_op.b, expr_op.b, ctx_copy)) { |
202 | ctx = ctx_copy; |
203 | return true; |
204 | } |
205 | return false; |
206 | } |
207 | |
208 | bool match_iif(const expr_t &ptrn, const expr_t &expr, match_context_t &ctx) { |
209 | bool ptrn_is_iif = ptrn.is<iif_t>(); |
210 | bool expr_is_iif = expr.is<iif_t>(); |
211 | |
212 | if (!ptrn_is_iif || !expr_is_iif) return false; |
213 | |
214 | auto &ptrn_iif = ptrn.as<iif_t>(); |
215 | auto &expr_iif = expr.as<iif_t>(); |
216 | |
217 | match_context_t ctx_copy = ctx; |
218 | if (match(ptrn_iif.cond, expr_iif.cond, ctx_copy) |
219 | && match(ptrn_iif.true_expr, expr_iif.true_expr, ctx_copy) |
220 | && match(ptrn_iif.false_expr, expr_iif.false_expr, ctx_copy)) { |
221 | ctx = ctx_copy; |
222 | return true; |
223 | } |
224 | |
225 | return false; |
226 | } |
227 | |
228 | bool match(const expr_t &ptrn, const expr_t &expr, match_context_t &ctx) { |
229 | if (ptrn.is_equal(expr)) return true; |
230 | |
231 | if (ptrn.is<pint_imm_t>()) { |
232 | auto &ptrn_imm = ptrn.as<pint_imm_t>(); |
233 | |
234 | bool ok = false; |
235 | if (expr.is<int_imm_t>()) { |
236 | ok = ptrn_imm.matches(expr.as<int_imm_t>()); |
237 | } else if (ptrn_imm.id == -1 && expr.is<float_imm_t>()) { |
238 | ok = (to_cpp<float>(expr) == ptrn_imm.value); |
239 | } |
240 | return ok; |
241 | } |
242 | |
243 | if (ptrn.is<pexpr_t>()) { |
244 | if (ctx.contains(ptrn)) { |
245 | if (!ctx[ptrn].is_equal(expr)) return false; |
246 | } else { |
247 | ctx.set(ptrn, expr); |
248 | } |
249 | return true; |
250 | } |
251 | |
252 | if (match_binary(ptrn, expr, ctx)) return true; |
253 | if (match_iif(ptrn, expr, ctx)) return true; |
254 | |
255 | return false; |
256 | } |
257 | |
258 | // Rewrites expression `expr` according to `from` -> `to` rule. |
259 | // Example: |
260 | // auto x = pexpr_t::x(); |
261 | // auto c = rewrite(a + a, x + x, 2 * x); |
262 | // // Now c is equal to (2 * a). |
263 | expr_t rewrite(const expr_t &expr, const expr_t &from, const expr_t &to, |
264 | bool *rewritten = nullptr) { |
265 | match_context_t ctx; |
266 | if (match(from, expr, ctx)) { |
267 | if (rewritten) *rewritten = true; |
268 | return ctx.sub(to); |
269 | } |
270 | if (rewritten) *rewritten = false; |
271 | return expr; |
272 | } |
273 | |
274 | expr_t rewrite_binary(const expr_t &expr, const expr_t &from, const expr_t &to, |
275 | bool *rewritten = nullptr) { |
276 | match_context_t ctx; |
277 | if (match_binary(from, expr, ctx)) { |
278 | if (rewritten) *rewritten = true; |
279 | return ctx.sub(to); |
280 | } |
281 | if (rewritten) *rewritten = false; |
282 | return expr; |
283 | } |
284 | |
285 | #define REWRITE(a, b) \ |
286 | do { \ |
287 | bool rewritten; \ |
288 | static thread_local auto _a = a; \ |
289 | static thread_local auto _b = b; \ |
290 | e = rewrite(e, _a, _b, &rewritten); \ |
291 | if (rewritten) return e; \ |
292 | } while (false) |
293 | |
294 | #define REWRITE_BINARY(a, b) \ |
295 | do { \ |
296 | bool rewritten; \ |
297 | static thread_local auto _a = a; \ |
298 | static thread_local auto _b = b; \ |
299 | e = rewrite_binary(e, _a, _b, &rewritten); \ |
300 | if (rewritten) return e; \ |
301 | } while (false) |
302 | |
303 | #define REWRITE_BINARY_NO_STATIC(a, b) \ |
304 | do { \ |
305 | bool rewritten; \ |
306 | e = rewrite(e, a, b, &rewritten); \ |
307 | if (rewritten) return e; \ |
308 | } while (false) |
309 | |
310 | expr_t simplify_rewrite_add(const expr_t &_e) { |
311 | auto x = pexpr_t::x(); |
312 | auto _0 = pint_imm_t::_0(); |
313 | |
314 | auto &obj = _e.as<binary_op_t>(); |
315 | ir_assert(obj.op_kind == op_kind_t::_add); |
316 | |
317 | auto e = _e; |
318 | |
319 | REWRITE_BINARY(x + _0, x); |
320 | REWRITE_BINARY(_0 + x, x); |
321 | REWRITE_BINARY(x + x, 2 * x); |
322 | |
323 | return _e; |
324 | } |
325 | |
326 | expr_t simplify_rewrite_sub(const expr_t &_e) { |
327 | auto x = pexpr_t::x(); |
328 | auto _0 = pint_imm_t::_0(); |
329 | |
330 | auto &obj = _e.as<binary_op_t>(); |
331 | ir_assert(obj.op_kind == op_kind_t::_sub); |
332 | |
333 | auto e = _e; |
334 | |
335 | REWRITE_BINARY(x - _0, x); |
336 | REWRITE_BINARY(_0 - x, -x); |
337 | REWRITE_BINARY(x - x, 0); |
338 | |
339 | return e; |
340 | } |
341 | |
342 | expr_t simplify_rewrite_mul(const expr_t &_e) { |
343 | auto x = pexpr_t::x(); |
344 | auto _0 = pint_imm_t::_0(); |
345 | auto _1 = pint_imm_t::_1(); |
346 | |
347 | auto &obj = _e.as<binary_op_t>(); |
348 | ir_assert(obj.op_kind == op_kind_t::_mul); |
349 | |
350 | auto e = _e; |
351 | |
352 | REWRITE_BINARY(x * _0, 0); |
353 | REWRITE_BINARY(_0 * x, 0); |
354 | REWRITE_BINARY(x * _1, x); |
355 | REWRITE_BINARY(_1 * x, x); |
356 | |
357 | return e; |
358 | } |
359 | |
360 | expr_t simplify_rewrite_div(const expr_t &_e) { |
361 | auto x = pexpr_t::x(); |
362 | auto y = pexpr_t::y(); |
363 | auto _0 = pint_imm_t::_0(); |
364 | auto _1 = pint_imm_t::_1(); |
365 | |
366 | auto &obj = _e.as<binary_op_t>(); |
367 | ir_assert(obj.op_kind == op_kind_t::_div); |
368 | |
369 | auto e = _e; |
370 | |
371 | REWRITE_BINARY(_0 / x, 0); |
372 | REWRITE_BINARY(x / _1, x); |
373 | REWRITE_BINARY(x / x, 1); |
374 | |
375 | return e; |
376 | } |
377 | |
378 | expr_t simplify_rewrite_mod(const expr_t &_e) { |
379 | auto x = pexpr_t::x(); |
380 | auto _0 = pint_imm_t::_0(); |
381 | auto _1 = pint_imm_t::_1(); |
382 | |
383 | auto &obj = _e.as<binary_op_t>(); |
384 | ir_assert(obj.op_kind == op_kind_t::_mod); |
385 | |
386 | auto e = _e; |
387 | |
388 | REWRITE_BINARY(x % _1, 0); |
389 | REWRITE_BINARY(0 % x, 0); |
390 | |
391 | return e; |
392 | } |
393 | |
394 | expr_t simplify_rewrite_and(const expr_t &_e) { |
395 | auto x = pexpr_t::x(); |
396 | |
397 | auto &obj = _e.as<binary_op_t>(); |
398 | ir_assert(obj.op_kind == op_kind_t::_and); |
399 | |
400 | auto e = _e; |
401 | |
402 | // Boolean rules. |
403 | if (e.type().is_bool()) { |
404 | auto _true = (e.type().is_scalar() |
405 | ? expr_t(true) |
406 | : shuffle_t::make_broadcast( |
407 | expr_t(true), e.type().elems())); |
408 | auto _false = (e.type().is_scalar() |
409 | ? expr_t(false) |
410 | : shuffle_t::make_broadcast( |
411 | expr_t(false), e.type().elems())); |
412 | REWRITE_BINARY_NO_STATIC(_true & x, x); |
413 | REWRITE_BINARY_NO_STATIC(x & _true, x); |
414 | REWRITE_BINARY_NO_STATIC(_false & x, _false); |
415 | REWRITE_BINARY_NO_STATIC(x & _false, _false); |
416 | } |
417 | |
418 | return e; |
419 | } |
420 | |
421 | expr_t simplify_rewrite_iif(const expr_t &_e) { |
422 | auto x = pexpr_t::x(); |
423 | auto y = pexpr_t::y(); |
424 | |
425 | auto e = _e; |
426 | |
427 | // Ternary operation rules. |
428 | REWRITE(iif_t::make(expr_t(true), x, y), x); |
429 | REWRITE(iif_t::make(expr_t(false), x, y), y); |
430 | REWRITE(iif_t::make(x, y, y), y); |
431 | |
432 | return e; |
433 | } |
434 | |
435 | expr_t simplify_try_ternary_rules(const expr_t &_e) { |
436 | auto x = pexpr_t::x(); |
437 | auto y = pexpr_t::y(); |
438 | auto z = pexpr_t::z(); |
439 | |
440 | auto e = _e; |
441 | |
442 | // add3 rules. |
443 | REWRITE((x + y) + z, ternary_add3(x, y, z)); |
444 | REWRITE(x + (y + z), ternary_add3(x, y, z)); |
445 | |
446 | // mad rules. |
447 | REWRITE(x + y * z, ternary_mad(x, y, z)); |
448 | REWRITE(x - y * z, ternary_mad(x, -y, z)); |
449 | REWRITE(y * z + x, ternary_mad(x, y, z)); |
450 | REWRITE(y * z - x, ternary_mad(-x, y, z)); |
451 | |
452 | return e; |
453 | } |
454 | |
455 | #undef REWRITE |
456 | #undef REWRITE_NO_STATIC |
457 | |
458 | class term_rewrite_transformer_t : public ir_mutator_t { |
459 | public: |
460 | object_t _mutate(const binary_op_t &obj) override { |
461 | auto e = ir_mutator_t::_mutate(obj); |
462 | switch (obj.op_kind) { |
463 | case op_kind_t::_add: return simplify_rewrite_add(e); |
464 | case op_kind_t::_sub: return simplify_rewrite_sub(e); |
465 | case op_kind_t::_mul: return simplify_rewrite_mul(e); |
466 | case op_kind_t::_div: return simplify_rewrite_div(e); |
467 | case op_kind_t::_mod: return simplify_rewrite_mod(e); |
468 | case op_kind_t::_and: return simplify_rewrite_and(e); |
469 | default: return e; |
470 | } |
471 | } |
472 | object_t _mutate(const iif_t &obj) override { |
473 | auto e = ir_mutator_t::_mutate(obj); |
474 | return simplify_rewrite_iif(e); |
475 | } |
476 | }; |
477 | |
478 | expr_t simplify_rewrite(const expr_t &e) { |
479 | expr_t ret; |
480 | if (is_const(e) || is_var(e)) { |
481 | ret = e; |
482 | } else { |
483 | term_rewrite_transformer_t trt; |
484 | ret = trt.mutate(e); |
485 | } |
486 | return ret; |
487 | } |
488 | |
489 | class ternary_rewrite_transformer_t : public ir_mutator_t { |
490 | public: |
491 | object_t _mutate(const binary_op_t &obj) override { |
492 | return mutate_expr(obj); |
493 | } |
494 | object_t _mutate(const iif_t &obj) override { return mutate_expr(obj); } |
495 | |
496 | template <typename T> |
497 | expr_t mutate_expr(const T &obj) { |
498 | auto e_old = ir_mutator_t::_mutate(obj); |
499 | auto e = simplify_try_ternary_rules(e_old); |
500 | if (e.is_same(e_old)) return e_old; |
501 | return mutate(e); |
502 | } |
503 | }; |
504 | |
505 | expr_t simplify_rewrite_with_ternary(const expr_t &e, bool recursive) { |
506 | expr_t ret; |
507 | if (is_const(e) || is_var(e)) { |
508 | ret = e; |
509 | } else if (!recursive) { |
510 | ret = simplify_try_ternary_rules(e); |
511 | } else { |
512 | ternary_rewrite_transformer_t trt; |
513 | ret = trt.mutate(e); |
514 | } |
515 | return ret; |
516 | } |
517 | |
518 | class cmp_simplifier_t : public ir_mutator_t { |
519 | public: |
520 | object_t _mutate(const binary_op_t &obj) override { |
521 | auto e = ir_mutator_t::_mutate(obj); |
522 | if (!is_binary_cmp_op(e)) return e; |
523 | |
524 | e = simplify_mod_comparison(e); |
525 | |
526 | return e; |
527 | } |
528 | |
529 | static expr_t reduce_lhs_rhs(const expr_t &e) { |
530 | if (!is_binary_cmp_op(e)) return e; |
531 | |
532 | auto &op = e.as<binary_op_t>(); |
533 | |
534 | // Rule: |
535 | // (c0 * x op c1) or (x * c0 op c1) -> |
536 | // x new_op (c1 / c0) if abs(c1) % abs(c0) == 0 |
537 | // new_op == op or new_op == negate_cmp_op(op) |
538 | expr_t c0; |
539 | expr_t c1 = op.b; |
540 | expr_t x; |
541 | |
542 | if (!is_const(c1)) return e; |
543 | if (!is_binary_op(op.a, op_kind_t::_mul)) return e; |
544 | |
545 | auto &a_op = op.a.as<binary_op_t>(); |
546 | if (is_const(a_op.a)) { |
547 | c0 = a_op.a; |
548 | x = a_op.b; |
549 | } else if (is_const(a_op.b)) { |
550 | x = a_op.a; |
551 | c0 = a_op.b; |
552 | } |
553 | |
554 | if (c0.is_empty()) return e; |
555 | if (!c0.type().is_int()) return e; |
556 | if (!c1.type().is_int()) return e; |
557 | |
558 | auto i_c0 = to_cpp<int64_t>(c0); |
559 | auto i_c1 = to_cpp<int64_t>(c1); |
560 | |
561 | bool is_c0_neg = (i_c0 < 0); |
562 | bool sign = ((i_c0 < 0) != (i_c1 < 0)); |
563 | i_c0 = std::abs(i_c0); |
564 | i_c1 = std::abs(i_c1); |
565 | |
566 | bool has_mod = (i_c1 % i_c0 != 0); |
567 | if (has_mod |
568 | && utils::one_of(op.op_kind, op_kind_t::_eq, op_kind_t::_ne)) |
569 | return e; |
570 | |
571 | auto new_op_kind = (is_c0_neg ? negate_cmp_op(op.op_kind) : op.op_kind); |
572 | int64_t div = i_c1 / i_c0; |
573 | if (has_mod) { |
574 | switch (new_op_kind) { |
575 | case op_kind_t::_ge: |
576 | case op_kind_t::_gt: |
577 | new_op_kind = op_kind_t::_ge; |
578 | div = (sign ? div : div + 1); |
579 | break; |
580 | case op_kind_t::_le: |
581 | case op_kind_t::_lt: |
582 | new_op_kind = op_kind_t::_le; |
583 | div = (sign ? div + 1 : div); |
584 | break; |
585 | default: ir_error_not_expected(); |
586 | } |
587 | } |
588 | |
589 | return binary_op_t::make(new_op_kind, x, (sign ? -1 : 1) * div); |
590 | } |
591 | |
592 | static expr_t simplify_mod_comparison(const expr_t &e) { |
593 | if (!is_binary_cmp_op(e)) return e; |
594 | |
595 | auto &op = e.as<binary_op_t>(); |
596 | |
597 | // Use the following inequalities: |
598 | // 0 <= (x % c0) < c0 |
599 | if (!is_binary_op(op.a, op_kind_t::_mod)) return e; |
600 | if (!is_const(op.b)) return e; |
601 | |
602 | auto &a_op = op.a.as<binary_op_t>(); |
603 | if (!is_const(a_op.b)) return e; |
604 | |
605 | auto &c0 = a_op.b; |
606 | ir_assert(to_cpp<int64_t>(c0) > 0) << e; |
607 | |
608 | // Comparison against a constant is a continuous function, just check |
609 | // boundary points. |
610 | auto cond0 = binary_op_t::make(op.op_kind, 0, op.b); |
611 | auto cond1 = binary_op_t::make(op.op_kind, c0 - 1, op.b); |
612 | |
613 | bool is_cond0 = to_cpp<bool>(const_fold_non_recursive(cond0)); |
614 | bool is_cond1 = to_cpp<bool>(const_fold_non_recursive(cond1)); |
615 | |
616 | // Conditions are equal, can prove. |
617 | if (is_cond0 == is_cond1) return expr_t(is_cond0); |
618 | |
619 | // Can't prove, return the original expression. |
620 | return e; |
621 | } |
622 | }; |
623 | |
624 | expr_t simplify_comparison(const expr_t &e) { |
625 | return cmp_simplifier_t().mutate(e); |
626 | } |
627 | |
628 | class range_simplifier_t : public ir_mutator_t { |
629 | public: |
630 | range_simplifier_t(const constraint_set_t &cset) : cset(cset) {} |
631 | |
632 | object_t _mutate(const var_t &obj) override { |
633 | expr_t value; |
634 | if (cset.is_single_value(obj, value)) return std::move(value); |
635 | return obj; |
636 | } |
637 | |
638 | const constraint_set_t &cset; |
639 | }; |
640 | |
641 | // Finds all constant operands on an N-ary operation, returns the folded |
642 | // constant in `const_arg` and the remaining operands in `other_args`. |
643 | void split_const_nary_op_arg(op_kind_t op_kind, const std::vector<expr_t> &args, |
644 | expr_t &const_arg, std::vector<expr_t> &other_args) { |
645 | other_args.resize(0); |
646 | |
647 | const_arg = expr_t(); |
648 | for (auto &a : args) { |
649 | if (is_const(a)) { |
650 | if (const_arg.is_empty()) { |
651 | const_arg = a; |
652 | continue; |
653 | } |
654 | const_arg = const_fold_non_recursive( |
655 | binary_op_t::make(op_kind, const_arg, a)); |
656 | } else { |
657 | other_args.push_back(a); |
658 | } |
659 | } |
660 | } |
661 | |
662 | // Folds all constant operands into one. |
663 | void fold_const_nary_op_args(op_kind_t op_kind, const std::vector<expr_t> &args, |
664 | std::vector<expr_t> &new_args) { |
665 | expr_t c; |
666 | split_const_nary_op_arg(op_kind, args, c, new_args); |
667 | if (c.is_empty()) return; |
668 | if (op_kind == op_kind_t::_mul && is_zero(c)) { |
669 | new_args.clear(); |
670 | new_args.push_back(c); |
671 | return; |
672 | } |
673 | if (op_kind == op_kind_t::_mul && is_one(c)) return; |
674 | if (op_kind == op_kind_t::_add && is_zero(c)) return; |
675 | new_args.push_back(c); |
676 | } |
677 | |
678 | expr_t cvt_mul_to_nary_op(const expr_t &a, const expr_t &b) { |
679 | auto *a_nary = a.as_ptr<nary_op_t>(); |
680 | auto *b_nary = b.as_ptr<nary_op_t>(); |
681 | |
682 | if (a_nary) ir_assert(a_nary->op_kind == op_kind_t::_mul); |
683 | if (b_nary) ir_assert(b_nary->op_kind == op_kind_t::_mul); |
684 | |
685 | auto a_args = cvt_expr_to_nary_op_args(a); |
686 | auto b_args = cvt_expr_to_nary_op_args(b); |
687 | |
688 | std::vector<expr_t> args; |
689 | args.insert(args.end(), a_args.begin(), a_args.end()); |
690 | args.insert(args.end(), b_args.begin(), b_args.end()); |
691 | return make_nary_op(op_kind_t::_mul, args); |
692 | } |
693 | |
694 | class nary_op_visitor_t : public ir_visitor_t { |
695 | public: |
696 | using ir_visitor_t::_visit; |
697 | |
698 | virtual void _visit(const nary_op_t &obj) { visit(obj.args); } |
699 | }; |
700 | |
701 | class nary_op_mutator_t : public ir_mutator_t { |
702 | public: |
703 | using ir_mutator_t::_mutate; |
704 | |
705 | virtual object_t _mutate(const nary_op_t &obj) { |
706 | auto args = mutate(obj.args); |
707 | if (ir_utils::is_equal(args, obj.args)) return obj; |
708 | return make_nary_op(obj.op_kind, args); |
709 | } |
710 | }; |
711 | |
712 | class nary_op_transformer_t : public nary_op_mutator_t { |
713 | public: |
714 | using nary_op_mutator_t::_mutate; |
715 | |
716 | object_t _mutate(const binary_op_t &obj) override { |
717 | // Skip vector types. |
718 | if (!obj.type.is_scalar()) return nary_op_mutator_t::_mutate(obj); |
719 | switch (obj.op_kind) { |
720 | case op_kind_t::_add: |
721 | case op_kind_t::_sub: |
722 | case op_kind_t::_mul: { |
723 | auto a = mutate(obj.a); |
724 | auto b = obj.b; |
725 | auto nary_op_kind = obj.op_kind; |
726 | if (obj.op_kind == op_kind_t::_sub) { |
727 | nary_op_kind = op_kind_t::_add; |
728 | b *= -1; |
729 | } |
730 | b = mutate(b); |
731 | return make_nary_op(nary_op_kind, {a, b}); |
732 | } |
733 | default: return nary_op_mutator_t::_mutate(obj); |
734 | } |
735 | } |
736 | }; |
737 | |
738 | class nary_op_flattener_t : public nary_op_mutator_t { |
739 | public: |
740 | object_t _mutate(const nary_op_t &obj) override { |
741 | std::vector<expr_t> args; |
742 | for (auto &a : obj.args) { |
743 | auto new_a = mutate(a); |
744 | auto *nary = new_a.as_ptr<nary_op_t>(); |
745 | if (nary && nary->op_kind == obj.op_kind) { |
746 | args.insert(args.end(), nary->args.begin(), nary->args.end()); |
747 | } else { |
748 | args.push_back(new_a); |
749 | } |
750 | } |
751 | return make_nary_op(obj.op_kind, args); |
752 | } |
753 | }; |
754 | |
755 | expr_t nary_op_flatten(const expr_t &e) { |
756 | return nary_op_flattener_t().mutate(e); |
757 | } |
758 | |
759 | class mul_nary_op_expander_t : public nary_op_flattener_t { |
760 | public: |
761 | object_t _mutate(const nary_op_t &obj) override { |
762 | auto flat_object = nary_op_flattener_t::_mutate(obj); |
763 | if (obj.op_kind != op_kind_t::_mul) { return flat_object; } |
764 | |
765 | auto args = flat_object.as<nary_op_t>().args; |
766 | std::vector<expr_t> new_args; |
767 | for (size_t i = 0; i < args.size(); i++) { |
768 | auto *nary = args[i].as_ptr<nary_op_t>(); |
769 | if (nary && nary->op_kind != op_kind_t::_add) { |
770 | ir_error_not_expected(); |
771 | } |
772 | auto i_args = cvt_expr_to_nary_op_args(args[i]); |
773 | if (new_args.empty()) { |
774 | new_args = i_args; |
775 | continue; |
776 | } |
777 | std::vector<expr_t> next_args; |
778 | for (auto &a : new_args) |
779 | for (auto &b : i_args) |
780 | next_args.push_back(cvt_mul_to_nary_op(a, b)); |
781 | |
782 | new_args = next_args; |
783 | } |
784 | return make_nary_op(op_kind_t::_add, new_args); |
785 | } |
786 | }; |
787 | |
788 | class nary_op_canonical_verifier_t : public nary_op_visitor_t { |
789 | public: |
790 | bool is_canonical() const { return is_canonical_; } |
791 | |
792 | void _visit(const binary_op_t &obj) override { |
793 | // Skip vector types. |
794 | if (!obj.type.is_scalar()) { |
795 | visit_new_scope(obj); |
796 | return; |
797 | } |
798 | switch (obj.op_kind) { |
799 | // These operations must be converted to nary_op_t at this point. |
800 | case op_kind_t::_add: |
801 | case op_kind_t::_sub: |
802 | case op_kind_t::_mul: set_canonical_false(); break; |
803 | default: { |
804 | // Assume new scope here, n-ary operations from different |
805 | // scopes can't be merged. |
806 | visit_new_scope(obj); |
807 | break; |
808 | } |
809 | } |
810 | } |
811 | |
812 | void _visit(const iif_t &obj) override { visit_new_scope(obj); } |
813 | |
814 | void _visit(const load_t &obj) override { visit_new_scope(obj); } |
815 | |
816 | void _visit(const ptr_t &obj) override { visit_new_scope(obj); } |
817 | |
818 | void _visit(const unary_op_t &obj) override { visit_new_scope(obj); } |
819 | |
820 | void _visit(const nary_op_t &obj) override { |
821 | if (parent_nary_) { |
822 | if (!(parent_nary_->op_kind == op_kind_t::_add |
823 | && obj.op_kind == op_kind_t::_mul)) { |
824 | // Multiplications must be expanded at this point. |
825 | set_canonical_false(); |
826 | return; |
827 | } |
828 | } |
829 | |
830 | auto *old_parent_nary = parent_nary_; |
831 | parent_nary_ = &obj; |
832 | visit(obj.args); |
833 | parent_nary_ = old_parent_nary; |
834 | } |
835 | |
836 | private: |
837 | void set_canonical_false() { is_canonical_ = false; } |
838 | |
839 | template <typename T> |
840 | void visit_new_scope(const T &obj) { |
841 | auto *old_parent_nary = parent_nary_; |
842 | parent_nary_ = nullptr; |
843 | nary_op_visitor_t::_visit(obj); |
844 | parent_nary_ = old_parent_nary; |
845 | } |
846 | |
847 | bool is_canonical_ = true; |
848 | const nary_op_t *parent_nary_ = nullptr; |
849 | }; |
850 | |
851 | // Checks if the expression is in the canonical N-ary form. |
852 | bool is_nary_op_canonical(const expr_t &e) { |
853 | nary_op_canonical_verifier_t v; |
854 | v.visit(e); |
855 | return v.is_canonical(); |
856 | } |
857 | |
858 | class nary_op_back_transformer_t : public nary_op_mutator_t { |
859 | public: |
860 | object_t _mutate(const nary_op_t &obj) { |
861 | auto new_obj = nary_op_mutator_t::_mutate(obj); |
862 | auto &nary = new_obj.as<nary_op_t>(); |
863 | ir_assert(nary.args.size() > 0) << new_obj; |
864 | |
865 | if (nary.args.size() == 1) return nary.args[0]; |
866 | |
867 | if (nary.op_kind == op_kind_t::_add) { |
868 | expr_t ret = nary.args[0] + nary.args[1]; |
869 | for (size_t i = 2; i < nary.args.size(); i++) |
870 | ret += nary.args[i]; |
871 | return std::move(ret); |
872 | } else if (nary.op_kind == op_kind_t::_mul) { |
873 | expr_t ret = nary.args[0] * nary.args[1]; |
874 | for (size_t i = 2; i < nary.args.size(); i++) |
875 | ret *= nary.args[i]; |
876 | return std::move(ret); |
877 | } |
878 | ir_error_not_expected(); |
879 | return expr_t(); |
880 | } |
881 | }; |
882 | |
883 | // Stores factorization of an expression in the canonical (normalized) form: |
884 | // expr = (f(0), f(1), f(2), ... f(n)) |
885 | // f(0), ... f(n-1) are non-constant expressions, f(n) is a constant. |
886 | class factored_expr_t : public expr_impl_t { |
887 | public: |
888 | IR_DECL_EXPR_TYPE_ID(factored_expr_t); |
889 | |
890 | static expr_t make(const expr_t &e) { |
891 | return expr_t(new factored_expr_t(e)); |
892 | } |
893 | |
894 | static expr_t make(const type_t &type, const std::vector<expr_t> &factors) { |
895 | return expr_t(new factored_expr_t(type, factors)); |
896 | } |
897 | |
898 | bool is_equal(const object_impl_t &obj) const override { |
899 | if (!obj.is<self_type>()) return false; |
900 | auto &other = obj.as<self_type>(); |
901 | |
902 | if (factors.size() != other.factors.size()) return false; |
903 | if (!factors.back().is_equal(other.factors.back())) return false; |
904 | |
905 | auto common = intersect(obj); |
906 | auto &f_common = common.as<factored_expr_t>(); |
907 | return f_common.factors.size() == factors.size(); |
908 | } |
909 | |
910 | // Constant factor is ignored during comparison. |
911 | bool is_equal_ignore_const(const object_impl_t &obj) const { |
912 | if (!obj.is<self_type>()) return false; |
913 | auto &other = obj.as<self_type>(); |
914 | |
915 | if (factors.size() != other.factors.size()) return false; |
916 | |
917 | auto common = intersect_ignore_const(obj); |
918 | auto &f_common = common.as<factored_expr_t>(); |
919 | return f_common.factors.size() == factors.size(); |
920 | } |
921 | |
922 | size_t get_hash() const override { return ir_utils::get_hash(factors); } |
923 | |
924 | std::string str() const override { |
925 | std::ostringstream oss; |
926 | oss << "f(" ; |
927 | for (size_t i = 0; i < factors.size(); i++) { |
928 | oss << (i != 0 ? " x " : "" ) << factors[i]; |
929 | } |
930 | |
931 | if (factors.empty()) oss << "1" ; |
932 | |
933 | oss << ")" ; |
934 | return oss.str(); |
935 | } |
936 | |
937 | expr_t expr() const { |
938 | if (factors.size() > 1 && jit::is_one(factors.back())) { |
939 | std::vector<expr_t> f(factors.begin(), factors.end() - 1); |
940 | return make_nary_op(op_kind_t::_mul, f); |
941 | } |
942 | return make_nary_op(op_kind_t::_mul, factors); |
943 | } |
944 | |
945 | expr_t const_factor() const { return factors.back(); } |
946 | |
947 | bool is_one() const { |
948 | return (factors.size() == 1) && jit::is_one(factors[0]); |
949 | } |
950 | |
951 | bool is_const() const { return factors.size() == 1; } |
952 | |
953 | // Returns multiplication of this and other as factored_expr_t. |
954 | expr_t merge(const expr_t &other) const { |
955 | auto &f_other = other.as<factored_expr_t>(); |
956 | std::vector<expr_t> merged_factors(factors.begin(), factors.end()); |
957 | merged_factors.insert(merged_factors.end(), f_other.factors.begin(), |
958 | f_other.factors.end()); |
959 | return factored_expr_t::make(type, merged_factors); |
960 | } |
961 | |
962 | // Returns common factors of this and other as factored_expr_t. |
963 | expr_t intersect(const expr_t &other) const { |
964 | return intersect_impl(other, false); |
965 | } |
966 | |
967 | // Returns common factors of this and other as factored_expr_t (ignores |
968 | // constant factors). |
969 | expr_t intersect_ignore_const(const expr_t &other) const { |
970 | return intersect_impl(other, true); |
971 | } |
972 | |
973 | // Returns factors of this not presented in other as factored_expr_t. |
974 | expr_t diff(const expr_t &_other) const { |
975 | auto &other = _other.as<factored_expr_t>(); |
976 | object_eq_map_t<expr_t, int> f_map; |
977 | // Skip constant factor. |
978 | for (size_t i = 0; i < factors.size() - 1; i++) |
979 | f_map[factors[i]]++; |
980 | |
981 | for (auto &e : other.factors) { |
982 | if (f_map[e] > 0) f_map[e]--; |
983 | } |
984 | std::vector<expr_t> diff_factors; |
985 | for (auto &kv : f_map) { |
986 | for (int i = 0; i < kv.second; i++) |
987 | diff_factors.push_back(kv.first); |
988 | } |
989 | // Handle constant factor. |
990 | int64_t a_const = to_cpp<int64_t>(factors.back()); |
991 | int64_t b_const = to_cpp<int64_t>(other.factors.back()); |
992 | if (a_const != 0 && b_const != 0) { |
993 | int64_t ab_gcd = ((a_const < 0) && (b_const < 0)) ? -1 : 1; |
994 | ab_gcd *= math::gcd(std::abs(a_const), std::abs(b_const)); |
995 | diff_factors.push_back(to_expr(a_const / ab_gcd, type)); |
996 | } else if (a_const != 0 || b_const != 0) { |
997 | diff_factors.push_back(to_expr(a_const, type)); |
998 | } |
999 | |
1000 | return factored_expr_t::make(type, diff_factors); |
1001 | } |
1002 | |
1003 | // Returns factors of this reduced by factors of other as factored_expr_t. |
1004 | // This object must be reducible by other. |
1005 | expr_t reduce(const expr_t &other) const { |
1006 | auto &f_other = other.as<factored_expr_t>(); |
1007 | auto f_common = intersect(other); |
1008 | auto diff_other = f_other.diff(f_common); |
1009 | // Other must be reducible. |
1010 | ir_assert(diff_other.as<factored_expr_t>().is_one()) << diff_other; |
1011 | return diff(f_common); |
1012 | } |
1013 | |
1014 | // Returns true if this can be reduced by other. |
1015 | bool is_reducible(const expr_t &other) const { |
1016 | auto f_common = intersect(other); |
1017 | return f_common.is_equal(other); |
1018 | } |
1019 | |
1020 | static expr_t reduce(expr_t &a, expr_t &b) { |
1021 | auto fa_expr = factored_expr_t::make(a); |
1022 | auto fb_expr = factored_expr_t::make(b); |
1023 | auto &fa = fa_expr.as<factored_expr_t>(); |
1024 | auto &fb = fb_expr.as<factored_expr_t>(); |
1025 | auto f_common = fa.intersect(&fb); |
1026 | a = fa.reduce(f_common).as<factored_expr_t>().expr(); |
1027 | b = fb.reduce(f_common).as<factored_expr_t>().expr(); |
1028 | return f_common; |
1029 | } |
1030 | |
1031 | std::vector<expr_t> factors; |
1032 | |
1033 | private: |
1034 | factored_expr_t(const expr_t &e) : expr_impl_t(_type_info(), e.type()) { |
1035 | init_factors(e); |
1036 | } |
1037 | |
1038 | factored_expr_t(const type_t &type, const std::vector<expr_t> &factors) |
1039 | : expr_impl_t(_type_info(), type) { |
1040 | init_normalize(factors); |
1041 | } |
1042 | |
1043 | void init_normalize(const std::vector<expr_t> &f) { |
1044 | bool sign = false; |
1045 | expr_t e_const = to_expr(1); |
1046 | for (auto &e : f) { |
1047 | if (!jit::is_const(e)) { |
1048 | factors.push_back(e); |
1049 | continue; |
1050 | } |
1051 | if (to_cpp<int64_t>(e) < 0) sign = !sign; |
1052 | if (jit::is_one(e) || jit::is_minus_one(e)) continue; |
1053 | |
1054 | e_const = e_const * abs(e); |
1055 | } |
1056 | if (sign) e_const = -e_const; |
1057 | factors.push_back(e_const); |
1058 | } |
1059 | |
1060 | void init_factors(const expr_t &e) { |
1061 | auto *nary = e.as_ptr<nary_op_t>(); |
1062 | if (!nary) { |
1063 | auto *unary = e.as_ptr<unary_op_t>(); |
1064 | if (unary && unary->op_kind == op_kind_t::_minus) { |
1065 | init_factors(unary->a); |
1066 | factors.back() *= -1; |
1067 | return; |
1068 | } |
1069 | init_normalize({e}); |
1070 | return; |
1071 | } |
1072 | |
1073 | if (nary->op_kind == op_kind_t::_mul) { |
1074 | expr_t f_mul = factored_expr_t::make(to_expr(1)); |
1075 | for (auto &a : nary->args) { |
1076 | f_mul = f_mul.as<factored_expr_t>().merge( |
1077 | factored_expr_t::make(a)); |
1078 | } |
1079 | factors = f_mul.as<factored_expr_t>().factors; |
1080 | return; |
1081 | } |
1082 | |
1083 | if (nary->op_kind == op_kind_t::_add) { |
1084 | expr_t common; |
1085 | for (auto &a : nary->args) { |
1086 | if (common.is_empty()) { |
1087 | common = factored_expr_t::make(a); |
1088 | continue; |
1089 | } |
1090 | common = common.as<factored_expr_t>().intersect( |
1091 | factored_expr_t::make(a)); |
1092 | } |
1093 | if (common.as<factored_expr_t>().is_one()) { |
1094 | init_normalize({e}); |
1095 | return; |
1096 | } |
1097 | std::vector<expr_t> rest_factors; |
1098 | for (auto &a : nary->args) { |
1099 | auto fa_expr = factored_expr_t::make(a); |
1100 | auto &fa = fa_expr.as<factored_expr_t>(); |
1101 | rest_factors.push_back( |
1102 | fa.reduce(common).as<factored_expr_t>().expr()); |
1103 | } |
1104 | auto &f_common = common.as<factored_expr_t>(); |
1105 | auto rest = factored_expr_t::make( |
1106 | make_nary_op(op_kind_t::_add, rest_factors)); |
1107 | factors = f_common.merge(rest).as<factored_expr_t>().factors; |
1108 | return; |
1109 | } |
1110 | ir_error_not_expected(); |
1111 | } |
1112 | |
1113 | expr_t intersect_impl(const expr_t &other, bool ignore_constants) const { |
1114 | auto &f_other = other.as<factored_expr_t>(); |
1115 | object_eq_map_t<expr_t, int> f_map; |
1116 | // Skip constant factor. |
1117 | for (size_t i = 0; i < factors.size() - 1; i++) |
1118 | f_map[factors[i]]++; |
1119 | |
1120 | std::vector<expr_t> common_factors; |
1121 | for (auto &e : f_other.factors) { |
1122 | auto it = f_map.find(e); |
1123 | if (it == f_map.end() || it->second == 0) continue; |
1124 | f_map[e]--; |
1125 | common_factors.push_back(e); |
1126 | } |
1127 | |
1128 | if (ignore_constants) |
1129 | return factored_expr_t::make(type, common_factors); |
1130 | |
1131 | // Handle constant factor. |
1132 | int64_t a_const = to_cpp<int64_t>(factors.back()); |
1133 | int64_t b_const = to_cpp<int64_t>(f_other.factors.back()); |
1134 | if (a_const != 0 && b_const != 0) { |
1135 | int64_t ab_gcd = ((a_const < 0) && (b_const < 0)) ? -1 : 1; |
1136 | ab_gcd *= math::gcd(std::abs(a_const), std::abs(b_const)); |
1137 | if (ab_gcd != 1) common_factors.push_back(to_expr(ab_gcd, type)); |
1138 | } else if (a_const == 0 && b_const == 0) { |
1139 | common_factors.push_back(to_expr(0, type)); |
1140 | } |
1141 | |
1142 | return factored_expr_t::make(type, common_factors); |
1143 | } |
1144 | }; |
1145 | |
1146 | class division_reducer_t : public nary_op_mutator_t { |
1147 | public: |
1148 | using nary_op_mutator_t::_mutate; |
1149 | |
1150 | object_t _mutate(const binary_op_t &obj) override { |
1151 | if (obj.op_kind != op_kind_t::_div) |
1152 | return nary_op_mutator_t::_mutate(obj); |
1153 | |
1154 | expr_t a = mutate(obj.a); |
1155 | expr_t b = mutate(obj.b); |
1156 | |
1157 | factored_expr_t::reduce(a, b); |
1158 | |
1159 | if (is_one(b)) return std::move(a); |
1160 | |
1161 | return binary_op_t::make(op_kind_t::_div, a, b); |
1162 | } |
1163 | }; |
1164 | |
1165 | bool is_divisible( |
1166 | const expr_t &a, const expr_t &b, const constraint_set_t &cset) { |
1167 | if (cset.can_prove(a % b == 0, /*try_simplify=*/false)) return true; |
1168 | |
1169 | // Try to find b in factors of a. |
1170 | auto fa = factored_expr_t::make(a); |
1171 | auto fb = factored_expr_t::make(b); |
1172 | return fa.as<factored_expr_t>().is_reducible(fb); |
1173 | } |
1174 | |
1175 | class int_div_mod_expander_t : public nary_op_mutator_t { |
1176 | public: |
1177 | using nary_op_mutator_t::_mutate; |
1178 | |
1179 | int_div_mod_expander_t(const constraint_set_t &cset) : cset(cset) {} |
1180 | |
1181 | object_t _mutate(const binary_op_t &_obj) override { |
1182 | auto obj = nary_op_mutator_t::_mutate(_obj); |
1183 | auto *binary_op = obj.as_ptr<binary_op_t>(); |
1184 | if (!binary_op) return obj; |
1185 | if (!utils::one_of( |
1186 | binary_op->op_kind, op_kind_t::_div, op_kind_t::_mod)) |
1187 | return obj; |
1188 | if (!binary_op->type.is_int()) return obj; |
1189 | |
1190 | auto a = binary_op->a; |
1191 | auto b = binary_op->b; |
1192 | |
1193 | auto _b = nary_op_back_transform(b); |
1194 | if (!cset.can_prove(_b > 0)) return obj; |
1195 | |
1196 | auto *a_nary = a.as_ptr<nary_op_t>(); |
1197 | |
1198 | if (a_nary && a_nary->op_kind == op_kind_t::_add) |
1199 | return mutate_with_add(*binary_op); |
1200 | |
1201 | // Try to reduce a and b. |
1202 | auto common_factor = factored_expr_t::reduce(a, b); |
1203 | |
1204 | if (is_one(b)) { |
1205 | if (binary_op->op_kind == op_kind_t::_mod) |
1206 | return to_expr(0, binary_op->type); |
1207 | if (binary_op->op_kind == op_kind_t::_div) return std::move(a); |
1208 | } |
1209 | |
1210 | if (binary_op->op_kind == op_kind_t::_div) { |
1211 | return a / b; |
1212 | } else if (binary_op->op_kind == op_kind_t::_mod) { |
1213 | auto &c = common_factor.as<factored_expr_t>(); |
1214 | if (c.is_const() && to_cpp<int64_t>(c.const_factor()) > 1) |
1215 | return make_nary_op(op_kind_t::_mul, {c.const_factor(), a % b}); |
1216 | } |
1217 | |
1218 | return obj; |
1219 | } |
1220 | |
1221 | expr_t mutate_with_add(const binary_op_t &obj) { |
1222 | expr_t e = obj; |
1223 | if (reduce_v1(e)) return e; |
1224 | if (reduce_v2(e)) return e; |
1225 | return e; |
1226 | } |
1227 | |
1228 | // Applies the following rules: |
1229 | // 1) (A + B) % C -> B % C, when |
1230 | // - A % C == 0 |
1231 | // - B >= 0 |
1232 | // 2) (A + B) / C -> (A / C) + (B / C), when |
1233 | // - A % C == 0 |
1234 | // - B >= 0 |
1235 | bool reduce_v1(expr_t &expr) { |
1236 | auto *binary_op = expr.as_ptr<binary_op_t>(); |
1237 | if (!binary_op) return false; |
1238 | |
1239 | auto op_kind = binary_op->op_kind; |
1240 | auto &a = binary_op->a; |
1241 | auto &b = binary_op->b; |
1242 | |
1243 | std::vector<expr_t> lhs_args; // Reducible summands. |
1244 | std::vector<expr_t> rhs_args; // Non-reducible summands. |
1245 | |
1246 | auto *a_nary = a.as_ptr<nary_op_t>(); |
1247 | for (auto &e : a_nary->args) { |
1248 | if (is_div_reducible(e, b)) { |
1249 | lhs_args.push_back(e); |
1250 | } else { |
1251 | rhs_args.push_back(e); |
1252 | } |
1253 | } |
1254 | |
1255 | // Nothing to reduce, return expression as is. |
1256 | if (lhs_args.empty()) return false; |
1257 | |
1258 | auto rhs_nary = make_nary_op(op_kind_t::_add, rhs_args); |
1259 | auto _rhs = nary_op_back_transform(rhs_nary); |
1260 | bool rhs_ge_0 = cset.can_prove(_rhs >= 0); |
1261 | |
1262 | if (op_kind == op_kind_t::_mod) { |
1263 | if (rhs_args.empty()) { |
1264 | expr = to_expr(0, expr.type()); |
1265 | return true; |
1266 | } |
1267 | if (!rhs_ge_0) return false; |
1268 | expr = rhs_nary % b; |
1269 | return true; |
1270 | } |
1271 | |
1272 | if (op_kind == op_kind_t::_div) { |
1273 | if (!rhs_ge_0) return false; |
1274 | if (rhs_args.empty()) { |
1275 | expr = mutate(lhs_args[0] / b); |
1276 | for (int i = 1; i < int(lhs_args.size()); i++) { |
1277 | expr += mutate(lhs_args[i] / b); |
1278 | } |
1279 | return true; |
1280 | } |
1281 | auto lhs_div = make_nary_op(op_kind_t::_add, lhs_args) / b; |
1282 | auto rhs_div = rhs_nary / b; |
1283 | expr = mutate(lhs_div) + mutate(rhs_div); |
1284 | return true; |
1285 | } |
1286 | |
1287 | ir_error_not_expected() << expr; |
1288 | |
1289 | return false; |
1290 | } |
1291 | |
1292 | // Applies the following rules: |
1293 | // 1) (A * B + D) / (A * C) -> (A * B) / (A * C), when |
1294 | // - A > 0 |
1295 | // - C > 0 |
1296 | // - 0 <= D < A |
1297 | // 2) (A * B + D) % (A * C) -> (A * B) % (A * C) + D % (A * C), when |
1298 | // - A > 0 |
1299 | // - C > 0 |
1300 | // - 0 <= D < A |
1301 | bool reduce_v2(expr_t &expr) { |
1302 | auto *binary_op = expr.as_ptr<binary_op_t>(); |
1303 | if (!binary_op) return false; |
1304 | |
1305 | auto op_kind = binary_op->op_kind; |
1306 | auto &a = binary_op->a; |
1307 | auto &b = binary_op->b; |
1308 | if (!is_const(b)) return false; |
1309 | |
1310 | auto const_factor = [&](const expr_t &e) { |
1311 | auto _fe = factored_expr_t::make(e); |
1312 | auto &fe = _fe.as<factored_expr_t>(); |
1313 | auto ret = to_cpp<int64_t>(fe.const_factor()); |
1314 | for (auto &f : fe.factors) |
1315 | if (is_var(f)) ret *= cset.max_proven_gcd(f); |
1316 | return ret; |
1317 | }; |
1318 | |
1319 | // TODO: Check 0. |
1320 | // Find max constant GCD. |
1321 | int64_t b_gcd = const_factor(b); |
1322 | int64_t max_gcd = 0; |
1323 | auto *a_nary = a.as_ptr<nary_op_t>(); |
1324 | for (auto &e : a_nary->args) { |
1325 | int64_t gcd = math::gcd(b_gcd, const_factor(e)); |
1326 | if (gcd > max_gcd) max_gcd = gcd; |
1327 | } |
1328 | |
1329 | if (max_gcd == 0) return false; |
1330 | |
1331 | std::vector<expr_t> lhs_args; // Reducible summands. |
1332 | std::vector<expr_t> rhs_args; // Non-reducible summands. |
1333 | for (auto &e : a_nary->args) { |
1334 | if (is_div_reducible(e, max_gcd)) { |
1335 | lhs_args.push_back(e); |
1336 | } else { |
1337 | rhs_args.push_back(e); |
1338 | } |
1339 | } |
1340 | |
1341 | // max_gcd is the GCD for some summand so at least one summand must be |
1342 | // reducible. |
1343 | ir_assert(!lhs_args.empty()); |
1344 | |
1345 | if (rhs_args.empty()) return false; |
1346 | |
1347 | int64_t A = max_gcd; |
1348 | int64_t C = to_cpp<int64_t>(b) / A; |
1349 | if (A <= 0 || C <= 0) return false; |
1350 | |
1351 | auto rhs_nary = make_nary_op(op_kind_t::_add, rhs_args); |
1352 | auto D = nary_op_back_transform(rhs_nary); |
1353 | if (!cset.can_prove(D >= 0) || !cset.can_prove(D < A)) return false; |
1354 | |
1355 | if (op_kind == op_kind_t::_mod) { |
1356 | auto lhs_mod = make_nary_op(op_kind_t::_add, lhs_args) % b; |
1357 | auto rhs_mod = rhs_nary % b; |
1358 | expr = mutate(lhs_mod) + mutate(rhs_mod); |
1359 | return true; |
1360 | } |
1361 | |
1362 | if (op_kind == op_kind_t::_div) { |
1363 | auto lhs_div = make_nary_op(op_kind_t::_add, lhs_args) / b; |
1364 | expr = lhs_div; |
1365 | return true; |
1366 | } |
1367 | |
1368 | ir_error_not_expected() << expr; |
1369 | |
1370 | return false; |
1371 | } |
1372 | |
1373 | bool is_div_reducible(const expr_t &a, const expr_t &b) const { |
1374 | if (is_const(a) && is_const(b)) { |
1375 | return to_cpp<int64_t>(a) % to_cpp<int64_t>(b) == 0; |
1376 | } |
1377 | |
1378 | if (b.is_equal(to_expr(1, b.type()))) return true; |
1379 | |
1380 | return is_divisible(a, b, cset); |
1381 | } |
1382 | |
1383 | const constraint_set_t &cset; |
1384 | }; |
1385 | |
1386 | class int_div_mod_range_simplifier_t : public nary_op_mutator_t { |
1387 | public: |
1388 | using nary_op_mutator_t::_mutate; |
1389 | |
1390 | int_div_mod_range_simplifier_t(const constraint_set_t &cset) : cset(cset) {} |
1391 | |
1392 | object_t _mutate(const binary_op_t &obj) override { |
1393 | if (!utils::one_of(obj.op_kind, op_kind_t::_div, op_kind_t::_mod)) |
1394 | return nary_op_mutator_t::_mutate(obj); |
1395 | |
1396 | auto a = mutate(obj.a); |
1397 | auto b = mutate(obj.b); |
1398 | |
1399 | auto _a = nary_op_back_transform(a); |
1400 | auto _b = nary_op_back_transform(b); |
1401 | |
1402 | // 0 <= a < b => (a / b) == 0 |
1403 | bool abs_a_lt_b = cset.can_prove(_a >= 0) && cset.can_prove(_a < _b); |
1404 | |
1405 | // 0 <= a < b => (a % b) == a |
1406 | if (abs_a_lt_b) { |
1407 | if (obj.op_kind == op_kind_t::_div) return to_expr(0); |
1408 | if (obj.op_kind == op_kind_t::_mod) return a; |
1409 | } |
1410 | |
1411 | return binary_op_t::make(obj.op_kind, a, b); |
1412 | } |
1413 | |
1414 | const constraint_set_t &cset; |
1415 | }; |
1416 | |
1417 | // Factors out common factors in an N-ary expression. |
1418 | class common_factor_simplifier_t : public nary_op_mutator_t { |
1419 | public: |
1420 | object_t _mutate(const nary_op_t &obj) override { |
1421 | if (obj.op_kind != op_kind_t::_add) |
1422 | return nary_op_mutator_t::_mutate(obj); |
1423 | |
1424 | auto args = mutate(obj.args); |
1425 | for (auto &a : args) { |
1426 | auto *nary = a.as_ptr<nary_op_t>(); |
1427 | if (nary) ir_assert(nary->op_kind == op_kind_t::_mul) << a; |
1428 | } |
1429 | |
1430 | // Fold same factors (find exact match, ignore constants). |
1431 | // Example: |
1432 | // (a * c1 + a * c2 + b) -> |
1433 | // (a * c3 + b) where c3 = (c1 + c2) |
1434 | for (size_t i = 0; i < args.size(); i++) { |
1435 | auto e_fi = factored_expr_t::make(args[i]); |
1436 | for (size_t j = i + 1; j < args.size(); j++) { |
1437 | auto e_fj = factored_expr_t::make(args[j]); |
1438 | |
1439 | auto &fi = e_fi.as<factored_expr_t>(); |
1440 | auto &fj = e_fj.as<factored_expr_t>(); |
1441 | |
1442 | auto e_fij_common = fi.intersect_ignore_const(e_fj); |
1443 | auto &fij_common = e_fij_common.as<factored_expr_t>(); |
1444 | if (fi.is_equal_ignore_const(fij_common) |
1445 | && fj.is_equal_ignore_const(fij_common)) { |
1446 | auto new_args = fij_common.factors; |
1447 | new_args.push_back(fi.const_factor() + fj.const_factor()); |
1448 | args[i] = make_nary_op(op_kind_t::_mul, new_args); |
1449 | e_fi = factored_expr_t::make(args[i]); |
1450 | args[j] = to_expr(0, args[j].type()); |
1451 | } |
1452 | } |
1453 | } |
1454 | |
1455 | // Partial folding (fold any match). |
1456 | // Example: |
1457 | // (a * b * c + a * b * d + e) -> |
1458 | // ((a * b * (c + d)) + e) |
1459 | for (size_t i = 0; i < args.size(); i++) { |
1460 | if (is_zero(args[i])) continue; |
1461 | auto e_fi = factored_expr_t::make(args[i]); |
1462 | for (size_t j = i + 1; j < args.size(); j++) { |
1463 | if (is_zero(args[j])) continue; |
1464 | auto e_fj = factored_expr_t::make(args[j]); |
1465 | |
1466 | auto &fi = e_fi.as<factored_expr_t>(); |
1467 | auto &fj = e_fj.as<factored_expr_t>(); |
1468 | |
1469 | auto e_fij_common = fi.intersect_ignore_const(e_fj); |
1470 | auto &fij_common = e_fij_common.as<factored_expr_t>(); |
1471 | |
1472 | // fij_common = 1 means no common factors, other constant |
1473 | // factors are also ignored, for simplicity (though it might be |
1474 | // beneficial to fold them as well). |
1475 | if (fij_common.is_const()) continue; |
1476 | |
1477 | // factored_expr_t::make() will find common factors. |
1478 | auto e_fi_add_fj = factored_expr_t::make( |
1479 | make_nary_op(op_kind_t::_add, {fi.expr(), fj.expr()})); |
1480 | auto &fi_add_fj = e_fi_add_fj.as<factored_expr_t>(); |
1481 | args[i] = make_nary_op(op_kind_t::_mul, fi_add_fj.factors); |
1482 | e_fi = e_fi_add_fj; |
1483 | args[j] = to_expr(0, args[j].type()); |
1484 | } |
1485 | } |
1486 | |
1487 | return make_nary_op(obj.op_kind, args); |
1488 | } |
1489 | }; |
1490 | |
1491 | // Rewrites addition with mixed 64-bit/32-bit expressions to reduce 64-bit |
1492 | // arithmetic. Example: |
1493 | // Before: ((x.s64 + y.s32) + z.s32) [two 64-bit add] |
1494 | // After: ((y.s32 + z.s32) + x.s64) [one 32-bit add and one 64-bit add] |
1495 | class _64_bit_add_optimizer_t : public nary_op_mutator_t { |
1496 | public: |
1497 | object_t _mutate(const nary_op_t &obj) override { |
1498 | auto new_obj = nary_op_mutator_t::_mutate(obj); |
1499 | auto *nary_op = new_obj.as_ptr<nary_op_t>(); |
1500 | if (nary_op->op_kind != op_kind_t::_add || nary_op->args.size() <= 2) |
1501 | return new_obj; |
1502 | |
1503 | std::vector<expr_t> other_args; |
1504 | std::vector<expr_t> x64_args; |
1505 | for (auto &a : nary_op->args) { |
1506 | if (a.type().is_x64()) { |
1507 | x64_args.push_back(a); |
1508 | } else { |
1509 | other_args.push_back(a); |
1510 | } |
1511 | } |
1512 | |
1513 | if (other_args.empty() || x64_args.empty()) return new_obj; |
1514 | |
1515 | std::vector<expr_t> new_args = std::move(other_args); |
1516 | new_args.insert(new_args.end(), x64_args.begin(), x64_args.end()); |
1517 | |
1518 | return nary_op_t::make(nary_op->op_kind, new_args); |
1519 | } |
1520 | }; |
1521 | |
1522 | // Simplifies using the N-ary form. |
1523 | expr_t simplify_with_nary(const expr_t &_e, const constraint_set_t &cset) { |
1524 | auto e = _e; |
1525 | |
1526 | if (!e.type().is_scalar() || e.type().is_fp()) { return e; } |
1527 | e = nary_op_canonicalize(e); |
1528 | |
1529 | e = division_reducer_t().mutate(e); |
1530 | e = nary_op_flatten(e); |
1531 | e = int_div_mod_expander_t(cset).mutate(e); |
1532 | e = common_factor_simplifier_t().mutate(e); |
1533 | e = int_div_mod_range_simplifier_t(cset).mutate(e); |
1534 | e = _64_bit_add_optimizer_t().mutate(e); |
1535 | |
1536 | e = nary_op_back_transform(e); |
1537 | |
1538 | return e; |
1539 | } |
1540 | |
1541 | class stmt_simplifier_t : public ir_mutator_t { |
1542 | public: |
1543 | stmt_simplifier_t(const constraint_set_t &cset) : cset_(cset) {} |
1544 | |
1545 | ~stmt_simplifier_t() override { |
1546 | if (!cpp_compat::uncaught_exceptions()) { |
1547 | ir_assert(continue_calls_.empty()) << "Unexpected continue calls." ; |
1548 | } |
1549 | } |
1550 | |
1551 | object_t _mutate(const binary_op_t &obj) override { |
1552 | return simplify(obj, cset_); |
1553 | } |
1554 | |
1555 | object_t _mutate(const func_call_t &obj) override { |
1556 | if (obj.func.is_equal(funcs::continue_func())) { |
1557 | continue_calls_.push_back(obj); |
1558 | } |
1559 | return ir_mutator_t::_mutate(obj); |
1560 | } |
1561 | |
1562 | object_t _mutate(const if_t &obj) override { |
1563 | auto cond = simplify(obj.cond); |
1564 | |
1565 | if (all_of(cond, expr_t(true))) return mutate(obj.body); |
1566 | if (all_of(cond, expr_t(false))) return mutate(obj.else_body); |
1567 | |
1568 | auto body = obj.body; |
1569 | if (!body.is_empty()) { |
1570 | auto cset_old = cset_; |
1571 | cset_.add_constraint(cond); |
1572 | body = ir_mutator_t::mutate(body); |
1573 | cset_ = cset_old; |
1574 | } |
1575 | |
1576 | auto else_body = obj.else_body; |
1577 | if (!else_body.is_empty()) { |
1578 | auto cset_old = cset_; |
1579 | cset_.add_constraint(flip_condition(cond)); |
1580 | else_body = ir_mutator_t::mutate(else_body); |
1581 | cset_ = cset_old; |
1582 | } |
1583 | |
1584 | return if_t::make(cond, body, else_body); |
1585 | } |
1586 | |
1587 | object_t _mutate(const let_t &obj) override { |
1588 | // External variable. |
1589 | if (obj.value.is_empty()) return ir_mutator_t::_mutate(obj); |
1590 | |
1591 | // Substitute constants. |
1592 | auto value = simplify(obj.value); |
1593 | if (is_const(value)) { |
1594 | // Constants are not necessarily the same type as the assigned |
1595 | // variable |
1596 | value = cast_t::make(obj.var.as<var_t>().type, value); |
1597 | |
1598 | auto body = substitute(obj.body, obj.var, value); |
1599 | return mutate(body); |
1600 | } else if (is_var(value)) { |
1601 | auto body = substitute(obj.body, obj.var, value); |
1602 | return mutate(body); |
1603 | } |
1604 | |
1605 | auto cset_old = cset_; |
1606 | cset_.add_constraint(obj.var == value); |
1607 | auto new_obj = let_t::make(obj.var, value, obj.body); |
1608 | new_obj = ir_mutator_t::_mutate(new_obj.as<let_t>()); |
1609 | cset_ = cset_old; |
1610 | |
1611 | return std::move(new_obj); |
1612 | } |
1613 | |
1614 | object_t _mutate(const for_t &obj) override { |
1615 | object_t new_obj; |
1616 | bool found_continue = false; |
1617 | size_t ncontinue_calls = continue_calls_.size(); |
1618 | if (is_zero(obj.init) && is_one(obj.bound)) { |
1619 | auto body = substitute(obj.body, obj.var, expr_t(0)); |
1620 | body = mutate(body); |
1621 | if (continue_calls_.size() > ncontinue_calls) found_continue = true; |
1622 | if (found_continue) { |
1623 | new_obj = for_t::make( |
1624 | obj.var, obj.init, obj.bound, body, obj.unroll); |
1625 | } else { |
1626 | new_obj = body; |
1627 | } |
1628 | } else { |
1629 | auto cset_old = cset_; |
1630 | cset_.add_constraint(obj.var >= obj.init); |
1631 | cset_.add_constraint(obj.var < obj.bound); |
1632 | new_obj = ir_mutator_t::_mutate(obj); |
1633 | if (continue_calls_.size() > ncontinue_calls) found_continue = true; |
1634 | cset_ = cset_old; |
1635 | } |
1636 | |
1637 | // Remove continue call. |
1638 | if (found_continue) continue_calls_.pop_back(); |
1639 | |
1640 | return new_obj; |
1641 | } |
1642 | |
1643 | object_t _mutate(const store_t &obj) override { |
1644 | auto new_obj = ir_mutator_t::_mutate(obj); |
1645 | if (new_obj.is_empty()) return stmt_t(); |
1646 | |
1647 | auto &store = new_obj.as<store_t>(); |
1648 | if (!store.value.is<load_t>()) return new_obj; |
1649 | |
1650 | auto &load = store.value.as<load_t>(); |
1651 | if (!store.buf.is_equal(load.buf)) return new_obj; |
1652 | if (!store.off.is_equal(load.off)) return new_obj; |
1653 | if (store.stride != load.stride) return new_obj; |
1654 | |
1655 | // This is a load/store of the same value which is a no-op. |
1656 | return stmt_t(); |
1657 | } |
1658 | |
1659 | private: |
1660 | static op_kind_t flip_cmp_op(op_kind_t op_kind) { |
1661 | switch (op_kind) { |
1662 | case op_kind_t::_eq: return op_kind_t::_ne; |
1663 | case op_kind_t::_ge: return op_kind_t::_lt; |
1664 | case op_kind_t::_gt: return op_kind_t::_le; |
1665 | case op_kind_t::_le: return op_kind_t::_gt; |
1666 | case op_kind_t::_lt: return op_kind_t::_ge; |
1667 | case op_kind_t::_ne: return op_kind_t::_eq; |
1668 | default: ir_error_not_expected(); |
1669 | } |
1670 | return op_kind_t::undef; |
1671 | } |
1672 | |
1673 | static expr_t flip_condition(const expr_t &cond) { |
1674 | ir_assert(cond.type().is_bool()); |
1675 | |
1676 | auto *binary_op = cond.as_ptr<binary_op_t>(); |
1677 | if (binary_op) { |
1678 | auto &a = binary_op->a; |
1679 | auto &b = binary_op->b; |
1680 | auto op_kind = binary_op->op_kind; |
1681 | return binary_op_t::make(flip_cmp_op(op_kind), a, b); |
1682 | } |
1683 | |
1684 | auto *shuffle = cond.as_ptr<shuffle_t>(); |
1685 | if (shuffle && shuffle->is_broadcast()) { |
1686 | return shuffle_t::make_broadcast( |
1687 | flip_condition(shuffle->vec[0]), shuffle->elems()); |
1688 | } |
1689 | |
1690 | ir_error_not_expected(); |
1691 | return expr_t(); |
1692 | } |
1693 | |
1694 | constraint_set_t cset_; |
1695 | std::vector<stmt_t> continue_calls_; |
1696 | }; |
1697 | |
1698 | expr_t simplify_expr(const expr_t &_e, const constraint_set_t &cset) { |
1699 | expr_t e = _e; |
1700 | |
1701 | if (is_const(e) || is_var(e)) return e; |
1702 | |
1703 | e = const_fold(e); |
1704 | e = simplify_rewrite(e); |
1705 | |
1706 | e = simplify_comparison(e); |
1707 | e = range_simplifier_t(cset).mutate(e); |
1708 | e = simplify_with_nary(e, cset); |
1709 | |
1710 | e = const_fold(e); |
1711 | e = simplify_rewrite(e); |
1712 | |
1713 | return e; |
1714 | } |
1715 | |
1716 | stmt_t simplify_stmt(const stmt_t &s, const constraint_set_t &cset) { |
1717 | stmt_simplifier_t simplifier(cset); |
1718 | return simplifier.mutate(s); |
1719 | } |
1720 | |
1721 | int64_t get_max_const_factor(const expr_t &_e, const constraint_set_t &cset) { |
1722 | ir_assert(_e.type().is_int()); |
1723 | auto e = _e; |
1724 | // Some complex expressions need more than one simplify() call. |
1725 | int max_tries = 3; |
1726 | for (int i = 0; i < max_tries; i++) |
1727 | e = simplify(e, cset); |
1728 | auto o = factored_expr_t::make(nary_op_canonicalize(e)); |
1729 | auto &expr = o.as<factored_expr_t>(); |
1730 | return to_cpp<int64_t>(expr.const_factor()); |
1731 | } |
1732 | |
1733 | template <op_kind_t op_kind> |
1734 | struct op_traits_t {}; |
1735 | |
1736 | #define DECL_OP_TRAITS(name, op) \ |
1737 | template <> \ |
1738 | struct op_traits_t<name> { \ |
1739 | template <typename T, \ |
1740 | typename = typename std::enable_if< \ |
1741 | !std::is_same<T, bool>::value>::type> \ |
1742 | static auto compute(T a, T b) -> decltype(a op b) { \ |
1743 | return a op b; \ |
1744 | } \ |
1745 | template <op_kind_t dummy_op = name, \ |
1746 | typename \ |
1747 | = typename std::enable_if<dummy_op == op_kind_t::_and>::type> \ |
1748 | static bool compute(bool a, bool b) { \ |
1749 | return a op b; \ |
1750 | } \ |
1751 | }; |
1752 | |
1753 | DECL_OP_TRAITS(op_kind_t::_add, +) |
1754 | DECL_OP_TRAITS(op_kind_t::_sub, -) |
1755 | DECL_OP_TRAITS(op_kind_t::_mul, *) |
1756 | DECL_OP_TRAITS(op_kind_t::_div, /) |
1757 | DECL_OP_TRAITS(op_kind_t::_mod, %) |
1758 | |
1759 | DECL_OP_TRAITS(op_kind_t::_eq, ==) |
1760 | DECL_OP_TRAITS(op_kind_t::_ne, !=) |
1761 | DECL_OP_TRAITS(op_kind_t::_gt, >) |
1762 | DECL_OP_TRAITS(op_kind_t::_ge, >=) |
1763 | DECL_OP_TRAITS(op_kind_t::_lt, <) |
1764 | DECL_OP_TRAITS(op_kind_t::_le, <=) |
1765 | |
1766 | DECL_OP_TRAITS(op_kind_t::_and, &) |
1767 | |
1768 | template <> |
1769 | struct op_traits_t<op_kind_t::_min> { |
1770 | template <typename T> |
1771 | static T compute(T a, T b) { |
1772 | return std::min(a, b); |
1773 | } |
1774 | }; |
1775 | |
1776 | template <> |
1777 | struct op_traits_t<op_kind_t::_max> { |
1778 | template <typename T> |
1779 | static T compute(T a, T b) { |
1780 | return std::max(a, b); |
1781 | } |
1782 | }; |
1783 | |
1784 | #undef DECL_OP_TRAITS |
1785 | |
1786 | template <op_kind_t op_kind, typename T, typename = void> |
1787 | struct compute_helper_t { |
1788 | static expr_t call(T a, T b) { return expr_t(); } |
1789 | }; |
1790 | |
1791 | template <typename> |
1792 | struct voider_t { |
1793 | using type = void; |
1794 | }; |
1795 | |
1796 | template <op_kind_t op_kind, typename T> |
1797 | struct compute_helper_t<op_kind, T, |
1798 | typename voider_t<decltype( |
1799 | op_traits_t<op_kind>::compute(T(), T()))>::type> { |
1800 | static expr_t call(T a, T b) { |
1801 | return to_expr(op_traits_t<op_kind>::compute(a, b)); |
1802 | } |
1803 | }; |
1804 | |
1805 | template <typename T> |
1806 | class const_fold_helper_t { |
1807 | public: |
1808 | template <typename U = T> |
1809 | static expr_t call(op_kind_t op_kind, T a, T b) { |
1810 | switch (op_kind) { |
1811 | #define CASE(op) \ |
1812 | case op: return compute_helper_t<op, T>::call(a, b); |
1813 | |
1814 | CASE(op_kind_t::_add) |
1815 | CASE(op_kind_t::_sub) |
1816 | CASE(op_kind_t::_mul) |
1817 | CASE(op_kind_t::_div) |
1818 | CASE(op_kind_t::_mod) |
1819 | |
1820 | CASE(op_kind_t::_eq) |
1821 | CASE(op_kind_t::_ne) |
1822 | CASE(op_kind_t::_gt) |
1823 | CASE(op_kind_t::_ge) |
1824 | CASE(op_kind_t::_lt) |
1825 | CASE(op_kind_t::_le) |
1826 | |
1827 | CASE(op_kind_t::_and) |
1828 | CASE(op_kind_t::_min) |
1829 | CASE(op_kind_t::_max) |
1830 | |
1831 | default: ir_error_not_expected(); |
1832 | |
1833 | #undef CASE |
1834 | } |
1835 | return expr_t(); |
1836 | } |
1837 | }; |
1838 | |
1839 | class const_folder_t : public ir_mutator_t { |
1840 | public: |
1841 | object_t _mutate(const binary_op_t &obj) override { |
1842 | return mutate_expr(obj); |
1843 | } |
1844 | object_t _mutate(const cast_t &obj) override { return mutate_expr(obj); } |
1845 | object_t _mutate(const iif_t &obj) override { return mutate_expr(obj); } |
1846 | object_t _mutate(const unary_op_t &obj) override { |
1847 | return mutate_expr(obj); |
1848 | } |
1849 | |
1850 | private: |
1851 | template <typename T> |
1852 | object_t mutate_expr(const T &obj) { |
1853 | auto new_obj = ir_mutator_t::_mutate(obj); |
1854 | return const_fold_non_recursive(new_obj); |
1855 | } |
1856 | }; |
1857 | |
1858 | bool is_const_or_shuffle_const(const expr_t &e) { |
1859 | return is_const(e) || is_shuffle_const(e); |
1860 | } |
1861 | |
1862 | expr_t const_fold_unary(op_kind_t op_kind, const expr_t &a) { |
1863 | ir_assert(op_kind == op_kind_t::_minus); |
1864 | if (!a.type().is_scalar()) { |
1865 | int elems = a.type().elems(); |
1866 | std::vector<expr_t> ret; |
1867 | for (int i = 0; i < elems; i++) { |
1868 | ret.push_back(const_fold_unary(op_kind, a[i])); |
1869 | } |
1870 | return shuffle_t::make(ret); |
1871 | } |
1872 | |
1873 | #define CASE(ir_type, cpp_type) \ |
1874 | if (a.type() == type_t::ir_type()) return to_expr(-to_cpp<cpp_type>(a)) |
1875 | |
1876 | CASE(f32, float); |
1877 | CASE(s16, int16_t); |
1878 | CASE(s32, int32_t); |
1879 | CASE(s64, int64_t); |
1880 | |
1881 | if (a.type().is_bool()) return to_expr(!to_cpp<bool>(a)); |
1882 | |
1883 | #undef CASE |
1884 | |
1885 | ir_error_not_expected() << "Cannot handle type: " << a; |
1886 | return expr_t(); |
1887 | } |
1888 | |
1889 | expr_t const_fold_binary(const type_t &compute_type, op_kind_t op_kind, |
1890 | const expr_t &a, const expr_t &b) { |
1891 | if (!compute_type.is_scalar()) { |
1892 | int elems = compute_type.elems(); |
1893 | auto scalar_type = compute_type.scalar(); |
1894 | std::vector<expr_t> ret; |
1895 | for (int i = 0; i < elems; i++) { |
1896 | ret.push_back(const_fold_binary(scalar_type, op_kind, a[i], b[i])); |
1897 | } |
1898 | return shuffle_t::make(ret); |
1899 | } |
1900 | |
1901 | if (compute_type.is_unsigned()) { |
1902 | auto a_s64 = to_cpp<int64_t>(a); |
1903 | auto b_s64 = to_cpp<int64_t>(b); |
1904 | ir_assert(a_s64 >= 0 && b_s64 >= 0) |
1905 | << "Overflow detected: fix data types." ; |
1906 | MAYBE_UNUSED(a_s64); |
1907 | MAYBE_UNUSED(b_s64); |
1908 | } |
1909 | |
1910 | #define CASE(ir_type, cpp_type) \ |
1911 | if (compute_type == type_t::ir_type()) { \ |
1912 | auto _a = to_cpp<cpp_type>(a); \ |
1913 | auto _b = to_cpp<cpp_type>(b); \ |
1914 | return const_fold_helper_t<cpp_type>::call(op_kind, _a, _b); \ |
1915 | } |
1916 | |
1917 | CASE(_bool, bool) |
1918 | CASE(f32, float) |
1919 | CASE(s16, int16_t) |
1920 | CASE(s32, int32_t) |
1921 | CASE(s64, int64_t) |
1922 | CASE(u16, uint16_t) |
1923 | CASE(u32, uint32_t) |
1924 | CASE(u64, uint64_t) |
1925 | |
1926 | #undef CASE |
1927 | |
1928 | ir_error_not_expected() << "Unknown type." ; |
1929 | return expr_t(); |
1930 | } |
1931 | |
1932 | object_t simplify(const object_t &obj, const constraint_set_t &cset) { |
1933 | if (obj.is_expr()) return simplify_expr(obj, cset); |
1934 | if (obj.is_stmt()) return simplify_stmt(obj, cset); |
1935 | ir_assert(obj.is_empty()); |
1936 | return object_t(); |
1937 | } |
1938 | |
1939 | expr_t simplify_cmp_move_const_to_rhs(const expr_t &e) { |
1940 | if (!is_binary_cmp_op(e)) return e; |
1941 | |
1942 | auto &op = e.as<binary_op_t>(); |
1943 | if (!is_const(op.b)) return e; |
1944 | if (!is_binary_op(op.a)) return e; |
1945 | |
1946 | auto &a_op = op.a.as<binary_op_t>(); |
1947 | |
1948 | bool is_lhs_add = (a_op.op_kind == op_kind_t::_add); |
1949 | bool is_lhs_sub = (a_op.op_kind == op_kind_t::_sub); |
1950 | if (!is_lhs_add && !is_lhs_sub) return e; |
1951 | |
1952 | auto &c1 = op.b; |
1953 | |
1954 | expr_t lhs; |
1955 | expr_t rhs; |
1956 | op_kind_t op_kind; |
1957 | if (is_const(a_op.a)) { |
1958 | auto &c0 = a_op.a; |
1959 | auto &x = a_op.b; |
1960 | if (is_lhs_add) { |
1961 | // ((c0 + x) op c1) -> (x op (c1 - c0)) |
1962 | lhs = x; |
1963 | rhs = c1 - c0; |
1964 | op_kind = op.op_kind; |
1965 | } else { |
1966 | // ((c0 - x) op c1) -> (x -op (c0 - c1)) |
1967 | lhs = x; |
1968 | rhs = c0 - c1; |
1969 | op_kind = negate_cmp_op(op.op_kind); |
1970 | } |
1971 | } else if (is_const(a_op.b)) { |
1972 | auto &x = a_op.a; |
1973 | auto &c0 = a_op.b; |
1974 | if (is_lhs_add) { |
1975 | // ((x + c0) op c1) -> (x op (c1 - c0)) |
1976 | lhs = x; |
1977 | rhs = c1 - c0; |
1978 | op_kind = op.op_kind; |
1979 | } else { |
1980 | // ((x - c0) op c1) -> (x op (c0 + c1)) |
1981 | lhs = x; |
1982 | rhs = c0 + c1; |
1983 | op_kind = op.op_kind; |
1984 | } |
1985 | } else { |
1986 | return e; |
1987 | } |
1988 | return binary_op_t::make(op_kind, lhs, rhs); |
1989 | } |
1990 | |
1991 | expr_t simplify_cmp_reduce_lhs_rhs(const expr_t &e) { |
1992 | if (!is_binary_cmp_op(e)) return e; |
1993 | |
1994 | auto &op = e.as<binary_op_t>(); |
1995 | |
1996 | // Rule: |
1997 | // (c0 * x op c1) or (x * c0 op c1) -> |
1998 | // x new_op (c1 / c0) if abs(c1) % abs(c0) == 0 |
1999 | // new_op == op or new_op == negate_cmp_op(op) |
2000 | expr_t c0; |
2001 | expr_t c1 = op.b; |
2002 | expr_t x; |
2003 | |
2004 | if (!is_const(c1)) return e; |
2005 | if (!is_binary_op(op.a, op_kind_t::_mul)) return e; |
2006 | |
2007 | auto &a_op = op.a.as<binary_op_t>(); |
2008 | if (is_const(a_op.a)) { |
2009 | c0 = a_op.a; |
2010 | x = a_op.b; |
2011 | } else if (is_const(a_op.b)) { |
2012 | x = a_op.a; |
2013 | c0 = a_op.b; |
2014 | } |
2015 | |
2016 | if (c0.is_empty()) return e; |
2017 | if (!c0.type().is_int()) return e; |
2018 | if (!c1.type().is_int()) return e; |
2019 | |
2020 | auto i_c0 = to_cpp<int64_t>(c0); |
2021 | auto i_c1 = to_cpp<int64_t>(c1); |
2022 | |
2023 | bool is_c0_neg = (i_c0 < 0); |
2024 | bool sign = ((i_c0 < 0) != (i_c1 < 0)); |
2025 | i_c0 = std::abs(i_c0); |
2026 | i_c1 = std::abs(i_c1); |
2027 | |
2028 | bool has_mod = (i_c1 % i_c0 != 0); |
2029 | if (has_mod && utils::one_of(op.op_kind, op_kind_t::_eq, op_kind_t::_ne)) |
2030 | return e; |
2031 | |
2032 | auto new_op_kind = (is_c0_neg ? negate_cmp_op(op.op_kind) : op.op_kind); |
2033 | int64_t div = i_c1 / i_c0; |
2034 | if (has_mod) { |
2035 | switch (new_op_kind) { |
2036 | case op_kind_t::_ge: |
2037 | case op_kind_t::_gt: |
2038 | new_op_kind = op_kind_t::_ge; |
2039 | div = (sign ? div : div + 1); |
2040 | break; |
2041 | case op_kind_t::_le: |
2042 | case op_kind_t::_lt: |
2043 | new_op_kind = op_kind_t::_le; |
2044 | div = (sign ? div + 1 : div); |
2045 | break; |
2046 | default: ir_error_not_expected(); |
2047 | } |
2048 | } |
2049 | |
2050 | return binary_op_t::make(new_op_kind, x, (sign ? -1 : 1) * div); |
2051 | } |
2052 | |
2053 | bool const_to_const_binary(const expr_t &e, op_kind_t op_kind, |
2054 | const type_t &a_type, const type_t &b_type, expr_t &a, expr_t &b) { |
2055 | bool is_true = to_cpp<bool>(e); |
2056 | // Assume: |
2057 | // - a0 < b1 |
2058 | // - a1 > b0 |
2059 | // - a_eq == b_eq |
2060 | expr_t a0 = to_expr(0, a_type); |
2061 | expr_t a1 = to_expr(1, a_type); |
2062 | expr_t b0 = to_expr(0, b_type); |
2063 | expr_t b1 = to_expr(1, b_type); |
2064 | expr_t a_eq = to_expr(0, a_type); |
2065 | expr_t b_eq = to_expr(0, b_type); |
2066 | if (!a.is_empty()) { |
2067 | a0 = a1 = a; |
2068 | b0 = a - 1; |
2069 | b1 = a + 1; |
2070 | a_eq = b_eq = a; |
2071 | } else if (!b.is_empty()) { |
2072 | b0 = b1 = b; |
2073 | a0 = b - 1; |
2074 | a1 = b + 1; |
2075 | a_eq = b_eq = b; |
2076 | } |
2077 | switch (op_kind) { |
2078 | case op_kind_t::_and: a = b = e; return true; |
2079 | case op_kind_t::_le: |
2080 | case op_kind_t::_lt: |
2081 | a = (is_true ? a0 : a1); |
2082 | b = (is_true ? b1 : b0); |
2083 | return true; |
2084 | case op_kind_t::_ge: |
2085 | case op_kind_t::_gt: |
2086 | a = (is_true ? a1 : a0); |
2087 | b = (is_true ? b0 : b1); |
2088 | return true; |
2089 | case op_kind_t::_eq: |
2090 | a = (is_true ? a_eq : a0); |
2091 | b = (is_true ? b_eq : b1); |
2092 | return true; |
2093 | case op_kind_t::_ne: |
2094 | a = (is_true ? a0 : a_eq); |
2095 | b = (is_true ? b1 : b_eq); |
2096 | return true; |
2097 | default: return false; |
2098 | } |
2099 | } |
2100 | |
2101 | expr_t simplify_propagate_shuffle(const expr_t &e) { |
2102 | if (!e.type().is_bool()) return e; |
2103 | |
2104 | auto *shuffle = e.as_ptr<shuffle_t>(); |
2105 | if (!shuffle) return e; |
2106 | |
2107 | // Handle binary operation. |
2108 | { |
2109 | type_t a_type; |
2110 | type_t b_type; |
2111 | expr_t a_common_const; |
2112 | expr_t b_common_const; |
2113 | op_kind_t op_kind = op_kind_t::undef; |
2114 | bool found_binary = false; |
2115 | for (int i : shuffle->idx) { |
2116 | if (is_binary_op(shuffle->vec[i])) { |
2117 | auto &op = shuffle->vec[i].as<binary_op_t>(); |
2118 | if (found_binary && op.op_kind != op_kind_t::_and) continue; |
2119 | found_binary = true; |
2120 | a_type = op.a.type(); |
2121 | b_type = op.b.type(); |
2122 | op_kind = op.op_kind; |
2123 | if (is_const(op.a)) a_common_const = op.a; |
2124 | if (is_const(op.b)) b_common_const = op.b; |
2125 | if (op_kind == op_kind_t::_and) break; |
2126 | } |
2127 | } |
2128 | if (!found_binary) return e; |
2129 | |
2130 | for (int i : shuffle->idx) { |
2131 | auto &elem = shuffle->vec[i]; |
2132 | if (is_binary_op(elem, op_kind)) { |
2133 | auto &op = elem.as<binary_op_t>(); |
2134 | if (!a_common_const.is_equal(op.a)) { |
2135 | a_common_const = expr_t(); |
2136 | } |
2137 | if (!b_common_const.is_equal(op.b)) { |
2138 | b_common_const = expr_t(); |
2139 | } |
2140 | } |
2141 | } |
2142 | |
2143 | bool ok = true; |
2144 | std::vector<expr_t> a; |
2145 | std::vector<expr_t> b; |
2146 | for (int i : shuffle->idx) { |
2147 | auto &elem = shuffle->vec[i]; |
2148 | if (is_binary_op(elem, op_kind)) { |
2149 | auto &op = elem.as<binary_op_t>(); |
2150 | a.push_back(op.a); |
2151 | b.push_back(op.b); |
2152 | } else if (is_const(elem)) { |
2153 | expr_t op_a = a_common_const; |
2154 | expr_t op_b = b_common_const; |
2155 | if (!const_to_const_binary( |
2156 | elem, op_kind, a_type, b_type, op_a, op_b)) { |
2157 | ok = false; |
2158 | break; |
2159 | } |
2160 | a.push_back(op_a); |
2161 | b.push_back(op_b); |
2162 | } else if (op_kind == op_kind_t::_and) { |
2163 | // Replace with expression true <op_kind> elem to allow matching |
2164 | // this op against future binary operation. |
2165 | a.push_back(bool_imm_t::make(true)); |
2166 | b.push_back(elem); |
2167 | } else { |
2168 | ok = false; |
2169 | break; |
2170 | } |
2171 | } |
2172 | if (ok) { |
2173 | auto _a = simplify_propagate_shuffle(shuffle_t::make(a)); |
2174 | auto _b = simplify_propagate_shuffle(shuffle_t::make(b)); |
2175 | return binary_op_t::make(op_kind, _a, _b); |
2176 | } |
2177 | } |
2178 | |
2179 | return e; |
2180 | } |
2181 | |
2182 | expr_t const_fold_non_recursive(const expr_t &e) { |
2183 | auto *unary_op = e.as_ptr<unary_op_t>(); |
2184 | if (unary_op) { |
2185 | auto &a = unary_op->a; |
2186 | if (!is_const_or_shuffle_const(a)) return e; |
2187 | return const_fold_unary(unary_op->op_kind, a); |
2188 | } |
2189 | |
2190 | auto *binary_op = e.as_ptr<binary_op_t>(); |
2191 | if (binary_op) { |
2192 | auto op_kind = binary_op->op_kind; |
2193 | auto &a = binary_op->a; |
2194 | auto &b = binary_op->b; |
2195 | if (!is_const_or_shuffle_const(a) || !is_const_or_shuffle_const(b)) |
2196 | return e; |
2197 | |
2198 | auto compute_type = common_type(a, b); |
2199 | return const_fold_binary(compute_type, op_kind, a, b); |
2200 | } |
2201 | |
2202 | auto *iif = e.as_ptr<iif_t>(); |
2203 | if (iif) { |
2204 | if (!is_const(iif->cond)) return e; |
2205 | if (to_cpp<bool>(iif->cond)) return iif->true_expr; |
2206 | return iif->false_expr; |
2207 | } |
2208 | |
2209 | auto *cast = e.as_ptr<cast_t>(); |
2210 | if (cast && !cast->saturate) { |
2211 | if (cast->expr.is<bool_imm_t>()) |
2212 | return to_expr(to_cpp<bool>(cast->expr), cast->type); |
2213 | if (cast->expr.is<int_imm_t>()) |
2214 | return to_expr(to_cpp<int64_t>(cast->expr), cast->type); |
2215 | if (cast->expr.is<float_imm_t>()) |
2216 | return to_expr(to_cpp<double>(cast->expr), cast->type); |
2217 | } |
2218 | |
2219 | return e; |
2220 | } |
2221 | |
2222 | object_t const_fold(const object_t &obj) { |
2223 | return const_folder_t().mutate(obj); |
2224 | } |
2225 | |
2226 | expr_t nary_op_back_transform(const expr_t &e) { |
2227 | // Convert nary_op_t back to binary_op_t. |
2228 | return nary_op_back_transformer_t().mutate(e); |
2229 | } |
2230 | |
2231 | expr_t nary_op_canonicalize(const expr_t &_e) { |
2232 | auto e = _e; |
2233 | |
2234 | e = nary_op_transformer_t().mutate(e); |
2235 | e = mul_nary_op_expander_t().mutate(e); |
2236 | |
2237 | ir_assert(is_nary_op_canonical(e)) << e; |
2238 | MAYBE_UNUSED(is_nary_op_canonical); |
2239 | |
2240 | return e; |
2241 | } |
2242 | |
2243 | expr_t make_nary_op(op_kind_t op_kind, const std::vector<expr_t> &args) { |
2244 | if (args.empty()) { |
2245 | if (op_kind == op_kind_t::_add) return 0; |
2246 | if (op_kind == op_kind_t::_mul) return 1; |
2247 | ir_error_not_expected() << to_string(op_kind); |
2248 | } |
2249 | if (args.size() == 1) return args[0]; |
2250 | |
2251 | // Do eager constant folding. |
2252 | std::vector<expr_t> new_args; |
2253 | fold_const_nary_op_args(op_kind, args, new_args); |
2254 | |
2255 | if (new_args.size() < args.size()) return make_nary_op(op_kind, new_args); |
2256 | |
2257 | return nary_op_t::make(op_kind, new_args); |
2258 | } |
2259 | |
2260 | std::vector<expr_t> cvt_expr_to_nary_op_args(const expr_t &e) { |
2261 | auto *nary = e.as_ptr<nary_op_t>(); |
2262 | if (nary) return nary->args; |
2263 | return {e}; |
2264 | } |
2265 | |
2266 | stmt_t simplify(const stmt_t &s, ir_context_t &ir_ctx) { |
2267 | trace_start(); |
2268 | auto ret = simplify(s, ir_ctx.cset()); |
2269 | trace_pass("simplify_pass" , ret, ir_ctx); |
2270 | return ret; |
2271 | } |
2272 | |
2273 | } // namespace jit |
2274 | } // namespace gpu |
2275 | } // namespace impl |
2276 | } // namespace dnnl |
2277 | |