1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <assert.h> |
18 | #include <float.h> |
19 | #include <math.h> |
20 | |
21 | #include "common/c_types_map.hpp" |
22 | #include "common/dnnl_thread.hpp" |
23 | #include "common/math_utils.hpp" |
24 | #include "common/type_helpers.hpp" |
25 | |
26 | #include "cpu/cpu_primitive.hpp" |
27 | #include "cpu/ref_io_helper.hpp" |
28 | |
29 | #include "cpu/matmul/matmul_utils.hpp" |
30 | #include "cpu/matmul/ref_matmul.hpp" |
31 | |
32 | namespace dnnl { |
33 | namespace impl { |
34 | namespace cpu { |
35 | namespace matmul { |
36 | |
37 | status_t ref_matmul_t::execute_ref(const exec_ctx_t &ctx) const { |
38 | status_t status = status::success; |
39 | const auto src = CTX_IN_MEM(const void *, DNNL_ARG_SRC); |
40 | const auto weights = CTX_IN_MEM(const void *, DNNL_ARG_WEIGHTS); |
41 | const auto bias = CTX_IN_MEM(const void *, DNNL_ARG_BIAS); |
42 | auto dst = CTX_OUT_CLEAN_MEM(void *, DNNL_ARG_DST, status); |
43 | CHECK(status); |
44 | |
45 | DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC); |
46 | DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS); |
47 | DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); |
48 | |
49 | const auto src_d = ctx.memory_mdw(DNNL_ARG_SRC, pd()->src_md()); |
50 | const auto weights_d = ctx.memory_mdw(DNNL_ARG_WEIGHTS, pd()->weights_md()); |
51 | const auto dst_d = ctx.memory_mdw(DNNL_ARG_DST, pd()->dst_md()); |
52 | const auto bia_d = ctx.memory_mdw(DNNL_ARG_BIAS, pd()->weights_md(1)); |
53 | |
54 | if (src_d.has_zero_dim() || weights_d.has_zero_dim() |
55 | || dst_d.has_zero_dim()) |
56 | return status::success; |
57 | |
58 | const bool non_default_attrs = !pd()->attr()->has_default_values(); |
59 | |
60 | matmul_helper_t helper(src_d, weights_d, dst_d); |
61 | const int ndims = pd()->ndims(); |
62 | const int batch_ndims = ndims - 2; |
63 | const dim_t M = helper.M(); |
64 | const dim_t N = helper.N(); |
65 | const dim_t K = helper.K(); |
66 | const dim_t batch = helper.batch(); |
67 | |
68 | const int src_mask |
69 | = utils::get_dims_mask(dst_d.dims(), src_d.dims(), ndims); |
70 | const int wei_mask |
71 | = utils::get_dims_mask(dst_d.dims(), weights_d.dims(), ndims); |
72 | const int bia_mask |
73 | = utils::get_dims_mask(dst_d.dims(), bia_d.dims(), ndims); |
74 | |
75 | // mm kernel |
76 | auto ker = [&](const dims_t dst_dims_idx, dim_t m, dim_t n) { |
77 | float acc = 0; |
78 | dims_t src_dims_idx, weights_dims_idx; |
79 | utils::copy_dims_with_mask(src_dims_idx, dst_dims_idx, ndims, src_mask); |
80 | utils::copy_dims_with_mask( |
81 | weights_dims_idx, dst_dims_idx, ndims, wei_mask); |
82 | src_dims_idx[ndims - 2] = m; |
83 | weights_dims_idx[ndims - 1] = n; |
84 | auto &src_k_dim = src_dims_idx[ndims - 1]; |
85 | auto &wei_k_dim = weights_dims_idx[ndims - 2]; |
86 | for (dim_t k = 0; k < K; ++k) { |
87 | src_k_dim = k; |
88 | wei_k_dim = k; |
89 | const auto src_off = src_d.off_v(src_dims_idx); |
90 | const auto weights_off = weights_d.off_v(weights_dims_idx); |
91 | const float s |
92 | = io::load_float_value(src_d.data_type(), src, src_off); |
93 | const float w = io::load_float_value( |
94 | weights_d.data_type(), weights, weights_off); |
95 | acc += s * w; |
96 | } |
97 | return acc; |
98 | }; |
99 | |
100 | // bias section |
101 | auto ker_bias = [&](const dims_t &dst_dims_idx) -> float { |
102 | dims_t bia_dims_idx; |
103 | utils::copy_dims_with_mask(bia_dims_idx, dst_dims_idx, ndims, bia_mask); |
104 | const auto bias_off = bia_d.off_v(bia_dims_idx); |
105 | return io::load_float_value(bia_d.data_type(), bias, bias_off); |
106 | }; |
107 | |
108 | // arg scales section |
109 | const auto &attr_scales = pd()->attr()->scales_; |
110 | const bool with_src_scales |
111 | = !attr_scales.get(DNNL_ARG_SRC).has_default_values(); |
112 | const bool with_wei_scales |
113 | = !attr_scales.get(DNNL_ARG_WEIGHTS).has_default_values(); |
114 | const bool with_dst_scales |
115 | = !attr_scales.get(DNNL_ARG_DST).has_default_values(); |
116 | const dim_t wei_scale_stride |
117 | = attr_scales.get(DNNL_ARG_WEIGHTS).mask_ == 0 ? 0 : 1; |
118 | |
119 | auto sum_dt = pd()->attr()->post_ops_.get_sum_dt(dst_d.data_type()); |
120 | |
121 | // computations |
122 | parallel_nd(batch, M, N, [&](dim_t mb, dim_t m, dim_t n) { |
123 | dims_t dst_dims_idx; |
124 | // account for M, N dims for index calculations |
125 | const size_t l_offset = mb * M * N + m * N + n; |
126 | utils::l_dims_by_l_offset(dst_dims_idx, l_offset, dst_d.dims(), ndims); |
127 | float d = ker(dst_dims_idx, m, n); |
128 | if (with_src_scales) d *= src_scales[0]; |
129 | if (with_wei_scales) d *= wei_scales[wei_scale_stride * n]; |
130 | if (bias) d += ker_bias(dst_dims_idx); |
131 | |
132 | const auto dst_off = dst_d.off_v(dst_dims_idx); |
133 | if (non_default_attrs) { |
134 | ref_post_ops_t::args_t args; |
135 | args.dst_val = io::load_float_value(sum_dt, dst, dst_off); |
136 | args.ctx = &ctx; |
137 | args.l_offset = l_offset; |
138 | args.dst_md = pd()->dst_md(); |
139 | ref_post_ops->execute(d, args); |
140 | } |
141 | if (with_dst_scales) d *= dst_scales[0]; |
142 | io::store_float_value(dst_d.data_type(), d, dst, dst_off); |
143 | utils::dim_iterator(dst_d.dims(), dst_dims_idx, batch_ndims); |
144 | }); |
145 | |
146 | return status::success; |
147 | } |
148 | |
149 | } // namespace matmul |
150 | } // namespace cpu |
151 | } // namespace impl |
152 | } // namespace dnnl |
153 | |