1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "gpu/jit/gemm/xe_hp_systolic_gemm.hpp" |
18 | |
19 | #include "common/c_types_map.hpp" |
20 | #include "common/dnnl_traits.hpp" |
21 | #include "common/float16.hpp" |
22 | #include "common/impl_registration.hpp" |
23 | #include "common/type_helpers.hpp" |
24 | #include "gpu/jit/gemm/gemm_walk_orders.hpp" |
25 | #include "gpu/jit/utils/ngen_type_bridge.hpp" |
26 | #include "gpu/ocl/gemm/xe_systolic_gemm_copy_kernel.hpp" |
27 | |
28 | namespace dnnl { |
29 | namespace impl { |
30 | namespace gpu { |
31 | namespace jit { |
32 | |
33 | status_t xe_hp_systolic_gemm_t::pd_t::init(engine_t *engine) { |
34 | using namespace prop_kind; |
35 | using namespace data_type; |
36 | using namespace primitive_kind; |
37 | using smask_t = primitive_attr_t::skip_mask_t; |
38 | using arch_t = compute::gpu_arch_t; |
39 | |
40 | assert(engine->kind() == engine_kind::gpu); |
41 | auto *compute_engine = utils::downcast<compute::compute_engine_t *>(engine); |
42 | |
43 | if (!compute_engine->mayiuse_ngen_kernels()) return status::unimplemented; |
44 | if (!compute_engine->mayiuse_large_grf_mode()) return status::unimplemented; |
45 | |
46 | dev_info_ = compute_engine->device_info(); |
47 | auto arch = dev_info_->gpu_arch(); |
48 | |
49 | const auto &d = desc(); |
50 | |
51 | bool dt_float_ok = (d->a_type() == d->b_type() |
52 | && utils::one_of(d->a_type(), bf16, f16) |
53 | && utils::one_of(d->c_type(), f32, d->a_type())); |
54 | |
55 | bool dt_int_ok = (utils::one_of(d->a_type(), u8, s8) |
56 | && utils::one_of(d->b_type(), u8, s8) |
57 | && utils::one_of(d->c_type(), s32, f32, s8, u8, f16)); |
58 | |
59 | if (dt_int_ok) { |
60 | a_zp_ = !attr()->zero_points_.has_default_values(DNNL_ARG_SRC); |
61 | b_zp_ = !attr()->zero_points_.has_default_values(DNNL_ARG_WEIGHTS); |
62 | c_zp_ = !attr()->zero_points_.has_default_values(DNNL_ARG_DST); |
63 | } |
64 | |
65 | bool ok = set_default_formats(d->a_type()); |
66 | if (!ok) return status::unimplemented; |
67 | |
68 | CHECK(attr_.set_default_formats(dst_md(0))); |
69 | |
70 | if (use_nocopy()) return status::unimplemented; |
71 | |
72 | // LIMITATIONS: |
73 | // - batch is not supported for unpacked inputs. |
74 | // - runtime dims are not supported |
75 | bool limits_ok |
76 | = !utils::one_of(DNNL_RUNTIME_DIM_VAL, d->m(), d->n(), d->k()); |
77 | if (!packed_a()) |
78 | limits_ok = limits_ok && (d->lda() != DNNL_RUNTIME_DIM_VAL) |
79 | && (d->batch() == 1); |
80 | if (!packed_b()) |
81 | limits_ok = limits_ok && (d->ldb() != DNNL_RUNTIME_DIM_VAL) |
82 | && (d->batch() == 1); |
83 | if (!packed_c()) |
84 | limits_ok = limits_ok && (d->ldc() != DNNL_RUNTIME_DIM_VAL); |
85 | |
86 | auto attr_skip_mask = smask_t::scales_runtime | smask_t::post_ops; |
87 | |
88 | if (dt_int_ok) attr_skip_mask |= smask_t::zero_points_runtime; |
89 | |
90 | bool arch_ok = utils::one_of( |
91 | arch, arch_t::xe_hp, arch_t::xe_hpg, arch_t::xe_hpc); |
92 | |
93 | ok = ok && limits_ok && (dt_float_ok || dt_int_ok) && arch_ok |
94 | && compute_engine->mayiuse(compute::device_ext_t:: |
95 | intel_subgroup_split_matrix_multiply_accumulate) |
96 | && attr()->has_default_values(attr_skip_mask) |
97 | && desc()->sum_ab == sum_ab::sum_none |
98 | && IMPLICATION(with_bias(), |
99 | utils::one_of(d->bias_type(), d->a_type(), f32) |
100 | && utils::one_of(bias_cmask(), 0, 1, 2, 3)); |
101 | |
102 | auto status = init_post_ops(); |
103 | if (status != status::success) return status; |
104 | |
105 | if (dt_int_ok) { |
106 | ok &= IMPLICATION(a_zp_, !packed_b()) |
107 | && IMPLICATION(b_zp_, !packed_a()); |
108 | |
109 | int cmask_a = 0, cmask_b = 0, cmask_c = 0; |
110 | attr()->zero_points_.get(DNNL_ARG_WEIGHTS, &cmask_b); |
111 | attr()->zero_points_.get(DNNL_ARG_SRC, &cmask_a); |
112 | attr()->zero_points_.get(DNNL_ARG_DST, &cmask_c); |
113 | ok &= (cmask_a == 0) && (cmask_b == 0) |
114 | && utils::one_of(cmask_c, 0, 1 << 0, 1 << 1); |
115 | } |
116 | |
117 | if (!ok) return status::unimplemented; |
118 | |
119 | return status::success; |
120 | } |
121 | |
122 | namespace { |
123 | // Use no-copy if m*n < mn_limit * mn_limit and k < k_limit. |
124 | // Zero means no limit. |
125 | struct nocopy_table_t { |
126 | int mn_limit[2][2]; |
127 | int k_limit[2][2]; |
128 | }; |
129 | |
130 | const nocopy_table_t xe_hp_f16_nocopy_table[] = { |
131 | // NN NT TN TT |
132 | {{{2880, 512}, {4096, 1024}}, {{0, 0}, {0, 0}}}}; |
133 | |
134 | const nocopy_table_t xe_hp_x8x8s32_nocopy_table[] = { |
135 | // NN NT TN TT |
136 | {{{1344, 576}, {4800, 384}}, {{0, 0}, {0, 0}}}}; |
137 | |
138 | const nocopy_table_t xe_hp_f16_nocopy_bad_ld_table[] = { |
139 | // NN NT TN TT |
140 | {{{288, 320}, {288, 288}}, {{288, 320}, {288, 288}}}}; |
141 | |
142 | const nocopy_table_t xe_hp_x8x8s32_nocopy_bad_ld_table[] = { |
143 | // NN NT TN TT |
144 | {{{656, 528}, {352, 384}}, {{656, 528}, {352, 384}}}}; |
145 | |
146 | const nocopy_table_t xe_hpc_f16_nocopy_table[] = { |
147 | // NN NT TN TT |
148 | {{{8192, 8192}, {8192, 8192}}, {{0, 0}, {0, 0}}}}; |
149 | |
150 | const nocopy_table_t xe_hpc_x8x8s32_nocopy_table[] = { |
151 | // NN NT TN TT |
152 | {{{2049, 3000}, {1088, 1024}}, {{0, 0}, {0, 0}}}}; |
153 | |
154 | const nocopy_table_t xe_hpc_f16_nocopy_bad_ld_table[] = { |
155 | // NN NT TN TT |
156 | {{{1024, 1024}, {1024, 1024}}, {{0, 0}, {0, 0}}}}; |
157 | |
158 | const nocopy_table_t xe_hpc_x8x8s32_nocopy_bad_ld_table[] = { |
159 | // NN NT TN TT |
160 | {{{624, 624}, {480, 624}}, {{0, 0}, {0, 0}}}}; |
161 | } // namespace |
162 | |
163 | bool xe_hp_systolic_gemm_t::pd_t::use_nocopy() { |
164 | using namespace data_type; |
165 | |
166 | const auto &d = desc(); |
167 | bool xehpc = (dev_info_->gpu_arch() == compute::gpu_arch_t::xe_hpc); |
168 | |
169 | if (any_prepacked_ || (packed_a_ && packed_b_)) return false; |
170 | |
171 | // Use no-copy for gemv/ger cases. |
172 | if (d->m() <= 1 || d->n() <= 1 || d->k() <= 1) return true; |
173 | |
174 | // Use no-copy implementation if one matrix is very small. |
175 | if (d->m() < 32 && d->n() < 32) return true; |
176 | if (d->m() < 32 && d->k() < 32) return true; |
177 | if (d->n() < 32 && d->k() < 32) return true; |
178 | |
179 | // Use no-copy for small/medium sizes. |
180 | if (utils::one_of(d->a_type(), bf16, f16, s8, u8)) { |
181 | // clang-format off |
182 | const nocopy_table_t *all_tables[2][2][3] = { |
183 | {{xe_hp_f16_nocopy_table, xe_hp_f16_nocopy_table, xe_hp_x8x8s32_nocopy_table}, |
184 | {xe_hp_f16_nocopy_bad_ld_table, xe_hp_f16_nocopy_bad_ld_table, xe_hp_x8x8s32_nocopy_bad_ld_table}}, |
185 | {{xe_hpc_f16_nocopy_table, xe_hpc_f16_nocopy_table, xe_hpc_x8x8s32_nocopy_table}, |
186 | {xe_hpc_f16_nocopy_bad_ld_table, xe_hpc_f16_nocopy_bad_ld_table, xe_hpc_x8x8s32_nocopy_bad_ld_table}} |
187 | }; |
188 | // clang-format on |
189 | int type_idx = (d->a_type() == f16) ? 0 : (d->a_type() == bf16) ? 1 : 2; |
190 | int arch_idx = xehpc ? 1 : 0; |
191 | bool bad_ld = false; |
192 | |
193 | if (!packed_a_) { |
194 | auto lda_bytes = d->lda() * types::data_type_size(d->a_type()); |
195 | bad_ld |= ((lda_bytes & 0x3) != 0); |
196 | } |
197 | if (!packed_b_) { |
198 | auto ldb_bytes = d->ldb() * types::data_type_size(d->b_type()); |
199 | bad_ld |= ((ldb_bytes & 0x3) != 0); |
200 | } |
201 | |
202 | auto table = all_tables[arch_idx][int(bad_ld)][type_idx]; |
203 | long mnl = table->mn_limit[d->transa()][d->transb()]; |
204 | long kl = table->k_limit[d->transa()][d->transb()]; |
205 | |
206 | if ((mnl == 0 || d->m() * d->n() < mnl * mnl) |
207 | && (kl == 0 || d->k() < kl)) |
208 | return true; |
209 | } |
210 | |
211 | return false; |
212 | } |
213 | |
214 | bool xe_hp_systolic_gemm_t::pd_t::set_default_formats(data_type_t dt) { |
215 | using namespace format_tag; |
216 | using new_kd_t = gen_gemm_xe_systolic_kernel_desc_t; |
217 | |
218 | auto sz = types::data_type_size(dt); |
219 | const auto &d = desc(); |
220 | auto arch = dev_info_->gpu_arch(); |
221 | |
222 | auto &a_desc = desc_.b_desc; |
223 | auto &b_desc = desc_.a_desc; |
224 | auto &c_desc = desc_.c_desc; |
225 | |
226 | memory_desc_wrapper a_mdw(&a_desc); |
227 | memory_desc_wrapper b_mdw(&b_desc); |
228 | memory_desc_wrapper c_mdw(&c_desc); |
229 | |
230 | bool a_any = a_mdw.format_any(); |
231 | bool b_any = b_mdw.format_any(); |
232 | bool c_any = c_mdw.format_any(); |
233 | bool batch = d->is_batched(); |
234 | |
235 | if (batch_dims() > 1) return false; |
236 | |
237 | format_tag_t a_packed_tag_16 = undef; |
238 | format_tag_t a_packed_tag_32 = undef; |
239 | format_tag_t a_packed_tag_64 = undef; |
240 | format_tag_t b_packed_tag_16 = undef; |
241 | format_tag_t b_packed_tag_32 = undef; |
242 | format_tag_t b_packed_tag_48 = undef; |
243 | format_tag_t unpacked_tag = batch ? abc : ab; |
244 | |
245 | if (arch == compute::gpu_arch_t::xe_hpc) { |
246 | a_packed_tag_64 = batch ? ((sz == 2) ? aCB4c8b16c2b : aCB4c8b16c4b) |
247 | : ((sz == 2) ? BA4b8a16b2a : BA4b8a16b4a); |
248 | a_packed_tag_16 = batch ? ((sz == 2) ? aCB16c2b : aCB16c4b) |
249 | : ((sz == 2) ? BA16b2a : BA16b4a); |
250 | b_packed_tag_16 = batch ? ((sz == 2) ? aBC16b16c : aBC16b32c) |
251 | : ((sz == 2) ? AB16a16b : AB16a32b); |
252 | } else { |
253 | a_packed_tag_32 = batch ? ((sz == 2) ? aCB4c8b8c2b : aCB4c8b8c4b) |
254 | : ((sz == 2) ? BA4b8a8b2a : BA4b8a8b4a); |
255 | b_packed_tag_48 = batch ? ((sz == 2) ? aBC48b16c : aBC48b32c) |
256 | : ((sz == 2) ? AB48a16b : AB48a32b); |
257 | } |
258 | b_packed_tag_32 = batch ? ((sz == 2) ? aBC32b16c : aBC32b32c) |
259 | : ((sz == 2) ? AB32a16b : AB32a32b); |
260 | |
261 | bool a_prepacked_16 = a_mdw.matches_tag(a_packed_tag_16); |
262 | bool a_prepacked_32 = a_mdw.matches_tag(a_packed_tag_32); |
263 | bool a_prepacked_64 = a_mdw.matches_tag(a_packed_tag_64); |
264 | bool bc_prepacked_16 = b_mdw.matches_tag(b_packed_tag_16) |
265 | || c_mdw.matches_tag(b_packed_tag_16); |
266 | bool bc_prepacked_32 = b_mdw.matches_tag(b_packed_tag_32) |
267 | || c_mdw.matches_tag(b_packed_tag_32); |
268 | bool bc_prepacked_48 = b_mdw.matches_tag(b_packed_tag_48) |
269 | || c_mdw.matches_tag(b_packed_tag_48); |
270 | |
271 | any_prepacked_ = a_prepacked_16 || a_prepacked_32 || a_prepacked_64 |
272 | || bc_prepacked_16 || bc_prepacked_32 || bc_prepacked_48; |
273 | |
274 | unroll_m_ = 0; |
275 | unroll_n_ = 0; |
276 | alt_ = false; |
277 | if (a_prepacked_16) unroll_m_ = 16; |
278 | if (a_prepacked_32) unroll_m_ = 32; |
279 | if (a_prepacked_64) unroll_m_ = 64; |
280 | if (bc_prepacked_16) unroll_n_ = 16; |
281 | if (bc_prepacked_32) unroll_n_ = 32; |
282 | if (bc_prepacked_48) unroll_n_ = 48; |
283 | |
284 | new_kd_t::choose_unrolls(arch, dev_info_->eu_count(), d->a_type(), |
285 | d->b_type(), d->c_type(), d->m(), d->n(), d->k(), d->batch(), |
286 | unroll_m_, unroll_n_, alt_); |
287 | |
288 | format_tag_t a_packed_tag = (unroll_m_ == 64) |
289 | ? a_packed_tag_64 |
290 | : (unroll_m_ == 32) ? a_packed_tag_32 : a_packed_tag_16; |
291 | format_tag_t b_packed_tag = (unroll_n_ == 48) |
292 | ? b_packed_tag_48 |
293 | : (unroll_n_ == 32) ? b_packed_tag_32 : b_packed_tag_16; |
294 | format_tag_t c_packed_tag = b_packed_tag; |
295 | |
296 | packed_a_ = packed_b_ = packed_c_ = false; |
297 | |
298 | if (a_any) { |
299 | if (b_zp_) { |
300 | CHECK(memory_desc_init_by_tag(a_desc, unpacked_tag)); |
301 | } else { |
302 | CHECK(memory_desc_init_by_tag(a_desc, a_packed_tag)); |
303 | auto ld = a_desc.padded_dims[batch ? 1 : 0]; |
304 | ld = nice_ld(ld, int(sz)); |
305 | auto &ostride = a_desc.format_desc.blocking.strides[batch ? 2 : 1]; |
306 | if (batch) { |
307 | auto &bstride = a_desc.format_desc.blocking.strides[0]; |
308 | bstride = (bstride / ostride) * unroll_m_ * ld; |
309 | } |
310 | ostride = unroll_m_ * ld; |
311 | packed_a_ = true; |
312 | } |
313 | } else if (!a_mdw.matches_tag(a_packed_tag) |
314 | && !is_md_gemm_compatible_plain_format(&a_desc)) |
315 | return false; |
316 | |
317 | if (b_any) { |
318 | if (a_zp_) { |
319 | CHECK(memory_desc_init_by_tag(b_desc, unpacked_tag)); |
320 | } else { |
321 | CHECK(memory_desc_init_by_tag(b_desc, b_packed_tag)); |
322 | if (unroll_n_ > 16) { // Bug in zero-padding when unroll_n_ == 16 |
323 | auto ld = b_desc.padded_dims[batch ? 2 : 1]; |
324 | ld = nice_ld(ld, int(sz)); |
325 | auto &ostride |
326 | = b_desc.format_desc.blocking.strides[batch ? 1 : 0]; |
327 | if (batch) { |
328 | auto &bstride = b_desc.format_desc.blocking.strides[0]; |
329 | bstride = (bstride / ostride) * unroll_n_ * ld; |
330 | } |
331 | ostride = unroll_n_ * ld; |
332 | } |
333 | packed_b_ = true; |
334 | } |
335 | } else if (!b_mdw.matches_tag(b_packed_tag) |
336 | && !is_md_gemm_compatible_plain_format(&b_desc)) |
337 | return false; |
338 | |
339 | if (c_any) { |
340 | CHECK(memory_desc_init_by_tag(c_desc, c_packed_tag)); |
341 | if (unroll_n_ > 16) { // Bug in zero-padding when unroll_n_ == 16 |
342 | auto ld = c_desc.padded_dims[batch ? 2 : 1]; |
343 | ld = nice_ld(ld, int(sz)); |
344 | auto &ostride = c_desc.format_desc.blocking.strides[batch ? 1 : 0]; |
345 | if (batch) { |
346 | auto &bstride = c_desc.format_desc.blocking.strides[0]; |
347 | bstride = (bstride / ostride) * unroll_n_ * ld; |
348 | } |
349 | ostride = unroll_n_ * ld; |
350 | } |
351 | packed_c_ = true; |
352 | } else if (!c_mdw.matches_tag(c_packed_tag) |
353 | && !is_md_gemm_compatible_plain_format(&c_desc, true)) |
354 | return false; |
355 | |
356 | packed_a_ = packed_a_ || a_mdw.matches_tag(a_packed_tag); |
357 | packed_b_ = packed_b_ || b_mdw.matches_tag(b_packed_tag); |
358 | packed_c_ = packed_c_ || c_mdw.matches_tag(b_packed_tag); |
359 | |
360 | // No 16x16 copy kernels currently. |
361 | if ((!packed_a_ && unroll_m_ == 16) || (!packed_b_ && unroll_n_ == 16)) |
362 | return false; |
363 | |
364 | return gpu_gemm_pd_t::set_default_formats(); |
365 | } |
366 | |
367 | status_t xe_hp_systolic_gemm_t::init(engine_t *engine) { |
368 | arch_ = pd()->dev_info_->gpu_arch(); |
369 | eu_count_ = pd()->dev_info_->eu_count(); |
370 | |
371 | auto a_type = pd()->desc()->a_type(); |
372 | auto b_type = pd()->desc()->b_type(); |
373 | |
374 | int cmask = -1; |
375 | |
376 | if (pd()->with_c_zero_points()) |
377 | pd()->attr()->zero_points_.get(DNNL_ARG_DST, &cmask); |
378 | else if (pd()->with_bias()) |
379 | cmask = pd()->bias_cmask(); |
380 | |
381 | switch (cmask) { |
382 | case 0: co_kind_ = 'F'; break; |
383 | case (1 << 1): co_kind_ = 'R'; break; |
384 | case (1 << 0): co_kind_ = 'C'; break; |
385 | case 3: co_kind_ = 'M'; break; |
386 | case -1: |
387 | default: co_kind_ = 'N'; break; |
388 | } |
389 | |
390 | // Initialize compute kernels (assembly) |
391 | { |
392 | auto status = init_compute(engine); |
393 | if (status != status::success) return status; |
394 | } |
395 | |
396 | // Initialize copy kernels (OpenCL) |
397 | for (bool copy_b : {false, true}) { |
398 | for (bool clear_sum : {false, true}) { |
399 | if (clear_sum && !pd()->with_ab_zero_points()) continue; |
400 | if (!copy_b ? pd()->packed_a() : pd()->packed_b()) continue; |
401 | |
402 | using copy_kernel_t = ocl::xe_systolic_gemm_copy_kernel_t; |
403 | compute::kernel_ctx_t kernel_ctx; |
404 | |
405 | auto trans |
406 | = !copy_b ? pd()->desc()->transa() : pd()->desc()->transb(); |
407 | auto status = copy_kernel_t::init_kernel_ctx(kernel_ctx, arch_, |
408 | !copy_b ? a_type : b_type, pd()->unroll_n(), copy_b, trans, |
409 | pd()->with_ab_zero_points(), clear_sum); |
410 | if (status != status::success) return status; |
411 | |
412 | create_kernel(engine, ©_kernel_[copy_b][clear_sum], |
413 | copy_kernel_t::name(arch_), kernel_ctx); |
414 | if (!copy_kernel_[copy_b][clear_sum]) return status::runtime_error; |
415 | } |
416 | } |
417 | |
418 | if (get_verbose() >= 2) { |
419 | printf("onednn_verbose,info,gpu,gemm,kernel:%dx%d,%dx%dx%d\n" , |
420 | pd()->unroll_m(), pd()->unroll_n(), compute_info_.wg[LoopM], |
421 | compute_info_.wg[LoopN], compute_info_.wg[LoopK]); |
422 | } |
423 | |
424 | return status::success; |
425 | } |
426 | |
427 | status_t xe_hp_systolic_gemm_t::init_compute(engine_t *engine) { |
428 | using kernel_t = gen_gemm_kernel_t; |
429 | using kd_t = gen_gemm_xe_systolic_kernel_desc_t; |
430 | |
431 | const auto d = pd()->desc(); |
432 | |
433 | auto a_type = d->a_type(); |
434 | auto b_type = d->b_type(); |
435 | auto c_type = d->c_type(); |
436 | auto co_type = pd()->impl_co_type(); |
437 | auto acc_type = pd()->impl_acc_type(); |
438 | |
439 | bool may_k_block |
440 | = (d->k() > kd_t::min_block_k(a_type)) && pd()->allow_k_blocking(); |
441 | bool got_info = false; |
442 | |
443 | auto post_ops = pd()->post_ops(); |
444 | bool with_post_ops = (post_ops->find(primitive_kind::eltwise) != -1) |
445 | || (post_ops->find(primitive_kind::binary) != -1); |
446 | |
447 | kd_t kd_full; |
448 | |
449 | auto status = kd_full.select_kernel(arch_, eu_count_, pd()->with_batch(), |
450 | pd()->packed_c(), pd()->with_a_zero_points(), |
451 | pd()->with_b_zero_points(), pd()->with_c_zero_points(), |
452 | pd()->with_bias(), pd()->alpha(), pd()->beta(), *post_ops, a_type, |
453 | b_type, c_type, co_type, acc_type, d->m(), d->n(), d->k(), |
454 | d->batch(), pd()->unroll_m(), pd()->unroll_n(), pd()->alt()); |
455 | |
456 | if (status != status::success) return status; |
457 | |
458 | problem_ = std::move(*kd_full.problem()); |
459 | |
460 | for (bool first_k_block : {false, true}) { |
461 | for (bool last_k_block : {false, true}) { |
462 | if ((!first_k_block || !last_k_block) && !may_k_block) continue; |
463 | if (may_k_block && last_k_block && !pd()->with_c_zero_points() |
464 | && !with_post_ops) |
465 | kernel_[first_k_block][last_k_block] |
466 | = kernel_[first_k_block][false]; |
467 | else if (may_k_block && first_k_block && pd()->beta() == 1.0f) |
468 | kernel_[first_k_block][last_k_block] |
469 | = kernel_[false][last_k_block]; |
470 | else { |
471 | auto this_beta = pd()->beta(); |
472 | bool this_c_offset = pd()->with_c_zero_points(); |
473 | auto *this_post_ops = pd()->post_ops(); |
474 | post_ops_t no_post_ops; |
475 | |
476 | if (!first_k_block) this_beta = 1.0f; |
477 | if (!last_k_block) { |
478 | this_c_offset = false; |
479 | this_post_ops = &no_post_ops; |
480 | } |
481 | |
482 | kd_t kd; |
483 | |
484 | auto status = kd.select_kernel(arch_, eu_count_, |
485 | pd()->with_batch(), pd()->packed_c(), |
486 | pd()->with_a_zero_points(), pd()->with_b_zero_points(), |
487 | this_c_offset, pd()->with_bias(), pd()->alpha(), |
488 | this_beta, *this_post_ops, a_type, b_type, c_type, |
489 | co_type, acc_type, d->m(), d->n(), d->k(), d->batch(), |
490 | pd()->unroll_m(), pd()->unroll_n(), pd()->alt()); |
491 | |
492 | if (status != status::success) return status; |
493 | |
494 | if (!got_info) { |
495 | compute_info_ = *kd.driver_info(); |
496 | got_info = true; |
497 | } |
498 | |
499 | kernel_t kernel(kd); |
500 | |
501 | create_kernel( |
502 | engine, &kernel_[first_k_block][last_k_block], &kernel); |
503 | |
504 | if (!kernel_[first_k_block][last_k_block]) |
505 | return status::runtime_error; |
506 | } |
507 | } |
508 | } |
509 | |
510 | return status::success; |
511 | } |
512 | |
513 | status_t xe_hp_systolic_gemm_t::init_res_storage( |
514 | engine_t *engine, gpu_resource_t *r) const { |
515 | auto a_type = pd()->desc()->a_type(); |
516 | auto b_type = pd()->desc()->b_type(); |
517 | |
518 | auto m = pd()->desc()->m(); |
519 | auto n = pd()->desc()->n(); |
520 | auto k = pd()->desc()->k(); |
521 | |
522 | int64_t align_m = compute_info_.wgTile(LoopM); |
523 | int64_t align_n = compute_info_.wgTile(LoopN); |
524 | |
525 | auto m_aligned = utils::rnd_up(m, align_m); |
526 | auto n_aligned = utils::rnd_up(n, align_n); |
527 | |
528 | auto max_ldab_packed = max_ld_packed(k); |
529 | |
530 | if (!pd()->packed_a()) { |
531 | memory_storage_t *a_packed_ptr; |
532 | engine->create_memory_storage(&a_packed_ptr, |
533 | m_aligned * max_ldab_packed * types::data_type_size(a_type)); |
534 | if (!a_packed_ptr) return status::runtime_error; |
535 | |
536 | std::unique_ptr<memory_storage_t> a_packed(a_packed_ptr); |
537 | r->add_memory_storage(A_PACKED_, std::move(a_packed)); |
538 | } |
539 | |
540 | if (!pd()->packed_b()) { |
541 | memory_storage_t *b_packed_ptr; |
542 | engine->create_memory_storage(&b_packed_ptr, |
543 | n_aligned * max_ldab_packed * types::data_type_size(b_type)); |
544 | if (!b_packed_ptr) return status::runtime_error; |
545 | |
546 | std::unique_ptr<memory_storage_t> b_packed(b_packed_ptr); |
547 | r->add_memory_storage(B_PACKED_, std::move(b_packed)); |
548 | } |
549 | |
550 | return status::success; |
551 | } |
552 | |
553 | bool xe_hp_systolic_gemm_t::enable_mn_blocking() const { |
554 | return (pd()->desc()->m() >= 8192) && (pd()->desc()->n() >= 8192); |
555 | } |
556 | |
557 | std::tuple<int64_t, int64_t, int64_t> |
558 | xe_hp_systolic_gemm_t::get_blocking() const { |
559 | int64_t m = pd()->desc()->m(); |
560 | int64_t n = pd()->desc()->n(); |
561 | int64_t k = pd()->desc()->k(); |
562 | |
563 | int64_t unroll_k = compute_info_.unroll[LoopK]; |
564 | |
565 | int64_t align_m = compute_info_.wgTile(LoopM); |
566 | int64_t align_n = compute_info_.wgTile(LoopN); |
567 | |
568 | m = utils::rnd_up(m, align_m); |
569 | n = utils::rnd_up(n, align_n); |
570 | |
571 | // Decide on m/n blocking. |
572 | int64_t block_m = compute_info_.blocking[LoopM]; |
573 | int64_t block_n = compute_info_.blocking[LoopN]; |
574 | int64_t max_block_m = utils::rnd_up(m, align_m); |
575 | int64_t max_block_n = utils::rnd_up(n, align_n); |
576 | |
577 | if (enable_mn_blocking()) { |
578 | if (n <= block_n) |
579 | block_m = (block_m * block_n) / n; |
580 | else if (m <= block_m) |
581 | block_n = (block_m * block_n) / m; |
582 | else if (n < 2 * block_n) { |
583 | block_n = utils::rnd_up(n / 2, align_n); |
584 | block_m = (2 * block_m * block_n) / n; |
585 | } else if (m < 2 * block_m) { |
586 | block_m = utils::rnd_up(m / 2, align_m); |
587 | block_n = (2 * block_m * block_n) / m; |
588 | } |
589 | |
590 | block_m = utils::rnd_dn(nstl::min(block_m, max_block_m), align_m); |
591 | block_n = utils::rnd_dn(nstl::min(block_n, max_block_n), align_n); |
592 | } else { |
593 | block_m = m; |
594 | block_n = n; |
595 | } |
596 | |
597 | // Decide on k blocking. |
598 | int64_t block_k = compute_info_.blocking[LoopK]; |
599 | int64_t nblock_k = utils::div_up(k, block_k); |
600 | nblock_k = nstl::max<int64_t>(nblock_k, 1); |
601 | block_k = utils::div_up(k, nblock_k); |
602 | block_k = nstl::max<dim_t>(block_k, 1); |
603 | block_k = utils::rnd_up(pd()->allow_k_blocking() ? block_k : k, unroll_k); |
604 | block_k = nstl::max<dim_t>(block_k, 1); |
605 | |
606 | return std::make_tuple(block_m, block_n, block_k); |
607 | } |
608 | |
609 | status_t xe_hp_systolic_gemm_t::launch_copy(const gemm_exec_ctx_t &ctx, |
610 | int64_t r, int64_t c, const memory_storage_t &src, int64_t offset_src, |
611 | int64_t ld_src, const memory_storage_t &dst, int32_t offset_dst, |
612 | int32_t ld_dst, bool copyb) const { |
613 | |
614 | using copy_kernel_t = ocl::xe_systolic_gemm_copy_kernel_t; |
615 | |
616 | if (pd()->with_ab_zero_points()) { |
617 | auto status |
618 | = launch_clear_sum(ctx, r, c, dst, offset_dst, ld_dst, copyb); |
619 | if (status) return status; |
620 | } |
621 | |
622 | int64_t unroll_k = compute_info_.unroll[LoopK]; |
623 | |
624 | int64_t align_r = 0, align_c = 0; |
625 | |
626 | if (!copyb) { |
627 | align_r = compute_info_.wgTile(LoopM); |
628 | align_c = unroll_k; |
629 | } else { |
630 | align_r = unroll_k; |
631 | align_c = compute_info_.wgTile(LoopN); |
632 | } |
633 | |
634 | bool transa = (pd()->desc()->transa() == dnnl_trans); |
635 | bool transb = (pd()->desc()->transb() == dnnl_trans); |
636 | bool trans = !copyb ? transa : transb; |
637 | |
638 | auto &kernel = copy_kernel_[copyb][false]; |
639 | |
640 | assert(kernel); |
641 | compute::kernel_arg_list_t arg_list; |
642 | arg_list.set(0, r); |
643 | arg_list.set(1, c); |
644 | arg_list.set(2, src); |
645 | arg_list.set(3, offset_src); |
646 | arg_list.set(4, ld_src); |
647 | arg_list.set(5, dst); |
648 | arg_list.set(6, offset_dst); |
649 | arg_list.set(7, ld_dst); |
650 | |
651 | auto elt_size = types::data_type_size(pd()->desc()->a_type()); |
652 | size_t r_threads = utils::div_up(utils::rnd_up(r, align_r), |
653 | copy_kernel_t::unroll_r( |
654 | arch_, elt_size, pd()->unroll_n(), copyb, trans)); |
655 | size_t c_threads = utils::div_up(utils::rnd_up(c, align_c), |
656 | copy_kernel_t::unroll_c( |
657 | arch_, elt_size, pd()->unroll_n(), copyb, trans)); |
658 | size_t sg = copy_kernel_t::subgroup_size(arch_, elt_size, copyb, trans); |
659 | |
660 | size_t r_lsz = trans ? 1 : 16; |
661 | size_t c_lsz = trans ? 16 : 1; |
662 | |
663 | if (r_threads > r_lsz) |
664 | r_threads = utils::rnd_up(r_threads, r_lsz); |
665 | else |
666 | r_lsz = r_threads; |
667 | |
668 | if (c_threads > c_lsz) |
669 | c_threads = utils::rnd_up(c_threads, c_lsz); |
670 | else |
671 | c_lsz = c_threads; |
672 | |
673 | size_t gws[3] = {r_threads * sg, c_threads, 1}; |
674 | size_t lws[3] = {r_lsz * sg, c_lsz, 1}; |
675 | |
676 | auto nd_range = compute::nd_range_t(gws, lws); |
677 | |
678 | return parallel_for(ctx, nd_range, kernel, arg_list); |
679 | } |
680 | |
681 | status_t xe_hp_systolic_gemm_t::launch_clear_sum(const gemm_exec_ctx_t &ctx, |
682 | int64_t r, int64_t c, const memory_storage_t &dst, int32_t offset_dst, |
683 | int32_t ld_dst, bool copyb) const { |
684 | |
685 | auto &kernel = copy_kernel_[copyb][true]; |
686 | |
687 | assert(kernel); |
688 | compute::kernel_arg_list_t arg_list; |
689 | arg_list.set(0, r); |
690 | arg_list.set(1, c); |
691 | arg_list.set(2, dst); |
692 | arg_list.set(3, offset_dst); |
693 | arg_list.set(4, ld_dst); |
694 | |
695 | auto elt_size = types::data_type_size(pd()->desc()->a_type()); |
696 | size_t threads = !copyb ? utils::div_up(r, pd()->unroll_m()) |
697 | : utils::div_up(c, pd()->unroll_n()); |
698 | size_t sg = ocl::xe_systolic_gemm_copy_kernel_t::subgroup_size_clear_sum( |
699 | arch_, elt_size, copyb); |
700 | |
701 | size_t gws[3] = {threads * sg, 1, 1}; |
702 | size_t lws[3] = {sg, 1, 1}; |
703 | |
704 | auto nd_range = compute::nd_range_t(gws, lws); |
705 | |
706 | return parallel_for(ctx, nd_range, kernel, arg_list); |
707 | } |
708 | |
709 | status_t xe_hp_systolic_gemm_t::launch_compute(const gemm_exec_ctx_t &ctx, |
710 | int32_t m, int32_t n, int32_t k, const memory_storage_t &ap, |
711 | int64_t offset_a, int32_t lda, const memory_storage_t &bp, |
712 | int64_t offset_b, int32_t ldb, const memory_storage_t &c, |
713 | int64_t offset_c, int32_t ldc, float alpha, float beta, |
714 | const memory_storage_t *ao, const memory_storage_t *bo, |
715 | const memory_storage_t &co, int32_t offset_co, int po_count, |
716 | const memory_storage_t **po_srcs, int32_t *offset_po_src, |
717 | bool first_k_block, bool last_k_block, int32_t batch, int32_t stride_a, |
718 | int32_t stride_b, int32_t stride_c) const { |
719 | |
720 | auto tg_m = compute_info_.wg[LoopM]; |
721 | auto tg_n = compute_info_.wg[LoopN]; |
722 | |
723 | auto &kernel = kernel_[first_k_block][last_k_block]; |
724 | |
725 | // kernel void gemm_kernel(global char *Ap, global uchar *Bp, global int *C, |
726 | // int k, int ldc, |
727 | // long offsetA, long offsetB, long offsetC, |
728 | // int m, int n, |
729 | // float alpha, float beta, |
730 | // int lda, int ldb) |
731 | |
732 | assert(kernel); |
733 | |
734 | compute::kernel_arg_list_t arg_list; |
735 | int argn = 0; |
736 | arg_list.set(argn++, ap); |
737 | arg_list.set(argn++, bp); |
738 | arg_list.set(argn++, c); |
739 | arg_list.set(argn++, offset_a); |
740 | arg_list.set(argn++, offset_b); |
741 | arg_list.set(argn++, offset_c); |
742 | arg_list.set(argn++, lda); |
743 | arg_list.set(argn++, ldb); |
744 | arg_list.set(argn++, ldc); |
745 | arg_list.set(argn++, m); |
746 | arg_list.set(argn++, n); |
747 | arg_list.set(argn++, k); |
748 | arg_list.set(argn++, alpha); |
749 | arg_list.set(argn++, beta); |
750 | |
751 | if (pd()->with_a_zero_points()) arg_list.set(argn++, *ao); |
752 | if (pd()->with_b_zero_points()) arg_list.set(argn++, *bo); |
753 | if ((pd()->with_bias() || pd()->with_c_zero_points())) { |
754 | arg_list.set(argn++, co); |
755 | arg_list.set(argn++, offset_co); |
756 | if (pd()->with_bias()) { |
757 | int32_t ldco = pd()->desc()->ld_bias(); |
758 | arg_list.set(argn++, ldco); |
759 | } |
760 | } |
761 | for (int i = 0; i < po_count; i++) { |
762 | if (!po_srcs[i]) continue; |
763 | arg_list.set(argn++, *po_srcs[i]); |
764 | arg_list.set(argn++, offset_po_src[i]); |
765 | |
766 | if (problem_.binaryRow[i] && problem_.binaryCol[i]) |
767 | arg_list.set(argn++, int32_t(pd()->ld_binary(i))); |
768 | } |
769 | |
770 | uint32_t flags = 0; |
771 | if (co_kind_ == 'R') flags |= FlagCORow; |
772 | if (co_kind_ == 'C') flags |= FlagCOColumn; |
773 | if (!first_k_block) flags |= FlagNoninitialKBlock; |
774 | if (!last_k_block) flags |= FlagNonfinalKBlock; |
775 | arg_list.set(argn++, flags); |
776 | |
777 | if (pd()->with_batch()) { |
778 | arg_list.set(argn++, stride_a); |
779 | arg_list.set(argn++, stride_b); |
780 | arg_list.set(argn++, stride_c); |
781 | for (int i = 0; i < po_count; i++) |
782 | if (problem_.binaryBatch[i]) |
783 | arg_list.set(argn++, int32_t(pd()->stride_binary(i, 0))); |
784 | } |
785 | |
786 | auto thread_m = utils::div_up(m, pd()->unroll_m() * tg_m) * tg_m; |
787 | auto thread_n = utils::div_up(n, pd()->unroll_n() * tg_n) * tg_n; |
788 | |
789 | if (walk_n_first_) std::swap(thread_m, thread_n); |
790 | |
791 | size_t gws[3] = {size_t(thread_m), size_t(thread_n), 1}; |
792 | size_t lws[3] = {size_t(tg_m), size_t(tg_n), 1}; |
793 | if (pd()->with_batch()) gws[2] = batch; |
794 | |
795 | lws[1] *= compute_info_.wgExpand; |
796 | gws[1] *= compute_info_.wgExpand; |
797 | |
798 | gemm_linear_order_args(arg_list, argn, lws, gws, m, n, false, compute_info_, |
799 | pd()->dev_info_); |
800 | |
801 | lws[0] *= compute_info_.subgroupSize; |
802 | gws[0] *= compute_info_.subgroupSize; |
803 | |
804 | auto nd_range = compute::nd_range_t(gws, lws); |
805 | |
806 | return parallel_for(ctx, nd_range, kernel, arg_list); |
807 | } |
808 | |
809 | status_t xe_hp_systolic_gemm_t::execute(const gemm_exec_ctx_t &ctx) const { |
810 | auto a_type = pd()->desc()->a_type(); |
811 | auto b_type = pd()->desc()->b_type(); |
812 | auto c_type = pd()->desc()->c_type(); |
813 | auto bias_type = pd()->desc()->bias_type(); |
814 | |
815 | auto m = pd()->desc()->m(); |
816 | auto n = pd()->desc()->n(); |
817 | auto k = pd()->desc()->k(); |
818 | auto batch = pd()->desc()->batch(); |
819 | |
820 | bool packed_a = pd()->packed_a(); |
821 | bool packed_b = pd()->packed_b(); |
822 | bool packed_c = pd()->packed_c(); |
823 | |
824 | auto lda = packed_a ? 0 : pd()->desc()->lda(); |
825 | auto ldb = packed_b ? 0 : pd()->desc()->ldb(); |
826 | auto ldc = packed_c ? pd()->ldc_packed() : pd()->desc()->ldc(); |
827 | auto ldco = pd()->with_bias() ? pd()->desc()->ld_bias() : 0; |
828 | |
829 | auto stride_a = pd()->desc()->stride_a(); |
830 | auto stride_b = pd()->desc()->stride_b(); |
831 | auto stride_c = pd()->desc()->stride_c(); |
832 | |
833 | auto alpha = pd()->alpha(); |
834 | auto beta = pd()->beta(); |
835 | |
836 | auto &a = GEMM_CTX_ARG_STORAGE(b); |
837 | auto &b = GEMM_CTX_ARG_STORAGE(a); |
838 | auto &c = GEMM_CTX_ARG_STORAGE(c); |
839 | auto &c_zp = GEMM_CTX_ARG_STORAGE(c_zero_point); |
840 | auto &bias = GEMM_CTX_ARG_STORAGE(bias); |
841 | auto *co = &c_zp; |
842 | memory_storage_t *ao = nullptr, *bo = nullptr; |
843 | |
844 | auto &a_packed = packed_a ? a : CTX_GPU_RES_STORAGE(A_PACKED_); |
845 | auto &b_packed = packed_b ? b : CTX_GPU_RES_STORAGE(B_PACKED_); |
846 | |
847 | const memory_storage_t *po_srcs[GEMM_MAX_PO]; |
848 | |
849 | int po_count = pd()->post_ops()->len(); |
850 | assert(po_count <= GEMM_MAX_PO); |
851 | |
852 | for (int i = 0; i < po_count; i++) { |
853 | auto &src = pd()->binary_srcs()[i]; |
854 | switch (src.type) { |
855 | case pd_t::binary_src_t::binary: |
856 | po_srcs[i] |
857 | = ctx.args() |
858 | .exec_args |
859 | .at(DNNL_ARG_ATTR_MULTIPLE_POST_OP(src.index) |
860 | | DNNL_ARG_SRC_1) |
861 | .mem->memory_storage(); |
862 | break; |
863 | case pd_t::binary_src_t::bias: po_srcs[i] = &bias; break; |
864 | case pd_t::binary_src_t::scales: |
865 | switch (src.index) { |
866 | case DNNL_ARG_WEIGHTS: |
867 | po_srcs[i] = &GEMM_CTX_ARG_STORAGE(a_scales); |
868 | break; |
869 | case DNNL_ARG_SRC: |
870 | po_srcs[i] = &GEMM_CTX_ARG_STORAGE(b_scales); |
871 | break; |
872 | case DNNL_ARG_DST: |
873 | po_srcs[i] = &GEMM_CTX_ARG_STORAGE(c_scales); |
874 | break; |
875 | default: |
876 | po_srcs[i] = nullptr; |
877 | assert(!"invalid scale type" ); |
878 | break; |
879 | } |
880 | break; |
881 | default: po_srcs[i] = nullptr; break; |
882 | } |
883 | } |
884 | |
885 | size_t off_a0 |
886 | = a.offset() / types::data_type_size(a_type) + pd()->dyn_offset_a; |
887 | size_t off_b0 |
888 | = b.offset() / types::data_type_size(b_type) + pd()->dyn_offset_b; |
889 | size_t off_c0 |
890 | = c.offset() / types::data_type_size(c_type) + pd()->dyn_offset_c; |
891 | size_t off_co0 = 0; |
892 | |
893 | int32_t po_offsets0[GEMM_MAX_PO] = {0}, po_offsets[GEMM_MAX_PO] = {0}; |
894 | for (int i = 0; i < po_count; i++) |
895 | if (po_srcs[i]) |
896 | po_offsets0[i] = po_srcs[i]->offset() / problem_.Tbinary[i]; |
897 | |
898 | if (pd()->with_ab_zero_points()) { |
899 | ao = &GEMM_CTX_ARG_STORAGE(a_zero_point); |
900 | bo = &GEMM_CTX_ARG_STORAGE(b_zero_point); |
901 | } |
902 | |
903 | if (pd()->with_bias()) { |
904 | off_co0 = bias.offset() / types::data_type_size(bias_type); |
905 | co = &bias; |
906 | } |
907 | |
908 | int64_t block_m = 0, block_n = 0, block_k = 0; |
909 | std::tie(block_m, block_n, block_k) = get_blocking(); |
910 | |
911 | auto ld_packed = get_ld_packed(k); |
912 | auto lda_packed = packed_a ? pd()->lda_packed() : ld_packed; |
913 | auto ldb_packed = packed_b ? pd()->ldb_packed() : ld_packed; |
914 | |
915 | status_t status; |
916 | |
917 | if (!packed_a) { |
918 | assert(batch == 1); |
919 | status = launch_copy( |
920 | ctx, m, k, a, off_a0, lda, a_packed, 0, lda_packed, false); |
921 | if (status) return status; |
922 | } |
923 | |
924 | if (!packed_b) { |
925 | assert(batch == 1); |
926 | status = launch_copy( |
927 | ctx, k, n, b, off_b0, ldb, b_packed, 0, ldb_packed, true); |
928 | if (status) return status; |
929 | } |
930 | |
931 | for (int64_t Bk = 0; Bk < nstl::max<dim_t>(k, 1); Bk += block_k) { |
932 | int64_t size_k = k - Bk; |
933 | bool first_k_block = (Bk == 0); |
934 | bool last_k_block = (size_k <= block_k); |
935 | if (!last_k_block) size_k = block_k; |
936 | |
937 | for (int64_t Bm = 0; Bm < m; Bm += block_m) { |
938 | int64_t size_m = m - Bm; |
939 | if (size_m > block_m) size_m = block_m; |
940 | |
941 | auto off_a_packed = Bm * lda_packed + Bk * pd()->unroll_m(); |
942 | if (packed_a) off_a_packed += off_a0; |
943 | |
944 | for (int64_t Bn = 0; Bn < n; Bn += block_n) { |
945 | int64_t size_n = n - Bn; |
946 | if (size_n > block_n) size_n = block_n; |
947 | |
948 | auto off_b_packed = Bn * ldb_packed + Bk * pd()->unroll_n(); |
949 | if (packed_b) off_b_packed += off_b0; |
950 | |
951 | auto off_c = off_c0 + Bm + Bn * ldc; |
952 | auto off_co = int32_t(off_co0); |
953 | switch (co_kind_) { |
954 | case 'R': off_co += Bm; break; |
955 | case 'C': off_co += Bn; break; |
956 | case 'M': |
957 | off_co += isColMajor(problem_.CO.layout) |
958 | ? (Bn * ldco + Bm) |
959 | : (Bm * ldco + Bn); |
960 | break; |
961 | default: break; |
962 | } |
963 | |
964 | for (int i = 0; i < po_count; i++) { |
965 | po_offsets[i] = po_offsets0[i]; |
966 | bool row = problem_.binaryRow[i], |
967 | col = problem_.binaryCol[i]; |
968 | if (row && col) { |
969 | auto ld = pd()->ld_binary(i); |
970 | po_offsets[i] += isColMajor(problem_.binary[i].layout) |
971 | ? (Bn * ld + Bm) |
972 | : (Bm * ld + Bn); |
973 | } else if (row) |
974 | po_offsets[i] += Bm; |
975 | else if (col) |
976 | po_offsets[i] += Bn; |
977 | } |
978 | |
979 | float this_beta = first_k_block ? beta : 1.0f; |
980 | status = launch_compute(ctx, size_m, size_n, size_k, a_packed, |
981 | off_a_packed, lda_packed, b_packed, off_b_packed, |
982 | ldb_packed, c, off_c, ldc, alpha, this_beta, ao, bo, |
983 | *co, off_co, po_count, po_srcs, po_offsets, |
984 | first_k_block, last_k_block, batch, stride_a, stride_b, |
985 | stride_c); |
986 | if (status) return status; |
987 | } |
988 | } |
989 | } |
990 | |
991 | return status::success; |
992 | } |
993 | |
994 | } // namespace jit |
995 | } // namespace gpu |
996 | } // namespace impl |
997 | } // namespace dnnl |
998 | |
999 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
1000 | |