1/*******************************************************************************
2* Copyright 2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_BRGEMM_CONV_TRANS_KERNEL_HPP
18#define CPU_X64_JIT_BRGEMM_CONV_TRANS_KERNEL_HPP
19
20#include "cpu/x64/jit_generator.hpp"
21#include "cpu/x64/jit_primitive_conf.hpp"
22
23namespace dnnl {
24namespace impl {
25namespace cpu {
26namespace x64 {
27
28namespace jit_avx512_core_brgemm_conv_trans_kernel {
29struct jit_brgemm_conv_trans_kernel_call_s {
30 const void *src;
31 const void *dst;
32 size_t owb;
33 size_t ic;
34 size_t t_pad;
35 size_t h_count;
36 size_t b_pad;
37};
38
39struct jit_avx512_core_brgemm_conv_trans_kernel_t : public jit_generator {
40 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_brgemm_conv_trans_kernel_t)
41
42 using reg64_t = const Xbyak::Reg64;
43
44 jit_avx512_core_brgemm_conv_trans_kernel_t(
45 const jit_brgemm_conv_conf_t &ajcp, const char *name = jit_name());
46
47 int dst_w(int out_w) const;
48
49protected:
50 jit_brgemm_conv_conf_t jcp;
51 dim_t inp_dsz;
52 dim_t ic_block_sz;
53 dim_t iw_size, dst_w_block, dst_stride;
54 dim_t dst_h_offset, dst_w_offset;
55 dim_t VL, n_vec, n_tail_vec;
56 const reg64_t inp_ptr = r15;
57 const reg64_t dst_ptr = r14;
58
59 const reg64_t aux_inp_ptr = r13;
60 const reg64_t aux_dst_ptr = r12;
61
62 const reg64_t reg_hc = r10;
63
64 const reg64_t reg_ic = r9;
65
66 const reg64_t reg_owb = rdx;
67
68 const reg64_t kh_over = r8;
69 const reg64_t reg_t_pad = rax;
70 const reg64_t reg_b_pad = rbx;
71
72 const reg64_t reg_tmp = rsi;
73
74 const Xbyak::Opmask ktail_mask = Xbyak::Opmask(2);
75 const Xbyak::Opmask kblock_tail_mask = Xbyak::Opmask(3);
76
77 const Xbyak::Zmm zmm_tmp = Xbyak::Zmm(0);
78 const Xbyak::Zmm zmm_zero = Xbyak::Zmm(1);
79
80 void load(const Xbyak::Xmm &x, const Xbyak::Address &addr);
81
82 void store(const Xbyak::Address &addr, const Xbyak::Xmm &x);
83
84 void zero_ic_block(bool is_ic_tail, dim_t dst_off);
85 void copy_ic_block(
86 bool is_ic_tail, dim_t inp_off, dim_t dst_off, bool do_load);
87 void generate() override;
88 void copy_ow_block(bool is_ic_tail);
89 void copy_ow_block_body(int lpad, int ow_len, int iw_len, bool is_ic_tail);
90
91 int inp_w(int out_w) const;
92 int inp_w(int out_w, int kw) const;
93 int inp_w_start(int owb) const;
94};
95
96struct jit_avx512_core_brgemm_conv_rtus_kernel_t
97 : jit_avx512_core_brgemm_conv_trans_kernel_t {
98 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_brgemm_conv_rtus_kernel_t)
99
100 jit_avx512_core_brgemm_conv_rtus_kernel_t(
101 const jit_brgemm_conv_conf_t &ajcp);
102
103private:
104 void generate() override;
105};
106
107} // namespace jit_avx512_core_brgemm_conv_trans_kernel
108
109} // namespace x64
110} // namespace cpu
111} // namespace impl
112} // namespace dnnl
113
114#endif
115
116// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
117