1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/jit/pass/simplify.hpp"
18
19#include <algorithm>
20#include <iostream>
21#include <sstream>
22#include <string>
23#include <vector>
24
25#include "common/cpp_compat.hpp"
26#include "common/math_utils.hpp"
27#include "gpu/jit/ir/ir.hpp"
28#include "gpu/jit/utils/trace.hpp"
29#include "gpu/jit/utils/utils.hpp"
30
31namespace dnnl {
32namespace impl {
33namespace gpu {
34namespace jit {
35
36using namespace ir_utils;
37
38// Generic pattern expression, used as a wild card during pattern matching. Can
39// match any expression.
40class pexpr_t : public expr_impl_t {
41public:
42 IR_DECL_EXPR_TYPE_ID(pexpr_t)
43
44 static expr_t make(int id) { return expr_t(new pexpr_t(id)); }
45
46 bool is_equal(const object_impl_t &obj) const override {
47 if (!obj.is<self_type>()) return false;
48 auto &other = obj.as<self_type>();
49
50 return id == other.id;
51 }
52
53 size_t get_hash() const override { return ir_utils::get_hash(id); }
54
55 std::string str() const override {
56 std::ostringstream oss;
57 oss << "pexpr_t(" << id << ")";
58 return oss.str();
59 }
60
61 static expr_t x() {
62 static thread_local expr_t x = pexpr_t::make(0);
63 return x;
64 }
65 static expr_t y() {
66 static thread_local expr_t y = pexpr_t::make(1);
67 return y;
68 }
69 static expr_t z() {
70 static thread_local expr_t z = pexpr_t::make(2);
71 return z;
72 }
73
74 IR_DECLARE_TRAVERSERS()
75
76 int id;
77
78private:
79 pexpr_t(int id) : expr_impl_t(_type_info(), type_t::undef()), id(id) {}
80};
81
82// Pattern expression for int_imm_t, used as a wild card during pattern
83// matching. Can match any int_imm_t with the given value.
84class pint_imm_t : public expr_impl_t {
85public:
86 IR_DECL_EXPR_TYPE_ID(pint_imm_t)
87
88 // Matches an integer constant with the given value.
89 static expr_t make(int64_t value) {
90 return expr_t(new pint_imm_t(-1, value));
91 }
92
93 static expr_t _0() {
94 static thread_local expr_t ret = pint_imm_t::make(0);
95 return ret;
96 }
97 static expr_t _1() {
98 static thread_local expr_t ret = pint_imm_t::make(1);
99 return ret;
100 }
101
102 // Matches any integer constant.
103 static expr_t make_any(int64_t id) { return expr_t(new pint_imm_t(id, 0)); }
104
105 bool matches(const int_imm_t &imm) const {
106 if (id == -1) return value == imm.value;
107 return true;
108 }
109
110 bool is_equal(const object_impl_t &obj) const override {
111 if (!obj.is<self_type>()) return false;
112 auto &other = obj.as<self_type>();
113
114 return value == other.value;
115 }
116
117 size_t get_hash() const override { return ir_utils::get_hash(value); }
118
119 std::string str() const override {
120 std::ostringstream oss;
121 oss << "pint_imm_t(" << value << ")";
122 return oss.str();
123 }
124
125 int id;
126 int64_t value;
127
128private:
129 pint_imm_t(int id, int64_t value)
130 : expr_impl_t(_type_info(), type_t::undef()), id(id), value(value) {}
131};
132
133// Stores already matched pairs of <pattern expression, matched expression>.
134class match_context_t {
135public:
136 bool contains(const expr_t &ptrn) const {
137 return expr_matched_.count(ptrn) != 0;
138 }
139
140 void set(const expr_t &ptrn, const expr_t &e) {
141 ir_assert(ptrn.is<pexpr_t>());
142 auto ret = expr_matched_.insert({ptrn, e});
143 ir_assert(ret.second);
144 MAYBE_UNUSED(ret);
145 }
146
147 const expr_t &operator[](const expr_t &ptrn) const {
148 return expr_matched_.at(ptrn);
149 }
150
151 template <typename T>
152 const T &at(const expr_t &ptrn) const {
153 return expr_matched_.at(ptrn).as<T>();
154 }
155
156 expr_t sub(const expr_t &expr) const;
157
158private:
159 object_eq_map_t<expr_t, expr_t> expr_matched_;
160};
161
162class pexpr_substitute_t : public ir_mutator_t {
163public:
164 using ir_mutator_t::_mutate;
165
166 pexpr_substitute_t(const match_context_t *ctx) : ctx_(ctx) {}
167
168 object_t _mutate(const pexpr_t &obj) override {
169 return (*ctx_)[expr_t(obj)];
170 }
171
172private:
173 const match_context_t *ctx_;
174};
175
176// Replaces occurrences of pattern expressions in `expr` according to the
177// context.
178expr_t match_context_t::sub(const expr_t &expr) const {
179 pexpr_substitute_t s(this);
180 return s.mutate(expr);
181}
182
183// Returns true if the expression matches the pattern, false otherwise. Upon
184// successful match the context contains information matched pattern
185// expressions.
186bool match(const expr_t &ptrn, const expr_t &expr, match_context_t &ctx);
187
188bool match_binary(
189 const expr_t &ptrn, const expr_t &expr, match_context_t &ctx) {
190 bool ptrn_is_binary = is_binary_op(ptrn);
191 bool expr_is_binary = is_binary_op(expr);
192
193 if (!ptrn_is_binary || !expr_is_binary) return false;
194
195 auto &ptrn_op = ptrn.as<binary_op_t>();
196 auto &expr_op = expr.as<binary_op_t>();
197 if (ptrn_op.op_kind != expr_op.op_kind) return false;
198
199 match_context_t ctx_copy = ctx;
200 if (match(ptrn_op.a, expr_op.a, ctx_copy)
201 && match(ptrn_op.b, expr_op.b, ctx_copy)) {
202 ctx = ctx_copy;
203 return true;
204 }
205 return false;
206}
207
208bool match_iif(const expr_t &ptrn, const expr_t &expr, match_context_t &ctx) {
209 bool ptrn_is_iif = ptrn.is<iif_t>();
210 bool expr_is_iif = expr.is<iif_t>();
211
212 if (!ptrn_is_iif || !expr_is_iif) return false;
213
214 auto &ptrn_iif = ptrn.as<iif_t>();
215 auto &expr_iif = expr.as<iif_t>();
216
217 match_context_t ctx_copy = ctx;
218 if (match(ptrn_iif.cond, expr_iif.cond, ctx_copy)
219 && match(ptrn_iif.true_expr, expr_iif.true_expr, ctx_copy)
220 && match(ptrn_iif.false_expr, expr_iif.false_expr, ctx_copy)) {
221 ctx = ctx_copy;
222 return true;
223 }
224
225 return false;
226}
227
228bool match(const expr_t &ptrn, const expr_t &expr, match_context_t &ctx) {
229 if (ptrn.is_equal(expr)) return true;
230
231 if (ptrn.is<pint_imm_t>()) {
232 auto &ptrn_imm = ptrn.as<pint_imm_t>();
233
234 bool ok = false;
235 if (expr.is<int_imm_t>()) {
236 ok = ptrn_imm.matches(expr.as<int_imm_t>());
237 } else if (ptrn_imm.id == -1 && expr.is<float_imm_t>()) {
238 ok = (to_cpp<float>(expr) == ptrn_imm.value);
239 }
240 return ok;
241 }
242
243 if (ptrn.is<pexpr_t>()) {
244 if (ctx.contains(ptrn)) {
245 if (!ctx[ptrn].is_equal(expr)) return false;
246 } else {
247 ctx.set(ptrn, expr);
248 }
249 return true;
250 }
251
252 if (match_binary(ptrn, expr, ctx)) return true;
253 if (match_iif(ptrn, expr, ctx)) return true;
254
255 return false;
256}
257
258// Rewrites expression `expr` according to `from` -> `to` rule.
259// Example:
260// auto x = pexpr_t::x();
261// auto c = rewrite(a + a, x + x, 2 * x);
262// // Now c is equal to (2 * a).
263expr_t rewrite(const expr_t &expr, const expr_t &from, const expr_t &to,
264 bool *rewritten = nullptr) {
265 match_context_t ctx;
266 if (match(from, expr, ctx)) {
267 if (rewritten) *rewritten = true;
268 return ctx.sub(to);
269 }
270 if (rewritten) *rewritten = false;
271 return expr;
272}
273
274expr_t rewrite_binary(const expr_t &expr, const expr_t &from, const expr_t &to,
275 bool *rewritten = nullptr) {
276 match_context_t ctx;
277 if (match_binary(from, expr, ctx)) {
278 if (rewritten) *rewritten = true;
279 return ctx.sub(to);
280 }
281 if (rewritten) *rewritten = false;
282 return expr;
283}
284
285#define REWRITE(a, b) \
286 do { \
287 bool rewritten; \
288 static thread_local auto _a = a; \
289 static thread_local auto _b = b; \
290 e = rewrite(e, _a, _b, &rewritten); \
291 if (rewritten) return e; \
292 } while (false)
293
294#define REWRITE_BINARY(a, b) \
295 do { \
296 bool rewritten; \
297 static thread_local auto _a = a; \
298 static thread_local auto _b = b; \
299 e = rewrite_binary(e, _a, _b, &rewritten); \
300 if (rewritten) return e; \
301 } while (false)
302
303#define REWRITE_BINARY_NO_STATIC(a, b) \
304 do { \
305 bool rewritten; \
306 e = rewrite(e, a, b, &rewritten); \
307 if (rewritten) return e; \
308 } while (false)
309
310expr_t simplify_rewrite_add(const expr_t &_e) {
311 auto x = pexpr_t::x();
312 auto _0 = pint_imm_t::_0();
313
314 auto &obj = _e.as<binary_op_t>();
315 ir_assert(obj.op_kind == op_kind_t::_add);
316
317 auto e = _e;
318
319 REWRITE_BINARY(x + _0, x);
320 REWRITE_BINARY(_0 + x, x);
321 REWRITE_BINARY(x + x, 2 * x);
322
323 return _e;
324}
325
326expr_t simplify_rewrite_sub(const expr_t &_e) {
327 auto x = pexpr_t::x();
328 auto _0 = pint_imm_t::_0();
329
330 auto &obj = _e.as<binary_op_t>();
331 ir_assert(obj.op_kind == op_kind_t::_sub);
332
333 auto e = _e;
334
335 REWRITE_BINARY(x - _0, x);
336 REWRITE_BINARY(_0 - x, -x);
337 REWRITE_BINARY(x - x, 0);
338
339 return e;
340}
341
342expr_t simplify_rewrite_mul(const expr_t &_e) {
343 auto x = pexpr_t::x();
344 auto _0 = pint_imm_t::_0();
345 auto _1 = pint_imm_t::_1();
346
347 auto &obj = _e.as<binary_op_t>();
348 ir_assert(obj.op_kind == op_kind_t::_mul);
349
350 auto e = _e;
351
352 REWRITE_BINARY(x * _0, 0);
353 REWRITE_BINARY(_0 * x, 0);
354 REWRITE_BINARY(x * _1, x);
355 REWRITE_BINARY(_1 * x, x);
356
357 return e;
358}
359
360expr_t simplify_rewrite_div(const expr_t &_e) {
361 auto x = pexpr_t::x();
362 auto y = pexpr_t::y();
363 auto _0 = pint_imm_t::_0();
364 auto _1 = pint_imm_t::_1();
365
366 auto &obj = _e.as<binary_op_t>();
367 ir_assert(obj.op_kind == op_kind_t::_div);
368
369 auto e = _e;
370
371 REWRITE_BINARY(_0 / x, 0);
372 REWRITE_BINARY(x / _1, x);
373 REWRITE_BINARY(x / x, 1);
374
375 return e;
376}
377
378expr_t simplify_rewrite_mod(const expr_t &_e) {
379 auto x = pexpr_t::x();
380 auto _0 = pint_imm_t::_0();
381 auto _1 = pint_imm_t::_1();
382
383 auto &obj = _e.as<binary_op_t>();
384 ir_assert(obj.op_kind == op_kind_t::_mod);
385
386 auto e = _e;
387
388 REWRITE_BINARY(x % _1, 0);
389 REWRITE_BINARY(0 % x, 0);
390
391 return e;
392}
393
394expr_t simplify_rewrite_and(const expr_t &_e) {
395 auto x = pexpr_t::x();
396
397 auto &obj = _e.as<binary_op_t>();
398 ir_assert(obj.op_kind == op_kind_t::_and);
399
400 auto e = _e;
401
402 // Boolean rules.
403 if (e.type().is_bool()) {
404 auto _true = (e.type().is_scalar()
405 ? expr_t(true)
406 : shuffle_t::make_broadcast(
407 expr_t(true), e.type().elems()));
408 auto _false = (e.type().is_scalar()
409 ? expr_t(false)
410 : shuffle_t::make_broadcast(
411 expr_t(false), e.type().elems()));
412 REWRITE_BINARY_NO_STATIC(_true & x, x);
413 REWRITE_BINARY_NO_STATIC(x & _true, x);
414 REWRITE_BINARY_NO_STATIC(_false & x, _false);
415 REWRITE_BINARY_NO_STATIC(x & _false, _false);
416 }
417
418 return e;
419}
420
421expr_t simplify_rewrite_iif(const expr_t &_e) {
422 auto x = pexpr_t::x();
423 auto y = pexpr_t::y();
424
425 auto e = _e;
426
427 // Ternary operation rules.
428 REWRITE(iif_t::make(expr_t(true), x, y), x);
429 REWRITE(iif_t::make(expr_t(false), x, y), y);
430 REWRITE(iif_t::make(x, y, y), y);
431
432 return e;
433}
434
435expr_t simplify_try_ternary_rules(const expr_t &_e) {
436 auto x = pexpr_t::x();
437 auto y = pexpr_t::y();
438 auto z = pexpr_t::z();
439
440 auto e = _e;
441
442 // add3 rules.
443 REWRITE((x + y) + z, ternary_add3(x, y, z));
444 REWRITE(x + (y + z), ternary_add3(x, y, z));
445
446 // mad rules.
447 REWRITE(x + y * z, ternary_mad(x, y, z));
448 REWRITE(x - y * z, ternary_mad(x, -y, z));
449 REWRITE(y * z + x, ternary_mad(x, y, z));
450 REWRITE(y * z - x, ternary_mad(-x, y, z));
451
452 return e;
453}
454
455#undef REWRITE
456#undef REWRITE_NO_STATIC
457
458class term_rewrite_transformer_t : public ir_mutator_t {
459public:
460 object_t _mutate(const binary_op_t &obj) override {
461 auto e = ir_mutator_t::_mutate(obj);
462 switch (obj.op_kind) {
463 case op_kind_t::_add: return simplify_rewrite_add(e);
464 case op_kind_t::_sub: return simplify_rewrite_sub(e);
465 case op_kind_t::_mul: return simplify_rewrite_mul(e);
466 case op_kind_t::_div: return simplify_rewrite_div(e);
467 case op_kind_t::_mod: return simplify_rewrite_mod(e);
468 case op_kind_t::_and: return simplify_rewrite_and(e);
469 default: return e;
470 }
471 }
472 object_t _mutate(const iif_t &obj) override {
473 auto e = ir_mutator_t::_mutate(obj);
474 return simplify_rewrite_iif(e);
475 }
476};
477
478expr_t simplify_rewrite(const expr_t &e) {
479 expr_t ret;
480 if (is_const(e) || is_var(e)) {
481 ret = e;
482 } else {
483 term_rewrite_transformer_t trt;
484 ret = trt.mutate(e);
485 }
486 return ret;
487}
488
489class ternary_rewrite_transformer_t : public ir_mutator_t {
490public:
491 object_t _mutate(const binary_op_t &obj) override {
492 return mutate_expr(obj);
493 }
494 object_t _mutate(const iif_t &obj) override { return mutate_expr(obj); }
495
496 template <typename T>
497 expr_t mutate_expr(const T &obj) {
498 auto e_old = ir_mutator_t::_mutate(obj);
499 auto e = simplify_try_ternary_rules(e_old);
500 if (e.is_same(e_old)) return e_old;
501 return mutate(e);
502 }
503};
504
505expr_t simplify_rewrite_with_ternary(const expr_t &e, bool recursive) {
506 expr_t ret;
507 if (is_const(e) || is_var(e)) {
508 ret = e;
509 } else if (!recursive) {
510 ret = simplify_try_ternary_rules(e);
511 } else {
512 ternary_rewrite_transformer_t trt;
513 ret = trt.mutate(e);
514 }
515 return ret;
516}
517
518class cmp_simplifier_t : public ir_mutator_t {
519public:
520 object_t _mutate(const binary_op_t &obj) override {
521 auto e = ir_mutator_t::_mutate(obj);
522 if (!is_binary_cmp_op(e)) return e;
523
524 e = simplify_mod_comparison(e);
525
526 return e;
527 }
528
529 static expr_t reduce_lhs_rhs(const expr_t &e) {
530 if (!is_binary_cmp_op(e)) return e;
531
532 auto &op = e.as<binary_op_t>();
533
534 // Rule:
535 // (c0 * x op c1) or (x * c0 op c1) ->
536 // x new_op (c1 / c0) if abs(c1) % abs(c0) == 0
537 // new_op == op or new_op == negate_cmp_op(op)
538 expr_t c0;
539 expr_t c1 = op.b;
540 expr_t x;
541
542 if (!is_const(c1)) return e;
543 if (!is_binary_op(op.a, op_kind_t::_mul)) return e;
544
545 auto &a_op = op.a.as<binary_op_t>();
546 if (is_const(a_op.a)) {
547 c0 = a_op.a;
548 x = a_op.b;
549 } else if (is_const(a_op.b)) {
550 x = a_op.a;
551 c0 = a_op.b;
552 }
553
554 if (c0.is_empty()) return e;
555 if (!c0.type().is_int()) return e;
556 if (!c1.type().is_int()) return e;
557
558 auto i_c0 = to_cpp<int64_t>(c0);
559 auto i_c1 = to_cpp<int64_t>(c1);
560
561 bool is_c0_neg = (i_c0 < 0);
562 bool sign = ((i_c0 < 0) != (i_c1 < 0));
563 i_c0 = std::abs(i_c0);
564 i_c1 = std::abs(i_c1);
565
566 bool has_mod = (i_c1 % i_c0 != 0);
567 if (has_mod
568 && utils::one_of(op.op_kind, op_kind_t::_eq, op_kind_t::_ne))
569 return e;
570
571 auto new_op_kind = (is_c0_neg ? negate_cmp_op(op.op_kind) : op.op_kind);
572 int64_t div = i_c1 / i_c0;
573 if (has_mod) {
574 switch (new_op_kind) {
575 case op_kind_t::_ge:
576 case op_kind_t::_gt:
577 new_op_kind = op_kind_t::_ge;
578 div = (sign ? div : div + 1);
579 break;
580 case op_kind_t::_le:
581 case op_kind_t::_lt:
582 new_op_kind = op_kind_t::_le;
583 div = (sign ? div + 1 : div);
584 break;
585 default: ir_error_not_expected();
586 }
587 }
588
589 return binary_op_t::make(new_op_kind, x, (sign ? -1 : 1) * div);
590 }
591
592 static expr_t simplify_mod_comparison(const expr_t &e) {
593 if (!is_binary_cmp_op(e)) return e;
594
595 auto &op = e.as<binary_op_t>();
596
597 // Use the following inequalities:
598 // 0 <= (x % c0) < c0
599 if (!is_binary_op(op.a, op_kind_t::_mod)) return e;
600 if (!is_const(op.b)) return e;
601
602 auto &a_op = op.a.as<binary_op_t>();
603 if (!is_const(a_op.b)) return e;
604
605 auto &c0 = a_op.b;
606 ir_assert(to_cpp<int64_t>(c0) > 0) << e;
607
608 // Comparison against a constant is a continuous function, just check
609 // boundary points.
610 auto cond0 = binary_op_t::make(op.op_kind, 0, op.b);
611 auto cond1 = binary_op_t::make(op.op_kind, c0 - 1, op.b);
612
613 bool is_cond0 = to_cpp<bool>(const_fold_non_recursive(cond0));
614 bool is_cond1 = to_cpp<bool>(const_fold_non_recursive(cond1));
615
616 // Conditions are equal, can prove.
617 if (is_cond0 == is_cond1) return expr_t(is_cond0);
618
619 // Can't prove, return the original expression.
620 return e;
621 }
622};
623
624expr_t simplify_comparison(const expr_t &e) {
625 return cmp_simplifier_t().mutate(e);
626}
627
628class range_simplifier_t : public ir_mutator_t {
629public:
630 range_simplifier_t(const constraint_set_t &cset) : cset(cset) {}
631
632 object_t _mutate(const var_t &obj) override {
633 expr_t value;
634 if (cset.is_single_value(obj, value)) return std::move(value);
635 return obj;
636 }
637
638 const constraint_set_t &cset;
639};
640
641// Finds all constant operands on an N-ary operation, returns the folded
642// constant in `const_arg` and the remaining operands in `other_args`.
643void split_const_nary_op_arg(op_kind_t op_kind, const std::vector<expr_t> &args,
644 expr_t &const_arg, std::vector<expr_t> &other_args) {
645 other_args.resize(0);
646
647 const_arg = expr_t();
648 for (auto &a : args) {
649 if (is_const(a)) {
650 if (const_arg.is_empty()) {
651 const_arg = a;
652 continue;
653 }
654 const_arg = const_fold_non_recursive(
655 binary_op_t::make(op_kind, const_arg, a));
656 } else {
657 other_args.push_back(a);
658 }
659 }
660}
661
662// Folds all constant operands into one.
663void fold_const_nary_op_args(op_kind_t op_kind, const std::vector<expr_t> &args,
664 std::vector<expr_t> &new_args) {
665 expr_t c;
666 split_const_nary_op_arg(op_kind, args, c, new_args);
667 if (c.is_empty()) return;
668 if (op_kind == op_kind_t::_mul && is_zero(c)) {
669 new_args.clear();
670 new_args.push_back(c);
671 return;
672 }
673 if (op_kind == op_kind_t::_mul && is_one(c)) return;
674 if (op_kind == op_kind_t::_add && is_zero(c)) return;
675 new_args.push_back(c);
676}
677
678expr_t cvt_mul_to_nary_op(const expr_t &a, const expr_t &b) {
679 auto *a_nary = a.as_ptr<nary_op_t>();
680 auto *b_nary = b.as_ptr<nary_op_t>();
681
682 if (a_nary) ir_assert(a_nary->op_kind == op_kind_t::_mul);
683 if (b_nary) ir_assert(b_nary->op_kind == op_kind_t::_mul);
684
685 auto a_args = cvt_expr_to_nary_op_args(a);
686 auto b_args = cvt_expr_to_nary_op_args(b);
687
688 std::vector<expr_t> args;
689 args.insert(args.end(), a_args.begin(), a_args.end());
690 args.insert(args.end(), b_args.begin(), b_args.end());
691 return make_nary_op(op_kind_t::_mul, args);
692}
693
694class nary_op_visitor_t : public ir_visitor_t {
695public:
696 using ir_visitor_t::_visit;
697
698 virtual void _visit(const nary_op_t &obj) { visit(obj.args); }
699};
700
701class nary_op_mutator_t : public ir_mutator_t {
702public:
703 using ir_mutator_t::_mutate;
704
705 virtual object_t _mutate(const nary_op_t &obj) {
706 auto args = mutate(obj.args);
707 if (ir_utils::is_equal(args, obj.args)) return obj;
708 return make_nary_op(obj.op_kind, args);
709 }
710};
711
712class nary_op_transformer_t : public nary_op_mutator_t {
713public:
714 using nary_op_mutator_t::_mutate;
715
716 object_t _mutate(const binary_op_t &obj) override {
717 // Skip vector types.
718 if (!obj.type.is_scalar()) return nary_op_mutator_t::_mutate(obj);
719 switch (obj.op_kind) {
720 case op_kind_t::_add:
721 case op_kind_t::_sub:
722 case op_kind_t::_mul: {
723 auto a = mutate(obj.a);
724 auto b = obj.b;
725 auto nary_op_kind = obj.op_kind;
726 if (obj.op_kind == op_kind_t::_sub) {
727 nary_op_kind = op_kind_t::_add;
728 b *= -1;
729 }
730 b = mutate(b);
731 return make_nary_op(nary_op_kind, {a, b});
732 }
733 default: return nary_op_mutator_t::_mutate(obj);
734 }
735 }
736};
737
738class nary_op_flattener_t : public nary_op_mutator_t {
739public:
740 object_t _mutate(const nary_op_t &obj) override {
741 std::vector<expr_t> args;
742 for (auto &a : obj.args) {
743 auto new_a = mutate(a);
744 auto *nary = new_a.as_ptr<nary_op_t>();
745 if (nary && nary->op_kind == obj.op_kind) {
746 args.insert(args.end(), nary->args.begin(), nary->args.end());
747 } else {
748 args.push_back(new_a);
749 }
750 }
751 return make_nary_op(obj.op_kind, args);
752 }
753};
754
755expr_t nary_op_flatten(const expr_t &e) {
756 return nary_op_flattener_t().mutate(e);
757}
758
759class mul_nary_op_expander_t : public nary_op_flattener_t {
760public:
761 object_t _mutate(const nary_op_t &obj) override {
762 auto flat_object = nary_op_flattener_t::_mutate(obj);
763 if (obj.op_kind != op_kind_t::_mul) { return flat_object; }
764
765 auto args = flat_object.as<nary_op_t>().args;
766 std::vector<expr_t> new_args;
767 for (size_t i = 0; i < args.size(); i++) {
768 auto *nary = args[i].as_ptr<nary_op_t>();
769 if (nary && nary->op_kind != op_kind_t::_add) {
770 ir_error_not_expected();
771 }
772 auto i_args = cvt_expr_to_nary_op_args(args[i]);
773 if (new_args.empty()) {
774 new_args = i_args;
775 continue;
776 }
777 std::vector<expr_t> next_args;
778 for (auto &a : new_args)
779 for (auto &b : i_args)
780 next_args.push_back(cvt_mul_to_nary_op(a, b));
781
782 new_args = next_args;
783 }
784 return make_nary_op(op_kind_t::_add, new_args);
785 }
786};
787
788class nary_op_canonical_verifier_t : public nary_op_visitor_t {
789public:
790 bool is_canonical() const { return is_canonical_; }
791
792 void _visit(const binary_op_t &obj) override {
793 // Skip vector types.
794 if (!obj.type.is_scalar()) {
795 visit_new_scope(obj);
796 return;
797 }
798 switch (obj.op_kind) {
799 // These operations must be converted to nary_op_t at this point.
800 case op_kind_t::_add:
801 case op_kind_t::_sub:
802 case op_kind_t::_mul: set_canonical_false(); break;
803 default: {
804 // Assume new scope here, n-ary operations from different
805 // scopes can't be merged.
806 visit_new_scope(obj);
807 break;
808 }
809 }
810 }
811
812 void _visit(const iif_t &obj) override { visit_new_scope(obj); }
813
814 void _visit(const load_t &obj) override { visit_new_scope(obj); }
815
816 void _visit(const ptr_t &obj) override { visit_new_scope(obj); }
817
818 void _visit(const unary_op_t &obj) override { visit_new_scope(obj); }
819
820 void _visit(const nary_op_t &obj) override {
821 if (parent_nary_) {
822 if (!(parent_nary_->op_kind == op_kind_t::_add
823 && obj.op_kind == op_kind_t::_mul)) {
824 // Multiplications must be expanded at this point.
825 set_canonical_false();
826 return;
827 }
828 }
829
830 auto *old_parent_nary = parent_nary_;
831 parent_nary_ = &obj;
832 visit(obj.args);
833 parent_nary_ = old_parent_nary;
834 }
835
836private:
837 void set_canonical_false() { is_canonical_ = false; }
838
839 template <typename T>
840 void visit_new_scope(const T &obj) {
841 auto *old_parent_nary = parent_nary_;
842 parent_nary_ = nullptr;
843 nary_op_visitor_t::_visit(obj);
844 parent_nary_ = old_parent_nary;
845 }
846
847 bool is_canonical_ = true;
848 const nary_op_t *parent_nary_ = nullptr;
849};
850
851// Checks if the expression is in the canonical N-ary form.
852bool is_nary_op_canonical(const expr_t &e) {
853 nary_op_canonical_verifier_t v;
854 v.visit(e);
855 return v.is_canonical();
856}
857
858class nary_op_back_transformer_t : public nary_op_mutator_t {
859public:
860 object_t _mutate(const nary_op_t &obj) {
861 auto new_obj = nary_op_mutator_t::_mutate(obj);
862 auto &nary = new_obj.as<nary_op_t>();
863 ir_assert(nary.args.size() > 0) << new_obj;
864
865 if (nary.args.size() == 1) return nary.args[0];
866
867 if (nary.op_kind == op_kind_t::_add) {
868 expr_t ret = nary.args[0] + nary.args[1];
869 for (size_t i = 2; i < nary.args.size(); i++)
870 ret += nary.args[i];
871 return std::move(ret);
872 } else if (nary.op_kind == op_kind_t::_mul) {
873 expr_t ret = nary.args[0] * nary.args[1];
874 for (size_t i = 2; i < nary.args.size(); i++)
875 ret *= nary.args[i];
876 return std::move(ret);
877 }
878 ir_error_not_expected();
879 return expr_t();
880 }
881};
882
883// Stores factorization of an expression in the canonical (normalized) form:
884// expr = (f(0), f(1), f(2), ... f(n))
885// f(0), ... f(n-1) are non-constant expressions, f(n) is a constant.
886class factored_expr_t : public expr_impl_t {
887public:
888 IR_DECL_EXPR_TYPE_ID(factored_expr_t);
889
890 static expr_t make(const expr_t &e) {
891 return expr_t(new factored_expr_t(e));
892 }
893
894 static expr_t make(const type_t &type, const std::vector<expr_t> &factors) {
895 return expr_t(new factored_expr_t(type, factors));
896 }
897
898 bool is_equal(const object_impl_t &obj) const override {
899 if (!obj.is<self_type>()) return false;
900 auto &other = obj.as<self_type>();
901
902 if (factors.size() != other.factors.size()) return false;
903 if (!factors.back().is_equal(other.factors.back())) return false;
904
905 auto common = intersect(obj);
906 auto &f_common = common.as<factored_expr_t>();
907 return f_common.factors.size() == factors.size();
908 }
909
910 // Constant factor is ignored during comparison.
911 bool is_equal_ignore_const(const object_impl_t &obj) const {
912 if (!obj.is<self_type>()) return false;
913 auto &other = obj.as<self_type>();
914
915 if (factors.size() != other.factors.size()) return false;
916
917 auto common = intersect_ignore_const(obj);
918 auto &f_common = common.as<factored_expr_t>();
919 return f_common.factors.size() == factors.size();
920 }
921
922 size_t get_hash() const override { return ir_utils::get_hash(factors); }
923
924 std::string str() const override {
925 std::ostringstream oss;
926 oss << "f(";
927 for (size_t i = 0; i < factors.size(); i++) {
928 oss << (i != 0 ? " x " : "") << factors[i];
929 }
930
931 if (factors.empty()) oss << "1";
932
933 oss << ")";
934 return oss.str();
935 }
936
937 expr_t expr() const {
938 if (factors.size() > 1 && jit::is_one(factors.back())) {
939 std::vector<expr_t> f(factors.begin(), factors.end() - 1);
940 return make_nary_op(op_kind_t::_mul, f);
941 }
942 return make_nary_op(op_kind_t::_mul, factors);
943 }
944
945 expr_t const_factor() const { return factors.back(); }
946
947 bool is_one() const {
948 return (factors.size() == 1) && jit::is_one(factors[0]);
949 }
950
951 bool is_const() const { return factors.size() == 1; }
952
953 // Returns multiplication of this and other as factored_expr_t.
954 expr_t merge(const expr_t &other) const {
955 auto &f_other = other.as<factored_expr_t>();
956 std::vector<expr_t> merged_factors(factors.begin(), factors.end());
957 merged_factors.insert(merged_factors.end(), f_other.factors.begin(),
958 f_other.factors.end());
959 return factored_expr_t::make(type, merged_factors);
960 }
961
962 // Returns common factors of this and other as factored_expr_t.
963 expr_t intersect(const expr_t &other) const {
964 return intersect_impl(other, false);
965 }
966
967 // Returns common factors of this and other as factored_expr_t (ignores
968 // constant factors).
969 expr_t intersect_ignore_const(const expr_t &other) const {
970 return intersect_impl(other, true);
971 }
972
973 // Returns factors of this not presented in other as factored_expr_t.
974 expr_t diff(const expr_t &_other) const {
975 auto &other = _other.as<factored_expr_t>();
976 object_eq_map_t<expr_t, int> f_map;
977 // Skip constant factor.
978 for (size_t i = 0; i < factors.size() - 1; i++)
979 f_map[factors[i]]++;
980
981 for (auto &e : other.factors) {
982 if (f_map[e] > 0) f_map[e]--;
983 }
984 std::vector<expr_t> diff_factors;
985 for (auto &kv : f_map) {
986 for (int i = 0; i < kv.second; i++)
987 diff_factors.push_back(kv.first);
988 }
989 // Handle constant factor.
990 int64_t a_const = to_cpp<int64_t>(factors.back());
991 int64_t b_const = to_cpp<int64_t>(other.factors.back());
992 if (a_const != 0 && b_const != 0) {
993 int64_t ab_gcd = ((a_const < 0) && (b_const < 0)) ? -1 : 1;
994 ab_gcd *= math::gcd(std::abs(a_const), std::abs(b_const));
995 diff_factors.push_back(to_expr(a_const / ab_gcd, type));
996 } else if (a_const != 0 || b_const != 0) {
997 diff_factors.push_back(to_expr(a_const, type));
998 }
999
1000 return factored_expr_t::make(type, diff_factors);
1001 }
1002
1003 // Returns factors of this reduced by factors of other as factored_expr_t.
1004 // This object must be reducible by other.
1005 expr_t reduce(const expr_t &other) const {
1006 auto &f_other = other.as<factored_expr_t>();
1007 auto f_common = intersect(other);
1008 auto diff_other = f_other.diff(f_common);
1009 // Other must be reducible.
1010 ir_assert(diff_other.as<factored_expr_t>().is_one()) << diff_other;
1011 return diff(f_common);
1012 }
1013
1014 // Returns true if this can be reduced by other.
1015 bool is_reducible(const expr_t &other) const {
1016 auto f_common = intersect(other);
1017 return f_common.is_equal(other);
1018 }
1019
1020 static expr_t reduce(expr_t &a, expr_t &b) {
1021 auto fa_expr = factored_expr_t::make(a);
1022 auto fb_expr = factored_expr_t::make(b);
1023 auto &fa = fa_expr.as<factored_expr_t>();
1024 auto &fb = fb_expr.as<factored_expr_t>();
1025 auto f_common = fa.intersect(&fb);
1026 a = fa.reduce(f_common).as<factored_expr_t>().expr();
1027 b = fb.reduce(f_common).as<factored_expr_t>().expr();
1028 return f_common;
1029 }
1030
1031 std::vector<expr_t> factors;
1032
1033private:
1034 factored_expr_t(const expr_t &e) : expr_impl_t(_type_info(), e.type()) {
1035 init_factors(e);
1036 }
1037
1038 factored_expr_t(const type_t &type, const std::vector<expr_t> &factors)
1039 : expr_impl_t(_type_info(), type) {
1040 init_normalize(factors);
1041 }
1042
1043 void init_normalize(const std::vector<expr_t> &f) {
1044 bool sign = false;
1045 expr_t e_const = to_expr(1);
1046 for (auto &e : f) {
1047 if (!jit::is_const(e)) {
1048 factors.push_back(e);
1049 continue;
1050 }
1051 if (to_cpp<int64_t>(e) < 0) sign = !sign;
1052 if (jit::is_one(e) || jit::is_minus_one(e)) continue;
1053
1054 e_const = e_const * abs(e);
1055 }
1056 if (sign) e_const = -e_const;
1057 factors.push_back(e_const);
1058 }
1059
1060 void init_factors(const expr_t &e) {
1061 auto *nary = e.as_ptr<nary_op_t>();
1062 if (!nary) {
1063 auto *unary = e.as_ptr<unary_op_t>();
1064 if (unary && unary->op_kind == op_kind_t::_minus) {
1065 init_factors(unary->a);
1066 factors.back() *= -1;
1067 return;
1068 }
1069 init_normalize({e});
1070 return;
1071 }
1072
1073 if (nary->op_kind == op_kind_t::_mul) {
1074 expr_t f_mul = factored_expr_t::make(to_expr(1));
1075 for (auto &a : nary->args) {
1076 f_mul = f_mul.as<factored_expr_t>().merge(
1077 factored_expr_t::make(a));
1078 }
1079 factors = f_mul.as<factored_expr_t>().factors;
1080 return;
1081 }
1082
1083 if (nary->op_kind == op_kind_t::_add) {
1084 expr_t common;
1085 for (auto &a : nary->args) {
1086 if (common.is_empty()) {
1087 common = factored_expr_t::make(a);
1088 continue;
1089 }
1090 common = common.as<factored_expr_t>().intersect(
1091 factored_expr_t::make(a));
1092 }
1093 if (common.as<factored_expr_t>().is_one()) {
1094 init_normalize({e});
1095 return;
1096 }
1097 std::vector<expr_t> rest_factors;
1098 for (auto &a : nary->args) {
1099 auto fa_expr = factored_expr_t::make(a);
1100 auto &fa = fa_expr.as<factored_expr_t>();
1101 rest_factors.push_back(
1102 fa.reduce(common).as<factored_expr_t>().expr());
1103 }
1104 auto &f_common = common.as<factored_expr_t>();
1105 auto rest = factored_expr_t::make(
1106 make_nary_op(op_kind_t::_add, rest_factors));
1107 factors = f_common.merge(rest).as<factored_expr_t>().factors;
1108 return;
1109 }
1110 ir_error_not_expected();
1111 }
1112
1113 expr_t intersect_impl(const expr_t &other, bool ignore_constants) const {
1114 auto &f_other = other.as<factored_expr_t>();
1115 object_eq_map_t<expr_t, int> f_map;
1116 // Skip constant factor.
1117 for (size_t i = 0; i < factors.size() - 1; i++)
1118 f_map[factors[i]]++;
1119
1120 std::vector<expr_t> common_factors;
1121 for (auto &e : f_other.factors) {
1122 auto it = f_map.find(e);
1123 if (it == f_map.end() || it->second == 0) continue;
1124 f_map[e]--;
1125 common_factors.push_back(e);
1126 }
1127
1128 if (ignore_constants)
1129 return factored_expr_t::make(type, common_factors);
1130
1131 // Handle constant factor.
1132 int64_t a_const = to_cpp<int64_t>(factors.back());
1133 int64_t b_const = to_cpp<int64_t>(f_other.factors.back());
1134 if (a_const != 0 && b_const != 0) {
1135 int64_t ab_gcd = ((a_const < 0) && (b_const < 0)) ? -1 : 1;
1136 ab_gcd *= math::gcd(std::abs(a_const), std::abs(b_const));
1137 if (ab_gcd != 1) common_factors.push_back(to_expr(ab_gcd, type));
1138 } else if (a_const == 0 && b_const == 0) {
1139 common_factors.push_back(to_expr(0, type));
1140 }
1141
1142 return factored_expr_t::make(type, common_factors);
1143 }
1144};
1145
1146class division_reducer_t : public nary_op_mutator_t {
1147public:
1148 using nary_op_mutator_t::_mutate;
1149
1150 object_t _mutate(const binary_op_t &obj) override {
1151 if (obj.op_kind != op_kind_t::_div)
1152 return nary_op_mutator_t::_mutate(obj);
1153
1154 expr_t a = mutate(obj.a);
1155 expr_t b = mutate(obj.b);
1156
1157 factored_expr_t::reduce(a, b);
1158
1159 if (is_one(b)) return std::move(a);
1160
1161 return binary_op_t::make(op_kind_t::_div, a, b);
1162 }
1163};
1164
1165bool is_divisible(
1166 const expr_t &a, const expr_t &b, const constraint_set_t &cset) {
1167 if (cset.can_prove(a % b == 0, /*try_simplify=*/false)) return true;
1168
1169 // Try to find b in factors of a.
1170 auto fa = factored_expr_t::make(a);
1171 auto fb = factored_expr_t::make(b);
1172 return fa.as<factored_expr_t>().is_reducible(fb);
1173}
1174
1175class int_div_mod_expander_t : public nary_op_mutator_t {
1176public:
1177 using nary_op_mutator_t::_mutate;
1178
1179 int_div_mod_expander_t(const constraint_set_t &cset) : cset(cset) {}
1180
1181 object_t _mutate(const binary_op_t &_obj) override {
1182 auto obj = nary_op_mutator_t::_mutate(_obj);
1183 auto *binary_op = obj.as_ptr<binary_op_t>();
1184 if (!binary_op) return obj;
1185 if (!utils::one_of(
1186 binary_op->op_kind, op_kind_t::_div, op_kind_t::_mod))
1187 return obj;
1188 if (!binary_op->type.is_int()) return obj;
1189
1190 auto a = binary_op->a;
1191 auto b = binary_op->b;
1192
1193 auto _b = nary_op_back_transform(b);
1194 if (!cset.can_prove(_b > 0)) return obj;
1195
1196 auto *a_nary = a.as_ptr<nary_op_t>();
1197
1198 if (a_nary && a_nary->op_kind == op_kind_t::_add)
1199 return mutate_with_add(*binary_op);
1200
1201 // Try to reduce a and b.
1202 auto common_factor = factored_expr_t::reduce(a, b);
1203
1204 if (is_one(b)) {
1205 if (binary_op->op_kind == op_kind_t::_mod)
1206 return to_expr(0, binary_op->type);
1207 if (binary_op->op_kind == op_kind_t::_div) return std::move(a);
1208 }
1209
1210 if (binary_op->op_kind == op_kind_t::_div) {
1211 return a / b;
1212 } else if (binary_op->op_kind == op_kind_t::_mod) {
1213 auto &c = common_factor.as<factored_expr_t>();
1214 if (c.is_const() && to_cpp<int64_t>(c.const_factor()) > 1)
1215 return make_nary_op(op_kind_t::_mul, {c.const_factor(), a % b});
1216 }
1217
1218 return obj;
1219 }
1220
1221 expr_t mutate_with_add(const binary_op_t &obj) {
1222 expr_t e = obj;
1223 if (reduce_v1(e)) return e;
1224 if (reduce_v2(e)) return e;
1225 return e;
1226 }
1227
1228 // Applies the following rules:
1229 // 1) (A + B) % C -> B % C, when
1230 // - A % C == 0
1231 // - B >= 0
1232 // 2) (A + B) / C -> (A / C) + (B / C), when
1233 // - A % C == 0
1234 // - B >= 0
1235 bool reduce_v1(expr_t &expr) {
1236 auto *binary_op = expr.as_ptr<binary_op_t>();
1237 if (!binary_op) return false;
1238
1239 auto op_kind = binary_op->op_kind;
1240 auto &a = binary_op->a;
1241 auto &b = binary_op->b;
1242
1243 std::vector<expr_t> lhs_args; // Reducible summands.
1244 std::vector<expr_t> rhs_args; // Non-reducible summands.
1245
1246 auto *a_nary = a.as_ptr<nary_op_t>();
1247 for (auto &e : a_nary->args) {
1248 if (is_div_reducible(e, b)) {
1249 lhs_args.push_back(e);
1250 } else {
1251 rhs_args.push_back(e);
1252 }
1253 }
1254
1255 // Nothing to reduce, return expression as is.
1256 if (lhs_args.empty()) return false;
1257
1258 auto rhs_nary = make_nary_op(op_kind_t::_add, rhs_args);
1259 auto _rhs = nary_op_back_transform(rhs_nary);
1260 bool rhs_ge_0 = cset.can_prove(_rhs >= 0);
1261
1262 if (op_kind == op_kind_t::_mod) {
1263 if (rhs_args.empty()) {
1264 expr = to_expr(0, expr.type());
1265 return true;
1266 }
1267 if (!rhs_ge_0) return false;
1268 expr = rhs_nary % b;
1269 return true;
1270 }
1271
1272 if (op_kind == op_kind_t::_div) {
1273 if (!rhs_ge_0) return false;
1274 if (rhs_args.empty()) {
1275 expr = mutate(lhs_args[0] / b);
1276 for (int i = 1; i < int(lhs_args.size()); i++) {
1277 expr += mutate(lhs_args[i] / b);
1278 }
1279 return true;
1280 }
1281 auto lhs_div = make_nary_op(op_kind_t::_add, lhs_args) / b;
1282 auto rhs_div = rhs_nary / b;
1283 expr = mutate(lhs_div) + mutate(rhs_div);
1284 return true;
1285 }
1286
1287 ir_error_not_expected() << expr;
1288
1289 return false;
1290 }
1291
1292 // Applies the following rules:
1293 // 1) (A * B + D) / (A * C) -> (A * B) / (A * C), when
1294 // - A > 0
1295 // - C > 0
1296 // - 0 <= D < A
1297 // 2) (A * B + D) % (A * C) -> (A * B) % (A * C) + D % (A * C), when
1298 // - A > 0
1299 // - C > 0
1300 // - 0 <= D < A
1301 bool reduce_v2(expr_t &expr) {
1302 auto *binary_op = expr.as_ptr<binary_op_t>();
1303 if (!binary_op) return false;
1304
1305 auto op_kind = binary_op->op_kind;
1306 auto &a = binary_op->a;
1307 auto &b = binary_op->b;
1308 if (!is_const(b)) return false;
1309
1310 auto const_factor = [&](const expr_t &e) {
1311 auto _fe = factored_expr_t::make(e);
1312 auto &fe = _fe.as<factored_expr_t>();
1313 auto ret = to_cpp<int64_t>(fe.const_factor());
1314 for (auto &f : fe.factors)
1315 if (is_var(f)) ret *= cset.max_proven_gcd(f);
1316 return ret;
1317 };
1318
1319 // TODO: Check 0.
1320 // Find max constant GCD.
1321 int64_t b_gcd = const_factor(b);
1322 int64_t max_gcd = 0;
1323 auto *a_nary = a.as_ptr<nary_op_t>();
1324 for (auto &e : a_nary->args) {
1325 int64_t gcd = math::gcd(b_gcd, const_factor(e));
1326 if (gcd > max_gcd) max_gcd = gcd;
1327 }
1328
1329 if (max_gcd == 0) return false;
1330
1331 std::vector<expr_t> lhs_args; // Reducible summands.
1332 std::vector<expr_t> rhs_args; // Non-reducible summands.
1333 for (auto &e : a_nary->args) {
1334 if (is_div_reducible(e, max_gcd)) {
1335 lhs_args.push_back(e);
1336 } else {
1337 rhs_args.push_back(e);
1338 }
1339 }
1340
1341 // max_gcd is the GCD for some summand so at least one summand must be
1342 // reducible.
1343 ir_assert(!lhs_args.empty());
1344
1345 if (rhs_args.empty()) return false;
1346
1347 int64_t A = max_gcd;
1348 int64_t C = to_cpp<int64_t>(b) / A;
1349 if (A <= 0 || C <= 0) return false;
1350
1351 auto rhs_nary = make_nary_op(op_kind_t::_add, rhs_args);
1352 auto D = nary_op_back_transform(rhs_nary);
1353 if (!cset.can_prove(D >= 0) || !cset.can_prove(D < A)) return false;
1354
1355 if (op_kind == op_kind_t::_mod) {
1356 auto lhs_mod = make_nary_op(op_kind_t::_add, lhs_args) % b;
1357 auto rhs_mod = rhs_nary % b;
1358 expr = mutate(lhs_mod) + mutate(rhs_mod);
1359 return true;
1360 }
1361
1362 if (op_kind == op_kind_t::_div) {
1363 auto lhs_div = make_nary_op(op_kind_t::_add, lhs_args) / b;
1364 expr = lhs_div;
1365 return true;
1366 }
1367
1368 ir_error_not_expected() << expr;
1369
1370 return false;
1371 }
1372
1373 bool is_div_reducible(const expr_t &a, const expr_t &b) const {
1374 if (is_const(a) && is_const(b)) {
1375 return to_cpp<int64_t>(a) % to_cpp<int64_t>(b) == 0;
1376 }
1377
1378 if (b.is_equal(to_expr(1, b.type()))) return true;
1379
1380 return is_divisible(a, b, cset);
1381 }
1382
1383 const constraint_set_t &cset;
1384};
1385
1386class int_div_mod_range_simplifier_t : public nary_op_mutator_t {
1387public:
1388 using nary_op_mutator_t::_mutate;
1389
1390 int_div_mod_range_simplifier_t(const constraint_set_t &cset) : cset(cset) {}
1391
1392 object_t _mutate(const binary_op_t &obj) override {
1393 if (!utils::one_of(obj.op_kind, op_kind_t::_div, op_kind_t::_mod))
1394 return nary_op_mutator_t::_mutate(obj);
1395
1396 auto a = mutate(obj.a);
1397 auto b = mutate(obj.b);
1398
1399 auto _a = nary_op_back_transform(a);
1400 auto _b = nary_op_back_transform(b);
1401
1402 // 0 <= a < b => (a / b) == 0
1403 bool abs_a_lt_b = cset.can_prove(_a >= 0) && cset.can_prove(_a < _b);
1404
1405 // 0 <= a < b => (a % b) == a
1406 if (abs_a_lt_b) {
1407 if (obj.op_kind == op_kind_t::_div) return to_expr(0);
1408 if (obj.op_kind == op_kind_t::_mod) return a;
1409 }
1410
1411 return binary_op_t::make(obj.op_kind, a, b);
1412 }
1413
1414 const constraint_set_t &cset;
1415};
1416
1417// Factors out common factors in an N-ary expression.
1418class common_factor_simplifier_t : public nary_op_mutator_t {
1419public:
1420 object_t _mutate(const nary_op_t &obj) override {
1421 if (obj.op_kind != op_kind_t::_add)
1422 return nary_op_mutator_t::_mutate(obj);
1423
1424 auto args = mutate(obj.args);
1425 for (auto &a : args) {
1426 auto *nary = a.as_ptr<nary_op_t>();
1427 if (nary) ir_assert(nary->op_kind == op_kind_t::_mul) << a;
1428 }
1429
1430 // Fold same factors (find exact match, ignore constants).
1431 // Example:
1432 // (a * c1 + a * c2 + b) ->
1433 // (a * c3 + b) where c3 = (c1 + c2)
1434 for (size_t i = 0; i < args.size(); i++) {
1435 auto e_fi = factored_expr_t::make(args[i]);
1436 for (size_t j = i + 1; j < args.size(); j++) {
1437 auto e_fj = factored_expr_t::make(args[j]);
1438
1439 auto &fi = e_fi.as<factored_expr_t>();
1440 auto &fj = e_fj.as<factored_expr_t>();
1441
1442 auto e_fij_common = fi.intersect_ignore_const(e_fj);
1443 auto &fij_common = e_fij_common.as<factored_expr_t>();
1444 if (fi.is_equal_ignore_const(fij_common)
1445 && fj.is_equal_ignore_const(fij_common)) {
1446 auto new_args = fij_common.factors;
1447 new_args.push_back(fi.const_factor() + fj.const_factor());
1448 args[i] = make_nary_op(op_kind_t::_mul, new_args);
1449 e_fi = factored_expr_t::make(args[i]);
1450 args[j] = to_expr(0, args[j].type());
1451 }
1452 }
1453 }
1454
1455 // Partial folding (fold any match).
1456 // Example:
1457 // (a * b * c + a * b * d + e) ->
1458 // ((a * b * (c + d)) + e)
1459 for (size_t i = 0; i < args.size(); i++) {
1460 if (is_zero(args[i])) continue;
1461 auto e_fi = factored_expr_t::make(args[i]);
1462 for (size_t j = i + 1; j < args.size(); j++) {
1463 if (is_zero(args[j])) continue;
1464 auto e_fj = factored_expr_t::make(args[j]);
1465
1466 auto &fi = e_fi.as<factored_expr_t>();
1467 auto &fj = e_fj.as<factored_expr_t>();
1468
1469 auto e_fij_common = fi.intersect_ignore_const(e_fj);
1470 auto &fij_common = e_fij_common.as<factored_expr_t>();
1471
1472 // fij_common = 1 means no common factors, other constant
1473 // factors are also ignored, for simplicity (though it might be
1474 // beneficial to fold them as well).
1475 if (fij_common.is_const()) continue;
1476
1477 // factored_expr_t::make() will find common factors.
1478 auto e_fi_add_fj = factored_expr_t::make(
1479 make_nary_op(op_kind_t::_add, {fi.expr(), fj.expr()}));
1480 auto &fi_add_fj = e_fi_add_fj.as<factored_expr_t>();
1481 args[i] = make_nary_op(op_kind_t::_mul, fi_add_fj.factors);
1482 e_fi = e_fi_add_fj;
1483 args[j] = to_expr(0, args[j].type());
1484 }
1485 }
1486
1487 return make_nary_op(obj.op_kind, args);
1488 }
1489};
1490
1491// Rewrites addition with mixed 64-bit/32-bit expressions to reduce 64-bit
1492// arithmetic. Example:
1493// Before: ((x.s64 + y.s32) + z.s32) [two 64-bit add]
1494// After: ((y.s32 + z.s32) + x.s64) [one 32-bit add and one 64-bit add]
1495class _64_bit_add_optimizer_t : public nary_op_mutator_t {
1496public:
1497 object_t _mutate(const nary_op_t &obj) override {
1498 auto new_obj = nary_op_mutator_t::_mutate(obj);
1499 auto *nary_op = new_obj.as_ptr<nary_op_t>();
1500 if (nary_op->op_kind != op_kind_t::_add || nary_op->args.size() <= 2)
1501 return new_obj;
1502
1503 std::vector<expr_t> other_args;
1504 std::vector<expr_t> x64_args;
1505 for (auto &a : nary_op->args) {
1506 if (a.type().is_x64()) {
1507 x64_args.push_back(a);
1508 } else {
1509 other_args.push_back(a);
1510 }
1511 }
1512
1513 if (other_args.empty() || x64_args.empty()) return new_obj;
1514
1515 std::vector<expr_t> new_args = std::move(other_args);
1516 new_args.insert(new_args.end(), x64_args.begin(), x64_args.end());
1517
1518 return nary_op_t::make(nary_op->op_kind, new_args);
1519 }
1520};
1521
1522// Simplifies using the N-ary form.
1523expr_t simplify_with_nary(const expr_t &_e, const constraint_set_t &cset) {
1524 auto e = _e;
1525
1526 if (!e.type().is_scalar() || e.type().is_fp()) { return e; }
1527 e = nary_op_canonicalize(e);
1528
1529 e = division_reducer_t().mutate(e);
1530 e = nary_op_flatten(e);
1531 e = int_div_mod_expander_t(cset).mutate(e);
1532 e = common_factor_simplifier_t().mutate(e);
1533 e = int_div_mod_range_simplifier_t(cset).mutate(e);
1534 e = _64_bit_add_optimizer_t().mutate(e);
1535
1536 e = nary_op_back_transform(e);
1537
1538 return e;
1539}
1540
1541class stmt_simplifier_t : public ir_mutator_t {
1542public:
1543 stmt_simplifier_t(const constraint_set_t &cset) : cset_(cset) {}
1544
1545 ~stmt_simplifier_t() override {
1546 if (!cpp_compat::uncaught_exceptions()) {
1547 ir_assert(continue_calls_.empty()) << "Unexpected continue calls.";
1548 }
1549 }
1550
1551 object_t _mutate(const binary_op_t &obj) override {
1552 return simplify(obj, cset_);
1553 }
1554
1555 object_t _mutate(const func_call_t &obj) override {
1556 if (obj.func.is_equal(funcs::continue_func())) {
1557 continue_calls_.push_back(obj);
1558 }
1559 return ir_mutator_t::_mutate(obj);
1560 }
1561
1562 object_t _mutate(const if_t &obj) override {
1563 auto cond = simplify(obj.cond);
1564
1565 if (all_of(cond, expr_t(true))) return mutate(obj.body);
1566 if (all_of(cond, expr_t(false))) return mutate(obj.else_body);
1567
1568 auto body = obj.body;
1569 if (!body.is_empty()) {
1570 auto cset_old = cset_;
1571 cset_.add_constraint(cond);
1572 body = ir_mutator_t::mutate(body);
1573 cset_ = cset_old;
1574 }
1575
1576 auto else_body = obj.else_body;
1577 if (!else_body.is_empty()) {
1578 auto cset_old = cset_;
1579 cset_.add_constraint(flip_condition(cond));
1580 else_body = ir_mutator_t::mutate(else_body);
1581 cset_ = cset_old;
1582 }
1583
1584 return if_t::make(cond, body, else_body);
1585 }
1586
1587 object_t _mutate(const let_t &obj) override {
1588 // External variable.
1589 if (obj.value.is_empty()) return ir_mutator_t::_mutate(obj);
1590
1591 // Substitute constants.
1592 auto value = simplify(obj.value);
1593 if (is_const(value)) {
1594 // Constants are not necessarily the same type as the assigned
1595 // variable
1596 value = cast_t::make(obj.var.as<var_t>().type, value);
1597
1598 auto body = substitute(obj.body, obj.var, value);
1599 return mutate(body);
1600 } else if (is_var(value)) {
1601 auto body = substitute(obj.body, obj.var, value);
1602 return mutate(body);
1603 }
1604
1605 auto cset_old = cset_;
1606 cset_.add_constraint(obj.var == value);
1607 auto new_obj = let_t::make(obj.var, value, obj.body);
1608 new_obj = ir_mutator_t::_mutate(new_obj.as<let_t>());
1609 cset_ = cset_old;
1610
1611 return std::move(new_obj);
1612 }
1613
1614 object_t _mutate(const for_t &obj) override {
1615 object_t new_obj;
1616 bool found_continue = false;
1617 size_t ncontinue_calls = continue_calls_.size();
1618 if (is_zero(obj.init) && is_one(obj.bound)) {
1619 auto body = substitute(obj.body, obj.var, expr_t(0));
1620 body = mutate(body);
1621 if (continue_calls_.size() > ncontinue_calls) found_continue = true;
1622 if (found_continue) {
1623 new_obj = for_t::make(
1624 obj.var, obj.init, obj.bound, body, obj.unroll);
1625 } else {
1626 new_obj = body;
1627 }
1628 } else {
1629 auto cset_old = cset_;
1630 cset_.add_constraint(obj.var >= obj.init);
1631 cset_.add_constraint(obj.var < obj.bound);
1632 new_obj = ir_mutator_t::_mutate(obj);
1633 if (continue_calls_.size() > ncontinue_calls) found_continue = true;
1634 cset_ = cset_old;
1635 }
1636
1637 // Remove continue call.
1638 if (found_continue) continue_calls_.pop_back();
1639
1640 return new_obj;
1641 }
1642
1643 object_t _mutate(const store_t &obj) override {
1644 auto new_obj = ir_mutator_t::_mutate(obj);
1645 if (new_obj.is_empty()) return stmt_t();
1646
1647 auto &store = new_obj.as<store_t>();
1648 if (!store.value.is<load_t>()) return new_obj;
1649
1650 auto &load = store.value.as<load_t>();
1651 if (!store.buf.is_equal(load.buf)) return new_obj;
1652 if (!store.off.is_equal(load.off)) return new_obj;
1653 if (store.stride != load.stride) return new_obj;
1654
1655 // This is a load/store of the same value which is a no-op.
1656 return stmt_t();
1657 }
1658
1659private:
1660 static op_kind_t flip_cmp_op(op_kind_t op_kind) {
1661 switch (op_kind) {
1662 case op_kind_t::_eq: return op_kind_t::_ne;
1663 case op_kind_t::_ge: return op_kind_t::_lt;
1664 case op_kind_t::_gt: return op_kind_t::_le;
1665 case op_kind_t::_le: return op_kind_t::_gt;
1666 case op_kind_t::_lt: return op_kind_t::_ge;
1667 case op_kind_t::_ne: return op_kind_t::_eq;
1668 default: ir_error_not_expected();
1669 }
1670 return op_kind_t::undef;
1671 }
1672
1673 static expr_t flip_condition(const expr_t &cond) {
1674 ir_assert(cond.type().is_bool());
1675
1676 auto *binary_op = cond.as_ptr<binary_op_t>();
1677 if (binary_op) {
1678 auto &a = binary_op->a;
1679 auto &b = binary_op->b;
1680 auto op_kind = binary_op->op_kind;
1681 return binary_op_t::make(flip_cmp_op(op_kind), a, b);
1682 }
1683
1684 auto *shuffle = cond.as_ptr<shuffle_t>();
1685 if (shuffle && shuffle->is_broadcast()) {
1686 return shuffle_t::make_broadcast(
1687 flip_condition(shuffle->vec[0]), shuffle->elems());
1688 }
1689
1690 ir_error_not_expected();
1691 return expr_t();
1692 }
1693
1694 constraint_set_t cset_;
1695 std::vector<stmt_t> continue_calls_;
1696};
1697
1698expr_t simplify_expr(const expr_t &_e, const constraint_set_t &cset) {
1699 expr_t e = _e;
1700
1701 if (is_const(e) || is_var(e)) return e;
1702
1703 e = const_fold(e);
1704 e = simplify_rewrite(e);
1705
1706 e = simplify_comparison(e);
1707 e = range_simplifier_t(cset).mutate(e);
1708 e = simplify_with_nary(e, cset);
1709
1710 e = const_fold(e);
1711 e = simplify_rewrite(e);
1712
1713 return e;
1714}
1715
1716stmt_t simplify_stmt(const stmt_t &s, const constraint_set_t &cset) {
1717 stmt_simplifier_t simplifier(cset);
1718 return simplifier.mutate(s);
1719}
1720
1721int64_t get_max_const_factor(const expr_t &_e, const constraint_set_t &cset) {
1722 ir_assert(_e.type().is_int());
1723 auto e = _e;
1724 // Some complex expressions need more than one simplify() call.
1725 int max_tries = 3;
1726 for (int i = 0; i < max_tries; i++)
1727 e = simplify(e, cset);
1728 auto o = factored_expr_t::make(nary_op_canonicalize(e));
1729 auto &expr = o.as<factored_expr_t>();
1730 return to_cpp<int64_t>(expr.const_factor());
1731}
1732
1733template <op_kind_t op_kind>
1734struct op_traits_t {};
1735
1736#define DECL_OP_TRAITS(name, op) \
1737 template <> \
1738 struct op_traits_t<name> { \
1739 template <typename T, \
1740 typename = typename std::enable_if< \
1741 !std::is_same<T, bool>::value>::type> \
1742 static auto compute(T a, T b) -> decltype(a op b) { \
1743 return a op b; \
1744 } \
1745 template <op_kind_t dummy_op = name, \
1746 typename \
1747 = typename std::enable_if<dummy_op == op_kind_t::_and>::type> \
1748 static bool compute(bool a, bool b) { \
1749 return a op b; \
1750 } \
1751 };
1752
1753DECL_OP_TRAITS(op_kind_t::_add, +)
1754DECL_OP_TRAITS(op_kind_t::_sub, -)
1755DECL_OP_TRAITS(op_kind_t::_mul, *)
1756DECL_OP_TRAITS(op_kind_t::_div, /)
1757DECL_OP_TRAITS(op_kind_t::_mod, %)
1758
1759DECL_OP_TRAITS(op_kind_t::_eq, ==)
1760DECL_OP_TRAITS(op_kind_t::_ne, !=)
1761DECL_OP_TRAITS(op_kind_t::_gt, >)
1762DECL_OP_TRAITS(op_kind_t::_ge, >=)
1763DECL_OP_TRAITS(op_kind_t::_lt, <)
1764DECL_OP_TRAITS(op_kind_t::_le, <=)
1765
1766DECL_OP_TRAITS(op_kind_t::_and, &)
1767
1768template <>
1769struct op_traits_t<op_kind_t::_min> {
1770 template <typename T>
1771 static T compute(T a, T b) {
1772 return std::min(a, b);
1773 }
1774};
1775
1776template <>
1777struct op_traits_t<op_kind_t::_max> {
1778 template <typename T>
1779 static T compute(T a, T b) {
1780 return std::max(a, b);
1781 }
1782};
1783
1784#undef DECL_OP_TRAITS
1785
1786template <op_kind_t op_kind, typename T, typename = void>
1787struct compute_helper_t {
1788 static expr_t call(T a, T b) { return expr_t(); }
1789};
1790
1791template <typename>
1792struct voider_t {
1793 using type = void;
1794};
1795
1796template <op_kind_t op_kind, typename T>
1797struct compute_helper_t<op_kind, T,
1798 typename voider_t<decltype(
1799 op_traits_t<op_kind>::compute(T(), T()))>::type> {
1800 static expr_t call(T a, T b) {
1801 return to_expr(op_traits_t<op_kind>::compute(a, b));
1802 }
1803};
1804
1805template <typename T>
1806class const_fold_helper_t {
1807public:
1808 template <typename U = T>
1809 static expr_t call(op_kind_t op_kind, T a, T b) {
1810 switch (op_kind) {
1811#define CASE(op) \
1812 case op: return compute_helper_t<op, T>::call(a, b);
1813
1814 CASE(op_kind_t::_add)
1815 CASE(op_kind_t::_sub)
1816 CASE(op_kind_t::_mul)
1817 CASE(op_kind_t::_div)
1818 CASE(op_kind_t::_mod)
1819
1820 CASE(op_kind_t::_eq)
1821 CASE(op_kind_t::_ne)
1822 CASE(op_kind_t::_gt)
1823 CASE(op_kind_t::_ge)
1824 CASE(op_kind_t::_lt)
1825 CASE(op_kind_t::_le)
1826
1827 CASE(op_kind_t::_and)
1828 CASE(op_kind_t::_min)
1829 CASE(op_kind_t::_max)
1830
1831 default: ir_error_not_expected();
1832
1833#undef CASE
1834 }
1835 return expr_t();
1836 }
1837};
1838
1839class const_folder_t : public ir_mutator_t {
1840public:
1841 object_t _mutate(const binary_op_t &obj) override {
1842 return mutate_expr(obj);
1843 }
1844 object_t _mutate(const cast_t &obj) override { return mutate_expr(obj); }
1845 object_t _mutate(const iif_t &obj) override { return mutate_expr(obj); }
1846 object_t _mutate(const unary_op_t &obj) override {
1847 return mutate_expr(obj);
1848 }
1849
1850private:
1851 template <typename T>
1852 object_t mutate_expr(const T &obj) {
1853 auto new_obj = ir_mutator_t::_mutate(obj);
1854 return const_fold_non_recursive(new_obj);
1855 }
1856};
1857
1858bool is_const_or_shuffle_const(const expr_t &e) {
1859 return is_const(e) || is_shuffle_const(e);
1860}
1861
1862expr_t const_fold_unary(op_kind_t op_kind, const expr_t &a) {
1863 ir_assert(op_kind == op_kind_t::_minus);
1864 if (!a.type().is_scalar()) {
1865 int elems = a.type().elems();
1866 std::vector<expr_t> ret;
1867 for (int i = 0; i < elems; i++) {
1868 ret.push_back(const_fold_unary(op_kind, a[i]));
1869 }
1870 return shuffle_t::make(ret);
1871 }
1872
1873#define CASE(ir_type, cpp_type) \
1874 if (a.type() == type_t::ir_type()) return to_expr(-to_cpp<cpp_type>(a))
1875
1876 CASE(f32, float);
1877 CASE(s16, int16_t);
1878 CASE(s32, int32_t);
1879 CASE(s64, int64_t);
1880
1881 if (a.type().is_bool()) return to_expr(!to_cpp<bool>(a));
1882
1883#undef CASE
1884
1885 ir_error_not_expected() << "Cannot handle type: " << a;
1886 return expr_t();
1887}
1888
1889expr_t const_fold_binary(const type_t &compute_type, op_kind_t op_kind,
1890 const expr_t &a, const expr_t &b) {
1891 if (!compute_type.is_scalar()) {
1892 int elems = compute_type.elems();
1893 auto scalar_type = compute_type.scalar();
1894 std::vector<expr_t> ret;
1895 for (int i = 0; i < elems; i++) {
1896 ret.push_back(const_fold_binary(scalar_type, op_kind, a[i], b[i]));
1897 }
1898 return shuffle_t::make(ret);
1899 }
1900
1901 if (compute_type.is_unsigned()) {
1902 auto a_s64 = to_cpp<int64_t>(a);
1903 auto b_s64 = to_cpp<int64_t>(b);
1904 ir_assert(a_s64 >= 0 && b_s64 >= 0)
1905 << "Overflow detected: fix data types.";
1906 MAYBE_UNUSED(a_s64);
1907 MAYBE_UNUSED(b_s64);
1908 }
1909
1910#define CASE(ir_type, cpp_type) \
1911 if (compute_type == type_t::ir_type()) { \
1912 auto _a = to_cpp<cpp_type>(a); \
1913 auto _b = to_cpp<cpp_type>(b); \
1914 return const_fold_helper_t<cpp_type>::call(op_kind, _a, _b); \
1915 }
1916
1917 CASE(_bool, bool)
1918 CASE(f32, float)
1919 CASE(s16, int16_t)
1920 CASE(s32, int32_t)
1921 CASE(s64, int64_t)
1922 CASE(u16, uint16_t)
1923 CASE(u32, uint32_t)
1924 CASE(u64, uint64_t)
1925
1926#undef CASE
1927
1928 ir_error_not_expected() << "Unknown type.";
1929 return expr_t();
1930}
1931
1932object_t simplify(const object_t &obj, const constraint_set_t &cset) {
1933 if (obj.is_expr()) return simplify_expr(obj, cset);
1934 if (obj.is_stmt()) return simplify_stmt(obj, cset);
1935 ir_assert(obj.is_empty());
1936 return object_t();
1937}
1938
1939expr_t simplify_cmp_move_const_to_rhs(const expr_t &e) {
1940 if (!is_binary_cmp_op(e)) return e;
1941
1942 auto &op = e.as<binary_op_t>();
1943 if (!is_const(op.b)) return e;
1944 if (!is_binary_op(op.a)) return e;
1945
1946 auto &a_op = op.a.as<binary_op_t>();
1947
1948 bool is_lhs_add = (a_op.op_kind == op_kind_t::_add);
1949 bool is_lhs_sub = (a_op.op_kind == op_kind_t::_sub);
1950 if (!is_lhs_add && !is_lhs_sub) return e;
1951
1952 auto &c1 = op.b;
1953
1954 expr_t lhs;
1955 expr_t rhs;
1956 op_kind_t op_kind;
1957 if (is_const(a_op.a)) {
1958 auto &c0 = a_op.a;
1959 auto &x = a_op.b;
1960 if (is_lhs_add) {
1961 // ((c0 + x) op c1) -> (x op (c1 - c0))
1962 lhs = x;
1963 rhs = c1 - c0;
1964 op_kind = op.op_kind;
1965 } else {
1966 // ((c0 - x) op c1) -> (x -op (c0 - c1))
1967 lhs = x;
1968 rhs = c0 - c1;
1969 op_kind = negate_cmp_op(op.op_kind);
1970 }
1971 } else if (is_const(a_op.b)) {
1972 auto &x = a_op.a;
1973 auto &c0 = a_op.b;
1974 if (is_lhs_add) {
1975 // ((x + c0) op c1) -> (x op (c1 - c0))
1976 lhs = x;
1977 rhs = c1 - c0;
1978 op_kind = op.op_kind;
1979 } else {
1980 // ((x - c0) op c1) -> (x op (c0 + c1))
1981 lhs = x;
1982 rhs = c0 + c1;
1983 op_kind = op.op_kind;
1984 }
1985 } else {
1986 return e;
1987 }
1988 return binary_op_t::make(op_kind, lhs, rhs);
1989}
1990
1991expr_t simplify_cmp_reduce_lhs_rhs(const expr_t &e) {
1992 if (!is_binary_cmp_op(e)) return e;
1993
1994 auto &op = e.as<binary_op_t>();
1995
1996 // Rule:
1997 // (c0 * x op c1) or (x * c0 op c1) ->
1998 // x new_op (c1 / c0) if abs(c1) % abs(c0) == 0
1999 // new_op == op or new_op == negate_cmp_op(op)
2000 expr_t c0;
2001 expr_t c1 = op.b;
2002 expr_t x;
2003
2004 if (!is_const(c1)) return e;
2005 if (!is_binary_op(op.a, op_kind_t::_mul)) return e;
2006
2007 auto &a_op = op.a.as<binary_op_t>();
2008 if (is_const(a_op.a)) {
2009 c0 = a_op.a;
2010 x = a_op.b;
2011 } else if (is_const(a_op.b)) {
2012 x = a_op.a;
2013 c0 = a_op.b;
2014 }
2015
2016 if (c0.is_empty()) return e;
2017 if (!c0.type().is_int()) return e;
2018 if (!c1.type().is_int()) return e;
2019
2020 auto i_c0 = to_cpp<int64_t>(c0);
2021 auto i_c1 = to_cpp<int64_t>(c1);
2022
2023 bool is_c0_neg = (i_c0 < 0);
2024 bool sign = ((i_c0 < 0) != (i_c1 < 0));
2025 i_c0 = std::abs(i_c0);
2026 i_c1 = std::abs(i_c1);
2027
2028 bool has_mod = (i_c1 % i_c0 != 0);
2029 if (has_mod && utils::one_of(op.op_kind, op_kind_t::_eq, op_kind_t::_ne))
2030 return e;
2031
2032 auto new_op_kind = (is_c0_neg ? negate_cmp_op(op.op_kind) : op.op_kind);
2033 int64_t div = i_c1 / i_c0;
2034 if (has_mod) {
2035 switch (new_op_kind) {
2036 case op_kind_t::_ge:
2037 case op_kind_t::_gt:
2038 new_op_kind = op_kind_t::_ge;
2039 div = (sign ? div : div + 1);
2040 break;
2041 case op_kind_t::_le:
2042 case op_kind_t::_lt:
2043 new_op_kind = op_kind_t::_le;
2044 div = (sign ? div + 1 : div);
2045 break;
2046 default: ir_error_not_expected();
2047 }
2048 }
2049
2050 return binary_op_t::make(new_op_kind, x, (sign ? -1 : 1) * div);
2051}
2052
2053bool const_to_const_binary(const expr_t &e, op_kind_t op_kind,
2054 const type_t &a_type, const type_t &b_type, expr_t &a, expr_t &b) {
2055 bool is_true = to_cpp<bool>(e);
2056 // Assume:
2057 // - a0 < b1
2058 // - a1 > b0
2059 // - a_eq == b_eq
2060 expr_t a0 = to_expr(0, a_type);
2061 expr_t a1 = to_expr(1, a_type);
2062 expr_t b0 = to_expr(0, b_type);
2063 expr_t b1 = to_expr(1, b_type);
2064 expr_t a_eq = to_expr(0, a_type);
2065 expr_t b_eq = to_expr(0, b_type);
2066 if (!a.is_empty()) {
2067 a0 = a1 = a;
2068 b0 = a - 1;
2069 b1 = a + 1;
2070 a_eq = b_eq = a;
2071 } else if (!b.is_empty()) {
2072 b0 = b1 = b;
2073 a0 = b - 1;
2074 a1 = b + 1;
2075 a_eq = b_eq = b;
2076 }
2077 switch (op_kind) {
2078 case op_kind_t::_and: a = b = e; return true;
2079 case op_kind_t::_le:
2080 case op_kind_t::_lt:
2081 a = (is_true ? a0 : a1);
2082 b = (is_true ? b1 : b0);
2083 return true;
2084 case op_kind_t::_ge:
2085 case op_kind_t::_gt:
2086 a = (is_true ? a1 : a0);
2087 b = (is_true ? b0 : b1);
2088 return true;
2089 case op_kind_t::_eq:
2090 a = (is_true ? a_eq : a0);
2091 b = (is_true ? b_eq : b1);
2092 return true;
2093 case op_kind_t::_ne:
2094 a = (is_true ? a0 : a_eq);
2095 b = (is_true ? b1 : b_eq);
2096 return true;
2097 default: return false;
2098 }
2099}
2100
2101expr_t simplify_propagate_shuffle(const expr_t &e) {
2102 if (!e.type().is_bool()) return e;
2103
2104 auto *shuffle = e.as_ptr<shuffle_t>();
2105 if (!shuffle) return e;
2106
2107 // Handle binary operation.
2108 {
2109 type_t a_type;
2110 type_t b_type;
2111 expr_t a_common_const;
2112 expr_t b_common_const;
2113 op_kind_t op_kind = op_kind_t::undef;
2114 bool found_binary = false;
2115 for (int i : shuffle->idx) {
2116 if (is_binary_op(shuffle->vec[i])) {
2117 auto &op = shuffle->vec[i].as<binary_op_t>();
2118 if (found_binary && op.op_kind != op_kind_t::_and) continue;
2119 found_binary = true;
2120 a_type = op.a.type();
2121 b_type = op.b.type();
2122 op_kind = op.op_kind;
2123 if (is_const(op.a)) a_common_const = op.a;
2124 if (is_const(op.b)) b_common_const = op.b;
2125 if (op_kind == op_kind_t::_and) break;
2126 }
2127 }
2128 if (!found_binary) return e;
2129
2130 for (int i : shuffle->idx) {
2131 auto &elem = shuffle->vec[i];
2132 if (is_binary_op(elem, op_kind)) {
2133 auto &op = elem.as<binary_op_t>();
2134 if (!a_common_const.is_equal(op.a)) {
2135 a_common_const = expr_t();
2136 }
2137 if (!b_common_const.is_equal(op.b)) {
2138 b_common_const = expr_t();
2139 }
2140 }
2141 }
2142
2143 bool ok = true;
2144 std::vector<expr_t> a;
2145 std::vector<expr_t> b;
2146 for (int i : shuffle->idx) {
2147 auto &elem = shuffle->vec[i];
2148 if (is_binary_op(elem, op_kind)) {
2149 auto &op = elem.as<binary_op_t>();
2150 a.push_back(op.a);
2151 b.push_back(op.b);
2152 } else if (is_const(elem)) {
2153 expr_t op_a = a_common_const;
2154 expr_t op_b = b_common_const;
2155 if (!const_to_const_binary(
2156 elem, op_kind, a_type, b_type, op_a, op_b)) {
2157 ok = false;
2158 break;
2159 }
2160 a.push_back(op_a);
2161 b.push_back(op_b);
2162 } else if (op_kind == op_kind_t::_and) {
2163 // Replace with expression true <op_kind> elem to allow matching
2164 // this op against future binary operation.
2165 a.push_back(bool_imm_t::make(true));
2166 b.push_back(elem);
2167 } else {
2168 ok = false;
2169 break;
2170 }
2171 }
2172 if (ok) {
2173 auto _a = simplify_propagate_shuffle(shuffle_t::make(a));
2174 auto _b = simplify_propagate_shuffle(shuffle_t::make(b));
2175 return binary_op_t::make(op_kind, _a, _b);
2176 }
2177 }
2178
2179 return e;
2180}
2181
2182expr_t const_fold_non_recursive(const expr_t &e) {
2183 auto *unary_op = e.as_ptr<unary_op_t>();
2184 if (unary_op) {
2185 auto &a = unary_op->a;
2186 if (!is_const_or_shuffle_const(a)) return e;
2187 return const_fold_unary(unary_op->op_kind, a);
2188 }
2189
2190 auto *binary_op = e.as_ptr<binary_op_t>();
2191 if (binary_op) {
2192 auto op_kind = binary_op->op_kind;
2193 auto &a = binary_op->a;
2194 auto &b = binary_op->b;
2195 if (!is_const_or_shuffle_const(a) || !is_const_or_shuffle_const(b))
2196 return e;
2197
2198 auto compute_type = common_type(a, b);
2199 return const_fold_binary(compute_type, op_kind, a, b);
2200 }
2201
2202 auto *iif = e.as_ptr<iif_t>();
2203 if (iif) {
2204 if (!is_const(iif->cond)) return e;
2205 if (to_cpp<bool>(iif->cond)) return iif->true_expr;
2206 return iif->false_expr;
2207 }
2208
2209 auto *cast = e.as_ptr<cast_t>();
2210 if (cast && !cast->saturate) {
2211 if (cast->expr.is<bool_imm_t>())
2212 return to_expr(to_cpp<bool>(cast->expr), cast->type);
2213 if (cast->expr.is<int_imm_t>())
2214 return to_expr(to_cpp<int64_t>(cast->expr), cast->type);
2215 if (cast->expr.is<float_imm_t>())
2216 return to_expr(to_cpp<double>(cast->expr), cast->type);
2217 }
2218
2219 return e;
2220}
2221
2222object_t const_fold(const object_t &obj) {
2223 return const_folder_t().mutate(obj);
2224}
2225
2226expr_t nary_op_back_transform(const expr_t &e) {
2227 // Convert nary_op_t back to binary_op_t.
2228 return nary_op_back_transformer_t().mutate(e);
2229}
2230
2231expr_t nary_op_canonicalize(const expr_t &_e) {
2232 auto e = _e;
2233
2234 e = nary_op_transformer_t().mutate(e);
2235 e = mul_nary_op_expander_t().mutate(e);
2236
2237 ir_assert(is_nary_op_canonical(e)) << e;
2238 MAYBE_UNUSED(is_nary_op_canonical);
2239
2240 return e;
2241}
2242
2243expr_t make_nary_op(op_kind_t op_kind, const std::vector<expr_t> &args) {
2244 if (args.empty()) {
2245 if (op_kind == op_kind_t::_add) return 0;
2246 if (op_kind == op_kind_t::_mul) return 1;
2247 ir_error_not_expected() << to_string(op_kind);
2248 }
2249 if (args.size() == 1) return args[0];
2250
2251 // Do eager constant folding.
2252 std::vector<expr_t> new_args;
2253 fold_const_nary_op_args(op_kind, args, new_args);
2254
2255 if (new_args.size() < args.size()) return make_nary_op(op_kind, new_args);
2256
2257 return nary_op_t::make(op_kind, new_args);
2258}
2259
2260std::vector<expr_t> cvt_expr_to_nary_op_args(const expr_t &e) {
2261 auto *nary = e.as_ptr<nary_op_t>();
2262 if (nary) return nary->args;
2263 return {e};
2264}
2265
2266stmt_t simplify(const stmt_t &s, ir_context_t &ir_ctx) {
2267 trace_start();
2268 auto ret = simplify(s, ir_ctx.cset());
2269 trace_pass("simplify_pass", ret, ir_ctx);
2270 return ret;
2271}
2272
2273} // namespace jit
2274} // namespace gpu
2275} // namespace impl
2276} // namespace dnnl
2277