1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_JIT_BRGEMM_PRIMITIVE_CONF_HPP |
18 | #define CPU_X64_JIT_BRGEMM_PRIMITIVE_CONF_HPP |
19 | |
20 | #include "cpu/x64/brgemm/brgemm_types.hpp" |
21 | #include "cpu/x64/jit_primitive_conf.hpp" |
22 | |
23 | namespace dnnl { |
24 | namespace impl { |
25 | namespace cpu { |
26 | namespace x64 { |
27 | |
28 | struct jit_brgemm_primitive_conf_t { |
29 | prop_kind_t prop_kind; |
30 | conv_loop_order_t loop_order; |
31 | conv_harness_t harness; |
32 | int simd_w; |
33 | int ndims; |
34 | int mb; |
35 | int ngroups, ic, oc, oc_without_padding, ic_without_padding; |
36 | int id, ih, iw, od, oh, ow, os; |
37 | int f_pad, l_pad, t_pad; |
38 | int back_pad, r_pad, b_pad; |
39 | int kd, kh, kw; |
40 | int stride_d, stride_h, stride_w; |
41 | int dilate_d, dilate_h, dilate_w; |
42 | format_tag_t src_tag, wei_tag, dst_tag; // temporary workaround |
43 | bool is_wei_layout_any; |
44 | bool with_bias; |
45 | bool with_sum; |
46 | bool with_eltwise; |
47 | bool with_binary; |
48 | bool with_scales; |
49 | bool signed_input; |
50 | int nb_ic, ic_block, ic_block_ext; |
51 | int nb_oc, oc_block, oc_block_ext; |
52 | int nb_iw, iw_block; |
53 | int nb_ow, ow_block; |
54 | int nb_os, os_block; |
55 | int nb_oc_blocking; |
56 | int nb_ic_blocking; |
57 | int nb_os_blocking; |
58 | |
59 | data_type_t src_dt; |
60 | data_type_t dst_dt; |
61 | data_type_t wei_dt; |
62 | data_type_t acc_dt; |
63 | data_type_t bia_dt; |
64 | |
65 | bool is_amx; |
66 | bool use_buffer; |
67 | bool use_buffer_a; |
68 | bool use_buffer_b; |
69 | bool is_bf32; |
70 | |
71 | int is_oc_scale; |
72 | |
73 | int LDA, LDB, LDC, LDD; |
74 | int M, N, K, M_tail, N_tail, K_tail; |
75 | int gemm_batch_size, adjusted_batch_size; |
76 | brgemm_batch_kind_t brg_type; |
77 | int num_gemm_kernels; |
78 | int nthr, nthr_mb, nthr_oc_b, nthr_ic_b; |
79 | |
80 | cpu_isa_t isa; |
81 | bool ip_bwd_d_global_b_transpose; |
82 | bool use_uker; |
83 | bool use_interleave_stores; |
84 | int amx_buf_size_per_thread; |
85 | brgemm_kernel_prefetching_t hint_prefetching |
86 | = brgemm_kernel_prefetching_t::brgemm_prf_default; |
87 | bool ip_bwd_w_local_buffers_for_input_tensors; |
88 | }; |
89 | |
90 | } // namespace x64 |
91 | } // namespace cpu |
92 | } // namespace impl |
93 | } // namespace dnnl |
94 | |
95 | #endif |
96 | |