1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <bitset> |
18 | |
19 | #include "common/c_types_map.hpp" |
20 | |
21 | #include "cpu/x64/jit_avx512_core_bf16cvt.hpp" |
22 | #include "cpu/x64/jit_generator.hpp" |
23 | #include "cpu/x64/jit_uni_resampling_kernel.hpp" |
24 | |
25 | namespace dnnl { |
26 | namespace impl { |
27 | namespace cpu { |
28 | namespace x64 { |
29 | |
30 | using namespace Xbyak; |
31 | using namespace format_tag; |
32 | using tag_kind = jit_memory_tag_kind_t; |
33 | |
34 | #define GET_OFF(field) offsetof(jit_resampling_call_s, field) |
35 | |
36 | template <cpu_isa_t isa, typename Vmm> |
37 | jit_uni_resampling_kernel_t<isa, Vmm>::jit_uni_resampling_kernel_t( |
38 | const jit_resampling_conf_t &conf, const memory_desc_t *dst_md) |
39 | : jit_uni_resampling_kernel_base_t(conf) |
40 | , tail_size_(calculate_tail_size()) |
41 | , io_(this, conf_.isa, {conf_.src_data_type, conf_.dst_data_type}, |
42 | {can_movntps_be_used()}, |
43 | io::io_tail_conf_t {simd_w_, tail_size_, k_tail_mask_, |
44 | vmm_tail_mask_.getIdx(), reg_tmp_}, |
45 | io::io_emu_bf16_conf_t {vmm_bf16_emu_1_, vmm_bf16_emu_2_, |
46 | vmm_bf16_emu_3_, reg_tmp_, vmm_bf16_emu_4_}, |
47 | create_saturation_vmm_map(), |
48 | io::io_gather_conf_t {simd_w_, k_full_mask_, |
49 | vmm_full_mask_.getIdx(), reg_tmp_, reg_tmp1_, |
50 | vmm_tmp_gather_.getIdx()}) { |
51 | if (conf_.with_postops) { |
52 | memory_desc_wrapper dst_d = memory_desc_wrapper(*dst_md); |
53 | |
54 | static constexpr bool preserve_gpr = true; |
55 | static constexpr bool preserve_vmm = false; |
56 | static constexpr bool use_exact_tail_scalar_bcast = true; |
57 | |
58 | const binary_injector::rhs_arg_static_params_t rhs_sp { |
59 | static_cast<size_t>(vmm_post_op_helper_.getIdx()), r14, r15, |
60 | r13, preserve_gpr, preserve_vmm, |
61 | GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig), dst_d, |
62 | tail_size_, k_tail_mask_, use_exact_tail_scalar_bcast}; |
63 | |
64 | const bcast_set_t accepted_broadcasts |
65 | = {broadcasting_strategy_t::scalar, |
66 | broadcasting_strategy_t::per_oc, |
67 | broadcasting_strategy_t::per_oc_spatial}; |
68 | const binary_injector::static_params_t bsp { |
69 | reg_param, accepted_broadcasts, rhs_sp}; |
70 | |
71 | postops_injector_ = utils::make_unique< |
72 | injector::jit_uni_postops_injector_t<isa, Vmm>>( |
73 | this, conf_.post_ops, bsp); |
74 | |
75 | std::tie(any_binary_postop_is_per_oc_bcast_type_, |
76 | any_binary_postop_is_per_oc_sp_bcast_type_) |
77 | = binary_injector_utils::bcast_strategies_present_tup( |
78 | conf_.post_ops.entry_, dst_d, |
79 | broadcasting_strategy_t::per_oc, |
80 | broadcasting_strategy_t::per_oc_spatial); |
81 | } |
82 | } |
83 | |
84 | template <cpu_isa_t isa, typename Vmm> |
85 | bool jit_uni_resampling_kernel_t<isa, Vmm>::can_movntps_be_used() const { |
86 | const std::size_t alignment = simd_w_ * conf_.dst_dt_size; |
87 | |
88 | assert(alignment > 0 && conf_.dst_dt_size > 0 |
89 | && "Incorrect output data type size." ); |
90 | |
91 | bool are_data_filling_register_fully = false; |
92 | switch (conf_.dst_data_type) { |
93 | case data_type::f32: |
94 | case data_type::s32: are_data_filling_register_fully = true; break; |
95 | case data_type::f16: |
96 | case data_type::bf16: |
97 | are_data_filling_register_fully = is_xmm_ ? false : true; |
98 | break; |
99 | case data_type::s8: |
100 | case data_type::u8: |
101 | are_data_filling_register_fully = is_zmm_ ? true : false; |
102 | break; |
103 | default: assert(!"Unsupported data type." ); break; |
104 | } |
105 | |
106 | // When movntps can be used: |
107 | // 1) There is no tail size because movntps has no possibility to store |
108 | // data with a mask. The blocked format is an exception from this rule |
109 | // because there is a padded area and for io operation, there is no use of masks. |
110 | // 2) Data are filling the register fully. Example: Zmm register can hold sixteen |
111 | // f32 values, so there is a possibility to calculate 16 values at the same time, |
112 | // but during store operation of i8 data the same sixteen values can hold only xmm |
113 | // register. If ymm will be used then eight values of f32 can be hold, but neither |
114 | // zmm nor ymm nor xmm can hold i8 data fully because data size is 64 bits only. |
115 | // 3) The memory operand must be aligned on a 16-byte (128-bit version), |
116 | // 32-byte (VEX.256 encoded version) or 64-byte (EVEX.512 encoded version) |
117 | // boundary otherwise a general-protection exception (#GP) will be generated. |
118 | // 4) Instruction is supported and the register is fully filled with data. |
119 | // 5) Data is big enough to see profit from using non-temporal stores. |
120 | bool can_use_movntps = false; |
121 | if (is_superset(conf_.isa, avx512_core) || conf_.dst_dt_size % 4 == 0) |
122 | can_use_movntps = conf_.is_data_size_bigger_than_L3 |
123 | && are_data_filling_register_fully |
124 | && conf_.output_data_size % alignment == 0 |
125 | && (tail_size_ == 0 || conf_.tag_kind == tag_kind::blocked); |
126 | |
127 | return can_use_movntps; |
128 | } |
129 | |
130 | template <cpu_isa_t isa, typename Vmm> |
131 | std::size_t jit_uni_resampling_kernel_t<isa, Vmm>::calculate_tail_size() const { |
132 | std::size_t tail_size = 0; |
133 | |
134 | if (conf_.tag_kind == tag_kind::nspc |
135 | || conf_.tag_kind == tag_kind::blocked) { |
136 | tail_size = conf_.c % simd_w_; |
137 | } else if (conf_.tag_kind == tag_kind::ncsp) { |
138 | if (conf_.alg == alg_kind::resampling_nearest) |
139 | tail_size = conf_.ow % simd_w_; |
140 | else |
141 | tail_size = (conf_.od * conf_.oh * conf_.ow) % simd_w_; |
142 | } else |
143 | assert(!"Incorrect memory tag passed to resampling primitive." ); |
144 | |
145 | return tail_size; |
146 | } |
147 | |
148 | template <cpu_isa_t isa, typename Vmm> |
149 | int jit_uni_resampling_kernel_t<isa, Vmm>::get_channels_to_compute_without_tail( |
150 | const bool is_tail_in_blocked_format) const { |
151 | assert(utils::one_of(conf_.tag_kind, tag_kind::blocked, tag_kind::nspc) |
152 | && "Incorrect memory tag." ); |
153 | |
154 | int c_to_compute_without_tail = 0; |
155 | |
156 | if (conf_.tag_kind == tag_kind::blocked && is_tail_in_blocked_format) { |
157 | // Example: |
158 | // c = 27 |
159 | // c_block = 16 |
160 | // simd_w = 4 |
161 | // result = ((27 % 16) / 4) * 4 = (11 / 4) * 4 = 2 * 4 = 8 |
162 | c_to_compute_without_tail |
163 | = ((conf_.c % conf_.inner_stride) / simd_w_) * simd_w_; |
164 | } else |
165 | c_to_compute_without_tail = (conf_.inner_stride / simd_w_) * simd_w_; |
166 | |
167 | return c_to_compute_without_tail; |
168 | } |
169 | |
170 | template <cpu_isa_t isa, typename Vmm> |
171 | std::map<data_type_t, io::io_saturation_conf_t> |
172 | jit_uni_resampling_kernel_t<isa, Vmm>::create_saturation_vmm_map() const { |
173 | |
174 | std::map<data_type_t, io::io_saturation_conf_t> saturation_map {}; |
175 | |
176 | if (conf_.is_saturation_needed) |
177 | saturation_map.emplace(conf_.dst_data_type, |
178 | io::io_saturation_conf_t {vmm_zero_saturation_.getIdx(), |
179 | vmm_saturation_ubound_.getIdx(), reg_tmp_}); |
180 | |
181 | return saturation_map; |
182 | } |
183 | |
184 | template <cpu_isa_t isa, typename Vmm> |
185 | void jit_uni_resampling_kernel_t<isa, |
186 | Vmm>::get_params_for_linear_in_c_oriented_format() { |
187 | mov(reg_src_ftl_, ptr[reg_param + GET_OFF(src)]); |
188 | add(reg_src_ftl_, ptr[reg_param + GET_OFF(src_offset_front)]); |
189 | add(reg_src_ftl_, ptr[reg_param + GET_OFF(src_offset_top)]); |
190 | mov(reg_src_ftr_, reg_src_ftl_); |
191 | |
192 | if (conf_.ndims == 4 || conf_.ndims == 5) { |
193 | uni_vbroadcastss(weight_top_, ptr[reg_param + GET_OFF(weight_top)]); |
194 | uni_vbroadcastss( |
195 | weight_bottom_, ptr[reg_param + GET_OFF(weight_bottom)]); |
196 | mov(reg_src_fbl_, ptr[reg_param + GET_OFF(src)]); |
197 | add(reg_src_fbl_, ptr[reg_param + GET_OFF(src_offset_front)]); |
198 | add(reg_src_fbl_, ptr[reg_param + GET_OFF(src_offset_bottom)]); |
199 | mov(reg_src_fbr_, reg_src_fbl_); |
200 | } |
201 | if (conf_.ndims == 5) { |
202 | uni_vbroadcastss(weight_front_, ptr[reg_param + GET_OFF(weight_front)]); |
203 | uni_vbroadcastss(weight_back_, ptr[reg_param + GET_OFF(weight_back)]); |
204 | mov(reg_src_btl_, ptr[reg_param + GET_OFF(src)]); |
205 | add(reg_src_btl_, ptr[reg_param + GET_OFF(src_offset_back)]); |
206 | add(reg_src_btl_, ptr[reg_param + GET_OFF(src_offset_top)]); |
207 | mov(reg_src_btr_, reg_src_btl_); |
208 | |
209 | mov(reg_src_bbl_, ptr[reg_param + GET_OFF(src)]); |
210 | add(reg_src_bbl_, ptr[reg_param + GET_OFF(src_offset_back)]); |
211 | add(reg_src_bbl_, ptr[reg_param + GET_OFF(src_offset_bottom)]); |
212 | mov(reg_src_bbr_, reg_src_bbl_); |
213 | } |
214 | } |
215 | |
216 | template <cpu_isa_t isa, typename Vmm> |
217 | void jit_uni_resampling_kernel_t<isa, Vmm>::preserve_zero_padding_in_post_ops( |
218 | const int data_idx) { |
219 | Vmm vmm_data(data_idx); |
220 | const Vmm vmm_zeros(vmm_tmp_.getIdx()); |
221 | |
222 | uni_vxorps(vmm_zeros, vmm_zeros, vmm_zeros); |
223 | if (is_superset(conf_.isa, avx512_core)) |
224 | vblendmps(vmm_data | k_tail_mask_, vmm_zeros, vmm_data); |
225 | else { |
226 | std::bitset<8> tail_mask((1 << tail_size_) - 1); |
227 | tail_mask.flip(); |
228 | uni_vblendps(vmm_data, vmm_data, vmm_zeros, tail_mask.to_ulong()); |
229 | } |
230 | } |
231 | |
232 | template <cpu_isa_t isa, typename Vmm> |
233 | void jit_uni_resampling_kernel_t<isa, Vmm>::apply_sum( |
234 | const int data_idx, const bool is_tail) { |
235 | if (conf_.with_sum) { |
236 | assert(!conf_.sum_scales.empty() |
237 | && "No scales for sum post operation." ); |
238 | const auto sum_injector = [this, data_idx, is_tail]() { |
239 | const Vmm vmm_prev_dst(vmm_tmp_.getIdx()); |
240 | const Vmm vmm_dst(data_idx); |
241 | |
242 | // Zeroing previous dst is needed to preserve zero padding. |
243 | if (is_tail && conf_.tag_kind == tag_kind::blocked) |
244 | uni_vxorps(vmm_prev_dst, vmm_prev_dst, vmm_prev_dst); |
245 | |
246 | io_.at(conf_.dst_data_type) |
247 | ->load(ptr[reg_dst_], vmm_prev_dst, is_tail); |
248 | const float sum_scale = sum_scales_.front(); |
249 | if (sum_scale == 1.f) |
250 | uni_vaddps(vmm_dst, vmm_dst, vmm_prev_dst); |
251 | else { |
252 | const Xmm xmm_sum_scale = Xmm(vmm_sum_scale_.getIdx()); |
253 | |
254 | // If the algorithm used is the linear algorithm, and the shape |
255 | // has 5 dimensions, then we have not enough gpr registers to use |
256 | // tmp registers. Therefore, if there is a need to use them it is |
257 | // needed to save their state and restore it after execution of all |
258 | // needed operations. |
259 | if (conf_.alg == alg_kind::resampling_linear |
260 | && conf_.ndims == 5) |
261 | push(reg_tmp1_); |
262 | mov(reg_tmp1_.cvt32(), float2int(sum_scale)); |
263 | uni_vmovd(xmm_sum_scale, reg_tmp1_.cvt32()); |
264 | if (conf_.alg == alg_kind::resampling_linear |
265 | && conf_.ndims == 5) |
266 | pop(reg_tmp1_); |
267 | uni_vbroadcastss(vmm_sum_scale_, xmm_sum_scale); |
268 | uni_vfmadd231ps(vmm_dst, vmm_prev_dst, vmm_sum_scale_); |
269 | } |
270 | sum_scales_.push(sum_scale); |
271 | sum_scales_.pop(); |
272 | }; |
273 | postops_injector_->set_lambda_injector( |
274 | primitive_kind::sum, sum_injector); |
275 | } |
276 | } |
277 | |
278 | template <cpu_isa_t isa, typename Vmm> |
279 | void jit_uni_resampling_kernel_t<isa, Vmm>::apply_postops( |
280 | const int data_idx, const bool is_tail, const Reg64 *reg_c) { |
281 | binary_injector::rhs_arg_dynamic_params_t rhs_arg_params; |
282 | const bool is_preserving_zero_padding_needed = is_tail && conf_.with_eltwise |
283 | && conf_.tag_kind == tag_kind::blocked; |
284 | bool update_c_offset = false; |
285 | |
286 | if (conf_.with_sum) apply_sum(data_idx, is_tail); |
287 | |
288 | if (conf_.with_binary) { |
289 | if (any_binary_postop_is_per_oc_bcast_type_ |
290 | || any_binary_postop_is_per_oc_sp_bcast_type_) { |
291 | rhs_arg_params.vmm_idx_to_out_reg.emplace(data_idx, reg_dst_); |
292 | } |
293 | if (is_tail) { rhs_arg_params.vmm_tail_idx_.emplace(data_idx); } |
294 | } |
295 | |
296 | if (update_c_offset) add(reg_c_offset, *reg_c); |
297 | postops_injector_->compute_vector(data_idx, rhs_arg_params); |
298 | if (is_preserving_zero_padding_needed) |
299 | preserve_zero_padding_in_post_ops(data_idx); |
300 | if (update_c_offset) sub(reg_c_offset, *reg_c); |
301 | } |
302 | |
303 | template <cpu_isa_t isa, typename Vmm> |
304 | void jit_uni_resampling_kernel_t<isa, Vmm>::preserve_zero_padding( |
305 | const int c_to_compute_without_tail, const bool is_tail) { |
306 | const int c_to_compute_with_tail |
307 | = is_tail ? utils::rnd_up(tail_size_, simd_w_) : 0; |
308 | const int c_to_zeroing = conf_.inner_stride - c_to_compute_without_tail |
309 | - c_to_compute_with_tail; |
310 | |
311 | if (c_to_zeroing > 0) { |
312 | assert(c_to_zeroing % simd_w_ == 0); |
313 | const Vmm vmm_zeros(vmm_tmp_.getIdx()); |
314 | |
315 | for (int c = 0; c < c_to_zeroing; c += simd_w_) { |
316 | uni_vxorps(vmm_zeros, vmm_zeros, vmm_zeros); |
317 | const auto dst_address = ptr[reg_dst_ + c * conf_.dst_dt_size]; |
318 | io_.at(conf_.dst_data_type)->store(vmm_zeros, dst_address, false); |
319 | } |
320 | |
321 | add(reg_dst_, c_to_zeroing * conf_.dst_dt_size); |
322 | } |
323 | } |
324 | |
325 | template <cpu_isa_t isa, typename Vmm> |
326 | void jit_uni_resampling_kernel_t<isa, Vmm>::interpolate_c_oriented_format( |
327 | const c_oriented_generation_fn_t &generation_fn) { |
328 | const unsigned c_with_padding = utils::rnd_up(conf_.c, conf_.inner_stride); |
329 | const unsigned padding_size_to_preserve = c_with_padding - conf_.c; |
330 | |
331 | if (padding_size_to_preserve > 0 && conf_.tag_kind == tag_kind::blocked) { |
332 | Label tail_label; |
333 | Label end_label; |
334 | cmp(reg_c_offset, utils::rnd_dn(conf_.c, conf_.inner_stride)); |
335 | je(tail_label, T_NEAR); |
336 | generation_fn(false /*is_tail_in_blocked_format*/); |
337 | jmp(end_label, T_NEAR); |
338 | L(tail_label); |
339 | generation_fn(true /*is_tail_in_blocked_format*/); |
340 | L(end_label); |
341 | } else { |
342 | generation_fn(false /*is_tail_in_blocked_format*/); |
343 | } |
344 | } |
345 | |
346 | template <cpu_isa_t isa, typename Vmm> |
347 | void jit_uni_resampling_kernel_t<isa, Vmm>::nearest_ncsp_format() { |
348 | const Reg64 ®_indices_h = reg_aux_src_0_; |
349 | const Reg64 ®_indices_w = reg_aux_src_1_; |
350 | const Reg64 ®_src_shifted = reg_aux_src_2_; |
351 | const Reg64 ®_oh = reg_tmp1_; |
352 | |
353 | auto nearest_interpolation = ([&](bool is_tail) { |
354 | uni_vmovdqu(vmm_indices_, ptr[reg_indices_w]); |
355 | io_.at(conf_.src_data_type) |
356 | ->gather(reg_src_shifted, vmm_indices_, vmm_src_, is_tail); |
357 | if (conf_.with_postops) apply_postops(vmm_src_.getIdx(), is_tail); |
358 | io_.at(conf_.dst_data_type)->store(vmm_src_, ptr[reg_dst_], is_tail); |
359 | }); |
360 | |
361 | mov(reg_indices_h, reg_indices_); |
362 | mov(reg_indices_w, reg_indices_); |
363 | add(reg_indices_w, conf_.oh * conf_.el_size_of_indices); |
364 | |
365 | Label oh_loop_begin, oh_loop_end; |
366 | Label ow_loop_begin, ow_loop_end; |
367 | xor_(reg_oh, reg_oh); |
368 | |
369 | L(oh_loop_begin); |
370 | { |
371 | cmp(reg_oh, conf_.oh); |
372 | jge(oh_loop_end, T_NEAR); |
373 | push(reg_oh); |
374 | |
375 | mov(reg_work_, conf_.ow); |
376 | mov(reg_src_shifted, reg_src_); |
377 | xor_(reg_tmp_, reg_tmp_); |
378 | mov(reg_tmp_.cvt32(), dword[reg_indices_h]); |
379 | add(reg_src_shifted, reg_tmp_); |
380 | |
381 | push(reg_indices_w); |
382 | |
383 | L(ow_loop_begin); |
384 | { |
385 | cmp(reg_work_, simd_w_); |
386 | jl(ow_loop_end, T_NEAR); |
387 | |
388 | nearest_interpolation(false); |
389 | |
390 | add(reg_dst_, simd_w_ * conf_.dst_dt_size); |
391 | add(reg_indices_w, simd_w_ * conf_.el_size_of_indices); |
392 | sub(reg_work_, simd_w_); |
393 | |
394 | jmp(ow_loop_begin, T_NEAR); |
395 | } |
396 | L(ow_loop_end); |
397 | |
398 | if (tail_size_ > 0) { |
399 | nearest_interpolation(true); |
400 | add(reg_dst_, tail_size_ * conf_.dst_dt_size); |
401 | } |
402 | |
403 | add(reg_indices_h, conf_.el_size_of_indices); |
404 | pop(reg_indices_w); |
405 | pop(reg_oh); |
406 | add(reg_oh, 1); |
407 | jmp(oh_loop_begin); |
408 | } |
409 | L(oh_loop_end); |
410 | } |
411 | |
412 | template <cpu_isa_t isa, typename Vmm> |
413 | void jit_uni_resampling_kernel_t<isa, Vmm>::nearest_c_oriented_format( |
414 | const bool is_tail_in_blocked_format) { |
415 | const int c_to_compute_without_tail |
416 | = get_channels_to_compute_without_tail(is_tail_in_blocked_format); |
417 | const bool insert_tail_processsing_code |
418 | = (conf_.tag_kind == tag_kind::nspc && tail_size_ > 0) |
419 | || is_tail_in_blocked_format; |
420 | |
421 | const Reg64 ®_c = reg_tmp_; |
422 | const Reg64 ®_src_shifted = reg_aux_src_0_; |
423 | |
424 | auto nearest_interpolation = [&](const bool is_tail) { |
425 | const bool load_and_store_with_tail |
426 | = is_tail && conf_.tag_kind == tag_kind::nspc; |
427 | |
428 | io_.at(conf_.src_data_type) |
429 | ->load(ptr[reg_src_shifted], vmm_src_, |
430 | load_and_store_with_tail); |
431 | if (conf_.with_postops) |
432 | apply_postops(vmm_src_.getIdx(), is_tail, ®_c); |
433 | io_.at(conf_.dst_data_type) |
434 | ->store(vmm_src_, ptr[reg_dst_], load_and_store_with_tail); |
435 | }; |
436 | |
437 | Label loop_begin, loop_end; |
438 | |
439 | L(loop_begin); |
440 | { |
441 | cmp(reg_work_, 1); |
442 | jl(loop_end, T_NEAR); |
443 | |
444 | mov(reg_src_shifted, reg_src_); |
445 | mov(reg_tmp1_.cvt32(), dword[reg_indices_]); |
446 | add(reg_src_shifted, reg_tmp1_); |
447 | |
448 | Label c_loop_begin, c_loop_end; |
449 | xor_(reg_c, reg_c); |
450 | L(c_loop_begin); |
451 | { |
452 | cmp(reg_c, c_to_compute_without_tail); |
453 | je(c_loop_end, T_NEAR); |
454 | |
455 | nearest_interpolation(false); |
456 | |
457 | add(reg_src_shifted, simd_w_ * conf_.src_dt_size); |
458 | add(reg_dst_, simd_w_ * conf_.dst_dt_size); |
459 | |
460 | add(reg_c, simd_w_); |
461 | jmp(c_loop_begin, T_NEAR); |
462 | } |
463 | L(c_loop_end); |
464 | |
465 | if (insert_tail_processsing_code) { |
466 | if (tail_size_ > 0) { |
467 | nearest_interpolation(true); |
468 | if (conf_.tag_kind == tag_kind::nspc) |
469 | add(reg_dst_, tail_size_ * conf_.dst_dt_size); |
470 | else if (conf_.tag_kind == tag_kind::blocked) { |
471 | add(reg_dst_, simd_w_ * conf_.dst_dt_size); |
472 | } |
473 | } |
474 | |
475 | if (conf_.tag_kind == tag_kind::blocked) |
476 | preserve_zero_padding( |
477 | c_to_compute_without_tail, is_tail_in_blocked_format); |
478 | } |
479 | |
480 | add(reg_indices_, conf_.el_size_of_indices); |
481 | |
482 | dec(reg_work_); |
483 | jmp(loop_begin, T_NEAR); |
484 | } |
485 | L(loop_end); |
486 | } |
487 | |
488 | template <cpu_isa_t isa, typename Vmm> |
489 | void jit_uni_resampling_kernel_t<isa, Vmm>::linear_ncsp_format() { |
490 | const unsigned indices_stride |
491 | = conf_.ow * conf_.oh * conf_.od * conf_.el_size_of_indices; |
492 | const unsigned weights_stride |
493 | = conf_.ow * conf_.oh * conf_.od * sizeof(float); |
494 | |
495 | auto linear_interpolation = [&](const bool is_tail) { |
496 | const Vmm vmm_dst(vmm_idx(0)); |
497 | |
498 | for (unsigned i = 0; i < conf_.number_of_corners; i++) { |
499 | uni_vmovdqu(vmm_indices_, ptr[reg_indices_ + i * indices_stride]); |
500 | io_.at(conf_.src_data_type) |
501 | ->gather(reg_src_, vmm_indices_, Vmm(vmm_idx(i)), is_tail); |
502 | } |
503 | |
504 | uni_vmovups(vmm_weights_, ptr[reg_weights]); |
505 | uni_vmulps(vmm_dst, vmm_dst, vmm_weights_); |
506 | for (unsigned i = 1; i < conf_.number_of_corners; i++) { |
507 | uni_vmovups(vmm_weights_, ptr[reg_weights + i * weights_stride]); |
508 | uni_vfmadd231ps(vmm_dst, Vmm(vmm_idx(i)), vmm_weights_); |
509 | } |
510 | |
511 | if (conf_.with_postops) apply_postops(vmm_idx(0), is_tail); |
512 | |
513 | if (conf_.is_saturation_needed && conf_.ndims == 5 |
514 | && !is_superset(conf_.isa, avx512_core)) { |
515 | // When saturation is needed, and the shape has |
516 | // 5 dimensions, and we have only 16 Vmm registers, |
517 | // we have no space for holding information for saturation |
518 | // in registers. That is why we need to repeat saturation |
519 | // initialization before every store operation. |
520 | io_.init_saturate_f32({conf_.dst_data_type}); |
521 | } |
522 | |
523 | io_.at(conf_.dst_data_type)->store(vmm_dst, ptr[reg_dst_], is_tail); |
524 | }; |
525 | |
526 | Label loop_begin, loop_end; |
527 | |
528 | L(loop_begin); |
529 | { |
530 | cmp(reg_work_, simd_w_); |
531 | jl(loop_end, T_NEAR); |
532 | |
533 | linear_interpolation(false); |
534 | |
535 | add(reg_dst_, simd_w_ * conf_.dst_dt_size); |
536 | add(reg_weights, simd_w_ * sizeof(float)); |
537 | add(reg_indices_, simd_w_ * conf_.el_size_of_indices); |
538 | sub(reg_work_, simd_w_); |
539 | |
540 | jmp(loop_begin, T_NEAR); |
541 | } |
542 | L(loop_end); |
543 | |
544 | if (tail_size_ > 0) linear_interpolation(true); |
545 | } |
546 | |
547 | template <cpu_isa_t isa, typename Vmm> |
548 | void jit_uni_resampling_kernel_t<isa, Vmm>::linear_c_oriented_format( |
549 | const bool is_tail_in_blocked_format) { |
550 | const int c_to_compute_without_tail |
551 | = get_channels_to_compute_without_tail(is_tail_in_blocked_format); |
552 | const bool insert_tail_processsing_code |
553 | = (conf_.tag_kind == tag_kind::nspc && tail_size_ > 0) |
554 | || is_tail_in_blocked_format; |
555 | |
556 | const Reg64 ®_c = reg_tmp_; |
557 | const Reg64 ®_index_left = reg_tmp_; |
558 | const Reg64 ®_index_right = reg_tmp_; |
559 | |
560 | const std::vector<std::reference_wrapper<const Reg64>> src_regs |
561 | = {reg_src_ftl_, reg_src_ftr_, reg_src_fbl_, reg_src_fbr_, |
562 | reg_src_btl_, reg_src_btr_, reg_src_bbl_, reg_src_bbr_}; |
563 | const std::vector<std::reference_wrapper<const Vmm>> src_vmms |
564 | = {src_ftl_, src_ftr_, src_fbl_, src_fbr_, src_btl_, src_btr_, |
565 | src_bbl_, src_bbr_}; |
566 | |
567 | assert(src_regs.size() >= conf_.number_of_corners |
568 | && src_vmms.size() >= conf_.number_of_corners); |
569 | |
570 | auto linear_interpolation = [&](const Reg64 ®_c, const bool is_tail) { |
571 | const bool load_and_store_with_tail |
572 | = is_tail && conf_.tag_kind == tag_kind::nspc; |
573 | |
574 | for (unsigned i = 0; i < conf_.number_of_corners; i++) { |
575 | io_.at(conf_.src_data_type) |
576 | ->load(ptr[src_regs[i].get()], src_vmms[i].get(), |
577 | load_and_store_with_tail); |
578 | } |
579 | |
580 | // w_d[0]*(w_h[0]*(src[0][0][0]*w_w[0] + src[0][0][1]*w_w[1]) + |
581 | // w_h[1]*(src[0][1][0]*w_w[0] + src[0][1][1]*w_w[1])) |
582 | // + |
583 | // w_d[1]*(w_h[0]*(src[1][0][0]*w_w[0] + src[1][0][1]*w_w[1]) + |
584 | // w_h[1]*(src[1][1][0]*w_w[0] + src[1][1][1]*w_w[1])) |
585 | uni_vmulps(src_ftl_, src_ftl_, weight_left_); |
586 | uni_vfmadd231ps(src_ftl_, src_ftr_, weight_right_); |
587 | if (conf_.ndims == 4 || conf_.ndims == 5) { |
588 | uni_vmulps(src_fbl_, src_fbl_, weight_left_); |
589 | uni_vfmadd231ps(src_fbl_, src_fbr_, weight_right_); |
590 | uni_vmulps(src_ftl_, src_ftl_, weight_top_); |
591 | uni_vfmadd231ps(src_ftl_, src_fbl_, weight_bottom_); |
592 | } |
593 | if (conf_.ndims == 5) { |
594 | uni_vmulps(src_btl_, src_btl_, weight_left_); |
595 | uni_vfmadd231ps(src_btl_, src_btr_, weight_right_); |
596 | uni_vmulps(src_bbl_, src_bbl_, weight_left_); |
597 | uni_vfmadd231ps(src_bbl_, src_bbr_, weight_right_); |
598 | uni_vmulps(src_btl_, src_btl_, weight_top_); |
599 | uni_vfmadd231ps(src_btl_, src_bbl_, weight_bottom_); |
600 | uni_vmulps(src_ftl_, src_ftl_, weight_front_); |
601 | uni_vfmadd231ps(src_ftl_, src_btl_, weight_back_); |
602 | } |
603 | |
604 | if (conf_.with_postops) |
605 | apply_postops(src_ftl_.getIdx(), is_tail, ®_c); |
606 | |
607 | if (conf_.is_saturation_needed && conf_.ndims == 5 |
608 | && !is_superset(conf_.isa, avx512_core)) { |
609 | // When saturation is needed, and the shape has |
610 | // 5 dimensions, and we have only 16 Vmm registers, |
611 | // we have no space for holding information for saturation |
612 | // in registers. That is why we need to repeat saturation |
613 | // initialization before every store operation. |
614 | push(reg_tmp_); |
615 | io_.init_saturate_f32({conf_.dst_data_type}); |
616 | pop(reg_tmp_); |
617 | } |
618 | |
619 | io_.at(conf_.dst_data_type) |
620 | ->store(src_ftl_, ptr[reg_dst_], load_and_store_with_tail); |
621 | }; |
622 | |
623 | xor_(reg_index_left, reg_index_left); |
624 | |
625 | Label loop_begin, loop_end; |
626 | L(loop_begin); |
627 | { |
628 | cmp(reg_work_, 1); |
629 | jl(loop_end, T_NEAR); |
630 | |
631 | for (unsigned i = 0; i < conf_.number_of_corners; i++) { |
632 | push(src_regs[i]); |
633 | } |
634 | |
635 | mov(reg_index_left.cvt32(), dword[reg_indices_]); |
636 | for (unsigned i = 0; i < conf_.number_of_corners / 2; i++) { |
637 | add(src_regs[2 * i], reg_index_left); |
638 | } |
639 | mov(reg_index_right.cvt32(), |
640 | dword[reg_indices_ + conf_.el_size_of_indices]); |
641 | for (unsigned i = 0; i < conf_.number_of_corners / 2; i++) { |
642 | add(src_regs[2 * i + 1], reg_index_right); |
643 | } |
644 | |
645 | uni_vbroadcastss(weight_left_, ptr[reg_weights]); |
646 | uni_vbroadcastss(weight_right_, ptr[reg_weights + sizeof(float)]); |
647 | |
648 | Label c_loop_begin, c_loop_end; |
649 | xor_(reg_c, reg_c); |
650 | L(c_loop_begin); |
651 | { |
652 | cmp(reg_c, c_to_compute_without_tail); |
653 | je(c_loop_end, T_NEAR); |
654 | |
655 | linear_interpolation(reg_c, false); |
656 | add(reg_dst_, simd_w_ * conf_.dst_dt_size); |
657 | |
658 | for (unsigned i = 0; i < conf_.number_of_corners; i++) |
659 | add(src_regs[i], simd_w_ * conf_.src_dt_size); |
660 | |
661 | add(reg_c, simd_w_); |
662 | jmp(c_loop_begin, T_NEAR); |
663 | } |
664 | L(c_loop_end); |
665 | |
666 | if (insert_tail_processsing_code) { |
667 | if (tail_size_ > 0) { |
668 | linear_interpolation(reg_c, true); |
669 | if (conf_.tag_kind == tag_kind::nspc) |
670 | add(reg_dst_, tail_size_ * conf_.dst_dt_size); |
671 | else if (conf_.tag_kind == tag_kind::blocked) { |
672 | add(reg_dst_, simd_w_ * conf_.dst_dt_size); |
673 | } |
674 | } |
675 | |
676 | if (conf_.tag_kind == tag_kind::blocked) |
677 | preserve_zero_padding( |
678 | c_to_compute_without_tail, is_tail_in_blocked_format); |
679 | } |
680 | |
681 | // During one loop cycle are read two values for left and |
682 | // right corners from both the weights and indices tables. |
683 | // These two values occurs one after the other in memory, |
684 | // so the address should be shifted by two elements. |
685 | add(reg_indices_, 2 * conf_.el_size_of_indices); |
686 | add(reg_weights, 2 * sizeof(float)); |
687 | |
688 | for (unsigned i = 0; i < conf_.number_of_corners; i++) { |
689 | pop(src_regs[(conf_.number_of_corners - 1) - i]); |
690 | } |
691 | |
692 | dec(reg_work_); |
693 | jmp(loop_begin, T_NEAR); |
694 | } |
695 | L(loop_end); |
696 | } |
697 | |
698 | template <cpu_isa_t isa, typename Vmm> |
699 | void jit_uni_resampling_kernel_t<isa, Vmm>::generate() { |
700 | preamble(); |
701 | |
702 | io_.init_bf16(); |
703 | if (conf_.is_saturation_needed) |
704 | io_.init_saturate_f32({conf_.dst_data_type}); |
705 | // Preparing tail is needed for blocked format, because |
706 | // there is chance that padding will not be preserved when user use |
707 | // post-ops. |
708 | if (tail_size_ > 0 |
709 | && (conf_.tag_kind != tag_kind::blocked || conf_.with_postops)) |
710 | io_.prepare_tail_mask(); |
711 | if (is_superset(conf_.isa, avx2) && conf_.tag_kind == tag_kind::ncsp) { |
712 | io_.init_full_mask(); |
713 | io_.prepare_full_mask(); |
714 | } |
715 | |
716 | mov(reg_dst_, ptr[reg_param + GET_OFF(dst)]); |
717 | mov(reg_work_, ptr[reg_param + GET_OFF(batch_of_sp_points_to_process)]); |
718 | mov(reg_indices_, ptr[reg_param + GET_OFF(indices)]); |
719 | mov(reg_c_offset, ptr[reg_param + GET_OFF(c_offset)]); |
720 | |
721 | if (conf_.alg == alg_kind::resampling_nearest) { |
722 | mov(reg_src_, ptr[reg_param + GET_OFF(src)]); |
723 | if (conf_.tag_kind == tag_kind::ncsp) { |
724 | nearest_ncsp_format(); |
725 | } else if (conf_.tag_kind == tag_kind::nspc |
726 | || conf_.tag_kind == tag_kind::blocked) { |
727 | interpolate_c_oriented_format( |
728 | [&](const bool is_tail_in_blocked_format) { |
729 | nearest_c_oriented_format(is_tail_in_blocked_format); |
730 | }); |
731 | } |
732 | } else if (conf_.alg == alg_kind::resampling_linear) { |
733 | mov(reg_weights, ptr[reg_param + GET_OFF(weights)]); |
734 | if (conf_.tag_kind == tag_kind::ncsp) { |
735 | mov(reg_src_, ptr[reg_param + GET_OFF(src)]); |
736 | linear_ncsp_format(); |
737 | } else if (conf_.tag_kind == tag_kind::nspc |
738 | || conf_.tag_kind == tag_kind::blocked) { |
739 | get_params_for_linear_in_c_oriented_format(); |
740 | interpolate_c_oriented_format( |
741 | [&](const bool is_tail_in_blocked_format) { |
742 | linear_c_oriented_format(is_tail_in_blocked_format); |
743 | }); |
744 | } |
745 | } |
746 | |
747 | postamble(); |
748 | |
749 | if (conf_.with_eltwise && postops_injector_) |
750 | postops_injector_->prepare_table(); |
751 | } |
752 | |
753 | template struct jit_uni_resampling_kernel_t<avx512_core_fp16, Zmm>; |
754 | template struct jit_uni_resampling_kernel_t<avx512_core, Zmm>; |
755 | template struct jit_uni_resampling_kernel_t<avx512_core, Ymm>; |
756 | template struct jit_uni_resampling_kernel_t<avx, Ymm>; |
757 | template struct jit_uni_resampling_kernel_t<avx, Xmm>; |
758 | template struct jit_uni_resampling_kernel_t<sse41, Xmm>; |
759 | |
760 | } // namespace x64 |
761 | } // namespace cpu |
762 | } // namespace impl |
763 | } // namespace dnnl |
764 | |