1namespace dnnl {
2namespace impl {
3namespace gpu {
4namespace ocl {
5const char *xe_hp_systolic_gemm_copy_kernel = R"==(/******************************************************************************* )==""\n"
6R"==(* Copyright 2019-2022 Intel Corporation )==""\n"
7R"==(* )==""\n"
8R"==(* Licensed under the Apache License, Version 2.0 (the "License"); )==""\n"
9R"==(* you may not use this file except in compliance with the License. )==""\n"
10R"==(* You may obtain a copy of the License at )==""\n"
11R"==(* )==""\n"
12R"==(* http: )==""\n"
13R"==(* )==""\n"
14R"==(* Unless required by applicable law or agreed to in writing, software )==""\n"
15R"==(* distributed under the License is distributed on an "AS IS" BASIS, )==""\n"
16R"==(* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. )==""\n"
17R"==(* See the License for the specific language governing permissions and )==""\n"
18R"==(* limitations under the License. )==""\n"
19R"==(*******************************************************************************/ )==""\n"
20R"==(#include "gpu/ocl/ocl_math_utils.h" )==""\n"
21R"==(#if ELEMENT_SIZE == 2 )==""\n"
22R"==(#pragma OPENCL EXTENSION cl_intel_subgroups_short : enable )==""\n"
23R"==(#define ELEMENT ushort )==""\n"
24R"==(#define ELEMENT2 ushort2 )==""\n"
25R"==(#define ELEMENT4 ushort4 )==""\n"
26R"==(#define ELEMENT8 ushort8 )==""\n"
27R"==(#define ELEMENT16 ushort16 )==""\n"
28R"==(#define ELEMENT_INT ushort2 )==""\n"
29R"==(#define ELEMENT_INT4 ushort8 )==""\n"
30R"==(#define VLOAD_ELEMENT_INT vload2 )==""\n"
31R"==(#define ELEMENTS_PER_INT 2 )==""\n"
32R"==(#define BLOCK_READ_ELEMENT2 intel_sub_group_block_read_us2 )==""\n"
33R"==(#define BLOCK_READ_ELEMENT4 intel_sub_group_block_read_us4 )==""\n"
34R"==(#define BLOCK_READ_ELEMENT_INT intel_sub_group_block_read_us2 )==""\n"
35R"==(#define MASKED_BLOCK_READ_ELEMENT_INT masked_block_read_element2 )==""\n"
36R"==(#define BLOCK_WRITE_ELEMENT_INT4 intel_sub_group_block_write_us8 )==""\n"
37R"==(#elif ELEMENT_SIZE == 1 )==""\n"
38R"==(#define ELEMENT uchar )==""\n"
39R"==(#define ELEMENT2 uchar2 )==""\n"
40R"==(#define ELEMENT4 uchar4 )==""\n"
41R"==(#define ELEMENT8 uchar8 )==""\n"
42R"==(#define ELEMENT16 uchar16 )==""\n"
43R"==(#define ELEMENT_INT uchar4 )==""\n"
44R"==(#define ELEMENT_INT4 uchar16 )==""\n"
45R"==(#define VLOAD_ELEMENT_INT vload4 )==""\n"
46R"==(#define BLOCK_READ_ELEMENT2 intel_sub_group_block_read_uc2 )==""\n"
47R"==(#define BLOCK_READ_ELEMENT4 intel_sub_group_block_read_uc4 )==""\n"
48R"==(#define BLOCK_READ_ELEMENT_INT intel_sub_group_block_read_uc4 )==""\n"
49R"==(#define MASKED_BLOCK_READ_ELEMENT_INT masked_block_read_element4 )==""\n"
50R"==(#define BLOCK_WRITE_ELEMENT_INT4 intel_sub_group_block_write_uc16 )==""\n"
51R"==(#define ELEMENTS_PER_INT 4 )==""\n"
52R"==(#define SUM_T int )==""\n"
53R"==(#define SUM_T4 int4 )==""\n"
54R"==(#define CONVERT_SUM_T convert_int )==""\n"
55R"==(#define CONVERT_SUM_T4 convert_int4 )==""\n"
56R"==(#if COPY_SIGNED )==""\n"
57R"==(#define AS_SIGNED_ELEMENT as_char )==""\n"
58R"==(#define AS_SIGNED_ELEMENT4 as_char4 )==""\n"
59R"==(#define AS_SIGNED_ELEMENT_INT as_char4 )==""\n"
60R"==(#define SIGNED_ELEMENT_INT char4 )==""\n"
61R"==(#else )==""\n"
62R"==(#define AS_SIGNED_ELEMENT as_uchar )==""\n"
63R"==(#define AS_SIGNED_ELEMENT4 as_uchar4 )==""\n"
64R"==(#define AS_SIGNED_ELEMENT_INT as_uchar4 )==""\n"
65R"==(#define SIGNED_ELEMENT_INT uchar4 )==""\n"
66R"==(#endif )==""\n"
67R"==(#else )==""\n"
68R"==(#error Unsupported element size. )==""\n"
69R"==(#endif )==""\n"
70R"==(#if !COPY_A && !COPY_B )==""\n"
71R"==(#error Source matrix not defined. )==""\n"
72R"==(#endif )==""\n"
73R"==(inline ELEMENT2 masked_block_read_element2(global ELEMENT *p, int rem) { )==""\n"
74R"==(ELEMENT2 v; )==""\n"
75R"==(int lid = get_sub_group_local_id(); )==""\n"
76R"==(int sg = get_sub_group_size(); )==""\n"
77R"==(v.s0 = (lid < rem) ? p[lid] : 0; )==""\n"
78R"==(v.s1 = (lid + sg < rem) ? p[lid + sg] : 0; )==""\n"
79R"==(return v; )==""\n"
80R"==(} )==""\n"
81R"==(inline ELEMENT4 masked_block_read_element4(global ELEMENT *p, int rem) { )==""\n"
82R"==(ELEMENT4 v; )==""\n"
83R"==(int lid = get_sub_group_local_id(); )==""\n"
84R"==(int sg = get_sub_group_size(); )==""\n"
85R"==(v.s0 = (lid < rem) ? p[lid] : 0; )==""\n"
86R"==(v.s1 = (lid + sg < rem) ? p[lid + sg] : 0; )==""\n"
87R"==(v.s2 = (lid + 2 * sg < rem) ? p[lid + 2 * sg] : 0; )==""\n"
88R"==(v.s3 = (lid + 3 * sg < rem) ? p[lid + 3 * sg] : 0; )==""\n"
89R"==(return v; )==""\n"
90R"==(} )==""\n"
91R"==(__attribute__((overloadable)) inline int sum(int v) { )==""\n"
92R"==(return sub_group_reduce_add(v); )==""\n"
93R"==(} )==""\n"
94R"==(__attribute__((overloadable)) inline int sum(int4 v) { )==""\n"
95R"==(return sub_group_reduce_add(v.s0) + sub_group_reduce_add(v.s1) )==""\n"
96R"==(+ sub_group_reduce_add(v.s2) + sub_group_reduce_add(v.s3); )==""\n"
97R"==(} )==""\n"
98R"==(void dummy_dpas() { )==""\n"
99R"==(if (get_sub_group_local_id() >= 16) { )==""\n"
100R"==(int __builtin_IB_sub_group_idpas_s8_s8_8_1(int, int, int8) )==""\n"
101R"==(__attribute__((const)); )==""\n"
102R"==(global volatile int *_; )==""\n"
103R"==(int z = __builtin_IB_sub_group_idpas_s8_s8_8_1(0, _[0], 1); )==""\n"
104R"==(for (int i = 0; i < z; i++) )==""\n"
105R"==((void)_[0]; )==""\n"
106R"==(} )==""\n"
107R"==(} )==""\n"
108R"==(#define DUMMY_DPAS dummy_dpas() )==""\n"
109R"==(#if ELEMENT_SIZE == 2 )==""\n"
110R"==(#define PARTIAL_LOAD(regs, rrem, crem, cc, p) \ )==""\n"
111R"==(if ((2 * cc + 1) < crem) { \ )==""\n"
112R"==(if (lid < rrem) regs[cc] = vload2(0, p); \ )==""\n"
113R"==(} else if ((2 * cc) < crem) { \ )==""\n"
114R"==(if (lid < rrem) regs[cc].s0 = *(p); \ )==""\n"
115R"==(} )==""\n"
116R"==(#elif ELEMENT_SIZE == 1 )==""\n"
117R"==(#define PARTIAL_LOAD(regs, rrem, crem, cc, p) \ )==""\n"
118R"==(if ((4 * cc + 3) < crem) { \ )==""\n"
119R"==(if (lid < rrem) regs[cc] = vload4(0, p); \ )==""\n"
120R"==(} else if ((4 * cc + 2) < crem) { \ )==""\n"
121R"==(if (lid < rrem) regs[cc].s012 = vload3(0, p); \ )==""\n"
122R"==(} else if ((4 * cc + 1) < crem) { \ )==""\n"
123R"==(if (lid < rrem) regs[cc].s01 = vload2(0, p); \ )==""\n"
124R"==(} else if (4 * cc < crem) { \ )==""\n"
125R"==(if (lid < rrem) regs[cc].s0 = *(p); \ )==""\n"
126R"==(} )==""\n"
127R"==(#endif )==""\n"
128R"==(#if COPY_A )==""\n"
129R"==(#define UNROLL_M 32 )==""\n"
130R"==(#define UNROLL_K (32 / ELEMENT_SIZE) )==""\n"
131R"==(#if COPY_SUM )==""\n"
132R"==(#define GET_A_SUM_ADDRESS \ )==""\n"
133R"==(global int *a_sum = (global int *)(a_packed + offseta_packed \ )==""\n"
134R"==(+ (m0 + UNROLL_M) * lda_packed - UNROLL_M * sizeof(int)); )==""\n"
135R"==(#else )==""\n"
136R"==(#define GET_A_SUM_ADDRESS )==""\n"
137R"==(#endif )==""\n"
138R"==(#if COPY_CLEAR_SUM )==""\n"
139R"==(__attribute__((intel_reqd_sub_group_size(8))) kernel void )==""\n"
140R"==(xe_hp_systolic_gemm_copy(long m, long k, global ELEMENT *a_packed, )==""\n"
141R"==(int offseta_packed, int lda_packed) { )==""\n"
142R"==(uint m0 = (sub_group_broadcast(get_global_id(0), 0) / 8) * UNROLL_M; )==""\n"
143R"==(GET_A_SUM_ADDRESS; )==""\n"
144R"==(uint4 zero = 0; )==""\n"
145R"==(intel_sub_group_block_write4(a_sum, zero); )==""\n"
146R"==(} )==""\n"
147R"==(#elif !COPY_TRANS )==""\n"
148R"==(#if ELEMENT_SIZE == 2 )==""\n"
149R"==(#define REPACK_REG(rr, cc) \ )==""\n"
150R"==(blk_r[rr].s##cc = (((uint)c[2 * cc + 1].s##rr) << 16) | c[2 * cc].s##rr )==""\n"
151R"==(#elif ELEMENT_SIZE == 1 )==""\n"
152R"==(#define REPACK_REG(rr, cc) \ )==""\n"
153R"==(blk_r[rr].s##cc = (((uint)c[4 * cc + 3].s##rr) << 24) \ )==""\n"
154R"==(| (((uint)c[4 * cc + 2].s##rr) << 16) \ )==""\n"
155R"==(| (((uint)c[4 * cc + 1].s##rr) << 8) | c[4 * cc].s##rr )==""\n"
156R"==(#endif )==""\n"
157R"==(#define REPACK_CC(cc) \ )==""\n"
158R"==(REPACK_REG(0, cc); \ )==""\n"
159R"==(REPACK_REG(1, cc); \ )==""\n"
160R"==(REPACK_REG(2, cc); \ )==""\n"
161R"==(REPACK_REG(3, cc) )==""\n"
162R"==(#define REPACK \ )==""\n"
163R"==(REPACK_CC(0); \ )==""\n"
164R"==(REPACK_CC(1); \ )==""\n"
165R"==(REPACK_CC(2); \ )==""\n"
166R"==(REPACK_CC(3); \ )==""\n"
167R"==(REPACK_CC(4); \ )==""\n"
168R"==(REPACK_CC(5); \ )==""\n"
169R"==(REPACK_CC(6); \ )==""\n"
170R"==(REPACK_CC(7) )==""\n"
171R"==(__attribute__((intel_reqd_sub_group_size(8))) kernel void )==""\n"
172R"==(xe_hp_systolic_gemm_copy(long m, long k, global ELEMENT *a, long offseta, )==""\n"
173R"==(long lda, global ELEMENT *a_packed, int offseta_packed, )==""\n"
174R"==(int lda_packed) { )==""\n"
175R"==(int lid = get_sub_group_local_id(); )==""\n"
176R"==(uint m0 = (sub_group_broadcast(get_global_id(0), 0) / 8) * UNROLL_M; )==""\n"
177R"==(uint k0 = get_global_id(1) * UNROLL_K; )==""\n"
178R"==(int mrem = m - m0; )==""\n"
179R"==(int krem = k - k0; )==""\n"
180R"==(bool aligned = ((as_long(a) | lda | offseta) & (ELEMENTS_PER_INT - 1)) == 0; )==""\n"
181R"==(if (mrem <= 0 || krem <= 0) return; )==""\n"
182R"==(GET_A_SUM_ADDRESS; )==""\n"
183R"==(a += offseta + m0 + k0 * lda; )==""\n"
184R"==(a_packed += offseta_packed + m0 * lda_packed + k0 * UNROLL_M; )==""\n"
185R"==(ELEMENT4 c[UNROLL_K]; )==""\n"
186R"==(if (mrem >= UNROLL_M && krem >= UNROLL_K && aligned) { )==""\n"
187R"==(for (int h = 0; h < UNROLL_K; h++) )==""\n"
188R"==(c[h] = BLOCK_READ_ELEMENT4(a + h * lda); )==""\n"
189R"==(} else { )==""\n"
190R"==(for (int h = 0; h < UNROLL_K; h++) )==""\n"
191R"==(if (h < krem) )==""\n"
192R"==(c[h] = masked_block_read_element4(a + h * lda, mrem); )==""\n"
193R"==(else )==""\n"
194R"==(c[h] = 0; )==""\n"
195R"==(} )==""\n"
196R"==(uint8 blk_r[UNROLL_M / 8]; )==""\n"
197R"==(REPACK; )==""\n"
198R"==(for (int rr = 0; rr < UNROLL_M / 8; rr++) )==""\n"
199R"==(intel_sub_group_block_write8( )==""\n"
200R"==((global uint *)(a_packed + rr * UNROLL_K * 8), blk_r[rr]); )==""\n"
201R"==(#if COPY_SUM )==""\n"
202R"==(SUM_T4 sum = 0; )==""\n"
203R"==(for (int h = 0; h < UNROLL_K; h++) )==""\n"
204R"==(sum += CONVERT_SUM_T4(AS_SIGNED_ELEMENT4(c[h])); )==""\n"
205R"==(atomic_add(a_sum + lid, sum.s0); )==""\n"
206R"==(atomic_add(a_sum + lid + 8, sum.s1); )==""\n"
207R"==(atomic_add(a_sum + lid + 16, sum.s2); )==""\n"
208R"==(atomic_add(a_sum + lid + 24, sum.s3); )==""\n"
209R"==(#endif )==""\n"
210R"==(DUMMY_DPAS; )==""\n"
211R"==(} )==""\n"
212R"==(#else /* COPY_TRANS */ )==""\n"
213R"==(__attribute__((intel_reqd_workgroup_walk_order(1, 0))) )==""\n"
214R"==(__attribute__((intel_reqd_sub_group_size(8))) kernel void )==""\n"
215R"==(xe_hp_systolic_gemm_copy(long m, long k, global ELEMENT *a, long offseta, )==""\n"
216R"==(long lda, global ELEMENT *a_packed, int offseta_packed, )==""\n"
217R"==(int lda_packed) { )==""\n"
218R"==(int lid = get_sub_group_local_id(); )==""\n"
219R"==(uint m0 = (sub_group_broadcast(get_global_id(0), 0) / 8) * UNROLL_M; )==""\n"
220R"==(uint k0 = get_global_id(1) * UNROLL_K; )==""\n"
221R"==(int mrem = m - m0; )==""\n"
222R"==(int krem = k - k0; )==""\n"
223R"==(if (mrem <= 0 || krem <= 0) return; )==""\n"
224R"==(GET_A_SUM_ADDRESS; )==""\n"
225R"==(a += offseta + m0 * lda + k0; )==""\n"
226R"==(a_packed += offseta_packed + m0 * lda_packed + k0 * UNROLL_M; )==""\n"
227R"==(#if COPY_SUM )==""\n"
228R"==(SUM_T sum[UNROLL_M / 8] = {0}; )==""\n"
229R"==(#endif )==""\n"
230R"==(for (int rr = 0; rr < UNROLL_M / 8; rr++, mrem -= 8) { )==""\n"
231R"==(ELEMENT_INT regs[8]; )==""\n"
232R"==(if (mrem >= UNROLL_M && krem >= UNROLL_K) { )==""\n"
233R"==(for (int cc = 0; cc < UNROLL_K / ELEMENTS_PER_INT; cc++) )==""\n"
234R"==(regs[cc] = VLOAD_ELEMENT_INT(0, )==""\n"
235R"==(a + ((rr * 8) + lid) * lda + (cc * ELEMENTS_PER_INT)); )==""\n"
236R"==(} else { )==""\n"
237R"==(for (int cc = 0; cc < UNROLL_K / ELEMENTS_PER_INT; cc++) { )==""\n"
238R"==(regs[cc] = 0; )==""\n"
239R"==(PARTIAL_LOAD(regs, mrem, krem, cc, )==""\n"
240R"==(a + ((rr * 8) + lid) * lda + (cc * ELEMENTS_PER_INT)); )==""\n"
241R"==(} )==""\n"
242R"==(} )==""\n"
243R"==(uint8 blk_r; )==""\n"
244R"==(blk_r.s0 = as_uint(regs[0]); )==""\n"
245R"==(blk_r.s1 = as_uint(regs[1]); )==""\n"
246R"==(blk_r.s2 = as_uint(regs[2]); )==""\n"
247R"==(blk_r.s3 = as_uint(regs[3]); )==""\n"
248R"==(blk_r.s4 = as_uint(regs[4]); )==""\n"
249R"==(blk_r.s5 = as_uint(regs[5]); )==""\n"
250R"==(blk_r.s6 = as_uint(regs[6]); )==""\n"
251R"==(blk_r.s7 = as_uint(regs[7]); )==""\n"
252R"==(#if COPY_SUM )==""\n"
253R"==(for (int cc = 0; cc < UNROLL_K / ELEMENTS_PER_INT; cc++) { )==""\n"
254R"==(sum[rr] += CONVERT_SUM_T(AS_SIGNED_ELEMENT(regs[cc].s0)); )==""\n"
255R"==(sum[rr] += CONVERT_SUM_T(AS_SIGNED_ELEMENT(regs[cc].s1)); )==""\n"
256R"==(sum[rr] += CONVERT_SUM_T(AS_SIGNED_ELEMENT(regs[cc].s2)); )==""\n"
257R"==(sum[rr] += CONVERT_SUM_T(AS_SIGNED_ELEMENT(regs[cc].s3)); )==""\n"
258R"==(} )==""\n"
259R"==(#endif )==""\n"
260R"==(intel_sub_group_block_write8( )==""\n"
261R"==((global uint *)(a_packed + rr * UNROLL_K * 8), blk_r); )==""\n"
262R"==(} )==""\n"
263R"==(#if COPY_SUM )==""\n"
264R"==(atomic_add(a_sum + lid, sum[0]); )==""\n"
265R"==(atomic_add(a_sum + lid + 8, sum[1]); )==""\n"
266R"==(atomic_add(a_sum + lid + 16, sum[2]); )==""\n"
267R"==(atomic_add(a_sum + lid + 24, sum[3]); )==""\n"
268R"==(#endif )==""\n"
269R"==(DUMMY_DPAS; )==""\n"
270R"==(} )==""\n"
271R"==(#endif /* !COPY_TRANS */ )==""\n"
272R"==(#endif /* COPY_A */ )==""\n"
273R"==(#if COPY_B )==""\n"
274R"==(#define UNROLL_K (32 / ELEMENT_SIZE) )==""\n"
275R"==(#if ELEMENT_SIZE == 2 )==""\n"
276R"==(#define REPACK_CC(cc) \ )==""\n"
277R"==(do { \ )==""\n"
278R"==(colgroups[cc].s01 = cols[cc * 4]; \ )==""\n"
279R"==(colgroups[cc].s23 = cols[cc * 4 + 1]; \ )==""\n"
280R"==(colgroups[cc].s45 = cols[cc * 4 + 2]; \ )==""\n"
281R"==(colgroups[cc].s67 = cols[cc * 4 + 3]; \ )==""\n"
282R"==(} while (false) )==""\n"
283R"==(#define REPACK_CC2(cc) \ )==""\n"
284R"==(do { \ )==""\n"
285R"==(colgroups[cc].s02 = cols[cc * 2]; \ )==""\n"
286R"==(colgroups[cc].s13 = cols2[cc * 2]; \ )==""\n"
287R"==(colgroups[cc].s46 = cols[cc * 2 + 1]; \ )==""\n"
288R"==(colgroups[cc].s57 = cols2[cc * 2 + 1]; \ )==""\n"
289R"==(} while (false) )==""\n"
290R"==(#elif ELEMENT_SIZE == 1 )==""\n"
291R"==(#define REPACK_CC(cc) \ )==""\n"
292R"==(do { \ )==""\n"
293R"==(colgroups[cc].s0123 = cols[cc * 4]; \ )==""\n"
294R"==(colgroups[cc].s4567 = cols[cc * 4 + 1]; \ )==""\n"
295R"==(colgroups[cc].s89ab = cols[cc * 4 + 2]; \ )==""\n"
296R"==(colgroups[cc].scdef = cols[cc * 4 + 3]; \ )==""\n"
297R"==(} while (false) )==""\n"
298R"==(#define REPACK_CC4(cc) \ )==""\n"
299R"==(do { \ )==""\n"
300R"==(colgroups[cc].s048c = cols[cc]; \ )==""\n"
301R"==(colgroups[cc].s159d = cols2[cc]; \ )==""\n"
302R"==(colgroups[cc].s26ae = cols3[cc]; \ )==""\n"
303R"==(colgroups[cc].s37bf = cols4[cc]; \ )==""\n"
304R"==(} while (false) )==""\n"
305R"==(#endif )==""\n"
306R"==(#if COPY_SUM )==""\n"
307R"==(#define GET_B_SUM_ADDRESS \ )==""\n"
308R"==(global int *b_sum = (global int *)(b_packed + offsetb_packed \ )==""\n"
309R"==(+ (n0 + UNROLL_N) * ldb_packed - UNROLL_N * sizeof(int)); )==""\n"
310R"==(#else )==""\n"
311R"==(#define GET_B_SUM_ADDRESS )==""\n"
312R"==(#endif )==""\n"
313R"==(#if COPY_CLEAR_SUM )==""\n"
314R"==(__attribute__((intel_reqd_sub_group_size(8))) kernel void )==""\n"
315R"==(xe_hp_systolic_gemm_copy(long k, long n, global ELEMENT *b_packed, )==""\n"
316R"==(int offsetb_packed, int ldb_packed) { )==""\n"
317R"==(uint n0 = (sub_group_broadcast(get_global_id(0), 0) / 8) * UNROLL_N; )==""\n"
318R"==(GET_B_SUM_ADDRESS; )==""\n"
319R"==(uint4 zero = 0; )==""\n"
320R"==(intel_sub_group_block_write4(b_sum, zero); )==""\n"
321R"==(#if UNROLL_N > 32 )==""\n"
322R"==(intel_sub_group_block_write2(b_sum + 32, zero.s01); )==""\n"
323R"==(#endif )==""\n"
324R"==(} )==""\n"
325R"==(#elif !COPY_TRANS )==""\n"
326R"==(__attribute__((intel_reqd_sub_group_size(8))) kernel void )==""\n"
327R"==(xe_hp_systolic_gemm_copy(long k, long n, global ELEMENT *b, long offsetb, )==""\n"
328R"==(long ldb, global ELEMENT *b_packed, int offsetb_packed, )==""\n"
329R"==(int ldb_packed) { )==""\n"
330R"==(int lid = get_sub_group_local_id(); )==""\n"
331R"==(uint k0 = (sub_group_broadcast(get_global_id(0), 0) / 8) * UNROLL_K; )==""\n"
332R"==(uint n0 = get_global_id(1) * UNROLL_N; )==""\n"
333R"==(int krem = k - k0; )==""\n"
334R"==(int nrem = n - n0; )==""\n"
335R"==(bool aligned = ((as_long(b) | ldb | offsetb) & (ELEMENTS_PER_INT - 1)) == 0; )==""\n"
336R"==(if (nrem <= 0 || krem <= 0) return; )==""\n"
337R"==(GET_B_SUM_ADDRESS; )==""\n"
338R"==(b += offsetb + k0 + n0 * ldb; )==""\n"
339R"==(b_packed += offsetb_packed + n0 * ldb_packed + k0 * UNROLL_N; )==""\n"
340R"==(#define UNROLL_N_CHUNK (UNROLL_N / 2) )==""\n"
341R"==(#if COPY_SUM )==""\n"
342R"==(SUM_T sums[UNROLL_N]; )==""\n"
343R"==(#endif )==""\n"
344R"==(ELEMENT_INT cols[UNROLL_N / 2]; )==""\n"
345R"==(for (int c0 = 0; c0 < UNROLL_N; )==""\n"
346R"==(c0 += UNROLL_N_CHUNK, nrem -= UNROLL_N_CHUNK) { )==""\n"
347R"==(if (krem >= UNROLL_K && nrem >= UNROLL_N_CHUNK && aligned) { )==""\n"
348R"==(for (int c = 0; c < UNROLL_N_CHUNK; c++) )==""\n"
349R"==(cols[c] = BLOCK_READ_ELEMENT_INT(b + (c + c0) * ldb); )==""\n"
350R"==(} else { )==""\n"
351R"==(for (int c = 0; c < UNROLL_N_CHUNK; c++) )==""\n"
352R"==(if (c < nrem) )==""\n"
353R"==(cols[c] = MASKED_BLOCK_READ_ELEMENT_INT( )==""\n"
354R"==(b + (c + c0) * ldb, krem); )==""\n"
355R"==(else )==""\n"
356R"==(cols[c] = 0; )==""\n"
357R"==(} )==""\n"
358R"==(ELEMENT_INT4 colgroups[UNROLL_N_CHUNK / 4]; )==""\n"
359R"==(for (int cc = 0; cc < UNROLL_N_CHUNK / 4; cc++) )==""\n"
360R"==(REPACK_CC(cc); )==""\n"
361R"==(for (int cc = 0; cc < UNROLL_N_CHUNK / 4; cc++) )==""\n"
362R"==(BLOCK_WRITE_ELEMENT_INT4( )==""\n"
363R"==(b_packed + (cc * 4 + c0) * UNROLL_K, colgroups[cc]); )==""\n"
364R"==(#if COPY_SUM )==""\n"
365R"==(for (int c = 0; c < UNROLL_N_CHUNK; c++) )==""\n"
366R"==(sums[c + c0] = sum(CONVERT_SUM_T4(AS_SIGNED_ELEMENT_INT(cols[c]))); )==""\n"
367R"==(#endif )==""\n"
368R"==(} )==""\n"
369R"==(#if COPY_SUM )==""\n"
370R"==(for (int c0 = 0; c0 < UNROLL_N; c0 += get_sub_group_size()) )==""\n"
371R"==(atomic_add(b_sum + c0 + lid, sums[c0 + lid]); )==""\n"
372R"==(#endif )==""\n"
373R"==(DUMMY_DPAS; )==""\n"
374R"==(} )==""\n"
375R"==(#else /* COPY_TRANS */ )==""\n"
376R"==(#define ADD_SUM(coln) \ )==""\n"
377R"==(for (int cc = 0; cc < UNROLL_N / 4; cc++) { \ )==""\n"
378R"==(sums[4 * cc + 0] \ )==""\n"
379R"==(+= sum(CONVERT_SUM_T(AS_SIGNED_ELEMENT(coln[cc].s0))); \ )==""\n"
380R"==(sums[4 * cc + 1] \ )==""\n"
381R"==(+= sum(CONVERT_SUM_T(AS_SIGNED_ELEMENT(coln[cc].s1))); \ )==""\n"
382R"==(sums[4 * cc + 2] \ )==""\n"
383R"==(+= sum(CONVERT_SUM_T(AS_SIGNED_ELEMENT(coln[cc].s2))); \ )==""\n"
384R"==(sums[4 * cc + 3] \ )==""\n"
385R"==(+= sum(CONVERT_SUM_T(AS_SIGNED_ELEMENT(coln[cc].s3))); \ )==""\n"
386R"==(} )==""\n"
387R"==(__attribute__((intel_reqd_workgroup_walk_order(1, 0))) )==""\n"
388R"==(__attribute__((intel_reqd_sub_group_size(8))) kernel void )==""\n"
389R"==(xe_hp_systolic_gemm_copy(long k, long n, global ELEMENT *b, long offsetb, )==""\n"
390R"==(long ldb, global ELEMENT *b_packed, int offsetb_packed, )==""\n"
391R"==(int ldb_packed) { )==""\n"
392R"==(int lid = get_sub_group_local_id(); )==""\n"
393R"==(uint k0 = (sub_group_broadcast(get_global_id(0), 0) / 8) * UNROLL_K; )==""\n"
394R"==(uint n0 = get_global_id(1) * UNROLL_N; )==""\n"
395R"==(int krem = k - k0; )==""\n"
396R"==(int nrem = n - n0; )==""\n"
397R"==(int sg = get_sub_group_size(); )==""\n"
398R"==(if (nrem <= 0 || krem <= 0) return; )==""\n"
399R"==(GET_B_SUM_ADDRESS; )==""\n"
400R"==(b += offsetb + n0 + k0 * ldb; )==""\n"
401R"==(b_packed += offsetb_packed + n0 * ldb_packed + k0 * UNROLL_N; )==""\n"
402R"==(ELEMENT_INT cols[UNROLL_N / ELEMENTS_PER_INT]; )==""\n"
403R"==(ELEMENT_INT cols2[UNROLL_N / ELEMENTS_PER_INT]; )==""\n"
404R"==(ELEMENT_INT4 colgroups[UNROLL_N / 4]; )==""\n"
405R"==(if (krem >= 2 * sg && nrem >= UNROLL_N) { )==""\n"
406R"==(for (int cc = 0; cc < UNROLL_N / ELEMENTS_PER_INT; cc++) { )==""\n"
407R"==(cols[cc] = VLOAD_ELEMENT_INT( )==""\n"
408R"==(0, b + cc * ELEMENTS_PER_INT + lid * ldb); )==""\n"
409R"==(cols2[cc] = VLOAD_ELEMENT_INT( )==""\n"
410R"==(0, b + cc * ELEMENTS_PER_INT + (lid + sg) * ldb); )==""\n"
411R"==(} )==""\n"
412R"==(} else { )==""\n"
413R"==(for (int cc = 0; cc < UNROLL_N / ELEMENTS_PER_INT; cc++) { )==""\n"
414R"==(cols[cc] = 0; )==""\n"
415R"==(cols2[cc] = 0; )==""\n"
416R"==(PARTIAL_LOAD(cols, krem, nrem, cc, )==""\n"
417R"==(b + cc * ELEMENTS_PER_INT + lid * ldb); )==""\n"
418R"==(PARTIAL_LOAD(cols2, krem - sg, nrem, cc, )==""\n"
419R"==(b + cc * ELEMENTS_PER_INT + (lid + sg) * ldb); )==""\n"
420R"==(} )==""\n"
421R"==(} )==""\n"
422R"==(#if ELEMENT_SIZE == 2 )==""\n"
423R"==(for (int cc = 0; cc < UNROLL_N / 4; cc++) )==""\n"
424R"==(REPACK_CC2(cc); )==""\n"
425R"==(#else )==""\n"
426R"==(ELEMENT_INT cols3[UNROLL_N / ELEMENTS_PER_INT]; )==""\n"
427R"==(ELEMENT_INT cols4[UNROLL_N / ELEMENTS_PER_INT]; )==""\n"
428R"==(krem -= 2 * sg; )==""\n"
429R"==(if (krem >= 2 * sg && nrem >= UNROLL_N) { )==""\n"
430R"==(for (int cc = 0; cc < UNROLL_N / ELEMENTS_PER_INT; cc++) { )==""\n"
431R"==(cols3[cc] = VLOAD_ELEMENT_INT( )==""\n"
432R"==(0, b + cc * ELEMENTS_PER_INT + (lid + 2 * sg) * ldb); )==""\n"
433R"==(cols4[cc] = VLOAD_ELEMENT_INT( )==""\n"
434R"==(0, b + cc * ELEMENTS_PER_INT + (lid + 3 * sg) * ldb); )==""\n"
435R"==(} )==""\n"
436R"==(} else { )==""\n"
437R"==(for (int cc = 0; cc < UNROLL_N / ELEMENTS_PER_INT; cc++) { )==""\n"
438R"==(cols3[cc] = 0; )==""\n"
439R"==(cols4[cc] = 0; )==""\n"
440R"==(PARTIAL_LOAD(cols3, krem, nrem, cc, )==""\n"
441R"==(b + cc * ELEMENTS_PER_INT + (lid + 2 * sg) * ldb); )==""\n"
442R"==(PARTIAL_LOAD(cols4, krem - sg, nrem, cc, )==""\n"
443R"==(b + cc * ELEMENTS_PER_INT + (lid + 3 * sg) * ldb); )==""\n"
444R"==(} )==""\n"
445R"==(} )==""\n"
446R"==(for (int cc = 0; cc < UNROLL_N / 4; cc++) )==""\n"
447R"==(REPACK_CC4(cc); )==""\n"
448R"==(#endif )==""\n"
449R"==(for (int cc = 0; cc < UNROLL_N / 4; cc++) )==""\n"
450R"==(BLOCK_WRITE_ELEMENT_INT4(b_packed + cc * 4 * UNROLL_K, colgroups[cc]); )==""\n"
451R"==(#if COPY_SUM )==""\n"
452R"==(SUM_T sums[UNROLL_N] = {0}; )==""\n"
453R"==(ADD_SUM(cols); )==""\n"
454R"==(ADD_SUM(cols2); )==""\n"
455R"==(ADD_SUM(cols3); )==""\n"
456R"==(ADD_SUM(cols4); )==""\n"
457R"==(for (int c0 = 0; c0 < UNROLL_N; c0 += get_sub_group_size()) )==""\n"
458R"==(atomic_add(b_sum + c0 + lid, sums[c0 + lid]); )==""\n"
459R"==(#endif )==""\n"
460R"==(DUMMY_DPAS; )==""\n"
461R"==(} )==""\n"
462R"==(#endif /* !COPY_TRANS */ )==""\n"
463R"==(#endif /* COPY_B */ )==""\n"
464R"==()==";
465}
466}
467}
468}