1namespace dnnl {
2namespace impl {
3namespace gpu {
4namespace ocl {
5const char *gen9_gemm_compute_kernel = R"==(/******************************************************************************* )==""\n"
6R"==(* Copyright 2019-2022 Intel Corporation )==""\n"
7R"==(* )==""\n"
8R"==(* Licensed under the Apache License, Version 2.0 (the "License"); )==""\n"
9R"==(* you may not use this file except in compliance with the License. )==""\n"
10R"==(* You may obtain a copy of the License at )==""\n"
11R"==(* )==""\n"
12R"==(* http: )==""\n"
13R"==(* )==""\n"
14R"==(* Unless required by applicable law or agreed to in writing, software )==""\n"
15R"==(* distributed under the License is distributed on an "AS IS" BASIS, )==""\n"
16R"==(* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. )==""\n"
17R"==(* See the License for the specific language governing permissions and )==""\n"
18R"==(* limitations under the License. )==""\n"
19R"==(*******************************************************************************/ )==""\n"
20R"==(#include "gpu/ocl/ocl_post_ops.h" )==""\n"
21R"==(#include "gpu/ocl/ocl_types.h" )==""\n"
22R"==(#define GRX 8 )==""\n"
23R"==(#if DT_F32 == 1 )==""\n"
24R"==(#if UNROLL_M <= 1 * GRX )==""\n"
25R"==(#define FLOATX float )==""\n"
26R"==(#define SIZEX 1 )==""\n"
27R"==(#elif UNROLL_M <= 2 * GRX )==""\n"
28R"==(#define FLOATX float2 )==""\n"
29R"==(#define SIZEX 2 )==""\n"
30R"==(#elif UNROLL_M <= 3 * GRX )==""\n"
31R"==(#define FLOATX float3 )==""\n"
32R"==(#define SIZEX 3 )==""\n"
33R"==(#else )==""\n"
34R"==(#define FLOATX float4 )==""\n"
35R"==(#define SIZEX 4 )==""\n"
36R"==(#endif )==""\n"
37R"==(#if UNROLL_N <= 1 * GRX )==""\n"
38R"==(#define FLOATY float )==""\n"
39R"==(#define SIZEY 1 )==""\n"
40R"==(#elif UNROLL_N <= 2 * GRX )==""\n"
41R"==(#define FLOATY float2 )==""\n"
42R"==(#define SIZEY 2 )==""\n"
43R"==(#elif UNROLL_N <= 3 * GRX )==""\n"
44R"==(#define FLOATY float3 )==""\n"
45R"==(#define SIZEY 3 )==""\n"
46R"==(#else )==""\n"
47R"==(#define FLOATY float4 )==""\n"
48R"==(#define SIZEY 4 )==""\n"
49R"==(#endif )==""\n"
50R"==(#define SHUFFLE(X, Y) intel_sub_group_shuffle(X, Y) )==""\n"
51R"==(#define SHUFFLE_DOWN(X, Y) intel_sub_group_shuffle_down(X, X, Y) )==""\n"
52R"==(#define SHUFFLE_UP(X, Y) intel_sub_group_shuffle_up(X, X, Y) )==""\n"
53R"==(#elif DT_F16 == 1 )==""\n"
54R"==(#if UNROLL_M <= 1 * GRX )==""\n"
55R"==(#define FLOATX half )==""\n"
56R"==(#define SIZEX 1 )==""\n"
57R"==(#elif UNROLL_M <= 2 * GRX )==""\n"
58R"==(#define FLOATX half2 )==""\n"
59R"==(#define SIZEX 2 )==""\n"
60R"==(#elif UNROLL_M <= 3 * GRX )==""\n"
61R"==(#define FLOATX half3 )==""\n"
62R"==(#define SIZEX 3 )==""\n"
63R"==(#elif UNROLL_M <= 4 * GRX )==""\n"
64R"==(#define FLOATX half4 )==""\n"
65R"==(#define SIZEX 4 )==""\n"
66R"==(#else )==""\n"
67R"==(#define FLOATX half8 )==""\n"
68R"==(#define SIZEX 8 )==""\n"
69R"==(#endif )==""\n"
70R"==(#if UNROLL_N <= 1 * GRX )==""\n"
71R"==(#define FLOATY half )==""\n"
72R"==(#define SIZEY 1 )==""\n"
73R"==(#elif UNROLL_N <= 2 * GRX )==""\n"
74R"==(#define FLOATY half2 )==""\n"
75R"==(#define SIZEY 2 )==""\n"
76R"==(#elif UNROLL_N <= 3 * GRX )==""\n"
77R"==(#define FLOATY half3 )==""\n"
78R"==(#define SIZEY 3 )==""\n"
79R"==(#elif UNROLL_N <= 4 * GRX )==""\n"
80R"==(#define FLOATY half4 )==""\n"
81R"==(#define SIZEY 4 )==""\n"
82R"==(#else )==""\n"
83R"==(#define FLOATY half8 )==""\n"
84R"==(#define SIZEY 8 )==""\n"
85R"==(#endif )==""\n"
86R"==(#if SIZEY == 2 )==""\n"
87R"==(#define SHUFFLE(X, Y) as_half2(intel_sub_group_shuffle(as_float(X), Y)) )==""\n"
88R"==(#elif SIZEY == 4 )==""\n"
89R"==(#define SHUFFLE(X, Y) as_half4(intel_sub_group_shuffle(as_float2(X), Y)) )==""\n"
90R"==(#else )==""\n"
91R"==(#define SHUFFLE(X, Y) as_half8(intel_sub_group_shuffle(as_float4(X), Y)) )==""\n"
92R"==(#endif )==""\n"
93R"==(#if SIZEX == 2 )==""\n"
94R"==(#define SHUFFLE_UP(X, Y) \ )==""\n"
95R"==(as_half2(intel_sub_group_shuffle_up(as_float(X), as_float(X), Y)) )==""\n"
96R"==(#define SHUFFLE_DOWN(X, Y) \ )==""\n"
97R"==(as_half2(intel_sub_group_shuffle_down(as_float(X), as_float(X), Y)) )==""\n"
98R"==(#elif SIZEX == 4 )==""\n"
99R"==(#define SHUFFLE_UP(X, Y) \ )==""\n"
100R"==(as_half4(intel_sub_group_shuffle_up(as_float2(X), as_float2(X), Y)) )==""\n"
101R"==(#define SHUFFLE_DOWN(X, Y) \ )==""\n"
102R"==(as_half4(intel_sub_group_shuffle_down(as_float2(X), as_float2(X), Y)) )==""\n"
103R"==(#else )==""\n"
104R"==(#define SHUFFLE_UP(X, Y) \ )==""\n"
105R"==(as_half8(intel_sub_group_shuffle_up(as_float4(X), as_float4(X), Y)) )==""\n"
106R"==(#define SHUFFLE_DOWN(X, Y) \ )==""\n"
107R"==(as_half8(intel_sub_group_shuffle_down(as_float4(X), as_float4(X), Y)) )==""\n"
108R"==(#endif )==""\n"
109R"==(#endif )==""\n"
110R"==(#define AS_FLOATX(X, Y) *((__global FLOATX *)(X + Y)) )==""\n"
111R"==(#define AS_FLOATY(X, Y) *((__global FLOATY *)(X + Y)) )==""\n"
112R"==(#if UNROLL_M <= 1 * GRX )==""\n"
113R"==(#define CALC_X(x, a, b, R0, R1, R2, R3) \ )==""\n"
114R"==(bb = SHUFFLE(b, x); \ )==""\n"
115R"==(R0##x = mad(a, bb, R0##x); )==""\n"
116R"==(#elif UNROLL_M <= 2 * GRX )==""\n"
117R"==(#define CALC_X(x, a, b, R0, R1, R2, R3) \ )==""\n"
118R"==(bb = SHUFFLE(b, x); \ )==""\n"
119R"==(R0##x = mad(a.s0, bb, R0##x); \ )==""\n"
120R"==(R1##x = mad(a.s1, bb, R1##x); )==""\n"
121R"==(#elif UNROLL_M <= 3 * GRX )==""\n"
122R"==(#define CALC_X(x, a, b, R0, R1, R2, R3) \ )==""\n"
123R"==(bb = SHUFFLE(b, x); \ )==""\n"
124R"==(R0##x = mad(a.s0, bb, R0##x); \ )==""\n"
125R"==(R1##x = mad(a.s1, bb, R1##x); \ )==""\n"
126R"==(R2##x = mad(a.s2, bb, R2##x); )==""\n"
127R"==(#else )==""\n"
128R"==(#define CALC_X(x, a, b, R0, R1, R2, R3) \ )==""\n"
129R"==(bb = SHUFFLE(b, x); \ )==""\n"
130R"==(R0##x = mad(a.s0, bb, R0##x); \ )==""\n"
131R"==(R1##x = mad(a.s1, bb, R1##x); \ )==""\n"
132R"==(R2##x = mad(a.s2, bb, R2##x); \ )==""\n"
133R"==(R3##x = mad(a.s3, bb, R3##x); )==""\n"
134R"==(#endif )==""\n"
135R"==(#define CALC(a, b, R0, R1, R2, R3) \ )==""\n"
136R"==(CALC_X(0, a, b, R0, R1, R2, R3); \ )==""\n"
137R"==(CALC_X(1, a, b, R0, R1, R2, R3); \ )==""\n"
138R"==(CALC_X(2, a, b, R0, R1, R2, R3); \ )==""\n"
139R"==(CALC_X(3, a, b, R0, R1, R2, R3); \ )==""\n"
140R"==(CALC_X(4, a, b, R0, R1, R2, R3); \ )==""\n"
141R"==(CALC_X(5, a, b, R0, R1, R2, R3); \ )==""\n"
142R"==(CALC_X(6, a, b, R0, R1, R2, R3); \ )==""\n"
143R"==(CALC_X(7, a, b, R0, R1, R2, R3); )==""\n"
144R"==(#define INIT_C(n) \ )==""\n"
145R"==(FLOATY cc##n##0 = DATA_ZERO, cc##n##1 = DATA_ZERO; \ )==""\n"
146R"==(FLOATY cc##n##2 = DATA_ZERO, cc##n##3 = DATA_ZERO; \ )==""\n"
147R"==(FLOATY cc##n##4 = DATA_ZERO, cc##n##5 = DATA_ZERO; \ )==""\n"
148R"==(FLOATY cc##n##6 = DATA_ZERO, cc##n##7 = DATA_ZERO; )==""\n"
149R"==(#if WITH_ELTWISE == 1 )==""\n"
150R"==(#define POST_OP(val) \ )==""\n"
151R"==(do { \ )==""\n"
152R"==(if (last_k_block) \ )==""\n"
153R"==(val = fwd_eltwise( \ )==""\n"
154R"==(val, eltwise_alpha, eltwise_beta, eltwise_scale); \ )==""\n"
155R"==(} while (0) )==""\n"
156R"==(#else )==""\n"
157R"==(#define POST_OP(val) )==""\n"
158R"==(#endif )==""\n"
159R"==(#ifdef BETA_ZERO )==""\n"
160R"==(#define UPDATE(c, acc) \ )==""\n"
161R"==(do { \ )==""\n"
162R"==(DATA_T val = acc; \ )==""\n"
163R"==(POST_OP(val); \ )==""\n"
164R"==(c = REF_TO_DST(val); \ )==""\n"
165R"==(} while (0) )==""\n"
166R"==(#else )==""\n"
167R"==(#define UPDATE(c, acc) \ )==""\n"
168R"==(do { \ )==""\n"
169R"==(DATA_T val = DST_TO_REF(c) + acc; \ )==""\n"
170R"==(POST_OP(val); \ )==""\n"
171R"==(c = REF_TO_DST(val); \ )==""\n"
172R"==(} while (0) )==""\n"
173R"==(#endif )==""\n"
174R"==(#if SIZEX == 1 )==""\n"
175R"==(#define UPDATE_YY(X, Y, R0, R1, R2, R3) \ )==""\n"
176R"==(if (n > (Y)) { \ )==""\n"
177R"==(if ((m > 0)) { UPDATE(c[offsetC + 0], R0); } \ )==""\n"
178R"==(offsetC += ldc; \ )==""\n"
179R"==(} )==""\n"
180R"==(#elif SIZEX == 2 )==""\n"
181R"==(#define UPDATE_YY(X, Y, R0, R1, R2, R3) \ )==""\n"
182R"==(if (n > (Y)) { \ )==""\n"
183R"==(if ((m > 0)) { UPDATE(c[offsetC + 0], R0); } \ )==""\n"
184R"==(if ((m > 1)) { UPDATE(c[offsetC + 1], R1); } \ )==""\n"
185R"==(offsetC += ldc; \ )==""\n"
186R"==(} )==""\n"
187R"==(#elif SIZEX == 3 )==""\n"
188R"==(#define UPDATE_YY(X, Y, R0, R1, R2, R3) \ )==""\n"
189R"==(if (n > (Y)) { \ )==""\n"
190R"==(if ((m > 0)) { UPDATE(c[offsetC + 0], R0); } \ )==""\n"
191R"==(if ((m > 1)) { UPDATE(c[offsetC + 1], R1); } \ )==""\n"
192R"==(if ((m > 2)) { UPDATE(c[offsetC + 2], R2); } \ )==""\n"
193R"==(offsetC += ldc; \ )==""\n"
194R"==(} )==""\n"
195R"==(#else )==""\n"
196R"==(#define UPDATE_YY(X, Y, R0, R1, R2, R3) \ )==""\n"
197R"==(if (n > (Y)) { \ )==""\n"
198R"==(if ((m > 0)) { UPDATE(c[offsetC + 0], R0); } \ )==""\n"
199R"==(if ((m > 1)) { UPDATE(c[offsetC + 1], R1); } \ )==""\n"
200R"==(if ((m > 2)) { UPDATE(c[offsetC + 2], R2); } \ )==""\n"
201R"==(if ((m > 3)) { UPDATE(c[offsetC + 3], R3); } \ )==""\n"
202R"==(offsetC += ldc; \ )==""\n"
203R"==(} )==""\n"
204R"==(#endif )==""\n"
205R"==(#if SIZEY == 1 )==""\n"
206R"==(#define UPDATE_Y(X, R0, R1, R2, R3) \ )==""\n"
207R"==(UPDATE_YY(X, X *SIZEY + 0, R0##X, R1##X, R2##X, R3##X); )==""\n"
208R"==(#elif SIZEY == 2 )==""\n"
209R"==(#define UPDATE_Y(X, R0, R1, R2, R3) \ )==""\n"
210R"==(UPDATE_YY(X, X *SIZEY + 0, R0##X.s0, R1##X.s0, R2##X.s0, R3##X.s0); \ )==""\n"
211R"==(UPDATE_YY(X, X *SIZEY + 1, R0##X.s1, R1##X.s1, R2##X.s1, R3##X.s1); )==""\n"
212R"==(#elif SIZEY == 3 )==""\n"
213R"==(#define UPDATE_Y(X, R0, R1, R2, R3) \ )==""\n"
214R"==(UPDATE_YY(X, X *SIZEY + 0, R0##X.s0, R1##X.s0, R2##X.s0, R3##X.s0); \ )==""\n"
215R"==(UPDATE_YY(X, X *SIZEY + 1, R0##X.s1, R1##X.s1, R2##X.s1, R3##X.s1); \ )==""\n"
216R"==(UPDATE_YY(X, X *SIZEY + 2, R0##X.s2, R1##X.s2, R2##X.s2, R3##X.s2); )==""\n"
217R"==(#else )==""\n"
218R"==(#define UPDATE_Y(X, R0, R1, R2, R3) \ )==""\n"
219R"==(UPDATE_YY(X, X *SIZEY + 0, R0##X.s0, R1##X.s0, R2##X.s0, R3##X.s0); \ )==""\n"
220R"==(UPDATE_YY(X, X *SIZEY + 1, R0##X.s1, R1##X.s1, R2##X.s1, R3##X.s1); \ )==""\n"
221R"==(UPDATE_YY(X, X *SIZEY + 2, R0##X.s2, R1##X.s2, R2##X.s2, R3##X.s2); \ )==""\n"
222R"==(UPDATE_YY(X, X *SIZEY + 3, R0##X.s3, R1##X.s3, R2##X.s3, R3##X.s3); )==""\n"
223R"==(#endif )==""\n"
224R"==(__attribute__((intel_reqd_sub_group_size(GRX))) __kernel void gen9_gemm_compute( )==""\n"
225R"==(long m, long n, long k, __global DATA_T *base, int offsetA, int offsetB, )==""\n"
226R"==(__global DST_DATA_T *c, long offsetC, long ldc, int last_k_block, )==""\n"
227R"==(float eltwise_alpha, float eltwise_beta, float eltwise_scale) { )==""\n"
228R"==(int idx, idy, lid; )==""\n"
229R"==(idx = get_group_id(0); )==""\n"
230R"==(idy = get_group_id(1) * get_enqueued_local_size(1) + get_local_id(1); )==""\n"
231R"==(lid = get_local_id(0); )==""\n"
232R"==(m -= UNROLL_M * idx; )==""\n"
233R"==(if (m > UNROLL_M) m = UNROLL_M; )==""\n"
234R"==(n -= UNROLL_N * idy; )==""\n"
235R"==(if (n > UNROLL_N) n = UNROLL_N; )==""\n"
236R"==(m -= UNROLL_M * lid / GRX; )==""\n"
237R"==(offsetA += UNROLL_M * k * idx + UNROLL_M * lid / GRX; )==""\n"
238R"==(offsetB += UNROLL_N * k * idy + UNROLL_N * lid / GRX; )==""\n"
239R"==(offsetC += UNROLL_M * idx + UNROLL_N * ldc * idy + UNROLL_M * lid / GRX; )==""\n"
240R"==(INIT_C(0); )==""\n"
241R"==(INIT_C(1); )==""\n"
242R"==(INIT_C(2); )==""\n"
243R"==(INIT_C(3); )==""\n"
244R"==(INIT_C(4); )==""\n"
245R"==(INIT_C(5); )==""\n"
246R"==(INIT_C(6); )==""\n"
247R"==(INIT_C(7); )==""\n"
248R"==(FLOATX blockA = AS_FLOATX(base, offsetA); )==""\n"
249R"==(offsetA += UNROLL_M; )==""\n"
250R"==(FLOATY blockB = AS_FLOATY(base, offsetB); )==""\n"
251R"==(offsetB += UNROLL_N; )==""\n"
252R"==(for (int l = k; l > 0; l--) { )==""\n"
253R"==(FLOATY bb; )==""\n"
254R"==(CALC(blockA, blockB, cc0, cc1, cc2, cc3); )==""\n"
255R"==(blockB = AS_FLOATY(base, offsetB); )==""\n"
256R"==(offsetB += UNROLL_N; )==""\n"
257R"==(blockA = AS_FLOATX(base, offsetA); )==""\n"
258R"==(offsetA += UNROLL_M; )==""\n"
259R"==(} )==""\n"
260R"==(UPDATE_Y(0, cc0, cc1, cc2, cc3); )==""\n"
261R"==(#if UNROLL_N >= 2 * SIZEY )==""\n"
262R"==(UPDATE_Y(1, cc0, cc1, cc2, cc3); )==""\n"
263R"==(#endif )==""\n"
264R"==(#if UNROLL_N >= 3 * SIZEY )==""\n"
265R"==(UPDATE_Y(2, cc0, cc1, cc2, cc3); )==""\n"
266R"==(#endif )==""\n"
267R"==(#if UNROLL_N >= 4 * SIZEY )==""\n"
268R"==(UPDATE_Y(3, cc0, cc1, cc2, cc3); )==""\n"
269R"==(#endif )==""\n"
270R"==(#if UNROLL_N >= 5 * SIZEY )==""\n"
271R"==(UPDATE_Y(4, cc0, cc1, cc2, cc3); )==""\n"
272R"==(#endif )==""\n"
273R"==(#if UNROLL_N >= 6 * SIZEY )==""\n"
274R"==(UPDATE_Y(5, cc0, cc1, cc2, cc3); )==""\n"
275R"==(#endif )==""\n"
276R"==(#if UNROLL_N >= 7 * SIZEY )==""\n"
277R"==(UPDATE_Y(6, cc0, cc1, cc2, cc3); )==""\n"
278R"==(#endif )==""\n"
279R"==(#if UNROLL_N >= 8 * SIZEY )==""\n"
280R"==(UPDATE_Y(7, cc0, cc1, cc2, cc3); )==""\n"
281R"==(#endif )==""\n"
282R"==(} )==""\n"
283R"==()==";
284}
285}
286}
287}