1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <bitset>
18
19#include "common/c_types_map.hpp"
20
21#include "cpu/x64/jit_avx512_core_bf16cvt.hpp"
22#include "cpu/x64/jit_generator.hpp"
23#include "cpu/x64/jit_uni_resampling_kernel.hpp"
24
25namespace dnnl {
26namespace impl {
27namespace cpu {
28namespace x64 {
29
30using namespace Xbyak;
31using namespace format_tag;
32using tag_kind = jit_memory_tag_kind_t;
33
34#define GET_OFF(field) offsetof(jit_resampling_call_s, field)
35
36template <cpu_isa_t isa, typename Vmm>
37jit_uni_resampling_kernel_t<isa, Vmm>::jit_uni_resampling_kernel_t(
38 const jit_resampling_conf_t &conf, const memory_desc_t *dst_md)
39 : jit_uni_resampling_kernel_base_t(conf)
40 , tail_size_(calculate_tail_size())
41 , io_(this, conf_.isa, {conf_.src_data_type, conf_.dst_data_type},
42 {can_movntps_be_used()},
43 io::io_tail_conf_t {simd_w_, tail_size_, k_tail_mask_,
44 vmm_tail_mask_.getIdx(), reg_tmp_},
45 io::io_emu_bf16_conf_t {vmm_bf16_emu_1_, vmm_bf16_emu_2_,
46 vmm_bf16_emu_3_, reg_tmp_, vmm_bf16_emu_4_},
47 create_saturation_vmm_map(),
48 io::io_gather_conf_t {simd_w_, k_full_mask_,
49 vmm_full_mask_.getIdx(), reg_tmp_, reg_tmp1_,
50 vmm_tmp_gather_.getIdx()}) {
51 if (conf_.with_postops) {
52 memory_desc_wrapper dst_d = memory_desc_wrapper(*dst_md);
53
54 static constexpr bool preserve_gpr = true;
55 static constexpr bool preserve_vmm = false;
56 static constexpr bool use_exact_tail_scalar_bcast = true;
57
58 const binary_injector::rhs_arg_static_params_t rhs_sp {
59 static_cast<size_t>(vmm_post_op_helper_.getIdx()), r14, r15,
60 r13, preserve_gpr, preserve_vmm,
61 GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig), dst_d,
62 tail_size_, k_tail_mask_, use_exact_tail_scalar_bcast};
63
64 const bcast_set_t accepted_broadcasts
65 = {broadcasting_strategy_t::scalar,
66 broadcasting_strategy_t::per_oc,
67 broadcasting_strategy_t::per_oc_spatial};
68 const binary_injector::static_params_t bsp {
69 reg_param, accepted_broadcasts, rhs_sp};
70
71 postops_injector_ = utils::make_unique<
72 injector::jit_uni_postops_injector_t<isa, Vmm>>(
73 this, conf_.post_ops, bsp);
74
75 std::tie(any_binary_postop_is_per_oc_bcast_type_,
76 any_binary_postop_is_per_oc_sp_bcast_type_)
77 = binary_injector_utils::bcast_strategies_present_tup(
78 conf_.post_ops.entry_, dst_d,
79 broadcasting_strategy_t::per_oc,
80 broadcasting_strategy_t::per_oc_spatial);
81 }
82}
83
84template <cpu_isa_t isa, typename Vmm>
85bool jit_uni_resampling_kernel_t<isa, Vmm>::can_movntps_be_used() const {
86 const std::size_t alignment = simd_w_ * conf_.dst_dt_size;
87
88 assert(alignment > 0 && conf_.dst_dt_size > 0
89 && "Incorrect output data type size.");
90
91 bool are_data_filling_register_fully = false;
92 switch (conf_.dst_data_type) {
93 case data_type::f32:
94 case data_type::s32: are_data_filling_register_fully = true; break;
95 case data_type::f16:
96 case data_type::bf16:
97 are_data_filling_register_fully = is_xmm_ ? false : true;
98 break;
99 case data_type::s8:
100 case data_type::u8:
101 are_data_filling_register_fully = is_zmm_ ? true : false;
102 break;
103 default: assert(!"Unsupported data type."); break;
104 }
105
106 // When movntps can be used:
107 // 1) There is no tail size because movntps has no possibility to store
108 // data with a mask. The blocked format is an exception from this rule
109 // because there is a padded area and for io operation, there is no use of masks.
110 // 2) Data are filling the register fully. Example: Zmm register can hold sixteen
111 // f32 values, so there is a possibility to calculate 16 values at the same time,
112 // but during store operation of i8 data the same sixteen values can hold only xmm
113 // register. If ymm will be used then eight values of f32 can be hold, but neither
114 // zmm nor ymm nor xmm can hold i8 data fully because data size is 64 bits only.
115 // 3) The memory operand must be aligned on a 16-byte (128-bit version),
116 // 32-byte (VEX.256 encoded version) or 64-byte (EVEX.512 encoded version)
117 // boundary otherwise a general-protection exception (#GP) will be generated.
118 // 4) Instruction is supported and the register is fully filled with data.
119 // 5) Data is big enough to see profit from using non-temporal stores.
120 bool can_use_movntps = false;
121 if (is_superset(conf_.isa, avx512_core) || conf_.dst_dt_size % 4 == 0)
122 can_use_movntps = conf_.is_data_size_bigger_than_L3
123 && are_data_filling_register_fully
124 && conf_.output_data_size % alignment == 0
125 && (tail_size_ == 0 || conf_.tag_kind == tag_kind::blocked);
126
127 return can_use_movntps;
128}
129
130template <cpu_isa_t isa, typename Vmm>
131std::size_t jit_uni_resampling_kernel_t<isa, Vmm>::calculate_tail_size() const {
132 std::size_t tail_size = 0;
133
134 if (conf_.tag_kind == tag_kind::nspc
135 || conf_.tag_kind == tag_kind::blocked) {
136 tail_size = conf_.c % simd_w_;
137 } else if (conf_.tag_kind == tag_kind::ncsp) {
138 if (conf_.alg == alg_kind::resampling_nearest)
139 tail_size = conf_.ow % simd_w_;
140 else
141 tail_size = (conf_.od * conf_.oh * conf_.ow) % simd_w_;
142 } else
143 assert(!"Incorrect memory tag passed to resampling primitive.");
144
145 return tail_size;
146}
147
148template <cpu_isa_t isa, typename Vmm>
149int jit_uni_resampling_kernel_t<isa, Vmm>::get_channels_to_compute_without_tail(
150 const bool is_tail_in_blocked_format) const {
151 assert(utils::one_of(conf_.tag_kind, tag_kind::blocked, tag_kind::nspc)
152 && "Incorrect memory tag.");
153
154 int c_to_compute_without_tail = 0;
155
156 if (conf_.tag_kind == tag_kind::blocked && is_tail_in_blocked_format) {
157 // Example:
158 // c = 27
159 // c_block = 16
160 // simd_w = 4
161 // result = ((27 % 16) / 4) * 4 = (11 / 4) * 4 = 2 * 4 = 8
162 c_to_compute_without_tail
163 = ((conf_.c % conf_.inner_stride) / simd_w_) * simd_w_;
164 } else
165 c_to_compute_without_tail = (conf_.inner_stride / simd_w_) * simd_w_;
166
167 return c_to_compute_without_tail;
168}
169
170template <cpu_isa_t isa, typename Vmm>
171std::map<data_type_t, io::io_saturation_conf_t>
172jit_uni_resampling_kernel_t<isa, Vmm>::create_saturation_vmm_map() const {
173
174 std::map<data_type_t, io::io_saturation_conf_t> saturation_map {};
175
176 if (conf_.is_saturation_needed)
177 saturation_map.emplace(conf_.dst_data_type,
178 io::io_saturation_conf_t {vmm_zero_saturation_.getIdx(),
179 vmm_saturation_ubound_.getIdx(), reg_tmp_});
180
181 return saturation_map;
182}
183
184template <cpu_isa_t isa, typename Vmm>
185void jit_uni_resampling_kernel_t<isa,
186 Vmm>::get_params_for_linear_in_c_oriented_format() {
187 mov(reg_src_ftl_, ptr[reg_param + GET_OFF(src)]);
188 add(reg_src_ftl_, ptr[reg_param + GET_OFF(src_offset_front)]);
189 add(reg_src_ftl_, ptr[reg_param + GET_OFF(src_offset_top)]);
190 mov(reg_src_ftr_, reg_src_ftl_);
191
192 if (conf_.ndims == 4 || conf_.ndims == 5) {
193 uni_vbroadcastss(weight_top_, ptr[reg_param + GET_OFF(weight_top)]);
194 uni_vbroadcastss(
195 weight_bottom_, ptr[reg_param + GET_OFF(weight_bottom)]);
196 mov(reg_src_fbl_, ptr[reg_param + GET_OFF(src)]);
197 add(reg_src_fbl_, ptr[reg_param + GET_OFF(src_offset_front)]);
198 add(reg_src_fbl_, ptr[reg_param + GET_OFF(src_offset_bottom)]);
199 mov(reg_src_fbr_, reg_src_fbl_);
200 }
201 if (conf_.ndims == 5) {
202 uni_vbroadcastss(weight_front_, ptr[reg_param + GET_OFF(weight_front)]);
203 uni_vbroadcastss(weight_back_, ptr[reg_param + GET_OFF(weight_back)]);
204 mov(reg_src_btl_, ptr[reg_param + GET_OFF(src)]);
205 add(reg_src_btl_, ptr[reg_param + GET_OFF(src_offset_back)]);
206 add(reg_src_btl_, ptr[reg_param + GET_OFF(src_offset_top)]);
207 mov(reg_src_btr_, reg_src_btl_);
208
209 mov(reg_src_bbl_, ptr[reg_param + GET_OFF(src)]);
210 add(reg_src_bbl_, ptr[reg_param + GET_OFF(src_offset_back)]);
211 add(reg_src_bbl_, ptr[reg_param + GET_OFF(src_offset_bottom)]);
212 mov(reg_src_bbr_, reg_src_bbl_);
213 }
214}
215
216template <cpu_isa_t isa, typename Vmm>
217void jit_uni_resampling_kernel_t<isa, Vmm>::preserve_zero_padding_in_post_ops(
218 const int data_idx) {
219 Vmm vmm_data(data_idx);
220 const Vmm vmm_zeros(vmm_tmp_.getIdx());
221
222 uni_vxorps(vmm_zeros, vmm_zeros, vmm_zeros);
223 if (is_superset(conf_.isa, avx512_core))
224 vblendmps(vmm_data | k_tail_mask_, vmm_zeros, vmm_data);
225 else {
226 std::bitset<8> tail_mask((1 << tail_size_) - 1);
227 tail_mask.flip();
228 uni_vblendps(vmm_data, vmm_data, vmm_zeros, tail_mask.to_ulong());
229 }
230}
231
232template <cpu_isa_t isa, typename Vmm>
233void jit_uni_resampling_kernel_t<isa, Vmm>::apply_sum(
234 const int data_idx, const bool is_tail) {
235 if (conf_.with_sum) {
236 assert(!conf_.sum_scales.empty()
237 && "No scales for sum post operation.");
238 const auto sum_injector = [this, data_idx, is_tail]() {
239 const Vmm vmm_prev_dst(vmm_tmp_.getIdx());
240 const Vmm vmm_dst(data_idx);
241
242 // Zeroing previous dst is needed to preserve zero padding.
243 if (is_tail && conf_.tag_kind == tag_kind::blocked)
244 uni_vxorps(vmm_prev_dst, vmm_prev_dst, vmm_prev_dst);
245
246 io_.at(conf_.dst_data_type)
247 ->load(ptr[reg_dst_], vmm_prev_dst, is_tail);
248 const float sum_scale = sum_scales_.front();
249 if (sum_scale == 1.f)
250 uni_vaddps(vmm_dst, vmm_dst, vmm_prev_dst);
251 else {
252 const Xmm xmm_sum_scale = Xmm(vmm_sum_scale_.getIdx());
253
254 // If the algorithm used is the linear algorithm, and the shape
255 // has 5 dimensions, then we have not enough gpr registers to use
256 // tmp registers. Therefore, if there is a need to use them it is
257 // needed to save their state and restore it after execution of all
258 // needed operations.
259 if (conf_.alg == alg_kind::resampling_linear
260 && conf_.ndims == 5)
261 push(reg_tmp1_);
262 mov(reg_tmp1_.cvt32(), float2int(sum_scale));
263 uni_vmovd(xmm_sum_scale, reg_tmp1_.cvt32());
264 if (conf_.alg == alg_kind::resampling_linear
265 && conf_.ndims == 5)
266 pop(reg_tmp1_);
267 uni_vbroadcastss(vmm_sum_scale_, xmm_sum_scale);
268 uni_vfmadd231ps(vmm_dst, vmm_prev_dst, vmm_sum_scale_);
269 }
270 sum_scales_.push(sum_scale);
271 sum_scales_.pop();
272 };
273 postops_injector_->set_lambda_injector(
274 primitive_kind::sum, sum_injector);
275 }
276}
277
278template <cpu_isa_t isa, typename Vmm>
279void jit_uni_resampling_kernel_t<isa, Vmm>::apply_postops(
280 const int data_idx, const bool is_tail, const Reg64 *reg_c) {
281 binary_injector::rhs_arg_dynamic_params_t rhs_arg_params;
282 const bool is_preserving_zero_padding_needed = is_tail && conf_.with_eltwise
283 && conf_.tag_kind == tag_kind::blocked;
284 bool update_c_offset = false;
285
286 if (conf_.with_sum) apply_sum(data_idx, is_tail);
287
288 if (conf_.with_binary) {
289 if (any_binary_postop_is_per_oc_bcast_type_
290 || any_binary_postop_is_per_oc_sp_bcast_type_) {
291 rhs_arg_params.vmm_idx_to_out_reg.emplace(data_idx, reg_dst_);
292 }
293 if (is_tail) { rhs_arg_params.vmm_tail_idx_.emplace(data_idx); }
294 }
295
296 if (update_c_offset) add(reg_c_offset, *reg_c);
297 postops_injector_->compute_vector(data_idx, rhs_arg_params);
298 if (is_preserving_zero_padding_needed)
299 preserve_zero_padding_in_post_ops(data_idx);
300 if (update_c_offset) sub(reg_c_offset, *reg_c);
301}
302
303template <cpu_isa_t isa, typename Vmm>
304void jit_uni_resampling_kernel_t<isa, Vmm>::preserve_zero_padding(
305 const int c_to_compute_without_tail, const bool is_tail) {
306 const int c_to_compute_with_tail
307 = is_tail ? utils::rnd_up(tail_size_, simd_w_) : 0;
308 const int c_to_zeroing = conf_.inner_stride - c_to_compute_without_tail
309 - c_to_compute_with_tail;
310
311 if (c_to_zeroing > 0) {
312 assert(c_to_zeroing % simd_w_ == 0);
313 const Vmm vmm_zeros(vmm_tmp_.getIdx());
314
315 for (int c = 0; c < c_to_zeroing; c += simd_w_) {
316 uni_vxorps(vmm_zeros, vmm_zeros, vmm_zeros);
317 const auto dst_address = ptr[reg_dst_ + c * conf_.dst_dt_size];
318 io_.at(conf_.dst_data_type)->store(vmm_zeros, dst_address, false);
319 }
320
321 add(reg_dst_, c_to_zeroing * conf_.dst_dt_size);
322 }
323}
324
325template <cpu_isa_t isa, typename Vmm>
326void jit_uni_resampling_kernel_t<isa, Vmm>::interpolate_c_oriented_format(
327 const c_oriented_generation_fn_t &generation_fn) {
328 const unsigned c_with_padding = utils::rnd_up(conf_.c, conf_.inner_stride);
329 const unsigned padding_size_to_preserve = c_with_padding - conf_.c;
330
331 if (padding_size_to_preserve > 0 && conf_.tag_kind == tag_kind::blocked) {
332 Label tail_label;
333 Label end_label;
334 cmp(reg_c_offset, utils::rnd_dn(conf_.c, conf_.inner_stride));
335 je(tail_label, T_NEAR);
336 generation_fn(false /*is_tail_in_blocked_format*/);
337 jmp(end_label, T_NEAR);
338 L(tail_label);
339 generation_fn(true /*is_tail_in_blocked_format*/);
340 L(end_label);
341 } else {
342 generation_fn(false /*is_tail_in_blocked_format*/);
343 }
344}
345
346template <cpu_isa_t isa, typename Vmm>
347void jit_uni_resampling_kernel_t<isa, Vmm>::nearest_ncsp_format() {
348 const Reg64 &reg_indices_h = reg_aux_src_0_;
349 const Reg64 &reg_indices_w = reg_aux_src_1_;
350 const Reg64 &reg_src_shifted = reg_aux_src_2_;
351 const Reg64 &reg_oh = reg_tmp1_;
352
353 auto nearest_interpolation = ([&](bool is_tail) {
354 uni_vmovdqu(vmm_indices_, ptr[reg_indices_w]);
355 io_.at(conf_.src_data_type)
356 ->gather(reg_src_shifted, vmm_indices_, vmm_src_, is_tail);
357 if (conf_.with_postops) apply_postops(vmm_src_.getIdx(), is_tail);
358 io_.at(conf_.dst_data_type)->store(vmm_src_, ptr[reg_dst_], is_tail);
359 });
360
361 mov(reg_indices_h, reg_indices_);
362 mov(reg_indices_w, reg_indices_);
363 add(reg_indices_w, conf_.oh * conf_.el_size_of_indices);
364
365 Label oh_loop_begin, oh_loop_end;
366 Label ow_loop_begin, ow_loop_end;
367 xor_(reg_oh, reg_oh);
368
369 L(oh_loop_begin);
370 {
371 cmp(reg_oh, conf_.oh);
372 jge(oh_loop_end, T_NEAR);
373 push(reg_oh);
374
375 mov(reg_work_, conf_.ow);
376 mov(reg_src_shifted, reg_src_);
377 xor_(reg_tmp_, reg_tmp_);
378 mov(reg_tmp_.cvt32(), dword[reg_indices_h]);
379 add(reg_src_shifted, reg_tmp_);
380
381 push(reg_indices_w);
382
383 L(ow_loop_begin);
384 {
385 cmp(reg_work_, simd_w_);
386 jl(ow_loop_end, T_NEAR);
387
388 nearest_interpolation(false);
389
390 add(reg_dst_, simd_w_ * conf_.dst_dt_size);
391 add(reg_indices_w, simd_w_ * conf_.el_size_of_indices);
392 sub(reg_work_, simd_w_);
393
394 jmp(ow_loop_begin, T_NEAR);
395 }
396 L(ow_loop_end);
397
398 if (tail_size_ > 0) {
399 nearest_interpolation(true);
400 add(reg_dst_, tail_size_ * conf_.dst_dt_size);
401 }
402
403 add(reg_indices_h, conf_.el_size_of_indices);
404 pop(reg_indices_w);
405 pop(reg_oh);
406 add(reg_oh, 1);
407 jmp(oh_loop_begin);
408 }
409 L(oh_loop_end);
410}
411
412template <cpu_isa_t isa, typename Vmm>
413void jit_uni_resampling_kernel_t<isa, Vmm>::nearest_c_oriented_format(
414 const bool is_tail_in_blocked_format) {
415 const int c_to_compute_without_tail
416 = get_channels_to_compute_without_tail(is_tail_in_blocked_format);
417 const bool insert_tail_processsing_code
418 = (conf_.tag_kind == tag_kind::nspc && tail_size_ > 0)
419 || is_tail_in_blocked_format;
420
421 const Reg64 &reg_c = reg_tmp_;
422 const Reg64 &reg_src_shifted = reg_aux_src_0_;
423
424 auto nearest_interpolation = [&](const bool is_tail) {
425 const bool load_and_store_with_tail
426 = is_tail && conf_.tag_kind == tag_kind::nspc;
427
428 io_.at(conf_.src_data_type)
429 ->load(ptr[reg_src_shifted], vmm_src_,
430 load_and_store_with_tail);
431 if (conf_.with_postops)
432 apply_postops(vmm_src_.getIdx(), is_tail, &reg_c);
433 io_.at(conf_.dst_data_type)
434 ->store(vmm_src_, ptr[reg_dst_], load_and_store_with_tail);
435 };
436
437 Label loop_begin, loop_end;
438
439 L(loop_begin);
440 {
441 cmp(reg_work_, 1);
442 jl(loop_end, T_NEAR);
443
444 mov(reg_src_shifted, reg_src_);
445 mov(reg_tmp1_.cvt32(), dword[reg_indices_]);
446 add(reg_src_shifted, reg_tmp1_);
447
448 Label c_loop_begin, c_loop_end;
449 xor_(reg_c, reg_c);
450 L(c_loop_begin);
451 {
452 cmp(reg_c, c_to_compute_without_tail);
453 je(c_loop_end, T_NEAR);
454
455 nearest_interpolation(false);
456
457 add(reg_src_shifted, simd_w_ * conf_.src_dt_size);
458 add(reg_dst_, simd_w_ * conf_.dst_dt_size);
459
460 add(reg_c, simd_w_);
461 jmp(c_loop_begin, T_NEAR);
462 }
463 L(c_loop_end);
464
465 if (insert_tail_processsing_code) {
466 if (tail_size_ > 0) {
467 nearest_interpolation(true);
468 if (conf_.tag_kind == tag_kind::nspc)
469 add(reg_dst_, tail_size_ * conf_.dst_dt_size);
470 else if (conf_.tag_kind == tag_kind::blocked) {
471 add(reg_dst_, simd_w_ * conf_.dst_dt_size);
472 }
473 }
474
475 if (conf_.tag_kind == tag_kind::blocked)
476 preserve_zero_padding(
477 c_to_compute_without_tail, is_tail_in_blocked_format);
478 }
479
480 add(reg_indices_, conf_.el_size_of_indices);
481
482 dec(reg_work_);
483 jmp(loop_begin, T_NEAR);
484 }
485 L(loop_end);
486}
487
488template <cpu_isa_t isa, typename Vmm>
489void jit_uni_resampling_kernel_t<isa, Vmm>::linear_ncsp_format() {
490 const unsigned indices_stride
491 = conf_.ow * conf_.oh * conf_.od * conf_.el_size_of_indices;
492 const unsigned weights_stride
493 = conf_.ow * conf_.oh * conf_.od * sizeof(float);
494
495 auto linear_interpolation = [&](const bool is_tail) {
496 const Vmm vmm_dst(vmm_idx(0));
497
498 for (unsigned i = 0; i < conf_.number_of_corners; i++) {
499 uni_vmovdqu(vmm_indices_, ptr[reg_indices_ + i * indices_stride]);
500 io_.at(conf_.src_data_type)
501 ->gather(reg_src_, vmm_indices_, Vmm(vmm_idx(i)), is_tail);
502 }
503
504 uni_vmovups(vmm_weights_, ptr[reg_weights]);
505 uni_vmulps(vmm_dst, vmm_dst, vmm_weights_);
506 for (unsigned i = 1; i < conf_.number_of_corners; i++) {
507 uni_vmovups(vmm_weights_, ptr[reg_weights + i * weights_stride]);
508 uni_vfmadd231ps(vmm_dst, Vmm(vmm_idx(i)), vmm_weights_);
509 }
510
511 if (conf_.with_postops) apply_postops(vmm_idx(0), is_tail);
512
513 if (conf_.is_saturation_needed && conf_.ndims == 5
514 && !is_superset(conf_.isa, avx512_core)) {
515 // When saturation is needed, and the shape has
516 // 5 dimensions, and we have only 16 Vmm registers,
517 // we have no space for holding information for saturation
518 // in registers. That is why we need to repeat saturation
519 // initialization before every store operation.
520 io_.init_saturate_f32({conf_.dst_data_type});
521 }
522
523 io_.at(conf_.dst_data_type)->store(vmm_dst, ptr[reg_dst_], is_tail);
524 };
525
526 Label loop_begin, loop_end;
527
528 L(loop_begin);
529 {
530 cmp(reg_work_, simd_w_);
531 jl(loop_end, T_NEAR);
532
533 linear_interpolation(false);
534
535 add(reg_dst_, simd_w_ * conf_.dst_dt_size);
536 add(reg_weights, simd_w_ * sizeof(float));
537 add(reg_indices_, simd_w_ * conf_.el_size_of_indices);
538 sub(reg_work_, simd_w_);
539
540 jmp(loop_begin, T_NEAR);
541 }
542 L(loop_end);
543
544 if (tail_size_ > 0) linear_interpolation(true);
545}
546
547template <cpu_isa_t isa, typename Vmm>
548void jit_uni_resampling_kernel_t<isa, Vmm>::linear_c_oriented_format(
549 const bool is_tail_in_blocked_format) {
550 const int c_to_compute_without_tail
551 = get_channels_to_compute_without_tail(is_tail_in_blocked_format);
552 const bool insert_tail_processsing_code
553 = (conf_.tag_kind == tag_kind::nspc && tail_size_ > 0)
554 || is_tail_in_blocked_format;
555
556 const Reg64 &reg_c = reg_tmp_;
557 const Reg64 &reg_index_left = reg_tmp_;
558 const Reg64 &reg_index_right = reg_tmp_;
559
560 const std::vector<std::reference_wrapper<const Reg64>> src_regs
561 = {reg_src_ftl_, reg_src_ftr_, reg_src_fbl_, reg_src_fbr_,
562 reg_src_btl_, reg_src_btr_, reg_src_bbl_, reg_src_bbr_};
563 const std::vector<std::reference_wrapper<const Vmm>> src_vmms
564 = {src_ftl_, src_ftr_, src_fbl_, src_fbr_, src_btl_, src_btr_,
565 src_bbl_, src_bbr_};
566
567 assert(src_regs.size() >= conf_.number_of_corners
568 && src_vmms.size() >= conf_.number_of_corners);
569
570 auto linear_interpolation = [&](const Reg64 &reg_c, const bool is_tail) {
571 const bool load_and_store_with_tail
572 = is_tail && conf_.tag_kind == tag_kind::nspc;
573
574 for (unsigned i = 0; i < conf_.number_of_corners; i++) {
575 io_.at(conf_.src_data_type)
576 ->load(ptr[src_regs[i].get()], src_vmms[i].get(),
577 load_and_store_with_tail);
578 }
579
580 // w_d[0]*(w_h[0]*(src[0][0][0]*w_w[0] + src[0][0][1]*w_w[1]) +
581 // w_h[1]*(src[0][1][0]*w_w[0] + src[0][1][1]*w_w[1]))
582 // +
583 // w_d[1]*(w_h[0]*(src[1][0][0]*w_w[0] + src[1][0][1]*w_w[1]) +
584 // w_h[1]*(src[1][1][0]*w_w[0] + src[1][1][1]*w_w[1]))
585 uni_vmulps(src_ftl_, src_ftl_, weight_left_);
586 uni_vfmadd231ps(src_ftl_, src_ftr_, weight_right_);
587 if (conf_.ndims == 4 || conf_.ndims == 5) {
588 uni_vmulps(src_fbl_, src_fbl_, weight_left_);
589 uni_vfmadd231ps(src_fbl_, src_fbr_, weight_right_);
590 uni_vmulps(src_ftl_, src_ftl_, weight_top_);
591 uni_vfmadd231ps(src_ftl_, src_fbl_, weight_bottom_);
592 }
593 if (conf_.ndims == 5) {
594 uni_vmulps(src_btl_, src_btl_, weight_left_);
595 uni_vfmadd231ps(src_btl_, src_btr_, weight_right_);
596 uni_vmulps(src_bbl_, src_bbl_, weight_left_);
597 uni_vfmadd231ps(src_bbl_, src_bbr_, weight_right_);
598 uni_vmulps(src_btl_, src_btl_, weight_top_);
599 uni_vfmadd231ps(src_btl_, src_bbl_, weight_bottom_);
600 uni_vmulps(src_ftl_, src_ftl_, weight_front_);
601 uni_vfmadd231ps(src_ftl_, src_btl_, weight_back_);
602 }
603
604 if (conf_.with_postops)
605 apply_postops(src_ftl_.getIdx(), is_tail, &reg_c);
606
607 if (conf_.is_saturation_needed && conf_.ndims == 5
608 && !is_superset(conf_.isa, avx512_core)) {
609 // When saturation is needed, and the shape has
610 // 5 dimensions, and we have only 16 Vmm registers,
611 // we have no space for holding information for saturation
612 // in registers. That is why we need to repeat saturation
613 // initialization before every store operation.
614 push(reg_tmp_);
615 io_.init_saturate_f32({conf_.dst_data_type});
616 pop(reg_tmp_);
617 }
618
619 io_.at(conf_.dst_data_type)
620 ->store(src_ftl_, ptr[reg_dst_], load_and_store_with_tail);
621 };
622
623 xor_(reg_index_left, reg_index_left);
624
625 Label loop_begin, loop_end;
626 L(loop_begin);
627 {
628 cmp(reg_work_, 1);
629 jl(loop_end, T_NEAR);
630
631 for (unsigned i = 0; i < conf_.number_of_corners; i++) {
632 push(src_regs[i]);
633 }
634
635 mov(reg_index_left.cvt32(), dword[reg_indices_]);
636 for (unsigned i = 0; i < conf_.number_of_corners / 2; i++) {
637 add(src_regs[2 * i], reg_index_left);
638 }
639 mov(reg_index_right.cvt32(),
640 dword[reg_indices_ + conf_.el_size_of_indices]);
641 for (unsigned i = 0; i < conf_.number_of_corners / 2; i++) {
642 add(src_regs[2 * i + 1], reg_index_right);
643 }
644
645 uni_vbroadcastss(weight_left_, ptr[reg_weights]);
646 uni_vbroadcastss(weight_right_, ptr[reg_weights + sizeof(float)]);
647
648 Label c_loop_begin, c_loop_end;
649 xor_(reg_c, reg_c);
650 L(c_loop_begin);
651 {
652 cmp(reg_c, c_to_compute_without_tail);
653 je(c_loop_end, T_NEAR);
654
655 linear_interpolation(reg_c, false);
656 add(reg_dst_, simd_w_ * conf_.dst_dt_size);
657
658 for (unsigned i = 0; i < conf_.number_of_corners; i++)
659 add(src_regs[i], simd_w_ * conf_.src_dt_size);
660
661 add(reg_c, simd_w_);
662 jmp(c_loop_begin, T_NEAR);
663 }
664 L(c_loop_end);
665
666 if (insert_tail_processsing_code) {
667 if (tail_size_ > 0) {
668 linear_interpolation(reg_c, true);
669 if (conf_.tag_kind == tag_kind::nspc)
670 add(reg_dst_, tail_size_ * conf_.dst_dt_size);
671 else if (conf_.tag_kind == tag_kind::blocked) {
672 add(reg_dst_, simd_w_ * conf_.dst_dt_size);
673 }
674 }
675
676 if (conf_.tag_kind == tag_kind::blocked)
677 preserve_zero_padding(
678 c_to_compute_without_tail, is_tail_in_blocked_format);
679 }
680
681 // During one loop cycle are read two values for left and
682 // right corners from both the weights and indices tables.
683 // These two values occurs one after the other in memory,
684 // so the address should be shifted by two elements.
685 add(reg_indices_, 2 * conf_.el_size_of_indices);
686 add(reg_weights, 2 * sizeof(float));
687
688 for (unsigned i = 0; i < conf_.number_of_corners; i++) {
689 pop(src_regs[(conf_.number_of_corners - 1) - i]);
690 }
691
692 dec(reg_work_);
693 jmp(loop_begin, T_NEAR);
694 }
695 L(loop_end);
696}
697
698template <cpu_isa_t isa, typename Vmm>
699void jit_uni_resampling_kernel_t<isa, Vmm>::generate() {
700 preamble();
701
702 io_.init_bf16();
703 if (conf_.is_saturation_needed)
704 io_.init_saturate_f32({conf_.dst_data_type});
705 // Preparing tail is needed for blocked format, because
706 // there is chance that padding will not be preserved when user use
707 // post-ops.
708 if (tail_size_ > 0
709 && (conf_.tag_kind != tag_kind::blocked || conf_.with_postops))
710 io_.prepare_tail_mask();
711 if (is_superset(conf_.isa, avx2) && conf_.tag_kind == tag_kind::ncsp) {
712 io_.init_full_mask();
713 io_.prepare_full_mask();
714 }
715
716 mov(reg_dst_, ptr[reg_param + GET_OFF(dst)]);
717 mov(reg_work_, ptr[reg_param + GET_OFF(batch_of_sp_points_to_process)]);
718 mov(reg_indices_, ptr[reg_param + GET_OFF(indices)]);
719 mov(reg_c_offset, ptr[reg_param + GET_OFF(c_offset)]);
720
721 if (conf_.alg == alg_kind::resampling_nearest) {
722 mov(reg_src_, ptr[reg_param + GET_OFF(src)]);
723 if (conf_.tag_kind == tag_kind::ncsp) {
724 nearest_ncsp_format();
725 } else if (conf_.tag_kind == tag_kind::nspc
726 || conf_.tag_kind == tag_kind::blocked) {
727 interpolate_c_oriented_format(
728 [&](const bool is_tail_in_blocked_format) {
729 nearest_c_oriented_format(is_tail_in_blocked_format);
730 });
731 }
732 } else if (conf_.alg == alg_kind::resampling_linear) {
733 mov(reg_weights, ptr[reg_param + GET_OFF(weights)]);
734 if (conf_.tag_kind == tag_kind::ncsp) {
735 mov(reg_src_, ptr[reg_param + GET_OFF(src)]);
736 linear_ncsp_format();
737 } else if (conf_.tag_kind == tag_kind::nspc
738 || conf_.tag_kind == tag_kind::blocked) {
739 get_params_for_linear_in_c_oriented_format();
740 interpolate_c_oriented_format(
741 [&](const bool is_tail_in_blocked_format) {
742 linear_c_oriented_format(is_tail_in_blocked_format);
743 });
744 }
745 }
746
747 postamble();
748
749 if (conf_.with_eltwise && postops_injector_)
750 postops_injector_->prepare_table();
751}
752
753template struct jit_uni_resampling_kernel_t<avx512_core_fp16, Zmm>;
754template struct jit_uni_resampling_kernel_t<avx512_core, Zmm>;
755template struct jit_uni_resampling_kernel_t<avx512_core, Ymm>;
756template struct jit_uni_resampling_kernel_t<avx, Ymm>;
757template struct jit_uni_resampling_kernel_t<avx, Xmm>;
758template struct jit_uni_resampling_kernel_t<sse41, Xmm>;
759
760} // namespace x64
761} // namespace cpu
762} // namespace impl
763} // namespace dnnl
764