1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "utils/parallel.hpp"
18
19#include "brgemm/brgemm.hpp"
20
21namespace brgemm {
22
23#if defined(DNNL_X64) && DNNL_X64 == 1 && DNNL_CPU_RUNTIME != DNNL_RUNTIME_NONE
24
25int64_t src_off_f(const prb_t *prb, int64_t bs, int64_t m, int64_t k) {
26 return (m * prb->batch_size + bs) * prb->k + k;
27}
28
29int64_t wei_off_f(const prb_t *prb, int64_t bs, int64_t k, int64_t n) {
30 return (bs * prb->k + k) * prb->n + n;
31}
32
33int64_t dst_off_f(const prb_t *prb, int64_t m, int64_t n) {
34 return m * prb->n + n;
35}
36
37void compute_ref_brgemm(const prb_t *prb, const args_t &args) {
38 const dnn_mem_t &src_m = args.find(DNNL_ARG_SRC);
39 const dnn_mem_t &wei_m = args.find(DNNL_ARG_WEIGHTS);
40 const dnn_mem_t &bia_m = args.find(DNNL_ARG_BIAS);
41 const dnn_mem_t &acc_m = args.find(DNNL_ARG_SRC_1);
42 const dnn_mem_t &dst_m = args.find(DNNL_ARG_DST);
43 const dnn_mem_t &ws_m = args.find(DNNL_ARG_WORKSPACE);
44 const int64_t BS = prb->batch_size;
45 const int64_t M = prb->m;
46 const int64_t N = prb->n;
47 const int64_t K = prb->k;
48
49 // Using workspace memory as a method to get brgemm attributes.
50 using brgemm_attr_t = dnnl::impl::cpu::x64::brgemm_attr_t;
51 brgemm_attr_t *brgemm_attr = (brgemm_attr_t *)ws_m;
52
53 const int wei_zero_point = prb->attr.zero_points[DNNL_ARG_WEIGHTS];
54
55 dnn_mem_t dst_tmp(dst_m, dnnl_f32, tag::abx, dst_m.engine());
56
57 const auto alpha = prb->alpha;
58 const auto beta = prb->beta;
59
60 if (!brgemm_attr->generate_skip_accumulation) {
61 benchdnn_parallel_nd(M, N, [&](int64_t m, int64_t n) {
62 auto src = (const float *)src_m;
63 auto wei = (const float *)wei_m;
64
65 float res = 0;
66 for_(int64_t bs = 0; bs < BS; bs++)
67 for (int64_t k = 0; k < K; ++k) {
68 auto s = src[src_off_f(prb, bs, m, k)];
69 maybe_zero_point(prb->attr, s, prb->src_zp, k, DNNL_ARG_SRC);
70 auto w = wei[wei_off_f(prb, bs, k, n)] - wei_zero_point;
71 res += alpha * s * w;
72 }
73 float &dst = ((float *)dst_tmp)[dst_off_f(prb, m, n)];
74 float acc = ((float *)acc_m)[dst_off_f(prb, m, n)];
75 dst = res + (beta != 0 ? beta * acc : 0);
76 });
77 } else {
78 benchdnn_parallel_nd(M, N, [&](int64_t m, int64_t n) {
79 float &dst = ((float *)dst_tmp)[dst_off_f(prb, m, n)];
80 float acc = ((float *)acc_m)[dst_off_f(prb, m, n)];
81 dst = beta * acc;
82 });
83 }
84
85 auto wei_scale = prb->attr.scales.get(DNNL_ARG_WEIGHTS);
86 auto attr_scale_arg = wei_scale.runtime ? DNNL_ARG_WEIGHTS : DNNL_ARG_SRC;
87
88 auto v_po_masks = prb->attr.post_ops.get_po_masks();
89 static constexpr int bias_broadcast_mask = 2;
90 benchdnn_parallel_nd(M, N, [&](int64_t m, int64_t n) {
91 size_t dst_off = dst_off_f(prb, m, n);
92 float &dst = ((float *)dst_m)[dst_off];
93
94 float tmp = ((float *)dst_tmp)[dst_off];
95 if (prb->bia_dt != dnnl_data_type_undef) {
96 int64_t bia_off = dst_m.get_scale_idx(dst_off, bias_broadcast_mask);
97 float *bia_ptr = (float *)bia_m;
98 tmp += bia_ptr[bia_off];
99 }
100 maybe_scale(prb->attr, tmp, prb->scales, n, attr_scale_arg);
101
102 const auto v_po_vals
103 = prepare_po_vals(dst_m, args, v_po_masks, dst_off);
104
105 maybe_post_ops(prb->attr, tmp, dst, v_po_vals);
106
107 maybe_zero_point(prb->attr, tmp, prb->dst_zp, n, DNNL_ARG_DST, true);
108 dst = tmp;
109 });
110}
111
112void compute_ref(
113 const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) {
114 if (prim_ref) {
115 SAFE_V(execute_and_wait(prim_ref, args));
116 return;
117 }
118
119 compute_ref_brgemm(prb, args);
120}
121
122#else
123
124void compute_ref(
125 const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) {}
126
127#endif
128
129} // namespace brgemm
130