1namespace dnnl {
2namespace impl {
3namespace gpu {
4namespace ocl {
5const char *gen9_conv_nhwc_fwd_data_kernel = R"==(/******************************************************************************* )==""\n"
6R"==(* Copyright 2020-2021 Intel Corporation )==""\n"
7R"==(* )==""\n"
8R"==(* Licensed under the Apache License, Version 2.0 (the "License"); )==""\n"
9R"==(* you may not use this file except in compliance with the License. )==""\n"
10R"==(* You may obtain a copy of the License at )==""\n"
11R"==(* )==""\n"
12R"==(* http: )==""\n"
13R"==(* )==""\n"
14R"==(* Unless required by applicable law or agreed to in writing, software )==""\n"
15R"==(* distributed under the License is distributed on an "AS IS" BASIS, )==""\n"
16R"==(* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. )==""\n"
17R"==(* See the License for the specific language governing permissions and )==""\n"
18R"==(* limitations under the License. )==""\n"
19R"==(*******************************************************************************/ )==""\n"
20R"==(#include "gpu/ocl/ocl_post_ops.h" )==""\n"
21R"==(#include "gpu/ocl/ocl_types.h" )==""\n"
22R"==(#define _BLOCK_READ8(ptr) \ )==""\n"
23R"==(AS_DATA8_T(BLOCK_READ8((const __global BLOCK_DATA_T *)(ptr))) )==""\n"
24R"==(#define _BLOCK_READ4(ptr) \ )==""\n"
25R"==(AS_DATA4_T(BLOCK_READ4((const __global BLOCK_DATA_T *)(ptr))) )==""\n"
26R"==(#define _BLOCK_READ2(ptr) \ )==""\n"
27R"==(AS_DATA2_T(BLOCK_READ2((const __global BLOCK_DATA_T *)(ptr))) )==""\n"
28R"==(#define _BLOCK_READ(ptr) \ )==""\n"
29R"==(AS_DATA_T(BLOCK_READ((const __global BLOCK_DATA_T *)(ptr))) )==""\n"
30R"==(#define _BLOCK_WRITE8(ptr, v) \ )==""\n"
31R"==(BLOCK_WRITE8((__global BLOCK_DATA_T *)(ptr), AS_BLOCK_DATA8_T(v)) )==""\n"
32R"==(#define _BLOCK_WRITE4(ptr, v) \ )==""\n"
33R"==(BLOCK_WRITE4((__global BLOCK_DATA_T *)(ptr), AS_BLOCK_DATA4_T(v)) )==""\n"
34R"==(#define _BLOCK_WRITE2(ptr, v) \ )==""\n"
35R"==(BLOCK_WRITE2((__global BLOCK_DATA_T *)(ptr), AS_BLOCK_DATA2_T(v)) )==""\n"
36R"==(#define _BLOCK_WRITE(ptr, v) \ )==""\n"
37R"==(BLOCK_WRITE((__global BLOCK_DATA_T *)(ptr), AS_BLOCK_DATA_T(v)) )==""\n"
38R"==(#define ENABLE_KW_BUF (KW >= 5) )==""\n"
39R"==(#define IS_3D (OD > 1) )==""\n"
40R"==(#define KDHW_SIZE (KD * KH * KW) )==""\n"
41R"==(#define HAS_PAD_D (PD > 0 || OD * SD - PD + (KD - 1) * (1 + DD) >= ID) )==""\n"
42R"==(#define HAS_PAD_H (PH > 0 || OH * SH - PH + (KH - 1) * (1 + DH) >= IH) )==""\n"
43R"==(#define HAS_PAD_W (PW > 0 || OW * SW - PW + (KW - 1) * (1 + DW) >= IW) )==""\n"
44R"==(#define OC_PAD_BLOCK (OC % OC_BLOCK ? (OC / OC_BLOCK + 1) * OC_BLOCK : OC) )==""\n"
45R"==(#if DT_F32 )==""\n"
46R"==(#define BLOCK_READ_BOUND 1 )==""\n"
47R"==(#define BLOCK_WRITE_BOUND 4 )==""\n"
48R"==(#elif DT_F16 )==""\n"
49R"==(#define BLOCK_READ_BOUND 2 )==""\n"
50R"==(#define BLOCK_WRITE_BOUND 8 )==""\n"
51R"==(#else )==""\n"
52R"==(#error "Wrong Data Type" )==""\n"
53R"==(#endif )==""\n"
54R"==(inline DATA_T read_ic_block(const __global DATA_T *ptr, int off) { )==""\n"
55R"==(const int local_id = get_local_id(0); )==""\n"
56R"==(#if IC == 3 )==""\n"
57R"==(return (local_id < IC) ? *ptr : 0; )==""\n"
58R"==(#else )==""\n"
59R"==(#if (IS_DW ? G_WO_PADDING : IC_WO_PADDING) % IC_BLOCK != 0 )==""\n"
60R"==(int tail = (IS_DW ? G_WO_PADDING : IC_WO_PADDING) - off; )==""\n"
61R"==(if (tail < IC_BLOCK) { return (local_id < tail) ? ptr[local_id] : 0; } )==""\n"
62R"==(#endif )==""\n"
63R"==(#if (IS_DW ? G_WO_PADDING : IC_WO_PADDING) % BLOCK_READ_BOUND != 0 )==""\n"
64R"==(return ptr[local_id]; )==""\n"
65R"==(#else )==""\n"
66R"==(return _BLOCK_READ(ptr); )==""\n"
67R"==(#endif )==""\n"
68R"==(#endif )==""\n"
69R"==(} )==""\n"
70R"==(inline DATA_T read_oc_block(const __global DATA_T *ptr, int off) { )==""\n"
71R"==(const int local_id = get_local_id(0); )==""\n"
72R"==(#if (IS_DW ? G_WO_PADDING : OC_WO_PADDING) % OC_BLOCK != 0 )==""\n"
73R"==(int tail = (IS_DW ? G_WO_PADDING : OC_WO_PADDING) - off; )==""\n"
74R"==(if (tail < OC_BLOCK) { return (local_id < tail) ? ptr[local_id] : 0; } )==""\n"
75R"==(#endif )==""\n"
76R"==(#if (IS_DW ? G_WO_PADDING : OC_WO_PADDING) % BLOCK_READ_BOUND != 0 )==""\n"
77R"==(return ptr[local_id]; )==""\n"
78R"==(#else )==""\n"
79R"==(return _BLOCK_READ(ptr); )==""\n"
80R"==(#endif )==""\n"
81R"==(} )==""\n"
82R"==(inline void write_oc_block(__global DATA_T *ptr, int off, DATA_T value) { )==""\n"
83R"==(const int local_id = get_local_id(0); )==""\n"
84R"==(#if (IS_DW ? G_WO_PADDING : OC_WO_PADDING) % OC_BLOCK != 0 )==""\n"
85R"==(int tail = (IS_DW ? G_WO_PADDING : OC_WO_PADDING) - off; )==""\n"
86R"==(if (tail < OC_BLOCK) { )==""\n"
87R"==(if (local_id < tail) ptr[local_id] = value; )==""\n"
88R"==(return; )==""\n"
89R"==(} )==""\n"
90R"==(#endif )==""\n"
91R"==(#if (IS_DW ? G_WO_PADDING : OC_WO_PADDING) % BLOCK_WRITE_BOUND != 0 )==""\n"
92R"==(ptr[local_id] = value; )==""\n"
93R"==(return; )==""\n"
94R"==(#else )==""\n"
95R"==(return _BLOCK_WRITE(ptr, value); )==""\n"
96R"==(#endif )==""\n"
97R"==(} )==""\n"
98R"==(void multiply_blocks_8x8_ic3(DATA_T *res, DATA_T blockA, const DATA_T *blockB) { )==""\n"
99R"==(*res = fma(blockB[0], intel_sub_group_shuffle(blockA, 0), *res); )==""\n"
100R"==(*res = fma(blockB[1], intel_sub_group_shuffle(blockA, 1), *res); )==""\n"
101R"==(*res = fma(blockB[2], intel_sub_group_shuffle(blockA, 2), *res); )==""\n"
102R"==(} )==""\n"
103R"==(void multiply_blocks_8x8( )==""\n"
104R"==(DATA_T *res, DATA_T blockA, DATA8_T blockB0, DATA8_T blockB1) { )==""\n"
105R"==(for (int i = 0; i < 8; i++) { )==""\n"
106R"==(*res = fma(blockB0[i], intel_sub_group_shuffle(blockA, i), *res); )==""\n"
107R"==(} )==""\n"
108R"==(for (int i = 0; i < 8; i++) { )==""\n"
109R"==(*res = fma(blockB1[i], intel_sub_group_shuffle(blockA, 8 + i), *res); )==""\n"
110R"==(} )==""\n"
111R"==(} )==""\n"
112R"==(__attribute__((reqd_work_group_size(LWS_0, LWS_1, LWS_2))) )==""\n"
113R"==(__attribute__((intel_reqd_sub_group_size(SUB_GROUP_SIZE))) __kernel void )==""\n"
114R"==(gen9_conv_nhwc_fwd(const __global DATA_T *src, const __global DATA_T *wei, )==""\n"
115R"==(const __global DATA_T *bias, __global DATA_T *dst POST_OP_ARGS) { )==""\n"
116R"==(MAYBE_SKIP_NON_UNIFORM_WG(); )==""\n"
117R"==(const int sp = get_group_id(1); )==""\n"
118R"==(const int local_id = get_sub_group_local_id(); )==""\n"
119R"==(const int ocb_mb = get_group_id(2); )==""\n"
120R"==(const int ocb = ocb_mb / (MB); )==""\n"
121R"==(const int mb = ocb_mb % (MB); )==""\n"
122R"==(#if IS_DW )==""\n"
123R"==(const int oc = get_group_id(0); )==""\n"
124R"==(const int g = 0; )==""\n"
125R"==(const int goc = oc; )==""\n"
126R"==(#else )==""\n"
127R"==(const int oc = (ocb * OCB) / OC_BLOCK + get_group_id(0); )==""\n"
128R"==(const int g = oc / (OC_PAD_BLOCK / OC_BLOCK); )==""\n"
129R"==(const int goc = oc % (OC_PAD_BLOCK / OC_BLOCK); )==""\n"
130R"==(#endif )==""\n"
131R"==(const int od = IS_3D ? sp / (OWB * OHB) : 0; )==""\n"
132R"==(const int ohw = IS_3D ? sp % (OWB * OHB) : sp; )==""\n"
133R"==(const int id = IS_3D ? od * SD - PD : 0; )==""\n"
134R"==(const int oh = (ohw / OWB) * OH_BLOCK; )==""\n"
135R"==(const int ow = (ohw % OWB) * OW_BLOCK; )==""\n"
136R"==(DATA_T blockC00[OW_BLOCK] = {0}; )==""\n"
137R"==(if (WITH_BIAS) { )==""\n"
138R"==(const int bc_off = oc * OC_BLOCK + local_id; )==""\n"
139R"==(#if IS_DW )==""\n"
140R"==(DATA_T b = (G_WO_PADDING % OC_BLOCK == 0 || bc_off < G_WO_PADDING) )==""\n"
141R"==(#else )==""\n"
142R"==(DATA_T b = (OC_WO_PADDING % OC_BLOCK == 0 || bc_off < OC_WO_PADDING) )==""\n"
143R"==(#endif )==""\n"
144R"==(? bias[bc_off] )==""\n"
145R"==(: DATA_ZERO; )==""\n"
146R"==(unroll_for(int i = 0; i < OW_BLOCK; i++) { blockC00[i] = b; } )==""\n"
147R"==(} )==""\n"
148R"==(int ih = oh * SH - PH; )==""\n"
149R"==(int iw = ow * SW - PW; )==""\n"
150R"==(src += mb * ID * IH * IW * G * IC_WO_PADDING; )==""\n"
151R"==(src += (id * IH * IW + ih * IW + iw) * G * IC_WO_PADDING; )==""\n"
152R"==(src += g * IC_WO_PADDING; )==""\n"
153R"==(src += (IS_DW ? oc * OC_BLOCK : 0); )==""\n"
154R"==(wei += goc * KDHW_SIZE * OC_BLOCK * IC + g * IC * OC_PAD_BLOCK * KDHW_SIZE; )==""\n"
155R"==(#if (KD == 1 && KH == 1) && (HAS_PAD_D || HAS_PAD_H) )==""\n"
156R"==(const bool dh_out_of_range = (id < 0 || id >= ID || ih < 0 || ih >= IH); )==""\n"
157R"==(#else )==""\n"
158R"==(const bool dh_out_of_range = false; )==""\n"
159R"==(#endif )==""\n"
160R"==(#if IS_DW )==""\n"
161R"==(const int icb_min = goc * OC_BLOCK; )==""\n"
162R"==(const int icb_max = icb_min + OC_BLOCK; )==""\n"
163R"==(#else )==""\n"
164R"==(const int icb_min = 0; )==""\n"
165R"==(const int icb_max = dh_out_of_range ? 0 : (IC == 3 ? 1 : IC); )==""\n"
166R"==(#endif )==""\n"
167R"==(for (int icb = icb_min; icb < icb_max; icb += IC_BLOCK) { )==""\n"
168R"==(__attribute__((opencl_unroll_hint(1))) )==""\n"
169R"==(for (int kd = 0; kd < KD; ++kd) { )==""\n"
170R"==(#if HAS_PAD_D )==""\n"
171R"==(if (id + kd * (1 + DD) < 0 || id + kd * (1 + DD) >= ID) continue; )==""\n"
172R"==(#endif )==""\n"
173R"==(__attribute__((opencl_unroll_hint(1))) )==""\n"
174R"==(for (int kh = 0; kh < KH; ++kh) { )==""\n"
175R"==(#if HAS_PAD_H )==""\n"
176R"==(if (ih + kh * (1 + DH) < 0 || ih + kh * (1 + DH) >= IH) )==""\n"
177R"==(continue; )==""\n"
178R"==(#endif )==""\n"
179R"==(const __global DATA_T *src1 = src )==""\n"
180R"==(+ kd * (1 + DD) * IH * IW * G * IC_WO_PADDING )==""\n"
181R"==(+ kh * (1 + DH) * IW * G * IC_WO_PADDING; )==""\n"
182R"==(if (IC == 3) src1 += local_id; )==""\n"
183R"==(#if ENABLE_KW_BUF )==""\n"
184R"==(DATA_T tempA[SW * OW_BLOCK + KW * (1 + DW)] = {0}; )==""\n"
185R"==(__attribute__((opencl_unroll_hint( )==""\n"
186R"==(SW * OW_BLOCK + KW * (1 + DW)))) )==""\n"
187R"==(for (int i = 0; i < SW * OW_BLOCK + KW * (1 + DW); i++) { )==""\n"
188R"==(if ((i + iw) >= 0 && (i + iw) < IW) { )==""\n"
189R"==(tempA[i] = read_ic_block( )==""\n"
190R"==(&src1[i * G * IC_WO_PADDING], icb); )==""\n"
191R"==(} )==""\n"
192R"==(} )==""\n"
193R"==(#endif )==""\n"
194R"==(__attribute__((opencl_unroll_hint(KW))) )==""\n"
195R"==(for (int kw = 0; kw < KW; ++kw) { )==""\n"
196R"==(#if IC == 3 )==""\n"
197R"==(const __global DATA_T *wei1 = wei )==""\n"
198R"==(+ (kd * KH * KW + kh * KW + kw) * IC * OC_BLOCK; )==""\n"
199R"==(#elif IS_DW )==""\n"
200R"==(const __global DATA_T *wei1 )==""\n"
201R"==(= wei + (kd * KH * KW + kh * KW + kw) * OC_BLOCK; )==""\n"
202R"==(#else )==""\n"
203R"==(const __global DATA_T *wei1 = wei )==""\n"
204R"==(+ (kd * KH * KW + kh * KW + kw) * IC_BLOCK )==""\n"
205R"==(* OC_BLOCK; )==""\n"
206R"==(#endif )==""\n"
207R"==(DATA_T blockA[OW_BLOCK] = {0}; )==""\n"
208R"==(#if ENABLE_KW_BUF )==""\n"
209R"==(__attribute__(( )==""\n"
210R"==(opencl_unroll_hint(OW_BLOCK))) )==""\n"
211R"==(for (int i = 0; i < OW_BLOCK; i++) { )==""\n"
212R"==(blockA[i] = tempA[i * SW + kw * (1 + DW)]; )==""\n"
213R"==(} )==""\n"
214R"==(#else )==""\n"
215R"==(__attribute__(( )==""\n"
216R"==(opencl_unroll_hint(OW_BLOCK))) )==""\n"
217R"==(for (int i = 0; i < OW_BLOCK; i++) { )==""\n"
218R"==(int iw_off = i * SW + kw * (1 + DW); )==""\n"
219R"==(if (iw + iw_off >= 0 && iw + iw_off < IW) { )==""\n"
220R"==(blockA[i] = read_ic_block( )==""\n"
221R"==(&src1[iw_off * G * IC_WO_PADDING], icb); )==""\n"
222R"==(} )==""\n"
223R"==(} )==""\n"
224R"==(#endif )==""\n"
225R"==(#if IC == 3 )==""\n"
226R"==(DATA_T blockB[IC]; )==""\n"
227R"==(__attribute__((opencl_unroll_hint(IC))) )==""\n"
228R"==(for (int i = 0; i < IC; i++) { )==""\n"
229R"==(blockB[i] = _BLOCK_READ(wei1 + i * OC_BLOCK); )==""\n"
230R"==(} )==""\n"
231R"==(__attribute__(( )==""\n"
232R"==(opencl_unroll_hint(OW_BLOCK))) )==""\n"
233R"==(for (int i = 0; i < OW_BLOCK; i++) { )==""\n"
234R"==(multiply_blocks_8x8_ic3( )==""\n"
235R"==(&blockC00[i], blockA[i], blockB); )==""\n"
236R"==(} )==""\n"
237R"==(#elif IS_DW )==""\n"
238R"==(DATA_T blockB = _BLOCK_READ(wei1); )==""\n"
239R"==(for (int i = 0; i < OW_BLOCK; i++) { )==""\n"
240R"==(blockC00[i] = fma(blockA[i], blockB, blockC00[i]); )==""\n"
241R"==(} )==""\n"
242R"==(#else )==""\n"
243R"==(DATA8_T blockB00 = _BLOCK_READ8(wei1); )==""\n"
244R"==(DATA8_T blockB01 = _BLOCK_READ8(wei1 + 8 * OC_BLOCK); )==""\n"
245R"==(__attribute__(( )==""\n"
246R"==(opencl_unroll_hint(OW_BLOCK))) )==""\n"
247R"==(for (int i = 0; i < OW_BLOCK; i++) { )==""\n"
248R"==(multiply_blocks_8x8( )==""\n"
249R"==(&blockC00[i], blockA[i], blockB00, blockB01); )==""\n"
250R"==(} )==""\n"
251R"==(#endif )==""\n"
252R"==(} )==""\n"
253R"==(} )==""\n"
254R"==(} )==""\n"
255R"==(src += IC_BLOCK; )==""\n"
256R"==(wei += KDHW_SIZE * IC_BLOCK * OC_BLOCK; )==""\n"
257R"==(} )==""\n"
258R"==(__global DATA_T *dst_write0 = dst + mb * OD * OH * OW * G * OC_WO_PADDING; )==""\n"
259R"==(dst_write0 += (od * OH * OW + oh * OW + ow) * G * OC_WO_PADDING; )==""\n"
260R"==(dst_write0 += g * OC_WO_PADDING + goc * OC_BLOCK; )==""\n"
261R"==(DATA_T blockS00[OW_BLOCK]; )==""\n"
262R"==(#if WITH_SUM )==""\n"
263R"==(for (int i = 0; i < min(OW_BLOCK, OW - ow); i++) { )==""\n"
264R"==(blockS00[i] = read_oc_block( )==""\n"
265R"==(&dst_write0[i * G * OC_WO_PADDING], goc * OC_BLOCK); )==""\n"
266R"==(} )==""\n"
267R"==(#endif )==""\n"
268R"==(const int po_mb = (mb) % MB; )==""\n"
269R"==(const int po_oc = (oc * OC_BLOCK + local_id) % (OC * G); )==""\n"
270R"==(APPLY_POST_OPS_SERIAL_BINARY_2D( )==""\n"
271R"==(blockC00, DATA_T, blockS00, DATA_T, po_mb, 1, po_oc, 1); )==""\n"
272R"==(for (int i = 0; i < min(OW_BLOCK, OW - ow); i++) { )==""\n"
273R"==(write_oc_block(&dst_write0[i * G * OC_WO_PADDING], goc * OC_BLOCK, )==""\n"
274R"==(blockC00[i]); )==""\n"
275R"==(} )==""\n"
276R"==(} )==""\n"
277R"==()==";
278}
279}
280}
281}