1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_BRGEMM_CONV_COMP_PAD_KERNEL_HPP
18#define CPU_X64_JIT_BRGEMM_CONV_COMP_PAD_KERNEL_HPP
19
20#include "cpu/x64/jit_generator.hpp"
21#include "cpu/x64/jit_primitive_conf.hpp"
22
23namespace dnnl {
24namespace impl {
25namespace cpu {
26namespace x64 {
27
28namespace jit_avx512_core_brgemm_conv_comp_pad_kernel {
29struct jit_brgemm_conv_comp_pad_call_s {
30 const void *ptr_in;
31 void *ptr_zp_out;
32 void *ptr_cp_out;
33 size_t kw_l;
34 size_t kh_l;
35 size_t kd_l;
36};
37
38struct jit_avx512_core_brgemm_conv_comp_pad_kernel_t : public jit_generator {
39 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_brgemm_conv_comp_pad_kernel_t)
40
41 using reg64_t = const Xbyak::Reg64;
42
43 jit_avx512_core_brgemm_conv_comp_pad_kernel_t(
44 const jit_brgemm_conv_conf_t &ajcp);
45
46 ~jit_avx512_core_brgemm_conv_comp_pad_kernel_t() = default;
47
48protected:
49 jit_brgemm_conv_conf_t jcp_;
50 const int inp_dsz_;
51 const int out_dsz_;
52 const size_t nb_ic_;
53 const size_t inp_ic_sz_;
54 const size_t inp_kw_sz_;
55 const size_t inp_kh_sz_;
56 const size_t inp_kd_sz_;
57
58 // Register decomposition
59 const reg64_t param1 = abi_param1;
60 const reg64_t reg_in = r15;
61 const reg64_t reg_comp_out = r14;
62 const reg64_t reg_zp_comp_out = r13;
63
64 const reg64_t reg_kd_l = r12;
65 const reg64_t reg_kh_l = r11;
66 const reg64_t reg_kw_l = r10;
67 const reg64_t reg_icb = r9;
68
69 const reg64_t reg_aux_in = r8;
70 const reg64_t reg_aux_kh_in = rbx;
71 const reg64_t reg_aux_kw_in = rsi;
72 const reg64_t reg_tmp = rax;
73
74 Xbyak::Zmm zmm_one_bytes = Xbyak::Zmm(30);
75 Xbyak::Zmm zmm_zp_shift = Xbyak::Zmm(29);
76 Xbyak::Zmm zmm_cp_shift = Xbyak::Zmm(28);
77
78 const int last_ic_block_ = 4;
79 const int n_block2_ = 4;
80 const int m_block2_ = 16;
81 const int n_max_regs_ = 4;
82
83 const Xbyak::Zmm &zmm_tmp_1() const noexcept { return this->zmm31; }
84
85 Xbyak::Zmm accum(const int n_block, const int m, const int n) const;
86 size_t out_oc_offset(const int n) const;
87 size_t inp_ic_offset(
88 const int m_block, const int icb, const int m, const int n) const;
89 int compute_ic_step(
90 const int m_max_regs, const int m_block, const int n_block) const;
91
92 void store_accumulators(const int m_block, const int n_block);
93 void zero_accumulators(const int m_block, const int n_block);
94 void compute(const int ic_step, const int m_block, const int n_block,
95 const int m_tail, const bool is_mb_tail);
96 void icb_loop(const int icb, const int icb_tail, const int ic_step,
97 const int m_block, const int mb_tail, const int n_block);
98 void khw_loop(const int icb, const int icb_tail, const int ic_step,
99 const int m_block, const int mb_tail, const int n_block);
100 void load_params();
101 void generate() override;
102};
103
104} // namespace jit_avx512_core_brgemm_conv_comp_pad_kernel
105
106} // namespace x64
107} // namespace cpu
108} // namespace impl
109} // namespace dnnl
110
111#endif
112
113// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
114