1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "utils/parallel.hpp" |
18 | |
19 | #include "brgemm/brgemm.hpp" |
20 | |
21 | namespace brgemm { |
22 | |
23 | #if defined(DNNL_X64) && DNNL_X64 == 1 && DNNL_CPU_RUNTIME != DNNL_RUNTIME_NONE |
24 | |
25 | int64_t src_off_f(const prb_t *prb, int64_t bs, int64_t m, int64_t k) { |
26 | return (m * prb->batch_size + bs) * prb->k + k; |
27 | } |
28 | |
29 | int64_t wei_off_f(const prb_t *prb, int64_t bs, int64_t k, int64_t n) { |
30 | return (bs * prb->k + k) * prb->n + n; |
31 | } |
32 | |
33 | int64_t dst_off_f(const prb_t *prb, int64_t m, int64_t n) { |
34 | return m * prb->n + n; |
35 | } |
36 | |
37 | void compute_ref_brgemm(const prb_t *prb, const args_t &args) { |
38 | const dnn_mem_t &src_m = args.find(DNNL_ARG_SRC); |
39 | const dnn_mem_t &wei_m = args.find(DNNL_ARG_WEIGHTS); |
40 | const dnn_mem_t &bia_m = args.find(DNNL_ARG_BIAS); |
41 | const dnn_mem_t &acc_m = args.find(DNNL_ARG_SRC_1); |
42 | const dnn_mem_t &dst_m = args.find(DNNL_ARG_DST); |
43 | const dnn_mem_t &ws_m = args.find(DNNL_ARG_WORKSPACE); |
44 | const int64_t BS = prb->batch_size; |
45 | const int64_t M = prb->m; |
46 | const int64_t N = prb->n; |
47 | const int64_t K = prb->k; |
48 | |
49 | // Using workspace memory as a method to get brgemm attributes. |
50 | using brgemm_attr_t = dnnl::impl::cpu::x64::brgemm_attr_t; |
51 | brgemm_attr_t *brgemm_attr = (brgemm_attr_t *)ws_m; |
52 | |
53 | const int wei_zero_point = prb->attr.zero_points[DNNL_ARG_WEIGHTS]; |
54 | |
55 | dnn_mem_t dst_tmp(dst_m, dnnl_f32, tag::abx, dst_m.engine()); |
56 | |
57 | const auto alpha = prb->alpha; |
58 | const auto beta = prb->beta; |
59 | |
60 | if (!brgemm_attr->generate_skip_accumulation) { |
61 | benchdnn_parallel_nd(M, N, [&](int64_t m, int64_t n) { |
62 | auto src = (const float *)src_m; |
63 | auto wei = (const float *)wei_m; |
64 | |
65 | float res = 0; |
66 | for_(int64_t bs = 0; bs < BS; bs++) |
67 | for (int64_t k = 0; k < K; ++k) { |
68 | auto s = src[src_off_f(prb, bs, m, k)]; |
69 | maybe_zero_point(prb->attr, s, prb->src_zp, k, DNNL_ARG_SRC); |
70 | auto w = wei[wei_off_f(prb, bs, k, n)] - wei_zero_point; |
71 | res += alpha * s * w; |
72 | } |
73 | float &dst = ((float *)dst_tmp)[dst_off_f(prb, m, n)]; |
74 | float acc = ((float *)acc_m)[dst_off_f(prb, m, n)]; |
75 | dst = res + (beta != 0 ? beta * acc : 0); |
76 | }); |
77 | } else { |
78 | benchdnn_parallel_nd(M, N, [&](int64_t m, int64_t n) { |
79 | float &dst = ((float *)dst_tmp)[dst_off_f(prb, m, n)]; |
80 | float acc = ((float *)acc_m)[dst_off_f(prb, m, n)]; |
81 | dst = beta * acc; |
82 | }); |
83 | } |
84 | |
85 | auto wei_scale = prb->attr.scales.get(DNNL_ARG_WEIGHTS); |
86 | auto attr_scale_arg = wei_scale.runtime ? DNNL_ARG_WEIGHTS : DNNL_ARG_SRC; |
87 | |
88 | auto v_po_masks = prb->attr.post_ops.get_po_masks(); |
89 | static constexpr int bias_broadcast_mask = 2; |
90 | benchdnn_parallel_nd(M, N, [&](int64_t m, int64_t n) { |
91 | size_t dst_off = dst_off_f(prb, m, n); |
92 | float &dst = ((float *)dst_m)[dst_off]; |
93 | |
94 | float tmp = ((float *)dst_tmp)[dst_off]; |
95 | if (prb->bia_dt != dnnl_data_type_undef) { |
96 | int64_t bia_off = dst_m.get_scale_idx(dst_off, bias_broadcast_mask); |
97 | float *bia_ptr = (float *)bia_m; |
98 | tmp += bia_ptr[bia_off]; |
99 | } |
100 | maybe_scale(prb->attr, tmp, prb->scales, n, attr_scale_arg); |
101 | |
102 | const auto v_po_vals |
103 | = prepare_po_vals(dst_m, args, v_po_masks, dst_off); |
104 | |
105 | maybe_post_ops(prb->attr, tmp, dst, v_po_vals); |
106 | |
107 | maybe_zero_point(prb->attr, tmp, prb->dst_zp, n, DNNL_ARG_DST, true); |
108 | dst = tmp; |
109 | }); |
110 | } |
111 | |
112 | void compute_ref( |
113 | const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) { |
114 | if (prim_ref) { |
115 | SAFE_V(execute_and_wait(prim_ref, args)); |
116 | return; |
117 | } |
118 | |
119 | compute_ref_brgemm(prb, args); |
120 | } |
121 | |
122 | #else |
123 | |
124 | void compute_ref( |
125 | const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) {} |
126 | |
127 | #endif |
128 | |
129 | } // namespace brgemm |
130 | |