1namespace dnnl {
2namespace impl {
3namespace gpu {
4namespace ocl {
5const char *xe_lp_conv_fwd_data_ow_block_x8s8x_kernel = R"==(/******************************************************************************* )==""\n"
6R"==(* Copyright 2019-2022 Intel Corporation )==""\n"
7R"==(* )==""\n"
8R"==(* Licensed under the Apache License, Version 2.0 (the "License"); )==""\n"
9R"==(* you may not use this file except in compliance with the License. )==""\n"
10R"==(* You may obtain a copy of the License at )==""\n"
11R"==(* )==""\n"
12R"==(* http: )==""\n"
13R"==(* )==""\n"
14R"==(* Unless required by applicable law or agreed to in writing, software )==""\n"
15R"==(* distributed under the License is distributed on an "AS IS" BASIS, )==""\n"
16R"==(* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. )==""\n"
17R"==(* See the License for the specific language governing permissions and )==""\n"
18R"==(* limitations under the License. )==""\n"
19R"==(*******************************************************************************/ )==""\n"
20R"==(#include "gpu/ocl/ocl_math_utils.h" )==""\n"
21R"==(#include "gpu/ocl/ocl_post_ops.h" )==""\n"
22R"==(#include "gpu/ocl/ocl_types.h" )==""\n"
23R"==(#include "gpu/ocl/ocl_zero_points.h" )==""\n"
24R"==(#if IC % IC_BLOCK != 0 )==""\n"
25R"==(#define IC_NBLOCKS_TAIL ((IC - (IC & ~(IC_BLOCK - 1)) + 3) / 4) )==""\n"
26R"==(#else )==""\n"
27R"==(#define IC_NBLOCKS_TAIL 8 )==""\n"
28R"==(#endif )==""\n"
29R"==(#if OW_BLOCK == 4 )==""\n"
30R"==(#define BLOCK 4 )==""\n"
31R"==(#define ACC_DATA_BLOCK int4 )==""\n"
32R"==(#define SRC_DATA_BLOCK_T SRC_MMAD_DATA4_T )==""\n"
33R"==(#define READ_BLOCK intel_sub_group_block_read4 )==""\n"
34R"==(#define WRITE_LOCAL block_write4 )==""\n"
35R"==(DECLARE_MMAD_EMU(mmad_tail, idot4, IC_NBLOCKS_TAIL, 4, SRC_DATA_BLOCK_T, int8, )==""\n"
36R"==(ACC_DATA_BLOCK) )==""\n"
37R"==(#define MMAD_FULL mmad8x4 )==""\n"
38R"==(#define MMAD_TAIL mmad_tail )==""\n"
39R"==(#else )==""\n"
40R"==(#define BLOCK 8 )==""\n"
41R"==(#define ACC_DATA_BLOCK int8 )==""\n"
42R"==(#define SRC_DATA_BLOCK_T SRC_MMAD_DATA8_T )==""\n"
43R"==(#define READ_BLOCK intel_sub_group_block_read8 )==""\n"
44R"==(#define WRITE_LOCAL block_write8 )==""\n"
45R"==(DECLARE_MMAD_EMU(mmad_tail, idot4, IC_NBLOCKS_TAIL, 8, SRC_DATA_BLOCK_T, int8, )==""\n"
46R"==(ACC_DATA_BLOCK) )==""\n"
47R"==(#define MMAD_FULL mmad8x8 )==""\n"
48R"==(#define MMAD_TAIL mmad_tail )==""\n"
49R"==(#endif )==""\n"
50R"==(#define BLOCK_READ_SRC(data, idx) \ )==""\n"
51R"==(data = intel_sub_group_block_read8((__global uint *)&src[idx]); )==""\n"
52R"==(#define BLOCK_READ_WHT_1x32(data, idx) \ )==""\n"
53R"==(data = as_int(intel_sub_group_block_read((__global uint *)&wei[idx])); )==""\n"
54R"==(#define BLOCK_READ_WHT_8x32(data, idx) \ )==""\n"
55R"==(data = as_int8(intel_sub_group_block_read8((__global uint *)&wei[idx])); )==""\n"
56R"==(#if OC % OC_BLOCK == 0 )==""\n"
57R"==(#define BLOCK_READ_BIA(data, idx) \ )==""\n"
58R"==(data = as_float4(intel_sub_group_block_read4((__global uint *)&bias[idx])); )==""\n"
59R"==(#else )==""\n"
60R"==(#define BLOCK_READ_BIA(data, idx) \ )==""\n"
61R"==(data = (float4)0; \ )==""\n"
62R"==(int i; \ )==""\n"
63R"==(for (i = idx; i < idx + OC_BLOCK && i < OC - (OC % SUB_GROUP_SIZE); \ )==""\n"
64R"==(i += SUB_GROUP_SIZE) { \ )==""\n"
65R"==(data[(i - idx) / SUB_GROUP_SIZE] = as_float( \ )==""\n"
66R"==(intel_sub_group_block_read((__global uint *)&bias[i])); \ )==""\n"
67R"==(} \ )==""\n"
68R"==(if ((get_sub_group_local_id() < OC % SUB_GROUP_SIZE) && i < OC \ )==""\n"
69R"==(&& (i - idx) / SUB_GROUP_SIZE < 4) { \ )==""\n"
70R"==(data[(i - idx) / SUB_GROUP_SIZE] \ )==""\n"
71R"==(= as_float(bias[i + get_sub_group_local_id()]); \ )==""\n"
72R"==(} )==""\n"
73R"==(#endif )==""\n"
74R"==(#define BLOCK_READ_SCALES(data, idx) \ )==""\n"
75R"==(if (OC >= idx + (SUB_GROUP_SIZE * 4)) { \ )==""\n"
76R"==(data = as_float4(intel_sub_group_block_read4( \ )==""\n"
77R"==((__global uint *)&runtime_scales[idx])); \ )==""\n"
78R"==(} else { \ )==""\n"
79R"==(float local_dat[4] = {}; \ )==""\n"
80R"==(for (int i = 0; i < 4; ++i) \ )==""\n"
81R"==(if (idx + ((i + 1) * SUB_GROUP_SIZE) <= OC) { \ )==""\n"
82R"==(local_dat[i] = as_float(intel_sub_group_block_read( \ )==""\n"
83R"==((__global uint *)&runtime_scales[idx \ )==""\n"
84R"==(+ (SUB_GROUP_SIZE * i)])); \ )==""\n"
85R"==(} else if (idx + (i * SUB_GROUP_SIZE) + subg_local_id < OC) { \ )==""\n"
86R"==(local_dat[i] = runtime_scales[idx + (SUB_GROUP_SIZE * i) \ )==""\n"
87R"==(+ subg_local_id]; \ )==""\n"
88R"==(} \ )==""\n"
89R"==(data.s0 = local_dat[0]; \ )==""\n"
90R"==(data.s1 = local_dat[1]; \ )==""\n"
91R"==(data.s2 = local_dat[2]; \ )==""\n"
92R"==(data.s3 = local_dat[3]; \ )==""\n"
93R"==(} )==""\n"
94R"==(#if SCALES_PER_OC )==""\n"
95R"==(#define SCALE scales )==""\n"
96R"==(#elif SCALES_COMMON )==""\n"
97R"==(#define SCALE runtime_scales[0] )==""\n"
98R"==(#else )==""\n"
99R"==(#define SCALE 1 )==""\n"
100R"==(#endif )==""\n"
101R"==(__attribute__((intel_reqd_sub_group_size(SUB_GROUP_SIZE))) )==""\n"
102R"==(__attribute__((reqd_work_group_size(LWS_0, LWS_1, LWS_2))) __kernel void )==""\n"
103R"==(conv_fwd_ow_block_x8s8x(const __global SRC_DATA_T *src, )==""\n"
104R"==(const __global char *wei, const __global float *bias, )==""\n"
105R"==(__global DATA_T *dst POST_OP_ARGS, const __global float *runtime_scales, )==""\n"
106R"==(const __global int *src_compensation, const __global int *src_zpoints, )==""\n"
107R"==(const __global int *dst_compensation) { )==""\n"
108R"==(const int group_oc = get_group_id(0) * OC_GROUP; )==""\n"
109R"==(const int group_mb = get_group_id(2) * MB_GROUP; )==""\n"
110R"==(const int group_sp = get_group_id(1) * SP_GROUP; )==""\n"
111R"==(const int sub_group_id = get_sub_group_id(); )==""\n"
112R"==(const int subg_local_id = get_sub_group_local_id(); )==""\n"
113R"==(const int oc = (sub_group_id % OC_GROUP); )==""\n"
114R"==(const int sp = (sub_group_id / OC_GROUP); )==""\n"
115R"==(const int g = (group_oc + oc) / OC_NCHUNK; )==""\n"
116R"==(const int group_ic = IC_NCHUNK * g; )==""\n"
117R"==(const int god = group_sp / (OW_PADDED * OH); )==""\n"
118R"==(const int gohw = group_sp % (OW_PADDED * OH); )==""\n"
119R"==(const int goh = gohw / OW_PADDED; )==""\n"
120R"==(const int gow = OW_BLOCK * (gohw % OW_PADDED); )==""\n"
121R"==(const int gid = god * SD; )==""\n"
122R"==(const int gih = goh * SH; )==""\n"
123R"==(const int giw = gow * SW; )==""\n"
124R"==(const int local_ow = OW_BLOCK * sp; )==""\n"
125R"==(const int local_iw = local_ow * SW; )==""\n"
126R"==(const int od = god; )==""\n"
127R"==(const int ow = gow + local_ow; )==""\n"
128R"==(const int oh = goh; )==""\n"
129R"==(const int id = gid - PD; )==""\n"
130R"==(const int iw = giw + local_iw - PW; )==""\n"
131R"==(const int ih = gih - PH; )==""\n"
132R"==(__local uint S_slice[SRC_SLM_SIZE]; )==""\n"
133R"==(__local uint *S_part = S_slice + IC_BLOCK / 4 * (sp * SW * OW_BLOCK + PW); )==""\n"
134R"==(__local uint *S_work = S_slice + IC_BLOCK / 4 * (sp * SW * OW_BLOCK); )==""\n"
135R"==(const bool left_tail = iw < 0; )==""\n"
136R"==(const bool left_nozero_tail = sub_group_id == 0 && iw > -PW; )==""\n"
137R"==(const bool right_tail = (iw + PW + SLM_TAIL >= IW) && (iw + PW < IW); )==""\n"
138R"==(const bool right_nozero_tail )==""\n"
139R"==(= sp == (LWS_1 - 1) && (iw + PW + SLM_TAIL < IW); )==""\n"
140R"==(const bool empty = (iw + PW >= IW); )==""\n"
141R"==(dst += OC_BLOCK * OD * OH * OW * MB_BLOCK * (group_oc + oc); )==""\n"
142R"==(dst += OC_BLOCK * OD * OH * OW * OC_NCHUNK * G * MB_BLOCK * group_mb; )==""\n"
143R"==(dst += OC_BLOCK * MB_BLOCK * (OW * OH * od + OW * oh + ow); )==""\n"
144R"==(src += IC_BLOCK * ID * IH * IW * MB_BLOCK * group_ic; )==""\n"
145R"==(src += IC_BLOCK * ID * IH * IW * IC_NCHUNK * G * MB_BLOCK * group_mb; )==""\n"
146R"==(src += IC_BLOCK * MB_BLOCK * (IW * IH * id + IW * ih + iw + PW); )==""\n"
147R"==(wei += IC_BLOCK * KD * KH * KW * OC_BLOCK * (group_oc + oc) * IC_NCHUNK; )==""\n"
148R"==(/* Prepare S_slice tails */ )==""\n"
149R"==(#if PW > 0 )==""\n"
150R"==(if (left_tail) { )==""\n"
151R"==(for (int i = 0; i < PW; i++) { )==""\n"
152R"==(block_write(S_slice + i * 8, 0); )==""\n"
153R"==(} )==""\n"
154R"==(} )==""\n"
155R"==(#endif )==""\n"
156R"==(#if ZERO_TAIL > 0 )==""\n"
157R"==(if (right_tail) { )==""\n"
158R"==(for (int i = SLM_TAIL; i < SW * OW_BLOCK + (KW - 1) * (1 + DW) - PW; )==""\n"
159R"==(i++) { )==""\n"
160R"==(block_write(S_part + i * 8, 0); )==""\n"
161R"==(} )==""\n"
162R"==(} )==""\n"
163R"==(#if SLM_NCHUNK < OW_NCHUNK )==""\n"
164R"==(if (empty) { )==""\n"
165R"==(for (int i = -PW; i < SW * OW_BLOCK + (KW - 1) * (1 + DW) - PW; i++) { )==""\n"
166R"==(block_write(S_part + i * 8, 0); )==""\n"
167R"==(} )==""\n"
168R"==(} )==""\n"
169R"==(#endif )==""\n"
170R"==(#endif )==""\n"
171R"==(ACC_DATA_BLOCK C00 = 0, C01 = 0, C02 = 0, C03 = 0; )==""\n"
172R"==(for (int ic_chunk = 0; ic_chunk < IC_NCHUNK; ic_chunk++) { )==""\n"
173R"==(SRC_DATA_BLOCK_T S0; )==""\n"
174R"==(for (int kd = 0; kd < KD; kd++) { )==""\n"
175R"==(if (kd * (1 + DD) + id < 0 || kd * (1 + DD) + id >= ID) { )==""\n"
176R"==(src += IC_BLOCK * MB_BLOCK * IH * IW * (1 + DD); )==""\n"
177R"==(wei += IC_BLOCK * OC_BLOCK * KH * KW; )==""\n"
178R"==(continue; )==""\n"
179R"==(} )==""\n"
180R"==(for (int kh = 0; kh < KH; kh++) { )==""\n"
181R"==(if (kh * (1 + DH) + ih < 0 || kh * (1 + DH) + ih >= IH) { )==""\n"
182R"==(src += IC_BLOCK * MB_BLOCK * IW * (1 + DH); )==""\n"
183R"==(wei += IC_BLOCK * OC_BLOCK * KW; )==""\n"
184R"==(continue; )==""\n"
185R"==(} )==""\n"
186R"==(barrier(CLK_LOCAL_MEM_FENCE); )==""\n"
187R"==(#if SLM_NCHUNK < OW_NCHUNK )==""\n"
188R"==(if (iw + PW < IW) { )==""\n"
189R"==(#endif )==""\n"
190R"==(#if OW_NCHUNK > LWS_1 )==""\n"
191R"==(/* Copy tails in case of multigroups */ )==""\n"
192R"==(if (ow < OW) { )==""\n"
193R"==(#if PW > 0 )==""\n"
194R"==(if (left_nozero_tail) { )==""\n"
195R"==(for (int i = -PW - min(iw, 0); i < 0; i++) { )==""\n"
196R"==(block_write(S_part + i * 8, )==""\n"
197R"==(intel_sub_group_block_read( )==""\n"
198R"==((const __global uint *)(&src[i )==""\n"
199R"==(* IC_BLOCK]))); )==""\n"
200R"==(} )==""\n"
201R"==(} )==""\n"
202R"==(#endif )==""\n"
203R"==(if (right_nozero_tail) { )==""\n"
204R"==(int buffer_last = (KW - 1) * (1 + DW) - PW; )==""\n"
205R"==(int src_last = IW - iw - SW * OW_BLOCK - PW; )==""\n"
206R"==(for (int i = SW * OW_BLOCK; i < SW * OW_BLOCK )==""\n"
207R"==(+ min(buffer_last, src_last); )==""\n"
208R"==(i++) { )==""\n"
209R"==(block_write(S_part + i * 8, )==""\n"
210R"==(intel_sub_group_block_read( )==""\n"
211R"==((const __global uint *)(&src[i )==""\n"
212R"==(* IC_BLOCK]))); )==""\n"
213R"==(} )==""\n"
214R"==(for (int i = SW * OW_BLOCK )==""\n"
215R"==(+ min(buffer_last, src_last); )==""\n"
216R"==(i < SW * OW_BLOCK + buffer_last; i++) { )==""\n"
217R"==(block_write(S_part + i * 8, 0); )==""\n"
218R"==(} )==""\n"
219R"==(} )==""\n"
220R"==(#endif )==""\n"
221R"==(#if SLM_TAIL != OW_BLOCK * SW )==""\n"
222R"==(/* Copy last block to SLM */ )==""\n"
223R"==(if (right_tail) { )==""\n"
224R"==(__attribute__((opencl_unroll_hint)) for (int i = 0; )==""\n"
225R"==(i )==""\n"
226R"==(< SLM_TAIL; )==""\n"
227R"==(i++) { )==""\n"
228R"==(block_write(S_part + i * 8, )==""\n"
229R"==(intel_sub_group_block_read( )==""\n"
230R"==((const __global uint *)(&src[i )==""\n"
231R"==(* IC_BLOCK]))); )==""\n"
232R"==(} )==""\n"
233R"==(} else { )==""\n"
234R"==(#endif )==""\n"
235R"==(/* Copy block to SLM */ )==""\n"
236R"==(__attribute__(( )==""\n"
237R"==(opencl_unroll_hint)) for (int i = 0; )==""\n"
238R"==(i < SW * OW_BLOCK; )==""\n"
239R"==(i += OW_BLOCK) { )==""\n"
240R"==(WRITE_LOCAL(S_part + i * 8, )==""\n"
241R"==(READ_BLOCK( )==""\n"
242R"==((const __global uint *)(&src[i )==""\n"
243R"==(* IC_BLOCK]))); )==""\n"
244R"==(} )==""\n"
245R"==(#if SLM_TAIL != OW_BLOCK * SW )==""\n"
246R"==(} )==""\n"
247R"==(#endif )==""\n"
248R"==(#if OW_NCHUNK > LWS_1 )==""\n"
249R"==(} )==""\n"
250R"==(#endif )==""\n"
251R"==(#if SLM_NCHUNK < OW_NCHUNK )==""\n"
252R"==(} )==""\n"
253R"==(#endif )==""\n"
254R"==(barrier(CLK_LOCAL_MEM_FENCE); )==""\n"
255R"==(for (int kw = 0; kw < KW; kw++) { )==""\n"
256R"==(unroll_for(int i = 0; i < OW_BLOCK; i++) { )==""\n"
257R"==(S0[i] = block_read( )==""\n"
258R"==(S_work + (kw * (1 + DW) + SW * i) * 8); )==""\n"
259R"==(} )==""\n"
260R"==(int8 W0 = 0, W1 = 0, W2 = 0, W3 = 0; )==""\n"
261R"==(#if IC % IC_BLOCK != 0 )==""\n"
262R"==(if (ic_chunk == IC_NCHUNK - 1) { )==""\n"
263R"==(unroll_for(int i = 0; i < IC_NBLOCKS_TAIL; ++i) )==""\n"
264R"==(BLOCK_READ_WHT_1x32(W0[i], (i + 0) * IC_BLOCK); )==""\n"
265R"==(if (OC > 8) )==""\n"
266R"==(unroll_for(int i = 0; i < IC_NBLOCKS_TAIL; ++i) )==""\n"
267R"==(BLOCK_READ_WHT_1x32( )==""\n"
268R"==(W1[i], (i + 8) * IC_BLOCK); )==""\n"
269R"==(if (OC > 16) )==""\n"
270R"==(unroll_for(int i = 0; i < IC_NBLOCKS_TAIL; ++i) )==""\n"
271R"==(BLOCK_READ_WHT_1x32( )==""\n"
272R"==(W2[i], (i + 16) * IC_BLOCK); )==""\n"
273R"==(if (OC > 24) )==""\n"
274R"==(unroll_for(int i = 0; i < IC_NBLOCKS_TAIL; ++i) )==""\n"
275R"==(BLOCK_READ_WHT_1x32( )==""\n"
276R"==(W3[i], (i + 24) * IC_BLOCK); )==""\n"
277R"==(C00 = MMAD_TAIL(S0, W0, C00); )==""\n"
278R"==(if (OC > 8) C01 = MMAD_TAIL(S0, W1, C01); )==""\n"
279R"==(if (OC > 16) C02 = MMAD_TAIL(S0, W2, C02); )==""\n"
280R"==(if (OC > 24) C03 = MMAD_TAIL(S0, W3, C03); )==""\n"
281R"==(} else )==""\n"
282R"==(#endif )==""\n"
283R"==({ )==""\n"
284R"==(BLOCK_READ_WHT_8x32(W0, 0); )==""\n"
285R"==(if (OC > 8) BLOCK_READ_WHT_8x32(W1, 8 * IC_BLOCK); )==""\n"
286R"==(if (OC > 16) BLOCK_READ_WHT_8x32(W2, 16 * IC_BLOCK); )==""\n"
287R"==(if (OC > 24) BLOCK_READ_WHT_8x32(W3, 24 * IC_BLOCK); )==""\n"
288R"==(C00 = MMAD_FULL(S0, W0, C00); )==""\n"
289R"==(if (OC > 8) C01 = MMAD_FULL(S0, W1, C01); )==""\n"
290R"==(if (OC > 16) C02 = MMAD_FULL(S0, W2, C02); )==""\n"
291R"==(if (OC > 24) C03 = MMAD_FULL(S0, W3, C03); )==""\n"
292R"==(} )==""\n"
293R"==(wei += IC_BLOCK * OC_BLOCK; )==""\n"
294R"==(} )==""\n"
295R"==(src += IC_BLOCK * MB_BLOCK * IW * (1 + DH); )==""\n"
296R"==(} )==""\n"
297R"==(src += IC_BLOCK * MB_BLOCK * (IH * (1 + DD) - KH * (1 + DH)) * IW; )==""\n"
298R"==(} )==""\n"
299R"==(src += IC_BLOCK * MB_BLOCK * (ID - KD * (1 + DD)) * IH * IW; )==""\n"
300R"==(} )==""\n"
301R"==(#if WITH_SRC_ZPOINTS )==""\n"
302R"==(const int has_pad_d = id < 0 || id + KD * (1 + DD) >= ID; )==""\n"
303R"==(const int has_pad_h = ih < 0 || ih + KH * (1 + DH) >= IH; )==""\n"
304R"==(const int has_pad_w = iw < 0 || iw + KW * (1 + DW) + OW_BLOCK * SW >= IW; )==""\n"
305R"==(if (has_pad_d || has_pad_h || has_pad_w) { )==""\n"
306R"==(wei -= IC_NCHUNK * KD * KH * KW * IC_BLOCK * OC_BLOCK; )==""\n"
307R"==(for (int ic_chunk = 0; ic_chunk < IC_NCHUNK; ic_chunk++) { )==""\n"
308R"==(#if WITH_SRC_ZPOINTS_PER_IC )==""\n"
309R"==(const int4 z = read_src_zero_points_32c( )==""\n"
310R"==(src_zpoints, (group_ic + ic_chunk) * IC_BLOCK); )==""\n"
311R"==(#else )==""\n"
312R"==(const int z = read_src_zero_point(src_zpoints); )==""\n"
313R"==(#endif )==""\n"
314R"==(for (int kd = 0; kd < KD; kd++) { )==""\n"
315R"==(for (int kh = 0; kh < KH; kh++) { )==""\n"
316R"==(for (int kw = 0; kw < KW; kw++) { )==""\n"
317R"==(int8 w0, w1, w2, w3; )==""\n"
318R"==(BLOCK_READ_WHT_8x32(w0, 0); )==""\n"
319R"==(BLOCK_READ_WHT_8x32(w1, 8 * IC_BLOCK); )==""\n"
320R"==(BLOCK_READ_WHT_8x32(w2, 16 * IC_BLOCK); )==""\n"
321R"==(BLOCK_READ_WHT_8x32(w3, 24 * IC_BLOCK); )==""\n"
322R"==(int4 acc = 0; )==""\n"
323R"==(#if WITH_SRC_ZPOINTS_PER_IC )==""\n"
324R"==(acc.s0 += calc_src_compensation_x32(z, w0); )==""\n"
325R"==(acc.s1 += calc_src_compensation_x32(z, w1); )==""\n"
326R"==(acc.s2 += calc_src_compensation_x32(z, w2); )==""\n"
327R"==(acc.s3 += calc_src_compensation_x32(z, w3); )==""\n"
328R"==(#else )==""\n"
329R"==(unroll_for(uint j = 0; j < 8; ++j) { )==""\n"
330R"==(acc.s0 = idot4(0x01010101, w0[j], acc.s0); )==""\n"
331R"==(acc.s1 = idot4(0x01010101, w1[j], acc.s1); )==""\n"
332R"==(acc.s2 = idot4(0x01010101, w2[j], acc.s2); )==""\n"
333R"==(acc.s3 = idot4(0x01010101, w3[j], acc.s3); )==""\n"
334R"==(} )==""\n"
335R"==(acc = z * acc; )==""\n"
336R"==(#endif )==""\n"
337R"==(for (int i = 0; i < OW_BLOCK; ++i) { )==""\n"
338R"==(const int id0 = kd * (1 + DD) + id; )==""\n"
339R"==(const int ih0 = kh * (1 + DH) + ih; )==""\n"
340R"==(const int iw0 = kw * (1 + DW) + iw + i * SW; )==""\n"
341R"==(const int is_pad_d = id0 < 0 || id0 >= ID; )==""\n"
342R"==(const int is_pad_h = ih0 < 0 || ih0 >= IH; )==""\n"
343R"==(const int is_pad_w = iw0 < 0 || iw0 >= IW; )==""\n"
344R"==(if (is_pad_d || is_pad_h || is_pad_w) { )==""\n"
345R"==(C00[i] += acc.s0; )==""\n"
346R"==(C01[i] += acc.s1; )==""\n"
347R"==(C02[i] += acc.s2; )==""\n"
348R"==(C03[i] += acc.s3; )==""\n"
349R"==(} )==""\n"
350R"==(} )==""\n"
351R"==(wei += IC_BLOCK * OC_BLOCK; )==""\n"
352R"==(} )==""\n"
353R"==(} )==""\n"
354R"==(} )==""\n"
355R"==(} )==""\n"
356R"==(} )==""\n"
357R"==(int4 src_comp = as_int4(intel_sub_group_block_read4( )==""\n"
358R"==((__global uint *)(&src_compensation[(group_oc + oc) * OC_BLOCK]))); )==""\n"
359R"==(C00 -= src_comp.s0; )==""\n"
360R"==(C01 -= src_comp.s1; )==""\n"
361R"==(C02 -= src_comp.s2; )==""\n"
362R"==(C03 -= src_comp.s3; )==""\n"
363R"==(#endif )==""\n"
364R"==(if (ow < OW) { )==""\n"
365R"==(float4 tmp; )==""\n"
366R"==(DST_DATA4_T dst_pack[BLOCK]; )==""\n"
367R"==(DST_DATA4_T D0[BLOCK]; )==""\n"
368R"==(#if SCALES_PER_OC )==""\n"
369R"==(float4 scales = 1; )==""\n"
370R"==(BLOCK_READ_SCALES(scales, (group_oc + oc) * OC_BLOCK); )==""\n"
371R"==(#endif )==""\n"
372R"==(#if WITH_BIAS )==""\n"
373R"==(float4 bia; )==""\n"
374R"==(BLOCK_READ_BIA(bia, (group_oc + oc) * OC_BLOCK); )==""\n"
375R"==(#define QUANTIZE_ADD_BIAS() tmp = SCALE * fma(tmp, (float4)1, bia); )==""\n"
376R"==(#else )==""\n"
377R"==(#define QUANTIZE_ADD_BIAS() tmp *= SCALE; )==""\n"
378R"==(#endif )==""\n"
379R"==(#if WITH_SUM )==""\n"
380R"==(#if OW_BLOCK == 4 )==""\n"
381R"==(*(DST_DATA16_T *)D0 = BLOCK_READ_DST16(dst); )==""\n"
382R"==(#endif )==""\n"
383R"==(#if OW_BLOCK == 8 )==""\n"
384R"==(*(DST_DATA16_T *)(D0 + 0) = BLOCK_READ_DST16(dst); )==""\n"
385R"==(*(DST_DATA16_T *)(D0 + 4) = BLOCK_READ_DST16(dst + 16 * 8); )==""\n"
386R"==(#endif )==""\n"
387R"==(#endif )==""\n"
388R"==(#if WITH_DST_ZPOINTS )==""\n"
389R"==(int4 dst_zp = read_dst_zero_points_32c( )==""\n"
390R"==(dst_compensation, (group_oc + oc) * OC_BLOCK); )==""\n"
391R"==(#if !WITH_DST_ZPOINTS_PER_OC && OC % 32 != 0 )==""\n"
392R"==(dst_zp = convert_int4(zero_pad_dst_32c( )==""\n"
393R"==(convert_float4(dst_zp), (group_oc + oc) * OC_BLOCK)); )==""\n"
394R"==(#endif )==""\n"
395R"==(#define ADD_DST_COMPENSATION() tmp += convert_float4(dst_zp); )==""\n"
396R"==(#else )==""\n"
397R"==(#define ADD_DST_COMPENSATION() )==""\n"
398R"==(#endif )==""\n"
399R"==(#if WITH_SRC_ZPOINTS )==""\n"
400R"==(#define ZERO_PAD_DST() tmp = zero_pad_dst_32c(tmp, (group_oc + oc) * OC_BLOCK); )==""\n"
401R"==(#else )==""\n"
402R"==(#define ZERO_PAD_DST() )==""\n"
403R"==(#endif )==""\n"
404R"==(#define PACK(C0, C1, C2, C3, idx) \ )==""\n"
405R"==(do { \ )==""\n"
406R"==(tmp[0] = convert_float(C0[idx]); \ )==""\n"
407R"==(tmp[1] = convert_float(C1[idx]); \ )==""\n"
408R"==(tmp[2] = convert_float(C2[idx]); \ )==""\n"
409R"==(tmp[3] = convert_float(C3[idx]); \ )==""\n"
410R"==(} while (0) )==""\n"
411R"==(#define CONVERT_PACK(idx) \ )==""\n"
412R"==(do { \ )==""\n"
413R"==(dst_pack[idx] = CONVERT_DST_DATA4_T(tmp); \ )==""\n"
414R"==(} while (0) )==""\n"
415R"==(#define PACK_DST(C0, C1, C2, C3, D) \ )==""\n"
416R"==(do { \ )==""\n"
417R"==(for (int n_i = 0; n_i < OW_BLOCK; ++n_i) { \ )==""\n"
418R"==(PACK(C0, C1, C2, C3, n_i); \ )==""\n"
419R"==(QUANTIZE_ADD_BIAS(); \ )==""\n"
420R"==(const int po_mb = group_mb * MB_BLOCK; \ )==""\n"
421R"==(const int po_oc \ )==""\n"
422R"==(= (group_oc * OC_BLOCK + oc * OC_BLOCK) % (OC * G); \ )==""\n"
423R"==(float4 dni = convert_float4(SUM_TO_REF(AS_SUM_DATA4_T(D[n_i]))); \ )==""\n"
424R"==(APPLY_POST_OPS_TRY_BURST(tmp, float, dni, float, po_mb, 1, po_oc, \ )==""\n"
425R"==(4 * SUB_GROUP_SIZE, subg_local_id); \ )==""\n"
426R"==(ADD_DST_COMPENSATION(); \ )==""\n"
427R"==(ZERO_PAD_DST(); \ )==""\n"
428R"==(CONVERT_PACK(n_i); \ )==""\n"
429R"==(} \ )==""\n"
430R"==(} while (0) )==""\n"
431R"==(PACK_DST(C00, C01, C02, C03, D0); )==""\n"
432R"==(#if OW_TAIL )==""\n"
433R"==(if (ow + OW_BLOCK > OW) { )==""\n"
434R"==(__attribute__((opencl_unroll_hint(OW_TAIL))) for (int i = 0; )==""\n"
435R"==(i < OW_TAIL; )==""\n"
436R"==(i++) { )==""\n"
437R"==(BLOCK_WRITE_DST4(&dst[i * 32], dst_pack[i]); )==""\n"
438R"==(} )==""\n"
439R"==(} else { )==""\n"
440R"==(#endif )==""\n"
441R"==(#if OW_BLOCK == 4 )==""\n"
442R"==(BLOCK_WRITE_DST16(dst, *(DST_DATA16_T *)dst_pack); )==""\n"
443R"==(#endif )==""\n"
444R"==(#if OW_BLOCK == 8 )==""\n"
445R"==(BLOCK_WRITE_DST16(dst, *(DST_DATA16_T *)dst_pack); )==""\n"
446R"==(BLOCK_WRITE_DST16(dst + 16 * 8, *(DST_DATA16_T *)(dst_pack + 4)); )==""\n"
447R"==(#endif )==""\n"
448R"==(#if OW_TAIL )==""\n"
449R"==(} )==""\n"
450R"==(#endif )==""\n"
451R"==(} )==""\n"
452R"==(} )==""\n"
453R"==()==";
454}
455}
456}
457}