1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <assert.h> |
18 | #include <float.h> |
19 | #include <math.h> |
20 | |
21 | #include "common/c_types_map.hpp" |
22 | #include "common/dnnl_thread.hpp" |
23 | #include "common/math_utils.hpp" |
24 | #include "common/type_helpers.hpp" |
25 | |
26 | #include "cpu/cpu_primitive.hpp" |
27 | #include "cpu/ref_io_helper.hpp" |
28 | #include "cpu/simple_q10n.hpp" |
29 | |
30 | #include "cpu/matmul/matmul_utils.hpp" |
31 | #include "cpu/matmul/ref_matmul_int8.hpp" |
32 | |
33 | namespace dnnl { |
34 | namespace impl { |
35 | namespace cpu { |
36 | namespace matmul { |
37 | |
38 | status_t ref_matmul_int8_t::execute_ref(const exec_ctx_t &ctx) const { |
39 | status_t status = status::success; |
40 | const auto src = CTX_IN_MEM(const void *, DNNL_ARG_SRC); |
41 | const auto weights = CTX_IN_MEM(const void *, DNNL_ARG_WEIGHTS); |
42 | const auto bias = CTX_IN_MEM(const void *, DNNL_ARG_BIAS); |
43 | auto dst = CTX_OUT_CLEAN_MEM(void *, DNNL_ARG_DST, status); |
44 | CHECK(status); |
45 | |
46 | DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC); |
47 | DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS); |
48 | DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); |
49 | |
50 | DEFINE_ZERO_POINTS_BUFFER(src_zero_point, DNNL_ARG_SRC); |
51 | DEFINE_ZERO_POINT_VALUE(weights_zero_point, DNNL_ARG_WEIGHTS); |
52 | DEFINE_ZERO_POINTS_BUFFER(dst_zero_point, DNNL_ARG_DST); |
53 | |
54 | const auto src_d = ctx.memory_mdw(DNNL_ARG_SRC, pd()->src_md()); |
55 | const auto weights_d = ctx.memory_mdw(DNNL_ARG_WEIGHTS, pd()->weights_md()); |
56 | const auto dst_d = ctx.memory_mdw(DNNL_ARG_DST, pd()->dst_md()); |
57 | const auto bia_d = ctx.memory_mdw(DNNL_ARG_BIAS, pd()->weights_md(1)); |
58 | |
59 | if (src_d.has_zero_dim() || weights_d.has_zero_dim() |
60 | || dst_d.has_zero_dim()) |
61 | return status::success; |
62 | |
63 | const bool non_default_attrs = !pd()->attr()->has_default_values(); |
64 | |
65 | matmul_helper_t helper(src_d, weights_d, dst_d); |
66 | const int ndims = pd()->ndims(); |
67 | const int batch_ndims = ndims - 2; |
68 | const dim_t M = helper.M(); |
69 | const dim_t N = helper.N(); |
70 | const dim_t K = helper.K(); |
71 | const dim_t batch = helper.batch(); |
72 | |
73 | const int src_mask |
74 | = utils::get_dims_mask(dst_d.dims(), src_d.dims(), ndims); |
75 | const int wei_mask |
76 | = utils::get_dims_mask(dst_d.dims(), weights_d.dims(), ndims); |
77 | const int bia_mask |
78 | = utils::get_dims_mask(dst_d.dims(), bia_d.dims(), ndims); |
79 | |
80 | // zp_idx_mult = 1 for per_dim1 zero points and 0, otherwise |
81 | const int src_zp_idx_mult |
82 | = !pd()->attr()->zero_points_.common(DNNL_ARG_SRC); |
83 | const int dst_zp_idx_mult |
84 | = !pd()->attr()->zero_points_.common(DNNL_ARG_DST); |
85 | |
86 | // mm kernel |
87 | auto ker = [&](const dims_t dst_dims_idx, dim_t m, dim_t n) { |
88 | int acc = 0; |
89 | dims_t src_dims_idx, weights_dims_idx; |
90 | utils::copy_dims_with_mask(src_dims_idx, dst_dims_idx, ndims, src_mask); |
91 | utils::copy_dims_with_mask( |
92 | weights_dims_idx, dst_dims_idx, ndims, wei_mask); |
93 | src_dims_idx[ndims - 2] = m; |
94 | weights_dims_idx[ndims - 1] = n; |
95 | auto &src_k_dim = src_dims_idx[ndims - 1]; |
96 | auto &wei_k_dim = weights_dims_idx[ndims - 2]; |
97 | for (dim_t k = 0; k < K; ++k) { |
98 | src_k_dim = k; |
99 | wei_k_dim = k; |
100 | const auto src_off = src_d.off_v(src_dims_idx); |
101 | const auto weights_off = weights_d.off_v(weights_dims_idx); |
102 | int s = io::load_int_value(src_d.data_type(), src, src_off); |
103 | int w = io::load_int_value( |
104 | weights_d.data_type(), weights, weights_off); |
105 | if (src_zero_point) { |
106 | const int src_zp = io::load_int_value( |
107 | data_type::s32, src_zero_point, src_zp_idx_mult * k); |
108 | s -= src_zp; |
109 | } |
110 | if (weights_zero_point) { w -= weights_zero_point; } |
111 | acc += s * w; |
112 | } |
113 | return acc; |
114 | }; |
115 | |
116 | // bias section |
117 | auto ker_bias = [&](const dims_t &dst_dims_idx) -> float { |
118 | dims_t bia_dims_idx; |
119 | utils::copy_dims_with_mask(bia_dims_idx, dst_dims_idx, ndims, bia_mask); |
120 | const auto bias_off = bia_d.off_v(bia_dims_idx); |
121 | return io::load_float_value(bia_d.data_type(), bias, bias_off); |
122 | }; |
123 | |
124 | // arg scales section |
125 | const auto &attr_scales = pd()->attr()->scales_; |
126 | const bool with_src_scales |
127 | = !attr_scales.get(DNNL_ARG_SRC).has_default_values(); |
128 | const bool with_wei_scales |
129 | = !attr_scales.get(DNNL_ARG_WEIGHTS).has_default_values(); |
130 | const bool with_dst_scales |
131 | = !attr_scales.get(DNNL_ARG_DST).has_default_values(); |
132 | const dim_t wei_scale_stride |
133 | = attr_scales.get(DNNL_ARG_WEIGHTS).mask_ == 0 ? 0 : 1; |
134 | |
135 | auto sum_dt = pd()->attr()->post_ops_.get_sum_dt(dst_d.data_type()); |
136 | |
137 | // computations |
138 | parallel_nd(batch, M, N, [&](dim_t mb, dim_t m, dim_t n) { |
139 | dims_t dst_dims_idx; |
140 | // account for M, N dims for index calculations |
141 | const size_t l_offset = mb * M * N + m * N + n; |
142 | utils::l_dims_by_l_offset(dst_dims_idx, l_offset, dst_d.dims(), ndims); |
143 | int acc = ker(dst_dims_idx, m, n); |
144 | float d = static_cast<int>(acc); |
145 | if (with_src_scales) d *= src_scales[0]; |
146 | if (with_wei_scales) d *= wei_scales[wei_scale_stride * n]; |
147 | if (bias) d += ker_bias(dst_dims_idx); |
148 | |
149 | const auto dst_off = dst_d.off_v(dst_dims_idx); |
150 | if (non_default_attrs) { |
151 | |
152 | ref_post_ops_t::args_t args; |
153 | args.dst_val = io::load_float_value(sum_dt, dst, dst_off); |
154 | args.ctx = &ctx; |
155 | args.l_offset = l_offset; |
156 | args.dst_md = pd()->dst_md(); |
157 | ref_post_ops->execute(d, args); |
158 | |
159 | if (with_dst_scales) d *= dst_scales[0]; |
160 | if (dst_zero_point) { |
161 | const int dst_zp = io::load_int_value( |
162 | data_type::s32, dst_zero_point, dst_zp_idx_mult * n); |
163 | d += dst_zp; |
164 | } |
165 | } |
166 | io::store_float_value(dst_d.data_type(), d, dst, dst_off); |
167 | utils::dim_iterator(dst_d.dims(), dst_dims_idx, batch_ndims); |
168 | }); |
169 | |
170 | return status::success; |
171 | } |
172 | |
173 | } // namespace matmul |
174 | } // namespace cpu |
175 | } // namespace impl |
176 | } // namespace dnnl |
177 | |