1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_BRGEMM_PRIMITIVE_CONF_HPP
18#define CPU_X64_JIT_BRGEMM_PRIMITIVE_CONF_HPP
19
20#include "cpu/x64/brgemm/brgemm_types.hpp"
21#include "cpu/x64/jit_primitive_conf.hpp"
22
23namespace dnnl {
24namespace impl {
25namespace cpu {
26namespace x64 {
27
28struct jit_brgemm_primitive_conf_t {
29 prop_kind_t prop_kind;
30 conv_loop_order_t loop_order;
31 conv_harness_t harness;
32 int simd_w;
33 int ndims;
34 int mb;
35 int ngroups, ic, oc, oc_without_padding, ic_without_padding;
36 int id, ih, iw, od, oh, ow, os;
37 int f_pad, l_pad, t_pad;
38 int back_pad, r_pad, b_pad;
39 int kd, kh, kw;
40 int stride_d, stride_h, stride_w;
41 int dilate_d, dilate_h, dilate_w;
42 format_tag_t src_tag, wei_tag, dst_tag; // temporary workaround
43 bool is_wei_layout_any;
44 bool with_bias;
45 bool with_sum;
46 bool with_eltwise;
47 bool with_binary;
48 bool with_scales;
49 bool signed_input;
50 int nb_ic, ic_block, ic_block_ext;
51 int nb_oc, oc_block, oc_block_ext;
52 int nb_iw, iw_block;
53 int nb_ow, ow_block;
54 int nb_os, os_block;
55 int nb_oc_blocking;
56 int nb_ic_blocking;
57 int nb_os_blocking;
58
59 data_type_t src_dt;
60 data_type_t dst_dt;
61 data_type_t wei_dt;
62 data_type_t acc_dt;
63 data_type_t bia_dt;
64
65 bool is_amx;
66 bool use_buffer;
67 bool use_buffer_a;
68 bool use_buffer_b;
69 bool is_bf32;
70
71 int is_oc_scale;
72
73 int LDA, LDB, LDC, LDD;
74 int M, N, K, M_tail, N_tail, K_tail;
75 int gemm_batch_size, adjusted_batch_size;
76 brgemm_batch_kind_t brg_type;
77 int num_gemm_kernels;
78 int nthr, nthr_mb, nthr_oc_b, nthr_ic_b;
79
80 cpu_isa_t isa;
81 bool ip_bwd_d_global_b_transpose;
82 bool use_uker;
83 bool use_interleave_stores;
84 int amx_buf_size_per_thread;
85 brgemm_kernel_prefetching_t hint_prefetching
86 = brgemm_kernel_prefetching_t::brgemm_prf_default;
87 bool ip_bwd_w_local_buffers_for_input_tensors;
88};
89
90} // namespace x64
91} // namespace cpu
92} // namespace impl
93} // namespace dnnl
94
95#endif
96