1namespace dnnl {
2namespace impl {
3namespace gpu {
4namespace ocl {
5const char *xe_lp_1x1_conv_fwd_data_x8s8x_kernel = R"==(/******************************************************************************* )==""\n"
6R"==(* Copyright 2019-2022 Intel Corporation )==""\n"
7R"==(* )==""\n"
8R"==(* licensed under the apache license, version 2.0 (the "license"); )==""\n"
9R"==(* you may not use this file except in compliance with the license. )==""\n"
10R"==(* you may obtain a copy of the license at )==""\n"
11R"==(* )==""\n"
12R"==(* http: )==""\n"
13R"==(* )==""\n"
14R"==(* unless required by applicable law or agreed to in writing, software )==""\n"
15R"==(* distributed under the license is distributed on an "as is" basis, )==""\n"
16R"==(* without warranties or conditions of any kind, either express or implied. )==""\n"
17R"==(* see the license for the specific language governing permissions and )==""\n"
18R"==(* limitations under the license. )==""\n"
19R"==(*******************************************************************************/ )==""\n"
20R"==(#include "gpu/ocl/ocl_math_utils.h" )==""\n"
21R"==(#include "gpu/ocl/ocl_post_ops.h" )==""\n"
22R"==(#include "gpu/ocl/ocl_scales.h" )==""\n"
23R"==(#include "gpu/ocl/ocl_types.h" )==""\n"
24R"==(#include "gpu/ocl/ocl_zero_points.h" )==""\n"
25R"==(#if IC % IC_BLOCK != 0 )==""\n"
26R"==(#define IC_NBLOCKS_TAIL ((IC - (IC & ~(IC_BLOCK - 1)) + 3) / 4) )==""\n"
27R"==(#else )==""\n"
28R"==(#define IC_NBLOCKS_TAIL 8 )==""\n"
29R"==(#endif )==""\n"
30R"==(#define OWB ((OW + SP_BLOCK - 1) / SP_BLOCK) )==""\n"
31R"==(#define HAS_PADDING (PW > 0 || PH > 0 || PD > 0) )==""\n"
32R"==(#define USE_SP_BLOCK (SW == 1 && SH == 1 && SW == 1) || HAS_PADDING )==""\n"
33R"==(#define SRC_SP (IW * IH * ID) )==""\n"
34R"==(#define SRC_MB_STRIDE IC_BLOCK )==""\n"
35R"==(#define SRC_SP_STRIDE (SRC_MB_STRIDE * MB_BLOCK) )==""\n"
36R"==(#define SRC_ICB_STRIDE (SRC_SP_STRIDE * SRC_SP) )==""\n"
37R"==(#define DST_SP (OW * OH * OD) )==""\n"
38R"==(#define DST_MB_STRIDE IC_BLOCK )==""\n"
39R"==(#define DST_SP_STRIDE (DST_MB_STRIDE * MB_BLOCK) )==""\n"
40R"==(#define DST_OCB_STRIDE (DST_SP_STRIDE * DST_SP) )==""\n"
41R"==(#define WEI_BLOCK_STRIDE (4 * 8 * 8 * 4) )==""\n"
42R"==(#if (MB_BLOCK == 32) \ )==""\n"
43R"==(|| ((PW == 0 && PH == 0 && PW == 0) \ )==""\n"
44R"==(&& (SW == 1 && SH == 1 && SD == 1)) )==""\n"
45R"==(#define BLOCK_READ_SRC_4x32(data, idx, sp_off) \ )==""\n"
46R"==(data = AS_SRC_MMAD_DATA4_T( \ )==""\n"
47R"==(intel_sub_group_block_read4((__global uint *)&src[idx])); )==""\n"
48R"==(#define BLOCK_READ_SRC_8x32(data, idx, sp_off) \ )==""\n"
49R"==(data = AS_SRC_MMAD_DATA8_T( \ )==""\n"
50R"==(intel_sub_group_block_read8((__global uint *)&src[idx])); )==""\n"
51R"==(#else )==""\n"
52R"==(#define BLOCK_READ_SRC_4x32(data, idx, sp_off) \ )==""\n"
53R"==(do { \ )==""\n"
54R"==(unroll_for(uint _i = 0; _i < 4; ++_i) { \ )==""\n"
55R"==(if (HAS_PADDING) { \ )==""\n"
56R"==(PAD_BLOCK_READ(data[_i], src, sp, _i + sp_off); \ )==""\n"
57R"==(} else { \ )==""\n"
58R"==(data[_i] = AS_SRC_MMAD_DATA_T(intel_sub_group_block_read( \ )==""\n"
59R"==((__global uint *)&src[idx + _i * SW * IC_BLOCK])); \ )==""\n"
60R"==(} \ )==""\n"
61R"==(} \ )==""\n"
62R"==(} while (0); )==""\n"
63R"==(#define BLOCK_READ_SRC_8x32(data, idx, sp_off) \ )==""\n"
64R"==(do { \ )==""\n"
65R"==(unroll_for(uint _i = 0; _i < 8; ++_i) { \ )==""\n"
66R"==(if (HAS_PADDING) { \ )==""\n"
67R"==(PAD_BLOCK_READ(data[_i], src, sp, _i + sp_off); \ )==""\n"
68R"==(} else { \ )==""\n"
69R"==(data[_i] = AS_SRC_MMAD_DATA_T(intel_sub_group_block_read( \ )==""\n"
70R"==((__global uint *)&src[idx + _i * SW * IC_BLOCK])); \ )==""\n"
71R"==(} \ )==""\n"
72R"==(} \ )==""\n"
73R"==(} while (0); )==""\n"
74R"==(#endif )==""\n"
75R"==(#if SP_BLOCK == 4 )==""\n"
76R"==(#define BLOCK0 4 )==""\n"
77R"==(#define ACC_DATA_BLOCK int4 )==""\n"
78R"==(#define SRC_DATA_BLOCK_T SRC_MMAD_DATA4_T )==""\n"
79R"==(#define BLOCK_READ_SRC BLOCK_READ_SRC_4x32 )==""\n"
80R"==(DECLARE_MMAD_EMU(mmad_tail0, idot4, IC_NBLOCKS_TAIL, 4, SRC_DATA_BLOCK_T, int8, )==""\n"
81R"==(ACC_DATA_BLOCK) )==""\n"
82R"==(#define MMAD_FULL0 mmad8x4 )==""\n"
83R"==(#define MMAD_TAIL0 mmad_tail0 )==""\n"
84R"==(#else )==""\n"
85R"==(#define BLOCK0 8 )==""\n"
86R"==(#define ACC_DATA_BLOCK int8 )==""\n"
87R"==(#define SRC_DATA_BLOCK_T SRC_MMAD_DATA8_T )==""\n"
88R"==(#define BLOCK_READ_SRC BLOCK_READ_SRC_8x32 )==""\n"
89R"==(DECLARE_MMAD_EMU(mmad_tail0, idot4, IC_NBLOCKS_TAIL, 8, SRC_DATA_BLOCK_T, int8, )==""\n"
90R"==(ACC_DATA_BLOCK) )==""\n"
91R"==(#define MMAD_FULL0 mmad8x8 )==""\n"
92R"==(#define MMAD_TAIL0 mmad_tail0 )==""\n"
93R"==(#endif )==""\n"
94R"==(#if SP_BLOCK == 12 )==""\n"
95R"==(#define BLOCK1 4 )==""\n"
96R"==(#define ACC_DATA_BLOCK1 int4 )==""\n"
97R"==(#define SRC_DATA_BLOCK_T1 SRC_MMAD_DATA4_T )==""\n"
98R"==(#define DST_DATA_BLOCK_T1 uint4 )==""\n"
99R"==(#define BLOCK_READ_SRC1 BLOCK_READ_SRC_4x32 )==""\n"
100R"==(DECLARE_MMAD_EMU(mmad_tail1, idot4, IC_NBLOCKS_TAIL, 4, SRC_DATA_BLOCK_T1, int8, )==""\n"
101R"==(ACC_DATA_BLOCK1) )==""\n"
102R"==(#define MMAD_FULL1 mmad8x4 )==""\n"
103R"==(#define MMAD_TAIL1 mmad_tail1 )==""\n"
104R"==(#else )==""\n"
105R"==(#define BLOCK1 8 )==""\n"
106R"==(#define ACC_DATA_BLOCK1 int8 )==""\n"
107R"==(#define SRC_DATA_BLOCK_T1 SRC_MMAD_DATA8_T )==""\n"
108R"==(#define DST_DATA_BLOCK_T1 uint8 )==""\n"
109R"==(#define BLOCK_READ_SRC1 BLOCK_READ_SRC_8x32 )==""\n"
110R"==(DECLARE_MMAD_EMU(mmad_tail1, idot4, IC_NBLOCKS_TAIL, 8, SRC_DATA_BLOCK_T1, int8, )==""\n"
111R"==(ACC_DATA_BLOCK1) )==""\n"
112R"==(#define MMAD_FULL1 mmad8x8 )==""\n"
113R"==(#define MMAD_TAIL1 mmad_tail1 )==""\n"
114R"==(#endif )==""\n"
115R"==(#if INT8_WEI_SLM )==""\n"
116R"==(#define BLOCK_READ_WHT_1x32(data, idx) \ )==""\n"
117R"==(data = as_int(block_read((__local uint *)&wei_tmp[idx])); )==""\n"
118R"==(#define BLOCK_READ_WHT_8x32(data, idx) \ )==""\n"
119R"==(data = as_int8(block_read8((__local uint *)&wei_tmp[idx])); )==""\n"
120R"==(#else )==""\n"
121R"==(#define BLOCK_READ_WHT_1x32(data, idx) \ )==""\n"
122R"==(data = as_int(intel_sub_group_block_read((__global uint *)&wei[idx])); )==""\n"
123R"==(#define BLOCK_READ_WHT_8x32(data, idx) \ )==""\n"
124R"==(data = as_int8(intel_sub_group_block_read8((__global uint *)&wei[idx])); )==""\n"
125R"==(#endif )==""\n"
126R"==(#if OC % OC_BLOCK == 0 )==""\n"
127R"==(#define BLOCK_READ_BIA(data, idx) \ )==""\n"
128R"==(data = as_float4(intel_sub_group_block_read4((__global uint *)&bias[idx])); )==""\n"
129R"==(#else )==""\n"
130R"==(#define BLOCK_READ_BIA(data, idx) \ )==""\n"
131R"==(data = (float4)0; \ )==""\n"
132R"==(int i; \ )==""\n"
133R"==(for (i = idx; i < idx + OC_BLOCK && i < OC - (OC % SUB_GROUP_SIZE); \ )==""\n"
134R"==(i += SUB_GROUP_SIZE) { \ )==""\n"
135R"==(data[(i - idx) / SUB_GROUP_SIZE] = as_float( \ )==""\n"
136R"==(intel_sub_group_block_read((__global uint *)&bias[i])); \ )==""\n"
137R"==(} \ )==""\n"
138R"==(if (i < OC && (i - idx) / SUB_GROUP_SIZE < 4 \ )==""\n"
139R"==(&& get_sub_group_local_id() < OC % SUB_GROUP_SIZE) { \ )==""\n"
140R"==(data[(i - idx) / SUB_GROUP_SIZE] \ )==""\n"
141R"==(= as_float(bias[i + get_sub_group_local_id()]); \ )==""\n"
142R"==(} )==""\n"
143R"==(#endif )==""\n"
144R"==(#define PAD_BLOCK_READ(data, src, sp, i) \ )==""\n"
145R"==(do { \ )==""\n"
146R"==(const int od = (sp + i) / (OW * OH); \ )==""\n"
147R"==(const int ohw = (sp + i) % (OW * OH); \ )==""\n"
148R"==(const int oh = ohw / OW; \ )==""\n"
149R"==(const int ow = (ohw % OW); \ )==""\n"
150R"==(const int id = SD * od - PD; \ )==""\n"
151R"==(const int ih = SH * oh - PH; \ )==""\n"
152R"==(const int iw = SW * ow - PW; \ )==""\n"
153R"==(bool pad = ((PW > 0 || PH > 0 || PD > 0) \ )==""\n"
154R"==(&& (iw < 0 || ih < 0 || id < 0 || iw >= IW || ih >= IH \ )==""\n"
155R"==(|| id >= ID)); \ )==""\n"
156R"==(int off = id * IH * IW + ih * IW + iw; \ )==""\n"
157R"==(data = pad ? 0 \ )==""\n"
158R"==(: AS_SRC_MMAD_DATA_T(intel_sub_group_block_read( \ )==""\n"
159R"==((global uint *)&src[off * SRC_SP_STRIDE])); \ )==""\n"
160R"==(} while (0); )==""\n"
161R"==(#if SCALES_PER_OC )==""\n"
162R"==(#define SCALE scales )==""\n"
163R"==(#elif SCALES_COMMON )==""\n"
164R"==(#define SCALE runtime_scales[0] )==""\n"
165R"==(#else )==""\n"
166R"==(#define SCALE 1 )==""\n"
167R"==(#endif )==""\n"
168R"==(void block_read_dst(int n, DST_DATA_T *d, const __global DST_DATA_T *dst); )==""\n"
169R"==(void block_write_dst(int n, const DST_DATA_T *d, __global DST_DATA_T *dst); )==""\n"
170R"==(__attribute__((intel_reqd_sub_group_size(SUB_GROUP_SIZE))) )==""\n"
171R"==(__attribute__((reqd_work_group_size(LWS_0, LWS_1, LWS_2))) __kernel void )==""\n"
172R"==(xe_lp_1x1_conv_fwd_x8s8x(const __global SRC_DATA_T *src, )==""\n"
173R"==(const __global char *wei, const __global float *bias, )==""\n"
174R"==(__global DST_DATA_T *dst POST_OP_ARGS, )==""\n"
175R"==(const __global float *runtime_scales, )==""\n"
176R"==(const __global int *src_compensation, )==""\n"
177R"==(const __global int *dst_compensation) { )==""\n"
178R"==(const uint oc_group_id = get_group_id(0); )==""\n"
179R"==(const uint sp_group_id = get_group_id(1); )==""\n"
180R"==(const uint mb_group_id = get_group_id(2); )==""\n"
181R"==(const uint ic_group_id = oc_group_id / OC_NCHUNK * IC_NCHUNK; )==""\n"
182R"==(const uint sg_local_id = get_sub_group_local_id(); )==""\n"
183R"==(const uint sg_id = get_sub_group_id(); )==""\n"
184R"==(#if USE_SP_BLOCK )==""\n"
185R"==(const uint sp = get_global_id(1) * SP_BLOCK; )==""\n"
186R"==(const int sp_local_id = get_local_id(1); )==""\n"
187R"==(const uint od = sp / (OH * OW); )==""\n"
188R"==(const uint ohw = sp % (OH * OW); )==""\n"
189R"==(const uint oh = ohw / OW; )==""\n"
190R"==(const uint ow = (ohw % OW); )==""\n"
191R"==(#else )==""\n"
192R"==(const uint sp = get_global_id(1); )==""\n"
193R"==(const int sp_local_id = get_local_id(1); )==""\n"
194R"==(const uint od = sp / (OWB * OH); )==""\n"
195R"==(const uint ohw = sp % (OWB * OH); )==""\n"
196R"==(const uint oh = ohw / OWB; )==""\n"
197R"==(const uint ow = (ohw % OWB) * SP_BLOCK; )==""\n"
198R"==(#endif )==""\n"
199R"==(const uint id = SD * od; )==""\n"
200R"==(const uint ih = SH * oh; )==""\n"
201R"==(const uint iw = SW * ow; )==""\n"
202R"==(#if MB_BLOCK == 32 )==""\n"
203R"==(src += (mb_group_id % 2) * MB_BLOCK / 2 * SRC_MB_STRIDE; )==""\n"
204R"==(src += (mb_group_id / 2) * SRC_ICB_STRIDE * IC_NCHUNK * G; )==""\n"
205R"==(#else )==""\n"
206R"==(src += mb_group_id * SRC_ICB_STRIDE * IC_NCHUNK * G; )==""\n"
207R"==(#endif )==""\n"
208R"==(src += ic_group_id * SRC_ICB_STRIDE; )==""\n"
209R"==(#if !HAS_PADDING )==""\n"
210R"==(src += (id * IH * IW + ih * IW + iw) * SRC_SP_STRIDE; )==""\n"
211R"==(#endif )==""\n"
212R"==(#if MB_BLOCK == 32 )==""\n"
213R"==(dst += (mb_group_id % 2) * MB_BLOCK / 2 * DST_MB_STRIDE; )==""\n"
214R"==(dst += (mb_group_id / 2) * DST_OCB_STRIDE * OC_NCHUNK * G; )==""\n"
215R"==(#else )==""\n"
216R"==(dst += mb_group_id * DST_OCB_STRIDE * OC_NCHUNK * G; )==""\n"
217R"==(#endif )==""\n"
218R"==(dst += oc_group_id * DST_OCB_STRIDE; )==""\n"
219R"==(dst += (od * OH * OW + oh * OW + ow) * SRC_SP_STRIDE; )==""\n"
220R"==(wei += oc_group_id * WEI_BLOCK_STRIDE * IC_NCHUNK; )==""\n"
221R"==(ACC_DATA_BLOCK C00 = 0, C01 = 0, C02 = 0, C03 = 0; )==""\n"
222R"==(ACC_DATA_BLOCK1 C10 = 0, C11 = 0, C12 = 0, C13 = 0; )==""\n"
223R"==(#if INT8_WEI_SLM )==""\n"
224R"==(#define READ_SLM() \ )==""\n"
225R"==(barrier(CLK_LOCAL_MEM_FENCE); \ )==""\n"
226R"==(const __global char *wei_copy_from \ )==""\n"
227R"==(= wei + sp_local_id * WEI_BLOCK_STRIDE / LWS_1; \ )==""\n"
228R"==(__local char *wei_copy_to \ )==""\n"
229R"==(= wei_slm + sp_local_id * WEI_BLOCK_STRIDE / LWS_1; \ )==""\n"
230R"==(block_write4((__local uint *)wei_copy_to, \ )==""\n"
231R"==(intel_sub_group_block_read4((__global uint *)wei_copy_from)); \ )==""\n"
232R"==(__local char *wei_tmp = wei_slm; \ )==""\n"
233R"==(barrier(CLK_LOCAL_MEM_FENCE); )==""\n"
234R"==(__local char wei_slm[WEI_BLOCK_STRIDE]; )==""\n"
235R"==(#endif )==""\n"
236R"==(for (uint ic_block_id = 0; ic_block_id < IC_NCHUNK; ++ic_block_id) { )==""\n"
237R"==(if (MB % MB_BLOCK != 0 && MB_BLOCK == 32 )==""\n"
238R"==(&& mb_group_id * MB_BLOCK / 2 > MB) { )==""\n"
239R"==(break; )==""\n"
240R"==(} )==""\n"
241R"==(#if INT8_WEI_SLM )==""\n"
242R"==(READ_SLM() )==""\n"
243R"==(#if SP_TAIL )==""\n"
244R"==(if (ow < OW) )==""\n"
245R"==(#endif )==""\n"
246R"==(#endif )==""\n"
247R"==({ )==""\n"
248R"==(SRC_DATA_BLOCK_T S0 = 0; )==""\n"
249R"==(SRC_DATA_BLOCK_T1 S1 = 0; )==""\n"
250R"==(#if OUT_SP_TAIL )==""\n"
251R"==(#if USE_SP_BLOCK )==""\n"
252R"==(if (od * OH * OW + oh * OW + ow + SP_BLOCK > DST_SP) { )==""\n"
253R"==(#else )==""\n"
254R"==(if (ow + SP_BLOCK > OW) { )==""\n"
255R"==(#endif )==""\n"
256R"==(#if OUT_SP_TAIL < 8 )==""\n"
257R"==(for (int _i = 0; _i < OUT_SP_TAIL; ++_i) { )==""\n"
258R"==(if (HAS_PADDING) { )==""\n"
259R"==(PAD_BLOCK_READ(S0[_i], src, sp, _i); )==""\n"
260R"==(} else { )==""\n"
261R"==(S0[_i] = AS_SRC_MMAD_DATA_T(intel_sub_group_block_read( )==""\n"
262R"==((__global uint *)&src[_i * SW * IC_BLOCK])); )==""\n"
263R"==(} )==""\n"
264R"==(} )==""\n"
265R"==(#else )==""\n"
266R"==(BLOCK_READ_SRC(S0, 0 * IC_BLOCK, 0); )==""\n"
267R"==(for (int _i = 8; _i < OUT_SP_TAIL; ++_i) { )==""\n"
268R"==(if (HAS_PADDING) { )==""\n"
269R"==(PAD_BLOCK_READ(S1[_i - 8], src, sp, _i) )==""\n"
270R"==(} else { )==""\n"
271R"==(S1[_i - 8] = AS_SRC_MMAD_DATA_T( )==""\n"
272R"==(intel_sub_group_block_read((__global uint )==""\n"
273R"==(*)&src[_i * SW * IC_BLOCK])); )==""\n"
274R"==(} )==""\n"
275R"==(} )==""\n"
276R"==(#endif )==""\n"
277R"==(} else )==""\n"
278R"==(#endif )==""\n"
279R"==({ )==""\n"
280R"==(BLOCK_READ_SRC(S0, 0 * IC_BLOCK, 0); )==""\n"
281R"==(#if (MB_BLOCK == 32 && MB > 8) )==""\n"
282R"==(BLOCK_READ_SRC1(S1, 8 * IC_BLOCK, 0); )==""\n"
283R"==(#elif SP_BLOCK > 8 )==""\n"
284R"==(BLOCK_READ_SRC1(S1, 8 * SW * IC_BLOCK, 8); )==""\n"
285R"==(#endif )==""\n"
286R"==(} )==""\n"
287R"==(int8 W0 = 0, W1 = 0, W2 = 0, W3 = 0; )==""\n"
288R"==(#if IC % IC_BLOCK != 0 )==""\n"
289R"==(if (ic_block_id == IC_NCHUNK - 1) { )==""\n"
290R"==(unroll_for(int i = 0; i < IC_NBLOCKS_TAIL; ++i) )==""\n"
291R"==(BLOCK_READ_WHT_1x32(W0[i], (i + 0) * IC_BLOCK); )==""\n"
292R"==(if (OC > 8) )==""\n"
293R"==(unroll_for(int i = 0; i < IC_NBLOCKS_TAIL; ++i) )==""\n"
294R"==(BLOCK_READ_WHT_1x32(W1[i], (i + 8) * IC_BLOCK); )==""\n"
295R"==(if (OC > 16) )==""\n"
296R"==(unroll_for(int i = 0; i < IC_NBLOCKS_TAIL; ++i) )==""\n"
297R"==(BLOCK_READ_WHT_1x32(W2[i], (i + 16) * IC_BLOCK); )==""\n"
298R"==(if (OC > 24) )==""\n"
299R"==(unroll_for(int i = 0; i < IC_NBLOCKS_TAIL; ++i) )==""\n"
300R"==(BLOCK_READ_WHT_1x32(W3[i], (i + 24) * IC_BLOCK); )==""\n"
301R"==(C00 = MMAD_TAIL0(S0, W0, C00); )==""\n"
302R"==(if (OC > 8) C01 = MMAD_TAIL0(S0, W1, C01); )==""\n"
303R"==(if (OC > 16) C02 = MMAD_TAIL0(S0, W2, C02); )==""\n"
304R"==(if (OC > 24) C03 = MMAD_TAIL0(S0, W3, C03); )==""\n"
305R"==(#if (MB_BLOCK == 32 && MB > 8) || SP_BLOCK > 8 )==""\n"
306R"==(C10 = MMAD_TAIL1(S1, W0, C10); )==""\n"
307R"==(if (OC > 8) C11 = MMAD_TAIL1(S1, W1, C11); )==""\n"
308R"==(if (OC > 16) C12 = MMAD_TAIL1(S1, W2, C12); )==""\n"
309R"==(if (OC > 24) C13 = MMAD_TAIL1(S1, W3, C13); )==""\n"
310R"==(#endif )==""\n"
311R"==(} else )==""\n"
312R"==(#endif )==""\n"
313R"==({ )==""\n"
314R"==(BLOCK_READ_WHT_8x32(W0, 0); )==""\n"
315R"==(if (OC > 8) BLOCK_READ_WHT_8x32(W1, 8 * IC_BLOCK); )==""\n"
316R"==(if (OC > 16) BLOCK_READ_WHT_8x32(W2, 16 * IC_BLOCK); )==""\n"
317R"==(if (OC > 24) BLOCK_READ_WHT_8x32(W3, 24 * IC_BLOCK); )==""\n"
318R"==(C00 = MMAD_FULL0(S0, W0, C00); )==""\n"
319R"==(if (OC > 8) C01 = MMAD_FULL0(S0, W1, C01); )==""\n"
320R"==(if (OC > 16) C02 = MMAD_FULL0(S0, W2, C02); )==""\n"
321R"==(if (OC > 24) C03 = MMAD_FULL0(S0, W3, C03); )==""\n"
322R"==(#if (MB_BLOCK == 32 && MB > 8) || SP_BLOCK > 8 )==""\n"
323R"==(C10 = MMAD_FULL1(S1, W0, C10); )==""\n"
324R"==(if (OC > 8) C11 = MMAD_FULL1(S1, W1, C11); )==""\n"
325R"==(if (OC > 16) C12 = MMAD_FULL1(S1, W2, C12); )==""\n"
326R"==(if (OC > 24) C13 = MMAD_FULL1(S1, W3, C13); )==""\n"
327R"==(#endif )==""\n"
328R"==(} )==""\n"
329R"==(} )==""\n"
330R"==(src += SRC_ICB_STRIDE; )==""\n"
331R"==(wei += WEI_BLOCK_STRIDE; )==""\n"
332R"==(} )==""\n"
333R"==(#if WITH_SRC_ZPOINTS )==""\n"
334R"==(int4 src_comp = as_int4(intel_sub_group_block_read4( )==""\n"
335R"==((__global uint *)(&src_compensation[oc_group_id * OC_BLOCK]))); )==""\n"
336R"==(C00 -= src_comp.s0; )==""\n"
337R"==(C01 -= src_comp.s1; )==""\n"
338R"==(C02 -= src_comp.s2; )==""\n"
339R"==(C03 -= src_comp.s3; )==""\n"
340R"==(#if (MB_BLOCK == 32 && MB > 8) || SP_BLOCK > 8 )==""\n"
341R"==(C10 -= src_comp.s0; )==""\n"
342R"==(C11 -= src_comp.s1; )==""\n"
343R"==(C12 -= src_comp.s2; )==""\n"
344R"==(C13 -= src_comp.s3; )==""\n"
345R"==(#endif )==""\n"
346R"==(#endif )==""\n"
347R"==(float4 tmp; )==""\n"
348R"==(DST_DATA4_T dst_pack[8]; )==""\n"
349R"==(DST_DATA4_T D0[BLOCK0] = {0}; )==""\n"
350R"==(DST_DATA4_T D1[BLOCK1] = {0}; )==""\n"
351R"==(#if SCALES_PER_OC )==""\n"
352R"==(float4 scales; )==""\n"
353R"==(block_read_scales( )==""\n"
354R"==(&scales, oc_group_id * OC_BLOCK, sg_local_id, runtime_scales); )==""\n"
355R"==(#endif )==""\n"
356R"==(#if WITH_BIAS )==""\n"
357R"==(float4 bia; )==""\n"
358R"==(BLOCK_READ_BIA(bia, oc_group_id * OC_BLOCK); )==""\n"
359R"==(#define QUANTIZE_ADD_BIAS() tmp = SCALE * fma(tmp, (float4)1, bia); )==""\n"
360R"==(#else )==""\n"
361R"==(#define QUANTIZE_ADD_BIAS() tmp *= SCALE; )==""\n"
362R"==(#endif )==""\n"
363R"==(#if WITH_SUM )==""\n"
364R"==(#if USE_SP_BLOCK )==""\n"
365R"==(if (OUT_SP_TAIL && od * OH * OW + oh * OW + ow + SP_BLOCK > DST_SP) { )==""\n"
366R"==(#else )==""\n"
367R"==(if (OUT_SP_TAIL && ow + SP_BLOCK > OW) { )==""\n"
368R"==(#endif )==""\n"
369R"==(#if OUT_SP_TAIL < 8 )==""\n"
370R"==(block_read_dst(OUT_SP_TAIL, D0, dst); )==""\n"
371R"==(#else )==""\n"
372R"==(block_read_dst(BLOCK0, D0, dst); )==""\n"
373R"==(block_read_dst(OUT_SP_TAIL - 8, D1, dst + 8 * OC_BLOCK); )==""\n"
374R"==(#endif )==""\n"
375R"==(} else { )==""\n"
376R"==(block_read_dst(BLOCK0, D0, dst); )==""\n"
377R"==(if (SP_BLOCK > 8 || (MB_BLOCK == 32 && MB > 8)) { )==""\n"
378R"==(block_read_dst(BLOCK1, D1, dst + 8 * OC_BLOCK); )==""\n"
379R"==(} )==""\n"
380R"==(} )==""\n"
381R"==(#endif )==""\n"
382R"==(#if WITH_DST_ZPOINTS )==""\n"
383R"==(int4 dst_zp = read_dst_zero_points_32c( )==""\n"
384R"==(dst_compensation, oc_group_id * OC_BLOCK); )==""\n"
385R"==(#if !WITH_DST_ZPOINTS_PER_OC && OC % 32 != 0 )==""\n"
386R"==(dst_zp = convert_int4( )==""\n"
387R"==(zero_pad_dst_32c(convert_float4(dst_zp), oc_group_id * OC_BLOCK)); )==""\n"
388R"==(#endif )==""\n"
389R"==(#define ADD_DST_COMPENSATION() tmp += convert_float4(dst_zp); )==""\n"
390R"==(#else )==""\n"
391R"==(#define ADD_DST_COMPENSATION() )==""\n"
392R"==(#endif )==""\n"
393R"==(#if WITH_SRC_ZPOINTS )==""\n"
394R"==(#define ZERO_PAD_DST() tmp = zero_pad_dst_32c(tmp, oc_group_id * OC_BLOCK); )==""\n"
395R"==(#else )==""\n"
396R"==(#define ZERO_PAD_DST() )==""\n"
397R"==(#endif )==""\n"
398R"==(#define PACK(C0, C1, C2, C3, idx) \ )==""\n"
399R"==(do { \ )==""\n"
400R"==(tmp[0] = C0[idx]; \ )==""\n"
401R"==(tmp[1] = C1[idx]; \ )==""\n"
402R"==(tmp[2] = C2[idx]; \ )==""\n"
403R"==(tmp[3] = C3[idx]; \ )==""\n"
404R"==(} while (0) )==""\n"
405R"==(#define CONVERT_PACK(idx) \ )==""\n"
406R"==(do { \ )==""\n"
407R"==(dst_pack[idx] = CONVERT_DST_DATA4_T(tmp); \ )==""\n"
408R"==(} while (0) )==""\n"
409R"==(#define STORE_DST(n, C0, C1, C2, C3, D, dst_ptr, mb_stride) \ )==""\n"
410R"==(do { \ )==""\n"
411R"==(for (int n_i = 0; n_i < n; n_i++) { \ )==""\n"
412R"==(PACK(C0, C1, C2, C3, n_i); \ )==""\n"
413R"==(int po_mb; \ )==""\n"
414R"==(if (MB_BLOCK == 32) \ )==""\n"
415R"==(po_mb = (mb_group_id * MB_BLOCK / 2 + mb_stride * 8 + n_i); \ )==""\n"
416R"==(else \ )==""\n"
417R"==(po_mb = mb_group_id % MB; \ )==""\n"
418R"==(if (MB % MB_BLOCK == 0 || po_mb < MB) { \ )==""\n"
419R"==(QUANTIZE_ADD_BIAS(); \ )==""\n"
420R"==(const int po_oc = (oc_group_id * OC_BLOCK) % (OC * G); \ )==""\n"
421R"==(float4 dni = convert_float4(AS_SUM_DATA4_T(D[n_i])); \ )==""\n"
422R"==(APPLY_POST_OPS_TRY_BURST(tmp, float, dni, float, po_mb, 1, \ )==""\n"
423R"==(po_oc, 4 * SUB_GROUP_SIZE, sg_local_id); \ )==""\n"
424R"==(ADD_DST_COMPENSATION(); \ )==""\n"
425R"==(ZERO_PAD_DST(); \ )==""\n"
426R"==(} \ )==""\n"
427R"==(CONVERT_PACK(n_i); \ )==""\n"
428R"==(} \ )==""\n"
429R"==(block_write_dst(n, dst_pack, dst_ptr); \ )==""\n"
430R"==(} while (0) )==""\n"
431R"==(#if INT8_WEI_SLM && SP_TAIL )==""\n"
432R"==(if (ow < OW) )==""\n"
433R"==(#endif )==""\n"
434R"==({ )==""\n"
435R"==(#if USE_SP_BLOCK )==""\n"
436R"==(if (OUT_SP_TAIL && od * OH * OW + oh * OW + ow + SP_BLOCK > DST_SP) { )==""\n"
437R"==(#else )==""\n"
438R"==(if (OUT_SP_TAIL && ow + SP_BLOCK > OW) { )==""\n"
439R"==(#endif )==""\n"
440R"==(STORE_DST(min(BLOCK0, OUT_SP_TAIL), C00, C01, C02, C03, D0, dst, 0); )==""\n"
441R"==(STORE_DST(OUT_SP_TAIL - 8, C10, C11, C12, C13, D1, )==""\n"
442R"==(dst + 8 * OC_BLOCK, 1); )==""\n"
443R"==(} else { )==""\n"
444R"==(STORE_DST(BLOCK0, C00, C01, C02, C03, D0, dst, 0); )==""\n"
445R"==(if (SP_BLOCK > 8 || MB_BLOCK == 32) )==""\n"
446R"==(STORE_DST( )==""\n"
447R"==(BLOCK1, C10, C11, C12, C13, D1, dst + 8 * OC_BLOCK, 1); )==""\n"
448R"==(} )==""\n"
449R"==(} )==""\n"
450R"==(} )==""\n"
451R"==(void block_read_dst(int n, DST_DATA_T *d, const __global DST_DATA_T *dst) { )==""\n"
452R"==(int nelems = n * 4; )==""\n"
453R"==(__attribute__((opencl_unroll_hint)) )==""\n"
454R"==(for (int i = 0; i < nelems / 16 * 16; i += 16) { )==""\n"
455R"==(*((DST_DATA16_T *)&d[i]) = BLOCK_READ_DST16(dst + i * 8); )==""\n"
456R"==(} )==""\n"
457R"==(__attribute__((opencl_unroll_hint)) )==""\n"
458R"==(for (int i = nelems / 16 * 16; i < nelems / 8 * 8; i += 8) { )==""\n"
459R"==(*((DST_DATA8_T *)&d[i]) = BLOCK_READ_DST8(dst + i * 8); )==""\n"
460R"==(} )==""\n"
461R"==(__attribute__((opencl_unroll_hint)) )==""\n"
462R"==(for (int i = nelems / 8 * 8; i < nelems; i += 4) { )==""\n"
463R"==(*((DST_DATA4_T *)&d[i]) = BLOCK_READ_DST4(dst + i * 8); )==""\n"
464R"==(} )==""\n"
465R"==(} )==""\n"
466R"==(void block_write_dst(int n, const DST_DATA_T *d, __global DST_DATA_T *dst) { )==""\n"
467R"==(int nelems = n * 4; )==""\n"
468R"==(__attribute__((opencl_unroll_hint)) )==""\n"
469R"==(for (int i = 0; i < nelems / 16 * 16; i += 16) { )==""\n"
470R"==(BLOCK_WRITE_DST16(dst + i * 8, *((DST_DATA16_T *)&d[i])); )==""\n"
471R"==(} )==""\n"
472R"==(__attribute__((opencl_unroll_hint)) )==""\n"
473R"==(for (int i = nelems / 16 * 16; i < nelems / 8 * 8; i += 8) { )==""\n"
474R"==(BLOCK_WRITE_DST8(dst + i * 8, *((DST_DATA8_T *)&d[i])); )==""\n"
475R"==(} )==""\n"
476R"==(__attribute__((opencl_unroll_hint)) )==""\n"
477R"==(for (int i = nelems / 8 * 8; i < nelems; i += 4) { )==""\n"
478R"==(BLOCK_WRITE_DST4(dst + i * 8, *((DST_DATA4_T *)&d[i])); )==""\n"
479R"==(} )==""\n"
480R"==(} )==""\n"
481R"==()==";
482}
483}
484}
485}