1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_JIT_BRGEMM_CONV_COMP_PAD_KERNEL_HPP |
18 | #define CPU_X64_JIT_BRGEMM_CONV_COMP_PAD_KERNEL_HPP |
19 | |
20 | #include "cpu/x64/jit_generator.hpp" |
21 | #include "cpu/x64/jit_primitive_conf.hpp" |
22 | |
23 | namespace dnnl { |
24 | namespace impl { |
25 | namespace cpu { |
26 | namespace x64 { |
27 | |
28 | namespace jit_avx512_core_brgemm_conv_comp_pad_kernel { |
29 | struct jit_brgemm_conv_comp_pad_call_s { |
30 | const void *ptr_in; |
31 | void *ptr_zp_out; |
32 | void *ptr_cp_out; |
33 | size_t kw_l; |
34 | size_t kh_l; |
35 | size_t kd_l; |
36 | }; |
37 | |
38 | struct jit_avx512_core_brgemm_conv_comp_pad_kernel_t : public jit_generator { |
39 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_brgemm_conv_comp_pad_kernel_t) |
40 | |
41 | using reg64_t = const Xbyak::Reg64; |
42 | |
43 | jit_avx512_core_brgemm_conv_comp_pad_kernel_t( |
44 | const jit_brgemm_conv_conf_t &ajcp); |
45 | |
46 | ~jit_avx512_core_brgemm_conv_comp_pad_kernel_t() = default; |
47 | |
48 | protected: |
49 | jit_brgemm_conv_conf_t jcp_; |
50 | const int inp_dsz_; |
51 | const int out_dsz_; |
52 | const size_t nb_ic_; |
53 | const size_t inp_ic_sz_; |
54 | const size_t inp_kw_sz_; |
55 | const size_t inp_kh_sz_; |
56 | const size_t inp_kd_sz_; |
57 | |
58 | // Register decomposition |
59 | const reg64_t param1 = abi_param1; |
60 | const reg64_t reg_in = r15; |
61 | const reg64_t reg_comp_out = r14; |
62 | const reg64_t reg_zp_comp_out = r13; |
63 | |
64 | const reg64_t reg_kd_l = r12; |
65 | const reg64_t reg_kh_l = r11; |
66 | const reg64_t reg_kw_l = r10; |
67 | const reg64_t reg_icb = r9; |
68 | |
69 | const reg64_t reg_aux_in = r8; |
70 | const reg64_t reg_aux_kh_in = rbx; |
71 | const reg64_t reg_aux_kw_in = rsi; |
72 | const reg64_t reg_tmp = rax; |
73 | |
74 | Xbyak::Zmm zmm_one_bytes = Xbyak::Zmm(30); |
75 | Xbyak::Zmm zmm_zp_shift = Xbyak::Zmm(29); |
76 | Xbyak::Zmm zmm_cp_shift = Xbyak::Zmm(28); |
77 | |
78 | const int last_ic_block_ = 4; |
79 | const int n_block2_ = 4; |
80 | const int m_block2_ = 16; |
81 | const int n_max_regs_ = 4; |
82 | |
83 | const Xbyak::Zmm &zmm_tmp_1() const noexcept { return this->zmm31; } |
84 | |
85 | Xbyak::Zmm accum(const int n_block, const int m, const int n) const; |
86 | size_t out_oc_offset(const int n) const; |
87 | size_t inp_ic_offset( |
88 | const int m_block, const int icb, const int m, const int n) const; |
89 | int compute_ic_step( |
90 | const int m_max_regs, const int m_block, const int n_block) const; |
91 | |
92 | void store_accumulators(const int m_block, const int n_block); |
93 | void zero_accumulators(const int m_block, const int n_block); |
94 | void compute(const int ic_step, const int m_block, const int n_block, |
95 | const int m_tail, const bool is_mb_tail); |
96 | void icb_loop(const int icb, const int icb_tail, const int ic_step, |
97 | const int m_block, const int mb_tail, const int n_block); |
98 | void khw_loop(const int icb, const int icb_tail, const int ic_step, |
99 | const int m_block, const int mb_tail, const int n_block); |
100 | void load_params(); |
101 | void generate() override; |
102 | }; |
103 | |
104 | } // namespace jit_avx512_core_brgemm_conv_comp_pad_kernel |
105 | |
106 | } // namespace x64 |
107 | } // namespace cpu |
108 | } // namespace impl |
109 | } // namespace dnnl |
110 | |
111 | #endif |
112 | |
113 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
114 | |