1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_BRGEMM_TRANSPOSE_UTILS_HPP
18#define CPU_X64_JIT_BRGEMM_TRANSPOSE_UTILS_HPP
19
20#include "cpu/x64/jit_brgemm_primitive_conf.hpp"
21#include "cpu/x64/jit_generator.hpp"
22
23namespace dnnl {
24namespace impl {
25namespace cpu {
26namespace x64 {
27
28struct jit_brgemm_trans_src_t {
29 struct ctx_t {
30 const void *src;
31 const void *tr_src;
32
33 dim_t current_gemm_batch;
34 dim_t current_M, current_K;
35 };
36
37 virtual void operator()(ctx_t *ctx) = 0;
38 virtual status_t create_kernel() = 0;
39
40 jit_brgemm_trans_src_t(const jit_brgemm_primitive_conf_t *conf)
41 : conf_(conf) {}
42 virtual ~jit_brgemm_trans_src_t() {}
43
44 const jit_brgemm_primitive_conf_t *conf_;
45};
46
47struct jit_brgemm_copy_to_coarse_t : public jit_generator {
48 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_copy_to_coarse_t)
49
50 struct ctx_t {
51 const void *data;
52 const void *tr_data;
53
54 dim_t os_work;
55 dim_t last_row_blk;
56 };
57
58 void operator()(ctx_t *ctx) { jit_generator::operator()(ctx); }
59 status_t create_kernel() override { return jit_generator::create_kernel(); }
60
61 jit_brgemm_copy_to_coarse_t(const jit_brgemm_primitive_conf_t *conf)
62 : jit_generator(jit_name())
63 , conf_(conf)
64 , typesize_(sizeof(float) / data_type_vnni_granularity(conf_->wei_dt))
65 , is_fwd_dir_(utils::one_of(conf_->prop_kind,
66 prop_kind::forward_training, prop_kind::forward_inference))
67 , row_block_size_(is_fwd_dir_ ? conf_->ic_block : conf_->oc_block)
68 , row_size_(is_fwd_dir_ ? conf_->ic_without_padding
69 : conf_->oc_without_padding)
70 , tr_row_size_(conf_->LDA)
71 , row_granularity_(granularity_in_bytes / typesize_)
72 , row_step_(zmm_size_in_bytes / typesize_)
73 , data_stride_(row_size_ * typesize_)
74 , tr_data_stride_(tr_row_size_ * typesize_) {
75
76 // Kernel is supposed to be called under the following constraints
77 assert(is_superset(conf_->isa, avx512_core_amx));
78 assert(row_size_ % row_granularity_ != 0);
79
80 MAYBE_UNUSED(row_granularity_);
81 }
82 ~jit_brgemm_copy_to_coarse_t() {}
83
84private:
85 enum {
86 zmm_size_in_bytes = 64,
87 row_loop_unroll = 16,
88 granularity_in_bytes = 4,
89 };
90
91 const jit_brgemm_primitive_conf_t *conf_;
92 const int typesize_;
93 const bool is_fwd_dir_;
94 const int row_block_size_, row_size_, tr_row_size_, row_granularity_,
95 row_step_;
96 const dim_t data_stride_, tr_data_stride_;
97
98 inline size_t addr_offset(int row_idx) {
99 return row_idx * row_step_ * typesize_;
100 }
101
102 inline Xbyak::Zmm get_zmm_copy(int row_idx) const {
103 assert(row_idx >= 0 && row_idx < row_loop_unroll);
104 return Xbyak::Zmm(row_idx);
105 }
106
107 const Xbyak::Zmm zmm_zero = Xbyak::Zmm(row_loop_unroll);
108 const Xbyak::Zmm zmm_row_tail = Xbyak::Zmm(row_loop_unroll + 1);
109
110 const Xbyak::Opmask reg_m_full_row_tail_load = k7;
111 const Xbyak::Opmask reg_m_full_row_tail_store = k6;
112 const Xbyak::Opmask reg_m_last_row_tail_load = k5;
113 const Xbyak::Opmask reg_m_last_row_tail_store = k4;
114
115 const Xbyak::Reg64 reg_data = rax;
116 const Xbyak::Reg64 reg_tr_data = rbx;
117
118 const Xbyak::Reg64 reg_os_work = r11;
119 const Xbyak::Reg64 reg_last_row_blk = r12;
120 const Xbyak::Reg64 reg_tail_mask = r13;
121
122 void copy_os_loop();
123 void copy_row_loop();
124
125 void copy_row_blks(int num_row_blks);
126 void copy_row_tail(bool is_last_iteration, int row_offset);
127 void zero_out_rows();
128
129 void set_full_row_tail_masks();
130 void set_last_row_tail_masks();
131
132 void generate() override;
133};
134
135struct jit_brgemm_trans_to_vnni_t {
136 struct ctx_t {
137 const void *src;
138 const void *tr_src;
139
140 dim_t current_gemm_batch;
141 dim_t current_col_size, current_row_size;
142 };
143
144 typedef enum matrix_to_transform {
145 matrix_B,
146 matrix_C
147 } matrix_to_transform_t;
148
149 virtual void operator()(ctx_t *ctx) = 0;
150 virtual status_t create_kernel() = 0;
151
152 jit_brgemm_trans_to_vnni_t(const jit_brgemm_primitive_conf_t *conf,
153 matrix_to_transform_t matrix_to_transform)
154 : conf_(conf), matrix_to_transform_(matrix_to_transform) {}
155 virtual ~jit_brgemm_trans_to_vnni_t() {}
156
157 const jit_brgemm_primitive_conf_t *conf_;
158 matrix_to_transform_t matrix_to_transform_;
159};
160
161struct jit_brgemm_trans_wei_t {
162 struct ctx_t {
163 const void *src;
164 const void *tr_src;
165
166 dim_t current_gemm_batch;
167 dim_t current_N, current_K;
168 };
169
170 virtual void operator()(ctx_t *ctx) = 0;
171 virtual status_t create_kernel() = 0;
172
173 jit_brgemm_trans_wei_t(const jit_brgemm_primitive_conf_t *conf)
174 : conf_(conf) {}
175 virtual ~jit_brgemm_trans_wei_t() {}
176
177 const jit_brgemm_primitive_conf_t *conf_;
178};
179
180struct jit_amx_ip_trans_diff_wei {
181 struct ctx_t {
182 const void *src;
183 const void *dst;
184
185 size_t last_oc_block;
186 size_t last_ic_block;
187 };
188
189 virtual void operator()(ctx_t *ctx) = 0;
190 virtual status_t create_kernel() = 0;
191
192 jit_amx_ip_trans_diff_wei(const jit_brgemm_primitive_conf_t *jbgp,
193 const int ext_ic_block, const int ext_oc_block)
194 : jbgp_(jbgp)
195 , ext_ic_block_(ext_ic_block)
196 , ext_oc_block_(ext_oc_block) {}
197
198 virtual ~jit_amx_ip_trans_diff_wei() {}
199
200 const jit_brgemm_primitive_conf_t *jbgp_;
201
202 int ext_ic_block_ = 0;
203 int ext_oc_block_ = 0;
204};
205
206status_t create_brgemm_trans_src(
207 std::unique_ptr<jit_brgemm_trans_src_t> &trans_ker,
208 const jit_brgemm_primitive_conf_t *conf);
209status_t create_brgemm_copy_to_coarse(
210 std::unique_ptr<jit_brgemm_copy_to_coarse_t> &copy_ker,
211 const jit_brgemm_primitive_conf_t *conf);
212status_t create_brgemm_trans_to_vnni(
213 std::unique_ptr<jit_brgemm_trans_to_vnni_t> &trans_ker,
214 const jit_brgemm_primitive_conf_t *conf,
215 jit_brgemm_trans_to_vnni_t::matrix_to_transform_t matrix_to_transform);
216status_t create_brgemm_trans_wei(
217 std::unique_ptr<jit_brgemm_trans_wei_t> &trans_ker,
218 const jit_brgemm_primitive_conf_t *conf);
219status_t create_brgemm_amx_ip_trans_wei(
220 std::unique_ptr<jit_amx_ip_trans_diff_wei> &trans_ker,
221 const jit_brgemm_primitive_conf_t *conf, const int ext_ic_block,
222 const int ext_oc_block);
223} // namespace x64
224} // namespace cpu
225} // namespace impl
226} // namespace dnnl
227
228#endif
229