1 | namespace dnnl { |
2 | namespace impl { |
3 | namespace gpu { |
4 | namespace ocl { |
5 | const char *ref_resampling_kernel = R"==(/******************************************************************************* )==" "\n" |
6 | R"==(* Copyright 2019-2022 Intel Corporation )==" "\n" |
7 | R"==(* )==" "\n" |
8 | R"==(* Licensed under the Apache License, Version 2.0 (the "License"); )==" "\n" |
9 | R"==(* you may not use this file except in compliance with the License. )==" "\n" |
10 | R"==(* You may obtain a copy of the License at )==" "\n" |
11 | R"==(* )==" "\n" |
12 | R"==(* http: )==" "\n" |
13 | R"==(* )==" "\n" |
14 | R"==(* Unless required by applicable law or agreed to in writing, software )==" "\n" |
15 | R"==(* distributed under the License is distributed on an "AS IS" BASIS, )==" "\n" |
16 | R"==(* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. )==" "\n" |
17 | R"==(* See the License for the specific language governing permissions and )==" "\n" |
18 | R"==(* limitations under the License. )==" "\n" |
19 | R"==(*******************************************************************************/ )==" "\n" |
20 | R"==(#include "gpu/ocl/ocl_post_ops.h" )==" "\n" |
21 | R"==(#include "gpu/ocl/ocl_types.h" )==" "\n" |
22 | R"==(#if IS_FWD == 1 )==" "\n" |
23 | R"==(KERNEL_ATTR )==" "\n" |
24 | R"==(__kernel void ref_resampling_fwd( )==" "\n" |
25 | R"==(__global const DATA_T *src, __global DST_DATA_T *dst POST_OP_ARGS) { )==" "\n" |
26 | R"==(const uint mb = GWS_GET_MB(); )==" "\n" |
27 | R"==(const uint c = GWS_GET_C(); )==" "\n" |
28 | R"==(const uint od = GWS_GET_OD(); )==" "\n" |
29 | R"==(const uint oh = GWS_GET_OH(); )==" "\n" |
30 | R"==(const uint ow = GWS_GET_OW(); )==" "\n" |
31 | R"==(const float id = (od + .5f) * ID / OD; )==" "\n" |
32 | R"==(const float ih = (oh + .5f) * IH / OH; )==" "\n" |
33 | R"==(const float iw = (ow + .5f) * IW / OW; )==" "\n" |
34 | R"==(float result; )==" "\n" |
35 | R"==(const uint dst_index = DST_OFF(mb, c, od, oh, ow); )==" "\n" |
36 | R"==(if (mb >= DST_D0 || c >= DST_D1) { )==" "\n" |
37 | R"==(dst[dst_index] = TO_DST(0.0f); )==" "\n" |
38 | R"==(return; )==" "\n" |
39 | R"==(} )==" "\n" |
40 | R"==(#if RESAMPLING_ALG_NEAREST )==" "\n" |
41 | R"==(const uint src_index = SRC_OFF(mb, c, (uint)id, (uint)ih, (uint)iw); )==" "\n" |
42 | R"==(result = CONVERT_FLOAT_T(src[src_index]); )==" "\n" |
43 | R"==(#else )==" "\n" |
44 | R"==(const int id0 = max((int)floor(id - .5f), 0); )==" "\n" |
45 | R"==(const int id1 = min((int)ceil(id - .5f), ID - 1); )==" "\n" |
46 | R"==(const int ih0 = max((int)floor(ih - .5f), 0); )==" "\n" |
47 | R"==(const int ih1 = min((int)ceil(ih - .5f), IH - 1); )==" "\n" |
48 | R"==(const int iw0 = max((int)floor(iw - .5f), 0); )==" "\n" |
49 | R"==(const int iw1 = min((int)ceil(iw - .5f), IW - 1); )==" "\n" |
50 | R"==(const float wd[2] = {1.0f - fabs(id - .5f - id0), fabs(id - .5f - id0)}; )==" "\n" |
51 | R"==(const float wh[2] = {1.0f - fabs(ih - .5f - ih0), fabs(ih - .5f - ih0)}; )==" "\n" |
52 | R"==(const float ww[2] = {1.0f - fabs(iw - .5f - iw0), fabs(iw - .5f - iw0)}; )==" "\n" |
53 | R"==(const int ih_arr[2] = {ih0, ih1}; )==" "\n" |
54 | R"==(const int iw_arr[2] = {iw0, iw1}; )==" "\n" |
55 | R"==(float cd[2][2]; )==" "\n" |
56 | R"==(for_(int i = 0; i < 2; i++) )==" "\n" |
57 | R"==(for (int j = 0; j < 2; j++) )==" "\n" |
58 | R"==(cd[i][j] = CONVERT_FLOAT_T( )==" "\n" |
59 | R"==(src[SRC_OFF(mb, c, id0, ih_arr[i], iw_arr[j])]) )==" "\n" |
60 | R"==(* wd[0] )==" "\n" |
61 | R"==(+ CONVERT_FLOAT_T( )==" "\n" |
62 | R"==(src[SRC_OFF(mb, c, id1, ih_arr[i], iw_arr[j])]) )==" "\n" |
63 | R"==(* wd[1]; )==" "\n" |
64 | R"==(float ch[2]; )==" "\n" |
65 | R"==(for (int i = 0; i < 2; i++) )==" "\n" |
66 | R"==(ch[i] = cd[0][i] * wh[0] + cd[1][i] * wh[1]; )==" "\n" |
67 | R"==(result = ch[0] * ww[0] + ch[1] * ww[1]; )==" "\n" |
68 | R"==(#endif )==" "\n" |
69 | R"==(float sum_src; )==" "\n" |
70 | R"==(#if WITH_SUM )==" "\n" |
71 | R"==(sum_src = DST_TO_REF(dst[dst_index]); )==" "\n" |
72 | R"==(#endif )==" "\n" |
73 | R"==(#if NDIMS == 3 )==" "\n" |
74 | R"==(const unsigned po_d2 = ow; )==" "\n" |
75 | R"==(const unsigned po_d3 = 0; )==" "\n" |
76 | R"==(const unsigned po_d4 = 0; )==" "\n" |
77 | R"==(#elif NDIMS == 4 )==" "\n" |
78 | R"==(const unsigned po_d2 = oh; )==" "\n" |
79 | R"==(const unsigned po_d3 = ow; )==" "\n" |
80 | R"==(const unsigned po_d4 = 0; )==" "\n" |
81 | R"==(#elif NDIMS == 5 )==" "\n" |
82 | R"==(const unsigned po_d2 = od; )==" "\n" |
83 | R"==(const unsigned po_d3 = oh; )==" "\n" |
84 | R"==(const unsigned po_d4 = ow; )==" "\n" |
85 | R"==(#else )==" "\n" |
86 | R"==(const unsigned po_d2 = 0; )==" "\n" |
87 | R"==(const unsigned po_d3 = 0; )==" "\n" |
88 | R"==(const unsigned po_d4 = 0; )==" "\n" |
89 | R"==(#endif )==" "\n" |
90 | R"==(APPLY_POST_OPS_SERIAL(result, float, sum_src, float, mb, 1, c, 1, po_d2, 1, )==" "\n" |
91 | R"==(po_d3, 1, po_d4, 1, 0, 1); )==" "\n" |
92 | R"==(dst[dst_index] = TO_DST(result); )==" "\n" |
93 | R"==(} )==" "\n" |
94 | R"==(#endif )==" "\n" |
95 | R"==(#if IS_BWD == 1 )==" "\n" |
96 | R"==(float linear(int x, int fo, int fi) { )==" "\n" |
97 | R"==(return ((x + .5f) * fo / fi) - .5f; )==" "\n" |
98 | R"==(} )==" "\n" |
99 | R"==(KERNEL_ATTR )==" "\n" |
100 | R"==(__kernel void ref_resampling_bwd( )==" "\n" |
101 | R"==(__global DATA_T *diff_src, __global const DST_DATA_T *diff_dst) { )==" "\n" |
102 | R"==(#define CEIL(x) max((int)ceil(x), (int)0) )==" "\n" |
103 | R"==(#define L(x, fo, fi) linear(x, fo, fi) )==" "\n" |
104 | R"==(#define LS(x, fo, fi) CEIL(L(x, fo, fi)) )==" "\n" |
105 | R"==(#define RS(x, fo, fi) \ )==" "\n" |
106 | R"==(L((int)x - 1, fo, fi) < 0 ? 0 : (int)(L((int)x - 1, fo, fi)) + 1 )==" "\n" |
107 | R"==(#define LE(x, fo, fi, lim) min(CEIL(L(x + 1, fo, fi)), (int)lim) )==" "\n" |
108 | R"==(#define RE(x, fo, fi, lim) \ )==" "\n" |
109 | R"==(min((L(x, fo, fi) < 0 ? 0 : (int)(L(x, fo, fi)) + 1), (int)lim) )==" "\n" |
110 | R"==(const uint mb = GWS_GET_MB(); )==" "\n" |
111 | R"==(const uint c = GWS_GET_C(); )==" "\n" |
112 | R"==(const uint id = GWS_GET_ID(); )==" "\n" |
113 | R"==(const uint ih = GWS_GET_IH(); )==" "\n" |
114 | R"==(const uint iw = GWS_GET_IW(); )==" "\n" |
115 | R"==(const uint src_index = SRC_OFF(mb, c, id, ih, iw); )==" "\n" |
116 | R"==(if (mb >= DST_D0 || c >= DST_D1) { )==" "\n" |
117 | R"==(diff_src[src_index] = TO_DST(0.f); )==" "\n" |
118 | R"==(return; )==" "\n" |
119 | R"==(} )==" "\n" |
120 | R"==(#if RESAMPLING_ALG_NEAREST )==" "\n" |
121 | R"==(int od_start = CEIL(id * FD - .5f); )==" "\n" |
122 | R"==(int oh_start = CEIL(ih * FH - .5f); )==" "\n" |
123 | R"==(int ow_start = CEIL(iw * FW - .5f); )==" "\n" |
124 | R"==(int od_end = CEIL((id + 1.f) * FD - .5f); )==" "\n" |
125 | R"==(int oh_end = CEIL((ih + 1.f) * FH - .5f); )==" "\n" |
126 | R"==(int ow_end = CEIL((iw + 1.f) * FW - .5f); )==" "\n" |
127 | R"==(float src_val = 0; )==" "\n" |
128 | R"==(for (int i = od_start; i < od_end; i++) { )==" "\n" |
129 | R"==(for (int j = oh_start; j < oh_end; j++) { )==" "\n" |
130 | R"==(for (int k = ow_start; k < ow_end; k++) { )==" "\n" |
131 | R"==(const int dst_index = DST_OFF(mb, c, i, j, k); )==" "\n" |
132 | R"==(src_val += DST_TO_REF(diff_dst[dst_index]); )==" "\n" |
133 | R"==(} )==" "\n" |
134 | R"==(} )==" "\n" |
135 | R"==(} )==" "\n" |
136 | R"==(#else )==" "\n" |
137 | R"==(int left_sd = id == 0 ? 0 : LS(id, OD, ID); )==" "\n" |
138 | R"==(int left_sh = ih == 0 ? 0 : LS(ih, OH, IH); )==" "\n" |
139 | R"==(int left_sw = iw == 0 ? 0 : LS(iw, OW, IW); )==" "\n" |
140 | R"==(int right_sd = RS(id, OD, ID); )==" "\n" |
141 | R"==(int right_sh = RS(ih, OH, IH); )==" "\n" |
142 | R"==(int right_sw = RS(iw, OW, IW); )==" "\n" |
143 | R"==(int left_ed = LE(id, OD, ID, OD); )==" "\n" |
144 | R"==(int left_eh = LE(ih, OH, IH, OH); )==" "\n" |
145 | R"==(int left_ew = LE(iw, OW, IW, OW); )==" "\n" |
146 | R"==(int right_ed = id == (ID - 1) ? OD : RE(id, OD, ID, OD); )==" "\n" |
147 | R"==(int right_eh = ih == (IH - 1) ? OH : RE(ih, OH, IH, OH); )==" "\n" |
148 | R"==(int right_ew = iw == (IW - 1) ? OW : RE(iw, OW, IW, OW); )==" "\n" |
149 | R"==(int od_start[2] = {left_sd, right_sd}; )==" "\n" |
150 | R"==(int oh_start[2] = {left_sh, right_sh}; )==" "\n" |
151 | R"==(int ow_start[2] = {left_sw, right_sw}; )==" "\n" |
152 | R"==(int od_end[2] = {left_ed, right_ed}; )==" "\n" |
153 | R"==(int oh_end[2] = {left_eh, right_eh}; )==" "\n" |
154 | R"==(int ow_end[2] = {left_ew, right_ew}; )==" "\n" |
155 | R"==(float src_val = 0.0f; )==" "\n" |
156 | R"==(for (int c1 = 0; c1 < 2; c1++) { )==" "\n" |
157 | R"==(for (int c2 = 0; c2 < 2; c2++) { )==" "\n" |
158 | R"==(for (int c3 = 0; c3 < 2; c3++) { )==" "\n" |
159 | R"==(for (int i = od_start[c1]; i < od_end[c1]; i++) { )==" "\n" |
160 | R"==(for (int j = oh_start[c2]; j < oh_end[c2]; j++) { )==" "\n" |
161 | R"==(for (int k = ow_start[c3]; k < ow_end[c3]; k++) { )==" "\n" |
162 | R"==(float dst_val = DST_TO_REF( )==" "\n" |
163 | R"==(diff_dst[DST_OFF(mb, c, i, j, k)]); )==" "\n" |
164 | R"==(float d = L(i, ID, OD); )==" "\n" |
165 | R"==(float h = L(j, IH, OH); )==" "\n" |
166 | R"==(float w = L(k, IW, OW); )==" "\n" |
167 | R"==(float Wid = c1 == 0 ? 1.f - fabs(d - (int)d) )==" "\n" |
168 | R"==(: fabs(d - (int)d); )==" "\n" |
169 | R"==(float Wih = c2 == 0 ? 1.f - fabs(h - (int)h) )==" "\n" |
170 | R"==(: fabs(h - (int)h); )==" "\n" |
171 | R"==(float Wiw = c3 == 0 ? 1.f - fabs(w - (int)w) )==" "\n" |
172 | R"==(: fabs(w - (int)w); )==" "\n" |
173 | R"==(src_val += dst_val * Wid * Wih * Wiw; )==" "\n" |
174 | R"==(} )==" "\n" |
175 | R"==(} )==" "\n" |
176 | R"==(} )==" "\n" |
177 | R"==(} )==" "\n" |
178 | R"==(} )==" "\n" |
179 | R"==(} )==" "\n" |
180 | R"==(#endif )==" "\n" |
181 | R"==(#if DT_S32 == 1 )==" "\n" |
182 | R"==(diff_src[src_index] = CONVERT_DATA_T(src_val); )==" "\n" |
183 | R"==(#else )==" "\n" |
184 | R"==(diff_src[src_index] = TO_DATA_T(src_val); )==" "\n" |
185 | R"==(#endif )==" "\n" |
186 | R"==(} )==" "\n" |
187 | R"==(#endif )==" "\n" |
188 | R"==()==" ; |
189 | } |
190 | } |
191 | } |
192 | } |