1 | namespace dnnl { |
2 | namespace impl { |
3 | namespace gpu { |
4 | namespace ocl { |
5 | const char *vectorized_resampling_kernel = R"==(/******************************************************************************* )==" "\n" |
6 | R"==(* Copyright 2022 Intel Corporation )==" "\n" |
7 | R"==(* )==" "\n" |
8 | R"==(* Licensed under the Apache License, Version 2.0 (the "License"); )==" "\n" |
9 | R"==(* you may not use this file except in compliance with the License. )==" "\n" |
10 | R"==(* You may obtain a copy of the License at )==" "\n" |
11 | R"==(* )==" "\n" |
12 | R"==(* http: )==" "\n" |
13 | R"==(* )==" "\n" |
14 | R"==(* Unless required by applicable law or agreed to in writing, software )==" "\n" |
15 | R"==(* distributed under the License is distributed on an "AS IS" BASIS, )==" "\n" |
16 | R"==(* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. )==" "\n" |
17 | R"==(* See the License for the specific language governing permissions and )==" "\n" |
18 | R"==(* limitations under the License. )==" "\n" |
19 | R"==(*******************************************************************************/ )==" "\n" |
20 | R"==(#include "gpu/ocl/ocl_post_ops.h" )==" "\n" |
21 | R"==(#include "gpu/ocl/ocl_types.h" )==" "\n" |
22 | R"==(#if IS_BWD == 1 )==" "\n" |
23 | R"==(inline float linear(int x, int fo, int fi) { )==" "\n" |
24 | R"==(return ((x + .5f) * fo / fi) - .5f; )==" "\n" |
25 | R"==(} )==" "\n" |
26 | R"==(inline int ceil_pos(float x) { )==" "\n" |
27 | R"==(return max((int)ceil(x), (int)0); )==" "\n" |
28 | R"==(} )==" "\n" |
29 | R"==(#define DST_MB_STRIDE(x) (x % DST_B0) * DST_SB0 + (x / DST_B0) * DST_S0 )==" "\n" |
30 | R"==(#define DST_C_STRIDE(x) (x % DST_B1) * DST_SB1 + (x / DST_B1) * DST_S1 )==" "\n" |
31 | R"==(#if NDIMS == 3 )==" "\n" |
32 | R"==(#define OW_STRIDE(x) (x % DST_B2) * DST_SB2 + (x / DST_B2) * DST_S2 )==" "\n" |
33 | R"==(#define OH_STRIDE(x) 0 )==" "\n" |
34 | R"==(#define OD_STRIDE(x) 0 )==" "\n" |
35 | R"==(#elif NDIMS == 4 )==" "\n" |
36 | R"==(#define OW_STRIDE(x) (x % DST_B3) * DST_SB3 + (x / DST_B3) * DST_S3 )==" "\n" |
37 | R"==(#define OH_STRIDE(x) (x % DST_B2) * DST_SB2 + (x / DST_B2) * DST_S2 )==" "\n" |
38 | R"==(#define OD_STRIDE(x) 0 )==" "\n" |
39 | R"==(#elif NDIMS == 5 )==" "\n" |
40 | R"==(#define OW_STRIDE(x) (x % DST_B4) * DST_SB4 + (x / DST_B4) * DST_S4 )==" "\n" |
41 | R"==(#define OH_STRIDE(x) (x % DST_B3) * DST_SB3 + (x / DST_B3) * DST_S3 )==" "\n" |
42 | R"==(#define OD_STRIDE(x) (x % DST_B2) * DST_SB2 + (x / DST_B2) * DST_S2 )==" "\n" |
43 | R"==(#endif )==" "\n" |
44 | R"==(KERNEL_ATTR )==" "\n" |
45 | R"==(__kernel void vectorized_resampling_bwd( )==" "\n" |
46 | R"==(__global DST_DATA_T *diff_src, __global const DATA_T *diff_dst) { )==" "\n" |
47 | R"==(const uint sglid = get_sub_group_local_id(); )==" "\n" |
48 | R"==(const uint mb = (get_global_id(0) / MB_STRIDE); )==" "\n" |
49 | R"==(const uint c_start = (get_global_id(0) / GWS_SGS_DEFAULT) * GWS_SGS_DEFAULT )==" "\n" |
50 | R"==(* VECT_DT_N % PADDED_C; )==" "\n" |
51 | R"==(const uint c = c_start + sglid; )==" "\n" |
52 | R"==(const uint id = (get_global_id(0) / ID_STRIDE) % ID; )==" "\n" |
53 | R"==(const uint ih = (get_global_id(0) / IH_STRIDE) % IH; )==" "\n" |
54 | R"==(const uint iw = (get_global_id(0) / IW_STRIDE) % IW; )==" "\n" |
55 | R"==(const uint src_index = SRC_OFF(mb, c_start, id, ih, iw); )==" "\n" |
56 | R"==(VECT_DEF_ACC_DATA_T src_val = 0.0f; )==" "\n" |
57 | R"==(#if RESAMPLING_ALG_NEAREST )==" "\n" |
58 | R"==(if (mb >= MB || c >= C) return; )==" "\n" |
59 | R"==(int od_start = ceil_pos(id * FD - .5f); )==" "\n" |
60 | R"==(int oh_start = ceil_pos(ih * FH - .5f); )==" "\n" |
61 | R"==(int ow_start = ceil_pos(iw * FW - .5f); )==" "\n" |
62 | R"==(int od_end = ceil_pos((id + 1.f) * FD - .5f); )==" "\n" |
63 | R"==(int oh_end = ceil_pos((ih + 1.f) * FH - .5f); )==" "\n" |
64 | R"==(int ow_end = ceil_pos((iw + 1.f) * FW - .5f); )==" "\n" |
65 | R"==(for_(int i = od_start; i < od_end; i++) )==" "\n" |
66 | R"==(for_(int j = oh_start; j < oh_end; j++) )==" "\n" |
67 | R"==(for (int k = ow_start; k < ow_end; k++) { )==" "\n" |
68 | R"==(const int dst_index = DST_OFF(mb, c_start, i, j, k); )==" "\n" |
69 | R"==(#if VECT_DT_N == 1 )==" "\n" |
70 | R"==(src_val += AS_VECT_DEF_ACC_DATA_T(diff_dst[dst_index + sglid]); )==" "\n" |
71 | R"==(#else )==" "\n" |
72 | R"==(for (int i = 0; i < VECT_DT_N && c + GWS_SGS_DEFAULT * i < C; i++) { )==" "\n" |
73 | R"==(src_val[i] += TO_DEF_ACC_DATA_T( )==" "\n" |
74 | R"==(diff_dst[dst_index + sglid + GWS_SGS_DEFAULT * i]); )==" "\n" |
75 | R"==(} )==" "\n" |
76 | R"==(#endif )==" "\n" |
77 | R"==(} )==" "\n" |
78 | R"==(#else )==" "\n" |
79 | R"==(int my_idx; )==" "\n" |
80 | R"==({ )==" "\n" |
81 | R"==(int ix, OX, IX; )==" "\n" |
82 | R"==(if (sglid < 4) { )==" "\n" |
83 | R"==(ix = id; )==" "\n" |
84 | R"==(OX = OD; )==" "\n" |
85 | R"==(IX = ID; )==" "\n" |
86 | R"==(} else if (sglid < 8) { )==" "\n" |
87 | R"==(ix = ih; )==" "\n" |
88 | R"==(OX = OH; )==" "\n" |
89 | R"==(IX = IH; )==" "\n" |
90 | R"==(} else { )==" "\n" |
91 | R"==(ix = iw; )==" "\n" |
92 | R"==(OX = OW; )==" "\n" |
93 | R"==(IX = IW; )==" "\n" |
94 | R"==(} )==" "\n" |
95 | R"==(if (sglid % 4 == 1) { ix -= 1; } )==" "\n" |
96 | R"==(if (sglid % 4 == 2) { ix += 1; } )==" "\n" |
97 | R"==(float idx_intermediate = linear(ix, OX, IX); )==" "\n" |
98 | R"==(if (sglid % 2 == 0) { )==" "\n" |
99 | R"==(my_idx = ceil_pos(idx_intermediate); )==" "\n" |
100 | R"==(} else { )==" "\n" |
101 | R"==(my_idx = (idx_intermediate < 0) ? 0 : (int)idx_intermediate + 1; )==" "\n" |
102 | R"==(} )==" "\n" |
103 | R"==(if (sglid % 4 >= 2) { my_idx = min(my_idx, OX); } )==" "\n" |
104 | R"==(if (sglid % 4 == 0 && ix == 0) { my_idx = 0; } )==" "\n" |
105 | R"==(if (sglid % 4 == 3 && ix == IX - 1) { my_idx = OX; } )==" "\n" |
106 | R"==(} )==" "\n" |
107 | R"==(int od_start[2] )==" "\n" |
108 | R"==(= {sub_group_shuffle(my_idx, 0), sub_group_shuffle(my_idx, 1)}; )==" "\n" |
109 | R"==(int od_end[2] )==" "\n" |
110 | R"==(= {sub_group_shuffle(my_idx, 2), sub_group_shuffle(my_idx, 3)}; )==" "\n" |
111 | R"==(int oh_start[2] )==" "\n" |
112 | R"==(= {sub_group_shuffle(my_idx, 4), sub_group_shuffle(my_idx, 5)}; )==" "\n" |
113 | R"==(int oh_end[2] )==" "\n" |
114 | R"==(= {sub_group_shuffle(my_idx, 6), sub_group_shuffle(my_idx, 7)}; )==" "\n" |
115 | R"==(int ow_start[2] )==" "\n" |
116 | R"==(= {sub_group_shuffle(my_idx, 8), sub_group_shuffle(my_idx, 9)}; )==" "\n" |
117 | R"==(int ow_end[2] )==" "\n" |
118 | R"==(= {sub_group_shuffle(my_idx, 10), sub_group_shuffle(my_idx, 11)}; )==" "\n" |
119 | R"==(const int num_od_left = od_end[0] - od_start[0]; )==" "\n" |
120 | R"==(const int num_od_right = od_end[1] - od_start[1]; )==" "\n" |
121 | R"==(const int num_oh_left = oh_end[0] - oh_start[0]; )==" "\n" |
122 | R"==(const int num_oh_right = oh_end[1] - oh_start[1]; )==" "\n" |
123 | R"==(const int num_ow_left = ow_end[0] - ow_start[0]; )==" "\n" |
124 | R"==(const int num_ow_right = ow_end[1] - ow_start[1]; )==" "\n" |
125 | R"==(float myres; )==" "\n" |
126 | R"==({ )==" "\n" |
127 | R"==(int ox, IX, OX; )==" "\n" |
128 | R"==(int offset = sglid; )==" "\n" |
129 | R"==(if (0 <= offset && offset < num_od_left) { )==" "\n" |
130 | R"==(OX = OD; )==" "\n" |
131 | R"==(IX = ID; )==" "\n" |
132 | R"==(ox = od_start[0] + offset; )==" "\n" |
133 | R"==(} )==" "\n" |
134 | R"==(offset -= num_od_left; )==" "\n" |
135 | R"==(if (0 <= offset && offset < num_od_right) { )==" "\n" |
136 | R"==(OX = OD; )==" "\n" |
137 | R"==(IX = ID; )==" "\n" |
138 | R"==(ox = od_start[1] + offset; )==" "\n" |
139 | R"==(} )==" "\n" |
140 | R"==(offset -= num_od_right; )==" "\n" |
141 | R"==(if (0 <= offset && offset < num_oh_left) { )==" "\n" |
142 | R"==(OX = OH; )==" "\n" |
143 | R"==(IX = IH; )==" "\n" |
144 | R"==(ox = oh_start[0] + offset; )==" "\n" |
145 | R"==(} )==" "\n" |
146 | R"==(offset -= num_oh_left; )==" "\n" |
147 | R"==(if (0 <= offset && offset < num_oh_right) { )==" "\n" |
148 | R"==(OX = OH; )==" "\n" |
149 | R"==(IX = IH; )==" "\n" |
150 | R"==(ox = oh_start[1] + offset; )==" "\n" |
151 | R"==(} )==" "\n" |
152 | R"==(offset -= num_oh_right; )==" "\n" |
153 | R"==(if (0 <= offset && offset < num_ow_left) { )==" "\n" |
154 | R"==(OX = OW; )==" "\n" |
155 | R"==(IX = IW; )==" "\n" |
156 | R"==(ox = ow_start[0] + offset; )==" "\n" |
157 | R"==(} )==" "\n" |
158 | R"==(offset -= num_ow_left; )==" "\n" |
159 | R"==(if (0 <= offset && offset < num_ow_right) { )==" "\n" |
160 | R"==(OX = OW; )==" "\n" |
161 | R"==(IX = IW; )==" "\n" |
162 | R"==(ox = ow_start[1] + offset; )==" "\n" |
163 | R"==(} )==" "\n" |
164 | R"==(const float x = linear(ox, IX, OX); )==" "\n" |
165 | R"==(myres = fabs(x - trunc(x)); )==" "\n" |
166 | R"==(} )==" "\n" |
167 | R"==(float d_list[2][MAX_NUM_D]; )==" "\n" |
168 | R"==(float h_list[2][MAX_NUM_H]; )==" "\n" |
169 | R"==(float w_list[2][MAX_NUM_W]; )==" "\n" |
170 | R"==(for (int d = 0; d < num_od_left; d++) { )==" "\n" |
171 | R"==(d_list[0][d] = 1.0f - sub_group_shuffle(myres, d); )==" "\n" |
172 | R"==(} )==" "\n" |
173 | R"==(int offset = num_od_left; )==" "\n" |
174 | R"==(for (int d = 0; d < num_od_right; d++) { )==" "\n" |
175 | R"==(d_list[1][d] = sub_group_shuffle(myres, d + offset); )==" "\n" |
176 | R"==(} )==" "\n" |
177 | R"==(offset += num_od_right; )==" "\n" |
178 | R"==(for (int h = 0; h < num_oh_left; h++) { )==" "\n" |
179 | R"==(h_list[0][h] = 1.0f - sub_group_shuffle(myres, h + offset); )==" "\n" |
180 | R"==(} )==" "\n" |
181 | R"==(offset += num_oh_left; )==" "\n" |
182 | R"==(for (int h = 0; h < num_oh_right; h++) { )==" "\n" |
183 | R"==(h_list[1][h] = sub_group_shuffle(myres, h + offset); )==" "\n" |
184 | R"==(} )==" "\n" |
185 | R"==(offset += num_oh_right; )==" "\n" |
186 | R"==(for (int w = 0; w < num_ow_left; w++) { )==" "\n" |
187 | R"==(w_list[0][w] = 1.0f - sub_group_shuffle(myres, w + offset); )==" "\n" |
188 | R"==(} )==" "\n" |
189 | R"==(offset += num_ow_left; )==" "\n" |
190 | R"==(for (int w = 0; w < num_ow_right; w++) { )==" "\n" |
191 | R"==(w_list[1][w] = sub_group_shuffle(myres, w + offset); )==" "\n" |
192 | R"==(} )==" "\n" |
193 | R"==(if (mb >= MB || c >= C) return; )==" "\n" |
194 | R"==(const uint mb_c_off = DST_MB_STRIDE(mb) + DST_C_STRIDE(c_start); )==" "\n" |
195 | R"==(for_(int c1 = 0; c1 < 2; c1++) )==" "\n" |
196 | R"==(for (int od = od_start[c1], i = 0; i < MAX_NUM_D && od < od_end[c1]; )==" "\n" |
197 | R"==(od++, i++) { )==" "\n" |
198 | R"==(const uint d_off = mb_c_off + OD_STRIDE(od); )==" "\n" |
199 | R"==(float Wid = d_list[c1][i]; )==" "\n" |
200 | R"==(for_(int c2 = 0; c2 < 2; c2++) )==" "\n" |
201 | R"==(for (int oh = oh_start[c2], j = 0; j < MAX_NUM_H && oh < oh_end[c2]; )==" "\n" |
202 | R"==(oh++, j++) { )==" "\n" |
203 | R"==(const uint h_off = d_off + OH_STRIDE(oh); )==" "\n" |
204 | R"==(float Wih = h_list[c2][j]; )==" "\n" |
205 | R"==(unroll_for(int c3 = 0; c3 < 2; c3++) )==" "\n" |
206 | R"==(unroll_for(int k = 0, ow = ow_start[c3]; )==" "\n" |
207 | R"==(k < MAX_NUM_W && ow < ow_end[c3]; k++, ow++) { )==" "\n" |
208 | R"==(const uint dst_off = h_off + OW_STRIDE(ow); )==" "\n" |
209 | R"==(#if VECT_DT_N == 1 )==" "\n" |
210 | R"==(VECT_DEF_ACC_DATA_T dst_val )==" "\n" |
211 | R"==(= AS_VECT_DEF_ACC_DATA_T(diff_dst[dst_off + sglid]); )==" "\n" |
212 | R"==(#else )==" "\n" |
213 | R"==(VECT_DEF_ACC_DATA_T dst_val = 0; )==" "\n" |
214 | R"==(for (int idx = 0; )==" "\n" |
215 | R"==(idx < VECT_DT_N && c + GWS_SGS_DEFAULT * idx < C; )==" "\n" |
216 | R"==(idx++) { )==" "\n" |
217 | R"==(dst_val[idx] = TO_DEF_ACC_DATA_T( )==" "\n" |
218 | R"==(diff_dst[dst_off + sglid + GWS_SGS_DEFAULT * idx]); )==" "\n" |
219 | R"==(} )==" "\n" |
220 | R"==(#endif )==" "\n" |
221 | R"==(float Wiw = w_list[c3][k]; )==" "\n" |
222 | R"==(src_val += dst_val * Wid * Wih * Wiw; )==" "\n" |
223 | R"==(} )==" "\n" |
224 | R"==(} )==" "\n" |
225 | R"==(} )==" "\n" |
226 | R"==(#endif )==" "\n" |
227 | R"==(#if VECT_DT_N == 1 )==" "\n" |
228 | R"==(diff_src[src_index + sglid] = TO_DST(src_val); )==" "\n" |
229 | R"==(#else )==" "\n" |
230 | R"==(for (int i = 0; i < VECT_DT_N && c + GWS_SGS_DEFAULT * i < C; i++) { )==" "\n" |
231 | R"==(diff_src[src_index + sglid + GWS_SGS_DEFAULT * i] = TO_DST(src_val[i]); )==" "\n" |
232 | R"==(} )==" "\n" |
233 | R"==(#endif )==" "\n" |
234 | R"==(} )==" "\n" |
235 | R"==(#endif )==" "\n" |
236 | R"==()==" ; |
237 | } |
238 | } |
239 | } |
240 | } |