1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_AVX512_COMMON_1X1_CONV_KERNEL_HPP
18#define CPU_X64_JIT_AVX512_COMMON_1X1_CONV_KERNEL_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/memory_tracking.hpp"
22
23#include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
24#include "cpu/x64/jit_generator.hpp"
25#include "cpu/x64/jit_primitive_conf.hpp"
26
27namespace dnnl {
28namespace impl {
29namespace cpu {
30namespace x64 {
31
32struct jit_avx512_common_1x1_conv_kernel : public jit_generator {
33 jit_avx512_common_1x1_conv_kernel(const jit_1x1_conv_conf_t &ajcp,
34 const primitive_attr_t &attr, const memory_desc_t &dst_md);
35
36 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_1x1_conv_kernel)
37
38 static status_t init_conf(jit_1x1_conv_conf_t &jcp,
39 const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
40 const memory_desc_wrapper &weights_d,
41 const memory_desc_wrapper &dst_d, const primitive_attr_t &attr,
42 int nthreads, bool reduce_src);
43
44 static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
45 const jit_1x1_conv_conf_t &jcp);
46
47 jit_1x1_conv_conf_t jcp;
48 const primitive_attr_t &attr_;
49
50private:
51 std::unique_ptr<injector::jit_uni_postops_injector_t<avx512_core>>
52 postops_injector_;
53
54 constexpr static int isa_simd_width_
55 = cpu_isa_traits<avx512_core>::vlen / sizeof(float);
56 using reg64_t = const Xbyak::Reg64;
57 using zmm_t = const Xbyak::Zmm;
58
59 reg64_t reg_bcast_data = r8;
60 reg64_t reg_load_data = r10;
61 reg64_t reg_output_data = r9;
62 reg64_t aux_reg_bcast_data = r14;
63 reg64_t aux1_reg_bcast_data = rbx;
64 reg64_t aux_reg_load_data = r15;
65 reg64_t imm_addr64 = aux_reg_load_data;
66 reg64_t aux_reg_output_data = abi_not_param1;
67 reg64_t reg_load_loop_work = rsi;
68 reg64_t reg_reduce_loop_work = r11;
69 reg64_t reg_bcast_loop_iter = rdx;
70 reg64_t reduce_loop_iter = abi_param1;
71 reg64_t reg_reduce_pos_flag = rax;
72 reg64_t reg_output_stride = r13;
73 reg64_t reg_bias_data = r12;
74 reg64_t reg_relu_ns = r13;
75 reg64_t reg_bcast_loop_work = aux1_reg_bcast_data;
76 reg64_t reg_load_dim_tail_mask = aux_reg_load_data;
77 reg64_t reg_long_offt = reg_bcast_data;
78
79 Xbyak::Zmm vreg_bcast = Xbyak::Zmm(31);
80 Xbyak::Opmask k_load_dim_mask = Xbyak::Opmask(2);
81 Xbyak::Opmask k_load_dim_tail_mask = Xbyak::Opmask(3);
82
83 constexpr static int reg64_size_ = sizeof(int64_t);
84 constexpr static int reg_bcast_loop_work_offt = 0;
85 constexpr static int reg_binary_post_op_acc_off = 1 * reg64_size_;
86 constexpr static int reg_abi_param1_backup = 2 * reg64_size_;
87 constexpr static int reg_bcast_data_off = 3 * reg64_size_;
88 constexpr static int stack_space_needed = 4 * reg64_size_;
89
90 void bcast_loop(int load_loop_blk);
91 void reduce_loop(int load_loop_blk, int ur, int substep, bool wraparound);
92
93 inline size_t get_output_offset(
94 const bool is_out_layout_nxc, const int i_load, const int i_ur) {
95 const size_t i_load_shift = is_out_layout_nxc
96 ? jcp.load_block
97 : (jcp.with_dw_conv ? jcp.ow : jcp.bcast_dim) * jcp.load_block;
98 const size_t i_ur_shift
99 = is_out_layout_nxc ? jcp.load_dim : jcp.load_block;
100 return jcp.typesize_out * (i_load * i_load_shift + i_ur * i_ur_shift);
101 }
102
103 Xbyak::Address output_ptr(
104 const bool out_layout_nxc, const int i_load, const int i_ur);
105 void apply_postops(const bool is_out_layout_nxc, const int load_loop_blk,
106 const int ur);
107 void generate() override;
108 static void balance(jit_1x1_conv_conf_t &jcp);
109};
110
111} // namespace x64
112} // namespace cpu
113} // namespace impl
114} // namespace dnnl
115
116#endif
117