1 | /******************************************************************************* |
2 | * Copyright 2016-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_JIT_PRIMITIVE_CONF_HPP |
18 | #define CPU_X64_JIT_PRIMITIVE_CONF_HPP |
19 | |
20 | #include <queue> |
21 | #include <stdint.h> |
22 | |
23 | #include "common/primitive_attr.hpp" |
24 | #include "cpu/x64/brgemm/brgemm_types.hpp" |
25 | #include "cpu/x64/cpu_isa_traits.hpp" |
26 | |
27 | namespace dnnl { |
28 | namespace impl { |
29 | namespace cpu { |
30 | namespace x64 { |
31 | |
32 | /* convolution */ |
33 | enum conv_loop_order_t { |
34 | loop_cgn, |
35 | loop_gnc, |
36 | loop_ngc, |
37 | loop_gncw, |
38 | loop_cwgn, |
39 | loop_ngcw, |
40 | loop_nhwcg, |
41 | loop_nwcg |
42 | }; |
43 | enum conv_1x1_loop_order_t { |
44 | loop_rbl, |
45 | loop_rlb, |
46 | loop_lbr, |
47 | loop_lrb, |
48 | loop_blr, |
49 | loop_brl |
50 | }; |
51 | |
52 | enum conv_kernel_kind_t { embd_bcast, expl_bcast }; |
53 | enum conv_harness_t { |
54 | harness_2d_reduction, |
55 | harness_3d_reduction, |
56 | harness_mb_reduction, |
57 | harness_compute_full_spatial, |
58 | harness_nxc |
59 | }; |
60 | |
61 | enum { |
62 | FLAG_MB_FIRST = 1 << 0, |
63 | FLAG_MB_LAST = 1 << 1, |
64 | FLAG_OC_FIRST = 1 << 2, |
65 | FLAG_OC_LAST = 1 << 3, |
66 | FLAG_IC_FIRST = 1 << 4, |
67 | FLAG_IC_LAST = 1 << 5, |
68 | FLAG_SP_FIRST = 1 << 6, |
69 | FLAG_SP_LAST = 1 << 7, |
70 | FLAG_REDUCE_FIRST = 1 << 8, |
71 | FLAG_REDUCE_LAST = 1 << 9, |
72 | FLAG_ZERO_FILTER = 1 << 0, /* Controls whether the inner kernel skips |
73 | loading weights-data from memory; this |
74 | needs to happen on the first Group/16 |
75 | iteration. */ |
76 | FLAG_ZERO_BIAS = 1 << 1, /* Controls whether the inner kernel skip |
77 | loading bias data from memory */ |
78 | FLAG_COMPUTE_BIAS = 1 << 2, /* Controls bias computation during execution |
79 | pass */ |
80 | }; |
81 | |
82 | enum class jit_memory_tag_kind_t { ncsp, nspc, blocked, undef }; |
83 | |
84 | struct jit_conv_conf_t { |
85 | prop_kind_t prop_kind; |
86 | bool has_vnni; |
87 | conv_loop_order_t loop_order; |
88 | conv_harness_t harness; |
89 | |
90 | int simd_w; |
91 | int ndims; |
92 | int mb; |
93 | int ngroups, ic, oc, oc_without_padding, ic_without_padding; |
94 | int id, ih, iw, od, oh, ow; |
95 | int f_pad, l_pad, t_pad; |
96 | int back_pad, r_pad, b_pad; |
97 | int kd, kh, kw; |
98 | int stride_d, stride_h, stride_w; |
99 | int dilate_d, dilate_h, dilate_w; |
100 | format_tag_t src_tag, wei_tag, dst_tag; // temporary workaround |
101 | bool with_bias; |
102 | bool with_sum; |
103 | bool with_eltwise; |
104 | bool with_binary; |
105 | |
106 | data_type_t sum_dt; |
107 | |
108 | bool with_binary_per_oc_bcast; |
109 | bool with_binary_no_bcast; |
110 | |
111 | bool is_fused_conv; |
112 | int dw_conv_buffer_oc; |
113 | |
114 | post_ops_t::entry_t::eltwise_t eltwise; |
115 | post_ops_t post_ops; |
116 | bool is_fast_postops; // maybe skip injector for sum and/or relu |
117 | |
118 | int nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b, nthr_oh; |
119 | |
120 | int idp, ihp, iwp, ohp, owp, icp; |
121 | int nb_ic, ic_block; |
122 | int nb_oc, oc_block; |
123 | int nb_iw, iw_block; |
124 | int nb_ow, ow_block; |
125 | int nb_oc_blocking; /* used in jit kernels for nb_oc work blocking taking |
126 | into account vector registers distribution */ |
127 | int nb_oc_blocking_thr_chunk; /* used for distribution of nb_oc work |
128 | within threads */ |
129 | int nb_ic_blocking, nb_ic_blocking_max; // blocking of nb_ic work |
130 | int nb_ic_L2; |
131 | int h_blocking; |
132 | int nb_oc_L2; |
133 | int ic_tail, oc_tail, ch_tail; |
134 | int ur_h, ur_w; |
135 | int ur_w_tail, ur_w_blocks; |
136 | int ur_ic, ur_kw; |
137 | bool is_1stconv; |
138 | int nonblk_group_off; |
139 | /* fma avx512_core */ |
140 | conv_kernel_kind_t kernel_kind; |
141 | |
142 | int tr_iw, tr_ih; |
143 | int tr_kw, tr_kh; |
144 | int tr_src_num_guard_elems; |
145 | |
146 | // Transpose buffer management |
147 | size_t tr_src_buf_size, tr_src_buf_count; |
148 | size_t tr_diff_dst_buf_size, tr_diff_dst_buf_count; |
149 | int nthr_mb_work; |
150 | |
151 | int typesize_in; |
152 | int typesize_out; |
153 | int typesize_bia; |
154 | int typesize_acc; |
155 | /* avx512_u8s8u8 */ |
156 | int ic_nb1, ic_nb2; |
157 | int oc_nb1; |
158 | int ur_ow_max, ur_ow, ur_ow_tail; |
159 | int ur_ow_nsteps; |
160 | data_type_t bia_dt; |
161 | /* bf16 data-type for output */ |
162 | data_type_t dst_dt; |
163 | data_type_t src_dt; |
164 | /* bf16 weights update */ |
165 | data_type_t wei_dt; |
166 | data_type_t ddst_dt; |
167 | data_type_t dsrc_dt; |
168 | data_type_t dwei_dt; |
169 | bool expl_bcast; |
170 | bool large_spatial, large_w_filter; |
171 | int is_ic_scale, is_oc_scale; |
172 | int max_regs_ur; // maximum accumulation registers |
173 | // dw conv |
174 | int nb_ch, ch_block, nb_ch_blocking; |
175 | bool is_depthwise, is_fast_depthwise, is_resrc_depthwise; |
176 | int aligned_threads; |
177 | // large spatial |
178 | int ih_blk_size, oh_blk_size; |
179 | // s8s8 convolution |
180 | bool signed_input; |
181 | bool need_saturation; |
182 | float wei_adj_scale; |
183 | // zero-point compensation |
184 | bool src_zero_point; |
185 | int zp_pbuff_size; |
186 | bool dst_zero_point; |
187 | bool zp_src_is_common; // common, otherwise (TODO) per-channel |
188 | bool req_zero_point_buffer; // used for calculating padding compensation |
189 | bool zp_pbuff_outer_compute; // indicates if zp_bbuff is computed in |
190 | |
191 | bool dst_scale; |
192 | |
193 | // a separate parallel region |
194 | int ow_pad, oh_pad, od_pad; // output elements with padding & filter overlap |
195 | |
196 | //output elements requiring zero-point padding compensation |
197 | int f_pad_output, back_pad_output; |
198 | int t_pad_output, b_pad_output; |
199 | int l_pad_output, r_pad_output; |
200 | // The number of output blocks corresponding to {l_pad, no_pad, r_pad} |
201 | int l_pad_blk, no_pad_w_blk, r_pad_blk; |
202 | |
203 | bool od_mid, oh_mid, ow_mid; // indicate if there is overlap between the |
204 | //width and height padded regions |
205 | |
206 | size_t h_blk_limits[5]; // pre-computed limits for output height block |
207 | |
208 | bool uses_permw_transposition; |
209 | bool transpose_src; |
210 | bool transpose_dst; |
211 | int ic_block_step; |
212 | |
213 | cpu_isa_t isa; |
214 | // bf16 bwdw conv |
215 | int tr_ow; |
216 | bool is_hw_transp; // spatial dim height-width transposed |
217 | int spatial_blk_size; // Height/depth block size inside the driver |
218 | bool global_transpose; // diff_dst & src tensors are transposed in one go |
219 | bool use_nt_stores_ddst; // Use non temporal stores in diff_dst transform |
220 | |
221 | // Needed for Intel(R) Advanced Matrix Extensions (Intel(R) AMX) kernels |
222 | bool is_nspc; // activations in nwc, nhwc, or ndhwc layout |
223 | bool is_relo; // reduced lowering optimization |
224 | int nreduce; // used with is_relo |
225 | bool is_pbuffer_strided; // does pbuffer have strided sectors? |
226 | int n_stride_sets; // number of stride sectors (or sets) in pbuffer |
227 | int kw_step; // usually stride_w, unless !is_pbuffer_strided |
228 | int kw_per_tile; // mostly for 1st convs |
229 | // The suffix _int refers to the block sizes of the src and diff_dst tiles, |
230 | // as opposed to the vector registers. This distinction is needed due to |
231 | // support for blocked layout (ie nChw16c) with bf16 data type. |
232 | int ic_block_int, ic_block_int_np, oc_block_int; |
233 | int nb_ic_int, nb_oc_int; |
234 | int nb_ih_blocking, nb_oh_blocking; |
235 | |
236 | int full_tile_width; |
237 | int max_tiles; |
238 | int tile_width; |
239 | int tile_tail; |
240 | int oh_per_tile; |
241 | int iw_blocks, ow_blocks; |
242 | |
243 | int per_one_pstore; |
244 | |
245 | size_t inp_buffer_size; |
246 | size_t wei_buffer_size; |
247 | size_t wsp_buffer_size; |
248 | |
249 | int nb_os; |
250 | int nb_os_blocking; |
251 | int nb_os2_blocking; |
252 | int os_tail; |
253 | int os_blocked; |
254 | int max_width; |
255 | |
256 | bool transform_to_vnni; |
257 | }; |
258 | |
259 | // calculates filter size taking into account dilation |
260 | inline int calculate_extended_filter_size(int filter_size, int dilation) { |
261 | return (filter_size - 1) * (dilation + 1) + 1; |
262 | } |
263 | |
264 | inline int calculate_end_padding(int start_padding, int dst_size, int src_size, |
265 | int spatial_stride, int dilated_filter_size) { |
266 | return (dst_size - 1) * spatial_stride + dilated_filter_size |
267 | - (src_size + start_padding); |
268 | } |
269 | |
270 | inline status_t init_tag(format_tag_t &tag, const memory_desc_wrapper &mdw, |
271 | const format_tag_t &tag_value) { |
272 | if (mdw.format_kind() == format_kind::any) return status::unimplemented; |
273 | |
274 | tag = mdw.matches_one_of_tag(tag_value); |
275 | return tag == tag_value ? status::success : status::unimplemented; |
276 | } |
277 | |
278 | struct jit_conv_conf_2x3_wino_t { |
279 | bool has_vnni; |
280 | |
281 | int m; |
282 | int r; |
283 | int alpha; |
284 | int tile_h, tile_w; |
285 | |
286 | int mb; |
287 | int ngroups, ic, oc, oc_without_padding; |
288 | int ih, iw, oh, ow; |
289 | int l_pad, t_pad; |
290 | int r_pad, b_pad; |
291 | int kh, kw; |
292 | int stride_h, stride_w; |
293 | int dilate_h, dilate_w; |
294 | |
295 | int nb_ic, ic_block; |
296 | int nb_oc, oc_block; |
297 | |
298 | int w_block_size, h_block_size; |
299 | |
300 | data_type_t bia_dt; |
301 | data_type_t dst_dt; |
302 | |
303 | int is_oc_scale; |
304 | int typesize_in; |
305 | int typesize_out; |
306 | int typesize_bia; |
307 | int typesize_acc; |
308 | |
309 | format_tag_t src_tag, dst_tag; // temporary workaround |
310 | bool with_bias; |
311 | bool small_mb; |
312 | |
313 | int xb, yb; |
314 | int inp_stride; |
315 | int out_stride; |
316 | int wei_stride; |
317 | int bia_stride; |
318 | |
319 | int M, N, K; |
320 | int m_block, n_block, k_block; |
321 | int n2_block, n_chunks; |
322 | int k2_block, k_chunks; |
323 | |
324 | int mb_block, nb_mb; |
325 | |
326 | size_t size_wino_src, size_wino_wei, size_wino_dst; |
327 | |
328 | int nthr; |
329 | }; |
330 | |
331 | /* |
332 | Winograd sched policy: |
333 | |
334 | Computation Unit: |
335 | W: weights transform |
336 | S: src transform |
337 | D: dst transform |
338 | G: gemm |
339 | |
340 | Thread grouping by: |
341 | i: nb_ic |
342 | o: nb_oc |
343 | t: tile_block |
344 | e: element in tile |
345 | |
346 | Note: 'i' and 'o' are omitted if |
347 | i. not combined with t or |
348 | ii. with discrete transforms |
349 | |
350 | Current policies supported: |
351 | */ |
352 | enum winograd_sched_t { |
353 | WSCHED_INVALID = 0, |
354 | |
355 | /* Forward & backward-data */ |
356 | /* W_S_G_D implements discrete transforms */ |
357 | WSCHED_DATA_W_S_G_D, |
358 | /* W_SGD implements tiled transforms s.t. GEMM could reuse data in L2*/ |
359 | WSCHED_DATA_W_SGD, |
360 | |
361 | /* Backward-weights */ |
362 | WSCHED_WEI_S_D_G_W, |
363 | WSCHED_WEI_SDGtWo, |
364 | WSCHED_WEI_S_D_Giot_W, |
365 | WSCHED_WEI_SDGt_W, |
366 | }; |
367 | |
368 | struct jit_conv_winograd_conf_t : public jit_conv_conf_t { |
369 | int itiles; |
370 | int jtiles; |
371 | int ntiles; |
372 | int ic_simd_block = 16; |
373 | int oc_simd_block = 16; |
374 | int oc_reg_block; |
375 | int ic_reg_block; |
376 | int tile_block; |
377 | int tile_block_ur; |
378 | int nb_tile_block_ur; |
379 | |
380 | bool double_buffering; |
381 | bool with_relu_postsum; |
382 | int zmm_start; |
383 | int nb_reg; |
384 | |
385 | int dimK; |
386 | int dimK_reg_block; |
387 | int dimK_block; |
388 | int dimK_nb_block; |
389 | |
390 | int dimM; |
391 | int dimM_reg_block; |
392 | int dimM_simd_block; |
393 | int dimM_block; |
394 | int dimM_nb_block; |
395 | |
396 | int dimN; |
397 | int dimN_reg_block; |
398 | int dimN_bcast_ur; |
399 | int dimN_block; |
400 | int dimN_nb_block; |
401 | |
402 | winograd_sched_t sched_policy; |
403 | }; |
404 | |
405 | struct jit_conv_call_s { |
406 | const void *src; /* hack, non-const for backward_data */ |
407 | const void *dst; /* hack, non-const for forward */ |
408 | const void *filt; /* hack, non-const for backward_weights */ |
409 | const void *bias; /* hack, non-const for backward_bias */ |
410 | const void *src_prf; |
411 | const void *dst_prf; |
412 | const void *filt_prf; |
413 | const void *bias_prf; |
414 | const void *scales; |
415 | const void *acc_s32; |
416 | const void *compensation; |
417 | const int32_t *zp_compensation; |
418 | const int32_t *src_zero_point; |
419 | const int32_t *zero_point_pbuff; |
420 | const int32_t *dst_zero_point; |
421 | const void *tile_cfg; |
422 | const void *tile_cfg_tail; |
423 | const void *dst_scale; |
424 | |
425 | // ptr to table of void * elements that are pointers to |
426 | // post_op binary src1 tensors |
427 | const void *post_ops_binary_rhs_arg_vec; |
428 | // logical (# of elems) offset to the processed output channel |
429 | // (for broadcasting [1,OC,1,1]) |
430 | size_t oc_l_off; |
431 | const void *dst_orig; // pointer to dst memory (no offset) |
432 | |
433 | size_t oc_l_off_prf; |
434 | const void *dst_orig_prf; |
435 | |
436 | size_t kd_offset; |
437 | size_t kd_offset_prf; |
438 | size_t kh_offset; |
439 | size_t kh_offset_prf; |
440 | size_t os_index_begin; |
441 | size_t os_index_begin_prf; |
442 | size_t os_index_end; |
443 | size_t os_index_end_prf; |
444 | size_t kd_padding; |
445 | size_t kd_padding_prf; |
446 | size_t kh_padding; |
447 | size_t kh_padding_prf; |
448 | size_t iwb; |
449 | size_t iwb_prf; |
450 | size_t owb; |
451 | size_t owb_prf; |
452 | size_t ohb; |
453 | size_t kw_padding; |
454 | size_t channel; |
455 | size_t channel_prf; |
456 | size_t ic_blocks; |
457 | size_t oc_blocks; |
458 | size_t ur_w; |
459 | size_t ur_str_w; |
460 | size_t ch_blocks; |
461 | size_t ch_blocks_prf; |
462 | size_t reduce_work; |
463 | size_t reduce_work_prf; |
464 | size_t load_work; |
465 | size_t load_work_prf; |
466 | size_t l_overflow; |
467 | size_t r_overflow; |
468 | size_t t_overflow; |
469 | size_t b_overflow; |
470 | size_t f_overflow; |
471 | size_t back_overflow; |
472 | size_t last_h; |
473 | size_t tail; |
474 | size_t current_iw; |
475 | size_t is_osb; |
476 | int flags; |
477 | int flags_prf; |
478 | int oc_flag; |
479 | size_t last_ic_block; |
480 | size_t last_oc_block; |
481 | }; |
482 | |
483 | struct jit_deconv_call_s { |
484 | const void *src; /* hack, non-const for backward_data */ |
485 | const void *dst; /* hack, non-const for forward */ |
486 | const void *filt; /* hack, non-const for backward_weights */ |
487 | const void *bias; /* hack, non-const for backward_bias */ |
488 | const void *scales; |
489 | const void *dst_scale; |
490 | const void *compensation; |
491 | const int32_t *zp_src_pad_str_compensation; |
492 | const int32_t *zp_compensation; |
493 | const int32_t *src_zero_point; |
494 | const int32_t *dst_zero_point; |
495 | |
496 | /* |
497 | * ptr to table of void * elements that are pointers to post_op binary |
498 | * src1 tensors |
499 | */ |
500 | const void *post_ops_binary_rhs_arg_vec; |
501 | const void *dst_orig; /* pointer to dst memory (no offset) */ |
502 | /* |
503 | * logical (# of elems) offset to the processed output channel |
504 | * (for broadcasting [1,OC,1,1]) |
505 | */ |
506 | size_t oc_l_off; |
507 | size_t t_overflow; |
508 | size_t b_overflow; |
509 | size_t f_overflow; |
510 | size_t back_overflow; |
511 | size_t kh_padding; |
512 | size_t kd_padding; |
513 | size_t oc_blocks; |
514 | }; |
515 | |
516 | struct jit_dw_conv_call_s { |
517 | const void *input; |
518 | const void *output; |
519 | const void *filter; |
520 | const void *bias; |
521 | size_t kh_count; |
522 | size_t oh_count; |
523 | size_t oh_index; |
524 | size_t filter_pad_off; |
525 | unsigned char |
526 | exec_flags; /* Flags passed by driver execution to inner kernel */ |
527 | }; |
528 | |
529 | struct jit_wino_transform_call_s { |
530 | size_t tile_block; |
531 | size_t tile_block_ur; |
532 | size_t nb_tile_block_ur; |
533 | size_t tile_count; |
534 | size_t tj; |
535 | size_t ti; |
536 | void *src; |
537 | void *dst; |
538 | void *Mw; |
539 | void *M; |
540 | void *T; |
541 | void *G; |
542 | void *bias; |
543 | }; |
544 | |
545 | struct jit_1x1_conv_conf_t { |
546 | prop_kind_t prop_kind; |
547 | bool has_vnni; |
548 | |
549 | int ndims; |
550 | int mb; |
551 | int ngroups, ic, oc, oc_without_padding, ic_without_padding; |
552 | int id, ih, iw, od, oh, ow; |
553 | int f_pad, t_pad, l_pad; |
554 | int kd, kh, kw; |
555 | int stride_d, stride_h, stride_w; |
556 | format_tag_t src_tag, wei_tag, dst_tag; // temporary workaround |
557 | bool with_bias; |
558 | bool with_sum; |
559 | bool with_eltwise; |
560 | bool with_binary; |
561 | bool with_dw_conv; |
562 | |
563 | post_ops_t post_ops; |
564 | |
565 | dim_t is, os; |
566 | int ic_block, oc_block; |
567 | |
568 | int ur, ur_tail; |
569 | |
570 | dim_t reduce_dim; |
571 | int reduce_block, nb_reduce, nb_reduce_blocking, nb_reduce_blocking_max; |
572 | int load_dim, load_block, nb_load, nb_load_blocking, nb_load_blocking_max, |
573 | nb_load_chunk; |
574 | dim_t bcast_dim; |
575 | int bcast_block, nb_bcast, nb_bcast_blocking, nb_bcast_blocking_max; |
576 | |
577 | int reduce_loop_unroll; |
578 | dim_t reduce_loop_bcast_step; |
579 | int reduce_loop_load_step; |
580 | int load_loop_load_step, load_loop_iter_step; |
581 | int bcast_loop_output_step, bcast_loop_output_substep; |
582 | int bcast_loop_bcast_step, bcast_loop_bcast_substep; |
583 | int load_grp_count; |
584 | conv_1x1_loop_order_t loop_order; |
585 | bool use_vmovntps; |
586 | /* avx512 core */ |
587 | bool expl_bcast; |
588 | |
589 | int typesize_in; |
590 | int typesize_out; |
591 | int typesize_bia; |
592 | int typesize_acc; |
593 | |
594 | bool transpose_src; |
595 | int nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b; |
596 | int is_oc_scale; |
597 | data_type_t bia_dt; |
598 | data_type_t dst_dt; |
599 | data_type_t sum_dt; |
600 | bool signed_input; |
601 | float wei_adj_scale; |
602 | // zero-point compensation |
603 | bool src_zero_point; |
604 | bool dst_zero_point; |
605 | bool zp_src_is_common; // common, otherwise (TODO) per-channel |
606 | |
607 | bool dst_scale; |
608 | |
609 | cpu_isa_t isa; |
610 | bool uses_permw_transposition; |
611 | }; |
612 | |
613 | struct jit_1x1_conv_call_s { |
614 | const void *bcast_data; |
615 | const void *load_data; |
616 | const void *output_data; |
617 | const void *bias_data; // used in forward and backward_weights only |
618 | const void *acc_s32; |
619 | const void *scales; |
620 | const void *compensation; |
621 | const void *store_buffer; |
622 | const int32_t *zp_compensation; |
623 | const int32_t *src_zero_point; |
624 | const int32_t *dst_zero_point; |
625 | const void *dst_scale; |
626 | |
627 | // ptr to table of void * elements that are pointers to |
628 | // post_op binary src1 tensors |
629 | const void *post_ops_binary_rhs_arg_vec; |
630 | // logical (# of elems) offset to the processed output channel |
631 | // (for broadcasting [1,OC,1,1]) |
632 | size_t oc_l_off; |
633 | // logical (# of elems) offset to the processed pixel |
634 | // (for non-broadcasting policy) |
635 | size_t dst_l_off; |
636 | const void *dst_orig; // pointer to dst memory (no offset) |
637 | |
638 | size_t load_dim; |
639 | size_t bcast_dim; |
640 | size_t reduce_dim; |
641 | |
642 | size_t output_stride; // used in backward_weights only |
643 | |
644 | size_t first_last_flag; |
645 | }; |
646 | |
647 | struct jit_pool_conf_t { |
648 | int ndims; |
649 | int mb, c, c_without_padding; |
650 | int id, ih, iw, od, oh, ow; |
651 | int stride_d, stride_h, stride_w; |
652 | int kd, kh, kw; |
653 | int f_pad, t_pad, l_pad; |
654 | alg_kind_t alg; |
655 | bool is_training; |
656 | bool pad_w_is_null; |
657 | bool is_backward; |
658 | bool simple_alg; |
659 | bool is_c_padded; |
660 | data_type_t ind_dt; |
661 | |
662 | int c_block, c_tail, nb_c; |
663 | int ur_bc, ur_bc_tail; |
664 | int ur_c, ur_c_tail; |
665 | int ur; |
666 | size_t tail[4]; |
667 | bool safe_c_tail; |
668 | data_type_t src_dt; |
669 | data_type_t dst_dt; |
670 | |
671 | int dt_size; |
672 | bool is_bf16; |
673 | bool is_f16; |
674 | jit_memory_tag_kind_t tag_kind; |
675 | bool is_plain() const { |
676 | return (tag_kind == jit_memory_tag_kind_t::ncsp |
677 | || tag_kind == jit_memory_tag_kind_t::nspc); |
678 | } |
679 | |
680 | cpu_isa_t isa; |
681 | post_ops_t post_ops; |
682 | bool with_postops; |
683 | bool with_eltwise; |
684 | bool with_binary; |
685 | int nthr; |
686 | memory_desc_t tmp_md; |
687 | }; |
688 | |
689 | struct jit_pool_call_s { |
690 | const void *src; |
691 | const void *dst; |
692 | const void *indices; |
693 | const void *src_prf; |
694 | const void *dst_prf; |
695 | const void *indices_prf; |
696 | const void *post_ops_binary_rhs_arg_vec; |
697 | const void *dst_orig; |
698 | const void *dst_po_helper; |
699 | size_t zero_ih; |
700 | size_t zero_id; |
701 | const void *zero_ptr; |
702 | size_t kd_padding; |
703 | size_t kh_padding; |
704 | size_t kh_padding_shift; |
705 | size_t kd_padding_shift; |
706 | size_t kw_padding; |
707 | const void *init_value; |
708 | float ker_area_h; |
709 | size_t ur_bc; // contains number of channel blocks to processing |
710 | size_t b_c; // contains number of channel blocks already processed |
711 | }; |
712 | |
713 | struct jit_resampling_conf_t { |
714 | unsigned ndims = 0; |
715 | |
716 | unsigned c = 0; |
717 | unsigned id = 0, ih = 0, iw = 0; |
718 | unsigned od = 0, oh = 0, ow = 0; |
719 | |
720 | unsigned stride_d = 0; |
721 | unsigned stride_h = 0; |
722 | unsigned stride_w = 0; |
723 | unsigned inner_stride = 0; |
724 | |
725 | // The linear algorithm is an approximation of the point |
726 | // value based on the limit values. For one dimension, |
727 | // the approximation is based on the line, for two |
728 | // dimensions it will be a rectangle, and for three |
729 | // dimensions it will be a cuboid. Therefore, |
730 | // the possible variants for the number of corners are 2, 4, 8. |
731 | unsigned number_of_corners = 0; |
732 | |
733 | bool is_data_size_bigger_than_L3 = false; |
734 | bool is_saturation_needed = false; |
735 | data_type_t src_data_type = data_type::undef; |
736 | data_type_t dst_data_type = data_type::undef; |
737 | size_t src_dt_size = 0; |
738 | size_t dst_dt_size = 0; |
739 | size_t output_data_size = 0; |
740 | size_t el_size_of_indices = 0; |
741 | |
742 | bool is_blocked_8_format = false; |
743 | format_tag_t src_tag = format_tag::undef; |
744 | jit_memory_tag_kind_t tag_kind = jit_memory_tag_kind_t::undef; |
745 | alg_kind_t alg = alg_kind::undef; |
746 | |
747 | cpu_isa_t isa = isa_undef; |
748 | |
749 | post_ops_t post_ops = post_ops_t(); |
750 | bool with_postops = false; |
751 | bool with_eltwise = false; |
752 | bool with_binary = false; |
753 | bool with_sum = false; |
754 | std::queue<float> sum_scales; |
755 | }; |
756 | |
757 | struct jit_resampling_call_s { |
758 | size_t batch_of_sp_points_to_process = 0; |
759 | |
760 | const void *src = nullptr; |
761 | const void *dst = nullptr; |
762 | const void *indices = nullptr; |
763 | const void *weights = nullptr; |
764 | const void *post_ops_binary_rhs_arg_vec = nullptr; |
765 | const void *dst_orig = nullptr; |
766 | |
767 | size_t c_offset = 0; |
768 | |
769 | size_t src_offset_top = 0; |
770 | size_t src_offset_bottom = 0; |
771 | size_t src_offset_front = 0; |
772 | size_t src_offset_back = 0; |
773 | |
774 | float weight_top = 0.0f; |
775 | float weight_bottom = 0.0f; |
776 | float weight_front = 0.0f; |
777 | float weight_back = 0.0f; |
778 | }; |
779 | |
780 | struct jit_brdgmm_conv_conf_t { |
781 | |
782 | int nthr; |
783 | int mb, ngroups, ic, oc; |
784 | int ih, iw, oh, ow; |
785 | int l_pad, r_pad, t_pad, b_pad; |
786 | int kh, kw; |
787 | int stride_h, stride_w; |
788 | int nb_ch, ch_block, chb_tail; |
789 | int nb_ch_blocking; |
790 | int ow_block, ow_tail, nb_ow; |
791 | // idx of jit kernel when mutiple jit kernels are used in a primitive. |
792 | int chb_tail_idx, ow_tail_idx, nb_ch_blocking_idx; |
793 | int adjusted_batch_size; |
794 | |
795 | bool with_bias; |
796 | bool with_post_ops; |
797 | bool with_scale; |
798 | bool is_oc_scale; |
799 | |
800 | data_type_t src_dt; |
801 | data_type_t wei_dt; |
802 | data_type_t bia_dt; |
803 | data_type_t dst_dt; |
804 | |
805 | brgemm_batch_kind_t batch_kind; |
806 | |
807 | size_t src_dsz; |
808 | size_t wei_dsz; |
809 | size_t bia_dsz; |
810 | size_t dst_dsz; |
811 | |
812 | cpu_isa_t isa; |
813 | }; |
814 | |
815 | enum conv_brgemm_loop_order_t { |
816 | loop_ndhwgc, |
817 | loop_ngcdhw, |
818 | }; |
819 | |
820 | enum conv_brgemm_exec_type_t { |
821 | exec_undefined = 0, |
822 | exec_base, |
823 | exec_trans, |
824 | exec_vpad, |
825 | }; |
826 | |
827 | struct jit_brgemm_conv_conf_t { |
828 | cpu_isa_t isa; |
829 | prop_kind_t prop_kind; |
830 | conv_brgemm_loop_order_t loop_order; |
831 | conv_harness_t harness; |
832 | int simd_w, acc_simd_w, amx_w, amx_h; |
833 | int ndims; |
834 | int mb; |
835 | int ngroups, ic, oc, oc_without_padding, ic_without_padding; |
836 | |
837 | int od_block, oh_block, nb_od, |
838 | nb_oh; // blocking - included in parallelization |
839 | int id_block, ih_block, nb_id, nb_ih; |
840 | dim_t inp_buffer_size, inp_buffer_mask_size; |
841 | conv_brgemm_exec_type_t exec_type; |
842 | |
843 | int id, ih, iw, od, oh, ow, os, is, idp, ihp, iwp, icp, odp, ohp, owp, ocp; |
844 | int f_pad, l_pad, t_pad; |
845 | int back_pad, r_pad, b_pad; |
846 | int l_ovf, r_ovf, t_ovf, b_ovf, f_ovf, back_ovf; |
847 | int kd, kh, kw; |
848 | int ext_kd, ext_kh, ext_kw; |
849 | int kd_block, kh_block, kw_block, kd_block_pad, kh_block_pad, kw_block_pad; |
850 | int stride_d, stride_h, stride_w; |
851 | int dilate_d, dilate_h, dilate_w; |
852 | format_tag_t src_tag, wei_tag, dst_tag; // temporary workaround |
853 | bool with_bias; |
854 | bool with_sum; |
855 | bool with_eltwise; |
856 | bool with_binary; |
857 | |
858 | bool is_fused_conv; |
859 | bool is_is_blocking; |
860 | bool is_os_blocking; |
861 | bool is_rtus; |
862 | int nb_ic, ic_block; |
863 | int nb_oc, oc_block; |
864 | int nb_iw, iw_block, iw_tail; |
865 | int nb_ow, ow_block, ow_tail; |
866 | int nb_is, is_block; |
867 | int nb_os, os_block; |
868 | int nb_oc_blocking; |
869 | int nb_ic_blocking; |
870 | int nb_is_blocking; |
871 | int nb_os_blocking; |
872 | |
873 | data_type_t src_dt; |
874 | data_type_t dst_dt; |
875 | data_type_t wei_dt; |
876 | data_type_t acc_dt; |
877 | data_type_t bia_dt; |
878 | size_t src_dsz; |
879 | size_t wei_dsz; |
880 | size_t dst_dsz; |
881 | size_t acc_dsz; |
882 | size_t bia_dsz; |
883 | |
884 | bool use_buffer; |
885 | dim_t buffer_size; |
886 | dim_t ker_ranges_size; |
887 | dim_t comp_a_buffer_size; |
888 | dim_t s8s8_comp_buffer_size; |
889 | |
890 | bool with_scales; |
891 | int is_ic_scale, is_oc_scale; |
892 | |
893 | int LDA, LDB, LDC, LDD; |
894 | int M, N, K, M_tail, N_tail, K_tail; |
895 | // M for brgemm kernel. For use_store_mask it is usually greater than M (M_tail). Otherwise it is equal to M (M_tail) |
896 | int brgM, brgM_tail; |
897 | int gemm_batch_size, adjusted_batch_size; |
898 | brgemm_batch_kind_t brg_type; |
899 | // strides for brg_type == brgemm_strd |
900 | dim_t brg_stride_a, brg_stride_b; |
901 | int nthr; |
902 | |
903 | int max_batch; |
904 | int max_vpad; |
905 | int amx_buf_size_per_thread; |
906 | |
907 | bool wei_plain; |
908 | bool is_ic_padded, is_oc_padded; |
909 | int kw_sets, kh_sets; |
910 | bool copy_block_only; |
911 | bool amx_tile_load_xx; |
912 | int use_M_mask; |
913 | int oskip, iskip; |
914 | bool brgemm_bd_loop_innermost; |
915 | |
916 | bool use_uker; |
917 | bool var_bs {false}; |
918 | bool use_interleave_stores; |
919 | brgemm_kernel_prefetching_t hint_prefetching; |
920 | bool is_1x1; |
921 | bool s8s8_avx512; |
922 | bool src_zero_point; |
923 | bool dst_zero_point; |
924 | bool req_brg_comp_pad; |
925 | bool req_cal_comp_pad; |
926 | bool is_bf32; |
927 | bool comp_with_vpads; |
928 | |
929 | int nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b, nthr_oh; |
930 | bool transform_to_vnni; |
931 | bool has_vnni; |
932 | int ic_tail, oc_tail; |
933 | size_t tr_src_buf_size, tr_src_buf_count; |
934 | size_t tr_diff_dst_buf_size, tr_diff_dst_buf_count; |
935 | int tr_src_num_guard_elems; |
936 | bool global_transpose; // diff_dst & src tensors are transposed in one go |
937 | int nthr_mb_work; |
938 | int tr_iw, tr_ow; |
939 | int spatial_blk_size; // Height/depth block size inside the driver |
940 | int typesize_in; |
941 | int typesize_out; |
942 | bool tr_ocb_chunk = false; |
943 | bool tr_icb_chunk = false; |
944 | }; |
945 | |
946 | struct jit_shuffle_conf_t { |
947 | unsigned ndims = 0; |
948 | |
949 | unsigned mb = 0, c = 0, d = 0, h = 0, w = 0, sp = 0; |
950 | |
951 | unsigned stride_mb = 0; |
952 | unsigned blk_size = 0; |
953 | unsigned group_size = 0; |
954 | unsigned axis = 0; |
955 | unsigned axis_size = 0; |
956 | unsigned simd_tail = 0; |
957 | unsigned simd_w = 0; |
958 | |
959 | jit_memory_tag_kind_t tag_kind = jit_memory_tag_kind_t::undef; |
960 | data_type_t data_type = data_type::undef; |
961 | size_t dt_size = 0; |
962 | unsigned el_size_of_indices = 0; |
963 | dim_t c_split_size = 0; |
964 | dim_t sp_split_size = 0; |
965 | |
966 | cpu_isa_t isa = isa_undef; |
967 | }; |
968 | |
969 | struct jit_shuffle_call_s { |
970 | const void *src = nullptr; |
971 | void *dst = nullptr; |
972 | const void *input_off_ptr = nullptr; |
973 | |
974 | dim_t cb_loop_size |
975 | = 0; // number of loop iterations over corresponding C batches |
976 | bool is_padded_block = false; |
977 | }; |
978 | |
979 | enum class binary_op_t : unsigned { none, c_blocked, n_spatial_c, n_c_spatial }; |
980 | |
981 | enum class binary_bcast_t : unsigned { |
982 | none, // tensor operation |
983 | scalar, |
984 | per_batch, |
985 | per_c, |
986 | per_w |
987 | }; |
988 | |
989 | struct jit_binary_conf_t { |
990 | binary_op_t op_type = binary_op_t::none; |
991 | binary_bcast_t bcast_type = binary_bcast_t::none; |
992 | bool do_scale_src0 = false; |
993 | bool do_scale_src1 = false; |
994 | bool do_sum = false; |
995 | bool with_eltwise = false; |
996 | bool with_binary = false; |
997 | bool with_postops = false; |
998 | float sum_scale = 0.f; |
999 | bool use_stride_src1 = false; |
1000 | bool broadcast_src1_value = false; |
1001 | bool use_stride_rhs_postops = false; |
1002 | bool postops_per_oc_broadcast_exists = false; |
1003 | bool is_i8 = false; |
1004 | bool is_bf16 = false; |
1005 | bool is_f16 = false; |
1006 | bool is_src_different_layouts = false; |
1007 | dim_t outer_dims = 1; |
1008 | int src1_stride = 1; |
1009 | int not_bcasted_sp_dims = 0; |
1010 | cpu_isa_t isa = isa_undef; |
1011 | |
1012 | data_type_t src0_type = data_type::undef; |
1013 | data_type_t src1_type = data_type::undef; |
1014 | data_type_t dst_type = data_type::undef; |
1015 | }; |
1016 | |
1017 | struct jit_binary_call_s { |
1018 | // keep all sizes at 8 bytes -- jit code expects this |
1019 | const void *src0, *src1, *dst, *indices; |
1020 | const float *scales_src0, *scales_src1; |
1021 | size_t spat_offt_count; |
1022 | const void *post_ops_binary_rhs_arg_vec; |
1023 | size_t src1_stride_range; |
1024 | const void *dst_orig; |
1025 | }; |
1026 | |
1027 | struct jit_reduction_conf_t { |
1028 | data_type_t src_type = data_type::undef; |
1029 | data_type_t dst_type = data_type::undef; |
1030 | data_type_t acc_type = data_type::undef; |
1031 | |
1032 | std::size_t src_dt_size = 0; |
1033 | std::size_t dst_dt_size = 0; |
1034 | std::size_t acc_dt_size = 0; |
1035 | |
1036 | alg_kind_t alg = alg_kind::undef; |
1037 | cpu_isa_t isa = isa_undef; |
1038 | |
1039 | dim_t idle_size = 0; |
1040 | dim_t reduce_size = 0; |
1041 | |
1042 | bool is_saturation_needed = false; |
1043 | |
1044 | post_ops_t post_ops = post_ops_t(); |
1045 | bool with_postops = false; |
1046 | bool with_eltwise = false; |
1047 | bool with_binary = false; |
1048 | bool with_sum = false; |
1049 | std::queue<float> sum_scales; |
1050 | }; |
1051 | |
1052 | struct jit_reduction_call_s { |
1053 | const void *src = nullptr; |
1054 | void *dst = nullptr; |
1055 | const void *post_ops_binary_rhs_arg_vec = nullptr; |
1056 | const void *dst_orig = nullptr; |
1057 | }; |
1058 | |
1059 | } // namespace x64 |
1060 | } // namespace cpu |
1061 | } // namespace impl |
1062 | } // namespace dnnl |
1063 | |
1064 | #endif |
1065 | |