1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <cassert>
18
19#include "common/bfloat16.hpp"
20#include "common/c_types_map.hpp"
21#include "common/dnnl_thread.hpp"
22#include "common/math_utils.hpp"
23#include "common/type_helpers.hpp"
24
25#include "cpu/simple_resampling.hpp"
26
27namespace dnnl {
28namespace impl {
29namespace cpu {
30
31using namespace format_tag;
32using namespace resampling_utils;
33using namespace std::placeholders;
34
35using namespace resampling_utils;
36
37namespace {
38
39template <data_type_t src_type, data_type_t dst_type>
40struct simple_resampling_kernel_t : public simple_resampling_base_t {
41 simple_resampling_kernel_t(const resampling_pd_t *pd);
42
43 using src_data_t = typename prec_traits<src_type>::type;
44 using dst_data_t = typename prec_traits<dst_type>::type;
45
46 status_t init() override;
47 status_t execute(const exec_ctx_t &ctx) const override;
48
49private:
50 using interpolate_fn_t = std::function<void(const src_data_t *,
51 dst_data_t *, ref_post_ops_t::args_t &, dim_t, dim_t, dim_t)>;
52
53 void fill_coeffs();
54 void fill_weights();
55 interpolate_fn_t create_nearest() const;
56 interpolate_fn_t create_linear() const;
57 interpolate_fn_t create_bilinear() const;
58 interpolate_fn_t create_trilinear() const;
59
60 // For fwd processing:
61 const bool are_postops_set_;
62 const ref_post_ops_t ref_post_ops_;
63 std::vector<linear_coeffs_t> linear_coeffs_;
64
65 // For bwd processing:
66 std::vector<float> bwd_linear_weights_;
67 std::vector<bwd_linear_coeffs_t> bwd_linear_coeffs_;
68
69 interpolate_fn_t interpolate_fn_;
70};
71
72template <data_type_t src_type, data_type_t dst_type>
73simple_resampling_kernel_t<src_type, dst_type>::simple_resampling_kernel_t(
74 const resampling_pd_t *pd)
75 : simple_resampling_base_t(pd)
76 , are_postops_set_(!(pd_->attr()->post_ops_.entry_.empty()))
77 , ref_post_ops_(pd_->attr()->post_ops_) {
78 if (pd_->is_fwd()) {
79 const memory_desc_wrapper src_d(pd_->src_md());
80 inner_stride_ = src_d.blocking_desc().strides[pd_->ndims() - 1];
81 nsp_outer_ = src_d.nelems(true)
82 / (pd_->ID() * pd_->IH() * pd_->IW() * inner_stride_);
83 stride_d_ = pd_->IH() * pd_->IW() * inner_stride_;
84 stride_h_ = pd_->IW() * inner_stride_;
85 stride_w_ = inner_stride_;
86 } else {
87 const memory_desc_wrapper diff_src_d(pd_->diff_src_md());
88 inner_stride_ = diff_src_d.blocking_desc().strides[pd_->ndims() - 1];
89 nsp_outer_ = diff_src_d.nelems(true)
90 / (pd_->ID() * pd_->IH() * pd_->IW() * inner_stride_);
91 stride_d_ = pd_->OH() * pd_->OW() * inner_stride_;
92 stride_h_ = pd_->OW() * inner_stride_;
93 stride_w_ = inner_stride_;
94 }
95}
96
97template <data_type_t src_type, data_type_t dst_type>
98status_t simple_resampling_kernel_t<src_type, dst_type>::init() {
99 if (pd_->desc()->alg_kind == alg_kind::resampling_nearest)
100 interpolate_fn_ = create_nearest();
101 else {
102 if (pd_->ndims() == 5)
103 interpolate_fn_ = create_trilinear();
104 else if (pd_->ndims() == 4)
105 interpolate_fn_ = create_bilinear();
106 else
107 interpolate_fn_ = create_linear();
108
109 fill_coeffs();
110 if (!pd_->is_fwd()) fill_weights();
111 }
112
113 return status::success;
114}
115
116template <data_type_t src_type, data_type_t dst_type>
117status_t simple_resampling_kernel_t<src_type, dst_type>::execute(
118 const exec_ctx_t &ctx) const {
119 const int OD = pd_->OD();
120 const int OH = pd_->OH();
121 const int OW = pd_->OW();
122 const int ID = pd_->ID();
123 const int IH = pd_->IH();
124 const int IW = pd_->IW();
125
126 if (pd_->is_fwd()) {
127 const auto src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC);
128 auto dst = CTX_OUT_MEM(dst_data_t *, DNNL_ARG_DST);
129
130 parallel_nd(nsp_outer_, OD, OH, [&](dim_t nsp0, dim_t od, dim_t oh) {
131 ref_post_ops_t::args_t postops_args;
132 postops_args.ctx = &ctx;
133 postops_args.dst_md = pd_->dst_md();
134
135 for (dim_t ow = 0; ow < OW; ow++) {
136 const dim_t src_off = nsp0 * ID * IH * IW * inner_stride_;
137 const dim_t dst_off
138 = (nsp0 * OD * OH * OW + od * OH * OW + oh * OW + ow)
139 * inner_stride_;
140
141 postops_args.l_offset = dst_off;
142
143 interpolate_fn_(
144 src + src_off, dst + dst_off, postops_args, od, oh, ow);
145 }
146 });
147 } else {
148 const auto diff_dst = CTX_IN_MEM(const src_data_t *, DNNL_ARG_DIFF_DST);
149 auto diff_src = CTX_OUT_MEM(dst_data_t *, DNNL_ARG_DIFF_SRC);
150 ref_post_ops_t::args_t empty_args;
151
152 parallel_nd(nsp_outer_, ID, IH, IW,
153 [&](dim_t nsp, dim_t id, dim_t ih, dim_t iw) {
154 const dim_t diff_dst_off
155 = nsp * OD * OH * OW * inner_stride_;
156 const dim_t diff_src_off
157 = (nsp * ID * IH * IW + id * IH * IW + ih * IW + iw)
158 * inner_stride_;
159 interpolate_fn_(diff_dst + diff_dst_off,
160 diff_src + diff_src_off, empty_args, id, ih, iw);
161 });
162 }
163
164 return status::success;
165}
166
167template <data_type_t src_type, data_type_t dst_type>
168void simple_resampling_kernel_t<src_type, dst_type>::fill_coeffs() {
169 if (pd_->is_fwd()) {
170 linear_coeffs_.reserve(pd_->OD() + pd_->OH() + pd_->OW());
171 for (dim_t od = 0; od < pd_->OD(); od++)
172 linear_coeffs_.emplace_back(
173 linear_coeffs_t(od, pd_->OD(), pd_->ID()));
174 for (dim_t oh = 0; oh < pd_->OH(); oh++)
175 linear_coeffs_.emplace_back(
176 linear_coeffs_t(oh, pd_->OH(), pd_->IH()));
177 for (dim_t ow = 0; ow < pd_->OW(); ow++)
178 linear_coeffs_.emplace_back(
179 linear_coeffs_t(ow, pd_->OW(), pd_->IW()));
180 } else {
181 bwd_linear_coeffs_.reserve(pd_->ID() + pd_->IH() + pd_->IW());
182 for (dim_t id = 0; id < pd_->ID(); id++)
183 bwd_linear_coeffs_.emplace_back(
184 bwd_linear_coeffs_t(id, pd_->OD(), pd_->ID()));
185 for (dim_t ih = 0; ih < pd_->IH(); ih++)
186 bwd_linear_coeffs_.emplace_back(
187 bwd_linear_coeffs_t(ih, pd_->OH(), pd_->IH()));
188 for (dim_t iw = 0; iw < pd_->IW(); iw++)
189 bwd_linear_coeffs_.emplace_back(
190 bwd_linear_coeffs_t(iw, pd_->OW(), pd_->IW()));
191 }
192}
193
194template <data_type_t src_type, data_type_t dst_type>
195void simple_resampling_kernel_t<src_type, dst_type>::fill_weights() {
196 assert(!pd_->is_fwd() && "The function is used in bwd path only.");
197
198 using namespace resampling_utils;
199 bwd_linear_weights_.reserve(2 * (pd_->OD() + pd_->OH() + pd_->OW()));
200 for (dim_t od = 0; od < pd_->OD(); od++) {
201 bwd_linear_weights_.emplace_back(
202 linear_weight(0, od, pd_->OD(), pd_->ID()));
203 bwd_linear_weights_.emplace_back(
204 linear_weight(1, od, pd_->OD(), pd_->ID()));
205 }
206 for (dim_t oh = 0; oh < pd_->OH(); oh++) {
207 bwd_linear_weights_.emplace_back(
208 linear_weight(0, oh, pd_->OH(), pd_->IH()));
209 bwd_linear_weights_.emplace_back(
210 linear_weight(1, oh, pd_->OH(), pd_->IH()));
211 }
212 for (dim_t ow = 0; ow < pd_->OW(); ow++) {
213 bwd_linear_weights_.emplace_back(
214 linear_weight(0, ow, pd_->OW(), pd_->IW()));
215 bwd_linear_weights_.emplace_back(
216 linear_weight(1, ow, pd_->OW(), pd_->IW()));
217 }
218}
219
220template <data_type_t src_type, data_type_t dst_type>
221typename simple_resampling_kernel_t<src_type, dst_type>::interpolate_fn_t
222simple_resampling_kernel_t<src_type, dst_type>::create_nearest() const {
223 if (pd_->is_fwd()) {
224 return [&](const src_data_t *src, dst_data_t *dst,
225 ref_post_ops_t::args_t &po_args, dim_t od, dim_t oh,
226 dim_t ow) {
227 const dim_t id = nearest_idx(od, pd_->OD(), pd_->ID());
228 const dim_t ih = nearest_idx(oh, pd_->OH(), pd_->IH());
229 const dim_t iw = nearest_idx(ow, pd_->OW(), pd_->IW());
230 const dim_t offset
231 = id * stride_d_ + ih * stride_h_ + iw * stride_w_;
232
233 PRAGMA_OMP_SIMD()
234 for (dim_t innermost_el = 0; innermost_el < inner_stride_;
235 innermost_el++) {
236 float res = static_cast<float>(src[offset + innermost_el]);
237
238 if (are_postops_set_) {
239 po_args.dst_val = dst[innermost_el];
240 ref_post_ops_.execute(res, po_args);
241 po_args.l_offset++;
242 }
243
244 dst[innermost_el] = cpu::saturate_and_round<dst_data_t>(res);
245 }
246 };
247 } else {
248 return [&](const src_data_t *diff_dst, dst_data_t *diff_src,
249 ref_post_ops_t::args_t &po_args, dim_t id, dim_t ih,
250 dim_t iw) {
251 auto ow_idx = [&](const float in_idx) -> dim_t {
252 return ceil_idx((in_idx * pd_->OW() / pd_->IW()) - 0.5f);
253 };
254 auto oh_idx = [&](const float in_idx) -> dim_t {
255 return ceil_idx((in_idx * pd_->OH() / pd_->IH()) - 0.5f);
256 };
257 auto od_idx = [&](const float in_idx) -> dim_t {
258 return ceil_idx((in_idx * pd_->OD() / pd_->ID()) - 0.5f);
259 };
260
261 const dim_t ow_start = ow_idx(iw) * stride_w_;
262 const dim_t oh_start = oh_idx(ih) * stride_h_;
263 const dim_t od_start = od_idx(id) * stride_d_;
264 const dim_t ow_end = ow_idx(iw + 1.f) * stride_w_;
265 const dim_t oh_end = oh_idx(ih + 1.f) * stride_h_;
266 const dim_t od_end = od_idx(id + 1.f) * stride_d_;
267
268 PRAGMA_OMP_SIMD()
269 for (dim_t innermost_el = 0; innermost_el < inner_stride_;
270 innermost_el++) {
271 float sum = 0;
272 for_(dim_t od = od_start; od < od_end; od += stride_d_)
273 for_(dim_t oh = oh_start; oh < oh_end; oh += stride_h_)
274 for (dim_t ow = ow_start; ow < ow_end; ow += stride_w_) {
275 sum += static_cast<float>(
276 diff_dst[od + oh + ow + innermost_el]);
277 }
278 diff_src[innermost_el]
279 = cpu::saturate_and_round<dst_data_t>(sum);
280 }
281 };
282 }
283}
284
285template <data_type_t src_type, data_type_t dst_type>
286typename simple_resampling_kernel_t<src_type, dst_type>::interpolate_fn_t
287simple_resampling_kernel_t<src_type, dst_type>::create_linear() const {
288 if (pd_->is_fwd()) {
289 return [&](const src_data_t *src, dst_data_t *dst,
290 ref_post_ops_t::args_t &po_args, dim_t od, dim_t oh,
291 dim_t ow) {
292 const linear_coeffs_t &iw
293 = linear_coeffs_[pd_->OD() + pd_->OH() + ow];
294
295 PRAGMA_OMP_SIMD()
296 for (dim_t innermost_el = 0; innermost_el < inner_stride_;
297 innermost_el++) {
298 float res = 0;
299 for (int k = 0; k < 2; k++)
300 res += static_cast<float>(
301 src[iw.idx[k] * stride_w_ + innermost_el])
302 * iw.wei[k];
303
304 if (are_postops_set_) {
305 po_args.dst_val = dst[innermost_el];
306 ref_post_ops_.execute(res, po_args);
307 po_args.l_offset++;
308 }
309
310 dst[innermost_el] = cpu::saturate_and_round<dst_data_t>(res);
311 }
312 };
313 } else {
314 return [&](const src_data_t *diff_dst, dst_data_t *diff_src,
315 ref_post_ops_t::args_t &po_args, dim_t id, dim_t ih,
316 dim_t iw) {
317 const bwd_linear_coeffs_t &w
318 = bwd_linear_coeffs_[pd_->ID() + pd_->IH() + iw];
319
320 PRAGMA_OMP_SIMD()
321 for (dim_t innermost_el = 0; innermost_el < inner_stride_;
322 innermost_el++) {
323 float sum = 0;
324 for_(int k = 0; k < 2; k++)
325 for (dim_t ow = w.start[k]; ow < w.end[k]; ow++) {
326 sum += static_cast<float>(
327 diff_dst[ow * stride_w_ + innermost_el])
328 * bwd_linear_weights_[2
329 * (pd_->OD() + pd_->OH() + ow)
330 + k];
331 }
332 diff_src[innermost_el]
333 = cpu::saturate_and_round<dst_data_t>(sum);
334 }
335 };
336 }
337}
338
339template <data_type_t src_type, data_type_t dst_type>
340typename simple_resampling_kernel_t<src_type, dst_type>::interpolate_fn_t
341simple_resampling_kernel_t<src_type, dst_type>::create_bilinear() const {
342 if (pd_->is_fwd()) {
343 return [&](const src_data_t *src, dst_data_t *dst,
344 ref_post_ops_t::args_t &po_args, dim_t od, dim_t oh,
345 dim_t ow) {
346 const linear_coeffs_t &ih = linear_coeffs_[pd_->OD() + oh];
347 const linear_coeffs_t &iw
348 = linear_coeffs_[pd_->OD() + pd_->OH() + ow];
349
350 PRAGMA_OMP_SIMD()
351 for (dim_t innermost_el = 0; innermost_el < inner_stride_;
352 innermost_el++) {
353 float res = 0;
354 for_(int j = 0; j < 2; j++)
355 for (int k = 0; k < 2; k++)
356 res += static_cast<float>(src[ih.idx[j] * stride_h_
357 + iw.idx[k] * stride_w_ + innermost_el])
358 * ih.wei[j] * iw.wei[k];
359
360 if (are_postops_set_) {
361 po_args.dst_val = dst[innermost_el];
362 ref_post_ops_.execute(res, po_args);
363 po_args.l_offset++;
364 }
365
366 dst[innermost_el] = cpu::saturate_and_round<dst_data_t>(res);
367 }
368 };
369 } else {
370 return [&](const src_data_t *diff_dst, dst_data_t *diff_src,
371 ref_post_ops_t::args_t &po_args, dim_t id, dim_t ih,
372 dim_t iw) {
373 const bwd_linear_coeffs_t &h = bwd_linear_coeffs_[pd_->ID() + ih];
374 const bwd_linear_coeffs_t &w
375 = bwd_linear_coeffs_[pd_->ID() + pd_->IH() + iw];
376
377 PRAGMA_OMP_SIMD()
378 for (dim_t innermost_el = 0; innermost_el < inner_stride_;
379 innermost_el++) {
380 float sum = 0;
381 for_(int j = 0; j < 2; j++)
382 for_(int k = 0; k < 2; k++)
383 for_(dim_t oh = h.start[j]; oh < h.end[j]; oh++)
384 for (dim_t ow = w.start[k]; ow < w.end[k]; ow++) {
385 sum += static_cast<float>(diff_dst[oh * stride_h_
386 + ow * stride_w_ + innermost_el])
387 * bwd_linear_weights_[2 * (pd_->OD() + oh) + j]
388 * bwd_linear_weights_[2
389 * (pd_->OD() + pd_->OH() + ow)
390 + k];
391 }
392 diff_src[innermost_el]
393 = cpu::saturate_and_round<dst_data_t>(sum);
394 }
395 };
396 }
397}
398
399template <data_type_t src_type, data_type_t dst_type>
400typename simple_resampling_kernel_t<src_type, dst_type>::interpolate_fn_t
401simple_resampling_kernel_t<src_type, dst_type>::create_trilinear() const {
402 if (pd_->is_fwd()) {
403 return [&](const src_data_t *src, dst_data_t *dst,
404 ref_post_ops_t::args_t &po_args, dim_t od, dim_t oh,
405 dim_t ow) {
406 const linear_coeffs_t &id = linear_coeffs_[od];
407 const linear_coeffs_t &ih = linear_coeffs_[pd_->OD() + oh];
408 const linear_coeffs_t &iw
409 = linear_coeffs_[pd_->OD() + pd_->OH() + ow];
410
411 PRAGMA_OMP_SIMD()
412 for (dim_t innermost_el = 0; innermost_el < inner_stride_;
413 innermost_el++) {
414 float res = 0;
415 for_(int i = 0; i < 2; i++)
416 for_(int j = 0; j < 2; j++)
417 for (int k = 0; k < 2; k++)
418 res += static_cast<float>(src[id.idx[i] * stride_d_
419 + ih.idx[j] * stride_h_
420 + iw.idx[k] * stride_w_ + innermost_el])
421 * id.wei[i] * ih.wei[j] * iw.wei[k];
422
423 if (are_postops_set_) {
424 po_args.dst_val = dst[innermost_el];
425 ref_post_ops_.execute(res, po_args);
426 po_args.l_offset++;
427 }
428
429 dst[innermost_el] = cpu::saturate_and_round<dst_data_t>(res);
430 }
431 };
432 } else {
433 return [&](const src_data_t *diff_dst, dst_data_t *diff_src,
434 ref_post_ops_t::args_t &po_args, dim_t id, dim_t ih,
435 dim_t iw) {
436 const bwd_linear_coeffs_t &d = bwd_linear_coeffs_[id];
437 const bwd_linear_coeffs_t &h = bwd_linear_coeffs_[pd_->ID() + ih];
438 const bwd_linear_coeffs_t &w
439 = bwd_linear_coeffs_[pd_->ID() + pd_->IH() + iw];
440
441 PRAGMA_OMP_SIMD()
442 for (dim_t innermost_el = 0; innermost_el < inner_stride_;
443 innermost_el++) {
444 float sum = 0;
445 for_(int i = 0; i < 2; i++)
446 for_(int j = 0; j < 2; j++)
447 for_(int k = 0; k < 2; k++)
448 for_(dim_t od = d.start[i]; od < d.end[i]; od++)
449 for_(dim_t oh = h.start[j]; oh < h.end[j]; oh++)
450 for (dim_t ow = w.start[k]; ow < w.end[k]; ow++) {
451 sum += static_cast<float>(
452 diff_dst[od * stride_d_ + oh * stride_h_
453 + ow * stride_w_ + innermost_el])
454 * bwd_linear_weights_[2 * od + i]
455 * bwd_linear_weights_[2 * (pd_->OD() + oh) + j]
456 * bwd_linear_weights_[2
457 * (pd_->OD() + pd_->OH() + ow)
458 + k];
459 }
460 diff_src[innermost_el]
461 = cpu::saturate_and_round<dst_data_t>(sum);
462 }
463 };
464 }
465}
466
467template struct simple_resampling_kernel_t<data_type::f32, data_type::f32>;
468template struct simple_resampling_kernel_t<data_type::f32, data_type::bf16>;
469template struct simple_resampling_kernel_t<data_type::f32, data_type::f16>;
470template struct simple_resampling_kernel_t<data_type::f32, data_type::s32>;
471template struct simple_resampling_kernel_t<data_type::f32, data_type::s8>;
472template struct simple_resampling_kernel_t<data_type::f32, data_type::u8>;
473
474template struct simple_resampling_kernel_t<data_type::bf16, data_type::f32>;
475template struct simple_resampling_kernel_t<data_type::bf16, data_type::bf16>;
476template struct simple_resampling_kernel_t<data_type::bf16, data_type::f16>;
477template struct simple_resampling_kernel_t<data_type::bf16, data_type::s32>;
478template struct simple_resampling_kernel_t<data_type::bf16, data_type::s8>;
479template struct simple_resampling_kernel_t<data_type::bf16, data_type::u8>;
480
481template struct simple_resampling_kernel_t<data_type::f16, data_type::f32>;
482template struct simple_resampling_kernel_t<data_type::f16, data_type::bf16>;
483template struct simple_resampling_kernel_t<data_type::f16, data_type::f16>;
484template struct simple_resampling_kernel_t<data_type::f16, data_type::s32>;
485template struct simple_resampling_kernel_t<data_type::f16, data_type::s8>;
486template struct simple_resampling_kernel_t<data_type::f16, data_type::u8>;
487
488template struct simple_resampling_kernel_t<data_type::s32, data_type::f32>;
489template struct simple_resampling_kernel_t<data_type::s32, data_type::bf16>;
490template struct simple_resampling_kernel_t<data_type::s32, data_type::f16>;
491template struct simple_resampling_kernel_t<data_type::s32, data_type::s32>;
492template struct simple_resampling_kernel_t<data_type::s32, data_type::s8>;
493template struct simple_resampling_kernel_t<data_type::s32, data_type::u8>;
494
495template struct simple_resampling_kernel_t<data_type::s8, data_type::f32>;
496template struct simple_resampling_kernel_t<data_type::s8, data_type::bf16>;
497template struct simple_resampling_kernel_t<data_type::s8, data_type::f16>;
498template struct simple_resampling_kernel_t<data_type::s8, data_type::s32>;
499template struct simple_resampling_kernel_t<data_type::s8, data_type::s8>;
500template struct simple_resampling_kernel_t<data_type::s8, data_type::u8>;
501
502template struct simple_resampling_kernel_t<data_type::u8, data_type::f32>;
503template struct simple_resampling_kernel_t<data_type::u8, data_type::bf16>;
504template struct simple_resampling_kernel_t<data_type::u8, data_type::f16>;
505template struct simple_resampling_kernel_t<data_type::u8, data_type::s32>;
506template struct simple_resampling_kernel_t<data_type::u8, data_type::s8>;
507template struct simple_resampling_kernel_t<data_type::u8, data_type::u8>;
508
509simple_resampling_base_t *create_simple_resampling(const resampling_pd_t *pd,
510 const data_type_t src_dt, const data_type_t dst_dt) {
511 using namespace data_type;
512
513 switch (src_dt) {
514 case f32:
515 switch (dst_dt) {
516 case f32: return new simple_resampling_kernel_t<f32, f32>(pd);
517 case s32: return new simple_resampling_kernel_t<f32, s32>(pd);
518 case bf16: return new simple_resampling_kernel_t<f32, bf16>(pd);
519 case f16: return new simple_resampling_kernel_t<f32, f16>(pd);
520 case s8: return new simple_resampling_kernel_t<f32, s8>(pd);
521 case u8: return new simple_resampling_kernel_t<f32, u8>(pd);
522 default: break;
523 }
524 case s32:
525 switch (dst_dt) {
526 case f32: return new simple_resampling_kernel_t<s32, f32>(pd);
527 case s32: return new simple_resampling_kernel_t<s32, s32>(pd);
528 case bf16: return new simple_resampling_kernel_t<s32, bf16>(pd);
529 case f16: return new simple_resampling_kernel_t<s32, f16>(pd);
530 case s8: return new simple_resampling_kernel_t<s32, s8>(pd);
531 case u8: return new simple_resampling_kernel_t<s32, u8>(pd);
532 default: break;
533 }
534 case bf16:
535 switch (dst_dt) {
536 case f32: return new simple_resampling_kernel_t<bf16, f32>(pd);
537 case s32: return new simple_resampling_kernel_t<bf16, s32>(pd);
538 case bf16:
539 return new simple_resampling_kernel_t<bf16, bf16>(pd);
540 case f16: return new simple_resampling_kernel_t<bf16, f16>(pd);
541 case s8: return new simple_resampling_kernel_t<bf16, s8>(pd);
542 case u8: return new simple_resampling_kernel_t<bf16, u8>(pd);
543 default: break;
544 }
545 case f16:
546 switch (dst_dt) {
547 case f32: return new simple_resampling_kernel_t<f16, f32>(pd);
548 case s32: return new simple_resampling_kernel_t<f16, s32>(pd);
549 case bf16: return new simple_resampling_kernel_t<f16, bf16>(pd);
550 case f16: return new simple_resampling_kernel_t<f16, f16>(pd);
551 case s8: return new simple_resampling_kernel_t<f16, s8>(pd);
552 case u8: return new simple_resampling_kernel_t<f16, u8>(pd);
553 default: break;
554 }
555 case s8:
556 switch (dst_dt) {
557 case f32: return new simple_resampling_kernel_t<s8, f32>(pd);
558 case s32: return new simple_resampling_kernel_t<s8, s32>(pd);
559 case bf16: return new simple_resampling_kernel_t<s8, bf16>(pd);
560 case f16: return new simple_resampling_kernel_t<s8, f16>(pd);
561 case s8: return new simple_resampling_kernel_t<s8, s8>(pd);
562 case u8: return new simple_resampling_kernel_t<s8, u8>(pd);
563 default: break;
564 }
565 case u8:
566 switch (dst_dt) {
567 case f32: return new simple_resampling_kernel_t<u8, f32>(pd);
568 case s32: return new simple_resampling_kernel_t<u8, s32>(pd);
569 case bf16: return new simple_resampling_kernel_t<u8, bf16>(pd);
570 case f16: return new simple_resampling_kernel_t<u8, f16>(pd);
571 case s8: return new simple_resampling_kernel_t<u8, s8>(pd);
572 case u8: return new simple_resampling_kernel_t<u8, u8>(pd);
573 default: break;
574 }
575 default: break;
576 }
577
578 assert(!"Unsupported data type combination.");
579 return nullptr;
580}
581
582} // namespace
583
584simple_resampling_fwd_t::simple_resampling_fwd_t(const pd_t *apd)
585 : primitive_t(apd), kernel_(nullptr) {}
586
587status_t simple_resampling_fwd_t::init(engine_t *engine) {
588 CHECK(safe_ptr_assign(kernel_,
589 create_simple_resampling(pd(), pd()->src_md()->data_type,
590 pd()->dst_md()->data_type)));
591 return kernel_->init();
592}
593
594status_t simple_resampling_fwd_t::execute(const exec_ctx_t &ctx) const {
595 return kernel_->execute(ctx);
596}
597
598simple_resampling_bwd_t::simple_resampling_bwd_t(const pd_t *apd)
599 : primitive_t(apd), kernel_(nullptr) {}
600
601status_t simple_resampling_bwd_t::init(engine_t *engine) {
602 CHECK(safe_ptr_assign(kernel_,
603 create_simple_resampling(pd(), pd()->diff_dst_md()->data_type,
604 pd()->diff_src_md()->data_type)));
605 return kernel_->init();
606}
607
608status_t simple_resampling_bwd_t::execute(const exec_ctx_t &ctx) const {
609 return kernel_->execute(ctx);
610}
611} // namespace cpu
612} // namespace impl
613} // namespace dnnl
614
615// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
616