1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <cctype>
18
19#include "common/gemm_utils.hpp"
20#include "common/impl_registration.hpp"
21#include "gpu/compute/device_info.hpp"
22#include "gpu/jit/gemm/gen_gemm_kernel.hpp"
23#include "gpu/jit/gemm/kernel_catalog.hpp"
24#include "gpu/jit/gemm/kernel_selector.hpp"
25#include "gpu/jit/gemm/strategy_parser.hpp"
26#include "gpu/ocl/ocl_utils.hpp"
27
28namespace dnnl {
29namespace impl {
30namespace gpu {
31namespace jit {
32
33#define _CATALOG_ gemm_catalog
34#include "gpu/jit/gemm/kernel.db"
35;
36#undef _CATALOG_
37
38status_t gen_gemm_kernel_desc_t::finalize() {
39 // Update problem alignments to match catalog entry.
40 if (!isPacked(problem_.A.layout)) {
41 problem_.A.setAlignment(
42 std::max(problem_.Ta.size(), entry_->driverInfo.alignment[0]));
43 }
44
45 if (!isPacked(problem_.B.layout)) {
46 problem_.B.setAlignment(
47 std::max(problem_.Tb.size(), entry_->driverInfo.alignment[1]));
48 }
49
50 if (!isPacked(problem_.C.layout)) {
51 problem_.C.setAlignment(std::max(
52 problem_.Tc_ext.size(), entry_->restrictions.alignment[2]));
53 }
54
55 problem_.CO.setAlignment(problem_.Tco.size());
56
57 // Parse strategy string.
58 strategy_ = GEMMStrategy(hw_, stepping_);
59 strategy_.unroll[LoopM] = entry_->driverInfo.unroll[LoopM];
60 strategy_.unroll[LoopN] = entry_->driverInfo.unroll[LoopN];
61 parseStrategy(entry_->strategy, hw_, problem_, strategy_);
62 adjustStrategy(hw_, problem_, strategy_);
63
64 // Always use variable beta for global k-parallel kernels.
65 if (strategy_.kParallel) problem_.beta_real = Scalar<double>();
66
67 // Omit periodic barriers when k is small.
68 if (strategy_.barrierFreq > 0 && k_ >= 0 && k_ < 2 * strategy_.barrierFreq)
69 strategy_.barrierFreq = 0;
70
71 // Disable linear ordering and persistent threads if the GEMM doesn't fill the GPU.
72 if (m_ >= 0 && n_ >= 0 && eu_count_ >= 0) {
73 int wg_tile_m = strategy_.wg[LoopM] * strategy_.unroll[LoopM];
74 int wg_tile_n = strategy_.wg[LoopN] * strategy_.unroll[LoopN];
75 if (wg_tile_m > 0 && wg_tile_n > 0) {
76 dim_t thread_count = utils::div_up(m_, wg_tile_m)
77 * utils::div_up(n_, wg_tile_n) * strategy_.wg[LoopM]
78 * strategy_.wg[LoopN] * std::max(strategy_.wg[LoopK], 1);
79 dim_t thread_gpu = eu_count_
80 * compute::device_info_t::threads_per_eu(
81 arch_, strategy_.GRFs > 128);
82 if (thread_count <= thread_gpu)
83 strategy_.persistent = strategy_.hilbertOrder
84 = strategy_.boustrophedon = false;
85 }
86 }
87
88 strategy_.preflight(hw_, problem_);
89
90 update_driver_info();
91
92 return status::success;
93}
94
95void gen_gemm_kernel_desc_t::update_driver_info() {
96#define ARCH_DISPATCH(arch) \
97 case ngen::HW::arch: \
98 driver_info_ = gemm_kernel_generator_t<ngen::HW::arch>::driverInfo( \
99 problem_, strategy_); \
100 break;
101
102 switch (hw_) {
103 REG_GEN9_ISA(ARCH_DISPATCH(Gen9))
104 REG_XELP_ISA(ARCH_DISPATCH(XeLP))
105 REG_XEHP_ISA(ARCH_DISPATCH(XeHP))
106 REG_XEHPG_ISA(ARCH_DISPATCH(XeHPG))
107 REG_XEHPC_ISA(ARCH_DISPATCH(XeHPC))
108 default:
109 assert(!"Unsupported architecture");
110 driver_info_ = entry_->driverInfo;
111 break;
112 }
113#undef ARCH_DISPATCH
114}
115
116status_t gen_gemm_kernel_desc_t::transfer_post_ops(
117 const post_ops_t &post_ops, bool swap_ab) {
118 if (post_ops.len() > 0) {
119 problem_.postOps = post_ops;
120 int po_count = post_ops.len();
121 problem_.Tbinary.reserve(po_count);
122 problem_.binary.reserve(po_count);
123 problem_.binaryRow.reserve(po_count);
124 problem_.binaryCol.reserve(po_count);
125 problem_.binaryBatch.reserve(po_count);
126
127 if (problem_.Ta == Type::f16) problem_.Ts = Type::f32;
128
129 for (int i = 0; i < po_count; i++) {
130 const auto &entry = post_ops.entry_[i];
131 if (entry.kind != primitive_kind::binary) {
132 problem_.Tbinary.push_back(Type::invalid);
133 problem_.binaryRow.push_back(false);
134 problem_.binaryCol.push_back(false);
135 problem_.binaryBatch.push_back(false);
136 problem_.binary.push_back(MatrixAddressing {});
137 continue;
138 }
139
140 const auto &src_md = entry.binary.src1_desc;
141 memory_desc_wrapper src_mdw(src_md);
142
143 int ndims = src_mdw.ndims();
144 auto T = convert_dnnl_to_kernel_type(src_mdw.data_type());
145 int nr = (ndims >= 1) ? src_mdw.dims()[ndims - 1] : 1;
146 int nc = (ndims >= 2) ? src_mdw.dims()[ndims - 2] : 1;
147 bool trans = false;
148
149 if (src_mdw.ndims() >= 2) {
150 if (src_md.format_kind != format_kind::blocked
151 || !is_md_gemm_compatible_plain_format(&src_md, false))
152 return status::unimplemented;
153 trans = (src_md.format_desc.blocking.strides[ndims - 1] > 1);
154 }
155
156 if (swap_ab) {
157 trans = !trans;
158 std::swap(nr, nc);
159 }
160
161 problem_.Tbinary.push_back(T);
162 problem_.binaryRow.push_back(nr > 1);
163 problem_.binaryCol.push_back(nc > 1);
164 problem_.binaryBatch.push_back(ndims >= 3);
165
166 MatrixAddressing atype;
167 atype.layout = trans ? MatrixLayout::T : MatrixLayout::N;
168 atype.crosspack = 1;
169 atype.setAlignment(T.size());
170
171 problem_.binary.push_back(atype);
172 }
173 }
174
175 return status::success;
176}
177
178status_t gen_gemm_nocopy_kernel_desc_t::select_kernel(compute::gpu_arch_t arch,
179 int stepping, int eu_count, compute_mode mode, int batch_dims,
180 bool trans_a, bool trans_b, bool trans_co, bool swap_ab, bool a_offset,
181 bool b_offset, bool c_offset, bool bias, sum_ab_t reduce_ab,
182 float alpha, float beta, const post_ops_t &post_ops, data_type_t a_type,
183 data_type_t b_type, data_type_t c_type, data_type_t co_type,
184 data_type_t acc_type, int align_a, int align_b, int align_c, dim_t m,
185 dim_t n, dim_t k, dim_t lda, dim_t ldb, dim_t ldc, dim_t batch) {
186 using namespace ngen;
187 using namespace kcatalog;
188
189 arch_ = arch;
190 hw_ = convert_dnnl_arch_to_hw(arch);
191 stepping_ = stepping;
192 m_ = m;
193 n_ = n;
194 k_ = k;
195 eu_count_ = eu_count;
196 a_offset_ = a_offset;
197 b_offset_ = b_offset;
198
199 align_a = nstl::max(align_a, int(types::data_type_size(a_type)));
200 align_b = nstl::max(align_b, int(types::data_type_size(b_type)));
201 align_c = nstl::max(align_c, int(types::data_type_size(c_type)));
202
203 // Set up problem structure.
204 problem_.Ta = problem_.Ta_ext = convert_dnnl_to_kernel_type(a_type);
205 problem_.Tb = problem_.Tb_ext = convert_dnnl_to_kernel_type(b_type);
206 problem_.Tc = convert_dnnl_to_kernel_type(acc_type);
207 problem_.Tco = convert_dnnl_to_kernel_type(co_type);
208 problem_.Tc_ext = convert_dnnl_to_kernel_type(c_type);
209 problem_.Ts = problem_.Tc;
210 problem_.A.layout = trans_a ? MatrixLayout::T : MatrixLayout::N;
211 problem_.B.layout = trans_b ? MatrixLayout::T : MatrixLayout::N;
212 problem_.C.layout = MatrixLayout::N;
213 problem_.A.crosspack = problem_.B.crosspack = problem_.C.crosspack = 1;
214 problem_.A.packSize = problem_.B.packSize = problem_.C.packSize = 0;
215 problem_.A.setAlignment(align_a);
216 problem_.B.setAlignment(align_b);
217 problem_.C.setAlignment(align_c);
218 if (batch_dims > 0) {
219 problem_.batch = BatchMode::Strided;
220 problem_.batchDims = batch_dims;
221 }
222 if (a_offset || b_offset) problem_.abOffset = ABOffset::Calc;
223
224 if (problem_.Ta.isInteger()) problem_.Ts = Type::f32;
225
226 if (alpha == 1.0f) problem_.alpha_real = alpha;
227 if (beta == 0.0f || beta == 1.0f) problem_.beta_real = beta;
228
229 auto status = transfer_post_ops(post_ops, swap_ab);
230 if (status != status::success) return status;
231
232 if (c_offset || bias || reduce_ab != sum_ab::sum_none) {
233 assert(!(c_offset && bias));
234 if (bias) problem_.cOffset = COffset::Pre;
235 if (c_offset) problem_.cOffset = COffset::Post;
236 problem_.CO.crosspack = 1;
237 problem_.CO.alignment = problem_.C.alignment;
238 problem_.CO.layout = trans_co ? MatrixLayout::T : MatrixLayout::N;
239 }
240
241 problem_.sumA = (reduce_ab == sum_ab::sum_b_col);
242 problem_.sumB = (reduce_ab == sum_ab::sum_a_row);
243
244 // Select a kernel from the catalog.
245 MatchParams match_params[3];
246 int npatterns = 1;
247
248 match_params[0] = MatchParams(hw_, problem_);
249
250 match_params[0].sizes.m = m;
251 match_params[0].sizes.n = n;
252 match_params[0].sizes.k = k;
253 match_params[0].sizes.batch = batch;
254 match_params[0].stepping = stepping;
255
256 auto tags = const_cast<char *>(match_params[0].tags);
257 while (*tags)
258 tags++;
259 if (lda * problem_.Ta >= 64) *tags++ = kcatalog::ReqBlock2DA;
260 if (ldb * problem_.Tb >= 64) *tags++ = kcatalog::ReqBlock2DB;
261 if (ldc * problem_.Tc >= 64) *tags++ = kcatalog::ReqBlock2DC;
262
263 if ((mode & mode_tf32)
264 && utils::everyone_is(Type::f32, problem_.Ta, problem_.Tb)) {
265 match_params[npatterns] = match_params[0];
266 match_params[npatterns].selector.precisions[0] = "T";
267 match_params[npatterns].selector.precisions[1] = "T";
268 npatterns++;
269 }
270
271 if ((mode & mode_bf16x1)
272 && utils::everyone_is(Type::f32, problem_.Ta, problem_.Tb)) {
273 match_params[npatterns] = match_params[0];
274 match_params[npatterns].selector.precisions[0] = "[SB]";
275 match_params[npatterns].selector.precisions[1] = "[SB]";
276 npatterns++;
277 }
278
279 EvaluateParams eval_params;
280
281 eval_params.sizes = match_params[0].sizes;
282 eval_params.beta = (post_ops.len() > 0) ? 0.0f : beta;
283 eval_params.euCount = eu_count;
284
285 entry_ = select(
286 gemm_catalog, npatterns, match_params, eval_params, aux_params_);
287
288 if (!entry_) return status::unimplemented;
289
290 if (mode & mode_tf32) {
291 if (entry_->selector.precisions[0][0] == 'T')
292 problem_.Ta = problem_.Ta_ext = Type::tf32;
293 if (entry_->selector.precisions[1][0] == 'T')
294 problem_.Tb = problem_.Tb_ext = Type::tf32;
295 }
296
297 if (mode & mode_bf16x1) {
298 if (entry_->selector.precisions[0][0] == '[') problem_.Ta = Type::bf16;
299 if (entry_->selector.precisions[1][0] == '[') problem_.Tb = Type::bf16;
300 }
301
302 auto block_k = entry_->driverInfo.blocking[LoopK];
303 if (block_k > 0 && k > block_k && beta != 1.0f)
304 problem_.beta_real = Scalar<double>();
305
306 return finalize();
307}
308
309status_t gen_gemm_xe_systolic_kernel_desc_t::select_kernel(
310 compute::gpu_arch_t arch, int eu_count, int batch_dims, bool packed_c,
311 bool a_offset, bool b_offset, bool c_offset, bool bias, float alpha,
312 float beta, const post_ops_t &post_ops, data_type_t a_type,
313 data_type_t b_type, data_type_t c_type, data_type_t co_type,
314 data_type_t acc_type, dim_t m, dim_t n, dim_t k, dim_t batch,
315 int unroll_m, int unroll_n, bool alt) {
316 using namespace ngen;
317 using namespace kcatalog;
318
319 arch_ = arch;
320 hw_ = convert_dnnl_arch_to_hw(arch);
321 m_ = m;
322 n_ = n;
323 k_ = k;
324 eu_count_ = eu_count;
325 a_offset_ = a_offset;
326 b_offset_ = b_offset;
327
328 if (!utils::one_of(hw_, HW::XeHP, HW::XeHPG, HW::XeHPC))
329 return status::unimplemented;
330
331 bool xehpc = (hw_ == HW::XeHPC);
332
333 auto osys = xehpc ? 16 : 8;
334 auto ksys = int(32 / types::data_type_size(a_type));
335 auto csys = int(4 / types::data_type_size(a_type));
336
337 problem_.Ta = problem_.Ta_ext = convert_dnnl_to_kernel_type(a_type);
338 problem_.Tb = problem_.Tb_ext = convert_dnnl_to_kernel_type(b_type);
339 problem_.Tc = convert_dnnl_to_kernel_type(acc_type);
340 problem_.Tco = convert_dnnl_to_kernel_type(co_type);
341 problem_.Tc_ext = convert_dnnl_to_kernel_type(c_type);
342 problem_.Ts = Type::f32;
343 problem_.A.layout = MatrixLayout::PackedColumns;
344 problem_.B.layout = MatrixLayout::PackedRows;
345 problem_.C.layout = MatrixLayout::N;
346 problem_.A.crosspack = csys;
347 problem_.B.crosspack = ksys;
348 problem_.C.crosspack = 1;
349 problem_.A.packSize = unroll_m;
350 problem_.B.packSize = unroll_n;
351 problem_.C.packSize = 0;
352 if (osys < unroll_m) {
353 problem_.A.tileR = osys;
354 problem_.A.tileC = ksys;
355 }
356 problem_.A.setAlignment(32);
357 problem_.B.setAlignment(32);
358 problem_.C.setAlignment(int(types::data_type_size(c_type)));
359 if (packed_c) problem_.C = problem_.B;
360 if (batch_dims > 0) {
361 problem_.batch = BatchMode::Strided;
362 problem_.batchDims = batch_dims;
363 }
364 if (a_offset || b_offset) problem_.abOffset = ABOffset::Load;
365 if (alpha == 1.0f) problem_.alpha_real = alpha;
366 if (beta == 0.0f || beta == 1.0f) problem_.beta_real = beta;
367
368 auto status = transfer_post_ops(post_ops);
369 if (status != status::success) return status;
370
371 if (c_offset) problem_.cOffset = COffset::Post;
372
373 if (bias) {
374 if (problem_.cOffset != COffset::None) return status::unimplemented;
375 problem_.cOffset = COffset::Pre;
376 }
377
378 if (problem_.cOffset != COffset::None) {
379 problem_.CO.crosspack = 1;
380 problem_.CO.alignment = problem_.C.alignment;
381 }
382
383 // Find it in the catalog.
384 MatchParams match_params(hw_, problem_);
385
386 match_params.sizes.m = m;
387 match_params.sizes.n = n;
388 match_params.sizes.k = k;
389 match_params.sizes.batch = batch;
390 match_params.unroll[LoopM] = unroll_m;
391 match_params.unroll[LoopN] = unroll_n;
392
393 const char alt_tag[2] = {kcatalog::ReqCustom1, '\0'};
394 if (alt) match_params.lateTags = match_params.tags = &alt_tag[0];
395
396 EvaluateParams eval_params;
397
398 eval_params.sizes = match_params.sizes;
399 eval_params.beta = beta;
400 eval_params.euCount = eu_count;
401
402 entry_ = select(gemm_catalog, match_params, eval_params, aux_params_);
403
404 if (!entry_) return status::unimplemented;
405
406 return finalize();
407}
408
409void gen_gemm_xe_systolic_kernel_desc_t::choose_unrolls(
410 compute::gpu_arch_t arch, int eu_count, data_type_t a_type,
411 data_type_t b_type, data_type_t c_type, dim_t m, dim_t n, dim_t k,
412 dim_t batch, int &unroll_m, int &unroll_n, bool &alt) {
413
414 using namespace data_type;
415
416 alt = false;
417
418 switch (arch) {
419 case compute::gpu_arch_t::xe_hp:
420 case compute::gpu_arch_t::xe_hpg:
421 if (unroll_m == 0) unroll_m = 32;
422 if (unroll_n == 0) unroll_n = (m * n >= 6144 * eu_count) ? 48 : 32;
423
424 if (unroll_n == 48) alt = (m * n >= 13824 * eu_count);
425 break;
426 case compute::gpu_arch_t::xe_hpc:
427 if (utils::one_of(a_type, f16, bf16)) {
428 if (unroll_m != 0)
429 unroll_n = (unroll_m > 16) ? 32 : 16;
430 else if (unroll_n != 0)
431 unroll_m = (unroll_n > 16) ? 64 : 16;
432 else if (m * n < 4096 * eu_count)
433 unroll_m = unroll_n = 16;
434 else {
435 unroll_m = 64;
436 unroll_n = 32;
437 }
438 } else {
439 unroll_m = 64;
440 unroll_n = 32;
441 }
442 break;
443 default: assert(!"Unsupported architecture.");
444 }
445}
446
447void gen_gemm_kernel_t::init_interface() {
448 using namespace ngen;
449
450 auto &problem = *desc()->problem();
451 auto &strategy = *desc()->strategy();
452
453 interface_ = NEOInterfaceHandler {desc()->hw_};
454 auto s_type_ngen = problem.Ts.ngen();
455
456 interface_.newArgument("A", ExternalArgumentType::GlobalPtr);
457 interface_.newArgument("B", ExternalArgumentType::GlobalPtr);
458 interface_.newArgument("C", ExternalArgumentType::GlobalPtr);
459 interface_.newArgument("offset_A", DataType::q);
460 interface_.newArgument("offset_B", DataType::q);
461 interface_.newArgument("offset_C", DataType::q);
462 interface_.newArgument("lda", DataType::d);
463 interface_.newArgument("ldb", DataType::d);
464 interface_.newArgument("ldc", DataType::d);
465 interface_.newArgument("m", DataType::d);
466 interface_.newArgument("n", DataType::d);
467 interface_.newArgument("k", DataType::d);
468 interface_.newArgument("alpha_real", s_type_ngen);
469 interface_.newArgument("beta_real", s_type_ngen);
470 if (problem.abOffset != ABOffset::None) {
471 if (!desc()->a_offset_ && !desc()->b_offset_)
472 interface_.newArgument("abo", DataType::ud);
473 else {
474 if (desc()->a_offset_)
475 interface_.newArgument(
476 "ao_ptr", ExternalArgumentType::GlobalPtr);
477 if (desc()->b_offset_)
478 interface_.newArgument(
479 "bo_ptr", ExternalArgumentType::GlobalPtr);
480 }
481 }
482 if (problem.cOffset != COffset::None || problem.sumA || problem.sumB) {
483 interface_.newArgument("CO", ExternalArgumentType::GlobalPtr);
484 interface_.newArgument("offset_CO", DataType::d);
485 if (problem.cOffset == COffset::Pre)
486 interface_.newArgument("ldco", DataType::d);
487 }
488 for (int i = 0; i < problem.postOps.len(); i++) {
489 if (problem.postOps.entry_[i].kind != primitive_kind::binary) continue;
490 auto bname = "binary" + std::to_string(i);
491 interface_.newArgument(bname, ExternalArgumentType::GlobalPtr);
492 interface_.newArgument("offset_" + bname, DataType::d);
493 if (problem.binaryRow[i] && problem.binaryCol[i])
494 interface_.newArgument("ld" + bname, DataType::d);
495 }
496 interface_.newArgument("flags", DataType::ud);
497 if (strategy.kParallel || strategy.kParallelLocal)
498 interface_.newArgument("k0", DataType::d);
499 if (problem.batch == BatchMode::Strided) {
500 if (problem.batchDims > 1) {
501 interface_.newArgument("stride_A1", DataType::d);
502 interface_.newArgument("stride_B1", DataType::d);
503 interface_.newArgument("stride_C1", DataType::d);
504 for (int i = 0; i < problem.postOps.len(); i++)
505 if (problem.postOps.entry_[i].kind == primitive_kind::binary
506 && problem.binaryBatch[i])
507 interface_.newArgument(
508 "stride1_binary" + std::to_string(i), DataType::d);
509 }
510 interface_.newArgument("stride_A", DataType::d);
511 interface_.newArgument("stride_B", DataType::d);
512 interface_.newArgument("stride_C", DataType::d);
513 for (int i = 0; i < problem.postOps.len(); i++)
514 if (problem.postOps.entry_[i].kind == primitive_kind::binary
515 && problem.binaryBatch[i])
516 interface_.newArgument(
517 "stride_binary" + std::to_string(i), DataType::d);
518 if (problem.batchDims > 1) {
519 interface_.newArgument("batch_size1", DataType::ud);
520 interface_.newArgument("recip_batch_size1", DataType::ud);
521 }
522 }
523 if (strategy.linearOrder()) {
524 interface_.newArgument("group_count_m", DataType::ud);
525 interface_.newArgument("group_count_n", DataType::ud);
526 }
527 if (strategy.hilbertOrder) {
528 interface_.newArgument("hilbert_vd", DataType::ud);
529 interface_.newArgument("hilbert_uvd_recip", DataType::ud);
530 interface_.newArgument("hilbert_bail", DataType::ud);
531 } else if (strategy.boustrophedon) {
532 interface_.newArgument("bslice", DataType::d);
533 interface_.newArgument("bthresh", DataType::d);
534 }
535 if (strategy.persistent)
536 interface_.newArgument("group_stride", DataType::ud);
537 if (strategy.variableSLM())
538 interface_.newArgument("local_mem", ExternalArgumentType::LocalPtr);
539
540 interface_.externalName(kernel_name());
541}
542
543cl_kernel gen_gemm_kernel_t::get_kernel(
544 cl_context context, cl_device_id device) {
545 cl_kernel ocl_kernel = nullptr;
546
547 init_interface();
548
549#define ARCH_DISPATCH(arch) \
550 case ngen::HW::arch: { \
551 gemm_kernel_generator_t<ngen::HW::arch> generator; \
552 generator.setStepping(desc()->stepping_); \
553 generator.gemm(*desc()->problem(), *desc()->strategy(), interface_); \
554 ocl_kernel = generator.getKernel(context, device); \
555 break; \
556 }
557
558 switch (desc()->hw_) {
559 REG_GEN9_ISA(ARCH_DISPATCH(Gen9))
560 REG_XELP_ISA(ARCH_DISPATCH(XeLP))
561 REG_XEHP_ISA(ARCH_DISPATCH(XeHP))
562 REG_XEHPG_ISA(ARCH_DISPATCH(XeHPG))
563 REG_XEHPC_ISA(ARCH_DISPATCH(XeHPC))
564 default: assert(!"Unsupported architecture"); break;
565 }
566
567 return ocl_kernel;
568
569#undef ARCH_DISPATCH
570}
571
572} // namespace jit
573} // namespace gpu
574} // namespace impl
575} // namespace dnnl
576