1namespace dnnl {
2namespace impl {
3namespace gpu {
4namespace ocl {
5const char *gen9_conv_dw_bwd_data_kernel = R"==(/******************************************************************************* )==""\n"
6R"==(* Copyright 2019-2021 Intel Corporation )==""\n"
7R"==(* )==""\n"
8R"==(* Licensed under the Apache License, Version 2.0 (the "License"); )==""\n"
9R"==(* you may not use this file except in compliance with the License. )==""\n"
10R"==(* You may obtain a copy of the License at )==""\n"
11R"==(* )==""\n"
12R"==(* http: )==""\n"
13R"==(* )==""\n"
14R"==(* Unless required by applicable law or agreed to in writing, software )==""\n"
15R"==(* distributed under the License is distributed on an "AS IS" BASIS, )==""\n"
16R"==(* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. )==""\n"
17R"==(* See the License for the specific language governing permissions and )==""\n"
18R"==(* limitations under the License. )==""\n"
19R"==(*******************************************************************************/ )==""\n"
20R"==(#include "gpu/ocl/ocl_types.h" )==""\n"
21R"==(#if BWD_DATA == 1 )==""\n"
22R"==(__attribute__((reqd_work_group_size(LWS_0, LWS_1, LWS_2))) )==""\n"
23R"==(__attribute__((intel_reqd_sub_group_size(SUB_GROUP_SIZE))) )==""\n"
24R"==(__kernel void )==""\n"
25R"==(gen9_conv_dw_bwd_data(__global DATA_T *diff_src, __global DATA_T *wei, )==""\n"
26R"==(__global DATA_T *diff_dst, __global DATA_T *bias) { )==""\n"
27R"==(MAYBE_SKIP_NON_UNIFORM_WG(); )==""\n"
28R"==(#if VER_16MB16C == 1 )==""\n"
29R"==(const int mb_unroll = 16; )==""\n"
30R"==(const int ic )==""\n"
31R"==(= get_group_id(1) * (LWS_1 / SUB_GROUP_SIZE) + get_sub_group_id(); )==""\n"
32R"==(const int sp = get_group_id(0); )==""\n"
33R"==(const int sglid = get_sub_group_local_id(); )==""\n"
34R"==(int mb = get_group_id(2) * mb_unroll; )==""\n"
35R"==(const int g = ic * IC_BLOCK; )==""\n"
36R"==(const int gic = 0; )==""\n"
37R"==(const int id = sp / (IW * IH); )==""\n"
38R"==(const int ihw = sp % (IW * IH); )==""\n"
39R"==(const int ih = ihw / IW; )==""\n"
40R"==(const int iw = ihw % IW; )==""\n"
41R"==(diff_dst += mb * OC * G * OD * OH * OW + g * OC * OD * OH * OW * MB_BLOCK; )==""\n"
42R"==(DATA8_T blockC00 = (DATA8_T)DATA_ZERO; )==""\n"
43R"==(DATA8_T blockC01 = (DATA8_T)DATA_ZERO; )==""\n"
44R"==(if (WITH_BIAS) { )==""\n"
45R"==(const int bg_off = g * IC + gic * IC_BLOCK + sglid; )==""\n"
46R"==(DATA_T b = (G_WO_PADDING % IC_BLOCK == 0 || bg_off < G_WO_PADDING) )==""\n"
47R"==(? bias[bg_off] )==""\n"
48R"==(: DATA_ZERO; )==""\n"
49R"==(unroll_for(int i = 0; i < 8; ++i) { )==""\n"
50R"==(blockC00[i] = b; )==""\n"
51R"==(blockC01[i] = b; )==""\n"
52R"==(} )==""\n"
53R"==(} )==""\n"
54R"==(wei += gic * KD * KH * KW * OC_BLOCK * IC_BLOCK )==""\n"
55R"==(+ g * IC * OC * KD * KH * KW; )==""\n"
56R"==(#if KH != 1 || KW != 1 || KD != 1 )==""\n"
57R"==(for (int kd = 0; kd < KD; ++kd) )==""\n"
58R"==(for (int kh = 0; kh < KH; ++kh) )==""\n"
59R"==(for (int kw = 0; kw < KW; ++kw) { )==""\n"
60R"==(if (id + PD < kd * (1 + DD) || iw + PW < kw * (1 + DW) )==""\n"
61R"==(|| ih + PH < kh * (1 + DH)) )==""\n"
62R"==(continue; )==""\n"
63R"==(int od = id - kd * (1 + DD) + PD; )==""\n"
64R"==(int ow = iw - kw * (1 + DW) + PW; )==""\n"
65R"==(int oh = ih - kh * (1 + DH) + PH; )==""\n"
66R"==(if (od % SD != 0 || ow % SW != 0 || oh % SH != 0) continue; )==""\n"
67R"==(od /= SD; )==""\n"
68R"==(ow /= SW; )==""\n"
69R"==(oh /= SH; )==""\n"
70R"==(if (od >= OD || oh >= OH || ow >= OW) continue; )==""\n"
71R"==(const __global DATA_T *diff_dst1 = diff_dst )==""\n"
72R"==(+ ow * OC_BLOCK * MB_BLOCK )==""\n"
73R"==(+ oh * OW * OC_BLOCK * MB_BLOCK; )==""\n"
74R"==(diff_dst1 += od * OH * OW * OC_BLOCK * MB_BLOCK; )==""\n"
75R"==(const __global DATA_T *wei1 = wei + kd * KH * KW * OC_BLOCK )==""\n"
76R"==(+ kh * KW * OC_BLOCK + kw * OC_BLOCK; )==""\n"
77R"==(#else )==""\n"
78R"==(const int kw = 0; )==""\n"
79R"==(int ow = (iw + PW); )==""\n"
80R"==(int oh = (ih + PH); )==""\n"
81R"==(int od = (id + PD); )==""\n"
82R"==(bool do_ker = ow % SW == 0 && oh % SH == 0 && od % SD == 0; )==""\n"
83R"==(ow /= SW; )==""\n"
84R"==(oh /= SH; )==""\n"
85R"==(od /= SD; )==""\n"
86R"==(#if PH != 0 || PW != 0 || PD != 0 )==""\n"
87R"==(do_ker = do_ker && (od < OD && oh < OH && ow < OW); )==""\n"
88R"==(#endif )==""\n"
89R"==(if (do_ker) { )==""\n"
90R"==(const __global DATA_T *diff_dst1 = diff_dst + ow * OC_BLOCK * MB_BLOCK )==""\n"
91R"==(+ oh * OW * OC_BLOCK * MB_BLOCK; )==""\n"
92R"==(diff_dst1 += od * OH * OW * OC_BLOCK * MB_BLOCK; )==""\n"
93R"==(const __global DATA_T *wei1 = wei; )==""\n"
94R"==(#endif )==""\n"
95R"==(#define LOAD_DIFF_DST(_block, _diff_dst, mb_chunk) \ )==""\n"
96R"==({ \ )==""\n"
97R"==((_block) = AS_DATA8_T( \ )==""\n"
98R"==(BLOCK_READ8((const __global BLOCK_DATA_T *)((_diff_dst) \ )==""\n"
99R"==(+ (mb_chunk)*OC_BLOCK))); \ )==""\n"
100R"==(} )==""\n"
101R"==(#define SAVE_SRC_DIFF(_block, _diff_src, mb_chunk) \ )==""\n"
102R"==({ \ )==""\n"
103R"==(BLOCK_WRITE8((const __global BLOCK_DATA_T *)(&( \ )==""\n"
104R"==(_diff_src)[(mb_chunk)*IC_BLOCK]), \ )==""\n"
105R"==(AS_BLOCK_DATA8_T((_block))); \ )==""\n"
106R"==(} )==""\n"
107R"==(DATA8_T blockA0, blockA1; )==""\n"
108R"==(LOAD_DIFF_DST(blockA0, diff_dst1, 0); )==""\n"
109R"==(LOAD_DIFF_DST(blockA1, diff_dst1, 8); )==""\n"
110R"==(DATA_T blockB00 = AS_DATA_T( )==""\n"
111R"==(BLOCK_READ((const __global BLOCK_DATA_T *)wei1)); )==""\n"
112R"==(blockC00 = fma(blockA0, (DATA8_T)blockB00, blockC00); )==""\n"
113R"==(blockC01 = fma(blockA1, (DATA8_T)blockB00, blockC01); )==""\n"
114R"==(#if KH != 1 || KW != 1 || KD != 1 )==""\n"
115R"==(} )==""\n"
116R"==(#else )==""\n"
117R"==(} )==""\n"
118R"==(#endif )==""\n"
119R"==(__global DATA_T *src_write0 = diff_src + mb * IC * G * ID * IH * IW )==""\n"
120R"==(+ gic * ID * IH * IW * IC_BLOCK * MB_BLOCK )==""\n"
121R"==(+ g * IC * ID * IH * IW * MB_BLOCK )==""\n"
122R"==(+ id * IH * IW * IC_BLOCK * MB_BLOCK + ih * IW * IC_BLOCK * MB_BLOCK )==""\n"
123R"==(+ iw * IC_BLOCK * MB_BLOCK; )==""\n"
124R"==(SAVE_SRC_DIFF(blockC00, src_write0, 0); )==""\n"
125R"==(SAVE_SRC_DIFF(blockC01, src_write0, 8); )==""\n"
126R"==(#endif )==""\n"
127R"==(#if VER_8OW16C == 1 )==""\n"
128R"==(const int ic )==""\n"
129R"==(= get_group_id(1) * (LWS_1 / SUB_GROUP_SIZE) + get_sub_group_id(); )==""\n"
130R"==(const int sp = get_group_id(0); )==""\n"
131R"==(const int sglid = get_sub_group_local_id(); )==""\n"
132R"==(const int mb = get_group_id(2); )==""\n"
133R"==(const int g = ic * IC_BLOCK; )==""\n"
134R"==(const int gic = 0; )==""\n"
135R"==(const int id = sp / (IWB * IH); )==""\n"
136R"==(const int ihw = sp % (IWB * IH); )==""\n"
137R"==(const int ih = ihw / IWB; )==""\n"
138R"==(const int iw = (ihw % IWB) * IW_BLOCK; )==""\n"
139R"==(diff_dst += mb * OC * G * OD * OH * OW + g * OC * OD * OH * OW * MB_BLOCK; )==""\n"
140R"==(DATA_T blockC00[IW_BLOCK] = {DATA_ZERO}; )==""\n"
141R"==(if (WITH_BIAS) { )==""\n"
142R"==(const int bg_off = g * IC + gic * IC_BLOCK + sglid; )==""\n"
143R"==(DATA_T b = (G_WO_PADDING % IC_BLOCK == 0 || bg_off < G_WO_PADDING) )==""\n"
144R"==(? bias[bg_off] )==""\n"
145R"==(: DATA_ZERO; )==""\n"
146R"==(unroll_for(int i = 0; i < IW_BLOCK; ++i) { blockC00[i] = b; } )==""\n"
147R"==(} )==""\n"
148R"==(wei += gic * KD * KH * KW * OC_BLOCK * IC_BLOCK )==""\n"
149R"==(+ g * IC * OC * KD * KH * KW; )==""\n"
150R"==(#if KH != 1 || KW != 1 || KD != 1 )==""\n"
151R"==(for (int kd = 0; kd < KD; ++kd) )==""\n"
152R"==(for (int kh = 0; kh < KH; ++kh) )==""\n"
153R"==(for (int kw = 0; kw < KW; ++kw) { )==""\n"
154R"==(if (id + PD < kd * (1 + DD)) continue; )==""\n"
155R"==(if (ih + PH < kh * (1 + DH)) continue; )==""\n"
156R"==(int od = id - kd * (1 + DD) + PD; )==""\n"
157R"==(int oh = ih - kh * (1 + DH) + PH; )==""\n"
158R"==(if (od % SD != 0 || oh % SH != 0) continue; )==""\n"
159R"==(od /= SD; )==""\n"
160R"==(oh /= SH; )==""\n"
161R"==(if (od >= OD || oh >= OH) continue; )==""\n"
162R"==(const __global DATA_T *diff_dst1 = diff_dst )==""\n"
163R"==(+ oh * OW * OC_BLOCK * MB_BLOCK )==""\n"
164R"==(+ od * OH * OW * OC_BLOCK * MB_BLOCK; )==""\n"
165R"==(const __global DATA_T *wei1 = wei + kd * KH * KW * OC_BLOCK )==""\n"
166R"==(+ kh * KW * OC_BLOCK + kw * OC_BLOCK; )==""\n"
167R"==(#else )==""\n"
168R"==(const int kw = 0; )==""\n"
169R"==(int oh = (ih + PH); )==""\n"
170R"==(int od = (id + PD); )==""\n"
171R"==(bool do_ker = od % SD == 0 && oh % SH == 0; )==""\n"
172R"==(oh /= SH; )==""\n"
173R"==(od /= SD; )==""\n"
174R"==(#if PH != 0 || PW != 0 || PD != 0 )==""\n"
175R"==(do_ker = do_ker && (oh < OH && od < OD); )==""\n"
176R"==(#endif )==""\n"
177R"==(if (do_ker) { )==""\n"
178R"==(const __global DATA_T *diff_dst1 = diff_dst )==""\n"
179R"==(+ oh * OW * OC_BLOCK * MB_BLOCK )==""\n"
180R"==(+ od * OH * OW * OC_BLOCK * MB_BLOCK; )==""\n"
181R"==(const __global DATA_T *wei1 = wei; )==""\n"
182R"==(#endif )==""\n"
183R"==(DATA_T blockB00 = AS_DATA_T( )==""\n"
184R"==(BLOCK_READ((const __global BLOCK_DATA_T *)wei1)); )==""\n"
185R"==(DATA_T blockA[IW_BLOCK]; )==""\n"
186R"==(__attribute__((opencl_unroll_hint(IW_BLOCK))) )==""\n"
187R"==(for (int i = 0; i < IW_BLOCK; i++) { )==""\n"
188R"==(if (iw + i + PW < kw * (1 + DW)) { )==""\n"
189R"==(blockA[i] = 0.0; )==""\n"
190R"==(continue; )==""\n"
191R"==(} )==""\n"
192R"==(int ow = iw + i - kw * (1 + DW) + PW; )==""\n"
193R"==(if (ow % SW != 0) { )==""\n"
194R"==(blockA[i] = 0.0; )==""\n"
195R"==(continue; )==""\n"
196R"==(} )==""\n"
197R"==(ow /= SW; )==""\n"
198R"==(if (ow >= OW) { )==""\n"
199R"==(blockA[i] = 0.0; )==""\n"
200R"==(continue; )==""\n"
201R"==(} )==""\n"
202R"==(blockA[i] = AS_DATA_T( )==""\n"
203R"==(BLOCK_READ((const __global BLOCK_DATA_T *)(&( )==""\n"
204R"==(diff_dst1)[ow * OC_BLOCK]))); )==""\n"
205R"==(} )==""\n"
206R"==(__attribute__((opencl_unroll_hint(IW_BLOCK))) )==""\n"
207R"==(for (int i = 0; i < IW_BLOCK; i++) { )==""\n"
208R"==(blockC00[i] = fma(blockA[i], (DATA_T)blockB00, blockC00[i]); )==""\n"
209R"==(} )==""\n"
210R"==(diff_dst1 += OC_BLOCK * OD * OH * OW * MB_BLOCK; )==""\n"
211R"==(wei1 += IC * KD * KH * KW * OC_BLOCK; )==""\n"
212R"==(#if KH != 1 || KW != 1 || KD != 1 )==""\n"
213R"==(} )==""\n"
214R"==(#else )==""\n"
215R"==(} )==""\n"
216R"==(#endif )==""\n"
217R"==(__global DATA_T *src_write0 = diff_src + mb * IC * G * ID * IH * IW )==""\n"
218R"==(+ gic * ID * IH * IW * IC_BLOCK * MB_BLOCK )==""\n"
219R"==(+ g * IC * ID * IH * IW * MB_BLOCK )==""\n"
220R"==(+ id * IH * IW * IC_BLOCK * MB_BLOCK + ih * IW * IC_BLOCK * MB_BLOCK )==""\n"
221R"==(+ iw * IC_BLOCK * MB_BLOCK; )==""\n"
222R"==(for (int i = 0; i < IW_BLOCK; i++) { )==""\n"
223R"==(if (iw + i >= IW) continue; )==""\n"
224R"==(BLOCK_WRITE((__global BLOCK_DATA_T *)(&(src_write0)[i * IC_BLOCK]), )==""\n"
225R"==(AS_BLOCK_DATA_T(blockC00[i])); )==""\n"
226R"==(} )==""\n"
227R"==(#endif )==""\n"
228R"==(} )==""\n"
229R"==(#endif )==""\n"
230R"==()==";
231}
232}
233}
234}