1namespace dnnl {
2namespace impl {
3namespace gpu {
4namespace ocl {
5const char *ref_convolution_kernel = R"==(/******************************************************************************* )==""\n"
6R"==(* Copyright 2019-2022 Intel Corporation )==""\n"
7R"==(* )==""\n"
8R"==(* Licensed under the Apache License, Version 2.0 (the "License"); )==""\n"
9R"==(* you may not use this file except in compliance with the License. )==""\n"
10R"==(* You may obtain a copy of the License at )==""\n"
11R"==(* )==""\n"
12R"==(* http: )==""\n"
13R"==(* )==""\n"
14R"==(* Unless required by applicable law or agreed to in writing, software )==""\n"
15R"==(* distributed under the License is distributed on an "AS IS" BASIS, )==""\n"
16R"==(* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. )==""\n"
17R"==(* See the License for the specific language governing permissions and )==""\n"
18R"==(* limitations under the License. )==""\n"
19R"==(*******************************************************************************/ )==""\n"
20R"==(#include "gpu/ocl/ocl_post_ops.h" )==""\n"
21R"==(#include "gpu/ocl/ocl_types.h" )==""\n"
22R"==(#undef SRC_OFF )==""\n"
23R"==(#undef WEI_OFF )==""\n"
24R"==(#undef DST_OFF )==""\n"
25R"==(#define SRC_OFF CONV_SRC_OFF )==""\n"
26R"==(#define WEI_OFF CONV_WEI_OFF )==""\n"
27R"==(#define DST_OFF CONV_DST_OFF )==""\n"
28R"==(#if IS_FWD )==""\n"
29R"==(KERNEL_ATTR )==""\n"
30R"==(__kernel void ref_convolution_fwd(const __global SRC_DATA_T *src, )==""\n"
31R"==(const __global WEI_DATA_T *wei, const __global BIA_DATA_T *bias, )==""\n"
32R"==(__global DST_DATA_T *dst POST_OP_ARGS, const __global float *src_scales, )==""\n"
33R"==(const __global float *wei_scales, const __global float *dst_scales, )==""\n"
34R"==(const __global int *src_zpoints, const __global int *dst_zpoints) { )==""\n"
35R"==(src += SRC_OFFSET0; )==""\n"
36R"==(dst += DST_OFFSET0; )==""\n"
37R"==(const int n = GWS_GET_MB(); )==""\n"
38R"==(const int oc = GWS_GET_OC(); )==""\n"
39R"==(const int g = GWS_GET_G(); )==""\n"
40R"==(const int od = GWS_GET_OD(); )==""\n"
41R"==(const int oh = GWS_GET_OH(); )==""\n"
42R"==(const int ow = GWS_GET_OW(); )==""\n"
43R"==(ACC_DATA_T d = 0; )==""\n"
44R"==(for (int ic = 0; ic < IC; ++ic) )==""\n"
45R"==(for (int kd = 0; kd < KD; ++kd) )==""\n"
46R"==(for (int kh = 0; kh < KH; ++kh) )==""\n"
47R"==(for (int kw = 0; kw < KW; ++kw) { )==""\n"
48R"==(const int id = od * SD - PD + kd * (1 + DD); )==""\n"
49R"==(const int ih = oh * SH - PH + kh * (1 + DH); )==""\n"
50R"==(const int iw = ow * SW - PW + kw * (1 + DW); )==""\n"
51R"==(if (id < 0 || id >= ID || ih < 0 || ih >= IH || iw < 0 )==""\n"
52R"==(|| iw >= IW) )==""\n"
53R"==(continue; )==""\n"
54R"==(const uint src_off = SRC_OFF(n, g * IC + ic, id, ih, iw); )==""\n"
55R"==(const uint wei_off = WEI_OFF(g, oc, ic, kd, kh, kw); )==""\n"
56R"==(d += SRC_TO_REF(src[src_off]) * WEI_TO_REF(wei[wei_off]); )==""\n"
57R"==(#if WITH_SRC_ZPOINTS )==""\n"
58R"==(const int src_zp )==""\n"
59R"==(= src_zpoints[WITH_SRC_ZPOINTS_PER_IC ? g * IC + ic )==""\n"
60R"==(: 0]; )==""\n"
61R"==(d -= src_zp * WEI_TO_REF(wei[wei_off]); )==""\n"
62R"==(#endif )==""\n"
63R"==(} )==""\n"
64R"==(POST_OP_DATA_T tmp = d; )==""\n"
65R"==(#if WITH_SRC_SCALES )==""\n"
66R"==(tmp *= src_scales[0]; )==""\n"
67R"==(#endif )==""\n"
68R"==(#if WITH_WEI_SCALES )==""\n"
69R"==(#if WEI_SCALES_MASK == 0 )==""\n"
70R"==(tmp *= wei_scales[0]; )==""\n"
71R"==(#else )==""\n"
72R"==(tmp *= wei_scales[g * OC + oc]; )==""\n"
73R"==(#endif )==""\n"
74R"==(#endif )==""\n"
75R"==(#if WITH_BIAS )==""\n"
76R"==(tmp += (POST_OP_DATA_T)BIA_TO_REF(bias[g * OC + oc]); )==""\n"
77R"==(#endif )==""\n"
78R"==(POST_OP_DATA_T sum_src; )==""\n"
79R"==(#if WITH_SUM )==""\n"
80R"==(sum_src = (POST_OP_DATA_T)SUM_TO_REF( )==""\n"
81R"==(AS_SUM_DATA_T(dst[DST_OFF(n, g * OC + oc, od, oh, ow)])); )==""\n"
82R"==(#endif )==""\n"
83R"==(#if NDIMS == 3 )==""\n"
84R"==(const unsigned po_d2 = ow; )==""\n"
85R"==(const unsigned po_d3 = 0; )==""\n"
86R"==(const unsigned po_d4 = 0; )==""\n"
87R"==(#elif NDIMS == 4 )==""\n"
88R"==(const unsigned po_d2 = oh; )==""\n"
89R"==(const unsigned po_d3 = ow; )==""\n"
90R"==(const unsigned po_d4 = 0; )==""\n"
91R"==(#elif NDIMS == 5 )==""\n"
92R"==(const unsigned po_d2 = od; )==""\n"
93R"==(const unsigned po_d3 = oh; )==""\n"
94R"==(const unsigned po_d4 = ow; )==""\n"
95R"==(#else )==""\n"
96R"==(const unsigned po_d2 = 0; )==""\n"
97R"==(const unsigned po_d3 = 0; )==""\n"
98R"==(const unsigned po_d4 = 0; )==""\n"
99R"==(#endif )==""\n"
100R"==(APPLY_POST_OPS_SERIAL(tmp, POST_OP_DATA_T, sum_src, POST_OP_DATA_T, n, 1, )==""\n"
101R"==(g * OC + oc, 1, po_d2, 1, po_d3, 1, po_d4, 1, 0, 1); )==""\n"
102R"==(#if WITH_DST_SCALES )==""\n"
103R"==(tmp /= dst_scales[0]; )==""\n"
104R"==(#endif )==""\n"
105R"==(#if WITH_DST_ZPOINTS )==""\n"
106R"==(const int dst_zp = dst_zpoints[WITH_DST_ZPOINTS_PER_OC ? g * OC + oc : 0]; )==""\n"
107R"==(tmp += dst_zp; )==""\n"
108R"==(#endif )==""\n"
109R"==(dst[DST_OFF(n, g * OC + oc, od, oh, ow)] = TO_DST(tmp); )==""\n"
110R"==(} )==""\n"
111R"==(#endif )==""\n"
112R"==(#if IS_BWD_D )==""\n"
113R"==(KERNEL_ATTR )==""\n"
114R"==(__kernel void ref_convolution_bwd_data(__global SRC_DATA_T *diff_src, )==""\n"
115R"==(const __global WEI_DATA_T *wei, const __global DST_DATA_T *diff_dst, )==""\n"
116R"==(const __global BIA_DATA_T *bias POST_OP_ARGS, )==""\n"
117R"==(const __global float *src_scales, const __global float *wei_scales, )==""\n"
118R"==(const __global float *dst_scales, const __global int *src_zpoints, )==""\n"
119R"==(const __global int *dst_zpoints) { )==""\n"
120R"==(const int n = GWS_GET_MB(); )==""\n"
121R"==(const int ic = GWS_GET_IC(); )==""\n"
122R"==(const int g = GWS_GET_G(); )==""\n"
123R"==(const int id = GWS_GET_ID(); )==""\n"
124R"==(const int ih = GWS_GET_IH(); )==""\n"
125R"==(const int iw = GWS_GET_IW(); )==""\n"
126R"==(ACC_DATA_T d = WITH_BIAS ? BIA_TO_REF(bias[g * IC + ic]) : 0.0; )==""\n"
127R"==(for_(int oc = 0; oc < OC; ++oc) )==""\n"
128R"==(for_(int kd = 0; kd < KD; ++kd) )==""\n"
129R"==(for_(int kh = 0; kh < KH; ++kh) )==""\n"
130R"==(for (int kw = 0; kw < KW; ++kw) { )==""\n"
131R"==(if (iw + PW < kw * (1 + DW) || ih + PH < kh * (1 + DH) )==""\n"
132R"==(|| id + PD < kd * (1 + DD)) )==""\n"
133R"==(continue; )==""\n"
134R"==(int ow = iw - kw * (1 + DW) + PW; )==""\n"
135R"==(int oh = ih - kh * (1 + DH) + PH; )==""\n"
136R"==(int od = id - kd * (1 + DD) + PD; )==""\n"
137R"==(if (ow % SW != 0 || oh % SH != 0 || od % SD != 0) continue; )==""\n"
138R"==(ow /= SW; )==""\n"
139R"==(oh /= SH; )==""\n"
140R"==(od /= SD; )==""\n"
141R"==(if (oh < OH && ow < OW && od < OD) { )==""\n"
142R"==(const uint dst_off = DST_OFF(n, g * OC + oc, od, oh, ow); )==""\n"
143R"==(const uint wei_off = WEI_OFF(g, oc, ic, kd, kh, kw); )==""\n"
144R"==(d += DST_TO_REF(diff_dst[dst_off]) * WEI_TO_REF(wei[wei_off]); )==""\n"
145R"==(#if WITH_SRC_ZPOINTS )==""\n"
146R"==(const int src_zp )==""\n"
147R"==(= src_zpoints[WITH_SRC_ZPOINTS_PER_IC ? g * OC + oc : 0]; )==""\n"
148R"==(d -= src_zp * WEI_TO_REF(wei[wei_off]); )==""\n"
149R"==(#endif )==""\n"
150R"==(} )==""\n"
151R"==(} )==""\n"
152R"==(float sum_src; )==""\n"
153R"==(#if WITH_SUM )==""\n"
154R"==(sum_src = convert_float( )==""\n"
155R"==(SRC_TO_REF(diff_src[SRC_OFF(n, g * IC + ic, id, ih, iw)])); )==""\n"
156R"==(#endif )==""\n"
157R"==(float accumulator = convert_float(d); )==""\n"
158R"==(#if WITH_SRC_SCALES )==""\n"
159R"==(accumulator *= src_scales[0]; )==""\n"
160R"==(#endif )==""\n"
161R"==(#if WITH_WEI_SCALES )==""\n"
162R"==(#if WEI_SCALES_MASK == 0 )==""\n"
163R"==(accumulator *= wei_scales[0]; )==""\n"
164R"==(#else )==""\n"
165R"==(accumulator *= wei_scales[g * IC + ic]; )==""\n"
166R"==(#endif )==""\n"
167R"==(#endif )==""\n"
168R"==(#if NDIMS == 3 )==""\n"
169R"==(const unsigned po_d2 = iw; )==""\n"
170R"==(const unsigned po_d3 = 0; )==""\n"
171R"==(const unsigned po_d4 = 0; )==""\n"
172R"==(#elif NDIMS == 4 )==""\n"
173R"==(const unsigned po_d2 = ih; )==""\n"
174R"==(const unsigned po_d3 = iw; )==""\n"
175R"==(const unsigned po_d4 = 0; )==""\n"
176R"==(#elif NDIMS == 5 )==""\n"
177R"==(const unsigned po_d2 = id; )==""\n"
178R"==(const unsigned po_d3 = ih; )==""\n"
179R"==(const unsigned po_d4 = iw; )==""\n"
180R"==(#else )==""\n"
181R"==(const unsigned po_d2 = 0; )==""\n"
182R"==(const unsigned po_d3 = 0; )==""\n"
183R"==(const unsigned po_d4 = 0; )==""\n"
184R"==(#endif )==""\n"
185R"==(APPLY_POST_OPS_SERIAL(accumulator, float, sum_src, float, n, 1, g *IC + ic, )==""\n"
186R"==(1, po_d2, 1, po_d3, 1, po_d4, 1, 0, 1); )==""\n"
187R"==(#if WITH_DST_SCALES )==""\n"
188R"==(accumulator /= dst_scales[0]; )==""\n"
189R"==(#endif )==""\n"
190R"==(#if WITH_DST_ZPOINTS )==""\n"
191R"==(const int dst_zp = dst_zpoints[WITH_DST_ZPOINTS_PER_OC ? g * IC + ic : 0]; )==""\n"
192R"==(accumulator += dst_zp; )==""\n"
193R"==(#endif )==""\n"
194R"==(diff_src[SRC_OFF(n, g * IC + ic, id, ih, iw)] = TO_SRC(accumulator); )==""\n"
195R"==(} )==""\n"
196R"==(#endif )==""\n"
197R"==(#if IS_BWD_W )==""\n"
198R"==(KERNEL_ATTR )==""\n"
199R"==(__kernel void ref_convolution_bwd_weights(const __global SRC_DATA_T *src, )==""\n"
200R"==(__global WEI_DATA_T *diff_wei, __global BIA_DATA_T *diff_bias, )==""\n"
201R"==(const __global DST_DATA_T *diff_dst) { )==""\n"
202R"==(const int g = GWS_GET_G(); )==""\n"
203R"==(const int ic = GWS_GET_IC(); )==""\n"
204R"==(const int oc = GWS_GET_OC(); )==""\n"
205R"==(const int kd = GWS_GET_KD(); )==""\n"
206R"==(const int kh = GWS_GET_KH(); )==""\n"
207R"==(const int kw = GWS_GET_KW(); )==""\n"
208R"==(#if WITH_BIAS )==""\n"
209R"==(if (ic == 0 && kh == 0 && kw == 0 & kd == 0) { )==""\n"
210R"==(ACC_DATA_T d = 0.0; )==""\n"
211R"==(for (int n = 0; n < MB; ++n) )==""\n"
212R"==(for (int od = 0; od < OD; ++od) )==""\n"
213R"==(for (int oh = 0; oh < OH; ++oh) )==""\n"
214R"==(for (int ow = 0; ow < OW; ++ow) { )==""\n"
215R"==(d += DST_TO_REF( )==""\n"
216R"==(diff_dst[DST_OFF(n, g * OC + oc, od, oh, ow)]); )==""\n"
217R"==(} )==""\n"
218R"==(diff_bias[g * OC + oc] = TO_BIA(d); )==""\n"
219R"==(} )==""\n"
220R"==(#endif )==""\n"
221R"==(ACC_DATA_T dw = 0.0; )==""\n"
222R"==(for (int n = 0; n < MB; ++n) )==""\n"
223R"==(for (int od = 0; od < OD; ++od) )==""\n"
224R"==(for (int oh = 0; oh < OH; ++oh) )==""\n"
225R"==(for (int ow = 0; ow < OW; ++ow) { )==""\n"
226R"==(if (ow * SW + kw * (1 + DW) < PW )==""\n"
227R"==(|| oh * SH + kh * (1 + DH) < PH )==""\n"
228R"==(|| od * SD + kd * (1 + DD) < PD )==""\n"
229R"==(|| ow * SW + kw * (1 + DW) >= IW + PW )==""\n"
230R"==(|| oh * SH + kh * (1 + DH) >= IH + PH )==""\n"
231R"==(|| od * SD + kd * (1 + DD) >= ID + PD) )==""\n"
232R"==(continue; )==""\n"
233R"==(int id = od * SD - PD + kd * (1 + DD); )==""\n"
234R"==(int ih = oh * SH - PH + kh * (1 + DH); )==""\n"
235R"==(int iw = ow * SW - PW + kw * (1 + DW); )==""\n"
236R"==(dw += DST_TO_REF( )==""\n"
237R"==(diff_dst[DST_OFF(n, g * OC + oc, od, oh, ow)]) )==""\n"
238R"==(* SRC_TO_REF( )==""\n"
239R"==(src[SRC_OFF(n, g * IC + ic, id, ih, iw)]); )==""\n"
240R"==(} )==""\n"
241R"==(diff_wei[WEI_OFF(g, oc, ic, kd, kh, kw)] = TO_WEI(dw); )==""\n"
242R"==(} )==""\n"
243R"==(#endif )==""\n"
244R"==()==";
245}
246}
247}
248}