1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_MATMUL_BRGEMM_MATMUL_HPP |
18 | #define CPU_X64_MATMUL_BRGEMM_MATMUL_HPP |
19 | |
20 | #include "common/c_types_map.hpp" |
21 | #include "common/primitive.hpp" |
22 | #include "common/type_helpers.hpp" |
23 | |
24 | #include "cpu/matmul/cpu_matmul_pd.hpp" |
25 | |
26 | #include "cpu/x64/brgemm/brgemm.hpp" |
27 | #include "cpu/x64/cpu_reducer.hpp" |
28 | #include "cpu/x64/matmul/brgemm_matmul_copy_utils.hpp" |
29 | #include "cpu/x64/matmul/brgemm_matmul_utils.hpp" |
30 | |
31 | namespace dnnl { |
32 | namespace impl { |
33 | namespace cpu { |
34 | namespace x64 { |
35 | namespace matmul { |
36 | |
37 | namespace { |
38 | constexpr int max_num_brg_kernels_matmul = 2 * 2 * 2 * 2 * 2; |
39 | |
40 | inline int get_brg_kernel_index(const brgemm_matmul_conf_t &bgmmc, |
41 | bool is_bs_tail, bool do_initialization, bool is_M_tail, bool is_N_tail, |
42 | bool is_K_tail, int bs) { |
43 | auto vM = (is_M_tail) ? bgmmc.M_tail : bgmmc.M_blk; |
44 | auto vN = (is_N_tail) ? bgmmc.N_tail : bgmmc.N_blk; |
45 | auto vK = (is_K_tail) ? bgmmc.K_tail : bgmmc.K_blk; |
46 | if (vM == 0 || vN == 0 || vK == 0 || bs == 0 || bgmmc.LDA < vK |
47 | || bgmmc.LDB < vN || bgmmc.LDC < vN) |
48 | return -1; |
49 | |
50 | int idx = 16 * (int)is_bs_tail + 8 * (int)do_initialization |
51 | + 4 * (int)is_M_tail + 2 * (int)is_N_tail + (int)is_K_tail; |
52 | |
53 | assert(idx < max_num_brg_kernels_matmul); |
54 | return idx; |
55 | } |
56 | |
57 | inline int get_brg_batchsize( |
58 | const brgemm_matmul_conf_t &bgmmc, bool is_bs_tail, bool is_K_tail) { |
59 | auto bs = is_K_tail ? 1 |
60 | : is_bs_tail ? bgmmc.brgemm_batch_tail_size |
61 | : bgmmc.brgemm_batch_size; |
62 | return bs; |
63 | } |
64 | } // namespace |
65 | |
66 | template <cpu_isa_t isa> |
67 | struct brgemm_matmul_t : public primitive_t { |
68 | struct pd_t : public ::dnnl::impl::cpu::matmul::cpu_matmul_pd_t { |
69 | using ::dnnl::impl::cpu::matmul::cpu_matmul_pd_t::cpu_matmul_pd_t; |
70 | |
71 | DECLARE_COMMON_PD_T( |
72 | JIT_IMPL_NAME_HELPER("brg:" , isa, "" ), brgemm_matmul_t); |
73 | |
74 | status_t init(engine_t *engine); |
75 | int get_brg_kernel_idx(bool is_bs_tail, bool do_initialization, |
76 | bool is_M_tail, bool is_N_tail, bool is_K_tail) const { |
77 | int bs = get_brg_batchsize(bgmmc_, is_bs_tail, is_K_tail); |
78 | return get_brg_kernel_index(bgmmc_, is_bs_tail, do_initialization, |
79 | is_M_tail, is_N_tail, is_K_tail, bs); |
80 | } |
81 | const brgemm_t &get_brg_desc(int idx) const { return brg_descs_[idx]; } |
82 | const brgemm_matmul_conf_t &get_brgemm_matmul_conf() const { |
83 | return bgmmc_; |
84 | } |
85 | |
86 | private: |
87 | brgemm_t brg_descs_[max_num_brg_kernels_matmul]; |
88 | brgemm_matmul_conf_t bgmmc_; |
89 | }; |
90 | |
91 | brgemm_matmul_t(const pd_t *apd) : primitive_t(apd) {} |
92 | |
93 | status_t init(engine_t *engine) override; |
94 | static constexpr data_type_t acc_type = data_type::s32; |
95 | |
96 | status_t execute(const exec_ctx_t &ctx) const override { |
97 | return execute_body(ctx); |
98 | } |
99 | |
100 | private: |
101 | struct brg_matmul_exec_ctx_t; |
102 | |
103 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
104 | status_t execute_body(const exec_ctx_t &ctx) const; |
105 | void compute_kernel(const brg_matmul_exec_ctx_t &brgmm_ctx, int ithr, |
106 | int b_idx, int m_blk_idx, int n_blk_idx, int k_blk_idx, |
107 | bool do_init) const; |
108 | void copy_a_chunk_in_buffer(const brg_matmul_exec_ctx_t &brgmm_ctx, |
109 | int ithr, int b_idx, int m_blk_idx, int k_blk_idx) const; |
110 | void copy_b_chunk_in_buffer(const brg_matmul_exec_ctx_t &brgmm_ctx, |
111 | int ithr, int b_idx, int n_blk_idx, int k_blk_idx) const; |
112 | void maybe_reduce_partial_results_and_apply_postops( |
113 | const brg_matmul_exec_ctx_t &brgmm_ctx) const; |
114 | void accumulate( |
115 | char *result_ptr, const char *reduce_ptr, size_t size) const; |
116 | |
117 | std::unique_ptr<brgemm_kernel_t> brg_kernels_[max_num_brg_kernels_matmul]; |
118 | char brg_kernel_palettes_[max_num_brg_kernels_matmul][64]; |
119 | std::unique_ptr<jit_brgemm_matmul_copy_b_t> copy_B_kernel_; |
120 | std::unique_ptr<jit_brgemm_matmul_copy_a_t> copy_A_kernel_; |
121 | std::unique_ptr<cpu_accumulator_1d_t<data_type::f32>> acc_ker_f32_; |
122 | std::unique_ptr<cpu_accumulator_1d_t<data_type::s32>> acc_ker_s32_; |
123 | }; |
124 | |
125 | } // namespace matmul |
126 | } // namespace x64 |
127 | } // namespace cpu |
128 | } // namespace impl |
129 | } // namespace dnnl |
130 | |
131 | #endif |
132 | |