1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_BRGEMM_INNER_PRODUCT_HPP
18#define CPU_X64_BRGEMM_INNER_PRODUCT_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/dnnl_thread.hpp"
22#include "common/memory_tracking.hpp"
23#include "common/primitive.hpp"
24#include "common/utils.hpp"
25
26#include "cpu/cpu_inner_product_pd.hpp"
27
28#include "cpu/x64/amx_tile_configure.hpp"
29#include "cpu/x64/brgemm/brgemm.hpp"
30#include "cpu/x64/cpu_barrier.hpp"
31#include "cpu/x64/cpu_reducer.hpp"
32#include "cpu/x64/jit_brgemm_inner_product_utils.hpp"
33#include "cpu/x64/jit_brgemm_post_ops.hpp"
34#include "cpu/x64/jit_brgemm_transpose_utils.hpp"
35#include "cpu/x64/jit_transpose_utils.hpp"
36
37namespace dnnl {
38namespace impl {
39namespace cpu {
40namespace x64 {
41
42template <cpu_isa_t isa>
43struct brgemm_inner_product_fwd_t : public primitive_t {
44 struct pd_t : public cpu_inner_product_fwd_pd_t {
45 pd_t(const inner_product_desc_t *adesc, const primitive_attr_t *attr,
46 const typename pd_t::base_class *hint_fwd_pd)
47 : cpu_inner_product_fwd_pd_t(adesc, attr, hint_fwd_pd) {}
48
49 DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("brgemm:", isa, ""),
50 brgemm_inner_product_fwd_t);
51
52 status_t init(engine_t *engine) {
53 using namespace utils;
54 using namespace data_type;
55
56 auto src_dt = invariant_src_md()->data_type;
57 auto dst_dt = invariant_dst_md()->data_type;
58 auto wei_dt = invariant_wei_md()->data_type;
59 const bool is_int8 = one_of(src_dt, u8, s8);
60
61 using skip_mask_t = primitive_attr_t::skip_mask_t;
62 auto skip_mask = skip_mask_t::post_ops;
63 if (is_int8) skip_mask |= skip_mask_t::scales_runtime;
64
65 bool ok = is_fwd() && mayiuse(isa)
66 && expect_data_types(src_dt, wei_dt, data_type::undef,
67 dst_dt, data_type::undef)
68 && IMPLICATION(with_bias() && is_int8,
69 one_of(bias_md_.data_type, f32, bf16, s32, s8, u8))
70 && IMPLICATION(with_bias() && !is_int8,
71 one_of(bias_md_.data_type, f32, src_dt))
72 && attr()->has_default_values(skip_mask, dst_dt)
73 && attr()->post_ops_.check_sum_consistent_dt(dst_dt)
74 && !has_zero_dim_memory() && arg_scales_ok();
75 if (!ok) return status::unimplemented;
76
77 CHECK(brgemm_inner_product_utils::init_ip_conf(isa, jbgp_, *desc(),
78 src_md_, weights_md_, dst_md_, bias_md_, attr_,
79 dnnl_get_max_threads()));
80
81 bool are_post_ops_applicable = one_of(true, jbgp_.with_sum,
82 jbgp_.with_bias, jbgp_.with_scales, jbgp_.with_eltwise,
83 jbgp_.with_binary, jbgp_.acc_dt != jbgp_.dst_dt,
84 jbgp_.signed_input);
85
86 const float alpha = 1.0;
87 const float beta = 1.0;
88 const float beta_init = 0.0;
89
90 for_(int i_bs = 0; i_bs < 2; i_bs++)
91 for_(int i_init = 0; i_init < 2; i_init++)
92 for_(int i_M = 0; i_M < 2; i_M++)
93 for_(int i_N = 0; i_N < 2; i_N++)
94 for (int i_K = 0; i_K < 2; i_K++) {
95 auto vbeta = (i_init) ? beta_init : beta;
96 auto vM = (i_M) ? jbgp_.M_tail : jbgp_.M;
97 auto vN = (i_N) ? jbgp_.N_tail : jbgp_.N;
98 auto vK = (i_K) ? jbgp_.K_tail : jbgp_.K;
99 int bs = get_brg_batchsize(i_bs, i_K);
100 int idx = get_brg_kernel_idx(i_bs, i_init, i_M, i_N, i_K, bs);
101 if (idx < 0) continue;
102 brgemm_t &brg = brg_descs_[idx];
103 CHECK(brgemm_desc_init(&brg, isa, jbgp_.brg_type, jbgp_.src_dt,
104 jbgp_.wei_dt, false, false, brgemm_row_major, alpha,
105 vbeta, jbgp_.LDA, jbgp_.LDB, jbgp_.LDC, vM, vN, vK));
106
107 auto LDD = jbgp_.oc_without_padding;
108 CHECK(brgemm_desc_set_postops(
109 &brg, attr(), &dst_md_, LDD, jbgp_.bia_dt));
110
111 if (are_post_ops_applicable && jbgp_.nthr_ic_b > 1) {
112 brgemm_attr_t brgattr;
113 brgattr.generate_skip_accumulation = true;
114 CHECK(brgemm_desc_set_attr(&brg, brgattr));
115 }
116 if (jbgp_.is_amx) {
117 brgemm_attr_t brgattr;
118 brgattr.max_bs = bs;
119 brgattr.wary_tail_read = false;
120 brgattr.hint_expected_A_size = jbgp_.mb * jbgp_.ic;
121 brgattr.hint_expected_B_size = jbgp_.oc * jbgp_.ic;
122 brgattr.hint_expected_C_size = jbgp_.mb * jbgp_.oc;
123 brgattr.hint_innermost_loop = brgemm_ld_loop_innermost;
124 brgattr.use_uker = jbgp_.use_uker;
125 brgattr.use_interleave_stores = jbgp_.use_interleave_stores;
126 brgattr.hint_prefetching = jbgp_.hint_prefetching;
127 brgattr.fpmath_mode = attr()->fpmath_mode_;
128
129 CHECK(brgemm_desc_set_attr(&brg, brgattr));
130 jbgp_.amx_buf_size_per_thread
131 = nstl::max(brg.get_wsp_buffer_size(),
132 jbgp_.amx_buf_size_per_thread);
133 }
134 }
135
136 auto scratchpad = scratchpad_registry().registrar();
137 brgemm_inner_product_utils::init_scratchpad(scratchpad, jbgp_);
138 if (jbgp_.with_scales)
139 book_precomputed_scales(scratchpad, attr()->scales_, OC());
140
141 return status::success;
142 }
143
144 bool arg_scales_ok() const {
145 std::vector<int> supported_args = {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS};
146 bool ok = true;
147 ok = ok && attr()->scales_.has_default_values(supported_args);
148 for (int arg : supported_args) {
149 const auto &mask = attr()->scales_.get(arg).mask_;
150 if (arg == DNNL_ARG_WEIGHTS)
151 ok = ok && (mask == 0 || mask == (1 << 0));
152 else
153 ok = ok && (mask == 0);
154 }
155 return ok;
156 }
157
158 int get_brg_kernel_idx(bool is_bs_tail, bool do_initialization,
159 bool is_M_tail, bool is_N_tail, bool is_K_tail, int bs) const {
160 auto vM = (is_M_tail) ? jbgp_.M_tail : jbgp_.M;
161 auto vN = (is_N_tail) ? jbgp_.N_tail : jbgp_.N;
162 auto vK = (is_K_tail) ? jbgp_.K_tail : jbgp_.K;
163
164 if (vM == 0 || vN == 0 || vK == 0 || bs == 0 || jbgp_.LDA < vK
165 || jbgp_.LDB < vN || jbgp_.LDC < vN)
166 return -1;
167 return brgemm_inner_product_utils::get_brg_kernel_index(jbgp_,
168 is_bs_tail, do_initialization, is_M_tail, is_N_tail,
169 is_K_tail);
170 }
171
172 int get_brg_batchsize(bool is_bs_tail, bool is_K_tail) const {
173 auto adj_ic = jbgp_.use_buffer_a
174 ? utils::rnd_up(jbgp_.ic, jbgp_.ic_block)
175 : jbgp_.ic;
176 auto bs = (is_K_tail)
177 ? 1
178 : ((is_bs_tail) ? (adj_ic / jbgp_.K) % jbgp_.gemm_batch_size
179 : jbgp_.gemm_batch_size);
180 return bs;
181 }
182
183 brgemm_t brg_descs_[brgemm_inner_product_utils::max_num_brg_kernels_ip];
184 jit_brgemm_primitive_conf_t jbgp_;
185 };
186
187 brgemm_inner_product_fwd_t(const pd_t *apd) : primitive_t(apd) {}
188
189 status_t init(engine_t *engine) override {
190 for_(int i_bs = 0; i_bs < 2; i_bs++)
191 for_(int i_M = 0; i_M < 2; i_M++)
192 for_(int i_N = 0; i_N < 2; i_N++)
193 for_(int i_K = 0; i_K < 2; i_K++)
194 for (int i_init = 0; i_init < 2; i_init++) {
195 int bs = pd()->get_brg_batchsize(i_bs, i_K);
196 int idx = pd()->get_brg_kernel_idx(i_bs, i_init, i_M, i_N, i_K, bs);
197 if (idx < 0) continue;
198
199 brgemm_kernel_t *ker = nullptr;
200 CHECK(brgemm_kernel_create(&ker, pd()->brg_descs_[idx]));
201 CHECK(safe_ptr_assign(brg_kernels_[idx], ker));
202 if (pd()->jbgp_.is_amx)
203 CHECK(brgemm_init_tiles(
204 pd()->brg_descs_[idx], &brg_kernel_palettes_[idx][0]));
205 }
206 if (pd()->jbgp_.use_buffer_a)
207 CHECK(create_brgemm_copy_to_coarse(copy_src_kernel_, &pd()->jbgp_));
208 if (pd()->jbgp_.nthr_ic_b > 1) {
209 CHECK(safe_ptr_assign(
210 acc_ker_, new cpu_accumulator_1d_t<data_type::f32>()));
211 CHECK(acc_ker_->create_kernel());
212 }
213 return status::success;
214 }
215
216 status_t execute(const exec_ctx_t &ctx) const override {
217 return execute_forward(ctx);
218 }
219
220private:
221 status_t execute_forward(const exec_ctx_t &ctx) const;
222 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
223
224 std::unique_ptr<brgemm_kernel_t>
225 brg_kernels_[brgemm_inner_product_utils::max_num_brg_kernels_ip];
226 std::unique_ptr<jit_brgemm_copy_to_coarse_t> copy_src_kernel_;
227 std::unique_ptr<cpu_accumulator_1d_t<data_type::f32>> acc_ker_;
228 char brg_kernel_palettes_[brgemm_inner_product_utils::
229 max_num_brg_kernels_ip][AMX_PALETTE_SIZE];
230};
231
232template <cpu_isa_t isa>
233struct brgemm_inner_product_bwd_data_t : public primitive_t {
234 struct pd_t : public cpu_inner_product_bwd_data_pd_t {
235 pd_t(const inner_product_desc_t *adesc, const primitive_attr_t *attr,
236 const inner_product_fwd_pd_t *hint_fwd_pd)
237 : cpu_inner_product_bwd_data_pd_t(adesc, attr, hint_fwd_pd) {}
238
239 DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("brgemm_bwd_d:", isa, ""),
240 brgemm_inner_product_bwd_data_t);
241
242 status_t init(engine_t *engine) {
243
244 auto diff_src_dt = invariant_src_md()->data_type;
245 auto diff_dst_dt = invariant_dst_md()->data_type;
246 auto wei_dt = invariant_wei_md()->data_type;
247
248 bool ok = true && desc()->prop_kind == prop_kind::backward_data
249 && !has_zero_dim_memory() && mayiuse(isa)
250 && utils::one_of(diff_dst_dt, data_type::f32,
251 data_type::bf16, data_type::f16)
252 && wei_dt == diff_dst_dt
253 && utils::one_of(diff_src_dt, data_type::f32, diff_dst_dt)
254 && attr()->has_default_values(
255 primitive_attr_t::skip_mask_t::post_ops);
256 if (!ok) return status::unimplemented;
257
258 memory_desc_t dummy_bias_md;
259 CHECK(brgemm_inner_product_utils::init_ip_conf(isa, jbgp_, *desc(),
260 diff_src_md_, weights_md_, diff_dst_md_, dummy_bias_md,
261 attr_, dnnl_get_max_threads()));
262
263 const float alpha = 1.0;
264 const float beta = 1.0;
265 const float beta_init = 0.0;
266 const auto dt_b = isa == avx512_core_fp16 && jbgp_.use_buffer_b
267 ? data_type::f32
268 : wei_dt;
269
270 for_(int i_bs = 0; i_bs < 2; i_bs++)
271 for_(int i_init = 0; i_init < 2; i_init++)
272 for_(int i_M = 0; i_M < 2; i_M++)
273 for_(int i_N = 0; i_N < 2; i_N++)
274 for (int i_K = 0; i_K < 2; i_K++) {
275 auto vbeta = (i_init) ? beta_init : beta;
276 auto vM = (i_M) ? jbgp_.M_tail : jbgp_.M;
277 auto vN = (i_N) ? jbgp_.N_tail : jbgp_.N;
278 auto vK = (i_K) ? jbgp_.K_tail : jbgp_.K;
279 int bs = get_brg_batchsize(i_bs, i_K);
280 int idx = get_brg_kernel_idx(i_bs, i_init, i_M, i_N, i_K, bs);
281 if (idx < 0) continue;
282
283 brgemm_t &brg = brg_descs_[idx];
284 CHECK(brgemm_desc_init(&brg, isa, jbgp_.brg_type, diff_dst_dt,
285 dt_b, false, false, brgemm_row_major, alpha, vbeta,
286 jbgp_.LDA, jbgp_.LDB, jbgp_.LDC, vM, vN, vK));
287
288 auto LDD = jbgp_.ic_without_padding;
289 CHECK(brgemm_desc_set_postops(
290 &brg, attr(), &diff_src_md_, LDD, jbgp_.bia_dt));
291 if (jbgp_.is_amx) {
292 brgemm_attr_t brgattr;
293 brgattr.max_bs = bs;
294 brgattr.wary_tail_read = false;
295 brgattr.hint_expected_A_size = jbgp_.mb * jbgp_.oc;
296 brgattr.hint_expected_B_size = jbgp_.oc * jbgp_.ic;
297 brgattr.hint_expected_C_size = jbgp_.mb * jbgp_.ic;
298 brgattr.hint_innermost_loop = brgemm_ld_loop_innermost;
299 brgattr.use_uker = jbgp_.use_uker;
300 brgattr.use_interleave_stores = jbgp_.use_interleave_stores;
301 brgattr.hint_prefetching = jbgp_.hint_prefetching;
302 brgattr.fpmath_mode = attr()->fpmath_mode_;
303
304 CHECK(brgemm_desc_set_attr(&brg, brgattr));
305 jbgp_.amx_buf_size_per_thread
306 = nstl::max(brg.get_wsp_buffer_size(),
307 jbgp_.amx_buf_size_per_thread);
308 }
309 }
310
311 auto scratchpad = scratchpad_registry().registrar();
312 brgemm_inner_product_utils::init_scratchpad(scratchpad, jbgp_);
313
314 return status::success;
315 }
316
317 int get_brg_kernel_idx(bool is_bs_tail, bool do_initialization,
318 bool is_M_tail, bool is_N_tail, bool is_K_tail, int bs) const {
319 auto vM = (is_M_tail) ? jbgp_.M_tail : jbgp_.M;
320 auto vN = (is_N_tail) ? jbgp_.N_tail : jbgp_.N;
321 auto vK = (is_K_tail) ? jbgp_.K_tail : jbgp_.K;
322
323 if (vM == 0 || vN == 0 || vK == 0 || bs == 0 || jbgp_.LDA < vK
324 || jbgp_.LDB < vN || jbgp_.LDC < vN)
325 return -1;
326 return brgemm_inner_product_utils::get_brg_kernel_index(jbgp_,
327 is_bs_tail, do_initialization, is_M_tail, is_N_tail,
328 is_K_tail);
329 }
330
331 int get_brg_batchsize(bool is_bs_tail, bool is_K_tail) const {
332 auto adj_oc = jbgp_.use_buffer_a
333 ? utils::rnd_up(jbgp_.oc, jbgp_.oc_block)
334 : jbgp_.oc;
335 auto bs = (is_K_tail) ? 1
336 : ((is_bs_tail) ? (adj_oc / jbgp_.oc_block)
337 % jbgp_.nb_oc_blocking
338 : jbgp_.nb_oc_blocking);
339
340 return bs;
341 }
342
343 brgemm_t brg_descs_[brgemm_inner_product_utils::max_num_brg_kernels_ip];
344 jit_brgemm_primitive_conf_t jbgp_;
345 };
346
347 brgemm_inner_product_bwd_data_t(const pd_t *apd) : primitive_t(apd) {}
348
349 status_t init(engine_t *engine) override {
350 const auto &jbgp = pd()->jbgp_;
351 for_(int i_bs = 0; i_bs < 2; i_bs++)
352 for_(int i_M = 0; i_M < 2; i_M++)
353 for_(int i_N = 0; i_N < 2; i_N++)
354 for_(int i_K = 0; i_K < 2; i_K++)
355 for (int i_init = 0; i_init < 2; i_init++) {
356 int bs = pd()->get_brg_batchsize(i_bs, i_K);
357 int idx = pd()->get_brg_kernel_idx(i_bs, i_init, i_M, i_N, i_K, bs);
358 if (idx < 0) continue;
359
360 brgemm_kernel_t *ker = nullptr;
361 CHECK(brgemm_kernel_create(&ker, pd()->brg_descs_[idx]));
362 CHECK(safe_ptr_assign(brg_kernels_[idx], ker));
363 if (jbgp.is_amx)
364 CHECK(brgemm_init_tiles(
365 pd()->brg_descs_[idx], &brg_kernel_palettes_[idx][0]));
366 }
367
368 if (pd()->jbgp_.use_buffer_a)
369 CHECK(create_brgemm_copy_to_coarse(
370 copy_diff_dst_kernel_, &pd()->jbgp_));
371 if (jbgp.use_buffer_b)
372 CHECK(create_brgemm_trans_wei(trans_B_kernel_, &pd()->jbgp_));
373
374 if (jbgp.nthr_oc_b > 1) {
375 CHECK(safe_ptr_assign(
376 acc_ker_, new cpu_accumulator_1d_t<data_type::f32>()));
377 CHECK(acc_ker_->create_kernel());
378 }
379
380 return status::success;
381 }
382
383 status_t execute(const exec_ctx_t &ctx) const override {
384 execute_backward_data(ctx);
385 return status::success;
386 }
387
388private:
389 void execute_backward_data(const exec_ctx_t &ctx) const;
390 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
391
392 std::unique_ptr<brgemm_kernel_t>
393 brg_kernels_[brgemm_inner_product_utils::max_num_brg_kernels_ip];
394 std::unique_ptr<jit_brgemm_copy_to_coarse_t> copy_diff_dst_kernel_;
395 std::unique_ptr<jit_brgemm_trans_wei_t> trans_B_kernel_;
396 std::unique_ptr<cpu_accumulator_1d_t<data_type::f32>> acc_ker_;
397 char brg_kernel_palettes_[brgemm_inner_product_utils::
398 max_num_brg_kernels_ip][AMX_PALETTE_SIZE];
399};
400
401template <cpu_isa_t isa>
402struct brgemm_inner_product_bwd_weights_t : public primitive_t {
403 struct pd_t : public cpu_inner_product_bwd_weights_pd_t {
404 pd_t(const inner_product_desc_t *adesc, const primitive_attr_t *attr,
405 const inner_product_fwd_pd_t *hint_fwd_pd)
406 : cpu_inner_product_bwd_weights_pd_t(adesc, attr, hint_fwd_pd) {}
407
408 DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("brgemm_bwd_w:", isa, ""),
409 brgemm_inner_product_bwd_weights_t);
410
411 status_t init(engine_t *engine) {
412
413 auto src_dt = invariant_src_md()->data_type;
414 auto diff_wei_type = invariant_wei_md()->data_type;
415 auto diff_dst_type = invariant_dst_md()->data_type;
416
417 bool ok = true && desc()->prop_kind == prop_kind::backward_weights
418 && !has_zero_dim_memory() && mayiuse(isa)
419 && utils::one_of(src_dt, data_type::f32, data_type::bf16,
420 data_type::f16)
421 && diff_dst_type == src_dt
422 && utils::one_of(diff_wei_type, data_type::f32, src_dt)
423 && attr()->has_default_values(
424 primitive_attr_t::skip_mask_t::post_ops);
425 if (!ok) return status::unimplemented;
426
427 CHECK(brgemm_inner_product_utils::init_ip_conf(isa, jbgp_, *desc(),
428 src_md_, diff_weights_md_, diff_dst_md_, diff_bias_md_,
429 attr_, dnnl_get_max_threads()));
430
431 const float alpha = 1.0;
432 const float beta = 1.0;
433 const float beta_init = 0.0;
434 const auto dt_a = isa == avx512_core_fp16 && jbgp_.use_buffer_a
435 ? data_type::f32
436 : jbgp_.src_dt;
437 const auto dt_b = isa == avx512_core_fp16 && jbgp_.use_buffer_b
438 ? data_type::f32
439 : jbgp_.dst_dt;
440
441 for_(int i_bs = 0; i_bs < 2; i_bs++)
442 for_(int i_init = 0; i_init < 2; i_init++)
443 for_(int i_M = 0; i_M < 2; i_M++)
444 for_(int i_N = 0; i_N < 2; i_N++)
445 for (int i_K = 0; i_K < 2; i_K++) {
446 auto vbeta = (i_init) ? beta_init : beta;
447 auto vM = (i_M) ? jbgp_.M_tail : jbgp_.M;
448 auto vN = (i_N) ? jbgp_.N_tail : jbgp_.N;
449 auto vK = (i_K) ? jbgp_.K_tail : jbgp_.K;
450 int bs = get_brg_batchsize(i_bs, i_K);
451 int idx = get_brg_kernel_idx(i_bs, i_init, i_M, i_N, i_K, bs);
452 if (idx < 0) continue;
453 brgemm_t &brg = brg_descs_[idx];
454 CHECK(brgemm_desc_init(&brg, isa, jbgp_.brg_type, dt_a, dt_b,
455 false, false, brgemm_row_major, alpha, vbeta, jbgp_.LDA,
456 jbgp_.LDB, jbgp_.LDC, vM, vN, vK));
457 if (jbgp_.is_amx) {
458 brgemm_attr_t brgattr;
459 brgattr.max_bs = bs;
460 brgattr.wary_tail_read = false;
461 brgattr.hint_expected_A_size = jbgp_.mb * jbgp_.ic;
462 brgattr.hint_expected_B_size = jbgp_.mb * jbgp_.oc;
463 brgattr.hint_expected_C_size = jbgp_.ic * jbgp_.oc;
464 brgattr.hint_innermost_loop = brgemm_ld_loop_innermost;
465 brgattr.use_uker = jbgp_.use_uker;
466 brgattr.use_interleave_stores = jbgp_.use_interleave_stores;
467 brgattr.hint_prefetching = jbgp_.hint_prefetching;
468 brgattr.fpmath_mode = attr()->fpmath_mode_;
469
470 CHECK(brgemm_desc_set_attr(&brg, brgattr));
471 jbgp_.amx_buf_size_per_thread
472 = nstl::max(brg.get_wsp_buffer_size(),
473 jbgp_.amx_buf_size_per_thread);
474 }
475 }
476
477 auto scratchpad = scratchpad_registry().registrar();
478 brgemm_inner_product_utils::init_scratchpad(scratchpad, jbgp_);
479
480 return status::success;
481 }
482
483 int get_brg_kernel_idx(bool is_bs_tail, bool do_initialization,
484 bool is_M_tail, bool is_N_tail, bool is_K_tail, int bs) const {
485 auto vM = (is_M_tail) ? jbgp_.M_tail : jbgp_.M;
486 auto vN = (is_N_tail) ? jbgp_.N_tail : jbgp_.N;
487 auto vK = (is_K_tail) ? jbgp_.K_tail : jbgp_.K;
488
489 if (vM == 0 || vN == 0 || vK == 0 || bs == 0 || jbgp_.LDA < vK
490 || jbgp_.LDB < vN || jbgp_.LDC < vN)
491 return -1;
492 return brgemm_inner_product_utils::get_brg_kernel_index(jbgp_,
493 is_bs_tail, do_initialization, is_M_tail, is_N_tail,
494 is_K_tail);
495 }
496
497 int get_brg_batchsize(bool is_bs_tail, bool is_K_tail) const {
498 auto bs = (is_K_tail) ? 1
499 : ((is_bs_tail) ? (jbgp_.os / jbgp_.os_block)
500 % jbgp_.nb_os_blocking
501 : jbgp_.nb_os_blocking);
502 return bs;
503 }
504
505 brgemm_t brg_descs_[brgemm_inner_product_utils::max_num_brg_kernels_ip];
506 jit_brgemm_primitive_conf_t jbgp_;
507 };
508
509 brgemm_inner_product_bwd_weights_t(const pd_t *apd) : primitive_t(apd) {}
510
511 status_t init(engine_t *engine) override {
512 const auto &jbgp = pd()->jbgp_;
513 for_(int i_bs = 0; i_bs < 2; i_bs++)
514 for_(int i_M = 0; i_M < 2; i_M++)
515 for_(int i_N = 0; i_N < 2; i_N++)
516 for_(int i_K = 0; i_K < 2; i_K++)
517 for (int i_init = 0; i_init < 2; i_init++) {
518 int bs = pd()->get_brg_batchsize(i_bs, i_K);
519 int idx = pd()->get_brg_kernel_idx(i_bs, i_init, i_M, i_N, i_K, bs);
520 if (idx < 0) continue;
521
522 brgemm_kernel_t *ker = nullptr;
523 CHECK(brgemm_kernel_create(&ker, pd()->brg_descs_[idx]));
524 CHECK(safe_ptr_assign(brg_kernels_[idx], ker));
525 if (jbgp.is_amx)
526 CHECK(brgemm_init_tiles(
527 pd()->brg_descs_[idx], &brg_kernel_palettes_[idx][0]));
528
529 if (jbgp.with_bias && i_M == 0 && i_init == 0) {
530 kernels_db_[i_K][i_N] = nullptr;
531 auto db_desc = pd()->brg_descs_[idx];
532 db_desc.reduce_dim = (i_K) ? jbgp.K_tail : jbgp.K;
533 if (db_desc.reduce_dim > 0 && db_desc.load_dim > 0) {
534 CHECK(safe_ptr_assign(kernels_db_[i_K][i_N],
535 new jit_brgemm_kernel_diff_bias_t(jbgp, db_desc)));
536 CHECK(kernels_db_[i_K][i_N]->create_kernel());
537 }
538 }
539 }
540 if (jbgp.is_amx) {
541 ext_ic_block_ = jbgp.ic_block_ext;
542 ext_oc_block_ = jbgp.oc_block_ext;
543 }
544 CHECK(create_brgemm_trans_src(trans_A_kernel_, &pd()->jbgp_));
545
546 if (jbgp.use_buffer_b)
547 CHECK(create_brgemm_trans_to_vnni(trans_B_kernel_, &pd()->jbgp_,
548 jit_brgemm_trans_to_vnni_t::matrix_to_transform::matrix_B));
549
550 if (!jbgp.is_amx) {
551 if (jbgp.wei_dt != jbgp.acc_dt)
552 CHECK(create_brgemm_trans_to_vnni(trans_C_kernel_, &pd()->jbgp_,
553 jit_brgemm_trans_to_vnni_t::matrix_to_transform::
554 matrix_C));
555 } else if (utils::one_of(
556 jbgp.wei_dt, data_type::bf16, data_type::f16)) {
557 CHECK(create_brgemm_amx_ip_trans_wei(diff_wei_trans_kernel_,
558 &pd()->jbgp_, ext_ic_block_, ext_oc_block_));
559 }
560 if (jbgp.nthr_mb > 1) {
561 CHECK(safe_ptr_assign(
562 acc_ker_, new cpu_accumulator_1d_t<data_type::f32>()));
563 CHECK(acc_ker_->create_kernel());
564 }
565
566 return status::success;
567 }
568
569 status_t execute(const exec_ctx_t &ctx) const override {
570 execute_backward_weights(ctx);
571 return status::success;
572 }
573
574private:
575 enum loop_order_t { osc_icc_occ, osc_occ_icc, occ_icc_osc };
576 struct thread_info_t;
577 std::unique_ptr<jit_brgemm_kernel_diff_bias_t> kernels_db_[2][2];
578 std::unique_ptr<brgemm_kernel_t>
579 brg_kernels_[brgemm_inner_product_utils::max_num_brg_kernels_ip];
580 std::unique_ptr<jit_brgemm_trans_src_t> trans_A_kernel_;
581 std::unique_ptr<jit_brgemm_trans_to_vnni_t> trans_B_kernel_;
582 std::unique_ptr<jit_brgemm_trans_to_vnni_t> trans_C_kernel_;
583 std::unique_ptr<cpu_accumulator_1d_t<data_type::f32>> acc_ker_;
584 std::unique_ptr<jit_amx_ip_trans_diff_wei> diff_wei_trans_kernel_;
585
586 void execute_backward_weights(const exec_ctx_t &ctx) const;
587 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
588 void compute_diff_weights_and_bias(const thread_info_t *ti) const;
589 void reduce_and_convert_diff_weights_and_bias(
590 const thread_info_t *ti) const;
591 void transform_matrix_a_chunk(char *tr_src, const char *src,
592 int trans_batch, int current_m, int current_k) const;
593 void transform_matrix_b_chunk(char *tr_diff_dst, const char *diff_dst,
594 int trans_batch, int current_col_size, int current_row_size) const;
595 void transpose_matrix_c_chunk(const thread_info_t *ti, const int ocb,
596 const int icb, int oc_size, int ic_size,
597 bool is_reduction = false) const;
598
599 char brg_kernel_palettes_[brgemm_inner_product_utils::
600 max_num_brg_kernels_ip][AMX_PALETTE_SIZE];
601 dim_t get_wei_offset(int ocb, int icb) const;
602 char *get_wei_acc_ptr(const thread_info_t *ti, int ocb, int icb,
603 int reduction_buf_idx = -1) const;
604
605 int ext_ic_block_ = 0;
606 int ext_oc_block_ = 0;
607};
608
609} // namespace x64
610} // namespace cpu
611} // namespace impl
612} // namespace dnnl
613
614#endif
615
616// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
617