1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "common/c_types_map.hpp" |
18 | #include "common/nstl.hpp" |
19 | #include "common/type_helpers.hpp" |
20 | #include "common/utils.hpp" |
21 | #include "cpu/x64/cpu_barrier.hpp" |
22 | #include "cpu/x64/jit_generator.hpp" |
23 | |
24 | #include "cpu/x64/jit_transpose_utils.hpp" |
25 | |
26 | namespace dnnl { |
27 | namespace impl { |
28 | namespace cpu { |
29 | namespace x64 { |
30 | |
31 | using namespace dnnl::impl::utils; |
32 | using namespace Xbyak; |
33 | |
34 | #define GET_OFF(x) offsetof(ctx_t, x) |
35 | |
36 | struct jit_trans_iw_ic_int16_t : public jit_trans_src_t, public jit_generator { |
37 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_trans_iw_ic_int16_t) |
38 | jit_trans_iw_ic_int16_t(const jit_conv_conf_t *conf) |
39 | : jit_trans_src_t(conf), jit_generator(jit_name()) {} |
40 | |
41 | void operator()(ctx_t *ctx) override { jit_generator::operator()(ctx); } |
42 | |
43 | status_t create_kernel() override { return jit_generator::create_kernel(); } |
44 | |
45 | private: |
46 | using reg64_t = const Xbyak::Reg64; |
47 | using reg32_t = const Xbyak::Reg32; |
48 | using opmask_t = const Xbyak::Opmask; |
49 | |
50 | enum { |
51 | typesize = sizeof(int16_t), |
52 | transpose_size = 16, |
53 | small_spatial = 14 |
54 | }; |
55 | size_t src_stride = 0, tr_src_stride = 0; |
56 | int tail = 0; |
57 | bool enable_prefetch = false; |
58 | |
59 | opmask_t kFFFF = k1; |
60 | opmask_t k5555 = k2; |
61 | opmask_t kAAAA = k3; |
62 | opmask_t kAA = k4; |
63 | opmask_t k55 = k5; |
64 | opmask_t kCC = k6; |
65 | opmask_t k33 = k7; |
66 | opmask_t kTail = k1; |
67 | |
68 | reg64_t reg_src = r8; |
69 | reg64_t reg_tr_src = r9; |
70 | reg64_t reg_src_prf = r10; |
71 | reg64_t reg_tr_src_prf = r11; |
72 | reg64_t reg_loop = r12; |
73 | reg64_t reg_tr_src_tmp = r13; |
74 | reg32_t regw_tmp = r14d; |
75 | reg64_t imm_addr64 = rbx; |
76 | |
77 | Xbyak::Zmm vidx1 = zmm31; |
78 | Xbyak::Zmm vidx2 = zmm30; |
79 | Xbyak::Zmm vidx3 = zmm29; |
80 | Xbyak::Zmm vidx4 = zmm28; |
81 | Xbyak::Zmm vidx5 = zmm27; |
82 | Xbyak::Zmm zmm_tmp = zmm26; |
83 | |
84 | void transpose(int nrows, int l_pad, int r_pad, bool nontemporal_stores); |
85 | void generate() override; |
86 | }; |
87 | |
88 | void jit_trans_iw_ic_int16_t::transpose( |
89 | int nrows, int l_pad, int r_pad, bool nontemporal_stores) { |
90 | assert(nrows >= 0 && nrows <= transpose_size); |
91 | static_assert(transpose_size == 16, "Unsupported transpose size" ); |
92 | if (!nrows) return; |
93 | |
94 | auto src_zmm = [=](int i) { return Zmm(i); }; |
95 | |
96 | auto src_ymm = [=](int i) { |
97 | assert(i >= 0 && i < 16); |
98 | return Ymm(i); |
99 | }; |
100 | |
101 | auto load_ymm = [=](int i) { |
102 | vmovups(src_ymm(i), EVEX_compress_addr(reg_src, i * src_stride)); |
103 | }; |
104 | |
105 | auto kmovw = [=](Opmask k, unsigned w) { |
106 | mov(regw_tmp, w); |
107 | jit_generator::kmovw(k, regw_tmp); |
108 | }; |
109 | |
110 | auto kmovd = [=](Opmask k, unsigned w) { |
111 | mov(regw_tmp, w); |
112 | jit_generator::kmovd(k, regw_tmp); |
113 | }; |
114 | |
115 | auto store = [=](Zmm r, int i) { |
116 | auto padding = [=](Reg64 base, int pad_rows, int pad_tail) { |
117 | // note: pad can be bigger than 16 because of dilation |
118 | const size_t row_offset = 2 * transpose_size * typesize; |
119 | auto zmm_zero = zmm_tmp; |
120 | vpxord(zmm_zero, zmm_zero, zmm_zero); |
121 | for (int i_row = 0; i_row < pad_rows; i_row++) { |
122 | auto addr = EVEX_compress_addr( |
123 | base, i * tr_src_stride + i_row * row_offset); |
124 | vmovups(addr, zmm_zero); |
125 | } |
126 | if (pad_tail > 0) { |
127 | kmovw(kTail, (1 << pad_tail) - 1); |
128 | base.setOpmaskIdx(kTail.getIdx(), true); |
129 | auto addr = EVEX_compress_addr( |
130 | base, i * tr_src_stride + pad_rows * row_offset); |
131 | vmovups(addr, zmm_zero); |
132 | } |
133 | }; |
134 | |
135 | mov(reg_tr_src_tmp, reg_tr_src); |
136 | if (l_pad > 0) { |
137 | int store_pad = 2 * transpose_size; |
138 | int pad_rows = l_pad / store_pad; |
139 | int tail = l_pad % store_pad; |
140 | padding(reg_tr_src_tmp, pad_rows, div_up(tail, 2)); |
141 | add(reg_tr_src_tmp, (pad_rows * store_pad + tail) * typesize); |
142 | } |
143 | if (r_pad > 0) { |
144 | int addr_shift = nrows - r_pad % 2; |
145 | int store_pad = div_up(r_pad, 2); |
146 | int pad_rows = store_pad / transpose_size; |
147 | add(reg_tr_src_tmp, addr_shift * typesize); |
148 | padding(reg_tr_src_tmp, pad_rows, store_pad % transpose_size); |
149 | sub(reg_tr_src_tmp, addr_shift * typesize); |
150 | } |
151 | |
152 | int store_tail = rnd_up(nrows, 2); |
153 | kmovw(kTail, (1 << store_tail / 2) - 1); |
154 | auto k = kTail; |
155 | auto base = reg_tr_src_tmp; |
156 | base.setOpmaskIdx(k.getIdx(), true); |
157 | |
158 | auto addr = EVEX_compress_addr(base, i * tr_src_stride); |
159 | vmovups(addr, r); |
160 | }; |
161 | |
162 | const bool is_layout_nxc = utils::one_of(conf_->src_tag, format_tag::ndhwc, |
163 | format_tag::nhwc, format_tag::nwc); |
164 | const int ic_block = conf_->ic_block; |
165 | const bool is_tail_block = ic_block != 16; |
166 | const int ic_tail = conf_->ic_tail; |
167 | // Assertion below as we need vmovdqu16 for ic_tails. |
168 | // If needed, can be extended by using load_bytes() helper. |
169 | assert(IMPLICATION(ic_tail, mayiuse(avx512_core))); |
170 | if (mayiuse(avx512_core)) { |
171 | if (conf_->stride_w > 1 || nrows % 2 || is_layout_nxc) |
172 | kmovd(kFFFF, (1 << ic_block) - 1); |
173 | if (conf_->stride_w > 1 || is_layout_nxc) kmovd(k33, 0xffff0000); |
174 | if (is_layout_nxc && conf_->ic_tail) { |
175 | Label done; |
176 | cmp(dword[param1 + GET_OFF(ch_work)], ic_block); |
177 | je(done, T_NEAR); |
178 | kmovd(kFFFF, (1 << conf_->ic_tail) - 1); |
179 | kshiftld(k33, kFFFF, 16); |
180 | L(done); |
181 | } |
182 | |
183 | for (int i = 0; i < nrows / 2; i++) { |
184 | auto zmm_src0 = src_zmm(2 * i); |
185 | if (conf_->stride_w == 1 && !is_layout_nxc) { |
186 | vmovdqu16(zmm_src0, |
187 | EVEX_compress_addr(reg_src, 2 * i * src_stride)); |
188 | } else { |
189 | vmovdqu16(zmm_src0 | kFFFF | T_z, |
190 | EVEX_compress_addr(reg_src, 2 * i * src_stride)); |
191 | if (is_tail_block || ic_tail) { |
192 | auto zmm_tmp = src_zmm(2 * i + 1); |
193 | vmovdqu16(zmm_tmp | kFFFF | T_z, |
194 | EVEX_compress_addr( |
195 | reg_src, (2 * i + 1) * src_stride)); |
196 | vinsertf64x4(zmm_src0, zmm_src0, src_ymm(2 * i + 1), 1); |
197 | } else { |
198 | vmovdqu16(zmm_src0 | k33, |
199 | EVEX_compress_addr( |
200 | reg_src, (2 * i + 1) * src_stride - 32)); |
201 | } |
202 | } |
203 | vpermw(zmm_src0, vidx5, zmm_src0); |
204 | } |
205 | |
206 | // for odd numbers we need to mix row with zeroes |
207 | if (nrows % 2) { |
208 | int i = nrows / 2; |
209 | auto zmm_src0 = src_zmm(2 * i); |
210 | vmovdqu16(zmm_src0 | kFFFF | T_z, |
211 | EVEX_compress_addr(reg_src, 2 * i * src_stride)); |
212 | vpermw(zmm_src0, vidx5, zmm_src0); |
213 | } |
214 | |
215 | if (conf_->stride_w > 1 || is_layout_nxc) kmovw(k33, 0x33); |
216 | |
217 | for (int i = rnd_up(nrows, 2); i < 16; i += 2) { |
218 | vpxord(src_zmm(i), src_zmm(i), src_zmm(i)); |
219 | } |
220 | } else { |
221 | kmovw(kFFFF, 0xffff); |
222 | // all loads |
223 | for (int i = 0; i < 16; i++) { |
224 | vpxord(src_zmm(i), src_zmm(i), src_zmm(i)); |
225 | } |
226 | |
227 | for (int i = 0; i < nrows / 2; i++) { |
228 | auto src0 = src_ymm(2 * i); |
229 | auto src1 = src_ymm(2 * i + 1); |
230 | auto zmm_src0 = src_zmm(2 * i); |
231 | load_ymm(2 * i); |
232 | |
233 | vpunpcklwd(src1, src0, |
234 | EVEX_compress_addr(reg_src, (2 * i + 1) * src_stride)); |
235 | vpunpckhwd(src0, src0, |
236 | EVEX_compress_addr(reg_src, (2 * i + 1) * src_stride)); |
237 | vinserti64x4(zmm_src0, zmm_src0, src1, 1); |
238 | vpermps(zmm_src0 | kFFFF, vidx4, zmm_src0); |
239 | } |
240 | |
241 | // for odd numbers we need to mix row with zeroes |
242 | if (nrows % 2) { |
243 | int i = nrows - 1; |
244 | auto src0 = src_ymm(i); |
245 | auto src1 = src_ymm(i + 1); // zero |
246 | |
247 | auto zmm_src0 = src_zmm(i); |
248 | vpxor(src1, src1, src1); |
249 | |
250 | load_ymm(i); |
251 | vpunpckhwd(src0, src0, src1); |
252 | vinserti64x4(zmm_tmp, zmm_tmp, src0, 0); |
253 | vpxor(src0, src0, src0); |
254 | load_ymm(i); |
255 | vpunpcklwd(src1, src0, src1); |
256 | vinserti64x4(zmm_tmp, zmm_tmp, src1, 1); |
257 | vpxord(zmm_src0, zmm_src0, zmm_src0); |
258 | vmovups(zmm_src0, zmm_tmp); |
259 | vpermps(zmm_src0 | kFFFF, vidx4, zmm_src0); |
260 | } |
261 | } |
262 | |
263 | // swap 1 |
264 | for (int i = 0; i < 4; i++) { |
265 | auto zmm0 = src_zmm(4 * i); |
266 | auto zmm1 = src_zmm(4 * i + 2); |
267 | auto tmp0 = src_zmm(4 * i + 1); |
268 | auto tmp1 = src_zmm(4 * i + 3); |
269 | |
270 | vmovups(tmp0, zmm0); |
271 | vmovups(tmp1, zmm1); |
272 | |
273 | vpermps(tmp0 | kAAAA, vidx3, zmm1); |
274 | vpermps(tmp1 | k5555, vidx3, zmm0); |
275 | } |
276 | // swap 2 |
277 | int base_idx; |
278 | base_idx = 0; |
279 | for (int i = 0; i < 2; i++) { |
280 | auto zmm0 = src_zmm(base_idx + 2 * i + 1); |
281 | auto zmm1 = src_zmm(base_idx + 2 * i + 5); |
282 | |
283 | auto tmp0 = src_zmm(base_idx + 2 * i); |
284 | auto tmp1 = src_zmm(base_idx + 2 * i + 4); |
285 | |
286 | vmovupd(tmp0, zmm0); |
287 | vmovupd(tmp1, zmm1); |
288 | |
289 | vpermpd(tmp0 | kAA, vidx2, zmm1); |
290 | vpermpd(tmp1 | k55, vidx2, zmm0); |
291 | } |
292 | base_idx = 8; |
293 | for (int i = 0; i < 2; i++) { |
294 | auto zmm0 = src_zmm(base_idx + 2 * i + 1); |
295 | auto zmm1 = src_zmm(base_idx + 2 * i + 5); |
296 | |
297 | auto tmp0 = src_zmm(base_idx + 2 * i); |
298 | auto tmp1 = src_zmm(base_idx + 2 * i + 4); |
299 | |
300 | vmovupd(tmp0, zmm0); |
301 | vmovupd(tmp1, zmm1); |
302 | |
303 | vpermpd(tmp0 | kAA, vidx2, zmm1); |
304 | vpermpd(tmp1 | k55, vidx2, zmm0); |
305 | } |
306 | |
307 | // swap 3 |
308 | for (int i = 0; i < 4; i++) { |
309 | auto zmm0 = src_zmm(2 * i); |
310 | auto zmm1 = src_zmm(2 * i + 8); |
311 | |
312 | auto tmp0 = src_zmm(2 * i + 1); |
313 | auto tmp1 = src_zmm(2 * i + 9); |
314 | |
315 | vmovupd(tmp0, zmm0); |
316 | vmovupd(tmp1, zmm1); |
317 | |
318 | vpermpd(tmp0 | kCC, vidx1, zmm1); |
319 | vpermpd(tmp1 | k33, vidx1, zmm0); |
320 | } |
321 | |
322 | // all stores |
323 | for (int i = 0; i < 8; i++) |
324 | vextracti64x4(src_ymm(2 * i), src_zmm(2 * i + 1), 1); |
325 | |
326 | auto get_vec_idx = [=](int ic_idx) { |
327 | assert(ic_idx < 16 && ic_idx >= 0); |
328 | switch (ic_idx) { |
329 | case 0: return 1; |
330 | case 1: return 0; |
331 | case 2: return 3; |
332 | case 3: return 2; |
333 | case 4: return 9; |
334 | case 5: return 8; |
335 | case 6: return 11; |
336 | case 7: return 10; |
337 | case 8: return 5; |
338 | case 9: return 4; |
339 | case 10: return 7; |
340 | case 11: return 6; |
341 | case 12: return 13; |
342 | case 13: return 12; |
343 | case 14: return 15; |
344 | default: return 14; |
345 | } |
346 | }; |
347 | |
348 | for (int ic = 0; ic < ic_block; ic++) |
349 | store(src_zmm(get_vec_idx(ic)), ic); |
350 | } |
351 | |
352 | void jit_trans_iw_ic_int16_t::generate() { |
353 | preamble(); |
354 | |
355 | alignas(64) static constexpr const int64_t idx1[8] |
356 | = {2, 3, 0, 1, 6, 7, 4, 5}; |
357 | alignas(64) static constexpr const int64_t idx2[8] |
358 | = {1, 0, 3, 2, 5, 4, 7, 6}; |
359 | alignas(64) static constexpr const int32_t idx3[16] |
360 | = {1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 15, 14}; |
361 | alignas(64) static constexpr const int32_t idx4[16] |
362 | = {8, 10, 12, 14, 0, 2, 4, 6, 9, 11, 13, 15, 1, 3, 5, 7}; |
363 | alignas(64) static constexpr const uint16_t idx5[32] |
364 | = {0, 16, 2, 18, 8, 24, 10, 26, 4, 20, 6, 22, 12, 28, 14, 30, 1, 17, |
365 | 3, 19, 9, 25, 11, 27, 5, 21, 7, 23, 13, 29, 15, 31}; |
366 | |
367 | const int ic_block = conf_->ic_block; |
368 | const bool is_layout_nxc = utils::one_of(conf_->src_tag, format_tag::ndhwc, |
369 | format_tag::nhwc, format_tag::nwc); |
370 | const size_t src_mult |
371 | = is_layout_nxc ? conf_->ngroups * conf_->ic : ic_block; |
372 | const int iw = conf_->iw; |
373 | const int tr_iw = conf_->tr_iw; |
374 | const int str_w = conf_->stride_w; |
375 | assert(tr_iw % str_w == 0); |
376 | const int tr_iw_s = tr_iw / str_w; |
377 | assert(transpose_size >= ic_block); |
378 | |
379 | auto kmovw = [=](Opmask k, unsigned w) { |
380 | mov(regw_tmp, w); |
381 | jit_generator::kmovw(k, regw_tmp); |
382 | }; |
383 | |
384 | kmovw(kFFFF, 0xffff); |
385 | kmovw(k5555, 0x5555); |
386 | kmovw(kAAAA, 0xaaaa); |
387 | kmovw(kAA, 0xaa); |
388 | kmovw(k55, 0x55); |
389 | kmovw(kCC, 0xcc); |
390 | kmovw(k33, 0x33); |
391 | |
392 | auto vmovdqa64 = [=](Zmm z, const int64_t *addr) { |
393 | mov(imm_addr64, reinterpret_cast<size_t>(addr)); |
394 | jit_generator::vmovdqa64(z, ptr[imm_addr64]); |
395 | }; |
396 | |
397 | auto vmovdqa32 = [=](Zmm z, const int32_t *addr) { |
398 | mov(imm_addr64, reinterpret_cast<size_t>(addr)); |
399 | jit_generator::vmovdqa32(z, ptr[imm_addr64]); |
400 | }; |
401 | |
402 | vmovdqa64(vidx1, idx1); |
403 | vmovdqa64(vidx2, idx2); |
404 | vmovdqa32(vidx3, idx3); |
405 | vmovdqa32(vidx4, idx4); |
406 | vmovdqa32(vidx5, (const int32_t *)idx5); |
407 | |
408 | // Data for every strided case is placed consecutively |
409 | // For 1x1 convolutions with strides we transpose only needed elements |
410 | const auto str_w_end = (conf_->kw == 1) ? 1 : str_w; |
411 | for (int s = 0; s < str_w_end; s++) { |
412 | const int left_pad = div_up(conf_->l_pad - s, str_w); |
413 | const int iw1 = iw + conf_->l_pad; |
414 | const int iw_s = (s < (iw1 % str_w) ? div_up(iw1, str_w) : iw1 / str_w) |
415 | - left_pad; |
416 | const int right_pad = tr_iw_s - iw_s - left_pad; |
417 | |
418 | const int transposes = utils::div_up(iw_s, transpose_size); |
419 | int loop_iters = nstl::max(0, transposes - 1); |
420 | tail = iw_s - loop_iters * transpose_size; |
421 | |
422 | src_stride = src_mult * typesize * str_w; |
423 | tr_src_stride = tr_iw * typesize; |
424 | |
425 | bool nontemporal_stores = false; |
426 | enable_prefetch = iw > small_spatial ? true : false; |
427 | |
428 | const size_t src_step = src_mult * transpose_size * str_w * typesize; |
429 | const size_t tr_src_step = transpose_size * typesize; |
430 | |
431 | mov(reg_src, ptr[param1 + GET_OFF(src)]); |
432 | mov(reg_tr_src, ptr[param1 + GET_OFF(tr_src)]); |
433 | mov(reg_src_prf, ptr[param1 + GET_OFF(src_prf)]); |
434 | mov(reg_tr_src_prf, ptr[param1 + GET_OFF(tr_src_prf)]); |
435 | |
436 | if (str_w > 1) { |
437 | int tr_src_shift = s; |
438 | int src_shift = (str_w - (conf_->l_pad % str_w) + s) % str_w; |
439 | add(reg_src, src_shift * src_mult * typesize); |
440 | add(reg_tr_src, tr_src_shift * tr_iw_s * typesize); |
441 | add(reg_src_prf, src_shift * src_mult * typesize); |
442 | add(reg_tr_src_prf, tr_src_shift * tr_iw_s * typesize); |
443 | } |
444 | |
445 | if (left_pad > 0 && loop_iters > 0) { |
446 | loop_iters--; |
447 | transpose(transpose_size, left_pad, 0, nontemporal_stores); |
448 | add(reg_src, src_step); |
449 | add(reg_tr_src, tr_src_step + left_pad * typesize); |
450 | add(reg_src_prf, src_step); |
451 | add(reg_tr_src_prf, tr_src_step + left_pad * typesize); |
452 | } |
453 | |
454 | if (loop_iters) { |
455 | mov(reg_loop, loop_iters); |
456 | Label loop; |
457 | L(loop); |
458 | { |
459 | transpose(transpose_size, 0, 0, nontemporal_stores); |
460 | add(reg_src, src_step); |
461 | add(reg_tr_src, tr_src_step); |
462 | add(reg_src_prf, src_step); |
463 | add(reg_tr_src_prf, tr_src_step); |
464 | sub(reg_loop, 1); |
465 | jnz(loop); |
466 | } |
467 | } |
468 | if (transposes > 1) |
469 | transpose(tail, 0, right_pad, nontemporal_stores); |
470 | else |
471 | transpose(tail, left_pad, right_pad, nontemporal_stores); |
472 | } |
473 | postamble(); |
474 | } |
475 | |
476 | struct jit_trans_ow_oc_t : public jit_trans_dst_t, public jit_generator { |
477 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_trans_ow_oc_t) |
478 | jit_trans_ow_oc_t(const jit_conv_conf_t *conf) |
479 | : jit_trans_dst_t(conf), jit_generator(jit_name()) {} |
480 | |
481 | void operator()(ctx_t *ctx) override { jit_generator::operator()(ctx); } |
482 | |
483 | status_t create_kernel() override { return jit_generator::create_kernel(); } |
484 | |
485 | private: |
486 | using reg64_t = const Xbyak::Reg64; |
487 | using reg32_t = const Xbyak::Reg32; |
488 | using opmask_t = const Xbyak::Opmask; |
489 | using zmm = const Xbyak::Zmm; |
490 | |
491 | enum { |
492 | typesize = sizeof(int16_t), |
493 | transpose_size = 16, |
494 | small_spatial = 14 |
495 | }; |
496 | size_t src_stride = 0, tr_src_stride = 0; |
497 | int tail = 0; |
498 | bool enable_prefetch = false; |
499 | |
500 | opmask_t kFF = k1; |
501 | opmask_t mask_lo = k2; |
502 | opmask_t k_oc_tail = k3; |
503 | |
504 | zmm vidx1 = zmm31; |
505 | zmm vidx2 = zmm30; |
506 | |
507 | reg64_t reg_src = r8; |
508 | reg64_t reg_tr_src = r9; |
509 | reg64_t reg_src_prf = r10; |
510 | reg64_t reg_tr_src_prf = r11; |
511 | reg64_t reg_loop = r12; |
512 | reg64_t reg_tr_src_tmp = r13; |
513 | reg32_t regw_tmp = r14d; |
514 | reg64_t imm_addr64 = rbx; |
515 | |
516 | void transpose(int nrows, int l_pad, int r_pad, bool nontemporal_stores, |
517 | bool do_convert = true); |
518 | void generate() override; |
519 | }; |
520 | |
521 | // do_convert (default is 'true') is a flag that determines when to do the |
522 | // transformation of the input data and when to simply zero out the output data |
523 | void jit_trans_ow_oc_t::transpose(int nrows, int l_pad, int r_pad, |
524 | bool nontemporal_stores, bool do_convert) { |
525 | assert(nrows >= 0 && nrows <= transpose_size); |
526 | static_assert(transpose_size == 16, "Unsupported transpose size" ); |
527 | if (!nrows) return; |
528 | |
529 | auto src_zmm = [=](int i) { return Zmm(i); }; |
530 | |
531 | auto src_ymm = [=](int i) { |
532 | assert(i >= 0 && i < 16); |
533 | return Ymm(i); |
534 | }; |
535 | |
536 | auto load_ymm = [=](int i) { |
537 | auto ymm_reg = src_ymm(i); |
538 | auto addr = EVEX_compress_addr(reg_src, i * src_stride); |
539 | if (conf_->oc_tail) { |
540 | ymm_reg = ymm_reg | k_oc_tail | T_z; |
541 | // Assertion below as we need vmovdqu16 for tails. |
542 | // If needed, can be removed by using load_bytes() helper. |
543 | assert(mayiuse(avx512_core)); |
544 | vmovdqu16(ymm_reg, addr); |
545 | } else { |
546 | vmovups(ymm_reg, addr); |
547 | } |
548 | }; |
549 | |
550 | auto store = [=](Zmm r, int i) { |
551 | auto addr = EVEX_compress_addr(reg_tr_src, i * tr_src_stride); |
552 | if (nontemporal_stores) |
553 | vmovntps(addr, r); |
554 | else |
555 | vmovups(addr, r); |
556 | }; |
557 | const bool is_layout_nxc = utils::one_of(conf_->dst_tag, format_tag::ndhwc, |
558 | format_tag::nhwc, format_tag::nwc); |
559 | |
560 | if (mayiuse(avx512_core) && !is_layout_nxc) { |
561 | // TODO: adopt for nhwc? |
562 | for (int i = 0; i < nrows / 2; i++) { |
563 | auto zmm_src0 = src_zmm(i); |
564 | if (do_convert) { |
565 | vmovdqu16(zmm_src0, |
566 | EVEX_compress_addr(reg_src, 2 * i * src_stride)); |
567 | vpermw(zmm_src0, vidx2, zmm_src0); |
568 | } else { |
569 | vpxord(zmm_src0, zmm_src0, zmm_src0); |
570 | } |
571 | store(zmm_src0, 2 * i); |
572 | } |
573 | if (r_pad > 0) { |
574 | auto zmm_src0 = src_zmm(29); |
575 | if (do_convert) { |
576 | vmovdqu16(zmm_src0 | mask_lo | T_z, |
577 | EVEX_compress_addr(reg_src, (nrows - 1) * src_stride)); |
578 | vpermw(zmm_src0, vidx2, zmm_src0); |
579 | } else { |
580 | vpxord(zmm_src0, zmm_src0, zmm_src0); |
581 | } |
582 | store(zmm_src0, nrows - 1); |
583 | } |
584 | } else { |
585 | for (int i = 0; i < nrows / 2; i++) { |
586 | auto src0 = src_ymm(2 * i); |
587 | auto src1 = src_ymm(2 * i + 1); |
588 | auto zmm_src0 = src_zmm(2 * i); |
589 | if (do_convert) { |
590 | load_ymm(2 * i); |
591 | if (is_layout_nxc && conf_->oc_tail) { |
592 | load_ymm(2 * i + 1); |
593 | auto ymm_tmp = Ymm(30); |
594 | vpunpcklwd(ymm_tmp, src0, src1); |
595 | vpunpckhwd(src0, src0, src1); |
596 | vinserti64x4(zmm_src0, zmm_src0, ymm_tmp, 1); |
597 | } else { |
598 | vpunpcklwd(src1, src0, |
599 | EVEX_compress_addr( |
600 | reg_src, (2 * i + 1) * src_stride)); |
601 | vpunpckhwd(src0, src0, |
602 | EVEX_compress_addr( |
603 | reg_src, (2 * i + 1) * src_stride)); |
604 | vinserti64x4(zmm_src0, zmm_src0, src1, 1); |
605 | } |
606 | vpermpd(zmm_src0 | kFF, vidx1, zmm_src0); |
607 | } else { |
608 | vpxord(zmm_src0, zmm_src0, zmm_src0); |
609 | } |
610 | store(zmm_src0, 2 * i); |
611 | } |
612 | if (r_pad > 0) { |
613 | auto src0 = src_ymm(nrows - 1); |
614 | auto src1 = src_ymm(nrows); |
615 | auto zmm_src0 = src_zmm(30); |
616 | if (do_convert) { |
617 | load_ymm(nrows - 1); |
618 | |
619 | vpxor(src1, src1, src1); |
620 | vpunpckhwd(src1, src0, src1); |
621 | vinserti64x4(zmm_src0, zmm_src0, src1, 0); |
622 | vpxor(src1, src1, src1); |
623 | vpunpcklwd(src0, src0, src1); |
624 | vinserti64x4(zmm_src0, zmm_src0, src0, 1); |
625 | vpermpd(zmm_src0 | kFF, vidx1, zmm_src0); |
626 | } else { |
627 | vpxord(zmm_src0, zmm_src0, zmm_src0); |
628 | } |
629 | store(zmm_src0, nrows - 1); |
630 | } |
631 | } |
632 | } |
633 | |
634 | void jit_trans_ow_oc_t::generate() { |
635 | preamble(); |
636 | |
637 | alignas(64) static constexpr const int64_t idx1[8] |
638 | = {4, 5, 0, 1, 6, 7, 2, 3}; |
639 | alignas(64) static constexpr const int16_t idx2[32] |
640 | = {0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23, 8, 24, 9, |
641 | 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31}; |
642 | |
643 | const int oc_block = conf_->oc_block; |
644 | const bool is_layout_nxc = utils::one_of(conf_->dst_tag, format_tag::ndhwc, |
645 | format_tag::nhwc, format_tag::nwc); |
646 | const size_t src_mult |
647 | = is_layout_nxc ? conf_->ngroups * conf_->oc : oc_block; |
648 | const int ow = conf_->ow; |
649 | const int transposes = utils::div_up(ow, transpose_size); |
650 | int loop_iters = nstl::max(0, transposes - 1); |
651 | tail = ow - loop_iters * transpose_size; |
652 | |
653 | src_stride = src_mult * typesize; |
654 | tr_src_stride = oc_block * typesize; |
655 | |
656 | bool nontemporal_stores = conf_->use_nt_stores_ddst; |
657 | enable_prefetch = ow > small_spatial; |
658 | |
659 | const size_t src_step = src_mult * transpose_size * typesize; |
660 | const size_t tr_src_step = (size_t)oc_block * transpose_size * typesize; |
661 | const int right_pad = ow % 2; |
662 | |
663 | const auto zero_tr_ow = nstl::max(0, conf_->tr_ow - ow - right_pad); |
664 | |
665 | mov(reg_src, ptr[param1 + GET_OFF(src)]); |
666 | mov(reg_tr_src, ptr[param1 + GET_OFF(tr_src)]); |
667 | mov(reg_src_prf, ptr[param1 + GET_OFF(src_prf)]); |
668 | mov(reg_tr_src_prf, ptr[param1 + GET_OFF(tr_src_prf)]); |
669 | |
670 | auto kmovw = [=](Opmask k, unsigned w) { |
671 | mov(regw_tmp, w); |
672 | jit_generator::kmovw(k, regw_tmp); |
673 | }; |
674 | auto kmovd = [=](Opmask k, unsigned w) { |
675 | mov(regw_tmp, w); |
676 | jit_generator::kmovd(k, regw_tmp); |
677 | }; |
678 | |
679 | kmovw(kFF, 0xFF); |
680 | kmovd(mask_lo, 0x0000ffff); |
681 | |
682 | if (is_layout_nxc && conf_->oc_tail) { |
683 | Label done; |
684 | kxnorw(k_oc_tail, k_oc_tail, k_oc_tail); |
685 | cmp(dword[param1 + GET_OFF(ch_work)], conf_->oc_block); |
686 | je(done, T_NEAR); |
687 | kmovw(k_oc_tail, (1 << conf_->oc_tail) - 1); |
688 | L(done); |
689 | } |
690 | |
691 | auto vmovdqa64 = [=](Zmm z, const int64_t *addr) { |
692 | mov(imm_addr64, reinterpret_cast<size_t>(addr)); |
693 | jit_generator::vmovdqa64(z, ptr[imm_addr64]); |
694 | }; |
695 | |
696 | vmovdqa64(vidx1, idx1); |
697 | vmovdqa64(vidx2, (const int64_t *)idx2); |
698 | if (loop_iters) { |
699 | mov(reg_loop, loop_iters); |
700 | Label loop; |
701 | L(loop); |
702 | { |
703 | transpose(transpose_size, 0, 0, nontemporal_stores); |
704 | add(reg_src, src_step); |
705 | add(reg_tr_src, tr_src_step); |
706 | add(reg_src_prf, src_step); |
707 | add(reg_tr_src_prf, tr_src_step); |
708 | sub(reg_loop, 1); |
709 | jnz(loop); |
710 | } |
711 | } |
712 | transpose(tail, 0, right_pad, nontemporal_stores); |
713 | if (zero_tr_ow) { |
714 | const auto zero_transposes = utils::div_up(zero_tr_ow, transpose_size); |
715 | const auto zero_loop_iters = nstl::max(0, zero_transposes - 1); |
716 | const auto zero_tail = zero_tr_ow - zero_loop_iters * transpose_size; |
717 | const auto zero_right_pad = zero_tr_ow % 2; |
718 | |
719 | // shift over tail |
720 | auto tr_src_tail_step |
721 | = (size_t)oc_block * (tail + right_pad) * typesize; |
722 | add(reg_tr_src, tr_src_tail_step); |
723 | add(reg_tr_src_prf, tr_src_tail_step); |
724 | |
725 | // zero the tr_ow - ow |
726 | if (zero_loop_iters) { |
727 | mov(reg_loop, zero_loop_iters); |
728 | Label zero_loop; |
729 | L(zero_loop); |
730 | { |
731 | transpose(transpose_size, 0, 0, nontemporal_stores, false); |
732 | add(reg_tr_src, tr_src_step); |
733 | add(reg_tr_src_prf, tr_src_step); |
734 | sub(reg_loop, 1); |
735 | jnz(zero_loop); |
736 | } |
737 | } |
738 | transpose(zero_tail, 0, zero_right_pad, nontemporal_stores, false); |
739 | } |
740 | |
741 | postamble(); |
742 | } |
743 | |
744 | /* |
745 | // ------------------------------------------------- |
746 | // jit_transpose4x16_src |
747 | // ------------------------------------------------- |
748 | */ |
749 | |
750 | void jit_transpose4x16_src::transpose(int nrows) { |
751 | assert(nrows >= 0 && nrows <= transpose_size); |
752 | static_assert(transpose_size == 4, "Unsupported transpose size" ); |
753 | if (!nrows) return; |
754 | |
755 | auto pf_src_t0 = [=](int i) { |
756 | if (tparams->src_pf0_distance) |
757 | prefetcht0(EVEX_compress_addr( |
758 | reg_src, (tparams->src_pf0_distance + i) * src_stride)); |
759 | }; |
760 | |
761 | auto pf_tr_src_t0 = [=](int i) { |
762 | if (tparams->tr_src_pf0_distance) |
763 | prefetcht0(EVEX_compress_addr(reg_tr_src, |
764 | (tparams->tr_src_pf0_distance + i) * src_stride)); |
765 | }; |
766 | |
767 | auto pf_src_t1 = [=](int i) { |
768 | if (tparams->src_pf1) |
769 | prefetcht1(EVEX_compress_addr(reg_src_prf, i * src_stride)); |
770 | }; |
771 | |
772 | auto pf_tr_src_t1 = [=](int i) { |
773 | if (tparams->tr_src_pf1) |
774 | prefetchwt1(EVEX_compress_addr(reg_tr_src_prf, i * tr_src_stride)); |
775 | }; |
776 | |
777 | auto src_zmm = [=](int i) { |
778 | assert(i >= 0 && i < 4); |
779 | return Zmm(i); |
780 | }; |
781 | |
782 | auto tmp_zmm = [=](int i) { |
783 | assert(i >= 0 && i < 4); |
784 | return Zmm(4 + i); |
785 | }; |
786 | |
787 | auto load = [=](int i) { |
788 | vmovups(src_zmm(i), EVEX_compress_addr(reg_src, i * src_stride)); |
789 | }; |
790 | |
791 | auto store = [=](Zmm r, int i) { |
792 | vmovups(EVEX_compress_addr(reg_tr_src, i * tr_src_stride), r); |
793 | }; |
794 | |
795 | auto tmp0 = tmp_zmm(0); |
796 | auto tmp1 = tmp_zmm(1); |
797 | auto tmp2 = tmp_zmm(2); |
798 | auto tmp3 = tmp_zmm(3); |
799 | |
800 | auto src0 = src_zmm(0); |
801 | auto src1 = src_zmm(1); |
802 | auto src2 = src_zmm(2); |
803 | auto src3 = src_zmm(3); |
804 | for (int i = 0; i < nrows; i++) { |
805 | load(i); |
806 | } |
807 | |
808 | for (size_t i = nrows; i < 4; i++) { |
809 | vpxord(src_zmm(i), src_zmm(i), src_zmm(i)); |
810 | } |
811 | |
812 | vmovupd(tmp0, src0); |
813 | vmovupd(tmp1, src1); |
814 | pf_src_t0(0); |
815 | vpermpd(tmp0 | kF0, vidx01, src2); |
816 | vpermpd(tmp1 | kF0, vidx01, src3); |
817 | |
818 | valignd(src0, src0, src0, 8); |
819 | valignd(src1, src1, src1, 8); |
820 | pf_src_t0(1); |
821 | vmovupd(tmp2, src0); |
822 | vmovupd(tmp3, src1); |
823 | pf_src_t0(2); |
824 | vpermpd(tmp2 | kF0, vidx10, src2); |
825 | vpermpd(tmp3 | kF0, vidx10, src3); |
826 | pf_src_t0(3); |
827 | |
828 | vmovupd(src0, tmp0); |
829 | pf_src_t1(0); |
830 | vmovupd(src1, tmp2); |
831 | pf_src_t1(1); |
832 | vmovupd(src2, tmp1); |
833 | pf_src_t1(2); |
834 | vmovupd(src3, tmp3); |
835 | pf_src_t1(3); |
836 | vpermpd(src0 | kCC, vidx1, tmp1); |
837 | vpermpd(src1 | kCC, vidx1, tmp3); |
838 | pf_tr_src_t0(0); |
839 | vpermpd(src2 | k33, vidx1, tmp0); |
840 | vpermpd(src3 | k33, vidx1, tmp2); |
841 | pf_tr_src_t0(1); |
842 | |
843 | vmovupd(tmp0, src0); |
844 | vmovupd(tmp1, src2); |
845 | pf_tr_src_t0(2); |
846 | vmovupd(tmp2, src1); |
847 | vmovupd(tmp3, src3); |
848 | pf_tr_src_t0(3); |
849 | vpermps(tmp0 | kFFFF, vidxP, src0); |
850 | pf_tr_src_t1(0); |
851 | vpermps(tmp1 | kFFFF, vidxP, src2); |
852 | pf_tr_src_t1(1); |
853 | vpermps(tmp2 | kFFFF, vidxP, src1); |
854 | pf_tr_src_t1(3); |
855 | vpermps(tmp3 | kFFFF, vidxP, src3); |
856 | pf_tr_src_t1(4); |
857 | |
858 | store(tmp0, 0); |
859 | store(tmp1, 1); |
860 | store(tmp2, 2); |
861 | store(tmp3, 3); |
862 | } |
863 | |
864 | alignas(64) static constexpr const int64_t idx01[8] = {0, 0, 0, 0, 0, 1, 2, 3}; |
865 | alignas(64) static constexpr const int64_t idx10[8] = {0, 0, 0, 0, 4, 5, 6, 7}; |
866 | alignas(64) static constexpr const int64_t idx1[8] = {2, 3, 0, 1, 6, 7, 4, 5}; |
867 | alignas(64) static constexpr const int32_t idxP[16] |
868 | = {0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15}; |
869 | |
870 | void jit_transpose4x16_src::generate() { |
871 | preamble(); |
872 | |
873 | const int ic_block = params->ic_block; |
874 | const int is = params->is; |
875 | int tail = is % transpose_size; |
876 | |
877 | src_stride = ic_block * typesize; |
878 | assert(src_stride == 64); |
879 | tr_src_stride = ic_block * typesize; |
880 | |
881 | const int src_step = ic_block * transpose_size * typesize; |
882 | const int tr_src_step = ic_block * transpose_size * typesize; |
883 | |
884 | #define GET_TR_OFF(x) offsetof(jit_src_transpose_s, x) |
885 | mov(reg_loop, ptr[param1 + GET_TR_OFF(size)]); |
886 | mov(reg_src, ptr[param1 + GET_TR_OFF(src)]); |
887 | mov(reg_tr_src, ptr[param1 + GET_TR_OFF(tr_src)]); |
888 | mov(reg_src_prf, ptr[param1 + GET_TR_OFF(src_prf)]); |
889 | mov(reg_tr_src_prf, ptr[param1 + GET_TR_OFF(tr_src_prf)]); |
890 | #undef GET_TR_OFF |
891 | |
892 | auto kmovw = [=](Opmask k, unsigned w) { |
893 | mov(regw_tmp, w); |
894 | jit_generator::kmovw(k, regw_tmp); |
895 | }; |
896 | |
897 | auto vmovdqa64 = [=](Zmm z, const int64_t *addr) { |
898 | mov(imm_addr64, reinterpret_cast<size_t>(addr)); |
899 | jit_generator::vmovdqa64(z, ptr[imm_addr64]); |
900 | }; |
901 | |
902 | auto vmovdqa32 = [=](Zmm z, const int32_t *addr) { |
903 | mov(imm_addr64, reinterpret_cast<size_t>(addr)); |
904 | jit_generator::vmovdqa32(z, ptr[imm_addr64]); |
905 | }; |
906 | |
907 | kmovw(kF0, 0xf0); // 11110000 |
908 | kmovw(kCC, 0xcc); // 11001100 |
909 | kmovw(k33, 0x33); // 00110011 |
910 | kmovw(kFFFF, 0xffff); // 1111111111111111 |
911 | |
912 | vmovdqa64(vidx01, idx01); |
913 | vmovdqa64(vidx10, idx10); |
914 | vmovdqa64(vidx1, idx1); |
915 | vmovdqa32(vidxP, idxP); |
916 | |
917 | Label loop_label; |
918 | Label tail_label; |
919 | |
920 | cmp(reg_loop, transpose_size); |
921 | jl(tail_label, T_NEAR); |
922 | |
923 | L(loop_label); |
924 | { |
925 | transpose(transpose_size); |
926 | add(reg_src, src_step); |
927 | add(reg_tr_src, tr_src_step); |
928 | add(reg_src_prf, src_step); |
929 | add(reg_tr_src_prf, tr_src_step); |
930 | sub(reg_loop, transpose_size); |
931 | cmp(reg_loop, transpose_size); |
932 | jge(loop_label, T_NEAR); |
933 | } |
934 | L(tail_label); |
935 | transpose(tail); |
936 | |
937 | postamble(); |
938 | } |
939 | |
940 | #undef GET_OFF |
941 | |
942 | #define GET_OFF(field) offsetof(jit_conv_call_s, field) |
943 | |
944 | void jit_diff_wei_trans_to_vnni_t::generate() { |
945 | /* Reorder part of F32 weights tensor |
946 | from [2I][kd][kh][kw][16i][16o] to VNNI format [kd][kh][kw][16i][16o][2i] |
947 | and downconvert it to Bfloat16. */ |
948 | const int typesize_out = 2; |
949 | const int typesize_acc = 4; |
950 | const int simd_w = 16; |
951 | |
952 | using reg64_t = const Xbyak::Reg64; |
953 | const reg64_t ®_output = r15; |
954 | const reg64_t &org_reg_output = r14; |
955 | const reg64_t ®_input = r13; |
956 | const reg64_t ®_input_1 = r12; |
957 | const reg64_t &org_reg_input_1 = r11; |
958 | const reg64_t ®_input_2 = r10; |
959 | const reg64_t ®_prm_table = r9; |
960 | const reg64_t ®_last_ic_block = rax; |
961 | const reg64_t ®_kd = rsi; |
962 | const reg64_t ®_kh = abi_not_param1; |
963 | const reg64_t ®_tmp = rdx; |
964 | |
965 | const Xbyak::Zmm &zmm_idx = Xbyak::Zmm(31); |
966 | auto get_zmm_src_0 = [&](int ic) { return Xbyak::Zmm(ic); }; |
967 | auto get_zmm_src_1 = [&](int ic) { return Xbyak::Zmm(4 + ic); }; |
968 | auto get_zmm_bf16 = [&](int ic) { return Xbyak::Zmm(8 + ic); }; |
969 | |
970 | Xbyak::Label prm_table, zero_buffer; |
971 | Xbyak::Label kd_loop_label, kh_loop_label; |
972 | |
973 | preamble(); |
974 | |
975 | mov(reg_last_ic_block, ptr[abi_param1 + GET_OFF(last_ic_block)]); |
976 | mov(org_reg_input_1, ptr[abi_param1 + GET_OFF(src)]); |
977 | mov(org_reg_output, ptr[abi_param1 + GET_OFF(dst)]); |
978 | |
979 | mov(reg_prm_table, prm_table); |
980 | vmovups(zmm_idx, ptr[reg_prm_table]); |
981 | |
982 | xor_(reg_kd, reg_kd); |
983 | L(kd_loop_label); |
984 | { |
985 | mov(reg_output, org_reg_output); |
986 | mov(reg_input_1, org_reg_input_1); |
987 | xor_(reg_kh, reg_kh); |
988 | L(kh_loop_label); |
989 | { |
990 | for (int kw = 0; kw < kw_; kw++) { |
991 | Xbyak::Label last_ic_label, done_ic_label; |
992 | |
993 | dim_t out_offset |
994 | = (dim_t)typesize_out * kw * ic_block_ * oc_block_ * 2; |
995 | dim_t inp_1_offset |
996 | = (dim_t)typesize_acc * kw * ic_block_ * oc_block_; |
997 | dim_t inp_2_offset = (dim_t)typesize_acc |
998 | * (kd_ * kh_ * kw_ * ic_block_ * oc_block_ |
999 | + kw * ic_block_ * oc_block_); |
1000 | |
1001 | cmp(reg_last_ic_block, 0); |
1002 | jne(last_ic_label, T_NEAR); |
1003 | |
1004 | mov(reg_input_2, reg_input_1); |
1005 | safe_add(reg_input_2, inp_2_offset, reg_tmp); |
1006 | jmp(done_ic_label, T_NEAR); |
1007 | |
1008 | L(last_ic_label); |
1009 | mov(reg_input_2, zero_buffer); |
1010 | |
1011 | L(done_ic_label); |
1012 | |
1013 | for (int ocb = 0; ocb < oc_block_; ocb += simd_w) { |
1014 | int ic_count = 0; |
1015 | for (int bc = 0; bc < 2; bc++) { |
1016 | if (!bc) { |
1017 | mov(reg_input, reg_input_1); |
1018 | safe_add(reg_input, inp_1_offset, reg_tmp); |
1019 | } else |
1020 | mov(reg_input, reg_input_2); |
1021 | |
1022 | for (int ic = 0; ic < ic_block_ / 2; ic++) { |
1023 | auto zmm_src_0 = get_zmm_src_0(ic); |
1024 | auto zmm_src_1 = get_zmm_src_1(ic); |
1025 | auto zmm_out = get_zmm_bf16(ic); |
1026 | |
1027 | vmovups(zmm_src_0, |
1028 | ptr[reg_input |
1029 | + typesize_acc |
1030 | * ((2 * ic + 0) * oc_block_ |
1031 | + ocb)]); |
1032 | vmovups(zmm_src_1, |
1033 | ptr[reg_input |
1034 | + typesize_acc |
1035 | * ((2 * ic + 1) * oc_block_ |
1036 | + ocb)]); |
1037 | if (out_dt_ == data_type::bf16) { |
1038 | vcvtne2ps2bf16(zmm_out, zmm_src_1, zmm_src_0); |
1039 | } else if (out_dt_ == data_type::f16) { |
1040 | vcvtps2phx(Ymm(zmm_src_0.getIdx()), zmm_src_0); |
1041 | vcvtps2phx(Ymm(zmm_src_1.getIdx()), zmm_src_1); |
1042 | vinsertf32x8(zmm_out, zmm_src_0, |
1043 | Ymm(zmm_src_1.getIdx()), 1); |
1044 | } else { |
1045 | assert(!"unsupported data type" ); |
1046 | } |
1047 | vpermw(zmm_out, zmm_idx, zmm_out); |
1048 | |
1049 | vmovups(ptr[reg_output + out_offset |
1050 | + typesize_out |
1051 | * (ic_count * oc_block_ * 2 |
1052 | + ocb * 2)], |
1053 | zmm_out); |
1054 | ic_count++; |
1055 | } |
1056 | } |
1057 | } |
1058 | } |
1059 | safe_add(reg_output, |
1060 | (dim_t)typesize_out * kw_ * 2 * ic_block_ * oc_block_, |
1061 | reg_tmp); |
1062 | safe_add(reg_input_1, |
1063 | (dim_t)typesize_acc * kw_ * ic_block_ * oc_block_, reg_tmp); |
1064 | |
1065 | add(reg_kh, 1); |
1066 | cmp(reg_kh, kh_); |
1067 | jl(kh_loop_label, T_NEAR); |
1068 | } |
1069 | safe_add(org_reg_output, |
1070 | (dim_t)typesize_out * kh_ * kw_ * 2 * ic_block_ * oc_block_, |
1071 | reg_tmp); |
1072 | safe_add(org_reg_input_1, |
1073 | (dim_t)typesize_acc * kh_ * kw_ * ic_block_ * oc_block_, |
1074 | reg_tmp); |
1075 | |
1076 | add(reg_kd, 1); |
1077 | cmp(reg_kd, kd_); |
1078 | jl(kd_loop_label, T_NEAR); |
1079 | } |
1080 | |
1081 | postamble(); |
1082 | |
1083 | align(64); |
1084 | L(prm_table); |
1085 | const uint16_t prm_array[32] |
1086 | = {0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23, 8, 24, 9, |
1087 | 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31}; |
1088 | for (size_t i = 0; i < 32; ++i) |
1089 | dw(prm_array[i]); |
1090 | |
1091 | align(64); |
1092 | L(zero_buffer); |
1093 | const uint16_t zero = 0; |
1094 | for (int i = 0; i < typesize_acc * oc_block_ * ic_block_; ++i) |
1095 | db(zero); |
1096 | } |
1097 | |
1098 | #undef GET_OFF |
1099 | |
1100 | jit_trans_src_t *create_trans_src(const jit_conv_conf_t *conf) { |
1101 | if (conf->has_vnni && IMPLICATION(conf->is_1stconv, conf->transpose_src)) |
1102 | return new jit_trans_iw_ic_int16_t(conf); |
1103 | assert(!"unsupported configuration" ); |
1104 | return nullptr; |
1105 | } |
1106 | |
1107 | jit_trans_dst_t *create_trans_dst(const jit_conv_conf_t *conf) { |
1108 | if (conf->has_vnni) return new jit_trans_ow_oc_t(conf); |
1109 | assert(!"unsupported configuration" ); |
1110 | return nullptr; |
1111 | } |
1112 | } // namespace x64 |
1113 | } // namespace cpu |
1114 | } // namespace impl |
1115 | } // namespace dnnl |
1116 | |