1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <atomic>
18
19#include <assert.h>
20#include <float.h>
21#include <math.h>
22
23#include "common/c_types_map.hpp"
24#include "common/dnnl_thread.hpp"
25#include "common/memory_tracking.hpp"
26#include "common/type_helpers.hpp"
27#include "common/utils.hpp"
28
29#include "cpu/cpu_primitive.hpp"
30
31#include "cpu/gemm/gemm.hpp"
32
33#include "cpu/binary_injector_utils.hpp"
34#include "cpu/matmul/gemm_x8s8s32x_matmul.hpp"
35#include "cpu/matmul/matmul_utils.hpp"
36#include "cpu/scale_utils.hpp"
37
38namespace dnnl {
39namespace impl {
40namespace cpu {
41namespace matmul {
42
43using namespace data_type;
44
45namespace {
46template <typename pd_t>
47bool need_post_processing(const pd_t *pd, float runtime_dst_zero_point = 0.f) {
48 return pd->with_bias() || pd->dst_md()->data_type != s32
49 || !pd->params().dst_is_acc_
50 || !pd->params().pp_attr_.has_default_values()
51 || !pd->params().pp_attr_.zero_points_.has_default_values(
52 DNNL_ARG_DST)
53 || runtime_dst_zero_point != 0.f;
54}
55} // namespace
56
57status_t gemm_x8s8s32x_matmul_t::pd_t::init(engine_t *engine) {
58 using namespace utils;
59 using namespace data_type;
60
61 auto check_attr_scales = [&]() -> bool {
62 using namespace data_type;
63 const std::vector<int> supported_args
64 = {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST};
65 bool ok = attr()->scales_.has_default_values(supported_args);
66 for (int arg : supported_args) {
67 const auto &mask = attr()->scales_.get(arg).mask_;
68 if (arg == DNNL_ARG_WEIGHTS)
69 ok = ok && (mask == 0 || mask == (dst_md()->ndims - 1));
70 else
71 ok = ok && (mask == 0);
72 }
73 if (!attr()->scales_.get(DNNL_ARG_SRC).has_default_values()
74 && !attr()->scales_.get(DNNL_ARG_WEIGHTS).has_default_values()
75 && attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_ != 0) {
76 // This case requires scratchpad with unknown size
77 if (N() == DNNL_RUNTIME_DIM_VAL) return false;
78 }
79 return ok;
80 };
81
82 auto check_attr_zero_points
83 = [&]() -> bool { return attr()->zero_points_.common(); };
84
85 auto check_attr_post_ops = [&]() -> bool {
86 using namespace primitive_kind;
87 const auto &post_ops = attr()->post_ops_;
88 static const bcast_set_t enabled_bcast_strategy {
89 broadcasting_strategy_t::scalar,
90 broadcasting_strategy_t::per_oc,
91 broadcasting_strategy_t::per_oc_spatial,
92 broadcasting_strategy_t::per_mb_spatial,
93 broadcasting_strategy_t::per_mb_w,
94 broadcasting_strategy_t::per_w,
95 broadcasting_strategy_t::no_broadcast};
96 const bool is_binary_po_per_oc
97 = binary_injector_utils::bcast_strategy_present(
98 binary_injector_utils::extract_bcast_strategies(
99 post_ops.entry_, dst_md()),
100 broadcasting_strategy_t::per_oc);
101 return cpu::inner_product_utils::post_ops_ok(
102 post_ops, dst_md(), enabled_bcast_strategy)
103 && IMPLICATION(is_binary_po_per_oc,
104 gemm_based::check_gemm_binary_per_oc_compatible_formats(
105 *this));
106 };
107
108 bool ok = !has_zero_dim_memory() && one_of(src_md()->data_type, s8, u8)
109 && weights_md()->data_type == s8 && desc()->accum_data_type == s32
110 && one_of(dst_md()->data_type, f32, s32, s8, u8)
111 && IMPLICATION(with_bias(),
112 one_of(weights_md(1)->data_type, f32, s32, s8, u8)
113 && is_bias_1xN())
114 && attr()->has_default_values(
115 primitive_attr_t::skip_mask_t::scales_runtime
116 | primitive_attr_t::skip_mask_t::zero_points_runtime
117 | primitive_attr_t::skip_mask_t::post_ops
118 | primitive_attr_t::skip_mask_t::sum_dt,
119 dst_md()->data_type)
120 && attr_.post_ops_.check_sum_consistent_dt(dst_md()->data_type)
121 // need to set up default formats first, so that latter checks can
122 // be perfomed properly
123 && set_default_formats() && check_attr_scales()
124 && check_attr_zero_points() && check_attr_post_ops()
125 && gemm_based::check_gemm_compatible_formats(*this)
126 && attr_.set_default_formats(dst_md(0)) == status::success;
127 if (!ok) return status::unimplemented;
128
129 // set states
130
131 // copy attributes and drop src and weights zero points
132 CHECK(params_.pp_attr_.copy_from(*attr()));
133 params_.pp_attr_.zero_points_.set(DNNL_ARG_SRC, 0);
134 params_.pp_attr_.zero_points_.set(DNNL_ARG_WEIGHTS, 0);
135
136 params_.gemm_applies_output_scales_ = false;
137 params_.gemm_beta_ = 0.f;
138
139 bool do_sum = params_.pp_attr_.post_ops_.find(primitive_kind::sum) >= 0;
140 params_.dst_is_acc_
141 = utils::one_of(dst_md()->data_type, s32, f32) && !do_sum;
142
143 params_.has_pp_kernel_ = need_post_processing(this);
144
145 nthr_ = dnnl_get_max_threads();
146 gemm_based::book_acc_scratchpad(*this, params_, sizeof(int32_t), nthr_);
147 auto scratchpad = scratchpad_registry().registrar();
148 book_precomputed_scales(scratchpad, attr()->scales_, N());
149
150 return status::success;
151}
152
153void gemm_x8s8s32x_matmul_t::post_process_src_and_weights_zero_points(
154 std::vector<int32_t> &src_comp, std::vector<int32_t> &wei_comp, dim_t M,
155 dim_t N, dim_t K, const char *src, dim_t src_s0, dim_t src_s1,
156 const int8_t *wei, dim_t wei_s0, dim_t wei_s1, int32_t *acc, int ldc,
157 int32_t src_zero_point, int32_t wei_zero_point) const {
158 if (wei_zero_point) {
159 for_(dim_t m = 0; m < M; ++m)
160 for (dim_t k = 0; k < K; ++k) {
161 if (k == 0) src_comp[m] = int32_t(0);
162 src_comp[m] += src[src_s0 * m + src_s1 * k];
163 }
164 }
165
166 if (src_zero_point) {
167 for_(dim_t k = 0; k < K; ++k)
168 for (dim_t n = 0; n < N; ++n) {
169 if (k == 0) wei_comp[n] = int32_t(0);
170 wei_comp[n] += wei[wei_s0 * k + wei_s1 * n];
171 }
172 }
173
174 for_(dim_t m = 0; m < M; ++m)
175 for (dim_t n = 0; n < N; ++n)
176 acc[m * ldc + n] += 0 - src_zero_point * wei_comp[n]
177 - wei_zero_point * src_comp[m]
178 + src_zero_point * wei_zero_point * (int)K;
179}
180
181status_t gemm_x8s8s32x_matmul_t::execute_ref(const exec_ctx_t &ctx) const {
182 using namespace binary_injector_utils;
183
184 auto src = CTX_IN_MEM(const char *, DNNL_ARG_SRC);
185 auto weights = CTX_IN_MEM(const int8_t *, DNNL_ARG_WEIGHTS);
186 auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS);
187 auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST);
188 const auto &po = this->pd()->attr()->post_ops_;
189 const auto post_ops_binary_rhs_arg_vec = prepare_binary_args(po, ctx);
190
191 const auto src_d = ctx.memory_mdw(DNNL_ARG_SRC, pd()->src_md());
192 const auto weights_d = ctx.memory_mdw(DNNL_ARG_WEIGHTS, pd()->weights_md());
193 const auto dst_d = ctx.memory_mdw(DNNL_ARG_DST, pd()->dst_md());
194
195 const int ndims = pd()->ndims();
196
197 DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC);
198 DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS);
199 DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST);
200
201 auto &scratchpad = ctx.get_scratchpad_grantor();
202 const float *scales = precompute_scales(scratchpad, src_scales, wei_scales,
203 dst_d.dims()[ndims - 1], pd()->attr());
204
205 DEFINE_ZERO_POINT_VALUE(src_zero_point, DNNL_ARG_SRC);
206 DEFINE_ZERO_POINT_VALUE(weights_zero_point, DNNL_ARG_WEIGHTS);
207 DEFINE_ZERO_POINT_VALUE(dst_zero_point, DNNL_ARG_DST);
208
209 if (src_d.has_zero_dim() || weights_d.has_zero_dim()
210 || dst_d.has_zero_dim())
211 return status::success;
212
213 int8_t gemm_off_a_int8 = static_cast<int8_t>(src_zero_point);
214 uint8_t gemm_off_a_uint8 = static_cast<uint8_t>(src_zero_point);
215 int8_t gemm_off_b = static_cast<int8_t>(weights_zero_point);
216 const bool ok = IMPLICATION(src_d.data_type() == data_type::s8,
217 gemm_off_a_int8 == src_zero_point)
218 && IMPLICATION(src_d.data_type() == data_type::u8,
219 gemm_off_a_uint8 == src_zero_point)
220 && gemm_off_b == weights_zero_point;
221 const bool post_process_src_and_weights_zero_points_outside_of_gemm = !ok;
222 if (post_process_src_and_weights_zero_points_outside_of_gemm) {
223 gemm_off_a_int8 = gemm_off_a_uint8 = gemm_off_b = 0;
224 }
225 const float dst_zero_point_f32 = static_cast<float>(dst_zero_point);
226
227 matmul_helper_t helper(src_d, weights_d, dst_d);
228 const int batch_ndims = ndims - 2;
229 dim_t M = helper.M();
230 const dim_t N = helper.N();
231 const dim_t K = helper.K();
232 const dim_t batch = helper.batch();
233 const dim_t batch_without_dim0
234 = helper.ndims() > 3 ? batch / dst_d.dims()[0] : 0;
235 const dim_t batch_without_dim01
236 = helper.ndims() > 4 ? batch_without_dim0 / dst_d.dims()[1] : 1;
237 const char transA = helper.transA();
238 const char transB = helper.transB();
239 const dim_t lda = helper.lda();
240 const dim_t ldb = helper.ldb();
241 const dim_t ldc = helper.ldc();
242 const int ldx_dim_idx = pd()->ndims() - 2;
243 const dim_t *src_strides = &src_d.blocking_desc().strides[ldx_dim_idx];
244 const dim_t *weights_strides
245 = &weights_d.blocking_desc().strides[ldx_dim_idx];
246 const int nthr = pd()->nthr_;
247
248 const gemm_based::params_t &params = pd()->params();
249 const bool can_fuse_src_batch_dims = pd()->has_runtime_dims_or_strides()
250 ? helper.can_fuse_src_batch_dims()
251 : params.can_fuse_src_batch_dims_;
252 const dim_t acc_stride = gemm_based::get_scratchpad_size(
253 batch, M, N, can_fuse_src_batch_dims, nthr);
254 bool dst_is_acc = params.dst_is_acc_;
255 int32_t *acc = dst_is_acc
256 ? reinterpret_cast<int32_t *>(dst)
257 : ctx.get_scratchpad_grantor().template get<int32_t>(
258 memory_tracking::names::key_matmul_dst_in_acc_dt);
259 // case: dynamic sizes
260 bool need_free_acc = false;
261 if (acc == nullptr) {
262 acc = (int32_t *)malloc(sizeof(int32_t) * acc_stride
263 * ((can_fuse_src_batch_dims || batch == 1) ? 1 : nthr),
264 64);
265
266 if (acc == nullptr) return status::out_of_memory;
267 need_free_acc = true;
268 }
269
270 const float alpha = params.get_gemm_alpha(scales);
271 const float beta = params.gemm_beta_;
272 const dim_t acc_ldc = dst_is_acc ? ldc : N;
273 const int scale_idx_mult
274 = this->pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_
275 == (1 << (ndims - 1));
276
277 std::atomic<status_t> st(status::success);
278 // use parallel over batch when binary po with channel bcast
279 // (except batch == 1)
280 bool is_binary_po_per_oc = false;
281 bool is_binary_po_per_oc_sp = false;
282 bool is_binary_po_channel_bcast = false;
283 std::tie(is_binary_po_per_oc, is_binary_po_per_oc_sp,
284 is_binary_po_channel_bcast)
285 = bcast_strategies_present_tup(po.entry_, pd()->dst_md(),
286 broadcasting_strategy_t::per_oc,
287 broadcasting_strategy_t::per_oc_spatial,
288 broadcasting_strategy_t::per_mb_spatial);
289 // if batched, parralel over batch for per_mb_sp and per_oc binary
290 // post-op broadcast
291 const bool can_use_po_with_fused_batch = !is_binary_po_channel_bcast
292 && IMPLICATION(
293 is_binary_po_per_oc || is_binary_po_per_oc_sp, ndims == 2);
294 const bool parallel_over_batch = batch > 1 && !can_fuse_src_batch_dims;
295 if (IMPLICATION(can_use_po_with_fused_batch, parallel_over_batch)) {
296 const int src_mask
297 = utils::get_dims_mask(dst_d.dims(), src_d.dims(), ndims);
298 const int wei_mask
299 = utils::get_dims_mask(dst_d.dims(), weights_d.dims(), ndims);
300 const size_t bia_dt_size = !pd()->with_bias()
301 ? 0
302 : types::data_type_size(pd()->weights_md(1)->data_type);
303 const size_t dst_dt_size = types::data_type_size(dst_d.data_type());
304 const size_t work_amount = (size_t)batch * M * N;
305 const size_t work_per_batch = (size_t)M * N;
306
307 // NOTE: inside lambda, type cast variables captured by reference using
308 // either c-like "(type)var" or functional "type(var)" notation in order
309 // to avoid gcc bug with c++14 standard. Otherwise, capture by value.
310 parallel(nthr, [=, &st](int ithr, int nthr) {
311 size_t t_work_start {0}, t_work_end {0};
312 balance211(work_amount, nthr, ithr, t_work_start, t_work_end);
313
314 dim_t cur_b {0}, cur_m {0}, cur_n {0};
315 dims_t s_dims_idx, w_dims_idx, d_dims_idx;
316 size_t i_work = t_work_start;
317
318 const bool reuse_acc = acc != (int32_t *)dst;
319 int32_t *curr_acc = reuse_acc ? acc + ithr * acc_stride : nullptr;
320
321 std::vector<int32_t> src_compensation(M, 0);
322 std::vector<int32_t> weights_compensation(N, 0);
323
324 // icc 17.0 has a bug with capturing const variables with value known
325 // at compilation time in lambdas
326 const int32_t gemm_off_c = 0;
327
328 while (i_work < t_work_end) {
329 utils::nd_iterator_init(
330 i_work, cur_b, batch, cur_m, M, cur_n, N);
331
332 utils::l_dims_by_l_offset(
333 d_dims_idx, i_work, dst_d.dims(), ndims);
334
335 utils::copy_dims_with_mask(
336 s_dims_idx, d_dims_idx, batch_ndims, src_mask);
337 s_dims_idx[ndims - 2] = cur_m;
338 s_dims_idx[ndims - 1] = 0; // k idx is always 0
339
340 utils::copy_dims_with_mask(
341 w_dims_idx, d_dims_idx, batch_ndims, wei_mask);
342 w_dims_idx[ndims - 2] = 0; // k idx is always 0
343 w_dims_idx[ndims - 1] = cur_n;
344
345 const char *curr_src = src + src_d.off_v(s_dims_idx);
346 const int8_t *curr_weights
347 = weights + weights_d.off_v(w_dims_idx);
348 const dim_t dst_off = dst_d.off_v(d_dims_idx);
349 char *curr_dst = dst + dst_dt_size * dst_off;
350 if (!reuse_acc) curr_acc = acc + dst_off;
351
352 dim_t gemm_M {0}, gemm_N {0};
353 size_t matrix_offset;
354 const size_t rem_work = t_work_end - i_work;
355 if (rem_work >= work_per_batch && cur_m == 0 && cur_n == 0) {
356 // parallel over batch
357 gemm_M = M;
358 gemm_N = N;
359 matrix_offset = 0;
360 } else if (rem_work >= (size_t)N && cur_n == 0) {
361 // parallel over M
362 gemm_M = nstl::min(
363 (size_t)(M - cur_m), (size_t)(rem_work / N));
364 gemm_N = N;
365 matrix_offset = cur_n + cur_m * N;
366 } else {
367 // parallel over N
368 gemm_M = 1;
369 gemm_N = nstl::min((size_t)(N - cur_n), rem_work);
370 matrix_offset = cur_n + cur_m * N;
371 }
372
373 status_t st_thr = status::runtime_error;
374 switch (src_d.data_type()) {
375 case data_type::s8: {
376 const int8_t *curr_src_
377 = reinterpret_cast<const int8_t *>(curr_src);
378 st_thr = gemm_s8x8s32(&transB, &transA, "F", &gemm_N,
379 &gemm_M, &K, &alpha, curr_weights, &ldb,
380 &gemm_off_b, curr_src_, &lda, &gemm_off_a_int8,
381 &beta, curr_acc, &acc_ldc, &gemm_off_c);
382 } break;
383 case data_type::u8: {
384 const uint8_t *curr_src_
385 = reinterpret_cast<const uint8_t *>(curr_src);
386 st_thr = gemm_s8x8s32(&transB, &transA, "F", &gemm_N,
387 &gemm_M, &K, &alpha, curr_weights, &ldb,
388 &gemm_off_b, curr_src_, &lda, &gemm_off_a_uint8,
389 &beta, curr_acc, &acc_ldc, &gemm_off_c);
390 } break;
391 default: assert(!"unsupported data type"); break;
392 }
393
394 if (st_thr != status::success) {
395 st = st_thr;
396 return;
397 }
398
399 // if igemm cannot handle src and weights zero points
400 if (post_process_src_and_weights_zero_points_outside_of_gemm) {
401 post_process_src_and_weights_zero_points(src_compensation,
402 weights_compensation, gemm_M, gemm_N, K, curr_src,
403 src_strides[0], src_strides[1], curr_weights,
404 weights_strides[0], weights_strides[1], curr_acc,
405 acc_ldc, src_zero_point, weights_zero_point);
406 }
407
408 bool postops_in_matmul
409 = need_post_processing(pd(), dst_zero_point_f32);
410 assert(IMPLICATION(postops_in_matmul, params.has_pp_kernel_));
411
412 if (postops_in_matmul) {
413 const size_t dst_logical_off = i_work;
414 const size_t dim1_off = helper.ndims() > 3
415 ? ((cur_b % batch_without_dim0)
416 / batch_without_dim01)
417 : cur_m;
418 // offset for case with post-op broadcast_channel
419 const size_t matrix_per_first_batch_off = helper.ndims() > 3
420 ? M * N * (cur_b / batch_without_dim0)
421 + matrix_offset
422 : 0;
423 const ptrdiff_t oc_off = i_work % N;
424 (*pp_kernel_)(curr_dst, curr_acc,
425 bias + oc_off * bia_dt_size,
426 scales + oc_off * scale_idx_mult, dst_scales[0], 0,
427 dst_logical_off, dim1_off, gemm_M * gemm_N,
428 static_cast<size_t>(N), ldc, &dst_zero_point_f32,
429 post_ops_binary_rhs_arg_vec.data(), dst,
430 matrix_per_first_batch_off, ctx, *pd()->dst_md());
431 }
432 i_work += gemm_M * gemm_N;
433 }
434 });
435 } else {
436 // icc 17.0 has a bug with capturing const variables with value known
437 // at compilation time in lambdas
438 const int32_t gemm_off_c = 0;
439
440 // collapse batch into M, if weights batch dimensions are broadcasted.
441 M = batch * M;
442 status_t st = status::runtime_error;
443 switch (src_d.data_type()) {
444 case data_type::s8: {
445 const int8_t *src_ = reinterpret_cast<const int8_t *>(src);
446 st = gemm_s8x8s32(&transB, &transA, "F", &N, &M, &K, &alpha,
447 weights, &ldb, &gemm_off_b, src_, &lda,
448 &gemm_off_a_int8, &beta, acc, &acc_ldc, &gemm_off_c);
449 } break;
450 case data_type::u8: {
451 const uint8_t *src_ = reinterpret_cast<const uint8_t *>(src);
452 st = gemm_s8x8s32(&transB, &transA, "F", &N, &M, &K, &alpha,
453 weights, &ldb, &gemm_off_b, src_, &lda,
454 &gemm_off_a_uint8, &beta, acc, &acc_ldc, &gemm_off_c);
455 } break;
456 default: assert(!"unsupported data type"); break;
457 }
458
459 if (st == status::success) {
460 std::vector<int32_t> src_compensation(M, 0);
461 std::vector<int32_t> weights_compensation(N, 0);
462
463 // if igemm cannot handle src and weights zero points
464 if (post_process_src_and_weights_zero_points_outside_of_gemm) {
465 post_process_src_and_weights_zero_points(src_compensation,
466 weights_compensation, M, N, K, src, src_strides[0],
467 src_strides[1], weights, weights_strides[0],
468 weights_strides[1], acc, acc_ldc, src_zero_point,
469 weights_zero_point);
470 }
471
472 bool postops_in_matmul
473 = need_post_processing(pd(), dst_zero_point_f32);
474 assert(IMPLICATION(postops_in_matmul, params.has_pp_kernel_));
475
476 if (postops_in_matmul) {
477 const bool force_sequential = pp_kernel_->sequential_kernel();
478 parallel(force_sequential ? 1 : nthr, [&](int ithr, int nthr) {
479 size_t start {}, end {};
480 balance211((size_t)(M * N), nthr, ithr, start, end);
481 const size_t dst_logical_off = start;
482 const size_t dim1_off = start % N;
483 (*pp_kernel_)(dst, acc, bias, scales, dst_scales[0], start,
484 dst_logical_off, dim1_off, end, (size_t)N, ldc,
485 &dst_zero_point_f32,
486 post_ops_binary_rhs_arg_vec.data(), dst, 0, ctx,
487 *pd()->dst_md());
488 });
489 }
490 }
491 }
492 if (need_free_acc) free(acc);
493
494 return st;
495}
496
497} // namespace matmul
498} // namespace cpu
499} // namespace impl
500} // namespace dnnl
501