1namespace dnnl {
2namespace impl {
3namespace gpu {
4namespace ocl {
5const char *gen9_conv_fwd_data_kernel = R"==(/******************************************************************************* )==""\n"
6R"==(* Copyright 2020-2021 Intel Corporation )==""\n"
7R"==(* )==""\n"
8R"==(* Licensed under the Apache License, Version 2.0 (the "License"); )==""\n"
9R"==(* you may not use this file except in compliance with the License. )==""\n"
10R"==(* You may obtain a copy of the License at )==""\n"
11R"==(* )==""\n"
12R"==(* http: )==""\n"
13R"==(* )==""\n"
14R"==(* Unless required by applicable law or agreed to in writing, software )==""\n"
15R"==(* distributed under the License is distributed on an "AS IS" BASIS, )==""\n"
16R"==(* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. )==""\n"
17R"==(* See the License for the specific language governing permissions and )==""\n"
18R"==(* limitations under the License. )==""\n"
19R"==(*******************************************************************************/ )==""\n"
20R"==(#include "gpu/ocl/ocl_post_ops.h" )==""\n"
21R"==(#include "gpu/ocl/ocl_types.h" )==""\n"
22R"==(#include "gpu/ocl/offsets.h" )==""\n"
23R"==(#define BASE_OC_TAIL (SUB_GROUP_SIZE - (OC - OC_WO_PADDING)) )==""\n"
24R"==(#define _BLOCK_READ_DST(v, n_c, ptr) \ )==""\n"
25R"==(if (n_c == 8) \ )==""\n"
26R"==((*(DATA8_T *)(v)) = CONVERT_DATA8_T( \ )==""\n"
27R"==(vload8(0, (const __global DST_DATA_T *)(ptr))); \ )==""\n"
28R"==(if (n_c == 4) \ )==""\n"
29R"==((*(DATA4_T *)(v)) = CONVERT_DATA4_T( \ )==""\n"
30R"==(vload4(0, (const __global DST_DATA_T *)(ptr))); \ )==""\n"
31R"==(if (n_c == 2) \ )==""\n"
32R"==((*(DATA2_T *)(v)) = CONVERT_DATA2_T( \ )==""\n"
33R"==(vload2(0, (const __global DST_DATA_T *)(ptr))); \ )==""\n"
34R"==(if (n_c == 1) \ )==""\n"
35R"==((*(DATA_T *)(v)) = CONVERT_DATA_T(*(__global DST_DATA_T *)(ptr)); )==""\n"
36R"==(#define _BLOCK_READ8(ptr) \ )==""\n"
37R"==(AS_DATA8_T(BLOCK_READ8((const __global BLOCK_DATA_T *)(ptr))) )==""\n"
38R"==(#define _BLOCK_READ4(ptr) \ )==""\n"
39R"==(AS_DATA4_T(BLOCK_READ4((const __global BLOCK_DATA_T *)(ptr))) )==""\n"
40R"==(#define _BLOCK_READ2(ptr) \ )==""\n"
41R"==(AS_DATA2_T(BLOCK_READ2((const __global BLOCK_DATA_T *)(ptr))) )==""\n"
42R"==(#define _BLOCK_READ(ptr) \ )==""\n"
43R"==(AS_DATA_T(BLOCK_READ((const __global BLOCK_DATA_T *)(ptr))) )==""\n"
44R"==(#ifdef DST_DT_S8 )==""\n"
45R"==(#if DST_NCHW )==""\n"
46R"==(#define _BLOCK_WRITE8(ptr, v) \ )==""\n"
47R"==(vstore8(CONVERT_DST_DATA8_T(v), 0, (__global DST_DATA_T *)(ptr)); )==""\n"
48R"==(#define _BLOCK_WRITE4(ptr, v) \ )==""\n"
49R"==(vstore4(CONVERT_DST_DATA4_T(v), 0, (__global DST_DATA_T *)(ptr)); )==""\n"
50R"==(#define _BLOCK_WRITE2(ptr, v) \ )==""\n"
51R"==(vstore2(CONVERT_DST_DATA2_T(v), 0, (__global DST_DATA_T *)(ptr)); )==""\n"
52R"==(#define _BLOCK_WRITE(ptr, v) \ )==""\n"
53R"==(*(__global DST_DATA_T *)(ptr) = CONVERT_DST_DATA_T(v); )==""\n"
54R"==(#else )==""\n"
55R"==(#define _BLOCK_WRITE8(ptr, v) \ )==""\n"
56R"==(BLOCK_WRITE_DST8((__global DST_DATA_T *)(ptr), CONVERT_DST_DATA8_T(v)); )==""\n"
57R"==(#define _BLOCK_WRITE4(ptr, v) \ )==""\n"
58R"==(BLOCK_WRITE_DST4((__global DST_DATA_T *)(ptr), CONVERT_DST_DATA4_T(v)); )==""\n"
59R"==(#define _BLOCK_WRITE2(ptr, v) \ )==""\n"
60R"==(BLOCK_WRITE_DST2((__global DST_DATA_T *)(ptr), CONVERT_DST_DATA2_T(v)); )==""\n"
61R"==(#define _BLOCK_WRITE(ptr, v) \ )==""\n"
62R"==(BLOCK_WRITE_DST((__global DST_DATA_T *)(ptr), CONVERT_DST_DATA_T(v)); )==""\n"
63R"==(#endif )==""\n"
64R"==(#else )==""\n"
65R"==(#define _BLOCK_WRITE8(ptr, v) \ )==""\n"
66R"==(BLOCK_WRITE8((__global BLOCK_DATA_T *)(ptr), AS_BLOCK_DATA8_T(v)); )==""\n"
67R"==(#define _BLOCK_WRITE4(ptr, v) \ )==""\n"
68R"==(BLOCK_WRITE4((__global BLOCK_DATA_T *)(ptr), AS_BLOCK_DATA4_T(v)); )==""\n"
69R"==(#define _BLOCK_WRITE2(ptr, v) \ )==""\n"
70R"==(BLOCK_WRITE2((__global BLOCK_DATA_T *)(ptr), AS_BLOCK_DATA2_T(v)); )==""\n"
71R"==(#define _BLOCK_WRITE(ptr, v) \ )==""\n"
72R"==(BLOCK_WRITE((__global BLOCK_DATA_T *)(ptr), AS_BLOCK_DATA_T(v)); )==""\n"
73R"==(#endif )==""\n"
74R"==(#define IS_3D (OD > 1) )==""\n"
75R"==(#define IS_1STCONV (IC == 3) )==""\n"
76R"==(#define HAS_PAD_D (PD > 0 || OD * SD - PD + (KD - 1) * (1 + DD) >= ID) )==""\n"
77R"==(#define HAS_PAD_H (PH > 0 || OH * SH - PH + (KH - 1) * (1 + DH) >= IH) )==""\n"
78R"==(#define HAS_PAD_W (PW > 0 || OW * SW - PW + (KW - 1) * (1 + DW) >= IW) )==""\n"
79R"==(#define ENABLE_SRC_BUF (MB_BLOCK == 1 && KW >= 3) )==""\n"
80R"==(#define W_VEC (IS_1STCONV && ENABLE_SRC_BUF && KW >= 5 && SW <= 3) )==""\n"
81R"==(#define C_VEC (!W_VEC) )==""\n"
82R"==(#define IC_INNER (C_VEC ? (IC < 16 ? IC : 16) : 1) )==""\n"
83R"==(#define IC_OUTER (IC_BLOCK / IC_INNER) )==""\n"
84R"==(#define OC_OUTER (OC_BLOCK / 16) )==""\n"
85R"==(#define IW_BLOCK (SW * (OW_BLOCK - 1) + (KW - 1) * (1 + DW) + 1) )==""\n"
86R"==(#define OW_INNER (C_VEC ? 1 : 16) )==""\n"
87R"==(#define IW_INNER OW_INNER )==""\n"
88R"==(#define OW_OUTER ((OW_BLOCK + OW_INNER - 1) / OW_INNER) )==""\n"
89R"==(#define IW_OUTER ((IW_BLOCK + IW_INNER - 1) / IW_INNER) )==""\n"
90R"==(#define C_SIZE (MB_BLOCK * OC_OUTER * OW_BLOCK) )==""\n"
91R"==(#if OW_BLOCK >= 32 )==""\n"
92R"==(#error "Block is too big for unrolled_read and unrolled_write." )==""\n"
93R"==(#endif )==""\n"
94R"==(int src_idx_c_vec(int mb_block, int ic_outer, int ow_outer) { )==""\n"
95R"==(if (SRC_16N16C) )==""\n"
96R"==(return ic_outer * OW_BLOCK * MB_BLOCK + ow_outer * MB_BLOCK + mb_block; )==""\n"
97R"==(return mb_block * IC_OUTER * OW_BLOCK + ic_outer * OW_BLOCK + ow_outer; )==""\n"
98R"==(} )==""\n"
99R"==(int src_idx_w_vec(int mb_block, int ic_outer, int ow_outer) { )==""\n"
100R"==(return mb_block * IC_OUTER * OW_OUTER + ic_outer * OW_OUTER + ow_outer; )==""\n"
101R"==(} )==""\n"
102R"==(int src_idx(int mb_block, int ic_outer, int ow_outer) { )==""\n"
103R"==(if (C_VEC) return src_idx_c_vec(mb_block, ic_outer, ow_outer); )==""\n"
104R"==(return src_idx_w_vec(mb_block, ic_outer, ow_outer); )==""\n"
105R"==(} )==""\n"
106R"==(int src_buf_idx_c_vec(int mb_block, int ic_outer, int iw_outer) { )==""\n"
107R"==(if (SRC_16N16C) )==""\n"
108R"==(return mb_block * IC_OUTER * IW_OUTER + ic_outer * IW_OUTER + iw_outer; )==""\n"
109R"==(return ic_outer * IW_OUTER * MB_BLOCK + iw_outer * MB_BLOCK + mb_block; )==""\n"
110R"==(} )==""\n"
111R"==(int src_buf_idx_w_vec(int mb_block, int ic_outer, int iw_outer) { )==""\n"
112R"==(return mb_block * IC_OUTER * IW_OUTER + ic_outer * IW_OUTER + iw_outer; )==""\n"
113R"==(} )==""\n"
114R"==(int src_buf_idx(int mb_block, int ic_outer, int iw_outer) { )==""\n"
115R"==(if (C_VEC) return src_buf_idx_c_vec(mb_block, ic_outer, iw_outer); )==""\n"
116R"==(return src_buf_idx_w_vec(mb_block, ic_outer, iw_outer); )==""\n"
117R"==(} )==""\n"
118R"==(int wei_idx_c_vec(int oc_outer, int ic_outer) { )==""\n"
119R"==(return ic_outer * OC_OUTER * IC_INNER + oc_outer * IC_INNER; )==""\n"
120R"==(} )==""\n"
121R"==(int wei_idx_w_vec(int oc_outer, int ic_outer) { )==""\n"
122R"==(return oc_outer * IC_OUTER + ic_outer; )==""\n"
123R"==(} )==""\n"
124R"==(int wei_idx(int oc_outer, int ic_outer) { )==""\n"
125R"==(if (C_VEC) return wei_idx_c_vec(oc_outer, ic_outer); )==""\n"
126R"==(return wei_idx_w_vec(oc_outer, ic_outer); )==""\n"
127R"==(} )==""\n"
128R"==(int dst_idx(int mb_block, int oc_outer, int ow_block) { )==""\n"
129R"==(if (DST_16N16C || DST_32N16C) )==""\n"
130R"==(return oc_outer * OW_BLOCK * MB_BLOCK + ow_block * MB_BLOCK + mb_block; )==""\n"
131R"==(return mb_block * OC_OUTER * OW_BLOCK + oc_outer * OW_BLOCK + ow_block; )==""\n"
132R"==(} )==""\n"
133R"==(#define copy(dst, src, n) \ )==""\n"
134R"==(do { \ )==""\n"
135R"==(for (int i = 0; i < (n); i++) \ )==""\n"
136R"==((dst)[i] = (src)[i]; \ )==""\n"
137R"==(} while (0) )==""\n"
138R"==(#define unrolled_read(n, block, ptr) \ )==""\n"
139R"==(do { \ )==""\n"
140R"==(if ((n)&16) { \ )==""\n"
141R"==(*((DATA8_T *)(block)) = _BLOCK_READ8((ptr)); \ )==""\n"
142R"==(*((DATA8_T *)((block) + 8)) = _BLOCK_READ8((ptr) + 8 * 16); \ )==""\n"
143R"==(} \ )==""\n"
144R"==(if ((n)&8) \ )==""\n"
145R"==(*((DATA8_T *)((block) + ((n) & ~15))) \ )==""\n"
146R"==(= _BLOCK_READ8((ptr) + ((n) & ~15) * 16); \ )==""\n"
147R"==(if ((n)&4) \ )==""\n"
148R"==(*((DATA4_T *)((block) + ((n) & ~7))) \ )==""\n"
149R"==(= _BLOCK_READ4((ptr) + ((n) & ~7) * 16); \ )==""\n"
150R"==(if ((n)&2) \ )==""\n"
151R"==(*((DATA2_T *)((block) + ((n) & ~3))) \ )==""\n"
152R"==(= _BLOCK_READ2((ptr) + ((n) & ~3) * 16); \ )==""\n"
153R"==(if ((n)&1) \ )==""\n"
154R"==(*((block) + ((n) & ~1)) = _BLOCK_READ((ptr) + ((n) & ~1) * 16); \ )==""\n"
155R"==(} while (0) )==""\n"
156R"==(#define strided_read(block, ptr, n, stride) \ )==""\n"
157R"==(do { \ )==""\n"
158R"==(if ((n) == 16 && (stride) == 1) { \ )==""\n"
159R"==((block)[0] = _BLOCK_READ((ptr)); \ )==""\n"
160R"==(} else { \ )==""\n"
161R"==(int sglid = get_sub_group_local_id(); \ )==""\n"
162R"==((block)[0] = (sglid < (n)) ? (ptr)[sglid * (stride)] : 0; \ )==""\n"
163R"==(} \ )==""\n"
164R"==(} while (0) )==""\n"
165R"==(#if DST_NCHW )==""\n"
166R"==(#define unrolled_write(n, block, ptr) \ )==""\n"
167R"==(do { \ )==""\n"
168R"==(if ((n)&16) { \ )==""\n"
169R"==(_BLOCK_WRITE8((ptr), *(DATA8_T *)((block))); \ )==""\n"
170R"==(_BLOCK_WRITE8((ptr) + 8 * 16, *(DATA8_T *)((block) + 8)); \ )==""\n"
171R"==(} \ )==""\n"
172R"==(if ((n)&8) \ )==""\n"
173R"==(_BLOCK_WRITE8( \ )==""\n"
174R"==((ptr) + ((n) & ~15), *(DATA8_T *)((block) + ((n) & ~15))); \ )==""\n"
175R"==(if ((n)&4) \ )==""\n"
176R"==(_BLOCK_WRITE4( \ )==""\n"
177R"==((ptr) + ((n) & ~7), *(DATA4_T *)((block) + ((n) & ~7))); \ )==""\n"
178R"==(if ((n)&2) \ )==""\n"
179R"==(_BLOCK_WRITE2( \ )==""\n"
180R"==((ptr) + ((n) & ~3), *(DATA2_T *)((block) + ((n) & ~3))); \ )==""\n"
181R"==(if ((n)&1) _BLOCK_WRITE((ptr) + ((n) & ~1), *((block) + ((n) & ~1))); \ )==""\n"
182R"==(} while (0) )==""\n"
183R"==(#else )==""\n"
184R"==(#define unrolled_write(n, block, ptr) \ )==""\n"
185R"==(do { \ )==""\n"
186R"==(if ((n)&16) { \ )==""\n"
187R"==(_BLOCK_WRITE8((ptr), *(DATA8_T *)((block))); \ )==""\n"
188R"==(_BLOCK_WRITE8((ptr) + 8 * 16, *(DATA8_T *)((block) + 8)); \ )==""\n"
189R"==(} \ )==""\n"
190R"==(if ((n)&8) \ )==""\n"
191R"==(_BLOCK_WRITE8((ptr) + ((n) & ~15) * 16, \ )==""\n"
192R"==(*(DATA8_T *)((block) + ((n) & ~15))); \ )==""\n"
193R"==(if ((n)&4) \ )==""\n"
194R"==(_BLOCK_WRITE4((ptr) + ((n) & ~7) * 16, \ )==""\n"
195R"==(*(DATA4_T *)((block) + ((n) & ~7))); \ )==""\n"
196R"==(if ((n)&2) \ )==""\n"
197R"==(_BLOCK_WRITE2((ptr) + ((n) & ~3) * 16, \ )==""\n"
198R"==(*(DATA2_T *)((block) + ((n) & ~3))); \ )==""\n"
199R"==(if ((n)&1) \ )==""\n"
200R"==(_BLOCK_WRITE((ptr) + ((n) & ~1) * 16, *((block) + ((n) & ~1))); \ )==""\n"
201R"==(} while (0) )==""\n"
202R"==(#endif )==""\n"
203R"==(#define multiply_row(C, A, B, mb_block, oc_outer, ic_outer, ow_outer) \ )==""\n"
204R"==(do { \ )==""\n"
205R"==(int b_off = wei_idx((oc_outer), (ic_outer)); \ )==""\n"
206R"==(int c_off = dst_idx((mb_block), (oc_outer), (ow_outer)); \ )==""\n"
207R"==(for (int ic_inner = 0; ic_inner < min(IC_WO_PADDING, IC_INNER); \ )==""\n"
208R"==(ic_inner++) { \ )==""\n"
209R"==((C)[c_off] = fma(intel_sub_group_shuffle((A), ic_inner), \ )==""\n"
210R"==((B)[b_off + ic_inner], (C)[c_off]); \ )==""\n"
211R"==(} \ )==""\n"
212R"==(} while (0) )==""\n"
213R"==(#define read_src_and_multiply_w16c_dense(ptr, iw, kw, C, B) \ )==""\n"
214R"==(do { \ )==""\n"
215R"==(for (int mb_block = 0; mb_block < MB_BLOCK; mb_block++) \ )==""\n"
216R"==(for (int ic_outer = 0; ic_outer < IC_OUTER; ic_outer++) { \ )==""\n"
217R"==(int idx = src_idx(mb_block, ic_outer, 0); \ )==""\n"
218R"==(for (int ow_block = 0; ow_block < OW_BLOCK; ow_block += 8) { \ )==""\n"
219R"==(int ow_bound = min(8, OW_BLOCK - ow_block); \ )==""\n"
220R"==(DATA_T A[8]; \ )==""\n"
221R"==(int off = src_off(mb_block, ic_outer * IC_INNER, 0, 0, \ )==""\n"
222R"==(ow_block + (kw) * (1 + DW)); \ )==""\n"
223R"==(unrolled_read(ow_bound, A, &(ptr)[off]); \ )==""\n"
224R"==(for (int oc_outer = 0; oc_outer < OC_OUTER; oc_outer++) { \ )==""\n"
225R"==(for (int i = 0; i < ow_bound; i++) { \ )==""\n"
226R"==(multiply_row((C), A[i], (B), mb_block, oc_outer, \ )==""\n"
227R"==(ic_outer, ow_block + i); \ )==""\n"
228R"==(} \ )==""\n"
229R"==(} \ )==""\n"
230R"==(} \ )==""\n"
231R"==(} \ )==""\n"
232R"==(} while (0) )==""\n"
233R"==(#define read_src_and_multiply_16n16c(ptr, iw, kw, do_w_check, C, B) \ )==""\n"
234R"==(do { \ )==""\n"
235R"==(for (int ow_block = 0; ow_block < OW_BLOCK; ow_block++) { \ )==""\n"
236R"==(int iw_off = ow_block * SW + (kw) * (1 + DW); \ )==""\n"
237R"==(if ((do_w_check) && HAS_PAD_W \ )==""\n"
238R"==(&& ((iw) + iw_off < 0 || (iw) + iw_off >= IW)) \ )==""\n"
239R"==(continue; \ )==""\n"
240R"==(for (int ic_outer = 0; ic_outer < IC_OUTER; ic_outer++) \ )==""\n"
241R"==(__attribute__((opencl_unroll_hint)) /* attr:no-format */ \ )==""\n"
242R"==(for (int mb_block = 0; mb_block < MB_BLOCK; \ )==""\n"
243R"==(mb_block += 8) { \ )==""\n"
244R"==(int mb_bound = min(8, MB_BLOCK - mb_block); \ )==""\n"
245R"==(DATA_T A[8]; \ )==""\n"
246R"==(int off = src_off( \ )==""\n"
247R"==(mb_block, ic_outer * IC_INNER, 0, 0, iw_off); \ )==""\n"
248R"==(unrolled_read(mb_bound, A, &(ptr)[off]); \ )==""\n"
249R"==(for (int oc_outer = 0; oc_outer < OC_OUTER; oc_outer++) { \ )==""\n"
250R"==(for (int i = 0; i < mb_bound; i++) { \ )==""\n"
251R"==(multiply_row((C), A[i], (B), mb_block + i, \ )==""\n"
252R"==(oc_outer, ic_outer, ow_block); \ )==""\n"
253R"==(} \ )==""\n"
254R"==(} \ )==""\n"
255R"==(} \ )==""\n"
256R"==(} \ )==""\n"
257R"==(} while (0) )==""\n"
258R"==(#define read_src_and_multiply_common(ptr, iw, kw, do_w_check, C, B) \ )==""\n"
259R"==(do { \ )==""\n"
260R"==(for (int i = 0; i < OW_BLOCK; i++) { \ )==""\n"
261R"==(int iw_off = i * SW + (kw) * (1 + DW); \ )==""\n"
262R"==(if ((do_w_check) && HAS_PAD_W \ )==""\n"
263R"==(&& ((iw) + iw_off < 0 || (iw) + iw_off >= IW)) \ )==""\n"
264R"==(continue; \ )==""\n"
265R"==(for (int ic_outer = 0; ic_outer < IC_OUTER; ic_outer++) { \ )==""\n"
266R"==(for (int mb_block = 0; mb_block < MB_BLOCK; mb_block++) { \ )==""\n"
267R"==(int off = src_off( \ )==""\n"
268R"==(mb_block, ic_outer * IC_INNER, 0, 0, iw_off); \ )==""\n"
269R"==(DATA_T A; \ )==""\n"
270R"==(strided_read(&A, &(ptr)[off], IC_INNER, \ )==""\n"
271R"==(src_off(0, 1, 0, 0, 0)); \ )==""\n"
272R"==(for (int oc_outer = 0; oc_outer < OC_OUTER; oc_outer++) { \ )==""\n"
273R"==(int b_off = wei_idx(oc_outer, ic_outer); \ )==""\n"
274R"==(int c_off = dst_idx(mb_block, oc_outer, i); \ )==""\n"
275R"==(for (int ic_inner = 0; \ )==""\n"
276R"==(ic_inner < min(IC_WO_PADDING, IC_INNER); \ )==""\n"
277R"==(ic_inner++) { \ )==""\n"
278R"==((C)[c_off] = fma( \ )==""\n"
279R"==(intel_sub_group_shuffle(A, ic_inner), \ )==""\n"
280R"==((B)[b_off + ic_inner], (C)[c_off]); \ )==""\n"
281R"==(} \ )==""\n"
282R"==(} \ )==""\n"
283R"==(} \ )==""\n"
284R"==(} \ )==""\n"
285R"==(} \ )==""\n"
286R"==(} while (0) )==""\n"
287R"==(#define read_src_and_multiply(ptr, iw, kw, do_w_check, C, B) \ )==""\n"
288R"==(do { \ )==""\n"
289R"==(if (SRC_W16C && (!(do_w_check) || (!HAS_PAD_W && SW == 1))) { \ )==""\n"
290R"==(read_src_and_multiply_w16c_dense((ptr), (iw), (kw), (C), (B)); \ )==""\n"
291R"==(} else if (SRC_16N16C) { \ )==""\n"
292R"==(read_src_and_multiply_16n16c( \ )==""\n"
293R"==((ptr), (iw), (kw), (do_w_check), (C), (B)); \ )==""\n"
294R"==(} else { \ )==""\n"
295R"==(read_src_and_multiply_common( \ )==""\n"
296R"==((ptr), (iw), (kw), (do_w_check), (C), (B)); \ )==""\n"
297R"==(} \ )==""\n"
298R"==(} while (0) )==""\n"
299R"==(#define read_src_buf(buf, ptr, iw) \ )==""\n"
300R"==(do { \ )==""\n"
301R"==(for (int iw_outer = 0; iw_outer < IW_OUTER; iw_outer++) { \ )==""\n"
302R"==(int iw_inner = (C_VEC ? 0 : get_sub_group_local_id()); \ )==""\n"
303R"==(int iw_block = iw_outer * IW_INNER + iw_inner; \ )==""\n"
304R"==(if (HAS_PAD_W && ((iw) + iw_block < 0 || (iw) + iw_block >= IW)) \ )==""\n"
305R"==(continue; \ )==""\n"
306R"==(for (int ic_outer = 0; ic_outer < IC_OUTER; ic_outer++) { \ )==""\n"
307R"==(for (int mb_block = 0; mb_block < MB_BLOCK; mb_block++) { \ )==""\n"
308R"==(int off = src_off( \ )==""\n"
309R"==(mb_block, ic_outer * IC_INNER, 0, 0, iw_block); \ )==""\n"
310R"==(int idx = src_buf_idx(mb_block, ic_outer, iw_outer); \ )==""\n"
311R"==(if (C_VEC) { \ )==""\n"
312R"==(strided_read(&(buf)[idx], &(ptr)[off], IC_INNER, \ )==""\n"
313R"==(src_off(0, 1, 0, 0, 0)); \ )==""\n"
314R"==(} else { \ )==""\n"
315R"==((buf)[idx] = (ptr)[off]; \ )==""\n"
316R"==(} \ )==""\n"
317R"==(} \ )==""\n"
318R"==(} \ )==""\n"
319R"==(} \ )==""\n"
320R"==(} while (0) )==""\n"
321R"==(#define read_wei_block(block, ptr) \ )==""\n"
322R"==(do { \ )==""\n"
323R"==(for (int oc_outer = 0; oc_outer < OC_OUTER; oc_outer++) { \ )==""\n"
324R"==(int ic_bound = min(IC_WO_PADDING, IC_BLOCK); \ )==""\n"
325R"==(for (int ic_block = 0; ic_block < ic_bound; ic_block += 16) { \ )==""\n"
326R"==(int off = wei_off(0, oc_outer * 16, ic_block, 0, 0, 0); \ )==""\n"
327R"==(int idx = wei_idx(oc_outer, ic_block); \ )==""\n"
328R"==(unrolled_read(min(16, ic_bound - ic_block), &(block)[idx], \ )==""\n"
329R"==(&(ptr)[off]); \ )==""\n"
330R"==(} \ )==""\n"
331R"==(} \ )==""\n"
332R"==(} while (0) )==""\n"
333R"==(#define read_wei_and_multiply_c_vec(wei, kw, C, A) \ )==""\n"
334R"==(do { \ )==""\n"
335R"==(for (int oc_outer = 0; oc_outer < OC_OUTER; oc_outer++) { \ )==""\n"
336R"==(int ic_bound1 = min(IC_WO_PADDING, IC_BLOCK); \ )==""\n"
337R"==(for (int ic_block = 0; ic_block < ic_bound1; ic_block += 8) { \ )==""\n"
338R"==(int ic_bound2 = min(8, ic_bound1 - ic_block); \ )==""\n"
339R"==(int off = wei_off(0, oc_outer * 16, ic_block, 0, 0, 0); \ )==""\n"
340R"==(DATA_T B[8]; \ )==""\n"
341R"==(unrolled_read(ic_bound2, B, &(wei)[off]); \ )==""\n"
342R"==(for (int mb_block = 0; mb_block < MB_BLOCK; mb_block++) { \ )==""\n"
343R"==(for (int ow_block = 0; ow_block < OW_BLOCK; ow_block++) { \ )==""\n"
344R"==(int iw_off = ow_block * SW + (kw) * (1 + DW); \ )==""\n"
345R"==(int buf_idx = src_buf_idx(mb_block, 0, iw_off); \ )==""\n"
346R"==(int c_off = dst_idx(mb_block, oc_outer, ow_block); \ )==""\n"
347R"==(for (int i = 0; i < ic_bound2; i++) { \ )==""\n"
348R"==((C)[c_off] \ )==""\n"
349R"==(= fma(intel_sub_group_shuffle( \ )==""\n"
350R"==((A)[buf_idx], ic_block + i), \ )==""\n"
351R"==(B[i], (C)[c_off]); \ )==""\n"
352R"==(} \ )==""\n"
353R"==(} \ )==""\n"
354R"==(} \ )==""\n"
355R"==(} \ )==""\n"
356R"==(} \ )==""\n"
357R"==(} while (0) )==""\n"
358R"==(DATA_T shuffle_a_value(int mb_block, int ic_block, int ow_outer, int ow_inner, )==""\n"
359R"==(int kw, const DATA_T *A) { )==""\n"
360R"==(int iw_off0 = ow_outer * OW_INNER * SW + kw * (1 + DW); )==""\n"
361R"==(int iw_outer0 = iw_off0 / IW_INNER; )==""\n"
362R"==(int buf_idx0 = src_buf_idx(mb_block, ic_block, iw_outer0); )==""\n"
363R"==(int iw_off = iw_off0 + ow_inner * SW; )==""\n"
364R"==(int iw_outer = iw_off / IW_INNER; )==""\n"
365R"==(DATA4_T A_shuf = 0; )==""\n"
366R"==(for (int i = 0; i < SW + 1; i++) { )==""\n"
367R"==(A_shuf[i] = (iw_outer0 + i < IW_OUTER) ? A[buf_idx0 + i] : 0; )==""\n"
368R"==(} )==""\n"
369R"==(A_shuf = AS_DATA4_T(intel_sub_group_shuffle( )==""\n"
370R"==(AS_BLOCK_DATA4_T(A_shuf), iw_off % IW_INNER)); )==""\n"
371R"==(return A_shuf[iw_outer - iw_outer0]; )==""\n"
372R"==(} )==""\n"
373R"==(#define read_wei_and_multiply_w_vec(wei, kw, C, A) \ )==""\n"
374R"==(do { \ )==""\n"
375R"==(for (int oc_outer = 0; oc_outer < OC_OUTER; oc_outer++) { \ )==""\n"
376R"==(int ic_bound1 = min(IC_WO_PADDING, IC_BLOCK); \ )==""\n"
377R"==(for (int ic_block = 0; ic_block < ic_bound1; ic_block += 8) { \ )==""\n"
378R"==(int ic_bound2 = min(8, ic_bound1 - ic_block); \ )==""\n"
379R"==(int off = wei_off(0, oc_outer * 16, ic_block, 0, 0, 0); \ )==""\n"
380R"==(DATA_T B[8]; \ )==""\n"
381R"==(unrolled_read(ic_bound2, B, &(wei)[off]); \ )==""\n"
382R"==(for (int mb_block = 0; mb_block < MB_BLOCK; mb_block++) \ )==""\n"
383R"==(/* IC_INNER is 1 with W vectorization. */ \ )==""\n"
384R"==(for (int ic_inner = 0; ic_inner < ic_bound2; ic_inner++) \ )==""\n"
385R"==(for (int ow_outer = 0; ow_outer < OW_OUTER; \ )==""\n"
386R"==(ow_outer++) { \ )==""\n"
387R"==(int ow_bound = min( \ )==""\n"
388R"==(OW_INNER, OW_BLOCK - ow_outer * OW_INNER); \ )==""\n"
389R"==(for (int i = 0; i < ow_bound; i++) { \ )==""\n"
390R"==(DATA_T A_val = shuffle_a_value(mb_block, \ )==""\n"
391R"==(ic_block + ic_inner, ow_outer, i, \ )==""\n"
392R"==((kw), (A)); \ )==""\n"
393R"==(int c_off = dst_idx(mb_block, oc_outer, \ )==""\n"
394R"==(ow_outer * OW_INNER + i); \ )==""\n"
395R"==((C)[c_off] = fma(A_val, \ )==""\n"
396R"==(B[ic_block + ic_inner], (C)[c_off]); \ )==""\n"
397R"==(} \ )==""\n"
398R"==(} \ )==""\n"
399R"==(} \ )==""\n"
400R"==(} \ )==""\n"
401R"==(} while (0) )==""\n"
402R"==(#define read_wei_and_multiply(wei, kw, C, A) \ )==""\n"
403R"==(do { \ )==""\n"
404R"==(if (W_VEC) \ )==""\n"
405R"==(read_wei_and_multiply_w_vec((wei), (kw), (C), (A)); \ )==""\n"
406R"==(else \ )==""\n"
407R"==(read_wei_and_multiply_c_vec((wei), (kw), (C), (A)); \ )==""\n"
408R"==(} while (0) )==""\n"
409R"==(#define read_dst_block(block, ptr, ow) \ )==""\n"
410R"==(do { \ )==""\n"
411R"==(if (DST_NCHW) { \ )==""\n"
412R"==(for (int mb_block = 0; mb_block < MB_BLOCK; mb_block++) { \ )==""\n"
413R"==(int n_channels = min(min(C_SIZE, OW - ow), OW_BLOCK); \ )==""\n"
414R"==(bool w_oc_tail = BASE_OC_TAIL > 0 \ )==""\n"
415R"==(&& OC_WO_PADDING - (ocb ? ocb : SUB_GROUP_SIZE) \ )==""\n"
416R"==(< OC_BLOCK; \ )==""\n"
417R"==(int loc_oc_tail = w_oc_tail ? BASE_OC_TAIL : SUB_GROUP_SIZE; \ )==""\n"
418R"==(int oc_loop = (!w_oc_tail || sglid < loc_oc_tail) \ )==""\n"
419R"==(? C_SIZE \ )==""\n"
420R"==(: n_channels; \ )==""\n"
421R"==(for (int oc_outer = 0; oc_outer < oc_loop; oc_outer += 1) { \ )==""\n"
422R"==(int oc_tail_idx = ((sglid < loc_oc_tail \ )==""\n"
423R"==(&& oc_outer >= n_channels) \ )==""\n"
424R"==(? (sglid + (16 * (C_SIZE / OW_BLOCK - 1))) \ )==""\n"
425R"==(: sglid); \ )==""\n"
426R"==(int off = dst_off( \ )==""\n"
427R"==(0, oc_tail_idx, 0, 0, (oc_outer % n_channels)); \ )==""\n"
428R"==(int idx = oc_outer; \ )==""\n"
429R"==(_BLOCK_READ_DST(&(block)[idx], 1, &(ptr)[off]); \ )==""\n"
430R"==(} \ )==""\n"
431R"==(} \ )==""\n"
432R"==(} else { \ )==""\n"
433R"==(int ow_bound = (OW % OW_BLOCK == 0) ? OW_BLOCK \ )==""\n"
434R"==(: min(OW_BLOCK, OW - (ow)); \ )==""\n"
435R"==(for (int mb_block = 0; mb_block < MB_BLOCK; mb_block++) \ )==""\n"
436R"==(for (int oc_outer = 0; oc_outer < OC_OUTER; oc_outer++) \ )==""\n"
437R"==(for (int ow_block = 0; ow_block < ow_bound; ow_block++) { \ )==""\n"
438R"==(int off = dst_off( \ )==""\n"
439R"==(mb_block, oc_outer * 16, 0, 0, ow_block); \ )==""\n"
440R"==(int idx = dst_idx(mb_block, oc_outer, ow_block); \ )==""\n"
441R"==(if (DST_W32C) { \ )==""\n"
442R"==(int off = dst_off( \ )==""\n"
443R"==(mb_block, oc_outer * 16, 0, 0, 0); \ )==""\n"
444R"==(int idx = dst_idx(mb_block, oc_outer, 0); \ )==""\n"
445R"==(for (int i = 0; i < OW_BLOCK && ow + i < OW; \ )==""\n"
446R"==(++i) { \ )==""\n"
447R"==(*((DATA_T *)&(block)[idx + i]) \ )==""\n"
448R"==(= CONVERT_DATA_T(BLOCK_READ_DST( \ )==""\n"
449R"==(&(ptr)[off + (32 * i)])); \ )==""\n"
450R"==(} \ )==""\n"
451R"==(} else if (DST_32N32C) { \ )==""\n"
452R"==(int off = dst_off(mb_block, \ )==""\n"
453R"==(oc_outer * (DST_32N32C ? 32 : 16), 0, 0, \ )==""\n"
454R"==(ow_block); \ )==""\n"
455R"==(int idx = dst_idx(mb_block, oc_outer, ow_block); \ )==""\n"
456R"==(mb_block ? mb_block += 31 : mb_block; \ )==""\n"
457R"==(for (int i = 0; i < 16; ++i) { \ )==""\n"
458R"==(*((DATA_T *)&(block)[idx + i]) \ )==""\n"
459R"==(= CONVERT_DATA_T(BLOCK_READ_DST( \ )==""\n"
460R"==(&(ptr)[off + (32 * i)])); \ )==""\n"
461R"==(} \ )==""\n"
462R"==(} else { \ )==""\n"
463R"==((block)[idx] = _BLOCK_READ(&(ptr)[off]); \ )==""\n"
464R"==(} \ )==""\n"
465R"==(} \ )==""\n"
466R"==(} \ )==""\n"
467R"==(} while (0) )==""\n"
468R"==(#define write_dst_block(block, ptr, ow) \ )==""\n"
469R"==(do { \ )==""\n"
470R"==(if (DST_W16C || DST_W32C) { \ )==""\n"
471R"==(for (int mb_block = 0; mb_block < MB_BLOCK; mb_block++) \ )==""\n"
472R"==(for (int oc_outer = 0; oc_outer < OC_OUTER; oc_outer++) { \ )==""\n"
473R"==(int off = dst_off(mb_block, oc_outer * 16, 0, 0, 0); \ )==""\n"
474R"==(int idx = dst_idx(mb_block, oc_outer, 0); \ )==""\n"
475R"==(if (DST_W32C) { \ )==""\n"
476R"==(for (int i = 0; i < OW_BLOCK && ow + i < OW; ++i) { \ )==""\n"
477R"==(unrolled_write(1, &(block)[idx + i], \ )==""\n"
478R"==(&(ptr)[off + (32 * i)]); \ )==""\n"
479R"==(} \ )==""\n"
480R"==(} else if (OW % OW_BLOCK == 0 || (ow) + OW_BLOCK <= OW) { \ )==""\n"
481R"==(unrolled_write(OW_BLOCK, &(block)[idx], &(ptr)[off]); \ )==""\n"
482R"==(} else { \ )==""\n"
483R"==(unrolled_write( \ )==""\n"
484R"==(OW % OW_BLOCK, &(block)[idx], &(ptr)[off]); \ )==""\n"
485R"==(} \ )==""\n"
486R"==(} \ )==""\n"
487R"==(} else if (DST_16N16C || DST_32N32C) { \ )==""\n"
488R"==(int ow_bound = (OW % OW_BLOCK == 0) ? OW_BLOCK \ )==""\n"
489R"==(: min(OW_BLOCK, OW - (ow)); \ )==""\n"
490R"==(for (int ow_block = 0; ow_block < ow_bound; ow_block++) \ )==""\n"
491R"==(for (int oc_outer = 0; oc_outer < OC_OUTER; oc_outer++) \ )==""\n"
492R"==(for (int mb_block = 0; mb_block < MB_BLOCK; \ )==""\n"
493R"==(mb_block += (DST_32N32C ? 32 : 16)) { \ )==""\n"
494R"==(int off = dst_off(mb_block, \ )==""\n"
495R"==(oc_outer * (DST_32N32C ? 32 : 16), 0, 0, \ )==""\n"
496R"==(ow_block); \ )==""\n"
497R"==(int idx = dst_idx(mb_block, oc_outer, ow_block); \ )==""\n"
498R"==(if (DST_32N32C && MB > 8) { \ )==""\n"
499R"==(for (int i = 0; i < C_SIZE; ++i) { \ )==""\n"
500R"==(unrolled_write(1, &(block)[idx + i], \ )==""\n"
501R"==(&(ptr)[off + (32 * i)]); \ )==""\n"
502R"==(} \ )==""\n"
503R"==(} else { \ )==""\n"
504R"==(unrolled_write(min(16, MB_BLOCK), &(block)[idx], \ )==""\n"
505R"==(&(ptr)[off]); \ )==""\n"
506R"==(} \ )==""\n"
507R"==(} \ )==""\n"
508R"==(} else if (DST_32N16C) { \ )==""\n"
509R"==(int ow_bound = (OW % OW_BLOCK == 0) ? OW_BLOCK \ )==""\n"
510R"==(: min(OW_BLOCK, OW - (ow)); \ )==""\n"
511R"==(for (int ow_block = 0; ow_block < ow_bound; ow_block++) \ )==""\n"
512R"==(for (int oc_outer = 0; oc_outer < OC_OUTER; oc_outer++) \ )==""\n"
513R"==(for (int mb_block = 0; mb_block < MB_BLOCK; \ )==""\n"
514R"==(mb_block += 32) { \ )==""\n"
515R"==(int off = dst_off( \ )==""\n"
516R"==(mb_block, oc_outer * 16, 0, 0, ow_block); \ )==""\n"
517R"==(int idx = dst_idx(mb_block, oc_outer, ow_block); \ )==""\n"
518R"==(unrolled_write(min(32, MB_BLOCK), &(block)[idx], \ )==""\n"
519R"==(&(ptr)[off]); \ )==""\n"
520R"==(} \ )==""\n"
521R"==(} else if (DST_NCHW && sglid < OC_WO_PADDING - oc) { \ )==""\n"
522R"==(for (int mb_block = 0; mb_block < MB_BLOCK; mb_block++) { \ )==""\n"
523R"==(int n_channels = min(min(C_SIZE, OW - ow), OW_BLOCK); \ )==""\n"
524R"==(bool w_oc_tail = BASE_OC_TAIL > 0 \ )==""\n"
525R"==(&& OC_WO_PADDING - (ocb ? ocb : SUB_GROUP_SIZE) \ )==""\n"
526R"==(< OC_BLOCK; \ )==""\n"
527R"==(int loc_oc_tail = w_oc_tail ? BASE_OC_TAIL : SUB_GROUP_SIZE; \ )==""\n"
528R"==(int oc_loop = (!w_oc_tail || sglid < loc_oc_tail) \ )==""\n"
529R"==(? C_SIZE \ )==""\n"
530R"==(: n_channels; \ )==""\n"
531R"==(for (int oc_outer = 0; oc_outer < oc_loop; oc_outer += 1) { \ )==""\n"
532R"==(int oc_tail_idx = ((sglid < loc_oc_tail \ )==""\n"
533R"==(&& oc_outer >= n_channels) \ )==""\n"
534R"==(? (sglid + (16 * (C_SIZE / OW_BLOCK - 1))) \ )==""\n"
535R"==(: sglid); \ )==""\n"
536R"==(int off = dst_off( \ )==""\n"
537R"==(0, oc_tail_idx, 0, 0, (oc_outer % n_channels)); \ )==""\n"
538R"==(int idx = oc_outer; \ )==""\n"
539R"==(unrolled_write(1, &(block)[oc_outer], &(ptr)[off]); \ )==""\n"
540R"==(} \ )==""\n"
541R"==(} \ )==""\n"
542R"==(} \ )==""\n"
543R"==(} while (0) )==""\n"
544R"==(#define loop_ic_outermost(src, wei, C, id, ih, iw) \ )==""\n"
545R"==(do { \ )==""\n"
546R"==(__attribute__((opencl_unroll_hint(1))) /* attr:no-format */ \ )==""\n"
547R"==(for (int ic = 0; ic < IC; ic += IC_BLOCK) { \ )==""\n"
548R"==(__attribute__((opencl_unroll_hint(1))) /* attr:no-format */ \ )==""\n"
549R"==(for (int kd = 0; kd < KD; kd++) { \ )==""\n"
550R"==(if (HAS_PAD_D \ )==""\n"
551R"==(&& ((id) + kd * (1 + DD) < 0 \ )==""\n"
552R"==(|| (id) + kd * (1 + DD) >= ID)) \ )==""\n"
553R"==(continue; \ )==""\n"
554R"==(__attribute__((opencl_unroll_hint(1))) /* attr:no-format */ \ )==""\n"
555R"==(for (int kh = 0; kh < KH; kh++) { \ )==""\n"
556R"==(if (HAS_PAD_H \ )==""\n"
557R"==(&& ((ih) + kh * (1 + DH) < 0 \ )==""\n"
558R"==(|| (ih) + kh * (1 + DH) >= IH)) \ )==""\n"
559R"==(continue; \ )==""\n"
560R"==(const __global DATA_T *src1 = (src) \ )==""\n"
561R"==(+ src_off(0, 0, kd * (1 + DD), kh * (1 + DH), 0); \ )==""\n"
562R"==(DATA_T A_buf[MB_BLOCK * IC_OUTER * IW_OUTER] = {0}; \ )==""\n"
563R"==(if (ENABLE_SRC_BUF) read_src_buf(A_buf, src1, (iw)); \ )==""\n"
564R"==(__attribute__( \ )==""\n"
565R"==((opencl_unroll_hint(KW))) /* attr:no-format */ \ )==""\n"
566R"==(for (int kw = 0; kw < KW; kw++) { \ )==""\n"
567R"==(const __global DATA_T *wei1 \ )==""\n"
568R"==(= (wei) + wei_off(0, 0, 0, kd, kh, kw); \ )==""\n"
569R"==(DATA_T B[IC_OUTER * OC_OUTER * IC_INNER]; \ )==""\n"
570R"==(if (ENABLE_SRC_BUF) { \ )==""\n"
571R"==(read_wei_and_multiply(wei1, kw, (C), A_buf); \ )==""\n"
572R"==(} else { \ )==""\n"
573R"==(read_wei_block(B, wei1); \ )==""\n"
574R"==(read_src_and_multiply(src1, (iw), kw, 1, (C), B); \ )==""\n"
575R"==(} \ )==""\n"
576R"==(} \ )==""\n"
577R"==(} \ )==""\n"
578R"==(} \ )==""\n"
579R"==((src) += src_off(0, IC_BLOCK, 0, 0, 0); \ )==""\n"
580R"==((wei) += wei_off(0, 0, IC_BLOCK, 0, 0, 0); \ )==""\n"
581R"==(} \ )==""\n"
582R"==(} while (0) )==""\n"
583R"==(#define loop_kdhw_outermost(src, wei, C, id, ih, iw) \ )==""\n"
584R"==(do { \ )==""\n"
585R"==(for (int kd = 0; kd < KD; kd++) { \ )==""\n"
586R"==(if (HAS_PAD_D \ )==""\n"
587R"==(&& ((id) + kd * (1 + DD) < 0 \ )==""\n"
588R"==(|| (id) + kd * (1 + DD) >= ID)) \ )==""\n"
589R"==(continue; \ )==""\n"
590R"==(for (int kh = 0; kh < KH; kh++) { \ )==""\n"
591R"==(if (HAS_PAD_H \ )==""\n"
592R"==(&& ((ih) + kh * (1 + DH) < 0 \ )==""\n"
593R"==(|| (ih) + kh * (1 + DH) >= IH)) \ )==""\n"
594R"==(continue; \ )==""\n"
595R"==(for (int kw = 0; kw < KW; kw++) { \ )==""\n"
596R"==(if (HAS_PAD_W \ )==""\n"
597R"==(&& ((iw) + kw * (1 + DW) < 0 \ )==""\n"
598R"==(|| (iw) + kw * (1 + DW) >= IW)) \ )==""\n"
599R"==(continue; \ )==""\n"
600R"==(/* XXX: kw offset is applied in read_src_and_multiply(). */ \ )==""\n"
601R"==(const __global DATA_T *src1 = (src) \ )==""\n"
602R"==(+ src_off(0, 0, kd * (1 + DD), kh * (1 + DH), 0); \ )==""\n"
603R"==(const __global DATA_T *wei1 \ )==""\n"
604R"==(= (wei) + wei_off(0, 0, 0, kd, kh, kw); \ )==""\n"
605R"==(__attribute__((opencl_unroll_hint)) /* attr:no-format */ \ )==""\n"
606R"==(for (int ic = 0; ic < IC; ic += IC_BLOCK) { \ )==""\n"
607R"==(DATA_T B[IC_OUTER * OC_OUTER * IC_INNER]; \ )==""\n"
608R"==(read_wei_block(B, wei1); \ )==""\n"
609R"==(read_src_and_multiply(src1, (iw), kw, 0, (C), B); \ )==""\n"
610R"==(src1 += src_off(0, IC_BLOCK, 0, 0, 0); \ )==""\n"
611R"==(wei1 += wei_off(0, 0, IC_BLOCK, 0, 0, 0); \ )==""\n"
612R"==(} \ )==""\n"
613R"==(} \ )==""\n"
614R"==(} \ )==""\n"
615R"==(} \ )==""\n"
616R"==(} while (0) )==""\n"
617R"==(__attribute__((reqd_work_group_size(LWS_0, LWS_1, LWS_2))) )==""\n"
618R"==(__attribute__((intel_reqd_sub_group_size(SUB_GROUP_SIZE))) __kernel void )==""\n"
619R"==(gen9_conv_fwd(const __global DATA_T *src, const __global DATA_T *wei, )==""\n"
620R"==(const __global DATA_T *bia, __global DST_DATA_T *dst POST_OP_ARGS) { )==""\n"
621R"==(MAYBE_SKIP_NON_UNIFORM_WG(); )==""\n"
622R"==(int sglid = get_sub_group_local_id(); )==""\n"
623R"==(int g_ocb = get_group_id(0) * (LWS_0 / SUB_GROUP_SIZE) + get_sub_group_id(); )==""\n"
624R"==(int g = g_ocb / (OCB / OC_BLOCK); )==""\n"
625R"==(int ocb = g_ocb % (OCB / OC_BLOCK) * OC_BLOCK; )==""\n"
626R"==(int odhw = get_group_id(1) / (OMB / MB_BLOCK); )==""\n"
627R"==(int omb = get_group_id(1) % (OMB / MB_BLOCK) * MB_BLOCK; )==""\n"
628R"==(int ohw = IS_3D ? odhw % (OWB * OHB) : odhw; )==""\n"
629R"==(int od = IS_3D ? odhw / (OWB * OHB) : 0; )==""\n"
630R"==(int oh = (ohw / OWB) * OH_BLOCK; )==""\n"
631R"==(int ow = (ohw % OWB) * OW_BLOCK; )==""\n"
632R"==(int ocb_idx_omb_idx = get_group_id(2); )==""\n"
633R"==(int ocb_idx = ocb_idx_omb_idx / (MB / OMB); )==""\n"
634R"==(int omb_idx = ocb_idx_omb_idx % (MB / OMB); )==""\n"
635R"==(int oc = ocb_idx * OCB + ocb; )==""\n"
636R"==(int mb = omb_idx * OMB + omb; )==""\n"
637R"==(int ih = oh * SH - PH; )==""\n"
638R"==(int iw = ow * SW - PW; )==""\n"
639R"==(int id = od * SD - PD; )==""\n"
640R"==(#if DT_F16 && C_SIZE == 8 )==""\n"
641R"==(DATA8_T C = 0; )==""\n"
642R"==(#elif DT_F16 && C_SIZE == 16 )==""\n"
643R"==(DATA16_T C = 0; )==""\n"
644R"==(#else )==""\n"
645R"==(DATA_T C[C_SIZE] = {0}; )==""\n"
646R"==(#endif )==""\n"
647R"==(src += src_off(mb, g * IC, id, ih, iw); )==""\n"
648R"==(wei += wei_off(g, oc, 0, 0, 0, 0); )==""\n"
649R"==(dst += dst_off(mb, g * OC + oc, od, oh, ow); )==""\n"
650R"==(if ((DST_32N32C || DST_W32C) && (OC_BLOCK % 32 != 0) )==""\n"
651R"==(&& oc > OC_WO_PADDING) { )==""\n"
652R"==(write_dst_block((DATA_T *)(&C), dst, ow); )==""\n"
653R"==(return; )==""\n"
654R"==(} )==""\n"
655R"==(if (WITH_BIAS) { )==""\n"
656R"==(for (int mb_block = 0; mb_block < MB_BLOCK; mb_block++) { )==""\n"
657R"==(for (int oc_outer = 0; oc_outer < OC_OUTER; oc_outer++) { )==""\n"
658R"==(const int bg_off = g * OC; )==""\n"
659R"==(const int bc_off = oc + oc_outer * 16 + sglid; )==""\n"
660R"==(#if (DT_F16 || DST_DT_U8 || DST_DT_S8) )==""\n"
661R"==(#if OC_WO_PADDING == SUB_GROUP_SIZE && (DST_DT_U8 || DST_DT_S8) )==""\n"
662R"==(if (OC_WO_PADDING % OC_BLOCK == 0 && bc_off < OC_WO_PADDING) { )==""\n"
663R"==(#else )==""\n"
664R"==(if (OC_WO_PADDING % OC_BLOCK == 0 || bc_off < OC_WO_PADDING) { )==""\n"
665R"==(#endif )==""\n"
666R"==(for (int ow_block = 0; ow_block < OW_BLOCK; ow_block++) { )==""\n"
667R"==(const int c_off = dst_idx(mb_block, oc_outer, ow_block); )==""\n"
668R"==(C[c_off] = bia[bg_off + bc_off]; )==""\n"
669R"==(} )==""\n"
670R"==(} )==""\n"
671R"==(#else )==""\n"
672R"==(for (int ow_block = 0; ow_block < OW_BLOCK; ow_block++) { )==""\n"
673R"==(const int c_off = dst_idx(mb_block, oc_outer, ow_block); )==""\n"
674R"==(C[c_off] = (OC_WO_PADDING % OC_BLOCK == 0 )==""\n"
675R"==(|| bc_off < OC_WO_PADDING) )==""\n"
676R"==(? bia[bg_off + bc_off] )==""\n"
677R"==(: DATA_ZERO; )==""\n"
678R"==(} )==""\n"
679R"==(#endif )==""\n"
680R"==(} )==""\n"
681R"==(} )==""\n"
682R"==(} )==""\n"
683R"==(#if OC_WO_PADDING == SUB_GROUP_SIZE && (DST_DT_U8 || DST_DT_S8) )==""\n"
684R"==(const bool DO_MUL = ((OC_BLOCK * g_ocb) < OC_WO_PADDING); )==""\n"
685R"==(#else )==""\n"
686R"==(const bool DO_MUL = true; )==""\n"
687R"==(#endif )==""\n"
688R"==(if (OW_BLOCK == 1 && DO_MUL) { )==""\n"
689R"==(loop_kdhw_outermost(src, wei, C, id, ih, iw); )==""\n"
690R"==(} else if (DO_MUL) { )==""\n"
691R"==(loop_ic_outermost(src, wei, C, id, ih, iw); )==""\n"
692R"==(} )==""\n"
693R"==(DATA_T S[MB_BLOCK * OC_OUTER * OW_BLOCK]; )==""\n"
694R"==(if (WITH_SUM) { read_dst_block(S, dst, ow); } )==""\n"
695R"==(#if WITH_POST_OP )==""\n"
696R"==(if (OW_BLOCK == 1) { )==""\n"
697R"==(const int po_mb = mb; )==""\n"
698R"==(const int po_oc = (g * OC + oc) % (OC * G); )==""\n"
699R"==(APPLY_POST_OPS_TRY_BURST(C, DATA_T, S, DATA_T, po_mb, MB_BLOCK, po_oc, )==""\n"
700R"==(OC_OUTER * SUB_GROUP_SIZE, sglid); )==""\n"
701R"==(} else { )==""\n"
702R"==(unroll_for(int mb_idx = 0; mb_idx < MB_BLOCK; ++mb_idx) { )==""\n"
703R"==(unroll_for(int oc_idx = 0; oc_idx < OC_OUTER; ++oc_idx) { )==""\n"
704R"==(const int po_mb = (mb + mb_idx) % MB; )==""\n"
705R"==(const int po_oc = (g * OC + oc + (oc_idx * 16)) % (OC * G); )==""\n"
706R"==(float acc_ow[OW_BLOCK]; )==""\n"
707R"==(float sum_ow[OW_BLOCK]; )==""\n"
708R"==(unroll_for(int ow_idx = 0; ow_idx < OW_BLOCK; ++ow_idx) { )==""\n"
709R"==(acc_ow[ow_idx] )==""\n"
710R"==(= CONVERT_FLOAT_T(C[mb_idx * OC_OUTER * OW_BLOCK )==""\n"
711R"==(+ oc_idx * OW_BLOCK + ow_idx]); )==""\n"
712R"==(sum_ow[ow_idx] )==""\n"
713R"==(= CONVERT_FLOAT_T(S[mb_idx * OC_OUTER * OW_BLOCK )==""\n"
714R"==(+ oc_idx * OW_BLOCK + ow_idx]); )==""\n"
715R"==(} )==""\n"
716R"==(APPLY_POST_OPS_TRY_BURST(acc_ow, float, sum_ow, float, po_mb, 1, )==""\n"
717R"==(po_oc, SUB_GROUP_SIZE, sglid); )==""\n"
718R"==(unroll_for(int ow_idx = 0; ow_idx < OW_BLOCK; ++ow_idx) { )==""\n"
719R"==(C[mb_idx * OC_OUTER * OW_BLOCK + oc_idx * OW_BLOCK + ow_idx] )==""\n"
720R"==(= acc_ow[ow_idx]; )==""\n"
721R"==(} )==""\n"
722R"==(} )==""\n"
723R"==(} )==""\n"
724R"==(} )==""\n"
725R"==(#endif )==""\n"
726R"==(write_dst_block((DATA_T *)(&C), dst, ow); )==""\n"
727R"==(} )==""\n"
728R"==()==";
729}
730}
731}
732}