1namespace dnnl {
2namespace impl {
3namespace gpu {
4namespace ocl {
5const char *gen9_gemm_nocopy_f32_kernel = R"==(/******************************************************************************* )==""\n"
6R"==(* Copyright 2019-2022 Intel Corporation )==""\n"
7R"==(* )==""\n"
8R"==(* Licensed under the Apache License, Version 2.0 (the "License"); )==""\n"
9R"==(* you may not use this file except in compliance with the License. )==""\n"
10R"==(* You may obtain a copy of the License at )==""\n"
11R"==(* )==""\n"
12R"==(* http: )==""\n"
13R"==(* )==""\n"
14R"==(* Unless required by applicable law or agreed to in writing, software )==""\n"
15R"==(* distributed under the License is distributed on an "AS IS" BASIS, )==""\n"
16R"==(* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. )==""\n"
17R"==(* See the License for the specific language governing permissions and )==""\n"
18R"==(* limitations under the License. )==""\n"
19R"==(*******************************************************************************/ )==""\n"
20R"==(#include "gpu/ocl/gemm/ocl_gemm_attrs.h" )==""\n"
21R"==(#include "gpu/ocl/ocl_post_ops.h" )==""\n"
22R"==(#include "gpu/ocl/ocl_types.h" )==""\n"
23R"==(#if DT_F32 != 1 )==""\n"
24R"==(#error "Incorrect datatype." )==""\n"
25R"==(#endif )==""\n"
26R"==(#define DO_FMA_NN(hh, i_mod_16, i_div_16, i_mod_4, i_div_4) \ )==""\n"
27R"==(do { \ )==""\n"
28R"==(c[i_div_4].s##i_mod_4 \ )==""\n"
29R"==(= mad(sub_group_broadcast(a[hh].s##i_div_16, i_mod_16), \ )==""\n"
30R"==(b.s##hh, c[i_div_4].s##i_mod_4); \ )==""\n"
31R"==(} while (0) )==""\n"
32R"==(#define DO_FMA_NT(hh, i_mod_16, i_div_16, i_mod_4, i_div_4) \ )==""\n"
33R"==(do { \ )==""\n"
34R"==(c[i_div_4].s##i_mod_4 \ )==""\n"
35R"==(= mad(sub_group_broadcast(a[hh].s##i_div_16, i_mod_16), b[hh], \ )==""\n"
36R"==(c[i_div_4].s##i_mod_4); \ )==""\n"
37R"==(} while (0) )==""\n"
38R"==(#define DO_FMA_TN(hh, i, i_mod_4, i_div_4) \ )==""\n"
39R"==(do { \ )==""\n"
40R"==(c[i_div_4][0].s##i_mod_4 = mad(sub_group_broadcast(a.s##hh, i), \ )==""\n"
41R"==(b[0].s##hh, c[i_div_4][0].s##i_mod_4); \ )==""\n"
42R"==(c[i_div_4][1].s##i_mod_4 = mad(sub_group_broadcast(a.s##hh, i), \ )==""\n"
43R"==(b[1].s##hh, c[i_div_4][1].s##i_mod_4); \ )==""\n"
44R"==(} while (0) )==""\n"
45R"==(#define DO_FMA_TT(hh, i, i_mod_4, i_div_4) \ )==""\n"
46R"==(do { \ )==""\n"
47R"==(c[i_div_4][0].s##i_mod_4 = mad(sub_group_broadcast(a.s##hh, i), \ )==""\n"
48R"==(b[hh].s0, c[i_div_4][0].s##i_mod_4); \ )==""\n"
49R"==(c[i_div_4][1].s##i_mod_4 = mad(sub_group_broadcast(a.s##hh, i), \ )==""\n"
50R"==(b[hh].s1, c[i_div_4][1].s##i_mod_4); \ )==""\n"
51R"==(} while (0) )==""\n"
52R"==(#if !defined(TRANS_A) )==""\n"
53R"==(#if !defined(TRANS_B) )==""\n"
54R"==(#define NN )==""\n"
55R"==(#define DO_FMA DO_FMA_NN )==""\n"
56R"==(#else )==""\n"
57R"==(#define NT )==""\n"
58R"==(#define DO_FMA DO_FMA_NT )==""\n"
59R"==(#endif )==""\n"
60R"==(#else )==""\n"
61R"==(#if !defined(TRANS_B) )==""\n"
62R"==(#define TN )==""\n"
63R"==(#define DO_FMA DO_FMA_TN )==""\n"
64R"==(#else )==""\n"
65R"==(#define TT )==""\n"
66R"==(#define DO_FMA DO_FMA_TT )==""\n"
67R"==(#endif )==""\n"
68R"==(#endif )==""\n"
69R"==(#if WITH_ELTWISE == 1 )==""\n"
70R"==(#define POST_OP(val) \ )==""\n"
71R"==(do { \ )==""\n"
72R"==(if (last_k_block && last_k_unroll) \ )==""\n"
73R"==(val = fwd_eltwise( \ )==""\n"
74R"==(val, eltwise_alpha, eltwise_beta, eltwise_scale); \ )==""\n"
75R"==(} while (0) )==""\n"
76R"==(#else )==""\n"
77R"==(#define POST_OP(val) )==""\n"
78R"==(#endif )==""\n"
79R"==(#define FMA_I_LOOP_32_ROW(hh) \ )==""\n"
80R"==(do { \ )==""\n"
81R"==(DO_FMA(hh, 0, 0, 0, 0); \ )==""\n"
82R"==(DO_FMA(hh, 1, 0, 1, 0); \ )==""\n"
83R"==(DO_FMA(hh, 2, 0, 2, 0); \ )==""\n"
84R"==(DO_FMA(hh, 3, 0, 3, 0); \ )==""\n"
85R"==(DO_FMA(hh, 4, 0, 0, 1); \ )==""\n"
86R"==(DO_FMA(hh, 5, 0, 1, 1); \ )==""\n"
87R"==(DO_FMA(hh, 6, 0, 2, 1); \ )==""\n"
88R"==(DO_FMA(hh, 7, 0, 3, 1); \ )==""\n"
89R"==(DO_FMA(hh, 8, 0, 0, 2); \ )==""\n"
90R"==(DO_FMA(hh, 9, 0, 1, 2); \ )==""\n"
91R"==(DO_FMA(hh, 10, 0, 2, 2); \ )==""\n"
92R"==(DO_FMA(hh, 11, 0, 3, 2); \ )==""\n"
93R"==(DO_FMA(hh, 12, 0, 0, 3); \ )==""\n"
94R"==(DO_FMA(hh, 13, 0, 1, 3); \ )==""\n"
95R"==(DO_FMA(hh, 14, 0, 2, 3); \ )==""\n"
96R"==(DO_FMA(hh, 15, 0, 3, 3); \ )==""\n"
97R"==(DO_FMA(hh, 16, 1, 0, 4); \ )==""\n"
98R"==(DO_FMA(hh, 17, 1, 1, 4); \ )==""\n"
99R"==(DO_FMA(hh, 18, 1, 2, 4); \ )==""\n"
100R"==(DO_FMA(hh, 19, 1, 3, 4); \ )==""\n"
101R"==(DO_FMA(hh, 20, 1, 0, 5); \ )==""\n"
102R"==(DO_FMA(hh, 21, 1, 1, 5); \ )==""\n"
103R"==(DO_FMA(hh, 22, 1, 2, 5); \ )==""\n"
104R"==(DO_FMA(hh, 23, 1, 3, 5); \ )==""\n"
105R"==(DO_FMA(hh, 24, 1, 0, 6); \ )==""\n"
106R"==(DO_FMA(hh, 25, 1, 1, 6); \ )==""\n"
107R"==(DO_FMA(hh, 26, 1, 2, 6); \ )==""\n"
108R"==(DO_FMA(hh, 27, 1, 3, 6); \ )==""\n"
109R"==(DO_FMA(hh, 28, 1, 0, 7); \ )==""\n"
110R"==(DO_FMA(hh, 29, 1, 1, 7); \ )==""\n"
111R"==(DO_FMA(hh, 30, 1, 2, 7); \ )==""\n"
112R"==(DO_FMA(hh, 31, 1, 3, 7); \ )==""\n"
113R"==(} while (0) )==""\n"
114R"==(#define FMA_I_LOOP_16_ROW(hh) \ )==""\n"
115R"==(do { \ )==""\n"
116R"==(DO_FMA(hh, 0, 0, 0); \ )==""\n"
117R"==(DO_FMA(hh, 1, 1, 0); \ )==""\n"
118R"==(DO_FMA(hh, 2, 2, 0); \ )==""\n"
119R"==(DO_FMA(hh, 3, 3, 0); \ )==""\n"
120R"==(DO_FMA(hh, 4, 0, 1); \ )==""\n"
121R"==(DO_FMA(hh, 5, 1, 1); \ )==""\n"
122R"==(DO_FMA(hh, 6, 2, 1); \ )==""\n"
123R"==(DO_FMA(hh, 7, 3, 1); \ )==""\n"
124R"==(DO_FMA(hh, 8, 0, 2); \ )==""\n"
125R"==(DO_FMA(hh, 9, 1, 2); \ )==""\n"
126R"==(DO_FMA(hh, 10, 2, 2); \ )==""\n"
127R"==(DO_FMA(hh, 11, 3, 2); \ )==""\n"
128R"==(DO_FMA(hh, 12, 0, 3); \ )==""\n"
129R"==(DO_FMA(hh, 13, 1, 3); \ )==""\n"
130R"==(DO_FMA(hh, 14, 2, 3); \ )==""\n"
131R"==(DO_FMA(hh, 15, 3, 3); \ )==""\n"
132R"==(} while (0) )==""\n"
133R"==(#define UPDATE_C_ROW(i, ii, betaZero) \ )==""\n"
134R"==(do { \ )==""\n"
135R"==(if (jrem > 0) \ )==""\n"
136R"==(if (irem > i) { \ )==""\n"
137R"==(float val = ATTR_ALPHA * c[i / 4].s##ii \ )==""\n"
138R"==(+ ((betaZero) ? 0 : beta * *C); \ )==""\n"
139R"==(POST_OP(val); \ )==""\n"
140R"==(*C = val; \ )==""\n"
141R"==(} \ )==""\n"
142R"==(C++; \ )==""\n"
143R"==(} while (0) )==""\n"
144R"==(#define UPDATE_C_ROW_2X(i, ii, betaZero) \ )==""\n"
145R"==(do { \ )==""\n"
146R"==(if (irem > i) { \ )==""\n"
147R"==(if (jrem > 0) { \ )==""\n"
148R"==(float val = ATTR_ALPHA * c[i / 4][0].s##ii \ )==""\n"
149R"==(+ ((betaZero) ? 0 : beta * *(C_ptrs[0])); \ )==""\n"
150R"==(POST_OP(val); \ )==""\n"
151R"==(*(C_ptrs[0]) = val; \ )==""\n"
152R"==(} \ )==""\n"
153R"==(if (jrem > 16) { \ )==""\n"
154R"==(float val = ATTR_ALPHA * c[i / 4][1].s##ii \ )==""\n"
155R"==(+ ((betaZero) ? 0 : beta * *(C_ptrs[1])); \ )==""\n"
156R"==(POST_OP(val); \ )==""\n"
157R"==(*(C_ptrs[1]) = val; \ )==""\n"
158R"==(} \ )==""\n"
159R"==(} \ )==""\n"
160R"==(C_ptrs[0]++; \ )==""\n"
161R"==(C_ptrs[1]++; \ )==""\n"
162R"==(} while (0) )==""\n"
163R"==(#define UPDATE_C_32_ROW(betaZero) \ )==""\n"
164R"==(do { \ )==""\n"
165R"==(UPDATE_C_ROW(0, 0, betaZero); \ )==""\n"
166R"==(UPDATE_C_ROW(1, 1, betaZero); \ )==""\n"
167R"==(UPDATE_C_ROW(2, 2, betaZero); \ )==""\n"
168R"==(UPDATE_C_ROW(3, 3, betaZero); \ )==""\n"
169R"==(UPDATE_C_ROW(4, 0, betaZero); \ )==""\n"
170R"==(UPDATE_C_ROW(5, 1, betaZero); \ )==""\n"
171R"==(UPDATE_C_ROW(6, 2, betaZero); \ )==""\n"
172R"==(UPDATE_C_ROW(7, 3, betaZero); \ )==""\n"
173R"==(UPDATE_C_ROW(8, 0, betaZero); \ )==""\n"
174R"==(UPDATE_C_ROW(9, 1, betaZero); \ )==""\n"
175R"==(UPDATE_C_ROW(10, 2, betaZero); \ )==""\n"
176R"==(UPDATE_C_ROW(11, 3, betaZero); \ )==""\n"
177R"==(UPDATE_C_ROW(12, 0, betaZero); \ )==""\n"
178R"==(UPDATE_C_ROW(13, 1, betaZero); \ )==""\n"
179R"==(UPDATE_C_ROW(14, 2, betaZero); \ )==""\n"
180R"==(UPDATE_C_ROW(15, 3, betaZero); \ )==""\n"
181R"==(UPDATE_C_ROW(16, 0, betaZero); \ )==""\n"
182R"==(UPDATE_C_ROW(17, 1, betaZero); \ )==""\n"
183R"==(UPDATE_C_ROW(18, 2, betaZero); \ )==""\n"
184R"==(UPDATE_C_ROW(19, 3, betaZero); \ )==""\n"
185R"==(UPDATE_C_ROW(20, 0, betaZero); \ )==""\n"
186R"==(UPDATE_C_ROW(21, 1, betaZero); \ )==""\n"
187R"==(UPDATE_C_ROW(22, 2, betaZero); \ )==""\n"
188R"==(UPDATE_C_ROW(23, 3, betaZero); \ )==""\n"
189R"==(UPDATE_C_ROW(24, 0, betaZero); \ )==""\n"
190R"==(UPDATE_C_ROW(25, 1, betaZero); \ )==""\n"
191R"==(UPDATE_C_ROW(26, 2, betaZero); \ )==""\n"
192R"==(UPDATE_C_ROW(27, 3, betaZero); \ )==""\n"
193R"==(UPDATE_C_ROW(28, 0, betaZero); \ )==""\n"
194R"==(UPDATE_C_ROW(29, 1, betaZero); \ )==""\n"
195R"==(UPDATE_C_ROW(30, 2, betaZero); \ )==""\n"
196R"==(UPDATE_C_ROW(31, 3, betaZero); \ )==""\n"
197R"==(} while (0) )==""\n"
198R"==(#define UPDATE_C_16_ROW(betaZero) \ )==""\n"
199R"==(do { \ )==""\n"
200R"==(UPDATE_C_ROW_2X(0, 0, betaZero); \ )==""\n"
201R"==(UPDATE_C_ROW_2X(1, 1, betaZero); \ )==""\n"
202R"==(UPDATE_C_ROW_2X(2, 2, betaZero); \ )==""\n"
203R"==(UPDATE_C_ROW_2X(3, 3, betaZero); \ )==""\n"
204R"==(UPDATE_C_ROW_2X(4, 0, betaZero); \ )==""\n"
205R"==(UPDATE_C_ROW_2X(5, 1, betaZero); \ )==""\n"
206R"==(UPDATE_C_ROW_2X(6, 2, betaZero); \ )==""\n"
207R"==(UPDATE_C_ROW_2X(7, 3, betaZero); \ )==""\n"
208R"==(UPDATE_C_ROW_2X(8, 0, betaZero); \ )==""\n"
209R"==(UPDATE_C_ROW_2X(9, 1, betaZero); \ )==""\n"
210R"==(UPDATE_C_ROW_2X(10, 2, betaZero); \ )==""\n"
211R"==(UPDATE_C_ROW_2X(11, 3, betaZero); \ )==""\n"
212R"==(UPDATE_C_ROW_2X(12, 0, betaZero); \ )==""\n"
213R"==(UPDATE_C_ROW_2X(13, 1, betaZero); \ )==""\n"
214R"==(UPDATE_C_ROW_2X(14, 2, betaZero); \ )==""\n"
215R"==(UPDATE_C_ROW_2X(15, 3, betaZero); \ )==""\n"
216R"==(} while (0) )==""\n"
217R"==(#ifdef NN )==""\n"
218R"==(__attribute__((intel_reqd_sub_group_size(16))) )==""\n"
219R"==(kernel void )==""\n"
220R"==(gen9_gemm_nocopy_f32(global float *A, global float *B, global float *C, )==""\n"
221R"==(long offset_a, long offset_b, long offset_c, int lda, int ldb, int ldc, )==""\n"
222R"==(int m, int n, int k, global float *alpha, float beta, int last_k_block, )==""\n"
223R"==(float eltwise_alpha, float eltwise_beta, float eltwise_scale )==""\n"
224R"==(#ifdef WITH_K_UNROLL )==""\n"
225R"==(, )==""\n"
226R"==(volatile global int *flag, long offset_f) { )==""\n"
227R"==(#else )==""\n"
228R"==() { )==""\n"
229R"==(#endif )==""\n"
230R"==(float2 a[4]; )==""\n"
231R"==(float4 b; )==""\n"
232R"==(float4 c[8]; )==""\n"
233R"==(int idM = get_global_id(1); )==""\n"
234R"==(int idN = get_global_id(0); )==""\n"
235R"==(int lid = get_sub_group_local_id(); )==""\n"
236R"==(int idK = get_global_id(2); )==""\n"
237R"==(int nku = get_global_size(2); )==""\n"
238R"==(int i0 = idM * 32; )==""\n"
239R"==(int j0 = idN; )==""\n"
240R"==(int irem = m - i0; )==""\n"
241R"==(int jrem = n - j0; )==""\n"
242R"==(if (irem < 0) irem = 0; )==""\n"
243R"==(if (jrem < 0) jrem = 0; )==""\n"
244R"==(int last_k_unroll = (idK == nku - 1); )==""\n"
245R"==(#ifdef WITH_K_UNROLL )==""\n"
246R"==(int k0 = idK * UNROLL_K; )==""\n"
247R"==(int kt = k - k0; )==""\n"
248R"==(if (kt < 0) kt = 0; )==""\n"
249R"==(if (kt > UNROLL_K) kt = UNROLL_K; )==""\n"
250R"==(A += offset_a + i0 + k0 * lda; )==""\n"
251R"==(B += offset_b + j0 * ldb + k0; )==""\n"
252R"==(k = kt; )==""\n"
253R"==(#else )==""\n"
254R"==(A += offset_a + i0; )==""\n"
255R"==(B += offset_b + j0 * ldb; )==""\n"
256R"==(#endif )==""\n"
257R"==(C += offset_c + i0 + j0 * ldc; )==""\n"
258R"==(global float *A_cols[4] = {A, A + lda, A + 2 * lda, A + 3 * lda}; )==""\n"
259R"==(int ldax4 = lda << 2; )==""\n"
260R"==(int ldbx4 = ldb << 2; )==""\n"
261R"==(for (int z = 0; z < 8; z++) )==""\n"
262R"==(c[z] = 0.f; )==""\n"
263R"==(#ifdef WITH_K_UNROLL )==""\n"
264R"==(flag += offset_f + idM; )==""\n"
265R"==(if (idK == 0 && lid == 0) *flag = 0; )==""\n"
266R"==(#endif )==""\n"
267R"==(#ifndef ALLOW_READ_OVERRUNS )==""\n"
268R"==(if (irem >= 32 && sub_group_broadcast(jrem, 0) >= 16) { )==""\n"
269R"==(#endif )==""\n"
270R"==(for (int h = 0; h < (k >> 2); h++) { )==""\n"
271R"==(for (int hh = 0; hh < 4; hh++) { )==""\n"
272R"==(a[hh] = as_float2( )==""\n"
273R"==(intel_sub_group_block_read2((global uint *)A_cols[hh])); )==""\n"
274R"==(A_cols[hh] += ldax4; )==""\n"
275R"==(} )==""\n"
276R"==(b = vload4(0, B); )==""\n"
277R"==(B += 4; )==""\n"
278R"==(FMA_I_LOOP_32_ROW(0); )==""\n"
279R"==(FMA_I_LOOP_32_ROW(1); )==""\n"
280R"==(FMA_I_LOOP_32_ROW(2); )==""\n"
281R"==(FMA_I_LOOP_32_ROW(3); )==""\n"
282R"==(} )==""\n"
283R"==(int krem = k & 3; )==""\n"
284R"==(for (int h = 0; h < krem; h++) { )==""\n"
285R"==(a[0] = as_float2( )==""\n"
286R"==(intel_sub_group_block_read2((global uint *)A_cols[0])); )==""\n"
287R"==(A_cols[0] += lda; )==""\n"
288R"==(b = *B++; )==""\n"
289R"==(FMA_I_LOOP_32_ROW(0); )==""\n"
290R"==(} )==""\n"
291R"==(#ifndef ALLOW_READ_OVERRUNS )==""\n"
292R"==(} else { )==""\n"
293R"==(for (int h = 0; h < (k >> 1); h++) { )==""\n"
294R"==(for (int hh = 0; hh < 2; hh++) { )==""\n"
295R"==(if (irem > lid) a[hh].s0 = A_cols[hh][lid]; )==""\n"
296R"==(if (irem > (lid + 16)) a[hh].s1 = A_cols[hh][lid + 16]; )==""\n"
297R"==(A_cols[hh] += (lda << 1); )==""\n"
298R"==(} )==""\n"
299R"==(if (jrem > 0) b.s01 = vload2(0, B); )==""\n"
300R"==(B += 2; )==""\n"
301R"==(FMA_I_LOOP_32_ROW(0); )==""\n"
302R"==(FMA_I_LOOP_32_ROW(1); )==""\n"
303R"==(} )==""\n"
304R"==(if (k & 1) { )==""\n"
305R"==(if (irem > lid) a[0].s0 = A_cols[0][lid]; )==""\n"
306R"==(if (irem > (lid + 16)) a[0].s1 = A_cols[0][lid + 16]; )==""\n"
307R"==(if (jrem > 0) b = *B; )==""\n"
308R"==(FMA_I_LOOP_32_ROW(0); )==""\n"
309R"==(} )==""\n"
310R"==(} )==""\n"
311R"==(#endif )==""\n"
312R"==(#ifdef WITH_K_UNROLL )==""\n"
313R"==(do { )==""\n"
314R"==(read_mem_fence(CLK_GLOBAL_MEM_FENCE); )==""\n"
315R"==(} while (*flag != idK); )==""\n"
316R"==(if (idK == 0) { )==""\n"
317R"==(if (beta == 0) )==""\n"
318R"==(UPDATE_C_32_ROW(1); )==""\n"
319R"==(else )==""\n"
320R"==(UPDATE_C_32_ROW(0); )==""\n"
321R"==(} else { )==""\n"
322R"==(beta = 1.0; )==""\n"
323R"==(UPDATE_C_32_ROW(0); )==""\n"
324R"==(} )==""\n"
325R"==(if (lid == 0) *flag = idK + 1; )==""\n"
326R"==(#else )==""\n"
327R"==(if (beta == 0) )==""\n"
328R"==(UPDATE_C_32_ROW(1); )==""\n"
329R"==(else )==""\n"
330R"==(UPDATE_C_32_ROW(0); )==""\n"
331R"==(#endif )==""\n"
332R"==(} )==""\n"
333R"==(#endif )==""\n"
334R"==(#ifdef NT )==""\n"
335R"==(__attribute__((intel_reqd_sub_group_size(16))) )==""\n"
336R"==(kernel void )==""\n"
337R"==(gen9_gemm_nocopy_f32(global float *A, global float *B, global float *C, )==""\n"
338R"==(long offset_a, long offset_b, long offset_c, int lda, int ldb, int ldc, )==""\n"
339R"==(int m, int n, int k, global float *alpha, float beta, int last_k_block, )==""\n"
340R"==(float eltwise_alpha, float eltwise_beta, float eltwise_scale )==""\n"
341R"==(#ifdef WITH_K_UNROLL )==""\n"
342R"==(, )==""\n"
343R"==(volatile global int *flag, long offset_f) { )==""\n"
344R"==(#else )==""\n"
345R"==() { )==""\n"
346R"==(#endif )==""\n"
347R"==(float2 a[2]; )==""\n"
348R"==(float b[2]; )==""\n"
349R"==(float4 c[8]; )==""\n"
350R"==(int idM = get_global_id(1); )==""\n"
351R"==(int idN = get_global_id(0); )==""\n"
352R"==(int lid = get_sub_group_local_id(); )==""\n"
353R"==(int idK = get_global_id(2); )==""\n"
354R"==(int nku = get_global_size(2); )==""\n"
355R"==(int i0 = idM * 32; )==""\n"
356R"==(int j0 = idN; )==""\n"
357R"==(int irem = m - i0; )==""\n"
358R"==(int jrem = n - j0; )==""\n"
359R"==(if (irem < 0) irem = 0; )==""\n"
360R"==(if (jrem < 0) jrem = 0; )==""\n"
361R"==(int last_k_unroll = (idK == nku - 1); )==""\n"
362R"==(#ifdef WITH_K_UNROLL )==""\n"
363R"==(int k0 = idK * UNROLL_K; )==""\n"
364R"==(int kt = k - k0; )==""\n"
365R"==(if (kt < 0) kt = 0; )==""\n"
366R"==(if (kt > UNROLL_K) kt = UNROLL_K; )==""\n"
367R"==(A += offset_a + i0 + k0 * lda; )==""\n"
368R"==(B += offset_b + sub_group_broadcast(j0, 0) + k0 * ldb; )==""\n"
369R"==(k = kt; )==""\n"
370R"==(#else )==""\n"
371R"==(A += offset_a + i0; )==""\n"
372R"==(B += offset_b + sub_group_broadcast(j0, 0); )==""\n"
373R"==(#endif )==""\n"
374R"==(C += offset_c + i0 + j0 * ldc; )==""\n"
375R"==(global float *A_cols[2] = {A, A + lda}; )==""\n"
376R"==(global float *B_rows[2] = {B, B + ldb}; )==""\n"
377R"==(int ldax2 = lda << 1; )==""\n"
378R"==(int ldbx2 = ldb << 1; )==""\n"
379R"==(for (int z = 0; z < 8; z++) )==""\n"
380R"==(c[z] = 0.f; )==""\n"
381R"==(#ifdef WITH_K_UNROLL )==""\n"
382R"==(flag += offset_f + idM; )==""\n"
383R"==(if (idK == 0 && lid == 0) *flag = 0; )==""\n"
384R"==(#endif )==""\n"
385R"==(#ifndef ALLOW_READ_OVERRUNS )==""\n"
386R"==(if (irem >= 32 && sub_group_broadcast(jrem, 0) >= 16) { )==""\n"
387R"==(#endif )==""\n"
388R"==(for (int h = 0; h < (k >> 1); h++) { )==""\n"
389R"==(for (int hh = 0; hh < 2; hh++) { )==""\n"
390R"==(a[hh] = as_float2( )==""\n"
391R"==(intel_sub_group_block_read2((global uint *)A_cols[hh])); )==""\n"
392R"==(A_cols[hh] += ldax2; )==""\n"
393R"==(} )==""\n"
394R"==(for (int hh = 0; hh < 2; hh++) { )==""\n"
395R"==(b[hh] = as_float( )==""\n"
396R"==(intel_sub_group_block_read((global uint *)B_rows[hh])); )==""\n"
397R"==(B_rows[hh] += ldbx2; )==""\n"
398R"==(} )==""\n"
399R"==(FMA_I_LOOP_32_ROW(0); )==""\n"
400R"==(FMA_I_LOOP_32_ROW(1); )==""\n"
401R"==(} )==""\n"
402R"==(int krem = k & 1; )==""\n"
403R"==(if (krem > 0) { )==""\n"
404R"==(a[0] = as_float2( )==""\n"
405R"==(intel_sub_group_block_read2((global uint *)A_cols[0])); )==""\n"
406R"==(b[0] = as_float( )==""\n"
407R"==(intel_sub_group_block_read((global uint *)B_rows[0])); )==""\n"
408R"==(FMA_I_LOOP_32_ROW(0); )==""\n"
409R"==(} )==""\n"
410R"==(#ifndef ALLOW_READ_OVERRUNS )==""\n"
411R"==(} else { )==""\n"
412R"==(for (int h = 0; h < (k >> 1); h++) { )==""\n"
413R"==(for (int hh = 0; hh < 2; hh++) { )==""\n"
414R"==(if (irem > lid) a[hh].s0 = A_cols[hh][lid]; )==""\n"
415R"==(if (irem > (lid + 16)) a[hh].s1 = A_cols[hh][lid + 16]; )==""\n"
416R"==(A_cols[hh] += ldax2; )==""\n"
417R"==(} )==""\n"
418R"==(for (int hh = 0; hh < 2; hh++) { )==""\n"
419R"==(if (jrem > 0) b[hh] = B_rows[hh][lid]; )==""\n"
420R"==(B_rows[hh] += ldbx2; )==""\n"
421R"==(} )==""\n"
422R"==(FMA_I_LOOP_32_ROW(0); )==""\n"
423R"==(FMA_I_LOOP_32_ROW(1); )==""\n"
424R"==(} )==""\n"
425R"==(int krem = k & 1; )==""\n"
426R"==(if (krem > 0) { )==""\n"
427R"==(if (irem > lid) a[0].s0 = A_cols[0][lid]; )==""\n"
428R"==(if (irem > (lid + 16)) a[0].s1 = A_cols[0][lid + 16]; )==""\n"
429R"==(if (jrem > 0) b[0] = B_rows[0][lid]; )==""\n"
430R"==(FMA_I_LOOP_32_ROW(0); )==""\n"
431R"==(} )==""\n"
432R"==(} )==""\n"
433R"==(#endif )==""\n"
434R"==(#ifdef WITH_K_UNROLL )==""\n"
435R"==(do { )==""\n"
436R"==(read_mem_fence(CLK_GLOBAL_MEM_FENCE); )==""\n"
437R"==(} while (*flag != idK); )==""\n"
438R"==(if (idK == 0) { )==""\n"
439R"==(if (beta == 0) )==""\n"
440R"==(UPDATE_C_32_ROW(1); )==""\n"
441R"==(else )==""\n"
442R"==(UPDATE_C_32_ROW(0); )==""\n"
443R"==(} else { )==""\n"
444R"==(beta = 1.0; )==""\n"
445R"==(UPDATE_C_32_ROW(0); )==""\n"
446R"==(} )==""\n"
447R"==(if (lid == 0) *flag = idK + 1; )==""\n"
448R"==(#else )==""\n"
449R"==(if (beta == 0) )==""\n"
450R"==(UPDATE_C_32_ROW(1); )==""\n"
451R"==(else )==""\n"
452R"==(UPDATE_C_32_ROW(0); )==""\n"
453R"==(#endif )==""\n"
454R"==(} )==""\n"
455R"==(#endif )==""\n"
456R"==(#ifdef TN )==""\n"
457R"==(__attribute__((intel_reqd_sub_group_size(16))) )==""\n"
458R"==(kernel void )==""\n"
459R"==(gen9_gemm_nocopy_f32(global float *A, global float *B, global float *C, )==""\n"
460R"==(long offset_a, long offset_b, long offset_c, int lda, int ldb, int ldc, )==""\n"
461R"==(int m, int n, int k, global float *alpha, float beta, int last_k_block, )==""\n"
462R"==(float eltwise_alpha, float eltwise_beta, float eltwise_scale )==""\n"
463R"==(#ifdef WITH_K_UNROLL )==""\n"
464R"==(, )==""\n"
465R"==(volatile global int *flag, long offset_f) { )==""\n"
466R"==(#else )==""\n"
467R"==() { )==""\n"
468R"==(#endif )==""\n"
469R"==(float4 a; )==""\n"
470R"==(float4 b[2]; )==""\n"
471R"==(float4 c[4][2]; )==""\n"
472R"==(int idM = get_global_id(1); )==""\n"
473R"==(int idN = get_global_id(0); )==""\n"
474R"==(int lid = get_sub_group_local_id(); )==""\n"
475R"==(int idK = get_global_id(2); )==""\n"
476R"==(int nku = get_global_size(2); )==""\n"
477R"==(int i0 = idM * 16; )==""\n"
478R"==(int j0 = sub_group_broadcast(idN, 0) * 2 + lid; )==""\n"
479R"==(int irem = m - i0; )==""\n"
480R"==(int jrem = n - j0; )==""\n"
481R"==(if (irem < 0) irem = 0; )==""\n"
482R"==(if (jrem < 0) jrem = 0; )==""\n"
483R"==(int last_k_unroll = (idK == nku - 1); )==""\n"
484R"==(#ifdef WITH_K_UNROLL )==""\n"
485R"==(int k0 = idK * UNROLL_K; )==""\n"
486R"==(int kt = k - k0; )==""\n"
487R"==(if (kt < 0) kt = 0; )==""\n"
488R"==(if (kt > UNROLL_K) kt = UNROLL_K; )==""\n"
489R"==(A += offset_a + (i0 + lid) * lda + k0; )==""\n"
490R"==(B += offset_b + j0 * ldb + k0; )==""\n"
491R"==(k = kt; )==""\n"
492R"==(#else )==""\n"
493R"==(A += offset_a + (i0 + lid) * lda; )==""\n"
494R"==(B += offset_b + j0 * ldb; )==""\n"
495R"==(#endif )==""\n"
496R"==(C += offset_c + i0 + j0 * ldc; )==""\n"
497R"==(global float *B_ptrs[2] = {B, B + 16 * ldb}; )==""\n"
498R"==(for (int ii = 0; ii < 4; ii++) )==""\n"
499R"==(for (int jj = 0; jj < 2; jj++) )==""\n"
500R"==(c[ii][jj] = 0.f; )==""\n"
501R"==(#ifdef WITH_K_UNROLL )==""\n"
502R"==(flag += offset_f + idM; )==""\n"
503R"==(if (idK == 0 && lid == 0) *flag = 0; )==""\n"
504R"==(#endif )==""\n"
505R"==(for (int h = 0; h < (k >> 2); h++) { )==""\n"
506R"==(if (irem > lid) a = vload4(0, A); )==""\n"
507R"==(A += 4; )==""\n"
508R"==(for (int hh = 0; hh < 2; hh++) { )==""\n"
509R"==(if (jrem > hh * 16) b[hh] = vload4(0, B_ptrs[hh]); )==""\n"
510R"==(B_ptrs[hh] += 4; )==""\n"
511R"==(} )==""\n"
512R"==(FMA_I_LOOP_16_ROW(0); )==""\n"
513R"==(FMA_I_LOOP_16_ROW(1); )==""\n"
514R"==(FMA_I_LOOP_16_ROW(2); )==""\n"
515R"==(FMA_I_LOOP_16_ROW(3); )==""\n"
516R"==(} )==""\n"
517R"==(int krem = k & 3; )==""\n"
518R"==(for (int h = 0; h < krem; h++) { )==""\n"
519R"==(if (irem > lid) a = *A++; )==""\n"
520R"==(for (int hh = 0; hh < 2; hh++) { )==""\n"
521R"==(if (jrem > hh * 16) b[hh] = *B_ptrs[hh]; )==""\n"
522R"==(B_ptrs[hh]++; )==""\n"
523R"==(} )==""\n"
524R"==(FMA_I_LOOP_16_ROW(0); )==""\n"
525R"==(} )==""\n"
526R"==(global float *C_ptrs[2] = {C, C + 16 * ldc}; )==""\n"
527R"==(#ifdef WITH_K_UNROLL )==""\n"
528R"==(do { )==""\n"
529R"==(read_mem_fence(CLK_GLOBAL_MEM_FENCE); )==""\n"
530R"==(} while (*flag != idK); )==""\n"
531R"==(if (idK == 0) { )==""\n"
532R"==(if (beta == 0) )==""\n"
533R"==(UPDATE_C_16_ROW(1); )==""\n"
534R"==(else )==""\n"
535R"==(UPDATE_C_16_ROW(0); )==""\n"
536R"==(} else { )==""\n"
537R"==(beta = 1.0; )==""\n"
538R"==(UPDATE_C_16_ROW(0); )==""\n"
539R"==(} )==""\n"
540R"==(if (lid == 0) *flag = idK + 1; )==""\n"
541R"==(#else )==""\n"
542R"==(if (beta == 0) )==""\n"
543R"==(UPDATE_C_16_ROW(1); )==""\n"
544R"==(else )==""\n"
545R"==(UPDATE_C_16_ROW(0); )==""\n"
546R"==(#endif )==""\n"
547R"==(} )==""\n"
548R"==(#endif )==""\n"
549R"==(#ifdef TT )==""\n"
550R"==(__attribute__((intel_reqd_sub_group_size(16))) )==""\n"
551R"==(kernel void )==""\n"
552R"==(gen9_gemm_nocopy_f32(global float *A, global float *B, global float *C, )==""\n"
553R"==(long offset_a, long offset_b, long offset_c, int lda, int ldb, int ldc, )==""\n"
554R"==(int m, int n, int k, global float *alpha, float beta, int last_k_block, )==""\n"
555R"==(float eltwise_alpha, float eltwise_beta, float eltwise_scale )==""\n"
556R"==(#ifdef WITH_K_UNROLL )==""\n"
557R"==(, )==""\n"
558R"==(volatile global int *flag, long offset_f) { )==""\n"
559R"==(#else )==""\n"
560R"==() { )==""\n"
561R"==(#endif )==""\n"
562R"==(float4 a; )==""\n"
563R"==(float2 b[4]; )==""\n"
564R"==(float4 c[4][2]; )==""\n"
565R"==(int idM = get_global_id(1); )==""\n"
566R"==(int idN = get_global_id(0); )==""\n"
567R"==(int lid = get_sub_group_local_id(); )==""\n"
568R"==(int idK = get_global_id(2); )==""\n"
569R"==(int nku = get_global_size(2); )==""\n"
570R"==(int i0 = idM * 16; )==""\n"
571R"==(int j0 = sub_group_broadcast(idN, 0) * 2 + lid; )==""\n"
572R"==(int irem = m - i0; )==""\n"
573R"==(int jrem = n - j0; )==""\n"
574R"==(if (irem < 0) irem = 0; )==""\n"
575R"==(if (jrem < 0) jrem = 0; )==""\n"
576R"==(int last_k_unroll = (idK == nku - 1); )==""\n"
577R"==(#ifdef WITH_K_UNROLL )==""\n"
578R"==(int k0 = idK * UNROLL_K; )==""\n"
579R"==(int kt = k - k0; )==""\n"
580R"==(if (kt < 0) kt = 0; )==""\n"
581R"==(if (kt > UNROLL_K) kt = UNROLL_K; )==""\n"
582R"==(A += offset_a + (i0 + lid) * lda + k0; )==""\n"
583R"==(B += offset_b + sub_group_broadcast(j0, 0) + k0 * ldb; )==""\n"
584R"==(k = kt; )==""\n"
585R"==(#else )==""\n"
586R"==(A += offset_a + (i0 + lid) * lda; )==""\n"
587R"==(B += offset_b + sub_group_broadcast(j0, 0); )==""\n"
588R"==(#endif )==""\n"
589R"==(C += offset_c + i0 + j0 * ldc; )==""\n"
590R"==(global float *B_rows[4] = {B, B + ldb, B + 2 * ldb, B + 3 * ldb}; )==""\n"
591R"==(int ldbx4 = ldb << 2; )==""\n"
592R"==(for (int ii = 0; ii < 4; ii++) )==""\n"
593R"==(for (int jj = 0; jj < 2; jj++) )==""\n"
594R"==(c[ii][jj] = 0.f; )==""\n"
595R"==(#ifdef WITH_K_UNROLL )==""\n"
596R"==(flag += offset_f + idM; )==""\n"
597R"==(if (idK == 0 && lid == 0) *flag = 0; )==""\n"
598R"==(#endif )==""\n"
599R"==(#ifndef ALLOW_READ_OVERRUNS )==""\n"
600R"==(if (irem >= 16 && sub_group_broadcast(jrem, 0) >= 32) { )==""\n"
601R"==(#endif )==""\n"
602R"==(for (int h = 0; h < (k >> 2); h++) { )==""\n"
603R"==(a = vload4(0, A); )==""\n"
604R"==(A += 4; )==""\n"
605R"==(for (int hh = 0; hh < 4; hh++) { )==""\n"
606R"==(b[hh] = as_float2( )==""\n"
607R"==(intel_sub_group_block_read2((global uint *)B_rows[hh])); )==""\n"
608R"==(B_rows[hh] += ldbx4; )==""\n"
609R"==(} )==""\n"
610R"==(FMA_I_LOOP_16_ROW(0); )==""\n"
611R"==(FMA_I_LOOP_16_ROW(1); )==""\n"
612R"==(FMA_I_LOOP_16_ROW(2); )==""\n"
613R"==(FMA_I_LOOP_16_ROW(3); )==""\n"
614R"==(} )==""\n"
615R"==(int krem = k & 3; )==""\n"
616R"==(for (int h = 0; h < krem; h++) { )==""\n"
617R"==(a = *A++; )==""\n"
618R"==(b[0] = as_float2( )==""\n"
619R"==(intel_sub_group_block_read2((global uint *)B_rows[0])); )==""\n"
620R"==(B_rows[0] += ldb; )==""\n"
621R"==(FMA_I_LOOP_16_ROW(0); )==""\n"
622R"==(} )==""\n"
623R"==(#ifndef ALLOW_READ_OVERRUNS )==""\n"
624R"==(} else { )==""\n"
625R"==(for (int h = 0; h < (k >> 2); h++) { )==""\n"
626R"==(if (irem > lid) a = vload4(0, A); )==""\n"
627R"==(A += 4; )==""\n"
628R"==(for (int hh = 0; hh < 4; hh++) { )==""\n"
629R"==(if (jrem > 0) b[hh].s0 = B_rows[hh][lid]; )==""\n"
630R"==(if (jrem > 16) b[hh].s1 = B_rows[hh][lid + 16]; )==""\n"
631R"==(B_rows[hh] += ldbx4; )==""\n"
632R"==(} )==""\n"
633R"==(FMA_I_LOOP_16_ROW(0); )==""\n"
634R"==(FMA_I_LOOP_16_ROW(1); )==""\n"
635R"==(FMA_I_LOOP_16_ROW(2); )==""\n"
636R"==(FMA_I_LOOP_16_ROW(3); )==""\n"
637R"==(} )==""\n"
638R"==(int krem = k & 3; )==""\n"
639R"==(for (int h = 0; h < krem; h++) { )==""\n"
640R"==(if (irem > lid) a = *A++; )==""\n"
641R"==(if (jrem > 0) b[0].s0 = B_rows[0][lid]; )==""\n"
642R"==(if (jrem > 16) b[0].s1 = B_rows[0][lid + 16]; )==""\n"
643R"==(B_rows[0] += ldb; )==""\n"
644R"==(FMA_I_LOOP_16_ROW(0); )==""\n"
645R"==(} )==""\n"
646R"==(} )==""\n"
647R"==(#endif )==""\n"
648R"==(global float *C_ptrs[2] = {C, C + 16 * ldc}; )==""\n"
649R"==(#ifdef WITH_K_UNROLL )==""\n"
650R"==(do { )==""\n"
651R"==(read_mem_fence(CLK_GLOBAL_MEM_FENCE); )==""\n"
652R"==(} while (*flag != idK); )==""\n"
653R"==(if (idK == 0) { )==""\n"
654R"==(if (beta == 0) )==""\n"
655R"==(UPDATE_C_16_ROW(1); )==""\n"
656R"==(else )==""\n"
657R"==(UPDATE_C_16_ROW(0); )==""\n"
658R"==(} else { )==""\n"
659R"==(beta = 1.0; )==""\n"
660R"==(UPDATE_C_16_ROW(0); )==""\n"
661R"==(} )==""\n"
662R"==(if (lid == 0) *flag = idK + 1; )==""\n"
663R"==(#else )==""\n"
664R"==(if (beta == 0) )==""\n"
665R"==(UPDATE_C_16_ROW(1); )==""\n"
666R"==(else )==""\n"
667R"==(UPDATE_C_16_ROW(0); )==""\n"
668R"==(#endif )==""\n"
669R"==(} )==""\n"
670R"==(#endif )==""\n"
671R"==()==";
672}
673}
674}
675}