1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_MATMUL_BRGEMM_MATMUL_UTILS_HPP
18#define CPU_X64_MATMUL_BRGEMM_MATMUL_UTILS_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/memory_tracking.hpp"
22
23#include "cpu/x64/brgemm/brgemm.hpp"
24
25namespace dnnl {
26namespace impl {
27namespace cpu {
28namespace x64 {
29namespace matmul {
30
31constexpr int max_batch_ndims = DNNL_MAX_NDIMS - 2;
32
33struct brgemm_matmul_bcast_desc_t {
34
35 brgemm_matmul_bcast_desc_t()
36 : bcast_mask(0)
37 , first_bcast_dim(-1)
38 , last_bcast_dim(-1)
39 , first_bcast_dim_to_last_batch_dim_prod(1)
40 , bcast_dims_prod(1)
41 , batch_dims {0}
42 , gb_off {0} {}
43
44 void set_params(const dims_t &inp_dims, const dims_t &dst_d_dims,
45 int batch_ndims, dim_t batch) {
46 const int ndims = batch_ndims;
47 first_bcast_dim_to_last_batch_dim_prod = batch;
48 for (int d = 0; d < ndims; ++d) {
49 batch_dims[d] = dst_d_dims[d];
50 gb_off[d] = (d == 0 ? batch : gb_off[d - 1]) / dst_d_dims[d];
51 if (dst_d_dims[d] != 1 && inp_dims[d] == 1) { // broadcast
52 const int mask = 1 << (ndims - 1);
53 bcast_mask |= (mask >> d);
54 if (first_bcast_dim == -1) {
55 first_bcast_dim = d;
56 if (d == 0) // broadcast_dim == B0
57 first_bcast_dim_to_last_batch_dim_prod = batch;
58 }
59 last_bcast_dim = d;
60 bcast_dims_prod *= dst_d_dims[d];
61 }
62 if (first_bcast_dim == -1) // broadcast_dim > B0
63 first_bcast_dim_to_last_batch_dim_prod /= dst_d_dims[d];
64 }
65 }
66
67 int bcast_mask; // sets bcast_dim = 1, non_bcast_dim = 0
68
69 int first_bcast_dim;
70 int last_bcast_dim;
71
72 dim_t first_bcast_dim_to_last_batch_dim_prod;
73 dim_t bcast_dims_prod;
74
75 dim_t batch_dims[max_batch_ndims];
76 dim_t gb_off[max_batch_ndims]; // generalized batch offset
77};
78
79struct brgemm_matmul_conf_t {
80 int ndims, batch_ndims;
81 dim_t M, N, K, batch, batch_without_first_dim;
82 dim_t M_blk, N_blk, K_blk, M_tail, N_tail, K_tail;
83 int M_chunk_size, N_chunk_size;
84 dim_t LDA, LDB, LDC, LDD;
85 int brgemm_batch_size, brgemm_batch_tail_size;
86 int wei_n_blk, wei_k_blk;
87 brgemm_batch_kind_t brg_type;
88
89 cpu_isa_t isa;
90
91 format_tag_t src_tag, wei_tag, dst_tag, bia_tag;
92 bool with_bias;
93 bool with_sum;
94 bool with_eltwise;
95 bool with_binary;
96 bool with_scales;
97 bool s8s8_compensation_required;
98 bool is_oscale_per_n;
99 brgemm_broadcast_t src_zp_type;
100 brgemm_broadcast_t wei_zp_type;
101 brgemm_broadcast_t dst_zp_type;
102
103 bool use_buffer_a;
104 bool use_buffer_a_tail_only;
105 bool use_buffer_b;
106 bool use_buffer_c;
107
108 brgemm_matmul_bcast_desc_t bcast_A_desc;
109 brgemm_matmul_bcast_desc_t bcast_B_desc;
110
111 data_type_t src_dt;
112 data_type_t dst_dt;
113 data_type_t wei_dt;
114 data_type_t acc_dt;
115 data_type_t bia_dt;
116 int nthr;
117 int nthr_k;
118
119 // Auxiliary values for init_config() and execute()
120 dim_t a_dt_sz, b_dt_sz, c_dt_sz, acc_dt_sz, bias_dt_sz;
121
122 // used for transposed buffer datatype when different from x_dt_sz
123 // (e.g. used in BF32 implementations having to down-convert to BF16
124 // from FP32 implementation)
125 dim_t tr_a_dt_sz, tr_b_dt_sz;
126
127 int M_chunks;
128 int N_chunks;
129 int K_chunks;
130 int num_M_blocks;
131 int num_N_blocks;
132 dim_t M_chunk_elems;
133 dim_t N_chunk_elems;
134 dim_t K_chunk_elems;
135
136 // Pre-calculated memory strides for each tensor
137 dim_t A_strides[3];
138 dim_t B_strides[3];
139 dim_t C_strides[3];
140 dim_t buffer_c_chunk_sz;
141 dim_t buffer_c_per_thread_sz;
142
143 dim_t A_ptr_shift_b;
144 dim_t B_ptr_shift_b;
145 dim_t C_ptr_shift_b;
146 dim_t copy_A_src_stride;
147 dim_t copy_B_wei_stride;
148
149 dim_t buffer_a_chunk_sz;
150 dim_t buffer_a_chunk_shift_along_m;
151 dim_t buffer_a_per_thread_sz;
152
153 dim_t buffer_b_chunk_sz;
154 dim_t buffer_b_per_thread_sz;
155 dim_t s8s8_comp_ithr_str;
156 dim_t s8s8_comp_b_str;
157 dim_t s8s8_comp_n_str;
158 bool has_zero_point_a, has_zero_point_b, has_zero_point_c;
159 bool post_ops_applicable;
160 bool transposed_A;
161 bool blocked_B;
162
163 dim_t zp_a_comp_shift_n;
164 dim_t zp_a_comp_elems_per_thr;
165
166 dim_t zp_b_comp_result_shift_m;
167 dim_t zp_b_comp_buffer_start;
168 dim_t zp_b_comp_buffer_shift_m;
169 dim_t zp_b_comp_elems_per_thr;
170
171 int wsp_tile_per_thr_bytes;
172 int brgemm_batch_element_per_thr_sz;
173 bool is_amx;
174
175 int required_k_granularity;
176 bool is_bf32 = false;
177 bool req_wei_vnni_downconvert = false;
178};
179
180struct brgemm_matmul_conf_utils_t {
181
182 brgemm_matmul_conf_utils_t(brgemm_matmul_conf_t &bgmmc, const cpu_isa_t isa,
183 const primitive_attr_t &attr, bool A_any_layout, bool B_any_layout,
184 bool C_any_layout, bool bias_any_layout);
185
186 inline bool check_b_layout_blocked_by_n(format_tag_t matrix_b_tag) const {
187 return blocked_B_layouts_allowed
188 && utils::one_of(matrix_b_tag, blocked_64n_B_layout_tag,
189 blocked_48n_B_layout_tag, blocked_32n_B_layout_tag,
190 blocked_16n_B_layout_tag);
191 }
192
193 inline bool get_blocked_B() const {
194 return blocked_B_layouts_allowed
195 && check_b_layout_blocked_by_n(bgmmc.wei_tag);
196 }
197
198 inline bool use_buffer_b(bool use_heuristic = true) const {
199 if (bgmmc.is_amx)
200 // use b_buffer for AMX when:
201 // - not bf32 && using non-blocked weights
202 // - is bf32
203 return IMPLICATION(!wei_down_convert_to_vnni(), !bgmmc.blocked_B);
204
205 // Values based on measured performance difference
206 // between plain and copy-to-blocked routine.
207 size_t big_LDB = bgmmc.N > 256;
208 bool is_pow2 = math::is_pow2(bgmmc.N);
209 bool use_copy_buffer = IMPLICATION(
210 this->is_f32(), use_heuristic && (big_LDB && is_pow2));
211 return (this->is_f16() && bgmmc.isa == avx512_core_fp16)
212 || (use_copy_buffer && this->check_is_plain(bgmmc.wei_tag))
213 || this->check_is_transposed(bgmmc.wei_tag)
214 || (bgmmc.wei_tag == format_tag::acbd)
215 || (bgmmc.wei_tag == format_tag::adbc);
216 }
217
218 inline dim_t get_actual_LDB() const {
219 if (bgmmc.wei_tag == format_tag::acbd && !bgmmc.use_buffer_b) {
220 assert(bgmmc.b_dt_sz == bgmmc.tr_b_dt_sz);
221 return bgmmc.B_strides[1] / bgmmc.b_dt_sz;
222 }
223 bool use_blocked_LDB = bgmmc.is_amx || bgmmc.use_buffer_b
224 || bgmmc.wei_tag != plain_tensor_layout_tag;
225 return use_blocked_LDB ? bgmmc.wei_n_blk : bgmmc.N;
226 }
227
228 inline bool maybe_low_brg_blocking() const {
229 // Check if m_block is a prime number from 32 to 64
230 const bool is_prime_num
231 = utils::one_of(bgmmc.M_blk, 37, 41, 43, 47, 53, 59, 61);
232 const bool maybe_ldb_tail = bgmmc.N % 16;
233 return is_prime_num && IMPLICATION(bgmmc.M_blk < 48, maybe_ldb_tail);
234 }
235
236 inline bool check_n_blk_fixed() const { return n_blk_fixed; }
237
238 inline bool check_is_transposed(format_tag_t tag) const {
239 return tag == transposed_tensor_layout_tag;
240 }
241
242 inline bool check_is_plain(format_tag_t tag) const {
243 return tag == plain_tensor_layout_tag;
244 }
245
246 inline bool is_f32() const { return f32_dt; }
247
248 inline bool is_bf16() const { return bf16_dt; }
249
250 inline bool is_f16() const { return f16_dt; }
251
252 inline bool is_int8() const { return int8_dt; }
253
254 inline bool is_bf32() const { return bf32_dt; }
255
256 inline bool is_int8_with_bf16_dst() const {
257 return this->is_int8() && bgmmc.dst_dt == data_type::bf16;
258 }
259
260 inline bool wei_down_convert_to_vnni() const {
261 return bf32_dt && get_blocked_B();
262 }
263
264 inline bool is_any_B_layout() const { return B_any_layout; }
265
266 status_t set_or_check_B_tag(
267 memory_desc_t &B_md, bool init_n_tag = true) const;
268 status_t update_and_check_B_tag(memory_desc_t &B_md, int n_blk_size) const;
269 status_t set_or_check_tags(memory_desc_t &A_md, memory_desc_t &C_md,
270 memory_desc_t &bias_md) const;
271 status_t set_B_flags(memory_desc_t &B_md) const;
272 format_tag_t pick_blocked_B_layout(int n_blk) const;
273
274private:
275 brgemm_matmul_conf_t &bgmmc;
276
277 const bool f32_dt, bf16_dt, f16_dt, int8_dt, bf32_dt;
278 const bool A_any_layout;
279 const bool B_any_layout;
280 const bool C_any_layout;
281 const bool bias_any_layout;
282
283 const format_tag_t plain_tensor_layout_tag;
284 const format_tag_t transposed_tensor_layout_tag;
285 const format_tag_t blocked_64n_B_layout_tag, blocked_48n_B_layout_tag,
286 blocked_32n_B_layout_tag, blocked_16n_B_layout_tag;
287 const bool blocked_B_layouts_allowed;
288 const bool n_blk_fixed;
289};
290
291void init_aux_values(brgemm_matmul_conf_t &bgmmc,
292 const memory_desc_wrapper &src_d, const memory_desc_wrapper &wei_d,
293 const memory_desc_wrapper &dst_d);
294
295status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,
296 const matmul_desc_t &mmd, memory_desc_t &src_md,
297 memory_desc_t &weights_md, memory_desc_t &dst_md,
298 memory_desc_t &bias_md, primitive_attr_t &attr);
299
300void init_scratchpad(memory_tracking::registrar_t &scratchpad,
301 const brgemm_matmul_conf_t &bgmmc);
302
303int get_default_n_block(format_tag_t matrix_b_tag);
304
305} // namespace matmul
306} // namespace x64
307} // namespace cpu
308} // namespace impl
309} // namespace dnnl
310
311#endif
312