1namespace dnnl {
2namespace impl {
3namespace gpu {
4namespace ocl {
5const char *vectorized_resampling_kernel = R"==(/******************************************************************************* )==""\n"
6R"==(* Copyright 2022 Intel Corporation )==""\n"
7R"==(* )==""\n"
8R"==(* Licensed under the Apache License, Version 2.0 (the "License"); )==""\n"
9R"==(* you may not use this file except in compliance with the License. )==""\n"
10R"==(* You may obtain a copy of the License at )==""\n"
11R"==(* )==""\n"
12R"==(* http: )==""\n"
13R"==(* )==""\n"
14R"==(* Unless required by applicable law or agreed to in writing, software )==""\n"
15R"==(* distributed under the License is distributed on an "AS IS" BASIS, )==""\n"
16R"==(* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. )==""\n"
17R"==(* See the License for the specific language governing permissions and )==""\n"
18R"==(* limitations under the License. )==""\n"
19R"==(*******************************************************************************/ )==""\n"
20R"==(#include "gpu/ocl/ocl_post_ops.h" )==""\n"
21R"==(#include "gpu/ocl/ocl_types.h" )==""\n"
22R"==(#if IS_BWD == 1 )==""\n"
23R"==(inline float linear(int x, int fo, int fi) { )==""\n"
24R"==(return ((x + .5f) * fo / fi) - .5f; )==""\n"
25R"==(} )==""\n"
26R"==(inline int ceil_pos(float x) { )==""\n"
27R"==(return max((int)ceil(x), (int)0); )==""\n"
28R"==(} )==""\n"
29R"==(#define DST_MB_STRIDE(x) (x % DST_B0) * DST_SB0 + (x / DST_B0) * DST_S0 )==""\n"
30R"==(#define DST_C_STRIDE(x) (x % DST_B1) * DST_SB1 + (x / DST_B1) * DST_S1 )==""\n"
31R"==(#if NDIMS == 3 )==""\n"
32R"==(#define OW_STRIDE(x) (x % DST_B2) * DST_SB2 + (x / DST_B2) * DST_S2 )==""\n"
33R"==(#define OH_STRIDE(x) 0 )==""\n"
34R"==(#define OD_STRIDE(x) 0 )==""\n"
35R"==(#elif NDIMS == 4 )==""\n"
36R"==(#define OW_STRIDE(x) (x % DST_B3) * DST_SB3 + (x / DST_B3) * DST_S3 )==""\n"
37R"==(#define OH_STRIDE(x) (x % DST_B2) * DST_SB2 + (x / DST_B2) * DST_S2 )==""\n"
38R"==(#define OD_STRIDE(x) 0 )==""\n"
39R"==(#elif NDIMS == 5 )==""\n"
40R"==(#define OW_STRIDE(x) (x % DST_B4) * DST_SB4 + (x / DST_B4) * DST_S4 )==""\n"
41R"==(#define OH_STRIDE(x) (x % DST_B3) * DST_SB3 + (x / DST_B3) * DST_S3 )==""\n"
42R"==(#define OD_STRIDE(x) (x % DST_B2) * DST_SB2 + (x / DST_B2) * DST_S2 )==""\n"
43R"==(#endif )==""\n"
44R"==(KERNEL_ATTR )==""\n"
45R"==(__kernel void vectorized_resampling_bwd( )==""\n"
46R"==(__global DST_DATA_T *diff_src, __global const DATA_T *diff_dst) { )==""\n"
47R"==(const uint sglid = get_sub_group_local_id(); )==""\n"
48R"==(const uint mb = (get_global_id(0) / MB_STRIDE); )==""\n"
49R"==(const uint c_start = (get_global_id(0) / GWS_SGS_DEFAULT) * GWS_SGS_DEFAULT )==""\n"
50R"==(* VECT_DT_N % PADDED_C; )==""\n"
51R"==(const uint c = c_start + sglid; )==""\n"
52R"==(const uint id = (get_global_id(0) / ID_STRIDE) % ID; )==""\n"
53R"==(const uint ih = (get_global_id(0) / IH_STRIDE) % IH; )==""\n"
54R"==(const uint iw = (get_global_id(0) / IW_STRIDE) % IW; )==""\n"
55R"==(const uint src_index = SRC_OFF(mb, c_start, id, ih, iw); )==""\n"
56R"==(VECT_DEF_ACC_DATA_T src_val = 0.0f; )==""\n"
57R"==(#if RESAMPLING_ALG_NEAREST )==""\n"
58R"==(if (mb >= MB || c >= C) return; )==""\n"
59R"==(int od_start = ceil_pos(id * FD - .5f); )==""\n"
60R"==(int oh_start = ceil_pos(ih * FH - .5f); )==""\n"
61R"==(int ow_start = ceil_pos(iw * FW - .5f); )==""\n"
62R"==(int od_end = ceil_pos((id + 1.f) * FD - .5f); )==""\n"
63R"==(int oh_end = ceil_pos((ih + 1.f) * FH - .5f); )==""\n"
64R"==(int ow_end = ceil_pos((iw + 1.f) * FW - .5f); )==""\n"
65R"==(for_(int i = od_start; i < od_end; i++) )==""\n"
66R"==(for_(int j = oh_start; j < oh_end; j++) )==""\n"
67R"==(for (int k = ow_start; k < ow_end; k++) { )==""\n"
68R"==(const int dst_index = DST_OFF(mb, c_start, i, j, k); )==""\n"
69R"==(#if VECT_DT_N == 1 )==""\n"
70R"==(src_val += AS_VECT_DEF_ACC_DATA_T(diff_dst[dst_index + sglid]); )==""\n"
71R"==(#else )==""\n"
72R"==(for (int i = 0; i < VECT_DT_N && c + GWS_SGS_DEFAULT * i < C; i++) { )==""\n"
73R"==(src_val[i] += TO_DEF_ACC_DATA_T( )==""\n"
74R"==(diff_dst[dst_index + sglid + GWS_SGS_DEFAULT * i]); )==""\n"
75R"==(} )==""\n"
76R"==(#endif )==""\n"
77R"==(} )==""\n"
78R"==(#else )==""\n"
79R"==(int my_idx; )==""\n"
80R"==({ )==""\n"
81R"==(int ix, OX, IX; )==""\n"
82R"==(if (sglid < 4) { )==""\n"
83R"==(ix = id; )==""\n"
84R"==(OX = OD; )==""\n"
85R"==(IX = ID; )==""\n"
86R"==(} else if (sglid < 8) { )==""\n"
87R"==(ix = ih; )==""\n"
88R"==(OX = OH; )==""\n"
89R"==(IX = IH; )==""\n"
90R"==(} else { )==""\n"
91R"==(ix = iw; )==""\n"
92R"==(OX = OW; )==""\n"
93R"==(IX = IW; )==""\n"
94R"==(} )==""\n"
95R"==(if (sglid % 4 == 1) { ix -= 1; } )==""\n"
96R"==(if (sglid % 4 == 2) { ix += 1; } )==""\n"
97R"==(float idx_intermediate = linear(ix, OX, IX); )==""\n"
98R"==(if (sglid % 2 == 0) { )==""\n"
99R"==(my_idx = ceil_pos(idx_intermediate); )==""\n"
100R"==(} else { )==""\n"
101R"==(my_idx = (idx_intermediate < 0) ? 0 : (int)idx_intermediate + 1; )==""\n"
102R"==(} )==""\n"
103R"==(if (sglid % 4 >= 2) { my_idx = min(my_idx, OX); } )==""\n"
104R"==(if (sglid % 4 == 0 && ix == 0) { my_idx = 0; } )==""\n"
105R"==(if (sglid % 4 == 3 && ix == IX - 1) { my_idx = OX; } )==""\n"
106R"==(} )==""\n"
107R"==(int od_start[2] )==""\n"
108R"==(= {sub_group_shuffle(my_idx, 0), sub_group_shuffle(my_idx, 1)}; )==""\n"
109R"==(int od_end[2] )==""\n"
110R"==(= {sub_group_shuffle(my_idx, 2), sub_group_shuffle(my_idx, 3)}; )==""\n"
111R"==(int oh_start[2] )==""\n"
112R"==(= {sub_group_shuffle(my_idx, 4), sub_group_shuffle(my_idx, 5)}; )==""\n"
113R"==(int oh_end[2] )==""\n"
114R"==(= {sub_group_shuffle(my_idx, 6), sub_group_shuffle(my_idx, 7)}; )==""\n"
115R"==(int ow_start[2] )==""\n"
116R"==(= {sub_group_shuffle(my_idx, 8), sub_group_shuffle(my_idx, 9)}; )==""\n"
117R"==(int ow_end[2] )==""\n"
118R"==(= {sub_group_shuffle(my_idx, 10), sub_group_shuffle(my_idx, 11)}; )==""\n"
119R"==(const int num_od_left = od_end[0] - od_start[0]; )==""\n"
120R"==(const int num_od_right = od_end[1] - od_start[1]; )==""\n"
121R"==(const int num_oh_left = oh_end[0] - oh_start[0]; )==""\n"
122R"==(const int num_oh_right = oh_end[1] - oh_start[1]; )==""\n"
123R"==(const int num_ow_left = ow_end[0] - ow_start[0]; )==""\n"
124R"==(const int num_ow_right = ow_end[1] - ow_start[1]; )==""\n"
125R"==(float myres; )==""\n"
126R"==({ )==""\n"
127R"==(int ox, IX, OX; )==""\n"
128R"==(int offset = sglid; )==""\n"
129R"==(if (0 <= offset && offset < num_od_left) { )==""\n"
130R"==(OX = OD; )==""\n"
131R"==(IX = ID; )==""\n"
132R"==(ox = od_start[0] + offset; )==""\n"
133R"==(} )==""\n"
134R"==(offset -= num_od_left; )==""\n"
135R"==(if (0 <= offset && offset < num_od_right) { )==""\n"
136R"==(OX = OD; )==""\n"
137R"==(IX = ID; )==""\n"
138R"==(ox = od_start[1] + offset; )==""\n"
139R"==(} )==""\n"
140R"==(offset -= num_od_right; )==""\n"
141R"==(if (0 <= offset && offset < num_oh_left) { )==""\n"
142R"==(OX = OH; )==""\n"
143R"==(IX = IH; )==""\n"
144R"==(ox = oh_start[0] + offset; )==""\n"
145R"==(} )==""\n"
146R"==(offset -= num_oh_left; )==""\n"
147R"==(if (0 <= offset && offset < num_oh_right) { )==""\n"
148R"==(OX = OH; )==""\n"
149R"==(IX = IH; )==""\n"
150R"==(ox = oh_start[1] + offset; )==""\n"
151R"==(} )==""\n"
152R"==(offset -= num_oh_right; )==""\n"
153R"==(if (0 <= offset && offset < num_ow_left) { )==""\n"
154R"==(OX = OW; )==""\n"
155R"==(IX = IW; )==""\n"
156R"==(ox = ow_start[0] + offset; )==""\n"
157R"==(} )==""\n"
158R"==(offset -= num_ow_left; )==""\n"
159R"==(if (0 <= offset && offset < num_ow_right) { )==""\n"
160R"==(OX = OW; )==""\n"
161R"==(IX = IW; )==""\n"
162R"==(ox = ow_start[1] + offset; )==""\n"
163R"==(} )==""\n"
164R"==(const float x = linear(ox, IX, OX); )==""\n"
165R"==(myres = fabs(x - trunc(x)); )==""\n"
166R"==(} )==""\n"
167R"==(float d_list[2][MAX_NUM_D]; )==""\n"
168R"==(float h_list[2][MAX_NUM_H]; )==""\n"
169R"==(float w_list[2][MAX_NUM_W]; )==""\n"
170R"==(for (int d = 0; d < num_od_left; d++) { )==""\n"
171R"==(d_list[0][d] = 1.0f - sub_group_shuffle(myres, d); )==""\n"
172R"==(} )==""\n"
173R"==(int offset = num_od_left; )==""\n"
174R"==(for (int d = 0; d < num_od_right; d++) { )==""\n"
175R"==(d_list[1][d] = sub_group_shuffle(myres, d + offset); )==""\n"
176R"==(} )==""\n"
177R"==(offset += num_od_right; )==""\n"
178R"==(for (int h = 0; h < num_oh_left; h++) { )==""\n"
179R"==(h_list[0][h] = 1.0f - sub_group_shuffle(myres, h + offset); )==""\n"
180R"==(} )==""\n"
181R"==(offset += num_oh_left; )==""\n"
182R"==(for (int h = 0; h < num_oh_right; h++) { )==""\n"
183R"==(h_list[1][h] = sub_group_shuffle(myres, h + offset); )==""\n"
184R"==(} )==""\n"
185R"==(offset += num_oh_right; )==""\n"
186R"==(for (int w = 0; w < num_ow_left; w++) { )==""\n"
187R"==(w_list[0][w] = 1.0f - sub_group_shuffle(myres, w + offset); )==""\n"
188R"==(} )==""\n"
189R"==(offset += num_ow_left; )==""\n"
190R"==(for (int w = 0; w < num_ow_right; w++) { )==""\n"
191R"==(w_list[1][w] = sub_group_shuffle(myres, w + offset); )==""\n"
192R"==(} )==""\n"
193R"==(if (mb >= MB || c >= C) return; )==""\n"
194R"==(const uint mb_c_off = DST_MB_STRIDE(mb) + DST_C_STRIDE(c_start); )==""\n"
195R"==(for_(int c1 = 0; c1 < 2; c1++) )==""\n"
196R"==(for (int od = od_start[c1], i = 0; i < MAX_NUM_D && od < od_end[c1]; )==""\n"
197R"==(od++, i++) { )==""\n"
198R"==(const uint d_off = mb_c_off + OD_STRIDE(od); )==""\n"
199R"==(float Wid = d_list[c1][i]; )==""\n"
200R"==(for_(int c2 = 0; c2 < 2; c2++) )==""\n"
201R"==(for (int oh = oh_start[c2], j = 0; j < MAX_NUM_H && oh < oh_end[c2]; )==""\n"
202R"==(oh++, j++) { )==""\n"
203R"==(const uint h_off = d_off + OH_STRIDE(oh); )==""\n"
204R"==(float Wih = h_list[c2][j]; )==""\n"
205R"==(unroll_for(int c3 = 0; c3 < 2; c3++) )==""\n"
206R"==(unroll_for(int k = 0, ow = ow_start[c3]; )==""\n"
207R"==(k < MAX_NUM_W && ow < ow_end[c3]; k++, ow++) { )==""\n"
208R"==(const uint dst_off = h_off + OW_STRIDE(ow); )==""\n"
209R"==(#if VECT_DT_N == 1 )==""\n"
210R"==(VECT_DEF_ACC_DATA_T dst_val )==""\n"
211R"==(= AS_VECT_DEF_ACC_DATA_T(diff_dst[dst_off + sglid]); )==""\n"
212R"==(#else )==""\n"
213R"==(VECT_DEF_ACC_DATA_T dst_val = 0; )==""\n"
214R"==(for (int idx = 0; )==""\n"
215R"==(idx < VECT_DT_N && c + GWS_SGS_DEFAULT * idx < C; )==""\n"
216R"==(idx++) { )==""\n"
217R"==(dst_val[idx] = TO_DEF_ACC_DATA_T( )==""\n"
218R"==(diff_dst[dst_off + sglid + GWS_SGS_DEFAULT * idx]); )==""\n"
219R"==(} )==""\n"
220R"==(#endif )==""\n"
221R"==(float Wiw = w_list[c3][k]; )==""\n"
222R"==(src_val += dst_val * Wid * Wih * Wiw; )==""\n"
223R"==(} )==""\n"
224R"==(} )==""\n"
225R"==(} )==""\n"
226R"==(#endif )==""\n"
227R"==(#if VECT_DT_N == 1 )==""\n"
228R"==(diff_src[src_index + sglid] = TO_DST(src_val); )==""\n"
229R"==(#else )==""\n"
230R"==(for (int i = 0; i < VECT_DT_N && c + GWS_SGS_DEFAULT * i < C; i++) { )==""\n"
231R"==(diff_src[src_index + sglid + GWS_SGS_DEFAULT * i] = TO_DST(src_val[i]); )==""\n"
232R"==(} )==""\n"
233R"==(#endif )==""\n"
234R"==(} )==""\n"
235R"==(#endif )==""\n"
236R"==()==";
237}
238}
239}
240}