1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_JIT_BRGEMM_TRANSPOSE_UTILS_HPP |
18 | #define CPU_X64_JIT_BRGEMM_TRANSPOSE_UTILS_HPP |
19 | |
20 | #include "cpu/x64/jit_brgemm_primitive_conf.hpp" |
21 | #include "cpu/x64/jit_generator.hpp" |
22 | |
23 | namespace dnnl { |
24 | namespace impl { |
25 | namespace cpu { |
26 | namespace x64 { |
27 | |
28 | struct jit_brgemm_trans_src_t { |
29 | struct ctx_t { |
30 | const void *src; |
31 | const void *tr_src; |
32 | |
33 | dim_t current_gemm_batch; |
34 | dim_t current_M, current_K; |
35 | }; |
36 | |
37 | virtual void operator()(ctx_t *ctx) = 0; |
38 | virtual status_t create_kernel() = 0; |
39 | |
40 | jit_brgemm_trans_src_t(const jit_brgemm_primitive_conf_t *conf) |
41 | : conf_(conf) {} |
42 | virtual ~jit_brgemm_trans_src_t() {} |
43 | |
44 | const jit_brgemm_primitive_conf_t *conf_; |
45 | }; |
46 | |
47 | struct jit_brgemm_copy_to_coarse_t : public jit_generator { |
48 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_copy_to_coarse_t) |
49 | |
50 | struct ctx_t { |
51 | const void *data; |
52 | const void *tr_data; |
53 | |
54 | dim_t os_work; |
55 | dim_t last_row_blk; |
56 | }; |
57 | |
58 | void operator()(ctx_t *ctx) { jit_generator::operator()(ctx); } |
59 | status_t create_kernel() override { return jit_generator::create_kernel(); } |
60 | |
61 | jit_brgemm_copy_to_coarse_t(const jit_brgemm_primitive_conf_t *conf) |
62 | : jit_generator(jit_name()) |
63 | , conf_(conf) |
64 | , typesize_(sizeof(float) / data_type_vnni_granularity(conf_->wei_dt)) |
65 | , is_fwd_dir_(utils::one_of(conf_->prop_kind, |
66 | prop_kind::forward_training, prop_kind::forward_inference)) |
67 | , row_block_size_(is_fwd_dir_ ? conf_->ic_block : conf_->oc_block) |
68 | , row_size_(is_fwd_dir_ ? conf_->ic_without_padding |
69 | : conf_->oc_without_padding) |
70 | , tr_row_size_(conf_->LDA) |
71 | , row_granularity_(granularity_in_bytes / typesize_) |
72 | , row_step_(zmm_size_in_bytes / typesize_) |
73 | , data_stride_(row_size_ * typesize_) |
74 | , tr_data_stride_(tr_row_size_ * typesize_) { |
75 | |
76 | // Kernel is supposed to be called under the following constraints |
77 | assert(is_superset(conf_->isa, avx512_core_amx)); |
78 | assert(row_size_ % row_granularity_ != 0); |
79 | |
80 | MAYBE_UNUSED(row_granularity_); |
81 | } |
82 | ~jit_brgemm_copy_to_coarse_t() {} |
83 | |
84 | private: |
85 | enum { |
86 | zmm_size_in_bytes = 64, |
87 | row_loop_unroll = 16, |
88 | granularity_in_bytes = 4, |
89 | }; |
90 | |
91 | const jit_brgemm_primitive_conf_t *conf_; |
92 | const int typesize_; |
93 | const bool is_fwd_dir_; |
94 | const int row_block_size_, row_size_, tr_row_size_, row_granularity_, |
95 | row_step_; |
96 | const dim_t data_stride_, tr_data_stride_; |
97 | |
98 | inline size_t addr_offset(int row_idx) { |
99 | return row_idx * row_step_ * typesize_; |
100 | } |
101 | |
102 | inline Xbyak::Zmm get_zmm_copy(int row_idx) const { |
103 | assert(row_idx >= 0 && row_idx < row_loop_unroll); |
104 | return Xbyak::Zmm(row_idx); |
105 | } |
106 | |
107 | const Xbyak::Zmm zmm_zero = Xbyak::Zmm(row_loop_unroll); |
108 | const Xbyak::Zmm zmm_row_tail = Xbyak::Zmm(row_loop_unroll + 1); |
109 | |
110 | const Xbyak::Opmask reg_m_full_row_tail_load = k7; |
111 | const Xbyak::Opmask reg_m_full_row_tail_store = k6; |
112 | const Xbyak::Opmask reg_m_last_row_tail_load = k5; |
113 | const Xbyak::Opmask reg_m_last_row_tail_store = k4; |
114 | |
115 | const Xbyak::Reg64 reg_data = rax; |
116 | const Xbyak::Reg64 reg_tr_data = rbx; |
117 | |
118 | const Xbyak::Reg64 reg_os_work = r11; |
119 | const Xbyak::Reg64 reg_last_row_blk = r12; |
120 | const Xbyak::Reg64 reg_tail_mask = r13; |
121 | |
122 | void copy_os_loop(); |
123 | void copy_row_loop(); |
124 | |
125 | void copy_row_blks(int num_row_blks); |
126 | void copy_row_tail(bool is_last_iteration, int row_offset); |
127 | void zero_out_rows(); |
128 | |
129 | void set_full_row_tail_masks(); |
130 | void set_last_row_tail_masks(); |
131 | |
132 | void generate() override; |
133 | }; |
134 | |
135 | struct jit_brgemm_trans_to_vnni_t { |
136 | struct ctx_t { |
137 | const void *src; |
138 | const void *tr_src; |
139 | |
140 | dim_t current_gemm_batch; |
141 | dim_t current_col_size, current_row_size; |
142 | }; |
143 | |
144 | typedef enum matrix_to_transform { |
145 | matrix_B, |
146 | matrix_C |
147 | } matrix_to_transform_t; |
148 | |
149 | virtual void operator()(ctx_t *ctx) = 0; |
150 | virtual status_t create_kernel() = 0; |
151 | |
152 | jit_brgemm_trans_to_vnni_t(const jit_brgemm_primitive_conf_t *conf, |
153 | matrix_to_transform_t matrix_to_transform) |
154 | : conf_(conf), matrix_to_transform_(matrix_to_transform) {} |
155 | virtual ~jit_brgemm_trans_to_vnni_t() {} |
156 | |
157 | const jit_brgemm_primitive_conf_t *conf_; |
158 | matrix_to_transform_t matrix_to_transform_; |
159 | }; |
160 | |
161 | struct jit_brgemm_trans_wei_t { |
162 | struct ctx_t { |
163 | const void *src; |
164 | const void *tr_src; |
165 | |
166 | dim_t current_gemm_batch; |
167 | dim_t current_N, current_K; |
168 | }; |
169 | |
170 | virtual void operator()(ctx_t *ctx) = 0; |
171 | virtual status_t create_kernel() = 0; |
172 | |
173 | jit_brgemm_trans_wei_t(const jit_brgemm_primitive_conf_t *conf) |
174 | : conf_(conf) {} |
175 | virtual ~jit_brgemm_trans_wei_t() {} |
176 | |
177 | const jit_brgemm_primitive_conf_t *conf_; |
178 | }; |
179 | |
180 | struct jit_amx_ip_trans_diff_wei { |
181 | struct ctx_t { |
182 | const void *src; |
183 | const void *dst; |
184 | |
185 | size_t last_oc_block; |
186 | size_t last_ic_block; |
187 | }; |
188 | |
189 | virtual void operator()(ctx_t *ctx) = 0; |
190 | virtual status_t create_kernel() = 0; |
191 | |
192 | jit_amx_ip_trans_diff_wei(const jit_brgemm_primitive_conf_t *jbgp, |
193 | const int ext_ic_block, const int ext_oc_block) |
194 | : jbgp_(jbgp) |
195 | , ext_ic_block_(ext_ic_block) |
196 | , ext_oc_block_(ext_oc_block) {} |
197 | |
198 | virtual ~jit_amx_ip_trans_diff_wei() {} |
199 | |
200 | const jit_brgemm_primitive_conf_t *jbgp_; |
201 | |
202 | int ext_ic_block_ = 0; |
203 | int ext_oc_block_ = 0; |
204 | }; |
205 | |
206 | status_t create_brgemm_trans_src( |
207 | std::unique_ptr<jit_brgemm_trans_src_t> &trans_ker, |
208 | const jit_brgemm_primitive_conf_t *conf); |
209 | status_t create_brgemm_copy_to_coarse( |
210 | std::unique_ptr<jit_brgemm_copy_to_coarse_t> ©_ker, |
211 | const jit_brgemm_primitive_conf_t *conf); |
212 | status_t create_brgemm_trans_to_vnni( |
213 | std::unique_ptr<jit_brgemm_trans_to_vnni_t> &trans_ker, |
214 | const jit_brgemm_primitive_conf_t *conf, |
215 | jit_brgemm_trans_to_vnni_t::matrix_to_transform_t matrix_to_transform); |
216 | status_t create_brgemm_trans_wei( |
217 | std::unique_ptr<jit_brgemm_trans_wei_t> &trans_ker, |
218 | const jit_brgemm_primitive_conf_t *conf); |
219 | status_t create_brgemm_amx_ip_trans_wei( |
220 | std::unique_ptr<jit_amx_ip_trans_diff_wei> &trans_ker, |
221 | const jit_brgemm_primitive_conf_t *conf, const int ext_ic_block, |
222 | const int ext_oc_block); |
223 | } // namespace x64 |
224 | } // namespace cpu |
225 | } // namespace impl |
226 | } // namespace dnnl |
227 | |
228 | #endif |
229 | |