1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "utils/parallel.hpp" |
18 | |
19 | #include "matmul/matmul.hpp" |
20 | |
21 | namespace matmul { |
22 | |
23 | void compute_ref_matmul(const prb_t *prb, const args_t &args) { |
24 | const dnn_mem_t &src_m = args.find(DNNL_ARG_SRC); |
25 | const dnn_mem_t &wei_m = args.find(DNNL_ARG_WEIGHTS); |
26 | const dnn_mem_t &bia_m = args.find(DNNL_ARG_BIAS); |
27 | const dnn_mem_t &dst_m = args.find(DNNL_ARG_DST); |
28 | const int64_t M = prb->m; |
29 | const int64_t N = prb->n; |
30 | const int64_t K = prb->k; |
31 | const int64_t MB = prb->mb; |
32 | const int batch_ndims = dst_m.ndims() - 2; |
33 | |
34 | // Fast return if any dim is zero. Common logic doesn't apply because of |
35 | // broadcast semantics. |
36 | for (int d = 0; d < dst_m.ndims(); d++) { |
37 | if (prb->src_dims()[d] == 0 || prb->weights_dims()[d] == 0) return; |
38 | } |
39 | |
40 | const int wei_zero_point = prb->attr.zero_points[DNNL_ARG_WEIGHTS]; |
41 | |
42 | dnn_mem_t dst_tmp(dst_m, dnnl_f32, tag::abx, dst_m.engine()); |
43 | |
44 | const auto src_broadcast_mask = prb->src_broadcast_mask(); |
45 | const auto wei_broadcast_mask = prb->weights_broadcast_mask(); |
46 | |
47 | benchdnn_parallel_nd(MB, M, N, [&](int64_t mb, int64_t m, int64_t n) { |
48 | auto src = (const float *)src_m; |
49 | auto wei = (const float *)wei_m; |
50 | |
51 | float dst = 0; |
52 | const int64_t src_mb |
53 | = dst_m.get_scale_idx(mb, src_broadcast_mask, batch_ndims); |
54 | const int64_t wei_mb |
55 | = dst_m.get_scale_idx(mb, wei_broadcast_mask, batch_ndims); |
56 | for (int64_t k = 0; k < K; ++k) { |
57 | auto s = src[src_off_f(prb, src_mb, m, k)]; |
58 | maybe_zero_point(prb->attr, s, prb->src_zp, k, DNNL_ARG_SRC); |
59 | dst += s * (wei[wei_off_f(prb, wei_mb, k, n)] - wei_zero_point); |
60 | } |
61 | ((float *)dst_tmp)[dst_off_f(prb, mb, m, n)] = dst; |
62 | }); |
63 | |
64 | auto v_po_masks = prb->attr.post_ops.get_po_masks(); |
65 | const auto bias_broadcast_mask = prb->bias_broadcast_mask(); |
66 | benchdnn_parallel_nd(MB, M, N, [&](int64_t mb, int64_t m, int64_t n) { |
67 | size_t dst_off = dst_off_f(prb, mb, m, n); |
68 | float &dst = ((float *)dst_m)[dst_off]; |
69 | |
70 | float tmp = ((float *)dst_tmp)[dst_off]; |
71 | maybe_scale(prb->attr, tmp, prb->src_scales, 0, DNNL_ARG_SRC); |
72 | maybe_scale(prb->attr, tmp, prb->wei_scales, n, DNNL_ARG_WEIGHTS); |
73 | |
74 | if (prb->bia_dt != dnnl_data_type_undef) { |
75 | int64_t bia_off = dst_m.get_scale_idx(dst_off, bias_broadcast_mask); |
76 | float *bia_ptr = (float *)bia_m; |
77 | tmp += bia_ptr[bia_off]; |
78 | } |
79 | |
80 | const auto v_po_vals |
81 | = prepare_po_vals(dst_m, args, v_po_masks, dst_off); |
82 | |
83 | maybe_post_ops(prb->attr, tmp, dst, v_po_vals); |
84 | |
85 | maybe_scale(prb->attr, tmp, prb->dst_scales, n, DNNL_ARG_DST, true); |
86 | maybe_zero_point(prb->attr, tmp, prb->dst_zp, n, DNNL_ARG_DST, true); |
87 | dst = tmp; |
88 | }); |
89 | } |
90 | |
91 | void compute_ref( |
92 | const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) { |
93 | if (prim_ref) { |
94 | SAFE_V(execute_and_wait(prim_ref, args)); |
95 | return; |
96 | } |
97 | |
98 | compute_ref_matmul(prb, args); |
99 | } |
100 | |
101 | } // namespace matmul |
102 | |