1namespace dnnl {
2namespace impl {
3namespace gpu {
4namespace ocl {
5const char *gen9_gemm_nocopy_x8x8s32_kernel = R"==(/******************************************************************************* )==""\n"
6R"==(* Copyright 2019-2022 Intel Corporation )==""\n"
7R"==(* )==""\n"
8R"==(* Licensed under the Apache License, Version 2.0 (the "License"); )==""\n"
9R"==(* you may not use this file except in compliance with the License. )==""\n"
10R"==(* You may obtain a copy of the License at )==""\n"
11R"==(* )==""\n"
12R"==(* http: )==""\n"
13R"==(* )==""\n"
14R"==(* Unless required by applicable law or agreed to in writing, software )==""\n"
15R"==(* distributed under the License is distributed on an "AS IS" BASIS, )==""\n"
16R"==(* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. )==""\n"
17R"==(* See the License for the specific language governing permissions and )==""\n"
18R"==(* limitations under the License. )==""\n"
19R"==(*******************************************************************************/ )==""\n"
20R"==(#include "gpu/ocl/gemm/ocl_gemm_attrs.h" )==""\n"
21R"==(#include "gpu/ocl/ocl_post_ops.h" )==""\n"
22R"==(#include "gpu/ocl/ocl_types.h" )==""\n"
23R"==(#undef GRX )==""\n"
24R"==(#define GRX 8 )==""\n"
25R"==(#if defined(S8S8) )==""\n"
26R"==(#define FLOATA char )==""\n"
27R"==(#define FLOATA2 char2 )==""\n"
28R"==(#define FLOATA4 char4 )==""\n"
29R"==(#define FLOATB char )==""\n"
30R"==(#define FLOATB4 char4 )==""\n"
31R"==(#define SHUFFLE(X, Y) as_char4(intel_sub_group_shuffle(as_int(X), Y)) )==""\n"
32R"==(#endif )==""\n"
33R"==(#if defined(U8S8) )==""\n"
34R"==(#define FLOATA uchar )==""\n"
35R"==(#define FLOATA2 uchar2 )==""\n"
36R"==(#define FLOATA4 uchar4 )==""\n"
37R"==(#define FLOATB char )==""\n"
38R"==(#define FLOATB4 char4 )==""\n"
39R"==(#define SHUFFLE(X, Y) as_char4(intel_sub_group_shuffle(as_int(X), Y)) )==""\n"
40R"==(#endif )==""\n"
41R"==(#if defined(S8U8) )==""\n"
42R"==(#define FLOATA char )==""\n"
43R"==(#define FLOATA2 char2 )==""\n"
44R"==(#define FLOATA4 char4 )==""\n"
45R"==(#define FLOATB uchar )==""\n"
46R"==(#define FLOATB4 uchar4 )==""\n"
47R"==(#define SHUFFLE(X, Y) as_uchar4(intel_sub_group_shuffle(as_int(X), Y)) )==""\n"
48R"==(#endif )==""\n"
49R"==(#if defined(U8U8) )==""\n"
50R"==(#define FLOATA uchar )==""\n"
51R"==(#define FLOATA2 uchar2 )==""\n"
52R"==(#define FLOATA4 uchar4 )==""\n"
53R"==(#define FLOATB uchar )==""\n"
54R"==(#define FLOATB4 uchar4 )==""\n"
55R"==(#define SHUFFLE(X, Y) as_uchar4(intel_sub_group_shuffle(as_int(X), Y)) )==""\n"
56R"==(#endif )==""\n"
57R"==(#define FLOATC int )==""\n"
58R"==(#define FLOATC4 int4 )==""\n"
59R"==(#if WITH_ELTWISE == 1 )==""\n"
60R"==(#define POST_OP(val) \ )==""\n"
61R"==(do { \ )==""\n"
62R"==(if (apply_eltwise) \ )==""\n"
63R"==(val = fwd_eltwise( \ )==""\n"
64R"==(val, eltwise_alpha, eltwise_beta, eltwise_scale); \ )==""\n"
65R"==(} while (0) )==""\n"
66R"==(#else )==""\n"
67R"==(#define POST_OP(val) )==""\n"
68R"==(#endif )==""\n"
69R"==(#define COMPUTE_C(X, Y, CO_IDX) \ )==""\n"
70R"==(do { \ )==""\n"
71R"==(float val = (!beta ? 0 : (X)) + (Y); \ )==""\n"
72R"==(POST_OP(val); \ )==""\n"
73R"==((X) = convert_int_sat_rte(val + (apply_co ? co[CO_IDX] : 0)); \ )==""\n"
74R"==(} while (0) )==""\n"
75R"==(#ifdef FF )==""\n"
76R"==(#define ADD_EACH(X, OFF) \ )==""\n"
77R"==(do { \ )==""\n"
78R"==(if (n > X + OFF) { \ )==""\n"
79R"==(if (m > 0) \ )==""\n"
80R"==(COMPUTE_C(c[0], sc[X / 4 + 0].s##OFF + xa[0] + xb[0], 0); \ )==""\n"
81R"==(if (m > 1) \ )==""\n"
82R"==(COMPUTE_C(c[1], sc[X / 4 + 4].s##OFF + xa[1] + xb[0], 0); \ )==""\n"
83R"==(if (m > 2) \ )==""\n"
84R"==(COMPUTE_C(c[2], sc[X / 4 + 8].s##OFF + xa[2] + xb[0], 0); \ )==""\n"
85R"==(if (m > 3) \ )==""\n"
86R"==(COMPUTE_C(c[3], sc[X / 4 + 12].s##OFF + xa[3] + xb[0], 0); \ )==""\n"
87R"==(xb++; \ )==""\n"
88R"==(c += ldc; \ )==""\n"
89R"==(} \ )==""\n"
90R"==(} while (0) )==""\n"
91R"==(#elif defined CC )==""\n"
92R"==(#define ADD_EACH(X, OFF) \ )==""\n"
93R"==(do { \ )==""\n"
94R"==(if (n > X + OFF) { \ )==""\n"
95R"==(if (m > 0) \ )==""\n"
96R"==(COMPUTE_C(c[0], sc[X / 4 + 0].s##OFF + xa[0] + xb[0], 0); \ )==""\n"
97R"==(if (m > 1) \ )==""\n"
98R"==(COMPUTE_C(c[1], sc[X / 4 + 4].s##OFF + xa[1] + xb[0], 0); \ )==""\n"
99R"==(if (m > 2) \ )==""\n"
100R"==(COMPUTE_C(c[2], sc[X / 4 + 8].s##OFF + xa[2] + xb[0], 0); \ )==""\n"
101R"==(if (m > 3) \ )==""\n"
102R"==(COMPUTE_C(c[3], sc[X / 4 + 12].s##OFF + xa[3] + xb[0], 0); \ )==""\n"
103R"==(xb++; \ )==""\n"
104R"==(c += ldc; \ )==""\n"
105R"==(co++; \ )==""\n"
106R"==(} \ )==""\n"
107R"==(} while (0) )==""\n"
108R"==(#else )==""\n"
109R"==(#define ADD_EACH(X, OFF) \ )==""\n"
110R"==(do { \ )==""\n"
111R"==(if (n > X + OFF) { \ )==""\n"
112R"==(if (m > 0) \ )==""\n"
113R"==(COMPUTE_C(c[0], sc[X / 4 + 0].s##OFF + xa[0] + xb[0], 0); \ )==""\n"
114R"==(if (m > 1) \ )==""\n"
115R"==(COMPUTE_C(c[1], sc[X / 4 + 4].s##OFF + xa[1] + xb[0], 1); \ )==""\n"
116R"==(if (m > 2) \ )==""\n"
117R"==(COMPUTE_C(c[2], sc[X / 4 + 8].s##OFF + xa[2] + xb[0], 2); \ )==""\n"
118R"==(if (m > 3) \ )==""\n"
119R"==(COMPUTE_C(c[3], sc[X / 4 + 12].s##OFF + xa[3] + xb[0], 3); \ )==""\n"
120R"==(xb++; \ )==""\n"
121R"==(c += ldc; \ )==""\n"
122R"==(} \ )==""\n"
123R"==(} while (0) )==""\n"
124R"==(#endif )==""\n"
125R"==(#define ADD_SCALE(X) \ )==""\n"
126R"==(do { \ )==""\n"
127R"==(ADD_EACH(X, 0); \ )==""\n"
128R"==(ADD_EACH(X, 1); \ )==""\n"
129R"==(ADD_EACH(X, 2); \ )==""\n"
130R"==(ADD_EACH(X, 3); \ )==""\n"
131R"==(} while (0) )==""\n"
132R"==(#define ACCUMULATE_1(a, b) \ )==""\n"
133R"==(((FLOATC)a.s0 * (FLOATC)b.s0) + ((FLOATC)a.s1 * (FLOATC)b.s1) \ )==""\n"
134R"==(+ ((FLOATC)a.s2 * (FLOATC)b.s2) + ((FLOATC)a.s3 * (FLOATC)b.s3) )==""\n"
135R"==(#define ACCUMULATE(a, b0, b1, b2, b3) \ )==""\n"
136R"==((FLOATC4)(ACCUMULATE_1(a, b0), ACCUMULATE_1(a, b1), ACCUMULATE_1(a, b2), \ )==""\n"
137R"==(ACCUMULATE_1(a, b3)) )==""\n"
138R"==(#define GROUPSIZE_M (6 * UNROLL_M) )==""\n"
139R"==(#define GROUPSIZE_N (4 * UNROLL_N) )==""\n"
140R"==(__attribute__((intel_reqd_sub_group_size(GRX))) kernel void )==""\n"
141R"==(gen9_gemm_compute_x8x8s32(global FLOATA *a, global FLOATB *b, global FLOATC *c, )==""\n"
142R"==(long offsetA, long offsetB, long offsetC, long lda, long ldb, long ldc, )==""\n"
143R"==(long m, long n, long k, int beta, global int *ao, global int *bo, )==""\n"
144R"==(global int *co, long offsetCO, int apply_co, local FLOATA *sa, )==""\n"
145R"==(local FLOATB *sb, int apply_eltwise, float eltwise_alpha, )==""\n"
146R"==(float eltwise_beta, float eltwise_scale) { )==""\n"
147R"==(long kk = (k + UNROLL_K - 1) & ~(UNROLL_K - 1); )==""\n"
148R"==(long i, j, l, ll; )==""\n"
149R"==(global FLOATC *c_ori; )==""\n"
150R"==(int lid = get_local_id(0); )==""\n"
151R"==(int idx = get_local_id(1); )==""\n"
152R"==(int idy = get_local_id(2); )==""\n"
153R"==(long gdx = get_group_id(1); )==""\n"
154R"==(long gdy = get_group_id(2); )==""\n"
155R"==(long szx = get_local_size(1); )==""\n"
156R"==(long szy = get_local_size(2); )==""\n"
157R"==(a += offsetA; )==""\n"
158R"==(b += offsetB; )==""\n"
159R"==(c += offsetC + UNROLL_M * idx + GROUPSIZE_M * gdx + UNROLL_M * lid / GRX )==""\n"
160R"==(+ (UNROLL_N * idy + GROUPSIZE_N * gdy) * ldc; )==""\n"
161R"==(c_ori = c; )==""\n"
162R"==(if (apply_co) { )==""\n"
163R"==(co += offsetCO; )==""\n"
164R"==(#ifdef RR )==""\n"
165R"==(co += GROUPSIZE_M * gdx + UNROLL_M * idx + UNROLL_M * lid / GRX; )==""\n"
166R"==(#endif )==""\n"
167R"==(#ifdef CC )==""\n"
168R"==(co += GROUPSIZE_N * gdy + UNROLL_N * idy; )==""\n"
169R"==(#endif )==""\n"
170R"==(} )==""\n"
171R"==(__local FLOATC *xa = (__local FLOATC *)sa; )==""\n"
172R"==(sa += UNROLL_M * szx * sizeof(FLOATC); )==""\n"
173R"==(__local FLOATC *xb = (__local FLOATC *)sb; )==""\n"
174R"==(sb += UNROLL_N * szy * sizeof(FLOATC); )==""\n"
175R"==(int cid0 = (idy * szx + idx) * get_local_size(0) + lid; )==""\n"
176R"==(int ctotal = get_local_size(0) * szx * szy; )==""\n"
177R"==(for (int cid = cid0; cid < szx * UNROLL_M; cid += ctotal) { )==""\n"
178R"==(long sa_moffset = (cid & ~(UNROLL_M - 1)) * kk )==""\n"
179R"==(+ (cid & (UNROLL_M - 1)) * UNROLL_K; )==""\n"
180R"==(long i = cid + GROUPSIZE_M * gdx; )==""\n"
181R"==(FLOATC sumA = 0; )==""\n"
182R"==(#if defined(NN) || defined(NT) )==""\n"
183R"==(long a_offset = i; )==""\n"
184R"==(#else )==""\n"
185R"==(long a_offset = i * lda; )==""\n"
186R"==(#endif )==""\n"
187R"==(for (l = 0; l < kk; l += UNROLL_K) { )==""\n"
188R"==(for (ll = 0; ll < UNROLL_K; ll++) { )==""\n"
189R"==(FLOATA a_val = (((i < m) && (l + ll < k)) ? a[a_offset] : 0); )==""\n"
190R"==(sa[sa_moffset + l * UNROLL_M + ll] = a_val; )==""\n"
191R"==(sumA -= a_val; )==""\n"
192R"==(#if defined(NN) || defined(NT) )==""\n"
193R"==(a_offset += lda; )==""\n"
194R"==(#else )==""\n"
195R"==(a_offset++; )==""\n"
196R"==(#endif )==""\n"
197R"==(} )==""\n"
198R"==(} )==""\n"
199R"==(xa[cid] = (FLOATC)ATTR_B0 * sumA; )==""\n"
200R"==(} )==""\n"
201R"==(for (int cid = cid0; cid < szy * UNROLL_N; cid += ctotal) { )==""\n"
202R"==(long sb_noffset = (cid & ~(UNROLL_N - 1)) * kk )==""\n"
203R"==(+ (cid & (UNROLL_N - 1)) * UNROLL_K; )==""\n"
204R"==(long j = cid + GROUPSIZE_N * gdy; )==""\n"
205R"==(FLOATC sumB = (FLOATC)ATTR_B0 * k; )==""\n"
206R"==(#if defined(NN) || defined(TN) )==""\n"
207R"==(long b_offset = j * ldb; )==""\n"
208R"==(#else )==""\n"
209R"==(long b_offset = j; )==""\n"
210R"==(#endif )==""\n"
211R"==(for (l = 0; l < kk; l += UNROLL_K) { )==""\n"
212R"==(for (ll = 0; ll < UNROLL_K; ll++) { )==""\n"
213R"==(FLOATB b_val = (((j < n) && (l + ll < k)) ? b[b_offset] : 0); )==""\n"
214R"==(sb[sb_noffset + l * UNROLL_N + ll] = b_val; )==""\n"
215R"==(sumB -= b_val; )==""\n"
216R"==(#if defined(NN) || defined(TN) )==""\n"
217R"==(b_offset++; )==""\n"
218R"==(#else )==""\n"
219R"==(b_offset += ldb; )==""\n"
220R"==(#endif )==""\n"
221R"==(} )==""\n"
222R"==(} )==""\n"
223R"==(xb[cid] = (FLOATC)ATTR_A0 * sumB; )==""\n"
224R"==(} )==""\n"
225R"==(barrier(CLK_LOCAL_MEM_FENCE); )==""\n"
226R"==(m -= GROUPSIZE_M * gdx + UNROLL_M * idx; )==""\n"
227R"==(if (m > UNROLL_M) m = UNROLL_M; )==""\n"
228R"==(n -= GROUPSIZE_N * gdy + UNROLL_N * idy; )==""\n"
229R"==(if (n > UNROLL_N) n = UNROLL_N; )==""\n"
230R"==(if ((m <= 0) || (n <= 0)) return; )==""\n"
231R"==(m -= UNROLL_M * lid / GRX; )==""\n"
232R"==(sa += UNROLL_M * kk * idx + UNROLL_M * UNROLL_K * lid / GRX; )==""\n"
233R"==(sb += UNROLL_N * kk * idy + UNROLL_K * lid; )==""\n"
234R"==(xa += UNROLL_M * idx + UNROLL_M * lid / GRX; )==""\n"
235R"==(xb += UNROLL_N * idy; )==""\n"
236R"==(FLOATC4 sc[UNROLL_M * UNROLL_N / GRX / 4] = {0}; )==""\n"
237R"==(for (l = 0; l < kk; l += UNROLL_K) { )==""\n"
238R"==(FLOATA4 a0, a1, a2, a3; )==""\n"
239R"==(FLOATB4 bb, b0, b1, b2, b3; )==""\n"
240R"==(a0 = ((__local FLOATA4 *)sa)[0]; )==""\n"
241R"==(a1 = ((__local FLOATA4 *)sa)[1]; )==""\n"
242R"==(a2 = ((__local FLOATA4 *)sa)[2]; )==""\n"
243R"==(a3 = ((__local FLOATA4 *)sa)[3]; )==""\n"
244R"==(for (ll = 0; ll < GRX / 4; ll++) { )==""\n"
245R"==(bb = ((__local FLOATB4 *)sb)[0]; )==""\n"
246R"==(b0 = SHUFFLE(bb, 0); )==""\n"
247R"==(b1 = SHUFFLE(bb, 1); )==""\n"
248R"==(b2 = SHUFFLE(bb, 2); )==""\n"
249R"==(b3 = SHUFFLE(bb, 3); )==""\n"
250R"==(sc[ll * 2 + 0] += ACCUMULATE(a0, b0, b1, b2, b3); )==""\n"
251R"==(sc[ll * 2 + 4] += ACCUMULATE(a1, b0, b1, b2, b3); )==""\n"
252R"==(sc[ll * 2 + 8] += ACCUMULATE(a2, b0, b1, b2, b3); )==""\n"
253R"==(sc[ll * 2 + 12] += ACCUMULATE(a3, b0, b1, b2, b3); )==""\n"
254R"==(b0 = SHUFFLE(bb, 4); )==""\n"
255R"==(b1 = SHUFFLE(bb, 5); )==""\n"
256R"==(b2 = SHUFFLE(bb, 6); )==""\n"
257R"==(b3 = SHUFFLE(bb, 7); )==""\n"
258R"==(sc[ll * 2 + 1] += ACCUMULATE(a0, b0, b1, b2, b3); )==""\n"
259R"==(sc[ll * 2 + 5] += ACCUMULATE(a1, b0, b1, b2, b3); )==""\n"
260R"==(sc[ll * 2 + 9] += ACCUMULATE(a2, b0, b1, b2, b3); )==""\n"
261R"==(sc[ll * 2 + 13] += ACCUMULATE(a3, b0, b1, b2, b3); )==""\n"
262R"==(sb += UNROLL_N * GRX / 4; )==""\n"
263R"==(} )==""\n"
264R"==(sa += UNROLL_M * UNROLL_K; )==""\n"
265R"==(} )==""\n"
266R"==(ADD_SCALE(0); )==""\n"
267R"==(ADD_SCALE(4); )==""\n"
268R"==(ADD_SCALE(8); )==""\n"
269R"==(ADD_SCALE(12); )==""\n"
270R"==(} )==""\n"
271R"==()==";
272}
273}
274}
275}