1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/jit/ir/core.hpp"
18
19#include <algorithm>
20
21namespace dnnl {
22namespace impl {
23namespace gpu {
24namespace jit {
25
26expr_t const_fold_non_recursive(const expr_t &expr);
27object_t const_fold(const object_t &obj);
28
29std::string to_string(type_kind_t kind) {
30#define CASE(_kind) \
31 case type_kind_t::_kind: return #_kind
32 switch (kind) {
33 CASE(undef);
34 CASE(u8);
35 CASE(s8);
36 CASE(u16);
37 CASE(s16);
38 CASE(u32);
39 CASE(s32);
40 CASE(u64);
41 CASE(s64);
42 CASE(bf16);
43 CASE(f16);
44 CASE(tf32);
45 CASE(f32);
46 CASE(f64);
47 CASE(byte);
48 CASE(dword);
49 CASE(qword);
50 CASE(oword);
51 CASE(hword);
52 case type_kind_t::_bool: return "bool";
53 default: ir_error_not_expected();
54 }
55#undef CASE
56 return {};
57}
58
59int type_t::size() const {
60 if (is_ptr()) return sizeof(uint64_t);
61
62 if (is_bool()) return utils::div_up(elems(), 8);
63
64 if (elems() != 1) return elems() * scalar().size();
65
66 switch (kind()) {
67 case type_kind_t::u8:
68 case type_kind_t::s8:
69 case type_kind_t::byte: return 1;
70 case type_kind_t::u16:
71 case type_kind_t::s16:
72 case type_kind_t::bf16:
73 case type_kind_t::f16: return 2;
74 case type_kind_t::u32:
75 case type_kind_t::s32:
76 case type_kind_t::tf32:
77 case type_kind_t::f32:
78 case type_kind_t::dword: return 4;
79 case type_kind_t::f64:
80 case type_kind_t::u64:
81 case type_kind_t::s64:
82 case type_kind_t::qword: return 8;
83 case type_kind_t::oword: return 16;
84 case type_kind_t::hword: return 32;
85 default: ir_error_not_expected();
86 }
87 return 0;
88}
89
90data_type_t to_dnnl(const type_t &type) {
91 ir_assert(type.elems() == 1) << type;
92 ir_assert(!type.is_ptr() == 1) << type;
93 switch (type.kind()) {
94 case type_kind_t::bf16: return data_type::bf16;
95 case type_kind_t::f16: return data_type::f16;
96 case type_kind_t::tf32: return data_type::tf32;
97 case type_kind_t::f32: return data_type::f32;
98 case type_kind_t::f64: return data_type::f64;
99 case type_kind_t::s32: return data_type::s32;
100 case type_kind_t::s8: return data_type::s8;
101 case type_kind_t::u8: return data_type::u8;
102 default: ir_error_not_expected();
103 }
104 return data_type::undef;
105}
106
107std::string to_string(op_kind_t kind) {
108 switch (kind) {
109 case op_kind_t::_minus: return "-";
110
111 case op_kind_t::_add: return "+";
112 case op_kind_t::_sub: return "-";
113 case op_kind_t::_mul: return "*";
114 case op_kind_t::_div: return "/";
115 case op_kind_t::_mod: return "%";
116 case op_kind_t::_shl: return "<<";
117 case op_kind_t::_shr: return ">>";
118 case op_kind_t::_min: return "min";
119 case op_kind_t::_max: return "max";
120
121 case op_kind_t::_lt: return "<";
122 case op_kind_t::_le: return "<=";
123 case op_kind_t::_gt: return ">";
124 case op_kind_t::_ge: return ">=";
125 case op_kind_t::_eq: return "==";
126 case op_kind_t::_ne: return "!=";
127
128 case op_kind_t::_and: return "&&";
129
130 case op_kind_t::_add3: return "add3";
131 case op_kind_t::_mad: return "mad";
132 case op_kind_t::_prelu: return "prelu";
133
134 case op_kind_t::_dp4a: return "dp4a";
135
136 default: ir_error_not_expected() << "Unknown op_kind_t value.";
137 }
138 return "";
139}
140
141bool is_cmp_op(op_kind_t op_kind) {
142 switch (op_kind) {
143 case op_kind_t::_ge:
144 case op_kind_t::_gt:
145 case op_kind_t::_le:
146 case op_kind_t::_lt:
147 case op_kind_t::_eq:
148 case op_kind_t::_ne: return true;
149 default: return false;
150 }
151}
152
153op_kind_t negate_cmp_op(op_kind_t op_kind) {
154 switch (op_kind) {
155 case op_kind_t::_ge: return op_kind_t::_le;
156 case op_kind_t::_gt: return op_kind_t::_lt;
157 case op_kind_t::_le: return op_kind_t::_ge;
158 case op_kind_t::_lt: return op_kind_t::_gt;
159 case op_kind_t::_eq: return op_kind_t::_eq;
160 case op_kind_t::_ne: return op_kind_t::_ne;
161 default: ir_error_not_expected();
162 }
163 return op_kind_t::undef;
164}
165
166type_t unary_op_type(op_kind_t op_kind, const expr_t &a) {
167 switch (op_kind) {
168 case op_kind_t::_minus: {
169 auto &t = a.type();
170 if (!t.is_int()) return t;
171 if (t.size() < int(sizeof(int32_t))) return type_t::s32(t.elems());
172 return t;
173 }
174 default:
175 ir_error_not_expected() << "Unknown op_kind_t value: " << op_kind;
176 }
177 return type_t::undef();
178}
179
180type_t common_int_type(const type_t &_a, const type_t &_b) {
181 ir_assert(_a.is_int() && _b.is_int()) << "Unexpected types.";
182
183 int elems = _a.elems();
184
185 // Promote to s32 first.
186 type_t a = _a.size() < int(sizeof(int32_t)) ? type_t::s32() : _a;
187 type_t b = _b.size() < int(sizeof(int32_t)) ? type_t::s32() : _b;
188 a = a.scalar();
189 b = b.scalar();
190
191 // Integer promotion, follow C++ rules.
192 int common_bits = 8 * std::max(a.size(), b.size());
193 if (a.is_signed() == b.is_signed()) {
194 if (a.is_signed()) return type_t::s(common_bits, elems);
195 return type_t::u(common_bits, elems);
196 }
197
198 if (a.size() >= b.size() && a.is_unsigned())
199 return type_t::u(common_bits, elems);
200 if (b.size() >= a.size() && b.is_unsigned())
201 return type_t::u(common_bits, elems);
202 if (a.size() > b.size() && a.is_signed())
203 return type_t::s(common_bits, elems);
204 if (b.size() > a.size() && b.is_signed())
205 return type_t::s(common_bits, elems);
206
207 return type_t::u(common_bits, elems);
208}
209
210type_t common_type(const type_t &a, const type_t &b) {
211 ir_assert(a.elems() == b.elems())
212 << "Types must have the same number of components.";
213 if (a.is_undef() || b.is_undef()) return type_t::undef();
214 if (a.is_fp() && !b.is_fp()) return a;
215 if (!a.is_fp() && b.is_fp()) return b;
216 if (a.is_fp() && b.is_fp()) return (a.size() > b.size() ? a : b);
217 if (a.is_bool() && b.is_bool()) return a;
218 return common_int_type(a, b);
219}
220
221type_t common_type(const expr_t &a, const expr_t &b) {
222 return common_type(a.type(), b.type());
223}
224
225type_t binary_op_type(op_kind_t op_kind, const type_t &a, const type_t &b,
226 const expr_t &a_expr = expr_t(), const expr_t &b_expr = expr_t()) {
227 if (a.is_undef() || b.is_undef()) return type_t::undef();
228 ir_assert(a.elems() == b.elems())
229 << "Types must have the same number of components.";
230 if (is_cmp_op(op_kind)) return type_t::_bool(a.elems());
231 if (utils::one_of(op_kind, op_kind_t::_shl, op_kind_t::_shr)) {
232 ir_assert(a.is_unsigned())
233 << "a must be unsigned for shift left/right.";
234 return type_t::u32(a.elems());
235 }
236 if (op_kind == op_kind_t::_and) {
237 if (a == b) return a;
238 if (is_const(a_expr)) return b;
239 if (is_const(b_expr)) return a;
240 return (a.size() >= b.size()) ? a : b;
241 }
242 return common_type(a, b);
243}
244
245type_t binary_op_type(op_kind_t op_kind, const expr_t &a, const expr_t &b) {
246 return binary_op_type(op_kind, a.type(), b.type(), a, b);
247}
248
249type_t ternary_op_type(
250 op_kind_t op_kind, const expr_t &a, const expr_t &b, const expr_t &c) {
251 switch (op_kind) {
252 case op_kind_t::_add3:
253 return binary_op_type(op_kind_t::_add, a.type(),
254 binary_op_type(op_kind_t::_add, b, c));
255 case op_kind_t::_mad:
256 return binary_op_type(op_kind_t::_add, a.type(),
257 binary_op_type(op_kind_t::_mul, b, c));
258 default: ir_error_not_expected();
259 }
260 return type_t::undef();
261}
262
263type_t nary_op_type(op_kind_t op_kind, const std::vector<expr_t> &args) {
264 ir_assert(!args.empty());
265 if (args.size() == 1) return args[0].type();
266
267 auto type = args[0].type();
268 for (size_t i = 1; i < args.size(); i++)
269 type = common_type(type, args[i].type());
270
271 return type;
272}
273
274void ptr_t::normalize(expr_t &base, expr_t &off, op_kind_t op_kind) {
275 ir_assert(base.type().is_ptr()) << "base is not a pointer: " << base;
276 ir_assert(off.type().is_int()) << "off is not an integer: " << off;
277 ir_assert(utils::one_of(op_kind, op_kind_t::_add, op_kind_t::_sub))
278 << "Can't apply this operation to pointer: " << to_string(op_kind);
279
280 if (!base.is<ptr_t>()) {
281 if (op_kind == op_kind_t::_sub) off = const_fold(-off);
282 return;
283 }
284
285 auto &base_off = base.as<ptr_t>().off;
286 base = base.as<ptr_t>().base;
287 off = const_fold_non_recursive(binary_op_t::make(op_kind, base_off, off));
288}
289
290expr_t shift_ptr(op_kind_t op_kind, const expr_t &a, const expr_t &b) {
291 expr_t base = a;
292 expr_t off = b;
293 ptr_t::normalize(base, off, op_kind);
294 return ptr_t::make(base, off);
295}
296
297void normalize_ptr(const type_t &type, expr_t &base_expr, expr_t &off) {
298 if (base_expr.is<ptr_t>()) {
299 auto &base = base_expr.as<ptr_t>().base;
300 auto &base_off = base_expr.as<ptr_t>().off;
301
302 base_expr = base;
303 off = const_fold_non_recursive(base_off + off);
304 }
305 ir_assert(to_cpp<int64_t>(off) % type.scalar().size() == 0)
306 << "Incompatible offset: " << off;
307}
308
309expr_t expr_t::operator[](const expr_t &off) const {
310 if (is<shuffle_t>()) {
311 ir_assert(is_const(off)) << "Offset is not constant.";
312 auto &shuffle = as<shuffle_t>();
313 int idx = shuffle.idx[to_cpp<int>(off)];
314 return shuffle.vec[idx];
315 }
316 return shift_ptr(op_kind_t::_add, *this, off);
317}
318
319expr_t::expr_t(bool value) : object_t(new bool_imm_t(value)) {}
320expr_t::expr_t(float value) : object_t(new float_imm_t(value)) {}
321expr_t::expr_t(double value)
322 : object_t(new float_imm_t(value, type_t::f64())) {}
323expr_t::expr_t(int16_t value) : object_t(new int_imm_t(value)) {}
324expr_t::expr_t(int32_t value) : object_t(new int_imm_t(value)) {}
325expr_t::expr_t(int64_t value) : object_t(new int_imm_t(value)) {}
326expr_t::expr_t(uint16_t value) : object_t(new int_imm_t(value)) {}
327expr_t::expr_t(uint32_t value) : object_t(new int_imm_t(value)) {}
328expr_t::expr_t(uint64_t value) : object_t(new int_imm_t(value)) {}
329
330expr_t operator-(const expr_t &a) {
331 return const_fold_non_recursive(unary_op_t::make(op_kind_t::_minus, a));
332}
333
334#define DEFINE_BINARY_OPERATOR(op, op_kind) \
335 expr_t operator op(const expr_t &a, const expr_t &b) { \
336 if (a.type().is_ptr()) return shift_ptr(op_kind, a, b); \
337 return const_fold_non_recursive(binary_op_t::make(op_kind, a, b)); \
338 }
339
340DEFINE_BINARY_OPERATOR(+, op_kind_t::_add)
341DEFINE_BINARY_OPERATOR(-, op_kind_t::_sub)
342DEFINE_BINARY_OPERATOR(*, op_kind_t::_mul)
343DEFINE_BINARY_OPERATOR(/, op_kind_t::_div)
344DEFINE_BINARY_OPERATOR(%, op_kind_t::_mod)
345DEFINE_BINARY_OPERATOR(<<, op_kind_t::_shl)
346DEFINE_BINARY_OPERATOR(>>, op_kind_t::_shr)
347
348DEFINE_BINARY_OPERATOR(==, op_kind_t::_eq)
349DEFINE_BINARY_OPERATOR(!=, op_kind_t::_ne)
350DEFINE_BINARY_OPERATOR(>, op_kind_t::_gt)
351DEFINE_BINARY_OPERATOR(>=, op_kind_t::_ge)
352DEFINE_BINARY_OPERATOR(<, op_kind_t::_lt)
353DEFINE_BINARY_OPERATOR(<=, op_kind_t::_le)
354
355DEFINE_BINARY_OPERATOR(&, op_kind_t::_and)
356
357#undef DEFINE_BINARY_OPERATOR
358
359#define DEFINE_BINARY_ASSIGN_OPERATOR(op) \
360 expr_t &expr_t::operator op##=(const expr_t &rhs) { \
361 auto tmp = (*this)op rhs; \
362 *this = tmp; \
363 return *this; \
364 }
365
366DEFINE_BINARY_ASSIGN_OPERATOR(+)
367DEFINE_BINARY_ASSIGN_OPERATOR(-)
368DEFINE_BINARY_ASSIGN_OPERATOR(*)
369DEFINE_BINARY_ASSIGN_OPERATOR(/)
370DEFINE_BINARY_ASSIGN_OPERATOR(%)
371DEFINE_BINARY_ASSIGN_OPERATOR(&)
372
373#undef DEFINE_BINARY_ASSIGN_OPERATOR
374
375object_t object_impl_t::_mutate(ir_mutator_t &mutator) const {
376 return *this;
377}
378void object_impl_t::_visit(ir_visitor_t &visitor) const {}
379
380#define DECL_TRAVERSE_LEAF(name) \
381 object_t ir_mutator_t::_mutate(const name &obj) { return obj; } \
382 void ir_visitor_t::_visit(const name &obj) {}
383
384DECL_TRAVERSE_LEAF(bool_imm_t)
385DECL_TRAVERSE_LEAF(float_imm_t)
386DECL_TRAVERSE_LEAF(func_impl_t)
387DECL_TRAVERSE_LEAF(int_imm_t)
388DECL_TRAVERSE_LEAF(var_t)
389
390#undef DECL_TRAVERSE_LEAF
391
392object_t ir_mutator_t::_mutate(const alloc_t &obj) {
393 auto buf = mutate(obj.buf);
394 auto body = mutate(obj.body);
395
396 if (buf.is_same(obj.buf) && body.is_same(obj.body)) return obj;
397
398 return alloc_t::make(buf, obj.size, obj.kind, obj.attrs, body);
399}
400
401void ir_visitor_t::_visit(const alloc_t &obj) {
402 visit(obj.buf);
403 visit(obj.body);
404}
405
406object_t ir_mutator_t::_mutate(const binary_op_t &obj) {
407 auto a = mutate(obj.a);
408 auto b = mutate(obj.b);
409
410 if (a.is_same(obj.a) && b.is_same(obj.b)) return obj;
411
412 return binary_op_t::make(obj.op_kind, a, b);
413}
414
415void ir_visitor_t::_visit(const binary_op_t &obj) {
416 visit(obj.a);
417 visit(obj.b);
418}
419
420object_t ir_mutator_t::_mutate(const cast_t &obj) {
421 auto expr = mutate(obj.expr);
422
423 if (expr.is_same(obj.expr)) return obj;
424
425 return cast_t::make(obj.type, expr, obj.saturate);
426}
427
428void ir_visitor_t::_visit(const cast_t &obj) {
429 visit(obj.expr);
430}
431
432object_t ir_mutator_t::_mutate(const for_t &obj) {
433 auto var = mutate(obj.var);
434 auto init = mutate(obj.init);
435 auto bound = mutate(obj.bound);
436 auto body = mutate(obj.body);
437
438 if (var.is_same(obj.var) && init.is_same(obj.init)
439 && bound.is_same(obj.bound) && body.is_same(obj.body))
440 return obj;
441
442 return for_t::make(var, init, bound, body, obj.unroll);
443}
444
445void ir_visitor_t::_visit(const for_t &obj) {
446 visit(obj.var);
447 visit(obj.init);
448 visit(obj.bound);
449 visit(obj.body);
450}
451
452object_t ir_mutator_t::_mutate(const func_call_t &obj) {
453 auto func = mutate(obj.func);
454 auto args = mutate(obj.args);
455
456 if (func.is_same(obj.func) && ir_utils::is_same(args, obj.args)) return obj;
457
458 return func_call_t::make(func, args, obj.attr);
459}
460
461void ir_visitor_t::_visit(const func_call_t &obj) {
462 visit(obj.func);
463 visit(obj.args);
464}
465
466object_t ir_mutator_t::_mutate(const if_t &obj) {
467 auto cond = mutate(obj.cond);
468 auto body = mutate(obj.body);
469 auto else_body = mutate(obj.else_body);
470
471 if (cond.is_same(obj.cond) && body.is_same(obj.body)
472 && else_body.is_same(obj.else_body))
473 return obj;
474
475 return if_t::make(cond, body, else_body);
476}
477
478void ir_visitor_t::_visit(const if_t &obj) {
479 visit(obj.cond);
480 visit(obj.body);
481 visit(obj.else_body);
482}
483
484object_t ir_mutator_t::_mutate(const iif_t &obj) {
485 auto cond = mutate(obj.cond);
486 auto true_expr = mutate(obj.true_expr);
487 auto false_expr = mutate(obj.false_expr);
488
489 if (cond.is_same(obj.cond) && true_expr.is_same(obj.true_expr)
490 && false_expr.is_same(obj.false_expr))
491 return obj;
492
493 return iif_t::make(cond, true_expr, false_expr);
494}
495
496void ir_visitor_t::_visit(const iif_t &obj) {
497 visit(obj.cond);
498 visit(obj.true_expr);
499 visit(obj.false_expr);
500}
501
502object_t ir_mutator_t::_mutate(const let_t &obj) {
503 auto var = mutate(obj.var);
504 auto value = mutate(obj.value);
505 auto body = mutate(obj.body);
506
507 if (var.is_same(obj.var) && value.is_same(obj.value)
508 && body.is_same(obj.body))
509 return obj;
510
511 return let_t::make(var, value, body);
512}
513
514void ir_visitor_t::_visit(const let_t &obj) {
515 visit(obj.var);
516 visit(obj.value);
517 visit(obj.body);
518}
519
520object_t ir_mutator_t::_mutate(const load_t &obj) {
521 auto buf = mutate(obj.buf);
522 auto off = mutate(obj.off);
523
524 if (buf.is_same(obj.buf) && off.is_same(obj.off)) return obj;
525
526 return load_t::make(obj.type, buf, off, obj.stride);
527}
528
529void ir_visitor_t::_visit(const load_t &obj) {
530 visit(obj.buf);
531 visit(obj.off);
532}
533
534object_t ir_mutator_t::_mutate(const ptr_t &obj) {
535 auto base = mutate(obj.base);
536 auto off = mutate(obj.off);
537
538 if (base.is_same(obj.base) && off.is_same(obj.off)) return obj;
539
540 return ptr_t::make(base, off);
541}
542
543void ir_visitor_t::_visit(const ptr_t &obj) {
544 visit(obj.base);
545 visit(obj.off);
546}
547
548object_t ir_mutator_t::_mutate(const shuffle_t &obj) {
549 auto vec = mutate(obj.vec);
550
551 if (ir_utils::is_same(vec, obj.vec)) return obj;
552
553 return shuffle_t::make(vec, obj.idx);
554}
555
556void ir_visitor_t::_visit(const shuffle_t &obj) {
557 visit(obj.vec);
558}
559
560object_t ir_mutator_t::_mutate(const stmt_group_t &obj) {
561 auto body = mutate(obj.body);
562
563 if (body.is_same(obj.body)) return obj;
564
565 return stmt_group_t::make(obj.label, body);
566}
567
568void ir_visitor_t::_visit(const stmt_group_t &obj) {
569 visit(obj.body);
570}
571
572object_t ir_mutator_t::_mutate(const stmt_seq_t &obj) {
573 auto head = mutate(obj.head);
574 auto tail = mutate(obj.tail);
575
576 if (head.is_same(obj.head) && tail.is_same(obj.tail)) return obj;
577
578 return stmt_seq_t::make(head, tail);
579}
580
581void ir_visitor_t::_visit(const stmt_seq_t &obj) {
582 visit(obj.head);
583 visit(obj.tail);
584}
585
586object_t ir_mutator_t::_mutate(const store_t &obj) {
587 auto buf = mutate(obj.buf);
588 auto off = mutate(obj.off);
589 auto value = mutate(obj.value);
590 auto mask = mutate(obj.mask);
591
592 if (buf.is_same(obj.buf) && off.is_same(obj.off) && value.is_same(obj.value)
593 && mask.is_same(obj.mask))
594 return obj;
595
596 return store_t::make(buf, off, value, obj.stride, mask, obj.fill_mask0);
597}
598
599void ir_visitor_t::_visit(const store_t &obj) {
600 visit(obj.buf);
601 visit(obj.off);
602 visit(obj.value);
603 visit(obj.mask);
604}
605
606object_t ir_mutator_t::_mutate(const ternary_op_t &obj) {
607 auto a = mutate(obj.a);
608 auto b = mutate(obj.b);
609 auto c = mutate(obj.c);
610
611 if (a.is_same(obj.a) && b.is_same(obj.b) && c.is_same(obj.c)) return obj;
612
613 return ternary_op_t::make(obj.op_kind, a, b, c);
614}
615
616void ir_visitor_t::_visit(const ternary_op_t &obj) {
617 visit(obj.a);
618 visit(obj.b);
619 visit(obj.c);
620}
621
622object_t ir_mutator_t::_mutate(const unary_op_t &obj) {
623 auto a = mutate(obj.a);
624 if (a.is_same(obj.a)) return obj;
625 return unary_op_t::make(obj.op_kind, a);
626}
627
628void ir_visitor_t::_visit(const unary_op_t &obj) {
629 visit(obj.a);
630}
631
632// Catch missing mutates that are not expected to dispatch to the base
633// mutator
634object_t ir_mutator_t::_mutate(const nary_op_t &obj) {
635 ir_error_not_expected() << "Can't handle type: nary_op_t";
636 return {};
637}
638void ir_visitor_t::_visit(const nary_op_t &obj) {
639 ir_error_not_expected() << "Can't handle type: nary_op_t";
640}
641object_t ir_mutator_t::_mutate(const pexpr_t &obj) {
642 ir_error_not_expected() << "Can't handle type: pexpr_t";
643 return {};
644}
645void ir_visitor_t::_visit(const pexpr_t &obj) {
646 ir_error_not_expected() << "Can't handle type: pexpr_t";
647}
648
649} // namespace jit
650} // namespace gpu
651} // namespace impl
652} // namespace dnnl
653