1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_AVX512_CORE_AMX_1X1_CONV_KERNEL_HPP
18#define CPU_X64_JIT_AVX512_CORE_AMX_1X1_CONV_KERNEL_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/memory_tracking.hpp"
22
23#include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
24#include "cpu/x64/jit_generator.hpp"
25#include "cpu/x64/jit_primitive_conf.hpp"
26
27namespace dnnl {
28namespace impl {
29namespace cpu {
30namespace x64 {
31
32struct jit_avx512_core_amx_1x1_fwd_kernel_t : public jit_generator {
33 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_amx_1x1_fwd_kernel_t)
34
35 jit_avx512_core_amx_1x1_fwd_kernel_t(const jit_conv_conf_t &ajcp,
36 const primitive_attr_t &attr, const memory_desc_t &dst_md);
37
38 static status_t init_conf(jit_conv_conf_t &jcp,
39 const convolution_desc_t &cd, memory_desc_t &src_pd,
40 memory_desc_t &weights_pd, memory_desc_t &dst_pd,
41 memory_desc_t &bias_pd, primitive_attr_t &attr, int nthreads);
42 static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
43 const jit_conv_conf_t &jcp, const primitive_attr_t &attr);
44
45 // Tile-registers decomposition
46 enum { C_BASE = 0, W_BASE = 6, I_BASE = 4 };
47
48 void tile_configure(char *tcgf_buff);
49
50 jit_conv_conf_t jcp;
51 const primitive_attr_t &attr_;
52
53private:
54 constexpr static int isa_simd_width_
55 = cpu_isa_traits<avx512_core>::vlen / sizeof(float);
56 std::unique_ptr<injector::jit_uni_postops_injector_t<avx512_core>>
57 postops_injector_;
58
59 enum {
60 zmm_idx_limit_bf16 = 29,
61 zmm_idx_limit_int8 = 27,
62 };
63
64 int row_count_ = 0;
65 int buf_count_ = 0;
66 bool is_store_done_ = false;
67 bool is_buffer_empty_ = true;
68 bool check_last_sb_ = false;
69 bool last_oc_block_flag_ = false;
70
71 /* data regs */
72 const Xbyak::Reg64 inp_ptr = r15;
73 const Xbyak::Reg64 wei_ptr = r14;
74 const Xbyak::Reg64 out_ptr = r13;
75 const Xbyak::Reg64 wsp_ptr = r12;
76
77 const Xbyak::Reg64 reg_bias = r11;
78 const Xbyak::Reg64 reg_ptr_scales = r10;
79 const Xbyak::Reg64 reg_ptr_sum_scale = r9;
80 const Xbyak::Reg64 reg_ptr_sum_zp = rax;
81 const Xbyak::Reg64 aux_reg_saturation = reg_ptr_sum_scale;
82 const Xbyak::Reg64 reg_last_h = r8;
83
84 const Xbyak::Reg64 stride_seq = rbx;
85 const Xbyak::Reg64 stride_nhwc = rsi;
86 const Xbyak::Reg64 reg_tmp = abi_not_param1;
87
88 const Xbyak::Reg64 reg_oc_blocks = rdx;
89 const Xbyak::Reg64 reg_is_osb = rsi;
90 const Xbyak::Reg64 reg_postop = abi_not_param1;
91 const Xbyak::Reg64 reg_scratch = reg_bias;
92 const Xbyak::Reg64 reg_tilebuff = reg_ptr_scales;
93 /* zero-point */
94 const Xbyak::Reg64 reg_zp_compensation = reg_last_h;
95 const Xbyak::Reg64 reg_src_zero_point = reg_oc_blocks;
96 const Xbyak::Reg64 reg_dst_zero_point = rax;
97
98 /* scale */
99 const Xbyak::Reg64 reg_ptr_dst_scale = reg_ptr_scales;
100
101 const Xbyak::Zmm zmm_bias = zmm31;
102 const Xbyak::Zmm zmm_saturation = zmm_bias;
103 const Xbyak::Zmm zmm_zero = zmm30;
104 const Xbyak::Zmm zmm_prev_dst = zmm29;
105 const Xbyak::Zmm zmm_sum_zp = zmm26;
106 /* zero-point */
107 const Xbyak::Zmm zmm_zp = zmm29;
108 const Xbyak::Zmm zmm_src_zp = zmm28;
109 const Xbyak::Zmm zmm_dst_zp = zmm27;
110
111 const Xbyak::Reg64 bin_injector_helper_reg_1 = r14;
112 const Xbyak::Reg64 bin_injector_helper_reg_2 = r15;
113 const Xbyak::Reg64 bin_injector_helper_reg_3 = r11;
114
115 const Xbyak::Opmask ktail_mask = k2;
116
117 bool is_bf16() const;
118
119 void init_runtime_counters();
120
121 int get_out_tensor(int h, int i) const;
122 int get_inp_tensor(int h) const;
123 int get_wei_tensor(int i) const;
124 int get_ic_tail() const;
125
126 size_t out_h_shift() const;
127 size_t out_w_shift() const;
128 size_t inp_offset(int ih, int iw, int icb) const;
129 size_t out_row_offset(int h, int w, int ocb) const;
130
131 void prepare_output();
132
133 void cvt2ps(data_type_t type_in, const Xbyak::Zmm ymm_in,
134 const Xbyak::Operand &op, bool mask_flag);
135 Xbyak::Zmm zmm_out(const int idx) {
136 const int upper_limit
137 = is_bf16() ? zmm_idx_limit_bf16 : zmm_idx_limit_int8;
138 assert(upper_limit > idx);
139 MAYBE_UNUSED(upper_limit);
140 return Xbyak::Zmm(idx);
141 }
142 Xbyak::Zmm zmm_mask(
143 const Xbyak::Zmm zmm_in, bool mask_flag, bool store = false);
144 Xbyak::Ymm ymm_mask(
145 const Xbyak::Ymm ymm_in, bool mask_flag, bool store = false);
146
147 void update_buffer_pointers();
148 void interleave_store();
149 void apply_sum(const Xbyak::Zmm zmm_out, const float *p_sum_scale,
150 const int32_t *p_sum_zp, const Xbyak::Address &addr,
151 const bool mask_flag);
152 void apply_postops(const Xbyak::Zmm zmm_out, const float *p_sum_scale,
153 const int32_t *p_sum_zp, const Xbyak::Address &addr,
154 const size_t off, const bool mask_flag);
155 static bool is_fast_postops(const jit_conv_conf_t &jcp);
156 void store_output_vectors_int8(int ocb, int osb);
157 void store_output_vector_int8(
158 const Xbyak::Zmm zmm_out, int ocb, int h, int w);
159 inline void store_output_ymm_bf16(
160 const int idx, const Xbyak::Address &addr, const bool mask_flag);
161 void store_output_vectors_bf16(int ocb, int osb);
162 void store_output_vector_bf16(
163 const Xbyak::Zmm zmm_out, int ocb, int h, int w);
164 void store_output_vectors(int ocb, int osb);
165 void store_output_vector(const Xbyak::Zmm zmm_out, int ocb, int h, int w);
166 void store_output(bool do_store, bool is_tail);
167 void icb_loop(bool do_store);
168 void osb_loop(int nb_os = 1);
169
170 void generate() override;
171};
172
173} // namespace x64
174} // namespace cpu
175} // namespace impl
176} // namespace dnnl
177
178#endif
179