1namespace dnnl {
2namespace impl {
3namespace gpu {
4namespace ocl {
5const char *gen9_gemm_nocopy_superkernel_f32_kernel = R"==(/******************************************************************************* )==""\n"
6R"==(* Copyright 2019-2022 Intel Corporation )==""\n"
7R"==(* )==""\n"
8R"==(* Licensed under the Apache License, Version 2.0 (the "License"); )==""\n"
9R"==(* you may not use this file except in compliance with the License. )==""\n"
10R"==(* You may obtain a copy of the License at )==""\n"
11R"==(* )==""\n"
12R"==(* http: )==""\n"
13R"==(* )==""\n"
14R"==(* Unless required by applicable law or agreed to in writing, software )==""\n"
15R"==(* distributed under the License is distributed on an "AS IS" BASIS, )==""\n"
16R"==(* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. )==""\n"
17R"==(* See the License for the specific language governing permissions and )==""\n"
18R"==(* limitations under the License. )==""\n"
19R"==(*******************************************************************************/ )==""\n"
20R"==(#include "gpu/ocl/gemm/ocl_gemm_attrs.h" )==""\n"
21R"==(#include "gpu/ocl/ocl_post_ops.h" )==""\n"
22R"==(#include "gpu/ocl/ocl_types.h" )==""\n"
23R"==(#if DT_F32 != 1 )==""\n"
24R"==(#error "Only f32 implemented." )==""\n"
25R"==(#endif )==""\n"
26R"==(#define DO_FMA_NN(hh, i_mod_16, i_div_16, i_mod_4, i_div_4) \ )==""\n"
27R"==(do { \ )==""\n"
28R"==(c[i_div_4].s##i_mod_4 \ )==""\n"
29R"==(= mad(sub_group_broadcast(a[hh].s##i_div_16, i_mod_16), \ )==""\n"
30R"==(b.s##hh, c[i_div_4].s##i_mod_4); \ )==""\n"
31R"==(} while (0) )==""\n"
32R"==(#define DO_FMA_NT(hh, i_mod_16, i_div_16, i_mod_4, i_div_4) \ )==""\n"
33R"==(do { \ )==""\n"
34R"==(c[i_div_4].s##i_mod_4 \ )==""\n"
35R"==(= mad(sub_group_broadcast(a[hh].s##i_div_16, i_mod_16), b[hh], \ )==""\n"
36R"==(c[i_div_4].s##i_mod_4); \ )==""\n"
37R"==(} while (0) )==""\n"
38R"==(#if !defined(TRANS_A) )==""\n"
39R"==(#if !defined(TRANS_B) )==""\n"
40R"==(#define NN )==""\n"
41R"==(#define DO_FMA DO_FMA_NN )==""\n"
42R"==(#else )==""\n"
43R"==(#define NT )==""\n"
44R"==(#define DO_FMA DO_FMA_NT )==""\n"
45R"==(#endif )==""\n"
46R"==(#else )==""\n"
47R"==(#error "No superkernel implementation." )==""\n"
48R"==(#endif )==""\n"
49R"==(#define FMA_I_LOOP_32_ROW(hh) \ )==""\n"
50R"==(do { \ )==""\n"
51R"==(DO_FMA(hh, 0, 0, 0, 0); \ )==""\n"
52R"==(DO_FMA(hh, 1, 0, 1, 0); \ )==""\n"
53R"==(DO_FMA(hh, 2, 0, 2, 0); \ )==""\n"
54R"==(DO_FMA(hh, 3, 0, 3, 0); \ )==""\n"
55R"==(DO_FMA(hh, 4, 0, 0, 1); \ )==""\n"
56R"==(DO_FMA(hh, 5, 0, 1, 1); \ )==""\n"
57R"==(DO_FMA(hh, 6, 0, 2, 1); \ )==""\n"
58R"==(DO_FMA(hh, 7, 0, 3, 1); \ )==""\n"
59R"==(DO_FMA(hh, 8, 0, 0, 2); \ )==""\n"
60R"==(DO_FMA(hh, 9, 0, 1, 2); \ )==""\n"
61R"==(DO_FMA(hh, 10, 0, 2, 2); \ )==""\n"
62R"==(DO_FMA(hh, 11, 0, 3, 2); \ )==""\n"
63R"==(DO_FMA(hh, 12, 0, 0, 3); \ )==""\n"
64R"==(DO_FMA(hh, 13, 0, 1, 3); \ )==""\n"
65R"==(DO_FMA(hh, 14, 0, 2, 3); \ )==""\n"
66R"==(DO_FMA(hh, 15, 0, 3, 3); \ )==""\n"
67R"==(DO_FMA(hh, 16, 1, 0, 4); \ )==""\n"
68R"==(DO_FMA(hh, 17, 1, 1, 4); \ )==""\n"
69R"==(DO_FMA(hh, 18, 1, 2, 4); \ )==""\n"
70R"==(DO_FMA(hh, 19, 1, 3, 4); \ )==""\n"
71R"==(DO_FMA(hh, 20, 1, 0, 5); \ )==""\n"
72R"==(DO_FMA(hh, 21, 1, 1, 5); \ )==""\n"
73R"==(DO_FMA(hh, 22, 1, 2, 5); \ )==""\n"
74R"==(DO_FMA(hh, 23, 1, 3, 5); \ )==""\n"
75R"==(DO_FMA(hh, 24, 1, 0, 6); \ )==""\n"
76R"==(DO_FMA(hh, 25, 1, 1, 6); \ )==""\n"
77R"==(DO_FMA(hh, 26, 1, 2, 6); \ )==""\n"
78R"==(DO_FMA(hh, 27, 1, 3, 6); \ )==""\n"
79R"==(DO_FMA(hh, 28, 1, 0, 7); \ )==""\n"
80R"==(DO_FMA(hh, 29, 1, 1, 7); \ )==""\n"
81R"==(DO_FMA(hh, 30, 1, 2, 7); \ )==""\n"
82R"==(DO_FMA(hh, 31, 1, 3, 7); \ )==""\n"
83R"==(} while (0) )==""\n"
84R"==(#define FMA_I_LOOP_16_ROW(hh) \ )==""\n"
85R"==(do { \ )==""\n"
86R"==(DO_FMA(hh, 0, 0, 0, 0); \ )==""\n"
87R"==(DO_FMA(hh, 1, 0, 1, 0); \ )==""\n"
88R"==(DO_FMA(hh, 2, 0, 2, 0); \ )==""\n"
89R"==(DO_FMA(hh, 3, 0, 3, 0); \ )==""\n"
90R"==(DO_FMA(hh, 4, 0, 0, 1); \ )==""\n"
91R"==(DO_FMA(hh, 5, 0, 1, 1); \ )==""\n"
92R"==(DO_FMA(hh, 6, 0, 2, 1); \ )==""\n"
93R"==(DO_FMA(hh, 7, 0, 3, 1); \ )==""\n"
94R"==(DO_FMA(hh, 8, 0, 0, 2); \ )==""\n"
95R"==(DO_FMA(hh, 9, 0, 1, 2); \ )==""\n"
96R"==(DO_FMA(hh, 10, 0, 2, 2); \ )==""\n"
97R"==(DO_FMA(hh, 11, 0, 3, 2); \ )==""\n"
98R"==(DO_FMA(hh, 12, 0, 0, 3); \ )==""\n"
99R"==(DO_FMA(hh, 13, 0, 1, 3); \ )==""\n"
100R"==(DO_FMA(hh, 14, 0, 2, 3); \ )==""\n"
101R"==(DO_FMA(hh, 15, 0, 3, 3); \ )==""\n"
102R"==(} while (0) )==""\n"
103R"==(#if WITH_ELTWISE == 1 )==""\n"
104R"==(#define POST_OP(val) \ )==""\n"
105R"==(do { \ )==""\n"
106R"==(if (last_k_block) \ )==""\n"
107R"==(val = fwd_eltwise( \ )==""\n"
108R"==(val, eltwise_alpha, eltwise_beta, eltwise_scale); \ )==""\n"
109R"==(} while (0) )==""\n"
110R"==(#else )==""\n"
111R"==(#define POST_OP(val) )==""\n"
112R"==(#endif )==""\n"
113R"==(#define UPDATE_C_ROW(i, ii, betaZero) \ )==""\n"
114R"==(do { \ )==""\n"
115R"==(if (jrem > 0) \ )==""\n"
116R"==(if (irem > i) { \ )==""\n"
117R"==(float val = ATTR_ALPHA * c[i / 4].s##ii \ )==""\n"
118R"==(+ ((betaZero) ? 0 : beta * *C); \ )==""\n"
119R"==(POST_OP(val); \ )==""\n"
120R"==(*C = val; \ )==""\n"
121R"==(} \ )==""\n"
122R"==(C++; \ )==""\n"
123R"==(} while (0) )==""\n"
124R"==(#define UPDATE_C_32_ROW(betaZero) \ )==""\n"
125R"==(do { \ )==""\n"
126R"==(UPDATE_C_ROW(0, 0, betaZero); \ )==""\n"
127R"==(UPDATE_C_ROW(1, 1, betaZero); \ )==""\n"
128R"==(UPDATE_C_ROW(2, 2, betaZero); \ )==""\n"
129R"==(UPDATE_C_ROW(3, 3, betaZero); \ )==""\n"
130R"==(UPDATE_C_ROW(4, 0, betaZero); \ )==""\n"
131R"==(UPDATE_C_ROW(5, 1, betaZero); \ )==""\n"
132R"==(UPDATE_C_ROW(6, 2, betaZero); \ )==""\n"
133R"==(UPDATE_C_ROW(7, 3, betaZero); \ )==""\n"
134R"==(UPDATE_C_ROW(8, 0, betaZero); \ )==""\n"
135R"==(UPDATE_C_ROW(9, 1, betaZero); \ )==""\n"
136R"==(UPDATE_C_ROW(10, 2, betaZero); \ )==""\n"
137R"==(UPDATE_C_ROW(11, 3, betaZero); \ )==""\n"
138R"==(UPDATE_C_ROW(12, 0, betaZero); \ )==""\n"
139R"==(UPDATE_C_ROW(13, 1, betaZero); \ )==""\n"
140R"==(UPDATE_C_ROW(14, 2, betaZero); \ )==""\n"
141R"==(UPDATE_C_ROW(15, 3, betaZero); \ )==""\n"
142R"==(UPDATE_C_ROW(16, 0, betaZero); \ )==""\n"
143R"==(UPDATE_C_ROW(17, 1, betaZero); \ )==""\n"
144R"==(UPDATE_C_ROW(18, 2, betaZero); \ )==""\n"
145R"==(UPDATE_C_ROW(19, 3, betaZero); \ )==""\n"
146R"==(UPDATE_C_ROW(20, 0, betaZero); \ )==""\n"
147R"==(UPDATE_C_ROW(21, 1, betaZero); \ )==""\n"
148R"==(UPDATE_C_ROW(22, 2, betaZero); \ )==""\n"
149R"==(UPDATE_C_ROW(23, 3, betaZero); \ )==""\n"
150R"==(UPDATE_C_ROW(24, 0, betaZero); \ )==""\n"
151R"==(UPDATE_C_ROW(25, 1, betaZero); \ )==""\n"
152R"==(UPDATE_C_ROW(26, 2, betaZero); \ )==""\n"
153R"==(UPDATE_C_ROW(27, 3, betaZero); \ )==""\n"
154R"==(UPDATE_C_ROW(28, 0, betaZero); \ )==""\n"
155R"==(UPDATE_C_ROW(29, 1, betaZero); \ )==""\n"
156R"==(UPDATE_C_ROW(30, 2, betaZero); \ )==""\n"
157R"==(UPDATE_C_ROW(31, 3, betaZero); \ )==""\n"
158R"==(} while (0) )==""\n"
159R"==(#define SUPERKERNEL_PROLOGUE \ )==""\n"
160R"==(global volatile int *p = plan; \ )==""\n"
161R"==(int id = get_group_id(0); \ )==""\n"
162R"==(\ )==""\n"
163R"==(A0 += offsetA; \ )==""\n"
164R"==(B0 += offsetB; \ )==""\n"
165R"==(C0 += offsetC; \ )==""\n"
166R"==(\ )==""\n"
167R"==(while (id < threads) { \ )==""\n"
168R"==(uint i0, j0; \ )==""\n"
169R"==(uint kid0, kid1; \ )==""\n"
170R"==(\ )==""\n"
171R"==(i0 = plan[2 * id + 2]; \ )==""\n"
172R"==(j0 = plan[2 * id + 3]; \ )==""\n"
173R"==(kid0 = (i0 >> 31); \ )==""\n"
174R"==(kid1 = (j0 >> 31); \ )==""\n"
175R"==(i0 &= ~(1 << 31); \ )==""\n"
176R"==(j0 &= ~(1 << 31); \ )==""\n"
177R"==(j0 += get_local_id(0); )==""\n"
178R"==(#define SUPERKERNEL_EPILOGUE \ )==""\n"
179R"==(if (get_sub_group_local_id() == 0) id = atomic_inc(plan); \ )==""\n"
180R"==(\ )==""\n"
181R"==(sub_group_barrier(0); \ )==""\n"
182R"==(id = sub_group_broadcast(id, 0); \ )==""\n"
183R"==(} \ )==""\n"
184R"==(if (get_sub_group_local_id() == 0) { \ )==""\n"
185R"==(if (atomic_inc(plan + 1) == (get_num_groups(0) - 1)) { \ )==""\n"
186R"==(mem_fence(CLK_GLOBAL_MEM_FENCE); \ )==""\n"
187R"==(plan[0] = get_num_groups(0); \ )==""\n"
188R"==(plan[1] = 0; \ )==""\n"
189R"==(} \ )==""\n"
190R"==(} )==""\n"
191R"==(#ifdef NN )==""\n"
192R"==(__attribute__((intel_reqd_sub_group_size(16))) )==""\n"
193R"==(kernel void )==""\n"
194R"==(gen9_gemm_nocopy_superkernel_f32(global int *plan, int threads, )==""\n"
195R"==(global float *A0, global float *B0, global float *C0, long offsetA, )==""\n"
196R"==(long offsetB, long offsetC, int lda, int ldb, int ldc, int m, int n, )==""\n"
197R"==(int k, global float *alpha, float beta, int last_k_block, )==""\n"
198R"==(float eltwise_alpha, float eltwise_beta, float eltwise_scale) { )==""\n"
199R"==(SUPERKERNEL_PROLOGUE )==""\n"
200R"==(float2 a[4]; )==""\n"
201R"==(float4 b; )==""\n"
202R"==(float4 c[8]; )==""\n"
203R"==(int irem = m - i0; )==""\n"
204R"==(int jrem = n - j0; )==""\n"
205R"==(if (irem < 0) irem = 0; )==""\n"
206R"==(if (jrem < 0) jrem = 0; )==""\n"
207R"==(global float *A = A0 + i0; )==""\n"
208R"==(global float *B = B0 + j0 * ldb; )==""\n"
209R"==(global float *C = C0 + i0 + j0 * ldc; )==""\n"
210R"==(global float *A_cols[4] = {A, A + lda, A + 2 * lda, A + 3 * lda}; )==""\n"
211R"==(int ldax4 = lda << 2; )==""\n"
212R"==(int ldbx4 = ldb << 2; )==""\n"
213R"==(if (kid0 == 0) { )==""\n"
214R"==(for (int z = 0; z < 8; z++) )==""\n"
215R"==(c[z] = 0.f; )==""\n"
216R"==(for (int h = 0; h < (k >> 2); h++) { )==""\n"
217R"==(for (int j = 0; j < 4; j++) { )==""\n"
218R"==(a[j] = as_float2( )==""\n"
219R"==(intel_sub_group_block_read2((global uint *)A_cols[j])); )==""\n"
220R"==(A_cols[j] += ldax4; )==""\n"
221R"==(} )==""\n"
222R"==(b = vload4(0, B); )==""\n"
223R"==(B += 4; )==""\n"
224R"==(FMA_I_LOOP_32_ROW(0); )==""\n"
225R"==(FMA_I_LOOP_32_ROW(1); )==""\n"
226R"==(FMA_I_LOOP_32_ROW(2); )==""\n"
227R"==(FMA_I_LOOP_32_ROW(3); )==""\n"
228R"==(} )==""\n"
229R"==(int krem = k & 3; )==""\n"
230R"==(if (krem > 0) { )==""\n"
231R"==(for (int j = 0; j < 4; j++) )==""\n"
232R"==(a[j] = as_float2( )==""\n"
233R"==(intel_sub_group_block_read2((global uint *)A_cols[j])); )==""\n"
234R"==(b = vload4(0, B); )==""\n"
235R"==(FMA_I_LOOP_32_ROW(0); )==""\n"
236R"==(if (krem > 1) FMA_I_LOOP_32_ROW(1); )==""\n"
237R"==(if (krem > 2) FMA_I_LOOP_32_ROW(2); )==""\n"
238R"==(} )==""\n"
239R"==(} else { )==""\n"
240R"==(if (irem > 16) irem = 16; )==""\n"
241R"==(for (int z = 0; z < 4; z++) )==""\n"
242R"==(c[z] = 0.f; )==""\n"
243R"==(for (int h = 0; h < (k >> 2); h++) { )==""\n"
244R"==(for (int j = 0; j < 4; j++) { )==""\n"
245R"==(a[j].s0 = as_float( )==""\n"
246R"==(intel_sub_group_block_read((global uint *)A_cols[j])); )==""\n"
247R"==(A_cols[j] += ldax4; )==""\n"
248R"==(} )==""\n"
249R"==(b = vload4(0, B); )==""\n"
250R"==(B += 4; )==""\n"
251R"==(FMA_I_LOOP_16_ROW(0); )==""\n"
252R"==(FMA_I_LOOP_16_ROW(1); )==""\n"
253R"==(FMA_I_LOOP_16_ROW(2); )==""\n"
254R"==(FMA_I_LOOP_16_ROW(3); )==""\n"
255R"==(} )==""\n"
256R"==(int krem = k & 3; )==""\n"
257R"==(if (krem > 0) { )==""\n"
258R"==(for (int j = 0; j < 4; j++) )==""\n"
259R"==(a[j].s0 = as_float( )==""\n"
260R"==(intel_sub_group_block_read((global uint *)A_cols[j])); )==""\n"
261R"==(b = vload4(0, B); )==""\n"
262R"==(FMA_I_LOOP_16_ROW(0); )==""\n"
263R"==(if (krem > 1) FMA_I_LOOP_16_ROW(1); )==""\n"
264R"==(if (krem > 2) FMA_I_LOOP_16_ROW(2); )==""\n"
265R"==(} )==""\n"
266R"==(} )==""\n"
267R"==(if (beta == 0) )==""\n"
268R"==(UPDATE_C_32_ROW(1); )==""\n"
269R"==(else )==""\n"
270R"==(UPDATE_C_32_ROW(0); )==""\n"
271R"==(SUPERKERNEL_EPILOGUE )==""\n"
272R"==(} )==""\n"
273R"==(#endif )==""\n"
274R"==(#ifdef NT )==""\n"
275R"==(__attribute__((intel_reqd_sub_group_size(16))) )==""\n"
276R"==(kernel void )==""\n"
277R"==(gen9_gemm_nocopy_superkernel_f32(global int *plan, int threads, )==""\n"
278R"==(global float *A0, global float *B0, global float *C0, long offsetA, )==""\n"
279R"==(long offsetB, long offsetC, int lda, int ldb, int ldc, int m, int n, )==""\n"
280R"==(int k, global float *alpha, float beta, int last_k_block, )==""\n"
281R"==(float eltwise_alpha, float eltwise_beta, float eltwise_scale) { )==""\n"
282R"==(SUPERKERNEL_PROLOGUE )==""\n"
283R"==(float2 a[2]; )==""\n"
284R"==(float b[2]; )==""\n"
285R"==(float4 c[8]; )==""\n"
286R"==(int irem = m - i0; )==""\n"
287R"==(int jrem = n - j0; )==""\n"
288R"==(if (irem < 0) irem = 0; )==""\n"
289R"==(if (jrem < 0) jrem = 0; )==""\n"
290R"==(global float *A = A0 + i0; )==""\n"
291R"==(global float *B = B0 + j0; )==""\n"
292R"==(global float *C = C0 + i0 + j0 * ldc; )==""\n"
293R"==(global float *A_cols[2] = {A, A + lda}; )==""\n"
294R"==(global float *B_rows[2] = {B, B + ldb}; )==""\n"
295R"==(int ldax2 = lda << 1; )==""\n"
296R"==(int ldbx2 = ldb << 1; )==""\n"
297R"==(if (kid0 == 0) { )==""\n"
298R"==(for (int z = 0; z < 8; z++) )==""\n"
299R"==(c[z] = 0.f; )==""\n"
300R"==(for (int h = 0; h < (k >> 1); h++) { )==""\n"
301R"==(for (int j = 0; j < 2; j++) { )==""\n"
302R"==(a[j] = as_float2( )==""\n"
303R"==(intel_sub_group_block_read2((global uint *)A_cols[j])); )==""\n"
304R"==(A_cols[j] += ldax2; )==""\n"
305R"==(} )==""\n"
306R"==(for (int i = 0; i < 2; i++) { )==""\n"
307R"==(b[i] = as_float( )==""\n"
308R"==(intel_sub_group_block_read((global uint *)B_rows[i])); )==""\n"
309R"==(B_rows[i] += ldbx2; )==""\n"
310R"==(} )==""\n"
311R"==(FMA_I_LOOP_32_ROW(0); )==""\n"
312R"==(FMA_I_LOOP_32_ROW(1); )==""\n"
313R"==(} )==""\n"
314R"==(int krem = k & 1; )==""\n"
315R"==(if (krem > 0) { )==""\n"
316R"==(a[0] = as_float2( )==""\n"
317R"==(intel_sub_group_block_read2((global uint *)A_cols[0])); )==""\n"
318R"==(b[0] = as_float( )==""\n"
319R"==(intel_sub_group_block_read((global uint *)B_rows[0])); )==""\n"
320R"==(FMA_I_LOOP_32_ROW(0); )==""\n"
321R"==(} )==""\n"
322R"==(} else { )==""\n"
323R"==(if (irem > 16) irem = 16; )==""\n"
324R"==(for (int z = 0; z < 4; z++) )==""\n"
325R"==(c[z] = 0.f; )==""\n"
326R"==(for (int h = 0; h < (k >> 1); h++) { )==""\n"
327R"==(for (int j = 0; j < 2; j++) { )==""\n"
328R"==(a[j].s0 = as_float( )==""\n"
329R"==(intel_sub_group_block_read((global uint *)A_cols[j])); )==""\n"
330R"==(A_cols[j] += ldax2; )==""\n"
331R"==(} )==""\n"
332R"==(for (int i = 0; i < 2; i++) { )==""\n"
333R"==(b[i] = as_float( )==""\n"
334R"==(intel_sub_group_block_read((global uint *)B_rows[i])); )==""\n"
335R"==(B_rows[i] += ldbx2; )==""\n"
336R"==(} )==""\n"
337R"==(FMA_I_LOOP_16_ROW(0); )==""\n"
338R"==(FMA_I_LOOP_16_ROW(1); )==""\n"
339R"==(} )==""\n"
340R"==(int krem = k & 1; )==""\n"
341R"==(if (krem > 0) { )==""\n"
342R"==(a[0].s0 = as_float( )==""\n"
343R"==(intel_sub_group_block_read((global uint *)A_cols[0])); )==""\n"
344R"==(b[0] = as_float( )==""\n"
345R"==(intel_sub_group_block_read((global uint *)B_rows[0])); )==""\n"
346R"==(FMA_I_LOOP_16_ROW(0); )==""\n"
347R"==(} )==""\n"
348R"==(} )==""\n"
349R"==(if (beta == 0) )==""\n"
350R"==(UPDATE_C_32_ROW(1); )==""\n"
351R"==(else )==""\n"
352R"==(UPDATE_C_32_ROW(0); )==""\n"
353R"==(SUPERKERNEL_EPILOGUE )==""\n"
354R"==(} )==""\n"
355R"==(#endif )==""\n"
356R"==()==";
357}
358}
359}
360}