1namespace dnnl {
2namespace impl {
3namespace gpu {
4namespace ocl {
5const char *gen9_conv_nhwc_bwd_data_kernel = R"==(/******************************************************************************* )==""\n"
6R"==(* Copyright 2020-2021 Intel Corporation )==""\n"
7R"==(* )==""\n"
8R"==(* Licensed under the Apache License, Version 2.0 (the "License"); )==""\n"
9R"==(* you may not use this file except in compliance with the License. )==""\n"
10R"==(* You may obtain a copy of the License at )==""\n"
11R"==(* )==""\n"
12R"==(* http: )==""\n"
13R"==(* )==""\n"
14R"==(* Unless required by applicable law or agreed to in writing, software )==""\n"
15R"==(* distributed under the License is distributed on an "AS IS" BASIS, )==""\n"
16R"==(* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. )==""\n"
17R"==(* See the License for the specific language governing permissions and )==""\n"
18R"==(* limitations under the License. )==""\n"
19R"==(*******************************************************************************/ )==""\n"
20R"==(#include "gpu/ocl/ocl_types.h" )==""\n"
21R"==(#if ID > 1 )==""\n"
22R"==(#define CASE_3D 1 )==""\n"
23R"==(#else )==""\n"
24R"==(#define CASE_3D 0 )==""\n"
25R"==(#endif )==""\n"
26R"==(#define TRANSPOSE_1(_block, _col) \ )==""\n"
27R"==((DATA_T)(intel_sub_group_shuffle(_block, _col)) )==""\n"
28R"==(#define FMA1(a, b, c) fma((DATA_T)(a), (DATA_T)b, (DATA_T)c) )==""\n"
29R"==(#define MULTIPLY_BLOCKS_8x8(_result, _blockA, _blockB, _blockB1) \ )==""\n"
30R"==({ \ )==""\n"
31R"==(_result = FMA1(_blockB.s0, TRANSPOSE_1(_blockA, 0), _result); \ )==""\n"
32R"==(_result = FMA1(_blockB.s1, TRANSPOSE_1(_blockA, 1), _result); \ )==""\n"
33R"==(_result = FMA1(_blockB.s2, TRANSPOSE_1(_blockA, 2), _result); \ )==""\n"
34R"==(_result = FMA1(_blockB.s3, TRANSPOSE_1(_blockA, 3), _result); \ )==""\n"
35R"==(_result = FMA1(_blockB.s4, TRANSPOSE_1(_blockA, 4), _result); \ )==""\n"
36R"==(_result = FMA1(_blockB.s5, TRANSPOSE_1(_blockA, 5), _result); \ )==""\n"
37R"==(_result = FMA1(_blockB.s6, TRANSPOSE_1(_blockA, 6), _result); \ )==""\n"
38R"==(_result = FMA1(_blockB.s7, TRANSPOSE_1(_blockA, 7), _result); \ )==""\n"
39R"==(_result = FMA1(_blockB1.s0, TRANSPOSE_1(_blockA, 8), _result); \ )==""\n"
40R"==(_result = FMA1(_blockB1.s1, TRANSPOSE_1(_blockA, 9), _result); \ )==""\n"
41R"==(_result = FMA1(_blockB1.s2, TRANSPOSE_1(_blockA, 10), _result); \ )==""\n"
42R"==(_result = FMA1(_blockB1.s3, TRANSPOSE_1(_blockA, 11), _result); \ )==""\n"
43R"==(_result = FMA1(_blockB1.s4, TRANSPOSE_1(_blockA, 12), _result); \ )==""\n"
44R"==(_result = FMA1(_blockB1.s5, TRANSPOSE_1(_blockA, 13), _result); \ )==""\n"
45R"==(_result = FMA1(_blockB1.s6, TRANSPOSE_1(_blockA, 14), _result); \ )==""\n"
46R"==(_result = FMA1(_blockB1.s7, TRANSPOSE_1(_blockA, 15), _result); \ )==""\n"
47R"==(} )==""\n"
48R"==(inline DATA_T read_oc_block(const __global DATA_T *ptr, int off) { )==""\n"
49R"==(const int local_id = get_sub_group_local_id(); )==""\n"
50R"==(#if OC_WO_PADDING % OC_BLOCK != 0 )==""\n"
51R"==(int tail = OC_WO_PADDING - off; )==""\n"
52R"==(if (tail < OC_BLOCK) { )==""\n"
53R"==(return (local_id < tail) ? AS_DATA_T(ptr[local_id]) : DATA_ZERO; )==""\n"
54R"==(} )==""\n"
55R"==(#endif )==""\n"
56R"==(if ((OC_WO_PADDING * sizeof(DATA_T)) % 4 != 0) )==""\n"
57R"==(return ptr[local_id]; )==""\n"
58R"==(else )==""\n"
59R"==(return AS_DATA_T(BLOCK_READ((const __global BLOCK_DATA_T *)(ptr))); )==""\n"
60R"==(} )==""\n"
61R"==(inline void write_ic_block(__global DATA_T *ptr, int off, DATA_T value) { )==""\n"
62R"==(const int local_id = get_sub_group_local_id(); )==""\n"
63R"==(#if IC_WO_PADDING % IC_BLOCK != 0 )==""\n"
64R"==(int tail = IC_WO_PADDING - off; )==""\n"
65R"==(if (tail < IC_BLOCK) { )==""\n"
66R"==(if (local_id < tail) ptr[local_id] = value; )==""\n"
67R"==(return; )==""\n"
68R"==(} )==""\n"
69R"==(#endif )==""\n"
70R"==(if ((IC_WO_PADDING * sizeof(DATA_T)) % 16 != 0) )==""\n"
71R"==(ptr[local_id] = value; )==""\n"
72R"==(else )==""\n"
73R"==(BLOCK_WRITE((__global BLOCK_DATA_T *)ptr, AS_BLOCK_DATA_T(value)); )==""\n"
74R"==(return; )==""\n"
75R"==(} )==""\n"
76R"==(__attribute__((reqd_work_group_size(LWS_0, LWS_1, LWS_2))) )==""\n"
77R"==(__attribute__((intel_reqd_sub_group_size(SUB_GROUP_SIZE))) )==""\n"
78R"==(__kernel void )==""\n"
79R"==(gen9_conv_nhwc_bwd_data(__global DATA_T *diff_src, __global DATA_T *wei, )==""\n"
80R"==(__global DATA_T *diff_dst, __global DATA_T *bias) { )==""\n"
81R"==(MAYBE_SKIP_NON_UNIFORM_WG(); )==""\n"
82R"==(const int sp = get_group_id(1); )==""\n"
83R"==(const int local_id = get_sub_group_local_id(); )==""\n"
84R"==(const int icb_mb = get_group_id(2); )==""\n"
85R"==(const int mb = icb_mb / (G * IC_PADDED / ICB); )==""\n"
86R"==(const int icb = icb_mb % (G * IC_PADDED / ICB); )==""\n"
87R"==(const int ic = (icb * ICB) / IC_BLOCK + get_group_id(0); )==""\n"
88R"==(const int g = ic / (IC_PADDED / IC_BLOCK); )==""\n"
89R"==(const int gic = ic % (IC_PADDED / IC_BLOCK); )==""\n"
90R"==(#if CASE_3D )==""\n"
91R"==(const int id = sp / (IWB * IH); )==""\n"
92R"==(const int ihw = sp % (IWB * IH); )==""\n"
93R"==(#else )==""\n"
94R"==(const int id = 0; )==""\n"
95R"==(const int ihw = sp; )==""\n"
96R"==(#endif )==""\n"
97R"==(const int ih = ihw / IWB; )==""\n"
98R"==(const int iw = (ihw % IWB) * IW_BLOCK; )==""\n"
99R"==(diff_dst += mb * OC_WO_PADDING * G * OD * OH * OW + g * OC_WO_PADDING; )==""\n"
100R"==(DATA_T blockC00[IW_BLOCK] = {DATA_ZERO}; )==""\n"
101R"==(if (WITH_BIAS) { )==""\n"
102R"==(const int bg_off = g * IC; )==""\n"
103R"==(const int bc_off = gic * IC_BLOCK + local_id; )==""\n"
104R"==(DATA_T b = (IC_WO_PADDING % IC_BLOCK == 0 || bc_off < IC_WO_PADDING) )==""\n"
105R"==(? bias[bg_off + bc_off] )==""\n"
106R"==(: DATA_ZERO; )==""\n"
107R"==(unroll_for(int i = 0; i < IW_BLOCK; ++i) { blockC00[i] = b; } )==""\n"
108R"==(} )==""\n"
109R"==(wei += gic * KD * KH * KW * OC_BLOCK * IC_BLOCK )==""\n"
110R"==(+ g * IC_PADDED * OC_PADDED * KD * KH * KW; )==""\n"
111R"==(#if KH != 1 || KW != 1 || KD != 1 )==""\n"
112R"==(for (int kd = 0; kd < KD; ++kd) )==""\n"
113R"==(for (int kh = 0; kh < KH; ++kh) )==""\n"
114R"==(for (int kw = 0; kw < KW; ++kw) { )==""\n"
115R"==(if (ih + PH < kh * (1 + DH)) continue; )==""\n"
116R"==(#if CASE_3D )==""\n"
117R"==(if (id + PD < kd * (1 + DD)) continue; )==""\n"
118R"==(int od = id - kd * (1 + DD) + PD; )==""\n"
119R"==(if (od % SD != 0) continue; )==""\n"
120R"==(od /= SD; )==""\n"
121R"==(if (od >= OD) continue; )==""\n"
122R"==(#endif )==""\n"
123R"==(int oh = ih - kh * (1 + DH) + PH; )==""\n"
124R"==(if (oh % SH != 0) continue; )==""\n"
125R"==(oh /= SH; )==""\n"
126R"==(if (oh >= OH) continue; )==""\n"
127R"==(const __global DATA_T *diff_dst1 )==""\n"
128R"==(= diff_dst + oh * OW * G * OC_WO_PADDING; )==""\n"
129R"==(#if CASE_3D )==""\n"
130R"==(diff_dst1 += od * OH * OW * G * OC_WO_PADDING; )==""\n"
131R"==(#endif )==""\n"
132R"==(const __global DATA_T *wei1 = wei )==""\n"
133R"==(#if CASE_3D )==""\n"
134R"==(+ kd * KH * KW * OC_BLOCK * IC_BLOCK )==""\n"
135R"==(#endif )==""\n"
136R"==(+ kh * KW * OC_BLOCK * IC_BLOCK )==""\n"
137R"==(+ kw * OC_BLOCK * IC_BLOCK; )==""\n"
138R"==(#else )==""\n"
139R"==(int oh = (ih + PH); )==""\n"
140R"==(#if CASE_3D )==""\n"
141R"==(int od = (id + PD); )==""\n"
142R"==(#endif )==""\n"
143R"==(bool do_ker = true; )==""\n"
144R"==(#if SW != 1 || SH != 1 || SD != 1 )==""\n"
145R"==(do_ker = oh % SH == 0; )==""\n"
146R"==(oh /= SH; )==""\n"
147R"==(#if CASE_3D )==""\n"
148R"==(do_ker = do_ker && od % SD == 0; )==""\n"
149R"==(od /= SD; )==""\n"
150R"==(#endif )==""\n"
151R"==(#endif )==""\n"
152R"==(#if PH != 0 || PW != 0 || PD != 0 )==""\n"
153R"==(do_ker = do_ker && (oh < OH); )==""\n"
154R"==(#if CASE_3D )==""\n"
155R"==(do_ker = do_ker && (od < OD); )==""\n"
156R"==(#endif )==""\n"
157R"==(#endif )==""\n"
158R"==(#if SW != 1 || SH != 1 || SD != 1 || PH != 0 || PW != 0 || PD != 0 )==""\n"
159R"==(if (do_ker) { )==""\n"
160R"==(#endif )==""\n"
161R"==(const __global DATA_T *diff_dst1 )==""\n"
162R"==(= diff_dst + oh * OW * G * OC_WO_PADDING; )==""\n"
163R"==(#if CASE_3D )==""\n"
164R"==(diff_dst1 += od * OH * OW * G * OC_WO_PADDING; )==""\n"
165R"==(#endif )==""\n"
166R"==(const __global DATA_T *wei1 = wei; )==""\n"
167R"==(#endif )==""\n"
168R"==(int ocb = 0; )==""\n"
169R"==(do { )==""\n"
170R"==(DATA8_T blockB00 = AS_DATA8_T( )==""\n"
171R"==(BLOCK_READ8((const __global BLOCK_DATA_T *)wei1)); )==""\n"
172R"==(DATA8_T blockB01 = AS_DATA8_T( )==""\n"
173R"==(BLOCK_READ8((const __global BLOCK_DATA_T *)(wei1 )==""\n"
174R"==(+ 8 * IC_BLOCK))); )==""\n"
175R"==(DATA_T blockA[IW_BLOCK]; )==""\n"
176R"==(__attribute__(( )==""\n"
177R"==(opencl_unroll_hint(IW_BLOCK))) )==""\n"
178R"==(for (int i = 0; i < IW_BLOCK; i++) { )==""\n"
179R"==(#if KW != 1 )==""\n"
180R"==(if (iw + i + PW < kw * (1 + DW)) { )==""\n"
181R"==(blockA[i] = 0.0; )==""\n"
182R"==(continue; )==""\n"
183R"==(} )==""\n"
184R"==(int ow = iw + i - kw * (1 + DW) + PW; )==""\n"
185R"==(#else )==""\n"
186R"==(int ow = iw + i + PW; )==""\n"
187R"==(#endif )==""\n"
188R"==(#if SW != 1 )==""\n"
189R"==(if (ow % SW != 0) { )==""\n"
190R"==(blockA[i] = 0.0; )==""\n"
191R"==(continue; )==""\n"
192R"==(} )==""\n"
193R"==(ow /= SW; )==""\n"
194R"==(#endif )==""\n"
195R"==(if (ow >= OW) { )==""\n"
196R"==(blockA[i] = 0.0; )==""\n"
197R"==(continue; )==""\n"
198R"==(} )==""\n"
199R"==(blockA[i] = read_oc_block( )==""\n"
200R"==(&diff_dst1[ow * G * OC_WO_PADDING], ocb); )==""\n"
201R"==(} )==""\n"
202R"==(__attribute__(( )==""\n"
203R"==(opencl_unroll_hint(IW_BLOCK))) )==""\n"
204R"==(for (int i = 0; i < IW_BLOCK; i++) { )==""\n"
205R"==(MULTIPLY_BLOCKS_8x8( )==""\n"
206R"==(blockC00[i], blockA[i], blockB00, blockB01); )==""\n"
207R"==(} )==""\n"
208R"==(diff_dst1 += OC_BLOCK; )==""\n"
209R"==(wei1 += KD * KH * KW * OC_BLOCK * IC_PADDED; )==""\n"
210R"==(ocb += OC_BLOCK; )==""\n"
211R"==(} while (ocb < OC); )==""\n"
212R"==(#if SW != 1 || SH != 1 || SD != 1 || PH != 0 || PW != 0 || PD != 0 )==""\n"
213R"==(} )==""\n"
214R"==(#else )==""\n"
215R"==(#if KH != 1 || KW != 1 || KD != 1 )==""\n"
216R"==(} )==""\n"
217R"==(#endif )==""\n"
218R"==(#endif )==""\n"
219R"==(__global DATA_T *src_write0 = diff_src )==""\n"
220R"==(+ mb * IC_WO_PADDING * G * ID * IH * IW )==""\n"
221R"==(+ id * IH * IW * G * IC_WO_PADDING + ih * IW * G * IC_WO_PADDING )==""\n"
222R"==(+ iw * G * IC_WO_PADDING + g * IC_WO_PADDING + gic * IC_BLOCK; )==""\n"
223R"==(for (int i = 0; i < min(IW_BLOCK, IW - iw); i++) { )==""\n"
224R"==(write_ic_block(&src_write0[i * G * IC_WO_PADDING], gic * IC_BLOCK, )==""\n"
225R"==(blockC00[i]); )==""\n"
226R"==(} )==""\n"
227R"==(} )==""\n"
228R"==(#undef TRANSPOSE_BLOCK_8 )==""\n"
229R"==(#undef MULTIPLY_BLOCKS_8x8 )==""\n"
230R"==()==";
231}
232}
233}
234}