1namespace dnnl {
2namespace impl {
3namespace gpu {
4namespace ocl {
5const char *xe_lp_nhwc_1x1_conv_fwd_x8s8x_kernel = R"==(/******************************************************************************* )==""\n"
6R"==(* Copyright 2020-2022 Intel Corporation )==""\n"
7R"==(* )==""\n"
8R"==(* licensed under the apache license, version 2.0 (the "license"); )==""\n"
9R"==(* you may not use this file except in compliance with the license. )==""\n"
10R"==(* you may obtain a copy of the license at )==""\n"
11R"==(* )==""\n"
12R"==(* http: )==""\n"
13R"==(* )==""\n"
14R"==(* unless required by applicable law or agreed to in writing, software )==""\n"
15R"==(* distributed under the license is distributed on an "as is" basis, )==""\n"
16R"==(* without warranties or conditions of any kind, either express or implied. )==""\n"
17R"==(* see the license for the specific language governing permissions and )==""\n"
18R"==(* limitations under the license. )==""\n"
19R"==(*******************************************************************************/ )==""\n"
20R"==(#include "gpu/ocl/ocl_math_utils.h" )==""\n"
21R"==(#include "gpu/ocl/ocl_post_ops.h" )==""\n"
22R"==(#include "gpu/ocl/ocl_scales.h" )==""\n"
23R"==(#include "gpu/ocl/ocl_types.h" )==""\n"
24R"==(#include "gpu/ocl/ocl_zero_points.h" )==""\n"
25R"==(#if IC % IC_BLOCK != 0 )==""\n"
26R"==(#define IC_NBLOCKS_TAIL ((IC % IC_BLOCK + 3) / 4) )==""\n"
27R"==(#else )==""\n"
28R"==(#define IC_NBLOCKS_TAIL 8 )==""\n"
29R"==(#endif )==""\n"
30R"==(#define IC_TAIL (IC % IC_BLOCK) )==""\n"
31R"==(#define HAS_PADDING (PW > 0 || PH > 0 || PD > 0) )==""\n"
32R"==(#define USE_SP_BLOCK ((SW == 1 && SH == 1 && SW == 1) || HAS_PADDING) )==""\n"
33R"==(#define SRC_SP (IW * IH * ID) )==""\n"
34R"==(#define SRC_SP_STRIDE (G * IC) )==""\n"
35R"==(#define SRC_ICB_STRIDE IC_BLOCK )==""\n"
36R"==(#define DST_SP (OW * OH * OD) )==""\n"
37R"==(#define DST_SP_STRIDE (G * OC) )==""\n"
38R"==(#define DST_OCB_STRIDE OC_BLOCK )==""\n"
39R"==(#define WEI_BLOCK_STRIDE (4 * 8 * 8 * 4) )==""\n"
40R"==(#if DST_DT_S8 || DST_DT_U8 )==""\n"
41R"==(#define OC_BLOCK_READ_BOUND 4 )==""\n"
42R"==(#define OC_BLOCK_WRITE_BOUND 16 )==""\n"
43R"==(#else )==""\n"
44R"==(#define OC_BLOCK_READ_BOUND 1 )==""\n"
45R"==(#define OC_BLOCK_WRITE_BOUND 4 )==""\n"
46R"==(#endif )==""\n"
47R"==(#define IC_BLOCK_READ_BOUND 4 )==""\n"
48R"==(#if IC % IC_BLOCK == 0 )==""\n"
49R"==(#define BLOCK_READ_SRC_Xx32(start, end, d_idx, data, idx, sp_off) \ )==""\n"
50R"==(do { \ )==""\n"
51R"==(uint *d = (uint *)&data; \ )==""\n"
52R"==(unroll_for(uint _i = (start); _i < (end); ++_i) { \ )==""\n"
53R"==(if (HAS_PADDING) { \ )==""\n"
54R"==(PAD_BLOCK_READ(d[_i + d_idx], src, sp, _i + sp_off, 0, 0) \ )==""\n"
55R"==(} else { \ )==""\n"
56R"==(d[_i + d_idx] = AS_SRC_MMAD_DATA_T(intel_sub_group_block_read( \ )==""\n"
57R"==((__global uint *)&src[idx + _i * SW * G * IC])); \ )==""\n"
58R"==(} \ )==""\n"
59R"==(} \ )==""\n"
60R"==(} while (0); )==""\n"
61R"==(#elif IC % 4 == 0 )==""\n"
62R"==(#define BLOCK_READ_SRC_Xx32(start, end, d_idx, data, idx, sp_off) \ )==""\n"
63R"==(do { \ )==""\n"
64R"==(uint *d = (uint *)&data; \ )==""\n"
65R"==(unroll_for(uint _i = (start); _i < (end); ++_i) { \ )==""\n"
66R"==(__global uchar *s = &src[idx + _i * SW * G * IC]; \ )==""\n"
67R"==(if (ic_block_id < IC_NCHUNK - 1 \ )==""\n"
68R"==(|| sg_local_id * 4 < IC_TAIL - IC_TAIL % 4) { \ )==""\n"
69R"==(if (HAS_PADDING) { \ )==""\n"
70R"==(PAD_BLOCK_READ(d[_i + d_idx], src, sp, _i + sp_off, 1, 0) \ )==""\n"
71R"==(} else { \ )==""\n"
72R"==(d[_i + d_idx] = *((__global uint *)&s[sg_local_id * 4]); \ )==""\n"
73R"==(} \ )==""\n"
74R"==(} \ )==""\n"
75R"==(} \ )==""\n"
76R"==(} while (0); )==""\n"
77R"==(#else )==""\n"
78R"==(#define BLOCK_READ_SRC_Xx32(start, end, d_idx, data, idx, sp_off) \ )==""\n"
79R"==(do { \ )==""\n"
80R"==(uint *d = (uint *)&data; \ )==""\n"
81R"==(unroll_for(uint _i = (start); _i < (end); ++_i) { \ )==""\n"
82R"==(__global uchar *s = &src[idx + _i * SW * G * IC]; \ )==""\n"
83R"==(uint _j_max = (sg_local_id * 4 < IC_TAIL - IC_TAIL % 4) \ )==""\n"
84R"==(|| (ic_block_id < IC_NCHUNK - 1) \ )==""\n"
85R"==(? 4 \ )==""\n"
86R"==(: (sg_local_id * 4 == IC_TAIL - IC_TAIL % 4 ? IC_TAIL % 4 \ )==""\n"
87R"==(: 0); \ )==""\n"
88R"==(unroll_for(uint _j = 0; _j < _j_max; ++_j) { \ )==""\n"
89R"==(if (HAS_PADDING) { \ )==""\n"
90R"==(PAD_BLOCK_READ(d[_i + d_idx], src, sp, _i + sp_off, 2, _j) \ )==""\n"
91R"==(} else { \ )==""\n"
92R"==(*((uchar *)&d[_i + d_idx] + _j) = s[sg_local_id * 4 + _j]; \ )==""\n"
93R"==(} \ )==""\n"
94R"==(} \ )==""\n"
95R"==(} \ )==""\n"
96R"==(} while (0); )==""\n"
97R"==(#endif )==""\n"
98R"==(#if SP_BLOCK == 4 )==""\n"
99R"==(#define BLOCK0 4 )==""\n"
100R"==(#define ACC_DATA_BLOCK int4 )==""\n"
101R"==(#define SRC_DATA_BLOCK_T SRC_MMAD_DATA4_T )==""\n"
102R"==(#define BLOCK_READ_SRC BLOCK_READ_SRC_4x32 )==""\n"
103R"==(DECLARE_MMAD_EMU(mmad_tail0, idot4, IC_NBLOCKS_TAIL, 4, SRC_DATA_BLOCK_T, int8, )==""\n"
104R"==(ACC_DATA_BLOCK) )==""\n"
105R"==(#define MMAD_FULL0 mmad8x4 )==""\n"
106R"==(#define MMAD_TAIL0 mmad_tail0 )==""\n"
107R"==(#else )==""\n"
108R"==(#define BLOCK0 8 )==""\n"
109R"==(#define ACC_DATA_BLOCK int8 )==""\n"
110R"==(#define SRC_DATA_BLOCK_T SRC_MMAD_DATA8_T )==""\n"
111R"==(#define BLOCK_READ_SRC BLOCK_READ_SRC_8x32 )==""\n"
112R"==(DECLARE_MMAD_EMU(mmad_tail0, idot4, IC_NBLOCKS_TAIL, 8, SRC_DATA_BLOCK_T, int8, )==""\n"
113R"==(ACC_DATA_BLOCK) )==""\n"
114R"==(#define MMAD_FULL0 mmad8x8 )==""\n"
115R"==(#define MMAD_TAIL0 mmad_tail0 )==""\n"
116R"==(#endif )==""\n"
117R"==(#if SP_BLOCK == 12 )==""\n"
118R"==(#define BLOCK1 4 )==""\n"
119R"==(#define ACC_DATA_BLOCK1 int4 )==""\n"
120R"==(#define SRC_DATA_BLOCK_T1 SRC_MMAD_DATA4_T )==""\n"
121R"==(#define DST_DATA_BLOCK_T1 uint4 )==""\n"
122R"==(#define BLOCK_READ_SRC1 BLOCK_READ_SRC_4x32 )==""\n"
123R"==(DECLARE_MMAD_EMU(mmad_tail1, idot4, IC_NBLOCKS_TAIL, 4, SRC_DATA_BLOCK_T1, int8, )==""\n"
124R"==(ACC_DATA_BLOCK1) )==""\n"
125R"==(#define MMAD_FULL1 mmad8x4 )==""\n"
126R"==(#define MMAD_TAIL1 mmad_tail1 )==""\n"
127R"==(#else )==""\n"
128R"==(#define BLOCK1 8 )==""\n"
129R"==(#define ACC_DATA_BLOCK1 int8 )==""\n"
130R"==(#define SRC_DATA_BLOCK_T1 SRC_MMAD_DATA8_T )==""\n"
131R"==(#define DST_DATA_BLOCK_T1 uint8 )==""\n"
132R"==(#define BLOCK_READ_SRC1 BLOCK_READ_SRC_8x32 )==""\n"
133R"==(DECLARE_MMAD_EMU(mmad_tail1, idot4, IC_NBLOCKS_TAIL, 8, SRC_DATA_BLOCK_T1, int8, )==""\n"
134R"==(ACC_DATA_BLOCK1) )==""\n"
135R"==(#define MMAD_FULL1 mmad8x8 )==""\n"
136R"==(#define MMAD_TAIL1 mmad_tail1 )==""\n"
137R"==(#endif )==""\n"
138R"==(#if INT8_WEI_SLM )==""\n"
139R"==(#define BLOCK_READ_WHT_1x32(data, idx) \ )==""\n"
140R"==(data = as_int(block_read((__local uint *)&wei_tmp[idx])); )==""\n"
141R"==(#define BLOCK_READ_WHT_8x32(data, idx) \ )==""\n"
142R"==(data = as_int8(block_read8((__local uint *)&wei_tmp[idx])); )==""\n"
143R"==(#else )==""\n"
144R"==(#define BLOCK_READ_WHT_1x32(data, idx) \ )==""\n"
145R"==(data = as_int(intel_sub_group_block_read((__global uint *)&wei[idx])); )==""\n"
146R"==(#define BLOCK_READ_WHT_8x32(data, idx) \ )==""\n"
147R"==(data = as_int8(intel_sub_group_block_read8((__global uint *)&wei[idx])); )==""\n"
148R"==(#endif )==""\n"
149R"==(#if OC % OC_BLOCK == 0 )==""\n"
150R"==(#define BLOCK_READ_BIA(data, idx) \ )==""\n"
151R"==(data = as_float4(intel_sub_group_block_read4((__global uint *)&bias[idx])); )==""\n"
152R"==(#else )==""\n"
153R"==(#define BLOCK_READ_BIA(data, idx) \ )==""\n"
154R"==(data = (float4)0; \ )==""\n"
155R"==(int i; \ )==""\n"
156R"==(for (i = idx; i < idx + OC_BLOCK && i < OC - (OC % SUB_GROUP_SIZE); \ )==""\n"
157R"==(i += SUB_GROUP_SIZE) { \ )==""\n"
158R"==(data[(i - idx) / SUB_GROUP_SIZE] = as_float( \ )==""\n"
159R"==(intel_sub_group_block_read((__global uint *)&bias[i])); \ )==""\n"
160R"==(} \ )==""\n"
161R"==(if ((get_sub_group_local_id() < OC % SUB_GROUP_SIZE) && i < OC \ )==""\n"
162R"==(&& (i - idx) / SUB_GROUP_SIZE < 4) { \ )==""\n"
163R"==(data[(i - idx) / SUB_GROUP_SIZE] \ )==""\n"
164R"==(= as_float(bias[i + get_sub_group_local_id()]); \ )==""\n"
165R"==(} )==""\n"
166R"==(#endif )==""\n"
167R"==(#define PAD_BLOCK_READ(data, src, sp, i, read_kind, l_off) \ )==""\n"
168R"==(do { \ )==""\n"
169R"==(const int od = (sp + i) / (OW * OH); \ )==""\n"
170R"==(const int ohw = (sp + i) % (OW * OH); \ )==""\n"
171R"==(const int oh = ohw / OW; \ )==""\n"
172R"==(const int ow = (ohw % OW); \ )==""\n"
173R"==(const int id = SD * od - PD; \ )==""\n"
174R"==(const int ih = SH * oh - PH; \ )==""\n"
175R"==(const int iw = SW * ow - PW; \ )==""\n"
176R"==(bool pad = ((PW > 0 || PH > 0 || PD > 0) \ )==""\n"
177R"==(&& (iw < 0 || ih < 0 || id < 0 || iw >= IW || ih >= IH \ )==""\n"
178R"==(|| id >= ID)); \ )==""\n"
179R"==(int off = id * IH * IW + ih * IW + iw; \ )==""\n"
180R"==(if (read_kind == 0) { \ )==""\n"
181R"==(data = pad ? 0 \ )==""\n"
182R"==(: AS_SRC_MMAD_DATA_T(intel_sub_group_block_read( \ )==""\n"
183R"==((global uint *)&src[off * SRC_SP_STRIDE])); \ )==""\n"
184R"==(} else if (read_kind == 1) { \ )==""\n"
185R"==(data = pad ? 0 \ )==""\n"
186R"==(: *((__global uint *)&src[off * SRC_SP_STRIDE \ )==""\n"
187R"==(+ sg_local_id * 4]); \ )==""\n"
188R"==(} else if (read_kind == 2) { \ )==""\n"
189R"==(*((uchar *)&data + l_off) = pad \ )==""\n"
190R"==(? 0 \ )==""\n"
191R"==(: src[off * SRC_SP_STRIDE + sg_local_id * 4 + l_off]; \ )==""\n"
192R"==(} \ )==""\n"
193R"==(} while (0); )==""\n"
194R"==(#if SCALES_PER_OC )==""\n"
195R"==(#define SCALE scales )==""\n"
196R"==(#elif SCALES_COMMON )==""\n"
197R"==(#define SCALE runtime_scales[0] )==""\n"
198R"==(#else )==""\n"
199R"==(#define SCALE 1 )==""\n"
200R"==(#endif )==""\n"
201R"==(void block_read_dst( )==""\n"
202R"==(int n, DST_DATA_T *d, const __global DST_DATA_T *dst, int oc_tail); )==""\n"
203R"==(void block_write_dst( )==""\n"
204R"==(int n, const DST_DATA_T *d, __global DST_DATA_T *dst, int oc_tail); )==""\n"
205R"==(__attribute__((intel_reqd_sub_group_size(SUB_GROUP_SIZE))) )==""\n"
206R"==(__attribute__((reqd_work_group_size(LWS_0, LWS_1, LWS_2))) __kernel void )==""\n"
207R"==(xe_lp_nhwc_1x1_conv_fwd_x8s8x(const __global SRC_DATA_T *src, )==""\n"
208R"==(const __global char *wei, const __global float *bias, )==""\n"
209R"==(__global DST_DATA_T *dst POST_OP_ARGS, )==""\n"
210R"==(const __global float *runtime_scales, )==""\n"
211R"==(const __global int *src_compensation, )==""\n"
212R"==(const __global int *dst_compensation) { )==""\n"
213R"==(const uint oc_group_id = get_group_id(0); )==""\n"
214R"==(const uint sp_group_id = get_group_id(1); )==""\n"
215R"==(const uint mb_group_id = get_group_id(2); )==""\n"
216R"==(const uint ic_group_id = oc_group_id / OC_NCHUNK * IC_NCHUNK; )==""\n"
217R"==(const uint sg_local_id = get_sub_group_local_id(); )==""\n"
218R"==(const uint sg_id = get_sub_group_id(); )==""\n"
219R"==(#define OWB ((OW + SP_BLOCK - 1) / SP_BLOCK) )==""\n"
220R"==(#if USE_SP_BLOCK )==""\n"
221R"==(const uint sp = get_global_id(1) * SP_BLOCK; )==""\n"
222R"==(const int sp_local_id = get_local_id(1); )==""\n"
223R"==(const uint od = sp / (OH * OW); )==""\n"
224R"==(const uint ohw = sp % (OH * OW); )==""\n"
225R"==(const uint oh = ohw / OW; )==""\n"
226R"==(const uint ow = (ohw % OW); )==""\n"
227R"==(#else )==""\n"
228R"==(const uint sp = get_global_id(1); )==""\n"
229R"==(const int sp_local_id = get_local_id(1); )==""\n"
230R"==(const uint od = sp / (OWB * OH); )==""\n"
231R"==(const uint ohw = sp % (OWB * OH); )==""\n"
232R"==(const uint oh = ohw / OWB; )==""\n"
233R"==(const uint ow = (ohw % OWB) * SP_BLOCK; )==""\n"
234R"==(#endif )==""\n"
235R"==(const uint id = SD * od; )==""\n"
236R"==(const uint ih = SH * oh; )==""\n"
237R"==(const uint iw = SW * ow; )==""\n"
238R"==(src += mb_group_id * SRC_SP * SRC_SP_STRIDE; )==""\n"
239R"==(#if !HAS_PADDING )==""\n"
240R"==(src += (id * IH * IW + ih * IW + iw) * SRC_SP_STRIDE; )==""\n"
241R"==(#endif )==""\n"
242R"==(src += ic_group_id * SRC_ICB_STRIDE; )==""\n"
243R"==(dst += mb_group_id * DST_SP * DST_SP_STRIDE; )==""\n"
244R"==(dst += (od * OH * OW + oh * OW + ow) * DST_SP_STRIDE; )==""\n"
245R"==(dst += oc_group_id * DST_OCB_STRIDE; )==""\n"
246R"==(wei += oc_group_id * WEI_BLOCK_STRIDE * IC_NCHUNK; )==""\n"
247R"==(ACC_DATA_BLOCK C00 = 0, C01 = 0, C02 = 0, C03 = 0; )==""\n"
248R"==(ACC_DATA_BLOCK1 C10 = 0, C11 = 0, C12 = 0, C13 = 0; )==""\n"
249R"==(const int oc_tail )==""\n"
250R"==(= (oc_group_id + 1) * OC_BLOCK > G * OC ? OC % OC_BLOCK : OC_BLOCK; )==""\n"
251R"==(#if INT8_WEI_SLM )==""\n"
252R"==(#define READ_SLM() \ )==""\n"
253R"==(barrier(CLK_LOCAL_MEM_FENCE); \ )==""\n"
254R"==(const __global char *wei_copy_from \ )==""\n"
255R"==(= wei + sp_local_id * WEI_BLOCK_STRIDE / LWS_1; \ )==""\n"
256R"==(__local char *wei_copy_to \ )==""\n"
257R"==(= wei_slm + sp_local_id * WEI_BLOCK_STRIDE / LWS_1; \ )==""\n"
258R"==(block_write4((__local uint *)wei_copy_to, \ )==""\n"
259R"==(intel_sub_group_block_read4((__global uint *)wei_copy_from)); \ )==""\n"
260R"==(__local char *wei_tmp = wei_slm; \ )==""\n"
261R"==(barrier(CLK_LOCAL_MEM_FENCE); )==""\n"
262R"==(__local char wei_slm[WEI_BLOCK_STRIDE]; )==""\n"
263R"==(#endif )==""\n"
264R"==(for (uint ic_block_id = 0; ic_block_id < IC_NCHUNK; ++ic_block_id) { )==""\n"
265R"==(#if INT8_WEI_SLM )==""\n"
266R"==(READ_SLM() )==""\n"
267R"==(#if SP_TAIL )==""\n"
268R"==(if (ow < OW) )==""\n"
269R"==(#endif )==""\n"
270R"==(#endif )==""\n"
271R"==({ )==""\n"
272R"==(SRC_DATA_BLOCK_T S0; )==""\n"
273R"==(SRC_DATA_BLOCK_T1 S1; )==""\n"
274R"==(#if OUT_SP_TAIL )==""\n"
275R"==(#if USE_SP_BLOCK )==""\n"
276R"==(if (od * OH * OW + oh * OW + ow + SP_BLOCK > DST_SP) { )==""\n"
277R"==(#else )==""\n"
278R"==(if (ow + SP_BLOCK > OW) { )==""\n"
279R"==(#endif )==""\n"
280R"==(#if OUT_SP_TAIL < 8 )==""\n"
281R"==(S0 = 0; )==""\n"
282R"==(BLOCK_READ_SRC_Xx32(0, OUT_SP_TAIL, 0, S0, 0 * IC, 0); )==""\n"
283R"==(#else )==""\n"
284R"==(BLOCK_READ_SRC_Xx32(0, BLOCK0, 0, S0, 0, 0); )==""\n"
285R"==(S1 = 0; )==""\n"
286R"==(BLOCK_READ_SRC_Xx32(8, OUT_SP_TAIL, -8, S1, 0, BLOCK0); )==""\n"
287R"==(#endif )==""\n"
288R"==(} else )==""\n"
289R"==(#endif )==""\n"
290R"==({ )==""\n"
291R"==(BLOCK_READ_SRC_Xx32(0, BLOCK0, 0, S0, 0 * IC, 0); )==""\n"
292R"==(#if SP_BLOCK > 8 )==""\n"
293R"==(BLOCK_READ_SRC_Xx32(0, BLOCK1, 0, S1, 8 * SW * G * IC, BLOCK0); )==""\n"
294R"==(#endif )==""\n"
295R"==(} )==""\n"
296R"==(int8 W0 = 0, W1 = 0, W2 = 0, W3 = 0; )==""\n"
297R"==(#if IC % IC_BLOCK != 0 )==""\n"
298R"==(if (ic_block_id == IC_NCHUNK - 1) { )==""\n"
299R"==(unroll_for(int i = 0; i < IC_NBLOCKS_TAIL; ++i) )==""\n"
300R"==(BLOCK_READ_WHT_1x32(W0[i], (i + 0) * IC_BLOCK); )==""\n"
301R"==(if (OC > 8) )==""\n"
302R"==(unroll_for(int i = 0; i < IC_NBLOCKS_TAIL; ++i) )==""\n"
303R"==(BLOCK_READ_WHT_1x32(W1[i], (i + 8) * IC_BLOCK); )==""\n"
304R"==(if (OC > 16) )==""\n"
305R"==(unroll_for(int i = 0; i < IC_NBLOCKS_TAIL; ++i) )==""\n"
306R"==(BLOCK_READ_WHT_1x32(W2[i], (i + 16) * IC_BLOCK); )==""\n"
307R"==(if (OC > 24) )==""\n"
308R"==(unroll_for(int i = 0; i < IC_NBLOCKS_TAIL; ++i) )==""\n"
309R"==(BLOCK_READ_WHT_1x32(W3[i], (i + 24) * IC_BLOCK); )==""\n"
310R"==(C00 = MMAD_TAIL0(S0, W0, C00); )==""\n"
311R"==(if (OC > 8) C01 = MMAD_TAIL0(S0, W1, C01); )==""\n"
312R"==(if (OC > 16) C02 = MMAD_TAIL0(S0, W2, C02); )==""\n"
313R"==(if (OC > 24) C03 = MMAD_TAIL0(S0, W3, C03); )==""\n"
314R"==(#if SP_BLOCK > 8 )==""\n"
315R"==(C10 = MMAD_TAIL1(S1, W0, C10); )==""\n"
316R"==(if (OC > 8) C11 = MMAD_TAIL1(S1, W1, C11); )==""\n"
317R"==(if (OC > 16) C12 = MMAD_TAIL1(S1, W2, C12); )==""\n"
318R"==(if (OC > 24) C13 = MMAD_TAIL1(S1, W3, C13); )==""\n"
319R"==(#endif )==""\n"
320R"==(} else )==""\n"
321R"==(#endif )==""\n"
322R"==({ )==""\n"
323R"==(BLOCK_READ_WHT_8x32(W0, 0); )==""\n"
324R"==(if (OC > 8) BLOCK_READ_WHT_8x32(W1, 8 * IC_BLOCK); )==""\n"
325R"==(if (OC > 16) BLOCK_READ_WHT_8x32(W2, 16 * IC_BLOCK); )==""\n"
326R"==(if (OC > 24) BLOCK_READ_WHT_8x32(W3, 24 * IC_BLOCK); )==""\n"
327R"==(C00 = MMAD_FULL0(S0, W0, C00); )==""\n"
328R"==(if (OC > 8) C01 = MMAD_FULL0(S0, W1, C01); )==""\n"
329R"==(if (OC > 16) C02 = MMAD_FULL0(S0, W2, C02); )==""\n"
330R"==(if (OC > 24) C03 = MMAD_FULL0(S0, W3, C03); )==""\n"
331R"==(#if SP_BLOCK > 8 )==""\n"
332R"==(C10 = MMAD_FULL1(S1, W0, C10); )==""\n"
333R"==(if (OC > 8) C11 = MMAD_FULL1(S1, W1, C11); )==""\n"
334R"==(if (OC > 16) C12 = MMAD_FULL1(S1, W2, C12); )==""\n"
335R"==(if (OC > 24) C13 = MMAD_FULL1(S1, W3, C13); )==""\n"
336R"==(#endif )==""\n"
337R"==(} )==""\n"
338R"==(} )==""\n"
339R"==(src += SRC_ICB_STRIDE; )==""\n"
340R"==(wei += WEI_BLOCK_STRIDE; )==""\n"
341R"==(} )==""\n"
342R"==(#if WITH_SRC_ZPOINTS )==""\n"
343R"==(int4 src_comp = as_int4(intel_sub_group_block_read4( )==""\n"
344R"==((__global uint *)(&src_compensation[oc_group_id * OC_BLOCK]))); )==""\n"
345R"==(C00 -= src_comp.s0; )==""\n"
346R"==(C01 -= src_comp.s1; )==""\n"
347R"==(C02 -= src_comp.s2; )==""\n"
348R"==(C03 -= src_comp.s3; )==""\n"
349R"==(#if SP_BLOCK > 8 )==""\n"
350R"==(C10 -= src_comp.s0; )==""\n"
351R"==(C11 -= src_comp.s1; )==""\n"
352R"==(C12 -= src_comp.s2; )==""\n"
353R"==(C13 -= src_comp.s3; )==""\n"
354R"==(#endif )==""\n"
355R"==(#endif )==""\n"
356R"==(float4 tmp; )==""\n"
357R"==(DST_DATA4_T dst_pack[8]; )==""\n"
358R"==(DST_DATA4_T D0[BLOCK0] = {0}; )==""\n"
359R"==(DST_DATA4_T D1[BLOCK1] = {0}; )==""\n"
360R"==(#if SCALES_PER_OC )==""\n"
361R"==(float4 scales; )==""\n"
362R"==(block_read_scales( )==""\n"
363R"==(&scales, oc_group_id * OC_BLOCK, sg_local_id, runtime_scales); )==""\n"
364R"==(#endif )==""\n"
365R"==(#if WITH_BIAS )==""\n"
366R"==(float4 bia; )==""\n"
367R"==(BLOCK_READ_BIA(bia, oc_group_id * OC_BLOCK); )==""\n"
368R"==(#define QUANTIZE_ADD_BIAS() tmp = SCALE * fma(tmp, (float4)1, bia); )==""\n"
369R"==(#else )==""\n"
370R"==(#define QUANTIZE_ADD_BIAS() tmp *= SCALE; )==""\n"
371R"==(#endif )==""\n"
372R"==(#if WITH_SUM )==""\n"
373R"==(#if USE_SP_BLOCK )==""\n"
374R"==(if (OUT_SP_TAIL && od * OH * OW + oh * OW + ow + SP_BLOCK > DST_SP) { )==""\n"
375R"==(#else )==""\n"
376R"==(if (OUT_SP_TAIL && ow + SP_BLOCK > OW) { )==""\n"
377R"==(#endif )==""\n"
378R"==(#if OUT_SP_TAIL < 8 )==""\n"
379R"==(block_read_dst(OUT_SP_TAIL, D0, dst, oc_tail); )==""\n"
380R"==(#else )==""\n"
381R"==(block_read_dst(BLOCK0, D0, dst, oc_tail); )==""\n"
382R"==(block_read_dst(OUT_SP_TAIL - 8, D1, dst + 8 * G * OC, oc_tail); )==""\n"
383R"==(#endif )==""\n"
384R"==(} else { )==""\n"
385R"==(block_read_dst(BLOCK0, D0, dst, oc_tail); )==""\n"
386R"==(if (SP_BLOCK > 8) { )==""\n"
387R"==(block_read_dst(BLOCK1, D1, dst + 8 * G * OC, oc_tail); )==""\n"
388R"==(} )==""\n"
389R"==(} )==""\n"
390R"==(#endif )==""\n"
391R"==(#if WITH_DST_ZPOINTS )==""\n"
392R"==(int4 dst_comp = read_dst_zero_points_32c( )==""\n"
393R"==(dst_compensation, oc_group_id * OC_BLOCK); )==""\n"
394R"==(#define ADD_DST_COMPENSATION() tmp += convert_float4(dst_comp); )==""\n"
395R"==(#else )==""\n"
396R"==(#define ADD_DST_COMPENSATION() )==""\n"
397R"==(#endif )==""\n"
398R"==(#define PACK(C0, C1, C2, C3, idx) \ )==""\n"
399R"==(do { \ )==""\n"
400R"==(tmp[0] = C0[idx]; \ )==""\n"
401R"==(tmp[1] = C1[idx]; \ )==""\n"
402R"==(tmp[2] = C2[idx]; \ )==""\n"
403R"==(tmp[3] = C3[idx]; \ )==""\n"
404R"==(} while (0) )==""\n"
405R"==(#define CONVERT_PACK(idx) \ )==""\n"
406R"==(do { \ )==""\n"
407R"==(dst_pack[idx] = CONVERT_DST_DATA4_T(tmp); \ )==""\n"
408R"==(} while (0) )==""\n"
409R"==(#define STORE_DST(n, C0, C1, C2, C3, D, dst_ptr, mb_stride) \ )==""\n"
410R"==(do { \ )==""\n"
411R"==(for (int n_i = 0; n_i < n; n_i++) { \ )==""\n"
412R"==(PACK(C0, C1, C2, C3, n_i); \ )==""\n"
413R"==(QUANTIZE_ADD_BIAS(); \ )==""\n"
414R"==(for (int didx = 0; didx < 4; ++didx) { \ )==""\n"
415R"==(float tmp_i = tmp[didx]; \ )==""\n"
416R"==(float dni_i = convert_float(AS_SUM_DATA_T(D[n_i][didx])); \ )==""\n"
417R"==(int po_mb; \ )==""\n"
418R"==(if (MB_BLOCK == 32) \ )==""\n"
419R"==(po_mb = (mb_group_id * MB_BLOCK / 2 + mb_stride * 8 + n_i) \ )==""\n"
420R"==(% MB; \ )==""\n"
421R"==(else \ )==""\n"
422R"==(po_mb = mb_group_id % MB; \ )==""\n"
423R"==(const int po_oc = (oc_group_id * OC_BLOCK + sg_local_id \ )==""\n"
424R"==(+ didx * SUB_GROUP_SIZE) \ )==""\n"
425R"==(% (OC * G); \ )==""\n"
426R"==(APPLY_POST_OPS_SERIAL_BINARY_2D( \ )==""\n"
427R"==(tmp_i, float, dni_i, float, po_mb, 1, po_oc, 1); \ )==""\n"
428R"==(tmp[didx] = tmp_i; \ )==""\n"
429R"==(} \ )==""\n"
430R"==(ADD_DST_COMPENSATION(); \ )==""\n"
431R"==(CONVERT_PACK(n_i); \ )==""\n"
432R"==(} \ )==""\n"
433R"==(block_write_dst(n, dst_pack, dst_ptr, oc_tail); \ )==""\n"
434R"==(} while (0) )==""\n"
435R"==(#if INT8_WEI_SLM && SP_TAIL )==""\n"
436R"==(if (ow < OW) )==""\n"
437R"==(#endif )==""\n"
438R"==({ )==""\n"
439R"==(#if USE_SP_BLOCK )==""\n"
440R"==(if (OUT_SP_TAIL && od * OH * OW + oh * OW + ow + SP_BLOCK > DST_SP) { )==""\n"
441R"==(#else )==""\n"
442R"==(if (OUT_SP_TAIL && ow + SP_BLOCK > OW) { )==""\n"
443R"==(#endif )==""\n"
444R"==(STORE_DST(min(BLOCK0, OUT_SP_TAIL), C00, C01, C02, C03, D0, dst, 0); )==""\n"
445R"==(STORE_DST(OUT_SP_TAIL - 8, C10, C11, C12, C13, D1, dst + 8 * G * OC, )==""\n"
446R"==(1); )==""\n"
447R"==(} else { )==""\n"
448R"==(STORE_DST(BLOCK0, C00, C01, C02, C03, D0, dst, 0); )==""\n"
449R"==(if (SP_BLOCK > 8) { )==""\n"
450R"==(STORE_DST(BLOCK1, C10, C11, C12, C13, D1, dst + 8 * G * OC, 1); )==""\n"
451R"==(} )==""\n"
452R"==(} )==""\n"
453R"==(} )==""\n"
454R"==(} )==""\n"
455R"==(void block_read_dst( )==""\n"
456R"==(int n, DST_DATA_T *d, const __global DST_DATA_T *dst, int oc_tail) { )==""\n"
457R"==(const int local_id = get_sub_group_local_id(); )==""\n"
458R"==(int nelems = n * 4; )==""\n"
459R"==(__attribute__((opencl_unroll_hint)) )==""\n"
460R"==(for (int i = 0; i < nelems; i += 4) { )==""\n"
461R"==(if (OC % OC_BLOCK_WRITE_BOUND == 0 && oc_tail == OC_BLOCK) { )==""\n"
462R"==(*((DST_DATA4_T *)&d[i]) = BLOCK_READ_DST4(dst + (i / 4) * G * OC); )==""\n"
463R"==(} else { )==""\n"
464R"==(unroll_for(int idx = 0; idx < 4; idx++) { )==""\n"
465R"==(if (local_id + 8 * idx < oc_tail) { )==""\n"
466R"==(d[i + idx] = dst[(i / 4) * G * OC + idx * 8 + local_id]; )==""\n"
467R"==(} )==""\n"
468R"==(} )==""\n"
469R"==(} )==""\n"
470R"==(} )==""\n"
471R"==(} )==""\n"
472R"==(void block_write_dst( )==""\n"
473R"==(int n, const DST_DATA_T *d, __global DST_DATA_T *dst, int oc_tail) { )==""\n"
474R"==(const int local_id = get_sub_group_local_id(); )==""\n"
475R"==(int nelems = n * 4; )==""\n"
476R"==(__attribute__((opencl_unroll_hint)) )==""\n"
477R"==(for (int i = 0; i < nelems; i += 4) { )==""\n"
478R"==(if (OC % OC_BLOCK_WRITE_BOUND == 0 && oc_tail == OC_BLOCK) { )==""\n"
479R"==(BLOCK_WRITE_DST4(dst + (i / 4) * G * OC, *((DST_DATA4_T *)&d[i])); )==""\n"
480R"==(} else { )==""\n"
481R"==(unroll_for(int idx = 0; idx < 4; idx++) { )==""\n"
482R"==(if (local_id + 8 * idx < oc_tail) { )==""\n"
483R"==(dst[(i / 4) * G * OC + idx * 8 + local_id] = d[i + idx]; )==""\n"
484R"==(} )==""\n"
485R"==(} )==""\n"
486R"==(} )==""\n"
487R"==(} )==""\n"
488R"==(} )==""\n"
489R"==()==";
490}
491}
492}
493}