1/*******************************************************************************
2* Copyright 2018-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "common/dnnl_thread.hpp"
18#include "common/math_utils.hpp"
19#include "cpu/simple_q10n.hpp"
20
21#include "cpu/cpu_primitive.hpp"
22#include "cpu/scale_utils.hpp"
23
24#include "cpu/binary_injector_utils.hpp"
25#include "cpu/gemm/gemm.hpp"
26#include "cpu/gemm_x8s8s32x_inner_product.hpp"
27
28namespace dnnl {
29namespace impl {
30namespace cpu {
31
32using namespace math;
33using namespace format_tag;
34using namespace memory_tracking::names;
35
36status_t gemm_x8s8s32x_inner_product_fwd_t::execute_forward(
37 const exec_ctx_t &ctx) const {
38 auto src = CTX_IN_MEM(const char *, DNNL_ARG_SRC);
39 auto weights = CTX_IN_MEM(const int8_t *, DNNL_ARG_WEIGHTS);
40 auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS);
41 auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST);
42 const auto post_ops_binary_rhs_arg_vec
43 = binary_injector_utils::prepare_binary_args(
44 this->pd()->attr()->post_ops_, ctx);
45
46 const dim_t MB = pd()->MB();
47 const dim_t OC = pd()->OC();
48 const dim_t IC = pd()->IC();
49
50 const auto &wmd = *pd()->weights_md();
51 const auto &smd = *pd()->src_md();
52 bool wei_tr = wmd.format_desc.blocking.strides[0] != 1;
53 // check if MB is the leading dimension
54 bool src_tr = smd.format_desc.blocking.strides[0] == 1 && IC > 1;
55
56 const dim_t M = OC;
57 const dim_t N = MB;
58 const dim_t K = pd()->IC_total_padded();
59 const int8_t off_a = 0;
60 const int32_t off_c = 0;
61
62 DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC);
63 DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS);
64 DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST);
65
66 auto scratchpad = ctx.get_scratchpad_grantor();
67 const float *scales = precompute_scales(
68 scratchpad, src_scales, wei_scales, OC, pd()->attr());
69
70 int32_t *acc = pd()->dst_is_acc_
71 ? (int32_t *)dst
72 : ctx.get_scratchpad_grantor().template get<int32_t>(
73 key_iprod_int_dat_in_acc_dt);
74
75 const float onef = 1.0, zerof = 0.0;
76
77 if (smd.data_type == data_type::s8) {
78 const int8_t off_b = 0;
79 const int8_t *src_ = reinterpret_cast<const int8_t *>(src);
80 CHECK(gemm_s8x8s32(wei_tr ? "T" : "N", src_tr ? "T" : "N", "F", &M, &N,
81 &K, &onef, weights, wei_tr ? &K : &M, &off_a, src_,
82 src_tr ? &N : &K, &off_b, &zerof, acc, &M, &off_c));
83 } else if (smd.data_type == data_type::u8) {
84 const uint8_t off_b = 0;
85 const uint8_t *src_ = reinterpret_cast<const uint8_t *>(src);
86 CHECK(gemm_s8x8s32(wei_tr ? "T" : "N", src_tr ? "T" : "N", "F", &M, &N,
87 &K, &onef, weights, wei_tr ? &K : &M, &off_a, src_,
88 src_tr ? &N : &K, &off_b, &zerof, acc, &M, &off_c));
89 } else {
90 assert(!"unsupported data type!");
91 }
92
93 if (!pd()->attr()->has_default_values()
94 || pd()->dst_md()->data_type != data_type::s32
95 || pd()->with_bias()) {
96 const bool force_sequential
97 = pp_kernel_->sequential_kernel() || MB * OC < 2000;
98 parallel(force_sequential ? 1 : 0, [&](int ithr, int nthr) {
99 size_t start, end;
100 balance211((size_t)(OC * MB), nthr, ithr, start, end);
101 const size_t dst_logical_off = start;
102 const size_t dim1_off = start % OC;
103 (*pp_kernel_)(dst, acc, bias, scales, dst_scales[0], start,
104 dst_logical_off, dim1_off, end, 0, 0, nullptr,
105 post_ops_binary_rhs_arg_vec.data(), dst, 0, ctx,
106 *pd()->dst_md());
107 });
108 }
109
110 return status::success;
111}
112
113} // namespace cpu
114} // namespace impl
115} // namespace dnnl
116