1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "common/c_types_map.hpp"
18#include "common/nstl.hpp"
19#include "common/type_helpers.hpp"
20#include "common/utils.hpp"
21#include "cpu/x64/cpu_barrier.hpp"
22#include "cpu/x64/jit_generator.hpp"
23
24#include "cpu/x64/jit_transpose_utils.hpp"
25
26namespace dnnl {
27namespace impl {
28namespace cpu {
29namespace x64 {
30
31using namespace dnnl::impl::utils;
32using namespace Xbyak;
33
34#define GET_OFF(x) offsetof(ctx_t, x)
35
36struct jit_trans_iw_ic_int16_t : public jit_trans_src_t, public jit_generator {
37 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_trans_iw_ic_int16_t)
38 jit_trans_iw_ic_int16_t(const jit_conv_conf_t *conf)
39 : jit_trans_src_t(conf), jit_generator(jit_name()) {}
40
41 void operator()(ctx_t *ctx) override { jit_generator::operator()(ctx); }
42
43 status_t create_kernel() override { return jit_generator::create_kernel(); }
44
45private:
46 using reg64_t = const Xbyak::Reg64;
47 using reg32_t = const Xbyak::Reg32;
48 using opmask_t = const Xbyak::Opmask;
49
50 enum {
51 typesize = sizeof(int16_t),
52 transpose_size = 16,
53 small_spatial = 14
54 };
55 size_t src_stride = 0, tr_src_stride = 0;
56 int tail = 0;
57 bool enable_prefetch = false;
58
59 opmask_t kFFFF = k1;
60 opmask_t k5555 = k2;
61 opmask_t kAAAA = k3;
62 opmask_t kAA = k4;
63 opmask_t k55 = k5;
64 opmask_t kCC = k6;
65 opmask_t k33 = k7;
66 opmask_t kTail = k1;
67
68 reg64_t reg_src = r8;
69 reg64_t reg_tr_src = r9;
70 reg64_t reg_src_prf = r10;
71 reg64_t reg_tr_src_prf = r11;
72 reg64_t reg_loop = r12;
73 reg64_t reg_tr_src_tmp = r13;
74 reg32_t regw_tmp = r14d;
75 reg64_t imm_addr64 = rbx;
76
77 Xbyak::Zmm vidx1 = zmm31;
78 Xbyak::Zmm vidx2 = zmm30;
79 Xbyak::Zmm vidx3 = zmm29;
80 Xbyak::Zmm vidx4 = zmm28;
81 Xbyak::Zmm vidx5 = zmm27;
82 Xbyak::Zmm zmm_tmp = zmm26;
83
84 void transpose(int nrows, int l_pad, int r_pad, bool nontemporal_stores);
85 void generate() override;
86};
87
88void jit_trans_iw_ic_int16_t::transpose(
89 int nrows, int l_pad, int r_pad, bool nontemporal_stores) {
90 assert(nrows >= 0 && nrows <= transpose_size);
91 static_assert(transpose_size == 16, "Unsupported transpose size");
92 if (!nrows) return;
93
94 auto src_zmm = [=](int i) { return Zmm(i); };
95
96 auto src_ymm = [=](int i) {
97 assert(i >= 0 && i < 16);
98 return Ymm(i);
99 };
100
101 auto load_ymm = [=](int i) {
102 vmovups(src_ymm(i), EVEX_compress_addr(reg_src, i * src_stride));
103 };
104
105 auto kmovw = [=](Opmask k, unsigned w) {
106 mov(regw_tmp, w);
107 jit_generator::kmovw(k, regw_tmp);
108 };
109
110 auto kmovd = [=](Opmask k, unsigned w) {
111 mov(regw_tmp, w);
112 jit_generator::kmovd(k, regw_tmp);
113 };
114
115 auto store = [=](Zmm r, int i) {
116 auto padding = [=](Reg64 base, int pad_rows, int pad_tail) {
117 // note: pad can be bigger than 16 because of dilation
118 const size_t row_offset = 2 * transpose_size * typesize;
119 auto zmm_zero = zmm_tmp;
120 vpxord(zmm_zero, zmm_zero, zmm_zero);
121 for (int i_row = 0; i_row < pad_rows; i_row++) {
122 auto addr = EVEX_compress_addr(
123 base, i * tr_src_stride + i_row * row_offset);
124 vmovups(addr, zmm_zero);
125 }
126 if (pad_tail > 0) {
127 kmovw(kTail, (1 << pad_tail) - 1);
128 base.setOpmaskIdx(kTail.getIdx(), true);
129 auto addr = EVEX_compress_addr(
130 base, i * tr_src_stride + pad_rows * row_offset);
131 vmovups(addr, zmm_zero);
132 }
133 };
134
135 mov(reg_tr_src_tmp, reg_tr_src);
136 if (l_pad > 0) {
137 int store_pad = 2 * transpose_size;
138 int pad_rows = l_pad / store_pad;
139 int tail = l_pad % store_pad;
140 padding(reg_tr_src_tmp, pad_rows, div_up(tail, 2));
141 add(reg_tr_src_tmp, (pad_rows * store_pad + tail) * typesize);
142 }
143 if (r_pad > 0) {
144 int addr_shift = nrows - r_pad % 2;
145 int store_pad = div_up(r_pad, 2);
146 int pad_rows = store_pad / transpose_size;
147 add(reg_tr_src_tmp, addr_shift * typesize);
148 padding(reg_tr_src_tmp, pad_rows, store_pad % transpose_size);
149 sub(reg_tr_src_tmp, addr_shift * typesize);
150 }
151
152 int store_tail = rnd_up(nrows, 2);
153 kmovw(kTail, (1 << store_tail / 2) - 1);
154 auto k = kTail;
155 auto base = reg_tr_src_tmp;
156 base.setOpmaskIdx(k.getIdx(), true);
157
158 auto addr = EVEX_compress_addr(base, i * tr_src_stride);
159 vmovups(addr, r);
160 };
161
162 const bool is_layout_nxc = utils::one_of(conf_->src_tag, format_tag::ndhwc,
163 format_tag::nhwc, format_tag::nwc);
164 const int ic_block = conf_->ic_block;
165 const bool is_tail_block = ic_block != 16;
166 const int ic_tail = conf_->ic_tail;
167 // Assertion below as we need vmovdqu16 for ic_tails.
168 // If needed, can be extended by using load_bytes() helper.
169 assert(IMPLICATION(ic_tail, mayiuse(avx512_core)));
170 if (mayiuse(avx512_core)) {
171 if (conf_->stride_w > 1 || nrows % 2 || is_layout_nxc)
172 kmovd(kFFFF, (1 << ic_block) - 1);
173 if (conf_->stride_w > 1 || is_layout_nxc) kmovd(k33, 0xffff0000);
174 if (is_layout_nxc && conf_->ic_tail) {
175 Label done;
176 cmp(dword[param1 + GET_OFF(ch_work)], ic_block);
177 je(done, T_NEAR);
178 kmovd(kFFFF, (1 << conf_->ic_tail) - 1);
179 kshiftld(k33, kFFFF, 16);
180 L(done);
181 }
182
183 for (int i = 0; i < nrows / 2; i++) {
184 auto zmm_src0 = src_zmm(2 * i);
185 if (conf_->stride_w == 1 && !is_layout_nxc) {
186 vmovdqu16(zmm_src0,
187 EVEX_compress_addr(reg_src, 2 * i * src_stride));
188 } else {
189 vmovdqu16(zmm_src0 | kFFFF | T_z,
190 EVEX_compress_addr(reg_src, 2 * i * src_stride));
191 if (is_tail_block || ic_tail) {
192 auto zmm_tmp = src_zmm(2 * i + 1);
193 vmovdqu16(zmm_tmp | kFFFF | T_z,
194 EVEX_compress_addr(
195 reg_src, (2 * i + 1) * src_stride));
196 vinsertf64x4(zmm_src0, zmm_src0, src_ymm(2 * i + 1), 1);
197 } else {
198 vmovdqu16(zmm_src0 | k33,
199 EVEX_compress_addr(
200 reg_src, (2 * i + 1) * src_stride - 32));
201 }
202 }
203 vpermw(zmm_src0, vidx5, zmm_src0);
204 }
205
206 // for odd numbers we need to mix row with zeroes
207 if (nrows % 2) {
208 int i = nrows / 2;
209 auto zmm_src0 = src_zmm(2 * i);
210 vmovdqu16(zmm_src0 | kFFFF | T_z,
211 EVEX_compress_addr(reg_src, 2 * i * src_stride));
212 vpermw(zmm_src0, vidx5, zmm_src0);
213 }
214
215 if (conf_->stride_w > 1 || is_layout_nxc) kmovw(k33, 0x33);
216
217 for (int i = rnd_up(nrows, 2); i < 16; i += 2) {
218 vpxord(src_zmm(i), src_zmm(i), src_zmm(i));
219 }
220 } else {
221 kmovw(kFFFF, 0xffff);
222 // all loads
223 for (int i = 0; i < 16; i++) {
224 vpxord(src_zmm(i), src_zmm(i), src_zmm(i));
225 }
226
227 for (int i = 0; i < nrows / 2; i++) {
228 auto src0 = src_ymm(2 * i);
229 auto src1 = src_ymm(2 * i + 1);
230 auto zmm_src0 = src_zmm(2 * i);
231 load_ymm(2 * i);
232
233 vpunpcklwd(src1, src0,
234 EVEX_compress_addr(reg_src, (2 * i + 1) * src_stride));
235 vpunpckhwd(src0, src0,
236 EVEX_compress_addr(reg_src, (2 * i + 1) * src_stride));
237 vinserti64x4(zmm_src0, zmm_src0, src1, 1);
238 vpermps(zmm_src0 | kFFFF, vidx4, zmm_src0);
239 }
240
241 // for odd numbers we need to mix row with zeroes
242 if (nrows % 2) {
243 int i = nrows - 1;
244 auto src0 = src_ymm(i);
245 auto src1 = src_ymm(i + 1); // zero
246
247 auto zmm_src0 = src_zmm(i);
248 vpxor(src1, src1, src1);
249
250 load_ymm(i);
251 vpunpckhwd(src0, src0, src1);
252 vinserti64x4(zmm_tmp, zmm_tmp, src0, 0);
253 vpxor(src0, src0, src0);
254 load_ymm(i);
255 vpunpcklwd(src1, src0, src1);
256 vinserti64x4(zmm_tmp, zmm_tmp, src1, 1);
257 vpxord(zmm_src0, zmm_src0, zmm_src0);
258 vmovups(zmm_src0, zmm_tmp);
259 vpermps(zmm_src0 | kFFFF, vidx4, zmm_src0);
260 }
261 }
262
263 // swap 1
264 for (int i = 0; i < 4; i++) {
265 auto zmm0 = src_zmm(4 * i);
266 auto zmm1 = src_zmm(4 * i + 2);
267 auto tmp0 = src_zmm(4 * i + 1);
268 auto tmp1 = src_zmm(4 * i + 3);
269
270 vmovups(tmp0, zmm0);
271 vmovups(tmp1, zmm1);
272
273 vpermps(tmp0 | kAAAA, vidx3, zmm1);
274 vpermps(tmp1 | k5555, vidx3, zmm0);
275 }
276 // swap 2
277 int base_idx;
278 base_idx = 0;
279 for (int i = 0; i < 2; i++) {
280 auto zmm0 = src_zmm(base_idx + 2 * i + 1);
281 auto zmm1 = src_zmm(base_idx + 2 * i + 5);
282
283 auto tmp0 = src_zmm(base_idx + 2 * i);
284 auto tmp1 = src_zmm(base_idx + 2 * i + 4);
285
286 vmovupd(tmp0, zmm0);
287 vmovupd(tmp1, zmm1);
288
289 vpermpd(tmp0 | kAA, vidx2, zmm1);
290 vpermpd(tmp1 | k55, vidx2, zmm0);
291 }
292 base_idx = 8;
293 for (int i = 0; i < 2; i++) {
294 auto zmm0 = src_zmm(base_idx + 2 * i + 1);
295 auto zmm1 = src_zmm(base_idx + 2 * i + 5);
296
297 auto tmp0 = src_zmm(base_idx + 2 * i);
298 auto tmp1 = src_zmm(base_idx + 2 * i + 4);
299
300 vmovupd(tmp0, zmm0);
301 vmovupd(tmp1, zmm1);
302
303 vpermpd(tmp0 | kAA, vidx2, zmm1);
304 vpermpd(tmp1 | k55, vidx2, zmm0);
305 }
306
307 // swap 3
308 for (int i = 0; i < 4; i++) {
309 auto zmm0 = src_zmm(2 * i);
310 auto zmm1 = src_zmm(2 * i + 8);
311
312 auto tmp0 = src_zmm(2 * i + 1);
313 auto tmp1 = src_zmm(2 * i + 9);
314
315 vmovupd(tmp0, zmm0);
316 vmovupd(tmp1, zmm1);
317
318 vpermpd(tmp0 | kCC, vidx1, zmm1);
319 vpermpd(tmp1 | k33, vidx1, zmm0);
320 }
321
322 // all stores
323 for (int i = 0; i < 8; i++)
324 vextracti64x4(src_ymm(2 * i), src_zmm(2 * i + 1), 1);
325
326 auto get_vec_idx = [=](int ic_idx) {
327 assert(ic_idx < 16 && ic_idx >= 0);
328 switch (ic_idx) {
329 case 0: return 1;
330 case 1: return 0;
331 case 2: return 3;
332 case 3: return 2;
333 case 4: return 9;
334 case 5: return 8;
335 case 6: return 11;
336 case 7: return 10;
337 case 8: return 5;
338 case 9: return 4;
339 case 10: return 7;
340 case 11: return 6;
341 case 12: return 13;
342 case 13: return 12;
343 case 14: return 15;
344 default: return 14;
345 }
346 };
347
348 for (int ic = 0; ic < ic_block; ic++)
349 store(src_zmm(get_vec_idx(ic)), ic);
350}
351
352void jit_trans_iw_ic_int16_t::generate() {
353 preamble();
354
355 alignas(64) static constexpr const int64_t idx1[8]
356 = {2, 3, 0, 1, 6, 7, 4, 5};
357 alignas(64) static constexpr const int64_t idx2[8]
358 = {1, 0, 3, 2, 5, 4, 7, 6};
359 alignas(64) static constexpr const int32_t idx3[16]
360 = {1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 15, 14};
361 alignas(64) static constexpr const int32_t idx4[16]
362 = {8, 10, 12, 14, 0, 2, 4, 6, 9, 11, 13, 15, 1, 3, 5, 7};
363 alignas(64) static constexpr const uint16_t idx5[32]
364 = {0, 16, 2, 18, 8, 24, 10, 26, 4, 20, 6, 22, 12, 28, 14, 30, 1, 17,
365 3, 19, 9, 25, 11, 27, 5, 21, 7, 23, 13, 29, 15, 31};
366
367 const int ic_block = conf_->ic_block;
368 const bool is_layout_nxc = utils::one_of(conf_->src_tag, format_tag::ndhwc,
369 format_tag::nhwc, format_tag::nwc);
370 const size_t src_mult
371 = is_layout_nxc ? conf_->ngroups * conf_->ic : ic_block;
372 const int iw = conf_->iw;
373 const int tr_iw = conf_->tr_iw;
374 const int str_w = conf_->stride_w;
375 assert(tr_iw % str_w == 0);
376 const int tr_iw_s = tr_iw / str_w;
377 assert(transpose_size >= ic_block);
378
379 auto kmovw = [=](Opmask k, unsigned w) {
380 mov(regw_tmp, w);
381 jit_generator::kmovw(k, regw_tmp);
382 };
383
384 kmovw(kFFFF, 0xffff);
385 kmovw(k5555, 0x5555);
386 kmovw(kAAAA, 0xaaaa);
387 kmovw(kAA, 0xaa);
388 kmovw(k55, 0x55);
389 kmovw(kCC, 0xcc);
390 kmovw(k33, 0x33);
391
392 auto vmovdqa64 = [=](Zmm z, const int64_t *addr) {
393 mov(imm_addr64, reinterpret_cast<size_t>(addr));
394 jit_generator::vmovdqa64(z, ptr[imm_addr64]);
395 };
396
397 auto vmovdqa32 = [=](Zmm z, const int32_t *addr) {
398 mov(imm_addr64, reinterpret_cast<size_t>(addr));
399 jit_generator::vmovdqa32(z, ptr[imm_addr64]);
400 };
401
402 vmovdqa64(vidx1, idx1);
403 vmovdqa64(vidx2, idx2);
404 vmovdqa32(vidx3, idx3);
405 vmovdqa32(vidx4, idx4);
406 vmovdqa32(vidx5, (const int32_t *)idx5);
407
408 // Data for every strided case is placed consecutively
409 // For 1x1 convolutions with strides we transpose only needed elements
410 const auto str_w_end = (conf_->kw == 1) ? 1 : str_w;
411 for (int s = 0; s < str_w_end; s++) {
412 const int left_pad = div_up(conf_->l_pad - s, str_w);
413 const int iw1 = iw + conf_->l_pad;
414 const int iw_s = (s < (iw1 % str_w) ? div_up(iw1, str_w) : iw1 / str_w)
415 - left_pad;
416 const int right_pad = tr_iw_s - iw_s - left_pad;
417
418 const int transposes = utils::div_up(iw_s, transpose_size);
419 int loop_iters = nstl::max(0, transposes - 1);
420 tail = iw_s - loop_iters * transpose_size;
421
422 src_stride = src_mult * typesize * str_w;
423 tr_src_stride = tr_iw * typesize;
424
425 bool nontemporal_stores = false;
426 enable_prefetch = iw > small_spatial ? true : false;
427
428 const size_t src_step = src_mult * transpose_size * str_w * typesize;
429 const size_t tr_src_step = transpose_size * typesize;
430
431 mov(reg_src, ptr[param1 + GET_OFF(src)]);
432 mov(reg_tr_src, ptr[param1 + GET_OFF(tr_src)]);
433 mov(reg_src_prf, ptr[param1 + GET_OFF(src_prf)]);
434 mov(reg_tr_src_prf, ptr[param1 + GET_OFF(tr_src_prf)]);
435
436 if (str_w > 1) {
437 int tr_src_shift = s;
438 int src_shift = (str_w - (conf_->l_pad % str_w) + s) % str_w;
439 add(reg_src, src_shift * src_mult * typesize);
440 add(reg_tr_src, tr_src_shift * tr_iw_s * typesize);
441 add(reg_src_prf, src_shift * src_mult * typesize);
442 add(reg_tr_src_prf, tr_src_shift * tr_iw_s * typesize);
443 }
444
445 if (left_pad > 0 && loop_iters > 0) {
446 loop_iters--;
447 transpose(transpose_size, left_pad, 0, nontemporal_stores);
448 add(reg_src, src_step);
449 add(reg_tr_src, tr_src_step + left_pad * typesize);
450 add(reg_src_prf, src_step);
451 add(reg_tr_src_prf, tr_src_step + left_pad * typesize);
452 }
453
454 if (loop_iters) {
455 mov(reg_loop, loop_iters);
456 Label loop;
457 L(loop);
458 {
459 transpose(transpose_size, 0, 0, nontemporal_stores);
460 add(reg_src, src_step);
461 add(reg_tr_src, tr_src_step);
462 add(reg_src_prf, src_step);
463 add(reg_tr_src_prf, tr_src_step);
464 sub(reg_loop, 1);
465 jnz(loop);
466 }
467 }
468 if (transposes > 1)
469 transpose(tail, 0, right_pad, nontemporal_stores);
470 else
471 transpose(tail, left_pad, right_pad, nontemporal_stores);
472 }
473 postamble();
474}
475
476struct jit_trans_ow_oc_t : public jit_trans_dst_t, public jit_generator {
477 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_trans_ow_oc_t)
478 jit_trans_ow_oc_t(const jit_conv_conf_t *conf)
479 : jit_trans_dst_t(conf), jit_generator(jit_name()) {}
480
481 void operator()(ctx_t *ctx) override { jit_generator::operator()(ctx); }
482
483 status_t create_kernel() override { return jit_generator::create_kernel(); }
484
485private:
486 using reg64_t = const Xbyak::Reg64;
487 using reg32_t = const Xbyak::Reg32;
488 using opmask_t = const Xbyak::Opmask;
489 using zmm = const Xbyak::Zmm;
490
491 enum {
492 typesize = sizeof(int16_t),
493 transpose_size = 16,
494 small_spatial = 14
495 };
496 size_t src_stride = 0, tr_src_stride = 0;
497 int tail = 0;
498 bool enable_prefetch = false;
499
500 opmask_t kFF = k1;
501 opmask_t mask_lo = k2;
502 opmask_t k_oc_tail = k3;
503
504 zmm vidx1 = zmm31;
505 zmm vidx2 = zmm30;
506
507 reg64_t reg_src = r8;
508 reg64_t reg_tr_src = r9;
509 reg64_t reg_src_prf = r10;
510 reg64_t reg_tr_src_prf = r11;
511 reg64_t reg_loop = r12;
512 reg64_t reg_tr_src_tmp = r13;
513 reg32_t regw_tmp = r14d;
514 reg64_t imm_addr64 = rbx;
515
516 void transpose(int nrows, int l_pad, int r_pad, bool nontemporal_stores,
517 bool do_convert = true);
518 void generate() override;
519};
520
521// do_convert (default is 'true') is a flag that determines when to do the
522// transformation of the input data and when to simply zero out the output data
523void jit_trans_ow_oc_t::transpose(int nrows, int l_pad, int r_pad,
524 bool nontemporal_stores, bool do_convert) {
525 assert(nrows >= 0 && nrows <= transpose_size);
526 static_assert(transpose_size == 16, "Unsupported transpose size");
527 if (!nrows) return;
528
529 auto src_zmm = [=](int i) { return Zmm(i); };
530
531 auto src_ymm = [=](int i) {
532 assert(i >= 0 && i < 16);
533 return Ymm(i);
534 };
535
536 auto load_ymm = [=](int i) {
537 auto ymm_reg = src_ymm(i);
538 auto addr = EVEX_compress_addr(reg_src, i * src_stride);
539 if (conf_->oc_tail) {
540 ymm_reg = ymm_reg | k_oc_tail | T_z;
541 // Assertion below as we need vmovdqu16 for tails.
542 // If needed, can be removed by using load_bytes() helper.
543 assert(mayiuse(avx512_core));
544 vmovdqu16(ymm_reg, addr);
545 } else {
546 vmovups(ymm_reg, addr);
547 }
548 };
549
550 auto store = [=](Zmm r, int i) {
551 auto addr = EVEX_compress_addr(reg_tr_src, i * tr_src_stride);
552 if (nontemporal_stores)
553 vmovntps(addr, r);
554 else
555 vmovups(addr, r);
556 };
557 const bool is_layout_nxc = utils::one_of(conf_->dst_tag, format_tag::ndhwc,
558 format_tag::nhwc, format_tag::nwc);
559
560 if (mayiuse(avx512_core) && !is_layout_nxc) {
561 // TODO: adopt for nhwc?
562 for (int i = 0; i < nrows / 2; i++) {
563 auto zmm_src0 = src_zmm(i);
564 if (do_convert) {
565 vmovdqu16(zmm_src0,
566 EVEX_compress_addr(reg_src, 2 * i * src_stride));
567 vpermw(zmm_src0, vidx2, zmm_src0);
568 } else {
569 vpxord(zmm_src0, zmm_src0, zmm_src0);
570 }
571 store(zmm_src0, 2 * i);
572 }
573 if (r_pad > 0) {
574 auto zmm_src0 = src_zmm(29);
575 if (do_convert) {
576 vmovdqu16(zmm_src0 | mask_lo | T_z,
577 EVEX_compress_addr(reg_src, (nrows - 1) * src_stride));
578 vpermw(zmm_src0, vidx2, zmm_src0);
579 } else {
580 vpxord(zmm_src0, zmm_src0, zmm_src0);
581 }
582 store(zmm_src0, nrows - 1);
583 }
584 } else {
585 for (int i = 0; i < nrows / 2; i++) {
586 auto src0 = src_ymm(2 * i);
587 auto src1 = src_ymm(2 * i + 1);
588 auto zmm_src0 = src_zmm(2 * i);
589 if (do_convert) {
590 load_ymm(2 * i);
591 if (is_layout_nxc && conf_->oc_tail) {
592 load_ymm(2 * i + 1);
593 auto ymm_tmp = Ymm(30);
594 vpunpcklwd(ymm_tmp, src0, src1);
595 vpunpckhwd(src0, src0, src1);
596 vinserti64x4(zmm_src0, zmm_src0, ymm_tmp, 1);
597 } else {
598 vpunpcklwd(src1, src0,
599 EVEX_compress_addr(
600 reg_src, (2 * i + 1) * src_stride));
601 vpunpckhwd(src0, src0,
602 EVEX_compress_addr(
603 reg_src, (2 * i + 1) * src_stride));
604 vinserti64x4(zmm_src0, zmm_src0, src1, 1);
605 }
606 vpermpd(zmm_src0 | kFF, vidx1, zmm_src0);
607 } else {
608 vpxord(zmm_src0, zmm_src0, zmm_src0);
609 }
610 store(zmm_src0, 2 * i);
611 }
612 if (r_pad > 0) {
613 auto src0 = src_ymm(nrows - 1);
614 auto src1 = src_ymm(nrows);
615 auto zmm_src0 = src_zmm(30);
616 if (do_convert) {
617 load_ymm(nrows - 1);
618
619 vpxor(src1, src1, src1);
620 vpunpckhwd(src1, src0, src1);
621 vinserti64x4(zmm_src0, zmm_src0, src1, 0);
622 vpxor(src1, src1, src1);
623 vpunpcklwd(src0, src0, src1);
624 vinserti64x4(zmm_src0, zmm_src0, src0, 1);
625 vpermpd(zmm_src0 | kFF, vidx1, zmm_src0);
626 } else {
627 vpxord(zmm_src0, zmm_src0, zmm_src0);
628 }
629 store(zmm_src0, nrows - 1);
630 }
631 }
632}
633
634void jit_trans_ow_oc_t::generate() {
635 preamble();
636
637 alignas(64) static constexpr const int64_t idx1[8]
638 = {4, 5, 0, 1, 6, 7, 2, 3};
639 alignas(64) static constexpr const int16_t idx2[32]
640 = {0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23, 8, 24, 9,
641 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31};
642
643 const int oc_block = conf_->oc_block;
644 const bool is_layout_nxc = utils::one_of(conf_->dst_tag, format_tag::ndhwc,
645 format_tag::nhwc, format_tag::nwc);
646 const size_t src_mult
647 = is_layout_nxc ? conf_->ngroups * conf_->oc : oc_block;
648 const int ow = conf_->ow;
649 const int transposes = utils::div_up(ow, transpose_size);
650 int loop_iters = nstl::max(0, transposes - 1);
651 tail = ow - loop_iters * transpose_size;
652
653 src_stride = src_mult * typesize;
654 tr_src_stride = oc_block * typesize;
655
656 bool nontemporal_stores = conf_->use_nt_stores_ddst;
657 enable_prefetch = ow > small_spatial;
658
659 const size_t src_step = src_mult * transpose_size * typesize;
660 const size_t tr_src_step = (size_t)oc_block * transpose_size * typesize;
661 const int right_pad = ow % 2;
662
663 const auto zero_tr_ow = nstl::max(0, conf_->tr_ow - ow - right_pad);
664
665 mov(reg_src, ptr[param1 + GET_OFF(src)]);
666 mov(reg_tr_src, ptr[param1 + GET_OFF(tr_src)]);
667 mov(reg_src_prf, ptr[param1 + GET_OFF(src_prf)]);
668 mov(reg_tr_src_prf, ptr[param1 + GET_OFF(tr_src_prf)]);
669
670 auto kmovw = [=](Opmask k, unsigned w) {
671 mov(regw_tmp, w);
672 jit_generator::kmovw(k, regw_tmp);
673 };
674 auto kmovd = [=](Opmask k, unsigned w) {
675 mov(regw_tmp, w);
676 jit_generator::kmovd(k, regw_tmp);
677 };
678
679 kmovw(kFF, 0xFF);
680 kmovd(mask_lo, 0x0000ffff);
681
682 if (is_layout_nxc && conf_->oc_tail) {
683 Label done;
684 kxnorw(k_oc_tail, k_oc_tail, k_oc_tail);
685 cmp(dword[param1 + GET_OFF(ch_work)], conf_->oc_block);
686 je(done, T_NEAR);
687 kmovw(k_oc_tail, (1 << conf_->oc_tail) - 1);
688 L(done);
689 }
690
691 auto vmovdqa64 = [=](Zmm z, const int64_t *addr) {
692 mov(imm_addr64, reinterpret_cast<size_t>(addr));
693 jit_generator::vmovdqa64(z, ptr[imm_addr64]);
694 };
695
696 vmovdqa64(vidx1, idx1);
697 vmovdqa64(vidx2, (const int64_t *)idx2);
698 if (loop_iters) {
699 mov(reg_loop, loop_iters);
700 Label loop;
701 L(loop);
702 {
703 transpose(transpose_size, 0, 0, nontemporal_stores);
704 add(reg_src, src_step);
705 add(reg_tr_src, tr_src_step);
706 add(reg_src_prf, src_step);
707 add(reg_tr_src_prf, tr_src_step);
708 sub(reg_loop, 1);
709 jnz(loop);
710 }
711 }
712 transpose(tail, 0, right_pad, nontemporal_stores);
713 if (zero_tr_ow) {
714 const auto zero_transposes = utils::div_up(zero_tr_ow, transpose_size);
715 const auto zero_loop_iters = nstl::max(0, zero_transposes - 1);
716 const auto zero_tail = zero_tr_ow - zero_loop_iters * transpose_size;
717 const auto zero_right_pad = zero_tr_ow % 2;
718
719 // shift over tail
720 auto tr_src_tail_step
721 = (size_t)oc_block * (tail + right_pad) * typesize;
722 add(reg_tr_src, tr_src_tail_step);
723 add(reg_tr_src_prf, tr_src_tail_step);
724
725 // zero the tr_ow - ow
726 if (zero_loop_iters) {
727 mov(reg_loop, zero_loop_iters);
728 Label zero_loop;
729 L(zero_loop);
730 {
731 transpose(transpose_size, 0, 0, nontemporal_stores, false);
732 add(reg_tr_src, tr_src_step);
733 add(reg_tr_src_prf, tr_src_step);
734 sub(reg_loop, 1);
735 jnz(zero_loop);
736 }
737 }
738 transpose(zero_tail, 0, zero_right_pad, nontemporal_stores, false);
739 }
740
741 postamble();
742}
743
744/*
745// -------------------------------------------------
746// jit_transpose4x16_src
747// -------------------------------------------------
748*/
749
750void jit_transpose4x16_src::transpose(int nrows) {
751 assert(nrows >= 0 && nrows <= transpose_size);
752 static_assert(transpose_size == 4, "Unsupported transpose size");
753 if (!nrows) return;
754
755 auto pf_src_t0 = [=](int i) {
756 if (tparams->src_pf0_distance)
757 prefetcht0(EVEX_compress_addr(
758 reg_src, (tparams->src_pf0_distance + i) * src_stride));
759 };
760
761 auto pf_tr_src_t0 = [=](int i) {
762 if (tparams->tr_src_pf0_distance)
763 prefetcht0(EVEX_compress_addr(reg_tr_src,
764 (tparams->tr_src_pf0_distance + i) * src_stride));
765 };
766
767 auto pf_src_t1 = [=](int i) {
768 if (tparams->src_pf1)
769 prefetcht1(EVEX_compress_addr(reg_src_prf, i * src_stride));
770 };
771
772 auto pf_tr_src_t1 = [=](int i) {
773 if (tparams->tr_src_pf1)
774 prefetchwt1(EVEX_compress_addr(reg_tr_src_prf, i * tr_src_stride));
775 };
776
777 auto src_zmm = [=](int i) {
778 assert(i >= 0 && i < 4);
779 return Zmm(i);
780 };
781
782 auto tmp_zmm = [=](int i) {
783 assert(i >= 0 && i < 4);
784 return Zmm(4 + i);
785 };
786
787 auto load = [=](int i) {
788 vmovups(src_zmm(i), EVEX_compress_addr(reg_src, i * src_stride));
789 };
790
791 auto store = [=](Zmm r, int i) {
792 vmovups(EVEX_compress_addr(reg_tr_src, i * tr_src_stride), r);
793 };
794
795 auto tmp0 = tmp_zmm(0);
796 auto tmp1 = tmp_zmm(1);
797 auto tmp2 = tmp_zmm(2);
798 auto tmp3 = tmp_zmm(3);
799
800 auto src0 = src_zmm(0);
801 auto src1 = src_zmm(1);
802 auto src2 = src_zmm(2);
803 auto src3 = src_zmm(3);
804 for (int i = 0; i < nrows; i++) {
805 load(i);
806 }
807
808 for (size_t i = nrows; i < 4; i++) {
809 vpxord(src_zmm(i), src_zmm(i), src_zmm(i));
810 }
811
812 vmovupd(tmp0, src0);
813 vmovupd(tmp1, src1);
814 pf_src_t0(0);
815 vpermpd(tmp0 | kF0, vidx01, src2);
816 vpermpd(tmp1 | kF0, vidx01, src3);
817
818 valignd(src0, src0, src0, 8);
819 valignd(src1, src1, src1, 8);
820 pf_src_t0(1);
821 vmovupd(tmp2, src0);
822 vmovupd(tmp3, src1);
823 pf_src_t0(2);
824 vpermpd(tmp2 | kF0, vidx10, src2);
825 vpermpd(tmp3 | kF0, vidx10, src3);
826 pf_src_t0(3);
827
828 vmovupd(src0, tmp0);
829 pf_src_t1(0);
830 vmovupd(src1, tmp2);
831 pf_src_t1(1);
832 vmovupd(src2, tmp1);
833 pf_src_t1(2);
834 vmovupd(src3, tmp3);
835 pf_src_t1(3);
836 vpermpd(src0 | kCC, vidx1, tmp1);
837 vpermpd(src1 | kCC, vidx1, tmp3);
838 pf_tr_src_t0(0);
839 vpermpd(src2 | k33, vidx1, tmp0);
840 vpermpd(src3 | k33, vidx1, tmp2);
841 pf_tr_src_t0(1);
842
843 vmovupd(tmp0, src0);
844 vmovupd(tmp1, src2);
845 pf_tr_src_t0(2);
846 vmovupd(tmp2, src1);
847 vmovupd(tmp3, src3);
848 pf_tr_src_t0(3);
849 vpermps(tmp0 | kFFFF, vidxP, src0);
850 pf_tr_src_t1(0);
851 vpermps(tmp1 | kFFFF, vidxP, src2);
852 pf_tr_src_t1(1);
853 vpermps(tmp2 | kFFFF, vidxP, src1);
854 pf_tr_src_t1(3);
855 vpermps(tmp3 | kFFFF, vidxP, src3);
856 pf_tr_src_t1(4);
857
858 store(tmp0, 0);
859 store(tmp1, 1);
860 store(tmp2, 2);
861 store(tmp3, 3);
862}
863
864alignas(64) static constexpr const int64_t idx01[8] = {0, 0, 0, 0, 0, 1, 2, 3};
865alignas(64) static constexpr const int64_t idx10[8] = {0, 0, 0, 0, 4, 5, 6, 7};
866alignas(64) static constexpr const int64_t idx1[8] = {2, 3, 0, 1, 6, 7, 4, 5};
867alignas(64) static constexpr const int32_t idxP[16]
868 = {0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15};
869
870void jit_transpose4x16_src::generate() {
871 preamble();
872
873 const int ic_block = params->ic_block;
874 const int is = params->is;
875 int tail = is % transpose_size;
876
877 src_stride = ic_block * typesize;
878 assert(src_stride == 64);
879 tr_src_stride = ic_block * typesize;
880
881 const int src_step = ic_block * transpose_size * typesize;
882 const int tr_src_step = ic_block * transpose_size * typesize;
883
884#define GET_TR_OFF(x) offsetof(jit_src_transpose_s, x)
885 mov(reg_loop, ptr[param1 + GET_TR_OFF(size)]);
886 mov(reg_src, ptr[param1 + GET_TR_OFF(src)]);
887 mov(reg_tr_src, ptr[param1 + GET_TR_OFF(tr_src)]);
888 mov(reg_src_prf, ptr[param1 + GET_TR_OFF(src_prf)]);
889 mov(reg_tr_src_prf, ptr[param1 + GET_TR_OFF(tr_src_prf)]);
890#undef GET_TR_OFF
891
892 auto kmovw = [=](Opmask k, unsigned w) {
893 mov(regw_tmp, w);
894 jit_generator::kmovw(k, regw_tmp);
895 };
896
897 auto vmovdqa64 = [=](Zmm z, const int64_t *addr) {
898 mov(imm_addr64, reinterpret_cast<size_t>(addr));
899 jit_generator::vmovdqa64(z, ptr[imm_addr64]);
900 };
901
902 auto vmovdqa32 = [=](Zmm z, const int32_t *addr) {
903 mov(imm_addr64, reinterpret_cast<size_t>(addr));
904 jit_generator::vmovdqa32(z, ptr[imm_addr64]);
905 };
906
907 kmovw(kF0, 0xf0); // 11110000
908 kmovw(kCC, 0xcc); // 11001100
909 kmovw(k33, 0x33); // 00110011
910 kmovw(kFFFF, 0xffff); // 1111111111111111
911
912 vmovdqa64(vidx01, idx01);
913 vmovdqa64(vidx10, idx10);
914 vmovdqa64(vidx1, idx1);
915 vmovdqa32(vidxP, idxP);
916
917 Label loop_label;
918 Label tail_label;
919
920 cmp(reg_loop, transpose_size);
921 jl(tail_label, T_NEAR);
922
923 L(loop_label);
924 {
925 transpose(transpose_size);
926 add(reg_src, src_step);
927 add(reg_tr_src, tr_src_step);
928 add(reg_src_prf, src_step);
929 add(reg_tr_src_prf, tr_src_step);
930 sub(reg_loop, transpose_size);
931 cmp(reg_loop, transpose_size);
932 jge(loop_label, T_NEAR);
933 }
934 L(tail_label);
935 transpose(tail);
936
937 postamble();
938}
939
940#undef GET_OFF
941
942#define GET_OFF(field) offsetof(jit_conv_call_s, field)
943
944void jit_diff_wei_trans_to_vnni_t::generate() {
945 /* Reorder part of F32 weights tensor
946 from [2I][kd][kh][kw][16i][16o] to VNNI format [kd][kh][kw][16i][16o][2i]
947 and downconvert it to Bfloat16. */
948 const int typesize_out = 2;
949 const int typesize_acc = 4;
950 const int simd_w = 16;
951
952 using reg64_t = const Xbyak::Reg64;
953 const reg64_t &reg_output = r15;
954 const reg64_t &org_reg_output = r14;
955 const reg64_t &reg_input = r13;
956 const reg64_t &reg_input_1 = r12;
957 const reg64_t &org_reg_input_1 = r11;
958 const reg64_t &reg_input_2 = r10;
959 const reg64_t &reg_prm_table = r9;
960 const reg64_t &reg_last_ic_block = rax;
961 const reg64_t &reg_kd = rsi;
962 const reg64_t &reg_kh = abi_not_param1;
963 const reg64_t &reg_tmp = rdx;
964
965 const Xbyak::Zmm &zmm_idx = Xbyak::Zmm(31);
966 auto get_zmm_src_0 = [&](int ic) { return Xbyak::Zmm(ic); };
967 auto get_zmm_src_1 = [&](int ic) { return Xbyak::Zmm(4 + ic); };
968 auto get_zmm_bf16 = [&](int ic) { return Xbyak::Zmm(8 + ic); };
969
970 Xbyak::Label prm_table, zero_buffer;
971 Xbyak::Label kd_loop_label, kh_loop_label;
972
973 preamble();
974
975 mov(reg_last_ic_block, ptr[abi_param1 + GET_OFF(last_ic_block)]);
976 mov(org_reg_input_1, ptr[abi_param1 + GET_OFF(src)]);
977 mov(org_reg_output, ptr[abi_param1 + GET_OFF(dst)]);
978
979 mov(reg_prm_table, prm_table);
980 vmovups(zmm_idx, ptr[reg_prm_table]);
981
982 xor_(reg_kd, reg_kd);
983 L(kd_loop_label);
984 {
985 mov(reg_output, org_reg_output);
986 mov(reg_input_1, org_reg_input_1);
987 xor_(reg_kh, reg_kh);
988 L(kh_loop_label);
989 {
990 for (int kw = 0; kw < kw_; kw++) {
991 Xbyak::Label last_ic_label, done_ic_label;
992
993 dim_t out_offset
994 = (dim_t)typesize_out * kw * ic_block_ * oc_block_ * 2;
995 dim_t inp_1_offset
996 = (dim_t)typesize_acc * kw * ic_block_ * oc_block_;
997 dim_t inp_2_offset = (dim_t)typesize_acc
998 * (kd_ * kh_ * kw_ * ic_block_ * oc_block_
999 + kw * ic_block_ * oc_block_);
1000
1001 cmp(reg_last_ic_block, 0);
1002 jne(last_ic_label, T_NEAR);
1003
1004 mov(reg_input_2, reg_input_1);
1005 safe_add(reg_input_2, inp_2_offset, reg_tmp);
1006 jmp(done_ic_label, T_NEAR);
1007
1008 L(last_ic_label);
1009 mov(reg_input_2, zero_buffer);
1010
1011 L(done_ic_label);
1012
1013 for (int ocb = 0; ocb < oc_block_; ocb += simd_w) {
1014 int ic_count = 0;
1015 for (int bc = 0; bc < 2; bc++) {
1016 if (!bc) {
1017 mov(reg_input, reg_input_1);
1018 safe_add(reg_input, inp_1_offset, reg_tmp);
1019 } else
1020 mov(reg_input, reg_input_2);
1021
1022 for (int ic = 0; ic < ic_block_ / 2; ic++) {
1023 auto zmm_src_0 = get_zmm_src_0(ic);
1024 auto zmm_src_1 = get_zmm_src_1(ic);
1025 auto zmm_out = get_zmm_bf16(ic);
1026
1027 vmovups(zmm_src_0,
1028 ptr[reg_input
1029 + typesize_acc
1030 * ((2 * ic + 0) * oc_block_
1031 + ocb)]);
1032 vmovups(zmm_src_1,
1033 ptr[reg_input
1034 + typesize_acc
1035 * ((2 * ic + 1) * oc_block_
1036 + ocb)]);
1037 if (out_dt_ == data_type::bf16) {
1038 vcvtne2ps2bf16(zmm_out, zmm_src_1, zmm_src_0);
1039 } else if (out_dt_ == data_type::f16) {
1040 vcvtps2phx(Ymm(zmm_src_0.getIdx()), zmm_src_0);
1041 vcvtps2phx(Ymm(zmm_src_1.getIdx()), zmm_src_1);
1042 vinsertf32x8(zmm_out, zmm_src_0,
1043 Ymm(zmm_src_1.getIdx()), 1);
1044 } else {
1045 assert(!"unsupported data type");
1046 }
1047 vpermw(zmm_out, zmm_idx, zmm_out);
1048
1049 vmovups(ptr[reg_output + out_offset
1050 + typesize_out
1051 * (ic_count * oc_block_ * 2
1052 + ocb * 2)],
1053 zmm_out);
1054 ic_count++;
1055 }
1056 }
1057 }
1058 }
1059 safe_add(reg_output,
1060 (dim_t)typesize_out * kw_ * 2 * ic_block_ * oc_block_,
1061 reg_tmp);
1062 safe_add(reg_input_1,
1063 (dim_t)typesize_acc * kw_ * ic_block_ * oc_block_, reg_tmp);
1064
1065 add(reg_kh, 1);
1066 cmp(reg_kh, kh_);
1067 jl(kh_loop_label, T_NEAR);
1068 }
1069 safe_add(org_reg_output,
1070 (dim_t)typesize_out * kh_ * kw_ * 2 * ic_block_ * oc_block_,
1071 reg_tmp);
1072 safe_add(org_reg_input_1,
1073 (dim_t)typesize_acc * kh_ * kw_ * ic_block_ * oc_block_,
1074 reg_tmp);
1075
1076 add(reg_kd, 1);
1077 cmp(reg_kd, kd_);
1078 jl(kd_loop_label, T_NEAR);
1079 }
1080
1081 postamble();
1082
1083 align(64);
1084 L(prm_table);
1085 const uint16_t prm_array[32]
1086 = {0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23, 8, 24, 9,
1087 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31};
1088 for (size_t i = 0; i < 32; ++i)
1089 dw(prm_array[i]);
1090
1091 align(64);
1092 L(zero_buffer);
1093 const uint16_t zero = 0;
1094 for (int i = 0; i < typesize_acc * oc_block_ * ic_block_; ++i)
1095 db(zero);
1096}
1097
1098#undef GET_OFF
1099
1100jit_trans_src_t *create_trans_src(const jit_conv_conf_t *conf) {
1101 if (conf->has_vnni && IMPLICATION(conf->is_1stconv, conf->transpose_src))
1102 return new jit_trans_iw_ic_int16_t(conf);
1103 assert(!"unsupported configuration");
1104 return nullptr;
1105}
1106
1107jit_trans_dst_t *create_trans_dst(const jit_conv_conf_t *conf) {
1108 if (conf->has_vnni) return new jit_trans_ow_oc_t(conf);
1109 assert(!"unsupported configuration");
1110 return nullptr;
1111}
1112} // namespace x64
1113} // namespace cpu
1114} // namespace impl
1115} // namespace dnnl
1116