1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <atomic> |
18 | |
19 | #include <assert.h> |
20 | #include <float.h> |
21 | #include <math.h> |
22 | |
23 | #include "common/c_types_map.hpp" |
24 | #include "common/dnnl_thread.hpp" |
25 | #include "common/memory_tracking.hpp" |
26 | #include "common/type_helpers.hpp" |
27 | #include "common/utils.hpp" |
28 | |
29 | #include "cpu/cpu_primitive.hpp" |
30 | |
31 | #include "cpu/gemm/gemm.hpp" |
32 | |
33 | #include "cpu/binary_injector_utils.hpp" |
34 | #include "cpu/matmul/gemm_x8s8s32x_matmul.hpp" |
35 | #include "cpu/matmul/matmul_utils.hpp" |
36 | #include "cpu/scale_utils.hpp" |
37 | |
38 | namespace dnnl { |
39 | namespace impl { |
40 | namespace cpu { |
41 | namespace matmul { |
42 | |
43 | using namespace data_type; |
44 | |
45 | namespace { |
46 | template <typename pd_t> |
47 | bool need_post_processing(const pd_t *pd, float runtime_dst_zero_point = 0.f) { |
48 | return pd->with_bias() || pd->dst_md()->data_type != s32 |
49 | || !pd->params().dst_is_acc_ |
50 | || !pd->params().pp_attr_.has_default_values() |
51 | || !pd->params().pp_attr_.zero_points_.has_default_values( |
52 | DNNL_ARG_DST) |
53 | || runtime_dst_zero_point != 0.f; |
54 | } |
55 | } // namespace |
56 | |
57 | status_t gemm_x8s8s32x_matmul_t::pd_t::init(engine_t *engine) { |
58 | using namespace utils; |
59 | using namespace data_type; |
60 | |
61 | auto check_attr_scales = [&]() -> bool { |
62 | using namespace data_type; |
63 | const std::vector<int> supported_args |
64 | = {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}; |
65 | bool ok = attr()->scales_.has_default_values(supported_args); |
66 | for (int arg : supported_args) { |
67 | const auto &mask = attr()->scales_.get(arg).mask_; |
68 | if (arg == DNNL_ARG_WEIGHTS) |
69 | ok = ok && (mask == 0 || mask == (dst_md()->ndims - 1)); |
70 | else |
71 | ok = ok && (mask == 0); |
72 | } |
73 | if (!attr()->scales_.get(DNNL_ARG_SRC).has_default_values() |
74 | && !attr()->scales_.get(DNNL_ARG_WEIGHTS).has_default_values() |
75 | && attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_ != 0) { |
76 | // This case requires scratchpad with unknown size |
77 | if (N() == DNNL_RUNTIME_DIM_VAL) return false; |
78 | } |
79 | return ok; |
80 | }; |
81 | |
82 | auto check_attr_zero_points |
83 | = [&]() -> bool { return attr()->zero_points_.common(); }; |
84 | |
85 | auto check_attr_post_ops = [&]() -> bool { |
86 | using namespace primitive_kind; |
87 | const auto &post_ops = attr()->post_ops_; |
88 | static const bcast_set_t enabled_bcast_strategy { |
89 | broadcasting_strategy_t::scalar, |
90 | broadcasting_strategy_t::per_oc, |
91 | broadcasting_strategy_t::per_oc_spatial, |
92 | broadcasting_strategy_t::per_mb_spatial, |
93 | broadcasting_strategy_t::per_mb_w, |
94 | broadcasting_strategy_t::per_w, |
95 | broadcasting_strategy_t::no_broadcast}; |
96 | const bool is_binary_po_per_oc |
97 | = binary_injector_utils::bcast_strategy_present( |
98 | binary_injector_utils::extract_bcast_strategies( |
99 | post_ops.entry_, dst_md()), |
100 | broadcasting_strategy_t::per_oc); |
101 | return cpu::inner_product_utils::post_ops_ok( |
102 | post_ops, dst_md(), enabled_bcast_strategy) |
103 | && IMPLICATION(is_binary_po_per_oc, |
104 | gemm_based::check_gemm_binary_per_oc_compatible_formats( |
105 | *this)); |
106 | }; |
107 | |
108 | bool ok = !has_zero_dim_memory() && one_of(src_md()->data_type, s8, u8) |
109 | && weights_md()->data_type == s8 && desc()->accum_data_type == s32 |
110 | && one_of(dst_md()->data_type, f32, s32, s8, u8) |
111 | && IMPLICATION(with_bias(), |
112 | one_of(weights_md(1)->data_type, f32, s32, s8, u8) |
113 | && is_bias_1xN()) |
114 | && attr()->has_default_values( |
115 | primitive_attr_t::skip_mask_t::scales_runtime |
116 | | primitive_attr_t::skip_mask_t::zero_points_runtime |
117 | | primitive_attr_t::skip_mask_t::post_ops |
118 | | primitive_attr_t::skip_mask_t::sum_dt, |
119 | dst_md()->data_type) |
120 | && attr_.post_ops_.check_sum_consistent_dt(dst_md()->data_type) |
121 | // need to set up default formats first, so that latter checks can |
122 | // be perfomed properly |
123 | && set_default_formats() && check_attr_scales() |
124 | && check_attr_zero_points() && check_attr_post_ops() |
125 | && gemm_based::check_gemm_compatible_formats(*this) |
126 | && attr_.set_default_formats(dst_md(0)) == status::success; |
127 | if (!ok) return status::unimplemented; |
128 | |
129 | // set states |
130 | |
131 | // copy attributes and drop src and weights zero points |
132 | CHECK(params_.pp_attr_.copy_from(*attr())); |
133 | params_.pp_attr_.zero_points_.set(DNNL_ARG_SRC, 0); |
134 | params_.pp_attr_.zero_points_.set(DNNL_ARG_WEIGHTS, 0); |
135 | |
136 | params_.gemm_applies_output_scales_ = false; |
137 | params_.gemm_beta_ = 0.f; |
138 | |
139 | bool do_sum = params_.pp_attr_.post_ops_.find(primitive_kind::sum) >= 0; |
140 | params_.dst_is_acc_ |
141 | = utils::one_of(dst_md()->data_type, s32, f32) && !do_sum; |
142 | |
143 | params_.has_pp_kernel_ = need_post_processing(this); |
144 | |
145 | nthr_ = dnnl_get_max_threads(); |
146 | gemm_based::book_acc_scratchpad(*this, params_, sizeof(int32_t), nthr_); |
147 | auto scratchpad = scratchpad_registry().registrar(); |
148 | book_precomputed_scales(scratchpad, attr()->scales_, N()); |
149 | |
150 | return status::success; |
151 | } |
152 | |
153 | void gemm_x8s8s32x_matmul_t::post_process_src_and_weights_zero_points( |
154 | std::vector<int32_t> &src_comp, std::vector<int32_t> &wei_comp, dim_t M, |
155 | dim_t N, dim_t K, const char *src, dim_t src_s0, dim_t src_s1, |
156 | const int8_t *wei, dim_t wei_s0, dim_t wei_s1, int32_t *acc, int ldc, |
157 | int32_t src_zero_point, int32_t wei_zero_point) const { |
158 | if (wei_zero_point) { |
159 | for_(dim_t m = 0; m < M; ++m) |
160 | for (dim_t k = 0; k < K; ++k) { |
161 | if (k == 0) src_comp[m] = int32_t(0); |
162 | src_comp[m] += src[src_s0 * m + src_s1 * k]; |
163 | } |
164 | } |
165 | |
166 | if (src_zero_point) { |
167 | for_(dim_t k = 0; k < K; ++k) |
168 | for (dim_t n = 0; n < N; ++n) { |
169 | if (k == 0) wei_comp[n] = int32_t(0); |
170 | wei_comp[n] += wei[wei_s0 * k + wei_s1 * n]; |
171 | } |
172 | } |
173 | |
174 | for_(dim_t m = 0; m < M; ++m) |
175 | for (dim_t n = 0; n < N; ++n) |
176 | acc[m * ldc + n] += 0 - src_zero_point * wei_comp[n] |
177 | - wei_zero_point * src_comp[m] |
178 | + src_zero_point * wei_zero_point * (int)K; |
179 | } |
180 | |
181 | status_t gemm_x8s8s32x_matmul_t::execute_ref(const exec_ctx_t &ctx) const { |
182 | using namespace binary_injector_utils; |
183 | |
184 | auto src = CTX_IN_MEM(const char *, DNNL_ARG_SRC); |
185 | auto weights = CTX_IN_MEM(const int8_t *, DNNL_ARG_WEIGHTS); |
186 | auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS); |
187 | auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST); |
188 | const auto &po = this->pd()->attr()->post_ops_; |
189 | const auto post_ops_binary_rhs_arg_vec = prepare_binary_args(po, ctx); |
190 | |
191 | const auto src_d = ctx.memory_mdw(DNNL_ARG_SRC, pd()->src_md()); |
192 | const auto weights_d = ctx.memory_mdw(DNNL_ARG_WEIGHTS, pd()->weights_md()); |
193 | const auto dst_d = ctx.memory_mdw(DNNL_ARG_DST, pd()->dst_md()); |
194 | |
195 | const int ndims = pd()->ndims(); |
196 | |
197 | DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC); |
198 | DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS); |
199 | DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); |
200 | |
201 | auto &scratchpad = ctx.get_scratchpad_grantor(); |
202 | const float *scales = precompute_scales(scratchpad, src_scales, wei_scales, |
203 | dst_d.dims()[ndims - 1], pd()->attr()); |
204 | |
205 | DEFINE_ZERO_POINT_VALUE(src_zero_point, DNNL_ARG_SRC); |
206 | DEFINE_ZERO_POINT_VALUE(weights_zero_point, DNNL_ARG_WEIGHTS); |
207 | DEFINE_ZERO_POINT_VALUE(dst_zero_point, DNNL_ARG_DST); |
208 | |
209 | if (src_d.has_zero_dim() || weights_d.has_zero_dim() |
210 | || dst_d.has_zero_dim()) |
211 | return status::success; |
212 | |
213 | int8_t gemm_off_a_int8 = static_cast<int8_t>(src_zero_point); |
214 | uint8_t gemm_off_a_uint8 = static_cast<uint8_t>(src_zero_point); |
215 | int8_t gemm_off_b = static_cast<int8_t>(weights_zero_point); |
216 | const bool ok = IMPLICATION(src_d.data_type() == data_type::s8, |
217 | gemm_off_a_int8 == src_zero_point) |
218 | && IMPLICATION(src_d.data_type() == data_type::u8, |
219 | gemm_off_a_uint8 == src_zero_point) |
220 | && gemm_off_b == weights_zero_point; |
221 | const bool post_process_src_and_weights_zero_points_outside_of_gemm = !ok; |
222 | if (post_process_src_and_weights_zero_points_outside_of_gemm) { |
223 | gemm_off_a_int8 = gemm_off_a_uint8 = gemm_off_b = 0; |
224 | } |
225 | const float dst_zero_point_f32 = static_cast<float>(dst_zero_point); |
226 | |
227 | matmul_helper_t helper(src_d, weights_d, dst_d); |
228 | const int batch_ndims = ndims - 2; |
229 | dim_t M = helper.M(); |
230 | const dim_t N = helper.N(); |
231 | const dim_t K = helper.K(); |
232 | const dim_t batch = helper.batch(); |
233 | const dim_t batch_without_dim0 |
234 | = helper.ndims() > 3 ? batch / dst_d.dims()[0] : 0; |
235 | const dim_t batch_without_dim01 |
236 | = helper.ndims() > 4 ? batch_without_dim0 / dst_d.dims()[1] : 1; |
237 | const char transA = helper.transA(); |
238 | const char transB = helper.transB(); |
239 | const dim_t lda = helper.lda(); |
240 | const dim_t ldb = helper.ldb(); |
241 | const dim_t ldc = helper.ldc(); |
242 | const int ldx_dim_idx = pd()->ndims() - 2; |
243 | const dim_t *src_strides = &src_d.blocking_desc().strides[ldx_dim_idx]; |
244 | const dim_t *weights_strides |
245 | = &weights_d.blocking_desc().strides[ldx_dim_idx]; |
246 | const int nthr = pd()->nthr_; |
247 | |
248 | const gemm_based::params_t ¶ms = pd()->params(); |
249 | const bool can_fuse_src_batch_dims = pd()->has_runtime_dims_or_strides() |
250 | ? helper.can_fuse_src_batch_dims() |
251 | : params.can_fuse_src_batch_dims_; |
252 | const dim_t acc_stride = gemm_based::get_scratchpad_size( |
253 | batch, M, N, can_fuse_src_batch_dims, nthr); |
254 | bool dst_is_acc = params.dst_is_acc_; |
255 | int32_t *acc = dst_is_acc |
256 | ? reinterpret_cast<int32_t *>(dst) |
257 | : ctx.get_scratchpad_grantor().template get<int32_t>( |
258 | memory_tracking::names::key_matmul_dst_in_acc_dt); |
259 | // case: dynamic sizes |
260 | bool need_free_acc = false; |
261 | if (acc == nullptr) { |
262 | acc = (int32_t *)malloc(sizeof(int32_t) * acc_stride |
263 | * ((can_fuse_src_batch_dims || batch == 1) ? 1 : nthr), |
264 | 64); |
265 | |
266 | if (acc == nullptr) return status::out_of_memory; |
267 | need_free_acc = true; |
268 | } |
269 | |
270 | const float alpha = params.get_gemm_alpha(scales); |
271 | const float beta = params.gemm_beta_; |
272 | const dim_t acc_ldc = dst_is_acc ? ldc : N; |
273 | const int scale_idx_mult |
274 | = this->pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_ |
275 | == (1 << (ndims - 1)); |
276 | |
277 | std::atomic<status_t> st(status::success); |
278 | // use parallel over batch when binary po with channel bcast |
279 | // (except batch == 1) |
280 | bool is_binary_po_per_oc = false; |
281 | bool is_binary_po_per_oc_sp = false; |
282 | bool is_binary_po_channel_bcast = false; |
283 | std::tie(is_binary_po_per_oc, is_binary_po_per_oc_sp, |
284 | is_binary_po_channel_bcast) |
285 | = bcast_strategies_present_tup(po.entry_, pd()->dst_md(), |
286 | broadcasting_strategy_t::per_oc, |
287 | broadcasting_strategy_t::per_oc_spatial, |
288 | broadcasting_strategy_t::per_mb_spatial); |
289 | // if batched, parralel over batch for per_mb_sp and per_oc binary |
290 | // post-op broadcast |
291 | const bool can_use_po_with_fused_batch = !is_binary_po_channel_bcast |
292 | && IMPLICATION( |
293 | is_binary_po_per_oc || is_binary_po_per_oc_sp, ndims == 2); |
294 | const bool parallel_over_batch = batch > 1 && !can_fuse_src_batch_dims; |
295 | if (IMPLICATION(can_use_po_with_fused_batch, parallel_over_batch)) { |
296 | const int src_mask |
297 | = utils::get_dims_mask(dst_d.dims(), src_d.dims(), ndims); |
298 | const int wei_mask |
299 | = utils::get_dims_mask(dst_d.dims(), weights_d.dims(), ndims); |
300 | const size_t bia_dt_size = !pd()->with_bias() |
301 | ? 0 |
302 | : types::data_type_size(pd()->weights_md(1)->data_type); |
303 | const size_t dst_dt_size = types::data_type_size(dst_d.data_type()); |
304 | const size_t work_amount = (size_t)batch * M * N; |
305 | const size_t work_per_batch = (size_t)M * N; |
306 | |
307 | // NOTE: inside lambda, type cast variables captured by reference using |
308 | // either c-like "(type)var" or functional "type(var)" notation in order |
309 | // to avoid gcc bug with c++14 standard. Otherwise, capture by value. |
310 | parallel(nthr, [=, &st](int ithr, int nthr) { |
311 | size_t t_work_start {0}, t_work_end {0}; |
312 | balance211(work_amount, nthr, ithr, t_work_start, t_work_end); |
313 | |
314 | dim_t cur_b {0}, cur_m {0}, cur_n {0}; |
315 | dims_t s_dims_idx, w_dims_idx, d_dims_idx; |
316 | size_t i_work = t_work_start; |
317 | |
318 | const bool reuse_acc = acc != (int32_t *)dst; |
319 | int32_t *curr_acc = reuse_acc ? acc + ithr * acc_stride : nullptr; |
320 | |
321 | std::vector<int32_t> src_compensation(M, 0); |
322 | std::vector<int32_t> weights_compensation(N, 0); |
323 | |
324 | // icc 17.0 has a bug with capturing const variables with value known |
325 | // at compilation time in lambdas |
326 | const int32_t gemm_off_c = 0; |
327 | |
328 | while (i_work < t_work_end) { |
329 | utils::nd_iterator_init( |
330 | i_work, cur_b, batch, cur_m, M, cur_n, N); |
331 | |
332 | utils::l_dims_by_l_offset( |
333 | d_dims_idx, i_work, dst_d.dims(), ndims); |
334 | |
335 | utils::copy_dims_with_mask( |
336 | s_dims_idx, d_dims_idx, batch_ndims, src_mask); |
337 | s_dims_idx[ndims - 2] = cur_m; |
338 | s_dims_idx[ndims - 1] = 0; // k idx is always 0 |
339 | |
340 | utils::copy_dims_with_mask( |
341 | w_dims_idx, d_dims_idx, batch_ndims, wei_mask); |
342 | w_dims_idx[ndims - 2] = 0; // k idx is always 0 |
343 | w_dims_idx[ndims - 1] = cur_n; |
344 | |
345 | const char *curr_src = src + src_d.off_v(s_dims_idx); |
346 | const int8_t *curr_weights |
347 | = weights + weights_d.off_v(w_dims_idx); |
348 | const dim_t dst_off = dst_d.off_v(d_dims_idx); |
349 | char *curr_dst = dst + dst_dt_size * dst_off; |
350 | if (!reuse_acc) curr_acc = acc + dst_off; |
351 | |
352 | dim_t gemm_M {0}, gemm_N {0}; |
353 | size_t matrix_offset; |
354 | const size_t rem_work = t_work_end - i_work; |
355 | if (rem_work >= work_per_batch && cur_m == 0 && cur_n == 0) { |
356 | // parallel over batch |
357 | gemm_M = M; |
358 | gemm_N = N; |
359 | matrix_offset = 0; |
360 | } else if (rem_work >= (size_t)N && cur_n == 0) { |
361 | // parallel over M |
362 | gemm_M = nstl::min( |
363 | (size_t)(M - cur_m), (size_t)(rem_work / N)); |
364 | gemm_N = N; |
365 | matrix_offset = cur_n + cur_m * N; |
366 | } else { |
367 | // parallel over N |
368 | gemm_M = 1; |
369 | gemm_N = nstl::min((size_t)(N - cur_n), rem_work); |
370 | matrix_offset = cur_n + cur_m * N; |
371 | } |
372 | |
373 | status_t st_thr = status::runtime_error; |
374 | switch (src_d.data_type()) { |
375 | case data_type::s8: { |
376 | const int8_t *curr_src_ |
377 | = reinterpret_cast<const int8_t *>(curr_src); |
378 | st_thr = gemm_s8x8s32(&transB, &transA, "F" , &gemm_N, |
379 | &gemm_M, &K, &alpha, curr_weights, &ldb, |
380 | &gemm_off_b, curr_src_, &lda, &gemm_off_a_int8, |
381 | &beta, curr_acc, &acc_ldc, &gemm_off_c); |
382 | } break; |
383 | case data_type::u8: { |
384 | const uint8_t *curr_src_ |
385 | = reinterpret_cast<const uint8_t *>(curr_src); |
386 | st_thr = gemm_s8x8s32(&transB, &transA, "F" , &gemm_N, |
387 | &gemm_M, &K, &alpha, curr_weights, &ldb, |
388 | &gemm_off_b, curr_src_, &lda, &gemm_off_a_uint8, |
389 | &beta, curr_acc, &acc_ldc, &gemm_off_c); |
390 | } break; |
391 | default: assert(!"unsupported data type" ); break; |
392 | } |
393 | |
394 | if (st_thr != status::success) { |
395 | st = st_thr; |
396 | return; |
397 | } |
398 | |
399 | // if igemm cannot handle src and weights zero points |
400 | if (post_process_src_and_weights_zero_points_outside_of_gemm) { |
401 | post_process_src_and_weights_zero_points(src_compensation, |
402 | weights_compensation, gemm_M, gemm_N, K, curr_src, |
403 | src_strides[0], src_strides[1], curr_weights, |
404 | weights_strides[0], weights_strides[1], curr_acc, |
405 | acc_ldc, src_zero_point, weights_zero_point); |
406 | } |
407 | |
408 | bool postops_in_matmul |
409 | = need_post_processing(pd(), dst_zero_point_f32); |
410 | assert(IMPLICATION(postops_in_matmul, params.has_pp_kernel_)); |
411 | |
412 | if (postops_in_matmul) { |
413 | const size_t dst_logical_off = i_work; |
414 | const size_t dim1_off = helper.ndims() > 3 |
415 | ? ((cur_b % batch_without_dim0) |
416 | / batch_without_dim01) |
417 | : cur_m; |
418 | // offset for case with post-op broadcast_channel |
419 | const size_t matrix_per_first_batch_off = helper.ndims() > 3 |
420 | ? M * N * (cur_b / batch_without_dim0) |
421 | + matrix_offset |
422 | : 0; |
423 | const ptrdiff_t oc_off = i_work % N; |
424 | (*pp_kernel_)(curr_dst, curr_acc, |
425 | bias + oc_off * bia_dt_size, |
426 | scales + oc_off * scale_idx_mult, dst_scales[0], 0, |
427 | dst_logical_off, dim1_off, gemm_M * gemm_N, |
428 | static_cast<size_t>(N), ldc, &dst_zero_point_f32, |
429 | post_ops_binary_rhs_arg_vec.data(), dst, |
430 | matrix_per_first_batch_off, ctx, *pd()->dst_md()); |
431 | } |
432 | i_work += gemm_M * gemm_N; |
433 | } |
434 | }); |
435 | } else { |
436 | // icc 17.0 has a bug with capturing const variables with value known |
437 | // at compilation time in lambdas |
438 | const int32_t gemm_off_c = 0; |
439 | |
440 | // collapse batch into M, if weights batch dimensions are broadcasted. |
441 | M = batch * M; |
442 | status_t st = status::runtime_error; |
443 | switch (src_d.data_type()) { |
444 | case data_type::s8: { |
445 | const int8_t *src_ = reinterpret_cast<const int8_t *>(src); |
446 | st = gemm_s8x8s32(&transB, &transA, "F" , &N, &M, &K, &alpha, |
447 | weights, &ldb, &gemm_off_b, src_, &lda, |
448 | &gemm_off_a_int8, &beta, acc, &acc_ldc, &gemm_off_c); |
449 | } break; |
450 | case data_type::u8: { |
451 | const uint8_t *src_ = reinterpret_cast<const uint8_t *>(src); |
452 | st = gemm_s8x8s32(&transB, &transA, "F" , &N, &M, &K, &alpha, |
453 | weights, &ldb, &gemm_off_b, src_, &lda, |
454 | &gemm_off_a_uint8, &beta, acc, &acc_ldc, &gemm_off_c); |
455 | } break; |
456 | default: assert(!"unsupported data type" ); break; |
457 | } |
458 | |
459 | if (st == status::success) { |
460 | std::vector<int32_t> src_compensation(M, 0); |
461 | std::vector<int32_t> weights_compensation(N, 0); |
462 | |
463 | // if igemm cannot handle src and weights zero points |
464 | if (post_process_src_and_weights_zero_points_outside_of_gemm) { |
465 | post_process_src_and_weights_zero_points(src_compensation, |
466 | weights_compensation, M, N, K, src, src_strides[0], |
467 | src_strides[1], weights, weights_strides[0], |
468 | weights_strides[1], acc, acc_ldc, src_zero_point, |
469 | weights_zero_point); |
470 | } |
471 | |
472 | bool postops_in_matmul |
473 | = need_post_processing(pd(), dst_zero_point_f32); |
474 | assert(IMPLICATION(postops_in_matmul, params.has_pp_kernel_)); |
475 | |
476 | if (postops_in_matmul) { |
477 | const bool force_sequential = pp_kernel_->sequential_kernel(); |
478 | parallel(force_sequential ? 1 : nthr, [&](int ithr, int nthr) { |
479 | size_t start {}, end {}; |
480 | balance211((size_t)(M * N), nthr, ithr, start, end); |
481 | const size_t dst_logical_off = start; |
482 | const size_t dim1_off = start % N; |
483 | (*pp_kernel_)(dst, acc, bias, scales, dst_scales[0], start, |
484 | dst_logical_off, dim1_off, end, (size_t)N, ldc, |
485 | &dst_zero_point_f32, |
486 | post_ops_binary_rhs_arg_vec.data(), dst, 0, ctx, |
487 | *pd()->dst_md()); |
488 | }); |
489 | } |
490 | } |
491 | } |
492 | if (need_free_acc) free(acc); |
493 | |
494 | return st; |
495 | } |
496 | |
497 | } // namespace matmul |
498 | } // namespace cpu |
499 | } // namespace impl |
500 | } // namespace dnnl |
501 | |