1namespace dnnl {
2namespace impl {
3namespace gpu {
4namespace ocl {
5const char *gen9_pooling_kernel = R"==(/******************************************************************************* )==""\n"
6R"==(* Copyright 2020-2022 Intel Corporation )==""\n"
7R"==(* )==""\n"
8R"==(* Licensed under the Apache License, Version 2.0 (the "License"); )==""\n"
9R"==(* you may not use this file except in compliance with the License. )==""\n"
10R"==(* You may obtain a copy of the License at )==""\n"
11R"==(* )==""\n"
12R"==(* http: )==""\n"
13R"==(* )==""\n"
14R"==(* Unless required by applicable law or agreed to in writing, software )==""\n"
15R"==(* distributed under the License is distributed on an "AS IS" BASIS, )==""\n"
16R"==(* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. )==""\n"
17R"==(* See the License for the specific language governing permissions and )==""\n"
18R"==(* limitations under the License. )==""\n"
19R"==(*******************************************************************************/ )==""\n"
20R"==(#include "gpu/ocl/ocl_post_ops.h" )==""\n"
21R"==(#include "gpu/ocl/ocl_types.h" )==""\n"
22R"==(inline VECT_DATA_T read_vect_c_block(int idx, const __global DATA_T *ptr, int c, )==""\n"
23R"==(int blocks_stride, int chunks_per_block); )==""\n"
24R"==(inline VECT_INT_T read_vect_c_block_int(int idx, const __global int *ptr, int c, )==""\n"
25R"==(int blocks_stride, int chunks_per_block); )==""\n"
26R"==(inline void write_vect_c_block(int idx, __global DATA_T *ptr, int c, )==""\n"
27R"==(int blocks_stride, int chunks_per_block, VECT_DATA_T block); )==""\n"
28R"==(inline void write_vect_c_block_int(int idx, __global int *ptr, int c, )==""\n"
29R"==(int blocks_stride, int chunks_per_block, VECT_INT_T block); )==""\n"
30R"==(#if DT_BF16 || DT_F16 )==""\n"
31R"==(#define USE_FLOATS true )==""\n"
32R"==(#else )==""\n"
33R"==(#define USE_FLOATS (ALG_AVG_NP || ALG_AVG_P) )==""\n"
34R"==(#endif )==""\n"
35R"==(#if IS_FWD )==""\n"
36R"==(KERNEL_ATTR )==""\n"
37R"==(__kernel void gen9_pooling_fwd(__global DATA_T *src, __global int *ws, )==""\n"
38R"==(__global DATA_T *dst, const int batch_id POST_OP_ARGS) { )==""\n"
39R"==(const int mb0 = MB_BLOCK_SIZE * batch_id + GWS_GET_MB(); )==""\n"
40R"==(#if UNROLL_MB )==""\n"
41R"==(const int mb1 = mb0 + MB / 2; )==""\n"
42R"==(#endif )==""\n"
43R"==(const int c = GWS_GET_C(); )==""\n"
44R"==(const int od = GWS_GET_OD(); )==""\n"
45R"==(const int oh = GWS_GET_OH(); )==""\n"
46R"==(const int ow = GWS_GET_OW(); )==""\n"
47R"==(#if USE_MB_C_BLOCK )==""\n"
48R"==(const int src_stride = (SRC_SB0 > 1) ? SRC_SB0 : SRC_S0; )==""\n"
49R"==(const int dst_stride = (DST_SB0 > 1) ? DST_SB0 : DST_S0; )==""\n"
50R"==(const int src_chunks_per_c_block = CHUNKS_PER_C_BLOCK; )==""\n"
51R"==(const int dst_chunks_per_c_block = CHUNKS_PER_C_BLOCK; )==""\n"
52R"==(#elif USE_ONLY_C_BLOCK )==""\n"
53R"==(const int src_stride = (SRC_B1 > 1) ? SRC_S1 : SUB_GROUP_SIZE; )==""\n"
54R"==(const int dst_stride = (DST_B1 > 1) ? DST_S1 : SUB_GROUP_SIZE; )==""\n"
55R"==(const int src_chunks_per_c_block )==""\n"
56R"==(= (SRC_B1 > 1) ? (SRC_B1 / SUB_GROUP_SIZE) : 1; )==""\n"
57R"==(const int dst_chunks_per_c_block )==""\n"
58R"==(= (DST_B1 > 1) ? (DST_B1 / SUB_GROUP_SIZE) : 1; )==""\n"
59R"==(#endif )==""\n"
60R"==(const int ws_stride = dst_stride; )==""\n"
61R"==(const int ws_chunks_per_c_block = dst_chunks_per_c_block; )==""\n"
62R"==(if (mb0 >= SRC_D0) { )==""\n"
63R"==(VECT_DATA_T dst_zero = DATA_ZERO; )==""\n"
64R"==(VECT_INT_T ws_zero = 0; )==""\n"
65R"==(int off = DST_OFF(mb0, c, od, oh, ow); )==""\n"
66R"==(write_vect_c_block( )==""\n"
67R"==(0, &dst[off], c, dst_stride, dst_chunks_per_c_block, dst_zero); )==""\n"
68R"==(write_vect_c_block( )==""\n"
69R"==(1, &dst[off], c, dst_stride, dst_chunks_per_c_block, dst_zero); )==""\n"
70R"==(#if ALG_MAX && IS_TRAINING )==""\n"
71R"==(write_vect_c_block_int( )==""\n"
72R"==(0, &ws[off], c, ws_stride, ws_chunks_per_c_block, ws_zero); )==""\n"
73R"==(write_vect_c_block_int( )==""\n"
74R"==(1, &ws[off], c, ws_stride, ws_chunks_per_c_block, ws_zero); )==""\n"
75R"==(#endif )==""\n"
76R"==(return; )==""\n"
77R"==(} )==""\n"
78R"==(const int id = od * SD - PD; )==""\n"
79R"==(const int ih = oh * SH - PH; )==""\n"
80R"==(const int iw = ow * SW - PW; )==""\n"
81R"==(#if USE_FLOATS )==""\n"
82R"==(VECT_FLOAT_T D0 = ALG_MAX ? CONVERT_FLOAT_T(DATA_MIN) : 0.0f; )==""\n"
83R"==(VECT_FLOAT_T D1 = ALG_MAX ? CONVERT_FLOAT_T(DATA_MIN) : 0.0f; )==""\n"
84R"==(#else )==""\n"
85R"==(VECT_DATA_T D0 = ALG_MAX ? DATA_MIN : DATA_ZERO; )==""\n"
86R"==(VECT_DATA_T D1 = ALG_MAX ? DATA_MIN : DATA_ZERO; )==""\n"
87R"==(#endif )==""\n"
88R"==(VECT_INT_T WS0 = 0, WS1 = 0; )==""\n"
89R"==(for (int kd = 0; kd < KD; ++kd) { )==""\n"
90R"==(if (id + kd < 0 || id + kd >= ID) continue; )==""\n"
91R"==(for (int kh = 0; kh < KH; ++kh) { )==""\n"
92R"==(if (ih + kh < 0 || ih + kh >= IH) continue; )==""\n"
93R"==(for (int kw = 0; kw < KW; ++kw) { )==""\n"
94R"==(if (iw + kw < 0 || iw + kw >= IW) continue; )==""\n"
95R"==(int src_off0 = SRC_OFF(mb0, c, id + kd, ih + kh, iw + kw); )==""\n"
96R"==(#if UNROLL_MB )==""\n"
97R"==(int src_off1 = SRC_OFF(mb1, c, id + kd, ih + kh, iw + kw); )==""\n"
98R"==(#endif )==""\n"
99R"==(#if USE_FLOATS )==""\n"
100R"==(VECT_FLOAT_T S0 = CONVERT_VECT_FLOAT_T(read_vect_c_block(0, )==""\n"
101R"==(&src[src_off0], c, src_stride, src_chunks_per_c_block)); )==""\n"
102R"==(#if UNROLL_MB )==""\n"
103R"==(VECT_FLOAT_T S1 = CONVERT_VECT_FLOAT_T(read_vect_c_block(0, )==""\n"
104R"==(&src[src_off1], c, src_stride, src_chunks_per_c_block)); )==""\n"
105R"==(#else )==""\n"
106R"==(VECT_FLOAT_T S1 = CONVERT_VECT_FLOAT_T(read_vect_c_block(1, )==""\n"
107R"==(&src[src_off0], c, src_stride, src_chunks_per_c_block)); )==""\n"
108R"==(#endif )==""\n"
109R"==(#else )==""\n"
110R"==(VECT_DATA_T S0 = read_vect_c_block(0, &src[src_off0], c, )==""\n"
111R"==(src_stride, src_chunks_per_c_block); )==""\n"
112R"==(#if UNROLL_MB )==""\n"
113R"==(VECT_DATA_T S1 = read_vect_c_block(0, &src[src_off1], c, )==""\n"
114R"==(src_stride, src_chunks_per_c_block); )==""\n"
115R"==(#else )==""\n"
116R"==(VECT_DATA_T S1 = read_vect_c_block(1, &src[src_off0], c, )==""\n"
117R"==(src_stride, src_chunks_per_c_block); )==""\n"
118R"==(#endif )==""\n"
119R"==(#endif )==""\n"
120R"==(#if ALG_MAX )==""\n"
121R"==(#if IS_TRAINING )==""\n"
122R"==(VECT_INT_T CMP0 = isless(D0, S0); )==""\n"
123R"==(WS0 = select(WS0, kd * KH * KW + kh * KW + kw, CMP0); )==""\n"
124R"==(D0 = select(D0, S0, CMP0); )==""\n"
125R"==(VECT_INT_T CMP1 = isless(D1, S1); )==""\n"
126R"==(WS1 = select(WS1, kd * KH * KW + kh * KW + kw, CMP1); )==""\n"
127R"==(D1 = select(D1, S1, CMP1); )==""\n"
128R"==(#else )==""\n"
129R"==(D0 = max(D0, S0); )==""\n"
130R"==(D1 = max(D1, S1); )==""\n"
131R"==(#endif )==""\n"
132R"==(#else )==""\n"
133R"==(D0 += S0; )==""\n"
134R"==(D1 += S1; )==""\n"
135R"==(#endif )==""\n"
136R"==(} )==""\n"
137R"==(} )==""\n"
138R"==(} )==""\n"
139R"==(#if ALG_AVG_P )==""\n"
140R"==(D0 = D0 / (KD * KH * KW); )==""\n"
141R"==(D1 = D1 / (KD * KH * KW); )==""\n"
142R"==(#endif )==""\n"
143R"==(#if ALG_AVG_NP )==""\n"
144R"==(const int id_start = max(od * SD - PD, 0); )==""\n"
145R"==(const int ih_start = max(oh * SH - PH, 0); )==""\n"
146R"==(const int iw_start = max(ow * SW - PW, 0); )==""\n"
147R"==(const int id_end = min(od * SD - PD + KD, ID); )==""\n"
148R"==(const int ih_end = min(oh * SH - PH + KH, IH); )==""\n"
149R"==(const int iw_end = min(ow * SW - PW + KW, IW); )==""\n"
150R"==(const int num_summands )==""\n"
151R"==(= (ih_end - ih_start) * (iw_end - iw_start) * (id_end - id_start); )==""\n"
152R"==(D0 = D0 / num_summands; )==""\n"
153R"==(D1 = D1 / num_summands; )==""\n"
154R"==(#endif )==""\n"
155R"==(int dst_off0 = DST_OFF(mb0, c, od, oh, ow); )==""\n"
156R"==(#if UNROLL_MB )==""\n"
157R"==(int dst_off1 = DST_OFF(mb1, c, od, oh, ow); )==""\n"
158R"==(#endif )==""\n"
159R"==(VECT_DATA_T sum0; )==""\n"
160R"==(VECT_DATA_T sum1; )==""\n"
161R"==(#if WITH_SUM )==""\n"
162R"==(sum0 = read_vect_c_block( )==""\n"
163R"==(0, &dst[dst_off0], c, dst_stride, dst_chunks_per_c_block); )==""\n"
164R"==(#if UNROLL_MB )==""\n"
165R"==(sum1 = read_vect_c_block( )==""\n"
166R"==(0, &dst[dst_off1], c, dst_stride, dst_chunks_per_c_block); )==""\n"
167R"==(#else )==""\n"
168R"==(sum1 = read_vect_c_block( )==""\n"
169R"==(1, &dst[dst_off0], c, dst_stride, dst_chunks_per_c_block); )==""\n"
170R"==(#endif )==""\n"
171R"==(#endif )==""\n"
172R"==(const int local_id = get_sub_group_local_id(); )==""\n"
173R"==(#if VECT_DT_N == 1 )==""\n"
174R"==(const int po_mb = mb0; )==""\n"
175R"==(const int po_oc = c + local_id; )==""\n"
176R"==(if (po_oc < C_WO_PADDING) { )==""\n"
177R"==(POST_OP_DATA_T po_sum0 = DATA_TO_REF(sum0); )==""\n"
178R"==(float po_D0 = USE_FLOATS ? D0 : CONVERT_FLOAT_T(D0); )==""\n"
179R"==(APPLY_POST_OPS_SERIAL_BINARY_2D( )==""\n"
180R"==(po_D0, float, po_sum0, POST_OP_DATA_T, po_mb, 1, po_oc, 1); )==""\n"
181R"==(D0 = USE_FLOATS ? po_D0 : CONVERT_DATA_T(po_D0); )==""\n"
182R"==(POST_OP_DATA_T po_sum1 = DATA_TO_REF(sum1); )==""\n"
183R"==(float po_D1 = USE_FLOATS ? D1 : CONVERT_FLOAT_T(D1); )==""\n"
184R"==(APPLY_POST_OPS_SERIAL_BINARY_2D( )==""\n"
185R"==(po_D1, float, po_sum1, POST_OP_DATA_T, po_mb, 1, po_oc, 1); )==""\n"
186R"==(D1 = USE_FLOATS ? po_D1 : CONVERT_DATA_T(po_D1); )==""\n"
187R"==(} )==""\n"
188R"==(#else )==""\n"
189R"==(for (int idx = 0; idx < VECT_DT_N; ++idx) { )==""\n"
190R"==(#if USE_MB_C_BLOCK )==""\n"
191R"==(int c_sub_block_id = idx % CHUNKS_PER_C_BLOCK; )==""\n"
192R"==(int mb_sub_block_id = idx / CHUNKS_PER_C_BLOCK; )==""\n"
193R"==(const int po_oc = c + c_sub_block_id * SUB_GROUP_SIZE + local_id; )==""\n"
194R"==(int po_mb = (mb0 + mb_sub_block_id) % MB; )==""\n"
195R"==(#else )==""\n"
196R"==(const int po_oc = c + idx * SUB_GROUP_SIZE + local_id; )==""\n"
197R"==(int po_mb = mb0; )==""\n"
198R"==(#endif )==""\n"
199R"==(if (po_mb >= MB || po_oc >= C_WO_PADDING) continue; )==""\n"
200R"==(float d0_i = USE_FLOATS ? D0[idx] : CONVERT_FLOAT_T(D0[idx]); )==""\n"
201R"==(POST_OP_DATA_T sum0_i = DATA_TO_REF(sum0[idx]); )==""\n"
202R"==(APPLY_POST_OPS_SERIAL_BINARY_2D( )==""\n"
203R"==(d0_i, float, sum0_i, POST_OP_DATA_T, po_mb, 1, po_oc, 1); )==""\n"
204R"==(D0[idx] = USE_FLOATS ? d0_i : CONVERT_DATA_T(d0_i); )==""\n"
205R"==(float d1_i = USE_FLOATS ? D1[idx] : CONVERT_FLOAT_T(D1[idx]); )==""\n"
206R"==(POST_OP_DATA_T sum1_i = DATA_TO_REF(sum1[idx]); )==""\n"
207R"==(po_mb += VECT_DT_N; )==""\n"
208R"==(APPLY_POST_OPS_SERIAL_BINARY_2D( )==""\n"
209R"==(d1_i, float, sum1_i, POST_OP_DATA_T, po_mb, 1, po_oc, 1); )==""\n"
210R"==(D1[idx] = USE_FLOATS ? d1_i : CONVERT_DATA_T(d1_i); )==""\n"
211R"==(} )==""\n"
212R"==(#endif )==""\n"
213R"==(#if USE_FLOATS )==""\n"
214R"==(VECT_DATA_T res0 = CONVERT_VECTOR_DATA_T(D0); )==""\n"
215R"==(VECT_DATA_T res1 = CONVERT_VECTOR_DATA_T(D1); )==""\n"
216R"==(#else )==""\n"
217R"==(VECT_DATA_T res0 = D0; )==""\n"
218R"==(VECT_DATA_T res1 = D1; )==""\n"
219R"==(#endif )==""\n"
220R"==(write_vect_c_block( )==""\n"
221R"==(0, &dst[dst_off0], c, dst_stride, dst_chunks_per_c_block, res0); )==""\n"
222R"==(#if UNROLL_MB )==""\n"
223R"==(write_vect_c_block( )==""\n"
224R"==(0, &dst[dst_off1], c, dst_stride, dst_chunks_per_c_block, res1); )==""\n"
225R"==(#else )==""\n"
226R"==(write_vect_c_block( )==""\n"
227R"==(1, &dst[dst_off0], c, dst_stride, dst_chunks_per_c_block, res1); )==""\n"
228R"==(#endif )==""\n"
229R"==(#if ALG_MAX && IS_TRAINING )==""\n"
230R"==(int ws_off0 = dst_off0; )==""\n"
231R"==(#if UNROLL_MB )==""\n"
232R"==(int ws_off1 = dst_off1; )==""\n"
233R"==(#endif )==""\n"
234R"==(write_vect_c_block_int( )==""\n"
235R"==(0, &ws[ws_off0], c, ws_stride, ws_chunks_per_c_block, WS0); )==""\n"
236R"==(#if UNROLL_MB )==""\n"
237R"==(write_vect_c_block_int( )==""\n"
238R"==(0, &ws[ws_off1], c, ws_stride, ws_chunks_per_c_block, WS1); )==""\n"
239R"==(#else )==""\n"
240R"==(write_vect_c_block_int( )==""\n"
241R"==(1, &ws[ws_off0], c, ws_stride, ws_chunks_per_c_block, WS1); )==""\n"
242R"==(#endif )==""\n"
243R"==(#endif )==""\n"
244R"==(} )==""\n"
245R"==(#endif )==""\n"
246R"==(#if IS_BWD )==""\n"
247R"==(KERNEL_ATTR )==""\n"
248R"==(__kernel void gen9_pooling_bwd(__global DATA_T *diff_src, __global int *ws, )==""\n"
249R"==(__global DATA_T *diff_dst) { )==""\n"
250R"==(const int mb0 = GWS_GET_MB(); )==""\n"
251R"==(#if UNROLL_MB )==""\n"
252R"==(const int mb1 = mb0 + MB / 4; )==""\n"
253R"==(const int mb2 = mb1 + MB / 4; )==""\n"
254R"==(const int mb3 = mb2 + MB / 4; )==""\n"
255R"==(#endif )==""\n"
256R"==(const int c = GWS_GET_C(); )==""\n"
257R"==(const int id = GWS_GET_ID(); )==""\n"
258R"==(const int ih = GWS_GET_IH(); )==""\n"
259R"==(const int iw = GWS_GET_IW(); )==""\n"
260R"==(#if USE_MB_C_BLOCK )==""\n"
261R"==(const int src_stride = (SRC_SB0 > 1) ? SRC_SB0 : SRC_S0; )==""\n"
262R"==(const int dst_stride = (DST_SB0 > 1) ? DST_SB0 : DST_S0; )==""\n"
263R"==(const int src_chunks_per_c_block = CHUNKS_PER_C_BLOCK; )==""\n"
264R"==(const int dst_chunks_per_c_block = CHUNKS_PER_C_BLOCK; )==""\n"
265R"==(#elif USE_ONLY_C_BLOCK )==""\n"
266R"==(const int src_stride = (SRC_B1 > 1) ? SRC_S1 : SUB_GROUP_SIZE; )==""\n"
267R"==(const int dst_stride = (DST_B1 > 1) ? DST_S1 : SUB_GROUP_SIZE; )==""\n"
268R"==(const int src_chunks_per_c_block )==""\n"
269R"==(= (SRC_B1 > 1) ? (SRC_B1 / SUB_GROUP_SIZE) : 1; )==""\n"
270R"==(const int dst_chunks_per_c_block )==""\n"
271R"==(= (DST_B1 > 1) ? (DST_B1 / SUB_GROUP_SIZE) : 1; )==""\n"
272R"==(#endif )==""\n"
273R"==(const int ws_stride = dst_stride; )==""\n"
274R"==(const int ws_chunks_per_c_block = dst_chunks_per_c_block; )==""\n"
275R"==(VECT_FLOAT_T S0 = 0, S1 = 0; )==""\n"
276R"==(#if UNROLL_MB )==""\n"
277R"==(VECT_FLOAT_T S2 = 0, S3 = 0; )==""\n"
278R"==(#endif )==""\n"
279R"==(for (int kd = 0; kd < KD; kd++) { )==""\n"
280R"==(int od = (id + PD - kd); )==""\n"
281R"==(if (od % SD != 0) continue; )==""\n"
282R"==(od /= SD; )==""\n"
283R"==(if (od < 0 || od >= OD) continue; )==""\n"
284R"==(for (int kh = 0; kh < KH; kh++) { )==""\n"
285R"==(int oh = (ih + PH - kh); )==""\n"
286R"==(if (oh % SH != 0) continue; )==""\n"
287R"==(oh /= SH; )==""\n"
288R"==(if (oh < 0 || oh >= OH) continue; )==""\n"
289R"==(for (int kw = 0; kw < KW; kw++) { )==""\n"
290R"==(int ow = (iw + PW - kw); )==""\n"
291R"==(if (ow % SW != 0) continue; )==""\n"
292R"==(ow /= SW; )==""\n"
293R"==(if (ow < 0 || ow >= OW) continue; )==""\n"
294R"==(const int dst_off0 = DST_OFF(mb0, c, od, oh, ow); )==""\n"
295R"==(#if UNROLL_MB )==""\n"
296R"==(const int dst_off1 = DST_OFF(mb1, c, od, oh, ow); )==""\n"
297R"==(const int dst_off2 = DST_OFF(mb2, c, od, oh, ow); )==""\n"
298R"==(const int dst_off3 = DST_OFF(mb3, c, od, oh, ow); )==""\n"
299R"==(#endif )==""\n"
300R"==(VECT_FLOAT_T D0 = CONVERT_VECT_FLOAT_T( )==""\n"
301R"==(read_vect_c_block(0, &diff_dst[dst_off0], c, dst_stride, )==""\n"
302R"==(dst_chunks_per_c_block)); )==""\n"
303R"==(#if UNROLL_MB )==""\n"
304R"==(VECT_FLOAT_T D1 = CONVERT_VECT_FLOAT_T( )==""\n"
305R"==(read_vect_c_block(0, &diff_dst[dst_off1], c, dst_stride, )==""\n"
306R"==(dst_chunks_per_c_block)); )==""\n"
307R"==(VECT_FLOAT_T D2 = CONVERT_VECT_FLOAT_T( )==""\n"
308R"==(read_vect_c_block(0, &diff_dst[dst_off2], c, dst_stride, )==""\n"
309R"==(dst_chunks_per_c_block)); )==""\n"
310R"==(VECT_FLOAT_T D3 = CONVERT_VECT_FLOAT_T( )==""\n"
311R"==(read_vect_c_block(0, &diff_dst[dst_off3], c, dst_stride, )==""\n"
312R"==(dst_chunks_per_c_block)); )==""\n"
313R"==(#else )==""\n"
314R"==(VECT_FLOAT_T D1 = CONVERT_VECT_FLOAT_T( )==""\n"
315R"==(read_vect_c_block(1, &diff_dst[dst_off0], c, dst_stride, )==""\n"
316R"==(dst_chunks_per_c_block)); )==""\n"
317R"==(#endif )==""\n"
318R"==(#if ALG_MAX )==""\n"
319R"==(VECT_INT_T WS0 = read_vect_c_block_int( )==""\n"
320R"==(0, &ws[dst_off0], c, ws_stride, ws_chunks_per_c_block); )==""\n"
321R"==(#if UNROLL_MB )==""\n"
322R"==(VECT_INT_T WS1 = read_vect_c_block_int( )==""\n"
323R"==(0, &ws[dst_off1], c, ws_stride, ws_chunks_per_c_block); )==""\n"
324R"==(VECT_INT_T WS2 = read_vect_c_block_int( )==""\n"
325R"==(0, &ws[dst_off2], c, ws_stride, ws_chunks_per_c_block); )==""\n"
326R"==(VECT_INT_T WS3 = read_vect_c_block_int( )==""\n"
327R"==(0, &ws[dst_off3], c, ws_stride, ws_chunks_per_c_block); )==""\n"
328R"==(#else )==""\n"
329R"==(VECT_INT_T WS1 = read_vect_c_block_int( )==""\n"
330R"==(1, &ws[dst_off0], c, ws_stride, ws_chunks_per_c_block); )==""\n"
331R"==(#endif )==""\n"
332R"==(VECT_INT_T CMP0 = isnotequal( )==""\n"
333R"==(AS_VECT_FLOAT_T(WS0 - kd * KH * KW - kh * KW - kw), )==""\n"
334R"==((VECT_FLOAT_T)0); )==""\n"
335R"==(D0 = select(D0, (VECT_FLOAT_T)0, CMP0); )==""\n"
336R"==(VECT_INT_T CMP1 = isnotequal( )==""\n"
337R"==(AS_VECT_FLOAT_T(WS1 - kd * KH * KW - kh * KW - kw), )==""\n"
338R"==((VECT_FLOAT_T)0); )==""\n"
339R"==(D1 = select(D1, (VECT_FLOAT_T)0, CMP1); )==""\n"
340R"==(#if UNROLL_MB )==""\n"
341R"==(VECT_INT_T CMP2 = isnotequal( )==""\n"
342R"==(AS_VECT_FLOAT_T(WS2 - kd * KH * KW - kh * KW - kw), )==""\n"
343R"==((VECT_FLOAT_T)0); )==""\n"
344R"==(D2 = select(D2, (VECT_FLOAT_T)0, CMP2); )==""\n"
345R"==(VECT_INT_T CMP3 = isnotequal( )==""\n"
346R"==(AS_VECT_FLOAT_T(WS3 - kd * KH * KW - kh * KW - kw), )==""\n"
347R"==((VECT_FLOAT_T)0); )==""\n"
348R"==(D3 = select(D3, (VECT_FLOAT_T)0, CMP3); )==""\n"
349R"==(#endif )==""\n"
350R"==(#endif )==""\n"
351R"==(#if ALG_AVG_NP )==""\n"
352R"==(const int id_start = max(id - kd, 0); )==""\n"
353R"==(const int ih_start = max(ih - kh, 0); )==""\n"
354R"==(const int iw_start = max(iw - kw, 0); )==""\n"
355R"==(const int id_end = min(id - kd + KD, ID); )==""\n"
356R"==(const int ih_end = min(ih - kh + KH, IH); )==""\n"
357R"==(const int iw_end = min(iw - kw + KW, IW); )==""\n"
358R"==(const int num_summands = (ih_end - ih_start) )==""\n"
359R"==(* (iw_end - iw_start) * (id_end - id_start); )==""\n"
360R"==(D0 /= num_summands; )==""\n"
361R"==(D1 /= num_summands; )==""\n"
362R"==(#endif )==""\n"
363R"==(S0 += D0; )==""\n"
364R"==(S1 += D1; )==""\n"
365R"==(#if UNROLL_MB )==""\n"
366R"==(S2 += D2; )==""\n"
367R"==(S3 += D3; )==""\n"
368R"==(#endif )==""\n"
369R"==(} )==""\n"
370R"==(} )==""\n"
371R"==(} )==""\n"
372R"==(#if ALG_AVG_P )==""\n"
373R"==(S0 /= KD * KH * KW; )==""\n"
374R"==(S1 /= KD * KH * KW; )==""\n"
375R"==(#if UNROLL_MB )==""\n"
376R"==(S2 /= KD * KH * KW; )==""\n"
377R"==(S3 /= KD * KH * KW; )==""\n"
378R"==(#endif )==""\n"
379R"==(#endif )==""\n"
380R"==(int src_off0 = SRC_OFF(mb0, c, id, ih, iw); )==""\n"
381R"==(#if UNROLL_MB )==""\n"
382R"==(int src_off1 = SRC_OFF(mb1, c, id, ih, iw); )==""\n"
383R"==(int src_off2 = SRC_OFF(mb2, c, id, ih, iw); )==""\n"
384R"==(int src_off3 = SRC_OFF(mb3, c, id, ih, iw); )==""\n"
385R"==(#endif )==""\n"
386R"==(write_vect_c_block(0, &diff_src[src_off0], c, src_stride, )==""\n"
387R"==(src_chunks_per_c_block, CONVERT_VECTOR_DATA_T(S0)); )==""\n"
388R"==(#if UNROLL_MB )==""\n"
389R"==(write_vect_c_block(0, &diff_src[src_off1], c, src_stride, )==""\n"
390R"==(src_chunks_per_c_block, CONVERT_VECTOR_DATA_T(S1)); )==""\n"
391R"==(write_vect_c_block(0, &diff_src[src_off2], c, src_stride, )==""\n"
392R"==(src_chunks_per_c_block, CONVERT_VECTOR_DATA_T(S2)); )==""\n"
393R"==(write_vect_c_block(0, &diff_src[src_off3], c, src_stride, )==""\n"
394R"==(src_chunks_per_c_block, CONVERT_VECTOR_DATA_T(S3)); )==""\n"
395R"==(#else )==""\n"
396R"==(write_vect_c_block(1, &diff_src[src_off0], c, src_stride, )==""\n"
397R"==(src_chunks_per_c_block, CONVERT_VECTOR_DATA_T(S1)); )==""\n"
398R"==(#endif )==""\n"
399R"==(} )==""\n"
400R"==(#endif )==""\n"
401R"==(inline DATA_T read_c_block(const __global DATA_T *ptr, int c) { )==""\n"
402R"==(#if C_W_PADDING % SUB_GROUP_SIZE != 0 )==""\n"
403R"==(int local_id = get_sub_group_local_id(); )==""\n"
404R"==(int tail = C_WO_PADDING - c; )==""\n"
405R"==(return (local_id < tail) ? ptr[local_id] : 0; )==""\n"
406R"==(#else )==""\n"
407R"==(return AS_DATA_T(BLOCK_READ((const __global BLOCK_DATA_T *)ptr)); )==""\n"
408R"==(#endif )==""\n"
409R"==(} )==""\n"
410R"==(#define CALC_VECT_LEN() \ )==""\n"
411R"==(({ \ )==""\n"
412R"==(int size; \ )==""\n"
413R"==(if (USE_ONLY_C_BLOCK == 1 \ )==""\n"
414R"==(&& VECT_DT_N > C_WO_PADDING / SUB_GROUP_SIZE + 1) \ )==""\n"
415R"==(size = C_WO_PADDING / SUB_GROUP_SIZE + 1; \ )==""\n"
416R"==(else \ )==""\n"
417R"==(size = VECT_DT_N; \ )==""\n"
418R"==(size; \ )==""\n"
419R"==(}) )==""\n"
420R"==(inline VECT_DATA_T read_vect_c_block(int idx, const __global DATA_T *ptr, int c, )==""\n"
421R"==(int blocks_stride, int chunks_per_block) { )==""\n"
422R"==(if (idx >= NVECT) return 0; )==""\n"
423R"==(if ((blocks_stride == chunks_per_block * SUB_GROUP_SIZE) )==""\n"
424R"==(&& (C_WO_PADDING % (chunks_per_block * SUB_GROUP_SIZE) == 0)) { )==""\n"
425R"==(return AS_VECT_DATA_T(VECT_BLOCK_READ((const __global BLOCK_DATA_T *)ptr )==""\n"
426R"==(+ idx * VECT_DT_N * SUB_GROUP_SIZE)); )==""\n"
427R"==(} else { )==""\n"
428R"==(VECT_DATA_T ret; )==""\n"
429R"==(for (int i = 0; i < CALC_VECT_LEN(); i++) { )==""\n"
430R"==(const int offset_index = (idx * VECT_DT_N + i); )==""\n"
431R"==(const int local_c_block_index = offset_index % chunks_per_block; )==""\n"
432R"==(const int global_c_block_index = offset_index / chunks_per_block; )==""\n"
433R"==(const int ptr_offset = local_c_block_index * SUB_GROUP_SIZE )==""\n"
434R"==(+ global_c_block_index * blocks_stride; )==""\n"
435R"==(const int c_off )==""\n"
436R"==(= (USE_ONLY_C_BLOCK ? offset_index * SUB_GROUP_SIZE )==""\n"
437R"==(: local_c_block_index * SUB_GROUP_SIZE); )==""\n"
438R"==(#if VECT_DT_N == 1 )==""\n"
439R"==(ret = read_c_block(ptr + ptr_offset, c + c_off); )==""\n"
440R"==(#else )==""\n"
441R"==(ret[i] = read_c_block(ptr + ptr_offset, c + c_off); )==""\n"
442R"==(#endif )==""\n"
443R"==(} )==""\n"
444R"==(#if VECT_DT_N > 1 )==""\n"
445R"==(for (int i = CALC_VECT_LEN(); i < VECT_DT_N; ++i) { )==""\n"
446R"==(ret[i] = 0; )==""\n"
447R"==(} )==""\n"
448R"==(#endif )==""\n"
449R"==(return ret; )==""\n"
450R"==(} )==""\n"
451R"==(} )==""\n"
452R"==(inline int read_c_block_int(const __global int *ptr, int c) { )==""\n"
453R"==(#if C_W_PADDING % SUB_GROUP_SIZE != 0 )==""\n"
454R"==(int local_id = get_sub_group_local_id(); )==""\n"
455R"==(int tail = C_WO_PADDING - c; )==""\n"
456R"==(return (local_id < tail) ? ptr[local_id] : 0; )==""\n"
457R"==(#else )==""\n"
458R"==(return as_int(intel_sub_group_block_read((const __global uint *)ptr)); )==""\n"
459R"==(#endif )==""\n"
460R"==(} )==""\n"
461R"==(inline VECT_INT_T read_vect_c_block_int(int idx, const __global int *ptr, int c, )==""\n"
462R"==(int blocks_stride, int chunks_per_block) { )==""\n"
463R"==(if (idx >= NVECT) return 0; )==""\n"
464R"==(if ((blocks_stride == chunks_per_block * SUB_GROUP_SIZE) )==""\n"
465R"==(&& (C_WO_PADDING % (chunks_per_block * SUB_GROUP_SIZE) == 0)) { )==""\n"
466R"==(return AS_VECT_INT_T(VECT_UINT_READ( )==""\n"
467R"==((const __global uint *)ptr + idx * VECT_DT_N * SUB_GROUP_SIZE)); )==""\n"
468R"==(} else { )==""\n"
469R"==(VECT_INT_T ret; )==""\n"
470R"==(for (int i = 0; i < VECT_DT_N; i++) { )==""\n"
471R"==(const int offset_index = (idx * VECT_DT_N + i); )==""\n"
472R"==(const int local_c_block_index = offset_index % chunks_per_block; )==""\n"
473R"==(const int global_c_block_index = offset_index / chunks_per_block; )==""\n"
474R"==(const int ptr_offset = local_c_block_index * SUB_GROUP_SIZE )==""\n"
475R"==(+ global_c_block_index * blocks_stride; )==""\n"
476R"==(const int c_off )==""\n"
477R"==(= (USE_ONLY_C_BLOCK ? offset_index * SUB_GROUP_SIZE )==""\n"
478R"==(: local_c_block_index * SUB_GROUP_SIZE); )==""\n"
479R"==(#if VECT_DT_N == 1 )==""\n"
480R"==(ret = read_c_block_int(ptr + ptr_offset, c + c_off); )==""\n"
481R"==(#else )==""\n"
482R"==(ret[i] = read_c_block_int(ptr + ptr_offset, c + c_off); )==""\n"
483R"==(#endif )==""\n"
484R"==(} )==""\n"
485R"==(return ret; )==""\n"
486R"==(} )==""\n"
487R"==(} )==""\n"
488R"==(inline void write_c_block(__global DATA_T *ptr, int c, DATA_T value) { )==""\n"
489R"==(#if C_W_PADDING % SUB_GROUP_SIZE != 0 )==""\n"
490R"==(int local_id = get_sub_group_local_id(); )==""\n"
491R"==(int tail = C_WO_PADDING - c; )==""\n"
492R"==(if (local_id < tail) ptr[local_id] = value; )==""\n"
493R"==(#else )==""\n"
494R"==(#if C_WO_PADDING % SUB_GROUP_SIZE != 0 )==""\n"
495R"==(int local_id = get_sub_group_local_id(); )==""\n"
496R"==(if (local_id >= C_WO_PADDING - c && local_id < C_W_PADDING - c) value = 0; )==""\n"
497R"==(#endif )==""\n"
498R"==(if (c >= C_WO_PADDING) { )==""\n"
499R"==(BLOCK_WRITE((__global BLOCK_DATA_T *)ptr, )==""\n"
500R"==(AS_BLOCK_DATA_T(CONVERT_DATA_T(DATA_ZERO))); )==""\n"
501R"==(return; )==""\n"
502R"==(} )==""\n"
503R"==(BLOCK_WRITE((__global BLOCK_DATA_T *)ptr, AS_BLOCK_DATA_T(value)); )==""\n"
504R"==(#endif )==""\n"
505R"==(} )==""\n"
506R"==(inline void write_vect_c_block(int idx, __global DATA_T *ptr, int c, )==""\n"
507R"==(int blocks_stride, int chunks_per_block, VECT_DATA_T block) { )==""\n"
508R"==(if (idx >= NVECT) return; )==""\n"
509R"==(if ((blocks_stride == chunks_per_block * SUB_GROUP_SIZE) )==""\n"
510R"==(&& (C_WO_PADDING % (chunks_per_block * SUB_GROUP_SIZE) == 0)) { )==""\n"
511R"==(VECT_BLOCK_WRITE( )==""\n"
512R"==((__global BLOCK_DATA_T *)ptr + idx * VECT_DT_N * SUB_GROUP_SIZE, )==""\n"
513R"==(AS_VECT_BLOCK_DATA_T(block)); )==""\n"
514R"==(} else { )==""\n"
515R"==(for (int i = 0; i < VECT_DT_N; i++) { )==""\n"
516R"==(const int offset_index = (idx * VECT_DT_N + i); )==""\n"
517R"==(const int local_c_block_index = offset_index % chunks_per_block; )==""\n"
518R"==(const int global_c_block_index = offset_index / chunks_per_block; )==""\n"
519R"==(const int ptr_offset = local_c_block_index * SUB_GROUP_SIZE )==""\n"
520R"==(+ global_c_block_index * blocks_stride; )==""\n"
521R"==(const int c_off )==""\n"
522R"==(= (USE_ONLY_C_BLOCK ? offset_index * SUB_GROUP_SIZE )==""\n"
523R"==(: local_c_block_index * SUB_GROUP_SIZE); )==""\n"
524R"==(#if VECT_DT_N == 1 )==""\n"
525R"==(write_c_block(ptr + ptr_offset, c + c_off, block); )==""\n"
526R"==(#else )==""\n"
527R"==(write_c_block(ptr + ptr_offset, c + c_off, block[i]); )==""\n"
528R"==(#endif )==""\n"
529R"==(} )==""\n"
530R"==(} )==""\n"
531R"==(} )==""\n"
532R"==(inline void write_c_block_int(__global int *ptr, int c, int value) { )==""\n"
533R"==(#if C_WO_PADDING % SUB_GROUP_SIZE != 0 )==""\n"
534R"==(int local_id = get_sub_group_local_id(); )==""\n"
535R"==(int tail = C_WO_PADDING - c; )==""\n"
536R"==(if (local_id < tail) )==""\n"
537R"==(ptr[local_id] = value; )==""\n"
538R"==(else if (local_id < C_W_PADDING - c) { )==""\n"
539R"==(ptr[local_id] = 0; )==""\n"
540R"==(} else )==""\n"
541R"==(return; )==""\n"
542R"==(#else )==""\n"
543R"==(if (c >= C_WO_PADDING) { )==""\n"
544R"==(intel_sub_group_block_write((__global uint *)ptr, 0); )==""\n"
545R"==(return; )==""\n"
546R"==(} )==""\n"
547R"==(intel_sub_group_block_write((__global uint *)ptr, as_uint(value)); )==""\n"
548R"==(#endif )==""\n"
549R"==(} )==""\n"
550R"==(inline void write_vect_c_block_int(int idx, __global int *ptr, int c, )==""\n"
551R"==(int blocks_stride, int chunks_per_block, VECT_INT_T block) { )==""\n"
552R"==(if (idx >= NVECT) return; )==""\n"
553R"==(if ((blocks_stride == chunks_per_block * SUB_GROUP_SIZE) )==""\n"
554R"==(&& (C_WO_PADDING % (chunks_per_block * SUB_GROUP_SIZE) == 0)) { )==""\n"
555R"==(VECT_UINT_WRITE((__global uint *)ptr + idx * VECT_DT_N * SUB_GROUP_SIZE, )==""\n"
556R"==(AS_VECT_UINT_T(block)); )==""\n"
557R"==(} else { )==""\n"
558R"==(for (int i = 0; i < VECT_DT_N; i++) { )==""\n"
559R"==(const int offset_index = (idx * VECT_DT_N + i); )==""\n"
560R"==(const int local_c_block_index = offset_index % chunks_per_block; )==""\n"
561R"==(const int global_c_block_index = offset_index / chunks_per_block; )==""\n"
562R"==(const int ptr_offset = local_c_block_index * SUB_GROUP_SIZE )==""\n"
563R"==(+ global_c_block_index * blocks_stride; )==""\n"
564R"==(const int c_off )==""\n"
565R"==(= (USE_ONLY_C_BLOCK ? offset_index * SUB_GROUP_SIZE )==""\n"
566R"==(: local_c_block_index * SUB_GROUP_SIZE); )==""\n"
567R"==(#if VECT_DT_N == 1 )==""\n"
568R"==(write_c_block_int(ptr + ptr_offset, c + c_off, block); )==""\n"
569R"==(#else )==""\n"
570R"==(write_c_block_int(ptr + ptr_offset, c + c_off, block[i]); )==""\n"
571R"==(#endif )==""\n"
572R"==(} )==""\n"
573R"==(} )==""\n"
574R"==(} )==""\n"
575R"==()==";
576}
577}
578}
579}