1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_BRGEMM_INNER_PRODUCT_HPP |
18 | #define CPU_X64_BRGEMM_INNER_PRODUCT_HPP |
19 | |
20 | #include "common/c_types_map.hpp" |
21 | #include "common/dnnl_thread.hpp" |
22 | #include "common/memory_tracking.hpp" |
23 | #include "common/primitive.hpp" |
24 | #include "common/utils.hpp" |
25 | |
26 | #include "cpu/cpu_inner_product_pd.hpp" |
27 | |
28 | #include "cpu/x64/amx_tile_configure.hpp" |
29 | #include "cpu/x64/brgemm/brgemm.hpp" |
30 | #include "cpu/x64/cpu_barrier.hpp" |
31 | #include "cpu/x64/cpu_reducer.hpp" |
32 | #include "cpu/x64/jit_brgemm_inner_product_utils.hpp" |
33 | #include "cpu/x64/jit_brgemm_post_ops.hpp" |
34 | #include "cpu/x64/jit_brgemm_transpose_utils.hpp" |
35 | #include "cpu/x64/jit_transpose_utils.hpp" |
36 | |
37 | namespace dnnl { |
38 | namespace impl { |
39 | namespace cpu { |
40 | namespace x64 { |
41 | |
42 | template <cpu_isa_t isa> |
43 | struct brgemm_inner_product_fwd_t : public primitive_t { |
44 | struct pd_t : public cpu_inner_product_fwd_pd_t { |
45 | pd_t(const inner_product_desc_t *adesc, const primitive_attr_t *attr, |
46 | const typename pd_t::base_class *hint_fwd_pd) |
47 | : cpu_inner_product_fwd_pd_t(adesc, attr, hint_fwd_pd) {} |
48 | |
49 | DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("brgemm:" , isa, "" ), |
50 | brgemm_inner_product_fwd_t); |
51 | |
52 | status_t init(engine_t *engine) { |
53 | using namespace utils; |
54 | using namespace data_type; |
55 | |
56 | auto src_dt = invariant_src_md()->data_type; |
57 | auto dst_dt = invariant_dst_md()->data_type; |
58 | auto wei_dt = invariant_wei_md()->data_type; |
59 | const bool is_int8 = one_of(src_dt, u8, s8); |
60 | |
61 | using skip_mask_t = primitive_attr_t::skip_mask_t; |
62 | auto skip_mask = skip_mask_t::post_ops; |
63 | if (is_int8) skip_mask |= skip_mask_t::scales_runtime; |
64 | |
65 | bool ok = is_fwd() && mayiuse(isa) |
66 | && expect_data_types(src_dt, wei_dt, data_type::undef, |
67 | dst_dt, data_type::undef) |
68 | && IMPLICATION(with_bias() && is_int8, |
69 | one_of(bias_md_.data_type, f32, bf16, s32, s8, u8)) |
70 | && IMPLICATION(with_bias() && !is_int8, |
71 | one_of(bias_md_.data_type, f32, src_dt)) |
72 | && attr()->has_default_values(skip_mask, dst_dt) |
73 | && attr()->post_ops_.check_sum_consistent_dt(dst_dt) |
74 | && !has_zero_dim_memory() && arg_scales_ok(); |
75 | if (!ok) return status::unimplemented; |
76 | |
77 | CHECK(brgemm_inner_product_utils::init_ip_conf(isa, jbgp_, *desc(), |
78 | src_md_, weights_md_, dst_md_, bias_md_, attr_, |
79 | dnnl_get_max_threads())); |
80 | |
81 | bool are_post_ops_applicable = one_of(true, jbgp_.with_sum, |
82 | jbgp_.with_bias, jbgp_.with_scales, jbgp_.with_eltwise, |
83 | jbgp_.with_binary, jbgp_.acc_dt != jbgp_.dst_dt, |
84 | jbgp_.signed_input); |
85 | |
86 | const float alpha = 1.0; |
87 | const float beta = 1.0; |
88 | const float beta_init = 0.0; |
89 | |
90 | for_(int i_bs = 0; i_bs < 2; i_bs++) |
91 | for_(int i_init = 0; i_init < 2; i_init++) |
92 | for_(int i_M = 0; i_M < 2; i_M++) |
93 | for_(int i_N = 0; i_N < 2; i_N++) |
94 | for (int i_K = 0; i_K < 2; i_K++) { |
95 | auto vbeta = (i_init) ? beta_init : beta; |
96 | auto vM = (i_M) ? jbgp_.M_tail : jbgp_.M; |
97 | auto vN = (i_N) ? jbgp_.N_tail : jbgp_.N; |
98 | auto vK = (i_K) ? jbgp_.K_tail : jbgp_.K; |
99 | int bs = get_brg_batchsize(i_bs, i_K); |
100 | int idx = get_brg_kernel_idx(i_bs, i_init, i_M, i_N, i_K, bs); |
101 | if (idx < 0) continue; |
102 | brgemm_t &brg = brg_descs_[idx]; |
103 | CHECK(brgemm_desc_init(&brg, isa, jbgp_.brg_type, jbgp_.src_dt, |
104 | jbgp_.wei_dt, false, false, brgemm_row_major, alpha, |
105 | vbeta, jbgp_.LDA, jbgp_.LDB, jbgp_.LDC, vM, vN, vK)); |
106 | |
107 | auto LDD = jbgp_.oc_without_padding; |
108 | CHECK(brgemm_desc_set_postops( |
109 | &brg, attr(), &dst_md_, LDD, jbgp_.bia_dt)); |
110 | |
111 | if (are_post_ops_applicable && jbgp_.nthr_ic_b > 1) { |
112 | brgemm_attr_t brgattr; |
113 | brgattr.generate_skip_accumulation = true; |
114 | CHECK(brgemm_desc_set_attr(&brg, brgattr)); |
115 | } |
116 | if (jbgp_.is_amx) { |
117 | brgemm_attr_t brgattr; |
118 | brgattr.max_bs = bs; |
119 | brgattr.wary_tail_read = false; |
120 | brgattr.hint_expected_A_size = jbgp_.mb * jbgp_.ic; |
121 | brgattr.hint_expected_B_size = jbgp_.oc * jbgp_.ic; |
122 | brgattr.hint_expected_C_size = jbgp_.mb * jbgp_.oc; |
123 | brgattr.hint_innermost_loop = brgemm_ld_loop_innermost; |
124 | brgattr.use_uker = jbgp_.use_uker; |
125 | brgattr.use_interleave_stores = jbgp_.use_interleave_stores; |
126 | brgattr.hint_prefetching = jbgp_.hint_prefetching; |
127 | brgattr.fpmath_mode = attr()->fpmath_mode_; |
128 | |
129 | CHECK(brgemm_desc_set_attr(&brg, brgattr)); |
130 | jbgp_.amx_buf_size_per_thread |
131 | = nstl::max(brg.get_wsp_buffer_size(), |
132 | jbgp_.amx_buf_size_per_thread); |
133 | } |
134 | } |
135 | |
136 | auto scratchpad = scratchpad_registry().registrar(); |
137 | brgemm_inner_product_utils::init_scratchpad(scratchpad, jbgp_); |
138 | if (jbgp_.with_scales) |
139 | book_precomputed_scales(scratchpad, attr()->scales_, OC()); |
140 | |
141 | return status::success; |
142 | } |
143 | |
144 | bool arg_scales_ok() const { |
145 | std::vector<int> supported_args = {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS}; |
146 | bool ok = true; |
147 | ok = ok && attr()->scales_.has_default_values(supported_args); |
148 | for (int arg : supported_args) { |
149 | const auto &mask = attr()->scales_.get(arg).mask_; |
150 | if (arg == DNNL_ARG_WEIGHTS) |
151 | ok = ok && (mask == 0 || mask == (1 << 0)); |
152 | else |
153 | ok = ok && (mask == 0); |
154 | } |
155 | return ok; |
156 | } |
157 | |
158 | int get_brg_kernel_idx(bool is_bs_tail, bool do_initialization, |
159 | bool is_M_tail, bool is_N_tail, bool is_K_tail, int bs) const { |
160 | auto vM = (is_M_tail) ? jbgp_.M_tail : jbgp_.M; |
161 | auto vN = (is_N_tail) ? jbgp_.N_tail : jbgp_.N; |
162 | auto vK = (is_K_tail) ? jbgp_.K_tail : jbgp_.K; |
163 | |
164 | if (vM == 0 || vN == 0 || vK == 0 || bs == 0 || jbgp_.LDA < vK |
165 | || jbgp_.LDB < vN || jbgp_.LDC < vN) |
166 | return -1; |
167 | return brgemm_inner_product_utils::get_brg_kernel_index(jbgp_, |
168 | is_bs_tail, do_initialization, is_M_tail, is_N_tail, |
169 | is_K_tail); |
170 | } |
171 | |
172 | int get_brg_batchsize(bool is_bs_tail, bool is_K_tail) const { |
173 | auto adj_ic = jbgp_.use_buffer_a |
174 | ? utils::rnd_up(jbgp_.ic, jbgp_.ic_block) |
175 | : jbgp_.ic; |
176 | auto bs = (is_K_tail) |
177 | ? 1 |
178 | : ((is_bs_tail) ? (adj_ic / jbgp_.K) % jbgp_.gemm_batch_size |
179 | : jbgp_.gemm_batch_size); |
180 | return bs; |
181 | } |
182 | |
183 | brgemm_t brg_descs_[brgemm_inner_product_utils::max_num_brg_kernels_ip]; |
184 | jit_brgemm_primitive_conf_t jbgp_; |
185 | }; |
186 | |
187 | brgemm_inner_product_fwd_t(const pd_t *apd) : primitive_t(apd) {} |
188 | |
189 | status_t init(engine_t *engine) override { |
190 | for_(int i_bs = 0; i_bs < 2; i_bs++) |
191 | for_(int i_M = 0; i_M < 2; i_M++) |
192 | for_(int i_N = 0; i_N < 2; i_N++) |
193 | for_(int i_K = 0; i_K < 2; i_K++) |
194 | for (int i_init = 0; i_init < 2; i_init++) { |
195 | int bs = pd()->get_brg_batchsize(i_bs, i_K); |
196 | int idx = pd()->get_brg_kernel_idx(i_bs, i_init, i_M, i_N, i_K, bs); |
197 | if (idx < 0) continue; |
198 | |
199 | brgemm_kernel_t *ker = nullptr; |
200 | CHECK(brgemm_kernel_create(&ker, pd()->brg_descs_[idx])); |
201 | CHECK(safe_ptr_assign(brg_kernels_[idx], ker)); |
202 | if (pd()->jbgp_.is_amx) |
203 | CHECK(brgemm_init_tiles( |
204 | pd()->brg_descs_[idx], &brg_kernel_palettes_[idx][0])); |
205 | } |
206 | if (pd()->jbgp_.use_buffer_a) |
207 | CHECK(create_brgemm_copy_to_coarse(copy_src_kernel_, &pd()->jbgp_)); |
208 | if (pd()->jbgp_.nthr_ic_b > 1) { |
209 | CHECK(safe_ptr_assign( |
210 | acc_ker_, new cpu_accumulator_1d_t<data_type::f32>())); |
211 | CHECK(acc_ker_->create_kernel()); |
212 | } |
213 | return status::success; |
214 | } |
215 | |
216 | status_t execute(const exec_ctx_t &ctx) const override { |
217 | return execute_forward(ctx); |
218 | } |
219 | |
220 | private: |
221 | status_t execute_forward(const exec_ctx_t &ctx) const; |
222 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
223 | |
224 | std::unique_ptr<brgemm_kernel_t> |
225 | brg_kernels_[brgemm_inner_product_utils::max_num_brg_kernels_ip]; |
226 | std::unique_ptr<jit_brgemm_copy_to_coarse_t> copy_src_kernel_; |
227 | std::unique_ptr<cpu_accumulator_1d_t<data_type::f32>> acc_ker_; |
228 | char brg_kernel_palettes_[brgemm_inner_product_utils:: |
229 | max_num_brg_kernels_ip][AMX_PALETTE_SIZE]; |
230 | }; |
231 | |
232 | template <cpu_isa_t isa> |
233 | struct brgemm_inner_product_bwd_data_t : public primitive_t { |
234 | struct pd_t : public cpu_inner_product_bwd_data_pd_t { |
235 | pd_t(const inner_product_desc_t *adesc, const primitive_attr_t *attr, |
236 | const inner_product_fwd_pd_t *hint_fwd_pd) |
237 | : cpu_inner_product_bwd_data_pd_t(adesc, attr, hint_fwd_pd) {} |
238 | |
239 | DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("brgemm_bwd_d:" , isa, "" ), |
240 | brgemm_inner_product_bwd_data_t); |
241 | |
242 | status_t init(engine_t *engine) { |
243 | |
244 | auto diff_src_dt = invariant_src_md()->data_type; |
245 | auto diff_dst_dt = invariant_dst_md()->data_type; |
246 | auto wei_dt = invariant_wei_md()->data_type; |
247 | |
248 | bool ok = true && desc()->prop_kind == prop_kind::backward_data |
249 | && !has_zero_dim_memory() && mayiuse(isa) |
250 | && utils::one_of(diff_dst_dt, data_type::f32, |
251 | data_type::bf16, data_type::f16) |
252 | && wei_dt == diff_dst_dt |
253 | && utils::one_of(diff_src_dt, data_type::f32, diff_dst_dt) |
254 | && attr()->has_default_values( |
255 | primitive_attr_t::skip_mask_t::post_ops); |
256 | if (!ok) return status::unimplemented; |
257 | |
258 | memory_desc_t dummy_bias_md; |
259 | CHECK(brgemm_inner_product_utils::init_ip_conf(isa, jbgp_, *desc(), |
260 | diff_src_md_, weights_md_, diff_dst_md_, dummy_bias_md, |
261 | attr_, dnnl_get_max_threads())); |
262 | |
263 | const float alpha = 1.0; |
264 | const float beta = 1.0; |
265 | const float beta_init = 0.0; |
266 | const auto dt_b = isa == avx512_core_fp16 && jbgp_.use_buffer_b |
267 | ? data_type::f32 |
268 | : wei_dt; |
269 | |
270 | for_(int i_bs = 0; i_bs < 2; i_bs++) |
271 | for_(int i_init = 0; i_init < 2; i_init++) |
272 | for_(int i_M = 0; i_M < 2; i_M++) |
273 | for_(int i_N = 0; i_N < 2; i_N++) |
274 | for (int i_K = 0; i_K < 2; i_K++) { |
275 | auto vbeta = (i_init) ? beta_init : beta; |
276 | auto vM = (i_M) ? jbgp_.M_tail : jbgp_.M; |
277 | auto vN = (i_N) ? jbgp_.N_tail : jbgp_.N; |
278 | auto vK = (i_K) ? jbgp_.K_tail : jbgp_.K; |
279 | int bs = get_brg_batchsize(i_bs, i_K); |
280 | int idx = get_brg_kernel_idx(i_bs, i_init, i_M, i_N, i_K, bs); |
281 | if (idx < 0) continue; |
282 | |
283 | brgemm_t &brg = brg_descs_[idx]; |
284 | CHECK(brgemm_desc_init(&brg, isa, jbgp_.brg_type, diff_dst_dt, |
285 | dt_b, false, false, brgemm_row_major, alpha, vbeta, |
286 | jbgp_.LDA, jbgp_.LDB, jbgp_.LDC, vM, vN, vK)); |
287 | |
288 | auto LDD = jbgp_.ic_without_padding; |
289 | CHECK(brgemm_desc_set_postops( |
290 | &brg, attr(), &diff_src_md_, LDD, jbgp_.bia_dt)); |
291 | if (jbgp_.is_amx) { |
292 | brgemm_attr_t brgattr; |
293 | brgattr.max_bs = bs; |
294 | brgattr.wary_tail_read = false; |
295 | brgattr.hint_expected_A_size = jbgp_.mb * jbgp_.oc; |
296 | brgattr.hint_expected_B_size = jbgp_.oc * jbgp_.ic; |
297 | brgattr.hint_expected_C_size = jbgp_.mb * jbgp_.ic; |
298 | brgattr.hint_innermost_loop = brgemm_ld_loop_innermost; |
299 | brgattr.use_uker = jbgp_.use_uker; |
300 | brgattr.use_interleave_stores = jbgp_.use_interleave_stores; |
301 | brgattr.hint_prefetching = jbgp_.hint_prefetching; |
302 | brgattr.fpmath_mode = attr()->fpmath_mode_; |
303 | |
304 | CHECK(brgemm_desc_set_attr(&brg, brgattr)); |
305 | jbgp_.amx_buf_size_per_thread |
306 | = nstl::max(brg.get_wsp_buffer_size(), |
307 | jbgp_.amx_buf_size_per_thread); |
308 | } |
309 | } |
310 | |
311 | auto scratchpad = scratchpad_registry().registrar(); |
312 | brgemm_inner_product_utils::init_scratchpad(scratchpad, jbgp_); |
313 | |
314 | return status::success; |
315 | } |
316 | |
317 | int get_brg_kernel_idx(bool is_bs_tail, bool do_initialization, |
318 | bool is_M_tail, bool is_N_tail, bool is_K_tail, int bs) const { |
319 | auto vM = (is_M_tail) ? jbgp_.M_tail : jbgp_.M; |
320 | auto vN = (is_N_tail) ? jbgp_.N_tail : jbgp_.N; |
321 | auto vK = (is_K_tail) ? jbgp_.K_tail : jbgp_.K; |
322 | |
323 | if (vM == 0 || vN == 0 || vK == 0 || bs == 0 || jbgp_.LDA < vK |
324 | || jbgp_.LDB < vN || jbgp_.LDC < vN) |
325 | return -1; |
326 | return brgemm_inner_product_utils::get_brg_kernel_index(jbgp_, |
327 | is_bs_tail, do_initialization, is_M_tail, is_N_tail, |
328 | is_K_tail); |
329 | } |
330 | |
331 | int get_brg_batchsize(bool is_bs_tail, bool is_K_tail) const { |
332 | auto adj_oc = jbgp_.use_buffer_a |
333 | ? utils::rnd_up(jbgp_.oc, jbgp_.oc_block) |
334 | : jbgp_.oc; |
335 | auto bs = (is_K_tail) ? 1 |
336 | : ((is_bs_tail) ? (adj_oc / jbgp_.oc_block) |
337 | % jbgp_.nb_oc_blocking |
338 | : jbgp_.nb_oc_blocking); |
339 | |
340 | return bs; |
341 | } |
342 | |
343 | brgemm_t brg_descs_[brgemm_inner_product_utils::max_num_brg_kernels_ip]; |
344 | jit_brgemm_primitive_conf_t jbgp_; |
345 | }; |
346 | |
347 | brgemm_inner_product_bwd_data_t(const pd_t *apd) : primitive_t(apd) {} |
348 | |
349 | status_t init(engine_t *engine) override { |
350 | const auto &jbgp = pd()->jbgp_; |
351 | for_(int i_bs = 0; i_bs < 2; i_bs++) |
352 | for_(int i_M = 0; i_M < 2; i_M++) |
353 | for_(int i_N = 0; i_N < 2; i_N++) |
354 | for_(int i_K = 0; i_K < 2; i_K++) |
355 | for (int i_init = 0; i_init < 2; i_init++) { |
356 | int bs = pd()->get_brg_batchsize(i_bs, i_K); |
357 | int idx = pd()->get_brg_kernel_idx(i_bs, i_init, i_M, i_N, i_K, bs); |
358 | if (idx < 0) continue; |
359 | |
360 | brgemm_kernel_t *ker = nullptr; |
361 | CHECK(brgemm_kernel_create(&ker, pd()->brg_descs_[idx])); |
362 | CHECK(safe_ptr_assign(brg_kernels_[idx], ker)); |
363 | if (jbgp.is_amx) |
364 | CHECK(brgemm_init_tiles( |
365 | pd()->brg_descs_[idx], &brg_kernel_palettes_[idx][0])); |
366 | } |
367 | |
368 | if (pd()->jbgp_.use_buffer_a) |
369 | CHECK(create_brgemm_copy_to_coarse( |
370 | copy_diff_dst_kernel_, &pd()->jbgp_)); |
371 | if (jbgp.use_buffer_b) |
372 | CHECK(create_brgemm_trans_wei(trans_B_kernel_, &pd()->jbgp_)); |
373 | |
374 | if (jbgp.nthr_oc_b > 1) { |
375 | CHECK(safe_ptr_assign( |
376 | acc_ker_, new cpu_accumulator_1d_t<data_type::f32>())); |
377 | CHECK(acc_ker_->create_kernel()); |
378 | } |
379 | |
380 | return status::success; |
381 | } |
382 | |
383 | status_t execute(const exec_ctx_t &ctx) const override { |
384 | execute_backward_data(ctx); |
385 | return status::success; |
386 | } |
387 | |
388 | private: |
389 | void execute_backward_data(const exec_ctx_t &ctx) const; |
390 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
391 | |
392 | std::unique_ptr<brgemm_kernel_t> |
393 | brg_kernels_[brgemm_inner_product_utils::max_num_brg_kernels_ip]; |
394 | std::unique_ptr<jit_brgemm_copy_to_coarse_t> copy_diff_dst_kernel_; |
395 | std::unique_ptr<jit_brgemm_trans_wei_t> trans_B_kernel_; |
396 | std::unique_ptr<cpu_accumulator_1d_t<data_type::f32>> acc_ker_; |
397 | char brg_kernel_palettes_[brgemm_inner_product_utils:: |
398 | max_num_brg_kernels_ip][AMX_PALETTE_SIZE]; |
399 | }; |
400 | |
401 | template <cpu_isa_t isa> |
402 | struct brgemm_inner_product_bwd_weights_t : public primitive_t { |
403 | struct pd_t : public cpu_inner_product_bwd_weights_pd_t { |
404 | pd_t(const inner_product_desc_t *adesc, const primitive_attr_t *attr, |
405 | const inner_product_fwd_pd_t *hint_fwd_pd) |
406 | : cpu_inner_product_bwd_weights_pd_t(adesc, attr, hint_fwd_pd) {} |
407 | |
408 | DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("brgemm_bwd_w:" , isa, "" ), |
409 | brgemm_inner_product_bwd_weights_t); |
410 | |
411 | status_t init(engine_t *engine) { |
412 | |
413 | auto src_dt = invariant_src_md()->data_type; |
414 | auto diff_wei_type = invariant_wei_md()->data_type; |
415 | auto diff_dst_type = invariant_dst_md()->data_type; |
416 | |
417 | bool ok = true && desc()->prop_kind == prop_kind::backward_weights |
418 | && !has_zero_dim_memory() && mayiuse(isa) |
419 | && utils::one_of(src_dt, data_type::f32, data_type::bf16, |
420 | data_type::f16) |
421 | && diff_dst_type == src_dt |
422 | && utils::one_of(diff_wei_type, data_type::f32, src_dt) |
423 | && attr()->has_default_values( |
424 | primitive_attr_t::skip_mask_t::post_ops); |
425 | if (!ok) return status::unimplemented; |
426 | |
427 | CHECK(brgemm_inner_product_utils::init_ip_conf(isa, jbgp_, *desc(), |
428 | src_md_, diff_weights_md_, diff_dst_md_, diff_bias_md_, |
429 | attr_, dnnl_get_max_threads())); |
430 | |
431 | const float alpha = 1.0; |
432 | const float beta = 1.0; |
433 | const float beta_init = 0.0; |
434 | const auto dt_a = isa == avx512_core_fp16 && jbgp_.use_buffer_a |
435 | ? data_type::f32 |
436 | : jbgp_.src_dt; |
437 | const auto dt_b = isa == avx512_core_fp16 && jbgp_.use_buffer_b |
438 | ? data_type::f32 |
439 | : jbgp_.dst_dt; |
440 | |
441 | for_(int i_bs = 0; i_bs < 2; i_bs++) |
442 | for_(int i_init = 0; i_init < 2; i_init++) |
443 | for_(int i_M = 0; i_M < 2; i_M++) |
444 | for_(int i_N = 0; i_N < 2; i_N++) |
445 | for (int i_K = 0; i_K < 2; i_K++) { |
446 | auto vbeta = (i_init) ? beta_init : beta; |
447 | auto vM = (i_M) ? jbgp_.M_tail : jbgp_.M; |
448 | auto vN = (i_N) ? jbgp_.N_tail : jbgp_.N; |
449 | auto vK = (i_K) ? jbgp_.K_tail : jbgp_.K; |
450 | int bs = get_brg_batchsize(i_bs, i_K); |
451 | int idx = get_brg_kernel_idx(i_bs, i_init, i_M, i_N, i_K, bs); |
452 | if (idx < 0) continue; |
453 | brgemm_t &brg = brg_descs_[idx]; |
454 | CHECK(brgemm_desc_init(&brg, isa, jbgp_.brg_type, dt_a, dt_b, |
455 | false, false, brgemm_row_major, alpha, vbeta, jbgp_.LDA, |
456 | jbgp_.LDB, jbgp_.LDC, vM, vN, vK)); |
457 | if (jbgp_.is_amx) { |
458 | brgemm_attr_t brgattr; |
459 | brgattr.max_bs = bs; |
460 | brgattr.wary_tail_read = false; |
461 | brgattr.hint_expected_A_size = jbgp_.mb * jbgp_.ic; |
462 | brgattr.hint_expected_B_size = jbgp_.mb * jbgp_.oc; |
463 | brgattr.hint_expected_C_size = jbgp_.ic * jbgp_.oc; |
464 | brgattr.hint_innermost_loop = brgemm_ld_loop_innermost; |
465 | brgattr.use_uker = jbgp_.use_uker; |
466 | brgattr.use_interleave_stores = jbgp_.use_interleave_stores; |
467 | brgattr.hint_prefetching = jbgp_.hint_prefetching; |
468 | brgattr.fpmath_mode = attr()->fpmath_mode_; |
469 | |
470 | CHECK(brgemm_desc_set_attr(&brg, brgattr)); |
471 | jbgp_.amx_buf_size_per_thread |
472 | = nstl::max(brg.get_wsp_buffer_size(), |
473 | jbgp_.amx_buf_size_per_thread); |
474 | } |
475 | } |
476 | |
477 | auto scratchpad = scratchpad_registry().registrar(); |
478 | brgemm_inner_product_utils::init_scratchpad(scratchpad, jbgp_); |
479 | |
480 | return status::success; |
481 | } |
482 | |
483 | int get_brg_kernel_idx(bool is_bs_tail, bool do_initialization, |
484 | bool is_M_tail, bool is_N_tail, bool is_K_tail, int bs) const { |
485 | auto vM = (is_M_tail) ? jbgp_.M_tail : jbgp_.M; |
486 | auto vN = (is_N_tail) ? jbgp_.N_tail : jbgp_.N; |
487 | auto vK = (is_K_tail) ? jbgp_.K_tail : jbgp_.K; |
488 | |
489 | if (vM == 0 || vN == 0 || vK == 0 || bs == 0 || jbgp_.LDA < vK |
490 | || jbgp_.LDB < vN || jbgp_.LDC < vN) |
491 | return -1; |
492 | return brgemm_inner_product_utils::get_brg_kernel_index(jbgp_, |
493 | is_bs_tail, do_initialization, is_M_tail, is_N_tail, |
494 | is_K_tail); |
495 | } |
496 | |
497 | int get_brg_batchsize(bool is_bs_tail, bool is_K_tail) const { |
498 | auto bs = (is_K_tail) ? 1 |
499 | : ((is_bs_tail) ? (jbgp_.os / jbgp_.os_block) |
500 | % jbgp_.nb_os_blocking |
501 | : jbgp_.nb_os_blocking); |
502 | return bs; |
503 | } |
504 | |
505 | brgemm_t brg_descs_[brgemm_inner_product_utils::max_num_brg_kernels_ip]; |
506 | jit_brgemm_primitive_conf_t jbgp_; |
507 | }; |
508 | |
509 | brgemm_inner_product_bwd_weights_t(const pd_t *apd) : primitive_t(apd) {} |
510 | |
511 | status_t init(engine_t *engine) override { |
512 | const auto &jbgp = pd()->jbgp_; |
513 | for_(int i_bs = 0; i_bs < 2; i_bs++) |
514 | for_(int i_M = 0; i_M < 2; i_M++) |
515 | for_(int i_N = 0; i_N < 2; i_N++) |
516 | for_(int i_K = 0; i_K < 2; i_K++) |
517 | for (int i_init = 0; i_init < 2; i_init++) { |
518 | int bs = pd()->get_brg_batchsize(i_bs, i_K); |
519 | int idx = pd()->get_brg_kernel_idx(i_bs, i_init, i_M, i_N, i_K, bs); |
520 | if (idx < 0) continue; |
521 | |
522 | brgemm_kernel_t *ker = nullptr; |
523 | CHECK(brgemm_kernel_create(&ker, pd()->brg_descs_[idx])); |
524 | CHECK(safe_ptr_assign(brg_kernels_[idx], ker)); |
525 | if (jbgp.is_amx) |
526 | CHECK(brgemm_init_tiles( |
527 | pd()->brg_descs_[idx], &brg_kernel_palettes_[idx][0])); |
528 | |
529 | if (jbgp.with_bias && i_M == 0 && i_init == 0) { |
530 | kernels_db_[i_K][i_N] = nullptr; |
531 | auto db_desc = pd()->brg_descs_[idx]; |
532 | db_desc.reduce_dim = (i_K) ? jbgp.K_tail : jbgp.K; |
533 | if (db_desc.reduce_dim > 0 && db_desc.load_dim > 0) { |
534 | CHECK(safe_ptr_assign(kernels_db_[i_K][i_N], |
535 | new jit_brgemm_kernel_diff_bias_t(jbgp, db_desc))); |
536 | CHECK(kernels_db_[i_K][i_N]->create_kernel()); |
537 | } |
538 | } |
539 | } |
540 | if (jbgp.is_amx) { |
541 | ext_ic_block_ = jbgp.ic_block_ext; |
542 | ext_oc_block_ = jbgp.oc_block_ext; |
543 | } |
544 | CHECK(create_brgemm_trans_src(trans_A_kernel_, &pd()->jbgp_)); |
545 | |
546 | if (jbgp.use_buffer_b) |
547 | CHECK(create_brgemm_trans_to_vnni(trans_B_kernel_, &pd()->jbgp_, |
548 | jit_brgemm_trans_to_vnni_t::matrix_to_transform::matrix_B)); |
549 | |
550 | if (!jbgp.is_amx) { |
551 | if (jbgp.wei_dt != jbgp.acc_dt) |
552 | CHECK(create_brgemm_trans_to_vnni(trans_C_kernel_, &pd()->jbgp_, |
553 | jit_brgemm_trans_to_vnni_t::matrix_to_transform:: |
554 | matrix_C)); |
555 | } else if (utils::one_of( |
556 | jbgp.wei_dt, data_type::bf16, data_type::f16)) { |
557 | CHECK(create_brgemm_amx_ip_trans_wei(diff_wei_trans_kernel_, |
558 | &pd()->jbgp_, ext_ic_block_, ext_oc_block_)); |
559 | } |
560 | if (jbgp.nthr_mb > 1) { |
561 | CHECK(safe_ptr_assign( |
562 | acc_ker_, new cpu_accumulator_1d_t<data_type::f32>())); |
563 | CHECK(acc_ker_->create_kernel()); |
564 | } |
565 | |
566 | return status::success; |
567 | } |
568 | |
569 | status_t execute(const exec_ctx_t &ctx) const override { |
570 | execute_backward_weights(ctx); |
571 | return status::success; |
572 | } |
573 | |
574 | private: |
575 | enum loop_order_t { osc_icc_occ, osc_occ_icc, occ_icc_osc }; |
576 | struct thread_info_t; |
577 | std::unique_ptr<jit_brgemm_kernel_diff_bias_t> kernels_db_[2][2]; |
578 | std::unique_ptr<brgemm_kernel_t> |
579 | brg_kernels_[brgemm_inner_product_utils::max_num_brg_kernels_ip]; |
580 | std::unique_ptr<jit_brgemm_trans_src_t> trans_A_kernel_; |
581 | std::unique_ptr<jit_brgemm_trans_to_vnni_t> trans_B_kernel_; |
582 | std::unique_ptr<jit_brgemm_trans_to_vnni_t> trans_C_kernel_; |
583 | std::unique_ptr<cpu_accumulator_1d_t<data_type::f32>> acc_ker_; |
584 | std::unique_ptr<jit_amx_ip_trans_diff_wei> diff_wei_trans_kernel_; |
585 | |
586 | void execute_backward_weights(const exec_ctx_t &ctx) const; |
587 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
588 | void compute_diff_weights_and_bias(const thread_info_t *ti) const; |
589 | void reduce_and_convert_diff_weights_and_bias( |
590 | const thread_info_t *ti) const; |
591 | void transform_matrix_a_chunk(char *tr_src, const char *src, |
592 | int trans_batch, int current_m, int current_k) const; |
593 | void transform_matrix_b_chunk(char *tr_diff_dst, const char *diff_dst, |
594 | int trans_batch, int current_col_size, int current_row_size) const; |
595 | void transpose_matrix_c_chunk(const thread_info_t *ti, const int ocb, |
596 | const int icb, int oc_size, int ic_size, |
597 | bool is_reduction = false) const; |
598 | |
599 | char brg_kernel_palettes_[brgemm_inner_product_utils:: |
600 | max_num_brg_kernels_ip][AMX_PALETTE_SIZE]; |
601 | dim_t get_wei_offset(int ocb, int icb) const; |
602 | char *get_wei_acc_ptr(const thread_info_t *ti, int ocb, int icb, |
603 | int reduction_buf_idx = -1) const; |
604 | |
605 | int ext_ic_block_ = 0; |
606 | int ext_oc_block_ = 0; |
607 | }; |
608 | |
609 | } // namespace x64 |
610 | } // namespace cpu |
611 | } // namespace impl |
612 | } // namespace dnnl |
613 | |
614 | #endif |
615 | |
616 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
617 | |