1namespace dnnl {
2namespace impl {
3namespace gpu {
4namespace ocl {
5const char *xe_lp_conv_nhwc_fwd_x8s8x_kernel = R"==(/******************************************************************************* )==""\n"
6R"==(* Copyright 2020-2022 Intel Corporation )==""\n"
7R"==(* )==""\n"
8R"==(* Licensed under the Apache License, Version 2.0 (the "License"); )==""\n"
9R"==(* you may not use this file except in compliance with the License. )==""\n"
10R"==(* You may obtain a copy of the License at )==""\n"
11R"==(* )==""\n"
12R"==(* http: )==""\n"
13R"==(* )==""\n"
14R"==(* Unless required by applicable law or agreed to in writing, software )==""\n"
15R"==(* distributed under the License is distributed on an "AS IS" BASIS, )==""\n"
16R"==(* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. )==""\n"
17R"==(* See the License for the specific language governing permissions and )==""\n"
18R"==(* limitations under the License. )==""\n"
19R"==(*******************************************************************************/ )==""\n"
20R"==(#include "gpu/ocl/ocl_math_utils.h" )==""\n"
21R"==(#include "gpu/ocl/ocl_post_ops.h" )==""\n"
22R"==(#include "gpu/ocl/ocl_types.h" )==""\n"
23R"==(#if IC % IC_BLOCK != 0 )==""\n"
24R"==(#define IC_NBLOCKS_TAIL ((IC % IC_BLOCK + 3) / 4) )==""\n"
25R"==(#else )==""\n"
26R"==(#define IC_NBLOCKS_TAIL 8 )==""\n"
27R"==(#endif )==""\n"
28R"==(#if OW_BLOCK == 4 )==""\n"
29R"==(#define BLOCK 4 )==""\n"
30R"==(#define ACC_DATA_BLOCK int4 )==""\n"
31R"==(#define SRC_DATA_BLOCK_T SRC_MMAD_DATA4_T )==""\n"
32R"==(#define READ_BLOCK intel_sub_group_block_read4 )==""\n"
33R"==(#define WRITE_LOCAL block_write4 )==""\n"
34R"==(DECLARE_MMAD_EMU(mmad_tail, idot4, IC_NBLOCKS_TAIL, 4, SRC_DATA_BLOCK_T, int8, )==""\n"
35R"==(ACC_DATA_BLOCK) )==""\n"
36R"==(#define MMAD_FULL mmad8x4 )==""\n"
37R"==(#define MMAD_TAIL mmad_tail )==""\n"
38R"==(#elif OW_BLOCK == 8 )==""\n"
39R"==(#define BLOCK 8 )==""\n"
40R"==(#define ACC_DATA_BLOCK int8 )==""\n"
41R"==(#define SRC_DATA_BLOCK_T SRC_MMAD_DATA8_T )==""\n"
42R"==(#define READ_BLOCK intel_sub_group_block_read8 )==""\n"
43R"==(#define WRITE_LOCAL block_write8 )==""\n"
44R"==(DECLARE_MMAD_EMU(mmad_tail, idot4, IC_NBLOCKS_TAIL, 8, SRC_DATA_BLOCK_T, int8, )==""\n"
45R"==(ACC_DATA_BLOCK) )==""\n"
46R"==(#define MMAD_FULL mmad8x8 )==""\n"
47R"==(#define MMAD_TAIL mmad_tail )==""\n"
48R"==(#else )==""\n"
49R"==(#error "Wrong OW_BLOCK" )==""\n"
50R"==(#endif )==""\n"
51R"==(#define BLOCK_READ_SRC(data, idx) \ )==""\n"
52R"==(data = intel_sub_group_block_read8((__global uint *)&src[idx]); )==""\n"
53R"==(#define BLOCK_READ_WHT_1x32(data, idx) \ )==""\n"
54R"==(data = as_int(intel_sub_group_block_read((__global uint *)&wei[idx])); )==""\n"
55R"==(#define BLOCK_READ_WHT_8x32(data, idx) \ )==""\n"
56R"==(data = as_int8(intel_sub_group_block_read8((__global uint *)&wei[idx])); )==""\n"
57R"==(#define BLOCK_READ_BIA(data, idx) \ )==""\n"
58R"==(data = as_float4(intel_sub_group_block_read4((__global uint *)&bias[idx])); )==""\n"
59R"==(#if SCALES_PER_OC )==""\n"
60R"==(#define SCALE scales )==""\n"
61R"==(#elif SCALES_COMMON )==""\n"
62R"==(#define SCALE runtime_scales[0] )==""\n"
63R"==(#else )==""\n"
64R"==(#define SCALE 1 )==""\n"
65R"==(#endif )==""\n"
66R"==(#define BLOCK_READ_BOUND 1 )==""\n"
67R"==(#define BLOCK_WRITE_BOUND 4 )==""\n"
68R"==(inline float4 read_bias_scale_block4(const __global float *dst, int off) { )==""\n"
69R"==(const int local_id = get_sub_group_local_id(); )==""\n"
70R"==(#if OC % OC_BLOCK != 0 )==""\n"
71R"==(int tail = OC - off; )==""\n"
72R"==(if (tail < OC_BLOCK) { )==""\n"
73R"==(return (float4)(local_id < tail - 8 * 0 ? dst[0 * 8 + local_id] : 0, )==""\n"
74R"==(local_id < tail - 8 * 1 ? dst[1 * 8 + local_id] : 0, )==""\n"
75R"==(local_id < tail - 8 * 2 ? dst[2 * 8 + local_id] : 0, )==""\n"
76R"==(local_id < tail - 8 * 3 ? dst[3 * 8 + local_id] : 0); )==""\n"
77R"==(} )==""\n"
78R"==(#endif )==""\n"
79R"==(#if OC % BLOCK_READ_BOUND != 0 )==""\n"
80R"==(return (float4)(dst[0 * 8 + local_id], dst[1 * 8 + local_id], )==""\n"
81R"==(dst[2 * 8 + local_id], dst[3 * 8 + local_id]); )==""\n"
82R"==(#else )==""\n"
83R"==(return as_float4(intel_sub_group_block_read4((__global uint *)dst)); )==""\n"
84R"==(#endif )==""\n"
85R"==(} )==""\n"
86R"==(inline DST_DATA4_T read_oc_block4(const __global DATA_T *dst, int off) { )==""\n"
87R"==(const int local_id = get_sub_group_local_id(); )==""\n"
88R"==(#if OC % OC_BLOCK != 0 )==""\n"
89R"==(int tail = OC - off; )==""\n"
90R"==(if (tail < OC_BLOCK) { )==""\n"
91R"==(return (DST_DATA4_T)( )==""\n"
92R"==(local_id < tail - 8 * 0 ? dst[0 * 8 + local_id] : 0, )==""\n"
93R"==(local_id < tail - 8 * 1 ? dst[1 * 8 + local_id] : 0, )==""\n"
94R"==(local_id < tail - 8 * 2 ? dst[2 * 8 + local_id] : 0, )==""\n"
95R"==(local_id < tail - 8 * 3 ? dst[3 * 8 + local_id] : 0); )==""\n"
96R"==(} )==""\n"
97R"==(#endif )==""\n"
98R"==(#if OC % BLOCK_READ_BOUND != 0 )==""\n"
99R"==(return (DST_DATA4_T)(dst[0 * 8 + local_id], dst[1 * 8 + local_id], )==""\n"
100R"==(dst[2 * 8 + local_id], dst[3 * 8 + local_id]); )==""\n"
101R"==(#else )==""\n"
102R"==(return BLOCK_READ_DST4(dst); )==""\n"
103R"==(#endif )==""\n"
104R"==(} )==""\n"
105R"==(inline void write_oc_block4(__global DATA_T *dst, int off, DATA4_T value) { )==""\n"
106R"==(const int local_id = get_sub_group_local_id(); )==""\n"
107R"==(#if OC % OC_BLOCK != 0 )==""\n"
108R"==(int tail = OC - off; )==""\n"
109R"==(if (tail < OC_BLOCK) { )==""\n"
110R"==(if (local_id < tail) dst[0 * 8 + local_id] = value.s0; )==""\n"
111R"==(if (local_id < tail - 8 * 1) dst[1 * 8 + local_id] = value.s1; )==""\n"
112R"==(if (local_id < tail - 8 * 2) dst[2 * 8 + local_id] = value.s2; )==""\n"
113R"==(if (local_id < tail - 8 * 3) dst[3 * 8 + local_id] = value.s3; )==""\n"
114R"==(return; )==""\n"
115R"==(} )==""\n"
116R"==(#endif )==""\n"
117R"==(#if OC % BLOCK_WRITE_BOUND != 0 )==""\n"
118R"==(dst[0 * 8 + local_id] = value.s0; )==""\n"
119R"==(dst[1 * 8 + local_id] = value.s1; )==""\n"
120R"==(dst[2 * 8 + local_id] = value.s2; )==""\n"
121R"==(dst[3 * 8 + local_id] = value.s3; )==""\n"
122R"==(return; )==""\n"
123R"==(#else )==""\n"
124R"==(BLOCK_WRITE_DST4(dst, value); )==""\n"
125R"==(return; )==""\n"
126R"==(#endif )==""\n"
127R"==(} )==""\n"
128R"==(inline void write_local_1(__local uint *S, __global SRC_DATA_T *src1) { )==""\n"
129R"==(const int local_id = get_sub_group_local_id(); )==""\n"
130R"==(#if IC % 4 != 0 )==""\n"
131R"==(__local SRC_DATA_T *S1 = S + local_id; )==""\n"
132R"==(__global SRC_DATA_T *src2 = (const __global uint *)(src1) + local_id; )==""\n"
133R"==(S1[0] = src2[0]; )==""\n"
134R"==(S1[1] = src2[1]; )==""\n"
135R"==(S1[2] = src2[2]; )==""\n"
136R"==(S1[3] = src2[3]; )==""\n"
137R"==(#else )==""\n"
138R"==(block_write(S, intel_sub_group_block_read((const __global uint *)(src1))); )==""\n"
139R"==(#endif )==""\n"
140R"==(return; )==""\n"
141R"==(} )==""\n"
142R"==(__attribute__((intel_reqd_sub_group_size(SUB_GROUP_SIZE))) )==""\n"
143R"==(__attribute__((reqd_work_group_size(LWS_0, LWS_1, LWS_2))) )==""\n"
144R"==(__kernel void )==""\n"
145R"==(conv_nhwc_fwd_x8s8x(const __global SRC_DATA_T *src, const __global char *wei, )==""\n"
146R"==(const __global float *bias, __global DATA_T *dst POST_OP_ARGS, )==""\n"
147R"==(const __global float *runtime_scales, )==""\n"
148R"==(const __global int *src_compensation, const __global int *src_zpoints, )==""\n"
149R"==(const __global int *dst_compensation) { )==""\n"
150R"==(const int group_oc = get_group_id(0) * OC_GROUP; )==""\n"
151R"==(const int group_mb = get_group_id(2) * MB_GROUP; )==""\n"
152R"==(const int group_sp = get_group_id(1) * SP_GROUP; )==""\n"
153R"==(const int sub_group_id = get_sub_group_id(); )==""\n"
154R"==(const int oc = (sub_group_id % OC_GROUP); )==""\n"
155R"==(const int sp = (sub_group_id / OC_GROUP); )==""\n"
156R"==(const int g = (group_oc + oc) / OC_NCHUNK; )==""\n"
157R"==(const int group_ic = IC_NCHUNK * g; )==""\n"
158R"==(const int god = group_sp / (OW_PADDED * OH); )==""\n"
159R"==(const int gohw = group_sp % (OW_PADDED * OH); )==""\n"
160R"==(const int goh = gohw / OW_PADDED; )==""\n"
161R"==(const int gow = OW_BLOCK * (gohw % OW_PADDED); )==""\n"
162R"==(const int gid = god * SD; )==""\n"
163R"==(const int gih = goh * SH; )==""\n"
164R"==(const int giw = gow * SW; )==""\n"
165R"==(const int local_ow = OW_BLOCK * sp; )==""\n"
166R"==(const int local_iw = local_ow * SW; )==""\n"
167R"==(const int od = god; )==""\n"
168R"==(const int ow = gow + local_ow; )==""\n"
169R"==(const int oh = goh; )==""\n"
170R"==(const int id = gid - PD; )==""\n"
171R"==(const int iw = giw + local_iw - PW; )==""\n"
172R"==(const int ih = gih - PH; )==""\n"
173R"==(const int local_id = get_sub_group_local_id(); )==""\n"
174R"==(__local uint S_slice[SRC_SLM_SIZE]; )==""\n"
175R"==(__local uint *S_part = S_slice + IC_BLOCK / 4 * (sp * SW * OW_BLOCK + PW); )==""\n"
176R"==(__local uint *S_work = S_slice + IC_BLOCK / 4 * (sp * SW * OW_BLOCK); )==""\n"
177R"==(const bool left_tail = iw < 0; )==""\n"
178R"==(const bool left_nozero_tail = sub_group_id == 0 && iw > -PW; )==""\n"
179R"==(const bool right_tail = (iw + PW + SLM_TAIL >= IW) && (iw + PW < IW); )==""\n"
180R"==(const bool right_nozero_tail )==""\n"
181R"==(= sp == (LWS_1 - 1) && (iw + PW + SLM_TAIL < IW); )==""\n"
182R"==(const bool empty = (iw + PW >= IW); )==""\n"
183R"==(dst += group_mb * MB_BLOCK * OD * OH * OW * G * OC; )==""\n"
184R"==(dst += (OW * OH * od + OW * oh + ow) * G * OC; )==""\n"
185R"==(dst += OC_BLOCK * (group_oc + oc); )==""\n"
186R"==(src += group_mb * MB_BLOCK * ID * IH * IW * G * IC; )==""\n"
187R"==(src += (IW * IH * id + IW * ih + iw + PW) * G * IC; )==""\n"
188R"==(src += group_ic * IC_BLOCK; )==""\n"
189R"==(wei += IC_BLOCK * KD * KH * KW * OC_BLOCK * (group_oc + oc) * IC_NCHUNK; )==""\n"
190R"==(/* Prepare S_slice tails */ )==""\n"
191R"==(#if PW > 0 )==""\n"
192R"==(if (left_tail) { )==""\n"
193R"==(for (int i = 0; i < PW; i++) { )==""\n"
194R"==(block_write(S_slice + i * 8, 0); )==""\n"
195R"==(} )==""\n"
196R"==(} )==""\n"
197R"==(#endif )==""\n"
198R"==(#if ZERO_TAIL > 0 )==""\n"
199R"==(if (right_tail) { )==""\n"
200R"==(for (int i = SLM_TAIL; i < SW * OW_BLOCK + (KW - 1) * (1 + DW) - PW; )==""\n"
201R"==(i++) { )==""\n"
202R"==(block_write(S_part + i * 8, 0); )==""\n"
203R"==(} )==""\n"
204R"==(} )==""\n"
205R"==(#if SLM_NCHUNK < OW_NCHUNK )==""\n"
206R"==(if (empty) { )==""\n"
207R"==(for (int i = -PW; i < SW * OW_BLOCK + (KW - 1) * (1 + DW) - PW; i++) { )==""\n"
208R"==(block_write(S_part + i * 8, 0); )==""\n"
209R"==(} )==""\n"
210R"==(} )==""\n"
211R"==(#endif )==""\n"
212R"==(#endif )==""\n"
213R"==(ACC_DATA_BLOCK C00 = 0, C01 = 0, C02 = 0, C03 = 0; )==""\n"
214R"==(__attribute__((opencl_unroll_hint(1))) )==""\n"
215R"==(for (int ic_chunk = 0; ic_chunk < IC_NCHUNK; ic_chunk++) { )==""\n"
216R"==(SRC_DATA_BLOCK_T S0; )==""\n"
217R"==(__attribute__((opencl_unroll_hint(1))) )==""\n"
218R"==(for (int kd = 0; kd < KD; kd++) { )==""\n"
219R"==(if (kd * (1 + DD) + id < 0 || kd * (1 + DD) + id >= ID) { )==""\n"
220R"==(wei += IC_BLOCK * OC_BLOCK * KH * KW; )==""\n"
221R"==(continue; )==""\n"
222R"==(} )==""\n"
223R"==(__attribute__((opencl_unroll_hint(1))) )==""\n"
224R"==(for (int kh = 0; kh < KH; kh++) { )==""\n"
225R"==(if (kh * (1 + DH) + ih < 0 || kh * (1 + DH) + ih >= IH) { )==""\n"
226R"==(wei += IC_BLOCK * OC_BLOCK * KW; )==""\n"
227R"==(continue; )==""\n"
228R"==(} )==""\n"
229R"==(barrier(CLK_LOCAL_MEM_FENCE); )==""\n"
230R"==(const __global SRC_DATA_T *src1 = src )==""\n"
231R"==(+ kd * (1 + DD) * IH * IW * G * IC )==""\n"
232R"==(+ kh * (1 + DH) * IW * G * IC; )==""\n"
233R"==(#if SLM_NCHUNK < OW_NCHUNK )==""\n"
234R"==(if (iw + PW < IW) { )==""\n"
235R"==(#endif )==""\n"
236R"==(#if OW_NCHUNK > LWS_1 )==""\n"
237R"==(/* Copy tails in case of multigroups */ )==""\n"
238R"==(if (ow < OW) { )==""\n"
239R"==(#if PW > 0 )==""\n"
240R"==(if (left_nozero_tail) { )==""\n"
241R"==(for (int i = -PW - min(iw, 0); i < 0; i++) { )==""\n"
242R"==(write_local_1( )==""\n"
243R"==(S_part + i * 8, src1 + i * G * IC); )==""\n"
244R"==(} )==""\n"
245R"==(} )==""\n"
246R"==(#endif )==""\n"
247R"==(if (right_nozero_tail) { )==""\n"
248R"==(int buffer_last = (KW - 1) * (1 + DW) - PW; )==""\n"
249R"==(int src_last = IW - iw - SW * OW_BLOCK - PW; )==""\n"
250R"==(for (int i = SW * OW_BLOCK; i < SW * OW_BLOCK )==""\n"
251R"==(+ min(buffer_last, src_last); )==""\n"
252R"==(i++) { )==""\n"
253R"==(write_local_1( )==""\n"
254R"==(S_part + i * 8, src1 + i * G * IC); )==""\n"
255R"==(} )==""\n"
256R"==(for (int i = SW * OW_BLOCK )==""\n"
257R"==(+ min(buffer_last, src_last); )==""\n"
258R"==(i < SW * OW_BLOCK + buffer_last; i++) { )==""\n"
259R"==(block_write(S_part + i * 8, 0); )==""\n"
260R"==(} )==""\n"
261R"==(} )==""\n"
262R"==(#endif )==""\n"
263R"==(#if SLM_TAIL != OW_BLOCK * SW )==""\n"
264R"==(/* Copy last block to SLM */ )==""\n"
265R"==(if (right_tail) { )==""\n"
266R"==(__attribute__(( )==""\n"
267R"==(opencl_unroll_hint)) )==""\n"
268R"==(for (int i = 0; i < SLM_TAIL; i++) { )==""\n"
269R"==(write_local_1( )==""\n"
270R"==(S_part + i * 8, src1 + i * G * IC); )==""\n"
271R"==(} )==""\n"
272R"==(} else { )==""\n"
273R"==(#endif )==""\n"
274R"==(__attribute__(( )==""\n"
275R"==(opencl_unroll_hint)) )==""\n"
276R"==(for (int i = 0; i < SW * OW_BLOCK; i++) { )==""\n"
277R"==(write_local_1( )==""\n"
278R"==(S_part + i * 8, src1 + i * G * IC); )==""\n"
279R"==(} )==""\n"
280R"==(#if SLM_TAIL != OW_BLOCK * SW )==""\n"
281R"==(} )==""\n"
282R"==(#endif )==""\n"
283R"==(#if OW_NCHUNK > LWS_1 )==""\n"
284R"==(} )==""\n"
285R"==(#endif )==""\n"
286R"==(#if SLM_NCHUNK < OW_NCHUNK )==""\n"
287R"==(} )==""\n"
288R"==(#endif )==""\n"
289R"==(barrier(CLK_LOCAL_MEM_FENCE); )==""\n"
290R"==(__attribute__((opencl_unroll_hint)) )==""\n"
291R"==(for (int kw = 0; kw < KW; kw++) { )==""\n"
292R"==(__attribute__(( )==""\n"
293R"==(opencl_unroll_hint(OW_BLOCK))) )==""\n"
294R"==(for (int i = 0; i < OW_BLOCK; i++) { )==""\n"
295R"==(S0[i] = block_read( )==""\n"
296R"==(S_work + (kw * (1 + DW) + SW * i) * 8); )==""\n"
297R"==(} )==""\n"
298R"==(int8 W0 = 0, W1 = 0, W2 = 0, W3 = 0; )==""\n"
299R"==(#if IC % IC_BLOCK != 0 )==""\n"
300R"==(if (ic_chunk == IC_NCHUNK - 1) { )==""\n"
301R"==(unroll_for(int i = 0; i < IC_NBLOCKS_TAIL; ++i) )==""\n"
302R"==(BLOCK_READ_WHT_1x32(W0[i], (i + 0) * IC_BLOCK); )==""\n"
303R"==(if (OC > 8) )==""\n"
304R"==(unroll_for(int i = 0; i < IC_NBLOCKS_TAIL; ++i) )==""\n"
305R"==(BLOCK_READ_WHT_1x32( )==""\n"
306R"==(W1[i], (i + 8) * IC_BLOCK); )==""\n"
307R"==(if (OC > 16) )==""\n"
308R"==(unroll_for(int i = 0; i < IC_NBLOCKS_TAIL; ++i) )==""\n"
309R"==(BLOCK_READ_WHT_1x32( )==""\n"
310R"==(W2[i], (i + 16) * IC_BLOCK); )==""\n"
311R"==(if (OC > 24) )==""\n"
312R"==(unroll_for(int i = 0; i < IC_NBLOCKS_TAIL; ++i) )==""\n"
313R"==(BLOCK_READ_WHT_1x32( )==""\n"
314R"==(W3[i], (i + 24) * IC_BLOCK); )==""\n"
315R"==(C00 = MMAD_TAIL(S0, W0, C00); )==""\n"
316R"==(if (OC > 8) C01 = MMAD_TAIL(S0, W1, C01); )==""\n"
317R"==(if (OC > 16) C02 = MMAD_TAIL(S0, W2, C02); )==""\n"
318R"==(if (OC > 24) C03 = MMAD_TAIL(S0, W3, C03); )==""\n"
319R"==(} else )==""\n"
320R"==(#endif )==""\n"
321R"==({ )==""\n"
322R"==(BLOCK_READ_WHT_8x32(W0, 0); )==""\n"
323R"==(if (OC > 8) BLOCK_READ_WHT_8x32(W1, 8 * IC_BLOCK); )==""\n"
324R"==(if (OC > 16) BLOCK_READ_WHT_8x32(W2, 16 * IC_BLOCK); )==""\n"
325R"==(if (OC > 24) BLOCK_READ_WHT_8x32(W3, 24 * IC_BLOCK); )==""\n"
326R"==(C00 = MMAD_FULL(S0, W0, C00); )==""\n"
327R"==(if (OC > 8) C01 = MMAD_FULL(S0, W1, C01); )==""\n"
328R"==(if (OC > 16) C02 = MMAD_FULL(S0, W2, C02); )==""\n"
329R"==(if (OC > 24) C03 = MMAD_FULL(S0, W3, C03); )==""\n"
330R"==(} )==""\n"
331R"==(wei += IC_BLOCK * OC_BLOCK; )==""\n"
332R"==(} )==""\n"
333R"==(} )==""\n"
334R"==(} )==""\n"
335R"==(src += IC_BLOCK; )==""\n"
336R"==(} )==""\n"
337R"==(if (ow < OW) { )==""\n"
338R"==(float4 tmp = {0, 0, 0, 0}; )==""\n"
339R"==(DST_DATA4_T dst_pack[BLOCK] = {0}; )==""\n"
340R"==(DST_DATA4_T D0[BLOCK] = {0}; )==""\n"
341R"==(#if SCALES_PER_OC )==""\n"
342R"==(float4 scales = {0, 0, 0, 0}; )==""\n"
343R"==(scales = read_bias_scale_block4( )==""\n"
344R"==(runtime_scales + (group_oc + oc) * OC_BLOCK, )==""\n"
345R"==((group_oc + oc) * OC_BLOCK); )==""\n"
346R"==(#endif )==""\n"
347R"==(#if WITH_BIAS )==""\n"
348R"==(float4 bia = {0, 0, 0, 0}; )==""\n"
349R"==(bia = read_bias_scale_block4( )==""\n"
350R"==(bias + (group_oc + oc) * OC_BLOCK, (group_oc + oc) * OC_BLOCK); )==""\n"
351R"==(#define QUANTIZE_ADD_BIAS() tmp = SCALE * fma(tmp, (float4)1, bia); )==""\n"
352R"==(#else )==""\n"
353R"==(#define QUANTIZE_ADD_BIAS() tmp *= SCALE; )==""\n"
354R"==(#endif )==""\n"
355R"==(#if WITH_SUM )==""\n"
356R"==(#if OW_BLOCK >= 4 )==""\n"
357R"==(*(DST_DATA16_T *)D0 = (DST_DATA16_T)( )==""\n"
358R"==(read_oc_block4(dst + G * OC * 0, (group_oc + oc) * OC_BLOCK), )==""\n"
359R"==(read_oc_block4(dst + G * OC * 1, (group_oc + oc) * OC_BLOCK), )==""\n"
360R"==(read_oc_block4(dst + G * OC * 2, (group_oc + oc) * OC_BLOCK), )==""\n"
361R"==(read_oc_block4(dst + G * OC * 3, (group_oc + oc) * OC_BLOCK)); )==""\n"
362R"==(#endif )==""\n"
363R"==(#if OW_BLOCK == 8 )==""\n"
364R"==(*(DST_DATA16_T *)(D0 + 4) = (DST_DATA16_T)( )==""\n"
365R"==(read_oc_block4(dst + G * OC * 4, (group_oc + oc) * OC_BLOCK), )==""\n"
366R"==(read_oc_block4(dst + G * OC * 5, (group_oc + oc) * OC_BLOCK), )==""\n"
367R"==(read_oc_block4(dst + G * OC * 6, (group_oc + oc) * OC_BLOCK), )==""\n"
368R"==(read_oc_block4(dst + G * OC * 7, (group_oc + oc) * OC_BLOCK)); )==""\n"
369R"==(#endif )==""\n"
370R"==(#endif )==""\n"
371R"==(#define PACK(C0, C1, C2, C3, idx) \ )==""\n"
372R"==(do { \ )==""\n"
373R"==(tmp[0] = C0[idx]; \ )==""\n"
374R"==(tmp[1] = C1[idx]; \ )==""\n"
375R"==(tmp[2] = C2[idx]; \ )==""\n"
376R"==(tmp[3] = C3[idx]; \ )==""\n"
377R"==(} while (0) )==""\n"
378R"==(#define CONVERT_PACK(idx) \ )==""\n"
379R"==(do { \ )==""\n"
380R"==(dst_pack[idx] = CONVERT_DST_DATA4_T(tmp); \ )==""\n"
381R"==(} while (0) )==""\n"
382R"==(#define PACK_DST(C0, C1, C2, C3, D) \ )==""\n"
383R"==(do { \ )==""\n"
384R"==(for (int n_i = 0; n_i < OW_BLOCK; n_i++) { \ )==""\n"
385R"==(PACK(C0, C1, C2, C3, n_i); \ )==""\n"
386R"==(QUANTIZE_ADD_BIAS(); \ )==""\n"
387R"==(for (int didx = 0; didx < 4; ++didx) { \ )==""\n"
388R"==(float tmp_i = tmp[didx]; \ )==""\n"
389R"==(float dni_i = convert_float(AS_SUM_DATA_T(D[n_i][didx])); \ )==""\n"
390R"==(int po_mb; \ )==""\n"
391R"==(po_mb = group_mb % MB; \ )==""\n"
392R"==(const int po_oc = ((group_oc + oc) * OC_BLOCK + local_id \ )==""\n"
393R"==(+ didx * SUB_GROUP_SIZE) \ )==""\n"
394R"==(% (OC * G); \ )==""\n"
395R"==(APPLY_POST_OPS_SERIAL_BINARY_2D( \ )==""\n"
396R"==(tmp_i, float, dni_i, float, po_mb, 1, po_oc, 1); \ )==""\n"
397R"==(tmp[didx] = tmp_i; \ )==""\n"
398R"==(} \ )==""\n"
399R"==(CONVERT_PACK(n_i); \ )==""\n"
400R"==(} \ )==""\n"
401R"==(} while (0) )==""\n"
402R"==(PACK_DST(C00, C01, C02, C03, D0); )==""\n"
403R"==(#if OW_TAIL )==""\n"
404R"==(if (ow + OW_BLOCK > OW) { )==""\n"
405R"==(__attribute__((opencl_unroll_hint(OW_TAIL))) )==""\n"
406R"==(for (int i = 0; i < OW_TAIL; i++) { )==""\n"
407R"==(write_oc_block4(dst + i * G * OC, (group_oc + oc) * OC_BLOCK, )==""\n"
408R"==(dst_pack[i]); )==""\n"
409R"==(} )==""\n"
410R"==(} else { )==""\n"
411R"==(#endif )==""\n"
412R"==(#if OW_BLOCK >= 4 )==""\n"
413R"==(write_oc_block4( )==""\n"
414R"==(dst + G * OC * 0, (group_oc + oc) * OC_BLOCK, dst_pack[0]); )==""\n"
415R"==(write_oc_block4( )==""\n"
416R"==(dst + G * OC * 1, (group_oc + oc) * OC_BLOCK, dst_pack[1]); )==""\n"
417R"==(write_oc_block4( )==""\n"
418R"==(dst + G * OC * 2, (group_oc + oc) * OC_BLOCK, dst_pack[2]); )==""\n"
419R"==(write_oc_block4( )==""\n"
420R"==(dst + G * OC * 3, (group_oc + oc) * OC_BLOCK, dst_pack[3]); )==""\n"
421R"==(#endif )==""\n"
422R"==(#if OW_BLOCK == 8 )==""\n"
423R"==(write_oc_block4( )==""\n"
424R"==(dst + G * OC * 4, (group_oc + oc) * OC_BLOCK, dst_pack[4]); )==""\n"
425R"==(write_oc_block4( )==""\n"
426R"==(dst + G * OC * 5, (group_oc + oc) * OC_BLOCK, dst_pack[5]); )==""\n"
427R"==(write_oc_block4( )==""\n"
428R"==(dst + G * OC * 6, (group_oc + oc) * OC_BLOCK, dst_pack[6]); )==""\n"
429R"==(write_oc_block4( )==""\n"
430R"==(dst + G * OC * 7, (group_oc + oc) * OC_BLOCK, dst_pack[7]); )==""\n"
431R"==(#endif )==""\n"
432R"==(#if OW_TAIL )==""\n"
433R"==(} )==""\n"
434R"==(#endif )==""\n"
435R"==(} )==""\n"
436R"==(} )==""\n"
437R"==()==";
438}
439}
440}
441}