1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <cctype> |
18 | |
19 | #include "common/gemm_utils.hpp" |
20 | #include "common/impl_registration.hpp" |
21 | #include "gpu/compute/device_info.hpp" |
22 | #include "gpu/jit/gemm/gen_gemm_kernel.hpp" |
23 | #include "gpu/jit/gemm/kernel_catalog.hpp" |
24 | #include "gpu/jit/gemm/kernel_selector.hpp" |
25 | #include "gpu/jit/gemm/strategy_parser.hpp" |
26 | #include "gpu/ocl/ocl_utils.hpp" |
27 | |
28 | namespace dnnl { |
29 | namespace impl { |
30 | namespace gpu { |
31 | namespace jit { |
32 | |
33 | #define _CATALOG_ gemm_catalog |
34 | #include "gpu/jit/gemm/kernel.db" |
35 | ; |
36 | #undef _CATALOG_ |
37 | |
38 | status_t gen_gemm_kernel_desc_t::finalize() { |
39 | // Update problem alignments to match catalog entry. |
40 | if (!isPacked(problem_.A.layout)) { |
41 | problem_.A.setAlignment( |
42 | std::max(problem_.Ta.size(), entry_->driverInfo.alignment[0])); |
43 | } |
44 | |
45 | if (!isPacked(problem_.B.layout)) { |
46 | problem_.B.setAlignment( |
47 | std::max(problem_.Tb.size(), entry_->driverInfo.alignment[1])); |
48 | } |
49 | |
50 | if (!isPacked(problem_.C.layout)) { |
51 | problem_.C.setAlignment(std::max( |
52 | problem_.Tc_ext.size(), entry_->restrictions.alignment[2])); |
53 | } |
54 | |
55 | problem_.CO.setAlignment(problem_.Tco.size()); |
56 | |
57 | // Parse strategy string. |
58 | strategy_ = GEMMStrategy(hw_, stepping_); |
59 | strategy_.unroll[LoopM] = entry_->driverInfo.unroll[LoopM]; |
60 | strategy_.unroll[LoopN] = entry_->driverInfo.unroll[LoopN]; |
61 | parseStrategy(entry_->strategy, hw_, problem_, strategy_); |
62 | adjustStrategy(hw_, problem_, strategy_); |
63 | |
64 | // Always use variable beta for global k-parallel kernels. |
65 | if (strategy_.kParallel) problem_.beta_real = Scalar<double>(); |
66 | |
67 | // Omit periodic barriers when k is small. |
68 | if (strategy_.barrierFreq > 0 && k_ >= 0 && k_ < 2 * strategy_.barrierFreq) |
69 | strategy_.barrierFreq = 0; |
70 | |
71 | // Disable linear ordering and persistent threads if the GEMM doesn't fill the GPU. |
72 | if (m_ >= 0 && n_ >= 0 && eu_count_ >= 0) { |
73 | int wg_tile_m = strategy_.wg[LoopM] * strategy_.unroll[LoopM]; |
74 | int wg_tile_n = strategy_.wg[LoopN] * strategy_.unroll[LoopN]; |
75 | if (wg_tile_m > 0 && wg_tile_n > 0) { |
76 | dim_t thread_count = utils::div_up(m_, wg_tile_m) |
77 | * utils::div_up(n_, wg_tile_n) * strategy_.wg[LoopM] |
78 | * strategy_.wg[LoopN] * std::max(strategy_.wg[LoopK], 1); |
79 | dim_t thread_gpu = eu_count_ |
80 | * compute::device_info_t::threads_per_eu( |
81 | arch_, strategy_.GRFs > 128); |
82 | if (thread_count <= thread_gpu) |
83 | strategy_.persistent = strategy_.hilbertOrder |
84 | = strategy_.boustrophedon = false; |
85 | } |
86 | } |
87 | |
88 | strategy_.preflight(hw_, problem_); |
89 | |
90 | update_driver_info(); |
91 | |
92 | return status::success; |
93 | } |
94 | |
95 | void gen_gemm_kernel_desc_t::update_driver_info() { |
96 | #define ARCH_DISPATCH(arch) \ |
97 | case ngen::HW::arch: \ |
98 | driver_info_ = gemm_kernel_generator_t<ngen::HW::arch>::driverInfo( \ |
99 | problem_, strategy_); \ |
100 | break; |
101 | |
102 | switch (hw_) { |
103 | REG_GEN9_ISA(ARCH_DISPATCH(Gen9)) |
104 | REG_XELP_ISA(ARCH_DISPATCH(XeLP)) |
105 | REG_XEHP_ISA(ARCH_DISPATCH(XeHP)) |
106 | REG_XEHPG_ISA(ARCH_DISPATCH(XeHPG)) |
107 | REG_XEHPC_ISA(ARCH_DISPATCH(XeHPC)) |
108 | default: |
109 | assert(!"Unsupported architecture" ); |
110 | driver_info_ = entry_->driverInfo; |
111 | break; |
112 | } |
113 | #undef ARCH_DISPATCH |
114 | } |
115 | |
116 | status_t gen_gemm_kernel_desc_t::transfer_post_ops( |
117 | const post_ops_t &post_ops, bool swap_ab) { |
118 | if (post_ops.len() > 0) { |
119 | problem_.postOps = post_ops; |
120 | int po_count = post_ops.len(); |
121 | problem_.Tbinary.reserve(po_count); |
122 | problem_.binary.reserve(po_count); |
123 | problem_.binaryRow.reserve(po_count); |
124 | problem_.binaryCol.reserve(po_count); |
125 | problem_.binaryBatch.reserve(po_count); |
126 | |
127 | if (problem_.Ta == Type::f16) problem_.Ts = Type::f32; |
128 | |
129 | for (int i = 0; i < po_count; i++) { |
130 | const auto &entry = post_ops.entry_[i]; |
131 | if (entry.kind != primitive_kind::binary) { |
132 | problem_.Tbinary.push_back(Type::invalid); |
133 | problem_.binaryRow.push_back(false); |
134 | problem_.binaryCol.push_back(false); |
135 | problem_.binaryBatch.push_back(false); |
136 | problem_.binary.push_back(MatrixAddressing {}); |
137 | continue; |
138 | } |
139 | |
140 | const auto &src_md = entry.binary.src1_desc; |
141 | memory_desc_wrapper src_mdw(src_md); |
142 | |
143 | int ndims = src_mdw.ndims(); |
144 | auto T = convert_dnnl_to_kernel_type(src_mdw.data_type()); |
145 | int nr = (ndims >= 1) ? src_mdw.dims()[ndims - 1] : 1; |
146 | int nc = (ndims >= 2) ? src_mdw.dims()[ndims - 2] : 1; |
147 | bool trans = false; |
148 | |
149 | if (src_mdw.ndims() >= 2) { |
150 | if (src_md.format_kind != format_kind::blocked |
151 | || !is_md_gemm_compatible_plain_format(&src_md, false)) |
152 | return status::unimplemented; |
153 | trans = (src_md.format_desc.blocking.strides[ndims - 1] > 1); |
154 | } |
155 | |
156 | if (swap_ab) { |
157 | trans = !trans; |
158 | std::swap(nr, nc); |
159 | } |
160 | |
161 | problem_.Tbinary.push_back(T); |
162 | problem_.binaryRow.push_back(nr > 1); |
163 | problem_.binaryCol.push_back(nc > 1); |
164 | problem_.binaryBatch.push_back(ndims >= 3); |
165 | |
166 | MatrixAddressing atype; |
167 | atype.layout = trans ? MatrixLayout::T : MatrixLayout::N; |
168 | atype.crosspack = 1; |
169 | atype.setAlignment(T.size()); |
170 | |
171 | problem_.binary.push_back(atype); |
172 | } |
173 | } |
174 | |
175 | return status::success; |
176 | } |
177 | |
178 | status_t gen_gemm_nocopy_kernel_desc_t::select_kernel(compute::gpu_arch_t arch, |
179 | int stepping, int eu_count, compute_mode mode, int batch_dims, |
180 | bool trans_a, bool trans_b, bool trans_co, bool swap_ab, bool a_offset, |
181 | bool b_offset, bool c_offset, bool bias, sum_ab_t reduce_ab, |
182 | float alpha, float beta, const post_ops_t &post_ops, data_type_t a_type, |
183 | data_type_t b_type, data_type_t c_type, data_type_t co_type, |
184 | data_type_t acc_type, int align_a, int align_b, int align_c, dim_t m, |
185 | dim_t n, dim_t k, dim_t lda, dim_t ldb, dim_t ldc, dim_t batch) { |
186 | using namespace ngen; |
187 | using namespace kcatalog; |
188 | |
189 | arch_ = arch; |
190 | hw_ = convert_dnnl_arch_to_hw(arch); |
191 | stepping_ = stepping; |
192 | m_ = m; |
193 | n_ = n; |
194 | k_ = k; |
195 | eu_count_ = eu_count; |
196 | a_offset_ = a_offset; |
197 | b_offset_ = b_offset; |
198 | |
199 | align_a = nstl::max(align_a, int(types::data_type_size(a_type))); |
200 | align_b = nstl::max(align_b, int(types::data_type_size(b_type))); |
201 | align_c = nstl::max(align_c, int(types::data_type_size(c_type))); |
202 | |
203 | // Set up problem structure. |
204 | problem_.Ta = problem_.Ta_ext = convert_dnnl_to_kernel_type(a_type); |
205 | problem_.Tb = problem_.Tb_ext = convert_dnnl_to_kernel_type(b_type); |
206 | problem_.Tc = convert_dnnl_to_kernel_type(acc_type); |
207 | problem_.Tco = convert_dnnl_to_kernel_type(co_type); |
208 | problem_.Tc_ext = convert_dnnl_to_kernel_type(c_type); |
209 | problem_.Ts = problem_.Tc; |
210 | problem_.A.layout = trans_a ? MatrixLayout::T : MatrixLayout::N; |
211 | problem_.B.layout = trans_b ? MatrixLayout::T : MatrixLayout::N; |
212 | problem_.C.layout = MatrixLayout::N; |
213 | problem_.A.crosspack = problem_.B.crosspack = problem_.C.crosspack = 1; |
214 | problem_.A.packSize = problem_.B.packSize = problem_.C.packSize = 0; |
215 | problem_.A.setAlignment(align_a); |
216 | problem_.B.setAlignment(align_b); |
217 | problem_.C.setAlignment(align_c); |
218 | if (batch_dims > 0) { |
219 | problem_.batch = BatchMode::Strided; |
220 | problem_.batchDims = batch_dims; |
221 | } |
222 | if (a_offset || b_offset) problem_.abOffset = ABOffset::Calc; |
223 | |
224 | if (problem_.Ta.isInteger()) problem_.Ts = Type::f32; |
225 | |
226 | if (alpha == 1.0f) problem_.alpha_real = alpha; |
227 | if (beta == 0.0f || beta == 1.0f) problem_.beta_real = beta; |
228 | |
229 | auto status = transfer_post_ops(post_ops, swap_ab); |
230 | if (status != status::success) return status; |
231 | |
232 | if (c_offset || bias || reduce_ab != sum_ab::sum_none) { |
233 | assert(!(c_offset && bias)); |
234 | if (bias) problem_.cOffset = COffset::Pre; |
235 | if (c_offset) problem_.cOffset = COffset::Post; |
236 | problem_.CO.crosspack = 1; |
237 | problem_.CO.alignment = problem_.C.alignment; |
238 | problem_.CO.layout = trans_co ? MatrixLayout::T : MatrixLayout::N; |
239 | } |
240 | |
241 | problem_.sumA = (reduce_ab == sum_ab::sum_b_col); |
242 | problem_.sumB = (reduce_ab == sum_ab::sum_a_row); |
243 | |
244 | // Select a kernel from the catalog. |
245 | MatchParams match_params[3]; |
246 | int npatterns = 1; |
247 | |
248 | match_params[0] = MatchParams(hw_, problem_); |
249 | |
250 | match_params[0].sizes.m = m; |
251 | match_params[0].sizes.n = n; |
252 | match_params[0].sizes.k = k; |
253 | match_params[0].sizes.batch = batch; |
254 | match_params[0].stepping = stepping; |
255 | |
256 | auto tags = const_cast<char *>(match_params[0].tags); |
257 | while (*tags) |
258 | tags++; |
259 | if (lda * problem_.Ta >= 64) *tags++ = kcatalog::ReqBlock2DA; |
260 | if (ldb * problem_.Tb >= 64) *tags++ = kcatalog::ReqBlock2DB; |
261 | if (ldc * problem_.Tc >= 64) *tags++ = kcatalog::ReqBlock2DC; |
262 | |
263 | if ((mode & mode_tf32) |
264 | && utils::everyone_is(Type::f32, problem_.Ta, problem_.Tb)) { |
265 | match_params[npatterns] = match_params[0]; |
266 | match_params[npatterns].selector.precisions[0] = "T" ; |
267 | match_params[npatterns].selector.precisions[1] = "T" ; |
268 | npatterns++; |
269 | } |
270 | |
271 | if ((mode & mode_bf16x1) |
272 | && utils::everyone_is(Type::f32, problem_.Ta, problem_.Tb)) { |
273 | match_params[npatterns] = match_params[0]; |
274 | match_params[npatterns].selector.precisions[0] = "[SB]" ; |
275 | match_params[npatterns].selector.precisions[1] = "[SB]" ; |
276 | npatterns++; |
277 | } |
278 | |
279 | EvaluateParams eval_params; |
280 | |
281 | eval_params.sizes = match_params[0].sizes; |
282 | eval_params.beta = (post_ops.len() > 0) ? 0.0f : beta; |
283 | eval_params.euCount = eu_count; |
284 | |
285 | entry_ = select( |
286 | gemm_catalog, npatterns, match_params, eval_params, aux_params_); |
287 | |
288 | if (!entry_) return status::unimplemented; |
289 | |
290 | if (mode & mode_tf32) { |
291 | if (entry_->selector.precisions[0][0] == 'T') |
292 | problem_.Ta = problem_.Ta_ext = Type::tf32; |
293 | if (entry_->selector.precisions[1][0] == 'T') |
294 | problem_.Tb = problem_.Tb_ext = Type::tf32; |
295 | } |
296 | |
297 | if (mode & mode_bf16x1) { |
298 | if (entry_->selector.precisions[0][0] == '[') problem_.Ta = Type::bf16; |
299 | if (entry_->selector.precisions[1][0] == '[') problem_.Tb = Type::bf16; |
300 | } |
301 | |
302 | auto block_k = entry_->driverInfo.blocking[LoopK]; |
303 | if (block_k > 0 && k > block_k && beta != 1.0f) |
304 | problem_.beta_real = Scalar<double>(); |
305 | |
306 | return finalize(); |
307 | } |
308 | |
309 | status_t gen_gemm_xe_systolic_kernel_desc_t::select_kernel( |
310 | compute::gpu_arch_t arch, int eu_count, int batch_dims, bool packed_c, |
311 | bool a_offset, bool b_offset, bool c_offset, bool bias, float alpha, |
312 | float beta, const post_ops_t &post_ops, data_type_t a_type, |
313 | data_type_t b_type, data_type_t c_type, data_type_t co_type, |
314 | data_type_t acc_type, dim_t m, dim_t n, dim_t k, dim_t batch, |
315 | int unroll_m, int unroll_n, bool alt) { |
316 | using namespace ngen; |
317 | using namespace kcatalog; |
318 | |
319 | arch_ = arch; |
320 | hw_ = convert_dnnl_arch_to_hw(arch); |
321 | m_ = m; |
322 | n_ = n; |
323 | k_ = k; |
324 | eu_count_ = eu_count; |
325 | a_offset_ = a_offset; |
326 | b_offset_ = b_offset; |
327 | |
328 | if (!utils::one_of(hw_, HW::XeHP, HW::XeHPG, HW::XeHPC)) |
329 | return status::unimplemented; |
330 | |
331 | bool xehpc = (hw_ == HW::XeHPC); |
332 | |
333 | auto osys = xehpc ? 16 : 8; |
334 | auto ksys = int(32 / types::data_type_size(a_type)); |
335 | auto csys = int(4 / types::data_type_size(a_type)); |
336 | |
337 | problem_.Ta = problem_.Ta_ext = convert_dnnl_to_kernel_type(a_type); |
338 | problem_.Tb = problem_.Tb_ext = convert_dnnl_to_kernel_type(b_type); |
339 | problem_.Tc = convert_dnnl_to_kernel_type(acc_type); |
340 | problem_.Tco = convert_dnnl_to_kernel_type(co_type); |
341 | problem_.Tc_ext = convert_dnnl_to_kernel_type(c_type); |
342 | problem_.Ts = Type::f32; |
343 | problem_.A.layout = MatrixLayout::PackedColumns; |
344 | problem_.B.layout = MatrixLayout::PackedRows; |
345 | problem_.C.layout = MatrixLayout::N; |
346 | problem_.A.crosspack = csys; |
347 | problem_.B.crosspack = ksys; |
348 | problem_.C.crosspack = 1; |
349 | problem_.A.packSize = unroll_m; |
350 | problem_.B.packSize = unroll_n; |
351 | problem_.C.packSize = 0; |
352 | if (osys < unroll_m) { |
353 | problem_.A.tileR = osys; |
354 | problem_.A.tileC = ksys; |
355 | } |
356 | problem_.A.setAlignment(32); |
357 | problem_.B.setAlignment(32); |
358 | problem_.C.setAlignment(int(types::data_type_size(c_type))); |
359 | if (packed_c) problem_.C = problem_.B; |
360 | if (batch_dims > 0) { |
361 | problem_.batch = BatchMode::Strided; |
362 | problem_.batchDims = batch_dims; |
363 | } |
364 | if (a_offset || b_offset) problem_.abOffset = ABOffset::Load; |
365 | if (alpha == 1.0f) problem_.alpha_real = alpha; |
366 | if (beta == 0.0f || beta == 1.0f) problem_.beta_real = beta; |
367 | |
368 | auto status = transfer_post_ops(post_ops); |
369 | if (status != status::success) return status; |
370 | |
371 | if (c_offset) problem_.cOffset = COffset::Post; |
372 | |
373 | if (bias) { |
374 | if (problem_.cOffset != COffset::None) return status::unimplemented; |
375 | problem_.cOffset = COffset::Pre; |
376 | } |
377 | |
378 | if (problem_.cOffset != COffset::None) { |
379 | problem_.CO.crosspack = 1; |
380 | problem_.CO.alignment = problem_.C.alignment; |
381 | } |
382 | |
383 | // Find it in the catalog. |
384 | MatchParams match_params(hw_, problem_); |
385 | |
386 | match_params.sizes.m = m; |
387 | match_params.sizes.n = n; |
388 | match_params.sizes.k = k; |
389 | match_params.sizes.batch = batch; |
390 | match_params.unroll[LoopM] = unroll_m; |
391 | match_params.unroll[LoopN] = unroll_n; |
392 | |
393 | const char alt_tag[2] = {kcatalog::ReqCustom1, '\0'}; |
394 | if (alt) match_params.lateTags = match_params.tags = &alt_tag[0]; |
395 | |
396 | EvaluateParams eval_params; |
397 | |
398 | eval_params.sizes = match_params.sizes; |
399 | eval_params.beta = beta; |
400 | eval_params.euCount = eu_count; |
401 | |
402 | entry_ = select(gemm_catalog, match_params, eval_params, aux_params_); |
403 | |
404 | if (!entry_) return status::unimplemented; |
405 | |
406 | return finalize(); |
407 | } |
408 | |
409 | void gen_gemm_xe_systolic_kernel_desc_t::choose_unrolls( |
410 | compute::gpu_arch_t arch, int eu_count, data_type_t a_type, |
411 | data_type_t b_type, data_type_t c_type, dim_t m, dim_t n, dim_t k, |
412 | dim_t batch, int &unroll_m, int &unroll_n, bool &alt) { |
413 | |
414 | using namespace data_type; |
415 | |
416 | alt = false; |
417 | |
418 | switch (arch) { |
419 | case compute::gpu_arch_t::xe_hp: |
420 | case compute::gpu_arch_t::xe_hpg: |
421 | if (unroll_m == 0) unroll_m = 32; |
422 | if (unroll_n == 0) unroll_n = (m * n >= 6144 * eu_count) ? 48 : 32; |
423 | |
424 | if (unroll_n == 48) alt = (m * n >= 13824 * eu_count); |
425 | break; |
426 | case compute::gpu_arch_t::xe_hpc: |
427 | if (utils::one_of(a_type, f16, bf16)) { |
428 | if (unroll_m != 0) |
429 | unroll_n = (unroll_m > 16) ? 32 : 16; |
430 | else if (unroll_n != 0) |
431 | unroll_m = (unroll_n > 16) ? 64 : 16; |
432 | else if (m * n < 4096 * eu_count) |
433 | unroll_m = unroll_n = 16; |
434 | else { |
435 | unroll_m = 64; |
436 | unroll_n = 32; |
437 | } |
438 | } else { |
439 | unroll_m = 64; |
440 | unroll_n = 32; |
441 | } |
442 | break; |
443 | default: assert(!"Unsupported architecture." ); |
444 | } |
445 | } |
446 | |
447 | void gen_gemm_kernel_t::init_interface() { |
448 | using namespace ngen; |
449 | |
450 | auto &problem = *desc()->problem(); |
451 | auto &strategy = *desc()->strategy(); |
452 | |
453 | interface_ = NEOInterfaceHandler {desc()->hw_}; |
454 | auto s_type_ngen = problem.Ts.ngen(); |
455 | |
456 | interface_.newArgument("A" , ExternalArgumentType::GlobalPtr); |
457 | interface_.newArgument("B" , ExternalArgumentType::GlobalPtr); |
458 | interface_.newArgument("C" , ExternalArgumentType::GlobalPtr); |
459 | interface_.newArgument("offset_A" , DataType::q); |
460 | interface_.newArgument("offset_B" , DataType::q); |
461 | interface_.newArgument("offset_C" , DataType::q); |
462 | interface_.newArgument("lda" , DataType::d); |
463 | interface_.newArgument("ldb" , DataType::d); |
464 | interface_.newArgument("ldc" , DataType::d); |
465 | interface_.newArgument("m" , DataType::d); |
466 | interface_.newArgument("n" , DataType::d); |
467 | interface_.newArgument("k" , DataType::d); |
468 | interface_.newArgument("alpha_real" , s_type_ngen); |
469 | interface_.newArgument("beta_real" , s_type_ngen); |
470 | if (problem.abOffset != ABOffset::None) { |
471 | if (!desc()->a_offset_ && !desc()->b_offset_) |
472 | interface_.newArgument("abo" , DataType::ud); |
473 | else { |
474 | if (desc()->a_offset_) |
475 | interface_.newArgument( |
476 | "ao_ptr" , ExternalArgumentType::GlobalPtr); |
477 | if (desc()->b_offset_) |
478 | interface_.newArgument( |
479 | "bo_ptr" , ExternalArgumentType::GlobalPtr); |
480 | } |
481 | } |
482 | if (problem.cOffset != COffset::None || problem.sumA || problem.sumB) { |
483 | interface_.newArgument("CO" , ExternalArgumentType::GlobalPtr); |
484 | interface_.newArgument("offset_CO" , DataType::d); |
485 | if (problem.cOffset == COffset::Pre) |
486 | interface_.newArgument("ldco" , DataType::d); |
487 | } |
488 | for (int i = 0; i < problem.postOps.len(); i++) { |
489 | if (problem.postOps.entry_[i].kind != primitive_kind::binary) continue; |
490 | auto bname = "binary" + std::to_string(i); |
491 | interface_.newArgument(bname, ExternalArgumentType::GlobalPtr); |
492 | interface_.newArgument("offset_" + bname, DataType::d); |
493 | if (problem.binaryRow[i] && problem.binaryCol[i]) |
494 | interface_.newArgument("ld" + bname, DataType::d); |
495 | } |
496 | interface_.newArgument("flags" , DataType::ud); |
497 | if (strategy.kParallel || strategy.kParallelLocal) |
498 | interface_.newArgument("k0" , DataType::d); |
499 | if (problem.batch == BatchMode::Strided) { |
500 | if (problem.batchDims > 1) { |
501 | interface_.newArgument("stride_A1" , DataType::d); |
502 | interface_.newArgument("stride_B1" , DataType::d); |
503 | interface_.newArgument("stride_C1" , DataType::d); |
504 | for (int i = 0; i < problem.postOps.len(); i++) |
505 | if (problem.postOps.entry_[i].kind == primitive_kind::binary |
506 | && problem.binaryBatch[i]) |
507 | interface_.newArgument( |
508 | "stride1_binary" + std::to_string(i), DataType::d); |
509 | } |
510 | interface_.newArgument("stride_A" , DataType::d); |
511 | interface_.newArgument("stride_B" , DataType::d); |
512 | interface_.newArgument("stride_C" , DataType::d); |
513 | for (int i = 0; i < problem.postOps.len(); i++) |
514 | if (problem.postOps.entry_[i].kind == primitive_kind::binary |
515 | && problem.binaryBatch[i]) |
516 | interface_.newArgument( |
517 | "stride_binary" + std::to_string(i), DataType::d); |
518 | if (problem.batchDims > 1) { |
519 | interface_.newArgument("batch_size1" , DataType::ud); |
520 | interface_.newArgument("recip_batch_size1" , DataType::ud); |
521 | } |
522 | } |
523 | if (strategy.linearOrder()) { |
524 | interface_.newArgument("group_count_m" , DataType::ud); |
525 | interface_.newArgument("group_count_n" , DataType::ud); |
526 | } |
527 | if (strategy.hilbertOrder) { |
528 | interface_.newArgument("hilbert_vd" , DataType::ud); |
529 | interface_.newArgument("hilbert_uvd_recip" , DataType::ud); |
530 | interface_.newArgument("hilbert_bail" , DataType::ud); |
531 | } else if (strategy.boustrophedon) { |
532 | interface_.newArgument("bslice" , DataType::d); |
533 | interface_.newArgument("bthresh" , DataType::d); |
534 | } |
535 | if (strategy.persistent) |
536 | interface_.newArgument("group_stride" , DataType::ud); |
537 | if (strategy.variableSLM()) |
538 | interface_.newArgument("local_mem" , ExternalArgumentType::LocalPtr); |
539 | |
540 | interface_.externalName(kernel_name()); |
541 | } |
542 | |
543 | cl_kernel gen_gemm_kernel_t::get_kernel( |
544 | cl_context context, cl_device_id device) { |
545 | cl_kernel ocl_kernel = nullptr; |
546 | |
547 | init_interface(); |
548 | |
549 | #define ARCH_DISPATCH(arch) \ |
550 | case ngen::HW::arch: { \ |
551 | gemm_kernel_generator_t<ngen::HW::arch> generator; \ |
552 | generator.setStepping(desc()->stepping_); \ |
553 | generator.gemm(*desc()->problem(), *desc()->strategy(), interface_); \ |
554 | ocl_kernel = generator.getKernel(context, device); \ |
555 | break; \ |
556 | } |
557 | |
558 | switch (desc()->hw_) { |
559 | REG_GEN9_ISA(ARCH_DISPATCH(Gen9)) |
560 | REG_XELP_ISA(ARCH_DISPATCH(XeLP)) |
561 | REG_XEHP_ISA(ARCH_DISPATCH(XeHP)) |
562 | REG_XEHPG_ISA(ARCH_DISPATCH(XeHPG)) |
563 | REG_XEHPC_ISA(ARCH_DISPATCH(XeHPC)) |
564 | default: assert(!"Unsupported architecture" ); break; |
565 | } |
566 | |
567 | return ocl_kernel; |
568 | |
569 | #undef ARCH_DISPATCH |
570 | } |
571 | |
572 | } // namespace jit |
573 | } // namespace gpu |
574 | } // namespace impl |
575 | } // namespace dnnl |
576 | |