1/*******************************************************************************
2* Copyright 2018-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <assert.h>
18#include <numeric>
19
20#include "oneapi/dnnl/dnnl_debug.h"
21
22#include "common/c_types_map.hpp"
23#include "common/dnnl_thread.hpp"
24#include "common/memory_desc_wrapper.hpp"
25#include "common/nstl.hpp"
26#include "common/primitive.hpp"
27#include "common/type_helpers.hpp"
28#include "common/utils.hpp"
29
30#include "cpu/cpu_primitive.hpp"
31#include "cpu/reorder/cpu_reorder_pd.hpp"
32#include "cpu/x64/jit_uni_reorder.hpp"
33
34#include "cpu/x64/jit_avx512_core_bf16cvt.hpp"
35#include "cpu/x64/jit_generator.hpp"
36
37// #define TR_DEBUG
38#if defined(TR_DEBUG)
39#define DEBUg(...) \
40 do { \
41 __VA_ARGS__ \
42 } while (0)
43#else
44#define DEBUg(...)
45#endif
46#define DEBUG(...) DEBUg(__VA_ARGS__)
47
48#ifdef _WIN32
49/* seems like s_addr is a reserved macro on Windows */
50#undef s_addr
51constexpr static bool is_windows = true;
52#else
53constexpr static bool is_windows = false;
54#endif
55
56using namespace Xbyak;
57using namespace dnnl::impl::types;
58
59namespace dnnl {
60namespace impl {
61namespace cpu {
62namespace x64 {
63
64namespace tr {
65
66static bool prb_has_small_strides(const prb_t &prb) {
67 constexpr ptrdiff_t max_stride = (1LL << 31) - 1;
68 for (int d = 0; d < prb.ndims; ++d) {
69 const ptrdiff_t cms = max_stride / prb.nodes[d].n;
70 const bool small_strides = true
71 && prb.nodes[d].is < cms / (int)data_type_size(prb.itype)
72 && prb.nodes[d].os < cms / (int)data_type_size(prb.otype);
73 if (!small_strides) return false;
74 }
75 return true;
76}
77
78/** Minimal reasonable/desirable kernel size.
79 * The constant might be used to determine how a problem should be split
80 * between kernel and threading driver. */
81const size_t ker_prb_size_min = 64;
82
83/* kernel */
84struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator {
85 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_reorder_kernel_f32)
86
87 void operator()(const call_param_t *c) const override {
88 jit_generator::operator()(c);
89 }
90 void operator()(const tail_call_param_t *c) const override {
91 jit_generator::operator()(c);
92 }
93
94 status_t create_kernel() override { return jit_generator::create_kernel(); }
95
96 enum class scale_arg_t { NONE, SRC, DST };
97
98 enum {
99 len_unroll_max = 256,
100 ndims_jit_loop_max = 3,
101 };
102
103 struct simple_impl_desc_t {
104 int ndims_full_unroll = 0;
105 int len_last_dim_unroll = 0;
106 int tail_len_unroll = 0;
107 int len_unroll = 0;
108 };
109
110#define PARAM(x) \
111 prb_.is_tail_present \
112 ? ptr[abi_param1 + offsetof(tail_call_param_t, base_params) \
113 + offsetof(call_param_t, x)] \
114 : ptr[abi_param1 + offsetof(call_param_t, x)]
115#define TAIL_PARAM(x) ptr[abi_param1 + offsetof(tail_call_param_t, x)]
116
117 static bool simple_impl_desc_init(
118 const prb_t &prb, simple_impl_desc_t *desc) {
119 const int ndims = prb.ndims;
120
121 int ndims_full_unroll = 0;
122 int len_last_dim_unroll = 1;
123 int tail_len_unroll = 0;
124 int len_unroll = 1;
125
126 // It is responsible for finding as many values
127 // as kernel can unroll. If tail is present then
128 // kernel will unroll only last node (possible improvement).
129 // If there is no tail kernel can unroll a few nodes without any loops etc.
130 // ndims_full_unroll - how many nodes will be unrolled
131 // len_last_dim_unroll - what piece of last unrolled node will be unrolled
132 if (prb.is_tail_present) {
133 ndims_full_unroll = 1;
134 len_unroll = prb.nodes[0].n;
135 tail_len_unroll = prb.nodes[0].is_zero_pad_needed
136 ? 0
137 : static_cast<int>(prb.nodes[0].tail_size);
138 } else {
139 for (int d = 0; d < ndims; ++d) {
140 const auto &node = prb.nodes[d];
141 if (len_unroll * node.n <= len_unroll_max) {
142 ndims_full_unroll++;
143 len_unroll *= node.n;
144 } else {
145 len_last_dim_unroll = len_unroll_max / len_unroll;
146 while (node.n % len_last_dim_unroll)
147 --len_last_dim_unroll;
148 len_unroll *= len_last_dim_unroll;
149 break;
150 }
151 }
152 }
153
154 if (prb.ndims - ndims_full_unroll > ndims_jit_loop_max) return false;
155
156 if (desc) {
157 desc->ndims_full_unroll = ndims_full_unroll;
158 desc->len_last_dim_unroll = len_last_dim_unroll;
159 desc->tail_len_unroll = tail_len_unroll;
160 desc->len_unroll = len_unroll;
161 }
162
163 return true;
164 }
165
166 static bool applicable(const prb_t &p) {
167 using namespace data_type;
168
169 bool ok = true && p.ndims > 0
170 && utils::one_of(p.itype, f32, bf16, f16, s32, s8, u8)
171 && utils::one_of(p.otype, f32, bf16, f16, s32, s8, u8)
172 && IMPLICATION(utils::one_of(p.itype, bf16, f16),
173 utils::one_of(p.otype, s8, u8, f32, bf16, f16))
174 && IMPLICATION(utils::one_of(p.otype, bf16, f16),
175 utils::one_of(p.itype, s8, u8, f32, bf16, f16))
176 && utils::everyone_is(0, p.ioff, p.ooff) /* do we need this? */
177 && utils::one_of(p.beta, 0.f, 1.f) /* anything else? */
178 && simple_impl_desc_init(p, nullptr) && mayiuse(sse41)
179 && IMPLICATION(utils::one_of(bf16, p.itype, p.otype),
180 mayiuse(avx512_core) || mayiuse(avx2_vnni_2))
181 && IMPLICATION(utils::one_of(f16, p.itype, p.otype),
182 mayiuse(avx512_core_fp16) || mayiuse(avx2_vnni_2))
183 && prb_has_small_strides(p);
184
185 return ok;
186 }
187
188 Address i_addr(int i_off) {
189 return ptr[reg_ptr_in_ + reg_off_in_ + i_off * itype_sz_];
190 }
191
192 Address o_addr(int o_off, bool with_type_multiplier = true) {
193 if (with_type_multiplier)
194 return ptr[reg_ptr_out_ + reg_off_out_ + o_off * otype_sz_];
195 else
196 return ptr[reg_ptr_out_ + reg_off_out_ + o_off];
197 }
198
199 Address src_s_addr(int s_off) {
200 return ptr[reg_ptr_src_scales_ + reg_off_scale_ + s_off * stype_sz_];
201 }
202
203 Address dst_s_addr(int s_off) {
204 return ptr[reg_ptr_dst_scales_ + reg_off_scale_ + s_off * stype_sz_];
205 }
206
207 Address c_addr(int c_off) {
208 return ptr[reg_ptr_comp_ + reg_off_comp_ + c_off * sizeof(int32_t)];
209 }
210
211 Address data_chunk_addr(int node_id) {
212 return ptr[abi_param1 + offsetof(tail_call_param_t, curr_data_chunks)
213 + sizeof(int64_t) * (node_id)];
214 }
215
216 void step(int off, int prev_i_off, int prev_o_off, int prev_s_off,
217 int prev_c_off, int &i_off, int &o_off, int &s_off, int &c_off,
218 int step_size = 1) {
219 i_off = prev_i_off;
220 o_off = prev_o_off;
221 s_off = prev_s_off;
222 c_off = prev_c_off;
223
224 if (off == 0) return;
225
226 int start_dim = 0, dims_prod = 1;
227 for (; start_dim < prb_.ndims && dims_prod != step_size; ++start_dim)
228 dims_prod *= prb_.n(start_dim);
229 assert(start_dim < prb_.ndims);
230 off /= step_size;
231
232 for (int dim_id = start_dim; dim_id < prb_.ndims; ++dim_id) {
233 i_off += prb_.is(dim_id);
234 o_off += prb_.os(dim_id);
235 s_off += prb_.ss(dim_id);
236 c_off += prb_.cs(dim_id);
237
238 if (off % prb_.n(dim_id)) break;
239
240 i_off += -prb_.n(dim_id) * prb_.is(dim_id);
241 o_off += -prb_.n(dim_id) * prb_.os(dim_id);
242 s_off += -prb_.n(dim_id) * prb_.ss(dim_id);
243 c_off += -prb_.n(dim_id) * prb_.cs(dim_id);
244
245 off /= prb_.n(dim_id);
246
247 if (off == 0) break; /* FIXME: is it really required? */
248 }
249 }
250
251 void step(int off, int prev_i_off, int prev_o_off, int &i_off, int &o_off,
252 int step_size = 1) {
253 int dummy = 0;
254 step(off, prev_i_off, prev_o_off, dummy, dummy, i_off, o_off, dummy,
255 dummy, step_size);
256 }
257
258 void tr8x8_avx2(int i_off, int o_off) {
259 using namespace data_type;
260
261 const auto cvt2ps
262 = [=](const Ymm dst, const Operand &src, data_type_t idt) {
263 switch (idt) {
264 case f32:
265 if (src.isMEM() || src.getIdx() != dst.getIdx())
266 vmovups(dst, src);
267 break;
268 case bf16:
269 vpmovzxwd(dst, src);
270 vpslld(dst, dst, 0x10);
271 break;
272 case f16:
273 if (is_superset(isa_, avx512_core_fp16)) {
274 if (src.isMEM())
275 vcvtph2psx(dst, src);
276 else
277 vcvtph2psx(dst, Xmm(src.getIdx()));
278 } else if (is_superset(isa_, avx2_vnni_2)) {
279 if (src.isMEM())
280 vcvtph2ps(dst, src);
281 else
282 vcvtph2ps(dst, Xmm(src.getIdx()));
283 } else
284 assert(!"invalid isa");
285 break;
286 case s32: vcvtdq2ps(dst, src); break;
287 case s8:
288 vpmovsxbd(dst, src);
289 vcvtdq2ps(dst, dst);
290 break;
291 case u8:
292 vpmovzxbd(dst, src);
293 vcvtdq2ps(dst, dst);
294 break;
295 default: assert(!"unreachable");
296 }
297 };
298
299 const auto cvt2odt = [=](const Ymm ymm, data_type_t odt,
300 data_type_t idt) {
301 const Xmm xmm = Xmm(ymm.getIdx());
302 switch (odt) {
303 case bf16:
304 if (utils::one_of(idt, f32, f16, s8, u8)) {
305 if (idt != f32) cvt2ps(ymm, ymm, idt);
306 if (is_superset(isa_, avx2_vnni_2)) {
307 vcvtneps2bf16(
308 Xmm(ymm.getIdx()), ymm, Xbyak::VexEncoding);
309 } else if (mayiuse(avx512_core_bf16)) {
310 vcvtneps2bf16(Xmm(ymm.getIdx()), ymm);
311 } else {
312 bf16_emu_->vcvtneps2bf16(
313 Ymm(ymm.getIdx()), Zmm(ymm.getIdx()));
314 }
315 }
316 break;
317 case f16:
318 if (utils::one_of(idt, f32, bf16, s8, u8)) {
319 if (idt != f32) cvt2ps(ymm, ymm, idt);
320 vcvtps2ph(Xmm(ymm.getIdx()), ymm, _op_mxcsr);
321 }
322 break;
323 case s32:
324 if (idt == f32)
325 vcvtps2dq(ymm, ymm);
326 else if (idt == s8)
327 vpmovsxbd(ymm, ymm);
328 else if (idt == u8)
329 vpmovzxbd(ymm, ymm);
330 break;
331 case s8:
332 if (utils::one_of(idt, bf16, f16)) cvt2ps(ymm, ymm, idt);
333 if (utils::one_of(idt, f32, bf16, f16)) vcvtps2dq(ymm, ymm);
334 if (utils::one_of(idt, bf16, f16, f32, s32)) {
335 if (mayiuse(avx512_core)) {
336 vpmovsdb(xmm, ymm);
337 } else {
338 vpackssdw(ymm, ymm, ymm_zero_);
339 vpermq(ymm, ymm, 0x58);
340 vpacksswb(ymm, ymm, ymm_zero_);
341 }
342 }
343 if (idt == u8) vpminub(ymm, ymm, ymm_8x127b_);
344 break;
345 case u8:
346 if (utils::one_of(idt, bf16, f16)) cvt2ps(ymm, ymm, idt);
347 if (utils::one_of(idt, f32, bf16, f16)) vcvtps2dq(ymm, ymm);
348 if (utils::one_of(idt, bf16, f16, f32, s32)) {
349 if (mayiuse(avx512_core)) {
350 vpmaxsd(ymm, ymm, ymm_zero_);
351 vpmovusdb(xmm, ymm);
352 } else {
353 vpackssdw(ymm, ymm, ymm_zero_);
354 vpermq(ymm, ymm, 0x58);
355 vpackuswb(ymm, ymm, ymm_zero_);
356 }
357 }
358 if (idt == s8) vpmaxsb(ymm, ymm, ymm_zero_);
359 break;
360 default: assert(!"unreachable");
361 }
362 };
363
364 auto load = [=](const Ymm ymm, const Address &addr, int size) {
365 const Xmm xmm = Xmm(ymm.getIdx());
366 switch (size) {
367 case 32: vmovups(ymm, addr); break;
368 case 16: vmovups(xmm, addr); break;
369 case 8: vmovsd(xmm, addr); break;
370 default: assert(!"unreachable");
371 }
372 };
373
374 auto store = [=](const Address &addr, const Ymm ymm, int size) {
375 const Xmm xmm = Xmm(ymm.getIdx());
376 switch (size) {
377 case 32: vmovups(addr, ymm); break;
378 case 16: vmovups(addr, xmm); break;
379 case 8: vmovsd(addr, xmm); break;
380 default: assert(!"unreachable");
381 }
382 };
383
384 const int unroll = 8;
385
386 const bool interim_f32 = (prb_.itype != f32)
387 || utils::one_of(f32, prb_.itype, prb_.otype);
388
389 const bool need_saturation
390 = (utils::one_of(prb_.otype, u8, s8, s32) && interim_f32);
391
392 for (int i = 0; i < unroll; i++) {
393 const int node_0_input_stride = prb_.is(0);
394 load(Ymm(i), i_addr(i_off + i * node_0_input_stride),
395 unroll * itype_sz_);
396
397 if (interim_f32) cvt2ps(Ymm(i), Ymm(i), prb_.itype);
398 }
399
400 for (int i = 0; i < unroll / 2; i++) {
401 vunpcklps(Ymm(unroll + i), Ymm(2 * i), Ymm(2 * i + 1));
402 vunpckhps(Ymm(i), Ymm(2 * i), Ymm(2 * i + 1));
403 }
404
405 const unsigned int lfloat = 0x44;
406 const unsigned int ufloat = 0xee;
407 for (int i = 0; i < unroll / 2; i++) {
408 const int j = i % 2 == 0 ? unroll + i : i - 1;
409 vshufps(Ymm(unroll / 2 + 2 * i), Ymm(j), Ymm(j + 1), lfloat);
410 vshufps(Ymm(unroll / 2 + 2 * i + 1), Ymm(j), Ymm(j + 1), ufloat);
411 }
412
413 const unsigned int lquad = 0x20;
414 for (int i = 0; i < unroll / 2; i++)
415 vperm2f128(Ymm(i), Ymm(unroll / 2 + i), Ymm(unroll + i), lquad);
416
417 const unsigned int uquad = 0x31;
418 for (int i = unroll / 2; i < unroll; i++)
419 vperm2f128(Ymm(i), Ymm(i), Ymm(unroll / 2 + i), uquad);
420
421 if (need_saturation) {
422 init_saturate_f32(ymm_zero_, ymm_saturation_ubound_, reg_tmp_,
423 interim_f32 ? f32 : prb_.itype, prb_.otype);
424 for (int i = 0; i < unroll; i++)
425 saturate_f32(
426 Ymm(i), ymm_zero_, ymm_saturation_ubound_, prb_.otype);
427 }
428
429 for (int i = 0; i < unroll; i++) {
430 const int node_1_output_stride = prb_.os(1);
431 if (prb_.otype != f32)
432 cvt2odt(Ymm(i), prb_.otype, interim_f32 ? f32 : prb_.itype);
433 store(o_addr(o_off + i * node_1_output_stride), Ymm(i),
434 unroll * otype_sz_);
435 }
436 }
437
438 bool can_do_tr8x8() {
439 using namespace data_type;
440
441 static constexpr int desirable_node_size = 8;
442 static constexpr int desirable_stride = 1;
443
444 // This processing is relied on swaping two innermost dimension.
445 // Therefore, input stride in second node and output stride in first node
446 // have to be equal to 1.
447
448 return mayiuse(avx2) && prb_.ndims >= 2
449 && ((utils::one_of(prb_.itype, u8, s8, s32, f32, bf16, f16)
450 && utils::one_of(
451 prb_.otype, u8, s8, s32, f32, bf16, f16)))
452 && utils::everyone_is(desirable_node_size, prb_.n(0), prb_.n(1))
453 && utils::everyone_is(desirable_stride, prb_.os(0), prb_.is(1))
454 && !prb_.is_tail_present
455 && prb_.src_scale_type == scale_type_t::NONE
456 && prb_.dst_scale_type == scale_type_t::NONE
457 && prb_.beta == 0.f;
458 }
459
460 bool process_unroll_tr8x8(const int ndims, const int len) {
461 if (!can_do_tr8x8()) return false;
462
463 const int step_size = prb_.n(0) * prb_.n(1);
464 int i_off = 0, o_off = 0;
465 for (int off = 0; off < len; off += step_size) {
466 step(off, i_off, o_off, i_off, o_off, step_size);
467 tr8x8_avx2(i_off, o_off);
468 }
469
470 return true;
471 }
472
473 template <cpu_isa_t isa>
474 bool process_direct_copy(const int ndims, const int len) {
475 using namespace data_type;
476
477 static constexpr int desirable_stride = 1;
478 using Vmm = typename cpu_isa_traits<isa>::Vmm;
479 const int simd_w = cpu_isa_traits<isa>::vlen / itype_sz_;
480
481 // TODO: support tail_processing for direct copy
482
483 const bool do_src_zp = prb_.req_src_zp;
484 const bool do_dst_zp = prb_.req_dst_zp;
485 const bool zp_applicable = IMPLICATION(
486 (do_src_zp || do_dst_zp), utils::one_of(prb_.itype, s32, f32));
487 const bool can_do = true && mayiuse(isa)
488 && compensation_needed_ == false
489 && utils::everyone_is(desirable_stride, prb_.os(0), prb_.is(0))
490 && (false || (prb_.itype == prb_.otype ? zp_applicable : false)
491 || (prb_.itype == s32 && prb_.otype == f32)
492 || (prb_.itype == f32 && prb_.otype == s32))
493 && len % simd_w == 0 && prb_.n(0) % len == 0
494 && !prb_.is_tail_present
495 && prb_.src_scale_type == scale_type_t::NONE
496 && prb_.dst_scale_type == scale_type_t::NONE
497 && prb_.beta == 0.f;
498 if (!can_do) return false;
499
500 static constexpr int vmm_zp_last_idx = 15;
501 const auto vmm_src_zp
502 = Vmm(do_dst_zp ? vmm_zp_last_idx - 1 : vmm_zp_last_idx);
503 if (do_src_zp) {
504 uni_vpbroadcastd(vmm_src_zp, PARAM(src_zp));
505 uni_vcvtdq2ps(vmm_src_zp, vmm_src_zp);
506 }
507 const auto vmm_dst_zp = Vmm(vmm_zp_last_idx);
508 if (do_dst_zp) {
509 uni_vpbroadcastd(vmm_dst_zp, PARAM(dst_zp));
510 uni_vcvtdq2ps(vmm_dst_zp, vmm_dst_zp);
511 }
512
513 const auto apply_zp_ps = [&](const Vmm vmm) {
514 if (do_src_zp) uni_vsubps(vmm, vmm, vmm_src_zp);
515 if (do_dst_zp) uni_vaddps(vmm, vmm, vmm_dst_zp);
516 };
517
518 for (int off = 0; off < len;) {
519 // TODO: we need extra reg for proper saturation if otype == s32
520 int unroll
521 = nstl::min(16 - (prb_.otype == s32), (len - off) / simd_w);
522 unroll = (do_src_zp || do_dst_zp)
523 ? nstl::min(unroll, 16 - do_src_zp - do_dst_zp)
524 : unroll;
525
526 for (int ur = 0; ur < unroll; ++ur) {
527 const auto vmm = Vmm(ur);
528 uni_vmovups(vmm, i_addr(off + ur * simd_w));
529 }
530
531 if (prb_.itype != prb_.otype) {
532 for (int ur = 0; ur < unroll; ++ur) {
533 const auto vmm = Vmm(ur);
534 if (prb_.itype == s32 && prb_.otype == f32) {
535 uni_vcvtdq2ps(vmm, vmm);
536 apply_zp_ps(vmm);
537 } else if (prb_.itype == f32 && prb_.otype == s32) {
538 apply_zp_ps(vmm);
539 uni_vcvtps2dq(vmm, vmm);
540 } else
541 assert(!"unreachable");
542 }
543 } else if (do_src_zp || do_dst_zp) {
544 for (int ur = 0; ur < unroll; ++ur) {
545 const auto vmm = Vmm(ur);
546 if (prb_.otype == f32) {
547 apply_zp_ps(vmm);
548 } else if (prb_.otype == s32) {
549 uni_vcvtdq2ps(vmm, vmm);
550 apply_zp_ps(vmm);
551 uni_vcvtps2dq(vmm, vmm);
552 }
553 }
554 }
555
556 for (int ur = 0; ur < unroll; ++ur) {
557 const auto vmm = Vmm(ur);
558 uni_vmovups(o_addr(off + ur * simd_w), vmm);
559 }
560
561 off += unroll * simd_w;
562 }
563
564 return true;
565 }
566
567 void process_unroll_generic_step(int reg_unroll, const int *i_off,
568 const int *o_off, const int *s_off, const int *c_off,
569 const int *zero_padding, const bool tail_processing) {
570 using namespace data_type;
571
572 const auto cvt2ps
573 = [=](const Xmm dst, const Operand &src, data_type_t idt) {
574 Xmm dst_pure = Xmm(dst.getIdx());
575 switch (idt) {
576 case f32:
577 if (src.isMEM() || src.getIdx() != dst.getIdx())
578 uni_vmovups(dst, src);
579 break;
580 case bf16:
581 if (mayiuse(avx)) {
582 vpmovzxwd(dst, src);
583 vpslld(dst, dst, 0x10);
584 break;
585 } else
586 assert("unreachable!");
587 case f16: vcvtph2ps(dst, src); break;
588 case s32: uni_vcvtdq2ps(dst, src); break;
589 case s8:
590 uni_vpmovsxbd(dst, src);
591 uni_vcvtdq2ps(dst_pure, dst);
592 break;
593 case u8:
594 uni_vpmovzxbd(dst, src);
595 uni_vcvtdq2ps(dst_pure, dst);
596 break;
597 default: assert(!"unreachable");
598 }
599 };
600
601 const auto cvt2odt = [=](const Xmm xmm, data_type_t odt,
602 data_type_t idt) {
603 switch (odt) {
604 case bf16:
605 if (!mayiuse(avx)) assert(!"unreachable");
606 if (utils::one_of(idt, f32, f16, s8, u8)) {
607 if (idt != f32) cvt2ps(xmm, xmm, idt);
608 if (is_superset(isa_, avx2_vnni_2)) {
609 vcvtneps2bf16(xmm, xmm, Xbyak::VexEncoding);
610 } else if (mayiuse(avx512_core_bf16)) {
611 vcvtneps2bf16(xmm, xmm);
612 } else {
613 bf16_emu_->vcvtneps2bf16(
614 Ymm(xmm.getIdx()), Zmm(xmm.getIdx()));
615 }
616 }
617 break;
618 case f16:
619 if (!mayiuse(avx)) assert(!"unreachable");
620 if (utils::one_of(idt, f32, bf16, s8, u8)) {
621 if (idt != f32) cvt2ps(xmm, xmm, idt);
622 vcvtps2ph(xmm, xmm, _op_mxcsr);
623 }
624 break;
625 case s32:
626 if (idt == f32)
627 uni_vcvtps2dq(xmm, xmm);
628 else if (idt == s8)
629 uni_vpmovsxbd(xmm, xmm);
630 else if (idt == u8)
631 uni_vpmovzxbd(xmm, xmm);
632 break;
633 case s8:
634 if (utils::one_of(idt, bf16, f16)) cvt2ps(xmm, xmm, idt);
635 if (utils::one_of(idt, f32, bf16, f16))
636 uni_vcvtps2dq(xmm, xmm);
637 if (utils::one_of(idt, bf16, f16, f32, s32)) {
638 if (mayiuse(avx512_core)) {
639 vpmovsdb(xmm, xmm);
640 } else {
641 uni_vpackssdw(xmm, xmm, xmm_zero_);
642 uni_vpacksswb(xmm, xmm, xmm_zero_);
643 }
644 }
645 if (idt == u8) uni_vpminub(xmm, xmm, xmm_4x127b_);
646 break;
647 case u8:
648 if (utils::one_of(idt, bf16, f16)) cvt2ps(xmm, xmm, idt);
649 if (utils::one_of(idt, f32, bf16, f16))
650 uni_vcvtps2dq(xmm, xmm);
651 if (utils::one_of(idt, bf16, f16, f32, s32)) {
652 if (mayiuse(avx512_core)) {
653 vpmaxsd(xmm, xmm, xmm_zero_);
654 vpmovusdb(xmm, xmm);
655 } else {
656 uni_vpackssdw(xmm, xmm, xmm_zero_);
657 uni_vpackuswb(xmm, xmm, xmm_zero_);
658 }
659 }
660 if (idt == s8) uni_vpmaxsb(xmm, xmm, xmm_zero_);
661 break;
662 default: assert(!"unreachable");
663 }
664 };
665
666 auto load = [=](const Xmm xmm, const Address &addr, int size) {
667 switch (size) {
668 case 16: uni_vmovups(xmm, addr); break;
669 case 8: uni_vmovsd(xmm, addr); break;
670 case 4: uni_vmovss(xmm, addr); break;
671 case 2: uni_vpinsrw(xmm, xmm, addr, 0x0); break;
672 case 1: uni_vpinsrb(xmm, xmm, addr, 0x0); break;
673 default: assert(!"unreachable");
674 }
675 };
676
677 auto load_bytes
678 = [=](const Xmm xmm, const Address &addr, int size, int imm) {
679 switch (size) {
680 case 4: uni_vpinsrd(xmm, xmm, addr, imm); break;
681 case 2: uni_vpinsrw(xmm, xmm, addr, imm); break;
682 case 1: uni_vpinsrb(xmm, xmm, addr, imm); break;
683 default: assert(!"unreachable");
684 }
685 };
686
687 auto store = [=](const Address &addr, const Xmm xmm, int size) {
688 switch (size) {
689 case 16: uni_vmovups(addr, xmm); break;
690 case 8: uni_vmovsd(addr, xmm); break;
691 case 4: uni_vmovss(addr, xmm); break;
692 case 2: uni_vpextrw(addr, xmm, 0x0); break;
693 case 1: uni_vpextrb(addr, xmm, 0x0); break;
694 default: assert(!"unreachable");
695 }
696 };
697
698 /* check whether loading 4 values at once is possible */
699 static constexpr int xmm_vlen = 4;
700 bool can_load_xmm = reg_unroll % xmm_vlen == 0;
701 for (int ur = 1; ur < reg_unroll; ++ur)
702 if (i_off[ur] != i_off[ur - 1] + 1) {
703 can_load_xmm = false;
704 break;
705 }
706 const int load_step = can_load_xmm ? xmm_vlen : 1;
707
708 /* check whether storing 4 values at once is possible */
709 bool can_store_xmm = reg_unroll % xmm_vlen == 0;
710 for (int ur = 1; ur < reg_unroll; ++ur)
711 if (o_off[ur] != o_off[ur - 1] + 1) {
712 can_store_xmm = false;
713 break;
714 }
715 const int ur_step = can_store_xmm ? 4 : 1;
716 const int load_tail_step
717 = !can_load_xmm && can_store_xmm ? ur_step : load_step;
718
719 const bool interim_f32 = interim_f32_needed();
720
721 const bool need_saturation
722 = (utils::one_of(prb_.otype, u8, s8, s32) && interim_f32);
723
724 std::vector<int> store_masks;
725 if (tail_processing) {
726 for (int ur = 0; ur < reg_unroll; ur += load_tail_step) {
727 uni_vpxor(Xmm(ur), Xmm(ur), Xmm(ur));
728 store_masks.push_back(0);
729 for (int r = 0; r < load_tail_step; ++r) {
730 if (zero_padding[ur + r] == 0) {
731 store_masks.back() += 1 << r;
732 load_bytes(
733 Xmm(ur), i_addr(i_off[ur + r]), itype_sz_, r);
734 }
735 }
736 }
737 } else {
738 if (!can_load_xmm && can_store_xmm) {
739 assert(ur_step == xmm_vlen);
740 /* load with stride */
741 for (int ur = 0; ur < reg_unroll; ur += ur_step) {
742 for (int r = 0; r < ur_step; ++r) {
743 load_bytes(
744 Xmm(ur), i_addr(i_off[ur + r]), itype_sz_, r);
745 }
746 }
747 } else {
748 for (int ur = 0; ur < reg_unroll; ur += load_step) {
749 load(Xmm(ur), i_addr(i_off[ur]), load_step * itype_sz_);
750 }
751 }
752 }
753
754 /* xmm[:] <-- (f32)xmm[:] */
755 if (interim_f32) {
756 const int cvt_step = nstl::max(load_step, ur_step);
757 for (int ur = 0; ur < reg_unroll; ur += cvt_step)
758 cvt2ps(Xmm(ur), Xmm(ur), prb_.itype);
759 }
760
761 if (can_load_xmm && !can_store_xmm) {
762 // transposition on the fly
763 const bool fast_return = prb_.src_scale_type != scale_type_t::MANY
764 && prb_.dst_scale_type != scale_type_t::MANY
765 && prb_.beta == 0.f;
766 if (fast_return) {
767 if (prb_.src_scale_type == scale_type_t::COMMON)
768 for (int ur = 0; ur < reg_unroll; ur += load_step)
769 uni_vmulps(Xmm(ur), Xmm(ur), xmm_src_scales_);
770 if (prb_.dst_scale_type == scale_type_t::COMMON)
771 for (int ur = 0; ur < reg_unroll; ur += load_step)
772 uni_vmulps(Xmm(ur), Xmm(ur), xmm_dst_scales_);
773 if (prb_.otype != f32) {
774 init_saturate_f32(xmm_zero_, xmm_saturation_ubound_,
775 reg_tmp_, interim_f32 ? f32 : prb_.itype,
776 prb_.otype);
777 for (int ur = 0; ur < reg_unroll; ur += load_step) {
778 if (need_saturation)
779 saturate_f32(Xmm(ur), xmm_zero_,
780 xmm_saturation_ubound_, prb_.otype);
781 cvt2odt(Xmm(ur), prb_.otype,
782 interim_f32 ? f32 : prb_.itype);
783 }
784 }
785 for (int ur = 0; ur < reg_unroll; ur += load_step) {
786 for (int r = 0; r < load_step; ++r) {
787 if (otype_sz_ == 4)
788 uni_vpextrd(o_addr(o_off[ur + r]), Xmm(ur), r);
789 else if (otype_sz_ == 2)
790 uni_vpextrw(o_addr(o_off[ur + r]), Xmm(ur), r);
791 else
792 uni_vpextrb(o_addr(o_off[ur + r]), Xmm(ur), r);
793 }
794 }
795 return;
796 }
797
798 /* scatter elements of xmm into 4 xmms */
799 if (itype_sz_ == 4 || interim_f32) {
800 for (int ur = 0; ur < reg_unroll; ur += load_step)
801 for (int r = 1; r < load_step; ++r) {
802 uni_vshufps(Xmm(ur + r), Xmm(ur), Xmm(ur), r);
803 }
804 } else {
805 for (int ur = 0; ur < reg_unroll; ur += load_step)
806 for (int r = 1; r < load_step; ++r) {
807 if (mayiuse(avx))
808 vpalignr(Xmm(ur + r), Xmm(ur), Xmm(ur),
809 itype_sz_ * r);
810 else {
811 movups(Xmm(ur + r), Xmm(ur));
812 palignr(Xmm(ur + r), Xmm(ur), itype_sz_ * r);
813 }
814 }
815 }
816 }
817
818 /* src zero point application */
819 if (prb_.req_src_zp) {
820 for (int ur = 0; ur < reg_unroll; ur += ur_step) {
821 const auto xmm = Xmm(ur);
822 if (interim_f32)
823 uni_vsubps(xmm, xmm, xmm_src_zp_);
824 else
825 uni_vpsubd(xmm, xmm, xmm_src_zp_);
826 }
827 }
828
829 /* scale and beta processing */
830 if (can_store_xmm) {
831 const auto apply_scales = [&](const Xmm &vreg_scales,
832 scale_arg_t scale_arg,
833 scale_type_t scale_type) {
834 if (scale_type == scale_type_t::COMMON) {
835 for (int ur = 0; ur < reg_unroll; ur += ur_step)
836 uni_vmulps(Xmm(ur), Xmm(ur), vreg_scales);
837 } else if (scale_type == scale_type_t::MANY) {
838 enum class scale_load_type_t { bcast, load, gather };
839
840 uni_vpxor(vreg_scales, vreg_scales, vreg_scales);
841 for (int ur = 0; ur < reg_unroll; ur += ur_step) {
842 scale_load_type_t scale_load_type
843 = scale_load_type_t::bcast; // the best case
844
845 for (int r = ur + 1; r < ur + ur_step; ++r)
846 if (s_off[r] != s_off[r - 1] + 0)
847 scale_load_type = scale_load_type_t::load;
848
849 if (scale_load_type == scale_load_type_t::bcast
850 && !tail_processing) {
851 uni_vbroadcastss(vreg_scales,
852 scale_arg == scale_arg_t::SRC
853 ? src_s_addr(s_off[ur])
854 : dst_s_addr(s_off[ur]));
855 uni_vmulps(Xmm(ur), Xmm(ur), vreg_scales);
856 continue;
857 }
858
859 // bcast doesn't work, the next try -- load
860 for (int r = ur + 1; r < ur + ur_step; ++r)
861 if (s_off[r] != s_off[r - 1] + 1)
862 scale_load_type = scale_load_type_t::gather;
863
864 if (scale_load_type == scale_load_type_t::load
865 && !tail_processing) {
866 uni_vmovups(vreg_scales,
867 scale_arg == scale_arg_t::SRC
868 ? src_s_addr(s_off[ur])
869 : dst_s_addr(s_off[ur]));
870 uni_vmulps(Xmm(ur), Xmm(ur), vreg_scales);
871 continue;
872 }
873
874 // load doesn't work as well
875 // so gather the scale factors one by one
876 for (int r = ur; r < ur + ur_step; ++r) {
877 if (zero_padding[r] == 0 || !tail_processing)
878 uni_vpinsrd(vreg_scales, vreg_scales,
879 scale_arg == scale_arg_t::SRC
880 ? src_s_addr(s_off[r])
881 : dst_s_addr(s_off[r]),
882 r - ur);
883 }
884 uni_vmulps(Xmm(ur), Xmm(ur), vreg_scales);
885 }
886 }
887 };
888 /* xmm <-- src_scales * xmm[:] */
889 apply_scales(
890 xmm_src_scales_, scale_arg_t::SRC, prb_.src_scale_type);
891
892 /* xmm[:] <-- beta * dst + xmm[:] */
893 assert(prb_.beta == 0.f || prb_.beta == 1.f);
894 if (prb_.beta == 1.f) {
895 for (int ur = 0; ur < reg_unroll; ur += ur_step) {
896 if (prb_.otype == f32) {
897 /* non VEX instructions do not support unaligned
898 * memory for instructions other than movups. */
899 if (mayiuse(avx)) {
900 vaddps(Xmm(ur), o_addr(o_off[ur]));
901 } else {
902 /* register xmm(1) is unused */
903 movups(Xmm(1), o_addr(o_off[ur]));
904 addps(Xmm(ur), Xmm(1));
905 }
906 } else {
907 cvt2ps(Xmm(1), o_addr(o_off[ur]), prb_.otype);
908 uni_vaddps(Xmm(ur), Xmm(ur), Xmm(1));
909 }
910 }
911 }
912
913 /* dst <-- dst_scales * xmm[:] */
914 apply_scales(
915 xmm_dst_scales_, scale_arg_t::DST, prb_.dst_scale_type);
916 } else {
917 const auto apply_scales
918 = [&](const Xmm &vreg_scales, scale_arg_t scale_arg,
919 scale_type_t scale_type) {
920 if (scale_type == scale_type_t::COMMON) {
921 for (int ur = 0; ur < reg_unroll; ur += ur_step)
922 uni_vmulss(Xmm(ur), Xmm(ur), vreg_scales);
923 } else if (scale_type == scale_type_t::MANY) {
924 for (int ur = 0; ur < reg_unroll; ur += ur_step) {
925 if (zero_padding[ur] == 0 || !tail_processing)
926 uni_vmulss(Xmm(ur), Xmm(ur),
927 scale_arg == scale_arg_t::SRC
928 ? src_s_addr(s_off[ur])
929 : dst_s_addr(s_off[ur]));
930 }
931 }
932 };
933
934 /* xmm[0] <-- src_scales * xmm[0] */
935 apply_scales(
936 xmm_src_scales_, scale_arg_t::SRC, prb_.src_scale_type);
937
938 /* xmm[0] <-- beta * dst + xmm[0] */
939 assert(prb_.beta == 0.f || prb_.beta == 1.f);
940 if (prb_.beta == 1.f) {
941 for (int ur = 0; ur < reg_unroll; ur += ur_step) {
942 if (prb_.otype == f32) {
943 uni_vaddss(Xmm(ur), Xmm(ur), o_addr(o_off[ur]));
944 } else {
945 if (prb_.otype == s32) {
946 uni_vmovss(xmm_tmp_, o_addr(o_off[ur]));
947 } else if (utils::one_of(prb_.otype, s8, u8)) {
948 uni_vpinsrb(
949 xmm_tmp_, xmm_tmp_, o_addr(o_off[ur]), 0x0);
950 } else if (utils::one_of(prb_.otype, bf16, f16)) {
951 uni_vpinsrw(
952 xmm_tmp_, xmm_tmp_, o_addr(o_off[ur]), 0x0);
953 } else {
954 assert(!"unsupported o_type");
955 }
956 cvt2ps(xmm_tmp_, xmm_tmp_, prb_.otype);
957 uni_vaddps(Xmm(ur), Xmm(ur), xmm_tmp_);
958 }
959 }
960 }
961
962 /* dst <-- dst_scales * xmm[0] */
963 apply_scales(
964 xmm_dst_scales_, scale_arg_t::DST, prb_.dst_scale_type);
965 }
966
967 /* dst zero point application */
968 if (prb_.req_dst_zp) {
969 for (int ur = 0; ur < reg_unroll; ur += ur_step) {
970 const auto xmm = Xmm(ur);
971 if (interim_f32)
972 uni_vaddps(xmm, xmm, xmm_dst_zp_);
973 else
974 uni_vpaddd(xmm, xmm, xmm_dst_zp_);
975 }
976 }
977
978 /* adjust scale application */
979 if (prb_.scale_adjust != 1.f) {
980 uni_vmovd(xmm_tmp_, reg_scale_adjust_);
981 uni_vpshufd(xmm_tmp_, xmm_tmp_, 0x0);
982 for (int ur = 0; ur < reg_unroll; ur += ur_step) {
983 uni_vmulps(Xmm(ur), Xmm(ur), xmm_tmp_);
984 }
985 }
986
987 if (need_saturation) {
988 init_saturate_f32(xmm_zero_, xmm_saturation_ubound_, reg_tmp_, f32,
989 prb_.otype, compensation_needed_);
990 for (int ur = 0; ur < reg_unroll; ur += ur_step) {
991 saturate_f32(Xmm(ur), xmm_zero_, xmm_saturation_ubound_,
992 prb_.otype, compensation_needed_);
993 }
994
995 // reset back xmm_zero_ if needed.
996 if (compensation_needed_ && (prb_.req_src_zp || prb_.req_dst_zp))
997 uni_vxorps(xmm_zero_, xmm_zero_, xmm_zero_);
998 }
999
1000 if (compensation_needed_) {
1001 const bool mayiuse_avx2 = mayiuse(avx2);
1002 const auto uni_vpaddd_wrapper
1003 = [&](const Xmm xmm, const Address &addr) {
1004 if (mayiuse_avx2)
1005 vpaddd(xmm, xmm, addr);
1006 else {
1007 //isas < avx2 demand paddd instruction addr to be aligned
1008 assert(xmm.getIdx() != xmm_tmp_.getIdx());
1009 uni_vmovups(xmm_tmp_, addr);
1010 paddd(xmm, xmm_tmp_);
1011 }
1012 };
1013 if (can_store_xmm) {
1014 enum class comp_load_type_t { bcast, load, gather };
1015
1016 for (int ur = 0; ur < reg_unroll; ur += ur_step) {
1017
1018 bool all_ip_padding_one = true;
1019 bool all_ip_padding_zero = true;
1020 for (int r = ur; r < ur + ur_step; r++) {
1021 if (zero_padding[r] != 1)
1022 all_ip_padding_one = false;
1023 else
1024 all_ip_padding_zero = false;
1025 }
1026 if (all_ip_padding_one) continue;
1027
1028 comp_load_type_t comp_load_type = comp_load_type_t::bcast;
1029
1030 for (int r = ur + 1; r < ur + ur_step; ++r)
1031 if (c_off[r] != c_off[r - 1] + 0) {
1032 comp_load_type = comp_load_type_t::load;
1033 break;
1034 }
1035
1036 if (comp_load_type == comp_load_type_t::bcast
1037 && all_ip_padding_zero) {
1038 // xmm_compensation is used for reduction.
1039 uni_vcvtps2dq(xmm_compensation, Xmm(ur));
1040 uni_vphaddd(xmm_compensation, xmm_compensation,
1041 xmm_compensation);
1042 uni_vphaddd(xmm_compensation, xmm_compensation,
1043 xmm_compensation);
1044 const auto comp_addr = c_addr(c_off[ur]);
1045 uni_vmovss(xmm_tmp_, comp_addr);
1046 uni_vpaddd(xmm_tmp_, xmm_tmp_, xmm_compensation);
1047 uni_vmovss(comp_addr, xmm_tmp_);
1048 continue;
1049 }
1050
1051 if (comp_load_type == comp_load_type_t::load)
1052 for (int r = ur + 1; r < ur + ur_step; ++r)
1053 if (c_off[r] != c_off[r - 1] + 1) {
1054 comp_load_type = comp_load_type_t::gather;
1055 break;
1056 }
1057
1058 if (comp_load_type == comp_load_type_t::load
1059 && all_ip_padding_zero) {
1060 const auto comp_addr = c_addr(c_off[ur]);
1061 uni_vcvtps2dq(xmm_compensation, Xmm(ur));
1062 uni_vpaddd_wrapper(xmm_compensation, comp_addr);
1063 uni_vmovups(comp_addr, xmm_compensation);
1064 continue;
1065 }
1066
1067 uni_vcvtps2dq(xmm_compensation, Xmm(ur));
1068 for (int r = ur; r < ur + ur_step; ++r) {
1069 if (zero_padding[r] == 0 || !tail_processing) {
1070 uni_vshufps(xmm_tmp_, xmm_compensation,
1071 xmm_compensation, r);
1072 const Reg32 reg_tmp_32 = reg_tmp_.cvt32();
1073 uni_vmovd(reg_tmp_32, xmm_tmp_);
1074 const auto comp_addr = c_addr(c_off[r]);
1075 add(comp_addr, reg_tmp_32);
1076 }
1077 }
1078 }
1079 } else {
1080 for (int ur = 0; ur < reg_unroll; ur += ur_step) {
1081 if (zero_padding[ur] == 0 || !tail_processing) {
1082 const auto comp_addr = c_addr(c_off[ur]);
1083 uni_vcvtps2dq(xmm_compensation, Xmm(ur));
1084 uni_vpaddd_wrapper(xmm_compensation, comp_addr);
1085 uni_vmovss(comp_addr, xmm_compensation);
1086 }
1087 }
1088 }
1089 }
1090
1091 for (int ur = 0; ur < reg_unroll; ur += ur_step) {
1092 if (prb_.req_src_zp || prb_.req_dst_zp) {
1093 const bool use_store_masks = !store_masks.empty();
1094 if (use_store_masks) {
1095 const auto mask = ~store_masks[ur / ur_step];
1096 uni_vblendps(Xmm(ur), Xmm(ur), xmm_zero_, mask);
1097 }
1098 }
1099 if (prb_.otype != f32)
1100 cvt2odt(Xmm(ur), prb_.otype, interim_f32 ? f32 : prb_.itype);
1101
1102 store(o_addr(o_off[ur]), Xmm(ur), ur_step * otype_sz_);
1103 }
1104 }
1105
1106 bool interim_f32_needed() {
1107 using namespace data_type;
1108
1109 return utils::one_of(f32, prb_.itype, prb_.otype)
1110 || prb_.src_scale_type != scale_type_t::NONE
1111 || prb_.dst_scale_type != scale_type_t::NONE || prb_.beta != 0.f
1112 || ((prb_.req_src_zp || prb_.req_dst_zp)
1113 ? !(prb_.itype == s32 && prb_.otype == s32)
1114 : false)
1115 || (prb_.itype != f32 && compensation_needed_)
1116 || prb_.scale_adjust != 1.f;
1117 }
1118
1119 void process_unroll_generic(
1120 const int ndims, int len, const bool tail_processing) {
1121 assert(IMPLICATION(prb_.nodes[0].tail_size > 0,
1122 len == static_cast<int>(prb_.nodes[0].n)
1123 || len == static_cast<int>(prb_.nodes[0].tail_size)));
1124
1125 const int blk = 8;
1126
1127 int i_off[2 * blk] = {0};
1128 int o_off[2 * blk] = {0};
1129 int s_off[2 * blk] = {0};
1130 int c_off[2 * blk] = {0};
1131
1132 int curr = 0; // will switch between 0 and 1
1133
1134 const bool interim_f32 = interim_f32_needed();
1135
1136 if (prb_.req_src_zp) {
1137 uni_vbroadcastss(xmm_src_zp_, PARAM(src_zp));
1138 if (interim_f32) uni_vcvtdq2ps(xmm_src_zp_, xmm_src_zp_);
1139 }
1140 if (prb_.req_dst_zp) {
1141 uni_vbroadcastss(xmm_dst_zp_, PARAM(dst_zp));
1142 if (interim_f32) uni_vcvtdq2ps(xmm_dst_zp_, xmm_dst_zp_);
1143 }
1144
1145 for (int off = 0; off < len; off += blk) {
1146 const int reg_unroll = nstl::min(off + blk, len) - off;
1147 int zero_padding[blk] = {0};
1148 const auto curr_blk = curr * blk;
1149
1150 /* compute offsets and tail*/
1151 for (int ur = off != 0 ? 0 : 1; ur < reg_unroll; ++ur) {
1152 const int ur_c = curr_blk + ur;
1153 const int ur_p = (ur_c - 1 + 2 * blk) % (2 * blk); // prev ur
1154 const bool is_tail
1155 = off + ur >= static_cast<int>(prb_.nodes[0].tail_size);
1156 step(off + ur, i_off[ur_p], o_off[ur_p], s_off[ur_p],
1157 c_off[ur_p], i_off[ur_c], o_off[ur_c], s_off[ur_c],
1158 c_off[ur_c]);
1159 if (tail_processing && is_tail) zero_padding[ur] = 1;
1160 }
1161
1162 process_unroll_generic_step(reg_unroll, i_off + curr_blk,
1163 o_off + curr_blk, s_off + curr_blk, c_off + curr_blk,
1164 zero_padding, tail_processing);
1165
1166 curr = 1 - curr;
1167 }
1168 }
1169
1170 void compute_ker(
1171 const int ndims, const int len_unroll, const bool tail_processing) {
1172 bool optimized = false;
1173 optimized = optimized || process_direct_copy<avx>(ndims, len_unroll)
1174 || process_direct_copy<sse41>(ndims, len_unroll)
1175 || process_unroll_tr8x8(ndims, len_unroll);
1176 if (!optimized)
1177 process_unroll_generic(ndims, len_unroll, tail_processing);
1178 }
1179
1180 void loop_begin(Label &l, Reg64 reg_cnt, int len) {
1181 mov(reg_cnt, len);
1182 L(l);
1183 }
1184
1185 void check_if_this_is_last_chunk(const Reg64 reg_curr_chunk, int node_id) {
1186 // Chunks are backwards numered i.e:
1187 // [0] -> [node_size]
1188 // [1] -> [node_size - 1]
1189 // ...
1190 // [node_size - 1] -> [1]
1191
1192 // It is done like this, because it is easier to decrement counter
1193 // and check if it is equal to zero than increment and check
1194 // if it is equal to node_size.
1195 static constexpr int64_t last_chunk = 1;
1196 cmp(reg_curr_chunk, last_chunk);
1197 }
1198
1199 void zero_dst_memory(const int bytes_to_zeroing) {
1200 static constexpr int num_of_bytes_in_xmm = 128 / 8;
1201
1202 const int xmms_to_zeroing
1203 = std::div(bytes_to_zeroing, num_of_bytes_in_xmm).quot;
1204 const int tail_to_zeroing
1205 = std::div(bytes_to_zeroing, num_of_bytes_in_xmm).rem;
1206
1207 uni_vpxor(xmm_tmp_, xmm_tmp_, xmm_tmp_);
1208
1209 if (xmms_to_zeroing > 0) {
1210 Label loop;
1211
1212 mov(reg_tmp_, xmms_to_zeroing);
1213 L(loop);
1214 uni_vmovups(o_addr(0), xmm_tmp_);
1215 add(reg_off_out_, num_of_bytes_in_xmm);
1216 dec(reg_tmp_);
1217 jnz(loop);
1218 }
1219
1220 for (int i = 0; i < tail_to_zeroing; i++)
1221 uni_vpextrb(o_addr(i, false), xmm_tmp_, 0);
1222
1223 // Restore dst offset to initial value
1224 if (xmms_to_zeroing > 0)
1225 sub(reg_off_out_, num_of_bytes_in_xmm * xmms_to_zeroing);
1226 }
1227
1228 void finalize_tail_loop(int i_step, int o_step, int s_step, int c_step,
1229 const int curr_node_id) {
1230 static constexpr int empty_chunk_info = -1;
1231
1232 mov(reg_tmp_, empty_chunk_info);
1233 mov(data_chunk_addr(curr_node_id), reg_tmp_);
1234
1235 const int padded_area = prb_.nodes[curr_node_id].n
1236 - prb_.nodes[curr_node_id].tail_size;
1237
1238 if (prb_.nodes[curr_node_id].is_zero_pad_needed) {
1239 int num_of_zero_padded_values = padded_area;
1240 for (int i = curr_node_id - 1; i >= 0; i--) {
1241 num_of_zero_padded_values *= prb_.nodes[i].n;
1242 }
1243
1244 const int bytes_to_zeroing = num_of_zero_padded_values * otype_sz_;
1245 zero_dst_memory(bytes_to_zeroing);
1246 }
1247
1248 // This function is called by loop_end. At the end
1249 // of loop_end is section that is responsible for
1250 // restoring offset values. Restoring is based on
1251 // len value which is equal to prb.nodes[x].n.
1252 // If fill_zero_padded_area is called then it means
1253 // offsets were shifted prb.nodes[x].tail_size times.
1254 // Therefore, this function has to shift offsets by
1255 // zero pad area.
1256 add(reg_off_in_, padded_area * i_step * itype_sz_);
1257 add(reg_off_out_, padded_area * o_step * otype_sz_);
1258 if (prb_.src_scale_type == scale_type_t::MANY
1259 || prb_.dst_scale_type == scale_type_t::MANY)
1260 add(reg_off_scale_, padded_area * s_step * stype_sz_);
1261 if (compensation_needed_)
1262 add(reg_off_comp_, padded_area * c_step * sizeof(int32_t));
1263 }
1264
1265 void loop_end(Label &l, const Reg64 reg_cnt, int len, int i_step,
1266 int o_step, int s_step, int c_step, const int curr_node_id) {
1267 add(reg_off_in_, i_step * itype_sz_);
1268 add(reg_off_out_, o_step * otype_sz_);
1269 if (prb_.src_scale_type == scale_type_t::MANY
1270 || prb_.dst_scale_type == scale_type_t::MANY)
1271 add(reg_off_scale_, s_step * stype_sz_);
1272 if (compensation_needed_) add(reg_off_comp_, c_step * sizeof(int32_t));
1273
1274 dec(reg_cnt);
1275 jnz(l);
1276
1277 if (prb_.tail(curr_node_id) != 0) {
1278 Label if_end;
1279
1280 // On the stack should be an information if node
1281 // was processed with tail or not.
1282 pop(reg_tmp_);
1283
1284 cmp(reg_tmp_, with_tail_info_);
1285 jne(if_end, T_NEAR);
1286 finalize_tail_loop(i_step, o_step, s_step, c_step, curr_node_id);
1287 L(if_end);
1288 }
1289
1290 // Restore offset to initial values. It means before
1291 // loop execution.
1292 sub(reg_off_in_, len * i_step * itype_sz_);
1293 sub(reg_off_out_, len * o_step * otype_sz_);
1294 if (prb_.src_scale_type == scale_type_t::MANY
1295 || prb_.dst_scale_type == scale_type_t::MANY)
1296 sub(reg_off_scale_, len * s_step * stype_sz_);
1297 if (compensation_needed_)
1298 sub(reg_off_comp_, len * c_step * sizeof(int32_t));
1299 }
1300
1301 void compute_blk_ker(const simple_impl_desc_t &desc) {
1302 static constexpr bool with_tail_processing = true;
1303 Label no_last_chunk, end_label;
1304 int omp_ndims = prb_.full_ndims - prb_.ndims;
1305
1306 if (prb_.nodes[0].tail_size > 0) {
1307 if (!prb_.nodes[0].is_parent_empty()) {
1308 const int parent_node_id = prb_.nodes[0].parent_node_id;
1309 mov(reg_tmp_, data_chunk_addr(parent_node_id));
1310 check_if_this_is_last_chunk(reg_tmp_, parent_node_id);
1311 jne(no_last_chunk, T_NEAR);
1312 }
1313
1314 const int len_unroll = desc.tail_len_unroll > 0
1315 ? desc.tail_len_unroll
1316 : desc.len_unroll;
1317 compute_ker(omp_ndims, len_unroll, with_tail_processing);
1318 jmp(end_label, T_NEAR);
1319 }
1320
1321 L(no_last_chunk);
1322 compute_ker(omp_ndims, desc.len_unroll, !with_tail_processing);
1323 L(end_label);
1324 }
1325
1326 void create_loops(const simple_impl_desc_t &desc,
1327 const std::array<const Reg64, 3> &reg_cnt, int jit_loop) {
1328 assert(jit_loop <= ndims_jit_loop_max);
1329
1330 if (jit_loop > 0) {
1331 const int nfu = desc.ndims_full_unroll;
1332 const int unroll_factor
1333 = jit_loop == 1 ? desc.len_last_dim_unroll : 1;
1334 const int curr_node_id = nfu + (jit_loop - 1);
1335 const int parent_node_id = prb_.nodes[curr_node_id].parent_node_id;
1336 const int tail_size = prb_.tail(curr_node_id) / unroll_factor;
1337 const int node_size = prb_.n(curr_node_id) / unroll_factor;
1338 const Reg64 reg_loop_cnt = reg_cnt[jit_loop - 1];
1339 const bool curr_node_has_tail = prb_.tail(curr_node_id) != 0;
1340 Label loop, if_no_tail, if_end;
1341
1342 if (curr_node_has_tail) {
1343 if (prb_.nodes[curr_node_id].is_parent_empty()) {
1344 mov(reg_loop_cnt, tail_size);
1345 // Put info that node is being processed with tail.
1346 mov(reg_tmp_, with_tail_info_);
1347 push(reg_tmp_);
1348 } else {
1349 mov(reg_tmp_, data_chunk_addr(parent_node_id));
1350 check_if_this_is_last_chunk(reg_tmp_, parent_node_id);
1351 jne(if_no_tail, T_NEAR);
1352 mov(reg_loop_cnt, tail_size);
1353 // Put info that node is being processed with tail.
1354 mov(reg_tmp_, with_tail_info_);
1355 push(reg_tmp_);
1356 jmp(if_end, T_NEAR);
1357
1358 L(if_no_tail);
1359 mov(reg_loop_cnt, node_size);
1360 // Put info that node is being processed without tail.
1361 mov(reg_tmp_, without_tail_info_);
1362 push(reg_tmp_);
1363 L(if_end);
1364 }
1365 }
1366
1367 if (prb_.is_tail_in_one_of_child_nodes(curr_node_id)) {
1368 if (!curr_node_has_tail) {
1369 mov(reg_loop_cnt, node_size);
1370 mov(data_chunk_addr(curr_node_id), reg_loop_cnt);
1371 }
1372 L(loop);
1373 if (!prb_.nodes[curr_node_id].is_parent_empty()) {
1374 Label if_no_tail_in_child_node;
1375 mov(reg_tmp_, data_chunk_addr(parent_node_id));
1376 check_if_this_is_last_chunk(reg_tmp_, parent_node_id);
1377 jne(if_no_tail_in_child_node, T_NEAR);
1378 mov(data_chunk_addr(curr_node_id), reg_loop_cnt);
1379 L(if_no_tail_in_child_node);
1380 } else {
1381 mov(data_chunk_addr(curr_node_id), reg_loop_cnt);
1382 }
1383 } else if (curr_node_has_tail) {
1384 L(loop);
1385 } else {
1386 loop_begin(loop, reg_loop_cnt, node_size);
1387 }
1388
1389 create_loops(desc, reg_cnt, jit_loop - 1);
1390
1391 loop_end(loop, reg_loop_cnt, node_size,
1392 prb_.is(curr_node_id) * unroll_factor,
1393 prb_.os(curr_node_id) * unroll_factor,
1394 prb_.ss(curr_node_id) * unroll_factor,
1395 prb_.cs(curr_node_id) * unroll_factor, curr_node_id);
1396 } else {
1397 compute_blk_ker(desc);
1398 }
1399 }
1400
1401 bool simple_impl() {
1402 simple_impl_desc_t d;
1403 if (!simple_impl_desc_init(prb_, &d)) return false;
1404
1405 xor_(reg_off_in_, reg_off_in_);
1406 xor_(reg_off_out_, reg_off_out_);
1407 if (prb_.src_scale_type == scale_type_t::MANY
1408 || prb_.dst_scale_type == scale_type_t::MANY)
1409 xor_(reg_off_scale_, reg_off_scale_);
1410 if (compensation_needed_) xor_(reg_off_comp_, reg_off_comp_);
1411
1412 std::array<const Reg64, 3> reg_cnt({{r15, r14, r13}});
1413
1414 const int n_jit_loops = prb_.ndims - d.ndims_full_unroll;
1415 create_loops(d, reg_cnt, n_jit_loops);
1416
1417 return true;
1418 }
1419
1420 void impl() {
1421 if (simple_impl()) return;
1422 assert(!"no implementation available");
1423 }
1424
1425 jit_uni_reorder_kernel_f32_t(const desc_t &desc)
1426 : kernel_t(desc)
1427 , jit_generator(jit_name())
1428 , isa_(get_max_cpu_isa())
1429 , bf16_emu_(nullptr) {
1430 assert(!utils::one_of(isa_, isa_undef, isa_all));
1431 itype_sz_ = data_type_size(prb_.itype);
1432 otype_sz_ = data_type_size(prb_.otype);
1433 stype_sz_ = sizeof(float);
1434 if (prb_.otype == data_type::bf16 && !mayiuse(avx512_core_bf16)
1435 && !mayiuse(avx2_vnni_2)) {
1436 bf16_emu_ = utils::make_unique<bf16_emulation_t>(this,
1437 bf16_emu_reserv_1_, bf16_emu_reserv_2_, bf16_emu_reserv_3_,
1438 bf16_emu_scratch_, bf16_emu_reserv_4_);
1439 }
1440 }
1441
1442 void generate() override {
1443 Label end_of_kernel;
1444
1445 preamble();
1446
1447 if (bf16_emu_) bf16_emu_->init_vcvtneps2bf16();
1448
1449 if (prb_.src_scale_type == scale_type_t::COMMON) {
1450 auto reg_ptr_src_scales__tmp = reg_ptr_in_;
1451 mov(reg_ptr_src_scales__tmp, PARAM(src_scales));
1452 uni_vbroadcastss(xmm_src_scales_, ptr[reg_ptr_src_scales__tmp]);
1453 } else if (prb_.src_scale_type == scale_type_t::MANY) {
1454 mov(reg_ptr_src_scales_, PARAM(src_scales));
1455 }
1456
1457 if (prb_.dst_scale_type == scale_type_t::COMMON) {
1458 auto reg_ptr_dst_scales__tmp = reg_ptr_in_;
1459 mov(reg_ptr_dst_scales__tmp, PARAM(dst_scales));
1460 uni_vbroadcastss(xmm_dst_scales_, ptr[reg_ptr_dst_scales__tmp]);
1461 } else if (prb_.dst_scale_type == scale_type_t::MANY) {
1462 mov(reg_ptr_dst_scales_, PARAM(dst_scales));
1463 }
1464
1465 if (compensation_needed_)
1466 mov(reg_ptr_comp_, PARAM(compensation_scratch));
1467 if (prb_.scale_adjust == 0.5f) { mov(reg_scale_adjust_, 0x3f000000); }
1468 mov(reg_ptr_in_, PARAM(in));
1469 mov(reg_ptr_out_, PARAM(out));
1470
1471 bool is_tail_in_drv_dims = false;
1472 for (int i = prb_.ndims; i < prb_.full_ndims; i++)
1473 if (prb_.nodes[i].tail_size > 0) {
1474 is_tail_in_drv_dims = true;
1475 break;
1476 }
1477
1478 if (is_tail_in_drv_dims) {
1479 Label reorder_kernel;
1480
1481 mov(reg_tmp_, TAIL_PARAM(skip_kernel_execution));
1482 cmp(reg_tmp_, static_cast<int64_t>(true));
1483 je(end_of_kernel, T_NEAR);
1484
1485 mov(reg_tmp_, TAIL_PARAM(zeroing_data));
1486 cmp(reg_tmp_, static_cast<int64_t>(false));
1487 je(reorder_kernel, T_NEAR);
1488 // If zeroing data is set then all dst memory
1489 // will be zeroed and nothing more will be done.
1490 int bytes_to_zeroing = otype_sz_;
1491 for (int i = 0; i < prb_.ndims; i++) {
1492 bytes_to_zeroing *= prb_.nodes[i].n;
1493 }
1494 xor_(reg_off_out_, reg_off_out_);
1495 zero_dst_memory(bytes_to_zeroing);
1496 jmp(end_of_kernel, T_NEAR);
1497 L(reorder_kernel);
1498 }
1499
1500 if (can_do_tr8x8()) {
1501 vxorps(ymm_zero_, ymm_zero_, ymm_zero_);
1502
1503 if (prb_.itype == data_type::u8 && prb_.otype == data_type::s8) {
1504 mov(reg_tmp_, 0x7f7f7f7f7f7f7f7f);
1505 uni_vmovq(Xmm(ymm_8x127b_.getIdx()), reg_tmp_);
1506 }
1507 } else {
1508 uni_vxorps(xmm_zero_, xmm_zero_, xmm_zero_);
1509
1510 if (prb_.itype == data_type::u8 && prb_.otype == data_type::s8) {
1511 mov(reg_tmp_.cvt32(), 0x7f7f7f7f);
1512 movd(xmm_4x127b_, reg_tmp_.cvt32());
1513 }
1514 }
1515
1516 impl();
1517
1518 L(end_of_kernel);
1519 postamble();
1520 }
1521
1522 ~jit_uni_reorder_kernel_f32_t() override = default;
1523
1524#undef TAIL_PARAM
1525#undef PARAM
1526
1527private:
1528 static constexpr int64_t with_tail_info_ = static_cast<int64_t>(true);
1529 static constexpr int64_t without_tail_info_ = static_cast<int64_t>(false);
1530
1531 int itype_sz_;
1532 int otype_sz_;
1533 int stype_sz_;
1534
1535 const cpu_isa_t isa_;
1536
1537 const Reg64 reg_ptr_in_ = rsi;
1538 const Reg64 reg_ptr_out_ = rdx;
1539 const Reg64 reg_ptr_src_scales_ = abi_not_param1;
1540 const Reg64 reg_ptr_dst_scales_ = r12;
1541 const Reg64 reg_ptr_comp_ = rbx;
1542 const Reg32 &reg_scale_adjust_ = ebp;
1543
1544 const Reg64 reg_off_in_ = r8;
1545 const Reg64 reg_off_out_ = r9;
1546 const Reg64 reg_off_scale_ = r10;
1547 const Reg64 reg_off_comp_ = r11;
1548 // r13-r15 are reserved for creating loops over compute kernels...
1549
1550 const Reg64 reg_tmp_ = rax;
1551
1552 const Xmm xmm_src_scales_ = xmm15;
1553 const Xmm xmm_dst_scales_ = xmm11;
1554 const Xmm xmm_zero_ = xmm14;
1555 const Xmm xmm_4x127b_ = xmm13; // TODO: unite with ymm_zero_
1556 const Ymm ymm_zero_ = ymm14;
1557 const Ymm ymm_8x127b_ = ymm13;
1558 const Xmm xmm_tmp_ = xmm12;
1559 const Xmm xmm_src_zp_ = xmm9;
1560 const Xmm xmm_dst_zp_ = xmm10;
1561 const Xmm xmm_compensation = xmm8;
1562 const Xmm xmm_saturation_ubound_ = xmm12;
1563 const Ymm ymm_saturation_ubound_ = ymm12;
1564
1565 /* bf16 support on SKX */
1566 std::unique_ptr<bf16_emulation_t> bf16_emu_;
1567 const Zmm bf16_emu_reserv_1_ = Zmm(16);
1568 const Zmm bf16_emu_reserv_2_ = Zmm(17);
1569 const Reg64 bf16_emu_scratch_ = reg_tmp_;
1570 const Zmm bf16_emu_reserv_3_ = Zmm(18);
1571 const Zmm bf16_emu_reserv_4_ = Zmm(19);
1572};
1573
1574// Seperate class for no unroll/threading burden
1575struct jit_single_blk_kernel_t : public jit_generator {
1576 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_single_blk_kernel)
1577 static bool applicable(const prb_t &p) {
1578 using namespace data_type;
1579
1580 bool ok = p.ndims >= 2 && mayiuse(avx2)
1581 && p.src_scale_type == scale_type_t::NONE
1582 && p.dst_scale_type == scale_type_t::NONE
1583 && utils::one_of(p.itype, f32) && utils::one_of(p.otype, f32)
1584 && utils::everyone_is(0, p.ioff, p.ooff) && p.beta == 0.f
1585 && prb_has_small_strides(p);
1586 if (!ok) return false;
1587
1588 int64_t n0 = p.nodes[0].n;
1589 auto i0 = p.nodes[0].is;
1590 auto o0 = p.nodes[0].os;
1591 int64_t n1 = p.nodes[1].n;
1592 auto i1 = p.nodes[1].is;
1593 auto o1 = p.nodes[1].os;
1594
1595 /*
1596 * for a transpose of plain to 8c case, nodes would be like:
1597 * n is os
1598 * m 1 8
1599 * 8 m 1
1600 * or
1601 * 8 m 1
1602 * m 1 8
1603 */
1604 ok = (utils::one_of(n0, 8, 16) || utils::one_of(n1, 8, 16))
1605 && ((i0 == 1 && o1 == 1 && n0 == i1 && o0 == n1)
1606 || (o0 == 1 && i1 == 1 && n0 == o1 && i0 == n1));
1607 if (!ok) return false;
1608
1609 // Do not handle transpose of dimensions other than last 2
1610 for (int i = 2; i < p.ndims; ++i) {
1611 if (p.nodes[i].is != p.nodes[i].os) {
1612 ok = false;
1613 break;
1614 }
1615 }
1616
1617 return ok;
1618 }
1619
1620 jit_single_blk_kernel_t(const tr::prb_t &prb)
1621 : jit_generator(jit_name())
1622 , prb_(prb)
1623 , itype_sz_(data_type_size(prb_.itype))
1624 , otype_sz_(data_type_size(prb_.otype))
1625 , block_sz(prb.nodes[0].n) {}
1626
1627 void generate() override {
1628 auto input_stride
1629 = prb_.nodes[0].is != 1 ? prb_.nodes[0].is : prb_.nodes[1].is;
1630 auto output_stride
1631 = prb_.nodes[0].os != 1 ? prb_.nodes[0].os : prb_.nodes[1].os;
1632
1633 Label tail_processing;
1634
1635 const auto load_zp = [&](const Ymm ymm_zp, const Reg64 reg_zp) {
1636 const Xmm xmm_zp = Xmm(ymm_zp.getIdx());
1637 uni_vmovq(xmm_zp, reg_zp);
1638 uni_vpbroadcastd(ymm_zp, xmm_zp);
1639 uni_vcvtdq2ps(ymm_zp, ymm_zp);
1640 };
1641
1642 preamble();
1643
1644 if (prb_.req_src_zp) load_zp(ymm_src_zp, reg_src_zp);
1645
1646 if (prb_.req_dst_zp) load_zp(ymm_dst_zp, reg_dst_zp);
1647
1648 cmp(reg_ptr_tail, true);
1649 je(tail_processing, T_NEAR);
1650
1651 if (block_sz == 8) {
1652 gen_ker8x8(0, 0, input_stride, output_stride, 8, 8);
1653 block_sz = 8;
1654 } else if (block_sz == 16) {
1655 gen_ker16x16_in_8x8(input_stride, output_stride);
1656 block_sz = 16;
1657 } else {
1658 assert(!"unimplemented");
1659 }
1660
1661 postamble();
1662
1663 L(tail_processing);
1664
1665 if (block_sz == 8) {
1666 auto i_tail = input_stride % 8 != 0 ? input_stride % 8 : 8;
1667 auto o_tail = output_stride % 8 != 0 ? output_stride % 8 : 8;
1668 if (i_tail != o_tail) {
1669 auto t_mask = i_tail == 8 ? o_tail : i_tail;
1670 gen_setmask(t_mask);
1671 gen_ker8x8(0, 0, input_stride, output_stride, i_tail, o_tail);
1672 }
1673 } else if (block_sz == 16) {
1674 auto i_tail = input_stride % 16 != 0 ? input_stride % 16 : 16;
1675 auto o_tail = output_stride % 16 != 0 ? output_stride % 16 : 16;
1676 if (i_tail != o_tail) {
1677 auto t_mask = i_tail == 16 ? o_tail : i_tail;
1678 t_mask %= 8;
1679 if (t_mask != 0) gen_setmask(t_mask);
1680 gen_ker16x16_in_8x8(
1681 input_stride, output_stride, i_tail, o_tail);
1682 }
1683 } else {
1684 assert(!"unimplemented");
1685 }
1686
1687 postamble();
1688 }
1689
1690 void gen_loadu(const Ymm ymm, const Address &addr, int size) {
1691 Xmm xmm(ymm.getIdx());
1692 switch (size) {
1693 case 32: vmovups(ymm, addr); break;
1694 case 16: vmovups(xmm, addr); break;
1695 default: assert(!"unreachable");
1696 }
1697 }
1698
1699 void gen_storeu(const Address &addr, const Ymm ymm, int size) {
1700 Xmm xmm(ymm.getIdx());
1701 switch (size) {
1702 case 32: vmovups(addr, ymm); break;
1703 case 16: vmovups(addr, xmm); break;
1704 default: assert(!"unreachable");
1705 }
1706 }
1707
1708 void gen_maskloadu(
1709 const Ymm ymm, const Address &addr, const Ymm mask, int size) {
1710 Xmm xmm(ymm.getIdx());
1711 Xmm mask128(mask.getIdx());
1712 switch (size) {
1713 case 32: vmaskmovps(ymm, mask, addr); break;
1714 case 16: vmaskmovps(xmm, mask128, addr); break;
1715 default: assert(!"unreachable");
1716 }
1717 }
1718
1719 void gen_maskstoreu(
1720 const Address &addr, const Ymm ymm, const Ymm mask, int size) {
1721 Xmm xmm(ymm.getIdx());
1722 Xmm mask128(mask.getIdx());
1723 switch (size) {
1724 case 32: vmaskmovps(addr, mask, ymm); break;
1725 case 16: vmaskmovps(addr, mask128, xmm); break;
1726 default: assert(!"unreachable");
1727 }
1728 }
1729
1730 // Register allocation xmm0~11
1731 void gen_transpose_8x8() {
1732 constexpr int lane = 8;
1733 for (int i = 0; i < lane / 2; i++) {
1734 vunpcklps(Ymm(lane + i), Ymm(2 * i), Ymm(2 * i + 1));
1735 vunpckhps(Ymm(i), Ymm(2 * i), Ymm(2 * i + 1));
1736 }
1737
1738 const unsigned int lfloat = 0x44;
1739 const unsigned int ufloat = 0xee;
1740 for (int i = 0; i < lane / 2; i++) {
1741 int j = i % 2 == 0 ? lane + i : i - 1;
1742 vshufps(Ymm(lane / 2 + 2 * i), Ymm(j), Ymm(j + 1), lfloat);
1743 vshufps(Ymm(lane / 2 + 2 * i + 1), Ymm(j), Ymm(j + 1), ufloat);
1744 }
1745
1746 const unsigned int lquad = 0x20;
1747 for (int i = 0; i < lane / 2; i++)
1748 vperm2f128(Ymm(i), Ymm(lane / 2 + i), Ymm(lane + i), lquad);
1749
1750 const unsigned int uquad = 0x31;
1751 for (int i = lane / 2; i < lane; i++)
1752 vperm2f128(Ymm(i), Ymm(i), Ymm(lane / 2 + i), uquad);
1753 }
1754
1755 // keep order nchw -> nChw()C
1756 // or nChw()C -> nchw
1757 void gen_setmask(int mask) {
1758 // all 0, all 1
1759 vxorps(ymm_tmp, ymm_tmp, ymm_tmp);
1760 vpcmpeqd(ymm_mask, ymm_mask, ymm_mask);
1761 // shift by mask to have tail nelems in ymm_mask
1762 const uint8_t in_mask = 0xFF << mask;
1763 vpblendd(ymm_mask, ymm_mask, ymm_tmp, in_mask);
1764 }
1765
1766 // TODO: Mark parameter with type information
1767 // XXX: !
1768 // offset in byte offset
1769 // stride in element number
1770 //
1771 // Gen specific 8x8 transform respect to certain tail condition
1772 void gen_tr8x8(int i_off, int o_off, int input_stride, int output_stride,
1773 int in_tail, int out_tail) {
1774 constexpr int lane = 8;
1775
1776 if (in_tail == 0 || out_tail == 0) return;
1777
1778 for (int i = 0; i < out_tail; ++i) {
1779 if (in_tail != lane) {
1780 gen_maskloadu(Ymm(i),
1781 ptr[reg_ptr_in_ + i_off + i * input_stride * itype_sz_],
1782 ymm_mask, lane * itype_sz_);
1783 } else {
1784 gen_loadu(Ymm(i),
1785 ptr[reg_ptr_in_ + i_off + i * input_stride * itype_sz_],
1786 lane * itype_sz_);
1787 }
1788 if (prb_.req_src_zp) { vsubps(Ymm(i), Ymm(i), ymm_src_zp); }
1789 }
1790
1791 gen_transpose_8x8();
1792
1793 for (int i = 0; i < in_tail; ++i) {
1794 if (prb_.req_dst_zp) { vaddps(Ymm(i), Ymm(i), ymm_dst_zp); }
1795 if (out_tail == lane) {
1796 gen_storeu(ptr[reg_ptr_out_ + o_off
1797 + i * output_stride * otype_sz_],
1798 Ymm(i), lane * otype_sz_);
1799 } else {
1800 gen_maskstoreu(ptr[reg_ptr_out_ + o_off
1801 + i * output_stride * otype_sz_],
1802 Ymm(i), ymm_mask, lane * otype_sz_);
1803 }
1804 }
1805 }
1806
1807 // tail: 0 ~ 8
1808 // support: either in_tail or out_tail is not 8, but not both
1809 void gen_ker8x8(int i_off, int o_off, int input_stride, int output_stride,
1810 int in_tail, int out_tail) {
1811 gen_tr8x8(i_off, o_off, input_stride, output_stride, in_tail, out_tail);
1812 }
1813
1814 void gen_ker16x16_in_8x8(int input_stride, int output_stride) {
1815 const auto lane = 16;
1816 const auto sub_lane = lane / 2;
1817 gen_tr8x8(0, 0, input_stride, output_stride, sub_lane, sub_lane);
1818 gen_tr8x8(input_stride * sub_lane * itype_sz_, sub_lane * otype_sz_,
1819 input_stride, output_stride, sub_lane, sub_lane);
1820 gen_tr8x8(sub_lane * itype_sz_, output_stride * sub_lane * otype_sz_,
1821 input_stride, output_stride, sub_lane, sub_lane);
1822 gen_tr8x8((input_stride * sub_lane + sub_lane) * itype_sz_,
1823 (output_stride * sub_lane + sub_lane) * otype_sz_, input_stride,
1824 output_stride, sub_lane, sub_lane);
1825 }
1826
1827 // tail can be 1 ~ 16, using avx2 for now
1828 void gen_ker16x16_in_8x8(
1829 int input_stride, int output_stride, int in_tail, int out_tail) {
1830 constexpr auto lane = 16;
1831 constexpr auto sub_lane = lane / 2;
1832 auto tail = in_tail != lane ? in_tail : out_tail;
1833
1834 const auto l_tail = tail < sub_lane ? tail : sub_lane;
1835 const auto u_tail = tail < sub_lane ? 0 : tail - sub_lane;
1836
1837 if (tail == in_tail) {
1838 gen_tr8x8(0, 0, input_stride, output_stride, l_tail, sub_lane);
1839 gen_tr8x8(input_stride * sub_lane * itype_sz_, sub_lane * otype_sz_,
1840 input_stride, output_stride, l_tail, sub_lane);
1841 gen_tr8x8(sub_lane * itype_sz_,
1842 output_stride * sub_lane * otype_sz_, input_stride,
1843 output_stride, u_tail, sub_lane);
1844 gen_tr8x8(itype_sz_ * (input_stride * sub_lane + sub_lane),
1845 otype_sz_ * (output_stride * sub_lane + sub_lane),
1846 input_stride, output_stride, u_tail, sub_lane);
1847 } else {
1848 gen_tr8x8(0, 0, input_stride, output_stride, sub_lane, l_tail);
1849 gen_tr8x8(input_stride * sub_lane * itype_sz_, sub_lane * otype_sz_,
1850 input_stride, output_stride, sub_lane, u_tail);
1851 gen_tr8x8(sub_lane * itype_sz_,
1852 output_stride * sub_lane * itype_sz_, input_stride,
1853 output_stride, sub_lane, l_tail);
1854 gen_tr8x8(itype_sz_ * (input_stride * sub_lane + sub_lane),
1855 otype_sz_ * (output_stride * sub_lane + sub_lane),
1856 input_stride, output_stride, sub_lane, u_tail);
1857 }
1858 }
1859
1860private:
1861 // 6 ~ 12
1862 constexpr static int xmm_save_for_windows = is_windows ? 7 : 0;
1863 constexpr static int xmm_save_start_from = 6;
1864 constexpr static int xmm_width = 16;
1865
1866 void preamble() {
1867 if (is_windows) {
1868 // retrieve 5th function call argument from call stack
1869 static constexpr int param5 = 0x8;
1870 mov(reg_dst_zp, ptr[rsp + param5]);
1871 sub(rsp, xmm_save_for_windows * xmm_width);
1872 for (int i = 0; i < xmm_save_for_windows; ++i) {
1873 uni_vmovdqu(ptr[rsp + i * xmm_width],
1874 Xbyak::Xmm(xmm_save_start_from + i));
1875 }
1876 }
1877 }
1878
1879 void postamble() {
1880 if (is_windows) {
1881 for (int i = 0; i < xmm_save_for_windows; ++i)
1882 uni_vmovdqu(Xbyak::Xmm(xmm_save_start_from + i),
1883 ptr[rsp + i * xmm_width]);
1884 add(rsp, xmm_save_for_windows * xmm_width);
1885 }
1886 uni_vzeroupper();
1887 ret();
1888 }
1889
1890 const prb_t &prb_;
1891
1892 int itype_sz_;
1893 int otype_sz_;
1894 int block_sz;
1895
1896 Reg64 reg_ptr_in_ = abi_param1;
1897 Reg64 reg_ptr_out_ = abi_param2;
1898 // Windows bool is 1-byte in register
1899 Reg8 reg_ptr_tail = is_windows ? r8b : dl;
1900 Reg64 reg_src_zp = abi_param4;
1901 Reg64 reg_dst_zp = is_windows ? r10 : r8;
1902
1903 Ymm ymm_mask = ymm12;
1904 Ymm ymm_tmp = ymm0;
1905 Ymm ymm_src_zp = ymm14;
1906 Ymm ymm_dst_zp = ymm15;
1907};
1908
1909status_t kernel_t::desc_init(
1910 kernel_t::desc_t &desc, const prb_t &prb, int ndims_ker_max) {
1911 desc.prb = prb;
1912 desc.prb.ioff = desc.prb.ooff = 0;
1913
1914 if (ndims_ker_max > prb.ndims) return status::invalid_arguments;
1915
1916 auto ndims_ker_max_f = [&]() {
1917 size_t cur_size = 1;
1918 for (int d = 0; d < prb.ndims; cur_size *= prb.nodes[d++].n)
1919 if (cur_size >= ker_prb_size_min) return d;
1920 return prb.ndims;
1921 };
1922
1923 if (ndims_ker_max <= 0) ndims_ker_max = ndims_ker_max_f();
1924
1925 /* traverse through kernel implementations */
1926 /* TODO: find a better way to do that... */
1927 desc.id = 0;
1928 for (int ndims_ker = ndims_ker_max; ndims_ker > 0; --ndims_ker) {
1929 desc.prb.ndims = ndims_ker;
1930 if (jit_uni_reorder_kernel_f32_t::applicable(desc.prb))
1931 return status::success;
1932 }
1933
1934 return status::unimplemented;
1935}
1936
1937kernel_t *kernel_t::create(const kernel_t::desc_t &desc) {
1938 switch (desc.id) {
1939 case 0: return new jit_uni_reorder_kernel_f32_t(desc);
1940 default: assert(!"unknown kernel id"); return nullptr;
1941 }
1942
1943 return nullptr;
1944}
1945
1946} // namespace tr
1947
1948static void prb_block_for_cache(tr::prb_t &prb) {
1949 /* If strides for 0th and 1st nodes are cache friendly
1950 * then one can altogether do away with blocking ! */
1951 static constexpr int num_elems_thr = 16;
1952 const bool stride_cache_friendly
1953 = ((prb.nodes[0].is % 64 == 0 && prb.nodes[0].n > num_elems_thr)
1954 || (prb.ndims > 1 && prb.nodes[1].is % num_elems_thr == 0
1955 && prb.nodes[1].n > num_elems_thr))
1956 && !prb.is_tail_present;
1957
1958 // performance improvement for shapes with large inner-most dimension
1959 const size_t L1_cache_sz
1960 = size_t(3) * platform::get_per_core_cache_size(1) / 4;
1961 const size_t itype_sz_ = data_type_size(prb.itype);
1962 const size_t inner_block_sz = prb.nodes[0].n * itype_sz_;
1963 const bool requires_inner_blocking = inner_block_sz > L1_cache_sz
1964 // 'is_tail_present' is not supported for cache_blocking when
1965 // asymmetric_comp is executed.
1966 && IMPLICATION(prb.req_asymmetric_comp, !prb.is_tail_present);
1967
1968 const bool cache_blocking_needed
1969 = stride_cache_friendly || requires_inner_blocking;
1970 if (!cache_blocking_needed) return;
1971
1972 int unit_input_stride_idx = -1;
1973 for (auto idx = 0; idx < prb.ndims; ++idx) {
1974 if (prb.nodes[idx].is == 1) unit_input_stride_idx = idx;
1975 }
1976
1977 /* Re-prioritize the sequential read over sequential write:
1978 * /-> [n0:is0:1][16n1:1:osk]...
1979 * [n0:is0:1]...[nk:1:osk] --> or
1980 * \-> [16n1:1:osk][n0:is0:1]... */
1981 if (unit_input_stride_idx != -1) {
1982 const auto output_stride = prb.nodes[unit_input_stride_idx].os;
1983 const auto num_elems = prb.nodes[unit_input_stride_idx].n;
1984
1985 const bool split_needed = (num_elems > num_elems_thr)
1986 && (num_elems % num_elems_thr == 0);
1987 const int move_location = (output_stride % 4 != 0) ? 0 : 1;
1988 if (split_needed)
1989 prb_node_split(prb, unit_input_stride_idx, num_elems_thr);
1990
1991 /* Because of cache-unfriendly nature of unit-output stride node, let
1992 * us move unit-input stride node on or near front! */
1993 if (unit_input_stride_idx != move_location)
1994 prb_node_move(prb, unit_input_stride_idx, move_location);
1995 }
1996
1997 /* Potentially, split the node with os=1 in two and pull in the node with
1998 * is=1 between them for better cache reuse:
1999 * [n0:is0:1][n1:1:os1] --> [16n0:is0:1][n1:1:os1][n0/16:is0*16:16] */
2000 if (prb.ndims >= 2 && prb.nodes[0].os == 1 && prb.nodes[1].is == 1) {
2001 const auto num_elems = prb.nodes[0].n;
2002
2003 const bool split_needed = (num_elems > num_elems_thr)
2004 && (num_elems % num_elems_thr == 0);
2005 if (split_needed) {
2006 prb_node_split(prb, 0, num_elems_thr);
2007 prb_node_move(prb, 1, 2);
2008
2009 // Update node information
2010 prb_node_dependency(prb);
2011
2012 // heuristics - looping over the unrolled dims should maximize reuse
2013 // of the already cached data; observation is choosing the smallest
2014 // dim from the remaining (from 2 up to ndims) gives good results
2015 constexpr int new_position = 2;
2016 const auto dim_beg_it = std::begin(prb.nodes);
2017 const auto dim_two_it = dim_beg_it + new_position;
2018 const auto dim_last_it = dim_beg_it + prb.ndims;
2019 const auto min_n_node_it = std::min_element(dim_two_it, dim_last_it,
2020 [](const tr::node_t &lhs, const tr::node_t &rhs) {
2021 return lhs.n < rhs.n;
2022 });
2023 const auto min_idx = std::distance(dim_beg_it, min_n_node_it);
2024 // check if min_idx node is parent of node with tail processing which
2025 // is currently unsupported (i.e. tail processing can only be handled
2026 // at the inner-most dimension)
2027 bool inner_block_has_tail = false;
2028 for (int idx = min_idx - 1; idx >= new_position; idx--) {
2029 if (prb.nodes[idx].parent_node_id == min_idx) {
2030 inner_block_has_tail = true;
2031 break;
2032 }
2033 }
2034
2035 if (min_idx > new_position && (!inner_block_has_tail))
2036 prb_node_move(prb, min_idx, new_position);
2037 }
2038 }
2039}
2040
2041/** finds the maximum number of dimension the kernel should process and
2042 * optionally splits one of the dimension to achieve better balance between
2043 * parallel driver and the kernel. */
2044static void prb_thread_kernel_balance(
2045 tr::prb_t &prb, int &ndims_ker_max, int nthr) {
2046 size_t size_total = 1;
2047 for (int d = 0; d < prb.ndims; ++d)
2048 size_total *= prb.nodes[d].n;
2049
2050 /* The general expression for size_drv_thr can be written as
2051 * size_drv_min = C0 + FC * (nthr > 1 ? 1 : 0) + VC * (nthr - 1)
2052 * where FC and VC are fixed and variable costs respectively.
2053 * Though for now, the below heuristic seems to be good enough */
2054 const size_t size_drv_thr = (nthr > 1) ? 16 * nthr : 1;
2055
2056 /* size_drv_min is the minimal size for the parallel
2057 * driver required for good parallelization */
2058 const size_t size_drv_min
2059 = nstl::min<size_t>(size_drv_thr, utils::div_up(size_total, 1024));
2060
2061 /* kdims -- # of dimensions processed by a kernel
2062 * size_ker_cur -- product of the dimension processed by a kernel
2063 * size_drv_cur -- product of the dimension processed by a driver */
2064
2065 int kdims = prb.ndims;
2066 size_t size_drv_cur = 1;
2067 for (; kdims > 1 && size_drv_cur < size_drv_min; --kdims)
2068 size_drv_cur *= prb.nodes[kdims - 1].n;
2069
2070 size_t size_ker_cur = 1;
2071 for (int d = 0; d < kdims; ++d)
2072 size_ker_cur *= prb.nodes[d].n;
2073
2074 /* Initially kdims is chosen so that size_drv_cur >= size_drv_min.
2075 *
2076 * It might happen that for chosen kdims the size_ker_cur is too small
2077 * (less than tr::ker_prb_size_min). In that case try to split the
2078 * innermost driver dimension into two, to increase size_ker_cur. */
2079 const bool want_borrow_ker_from_drv = kdims < prb.ndims
2080 && size_ker_cur < tr::ker_prb_size_min
2081 && size_drv_cur > size_drv_min;
2082 if (want_borrow_ker_from_drv) {
2083 /* size_want_borrow is the minimal size, so that:
2084 * o) size_ker_cur * size_want_borrow >= tr::ker_prb_size_min
2085 * o) current innermost driver dimension is divisible by
2086 * size_want_borrow (so that we can evenly split that
2087 * dimension into two)
2088 *
2089 * In the worst case the minimal size_want_borrow is equal
2090 * to the innermost driver dimension itself. In that case
2091 * we will sacrifice it in favor of kernel (is it fine?). */
2092 size_t size_want_borrow
2093 = utils::div_up(tr::ker_prb_size_min, size_ker_cur);
2094 for (; prb.nodes[kdims].n % size_want_borrow; ++size_want_borrow)
2095 ;
2096
2097 if (size_want_borrow != prb.nodes[kdims].n)
2098 prb_node_split(prb, kdims, size_want_borrow);
2099 kdims += 1;
2100 }
2101
2102 /* On the other hand it might happen that for chosen kdims
2103 * the size_drv_cur is too small (less than size_drv_min). In that case
2104 * try to split the outermost kernel dimension into two, to increase
2105 * size_drv_cur. */
2106 const bool want_borrow_drv_from_ker = size_ker_cur > tr::ker_prb_size_min
2107 && size_drv_cur < size_drv_min;
2108 if (want_borrow_drv_from_ker) {
2109 size_t size_want_borrow = utils::div_up(size_drv_min, size_drv_cur);
2110 for (; prb.nodes[kdims - 1].n % size_want_borrow; ++size_want_borrow)
2111 ;
2112
2113 if (size_want_borrow != prb.nodes[kdims - 1].n)
2114 prb_node_split(
2115 prb, kdims - 1, prb.nodes[kdims - 1].n / size_want_borrow);
2116 }
2117
2118 ndims_ker_max = kdims;
2119
2120 if (want_borrow_ker_from_drv || want_borrow_drv_from_ker) {
2121 DEBUG({
2122 printf("split: ");
2123 prb_dump(prb);
2124 printf("ndims_ker_max = %d\n", ndims_ker_max);
2125 });
2126 }
2127}
2128
2129status_t jit_uni_reorder_t::pd_t::init(
2130 engine_t *engine, engine_t *src_engine, engine_t *dst_engine) {
2131 CHECK(cpu_reorder_pd_t::init(engine, src_engine, dst_engine));
2132
2133 CHECK(init_scratchpad());
2134
2135 return status::success;
2136}
2137
2138status_t jit_uni_reorder_t::pd_t::init_scratchpad() {
2139 auto scratchpad = scratchpad_registry().registrar();
2140
2141 const bool compensation_needed
2142 = prb_.req_s8s8_comp || prb_.req_asymmetric_comp;
2143 if (compensation_needed) {
2144 const memory_desc_wrapper od(dst_md());
2145 const auto G = with_groups_ ? od.padded_dims()[0] : 1;
2146 const auto N = od.padded_dims()[with_groups_ ? 1 : 0];
2147 static constexpr int cache_line_size = 16;
2148 const auto wspace_per_thr_size
2149 = utils::rnd_up(G * N, cache_line_size) * sizeof(int32_t);
2150
2151 const auto compensation_reduce_size = wspace_per_thr_size * nthr_;
2152
2153 // Every thread gets its own scratchpad space for each N.
2154 scratchpad.template book<int32_t>(
2155 memory_tracking::names::key_reorder_space,
2156 compensation_reduce_size);
2157 }
2158
2159 const memory_desc_wrapper input_d(src_md());
2160 int scales_mask = -1;
2161 bool is_set = false;
2162 CHECK(attr()->scales_.get(DNNL_ARG_DST, &scales_mask, &is_set));
2163
2164 if (is_set && scales_mask > 0) {
2165 get_D_values(input_d, scales_mask, nullptr, &D_mask_, nullptr);
2166 if (D_mask_ > 1) {
2167 scratchpad.template book<float>(
2168 memory_tracking::names::key_reorder_precomputed_dst_scales,
2169 D_mask_);
2170 }
2171 }
2172
2173 return status::success;
2174}
2175
2176status_t jit_uni_reorder_t::pd_t::create(reorder_pd_t **reorder_pd,
2177 engine_t *engine, const primitive_attr_t *attr, engine_t *src_engine,
2178 const memory_desc_t *src_md, engine_t *dst_engine,
2179 const memory_desc_t *dst_md) {
2180 auto prb = tr::prb_t();
2181
2182 status_t prb_init_status = prb_init(prb, *src_md, *dst_md, attr);
2183 if (prb_init_status != status::success) return prb_init_status;
2184
2185 prb_block_for_cache(prb);
2186 DEBUG({
2187 printf("cache: ");
2188 prb_dump(prb);
2189 });
2190
2191 int ndims_ker_max {};
2192 int nthr = dnnl_get_max_threads();
2193 prb_thread_kernel_balance(prb, ndims_ker_max, nthr);
2194
2195 if (prb.is_tail_present) prb_node_dependency(prb);
2196
2197 tr::kernel_t::desc_t ker_desc;
2198 status_t ker_init_status
2199 = tr::kernel_t::desc_init(ker_desc, prb, ndims_ker_max);
2200 if (ker_init_status != status::success) return ker_init_status;
2201
2202 const int ndims_driver = prb.ndims - ker_desc.prb.ndims;
2203 if (ndims_driver > jit_uni_reorder_t::ndims_driver_max)
2204 return status::unimplemented;
2205
2206 DEBUG({
2207 printf("ker : ");
2208 prb_dump(ker_desc.prb);
2209 });
2210
2211 auto _pd = new pd_t(
2212 attr, src_engine->kind(), src_md, dst_engine->kind(), dst_md);
2213 if (_pd == nullptr) return status::out_of_memory;
2214
2215 _pd->nthr_ = nthr;
2216 _pd->prb_ = prb;
2217 _pd->with_groups_
2218 = prb.compensation_mask == tr::prb_t::comp_mask_with_groups;
2219 if (_pd->init(engine, src_engine, dst_engine) != status::success) {
2220 delete _pd;
2221 return status::unimplemented;
2222 }
2223 _pd->ker_desc_ = ker_desc;
2224 _pd->init_scratchpad_md();
2225
2226 return safe_ptr_assign(*reorder_pd, _pd);
2227}
2228
2229void jit_uni_reorder_t::omp_driver_0d(int off, const char *in, char *out,
2230 const float *src_scales, const float *dst_scales, int src_zp,
2231 int dst_zp, int32_t *compensation_scratch) const {
2232 const tr::prb_t &prb = pd()->prb_;
2233
2234 tr::call_param_t base_params;
2235 base_params.in = in;
2236 base_params.out = out;
2237 base_params.src_scales = src_scales;
2238 base_params.dst_scales = dst_scales;
2239 base_params.src_zp = src_zp;
2240 base_params.dst_zp = dst_zp;
2241 base_params.compensation_scratch = compensation_scratch;
2242
2243 if (prb.is_tail_present) {
2244 tr::tail_call_param_t tail_params;
2245 tail_params.base_params = base_params;
2246
2247 static constexpr int omp_ndims = 0;
2248 fill_curr_data_chunks(prb, off, nullptr, omp_ndims, tail_params);
2249
2250 (*kernel_)(&tail_params);
2251 } else {
2252 (*kernel_)(&base_params);
2253 }
2254}
2255
2256void jit_uni_reorder_t::omp_driver_1d(int ithr, int nthr, int off,
2257 const char *in, char *out, const float *src_scales,
2258 const float *dst_scales, int src_zp, int dst_zp,
2259 int32_t *compensation_scratch) const {
2260 const tr::prb_t &prb = pd()->prb_;
2261 const tr::node_t *ns = prb.nodes + off;
2262 for_nd(ithr, nthr, (ptrdiff_t)ns[0].n, [&](ptrdiff_t d0) {
2263 tr::call_param_t base_params;
2264 base_params.in = in + d0 * ns[0].is * data_type_size(prb.itype);
2265 base_params.out = out + d0 * ns[0].os * data_type_size(prb.otype);
2266 base_params.src_scales = src_scales + d0 * ns[0].ss;
2267 base_params.dst_scales = dst_scales + d0 * ns[0].ss;
2268 base_params.src_zp = src_zp;
2269 base_params.dst_zp = dst_zp;
2270 base_params.compensation_scratch = compensation_scratch + d0 * ns[0].cs;
2271
2272 if (prb.is_tail_present) {
2273 tr::tail_call_param_t tail_params;
2274 tail_params.base_params = base_params;
2275
2276 static constexpr int omp_ndims = 1;
2277 const ptrdiff_t omp_data_chunks[omp_ndims] = {d0};
2278 fill_curr_data_chunks(
2279 prb, off, omp_data_chunks, omp_ndims, tail_params);
2280
2281 (*kernel_)(&tail_params);
2282 } else {
2283 (*kernel_)(&base_params);
2284 }
2285 });
2286}
2287
2288void jit_uni_reorder_t::omp_driver_2d(int ithr, int nthr, int off,
2289 const char *in, char *out, const float *src_scales,
2290 const float *dst_scales, int src_zp, int dst_zp,
2291 int32_t *compensation_scratch) const {
2292 const tr::prb_t &prb = pd()->prb_;
2293 const tr::node_t *ns = prb.nodes + off;
2294 for_nd(ithr, nthr, (ptrdiff_t)ns[1].n, (ptrdiff_t)ns[0].n,
2295 [&](ptrdiff_t d1, ptrdiff_t d0) {
2296 tr::call_param_t base_params;
2297 base_params.in = in
2298 + (d0 * ns[0].is + d1 * ns[1].is)
2299 * data_type_size(prb.itype);
2300 base_params.out = out
2301 + (d0 * ns[0].os + d1 * ns[1].os)
2302 * data_type_size(prb.otype);
2303 base_params.src_scales
2304 = src_scales + d0 * ns[0].ss + d1 * ns[1].ss;
2305 base_params.dst_scales
2306 = dst_scales + d0 * ns[0].ss + d1 * ns[1].ss;
2307 base_params.src_zp = src_zp;
2308 base_params.dst_zp = dst_zp;
2309 base_params.compensation_scratch
2310 = compensation_scratch + d0 * ns[0].cs + d1 * ns[1].cs;
2311
2312 if (prb.is_tail_present) {
2313 tr::tail_call_param_t tail_params;
2314 tail_params.base_params = base_params;
2315
2316 static constexpr int omp_ndims = 2;
2317 const ptrdiff_t omp_data_chunks[omp_ndims] = {d0, d1};
2318 fill_curr_data_chunks(
2319 prb, off, omp_data_chunks, omp_ndims, tail_params);
2320
2321 (*kernel_)(&tail_params);
2322 } else {
2323 (*kernel_)(&base_params);
2324 }
2325 });
2326}
2327
2328void jit_uni_reorder_t::omp_driver_3d(int ithr, int nthr, int off,
2329 const char *in, char *out, const float *src_scales,
2330 const float *dst_scales, int src_zp, int dst_zp,
2331 int32_t *compensation_scratch) const {
2332 const tr::prb_t &prb = pd()->prb_;
2333 const tr::node_t *ns = prb.nodes + off;
2334 for_nd(ithr, nthr, (ptrdiff_t)ns[2].n, (ptrdiff_t)ns[1].n,
2335 (ptrdiff_t)ns[0].n, [&](ptrdiff_t d2, ptrdiff_t d1, ptrdiff_t d0) {
2336 tr::call_param_t base_params;
2337 base_params.in = in
2338 + (d0 * ns[0].is + d1 * ns[1].is + d2 * ns[2].is)
2339 * data_type_size(prb.itype);
2340 base_params.out = out
2341 + (d0 * ns[0].os + d1 * ns[1].os + d2 * ns[2].os)
2342 * data_type_size(prb.otype);
2343 base_params.src_scales = src_scales + d0 * ns[0].ss
2344 + d1 * ns[1].ss + d2 * ns[2].ss;
2345 base_params.dst_scales = dst_scales + d0 * ns[0].ss
2346 + d1 * ns[1].ss + d2 * ns[2].ss;
2347 base_params.src_zp = src_zp;
2348 base_params.dst_zp = dst_zp;
2349 base_params.compensation_scratch = compensation_scratch
2350 + d0 * ns[0].cs + d1 * ns[1].cs + d2 * ns[2].cs;
2351
2352 if (prb.is_tail_present) {
2353 tr::tail_call_param_t tail_params;
2354 tail_params.base_params = base_params;
2355
2356 static constexpr int omp_ndims = 3;
2357 const ptrdiff_t omp_data_chunks[omp_ndims] = {d0, d1, d2};
2358 fill_curr_data_chunks(
2359 prb, off, omp_data_chunks, omp_ndims, tail_params);
2360
2361 (*kernel_)(&tail_params);
2362 } else {
2363 (*kernel_)(&base_params);
2364 }
2365 });
2366}
2367
2368void jit_uni_reorder_t::omp_driver_4d(int ithr, int nthr, int off,
2369 const char *in, char *out, const float *src_scales,
2370 const float *dst_scales, int src_zp, int dst_zp,
2371 int32_t *compensation_scratch) const {
2372 const tr::prb_t &prb = pd()->prb_;
2373 const tr::node_t *ns = prb.nodes + off;
2374 for_nd(ithr, nthr, (ptrdiff_t)ns[3].n, (ptrdiff_t)ns[2].n,
2375 (ptrdiff_t)ns[1].n, (ptrdiff_t)ns[0].n,
2376 [&](ptrdiff_t d3, ptrdiff_t d2, ptrdiff_t d1, ptrdiff_t d0) {
2377 tr::call_param_t base_params;
2378 base_params.in = in
2379 + (d0 * ns[0].is + d1 * ns[1].is + d2 * ns[2].is
2380 + d3 * ns[3].is)
2381 * data_type_size(prb.itype);
2382 base_params.out = out
2383 + (d0 * ns[0].os + d1 * ns[1].os + d2 * ns[2].os
2384 + d3 * ns[3].os)
2385 * data_type_size(prb.otype);
2386 base_params.src_scales = src_scales + d0 * ns[0].ss
2387 + d1 * ns[1].ss + d2 * ns[2].ss + d3 * ns[3].ss;
2388 base_params.dst_scales = dst_scales + d0 * ns[0].ss
2389 + d1 * ns[1].ss + d2 * ns[2].ss + d3 * ns[3].ss;
2390 base_params.src_zp = src_zp;
2391 base_params.dst_zp = dst_zp;
2392 base_params.compensation_scratch = compensation_scratch
2393 + d0 * ns[0].cs + d1 * ns[1].cs + d2 * ns[2].cs
2394 + d3 * ns[3].cs;
2395
2396 if (prb.is_tail_present) {
2397 tr::tail_call_param_t tail_params;
2398 tail_params.base_params = base_params;
2399
2400 static constexpr int omp_ndims = 4;
2401 const ptrdiff_t omp_data_chunks[omp_ndims]
2402 = {d0, d1, d2, d3};
2403 fill_curr_data_chunks(
2404 prb, off, omp_data_chunks, omp_ndims, tail_params);
2405
2406 (*kernel_)(&tail_params);
2407 } else {
2408 (*kernel_)(&base_params);
2409 }
2410 });
2411}
2412
2413void jit_uni_reorder_t::omp_driver(const char *in, char *out,
2414 const float *src_scales, const float *dst_scales, int src_zp,
2415 int dst_zp, const memory_tracking::grantor_t &scratchpad) const {
2416 in += pd()->prb_.ioff * data_type_size(pd()->prb_.itype);
2417 out += pd()->prb_.ooff * data_type_size(pd()->prb_.otype);
2418
2419 DEBUG({
2420 printf("prb : ");
2421 tr::prb_dump(pd()->prb_);
2422 });
2423 DEBUG({
2424 printf("ker : ");
2425 tr::prb_dump(pd()->ker_desc_.prb);
2426 });
2427
2428 int ndims = pd()->prb_.ndims;
2429 int ndims_ker = pd()->ker_desc_.prb.ndims;
2430 const bool req_s8s8_comp = pd()->prb_.req_s8s8_comp;
2431 const bool req_asymmetric_comp = pd()->prb_.req_asymmetric_comp;
2432 const bool req_compensation = req_s8s8_comp || req_asymmetric_comp;
2433 assert(ndims - ndims_ker <= ndims_driver_max);
2434
2435 int32_t *compensation_reduce_scratch = scratchpad.template get<int32_t>(
2436 memory_tracking::names::key_reorder_space);
2437
2438 const memory_desc_wrapper od(pd()->dst_md());
2439 const auto G = pd()->with_groups_ ? od.padded_dims()[0] : 1;
2440 const auto N = od.padded_dims()[pd()->with_groups_ ? 1 : 0];
2441 static constexpr int cache_line_size = 16;
2442 const auto wspace_per_thr_size = utils::rnd_up(G * N, cache_line_size);
2443 const auto wspace_per_thr_bytes = wspace_per_thr_size * sizeof(int32_t);
2444
2445 if (ndims - ndims_ker == 0) {
2446 if (req_compensation)
2447 std::memset(compensation_reduce_scratch, 0, wspace_per_thr_bytes);
2448
2449 omp_driver_0d(ndims_ker, in, out, src_scales, dst_scales, src_zp,
2450 dst_zp, compensation_reduce_scratch);
2451 } else {
2452 parallel(pd()->nthr_, [&](const int ithr, const int nthr) {
2453 int32_t *compensation_scratch = nullptr;
2454 if (req_compensation) {
2455 compensation_scratch = &compensation_reduce_scratch[ithr
2456 * wspace_per_thr_size];
2457 std::memset(compensation_scratch, 0, wspace_per_thr_bytes);
2458 }
2459
2460 switch (ndims - ndims_ker) {
2461 case 1:
2462 omp_driver_1d(ithr, nthr, ndims_ker, in, out, src_scales,
2463 dst_scales, src_zp, dst_zp, compensation_scratch);
2464 break;
2465 case 2:
2466 omp_driver_2d(ithr, nthr, ndims_ker, in, out, src_scales,
2467 dst_scales, src_zp, dst_zp, compensation_scratch);
2468 break;
2469 case 3:
2470 omp_driver_3d(ithr, nthr, ndims_ker, in, out, src_scales,
2471 dst_scales, src_zp, dst_zp, compensation_scratch);
2472 break;
2473 case 4:
2474 omp_driver_4d(ithr, nthr, ndims_ker, in, out, src_scales,
2475 dst_scales, src_zp, dst_zp, compensation_scratch);
2476 break;
2477 default: assert(!"unimplemented");
2478 }
2479 });
2480 }
2481
2482 //reduction of intermediate compensation results to the final output
2483 if (req_compensation) {
2484 const int nthr = ndims - ndims_ker == 0 ? 1 : pd()->nthr_;
2485 reduce_compensation(
2486 out, compensation_reduce_scratch, nthr, wspace_per_thr_size);
2487 }
2488}
2489
2490void jit_uni_reorder_t::reduce_compensation(char *out,
2491 const int32_t *compensation_reduce_scratch, const int nthr,
2492 const dim_t wspace_per_thr_size) const {
2493
2494 const memory_desc_wrapper od(pd()->dst_md());
2495 const size_t offset = od.size() - od.additional_buffer_size();
2496
2497 static constexpr auto comp_dt_size = sizeof(int32_t);
2498 static constexpr int32_t comp_s8s8_shift = 128;
2499
2500 // Note: We do not need to explicitly zero-out compensation buffer, as the
2501 // per_thread buffers are already zeroed out in the padded area.
2502 const auto G = pd()->with_groups_ ? od.padded_dims()[0] : 1;
2503 const auto N = od.padded_dims()[pd()->with_groups_ ? 1 : 0];
2504 const auto GN = G * N;
2505 const bool req_s8s8_comp = pd()->prb_.req_s8s8_comp;
2506 const bool req_asymmetric_comp = pd()->prb_.req_asymmetric_comp;
2507 const size_t zp_offset
2508 = offset + (pd()->prb_.req_s8s8_comp ? GN * comp_dt_size : 0);
2509
2510 parallel_nd(GN, [&](int idx) {
2511 int32_t acc = 0;
2512 for (int ithr = 0; ithr < nthr; ithr++) {
2513 acc -= compensation_reduce_scratch[ithr * wspace_per_thr_size
2514 + idx];
2515 }
2516 if (req_s8s8_comp) {
2517 int32_t *out_comp = reinterpret_cast<int32_t *>(&out[offset]);
2518 out_comp[idx] = comp_s8s8_shift * acc;
2519 }
2520 if (req_asymmetric_comp) {
2521 int32_t *out_asym_comp
2522 = reinterpret_cast<int32_t *>(&out[zp_offset]);
2523 out_asym_comp[idx] = acc;
2524 }
2525 });
2526}
2527
2528void jit_uni_reorder_t::fill_curr_data_chunks(const tr::prb_t &prb,
2529 const int off, const ptrdiff_t *omp_data_chunks, const int omp_ndims,
2530 tr::tail_call_param_t &c) const {
2531 // Chunks are backwards numered i.e:
2532 // [0] -> [node_size]
2533 // [1] -> [node_size - 1]
2534 // ...
2535 // [node_size - 1] -> [1]
2536
2537 // It is done like this, because it is easier to decrement counter
2538 // and check if it is equal to zero than increment and check
2539 // if it is equal to node_size in jit kernel.
2540
2541 static constexpr int64_t empty_chunk_info = -1;
2542 static constexpr int64_t last_chunk = 1;
2543
2544 for (int curr_node_id = prb.ndims - 1; curr_node_id >= 0; curr_node_id--) {
2545 const int parent_node_id = prb.nodes[curr_node_id].parent_node_id;
2546 const bool is_drv_processing_this_node
2547 = curr_node_id >= off && curr_node_id <= off + omp_ndims - 1;
2548 const bool is_tail_processing
2549 = prb.is_tail_in_one_of_child_nodes(curr_node_id)
2550 || prb.nodes[curr_node_id].tail_size > 0;
2551
2552 if (is_drv_processing_this_node && is_tail_processing) {
2553 const int inner_idx = curr_node_id - off;
2554 assert(inner_idx < omp_ndims);
2555 const int64_t node_size = prb.nodes[curr_node_id].tail_size > 0
2556 ? prb.nodes[curr_node_id].tail_size
2557 : prb.nodes[curr_node_id].n;
2558 const int64_t data_chunk = node_size - omp_data_chunks[inner_idx];
2559
2560 if (!prb.nodes[curr_node_id].is_parent_empty()) {
2561 const bool is_parent_chunk_last
2562 = c.curr_data_chunks[parent_node_id] == last_chunk;
2563 c.curr_data_chunks[curr_node_id]
2564 = is_parent_chunk_last ? data_chunk : empty_chunk_info;
2565 c.zeroing_data = static_cast<int64_t>(
2566 is_parent_chunk_last && data_chunk <= 0);
2567 } else {
2568 c.curr_data_chunks[curr_node_id] = data_chunk;
2569 c.zeroing_data = static_cast<int64_t>(data_chunk <= 0);
2570 }
2571 c.skip_kernel_execution = static_cast<int64_t>(c.zeroing_data
2572 && !prb.nodes[curr_node_id].is_zero_pad_needed);
2573 if (c.zeroing_data || c.skip_kernel_execution) break;
2574 } else
2575 c.curr_data_chunks[curr_node_id] = empty_chunk_info;
2576 }
2577}
2578
2579status_t jit_uni_reorder_t::init(engine_t *engine) {
2580 CHECK(safe_ptr_assign(kernel_, tr::kernel_t::create(pd()->ker_desc_)));
2581 return kernel_->create_kernel();
2582}
2583
2584status_t jit_uni_reorder_t::execute(const exec_ctx_t &ctx) const {
2585 const auto &scratchpad = ctx.get_scratchpad_grantor();
2586
2587 auto in = CTX_IN_MEM(const char *, DNNL_ARG_FROM);
2588 auto out = CTX_OUT_MEM(char *, DNNL_ARG_TO);
2589 DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC);
2590 DEFINE_ARG_SCALES_BUFFER(dst_scales_, DNNL_ARG_DST);
2591
2592 const float *dst_scales = pd()->precompute_scales(
2593 scratchpad, pd()->attr(), pd()->D_mask_, dst_scales_);
2594 assert(dst_scales);
2595
2596 DEFINE_ZERO_POINT_VALUE(src_zp, DNNL_ARG_FROM);
2597 DEFINE_ZERO_POINT_VALUE(dst_zp, DNNL_ARG_TO);
2598
2599 omp_driver(in, out, src_scales, dst_scales, src_zp, dst_zp, scratchpad);
2600
2601 return status::success;
2602}
2603
2604status_t jit_blk_reorder_t::pd_t::create(reorder_pd_t **reorder_pd,
2605 engine_t *engine, const primitive_attr_t *attr, engine_t *src_engine,
2606 const memory_desc_t *src_md, engine_t *dst_engine,
2607 const memory_desc_t *dst_md) {
2608 auto prb = tr::prb_t();
2609
2610 status_t prb_init_status = prb_init(prb, *src_md, *dst_md, attr);
2611 if (prb_init_status != status::success) return prb_init_status;
2612 // only uni_reorder supports tail processing now
2613 // TODO: Add tail processing support in blk_reorder
2614 if (prb.is_tail_present) return status::unimplemented;
2615
2616 prb_tile_normalize(prb);
2617 DEBUG({
2618 printf("tile : ");
2619 prb_dump(prb);
2620 });
2621
2622 if (!tr::jit_single_blk_kernel_t::applicable(prb)) {
2623 return status::unimplemented;
2624 }
2625
2626 auto _pd = new pd_t(
2627 attr, src_engine->kind(), src_md, dst_engine->kind(), dst_md);
2628 if (_pd == nullptr) return status::out_of_memory;
2629 _pd->prb_ = prb;
2630 if (_pd->init(engine, src_engine, dst_engine) != status::success) {
2631 delete _pd;
2632 return status::unimplemented;
2633 }
2634 _pd->init_scratchpad_md();
2635
2636 return safe_ptr_assign(*reorder_pd, _pd);
2637}
2638
2639void jit_blk_reorder_t::pd_t::prb_tile_normalize(tr::prb_t &p) {
2640 if (!utils::one_of(p.nodes[0].n, 8ul, 16ul)
2641 && utils::one_of(p.nodes[1].n, 8ul, 16ul)) {
2642 nstl::swap(p.nodes[0], p.nodes[1]);
2643 }
2644}
2645
2646jit_blk_reorder_t::jit_blk_reorder_t(const pd_t *apd) : primitive_t(apd) {}
2647jit_blk_reorder_t::~jit_blk_reorder_t() = default;
2648
2649status_t jit_blk_reorder_t::init(engine_t *engine) {
2650 kernel_ = utils::make_unique<tr::jit_single_blk_kernel_t>(pd()->prb_);
2651 return kernel_->create_kernel();
2652}
2653
2654status_t jit_blk_reorder_t::execute(const exec_ctx_t &ctx) const {
2655 const auto in = CTX_IN_MEM(const char *, DNNL_ARG_FROM);
2656 auto out = CTX_OUT_MEM(char *, DNNL_ARG_TO);
2657 DEFINE_ZERO_POINT_VALUE(src_zp, DNNL_ARG_FROM);
2658 DEFINE_ZERO_POINT_VALUE(dst_zp, DNNL_ARG_TO);
2659
2660 // kernel handle 2-dimension tiles, a tail is possible
2661 auto &prb = this->pd()->prb_;
2662 ptrdiff_t BH = 1;
2663 for (int i = 2; i < prb.ndims; ++i) {
2664 BH *= prb.nodes[i].n;
2665 }
2666
2667 auto block_sz = prb.n(0);
2668 auto n1 = prb.n(1);
2669 auto i1 = prb.is(1);
2670 auto o1 = prb.os(1);
2671 auto FL = (n1 + block_sz - 1) / block_sz;
2672 auto bh_stride = BH == 1 ? 0 : prb.is(2);
2673
2674 auto itype_sz_ = data_type_size(pd()->prb_.itype);
2675 auto otype_sz_ = data_type_size(pd()->prb_.otype);
2676
2677 parallel_nd(BH, FL, [&](dim_t bh, dim_t fl) {
2678 auto fl_b = fl * block_sz;
2679 auto bh_b = bh_stride * bh;
2680 auto *i = in + (bh_b + fl_b * i1) * itype_sz_;
2681 auto *o = out + (bh_b + fl_b * o1) * otype_sz_;
2682 (*kernel_)(i, o, n1 - fl_b < block_sz, src_zp, dst_zp);
2683 });
2684
2685 return status::success;
2686}
2687
2688} // namespace x64
2689} // namespace cpu
2690} // namespace impl
2691} // namespace dnnl
2692