1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_MATMUL_BRGEMM_MATMUL_HPP
18#define CPU_X64_MATMUL_BRGEMM_MATMUL_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/primitive.hpp"
22#include "common/type_helpers.hpp"
23
24#include "cpu/matmul/cpu_matmul_pd.hpp"
25
26#include "cpu/x64/brgemm/brgemm.hpp"
27#include "cpu/x64/cpu_reducer.hpp"
28#include "cpu/x64/matmul/brgemm_matmul_copy_utils.hpp"
29#include "cpu/x64/matmul/brgemm_matmul_utils.hpp"
30
31namespace dnnl {
32namespace impl {
33namespace cpu {
34namespace x64 {
35namespace matmul {
36
37namespace {
38constexpr int max_num_brg_kernels_matmul = 2 * 2 * 2 * 2 * 2;
39
40inline int get_brg_kernel_index(const brgemm_matmul_conf_t &bgmmc,
41 bool is_bs_tail, bool do_initialization, bool is_M_tail, bool is_N_tail,
42 bool is_K_tail, int bs) {
43 auto vM = (is_M_tail) ? bgmmc.M_tail : bgmmc.M_blk;
44 auto vN = (is_N_tail) ? bgmmc.N_tail : bgmmc.N_blk;
45 auto vK = (is_K_tail) ? bgmmc.K_tail : bgmmc.K_blk;
46 if (vM == 0 || vN == 0 || vK == 0 || bs == 0 || bgmmc.LDA < vK
47 || bgmmc.LDB < vN || bgmmc.LDC < vN)
48 return -1;
49
50 int idx = 16 * (int)is_bs_tail + 8 * (int)do_initialization
51 + 4 * (int)is_M_tail + 2 * (int)is_N_tail + (int)is_K_tail;
52
53 assert(idx < max_num_brg_kernels_matmul);
54 return idx;
55}
56
57inline int get_brg_batchsize(
58 const brgemm_matmul_conf_t &bgmmc, bool is_bs_tail, bool is_K_tail) {
59 auto bs = is_K_tail ? 1
60 : is_bs_tail ? bgmmc.brgemm_batch_tail_size
61 : bgmmc.brgemm_batch_size;
62 return bs;
63}
64} // namespace
65
66template <cpu_isa_t isa>
67struct brgemm_matmul_t : public primitive_t {
68 struct pd_t : public ::dnnl::impl::cpu::matmul::cpu_matmul_pd_t {
69 using ::dnnl::impl::cpu::matmul::cpu_matmul_pd_t::cpu_matmul_pd_t;
70
71 DECLARE_COMMON_PD_T(
72 JIT_IMPL_NAME_HELPER("brg:", isa, ""), brgemm_matmul_t);
73
74 status_t init(engine_t *engine);
75 int get_brg_kernel_idx(bool is_bs_tail, bool do_initialization,
76 bool is_M_tail, bool is_N_tail, bool is_K_tail) const {
77 int bs = get_brg_batchsize(bgmmc_, is_bs_tail, is_K_tail);
78 return get_brg_kernel_index(bgmmc_, is_bs_tail, do_initialization,
79 is_M_tail, is_N_tail, is_K_tail, bs);
80 }
81 const brgemm_t &get_brg_desc(int idx) const { return brg_descs_[idx]; }
82 const brgemm_matmul_conf_t &get_brgemm_matmul_conf() const {
83 return bgmmc_;
84 }
85
86 private:
87 brgemm_t brg_descs_[max_num_brg_kernels_matmul];
88 brgemm_matmul_conf_t bgmmc_;
89 };
90
91 brgemm_matmul_t(const pd_t *apd) : primitive_t(apd) {}
92
93 status_t init(engine_t *engine) override;
94 static constexpr data_type_t acc_type = data_type::s32;
95
96 status_t execute(const exec_ctx_t &ctx) const override {
97 return execute_body(ctx);
98 }
99
100private:
101 struct brg_matmul_exec_ctx_t;
102
103 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
104 status_t execute_body(const exec_ctx_t &ctx) const;
105 void compute_kernel(const brg_matmul_exec_ctx_t &brgmm_ctx, int ithr,
106 int b_idx, int m_blk_idx, int n_blk_idx, int k_blk_idx,
107 bool do_init) const;
108 void copy_a_chunk_in_buffer(const brg_matmul_exec_ctx_t &brgmm_ctx,
109 int ithr, int b_idx, int m_blk_idx, int k_blk_idx) const;
110 void copy_b_chunk_in_buffer(const brg_matmul_exec_ctx_t &brgmm_ctx,
111 int ithr, int b_idx, int n_blk_idx, int k_blk_idx) const;
112 void maybe_reduce_partial_results_and_apply_postops(
113 const brg_matmul_exec_ctx_t &brgmm_ctx) const;
114 void accumulate(
115 char *result_ptr, const char *reduce_ptr, size_t size) const;
116
117 std::unique_ptr<brgemm_kernel_t> brg_kernels_[max_num_brg_kernels_matmul];
118 char brg_kernel_palettes_[max_num_brg_kernels_matmul][64];
119 std::unique_ptr<jit_brgemm_matmul_copy_b_t> copy_B_kernel_;
120 std::unique_ptr<jit_brgemm_matmul_copy_a_t> copy_A_kernel_;
121 std::unique_ptr<cpu_accumulator_1d_t<data_type::f32>> acc_ker_f32_;
122 std::unique_ptr<cpu_accumulator_1d_t<data_type::s32>> acc_ker_s32_;
123};
124
125} // namespace matmul
126} // namespace x64
127} // namespace cpu
128} // namespace impl
129} // namespace dnnl
130
131#endif
132