1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <float.h> |
18 | #include <math.h> |
19 | #include <random> |
20 | #include <stdio.h> |
21 | #include <stdlib.h> |
22 | |
23 | #include "oneapi/dnnl/dnnl.h" |
24 | |
25 | #include "utils/parallel.hpp" |
26 | |
27 | #include "dnnl_common.hpp" |
28 | #include "dnnl_memory.hpp" |
29 | |
30 | #include "binary/binary.hpp" |
31 | #include "matmul/matmul.hpp" |
32 | |
33 | namespace matmul { |
34 | |
35 | void prep_bia_dims(const prb_t *prb, dims_t &bia_dims) { |
36 | bia_dims.resize(prb->ndims); |
37 | for (int d = 0; d < prb->ndims; ++d) |
38 | bia_dims[d] = (prb->bia_mask & (1 << d)) ? prb->dst_dims[d] : 1; |
39 | } |
40 | |
41 | dims_t get_runtime_dims(const dims_t &dims, const dims_mask_t &mask) { |
42 | if (mask.none() || dims.empty()) return dims; |
43 | dims_t runtime_dims; |
44 | runtime_dims.resize(dims.size()); |
45 | for (size_t i = 0; i < dims.size(); ++i) { |
46 | runtime_dims[i] = mask[i] ? DNNL_RUNTIME_DIM_VAL : dims[i]; |
47 | } |
48 | return runtime_dims; |
49 | } |
50 | |
51 | dnnl_status_t init_pd(init_pd_args_t<prb_t> &init_pd_args) { |
52 | const prb_t *prb = init_pd_args.prb; |
53 | |
54 | const auto &src_rt_dims |
55 | = get_runtime_dims(prb->src_dims(), prb->src_runtime_dim_mask()); |
56 | const auto &weights_rt_dims = get_runtime_dims( |
57 | prb->weights_dims(), prb->weights_runtime_dim_mask()); |
58 | const auto &dst_rt_dims |
59 | = get_runtime_dims(prb->dst_dims, prb->dst_runtime_dim_mask()); |
60 | |
61 | auto src_d = dnn_mem_t::init_md(prb->ndims, src_rt_dims.data(), |
62 | prb->src_dt(), prb->stag, prb->strides[STRIDES_SRC]); |
63 | auto wei_d = dnn_mem_t::init_md(prb->ndims, weights_rt_dims.data(), |
64 | prb->wei_dt(), prb->wtag, prb->strides[STRIDES_WEI]); |
65 | auto dst_d = dnn_mem_t::init_md(prb->ndims, dst_rt_dims.data(), |
66 | prb->dst_dt(), prb->dtag, prb->strides[STRIDES_DST]); |
67 | |
68 | benchdnn_dnnl_wrapper_t<dnnl_memory_desc_t> bia_d {}; |
69 | if (prb->bia_dt != dnnl_data_type_undef) { |
70 | dims_t bia_dims; |
71 | prep_bia_dims(prb, bia_dims); |
72 | bia_dims = get_runtime_dims(bia_dims, prb->dst_runtime_dim_mask()); |
73 | bia_d = dnn_mem_t::init_md(prb->ndims, bia_dims.data(), prb->bia_dt, |
74 | prb->dst_runtime_dim_mask() != 0 ? tag::abx : tag::any); |
75 | } |
76 | |
77 | attr_args_t attr_args; |
78 | attr_args.prepare_post_ops_mds(prb->attr, prb->ndims, prb->dst_dims.data()); |
79 | // Overload PER_OC wei_mask definition for batched case |
80 | auto wei_scale = prb->attr.scales.get(DNNL_ARG_WEIGHTS); |
81 | if (wei_scale.policy == policy_t::PER_OC) { |
82 | int wei_mask = (1 << (dst_rt_dims.size() - 1)); |
83 | attr_args.prepare_scales( |
84 | prb->attr, DNNL_ARG_WEIGHTS, prb->wei_scales, prb->n, wei_mask); |
85 | } |
86 | auto dnnl_attr = make_benchdnn_dnnl_wrapper( |
87 | create_dnnl_attr(prb->attr, attr_args)); |
88 | |
89 | DNN_SAFE_STATUS(dnnl_matmul_primitive_desc_create(&init_pd_args.pd, |
90 | init_pd_args.engine, src_d, wei_d, bia_d, dst_d, dnnl_attr)); |
91 | |
92 | return dnnl_success; |
93 | } |
94 | |
95 | int init_prim_ref( |
96 | benchdnn_dnnl_wrapper_t<dnnl_primitive_t> &prim_ref, const prb_t *prb) { |
97 | if (!(is_bench_mode(CORR) && is_gpu() && fast_ref_gpu)) return OK; |
98 | |
99 | // Create a new copy of prb to avoid potentially corrupting the test by |
100 | // modifying prb in place. |
101 | const auto cpu_bia_dt = prb->bia_dt == dnnl_data_type_undef |
102 | ? dnnl_data_type_undef |
103 | : dnnl_f32; |
104 | const auto cpu_bia_mask |
105 | = prb->bia_dt == dnnl_data_type_undef ? 0 : prb->bia_mask; |
106 | auto cpu_attr = prb->attr; |
107 | update_cpu_ref_attrs(cpu_attr); |
108 | prb_t prb_cpu {*prb, {dnnl_f32}, tag::abx, tag::abx, tag::abx, |
109 | {vdims_t(STRIDES_SIZE)}, cpu_bia_dt, cpu_bia_mask, {0, 0, 0}, |
110 | cpu_attr, prb->ctx_init, prb->ctx_exe}; |
111 | |
112 | init_pd_args_t<prb_t> init_pd_args( |
113 | /* res = */ nullptr, get_cpu_engine(), &prb_cpu, prb->dir, |
114 | /* hint = */ nullptr); |
115 | init_pd(init_pd_args); |
116 | |
117 | benchdnn_dnnl_wrapper_t<dnnl_primitive_desc_t> pdw; |
118 | fetch_impl(pdw, init_pd_args, /* res = */ nullptr, |
119 | /* is_service_prim = */ true); |
120 | |
121 | dnnl_primitive_t prim_ref_ {}; |
122 | if (pdw) { |
123 | if (query_impl_info(pdw) == "ref:any" ) return OK; |
124 | DNN_SAFE(dnnl_primitive_create(&prim_ref_, pdw), WARN); |
125 | BENCHDNN_PRINT(5, "CPU reference oneDNN implementation: %s\n" , |
126 | query_impl_info(pdw).c_str()); |
127 | } |
128 | prim_ref.reset(prim_ref_); |
129 | return OK; |
130 | } |
131 | |
132 | int fill_data(data_kind_t kind, const prb_t *prb, dnn_mem_t &mem_dt, |
133 | dnn_mem_t &mem_fp, res_t *res) { |
134 | |
135 | const auto nelems = mem_dt.nelems(); |
136 | if (nelems == 0) return OK; |
137 | |
138 | assert(mem_dt.nelems() == mem_fp.nelems()); |
139 | |
140 | cfg_t cfg(prb, {SRC, WEI, BIA, DST}); |
141 | cfg_t::density_args_t density_args; |
142 | density_args.data_kind = kind; |
143 | density_args.n_acc = prb->k; |
144 | const auto density = cfg.get_density(density_args); |
145 | |
146 | /* Do fixed partitioning to have same filling for any number of threads */ |
147 | const int64_t n_chunks = 16; |
148 | const int64_t chunk_size = div_up(nelems, n_chunks); |
149 | |
150 | benchdnn_parallel_nd(n_chunks, [&](int64_t idx_chunk) { |
151 | int64_t idx_start = idx_chunk * chunk_size; |
152 | int64_t idx_end = MIN2(idx_start + chunk_size, nelems); |
153 | // Note: we use a different seed for each chunk to avoid |
154 | // repeating patterns. We could use discard(idx_start) too but |
155 | // it has a complexity in O(idx_start). We also add 1 to avoid |
156 | // seeding with 0. |
157 | std::minstd_rand int_seed(kind * nelems + idx_start + 1); |
158 | int_seed.discard(1); |
159 | std::minstd_rand b_seed(kind * nelems + idx_start + 1); |
160 | b_seed.discard(10); |
161 | |
162 | std::uniform_int_distribution<> gen( |
163 | cfg.get_range_min(kind), cfg.get_range_max(kind)); |
164 | std::bernoulli_distribution b_dist(density); |
165 | |
166 | // make sure the first element is positive |
167 | if (idx_start == 0) { |
168 | float val = 0; |
169 | while (val <= 0) |
170 | val = gen(int_seed); |
171 | mem_fp.set_elem( |
172 | 0, round_to_nearest_representable(cfg.get_dt(kind), val)); |
173 | idx_start += 1; |
174 | } |
175 | |
176 | for (int64_t idx = idx_start; idx < idx_end; ++idx) { |
177 | bool is_one = density == 1.f ? true : b_dist(b_seed); |
178 | float val = is_one * gen(int_seed); |
179 | mem_fp.set_elem( |
180 | idx, round_to_nearest_representable(cfg.get_dt(kind), val)); |
181 | } |
182 | }); |
183 | |
184 | const bool swap_dt |
185 | = kind == DST && cfg.get_orig_dt(kind) != cfg.get_dt(kind); |
186 | if (swap_dt) mem_dt.set_dt(cfg.get_dt(kind)); |
187 | SAFE(mem_dt.reorder(mem_fp), WARN); |
188 | if (swap_dt) mem_dt.set_dt(cfg.get_orig_dt(kind)); |
189 | |
190 | return OK; |
191 | } |
192 | |
193 | void skip_unimplemented_prb(const prb_t *prb, res_t *res) { |
194 | skip_unimplemented_data_type( |
195 | {prb->src_dt(), prb->wei_dt(), prb->bia_dt, prb->dst_dt()}, |
196 | prb->dir, res); |
197 | skip_unimplemented_sum_po(prb->attr, res, prb->dst_dt()); |
198 | |
199 | if (is_gpu()) { |
200 | // GPU supports only single zero-point per tensor. |
201 | if (prb->attr.zero_points.get(DNNL_ARG_SRC).policy != policy_t::COMMON |
202 | || prb->attr.zero_points.get(DNNL_ARG_DST).policy |
203 | != policy_t::COMMON) { |
204 | res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED; |
205 | return; |
206 | } |
207 | |
208 | // GPU supports only default sum_dt argument. |
209 | const auto &po = prb->attr.post_ops; |
210 | const int sum_idx = po.find(attr_t::post_ops_t::kind_t::SUM); |
211 | if (sum_idx != -1 && po.entry[sum_idx].sum.dt != dnnl_data_type_undef) { |
212 | res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED; |
213 | return; |
214 | } |
215 | |
216 | // GPU for x8s8bf16 doesn't support: |
217 | // * Destination zero-point. |
218 | // * Any run-time dimensions. |
219 | // * Any batch dimensions. |
220 | const bool is_x8s8bf16 |
221 | = prb->wei_dt() == dnnl_s8 && prb->dst_dt() == dnnl_bf16; |
222 | const bool rt_dims_are_none = prb->src_runtime_dim_mask().none() |
223 | && prb->weights_runtime_dim_mask().none() |
224 | && prb->dst_runtime_dim_mask().none(); |
225 | const bool x8s8bf16_ok = IMPLICATION(is_x8s8bf16, |
226 | prb->attr.zero_points.get(DNNL_ARG_DST).is_def() |
227 | && rt_dims_are_none && prb->ndims <= 2); |
228 | if (!x8s8bf16_ok) { |
229 | res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED; |
230 | return; |
231 | } |
232 | |
233 | // GPU supports bf16 bias only for bf16 config, with a single batch dim. |
234 | const bool is_bf16 = prb->src_dt() == dnnl_bf16 |
235 | && prb->wei_dt() == dnnl_bf16 |
236 | && (prb->dst_dt() == dnnl_bf16 || prb->dst_dt() == dnnl_f32); |
237 | const bool bf16_bias_ok = IMPLICATION( |
238 | prb->bia_dt == dnnl_bf16, prb->ndims <= 2 + is_bf16); |
239 | if (!bf16_bias_ok) { |
240 | res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED; |
241 | return; |
242 | } |
243 | } |
244 | } |
245 | |
246 | void skip_invalid_prb(const prb_t *prb, res_t *res) { |
247 | // Zero-points for non-integral data type does not make sense |
248 | if (!prb->attr.zero_points.is_def() && prb->wei_dt() != dnnl_s8) { |
249 | res->state = SKIPPED, res->reason = INVALID_CASE; |
250 | return; |
251 | } |
252 | |
253 | auto src_rt_mask = prb->src_runtime_dim_mask(); |
254 | auto wei_rt_mask = prb->weights_runtime_dim_mask(); |
255 | auto dst_rt_mask = prb->dst_runtime_dim_mask(); |
256 | |
257 | // Memory layouts must be defined when some dimensions are unknown at pd |
258 | // creation time. |
259 | if ((src_rt_mask.any() && prb->stag == "any" ) |
260 | || (wei_rt_mask.any() && prb->wtag == "any" ) |
261 | || (dst_rt_mask.any() && prb->dtag == "any" )) { |
262 | res->state = SKIPPED, res->reason = INVALID_CASE; |
263 | return; |
264 | } |
265 | |
266 | // Runtime masks for `m`, `k`, and `n` dimensions must be consistent. |
267 | const int m_idx = prb->ndims - 2; |
268 | const int k_idx_src = prb->ndims - 1; |
269 | const int k_idx_wei = prb->ndims - 2; |
270 | const int n_idx = prb->ndims - 1; |
271 | if (src_rt_mask[m_idx] != dst_rt_mask[m_idx] |
272 | || src_rt_mask[k_idx_src] != wei_rt_mask[k_idx_wei] |
273 | || wei_rt_mask[n_idx] != dst_rt_mask[n_idx]) { |
274 | res->state = SKIPPED, res->reason = INVALID_CASE; |
275 | return; |
276 | } |
277 | |
278 | // Runtime masks for batch dimensions must be consistent. |
279 | if (prb->ndims > 2) { |
280 | dims_mask_t batch_rt_mask; |
281 | for (int i = 0; i < prb->ndims - 2; ++i) |
282 | batch_rt_mask[i] = true; |
283 | src_rt_mask &= batch_rt_mask; |
284 | wei_rt_mask &= batch_rt_mask; |
285 | dst_rt_mask &= batch_rt_mask; |
286 | if (src_rt_mask != wei_rt_mask || src_rt_mask != dst_rt_mask) { |
287 | res->state = SKIPPED, res->reason = INVALID_CASE; |
288 | return; |
289 | } |
290 | } |
291 | } |
292 | |
293 | void setup_cmp(compare::compare_t &cmp, const prb_t *prb, data_kind_t kind, |
294 | const args_t &ref_args) { |
295 | const auto dt = prb->get_dt(kind); |
296 | const float trh = dt == dnnl_f32 ? 1e-6f : epsilon_dt(dt); |
297 | cmp.set_threshold(trh); |
298 | cmp.set_zero_trust_percent(90.f); // TODO: why so bad filling? |
299 | } |
300 | |
301 | int doit(const prb_t *prb, res_t *res) { |
302 | if (bench_mode == LIST) return res->state = LISTED, OK; |
303 | |
304 | benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim; |
305 | SAFE(init_prim(prb->ctx_init, prim, init_pd, prb, res), WARN); |
306 | if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK; |
307 | |
308 | auto const_pd = query_pd(prim); |
309 | |
310 | benchdnn_dnnl_wrapper_t<dnnl_memory_desc_t> src_md {}, wei_md {}, dst_md {}, |
311 | bia_md {}, def_md {}; |
312 | // query md if it was defined at pd creation time |
313 | if (prb->src_runtime_dim_mask().none()) |
314 | src_md.reset(clone_md(query_md(const_pd, DNNL_ARG_SRC))); |
315 | if (prb->weights_runtime_dim_mask().none()) |
316 | wei_md.reset(clone_md(query_md(const_pd, DNNL_ARG_WEIGHTS))); |
317 | if (prb->dst_runtime_dim_mask().none()) { |
318 | dst_md.reset(clone_md(query_md(const_pd, DNNL_ARG_DST))); |
319 | if (prb->bia_dt != dnnl_data_type_undef) |
320 | bia_md.reset(clone_md(query_md(const_pd, DNNL_ARG_BIAS))); |
321 | } |
322 | |
323 | // if md is same as default, it means we need to re-create it |
324 | const auto &src_dims = prb->src_dims(); |
325 | if (dnnl_memory_desc_equal(src_md, def_md)) { |
326 | assert(prb->stag != tag::any); |
327 | src_md = dnn_mem_t::init_md(prb->ndims, src_dims.data(), prb->src_dt(), |
328 | prb->stag, prb->strides[STRIDES_SRC]); |
329 | } |
330 | |
331 | const auto &weights_dims = prb->weights_dims(); |
332 | if (dnnl_memory_desc_equal(wei_md, def_md)) { |
333 | assert(prb->wtag != tag::any); |
334 | wei_md = dnn_mem_t::init_md(prb->ndims, weights_dims.data(), |
335 | prb->wei_dt(), prb->wtag, prb->strides[STRIDES_WEI]); |
336 | } |
337 | |
338 | if (dnnl_memory_desc_equal(dst_md, def_md)) { |
339 | assert(prb->dtag != tag::any); |
340 | dst_md = dnn_mem_t::init_md(prb->ndims, prb->dst_dims.data(), |
341 | prb->dst_dt(), prb->dtag, prb->strides[STRIDES_DST]); |
342 | } |
343 | if (prb->bia_dt != dnnl_data_type_undef |
344 | && dnnl_memory_desc_equal(bia_md, def_md)) { |
345 | dims_t bia_dims; |
346 | prep_bia_dims(prb, bia_dims); |
347 | bia_md = dnn_mem_t::init_md( |
348 | prb->ndims, bia_dims.data(), prb->bia_dt, tag::abx); |
349 | } |
350 | |
351 | const auto &scratchpad_md = query_md(const_pd, DNNL_ARG_SCRATCHPAD); |
352 | |
353 | // Use CPU prim as the reference in GPU testing to reduce testing time. |
354 | benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim_ref; |
355 | SAFE(init_prim_ref(prim_ref, prb), WARN); |
356 | |
357 | const auto &test_engine = get_test_engine(); |
358 | const auto &ref_engine = get_cpu_engine(); |
359 | |
360 | dnn_mem_t src_dt(src_md, test_engine); |
361 | dnn_mem_t wei_dt(wei_md, test_engine); |
362 | dnn_mem_t dst_dt(dst_md, test_engine); |
363 | dnn_mem_t bia_dt; |
364 | if (prb->bia_dt != dnnl_data_type_undef) |
365 | bia_dt = dnn_mem_t(bia_md, test_engine); |
366 | dnn_mem_t scratchpad_dt(scratchpad_md, test_engine); |
367 | |
368 | const auto fp = dnnl_f32; |
369 | dnn_mem_t src_fp(src_md, fp, tag::abx, ref_engine); |
370 | dnn_mem_t wei_fp(wei_md, fp, tag::abx, ref_engine); |
371 | dnn_mem_t dst_fp(dst_md, fp, tag::abx, ref_engine); |
372 | dnn_mem_t bia_fp; |
373 | if (prb->bia_dt != dnnl_data_type_undef) |
374 | bia_fp = dnn_mem_t(bia_md, fp, tag::abx, ref_engine); |
375 | dnn_mem_t scratchpad_fp; |
376 | if (prim_ref) |
377 | scratchpad_fp = dnn_mem_t( |
378 | query_md(query_pd(prim_ref), DNNL_ARG_SCRATCHPAD), ref_engine); |
379 | |
380 | SAFE(fill_data(SRC, prb, src_dt, src_fp, res), WARN); |
381 | SAFE(fill_data(WEI, prb, wei_dt, wei_fp, res), WARN); |
382 | const int sum_idx = prb->attr.post_ops.find(attr_t::post_ops_t::SUM); |
383 | if (sum_idx >= 0) SAFE(fill_data(DST, prb, dst_dt, dst_fp, res), WARN); |
384 | if (prb->bia_dt != dnnl_data_type_undef) |
385 | SAFE(fill_data(BIA, prb, bia_dt, bia_fp, res), WARN); |
386 | |
387 | dnn_mem_t src_scales_fp, src_scales_dt, wei_scales_fp, wei_scales_dt, |
388 | dst_scales_fp, dst_scales_dt; |
389 | maybe_prepare_runtime_scales_v2(src_scales_dt, src_scales_fp, |
390 | prb->attr.scales.get(DNNL_ARG_SRC), prb->k, prb->src_scales); |
391 | maybe_prepare_runtime_scales_v2(wei_scales_dt, wei_scales_fp, |
392 | prb->attr.scales.get(DNNL_ARG_WEIGHTS), prb->n, prb->wei_scales); |
393 | maybe_prepare_runtime_scales_v2(dst_scales_dt, dst_scales_fp, |
394 | prb->attr.scales.get(DNNL_ARG_DST), prb->n, prb->dst_scales); |
395 | |
396 | dnn_mem_t src_zp_dt, src_zp_fp, wei_zp_dt, wei_zp_fp, dst_zp_dt, dst_zp_fp; |
397 | const auto &wei_zero_point_val |
398 | = prb->attr.zero_points.get(DNNL_ARG_WEIGHTS).value; |
399 | maybe_prepare_runtime_zero_points_v2( |
400 | src_zp_dt, src_zp_fp, prb->attr, DNNL_ARG_SRC, prb->k, prb->src_zp); |
401 | maybe_prepare_runtime_zero_points_v2(wei_zp_dt, wei_zp_fp, prb->attr, |
402 | DNNL_ARG_WEIGHTS, 1, &(wei_zero_point_val)); |
403 | maybe_prepare_runtime_zero_points_v2( |
404 | dst_zp_dt, dst_zp_fp, prb->attr, DNNL_ARG_DST, prb->n, prb->dst_zp); |
405 | |
406 | std::vector<dnn_mem_t> binary_po_fp, binary_po_dt; |
407 | std::vector<int> binary_po_args; |
408 | SAFE(binary::setup_binary_po(const_pd, binary_po_args, binary_po_dt, |
409 | binary_po_fp, /*only_positive=*/false, /*only_integer=*/true), |
410 | WARN); |
411 | |
412 | args_t args, ref_args; |
413 | |
414 | args.set(DNNL_ARG_SRC, src_dt); |
415 | args.set(DNNL_ARG_WEIGHTS, wei_dt); |
416 | args.set(DNNL_ARG_DST, dst_dt); |
417 | args.set(DNNL_ARG_BIAS, bia_dt); |
418 | args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt); |
419 | args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_scales_dt); |
420 | args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, wei_scales_dt); |
421 | args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, dst_scales_dt); |
422 | args.set(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC, src_zp_dt); |
423 | args.set(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS, wei_zp_dt); |
424 | args.set(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST, dst_zp_dt); |
425 | args.set(binary_po_args, binary_po_dt); |
426 | |
427 | SAFE(execute_and_wait(prim, args, res), WARN); |
428 | |
429 | if (is_bench_mode(CORR)) { |
430 | ref_args.set(DNNL_ARG_SRC, src_fp); |
431 | ref_args.set(DNNL_ARG_WEIGHTS, wei_fp); |
432 | ref_args.set(DNNL_ARG_BIAS, bia_fp); |
433 | ref_args.set(DNNL_ARG_DST, dst_fp); |
434 | ref_args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_scales_fp); |
435 | ref_args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, wei_scales_fp); |
436 | ref_args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, dst_scales_fp); |
437 | ref_args.set(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC, src_zp_fp); |
438 | ref_args.set(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS, wei_zp_fp); |
439 | ref_args.set(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST, dst_zp_fp); |
440 | ref_args.set(DNNL_ARG_SCRATCHPAD, scratchpad_fp); |
441 | ref_args.set(binary_po_args, binary_po_fp); |
442 | |
443 | check_correctness(prb, {DST}, args, ref_args, setup_cmp, res, prim_ref); |
444 | } |
445 | |
446 | return measure_perf(prb->ctx_exe, res, prim, args); |
447 | } |
448 | |
449 | } // namespace matmul |
450 | |