1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <float.h> |
18 | #include <functional> |
19 | #include <math.h> |
20 | #include <random> |
21 | #include <stdio.h> |
22 | #include <stdlib.h> |
23 | |
24 | #include "oneapi/dnnl/dnnl.h" |
25 | |
26 | // TODO: refactor the driver to avoid using extra flags of a memory descriptor. |
27 | #include "src/common/memory_desc.hpp" |
28 | |
29 | #include "tests/test_isa_common.hpp" |
30 | |
31 | #include "utils/parallel.hpp" |
32 | #include "utils/parser.hpp" |
33 | |
34 | #include "dnnl_common.hpp" |
35 | #include "dnnl_memory.hpp" |
36 | |
37 | #include "brgemm/brgemm.hpp" |
38 | |
39 | #if defined(DNNL_X64) && DNNL_X64 == 1 && DNNL_CPU_RUNTIME != DNNL_RUNTIME_NONE |
40 | template <> |
41 | struct dnnl_api_traits<dnnl::impl::cpu::x64::brgemm_kernel_t *> { |
42 | static void destroy(dnnl::impl::cpu::x64::brgemm_kernel_t *t) { |
43 | DNN_SAFE_V(dnnl::impl::cpu::x64::brgemm_kernel_destroy(t)); |
44 | } |
45 | }; |
46 | #endif |
47 | |
48 | namespace brgemm { |
49 | |
50 | #if defined(DNNL_X64) && DNNL_X64 == 1 && DNNL_CPU_RUNTIME != DNNL_RUNTIME_NONE |
51 | |
52 | /// Initializes BRGEMM attributes from an input string. |
53 | /// |
54 | /// @param brgattr Output BRGEMM attributes. |
55 | /// @param str Input string of values in the format: KEY:VALUE[+KEY:VALUE[...]]. |
56 | /// `KEY` follows exact name of the brgemm_attr_t object members and their |
57 | /// `VALUE` follow the member type. enum and boolean types are treated as |
58 | /// integers. |
59 | /// |
60 | dnnl_status_t brgemm_attr_init( |
61 | dnnl::impl::cpu::x64::brgemm_attr_t *brgattr, const prb_t *prb) { |
62 | using namespace dnnl::impl::cpu::x64; |
63 | |
64 | const auto &str = prb->brgemm_attr; |
65 | if (str.empty()) return dnnl_success; |
66 | |
67 | size_t entry_pos = 0; |
68 | while (entry_pos != std::string::npos) { |
69 | auto key_value_str = parser::get_substr(str, entry_pos, '+'); |
70 | size_t value_pos = 0; |
71 | auto key_str = parser::get_substr(key_value_str, value_pos, ':'); |
72 | auto value_str = parser::get_substr(key_value_str, value_pos, '\0'); |
73 | |
74 | #define PROCESS_SETTING_KEY_VAL(setting, key) \ |
75 | if (key_str.compare(STRINGIFY(key)) == 0) \ |
76 | brgattr->setting = std::stoi(value_str); |
77 | |
78 | #define PROCESS_KEY_VAL(setting) PROCESS_SETTING_KEY_VAL(setting, setting) |
79 | |
80 | // TODO: `max_top_vpad` and `max_bottom_vpad` do not affect anything in |
81 | // the kernel call and reference computation so far since |
82 | // batch_element_t struct is not adjusted to incorporate different pad |
83 | // values. |
84 | // PROCESS_KEY_VAL(max_top_vpad); |
85 | // PROCESS_KEY_VAL(max_bottom_vpad); |
86 | PROCESS_KEY_VAL(hint_expected_A_size); |
87 | PROCESS_KEY_VAL(hint_expected_B_size); |
88 | PROCESS_KEY_VAL(hint_expected_C_size); |
89 | PROCESS_KEY_VAL(wary_tail_read); |
90 | PROCESS_KEY_VAL(generate_skip_accumulation); |
91 | // TODO: `bd_mask` can't be passed to the kernel at this moment, that's |
92 | // why `bd_mask_level` has to stay `0` for now until it's enabled. |
93 | // PROCESS_KEY_VAL(bd_mask_level); |
94 | PROCESS_KEY_VAL(use_uker); |
95 | PROCESS_KEY_VAL(use_interleave_stores); |
96 | PROCESS_KEY_VAL(postops_only); |
97 | PROCESS_KEY_VAL(hint_bd_block); |
98 | PROCESS_KEY_VAL(hint_bd_block2); |
99 | PROCESS_KEY_VAL(hint_ld_block); |
100 | PROCESS_KEY_VAL(hint_ld_block2); |
101 | |
102 | PROCESS_SETTING_KEY_VAL(hint_prfA.dist1, hint_prfA_dist1); |
103 | PROCESS_SETTING_KEY_VAL(hint_prfA.dist2, hint_prfA_dist2); |
104 | PROCESS_SETTING_KEY_VAL(hint_prfB.dist1, hint_prfB_dist1); |
105 | PROCESS_SETTING_KEY_VAL(hint_prfB.dist2, hint_prfB_dist2); |
106 | PROCESS_SETTING_KEY_VAL(hint_prfC.dist1, hint_prfC_dist1); |
107 | PROCESS_SETTING_KEY_VAL(hint_prfC.dist2, hint_prfC_dist2); |
108 | |
109 | #undef PROCESS_SETTING_KEY_VAL |
110 | #undef PROCESS_KEY_VAL |
111 | |
112 | if (key_str.find(STRINGIFY(hint_innermost_loop)) != std::string::npos) |
113 | brgattr->hint_innermost_loop |
114 | = static_cast<brgemm_kernel_innermost_loop_t>( |
115 | std::stoi(value_str)); |
116 | if (key_str.find(STRINGIFY(hint_loop_order)) != std::string::npos) |
117 | brgattr->hint_loop_order = static_cast<brgemm_kernel_loop_order_t>( |
118 | std::stoi(value_str)); |
119 | if (key_str.find(STRINGIFY(hint_prefetching)) != std::string::npos) |
120 | brgattr->hint_prefetching |
121 | = static_cast<brgemm_kernel_prefetching_t>( |
122 | std::stoi(value_str)); |
123 | if (key_str.find(STRINGIFY(hint_load_nt_A)) != std::string::npos) |
124 | brgattr->hint_load_nt_A = static_cast<brgemm_kernel_hint_nt_t>( |
125 | std::stoi(value_str)); |
126 | if (key_str.find(STRINGIFY(hint_load_nt_B)) != std::string::npos) |
127 | brgattr->hint_load_nt_B = static_cast<brgemm_kernel_hint_nt_t>( |
128 | std::stoi(value_str)); |
129 | } |
130 | |
131 | // `max_bs` is handled directly through the driver interface. |
132 | brgattr->max_bs = prb->batch_size; |
133 | |
134 | // `fpmath_mode` is handled directly through the driver interface. |
135 | brgattr->fpmath_mode = prb->attr.fpmath_mode; |
136 | |
137 | return dnnl_success; |
138 | } |
139 | |
140 | std::string prepare_wei_format_string( |
141 | dnnl_data_type_t dt, int64_t n, bool is_vnni_layout) { |
142 | // `dt` affects the choice of last inner block (for VNNI-friendliness). |
143 | // `n` affects the choice of B block. |
144 | std::string wtag("BA16a" ); |
145 | switch (n) { |
146 | case 64: wtag += "64b" ; break; |
147 | case 48: wtag += "48b" ; break; |
148 | case 32: wtag += "32b" ; break; |
149 | default: |
150 | if (n <= 16) |
151 | wtag += "16b" ; |
152 | else { |
153 | if (n % 16 != 0) { |
154 | wtag += std::to_string(n) + "b" ; |
155 | } else |
156 | wtag += std::to_string(16 * div_up(n, 16)) + "b" ; |
157 | } |
158 | break; |
159 | } |
160 | if (is_vnni_layout) { |
161 | switch (dt) { |
162 | case dnnl_f32: break; |
163 | case dnnl_f16: |
164 | case dnnl_bf16: wtag += "2a" ; break; |
165 | case dnnl_s8: wtag += "4a" ; break; |
166 | default: assert(!"unsupported data type" ); |
167 | } |
168 | } |
169 | |
170 | return wtag; |
171 | } |
172 | |
173 | int fill_data(data_kind_t kind, const prb_t *prb, dnn_mem_t &mem_dt, |
174 | dnn_mem_t &mem_fp, res_t *res) { |
175 | |
176 | const auto nelems = mem_dt.nelems(); |
177 | if (nelems == 0) return OK; |
178 | |
179 | assert(mem_dt.nelems() == mem_fp.nelems()); |
180 | |
181 | cfg_t cfg(prb, {SRC, WEI, BIA, DST}); |
182 | cfg_t::density_args_t density_args; |
183 | density_args.data_kind = kind; |
184 | density_args.n_acc = prb->k; |
185 | const auto density = cfg.get_density(density_args); |
186 | |
187 | /* Do fixed partitioning to have same filling for any number of threads */ |
188 | const int64_t n_chunks = 16; |
189 | const int64_t chunk_size = div_up(nelems, n_chunks); |
190 | |
191 | benchdnn_parallel_nd(n_chunks, [&](int64_t idx_chunk) { |
192 | int64_t idx_start = idx_chunk * chunk_size; |
193 | int64_t idx_end = MIN2(idx_start + chunk_size, nelems); |
194 | // Note: we use a different seed for each chunk to avoid |
195 | // repeating patterns. We could use discard(idx_start) too but |
196 | // it has a complexity in O(idx_start). We also add 1 to avoid |
197 | // seeding with 0. |
198 | std::minstd_rand int_seed(kind * nelems + idx_start + 1); |
199 | int_seed.discard(1); |
200 | std::minstd_rand b_seed(kind * nelems + idx_start + 1); |
201 | b_seed.discard(10); |
202 | |
203 | std::uniform_int_distribution<> gen( |
204 | cfg.get_range_min(kind), cfg.get_range_max(kind)); |
205 | std::bernoulli_distribution b_dist(density); |
206 | |
207 | // make sure the first element is positive |
208 | if (idx_start == 0) { |
209 | float val = 0; |
210 | while (val <= 0) |
211 | val = gen(int_seed); |
212 | mem_fp.set_elem( |
213 | 0, round_to_nearest_representable(cfg.get_dt(kind), val)); |
214 | idx_start += 1; |
215 | } |
216 | |
217 | for (int64_t idx = idx_start; idx < idx_end; ++idx) { |
218 | bool is_one = density == 1.f ? true : b_dist(b_seed); |
219 | float val = is_one * gen(int_seed); |
220 | mem_fp.set_elem( |
221 | idx, round_to_nearest_representable(cfg.get_dt(kind), val)); |
222 | } |
223 | }); |
224 | |
225 | SAFE(mem_dt.reorder(mem_fp), WARN); |
226 | |
227 | return OK; |
228 | } |
229 | |
230 | void skip_unimplemented_prb(const prb_t *prb, res_t *res) { |
231 | skip_unimplemented_data_type( |
232 | {prb->src_dt(), prb->wei_dt(), prb->bia_dt, prb->dst_dt()}, |
233 | prb->dir, res); |
234 | skip_unimplemented_sum_po(prb->attr, res, prb->dst_dt()); |
235 | } |
236 | |
237 | void skip_invalid_prb(const prb_t *prb, res_t *res) { |
238 | const bool is_src_zp = !prb->attr.zero_points.is_def(DNNL_ARG_SRC); |
239 | const bool is_dst_zp = !prb->attr.zero_points.is_def(DNNL_ARG_DST); |
240 | |
241 | // Only runtime zero points are supported by this driver |
242 | const bool is_runtime_src_zp = prb->attr.zero_points.runtime(DNNL_ARG_SRC); |
243 | const bool is_runtime_dst_zp = prb->attr.zero_points.runtime(DNNL_ARG_DST); |
244 | const bool is_static_zp = (is_src_zp && !is_runtime_src_zp) |
245 | || (is_dst_zp && !is_runtime_dst_zp); |
246 | if (is_static_zp) { |
247 | res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED; |
248 | return; |
249 | } |
250 | |
251 | // AMX kernel only supports SRC zero points in unrolled kernel, |
252 | // and only for values of 0 or 1. |
253 | // Note: this check must be done here due to the fact that zero point value |
254 | // in brgemm API is a runtime argument. |
255 | // TODO: remove once AMX kernel fully supports zero points. |
256 | const bool is_amx = dnnl::mayiuse(dnnl_cpu_isa_avx512_core_amx); |
257 | const int src_zp_value = prb->attr.zero_points.get(DNNL_ARG_SRC).value; |
258 | if (is_amx && is_src_zp && src_zp_value != 0 && src_zp_value != 1) { |
259 | res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED; |
260 | return; |
261 | } |
262 | } |
263 | |
264 | void setup_cmp(compare::compare_t &cmp, const prb_t *prb, data_kind_t kind, |
265 | const args_t &ref_args) { |
266 | const auto dt = prb->get_dt(kind); |
267 | const float trh = dt == dnnl_f32 ? 1e-6f : epsilon_dt(dt); |
268 | cmp.set_threshold(trh); |
269 | cmp.set_zero_trust_percent(90.f); // TODO: why so bad filling? |
270 | } |
271 | |
272 | // A special wrapper needed to match internal infrastructure. |
273 | dnnl_status_t brgemm_kernel_execute_postops_wrapper( |
274 | const dnnl::impl::cpu::x64::brgemm_kernel_t *brgemm_kernel, |
275 | int batch_size, |
276 | const dnnl::impl::cpu::x64::brgemm_batch_element_t *batch_element, |
277 | void *acc_ptr, void *dst_ptr, |
278 | const dnnl::impl::cpu::x64::brgemm_post_ops_data_t &post_ops_data, |
279 | void *scratchpad_ptr, const dnnl_stream_t &stream, |
280 | const std::vector<dnnl_exec_arg_t> &dnnl_args) { |
281 | brgemm_kernel_execute_postops(brgemm_kernel, batch_size, batch_element, |
282 | acc_ptr, dst_ptr, post_ops_data, scratchpad_ptr); |
283 | return dnnl_success; |
284 | } |
285 | |
286 | int doit(const prb_t *prb, res_t *res) { |
287 | if (bench_mode == LIST) return res->state = LISTED, OK; |
288 | |
289 | skip_start(res); |
290 | if (res->state == SKIPPED) return OK; |
291 | |
292 | // Need this here as brgemm has no primitive creation step |
293 | skip_invalid_prb(prb, res); |
294 | if (res->state == SKIPPED) return OK; |
295 | |
296 | bool use_dst_as_acc = false; |
297 | if (prb->bia_dt == dnnl_data_type_undef && prb->acc_dt() == prb->dst_dt() |
298 | && prb->attr.is_def(/* skip_fmpath = */ true)) |
299 | use_dst_as_acc = true; |
300 | |
301 | // Fuse batch size into K dimension which follows the library usage of the |
302 | // kernel batch size setting. |
303 | const dnnl_dims_t src_dims = {prb->m, prb->k * prb->batch_size}; |
304 | const dnnl_dims_t wei_dims = {prb->k * prb->batch_size, prb->n}; |
305 | |
306 | dims_t src_strides = {prb->get_lda(), 1}; |
307 | dims_t dst_strides = {prb->get_ldd(), 1}; |
308 | dims_t acc_strides = use_dst_as_acc ? dst_strides : dims_t(); |
309 | |
310 | auto dst_md = dnn_mem_t::init_md(prb->ndims, prb->dst_dims.data(), |
311 | prb->dst_dt(), prb->dtag, dst_strides); |
312 | |
313 | using namespace dnnl::impl::cpu::x64; |
314 | |
315 | brgemm_t brgemm_desc; |
316 | // Supports only address model for now as only affects the way memory is |
317 | // passed to `brgemm_batch_element_t` object. |
318 | brgemm_batch_kind_t batch_kind = brgemm_batch_kind_t::brgemm_addr; |
319 | brgemm_layout_t layout = brgemm_layout_t::brgemm_row_major; |
320 | |
321 | // Pass `isa_undef` for now since internal work with it or rather isa bits |
322 | // than isa values directly which causes misalignment between public enum |
323 | // and internal values. |
324 | // TODO: re-consider enabling isa values. |
325 | const auto isa_undef = cpu_isa_t::isa_undef; |
326 | |
327 | // Create BRGeMM descriptor, analogous to primitive descriptor creation |
328 | const auto status_init = brgemm_desc_init(&brgemm_desc, isa_undef, |
329 | batch_kind, prb->src_dt(), prb->wei_dt(), false /* transA */, |
330 | false /* transB */, layout, prb->alpha, prb->beta, prb->get_lda(), |
331 | prb->get_ldb(), prb->get_ldc(use_dst_as_acc), prb->m, prb->n, |
332 | prb->k, nullptr /* strides */); |
333 | check_dnnl_status(status_init, prb, res); |
334 | if (res->state == SKIPPED) return OK; |
335 | // Unconditionally skip remaining unimplemented cases. |
336 | // TODO: remove this and add a SAFE check above. |
337 | if (status_init != dnnl_success) |
338 | return res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED, OK; |
339 | |
340 | attr_args_t attr_args; |
341 | auto dnnl_attr = make_benchdnn_dnnl_wrapper( |
342 | create_dnnl_attr(prb->attr, attr_args)); |
343 | |
344 | SAFE(check_dnnl_status(brgemm_desc_set_postops(&brgemm_desc, dnnl_attr, |
345 | dst_md, prb->get_ldd(), prb->bia_dt), |
346 | prb, res), |
347 | WARN); |
348 | if (res->state == SKIPPED) return OK; |
349 | |
350 | brgemm_attr_t brgemm_attr; |
351 | DNN_SAFE(brgemm_attr_init(&brgemm_attr, prb), WARN); |
352 | SAFE(check_dnnl_status( |
353 | brgemm_desc_set_attr(&brgemm_desc, brgemm_attr), prb, res), |
354 | WARN); |
355 | if (res->state == SKIPPED) return OK; |
356 | |
357 | // Create BRGeMM kernel, analogous to primitive creation. |
358 | // ctx_init can here be used to select core type on hetero ISA with |
359 | // tbb |
360 | brgemm_kernel_t *brgemm_kernel_; |
361 | { |
362 | auto brgemm_kernel_addr = &brgemm_kernel_; |
363 | DNN_SAFE(create_in_thr_ctx(prb->ctx_init, brgemm_kernel_create, |
364 | brgemm_kernel_addr, brgemm_desc), |
365 | WARN); |
366 | } |
367 | auto brgemm_kernel = make_benchdnn_dnnl_wrapper(brgemm_kernel_); |
368 | |
369 | const auto is_tmm = brgemm_desc.is_tmm; |
370 | if (is_tmm) { |
371 | char palette[AMX_PALETTE_SIZE] = {}; |
372 | DNN_SAFE(brgemm_init_tiles(brgemm_desc, palette), WARN); |
373 | DNN_SAFE(amx_tile_configure(palette), WARN); |
374 | } |
375 | |
376 | auto src_md = dnn_mem_t::init_md( |
377 | prb->ndims, src_dims, prb->src_dt(), prb->stag, src_strides); |
378 | |
379 | // Create weights memory descriptor with VNNI-friendly format. |
380 | // Note: LDB is not passed here. This is because it's super difficult to |
381 | // incorporate stride on top of blocking - oneDNN API doesn't provide any |
382 | // calls to support both options together. Submemory descriptor, which is |
383 | // the only one who can create such memory desc, can't return the size of |
384 | // memory. Thus, it requires two memories and we need to pass a memory |
385 | // handle from bigger one (where LDB is an actual dim value) to smaller, but |
386 | // there's some reorder bug resulting in an error. |
387 | const auto wtag = prepare_wei_format_string( |
388 | prb->wei_dt(), prb->n, brgemm_desc.is_b_data_layout_vnni()); |
389 | BENCHDNN_PRINT(6, "wtag: %s\n" , wtag.c_str()); |
390 | auto wei_md = dnn_mem_t::init_md(prb->ndims, wei_dims, prb->wei_dt(), wtag); |
391 | |
392 | const size_t wei_offset_s8s8 = dnnl_memory_desc_get_size(wei_md); |
393 | // Prepare and assign extra for wei_md when s8s8 compensation, or source |
394 | // zero point reduction values are needed. |
395 | dnnl::impl::memory_extra_desc_t {}; |
396 | wei_md_extra.flags = dnnl::impl::memory_extra_flags::none; |
397 | if (prb->get_dt(SRC) == dnnl_s8 && prb->get_dt(WEI) == dnnl_s8) { |
398 | wei_md_extra.flags |
399 | |= dnnl::impl::memory_extra_flags::compensation_conv_s8s8; |
400 | wei_md_extra.compensation_mask = 2; // N dimension |
401 | } |
402 | static_cast<dnnl_memory_desc_t>(wei_md)->extra = wei_md_extra; |
403 | |
404 | const size_t wei_offset_zp = wei_offset_s8s8 |
405 | + (wei_md_extra.flags != dnnl::impl::memory_extra_flags::none |
406 | ? prb->get_ldb() * sizeof(int32_t) |
407 | : 0); |
408 | |
409 | const bool need_src_comp = !prb->attr.zero_points.is_def(DNNL_ARG_SRC); |
410 | if (need_src_comp) { |
411 | wei_md_extra.flags |= dnnl::impl::memory_extra_flags:: |
412 | compensation_conv_asymmetric_src; |
413 | wei_md_extra.asymm_compensation_mask = 2; // N dimension |
414 | } |
415 | static_cast<dnnl_memory_desc_t>(wei_md)->extra = wei_md_extra; |
416 | |
417 | // Same as dst_md but with a pre-defined data type according to doc. |
418 | auto acc_md = dnn_mem_t::init_md(prb->ndims, prb->dst_dims.data(), |
419 | prb->acc_dt(), tag::abx, acc_strides); |
420 | |
421 | benchdnn_dnnl_wrapper_t<dnnl_memory_desc_t> bia_md {}; |
422 | if (prb->bia_dt != dnnl_data_type_undef) { |
423 | const dnnl_dims_t bia_dims = {1, prb->n}; |
424 | bia_md = dnn_mem_t::init_md( |
425 | prb->ndims, bia_dims, prb->bia_dt, tag::abx); |
426 | } |
427 | |
428 | const auto &test_engine = get_test_engine(); |
429 | const auto &ref_engine = get_cpu_engine(); |
430 | |
431 | dnn_mem_t src_dt(src_md, test_engine); |
432 | dnn_mem_t wei_dt(wei_md, test_engine); |
433 | dnn_mem_t acc_dt(acc_md, test_engine); |
434 | dnn_mem_t dst_dt(dst_md, test_engine); |
435 | dnn_mem_t bia_dt; |
436 | if (prb->bia_dt != dnnl_data_type_undef) |
437 | bia_dt = dnn_mem_t(bia_md, test_engine); |
438 | |
439 | dnn_mem_t src_fp(src_md, dnnl_f32, tag::abx, ref_engine); |
440 | dnn_mem_t wei_fp(wei_md, dnnl_f32, tag::abx, ref_engine); |
441 | dnn_mem_t acc_fp(acc_md, dnnl_f32, tag::abx, ref_engine); |
442 | dnn_mem_t dst_fp(dst_md, dnnl_f32, tag::abx, ref_engine); |
443 | dnn_mem_t bia_fp; |
444 | if (prb->bia_dt != dnnl_data_type_undef) |
445 | bia_fp = dnn_mem_t(bia_md, dnnl_f32, tag::abx, ref_engine); |
446 | |
447 | SAFE(fill_data(SRC, prb, src_dt, src_fp, res), WARN); |
448 | SAFE(fill_data(WEI, prb, wei_dt, wei_fp, res), WARN); |
449 | const int sum_idx = prb->attr.post_ops.find(attr_t::post_ops_t::SUM); |
450 | if ((prb->beta != 0) || brgemm_attr.generate_skip_accumulation) { |
451 | SAFE(fill_data(DST, prb, acc_dt, acc_fp, res), WARN); |
452 | // Beta requires same values for reference and the kernel. |
453 | if (use_dst_as_acc) { |
454 | dst_fp.reorder(acc_fp); |
455 | dst_dt.reorder(dst_fp); |
456 | } |
457 | } |
458 | if (sum_idx >= 0) SAFE(fill_data(DST, prb, dst_dt, dst_fp, res), WARN); |
459 | if (prb->bia_dt != dnnl_data_type_undef) |
460 | SAFE(fill_data(BIA, prb, bia_dt, bia_fp, res), WARN); |
461 | |
462 | dnn_mem_t src_zero_points_m, wei_zero_points_m, dst_zero_points_m; |
463 | const auto &wei_zero_point_val |
464 | = prb->attr.zero_points.get(DNNL_ARG_WEIGHTS).value; |
465 | maybe_prepare_runtime_zero_points( |
466 | src_zero_points_m, prb->attr, DNNL_ARG_SRC, prb->k, prb->src_zp); |
467 | maybe_prepare_runtime_zero_points(wei_zero_points_m, prb->attr, |
468 | DNNL_ARG_WEIGHTS, 1, &(wei_zero_point_val)); |
469 | maybe_prepare_runtime_zero_points( |
470 | dst_zero_points_m, prb->attr, DNNL_ARG_DST, prb->n, prb->dst_zp); |
471 | |
472 | // "Library" args are needed to get dst for comparison. |
473 | // "Reference" are used as usual. |
474 | args_t args, ref_args; |
475 | args.set(DNNL_ARG_DST, dst_dt); |
476 | |
477 | std::vector<brgemm_batch_element_t> v_batch_element(prb->batch_size); |
478 | const char *src_ptr = (const char *)src_dt; |
479 | const char *wei_ptr = (const char *)wei_dt; |
480 | // Note: batch_size is incorporated into K dimension. |
481 | // That's why each source batch has an offset of `k`. |
482 | // Weights have more complicated case. Weights are in double-blocked format, |
483 | // which becomes triple-blocked for bf16 and int8 to become VNNI-friendly. |
484 | // Because of this and batch_size incorporation, offsets below DO NOT work |
485 | // with K not divisible by K block size and batch_size > 1. |
486 | // The problem is it can't be handled properly when batch size is fused, |
487 | // but this allows enable s8s8 and zero-points compensation cases easier. |
488 | int block_size = 0; |
489 | switch (prb->wei_dt()) { |
490 | case dnnl_f32: block_size = 16; break; |
491 | case dnnl_f16: block_size = 16; break; |
492 | case dnnl_bf16: block_size = 32; break; |
493 | case dnnl_s8: block_size = 64; break; |
494 | default: break; |
495 | } |
496 | (void)block_size; |
497 | assert(block_size > 1); |
498 | assert(IMPLICATION(prb->batch_size > 1, prb->k % block_size == 0)); |
499 | |
500 | const int64_t src_batch_offset = prb->k; |
501 | const int64_t wei_batch_offset = prb->get_ldb() * prb->k; |
502 | BENCHDNN_PRINT(6, "src_batch_offset=%ld wei_batch_offset=%ld\n" , |
503 | (long)src_batch_offset, (long)wei_batch_offset); |
504 | |
505 | for (size_t i = 0; i < v_batch_element.size(); i++) { |
506 | v_batch_element[i].ptr.A |
507 | = src_ptr + i * src_batch_offset * src_dt.sizeof_dt(); |
508 | v_batch_element[i].ptr.B |
509 | = wei_ptr + i * wei_batch_offset * wei_dt.sizeof_dt(); |
510 | } |
511 | |
512 | // Brgemm takes single pointer oscale, but relies on a combination of arg |
513 | // scales attributes. This helps to reuse attributes from primitives, but |
514 | // requires them to pre-compute oscale = src_scale * wei_scale[:] |
515 | dnn_mem_t scales; |
516 | auto src_scale = prb->attr.scales.get(DNNL_ARG_SRC); |
517 | auto wei_scale = prb->attr.scales.get(DNNL_ARG_WEIGHTS); |
518 | auto attr_scale = wei_scale.runtime ? wei_scale : src_scale; |
519 | maybe_prepare_runtime_scales(scales, attr_scale, prb->n, prb->scales); |
520 | // Handle output scale common policy separately since the implementation |
521 | // always expects them to be of vector length in case of `common` policy. |
522 | std::vector<float> v16_scales(16, prb->scales[0]); |
523 | const float *scales_ptr = attr_scale.policy == policy_t::COMMON |
524 | ? v16_scales.data() |
525 | : (const float *)scales; |
526 | |
527 | char *acc_ptr = (char *)acc_dt; |
528 | |
529 | const int32_t *dst_zp_ptr = (const int32_t *)dst_zero_points_m; |
530 | char *src_comp_ptr = (char *)wei_dt + wei_offset_zp; |
531 | int32_t zp_a_val = !prb->attr.zero_points.is_def(DNNL_ARG_SRC) |
532 | ? src_zero_points_m.get_elem(0) |
533 | : 0; |
534 | |
535 | if (!prb->attr.zero_points.is_def(DNNL_ARG_WEIGHTS)) { |
536 | // TODO: weights zero point is not supported yet. |
537 | // It requires enabling f32 -> u8 reorder with compensation on the |
538 | // library side. When enabled, it produces incorrect results for cases |
539 | // with K=1. Likely there's a bug inside. Postpone supporting it. |
540 | return res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED, OK; |
541 | } |
542 | |
543 | if (prb->attr.post_ops.binary_index() >= 0) { |
544 | // TODO: binary post-op is not supported yet. |
545 | return res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED, OK; |
546 | } |
547 | |
548 | brgemm_post_ops_data_t post_ops_data( |
549 | /* bias */ (const char *)bia_dt, |
550 | /* scales */ scales_ptr, /* binary_post_ops_rhs */ nullptr, |
551 | /* oc_logical_off */ 0, /* dst_row_logical_off */ 0, |
552 | /* data_C_ptr_ */ acc_ptr, /* first_mb_matrix_addr_off */ 0, |
553 | /* a_zp_compensations */ src_comp_ptr, |
554 | /* b_zp_compensations */ nullptr, |
555 | /* c_zp_values */ dst_zp_ptr, |
556 | /* skip_accumulation */ brgemm_attr.generate_skip_accumulation, |
557 | /* zp_a_val */ zp_a_val, |
558 | /* do_only_comp */ false, |
559 | /* do_only_zp_a_val */ false); |
560 | |
561 | auto scratchpad_size = brgemm_desc.get_wsp_buffer_size(); |
562 | std::vector<char> scratchpad(scratchpad_size); |
563 | // Note: hardware lacking native s8s8 support expects compensation buffer |
564 | // passed through a scratchpad argument in postops execution call. |
565 | const bool need_hidden_compensation = scratchpad_size == 0 |
566 | && prb->get_dt(SRC) == dnnl_s8 && prb->get_dt(WEI) == dnnl_s8; |
567 | char *scratchpad_ptr = need_hidden_compensation |
568 | ? ((char *)wei_dt + wei_offset_s8s8) |
569 | : scratchpad.data(); |
570 | |
571 | char *dst_ptr = (char *)dst_dt; |
572 | if (use_dst_as_acc) acc_ptr = dst_ptr; |
573 | |
574 | brgemm_kernel_execute_postops(brgemm_kernel, prb->batch_size, |
575 | v_batch_element.data(), acc_ptr, dst_ptr, post_ops_data, |
576 | scratchpad_ptr); |
577 | if (res) res->state = EXECUTED; |
578 | |
579 | if (is_bench_mode(CORR)) { |
580 | ref_args.set(DNNL_ARG_SRC, src_fp); |
581 | ref_args.set(DNNL_ARG_WEIGHTS, wei_fp); |
582 | if (prb->bia_dt != dnnl_data_type_undef) |
583 | ref_args.set(DNNL_ARG_BIAS, bia_fp); |
584 | ref_args.set(DNNL_ARG_DST, dst_fp); |
585 | // Passing accumulator values for `generate_skip_accumulation` check. |
586 | ref_args.set(DNNL_ARG_SRC_1, acc_fp); |
587 | // A hack to pass brgemm attributes to reference since some members |
588 | // change the computation flow for correctness validation. |
589 | dnn_mem_t workspace(src_md, ref_engine, {false, (void *)&brgemm_attr}); |
590 | ref_args.set(DNNL_ARG_WORKSPACE, workspace); |
591 | |
592 | check_correctness(prb, {DST}, args, ref_args, setup_cmp, res); |
593 | } |
594 | |
595 | // Create a bind to match internals to run performance measurements. |
596 | perf_function_t perf_func = std::bind(brgemm_kernel_execute_postops_wrapper, |
597 | brgemm_kernel_, prb->batch_size, v_batch_element.data(), acc_ptr, |
598 | dst_ptr, post_ops_data, scratchpad_ptr, std::placeholders::_1, |
599 | std::placeholders::_2); |
600 | measure_perf(prb->ctx_exe, res, perf_func, args); |
601 | |
602 | if (is_tmm) DNN_SAFE(amx_tile_release(), WARN); |
603 | |
604 | return OK; |
605 | } |
606 | |
607 | #else |
608 | |
609 | int doit(const prb_t *prb, res_t *res) { |
610 | return OK; |
611 | } |
612 | |
613 | #endif |
614 | |
615 | } // namespace brgemm |
616 | |