1/*******************************************************************************
2* Copyright 2019-2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_MATMUL_GEMM_X8S8S32X_MATMUL_HPP
18#define CPU_MATMUL_GEMM_X8S8S32X_MATMUL_HPP
19
20#include <assert.h>
21
22#include <memory>
23
24#include "common/c_types_map.hpp"
25#include "common/primitive.hpp"
26#include "common/type_helpers.hpp"
27
28#include "cpu/gemm_inner_product_utils.hpp"
29
30#include "cpu/matmul/cpu_matmul_pd.hpp"
31#include "cpu/matmul/gemm_based_common.hpp"
32
33namespace dnnl {
34namespace impl {
35namespace cpu {
36namespace matmul {
37
38struct gemm_x8s8s32x_matmul_t : public primitive_t {
39 struct pd_t : public cpu_matmul_pd_t {
40 using cpu_matmul_pd_t::cpu_matmul_pd_t;
41
42 DECLARE_COMMON_PD_T("gemm:jit", gemm_x8s8s32x_matmul_t);
43
44 status_t init(engine_t *engine);
45 const gemm_based::params_t &params() const { return params_; }
46
47 int nthr_; // To not exceed the limit in execute used for set up.
48
49 private:
50 gemm_based::params_t params_;
51 };
52
53 gemm_x8s8s32x_matmul_t(const pd_t *apd) : primitive_t(apd) {}
54
55 status_t init(engine_t *engine) override {
56 if (pd()->params().has_pp_kernel_) {
57 const bool has_runtime_dims
58 = memory_desc_wrapper(pd()->dst_md()).has_runtime_dims();
59 const int nthr = pd()->nthr_;
60 const dim_t batch = pd()->batch();
61 const dim_t M = pd()->M();
62
63 // mb value is calculated based on work-sharing using
64 // balance211 in execute()
65 dim_t mb = DNNL_RUNTIME_DIM_VAL;
66 if (!has_runtime_dims && ((batch * M) % nthr == 0)) {
67 const dim_t m_per_thr = nstl::max<dim_t>(1, (batch * M) / nthr);
68 if (m_per_thr >= M && m_per_thr % M == 0) {
69 mb = M;
70 } else if (m_per_thr < M && M % m_per_thr == 0) {
71 mb = m_per_thr;
72 }
73 }
74
75 CHECK(safe_ptr_assign(pp_kernel_,
76 inner_product_utils::pp_kernel_t::create(pd()->N(), mb,
77 pd()->ldc(), &pd()->params().pp_attr_,
78 pd()->desc()->bias_desc.data_type,
79 pd()->desc()->accum_data_type, pd()->dst_md(),
80 false)));
81 return pp_kernel_->create_kernel();
82 }
83 return status::success;
84 }
85
86 status_t execute(const exec_ctx_t &ctx) const override {
87 return execute_ref(ctx);
88 }
89
90private:
91 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
92 status_t execute_ref(const exec_ctx_t &ctx) const;
93 void post_process_src_and_weights_zero_points(
94 std::vector<int32_t> &src_comp, std::vector<int32_t> &wei_comp,
95 dim_t M, dim_t N, dim_t K, const char *src, dim_t src_s0,
96 dim_t src_s1, const int8_t *wei, dim_t wei_s0, dim_t wei_s1,
97 int32_t *acc, int ldc, int32_t src_zero_point,
98 int32_t wei_zero_point) const;
99
100 std::unique_ptr<inner_product_utils::pp_kernel_t> pp_kernel_;
101};
102
103} // namespace matmul
104} // namespace cpu
105} // namespace impl
106} // namespace dnnl
107
108#endif
109