1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <float.h>
18#include <functional>
19#include <math.h>
20#include <random>
21#include <stdio.h>
22#include <stdlib.h>
23
24#include "oneapi/dnnl/dnnl.h"
25
26// TODO: refactor the driver to avoid using extra flags of a memory descriptor.
27#include "src/common/memory_desc.hpp"
28
29#include "tests/test_isa_common.hpp"
30
31#include "utils/parallel.hpp"
32#include "utils/parser.hpp"
33
34#include "dnnl_common.hpp"
35#include "dnnl_memory.hpp"
36
37#include "brgemm/brgemm.hpp"
38
39#if defined(DNNL_X64) && DNNL_X64 == 1 && DNNL_CPU_RUNTIME != DNNL_RUNTIME_NONE
40template <>
41struct dnnl_api_traits<dnnl::impl::cpu::x64::brgemm_kernel_t *> {
42 static void destroy(dnnl::impl::cpu::x64::brgemm_kernel_t *t) {
43 DNN_SAFE_V(dnnl::impl::cpu::x64::brgemm_kernel_destroy(t));
44 }
45};
46#endif
47
48namespace brgemm {
49
50#if defined(DNNL_X64) && DNNL_X64 == 1 && DNNL_CPU_RUNTIME != DNNL_RUNTIME_NONE
51
52/// Initializes BRGEMM attributes from an input string.
53///
54/// @param brgattr Output BRGEMM attributes.
55/// @param str Input string of values in the format: KEY:VALUE[+KEY:VALUE[...]].
56/// `KEY` follows exact name of the brgemm_attr_t object members and their
57/// `VALUE` follow the member type. enum and boolean types are treated as
58/// integers.
59///
60dnnl_status_t brgemm_attr_init(
61 dnnl::impl::cpu::x64::brgemm_attr_t *brgattr, const prb_t *prb) {
62 using namespace dnnl::impl::cpu::x64;
63
64 const auto &str = prb->brgemm_attr;
65 if (str.empty()) return dnnl_success;
66
67 size_t entry_pos = 0;
68 while (entry_pos != std::string::npos) {
69 auto key_value_str = parser::get_substr(str, entry_pos, '+');
70 size_t value_pos = 0;
71 auto key_str = parser::get_substr(key_value_str, value_pos, ':');
72 auto value_str = parser::get_substr(key_value_str, value_pos, '\0');
73
74#define PROCESS_SETTING_KEY_VAL(setting, key) \
75 if (key_str.compare(STRINGIFY(key)) == 0) \
76 brgattr->setting = std::stoi(value_str);
77
78#define PROCESS_KEY_VAL(setting) PROCESS_SETTING_KEY_VAL(setting, setting)
79
80 // TODO: `max_top_vpad` and `max_bottom_vpad` do not affect anything in
81 // the kernel call and reference computation so far since
82 // batch_element_t struct is not adjusted to incorporate different pad
83 // values.
84 // PROCESS_KEY_VAL(max_top_vpad);
85 // PROCESS_KEY_VAL(max_bottom_vpad);
86 PROCESS_KEY_VAL(hint_expected_A_size);
87 PROCESS_KEY_VAL(hint_expected_B_size);
88 PROCESS_KEY_VAL(hint_expected_C_size);
89 PROCESS_KEY_VAL(wary_tail_read);
90 PROCESS_KEY_VAL(generate_skip_accumulation);
91 // TODO: `bd_mask` can't be passed to the kernel at this moment, that's
92 // why `bd_mask_level` has to stay `0` for now until it's enabled.
93 // PROCESS_KEY_VAL(bd_mask_level);
94 PROCESS_KEY_VAL(use_uker);
95 PROCESS_KEY_VAL(use_interleave_stores);
96 PROCESS_KEY_VAL(postops_only);
97 PROCESS_KEY_VAL(hint_bd_block);
98 PROCESS_KEY_VAL(hint_bd_block2);
99 PROCESS_KEY_VAL(hint_ld_block);
100 PROCESS_KEY_VAL(hint_ld_block2);
101
102 PROCESS_SETTING_KEY_VAL(hint_prfA.dist1, hint_prfA_dist1);
103 PROCESS_SETTING_KEY_VAL(hint_prfA.dist2, hint_prfA_dist2);
104 PROCESS_SETTING_KEY_VAL(hint_prfB.dist1, hint_prfB_dist1);
105 PROCESS_SETTING_KEY_VAL(hint_prfB.dist2, hint_prfB_dist2);
106 PROCESS_SETTING_KEY_VAL(hint_prfC.dist1, hint_prfC_dist1);
107 PROCESS_SETTING_KEY_VAL(hint_prfC.dist2, hint_prfC_dist2);
108
109#undef PROCESS_SETTING_KEY_VAL
110#undef PROCESS_KEY_VAL
111
112 if (key_str.find(STRINGIFY(hint_innermost_loop)) != std::string::npos)
113 brgattr->hint_innermost_loop
114 = static_cast<brgemm_kernel_innermost_loop_t>(
115 std::stoi(value_str));
116 if (key_str.find(STRINGIFY(hint_loop_order)) != std::string::npos)
117 brgattr->hint_loop_order = static_cast<brgemm_kernel_loop_order_t>(
118 std::stoi(value_str));
119 if (key_str.find(STRINGIFY(hint_prefetching)) != std::string::npos)
120 brgattr->hint_prefetching
121 = static_cast<brgemm_kernel_prefetching_t>(
122 std::stoi(value_str));
123 if (key_str.find(STRINGIFY(hint_load_nt_A)) != std::string::npos)
124 brgattr->hint_load_nt_A = static_cast<brgemm_kernel_hint_nt_t>(
125 std::stoi(value_str));
126 if (key_str.find(STRINGIFY(hint_load_nt_B)) != std::string::npos)
127 brgattr->hint_load_nt_B = static_cast<brgemm_kernel_hint_nt_t>(
128 std::stoi(value_str));
129 }
130
131 // `max_bs` is handled directly through the driver interface.
132 brgattr->max_bs = prb->batch_size;
133
134 // `fpmath_mode` is handled directly through the driver interface.
135 brgattr->fpmath_mode = prb->attr.fpmath_mode;
136
137 return dnnl_success;
138}
139
140std::string prepare_wei_format_string(
141 dnnl_data_type_t dt, int64_t n, bool is_vnni_layout) {
142 // `dt` affects the choice of last inner block (for VNNI-friendliness).
143 // `n` affects the choice of B block.
144 std::string wtag("BA16a");
145 switch (n) {
146 case 64: wtag += "64b"; break;
147 case 48: wtag += "48b"; break;
148 case 32: wtag += "32b"; break;
149 default:
150 if (n <= 16)
151 wtag += "16b";
152 else {
153 if (n % 16 != 0) {
154 wtag += std::to_string(n) + "b";
155 } else
156 wtag += std::to_string(16 * div_up(n, 16)) + "b";
157 }
158 break;
159 }
160 if (is_vnni_layout) {
161 switch (dt) {
162 case dnnl_f32: break;
163 case dnnl_f16:
164 case dnnl_bf16: wtag += "2a"; break;
165 case dnnl_s8: wtag += "4a"; break;
166 default: assert(!"unsupported data type");
167 }
168 }
169
170 return wtag;
171}
172
173int fill_data(data_kind_t kind, const prb_t *prb, dnn_mem_t &mem_dt,
174 dnn_mem_t &mem_fp, res_t *res) {
175
176 const auto nelems = mem_dt.nelems();
177 if (nelems == 0) return OK;
178
179 assert(mem_dt.nelems() == mem_fp.nelems());
180
181 cfg_t cfg(prb, {SRC, WEI, BIA, DST});
182 cfg_t::density_args_t density_args;
183 density_args.data_kind = kind;
184 density_args.n_acc = prb->k;
185 const auto density = cfg.get_density(density_args);
186
187 /* Do fixed partitioning to have same filling for any number of threads */
188 const int64_t n_chunks = 16;
189 const int64_t chunk_size = div_up(nelems, n_chunks);
190
191 benchdnn_parallel_nd(n_chunks, [&](int64_t idx_chunk) {
192 int64_t idx_start = idx_chunk * chunk_size;
193 int64_t idx_end = MIN2(idx_start + chunk_size, nelems);
194 // Note: we use a different seed for each chunk to avoid
195 // repeating patterns. We could use discard(idx_start) too but
196 // it has a complexity in O(idx_start). We also add 1 to avoid
197 // seeding with 0.
198 std::minstd_rand int_seed(kind * nelems + idx_start + 1);
199 int_seed.discard(1);
200 std::minstd_rand b_seed(kind * nelems + idx_start + 1);
201 b_seed.discard(10);
202
203 std::uniform_int_distribution<> gen(
204 cfg.get_range_min(kind), cfg.get_range_max(kind));
205 std::bernoulli_distribution b_dist(density);
206
207 // make sure the first element is positive
208 if (idx_start == 0) {
209 float val = 0;
210 while (val <= 0)
211 val = gen(int_seed);
212 mem_fp.set_elem(
213 0, round_to_nearest_representable(cfg.get_dt(kind), val));
214 idx_start += 1;
215 }
216
217 for (int64_t idx = idx_start; idx < idx_end; ++idx) {
218 bool is_one = density == 1.f ? true : b_dist(b_seed);
219 float val = is_one * gen(int_seed);
220 mem_fp.set_elem(
221 idx, round_to_nearest_representable(cfg.get_dt(kind), val));
222 }
223 });
224
225 SAFE(mem_dt.reorder(mem_fp), WARN);
226
227 return OK;
228}
229
230void skip_unimplemented_prb(const prb_t *prb, res_t *res) {
231 skip_unimplemented_data_type(
232 {prb->src_dt(), prb->wei_dt(), prb->bia_dt, prb->dst_dt()},
233 prb->dir, res);
234 skip_unimplemented_sum_po(prb->attr, res, prb->dst_dt());
235}
236
237void skip_invalid_prb(const prb_t *prb, res_t *res) {
238 const bool is_src_zp = !prb->attr.zero_points.is_def(DNNL_ARG_SRC);
239 const bool is_dst_zp = !prb->attr.zero_points.is_def(DNNL_ARG_DST);
240
241 // Only runtime zero points are supported by this driver
242 const bool is_runtime_src_zp = prb->attr.zero_points.runtime(DNNL_ARG_SRC);
243 const bool is_runtime_dst_zp = prb->attr.zero_points.runtime(DNNL_ARG_DST);
244 const bool is_static_zp = (is_src_zp && !is_runtime_src_zp)
245 || (is_dst_zp && !is_runtime_dst_zp);
246 if (is_static_zp) {
247 res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED;
248 return;
249 }
250
251 // AMX kernel only supports SRC zero points in unrolled kernel,
252 // and only for values of 0 or 1.
253 // Note: this check must be done here due to the fact that zero point value
254 // in brgemm API is a runtime argument.
255 // TODO: remove once AMX kernel fully supports zero points.
256 const bool is_amx = dnnl::mayiuse(dnnl_cpu_isa_avx512_core_amx);
257 const int src_zp_value = prb->attr.zero_points.get(DNNL_ARG_SRC).value;
258 if (is_amx && is_src_zp && src_zp_value != 0 && src_zp_value != 1) {
259 res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED;
260 return;
261 }
262}
263
264void setup_cmp(compare::compare_t &cmp, const prb_t *prb, data_kind_t kind,
265 const args_t &ref_args) {
266 const auto dt = prb->get_dt(kind);
267 const float trh = dt == dnnl_f32 ? 1e-6f : epsilon_dt(dt);
268 cmp.set_threshold(trh);
269 cmp.set_zero_trust_percent(90.f); // TODO: why so bad filling?
270}
271
272// A special wrapper needed to match internal infrastructure.
273dnnl_status_t brgemm_kernel_execute_postops_wrapper(
274 const dnnl::impl::cpu::x64::brgemm_kernel_t *brgemm_kernel,
275 int batch_size,
276 const dnnl::impl::cpu::x64::brgemm_batch_element_t *batch_element,
277 void *acc_ptr, void *dst_ptr,
278 const dnnl::impl::cpu::x64::brgemm_post_ops_data_t &post_ops_data,
279 void *scratchpad_ptr, const dnnl_stream_t &stream,
280 const std::vector<dnnl_exec_arg_t> &dnnl_args) {
281 brgemm_kernel_execute_postops(brgemm_kernel, batch_size, batch_element,
282 acc_ptr, dst_ptr, post_ops_data, scratchpad_ptr);
283 return dnnl_success;
284}
285
286int doit(const prb_t *prb, res_t *res) {
287 if (bench_mode == LIST) return res->state = LISTED, OK;
288
289 skip_start(res);
290 if (res->state == SKIPPED) return OK;
291
292 // Need this here as brgemm has no primitive creation step
293 skip_invalid_prb(prb, res);
294 if (res->state == SKIPPED) return OK;
295
296 bool use_dst_as_acc = false;
297 if (prb->bia_dt == dnnl_data_type_undef && prb->acc_dt() == prb->dst_dt()
298 && prb->attr.is_def(/* skip_fmpath = */ true))
299 use_dst_as_acc = true;
300
301 // Fuse batch size into K dimension which follows the library usage of the
302 // kernel batch size setting.
303 const dnnl_dims_t src_dims = {prb->m, prb->k * prb->batch_size};
304 const dnnl_dims_t wei_dims = {prb->k * prb->batch_size, prb->n};
305
306 dims_t src_strides = {prb->get_lda(), 1};
307 dims_t dst_strides = {prb->get_ldd(), 1};
308 dims_t acc_strides = use_dst_as_acc ? dst_strides : dims_t();
309
310 auto dst_md = dnn_mem_t::init_md(prb->ndims, prb->dst_dims.data(),
311 prb->dst_dt(), prb->dtag, dst_strides);
312
313 using namespace dnnl::impl::cpu::x64;
314
315 brgemm_t brgemm_desc;
316 // Supports only address model for now as only affects the way memory is
317 // passed to `brgemm_batch_element_t` object.
318 brgemm_batch_kind_t batch_kind = brgemm_batch_kind_t::brgemm_addr;
319 brgemm_layout_t layout = brgemm_layout_t::brgemm_row_major;
320
321 // Pass `isa_undef` for now since internal work with it or rather isa bits
322 // than isa values directly which causes misalignment between public enum
323 // and internal values.
324 // TODO: re-consider enabling isa values.
325 const auto isa_undef = cpu_isa_t::isa_undef;
326
327 // Create BRGeMM descriptor, analogous to primitive descriptor creation
328 const auto status_init = brgemm_desc_init(&brgemm_desc, isa_undef,
329 batch_kind, prb->src_dt(), prb->wei_dt(), false /* transA */,
330 false /* transB */, layout, prb->alpha, prb->beta, prb->get_lda(),
331 prb->get_ldb(), prb->get_ldc(use_dst_as_acc), prb->m, prb->n,
332 prb->k, nullptr /* strides */);
333 check_dnnl_status(status_init, prb, res);
334 if (res->state == SKIPPED) return OK;
335 // Unconditionally skip remaining unimplemented cases.
336 // TODO: remove this and add a SAFE check above.
337 if (status_init != dnnl_success)
338 return res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED, OK;
339
340 attr_args_t attr_args;
341 auto dnnl_attr = make_benchdnn_dnnl_wrapper(
342 create_dnnl_attr(prb->attr, attr_args));
343
344 SAFE(check_dnnl_status(brgemm_desc_set_postops(&brgemm_desc, dnnl_attr,
345 dst_md, prb->get_ldd(), prb->bia_dt),
346 prb, res),
347 WARN);
348 if (res->state == SKIPPED) return OK;
349
350 brgemm_attr_t brgemm_attr;
351 DNN_SAFE(brgemm_attr_init(&brgemm_attr, prb), WARN);
352 SAFE(check_dnnl_status(
353 brgemm_desc_set_attr(&brgemm_desc, brgemm_attr), prb, res),
354 WARN);
355 if (res->state == SKIPPED) return OK;
356
357 // Create BRGeMM kernel, analogous to primitive creation.
358 // ctx_init can here be used to select core type on hetero ISA with
359 // tbb
360 brgemm_kernel_t *brgemm_kernel_;
361 {
362 auto brgemm_kernel_addr = &brgemm_kernel_;
363 DNN_SAFE(create_in_thr_ctx(prb->ctx_init, brgemm_kernel_create,
364 brgemm_kernel_addr, brgemm_desc),
365 WARN);
366 }
367 auto brgemm_kernel = make_benchdnn_dnnl_wrapper(brgemm_kernel_);
368
369 const auto is_tmm = brgemm_desc.is_tmm;
370 if (is_tmm) {
371 char palette[AMX_PALETTE_SIZE] = {};
372 DNN_SAFE(brgemm_init_tiles(brgemm_desc, palette), WARN);
373 DNN_SAFE(amx_tile_configure(palette), WARN);
374 }
375
376 auto src_md = dnn_mem_t::init_md(
377 prb->ndims, src_dims, prb->src_dt(), prb->stag, src_strides);
378
379 // Create weights memory descriptor with VNNI-friendly format.
380 // Note: LDB is not passed here. This is because it's super difficult to
381 // incorporate stride on top of blocking - oneDNN API doesn't provide any
382 // calls to support both options together. Submemory descriptor, which is
383 // the only one who can create such memory desc, can't return the size of
384 // memory. Thus, it requires two memories and we need to pass a memory
385 // handle from bigger one (where LDB is an actual dim value) to smaller, but
386 // there's some reorder bug resulting in an error.
387 const auto wtag = prepare_wei_format_string(
388 prb->wei_dt(), prb->n, brgemm_desc.is_b_data_layout_vnni());
389 BENCHDNN_PRINT(6, "wtag: %s\n", wtag.c_str());
390 auto wei_md = dnn_mem_t::init_md(prb->ndims, wei_dims, prb->wei_dt(), wtag);
391
392 const size_t wei_offset_s8s8 = dnnl_memory_desc_get_size(wei_md);
393 // Prepare and assign extra for wei_md when s8s8 compensation, or source
394 // zero point reduction values are needed.
395 dnnl::impl::memory_extra_desc_t wei_md_extra {};
396 wei_md_extra.flags = dnnl::impl::memory_extra_flags::none;
397 if (prb->get_dt(SRC) == dnnl_s8 && prb->get_dt(WEI) == dnnl_s8) {
398 wei_md_extra.flags
399 |= dnnl::impl::memory_extra_flags::compensation_conv_s8s8;
400 wei_md_extra.compensation_mask = 2; // N dimension
401 }
402 static_cast<dnnl_memory_desc_t>(wei_md)->extra = wei_md_extra;
403
404 const size_t wei_offset_zp = wei_offset_s8s8
405 + (wei_md_extra.flags != dnnl::impl::memory_extra_flags::none
406 ? prb->get_ldb() * sizeof(int32_t)
407 : 0);
408
409 const bool need_src_comp = !prb->attr.zero_points.is_def(DNNL_ARG_SRC);
410 if (need_src_comp) {
411 wei_md_extra.flags |= dnnl::impl::memory_extra_flags::
412 compensation_conv_asymmetric_src;
413 wei_md_extra.asymm_compensation_mask = 2; // N dimension
414 }
415 static_cast<dnnl_memory_desc_t>(wei_md)->extra = wei_md_extra;
416
417 // Same as dst_md but with a pre-defined data type according to doc.
418 auto acc_md = dnn_mem_t::init_md(prb->ndims, prb->dst_dims.data(),
419 prb->acc_dt(), tag::abx, acc_strides);
420
421 benchdnn_dnnl_wrapper_t<dnnl_memory_desc_t> bia_md {};
422 if (prb->bia_dt != dnnl_data_type_undef) {
423 const dnnl_dims_t bia_dims = {1, prb->n};
424 bia_md = dnn_mem_t::init_md(
425 prb->ndims, bia_dims, prb->bia_dt, tag::abx);
426 }
427
428 const auto &test_engine = get_test_engine();
429 const auto &ref_engine = get_cpu_engine();
430
431 dnn_mem_t src_dt(src_md, test_engine);
432 dnn_mem_t wei_dt(wei_md, test_engine);
433 dnn_mem_t acc_dt(acc_md, test_engine);
434 dnn_mem_t dst_dt(dst_md, test_engine);
435 dnn_mem_t bia_dt;
436 if (prb->bia_dt != dnnl_data_type_undef)
437 bia_dt = dnn_mem_t(bia_md, test_engine);
438
439 dnn_mem_t src_fp(src_md, dnnl_f32, tag::abx, ref_engine);
440 dnn_mem_t wei_fp(wei_md, dnnl_f32, tag::abx, ref_engine);
441 dnn_mem_t acc_fp(acc_md, dnnl_f32, tag::abx, ref_engine);
442 dnn_mem_t dst_fp(dst_md, dnnl_f32, tag::abx, ref_engine);
443 dnn_mem_t bia_fp;
444 if (prb->bia_dt != dnnl_data_type_undef)
445 bia_fp = dnn_mem_t(bia_md, dnnl_f32, tag::abx, ref_engine);
446
447 SAFE(fill_data(SRC, prb, src_dt, src_fp, res), WARN);
448 SAFE(fill_data(WEI, prb, wei_dt, wei_fp, res), WARN);
449 const int sum_idx = prb->attr.post_ops.find(attr_t::post_ops_t::SUM);
450 if ((prb->beta != 0) || brgemm_attr.generate_skip_accumulation) {
451 SAFE(fill_data(DST, prb, acc_dt, acc_fp, res), WARN);
452 // Beta requires same values for reference and the kernel.
453 if (use_dst_as_acc) {
454 dst_fp.reorder(acc_fp);
455 dst_dt.reorder(dst_fp);
456 }
457 }
458 if (sum_idx >= 0) SAFE(fill_data(DST, prb, dst_dt, dst_fp, res), WARN);
459 if (prb->bia_dt != dnnl_data_type_undef)
460 SAFE(fill_data(BIA, prb, bia_dt, bia_fp, res), WARN);
461
462 dnn_mem_t src_zero_points_m, wei_zero_points_m, dst_zero_points_m;
463 const auto &wei_zero_point_val
464 = prb->attr.zero_points.get(DNNL_ARG_WEIGHTS).value;
465 maybe_prepare_runtime_zero_points(
466 src_zero_points_m, prb->attr, DNNL_ARG_SRC, prb->k, prb->src_zp);
467 maybe_prepare_runtime_zero_points(wei_zero_points_m, prb->attr,
468 DNNL_ARG_WEIGHTS, 1, &(wei_zero_point_val));
469 maybe_prepare_runtime_zero_points(
470 dst_zero_points_m, prb->attr, DNNL_ARG_DST, prb->n, prb->dst_zp);
471
472 // "Library" args are needed to get dst for comparison.
473 // "Reference" are used as usual.
474 args_t args, ref_args;
475 args.set(DNNL_ARG_DST, dst_dt);
476
477 std::vector<brgemm_batch_element_t> v_batch_element(prb->batch_size);
478 const char *src_ptr = (const char *)src_dt;
479 const char *wei_ptr = (const char *)wei_dt;
480 // Note: batch_size is incorporated into K dimension.
481 // That's why each source batch has an offset of `k`.
482 // Weights have more complicated case. Weights are in double-blocked format,
483 // which becomes triple-blocked for bf16 and int8 to become VNNI-friendly.
484 // Because of this and batch_size incorporation, offsets below DO NOT work
485 // with K not divisible by K block size and batch_size > 1.
486 // The problem is it can't be handled properly when batch size is fused,
487 // but this allows enable s8s8 and zero-points compensation cases easier.
488 int block_size = 0;
489 switch (prb->wei_dt()) {
490 case dnnl_f32: block_size = 16; break;
491 case dnnl_f16: block_size = 16; break;
492 case dnnl_bf16: block_size = 32; break;
493 case dnnl_s8: block_size = 64; break;
494 default: break;
495 }
496 (void)block_size;
497 assert(block_size > 1);
498 assert(IMPLICATION(prb->batch_size > 1, prb->k % block_size == 0));
499
500 const int64_t src_batch_offset = prb->k;
501 const int64_t wei_batch_offset = prb->get_ldb() * prb->k;
502 BENCHDNN_PRINT(6, "src_batch_offset=%ld wei_batch_offset=%ld\n",
503 (long)src_batch_offset, (long)wei_batch_offset);
504
505 for (size_t i = 0; i < v_batch_element.size(); i++) {
506 v_batch_element[i].ptr.A
507 = src_ptr + i * src_batch_offset * src_dt.sizeof_dt();
508 v_batch_element[i].ptr.B
509 = wei_ptr + i * wei_batch_offset * wei_dt.sizeof_dt();
510 }
511
512 // Brgemm takes single pointer oscale, but relies on a combination of arg
513 // scales attributes. This helps to reuse attributes from primitives, but
514 // requires them to pre-compute oscale = src_scale * wei_scale[:]
515 dnn_mem_t scales;
516 auto src_scale = prb->attr.scales.get(DNNL_ARG_SRC);
517 auto wei_scale = prb->attr.scales.get(DNNL_ARG_WEIGHTS);
518 auto attr_scale = wei_scale.runtime ? wei_scale : src_scale;
519 maybe_prepare_runtime_scales(scales, attr_scale, prb->n, prb->scales);
520 // Handle output scale common policy separately since the implementation
521 // always expects them to be of vector length in case of `common` policy.
522 std::vector<float> v16_scales(16, prb->scales[0]);
523 const float *scales_ptr = attr_scale.policy == policy_t::COMMON
524 ? v16_scales.data()
525 : (const float *)scales;
526
527 char *acc_ptr = (char *)acc_dt;
528
529 const int32_t *dst_zp_ptr = (const int32_t *)dst_zero_points_m;
530 char *src_comp_ptr = (char *)wei_dt + wei_offset_zp;
531 int32_t zp_a_val = !prb->attr.zero_points.is_def(DNNL_ARG_SRC)
532 ? src_zero_points_m.get_elem(0)
533 : 0;
534
535 if (!prb->attr.zero_points.is_def(DNNL_ARG_WEIGHTS)) {
536 // TODO: weights zero point is not supported yet.
537 // It requires enabling f32 -> u8 reorder with compensation on the
538 // library side. When enabled, it produces incorrect results for cases
539 // with K=1. Likely there's a bug inside. Postpone supporting it.
540 return res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED, OK;
541 }
542
543 if (prb->attr.post_ops.binary_index() >= 0) {
544 // TODO: binary post-op is not supported yet.
545 return res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED, OK;
546 }
547
548 brgemm_post_ops_data_t post_ops_data(
549 /* bias */ (const char *)bia_dt,
550 /* scales */ scales_ptr, /* binary_post_ops_rhs */ nullptr,
551 /* oc_logical_off */ 0, /* dst_row_logical_off */ 0,
552 /* data_C_ptr_ */ acc_ptr, /* first_mb_matrix_addr_off */ 0,
553 /* a_zp_compensations */ src_comp_ptr,
554 /* b_zp_compensations */ nullptr,
555 /* c_zp_values */ dst_zp_ptr,
556 /* skip_accumulation */ brgemm_attr.generate_skip_accumulation,
557 /* zp_a_val */ zp_a_val,
558 /* do_only_comp */ false,
559 /* do_only_zp_a_val */ false);
560
561 auto scratchpad_size = brgemm_desc.get_wsp_buffer_size();
562 std::vector<char> scratchpad(scratchpad_size);
563 // Note: hardware lacking native s8s8 support expects compensation buffer
564 // passed through a scratchpad argument in postops execution call.
565 const bool need_hidden_compensation = scratchpad_size == 0
566 && prb->get_dt(SRC) == dnnl_s8 && prb->get_dt(WEI) == dnnl_s8;
567 char *scratchpad_ptr = need_hidden_compensation
568 ? ((char *)wei_dt + wei_offset_s8s8)
569 : scratchpad.data();
570
571 char *dst_ptr = (char *)dst_dt;
572 if (use_dst_as_acc) acc_ptr = dst_ptr;
573
574 brgemm_kernel_execute_postops(brgemm_kernel, prb->batch_size,
575 v_batch_element.data(), acc_ptr, dst_ptr, post_ops_data,
576 scratchpad_ptr);
577 if (res) res->state = EXECUTED;
578
579 if (is_bench_mode(CORR)) {
580 ref_args.set(DNNL_ARG_SRC, src_fp);
581 ref_args.set(DNNL_ARG_WEIGHTS, wei_fp);
582 if (prb->bia_dt != dnnl_data_type_undef)
583 ref_args.set(DNNL_ARG_BIAS, bia_fp);
584 ref_args.set(DNNL_ARG_DST, dst_fp);
585 // Passing accumulator values for `generate_skip_accumulation` check.
586 ref_args.set(DNNL_ARG_SRC_1, acc_fp);
587 // A hack to pass brgemm attributes to reference since some members
588 // change the computation flow for correctness validation.
589 dnn_mem_t workspace(src_md, ref_engine, {false, (void *)&brgemm_attr});
590 ref_args.set(DNNL_ARG_WORKSPACE, workspace);
591
592 check_correctness(prb, {DST}, args, ref_args, setup_cmp, res);
593 }
594
595 // Create a bind to match internals to run performance measurements.
596 perf_function_t perf_func = std::bind(brgemm_kernel_execute_postops_wrapper,
597 brgemm_kernel_, prb->batch_size, v_batch_element.data(), acc_ptr,
598 dst_ptr, post_ops_data, scratchpad_ptr, std::placeholders::_1,
599 std::placeholders::_2);
600 measure_perf(prb->ctx_exe, res, perf_func, args);
601
602 if (is_tmm) DNN_SAFE(amx_tile_release(), WARN);
603
604 return OK;
605}
606
607#else
608
609int doit(const prb_t *prb, res_t *res) {
610 return OK;
611}
612
613#endif
614
615} // namespace brgemm
616