1namespace dnnl {
2namespace impl {
3namespace gpu {
4namespace ocl {
5const char *gen9_conv_nhwc_bwd_weights_f32_kernel = R"==(/******************************************************************************* )==""\n"
6R"==(* Copyright 2020-2021 Intel Corporation )==""\n"
7R"==(* )==""\n"
8R"==(* Licensed under the Apache License, Version 2.0 (the "License"); )==""\n"
9R"==(* you may not use this file except in compliance with the License. )==""\n"
10R"==(* You may obtain a copy of the License at )==""\n"
11R"==(* )==""\n"
12R"==(* http: )==""\n"
13R"==(* )==""\n"
14R"==(* Unless required by applicable law or agreed to in writing, software )==""\n"
15R"==(* distributed under the License is distributed on an "AS IS" BASIS, )==""\n"
16R"==(* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. )==""\n"
17R"==(* See the License for the specific language governing permissions and )==""\n"
18R"==(* limitations under the License. )==""\n"
19R"==(*******************************************************************************/ )==""\n"
20R"==(#include "gpu/ocl/ocl_math_utils.h" )==""\n"
21R"==(#include "gpu/ocl/ocl_types.h" )==""\n"
22R"==(#if ID > 1 )==""\n"
23R"==(#define CASE_3D 1 )==""\n"
24R"==(#else )==""\n"
25R"==(#define CASE_3D 0 )==""\n"
26R"==(#endif )==""\n"
27R"==(#define DIV_UP(a, b) (((a) + (b)-1) / (b)) )==""\n"
28R"==(#define RND_UP(a, b) (DIV_UP(a, b) * (b)) )==""\n"
29R"==(#if BWD_WEIGHTS == 1 )==""\n"
30R"==(inline float read_ic_block(const __global float *ptr, int off) { )==""\n"
31R"==(#if (IS_DW ? G : IC) % IC_BLOCK != 0 )==""\n"
32R"==(int tail = (IS_DW ? G : IC) - off; )==""\n"
33R"==(if (tail < IC_BLOCK) { )==""\n"
34R"==(const int sglid = get_sub_group_local_id(); )==""\n"
35R"==(return (sglid < tail) ? ptr[sglid] : 0.0f; )==""\n"
36R"==(} )==""\n"
37R"==(#endif )==""\n"
38R"==(return as_float(intel_sub_group_block_read((const __global uint *)ptr)); )==""\n"
39R"==(} )==""\n"
40R"==(inline float read_oc_block(const __global float *ptr, int off) { )==""\n"
41R"==(#if (IS_DW ? G : OC_WO_PADDING) % OC_BLOCK != 0 )==""\n"
42R"==(int tail = (IS_DW ? G : OC_WO_PADDING) - off; )==""\n"
43R"==(if (tail < OC_BLOCK) { )==""\n"
44R"==(const int sglid = get_sub_group_local_id(); )==""\n"
45R"==(return (sglid < tail) ? ptr[sglid] : 0.0f; )==""\n"
46R"==(} )==""\n"
47R"==(#endif )==""\n"
48R"==(return as_float(intel_sub_group_block_read((const __global uint *)ptr)); )==""\n"
49R"==(} )==""\n"
50R"==(__attribute__((reqd_work_group_size(LWS_0, LWS_1, LWS_2))) )==""\n"
51R"==(__attribute__((intel_reqd_sub_group_size(SUB_GROUP_SIZE))) )==""\n"
52R"==(__kernel void )==""\n"
53R"==(gen9_conv_nhwc_bwd_weights(__global float *src, )==""\n"
54R"==(volatile __global atomic_float *diff_wei, )==""\n"
55R"==(volatile __global atomic_float *diff_bias, __global float *diff_dst) { )==""\n"
56R"==(MAYBE_SKIP_NON_UNIFORM_WG(); )==""\n"
57R"==(const int ksp = get_global_id(1); )==""\n"
58R"==(#if CASE_3D )==""\n"
59R"==(const int kd = ksp / (KW * KH); )==""\n"
60R"==(const int khw = ksp % (KW * KH); )==""\n"
61R"==(#else )==""\n"
62R"==(const int khw = ksp; )==""\n"
63R"==(const int kd = 0; )==""\n"
64R"==(#endif )==""\n"
65R"==(const int kh = khw / KW; )==""\n"
66R"==(const int kw = khw % KW; )==""\n"
67R"==(const int sglid = get_sub_group_local_id(); )==""\n"
68R"==(const int chunk = get_global_id(2) % NCHUNK; )==""\n"
69R"==(const int icb_ocb = get_global_id(2) / NCHUNK; )==""\n"
70R"==(const int icb = icb_ocb % DIV_UP(IC, ICB); )==""\n"
71R"==(const int ocb = icb_ocb / DIV_UP(IC, ICB); )==""\n"
72R"==(const int ic_padded = RND_UP(IC, IC_BLOCK); )==""\n"
73R"==(const int oc_padded = RND_UP(OC, OC_BLOCK); )==""\n"
74R"==(#if IS_DW )==""\n"
75R"==(const int g = 0; )==""\n"
76R"==(const int oc )==""\n"
77R"==(= get_group_id(0) * (LWS_0 / SUB_GROUP_SIZE) + get_sub_group_id(); )==""\n"
78R"==(const int ic = oc; )==""\n"
79R"==(#else )==""\n"
80R"==(const int g_ic_oc = get_global_id(0); )==""\n"
81R"==(const int g = g_ic_oc / (oc_padded * DIV_UP(IC, IC_BLOCK)); )==""\n"
82R"==(const int io = g_ic_oc % (oc_padded * DIV_UP(IC, IC_BLOCK)); )==""\n"
83R"==(const int oc = (io % OCB) / OC_BLOCK + ocb * (OCB / OC_BLOCK); )==""\n"
84R"==(const int ic = (IC == 3) ? 0 : (io / OCB + icb * (ICB / IC_BLOCK)); )==""\n"
85R"==(#endif )==""\n"
86R"==(const int sp_chunk = chunk % OSP_CHUNK; )==""\n"
87R"==(const int mb_chunk = chunk / OSP_CHUNK; )==""\n"
88R"==(const int ow_nb = (OW + OWB - 1) / OWB; )==""\n"
89R"==(const int oh_nb = (OH + OHB - 1) / OHB; )==""\n"
90R"==(const int od_beg = ((sp_chunk / ow_nb) / oh_nb) * ODB; )==""\n"
91R"==(const int oh_beg = ((sp_chunk / ow_nb) % oh_nb) * OHB; )==""\n"
92R"==(const int ow_beg = (sp_chunk % ow_nb) * OWB; )==""\n"
93R"==(const int mb = mb_chunk * MB_CHUNK_SIZE; )==""\n"
94R"==(const int mb_end = min((mb_chunk + 1) * MB_CHUNK_SIZE, MB); )==""\n"
95R"==(const bool do_bias = (ic == 0 || IS_DW) && kh == 0 && kw == 0 && kd == 0; )==""\n"
96R"==(src += mb * ID * IH * IW * G * IC; )==""\n"
97R"==(src += g * IC + ic * IC_BLOCK; )==""\n"
98R"==(diff_dst += g * OC_WO_PADDING + oc * OC_BLOCK; )==""\n"
99R"==(#if WITH_BIAS == 1 )==""\n"
100R"==(diff_bias += g * OC_WO_PADDING + oc * OC_BLOCK + sglid; )==""\n"
101R"==(float bias_loc = 0.0f; )==""\n"
102R"==(#endif )==""\n"
103R"==(#if IC == 3 )==""\n"
104R"==(float8 blockC00 = 0.0f; )==""\n"
105R"==(#elif IS_DW )==""\n"
106R"==(float blockC00 = 0.0f; )==""\n"
107R"==(#else )==""\n"
108R"==(float8 blockC00 = 0.0f; )==""\n"
109R"==(float8 blockC01 = 0.0f; )==""\n"
110R"==(#endif )==""\n"
111R"==(for (int omb = mb; omb < mb_end; omb++) { )==""\n"
112R"==(const __global float *diff_dst1_ )==""\n"
113R"==(= diff_dst + omb * OD * OH * OW * G * OC_WO_PADDING; )==""\n"
114R"==(for (int od = od_beg; od < min(od_beg + ODB, OD); od++) )==""\n"
115R"==(for (int oh = oh_beg; oh < min(oh_beg + OHB, OH); oh++) { )==""\n"
116R"==(const __global float *diff_dst1 = diff_dst1_ )==""\n"
117R"==(+ (od * OH * OW + oh * OW) * G * OC_WO_PADDING; )==""\n"
118R"==(if (oh * SH + kh * (1 + DH) < PH )==""\n"
119R"==(|| oh * SH + kh * (1 + DH) >= IH + PH )==""\n"
120R"==(#if CASE_3D )==""\n"
121R"==(|| od * SD + kd * (1 + DD) < PD )==""\n"
122R"==(|| od * SD + kd * (1 + DD) >= ID + PD )==""\n"
123R"==(#endif )==""\n"
124R"==() { )==""\n"
125R"==(#if WITH_BIAS == 1 )==""\n"
126R"==(if (do_bias) { )==""\n"
127R"==(for (int ow = ow_beg; ow < ow_beg + OWB; )==""\n"
128R"==(ow += OW_BLOCK) { )==""\n"
129R"==(float8 blockB; )==""\n"
130R"==(for (int i = 0; i < OW_BLOCK; i++) { )==""\n"
131R"==(if (ow + i >= OW) { )==""\n"
132R"==(blockB[i] = 0.0; )==""\n"
133R"==(} else { )==""\n"
134R"==(blockB[i] = read_oc_block( )==""\n"
135R"==(&diff_dst1[(ow + i) * G )==""\n"
136R"==(* OC_WO_PADDING], )==""\n"
137R"==(oc * OC_BLOCK); )==""\n"
138R"==(} )==""\n"
139R"==(} )==""\n"
140R"==(for (int i = 0; i < OW_BLOCK; i++) )==""\n"
141R"==(bias_loc += blockB[i]; )==""\n"
142R"==(} )==""\n"
143R"==(} )==""\n"
144R"==(#endif )==""\n"
145R"==(continue; )==""\n"
146R"==(} )==""\n"
147R"==(for (int ow = ow_beg; ow < ow_beg + OWB; ow += OW_BLOCK) { )==""\n"
148R"==(const int id = od * SD - PD + kd * (1 + DD); )==""\n"
149R"==(const int ih = oh * SH - PH + kh * (1 + DH); )==""\n"
150R"==(const int iw = ow * SW - PW + kw * (1 + DW); )==""\n"
151R"==(__global float *src1 )==""\n"
152R"==(= src + (id * IH * IW + ih * IW + iw) * G * IC; )==""\n"
153R"==(#define TRANSPOSE_8(_block, _row, _col) \ )==""\n"
154R"==({ \ )==""\n"
155R"==((float8)(intel_sub_group_shuffle(_block[_row], 0 + _col), \ )==""\n"
156R"==(intel_sub_group_shuffle(_block[_row], 1 + _col), \ )==""\n"
157R"==(intel_sub_group_shuffle(_block[_row], 2 + _col), \ )==""\n"
158R"==(intel_sub_group_shuffle(_block[_row], 3 + _col), \ )==""\n"
159R"==(intel_sub_group_shuffle(_block[_row], 4 + _col), \ )==""\n"
160R"==(intel_sub_group_shuffle(_block[_row], 5 + _col), \ )==""\n"
161R"==(intel_sub_group_shuffle(_block[_row], 6 + _col), \ )==""\n"
162R"==(intel_sub_group_shuffle(_block[_row], 7 + _col)) \ )==""\n"
163R"==(} )==""\n"
164R"==(#define FMA8(a, b, c) fma((float8)(a), (float8)b, (float8)c) )==""\n"
165R"==(#define MULTIPLY_BLOCKS_8x8(_result, _blockA, _blockB, col) \ )==""\n"
166R"==({ \ )==""\n"
167R"==(_result = FMA8(_blockB.s0, TRANSPOSE_8(_blockA, 0, col), _result); \ )==""\n"
168R"==(_result = FMA8(_blockB.s1, TRANSPOSE_8(_blockA, 1, col), _result); \ )==""\n"
169R"==(_result = FMA8(_blockB.s2, TRANSPOSE_8(_blockA, 2, col), _result); \ )==""\n"
170R"==(_result = FMA8(_blockB.s3, TRANSPOSE_8(_blockA, 3, col), _result); \ )==""\n"
171R"==(_result = FMA8(_blockB.s4, TRANSPOSE_8(_blockA, 4, col), _result); \ )==""\n"
172R"==(_result = FMA8(_blockB.s5, TRANSPOSE_8(_blockA, 5, col), _result); \ )==""\n"
173R"==(_result = FMA8(_blockB.s6, TRANSPOSE_8(_blockA, 6, col), _result); \ )==""\n"
174R"==(_result = FMA8(_blockB.s7, TRANSPOSE_8(_blockA, 7, col), _result); \ )==""\n"
175R"==(} )==""\n"
176R"==(float8 blockA, blockB; )==""\n"
177R"==(#if IC == 3 )==""\n"
178R"==(if (sglid < IC) { )==""\n"
179R"==(for (int i = 0; i < OW_BLOCK; i++) { )==""\n"
180R"==(if (iw + i * SW < 0 || iw + i * SW >= IW) { )==""\n"
181R"==(blockA[i] = 0; )==""\n"
182R"==(} else { )==""\n"
183R"==(blockA[i] = src1[i * SW * G * IC + sglid]; )==""\n"
184R"==(} )==""\n"
185R"==(} )==""\n"
186R"==(} else { )==""\n"
187R"==(blockA = 0.0f; )==""\n"
188R"==(} )==""\n"
189R"==(#else )==""\n"
190R"==(__attribute__((opencl_unroll_hint(8))) )==""\n"
191R"==(for (int i = 0; i < OW_BLOCK; i++) { )==""\n"
192R"==(if (iw + i * SW < 0 || iw + i * SW >= IW) { )==""\n"
193R"==(blockA[i] = 0; )==""\n"
194R"==(} else { )==""\n"
195R"==(blockA[i] = read_ic_block( )==""\n"
196R"==(&src1[i * SW * G * IC], ic * IC_BLOCK); )==""\n"
197R"==(} )==""\n"
198R"==(} )==""\n"
199R"==(#endif )==""\n"
200R"==(__attribute__((opencl_unroll_hint(8))) )==""\n"
201R"==(for (int i = 0; i < OW_BLOCK; i++) { )==""\n"
202R"==(if (ow + i >= OW) { )==""\n"
203R"==(blockB[i] = 0.0; )==""\n"
204R"==(} else { )==""\n"
205R"==(blockB[i] = read_oc_block( )==""\n"
206R"==(&diff_dst1[(ow + i) * G * OC_WO_PADDING], )==""\n"
207R"==(oc * OC_BLOCK); )==""\n"
208R"==(} )==""\n"
209R"==(} )==""\n"
210R"==(#if IC == 3 )==""\n"
211R"==(MULTIPLY_BLOCKS_8x8(blockC00, blockA, blockB, 0); )==""\n"
212R"==(#elif IS_DW )==""\n"
213R"==(for (int i = 0; i < OW_BLOCK; i++) { )==""\n"
214R"==(blockC00 = fma(blockA[i], blockB[i], blockC00); )==""\n"
215R"==(} )==""\n"
216R"==(#else )==""\n"
217R"==(MULTIPLY_BLOCKS_8x8(blockC00, blockA, blockB, 0); )==""\n"
218R"==(MULTIPLY_BLOCKS_8x8(blockC01, blockA, blockB, 8); )==""\n"
219R"==(#endif )==""\n"
220R"==(#if WITH_BIAS == 1 )==""\n"
221R"==(for (int i = 0; i < 8; i++) )==""\n"
222R"==(bias_loc += blockB[i]; )==""\n"
223R"==(#endif )==""\n"
224R"==(} )==""\n"
225R"==(} )==""\n"
226R"==(src += ID * IH * IW * G * IC; )==""\n"
227R"==(} )==""\n"
228R"==(#if WITH_BIAS == 1 )==""\n"
229R"==(if (do_bias && oc * OC_BLOCK + sglid < (IS_DW ? G : OC_WO_PADDING)) )==""\n"
230R"==(atomic_add_global(diff_bias, bias_loc); )==""\n"
231R"==(#endif )==""\n"
232R"==(#if IC == 3 )==""\n"
233R"==(diff_wei += g * oc_padded * ic_padded * KD * KH * KW; )==""\n"
234R"==(diff_wei += oc * KD * KH * KW * ic_padded * OC_BLOCK; )==""\n"
235R"==(diff_wei += (kd * KH * KW + kh * KW + kw) * ic_padded * OC_BLOCK; )==""\n"
236R"==(for (int i = 0; i < 3; i++) )==""\n"
237R"==(atomic_add_global(diff_wei + i * OC_BLOCK + sglid, blockC00[i]); )==""\n"
238R"==(#elif IS_DW )==""\n"
239R"==(diff_wei += oc * KD * KH * KW * OC_BLOCK; )==""\n"
240R"==(diff_wei += (kd * KH * KW + kh * KW + kw) * OC_BLOCK; )==""\n"
241R"==(atomic_add_global(diff_wei + sglid, blockC00); )==""\n"
242R"==(#else )==""\n"
243R"==(diff_wei += g * ic_padded * oc_padded * KD * KH * KW; )==""\n"
244R"==(diff_wei += ic * oc_padded * KD * KH * KW * IC_BLOCK; )==""\n"
245R"==(diff_wei += oc * KD * KH * KW * IC_BLOCK * OC_BLOCK; )==""\n"
246R"==(diff_wei += (kd * KH * KW + kh * KW + kw) * IC_BLOCK * OC_BLOCK; )==""\n"
247R"==(for (int i = 0; i < 8; i++) )==""\n"
248R"==(atomic_add_global(diff_wei + i * OC_BLOCK + sglid, blockC00[i]); )==""\n"
249R"==(for (int i = 0; i < 8; i++) )==""\n"
250R"==(atomic_add_global(diff_wei + (8 + i) * OC_BLOCK + sglid, blockC01[i]); )==""\n"
251R"==(#endif )==""\n"
252R"==(} )==""\n"
253R"==(#endif )==""\n"
254R"==()==";
255}
256}
257}
258}