1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <assert.h>
18#include <float.h>
19#include <math.h>
20
21#include "common/c_types_map.hpp"
22#include "common/dnnl_thread.hpp"
23#include "common/math_utils.hpp"
24#include "common/type_helpers.hpp"
25
26#include "cpu/cpu_primitive.hpp"
27#include "cpu/ref_io_helper.hpp"
28#include "cpu/simple_q10n.hpp"
29
30#include "cpu/matmul/matmul_utils.hpp"
31#include "cpu/matmul/ref_matmul_int8.hpp"
32
33namespace dnnl {
34namespace impl {
35namespace cpu {
36namespace matmul {
37
38status_t ref_matmul_int8_t::execute_ref(const exec_ctx_t &ctx) const {
39 status_t status = status::success;
40 const auto src = CTX_IN_MEM(const void *, DNNL_ARG_SRC);
41 const auto weights = CTX_IN_MEM(const void *, DNNL_ARG_WEIGHTS);
42 const auto bias = CTX_IN_MEM(const void *, DNNL_ARG_BIAS);
43 auto dst = CTX_OUT_CLEAN_MEM(void *, DNNL_ARG_DST, status);
44 CHECK(status);
45
46 DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC);
47 DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS);
48 DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST);
49
50 DEFINE_ZERO_POINTS_BUFFER(src_zero_point, DNNL_ARG_SRC);
51 DEFINE_ZERO_POINT_VALUE(weights_zero_point, DNNL_ARG_WEIGHTS);
52 DEFINE_ZERO_POINTS_BUFFER(dst_zero_point, DNNL_ARG_DST);
53
54 const auto src_d = ctx.memory_mdw(DNNL_ARG_SRC, pd()->src_md());
55 const auto weights_d = ctx.memory_mdw(DNNL_ARG_WEIGHTS, pd()->weights_md());
56 const auto dst_d = ctx.memory_mdw(DNNL_ARG_DST, pd()->dst_md());
57 const auto bia_d = ctx.memory_mdw(DNNL_ARG_BIAS, pd()->weights_md(1));
58
59 if (src_d.has_zero_dim() || weights_d.has_zero_dim()
60 || dst_d.has_zero_dim())
61 return status::success;
62
63 const bool non_default_attrs = !pd()->attr()->has_default_values();
64
65 matmul_helper_t helper(src_d, weights_d, dst_d);
66 const int ndims = pd()->ndims();
67 const int batch_ndims = ndims - 2;
68 const dim_t M = helper.M();
69 const dim_t N = helper.N();
70 const dim_t K = helper.K();
71 const dim_t batch = helper.batch();
72
73 const int src_mask
74 = utils::get_dims_mask(dst_d.dims(), src_d.dims(), ndims);
75 const int wei_mask
76 = utils::get_dims_mask(dst_d.dims(), weights_d.dims(), ndims);
77 const int bia_mask
78 = utils::get_dims_mask(dst_d.dims(), bia_d.dims(), ndims);
79
80 // zp_idx_mult = 1 for per_dim1 zero points and 0, otherwise
81 const int src_zp_idx_mult
82 = !pd()->attr()->zero_points_.common(DNNL_ARG_SRC);
83 const int dst_zp_idx_mult
84 = !pd()->attr()->zero_points_.common(DNNL_ARG_DST);
85
86 // mm kernel
87 auto ker = [&](const dims_t dst_dims_idx, dim_t m, dim_t n) {
88 int acc = 0;
89 dims_t src_dims_idx, weights_dims_idx;
90 utils::copy_dims_with_mask(src_dims_idx, dst_dims_idx, ndims, src_mask);
91 utils::copy_dims_with_mask(
92 weights_dims_idx, dst_dims_idx, ndims, wei_mask);
93 src_dims_idx[ndims - 2] = m;
94 weights_dims_idx[ndims - 1] = n;
95 auto &src_k_dim = src_dims_idx[ndims - 1];
96 auto &wei_k_dim = weights_dims_idx[ndims - 2];
97 for (dim_t k = 0; k < K; ++k) {
98 src_k_dim = k;
99 wei_k_dim = k;
100 const auto src_off = src_d.off_v(src_dims_idx);
101 const auto weights_off = weights_d.off_v(weights_dims_idx);
102 int s = io::load_int_value(src_d.data_type(), src, src_off);
103 int w = io::load_int_value(
104 weights_d.data_type(), weights, weights_off);
105 if (src_zero_point) {
106 const int src_zp = io::load_int_value(
107 data_type::s32, src_zero_point, src_zp_idx_mult * k);
108 s -= src_zp;
109 }
110 if (weights_zero_point) { w -= weights_zero_point; }
111 acc += s * w;
112 }
113 return acc;
114 };
115
116 // bias section
117 auto ker_bias = [&](const dims_t &dst_dims_idx) -> float {
118 dims_t bia_dims_idx;
119 utils::copy_dims_with_mask(bia_dims_idx, dst_dims_idx, ndims, bia_mask);
120 const auto bias_off = bia_d.off_v(bia_dims_idx);
121 return io::load_float_value(bia_d.data_type(), bias, bias_off);
122 };
123
124 // arg scales section
125 const auto &attr_scales = pd()->attr()->scales_;
126 const bool with_src_scales
127 = !attr_scales.get(DNNL_ARG_SRC).has_default_values();
128 const bool with_wei_scales
129 = !attr_scales.get(DNNL_ARG_WEIGHTS).has_default_values();
130 const bool with_dst_scales
131 = !attr_scales.get(DNNL_ARG_DST).has_default_values();
132 const dim_t wei_scale_stride
133 = attr_scales.get(DNNL_ARG_WEIGHTS).mask_ == 0 ? 0 : 1;
134
135 auto sum_dt = pd()->attr()->post_ops_.get_sum_dt(dst_d.data_type());
136
137 // computations
138 parallel_nd(batch, M, N, [&](dim_t mb, dim_t m, dim_t n) {
139 dims_t dst_dims_idx;
140 // account for M, N dims for index calculations
141 const size_t l_offset = mb * M * N + m * N + n;
142 utils::l_dims_by_l_offset(dst_dims_idx, l_offset, dst_d.dims(), ndims);
143 int acc = ker(dst_dims_idx, m, n);
144 float d = static_cast<int>(acc);
145 if (with_src_scales) d *= src_scales[0];
146 if (with_wei_scales) d *= wei_scales[wei_scale_stride * n];
147 if (bias) d += ker_bias(dst_dims_idx);
148
149 const auto dst_off = dst_d.off_v(dst_dims_idx);
150 if (non_default_attrs) {
151
152 ref_post_ops_t::args_t args;
153 args.dst_val = io::load_float_value(sum_dt, dst, dst_off);
154 args.ctx = &ctx;
155 args.l_offset = l_offset;
156 args.dst_md = pd()->dst_md();
157 ref_post_ops->execute(d, args);
158
159 if (with_dst_scales) d *= dst_scales[0];
160 if (dst_zero_point) {
161 const int dst_zp = io::load_int_value(
162 data_type::s32, dst_zero_point, dst_zp_idx_mult * n);
163 d += dst_zp;
164 }
165 }
166 io::store_float_value(dst_d.data_type(), d, dst, dst_off);
167 utils::dim_iterator(dst_d.dims(), dst_dims_idx, batch_ndims);
168 });
169
170 return status::success;
171}
172
173} // namespace matmul
174} // namespace cpu
175} // namespace impl
176} // namespace dnnl
177