1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_JIT_IR_CORE_HPP
18#define GPU_JIT_IR_CORE_HPP
19
20#include <algorithm>
21#include <atomic>
22#include <cstdio>
23#include <memory>
24#include <numeric>
25#include <string>
26
27#include "common/c_types_map.hpp"
28#include "common/math_utils.hpp"
29#include "gpu/jit/utils/ngen_proxy.hpp"
30#include "gpu/jit/utils/utils.hpp"
31
32#if !defined(NDEBUG) || defined(GEN_CONV_DEBUG)
33#define SANITY_CHECK 1
34#endif
35
36// All IR expression objects.
37#define HANDLE_EXPR_IR_OBJECTS() \
38 HANDLE_IR_OBJECT(binary_op_t) \
39 HANDLE_IR_OBJECT(bool_imm_t) \
40 HANDLE_IR_OBJECT(cast_t) \
41 HANDLE_IR_OBJECT(float_imm_t) \
42 HANDLE_IR_OBJECT(iif_t) \
43 HANDLE_IR_OBJECT(int_imm_t) \
44 HANDLE_IR_OBJECT(load_t) \
45 HANDLE_IR_OBJECT(ptr_t) \
46 HANDLE_IR_OBJECT(shuffle_t) \
47 HANDLE_IR_OBJECT(ternary_op_t) \
48 HANDLE_IR_OBJECT(unary_op_t) \
49 HANDLE_IR_OBJECT(var_t)
50
51// All IR statement objects.
52#define HANDLE_STMT_IR_OBJECTS() \
53 HANDLE_IR_OBJECT(alloc_t) \
54 HANDLE_IR_OBJECT(for_t) \
55 HANDLE_IR_OBJECT(func_call_t) \
56 HANDLE_IR_OBJECT(if_t) \
57 HANDLE_IR_OBJECT(let_t) \
58 HANDLE_IR_OBJECT(stmt_group_t) \
59 HANDLE_IR_OBJECT(stmt_seq_t) \
60 HANDLE_IR_OBJECT(store_t)
61
62#define HANDLE_TRAVERSE_TARGETS() \
63 HANDLE_EXPR_IR_OBJECTS() \
64 HANDLE_STMT_IR_OBJECTS() \
65 HANDLE_IR_OBJECT(func_impl_t) \
66 HANDLE_IR_OBJECT(nary_op_t) \
67 HANDLE_IR_OBJECT(pexpr_t)
68
69#define HANDLE_ALL_IR_OBJECTS() \
70 HANDLE_EXPR_IR_OBJECTS() \
71 HANDLE_STMT_IR_OBJECTS() \
72 HANDLE_IR_OBJECT(func_impl_t)
73
74enum ir_type_id_t : uint8_t {
75#define HANDLE_IR_OBJECT(type) type,
76
77 // Create typeid for objects which can be visited/mutated. These need to be
78 // first as the typeid is used as an index into an array to dispatch to the
79 // correct mutate function.
80 HANDLE_ALL_IR_OBJECTS()
81
82 //Used to calculate number of IR objects that can be visited/mutated
83 end_visitable_ir_objects,
84
85 // Other IR object
86 expr_impl_t = end_visitable_ir_objects,
87 nary_op_t,
88 stmt_impl_t,
89 grf_permute_attr_t,
90 bank_conflict_attr_t,
91 instruction_modifier_attr_t,
92 builtin_t,
93 pexpr_t,
94 pint_imm_t,
95 factored_expr_t,
96 send_t,
97 dpas_t,
98 mad_t,
99 reduce_t,
100 reorder_t,
101 eltwise_t,
102
103#undef HANDLE_IR_OBJECT
104};
105
106struct type_info_t {
107 type_info_t(ir_type_id_t type_id, bool is_expr, bool is_stmt)
108 : type_id(type_id), is_expr(is_expr), is_stmt(is_stmt) {};
109 ir_type_id_t type_id;
110 bool is_expr;
111 bool is_stmt;
112};
113
114// Auxiliary macros to reduce boilerplate.
115#define IR_DECL_TYPE_ID(class_name) \
116 using self_type = class_name; \
117 static ir_type_id_t _type_id() { return ir_type_id_t::class_name; } \
118 static ir_type_id_t _dispatch_type_id() { return _type_id(); } \
119 static type_info_t _type_info() { \
120 return type_info_t(_type_id(), _is_expr(), _is_stmt()); \
121 }
122
123#define IR_DECL_DERIVED_TYPE_ID(class_name, base_name) \
124 using self_type = class_name; \
125 static ir_type_id_t _type_id() { return ir_type_id_t::class_name; } \
126 static ir_type_id_t _dispatch_type_id() { return base_name::_type_id(); } \
127 ir_type_id_t dispatch_type_id() const override { \
128 return _dispatch_type_id(); \
129 } \
130 static type_info_t _type_info() { \
131 return type_info_t(_type_id(), _is_expr(), _is_stmt()); \
132 }
133
134#define IR_DECL_EXPR_TYPE_ID(class_name) \
135 IR_DECL_TYPE_ID(class_name) \
136 static bool _is_expr() { return true; };
137
138#define IR_DECL_STMT_TYPE_ID(class_name) \
139 IR_DECL_TYPE_ID(class_name) \
140 static bool _is_stmt() { return true; };
141
142#define IR_DECL_MUTATE(mutator_template) \
143 object_t _mutate(mutator_template &mutator) const override { \
144 return mutator._mutate(*this); \
145 }
146#define IR_DECL_VISIT(visitor_template) \
147 void _visit(visitor_template &visitor) const override { \
148 visitor._visit(*this); \
149 }
150
151#define IR_DECLARE_TRAVERSERS() \
152 IR_DECL_MUTATE(ir_mutator_t) \
153 IR_DECL_VISIT(ir_visitor_t)
154
155// Defines getter for a function argument.
156#define IR_DEFINE_ARG_GET(name, index) \
157 static const expr_t &arg_##name(const stmt_t &s) { \
158 ir_assert(s.is<func_call_t>()) << s; \
159 auto &c = s.as<func_call_t>(); \
160 ir_assert(c.func.is<self_type>()) << s; \
161 return c.args[index]; \
162 } \
163 template <typename T> \
164 static T &arg_##name(std::vector<T> &args) { \
165 return args[index]; \
166 } \
167 template <typename T> \
168 static const T &arg_##name(const std::vector<T> &args) { \
169 return args[index]; \
170 }
171
172#if defined(__GNUC__)
173// clang-format off
174// Defines dump() method for debugging purposes, to pretty print the object.
175#define IR_DEFINE_DUMP() \
176 __attribute__((noinline)) \
177 __attribute__((used)) \
178 void dump() const { \
179 printf("%s\n", str().c_str()); \
180 }
181// clang-format on
182#else
183#define IR_DEFINE_DUMP()
184#endif
185
186namespace dnnl {
187namespace impl {
188namespace gpu {
189namespace jit {
190
191enum class type_kind_t {
192 undef,
193 _bool,
194
195 // Integer types.
196 u8,
197 s8,
198 u16,
199 s16,
200 u32,
201 s32,
202 u64,
203 s64,
204
205 // Floating point types.
206 bf16,
207 f16,
208 tf32,
209 f32,
210 f64,
211
212 // Message data types.
213 byte,
214 dword,
215 qword,
216 oword,
217 hword
218};
219
220std::string to_string(type_kind_t kind);
221
222class type_t {
223public:
224 static type_t undef() { return type_t(type_kind_t::undef); }
225 static type_t _bool(int elems = 1) {
226 return type_t(type_kind_t::_bool, elems);
227 }
228
229 static type_t u8(int elems = 1) { return type_t(type_kind_t::u8, elems); }
230 static type_t s8(int elems = 1) { return type_t(type_kind_t::s8, elems); }
231 static type_t u16(int elems = 1) { return type_t(type_kind_t::u16, elems); }
232 static type_t s16(int elems = 1) { return type_t(type_kind_t::s16, elems); }
233 static type_t u32(int elems = 1) { return type_t(type_kind_t::u32, elems); }
234 static type_t s32(int elems = 1) { return type_t(type_kind_t::s32, elems); }
235 static type_t u64(int elems = 1) { return type_t(type_kind_t::u64, elems); }
236 static type_t s64(int elems = 1) { return type_t(type_kind_t::s64, elems); }
237
238 // Returns unsigned integer type.
239 static type_t u(int bits, int elems = 1) {
240 switch (bits) {
241 case 8: return u8(elems);
242 case 16: return u16(elems);
243 case 32: return u32(elems);
244 case 64: return u64(elems);
245 default: ir_error_not_expected();
246 }
247 return type_t::undef();
248 }
249
250 // Returns signed integer type.
251 static type_t s(int bits, int elems = 1) {
252 switch (bits) {
253 case 8: return s8(elems);
254 case 16: return s16(elems);
255 case 32: return s32(elems);
256 case 64: return s64(elems);
257 default: ir_error_not_expected();
258 }
259 return type_t::undef();
260 }
261
262 static type_t bf16(int elems = 1) {
263 return type_t(type_kind_t::bf16, elems);
264 }
265 static type_t f16(int elems = 1) { return type_t(type_kind_t::f16, elems); }
266 static type_t tf32(int elems = 1) {
267 return type_t(type_kind_t::tf32, elems);
268 }
269 static type_t f32(int elems = 1) { return type_t(type_kind_t::f32, elems); }
270 static type_t f64(int elems = 1) { return type_t(type_kind_t::f64, elems); }
271
272 static type_t byte(int elems = 1) {
273 return type_t(type_kind_t::byte, elems);
274 }
275 static type_t byte_ptr(int elems = 1) {
276 return type_t(type_kind_t::byte, elems).with_ptr();
277 }
278 static type_t dword(int elems = 1) {
279 return type_t(type_kind_t::dword, elems);
280 }
281 static type_t qword(int elems = 1) {
282 return type_t(type_kind_t::qword, elems);
283 }
284 static type_t oword(int elems = 1) {
285 return type_t(type_kind_t::oword, elems);
286 }
287 static type_t hword(int elems = 1) {
288 return type_t(type_kind_t::hword, elems);
289 }
290
291 template <typename T>
292 static type_t from_cpp() {
293#define CASE(cpp_type, type) \
294 if (std::is_same<T, cpp_type>::value) return type()
295
296 CASE(bool, _bool);
297 CASE(float, f32);
298 CASE(double, f64);
299 CASE(int16_t, s16);
300 CASE(int32_t, s32);
301 CASE(int64_t, s64);
302 CASE(uint16_t, u16);
303 CASE(uint32_t, u32);
304 CASE(uint64_t, u64);
305
306#undef CASE
307
308 ir_error_not_expected();
309
310 return undef();
311 }
312
313 template <typename T>
314 T max() const {
315 switch (kind()) {
316 case type_kind_t::u8:
317 case type_kind_t::s8:
318 case type_kind_t::u16:
319 case type_kind_t::s16:
320 case type_kind_t::u32:
321 case type_kind_t::s32:
322 case type_kind_t::u64:
323 case type_kind_t::s64: {
324 int bits = 8 * size();
325 if (is_signed()) bits--;
326 T ret = T(1) << (bits - 1);
327 return ret + (ret - 1);
328 }
329 default: ir_error_not_expected();
330 }
331 return 0;
332 }
333
334 template <typename T>
335 T min() const {
336 switch (kind()) {
337 case type_kind_t::u8:
338 case type_kind_t::s8:
339 case type_kind_t::u16:
340 case type_kind_t::s16:
341 case type_kind_t::u32:
342 case type_kind_t::s32:
343 case type_kind_t::u64:
344 case type_kind_t::s64: {
345 if (is_unsigned()) return 0;
346 return -max<T>() - 1;
347 }
348 default: ir_error_not_expected();
349 }
350 return 0;
351 }
352
353 static bool is_vector(int elems) { return elems != 1; }
354
355 type_t() : type_t(type_t::undef()) {}
356
357 type_t(type_kind_t kind, uint32_t elems = 1) : kind_(kind), elems_(elems) {}
358
359 // Constructor from dnnl_data_type_t.
360 type_t(data_type_t dt) {
361 elems_ = 1;
362 switch ((int)dt) {
363#define CASE(x) \
364 case data_type::x: kind_ = type_kind_t::x; break;
365 CASE(bf16);
366 CASE(f16);
367 CASE(tf32);
368 CASE(f32);
369 CASE(f64);
370 CASE(s32);
371 CASE(s8);
372 CASE(u8);
373#undef CASE
374 default: ir_error_not_expected();
375 }
376 }
377
378 type_kind_t kind() const { return kind_; }
379
380 int elems() const { return elems_; }
381
382 bool is_ptr() const { return is_ptr_; }
383
384 bool operator==(const type_t &other) const {
385 return (kind() == other.kind()) && (elems() == other.elems())
386 && (is_ptr() == other.is_ptr());
387 }
388
389 bool operator!=(const type_t &other) const { return !operator==(other); }
390
391 bool is_equal(const type_t &other) const { return operator==(other); }
392
393 size_t get_hash() const {
394 return ir_utils::get_hash(kind(), elems(), is_ptr());
395 }
396
397 bool is_undef() const { return kind() == type_kind_t::undef; }
398
399 bool is_vector() const { return type_t::is_vector(elems()); }
400
401 bool is_bool() const { return kind() == type_kind_t::_bool; }
402
403 bool is_fp() const {
404 return utils::one_of(kind(), type_kind_t::bf16, type_kind_t::f16,
405 type_kind_t::tf32, type_kind_t::f32, type_kind_t::f64);
406 }
407
408 bool is_bf16() const { return kind() == type_kind_t::bf16; }
409 bool is_f16() const { return kind() == type_kind_t::f16; }
410 bool is_tf32() const { return kind() == type_kind_t::tf32; }
411 bool is_f32() const { return kind() == type_kind_t::f32; }
412 bool is_f64() const { return kind() == type_kind_t::f64; }
413
414 bool is_int() const {
415 return utils::one_of(kind(), type_kind_t::u8, type_kind_t::s8,
416 type_kind_t::u16, type_kind_t::s16, type_kind_t::u32,
417 type_kind_t::s32, type_kind_t::u64, type_kind_t::s64);
418 }
419
420 bool is_s8() const { return kind() == type_kind_t::s8; }
421 bool is_u8() const { return kind() == type_kind_t::u8; }
422 bool is_x8() const {
423 return utils::one_of(kind(), type_kind_t::s8, type_kind_t::u8);
424 }
425
426 bool is_s16() const { return kind() == type_kind_t::s16; }
427 bool is_u16() const { return kind() == type_kind_t::u16; }
428 bool is_x16() const {
429 return utils::one_of(kind(), type_kind_t::s16, type_kind_t::u16);
430 }
431
432 bool is_s32() const { return kind() == type_kind_t::s32; }
433 bool is_u32() const { return kind() == type_kind_t::u32; }
434 bool is_x32() const {
435 return utils::one_of(kind(), type_kind_t::s32, type_kind_t::u32);
436 }
437
438 bool is_s64() const { return kind() == type_kind_t::s64; }
439 bool is_u64() const { return kind() == type_kind_t::u64; }
440 bool is_x64() const {
441 return utils::one_of(kind(), type_kind_t::s64, type_kind_t::u64);
442 }
443
444 bool is_byte() const { return kind() == type_kind_t::byte; }
445 bool is_dword() const { return kind() == type_kind_t::dword; }
446 bool is_qword() const { return kind() == type_kind_t::qword; }
447 bool is_oword() const { return kind() == type_kind_t::oword; }
448 bool is_hword() const { return kind() == type_kind_t::hword; }
449
450 bool is_signed(int elems = -1) const {
451 if (elems != -1 && elems_ != elems) return false;
452 return utils::one_of(kind(), type_kind_t::s8, type_kind_t::s16,
453 type_kind_t::s32, type_kind_t::s64);
454 }
455
456 bool is_unsigned(int elems = -1) const {
457 if (elems != -1 && elems_ != elems) return false;
458 return utils::one_of(kind(), type_kind_t::u8, type_kind_t::u16,
459 type_kind_t::u32, type_kind_t::u64);
460 }
461
462 bool is_scalar() const { return elems() == 1; }
463
464 template <typename T>
465 bool is_cpp() const {
466 return *this == type_t::from_cpp<T>();
467 }
468
469 bool is_bitwise_compatible(const type_t &other) const {
470 if (*this == other) return true;
471
472 // tf32 is bitwise compatible with f32.
473 if (kind() == type_kind_t::f32 && other.kind() == type_kind_t::tf32)
474 return elems() == other.elems();
475
476 return false;
477 }
478
479 type_t remove_elems() const { return with_elems(1); }
480
481 type_t remove_ptr() const {
482 type_t copy = *this;
483 copy.is_ptr_ = false;
484 return copy;
485 }
486
487 type_t with_elems(int new_elems) const {
488 type_t copy = *this;
489 copy.elems_ = new_elems;
490 return copy;
491 }
492
493 type_t with_ptr() const {
494 type_t copy = *this;
495 copy.is_ptr_ = true;
496 return copy;
497 }
498
499 type_t scalar() const { return with_elems(1); }
500
501 // Returns size in bytes.
502 int size() const;
503
504 std::string str() const {
505 std::ostringstream oss;
506 oss << to_string(kind());
507 if (elems() > 1) oss << "x" << elems();
508 if (is_ptr()) oss << "*";
509 return oss.str();
510 }
511
512 IR_DEFINE_DUMP()
513
514private:
515 type_kind_t kind_ = type_kind_t::undef;
516 int elems_ = 0;
517 bool is_ptr_ = false;
518};
519
520inline std::ostream &operator<<(std::ostream &out, const type_t &type) {
521 out << type.str();
522 return out;
523}
524
525// type_t to dnnl_data_type_t convertor.
526data_type_t to_dnnl(const type_t &type);
527
528// Reference counter for IR objects.
529class ref_count_t {
530public:
531 ref_count_t() : value_(0) {}
532 ref_count_t(const ref_count_t &) = delete;
533
534 uint32_t increment() { return ++value_; }
535 uint32_t decrement() { return --value_; }
536
537private:
538 uint32_t value_;
539};
540
541// Forward Declare IR objects
542class object_t;
543class ir_mutator_t;
544class ir_visitor_t;
545
546#define HANDLE_IR_OBJECT(type) class type;
547HANDLE_TRAVERSE_TARGETS()
548#undef HANDLE_IR_OBJECT
549
550// Base class for all IR objects. Implemented as an intrusive pointer, with
551// the reference counter stored inside the object.
552class object_impl_t {
553public:
554 object_impl_t(type_info_t type_info)
555 : ref_count_(), type_info_(type_info) {};
556
557 object_impl_t(const object_impl_t &) = delete;
558
559 virtual ~object_impl_t() = default;
560
561 ref_count_t &ref_count() { return ref_count_; }
562
563 // Type ID used for dispatching in ir_visitor_t and ir_mutator_t.
564 // For some IR objects
565 virtual ir_type_id_t dispatch_type_id() const { return type_id(); }
566
567 // Provides equality semantics.
568 virtual bool is_equal(const object_impl_t &obj) const = 0;
569
570 virtual size_t get_hash() const = 0;
571
572 static bool _is_expr() { return false; };
573 static bool _is_stmt() { return false; };
574 bool is_expr() const { return type_info_.is_expr; }
575 bool is_stmt() const { return type_info_.is_stmt; }
576
577 // Downcasts the object to the IR type, returns a reference. The IR type
578 // must match the real IR type.
579 template <typename T>
580 const T &as() const {
581 ir_assert(this->is<T>());
582 return *(const T *)this;
583 }
584
585 template <typename T>
586 T &as() {
587 ir_assert(this->is<T>());
588 return *(T *)this;
589 }
590
591 // Downcasts the object to the IR type, returns a pointer. If the IR type
592 // doesn't match the real IR type, returns nullptr.
593 template <typename T>
594 const T *as_ptr() const {
595 if (!this->is<T>()) return nullptr;
596 return (const T *)this;
597 }
598
599 template <typename T>
600 T *as_ptr() {
601 if (!this->is<T>()) return nullptr;
602 return (T *)this;
603 }
604
605 // Returns true if T matches the real IR type.
606 template <typename T>
607 bool is() const {
608 return type_id() == T::_type_id();
609 }
610
611 virtual std::string str() const;
612
613 virtual object_t _mutate(ir_mutator_t &mutator) const;
614 virtual void _visit(ir_visitor_t &visitor) const;
615 IR_DEFINE_DUMP()
616
617private:
618 // Unique type ID.
619 ir_type_id_t type_id() const { return type_info_.type_id; };
620
621 ref_count_t ref_count_;
622 type_info_t type_info_;
623};
624
625// Base wrapper for IR objects.
626class object_t {
627public:
628 object_t(object_impl_t *impl = nullptr) : impl_(impl) {
629 increment(impl_);
630#ifdef SANITY_CHECK
631 sanity_check();
632#endif
633 }
634 object_t(const object_impl_t &impl)
635 : object_t(const_cast<object_impl_t *>(&impl)) {}
636 object_t(const object_impl_t *impl)
637 : object_t(const_cast<object_impl_t *>(impl)) {}
638 object_t(const object_t &obj) : object_t(obj.impl()) {}
639 object_t(object_t &&obj) : impl_(obj.impl_) {
640 obj.impl_ = nullptr;
641#ifdef SANITY_CHECK
642 sanity_check();
643#endif
644 }
645
646#ifdef SANITY_CHECK
647 virtual ~object_t() { decrement_and_maybe_destroy(impl_); }
648#else
649 ~object_t() { decrement_and_maybe_destroy(impl_); }
650#endif
651
652 object_t &operator=(const object_t &other) {
653 increment(other.impl());
654 decrement_and_maybe_destroy(impl_);
655 impl_ = other.impl();
656#ifdef SANITY_CHECK
657 sanity_check();
658#endif
659 return *this;
660 }
661
662 object_t &operator=(object_t &&other) {
663 std::swap(impl_, other.impl_);
664#ifdef SANITY_CHECK
665 sanity_check();
666#endif
667 return *this;
668 }
669
670 object_impl_t *impl() const { return impl_; }
671
672 bool is_empty() const { return !impl_; }
673
674 ir_type_id_t dispatch_type_id() const { return impl_->dispatch_type_id(); }
675
676 template <typename T>
677 const T &as() const {
678 ir_assert(impl_);
679 return impl_->as<T>();
680 }
681
682 template <typename T>
683 T &as() {
684 ir_assert(impl_);
685 return impl_->as<T>();
686 }
687
688 template <typename T>
689 const T *as_ptr() const {
690 if (!impl_) return nullptr;
691 return impl_->as_ptr<T>();
692 }
693
694 template <typename T>
695 T *as_ptr() {
696 if (!impl_) return nullptr;
697 return impl_->as_ptr<T>();
698 }
699
700 template <typename T>
701 bool is() const {
702 if (is_empty()) return false;
703 return impl_->is<T>();
704 }
705
706 // Comparison with identity semantics.
707 bool is_same(const object_t &other) const { return impl_ == other.impl(); }
708
709 // Comparison with equality semantics.
710 bool is_equal(const object_t &other) const {
711 if (is_empty() || other.is_empty())
712 return is_empty() == other.is_empty();
713
714 return impl_->is_equal(*other.impl());
715 }
716
717 size_t get_hash() const {
718 if (is_empty()) return 0;
719 return impl()->get_hash();
720 }
721
722 bool is_expr() const { return impl_ && impl_->is_expr(); }
723 bool is_stmt() const { return impl_ && impl_->is_stmt(); }
724
725 std::string str() const {
726 if (is_empty()) return "(nil)";
727 return impl()->str();
728 }
729
730 IR_DEFINE_DUMP()
731
732protected:
733#ifdef SANITY_CHECK
734 virtual void sanity_check() const {}
735#endif
736
737private:
738 static void increment(object_impl_t *impl) {
739 if (!impl) return;
740 impl->ref_count().increment();
741 }
742
743 static void decrement_and_maybe_destroy(object_impl_t *impl) {
744 if (!impl) return;
745 if (impl->ref_count().decrement() == 0) { delete impl; }
746 }
747
748 object_impl_t *impl_;
749};
750
751inline std::ostream &operator<<(std::ostream &out, const object_t &obj) {
752 out << obj.str();
753 return out;
754}
755
756// Helper classes for containers to store object_t.
757struct object_id_hash_t {
758 size_t operator()(const object_t &obj) const {
759 return std::hash<const object_impl_t *>()(obj.impl());
760 }
761};
762
763struct object_eq_hash_t {
764 size_t operator()(const object_t &obj) const { return obj.get_hash(); }
765};
766
767struct object_id_equal_t {
768 bool operator()(const object_t &a, const object_t &b) const {
769 return a.is_same(b);
770 }
771};
772
773struct object_eq_equal_t {
774 bool operator()(const object_t &a, const object_t &b) const {
775 return a.is_equal(b);
776 }
777};
778
779// Containers to store object_t.
780
781// Unordered set, uses identity comparison for keys.
782template <typename KeyT>
783using object_set_t
784 = std::unordered_set<KeyT, object_id_hash_t, object_id_equal_t>;
785
786// Unordered set, uses equality comparison for keys.
787template <typename KeyT>
788using object_eq_set_t
789 = std::unordered_set<KeyT, object_eq_hash_t, object_eq_equal_t>;
790
791// Unordered map, uses identity comparison for keys.
792template <typename KeyT, typename ValueT>
793using object_map_t
794 = std::unordered_map<KeyT, ValueT, object_id_hash_t, object_id_equal_t>;
795
796// Unordered map, uses equality comparison for keys.
797template <typename KeyT, typename ValueT>
798using object_eq_map_t
799 = std::unordered_map<KeyT, ValueT, object_eq_hash_t, object_eq_equal_t>;
800
801// Helper class to mutate IR tree.
802class ir_mutator_t {
803public:
804 virtual ~ir_mutator_t() = default;
805
806 object_t mutate(const object_t &obj) {
807 auto impl = obj.impl();
808 if (!impl) return impl;
809 return impl->_mutate(*this);
810 }
811
812 template <typename T>
813 std::vector<T> mutate(const std::vector<T> &v) {
814 std::vector<T> new_v;
815 for (auto &e : v)
816 new_v.push_back(mutate(e));
817 return new_v;
818 }
819
820 // To catch missing _mutate() handlers in ir_mutator_t.
821 object_t _mutate(const object_impl_t &obj) {
822 ir_error_not_expected() << "Can't handle type: " << object_t(&obj);
823 return {};
824 }
825
826#define HANDLE_IR_OBJECT(type) virtual object_t _mutate(const type &obj);
827 HANDLE_TRAVERSE_TARGETS()
828#undef HANDLE_IR_OBJECT
829};
830
831// Helper class to walk through IR tree.
832class ir_visitor_t {
833public:
834 virtual ~ir_visitor_t() = default;
835
836 void visit(const object_t &obj) {
837 const object_impl_t *impl = obj.impl();
838 if (impl) {
839 pre_visit(*impl);
840 impl->_visit(*this);
841 post_visit(*impl);
842 };
843 }
844
845 template <typename T>
846 void visit(const std::vector<T> &v) {
847 for (auto &e : v)
848 visit(e);
849 }
850
851 virtual void pre_visit(const object_impl_t &obj) {}
852 virtual void post_visit(const object_impl_t &obj) {}
853
854 // To catch missing _visit() handlers in ir_visitor_t.
855 void _visit(const object_impl_t &obj) {
856 ir_error_not_expected() << "Can't handle type: " << object_t(obj);
857 }
858
859#define HANDLE_IR_OBJECT(type) virtual void _visit(const type &obj);
860 HANDLE_TRAVERSE_TARGETS()
861#undef HANDLE_IR_OBJECT
862};
863
864// Base class for IR expression objects.
865class expr_impl_t : public object_impl_t {
866public:
867 IR_DECL_TYPE_ID(expr_impl_t)
868
869 expr_impl_t(type_info_t type_info, const type_t &type)
870 : object_impl_t(type_info), type(type) {}
871
872 type_t type;
873};
874
875// Wrapper for IR expression objects.
876class expr_t : public object_t {
877public:
878 using object_t::object_t;
879
880 expr_t() = default;
881 expr_t(const object_t &obj) : object_t(obj) {}
882 expr_t(object_t &&obj) : object_t(obj) {}
883 expr_t &operator=(const object_t &obj) {
884 object_t::operator=(obj);
885 return *this;
886 }
887 expr_t &operator=(object_t &&obj) {
888 object_t::operator=(obj);
889 return *this;
890 }
891
892 explicit expr_t(bool v);
893 expr_t(float v);
894 expr_t(double v);
895 expr_t(int16_t v);
896 expr_t(int32_t v);
897 expr_t(int64_t v);
898 expr_t(uint16_t v);
899 expr_t(uint32_t v);
900 expr_t(uint64_t v);
901
902 const type_t &type() const {
903 ir_assert(!is_empty());
904 return ((const expr_impl_t *)impl())->type;
905 }
906
907#define DECLARE_BINARY_ASSIGN_OPERATOR(op) \
908 expr_t &operator op##=(const expr_t &rhs);
909
910 DECLARE_BINARY_ASSIGN_OPERATOR(+)
911 DECLARE_BINARY_ASSIGN_OPERATOR(-)
912 DECLARE_BINARY_ASSIGN_OPERATOR(*)
913 DECLARE_BINARY_ASSIGN_OPERATOR(/)
914 DECLARE_BINARY_ASSIGN_OPERATOR(%)
915 DECLARE_BINARY_ASSIGN_OPERATOR(&)
916
917#undef DECLARE_BINARY_ASSIGN_OPERATOR
918
919 // Returns a pointer shifted by `off` bytes relative to this pointer. The
920 // base expression must be a pointer.
921 expr_t operator[](const expr_t &off) const;
922
923private:
924#ifdef SANITY_CHECK
925 void sanity_check() const override {
926 ir_assert(dynamic_cast<const expr_impl_t *>(impl()) == impl())
927 << object_t(impl());
928 }
929#endif
930};
931
932// Helper functions.
933inline bool is_const(const expr_t &e);
934inline bool is_var(const expr_t &e);
935inline bool all_of(const expr_t &e, const expr_t &value);
936
937// Unary and binary operators.
938enum class op_kind_t {
939 undef,
940
941 _minus,
942 _add,
943 _sub,
944 _mul,
945 _div,
946 _mod,
947 _shl,
948 _shr,
949 _min,
950 _max,
951
952 _lt,
953 _le,
954 _gt,
955 _ge,
956 _ne,
957 _eq,
958
959 _and,
960
961 _prelu, // binary relu(a, b)
962 _add3, // a + b + c
963 _mad, // a + b * c
964 _dp4a, // dpas.1x1
965};
966
967std::string to_string(op_kind_t kind);
968
969inline std::ostream &operator<<(std::ostream &out, op_kind_t kind) {
970 out << to_string(kind);
971 return out;
972}
973
974bool is_cmp_op(op_kind_t op_kind);
975
976op_kind_t negate_cmp_op(op_kind_t op_kind);
977
978type_t unary_op_type(op_kind_t op_kind, const expr_t &a);
979
980type_t common_int_type(const type_t &_a, const type_t &_b);
981
982type_t common_type(const type_t &a, const type_t &b);
983
984type_t common_type(const expr_t &a, const expr_t &b);
985
986type_t binary_op_type(op_kind_t op_kind, const expr_t &a, const expr_t &b);
987
988type_t ternary_op_type(
989 op_kind_t op_kind, const expr_t &a, const expr_t &b, const expr_t &c);
990
991type_t nary_op_type(op_kind_t op_kind, const std::vector<expr_t> &args);
992
993// Binary operation: (a op b).
994class binary_op_t : public expr_impl_t {
995public:
996 IR_DECL_EXPR_TYPE_ID(binary_op_t)
997
998 static expr_t make(op_kind_t op_kind, const expr_t &a, const expr_t &b,
999 type_t ty = type_t()) {
1000 return expr_t(new binary_op_t(op_kind, a, b, ty));
1001 }
1002
1003 bool is_equal(const object_impl_t &obj) const override {
1004 if (!obj.is<self_type>()) return false;
1005 auto &other = obj.as<self_type>();
1006
1007 return (op_kind == other.op_kind) && a.is_equal(other.a)
1008 && b.is_equal(other.b);
1009 }
1010
1011 size_t get_hash() const override {
1012 return ir_utils::get_hash(op_kind, a, b);
1013 }
1014
1015 IR_DECLARE_TRAVERSERS()
1016
1017 op_kind_t op_kind;
1018 expr_t a;
1019 expr_t b;
1020
1021private:
1022 binary_op_t(op_kind_t op_kind, const expr_t &a, const expr_t &b, type_t ty)
1023 : expr_impl_t(_type_info(),
1024 (ty.is_undef()) ? binary_op_type(op_kind, a, b) : ty)
1025 , op_kind(op_kind)
1026 , a(a)
1027 , b(b) {}
1028};
1029
1030// Boolean immediate value.
1031class bool_imm_t : public expr_impl_t {
1032public:
1033 friend class expr_t;
1034 IR_DECL_EXPR_TYPE_ID(bool_imm_t)
1035
1036 static expr_t make(bool value) { return expr_t(new bool_imm_t(value)); }
1037
1038 bool is_equal(const object_impl_t &obj) const override {
1039 if (!obj.is<self_type>()) return false;
1040 auto &other = obj.as<self_type>();
1041
1042 return value == other.value;
1043 }
1044
1045 size_t get_hash() const override { return ir_utils::get_hash(value); }
1046
1047 IR_DECLARE_TRAVERSERS()
1048
1049 bool value;
1050
1051private:
1052 bool_imm_t(bool value)
1053 : expr_impl_t(_type_info(), type_t::_bool()), value(value) {}
1054};
1055
1056// Cast between data types. In general conversion follows the C++ casting
1057// rules. Several modes/scenarios are supported:
1058// - Cast with saturation: cast(T, e) = max(T_min, min(T_max, e))
1059// By default saturation is disabled and any underflow/overflow is unhandled.
1060// - Bitwise cast from bool vector to u16 (boolxN -> u16, 2 <= N <= 16):
1061// In this case the lower N bits of the resulting value are initialized based
1062// on the boolean elements. The upper (16 - N) bits are uninitialized.
1063class cast_t : public expr_impl_t {
1064public:
1065 IR_DECL_EXPR_TYPE_ID(cast_t)
1066
1067 static expr_t make(
1068 const type_t &type, const expr_t &expr, bool saturate = false) {
1069 if (expr.type() == type) return expr;
1070 if (!saturate) {
1071 auto *expr_cast = expr.as_ptr<cast_t>();
1072 if (expr_cast && !expr_cast->saturate
1073 && type == expr_cast->expr.type())
1074 return expr_cast->expr;
1075 }
1076 return expr_t(new cast_t(type, expr, saturate));
1077 }
1078
1079 bool is_equal(const object_impl_t &obj) const override {
1080 if (!obj.is<self_type>()) return false;
1081 auto &other = obj.as<self_type>();
1082
1083 return (type == other.type) && expr.is_equal(other.expr)
1084 && (saturate == other.saturate);
1085 }
1086
1087 size_t get_hash() const override {
1088 return ir_utils::get_hash(type, expr, saturate);
1089 }
1090
1091 bool is_bool_vec_u16() const {
1092 if (is_bool_vec(expr.type()) && is_u16_scalar(type)) return true;
1093 if (is_bool_vec(type) && is_u16_scalar(expr.type())) return true;
1094 return false;
1095 }
1096
1097 IR_DECLARE_TRAVERSERS()
1098
1099 expr_t expr;
1100 bool saturate;
1101
1102private:
1103 cast_t(const type_t &type, const expr_t &expr, bool saturate)
1104 : expr_impl_t(_type_info(), type), expr(expr), saturate(saturate) {
1105 if (!is_bool_vec_u16()) {
1106 ir_assert(type.elems() == expr.type().elems())
1107 << "Number of elements must match.";
1108 }
1109 }
1110
1111 static bool is_bool_vec(const type_t &type) {
1112 return type.is_bool() && type.elems() > 1;
1113 }
1114
1115 static bool is_u16_scalar(const type_t &type) {
1116 return type.is_u16() && type.is_scalar();
1117 }
1118};
1119
1120// Floating-point immediate value.
1121class float_imm_t : public expr_impl_t {
1122public:
1123 friend class expr_t;
1124 IR_DECL_EXPR_TYPE_ID(float_imm_t)
1125
1126 static expr_t make(double value, const type_t &type = type_t::undef()) {
1127 return expr_t(new float_imm_t(value, type));
1128 }
1129
1130 bool is_equal(const object_impl_t &obj) const override {
1131 if (!obj.is<self_type>()) return false;
1132 auto &other = obj.as<self_type>();
1133
1134 return value == other.value;
1135 }
1136
1137 size_t get_hash() const override { return ir_utils::get_hash(value); }
1138
1139 IR_DECLARE_TRAVERSERS()
1140
1141 double value;
1142
1143private:
1144 float_imm_t(double value, const type_t &type = type_t::undef())
1145 : expr_impl_t(_type_info(), type.is_undef() ? type_t::f32() : type)
1146 , value(value) {}
1147};
1148
1149// Integer immediate value.
1150class int_imm_t : public expr_impl_t {
1151public:
1152 friend class expr_t;
1153 IR_DECL_EXPR_TYPE_ID(int_imm_t);
1154
1155 template <typename T>
1156 static expr_t make(T value, const type_t &type = type_t::undef()) {
1157 return expr_t(new int_imm_t(value, type));
1158 }
1159
1160 bool is_equal(const object_impl_t &obj) const override {
1161 if (!obj.is<self_type>()) return false;
1162 auto &other = obj.as<self_type>();
1163
1164 return value == other.value;
1165 }
1166
1167 size_t get_hash() const override { return ir_utils::get_hash(value); }
1168
1169 static expr_t shrink_type(const expr_t &e) {
1170 auto &imm = e.as<int_imm_t>();
1171 type_t new_type = shrink_type(imm.value);
1172 if (new_type == imm.type) return e;
1173 return make(imm.value, new_type);
1174 }
1175
1176 template <typename T>
1177 static bool try_shrink_type(int64_t v) {
1178 if (v >= std::numeric_limits<T>::min()
1179 && v <= std::numeric_limits<T>::max())
1180 return true;
1181 return false;
1182 }
1183
1184 IR_DECLARE_TRAVERSERS()
1185
1186 int64_t value;
1187
1188private:
1189 int_imm_t(int64_t value, const type_t &type = type_t::undef())
1190 : expr_impl_t(_type_info(), type.is_undef() ? shrink_type(value) : type)
1191 , value(value) {}
1192
1193 static type_t shrink_type(int64_t v) {
1194 if (try_shrink_type<int32_t>(v)) return type_t::s32();
1195 return type_t::s64();
1196 }
1197};
1198
1199// Immediate if or the conditional (ternary) operator.
1200// C++ equivalent: (cond ? true_expr : false_expr).
1201class iif_t : public expr_impl_t {
1202public:
1203 IR_DECL_EXPR_TYPE_ID(iif_t);
1204
1205 static expr_t make(const expr_t &cond, const expr_t &true_expr,
1206 const expr_t &false_expr) {
1207 return expr_t(new iif_t(cond, true_expr, false_expr));
1208 }
1209
1210 bool is_equal(const object_impl_t &obj) const override {
1211 if (!obj.is<self_type>()) return false;
1212 auto &other = obj.as<self_type>();
1213
1214 return cond.is_equal(other.cond) && true_expr.is_equal(other.true_expr)
1215 && false_expr.is_equal(other.false_expr);
1216 }
1217
1218 size_t get_hash() const override {
1219 return ir_utils::get_hash(cond, true_expr, false_expr);
1220 }
1221
1222 IR_DECLARE_TRAVERSERS()
1223
1224 expr_t cond;
1225 expr_t true_expr;
1226 expr_t false_expr;
1227
1228private:
1229 iif_t(const expr_t &cond, const expr_t &true_expr, const expr_t &false_expr)
1230 : expr_impl_t(
1231 _type_info(), common_type(true_expr.type(), false_expr.type()))
1232 , cond(cond)
1233 , true_expr(true_expr)
1234 , false_expr(false_expr) {}
1235};
1236
1237// Updates `base_expr` and `off` so that after return:
1238// - base_expr contains a variable of a pointer type
1239// - off contains an offset
1240void normalize_ptr(const type_t &type, expr_t &base, expr_t &off);
1241
1242// Load from a GRF buffer.
1243// C++ equivalent (when type is scalar):
1244// load = *(type *)(&buf[off]);
1245// C++ equivalent (when type is vector):
1246// int _stride = (has_default_stride() ? sizeof(scalar_type) : stride);
1247// for (int i = 0; i < elems; i++) {
1248// load[i] = *(scalar_type *)(&buf[off + i * _stride]);
1249// }
1250class load_t : public expr_impl_t {
1251public:
1252 IR_DECL_EXPR_TYPE_ID(load_t)
1253
1254 // offset and stride are expressed in bytes.
1255 // default stride means unit stride (in terms of type.scalar() elements).
1256 static expr_t make(const type_t &type, const expr_t &buf, const expr_t &off,
1257 int stride = default_stride) {
1258 return expr_t(new load_t(type, buf, off, stride));
1259 }
1260
1261 bool is_equal(const object_impl_t &obj) const override {
1262 if (!obj.is<self_type>()) return false;
1263 auto &other = obj.as<self_type>();
1264
1265 return type.is_equal(other.type) && buf.is_equal(other.buf)
1266 && off.is_equal(other.off) && (stride == other.stride);
1267 }
1268
1269 size_t get_hash() const override {
1270 return ir_utils::get_hash(type, buf, off, stride);
1271 }
1272
1273 bool has_default_stride() const { return stride == default_stride; }
1274
1275 IR_DECLARE_TRAVERSERS()
1276
1277 static const int default_stride = -1;
1278
1279 expr_t buf;
1280 expr_t off;
1281 int stride;
1282
1283private:
1284 load_t(const type_t &_type, const expr_t &_buf, const expr_t &_off,
1285 int _stride)
1286 : expr_impl_t(_type_info(), _type)
1287 , buf(_buf)
1288 , off(_off)
1289 , stride(_stride) {
1290 normalize_ptr(type, buf, off);
1291 ir_assert(is_var(buf)) << buf;
1292 ir_assert(buf.type().is_ptr()) << buf;
1293 if (stride == type.scalar().size()) stride = default_stride;
1294 }
1295};
1296
1297// N-ary expression: (a[0] op a[1] op ... op a[n - 1]),
1298// where <op> is either addition or multiplication.
1299class nary_op_t : public expr_impl_t {
1300public:
1301 IR_DECL_EXPR_TYPE_ID(nary_op_t)
1302
1303 static expr_t make(op_kind_t op_kind, const std::vector<expr_t> &args) {
1304 return expr_t(new nary_op_t(op_kind, args));
1305 }
1306
1307 bool is_equal(const object_impl_t &obj) const override {
1308 if (!obj.is<self_type>()) return false;
1309 auto &other = obj.as<self_type>();
1310
1311 return (op_kind == other.op_kind)
1312 && ir_utils::is_equal(args, other.args);
1313 }
1314
1315 size_t get_hash() const override {
1316 return ir_utils::get_hash(op_kind, args);
1317 }
1318
1319 std::string str() const override {
1320 std::ostringstream oss;
1321 oss << "(";
1322 for (size_t i = 0; i < args.size(); i++) {
1323 oss << (i != 0 ? " " + to_string(op_kind) + " " : "") << args[i];
1324 }
1325
1326 oss << ")";
1327 return oss.str();
1328 }
1329
1330 IR_DECLARE_TRAVERSERS()
1331
1332 op_kind_t op_kind;
1333 std::vector<expr_t> args;
1334
1335private:
1336 nary_op_t(op_kind_t op_kind, const std::vector<expr_t> &args)
1337 : expr_impl_t(_type_info(), nary_op_type(op_kind, args))
1338 , op_kind(op_kind)
1339 , args(args) {}
1340};
1341
1342// Pointer expression: (base_ptr + off).
1343class ptr_t : public expr_impl_t {
1344public:
1345 IR_DECL_EXPR_TYPE_ID(ptr_t)
1346
1347 // off - offset in bytes.
1348 static expr_t make(const expr_t &base, const expr_t &off) {
1349 return expr_t(new ptr_t(base, off));
1350 }
1351
1352 bool is_equal(const object_impl_t &obj) const override {
1353 if (!obj.is<self_type>()) return false;
1354 auto &other = obj.as<self_type>();
1355
1356 return base.is_equal(other.base) && off.is_equal(other.off);
1357 }
1358
1359 size_t get_hash() const override { return ir_utils::get_hash(base, off); }
1360
1361 // Normalizes (base op off) pointer so that the new base is a variable and
1362 // off is an offset expression.
1363 // Example:
1364 // Before call: base = (base0 + off0), off = off1
1365 // After call: base = base0, off = off0 + off1
1366 static void normalize(
1367 expr_t &base, expr_t &off, op_kind_t op_kind = op_kind_t::_add);
1368
1369 IR_DECLARE_TRAVERSERS()
1370
1371 expr_t base;
1372 expr_t off;
1373
1374private:
1375 ptr_t(const expr_t &base, const expr_t &off)
1376 : expr_impl_t(_type_info(), base.type()), base(base), off(off) {
1377 normalize(this->base, this->off);
1378 }
1379};
1380
1381inline const expr_t &get_base(const expr_t &e) {
1382 if (e.is_empty()) return e;
1383 if (e.is<var_t>()) return e;
1384 if (e.is<ptr_t>()) return e.as<ptr_t>().base;
1385 ir_error_not_expected() << e;
1386 return e;
1387}
1388
1389class shuffle_t : public expr_impl_t {
1390public:
1391 IR_DECL_EXPR_TYPE_ID(shuffle_t)
1392
1393 static expr_t make(
1394 const std::vector<expr_t> &vec, const std::vector<int> &idx) {
1395 if (idx.size() == 1) return vec[idx[0]];
1396 return expr_t(new shuffle_t(vec, idx));
1397 }
1398
1399 static expr_t make(
1400 const std::vector<expr_t> &_vec, bool find_equal = true) {
1401 std::vector<expr_t> vec;
1402 std::vector<int> idx;
1403 for (auto &v : _vec) {
1404 bool found = false;
1405 int size = int(vec.size());
1406 if (find_equal) {
1407 for (int i = 0; i < size; i++) {
1408 if (v.is_equal(vec[i])) {
1409 idx.push_back(i);
1410 found = true;
1411 break;
1412 }
1413 }
1414 }
1415 if (!found) {
1416 vec.push_back(v);
1417 idx.push_back(size);
1418 }
1419 }
1420 return make(vec, idx);
1421 }
1422
1423 static expr_t make_broadcast(const expr_t &expr, int elems) {
1424 ir_assert(expr.type().is_scalar()) << expr;
1425 ir_assert(math::is_pow2(elems));
1426 return make({expr}, std::vector<int>(elems, 0));
1427 }
1428
1429 // Slices the existing shuffle expression. For inputs (S, beg, end) returns
1430 // (S[beg], S[beg + 1], ..., S[end - 1]) vector.
1431 static expr_t make(const expr_t &_shuffle, int beg, int end) {
1432 auto &shuffle = _shuffle.as<shuffle_t>();
1433 ir_assert(beg >= 0 && beg <= shuffle.elems());
1434 ir_assert(end >= 0 && end <= shuffle.elems());
1435 ir_assert(beg < end);
1436 std::vector<expr_t> vec;
1437 std::vector<int> idx(end - beg, -1);
1438 for (int i = beg; i < end; i++) {
1439 if (idx[i - beg] != -1) continue;
1440 int old_idx = shuffle.idx[i];
1441 vec.push_back(shuffle.vec[old_idx]);
1442 for (int j = i; j < end; j++) {
1443 if (shuffle.idx[j] == old_idx)
1444 idx[j - beg] = int(vec.size()) - 1;
1445 }
1446 }
1447 return make(vec, idx);
1448 }
1449
1450 bool is_equal(const object_impl_t &obj) const override {
1451 if (!obj.is<self_type>()) return false;
1452 auto &other = obj.as<self_type>();
1453
1454 return ir_utils::is_equal(vec, other.vec)
1455 && ir_utils::is_equal(idx, other.idx);
1456 }
1457
1458 size_t get_hash() const override { return ir_utils::get_hash(vec, idx); }
1459
1460 int elems() const { return int(idx.size()); }
1461
1462 bool is_vector() const {
1463 for (int i = 0; i < elems(); i++)
1464 if (idx[i] != i) return false;
1465 return true;
1466 }
1467
1468 bool is_broadcast() const { return vec.size() == 1; }
1469
1470 IR_DECLARE_TRAVERSERS()
1471
1472 std::vector<expr_t> vec;
1473 std::vector<int> idx;
1474
1475private:
1476 shuffle_t(const std::vector<expr_t> &vec, const std::vector<int> &idx)
1477 : expr_impl_t(_type_info(), shuffle_type(vec, idx))
1478 , vec(vec)
1479 , idx(idx) {
1480 ir_assert(idx.size() > 1) << "Unexpected empty or scalar shuffle.";
1481 }
1482
1483 static type_t shuffle_type(
1484 const std::vector<expr_t> &vec, const std::vector<int> &idx) {
1485 ir_assert(!vec.empty() && !idx.empty());
1486
1487 auto elem_type = vec[0].type();
1488 for (auto &v : vec)
1489 elem_type = common_type(elem_type, v.type());
1490
1491 for (size_t i = 0; i < idx.size(); i++) {
1492 ir_assert(idx[i] >= 0 && idx[i] < int(vec.size()))
1493 << "Incorrect index.";
1494 MAYBE_UNUSED(i);
1495 }
1496
1497 int elems = int(idx.size());
1498 return elem_type.with_elems(elems);
1499 }
1500};
1501
1502// Ternary operation: op(a, b, c).
1503class ternary_op_t : public expr_impl_t {
1504public:
1505 IR_DECL_EXPR_TYPE_ID(ternary_op_t)
1506
1507 static expr_t make(op_kind_t op_kind, const expr_t &a, const expr_t &b,
1508 const expr_t &c, type_t ty = type_t()) {
1509 return expr_t(new ternary_op_t(op_kind, a, b, c, ty));
1510 }
1511
1512 bool is_equal(const object_impl_t &obj) const override {
1513 if (!obj.is<self_type>()) return false;
1514 auto &other = obj.as<self_type>();
1515
1516 return (op_kind == other.op_kind) && a.is_equal(other.a)
1517 && b.is_equal(other.b) && c.is_equal(other.c);
1518 }
1519
1520 size_t get_hash() const override {
1521 return ir_utils::get_hash(op_kind, a, b, c);
1522 }
1523
1524 IR_DECLARE_TRAVERSERS()
1525
1526 op_kind_t op_kind;
1527 expr_t a;
1528 expr_t b;
1529 expr_t c;
1530
1531private:
1532 ternary_op_t(op_kind_t op_kind, const expr_t &a, const expr_t &b,
1533 const expr_t &c, type_t ty)
1534 : expr_impl_t(_type_info(),
1535 (ty.is_undef()) ? ternary_op_type(op_kind, a, b, c) : ty)
1536 , op_kind(op_kind)
1537 , a(a)
1538 , b(b)
1539 , c(c) {}
1540};
1541
1542inline expr_t ternary_mad(const expr_t &a, const expr_t &b, const expr_t &c) {
1543 return ternary_op_t::make(op_kind_t::_mad, a, b, c);
1544}
1545
1546inline expr_t ternary_add3(const expr_t &a, const expr_t &b, const expr_t &c) {
1547 return ternary_op_t::make(op_kind_t::_add3, a, b, c);
1548}
1549
1550// Unary operation: (op a).
1551class unary_op_t : public expr_impl_t {
1552public:
1553 IR_DECL_EXPR_TYPE_ID(unary_op_t)
1554
1555 static expr_t make(op_kind_t op_kind, const expr_t &a) {
1556 return expr_t(new unary_op_t(op_kind, a));
1557 }
1558
1559 bool is_equal(const object_impl_t &obj) const override {
1560 if (!obj.is<self_type>()) return false;
1561 auto &other = obj.as<self_type>();
1562
1563 return (op_kind == other.op_kind) && a.is_equal(other.a);
1564 }
1565
1566 size_t get_hash() const override { return ir_utils::get_hash(op_kind, a); }
1567
1568 IR_DECLARE_TRAVERSERS()
1569
1570 op_kind_t op_kind;
1571 expr_t a;
1572
1573private:
1574 unary_op_t(op_kind_t op_kind, const expr_t &a)
1575 : expr_impl_t(_type_info(), unary_op_type(op_kind, a))
1576 , op_kind(op_kind)
1577 , a(a) {}
1578};
1579
1580class var_t : public expr_impl_t {
1581public:
1582 IR_DECL_EXPR_TYPE_ID(var_t)
1583
1584 static expr_t make(const type_t &type, const std::string &name) {
1585 return expr_t(new var_t(type, name));
1586 }
1587
1588 bool is_equal(const object_impl_t &obj) const override {
1589 // Do not allow variable cloning.
1590 return this == &obj;
1591 }
1592
1593 size_t get_hash() const override { return ir_utils::get_hash(name); }
1594
1595 IR_DECLARE_TRAVERSERS()
1596
1597 std::string name;
1598
1599private:
1600 var_t(const type_t &type, const std::string &name)
1601 : expr_impl_t(_type_info(), type), name(name) {}
1602};
1603
1604// Convertor from C++ type to IR expression.
1605template <typename T>
1606expr_t to_expr(T value, const type_t &type) {
1607#define CASE(ir_type, cpp_type) \
1608 if (type == type_t::ir_type()) return expr_t((cpp_type)value)
1609
1610 CASE(_bool, bool);
1611 CASE(f32, float);
1612 CASE(f64, double);
1613 CASE(s16, int16_t);
1614 CASE(s32, int32_t);
1615 CASE(s64, int64_t);
1616 CASE(u16, uint16_t);
1617 CASE(u32, uint32_t);
1618 CASE(u64, uint64_t);
1619
1620#undef CASE
1621
1622 ir_error_not_expected() << type;
1623
1624 return expr_t();
1625}
1626
1627template <typename T>
1628expr_t to_expr(T value) {
1629 return to_expr(value, type_t::from_cpp<T>());
1630}
1631
1632inline bool is_binary_op(const expr_t &e) {
1633 return e.is<binary_op_t>();
1634}
1635
1636inline bool is_binary_op(const expr_t &e, op_kind_t op_kind) {
1637 if (!is_binary_op(e)) return false;
1638 return e.as<binary_op_t>().op_kind == op_kind;
1639}
1640
1641inline bool is_binary_cmp_op(const expr_t &e) {
1642 if (!is_binary_op(e)) return false;
1643 return is_cmp_op(e.as<binary_op_t>().op_kind);
1644}
1645
1646inline bool is_const(const expr_t &e) {
1647 return e.is<bool_imm_t>() || e.is<int_imm_t>() || e.is<float_imm_t>();
1648}
1649
1650inline bool all_of(const expr_t &e, const expr_t &value) {
1651 auto *shuffle = e.as_ptr<shuffle_t>();
1652 if (!shuffle) return e.is_equal(value);
1653 for (auto &i : shuffle->idx) {
1654 if (!shuffle->vec[i].is_equal(value)) return false;
1655 }
1656 return true;
1657}
1658
1659inline bool is_shuffle_const(const expr_t &e) {
1660 auto *shuffle = e.as_ptr<shuffle_t>();
1661 if (!shuffle) return false;
1662 for (auto &v : shuffle->vec)
1663 if (!is_const(v)) return false;
1664 return true;
1665}
1666
1667inline bool is_var(const expr_t &e) {
1668 return e.is<var_t>();
1669}
1670
1671// Convertor from IR expression to C++ constant.
1672template <typename T>
1673T to_cpp(const expr_t &e) {
1674 ir_assert(is_const(e)) << "Expression must be constant.";
1675
1676 if (e.is<int_imm_t>()) return (T)e.as<int_imm_t>().value;
1677 if (e.is<float_imm_t>()) return (T)e.as<float_imm_t>().value;
1678 if (e.is<bool_imm_t>()) return (T)e.as<bool_imm_t>().value;
1679
1680 ir_error_not_expected();
1681 return 0;
1682}
1683
1684expr_t operator-(const expr_t &a);
1685
1686#define DECLARE_BINARY_OPERATOR(op, op_kind) \
1687 expr_t operator op(const expr_t &a, const expr_t &b);
1688
1689DECLARE_BINARY_OPERATOR(+, op_kind_t::_add)
1690DECLARE_BINARY_OPERATOR(-, op_kind_t::_sub)
1691DECLARE_BINARY_OPERATOR(*, op_kind_t::_mul)
1692DECLARE_BINARY_OPERATOR(/, op_kind_t::_div)
1693DECLARE_BINARY_OPERATOR(%, op_kind_t::_mod)
1694DECLARE_BINARY_OPERATOR(<<, op_kind_t::_shl)
1695DECLARE_BINARY_OPERATOR(>>, op_kind_t::_shr)
1696
1697DECLARE_BINARY_OPERATOR(==, op_kind_t::_eq)
1698DECLARE_BINARY_OPERATOR(!=, op_kind_t::_ne)
1699DECLARE_BINARY_OPERATOR(>, op_kind_t::_gt)
1700DECLARE_BINARY_OPERATOR(>=, op_kind_t::_ge)
1701DECLARE_BINARY_OPERATOR(<, op_kind_t::_lt)
1702DECLARE_BINARY_OPERATOR(<=, op_kind_t::_le)
1703
1704DECLARE_BINARY_OPERATOR(&, op_kind_t::_and)
1705
1706#undef DECLARE_BINARY_OPERATOR
1707
1708// Returns a shifted pointer with base `a` (pointer) and offset `b` (in bytes).
1709// shift_ptr(op, a, b) returns &(a op b) in C++ terms (op is either addition or
1710// subtraction).
1711expr_t shift_ptr(op_kind_t op_kind, const expr_t &a, const expr_t &b);
1712
1713// Base class for IR statement objects.
1714class stmt_impl_t : public object_impl_t {
1715public:
1716 IR_DECL_TYPE_ID(stmt_impl_t)
1717 stmt_impl_t(type_info_t type_info) : object_impl_t(type_info) {}
1718};
1719
1720// Wrapper for IR statement objects.
1721class stmt_t : public object_t {
1722public:
1723 using object_t::object_t;
1724
1725 stmt_t() = default;
1726 stmt_t(const object_t &obj) : object_t(obj) {}
1727 stmt_t(object_t &&obj) : object_t(obj) {}
1728 stmt_t &operator=(const object_t &obj) {
1729 object_t::operator=(obj);
1730 return *this;
1731 }
1732 stmt_t &operator=(object_t &&obj) {
1733 object_t::operator=(obj);
1734 return *this;
1735 }
1736
1737 stmt_t append(const stmt_t &s) const;
1738
1739private:
1740#ifdef SANITY_CHECK
1741 void sanity_check() const override {
1742 ir_assert(dynamic_cast<const stmt_impl_t *>(impl()) == impl())
1743 << object_t(impl());
1744 }
1745#endif
1746};
1747
1748enum class alloc_kind_t {
1749 undef,
1750 grf, // GRF - general register file.
1751 slm, // SLM - shared local memory.
1752 global, // Global memory.
1753};
1754
1755class alloc_attr_impl_t : public object_impl_t {
1756public:
1757 alloc_attr_impl_t(type_info_t type_info) : object_impl_t(type_info) {}
1758};
1759
1760class alloc_attr_t : public object_t {
1761public:
1762 using object_t::object_t;
1763
1764 alloc_attr_t() = default;
1765 alloc_attr_t(const object_t &obj) : object_t(obj) {}
1766 alloc_attr_t(object_t &&obj) : object_t(obj) {}
1767 alloc_attr_t &operator=(const object_t &obj) {
1768 object_t::operator=(obj);
1769 return *this;
1770 }
1771 alloc_attr_t &operator=(object_t &&obj) {
1772 object_t::operator=(obj);
1773 return *this;
1774 }
1775
1776private:
1777#ifdef SANITY_CHECK
1778 void sanity_check() const override {
1779 ir_assert(dynamic_cast<const alloc_attr_impl_t *>(impl()) == impl())
1780 << object_t(impl());
1781 }
1782#endif
1783};
1784
1785class grf_permutation_t;
1786
1787// Allocation attribute specifying permutation for a GRF buffer.
1788class grf_permute_attr_t : public alloc_attr_impl_t {
1789public:
1790 IR_DECL_TYPE_ID(grf_permute_attr_t)
1791
1792 static alloc_attr_t make(
1793 const std::shared_ptr<grf_permutation_t> &grf_perm) {
1794 return alloc_attr_t(new grf_permute_attr_t(grf_perm));
1795 }
1796
1797 bool is_equal(const object_impl_t &obj) const override {
1798 return this == &obj;
1799 }
1800
1801 size_t get_hash() const override {
1802 return std::hash<const self_type *>()(this);
1803 }
1804
1805 std::shared_ptr<grf_permutation_t> grf_perm;
1806
1807private:
1808 grf_permute_attr_t(const std::shared_ptr<grf_permutation_t> &grf_perm)
1809 : alloc_attr_impl_t(_type_info()), grf_perm(grf_perm) {}
1810};
1811
1812// Allocation attribute to store extra information to avoid bank conflicts.
1813class bank_conflict_attr_t : public alloc_attr_impl_t {
1814public:
1815 IR_DECL_TYPE_ID(bank_conflict_attr_t)
1816
1817 static alloc_attr_t make(const std::vector<expr_t> &bufs,
1818 const std::vector<int> &buf_sizes,
1819 const std::vector<int> &buf_min_block_sizes,
1820 const std::vector<stmt_t> &instructions) {
1821 return alloc_attr_t(new bank_conflict_attr_t(
1822 bufs, buf_sizes, buf_min_block_sizes, instructions));
1823 }
1824
1825 bool is_equal(const object_impl_t &obj) const override {
1826 return this == &obj;
1827 }
1828
1829 size_t get_hash() const override {
1830 return std::hash<const self_type *>()(this);
1831 }
1832
1833 // List of buffers accessed from instructions.
1834 std::vector<expr_t> bufs;
1835 // Buffer sizes in bytes.
1836 std::vector<int> buf_sizes;
1837 // Minimum power-of-two block sizes for each buffer to avoid unhandled
1838 // cross-boundary accesses. A buffer may be allocated in fixed-size blocks
1839 // to avoid bank conflicts however the block size can't be arbitrary - we
1840 // need to avoid unhandled boundary crossings (e.g. in memory loads).
1841 std::vector<int> buf_min_block_sizes;
1842 // List of instructions whose bank conflicts are to be avoided.
1843 std::vector<stmt_t> instructions;
1844
1845private:
1846 bank_conflict_attr_t(const std::vector<expr_t> &bufs,
1847 const std::vector<int> &buf_sizes,
1848 const std::vector<int> &buf_min_block_sizes,
1849 const std::vector<stmt_t> &instructions)
1850 : alloc_attr_impl_t(_type_info())
1851 , bufs(bufs)
1852 , buf_sizes(buf_sizes)
1853 , buf_min_block_sizes(buf_min_block_sizes)
1854 , instructions(instructions) {}
1855};
1856
1857// Allocation for SLM and GRF buffers.
1858// C++ equivalent:
1859// {
1860// byte *buf = new byte[size];
1861// body;
1862// }
1863class alloc_t : public stmt_impl_t {
1864public:
1865 IR_DECL_STMT_TYPE_ID(alloc_t)
1866
1867 static stmt_t make(const expr_t &buf, int size, alloc_kind_t kind,
1868 const std::vector<alloc_attr_t> &attrs, const stmt_t &body = {}) {
1869 return stmt_t(new alloc_t(buf, size, kind, attrs, body));
1870 }
1871
1872 static stmt_t make(const expr_t &buf, int size, alloc_kind_t kind,
1873 const alloc_attr_t &attr, const stmt_t &body = {}) {
1874 std::vector<alloc_attr_t> attrs = {attr};
1875 return make(buf, size, kind, attrs, body);
1876 }
1877
1878 static stmt_t make(const expr_t &buf, int size, alloc_kind_t kind,
1879 const stmt_t &body = {}) {
1880 return make(buf, size, kind, std::vector<alloc_attr_t>(), body);
1881 }
1882
1883 bool is_equal(const object_impl_t &obj) const override {
1884 if (!obj.is<self_type>()) return false;
1885 auto &other = obj.as<self_type>();
1886
1887 return buf.is_equal(other.buf) && (size == other.size)
1888 && (kind == other.kind)
1889 && ir_utils::is_equal(attrs, other.attrs)
1890 && body.is_equal(other.body);
1891 }
1892
1893 size_t get_hash() const override {
1894 return ir_utils::get_hash(buf, size, kind, attrs, body);
1895 }
1896
1897 template <typename T>
1898 bool has_attr() const {
1899 for (auto &a : attrs)
1900 if (a.is<T>()) return true;
1901 return false;
1902 }
1903
1904 template <typename T>
1905 const T &get_attr() const {
1906 for (auto &a : attrs)
1907 if (a.is<T>()) return a.as<T>();
1908 ir_error_not_expected() << "Can't find attribute.";
1909 return attrs[0].as<T>();
1910 }
1911
1912 IR_DECLARE_TRAVERSERS()
1913
1914 expr_t buf;
1915 int size;
1916 alloc_kind_t kind;
1917 std::vector<alloc_attr_t> attrs;
1918 stmt_t body;
1919
1920private:
1921 alloc_t(const expr_t &buf, int size, alloc_kind_t kind,
1922 const std::vector<alloc_attr_t> &attrs, const stmt_t &body)
1923 : stmt_impl_t(_type_info())
1924 , buf(buf)
1925 , size(size)
1926 , kind(kind)
1927 , attrs(attrs)
1928 , body(body) {
1929 ir_assert(buf.type().is_ptr()) << buf;
1930 }
1931};
1932
1933// Store to a GRF buffer.
1934// C++ equivalent (when value is scalar):
1935// *(value_type *)(&buf[off]) = value;
1936// C++ equivalent (when value is vector):
1937// int _stride = (has_default_stride() ? sizeof(scalar_type) : stride);
1938// for (int i = 0; i < elems; i++) {
1939// *(scalar_type *)(&buf[off + i * _stride]) = value[i];
1940// }
1941class store_t : public stmt_impl_t {
1942public:
1943 IR_DECL_STMT_TYPE_ID(store_t)
1944
1945 // offset and stride are expressed in bytes.
1946 // default stride means unit stride (in terms of value.type().scalar()
1947 // elements).
1948 static stmt_t make(const expr_t &buf, const expr_t &off,
1949 const expr_t &value, int stride = default_stride,
1950 const expr_t &_mask = expr_t(), bool fill_mask0 = false) {
1951 auto mask = _mask;
1952 if (!mask.is_empty()) {
1953 if (all_of(mask, expr_t(true))) {
1954 mask = expr_t();
1955 } else if (all_of(mask, expr_t(false))) {
1956 // No need to store anything with a false mask.
1957 return stmt_t();
1958 }
1959 }
1960 return stmt_t(new store_t(buf, off, value, stride, mask, fill_mask0));
1961 }
1962
1963 bool is_equal(const object_impl_t &obj) const override {
1964 if (!obj.is<self_type>()) return false;
1965 auto &other = obj.as<self_type>();
1966
1967 return buf.is_equal(other.buf) && off.is_equal(other.off)
1968 && value.is_equal(other.value) && mask.is_equal(other.mask)
1969 && (stride == other.stride) && (fill_mask0 == other.fill_mask0);
1970 }
1971
1972 size_t get_hash() const override {
1973 return ir_utils::get_hash(buf, off, value, stride, mask, fill_mask0);
1974 }
1975
1976 bool has_default_stride() const { return stride == default_stride; }
1977
1978 IR_DECLARE_TRAVERSERS()
1979
1980 static const int default_stride = -1;
1981
1982 expr_t buf;
1983 expr_t off;
1984 expr_t value;
1985 int stride;
1986 expr_t mask;
1987 bool fill_mask0;
1988
1989private:
1990 store_t(const expr_t &_buf, const expr_t &_off, const expr_t &_value,
1991 int _stride, const expr_t &_mask, bool _fill_mask0)
1992 : stmt_impl_t(_type_info())
1993 , buf(_buf)
1994 , off(_off)
1995 , value(_value)
1996 , stride(_stride)
1997 , mask(_mask)
1998 , fill_mask0(_fill_mask0) {
1999 normalize_ptr(value.type(), buf, off);
2000 ir_assert(is_var(buf)) << buf;
2001 ir_assert(buf.type().is_ptr()) << buf;
2002 if (stride == value.type().scalar().size()) stride = default_stride;
2003 if (!mask.is_empty())
2004 ir_assert(mask.type() == type_t::_bool(value.type().elems()));
2005 }
2006};
2007
2008// Loop statement with unit increment.
2009// C++ equivalent:
2010// for (var = init; var < bound; var++) {
2011// body;
2012// }
2013// unroll specifies the unroll factor, unroll = 1 means no unrolling.
2014class for_t : public stmt_impl_t {
2015public:
2016 IR_DECL_STMT_TYPE_ID(for_t)
2017
2018 static stmt_t make(const expr_t &var, const expr_t &init,
2019 const expr_t &bound, const stmt_t &body = {}, int unroll = 1) {
2020 return stmt_t(new for_t(var, init, bound, body, unroll));
2021 }
2022
2023 bool is_equal(const object_impl_t &obj) const override {
2024 if (!obj.is<self_type>()) return false;
2025 auto &other = obj.as<self_type>();
2026
2027 return var.is_equal(other.var) && init.is_equal(other.init)
2028 && bound.is_equal(other.bound) && body.is_equal(other.body)
2029 && (unroll == other.unroll);
2030 }
2031
2032 size_t get_hash() const override {
2033 return ir_utils::get_hash(var, init, bound, body, unroll);
2034 }
2035
2036 IR_DECLARE_TRAVERSERS()
2037
2038 expr_t var;
2039 expr_t init;
2040 expr_t bound;
2041 stmt_t body;
2042 int unroll;
2043
2044private:
2045 for_t(const expr_t &var, const expr_t &init, const expr_t &bound,
2046 const stmt_t &body, int unroll)
2047 : stmt_impl_t(_type_info())
2048 , var(var)
2049 , init(init)
2050 , bound(bound)
2051 , body(body)
2052 , unroll(unroll) {}
2053};
2054
2055// If-else statement.
2056// C++ equivalent:
2057// if (cond) {
2058// body;
2059// } else {
2060// else_body;
2061// }
2062class if_t : public stmt_impl_t {
2063public:
2064 IR_DECL_STMT_TYPE_ID(if_t)
2065
2066 static stmt_t make(const expr_t &cond, const stmt_t &body,
2067 const stmt_t &else_body = stmt_t()) {
2068 return stmt_t(new if_t(cond, body, else_body));
2069 }
2070
2071 bool is_equal(const object_impl_t &obj) const override {
2072 if (!obj.is<self_type>()) return false;
2073 auto &other = obj.as<self_type>();
2074
2075 return cond.is_equal(other.cond) && body.is_equal(other.body)
2076 && else_body.is_equal(other.else_body);
2077 }
2078
2079 size_t get_hash() const override {
2080 return ir_utils::get_hash(cond, body, else_body);
2081 }
2082
2083 IR_DECLARE_TRAVERSERS()
2084
2085 expr_t cond;
2086 stmt_t body;
2087 stmt_t else_body;
2088
2089private:
2090 if_t(const expr_t &cond, const stmt_t &body, const stmt_t &else_body)
2091 : stmt_impl_t(_type_info())
2092 , cond(cond)
2093 , body(body)
2094 , else_body(else_body) {}
2095};
2096
2097// Let statement, used to bind a variable to a value within a scope.
2098// C++ equivalent:
2099// {
2100// var = value;
2101// body;
2102// }
2103class let_t : public stmt_impl_t {
2104public:
2105 IR_DECL_STMT_TYPE_ID(let_t)
2106
2107 static stmt_t make(
2108 const expr_t &var, const expr_t &value, const stmt_t &body = {}) {
2109 return stmt_t(new let_t(var, value, body));
2110 }
2111
2112 bool is_equal(const object_impl_t &obj) const override {
2113 if (!obj.is<self_type>()) return false;
2114 auto &other = obj.as<self_type>();
2115
2116 return var.is_equal(other.var) && value.is_equal(other.value)
2117 && body.is_equal(other.body);
2118 }
2119
2120 size_t get_hash() const override {
2121 return ir_utils::get_hash(var, value, body);
2122 }
2123
2124 IR_DECLARE_TRAVERSERS()
2125
2126 expr_t var;
2127 expr_t value;
2128 stmt_t body;
2129
2130private:
2131 let_t(const expr_t &var, const expr_t &value, const stmt_t &body)
2132 : stmt_impl_t(_type_info()), var(var), value(value), body(body) {
2133 if (!value.is_empty() && !is_const(value))
2134 ir_assert(var.type() == value.type());
2135 }
2136};
2137
2138// Statement label, specific to GEMM/convolution.
2139class stmt_label_t {
2140public:
2141 static stmt_label_t kernel(int index = -1) {
2142 return stmt_label_t(kind_t::_kernel, index);
2143 }
2144 static stmt_label_t compute_loop(int index = -1) {
2145 return stmt_label_t(kind_t::_compute_loop, index);
2146 }
2147 static stmt_label_t c_store(int index = -1) {
2148 return stmt_label_t(kind_t::_c_store, index);
2149 }
2150 static stmt_label_t c_zero_out(int index = -1) {
2151 return stmt_label_t(kind_t::_c_zero_out, index);
2152 }
2153 static stmt_label_t b_reduced_zero_out(int index = -1) {
2154 return stmt_label_t(kind_t::_b_reduced_zero_out, index);
2155 }
2156 static stmt_label_t g2s_load(int index = -1) {
2157 return stmt_label_t(kind_t::_g2s_load, index);
2158 }
2159 static stmt_label_t g2s_store(int index = -1) {
2160 return stmt_label_t(kind_t::_g2s_store, index);
2161 }
2162 static stmt_label_t g2r_load(int index = -1) {
2163 return stmt_label_t(kind_t::_g2r_load, index);
2164 }
2165 static stmt_label_t s2r_load(int index = -1) {
2166 return stmt_label_t(kind_t::_s2r_load, index);
2167 }
2168 static stmt_label_t prefetch(int index = -1) {
2169 return stmt_label_t(kind_t::_prefetch, index);
2170 }
2171 static stmt_label_t mul(int index = -1) {
2172 return stmt_label_t(kind_t::_mul, index);
2173 }
2174
2175 bool operator==(const stmt_label_t &other) const {
2176 if (kind_ != other.kind_) return false;
2177 if (index_ == -1 || other.index_ == -1) return true;
2178 return index_ == other.index_;
2179 }
2180
2181 size_t get_hash() const { return ir_utils::get_hash(kind_, index_); }
2182
2183 std::string str() const {
2184 switch (kind_) {
2185#define CASE(kind) \
2186 case kind_t::_##kind: return #kind
2187 CASE(kernel);
2188 CASE(compute_loop);
2189 CASE(c_store);
2190 CASE(c_zero_out);
2191 CASE(g2r_load);
2192 CASE(g2s_load);
2193 CASE(g2s_store);
2194 CASE(s2r_load);
2195 CASE(prefetch);
2196 CASE(mul);
2197#undef CASE
2198 default: ir_error_not_expected();
2199 }
2200 return {};
2201 }
2202
2203private:
2204 enum class kind_t {
2205 _undef,
2206 _kernel, // All kernel.
2207 _compute_loop, // Compute loop.
2208 _c_store, // GRF to GMEM store of C.
2209 _c_zero_out, // Zeroing-out of C.
2210 _b_reduced_zero_out, // Zeroing-out of B reduced buffer.
2211 _g2r_load, // GMEM to GRF load for further multiplication.
2212 _g2s_load, // GMEM to GRF load for GMEM -> SLM copy.
2213 _g2s_store, // GRF to SLM store for GMEM -> SLM copy.
2214 _s2r_load, // SLM to GRF load for further multiplication.
2215 _prefetch, // GMEM prefetch.
2216 _mul, // Multiplication.
2217 };
2218
2219 stmt_label_t() : kind_(kind_t::_undef), index_(-1) {}
2220 stmt_label_t(kind_t kind, int index) : kind_(kind), index_(index) {}
2221
2222 kind_t kind_;
2223 int index_; // Used to differentiate groups with the same kind.
2224};
2225
2226inline std::ostream &operator<<(std::ostream &out, const stmt_label_t &label) {
2227 out << label.str();
2228 return out;
2229}
2230
2231// Statement group, used to assign a label to a group of statements.
2232class stmt_group_t : public stmt_impl_t {
2233public:
2234 IR_DECL_STMT_TYPE_ID(stmt_group_t)
2235
2236 static stmt_t make(const stmt_label_t &label, const stmt_t &body) {
2237 return stmt_t(new stmt_group_t(label, body));
2238 }
2239
2240 bool is_equal(const object_impl_t &obj) const override {
2241 if (!obj.is<self_type>()) return false;
2242 auto &other = obj.as<self_type>();
2243
2244 return (label == other.label) && body.is_equal(other.body);
2245 }
2246
2247 size_t get_hash() const override { return ir_utils::get_hash(label, body); }
2248
2249 IR_DECLARE_TRAVERSERS()
2250
2251 stmt_label_t label;
2252 stmt_t body;
2253
2254private:
2255 stmt_group_t(const stmt_label_t &label, const stmt_t &body)
2256 : stmt_impl_t(_type_info()), label(label), body(body) {}
2257};
2258
2259// Statement sequence, allows combining two statements.
2260// C++ equivalent:
2261// {
2262// head;
2263// tail;
2264// }
2265class stmt_seq_t : public stmt_impl_t {
2266public:
2267 IR_DECL_STMT_TYPE_ID(stmt_seq_t)
2268
2269 static stmt_t make(const stmt_t &head, const stmt_t &tail) {
2270 return stmt_t(new stmt_seq_t(head, tail));
2271 }
2272
2273 bool is_equal(const object_impl_t &obj) const override {
2274 if (!obj.is<self_type>()) return false;
2275 auto &other = obj.as<self_type>();
2276
2277 return head.is_equal(other.head) && tail.is_equal(other.tail);
2278 }
2279
2280 size_t get_hash() const override { return ir_utils::get_hash(head, tail); }
2281
2282 IR_DECLARE_TRAVERSERS()
2283
2284 stmt_t head;
2285 stmt_t tail;
2286
2287private:
2288 stmt_seq_t(const stmt_t &head, const stmt_t &tail)
2289 : stmt_impl_t(_type_info()), head(head), tail(tail) {}
2290};
2291
2292inline stmt_t stmt_t::append(const stmt_t &s) const {
2293 if (is_empty()) return s;
2294 return stmt_seq_t::make(*this, s);
2295}
2296
2297// Function call attribute.
2298class func_call_attr_impl_t : public object_impl_t {
2299public:
2300 func_call_attr_impl_t(type_info_t type_info) : object_impl_t(type_info) {}
2301};
2302
2303class func_call_attr_t : public object_t {
2304public:
2305 using object_t::object_t;
2306
2307 func_call_attr_t() = default;
2308 func_call_attr_t(const object_t &obj) : object_t(obj) {}
2309 func_call_attr_t(object_t &&obj) : object_t(obj) {}
2310 func_call_attr_t &operator=(const object_t &obj) {
2311 object_t::operator=(obj);
2312 return *this;
2313 }
2314 func_call_attr_t &operator=(object_t &&obj) {
2315 object_t::operator=(obj);
2316 return *this;
2317 }
2318
2319 // Returns a function call with the attribute applied. The input statement
2320 // must be a function call.
2321 stmt_t apply_to(const stmt_t &s) const;
2322
2323private:
2324#ifdef SANITY_CHECK
2325 void sanity_check() const override {
2326 ir_assert(dynamic_cast<const func_call_attr_impl_t *>(impl()) == impl())
2327 << object_t(impl());
2328 }
2329#endif
2330};
2331
2332// Instruction modifier, relies on nGEN API.
2333class instruction_modifier_attr_t : public func_call_attr_impl_t {
2334public:
2335 IR_DECL_TYPE_ID(instruction_modifier_attr_t)
2336
2337 static func_call_attr_t make(const ngen_proxy::InstructionModifier &mod) {
2338 return func_call_attr_t(new instruction_modifier_attr_t(mod));
2339 }
2340
2341 bool is_equal(const object_impl_t &obj) const override {
2342 if (!obj.is<self_type>()) return false;
2343 auto &other = obj.as<self_type>();
2344
2345 return mod == other.mod;
2346 }
2347
2348 size_t get_hash() const override { return ir_utils::get_hash(mod); }
2349
2350 std::string str() const override {
2351 std::ostringstream oss;
2352 oss << "{";
2353 bool is_first = true;
2354 auto append = [&](const std::string &s) {
2355 if (!is_first) oss << ", ";
2356 oss << s;
2357 is_first = false;
2358 };
2359 if (mod.is_atomic) append("Atomic");
2360 if (!mod.sbid.is_empty()) {
2361 append(std::string("$") + std::to_string(mod.sbid.token));
2362 }
2363 oss << "}";
2364 return oss.str();
2365 }
2366
2367 ngen_proxy::InstructionModifier mod;
2368
2369private:
2370 instruction_modifier_attr_t(const ngen_proxy::InstructionModifier &mod)
2371 : func_call_attr_impl_t(_type_info()), mod(mod) {}
2372};
2373
2374// Base class for function IR objects.
2375class func_impl_t : public object_impl_t {
2376public:
2377 IR_DECL_TYPE_ID(func_impl_t)
2378
2379 func_impl_t(type_info_t type_info) : object_impl_t(type_info) {}
2380
2381 stmt_t call(const std::vector<expr_t> &args,
2382 const func_call_attr_t &attr = {}) const;
2383
2384 IR_DECLARE_TRAVERSERS()
2385};
2386
2387// Wrapper for IR function objects.
2388class func_t : public object_t {
2389public:
2390 using object_t::object_t;
2391
2392 func_t() = default;
2393 func_t(const object_t &obj) : object_t(obj) {}
2394 func_t(object_t &&obj) : object_t(obj) {}
2395 func_t &operator=(const object_t &obj) {
2396 object_t::operator=(obj);
2397 return *this;
2398 }
2399 func_t &operator=(object_t &&obj) {
2400 object_t::operator=(obj);
2401 return *this;
2402 }
2403
2404 stmt_t call(const std::vector<expr_t> &args = {},
2405 const func_call_attr_t &attr = {}) const {
2406 return ((const func_impl_t *)impl())->call(args, attr);
2407 }
2408
2409private:
2410#ifdef SANITY_CHECK
2411 void sanity_check() const override {
2412 ir_assert(dynamic_cast<const func_impl_t *>(impl()) == impl())
2413 << object_t(impl());
2414 }
2415#endif
2416};
2417
2418// Function call.
2419class func_call_t : public stmt_impl_t {
2420public:
2421 IR_DECL_STMT_TYPE_ID(func_call_t)
2422
2423 static stmt_t make(const func_t &func, const std::vector<expr_t> &args,
2424 const func_call_attr_t &attr = {}) {
2425 return stmt_t(new func_call_t(func, args, attr));
2426 }
2427
2428 bool is_equal(const object_impl_t &obj) const override {
2429 if (!obj.is<self_type>()) return false;
2430 auto &other = obj.as<self_type>();
2431
2432 return func.is_equal(other.func) && ir_utils::is_equal(args, other.args)
2433 && attr.is_equal(other.attr);
2434 }
2435
2436 size_t get_hash() const override {
2437 return ir_utils::get_hash(func, args, attr);
2438 }
2439
2440 IR_DECLARE_TRAVERSERS()
2441
2442 func_t func;
2443 std::vector<expr_t> args;
2444 func_call_attr_t attr;
2445
2446private:
2447 func_call_t(const func_t &func, const std::vector<expr_t> &args,
2448 const func_call_attr_t &attr)
2449 : stmt_impl_t(_type_info()), func(func), args(args), attr(attr) {
2450 ir_assert(!func.is_empty());
2451 }
2452};
2453
2454inline stmt_t func_impl_t::call(
2455 const std::vector<expr_t> &args, const func_call_attr_t &attr) const {
2456 return func_call_t::make(this, args, attr);
2457}
2458
2459inline stmt_t func_call_attr_t::apply_to(const stmt_t &s) const {
2460 auto &c = s.as<func_call_t>();
2461 ir_assert(c.attr.is_empty())
2462 << "Merging of attributes is not supported: " << s;
2463 return func_call_t::make(c.func, c.args, *this);
2464}
2465
2466template <typename F>
2467inline bool is_func_call(const stmt_t &s) {
2468 auto *c = s.as_ptr<func_call_t>();
2469 if (!c) return false;
2470 return c->func.is<F>();
2471}
2472
2473// Generic function with a name.
2474class builtin_t : public func_impl_t {
2475public:
2476 IR_DECL_DERIVED_TYPE_ID(builtin_t, func_impl_t)
2477
2478 static func_t make(const std::string &name) {
2479 return func_t(new builtin_t(name));
2480 }
2481
2482 bool is_equal(const object_impl_t &obj) const override {
2483 if (!obj.is<self_type>()) return false;
2484 auto &other = obj.as<self_type>();
2485
2486 return name == other.name;
2487 }
2488
2489 size_t get_hash() const override { return ir_utils::get_hash(name); }
2490
2491 std::string str() const override { return name; }
2492
2493 std::string name;
2494
2495private:
2496 builtin_t(const std::string &name)
2497 : func_impl_t(_type_info()), name(name) {}
2498};
2499
2500#ifndef SANITY_CHECK
2501// The following types are intrusive pointers and, as such, should have the same
2502// size as a pointer.
2503static_assert(sizeof(object_t) <= sizeof(void *),
2504 "intrusive pointer type object_t size is greater than void * size.");
2505static_assert(sizeof(expr_t) <= sizeof(void *),
2506 "intrusive pointer type expr_t size is greater than void * size.");
2507static_assert(sizeof(stmt_t) <= sizeof(void *),
2508 "intrusive pointer type stmt_t size is greater than void * size.");
2509#endif
2510
2511} // namespace jit
2512} // namespace gpu
2513} // namespace impl
2514} // namespace dnnl
2515
2516#endif
2517