1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "utils/parallel.hpp"
18
19#include "matmul/matmul.hpp"
20
21namespace matmul {
22
23void compute_ref_matmul(const prb_t *prb, const args_t &args) {
24 const dnn_mem_t &src_m = args.find(DNNL_ARG_SRC);
25 const dnn_mem_t &wei_m = args.find(DNNL_ARG_WEIGHTS);
26 const dnn_mem_t &bia_m = args.find(DNNL_ARG_BIAS);
27 const dnn_mem_t &dst_m = args.find(DNNL_ARG_DST);
28 const int64_t M = prb->m;
29 const int64_t N = prb->n;
30 const int64_t K = prb->k;
31 const int64_t MB = prb->mb;
32 const int batch_ndims = dst_m.ndims() - 2;
33
34 // Fast return if any dim is zero. Common logic doesn't apply because of
35 // broadcast semantics.
36 for (int d = 0; d < dst_m.ndims(); d++) {
37 if (prb->src_dims()[d] == 0 || prb->weights_dims()[d] == 0) return;
38 }
39
40 const int wei_zero_point = prb->attr.zero_points[DNNL_ARG_WEIGHTS];
41
42 dnn_mem_t dst_tmp(dst_m, dnnl_f32, tag::abx, dst_m.engine());
43
44 const auto src_broadcast_mask = prb->src_broadcast_mask();
45 const auto wei_broadcast_mask = prb->weights_broadcast_mask();
46
47 benchdnn_parallel_nd(MB, M, N, [&](int64_t mb, int64_t m, int64_t n) {
48 auto src = (const float *)src_m;
49 auto wei = (const float *)wei_m;
50
51 float dst = 0;
52 const int64_t src_mb
53 = dst_m.get_scale_idx(mb, src_broadcast_mask, batch_ndims);
54 const int64_t wei_mb
55 = dst_m.get_scale_idx(mb, wei_broadcast_mask, batch_ndims);
56 for (int64_t k = 0; k < K; ++k) {
57 auto s = src[src_off_f(prb, src_mb, m, k)];
58 maybe_zero_point(prb->attr, s, prb->src_zp, k, DNNL_ARG_SRC);
59 dst += s * (wei[wei_off_f(prb, wei_mb, k, n)] - wei_zero_point);
60 }
61 ((float *)dst_tmp)[dst_off_f(prb, mb, m, n)] = dst;
62 });
63
64 auto v_po_masks = prb->attr.post_ops.get_po_masks();
65 const auto bias_broadcast_mask = prb->bias_broadcast_mask();
66 benchdnn_parallel_nd(MB, M, N, [&](int64_t mb, int64_t m, int64_t n) {
67 size_t dst_off = dst_off_f(prb, mb, m, n);
68 float &dst = ((float *)dst_m)[dst_off];
69
70 float tmp = ((float *)dst_tmp)[dst_off];
71 maybe_scale(prb->attr, tmp, prb->src_scales, 0, DNNL_ARG_SRC);
72 maybe_scale(prb->attr, tmp, prb->wei_scales, n, DNNL_ARG_WEIGHTS);
73
74 if (prb->bia_dt != dnnl_data_type_undef) {
75 int64_t bia_off = dst_m.get_scale_idx(dst_off, bias_broadcast_mask);
76 float *bia_ptr = (float *)bia_m;
77 tmp += bia_ptr[bia_off];
78 }
79
80 const auto v_po_vals
81 = prepare_po_vals(dst_m, args, v_po_masks, dst_off);
82
83 maybe_post_ops(prb->attr, tmp, dst, v_po_vals);
84
85 maybe_scale(prb->attr, tmp, prb->dst_scales, n, DNNL_ARG_DST, true);
86 maybe_zero_point(prb->attr, tmp, prb->dst_zp, n, DNNL_ARG_DST, true);
87 dst = tmp;
88 });
89}
90
91void compute_ref(
92 const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) {
93 if (prim_ref) {
94 SAFE_V(execute_and_wait(prim_ref, args));
95 return;
96 }
97
98 compute_ref_matmul(prb, args);
99}
100
101} // namespace matmul
102