1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "common/bfloat16.hpp" |
18 | #include "common/c_types_map.hpp" |
19 | #include "common/dnnl_thread.hpp" |
20 | #include "common/type_helpers.hpp" |
21 | #include "common/utils.hpp" |
22 | |
23 | #include "cpu/x64/jit_generator.hpp" |
24 | |
25 | #include "cpu/x64/jit_avx512_core_resampling.hpp" |
26 | |
27 | #include "utils/jit_io_helper.hpp" |
28 | |
29 | namespace dnnl { |
30 | namespace impl { |
31 | namespace cpu { |
32 | namespace x64 { |
33 | |
34 | using namespace Xbyak; |
35 | |
36 | static bool impl_supports_datatype(data_type_t data_type) { |
37 | switch (data_type) { |
38 | case data_type::bf16: return x64::mayiuse(x64::avx512_core); |
39 | case data_type::f16: return x64::mayiuse(x64::avx512_core_fp16); |
40 | case data_type::f32: |
41 | case data_type::s32: |
42 | case data_type::s8: |
43 | case data_type::u8: return true; |
44 | default: return false; |
45 | } |
46 | } |
47 | |
48 | #define GET_OFF(field) offsetof(jit_resampling_args_t, field) |
49 | struct jit_resampling_args_t { |
50 | const void *src; // fwd: src bwd: diff_dst |
51 | const void *dst; // fwd: dst bwd: diff_src |
52 | dim_t d; // fwd: od bwd: id |
53 | dim_t h; // fwd: oh bwd: ih |
54 | dim_t w; // fwd: ow bwd: iw |
55 | }; |
56 | |
57 | // jit kernels |
58 | namespace { |
59 | |
60 | struct jit_avx512_core_resampling_kernel_t |
61 | : public jit_avx512_core_resampling_kernel_base_t { |
62 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_resampling) |
63 | |
64 | jit_avx512_core_resampling_kernel_t(const resampling_pd_t *pd) |
65 | : jit_avx512_core_resampling_kernel_base_t(pd, jit_name()) |
66 | , is_saturation_needed_(utils::one_of(dst_data_type(), data_type::u8, |
67 | data_type::s8, data_type::s32)) { |
68 | |
69 | if (pd_->is_fwd()) { |
70 | const memory_desc_wrapper src_d(pd_->src_md()); |
71 | inner_stride_ = src_d.blocking_desc().strides[pd_->ndims() - 1]; |
72 | stride_d_ = pd_->IH() * pd_->IW() * inner_stride_; |
73 | stride_h_ = pd_->IW() * inner_stride_; |
74 | stride_w_ = inner_stride_; |
75 | } else { |
76 | const memory_desc_wrapper diff_src_d(pd_->diff_src_md()); |
77 | inner_stride_ |
78 | = diff_src_d.blocking_desc().strides[pd_->ndims() - 1]; |
79 | stride_d_ = pd_->OH() * pd_->OW() * inner_stride_; |
80 | stride_h_ = pd_->OW() * inner_stride_; |
81 | stride_w_ = inner_stride_; |
82 | } |
83 | |
84 | number_of_loops_ = (inner_stride_ / simd_w()); |
85 | tail_size_ = inner_stride_ % simd_w(); |
86 | stack_size_needed_ = 0; |
87 | |
88 | cpu_isa_t isa |
89 | = mayiuse(avx512_core_bf16) ? avx512_core_bf16 : avx512_core; |
90 | |
91 | const io::jit_io_multi_dt_helper_t<Zmm>::data_types_t data_types { |
92 | src_data_type(), dst_data_type()}; |
93 | std::map<data_type_t, io::io_saturation_conf_t> saturation_conf {}; |
94 | if (is_saturation_needed_) { |
95 | saturation_conf.emplace(dst_data_type(), |
96 | io::io_saturation_conf_t {zmm_zero_saturation_.getIdx(), |
97 | zmm_saturation_ubound_.getIdx(), reg_tmp}); |
98 | } |
99 | |
100 | io_ = utils::make_unique<io::jit_io_multi_dt_helper_t<Zmm>>(this, isa, |
101 | data_types, io::io_conf_t {}, |
102 | io::io_tail_conf_t { |
103 | simd_w(), tail_size_, k_tail_mask, 0, reg_tmp}, |
104 | io::io_emu_bf16_conf_t {}, saturation_conf); |
105 | } |
106 | |
107 | private: |
108 | enum class rounding_mode { none, floor, ceil, rounding_mode_max }; |
109 | |
110 | struct bwd_counting_range_t { |
111 | RegExp loop_counter; |
112 | struct start_t { |
113 | RegExp linear[2]; |
114 | RegExp nearest; |
115 | } start; |
116 | struct end_t { |
117 | RegExp linear[2]; |
118 | RegExp nearest; |
119 | } end; |
120 | }; |
121 | |
122 | void round_to_near_away_from_zero(const Reg64 &dst_int_val, |
123 | const Xmm &to_round, const Xmm &zero_point_five, const Xmm &tmp) { |
124 | EvexModifierRounding rm_trunc(EvexModifierRounding::T_RZ_SAE); |
125 | vaddss(tmp, to_round, zero_point_five); |
126 | vcvtss2si(dst_int_val, tmp | rm_trunc); |
127 | } |
128 | |
129 | void for_begin(Label &begin_loop, Label &end_loop, |
130 | const RegExp &loop_counter, const RegExp &start, const RegExp &end, |
131 | const Reg64 &tmp) { |
132 | // for (initialization; check; incrementation) |
133 | // initialization |
134 | mov(tmp, ptr[start]); |
135 | mov(ptr[loop_counter], tmp); |
136 | L(begin_loop); |
137 | // check |
138 | mov(tmp, ptr[loop_counter]); |
139 | cmp(tmp, ptr[end]); |
140 | jge(end_loop, T_NEAR); |
141 | } |
142 | |
143 | void for_end(Label &begin_loop, Label &end_loop, const RegExp &loop_counter, |
144 | const Reg64 &tmp) { |
145 | // incrementation |
146 | mov(tmp, ptr[loop_counter]); |
147 | inc(tmp); |
148 | mov(ptr[loop_counter], tmp); |
149 | jmp(begin_loop, T_NEAR); |
150 | L(end_loop); |
151 | } |
152 | |
153 | void max(const Reg64 ®, const dim_t to_cmp) { |
154 | mov(reg_tmp, to_cmp); |
155 | cmp(reg, reg_tmp); |
156 | cmovl(reg, reg_tmp); |
157 | } |
158 | |
159 | void min(const Reg64 ®, const dim_t to_cmp) { |
160 | mov(reg_tmp, to_cmp); |
161 | cmp(reg, reg_tmp); |
162 | cmovg(reg, reg_tmp); |
163 | } |
164 | |
165 | void move_imm_float_to_xmm(const Xmm &xmm, const Reg64 &tmp, float imm) { |
166 | mov(tmp.cvt32(), float2int(imm)); |
167 | vmovd(xmm, tmp.cvt32()); |
168 | } |
169 | |
170 | void generate() override { |
171 | preamble(); |
172 | |
173 | io_->init_bf16(); |
174 | if (is_saturation_needed_) io_->init_saturate_f32({dst_data_type()}); |
175 | if (tail_size_) io_->prepare_tail_mask(); |
176 | |
177 | mov(reg_src, ptr[abi_param1 + GET_OFF(src)]); |
178 | mov(reg_dst, ptr[abi_param1 + GET_OFF(dst)]); |
179 | |
180 | move_imm_float_to_xmm(xmm_zero_point_five, reg_tmp, 0.5f); |
181 | |
182 | if (pd_->is_fwd()) { |
183 | // Count coeffs |
184 | if (pd_->ndims() == 5) { |
185 | mov(reg_curr_d, ptr[abi_param1 + GET_OFF(d)]); |
186 | mov(reg_curr_h, ptr[abi_param1 + GET_OFF(h)]); |
187 | mov(reg_curr_w, ptr[abi_param1 + GET_OFF(w)]); |
188 | count_dim_coeff(xmm_d_coeff, reg_curr_d, pd_->OD(), pd_->ID()); |
189 | count_dim_coeff(xmm_h_coeff, reg_curr_h, pd_->OH(), pd_->IH()); |
190 | count_dim_coeff(xmm_w_coeff, reg_curr_w, pd_->OW(), pd_->IW()); |
191 | } else if (pd_->ndims() == 4) { |
192 | mov(reg_curr_h, ptr[abi_param1 + GET_OFF(h)]); |
193 | mov(reg_curr_w, ptr[abi_param1 + GET_OFF(w)]); |
194 | count_dim_coeff(xmm_h_coeff, reg_curr_h, pd_->OH(), pd_->IH()); |
195 | count_dim_coeff(xmm_w_coeff, reg_curr_w, pd_->OW(), pd_->IW()); |
196 | } else { |
197 | mov(reg_curr_w, ptr[abi_param1 + GET_OFF(w)]); |
198 | count_dim_coeff(xmm_w_coeff, reg_curr_w, pd_->OW(), pd_->IW()); |
199 | } |
200 | } else { |
201 | if (pd_->desc()->alg_kind == alg_kind::resampling_linear) { |
202 | // Stack: |
203 | // ow_loop_counter: |
204 | // ow_left_start : 8 |
205 | // ow_left_end : 16 |
206 | // ow_right_start : 24 |
207 | // ow_right_end : 32 |
208 | // ----- 3 dims ----- |
209 | // oh_loop_counter: 40 |
210 | // oh_left_start : 48 |
211 | // oh_left_end : 56 |
212 | // oh_right_start : 64 |
213 | // oh_right_end : 72 |
214 | // ----- 4 dims ----- |
215 | // od_loop_counter: 80 |
216 | // od_left_start : 88 |
217 | // od_left_end : 96 |
218 | // od_right_start : 104 |
219 | // od_right_end : 112 |
220 | // ----- 5 dims ----- |
221 | |
222 | // 5*size(int64)*nr_of_spatial_dims |
223 | stack_size_needed_ = 5 * 8 * (pd_->ndims() - 2); |
224 | sub(rsp, stack_size_needed_); |
225 | |
226 | if (pd_->ndims() == 5) { |
227 | mov(reg_curr_d, ptr[abi_param1 + GET_OFF(d)]); |
228 | mov(reg_curr_h, ptr[abi_param1 + GET_OFF(h)]); |
229 | mov(reg_curr_w, ptr[abi_param1 + GET_OFF(w)]); |
230 | count_bwd_counting_range( |
231 | rsp + 80, od, reg_curr_d, pd_->OD(), pd_->ID()); |
232 | count_bwd_counting_range( |
233 | rsp + 40, oh, reg_curr_h, pd_->OH(), pd_->IH()); |
234 | count_bwd_counting_range( |
235 | rsp, ow, reg_curr_w, pd_->OW(), pd_->IW()); |
236 | } else if (pd_->ndims() == 4) { |
237 | mov(reg_curr_h, ptr[abi_param1 + GET_OFF(h)]); |
238 | mov(reg_curr_w, ptr[abi_param1 + GET_OFF(w)]); |
239 | count_bwd_counting_range( |
240 | rsp + 40, oh, reg_curr_h, pd_->OH(), pd_->IH()); |
241 | count_bwd_counting_range( |
242 | rsp, ow, reg_curr_w, pd_->OW(), pd_->IW()); |
243 | } else { |
244 | mov(reg_curr_w, ptr[abi_param1 + GET_OFF(w)]); |
245 | count_bwd_counting_range( |
246 | rsp, ow, reg_curr_w, pd_->OW(), pd_->IW()); |
247 | } |
248 | } else { |
249 | // Stack: |
250 | // ow_loop_counter: |
251 | // ow_start : 8 |
252 | // ow_end : 16 |
253 | // oh_loop_counter: 24 |
254 | // oh_start : 32 |
255 | // oh_end : 40 |
256 | // od_loop_counter: 48 |
257 | // od_start : 56 |
258 | // od_end : 64 |
259 | |
260 | // 3*size(int64)*max_nr_of_spatial_dims |
261 | stack_size_needed_ = 3 * 8 * 3; |
262 | sub(rsp, stack_size_needed_); |
263 | |
264 | mov(reg_curr_d, ptr[abi_param1 + GET_OFF(d)]); |
265 | mov(reg_curr_h, ptr[abi_param1 + GET_OFF(h)]); |
266 | mov(reg_curr_w, ptr[abi_param1 + GET_OFF(w)]); |
267 | count_bwd_counting_range( |
268 | rsp + 48, od, reg_curr_d, pd_->OD(), pd_->ID()); |
269 | count_bwd_counting_range( |
270 | rsp + 24, oh, reg_curr_h, pd_->OH(), pd_->IH()); |
271 | count_bwd_counting_range( |
272 | rsp, ow, reg_curr_w, pd_->OW(), pd_->IW()); |
273 | } |
274 | } |
275 | |
276 | // Choose algorithm |
277 | if (pd_->desc()->alg_kind == alg_kind::resampling_linear) { |
278 | if (pd_->ndims() == 5) { |
279 | trilinear(); |
280 | } else if (pd_->ndims() == 4) { |
281 | bilinear(); |
282 | } else { |
283 | linear(); |
284 | } |
285 | } else { |
286 | nearest(); |
287 | } |
288 | |
289 | if (!pd_->is_fwd()) add(rsp, stack_size_needed_); |
290 | postamble(); |
291 | } |
292 | |
293 | void count_dim_coeff(const Xmm &xmm_coeff, const Reg64 ®_dim, |
294 | dim_t y_max, dim_t x_max) { |
295 | // Formula = ((y + 0.5f) * x_max / y_max) - 0.5f |
296 | vcvtsi2ss(xmm_coeff, xmm_coeff, reg_dim); // y |
297 | vaddss(xmm_coeff, xmm_coeff, xmm_zero_point_five); // y + 0.5f |
298 | |
299 | move_imm_float_to_xmm(xmm_tmp_factor, reg_tmp, (float)x_max); |
300 | vmulss(xmm_coeff, xmm_coeff, |
301 | xmm_tmp_factor); // (y + 0.5f) * x_max |
302 | move_imm_float_to_xmm(xmm_tmp_factor, reg_tmp, (float)y_max); |
303 | vdivss(xmm_coeff, xmm_coeff, |
304 | xmm_tmp_factor); // (y + 0.5f) * x_max / y_max |
305 | |
306 | vsubss(xmm_coeff, xmm_coeff, |
307 | xmm_zero_point_five); // ((y + 0.5) * x_max / y_max) - 0.5 |
308 | } |
309 | |
310 | void count_bwd_counting_range(RegExp stack_position, |
311 | bwd_counting_range_t &c_range, const Reg64 &curr_position, |
312 | dim_t y_max, dim_t x_max) { |
313 | c_range.loop_counter = stack_position; |
314 | if (pd_->desc()->alg_kind == alg_kind::resampling_linear) { |
315 | c_range.start.linear[0] = stack_position + 8; |
316 | c_range.end.linear[0] = stack_position + 16; |
317 | c_range.start.linear[1] = stack_position + 24; |
318 | c_range.end.linear[1] = stack_position + 32; |
319 | } else { |
320 | c_range.start.nearest = stack_position + 8; |
321 | c_range.end.nearest = stack_position + 16; |
322 | } |
323 | |
324 | EvexModifierRounding rm_ceil(EvexModifierRounding::T_RU_SAE); |
325 | EvexModifierRounding rm_floor(EvexModifierRounding::T_RD_SAE); |
326 | |
327 | if (pd_->desc()->alg_kind == alg_kind::resampling_linear) { |
328 | // coeff = (pos + 0.5) * y_max / x_max - 0.5 |
329 | count_dim_coeff(xmm_coeff, curr_position, x_max, y_max); |
330 | |
331 | // l_start: x == 0 ? 0 : ceil(coeff) |
332 | vcvtss2si(reg_tmp_idx, xmm_coeff | rm_ceil); |
333 | mov(reg_tmp, 0); |
334 | cmp(curr_position, reg_tmp); |
335 | cmove(reg_tmp_idx, reg_tmp); |
336 | mov(ptr[c_range.start.linear[0]], reg_tmp_idx); |
337 | |
338 | // r_end: x == x_max-1 ? y_max : min(max(0, floor(coeff) + 1), y_max) |
339 | vcvtss2si(reg_tmp_idx, xmm_coeff | rm_floor); |
340 | add(reg_tmp_idx, 1); |
341 | max(reg_tmp_idx, 0); |
342 | min(reg_tmp_idx, y_max); |
343 | cmp(curr_position, x_max - 1); |
344 | mov(reg_tmp, y_max); |
345 | cmove(reg_tmp_idx, reg_tmp); |
346 | mov(ptr[c_range.end.linear[1]], reg_tmp_idx); |
347 | |
348 | // coeff = ((pos-1) + 0.5) * y_max / x_max - 0.5 |
349 | sub(curr_position, 1); |
350 | count_dim_coeff(xmm_coeff, curr_position, x_max, y_max); |
351 | |
352 | // r_start: max(0, floor(coeff) + 1) |
353 | vcvtss2si(reg_tmp_idx, xmm_coeff | rm_floor); |
354 | add(reg_tmp_idx, 1); |
355 | max(reg_tmp_idx, 0); |
356 | mov(ptr[c_range.start.linear[1]], reg_tmp_idx); |
357 | |
358 | // coeff = ((pos+1) + 0.5) * y_max / x_max - 0.5 |
359 | add(curr_position, 2); |
360 | count_dim_coeff(xmm_coeff, curr_position, x_max, y_max); |
361 | |
362 | // l_end: min(ceil(coeff), y_max) |
363 | vcvtss2si(reg_tmp_idx, xmm_coeff | rm_ceil); |
364 | min(reg_tmp_idx, y_max); |
365 | mov(ptr[c_range.end.linear[0]], reg_tmp_idx); |
366 | } else { |
367 | float factor = (float)y_max / x_max; |
368 | |
369 | // start: ceil(pos * factor - 0.5f) |
370 | vcvtsi2ss(xmm_coeff, xmm_coeff, curr_position); |
371 | move_imm_float_to_xmm(xmm_tmp_factor, reg_tmp, factor); |
372 | vmulss(xmm_coeff, xmm_coeff, xmm_tmp_factor); |
373 | vsubss(xmm_coeff, xmm_coeff, xmm_zero_point_five); |
374 | vcvtss2si(reg_tmp_idx, xmm_coeff | rm_ceil); |
375 | mov(ptr[c_range.start.nearest], reg_tmp_idx); |
376 | |
377 | // start: ceil((pos+1) * factor - 0.5f) |
378 | add(curr_position, 1); |
379 | vcvtsi2ss(xmm_coeff, xmm_coeff, curr_position); |
380 | vmulss(xmm_coeff, xmm_coeff, xmm_tmp_factor); |
381 | vsubss(xmm_coeff, xmm_coeff, xmm_zero_point_five); |
382 | vcvtss2si(reg_tmp_idx, xmm_coeff | rm_ceil); |
383 | mov(ptr[c_range.end.nearest], reg_tmp_idx); |
384 | } |
385 | } |
386 | |
387 | void count_idx_and_weight_for_linear(const Xmm &coeff, const Zmm &weight, |
388 | const Reg64 &idx, dim_t dim_max, rounding_mode rm) { |
389 | const Xmm xmm_weight = Xmm(weight.getIdx()); |
390 | Reg64 reg_idx_floor; |
391 | |
392 | if (pd_->is_fwd() && rm == rounding_mode::ceil) { |
393 | EvexModifierRounding rm_ceil(EvexModifierRounding::T_RU_SAE); |
394 | EvexModifierRounding rm_floor(EvexModifierRounding::T_RD_SAE); |
395 | vcvtss2si(idx, |
396 | coeff | rm_ceil); // ceil(coeff) |
397 | reg_idx_floor = reg_tmp; |
398 | vcvtss2si(reg_idx_floor, |
399 | coeff | rm_floor); // floor(coeff) |
400 | |
401 | } else { |
402 | EvexModifierRounding rm_floor(EvexModifierRounding::T_RD_SAE); |
403 | vcvtss2si(idx, |
404 | coeff | rm_floor); // floor(coeff) |
405 | reg_idx_floor = idx; |
406 | } |
407 | |
408 | vcvtsi2ss(xmm_tmp, xmm_tmp, reg_idx_floor); |
409 | vsubss(xmm_weight, coeff, |
410 | zmm_tmp); // W = coeff - idx |
411 | if (rm == rounding_mode::floor) { |
412 | move_imm_float_to_xmm(xmm_tmp, reg_tmp, 1.0f); |
413 | vsubss(xmm_weight, xmm_tmp, |
414 | xmm_weight); // W = 1 - (coeff - idx) |
415 | } |
416 | vbroadcastss(weight, xmm_weight); |
417 | |
418 | if (pd_->is_fwd()) { |
419 | if (rm == rounding_mode::ceil) { |
420 | min(idx, dim_max - 1); |
421 | } else if (rm == rounding_mode::floor) { |
422 | max(idx, 0); |
423 | } |
424 | } |
425 | } |
426 | |
427 | void linear_alg(int64_t channel_offset, rounding_mode rm_w, |
428 | rounding_mode rm_h = rounding_mode::none, |
429 | rounding_mode rm_d = rounding_mode::none, bool is_tail = false) { |
430 | xor_(reg_offset, reg_offset); // reg_offset = 0 |
431 | |
432 | if (rm_w != rounding_mode::none) { |
433 | // out: Ww, curr_w |
434 | count_idx_and_weight_for_linear( |
435 | xmm_w_coeff, zmm_weight, reg_curr_w, pd_->IW(), rm_w); |
436 | // curr_w * stride_w_ |
437 | if (!pd_->is_fwd()) mov(reg_curr_w, ptr[ow.loop_counter]); |
438 | imul(reg_offset, reg_curr_w, stride_w_); |
439 | } |
440 | if (rm_h != rounding_mode::none) { |
441 | // out: Wh, curr_h |
442 | count_idx_and_weight_for_linear( |
443 | xmm_h_coeff, zmm_tmp_weight, reg_curr_h, pd_->IH(), rm_h); |
444 | // Ww * Wh |
445 | vmulps(zmm_weight, zmm_weight, zmm_tmp_weight); |
446 | // curr_w * stride_w_ + curr_h * stride_h_ |
447 | if (!pd_->is_fwd()) mov(reg_curr_h, ptr[oh.loop_counter]); |
448 | imul(reg_tmp, reg_curr_h, stride_h_); |
449 | add(reg_offset, reg_tmp); |
450 | } |
451 | if (rm_d != rounding_mode::none) { |
452 | // out: Wd, curr_d |
453 | count_idx_and_weight_for_linear( |
454 | xmm_d_coeff, zmm_tmp_weight, reg_curr_d, pd_->ID(), rm_d); |
455 | // Ww * Wh * Wd |
456 | vmulps(zmm_weight, zmm_weight, zmm_tmp_weight); |
457 | // curr_w * stride_w_ + curr_h * stride_h_ + curr_d * stride_d_ |
458 | if (!pd_->is_fwd()) mov(reg_curr_d, ptr[od.loop_counter]); |
459 | imul(reg_tmp, reg_curr_d, stride_d_); |
460 | add(reg_offset, reg_tmp); |
461 | } |
462 | |
463 | add(reg_offset, channel_offset); |
464 | imul(reg_offset, reg_offset, types::data_type_size(src_data_type())); |
465 | |
466 | // read src |
467 | io_->at(src_data_type()) |
468 | ->load(ptr[reg_src + reg_offset], zmm_src, is_tail); |
469 | |
470 | // mul src, weight |
471 | vmulps(zmm_tmp, zmm_src, zmm_weight); |
472 | vaddps(zmm_dst, zmm_dst, zmm_tmp); |
473 | } |
474 | |
475 | void linear() { |
476 | int64_t number_of_processed_points = 0; |
477 | |
478 | auto resample_linear = ([&](bool is_tail) { |
479 | auto call_linear = ([&](int i) { |
480 | linear_alg(number_of_processed_points, |
481 | i % 2 ? rounding_mode::floor |
482 | : rounding_mode::ceil /* rounding_mode_w */, |
483 | rounding_mode::none /* rounding_mode_h */, |
484 | rounding_mode::none /* rounding_mode_d */, is_tail); |
485 | }); |
486 | |
487 | // zero dst |
488 | vpxorq(zmm_dst, zmm_dst, zmm_dst); |
489 | |
490 | if (pd_->is_fwd()) { |
491 | for (int i = 0; i < 2; i++) { |
492 | call_linear(i); |
493 | } |
494 | } else { |
495 | Label label[2][2]; |
496 | |
497 | for (int i = 0; i < 2; i++) { |
498 | // for (dim_t ow = w.start[i]; ow < w.end[i]; ow++) |
499 | for_begin(label[i][0], label[i][1], ow.loop_counter, |
500 | ow.start.linear[i], ow.end.linear[i], reg_tmp); |
501 | count_dim_coeff(xmm_w_coeff, reg_tmp, pd_->OW(), pd_->IW()); |
502 | |
503 | call_linear(i + 1); |
504 | |
505 | for_end(label[i][0], label[i][1], ow.loop_counter, reg_tmp); |
506 | } |
507 | } |
508 | |
509 | const size_t offset = number_of_processed_points |
510 | * types::data_type_size(dst_data_type()); |
511 | const auto address = ptr[reg_dst + offset]; |
512 | // store dst |
513 | io_->at(dst_data_type())->store(zmm_dst, address, is_tail); |
514 | }); |
515 | |
516 | for (unsigned i = 0; i < number_of_loops_; |
517 | i++, number_of_processed_points += simd_w()) |
518 | resample_linear(false); |
519 | |
520 | if (tail_size_ != 0) resample_linear(true); |
521 | } |
522 | |
523 | void bilinear() { |
524 | int64_t number_of_processed_points = 0; |
525 | |
526 | auto resample_linear = ([&](bool is_tail) { |
527 | auto call_linear = ([&](int i, int j) { |
528 | linear_alg(number_of_processed_points, |
529 | i % 2 ? rounding_mode::floor |
530 | : rounding_mode::ceil /* rounding_mode_w */, |
531 | j % 2 ? rounding_mode::floor |
532 | : rounding_mode::ceil /* rounding_mode_h */, |
533 | rounding_mode::none /* rounding_mode_d */, is_tail); |
534 | }); |
535 | |
536 | // zero dst |
537 | vpxorq(zmm_dst, zmm_dst, zmm_dst); |
538 | |
539 | if (pd_->is_fwd()) { |
540 | for (int i = 0; i < 2; i++) { |
541 | for (int j = 0; j < 2; j++) { |
542 | call_linear(i, j); |
543 | } |
544 | } |
545 | } else { |
546 | Label label[2][2][4]; |
547 | |
548 | for (int i = 0; i < 2; i++) { |
549 | for (int j = 0; j < 2; j++) { |
550 | // for (dim_t ow = w.start[i]; ow < w.end[i]; ow++) |
551 | for_begin(label[i][j][0], label[i][j][1], |
552 | ow.loop_counter, ow.start.linear[i], |
553 | ow.end.linear[i], reg_tmp); |
554 | count_dim_coeff( |
555 | xmm_w_coeff, reg_tmp, pd_->OW(), pd_->IW()); |
556 | // for (dim_t oh = h.start[j]; oh < h.end[j]; oh++) |
557 | for_begin(label[i][j][2], label[i][j][3], |
558 | oh.loop_counter, oh.start.linear[j], |
559 | oh.end.linear[j], reg_tmp); |
560 | count_dim_coeff( |
561 | xmm_h_coeff, reg_tmp, pd_->OH(), pd_->IH()); |
562 | |
563 | call_linear(i + 1, j + 1); |
564 | |
565 | for_end(label[i][j][2], label[i][j][3], oh.loop_counter, |
566 | reg_tmp); |
567 | for_end(label[i][j][0], label[i][j][1], ow.loop_counter, |
568 | reg_tmp); |
569 | } |
570 | } |
571 | } |
572 | |
573 | const size_t offset = number_of_processed_points |
574 | * types::data_type_size(dst_data_type()); |
575 | const auto address = ptr[reg_dst + offset]; |
576 | // store dst |
577 | io_->at(dst_data_type())->store(zmm_dst, address, is_tail); |
578 | }); |
579 | |
580 | for (unsigned i = 0; i < number_of_loops_; |
581 | i++, number_of_processed_points += simd_w()) |
582 | resample_linear(false); |
583 | |
584 | if (tail_size_ != 0) resample_linear(true); |
585 | } |
586 | |
587 | void trilinear() { |
588 | int64_t number_of_processed_points = 0; |
589 | |
590 | auto resample_linear = ([&](bool is_tail) { |
591 | auto call_linear = ([&](int i, int j, int k) { |
592 | linear_alg(number_of_processed_points, |
593 | i % 2 ? rounding_mode::floor |
594 | : rounding_mode::ceil /* rounding_mode_w */, |
595 | j % 2 ? rounding_mode::floor |
596 | : rounding_mode::ceil /* rounding_mode_h */, |
597 | k % 2 ? rounding_mode::floor |
598 | : rounding_mode::ceil /* rounding_mode_d */, |
599 | is_tail); |
600 | }); |
601 | |
602 | // zero dst |
603 | vpxorq(zmm_dst, zmm_dst, zmm_dst); |
604 | |
605 | if (pd_->is_fwd()) { |
606 | for (int i = 0; i < 2; i++) { |
607 | for (int j = 0; j < 2; j++) { |
608 | for (int k = 0; k < 2; k++) { |
609 | call_linear(i, j, k); |
610 | } |
611 | } |
612 | } |
613 | } else { |
614 | Label label[2][2][2][6]; |
615 | |
616 | for (int i = 0; i < 2; i++) { |
617 | for (int j = 0; j < 2; j++) { |
618 | for (int k = 0; k < 2; k++) { |
619 | // for (dim_t ow = w.start[i]; ow < w.end[i]; ow++) |
620 | for_begin(label[i][j][k][0], label[i][j][k][1], |
621 | ow.loop_counter, ow.start.linear[i], |
622 | ow.end.linear[i], reg_tmp); |
623 | count_dim_coeff( |
624 | xmm_w_coeff, reg_tmp, pd_->OW(), pd_->IW()); |
625 | // for (dim_t oh = h.start[j]; oh < h.end[j]; oh++) |
626 | for_begin(label[i][j][k][2], label[i][j][k][3], |
627 | oh.loop_counter, oh.start.linear[j], |
628 | oh.end.linear[j], reg_tmp); |
629 | count_dim_coeff( |
630 | xmm_h_coeff, reg_tmp, pd_->OH(), pd_->IH()); |
631 | // for (dim_t od = d.start[k]; od < d.end[k]; od++) |
632 | for_begin(label[i][j][k][4], label[i][j][k][5], |
633 | od.loop_counter, od.start.linear[k], |
634 | od.end.linear[k], reg_tmp); |
635 | count_dim_coeff( |
636 | xmm_d_coeff, reg_tmp, pd_->OD(), pd_->ID()); |
637 | |
638 | call_linear(i + 1, j + 1, k + 1); |
639 | |
640 | for_end(label[i][j][k][4], label[i][j][k][5], |
641 | od.loop_counter, reg_tmp); |
642 | for_end(label[i][j][k][2], label[i][j][k][3], |
643 | oh.loop_counter, reg_tmp); |
644 | for_end(label[i][j][k][0], label[i][j][k][1], |
645 | ow.loop_counter, reg_tmp); |
646 | } |
647 | } |
648 | } |
649 | } |
650 | |
651 | const size_t offset = number_of_processed_points |
652 | * types::data_type_size(dst_data_type()); |
653 | const auto address = ptr[reg_dst + offset]; |
654 | // store dst |
655 | io_->at(dst_data_type())->store(zmm_dst, address, is_tail); |
656 | }); |
657 | |
658 | for (unsigned i = 0; i < number_of_loops_; |
659 | i++, number_of_processed_points += simd_w()) |
660 | resample_linear(false); |
661 | |
662 | if (tail_size_ != 0) resample_linear(true); |
663 | } |
664 | |
665 | void nearest_alg(int64_t channel_offset, bool is_tail = false) { |
666 | xor_(reg_offset, reg_offset); // reg_offset = 0 |
667 | |
668 | auto get_idx = ([&](const Reg64 &idx, const Xmm &coeff, |
669 | dim_t dim_max_size) { |
670 | round_to_near_away_from_zero(idx, coeff, xmm_zero_point_five, |
671 | xmm_tmp); // round_to_nearest(coeff) |
672 | min(idx, dim_max_size - 1); |
673 | max(idx, 0); |
674 | }); |
675 | |
676 | if (pd_->is_fwd()) { |
677 | get_idx(reg_curr_w, xmm_w_coeff, pd_->IW()); |
678 | get_idx(reg_curr_h, xmm_h_coeff, pd_->IH()); |
679 | get_idx(reg_curr_d, xmm_d_coeff, pd_->ID()); |
680 | } else { |
681 | mov(reg_curr_w, ptr[ow.loop_counter]); |
682 | mov(reg_curr_h, ptr[oh.loop_counter]); |
683 | mov(reg_curr_d, ptr[od.loop_counter]); |
684 | } |
685 | |
686 | imul(reg_offset, reg_curr_w, stride_w_); // iw * stride_w_ |
687 | imul(reg_tmp, reg_curr_h, stride_h_); |
688 | add(reg_offset, reg_tmp); // iw * stride_w_ + ih * stride_h_ |
689 | imul(reg_tmp, reg_curr_d, stride_d_); |
690 | add(reg_offset, |
691 | reg_tmp); // iw * stride_w_ + ih * stride_h_ + id * stride_d_ |
692 | |
693 | add(reg_offset, |
694 | channel_offset); // iw * stride_w_ + ih * stride_h_ + id * stride_d_ + channel_offset |
695 | imul(reg_offset, reg_offset, |
696 | types::data_type_size( |
697 | src_data_type())); // (iw * stride_w_ + ih * stride_h_ + id * stride_d_ + channel_offset)*dt_size |
698 | |
699 | if (pd_->is_fwd()) { |
700 | // read nearest to dst |
701 | io_->at(src_data_type()) |
702 | ->load(ptr[reg_src + reg_offset], zmm_dst, is_tail); |
703 | } else { |
704 | // add nearest to dst |
705 | io_->at(src_data_type()) |
706 | ->load(ptr[reg_src + reg_offset], zmm_tmp, is_tail); |
707 | vaddps(zmm_dst, zmm_dst, zmm_tmp); |
708 | } |
709 | } |
710 | |
711 | void nearest() { |
712 | int64_t number_of_processed_points = 0; |
713 | |
714 | auto resample_nearest = ([&](bool is_tail) { |
715 | // zero dst |
716 | vpxorq(zmm_dst, zmm_dst, zmm_dst); |
717 | |
718 | if (pd_->is_fwd()) { |
719 | nearest_alg(number_of_processed_points, is_tail); |
720 | } else { |
721 | Label label[6]; |
722 | |
723 | // for (dim_t ow = w.start[i]; ow < w.end[i]; ow++) |
724 | for_begin(label[0], label[1], ow.loop_counter, ow.start.nearest, |
725 | ow.end.nearest, reg_tmp); |
726 | // for (dim_t oh = h.start[j]; oh < h.end[j]; oh++) |
727 | for_begin(label[2], label[3], oh.loop_counter, oh.start.nearest, |
728 | oh.end.nearest, reg_tmp); |
729 | // for (dim_t od = d.start[k]; od < d.end[k]; od++) |
730 | for_begin(label[4], label[5], od.loop_counter, od.start.nearest, |
731 | od.end.nearest, reg_tmp); |
732 | |
733 | nearest_alg(number_of_processed_points, is_tail); |
734 | |
735 | for_end(label[4], label[5], od.loop_counter, reg_tmp); |
736 | for_end(label[2], label[3], oh.loop_counter, reg_tmp); |
737 | for_end(label[0], label[1], ow.loop_counter, reg_tmp); |
738 | } |
739 | |
740 | const size_t offset = number_of_processed_points |
741 | * types::data_type_size(dst_data_type()); |
742 | const auto address = ptr[reg_dst + offset]; |
743 | // store dst |
744 | io_->at(dst_data_type())->store(zmm_dst, address, is_tail); |
745 | }); |
746 | |
747 | for (unsigned i = 0; i < number_of_loops_; |
748 | i++, number_of_processed_points += simd_w()) |
749 | resample_nearest(false); |
750 | |
751 | if (tail_size_ != 0) resample_nearest(true); |
752 | } |
753 | |
754 | static constexpr std::size_t simd_w() { |
755 | return cpu_isa_traits<avx512_core>::vlen / sizeof(float); |
756 | } |
757 | |
758 | Zmm zmm_src = Zmm(1); |
759 | Zmm zmm_dst = Zmm(2); |
760 | Zmm zmm_weight = Zmm(3); |
761 | Xmm xmm_coeff = Xmm(4); |
762 | Xmm xmm_d_coeff = Xmm(4); |
763 | Xmm xmm_h_coeff = Xmm(5); |
764 | Xmm xmm_w_coeff = Xmm(6); |
765 | Xmm xmm_zero_point_five = Xmm(7); |
766 | Zmm zmm_tmp = Zmm(8); |
767 | Xmm xmm_tmp = Xmm(8); |
768 | Zmm zmm_tmp_weight = Zmm(9); |
769 | Xmm xmm_tmp_factor = Xmm(9); |
770 | Zmm zmm_zero_saturation_ = Zmm(10); |
771 | Zmm zmm_saturation_ubound_ = Zmm(11); |
772 | |
773 | Opmask k_tail_mask = k6; |
774 | Reg64 reg_src = rax; |
775 | Reg64 reg_dst = rbx; |
776 | Reg64 reg_tmp = r8; |
777 | Reg64 reg_curr_d = r9; |
778 | Reg64 reg_curr_h = r10; |
779 | Reg64 reg_curr_w = r11; |
780 | Reg64 reg_offset = r12; |
781 | Reg64 reg_tmp_idx = r12; |
782 | |
783 | bwd_counting_range_t ow; |
784 | bwd_counting_range_t oh; |
785 | bwd_counting_range_t od; |
786 | |
787 | std::unique_ptr<io::jit_io_multi_dt_helper_t<Zmm>> io_; |
788 | |
789 | dim_t stride_d_ = 0; |
790 | dim_t stride_h_ = 0; |
791 | dim_t stride_w_ = 0; |
792 | dim_t inner_stride_ = 0; |
793 | unsigned number_of_loops_ = 0; |
794 | size_t tail_size_ = 0; |
795 | bool is_saturation_needed_ = false; |
796 | unsigned stack_size_needed_ = 0; |
797 | }; |
798 | |
799 | } // namespace |
800 | |
801 | jit_avx512_core_resampling_kernel_base_t:: |
802 | jit_avx512_core_resampling_kernel_base_t( |
803 | const resampling_pd_t *pd, const char *name) |
804 | : jit_generator(name), pd_(pd) {} |
805 | |
806 | data_type_t jit_avx512_core_resampling_kernel_base_t::src_data_type() const { |
807 | if (pd_->is_fwd()) |
808 | return pd_->src_md()->data_type; |
809 | else |
810 | return pd_->diff_dst_md()->data_type; |
811 | } |
812 | |
813 | data_type_t jit_avx512_core_resampling_kernel_base_t::dst_data_type() const { |
814 | if (pd_->is_fwd()) |
815 | return pd_->dst_md()->data_type; |
816 | else |
817 | return pd_->diff_src_md()->data_type; |
818 | } |
819 | |
820 | status_t jit_avx512_core_resampling_bwd_t::pd_t::init(engine_t *engine) { |
821 | using namespace format_tag; |
822 | using namespace data_type; |
823 | const bool ok = mayiuse(avx512_core) && !is_fwd() && !has_zero_dim_memory() |
824 | && impl_supports_datatype(diff_dst_md()->data_type) |
825 | && impl_supports_datatype(diff_src_md()->data_type) |
826 | && IMPLICATION(diff_src_md()->data_type == f16, |
827 | mayiuse(avx512_core_fp16) |
828 | && memory_desc_wrapper(diff_src_md()).is_plain()) |
829 | && set_default_params() == status::success |
830 | && attr()->has_default_values(); |
831 | if (!ok) return status::unimplemented; |
832 | |
833 | format_tag_t dat_tag = memory_desc_matches_one_of_tag(*diff_src_md(), nCw8c, |
834 | nChw8c, nCdhw8c, nCw16c, nChw16c, nCdhw16c, nwc, nhwc, ndhwc); |
835 | if (!memory_desc_matches_tag(*diff_dst_md(), dat_tag)) |
836 | return status::unimplemented; |
837 | |
838 | return status::success; |
839 | } |
840 | |
841 | jit_avx512_core_resampling_bwd_t::~jit_avx512_core_resampling_bwd_t() = default; |
842 | |
843 | status_t jit_avx512_core_resampling_bwd_t::init(engine_t *engine) { |
844 | CHECK(safe_ptr_assign( |
845 | kernel_, new jit_avx512_core_resampling_kernel_t(pd()))); |
846 | return kernel_->create_kernel(); |
847 | } |
848 | |
849 | status_t jit_avx512_core_resampling_bwd_t::execute( |
850 | const exec_ctx_t &ctx) const { |
851 | |
852 | const auto diff_dst = CTX_IN_MEM(const unsigned char *, DNNL_ARG_DIFF_DST); |
853 | auto diff_src = CTX_OUT_MEM(unsigned char *, DNNL_ARG_DIFF_SRC); |
854 | |
855 | const std::size_t diff_dst_dt_size |
856 | = types::data_type_size(pd()->diff_dst_md()->data_type); |
857 | const std::size_t diff_src_dt_size |
858 | = types::data_type_size(pd()->diff_src_md()->data_type); |
859 | |
860 | const dim_t OD = pd()->OD(); |
861 | const dim_t OH = pd()->OH(); |
862 | const dim_t OW = pd()->OW(); |
863 | const dim_t ID = pd()->ID(); |
864 | const dim_t IH = pd()->IH(); |
865 | const dim_t IW = pd()->IW(); |
866 | |
867 | const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); |
868 | const dim_t inner_stride |
869 | = diff_src_d.blocking_desc().strides[pd()->ndims() - 1]; |
870 | const dim_t nsp_outer |
871 | = diff_src_d.nelems(true) / (ID * IH * IW * inner_stride); |
872 | |
873 | parallel_nd(nsp_outer, ID, IH, IW, |
874 | [&](dim_t nsp, dim_t id, dim_t ih, dim_t iw) { |
875 | const dim_t diff_dst_off |
876 | = nsp * OD * OH * OW * inner_stride * diff_dst_dt_size; |
877 | const dim_t diff_src_off |
878 | = (nsp * ID * IH * IW + id * IH * IW + ih * IW + iw) |
879 | * inner_stride * diff_src_dt_size; |
880 | jit_resampling_args_t args; |
881 | args.src = diff_dst + diff_dst_off; |
882 | args.dst = diff_src + diff_src_off; |
883 | args.d = id; |
884 | args.h = ih; |
885 | args.w = iw; |
886 | (*kernel_)(&args); |
887 | }); |
888 | |
889 | return status_t(); |
890 | } |
891 | |
892 | } // namespace x64 |
893 | } // namespace cpu |
894 | } // namespace impl |
895 | } // namespace dnnl |
896 | |