1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/jit/gemm/gen_gemm.hpp"
18#include "common/c_types_map.hpp"
19#include "common/dnnl_traits.hpp"
20#include "common/float16.hpp"
21#include "common/math_utils.hpp"
22#include "common/type_helpers.hpp"
23#include "gpu/jit/gemm/gemm_walk_orders.hpp"
24#include "gpu/jit/gemm/gen_gemm_kernel_common.hpp"
25
26namespace dnnl {
27namespace impl {
28namespace gpu {
29namespace jit {
30
31status_t gen_gemm_t::launch_nocopy(const gemm_exec_ctx_t &ctx,
32 compute::compute_stream_t *compute_stream, const memory_storage_t &a,
33 const memory_storage_t &b, const memory_storage_t &c,
34 const memory_storage_t *ao, const memory_storage_t *bo,
35 const memory_storage_t &co, int po_count,
36 const memory_storage_t **po_srcs, int64_t offset_a, int64_t offset_b,
37 int64_t offset_c, int32_t offset_co, int32_t *offset_po_src,
38 int32_t lda, int32_t ldb, int32_t ldc, int32_t m, int32_t n, int32_t k,
39 int32_t k0, float alpha, float beta, int32_t cmask, bool last_k_block,
40 bool swapab, bool disable_hilbert) const {
41
42 uint32_t flags = 0;
43 bool k_parallel
44 = (nocopy_info()->kParallel || nocopy_info()->kParallelLocal);
45
46 auto problem = pd()->kernel_desc()->problem();
47
48 auto stride_a0 = int32_t(pd()->desc()->stride_a(0));
49 auto stride_b0 = int32_t(pd()->desc()->stride_b(0));
50 auto stride_c0 = int32_t(pd()->desc()->stride_c(0));
51
52 auto stride_a1 = int32_t(pd()->desc()->stride_a(1));
53 auto stride_b1 = int32_t(pd()->desc()->stride_b(1));
54 auto stride_c1 = int32_t(pd()->desc()->stride_c(1));
55
56 if (swapab) {
57 std::swap(stride_a0, stride_b0);
58 std::swap(stride_a1, stride_b1);
59 }
60
61 if (!last_k_block) flags |= FlagNonfinalKBlock;
62 if (cmask & 1) flags |= FlagCOColumn;
63 if (cmask & 2) flags |= FlagCORow;
64
65 compute::kernel_arg_list_t arg_list;
66 int argn = 0;
67
68 arg_list.set(argn++, a);
69 arg_list.set(argn++, b);
70 arg_list.set(argn++, c);
71 arg_list.set(argn++, offset_a);
72 arg_list.set(argn++, offset_b);
73 arg_list.set(argn++, offset_c);
74 arg_list.set(argn++, lda);
75 arg_list.set(argn++, ldb);
76 arg_list.set(argn++, ldc);
77 arg_list.set(argn++, m);
78 arg_list.set(argn++, n);
79 arg_list.set(argn++, k);
80
81 set_scalar_arg_cvt(arg_list, argn++, alpha, scalar_type_);
82 set_scalar_arg_cvt(arg_list, argn++, beta, scalar_type_);
83
84 if (pd()->with_a_zero_points()) arg_list.set(argn++, *ao);
85 if (pd()->with_b_zero_points()) arg_list.set(argn++, *bo);
86 if (pd()->with_c_zero_points() || pd()->with_bias()
87 || pd()->with_sum_ab()) {
88 arg_list.set(argn++, co);
89 arg_list.set(argn++, offset_co);
90 if (pd()->with_bias()) {
91 int32_t ldco = pd()->desc()->ld_bias();
92 arg_list.set(argn++, ldco);
93 }
94 }
95 for (int i = 0; i < po_count; i++) {
96 if (!po_srcs[i]) continue;
97 arg_list.set(argn++, *po_srcs[i]);
98 arg_list.set(argn++, offset_po_src[i]);
99
100 if (problem->binaryRow[i] && problem->binaryCol[i])
101 arg_list.set(argn++, int32_t(pd()->ld_binary(i)));
102 }
103 arg_list.set(argn++, flags);
104 if (k_parallel) arg_list.set(argn++, k0);
105
106 if (pd()->batch_dims() >= 1) {
107 arg_list.set(argn++, stride_a0);
108 arg_list.set(argn++, stride_b0);
109 arg_list.set(argn++, stride_c0);
110 for (int i = 0; i < po_count; i++)
111 if (problem->binaryBatch[i])
112 arg_list.set(argn++, int32_t(pd()->stride_binary(i, 0)));
113 }
114
115 if (pd()->batch_dims() >= 2) {
116 auto batchSize1 = uint32_t(pd()->desc()->c_desc.dims[1]);
117 uint32_t recipBatchSize1 = (uint32_t)utils::div_up(
118 uint64_t(0x100000000) << math::ilog2q(batchSize1), batchSize1);
119 arg_list.set(argn++, stride_a1);
120 arg_list.set(argn++, stride_b1);
121 arg_list.set(argn++, stride_c1);
122 for (int i = 0; i < po_count; i++)
123 if (problem->binaryBatch[i])
124 arg_list.set(argn++, int32_t(pd()->stride_binary(i, 1)));
125 arg_list.set(argn++, batchSize1);
126 arg_list.set(argn++, recipBatchSize1);
127 }
128
129 size_t gws[3] = {0, 0, 1};
130
131 gws[0] = utils::div_up(m, nocopy_info()->unroll[LoopM]);
132 gws[1] = utils::div_up(n, nocopy_info()->unroll[LoopN]);
133 gws[2] = k_parallel ? nstl::max(1, utils::div_up(k, k0)) : 1;
134
135 size_t lws[3] = {size_t(nocopy_info()->wg[0]), size_t(nocopy_info()->wg[1]),
136 size_t(nocopy_info()->wg[2])};
137
138 if (nocopy_info()->isNMK()) {
139 std::swap(lws[0], lws[1]);
140 std::swap(gws[0], gws[1]);
141 }
142
143 if (nocopy_info()->fusedEUs() && (lws[0] > 1))
144 gws[0] = utils::rnd_up(gws[0], 2);
145
146 lws[2] = nstl::min(lws[2], gws[2]);
147
148 int last_non_1 = 2;
149 for (; last_non_1 >= 0 && (gws[last_non_1] == 1 || lws[last_non_1] == 1);
150 last_non_1--)
151 ;
152
153 for (int d = 0; d < 3; d++) {
154 if (nocopy_info()->fixedWG() || (gws[d] > lws[d]))
155 gws[d] = utils::rnd_up(gws[d], lws[d]);
156 else {
157 // Workaround to avoid local ID reordering until reqd_walk_group_order implemented in UMD.
158 if (pd()->arch_ >= compute::gpu_arch_t::xe_hp && d < last_non_1)
159 gws[d] = utils::rnd_up_pow2(gws[d]);
160 lws[d] = gws[d];
161 }
162 }
163
164 lws[1] *= nocopy_info()->wgExpand;
165 gws[1] *= nocopy_info()->wgExpand;
166
167 gws[2] *= pd()->desc()->batch();
168
169 gemm_linear_order_args(arg_list, argn, lws, gws, m, n, disable_hilbert,
170 *nocopy_info(), pd()->dev_info_);
171
172 if (nocopy_info()->perKSLM > 0) {
173 size_t slm = nocopy_info()->slm;
174 if (lws[2] > 1) slm = nstl::max(slm, nocopy_info()->perKSLM * lws[2]);
175 arg_list.set(argn++, slm, nullptr);
176 }
177
178 lws[0] *= nocopy_info()->subgroupSize;
179 gws[0] *= nocopy_info()->subgroupSize;
180
181 auto nd_range = compute::nd_range_t(gws, lws);
182 return parallel_for(ctx, nd_range, nocopy_kernel_, arg_list);
183}
184
185status_t gen_gemm_t::execute(const gemm_exec_ctx_t &ctx) const {
186 auto *compute_stream
187 = utils::downcast<compute::compute_stream_t *>(ctx.stream());
188
189 const auto d = pd()->desc();
190 const auto &problem = *pd()->kernel_desc()->problem();
191
192 const bool swapab = pd()->swap_ab();
193
194 auto a_type = pd()->eff_a_type();
195 auto b_type = pd()->eff_b_type();
196 auto c_type = d->c_type();
197
198 const auto m = pd()->eff_m();
199 const auto n = pd()->eff_n();
200 auto k = d->k();
201
202 const bool transa = pd()->eff_transa();
203 const bool transb = pd()->eff_transb();
204
205 const auto lda = pd()->eff_lda();
206 const auto ldb = pd()->eff_ldb();
207 auto ldc = d->ldc();
208 auto ldco = pd()->with_bias() ? d->ld_bias() : 0;
209
210 auto alpha = pd()->alpha();
211 auto beta = pd()->beta();
212
213 bool k_parallel_global = nocopy_info()->kParallel;
214 bool k_parallel_local = nocopy_info()->kParallelLocal;
215
216 auto &a = swapab ? GEMM_CTX_ARG_STORAGE(a) : GEMM_CTX_ARG_STORAGE(b);
217 auto &b = swapab ? GEMM_CTX_ARG_STORAGE(b) : GEMM_CTX_ARG_STORAGE(a);
218 auto &c = GEMM_CTX_ARG_STORAGE(c);
219 auto &c_zp = GEMM_CTX_ARG_STORAGE(c_zero_point);
220 auto &bias = GEMM_CTX_ARG_STORAGE(bias);
221 auto &sum_ab = GEMM_CTX_ARG_STORAGE(sum_ab);
222 auto *co = &c_zp;
223 memory_storage_t *ao = nullptr, *bo = nullptr;
224
225 const memory_storage_t *po_srcs[GEMM_MAX_PO];
226
227 int po_count = pd()->post_ops()->len();
228 assert(po_count <= GEMM_MAX_PO);
229
230 for (int i = 0; i < po_count; i++) {
231 auto &src = pd()->binary_srcs()[i];
232 switch (src.type) {
233 case pd_t::binary_src_t::binary:
234 po_srcs[i]
235 = ctx.args()
236 .exec_args
237 .at(DNNL_ARG_ATTR_MULTIPLE_POST_OP(src.index)
238 | DNNL_ARG_SRC_1)
239 .mem->memory_storage();
240 break;
241 case pd_t::binary_src_t::bias: po_srcs[i] = &bias; break;
242 case pd_t::binary_src_t::scales:
243 switch (src.index) {
244 case DNNL_ARG_WEIGHTS:
245 po_srcs[i] = &GEMM_CTX_ARG_STORAGE(a_scales);
246 break;
247 case DNNL_ARG_SRC:
248 po_srcs[i] = &GEMM_CTX_ARG_STORAGE(b_scales);
249 break;
250 case DNNL_ARG_DST:
251 po_srcs[i] = &GEMM_CTX_ARG_STORAGE(c_scales);
252 break;
253 default:
254 po_srcs[i] = nullptr;
255 assert(!"invalid scale type");
256 break;
257 }
258 break;
259 default: po_srcs[i] = nullptr; break;
260 }
261 }
262
263 size_t off_a0
264 = a.offset() / types::data_type_size(a_type) + pd()->dyn_offset_a;
265 size_t off_b0
266 = b.offset() / types::data_type_size(b_type) + pd()->dyn_offset_b;
267 size_t off_c0
268 = c.offset() / types::data_type_size(c_type) + pd()->dyn_offset_c;
269 size_t off_co0 = 0;
270
271 int32_t po_offsets0[GEMM_MAX_PO] = {0}, po_offsets[GEMM_MAX_PO] = {0};
272 for (int i = 0; i < po_count; i++)
273 if (po_srcs[i])
274 po_offsets0[i] = po_srcs[i]->offset() / problem.Tbinary[i];
275
276 int cmask = 0;
277
278 if (pd()->with_c_zero_points()) {
279 off_co0 = co->offset() / types::data_type_size(c_type)
280 + pd()->dyn_offset_co;
281 pd()->attr()->zero_points_.get(DNNL_ARG_DST, &cmask);
282 } else if (pd()->with_bias()) {
283 off_co0 = bias.offset() / types::data_type_size(c_type);
284 co = &bias;
285 cmask = pd()->bias_cmask();
286 } else if (pd()->with_sum_ab()) {
287 off_co0 = sum_ab.offset() / types::data_type_size(c_type);
288 co = &sum_ab;
289 cmask = pd()->sum_ab_cmask();
290 }
291
292 if (pd()->with_a_zero_points() || pd()->with_b_zero_points()) {
293 ao = &GEMM_CTX_ARG_STORAGE(a_zero_point);
294 bo = &GEMM_CTX_ARG_STORAGE(b_zero_point);
295 if (swapab) std::swap(ao, bo);
296 }
297
298 if (swapab) {
299 uint8_t swap_table[4] = {0, 2, 1, 3};
300 cmask = (cmask & ~3) | swap_table[cmask & 3];
301 }
302
303 status_t status;
304
305 auto block_m = nocopy_info()->blocking[0];
306 auto block_n = nocopy_info()->blocking[1];
307 auto block_k = nocopy_info()->blocking[2];
308
309 bool disable_hilbert = (k <= 64) && nocopy_info()->isHilbert();
310 if (disable_hilbert) {
311 block_m = nocopy_info()->blockingAlt[0];
312 block_n = nocopy_info()->blockingAlt[1];
313 }
314
315 if (!utils::one_of(pd()->desc()->c_type(), data_type::f32, data_type::f16))
316 block_k = k;
317 if (pd()->post_ops()->len() > 0
318 && pd()->post_ops()->entry_[0].kind != primitive_kind::sum)
319 block_k = k;
320
321 if (k_parallel_global)
322 block_k = pd()->kernel_desc()->aux_params()->k0;
323 else if (k_parallel_local)
324 block_k = utils::div_up(k, nocopy_info()->wg[2]);
325
326 block_m = utils::rnd_up(
327 block_m, nocopy_info()->wg[0] * nocopy_info()->unroll[0]);
328 block_n = utils::rnd_up(
329 block_n, nocopy_info()->wg[1] * nocopy_info()->unroll[1]);
330 block_k = utils::rnd_up(block_k, nocopy_info()->unroll[2]);
331 block_k = nstl::max(block_k, 2 * nocopy_info()->unroll[2]);
332
333 int32_t k0 = 1;
334 if (k_parallel_local || k_parallel_global) {
335 k0 = block_k;
336 block_k = nstl::max<dim_t>(k, 1);
337
338 if (k_parallel_global && beta != 1.0f
339 && (k > k0 * nocopy_info()->wg[2])) {
340 status = launch_nocopy(ctx, compute_stream, a, b, c, ao, bo, *co,
341 po_count, po_srcs, off_a0, off_b0, off_c0, int32_t(off_co0),
342 po_offsets0, lda, ldb, ldc, m, n, 0, 1, 1.0f, beta, 0,
343 false, swapab, true);
344 beta = 1.0f;
345 }
346 }
347
348 for (int64_t Bk = 0; Bk < nstl::max<dim_t>(k, 1); Bk += block_k) {
349 int64_t size_k = k - Bk;
350 bool last_k_block = (size_k <= block_k);
351 if (!last_k_block) size_k = block_k;
352
353 for (int64_t Bm = 0; Bm < m; Bm += block_m) {
354 int64_t size_m = m - Bm;
355 if (size_m > block_m) size_m = block_m;
356
357 auto off_a_src
358 = off_a0 + (!transa ? (Bm + Bk * lda) : (Bk + Bm * lda));
359
360 for (int64_t Bn = 0; Bn < n; Bn += block_n) {
361 int64_t size_n = n - Bn;
362 if (size_n > block_n) size_n = block_n;
363
364 auto off_b_src = off_b0
365 + (!transb ? (Bk + Bn * ldb) : (Bn + Bk * ldb));
366
367 auto off_c = off_c0 + Bm + Bn * ldc;
368 auto off_co = int32_t(off_co0);
369 switch (cmask & 3) {
370 case 1: off_co += Bn; break;
371 case 2: off_co += Bm; break;
372 case 3:
373 off_co += isColMajor(problem.CO.layout)
374 ? (Bn * ldco + Bm)
375 : (Bm * ldco + Bn);
376 break;
377 }
378
379 for (int i = 0; i < po_count; i++) {
380 po_offsets[i] = po_offsets0[i];
381 bool row = problem.binaryRow[i], col = problem.binaryCol[i];
382 if (row && col) {
383 auto ld = pd()->ld_binary(i);
384 po_offsets[i] += isColMajor(problem.binary[i].layout)
385 ? (Bn * ld + Bm)
386 : (Bm * ld + Bn);
387 } else if (row)
388 po_offsets[i] += Bm;
389 else if (col)
390 po_offsets[i] += Bn;
391 }
392
393 float eff_beta = (Bk == 0) ? beta : 1.0f;
394 status = launch_nocopy(ctx, compute_stream, a, b, c, ao, bo,
395 *co, po_count, po_srcs, off_a_src, off_b_src, off_c,
396 off_co, po_offsets, lda, ldb, ldc, size_m, size_n,
397 size_k, k0, alpha, eff_beta, cmask, last_k_block,
398 swapab, disable_hilbert);
399
400 if (status) return status;
401 }
402 }
403 }
404
405 return status::success;
406}
407
408} // namespace jit
409} // namespace gpu
410} // namespace impl
411} // namespace dnnl
412
413// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
414