1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "cpu/x64/brgemm/brgemm_utils.hpp" |
18 | #include "cpu/x64/brgemm/jit_brdgmm_kernel.hpp" |
19 | |
20 | #include "cpu/x64/cpu_isa_traits.hpp" |
21 | |
22 | #include "common/c_types_map.hpp" |
23 | #include "common/dnnl_thread.hpp" |
24 | #include "common/nstl.hpp" |
25 | #include "common/type_helpers.hpp" |
26 | #include "common/utils.hpp" |
27 | |
28 | namespace dnnl { |
29 | namespace impl { |
30 | namespace cpu { |
31 | namespace x64 { |
32 | |
33 | using namespace dnnl::impl::utils; |
34 | |
35 | enum { |
36 | decomposition_2x2 = 101, |
37 | decomposition_3x1_3, |
38 | decomposition_3x1_2, |
39 | undefined, |
40 | }; |
41 | |
42 | impl::data_type_t get_accum_datatype(brgemm_t *brg) { |
43 | // this assert should check if 'init_kernel_datatype()' was previously |
44 | // called. |
45 | assert(brg->is_int8 || brg->is_bf16 || brg->is_f32 || brg->is_f16); |
46 | return brg->is_int8 ? data_type::s32 : data_type::f32; |
47 | } |
48 | |
49 | void init_kernel_datatype( |
50 | brgemm_t *brg, impl::data_type_t dt_a, impl::data_type_t dt_b) { |
51 | assert(dt_a != data_type::undef && dt_b != data_type::undef); |
52 | brg->is_int8 = utils::one_of(dt_a, data_type::u8, data_type::s8) |
53 | && utils::one_of(dt_b, data_type::u8, data_type::s8); |
54 | brg->is_bf16 = (dt_a == data_type::bf16) && (dt_b == data_type::bf16); |
55 | brg->is_f32 = (dt_a == data_type::f32) && (dt_b == data_type::f32); |
56 | brg->is_f16 = utils::one_of(data_type::f16, dt_a, dt_b); |
57 | assert(brg->is_int8 || brg->is_bf16 || brg->is_f32 || brg->is_f16); |
58 | } |
59 | |
60 | void init_common_conf(brgemm_t *brg, brgemm_batch_kind_t type, float alpha, |
61 | float beta, const brgemm_strides_t *strides) { |
62 | brg->beta = beta; |
63 | brg->alpha = alpha; |
64 | brg->type = type; |
65 | brg->with_bias = false; |
66 | brg->with_eltwise = false; |
67 | brg->with_sum = false; |
68 | brg->sum_scale = 0; |
69 | brg->sum_zp = 0; |
70 | brg->with_scales = false; |
71 | |
72 | if (strides != nullptr) { |
73 | brg->stride_a = strides->stride_a; |
74 | brg->stride_b = strides->stride_b; |
75 | } else { |
76 | brg->stride_a = brg->stride_b = 0; |
77 | } |
78 | } |
79 | |
80 | namespace brgemm_utils { |
81 | |
82 | bool can_dispatch_uker(const brgemm_t *brg) { |
83 | return brg->is_tmm && brg->type == brgemm_addr && brg->brgattr.use_uker |
84 | && !brg->brgattr.generate_skip_accumulation; |
85 | } |
86 | |
87 | void maybe_try_bf32(brgemm_t *brg) { |
88 | const bool try_bf32 = brg->is_f32 |
89 | && brg->brgattr.fpmath_mode == fpmath_mode::bf16 |
90 | && utils::one_of(brg->isa_user, isa_undef, avx512_core_amx) |
91 | && mayiuse(avx512_core_amx); |
92 | if (try_bf32) { |
93 | const bool is_tmm = brg->is_tmm; |
94 | brg->is_tmm = true; |
95 | if (can_dispatch_uker(brg) /*Requires is_amx to be true*/) { |
96 | brg->is_bf32 = true; |
97 | } else { |
98 | brg->is_bf32 = false; |
99 | // Restore |
100 | brg->is_tmm = is_tmm; |
101 | } |
102 | } |
103 | } |
104 | |
105 | void set_isa_impl(brgemm_t *brg) { |
106 | auto is_isa_ok = [&](cpu_isa_t isa) { |
107 | return mayiuse(isa) && |
108 | // maybe IMPLICATION(brg->isa_user != isa_undef, |
109 | // is_superset(brg->isa_user, isa)), but the API is not clear. |
110 | one_of(brg->isa_user, isa_undef, isa); |
111 | }; |
112 | |
113 | if (brg->is_bf32) { |
114 | brg->isa_impl = avx512_core_amx; |
115 | } else if (brg->is_f32) { |
116 | brg->isa_impl = utils::map(true, isa_undef, |
117 | is_isa_ok(avx512_core) || is_isa_ok(avx512_core_amx) /*bf32*/, |
118 | avx512_core, is_isa_ok(avx2), avx2, |
119 | // Allow avx512_core_fp16 isa in case of a f16 primitive that |
120 | // is implemented using pre-conversion of inputs to f32. |
121 | // This is needed to support f16 binary post-ops. |
122 | is_isa_ok(avx512_core_fp16), avx512_core_fp16, is_isa_ok(avx2), |
123 | avx2); |
124 | } else if (brg->is_bf16) { |
125 | brg->isa_impl = utils::map(true, isa_undef, is_isa_ok(avx512_core_amx), |
126 | avx512_core_amx, is_isa_ok(avx512_core_bf16), avx512_core_bf16, |
127 | is_isa_ok(avx2_vnni_2), avx2_vnni_2); |
128 | } else if (brg->is_f16) { |
129 | brg->isa_impl |
130 | = utils::map(true, isa_undef, is_isa_ok(avx512_core_amx_fp16), |
131 | avx512_core_amx_fp16, is_isa_ok(avx512_core_fp16), |
132 | avx512_core_fp16, is_isa_ok(avx2_vnni_2), avx2_vnni_2); |
133 | } else if (brg->is_int8) { |
134 | brg->isa_impl = utils::map(true, isa_undef, is_isa_ok(avx512_core_amx), |
135 | avx512_core_amx, is_isa_ok(avx512_core_vnni), avx512_core_vnni, |
136 | is_isa_ok(avx2_vnni), avx2_vnni); |
137 | } |
138 | } |
139 | |
140 | void set_brg_vmm(brgemm_t *brg) { |
141 | brg->is_tmm = brg->is_int8_tmm || brg->is_bf16_tmm || brg->is_f16_tmm |
142 | || brg->is_bf32; |
143 | brg->is_zmm = !brg->is_tmm && mayiuse(avx512_core) |
144 | && is_superset(brg->isa_impl, avx512_core); |
145 | brg->is_ymm |
146 | = !brg->is_zmm && mayiuse(avx2) && is_superset(brg->isa_impl, avx2); |
147 | } |
148 | |
149 | status_t brgemm_blocking(brgemm_t *brg) { |
150 | |
151 | set_isa_impl(brg); |
152 | if (brg->isa_impl == isa_undef) return status::unimplemented; |
153 | set_brg_vmm(brg); |
154 | if (!(brg->is_tmm || brg->is_zmm || brg->is_ymm)) |
155 | return status::unimplemented; |
156 | |
157 | if (!brg->is_tmm) { |
158 | const int simd_w = is_superset(brg->isa_impl, avx512_core) ? 16 : 8; |
159 | brg->ld_block = simd_w; |
160 | brg->ldb = brg->load_dim / brg->ld_block; |
161 | brg->ldb_tail = brg->load_dim % brg->ld_block; |
162 | |
163 | brg->ld_block2 = 4; // (M < 9) ? 2 : 4 | TODO - fix this for INT8 |
164 | brg->ldb2 = brg->ldb / brg->ld_block2; |
165 | brg->ldb2_tail = brg->ldb % brg->ld_block2; |
166 | |
167 | if (brg->ldb2 == 0) brg->ld_block2 = nstl::max(1, brg->ldb2_tail); |
168 | brg->embd_bcst = brg->is_f32 |
169 | && (brg->ldb2_tail <= 1 && brg->ldb2 == 0) |
170 | /*only avx512 or more can bcast*/ |
171 | && is_superset(brg->isa_impl, avx512_core); |
172 | |
173 | const auto ld_block |
174 | = (brg->ldb2 != 0) ? brg->ld_block2 : brg->ldb2_tail; |
175 | const auto adj_ld_block = (ld_block == 0) ? (ld_block + 1) : ld_block; |
176 | |
177 | const int max_isa_regs |
178 | = is_superset(brg->isa_impl, avx512_core) ? 32 : 16; |
179 | const int max_bcst_regs = 1; |
180 | const bool req_compensation = brg->req_s8s8_compensation |
181 | || brg->zp_type_a != brgemm_broadcast_t::none; |
182 | const bool req_zp_a_comp_pads |
183 | = (brg->req_cal_comp_pads || brg->brgattr.max_top_vpad > 0 |
184 | || brg->brgattr.max_bottom_vpad > 0) |
185 | && brg->zp_type_a != brgemm_broadcast_t::none; |
186 | const auto max_regs = max_isa_regs - (adj_ld_block + max_bcst_regs); |
187 | auto max_block |
188 | = (brg->embd_bcst ? max_regs - 4 |
189 | : ((brg->beta == 1.f || brg->beta == 0.f) |
190 | ? max_regs |
191 | : max_regs - 1)); |
192 | max_block -= req_compensation; |
193 | max_block -= req_zp_a_comp_pads; |
194 | if (!is_superset(brg->isa_impl, avx512_core) && brg->ldb_tail > 0 |
195 | && brg->beta != 0.f) |
196 | --max_block; |
197 | if (req_zp_a_comp_pads) max_block = nstl::min(max_block, max_regs - 5); |
198 | if (brg->is_bf16_emu) |
199 | max_block |
200 | = nstl::min(max_block, 28); // bf16_emu only for avx512_core |
201 | max_block /= adj_ld_block; |
202 | const int min_block = 1; |
203 | float best_bd_block_eff = 0.f; |
204 | brg->bd_block = 1; |
205 | for (int bd_block = max_block; bd_block >= min_block; bd_block--) { |
206 | const auto bd_block_disb = static_cast<float>(brg->bcast_dim) |
207 | / rnd_up(brg->bcast_dim, bd_block); |
208 | const auto brgemm_microkernel_eff |
209 | = (static_cast<float>(adj_ld_block) * bd_block) |
210 | / (((adj_ld_block) + bd_block) * max_block); |
211 | const auto bd_block_eff = bd_block_disb * brgemm_microkernel_eff; |
212 | |
213 | float = static_cast<float>(brg->typesize_A) |
214 | * (bd_block * brg->reduce_dim); |
215 | if (block_foot_print <= static_cast<float>( |
216 | platform::get_per_core_cache_size(1)) |
217 | && (bd_block_eff > best_bd_block_eff)) { |
218 | brg->bd_block = bd_block; |
219 | best_bd_block_eff = bd_block_eff; |
220 | } |
221 | } |
222 | brg->bdb = brg->bcast_dim / brg->bd_block; |
223 | brg->bdb_tail = brg->bcast_dim % brg->bd_block; |
224 | |
225 | const int rd_unroll = 4; |
226 | const int vnni_granularity |
227 | = (brg->is_f16 && brg->isa_impl == avx512_core_fp16) |
228 | ? 1 |
229 | : data_type_vnni_granularity(brg->dt_a); |
230 | brg->rd_block = rd_unroll * vnni_granularity; |
231 | brg->rdb = brg->reduce_dim / brg->rd_block; |
232 | brg->rdb_tail = brg->reduce_dim % brg->rd_block; |
233 | |
234 | brg->is_M_tail = false; |
235 | } else { |
236 | // Blocking configuration for AMX |
237 | const int max_width = 16, min_width = 1; |
238 | brg->ld_block = 16; |
239 | brg->ldb = brg->load_dim / brg->ld_block; |
240 | brg->ldb_tail = brg->load_dim % brg->ld_block; |
241 | |
242 | auto find_bd_block_for_bd_mask = [&]() { |
243 | const auto bd_mask_size = brg->bcast_dim; |
244 | if (brg->brgattr.bd_mask_level != 2 || bd_mask_size == 0) |
245 | return false; |
246 | |
247 | const auto sm_buffer = brg->brgattr.bd_mask; |
248 | auto min_bdb = INT_MAX; |
249 | const auto start_bd_block = nstl::min(max_width, brg->bcast_dim); |
250 | auto best_bd_block = start_bd_block; |
251 | for (auto bd_block = start_bd_block; bd_block > 0; bd_block--) { |
252 | auto bdb = 0; |
253 | for (int i = 0; i < bd_mask_size;) { |
254 | if (brg->brgattr.bd_mask_level == 2 && sm_buffer[i] == 0) { |
255 | i++; |
256 | } else { |
257 | i += bd_block; |
258 | if (i > brg->bcast_dim) { |
259 | // bcast_dim not divided by bd_block |
260 | bdb = INT_MAX; |
261 | } else |
262 | bdb++; |
263 | } |
264 | } |
265 | if (bdb < min_bdb) { |
266 | min_bdb = bdb; |
267 | best_bd_block = bd_block; |
268 | } |
269 | } |
270 | brg->bd_block = best_bd_block; |
271 | brg->bdb_tail = 0; |
272 | brg->bdb = min_bdb; |
273 | return true; |
274 | }; |
275 | |
276 | auto set_decomposition_by_ld = [&]() { |
277 | if (brg->bd_block2 == 1 && brg->ldb > 0 && brg->ldb_tail == 0) { |
278 | if (brg->ldb % 3 == 0) |
279 | brg->ld_block2 = 3; |
280 | else if (brg->ldb % 2 == 0) |
281 | brg->ld_block2 = 2; |
282 | else |
283 | brg->ld_block2 = 1; |
284 | } else { |
285 | brg->ld_block2 |
286 | = (brg->ldb > 0 && brg->ldb % 2 == 0 |
287 | && brg->ldb_tail == 0 && brg->bd_block2 < 3) |
288 | ? 2 |
289 | : 1; |
290 | } |
291 | brg->ldb2 = brg->ldb / brg->ld_block2; |
292 | brg->ldb2_tail = brg->ldb % brg->ld_block2; |
293 | |
294 | // Re-adjust the bd_block2 if possible |
295 | if (brg->ld_block2 == 1 && !brg->is_M_tail && brg->ldb_tail == 0) { |
296 | brg->bd_block2 = (brg->bdb >= 3) ? 3 : (brg->bdb >= 2) ? 2 : 1; |
297 | brg->bdb2 = brg->bdb / brg->bd_block2; |
298 | brg->bdb2_tail = (brg->bd_block2 == 1) |
299 | ? brg->bdb |
300 | : brg->bdb % brg->bd_block2; |
301 | } |
302 | }; |
303 | |
304 | auto try_3x1_decomposition = [&](int width_step) { |
305 | brg->is_M_tail = false; |
306 | if (brg->bcast_dim > (width_step - 1) * max_width |
307 | && brg->bcast_dim < width_step * max_width |
308 | && brg->ldb_tail == 0) { |
309 | if (!find_bd_block_for_bd_mask()) { |
310 | brg->bd_block = max_width; |
311 | brg->bdb = div_up(brg->bcast_dim, brg->bd_block); |
312 | brg->bdb_tail = brg->bcast_dim % brg->bd_block; |
313 | brg->is_M_tail = true; |
314 | } |
315 | brg->bd_block2 = width_step; |
316 | brg->bdb2 = brg->bdb / brg->bd_block2; |
317 | brg->bdb2_tail = brg->bdb % brg->bd_block2; |
318 | set_decomposition_by_ld(); |
319 | return true; |
320 | } |
321 | return false; |
322 | }; |
323 | |
324 | auto try_2x2_decomposition = [&]() { |
325 | if (!find_bd_block_for_bd_mask()) { |
326 | for (int m_block = max_width; m_block >= min_width; m_block--) { |
327 | if (brg->bcast_dim % m_block == 0) { |
328 | brg->bd_block = m_block; |
329 | break; |
330 | } |
331 | } |
332 | if (brg->bd_block == 1) { |
333 | brg->bd_block = nstl::min(max_width, brg->bcast_dim); |
334 | brg->bdb_tail = brg->bcast_dim % max_width; |
335 | for (int i = max_width; i >= min_width; i--) { |
336 | const auto i_tail = brg->bcast_dim % i; |
337 | if (i_tail > brg->bdb_tail || i_tail == 0) { |
338 | brg->bd_block = i; |
339 | brg->bdb_tail = i_tail; |
340 | if (i_tail == 0) break; |
341 | } |
342 | } |
343 | } |
344 | brg->bdb = brg->bcast_dim / brg->bd_block; |
345 | brg->bdb_tail = brg->bcast_dim % brg->bd_block; |
346 | } |
347 | |
348 | brg->bd_block2 = (brg->bdb >= 2) ? 2 : 1; |
349 | brg->bdb2 = brg->bdb / brg->bd_block2; |
350 | brg->bdb2_tail = (brg->bd_block2 == 1) ? brg->bdb |
351 | : brg->bdb % brg->bd_block2; |
352 | |
353 | brg->is_M_tail = false; |
354 | |
355 | set_decomposition_by_ld(); |
356 | |
357 | return !(brg->ld_block2 == 1 || brg->bd_block2 == 1 |
358 | || brg->bd_block < 8); |
359 | }; |
360 | |
361 | bool is_decomposition_defined = false; |
362 | for (int i = decomposition_2x2; i != undefined; i++) { |
363 | switch (i) { |
364 | case decomposition_2x2: |
365 | is_decomposition_defined = try_2x2_decomposition(); |
366 | break; |
367 | case decomposition_3x1_3: |
368 | is_decomposition_defined = try_3x1_decomposition(3); |
369 | break; |
370 | case decomposition_3x1_2: |
371 | is_decomposition_defined = try_3x1_decomposition(2); |
372 | break; |
373 | default: assert(!"invalid value" ); break; |
374 | }; |
375 | if (is_decomposition_defined) break; |
376 | } |
377 | if (!is_decomposition_defined) try_2x2_decomposition(); |
378 | |
379 | const bool try_load_nt_A = (brg->brgattr.hint_innermost_loop |
380 | == brgemm_bd_loop_innermost); |
381 | const bool try_load_nt_B = (brg->brgattr.hint_innermost_loop |
382 | == brgemm_ld_loop_innermost); |
383 | const bool try_load_nt |
384 | = (static_cast<size_t>(brg->typesize_A) |
385 | * brg->brgattr.hint_expected_A_size |
386 | + static_cast<size_t>(brg->typesize_B) |
387 | * brg->brgattr.hint_expected_B_size |
388 | + static_cast<size_t>(brg->typesize_C) |
389 | * brg->brgattr.hint_expected_C_size) |
390 | >= platform::get_per_core_cache_size(1); |
391 | brg->load_nt_A = try_load_nt_A && try_load_nt; |
392 | brg->load_nt_B = try_load_nt_B && try_load_nt; |
393 | |
394 | auto recalc_bd_block = [&](int new_bd_block) { |
395 | if (new_bd_block == 0) return; |
396 | brg->bd_block = new_bd_block; |
397 | brg->bdb = div_up(brg->bcast_dim, brg->bd_block); |
398 | brg->bdb_tail = brg->bcast_dim % brg->bd_block; |
399 | brg->is_M_tail = (brg->bdb_tail != 0); |
400 | }; |
401 | |
402 | auto recalc_bd_block2 = [&](int new_bd_block2) { |
403 | if (new_bd_block2 == 0) return; |
404 | brg->bd_block2 = new_bd_block2; |
405 | if (brg->bdb_tail && brg->bd_block2 > 1) brg->bd_block2--; |
406 | auto full_bd_blocks = brg->bdb - (brg->bdb_tail != 0 ? 1 : 0); |
407 | brg->bdb2 = full_bd_blocks / brg->bd_block2; |
408 | brg->bdb2_tail = full_bd_blocks % brg->bd_block2; |
409 | }; |
410 | |
411 | auto recalc_ld_block = [&](int new_ld_block) { |
412 | if (new_ld_block == 0) return; |
413 | brg->ld_block = new_ld_block; |
414 | brg->ldb = brg->load_dim / brg->ld_block; |
415 | brg->ldb_tail = brg->load_dim % brg->ld_block; |
416 | }; |
417 | |
418 | auto recalc_ld_block2 = [&](int new_ld_block2) { |
419 | if (new_ld_block2 == 0) return; |
420 | brg->ld_block2 = new_ld_block2; |
421 | if (brg->ldb_tail && brg->ld_block2 > 1) brg->ld_block2--; |
422 | brg->ldb2 = brg->ldb / brg->ld_block2; |
423 | brg->ldb2_tail = brg->ldb % brg->ld_block2; |
424 | }; |
425 | |
426 | recalc_bd_block2(brg->bd_block2); |
427 | recalc_ld_block2(brg->ld_block2); |
428 | |
429 | // check hints for blocking parameters |
430 | recalc_bd_block(brg->brgattr.hint_bd_block); |
431 | recalc_bd_block2(brg->brgattr.hint_bd_block2); |
432 | recalc_ld_block(brg->brgattr.hint_ld_block); |
433 | recalc_ld_block2(brg->brgattr.hint_ld_block2); |
434 | |
435 | if (brg->brgattr.hint_load_nt_A != brgemm_hint_nt_undef) |
436 | brg->load_nt_A |
437 | = (brg->brgattr.hint_load_nt_A == brgemm_hint_nt_true); |
438 | if (brg->brgattr.hint_load_nt_B != brgemm_hint_nt_undef) |
439 | brg->load_nt_B |
440 | = (brg->brgattr.hint_load_nt_B == brgemm_hint_nt_true); |
441 | |
442 | const auto max_rd_block |
443 | = (brg->is_bf16_tmm || brg->is_f16_tmm || brg->is_bf32) ? 32 |
444 | : 64; |
445 | const auto rd_block_step |
446 | = (brg->is_bf16_tmm || brg->is_f16_tmm || brg->is_bf32) ? 2 : 4; |
447 | // TODO: if rd_block calculated is very small then maybe it makes |
448 | // sense to use 1x2 or 2x1 blocking with supporting rd_block |
449 | // and rdb_tail |
450 | brg->rd_block = rd_block_step; |
451 | for (int i = max_rd_block; i > 0; i -= rd_block_step) { |
452 | if (brg->reduce_dim % i == 0) { |
453 | brg->rd_block = i; |
454 | break; |
455 | } |
456 | } |
457 | brg->rdb = brg->reduce_dim / brg->rd_block; |
458 | brg->rdb_tail = brg->reduce_dim % brg->rd_block; |
459 | |
460 | // Remove these guards in the future (add tail processing by reduction |
461 | // dimension) |
462 | if (!IMPLICATION(brg->rdb > 0 && brg->rdb_tail, brg->is_bf32)) |
463 | return status::unimplemented; |
464 | if (!IMPLICATION( |
465 | (brg->rdb_tail |
466 | % ((brg->is_bf16_tmm || brg->is_f16_tmm) ? 2 : 4)) |
467 | != 0, |
468 | brg->is_bf32)) |
469 | return status::unimplemented; |
470 | |
471 | //TODO: check this condition |
472 | brg->interleave_tilestores_ = brg->beta == 0 |
473 | && (brg->brgattr.use_interleave_stores |
474 | && (brg->bd_block2 * brg->ld_block2 == 4) |
475 | && !brg->brgattr.var_bs) |
476 | ? true |
477 | : false; |
478 | } |
479 | |
480 | return status::success; |
481 | } |
482 | |
483 | status_t brdgmm_blocking(brgemm_t *brg) { |
484 | |
485 | if (brg->isa_impl == isa_undef) return status::unimplemented; |
486 | |
487 | const int requires_permute_dst_vmm = brg->isa_impl == avx512_core_vnni |
488 | && jit_brdgmm_kernel_base_t<avx512_core_vnni, |
489 | Xbyak::Zmm>::is_fast_vnni_int8(*brg); |
490 | const int max_vregs = isa_num_vregs(brg->isa_impl); |
491 | const int aux_vregs |
492 | = nstl::max(brg->is_bf16_emu * 4, 2) + requires_permute_dst_vmm; |
493 | const int max_acc_vmms = max_vregs - aux_vregs; |
494 | const int simd_w = isa_max_vlen(brg->isa_impl) / brg->typesize_C; |
495 | const bool is_avx2_vnni_2_xf16 |
496 | = (brg->is_bf16 || brg->is_f16) && brg->isa_impl == avx2_vnni_2; |
497 | |
498 | auto &M = brg->bcast_dim; |
499 | auto &N = brg->load_dim; |
500 | |
501 | // In current implementation of dgmm, there is no reduce dim. |
502 | auto &m_block1 = brg->bd_block; |
503 | auto &nb_m_block1 = brg->bdb; |
504 | auto &m_block1_tail = brg->bdb_tail; |
505 | auto &m_block2 = brg->bd_block2; |
506 | auto &nb_m_block2 = brg->bdb2; |
507 | auto &m_block2_tail = brg->bdb2_tail; |
508 | |
509 | auto &n_block1 = brg->ld_block; |
510 | auto &nb_n_block1 = brg->ldb; |
511 | auto &n_block1_tail = brg->ldb_tail; |
512 | auto &n_block2 = brg->ld_block2; |
513 | auto &nb_n_block2 = brg->ldb2; |
514 | auto &n_block2_tail = brg->ldb2_tail; |
515 | |
516 | // begin blocking |
517 | // for avx2_vnni_2_xf16, instead of processing a n_block1 at once, it is |
518 | // processed as even/odd pair. |
519 | const int n_block1_num_steps = is_avx2_vnni_2_xf16 ? 2 : 1; |
520 | n_block1 = n_block1_num_steps * simd_w; |
521 | nb_n_block1 = div_up(N, n_block1); |
522 | n_block1_tail = N % n_block1; |
523 | |
524 | const int max_n_block2_vmms = 4; |
525 | const int max_n_block2 = max_n_block2_vmms / n_block1_num_steps; |
526 | n_block2 = nstl::min(max_n_block2, nb_n_block1); |
527 | nb_n_block2 = div_up(nb_n_block1, n_block2); |
528 | n_block2_tail = nb_n_block1 % n_block2; |
529 | |
530 | m_block1 = 1; |
531 | nb_m_block1 = M / m_block1; |
532 | m_block1_tail = M % m_block1; |
533 | m_block2 = nstl::min( |
534 | nb_m_block1, max_acc_vmms / (n_block2 * n_block1_num_steps)); |
535 | nb_m_block2 = div_up(nb_m_block1, m_block2); |
536 | m_block2_tail = nb_m_block1 % m_block2; |
537 | |
538 | return status::success; |
539 | } |
540 | |
541 | void init_brgemm_conf(brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type, |
542 | impl::data_type_t dt_a, impl::data_type_t dt_b, brgemm_layout_t layout, |
543 | float alpha, float beta, dim_t LDA, dim_t LDB, dim_t LDC, dim_t M, |
544 | dim_t N, dim_t K, const brgemm_strides_t *strides, bool is_bf32) { |
545 | |
546 | init_common_conf(brg, type, alpha, beta, strides); |
547 | |
548 | brg->layout = layout; |
549 | |
550 | brg->dt_a = brg->is_row_major() ? dt_a : dt_b; |
551 | brg->dt_b = brg->is_row_major() ? dt_b : dt_a; |
552 | init_kernel_datatype(brg, brg->dt_a, brg->dt_b); |
553 | |
554 | brg->dt_c = get_accum_datatype(brg); |
555 | brg->dt_d = brg->dt_c; |
556 | brg->dt_bias = brg->dt_c; |
557 | |
558 | brg->typesize_A = types::data_type_size(brg->dt_a); |
559 | brg->typesize_B = types::data_type_size(brg->dt_b); |
560 | brg->typesize_C = types::data_type_size(brg->dt_c); |
561 | brg->typesize_D = types::data_type_size(brg->dt_d); |
562 | |
563 | brg->isa_user = isa; |
564 | set_isa_impl(brg); |
565 | brg->is_int8_tmm = brg->is_int8 && brg->isa_impl == avx512_core_amx; |
566 | brg->is_bf16_tmm = brg->is_bf16 && brg->isa_impl == avx512_core_amx; |
567 | brg->is_f16_tmm = brg->is_f16 && brg->isa_impl == avx512_core_amx_fp16; |
568 | brg->is_bf32 = is_bf32 |
569 | && utils::one_of(brg->isa_user, isa_undef, avx512_core_amx) |
570 | && mayiuse(avx512_core_amx); |
571 | set_brg_vmm(brg); // TODO: Investigate if it is really needed here. |
572 | brg->req_s8s8_compensation |
573 | = brg->is_int8 && !brg->is_int8_tmm && brg->dt_a == data_type::s8; |
574 | |
575 | brg->LDA = (brg->is_row_major()) ? static_cast<int>(LDA) |
576 | : static_cast<int>(LDB); |
577 | brg->LDB = (brg->is_row_major()) ? static_cast<int>(LDB) |
578 | : static_cast<int>(LDA); |
579 | brg->LDC = static_cast<int>(LDC); |
580 | brg->LDD = static_cast<int>(LDC); |
581 | |
582 | brg->bcast_dim |
583 | = (brg->is_row_major()) ? static_cast<int>(M) : static_cast<int>(N); |
584 | brg->load_dim |
585 | = (brg->is_row_major()) ? static_cast<int>(N) : static_cast<int>(M); |
586 | brg->reduce_dim = static_cast<int>(K); |
587 | |
588 | brg->bd_block2 = 0; |
589 | brg->bdb2 = 0; |
590 | brg->bdb2_tail = 0; |
591 | |
592 | const bool is_b_in_vnni_format = !( |
593 | brg->dt_b == data_type::f16 && brg->isa_impl == avx512_core_fp16); |
594 | brg->ld_step |
595 | = is_b_in_vnni_format ? data_type_vnni_granularity(brg->dt_b) : 1; |
596 | |
597 | const bool has_no_vnni_compute_instruction |
598 | = (brg->is_f16 |
599 | && one_of(brg->isa_impl, avx2_vnni_2, avx512_core_fp16)) |
600 | || (brg->is_bf16 && brg->isa_impl == avx2_vnni_2); |
601 | brg->rd_step = has_no_vnni_compute_instruction |
602 | ? 1 |
603 | : data_type_vnni_granularity(brg->dt_b); |
604 | } |
605 | |
606 | void init_brdgmm_conf(brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type, |
607 | impl::data_type_t dt_a, impl::data_type_t dt_b, brgemm_layout_t layout, |
608 | float alpha, float beta, dim_t LDA, dim_t LDC, dim_t M, dim_t N, |
609 | const brgemm_strides_t *strides) { |
610 | |
611 | init_common_conf(brg, type, alpha, beta, strides); |
612 | |
613 | brg->layout = layout; |
614 | |
615 | brg->dt_a = dt_a; |
616 | brg->dt_b = dt_b; |
617 | init_kernel_datatype(brg, brg->dt_a, brg->dt_b); |
618 | |
619 | brg->dt_c = get_accum_datatype(brg); |
620 | brg->dt_d = brg->dt_c; |
621 | brg->dt_bias = brg->dt_c; |
622 | |
623 | brg->typesize_A = types::data_type_size(brg->dt_a); |
624 | brg->typesize_B = types::data_type_size(brg->dt_b); |
625 | brg->typesize_C = types::data_type_size(brg->dt_c); |
626 | brg->typesize_D = types::data_type_size(brg->dt_d); |
627 | |
628 | brg->isa_user = isa; |
629 | auto is_isa_ok = [&](cpu_isa_t isa) { |
630 | return mayiuse(isa) && one_of(brg->isa_user, isa_undef, isa); |
631 | }; |
632 | |
633 | if (brg->is_f32) { |
634 | brg->isa_impl = utils::map(true, isa_undef, is_isa_ok(avx512_core), |
635 | avx512_core, is_isa_ok(avx2), avx2); |
636 | } else if (brg->is_bf16) { |
637 | brg->isa_impl = utils::map(true, isa_undef, is_isa_ok(avx512_core_bf16), |
638 | avx512_core_bf16, is_isa_ok(avx2_vnni_2), avx2_vnni_2); |
639 | } else if (brg->is_f16) { |
640 | brg->isa_impl = utils::map(true, isa_undef, is_isa_ok(avx512_core_fp16), |
641 | avx512_core_fp16, is_isa_ok(avx2_vnni_2), avx2_vnni_2); |
642 | } else if (brg->is_int8) { |
643 | brg->isa_impl = utils::map(true, isa_undef, is_isa_ok(avx512_core_vnni), |
644 | avx512_core_vnni, is_isa_ok(avx2_vnni), avx2_vnni); |
645 | } |
646 | |
647 | brg->is_bf16_tmm = brg->is_bf16 && mayiuse(avx512_core_amx); |
648 | brg->is_dgmm = true; |
649 | |
650 | brg->LDA = static_cast<int>(LDA); |
651 | brg->LDC = static_cast<int>(LDC); |
652 | brg->LDD = static_cast<int>(LDC); |
653 | |
654 | brg->bcast_dim = M; |
655 | brg->load_dim = N; |
656 | } |
657 | |
658 | } // namespace brgemm_utils |
659 | } // namespace x64 |
660 | } // namespace cpu |
661 | } // namespace impl |
662 | } // namespace dnnl |
663 | |
664 | //vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
665 | |