1namespace dnnl {
2namespace impl {
3namespace gpu {
4namespace ocl {
5const char *gen9_conv_bwd_weights_kernel = R"==(/******************************************************************************* )==""\n"
6R"==(* Copyright 2019-2021 Intel Corporation )==""\n"
7R"==(* )==""\n"
8R"==(* Licensed under the Apache License, Version 2.0 (the "License"); )==""\n"
9R"==(* you may not use this file except in compliance with the License. )==""\n"
10R"==(* You may obtain a copy of the License at )==""\n"
11R"==(* )==""\n"
12R"==(* http: )==""\n"
13R"==(* )==""\n"
14R"==(* Unless required by applicable law or agreed to in writing, software )==""\n"
15R"==(* distributed under the License is distributed on an "AS IS" BASIS, )==""\n"
16R"==(* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. )==""\n"
17R"==(* See the License for the specific language governing permissions and )==""\n"
18R"==(* limitations under the License. )==""\n"
19R"==(*******************************************************************************/ )==""\n"
20R"==(#include "gpu/ocl/ocl_types.h" )==""\n"
21R"==(#define DT_UNDEF )==""\n"
22R"==(#include "gpu/ocl/ocl_math_utils.h" )==""\n"
23R"==(#include "gpu/ocl/ocl_types.h" )==""\n"
24R"==(#if OD > 1 )==""\n"
25R"==(#define CASE_3D 1 )==""\n"
26R"==(#else )==""\n"
27R"==(#define CASE_3D 0 )==""\n"
28R"==(#endif )==""\n"
29R"==(#define HAS_PAD_D (PD != 0 || PD_R != 0) )==""\n"
30R"==(#define HAS_PAD_H (PH != 0 || PH_R != 0) )==""\n"
31R"==(#define HAS_PAD_W (PW != 0 || PW_R != 0) )==""\n"
32R"==(#if DST_DT_F32 )==""\n"
33R"==(#define BLOCK_READ_DST(ptr) \ )==""\n"
34R"==(as_float(intel_sub_group_block_read((__global uint *)ptr)) )==""\n"
35R"==(#elif DST_DT_BF16 )==""\n"
36R"==(#define BLOCK_READ_DST(ptr) \ )==""\n"
37R"==(as_ushort(intel_sub_group_block_read_us((__global ushort *)ptr)) )==""\n"
38R"==(#define BLOCK_READ_DST8(ptr) \ )==""\n"
39R"==(as_ushort8(intel_sub_group_block_read_us8((__global ushort *)ptr)) )==""\n"
40R"==(#endif )==""\n"
41R"==(#if BWD_WEIGHTS == 1 )==""\n"
42R"==(__attribute__((reqd_work_group_size(LWS_0, LWS_1, LWS_2))) )==""\n"
43R"==(#if VER_16MB16C == 1 || VER_8OW16C == 1 )==""\n"
44R"==(__attribute__((intel_reqd_sub_group_size(SUB_GROUP_SIZE))) )==""\n"
45R"==(#endif )==""\n"
46R"==(__kernel void )==""\n"
47R"==(gen9_conv_bwd_weights(__global SRC_DATA_T *src, )==""\n"
48R"==(volatile __global atomic_float *diff_wei, )==""\n"
49R"==(volatile __global atomic_float *diff_bias, )==""\n"
50R"==(__global DST_DATA_T *diff_dst) { )==""\n"
51R"==(MAYBE_SKIP_NON_UNIFORM_WG(); )==""\n"
52R"==(#if VER_16MB16C == 1 )==""\n"
53R"==(const uint ksp = get_global_id(1); )==""\n"
54R"==(#if CASE_3D )==""\n"
55R"==(const uint kd = ksp / (KW * KH); )==""\n"
56R"==(const uint khw = ksp % (KW * KH); )==""\n"
57R"==(#else )==""\n"
58R"==(const uint khw = ksp; )==""\n"
59R"==(const uint kd = 0; )==""\n"
60R"==(#endif )==""\n"
61R"==(const uint kh = khw / KW; )==""\n"
62R"==(const uint kw = khw % KW; )==""\n"
63R"==(const uint sglid = get_sub_group_local_id(); )==""\n"
64R"==(const uint chunk = get_global_id(2) / ((IC / ICB) * (OC / OCB)); )==""\n"
65R"==(const uint icb_ocb = get_global_id(2) % ((IC / ICB) * (OC / OCB)); )==""\n"
66R"==(const uint icb = icb_ocb % (IC / ICB); )==""\n"
67R"==(const uint ocb = icb_ocb / (IC / ICB); )==""\n"
68R"==(#if IS_DW )==""\n"
69R"==(const uint g = 0; )==""\n"
70R"==(const uint oc )==""\n"
71R"==(= get_group_id(0) * (LWS_0 / SUB_GROUP_SIZE) + get_sub_group_id(); )==""\n"
72R"==(const uint ic = oc; )==""\n"
73R"==(#else )==""\n"
74R"==(const uint g_ic_oc = get_global_id(0); )==""\n"
75R"==(const uint g = g_ic_oc / (OCB * (ICB / IC_BLOCK)); )==""\n"
76R"==(const uint io = g_ic_oc % (OCB * (ICB / IC_BLOCK)); )==""\n"
77R"==(const uint oc = (io % OCB) / OC_BLOCK + ocb * (OCB / OC_BLOCK); )==""\n"
78R"==(const uint ic = io / OCB + icb * (ICB / IC_BLOCK); )==""\n"
79R"==(#endif )==""\n"
80R"==(const uint sp_chunk = chunk % OSP_CHUNK; )==""\n"
81R"==(const uint mb_chunk = chunk / OSP_CHUNK; )==""\n"
82R"==(const uint oh_nb = (OH + OHB - 1) / OHB; )==""\n"
83R"==(const uint ow_nb = (OW + OWB - 1) / OWB; )==""\n"
84R"==(const uint od_beg = (sp_chunk / ow_nb) / oh_nb * ODB; )==""\n"
85R"==(const uint oh_beg = (sp_chunk / ow_nb) % oh_nb * OHB; )==""\n"
86R"==(const uint ow_beg = (sp_chunk % ow_nb) * OWB; )==""\n"
87R"==(const uint mb = mb_chunk * (MB_CHUNK_SIZE); )==""\n"
88R"==(const uint mb_end = min((mb_chunk + 1) * (MB_CHUNK_SIZE), (uint)MB); )==""\n"
89R"==(const bool do_bias = (ic == 0 || IS_DW) && kh == 0 && kw == 0 && kd == 0; )==""\n"
90R"==(src += ic * ID * IH * IW * IC_BLOCK * MB_BLOCK + mb * IC * G * ID * IH * IW )==""\n"
91R"==(+ g * IC * ID * IH * IW * MB_BLOCK; )==""\n"
92R"==(diff_dst += oc * OD * OH * OW * OC_BLOCK * MB_BLOCK )==""\n"
93R"==(+ g * OC * OD * OH * OW * MB_BLOCK; )==""\n"
94R"==(#if WITH_BIAS == 1 )==""\n"
95R"==(diff_bias += g * OC + oc * OC_BLOCK + sglid; )==""\n"
96R"==(float bias_loc = 0.0f; )==""\n"
97R"==(#endif )==""\n"
98R"==(#if IS_DW )==""\n"
99R"==(float blockC00 = 0.0f; )==""\n"
100R"==(#else )==""\n"
101R"==(float8 blockC00 = 0.0f; )==""\n"
102R"==(float8 blockC01 = 0.0f; )==""\n"
103R"==(#endif )==""\n"
104R"==(#if MB != (MB_CHUNK * MB_BLOCK) )==""\n"
105R"==(uint omb = mb; )==""\n"
106R"==(do { )==""\n"
107R"==(const __global float *diff_dst1_ )==""\n"
108R"==(= diff_dst + omb * OC * G * OD * OH * OW; )==""\n"
109R"==(#else )==""\n"
110R"==(const __global float *diff_dst1_ = diff_dst + mb * OC * G * OD * OH * OW; )==""\n"
111R"==(#endif )==""\n"
112R"==(for (uint od = od_beg; od < min(od_beg + ODB, (uint)OD); od++) { )==""\n"
113R"==(for (uint oh = oh_beg; oh < min(oh_beg + OHB, (uint)OH); oh++) { )==""\n"
114R"==(for (uint ow = ow_beg; ow < min(ow_beg + OWB, (uint)OW); ow++) { )==""\n"
115R"==(const __global float *diff_dst1 = diff_dst1_ )==""\n"
116R"==(+ od * OH * OW * OC_BLOCK * MB_BLOCK )==""\n"
117R"==(+ oh * OW * OC_BLOCK * MB_BLOCK )==""\n"
118R"==(+ ow * OC_BLOCK * MB_BLOCK; )==""\n"
119R"==(const uint ih = oh * SH - PH + kh * (1 + DH); )==""\n"
120R"==(const uint iw = ow * SW - PW + kw * (1 + DW); )==""\n"
121R"==(#if CASE_3D )==""\n"
122R"==(const uint id = od * SD - PD + kd * (1 + DD); )==""\n"
123R"==(#endif )==""\n"
124R"==(if (iw < 0 || ih < 0 || iw >= IW || ih >= IH )==""\n"
125R"==(#if CASE_3D )==""\n"
126R"==(|| id < 0 || id >= ID )==""\n"
127R"==(#endif )==""\n"
128R"==() { )==""\n"
129R"==(#if WITH_BIAS == 1 )==""\n"
130R"==(if (do_bias) { )==""\n"
131R"==(float8 blockB )==""\n"
132R"==(= as_float8(intel_sub_group_block_read8(( )==""\n"
133R"==(const __global uint *)(diff_dst1))); )==""\n"
134R"==(for (int i = 0; i < 8; i++) )==""\n"
135R"==(bias_loc += blockB[i]; )==""\n"
136R"==(blockB = as_float8(intel_sub_group_block_read8( )==""\n"
137R"==((const __global uint *)(diff_dst1 )==""\n"
138R"==(+ 8 * OC_BLOCK))); )==""\n"
139R"==(for (int i = 0; i < 8; i++) )==""\n"
140R"==(bias_loc += blockB[i]; )==""\n"
141R"==(} )==""\n"
142R"==(#endif )==""\n"
143R"==(continue; )==""\n"
144R"==(} )==""\n"
145R"==(const __global float *src1 = src )==""\n"
146R"==(+ ih * IW * IC_BLOCK * MB_BLOCK )==""\n"
147R"==(+ iw * IC_BLOCK * MB_BLOCK; )==""\n"
148R"==(#if CASE_3D )==""\n"
149R"==(src1 += id * IH * IW * IC_BLOCK * MB_BLOCK; )==""\n"
150R"==(#endif )==""\n"
151R"==(#define TRANSPOSE_8(_block, _row, _col) \ )==""\n"
152R"==((float8)(intel_sub_group_shuffle(_block[_row], 0 + _col), \ )==""\n"
153R"==(intel_sub_group_shuffle(_block[_row], 1 + _col), \ )==""\n"
154R"==(intel_sub_group_shuffle(_block[_row], 2 + _col), \ )==""\n"
155R"==(intel_sub_group_shuffle(_block[_row], 3 + _col), \ )==""\n"
156R"==(intel_sub_group_shuffle(_block[_row], 4 + _col), \ )==""\n"
157R"==(intel_sub_group_shuffle(_block[_row], 5 + _col), \ )==""\n"
158R"==(intel_sub_group_shuffle(_block[_row], 6 + _col), \ )==""\n"
159R"==(intel_sub_group_shuffle(_block[_row], 7 + _col)) )==""\n"
160R"==(#define FMA8(a, b, c) fma((float8)(a), (float8)b, (float8)c) )==""\n"
161R"==(#define MULTIPLY_BLOCKS_8x8(_result, _blockA, _blockB, col) \ )==""\n"
162R"==({ \ )==""\n"
163R"==(_result = FMA8(_blockB.s0, TRANSPOSE_8(_blockA, 0, col), _result); \ )==""\n"
164R"==(_result = FMA8(_blockB.s1, TRANSPOSE_8(_blockA, 1, col), _result); \ )==""\n"
165R"==(_result = FMA8(_blockB.s2, TRANSPOSE_8(_blockA, 2, col), _result); \ )==""\n"
166R"==(_result = FMA8(_blockB.s3, TRANSPOSE_8(_blockA, 3, col), _result); \ )==""\n"
167R"==(_result = FMA8(_blockB.s4, TRANSPOSE_8(_blockA, 4, col), _result); \ )==""\n"
168R"==(_result = FMA8(_blockB.s5, TRANSPOSE_8(_blockA, 5, col), _result); \ )==""\n"
169R"==(_result = FMA8(_blockB.s6, TRANSPOSE_8(_blockA, 6, col), _result); \ )==""\n"
170R"==(_result = FMA8(_blockB.s7, TRANSPOSE_8(_blockA, 7, col), _result); \ )==""\n"
171R"==(} )==""\n"
172R"==(#if IS_DW )==""\n"
173R"==(float8 blockA = as_float8(intel_sub_group_block_read8( )==""\n"
174R"==((const __global uint *)(src1))); )==""\n"
175R"==(float8 blockA1 = as_float8(intel_sub_group_block_read8( )==""\n"
176R"==((const __global uint *)(src1 + 8 * IC_BLOCK))); )==""\n"
177R"==(float8 blockB = as_float8(intel_sub_group_block_read8( )==""\n"
178R"==((const __global uint *)(diff_dst1))); )==""\n"
179R"==(float8 blockB1 = as_float8(intel_sub_group_block_read8( )==""\n"
180R"==((const __global uint *)(diff_dst1 + 8 * OC_BLOCK))); )==""\n"
181R"==(for (int i = 0; i < 8; i++) { )==""\n"
182R"==(blockC00 = fma(blockA[i], blockB[i], blockC00); )==""\n"
183R"==(} )==""\n"
184R"==(#if WITH_BIAS == 1 )==""\n"
185R"==(for (int i = 0; i < 8; i++) )==""\n"
186R"==(bias_loc += blockB[i]; )==""\n"
187R"==(#endif )==""\n"
188R"==(for (int i = 0; i < 8; i++) { )==""\n"
189R"==(blockC00 = fma(blockA1[i], blockB1[i], blockC00); )==""\n"
190R"==(} )==""\n"
191R"==(#if WITH_BIAS == 1 )==""\n"
192R"==(for (int i = 0; i < 8; i++) )==""\n"
193R"==(bias_loc += blockB1[i]; )==""\n"
194R"==(#endif )==""\n"
195R"==(#else )==""\n"
196R"==(float8 blockA = as_float8(intel_sub_group_block_read8( )==""\n"
197R"==((const __global uint *)(src1))); )==""\n"
198R"==(float8 blockB = as_float8(intel_sub_group_block_read8( )==""\n"
199R"==((const __global uint *)(diff_dst1))); )==""\n"
200R"==(MULTIPLY_BLOCKS_8x8(blockC00, blockA, blockB, 0); )==""\n"
201R"==(MULTIPLY_BLOCKS_8x8(blockC01, blockA, blockB, 8); )==""\n"
202R"==(#if WITH_BIAS == 1 )==""\n"
203R"==(for (int i = 0; i < 8; i++) )==""\n"
204R"==(bias_loc += blockB[i]; )==""\n"
205R"==(#endif )==""\n"
206R"==(blockA = as_float8(intel_sub_group_block_read8( )==""\n"
207R"==((const __global uint *)(src1 + 8 * IC_BLOCK))); )==""\n"
208R"==(blockB = as_float8(intel_sub_group_block_read8( )==""\n"
209R"==((const __global uint *)(diff_dst1 + 8 * OC_BLOCK))); )==""\n"
210R"==(MULTIPLY_BLOCKS_8x8(blockC00, blockA, blockB, 0); )==""\n"
211R"==(MULTIPLY_BLOCKS_8x8(blockC01, blockA, blockB, 8); )==""\n"
212R"==(#if WITH_BIAS == 1 )==""\n"
213R"==(for (int i = 0; i < 8; i++) )==""\n"
214R"==(bias_loc += blockB[i]; )==""\n"
215R"==(#endif )==""\n"
216R"==(#endif )==""\n"
217R"==(} )==""\n"
218R"==(} )==""\n"
219R"==(} )==""\n"
220R"==(#if MB != (MB_CHUNK * MB_BLOCK) )==""\n"
221R"==(omb += MB_BLOCK; )==""\n"
222R"==(src += IC * G * ID * IH * IW * MB_BLOCK; )==""\n"
223R"==(} while (omb < mb_end); )==""\n"
224R"==(#endif )==""\n"
225R"==(#if WITH_BIAS == 1 )==""\n"
226R"==(if (do_bias )==""\n"
227R"==(&& oc * OC_BLOCK + sglid < (IS_DW ? G_WO_PADDING : OC_WO_PADDING)) )==""\n"
228R"==(atomic_add_global(diff_bias, bias_loc); )==""\n"
229R"==(#endif )==""\n"
230R"==(#if IS_DW )==""\n"
231R"==(diff_wei += oc * KD * KH * KW * OC_BLOCK + kd * KH * KW * OC_BLOCK )==""\n"
232R"==(+ kh * KW * OC_BLOCK + kw * OC_BLOCK; )==""\n"
233R"==(atomic_add_global(diff_wei + sglid, blockC00); )==""\n"
234R"==(#else )==""\n"
235R"==(diff_wei += ic * OC * KD * KH * KW * IC_BLOCK )==""\n"
236R"==(+ oc * KD * KH * KW * IC_BLOCK * OC_BLOCK )==""\n"
237R"==(+ kd * KH * KW * IC_BLOCK * OC_BLOCK + kh * KW * IC_BLOCK * OC_BLOCK )==""\n"
238R"==(+ kw * IC_BLOCK * OC_BLOCK + g * OC * IC * KD * KH * KW; )==""\n"
239R"==(for (int i = 0; i < 8; i++) )==""\n"
240R"==(atomic_add_global(diff_wei + i * OC_BLOCK + sglid, blockC00[i]); )==""\n"
241R"==(for (int i = 0; i < 8; i++) )==""\n"
242R"==(atomic_add_global(diff_wei + (8 + i) * OC_BLOCK + sglid, blockC01[i]); )==""\n"
243R"==(#endif )==""\n"
244R"==(#endif )==""\n"
245R"==(#if VER_8OW16C == 1 )==""\n"
246R"==(#define HAS_PAD_W (PW > 0 || OW * SW - PW + (KW - 1) * (1 + DW) >= IW) )==""\n"
247R"==(const int sglid = get_sub_group_local_id(); )==""\n"
248R"==(#if IC == 3 )==""\n"
249R"==(const int ksp = get_global_id(1) * 16 + sglid; )==""\n"
250R"==(#else )==""\n"
251R"==(const int ksp = get_global_id(1); )==""\n"
252R"==(#endif )==""\n"
253R"==(const int ICX = IC == 3 ? 3 : 1; )==""\n"
254R"==(#if CASE_3D )==""\n"
255R"==(const int kd = ksp / (KW * KH * ICX); )==""\n"
256R"==(const int khw = ksp % (KW * KH * ICX); )==""\n"
257R"==(#else )==""\n"
258R"==(const int khw = ksp; )==""\n"
259R"==(const int kd = 0; )==""\n"
260R"==(#endif )==""\n"
261R"==(const int kh = khw / (KW * ICX); )==""\n"
262R"==(const int kw = (khw % (KW * ICX)) % KW; )==""\n"
263R"==(const int chunk = get_global_id(2) % NCHUNK; )==""\n"
264R"==(const int icb_ocb = get_global_id(2) / NCHUNK; )==""\n"
265R"==(const int icb = icb_ocb % (IC / ICB); )==""\n"
266R"==(const int ocb = icb_ocb / (IC / ICB); )==""\n"
267R"==(#if IS_DW )==""\n"
268R"==(const int g = 0; )==""\n"
269R"==(const int oc )==""\n"
270R"==(= get_group_id(0) * (LWS_0 / SUB_GROUP_SIZE) + get_sub_group_id(); )==""\n"
271R"==(const int ic = oc; )==""\n"
272R"==(#else )==""\n"
273R"==(const int g_ic_oc = get_global_id(0); )==""\n"
274R"==(const int g = g_ic_oc / (OC * (IC / IC_BLOCK)); )==""\n"
275R"==(const int io = g_ic_oc % (OC * (IC / IC_BLOCK)); )==""\n"
276R"==(const int oc = (io % OCB) / OC_BLOCK + ocb * (OCB / OC_BLOCK); )==""\n"
277R"==(const int ic = (IC == 3) ? (khw % (KW * ICX)) / KW )==""\n"
278R"==(: (io / OCB + icb * (ICB / IC_BLOCK)); )==""\n"
279R"==(#endif )==""\n"
280R"==(const int sp_chunk = chunk % OSP_CHUNK; )==""\n"
281R"==(const int mb_chunk = chunk / OSP_CHUNK; )==""\n"
282R"==(const int ow_nb = (OW + OWB - 1) / OWB; )==""\n"
283R"==(const int oh_nb = (OH + OHB - 1) / OHB; )==""\n"
284R"==(const int od_beg = ((sp_chunk / ow_nb) / oh_nb) * ODB; )==""\n"
285R"==(const int oh_beg = ((sp_chunk / ow_nb) % oh_nb) * OHB; )==""\n"
286R"==(const int ow_beg = (sp_chunk % ow_nb) * OWB; )==""\n"
287R"==(const int mb = mb_chunk * MB_CHUNK_SIZE; )==""\n"
288R"==(const int mb_end = min((mb_chunk + 1) * MB_CHUNK_SIZE, MB); )==""\n"
289R"==(#if IC == 3 )==""\n"
290R"==(const bool do_bias = get_global_id(1) == 0; )==""\n"
291R"==(#else )==""\n"
292R"==(const bool do_bias = (ic == 0 || IS_DW) && kh == 0 && kw == 0 && kd == 0; )==""\n"
293R"==(#endif )==""\n"
294R"==(const int OW_LOOP_BLOCK = 8; )==""\n"
295R"==(#if IC == 3 )==""\n"
296R"==(src += mb * IC * G * ID * IH * IW + g * IC * ID * IH * IW * MB_BLOCK; )==""\n"
297R"==(#else )==""\n"
298R"==(src += ic * ID * IH * IW * IC_BLOCK * MB_BLOCK + mb * IC * G * ID * IH * IW )==""\n"
299R"==(+ g * IC * ID * IH * IW * MB_BLOCK; )==""\n"
300R"==(#endif )==""\n"
301R"==(diff_dst += oc * OD * OH * OW * OC_BLOCK * MB_BLOCK )==""\n"
302R"==(+ g * OC * OD * OH * OW * MB_BLOCK; )==""\n"
303R"==(#if WITH_BIAS == 1 )==""\n"
304R"==(diff_bias += g * OC + oc * OC_BLOCK + sglid; )==""\n"
305R"==(float bias_loc = 0.0f; )==""\n"
306R"==(#endif )==""\n"
307R"==(#if IC == 3 )==""\n"
308R"==(float8 blockC00 = 0.0f; )==""\n"
309R"==(float8 blockC01 = 0.0f; )==""\n"
310R"==(#elif IS_DW )==""\n"
311R"==(float blockC00 = 0.0f; )==""\n"
312R"==(#else )==""\n"
313R"==(float8 blockC00 = 0.0f; )==""\n"
314R"==(float8 blockC01 = 0.0f; )==""\n"
315R"==(#endif )==""\n"
316R"==(for (int omb = mb; omb < mb_end; omb++) { )==""\n"
317R"==(const __global DST_DATA_T *diff_dst1_ )==""\n"
318R"==(= diff_dst + omb * OC * G * OD * OH * OW; )==""\n"
319R"==(for (int od = od_beg; od < min(od_beg + ODB, OD); od++) )==""\n"
320R"==(for (int oh = oh_beg; oh < min(oh_beg + OHB, OH); oh++) { )==""\n"
321R"==(const __global DST_DATA_T *diff_dst1 = diff_dst1_ )==""\n"
322R"==(+ od * OH * OW * OC_BLOCK + oh * OW * OC_BLOCK; )==""\n"
323R"==(bool skip = false; )==""\n"
324R"==(if (oh * SH + kh * (1 + DH) < PH )==""\n"
325R"==(|| oh * SH + kh * (1 + DH) >= IH + PH )==""\n"
326R"==(#if CASE_3D )==""\n"
327R"==(|| od * SD + kd * (1 + DD) < PD )==""\n"
328R"==(|| od * SD + kd * (1 + DD) >= ID + PD )==""\n"
329R"==(#endif )==""\n"
330R"==() { )==""\n"
331R"==(skip = true; )==""\n"
332R"==(} )==""\n"
333R"==(const int id = od * SD - PD + kd * (1 + DD); )==""\n"
334R"==(const int ih = oh * SH - PH + kh * (1 + DH); )==""\n"
335R"==(__global SRC_DATA_T *src1; )==""\n"
336R"==(for (int ow = ow_beg; )==""\n"
337R"==(ow < min(ow_beg + OWB, (OW / OW_BLOCK) * OW_BLOCK); )==""\n"
338R"==(ow += OW_BLOCK) { )==""\n"
339R"==(const int iw = ow * SW - PW + kw * (1 + DW); )==""\n"
340R"==(src1 = src + id * IH * IW * IC_BLOCK + ih * IW * IC_BLOCK )==""\n"
341R"==(+ iw * IC_BLOCK; )==""\n"
342R"==(#define TRANSPOSE_8(_block, _row, _col) \ )==""\n"
343R"==({ \ )==""\n"
344R"==((float8)(intel_sub_group_shuffle(_block[_row], 0 + _col), \ )==""\n"
345R"==(intel_sub_group_shuffle(_block[_row], 1 + _col), \ )==""\n"
346R"==(intel_sub_group_shuffle(_block[_row], 2 + _col), \ )==""\n"
347R"==(intel_sub_group_shuffle(_block[_row], 3 + _col), \ )==""\n"
348R"==(intel_sub_group_shuffle(_block[_row], 4 + _col), \ )==""\n"
349R"==(intel_sub_group_shuffle(_block[_row], 5 + _col), \ )==""\n"
350R"==(intel_sub_group_shuffle(_block[_row], 6 + _col), \ )==""\n"
351R"==(intel_sub_group_shuffle(_block[_row], 7 + _col)) \ )==""\n"
352R"==(} )==""\n"
353R"==(#define FMA8(a, b, c) fma((float8)(a), (float8)b, (float8)c) )==""\n"
354R"==(#define MULTIPLY_BLOCKS_8x8(_result, _blockA, _blockB, col) \ )==""\n"
355R"==({ \ )==""\n"
356R"==(_result = FMA8(_blockB.s0, TRANSPOSE_8(_blockA, 0, col), _result); \ )==""\n"
357R"==(_result = FMA8(_blockB.s1, TRANSPOSE_8(_blockA, 1, col), _result); \ )==""\n"
358R"==(_result = FMA8(_blockB.s2, TRANSPOSE_8(_blockA, 2, col), _result); \ )==""\n"
359R"==(_result = FMA8(_blockB.s3, TRANSPOSE_8(_blockA, 3, col), _result); \ )==""\n"
360R"==(_result = FMA8(_blockB.s4, TRANSPOSE_8(_blockA, 4, col), _result); \ )==""\n"
361R"==(_result = FMA8(_blockB.s5, TRANSPOSE_8(_blockA, 5, col), _result); \ )==""\n"
362R"==(_result = FMA8(_blockB.s6, TRANSPOSE_8(_blockA, 6, col), _result); \ )==""\n"
363R"==(_result = FMA8(_blockB.s7, TRANSPOSE_8(_blockA, 7, col), _result); \ )==""\n"
364R"==(} )==""\n"
365R"==(float8 blockA, blockB; )==""\n"
366R"==(#if IC == 3 )==""\n"
367R"==(if (skip) { )==""\n"
368R"==(blockA = 0.0f; )==""\n"
369R"==(} else { )==""\n"
370R"==(for (int i = 0; i < 8; i++) { )==""\n"
371R"==(if (HAS_PAD_W )==""\n"
372R"==(&& (iw + i * SW < 0 || iw + i * SW >= IW)) )==""\n"
373R"==(blockA[i] = 0; )==""\n"
374R"==(else )==""\n"
375R"==(blockA[i] = SRC_TO_REF( )==""\n"
376R"==(src1[ic * ID * IH * IW + i * SW]); )==""\n"
377R"==(} )==""\n"
378R"==(} )==""\n"
379R"==(#else )==""\n"
380R"==(if (skip) { )==""\n"
381R"==(blockA = 0.0f; )==""\n"
382R"==(} else { )==""\n"
383R"==(for (int i = 0; i < OW_BLOCK; i++) { )==""\n"
384R"==(if (HAS_PAD_W )==""\n"
385R"==(&& (iw + i * SW < 0 || iw + i * SW >= IW)) { )==""\n"
386R"==(blockA[i] = 0; )==""\n"
387R"==(} else { )==""\n"
388R"==(blockA[i] = as_float(intel_sub_group_block_read( )==""\n"
389R"==((const __global uint *)(&src1[i )==""\n"
390R"==(* IC_BLOCK * SW]))); )==""\n"
391R"==(} )==""\n"
392R"==(} )==""\n"
393R"==(} )==""\n"
394R"==(#endif )==""\n"
395R"==(blockB = DST_TO_REF8( )==""\n"
396R"==(BLOCK_READ_DST8(diff_dst1 + ow * OC_BLOCK)); )==""\n"
397R"==(#if IC == 3 )==""\n"
398R"==(MULTIPLY_BLOCKS_8x8(blockC00, blockB, blockA, 0); )==""\n"
399R"==(MULTIPLY_BLOCKS_8x8(blockC01, blockB, blockA, 8); )==""\n"
400R"==(#elif IS_DW )==""\n"
401R"==(for (int i = 0; i < OW_LOOP_BLOCK; i++) { )==""\n"
402R"==(blockC00 = fma(blockA[i], blockB[i], blockC00); )==""\n"
403R"==(} )==""\n"
404R"==(#else )==""\n"
405R"==(MULTIPLY_BLOCKS_8x8(blockC00, blockA, blockB, 0); )==""\n"
406R"==(MULTIPLY_BLOCKS_8x8(blockC01, blockA, blockB, 8); )==""\n"
407R"==(#endif )==""\n"
408R"==(#if WITH_BIAS == 1 )==""\n"
409R"==(for (int i = 0; i < OW_LOOP_BLOCK; i++) { )==""\n"
410R"==(bias_loc += blockB[i]; )==""\n"
411R"==(} )==""\n"
412R"==(#endif )==""\n"
413R"==(} )==""\n"
414R"==(for (int ow = (OW / OW_BLOCK) * OW_BLOCK; )==""\n"
415R"==(ow < min(ow_beg + OWB, OW); ow += OW_LOOP_BLOCK) { )==""\n"
416R"==(const int id = od * SD - PD + kd * (1 + DD); )==""\n"
417R"==(const int ih = oh * SH - PH + kh * (1 + DH); )==""\n"
418R"==(const int iw = ow * SW - PW + kw * (1 + DW); )==""\n"
419R"==(__global SRC_DATA_T *src1; )==""\n"
420R"==(float8 blockA, blockB; )==""\n"
421R"==(src1 = src + id * IH * IW * IC_BLOCK + ih * IW * IC_BLOCK )==""\n"
422R"==(+ iw * IC_BLOCK; )==""\n"
423R"==(#if IC == 3 )==""\n"
424R"==(if (skip) { )==""\n"
425R"==(blockA = 0.0f; )==""\n"
426R"==(} else { )==""\n"
427R"==(for (int i = 0; i < min(OW_LOOP_BLOCK, OW - ow); i++) { )==""\n"
428R"==(if (HAS_PAD_W )==""\n"
429R"==(&& (iw + i * SW < 0 || iw + i * SW >= IW)) )==""\n"
430R"==(blockA[i] = 0; )==""\n"
431R"==(else )==""\n"
432R"==(blockA[i] = SRC_TO_REF( )==""\n"
433R"==(src1[ic * ID * IH * IW + i * SW]); )==""\n"
434R"==(} )==""\n"
435R"==(} )==""\n"
436R"==(#else )==""\n"
437R"==(if (skip) { )==""\n"
438R"==(blockA = 0.0f; )==""\n"
439R"==(} else { )==""\n"
440R"==(for (int i = 0; i < min(OW_LOOP_BLOCK, OW - ow); i++) { )==""\n"
441R"==(if (HAS_PAD_W )==""\n"
442R"==(&& (iw + i * SW < 0 || iw + i * SW >= IW)) { )==""\n"
443R"==(blockA[i] = 0; )==""\n"
444R"==(} else { )==""\n"
445R"==(blockA[i] = as_float(intel_sub_group_block_read( )==""\n"
446R"==((const __global uint *)(&src1[i )==""\n"
447R"==(* IC_BLOCK * SW]))); )==""\n"
448R"==(} )==""\n"
449R"==(} )==""\n"
450R"==(} )==""\n"
451R"==(#endif )==""\n"
452R"==(for (int i = 0; i < min(OW_LOOP_BLOCK, OW - ow); i++) { )==""\n"
453R"==(blockB[i] = DST_TO_REF(BLOCK_READ_DST( )==""\n"
454R"==((&diff_dst1[(ow + i) * OC_BLOCK]))); )==""\n"
455R"==(} )==""\n"
456R"==(#if IC == 3 )==""\n"
457R"==(for (int i = 0; i < min(OW_LOOP_BLOCK, OW - ow); i++) { )==""\n"
458R"==(blockC00 = FMA8( )==""\n"
459R"==(blockA[i], TRANSPOSE_8(blockB, i, 0), blockC00); )==""\n"
460R"==(blockC01 = FMA8( )==""\n"
461R"==(blockA[i], TRANSPOSE_8(blockB, i, 8), blockC01); )==""\n"
462R"==(} )==""\n"
463R"==(#elif IS_DW )==""\n"
464R"==(for (int i = 0; i < min(OW_LOOP_BLOCK, OW - ow); i++) { )==""\n"
465R"==(blockC00 = fma(blockA[i], blockB[i], blockC00); )==""\n"
466R"==(} )==""\n"
467R"==(#else )==""\n"
468R"==(for (int i = 0; i < min(OW_LOOP_BLOCK, OW - ow); i++) { )==""\n"
469R"==(blockC00 = FMA8( )==""\n"
470R"==(blockB[i], TRANSPOSE_8(blockA, i, 0), blockC00); )==""\n"
471R"==(blockC01 = FMA8( )==""\n"
472R"==(blockB[i], TRANSPOSE_8(blockA, i, 8), blockC01); )==""\n"
473R"==(} )==""\n"
474R"==(#endif )==""\n"
475R"==(#if WITH_BIAS == 1 )==""\n"
476R"==(for (int i = 0; i < min(OW_LOOP_BLOCK, OW - ow); i++) )==""\n"
477R"==(bias_loc += blockB[i]; )==""\n"
478R"==(#endif )==""\n"
479R"==(} )==""\n"
480R"==(} )==""\n"
481R"==(src += G * IC * ID * IH * IW * MB_BLOCK; )==""\n"
482R"==(} )==""\n"
483R"==(#if WITH_BIAS == 1 )==""\n"
484R"==(if (do_bias )==""\n"
485R"==(&& oc * OC_BLOCK + sglid < (IS_DW ? G_WO_PADDING : OC_WO_PADDING)) )==""\n"
486R"==(atomic_add_global(diff_bias, bias_loc); )==""\n"
487R"==(#endif )==""\n"
488R"==(#if IC == 3 )==""\n"
489R"==(diff_wei += ic * OC_BLOCK + oc * KD * KH * KW * IC * OC_BLOCK )==""\n"
490R"==(+ g * OC * IC * KD * KH * KW + kd * KH * KW * IC * OC_BLOCK )==""\n"
491R"==(+ kh * KW * IC * OC_BLOCK + kw * IC * OC_BLOCK; )==""\n"
492R"==(if (ksp >= KH * KW * KD * IC) return; )==""\n"
493R"==(for (int i = 0; i < 8; i++) )==""\n"
494R"==(atomic_add_global(diff_wei + i, blockC00[i]); )==""\n"
495R"==(for (int i = 0; i < 8; i++) )==""\n"
496R"==(atomic_add_global(diff_wei + 8 + i, blockC01[i]); )==""\n"
497R"==(#elif IS_DW )==""\n"
498R"==(diff_wei += oc * KD * KH * KW * OC_BLOCK + kd * KH * KW * OC_BLOCK )==""\n"
499R"==(+ kh * KW * OC_BLOCK + kw * OC_BLOCK; )==""\n"
500R"==(atomic_add_global(diff_wei + sglid, blockC00); )==""\n"
501R"==(#else )==""\n"
502R"==(diff_wei += ic * OC * KD * KH * KW * IC_BLOCK )==""\n"
503R"==(+ oc * KD * KH * KW * IC_BLOCK * OC_BLOCK )==""\n"
504R"==(+ kd * KH * KW * IC_BLOCK * OC_BLOCK + kh * KW * IC_BLOCK * OC_BLOCK )==""\n"
505R"==(+ kw * IC_BLOCK * OC_BLOCK + g * OC * IC * KD * KH * KW; )==""\n"
506R"==(for (int i = 0; i < 8; i++) )==""\n"
507R"==(atomic_add_global(diff_wei + i * OC_BLOCK + sglid, blockC00[i]); )==""\n"
508R"==(for (int i = 0; i < 8; i++) )==""\n"
509R"==(atomic_add_global(diff_wei + (8 + i) * OC_BLOCK + sglid, blockC01[i]); )==""\n"
510R"==(#endif )==""\n"
511R"==(#endif )==""\n"
512R"==(} )==""\n"
513R"==(#endif )==""\n"
514R"==()==";
515}
516}
517}
518}