1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_MATMUL_BRGEMM_MATMUL_UTILS_HPP |
18 | #define CPU_X64_MATMUL_BRGEMM_MATMUL_UTILS_HPP |
19 | |
20 | #include "common/c_types_map.hpp" |
21 | #include "common/memory_tracking.hpp" |
22 | |
23 | #include "cpu/x64/brgemm/brgemm.hpp" |
24 | |
25 | namespace dnnl { |
26 | namespace impl { |
27 | namespace cpu { |
28 | namespace x64 { |
29 | namespace matmul { |
30 | |
31 | constexpr int max_batch_ndims = DNNL_MAX_NDIMS - 2; |
32 | |
33 | struct brgemm_matmul_bcast_desc_t { |
34 | |
35 | brgemm_matmul_bcast_desc_t() |
36 | : bcast_mask(0) |
37 | , first_bcast_dim(-1) |
38 | , last_bcast_dim(-1) |
39 | , first_bcast_dim_to_last_batch_dim_prod(1) |
40 | , bcast_dims_prod(1) |
41 | , batch_dims {0} |
42 | , gb_off {0} {} |
43 | |
44 | void set_params(const dims_t &inp_dims, const dims_t &dst_d_dims, |
45 | int batch_ndims, dim_t batch) { |
46 | const int ndims = batch_ndims; |
47 | first_bcast_dim_to_last_batch_dim_prod = batch; |
48 | for (int d = 0; d < ndims; ++d) { |
49 | batch_dims[d] = dst_d_dims[d]; |
50 | gb_off[d] = (d == 0 ? batch : gb_off[d - 1]) / dst_d_dims[d]; |
51 | if (dst_d_dims[d] != 1 && inp_dims[d] == 1) { // broadcast |
52 | const int mask = 1 << (ndims - 1); |
53 | bcast_mask |= (mask >> d); |
54 | if (first_bcast_dim == -1) { |
55 | first_bcast_dim = d; |
56 | if (d == 0) // broadcast_dim == B0 |
57 | first_bcast_dim_to_last_batch_dim_prod = batch; |
58 | } |
59 | last_bcast_dim = d; |
60 | bcast_dims_prod *= dst_d_dims[d]; |
61 | } |
62 | if (first_bcast_dim == -1) // broadcast_dim > B0 |
63 | first_bcast_dim_to_last_batch_dim_prod /= dst_d_dims[d]; |
64 | } |
65 | } |
66 | |
67 | int bcast_mask; // sets bcast_dim = 1, non_bcast_dim = 0 |
68 | |
69 | int first_bcast_dim; |
70 | int last_bcast_dim; |
71 | |
72 | dim_t first_bcast_dim_to_last_batch_dim_prod; |
73 | dim_t bcast_dims_prod; |
74 | |
75 | dim_t batch_dims[max_batch_ndims]; |
76 | dim_t gb_off[max_batch_ndims]; // generalized batch offset |
77 | }; |
78 | |
79 | struct brgemm_matmul_conf_t { |
80 | int ndims, batch_ndims; |
81 | dim_t M, N, K, batch, batch_without_first_dim; |
82 | dim_t M_blk, N_blk, K_blk, M_tail, N_tail, K_tail; |
83 | int M_chunk_size, N_chunk_size; |
84 | dim_t LDA, LDB, LDC, LDD; |
85 | int brgemm_batch_size, brgemm_batch_tail_size; |
86 | int wei_n_blk, wei_k_blk; |
87 | brgemm_batch_kind_t brg_type; |
88 | |
89 | cpu_isa_t isa; |
90 | |
91 | format_tag_t src_tag, wei_tag, dst_tag, bia_tag; |
92 | bool with_bias; |
93 | bool with_sum; |
94 | bool with_eltwise; |
95 | bool with_binary; |
96 | bool with_scales; |
97 | bool s8s8_compensation_required; |
98 | bool is_oscale_per_n; |
99 | brgemm_broadcast_t src_zp_type; |
100 | brgemm_broadcast_t wei_zp_type; |
101 | brgemm_broadcast_t dst_zp_type; |
102 | |
103 | bool use_buffer_a; |
104 | bool use_buffer_a_tail_only; |
105 | bool use_buffer_b; |
106 | bool use_buffer_c; |
107 | |
108 | brgemm_matmul_bcast_desc_t bcast_A_desc; |
109 | brgemm_matmul_bcast_desc_t bcast_B_desc; |
110 | |
111 | data_type_t src_dt; |
112 | data_type_t dst_dt; |
113 | data_type_t wei_dt; |
114 | data_type_t acc_dt; |
115 | data_type_t bia_dt; |
116 | int nthr; |
117 | int nthr_k; |
118 | |
119 | // Auxiliary values for init_config() and execute() |
120 | dim_t a_dt_sz, b_dt_sz, c_dt_sz, acc_dt_sz, bias_dt_sz; |
121 | |
122 | // used for transposed buffer datatype when different from x_dt_sz |
123 | // (e.g. used in BF32 implementations having to down-convert to BF16 |
124 | // from FP32 implementation) |
125 | dim_t tr_a_dt_sz, tr_b_dt_sz; |
126 | |
127 | int M_chunks; |
128 | int N_chunks; |
129 | int K_chunks; |
130 | int num_M_blocks; |
131 | int num_N_blocks; |
132 | dim_t M_chunk_elems; |
133 | dim_t N_chunk_elems; |
134 | dim_t K_chunk_elems; |
135 | |
136 | // Pre-calculated memory strides for each tensor |
137 | dim_t A_strides[3]; |
138 | dim_t B_strides[3]; |
139 | dim_t C_strides[3]; |
140 | dim_t buffer_c_chunk_sz; |
141 | dim_t buffer_c_per_thread_sz; |
142 | |
143 | dim_t A_ptr_shift_b; |
144 | dim_t B_ptr_shift_b; |
145 | dim_t C_ptr_shift_b; |
146 | dim_t copy_A_src_stride; |
147 | dim_t copy_B_wei_stride; |
148 | |
149 | dim_t buffer_a_chunk_sz; |
150 | dim_t buffer_a_chunk_shift_along_m; |
151 | dim_t buffer_a_per_thread_sz; |
152 | |
153 | dim_t buffer_b_chunk_sz; |
154 | dim_t buffer_b_per_thread_sz; |
155 | dim_t s8s8_comp_ithr_str; |
156 | dim_t s8s8_comp_b_str; |
157 | dim_t s8s8_comp_n_str; |
158 | bool has_zero_point_a, has_zero_point_b, has_zero_point_c; |
159 | bool post_ops_applicable; |
160 | bool transposed_A; |
161 | bool blocked_B; |
162 | |
163 | dim_t zp_a_comp_shift_n; |
164 | dim_t zp_a_comp_elems_per_thr; |
165 | |
166 | dim_t zp_b_comp_result_shift_m; |
167 | dim_t zp_b_comp_buffer_start; |
168 | dim_t zp_b_comp_buffer_shift_m; |
169 | dim_t zp_b_comp_elems_per_thr; |
170 | |
171 | int wsp_tile_per_thr_bytes; |
172 | int brgemm_batch_element_per_thr_sz; |
173 | bool is_amx; |
174 | |
175 | int required_k_granularity; |
176 | bool is_bf32 = false; |
177 | bool req_wei_vnni_downconvert = false; |
178 | }; |
179 | |
180 | struct brgemm_matmul_conf_utils_t { |
181 | |
182 | brgemm_matmul_conf_utils_t(brgemm_matmul_conf_t &bgmmc, const cpu_isa_t isa, |
183 | const primitive_attr_t &attr, bool A_any_layout, bool B_any_layout, |
184 | bool C_any_layout, bool bias_any_layout); |
185 | |
186 | inline bool check_b_layout_blocked_by_n(format_tag_t matrix_b_tag) const { |
187 | return blocked_B_layouts_allowed |
188 | && utils::one_of(matrix_b_tag, blocked_64n_B_layout_tag, |
189 | blocked_48n_B_layout_tag, blocked_32n_B_layout_tag, |
190 | blocked_16n_B_layout_tag); |
191 | } |
192 | |
193 | inline bool get_blocked_B() const { |
194 | return blocked_B_layouts_allowed |
195 | && check_b_layout_blocked_by_n(bgmmc.wei_tag); |
196 | } |
197 | |
198 | inline bool use_buffer_b(bool use_heuristic = true) const { |
199 | if (bgmmc.is_amx) |
200 | // use b_buffer for AMX when: |
201 | // - not bf32 && using non-blocked weights |
202 | // - is bf32 |
203 | return IMPLICATION(!wei_down_convert_to_vnni(), !bgmmc.blocked_B); |
204 | |
205 | // Values based on measured performance difference |
206 | // between plain and copy-to-blocked routine. |
207 | size_t big_LDB = bgmmc.N > 256; |
208 | bool is_pow2 = math::is_pow2(bgmmc.N); |
209 | bool use_copy_buffer = IMPLICATION( |
210 | this->is_f32(), use_heuristic && (big_LDB && is_pow2)); |
211 | return (this->is_f16() && bgmmc.isa == avx512_core_fp16) |
212 | || (use_copy_buffer && this->check_is_plain(bgmmc.wei_tag)) |
213 | || this->check_is_transposed(bgmmc.wei_tag) |
214 | || (bgmmc.wei_tag == format_tag::acbd) |
215 | || (bgmmc.wei_tag == format_tag::adbc); |
216 | } |
217 | |
218 | inline dim_t get_actual_LDB() const { |
219 | if (bgmmc.wei_tag == format_tag::acbd && !bgmmc.use_buffer_b) { |
220 | assert(bgmmc.b_dt_sz == bgmmc.tr_b_dt_sz); |
221 | return bgmmc.B_strides[1] / bgmmc.b_dt_sz; |
222 | } |
223 | bool use_blocked_LDB = bgmmc.is_amx || bgmmc.use_buffer_b |
224 | || bgmmc.wei_tag != plain_tensor_layout_tag; |
225 | return use_blocked_LDB ? bgmmc.wei_n_blk : bgmmc.N; |
226 | } |
227 | |
228 | inline bool maybe_low_brg_blocking() const { |
229 | // Check if m_block is a prime number from 32 to 64 |
230 | const bool is_prime_num |
231 | = utils::one_of(bgmmc.M_blk, 37, 41, 43, 47, 53, 59, 61); |
232 | const bool maybe_ldb_tail = bgmmc.N % 16; |
233 | return is_prime_num && IMPLICATION(bgmmc.M_blk < 48, maybe_ldb_tail); |
234 | } |
235 | |
236 | inline bool check_n_blk_fixed() const { return n_blk_fixed; } |
237 | |
238 | inline bool check_is_transposed(format_tag_t tag) const { |
239 | return tag == transposed_tensor_layout_tag; |
240 | } |
241 | |
242 | inline bool check_is_plain(format_tag_t tag) const { |
243 | return tag == plain_tensor_layout_tag; |
244 | } |
245 | |
246 | inline bool is_f32() const { return f32_dt; } |
247 | |
248 | inline bool is_bf16() const { return bf16_dt; } |
249 | |
250 | inline bool is_f16() const { return f16_dt; } |
251 | |
252 | inline bool is_int8() const { return int8_dt; } |
253 | |
254 | inline bool is_bf32() const { return bf32_dt; } |
255 | |
256 | inline bool is_int8_with_bf16_dst() const { |
257 | return this->is_int8() && bgmmc.dst_dt == data_type::bf16; |
258 | } |
259 | |
260 | inline bool wei_down_convert_to_vnni() const { |
261 | return bf32_dt && get_blocked_B(); |
262 | } |
263 | |
264 | inline bool is_any_B_layout() const { return B_any_layout; } |
265 | |
266 | status_t set_or_check_B_tag( |
267 | memory_desc_t &B_md, bool init_n_tag = true) const; |
268 | status_t update_and_check_B_tag(memory_desc_t &B_md, int n_blk_size) const; |
269 | status_t set_or_check_tags(memory_desc_t &A_md, memory_desc_t &C_md, |
270 | memory_desc_t &bias_md) const; |
271 | status_t set_B_flags(memory_desc_t &B_md) const; |
272 | format_tag_t pick_blocked_B_layout(int n_blk) const; |
273 | |
274 | private: |
275 | brgemm_matmul_conf_t &bgmmc; |
276 | |
277 | const bool f32_dt, bf16_dt, f16_dt, int8_dt, bf32_dt; |
278 | const bool A_any_layout; |
279 | const bool B_any_layout; |
280 | const bool C_any_layout; |
281 | const bool bias_any_layout; |
282 | |
283 | const format_tag_t plain_tensor_layout_tag; |
284 | const format_tag_t transposed_tensor_layout_tag; |
285 | const format_tag_t blocked_64n_B_layout_tag, blocked_48n_B_layout_tag, |
286 | blocked_32n_B_layout_tag, blocked_16n_B_layout_tag; |
287 | const bool blocked_B_layouts_allowed; |
288 | const bool n_blk_fixed; |
289 | }; |
290 | |
291 | void init_aux_values(brgemm_matmul_conf_t &bgmmc, |
292 | const memory_desc_wrapper &src_d, const memory_desc_wrapper &wei_d, |
293 | const memory_desc_wrapper &dst_d); |
294 | |
295 | status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc, |
296 | const matmul_desc_t &mmd, memory_desc_t &src_md, |
297 | memory_desc_t &weights_md, memory_desc_t &dst_md, |
298 | memory_desc_t &bias_md, primitive_attr_t &attr); |
299 | |
300 | void init_scratchpad(memory_tracking::registrar_t &scratchpad, |
301 | const brgemm_matmul_conf_t &bgmmc); |
302 | |
303 | int get_default_n_block(format_tag_t matrix_b_tag); |
304 | |
305 | } // namespace matmul |
306 | } // namespace x64 |
307 | } // namespace cpu |
308 | } // namespace impl |
309 | } // namespace dnnl |
310 | |
311 | #endif |
312 | |