1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <float.h>
18#include <math.h>
19#include <random>
20#include <stdio.h>
21#include <stdlib.h>
22
23#include "oneapi/dnnl/dnnl.h"
24
25#include "utils/parallel.hpp"
26
27#include "dnnl_common.hpp"
28#include "dnnl_memory.hpp"
29
30#include "binary/binary.hpp"
31#include "matmul/matmul.hpp"
32
33namespace matmul {
34
35void prep_bia_dims(const prb_t *prb, dims_t &bia_dims) {
36 bia_dims.resize(prb->ndims);
37 for (int d = 0; d < prb->ndims; ++d)
38 bia_dims[d] = (prb->bia_mask & (1 << d)) ? prb->dst_dims[d] : 1;
39}
40
41dims_t get_runtime_dims(const dims_t &dims, const dims_mask_t &mask) {
42 if (mask.none() || dims.empty()) return dims;
43 dims_t runtime_dims;
44 runtime_dims.resize(dims.size());
45 for (size_t i = 0; i < dims.size(); ++i) {
46 runtime_dims[i] = mask[i] ? DNNL_RUNTIME_DIM_VAL : dims[i];
47 }
48 return runtime_dims;
49}
50
51dnnl_status_t init_pd(init_pd_args_t<prb_t> &init_pd_args) {
52 const prb_t *prb = init_pd_args.prb;
53
54 const auto &src_rt_dims
55 = get_runtime_dims(prb->src_dims(), prb->src_runtime_dim_mask());
56 const auto &weights_rt_dims = get_runtime_dims(
57 prb->weights_dims(), prb->weights_runtime_dim_mask());
58 const auto &dst_rt_dims
59 = get_runtime_dims(prb->dst_dims, prb->dst_runtime_dim_mask());
60
61 auto src_d = dnn_mem_t::init_md(prb->ndims, src_rt_dims.data(),
62 prb->src_dt(), prb->stag, prb->strides[STRIDES_SRC]);
63 auto wei_d = dnn_mem_t::init_md(prb->ndims, weights_rt_dims.data(),
64 prb->wei_dt(), prb->wtag, prb->strides[STRIDES_WEI]);
65 auto dst_d = dnn_mem_t::init_md(prb->ndims, dst_rt_dims.data(),
66 prb->dst_dt(), prb->dtag, prb->strides[STRIDES_DST]);
67
68 benchdnn_dnnl_wrapper_t<dnnl_memory_desc_t> bia_d {};
69 if (prb->bia_dt != dnnl_data_type_undef) {
70 dims_t bia_dims;
71 prep_bia_dims(prb, bia_dims);
72 bia_dims = get_runtime_dims(bia_dims, prb->dst_runtime_dim_mask());
73 bia_d = dnn_mem_t::init_md(prb->ndims, bia_dims.data(), prb->bia_dt,
74 prb->dst_runtime_dim_mask() != 0 ? tag::abx : tag::any);
75 }
76
77 attr_args_t attr_args;
78 attr_args.prepare_post_ops_mds(prb->attr, prb->ndims, prb->dst_dims.data());
79 // Overload PER_OC wei_mask definition for batched case
80 auto wei_scale = prb->attr.scales.get(DNNL_ARG_WEIGHTS);
81 if (wei_scale.policy == policy_t::PER_OC) {
82 int wei_mask = (1 << (dst_rt_dims.size() - 1));
83 attr_args.prepare_scales(
84 prb->attr, DNNL_ARG_WEIGHTS, prb->wei_scales, prb->n, wei_mask);
85 }
86 auto dnnl_attr = make_benchdnn_dnnl_wrapper(
87 create_dnnl_attr(prb->attr, attr_args));
88
89 DNN_SAFE_STATUS(dnnl_matmul_primitive_desc_create(&init_pd_args.pd,
90 init_pd_args.engine, src_d, wei_d, bia_d, dst_d, dnnl_attr));
91
92 return dnnl_success;
93}
94
95int init_prim_ref(
96 benchdnn_dnnl_wrapper_t<dnnl_primitive_t> &prim_ref, const prb_t *prb) {
97 if (!(is_bench_mode(CORR) && is_gpu() && fast_ref_gpu)) return OK;
98
99 // Create a new copy of prb to avoid potentially corrupting the test by
100 // modifying prb in place.
101 const auto cpu_bia_dt = prb->bia_dt == dnnl_data_type_undef
102 ? dnnl_data_type_undef
103 : dnnl_f32;
104 const auto cpu_bia_mask
105 = prb->bia_dt == dnnl_data_type_undef ? 0 : prb->bia_mask;
106 auto cpu_attr = prb->attr;
107 update_cpu_ref_attrs(cpu_attr);
108 prb_t prb_cpu {*prb, {dnnl_f32}, tag::abx, tag::abx, tag::abx,
109 {vdims_t(STRIDES_SIZE)}, cpu_bia_dt, cpu_bia_mask, {0, 0, 0},
110 cpu_attr, prb->ctx_init, prb->ctx_exe};
111
112 init_pd_args_t<prb_t> init_pd_args(
113 /* res = */ nullptr, get_cpu_engine(), &prb_cpu, prb->dir,
114 /* hint = */ nullptr);
115 init_pd(init_pd_args);
116
117 benchdnn_dnnl_wrapper_t<dnnl_primitive_desc_t> pdw;
118 fetch_impl(pdw, init_pd_args, /* res = */ nullptr,
119 /* is_service_prim = */ true);
120
121 dnnl_primitive_t prim_ref_ {};
122 if (pdw) {
123 if (query_impl_info(pdw) == "ref:any") return OK;
124 DNN_SAFE(dnnl_primitive_create(&prim_ref_, pdw), WARN);
125 BENCHDNN_PRINT(5, "CPU reference oneDNN implementation: %s\n",
126 query_impl_info(pdw).c_str());
127 }
128 prim_ref.reset(prim_ref_);
129 return OK;
130}
131
132int fill_data(data_kind_t kind, const prb_t *prb, dnn_mem_t &mem_dt,
133 dnn_mem_t &mem_fp, res_t *res) {
134
135 const auto nelems = mem_dt.nelems();
136 if (nelems == 0) return OK;
137
138 assert(mem_dt.nelems() == mem_fp.nelems());
139
140 cfg_t cfg(prb, {SRC, WEI, BIA, DST});
141 cfg_t::density_args_t density_args;
142 density_args.data_kind = kind;
143 density_args.n_acc = prb->k;
144 const auto density = cfg.get_density(density_args);
145
146 /* Do fixed partitioning to have same filling for any number of threads */
147 const int64_t n_chunks = 16;
148 const int64_t chunk_size = div_up(nelems, n_chunks);
149
150 benchdnn_parallel_nd(n_chunks, [&](int64_t idx_chunk) {
151 int64_t idx_start = idx_chunk * chunk_size;
152 int64_t idx_end = MIN2(idx_start + chunk_size, nelems);
153 // Note: we use a different seed for each chunk to avoid
154 // repeating patterns. We could use discard(idx_start) too but
155 // it has a complexity in O(idx_start). We also add 1 to avoid
156 // seeding with 0.
157 std::minstd_rand int_seed(kind * nelems + idx_start + 1);
158 int_seed.discard(1);
159 std::minstd_rand b_seed(kind * nelems + idx_start + 1);
160 b_seed.discard(10);
161
162 std::uniform_int_distribution<> gen(
163 cfg.get_range_min(kind), cfg.get_range_max(kind));
164 std::bernoulli_distribution b_dist(density);
165
166 // make sure the first element is positive
167 if (idx_start == 0) {
168 float val = 0;
169 while (val <= 0)
170 val = gen(int_seed);
171 mem_fp.set_elem(
172 0, round_to_nearest_representable(cfg.get_dt(kind), val));
173 idx_start += 1;
174 }
175
176 for (int64_t idx = idx_start; idx < idx_end; ++idx) {
177 bool is_one = density == 1.f ? true : b_dist(b_seed);
178 float val = is_one * gen(int_seed);
179 mem_fp.set_elem(
180 idx, round_to_nearest_representable(cfg.get_dt(kind), val));
181 }
182 });
183
184 const bool swap_dt
185 = kind == DST && cfg.get_orig_dt(kind) != cfg.get_dt(kind);
186 if (swap_dt) mem_dt.set_dt(cfg.get_dt(kind));
187 SAFE(mem_dt.reorder(mem_fp), WARN);
188 if (swap_dt) mem_dt.set_dt(cfg.get_orig_dt(kind));
189
190 return OK;
191}
192
193void skip_unimplemented_prb(const prb_t *prb, res_t *res) {
194 skip_unimplemented_data_type(
195 {prb->src_dt(), prb->wei_dt(), prb->bia_dt, prb->dst_dt()},
196 prb->dir, res);
197 skip_unimplemented_sum_po(prb->attr, res, prb->dst_dt());
198
199 if (is_gpu()) {
200 // GPU supports only single zero-point per tensor.
201 if (prb->attr.zero_points.get(DNNL_ARG_SRC).policy != policy_t::COMMON
202 || prb->attr.zero_points.get(DNNL_ARG_DST).policy
203 != policy_t::COMMON) {
204 res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED;
205 return;
206 }
207
208 // GPU supports only default sum_dt argument.
209 const auto &po = prb->attr.post_ops;
210 const int sum_idx = po.find(attr_t::post_ops_t::kind_t::SUM);
211 if (sum_idx != -1 && po.entry[sum_idx].sum.dt != dnnl_data_type_undef) {
212 res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED;
213 return;
214 }
215
216 // GPU for x8s8bf16 doesn't support:
217 // * Destination zero-point.
218 // * Any run-time dimensions.
219 // * Any batch dimensions.
220 const bool is_x8s8bf16
221 = prb->wei_dt() == dnnl_s8 && prb->dst_dt() == dnnl_bf16;
222 const bool rt_dims_are_none = prb->src_runtime_dim_mask().none()
223 && prb->weights_runtime_dim_mask().none()
224 && prb->dst_runtime_dim_mask().none();
225 const bool x8s8bf16_ok = IMPLICATION(is_x8s8bf16,
226 prb->attr.zero_points.get(DNNL_ARG_DST).is_def()
227 && rt_dims_are_none && prb->ndims <= 2);
228 if (!x8s8bf16_ok) {
229 res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED;
230 return;
231 }
232
233 // GPU supports bf16 bias only for bf16 config, with a single batch dim.
234 const bool is_bf16 = prb->src_dt() == dnnl_bf16
235 && prb->wei_dt() == dnnl_bf16
236 && (prb->dst_dt() == dnnl_bf16 || prb->dst_dt() == dnnl_f32);
237 const bool bf16_bias_ok = IMPLICATION(
238 prb->bia_dt == dnnl_bf16, prb->ndims <= 2 + is_bf16);
239 if (!bf16_bias_ok) {
240 res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED;
241 return;
242 }
243 }
244}
245
246void skip_invalid_prb(const prb_t *prb, res_t *res) {
247 // Zero-points for non-integral data type does not make sense
248 if (!prb->attr.zero_points.is_def() && prb->wei_dt() != dnnl_s8) {
249 res->state = SKIPPED, res->reason = INVALID_CASE;
250 return;
251 }
252
253 auto src_rt_mask = prb->src_runtime_dim_mask();
254 auto wei_rt_mask = prb->weights_runtime_dim_mask();
255 auto dst_rt_mask = prb->dst_runtime_dim_mask();
256
257 // Memory layouts must be defined when some dimensions are unknown at pd
258 // creation time.
259 if ((src_rt_mask.any() && prb->stag == "any")
260 || (wei_rt_mask.any() && prb->wtag == "any")
261 || (dst_rt_mask.any() && prb->dtag == "any")) {
262 res->state = SKIPPED, res->reason = INVALID_CASE;
263 return;
264 }
265
266 // Runtime masks for `m`, `k`, and `n` dimensions must be consistent.
267 const int m_idx = prb->ndims - 2;
268 const int k_idx_src = prb->ndims - 1;
269 const int k_idx_wei = prb->ndims - 2;
270 const int n_idx = prb->ndims - 1;
271 if (src_rt_mask[m_idx] != dst_rt_mask[m_idx]
272 || src_rt_mask[k_idx_src] != wei_rt_mask[k_idx_wei]
273 || wei_rt_mask[n_idx] != dst_rt_mask[n_idx]) {
274 res->state = SKIPPED, res->reason = INVALID_CASE;
275 return;
276 }
277
278 // Runtime masks for batch dimensions must be consistent.
279 if (prb->ndims > 2) {
280 dims_mask_t batch_rt_mask;
281 for (int i = 0; i < prb->ndims - 2; ++i)
282 batch_rt_mask[i] = true;
283 src_rt_mask &= batch_rt_mask;
284 wei_rt_mask &= batch_rt_mask;
285 dst_rt_mask &= batch_rt_mask;
286 if (src_rt_mask != wei_rt_mask || src_rt_mask != dst_rt_mask) {
287 res->state = SKIPPED, res->reason = INVALID_CASE;
288 return;
289 }
290 }
291}
292
293void setup_cmp(compare::compare_t &cmp, const prb_t *prb, data_kind_t kind,
294 const args_t &ref_args) {
295 const auto dt = prb->get_dt(kind);
296 const float trh = dt == dnnl_f32 ? 1e-6f : epsilon_dt(dt);
297 cmp.set_threshold(trh);
298 cmp.set_zero_trust_percent(90.f); // TODO: why so bad filling?
299}
300
301int doit(const prb_t *prb, res_t *res) {
302 if (bench_mode == LIST) return res->state = LISTED, OK;
303
304 benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim;
305 SAFE(init_prim(prb->ctx_init, prim, init_pd, prb, res), WARN);
306 if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK;
307
308 auto const_pd = query_pd(prim);
309
310 benchdnn_dnnl_wrapper_t<dnnl_memory_desc_t> src_md {}, wei_md {}, dst_md {},
311 bia_md {}, def_md {};
312 // query md if it was defined at pd creation time
313 if (prb->src_runtime_dim_mask().none())
314 src_md.reset(clone_md(query_md(const_pd, DNNL_ARG_SRC)));
315 if (prb->weights_runtime_dim_mask().none())
316 wei_md.reset(clone_md(query_md(const_pd, DNNL_ARG_WEIGHTS)));
317 if (prb->dst_runtime_dim_mask().none()) {
318 dst_md.reset(clone_md(query_md(const_pd, DNNL_ARG_DST)));
319 if (prb->bia_dt != dnnl_data_type_undef)
320 bia_md.reset(clone_md(query_md(const_pd, DNNL_ARG_BIAS)));
321 }
322
323 // if md is same as default, it means we need to re-create it
324 const auto &src_dims = prb->src_dims();
325 if (dnnl_memory_desc_equal(src_md, def_md)) {
326 assert(prb->stag != tag::any);
327 src_md = dnn_mem_t::init_md(prb->ndims, src_dims.data(), prb->src_dt(),
328 prb->stag, prb->strides[STRIDES_SRC]);
329 }
330
331 const auto &weights_dims = prb->weights_dims();
332 if (dnnl_memory_desc_equal(wei_md, def_md)) {
333 assert(prb->wtag != tag::any);
334 wei_md = dnn_mem_t::init_md(prb->ndims, weights_dims.data(),
335 prb->wei_dt(), prb->wtag, prb->strides[STRIDES_WEI]);
336 }
337
338 if (dnnl_memory_desc_equal(dst_md, def_md)) {
339 assert(prb->dtag != tag::any);
340 dst_md = dnn_mem_t::init_md(prb->ndims, prb->dst_dims.data(),
341 prb->dst_dt(), prb->dtag, prb->strides[STRIDES_DST]);
342 }
343 if (prb->bia_dt != dnnl_data_type_undef
344 && dnnl_memory_desc_equal(bia_md, def_md)) {
345 dims_t bia_dims;
346 prep_bia_dims(prb, bia_dims);
347 bia_md = dnn_mem_t::init_md(
348 prb->ndims, bia_dims.data(), prb->bia_dt, tag::abx);
349 }
350
351 const auto &scratchpad_md = query_md(const_pd, DNNL_ARG_SCRATCHPAD);
352
353 // Use CPU prim as the reference in GPU testing to reduce testing time.
354 benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim_ref;
355 SAFE(init_prim_ref(prim_ref, prb), WARN);
356
357 const auto &test_engine = get_test_engine();
358 const auto &ref_engine = get_cpu_engine();
359
360 dnn_mem_t src_dt(src_md, test_engine);
361 dnn_mem_t wei_dt(wei_md, test_engine);
362 dnn_mem_t dst_dt(dst_md, test_engine);
363 dnn_mem_t bia_dt;
364 if (prb->bia_dt != dnnl_data_type_undef)
365 bia_dt = dnn_mem_t(bia_md, test_engine);
366 dnn_mem_t scratchpad_dt(scratchpad_md, test_engine);
367
368 const auto fp = dnnl_f32;
369 dnn_mem_t src_fp(src_md, fp, tag::abx, ref_engine);
370 dnn_mem_t wei_fp(wei_md, fp, tag::abx, ref_engine);
371 dnn_mem_t dst_fp(dst_md, fp, tag::abx, ref_engine);
372 dnn_mem_t bia_fp;
373 if (prb->bia_dt != dnnl_data_type_undef)
374 bia_fp = dnn_mem_t(bia_md, fp, tag::abx, ref_engine);
375 dnn_mem_t scratchpad_fp;
376 if (prim_ref)
377 scratchpad_fp = dnn_mem_t(
378 query_md(query_pd(prim_ref), DNNL_ARG_SCRATCHPAD), ref_engine);
379
380 SAFE(fill_data(SRC, prb, src_dt, src_fp, res), WARN);
381 SAFE(fill_data(WEI, prb, wei_dt, wei_fp, res), WARN);
382 const int sum_idx = prb->attr.post_ops.find(attr_t::post_ops_t::SUM);
383 if (sum_idx >= 0) SAFE(fill_data(DST, prb, dst_dt, dst_fp, res), WARN);
384 if (prb->bia_dt != dnnl_data_type_undef)
385 SAFE(fill_data(BIA, prb, bia_dt, bia_fp, res), WARN);
386
387 dnn_mem_t src_scales_fp, src_scales_dt, wei_scales_fp, wei_scales_dt,
388 dst_scales_fp, dst_scales_dt;
389 maybe_prepare_runtime_scales_v2(src_scales_dt, src_scales_fp,
390 prb->attr.scales.get(DNNL_ARG_SRC), prb->k, prb->src_scales);
391 maybe_prepare_runtime_scales_v2(wei_scales_dt, wei_scales_fp,
392 prb->attr.scales.get(DNNL_ARG_WEIGHTS), prb->n, prb->wei_scales);
393 maybe_prepare_runtime_scales_v2(dst_scales_dt, dst_scales_fp,
394 prb->attr.scales.get(DNNL_ARG_DST), prb->n, prb->dst_scales);
395
396 dnn_mem_t src_zp_dt, src_zp_fp, wei_zp_dt, wei_zp_fp, dst_zp_dt, dst_zp_fp;
397 const auto &wei_zero_point_val
398 = prb->attr.zero_points.get(DNNL_ARG_WEIGHTS).value;
399 maybe_prepare_runtime_zero_points_v2(
400 src_zp_dt, src_zp_fp, prb->attr, DNNL_ARG_SRC, prb->k, prb->src_zp);
401 maybe_prepare_runtime_zero_points_v2(wei_zp_dt, wei_zp_fp, prb->attr,
402 DNNL_ARG_WEIGHTS, 1, &(wei_zero_point_val));
403 maybe_prepare_runtime_zero_points_v2(
404 dst_zp_dt, dst_zp_fp, prb->attr, DNNL_ARG_DST, prb->n, prb->dst_zp);
405
406 std::vector<dnn_mem_t> binary_po_fp, binary_po_dt;
407 std::vector<int> binary_po_args;
408 SAFE(binary::setup_binary_po(const_pd, binary_po_args, binary_po_dt,
409 binary_po_fp, /*only_positive=*/false, /*only_integer=*/true),
410 WARN);
411
412 args_t args, ref_args;
413
414 args.set(DNNL_ARG_SRC, src_dt);
415 args.set(DNNL_ARG_WEIGHTS, wei_dt);
416 args.set(DNNL_ARG_DST, dst_dt);
417 args.set(DNNL_ARG_BIAS, bia_dt);
418 args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt);
419 args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_scales_dt);
420 args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, wei_scales_dt);
421 args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, dst_scales_dt);
422 args.set(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC, src_zp_dt);
423 args.set(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS, wei_zp_dt);
424 args.set(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST, dst_zp_dt);
425 args.set(binary_po_args, binary_po_dt);
426
427 SAFE(execute_and_wait(prim, args, res), WARN);
428
429 if (is_bench_mode(CORR)) {
430 ref_args.set(DNNL_ARG_SRC, src_fp);
431 ref_args.set(DNNL_ARG_WEIGHTS, wei_fp);
432 ref_args.set(DNNL_ARG_BIAS, bia_fp);
433 ref_args.set(DNNL_ARG_DST, dst_fp);
434 ref_args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_scales_fp);
435 ref_args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, wei_scales_fp);
436 ref_args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, dst_scales_fp);
437 ref_args.set(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC, src_zp_fp);
438 ref_args.set(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS, wei_zp_fp);
439 ref_args.set(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST, dst_zp_fp);
440 ref_args.set(DNNL_ARG_SCRATCHPAD, scratchpad_fp);
441 ref_args.set(binary_po_args, binary_po_fp);
442
443 check_correctness(prb, {DST}, args, ref_args, setup_cmp, res, prim_ref);
444 }
445
446 return measure_perf(prb->ctx_exe, res, prim, args);
447}
448
449} // namespace matmul
450