1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef GPU_JIT_GEMM_GEN_GEMM_HPP |
18 | #define GPU_JIT_GEMM_GEN_GEMM_HPP |
19 | |
20 | #include <assert.h> |
21 | #include <memory> |
22 | |
23 | #include "common/c_types_map.hpp" |
24 | #include "common/gemm_utils.hpp" |
25 | #include "common/utils.hpp" |
26 | #include "gpu/compute/compute.hpp" |
27 | #include "gpu/compute/kernel.hpp" |
28 | #include "gpu/gemm/gpu_gemm.hpp" |
29 | #include "gpu/jit/gemm/gen_gemm_kernel.hpp" |
30 | #include "gpu/jit/gemm/jit_gemm_pd.hpp" |
31 | #include "gpu/jit/jit_post_op_injector.hpp" |
32 | #include "gpu/primitive_conf.hpp" |
33 | |
34 | namespace dnnl { |
35 | namespace impl { |
36 | namespace gpu { |
37 | namespace jit { |
38 | |
39 | struct gen_gemm_t : public gpu_gemm_t { |
40 | struct pd_t : public jit_gemm_pd_t { |
41 | using jit_gemm_pd_t::jit_gemm_pd_t; |
42 | using kernel_desc_t = gen_gemm_nocopy_kernel_desc_t; |
43 | |
44 | DECLARE_COMMON_PD_T("jit:gemm:any" , gen_gemm_t); |
45 | |
46 | status_t init(engine_t *engine) { |
47 | using namespace prop_kind; |
48 | using namespace data_type; |
49 | using namespace primitive_kind; |
50 | using namespace alg_kind; |
51 | using smask_t = primitive_attr_t::skip_mask_t; |
52 | using arch_t = compute::gpu_arch_t; |
53 | |
54 | assert(engine->kind() == engine_kind::gpu); |
55 | auto *compute_engine |
56 | = utils::downcast<compute::compute_engine_t *>(engine); |
57 | |
58 | // LIMITATIONS: |
59 | // - runtime dims are not supported |
60 | bool ok = true; |
61 | |
62 | auto attr_skip_mask = smask_t::scales_runtime | smask_t::post_ops; |
63 | |
64 | dev_info_ = compute_engine->device_info(); |
65 | arch_ = dev_info_->gpu_arch(); |
66 | int stepping = dev_info_->stepping_id(); |
67 | |
68 | ok = set_default_formats(); |
69 | if (!ok) return status::unimplemented; |
70 | |
71 | bool check_lda |
72 | = ((desc()->transa() == dnnl_notrans && desc()->lda() == 1) |
73 | || (desc()->transa() == dnnl_trans)); |
74 | swap_ab_ = (desc()->a_type() == data_type::f16 && desc()->m() == 1 |
75 | && desc()->ldc() == 1 && check_lda); |
76 | |
77 | const auto d = desc(); |
78 | |
79 | if (utils::one_of(d->c_type(), s32, f16, f32, u8, s8) |
80 | && utils::one_of(d->a_type(), u8, s8)) { |
81 | ok = ok && utils::one_of(d->b_type(), u8, s8); |
82 | |
83 | a_zp_ = !attr()->zero_points_.has_default_values(DNNL_ARG_SRC); |
84 | b_zp_ = !attr()->zero_points_.has_default_values( |
85 | DNNL_ARG_WEIGHTS); |
86 | if (swap_ab_) std::swap(a_zp_, b_zp_); |
87 | |
88 | int cmask_a = 0, cmask_b = 0, cmask_c = 0; |
89 | attr()->zero_points_.get(DNNL_ARG_WEIGHTS, &cmask_b); |
90 | attr()->zero_points_.get(DNNL_ARG_SRC, &cmask_a); |
91 | attr()->zero_points_.get(DNNL_ARG_DST, &cmask_c); |
92 | ok &= (cmask_a == 0) && (cmask_b == 0) |
93 | && utils::one_of(cmask_c, 0, 1 << 0, 1 << 1); |
94 | |
95 | attr_skip_mask |= smask_t::zero_points_runtime; |
96 | |
97 | ok = ok |
98 | && IMPLICATION( |
99 | utils::one_of(d->c_type(), f32, s8, u8, f16), |
100 | arch_ >= arch_t::xe_hp); |
101 | } else if (d->a_type() == bf16) { |
102 | ok = ok && d->b_type() == bf16 |
103 | && utils::one_of(d->c_type(), bf16, f32) |
104 | && utils::one_of(d->acc_type, bf16, f32); |
105 | } else { |
106 | ok = ok && utils::one_of(d->a_type(), f32, f16) |
107 | && d->b_type() == d->a_type() |
108 | && utils::one_of(d->acc_type, d->a_type(), f32); |
109 | } |
110 | |
111 | ok = ok && !has_blocks() && batch_dims() <= 2 |
112 | && !utils::one_of(DNNL_RUNTIME_DIM_VAL, d->m(), d->n(), |
113 | d->k(), d->lda(), d->ldb(), d->ldc(), d->batch()) |
114 | && IMPLICATION(with_bias(), |
115 | utils::one_of(d->bias_type(), f32, bf16, f16) |
116 | && (d->bias_desc.ndims <= 3) |
117 | && utils::one_of(bias_cmask(), 0, 1, 2, 3)) |
118 | && IMPLICATION(utils::one_of(d->bias_type(), bf16, f16), |
119 | (d->bias_type() == d->c_type())) |
120 | && compute_engine->mayiuse_ngen_kernels() |
121 | && attr()->has_default_values(attr_skip_mask) |
122 | && attr()->output_scales_.mask_ == 0 |
123 | && IMPLICATION(with_sum_ab(), |
124 | !with_bias() |
125 | && (attr()->zero_points_.has_default_values( |
126 | DNNL_ARG_DST))); |
127 | |
128 | auto status = init_post_ops(); |
129 | if (status != status::success) return status; |
130 | |
131 | bool with_binary = (post_ops_.find(binary) != -1); |
132 | |
133 | // check GPU architecture |
134 | ok &= utils::one_of(arch_, arch_t::gen9, arch_t::xe_lp, |
135 | arch_t::xe_hp, arch_t::xe_hpg, arch_t::xe_hpc); |
136 | ok &= IMPLICATION(with_binary, arch_ >= arch_t::xe_hp); |
137 | |
138 | if (!ok) return status::unimplemented; |
139 | |
140 | // size checks for fused reduction kernels |
141 | if (with_sum_ab()) { |
142 | auto mnk = d->m() * d->n() * d->k(); |
143 | if (arch_ == arch_t::xe_hpc && d->a_type() == f32) |
144 | ok &= (mnk <= 256 * 1024 * 1024); |
145 | |
146 | if (!ok) return status::unimplemented; |
147 | } |
148 | |
149 | // choose kernel |
150 | auto co_type = with_bias() |
151 | ? d->bias_type() |
152 | : with_sum_ab() ? d->sum_ab_type |
153 | : (utils::one_of(eff_a_type(), s8, u8) |
154 | ? s32 |
155 | : d->c_type()); |
156 | |
157 | auto acc_type = utils::one_of(eff_a_type(), s8, u8) ? s32 : f32; |
158 | |
159 | if (d->c_type() == f16 && arch_ < compute::gpu_arch_t::xe_hpg) |
160 | acc_type = data_type::f16; |
161 | |
162 | if (types::data_type_size(acc_type) < 4) { |
163 | // Limited post-op support for low-precision accumulation. |
164 | ok &= !with_binary && IMPLICATION(with_sum_, sum_at_begin_); |
165 | } |
166 | |
167 | kernel_desc_t::compute_mode mode = kernel_desc_t::mode_default; |
168 | |
169 | if (attr()->mayidownconvert(f32, tf32)) |
170 | mode = static_cast<decltype(mode)>( |
171 | mode | kernel_desc_t::mode_tf32); |
172 | |
173 | if (attr()->mayidownconvert(f32, bf16)) |
174 | mode = static_cast<decltype(mode)>( |
175 | mode | kernel_desc_t::mode_bf16x1); |
176 | |
177 | status = kernel_desc_.select_kernel(arch_, stepping, |
178 | dev_info_->eu_count(), mode, batch_dims(), eff_transa(), |
179 | eff_transb(), eff_trans_bias(), swap_ab(), |
180 | with_a_zero_points(), with_b_zero_points(), |
181 | with_c_zero_points(), with_bias(), sum_ab(), alpha(), |
182 | beta(), post_ops_, eff_a_type(), eff_b_type(), |
183 | desc()->c_type(), co_type, acc_type, eff_align_a(), |
184 | eff_align_b(), align_c(), eff_m(), eff_n(), d->k(), |
185 | eff_lda(), eff_ldb(), d->ldc(), d->batch()); |
186 | |
187 | if (status != status::success) return status; |
188 | |
189 | // global k-parallel kernels don't support post-ops. |
190 | // use global k-parallel kernels only with f32 accumulation |
191 | bool k_parallel_global = kernel_desc_.driver_info()->kParallel; |
192 | bool with_eltwise = (post_ops_.find(eltwise) != -1); |
193 | |
194 | ok &= IMPLICATION(k_parallel_global, |
195 | !with_bias() && !with_eltwise && !with_binary |
196 | && utils::one_of(d->c_type(), f32, s32)); |
197 | |
198 | if (!ok) return status::unimplemented; |
199 | |
200 | return status::success; |
201 | } |
202 | |
203 | status_t query(query_t what, int idx, void *result) const override { |
204 | switch ((int)what) { |
205 | case (int)query::preferred_gpu_threads_per_eu: { |
206 | int grfs = kernel_desc_.driver_info()->grfCount; |
207 | *(int *)result = (grfs > 128) ? 4 : 8; |
208 | break; |
209 | } |
210 | default: return gpu_gemm_pd_t::query(what, idx, result); |
211 | } |
212 | return status::success; |
213 | } |
214 | |
215 | bool set_default_formats() { |
216 | using namespace data_type; |
217 | using namespace format_tag; |
218 | using arch_t = compute::gpu_arch_t; |
219 | |
220 | auto d = desc(); |
221 | |
222 | auto m = d->m(); |
223 | auto n = d->n(); |
224 | auto k = d->k(); |
225 | auto a_t = d->a_type(); |
226 | auto b_t = d->b_type(); |
227 | auto c_t = d->c_type(); |
228 | auto a_t_sz = types::data_type_size(a_t); |
229 | auto b_t_sz = types::data_type_size(b_t); |
230 | |
231 | bool is_f16 = utils::everyone_is(f16, a_t, b_t, c_t); |
232 | bool is_xe_hp_plus = arch_ >= arch_t::xe_hp; |
233 | |
234 | // Rename memory descriptors following column major format. |
235 | auto &a_desc = desc_.b_desc; |
236 | auto &b_desc = desc_.a_desc; |
237 | auto &c_desc = desc_.c_desc; |
238 | |
239 | memory_desc_wrapper a_mdw(&a_desc); |
240 | memory_desc_wrapper b_mdw(&b_desc); |
241 | memory_desc_wrapper c_mdw(&c_desc); |
242 | |
243 | bool a_any = a_mdw.format_any(); |
244 | bool b_any = b_mdw.format_any(); |
245 | bool c_any = c_mdw.format_any(); |
246 | |
247 | if (!a_any && !is_md_gemm_compatible_plain_format(&a_desc)) |
248 | return false; |
249 | if (!b_any && !is_md_gemm_compatible_plain_format(&b_desc)) |
250 | return false; |
251 | if (!c_any && !is_md_gemm_compatible_plain_format(&c_desc, true)) |
252 | return false; |
253 | |
254 | bool is_a_trans = (desc()->transa() == dnnl_trans); |
255 | bool is_b_trans = (desc()->transb() == dnnl_trans); |
256 | |
257 | auto lda = is_a_trans ? m : k; |
258 | auto ldb = is_b_trans ? k : n; |
259 | |
260 | auto is_aligned = [](dim_t ld, size_t sz, int byte) { |
261 | return ld * sz % byte == 0; |
262 | }; |
263 | |
264 | bool a_4B_aligned = is_aligned(lda, a_t_sz, 4); |
265 | bool b_4B_aligned = is_aligned(ldb, b_t_sz, 4); |
266 | bool ab_4B_aligned = a_4B_aligned && b_4B_aligned; |
267 | |
268 | bool a_tn_4B_aligned = is_aligned(k, a_t_sz, 4); |
269 | bool b_tn_4B_aligned = is_aligned(k, b_t_sz, 4); |
270 | bool ab_tn_4B_aligned = a_tn_4B_aligned && b_tn_4B_aligned; |
271 | |
272 | bool use_tn = (m <= 32 || n <= 32) && !ab_4B_aligned |
273 | && ab_tn_4B_aligned; |
274 | |
275 | bool batch = d->is_batched(); |
276 | |
277 | auto dotrans = batch ? acb : ba; |
278 | auto notrans = batch ? abc : ab; |
279 | |
280 | if (is_f16 && is_xe_hp_plus && use_tn) { |
281 | if (a_any && b_any) { |
282 | CHECK(memory_desc_init_by_tag(a_desc, dotrans)); |
283 | CHECK(memory_desc_init_by_tag(b_desc, notrans)); |
284 | } else if (a_any && !is_b_trans) { |
285 | CHECK(memory_desc_init_by_tag(a_desc, dotrans)); |
286 | } else if (b_any && is_a_trans) { |
287 | CHECK(memory_desc_init_by_tag(b_desc, notrans)); |
288 | } |
289 | } |
290 | |
291 | return gpu_gemm_pd_t::set_default_formats(); |
292 | } |
293 | |
294 | float alpha() const { return 1.0f; } |
295 | |
296 | float beta() const { return beta_; } |
297 | |
298 | bool with_bias() const { |
299 | return desc()->bias_type() != data_type::undef && !bias_via_binary_; |
300 | } |
301 | |
302 | int bias_cmask() const { |
303 | unsigned char to_cmask[8] = {0, 4, 2, 6, 1, 5, 3, 7}; |
304 | assert(unsigned(desc()->bias_mask()) < 8); |
305 | return with_bias() ? to_cmask[desc()->bias_mask() & 7] : -1; |
306 | } |
307 | |
308 | sum_ab_t sum_ab() const { return desc()->sum_ab; } |
309 | |
310 | bool with_sum_ab() const { return sum_ab() != sum_ab::sum_none; } |
311 | |
312 | int sum_ab_cmask() const { |
313 | switch (sum_ab()) { |
314 | default: |
315 | case sum_ab::sum_none: return 0; |
316 | case sum_ab::sum_a_row: return 1; |
317 | case sum_ab::sum_b_col: return 2; |
318 | } |
319 | } |
320 | |
321 | bool with_a_zero_points() const { return a_zp_; } |
322 | bool with_b_zero_points() const { return b_zp_; } |
323 | |
324 | bool with_c_zero_points() const { |
325 | return !attr()->zero_points_.has_default_values(DNNL_ARG_DST); |
326 | } |
327 | |
328 | bool swap_ab() const { return swap_ab_; } |
329 | |
330 | int batch_dims() const { |
331 | return nstl::max(desc()->c_desc.ndims - 2, 0); |
332 | } |
333 | |
334 | int align_a() const { |
335 | return int(utils::max_pow2_div( |
336 | types::data_type_size(desc()->a_type()) * desc()->lda())); |
337 | } |
338 | int align_b() const { |
339 | return int(utils::max_pow2_div( |
340 | types::data_type_size(desc()->b_type()) * desc()->ldb())); |
341 | } |
342 | int align_c() const { |
343 | return int(utils::max_pow2_div( |
344 | types::data_type_size(desc()->c_type()) * desc()->ldc())); |
345 | } |
346 | |
347 | int eff_align_a() const { return !swap_ab() ? align_a() : align_b(); } |
348 | int eff_align_b() const { return !swap_ab() ? align_b() : align_a(); } |
349 | bool eff_transa() const { |
350 | return !swap_ab() ? (desc()->transa() == dnnl_trans) |
351 | : (desc()->transb() == dnnl_notrans); |
352 | } |
353 | bool eff_transb() const { |
354 | return !swap_ab() ? (desc()->transb() == dnnl_trans) : false; |
355 | } |
356 | bool eff_trans_bias() const { |
357 | return swap_ab() ? (desc()->trans_bias() == dnnl_notrans) |
358 | : (desc()->trans_bias() == dnnl_trans); |
359 | } |
360 | dim_t eff_m() const { return !swap_ab() ? desc()->m() : desc()->n(); } |
361 | dim_t eff_n() const { return !swap_ab() ? desc()->n() : desc()->m(); } |
362 | dim_t eff_lda() const { |
363 | return !swap_ab() ? desc()->lda() : desc()->ldb(); |
364 | } |
365 | dim_t eff_ldb() const { |
366 | return !swap_ab() ? desc()->ldb() : desc()->lda(); |
367 | } |
368 | data_type_t eff_a_type() const { |
369 | return !swap_ab() ? desc()->a_type() : desc()->b_type(); |
370 | } |
371 | data_type_t eff_b_type() const { |
372 | return !swap_ab() ? desc()->b_type() : desc()->a_type(); |
373 | } |
374 | const gen_gemm_nocopy_kernel_desc_t *kernel_desc() const { |
375 | return &kernel_desc_; |
376 | } |
377 | |
378 | size_t dyn_offset_a = 0; |
379 | size_t dyn_offset_b = 0; |
380 | size_t dyn_offset_c = 0; |
381 | size_t dyn_offset_co = 0; |
382 | |
383 | bool swap_ab_ = false; |
384 | bool a_zp_ = false, b_zp_ = false; |
385 | |
386 | const compute::device_info_t *dev_info_; |
387 | compute::gpu_arch_t arch_ = compute::gpu_arch_t::unknown; |
388 | |
389 | kernel_desc_t kernel_desc_; |
390 | }; |
391 | |
392 | gen_gemm_t(const pd_t *apd) : gpu_gemm_t(apd) {} |
393 | |
394 | status_t init(engine_t *engine) override { return init_nocopy(engine); } |
395 | |
396 | status_t init_nocopy(engine_t *engine) { |
397 | using kernel_t = gen_gemm_kernel_t; |
398 | using namespace data_type; |
399 | |
400 | auto kd = pd()->kernel_desc(); |
401 | kernel_t kernel(*kd); |
402 | |
403 | create_kernel(engine, &nocopy_kernel_, &kernel); |
404 | |
405 | scalar_type_ = kd->scalar_type(); |
406 | |
407 | if (get_verbose() >= 2) { |
408 | auto info = kd->driver_info(); |
409 | printf("onednn_verbose,info,gpu,gemm,kernel:%dx%d,%dx%dx%d\n" , |
410 | info->unroll[LoopM], info->unroll[LoopN], info->wg[LoopM], |
411 | info->wg[LoopN], info->wg[LoopK]); |
412 | } |
413 | |
414 | return status::success; |
415 | } |
416 | |
417 | status_t execute(const gemm_exec_ctx_t &ctx) const override; |
418 | |
419 | private: |
420 | status_t launch_nocopy(const gemm_exec_ctx_t &ctx, |
421 | compute::compute_stream_t *s, const memory_storage_t &a, |
422 | const memory_storage_t &b, const memory_storage_t &c, |
423 | const memory_storage_t *ao, const memory_storage_t *bo, |
424 | const memory_storage_t &co, int po_count, |
425 | const memory_storage_t **po_src, int64_t offset_a, int64_t offset_b, |
426 | int64_t offset_c, int32_t offset_co, int32_t *offset_po_src, |
427 | int32_t lda, int32_t ldb, int32_t ldc, int32_t m, int32_t n, |
428 | int32_t k, int32_t k0, float alpha, float beta, int32_t cmask, |
429 | bool last_k_block, bool swapab, bool disable_hilbert) const; |
430 | |
431 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
432 | const CommonDriverInfo *nocopy_info() const { |
433 | return pd()->kernel_desc()->driver_info(); |
434 | } |
435 | |
436 | compute::kernel_t nocopy_kernel_; |
437 | compute::scalar_type_t scalar_type_; |
438 | }; |
439 | |
440 | } // namespace jit |
441 | } // namespace gpu |
442 | } // namespace impl |
443 | } // namespace dnnl |
444 | #endif |
445 | |
446 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
447 | |