1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_PRIMITIVE_CONF_HPP
18#define CPU_X64_JIT_PRIMITIVE_CONF_HPP
19
20#include <queue>
21#include <stdint.h>
22
23#include "common/primitive_attr.hpp"
24#include "cpu/x64/brgemm/brgemm_types.hpp"
25#include "cpu/x64/cpu_isa_traits.hpp"
26
27namespace dnnl {
28namespace impl {
29namespace cpu {
30namespace x64 {
31
32/* convolution */
33enum conv_loop_order_t {
34 loop_cgn,
35 loop_gnc,
36 loop_ngc,
37 loop_gncw,
38 loop_cwgn,
39 loop_ngcw,
40 loop_nhwcg,
41 loop_nwcg
42};
43enum conv_1x1_loop_order_t {
44 loop_rbl,
45 loop_rlb,
46 loop_lbr,
47 loop_lrb,
48 loop_blr,
49 loop_brl
50};
51
52enum conv_kernel_kind_t { embd_bcast, expl_bcast };
53enum conv_harness_t {
54 harness_2d_reduction,
55 harness_3d_reduction,
56 harness_mb_reduction,
57 harness_compute_full_spatial,
58 harness_nxc
59};
60
61enum {
62 FLAG_MB_FIRST = 1 << 0,
63 FLAG_MB_LAST = 1 << 1,
64 FLAG_OC_FIRST = 1 << 2,
65 FLAG_OC_LAST = 1 << 3,
66 FLAG_IC_FIRST = 1 << 4,
67 FLAG_IC_LAST = 1 << 5,
68 FLAG_SP_FIRST = 1 << 6,
69 FLAG_SP_LAST = 1 << 7,
70 FLAG_REDUCE_FIRST = 1 << 8,
71 FLAG_REDUCE_LAST = 1 << 9,
72 FLAG_ZERO_FILTER = 1 << 0, /* Controls whether the inner kernel skips
73 loading weights-data from memory; this
74 needs to happen on the first Group/16
75 iteration. */
76 FLAG_ZERO_BIAS = 1 << 1, /* Controls whether the inner kernel skip
77 loading bias data from memory */
78 FLAG_COMPUTE_BIAS = 1 << 2, /* Controls bias computation during execution
79 pass */
80};
81
82enum class jit_memory_tag_kind_t { ncsp, nspc, blocked, undef };
83
84struct jit_conv_conf_t {
85 prop_kind_t prop_kind;
86 bool has_vnni;
87 conv_loop_order_t loop_order;
88 conv_harness_t harness;
89
90 int simd_w;
91 int ndims;
92 int mb;
93 int ngroups, ic, oc, oc_without_padding, ic_without_padding;
94 int id, ih, iw, od, oh, ow;
95 int f_pad, l_pad, t_pad;
96 int back_pad, r_pad, b_pad;
97 int kd, kh, kw;
98 int stride_d, stride_h, stride_w;
99 int dilate_d, dilate_h, dilate_w;
100 format_tag_t src_tag, wei_tag, dst_tag; // temporary workaround
101 bool with_bias;
102 bool with_sum;
103 bool with_eltwise;
104 bool with_binary;
105
106 data_type_t sum_dt;
107
108 bool with_binary_per_oc_bcast;
109 bool with_binary_no_bcast;
110
111 bool is_fused_conv;
112 int dw_conv_buffer_oc;
113
114 post_ops_t::entry_t::eltwise_t eltwise;
115 post_ops_t post_ops;
116 bool is_fast_postops; // maybe skip injector for sum and/or relu
117
118 int nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b, nthr_oh;
119
120 int idp, ihp, iwp, ohp, owp, icp;
121 int nb_ic, ic_block;
122 int nb_oc, oc_block;
123 int nb_iw, iw_block;
124 int nb_ow, ow_block;
125 int nb_oc_blocking; /* used in jit kernels for nb_oc work blocking taking
126 into account vector registers distribution */
127 int nb_oc_blocking_thr_chunk; /* used for distribution of nb_oc work
128 within threads */
129 int nb_ic_blocking, nb_ic_blocking_max; // blocking of nb_ic work
130 int nb_ic_L2;
131 int h_blocking;
132 int nb_oc_L2;
133 int ic_tail, oc_tail, ch_tail;
134 int ur_h, ur_w;
135 int ur_w_tail, ur_w_blocks;
136 int ur_ic, ur_kw;
137 bool is_1stconv;
138 int nonblk_group_off;
139 /* fma avx512_core */
140 conv_kernel_kind_t kernel_kind;
141
142 int tr_iw, tr_ih;
143 int tr_kw, tr_kh;
144 int tr_src_num_guard_elems;
145
146 // Transpose buffer management
147 size_t tr_src_buf_size, tr_src_buf_count;
148 size_t tr_diff_dst_buf_size, tr_diff_dst_buf_count;
149 int nthr_mb_work;
150
151 int typesize_in;
152 int typesize_out;
153 int typesize_bia;
154 int typesize_acc;
155 /* avx512_u8s8u8 */
156 int ic_nb1, ic_nb2;
157 int oc_nb1;
158 int ur_ow_max, ur_ow, ur_ow_tail;
159 int ur_ow_nsteps;
160 data_type_t bia_dt;
161 /* bf16 data-type for output */
162 data_type_t dst_dt;
163 data_type_t src_dt;
164 /* bf16 weights update */
165 data_type_t wei_dt;
166 data_type_t ddst_dt;
167 data_type_t dsrc_dt;
168 data_type_t dwei_dt;
169 bool expl_bcast;
170 bool large_spatial, large_w_filter;
171 int is_ic_scale, is_oc_scale;
172 int max_regs_ur; // maximum accumulation registers
173 // dw conv
174 int nb_ch, ch_block, nb_ch_blocking;
175 bool is_depthwise, is_fast_depthwise, is_resrc_depthwise;
176 int aligned_threads;
177 // large spatial
178 int ih_blk_size, oh_blk_size;
179 // s8s8 convolution
180 bool signed_input;
181 bool need_saturation;
182 float wei_adj_scale;
183 // zero-point compensation
184 bool src_zero_point;
185 int zp_pbuff_size;
186 bool dst_zero_point;
187 bool zp_src_is_common; // common, otherwise (TODO) per-channel
188 bool req_zero_point_buffer; // used for calculating padding compensation
189 bool zp_pbuff_outer_compute; // indicates if zp_bbuff is computed in
190
191 bool dst_scale;
192
193 // a separate parallel region
194 int ow_pad, oh_pad, od_pad; // output elements with padding & filter overlap
195
196 //output elements requiring zero-point padding compensation
197 int f_pad_output, back_pad_output;
198 int t_pad_output, b_pad_output;
199 int l_pad_output, r_pad_output;
200 // The number of output blocks corresponding to {l_pad, no_pad, r_pad}
201 int l_pad_blk, no_pad_w_blk, r_pad_blk;
202
203 bool od_mid, oh_mid, ow_mid; // indicate if there is overlap between the
204 //width and height padded regions
205
206 size_t h_blk_limits[5]; // pre-computed limits for output height block
207
208 bool uses_permw_transposition;
209 bool transpose_src;
210 bool transpose_dst;
211 int ic_block_step;
212
213 cpu_isa_t isa;
214 // bf16 bwdw conv
215 int tr_ow;
216 bool is_hw_transp; // spatial dim height-width transposed
217 int spatial_blk_size; // Height/depth block size inside the driver
218 bool global_transpose; // diff_dst & src tensors are transposed in one go
219 bool use_nt_stores_ddst; // Use non temporal stores in diff_dst transform
220
221 // Needed for Intel(R) Advanced Matrix Extensions (Intel(R) AMX) kernels
222 bool is_nspc; // activations in nwc, nhwc, or ndhwc layout
223 bool is_relo; // reduced lowering optimization
224 int nreduce; // used with is_relo
225 bool is_pbuffer_strided; // does pbuffer have strided sectors?
226 int n_stride_sets; // number of stride sectors (or sets) in pbuffer
227 int kw_step; // usually stride_w, unless !is_pbuffer_strided
228 int kw_per_tile; // mostly for 1st convs
229 // The suffix _int refers to the block sizes of the src and diff_dst tiles,
230 // as opposed to the vector registers. This distinction is needed due to
231 // support for blocked layout (ie nChw16c) with bf16 data type.
232 int ic_block_int, ic_block_int_np, oc_block_int;
233 int nb_ic_int, nb_oc_int;
234 int nb_ih_blocking, nb_oh_blocking;
235
236 int full_tile_width;
237 int max_tiles;
238 int tile_width;
239 int tile_tail;
240 int oh_per_tile;
241 int iw_blocks, ow_blocks;
242
243 int per_one_pstore;
244
245 size_t inp_buffer_size;
246 size_t wei_buffer_size;
247 size_t wsp_buffer_size;
248
249 int nb_os;
250 int nb_os_blocking;
251 int nb_os2_blocking;
252 int os_tail;
253 int os_blocked;
254 int max_width;
255
256 bool transform_to_vnni;
257};
258
259// calculates filter size taking into account dilation
260inline int calculate_extended_filter_size(int filter_size, int dilation) {
261 return (filter_size - 1) * (dilation + 1) + 1;
262}
263
264inline int calculate_end_padding(int start_padding, int dst_size, int src_size,
265 int spatial_stride, int dilated_filter_size) {
266 return (dst_size - 1) * spatial_stride + dilated_filter_size
267 - (src_size + start_padding);
268}
269
270inline status_t init_tag(format_tag_t &tag, const memory_desc_wrapper &mdw,
271 const format_tag_t &tag_value) {
272 if (mdw.format_kind() == format_kind::any) return status::unimplemented;
273
274 tag = mdw.matches_one_of_tag(tag_value);
275 return tag == tag_value ? status::success : status::unimplemented;
276}
277
278struct jit_conv_conf_2x3_wino_t {
279 bool has_vnni;
280
281 int m;
282 int r;
283 int alpha;
284 int tile_h, tile_w;
285
286 int mb;
287 int ngroups, ic, oc, oc_without_padding;
288 int ih, iw, oh, ow;
289 int l_pad, t_pad;
290 int r_pad, b_pad;
291 int kh, kw;
292 int stride_h, stride_w;
293 int dilate_h, dilate_w;
294
295 int nb_ic, ic_block;
296 int nb_oc, oc_block;
297
298 int w_block_size, h_block_size;
299
300 data_type_t bia_dt;
301 data_type_t dst_dt;
302
303 int is_oc_scale;
304 int typesize_in;
305 int typesize_out;
306 int typesize_bia;
307 int typesize_acc;
308
309 format_tag_t src_tag, dst_tag; // temporary workaround
310 bool with_bias;
311 bool small_mb;
312
313 int xb, yb;
314 int inp_stride;
315 int out_stride;
316 int wei_stride;
317 int bia_stride;
318
319 int M, N, K;
320 int m_block, n_block, k_block;
321 int n2_block, n_chunks;
322 int k2_block, k_chunks;
323
324 int mb_block, nb_mb;
325
326 size_t size_wino_src, size_wino_wei, size_wino_dst;
327
328 int nthr;
329};
330
331/*
332 Winograd sched policy:
333
334 Computation Unit:
335 W: weights transform
336 S: src transform
337 D: dst transform
338 G: gemm
339
340 Thread grouping by:
341 i: nb_ic
342 o: nb_oc
343 t: tile_block
344 e: element in tile
345
346 Note: 'i' and 'o' are omitted if
347 i. not combined with t or
348 ii. with discrete transforms
349
350 Current policies supported:
351*/
352enum winograd_sched_t {
353 WSCHED_INVALID = 0,
354
355 /* Forward & backward-data */
356 /* W_S_G_D implements discrete transforms */
357 WSCHED_DATA_W_S_G_D,
358 /* W_SGD implements tiled transforms s.t. GEMM could reuse data in L2*/
359 WSCHED_DATA_W_SGD,
360
361 /* Backward-weights */
362 WSCHED_WEI_S_D_G_W,
363 WSCHED_WEI_SDGtWo,
364 WSCHED_WEI_S_D_Giot_W,
365 WSCHED_WEI_SDGt_W,
366};
367
368struct jit_conv_winograd_conf_t : public jit_conv_conf_t {
369 int itiles;
370 int jtiles;
371 int ntiles;
372 int ic_simd_block = 16;
373 int oc_simd_block = 16;
374 int oc_reg_block;
375 int ic_reg_block;
376 int tile_block;
377 int tile_block_ur;
378 int nb_tile_block_ur;
379
380 bool double_buffering;
381 bool with_relu_postsum;
382 int zmm_start;
383 int nb_reg;
384
385 int dimK;
386 int dimK_reg_block;
387 int dimK_block;
388 int dimK_nb_block;
389
390 int dimM;
391 int dimM_reg_block;
392 int dimM_simd_block;
393 int dimM_block;
394 int dimM_nb_block;
395
396 int dimN;
397 int dimN_reg_block;
398 int dimN_bcast_ur;
399 int dimN_block;
400 int dimN_nb_block;
401
402 winograd_sched_t sched_policy;
403};
404
405struct jit_conv_call_s {
406 const void *src; /* hack, non-const for backward_data */
407 const void *dst; /* hack, non-const for forward */
408 const void *filt; /* hack, non-const for backward_weights */
409 const void *bias; /* hack, non-const for backward_bias */
410 const void *src_prf;
411 const void *dst_prf;
412 const void *filt_prf;
413 const void *bias_prf;
414 const void *scales;
415 const void *acc_s32;
416 const void *compensation;
417 const int32_t *zp_compensation;
418 const int32_t *src_zero_point;
419 const int32_t *zero_point_pbuff;
420 const int32_t *dst_zero_point;
421 const void *tile_cfg;
422 const void *tile_cfg_tail;
423 const void *dst_scale;
424
425 // ptr to table of void * elements that are pointers to
426 // post_op binary src1 tensors
427 const void *post_ops_binary_rhs_arg_vec;
428 // logical (# of elems) offset to the processed output channel
429 // (for broadcasting [1,OC,1,1])
430 size_t oc_l_off;
431 const void *dst_orig; // pointer to dst memory (no offset)
432
433 size_t oc_l_off_prf;
434 const void *dst_orig_prf;
435
436 size_t kd_offset;
437 size_t kd_offset_prf;
438 size_t kh_offset;
439 size_t kh_offset_prf;
440 size_t os_index_begin;
441 size_t os_index_begin_prf;
442 size_t os_index_end;
443 size_t os_index_end_prf;
444 size_t kd_padding;
445 size_t kd_padding_prf;
446 size_t kh_padding;
447 size_t kh_padding_prf;
448 size_t iwb;
449 size_t iwb_prf;
450 size_t owb;
451 size_t owb_prf;
452 size_t ohb;
453 size_t kw_padding;
454 size_t channel;
455 size_t channel_prf;
456 size_t ic_blocks;
457 size_t oc_blocks;
458 size_t ur_w;
459 size_t ur_str_w;
460 size_t ch_blocks;
461 size_t ch_blocks_prf;
462 size_t reduce_work;
463 size_t reduce_work_prf;
464 size_t load_work;
465 size_t load_work_prf;
466 size_t l_overflow;
467 size_t r_overflow;
468 size_t t_overflow;
469 size_t b_overflow;
470 size_t f_overflow;
471 size_t back_overflow;
472 size_t last_h;
473 size_t tail;
474 size_t current_iw;
475 size_t is_osb;
476 int flags;
477 int flags_prf;
478 int oc_flag;
479 size_t last_ic_block;
480 size_t last_oc_block;
481};
482
483struct jit_deconv_call_s {
484 const void *src; /* hack, non-const for backward_data */
485 const void *dst; /* hack, non-const for forward */
486 const void *filt; /* hack, non-const for backward_weights */
487 const void *bias; /* hack, non-const for backward_bias */
488 const void *scales;
489 const void *dst_scale;
490 const void *compensation;
491 const int32_t *zp_src_pad_str_compensation;
492 const int32_t *zp_compensation;
493 const int32_t *src_zero_point;
494 const int32_t *dst_zero_point;
495
496 /*
497 * ptr to table of void * elements that are pointers to post_op binary
498 * src1 tensors
499 */
500 const void *post_ops_binary_rhs_arg_vec;
501 const void *dst_orig; /* pointer to dst memory (no offset) */
502 /*
503 * logical (# of elems) offset to the processed output channel
504 * (for broadcasting [1,OC,1,1])
505 */
506 size_t oc_l_off;
507 size_t t_overflow;
508 size_t b_overflow;
509 size_t f_overflow;
510 size_t back_overflow;
511 size_t kh_padding;
512 size_t kd_padding;
513 size_t oc_blocks;
514};
515
516struct jit_dw_conv_call_s {
517 const void *input;
518 const void *output;
519 const void *filter;
520 const void *bias;
521 size_t kh_count;
522 size_t oh_count;
523 size_t oh_index;
524 size_t filter_pad_off;
525 unsigned char
526 exec_flags; /* Flags passed by driver execution to inner kernel */
527};
528
529struct jit_wino_transform_call_s {
530 size_t tile_block;
531 size_t tile_block_ur;
532 size_t nb_tile_block_ur;
533 size_t tile_count;
534 size_t tj;
535 size_t ti;
536 void *src;
537 void *dst;
538 void *Mw;
539 void *M;
540 void *T;
541 void *G;
542 void *bias;
543};
544
545struct jit_1x1_conv_conf_t {
546 prop_kind_t prop_kind;
547 bool has_vnni;
548
549 int ndims;
550 int mb;
551 int ngroups, ic, oc, oc_without_padding, ic_without_padding;
552 int id, ih, iw, od, oh, ow;
553 int f_pad, t_pad, l_pad;
554 int kd, kh, kw;
555 int stride_d, stride_h, stride_w;
556 format_tag_t src_tag, wei_tag, dst_tag; // temporary workaround
557 bool with_bias;
558 bool with_sum;
559 bool with_eltwise;
560 bool with_binary;
561 bool with_dw_conv;
562
563 post_ops_t post_ops;
564
565 dim_t is, os;
566 int ic_block, oc_block;
567
568 int ur, ur_tail;
569
570 dim_t reduce_dim;
571 int reduce_block, nb_reduce, nb_reduce_blocking, nb_reduce_blocking_max;
572 int load_dim, load_block, nb_load, nb_load_blocking, nb_load_blocking_max,
573 nb_load_chunk;
574 dim_t bcast_dim;
575 int bcast_block, nb_bcast, nb_bcast_blocking, nb_bcast_blocking_max;
576
577 int reduce_loop_unroll;
578 dim_t reduce_loop_bcast_step;
579 int reduce_loop_load_step;
580 int load_loop_load_step, load_loop_iter_step;
581 int bcast_loop_output_step, bcast_loop_output_substep;
582 int bcast_loop_bcast_step, bcast_loop_bcast_substep;
583 int load_grp_count;
584 conv_1x1_loop_order_t loop_order;
585 bool use_vmovntps;
586 /* avx512 core */
587 bool expl_bcast;
588
589 int typesize_in;
590 int typesize_out;
591 int typesize_bia;
592 int typesize_acc;
593
594 bool transpose_src;
595 int nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b;
596 int is_oc_scale;
597 data_type_t bia_dt;
598 data_type_t dst_dt;
599 data_type_t sum_dt;
600 bool signed_input;
601 float wei_adj_scale;
602 // zero-point compensation
603 bool src_zero_point;
604 bool dst_zero_point;
605 bool zp_src_is_common; // common, otherwise (TODO) per-channel
606
607 bool dst_scale;
608
609 cpu_isa_t isa;
610 bool uses_permw_transposition;
611};
612
613struct jit_1x1_conv_call_s {
614 const void *bcast_data;
615 const void *load_data;
616 const void *output_data;
617 const void *bias_data; // used in forward and backward_weights only
618 const void *acc_s32;
619 const void *scales;
620 const void *compensation;
621 const void *store_buffer;
622 const int32_t *zp_compensation;
623 const int32_t *src_zero_point;
624 const int32_t *dst_zero_point;
625 const void *dst_scale;
626
627 // ptr to table of void * elements that are pointers to
628 // post_op binary src1 tensors
629 const void *post_ops_binary_rhs_arg_vec;
630 // logical (# of elems) offset to the processed output channel
631 // (for broadcasting [1,OC,1,1])
632 size_t oc_l_off;
633 // logical (# of elems) offset to the processed pixel
634 // (for non-broadcasting policy)
635 size_t dst_l_off;
636 const void *dst_orig; // pointer to dst memory (no offset)
637
638 size_t load_dim;
639 size_t bcast_dim;
640 size_t reduce_dim;
641
642 size_t output_stride; // used in backward_weights only
643
644 size_t first_last_flag;
645};
646
647struct jit_pool_conf_t {
648 int ndims;
649 int mb, c, c_without_padding;
650 int id, ih, iw, od, oh, ow;
651 int stride_d, stride_h, stride_w;
652 int kd, kh, kw;
653 int f_pad, t_pad, l_pad;
654 alg_kind_t alg;
655 bool is_training;
656 bool pad_w_is_null;
657 bool is_backward;
658 bool simple_alg;
659 bool is_c_padded;
660 data_type_t ind_dt;
661
662 int c_block, c_tail, nb_c;
663 int ur_bc, ur_bc_tail;
664 int ur_c, ur_c_tail;
665 int ur;
666 size_t tail[4];
667 bool safe_c_tail;
668 data_type_t src_dt;
669 data_type_t dst_dt;
670
671 int dt_size;
672 bool is_bf16;
673 bool is_f16;
674 jit_memory_tag_kind_t tag_kind;
675 bool is_plain() const {
676 return (tag_kind == jit_memory_tag_kind_t::ncsp
677 || tag_kind == jit_memory_tag_kind_t::nspc);
678 }
679
680 cpu_isa_t isa;
681 post_ops_t post_ops;
682 bool with_postops;
683 bool with_eltwise;
684 bool with_binary;
685 int nthr;
686 memory_desc_t tmp_md;
687};
688
689struct jit_pool_call_s {
690 const void *src;
691 const void *dst;
692 const void *indices;
693 const void *src_prf;
694 const void *dst_prf;
695 const void *indices_prf;
696 const void *post_ops_binary_rhs_arg_vec;
697 const void *dst_orig;
698 const void *dst_po_helper;
699 size_t zero_ih;
700 size_t zero_id;
701 const void *zero_ptr;
702 size_t kd_padding;
703 size_t kh_padding;
704 size_t kh_padding_shift;
705 size_t kd_padding_shift;
706 size_t kw_padding;
707 const void *init_value;
708 float ker_area_h;
709 size_t ur_bc; // contains number of channel blocks to processing
710 size_t b_c; // contains number of channel blocks already processed
711};
712
713struct jit_resampling_conf_t {
714 unsigned ndims = 0;
715
716 unsigned c = 0;
717 unsigned id = 0, ih = 0, iw = 0;
718 unsigned od = 0, oh = 0, ow = 0;
719
720 unsigned stride_d = 0;
721 unsigned stride_h = 0;
722 unsigned stride_w = 0;
723 unsigned inner_stride = 0;
724
725 // The linear algorithm is an approximation of the point
726 // value based on the limit values. For one dimension,
727 // the approximation is based on the line, for two
728 // dimensions it will be a rectangle, and for three
729 // dimensions it will be a cuboid. Therefore,
730 // the possible variants for the number of corners are 2, 4, 8.
731 unsigned number_of_corners = 0;
732
733 bool is_data_size_bigger_than_L3 = false;
734 bool is_saturation_needed = false;
735 data_type_t src_data_type = data_type::undef;
736 data_type_t dst_data_type = data_type::undef;
737 size_t src_dt_size = 0;
738 size_t dst_dt_size = 0;
739 size_t output_data_size = 0;
740 size_t el_size_of_indices = 0;
741
742 bool is_blocked_8_format = false;
743 format_tag_t src_tag = format_tag::undef;
744 jit_memory_tag_kind_t tag_kind = jit_memory_tag_kind_t::undef;
745 alg_kind_t alg = alg_kind::undef;
746
747 cpu_isa_t isa = isa_undef;
748
749 post_ops_t post_ops = post_ops_t();
750 bool with_postops = false;
751 bool with_eltwise = false;
752 bool with_binary = false;
753 bool with_sum = false;
754 std::queue<float> sum_scales;
755};
756
757struct jit_resampling_call_s {
758 size_t batch_of_sp_points_to_process = 0;
759
760 const void *src = nullptr;
761 const void *dst = nullptr;
762 const void *indices = nullptr;
763 const void *weights = nullptr;
764 const void *post_ops_binary_rhs_arg_vec = nullptr;
765 const void *dst_orig = nullptr;
766
767 size_t c_offset = 0;
768
769 size_t src_offset_top = 0;
770 size_t src_offset_bottom = 0;
771 size_t src_offset_front = 0;
772 size_t src_offset_back = 0;
773
774 float weight_top = 0.0f;
775 float weight_bottom = 0.0f;
776 float weight_front = 0.0f;
777 float weight_back = 0.0f;
778};
779
780struct jit_brdgmm_conv_conf_t {
781
782 int nthr;
783 int mb, ngroups, ic, oc;
784 int ih, iw, oh, ow;
785 int l_pad, r_pad, t_pad, b_pad;
786 int kh, kw;
787 int stride_h, stride_w;
788 int nb_ch, ch_block, chb_tail;
789 int nb_ch_blocking;
790 int ow_block, ow_tail, nb_ow;
791 // idx of jit kernel when mutiple jit kernels are used in a primitive.
792 int chb_tail_idx, ow_tail_idx, nb_ch_blocking_idx;
793 int adjusted_batch_size;
794
795 bool with_bias;
796 bool with_post_ops;
797 bool with_scale;
798 bool is_oc_scale;
799
800 data_type_t src_dt;
801 data_type_t wei_dt;
802 data_type_t bia_dt;
803 data_type_t dst_dt;
804
805 brgemm_batch_kind_t batch_kind;
806
807 size_t src_dsz;
808 size_t wei_dsz;
809 size_t bia_dsz;
810 size_t dst_dsz;
811
812 cpu_isa_t isa;
813};
814
815enum conv_brgemm_loop_order_t {
816 loop_ndhwgc,
817 loop_ngcdhw,
818};
819
820enum conv_brgemm_exec_type_t {
821 exec_undefined = 0,
822 exec_base,
823 exec_trans,
824 exec_vpad,
825};
826
827struct jit_brgemm_conv_conf_t {
828 cpu_isa_t isa;
829 prop_kind_t prop_kind;
830 conv_brgemm_loop_order_t loop_order;
831 conv_harness_t harness;
832 int simd_w, acc_simd_w, amx_w, amx_h;
833 int ndims;
834 int mb;
835 int ngroups, ic, oc, oc_without_padding, ic_without_padding;
836
837 int od_block, oh_block, nb_od,
838 nb_oh; // blocking - included in parallelization
839 int id_block, ih_block, nb_id, nb_ih;
840 dim_t inp_buffer_size, inp_buffer_mask_size;
841 conv_brgemm_exec_type_t exec_type;
842
843 int id, ih, iw, od, oh, ow, os, is, idp, ihp, iwp, icp, odp, ohp, owp, ocp;
844 int f_pad, l_pad, t_pad;
845 int back_pad, r_pad, b_pad;
846 int l_ovf, r_ovf, t_ovf, b_ovf, f_ovf, back_ovf;
847 int kd, kh, kw;
848 int ext_kd, ext_kh, ext_kw;
849 int kd_block, kh_block, kw_block, kd_block_pad, kh_block_pad, kw_block_pad;
850 int stride_d, stride_h, stride_w;
851 int dilate_d, dilate_h, dilate_w;
852 format_tag_t src_tag, wei_tag, dst_tag; // temporary workaround
853 bool with_bias;
854 bool with_sum;
855 bool with_eltwise;
856 bool with_binary;
857
858 bool is_fused_conv;
859 bool is_is_blocking;
860 bool is_os_blocking;
861 bool is_rtus;
862 int nb_ic, ic_block;
863 int nb_oc, oc_block;
864 int nb_iw, iw_block, iw_tail;
865 int nb_ow, ow_block, ow_tail;
866 int nb_is, is_block;
867 int nb_os, os_block;
868 int nb_oc_blocking;
869 int nb_ic_blocking;
870 int nb_is_blocking;
871 int nb_os_blocking;
872
873 data_type_t src_dt;
874 data_type_t dst_dt;
875 data_type_t wei_dt;
876 data_type_t acc_dt;
877 data_type_t bia_dt;
878 size_t src_dsz;
879 size_t wei_dsz;
880 size_t dst_dsz;
881 size_t acc_dsz;
882 size_t bia_dsz;
883
884 bool use_buffer;
885 dim_t buffer_size;
886 dim_t ker_ranges_size;
887 dim_t comp_a_buffer_size;
888 dim_t s8s8_comp_buffer_size;
889
890 bool with_scales;
891 int is_ic_scale, is_oc_scale;
892
893 int LDA, LDB, LDC, LDD;
894 int M, N, K, M_tail, N_tail, K_tail;
895 // M for brgemm kernel. For use_store_mask it is usually greater than M (M_tail). Otherwise it is equal to M (M_tail)
896 int brgM, brgM_tail;
897 int gemm_batch_size, adjusted_batch_size;
898 brgemm_batch_kind_t brg_type;
899 // strides for brg_type == brgemm_strd
900 dim_t brg_stride_a, brg_stride_b;
901 int nthr;
902
903 int max_batch;
904 int max_vpad;
905 int amx_buf_size_per_thread;
906
907 bool wei_plain;
908 bool is_ic_padded, is_oc_padded;
909 int kw_sets, kh_sets;
910 bool copy_block_only;
911 bool amx_tile_load_xx;
912 int use_M_mask;
913 int oskip, iskip;
914 bool brgemm_bd_loop_innermost;
915
916 bool use_uker;
917 bool var_bs {false};
918 bool use_interleave_stores;
919 brgemm_kernel_prefetching_t hint_prefetching;
920 bool is_1x1;
921 bool s8s8_avx512;
922 bool src_zero_point;
923 bool dst_zero_point;
924 bool req_brg_comp_pad;
925 bool req_cal_comp_pad;
926 bool is_bf32;
927 bool comp_with_vpads;
928
929 int nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b, nthr_oh;
930 bool transform_to_vnni;
931 bool has_vnni;
932 int ic_tail, oc_tail;
933 size_t tr_src_buf_size, tr_src_buf_count;
934 size_t tr_diff_dst_buf_size, tr_diff_dst_buf_count;
935 int tr_src_num_guard_elems;
936 bool global_transpose; // diff_dst & src tensors are transposed in one go
937 int nthr_mb_work;
938 int tr_iw, tr_ow;
939 int spatial_blk_size; // Height/depth block size inside the driver
940 int typesize_in;
941 int typesize_out;
942 bool tr_ocb_chunk = false;
943 bool tr_icb_chunk = false;
944};
945
946struct jit_shuffle_conf_t {
947 unsigned ndims = 0;
948
949 unsigned mb = 0, c = 0, d = 0, h = 0, w = 0, sp = 0;
950
951 unsigned stride_mb = 0;
952 unsigned blk_size = 0;
953 unsigned group_size = 0;
954 unsigned axis = 0;
955 unsigned axis_size = 0;
956 unsigned simd_tail = 0;
957 unsigned simd_w = 0;
958
959 jit_memory_tag_kind_t tag_kind = jit_memory_tag_kind_t::undef;
960 data_type_t data_type = data_type::undef;
961 size_t dt_size = 0;
962 unsigned el_size_of_indices = 0;
963 dim_t c_split_size = 0;
964 dim_t sp_split_size = 0;
965
966 cpu_isa_t isa = isa_undef;
967};
968
969struct jit_shuffle_call_s {
970 const void *src = nullptr;
971 void *dst = nullptr;
972 const void *input_off_ptr = nullptr;
973
974 dim_t cb_loop_size
975 = 0; // number of loop iterations over corresponding C batches
976 bool is_padded_block = false;
977};
978
979enum class binary_op_t : unsigned { none, c_blocked, n_spatial_c, n_c_spatial };
980
981enum class binary_bcast_t : unsigned {
982 none, // tensor operation
983 scalar,
984 per_batch,
985 per_c,
986 per_w
987};
988
989struct jit_binary_conf_t {
990 binary_op_t op_type = binary_op_t::none;
991 binary_bcast_t bcast_type = binary_bcast_t::none;
992 bool do_scale_src0 = false;
993 bool do_scale_src1 = false;
994 bool do_sum = false;
995 bool with_eltwise = false;
996 bool with_binary = false;
997 bool with_postops = false;
998 float sum_scale = 0.f;
999 bool use_stride_src1 = false;
1000 bool broadcast_src1_value = false;
1001 bool use_stride_rhs_postops = false;
1002 bool postops_per_oc_broadcast_exists = false;
1003 bool is_i8 = false;
1004 bool is_bf16 = false;
1005 bool is_f16 = false;
1006 bool is_src_different_layouts = false;
1007 dim_t outer_dims = 1;
1008 int src1_stride = 1;
1009 int not_bcasted_sp_dims = 0;
1010 cpu_isa_t isa = isa_undef;
1011
1012 data_type_t src0_type = data_type::undef;
1013 data_type_t src1_type = data_type::undef;
1014 data_type_t dst_type = data_type::undef;
1015};
1016
1017struct jit_binary_call_s {
1018 // keep all sizes at 8 bytes -- jit code expects this
1019 const void *src0, *src1, *dst, *indices;
1020 const float *scales_src0, *scales_src1;
1021 size_t spat_offt_count;
1022 const void *post_ops_binary_rhs_arg_vec;
1023 size_t src1_stride_range;
1024 const void *dst_orig;
1025};
1026
1027struct jit_reduction_conf_t {
1028 data_type_t src_type = data_type::undef;
1029 data_type_t dst_type = data_type::undef;
1030 data_type_t acc_type = data_type::undef;
1031
1032 std::size_t src_dt_size = 0;
1033 std::size_t dst_dt_size = 0;
1034 std::size_t acc_dt_size = 0;
1035
1036 alg_kind_t alg = alg_kind::undef;
1037 cpu_isa_t isa = isa_undef;
1038
1039 dim_t idle_size = 0;
1040 dim_t reduce_size = 0;
1041
1042 bool is_saturation_needed = false;
1043
1044 post_ops_t post_ops = post_ops_t();
1045 bool with_postops = false;
1046 bool with_eltwise = false;
1047 bool with_binary = false;
1048 bool with_sum = false;
1049 std::queue<float> sum_scales;
1050};
1051
1052struct jit_reduction_call_s {
1053 const void *src = nullptr;
1054 void *dst = nullptr;
1055 const void *post_ops_binary_rhs_arg_vec = nullptr;
1056 const void *dst_orig = nullptr;
1057};
1058
1059} // namespace x64
1060} // namespace cpu
1061} // namespace impl
1062} // namespace dnnl
1063
1064#endif
1065