1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <bitset> |
18 | #include <cassert> |
19 | |
20 | #include "common/bfloat16.hpp" |
21 | #include "common/c_types_map.hpp" |
22 | #include "common/dnnl_thread.hpp" |
23 | #include "common/type_helpers.hpp" |
24 | #include "common/utils.hpp" |
25 | |
26 | #include "cpu/resampling_utils.hpp" |
27 | |
28 | #include "cpu/x64/jit_generator.hpp" |
29 | #include "cpu/x64/jit_uni_resampling.hpp" |
30 | |
31 | namespace dnnl { |
32 | namespace impl { |
33 | namespace cpu { |
34 | namespace x64 { |
35 | |
36 | using namespace resampling_utils; |
37 | |
38 | static cpu_isa_t get_supported_isa(bool is_plain) { |
39 | if (is_plain && mayiuse(avx512_core_fp16)) return avx512_core_fp16; |
40 | if (mayiuse(avx512_core_bf16)) return avx512_core_bf16; |
41 | if (mayiuse(avx512_core)) return avx512_core; |
42 | if (mayiuse(avx2)) return avx2; |
43 | if (mayiuse(avx)) return avx; |
44 | if (mayiuse(sse41)) return sse41; |
45 | |
46 | return isa_undef; |
47 | } |
48 | |
49 | static bool impl_supports_datatype(data_type_t data_type) { |
50 | switch (data_type) { |
51 | case data_type::bf16: return x64::mayiuse(x64::avx512_core); |
52 | case data_type::f16: return x64::mayiuse(x64::avx512_core_fp16); |
53 | case data_type::f32: |
54 | case data_type::s32: |
55 | case data_type::s8: |
56 | case data_type::u8: return true; |
57 | default: return false; |
58 | } |
59 | } |
60 | |
61 | status_t jit_uni_resampling_fwd_t::pd_t::init(engine_t *engine) { |
62 | using namespace data_type; |
63 | using sm = primitive_attr_t::skip_mask_t; |
64 | |
65 | const memory_desc_wrapper src_d(src_md()); |
66 | const memory_desc_wrapper dst_d(dst_md()); |
67 | |
68 | conf_.src_data_type = src_md()->data_type; |
69 | conf_.dst_data_type = dst_md()->data_type; |
70 | |
71 | fill_format_tag_info(); |
72 | conf_.isa = get_supported_isa(src_d.is_plain()); |
73 | |
74 | const bool ok = is_fwd() && !has_zero_dim_memory() |
75 | && conf_.src_tag != format_tag::undef |
76 | && set_default_params(conf_.src_tag) == status::success |
77 | && impl_supports_datatype(conf_.src_data_type) |
78 | && impl_supports_datatype(conf_.dst_data_type) |
79 | && IMPLICATION(conf_.src_data_type == f16, src_d.is_plain()) |
80 | && attr()->has_default_values(sm::post_ops, conf_.dst_data_type) |
81 | && attr_.set_default_formats(dst_md(0)) == status::success; |
82 | if (!ok) return status::unimplemented; |
83 | |
84 | if (!memory_desc_matches_tag(*dst_md(), conf_.src_tag)) |
85 | return status::unimplemented; |
86 | |
87 | conf_.alg = desc()->alg_kind; |
88 | conf_.c = C(); |
89 | conf_.od = OD(); |
90 | conf_.oh = OH(); |
91 | conf_.ow = OW(); |
92 | conf_.id = ID(); |
93 | conf_.ih = IH(); |
94 | conf_.iw = IW(); |
95 | conf_.ndims = ndims(); |
96 | |
97 | if (conf_.alg == alg_kind::resampling_linear) |
98 | conf_.number_of_corners = pow(2, conf_.ndims - 2); |
99 | |
100 | conf_.src_dt_size = types::data_type_size(conf_.src_data_type); |
101 | conf_.dst_dt_size = types::data_type_size(conf_.dst_data_type); |
102 | |
103 | conf_.is_saturation_needed |
104 | = utils::one_of(conf_.dst_data_type, s32, s8, u8); |
105 | |
106 | const size_t L3_size = static_cast<size_t>(dnnl_get_current_num_threads()) |
107 | * platform::get_per_core_cache_size(3); |
108 | const size_t input_data_size = src_d.nelems(true) * conf_.src_dt_size; |
109 | const size_t output_data_size = dst_d.nelems(true) * conf_.dst_dt_size; |
110 | const size_t whole_data_size = input_data_size + output_data_size; |
111 | conf_.output_data_size = output_data_size; |
112 | conf_.is_data_size_bigger_than_L3 |
113 | = L3_size > 0 ? whole_data_size > L3_size : false; |
114 | |
115 | conf_.el_size_of_indices = sizeof(unsigned); |
116 | |
117 | conf_.inner_stride = src_d.blocking_desc().strides[ndims() - 1]; |
118 | conf_.stride_d = IH() * IW() * conf_.inner_stride * conf_.src_dt_size; |
119 | conf_.stride_h = IW() * conf_.inner_stride * conf_.src_dt_size; |
120 | conf_.stride_w = conf_.inner_stride * conf_.src_dt_size; |
121 | |
122 | const std::vector<injector::post_op_type> accepted_post_ops |
123 | = {injector::sum, injector::eltwise, injector::binary}; |
124 | static constexpr bool sum_at_0_pos_only = false; |
125 | static constexpr bool sum_requires_scale_one = false; |
126 | static constexpr bool sum_requires_zp_zero = true; |
127 | const bcast_set_t accepted_broadcasts |
128 | = {broadcasting_strategy_t::scalar, broadcasting_strategy_t::per_oc, |
129 | broadcasting_strategy_t::per_oc_spatial}; |
130 | injector::post_ops_ok_args_t post_ops_args(conf_.isa, accepted_post_ops, |
131 | attr()->post_ops_, &dst_d, sum_at_0_pos_only, |
132 | sum_requires_scale_one, sum_requires_zp_zero, accepted_broadcasts); |
133 | if (!post_ops_ok(post_ops_args)) return status::unimplemented; |
134 | |
135 | conf_.post_ops = attr()->post_ops_; |
136 | |
137 | static constexpr bool require_scale_one = false; |
138 | conf_.with_eltwise = conf_.with_binary = conf_.with_sum = false; |
139 | for (const auto &entry : conf_.post_ops.entry_) { |
140 | if (entry.is_eltwise()) { |
141 | conf_.with_eltwise = true; |
142 | } else if (entry.is_binary()) { |
143 | conf_.with_binary = true; |
144 | } else if (entry.is_sum(require_scale_one) && entry.sum.scale != 0.f) { |
145 | conf_.with_sum = true; |
146 | conf_.sum_scales.push(entry.sum.scale); |
147 | } |
148 | } |
149 | conf_.with_postops |
150 | = conf_.with_eltwise || conf_.with_binary || conf_.with_sum; |
151 | |
152 | return status::success; |
153 | } |
154 | |
155 | void jit_uni_resampling_fwd_t::pd_t::fill_format_tag_info() { |
156 | using namespace format_tag; |
157 | |
158 | const format_tag_t blocked_16_format = memory_desc_matches_one_of_tag( |
159 | *src_md(), nCw16c, nChw16c, nCdhw16c); |
160 | const format_tag_t blocked_8_format |
161 | = memory_desc_matches_one_of_tag(*src_md(), nCw8c, nChw8c, nCdhw8c); |
162 | const format_tag_t nspc_format |
163 | = memory_desc_matches_one_of_tag(*src_md(), nwc, nhwc, ndhwc); |
164 | const format_tag_t ncsp_format |
165 | = memory_desc_matches_one_of_tag(*src_md(), ncw, nchw, ncdhw); |
166 | |
167 | if (blocked_16_format != undef) { |
168 | conf_.tag_kind = jit_memory_tag_kind_t::blocked; |
169 | conf_.src_tag = blocked_16_format; |
170 | } else if (blocked_8_format != undef) { |
171 | conf_.is_blocked_8_format = true; |
172 | conf_.tag_kind = jit_memory_tag_kind_t::blocked; |
173 | conf_.src_tag = blocked_8_format; |
174 | } else if (nspc_format != undef) { |
175 | conf_.tag_kind = jit_memory_tag_kind_t::nspc; |
176 | conf_.src_tag = nspc_format; |
177 | } else if (ncsp_format != undef) { |
178 | conf_.tag_kind = jit_memory_tag_kind_t::ncsp; |
179 | conf_.src_tag = ncsp_format; |
180 | } else { |
181 | conf_.tag_kind = jit_memory_tag_kind_t::undef; |
182 | conf_.src_tag = undef; |
183 | } |
184 | } |
185 | |
186 | status_t jit_uni_resampling_fwd_t::get_proper_kernel_for_avx512( |
187 | const memory_desc_t *dst_md, const jit_resampling_conf_t &conf) { |
188 | const format_tag_t blocked_8_tag = utils::pick(conf.ndims - 3, |
189 | format_tag::nCw8c, format_tag::nChw8c, format_tag::nCdhw8c); |
190 | if (is_superset(conf.isa, avx512_core_fp16)) |
191 | return safe_ptr_assign(kernel_, |
192 | new jit_uni_resampling_kernel_t<avx512_core_fp16, Xbyak::Zmm>( |
193 | conf, dst_md)); |
194 | |
195 | if (memory_desc_matches_tag(*pd()->src_md(), blocked_8_tag)) { |
196 | return safe_ptr_assign(kernel_, |
197 | new jit_uni_resampling_kernel_t<avx512_core, Xbyak::Ymm>( |
198 | conf, dst_md)); |
199 | } |
200 | |
201 | return safe_ptr_assign(kernel_, |
202 | new jit_uni_resampling_kernel_t<avx512_core, Xbyak::Zmm>( |
203 | conf, dst_md)); |
204 | } |
205 | |
206 | status_t jit_uni_resampling_fwd_t::get_proper_kernel_for_avx( |
207 | const memory_desc_t *dst_md, const jit_resampling_conf_t &conf) { |
208 | using namespace data_type; |
209 | |
210 | const bool is_src_i8 = utils::one_of(conf.src_data_type, s8, u8); |
211 | const bool is_dst_i8 = utils::one_of(conf.dst_data_type, s8, u8); |
212 | if (is_src_i8 || is_dst_i8) |
213 | return safe_ptr_assign(kernel_, |
214 | new jit_uni_resampling_kernel_t<avx, Xbyak::Xmm>(conf, dst_md)); |
215 | |
216 | return safe_ptr_assign(kernel_, |
217 | new jit_uni_resampling_kernel_t<avx, Xbyak::Ymm>(conf, dst_md)); |
218 | } |
219 | |
220 | status_t jit_uni_resampling_fwd_t::get_proper_kernel_for_sse( |
221 | const memory_desc_t *dst_md, const jit_resampling_conf_t &conf) { |
222 | return safe_ptr_assign(kernel_, |
223 | new jit_uni_resampling_kernel_t<sse41, Xbyak::Xmm>(conf, dst_md)); |
224 | } |
225 | |
226 | status_t jit_uni_resampling_fwd_t::init(engine_t *engine) { |
227 | using namespace format_tag; |
228 | |
229 | const memory_desc_t *dst_md = pd()->dst_md(); |
230 | const jit_resampling_conf_t &conf = pd()->get_conf(); |
231 | |
232 | if (is_superset(conf.isa, avx512_core)) |
233 | CHECK(get_proper_kernel_for_avx512(dst_md, conf)); |
234 | else if (is_superset(conf.isa, avx)) |
235 | CHECK(get_proper_kernel_for_avx(dst_md, conf)); |
236 | else if (conf.isa == sse41) { |
237 | CHECK(get_proper_kernel_for_sse(dst_md, conf)); |
238 | } else { |
239 | assert(!"Unsupported isa." ); |
240 | return status::runtime_error; |
241 | } |
242 | |
243 | CHECK(kernel_->create_kernel()); |
244 | |
245 | return fill_data_for_interpolation(); |
246 | } |
247 | |
248 | status_t jit_uni_resampling_fwd_t::fill_data_for_interpolation() { |
249 | switch (pd()->desc()->alg_kind) { |
250 | case alg_kind::resampling_nearest: return fill_data_for_nearest(); |
251 | case alg_kind::resampling_linear: return fill_data_for_linear(); |
252 | default: |
253 | assert(!"Invalid resampling algorithm." ); |
254 | return status::invalid_arguments; |
255 | } |
256 | } |
257 | |
258 | status_t jit_uni_resampling_fwd_t::fill_data_for_nearest() { |
259 | // In kernel is used vmovdqu to get indices. This instruction don't have |
260 | // tail processing possibilities on sse41 and avx. To avoid problems |
261 | // with that, OW is aligned to simd width, because indices for ow |
262 | // are read in the kernel. |
263 | indices_.reserve(pd()->OD() + pd()->OH() |
264 | + utils::rnd_up(pd()->OW(), kernel_->get_simd_w())); |
265 | |
266 | for (dim_t od = 0; od < pd()->OD(); od++) { |
267 | const int offset_id = nearest_idx(od, pd()->OD(), pd()->ID()) |
268 | * pd()->get_conf().stride_d; |
269 | indices_.emplace_back(offset_id); |
270 | } |
271 | for (dim_t oh = 0; oh < pd()->OH(); oh++) { |
272 | const int offset_ih = nearest_idx(oh, pd()->OH(), pd()->IH()) |
273 | * pd()->get_conf().stride_h; |
274 | indices_.emplace_back(offset_ih); |
275 | } |
276 | for (dim_t ow = 0; ow < pd()->OW(); ow++) { |
277 | const int offset_iw = nearest_idx(ow, pd()->OW(), pd()->IW()) |
278 | * pd()->get_conf().stride_w; |
279 | indices_.emplace_back(offset_iw); |
280 | } |
281 | |
282 | return status::success; |
283 | } |
284 | |
285 | status_t jit_uni_resampling_fwd_t::fill_data_for_linear() { |
286 | using namespace resampling_utils; |
287 | |
288 | const unsigned number_of_corners = pd()->get_conf().number_of_corners; |
289 | const unsigned stride_w = pd()->get_conf().stride_w; |
290 | const unsigned stride_h = pd()->get_conf().stride_h; |
291 | const unsigned stride_d = pd()->get_conf().stride_d; |
292 | |
293 | unsigned num_of_elements = 0; |
294 | if (pd()->get_conf().tag_kind == jit_memory_tag_kind_t::ncsp) { |
295 | // In kernel is used vmovdqu to get indices. This instruction don't have |
296 | // tail processing possibilities on sse41 and avx. To avoid problems |
297 | // with that, number of spatial points is aligned to simd width, because |
298 | // all of them are read in the kernel. |
299 | num_of_elements = number_of_corners |
300 | * utils::rnd_up(pd()->OD() * pd()->OH() * pd()->OW(), |
301 | kernel_->get_simd_w()); |
302 | |
303 | indices_.resize(num_of_elements); |
304 | weights_.resize(num_of_elements); |
305 | |
306 | const size_t indices_stride = pd()->OW() * pd()->OH() * pd()->OD(); |
307 | const size_t weights_stride = pd()->OW() * pd()->OH() * pd()->OD(); |
308 | |
309 | parallel_nd(pd()->OD(), pd()->OH(), [&](dim_t od, dim_t oh) { |
310 | const linear_coeffs_t coeffs_id(od, pd()->OD(), pd()->ID()); |
311 | const linear_coeffs_t coeffs_ih(oh, pd()->OH(), pd()->IH()); |
312 | |
313 | for (dim_t ow = 0; ow < pd()->OW(); ow++) { |
314 | const size_t offset |
315 | = od * pd()->OH() * pd()->OW() + oh * pd()->OW() + ow; |
316 | |
317 | const linear_coeffs_t coeffs_iw(ow, pd()->OW(), pd()->IW()); |
318 | |
319 | for (unsigned i = 0; i < number_of_corners; i++) { |
320 | std::bitset<3> corners(i); |
321 | indices_[i * indices_stride + offset] |
322 | = coeffs_id.idx[corners.test(2)] * stride_d |
323 | + coeffs_ih.idx[corners.test(1)] * stride_h |
324 | + coeffs_iw.idx[corners.test(0)] * stride_w; |
325 | weights_[i * weights_stride + offset] |
326 | = coeffs_id.wei[corners.test(2)] |
327 | * coeffs_ih.wei[corners.test(1)] |
328 | * coeffs_iw.wei[corners.test(0)]; |
329 | } |
330 | } |
331 | }); |
332 | } else if (pd()->get_conf().tag_kind == jit_memory_tag_kind_t::nspc |
333 | || pd()->get_conf().tag_kind == jit_memory_tag_kind_t::blocked) { |
334 | num_of_elements = 2 * (pd()->OD() + pd()->OH() + pd()->OW()); |
335 | |
336 | indices_.resize(num_of_elements); |
337 | weights_.resize(num_of_elements); |
338 | |
339 | unsigned *indices_w = &indices_[0]; |
340 | unsigned *indices_h = &indices_[2 * pd()->OW()]; |
341 | unsigned *indices_d = &indices_[2 * (pd()->OW() + pd()->OH())]; |
342 | float *weights_w = &weights_[0]; |
343 | float *weights_h = &weights_[2 * pd()->OW()]; |
344 | float *weights_d = &weights_[2 * (pd()->OW() + pd()->OH())]; |
345 | |
346 | for (dim_t ow = 0; ow < pd()->OW(); ow++) { |
347 | const linear_coeffs_t coeffs_iw(ow, pd()->OW(), pd()->IW()); |
348 | |
349 | // The right and left corners are set one after |
350 | // the other because in the kernel these values |
351 | // are read one by one, which makes it easier |
352 | // to read and makes the operation faster. |
353 | weights_w[2 * ow] = coeffs_iw.wei[0]; |
354 | weights_w[2 * ow + 1] = coeffs_iw.wei[1]; |
355 | indices_w[2 * ow] = coeffs_iw.idx[0] * stride_w; |
356 | indices_w[2 * ow + 1] = coeffs_iw.idx[1] * stride_w; |
357 | } |
358 | |
359 | for (dim_t oh = 0; oh < pd()->OH(); oh++) { |
360 | const linear_coeffs_t coeffs_ih(oh, pd()->OH(), pd()->IH()); |
361 | |
362 | weights_h[oh] = coeffs_ih.wei[0]; |
363 | weights_h[pd()->OH() + oh] = coeffs_ih.wei[1]; |
364 | indices_h[oh] = coeffs_ih.idx[0] * stride_h; |
365 | indices_h[pd()->OH() + oh] = coeffs_ih.idx[1] * stride_h; |
366 | } |
367 | |
368 | for (dim_t od = 0; od < pd()->OD(); od++) { |
369 | const linear_coeffs_t coeffs_id(od, pd()->OD(), pd()->ID()); |
370 | |
371 | weights_d[od] = coeffs_id.wei[0]; |
372 | weights_d[pd()->OD() + od] = coeffs_id.wei[1]; |
373 | indices_d[od] = coeffs_id.idx[0] * stride_d; |
374 | indices_d[pd()->OD() + od] = coeffs_id.idx[1] * stride_d; |
375 | } |
376 | } else { |
377 | assert(!"Invalid memory format kind." ); |
378 | return status::invalid_arguments; |
379 | } |
380 | |
381 | return status::success; |
382 | } |
383 | |
384 | status_t jit_uni_resampling_fwd_t::execute(const exec_ctx_t &ctx) const { |
385 | const auto src = CTX_IN_MEM(const uint8_t *, DNNL_ARG_SRC); |
386 | auto dst = CTX_OUT_MEM(uint8_t *, DNNL_ARG_DST); |
387 | |
388 | const std::vector<const void *> post_ops_binary_rhs_arg_vec |
389 | = binary_injector::prepare_binary_args( |
390 | pd()->get_conf().post_ops, ctx); |
391 | |
392 | switch (pd()->desc()->alg_kind) { |
393 | case alg_kind::resampling_nearest: |
394 | return interpolate_nearest(src, dst, post_ops_binary_rhs_arg_vec); |
395 | case alg_kind::resampling_linear: |
396 | return interpolate_linear(src, dst, post_ops_binary_rhs_arg_vec); |
397 | default: |
398 | assert(!"Invalid resampling algorithm." ); |
399 | return status::invalid_arguments; |
400 | } |
401 | } |
402 | |
403 | status_t jit_uni_resampling_fwd_t::interpolate_nearest(const uint8_t *src, |
404 | uint8_t *dst, const std::vector<const void *> &post_ops_args) const { |
405 | const size_t src_dt_size = pd()->get_conf().src_dt_size; |
406 | const size_t dst_dt_size = pd()->get_conf().dst_dt_size; |
407 | const size_t inner_stride = pd()->get_conf().inner_stride; |
408 | |
409 | const dim_t MB = pd()->MB(); |
410 | const dim_t C = pd()->C(); |
411 | const dim_t CB = utils::div_up(C, inner_stride); |
412 | const dim_t nsp_outer = MB * CB; |
413 | const dim_t OD = pd()->OD(); |
414 | const dim_t OH = pd()->OH(); |
415 | const dim_t OW = pd()->OW(); |
416 | const dim_t ID = pd()->ID(); |
417 | const dim_t IH = pd()->IH(); |
418 | const dim_t IW = pd()->IW(); |
419 | |
420 | const unsigned *indices_d = &indices_[0]; |
421 | const unsigned *indices_h = &indices_[OD]; |
422 | const unsigned *indices_w = &indices_[OD + OH]; |
423 | |
424 | if (pd()->get_conf().tag_kind == jit_memory_tag_kind_t::ncsp) { |
425 | parallel_nd(MB, C, OD, [&](dim_t mb, dim_t c, dim_t od) { |
426 | const dim_t src_off |
427 | = (mb * C + c) * ID * IH * IW * src_dt_size + indices_d[od]; |
428 | const dim_t dst_off = ((mb * C + c) * OD * OH * OW + od * OH * OW) |
429 | * dst_dt_size; |
430 | |
431 | jit_resampling_call_s args = jit_resampling_call_s(); |
432 | args.src = src + src_off; |
433 | args.dst = dst + dst_off; |
434 | args.dst_orig = dst; |
435 | args.indices = &indices_h[0]; |
436 | args.post_ops_binary_rhs_arg_vec = post_ops_args.data(); |
437 | args.c_offset = static_cast<size_t>(c); |
438 | |
439 | (*kernel_)(&args); |
440 | }); |
441 | } else if (pd()->get_conf().tag_kind == jit_memory_tag_kind_t::nspc |
442 | || pd()->get_conf().tag_kind == jit_memory_tag_kind_t::blocked) { |
443 | parallel_nd(nsp_outer, OD, OH, [&](dim_t nsp, dim_t od, dim_t oh) { |
444 | const dim_t src_off |
445 | = nsp * ID * IH * IW * inner_stride * src_dt_size |
446 | + indices_d[od] + indices_h[oh]; |
447 | const dim_t dst_off = ((nsp * OD + od) * OH + oh) * OW |
448 | * inner_stride * dst_dt_size; |
449 | |
450 | const size_t cb = std::div(nsp, CB).rem; |
451 | |
452 | jit_resampling_call_s args = jit_resampling_call_s(); |
453 | args.batch_of_sp_points_to_process = OW; |
454 | args.src = src + src_off; |
455 | args.dst = dst + dst_off; |
456 | args.dst_orig = dst; |
457 | args.indices = &indices_w[0]; |
458 | args.post_ops_binary_rhs_arg_vec = post_ops_args.data(); |
459 | args.c_offset = static_cast<size_t>(cb * inner_stride); |
460 | |
461 | (*kernel_)(&args); |
462 | }); |
463 | } else { |
464 | assert(!"Invalid memory format kind." ); |
465 | return status::invalid_arguments; |
466 | } |
467 | |
468 | return status::success; |
469 | } |
470 | |
471 | status_t jit_uni_resampling_fwd_t::interpolate_linear(const uint8_t *src, |
472 | uint8_t *dst, const std::vector<const void *> &post_ops_args) const { |
473 | const size_t src_dt_size = pd()->get_conf().src_dt_size; |
474 | const size_t dst_dt_size = pd()->get_conf().dst_dt_size; |
475 | const size_t inner_stride = pd()->get_conf().inner_stride; |
476 | |
477 | const dim_t MB = pd()->MB(); |
478 | const dim_t C = pd()->C(); |
479 | const dim_t CB = utils::div_up(C, inner_stride); |
480 | const dim_t nsp_outer = MB * CB; |
481 | const dim_t OD = pd()->OD(); |
482 | const dim_t OH = pd()->OH(); |
483 | const dim_t OW = pd()->OW(); |
484 | const dim_t ID = pd()->ID(); |
485 | const dim_t IH = pd()->IH(); |
486 | const dim_t IW = pd()->IW(); |
487 | |
488 | if (pd()->get_conf().tag_kind == jit_memory_tag_kind_t::ncsp) { |
489 | parallel_nd(MB, C, [&](dim_t mb, dim_t c) { |
490 | const dim_t src_off = (mb * C + c) * ID * IH * IW * src_dt_size; |
491 | const dim_t dst_off = (mb * C + c) * OD * OH * OW * dst_dt_size; |
492 | |
493 | jit_resampling_call_s args = jit_resampling_call_s(); |
494 | args.batch_of_sp_points_to_process = OW * OH * OD; |
495 | args.src = src + src_off; |
496 | args.dst = dst + dst_off; |
497 | args.dst_orig = dst; |
498 | args.indices = &indices_[0]; |
499 | args.weights = &weights_[0]; |
500 | args.post_ops_binary_rhs_arg_vec = post_ops_args.data(); |
501 | args.c_offset = static_cast<size_t>(c); |
502 | |
503 | (*kernel_)(&args); |
504 | }); |
505 | } else if (pd()->get_conf().tag_kind == jit_memory_tag_kind_t::nspc |
506 | || pd()->get_conf().tag_kind == jit_memory_tag_kind_t::blocked) { |
507 | const unsigned *indices_top = &indices_[2 * OW]; |
508 | const unsigned *indices_bottom = &indices_[2 * OW + OH]; |
509 | const unsigned *indices_front = &indices_[2 * (OW + OH)]; |
510 | const unsigned *indices_back = &indices_[2 * (OW + OH) + OD]; |
511 | const float *weights_top = &weights_[2 * OW]; |
512 | const float *weights_bottom = &weights_[2 * OW + OH]; |
513 | const float *weights_front = &weights_[2 * (OW + OH)]; |
514 | const float *weights_back = &weights_[2 * (OW + OH) + OD]; |
515 | |
516 | parallel_nd(nsp_outer, OD, OH, [&](dim_t nsp, dim_t od, dim_t oh) { |
517 | const dim_t src_off |
518 | = nsp * ID * IH * IW * inner_stride * src_dt_size; |
519 | const dim_t dst_off = (((nsp * OD + od) * OH + oh) * OW) |
520 | * inner_stride * dst_dt_size; |
521 | |
522 | const size_t cb = std::div(nsp, CB).rem; |
523 | |
524 | jit_resampling_call_s args = jit_resampling_call_s(); |
525 | args.batch_of_sp_points_to_process = OW; |
526 | args.src = src + src_off; |
527 | args.dst = dst + dst_off; |
528 | args.dst_orig = dst; |
529 | args.indices = &indices_[0]; |
530 | args.weights = &weights_[0]; |
531 | args.post_ops_binary_rhs_arg_vec = post_ops_args.data(); |
532 | args.c_offset = static_cast<size_t>(cb * inner_stride); |
533 | |
534 | args.src_offset_front = indices_front[od]; |
535 | args.src_offset_back = indices_back[od]; |
536 | args.src_offset_top = indices_top[oh]; |
537 | args.src_offset_bottom = indices_bottom[oh]; |
538 | args.weight_front = weights_front[od]; |
539 | args.weight_back = weights_back[od]; |
540 | args.weight_top = weights_top[oh]; |
541 | args.weight_bottom = weights_bottom[oh]; |
542 | |
543 | (*kernel_)(&args); |
544 | }); |
545 | } else { |
546 | assert(!"Invalid memory format kind." ); |
547 | return status::invalid_arguments; |
548 | } |
549 | |
550 | return status::success; |
551 | } |
552 | |
553 | } // namespace x64 |
554 | } // namespace cpu |
555 | } // namespace impl |
556 | } // namespace dnnl |
557 | |