1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_BRGEMM_CONV_BWD_TRANS_KERNEL_HPP
18#define CPU_X64_JIT_BRGEMM_CONV_BWD_TRANS_KERNEL_HPP
19
20#include "cpu/x64/jit_generator.hpp"
21#include "cpu/x64/jit_primitive_conf.hpp"
22
23namespace dnnl {
24namespace impl {
25namespace cpu {
26namespace x64 {
27
28namespace jit_avx512_core_brgemm_conv_bwd_trans_kernel {
29struct jit_brgemm_conv_bwd_trans_kernel_call_s {
30 const void *src;
31 const void *dst;
32 size_t iwb;
33 size_t oc;
34 size_t t_pad;
35 size_t h_count;
36 size_t b_pad;
37};
38
39struct jit_avx512_core_brgemm_conv_bwd_trans_kernel_t : public jit_generator {
40 DECLARE_CPU_JIT_AUX_FUNCTIONS(
41 jit_avx512_core_brgemm_conv_bwd_trans_kernel_t)
42
43 using reg64_t = const Xbyak::Reg64;
44
45 jit_avx512_core_brgemm_conv_bwd_trans_kernel_t(
46 const jit_brgemm_conv_conf_t &ajcp, const char *name = jit_name());
47
48protected:
49 jit_brgemm_conv_conf_t jcp;
50 dim_t inp_dsz;
51 dim_t oc_block_sz;
52 dim_t ow_size, dst_w_block, dst_stride;
53 dim_t dst_h_offset, dst_w_offset;
54 dim_t VL, n_vec, n_tail_vec;
55 const reg64_t inp_ptr = r15;
56 const reg64_t dst_ptr = r14;
57
58 const reg64_t aux_inp_ptr = r13;
59 const reg64_t aux_dst_ptr = r12;
60
61 const reg64_t reg_hc = r10;
62
63 const reg64_t reg_oc = r9;
64
65 const reg64_t reg_iwb = rdx;
66
67 const reg64_t kh_over = r8;
68 const reg64_t reg_t_pad = rax;
69 const reg64_t reg_b_pad = rbx;
70
71 const reg64_t reg_tmp = rsi;
72
73 const Xbyak::Opmask ktail_mask = Xbyak::Opmask(2);
74 const Xbyak::Opmask kblock_tail_mask = Xbyak::Opmask(3);
75
76 const Xbyak::Zmm zmm_tmp = Xbyak::Zmm(0);
77 const Xbyak::Zmm zmm_zero = Xbyak::Zmm(1);
78
79 void load(const Xbyak::Xmm &x, const Xbyak::Address &addr);
80
81 void store(const Xbyak::Address &addr, const Xbyak::Xmm &x);
82
83 void zero_oc_block(bool is_oc_tail, dim_t dst_off);
84 void copy_oc_block(
85 bool is_oc_tail, dim_t inp_off, dim_t dst_off, bool do_load);
86 void generate() override;
87 void copy_iw_block(bool is_oc_tail);
88 void copy_iw_block_body(int lpad, int iw_len, int ow_len, bool is_oc_tail);
89
90 int inp_w(int out_w) const;
91 int inp_w_start(int iwb) const;
92};
93
94} // namespace jit_avx512_core_brgemm_conv_bwd_trans_kernel
95} // namespace x64
96} // namespace cpu
97} // namespace impl
98} // namespace dnnl
99
100#endif
101
102// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
103