1/*******************************************************************************
2* Copyright 2018-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <assert.h>
18
19#include "common/c_types_map.hpp"
20#include "common/dnnl_thread.hpp"
21#include "common/type_helpers.hpp"
22#include "common/utils.hpp"
23
24#include "cpu/platform.hpp"
25
26#include "cpu/x64/jit_avx512_core_f32_wino_conv_2x3.hpp"
27#include "cpu/x64/jit_generator.hpp"
28#include "cpu/x64/jit_primitive_conf.hpp"
29
30namespace dnnl {
31namespace impl {
32namespace cpu {
33namespace x64 {
34
35using namespace dnnl::impl::format_kind;
36using namespace dnnl::impl::memory_tracking::names;
37using namespace dnnl::impl::utils;
38using namespace Xbyak;
39
40/// SRC TRANSFORMS /////////////////////////////////////////////////////////////
41struct jit_avx512_core_f32_wino_conv_2x3_src_trans_t : public jit_generator {
42 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_f32_wino_conv_2x3_src_trans_t)
43
44 jit_conv_conf_2x3_wino_t jcp;
45
46 struct call_params_t {
47 const void *src;
48 const void *wino_src;
49 const void *v_y_masks;
50 const void *v_x_masks;
51 };
52
53 jit_avx512_core_f32_wino_conv_2x3_src_trans_t(
54 const jit_conv_conf_2x3_wino_t &ajcp, const primitive_attr_t &attr)
55 : jit_generator(jit_name()), jcp(ajcp) {}
56
57 void generate() override;
58
59 Zmm vreg_inp(int i) const {
60 assert(i < jcp.alpha * jcp.alpha);
61 return Zmm(31 - i);
62 }
63
64 Zmm vreg_tmp(int i) const {
65 assert(i < jcp.alpha * jcp.alpha);
66 return Zmm(15 - i);
67 }
68
69 Zmm vreg_out(int i) const {
70 assert(i < jcp.alpha * jcp.alpha);
71 return Zmm(31 - i);
72 }
73
74 Opmask y_mask = Opmask(1);
75 Opmask r_mask = Opmask(2);
76 Opmask x_mask(int id) {
77 assert(id < 4);
78 return Opmask(3 + id);
79 }
80
81 Reg64 reg_ptr_v_y_masks = r12;
82 Reg64 reg_ptr_v_x_masks = r11;
83
84 Reg64 reg_aux_ptr_src = r10;
85 Reg64 reg_aux_ptr_dst = r9;
86
87 Reg64 reg_ic_block = r8;
88};
89
90void jit_avx512_core_f32_wino_conv_2x3_src_trans_t::generate() {
91 Label ic_block_label;
92
93 const int load_block = 16;
94 int out_offset = 0, inp_offset = 0;
95 preamble();
96
97#define READ_PARAM(reg, field) \
98 mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)])
99 READ_PARAM(reg_aux_ptr_src, src);
100 READ_PARAM(reg_aux_ptr_dst, wino_src);
101 READ_PARAM(reg_ptr_v_y_masks, v_y_masks);
102 READ_PARAM(reg_ptr_v_x_masks, v_x_masks);
103#undef READ_PARAM
104
105 for (int i = 0; i < jcp.alpha; i++) {
106 kmovw(x_mask(i), ptr[reg_ptr_v_x_masks + sizeof(int16_t) * i]);
107 }
108 mov(reg_ic_block, jcp.ic / load_block);
109 L(ic_block_label);
110 {
111 for (int y = 0; y < jcp.alpha; y++) {
112 kmovw(y_mask, ptr[reg_ptr_v_y_masks + sizeof(int16_t) * y]);
113 for (int x = 0; x < jcp.alpha; x++) {
114 Zmm zmm = vreg_inp(y * jcp.alpha + x);
115
116 vxorps(zmm, zmm, zmm);
117 kandw(r_mask, y_mask, x_mask(x));
118 inp_offset = sizeof(float)
119 * ((-jcp.t_pad + y) * jcp.iw * load_block
120 + (-jcp.l_pad + x) * load_block);
121 vmovups(zmm | r_mask,
122 EVEX_compress_addr(reg_aux_ptr_src, inp_offset));
123 }
124 }
125 for (int y = 0; y < jcp.alpha; y++) {
126 vsubps(vreg_tmp(y * jcp.alpha + 0), vreg_inp(y * jcp.alpha + 0),
127 vreg_inp(y * jcp.alpha + 2));
128 vaddps(vreg_tmp(y * jcp.alpha + 1), vreg_inp(y * jcp.alpha + 1),
129 vreg_inp(y * jcp.alpha + 2));
130 vsubps(vreg_tmp(y * jcp.alpha + 2), vreg_inp(y * jcp.alpha + 2),
131 vreg_inp(y * jcp.alpha + 1));
132 vsubps(vreg_tmp(y * jcp.alpha + 3), vreg_inp(y * jcp.alpha + 1),
133 vreg_inp(y * jcp.alpha + 3));
134 }
135 for (int x = 0; x < jcp.alpha; x++) {
136 vsubps(vreg_out(x + 0 * jcp.alpha), vreg_tmp(x + jcp.alpha * 0),
137 vreg_tmp(x + jcp.alpha * 2));
138 vaddps(vreg_out(x + 1 * jcp.alpha), vreg_tmp(x + jcp.alpha * 1),
139 vreg_tmp(x + jcp.alpha * 2));
140 vsubps(vreg_out(x + 2 * jcp.alpha), vreg_tmp(x + jcp.alpha * 2),
141 vreg_tmp(x + jcp.alpha * 1));
142 vsubps(vreg_out(x + 3 * jcp.alpha), vreg_tmp(x + jcp.alpha * 1),
143 vreg_tmp(x + jcp.alpha * 3));
144 }
145
146 for (int i = 0; i < 16; i++) {
147 out_offset = sizeof(float) * (jcp.inp_stride * i);
148 vmovups(EVEX_compress_addr(reg_aux_ptr_dst, out_offset),
149 vreg_out(i));
150 }
151
152 add(reg_aux_ptr_src, sizeof(float) * jcp.ih * jcp.iw * load_block);
153 add(reg_aux_ptr_dst, sizeof(float) * load_block);
154 }
155 dec(reg_ic_block);
156 cmp(reg_ic_block, 0);
157 jg(ic_block_label, T_NEAR);
158 postamble();
159}
160
161/// DST TRANSFORMS /////////////////////////////////////////////////////////////
162struct jit_avx512_core_f32_wino_conv_2x3_dst_trans_t : public jit_generator {
163 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_f32_wino_conv_2x3_dst_trans_t)
164
165 jit_conv_conf_2x3_wino_t jcp;
166 const primitive_attr_t &attr_;
167
168 struct call_params_t {
169 const void *wino_dst;
170 const void *dst;
171 const void *v_y_masks;
172 const void *v_x_masks;
173
174 const void *bias;
175 };
176
177 jit_avx512_core_f32_wino_conv_2x3_dst_trans_t(
178 const jit_conv_conf_2x3_wino_t &ajcp, const primitive_attr_t &attr)
179 : jit_generator(jit_name()), jcp(ajcp), attr_(attr) {}
180
181 void generate() override;
182 bool maybe_relu(int position);
183
184 Zmm vreg_inp(int i) const { // 16
185 assert(i < jcp.alpha * jcp.alpha);
186 return Zmm(31 - i);
187 }
188
189 Zmm vreg_stg(int id) const { // 8
190 const int id_reg_stg = jcp.alpha * jcp.alpha + id;
191 assert(id_reg_stg < jcp.alpha * jcp.alpha + 8);
192 return Zmm(31 - id_reg_stg);
193 }
194
195 Zmm vreg_out(int id) const { // 4
196 const int id_reg_out = jcp.alpha * jcp.alpha + 8 + id;
197 assert(id_reg_out < jcp.alpha * jcp.alpha + 12);
198 return Zmm(31 - id_reg_out);
199 }
200
201 Zmm vreg_tmp(int id) const { // 2
202 const int id_reg_tmp = jcp.alpha * jcp.alpha + 12 + id;
203 assert(id_reg_tmp < jcp.alpha * jcp.alpha + 14);
204 return Zmm(31 - id_reg_tmp);
205 }
206
207 Zmm vreg_zero = Zmm(0);
208 Zmm vreg_prev_dst = Zmm(0);
209 Zmm vreg_bias = Zmm(2);
210
211 Opmask y_mask = Opmask(1);
212 Opmask r_mask = Opmask(2);
213 Opmask x_mask(int id) {
214 assert(id < 4);
215 return Opmask(3 + id);
216 }
217
218 Reg64 reg_ptr_v_y_masks = r12;
219 Reg64 reg_ptr_v_x_masks = r11;
220
221 Reg64 reg_aux_ptr_src = r10;
222 Reg64 reg_aux_ptr_dst = r9;
223
224 Reg64 reg_oc_block = r8;
225
226 Reg64 reg_ptr_bias = rbx;
227 Reg64 reg_ptr_sum_scale = rdx;
228};
229
230bool jit_avx512_core_f32_wino_conv_2x3_dst_trans_t::maybe_relu(int position) {
231 using namespace primitive_kind;
232 const auto &p = attr_.post_ops_;
233
234 if (position == 0) {
235 /* relu before sum */
236 return false || p.contain(eltwise, 0);
237 } else if (position == 1) {
238 /* relu after sum */
239 const int sum_idx
240 = p.contain(sum, 0) ? 0 : (p.contain(sum, 1) ? 1 : -1);
241 if (sum_idx == -1) return false;
242
243 return false || p.contain(eltwise, sum_idx + 1);
244 }
245
246 return false;
247}
248
249void jit_avx512_core_f32_wino_conv_2x3_dst_trans_t::generate() {
250 Label oc_block_label;
251
252 const int load_block = 16;
253
254 auto loop_body = [=]() {
255 const auto &p = attr_.post_ops_;
256 const int sum_idx = p.find(primitive_kind::sum);
257 const float *p_sum_scale
258 = (sum_idx != -1) ? &p.entry_[sum_idx].sum.scale : nullptr;
259 if (p_sum_scale && *p_sum_scale != 1.f)
260 mov(reg_ptr_sum_scale, (size_t)p_sum_scale);
261
262 for (int i = 0; i < 16; i++) {
263 int internal_offset = sizeof(float) * jcp.out_stride * i;
264 vmovups(vreg_inp(i),
265 EVEX_compress_addr(reg_aux_ptr_src, internal_offset));
266 }
267 for (int y = 0; y < jcp.alpha; y++) {
268 vaddps(vreg_tmp(0), vreg_inp(y * 4 + 0), vreg_inp(y * 4 + 1));
269 vaddps(vreg_stg(y * 2), vreg_tmp(0), vreg_inp(y * 4 + 2));
270
271 vsubps(vreg_tmp(1), vreg_inp(y * 4 + 1), vreg_inp(y * 4 + 2));
272 vsubps(vreg_stg(y * 2 + 1), vreg_tmp(1), vreg_inp(y * 4 + 3));
273 }
274 for (int x = 0; x < jcp.m; x++) {
275 vaddps(vreg_tmp(0), vreg_stg(x), vreg_stg(x + 2 * 1));
276 vaddps(vreg_out(x), vreg_tmp(0), vreg_stg(x + 2 * 2));
277
278 vsubps(vreg_tmp(1), vreg_stg(x + 2 * 1), vreg_stg(x + 2 * 2));
279 vsubps(vreg_out(x + 2), vreg_tmp(1), vreg_stg(x + 2 * 3));
280 }
281
282 if (jcp.with_bias) {
283 auto bias_addr = ptr[reg_ptr_bias];
284 vmovups(vreg_bias, bias_addr);
285 }
286 for (int y = 0; y < jcp.m; y++) {
287 kmovw(y_mask, ptr[reg_ptr_v_y_masks + sizeof(int16_t) * y]);
288 for (int x = 0; x < jcp.m; x++) {
289 kandw(r_mask, y_mask, x_mask(x));
290
291 int i = y * jcp.m + x;
292 int offset = sizeof(float)
293 * (y * jcp.ow * jcp.oc_block + x * jcp.oc_block);
294 Address addr = EVEX_compress_addr(reg_aux_ptr_dst, offset);
295
296 Zmm zmm = vreg_out(i);
297 if (jcp.with_bias) vaddps(zmm, zmm, vreg_bias);
298
299 if (maybe_relu(0)) {
300 vxorps(vreg_zero, vreg_zero, vreg_zero);
301 vmaxps(zmm, vreg_zero, zmm);
302 }
303 if (p_sum_scale) { // post_op: sum
304 vxorps(vreg_prev_dst, vreg_prev_dst, vreg_prev_dst);
305 vmovups(vreg_prev_dst | r_mask, addr);
306 if (*p_sum_scale == 1.f)
307 vaddps(zmm, vreg_prev_dst);
308 else
309 vfmadd231ps(
310 zmm, vreg_prev_dst, zword_b[reg_ptr_sum_scale]);
311 }
312 if (maybe_relu(1)) {
313 vxorps(vreg_zero, vreg_zero, vreg_zero);
314 vmaxps(zmm, vreg_zero, zmm);
315 }
316
317 vmovups(addr, zmm | r_mask);
318 }
319 }
320 };
321
322 preamble();
323
324#define READ_PARAM(reg, field) \
325 mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)])
326 READ_PARAM(reg_aux_ptr_src, wino_dst);
327 READ_PARAM(reg_aux_ptr_dst, dst);
328 READ_PARAM(reg_ptr_v_y_masks, v_y_masks);
329 READ_PARAM(reg_ptr_v_x_masks, v_x_masks);
330 READ_PARAM(reg_ptr_bias, bias);
331#undef READ_PARAM
332
333 for (int i = 0; i < jcp.alpha * jcp.alpha; i++)
334 vxorps(vreg_inp(i), vreg_inp(i), vreg_inp(i));
335
336 for (int i = 0; i < jcp.alpha; i++)
337 kmovw(x_mask(i), ptr[reg_ptr_v_x_masks + sizeof(int16_t) * i]);
338
339 int oc_blocks = 1;
340 oc_blocks = jcp.oc / load_block;
341 mov(reg_oc_block, oc_blocks);
342 L(oc_block_label);
343 {
344 loop_body();
345 add(reg_aux_ptr_src, sizeof(float) * load_block);
346 add(reg_aux_ptr_dst, sizeof(float) * jcp.oh * jcp.ow * load_block);
347
348 add(reg_ptr_bias, jcp.typesize_bia * load_block);
349 }
350 dec(reg_oc_block);
351 cmp(reg_oc_block, 0);
352 jg(oc_block_label, T_NEAR);
353
354 sub(reg_ptr_bias, oc_blocks * jcp.typesize_bia * load_block);
355
356 postamble();
357}
358
359/// GEMM kernel ////////////////////////////////////////////////////////////////
360struct jit_avx512_core_f32_wino_conv_2x3_fwd_ker_t : public jit_generator {
361 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_f32_wino_conv_2x3_fwd_ker_t)
362 jit_conv_conf_2x3_wino_t jcp;
363
364 struct call_params_t {
365 const void *src;
366 const void *dst;
367 const void *wei;
368 const void *dst_b;
369 };
370
371 void generate() override;
372 static bool post_ops_ok(
373 jit_conv_conf_2x3_wino_t &jcp, const primitive_attr_t &attr);
374
375 jit_avx512_core_f32_wino_conv_2x3_fwd_ker_t(
376 const jit_conv_conf_2x3_wino_t &ajcp, const primitive_attr_t &attr)
377 : jit_generator(jit_name()), jcp(ajcp) {}
378
379 static status_t init_conf(jit_conv_conf_2x3_wino_t &jcp,
380 const convolution_desc_t &cd, memory_desc_t &src_md,
381 memory_desc_t &weights_md, memory_desc_t &dst_md,
382 memory_desc_t &bias_md, const primitive_attr_t &attr,
383 memory_desc_t &expect_wei_md);
384
385 Zmm vreg_out(int n, int m) const {
386 const int id_reg_out = n * jcp.m_block + m;
387 assert(id_reg_out < jcp.n2_block * jcp.m_block);
388 return Zmm(31 - id_reg_out);
389 }
390 Zmm vreg_wei(int i) const {
391 assert(31 - jcp.n2_block * jcp.m_block - i > 1);
392 return Zmm(31 - jcp.n2_block * jcp.m_block - i);
393 }
394
395 Zmm vreg_src = Zmm(0);
396 Zmm vreg_one = Zmm(1);
397 Zmm vreg_tmp = Zmm(2);
398
399 Reg64 reg_ptr_src = r15;
400
401 Reg64 reg_aux_dst = r12;
402 Reg64 reg_aux_dst2 = r11;
403 Reg64 reg_aux_wei = r10;
404 Reg64 reg_aux_wei2 = r9;
405 Reg64 reg_aux_src = r8;
406 Reg64 reg_aux_src2 = rax;
407
408 Reg64 reg_mb = rbx;
409 Reg64 reg_nnb = rdx;
410 Reg64 reg_K = rsi;
411};
412
413bool jit_avx512_core_f32_wino_conv_2x3_fwd_ker_t::post_ops_ok(
414 jit_conv_conf_2x3_wino_t &jcp, const primitive_attr_t &attr) {
415 using namespace primitive_kind;
416 const auto &p = attr.post_ops_;
417
418 auto is_relu = [&](int idx) { return p.entry_[idx].is_relu(); };
419
420 switch (p.len()) {
421 case 0: return true;
422 case 1: return is_relu(0) || p.contain(sum, 0);
423 case 2:
424 return (p.contain(sum, 0) && is_relu(1))
425 || (p.contain(sum, 1) && is_relu(0));
426 case 3: return is_relu(0) && p.contain(sum, 1) && is_relu(2);
427 default: return false;
428 }
429}
430
431void jit_avx512_core_f32_wino_conv_2x3_fwd_ker_t::generate() {
432 Label nnb_loop_label, K_loop_label, mb_loop_label;
433
434 preamble();
435#define READ_PARAM(reg, field) \
436 mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)])
437 READ_PARAM(reg_ptr_src, src);
438 READ_PARAM(reg_aux_dst, dst);
439 READ_PARAM(reg_aux_wei, wei);
440#undef READ_PARAM
441
442 if (!jcp.small_mb) {
443 mov(reg_nnb, jcp.n_chunks);
444 L(nnb_loop_label);
445 }
446 mov(reg_aux_dst2, reg_aux_dst);
447 mov(reg_aux_src, reg_ptr_src);
448 mov(reg_mb, jcp.M / jcp.m_block);
449 L(mb_loop_label);
450 {
451 int nb2 = 0;
452 for (nb2 = 0; nb2 < jcp.n2_block; nb2++) {
453 for (int m = 0; m < jcp.m_block; m++) {
454 vxorps(vreg_out(nb2, m), vreg_out(nb2, m), vreg_out(nb2, m));
455 }
456 }
457 mov(reg_aux_src2, reg_aux_src);
458 mov(reg_aux_wei2, reg_aux_wei);
459
460 mov(reg_K, jcp.k_chunks);
461 L(K_loop_label);
462 {
463 int wei_offset = 0;
464 for (int _i = 0; _i < jcp.k2_block; _i++) {
465 for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) {
466 if (jcp.small_mb) {
467 int wei_offset = sizeof(float)
468 * ((nb2 * jcp.nb_ic * jcp.ic_block
469 * jcp.oc_block)
470 + _i * jcp.oc_block);
471 vmovups(vreg_wei(nb2),
472 EVEX_compress_addr(reg_aux_wei2, wei_offset));
473 } else {
474 vmovups(vreg_wei(nb2),
475 EVEX_compress_addr(reg_aux_wei2,
476 sizeof(float) * wei_offset));
477 wei_offset += jcp.oc_block;
478 }
479 }
480 for (int m = 0; m < jcp.m_block; m++) {
481 int inp_offset = sizeof(float) * (m * jcp.K + _i);
482 if (jcp.n2_block > 1) {
483 vbroadcastss(vreg_src,
484 EVEX_compress_addr(reg_aux_src2, inp_offset));
485 for (int nb2 = 0; nb2 < jcp.n2_block; nb2++)
486 vfmadd231ps(
487 vreg_out(nb2, m), vreg_wei(nb2), vreg_src);
488 } else {
489 vfmadd231ps(vreg_out(0, m), vreg_wei(0),
490 EVEX_compress_addr(
491 reg_aux_src2, inp_offset, true));
492 }
493 }
494 }
495 add(reg_aux_src2, sizeof(float) * jcp.ic_block);
496 if (jcp.small_mb)
497 add(reg_aux_wei2, sizeof(float) * jcp.oc_block * jcp.ic_block);
498 else
499 add(reg_aux_wei2,
500 sizeof(float) * jcp.k2_block * jcp.n2_block
501 * jcp.oc_block);
502 }
503 dec(reg_K);
504 cmp(reg_K, 0);
505 jg(K_loop_label, T_NEAR);
506
507 for (int m = 0; m < jcp.m_block; m++) {
508 int nb2 = 0;
509 for (nb2 = 0; nb2 < jcp.n2_block; nb2++) {
510 int offset = sizeof(float) * (m * jcp.N + nb2 * jcp.oc_block);
511 vmovups(EVEX_compress_addr(reg_aux_dst2, offset),
512 vreg_out(nb2, m));
513 }
514 }
515 add(reg_aux_src, sizeof(float) * jcp.m_block * jcp.K);
516 add(reg_aux_dst2, sizeof(float) * jcp.m_block * jcp.N);
517 }
518 dec(reg_mb);
519 cmp(reg_mb, 0);
520 jg(mb_loop_label, T_NEAR);
521
522 if (!jcp.small_mb) {
523 add(reg_aux_dst, sizeof(float) * jcp.n2_block * jcp.oc_block);
524 add(reg_aux_wei,
525 sizeof(float) * jcp.k_chunks * jcp.ic_block * jcp.n2_block
526 * jcp.oc_block);
527
528 dec(reg_nnb);
529 cmp(reg_nnb, 0);
530 jg(nnb_loop_label, T_NEAR);
531 }
532 postamble();
533}
534
535namespace {
536bool is_winograd_faster_than_direct(const jit_conv_conf_2x3_wino_t &jcp) {
537 return jcp.mb >= 4;
538}
539} // namespace
540
541status_t jit_avx512_core_f32_wino_conv_2x3_fwd_ker_t ::init_conf(
542 jit_conv_conf_2x3_wino_t &jcp, const convolution_desc_t &cd,
543 memory_desc_t &src_md, memory_desc_t &wei_md, memory_desc_t &dst_md,
544 memory_desc_t &bias_md, const primitive_attr_t &attr,
545 memory_desc_t &expect_wei_md) {
546 const memory_desc_wrapper src_d(&src_md);
547 const memory_desc_wrapper wei_d(&wei_md);
548 const memory_desc_wrapper dst_d(&dst_md);
549 const memory_desc_wrapper bias_d(&bias_md);
550
551 // This kernel only supports 2D convolutions.
552 if (src_d.ndims() != 4) return status::unimplemented;
553
554 const bool with_groups = wei_d.ndims() == src_d.ndims() + 1;
555
556 jcp.nthr = dnnl_get_max_threads();
557
558 jcp.ngroups = with_groups ? wei_d.dims()[0] : 1;
559 jcp.mb = src_d.dims()[0];
560 jcp.oc = dst_d.dims()[1] / jcp.ngroups;
561 jcp.oc_without_padding = jcp.oc;
562 jcp.ic = src_d.dims()[1] / jcp.ngroups;
563 jcp.ih = src_d.dims()[2];
564 jcp.iw = src_d.dims()[3];
565 jcp.oh = dst_d.dims()[2];
566 jcp.ow = dst_d.dims()[3];
567 jcp.kh = wei_d.dims()[with_groups + 2];
568 jcp.kw = wei_d.dims()[with_groups + 3];
569 jcp.t_pad = cd.padding[0][0];
570 jcp.l_pad = cd.padding[0][1];
571 jcp.stride_h = cd.strides[0];
572 jcp.stride_w = cd.strides[1];
573 jcp.dilate_h = cd.dilates[0];
574 jcp.dilate_w = cd.dilates[1];
575
576 const int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w);
577 const int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h);
578 jcp.r_pad = calculate_end_padding(
579 jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, ext_kw);
580 jcp.b_pad = calculate_end_padding(
581 jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, ext_kh);
582
583 jcp.m = 2;
584 jcp.r = 3;
585 jcp.alpha = jcp.m + jcp.r - 1;
586 int simdw = 16;
587
588 format_tag_t dat_tag = format_tag::nChw16c;
589 jcp.src_tag = src_d.matches_one_of_tag(dat_tag);
590 jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag);
591
592 if (jcp.src_tag != dat_tag) return status::unimplemented;
593 if (jcp.dst_tag != dat_tag) return status::unimplemented;
594
595 jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef;
596
597 if (!post_ops_ok(jcp, attr)) return status::unimplemented;
598
599 bool ok_to_pad_channels = jcp.ngroups == 1;
600 if (ok_to_pad_channels) {
601 jcp.oc = rnd_up(jcp.oc, simdw);
602 jcp.ic = rnd_up(jcp.ic, simdw);
603 }
604
605 if (!(mayiuse(avx512_core))) return status::unimplemented;
606
607 if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto,
608 is_winograd_faster_than_direct(jcp)))
609 return status::unimplemented;
610
611 if (src_d.data_type() != data_type::f32) return status::unimplemented;
612 if (wei_d.data_type() != data_type::f32) return status::unimplemented;
613 if (dst_d.data_type() != data_type::f32) return status::unimplemented;
614
615 jcp.ic_block = simdw;
616 jcp.oc_block = simdw;
617
618 bool ok = true && jcp.kh == 3 && jcp.kw == 3 && jcp.ngroups == 1
619 && jcp.oc % jcp.oc_block == 0 && jcp.ic % jcp.ic_block == 0
620 && jcp.stride_h == 1 && jcp.stride_w == 1 && jcp.dilate_h == 0
621 && jcp.dilate_w == 0 && jcp.t_pad == jcp.b_pad
622 && jcp.l_pad == jcp.r_pad && jcp.t_pad < 2 && jcp.t_pad >= 0
623 && jcp.l_pad < 2 && jcp.l_pad >= 0;
624 if (!ok) return status::unimplemented;
625
626 const int L2_capacity
627 = platform::get_per_core_cache_size(2) / sizeof(float);
628 const int L3_capacity
629 = platform::get_per_core_cache_size(3) * jcp.nthr / sizeof(float);
630 int a = jcp.alpha;
631 int aa = a * a;
632 int mb = jcp.mb;
633 int ic = jcp.ic;
634 int oc = jcp.oc;
635 int ih = jcp.ih;
636 int iw = jcp.iw;
637 auto wei_sz = (float)aa * ic * oc;
638 auto inp_sz = (float)mb * ih * iw * ic;
639 auto sp_sz = (float)mb * ih * iw;
640
641 /* Heuristics here. Numbers '28','196' is an observation from data. */
642 if (wei_sz / inp_sz > 5)
643 jcp.small_mb = true;
644 else
645 jcp.small_mb = false;
646
647 if (mb > nstl::min(jcp.nthr, 28)
648 || (!jcp.small_mb
649 && (wei_sz >= 0.9f * L2_capacity
650 || inp_sz > L2_capacity * jcp.nthr + L3_capacity))
651 || (jcp.small_mb && sp_sz > 196))
652 return status::unimplemented;
653
654 jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
655 jcp.dst_dt = cd.dst_desc.data_type;
656
657 jcp.typesize_bia
658 = jcp.with_bias ? types::data_type_size(bias_d.data_type()) : 0;
659
660 jcp.nb_oc = jcp.oc / jcp.oc_block;
661 jcp.nb_ic = jcp.ic / jcp.ic_block;
662
663 const int skx_free_regs = 30;
664
665 auto find_m_n2_blocks = [=](int xb, int yb, int &M, int &m_block,
666 int &n2_block, float &reg_eff) {
667 M = (xb * yb) / jcp.alpha;
668 int max_m_block = m_block = nstl::min(M, skx_free_regs);
669 int max_n2_block = n2_block = nstl::min(jcp.nb_oc, skx_free_regs);
670 reg_eff = 0;
671 for (int im = max_m_block; im > 0; im--) {
672 for (int in2 = max_n2_block; in2 > 0; in2--) {
673 int used_regs = in2 * im + in2;
674 float cur_reg_eff = ((float)in2 * im) / (im + in2) / 2.5f;
675 if (M % im || jcp.nb_oc % in2 || used_regs > skx_free_regs
676 || cur_reg_eff <= reg_eff)
677 continue;
678 reg_eff = cur_reg_eff;
679 m_block = im;
680 n2_block = in2;
681 }
682 }
683 };
684
685 int oh = jcp.oh;
686 int ow = jcp.ow;
687 int nb_oc = jcp.nb_oc;
688 int Z = ic + oc;
689 int Y = ic * oc;
690 const int L3_cap_per_core
691 = platform::get_per_core_cache_size(3) / sizeof(float);
692
693 /* Selecting xb and yb blocking */
694 int min_yb = jcp.alpha;
695 int min_xb = jcp.alpha;
696 int max_yb = nstl::max(min_yb, rnd_up(ih, 2));
697 int max_xb = nstl::max(min_xb, rnd_up(iw, 2));
698 float best_eff = 0.f;
699 for (int ix = max_xb; ix >= min_xb; ix -= 2) {
700 if (rnd_up(ow, ix) < iw - 2) continue;
701 for (int iy = max_yb; iy >= min_yb; iy -= 2) {
702 if (rnd_up(oh, iy) < ih - 2) continue;
703 int ex_y = rnd_up(oh, iy);
704 int ex_x = rnd_up(ow, ix);
705 float work_eff = (float)(ih * iw) / (ex_y * ex_x);
706
707 int M, m_block, n2_b;
708 float reg_eff, thr_eff, par_eff, mem_eff, req_mem;
709
710 find_m_n2_blocks(ix, iy, M, m_block, n2_b, reg_eff);
711
712 /* outer parallelization */
713 int nblocks = mb * div_up(ih, iy) * div_up(iw, ix);
714 thr_eff = (float)nblocks / rnd_up(nblocks, jcp.nthr);
715
716 mem_eff = 1.f;
717 req_mem = (((float)ix + 2) * (iy + 2) + aa * M) * Z + aa * Y;
718 if (req_mem > L2_capacity / 2) {
719 if (req_mem > ((L2_capacity + L3_cap_per_core) * 4) / 7)
720 mem_eff /= (n2_b + 1) / 2.f;
721 else
722 mem_eff /= (n2_b + 1) / 3.f;
723 }
724
725 float outer_eff = thr_eff + work_eff + reg_eff + mem_eff;
726
727 /* inner parallelization */
728 int bsz = iy * ix / a;
729 int gemmw = aa * (nb_oc / n2_b);
730 int bsz_r = rnd_up(bsz, jcp.nthr);
731 int gemmw_r = rnd_up(gemmw, jcp.nthr);
732 thr_eff = ((float)Z * bsz / bsz_r + Y * gemmw / gemmw_r) / (Z + Y);
733
734 req_mem = (float)ix * iy * (ic + simdw * n2_b) + simdw * n2_b * ic;
735 mem_eff = nstl::min(1.f, L2_capacity / req_mem);
736 int M_per_thr = nstl::max(2, div_up(aa, jcp.nthr));
737 int oc_per_thr
738 = nstl::min(oc, div_up(aa * (nb_oc / n2_b), jcp.nthr));
739 req_mem = (float)aa * oc_per_thr * ic + M_per_thr * M * Z;
740 if (req_mem > L2_capacity) mem_eff = 0.1f;
741 par_eff = 1 / (2.f * nblocks);
742
743 float inner_eff = thr_eff + work_eff + mem_eff + par_eff;
744
745 float eff = jcp.small_mb ? inner_eff : outer_eff;
746 if (eff > best_eff) {
747 best_eff = eff;
748 jcp.yb = iy;
749 jcp.xb = ix;
750 jcp.M = M;
751 jcp.m_block = m_block;
752 jcp.n2_block = n2_b;
753 }
754 }
755 }
756
757 assert(jcp.xb % 2 == 0 && jcp.yb % 2 == 0);
758
759 jcp.inp_stride = jcp.M * jcp.ic;
760 jcp.out_stride = jcp.M * jcp.oc;
761 jcp.wei_stride = jcp.ic * jcp.oc;
762 jcp.bia_stride = jcp.oc;
763
764 jcp.N = jcp.oc;
765 jcp.K = jcp.ic;
766
767 jcp.n_block = jcp.oc_block;
768 jcp.k_block = jcp.ic_block;
769
770 assert(jcp.M % jcp.m_block == 0);
771 assert(jcp.nb_oc % jcp.n2_block == 0);
772
773 jcp.n_chunks = jcp.nb_oc / jcp.n2_block;
774 jcp.k2_block = jcp.ic_block;
775 jcp.k_chunks = jcp.K / jcp.k2_block;
776
777 /* re-create weights primitive descriptor
778 and set weights wino_blocking */
779 expect_wei_md.format_kind = format_kind::wino;
780 expect_wei_md.data_type = data_type::f32;
781 wino_desc_t &wd = expect_wei_md.format_desc.wino_desc;
782 wd.wino_format = jcp.small_mb ? wino_memory_format_t::wino_wei_aaOio
783 : wino_memory_format_t::wino_wei_aaOBiOo;
784 wd.r = jcp.r;
785 wd.alpha = jcp.alpha;
786 wd.ic = jcp.ic;
787 wd.oc = jcp.oc;
788 wd.ic_block = jcp.ic_block;
789 wd.oc_block = jcp.oc_block;
790 wd.oc2_block = jcp.n2_block;
791 wd.ic2_block = 1;
792 wd.adj_scale = 1.f;
793 size_t max_size = sizeof(float) * jcp.alpha * jcp.alpha * jcp.ic * jcp.oc;
794 wd.size = max_size;
795
796 return status::success;
797}
798////////////////////////////////////////////////////////////////////////////////
799
800status_t jit_avx512_core_f32_wino_conv_2x3_fwd_t ::pd_t::jit_conf(
801 memory_desc_t &expect_wei_md) {
802 return jit_avx512_core_f32_wino_conv_2x3_fwd_ker_t::init_conf(jcp_,
803 *this->desc(), this->src_md_, this->weights_md_, this->dst_md_,
804 this->bias_md_, *this->attr(), expect_wei_md);
805}
806
807jit_avx512_core_f32_wino_conv_2x3_fwd_t::
808 jit_avx512_core_f32_wino_conv_2x3_fwd_t(const pd_t *apd)
809 : primitive_t(apd) {}
810
811status_t jit_avx512_core_f32_wino_conv_2x3_fwd_t::init(engine_t *engine) {
812 CHECK(safe_ptr_assign(kernel_,
813 new jit_avx512_core_f32_wino_conv_2x3_fwd_ker_t(
814 pd()->jcp_, *pd()->attr())));
815 CHECK(safe_ptr_assign(src_trans_,
816 new jit_avx512_core_f32_wino_conv_2x3_src_trans_t(
817 pd()->jcp_, *pd()->attr())));
818 CHECK(safe_ptr_assign(dst_trans_,
819 new jit_avx512_core_f32_wino_conv_2x3_dst_trans_t(
820 pd()->jcp_, *pd()->attr())));
821 CHECK(kernel_->create_kernel());
822 CHECK(src_trans_->create_kernel());
823 CHECK(dst_trans_->create_kernel());
824 return status::success;
825}
826
827jit_avx512_core_f32_wino_conv_2x3_fwd_t::
828 ~jit_avx512_core_f32_wino_conv_2x3_fwd_t()
829 = default;
830
831void jit_avx512_core_f32_wino_conv_2x3_fwd_t::execute_forward_mbN(
832 const float *src, const float *wei, const float *bia, float *dst,
833 const memory_tracking::grantor_t &scratchpad) const {
834 const auto &jcp = kernel_->jcp;
835
836 const size_t wino_size_offset
837 = (size_t)(pd()->jcp_.yb / 2) * (pd()->jcp_.xb / 2)
838 + (pd()->jcp_.xb);
839 const size_t size_wino_src = wino_size_offset * pd()->jcp_.ic * 16;
840 const size_t size_wino_dst = wino_size_offset * pd()->jcp_.oc * 16;
841
842 if (pd()->wants_padded_bias()) {
843 auto padded_bias = scratchpad.get<float>(key_conv_padded_bias);
844 utils::array_copy(padded_bias, bia, jcp.oc_without_padding);
845 utils::array_set(padded_bias + jcp.oc_without_padding, 0.f,
846 jcp.oc - jcp.oc_without_padding);
847 bia = padded_bias;
848 }
849
850 auto ptr_V = scratchpad.get<float>(key_wino_V);
851 auto ptr_M = scratchpad.get<float>(key_wino_M);
852
853 parallel_nd_ext(jcp.nthr, jcp.mb, div_up(jcp.oh, jcp.yb),
854 div_up(jcp.ow, jcp.xb),
855 [&](dim_t ithr, dim_t nthr, dim_t mb, dim_t tile_y_b,
856 dim_t tile_x_b) {
857 assert(nthr <= jcp.nthr);
858 MAYBE_UNUSED(nthr);
859
860 int tile_y = tile_y_b * jcp.yb;
861 int tile_x = tile_x_b * jcp.xb;
862
863 auto wino_src = ptr_V + size_wino_src * ithr;
864 auto wino_dst = ptr_M + size_wino_dst * ithr;
865
866 auto src_trans_p
867 = jit_avx512_core_f32_wino_conv_2x3_src_trans_t ::
868 call_params_t();
869 auto dst_trans_p
870 = jit_avx512_core_f32_wino_conv_2x3_dst_trans_t ::
871 call_params_t();
872 auto gemm_p = jit_avx512_core_f32_wino_conv_2x3_fwd_ker_t ::
873 call_params_t();
874
875 /* transformation of input tensor to winograd domain */
876 for (int y_in_block = 0; y_in_block < jcp.yb; y_in_block += 2) {
877 for (int x_in_block = 0; x_in_block < jcp.xb;
878 x_in_block += 2) {
879
880 unsigned short v_y_masks[4], v_x_masks[4];
881
882 int y = y_in_block + tile_y;
883 int x = x_in_block + tile_x;
884 int m = (y_in_block / 2) * (jcp.xb / 2)
885 + (x_in_block / 2);
886
887 int v_ys = nstl::max(0, jcp.t_pad - y);
888 int v_ye = nstl::min(jcp.alpha,
889 nstl::max(0, jcp.ih + jcp.t_pad - y));
890
891 int v_xs = nstl::max(0, jcp.l_pad - x);
892 int v_xe = nstl::min(jcp.alpha,
893 nstl::max(0, jcp.iw + jcp.l_pad - x));
894
895#pragma unroll(4)
896 for (int i = 0; i < jcp.alpha; i++) {
897 v_y_masks[i] = (i < v_ys || i >= v_ye) ? 0 : 0xffff;
898 v_x_masks[i] = (i < v_xs || i >= v_xe) ? 0 : 0xffff;
899 }
900 auto local_s = src
901 + (dim_t)mb * jcp.nb_ic * jcp.ih * jcp.iw
902 * jcp.ic_block
903 + y * jcp.iw * jcp.ic_block + x * jcp.ic_block;
904 auto local_w = wino_src + m * jcp.ic;
905
906 src_trans_p.src = local_s;
907 src_trans_p.wino_src = local_w;
908 src_trans_p.v_y_masks = v_y_masks;
909 src_trans_p.v_x_masks = v_x_masks;
910
911 (*src_trans_)(&src_trans_p);
912 }
913 }
914 /* gemms */
915 for (int tile_ij = 0; tile_ij < 16; tile_ij++) {
916 int offset = (tile_ij + ithr) % 16;
917 gemm_p.src = wino_src + jcp.inp_stride * offset;
918 gemm_p.dst = wino_dst + jcp.out_stride * offset;
919 gemm_p.wei = wei + jcp.wei_stride * offset;
920
921 (*kernel_)(&gemm_p);
922 }
923
924 /* transformation from winograd domain to output tensor */
925 for (int y_in_block = 0; y_in_block < jcp.yb; y_in_block += 2) {
926 for (int x_in_block = 0; x_in_block < jcp.xb;
927 x_in_block += 2) {
928 unsigned short v_y_masks[2], v_x_masks[2];
929
930 int y = y_in_block + tile_y;
931 int x = x_in_block + tile_x;
932 int m = (y_in_block / 2) * (jcp.xb / 2)
933 + (x_in_block / 2);
934
935#pragma unroll(2)
936 for (int i = 0; i < jcp.m; i++) {
937 v_x_masks[i] = (x + i < jcp.ow) ? 0xffff : 0;
938 v_y_masks[i] = (y + i < jcp.oh) ? 0xffff : 0;
939 }
940 auto local_d = dst
941 + (dim_t)mb * jcp.nb_oc * jcp.oh * jcp.ow
942 * jcp.oc_block
943 + y * jcp.ow * jcp.oc_block + x * jcp.oc_block;
944 auto local_w = wino_dst + m * jcp.oc;
945
946 dst_trans_p.dst = local_d;
947 dst_trans_p.wino_dst = local_w;
948 dst_trans_p.v_y_masks = v_y_masks;
949 dst_trans_p.v_x_masks = v_x_masks;
950 dst_trans_p.bias = bia;
951
952 (*dst_trans_)(&dst_trans_p);
953 }
954 }
955 });
956}
957
958void jit_avx512_core_f32_wino_conv_2x3_fwd_t::execute_forward_small_mb(
959 const float *src, const float *wei, const float *bia, float *dst,
960 const memory_tracking::grantor_t &scratchpad) const {
961 const auto &jcp = kernel_->jcp;
962
963 if (pd()->wants_padded_bias()) {
964 auto padded_bias = scratchpad.get<float>(key_conv_padded_bias);
965 utils::array_copy(padded_bias, bia, jcp.oc_without_padding);
966 utils::array_set(padded_bias + jcp.oc_without_padding, 0.f,
967 jcp.oc - jcp.oc_without_padding);
968 bia = padded_bias;
969 }
970
971 auto ptr_V = scratchpad.get<float>(key_wino_V);
972 auto ptr_M = scratchpad.get<float>(key_wino_M);
973
974 for_(int mb = 0; mb < jcp.mb; mb++)
975 for_(int tile_y = 0; tile_y < jcp.oh; tile_y += jcp.yb)
976 for (int tile_x = 0; tile_x < jcp.ow; tile_x += jcp.xb) {
977 /* transformation of input tensor to winograd domain */
978 parallel_nd(div_up(jcp.yb, 2), div_up(jcp.xb, 2),
979 [&](dim_t y_in_block_b, dim_t x_in_block_b) {
980 int y_in_block = y_in_block_b * 2;
981 int x_in_block = x_in_block_b * 2;
982
983 auto src_trans_p
984 = jit_avx512_core_f32_wino_conv_2x3_src_trans_t ::
985 call_params_t();
986
987 unsigned short v_y_masks[4], v_x_masks[4];
988
989 int y = y_in_block + tile_y;
990 int x = x_in_block + tile_x;
991 int m = (y_in_block / 2) * (jcp.xb / 2) + (x_in_block / 2);
992
993 int v_ys = nstl::max(0, jcp.t_pad - y);
994 int v_ye = nstl::min(
995 jcp.alpha, nstl::max(0, jcp.ih + jcp.t_pad - y));
996
997 int v_xs = nstl::max(0, jcp.l_pad - x);
998 int v_xe = nstl::min(
999 jcp.alpha, nstl::max(0, jcp.iw + jcp.l_pad - x));
1000
1001#pragma unroll(4)
1002 for (int i = 0; i < jcp.alpha; i++) {
1003 v_y_masks[i] = (i < v_ys || i >= v_ye) ? 0 : 0xffff;
1004 v_x_masks[i] = (i < v_xs || i >= v_xe) ? 0 : 0xffff;
1005 }
1006 auto local_s = src
1007 + (dim_t)mb * jcp.nb_ic * jcp.ih * jcp.iw
1008 * jcp.ic_block
1009 + y * jcp.iw * jcp.ic_block + x * jcp.ic_block;
1010 auto local_w = ptr_V + m * jcp.ic;
1011
1012 src_trans_p.src = local_s;
1013 src_trans_p.wino_src = local_w;
1014 src_trans_p.v_y_masks = v_y_masks;
1015 src_trans_p.v_x_masks = v_x_masks;
1016
1017 (*src_trans_)(&src_trans_p);
1018 });
1019
1020 /* gemms */
1021 parallel_nd(16, jcp.n_chunks, [&](dim_t tile_ij, dim_t nnb) {
1022 auto gemm_p = jit_avx512_core_f32_wino_conv_2x3_fwd_ker_t ::
1023 call_params_t();
1024
1025 gemm_p.src = ptr_V + jcp.inp_stride * tile_ij;
1026 gemm_p.dst = ptr_M + jcp.out_stride * tile_ij
1027 + nnb * jcp.n2_block * jcp.n_block;
1028 gemm_p.wei = wei + jcp.wei_stride * tile_ij
1029 + nnb * jcp.n2_block * jcp.n_block * jcp.K;
1030
1031 (*kernel_)(&gemm_p);
1032 });
1033
1034 /* transformation from winograd domain to output tensor */
1035
1036 parallel_nd(div_up(jcp.yb, 2), div_up(jcp.xb, 2),
1037 [&](dim_t y_in_block_b, dim_t x_in_block_b) {
1038 int y_in_block = y_in_block_b * 2;
1039 int x_in_block = x_in_block_b * 2;
1040
1041 auto dst_trans_p
1042 = jit_avx512_core_f32_wino_conv_2x3_dst_trans_t ::
1043 call_params_t();
1044
1045 unsigned short v_y_masks[2], v_x_masks[2];
1046
1047 int y = y_in_block + tile_y;
1048 int x = x_in_block + tile_x;
1049 int m = (y_in_block / 2) * (jcp.xb / 2) + (x_in_block / 2);
1050
1051#pragma unroll(2)
1052 for (int i = 0; i < jcp.m; i++) {
1053 v_x_masks[i] = (x + i < jcp.ow) ? 0xffff : 0;
1054 v_y_masks[i] = (y + i < jcp.oh) ? 0xffff : 0;
1055 }
1056 auto local_d = dst
1057 + (dim_t)mb * jcp.nb_oc * jcp.oh * jcp.ow
1058 * jcp.oc_block
1059 + y * jcp.ow * jcp.oc_block + x * jcp.oc_block;
1060 auto local_w = ptr_M + m * jcp.oc;
1061
1062 dst_trans_p.dst = local_d;
1063 dst_trans_p.wino_dst = local_w;
1064 dst_trans_p.v_y_masks = v_y_masks;
1065 dst_trans_p.v_x_masks = v_x_masks;
1066
1067 dst_trans_p.bias = bia;
1068
1069 (*dst_trans_)(&dst_trans_p);
1070 });
1071 }
1072}
1073
1074} // namespace x64
1075} // namespace cpu
1076} // namespace impl
1077} // namespace dnnl
1078