1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "cpu/x64/brgemm/brgemm.hpp" |
18 | #include "cpu/x64/brgemm/brgemm_utils.hpp" |
19 | |
20 | #include "common/c_types_map.hpp" |
21 | #include "common/nstl.hpp" |
22 | #include "common/type_helpers.hpp" |
23 | #include "common/utils.hpp" |
24 | |
25 | #include "cpu/platform.hpp" |
26 | #include "cpu/x64/cpu_barrier.hpp" |
27 | #include "cpu/x64/injectors/jit_uni_postops_injector.hpp" |
28 | |
29 | namespace dnnl { |
30 | namespace impl { |
31 | namespace cpu { |
32 | namespace x64 { |
33 | |
34 | using namespace dnnl::impl::status; |
35 | using namespace dnnl::impl::utils; |
36 | |
37 | using namespace prop_kind; |
38 | using namespace data_type; |
39 | using namespace brgemm_utils; |
40 | |
41 | void brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs, |
42 | const brgemm_batch_element_t *batch, void *ptr_C, void *scratch) { |
43 | brgemm_kernel_params_t brgemm_p; |
44 | |
45 | brgemm_p.batch = batch; |
46 | brgemm_p.ptr_A = nullptr; |
47 | brgemm_p.ptr_B = nullptr; |
48 | brgemm_p.ptr_C = ptr_C; |
49 | brgemm_p.ptr_D = ptr_C; |
50 | brgemm_p.ptr_buf = scratch; |
51 | brgemm_p.ptr_bias = nullptr; |
52 | brgemm_p.do_post_ops = 0; |
53 | brgemm_p.do_apply_comp = 0; |
54 | brgemm_p.skip_accm = 0; |
55 | brgemm_p.BS = bs; |
56 | |
57 | assert(brg_kernel); |
58 | |
59 | (*brg_kernel)(&brgemm_p); |
60 | } |
61 | |
62 | void brgemm_kernel_execute(const brgemm_kernel_t *brg_kernel, int bs, |
63 | const void *addr_A, const void *addr_B, |
64 | const brgemm_batch_element_t *batch, void *ptr_C, void *scratch) { |
65 | brgemm_kernel_params_t brgemm_p; |
66 | |
67 | brgemm_p.batch = batch; |
68 | brgemm_p.ptr_A = addr_A; |
69 | brgemm_p.ptr_B = addr_B; |
70 | brgemm_p.ptr_C = ptr_C; |
71 | brgemm_p.ptr_D = ptr_C; |
72 | brgemm_p.ptr_buf = scratch; |
73 | brgemm_p.ptr_bias = nullptr; |
74 | brgemm_p.do_post_ops = 0; |
75 | brgemm_p.do_apply_comp = 0; |
76 | brgemm_p.skip_accm = 0; |
77 | brgemm_p.BS = bs; |
78 | assert(brg_kernel); |
79 | (*brg_kernel)(&brgemm_p); |
80 | } |
81 | |
82 | void brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel, int bs, |
83 | const brgemm_batch_element_t *batch, void *ptr_C, void *ptr_D, |
84 | const brgemm_post_ops_data_t &post_ops_data, void *scratch) { |
85 | brgemm_kernel_params_t brgemm_p; |
86 | |
87 | brgemm_p.batch = batch; |
88 | brgemm_p.ptr_A = nullptr; |
89 | brgemm_p.ptr_B = nullptr; |
90 | brgemm_p.ptr_C = ptr_C; |
91 | brgemm_p.ptr_D = ptr_D; |
92 | brgemm_p.ptr_buf = scratch; |
93 | brgemm_p.ptr_bias = post_ops_data.bias; |
94 | brgemm_p.ptr_scales = post_ops_data.scales; |
95 | brgemm_p.do_post_ops |
96 | = post_ops_data.do_only_comp || post_ops_data.do_only_zp_a_val ? 0 |
97 | : 1; |
98 | brgemm_p.do_apply_comp = post_ops_data.do_only_zp_a_val ? 0 : 1; |
99 | brgemm_p.skip_accm = post_ops_data.skip_accumulation ? 1 : 0; |
100 | brgemm_p.BS = bs; |
101 | brgemm_p.zp_a_val = post_ops_data.zp_a_val; |
102 | brgemm_p.post_ops_binary_rhs_arg_vec = post_ops_data.binary_post_ops_rhs; |
103 | brgemm_p.oc_logical_off = post_ops_data.oc_logical_off; |
104 | brgemm_p.dst_row_logical_off = post_ops_data.dst_row_logical_off; |
105 | brgemm_p.data_C_ptr_ = post_ops_data.data_C_ptr_; |
106 | brgemm_p.first_mb_matrix_addr_off = post_ops_data.first_mb_matrix_addr_off; |
107 | brgemm_p.a_zp_compensations = post_ops_data.a_zp_compensations; |
108 | brgemm_p.b_zp_compensations = post_ops_data.b_zp_compensations; |
109 | brgemm_p.c_zp_values = post_ops_data.c_zp_values; |
110 | assert(brg_kernel); |
111 | (*brg_kernel)(&brgemm_p); |
112 | } |
113 | |
114 | void brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel, int bs, |
115 | const void *addr_A, const void *addr_B, |
116 | const brgemm_batch_element_t *batch, void *ptr_C, void *ptr_D, |
117 | const brgemm_post_ops_data_t &post_ops_data, void *scratch) { |
118 | brgemm_kernel_params_t brgemm_p; |
119 | |
120 | brgemm_p.batch = batch; |
121 | brgemm_p.ptr_A = addr_A; |
122 | brgemm_p.ptr_B = addr_B; |
123 | brgemm_p.ptr_C = ptr_C; |
124 | brgemm_p.ptr_D = ptr_D; |
125 | brgemm_p.ptr_buf = scratch; |
126 | brgemm_p.ptr_bias = post_ops_data.bias; |
127 | brgemm_p.ptr_scales = post_ops_data.scales; |
128 | brgemm_p.do_post_ops |
129 | = post_ops_data.do_only_comp || post_ops_data.do_only_zp_a_val ? 0 |
130 | : 1; |
131 | brgemm_p.do_apply_comp = post_ops_data.do_only_zp_a_val ? 0 : 1; |
132 | brgemm_p.skip_accm = post_ops_data.skip_accumulation ? 1 : 0; |
133 | brgemm_p.BS = bs; |
134 | brgemm_p.zp_a_val = post_ops_data.zp_a_val; |
135 | brgemm_p.post_ops_binary_rhs_arg_vec = post_ops_data.binary_post_ops_rhs; |
136 | brgemm_p.oc_logical_off = post_ops_data.oc_logical_off; |
137 | brgemm_p.data_C_ptr_ = post_ops_data.data_C_ptr_; |
138 | brgemm_p.dst_row_logical_off = post_ops_data.dst_row_logical_off; |
139 | brgemm_p.first_mb_matrix_addr_off = post_ops_data.first_mb_matrix_addr_off; |
140 | assert(brg_kernel); |
141 | (*brg_kernel)(&brgemm_p); |
142 | } |
143 | |
144 | status_t brgemm_desc_init(brgemm_t *brg, cpu_isa_t isa, |
145 | brgemm_batch_kind_t type, impl::data_type_t dt_a, |
146 | impl::data_type_t dt_b, bool transA, bool transB, |
147 | brgemm_layout_t layout, float alpha, float beta, dim_t LDA, dim_t LDB, |
148 | dim_t LDC, dim_t M, dim_t N, dim_t K, const brgemm_strides_t *strides) { |
149 | /* |
150 | m - number of rows of the matrix op(A) and number of rows of the matrix C |
151 | n - number of columns of the matrix op(B) and number of columns of the matrix C |
152 | k - number of columns of the matrix op(A) and number of rows of the matrix op(B) |
153 | |
154 | Matrices are in row-major layouts: |
155 | A: lda * m, LDA - lda must be at least max(1, k) |
156 | B: ldb * k, LDB - ldb must be at least max(1, n) |
157 | C: ldc * m, LDC - ldc must be at least max(1, n) |
158 | |
159 | Matrices are in column-major layouts: |
160 | A: lda * k, LDA - lda must be at least max(1, m) |
161 | B: ldb * n, LDB - ldb must be at least max(1, k) |
162 | C: ldc * n, LDC - ldc must be at least max(1, m) |
163 | */ |
164 | if (brg == nullptr) return status::invalid_arguments; |
165 | if (transA || transB) return status::unimplemented; |
166 | |
167 | brgemm_utils::init_brgemm_conf(brg, isa, type, dt_a, dt_b, layout, alpha, |
168 | beta, LDA, LDB, LDC, M, N, K, strides); |
169 | |
170 | if (M <= 0 || N <= 0 || K <= 0) return status::invalid_arguments; |
171 | bool ldx_check = (brg->is_row_major()) ? (LDA < K) |
172 | : (LDA < M || LDB < K || LDC < M); |
173 | if (ldx_check) return status::invalid_arguments; |
174 | |
175 | if (utils::everyone_is( |
176 | false, brg->is_int8, brg->is_bf16, brg->is_f32, brg->is_f16)) |
177 | return status::unimplemented; |
178 | |
179 | CHECK(brgemm_blocking(brg)); |
180 | |
181 | return status::success; |
182 | } |
183 | |
184 | status_t brdgmm_desc_init(brgemm_t *brg, cpu_isa_t isa, |
185 | brgemm_batch_kind_t type, impl::data_type_t dt_a, |
186 | impl::data_type_t dt_b, bool transA, brgemm_layout_t layout, |
187 | float alpha, float beta, dim_t LDA, dim_t LDC, dim_t M, dim_t N, |
188 | const brgemm_strides_t *strides) { |
189 | |
190 | if (brg == nullptr) return status::invalid_arguments; |
191 | if (transA || layout != brgemm_row_major || alpha != 1.0f || beta != 0.f) |
192 | return status::unimplemented; |
193 | |
194 | brgemm_utils::init_brdgmm_conf(brg, isa, type, dt_a, dt_b, layout, alpha, |
195 | beta, LDA, LDC, M, N, strides); |
196 | |
197 | const bool ldx_check = (LDA < N || LDC < N); |
198 | if (ldx_check) return status::invalid_arguments; |
199 | |
200 | if (utils::everyone_is( |
201 | false, brg->is_int8, brg->is_bf16, brg->is_f32, brg->is_f16)) |
202 | return status::unimplemented; |
203 | |
204 | CHECK(brdgmm_blocking(brg)); |
205 | |
206 | return status::success; |
207 | } |
208 | |
209 | status_t brgemm_desc_set_postops(brgemm_t *brg, const primitive_attr_t *attr, |
210 | const memory_desc_t *dst_md, int LDD, impl::data_type_t dt_bias) { |
211 | if (!brg || !dst_md) return status::invalid_arguments; |
212 | |
213 | brg->attr = attr; |
214 | brg->dst_md = dst_md; |
215 | |
216 | brg->with_bias = (dt_bias == data_type::undef) ? false : true; |
217 | brg->dt_bias = dt_bias; |
218 | brg->typesize_bias = (dt_bias == data_type::undef) |
219 | ? 0 |
220 | : types::data_type_size(brg->dt_bias); |
221 | |
222 | brg->LDD = LDD; |
223 | const auto dt_d = dst_md->data_type; |
224 | |
225 | // check that bias and output data type are supported by isa |
226 | if (!IMPLICATION(one_of(data_type::bf16, dt_bias, dt_d), |
227 | is_superset(brg->isa_impl, avx512_core) |
228 | || is_superset(brg->isa_impl, avx2_vnni_2))) |
229 | return status::unimplemented; |
230 | if (!IMPLICATION(one_of(data_type::f16, dt_bias, dt_d), |
231 | is_superset(brg->isa_impl, avx512_core_fp16) |
232 | || is_superset(brg->isa_impl, avx2_vnni_2))) |
233 | return status::unimplemented; |
234 | // check that combination of data types is allowed |
235 | if ((brg->dt_a == data_type::u8 && brg->dt_b == data_type::s8) |
236 | && (!one_of(dt_d, data_type::u8, data_type::s8, data_type::s32, |
237 | data_type::f32, data_type::bf16)) |
238 | && (!one_of(dt_bias, data_type::undef, data_type::u8, data_type::s8, |
239 | data_type::s32, data_type::f32, data_type::bf16))) |
240 | return status::unimplemented; |
241 | if ((brg->dt_a == data_type::bf16 && brg->dt_b == data_type::bf16) |
242 | && (!one_of(dt_d, data_type::bf16, data_type::f32)) |
243 | && (!one_of(dt_bias, data_type::undef, data_type::bf16, |
244 | data_type::f32))) |
245 | return status::unimplemented; |
246 | if ((brg->dt_a == data_type::f32 && brg->dt_b == data_type::f32) |
247 | && (!one_of(dt_d, data_type::f32)) |
248 | && (!one_of(dt_bias, data_type::undef, data_type::f32))) |
249 | return status::unimplemented; |
250 | if (!IMPLICATION(brg->is_f16, |
251 | one_of(dt_d, data_type::f32, data_type::f16) |
252 | && one_of(dt_bias, data_type::undef, data_type::f32, |
253 | data_type::f16))) |
254 | return status::unimplemented; |
255 | |
256 | brg->dt_d = dt_d; |
257 | brg->typesize_D = types::data_type_size(brg->dt_d); |
258 | |
259 | if (!IMPLICATION( |
260 | brg->is_int8 && brg->dt_d == bf16, mayiuse(avx512_core_vnni))) |
261 | return status::unimplemented; |
262 | |
263 | if (brg->is_int8 && brg->dt_d == bf16) |
264 | brg->is_bf16_emu = !mayiuse(avx512_core_bf16); |
265 | |
266 | // Rerun blocking heuristic due to reduced zmm register count |
267 | if (brg->is_bf16_emu && brg->is_dgmm) CHECK(brdgmm_blocking(brg)); |
268 | |
269 | if (!brg->attr) return status::success; |
270 | |
271 | using namespace injector; |
272 | |
273 | const auto &post_ops = brg->attr->post_ops_; |
274 | const memory_desc_wrapper dst_d(dst_md); |
275 | |
276 | const auto binary_ind = post_ops.find(primitive_kind::binary); |
277 | brg->with_binary = binary_ind != -1; |
278 | |
279 | // NOTE: Using brg->isa_impl here is a bit dangerous as it can change before |
280 | // kernel creation, so there is no gaurantee that the isa checked here |
281 | // matches the isa used at kernel creation time. For now this can only |
282 | // happen for bf32, where isa during this check is avx512_core and isa |
283 | // at kernel creation time is avx512_core_amx_bf16. It just so happens |
284 | // that the behavior of `post_ops_ok` is identical for those two isas, |
285 | // but there is no gaurentee that will always be the case. |
286 | if ((brg->with_binary && !dst_md) |
287 | || !injector::post_ops_ok( |
288 | post_ops_ok_args_t(brg->isa_impl, {sum, eltwise, binary}, |
289 | post_ops, &dst_d, false /*sum_at_pos_0_only*/, |
290 | false /*sum_requires_scale_one*/, |
291 | false /*sum_requires_zp_zero*/, |
292 | {broadcasting_strategy_t::per_oc, |
293 | broadcasting_strategy_t::scalar, |
294 | broadcasting_strategy_t::per_mb_spatial, |
295 | broadcasting_strategy_t::per_mb_w, |
296 | broadcasting_strategy_t::per_w, |
297 | broadcasting_strategy_t::no_broadcast}))) |
298 | return status::unimplemented; |
299 | |
300 | const auto sum_idx = post_ops.find(primitive_kind::sum); |
301 | const bool with_sum = sum_idx != -1; |
302 | brg->with_sum = with_sum; |
303 | brg->sum_scale = with_sum ? post_ops.entry_[sum_idx].sum.scale : 0; |
304 | brg->sum_zp = with_sum ? post_ops.entry_[sum_idx].sum.zero_point : 0; |
305 | const auto sum_dt |
306 | = with_sum ? post_ops.entry_[sum_idx].sum.dt : data_type::undef; |
307 | brg->sum_dt = sum_dt != data_type::undef ? sum_dt : dt_d; |
308 | |
309 | const auto eltwise_ind = post_ops.find(primitive_kind::eltwise); |
310 | brg->with_eltwise = eltwise_ind != -1; |
311 | |
312 | const auto &src_scales = attr->scales_.get(DNNL_ARG_SRC); |
313 | const auto &wei_scales = attr->scales_.get(DNNL_ARG_WEIGHTS); |
314 | brg->with_scales = !src_scales.has_default_values() |
315 | || !wei_scales.has_default_values(); |
316 | if (brg->with_scales) { |
317 | // Note. the current version supports only two different output scale |
318 | // types: |
319 | // 1) common (mask_ = 0) |
320 | // 2) per_n_dim_scale - broadcast across n dimension; |
321 | // for convolution and inner product promitives it corresponds |
322 | // to "per_oc" mask_ = 1 << 1; for matmul - to |
323 | // mask_ = (1 << (ndims - 1))), where ndims is number of |
324 | // dimensions for original matmul problem |
325 | // So if wei_scales.mask_ != 0 (not common) it's assumed here that scale |
326 | // type is per_n_dim_scale and driver which calls brgemm kernel checked |
327 | // that mask has correct value for this case |
328 | brg->is_oc_scale = wei_scales.mask_ != 0; |
329 | } |
330 | const bool scales_ok = src_scales.mask_ == 0 |
331 | && attr->scales_.has_default_values( |
332 | {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS}); |
333 | if (!scales_ok) return status::unimplemented; |
334 | |
335 | auto init_zp_type |
336 | = [&](brgemm_broadcast_t &zp_type, int mem_arg) -> status_t { |
337 | auto zero_points = attr->zero_points_; |
338 | |
339 | // common zero point type is supported for now |
340 | if (!zero_points.common(mem_arg)) return status::unimplemented; |
341 | |
342 | zp_type = zero_points.has_default_values(mem_arg) |
343 | ? brgemm_broadcast_t::none |
344 | : brgemm_broadcast_t::per_tensor; |
345 | return status::success; |
346 | }; |
347 | |
348 | init_zp_type(brg->zp_type_a, DNNL_ARG_SRC); |
349 | init_zp_type(brg->zp_type_b, DNNL_ARG_WEIGHTS); |
350 | init_zp_type(brg->zp_type_c, DNNL_ARG_DST); |
351 | |
352 | // src zero points require additional register in brgemm kernel |
353 | if (brg->zp_type_a != brgemm_broadcast_t::none |
354 | || (brg->is_bf16_emu && !brg->is_dgmm)) |
355 | CHECK(brgemm_blocking(brg)); |
356 | |
357 | return status::success; |
358 | } |
359 | |
360 | status_t brgemm_desc_set_attr(brgemm_t *brg, const brgemm_attr_t &brgattr) { |
361 | if (brg == nullptr) return status::invalid_arguments; |
362 | |
363 | // negative padding is not supported |
364 | if (brgattr.max_top_vpad < 0 || brgattr.max_bottom_vpad < 0) |
365 | return status::unimplemented; |
366 | |
367 | if (!brg->is_dgmm) { |
368 | // virtual padding size is restricted by MAX_VPAD value |
369 | if (brgattr.max_top_vpad > brgemm_t::MAX_VPAD |
370 | || brgattr.max_bottom_vpad > brgemm_t::MAX_VPAD) |
371 | return status::unimplemented; |
372 | |
373 | // virtual padding is restricted by bd_block size due to |
374 | // brgemm_kernel implementation. TODO: remove this restriction |
375 | if (brgattr.max_top_vpad > brg->bd_block |
376 | || brgattr.max_bottom_vpad > brg->bd_block) |
377 | return status::unimplemented; |
378 | } |
379 | |
380 | // virtual padding is supported for "brgemm_row_major" layout |
381 | // TODO: remove this restriction |
382 | if ((brgattr.max_top_vpad > 0 || brgattr.max_bottom_vpad > 0) |
383 | && brg->layout != brgemm_row_major) |
384 | return status::unimplemented; |
385 | |
386 | brg->brgattr = brgattr; |
387 | |
388 | if (brgattr.fpmath_mode != fpmath_mode::strict) maybe_try_bf32(brg); |
389 | |
390 | bool hint_blocking_set |
391 | = (brgattr.hint_bd_block != 0 || brgattr.hint_bd_block2 != 0 |
392 | || brgattr.hint_ld_block != 0 || brgattr.hint_ld_block2 != 0 |
393 | || brgattr.hint_load_nt_A != brgemm_hint_nt_undef |
394 | || brgattr.hint_load_nt_B != brgemm_hint_nt_undef); |
395 | if (brg->is_bf16_tmm || hint_blocking_set || brgattr.bd_mask_level |
396 | || brgattr.fpmath_mode != fpmath_mode::strict) { |
397 | if (brg->is_dgmm) |
398 | CHECK(brdgmm_blocking(brg)); |
399 | else |
400 | CHECK(brgemm_blocking(brg)); |
401 | } |
402 | |
403 | brg->LDA2 = (brgattr.LDA2 != 0) ? brgattr.LDA2 : brg->LDA; |
404 | brg->LDB2 = (brgattr.LDB2 != 0) ? brgattr.LDB2 : brg->LDB; |
405 | brg->LDC2_M = (brgattr.LDC2_M != 0) ? brgattr.LDC2_M : brg->LDC; |
406 | brg->LDC2_N = (brgattr.LDC2_N != 0) ? brgattr.LDC2_N : brg->ld_block; |
407 | |
408 | brg->is_blocked = (brg->LDA2 != brg->LDA || brg->LDB2 != brg->LDB |
409 | || brg->LDC2_M != brg->LDC || brg->LDC2_N != brg->ld_block); |
410 | |
411 | if (!IMPLICATION(brg->is_blocked, brg->layout = brgemm_row_major)) |
412 | return status::invalid_arguments; |
413 | |
414 | // virtual padding is not supported for "amx" |
415 | if ((brgattr.max_top_vpad > 0 || brgattr.max_bottom_vpad > 0) |
416 | && (brg->is_tmm)) |
417 | return status::unimplemented; |
418 | |
419 | brg->prfA = brgattr.hint_prfA; |
420 | brg->prfB = brgattr.hint_prfB; |
421 | brg->prfC = brgattr.hint_prfC; |
422 | |
423 | if (brgattr.hint_prefetching == brgemm_kernel_prefetching_t::brgemm_prf1 |
424 | && brg->prfC.dist1 < 0) |
425 | brg->prfC.dist1 = 0; |
426 | if (brgattr.hint_prefetching == brgemm_kernel_prefetching_t::brgemm_prf2 |
427 | && brg->prfC.dist2 < 0) |
428 | brg->prfC.dist2 = 0; |
429 | |
430 | return status::success; |
431 | } |
432 | |
433 | status_t brgemm_kernel_create( |
434 | brgemm_kernel_t **brg_kernel, const brgemm_t &brg) { |
435 | if (!brg_kernel) return status::invalid_arguments; |
436 | *brg_kernel = nullptr; |
437 | |
438 | if (brg.is_dgmm) { |
439 | #define CASE(isa) \ |
440 | case isa: \ |
441 | CHECK(safe_ptr_assign<brgemm_kernel_t>(*brg_kernel, \ |
442 | new brdgmm_kernel_t<isa, typename cpu_isa_traits<isa>::Vmm>( \ |
443 | brg))); \ |
444 | break |
445 | switch (brg.isa_impl) { |
446 | CASE(avx512_core_fp16); |
447 | CASE(avx512_core_bf16); |
448 | CASE(avx512_core_vnni); |
449 | CASE(avx512_core); |
450 | CASE(avx2_vnni_2); |
451 | CASE(avx2); |
452 | default: return status::unimplemented; |
453 | } |
454 | #undef CASE |
455 | } else if (can_dispatch_uker(&brg)) { |
456 | CHECK(safe_ptr_assign<brgemm_kernel_t>( |
457 | *brg_kernel, new brgemm_amx_uker_t(brg))); |
458 | } else { |
459 | if (brg.is_tmm) { |
460 | if (brg.is_f16_tmm) { |
461 | CHECK(safe_ptr_assign<brgemm_kernel_t>(*brg_kernel, |
462 | new brgemm_kernel_common_t<avx512_core_amx_fp16, |
463 | Xbyak::Tmm>(brg))); |
464 | } else { |
465 | CHECK(safe_ptr_assign<brgemm_kernel_t>(*brg_kernel, |
466 | new brgemm_kernel_common_t<avx512_core_amx, Xbyak::Tmm>( |
467 | brg))); |
468 | } |
469 | } else if (brg.is_zmm) { |
470 | // isa specific instantiations are required because |
471 | // post-ops require template isa param. |
472 | if (brg.isa_impl == avx512_core_fp16) { |
473 | CHECK(safe_ptr_assign<brgemm_kernel_t>(*brg_kernel, |
474 | new brgemm_kernel_common_t<avx512_core_fp16, |
475 | Xbyak::Zmm>(brg))); |
476 | } else if (brg.isa_impl == avx512_core_bf16) { |
477 | CHECK(safe_ptr_assign<brgemm_kernel_t>(*brg_kernel, |
478 | new brgemm_kernel_common_t<avx512_core_bf16, |
479 | Xbyak::Zmm>(brg))); |
480 | } else if (brg.isa_impl == avx512_core_vnni) { |
481 | CHECK(safe_ptr_assign<brgemm_kernel_t>(*brg_kernel, |
482 | new brgemm_kernel_common_t<avx512_core_vnni, |
483 | Xbyak::Zmm>(brg))); |
484 | } else { |
485 | CHECK(safe_ptr_assign<brgemm_kernel_t>(*brg_kernel, |
486 | new brgemm_kernel_common_t<avx512_core, Xbyak::Zmm>( |
487 | brg))); |
488 | } |
489 | } else if (brg.is_ymm) { |
490 | if (brg.isa_impl == avx2) { |
491 | CHECK(safe_ptr_assign<brgemm_kernel_t>(*brg_kernel, |
492 | new brgemm_kernel_common_t<avx2, Xbyak::Ymm>(brg))); |
493 | } else if (brg.isa_impl == avx2_vnni) { |
494 | CHECK(safe_ptr_assign<brgemm_kernel_t>(*brg_kernel, |
495 | new brgemm_kernel_common_t<avx2_vnni, Xbyak::Ymm>( |
496 | brg))); |
497 | } else if (brg.isa_impl == avx2_vnni_2) { |
498 | CHECK(safe_ptr_assign<brgemm_kernel_t>(*brg_kernel, |
499 | new brgemm_kernel_common_t<avx2_vnni_2, Xbyak::Ymm>( |
500 | brg))); |
501 | } |
502 | } |
503 | } |
504 | if (!(*brg_kernel)) return status::unimplemented; |
505 | return (*brg_kernel)->create_kernel(); |
506 | } |
507 | |
508 | status_t brgemm_kernel_destroy(brgemm_kernel_t *brg_kernel) { |
509 | delete brg_kernel; |
510 | return status::success; |
511 | } |
512 | |
513 | status_t brgemm_init_tiles(const brgemm_t &brg, char palette[64]) { |
514 | constexpr int max_palette_size_in_bytes = 64; |
515 | |
516 | if (!brg.is_tmm) return status::unimplemented; |
517 | |
518 | //TODO: Add support of tail processing by reduction dimension |
519 | auto rd_block = (!brg.rdb && brg.rdb_tail) ? brg.rdb_tail : brg.rd_block; |
520 | if (brg.is_bf32) rd_block = utils::rnd_up(rd_block, 2 /*vnni_granularity*/); |
521 | |
522 | palette_config_t *buff = (palette_config_t *)(palette); |
523 | |
524 | char *_tc = (char *)(buff); |
525 | for (int i = 0; i < max_palette_size_in_bytes; i++) |
526 | _tc[i] = 0; |
527 | |
528 | const int typesize_A = brg.is_bf32 ? sizeof(bfloat16_t) : brg.typesize_A; |
529 | const int typesize_B = brg.is_bf32 ? sizeof(bfloat16_t) : brg.typesize_B; |
530 | |
531 | const int rd_step = 4 / typesize_A; |
532 | |
533 | const auto Ac = typesize_A * rd_block; |
534 | |
535 | const auto Br = (brg.typesize_C != 0) ? Ac / brg.typesize_C : 0; |
536 | |
537 | if (brg.ldb_tail && (brg.ld_block2 > 1)) return status::unimplemented; |
538 | if (brg.get_num_A_tiles() + brg.get_num_B_tiles() |
539 | + brg.get_bd_block2() * brg.get_ld_block2() |
540 | > brgemm_t::AMX_TILES_NUM) |
541 | return status::unimplemented; |
542 | |
543 | // Due to interleaving tileload/tmul we don't support blocking 1x6 and 6x1 |
544 | //TODO: update gemm_microkernel_amx to support such blocking |
545 | if (brg.get_bd_block2() >= 6 || brg.get_num_C_tiles() >= 6) |
546 | return status::unimplemented; |
547 | |
548 | for (int m = 0; m < brg.get_num_A_tiles(); m++) { |
549 | const bool is_bd_tail |
550 | = (brg.bdb_tail && m == (brg.get_num_A_tiles() - 1)); |
551 | const auto A_tensor = brg.get_A_tensor(m, is_bd_tail); |
552 | const auto Ar = is_bd_tail ? brg.bdb_tail : brg.bd_block; |
553 | tc_configure_tile(buff, A_tensor, Ar, Ac); |
554 | } |
555 | |
556 | for (int n = 0; n < brg.get_num_B_tiles(); n++) { |
557 | const bool is_ld_tail |
558 | = (brg.ldb_tail && n == (brg.get_num_B_tiles() - 1)); |
559 | const auto B_tensor = brg.get_B_tensor(n, is_ld_tail); |
560 | const auto Bc = (is_ld_tail ? brg.ldb_tail : brg.ld_block) * typesize_B |
561 | * rd_step; |
562 | tc_configure_tile(buff, B_tensor, Br, Bc); |
563 | } |
564 | |
565 | for (int m = 0; m < brg.get_bd_block2(); m++) { |
566 | const bool is_bd_tail |
567 | = (brg.bdb_tail && m == (brg.get_bd_block2() - 1)); |
568 | const auto Cr = is_bd_tail ? brg.bdb_tail : brg.bd_block; |
569 | for (int n = 0; n < brg.get_ld_block2(); n++) { |
570 | const bool is_ld_tail |
571 | = (brg.ldb_tail && n == (brg.get_ld_block2() - 1)); |
572 | const auto Cc = (is_ld_tail ? brg.ldb_tail : brg.ld_block) |
573 | * brg.typesize_C; |
574 | const auto C_tensor |
575 | = brg.get_C_tensor(m, n, is_bd_tail, is_ld_tail); |
576 | tc_configure_tile(buff, C_tensor, Cr, Cc); |
577 | } |
578 | } |
579 | |
580 | buff->palette_id = amx::get_target_palette(); |
581 | |
582 | return status::success; |
583 | } |
584 | |
585 | } // namespace x64 |
586 | } // namespace cpu |
587 | } // namespace impl |
588 | } // namespace dnnl |
589 | |
590 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
591 | |