1namespace dnnl {
2namespace impl {
3namespace gpu {
4namespace ocl {
5const char *ref_pooling_kernel = R"==(/******************************************************************************* )==""\n"
6R"==(* Copyright 2019-2022 Intel Corporation )==""\n"
7R"==(* )==""\n"
8R"==(* Licensed under the Apache License, Version 2.0 (the "License"); )==""\n"
9R"==(* you may not use this file except in compliance with the License. )==""\n"
10R"==(* You may obtain a copy of the License at )==""\n"
11R"==(* )==""\n"
12R"==(* http: )==""\n"
13R"==(* )==""\n"
14R"==(* Unless required by applicable law or agreed to in writing, software )==""\n"
15R"==(* distributed under the License is distributed on an "AS IS" BASIS, )==""\n"
16R"==(* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. )==""\n"
17R"==(* See the License for the specific language governing permissions and )==""\n"
18R"==(* limitations under the License. )==""\n"
19R"==(*******************************************************************************/ )==""\n"
20R"==(#include "gpu/ocl/ocl_post_ops.h" )==""\n"
21R"==(#include "gpu/ocl/ocl_types.h" )==""\n"
22R"==(#if DST_DT_F32 )==""\n"
23R"==(#define DST_DATA_MIN -FLT_MAX )==""\n"
24R"==(#elif DST_DT_S8 )==""\n"
25R"==(#define DST_DATA_MIN CHAR_MIN )==""\n"
26R"==(#elif DST_DT_U8 )==""\n"
27R"==(#define DST_DATA_MIN 0 )==""\n"
28R"==(#elif DST_DT_F16 )==""\n"
29R"==(#define DST_DATA_MIN -HALF_MAX )==""\n"
30R"==(#else )==""\n"
31R"==(#define DST_DATA_MIN DATA_MIN )==""\n"
32R"==(#endif )==""\n"
33R"==(#if IS_FWD )==""\n"
34R"==(KERNEL_ATTR )==""\n"
35R"==(__kernel void ref_pooling_fwd(__global DATA_T *src, __global int *ws, )==""\n"
36R"==(__global DST_DATA_T *dst POST_OP_ARGS) { )==""\n"
37R"==(const int mb = GWS_GET_MB(); )==""\n"
38R"==(const int oc = GWS_GET_OC(); )==""\n"
39R"==(const int od = GWS_GET_OD(); )==""\n"
40R"==(const int oh = GWS_GET_OH(); )==""\n"
41R"==(const int ow = GWS_GET_OW(); )==""\n"
42R"==(if (mb >= SRC_D0 || oc >= SRC_D1) { )==""\n"
43R"==(const uint dst_off = DST_OFF(mb, oc, od, oh, ow); )==""\n"
44R"==(dst[dst_off] = TO_DST(0.f); )==""\n"
45R"==(#if ALG_MAX && IS_TRAINING )==""\n"
46R"==(ws[dst_off] = 0; )==""\n"
47R"==(#endif )==""\n"
48R"==(return; )==""\n"
49R"==(} )==""\n"
50R"==(const uint dst_off = DST_OFF(mb, oc, od, oh, ow); )==""\n"
51R"==(#if ALG_MAX && IS_TRAINING )==""\n"
52R"==(ws[dst_off] = -1; )==""\n"
53R"==(#endif )==""\n"
54R"==(#if ALG_AVG_P || ALG_AVG_NP )==""\n"
55R"==(float d = 0; )==""\n"
56R"==(#endif )==""\n"
57R"==(#if ALG_MAX && DT_BF16 )==""\n"
58R"==(DEF_ACC_DATA_T d = TO_DEF_ACC_DATA_T(DST_DATA_MIN); )==""\n"
59R"==(#elif ALG_MAX )==""\n"
60R"==(float d = DST_DATA_MIN; )==""\n"
61R"==(#endif )==""\n"
62R"==(for (int kd = 0; kd < KD; ++kd) { )==""\n"
63R"==(const int id = od * SD - PD + kd * (DD + 1); )==""\n"
64R"==(if (id < 0 || id >= ID) continue; )==""\n"
65R"==(for (int kh = 0; kh < KH; ++kh) { )==""\n"
66R"==(const int ih = oh * SH - PH + kh * (DH + 1); )==""\n"
67R"==(if (ih < 0 || ih >= IH) continue; )==""\n"
68R"==(for (int kw = 0; kw < KW; ++kw) { )==""\n"
69R"==(const int iw = ow * SW - PW + kw * (DW + 1); )==""\n"
70R"==(if (iw < 0 || iw >= IW) continue; )==""\n"
71R"==(int src_off = SRC_OFF(mb, oc, id, ih, iw); )==""\n"
72R"==(#if ALG_MAX )==""\n"
73R"==(#if IS_TRAINING )==""\n"
74R"==(if (ws[dst_off] < 0) ws[dst_off] = kd * KH * KW + kh * KW + kw; )==""\n"
75R"==(#endif )==""\n"
76R"==(#if DT_BF16 )==""\n"
77R"==(DEF_ACC_DATA_T s = DATA_TO_REF(src[src_off]); )==""\n"
78R"==(#else )==""\n"
79R"==(float s = DATA_TO_REF(src[src_off]); )==""\n"
80R"==(#endif )==""\n"
81R"==(if (s > d) { )==""\n"
82R"==(d = s; )==""\n"
83R"==(#if IS_TRAINING )==""\n"
84R"==(ws[dst_off] = kd * KH * KW + kh * KW + kw; )==""\n"
85R"==(#endif )==""\n"
86R"==(} )==""\n"
87R"==(#else )==""\n"
88R"==(d += DATA_TO_REF(src[src_off]); )==""\n"
89R"==(#endif )==""\n"
90R"==(} )==""\n"
91R"==(} )==""\n"
92R"==(} )==""\n"
93R"==(#if ALG_MAX )==""\n"
94R"==(#if IS_TRAINING )==""\n"
95R"==(if (ws[dst_off] < 0) ws[dst_off] = 0; )==""\n"
96R"==(#endif )==""\n"
97R"==(#else )==""\n"
98R"==(#if ALG_AVG_P )==""\n"
99R"==(const int num_summands = KD * KW * KH; )==""\n"
100R"==(#else )==""\n"
101R"==(const int id_start = od * SD - PD; )==""\n"
102R"==(const int ih_start = oh * SH - PH; )==""\n"
103R"==(const int iw_start = ow * SW - PW; )==""\n"
104R"==(const int id_end = od * SD - PD + (KD - 1) * DD + KD; )==""\n"
105R"==(const int ih_end = oh * SH - PH + (KH - 1) * DH + KH; )==""\n"
106R"==(const int iw_end = ow * SW - PW + (KW - 1) * DW + KW; )==""\n"
107R"==(const int id_start_excluded )==""\n"
108R"==(= id_start < 0 ? (0 - id_start - 1) / (DD + 1) + 1 : 0; )==""\n"
109R"==(const int ih_start_excluded )==""\n"
110R"==(= ih_start < 0 ? (0 - ih_start - 1) / (DH + 1) + 1 : 0; )==""\n"
111R"==(const int iw_start_excluded )==""\n"
112R"==(= iw_start < 0 ? (0 - iw_start - 1) / (DW + 1) + 1 : 0; )==""\n"
113R"==(const int id_end_excluded )==""\n"
114R"==(= id_end > ID ? (id_end - ID - 1) / (DD + 1) + 1 : 0; )==""\n"
115R"==(const int ih_end_excluded )==""\n"
116R"==(= ih_end > IH ? (ih_end - IH - 1) / (DH + 1) + 1 : 0; )==""\n"
117R"==(const int iw_end_excluded )==""\n"
118R"==(= iw_end > IW ? (iw_end - IW - 1) / (DW + 1) + 1 : 0; )==""\n"
119R"==(const int num_summands = (KD - id_start_excluded - id_end_excluded) )==""\n"
120R"==(* (KH - ih_start_excluded - ih_end_excluded) )==""\n"
121R"==(* (KW - iw_start_excluded - iw_end_excluded); )==""\n"
122R"==(#endif )==""\n"
123R"==(d /= num_summands; )==""\n"
124R"==(#endif )==""\n"
125R"==(#if DT_BF16 )==""\n"
126R"==(POST_OP_DATA_T tmp = d; )==""\n"
127R"==(#else )==""\n"
128R"==(POST_OP_DATA_T tmp = DATA_TO_REF(d); )==""\n"
129R"==(#endif )==""\n"
130R"==(POST_OP_DATA_T sum_src; )==""\n"
131R"==(#if WITH_SUM )==""\n"
132R"==(sum_src = DATA_TO_REF(dst[dst_off]); )==""\n"
133R"==(#endif )==""\n"
134R"==(#if NDIMS == 3 )==""\n"
135R"==(const unsigned po_d2 = ow; )==""\n"
136R"==(const unsigned po_d3 = 0; )==""\n"
137R"==(const unsigned po_d4 = 0; )==""\n"
138R"==(#elif NDIMS == 4 )==""\n"
139R"==(const unsigned po_d2 = oh; )==""\n"
140R"==(const unsigned po_d3 = ow; )==""\n"
141R"==(const unsigned po_d4 = 0; )==""\n"
142R"==(#elif NDIMS == 5 )==""\n"
143R"==(const unsigned po_d2 = od; )==""\n"
144R"==(const unsigned po_d3 = oh; )==""\n"
145R"==(const unsigned po_d4 = ow; )==""\n"
146R"==(#else )==""\n"
147R"==(const unsigned po_d2 = 0; )==""\n"
148R"==(const unsigned po_d3 = 0; )==""\n"
149R"==(const unsigned po_d4 = 0; )==""\n"
150R"==(#endif )==""\n"
151R"==(APPLY_POST_OPS_SERIAL(tmp, POST_OP_DATA_T, sum_src, POST_OP_DATA_T, mb, 1, )==""\n"
152R"==(oc, 1, po_d2, 1, po_d3, 1, po_d4, 1, 0, 1); )==""\n"
153R"==(dst[dst_off] = TO_DST(tmp); )==""\n"
154R"==(} )==""\n"
155R"==(#endif )==""\n"
156R"==(#if IS_BWD )==""\n"
157R"==(KERNEL_ATTR )==""\n"
158R"==(__kernel void ref_pooling_bwd(__global DATA_T *diff_src, __global int *ws, )==""\n"
159R"==(__global DST_DATA_T *diff_dst) { )==""\n"
160R"==(const int mb = GWS_GET_MB(); )==""\n"
161R"==(const int oc = GWS_GET_OC(); )==""\n"
162R"==(const int id = GWS_GET_ID(); )==""\n"
163R"==(const int ih = GWS_GET_IH(); )==""\n"
164R"==(const int iw = GWS_GET_IW(); )==""\n"
165R"==(float s = 0; )==""\n"
166R"==(for (int kd = 0; kd < KD; ++kd) { )==""\n"
167R"==(int _od = id + PD - kd * (DD + 1); )==""\n"
168R"==(if (_od % SD != 0) continue; )==""\n"
169R"==(int od = _od / SD; )==""\n"
170R"==(if (od < 0 || od >= OD) continue; )==""\n"
171R"==(for (int kh = 0; kh < KH; ++kh) { )==""\n"
172R"==(int _oh = ih + PH - kh * (DH + 1); )==""\n"
173R"==(if (_oh % SH != 0) continue; )==""\n"
174R"==(int oh = _oh / SH; )==""\n"
175R"==(if (oh < 0 || oh >= OH) continue; )==""\n"
176R"==(for (int kw = 0; kw < KW; ++kw) { )==""\n"
177R"==(int _ow = iw + PW - kw * (DW + 1); )==""\n"
178R"==(if (_ow % SW != 0) continue; )==""\n"
179R"==(int ow = _ow / SW; )==""\n"
180R"==(if (ow < 0 || ow >= OW) continue; )==""\n"
181R"==(const uint dst_off = DST_OFF(mb, oc, od, oh, ow); )==""\n"
182R"==(#if ALG_MAX )==""\n"
183R"==(const int index = ws[dst_off]; )==""\n"
184R"==(const int hw = index % (KW * KH); )==""\n"
185R"==(const int w_kd = index / (KW * KH); )==""\n"
186R"==(const int w_kw = hw % KW; )==""\n"
187R"==(const int w_kh = hw / KW; )==""\n"
188R"==(if (w_kd != kd || w_kh != kh || w_kw != kw) continue; )==""\n"
189R"==(#endif )==""\n"
190R"==(#if ALG_MAX )==""\n"
191R"==(const int denom = 1; )==""\n"
192R"==(#elif ALG_AVG_P )==""\n"
193R"==(const int denom = KD * KH * KW; )==""\n"
194R"==(#elif ALG_AVG_NP )==""\n"
195R"==(const int id_start = od * SD - PD; )==""\n"
196R"==(const int ih_start = oh * SH - PH; )==""\n"
197R"==(const int iw_start = ow * SW - PW; )==""\n"
198R"==(const int id_end = od * SD - PD + (KD - 1) * DD + KD; )==""\n"
199R"==(const int ih_end = oh * SH - PH + (KH - 1) * DH + KH; )==""\n"
200R"==(const int iw_end = ow * SW - PW + (KW - 1) * DW + KW; )==""\n"
201R"==(const int id_start_excluded )==""\n"
202R"==(= id_start < 0 ? (0 - id_start - 1) / (DD + 1) + 1 : 0; )==""\n"
203R"==(const int ih_start_excluded )==""\n"
204R"==(= ih_start < 0 ? (0 - ih_start - 1) / (DH + 1) + 1 : 0; )==""\n"
205R"==(const int iw_start_excluded )==""\n"
206R"==(= iw_start < 0 ? (0 - iw_start - 1) / (DW + 1) + 1 : 0; )==""\n"
207R"==(const int id_end_excluded )==""\n"
208R"==(= id_end > ID ? (id_end - ID - 1) / (DD + 1) + 1 : 0; )==""\n"
209R"==(const int ih_end_excluded )==""\n"
210R"==(= ih_end > IH ? (ih_end - IH - 1) / (DH + 1) + 1 : 0; )==""\n"
211R"==(const int iw_end_excluded )==""\n"
212R"==(= iw_end > IW ? (iw_end - IW - 1) / (DW + 1) + 1 : 0; )==""\n"
213R"==(const int denom = (KD - id_start_excluded - id_end_excluded) )==""\n"
214R"==(* (KH - ih_start_excluded - ih_end_excluded) )==""\n"
215R"==(* (KW - iw_start_excluded - iw_end_excluded); )==""\n"
216R"==(#endif )==""\n"
217R"==(s += DATA_TO_REF(diff_dst[dst_off]) / denom; )==""\n"
218R"==(} )==""\n"
219R"==(} )==""\n"
220R"==(} )==""\n"
221R"==(uint diff_src_offset = SRC_OFF(mb, oc, id, ih, iw); )==""\n"
222R"==(diff_src[diff_src_offset] = CONVERT_DATA_T(s); )==""\n"
223R"==(} )==""\n"
224R"==(#endif )==""\n"
225R"==()==";
226}
227}
228}
229}