1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/jit/gemm/xe_hp_systolic_gemm.hpp"
18
19#include "common/c_types_map.hpp"
20#include "common/dnnl_traits.hpp"
21#include "common/float16.hpp"
22#include "common/impl_registration.hpp"
23#include "common/type_helpers.hpp"
24#include "gpu/jit/gemm/gemm_walk_orders.hpp"
25#include "gpu/jit/utils/ngen_type_bridge.hpp"
26#include "gpu/ocl/gemm/xe_systolic_gemm_copy_kernel.hpp"
27
28namespace dnnl {
29namespace impl {
30namespace gpu {
31namespace jit {
32
33status_t xe_hp_systolic_gemm_t::pd_t::init(engine_t *engine) {
34 using namespace prop_kind;
35 using namespace data_type;
36 using namespace primitive_kind;
37 using smask_t = primitive_attr_t::skip_mask_t;
38 using arch_t = compute::gpu_arch_t;
39
40 assert(engine->kind() == engine_kind::gpu);
41 auto *compute_engine = utils::downcast<compute::compute_engine_t *>(engine);
42
43 if (!compute_engine->mayiuse_ngen_kernels()) return status::unimplemented;
44 if (!compute_engine->mayiuse_large_grf_mode()) return status::unimplemented;
45
46 dev_info_ = compute_engine->device_info();
47 auto arch = dev_info_->gpu_arch();
48
49 const auto &d = desc();
50
51 bool dt_float_ok = (d->a_type() == d->b_type()
52 && utils::one_of(d->a_type(), bf16, f16)
53 && utils::one_of(d->c_type(), f32, d->a_type()));
54
55 bool dt_int_ok = (utils::one_of(d->a_type(), u8, s8)
56 && utils::one_of(d->b_type(), u8, s8)
57 && utils::one_of(d->c_type(), s32, f32, s8, u8, f16));
58
59 if (dt_int_ok) {
60 a_zp_ = !attr()->zero_points_.has_default_values(DNNL_ARG_SRC);
61 b_zp_ = !attr()->zero_points_.has_default_values(DNNL_ARG_WEIGHTS);
62 c_zp_ = !attr()->zero_points_.has_default_values(DNNL_ARG_DST);
63 }
64
65 bool ok = set_default_formats(d->a_type());
66 if (!ok) return status::unimplemented;
67
68 CHECK(attr_.set_default_formats(dst_md(0)));
69
70 if (use_nocopy()) return status::unimplemented;
71
72 // LIMITATIONS:
73 // - batch is not supported for unpacked inputs.
74 // - runtime dims are not supported
75 bool limits_ok
76 = !utils::one_of(DNNL_RUNTIME_DIM_VAL, d->m(), d->n(), d->k());
77 if (!packed_a())
78 limits_ok = limits_ok && (d->lda() != DNNL_RUNTIME_DIM_VAL)
79 && (d->batch() == 1);
80 if (!packed_b())
81 limits_ok = limits_ok && (d->ldb() != DNNL_RUNTIME_DIM_VAL)
82 && (d->batch() == 1);
83 if (!packed_c())
84 limits_ok = limits_ok && (d->ldc() != DNNL_RUNTIME_DIM_VAL);
85
86 auto attr_skip_mask = smask_t::scales_runtime | smask_t::post_ops;
87
88 if (dt_int_ok) attr_skip_mask |= smask_t::zero_points_runtime;
89
90 bool arch_ok = utils::one_of(
91 arch, arch_t::xe_hp, arch_t::xe_hpg, arch_t::xe_hpc);
92
93 ok = ok && limits_ok && (dt_float_ok || dt_int_ok) && arch_ok
94 && compute_engine->mayiuse(compute::device_ext_t::
95 intel_subgroup_split_matrix_multiply_accumulate)
96 && attr()->has_default_values(attr_skip_mask)
97 && desc()->sum_ab == sum_ab::sum_none
98 && IMPLICATION(with_bias(),
99 utils::one_of(d->bias_type(), d->a_type(), f32)
100 && utils::one_of(bias_cmask(), 0, 1, 2, 3));
101
102 auto status = init_post_ops();
103 if (status != status::success) return status;
104
105 if (dt_int_ok) {
106 ok &= IMPLICATION(a_zp_, !packed_b())
107 && IMPLICATION(b_zp_, !packed_a());
108
109 int cmask_a = 0, cmask_b = 0, cmask_c = 0;
110 attr()->zero_points_.get(DNNL_ARG_WEIGHTS, &cmask_b);
111 attr()->zero_points_.get(DNNL_ARG_SRC, &cmask_a);
112 attr()->zero_points_.get(DNNL_ARG_DST, &cmask_c);
113 ok &= (cmask_a == 0) && (cmask_b == 0)
114 && utils::one_of(cmask_c, 0, 1 << 0, 1 << 1);
115 }
116
117 if (!ok) return status::unimplemented;
118
119 return status::success;
120}
121
122namespace {
123// Use no-copy if m*n < mn_limit * mn_limit and k < k_limit.
124// Zero means no limit.
125struct nocopy_table_t {
126 int mn_limit[2][2];
127 int k_limit[2][2];
128};
129
130const nocopy_table_t xe_hp_f16_nocopy_table[] = {
131 // NN NT TN TT
132 {{{2880, 512}, {4096, 1024}}, {{0, 0}, {0, 0}}}};
133
134const nocopy_table_t xe_hp_x8x8s32_nocopy_table[] = {
135 // NN NT TN TT
136 {{{1344, 576}, {4800, 384}}, {{0, 0}, {0, 0}}}};
137
138const nocopy_table_t xe_hp_f16_nocopy_bad_ld_table[] = {
139 // NN NT TN TT
140 {{{288, 320}, {288, 288}}, {{288, 320}, {288, 288}}}};
141
142const nocopy_table_t xe_hp_x8x8s32_nocopy_bad_ld_table[] = {
143 // NN NT TN TT
144 {{{656, 528}, {352, 384}}, {{656, 528}, {352, 384}}}};
145
146const nocopy_table_t xe_hpc_f16_nocopy_table[] = {
147 // NN NT TN TT
148 {{{8192, 8192}, {8192, 8192}}, {{0, 0}, {0, 0}}}};
149
150const nocopy_table_t xe_hpc_x8x8s32_nocopy_table[] = {
151 // NN NT TN TT
152 {{{2049, 3000}, {1088, 1024}}, {{0, 0}, {0, 0}}}};
153
154const nocopy_table_t xe_hpc_f16_nocopy_bad_ld_table[] = {
155 // NN NT TN TT
156 {{{1024, 1024}, {1024, 1024}}, {{0, 0}, {0, 0}}}};
157
158const nocopy_table_t xe_hpc_x8x8s32_nocopy_bad_ld_table[] = {
159 // NN NT TN TT
160 {{{624, 624}, {480, 624}}, {{0, 0}, {0, 0}}}};
161} // namespace
162
163bool xe_hp_systolic_gemm_t::pd_t::use_nocopy() {
164 using namespace data_type;
165
166 const auto &d = desc();
167 bool xehpc = (dev_info_->gpu_arch() == compute::gpu_arch_t::xe_hpc);
168
169 if (any_prepacked_ || (packed_a_ && packed_b_)) return false;
170
171 // Use no-copy for gemv/ger cases.
172 if (d->m() <= 1 || d->n() <= 1 || d->k() <= 1) return true;
173
174 // Use no-copy implementation if one matrix is very small.
175 if (d->m() < 32 && d->n() < 32) return true;
176 if (d->m() < 32 && d->k() < 32) return true;
177 if (d->n() < 32 && d->k() < 32) return true;
178
179 // Use no-copy for small/medium sizes.
180 if (utils::one_of(d->a_type(), bf16, f16, s8, u8)) {
181 // clang-format off
182 const nocopy_table_t *all_tables[2][2][3] = {
183 {{xe_hp_f16_nocopy_table, xe_hp_f16_nocopy_table, xe_hp_x8x8s32_nocopy_table},
184 {xe_hp_f16_nocopy_bad_ld_table, xe_hp_f16_nocopy_bad_ld_table, xe_hp_x8x8s32_nocopy_bad_ld_table}},
185 {{xe_hpc_f16_nocopy_table, xe_hpc_f16_nocopy_table, xe_hpc_x8x8s32_nocopy_table},
186 {xe_hpc_f16_nocopy_bad_ld_table, xe_hpc_f16_nocopy_bad_ld_table, xe_hpc_x8x8s32_nocopy_bad_ld_table}}
187 };
188 // clang-format on
189 int type_idx = (d->a_type() == f16) ? 0 : (d->a_type() == bf16) ? 1 : 2;
190 int arch_idx = xehpc ? 1 : 0;
191 bool bad_ld = false;
192
193 if (!packed_a_) {
194 auto lda_bytes = d->lda() * types::data_type_size(d->a_type());
195 bad_ld |= ((lda_bytes & 0x3) != 0);
196 }
197 if (!packed_b_) {
198 auto ldb_bytes = d->ldb() * types::data_type_size(d->b_type());
199 bad_ld |= ((ldb_bytes & 0x3) != 0);
200 }
201
202 auto table = all_tables[arch_idx][int(bad_ld)][type_idx];
203 long mnl = table->mn_limit[d->transa()][d->transb()];
204 long kl = table->k_limit[d->transa()][d->transb()];
205
206 if ((mnl == 0 || d->m() * d->n() < mnl * mnl)
207 && (kl == 0 || d->k() < kl))
208 return true;
209 }
210
211 return false;
212}
213
214bool xe_hp_systolic_gemm_t::pd_t::set_default_formats(data_type_t dt) {
215 using namespace format_tag;
216 using new_kd_t = gen_gemm_xe_systolic_kernel_desc_t;
217
218 auto sz = types::data_type_size(dt);
219 const auto &d = desc();
220 auto arch = dev_info_->gpu_arch();
221
222 auto &a_desc = desc_.b_desc;
223 auto &b_desc = desc_.a_desc;
224 auto &c_desc = desc_.c_desc;
225
226 memory_desc_wrapper a_mdw(&a_desc);
227 memory_desc_wrapper b_mdw(&b_desc);
228 memory_desc_wrapper c_mdw(&c_desc);
229
230 bool a_any = a_mdw.format_any();
231 bool b_any = b_mdw.format_any();
232 bool c_any = c_mdw.format_any();
233 bool batch = d->is_batched();
234
235 if (batch_dims() > 1) return false;
236
237 format_tag_t a_packed_tag_16 = undef;
238 format_tag_t a_packed_tag_32 = undef;
239 format_tag_t a_packed_tag_64 = undef;
240 format_tag_t b_packed_tag_16 = undef;
241 format_tag_t b_packed_tag_32 = undef;
242 format_tag_t b_packed_tag_48 = undef;
243 format_tag_t unpacked_tag = batch ? abc : ab;
244
245 if (arch == compute::gpu_arch_t::xe_hpc) {
246 a_packed_tag_64 = batch ? ((sz == 2) ? aCB4c8b16c2b : aCB4c8b16c4b)
247 : ((sz == 2) ? BA4b8a16b2a : BA4b8a16b4a);
248 a_packed_tag_16 = batch ? ((sz == 2) ? aCB16c2b : aCB16c4b)
249 : ((sz == 2) ? BA16b2a : BA16b4a);
250 b_packed_tag_16 = batch ? ((sz == 2) ? aBC16b16c : aBC16b32c)
251 : ((sz == 2) ? AB16a16b : AB16a32b);
252 } else {
253 a_packed_tag_32 = batch ? ((sz == 2) ? aCB4c8b8c2b : aCB4c8b8c4b)
254 : ((sz == 2) ? BA4b8a8b2a : BA4b8a8b4a);
255 b_packed_tag_48 = batch ? ((sz == 2) ? aBC48b16c : aBC48b32c)
256 : ((sz == 2) ? AB48a16b : AB48a32b);
257 }
258 b_packed_tag_32 = batch ? ((sz == 2) ? aBC32b16c : aBC32b32c)
259 : ((sz == 2) ? AB32a16b : AB32a32b);
260
261 bool a_prepacked_16 = a_mdw.matches_tag(a_packed_tag_16);
262 bool a_prepacked_32 = a_mdw.matches_tag(a_packed_tag_32);
263 bool a_prepacked_64 = a_mdw.matches_tag(a_packed_tag_64);
264 bool bc_prepacked_16 = b_mdw.matches_tag(b_packed_tag_16)
265 || c_mdw.matches_tag(b_packed_tag_16);
266 bool bc_prepacked_32 = b_mdw.matches_tag(b_packed_tag_32)
267 || c_mdw.matches_tag(b_packed_tag_32);
268 bool bc_prepacked_48 = b_mdw.matches_tag(b_packed_tag_48)
269 || c_mdw.matches_tag(b_packed_tag_48);
270
271 any_prepacked_ = a_prepacked_16 || a_prepacked_32 || a_prepacked_64
272 || bc_prepacked_16 || bc_prepacked_32 || bc_prepacked_48;
273
274 unroll_m_ = 0;
275 unroll_n_ = 0;
276 alt_ = false;
277 if (a_prepacked_16) unroll_m_ = 16;
278 if (a_prepacked_32) unroll_m_ = 32;
279 if (a_prepacked_64) unroll_m_ = 64;
280 if (bc_prepacked_16) unroll_n_ = 16;
281 if (bc_prepacked_32) unroll_n_ = 32;
282 if (bc_prepacked_48) unroll_n_ = 48;
283
284 new_kd_t::choose_unrolls(arch, dev_info_->eu_count(), d->a_type(),
285 d->b_type(), d->c_type(), d->m(), d->n(), d->k(), d->batch(),
286 unroll_m_, unroll_n_, alt_);
287
288 format_tag_t a_packed_tag = (unroll_m_ == 64)
289 ? a_packed_tag_64
290 : (unroll_m_ == 32) ? a_packed_tag_32 : a_packed_tag_16;
291 format_tag_t b_packed_tag = (unroll_n_ == 48)
292 ? b_packed_tag_48
293 : (unroll_n_ == 32) ? b_packed_tag_32 : b_packed_tag_16;
294 format_tag_t c_packed_tag = b_packed_tag;
295
296 packed_a_ = packed_b_ = packed_c_ = false;
297
298 if (a_any) {
299 if (b_zp_) {
300 CHECK(memory_desc_init_by_tag(a_desc, unpacked_tag));
301 } else {
302 CHECK(memory_desc_init_by_tag(a_desc, a_packed_tag));
303 auto ld = a_desc.padded_dims[batch ? 1 : 0];
304 ld = nice_ld(ld, int(sz));
305 auto &ostride = a_desc.format_desc.blocking.strides[batch ? 2 : 1];
306 if (batch) {
307 auto &bstride = a_desc.format_desc.blocking.strides[0];
308 bstride = (bstride / ostride) * unroll_m_ * ld;
309 }
310 ostride = unroll_m_ * ld;
311 packed_a_ = true;
312 }
313 } else if (!a_mdw.matches_tag(a_packed_tag)
314 && !is_md_gemm_compatible_plain_format(&a_desc))
315 return false;
316
317 if (b_any) {
318 if (a_zp_) {
319 CHECK(memory_desc_init_by_tag(b_desc, unpacked_tag));
320 } else {
321 CHECK(memory_desc_init_by_tag(b_desc, b_packed_tag));
322 if (unroll_n_ > 16) { // Bug in zero-padding when unroll_n_ == 16
323 auto ld = b_desc.padded_dims[batch ? 2 : 1];
324 ld = nice_ld(ld, int(sz));
325 auto &ostride
326 = b_desc.format_desc.blocking.strides[batch ? 1 : 0];
327 if (batch) {
328 auto &bstride = b_desc.format_desc.blocking.strides[0];
329 bstride = (bstride / ostride) * unroll_n_ * ld;
330 }
331 ostride = unroll_n_ * ld;
332 }
333 packed_b_ = true;
334 }
335 } else if (!b_mdw.matches_tag(b_packed_tag)
336 && !is_md_gemm_compatible_plain_format(&b_desc))
337 return false;
338
339 if (c_any) {
340 CHECK(memory_desc_init_by_tag(c_desc, c_packed_tag));
341 if (unroll_n_ > 16) { // Bug in zero-padding when unroll_n_ == 16
342 auto ld = c_desc.padded_dims[batch ? 2 : 1];
343 ld = nice_ld(ld, int(sz));
344 auto &ostride = c_desc.format_desc.blocking.strides[batch ? 1 : 0];
345 if (batch) {
346 auto &bstride = c_desc.format_desc.blocking.strides[0];
347 bstride = (bstride / ostride) * unroll_n_ * ld;
348 }
349 ostride = unroll_n_ * ld;
350 }
351 packed_c_ = true;
352 } else if (!c_mdw.matches_tag(c_packed_tag)
353 && !is_md_gemm_compatible_plain_format(&c_desc, true))
354 return false;
355
356 packed_a_ = packed_a_ || a_mdw.matches_tag(a_packed_tag);
357 packed_b_ = packed_b_ || b_mdw.matches_tag(b_packed_tag);
358 packed_c_ = packed_c_ || c_mdw.matches_tag(b_packed_tag);
359
360 // No 16x16 copy kernels currently.
361 if ((!packed_a_ && unroll_m_ == 16) || (!packed_b_ && unroll_n_ == 16))
362 return false;
363
364 return gpu_gemm_pd_t::set_default_formats();
365}
366
367status_t xe_hp_systolic_gemm_t::init(engine_t *engine) {
368 arch_ = pd()->dev_info_->gpu_arch();
369 eu_count_ = pd()->dev_info_->eu_count();
370
371 auto a_type = pd()->desc()->a_type();
372 auto b_type = pd()->desc()->b_type();
373
374 int cmask = -1;
375
376 if (pd()->with_c_zero_points())
377 pd()->attr()->zero_points_.get(DNNL_ARG_DST, &cmask);
378 else if (pd()->with_bias())
379 cmask = pd()->bias_cmask();
380
381 switch (cmask) {
382 case 0: co_kind_ = 'F'; break;
383 case (1 << 1): co_kind_ = 'R'; break;
384 case (1 << 0): co_kind_ = 'C'; break;
385 case 3: co_kind_ = 'M'; break;
386 case -1:
387 default: co_kind_ = 'N'; break;
388 }
389
390 // Initialize compute kernels (assembly)
391 {
392 auto status = init_compute(engine);
393 if (status != status::success) return status;
394 }
395
396 // Initialize copy kernels (OpenCL)
397 for (bool copy_b : {false, true}) {
398 for (bool clear_sum : {false, true}) {
399 if (clear_sum && !pd()->with_ab_zero_points()) continue;
400 if (!copy_b ? pd()->packed_a() : pd()->packed_b()) continue;
401
402 using copy_kernel_t = ocl::xe_systolic_gemm_copy_kernel_t;
403 compute::kernel_ctx_t kernel_ctx;
404
405 auto trans
406 = !copy_b ? pd()->desc()->transa() : pd()->desc()->transb();
407 auto status = copy_kernel_t::init_kernel_ctx(kernel_ctx, arch_,
408 !copy_b ? a_type : b_type, pd()->unroll_n(), copy_b, trans,
409 pd()->with_ab_zero_points(), clear_sum);
410 if (status != status::success) return status;
411
412 create_kernel(engine, &copy_kernel_[copy_b][clear_sum],
413 copy_kernel_t::name(arch_), kernel_ctx);
414 if (!copy_kernel_[copy_b][clear_sum]) return status::runtime_error;
415 }
416 }
417
418 if (get_verbose() >= 2) {
419 printf("onednn_verbose,info,gpu,gemm,kernel:%dx%d,%dx%dx%d\n",
420 pd()->unroll_m(), pd()->unroll_n(), compute_info_.wg[LoopM],
421 compute_info_.wg[LoopN], compute_info_.wg[LoopK]);
422 }
423
424 return status::success;
425}
426
427status_t xe_hp_systolic_gemm_t::init_compute(engine_t *engine) {
428 using kernel_t = gen_gemm_kernel_t;
429 using kd_t = gen_gemm_xe_systolic_kernel_desc_t;
430
431 const auto d = pd()->desc();
432
433 auto a_type = d->a_type();
434 auto b_type = d->b_type();
435 auto c_type = d->c_type();
436 auto co_type = pd()->impl_co_type();
437 auto acc_type = pd()->impl_acc_type();
438
439 bool may_k_block
440 = (d->k() > kd_t::min_block_k(a_type)) && pd()->allow_k_blocking();
441 bool got_info = false;
442
443 auto post_ops = pd()->post_ops();
444 bool with_post_ops = (post_ops->find(primitive_kind::eltwise) != -1)
445 || (post_ops->find(primitive_kind::binary) != -1);
446
447 kd_t kd_full;
448
449 auto status = kd_full.select_kernel(arch_, eu_count_, pd()->with_batch(),
450 pd()->packed_c(), pd()->with_a_zero_points(),
451 pd()->with_b_zero_points(), pd()->with_c_zero_points(),
452 pd()->with_bias(), pd()->alpha(), pd()->beta(), *post_ops, a_type,
453 b_type, c_type, co_type, acc_type, d->m(), d->n(), d->k(),
454 d->batch(), pd()->unroll_m(), pd()->unroll_n(), pd()->alt());
455
456 if (status != status::success) return status;
457
458 problem_ = std::move(*kd_full.problem());
459
460 for (bool first_k_block : {false, true}) {
461 for (bool last_k_block : {false, true}) {
462 if ((!first_k_block || !last_k_block) && !may_k_block) continue;
463 if (may_k_block && last_k_block && !pd()->with_c_zero_points()
464 && !with_post_ops)
465 kernel_[first_k_block][last_k_block]
466 = kernel_[first_k_block][false];
467 else if (may_k_block && first_k_block && pd()->beta() == 1.0f)
468 kernel_[first_k_block][last_k_block]
469 = kernel_[false][last_k_block];
470 else {
471 auto this_beta = pd()->beta();
472 bool this_c_offset = pd()->with_c_zero_points();
473 auto *this_post_ops = pd()->post_ops();
474 post_ops_t no_post_ops;
475
476 if (!first_k_block) this_beta = 1.0f;
477 if (!last_k_block) {
478 this_c_offset = false;
479 this_post_ops = &no_post_ops;
480 }
481
482 kd_t kd;
483
484 auto status = kd.select_kernel(arch_, eu_count_,
485 pd()->with_batch(), pd()->packed_c(),
486 pd()->with_a_zero_points(), pd()->with_b_zero_points(),
487 this_c_offset, pd()->with_bias(), pd()->alpha(),
488 this_beta, *this_post_ops, a_type, b_type, c_type,
489 co_type, acc_type, d->m(), d->n(), d->k(), d->batch(),
490 pd()->unroll_m(), pd()->unroll_n(), pd()->alt());
491
492 if (status != status::success) return status;
493
494 if (!got_info) {
495 compute_info_ = *kd.driver_info();
496 got_info = true;
497 }
498
499 kernel_t kernel(kd);
500
501 create_kernel(
502 engine, &kernel_[first_k_block][last_k_block], &kernel);
503
504 if (!kernel_[first_k_block][last_k_block])
505 return status::runtime_error;
506 }
507 }
508 }
509
510 return status::success;
511}
512
513status_t xe_hp_systolic_gemm_t::init_res_storage(
514 engine_t *engine, gpu_resource_t *r) const {
515 auto a_type = pd()->desc()->a_type();
516 auto b_type = pd()->desc()->b_type();
517
518 auto m = pd()->desc()->m();
519 auto n = pd()->desc()->n();
520 auto k = pd()->desc()->k();
521
522 int64_t align_m = compute_info_.wgTile(LoopM);
523 int64_t align_n = compute_info_.wgTile(LoopN);
524
525 auto m_aligned = utils::rnd_up(m, align_m);
526 auto n_aligned = utils::rnd_up(n, align_n);
527
528 auto max_ldab_packed = max_ld_packed(k);
529
530 if (!pd()->packed_a()) {
531 memory_storage_t *a_packed_ptr;
532 engine->create_memory_storage(&a_packed_ptr,
533 m_aligned * max_ldab_packed * types::data_type_size(a_type));
534 if (!a_packed_ptr) return status::runtime_error;
535
536 std::unique_ptr<memory_storage_t> a_packed(a_packed_ptr);
537 r->add_memory_storage(A_PACKED_, std::move(a_packed));
538 }
539
540 if (!pd()->packed_b()) {
541 memory_storage_t *b_packed_ptr;
542 engine->create_memory_storage(&b_packed_ptr,
543 n_aligned * max_ldab_packed * types::data_type_size(b_type));
544 if (!b_packed_ptr) return status::runtime_error;
545
546 std::unique_ptr<memory_storage_t> b_packed(b_packed_ptr);
547 r->add_memory_storage(B_PACKED_, std::move(b_packed));
548 }
549
550 return status::success;
551}
552
553bool xe_hp_systolic_gemm_t::enable_mn_blocking() const {
554 return (pd()->desc()->m() >= 8192) && (pd()->desc()->n() >= 8192);
555}
556
557std::tuple<int64_t, int64_t, int64_t>
558xe_hp_systolic_gemm_t::get_blocking() const {
559 int64_t m = pd()->desc()->m();
560 int64_t n = pd()->desc()->n();
561 int64_t k = pd()->desc()->k();
562
563 int64_t unroll_k = compute_info_.unroll[LoopK];
564
565 int64_t align_m = compute_info_.wgTile(LoopM);
566 int64_t align_n = compute_info_.wgTile(LoopN);
567
568 m = utils::rnd_up(m, align_m);
569 n = utils::rnd_up(n, align_n);
570
571 // Decide on m/n blocking.
572 int64_t block_m = compute_info_.blocking[LoopM];
573 int64_t block_n = compute_info_.blocking[LoopN];
574 int64_t max_block_m = utils::rnd_up(m, align_m);
575 int64_t max_block_n = utils::rnd_up(n, align_n);
576
577 if (enable_mn_blocking()) {
578 if (n <= block_n)
579 block_m = (block_m * block_n) / n;
580 else if (m <= block_m)
581 block_n = (block_m * block_n) / m;
582 else if (n < 2 * block_n) {
583 block_n = utils::rnd_up(n / 2, align_n);
584 block_m = (2 * block_m * block_n) / n;
585 } else if (m < 2 * block_m) {
586 block_m = utils::rnd_up(m / 2, align_m);
587 block_n = (2 * block_m * block_n) / m;
588 }
589
590 block_m = utils::rnd_dn(nstl::min(block_m, max_block_m), align_m);
591 block_n = utils::rnd_dn(nstl::min(block_n, max_block_n), align_n);
592 } else {
593 block_m = m;
594 block_n = n;
595 }
596
597 // Decide on k blocking.
598 int64_t block_k = compute_info_.blocking[LoopK];
599 int64_t nblock_k = utils::div_up(k, block_k);
600 nblock_k = nstl::max<int64_t>(nblock_k, 1);
601 block_k = utils::div_up(k, nblock_k);
602 block_k = nstl::max<dim_t>(block_k, 1);
603 block_k = utils::rnd_up(pd()->allow_k_blocking() ? block_k : k, unroll_k);
604 block_k = nstl::max<dim_t>(block_k, 1);
605
606 return std::make_tuple(block_m, block_n, block_k);
607}
608
609status_t xe_hp_systolic_gemm_t::launch_copy(const gemm_exec_ctx_t &ctx,
610 int64_t r, int64_t c, const memory_storage_t &src, int64_t offset_src,
611 int64_t ld_src, const memory_storage_t &dst, int32_t offset_dst,
612 int32_t ld_dst, bool copyb) const {
613
614 using copy_kernel_t = ocl::xe_systolic_gemm_copy_kernel_t;
615
616 if (pd()->with_ab_zero_points()) {
617 auto status
618 = launch_clear_sum(ctx, r, c, dst, offset_dst, ld_dst, copyb);
619 if (status) return status;
620 }
621
622 int64_t unroll_k = compute_info_.unroll[LoopK];
623
624 int64_t align_r = 0, align_c = 0;
625
626 if (!copyb) {
627 align_r = compute_info_.wgTile(LoopM);
628 align_c = unroll_k;
629 } else {
630 align_r = unroll_k;
631 align_c = compute_info_.wgTile(LoopN);
632 }
633
634 bool transa = (pd()->desc()->transa() == dnnl_trans);
635 bool transb = (pd()->desc()->transb() == dnnl_trans);
636 bool trans = !copyb ? transa : transb;
637
638 auto &kernel = copy_kernel_[copyb][false];
639
640 assert(kernel);
641 compute::kernel_arg_list_t arg_list;
642 arg_list.set(0, r);
643 arg_list.set(1, c);
644 arg_list.set(2, src);
645 arg_list.set(3, offset_src);
646 arg_list.set(4, ld_src);
647 arg_list.set(5, dst);
648 arg_list.set(6, offset_dst);
649 arg_list.set(7, ld_dst);
650
651 auto elt_size = types::data_type_size(pd()->desc()->a_type());
652 size_t r_threads = utils::div_up(utils::rnd_up(r, align_r),
653 copy_kernel_t::unroll_r(
654 arch_, elt_size, pd()->unroll_n(), copyb, trans));
655 size_t c_threads = utils::div_up(utils::rnd_up(c, align_c),
656 copy_kernel_t::unroll_c(
657 arch_, elt_size, pd()->unroll_n(), copyb, trans));
658 size_t sg = copy_kernel_t::subgroup_size(arch_, elt_size, copyb, trans);
659
660 size_t r_lsz = trans ? 1 : 16;
661 size_t c_lsz = trans ? 16 : 1;
662
663 if (r_threads > r_lsz)
664 r_threads = utils::rnd_up(r_threads, r_lsz);
665 else
666 r_lsz = r_threads;
667
668 if (c_threads > c_lsz)
669 c_threads = utils::rnd_up(c_threads, c_lsz);
670 else
671 c_lsz = c_threads;
672
673 size_t gws[3] = {r_threads * sg, c_threads, 1};
674 size_t lws[3] = {r_lsz * sg, c_lsz, 1};
675
676 auto nd_range = compute::nd_range_t(gws, lws);
677
678 return parallel_for(ctx, nd_range, kernel, arg_list);
679}
680
681status_t xe_hp_systolic_gemm_t::launch_clear_sum(const gemm_exec_ctx_t &ctx,
682 int64_t r, int64_t c, const memory_storage_t &dst, int32_t offset_dst,
683 int32_t ld_dst, bool copyb) const {
684
685 auto &kernel = copy_kernel_[copyb][true];
686
687 assert(kernel);
688 compute::kernel_arg_list_t arg_list;
689 arg_list.set(0, r);
690 arg_list.set(1, c);
691 arg_list.set(2, dst);
692 arg_list.set(3, offset_dst);
693 arg_list.set(4, ld_dst);
694
695 auto elt_size = types::data_type_size(pd()->desc()->a_type());
696 size_t threads = !copyb ? utils::div_up(r, pd()->unroll_m())
697 : utils::div_up(c, pd()->unroll_n());
698 size_t sg = ocl::xe_systolic_gemm_copy_kernel_t::subgroup_size_clear_sum(
699 arch_, elt_size, copyb);
700
701 size_t gws[3] = {threads * sg, 1, 1};
702 size_t lws[3] = {sg, 1, 1};
703
704 auto nd_range = compute::nd_range_t(gws, lws);
705
706 return parallel_for(ctx, nd_range, kernel, arg_list);
707}
708
709status_t xe_hp_systolic_gemm_t::launch_compute(const gemm_exec_ctx_t &ctx,
710 int32_t m, int32_t n, int32_t k, const memory_storage_t &ap,
711 int64_t offset_a, int32_t lda, const memory_storage_t &bp,
712 int64_t offset_b, int32_t ldb, const memory_storage_t &c,
713 int64_t offset_c, int32_t ldc, float alpha, float beta,
714 const memory_storage_t *ao, const memory_storage_t *bo,
715 const memory_storage_t &co, int32_t offset_co, int po_count,
716 const memory_storage_t **po_srcs, int32_t *offset_po_src,
717 bool first_k_block, bool last_k_block, int32_t batch, int32_t stride_a,
718 int32_t stride_b, int32_t stride_c) const {
719
720 auto tg_m = compute_info_.wg[LoopM];
721 auto tg_n = compute_info_.wg[LoopN];
722
723 auto &kernel = kernel_[first_k_block][last_k_block];
724
725 // kernel void gemm_kernel(global char *Ap, global uchar *Bp, global int *C,
726 // int k, int ldc,
727 // long offsetA, long offsetB, long offsetC,
728 // int m, int n,
729 // float alpha, float beta,
730 // int lda, int ldb)
731
732 assert(kernel);
733
734 compute::kernel_arg_list_t arg_list;
735 int argn = 0;
736 arg_list.set(argn++, ap);
737 arg_list.set(argn++, bp);
738 arg_list.set(argn++, c);
739 arg_list.set(argn++, offset_a);
740 arg_list.set(argn++, offset_b);
741 arg_list.set(argn++, offset_c);
742 arg_list.set(argn++, lda);
743 arg_list.set(argn++, ldb);
744 arg_list.set(argn++, ldc);
745 arg_list.set(argn++, m);
746 arg_list.set(argn++, n);
747 arg_list.set(argn++, k);
748 arg_list.set(argn++, alpha);
749 arg_list.set(argn++, beta);
750
751 if (pd()->with_a_zero_points()) arg_list.set(argn++, *ao);
752 if (pd()->with_b_zero_points()) arg_list.set(argn++, *bo);
753 if ((pd()->with_bias() || pd()->with_c_zero_points())) {
754 arg_list.set(argn++, co);
755 arg_list.set(argn++, offset_co);
756 if (pd()->with_bias()) {
757 int32_t ldco = pd()->desc()->ld_bias();
758 arg_list.set(argn++, ldco);
759 }
760 }
761 for (int i = 0; i < po_count; i++) {
762 if (!po_srcs[i]) continue;
763 arg_list.set(argn++, *po_srcs[i]);
764 arg_list.set(argn++, offset_po_src[i]);
765
766 if (problem_.binaryRow[i] && problem_.binaryCol[i])
767 arg_list.set(argn++, int32_t(pd()->ld_binary(i)));
768 }
769
770 uint32_t flags = 0;
771 if (co_kind_ == 'R') flags |= FlagCORow;
772 if (co_kind_ == 'C') flags |= FlagCOColumn;
773 if (!first_k_block) flags |= FlagNoninitialKBlock;
774 if (!last_k_block) flags |= FlagNonfinalKBlock;
775 arg_list.set(argn++, flags);
776
777 if (pd()->with_batch()) {
778 arg_list.set(argn++, stride_a);
779 arg_list.set(argn++, stride_b);
780 arg_list.set(argn++, stride_c);
781 for (int i = 0; i < po_count; i++)
782 if (problem_.binaryBatch[i])
783 arg_list.set(argn++, int32_t(pd()->stride_binary(i, 0)));
784 }
785
786 auto thread_m = utils::div_up(m, pd()->unroll_m() * tg_m) * tg_m;
787 auto thread_n = utils::div_up(n, pd()->unroll_n() * tg_n) * tg_n;
788
789 if (walk_n_first_) std::swap(thread_m, thread_n);
790
791 size_t gws[3] = {size_t(thread_m), size_t(thread_n), 1};
792 size_t lws[3] = {size_t(tg_m), size_t(tg_n), 1};
793 if (pd()->with_batch()) gws[2] = batch;
794
795 lws[1] *= compute_info_.wgExpand;
796 gws[1] *= compute_info_.wgExpand;
797
798 gemm_linear_order_args(arg_list, argn, lws, gws, m, n, false, compute_info_,
799 pd()->dev_info_);
800
801 lws[0] *= compute_info_.subgroupSize;
802 gws[0] *= compute_info_.subgroupSize;
803
804 auto nd_range = compute::nd_range_t(gws, lws);
805
806 return parallel_for(ctx, nd_range, kernel, arg_list);
807}
808
809status_t xe_hp_systolic_gemm_t::execute(const gemm_exec_ctx_t &ctx) const {
810 auto a_type = pd()->desc()->a_type();
811 auto b_type = pd()->desc()->b_type();
812 auto c_type = pd()->desc()->c_type();
813 auto bias_type = pd()->desc()->bias_type();
814
815 auto m = pd()->desc()->m();
816 auto n = pd()->desc()->n();
817 auto k = pd()->desc()->k();
818 auto batch = pd()->desc()->batch();
819
820 bool packed_a = pd()->packed_a();
821 bool packed_b = pd()->packed_b();
822 bool packed_c = pd()->packed_c();
823
824 auto lda = packed_a ? 0 : pd()->desc()->lda();
825 auto ldb = packed_b ? 0 : pd()->desc()->ldb();
826 auto ldc = packed_c ? pd()->ldc_packed() : pd()->desc()->ldc();
827 auto ldco = pd()->with_bias() ? pd()->desc()->ld_bias() : 0;
828
829 auto stride_a = pd()->desc()->stride_a();
830 auto stride_b = pd()->desc()->stride_b();
831 auto stride_c = pd()->desc()->stride_c();
832
833 auto alpha = pd()->alpha();
834 auto beta = pd()->beta();
835
836 auto &a = GEMM_CTX_ARG_STORAGE(b);
837 auto &b = GEMM_CTX_ARG_STORAGE(a);
838 auto &c = GEMM_CTX_ARG_STORAGE(c);
839 auto &c_zp = GEMM_CTX_ARG_STORAGE(c_zero_point);
840 auto &bias = GEMM_CTX_ARG_STORAGE(bias);
841 auto *co = &c_zp;
842 memory_storage_t *ao = nullptr, *bo = nullptr;
843
844 auto &a_packed = packed_a ? a : CTX_GPU_RES_STORAGE(A_PACKED_);
845 auto &b_packed = packed_b ? b : CTX_GPU_RES_STORAGE(B_PACKED_);
846
847 const memory_storage_t *po_srcs[GEMM_MAX_PO];
848
849 int po_count = pd()->post_ops()->len();
850 assert(po_count <= GEMM_MAX_PO);
851
852 for (int i = 0; i < po_count; i++) {
853 auto &src = pd()->binary_srcs()[i];
854 switch (src.type) {
855 case pd_t::binary_src_t::binary:
856 po_srcs[i]
857 = ctx.args()
858 .exec_args
859 .at(DNNL_ARG_ATTR_MULTIPLE_POST_OP(src.index)
860 | DNNL_ARG_SRC_1)
861 .mem->memory_storage();
862 break;
863 case pd_t::binary_src_t::bias: po_srcs[i] = &bias; break;
864 case pd_t::binary_src_t::scales:
865 switch (src.index) {
866 case DNNL_ARG_WEIGHTS:
867 po_srcs[i] = &GEMM_CTX_ARG_STORAGE(a_scales);
868 break;
869 case DNNL_ARG_SRC:
870 po_srcs[i] = &GEMM_CTX_ARG_STORAGE(b_scales);
871 break;
872 case DNNL_ARG_DST:
873 po_srcs[i] = &GEMM_CTX_ARG_STORAGE(c_scales);
874 break;
875 default:
876 po_srcs[i] = nullptr;
877 assert(!"invalid scale type");
878 break;
879 }
880 break;
881 default: po_srcs[i] = nullptr; break;
882 }
883 }
884
885 size_t off_a0
886 = a.offset() / types::data_type_size(a_type) + pd()->dyn_offset_a;
887 size_t off_b0
888 = b.offset() / types::data_type_size(b_type) + pd()->dyn_offset_b;
889 size_t off_c0
890 = c.offset() / types::data_type_size(c_type) + pd()->dyn_offset_c;
891 size_t off_co0 = 0;
892
893 int32_t po_offsets0[GEMM_MAX_PO] = {0}, po_offsets[GEMM_MAX_PO] = {0};
894 for (int i = 0; i < po_count; i++)
895 if (po_srcs[i])
896 po_offsets0[i] = po_srcs[i]->offset() / problem_.Tbinary[i];
897
898 if (pd()->with_ab_zero_points()) {
899 ao = &GEMM_CTX_ARG_STORAGE(a_zero_point);
900 bo = &GEMM_CTX_ARG_STORAGE(b_zero_point);
901 }
902
903 if (pd()->with_bias()) {
904 off_co0 = bias.offset() / types::data_type_size(bias_type);
905 co = &bias;
906 }
907
908 int64_t block_m = 0, block_n = 0, block_k = 0;
909 std::tie(block_m, block_n, block_k) = get_blocking();
910
911 auto ld_packed = get_ld_packed(k);
912 auto lda_packed = packed_a ? pd()->lda_packed() : ld_packed;
913 auto ldb_packed = packed_b ? pd()->ldb_packed() : ld_packed;
914
915 status_t status;
916
917 if (!packed_a) {
918 assert(batch == 1);
919 status = launch_copy(
920 ctx, m, k, a, off_a0, lda, a_packed, 0, lda_packed, false);
921 if (status) return status;
922 }
923
924 if (!packed_b) {
925 assert(batch == 1);
926 status = launch_copy(
927 ctx, k, n, b, off_b0, ldb, b_packed, 0, ldb_packed, true);
928 if (status) return status;
929 }
930
931 for (int64_t Bk = 0; Bk < nstl::max<dim_t>(k, 1); Bk += block_k) {
932 int64_t size_k = k - Bk;
933 bool first_k_block = (Bk == 0);
934 bool last_k_block = (size_k <= block_k);
935 if (!last_k_block) size_k = block_k;
936
937 for (int64_t Bm = 0; Bm < m; Bm += block_m) {
938 int64_t size_m = m - Bm;
939 if (size_m > block_m) size_m = block_m;
940
941 auto off_a_packed = Bm * lda_packed + Bk * pd()->unroll_m();
942 if (packed_a) off_a_packed += off_a0;
943
944 for (int64_t Bn = 0; Bn < n; Bn += block_n) {
945 int64_t size_n = n - Bn;
946 if (size_n > block_n) size_n = block_n;
947
948 auto off_b_packed = Bn * ldb_packed + Bk * pd()->unroll_n();
949 if (packed_b) off_b_packed += off_b0;
950
951 auto off_c = off_c0 + Bm + Bn * ldc;
952 auto off_co = int32_t(off_co0);
953 switch (co_kind_) {
954 case 'R': off_co += Bm; break;
955 case 'C': off_co += Bn; break;
956 case 'M':
957 off_co += isColMajor(problem_.CO.layout)
958 ? (Bn * ldco + Bm)
959 : (Bm * ldco + Bn);
960 break;
961 default: break;
962 }
963
964 for (int i = 0; i < po_count; i++) {
965 po_offsets[i] = po_offsets0[i];
966 bool row = problem_.binaryRow[i],
967 col = problem_.binaryCol[i];
968 if (row && col) {
969 auto ld = pd()->ld_binary(i);
970 po_offsets[i] += isColMajor(problem_.binary[i].layout)
971 ? (Bn * ld + Bm)
972 : (Bm * ld + Bn);
973 } else if (row)
974 po_offsets[i] += Bm;
975 else if (col)
976 po_offsets[i] += Bn;
977 }
978
979 float this_beta = first_k_block ? beta : 1.0f;
980 status = launch_compute(ctx, size_m, size_n, size_k, a_packed,
981 off_a_packed, lda_packed, b_packed, off_b_packed,
982 ldb_packed, c, off_c, ldc, alpha, this_beta, ao, bo,
983 *co, off_co, po_count, po_srcs, po_offsets,
984 first_k_block, last_k_block, batch, stride_a, stride_b,
985 stride_c);
986 if (status) return status;
987 }
988 }
989 }
990
991 return status::success;
992}
993
994} // namespace jit
995} // namespace gpu
996} // namespace impl
997} // namespace dnnl
998
999// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
1000