1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "common/bfloat16.hpp"
18#include "common/c_types_map.hpp"
19#include "common/dnnl_thread.hpp"
20#include "common/type_helpers.hpp"
21#include "common/utils.hpp"
22
23#include "cpu/x64/jit_generator.hpp"
24
25#include "cpu/x64/jit_avx512_core_resampling.hpp"
26
27#include "utils/jit_io_helper.hpp"
28
29namespace dnnl {
30namespace impl {
31namespace cpu {
32namespace x64 {
33
34using namespace Xbyak;
35
36static bool impl_supports_datatype(data_type_t data_type) {
37 switch (data_type) {
38 case data_type::bf16: return x64::mayiuse(x64::avx512_core);
39 case data_type::f16: return x64::mayiuse(x64::avx512_core_fp16);
40 case data_type::f32:
41 case data_type::s32:
42 case data_type::s8:
43 case data_type::u8: return true;
44 default: return false;
45 }
46}
47
48#define GET_OFF(field) offsetof(jit_resampling_args_t, field)
49struct jit_resampling_args_t {
50 const void *src; // fwd: src bwd: diff_dst
51 const void *dst; // fwd: dst bwd: diff_src
52 dim_t d; // fwd: od bwd: id
53 dim_t h; // fwd: oh bwd: ih
54 dim_t w; // fwd: ow bwd: iw
55};
56
57// jit kernels
58namespace {
59
60struct jit_avx512_core_resampling_kernel_t
61 : public jit_avx512_core_resampling_kernel_base_t {
62 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_resampling)
63
64 jit_avx512_core_resampling_kernel_t(const resampling_pd_t *pd)
65 : jit_avx512_core_resampling_kernel_base_t(pd, jit_name())
66 , is_saturation_needed_(utils::one_of(dst_data_type(), data_type::u8,
67 data_type::s8, data_type::s32)) {
68
69 if (pd_->is_fwd()) {
70 const memory_desc_wrapper src_d(pd_->src_md());
71 inner_stride_ = src_d.blocking_desc().strides[pd_->ndims() - 1];
72 stride_d_ = pd_->IH() * pd_->IW() * inner_stride_;
73 stride_h_ = pd_->IW() * inner_stride_;
74 stride_w_ = inner_stride_;
75 } else {
76 const memory_desc_wrapper diff_src_d(pd_->diff_src_md());
77 inner_stride_
78 = diff_src_d.blocking_desc().strides[pd_->ndims() - 1];
79 stride_d_ = pd_->OH() * pd_->OW() * inner_stride_;
80 stride_h_ = pd_->OW() * inner_stride_;
81 stride_w_ = inner_stride_;
82 }
83
84 number_of_loops_ = (inner_stride_ / simd_w());
85 tail_size_ = inner_stride_ % simd_w();
86 stack_size_needed_ = 0;
87
88 cpu_isa_t isa
89 = mayiuse(avx512_core_bf16) ? avx512_core_bf16 : avx512_core;
90
91 const io::jit_io_multi_dt_helper_t<Zmm>::data_types_t data_types {
92 src_data_type(), dst_data_type()};
93 std::map<data_type_t, io::io_saturation_conf_t> saturation_conf {};
94 if (is_saturation_needed_) {
95 saturation_conf.emplace(dst_data_type(),
96 io::io_saturation_conf_t {zmm_zero_saturation_.getIdx(),
97 zmm_saturation_ubound_.getIdx(), reg_tmp});
98 }
99
100 io_ = utils::make_unique<io::jit_io_multi_dt_helper_t<Zmm>>(this, isa,
101 data_types, io::io_conf_t {},
102 io::io_tail_conf_t {
103 simd_w(), tail_size_, k_tail_mask, 0, reg_tmp},
104 io::io_emu_bf16_conf_t {}, saturation_conf);
105 }
106
107private:
108 enum class rounding_mode { none, floor, ceil, rounding_mode_max };
109
110 struct bwd_counting_range_t {
111 RegExp loop_counter;
112 struct start_t {
113 RegExp linear[2];
114 RegExp nearest;
115 } start;
116 struct end_t {
117 RegExp linear[2];
118 RegExp nearest;
119 } end;
120 };
121
122 void round_to_near_away_from_zero(const Reg64 &dst_int_val,
123 const Xmm &to_round, const Xmm &zero_point_five, const Xmm &tmp) {
124 EvexModifierRounding rm_trunc(EvexModifierRounding::T_RZ_SAE);
125 vaddss(tmp, to_round, zero_point_five);
126 vcvtss2si(dst_int_val, tmp | rm_trunc);
127 }
128
129 void for_begin(Label &begin_loop, Label &end_loop,
130 const RegExp &loop_counter, const RegExp &start, const RegExp &end,
131 const Reg64 &tmp) {
132 // for (initialization; check; incrementation)
133 // initialization
134 mov(tmp, ptr[start]);
135 mov(ptr[loop_counter], tmp);
136 L(begin_loop);
137 // check
138 mov(tmp, ptr[loop_counter]);
139 cmp(tmp, ptr[end]);
140 jge(end_loop, T_NEAR);
141 }
142
143 void for_end(Label &begin_loop, Label &end_loop, const RegExp &loop_counter,
144 const Reg64 &tmp) {
145 // incrementation
146 mov(tmp, ptr[loop_counter]);
147 inc(tmp);
148 mov(ptr[loop_counter], tmp);
149 jmp(begin_loop, T_NEAR);
150 L(end_loop);
151 }
152
153 void max(const Reg64 &reg, const dim_t to_cmp) {
154 mov(reg_tmp, to_cmp);
155 cmp(reg, reg_tmp);
156 cmovl(reg, reg_tmp);
157 }
158
159 void min(const Reg64 &reg, const dim_t to_cmp) {
160 mov(reg_tmp, to_cmp);
161 cmp(reg, reg_tmp);
162 cmovg(reg, reg_tmp);
163 }
164
165 void move_imm_float_to_xmm(const Xmm &xmm, const Reg64 &tmp, float imm) {
166 mov(tmp.cvt32(), float2int(imm));
167 vmovd(xmm, tmp.cvt32());
168 }
169
170 void generate() override {
171 preamble();
172
173 io_->init_bf16();
174 if (is_saturation_needed_) io_->init_saturate_f32({dst_data_type()});
175 if (tail_size_) io_->prepare_tail_mask();
176
177 mov(reg_src, ptr[abi_param1 + GET_OFF(src)]);
178 mov(reg_dst, ptr[abi_param1 + GET_OFF(dst)]);
179
180 move_imm_float_to_xmm(xmm_zero_point_five, reg_tmp, 0.5f);
181
182 if (pd_->is_fwd()) {
183 // Count coeffs
184 if (pd_->ndims() == 5) {
185 mov(reg_curr_d, ptr[abi_param1 + GET_OFF(d)]);
186 mov(reg_curr_h, ptr[abi_param1 + GET_OFF(h)]);
187 mov(reg_curr_w, ptr[abi_param1 + GET_OFF(w)]);
188 count_dim_coeff(xmm_d_coeff, reg_curr_d, pd_->OD(), pd_->ID());
189 count_dim_coeff(xmm_h_coeff, reg_curr_h, pd_->OH(), pd_->IH());
190 count_dim_coeff(xmm_w_coeff, reg_curr_w, pd_->OW(), pd_->IW());
191 } else if (pd_->ndims() == 4) {
192 mov(reg_curr_h, ptr[abi_param1 + GET_OFF(h)]);
193 mov(reg_curr_w, ptr[abi_param1 + GET_OFF(w)]);
194 count_dim_coeff(xmm_h_coeff, reg_curr_h, pd_->OH(), pd_->IH());
195 count_dim_coeff(xmm_w_coeff, reg_curr_w, pd_->OW(), pd_->IW());
196 } else {
197 mov(reg_curr_w, ptr[abi_param1 + GET_OFF(w)]);
198 count_dim_coeff(xmm_w_coeff, reg_curr_w, pd_->OW(), pd_->IW());
199 }
200 } else {
201 if (pd_->desc()->alg_kind == alg_kind::resampling_linear) {
202 // Stack:
203 // ow_loop_counter:
204 // ow_left_start : 8
205 // ow_left_end : 16
206 // ow_right_start : 24
207 // ow_right_end : 32
208 // ----- 3 dims -----
209 // oh_loop_counter: 40
210 // oh_left_start : 48
211 // oh_left_end : 56
212 // oh_right_start : 64
213 // oh_right_end : 72
214 // ----- 4 dims -----
215 // od_loop_counter: 80
216 // od_left_start : 88
217 // od_left_end : 96
218 // od_right_start : 104
219 // od_right_end : 112
220 // ----- 5 dims -----
221
222 // 5*size(int64)*nr_of_spatial_dims
223 stack_size_needed_ = 5 * 8 * (pd_->ndims() - 2);
224 sub(rsp, stack_size_needed_);
225
226 if (pd_->ndims() == 5) {
227 mov(reg_curr_d, ptr[abi_param1 + GET_OFF(d)]);
228 mov(reg_curr_h, ptr[abi_param1 + GET_OFF(h)]);
229 mov(reg_curr_w, ptr[abi_param1 + GET_OFF(w)]);
230 count_bwd_counting_range(
231 rsp + 80, od, reg_curr_d, pd_->OD(), pd_->ID());
232 count_bwd_counting_range(
233 rsp + 40, oh, reg_curr_h, pd_->OH(), pd_->IH());
234 count_bwd_counting_range(
235 rsp, ow, reg_curr_w, pd_->OW(), pd_->IW());
236 } else if (pd_->ndims() == 4) {
237 mov(reg_curr_h, ptr[abi_param1 + GET_OFF(h)]);
238 mov(reg_curr_w, ptr[abi_param1 + GET_OFF(w)]);
239 count_bwd_counting_range(
240 rsp + 40, oh, reg_curr_h, pd_->OH(), pd_->IH());
241 count_bwd_counting_range(
242 rsp, ow, reg_curr_w, pd_->OW(), pd_->IW());
243 } else {
244 mov(reg_curr_w, ptr[abi_param1 + GET_OFF(w)]);
245 count_bwd_counting_range(
246 rsp, ow, reg_curr_w, pd_->OW(), pd_->IW());
247 }
248 } else {
249 // Stack:
250 // ow_loop_counter:
251 // ow_start : 8
252 // ow_end : 16
253 // oh_loop_counter: 24
254 // oh_start : 32
255 // oh_end : 40
256 // od_loop_counter: 48
257 // od_start : 56
258 // od_end : 64
259
260 // 3*size(int64)*max_nr_of_spatial_dims
261 stack_size_needed_ = 3 * 8 * 3;
262 sub(rsp, stack_size_needed_);
263
264 mov(reg_curr_d, ptr[abi_param1 + GET_OFF(d)]);
265 mov(reg_curr_h, ptr[abi_param1 + GET_OFF(h)]);
266 mov(reg_curr_w, ptr[abi_param1 + GET_OFF(w)]);
267 count_bwd_counting_range(
268 rsp + 48, od, reg_curr_d, pd_->OD(), pd_->ID());
269 count_bwd_counting_range(
270 rsp + 24, oh, reg_curr_h, pd_->OH(), pd_->IH());
271 count_bwd_counting_range(
272 rsp, ow, reg_curr_w, pd_->OW(), pd_->IW());
273 }
274 }
275
276 // Choose algorithm
277 if (pd_->desc()->alg_kind == alg_kind::resampling_linear) {
278 if (pd_->ndims() == 5) {
279 trilinear();
280 } else if (pd_->ndims() == 4) {
281 bilinear();
282 } else {
283 linear();
284 }
285 } else {
286 nearest();
287 }
288
289 if (!pd_->is_fwd()) add(rsp, stack_size_needed_);
290 postamble();
291 }
292
293 void count_dim_coeff(const Xmm &xmm_coeff, const Reg64 &reg_dim,
294 dim_t y_max, dim_t x_max) {
295 // Formula = ((y + 0.5f) * x_max / y_max) - 0.5f
296 vcvtsi2ss(xmm_coeff, xmm_coeff, reg_dim); // y
297 vaddss(xmm_coeff, xmm_coeff, xmm_zero_point_five); // y + 0.5f
298
299 move_imm_float_to_xmm(xmm_tmp_factor, reg_tmp, (float)x_max);
300 vmulss(xmm_coeff, xmm_coeff,
301 xmm_tmp_factor); // (y + 0.5f) * x_max
302 move_imm_float_to_xmm(xmm_tmp_factor, reg_tmp, (float)y_max);
303 vdivss(xmm_coeff, xmm_coeff,
304 xmm_tmp_factor); // (y + 0.5f) * x_max / y_max
305
306 vsubss(xmm_coeff, xmm_coeff,
307 xmm_zero_point_five); // ((y + 0.5) * x_max / y_max) - 0.5
308 }
309
310 void count_bwd_counting_range(RegExp stack_position,
311 bwd_counting_range_t &c_range, const Reg64 &curr_position,
312 dim_t y_max, dim_t x_max) {
313 c_range.loop_counter = stack_position;
314 if (pd_->desc()->alg_kind == alg_kind::resampling_linear) {
315 c_range.start.linear[0] = stack_position + 8;
316 c_range.end.linear[0] = stack_position + 16;
317 c_range.start.linear[1] = stack_position + 24;
318 c_range.end.linear[1] = stack_position + 32;
319 } else {
320 c_range.start.nearest = stack_position + 8;
321 c_range.end.nearest = stack_position + 16;
322 }
323
324 EvexModifierRounding rm_ceil(EvexModifierRounding::T_RU_SAE);
325 EvexModifierRounding rm_floor(EvexModifierRounding::T_RD_SAE);
326
327 if (pd_->desc()->alg_kind == alg_kind::resampling_linear) {
328 // coeff = (pos + 0.5) * y_max / x_max - 0.5
329 count_dim_coeff(xmm_coeff, curr_position, x_max, y_max);
330
331 // l_start: x == 0 ? 0 : ceil(coeff)
332 vcvtss2si(reg_tmp_idx, xmm_coeff | rm_ceil);
333 mov(reg_tmp, 0);
334 cmp(curr_position, reg_tmp);
335 cmove(reg_tmp_idx, reg_tmp);
336 mov(ptr[c_range.start.linear[0]], reg_tmp_idx);
337
338 // r_end: x == x_max-1 ? y_max : min(max(0, floor(coeff) + 1), y_max)
339 vcvtss2si(reg_tmp_idx, xmm_coeff | rm_floor);
340 add(reg_tmp_idx, 1);
341 max(reg_tmp_idx, 0);
342 min(reg_tmp_idx, y_max);
343 cmp(curr_position, x_max - 1);
344 mov(reg_tmp, y_max);
345 cmove(reg_tmp_idx, reg_tmp);
346 mov(ptr[c_range.end.linear[1]], reg_tmp_idx);
347
348 // coeff = ((pos-1) + 0.5) * y_max / x_max - 0.5
349 sub(curr_position, 1);
350 count_dim_coeff(xmm_coeff, curr_position, x_max, y_max);
351
352 // r_start: max(0, floor(coeff) + 1)
353 vcvtss2si(reg_tmp_idx, xmm_coeff | rm_floor);
354 add(reg_tmp_idx, 1);
355 max(reg_tmp_idx, 0);
356 mov(ptr[c_range.start.linear[1]], reg_tmp_idx);
357
358 // coeff = ((pos+1) + 0.5) * y_max / x_max - 0.5
359 add(curr_position, 2);
360 count_dim_coeff(xmm_coeff, curr_position, x_max, y_max);
361
362 // l_end: min(ceil(coeff), y_max)
363 vcvtss2si(reg_tmp_idx, xmm_coeff | rm_ceil);
364 min(reg_tmp_idx, y_max);
365 mov(ptr[c_range.end.linear[0]], reg_tmp_idx);
366 } else {
367 float factor = (float)y_max / x_max;
368
369 // start: ceil(pos * factor - 0.5f)
370 vcvtsi2ss(xmm_coeff, xmm_coeff, curr_position);
371 move_imm_float_to_xmm(xmm_tmp_factor, reg_tmp, factor);
372 vmulss(xmm_coeff, xmm_coeff, xmm_tmp_factor);
373 vsubss(xmm_coeff, xmm_coeff, xmm_zero_point_five);
374 vcvtss2si(reg_tmp_idx, xmm_coeff | rm_ceil);
375 mov(ptr[c_range.start.nearest], reg_tmp_idx);
376
377 // start: ceil((pos+1) * factor - 0.5f)
378 add(curr_position, 1);
379 vcvtsi2ss(xmm_coeff, xmm_coeff, curr_position);
380 vmulss(xmm_coeff, xmm_coeff, xmm_tmp_factor);
381 vsubss(xmm_coeff, xmm_coeff, xmm_zero_point_five);
382 vcvtss2si(reg_tmp_idx, xmm_coeff | rm_ceil);
383 mov(ptr[c_range.end.nearest], reg_tmp_idx);
384 }
385 }
386
387 void count_idx_and_weight_for_linear(const Xmm &coeff, const Zmm &weight,
388 const Reg64 &idx, dim_t dim_max, rounding_mode rm) {
389 const Xmm xmm_weight = Xmm(weight.getIdx());
390 Reg64 reg_idx_floor;
391
392 if (pd_->is_fwd() && rm == rounding_mode::ceil) {
393 EvexModifierRounding rm_ceil(EvexModifierRounding::T_RU_SAE);
394 EvexModifierRounding rm_floor(EvexModifierRounding::T_RD_SAE);
395 vcvtss2si(idx,
396 coeff | rm_ceil); // ceil(coeff)
397 reg_idx_floor = reg_tmp;
398 vcvtss2si(reg_idx_floor,
399 coeff | rm_floor); // floor(coeff)
400
401 } else {
402 EvexModifierRounding rm_floor(EvexModifierRounding::T_RD_SAE);
403 vcvtss2si(idx,
404 coeff | rm_floor); // floor(coeff)
405 reg_idx_floor = idx;
406 }
407
408 vcvtsi2ss(xmm_tmp, xmm_tmp, reg_idx_floor);
409 vsubss(xmm_weight, coeff,
410 zmm_tmp); // W = coeff - idx
411 if (rm == rounding_mode::floor) {
412 move_imm_float_to_xmm(xmm_tmp, reg_tmp, 1.0f);
413 vsubss(xmm_weight, xmm_tmp,
414 xmm_weight); // W = 1 - (coeff - idx)
415 }
416 vbroadcastss(weight, xmm_weight);
417
418 if (pd_->is_fwd()) {
419 if (rm == rounding_mode::ceil) {
420 min(idx, dim_max - 1);
421 } else if (rm == rounding_mode::floor) {
422 max(idx, 0);
423 }
424 }
425 }
426
427 void linear_alg(int64_t channel_offset, rounding_mode rm_w,
428 rounding_mode rm_h = rounding_mode::none,
429 rounding_mode rm_d = rounding_mode::none, bool is_tail = false) {
430 xor_(reg_offset, reg_offset); // reg_offset = 0
431
432 if (rm_w != rounding_mode::none) {
433 // out: Ww, curr_w
434 count_idx_and_weight_for_linear(
435 xmm_w_coeff, zmm_weight, reg_curr_w, pd_->IW(), rm_w);
436 // curr_w * stride_w_
437 if (!pd_->is_fwd()) mov(reg_curr_w, ptr[ow.loop_counter]);
438 imul(reg_offset, reg_curr_w, stride_w_);
439 }
440 if (rm_h != rounding_mode::none) {
441 // out: Wh, curr_h
442 count_idx_and_weight_for_linear(
443 xmm_h_coeff, zmm_tmp_weight, reg_curr_h, pd_->IH(), rm_h);
444 // Ww * Wh
445 vmulps(zmm_weight, zmm_weight, zmm_tmp_weight);
446 // curr_w * stride_w_ + curr_h * stride_h_
447 if (!pd_->is_fwd()) mov(reg_curr_h, ptr[oh.loop_counter]);
448 imul(reg_tmp, reg_curr_h, stride_h_);
449 add(reg_offset, reg_tmp);
450 }
451 if (rm_d != rounding_mode::none) {
452 // out: Wd, curr_d
453 count_idx_and_weight_for_linear(
454 xmm_d_coeff, zmm_tmp_weight, reg_curr_d, pd_->ID(), rm_d);
455 // Ww * Wh * Wd
456 vmulps(zmm_weight, zmm_weight, zmm_tmp_weight);
457 // curr_w * stride_w_ + curr_h * stride_h_ + curr_d * stride_d_
458 if (!pd_->is_fwd()) mov(reg_curr_d, ptr[od.loop_counter]);
459 imul(reg_tmp, reg_curr_d, stride_d_);
460 add(reg_offset, reg_tmp);
461 }
462
463 add(reg_offset, channel_offset);
464 imul(reg_offset, reg_offset, types::data_type_size(src_data_type()));
465
466 // read src
467 io_->at(src_data_type())
468 ->load(ptr[reg_src + reg_offset], zmm_src, is_tail);
469
470 // mul src, weight
471 vmulps(zmm_tmp, zmm_src, zmm_weight);
472 vaddps(zmm_dst, zmm_dst, zmm_tmp);
473 }
474
475 void linear() {
476 int64_t number_of_processed_points = 0;
477
478 auto resample_linear = ([&](bool is_tail) {
479 auto call_linear = ([&](int i) {
480 linear_alg(number_of_processed_points,
481 i % 2 ? rounding_mode::floor
482 : rounding_mode::ceil /* rounding_mode_w */,
483 rounding_mode::none /* rounding_mode_h */,
484 rounding_mode::none /* rounding_mode_d */, is_tail);
485 });
486
487 // zero dst
488 vpxorq(zmm_dst, zmm_dst, zmm_dst);
489
490 if (pd_->is_fwd()) {
491 for (int i = 0; i < 2; i++) {
492 call_linear(i);
493 }
494 } else {
495 Label label[2][2];
496
497 for (int i = 0; i < 2; i++) {
498 // for (dim_t ow = w.start[i]; ow < w.end[i]; ow++)
499 for_begin(label[i][0], label[i][1], ow.loop_counter,
500 ow.start.linear[i], ow.end.linear[i], reg_tmp);
501 count_dim_coeff(xmm_w_coeff, reg_tmp, pd_->OW(), pd_->IW());
502
503 call_linear(i + 1);
504
505 for_end(label[i][0], label[i][1], ow.loop_counter, reg_tmp);
506 }
507 }
508
509 const size_t offset = number_of_processed_points
510 * types::data_type_size(dst_data_type());
511 const auto address = ptr[reg_dst + offset];
512 // store dst
513 io_->at(dst_data_type())->store(zmm_dst, address, is_tail);
514 });
515
516 for (unsigned i = 0; i < number_of_loops_;
517 i++, number_of_processed_points += simd_w())
518 resample_linear(false);
519
520 if (tail_size_ != 0) resample_linear(true);
521 }
522
523 void bilinear() {
524 int64_t number_of_processed_points = 0;
525
526 auto resample_linear = ([&](bool is_tail) {
527 auto call_linear = ([&](int i, int j) {
528 linear_alg(number_of_processed_points,
529 i % 2 ? rounding_mode::floor
530 : rounding_mode::ceil /* rounding_mode_w */,
531 j % 2 ? rounding_mode::floor
532 : rounding_mode::ceil /* rounding_mode_h */,
533 rounding_mode::none /* rounding_mode_d */, is_tail);
534 });
535
536 // zero dst
537 vpxorq(zmm_dst, zmm_dst, zmm_dst);
538
539 if (pd_->is_fwd()) {
540 for (int i = 0; i < 2; i++) {
541 for (int j = 0; j < 2; j++) {
542 call_linear(i, j);
543 }
544 }
545 } else {
546 Label label[2][2][4];
547
548 for (int i = 0; i < 2; i++) {
549 for (int j = 0; j < 2; j++) {
550 // for (dim_t ow = w.start[i]; ow < w.end[i]; ow++)
551 for_begin(label[i][j][0], label[i][j][1],
552 ow.loop_counter, ow.start.linear[i],
553 ow.end.linear[i], reg_tmp);
554 count_dim_coeff(
555 xmm_w_coeff, reg_tmp, pd_->OW(), pd_->IW());
556 // for (dim_t oh = h.start[j]; oh < h.end[j]; oh++)
557 for_begin(label[i][j][2], label[i][j][3],
558 oh.loop_counter, oh.start.linear[j],
559 oh.end.linear[j], reg_tmp);
560 count_dim_coeff(
561 xmm_h_coeff, reg_tmp, pd_->OH(), pd_->IH());
562
563 call_linear(i + 1, j + 1);
564
565 for_end(label[i][j][2], label[i][j][3], oh.loop_counter,
566 reg_tmp);
567 for_end(label[i][j][0], label[i][j][1], ow.loop_counter,
568 reg_tmp);
569 }
570 }
571 }
572
573 const size_t offset = number_of_processed_points
574 * types::data_type_size(dst_data_type());
575 const auto address = ptr[reg_dst + offset];
576 // store dst
577 io_->at(dst_data_type())->store(zmm_dst, address, is_tail);
578 });
579
580 for (unsigned i = 0; i < number_of_loops_;
581 i++, number_of_processed_points += simd_w())
582 resample_linear(false);
583
584 if (tail_size_ != 0) resample_linear(true);
585 }
586
587 void trilinear() {
588 int64_t number_of_processed_points = 0;
589
590 auto resample_linear = ([&](bool is_tail) {
591 auto call_linear = ([&](int i, int j, int k) {
592 linear_alg(number_of_processed_points,
593 i % 2 ? rounding_mode::floor
594 : rounding_mode::ceil /* rounding_mode_w */,
595 j % 2 ? rounding_mode::floor
596 : rounding_mode::ceil /* rounding_mode_h */,
597 k % 2 ? rounding_mode::floor
598 : rounding_mode::ceil /* rounding_mode_d */,
599 is_tail);
600 });
601
602 // zero dst
603 vpxorq(zmm_dst, zmm_dst, zmm_dst);
604
605 if (pd_->is_fwd()) {
606 for (int i = 0; i < 2; i++) {
607 for (int j = 0; j < 2; j++) {
608 for (int k = 0; k < 2; k++) {
609 call_linear(i, j, k);
610 }
611 }
612 }
613 } else {
614 Label label[2][2][2][6];
615
616 for (int i = 0; i < 2; i++) {
617 for (int j = 0; j < 2; j++) {
618 for (int k = 0; k < 2; k++) {
619 // for (dim_t ow = w.start[i]; ow < w.end[i]; ow++)
620 for_begin(label[i][j][k][0], label[i][j][k][1],
621 ow.loop_counter, ow.start.linear[i],
622 ow.end.linear[i], reg_tmp);
623 count_dim_coeff(
624 xmm_w_coeff, reg_tmp, pd_->OW(), pd_->IW());
625 // for (dim_t oh = h.start[j]; oh < h.end[j]; oh++)
626 for_begin(label[i][j][k][2], label[i][j][k][3],
627 oh.loop_counter, oh.start.linear[j],
628 oh.end.linear[j], reg_tmp);
629 count_dim_coeff(
630 xmm_h_coeff, reg_tmp, pd_->OH(), pd_->IH());
631 // for (dim_t od = d.start[k]; od < d.end[k]; od++)
632 for_begin(label[i][j][k][4], label[i][j][k][5],
633 od.loop_counter, od.start.linear[k],
634 od.end.linear[k], reg_tmp);
635 count_dim_coeff(
636 xmm_d_coeff, reg_tmp, pd_->OD(), pd_->ID());
637
638 call_linear(i + 1, j + 1, k + 1);
639
640 for_end(label[i][j][k][4], label[i][j][k][5],
641 od.loop_counter, reg_tmp);
642 for_end(label[i][j][k][2], label[i][j][k][3],
643 oh.loop_counter, reg_tmp);
644 for_end(label[i][j][k][0], label[i][j][k][1],
645 ow.loop_counter, reg_tmp);
646 }
647 }
648 }
649 }
650
651 const size_t offset = number_of_processed_points
652 * types::data_type_size(dst_data_type());
653 const auto address = ptr[reg_dst + offset];
654 // store dst
655 io_->at(dst_data_type())->store(zmm_dst, address, is_tail);
656 });
657
658 for (unsigned i = 0; i < number_of_loops_;
659 i++, number_of_processed_points += simd_w())
660 resample_linear(false);
661
662 if (tail_size_ != 0) resample_linear(true);
663 }
664
665 void nearest_alg(int64_t channel_offset, bool is_tail = false) {
666 xor_(reg_offset, reg_offset); // reg_offset = 0
667
668 auto get_idx = ([&](const Reg64 &idx, const Xmm &coeff,
669 dim_t dim_max_size) {
670 round_to_near_away_from_zero(idx, coeff, xmm_zero_point_five,
671 xmm_tmp); // round_to_nearest(coeff)
672 min(idx, dim_max_size - 1);
673 max(idx, 0);
674 });
675
676 if (pd_->is_fwd()) {
677 get_idx(reg_curr_w, xmm_w_coeff, pd_->IW());
678 get_idx(reg_curr_h, xmm_h_coeff, pd_->IH());
679 get_idx(reg_curr_d, xmm_d_coeff, pd_->ID());
680 } else {
681 mov(reg_curr_w, ptr[ow.loop_counter]);
682 mov(reg_curr_h, ptr[oh.loop_counter]);
683 mov(reg_curr_d, ptr[od.loop_counter]);
684 }
685
686 imul(reg_offset, reg_curr_w, stride_w_); // iw * stride_w_
687 imul(reg_tmp, reg_curr_h, stride_h_);
688 add(reg_offset, reg_tmp); // iw * stride_w_ + ih * stride_h_
689 imul(reg_tmp, reg_curr_d, stride_d_);
690 add(reg_offset,
691 reg_tmp); // iw * stride_w_ + ih * stride_h_ + id * stride_d_
692
693 add(reg_offset,
694 channel_offset); // iw * stride_w_ + ih * stride_h_ + id * stride_d_ + channel_offset
695 imul(reg_offset, reg_offset,
696 types::data_type_size(
697 src_data_type())); // (iw * stride_w_ + ih * stride_h_ + id * stride_d_ + channel_offset)*dt_size
698
699 if (pd_->is_fwd()) {
700 // read nearest to dst
701 io_->at(src_data_type())
702 ->load(ptr[reg_src + reg_offset], zmm_dst, is_tail);
703 } else {
704 // add nearest to dst
705 io_->at(src_data_type())
706 ->load(ptr[reg_src + reg_offset], zmm_tmp, is_tail);
707 vaddps(zmm_dst, zmm_dst, zmm_tmp);
708 }
709 }
710
711 void nearest() {
712 int64_t number_of_processed_points = 0;
713
714 auto resample_nearest = ([&](bool is_tail) {
715 // zero dst
716 vpxorq(zmm_dst, zmm_dst, zmm_dst);
717
718 if (pd_->is_fwd()) {
719 nearest_alg(number_of_processed_points, is_tail);
720 } else {
721 Label label[6];
722
723 // for (dim_t ow = w.start[i]; ow < w.end[i]; ow++)
724 for_begin(label[0], label[1], ow.loop_counter, ow.start.nearest,
725 ow.end.nearest, reg_tmp);
726 // for (dim_t oh = h.start[j]; oh < h.end[j]; oh++)
727 for_begin(label[2], label[3], oh.loop_counter, oh.start.nearest,
728 oh.end.nearest, reg_tmp);
729 // for (dim_t od = d.start[k]; od < d.end[k]; od++)
730 for_begin(label[4], label[5], od.loop_counter, od.start.nearest,
731 od.end.nearest, reg_tmp);
732
733 nearest_alg(number_of_processed_points, is_tail);
734
735 for_end(label[4], label[5], od.loop_counter, reg_tmp);
736 for_end(label[2], label[3], oh.loop_counter, reg_tmp);
737 for_end(label[0], label[1], ow.loop_counter, reg_tmp);
738 }
739
740 const size_t offset = number_of_processed_points
741 * types::data_type_size(dst_data_type());
742 const auto address = ptr[reg_dst + offset];
743 // store dst
744 io_->at(dst_data_type())->store(zmm_dst, address, is_tail);
745 });
746
747 for (unsigned i = 0; i < number_of_loops_;
748 i++, number_of_processed_points += simd_w())
749 resample_nearest(false);
750
751 if (tail_size_ != 0) resample_nearest(true);
752 }
753
754 static constexpr std::size_t simd_w() {
755 return cpu_isa_traits<avx512_core>::vlen / sizeof(float);
756 }
757
758 Zmm zmm_src = Zmm(1);
759 Zmm zmm_dst = Zmm(2);
760 Zmm zmm_weight = Zmm(3);
761 Xmm xmm_coeff = Xmm(4);
762 Xmm xmm_d_coeff = Xmm(4);
763 Xmm xmm_h_coeff = Xmm(5);
764 Xmm xmm_w_coeff = Xmm(6);
765 Xmm xmm_zero_point_five = Xmm(7);
766 Zmm zmm_tmp = Zmm(8);
767 Xmm xmm_tmp = Xmm(8);
768 Zmm zmm_tmp_weight = Zmm(9);
769 Xmm xmm_tmp_factor = Xmm(9);
770 Zmm zmm_zero_saturation_ = Zmm(10);
771 Zmm zmm_saturation_ubound_ = Zmm(11);
772
773 Opmask k_tail_mask = k6;
774 Reg64 reg_src = rax;
775 Reg64 reg_dst = rbx;
776 Reg64 reg_tmp = r8;
777 Reg64 reg_curr_d = r9;
778 Reg64 reg_curr_h = r10;
779 Reg64 reg_curr_w = r11;
780 Reg64 reg_offset = r12;
781 Reg64 reg_tmp_idx = r12;
782
783 bwd_counting_range_t ow;
784 bwd_counting_range_t oh;
785 bwd_counting_range_t od;
786
787 std::unique_ptr<io::jit_io_multi_dt_helper_t<Zmm>> io_;
788
789 dim_t stride_d_ = 0;
790 dim_t stride_h_ = 0;
791 dim_t stride_w_ = 0;
792 dim_t inner_stride_ = 0;
793 unsigned number_of_loops_ = 0;
794 size_t tail_size_ = 0;
795 bool is_saturation_needed_ = false;
796 unsigned stack_size_needed_ = 0;
797};
798
799} // namespace
800
801jit_avx512_core_resampling_kernel_base_t::
802 jit_avx512_core_resampling_kernel_base_t(
803 const resampling_pd_t *pd, const char *name)
804 : jit_generator(name), pd_(pd) {}
805
806data_type_t jit_avx512_core_resampling_kernel_base_t::src_data_type() const {
807 if (pd_->is_fwd())
808 return pd_->src_md()->data_type;
809 else
810 return pd_->diff_dst_md()->data_type;
811}
812
813data_type_t jit_avx512_core_resampling_kernel_base_t::dst_data_type() const {
814 if (pd_->is_fwd())
815 return pd_->dst_md()->data_type;
816 else
817 return pd_->diff_src_md()->data_type;
818}
819
820status_t jit_avx512_core_resampling_bwd_t::pd_t::init(engine_t *engine) {
821 using namespace format_tag;
822 using namespace data_type;
823 const bool ok = mayiuse(avx512_core) && !is_fwd() && !has_zero_dim_memory()
824 && impl_supports_datatype(diff_dst_md()->data_type)
825 && impl_supports_datatype(diff_src_md()->data_type)
826 && IMPLICATION(diff_src_md()->data_type == f16,
827 mayiuse(avx512_core_fp16)
828 && memory_desc_wrapper(diff_src_md()).is_plain())
829 && set_default_params() == status::success
830 && attr()->has_default_values();
831 if (!ok) return status::unimplemented;
832
833 format_tag_t dat_tag = memory_desc_matches_one_of_tag(*diff_src_md(), nCw8c,
834 nChw8c, nCdhw8c, nCw16c, nChw16c, nCdhw16c, nwc, nhwc, ndhwc);
835 if (!memory_desc_matches_tag(*diff_dst_md(), dat_tag))
836 return status::unimplemented;
837
838 return status::success;
839}
840
841jit_avx512_core_resampling_bwd_t::~jit_avx512_core_resampling_bwd_t() = default;
842
843status_t jit_avx512_core_resampling_bwd_t::init(engine_t *engine) {
844 CHECK(safe_ptr_assign(
845 kernel_, new jit_avx512_core_resampling_kernel_t(pd())));
846 return kernel_->create_kernel();
847}
848
849status_t jit_avx512_core_resampling_bwd_t::execute(
850 const exec_ctx_t &ctx) const {
851
852 const auto diff_dst = CTX_IN_MEM(const unsigned char *, DNNL_ARG_DIFF_DST);
853 auto diff_src = CTX_OUT_MEM(unsigned char *, DNNL_ARG_DIFF_SRC);
854
855 const std::size_t diff_dst_dt_size
856 = types::data_type_size(pd()->diff_dst_md()->data_type);
857 const std::size_t diff_src_dt_size
858 = types::data_type_size(pd()->diff_src_md()->data_type);
859
860 const dim_t OD = pd()->OD();
861 const dim_t OH = pd()->OH();
862 const dim_t OW = pd()->OW();
863 const dim_t ID = pd()->ID();
864 const dim_t IH = pd()->IH();
865 const dim_t IW = pd()->IW();
866
867 const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
868 const dim_t inner_stride
869 = diff_src_d.blocking_desc().strides[pd()->ndims() - 1];
870 const dim_t nsp_outer
871 = diff_src_d.nelems(true) / (ID * IH * IW * inner_stride);
872
873 parallel_nd(nsp_outer, ID, IH, IW,
874 [&](dim_t nsp, dim_t id, dim_t ih, dim_t iw) {
875 const dim_t diff_dst_off
876 = nsp * OD * OH * OW * inner_stride * diff_dst_dt_size;
877 const dim_t diff_src_off
878 = (nsp * ID * IH * IW + id * IH * IW + ih * IW + iw)
879 * inner_stride * diff_src_dt_size;
880 jit_resampling_args_t args;
881 args.src = diff_dst + diff_dst_off;
882 args.dst = diff_src + diff_src_off;
883 args.d = id;
884 args.h = ih;
885 args.w = iw;
886 (*kernel_)(&args);
887 });
888
889 return status_t();
890}
891
892} // namespace x64
893} // namespace cpu
894} // namespace impl
895} // namespace dnnl
896