1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <cassert>
18#include <cfloat>
19
20#include "common/c_types_map.hpp"
21#include "common/dnnl_thread.hpp"
22#include "common/math_utils.hpp"
23#include "common/type_helpers.hpp"
24
25#include "cpu/resampling_utils.hpp"
26
27#include "cpu/ref_resampling.hpp"
28
29namespace dnnl {
30namespace impl {
31namespace cpu {
32
33using namespace resampling_utils;
34
35using byte = unsigned char;
36using load_fn_t = std::function<float(const byte *base, const dim_t offset)>;
37using store_fn_t
38 = std::function<void(const float val, byte *base, const dim_t offset)>;
39
40namespace {
41template <data_type_t type>
42load_fn_t create_load() {
43 return [](const byte *base, dim_t offset) -> float {
44 return static_cast<float>(
45 reinterpret_cast<const typename prec_traits<type>::type *>(
46 base)[offset]);
47 };
48}
49template <>
50load_fn_t create_load<data_type::f32>() {
51 return [](const byte *base, dim_t offset) -> float {
52 return reinterpret_cast<const float *>(base)[offset];
53 };
54}
55template <data_type_t type>
56store_fn_t create_store() {
57 using dst_t = typename prec_traits<type>::type;
58 return [](const float val, byte *base, const dim_t offset) {
59 *reinterpret_cast<dst_t *>(base + sizeof(dst_t) * offset)
60 = cpu::saturate_and_round<dst_t>(val);
61 };
62}
63template <>
64store_fn_t create_store<data_type::f32>() {
65 return [](const float val, byte *base, const dim_t offset) {
66 *reinterpret_cast<float *>(base + sizeof(float) * offset) = val;
67 };
68}
69} // namespace
70
71static load_fn_t create_load(const data_type_t src_dtype) {
72 using namespace data_type;
73
74 switch (src_dtype) {
75 case f32: return create_load<f32>();
76 case s32: return create_load<s32>();
77 case bf16: return create_load<bf16>();
78 case f16: return create_load<f16>();
79 case s8: return create_load<s8>();
80 case u8: return create_load<u8>();
81 default: assert(!"Unsupported data type.");
82 }
83 return create_load<f32>();
84}
85
86static store_fn_t create_store(const data_type_t dst_dtype) {
87 using namespace data_type;
88
89 switch (dst_dtype) {
90 case f32: return create_store<f32>();
91 case s32: return create_store<s32>();
92 case bf16: return create_store<bf16>();
93 case f16: return create_store<f16>();
94 case s8: return create_store<s8>();
95 case u8: return create_store<u8>();
96 default: assert(!"Unsupported data type.");
97 }
98 return create_store<f32>();
99}
100
101static dim_t get_offset(
102 const memory_desc_wrapper &data_d, int n, int c, int d, int h, int w) {
103 if (data_d.ndims() == 5) return data_d.off(n, c, d, h, w);
104 if (data_d.ndims() == 4) return data_d.off(n, c, h, w);
105 return data_d.off(n, c, w);
106}
107
108ref_resampling_fwd_t::ref_resampling_fwd_t(const pd_t *apd)
109 : primitive_t(apd), ref_post_ops_(pd()->attr()->post_ops_) {}
110
111ref_resampling_fwd_t::~ref_resampling_fwd_t() = default;
112
113void ref_resampling_fwd_t::execute_forward(const exec_ctx_t &ctx) const {
114 if (this->pd()->has_zero_dim_memory()) return;
115 status_t status = status::success;
116 const auto src = CTX_IN_MEM(const byte *, DNNL_ARG_SRC);
117 auto dst = CTX_OUT_CLEAN_MEM(byte *, DNNL_ARG_DST, status);
118
119 const memory_desc_wrapper src_d(pd()->src_md());
120 const memory_desc_wrapper dst_d(pd()->dst_md());
121
122 const data_type_t src_dt = pd()->src_md()->data_type;
123 const data_type_t dst_dt = pd()->dst_md()->data_type;
124
125 load_fn_t load_fn = create_load(src_dt);
126 store_fn_t store_fn = create_store(dst_dt);
127
128 const auto alg = pd()->desc()->alg_kind;
129
130 const int MB = pd()->MB();
131 const int C = pd()->C();
132 const int ID = pd()->ID();
133 const int IH = pd()->IH();
134 const int IW = pd()->IW();
135 const int OD = pd()->OD();
136 const int OH = pd()->OH();
137 const int OW = pd()->OW();
138
139 auto lin_interp = [&](float c0, float c1, float w) {
140 return c0 * w + c1 * (1 - w);
141 };
142 auto bilin_interp = [&](float c00, float c01, float c10, float c11,
143 float w0, float w1) {
144 return lin_interp(
145 lin_interp(c00, c10, w0), lin_interp(c01, c11, w0), w1);
146 };
147 auto trilin_interp = [&](float c000, float c001, float c010, float c011,
148 float c100, float c101, float c110, float c111,
149 float w0, float w1, float w2) {
150 return lin_interp(bilin_interp(c000, c010, c100, c110, w0, w1),
151 bilin_interp(c001, c011, c101, c111, w0, w1), w2);
152 };
153
154 parallel_nd(MB, C, OD, OH, OW,
155 [&](dim_t mb, dim_t ch, dim_t od, dim_t oh, dim_t ow) {
156 const dim_t data_p_off = get_offset(dst_d, mb, ch, od, oh, ow);
157 const dim_t data_l_off
158 = (((mb * C + ch) * OD + od) * OH + oh) * OW + ow;
159 float res = 0.f;
160
161 if (alg == alg_kind::resampling_nearest) {
162 const dim_t id = nearest_idx(od, OD, ID);
163 const dim_t ih = nearest_idx(oh, OH, IH);
164 const dim_t iw = nearest_idx(ow, OW, IW);
165 res = load_fn(src, get_offset(src_d, mb, ch, id, ih, iw));
166 } else if (alg == alg_kind::resampling_linear) {
167 // Trilinear interpolation (linear interpolation on a 3D spatial
168 // tensor) can be expressed as linear interpolation along
169 // dimension x followed by interpolation along dimension y and z
170 // C011--C11--C111
171 // - - |
172 // - - |
173 //C001--C01--C111 |
174 // - .C - C110
175 // - - -
176 // - - -
177 //C000--C00--C100
178 auto id = linear_coeffs_t(od, OD, ID);
179 auto iw = linear_coeffs_t(ow, OW, IW);
180 auto ih = linear_coeffs_t(oh, OH, IH);
181 float src_l[8] = {0};
182 for_(int i = 0; i < 2; i++)
183 for_(int j = 0; j < 2; j++)
184 for (int k = 0; k < 2; k++) {
185 src_l[4 * i + 2 * j + k] = load_fn(src,
186 get_offset(src_d, mb, ch, id.idx[i], ih.idx[j],
187 iw.idx[k]));
188 }
189 res = trilin_interp(src_l[0], src_l[1], src_l[2], src_l[3],
190 src_l[4], src_l[5], src_l[6], src_l[7], id.wei[0],
191 ih.wei[0], iw.wei[0]);
192 }
193
194 ref_post_ops_t::args_t args;
195 args.ctx = &ctx;
196 args.dst_md = pd()->dst_md();
197 args.l_offset = data_l_off;
198 args.dst_val = dst[data_p_off];
199 ref_post_ops_.execute(res, args);
200
201 store_fn(res, dst, data_p_off);
202 });
203}
204
205ref_resampling_bwd_t::ref_resampling_bwd_t(const pd_t *apd)
206 : primitive_t(apd) {}
207
208ref_resampling_bwd_t::~ref_resampling_bwd_t() = default;
209
210void ref_resampling_bwd_t::execute_backward(const exec_ctx_t &ctx) const {
211 if (this->pd()->has_zero_dim_memory()) return;
212 status_t status = status::success;
213 const auto diff_dst = CTX_IN_MEM(const byte *, DNNL_ARG_DIFF_DST);
214 auto diff_src = CTX_OUT_CLEAN_MEM(byte *, DNNL_ARG_DIFF_SRC, status);
215
216 const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
217 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
218
219 const data_type_t diff_dst_dt = pd()->diff_dst_md()->data_type;
220 const data_type_t diff_src_dt = pd()->diff_src_md()->data_type;
221
222 load_fn_t load_fn = create_load(diff_dst_dt);
223 store_fn_t store_fn = create_store(diff_src_dt);
224
225 const auto alg = pd()->desc()->alg_kind;
226
227 const int MB = pd()->MB();
228 const int C = pd()->C();
229 const int ID = pd()->ID();
230 const int IH = pd()->IH();
231 const int IW = pd()->IW();
232 const int OD = pd()->OD();
233 const int OH = pd()->OH();
234 const int OW = pd()->OW();
235
236 if (alg == alg_kind::resampling_nearest) {
237 parallel_nd(MB, C, ID, IH, IW,
238 [&](dim_t mb, dim_t ch, dim_t id, dim_t ih, dim_t iw) {
239 const dim_t od_start
240 = ceil_idx(((float)id * OD / ID) - 0.5f);
241 const dim_t oh_start
242 = ceil_idx(((float)ih * OH / IH) - 0.5f);
243 const dim_t ow_start
244 = ceil_idx(((float)iw * OW / IW) - 0.5f);
245
246 const dim_t od_end
247 = ceil_idx(((id + 1.f) * OD / ID) - 0.5f);
248 const dim_t oh_end
249 = ceil_idx(((ih + 1.f) * OH / IH) - 0.5f);
250 const dim_t ow_end
251 = ceil_idx(((iw + 1.f) * OW / IW) - 0.5f);
252
253 float ds = 0;
254 for_(dim_t od = od_start; od < od_end; od++)
255 for_(dim_t oh = oh_start; oh < oh_end; oh++)
256 for (dim_t ow = ow_start; ow < ow_end; ow++)
257 ds += load_fn(diff_dst,
258 get_offset(diff_dst_d, mb, ch, od, oh, ow));
259 store_fn(ds, diff_src,
260 get_offset(diff_src_d, mb, ch, id, ih, iw));
261 });
262 } else {
263 parallel_nd(MB, C, ID, IH, IW,
264 [&](dim_t mb, dim_t ch, dim_t id, dim_t ih, dim_t iw) {
265 bwd_linear_coeffs_t d(id, OD, ID);
266 bwd_linear_coeffs_t h(ih, OH, IH);
267 bwd_linear_coeffs_t w(iw, OW, IW);
268
269 float ds = 0;
270 for_(int i = 0; i < 2; i++)
271 for_(int j = 0; j < 2; j++)
272 for_(int k = 0; k < 2; k++)
273 for_(dim_t od = d.start[i]; od < d.end[i]; od++)
274 for_(dim_t oh = h.start[j]; oh < h.end[j]; oh++)
275 for (dim_t ow = w.start[k]; ow < w.end[k]; ow++) {
276 const float weight_d = linear_weight(i, od, OD, ID);
277 const float weight_h = linear_weight(j, oh, OH, IH);
278 const float weight_w = linear_weight(k, ow, OW, IW);
279
280 float dd = load_fn(diff_dst,
281 get_offset(diff_dst_d, mb, ch, od, oh, ow));
282 ds += dd * weight_d * weight_h * weight_w;
283 }
284 store_fn(ds, diff_src,
285 get_offset(diff_src_d, mb, ch, id, ih, iw));
286 });
287 }
288}
289
290} // namespace cpu
291} // namespace impl
292} // namespace dnnl
293
294// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
295