1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/jit/ir/ir.hpp"
18
19#include <sstream>
20
21#include "common/math_utils.hpp"
22#include "common/optional.hpp"
23#include "gpu/jit/codegen/register_allocator.hpp"
24#include "gpu/jit/ir/core.hpp"
25#include "gpu/jit/ir/message.hpp"
26#include "gpu/jit/pass/simplify.hpp"
27
28namespace dnnl {
29namespace impl {
30namespace gpu {
31namespace jit {
32
33using namespace ir_utils;
34
35namespace {
36
37// Helper class to print IR objects.
38class ir_printer_t : public ir_visitor_t {
39public:
40 ir_printer_t(std::ostream &out) : out_(out) {}
41
42 void _visit(const alloc_t &obj) override {
43 auto guard
44 = mem_usage_guard(obj.kind == alloc_kind_t::grf ? obj.size : 0);
45 print_indent();
46 out_ << "alloc " << obj.buf.as<var_t>().name << "[" << obj.size
47 << "] (mem_usage: " << mem_usage_ << ")\n";
48 visit(obj.body);
49 }
50
51 void _visit(const binary_op_t &obj) override {
52 if (utils::one_of(obj.op_kind, op_kind_t::_min, op_kind_t::_max)) {
53 out_ << to_string(obj.op_kind) << "(" << obj.a << ", " << obj.b
54 << ")";
55 return;
56 }
57 out_ << "(";
58 visit(obj.a);
59 out_ << " " << to_string(obj.op_kind) << " ";
60 visit(obj.b);
61 out_ << ")";
62 }
63
64 void _visit(const bool_imm_t &obj) override {
65 out_ << (obj.value ? "true" : "false");
66 }
67
68 void _visit(const cast_t &obj) override {
69 out_ << obj.type;
70 if (obj.saturate) out_ << ".sat";
71 out_ << "(" << obj.expr << ")";
72 }
73
74 void _visit(const float_imm_t &obj) override { out_ << obj.value; }
75
76 void _visit(const for_t &obj) override {
77 print_indent();
78 out_ << "for (" << obj.var << " = " << obj.init << "; " << obj.var
79 << " < " << obj.bound << "; " << obj.var << "++) ";
80 if (obj.unroll != 1) out_ << "[unroll: " << obj.unroll << "] ";
81 out_ << "{\n";
82 add_indent();
83 visit(obj.body);
84 remove_indent();
85 print_indent();
86 out_ << "}\n";
87 }
88
89 void _visit(const func_call_t &obj) override {
90 print_indent();
91 out_ << obj.func << "(" << make_seq_print_helper(obj.args) << ")";
92 if (!obj.attr.is_empty()) out_ << " " << obj.attr;
93 out_ << "\n";
94 }
95
96 void _visit(const func_impl_t &obj) override { out_ << obj.str(); }
97
98 void _visit(const if_t &obj) override {
99 print_indent();
100 out_ << "if (" << strip_parens(obj.cond.str()) << ") {\n";
101 add_indent();
102 visit(obj.body);
103 remove_indent();
104 print_indent();
105 if (obj.else_body.is_empty()) {
106 out_ << "}\n";
107 return;
108 }
109 out_ << "} else {\n";
110 add_indent();
111 visit(obj.else_body);
112 remove_indent();
113 print_indent();
114 out_ << "}\n";
115 }
116
117 void _visit(const iif_t &obj) override {
118 out_ << "(" << obj.cond << " ? " << obj.true_expr << " : "
119 << obj.false_expr << ")";
120 }
121
122 void _visit(const int_imm_t &obj) override {
123 out_ << std::to_string(obj.value);
124 }
125
126 void _visit(const let_t &obj) override {
127 // Empty objects are allocated in reserved space
128 // nGEN only claims subregisters at dword granularity
129 int size = obj.value.is_empty() ? 0
130 : utils::rnd_up(obj.var.type().size(),
131 reg_allocator_t::granularity);
132 auto guard = mem_usage_guard(size);
133 print_indent();
134 out_ << obj.var << "." << obj.var.type() << " = " << obj.value << "\n";
135 visit(obj.body);
136 }
137
138 void _visit(const load_t &obj) override {
139 out_ << obj.buf;
140 if (obj.has_default_stride()) {
141 out_ << "." << obj.type << "(" << obj.off / obj.type.size() << ")";
142 } else {
143 out_ << "[" << obj.off << "]." << obj.type;
144 out_ << "<" << obj.stride << ">";
145 }
146 }
147
148 void _visit(const ptr_t &obj) override {
149 out_ << obj.base << "[" << obj.off << "]";
150 }
151
152 void _visit(const shuffle_t &obj) override {
153 if (obj.is_broadcast()) {
154 out_ << "bcast" << obj.elems() << "(" << obj.vec[0] << ")";
155 return;
156 }
157 std::vector<expr_t> vec_all;
158 for (auto &v : obj.vec) {
159 for (int i = 0; i < v.type().elems(); i++)
160 vec_all.push_back(v);
161 }
162 int elems = obj.type.elems();
163 out_ << "(";
164 for (int i = 0; i < elems; i++) {
165 int idx = obj.idx[i];
166 auto &v = vec_all[idx];
167 int v_elems = v.type().elems();
168 out_ << v;
169 if (v_elems != 1) out_ << "[" << idx << "]";
170 if (i != elems - 1) out_ << ", ";
171 }
172 out_ << ")";
173 }
174
175 void _visit(const stmt_group_t &obj) override {
176 print_indent();
177 out_ << obj.label << " {\n";
178 add_indent();
179 visit(obj.body);
180 remove_indent();
181 print_indent();
182 out_ << "}\n";
183 return;
184 }
185
186 void _visit(const stmt_seq_t &obj) override {
187 visit(obj.head);
188 visit(obj.tail);
189 }
190
191 void _visit(const store_t &obj) override {
192 print_indent();
193 out_ << load_t::make(obj.value.type(), obj.buf, obj.off, obj.stride);
194 out_ << " = " << obj.value;
195 if (!obj.mask.is_empty()) {
196 out_ << ", mask = " << obj.mask.str();
197 if (obj.fill_mask0) out_ << " [FILL]";
198 }
199 out_ << "\n";
200 }
201
202 void _visit(const ternary_op_t &obj) override {
203 out_ << to_string(obj.op_kind) << "(" << obj.a << ", " << obj.b << ", "
204 << obj.c << ")";
205 return;
206 }
207
208 void _visit(const unary_op_t &obj) override {
209 out_ << to_string(obj.op_kind);
210 visit(obj.a);
211 }
212
213 void _visit(const var_t &obj) override { out_ << obj.name; }
214
215private:
216 mem_usage_guard_t mem_usage_guard(int size) {
217 return mem_usage_guard_t(&mem_usage_, size);
218 }
219
220 static std::string strip_parens(const std::string &s) {
221 if (s.size() < 2 || s[0] != '(' || s[s.size() - 1] != ')') return s;
222 auto ret = s;
223 ret.resize(s.size() - 1);
224 return ret.substr(1);
225 }
226
227 void print_indent() {
228 for (int i = 0; i < indent_; i++)
229 out_ << prefix_;
230 }
231
232 void add_indent() { indent_++; }
233 void remove_indent() { indent_--; }
234
235 std::ostream &out_;
236 int indent_ = 0;
237
238 std::string prefix_ = " ";
239
240 // Size required for all enclosed let/alloc statements. The value is
241 // updated during traversal.
242 int mem_usage_ = 0;
243};
244
245class substitute_mutator_t : public ir_mutator_t {
246public:
247 substitute_mutator_t(const object_t &from, const object_t &to)
248 : from_(from), to_(to) {}
249
250 int substitutions() const { return substitutions_; }
251
252#define HANDLE_IR_OBJECT(type) \
253 object_t _mutate(const type &obj) override { \
254 if (from_.impl() == (const object_impl_t *)&obj) { \
255 substitutions_++; \
256 return to_; \
257 } \
258 return ir_mutator_t::_mutate(obj); \
259 };
260
261 HANDLE_TRAVERSE_TARGETS()
262
263#undef HANDLE_IR_OBJECT
264
265private:
266 object_t from_;
267 object_t to_;
268
269 int substitutions_ = 0;
270};
271
272class substitute_and_type_mutator_t : public ir_mutator_t {
273public:
274 substitute_and_type_mutator_t(const object_t &from, const object_t &to) {
275 substitutes_[from] = to;
276 }
277
278 int substitutions() const { return substitutions_; }
279
280 template <typename T>
281 object_t _mutate_after(const T &obj) {
282 return ir_mutator_t::_mutate(obj);
283 }
284
285 object_t _mutate_after(const let_t &obj) {
286 auto var = mutate(obj.var);
287 auto value = mutate(obj.value);
288
289 // Allow changing variable types when performing substitutions. Avoids
290 // the following invalid substitute transformation sequence:
291 //
292 // tmp0.s32 -> tmp0_0.u64
293 // tmp1.s32 = tmp0.s32 -> tmp1.s32 = tmp0_0.u64
294 if (!value.is_empty()) {
295 auto &value_type = expr_t(value).type();
296 if (var.as<var_t>().type != value_type) {
297 auto var_old = var;
298 var = var_t::make(value_type, var.as<var_t>().name);
299
300 substitutes_[var_old] = var;
301 }
302 }
303
304 auto body = mutate(obj.body);
305
306 if (var.is_same(obj.var) && value.is_same(obj.value)
307 && body.is_same(obj.body))
308 return obj;
309
310 return let_t::make(var, value, body);
311 }
312
313#define HANDLE_IR_OBJECT(type) \
314 object_t _mutate(const type &obj) override { \
315 auto it = substitutes_.find(obj); \
316 if (it != substitutes_.end()) { \
317 substitutions_++; \
318 return it->second; \
319 } \
320 return _mutate_after(obj); \
321 };
322
323 HANDLE_ALL_IR_OBJECTS()
324
325#undef HANDLE_IR_OBJECT
326
327private:
328 object_eq_map_t<object_t, object_t> substitutes_;
329
330 int substitutions_ = 0;
331};
332
333class stmt_flattener_t : public ir_visitor_t {
334public:
335#define HANDLE_IR_OBJECT(type) \
336 void _visit(const type &obj) { \
337 size_t old_size = stmts.size(); \
338 ir_visitor_t::_visit(obj); \
339 if (stmts.size() > old_size) return; \
340 if (obj.is_stmt()) stmts.push_back(obj); \
341 }
342
343 HANDLE_ALL_IR_OBJECTS()
344
345#undef HANDLE_IR_OBJECT
346
347 std::vector<stmt_t> stmts;
348};
349
350class alloc_injector_t : public ir_mutator_t {
351public:
352 alloc_injector_t(const stmt_t &root, const std::vector<stmt_t> &allocs,
353 bool put_innermost)
354 : root_(root), put_innermost_(put_innermost), allocs_(allocs) {
355 for (auto &_a : allocs) {
356 auto &a = _a.as<alloc_t>();
357 if (a.kind != alloc_kind_t::global) ir_assert(a.size > 0) << _a;
358 alloc_map_.insert({a.buf, _a});
359 }
360 mutate(root_);
361 buf_total_refs_ = buf_cur_refs_;
362 for (auto &kv : buf_cur_refs_)
363 kv.second = 0;
364 in_ctor_ = false;
365 }
366
367#define HANDLE_IR_OBJECT(type) \
368 object_t _mutate(const type &obj) override { return mutate_stmt(obj); }
369
370 HANDLE_STMT_IR_OBJECTS()
371
372#undef HANDLE_IR_OBJECT
373 object_t _mutate(const var_t &obj) override {
374 if (alloc_map_.find(obj) != alloc_map_.end()) buf_cur_refs_[obj]++;
375 return obj;
376 }
377
378private:
379 template <typename T>
380 object_t mutate_stmt(const T &obj) {
381 if (in_ctor_) return ir_mutator_t::_mutate(obj);
382 object_t new_obj = obj;
383 object_set_t<expr_t> undef_bufs;
384 if (put_innermost_) {
385 for (auto &kv : buf_cur_refs_)
386 if (kv.second == 0) undef_bufs.insert(kv.first);
387 new_obj = ir_mutator_t::_mutate(obj);
388 }
389 for (auto &a : allocs_) {
390 auto it = alloc_map_.find(a.as<alloc_t>().buf);
391 auto &buf = it->first;
392 if (it->second.is_empty()) continue; // Already injected.
393 bool do_inject = false;
394 if (put_innermost_) {
395 int cur_refs = buf_cur_refs_[buf];
396 int total_refs = buf_total_refs_[buf];
397 bool was_undef = (undef_bufs.count(buf) != 0);
398 do_inject = was_undef && (cur_refs == total_refs);
399 } else {
400 do_inject = root_.is_same(obj);
401 }
402 if (do_inject) {
403 auto &a = it->second.as<alloc_t>();
404 new_obj = alloc_t::make(
405 a.buf, a.size, a.kind, a.attrs, new_obj);
406 it->second = stmt_t();
407 }
408 }
409 return new_obj;
410 }
411
412 bool in_ctor_ = true;
413 const stmt_t &root_;
414 bool put_innermost_;
415 std::vector<stmt_t> allocs_;
416 object_map_t<expr_t, stmt_t> alloc_map_;
417 object_map_t<expr_t, int> buf_total_refs_;
418 object_map_t<expr_t, int> buf_cur_refs_;
419};
420
421} // namespace
422
423std::string object_impl_t::str() const {
424 std::ostringstream oss;
425 ir_printer_t printer(oss);
426 printer.visit(this);
427 return oss.str();
428}
429
430object_t substitute(const object_t &root, const object_t &from,
431 const object_t &to, int max_substitutions) {
432 if (to.is_same(from)) return root;
433 substitute_mutator_t sm(from, to);
434 auto ret = sm.mutate(root);
435 ir_assert(sm.substitutions() <= max_substitutions)
436 << "Unexpected number of substitutions.";
437 return ret;
438}
439
440object_t substitute_with_different_type(const object_t &root,
441 const object_t &from, const object_t &to, int max_substitutions) {
442 if (to.is_same(from)) return root;
443 substitute_and_type_mutator_t sm(from, to);
444 auto ret = sm.mutate(root);
445 ir_assert(sm.substitutions() <= max_substitutions)
446 << "Unexpected number of substitutions.";
447 return ret;
448}
449
450std::vector<stmt_t> flatten_statements(const stmt_t &root) {
451 stmt_flattener_t f;
452 f.visit(root);
453 return f.stmts;
454}
455
456stmt_t inject_alloc_stmts(const stmt_t &stmt, const std::vector<stmt_t> &allocs,
457 bool put_innermost) {
458 alloc_injector_t injector(stmt, allocs, put_innermost);
459 return injector.mutate(stmt);
460}
461
462stmt_t inject_let_stmts(const stmt_t &stmt, const std::vector<stmt_t> &lets) {
463 stmt_t ret = stmt;
464 for (auto it = lets.rbegin(); it != lets.rend(); ++it) {
465 auto &let = it->as<let_t>();
466 ret = let_t::make(let.var, let.value, ret);
467 }
468 return ret;
469}
470
471expr_t abs(const expr_t &e) {
472 ir_assert(is_const(e)) << e;
473 if (to_cpp<bool>(e >= 0)) return e;
474 return -e;
475}
476
477expr_t cast(const expr_t &e, const type_t &type, bool saturate) {
478 return const_fold(cast_t::make(type, e, saturate));
479}
480
481bool is_zero(const expr_t &e) {
482 if (e.is_empty()) return false;
483 if (!e.type().is_scalar() || e.type().is_ptr()) return false;
484 return e.is_equal(to_expr(0, e.type()));
485}
486
487bool is_one(const expr_t &e) {
488 if (e.is_empty()) return false;
489 if (!e.type().is_scalar() || e.type().is_ptr()) return false;
490 return e.is_equal(to_expr(1, e.type()));
491}
492
493bool is_minus_one(const expr_t &e) {
494 if (e.is_empty()) return false;
495 if (!e.type().is_scalar() || e.type().is_ptr()) return false;
496 return e.is_equal(to_expr(-1, e.type()));
497}
498
499bool is_const_broadcast(const expr_t &e) {
500 auto *shuffle = e.as_ptr<shuffle_t>();
501 if (!shuffle) return false;
502 if (!shuffle->is_broadcast()) return false;
503 return is_const(shuffle->vec[0]);
504}
505
506bool is_const_broadcast(const expr_t &e, const expr_t &value) {
507 if (!is_const_broadcast(e)) return false;
508 return e.as<shuffle_t>().vec[0].is_equal(value);
509}
510
511expr_t make_buffer(const std::string &name) {
512 return var_t::make(type_t::byte_ptr(), name);
513}
514
515// Returns number of occurrences of `obj` in `root` (based on identity equality).
516int count_object(const object_t &root, const object_t &obj) {
517 ir_assert(!obj.is_empty());
518
519 std::vector<object_t> found;
520 do {
521#define HANDLE_IR_OBJECT(type) \
522 if (obj.dispatch_type_id() == type::_dispatch_type_id()) { \
523 found = find_objects<type>(root); \
524 break; \
525 }
526
527 HANDLE_ALL_IR_OBJECTS()
528
529#undef HANDLE_IR_OBJECT
530
531 ir_error_not_expected() << obj;
532 } while (false);
533
534 int ret = 0;
535 for (auto &f : found)
536 if (f.is_equal(obj)) ret++;
537 return ret;
538}
539
540bool contains_object(const object_t &root, const object_t &obj) {
541 ir_assert(is_var(obj)) << obj;
542 return count_object(root, obj) > 0;
543}
544
545std::vector<stmt_t> find_stmt_groups(
546 const object_t &root, const stmt_label_t &label) {
547 auto groups = find_objects<stmt_group_t>(root);
548 std::vector<stmt_t> ret;
549 for (auto &g : groups) {
550 if (g.as<stmt_group_t>().label == label) ret.push_back(g);
551 }
552 return ret;
553}
554
555utils::optional_t<stmt_t> find_stmt_group(
556 const object_t &root, const stmt_label_t &label) {
557 auto groups = find_stmt_groups(root, label);
558 if (groups.size() == 1)
559 return groups[0];
560 else
561 return utils::nullopt;
562}
563
564class stmt_group_remover_t : public ir_mutator_t {
565public:
566 stmt_group_remover_t(stmt_label_t label) : label_(label) {}
567 object_t _mutate(const stmt_group_t &obj) override {
568 if (obj.label == label_) return stmt_t();
569 return ir_mutator_t::_mutate(obj);
570 }
571 stmt_label_t label_;
572};
573
574object_t remove_stmt_group(const object_t &root, stmt_label_t label) {
575 stmt_group_remover_t remover(label);
576 return remover.mutate(root);
577}
578
579stmt_t get_stmt_body(const stmt_t &stmt) {
580 auto *alloc = stmt.as_ptr<alloc_t>();
581 if (alloc) return alloc->body;
582
583 auto *_for = stmt.as_ptr<for_t>();
584 if (_for) return _for->body;
585
586 auto *let = stmt.as_ptr<let_t>();
587 if (let) return let->body;
588
589 auto *group = stmt.as_ptr<stmt_group_t>();
590 if (group) return group->body;
591
592 return stmt;
593}
594
595stmt_t replace_stmt_body(const stmt_t &stmt, const stmt_t &new_body) {
596 auto *alloc = stmt.as_ptr<alloc_t>();
597 if (alloc) {
598 return alloc_t::make(
599 alloc->buf, alloc->size, alloc->kind, alloc->attrs, new_body);
600 }
601
602 auto *_for = stmt.as_ptr<for_t>();
603 if (_for) {
604 return for_t::make(
605 _for->var, _for->init, _for->bound, new_body, _for->unroll);
606 }
607
608 auto *let = stmt.as_ptr<let_t>();
609 if (let) { return let_t::make(let->var, let->value, new_body); }
610
611 auto *group = stmt.as_ptr<stmt_group_t>();
612 if (group) { return stmt_group_t::make(group->label, new_body); }
613
614 return new_body;
615}
616
617class grf_usage_visitor_t : public ir_visitor_t {
618public:
619 grf_usage_visitor_t(int grf_size, int external_usage, bool skip_let)
620 : grf_size_(grf_size)
621 , skip_let_(skip_let)
622 , grf_usage_(external_usage) {}
623
624 void _visit(const alloc_t &obj) override {
625 int size = (obj.kind == alloc_kind_t::grf ? obj.size : 0);
626 size = utils::rnd_up(size, grf_size_);
627 auto guard = grf_usage_guard(size);
628 ir_visitor_t::_visit(obj);
629 }
630
631 void _visit(const let_t &obj) override {
632 // Empty objects are allocated in reserved space
633 // nGEN only claims subregisters at dword granularity
634 int size = (skip_let_ || obj.value.is_empty())
635 ? 0
636 : utils::rnd_up(
637 obj.var.type().size(), reg_allocator_t::granularity);
638 auto guard = grf_usage_guard(size);
639 ir_visitor_t::_visit(obj);
640 }
641
642 int peak_grf_usage() const { return peak_grf_usage_; }
643
644private:
645 mem_usage_guard_t grf_usage_guard(int size) {
646 auto ret = mem_usage_guard_t(&grf_usage_, size);
647 peak_grf_usage_ = std::max(peak_grf_usage_, grf_usage_);
648 return ret;
649 }
650
651 int grf_size_ = 0;
652 bool skip_let_ = false;
653 int grf_usage_ = 0;
654 int peak_grf_usage_ = 0;
655};
656
657int get_peak_grf_usage(
658 const stmt_t &stmt, int grf_size, int external_usage, bool skip_let) {
659 grf_usage_visitor_t visitor(grf_size, external_usage, skip_let);
660 visitor.visit(stmt);
661 return utils::div_up(visitor.peak_grf_usage(), grf_size);
662}
663
664class has_send_atomics_visitor_t : public ir_visitor_t {
665public:
666 void _visit(const func_call_t &obj) override {
667 auto *send = obj.func.as_ptr<send_t>();
668 if (send && send->is_atomic()) found = true;
669 }
670
671 bool found = false;
672};
673
674bool has_send_atomics(const stmt_t &s) {
675 has_send_atomics_visitor_t visitor;
676 visitor.visit(s);
677 return visitor.found;
678}
679
680bool relation_t::implies(const relation_t &other) const {
681 ir_assert(var().is_same(other.var()));
682
683 if (op_kind() != other.op_kind()) return false;
684
685 auto A = to_cpp<int64_t>(rhs());
686 auto B = to_cpp<int64_t>(other.rhs());
687
688 switch (op_kind()) {
689 // (x > A) && (A >= B) => (x > B)
690 // (x >= A) && (A >= B) => (x >= B)
691 case op_kind_t::_gt:
692 case op_kind_t::_ge: return A >= B;
693 // (x < A) && (A <= B) => (x < B)
694 // (x <= A) && (A <= B) => (x <= B)
695 case op_kind_t::_lt:
696 case op_kind_t::_le: return A <= B;
697 default: ir_error_not_expected() << "Not implemented: " << expr_;
698 }
699 return false;
700}
701
702relation_t relation_t::transform(
703 const linear_transform_t &t, const expr_t &new_var) {
704 ir_assert(t.a == 1) << "Not implemented.";
705 return relation_t(binary_op_t::make(op_kind(), new_var, rhs() + t.b));
706}
707
708expr_t relation_t::normalize(const expr_t &e) {
709 ir_assert(is_relation_constraint(e)) << e;
710 auto &op = e.as<binary_op_t>();
711
712 auto op_kind = op.op_kind;
713 auto a = op.a;
714 auto b = op.b;
715
716 switch (op_kind) {
717 case op_kind_t::_lt:
718 op_kind = op_kind_t::_le;
719 b -= 1;
720 break;
721 case op_kind_t::_gt:
722 op_kind = op_kind_t::_ge;
723 b += 1;
724 break;
725 default: return e;
726 }
727 return binary_op_t::make(op_kind, a, b);
728}
729
730bool modulus_info_t::is_modulus_constraint(const expr_t &e) {
731 auto *binary_op = e.as_ptr<binary_op_t>();
732 if (!binary_op) return false;
733 if (!is_zero(binary_op->b)) return false;
734 if (binary_op->op_kind != op_kind_t::_eq) return false;
735
736 auto *mod_op = binary_op->a.as_ptr<binary_op_t>();
737 if (!mod_op) return false;
738 if (mod_op->op_kind != op_kind_t::_mod) return false;
739 if (!is_var(mod_op->a)) return false;
740 if (!is_const(mod_op->b)) return false;
741
742 return true;
743}
744
745int64_t bound_finder_base_t::find_bound_impl(
746 const expr_t &e, bool is_low) const {
747 int64_t def_bound = unlimited_bound(is_low);
748 if (is_const(e)) return to_cpp<int64_t>(e);
749 if (is_var(e)) return get_var_bound(e, is_low);
750
751 auto *unary = e.as_ptr<unary_op_t>();
752 if (unary) {
753 ir_assert(unary->op_kind == op_kind_t::_minus) << e;
754 auto a = find_bound_impl(unary->a, !is_low);
755 if (!is_good_bound(a)) return def_bound;
756 return -a;
757 }
758
759 auto *binary = e.as_ptr<binary_op_t>();
760 if (binary) {
761 switch (binary->op_kind) {
762 case op_kind_t::_add: {
763 auto a = find_bound_impl(binary->a, is_low);
764 auto b = find_bound_impl(binary->b, is_low);
765 if (!is_good_bound(a) || !is_good_bound(b)) return def_bound;
766 return a + b;
767 }
768 case op_kind_t::_sub: {
769 auto a = find_bound_impl(binary->a, is_low);
770 auto b = find_bound_impl(binary->b, !is_low);
771 if (!is_good_bound(a) || !is_good_bound(b)) return def_bound;
772 return a - b;
773 }
774 case op_kind_t::_mul: {
775 auto a = binary->a;
776 auto b = binary->b;
777 if (!is_const(a) && is_const(b)) std::swap(a, b);
778 if (!is_const(a)) return def_bound;
779
780 auto a_const = to_cpp<int64_t>(a);
781 if (a_const == 0) return 0;
782
783 auto b_lo = find_low_bound(b);
784 auto b_hi = find_high_bound(b);
785 auto b_lo_ok = is_good_bound(b_lo);
786 auto b_hi_ok = is_good_bound(b_hi);
787
788 if ((a_const > 0) == is_low && b_lo_ok) return a_const * b_lo;
789 if ((a_const > 0) != is_low && b_hi_ok) return a_const * b_hi;
790
791 break;
792 }
793 case op_kind_t::_div: {
794 if (!is_const(binary->b)) return def_bound;
795
796 auto b = to_cpp<int64_t>(binary->b);
797 ir_assert(b != 0);
798
799 auto a = find_bound_impl(binary->a, b > 0 ? is_low : !is_low);
800 if (!is_good_bound(a)) return def_bound;
801
802 bool is_neg = ((a > 0) && (b < 0)) || ((a < 0) && (b > 0));
803
804 int64_t div_bound;
805 if (is_low != is_neg) {
806 // Truncate away from zero.
807 div_bound = utils::div_up(std::abs(a), std::abs(b));
808 } else {
809 // Truncate towards zero.
810 div_bound = std::abs(a) / std::abs(b);
811 }
812 if (is_neg) div_bound *= -1;
813 return div_bound;
814 }
815 case op_kind_t::_mod: {
816 if (is_low) return 0;
817 auto max_mod = find_bound_impl(binary->b, /*is_low=*/false);
818 if (!is_good_bound(max_mod)) return def_bound;
819 return max_mod - 1;
820 }
821 case op_kind_t::_and: {
822 if (e.type().is_u16()) {
823 return is_low ? e.type().min<int64_t>()
824 : e.type().max<int64_t>();
825 }
826 break;
827 }
828 default: break;
829 }
830 }
831
832 auto *cast = e.as_ptr<cast_t>();
833 if (cast) {
834 // Saturate if needed, otherwise assume the same bounds.
835 if (!cast->is_bool_vec_u16() && !cast->saturate)
836 return find_bound_impl(cast->expr, is_low);
837
838 if (is_low) {
839 auto type_lo = cast->type.min<int64_t>();
840 auto lo = find_low_bound(cast->expr);
841 return std::max(type_lo, lo);
842 }
843 // Check u64 explicitly as its max doesn't fit into int64_t.
844 if (cast->type.is_u64()) return find_bound_impl(cast->expr, is_low);
845 auto type_hi = cast->type.max<int64_t>();
846 auto hi = find_high_bound(cast->expr);
847 return std::min(type_hi, hi);
848 }
849
850 if (e.type().is_bool()) return is_low ? 0 : 1;
851
852 return def_bound;
853}
854
855bool is_linear_var_transform(const expr_t &e, linear_transform_t &t) {
856 if (is_var(e)) {
857 t.x = e;
858 t.a = 1;
859 t.b = 0;
860 return true;
861 }
862
863 auto *binary_op = e.as_ptr<binary_op_t>();
864 if (!binary_op) return false;
865
866 auto vars = find_objects<var_t>(e);
867 if (vars.size() != 1) return false;
868
869 auto &var = vars[0];
870
871 // TODO: Extend to match multiplication: (a * var).
872 if (!utils::one_of(binary_op->op_kind, op_kind_t::_add, op_kind_t::_sub))
873 return false;
874
875 auto &a = binary_op->a;
876 auto &b = binary_op->b;
877
878 bool is_sub = (binary_op->op_kind == op_kind_t::_sub);
879
880 // var op b -> (t.a = 1, t.b = +/-b)
881 if (a.is_same(var) && is_const(b)) {
882 t.x = var;
883 t.a = 1;
884 t.b = (is_sub ? -1 : 1) * to_cpp<int>(b);
885 return true;
886 }
887
888 // a op var -> (t.a = +/-1, t.b = a)
889 if (is_const(a) && b.is_same(var)) {
890 t.x = var;
891 t.a = (is_sub ? -1 : 1);
892 t.b = to_cpp<int>(a);
893 return true;
894 }
895
896 return false;
897}
898
899void ir_context_t::add_constraint(const expr_t &e) {
900 cset_.add_constraint(e);
901}
902
903void constraint_set_t::add_constraint(const expr_t &e) {
904 auto *shuffle = e.as_ptr<shuffle_t>();
905 if (shuffle) {
906 if (shuffle->is_broadcast()) add_constraint(shuffle->vec[0]);
907 return;
908 }
909
910 if (modulus_info_t::is_modulus_constraint(e)) {
911 modulus_info_t mi(e);
912 modulus_infos_[mi.var()].push_back(mi);
913 return;
914 }
915
916 if (relation_t::is_relation_constraint(e)) {
917 relation_t rel(e);
918 relations_[rel.var()].push_back(rel);
919 return;
920 }
921
922 // Propagate constraints from y for (x == y) equalities.
923 auto *binary_op = e.as_ptr<binary_op_t>();
924 if (binary_op && binary_op->op_kind == op_kind_t::_eq) {
925 auto &a = binary_op->a;
926 auto &b = binary_op->b;
927 linear_transform_t t;
928 if (is_var(a) && is_linear_var_transform(b, t)) {
929 // Relations.
930 auto r_it = relations_.find(t.x);
931 if (r_it != relations_.end()) {
932 for (auto &c : r_it->second) {
933 add_constraint(c.transform(t, a).expr());
934 }
935 }
936 // Modulus.
937 if (t.is_identity()) {
938 auto m_it = modulus_infos_.find(t.x);
939 if (m_it != modulus_infos_.end()) {
940 for (auto &c : m_it->second) {
941 add_constraint(substitute(c.expr(), b, a));
942 }
943 }
944 }
945 return;
946 }
947 }
948}
949
950bool constraint_set_t::is_single_value(const expr_t &e, expr_t &value) const {
951 ir_assert(is_var(e)) << e;
952 auto it = relations_.find(e);
953 if (it == relations_.end()) return false;
954
955 expr_t lo;
956 expr_t hi;
957 for (auto &rel : it->second) {
958 ir_assert(is_const(rel.rhs())) << rel;
959 bool do_break = false;
960 switch (rel.op_kind()) {
961 case op_kind_t::_eq:
962 lo = hi = rel.rhs();
963 do_break = true;
964 break;
965 case op_kind_t::_ge:
966 case op_kind_t::_gt: {
967 auto cur_lo = (rel.op_kind() == op_kind_t::_ge ? rel.rhs()
968 : rel.rhs() + 1);
969 if (lo.is_empty() || to_cpp<bool>(cur_lo > lo)) { lo = cur_lo; }
970 break;
971 }
972 case op_kind_t::_le:
973 case op_kind_t::_lt: {
974 auto cur_hi = (rel.op_kind() == op_kind_t::_le ? rel.rhs()
975 : rel.rhs() - 1);
976 if (hi.is_empty() || to_cpp<bool>(cur_hi < hi)) { hi = cur_hi; }
977 break;
978 }
979 default: ir_error_not_expected() << rel;
980 }
981 if (do_break) break;
982 }
983 bool ret = !lo.is_empty() && lo.is_equal(hi);
984 if (ret) value = lo;
985 return ret;
986}
987
988bool constraint_set_t::can_prove_impl(
989 const expr_t &_e, bool do_simplify) const {
990 auto e = _e;
991 if (is_const(e)) {
992 ir_assert(e.type() == type_t::_bool()) << e;
993 return to_cpp<bool>(e);
994 }
995
996 if (do_simplify) {
997 // These passes for comparison help to prove more inequalities.
998 e = simplify_cmp_move_const_to_rhs(e);
999 e = simplify_cmp_reduce_lhs_rhs(e);
1000 e = simplify(e);
1001 if (is_const(e)) {
1002 ir_assert(e.type() == type_t::_bool()) << e;
1003 return to_cpp<bool>(e);
1004 }
1005 }
1006
1007 if (modulus_info_t::is_modulus_constraint(e)) return can_prove_modulus(e);
1008 if (relation_t::is_relation_constraint(e)) return can_prove_relation(e);
1009
1010 // Try to estimate bounds for compound relation.
1011 if (try_prove_compound_relation(e)) return true;
1012
1013 // Can't prove.
1014 return false;
1015}
1016
1017int constraint_set_t::max_proven_gcd(const expr_t &var) const {
1018 auto it = modulus_infos_.find(var);
1019 if (it == modulus_infos_.end()) return 1;
1020 int ret = 1;
1021 for (auto &c : it->second) {
1022 ret = math::lcm(ret, to_cpp<int>(c.mod()));
1023 }
1024 return ret;
1025}
1026
1027} // namespace jit
1028} // namespace gpu
1029} // namespace impl
1030} // namespace dnnl
1031