1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <cassert> |
18 | #include <cfloat> |
19 | |
20 | #include "common/c_types_map.hpp" |
21 | #include "common/dnnl_thread.hpp" |
22 | #include "common/math_utils.hpp" |
23 | #include "common/type_helpers.hpp" |
24 | |
25 | #include "cpu/resampling_utils.hpp" |
26 | |
27 | #include "cpu/ref_resampling.hpp" |
28 | |
29 | namespace dnnl { |
30 | namespace impl { |
31 | namespace cpu { |
32 | |
33 | using namespace resampling_utils; |
34 | |
35 | using byte = unsigned char; |
36 | using load_fn_t = std::function<float(const byte *base, const dim_t offset)>; |
37 | using store_fn_t |
38 | = std::function<void(const float val, byte *base, const dim_t offset)>; |
39 | |
40 | namespace { |
41 | template <data_type_t type> |
42 | load_fn_t create_load() { |
43 | return [](const byte *base, dim_t offset) -> float { |
44 | return static_cast<float>( |
45 | reinterpret_cast<const typename prec_traits<type>::type *>( |
46 | base)[offset]); |
47 | }; |
48 | } |
49 | template <> |
50 | load_fn_t create_load<data_type::f32>() { |
51 | return [](const byte *base, dim_t offset) -> float { |
52 | return reinterpret_cast<const float *>(base)[offset]; |
53 | }; |
54 | } |
55 | template <data_type_t type> |
56 | store_fn_t create_store() { |
57 | using dst_t = typename prec_traits<type>::type; |
58 | return [](const float val, byte *base, const dim_t offset) { |
59 | *reinterpret_cast<dst_t *>(base + sizeof(dst_t) * offset) |
60 | = cpu::saturate_and_round<dst_t>(val); |
61 | }; |
62 | } |
63 | template <> |
64 | store_fn_t create_store<data_type::f32>() { |
65 | return [](const float val, byte *base, const dim_t offset) { |
66 | *reinterpret_cast<float *>(base + sizeof(float) * offset) = val; |
67 | }; |
68 | } |
69 | } // namespace |
70 | |
71 | static load_fn_t create_load(const data_type_t src_dtype) { |
72 | using namespace data_type; |
73 | |
74 | switch (src_dtype) { |
75 | case f32: return create_load<f32>(); |
76 | case s32: return create_load<s32>(); |
77 | case bf16: return create_load<bf16>(); |
78 | case f16: return create_load<f16>(); |
79 | case s8: return create_load<s8>(); |
80 | case u8: return create_load<u8>(); |
81 | default: assert(!"Unsupported data type." ); |
82 | } |
83 | return create_load<f32>(); |
84 | } |
85 | |
86 | static store_fn_t create_store(const data_type_t dst_dtype) { |
87 | using namespace data_type; |
88 | |
89 | switch (dst_dtype) { |
90 | case f32: return create_store<f32>(); |
91 | case s32: return create_store<s32>(); |
92 | case bf16: return create_store<bf16>(); |
93 | case f16: return create_store<f16>(); |
94 | case s8: return create_store<s8>(); |
95 | case u8: return create_store<u8>(); |
96 | default: assert(!"Unsupported data type." ); |
97 | } |
98 | return create_store<f32>(); |
99 | } |
100 | |
101 | static dim_t get_offset( |
102 | const memory_desc_wrapper &data_d, int n, int c, int d, int h, int w) { |
103 | if (data_d.ndims() == 5) return data_d.off(n, c, d, h, w); |
104 | if (data_d.ndims() == 4) return data_d.off(n, c, h, w); |
105 | return data_d.off(n, c, w); |
106 | } |
107 | |
108 | ref_resampling_fwd_t::ref_resampling_fwd_t(const pd_t *apd) |
109 | : primitive_t(apd), ref_post_ops_(pd()->attr()->post_ops_) {} |
110 | |
111 | ref_resampling_fwd_t::~ref_resampling_fwd_t() = default; |
112 | |
113 | void ref_resampling_fwd_t::execute_forward(const exec_ctx_t &ctx) const { |
114 | if (this->pd()->has_zero_dim_memory()) return; |
115 | status_t status = status::success; |
116 | const auto src = CTX_IN_MEM(const byte *, DNNL_ARG_SRC); |
117 | auto dst = CTX_OUT_CLEAN_MEM(byte *, DNNL_ARG_DST, status); |
118 | |
119 | const memory_desc_wrapper src_d(pd()->src_md()); |
120 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
121 | |
122 | const data_type_t src_dt = pd()->src_md()->data_type; |
123 | const data_type_t dst_dt = pd()->dst_md()->data_type; |
124 | |
125 | load_fn_t load_fn = create_load(src_dt); |
126 | store_fn_t store_fn = create_store(dst_dt); |
127 | |
128 | const auto alg = pd()->desc()->alg_kind; |
129 | |
130 | const int MB = pd()->MB(); |
131 | const int C = pd()->C(); |
132 | const int ID = pd()->ID(); |
133 | const int IH = pd()->IH(); |
134 | const int IW = pd()->IW(); |
135 | const int OD = pd()->OD(); |
136 | const int OH = pd()->OH(); |
137 | const int OW = pd()->OW(); |
138 | |
139 | auto lin_interp = [&](float c0, float c1, float w) { |
140 | return c0 * w + c1 * (1 - w); |
141 | }; |
142 | auto bilin_interp = [&](float c00, float c01, float c10, float c11, |
143 | float w0, float w1) { |
144 | return lin_interp( |
145 | lin_interp(c00, c10, w0), lin_interp(c01, c11, w0), w1); |
146 | }; |
147 | auto trilin_interp = [&](float c000, float c001, float c010, float c011, |
148 | float c100, float c101, float c110, float c111, |
149 | float w0, float w1, float w2) { |
150 | return lin_interp(bilin_interp(c000, c010, c100, c110, w0, w1), |
151 | bilin_interp(c001, c011, c101, c111, w0, w1), w2); |
152 | }; |
153 | |
154 | parallel_nd(MB, C, OD, OH, OW, |
155 | [&](dim_t mb, dim_t ch, dim_t od, dim_t oh, dim_t ow) { |
156 | const dim_t data_p_off = get_offset(dst_d, mb, ch, od, oh, ow); |
157 | const dim_t data_l_off |
158 | = (((mb * C + ch) * OD + od) * OH + oh) * OW + ow; |
159 | float res = 0.f; |
160 | |
161 | if (alg == alg_kind::resampling_nearest) { |
162 | const dim_t id = nearest_idx(od, OD, ID); |
163 | const dim_t ih = nearest_idx(oh, OH, IH); |
164 | const dim_t iw = nearest_idx(ow, OW, IW); |
165 | res = load_fn(src, get_offset(src_d, mb, ch, id, ih, iw)); |
166 | } else if (alg == alg_kind::resampling_linear) { |
167 | // Trilinear interpolation (linear interpolation on a 3D spatial |
168 | // tensor) can be expressed as linear interpolation along |
169 | // dimension x followed by interpolation along dimension y and z |
170 | // C011--C11--C111 |
171 | // - - | |
172 | // - - | |
173 | //C001--C01--C111 | |
174 | // - .C - C110 |
175 | // - - - |
176 | // - - - |
177 | //C000--C00--C100 |
178 | auto id = linear_coeffs_t(od, OD, ID); |
179 | auto iw = linear_coeffs_t(ow, OW, IW); |
180 | auto ih = linear_coeffs_t(oh, OH, IH); |
181 | float src_l[8] = {0}; |
182 | for_(int i = 0; i < 2; i++) |
183 | for_(int j = 0; j < 2; j++) |
184 | for (int k = 0; k < 2; k++) { |
185 | src_l[4 * i + 2 * j + k] = load_fn(src, |
186 | get_offset(src_d, mb, ch, id.idx[i], ih.idx[j], |
187 | iw.idx[k])); |
188 | } |
189 | res = trilin_interp(src_l[0], src_l[1], src_l[2], src_l[3], |
190 | src_l[4], src_l[5], src_l[6], src_l[7], id.wei[0], |
191 | ih.wei[0], iw.wei[0]); |
192 | } |
193 | |
194 | ref_post_ops_t::args_t args; |
195 | args.ctx = &ctx; |
196 | args.dst_md = pd()->dst_md(); |
197 | args.l_offset = data_l_off; |
198 | args.dst_val = dst[data_p_off]; |
199 | ref_post_ops_.execute(res, args); |
200 | |
201 | store_fn(res, dst, data_p_off); |
202 | }); |
203 | } |
204 | |
205 | ref_resampling_bwd_t::ref_resampling_bwd_t(const pd_t *apd) |
206 | : primitive_t(apd) {} |
207 | |
208 | ref_resampling_bwd_t::~ref_resampling_bwd_t() = default; |
209 | |
210 | void ref_resampling_bwd_t::execute_backward(const exec_ctx_t &ctx) const { |
211 | if (this->pd()->has_zero_dim_memory()) return; |
212 | status_t status = status::success; |
213 | const auto diff_dst = CTX_IN_MEM(const byte *, DNNL_ARG_DIFF_DST); |
214 | auto diff_src = CTX_OUT_CLEAN_MEM(byte *, DNNL_ARG_DIFF_SRC, status); |
215 | |
216 | const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); |
217 | const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); |
218 | |
219 | const data_type_t diff_dst_dt = pd()->diff_dst_md()->data_type; |
220 | const data_type_t diff_src_dt = pd()->diff_src_md()->data_type; |
221 | |
222 | load_fn_t load_fn = create_load(diff_dst_dt); |
223 | store_fn_t store_fn = create_store(diff_src_dt); |
224 | |
225 | const auto alg = pd()->desc()->alg_kind; |
226 | |
227 | const int MB = pd()->MB(); |
228 | const int C = pd()->C(); |
229 | const int ID = pd()->ID(); |
230 | const int IH = pd()->IH(); |
231 | const int IW = pd()->IW(); |
232 | const int OD = pd()->OD(); |
233 | const int OH = pd()->OH(); |
234 | const int OW = pd()->OW(); |
235 | |
236 | if (alg == alg_kind::resampling_nearest) { |
237 | parallel_nd(MB, C, ID, IH, IW, |
238 | [&](dim_t mb, dim_t ch, dim_t id, dim_t ih, dim_t iw) { |
239 | const dim_t od_start |
240 | = ceil_idx(((float)id * OD / ID) - 0.5f); |
241 | const dim_t oh_start |
242 | = ceil_idx(((float)ih * OH / IH) - 0.5f); |
243 | const dim_t ow_start |
244 | = ceil_idx(((float)iw * OW / IW) - 0.5f); |
245 | |
246 | const dim_t od_end |
247 | = ceil_idx(((id + 1.f) * OD / ID) - 0.5f); |
248 | const dim_t oh_end |
249 | = ceil_idx(((ih + 1.f) * OH / IH) - 0.5f); |
250 | const dim_t ow_end |
251 | = ceil_idx(((iw + 1.f) * OW / IW) - 0.5f); |
252 | |
253 | float ds = 0; |
254 | for_(dim_t od = od_start; od < od_end; od++) |
255 | for_(dim_t oh = oh_start; oh < oh_end; oh++) |
256 | for (dim_t ow = ow_start; ow < ow_end; ow++) |
257 | ds += load_fn(diff_dst, |
258 | get_offset(diff_dst_d, mb, ch, od, oh, ow)); |
259 | store_fn(ds, diff_src, |
260 | get_offset(diff_src_d, mb, ch, id, ih, iw)); |
261 | }); |
262 | } else { |
263 | parallel_nd(MB, C, ID, IH, IW, |
264 | [&](dim_t mb, dim_t ch, dim_t id, dim_t ih, dim_t iw) { |
265 | bwd_linear_coeffs_t d(id, OD, ID); |
266 | bwd_linear_coeffs_t h(ih, OH, IH); |
267 | bwd_linear_coeffs_t w(iw, OW, IW); |
268 | |
269 | float ds = 0; |
270 | for_(int i = 0; i < 2; i++) |
271 | for_(int j = 0; j < 2; j++) |
272 | for_(int k = 0; k < 2; k++) |
273 | for_(dim_t od = d.start[i]; od < d.end[i]; od++) |
274 | for_(dim_t oh = h.start[j]; oh < h.end[j]; oh++) |
275 | for (dim_t ow = w.start[k]; ow < w.end[k]; ow++) { |
276 | const float weight_d = linear_weight(i, od, OD, ID); |
277 | const float weight_h = linear_weight(j, oh, OH, IH); |
278 | const float weight_w = linear_weight(k, ow, OW, IW); |
279 | |
280 | float dd = load_fn(diff_dst, |
281 | get_offset(diff_dst_d, mb, ch, od, oh, ow)); |
282 | ds += dd * weight_d * weight_h * weight_w; |
283 | } |
284 | store_fn(ds, diff_src, |
285 | get_offset(diff_src_d, mb, ch, id, ih, iw)); |
286 | }); |
287 | } |
288 | } |
289 | |
290 | } // namespace cpu |
291 | } // namespace impl |
292 | } // namespace dnnl |
293 | |
294 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
295 | |