1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <cassert> |
18 | |
19 | #include "common/bfloat16.hpp" |
20 | #include "common/c_types_map.hpp" |
21 | #include "common/dnnl_thread.hpp" |
22 | #include "common/math_utils.hpp" |
23 | #include "common/type_helpers.hpp" |
24 | |
25 | #include "cpu/simple_resampling.hpp" |
26 | |
27 | namespace dnnl { |
28 | namespace impl { |
29 | namespace cpu { |
30 | |
31 | using namespace format_tag; |
32 | using namespace resampling_utils; |
33 | using namespace std::placeholders; |
34 | |
35 | using namespace resampling_utils; |
36 | |
37 | namespace { |
38 | |
39 | template <data_type_t src_type, data_type_t dst_type> |
40 | struct simple_resampling_kernel_t : public simple_resampling_base_t { |
41 | simple_resampling_kernel_t(const resampling_pd_t *pd); |
42 | |
43 | using src_data_t = typename prec_traits<src_type>::type; |
44 | using dst_data_t = typename prec_traits<dst_type>::type; |
45 | |
46 | status_t init() override; |
47 | status_t execute(const exec_ctx_t &ctx) const override; |
48 | |
49 | private: |
50 | using interpolate_fn_t = std::function<void(const src_data_t *, |
51 | dst_data_t *, ref_post_ops_t::args_t &, dim_t, dim_t, dim_t)>; |
52 | |
53 | void fill_coeffs(); |
54 | void fill_weights(); |
55 | interpolate_fn_t create_nearest() const; |
56 | interpolate_fn_t create_linear() const; |
57 | interpolate_fn_t create_bilinear() const; |
58 | interpolate_fn_t create_trilinear() const; |
59 | |
60 | // For fwd processing: |
61 | const bool are_postops_set_; |
62 | const ref_post_ops_t ref_post_ops_; |
63 | std::vector<linear_coeffs_t> linear_coeffs_; |
64 | |
65 | // For bwd processing: |
66 | std::vector<float> bwd_linear_weights_; |
67 | std::vector<bwd_linear_coeffs_t> bwd_linear_coeffs_; |
68 | |
69 | interpolate_fn_t interpolate_fn_; |
70 | }; |
71 | |
72 | template <data_type_t src_type, data_type_t dst_type> |
73 | simple_resampling_kernel_t<src_type, dst_type>::simple_resampling_kernel_t( |
74 | const resampling_pd_t *pd) |
75 | : simple_resampling_base_t(pd) |
76 | , are_postops_set_(!(pd_->attr()->post_ops_.entry_.empty())) |
77 | , ref_post_ops_(pd_->attr()->post_ops_) { |
78 | if (pd_->is_fwd()) { |
79 | const memory_desc_wrapper src_d(pd_->src_md()); |
80 | inner_stride_ = src_d.blocking_desc().strides[pd_->ndims() - 1]; |
81 | nsp_outer_ = src_d.nelems(true) |
82 | / (pd_->ID() * pd_->IH() * pd_->IW() * inner_stride_); |
83 | stride_d_ = pd_->IH() * pd_->IW() * inner_stride_; |
84 | stride_h_ = pd_->IW() * inner_stride_; |
85 | stride_w_ = inner_stride_; |
86 | } else { |
87 | const memory_desc_wrapper diff_src_d(pd_->diff_src_md()); |
88 | inner_stride_ = diff_src_d.blocking_desc().strides[pd_->ndims() - 1]; |
89 | nsp_outer_ = diff_src_d.nelems(true) |
90 | / (pd_->ID() * pd_->IH() * pd_->IW() * inner_stride_); |
91 | stride_d_ = pd_->OH() * pd_->OW() * inner_stride_; |
92 | stride_h_ = pd_->OW() * inner_stride_; |
93 | stride_w_ = inner_stride_; |
94 | } |
95 | } |
96 | |
97 | template <data_type_t src_type, data_type_t dst_type> |
98 | status_t simple_resampling_kernel_t<src_type, dst_type>::init() { |
99 | if (pd_->desc()->alg_kind == alg_kind::resampling_nearest) |
100 | interpolate_fn_ = create_nearest(); |
101 | else { |
102 | if (pd_->ndims() == 5) |
103 | interpolate_fn_ = create_trilinear(); |
104 | else if (pd_->ndims() == 4) |
105 | interpolate_fn_ = create_bilinear(); |
106 | else |
107 | interpolate_fn_ = create_linear(); |
108 | |
109 | fill_coeffs(); |
110 | if (!pd_->is_fwd()) fill_weights(); |
111 | } |
112 | |
113 | return status::success; |
114 | } |
115 | |
116 | template <data_type_t src_type, data_type_t dst_type> |
117 | status_t simple_resampling_kernel_t<src_type, dst_type>::execute( |
118 | const exec_ctx_t &ctx) const { |
119 | const int OD = pd_->OD(); |
120 | const int OH = pd_->OH(); |
121 | const int OW = pd_->OW(); |
122 | const int ID = pd_->ID(); |
123 | const int IH = pd_->IH(); |
124 | const int IW = pd_->IW(); |
125 | |
126 | if (pd_->is_fwd()) { |
127 | const auto src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC); |
128 | auto dst = CTX_OUT_MEM(dst_data_t *, DNNL_ARG_DST); |
129 | |
130 | parallel_nd(nsp_outer_, OD, OH, [&](dim_t nsp0, dim_t od, dim_t oh) { |
131 | ref_post_ops_t::args_t postops_args; |
132 | postops_args.ctx = &ctx; |
133 | postops_args.dst_md = pd_->dst_md(); |
134 | |
135 | for (dim_t ow = 0; ow < OW; ow++) { |
136 | const dim_t src_off = nsp0 * ID * IH * IW * inner_stride_; |
137 | const dim_t dst_off |
138 | = (nsp0 * OD * OH * OW + od * OH * OW + oh * OW + ow) |
139 | * inner_stride_; |
140 | |
141 | postops_args.l_offset = dst_off; |
142 | |
143 | interpolate_fn_( |
144 | src + src_off, dst + dst_off, postops_args, od, oh, ow); |
145 | } |
146 | }); |
147 | } else { |
148 | const auto diff_dst = CTX_IN_MEM(const src_data_t *, DNNL_ARG_DIFF_DST); |
149 | auto diff_src = CTX_OUT_MEM(dst_data_t *, DNNL_ARG_DIFF_SRC); |
150 | ref_post_ops_t::args_t empty_args; |
151 | |
152 | parallel_nd(nsp_outer_, ID, IH, IW, |
153 | [&](dim_t nsp, dim_t id, dim_t ih, dim_t iw) { |
154 | const dim_t diff_dst_off |
155 | = nsp * OD * OH * OW * inner_stride_; |
156 | const dim_t diff_src_off |
157 | = (nsp * ID * IH * IW + id * IH * IW + ih * IW + iw) |
158 | * inner_stride_; |
159 | interpolate_fn_(diff_dst + diff_dst_off, |
160 | diff_src + diff_src_off, empty_args, id, ih, iw); |
161 | }); |
162 | } |
163 | |
164 | return status::success; |
165 | } |
166 | |
167 | template <data_type_t src_type, data_type_t dst_type> |
168 | void simple_resampling_kernel_t<src_type, dst_type>::fill_coeffs() { |
169 | if (pd_->is_fwd()) { |
170 | linear_coeffs_.reserve(pd_->OD() + pd_->OH() + pd_->OW()); |
171 | for (dim_t od = 0; od < pd_->OD(); od++) |
172 | linear_coeffs_.emplace_back( |
173 | linear_coeffs_t(od, pd_->OD(), pd_->ID())); |
174 | for (dim_t oh = 0; oh < pd_->OH(); oh++) |
175 | linear_coeffs_.emplace_back( |
176 | linear_coeffs_t(oh, pd_->OH(), pd_->IH())); |
177 | for (dim_t ow = 0; ow < pd_->OW(); ow++) |
178 | linear_coeffs_.emplace_back( |
179 | linear_coeffs_t(ow, pd_->OW(), pd_->IW())); |
180 | } else { |
181 | bwd_linear_coeffs_.reserve(pd_->ID() + pd_->IH() + pd_->IW()); |
182 | for (dim_t id = 0; id < pd_->ID(); id++) |
183 | bwd_linear_coeffs_.emplace_back( |
184 | bwd_linear_coeffs_t(id, pd_->OD(), pd_->ID())); |
185 | for (dim_t ih = 0; ih < pd_->IH(); ih++) |
186 | bwd_linear_coeffs_.emplace_back( |
187 | bwd_linear_coeffs_t(ih, pd_->OH(), pd_->IH())); |
188 | for (dim_t iw = 0; iw < pd_->IW(); iw++) |
189 | bwd_linear_coeffs_.emplace_back( |
190 | bwd_linear_coeffs_t(iw, pd_->OW(), pd_->IW())); |
191 | } |
192 | } |
193 | |
194 | template <data_type_t src_type, data_type_t dst_type> |
195 | void simple_resampling_kernel_t<src_type, dst_type>::fill_weights() { |
196 | assert(!pd_->is_fwd() && "The function is used in bwd path only." ); |
197 | |
198 | using namespace resampling_utils; |
199 | bwd_linear_weights_.reserve(2 * (pd_->OD() + pd_->OH() + pd_->OW())); |
200 | for (dim_t od = 0; od < pd_->OD(); od++) { |
201 | bwd_linear_weights_.emplace_back( |
202 | linear_weight(0, od, pd_->OD(), pd_->ID())); |
203 | bwd_linear_weights_.emplace_back( |
204 | linear_weight(1, od, pd_->OD(), pd_->ID())); |
205 | } |
206 | for (dim_t oh = 0; oh < pd_->OH(); oh++) { |
207 | bwd_linear_weights_.emplace_back( |
208 | linear_weight(0, oh, pd_->OH(), pd_->IH())); |
209 | bwd_linear_weights_.emplace_back( |
210 | linear_weight(1, oh, pd_->OH(), pd_->IH())); |
211 | } |
212 | for (dim_t ow = 0; ow < pd_->OW(); ow++) { |
213 | bwd_linear_weights_.emplace_back( |
214 | linear_weight(0, ow, pd_->OW(), pd_->IW())); |
215 | bwd_linear_weights_.emplace_back( |
216 | linear_weight(1, ow, pd_->OW(), pd_->IW())); |
217 | } |
218 | } |
219 | |
220 | template <data_type_t src_type, data_type_t dst_type> |
221 | typename simple_resampling_kernel_t<src_type, dst_type>::interpolate_fn_t |
222 | simple_resampling_kernel_t<src_type, dst_type>::create_nearest() const { |
223 | if (pd_->is_fwd()) { |
224 | return [&](const src_data_t *src, dst_data_t *dst, |
225 | ref_post_ops_t::args_t &po_args, dim_t od, dim_t oh, |
226 | dim_t ow) { |
227 | const dim_t id = nearest_idx(od, pd_->OD(), pd_->ID()); |
228 | const dim_t ih = nearest_idx(oh, pd_->OH(), pd_->IH()); |
229 | const dim_t iw = nearest_idx(ow, pd_->OW(), pd_->IW()); |
230 | const dim_t offset |
231 | = id * stride_d_ + ih * stride_h_ + iw * stride_w_; |
232 | |
233 | PRAGMA_OMP_SIMD() |
234 | for (dim_t innermost_el = 0; innermost_el < inner_stride_; |
235 | innermost_el++) { |
236 | float res = static_cast<float>(src[offset + innermost_el]); |
237 | |
238 | if (are_postops_set_) { |
239 | po_args.dst_val = dst[innermost_el]; |
240 | ref_post_ops_.execute(res, po_args); |
241 | po_args.l_offset++; |
242 | } |
243 | |
244 | dst[innermost_el] = cpu::saturate_and_round<dst_data_t>(res); |
245 | } |
246 | }; |
247 | } else { |
248 | return [&](const src_data_t *diff_dst, dst_data_t *diff_src, |
249 | ref_post_ops_t::args_t &po_args, dim_t id, dim_t ih, |
250 | dim_t iw) { |
251 | auto ow_idx = [&](const float in_idx) -> dim_t { |
252 | return ceil_idx((in_idx * pd_->OW() / pd_->IW()) - 0.5f); |
253 | }; |
254 | auto oh_idx = [&](const float in_idx) -> dim_t { |
255 | return ceil_idx((in_idx * pd_->OH() / pd_->IH()) - 0.5f); |
256 | }; |
257 | auto od_idx = [&](const float in_idx) -> dim_t { |
258 | return ceil_idx((in_idx * pd_->OD() / pd_->ID()) - 0.5f); |
259 | }; |
260 | |
261 | const dim_t ow_start = ow_idx(iw) * stride_w_; |
262 | const dim_t oh_start = oh_idx(ih) * stride_h_; |
263 | const dim_t od_start = od_idx(id) * stride_d_; |
264 | const dim_t ow_end = ow_idx(iw + 1.f) * stride_w_; |
265 | const dim_t oh_end = oh_idx(ih + 1.f) * stride_h_; |
266 | const dim_t od_end = od_idx(id + 1.f) * stride_d_; |
267 | |
268 | PRAGMA_OMP_SIMD() |
269 | for (dim_t innermost_el = 0; innermost_el < inner_stride_; |
270 | innermost_el++) { |
271 | float sum = 0; |
272 | for_(dim_t od = od_start; od < od_end; od += stride_d_) |
273 | for_(dim_t oh = oh_start; oh < oh_end; oh += stride_h_) |
274 | for (dim_t ow = ow_start; ow < ow_end; ow += stride_w_) { |
275 | sum += static_cast<float>( |
276 | diff_dst[od + oh + ow + innermost_el]); |
277 | } |
278 | diff_src[innermost_el] |
279 | = cpu::saturate_and_round<dst_data_t>(sum); |
280 | } |
281 | }; |
282 | } |
283 | } |
284 | |
285 | template <data_type_t src_type, data_type_t dst_type> |
286 | typename simple_resampling_kernel_t<src_type, dst_type>::interpolate_fn_t |
287 | simple_resampling_kernel_t<src_type, dst_type>::create_linear() const { |
288 | if (pd_->is_fwd()) { |
289 | return [&](const src_data_t *src, dst_data_t *dst, |
290 | ref_post_ops_t::args_t &po_args, dim_t od, dim_t oh, |
291 | dim_t ow) { |
292 | const linear_coeffs_t &iw |
293 | = linear_coeffs_[pd_->OD() + pd_->OH() + ow]; |
294 | |
295 | PRAGMA_OMP_SIMD() |
296 | for (dim_t innermost_el = 0; innermost_el < inner_stride_; |
297 | innermost_el++) { |
298 | float res = 0; |
299 | for (int k = 0; k < 2; k++) |
300 | res += static_cast<float>( |
301 | src[iw.idx[k] * stride_w_ + innermost_el]) |
302 | * iw.wei[k]; |
303 | |
304 | if (are_postops_set_) { |
305 | po_args.dst_val = dst[innermost_el]; |
306 | ref_post_ops_.execute(res, po_args); |
307 | po_args.l_offset++; |
308 | } |
309 | |
310 | dst[innermost_el] = cpu::saturate_and_round<dst_data_t>(res); |
311 | } |
312 | }; |
313 | } else { |
314 | return [&](const src_data_t *diff_dst, dst_data_t *diff_src, |
315 | ref_post_ops_t::args_t &po_args, dim_t id, dim_t ih, |
316 | dim_t iw) { |
317 | const bwd_linear_coeffs_t &w |
318 | = bwd_linear_coeffs_[pd_->ID() + pd_->IH() + iw]; |
319 | |
320 | PRAGMA_OMP_SIMD() |
321 | for (dim_t innermost_el = 0; innermost_el < inner_stride_; |
322 | innermost_el++) { |
323 | float sum = 0; |
324 | for_(int k = 0; k < 2; k++) |
325 | for (dim_t ow = w.start[k]; ow < w.end[k]; ow++) { |
326 | sum += static_cast<float>( |
327 | diff_dst[ow * stride_w_ + innermost_el]) |
328 | * bwd_linear_weights_[2 |
329 | * (pd_->OD() + pd_->OH() + ow) |
330 | + k]; |
331 | } |
332 | diff_src[innermost_el] |
333 | = cpu::saturate_and_round<dst_data_t>(sum); |
334 | } |
335 | }; |
336 | } |
337 | } |
338 | |
339 | template <data_type_t src_type, data_type_t dst_type> |
340 | typename simple_resampling_kernel_t<src_type, dst_type>::interpolate_fn_t |
341 | simple_resampling_kernel_t<src_type, dst_type>::create_bilinear() const { |
342 | if (pd_->is_fwd()) { |
343 | return [&](const src_data_t *src, dst_data_t *dst, |
344 | ref_post_ops_t::args_t &po_args, dim_t od, dim_t oh, |
345 | dim_t ow) { |
346 | const linear_coeffs_t &ih = linear_coeffs_[pd_->OD() + oh]; |
347 | const linear_coeffs_t &iw |
348 | = linear_coeffs_[pd_->OD() + pd_->OH() + ow]; |
349 | |
350 | PRAGMA_OMP_SIMD() |
351 | for (dim_t innermost_el = 0; innermost_el < inner_stride_; |
352 | innermost_el++) { |
353 | float res = 0; |
354 | for_(int j = 0; j < 2; j++) |
355 | for (int k = 0; k < 2; k++) |
356 | res += static_cast<float>(src[ih.idx[j] * stride_h_ |
357 | + iw.idx[k] * stride_w_ + innermost_el]) |
358 | * ih.wei[j] * iw.wei[k]; |
359 | |
360 | if (are_postops_set_) { |
361 | po_args.dst_val = dst[innermost_el]; |
362 | ref_post_ops_.execute(res, po_args); |
363 | po_args.l_offset++; |
364 | } |
365 | |
366 | dst[innermost_el] = cpu::saturate_and_round<dst_data_t>(res); |
367 | } |
368 | }; |
369 | } else { |
370 | return [&](const src_data_t *diff_dst, dst_data_t *diff_src, |
371 | ref_post_ops_t::args_t &po_args, dim_t id, dim_t ih, |
372 | dim_t iw) { |
373 | const bwd_linear_coeffs_t &h = bwd_linear_coeffs_[pd_->ID() + ih]; |
374 | const bwd_linear_coeffs_t &w |
375 | = bwd_linear_coeffs_[pd_->ID() + pd_->IH() + iw]; |
376 | |
377 | PRAGMA_OMP_SIMD() |
378 | for (dim_t innermost_el = 0; innermost_el < inner_stride_; |
379 | innermost_el++) { |
380 | float sum = 0; |
381 | for_(int j = 0; j < 2; j++) |
382 | for_(int k = 0; k < 2; k++) |
383 | for_(dim_t oh = h.start[j]; oh < h.end[j]; oh++) |
384 | for (dim_t ow = w.start[k]; ow < w.end[k]; ow++) { |
385 | sum += static_cast<float>(diff_dst[oh * stride_h_ |
386 | + ow * stride_w_ + innermost_el]) |
387 | * bwd_linear_weights_[2 * (pd_->OD() + oh) + j] |
388 | * bwd_linear_weights_[2 |
389 | * (pd_->OD() + pd_->OH() + ow) |
390 | + k]; |
391 | } |
392 | diff_src[innermost_el] |
393 | = cpu::saturate_and_round<dst_data_t>(sum); |
394 | } |
395 | }; |
396 | } |
397 | } |
398 | |
399 | template <data_type_t src_type, data_type_t dst_type> |
400 | typename simple_resampling_kernel_t<src_type, dst_type>::interpolate_fn_t |
401 | simple_resampling_kernel_t<src_type, dst_type>::create_trilinear() const { |
402 | if (pd_->is_fwd()) { |
403 | return [&](const src_data_t *src, dst_data_t *dst, |
404 | ref_post_ops_t::args_t &po_args, dim_t od, dim_t oh, |
405 | dim_t ow) { |
406 | const linear_coeffs_t &id = linear_coeffs_[od]; |
407 | const linear_coeffs_t &ih = linear_coeffs_[pd_->OD() + oh]; |
408 | const linear_coeffs_t &iw |
409 | = linear_coeffs_[pd_->OD() + pd_->OH() + ow]; |
410 | |
411 | PRAGMA_OMP_SIMD() |
412 | for (dim_t innermost_el = 0; innermost_el < inner_stride_; |
413 | innermost_el++) { |
414 | float res = 0; |
415 | for_(int i = 0; i < 2; i++) |
416 | for_(int j = 0; j < 2; j++) |
417 | for (int k = 0; k < 2; k++) |
418 | res += static_cast<float>(src[id.idx[i] * stride_d_ |
419 | + ih.idx[j] * stride_h_ |
420 | + iw.idx[k] * stride_w_ + innermost_el]) |
421 | * id.wei[i] * ih.wei[j] * iw.wei[k]; |
422 | |
423 | if (are_postops_set_) { |
424 | po_args.dst_val = dst[innermost_el]; |
425 | ref_post_ops_.execute(res, po_args); |
426 | po_args.l_offset++; |
427 | } |
428 | |
429 | dst[innermost_el] = cpu::saturate_and_round<dst_data_t>(res); |
430 | } |
431 | }; |
432 | } else { |
433 | return [&](const src_data_t *diff_dst, dst_data_t *diff_src, |
434 | ref_post_ops_t::args_t &po_args, dim_t id, dim_t ih, |
435 | dim_t iw) { |
436 | const bwd_linear_coeffs_t &d = bwd_linear_coeffs_[id]; |
437 | const bwd_linear_coeffs_t &h = bwd_linear_coeffs_[pd_->ID() + ih]; |
438 | const bwd_linear_coeffs_t &w |
439 | = bwd_linear_coeffs_[pd_->ID() + pd_->IH() + iw]; |
440 | |
441 | PRAGMA_OMP_SIMD() |
442 | for (dim_t innermost_el = 0; innermost_el < inner_stride_; |
443 | innermost_el++) { |
444 | float sum = 0; |
445 | for_(int i = 0; i < 2; i++) |
446 | for_(int j = 0; j < 2; j++) |
447 | for_(int k = 0; k < 2; k++) |
448 | for_(dim_t od = d.start[i]; od < d.end[i]; od++) |
449 | for_(dim_t oh = h.start[j]; oh < h.end[j]; oh++) |
450 | for (dim_t ow = w.start[k]; ow < w.end[k]; ow++) { |
451 | sum += static_cast<float>( |
452 | diff_dst[od * stride_d_ + oh * stride_h_ |
453 | + ow * stride_w_ + innermost_el]) |
454 | * bwd_linear_weights_[2 * od + i] |
455 | * bwd_linear_weights_[2 * (pd_->OD() + oh) + j] |
456 | * bwd_linear_weights_[2 |
457 | * (pd_->OD() + pd_->OH() + ow) |
458 | + k]; |
459 | } |
460 | diff_src[innermost_el] |
461 | = cpu::saturate_and_round<dst_data_t>(sum); |
462 | } |
463 | }; |
464 | } |
465 | } |
466 | |
467 | template struct simple_resampling_kernel_t<data_type::f32, data_type::f32>; |
468 | template struct simple_resampling_kernel_t<data_type::f32, data_type::bf16>; |
469 | template struct simple_resampling_kernel_t<data_type::f32, data_type::f16>; |
470 | template struct simple_resampling_kernel_t<data_type::f32, data_type::s32>; |
471 | template struct simple_resampling_kernel_t<data_type::f32, data_type::s8>; |
472 | template struct simple_resampling_kernel_t<data_type::f32, data_type::u8>; |
473 | |
474 | template struct simple_resampling_kernel_t<data_type::bf16, data_type::f32>; |
475 | template struct simple_resampling_kernel_t<data_type::bf16, data_type::bf16>; |
476 | template struct simple_resampling_kernel_t<data_type::bf16, data_type::f16>; |
477 | template struct simple_resampling_kernel_t<data_type::bf16, data_type::s32>; |
478 | template struct simple_resampling_kernel_t<data_type::bf16, data_type::s8>; |
479 | template struct simple_resampling_kernel_t<data_type::bf16, data_type::u8>; |
480 | |
481 | template struct simple_resampling_kernel_t<data_type::f16, data_type::f32>; |
482 | template struct simple_resampling_kernel_t<data_type::f16, data_type::bf16>; |
483 | template struct simple_resampling_kernel_t<data_type::f16, data_type::f16>; |
484 | template struct simple_resampling_kernel_t<data_type::f16, data_type::s32>; |
485 | template struct simple_resampling_kernel_t<data_type::f16, data_type::s8>; |
486 | template struct simple_resampling_kernel_t<data_type::f16, data_type::u8>; |
487 | |
488 | template struct simple_resampling_kernel_t<data_type::s32, data_type::f32>; |
489 | template struct simple_resampling_kernel_t<data_type::s32, data_type::bf16>; |
490 | template struct simple_resampling_kernel_t<data_type::s32, data_type::f16>; |
491 | template struct simple_resampling_kernel_t<data_type::s32, data_type::s32>; |
492 | template struct simple_resampling_kernel_t<data_type::s32, data_type::s8>; |
493 | template struct simple_resampling_kernel_t<data_type::s32, data_type::u8>; |
494 | |
495 | template struct simple_resampling_kernel_t<data_type::s8, data_type::f32>; |
496 | template struct simple_resampling_kernel_t<data_type::s8, data_type::bf16>; |
497 | template struct simple_resampling_kernel_t<data_type::s8, data_type::f16>; |
498 | template struct simple_resampling_kernel_t<data_type::s8, data_type::s32>; |
499 | template struct simple_resampling_kernel_t<data_type::s8, data_type::s8>; |
500 | template struct simple_resampling_kernel_t<data_type::s8, data_type::u8>; |
501 | |
502 | template struct simple_resampling_kernel_t<data_type::u8, data_type::f32>; |
503 | template struct simple_resampling_kernel_t<data_type::u8, data_type::bf16>; |
504 | template struct simple_resampling_kernel_t<data_type::u8, data_type::f16>; |
505 | template struct simple_resampling_kernel_t<data_type::u8, data_type::s32>; |
506 | template struct simple_resampling_kernel_t<data_type::u8, data_type::s8>; |
507 | template struct simple_resampling_kernel_t<data_type::u8, data_type::u8>; |
508 | |
509 | simple_resampling_base_t *create_simple_resampling(const resampling_pd_t *pd, |
510 | const data_type_t src_dt, const data_type_t dst_dt) { |
511 | using namespace data_type; |
512 | |
513 | switch (src_dt) { |
514 | case f32: |
515 | switch (dst_dt) { |
516 | case f32: return new simple_resampling_kernel_t<f32, f32>(pd); |
517 | case s32: return new simple_resampling_kernel_t<f32, s32>(pd); |
518 | case bf16: return new simple_resampling_kernel_t<f32, bf16>(pd); |
519 | case f16: return new simple_resampling_kernel_t<f32, f16>(pd); |
520 | case s8: return new simple_resampling_kernel_t<f32, s8>(pd); |
521 | case u8: return new simple_resampling_kernel_t<f32, u8>(pd); |
522 | default: break; |
523 | } |
524 | case s32: |
525 | switch (dst_dt) { |
526 | case f32: return new simple_resampling_kernel_t<s32, f32>(pd); |
527 | case s32: return new simple_resampling_kernel_t<s32, s32>(pd); |
528 | case bf16: return new simple_resampling_kernel_t<s32, bf16>(pd); |
529 | case f16: return new simple_resampling_kernel_t<s32, f16>(pd); |
530 | case s8: return new simple_resampling_kernel_t<s32, s8>(pd); |
531 | case u8: return new simple_resampling_kernel_t<s32, u8>(pd); |
532 | default: break; |
533 | } |
534 | case bf16: |
535 | switch (dst_dt) { |
536 | case f32: return new simple_resampling_kernel_t<bf16, f32>(pd); |
537 | case s32: return new simple_resampling_kernel_t<bf16, s32>(pd); |
538 | case bf16: |
539 | return new simple_resampling_kernel_t<bf16, bf16>(pd); |
540 | case f16: return new simple_resampling_kernel_t<bf16, f16>(pd); |
541 | case s8: return new simple_resampling_kernel_t<bf16, s8>(pd); |
542 | case u8: return new simple_resampling_kernel_t<bf16, u8>(pd); |
543 | default: break; |
544 | } |
545 | case f16: |
546 | switch (dst_dt) { |
547 | case f32: return new simple_resampling_kernel_t<f16, f32>(pd); |
548 | case s32: return new simple_resampling_kernel_t<f16, s32>(pd); |
549 | case bf16: return new simple_resampling_kernel_t<f16, bf16>(pd); |
550 | case f16: return new simple_resampling_kernel_t<f16, f16>(pd); |
551 | case s8: return new simple_resampling_kernel_t<f16, s8>(pd); |
552 | case u8: return new simple_resampling_kernel_t<f16, u8>(pd); |
553 | default: break; |
554 | } |
555 | case s8: |
556 | switch (dst_dt) { |
557 | case f32: return new simple_resampling_kernel_t<s8, f32>(pd); |
558 | case s32: return new simple_resampling_kernel_t<s8, s32>(pd); |
559 | case bf16: return new simple_resampling_kernel_t<s8, bf16>(pd); |
560 | case f16: return new simple_resampling_kernel_t<s8, f16>(pd); |
561 | case s8: return new simple_resampling_kernel_t<s8, s8>(pd); |
562 | case u8: return new simple_resampling_kernel_t<s8, u8>(pd); |
563 | default: break; |
564 | } |
565 | case u8: |
566 | switch (dst_dt) { |
567 | case f32: return new simple_resampling_kernel_t<u8, f32>(pd); |
568 | case s32: return new simple_resampling_kernel_t<u8, s32>(pd); |
569 | case bf16: return new simple_resampling_kernel_t<u8, bf16>(pd); |
570 | case f16: return new simple_resampling_kernel_t<u8, f16>(pd); |
571 | case s8: return new simple_resampling_kernel_t<u8, s8>(pd); |
572 | case u8: return new simple_resampling_kernel_t<u8, u8>(pd); |
573 | default: break; |
574 | } |
575 | default: break; |
576 | } |
577 | |
578 | assert(!"Unsupported data type combination." ); |
579 | return nullptr; |
580 | } |
581 | |
582 | } // namespace |
583 | |
584 | simple_resampling_fwd_t::simple_resampling_fwd_t(const pd_t *apd) |
585 | : primitive_t(apd), kernel_(nullptr) {} |
586 | |
587 | status_t simple_resampling_fwd_t::init(engine_t *engine) { |
588 | CHECK(safe_ptr_assign(kernel_, |
589 | create_simple_resampling(pd(), pd()->src_md()->data_type, |
590 | pd()->dst_md()->data_type))); |
591 | return kernel_->init(); |
592 | } |
593 | |
594 | status_t simple_resampling_fwd_t::execute(const exec_ctx_t &ctx) const { |
595 | return kernel_->execute(ctx); |
596 | } |
597 | |
598 | simple_resampling_bwd_t::simple_resampling_bwd_t(const pd_t *apd) |
599 | : primitive_t(apd), kernel_(nullptr) {} |
600 | |
601 | status_t simple_resampling_bwd_t::init(engine_t *engine) { |
602 | CHECK(safe_ptr_assign(kernel_, |
603 | create_simple_resampling(pd(), pd()->diff_dst_md()->data_type, |
604 | pd()->diff_src_md()->data_type))); |
605 | return kernel_->init(); |
606 | } |
607 | |
608 | status_t simple_resampling_bwd_t::execute(const exec_ctx_t &ctx) const { |
609 | return kernel_->execute(ctx); |
610 | } |
611 | } // namespace cpu |
612 | } // namespace impl |
613 | } // namespace dnnl |
614 | |
615 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
616 | |