1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <assert.h>
18#include <float.h>
19#include <math.h>
20
21#include "common/c_types_map.hpp"
22#include "common/dnnl_thread.hpp"
23#include "common/math_utils.hpp"
24#include "common/type_helpers.hpp"
25
26#include "cpu/cpu_primitive.hpp"
27#include "cpu/ref_io_helper.hpp"
28
29#include "cpu/matmul/matmul_utils.hpp"
30#include "cpu/matmul/ref_matmul.hpp"
31
32namespace dnnl {
33namespace impl {
34namespace cpu {
35namespace matmul {
36
37status_t ref_matmul_t::execute_ref(const exec_ctx_t &ctx) const {
38 status_t status = status::success;
39 const auto src = CTX_IN_MEM(const void *, DNNL_ARG_SRC);
40 const auto weights = CTX_IN_MEM(const void *, DNNL_ARG_WEIGHTS);
41 const auto bias = CTX_IN_MEM(const void *, DNNL_ARG_BIAS);
42 auto dst = CTX_OUT_CLEAN_MEM(void *, DNNL_ARG_DST, status);
43 CHECK(status);
44
45 DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC);
46 DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS);
47 DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST);
48
49 const auto src_d = ctx.memory_mdw(DNNL_ARG_SRC, pd()->src_md());
50 const auto weights_d = ctx.memory_mdw(DNNL_ARG_WEIGHTS, pd()->weights_md());
51 const auto dst_d = ctx.memory_mdw(DNNL_ARG_DST, pd()->dst_md());
52 const auto bia_d = ctx.memory_mdw(DNNL_ARG_BIAS, pd()->weights_md(1));
53
54 if (src_d.has_zero_dim() || weights_d.has_zero_dim()
55 || dst_d.has_zero_dim())
56 return status::success;
57
58 const bool non_default_attrs = !pd()->attr()->has_default_values();
59
60 matmul_helper_t helper(src_d, weights_d, dst_d);
61 const int ndims = pd()->ndims();
62 const int batch_ndims = ndims - 2;
63 const dim_t M = helper.M();
64 const dim_t N = helper.N();
65 const dim_t K = helper.K();
66 const dim_t batch = helper.batch();
67
68 const int src_mask
69 = utils::get_dims_mask(dst_d.dims(), src_d.dims(), ndims);
70 const int wei_mask
71 = utils::get_dims_mask(dst_d.dims(), weights_d.dims(), ndims);
72 const int bia_mask
73 = utils::get_dims_mask(dst_d.dims(), bia_d.dims(), ndims);
74
75 // mm kernel
76 auto ker = [&](const dims_t dst_dims_idx, dim_t m, dim_t n) {
77 float acc = 0;
78 dims_t src_dims_idx, weights_dims_idx;
79 utils::copy_dims_with_mask(src_dims_idx, dst_dims_idx, ndims, src_mask);
80 utils::copy_dims_with_mask(
81 weights_dims_idx, dst_dims_idx, ndims, wei_mask);
82 src_dims_idx[ndims - 2] = m;
83 weights_dims_idx[ndims - 1] = n;
84 auto &src_k_dim = src_dims_idx[ndims - 1];
85 auto &wei_k_dim = weights_dims_idx[ndims - 2];
86 for (dim_t k = 0; k < K; ++k) {
87 src_k_dim = k;
88 wei_k_dim = k;
89 const auto src_off = src_d.off_v(src_dims_idx);
90 const auto weights_off = weights_d.off_v(weights_dims_idx);
91 const float s
92 = io::load_float_value(src_d.data_type(), src, src_off);
93 const float w = io::load_float_value(
94 weights_d.data_type(), weights, weights_off);
95 acc += s * w;
96 }
97 return acc;
98 };
99
100 // bias section
101 auto ker_bias = [&](const dims_t &dst_dims_idx) -> float {
102 dims_t bia_dims_idx;
103 utils::copy_dims_with_mask(bia_dims_idx, dst_dims_idx, ndims, bia_mask);
104 const auto bias_off = bia_d.off_v(bia_dims_idx);
105 return io::load_float_value(bia_d.data_type(), bias, bias_off);
106 };
107
108 // arg scales section
109 const auto &attr_scales = pd()->attr()->scales_;
110 const bool with_src_scales
111 = !attr_scales.get(DNNL_ARG_SRC).has_default_values();
112 const bool with_wei_scales
113 = !attr_scales.get(DNNL_ARG_WEIGHTS).has_default_values();
114 const bool with_dst_scales
115 = !attr_scales.get(DNNL_ARG_DST).has_default_values();
116 const dim_t wei_scale_stride
117 = attr_scales.get(DNNL_ARG_WEIGHTS).mask_ == 0 ? 0 : 1;
118
119 auto sum_dt = pd()->attr()->post_ops_.get_sum_dt(dst_d.data_type());
120
121 // computations
122 parallel_nd(batch, M, N, [&](dim_t mb, dim_t m, dim_t n) {
123 dims_t dst_dims_idx;
124 // account for M, N dims for index calculations
125 const size_t l_offset = mb * M * N + m * N + n;
126 utils::l_dims_by_l_offset(dst_dims_idx, l_offset, dst_d.dims(), ndims);
127 float d = ker(dst_dims_idx, m, n);
128 if (with_src_scales) d *= src_scales[0];
129 if (with_wei_scales) d *= wei_scales[wei_scale_stride * n];
130 if (bias) d += ker_bias(dst_dims_idx);
131
132 const auto dst_off = dst_d.off_v(dst_dims_idx);
133 if (non_default_attrs) {
134 ref_post_ops_t::args_t args;
135 args.dst_val = io::load_float_value(sum_dt, dst, dst_off);
136 args.ctx = &ctx;
137 args.l_offset = l_offset;
138 args.dst_md = pd()->dst_md();
139 ref_post_ops->execute(d, args);
140 }
141 if (with_dst_scales) d *= dst_scales[0];
142 io::store_float_value(dst_d.data_type(), d, dst, dst_off);
143 utils::dim_iterator(dst_d.dims(), dst_dims_idx, batch_ndims);
144 });
145
146 return status::success;
147}
148
149} // namespace matmul
150} // namespace cpu
151} // namespace impl
152} // namespace dnnl
153