1 | /******************************************************************************* |
2 | * Copyright 2018-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <assert.h> |
18 | #include <numeric> |
19 | |
20 | #include "oneapi/dnnl/dnnl_debug.h" |
21 | |
22 | #include "common/c_types_map.hpp" |
23 | #include "common/dnnl_thread.hpp" |
24 | #include "common/memory_desc_wrapper.hpp" |
25 | #include "common/nstl.hpp" |
26 | #include "common/primitive.hpp" |
27 | #include "common/type_helpers.hpp" |
28 | #include "common/utils.hpp" |
29 | |
30 | #include "cpu/cpu_primitive.hpp" |
31 | #include "cpu/reorder/cpu_reorder_pd.hpp" |
32 | #include "cpu/x64/jit_uni_reorder.hpp" |
33 | |
34 | #include "cpu/x64/jit_avx512_core_bf16cvt.hpp" |
35 | #include "cpu/x64/jit_generator.hpp" |
36 | |
37 | // #define TR_DEBUG |
38 | #if defined(TR_DEBUG) |
39 | #define DEBUg(...) \ |
40 | do { \ |
41 | __VA_ARGS__ \ |
42 | } while (0) |
43 | #else |
44 | #define DEBUg(...) |
45 | #endif |
46 | #define DEBUG(...) DEBUg(__VA_ARGS__) |
47 | |
48 | #ifdef _WIN32 |
49 | /* seems like s_addr is a reserved macro on Windows */ |
50 | #undef s_addr |
51 | constexpr static bool is_windows = true; |
52 | #else |
53 | constexpr static bool is_windows = false; |
54 | #endif |
55 | |
56 | using namespace Xbyak; |
57 | using namespace dnnl::impl::types; |
58 | |
59 | namespace dnnl { |
60 | namespace impl { |
61 | namespace cpu { |
62 | namespace x64 { |
63 | |
64 | namespace tr { |
65 | |
66 | static bool prb_has_small_strides(const prb_t &prb) { |
67 | constexpr ptrdiff_t max_stride = (1LL << 31) - 1; |
68 | for (int d = 0; d < prb.ndims; ++d) { |
69 | const ptrdiff_t cms = max_stride / prb.nodes[d].n; |
70 | const bool small_strides = true |
71 | && prb.nodes[d].is < cms / (int)data_type_size(prb.itype) |
72 | && prb.nodes[d].os < cms / (int)data_type_size(prb.otype); |
73 | if (!small_strides) return false; |
74 | } |
75 | return true; |
76 | } |
77 | |
78 | /** Minimal reasonable/desirable kernel size. |
79 | * The constant might be used to determine how a problem should be split |
80 | * between kernel and threading driver. */ |
81 | const size_t ker_prb_size_min = 64; |
82 | |
83 | /* kernel */ |
84 | struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { |
85 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_reorder_kernel_f32) |
86 | |
87 | void operator()(const call_param_t *c) const override { |
88 | jit_generator::operator()(c); |
89 | } |
90 | void operator()(const tail_call_param_t *c) const override { |
91 | jit_generator::operator()(c); |
92 | } |
93 | |
94 | status_t create_kernel() override { return jit_generator::create_kernel(); } |
95 | |
96 | enum class scale_arg_t { NONE, SRC, DST }; |
97 | |
98 | enum { |
99 | len_unroll_max = 256, |
100 | ndims_jit_loop_max = 3, |
101 | }; |
102 | |
103 | struct simple_impl_desc_t { |
104 | int ndims_full_unroll = 0; |
105 | int len_last_dim_unroll = 0; |
106 | int tail_len_unroll = 0; |
107 | int len_unroll = 0; |
108 | }; |
109 | |
110 | #define PARAM(x) \ |
111 | prb_.is_tail_present \ |
112 | ? ptr[abi_param1 + offsetof(tail_call_param_t, base_params) \ |
113 | + offsetof(call_param_t, x)] \ |
114 | : ptr[abi_param1 + offsetof(call_param_t, x)] |
115 | #define TAIL_PARAM(x) ptr[abi_param1 + offsetof(tail_call_param_t, x)] |
116 | |
117 | static bool simple_impl_desc_init( |
118 | const prb_t &prb, simple_impl_desc_t *desc) { |
119 | const int ndims = prb.ndims; |
120 | |
121 | int ndims_full_unroll = 0; |
122 | int len_last_dim_unroll = 1; |
123 | int tail_len_unroll = 0; |
124 | int len_unroll = 1; |
125 | |
126 | // It is responsible for finding as many values |
127 | // as kernel can unroll. If tail is present then |
128 | // kernel will unroll only last node (possible improvement). |
129 | // If there is no tail kernel can unroll a few nodes without any loops etc. |
130 | // ndims_full_unroll - how many nodes will be unrolled |
131 | // len_last_dim_unroll - what piece of last unrolled node will be unrolled |
132 | if (prb.is_tail_present) { |
133 | ndims_full_unroll = 1; |
134 | len_unroll = prb.nodes[0].n; |
135 | tail_len_unroll = prb.nodes[0].is_zero_pad_needed |
136 | ? 0 |
137 | : static_cast<int>(prb.nodes[0].tail_size); |
138 | } else { |
139 | for (int d = 0; d < ndims; ++d) { |
140 | const auto &node = prb.nodes[d]; |
141 | if (len_unroll * node.n <= len_unroll_max) { |
142 | ndims_full_unroll++; |
143 | len_unroll *= node.n; |
144 | } else { |
145 | len_last_dim_unroll = len_unroll_max / len_unroll; |
146 | while (node.n % len_last_dim_unroll) |
147 | --len_last_dim_unroll; |
148 | len_unroll *= len_last_dim_unroll; |
149 | break; |
150 | } |
151 | } |
152 | } |
153 | |
154 | if (prb.ndims - ndims_full_unroll > ndims_jit_loop_max) return false; |
155 | |
156 | if (desc) { |
157 | desc->ndims_full_unroll = ndims_full_unroll; |
158 | desc->len_last_dim_unroll = len_last_dim_unroll; |
159 | desc->tail_len_unroll = tail_len_unroll; |
160 | desc->len_unroll = len_unroll; |
161 | } |
162 | |
163 | return true; |
164 | } |
165 | |
166 | static bool applicable(const prb_t &p) { |
167 | using namespace data_type; |
168 | |
169 | bool ok = true && p.ndims > 0 |
170 | && utils::one_of(p.itype, f32, bf16, f16, s32, s8, u8) |
171 | && utils::one_of(p.otype, f32, bf16, f16, s32, s8, u8) |
172 | && IMPLICATION(utils::one_of(p.itype, bf16, f16), |
173 | utils::one_of(p.otype, s8, u8, f32, bf16, f16)) |
174 | && IMPLICATION(utils::one_of(p.otype, bf16, f16), |
175 | utils::one_of(p.itype, s8, u8, f32, bf16, f16)) |
176 | && utils::everyone_is(0, p.ioff, p.ooff) /* do we need this? */ |
177 | && utils::one_of(p.beta, 0.f, 1.f) /* anything else? */ |
178 | && simple_impl_desc_init(p, nullptr) && mayiuse(sse41) |
179 | && IMPLICATION(utils::one_of(bf16, p.itype, p.otype), |
180 | mayiuse(avx512_core) || mayiuse(avx2_vnni_2)) |
181 | && IMPLICATION(utils::one_of(f16, p.itype, p.otype), |
182 | mayiuse(avx512_core_fp16) || mayiuse(avx2_vnni_2)) |
183 | && prb_has_small_strides(p); |
184 | |
185 | return ok; |
186 | } |
187 | |
188 | Address i_addr(int i_off) { |
189 | return ptr[reg_ptr_in_ + reg_off_in_ + i_off * itype_sz_]; |
190 | } |
191 | |
192 | Address o_addr(int o_off, bool with_type_multiplier = true) { |
193 | if (with_type_multiplier) |
194 | return ptr[reg_ptr_out_ + reg_off_out_ + o_off * otype_sz_]; |
195 | else |
196 | return ptr[reg_ptr_out_ + reg_off_out_ + o_off]; |
197 | } |
198 | |
199 | Address src_s_addr(int s_off) { |
200 | return ptr[reg_ptr_src_scales_ + reg_off_scale_ + s_off * stype_sz_]; |
201 | } |
202 | |
203 | Address dst_s_addr(int s_off) { |
204 | return ptr[reg_ptr_dst_scales_ + reg_off_scale_ + s_off * stype_sz_]; |
205 | } |
206 | |
207 | Address c_addr(int c_off) { |
208 | return ptr[reg_ptr_comp_ + reg_off_comp_ + c_off * sizeof(int32_t)]; |
209 | } |
210 | |
211 | Address data_chunk_addr(int node_id) { |
212 | return ptr[abi_param1 + offsetof(tail_call_param_t, curr_data_chunks) |
213 | + sizeof(int64_t) * (node_id)]; |
214 | } |
215 | |
216 | void step(int off, int prev_i_off, int prev_o_off, int prev_s_off, |
217 | int prev_c_off, int &i_off, int &o_off, int &s_off, int &c_off, |
218 | int step_size = 1) { |
219 | i_off = prev_i_off; |
220 | o_off = prev_o_off; |
221 | s_off = prev_s_off; |
222 | c_off = prev_c_off; |
223 | |
224 | if (off == 0) return; |
225 | |
226 | int start_dim = 0, dims_prod = 1; |
227 | for (; start_dim < prb_.ndims && dims_prod != step_size; ++start_dim) |
228 | dims_prod *= prb_.n(start_dim); |
229 | assert(start_dim < prb_.ndims); |
230 | off /= step_size; |
231 | |
232 | for (int dim_id = start_dim; dim_id < prb_.ndims; ++dim_id) { |
233 | i_off += prb_.is(dim_id); |
234 | o_off += prb_.os(dim_id); |
235 | s_off += prb_.ss(dim_id); |
236 | c_off += prb_.cs(dim_id); |
237 | |
238 | if (off % prb_.n(dim_id)) break; |
239 | |
240 | i_off += -prb_.n(dim_id) * prb_.is(dim_id); |
241 | o_off += -prb_.n(dim_id) * prb_.os(dim_id); |
242 | s_off += -prb_.n(dim_id) * prb_.ss(dim_id); |
243 | c_off += -prb_.n(dim_id) * prb_.cs(dim_id); |
244 | |
245 | off /= prb_.n(dim_id); |
246 | |
247 | if (off == 0) break; /* FIXME: is it really required? */ |
248 | } |
249 | } |
250 | |
251 | void step(int off, int prev_i_off, int prev_o_off, int &i_off, int &o_off, |
252 | int step_size = 1) { |
253 | int dummy = 0; |
254 | step(off, prev_i_off, prev_o_off, dummy, dummy, i_off, o_off, dummy, |
255 | dummy, step_size); |
256 | } |
257 | |
258 | void tr8x8_avx2(int i_off, int o_off) { |
259 | using namespace data_type; |
260 | |
261 | const auto cvt2ps |
262 | = [=](const Ymm dst, const Operand &src, data_type_t idt) { |
263 | switch (idt) { |
264 | case f32: |
265 | if (src.isMEM() || src.getIdx() != dst.getIdx()) |
266 | vmovups(dst, src); |
267 | break; |
268 | case bf16: |
269 | vpmovzxwd(dst, src); |
270 | vpslld(dst, dst, 0x10); |
271 | break; |
272 | case f16: |
273 | if (is_superset(isa_, avx512_core_fp16)) { |
274 | if (src.isMEM()) |
275 | vcvtph2psx(dst, src); |
276 | else |
277 | vcvtph2psx(dst, Xmm(src.getIdx())); |
278 | } else if (is_superset(isa_, avx2_vnni_2)) { |
279 | if (src.isMEM()) |
280 | vcvtph2ps(dst, src); |
281 | else |
282 | vcvtph2ps(dst, Xmm(src.getIdx())); |
283 | } else |
284 | assert(!"invalid isa" ); |
285 | break; |
286 | case s32: vcvtdq2ps(dst, src); break; |
287 | case s8: |
288 | vpmovsxbd(dst, src); |
289 | vcvtdq2ps(dst, dst); |
290 | break; |
291 | case u8: |
292 | vpmovzxbd(dst, src); |
293 | vcvtdq2ps(dst, dst); |
294 | break; |
295 | default: assert(!"unreachable" ); |
296 | } |
297 | }; |
298 | |
299 | const auto cvt2odt = [=](const Ymm ymm, data_type_t odt, |
300 | data_type_t idt) { |
301 | const Xmm xmm = Xmm(ymm.getIdx()); |
302 | switch (odt) { |
303 | case bf16: |
304 | if (utils::one_of(idt, f32, f16, s8, u8)) { |
305 | if (idt != f32) cvt2ps(ymm, ymm, idt); |
306 | if (is_superset(isa_, avx2_vnni_2)) { |
307 | vcvtneps2bf16( |
308 | Xmm(ymm.getIdx()), ymm, Xbyak::VexEncoding); |
309 | } else if (mayiuse(avx512_core_bf16)) { |
310 | vcvtneps2bf16(Xmm(ymm.getIdx()), ymm); |
311 | } else { |
312 | bf16_emu_->vcvtneps2bf16( |
313 | Ymm(ymm.getIdx()), Zmm(ymm.getIdx())); |
314 | } |
315 | } |
316 | break; |
317 | case f16: |
318 | if (utils::one_of(idt, f32, bf16, s8, u8)) { |
319 | if (idt != f32) cvt2ps(ymm, ymm, idt); |
320 | vcvtps2ph(Xmm(ymm.getIdx()), ymm, _op_mxcsr); |
321 | } |
322 | break; |
323 | case s32: |
324 | if (idt == f32) |
325 | vcvtps2dq(ymm, ymm); |
326 | else if (idt == s8) |
327 | vpmovsxbd(ymm, ymm); |
328 | else if (idt == u8) |
329 | vpmovzxbd(ymm, ymm); |
330 | break; |
331 | case s8: |
332 | if (utils::one_of(idt, bf16, f16)) cvt2ps(ymm, ymm, idt); |
333 | if (utils::one_of(idt, f32, bf16, f16)) vcvtps2dq(ymm, ymm); |
334 | if (utils::one_of(idt, bf16, f16, f32, s32)) { |
335 | if (mayiuse(avx512_core)) { |
336 | vpmovsdb(xmm, ymm); |
337 | } else { |
338 | vpackssdw(ymm, ymm, ymm_zero_); |
339 | vpermq(ymm, ymm, 0x58); |
340 | vpacksswb(ymm, ymm, ymm_zero_); |
341 | } |
342 | } |
343 | if (idt == u8) vpminub(ymm, ymm, ymm_8x127b_); |
344 | break; |
345 | case u8: |
346 | if (utils::one_of(idt, bf16, f16)) cvt2ps(ymm, ymm, idt); |
347 | if (utils::one_of(idt, f32, bf16, f16)) vcvtps2dq(ymm, ymm); |
348 | if (utils::one_of(idt, bf16, f16, f32, s32)) { |
349 | if (mayiuse(avx512_core)) { |
350 | vpmaxsd(ymm, ymm, ymm_zero_); |
351 | vpmovusdb(xmm, ymm); |
352 | } else { |
353 | vpackssdw(ymm, ymm, ymm_zero_); |
354 | vpermq(ymm, ymm, 0x58); |
355 | vpackuswb(ymm, ymm, ymm_zero_); |
356 | } |
357 | } |
358 | if (idt == s8) vpmaxsb(ymm, ymm, ymm_zero_); |
359 | break; |
360 | default: assert(!"unreachable" ); |
361 | } |
362 | }; |
363 | |
364 | auto load = [=](const Ymm ymm, const Address &addr, int size) { |
365 | const Xmm xmm = Xmm(ymm.getIdx()); |
366 | switch (size) { |
367 | case 32: vmovups(ymm, addr); break; |
368 | case 16: vmovups(xmm, addr); break; |
369 | case 8: vmovsd(xmm, addr); break; |
370 | default: assert(!"unreachable" ); |
371 | } |
372 | }; |
373 | |
374 | auto store = [=](const Address &addr, const Ymm ymm, int size) { |
375 | const Xmm xmm = Xmm(ymm.getIdx()); |
376 | switch (size) { |
377 | case 32: vmovups(addr, ymm); break; |
378 | case 16: vmovups(addr, xmm); break; |
379 | case 8: vmovsd(addr, xmm); break; |
380 | default: assert(!"unreachable" ); |
381 | } |
382 | }; |
383 | |
384 | const int unroll = 8; |
385 | |
386 | const bool interim_f32 = (prb_.itype != f32) |
387 | || utils::one_of(f32, prb_.itype, prb_.otype); |
388 | |
389 | const bool need_saturation |
390 | = (utils::one_of(prb_.otype, u8, s8, s32) && interim_f32); |
391 | |
392 | for (int i = 0; i < unroll; i++) { |
393 | const int node_0_input_stride = prb_.is(0); |
394 | load(Ymm(i), i_addr(i_off + i * node_0_input_stride), |
395 | unroll * itype_sz_); |
396 | |
397 | if (interim_f32) cvt2ps(Ymm(i), Ymm(i), prb_.itype); |
398 | } |
399 | |
400 | for (int i = 0; i < unroll / 2; i++) { |
401 | vunpcklps(Ymm(unroll + i), Ymm(2 * i), Ymm(2 * i + 1)); |
402 | vunpckhps(Ymm(i), Ymm(2 * i), Ymm(2 * i + 1)); |
403 | } |
404 | |
405 | const unsigned int lfloat = 0x44; |
406 | const unsigned int ufloat = 0xee; |
407 | for (int i = 0; i < unroll / 2; i++) { |
408 | const int j = i % 2 == 0 ? unroll + i : i - 1; |
409 | vshufps(Ymm(unroll / 2 + 2 * i), Ymm(j), Ymm(j + 1), lfloat); |
410 | vshufps(Ymm(unroll / 2 + 2 * i + 1), Ymm(j), Ymm(j + 1), ufloat); |
411 | } |
412 | |
413 | const unsigned int lquad = 0x20; |
414 | for (int i = 0; i < unroll / 2; i++) |
415 | vperm2f128(Ymm(i), Ymm(unroll / 2 + i), Ymm(unroll + i), lquad); |
416 | |
417 | const unsigned int uquad = 0x31; |
418 | for (int i = unroll / 2; i < unroll; i++) |
419 | vperm2f128(Ymm(i), Ymm(i), Ymm(unroll / 2 + i), uquad); |
420 | |
421 | if (need_saturation) { |
422 | init_saturate_f32(ymm_zero_, ymm_saturation_ubound_, reg_tmp_, |
423 | interim_f32 ? f32 : prb_.itype, prb_.otype); |
424 | for (int i = 0; i < unroll; i++) |
425 | saturate_f32( |
426 | Ymm(i), ymm_zero_, ymm_saturation_ubound_, prb_.otype); |
427 | } |
428 | |
429 | for (int i = 0; i < unroll; i++) { |
430 | const int node_1_output_stride = prb_.os(1); |
431 | if (prb_.otype != f32) |
432 | cvt2odt(Ymm(i), prb_.otype, interim_f32 ? f32 : prb_.itype); |
433 | store(o_addr(o_off + i * node_1_output_stride), Ymm(i), |
434 | unroll * otype_sz_); |
435 | } |
436 | } |
437 | |
438 | bool can_do_tr8x8() { |
439 | using namespace data_type; |
440 | |
441 | static constexpr int desirable_node_size = 8; |
442 | static constexpr int desirable_stride = 1; |
443 | |
444 | // This processing is relied on swaping two innermost dimension. |
445 | // Therefore, input stride in second node and output stride in first node |
446 | // have to be equal to 1. |
447 | |
448 | return mayiuse(avx2) && prb_.ndims >= 2 |
449 | && ((utils::one_of(prb_.itype, u8, s8, s32, f32, bf16, f16) |
450 | && utils::one_of( |
451 | prb_.otype, u8, s8, s32, f32, bf16, f16))) |
452 | && utils::everyone_is(desirable_node_size, prb_.n(0), prb_.n(1)) |
453 | && utils::everyone_is(desirable_stride, prb_.os(0), prb_.is(1)) |
454 | && !prb_.is_tail_present |
455 | && prb_.src_scale_type == scale_type_t::NONE |
456 | && prb_.dst_scale_type == scale_type_t::NONE |
457 | && prb_.beta == 0.f; |
458 | } |
459 | |
460 | bool process_unroll_tr8x8(const int ndims, const int len) { |
461 | if (!can_do_tr8x8()) return false; |
462 | |
463 | const int step_size = prb_.n(0) * prb_.n(1); |
464 | int i_off = 0, o_off = 0; |
465 | for (int off = 0; off < len; off += step_size) { |
466 | step(off, i_off, o_off, i_off, o_off, step_size); |
467 | tr8x8_avx2(i_off, o_off); |
468 | } |
469 | |
470 | return true; |
471 | } |
472 | |
473 | template <cpu_isa_t isa> |
474 | bool process_direct_copy(const int ndims, const int len) { |
475 | using namespace data_type; |
476 | |
477 | static constexpr int desirable_stride = 1; |
478 | using Vmm = typename cpu_isa_traits<isa>::Vmm; |
479 | const int simd_w = cpu_isa_traits<isa>::vlen / itype_sz_; |
480 | |
481 | // TODO: support tail_processing for direct copy |
482 | |
483 | const bool do_src_zp = prb_.req_src_zp; |
484 | const bool do_dst_zp = prb_.req_dst_zp; |
485 | const bool zp_applicable = IMPLICATION( |
486 | (do_src_zp || do_dst_zp), utils::one_of(prb_.itype, s32, f32)); |
487 | const bool can_do = true && mayiuse(isa) |
488 | && compensation_needed_ == false |
489 | && utils::everyone_is(desirable_stride, prb_.os(0), prb_.is(0)) |
490 | && (false || (prb_.itype == prb_.otype ? zp_applicable : false) |
491 | || (prb_.itype == s32 && prb_.otype == f32) |
492 | || (prb_.itype == f32 && prb_.otype == s32)) |
493 | && len % simd_w == 0 && prb_.n(0) % len == 0 |
494 | && !prb_.is_tail_present |
495 | && prb_.src_scale_type == scale_type_t::NONE |
496 | && prb_.dst_scale_type == scale_type_t::NONE |
497 | && prb_.beta == 0.f; |
498 | if (!can_do) return false; |
499 | |
500 | static constexpr int vmm_zp_last_idx = 15; |
501 | const auto vmm_src_zp |
502 | = Vmm(do_dst_zp ? vmm_zp_last_idx - 1 : vmm_zp_last_idx); |
503 | if (do_src_zp) { |
504 | uni_vpbroadcastd(vmm_src_zp, PARAM(src_zp)); |
505 | uni_vcvtdq2ps(vmm_src_zp, vmm_src_zp); |
506 | } |
507 | const auto vmm_dst_zp = Vmm(vmm_zp_last_idx); |
508 | if (do_dst_zp) { |
509 | uni_vpbroadcastd(vmm_dst_zp, PARAM(dst_zp)); |
510 | uni_vcvtdq2ps(vmm_dst_zp, vmm_dst_zp); |
511 | } |
512 | |
513 | const auto apply_zp_ps = [&](const Vmm vmm) { |
514 | if (do_src_zp) uni_vsubps(vmm, vmm, vmm_src_zp); |
515 | if (do_dst_zp) uni_vaddps(vmm, vmm, vmm_dst_zp); |
516 | }; |
517 | |
518 | for (int off = 0; off < len;) { |
519 | // TODO: we need extra reg for proper saturation if otype == s32 |
520 | int unroll |
521 | = nstl::min(16 - (prb_.otype == s32), (len - off) / simd_w); |
522 | unroll = (do_src_zp || do_dst_zp) |
523 | ? nstl::min(unroll, 16 - do_src_zp - do_dst_zp) |
524 | : unroll; |
525 | |
526 | for (int ur = 0; ur < unroll; ++ur) { |
527 | const auto vmm = Vmm(ur); |
528 | uni_vmovups(vmm, i_addr(off + ur * simd_w)); |
529 | } |
530 | |
531 | if (prb_.itype != prb_.otype) { |
532 | for (int ur = 0; ur < unroll; ++ur) { |
533 | const auto vmm = Vmm(ur); |
534 | if (prb_.itype == s32 && prb_.otype == f32) { |
535 | uni_vcvtdq2ps(vmm, vmm); |
536 | apply_zp_ps(vmm); |
537 | } else if (prb_.itype == f32 && prb_.otype == s32) { |
538 | apply_zp_ps(vmm); |
539 | uni_vcvtps2dq(vmm, vmm); |
540 | } else |
541 | assert(!"unreachable" ); |
542 | } |
543 | } else if (do_src_zp || do_dst_zp) { |
544 | for (int ur = 0; ur < unroll; ++ur) { |
545 | const auto vmm = Vmm(ur); |
546 | if (prb_.otype == f32) { |
547 | apply_zp_ps(vmm); |
548 | } else if (prb_.otype == s32) { |
549 | uni_vcvtdq2ps(vmm, vmm); |
550 | apply_zp_ps(vmm); |
551 | uni_vcvtps2dq(vmm, vmm); |
552 | } |
553 | } |
554 | } |
555 | |
556 | for (int ur = 0; ur < unroll; ++ur) { |
557 | const auto vmm = Vmm(ur); |
558 | uni_vmovups(o_addr(off + ur * simd_w), vmm); |
559 | } |
560 | |
561 | off += unroll * simd_w; |
562 | } |
563 | |
564 | return true; |
565 | } |
566 | |
567 | void process_unroll_generic_step(int reg_unroll, const int *i_off, |
568 | const int *o_off, const int *s_off, const int *c_off, |
569 | const int *zero_padding, const bool tail_processing) { |
570 | using namespace data_type; |
571 | |
572 | const auto cvt2ps |
573 | = [=](const Xmm dst, const Operand &src, data_type_t idt) { |
574 | Xmm dst_pure = Xmm(dst.getIdx()); |
575 | switch (idt) { |
576 | case f32: |
577 | if (src.isMEM() || src.getIdx() != dst.getIdx()) |
578 | uni_vmovups(dst, src); |
579 | break; |
580 | case bf16: |
581 | if (mayiuse(avx)) { |
582 | vpmovzxwd(dst, src); |
583 | vpslld(dst, dst, 0x10); |
584 | break; |
585 | } else |
586 | assert("unreachable!" ); |
587 | case f16: vcvtph2ps(dst, src); break; |
588 | case s32: uni_vcvtdq2ps(dst, src); break; |
589 | case s8: |
590 | uni_vpmovsxbd(dst, src); |
591 | uni_vcvtdq2ps(dst_pure, dst); |
592 | break; |
593 | case u8: |
594 | uni_vpmovzxbd(dst, src); |
595 | uni_vcvtdq2ps(dst_pure, dst); |
596 | break; |
597 | default: assert(!"unreachable" ); |
598 | } |
599 | }; |
600 | |
601 | const auto cvt2odt = [=](const Xmm xmm, data_type_t odt, |
602 | data_type_t idt) { |
603 | switch (odt) { |
604 | case bf16: |
605 | if (!mayiuse(avx)) assert(!"unreachable" ); |
606 | if (utils::one_of(idt, f32, f16, s8, u8)) { |
607 | if (idt != f32) cvt2ps(xmm, xmm, idt); |
608 | if (is_superset(isa_, avx2_vnni_2)) { |
609 | vcvtneps2bf16(xmm, xmm, Xbyak::VexEncoding); |
610 | } else if (mayiuse(avx512_core_bf16)) { |
611 | vcvtneps2bf16(xmm, xmm); |
612 | } else { |
613 | bf16_emu_->vcvtneps2bf16( |
614 | Ymm(xmm.getIdx()), Zmm(xmm.getIdx())); |
615 | } |
616 | } |
617 | break; |
618 | case f16: |
619 | if (!mayiuse(avx)) assert(!"unreachable" ); |
620 | if (utils::one_of(idt, f32, bf16, s8, u8)) { |
621 | if (idt != f32) cvt2ps(xmm, xmm, idt); |
622 | vcvtps2ph(xmm, xmm, _op_mxcsr); |
623 | } |
624 | break; |
625 | case s32: |
626 | if (idt == f32) |
627 | uni_vcvtps2dq(xmm, xmm); |
628 | else if (idt == s8) |
629 | uni_vpmovsxbd(xmm, xmm); |
630 | else if (idt == u8) |
631 | uni_vpmovzxbd(xmm, xmm); |
632 | break; |
633 | case s8: |
634 | if (utils::one_of(idt, bf16, f16)) cvt2ps(xmm, xmm, idt); |
635 | if (utils::one_of(idt, f32, bf16, f16)) |
636 | uni_vcvtps2dq(xmm, xmm); |
637 | if (utils::one_of(idt, bf16, f16, f32, s32)) { |
638 | if (mayiuse(avx512_core)) { |
639 | vpmovsdb(xmm, xmm); |
640 | } else { |
641 | uni_vpackssdw(xmm, xmm, xmm_zero_); |
642 | uni_vpacksswb(xmm, xmm, xmm_zero_); |
643 | } |
644 | } |
645 | if (idt == u8) uni_vpminub(xmm, xmm, xmm_4x127b_); |
646 | break; |
647 | case u8: |
648 | if (utils::one_of(idt, bf16, f16)) cvt2ps(xmm, xmm, idt); |
649 | if (utils::one_of(idt, f32, bf16, f16)) |
650 | uni_vcvtps2dq(xmm, xmm); |
651 | if (utils::one_of(idt, bf16, f16, f32, s32)) { |
652 | if (mayiuse(avx512_core)) { |
653 | vpmaxsd(xmm, xmm, xmm_zero_); |
654 | vpmovusdb(xmm, xmm); |
655 | } else { |
656 | uni_vpackssdw(xmm, xmm, xmm_zero_); |
657 | uni_vpackuswb(xmm, xmm, xmm_zero_); |
658 | } |
659 | } |
660 | if (idt == s8) uni_vpmaxsb(xmm, xmm, xmm_zero_); |
661 | break; |
662 | default: assert(!"unreachable" ); |
663 | } |
664 | }; |
665 | |
666 | auto load = [=](const Xmm xmm, const Address &addr, int size) { |
667 | switch (size) { |
668 | case 16: uni_vmovups(xmm, addr); break; |
669 | case 8: uni_vmovsd(xmm, addr); break; |
670 | case 4: uni_vmovss(xmm, addr); break; |
671 | case 2: uni_vpinsrw(xmm, xmm, addr, 0x0); break; |
672 | case 1: uni_vpinsrb(xmm, xmm, addr, 0x0); break; |
673 | default: assert(!"unreachable" ); |
674 | } |
675 | }; |
676 | |
677 | auto load_bytes |
678 | = [=](const Xmm xmm, const Address &addr, int size, int imm) { |
679 | switch (size) { |
680 | case 4: uni_vpinsrd(xmm, xmm, addr, imm); break; |
681 | case 2: uni_vpinsrw(xmm, xmm, addr, imm); break; |
682 | case 1: uni_vpinsrb(xmm, xmm, addr, imm); break; |
683 | default: assert(!"unreachable" ); |
684 | } |
685 | }; |
686 | |
687 | auto store = [=](const Address &addr, const Xmm xmm, int size) { |
688 | switch (size) { |
689 | case 16: uni_vmovups(addr, xmm); break; |
690 | case 8: uni_vmovsd(addr, xmm); break; |
691 | case 4: uni_vmovss(addr, xmm); break; |
692 | case 2: uni_vpextrw(addr, xmm, 0x0); break; |
693 | case 1: uni_vpextrb(addr, xmm, 0x0); break; |
694 | default: assert(!"unreachable" ); |
695 | } |
696 | }; |
697 | |
698 | /* check whether loading 4 values at once is possible */ |
699 | static constexpr int xmm_vlen = 4; |
700 | bool can_load_xmm = reg_unroll % xmm_vlen == 0; |
701 | for (int ur = 1; ur < reg_unroll; ++ur) |
702 | if (i_off[ur] != i_off[ur - 1] + 1) { |
703 | can_load_xmm = false; |
704 | break; |
705 | } |
706 | const int load_step = can_load_xmm ? xmm_vlen : 1; |
707 | |
708 | /* check whether storing 4 values at once is possible */ |
709 | bool can_store_xmm = reg_unroll % xmm_vlen == 0; |
710 | for (int ur = 1; ur < reg_unroll; ++ur) |
711 | if (o_off[ur] != o_off[ur - 1] + 1) { |
712 | can_store_xmm = false; |
713 | break; |
714 | } |
715 | const int ur_step = can_store_xmm ? 4 : 1; |
716 | const int load_tail_step |
717 | = !can_load_xmm && can_store_xmm ? ur_step : load_step; |
718 | |
719 | const bool interim_f32 = interim_f32_needed(); |
720 | |
721 | const bool need_saturation |
722 | = (utils::one_of(prb_.otype, u8, s8, s32) && interim_f32); |
723 | |
724 | std::vector<int> store_masks; |
725 | if (tail_processing) { |
726 | for (int ur = 0; ur < reg_unroll; ur += load_tail_step) { |
727 | uni_vpxor(Xmm(ur), Xmm(ur), Xmm(ur)); |
728 | store_masks.push_back(0); |
729 | for (int r = 0; r < load_tail_step; ++r) { |
730 | if (zero_padding[ur + r] == 0) { |
731 | store_masks.back() += 1 << r; |
732 | load_bytes( |
733 | Xmm(ur), i_addr(i_off[ur + r]), itype_sz_, r); |
734 | } |
735 | } |
736 | } |
737 | } else { |
738 | if (!can_load_xmm && can_store_xmm) { |
739 | assert(ur_step == xmm_vlen); |
740 | /* load with stride */ |
741 | for (int ur = 0; ur < reg_unroll; ur += ur_step) { |
742 | for (int r = 0; r < ur_step; ++r) { |
743 | load_bytes( |
744 | Xmm(ur), i_addr(i_off[ur + r]), itype_sz_, r); |
745 | } |
746 | } |
747 | } else { |
748 | for (int ur = 0; ur < reg_unroll; ur += load_step) { |
749 | load(Xmm(ur), i_addr(i_off[ur]), load_step * itype_sz_); |
750 | } |
751 | } |
752 | } |
753 | |
754 | /* xmm[:] <-- (f32)xmm[:] */ |
755 | if (interim_f32) { |
756 | const int cvt_step = nstl::max(load_step, ur_step); |
757 | for (int ur = 0; ur < reg_unroll; ur += cvt_step) |
758 | cvt2ps(Xmm(ur), Xmm(ur), prb_.itype); |
759 | } |
760 | |
761 | if (can_load_xmm && !can_store_xmm) { |
762 | // transposition on the fly |
763 | const bool fast_return = prb_.src_scale_type != scale_type_t::MANY |
764 | && prb_.dst_scale_type != scale_type_t::MANY |
765 | && prb_.beta == 0.f; |
766 | if (fast_return) { |
767 | if (prb_.src_scale_type == scale_type_t::COMMON) |
768 | for (int ur = 0; ur < reg_unroll; ur += load_step) |
769 | uni_vmulps(Xmm(ur), Xmm(ur), xmm_src_scales_); |
770 | if (prb_.dst_scale_type == scale_type_t::COMMON) |
771 | for (int ur = 0; ur < reg_unroll; ur += load_step) |
772 | uni_vmulps(Xmm(ur), Xmm(ur), xmm_dst_scales_); |
773 | if (prb_.otype != f32) { |
774 | init_saturate_f32(xmm_zero_, xmm_saturation_ubound_, |
775 | reg_tmp_, interim_f32 ? f32 : prb_.itype, |
776 | prb_.otype); |
777 | for (int ur = 0; ur < reg_unroll; ur += load_step) { |
778 | if (need_saturation) |
779 | saturate_f32(Xmm(ur), xmm_zero_, |
780 | xmm_saturation_ubound_, prb_.otype); |
781 | cvt2odt(Xmm(ur), prb_.otype, |
782 | interim_f32 ? f32 : prb_.itype); |
783 | } |
784 | } |
785 | for (int ur = 0; ur < reg_unroll; ur += load_step) { |
786 | for (int r = 0; r < load_step; ++r) { |
787 | if (otype_sz_ == 4) |
788 | uni_vpextrd(o_addr(o_off[ur + r]), Xmm(ur), r); |
789 | else if (otype_sz_ == 2) |
790 | uni_vpextrw(o_addr(o_off[ur + r]), Xmm(ur), r); |
791 | else |
792 | uni_vpextrb(o_addr(o_off[ur + r]), Xmm(ur), r); |
793 | } |
794 | } |
795 | return; |
796 | } |
797 | |
798 | /* scatter elements of xmm into 4 xmms */ |
799 | if (itype_sz_ == 4 || interim_f32) { |
800 | for (int ur = 0; ur < reg_unroll; ur += load_step) |
801 | for (int r = 1; r < load_step; ++r) { |
802 | uni_vshufps(Xmm(ur + r), Xmm(ur), Xmm(ur), r); |
803 | } |
804 | } else { |
805 | for (int ur = 0; ur < reg_unroll; ur += load_step) |
806 | for (int r = 1; r < load_step; ++r) { |
807 | if (mayiuse(avx)) |
808 | vpalignr(Xmm(ur + r), Xmm(ur), Xmm(ur), |
809 | itype_sz_ * r); |
810 | else { |
811 | movups(Xmm(ur + r), Xmm(ur)); |
812 | palignr(Xmm(ur + r), Xmm(ur), itype_sz_ * r); |
813 | } |
814 | } |
815 | } |
816 | } |
817 | |
818 | /* src zero point application */ |
819 | if (prb_.req_src_zp) { |
820 | for (int ur = 0; ur < reg_unroll; ur += ur_step) { |
821 | const auto xmm = Xmm(ur); |
822 | if (interim_f32) |
823 | uni_vsubps(xmm, xmm, xmm_src_zp_); |
824 | else |
825 | uni_vpsubd(xmm, xmm, xmm_src_zp_); |
826 | } |
827 | } |
828 | |
829 | /* scale and beta processing */ |
830 | if (can_store_xmm) { |
831 | const auto apply_scales = [&](const Xmm &vreg_scales, |
832 | scale_arg_t scale_arg, |
833 | scale_type_t scale_type) { |
834 | if (scale_type == scale_type_t::COMMON) { |
835 | for (int ur = 0; ur < reg_unroll; ur += ur_step) |
836 | uni_vmulps(Xmm(ur), Xmm(ur), vreg_scales); |
837 | } else if (scale_type == scale_type_t::MANY) { |
838 | enum class scale_load_type_t { bcast, load, gather }; |
839 | |
840 | uni_vpxor(vreg_scales, vreg_scales, vreg_scales); |
841 | for (int ur = 0; ur < reg_unroll; ur += ur_step) { |
842 | scale_load_type_t scale_load_type |
843 | = scale_load_type_t::bcast; // the best case |
844 | |
845 | for (int r = ur + 1; r < ur + ur_step; ++r) |
846 | if (s_off[r] != s_off[r - 1] + 0) |
847 | scale_load_type = scale_load_type_t::load; |
848 | |
849 | if (scale_load_type == scale_load_type_t::bcast |
850 | && !tail_processing) { |
851 | uni_vbroadcastss(vreg_scales, |
852 | scale_arg == scale_arg_t::SRC |
853 | ? src_s_addr(s_off[ur]) |
854 | : dst_s_addr(s_off[ur])); |
855 | uni_vmulps(Xmm(ur), Xmm(ur), vreg_scales); |
856 | continue; |
857 | } |
858 | |
859 | // bcast doesn't work, the next try -- load |
860 | for (int r = ur + 1; r < ur + ur_step; ++r) |
861 | if (s_off[r] != s_off[r - 1] + 1) |
862 | scale_load_type = scale_load_type_t::gather; |
863 | |
864 | if (scale_load_type == scale_load_type_t::load |
865 | && !tail_processing) { |
866 | uni_vmovups(vreg_scales, |
867 | scale_arg == scale_arg_t::SRC |
868 | ? src_s_addr(s_off[ur]) |
869 | : dst_s_addr(s_off[ur])); |
870 | uni_vmulps(Xmm(ur), Xmm(ur), vreg_scales); |
871 | continue; |
872 | } |
873 | |
874 | // load doesn't work as well |
875 | // so gather the scale factors one by one |
876 | for (int r = ur; r < ur + ur_step; ++r) { |
877 | if (zero_padding[r] == 0 || !tail_processing) |
878 | uni_vpinsrd(vreg_scales, vreg_scales, |
879 | scale_arg == scale_arg_t::SRC |
880 | ? src_s_addr(s_off[r]) |
881 | : dst_s_addr(s_off[r]), |
882 | r - ur); |
883 | } |
884 | uni_vmulps(Xmm(ur), Xmm(ur), vreg_scales); |
885 | } |
886 | } |
887 | }; |
888 | /* xmm <-- src_scales * xmm[:] */ |
889 | apply_scales( |
890 | xmm_src_scales_, scale_arg_t::SRC, prb_.src_scale_type); |
891 | |
892 | /* xmm[:] <-- beta * dst + xmm[:] */ |
893 | assert(prb_.beta == 0.f || prb_.beta == 1.f); |
894 | if (prb_.beta == 1.f) { |
895 | for (int ur = 0; ur < reg_unroll; ur += ur_step) { |
896 | if (prb_.otype == f32) { |
897 | /* non VEX instructions do not support unaligned |
898 | * memory for instructions other than movups. */ |
899 | if (mayiuse(avx)) { |
900 | vaddps(Xmm(ur), o_addr(o_off[ur])); |
901 | } else { |
902 | /* register xmm(1) is unused */ |
903 | movups(Xmm(1), o_addr(o_off[ur])); |
904 | addps(Xmm(ur), Xmm(1)); |
905 | } |
906 | } else { |
907 | cvt2ps(Xmm(1), o_addr(o_off[ur]), prb_.otype); |
908 | uni_vaddps(Xmm(ur), Xmm(ur), Xmm(1)); |
909 | } |
910 | } |
911 | } |
912 | |
913 | /* dst <-- dst_scales * xmm[:] */ |
914 | apply_scales( |
915 | xmm_dst_scales_, scale_arg_t::DST, prb_.dst_scale_type); |
916 | } else { |
917 | const auto apply_scales |
918 | = [&](const Xmm &vreg_scales, scale_arg_t scale_arg, |
919 | scale_type_t scale_type) { |
920 | if (scale_type == scale_type_t::COMMON) { |
921 | for (int ur = 0; ur < reg_unroll; ur += ur_step) |
922 | uni_vmulss(Xmm(ur), Xmm(ur), vreg_scales); |
923 | } else if (scale_type == scale_type_t::MANY) { |
924 | for (int ur = 0; ur < reg_unroll; ur += ur_step) { |
925 | if (zero_padding[ur] == 0 || !tail_processing) |
926 | uni_vmulss(Xmm(ur), Xmm(ur), |
927 | scale_arg == scale_arg_t::SRC |
928 | ? src_s_addr(s_off[ur]) |
929 | : dst_s_addr(s_off[ur])); |
930 | } |
931 | } |
932 | }; |
933 | |
934 | /* xmm[0] <-- src_scales * xmm[0] */ |
935 | apply_scales( |
936 | xmm_src_scales_, scale_arg_t::SRC, prb_.src_scale_type); |
937 | |
938 | /* xmm[0] <-- beta * dst + xmm[0] */ |
939 | assert(prb_.beta == 0.f || prb_.beta == 1.f); |
940 | if (prb_.beta == 1.f) { |
941 | for (int ur = 0; ur < reg_unroll; ur += ur_step) { |
942 | if (prb_.otype == f32) { |
943 | uni_vaddss(Xmm(ur), Xmm(ur), o_addr(o_off[ur])); |
944 | } else { |
945 | if (prb_.otype == s32) { |
946 | uni_vmovss(xmm_tmp_, o_addr(o_off[ur])); |
947 | } else if (utils::one_of(prb_.otype, s8, u8)) { |
948 | uni_vpinsrb( |
949 | xmm_tmp_, xmm_tmp_, o_addr(o_off[ur]), 0x0); |
950 | } else if (utils::one_of(prb_.otype, bf16, f16)) { |
951 | uni_vpinsrw( |
952 | xmm_tmp_, xmm_tmp_, o_addr(o_off[ur]), 0x0); |
953 | } else { |
954 | assert(!"unsupported o_type" ); |
955 | } |
956 | cvt2ps(xmm_tmp_, xmm_tmp_, prb_.otype); |
957 | uni_vaddps(Xmm(ur), Xmm(ur), xmm_tmp_); |
958 | } |
959 | } |
960 | } |
961 | |
962 | /* dst <-- dst_scales * xmm[0] */ |
963 | apply_scales( |
964 | xmm_dst_scales_, scale_arg_t::DST, prb_.dst_scale_type); |
965 | } |
966 | |
967 | /* dst zero point application */ |
968 | if (prb_.req_dst_zp) { |
969 | for (int ur = 0; ur < reg_unroll; ur += ur_step) { |
970 | const auto xmm = Xmm(ur); |
971 | if (interim_f32) |
972 | uni_vaddps(xmm, xmm, xmm_dst_zp_); |
973 | else |
974 | uni_vpaddd(xmm, xmm, xmm_dst_zp_); |
975 | } |
976 | } |
977 | |
978 | /* adjust scale application */ |
979 | if (prb_.scale_adjust != 1.f) { |
980 | uni_vmovd(xmm_tmp_, reg_scale_adjust_); |
981 | uni_vpshufd(xmm_tmp_, xmm_tmp_, 0x0); |
982 | for (int ur = 0; ur < reg_unroll; ur += ur_step) { |
983 | uni_vmulps(Xmm(ur), Xmm(ur), xmm_tmp_); |
984 | } |
985 | } |
986 | |
987 | if (need_saturation) { |
988 | init_saturate_f32(xmm_zero_, xmm_saturation_ubound_, reg_tmp_, f32, |
989 | prb_.otype, compensation_needed_); |
990 | for (int ur = 0; ur < reg_unroll; ur += ur_step) { |
991 | saturate_f32(Xmm(ur), xmm_zero_, xmm_saturation_ubound_, |
992 | prb_.otype, compensation_needed_); |
993 | } |
994 | |
995 | // reset back xmm_zero_ if needed. |
996 | if (compensation_needed_ && (prb_.req_src_zp || prb_.req_dst_zp)) |
997 | uni_vxorps(xmm_zero_, xmm_zero_, xmm_zero_); |
998 | } |
999 | |
1000 | if (compensation_needed_) { |
1001 | const bool mayiuse_avx2 = mayiuse(avx2); |
1002 | const auto uni_vpaddd_wrapper |
1003 | = [&](const Xmm xmm, const Address &addr) { |
1004 | if (mayiuse_avx2) |
1005 | vpaddd(xmm, xmm, addr); |
1006 | else { |
1007 | //isas < avx2 demand paddd instruction addr to be aligned |
1008 | assert(xmm.getIdx() != xmm_tmp_.getIdx()); |
1009 | uni_vmovups(xmm_tmp_, addr); |
1010 | paddd(xmm, xmm_tmp_); |
1011 | } |
1012 | }; |
1013 | if (can_store_xmm) { |
1014 | enum class comp_load_type_t { bcast, load, gather }; |
1015 | |
1016 | for (int ur = 0; ur < reg_unroll; ur += ur_step) { |
1017 | |
1018 | bool all_ip_padding_one = true; |
1019 | bool all_ip_padding_zero = true; |
1020 | for (int r = ur; r < ur + ur_step; r++) { |
1021 | if (zero_padding[r] != 1) |
1022 | all_ip_padding_one = false; |
1023 | else |
1024 | all_ip_padding_zero = false; |
1025 | } |
1026 | if (all_ip_padding_one) continue; |
1027 | |
1028 | comp_load_type_t comp_load_type = comp_load_type_t::bcast; |
1029 | |
1030 | for (int r = ur + 1; r < ur + ur_step; ++r) |
1031 | if (c_off[r] != c_off[r - 1] + 0) { |
1032 | comp_load_type = comp_load_type_t::load; |
1033 | break; |
1034 | } |
1035 | |
1036 | if (comp_load_type == comp_load_type_t::bcast |
1037 | && all_ip_padding_zero) { |
1038 | // xmm_compensation is used for reduction. |
1039 | uni_vcvtps2dq(xmm_compensation, Xmm(ur)); |
1040 | uni_vphaddd(xmm_compensation, xmm_compensation, |
1041 | xmm_compensation); |
1042 | uni_vphaddd(xmm_compensation, xmm_compensation, |
1043 | xmm_compensation); |
1044 | const auto comp_addr = c_addr(c_off[ur]); |
1045 | uni_vmovss(xmm_tmp_, comp_addr); |
1046 | uni_vpaddd(xmm_tmp_, xmm_tmp_, xmm_compensation); |
1047 | uni_vmovss(comp_addr, xmm_tmp_); |
1048 | continue; |
1049 | } |
1050 | |
1051 | if (comp_load_type == comp_load_type_t::load) |
1052 | for (int r = ur + 1; r < ur + ur_step; ++r) |
1053 | if (c_off[r] != c_off[r - 1] + 1) { |
1054 | comp_load_type = comp_load_type_t::gather; |
1055 | break; |
1056 | } |
1057 | |
1058 | if (comp_load_type == comp_load_type_t::load |
1059 | && all_ip_padding_zero) { |
1060 | const auto comp_addr = c_addr(c_off[ur]); |
1061 | uni_vcvtps2dq(xmm_compensation, Xmm(ur)); |
1062 | uni_vpaddd_wrapper(xmm_compensation, comp_addr); |
1063 | uni_vmovups(comp_addr, xmm_compensation); |
1064 | continue; |
1065 | } |
1066 | |
1067 | uni_vcvtps2dq(xmm_compensation, Xmm(ur)); |
1068 | for (int r = ur; r < ur + ur_step; ++r) { |
1069 | if (zero_padding[r] == 0 || !tail_processing) { |
1070 | uni_vshufps(xmm_tmp_, xmm_compensation, |
1071 | xmm_compensation, r); |
1072 | const Reg32 reg_tmp_32 = reg_tmp_.cvt32(); |
1073 | uni_vmovd(reg_tmp_32, xmm_tmp_); |
1074 | const auto comp_addr = c_addr(c_off[r]); |
1075 | add(comp_addr, reg_tmp_32); |
1076 | } |
1077 | } |
1078 | } |
1079 | } else { |
1080 | for (int ur = 0; ur < reg_unroll; ur += ur_step) { |
1081 | if (zero_padding[ur] == 0 || !tail_processing) { |
1082 | const auto comp_addr = c_addr(c_off[ur]); |
1083 | uni_vcvtps2dq(xmm_compensation, Xmm(ur)); |
1084 | uni_vpaddd_wrapper(xmm_compensation, comp_addr); |
1085 | uni_vmovss(comp_addr, xmm_compensation); |
1086 | } |
1087 | } |
1088 | } |
1089 | } |
1090 | |
1091 | for (int ur = 0; ur < reg_unroll; ur += ur_step) { |
1092 | if (prb_.req_src_zp || prb_.req_dst_zp) { |
1093 | const bool use_store_masks = !store_masks.empty(); |
1094 | if (use_store_masks) { |
1095 | const auto mask = ~store_masks[ur / ur_step]; |
1096 | uni_vblendps(Xmm(ur), Xmm(ur), xmm_zero_, mask); |
1097 | } |
1098 | } |
1099 | if (prb_.otype != f32) |
1100 | cvt2odt(Xmm(ur), prb_.otype, interim_f32 ? f32 : prb_.itype); |
1101 | |
1102 | store(o_addr(o_off[ur]), Xmm(ur), ur_step * otype_sz_); |
1103 | } |
1104 | } |
1105 | |
1106 | bool interim_f32_needed() { |
1107 | using namespace data_type; |
1108 | |
1109 | return utils::one_of(f32, prb_.itype, prb_.otype) |
1110 | || prb_.src_scale_type != scale_type_t::NONE |
1111 | || prb_.dst_scale_type != scale_type_t::NONE || prb_.beta != 0.f |
1112 | || ((prb_.req_src_zp || prb_.req_dst_zp) |
1113 | ? !(prb_.itype == s32 && prb_.otype == s32) |
1114 | : false) |
1115 | || (prb_.itype != f32 && compensation_needed_) |
1116 | || prb_.scale_adjust != 1.f; |
1117 | } |
1118 | |
1119 | void process_unroll_generic( |
1120 | const int ndims, int len, const bool tail_processing) { |
1121 | assert(IMPLICATION(prb_.nodes[0].tail_size > 0, |
1122 | len == static_cast<int>(prb_.nodes[0].n) |
1123 | || len == static_cast<int>(prb_.nodes[0].tail_size))); |
1124 | |
1125 | const int blk = 8; |
1126 | |
1127 | int i_off[2 * blk] = {0}; |
1128 | int o_off[2 * blk] = {0}; |
1129 | int s_off[2 * blk] = {0}; |
1130 | int c_off[2 * blk] = {0}; |
1131 | |
1132 | int curr = 0; // will switch between 0 and 1 |
1133 | |
1134 | const bool interim_f32 = interim_f32_needed(); |
1135 | |
1136 | if (prb_.req_src_zp) { |
1137 | uni_vbroadcastss(xmm_src_zp_, PARAM(src_zp)); |
1138 | if (interim_f32) uni_vcvtdq2ps(xmm_src_zp_, xmm_src_zp_); |
1139 | } |
1140 | if (prb_.req_dst_zp) { |
1141 | uni_vbroadcastss(xmm_dst_zp_, PARAM(dst_zp)); |
1142 | if (interim_f32) uni_vcvtdq2ps(xmm_dst_zp_, xmm_dst_zp_); |
1143 | } |
1144 | |
1145 | for (int off = 0; off < len; off += blk) { |
1146 | const int reg_unroll = nstl::min(off + blk, len) - off; |
1147 | int zero_padding[blk] = {0}; |
1148 | const auto curr_blk = curr * blk; |
1149 | |
1150 | /* compute offsets and tail*/ |
1151 | for (int ur = off != 0 ? 0 : 1; ur < reg_unroll; ++ur) { |
1152 | const int ur_c = curr_blk + ur; |
1153 | const int ur_p = (ur_c - 1 + 2 * blk) % (2 * blk); // prev ur |
1154 | const bool is_tail |
1155 | = off + ur >= static_cast<int>(prb_.nodes[0].tail_size); |
1156 | step(off + ur, i_off[ur_p], o_off[ur_p], s_off[ur_p], |
1157 | c_off[ur_p], i_off[ur_c], o_off[ur_c], s_off[ur_c], |
1158 | c_off[ur_c]); |
1159 | if (tail_processing && is_tail) zero_padding[ur] = 1; |
1160 | } |
1161 | |
1162 | process_unroll_generic_step(reg_unroll, i_off + curr_blk, |
1163 | o_off + curr_blk, s_off + curr_blk, c_off + curr_blk, |
1164 | zero_padding, tail_processing); |
1165 | |
1166 | curr = 1 - curr; |
1167 | } |
1168 | } |
1169 | |
1170 | void compute_ker( |
1171 | const int ndims, const int len_unroll, const bool tail_processing) { |
1172 | bool optimized = false; |
1173 | optimized = optimized || process_direct_copy<avx>(ndims, len_unroll) |
1174 | || process_direct_copy<sse41>(ndims, len_unroll) |
1175 | || process_unroll_tr8x8(ndims, len_unroll); |
1176 | if (!optimized) |
1177 | process_unroll_generic(ndims, len_unroll, tail_processing); |
1178 | } |
1179 | |
1180 | void loop_begin(Label &l, Reg64 reg_cnt, int len) { |
1181 | mov(reg_cnt, len); |
1182 | L(l); |
1183 | } |
1184 | |
1185 | void check_if_this_is_last_chunk(const Reg64 reg_curr_chunk, int node_id) { |
1186 | // Chunks are backwards numered i.e: |
1187 | // [0] -> [node_size] |
1188 | // [1] -> [node_size - 1] |
1189 | // ... |
1190 | // [node_size - 1] -> [1] |
1191 | |
1192 | // It is done like this, because it is easier to decrement counter |
1193 | // and check if it is equal to zero than increment and check |
1194 | // if it is equal to node_size. |
1195 | static constexpr int64_t last_chunk = 1; |
1196 | cmp(reg_curr_chunk, last_chunk); |
1197 | } |
1198 | |
1199 | void zero_dst_memory(const int bytes_to_zeroing) { |
1200 | static constexpr int num_of_bytes_in_xmm = 128 / 8; |
1201 | |
1202 | const int xmms_to_zeroing |
1203 | = std::div(bytes_to_zeroing, num_of_bytes_in_xmm).quot; |
1204 | const int tail_to_zeroing |
1205 | = std::div(bytes_to_zeroing, num_of_bytes_in_xmm).rem; |
1206 | |
1207 | uni_vpxor(xmm_tmp_, xmm_tmp_, xmm_tmp_); |
1208 | |
1209 | if (xmms_to_zeroing > 0) { |
1210 | Label loop; |
1211 | |
1212 | mov(reg_tmp_, xmms_to_zeroing); |
1213 | L(loop); |
1214 | uni_vmovups(o_addr(0), xmm_tmp_); |
1215 | add(reg_off_out_, num_of_bytes_in_xmm); |
1216 | dec(reg_tmp_); |
1217 | jnz(loop); |
1218 | } |
1219 | |
1220 | for (int i = 0; i < tail_to_zeroing; i++) |
1221 | uni_vpextrb(o_addr(i, false), xmm_tmp_, 0); |
1222 | |
1223 | // Restore dst offset to initial value |
1224 | if (xmms_to_zeroing > 0) |
1225 | sub(reg_off_out_, num_of_bytes_in_xmm * xmms_to_zeroing); |
1226 | } |
1227 | |
1228 | void finalize_tail_loop(int i_step, int o_step, int s_step, int c_step, |
1229 | const int curr_node_id) { |
1230 | static constexpr int empty_chunk_info = -1; |
1231 | |
1232 | mov(reg_tmp_, empty_chunk_info); |
1233 | mov(data_chunk_addr(curr_node_id), reg_tmp_); |
1234 | |
1235 | const int padded_area = prb_.nodes[curr_node_id].n |
1236 | - prb_.nodes[curr_node_id].tail_size; |
1237 | |
1238 | if (prb_.nodes[curr_node_id].is_zero_pad_needed) { |
1239 | int num_of_zero_padded_values = padded_area; |
1240 | for (int i = curr_node_id - 1; i >= 0; i--) { |
1241 | num_of_zero_padded_values *= prb_.nodes[i].n; |
1242 | } |
1243 | |
1244 | const int bytes_to_zeroing = num_of_zero_padded_values * otype_sz_; |
1245 | zero_dst_memory(bytes_to_zeroing); |
1246 | } |
1247 | |
1248 | // This function is called by loop_end. At the end |
1249 | // of loop_end is section that is responsible for |
1250 | // restoring offset values. Restoring is based on |
1251 | // len value which is equal to prb.nodes[x].n. |
1252 | // If fill_zero_padded_area is called then it means |
1253 | // offsets were shifted prb.nodes[x].tail_size times. |
1254 | // Therefore, this function has to shift offsets by |
1255 | // zero pad area. |
1256 | add(reg_off_in_, padded_area * i_step * itype_sz_); |
1257 | add(reg_off_out_, padded_area * o_step * otype_sz_); |
1258 | if (prb_.src_scale_type == scale_type_t::MANY |
1259 | || prb_.dst_scale_type == scale_type_t::MANY) |
1260 | add(reg_off_scale_, padded_area * s_step * stype_sz_); |
1261 | if (compensation_needed_) |
1262 | add(reg_off_comp_, padded_area * c_step * sizeof(int32_t)); |
1263 | } |
1264 | |
1265 | void loop_end(Label &l, const Reg64 reg_cnt, int len, int i_step, |
1266 | int o_step, int s_step, int c_step, const int curr_node_id) { |
1267 | add(reg_off_in_, i_step * itype_sz_); |
1268 | add(reg_off_out_, o_step * otype_sz_); |
1269 | if (prb_.src_scale_type == scale_type_t::MANY |
1270 | || prb_.dst_scale_type == scale_type_t::MANY) |
1271 | add(reg_off_scale_, s_step * stype_sz_); |
1272 | if (compensation_needed_) add(reg_off_comp_, c_step * sizeof(int32_t)); |
1273 | |
1274 | dec(reg_cnt); |
1275 | jnz(l); |
1276 | |
1277 | if (prb_.tail(curr_node_id) != 0) { |
1278 | Label if_end; |
1279 | |
1280 | // On the stack should be an information if node |
1281 | // was processed with tail or not. |
1282 | pop(reg_tmp_); |
1283 | |
1284 | cmp(reg_tmp_, with_tail_info_); |
1285 | jne(if_end, T_NEAR); |
1286 | finalize_tail_loop(i_step, o_step, s_step, c_step, curr_node_id); |
1287 | L(if_end); |
1288 | } |
1289 | |
1290 | // Restore offset to initial values. It means before |
1291 | // loop execution. |
1292 | sub(reg_off_in_, len * i_step * itype_sz_); |
1293 | sub(reg_off_out_, len * o_step * otype_sz_); |
1294 | if (prb_.src_scale_type == scale_type_t::MANY |
1295 | || prb_.dst_scale_type == scale_type_t::MANY) |
1296 | sub(reg_off_scale_, len * s_step * stype_sz_); |
1297 | if (compensation_needed_) |
1298 | sub(reg_off_comp_, len * c_step * sizeof(int32_t)); |
1299 | } |
1300 | |
1301 | void compute_blk_ker(const simple_impl_desc_t &desc) { |
1302 | static constexpr bool with_tail_processing = true; |
1303 | Label no_last_chunk, end_label; |
1304 | int omp_ndims = prb_.full_ndims - prb_.ndims; |
1305 | |
1306 | if (prb_.nodes[0].tail_size > 0) { |
1307 | if (!prb_.nodes[0].is_parent_empty()) { |
1308 | const int parent_node_id = prb_.nodes[0].parent_node_id; |
1309 | mov(reg_tmp_, data_chunk_addr(parent_node_id)); |
1310 | check_if_this_is_last_chunk(reg_tmp_, parent_node_id); |
1311 | jne(no_last_chunk, T_NEAR); |
1312 | } |
1313 | |
1314 | const int len_unroll = desc.tail_len_unroll > 0 |
1315 | ? desc.tail_len_unroll |
1316 | : desc.len_unroll; |
1317 | compute_ker(omp_ndims, len_unroll, with_tail_processing); |
1318 | jmp(end_label, T_NEAR); |
1319 | } |
1320 | |
1321 | L(no_last_chunk); |
1322 | compute_ker(omp_ndims, desc.len_unroll, !with_tail_processing); |
1323 | L(end_label); |
1324 | } |
1325 | |
1326 | void create_loops(const simple_impl_desc_t &desc, |
1327 | const std::array<const Reg64, 3> ®_cnt, int jit_loop) { |
1328 | assert(jit_loop <= ndims_jit_loop_max); |
1329 | |
1330 | if (jit_loop > 0) { |
1331 | const int nfu = desc.ndims_full_unroll; |
1332 | const int unroll_factor |
1333 | = jit_loop == 1 ? desc.len_last_dim_unroll : 1; |
1334 | const int curr_node_id = nfu + (jit_loop - 1); |
1335 | const int parent_node_id = prb_.nodes[curr_node_id].parent_node_id; |
1336 | const int tail_size = prb_.tail(curr_node_id) / unroll_factor; |
1337 | const int node_size = prb_.n(curr_node_id) / unroll_factor; |
1338 | const Reg64 reg_loop_cnt = reg_cnt[jit_loop - 1]; |
1339 | const bool curr_node_has_tail = prb_.tail(curr_node_id) != 0; |
1340 | Label loop, if_no_tail, if_end; |
1341 | |
1342 | if (curr_node_has_tail) { |
1343 | if (prb_.nodes[curr_node_id].is_parent_empty()) { |
1344 | mov(reg_loop_cnt, tail_size); |
1345 | // Put info that node is being processed with tail. |
1346 | mov(reg_tmp_, with_tail_info_); |
1347 | push(reg_tmp_); |
1348 | } else { |
1349 | mov(reg_tmp_, data_chunk_addr(parent_node_id)); |
1350 | check_if_this_is_last_chunk(reg_tmp_, parent_node_id); |
1351 | jne(if_no_tail, T_NEAR); |
1352 | mov(reg_loop_cnt, tail_size); |
1353 | // Put info that node is being processed with tail. |
1354 | mov(reg_tmp_, with_tail_info_); |
1355 | push(reg_tmp_); |
1356 | jmp(if_end, T_NEAR); |
1357 | |
1358 | L(if_no_tail); |
1359 | mov(reg_loop_cnt, node_size); |
1360 | // Put info that node is being processed without tail. |
1361 | mov(reg_tmp_, without_tail_info_); |
1362 | push(reg_tmp_); |
1363 | L(if_end); |
1364 | } |
1365 | } |
1366 | |
1367 | if (prb_.is_tail_in_one_of_child_nodes(curr_node_id)) { |
1368 | if (!curr_node_has_tail) { |
1369 | mov(reg_loop_cnt, node_size); |
1370 | mov(data_chunk_addr(curr_node_id), reg_loop_cnt); |
1371 | } |
1372 | L(loop); |
1373 | if (!prb_.nodes[curr_node_id].is_parent_empty()) { |
1374 | Label if_no_tail_in_child_node; |
1375 | mov(reg_tmp_, data_chunk_addr(parent_node_id)); |
1376 | check_if_this_is_last_chunk(reg_tmp_, parent_node_id); |
1377 | jne(if_no_tail_in_child_node, T_NEAR); |
1378 | mov(data_chunk_addr(curr_node_id), reg_loop_cnt); |
1379 | L(if_no_tail_in_child_node); |
1380 | } else { |
1381 | mov(data_chunk_addr(curr_node_id), reg_loop_cnt); |
1382 | } |
1383 | } else if (curr_node_has_tail) { |
1384 | L(loop); |
1385 | } else { |
1386 | loop_begin(loop, reg_loop_cnt, node_size); |
1387 | } |
1388 | |
1389 | create_loops(desc, reg_cnt, jit_loop - 1); |
1390 | |
1391 | loop_end(loop, reg_loop_cnt, node_size, |
1392 | prb_.is(curr_node_id) * unroll_factor, |
1393 | prb_.os(curr_node_id) * unroll_factor, |
1394 | prb_.ss(curr_node_id) * unroll_factor, |
1395 | prb_.cs(curr_node_id) * unroll_factor, curr_node_id); |
1396 | } else { |
1397 | compute_blk_ker(desc); |
1398 | } |
1399 | } |
1400 | |
1401 | bool simple_impl() { |
1402 | simple_impl_desc_t d; |
1403 | if (!simple_impl_desc_init(prb_, &d)) return false; |
1404 | |
1405 | xor_(reg_off_in_, reg_off_in_); |
1406 | xor_(reg_off_out_, reg_off_out_); |
1407 | if (prb_.src_scale_type == scale_type_t::MANY |
1408 | || prb_.dst_scale_type == scale_type_t::MANY) |
1409 | xor_(reg_off_scale_, reg_off_scale_); |
1410 | if (compensation_needed_) xor_(reg_off_comp_, reg_off_comp_); |
1411 | |
1412 | std::array<const Reg64, 3> reg_cnt({{r15, r14, r13}}); |
1413 | |
1414 | const int n_jit_loops = prb_.ndims - d.ndims_full_unroll; |
1415 | create_loops(d, reg_cnt, n_jit_loops); |
1416 | |
1417 | return true; |
1418 | } |
1419 | |
1420 | void impl() { |
1421 | if (simple_impl()) return; |
1422 | assert(!"no implementation available" ); |
1423 | } |
1424 | |
1425 | jit_uni_reorder_kernel_f32_t(const desc_t &desc) |
1426 | : kernel_t(desc) |
1427 | , jit_generator(jit_name()) |
1428 | , isa_(get_max_cpu_isa()) |
1429 | , bf16_emu_(nullptr) { |
1430 | assert(!utils::one_of(isa_, isa_undef, isa_all)); |
1431 | itype_sz_ = data_type_size(prb_.itype); |
1432 | otype_sz_ = data_type_size(prb_.otype); |
1433 | stype_sz_ = sizeof(float); |
1434 | if (prb_.otype == data_type::bf16 && !mayiuse(avx512_core_bf16) |
1435 | && !mayiuse(avx2_vnni_2)) { |
1436 | bf16_emu_ = utils::make_unique<bf16_emulation_t>(this, |
1437 | bf16_emu_reserv_1_, bf16_emu_reserv_2_, bf16_emu_reserv_3_, |
1438 | bf16_emu_scratch_, bf16_emu_reserv_4_); |
1439 | } |
1440 | } |
1441 | |
1442 | void generate() override { |
1443 | Label end_of_kernel; |
1444 | |
1445 | preamble(); |
1446 | |
1447 | if (bf16_emu_) bf16_emu_->init_vcvtneps2bf16(); |
1448 | |
1449 | if (prb_.src_scale_type == scale_type_t::COMMON) { |
1450 | auto reg_ptr_src_scales__tmp = reg_ptr_in_; |
1451 | mov(reg_ptr_src_scales__tmp, PARAM(src_scales)); |
1452 | uni_vbroadcastss(xmm_src_scales_, ptr[reg_ptr_src_scales__tmp]); |
1453 | } else if (prb_.src_scale_type == scale_type_t::MANY) { |
1454 | mov(reg_ptr_src_scales_, PARAM(src_scales)); |
1455 | } |
1456 | |
1457 | if (prb_.dst_scale_type == scale_type_t::COMMON) { |
1458 | auto reg_ptr_dst_scales__tmp = reg_ptr_in_; |
1459 | mov(reg_ptr_dst_scales__tmp, PARAM(dst_scales)); |
1460 | uni_vbroadcastss(xmm_dst_scales_, ptr[reg_ptr_dst_scales__tmp]); |
1461 | } else if (prb_.dst_scale_type == scale_type_t::MANY) { |
1462 | mov(reg_ptr_dst_scales_, PARAM(dst_scales)); |
1463 | } |
1464 | |
1465 | if (compensation_needed_) |
1466 | mov(reg_ptr_comp_, PARAM(compensation_scratch)); |
1467 | if (prb_.scale_adjust == 0.5f) { mov(reg_scale_adjust_, 0x3f000000); } |
1468 | mov(reg_ptr_in_, PARAM(in)); |
1469 | mov(reg_ptr_out_, PARAM(out)); |
1470 | |
1471 | bool is_tail_in_drv_dims = false; |
1472 | for (int i = prb_.ndims; i < prb_.full_ndims; i++) |
1473 | if (prb_.nodes[i].tail_size > 0) { |
1474 | is_tail_in_drv_dims = true; |
1475 | break; |
1476 | } |
1477 | |
1478 | if (is_tail_in_drv_dims) { |
1479 | Label reorder_kernel; |
1480 | |
1481 | mov(reg_tmp_, TAIL_PARAM(skip_kernel_execution)); |
1482 | cmp(reg_tmp_, static_cast<int64_t>(true)); |
1483 | je(end_of_kernel, T_NEAR); |
1484 | |
1485 | mov(reg_tmp_, TAIL_PARAM(zeroing_data)); |
1486 | cmp(reg_tmp_, static_cast<int64_t>(false)); |
1487 | je(reorder_kernel, T_NEAR); |
1488 | // If zeroing data is set then all dst memory |
1489 | // will be zeroed and nothing more will be done. |
1490 | int bytes_to_zeroing = otype_sz_; |
1491 | for (int i = 0; i < prb_.ndims; i++) { |
1492 | bytes_to_zeroing *= prb_.nodes[i].n; |
1493 | } |
1494 | xor_(reg_off_out_, reg_off_out_); |
1495 | zero_dst_memory(bytes_to_zeroing); |
1496 | jmp(end_of_kernel, T_NEAR); |
1497 | L(reorder_kernel); |
1498 | } |
1499 | |
1500 | if (can_do_tr8x8()) { |
1501 | vxorps(ymm_zero_, ymm_zero_, ymm_zero_); |
1502 | |
1503 | if (prb_.itype == data_type::u8 && prb_.otype == data_type::s8) { |
1504 | mov(reg_tmp_, 0x7f7f7f7f7f7f7f7f); |
1505 | uni_vmovq(Xmm(ymm_8x127b_.getIdx()), reg_tmp_); |
1506 | } |
1507 | } else { |
1508 | uni_vxorps(xmm_zero_, xmm_zero_, xmm_zero_); |
1509 | |
1510 | if (prb_.itype == data_type::u8 && prb_.otype == data_type::s8) { |
1511 | mov(reg_tmp_.cvt32(), 0x7f7f7f7f); |
1512 | movd(xmm_4x127b_, reg_tmp_.cvt32()); |
1513 | } |
1514 | } |
1515 | |
1516 | impl(); |
1517 | |
1518 | L(end_of_kernel); |
1519 | postamble(); |
1520 | } |
1521 | |
1522 | ~jit_uni_reorder_kernel_f32_t() override = default; |
1523 | |
1524 | #undef TAIL_PARAM |
1525 | #undef PARAM |
1526 | |
1527 | private: |
1528 | static constexpr int64_t with_tail_info_ = static_cast<int64_t>(true); |
1529 | static constexpr int64_t without_tail_info_ = static_cast<int64_t>(false); |
1530 | |
1531 | int itype_sz_; |
1532 | int otype_sz_; |
1533 | int stype_sz_; |
1534 | |
1535 | const cpu_isa_t isa_; |
1536 | |
1537 | const Reg64 reg_ptr_in_ = rsi; |
1538 | const Reg64 reg_ptr_out_ = rdx; |
1539 | const Reg64 reg_ptr_src_scales_ = abi_not_param1; |
1540 | const Reg64 reg_ptr_dst_scales_ = r12; |
1541 | const Reg64 reg_ptr_comp_ = rbx; |
1542 | const Reg32 ®_scale_adjust_ = ebp; |
1543 | |
1544 | const Reg64 reg_off_in_ = r8; |
1545 | const Reg64 reg_off_out_ = r9; |
1546 | const Reg64 reg_off_scale_ = r10; |
1547 | const Reg64 reg_off_comp_ = r11; |
1548 | // r13-r15 are reserved for creating loops over compute kernels... |
1549 | |
1550 | const Reg64 reg_tmp_ = rax; |
1551 | |
1552 | const Xmm xmm_src_scales_ = xmm15; |
1553 | const Xmm xmm_dst_scales_ = xmm11; |
1554 | const Xmm xmm_zero_ = xmm14; |
1555 | const Xmm xmm_4x127b_ = xmm13; // TODO: unite with ymm_zero_ |
1556 | const Ymm ymm_zero_ = ymm14; |
1557 | const Ymm ymm_8x127b_ = ymm13; |
1558 | const Xmm xmm_tmp_ = xmm12; |
1559 | const Xmm xmm_src_zp_ = xmm9; |
1560 | const Xmm xmm_dst_zp_ = xmm10; |
1561 | const Xmm xmm_compensation = xmm8; |
1562 | const Xmm xmm_saturation_ubound_ = xmm12; |
1563 | const Ymm ymm_saturation_ubound_ = ymm12; |
1564 | |
1565 | /* bf16 support on SKX */ |
1566 | std::unique_ptr<bf16_emulation_t> bf16_emu_; |
1567 | const Zmm bf16_emu_reserv_1_ = Zmm(16); |
1568 | const Zmm bf16_emu_reserv_2_ = Zmm(17); |
1569 | const Reg64 bf16_emu_scratch_ = reg_tmp_; |
1570 | const Zmm bf16_emu_reserv_3_ = Zmm(18); |
1571 | const Zmm bf16_emu_reserv_4_ = Zmm(19); |
1572 | }; |
1573 | |
1574 | // Seperate class for no unroll/threading burden |
1575 | struct jit_single_blk_kernel_t : public jit_generator { |
1576 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_single_blk_kernel) |
1577 | static bool applicable(const prb_t &p) { |
1578 | using namespace data_type; |
1579 | |
1580 | bool ok = p.ndims >= 2 && mayiuse(avx2) |
1581 | && p.src_scale_type == scale_type_t::NONE |
1582 | && p.dst_scale_type == scale_type_t::NONE |
1583 | && utils::one_of(p.itype, f32) && utils::one_of(p.otype, f32) |
1584 | && utils::everyone_is(0, p.ioff, p.ooff) && p.beta == 0.f |
1585 | && prb_has_small_strides(p); |
1586 | if (!ok) return false; |
1587 | |
1588 | int64_t n0 = p.nodes[0].n; |
1589 | auto i0 = p.nodes[0].is; |
1590 | auto o0 = p.nodes[0].os; |
1591 | int64_t n1 = p.nodes[1].n; |
1592 | auto i1 = p.nodes[1].is; |
1593 | auto o1 = p.nodes[1].os; |
1594 | |
1595 | /* |
1596 | * for a transpose of plain to 8c case, nodes would be like: |
1597 | * n is os |
1598 | * m 1 8 |
1599 | * 8 m 1 |
1600 | * or |
1601 | * 8 m 1 |
1602 | * m 1 8 |
1603 | */ |
1604 | ok = (utils::one_of(n0, 8, 16) || utils::one_of(n1, 8, 16)) |
1605 | && ((i0 == 1 && o1 == 1 && n0 == i1 && o0 == n1) |
1606 | || (o0 == 1 && i1 == 1 && n0 == o1 && i0 == n1)); |
1607 | if (!ok) return false; |
1608 | |
1609 | // Do not handle transpose of dimensions other than last 2 |
1610 | for (int i = 2; i < p.ndims; ++i) { |
1611 | if (p.nodes[i].is != p.nodes[i].os) { |
1612 | ok = false; |
1613 | break; |
1614 | } |
1615 | } |
1616 | |
1617 | return ok; |
1618 | } |
1619 | |
1620 | jit_single_blk_kernel_t(const tr::prb_t &prb) |
1621 | : jit_generator(jit_name()) |
1622 | , prb_(prb) |
1623 | , itype_sz_(data_type_size(prb_.itype)) |
1624 | , otype_sz_(data_type_size(prb_.otype)) |
1625 | , block_sz(prb.nodes[0].n) {} |
1626 | |
1627 | void generate() override { |
1628 | auto input_stride |
1629 | = prb_.nodes[0].is != 1 ? prb_.nodes[0].is : prb_.nodes[1].is; |
1630 | auto output_stride |
1631 | = prb_.nodes[0].os != 1 ? prb_.nodes[0].os : prb_.nodes[1].os; |
1632 | |
1633 | Label tail_processing; |
1634 | |
1635 | const auto load_zp = [&](const Ymm ymm_zp, const Reg64 reg_zp) { |
1636 | const Xmm xmm_zp = Xmm(ymm_zp.getIdx()); |
1637 | uni_vmovq(xmm_zp, reg_zp); |
1638 | uni_vpbroadcastd(ymm_zp, xmm_zp); |
1639 | uni_vcvtdq2ps(ymm_zp, ymm_zp); |
1640 | }; |
1641 | |
1642 | preamble(); |
1643 | |
1644 | if (prb_.req_src_zp) load_zp(ymm_src_zp, reg_src_zp); |
1645 | |
1646 | if (prb_.req_dst_zp) load_zp(ymm_dst_zp, reg_dst_zp); |
1647 | |
1648 | cmp(reg_ptr_tail, true); |
1649 | je(tail_processing, T_NEAR); |
1650 | |
1651 | if (block_sz == 8) { |
1652 | gen_ker8x8(0, 0, input_stride, output_stride, 8, 8); |
1653 | block_sz = 8; |
1654 | } else if (block_sz == 16) { |
1655 | gen_ker16x16_in_8x8(input_stride, output_stride); |
1656 | block_sz = 16; |
1657 | } else { |
1658 | assert(!"unimplemented" ); |
1659 | } |
1660 | |
1661 | postamble(); |
1662 | |
1663 | L(tail_processing); |
1664 | |
1665 | if (block_sz == 8) { |
1666 | auto i_tail = input_stride % 8 != 0 ? input_stride % 8 : 8; |
1667 | auto o_tail = output_stride % 8 != 0 ? output_stride % 8 : 8; |
1668 | if (i_tail != o_tail) { |
1669 | auto t_mask = i_tail == 8 ? o_tail : i_tail; |
1670 | gen_setmask(t_mask); |
1671 | gen_ker8x8(0, 0, input_stride, output_stride, i_tail, o_tail); |
1672 | } |
1673 | } else if (block_sz == 16) { |
1674 | auto i_tail = input_stride % 16 != 0 ? input_stride % 16 : 16; |
1675 | auto o_tail = output_stride % 16 != 0 ? output_stride % 16 : 16; |
1676 | if (i_tail != o_tail) { |
1677 | auto t_mask = i_tail == 16 ? o_tail : i_tail; |
1678 | t_mask %= 8; |
1679 | if (t_mask != 0) gen_setmask(t_mask); |
1680 | gen_ker16x16_in_8x8( |
1681 | input_stride, output_stride, i_tail, o_tail); |
1682 | } |
1683 | } else { |
1684 | assert(!"unimplemented" ); |
1685 | } |
1686 | |
1687 | postamble(); |
1688 | } |
1689 | |
1690 | void gen_loadu(const Ymm ymm, const Address &addr, int size) { |
1691 | Xmm xmm(ymm.getIdx()); |
1692 | switch (size) { |
1693 | case 32: vmovups(ymm, addr); break; |
1694 | case 16: vmovups(xmm, addr); break; |
1695 | default: assert(!"unreachable" ); |
1696 | } |
1697 | } |
1698 | |
1699 | void gen_storeu(const Address &addr, const Ymm ymm, int size) { |
1700 | Xmm xmm(ymm.getIdx()); |
1701 | switch (size) { |
1702 | case 32: vmovups(addr, ymm); break; |
1703 | case 16: vmovups(addr, xmm); break; |
1704 | default: assert(!"unreachable" ); |
1705 | } |
1706 | } |
1707 | |
1708 | void gen_maskloadu( |
1709 | const Ymm ymm, const Address &addr, const Ymm mask, int size) { |
1710 | Xmm xmm(ymm.getIdx()); |
1711 | Xmm mask128(mask.getIdx()); |
1712 | switch (size) { |
1713 | case 32: vmaskmovps(ymm, mask, addr); break; |
1714 | case 16: vmaskmovps(xmm, mask128, addr); break; |
1715 | default: assert(!"unreachable" ); |
1716 | } |
1717 | } |
1718 | |
1719 | void gen_maskstoreu( |
1720 | const Address &addr, const Ymm ymm, const Ymm mask, int size) { |
1721 | Xmm xmm(ymm.getIdx()); |
1722 | Xmm mask128(mask.getIdx()); |
1723 | switch (size) { |
1724 | case 32: vmaskmovps(addr, mask, ymm); break; |
1725 | case 16: vmaskmovps(addr, mask128, xmm); break; |
1726 | default: assert(!"unreachable" ); |
1727 | } |
1728 | } |
1729 | |
1730 | // Register allocation xmm0~11 |
1731 | void gen_transpose_8x8() { |
1732 | constexpr int lane = 8; |
1733 | for (int i = 0; i < lane / 2; i++) { |
1734 | vunpcklps(Ymm(lane + i), Ymm(2 * i), Ymm(2 * i + 1)); |
1735 | vunpckhps(Ymm(i), Ymm(2 * i), Ymm(2 * i + 1)); |
1736 | } |
1737 | |
1738 | const unsigned int lfloat = 0x44; |
1739 | const unsigned int ufloat = 0xee; |
1740 | for (int i = 0; i < lane / 2; i++) { |
1741 | int j = i % 2 == 0 ? lane + i : i - 1; |
1742 | vshufps(Ymm(lane / 2 + 2 * i), Ymm(j), Ymm(j + 1), lfloat); |
1743 | vshufps(Ymm(lane / 2 + 2 * i + 1), Ymm(j), Ymm(j + 1), ufloat); |
1744 | } |
1745 | |
1746 | const unsigned int lquad = 0x20; |
1747 | for (int i = 0; i < lane / 2; i++) |
1748 | vperm2f128(Ymm(i), Ymm(lane / 2 + i), Ymm(lane + i), lquad); |
1749 | |
1750 | const unsigned int uquad = 0x31; |
1751 | for (int i = lane / 2; i < lane; i++) |
1752 | vperm2f128(Ymm(i), Ymm(i), Ymm(lane / 2 + i), uquad); |
1753 | } |
1754 | |
1755 | // keep order nchw -> nChw()C |
1756 | // or nChw()C -> nchw |
1757 | void gen_setmask(int mask) { |
1758 | // all 0, all 1 |
1759 | vxorps(ymm_tmp, ymm_tmp, ymm_tmp); |
1760 | vpcmpeqd(ymm_mask, ymm_mask, ymm_mask); |
1761 | // shift by mask to have tail nelems in ymm_mask |
1762 | const uint8_t in_mask = 0xFF << mask; |
1763 | vpblendd(ymm_mask, ymm_mask, ymm_tmp, in_mask); |
1764 | } |
1765 | |
1766 | // TODO: Mark parameter with type information |
1767 | // XXX: ! |
1768 | // offset in byte offset |
1769 | // stride in element number |
1770 | // |
1771 | // Gen specific 8x8 transform respect to certain tail condition |
1772 | void gen_tr8x8(int i_off, int o_off, int input_stride, int output_stride, |
1773 | int in_tail, int out_tail) { |
1774 | constexpr int lane = 8; |
1775 | |
1776 | if (in_tail == 0 || out_tail == 0) return; |
1777 | |
1778 | for (int i = 0; i < out_tail; ++i) { |
1779 | if (in_tail != lane) { |
1780 | gen_maskloadu(Ymm(i), |
1781 | ptr[reg_ptr_in_ + i_off + i * input_stride * itype_sz_], |
1782 | ymm_mask, lane * itype_sz_); |
1783 | } else { |
1784 | gen_loadu(Ymm(i), |
1785 | ptr[reg_ptr_in_ + i_off + i * input_stride * itype_sz_], |
1786 | lane * itype_sz_); |
1787 | } |
1788 | if (prb_.req_src_zp) { vsubps(Ymm(i), Ymm(i), ymm_src_zp); } |
1789 | } |
1790 | |
1791 | gen_transpose_8x8(); |
1792 | |
1793 | for (int i = 0; i < in_tail; ++i) { |
1794 | if (prb_.req_dst_zp) { vaddps(Ymm(i), Ymm(i), ymm_dst_zp); } |
1795 | if (out_tail == lane) { |
1796 | gen_storeu(ptr[reg_ptr_out_ + o_off |
1797 | + i * output_stride * otype_sz_], |
1798 | Ymm(i), lane * otype_sz_); |
1799 | } else { |
1800 | gen_maskstoreu(ptr[reg_ptr_out_ + o_off |
1801 | + i * output_stride * otype_sz_], |
1802 | Ymm(i), ymm_mask, lane * otype_sz_); |
1803 | } |
1804 | } |
1805 | } |
1806 | |
1807 | // tail: 0 ~ 8 |
1808 | // support: either in_tail or out_tail is not 8, but not both |
1809 | void gen_ker8x8(int i_off, int o_off, int input_stride, int output_stride, |
1810 | int in_tail, int out_tail) { |
1811 | gen_tr8x8(i_off, o_off, input_stride, output_stride, in_tail, out_tail); |
1812 | } |
1813 | |
1814 | void gen_ker16x16_in_8x8(int input_stride, int output_stride) { |
1815 | const auto lane = 16; |
1816 | const auto sub_lane = lane / 2; |
1817 | gen_tr8x8(0, 0, input_stride, output_stride, sub_lane, sub_lane); |
1818 | gen_tr8x8(input_stride * sub_lane * itype_sz_, sub_lane * otype_sz_, |
1819 | input_stride, output_stride, sub_lane, sub_lane); |
1820 | gen_tr8x8(sub_lane * itype_sz_, output_stride * sub_lane * otype_sz_, |
1821 | input_stride, output_stride, sub_lane, sub_lane); |
1822 | gen_tr8x8((input_stride * sub_lane + sub_lane) * itype_sz_, |
1823 | (output_stride * sub_lane + sub_lane) * otype_sz_, input_stride, |
1824 | output_stride, sub_lane, sub_lane); |
1825 | } |
1826 | |
1827 | // tail can be 1 ~ 16, using avx2 for now |
1828 | void gen_ker16x16_in_8x8( |
1829 | int input_stride, int output_stride, int in_tail, int out_tail) { |
1830 | constexpr auto lane = 16; |
1831 | constexpr auto sub_lane = lane / 2; |
1832 | auto tail = in_tail != lane ? in_tail : out_tail; |
1833 | |
1834 | const auto l_tail = tail < sub_lane ? tail : sub_lane; |
1835 | const auto u_tail = tail < sub_lane ? 0 : tail - sub_lane; |
1836 | |
1837 | if (tail == in_tail) { |
1838 | gen_tr8x8(0, 0, input_stride, output_stride, l_tail, sub_lane); |
1839 | gen_tr8x8(input_stride * sub_lane * itype_sz_, sub_lane * otype_sz_, |
1840 | input_stride, output_stride, l_tail, sub_lane); |
1841 | gen_tr8x8(sub_lane * itype_sz_, |
1842 | output_stride * sub_lane * otype_sz_, input_stride, |
1843 | output_stride, u_tail, sub_lane); |
1844 | gen_tr8x8(itype_sz_ * (input_stride * sub_lane + sub_lane), |
1845 | otype_sz_ * (output_stride * sub_lane + sub_lane), |
1846 | input_stride, output_stride, u_tail, sub_lane); |
1847 | } else { |
1848 | gen_tr8x8(0, 0, input_stride, output_stride, sub_lane, l_tail); |
1849 | gen_tr8x8(input_stride * sub_lane * itype_sz_, sub_lane * otype_sz_, |
1850 | input_stride, output_stride, sub_lane, u_tail); |
1851 | gen_tr8x8(sub_lane * itype_sz_, |
1852 | output_stride * sub_lane * itype_sz_, input_stride, |
1853 | output_stride, sub_lane, l_tail); |
1854 | gen_tr8x8(itype_sz_ * (input_stride * sub_lane + sub_lane), |
1855 | otype_sz_ * (output_stride * sub_lane + sub_lane), |
1856 | input_stride, output_stride, sub_lane, u_tail); |
1857 | } |
1858 | } |
1859 | |
1860 | private: |
1861 | // 6 ~ 12 |
1862 | constexpr static int xmm_save_for_windows = is_windows ? 7 : 0; |
1863 | constexpr static int xmm_save_start_from = 6; |
1864 | constexpr static int xmm_width = 16; |
1865 | |
1866 | void preamble() { |
1867 | if (is_windows) { |
1868 | // retrieve 5th function call argument from call stack |
1869 | static constexpr int param5 = 0x8; |
1870 | mov(reg_dst_zp, ptr[rsp + param5]); |
1871 | sub(rsp, xmm_save_for_windows * xmm_width); |
1872 | for (int i = 0; i < xmm_save_for_windows; ++i) { |
1873 | uni_vmovdqu(ptr[rsp + i * xmm_width], |
1874 | Xbyak::Xmm(xmm_save_start_from + i)); |
1875 | } |
1876 | } |
1877 | } |
1878 | |
1879 | void postamble() { |
1880 | if (is_windows) { |
1881 | for (int i = 0; i < xmm_save_for_windows; ++i) |
1882 | uni_vmovdqu(Xbyak::Xmm(xmm_save_start_from + i), |
1883 | ptr[rsp + i * xmm_width]); |
1884 | add(rsp, xmm_save_for_windows * xmm_width); |
1885 | } |
1886 | uni_vzeroupper(); |
1887 | ret(); |
1888 | } |
1889 | |
1890 | const prb_t &prb_; |
1891 | |
1892 | int itype_sz_; |
1893 | int otype_sz_; |
1894 | int block_sz; |
1895 | |
1896 | Reg64 reg_ptr_in_ = abi_param1; |
1897 | Reg64 reg_ptr_out_ = abi_param2; |
1898 | // Windows bool is 1-byte in register |
1899 | Reg8 reg_ptr_tail = is_windows ? r8b : dl; |
1900 | Reg64 reg_src_zp = abi_param4; |
1901 | Reg64 reg_dst_zp = is_windows ? r10 : r8; |
1902 | |
1903 | Ymm ymm_mask = ymm12; |
1904 | Ymm ymm_tmp = ymm0; |
1905 | Ymm ymm_src_zp = ymm14; |
1906 | Ymm ymm_dst_zp = ymm15; |
1907 | }; |
1908 | |
1909 | status_t kernel_t::desc_init( |
1910 | kernel_t::desc_t &desc, const prb_t &prb, int ndims_ker_max) { |
1911 | desc.prb = prb; |
1912 | desc.prb.ioff = desc.prb.ooff = 0; |
1913 | |
1914 | if (ndims_ker_max > prb.ndims) return status::invalid_arguments; |
1915 | |
1916 | auto ndims_ker_max_f = [&]() { |
1917 | size_t cur_size = 1; |
1918 | for (int d = 0; d < prb.ndims; cur_size *= prb.nodes[d++].n) |
1919 | if (cur_size >= ker_prb_size_min) return d; |
1920 | return prb.ndims; |
1921 | }; |
1922 | |
1923 | if (ndims_ker_max <= 0) ndims_ker_max = ndims_ker_max_f(); |
1924 | |
1925 | /* traverse through kernel implementations */ |
1926 | /* TODO: find a better way to do that... */ |
1927 | desc.id = 0; |
1928 | for (int ndims_ker = ndims_ker_max; ndims_ker > 0; --ndims_ker) { |
1929 | desc.prb.ndims = ndims_ker; |
1930 | if (jit_uni_reorder_kernel_f32_t::applicable(desc.prb)) |
1931 | return status::success; |
1932 | } |
1933 | |
1934 | return status::unimplemented; |
1935 | } |
1936 | |
1937 | kernel_t *kernel_t::create(const kernel_t::desc_t &desc) { |
1938 | switch (desc.id) { |
1939 | case 0: return new jit_uni_reorder_kernel_f32_t(desc); |
1940 | default: assert(!"unknown kernel id" ); return nullptr; |
1941 | } |
1942 | |
1943 | return nullptr; |
1944 | } |
1945 | |
1946 | } // namespace tr |
1947 | |
1948 | static void prb_block_for_cache(tr::prb_t &prb) { |
1949 | /* If strides for 0th and 1st nodes are cache friendly |
1950 | * then one can altogether do away with blocking ! */ |
1951 | static constexpr int num_elems_thr = 16; |
1952 | const bool stride_cache_friendly |
1953 | = ((prb.nodes[0].is % 64 == 0 && prb.nodes[0].n > num_elems_thr) |
1954 | || (prb.ndims > 1 && prb.nodes[1].is % num_elems_thr == 0 |
1955 | && prb.nodes[1].n > num_elems_thr)) |
1956 | && !prb.is_tail_present; |
1957 | |
1958 | // performance improvement for shapes with large inner-most dimension |
1959 | const size_t L1_cache_sz |
1960 | = size_t(3) * platform::get_per_core_cache_size(1) / 4; |
1961 | const size_t itype_sz_ = data_type_size(prb.itype); |
1962 | const size_t inner_block_sz = prb.nodes[0].n * itype_sz_; |
1963 | const bool requires_inner_blocking = inner_block_sz > L1_cache_sz |
1964 | // 'is_tail_present' is not supported for cache_blocking when |
1965 | // asymmetric_comp is executed. |
1966 | && IMPLICATION(prb.req_asymmetric_comp, !prb.is_tail_present); |
1967 | |
1968 | const bool cache_blocking_needed |
1969 | = stride_cache_friendly || requires_inner_blocking; |
1970 | if (!cache_blocking_needed) return; |
1971 | |
1972 | int unit_input_stride_idx = -1; |
1973 | for (auto idx = 0; idx < prb.ndims; ++idx) { |
1974 | if (prb.nodes[idx].is == 1) unit_input_stride_idx = idx; |
1975 | } |
1976 | |
1977 | /* Re-prioritize the sequential read over sequential write: |
1978 | * /-> [n0:is0:1][16n1:1:osk]... |
1979 | * [n0:is0:1]...[nk:1:osk] --> or |
1980 | * \-> [16n1:1:osk][n0:is0:1]... */ |
1981 | if (unit_input_stride_idx != -1) { |
1982 | const auto output_stride = prb.nodes[unit_input_stride_idx].os; |
1983 | const auto num_elems = prb.nodes[unit_input_stride_idx].n; |
1984 | |
1985 | const bool split_needed = (num_elems > num_elems_thr) |
1986 | && (num_elems % num_elems_thr == 0); |
1987 | const int move_location = (output_stride % 4 != 0) ? 0 : 1; |
1988 | if (split_needed) |
1989 | prb_node_split(prb, unit_input_stride_idx, num_elems_thr); |
1990 | |
1991 | /* Because of cache-unfriendly nature of unit-output stride node, let |
1992 | * us move unit-input stride node on or near front! */ |
1993 | if (unit_input_stride_idx != move_location) |
1994 | prb_node_move(prb, unit_input_stride_idx, move_location); |
1995 | } |
1996 | |
1997 | /* Potentially, split the node with os=1 in two and pull in the node with |
1998 | * is=1 between them for better cache reuse: |
1999 | * [n0:is0:1][n1:1:os1] --> [16n0:is0:1][n1:1:os1][n0/16:is0*16:16] */ |
2000 | if (prb.ndims >= 2 && prb.nodes[0].os == 1 && prb.nodes[1].is == 1) { |
2001 | const auto num_elems = prb.nodes[0].n; |
2002 | |
2003 | const bool split_needed = (num_elems > num_elems_thr) |
2004 | && (num_elems % num_elems_thr == 0); |
2005 | if (split_needed) { |
2006 | prb_node_split(prb, 0, num_elems_thr); |
2007 | prb_node_move(prb, 1, 2); |
2008 | |
2009 | // Update node information |
2010 | prb_node_dependency(prb); |
2011 | |
2012 | // heuristics - looping over the unrolled dims should maximize reuse |
2013 | // of the already cached data; observation is choosing the smallest |
2014 | // dim from the remaining (from 2 up to ndims) gives good results |
2015 | constexpr int new_position = 2; |
2016 | const auto dim_beg_it = std::begin(prb.nodes); |
2017 | const auto dim_two_it = dim_beg_it + new_position; |
2018 | const auto dim_last_it = dim_beg_it + prb.ndims; |
2019 | const auto min_n_node_it = std::min_element(dim_two_it, dim_last_it, |
2020 | [](const tr::node_t &lhs, const tr::node_t &rhs) { |
2021 | return lhs.n < rhs.n; |
2022 | }); |
2023 | const auto min_idx = std::distance(dim_beg_it, min_n_node_it); |
2024 | // check if min_idx node is parent of node with tail processing which |
2025 | // is currently unsupported (i.e. tail processing can only be handled |
2026 | // at the inner-most dimension) |
2027 | bool inner_block_has_tail = false; |
2028 | for (int idx = min_idx - 1; idx >= new_position; idx--) { |
2029 | if (prb.nodes[idx].parent_node_id == min_idx) { |
2030 | inner_block_has_tail = true; |
2031 | break; |
2032 | } |
2033 | } |
2034 | |
2035 | if (min_idx > new_position && (!inner_block_has_tail)) |
2036 | prb_node_move(prb, min_idx, new_position); |
2037 | } |
2038 | } |
2039 | } |
2040 | |
2041 | /** finds the maximum number of dimension the kernel should process and |
2042 | * optionally splits one of the dimension to achieve better balance between |
2043 | * parallel driver and the kernel. */ |
2044 | static void prb_thread_kernel_balance( |
2045 | tr::prb_t &prb, int &ndims_ker_max, int nthr) { |
2046 | size_t size_total = 1; |
2047 | for (int d = 0; d < prb.ndims; ++d) |
2048 | size_total *= prb.nodes[d].n; |
2049 | |
2050 | /* The general expression for size_drv_thr can be written as |
2051 | * size_drv_min = C0 + FC * (nthr > 1 ? 1 : 0) + VC * (nthr - 1) |
2052 | * where FC and VC are fixed and variable costs respectively. |
2053 | * Though for now, the below heuristic seems to be good enough */ |
2054 | const size_t size_drv_thr = (nthr > 1) ? 16 * nthr : 1; |
2055 | |
2056 | /* size_drv_min is the minimal size for the parallel |
2057 | * driver required for good parallelization */ |
2058 | const size_t size_drv_min |
2059 | = nstl::min<size_t>(size_drv_thr, utils::div_up(size_total, 1024)); |
2060 | |
2061 | /* kdims -- # of dimensions processed by a kernel |
2062 | * size_ker_cur -- product of the dimension processed by a kernel |
2063 | * size_drv_cur -- product of the dimension processed by a driver */ |
2064 | |
2065 | int kdims = prb.ndims; |
2066 | size_t size_drv_cur = 1; |
2067 | for (; kdims > 1 && size_drv_cur < size_drv_min; --kdims) |
2068 | size_drv_cur *= prb.nodes[kdims - 1].n; |
2069 | |
2070 | size_t size_ker_cur = 1; |
2071 | for (int d = 0; d < kdims; ++d) |
2072 | size_ker_cur *= prb.nodes[d].n; |
2073 | |
2074 | /* Initially kdims is chosen so that size_drv_cur >= size_drv_min. |
2075 | * |
2076 | * It might happen that for chosen kdims the size_ker_cur is too small |
2077 | * (less than tr::ker_prb_size_min). In that case try to split the |
2078 | * innermost driver dimension into two, to increase size_ker_cur. */ |
2079 | const bool want_borrow_ker_from_drv = kdims < prb.ndims |
2080 | && size_ker_cur < tr::ker_prb_size_min |
2081 | && size_drv_cur > size_drv_min; |
2082 | if (want_borrow_ker_from_drv) { |
2083 | /* size_want_borrow is the minimal size, so that: |
2084 | * o) size_ker_cur * size_want_borrow >= tr::ker_prb_size_min |
2085 | * o) current innermost driver dimension is divisible by |
2086 | * size_want_borrow (so that we can evenly split that |
2087 | * dimension into two) |
2088 | * |
2089 | * In the worst case the minimal size_want_borrow is equal |
2090 | * to the innermost driver dimension itself. In that case |
2091 | * we will sacrifice it in favor of kernel (is it fine?). */ |
2092 | size_t size_want_borrow |
2093 | = utils::div_up(tr::ker_prb_size_min, size_ker_cur); |
2094 | for (; prb.nodes[kdims].n % size_want_borrow; ++size_want_borrow) |
2095 | ; |
2096 | |
2097 | if (size_want_borrow != prb.nodes[kdims].n) |
2098 | prb_node_split(prb, kdims, size_want_borrow); |
2099 | kdims += 1; |
2100 | } |
2101 | |
2102 | /* On the other hand it might happen that for chosen kdims |
2103 | * the size_drv_cur is too small (less than size_drv_min). In that case |
2104 | * try to split the outermost kernel dimension into two, to increase |
2105 | * size_drv_cur. */ |
2106 | const bool want_borrow_drv_from_ker = size_ker_cur > tr::ker_prb_size_min |
2107 | && size_drv_cur < size_drv_min; |
2108 | if (want_borrow_drv_from_ker) { |
2109 | size_t size_want_borrow = utils::div_up(size_drv_min, size_drv_cur); |
2110 | for (; prb.nodes[kdims - 1].n % size_want_borrow; ++size_want_borrow) |
2111 | ; |
2112 | |
2113 | if (size_want_borrow != prb.nodes[kdims - 1].n) |
2114 | prb_node_split( |
2115 | prb, kdims - 1, prb.nodes[kdims - 1].n / size_want_borrow); |
2116 | } |
2117 | |
2118 | ndims_ker_max = kdims; |
2119 | |
2120 | if (want_borrow_ker_from_drv || want_borrow_drv_from_ker) { |
2121 | DEBUG({ |
2122 | printf("split: " ); |
2123 | prb_dump(prb); |
2124 | printf("ndims_ker_max = %d\n" , ndims_ker_max); |
2125 | }); |
2126 | } |
2127 | } |
2128 | |
2129 | status_t jit_uni_reorder_t::pd_t::init( |
2130 | engine_t *engine, engine_t *src_engine, engine_t *dst_engine) { |
2131 | CHECK(cpu_reorder_pd_t::init(engine, src_engine, dst_engine)); |
2132 | |
2133 | CHECK(init_scratchpad()); |
2134 | |
2135 | return status::success; |
2136 | } |
2137 | |
2138 | status_t jit_uni_reorder_t::pd_t::init_scratchpad() { |
2139 | auto scratchpad = scratchpad_registry().registrar(); |
2140 | |
2141 | const bool compensation_needed |
2142 | = prb_.req_s8s8_comp || prb_.req_asymmetric_comp; |
2143 | if (compensation_needed) { |
2144 | const memory_desc_wrapper od(dst_md()); |
2145 | const auto G = with_groups_ ? od.padded_dims()[0] : 1; |
2146 | const auto N = od.padded_dims()[with_groups_ ? 1 : 0]; |
2147 | static constexpr int cache_line_size = 16; |
2148 | const auto wspace_per_thr_size |
2149 | = utils::rnd_up(G * N, cache_line_size) * sizeof(int32_t); |
2150 | |
2151 | const auto compensation_reduce_size = wspace_per_thr_size * nthr_; |
2152 | |
2153 | // Every thread gets its own scratchpad space for each N. |
2154 | scratchpad.template book<int32_t>( |
2155 | memory_tracking::names::key_reorder_space, |
2156 | compensation_reduce_size); |
2157 | } |
2158 | |
2159 | const memory_desc_wrapper input_d(src_md()); |
2160 | int scales_mask = -1; |
2161 | bool is_set = false; |
2162 | CHECK(attr()->scales_.get(DNNL_ARG_DST, &scales_mask, &is_set)); |
2163 | |
2164 | if (is_set && scales_mask > 0) { |
2165 | get_D_values(input_d, scales_mask, nullptr, &D_mask_, nullptr); |
2166 | if (D_mask_ > 1) { |
2167 | scratchpad.template book<float>( |
2168 | memory_tracking::names::key_reorder_precomputed_dst_scales, |
2169 | D_mask_); |
2170 | } |
2171 | } |
2172 | |
2173 | return status::success; |
2174 | } |
2175 | |
2176 | status_t jit_uni_reorder_t::pd_t::create(reorder_pd_t **reorder_pd, |
2177 | engine_t *engine, const primitive_attr_t *attr, engine_t *src_engine, |
2178 | const memory_desc_t *src_md, engine_t *dst_engine, |
2179 | const memory_desc_t *dst_md) { |
2180 | auto prb = tr::prb_t(); |
2181 | |
2182 | status_t prb_init_status = prb_init(prb, *src_md, *dst_md, attr); |
2183 | if (prb_init_status != status::success) return prb_init_status; |
2184 | |
2185 | prb_block_for_cache(prb); |
2186 | DEBUG({ |
2187 | printf("cache: " ); |
2188 | prb_dump(prb); |
2189 | }); |
2190 | |
2191 | int ndims_ker_max {}; |
2192 | int nthr = dnnl_get_max_threads(); |
2193 | prb_thread_kernel_balance(prb, ndims_ker_max, nthr); |
2194 | |
2195 | if (prb.is_tail_present) prb_node_dependency(prb); |
2196 | |
2197 | tr::kernel_t::desc_t ker_desc; |
2198 | status_t ker_init_status |
2199 | = tr::kernel_t::desc_init(ker_desc, prb, ndims_ker_max); |
2200 | if (ker_init_status != status::success) return ker_init_status; |
2201 | |
2202 | const int ndims_driver = prb.ndims - ker_desc.prb.ndims; |
2203 | if (ndims_driver > jit_uni_reorder_t::ndims_driver_max) |
2204 | return status::unimplemented; |
2205 | |
2206 | DEBUG({ |
2207 | printf("ker : " ); |
2208 | prb_dump(ker_desc.prb); |
2209 | }); |
2210 | |
2211 | auto _pd = new pd_t( |
2212 | attr, src_engine->kind(), src_md, dst_engine->kind(), dst_md); |
2213 | if (_pd == nullptr) return status::out_of_memory; |
2214 | |
2215 | _pd->nthr_ = nthr; |
2216 | _pd->prb_ = prb; |
2217 | _pd->with_groups_ |
2218 | = prb.compensation_mask == tr::prb_t::comp_mask_with_groups; |
2219 | if (_pd->init(engine, src_engine, dst_engine) != status::success) { |
2220 | delete _pd; |
2221 | return status::unimplemented; |
2222 | } |
2223 | _pd->ker_desc_ = ker_desc; |
2224 | _pd->init_scratchpad_md(); |
2225 | |
2226 | return safe_ptr_assign(*reorder_pd, _pd); |
2227 | } |
2228 | |
2229 | void jit_uni_reorder_t::omp_driver_0d(int off, const char *in, char *out, |
2230 | const float *src_scales, const float *dst_scales, int src_zp, |
2231 | int dst_zp, int32_t *compensation_scratch) const { |
2232 | const tr::prb_t &prb = pd()->prb_; |
2233 | |
2234 | tr::call_param_t base_params; |
2235 | base_params.in = in; |
2236 | base_params.out = out; |
2237 | base_params.src_scales = src_scales; |
2238 | base_params.dst_scales = dst_scales; |
2239 | base_params.src_zp = src_zp; |
2240 | base_params.dst_zp = dst_zp; |
2241 | base_params.compensation_scratch = compensation_scratch; |
2242 | |
2243 | if (prb.is_tail_present) { |
2244 | tr::tail_call_param_t tail_params; |
2245 | tail_params.base_params = base_params; |
2246 | |
2247 | static constexpr int omp_ndims = 0; |
2248 | fill_curr_data_chunks(prb, off, nullptr, omp_ndims, tail_params); |
2249 | |
2250 | (*kernel_)(&tail_params); |
2251 | } else { |
2252 | (*kernel_)(&base_params); |
2253 | } |
2254 | } |
2255 | |
2256 | void jit_uni_reorder_t::omp_driver_1d(int ithr, int nthr, int off, |
2257 | const char *in, char *out, const float *src_scales, |
2258 | const float *dst_scales, int src_zp, int dst_zp, |
2259 | int32_t *compensation_scratch) const { |
2260 | const tr::prb_t &prb = pd()->prb_; |
2261 | const tr::node_t *ns = prb.nodes + off; |
2262 | for_nd(ithr, nthr, (ptrdiff_t)ns[0].n, [&](ptrdiff_t d0) { |
2263 | tr::call_param_t base_params; |
2264 | base_params.in = in + d0 * ns[0].is * data_type_size(prb.itype); |
2265 | base_params.out = out + d0 * ns[0].os * data_type_size(prb.otype); |
2266 | base_params.src_scales = src_scales + d0 * ns[0].ss; |
2267 | base_params.dst_scales = dst_scales + d0 * ns[0].ss; |
2268 | base_params.src_zp = src_zp; |
2269 | base_params.dst_zp = dst_zp; |
2270 | base_params.compensation_scratch = compensation_scratch + d0 * ns[0].cs; |
2271 | |
2272 | if (prb.is_tail_present) { |
2273 | tr::tail_call_param_t tail_params; |
2274 | tail_params.base_params = base_params; |
2275 | |
2276 | static constexpr int omp_ndims = 1; |
2277 | const ptrdiff_t omp_data_chunks[omp_ndims] = {d0}; |
2278 | fill_curr_data_chunks( |
2279 | prb, off, omp_data_chunks, omp_ndims, tail_params); |
2280 | |
2281 | (*kernel_)(&tail_params); |
2282 | } else { |
2283 | (*kernel_)(&base_params); |
2284 | } |
2285 | }); |
2286 | } |
2287 | |
2288 | void jit_uni_reorder_t::omp_driver_2d(int ithr, int nthr, int off, |
2289 | const char *in, char *out, const float *src_scales, |
2290 | const float *dst_scales, int src_zp, int dst_zp, |
2291 | int32_t *compensation_scratch) const { |
2292 | const tr::prb_t &prb = pd()->prb_; |
2293 | const tr::node_t *ns = prb.nodes + off; |
2294 | for_nd(ithr, nthr, (ptrdiff_t)ns[1].n, (ptrdiff_t)ns[0].n, |
2295 | [&](ptrdiff_t d1, ptrdiff_t d0) { |
2296 | tr::call_param_t base_params; |
2297 | base_params.in = in |
2298 | + (d0 * ns[0].is + d1 * ns[1].is) |
2299 | * data_type_size(prb.itype); |
2300 | base_params.out = out |
2301 | + (d0 * ns[0].os + d1 * ns[1].os) |
2302 | * data_type_size(prb.otype); |
2303 | base_params.src_scales |
2304 | = src_scales + d0 * ns[0].ss + d1 * ns[1].ss; |
2305 | base_params.dst_scales |
2306 | = dst_scales + d0 * ns[0].ss + d1 * ns[1].ss; |
2307 | base_params.src_zp = src_zp; |
2308 | base_params.dst_zp = dst_zp; |
2309 | base_params.compensation_scratch |
2310 | = compensation_scratch + d0 * ns[0].cs + d1 * ns[1].cs; |
2311 | |
2312 | if (prb.is_tail_present) { |
2313 | tr::tail_call_param_t tail_params; |
2314 | tail_params.base_params = base_params; |
2315 | |
2316 | static constexpr int omp_ndims = 2; |
2317 | const ptrdiff_t omp_data_chunks[omp_ndims] = {d0, d1}; |
2318 | fill_curr_data_chunks( |
2319 | prb, off, omp_data_chunks, omp_ndims, tail_params); |
2320 | |
2321 | (*kernel_)(&tail_params); |
2322 | } else { |
2323 | (*kernel_)(&base_params); |
2324 | } |
2325 | }); |
2326 | } |
2327 | |
2328 | void jit_uni_reorder_t::omp_driver_3d(int ithr, int nthr, int off, |
2329 | const char *in, char *out, const float *src_scales, |
2330 | const float *dst_scales, int src_zp, int dst_zp, |
2331 | int32_t *compensation_scratch) const { |
2332 | const tr::prb_t &prb = pd()->prb_; |
2333 | const tr::node_t *ns = prb.nodes + off; |
2334 | for_nd(ithr, nthr, (ptrdiff_t)ns[2].n, (ptrdiff_t)ns[1].n, |
2335 | (ptrdiff_t)ns[0].n, [&](ptrdiff_t d2, ptrdiff_t d1, ptrdiff_t d0) { |
2336 | tr::call_param_t base_params; |
2337 | base_params.in = in |
2338 | + (d0 * ns[0].is + d1 * ns[1].is + d2 * ns[2].is) |
2339 | * data_type_size(prb.itype); |
2340 | base_params.out = out |
2341 | + (d0 * ns[0].os + d1 * ns[1].os + d2 * ns[2].os) |
2342 | * data_type_size(prb.otype); |
2343 | base_params.src_scales = src_scales + d0 * ns[0].ss |
2344 | + d1 * ns[1].ss + d2 * ns[2].ss; |
2345 | base_params.dst_scales = dst_scales + d0 * ns[0].ss |
2346 | + d1 * ns[1].ss + d2 * ns[2].ss; |
2347 | base_params.src_zp = src_zp; |
2348 | base_params.dst_zp = dst_zp; |
2349 | base_params.compensation_scratch = compensation_scratch |
2350 | + d0 * ns[0].cs + d1 * ns[1].cs + d2 * ns[2].cs; |
2351 | |
2352 | if (prb.is_tail_present) { |
2353 | tr::tail_call_param_t tail_params; |
2354 | tail_params.base_params = base_params; |
2355 | |
2356 | static constexpr int omp_ndims = 3; |
2357 | const ptrdiff_t omp_data_chunks[omp_ndims] = {d0, d1, d2}; |
2358 | fill_curr_data_chunks( |
2359 | prb, off, omp_data_chunks, omp_ndims, tail_params); |
2360 | |
2361 | (*kernel_)(&tail_params); |
2362 | } else { |
2363 | (*kernel_)(&base_params); |
2364 | } |
2365 | }); |
2366 | } |
2367 | |
2368 | void jit_uni_reorder_t::omp_driver_4d(int ithr, int nthr, int off, |
2369 | const char *in, char *out, const float *src_scales, |
2370 | const float *dst_scales, int src_zp, int dst_zp, |
2371 | int32_t *compensation_scratch) const { |
2372 | const tr::prb_t &prb = pd()->prb_; |
2373 | const tr::node_t *ns = prb.nodes + off; |
2374 | for_nd(ithr, nthr, (ptrdiff_t)ns[3].n, (ptrdiff_t)ns[2].n, |
2375 | (ptrdiff_t)ns[1].n, (ptrdiff_t)ns[0].n, |
2376 | [&](ptrdiff_t d3, ptrdiff_t d2, ptrdiff_t d1, ptrdiff_t d0) { |
2377 | tr::call_param_t base_params; |
2378 | base_params.in = in |
2379 | + (d0 * ns[0].is + d1 * ns[1].is + d2 * ns[2].is |
2380 | + d3 * ns[3].is) |
2381 | * data_type_size(prb.itype); |
2382 | base_params.out = out |
2383 | + (d0 * ns[0].os + d1 * ns[1].os + d2 * ns[2].os |
2384 | + d3 * ns[3].os) |
2385 | * data_type_size(prb.otype); |
2386 | base_params.src_scales = src_scales + d0 * ns[0].ss |
2387 | + d1 * ns[1].ss + d2 * ns[2].ss + d3 * ns[3].ss; |
2388 | base_params.dst_scales = dst_scales + d0 * ns[0].ss |
2389 | + d1 * ns[1].ss + d2 * ns[2].ss + d3 * ns[3].ss; |
2390 | base_params.src_zp = src_zp; |
2391 | base_params.dst_zp = dst_zp; |
2392 | base_params.compensation_scratch = compensation_scratch |
2393 | + d0 * ns[0].cs + d1 * ns[1].cs + d2 * ns[2].cs |
2394 | + d3 * ns[3].cs; |
2395 | |
2396 | if (prb.is_tail_present) { |
2397 | tr::tail_call_param_t tail_params; |
2398 | tail_params.base_params = base_params; |
2399 | |
2400 | static constexpr int omp_ndims = 4; |
2401 | const ptrdiff_t omp_data_chunks[omp_ndims] |
2402 | = {d0, d1, d2, d3}; |
2403 | fill_curr_data_chunks( |
2404 | prb, off, omp_data_chunks, omp_ndims, tail_params); |
2405 | |
2406 | (*kernel_)(&tail_params); |
2407 | } else { |
2408 | (*kernel_)(&base_params); |
2409 | } |
2410 | }); |
2411 | } |
2412 | |
2413 | void jit_uni_reorder_t::omp_driver(const char *in, char *out, |
2414 | const float *src_scales, const float *dst_scales, int src_zp, |
2415 | int dst_zp, const memory_tracking::grantor_t &scratchpad) const { |
2416 | in += pd()->prb_.ioff * data_type_size(pd()->prb_.itype); |
2417 | out += pd()->prb_.ooff * data_type_size(pd()->prb_.otype); |
2418 | |
2419 | DEBUG({ |
2420 | printf("prb : " ); |
2421 | tr::prb_dump(pd()->prb_); |
2422 | }); |
2423 | DEBUG({ |
2424 | printf("ker : " ); |
2425 | tr::prb_dump(pd()->ker_desc_.prb); |
2426 | }); |
2427 | |
2428 | int ndims = pd()->prb_.ndims; |
2429 | int ndims_ker = pd()->ker_desc_.prb.ndims; |
2430 | const bool req_s8s8_comp = pd()->prb_.req_s8s8_comp; |
2431 | const bool req_asymmetric_comp = pd()->prb_.req_asymmetric_comp; |
2432 | const bool req_compensation = req_s8s8_comp || req_asymmetric_comp; |
2433 | assert(ndims - ndims_ker <= ndims_driver_max); |
2434 | |
2435 | int32_t *compensation_reduce_scratch = scratchpad.template get<int32_t>( |
2436 | memory_tracking::names::key_reorder_space); |
2437 | |
2438 | const memory_desc_wrapper od(pd()->dst_md()); |
2439 | const auto G = pd()->with_groups_ ? od.padded_dims()[0] : 1; |
2440 | const auto N = od.padded_dims()[pd()->with_groups_ ? 1 : 0]; |
2441 | static constexpr int cache_line_size = 16; |
2442 | const auto wspace_per_thr_size = utils::rnd_up(G * N, cache_line_size); |
2443 | const auto wspace_per_thr_bytes = wspace_per_thr_size * sizeof(int32_t); |
2444 | |
2445 | if (ndims - ndims_ker == 0) { |
2446 | if (req_compensation) |
2447 | std::memset(compensation_reduce_scratch, 0, wspace_per_thr_bytes); |
2448 | |
2449 | omp_driver_0d(ndims_ker, in, out, src_scales, dst_scales, src_zp, |
2450 | dst_zp, compensation_reduce_scratch); |
2451 | } else { |
2452 | parallel(pd()->nthr_, [&](const int ithr, const int nthr) { |
2453 | int32_t *compensation_scratch = nullptr; |
2454 | if (req_compensation) { |
2455 | compensation_scratch = &compensation_reduce_scratch[ithr |
2456 | * wspace_per_thr_size]; |
2457 | std::memset(compensation_scratch, 0, wspace_per_thr_bytes); |
2458 | } |
2459 | |
2460 | switch (ndims - ndims_ker) { |
2461 | case 1: |
2462 | omp_driver_1d(ithr, nthr, ndims_ker, in, out, src_scales, |
2463 | dst_scales, src_zp, dst_zp, compensation_scratch); |
2464 | break; |
2465 | case 2: |
2466 | omp_driver_2d(ithr, nthr, ndims_ker, in, out, src_scales, |
2467 | dst_scales, src_zp, dst_zp, compensation_scratch); |
2468 | break; |
2469 | case 3: |
2470 | omp_driver_3d(ithr, nthr, ndims_ker, in, out, src_scales, |
2471 | dst_scales, src_zp, dst_zp, compensation_scratch); |
2472 | break; |
2473 | case 4: |
2474 | omp_driver_4d(ithr, nthr, ndims_ker, in, out, src_scales, |
2475 | dst_scales, src_zp, dst_zp, compensation_scratch); |
2476 | break; |
2477 | default: assert(!"unimplemented" ); |
2478 | } |
2479 | }); |
2480 | } |
2481 | |
2482 | //reduction of intermediate compensation results to the final output |
2483 | if (req_compensation) { |
2484 | const int nthr = ndims - ndims_ker == 0 ? 1 : pd()->nthr_; |
2485 | reduce_compensation( |
2486 | out, compensation_reduce_scratch, nthr, wspace_per_thr_size); |
2487 | } |
2488 | } |
2489 | |
2490 | void jit_uni_reorder_t::reduce_compensation(char *out, |
2491 | const int32_t *compensation_reduce_scratch, const int nthr, |
2492 | const dim_t wspace_per_thr_size) const { |
2493 | |
2494 | const memory_desc_wrapper od(pd()->dst_md()); |
2495 | const size_t offset = od.size() - od.additional_buffer_size(); |
2496 | |
2497 | static constexpr auto comp_dt_size = sizeof(int32_t); |
2498 | static constexpr int32_t comp_s8s8_shift = 128; |
2499 | |
2500 | // Note: We do not need to explicitly zero-out compensation buffer, as the |
2501 | // per_thread buffers are already zeroed out in the padded area. |
2502 | const auto G = pd()->with_groups_ ? od.padded_dims()[0] : 1; |
2503 | const auto N = od.padded_dims()[pd()->with_groups_ ? 1 : 0]; |
2504 | const auto GN = G * N; |
2505 | const bool req_s8s8_comp = pd()->prb_.req_s8s8_comp; |
2506 | const bool req_asymmetric_comp = pd()->prb_.req_asymmetric_comp; |
2507 | const size_t zp_offset |
2508 | = offset + (pd()->prb_.req_s8s8_comp ? GN * comp_dt_size : 0); |
2509 | |
2510 | parallel_nd(GN, [&](int idx) { |
2511 | int32_t acc = 0; |
2512 | for (int ithr = 0; ithr < nthr; ithr++) { |
2513 | acc -= compensation_reduce_scratch[ithr * wspace_per_thr_size |
2514 | + idx]; |
2515 | } |
2516 | if (req_s8s8_comp) { |
2517 | int32_t *out_comp = reinterpret_cast<int32_t *>(&out[offset]); |
2518 | out_comp[idx] = comp_s8s8_shift * acc; |
2519 | } |
2520 | if (req_asymmetric_comp) { |
2521 | int32_t *out_asym_comp |
2522 | = reinterpret_cast<int32_t *>(&out[zp_offset]); |
2523 | out_asym_comp[idx] = acc; |
2524 | } |
2525 | }); |
2526 | } |
2527 | |
2528 | void jit_uni_reorder_t::fill_curr_data_chunks(const tr::prb_t &prb, |
2529 | const int off, const ptrdiff_t *omp_data_chunks, const int omp_ndims, |
2530 | tr::tail_call_param_t &c) const { |
2531 | // Chunks are backwards numered i.e: |
2532 | // [0] -> [node_size] |
2533 | // [1] -> [node_size - 1] |
2534 | // ... |
2535 | // [node_size - 1] -> [1] |
2536 | |
2537 | // It is done like this, because it is easier to decrement counter |
2538 | // and check if it is equal to zero than increment and check |
2539 | // if it is equal to node_size in jit kernel. |
2540 | |
2541 | static constexpr int64_t empty_chunk_info = -1; |
2542 | static constexpr int64_t last_chunk = 1; |
2543 | |
2544 | for (int curr_node_id = prb.ndims - 1; curr_node_id >= 0; curr_node_id--) { |
2545 | const int parent_node_id = prb.nodes[curr_node_id].parent_node_id; |
2546 | const bool is_drv_processing_this_node |
2547 | = curr_node_id >= off && curr_node_id <= off + omp_ndims - 1; |
2548 | const bool is_tail_processing |
2549 | = prb.is_tail_in_one_of_child_nodes(curr_node_id) |
2550 | || prb.nodes[curr_node_id].tail_size > 0; |
2551 | |
2552 | if (is_drv_processing_this_node && is_tail_processing) { |
2553 | const int inner_idx = curr_node_id - off; |
2554 | assert(inner_idx < omp_ndims); |
2555 | const int64_t node_size = prb.nodes[curr_node_id].tail_size > 0 |
2556 | ? prb.nodes[curr_node_id].tail_size |
2557 | : prb.nodes[curr_node_id].n; |
2558 | const int64_t data_chunk = node_size - omp_data_chunks[inner_idx]; |
2559 | |
2560 | if (!prb.nodes[curr_node_id].is_parent_empty()) { |
2561 | const bool is_parent_chunk_last |
2562 | = c.curr_data_chunks[parent_node_id] == last_chunk; |
2563 | c.curr_data_chunks[curr_node_id] |
2564 | = is_parent_chunk_last ? data_chunk : empty_chunk_info; |
2565 | c.zeroing_data = static_cast<int64_t>( |
2566 | is_parent_chunk_last && data_chunk <= 0); |
2567 | } else { |
2568 | c.curr_data_chunks[curr_node_id] = data_chunk; |
2569 | c.zeroing_data = static_cast<int64_t>(data_chunk <= 0); |
2570 | } |
2571 | c.skip_kernel_execution = static_cast<int64_t>(c.zeroing_data |
2572 | && !prb.nodes[curr_node_id].is_zero_pad_needed); |
2573 | if (c.zeroing_data || c.skip_kernel_execution) break; |
2574 | } else |
2575 | c.curr_data_chunks[curr_node_id] = empty_chunk_info; |
2576 | } |
2577 | } |
2578 | |
2579 | status_t jit_uni_reorder_t::init(engine_t *engine) { |
2580 | CHECK(safe_ptr_assign(kernel_, tr::kernel_t::create(pd()->ker_desc_))); |
2581 | return kernel_->create_kernel(); |
2582 | } |
2583 | |
2584 | status_t jit_uni_reorder_t::execute(const exec_ctx_t &ctx) const { |
2585 | const auto &scratchpad = ctx.get_scratchpad_grantor(); |
2586 | |
2587 | auto in = CTX_IN_MEM(const char *, DNNL_ARG_FROM); |
2588 | auto out = CTX_OUT_MEM(char *, DNNL_ARG_TO); |
2589 | DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC); |
2590 | DEFINE_ARG_SCALES_BUFFER(dst_scales_, DNNL_ARG_DST); |
2591 | |
2592 | const float *dst_scales = pd()->precompute_scales( |
2593 | scratchpad, pd()->attr(), pd()->D_mask_, dst_scales_); |
2594 | assert(dst_scales); |
2595 | |
2596 | DEFINE_ZERO_POINT_VALUE(src_zp, DNNL_ARG_FROM); |
2597 | DEFINE_ZERO_POINT_VALUE(dst_zp, DNNL_ARG_TO); |
2598 | |
2599 | omp_driver(in, out, src_scales, dst_scales, src_zp, dst_zp, scratchpad); |
2600 | |
2601 | return status::success; |
2602 | } |
2603 | |
2604 | status_t jit_blk_reorder_t::pd_t::create(reorder_pd_t **reorder_pd, |
2605 | engine_t *engine, const primitive_attr_t *attr, engine_t *src_engine, |
2606 | const memory_desc_t *src_md, engine_t *dst_engine, |
2607 | const memory_desc_t *dst_md) { |
2608 | auto prb = tr::prb_t(); |
2609 | |
2610 | status_t prb_init_status = prb_init(prb, *src_md, *dst_md, attr); |
2611 | if (prb_init_status != status::success) return prb_init_status; |
2612 | // only uni_reorder supports tail processing now |
2613 | // TODO: Add tail processing support in blk_reorder |
2614 | if (prb.is_tail_present) return status::unimplemented; |
2615 | |
2616 | prb_tile_normalize(prb); |
2617 | DEBUG({ |
2618 | printf("tile : " ); |
2619 | prb_dump(prb); |
2620 | }); |
2621 | |
2622 | if (!tr::jit_single_blk_kernel_t::applicable(prb)) { |
2623 | return status::unimplemented; |
2624 | } |
2625 | |
2626 | auto _pd = new pd_t( |
2627 | attr, src_engine->kind(), src_md, dst_engine->kind(), dst_md); |
2628 | if (_pd == nullptr) return status::out_of_memory; |
2629 | _pd->prb_ = prb; |
2630 | if (_pd->init(engine, src_engine, dst_engine) != status::success) { |
2631 | delete _pd; |
2632 | return status::unimplemented; |
2633 | } |
2634 | _pd->init_scratchpad_md(); |
2635 | |
2636 | return safe_ptr_assign(*reorder_pd, _pd); |
2637 | } |
2638 | |
2639 | void jit_blk_reorder_t::pd_t::prb_tile_normalize(tr::prb_t &p) { |
2640 | if (!utils::one_of(p.nodes[0].n, 8ul, 16ul) |
2641 | && utils::one_of(p.nodes[1].n, 8ul, 16ul)) { |
2642 | nstl::swap(p.nodes[0], p.nodes[1]); |
2643 | } |
2644 | } |
2645 | |
2646 | jit_blk_reorder_t::jit_blk_reorder_t(const pd_t *apd) : primitive_t(apd) {} |
2647 | jit_blk_reorder_t::~jit_blk_reorder_t() = default; |
2648 | |
2649 | status_t jit_blk_reorder_t::init(engine_t *engine) { |
2650 | kernel_ = utils::make_unique<tr::jit_single_blk_kernel_t>(pd()->prb_); |
2651 | return kernel_->create_kernel(); |
2652 | } |
2653 | |
2654 | status_t jit_blk_reorder_t::execute(const exec_ctx_t &ctx) const { |
2655 | const auto in = CTX_IN_MEM(const char *, DNNL_ARG_FROM); |
2656 | auto out = CTX_OUT_MEM(char *, DNNL_ARG_TO); |
2657 | DEFINE_ZERO_POINT_VALUE(src_zp, DNNL_ARG_FROM); |
2658 | DEFINE_ZERO_POINT_VALUE(dst_zp, DNNL_ARG_TO); |
2659 | |
2660 | // kernel handle 2-dimension tiles, a tail is possible |
2661 | auto &prb = this->pd()->prb_; |
2662 | ptrdiff_t BH = 1; |
2663 | for (int i = 2; i < prb.ndims; ++i) { |
2664 | BH *= prb.nodes[i].n; |
2665 | } |
2666 | |
2667 | auto block_sz = prb.n(0); |
2668 | auto n1 = prb.n(1); |
2669 | auto i1 = prb.is(1); |
2670 | auto o1 = prb.os(1); |
2671 | auto FL = (n1 + block_sz - 1) / block_sz; |
2672 | auto bh_stride = BH == 1 ? 0 : prb.is(2); |
2673 | |
2674 | auto itype_sz_ = data_type_size(pd()->prb_.itype); |
2675 | auto otype_sz_ = data_type_size(pd()->prb_.otype); |
2676 | |
2677 | parallel_nd(BH, FL, [&](dim_t bh, dim_t fl) { |
2678 | auto fl_b = fl * block_sz; |
2679 | auto bh_b = bh_stride * bh; |
2680 | auto *i = in + (bh_b + fl_b * i1) * itype_sz_; |
2681 | auto *o = out + (bh_b + fl_b * o1) * otype_sz_; |
2682 | (*kernel_)(i, o, n1 - fl_b < block_sz, src_zp, dst_zp); |
2683 | }); |
2684 | |
2685 | return status::success; |
2686 | } |
2687 | |
2688 | } // namespace x64 |
2689 | } // namespace cpu |
2690 | } // namespace impl |
2691 | } // namespace dnnl |
2692 | |