1namespace dnnl {
2namespace impl {
3namespace gpu {
4namespace ocl {
5const char *xe_hpc_systolic_gemm_copy_kernel = R"==(/******************************************************************************* )==""\n"
6R"==(* Copyright 2021-2022 Intel Corporation )==""\n"
7R"==(* )==""\n"
8R"==(* Licensed under the Apache License, Version 2.0 (the "License"); )==""\n"
9R"==(* you may not use this file except in compliance with the License. )==""\n"
10R"==(* You may obtain a copy of the License at )==""\n"
11R"==(* )==""\n"
12R"==(* http: )==""\n"
13R"==(* )==""\n"
14R"==(* Unless required by applicable law or agreed to in writing, software )==""\n"
15R"==(* distributed under the License is distributed on an "AS IS" BASIS, )==""\n"
16R"==(* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. )==""\n"
17R"==(* See the License for the specific language governing permissions and )==""\n"
18R"==(* limitations under the License. )==""\n"
19R"==(*******************************************************************************/ )==""\n"
20R"==(#if ELEMENT_SIZE == 2 )==""\n"
21R"==(#pragma OPENCL EXTENSION cl_intel_subgroups_short : enable )==""\n"
22R"==(#define ELEMENT ushort )==""\n"
23R"==(#define ELEMENT2 ushort2 )==""\n"
24R"==(#define ELEMENT4 ushort4 )==""\n"
25R"==(#define ELEMENT_WORD ushort )==""\n"
26R"==(#define ELEMENT_WORD4 ushort4 )==""\n"
27R"==(#define ELEMENT_INT ushort2 )==""\n"
28R"==(#define ELEMENT_INT4 ushort8 )==""\n"
29R"==(#define VLOAD_ELEMENT_INT vload2 )==""\n"
30R"==(#define ELEMENTS_PER_INT 2 )==""\n"
31R"==(#define BLOCK_READ_ELEMENT2 intel_sub_group_block_read_us2 )==""\n"
32R"==(#define BLOCK_READ_ELEMENT4 intel_sub_group_block_read_us4 )==""\n"
33R"==(#define BLOCK_READ_ELEMENT_WORD intel_sub_group_block_read_us )==""\n"
34R"==(#define MASKED_BLOCK_READ_ELEMENT_WORD masked_block_read_element )==""\n"
35R"==(#define BLOCK_WRITE_ELEMENT_WORD4 intel_sub_group_block_write_us4 )==""\n"
36R"==(#define BLOCK_WRITE_ELEMENT_INT4 intel_sub_group_block_write_us8 )==""\n"
37R"==(#elif ELEMENT_SIZE == 1 )==""\n"
38R"==(#define ELEMENT uchar )==""\n"
39R"==(#define ELEMENT2 uchar2 )==""\n"
40R"==(#define ELEMENT4 uchar4 )==""\n"
41R"==(#define ELEMENT_WORD uchar2 )==""\n"
42R"==(#define ELEMENT_WORD4 uchar8 )==""\n"
43R"==(#define ELEMENT_INT uchar4 )==""\n"
44R"==(#define ELEMENT_INT4 uchar16 )==""\n"
45R"==(#define VLOAD_ELEMENT_INT vload4 )==""\n"
46R"==(#define BLOCK_READ_ELEMENT2 intel_sub_group_block_read_uc2 )==""\n"
47R"==(#define BLOCK_READ_ELEMENT4 intel_sub_group_block_read_uc4 )==""\n"
48R"==(#define BLOCK_READ_ELEMENT_WORD intel_sub_group_block_read_uc2 )==""\n"
49R"==(#define MASKED_BLOCK_READ_ELEMENT_WORD masked_block_read_element2 )==""\n"
50R"==(#define BLOCK_WRITE_ELEMENT_WORD4 intel_sub_group_block_write_uc8 )==""\n"
51R"==(#define BLOCK_WRITE_ELEMENT_INT2 intel_sub_group_block_write_uc8 )==""\n"
52R"==(#define BLOCK_WRITE_ELEMENT_INT4 intel_sub_group_block_write_uc16 )==""\n"
53R"==(#define ELEMENTS_PER_INT 4 )==""\n"
54R"==(#define SUM_T int )==""\n"
55R"==(#define SUM_T2 int2 )==""\n"
56R"==(#define SUM_T4 int4 )==""\n"
57R"==(#define CONVERT_SUM_T convert_int )==""\n"
58R"==(#define CONVERT_SUM_T2 convert_int2 )==""\n"
59R"==(#define CONVERT_SUM_T4 convert_int4 )==""\n"
60R"==(#if COPY_SIGNED )==""\n"
61R"==(#define AS_SIGNED_ELEMENT as_char )==""\n"
62R"==(#define AS_SIGNED_ELEMENT4 as_char4 )==""\n"
63R"==(#define AS_SIGNED_ELEMENT_WORD as_char2 )==""\n"
64R"==(#define AS_SIGNED_ELEMENT_INT as_char4 )==""\n"
65R"==(#define SIGNED_ELEMENT_WORD char2 )==""\n"
66R"==(#define SIGNED_ELEMENT_INT char4 )==""\n"
67R"==(#else )==""\n"
68R"==(#define AS_SIGNED_ELEMENT as_uchar )==""\n"
69R"==(#define AS_SIGNED_ELEMENT4 as_uchar4 )==""\n"
70R"==(#define AS_SIGNED_ELEMENT_WORD as_uchar2 )==""\n"
71R"==(#define AS_SIGNED_ELEMENT_INT as_uchar4 )==""\n"
72R"==(#define SIGNED_ELEMENT_WORD uchar2 )==""\n"
73R"==(#define SIGNED_ELEMENT_INT uchar4 )==""\n"
74R"==(#endif )==""\n"
75R"==(#else )==""\n"
76R"==(#error Unsupported element size. )==""\n"
77R"==(#endif )==""\n"
78R"==(#if !COPY_A && !COPY_B )==""\n"
79R"==(#error Source matrix not defined. )==""\n"
80R"==(#endif )==""\n"
81R"==(inline ELEMENT masked_block_read_element(global ELEMENT *p, int rem) { )==""\n"
82R"==(ELEMENT v; )==""\n"
83R"==(int lid = get_sub_group_local_id(); )==""\n"
84R"==(int sg = get_sub_group_size(); )==""\n"
85R"==(v = (lid < rem) ? p[lid] : 0; )==""\n"
86R"==(return v; )==""\n"
87R"==(} )==""\n"
88R"==(inline ELEMENT2 masked_block_read_element2(global ELEMENT *p, int rem) { )==""\n"
89R"==(ELEMENT2 v; )==""\n"
90R"==(int lid = get_sub_group_local_id(); )==""\n"
91R"==(int sg = get_sub_group_size(); )==""\n"
92R"==(v.s0 = (lid < rem) ? p[lid] : 0; )==""\n"
93R"==(v.s1 = (lid + sg < rem) ? p[lid + sg] : 0; )==""\n"
94R"==(return v; )==""\n"
95R"==(} )==""\n"
96R"==(inline ELEMENT4 masked_block_read_element4(global ELEMENT *p, int rem) { )==""\n"
97R"==(ELEMENT4 v; )==""\n"
98R"==(int lid = get_sub_group_local_id(); )==""\n"
99R"==(int sg = get_sub_group_size(); )==""\n"
100R"==(v.s0 = (lid < rem) ? p[lid] : 0; )==""\n"
101R"==(v.s1 = (lid + sg < rem) ? p[lid + sg] : 0; )==""\n"
102R"==(v.s2 = (lid + 2 * sg < rem) ? p[lid + 2 * sg] : 0; )==""\n"
103R"==(v.s3 = (lid + 3 * sg < rem) ? p[lid + 3 * sg] : 0; )==""\n"
104R"==(return v; )==""\n"
105R"==(} )==""\n"
106R"==(__attribute__((overloadable)) inline int sum(int v) { )==""\n"
107R"==(return sub_group_reduce_add(v); )==""\n"
108R"==(} )==""\n"
109R"==(__attribute__((overloadable)) inline int sum(int2 v) { )==""\n"
110R"==(return sub_group_reduce_add(v.s0) + sub_group_reduce_add(v.s1); )==""\n"
111R"==(} )==""\n"
112R"==(__attribute__((overloadable)) inline int sum(int4 v) { )==""\n"
113R"==(return sub_group_reduce_add(v.s0) + sub_group_reduce_add(v.s1) )==""\n"
114R"==(+ sub_group_reduce_add(v.s2) + sub_group_reduce_add(v.s3); )==""\n"
115R"==(} )==""\n"
116R"==(void dummy_dpas() { )==""\n"
117R"==(if (get_sub_group_local_id() >= 16) { )==""\n"
118R"==(int __builtin_IB_sub_group_idpas_s8_s8_8_1(int, int, int8) )==""\n"
119R"==(__attribute__((const)); )==""\n"
120R"==(global volatile int *_; )==""\n"
121R"==(int z = __builtin_IB_sub_group_idpas_s8_s8_8_1(0, _[0], 1); )==""\n"
122R"==(for (int i = 0; i < z; i++) )==""\n"
123R"==((void)_[0]; )==""\n"
124R"==(} )==""\n"
125R"==(} )==""\n"
126R"==(#define DUMMY_DPAS /*dummy_dpas()*/ )==""\n"
127R"==(#if ELEMENT_SIZE == 2 )==""\n"
128R"==(#define PARTIAL_LOAD(regs, rrem, crem, cc, p) \ )==""\n"
129R"==(if ((2 * cc + 1) < crem) { \ )==""\n"
130R"==(if (lid < rrem) regs[cc] = vload2(0, p); \ )==""\n"
131R"==(} else if ((2 * cc) < crem) { \ )==""\n"
132R"==(if (lid < rrem) regs[cc].s0 = *(p); \ )==""\n"
133R"==(} )==""\n"
134R"==(#elif ELEMENT_SIZE == 1 )==""\n"
135R"==(#define PARTIAL_LOAD(regs, rrem, crem, cc, p) \ )==""\n"
136R"==(if ((4 * cc + 3) < crem) { \ )==""\n"
137R"==(if (lid < rrem) regs[cc] = vload4(0, p); \ )==""\n"
138R"==(} else if ((4 * cc + 2) < crem) { \ )==""\n"
139R"==(if (lid < rrem) regs[cc].s012 = vload3(0, p); \ )==""\n"
140R"==(} else if ((4 * cc + 1) < crem) { \ )==""\n"
141R"==(if (lid < rrem) regs[cc].s01 = vload2(0, p); \ )==""\n"
142R"==(} else if (4 * cc < crem) { \ )==""\n"
143R"==(if (lid < rrem) regs[cc].s0 = *(p); \ )==""\n"
144R"==(} )==""\n"
145R"==(#endif )==""\n"
146R"==(#if COPY_A )==""\n"
147R"==(#define UNROLL_M 64 )==""\n"
148R"==(#define UNROLL_K (32 / ELEMENT_SIZE) )==""\n"
149R"==(#if COPY_SUM )==""\n"
150R"==(#define GET_A_SUM_ADDRESS \ )==""\n"
151R"==(global int *a_sum = (global int *)(a_packed + offseta_packed \ )==""\n"
152R"==(+ (m0 + UNROLL_M) * lda_packed - UNROLL_M * sizeof(int)); )==""\n"
153R"==(#else )==""\n"
154R"==(#define GET_A_SUM_ADDRESS )==""\n"
155R"==(#endif )==""\n"
156R"==(#if COPY_CLEAR_SUM )==""\n"
157R"==(__attribute__((intel_reqd_sub_group_size(16))) kernel void )==""\n"
158R"==(xe_hpc_systolic_gemm_copy(long m, long k, global ELEMENT *a_packed, )==""\n"
159R"==(int offseta_packed, int lda_packed) { )==""\n"
160R"==(uint m0 = (sub_group_broadcast(get_global_id(0), 0) / 16) * UNROLL_M; )==""\n"
161R"==(GET_A_SUM_ADDRESS; )==""\n"
162R"==(uint4 zero = 0; )==""\n"
163R"==(intel_sub_group_block_write4(a_sum, zero); )==""\n"
164R"==(} )==""\n"
165R"==(#elif !COPY_TRANS )==""\n"
166R"==(#if ELEMENT_SIZE == 2 )==""\n"
167R"==(#define REPACK_REG(rr, cc) \ )==""\n"
168R"==(blk_r[rr].s##cc = (((uint)c[2 * cc + 1].s##rr) << 16) | c[2 * cc].s##rr )==""\n"
169R"==(#elif ELEMENT_SIZE == 1 )==""\n"
170R"==(#define REPACK_REG(rr, cc) \ )==""\n"
171R"==(blk_r[rr].s##cc = (((uint)c[4 * cc + 3].s##rr) << 24) \ )==""\n"
172R"==(| (((uint)c[4 * cc + 2].s##rr) << 16) \ )==""\n"
173R"==(| (((uint)c[4 * cc + 1].s##rr) << 8) | c[4 * cc].s##rr )==""\n"
174R"==(#endif )==""\n"
175R"==(#define REPACK_CC(cc) \ )==""\n"
176R"==(REPACK_REG(0, cc); \ )==""\n"
177R"==(REPACK_REG(1, cc); \ )==""\n"
178R"==(REPACK_REG(2, cc); \ )==""\n"
179R"==(REPACK_REG(3, cc) )==""\n"
180R"==(#define REPACK \ )==""\n"
181R"==(REPACK_CC(0); \ )==""\n"
182R"==(REPACK_CC(1); \ )==""\n"
183R"==(REPACK_CC(2); \ )==""\n"
184R"==(REPACK_CC(3); \ )==""\n"
185R"==(REPACK_CC(4); \ )==""\n"
186R"==(REPACK_CC(5); \ )==""\n"
187R"==(REPACK_CC(6); \ )==""\n"
188R"==(REPACK_CC(7) )==""\n"
189R"==(__attribute__((intel_reqd_sub_group_size(16))) kernel void )==""\n"
190R"==(xe_hpc_systolic_gemm_copy(long m, long k, global ELEMENT *a, long offseta, )==""\n"
191R"==(long lda, global ELEMENT *a_packed, int offseta_packed, )==""\n"
192R"==(int lda_packed) { )==""\n"
193R"==(int lid = get_sub_group_local_id(); )==""\n"
194R"==(uint m0 = (sub_group_broadcast(get_global_id(0), 0) / 16) * UNROLL_M; )==""\n"
195R"==(uint k0 = get_global_id(1) * UNROLL_K; )==""\n"
196R"==(int mrem = m - m0; )==""\n"
197R"==(int krem = k - k0; )==""\n"
198R"==(bool aligned = ((as_long(a) | lda | offseta) & (ELEMENTS_PER_INT - 1)) == 0; )==""\n"
199R"==(if (mrem <= 0 || krem <= 0) return; )==""\n"
200R"==(GET_A_SUM_ADDRESS; )==""\n"
201R"==(a += offseta + m0 + k0 * lda; )==""\n"
202R"==(a_packed += offseta_packed + m0 * lda_packed + k0 * UNROLL_M; )==""\n"
203R"==(ELEMENT4 c[UNROLL_K]; )==""\n"
204R"==(if (mrem >= UNROLL_M && krem >= UNROLL_K && aligned) { )==""\n"
205R"==(for (int h = 0; h < UNROLL_K; h++) )==""\n"
206R"==(c[h] = BLOCK_READ_ELEMENT4(a + h * lda); )==""\n"
207R"==(} else { )==""\n"
208R"==(for (int h = 0; h < UNROLL_K; h++) )==""\n"
209R"==(if (h < krem) )==""\n"
210R"==(c[h] = masked_block_read_element4(a + h * lda, mrem); )==""\n"
211R"==(else )==""\n"
212R"==(c[h] = 0; )==""\n"
213R"==(} )==""\n"
214R"==(uint8 blk_r[UNROLL_M / 16]; )==""\n"
215R"==(REPACK; )==""\n"
216R"==(for (int rr = 0; rr < UNROLL_M / 16; rr++) )==""\n"
217R"==(intel_sub_group_block_write8( )==""\n"
218R"==((global uint *)(a_packed + rr * UNROLL_K * 16), blk_r[rr]); )==""\n"
219R"==(#if COPY_SUM )==""\n"
220R"==(SUM_T4 sum = 0; )==""\n"
221R"==(for (int h = 0; h < UNROLL_K; h++) )==""\n"
222R"==(sum += CONVERT_SUM_T4(AS_SIGNED_ELEMENT4(c[h])); )==""\n"
223R"==(atomic_add(a_sum + lid, sum.s0); )==""\n"
224R"==(atomic_add(a_sum + lid + 16, sum.s1); )==""\n"
225R"==(atomic_add(a_sum + lid + 32, sum.s2); )==""\n"
226R"==(atomic_add(a_sum + lid + 48, sum.s3); )==""\n"
227R"==(#endif )==""\n"
228R"==(DUMMY_DPAS; )==""\n"
229R"==(} )==""\n"
230R"==(#else /* COPY_TRANS */ )==""\n"
231R"==(__attribute__((intel_reqd_workgroup_walk_order(1, 0))) )==""\n"
232R"==(__attribute__((intel_reqd_sub_group_size(16))) kernel void )==""\n"
233R"==(xe_hpc_systolic_gemm_copy(long m, long k, global ELEMENT *a, long offseta, )==""\n"
234R"==(long lda, global ELEMENT *a_packed, int offseta_packed, )==""\n"
235R"==(int lda_packed) { )==""\n"
236R"==(int lid = get_sub_group_local_id(); )==""\n"
237R"==(uint m0 = (sub_group_broadcast(get_global_id(0), 0) / 16) * UNROLL_M; )==""\n"
238R"==(uint k0 = get_global_id(1) * UNROLL_K; )==""\n"
239R"==(int mrem = m - m0; )==""\n"
240R"==(int krem = k - k0; )==""\n"
241R"==(if (mrem <= 0 || krem <= 0) return; )==""\n"
242R"==(GET_A_SUM_ADDRESS; )==""\n"
243R"==(a += offseta + m0 * lda + k0; )==""\n"
244R"==(a_packed += offseta_packed + m0 * lda_packed + k0 * UNROLL_M; )==""\n"
245R"==(#if COPY_SUM )==""\n"
246R"==(SUM_T sum[UNROLL_M / 16] = {0}; )==""\n"
247R"==(#endif )==""\n"
248R"==(for (int rr = 0; rr < UNROLL_M / 16; rr++, mrem -= 16) { )==""\n"
249R"==(ELEMENT_INT regs[8]; )==""\n"
250R"==(if (mrem >= UNROLL_M && krem >= UNROLL_K) { )==""\n"
251R"==(for (int cc = 0; cc < UNROLL_K / ELEMENTS_PER_INT; cc++) )==""\n"
252R"==(regs[cc] = VLOAD_ELEMENT_INT(0, )==""\n"
253R"==(a + ((rr * 16) + lid) * lda + (cc * ELEMENTS_PER_INT)); )==""\n"
254R"==(} else { )==""\n"
255R"==(for (int cc = 0; cc < UNROLL_K / ELEMENTS_PER_INT; cc++) { )==""\n"
256R"==(regs[cc] = 0; )==""\n"
257R"==(PARTIAL_LOAD(regs, mrem, krem, cc, )==""\n"
258R"==(a + ((rr * 16) + lid) * lda + (cc * ELEMENTS_PER_INT)); )==""\n"
259R"==(} )==""\n"
260R"==(} )==""\n"
261R"==(uint8 blk_r; )==""\n"
262R"==(blk_r.s0 = as_uint(regs[0]); )==""\n"
263R"==(blk_r.s1 = as_uint(regs[1]); )==""\n"
264R"==(blk_r.s2 = as_uint(regs[2]); )==""\n"
265R"==(blk_r.s3 = as_uint(regs[3]); )==""\n"
266R"==(blk_r.s4 = as_uint(regs[4]); )==""\n"
267R"==(blk_r.s5 = as_uint(regs[5]); )==""\n"
268R"==(blk_r.s6 = as_uint(regs[6]); )==""\n"
269R"==(blk_r.s7 = as_uint(regs[7]); )==""\n"
270R"==(#if COPY_SUM )==""\n"
271R"==(for (int cc = 0; cc < UNROLL_K / ELEMENTS_PER_INT; cc++) { )==""\n"
272R"==(sum[rr] += CONVERT_SUM_T(AS_SIGNED_ELEMENT(regs[cc].s0)); )==""\n"
273R"==(sum[rr] += CONVERT_SUM_T(AS_SIGNED_ELEMENT(regs[cc].s1)); )==""\n"
274R"==(sum[rr] += CONVERT_SUM_T(AS_SIGNED_ELEMENT(regs[cc].s2)); )==""\n"
275R"==(sum[rr] += CONVERT_SUM_T(AS_SIGNED_ELEMENT(regs[cc].s3)); )==""\n"
276R"==(} )==""\n"
277R"==(#endif )==""\n"
278R"==(intel_sub_group_block_write8( )==""\n"
279R"==((global uint *)(a_packed + rr * UNROLL_K * 16), blk_r); )==""\n"
280R"==(} )==""\n"
281R"==(#if COPY_SUM )==""\n"
282R"==(atomic_add(a_sum + lid, sum[0]); )==""\n"
283R"==(atomic_add(a_sum + lid + 16, sum[1]); )==""\n"
284R"==(atomic_add(a_sum + lid + 32, sum[2]); )==""\n"
285R"==(atomic_add(a_sum + lid + 48, sum[3]); )==""\n"
286R"==(#endif )==""\n"
287R"==(DUMMY_DPAS; )==""\n"
288R"==(} )==""\n"
289R"==(#endif /* !COPY_TRANS */ )==""\n"
290R"==(#endif /* COPY_A */ )==""\n"
291R"==(#if COPY_B )==""\n"
292R"==(#define UNROLL_K (32 / ELEMENT_SIZE) )==""\n"
293R"==(#if ELEMENT_SIZE == 2 )==""\n"
294R"==(#define REPACK_CC_WORD(cc) \ )==""\n"
295R"==(do { \ )==""\n"
296R"==(colgroups[cc].s0 = cols[cc * 4]; \ )==""\n"
297R"==(colgroups[cc].s1 = cols[cc * 4 + 1]; \ )==""\n"
298R"==(colgroups[cc].s2 = cols[cc * 4 + 2]; \ )==""\n"
299R"==(colgroups[cc].s3 = cols[cc * 4 + 3]; \ )==""\n"
300R"==(} while (false) )==""\n"
301R"==(#define REPACK_CC(cc) \ )==""\n"
302R"==(do { \ )==""\n"
303R"==(colgroups[cc].s01 = cols[cc * 4]; \ )==""\n"
304R"==(colgroups[cc].s23 = cols[cc * 4 + 1]; \ )==""\n"
305R"==(colgroups[cc].s45 = cols[cc * 4 + 2]; \ )==""\n"
306R"==(colgroups[cc].s67 = cols[cc * 4 + 3]; \ )==""\n"
307R"==(} while (false) )==""\n"
308R"==(#elif ELEMENT_SIZE == 1 )==""\n"
309R"==(#define REPACK_CC_WORD(cc) \ )==""\n"
310R"==(do { \ )==""\n"
311R"==(colgroups[cc].s01 = cols[cc * 4]; \ )==""\n"
312R"==(colgroups[cc].s23 = cols[cc * 4 + 1]; \ )==""\n"
313R"==(colgroups[cc].s45 = cols[cc * 4 + 2]; \ )==""\n"
314R"==(colgroups[cc].s67 = cols[cc * 4 + 3]; \ )==""\n"
315R"==(} while (false) )==""\n"
316R"==(#define REPACK_CC(cc) \ )==""\n"
317R"==(do { \ )==""\n"
318R"==(colgroups[cc].s0123 = cols[cc * 4]; \ )==""\n"
319R"==(colgroups[cc].s4567 = cols[cc * 4 + 1]; \ )==""\n"
320R"==(colgroups[cc].s89ab = cols[cc * 4 + 2]; \ )==""\n"
321R"==(colgroups[cc].scdef = cols[cc * 4 + 3]; \ )==""\n"
322R"==(} while (false) )==""\n"
323R"==(#define REPACK_CC2(cc) \ )==""\n"
324R"==(do { \ )==""\n"
325R"==(colgroups[cc].s0246 = cols[cc * 2]; \ )==""\n"
326R"==(colgroups[cc].s1357 = cols2[cc * 2]; \ )==""\n"
327R"==(colgroups[cc].s8ace = cols[cc * 2 + 1]; \ )==""\n"
328R"==(colgroups[cc].s9bdf = cols2[cc * 2 + 1]; \ )==""\n"
329R"==(} while (false) )==""\n"
330R"==(#endif )==""\n"
331R"==(#if COPY_SUM )==""\n"
332R"==(#define GET_B_SUM_ADDRESS \ )==""\n"
333R"==(global int *b_sum = (global int *)(b_packed + offsetb_packed \ )==""\n"
334R"==(+ (n0 + UNROLL_N) * ldb_packed - UNROLL_N * sizeof(int)); )==""\n"
335R"==(#else )==""\n"
336R"==(#define GET_B_SUM_ADDRESS )==""\n"
337R"==(#endif )==""\n"
338R"==(#if COPY_CLEAR_SUM )==""\n"
339R"==(__attribute__((intel_reqd_sub_group_size(16))) kernel void )==""\n"
340R"==(xe_hpc_systolic_gemm_copy(long k, long n, global ELEMENT *b_packed, )==""\n"
341R"==(int offsetb_packed, int ldb_packed) { )==""\n"
342R"==(uint n0 = (sub_group_broadcast(get_global_id(0), 0) / 16) * UNROLL_N; )==""\n"
343R"==(GET_B_SUM_ADDRESS; )==""\n"
344R"==(uint2 zero = 0; )==""\n"
345R"==(intel_sub_group_block_write2(b_sum, zero); )==""\n"
346R"==(#if UNROLL_N > 32 )==""\n"
347R"==(intel_sub_group_block_write(b_sum + 32, zero.s0); )==""\n"
348R"==(#endif )==""\n"
349R"==(} )==""\n"
350R"==(#elif !COPY_TRANS )==""\n"
351R"==(__attribute__((intel_reqd_sub_group_size(16))) kernel void )==""\n"
352R"==(xe_hpc_systolic_gemm_copy(long k, long n, global ELEMENT *b, long offsetb, )==""\n"
353R"==(long ldb, global ELEMENT *b_packed, int offsetb_packed, )==""\n"
354R"==(int ldb_packed) { )==""\n"
355R"==(int lid = get_sub_group_local_id(); )==""\n"
356R"==(uint k0 = (sub_group_broadcast(get_global_id(0), 0) / 16) * UNROLL_K; )==""\n"
357R"==(uint n0 = get_global_id(1) * UNROLL_N; )==""\n"
358R"==(int krem = k - k0; )==""\n"
359R"==(int nrem = n - n0; )==""\n"
360R"==(bool aligned = ((as_long(b) | ldb | offsetb) & (ELEMENTS_PER_INT - 1)) == 0; )==""\n"
361R"==(if (nrem <= 0 || krem <= 0) return; )==""\n"
362R"==(GET_B_SUM_ADDRESS; )==""\n"
363R"==(b += offsetb + k0 + n0 * ldb; )==""\n"
364R"==(b_packed += offsetb_packed + n0 * ldb_packed + k0 * UNROLL_N; )==""\n"
365R"==(#define UNROLL_N_CHUNK (UNROLL_N / 2) )==""\n"
366R"==(#if COPY_SUM )==""\n"
367R"==(SUM_T sums[UNROLL_N]; )==""\n"
368R"==(#endif )==""\n"
369R"==(ELEMENT_WORD cols[UNROLL_N / 2]; )==""\n"
370R"==(for (int c0 = 0; c0 < UNROLL_N; )==""\n"
371R"==(c0 += UNROLL_N_CHUNK, nrem -= UNROLL_N_CHUNK) { )==""\n"
372R"==(if (krem >= UNROLL_K && nrem >= UNROLL_N_CHUNK && aligned) { )==""\n"
373R"==(for (int c = 0; c < UNROLL_N_CHUNK; c++) )==""\n"
374R"==(cols[c] = BLOCK_READ_ELEMENT_WORD(b + (c + c0) * ldb); )==""\n"
375R"==(} else { )==""\n"
376R"==(for (int c = 0; c < UNROLL_N_CHUNK; c++) )==""\n"
377R"==(if (c < nrem) )==""\n"
378R"==(cols[c] = MASKED_BLOCK_READ_ELEMENT_WORD( )==""\n"
379R"==(b + (c + c0) * ldb, krem); )==""\n"
380R"==(else )==""\n"
381R"==(cols[c] = 0; )==""\n"
382R"==(} )==""\n"
383R"==(ELEMENT_WORD4 colgroups[UNROLL_N_CHUNK / 4]; )==""\n"
384R"==(for (int cc = 0; cc < UNROLL_N_CHUNK / 4; cc++) )==""\n"
385R"==(REPACK_CC_WORD(cc); )==""\n"
386R"==(for (int cc = 0; cc < UNROLL_N_CHUNK / 4; cc++) )==""\n"
387R"==(BLOCK_WRITE_ELEMENT_WORD4( )==""\n"
388R"==(b_packed + (cc * 4 + c0) * UNROLL_K, colgroups[cc]); )==""\n"
389R"==(#if COPY_SUM )==""\n"
390R"==(for (int c = 0; c < UNROLL_N_CHUNK; c++) )==""\n"
391R"==(sums[c + c0] = sum(CONVERT_SUM_T2(AS_SIGNED_ELEMENT_WORD(cols[c]))); )==""\n"
392R"==(#endif )==""\n"
393R"==(} )==""\n"
394R"==(#if COPY_SUM )==""\n"
395R"==(for (int c0 = 0; c0 < UNROLL_N; c0 += get_sub_group_size()) )==""\n"
396R"==(atomic_add(b_sum + c0 + lid, sums[c0 + lid]); )==""\n"
397R"==(#endif )==""\n"
398R"==(DUMMY_DPAS; )==""\n"
399R"==(} )==""\n"
400R"==(#else /* COPY_TRANS */ )==""\n"
401R"==(#define ADD_SUM(coln) \ )==""\n"
402R"==(for (int cc = 0; cc < UNROLL_N / 4; cc++) { \ )==""\n"
403R"==(sums[4 * cc + 0] \ )==""\n"
404R"==(+= sum(CONVERT_SUM_T(AS_SIGNED_ELEMENT(coln[cc].s0))); \ )==""\n"
405R"==(sums[4 * cc + 1] \ )==""\n"
406R"==(+= sum(CONVERT_SUM_T(AS_SIGNED_ELEMENT(coln[cc].s1))); \ )==""\n"
407R"==(sums[4 * cc + 2] \ )==""\n"
408R"==(+= sum(CONVERT_SUM_T(AS_SIGNED_ELEMENT(coln[cc].s2))); \ )==""\n"
409R"==(sums[4 * cc + 3] \ )==""\n"
410R"==(+= sum(CONVERT_SUM_T(AS_SIGNED_ELEMENT(coln[cc].s3))); \ )==""\n"
411R"==(} )==""\n"
412R"==(__attribute__((intel_reqd_workgroup_walk_order(1, 0))) )==""\n"
413R"==(__attribute__((intel_reqd_sub_group_size(16))) kernel void )==""\n"
414R"==(xe_hpc_systolic_gemm_copy(long k, long n, global ELEMENT *b, long offsetb, )==""\n"
415R"==(long ldb, global ELEMENT *b_packed, int offsetb_packed, )==""\n"
416R"==(int ldb_packed) { )==""\n"
417R"==(int lid = get_sub_group_local_id(); )==""\n"
418R"==(uint k0 = (sub_group_broadcast(get_global_id(0), 0) / 16) * UNROLL_K; )==""\n"
419R"==(uint n0 = get_global_id(1) * UNROLL_N; )==""\n"
420R"==(int krem = k - k0; )==""\n"
421R"==(int nrem = n - n0; )==""\n"
422R"==(int sg = get_sub_group_size(); )==""\n"
423R"==(if (nrem <= 0 || krem <= 0) return; )==""\n"
424R"==(GET_B_SUM_ADDRESS; )==""\n"
425R"==(b += offsetb + n0 + k0 * ldb; )==""\n"
426R"==(b_packed += offsetb_packed + n0 * ldb_packed + k0 * UNROLL_N; )==""\n"
427R"==(ELEMENT_INT cols[UNROLL_N / ELEMENTS_PER_INT]; )==""\n"
428R"==(ELEMENT_INT4 colgroups[UNROLL_N / 8]; )==""\n"
429R"==(if (krem >= sg && nrem >= UNROLL_N) { )==""\n"
430R"==(for (int cc = 0; cc < UNROLL_N / ELEMENTS_PER_INT; cc++) )==""\n"
431R"==(cols[cc] = VLOAD_ELEMENT_INT( )==""\n"
432R"==(0, b + cc * ELEMENTS_PER_INT + lid * ldb); )==""\n"
433R"==(} else { )==""\n"
434R"==(for (int cc = 0; cc < UNROLL_N / ELEMENTS_PER_INT; cc++) { )==""\n"
435R"==(cols[cc] = 0; )==""\n"
436R"==(PARTIAL_LOAD(cols, krem, nrem, cc, )==""\n"
437R"==(b + cc * ELEMENTS_PER_INT + lid * ldb); )==""\n"
438R"==(} )==""\n"
439R"==(} )==""\n"
440R"==(#if ELEMENT_SIZE == 2 )==""\n"
441R"==(for (int cc = 0; cc < UNROLL_N / 8; cc++) )==""\n"
442R"==(REPACK_CC(cc); )==""\n"
443R"==(#else )==""\n"
444R"==(#if COPY_SUM )==""\n"
445R"==(SUM_T sums[UNROLL_N] = {0}; )==""\n"
446R"==(ADD_SUM(cols); )==""\n"
447R"==(#endif )==""\n"
448R"==(ELEMENT_INT cols2[UNROLL_N / ELEMENTS_PER_INT]; )==""\n"
449R"==(krem -= sg; )==""\n"
450R"==(if (krem >= sg && nrem >= UNROLL_N) { )==""\n"
451R"==(for (int cc = 0; cc < UNROLL_N / ELEMENTS_PER_INT; cc++) )==""\n"
452R"==(cols2[cc] = VLOAD_ELEMENT_INT( )==""\n"
453R"==(0, b + cc * ELEMENTS_PER_INT + (lid + sg) * ldb); )==""\n"
454R"==(} else { )==""\n"
455R"==(for (int cc = 0; cc < UNROLL_N / ELEMENTS_PER_INT; cc++) { )==""\n"
456R"==(cols2[cc] = 0; )==""\n"
457R"==(PARTIAL_LOAD(cols2, krem, nrem, cc, )==""\n"
458R"==(b + cc * ELEMENTS_PER_INT + (lid + sg) * ldb); )==""\n"
459R"==(} )==""\n"
460R"==(} )==""\n"
461R"==(for (int cc = 0; cc < UNROLL_N / 8; cc++) )==""\n"
462R"==(REPACK_CC2(cc); )==""\n"
463R"==(#if COPY_SUM )==""\n"
464R"==(ADD_SUM(cols2); )==""\n"
465R"==(#endif )==""\n"
466R"==(#endif )==""\n"
467R"==(for (int cc = 0; cc < UNROLL_N / 8; cc++) )==""\n"
468R"==(BLOCK_WRITE_ELEMENT_INT4(b_packed + cc * 8 * UNROLL_K, colgroups[cc]); )==""\n"
469R"==(#if COPY_SUM )==""\n"
470R"==(for (int c0 = 0; c0 < UNROLL_N; c0 += get_sub_group_size()) )==""\n"
471R"==(atomic_add(b_sum + c0 + lid, sums[c0 + lid]); )==""\n"
472R"==(#endif )==""\n"
473R"==(DUMMY_DPAS; )==""\n"
474R"==(} )==""\n"
475R"==(#endif /* !COPY_TRANS */ )==""\n"
476R"==(#endif /* COPY_B */ )==""\n"
477R"==()==";
478}
479}
480}
481}