1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "gpu/jit/gemm/gen_gemm.hpp" |
18 | #include "common/c_types_map.hpp" |
19 | #include "common/dnnl_traits.hpp" |
20 | #include "common/float16.hpp" |
21 | #include "common/math_utils.hpp" |
22 | #include "common/type_helpers.hpp" |
23 | #include "gpu/jit/gemm/gemm_walk_orders.hpp" |
24 | #include "gpu/jit/gemm/gen_gemm_kernel_common.hpp" |
25 | |
26 | namespace dnnl { |
27 | namespace impl { |
28 | namespace gpu { |
29 | namespace jit { |
30 | |
31 | status_t gen_gemm_t::launch_nocopy(const gemm_exec_ctx_t &ctx, |
32 | compute::compute_stream_t *compute_stream, const memory_storage_t &a, |
33 | const memory_storage_t &b, const memory_storage_t &c, |
34 | const memory_storage_t *ao, const memory_storage_t *bo, |
35 | const memory_storage_t &co, int po_count, |
36 | const memory_storage_t **po_srcs, int64_t offset_a, int64_t offset_b, |
37 | int64_t offset_c, int32_t offset_co, int32_t *offset_po_src, |
38 | int32_t lda, int32_t ldb, int32_t ldc, int32_t m, int32_t n, int32_t k, |
39 | int32_t k0, float alpha, float beta, int32_t cmask, bool last_k_block, |
40 | bool swapab, bool disable_hilbert) const { |
41 | |
42 | uint32_t flags = 0; |
43 | bool k_parallel |
44 | = (nocopy_info()->kParallel || nocopy_info()->kParallelLocal); |
45 | |
46 | auto problem = pd()->kernel_desc()->problem(); |
47 | |
48 | auto stride_a0 = int32_t(pd()->desc()->stride_a(0)); |
49 | auto stride_b0 = int32_t(pd()->desc()->stride_b(0)); |
50 | auto stride_c0 = int32_t(pd()->desc()->stride_c(0)); |
51 | |
52 | auto stride_a1 = int32_t(pd()->desc()->stride_a(1)); |
53 | auto stride_b1 = int32_t(pd()->desc()->stride_b(1)); |
54 | auto stride_c1 = int32_t(pd()->desc()->stride_c(1)); |
55 | |
56 | if (swapab) { |
57 | std::swap(stride_a0, stride_b0); |
58 | std::swap(stride_a1, stride_b1); |
59 | } |
60 | |
61 | if (!last_k_block) flags |= FlagNonfinalKBlock; |
62 | if (cmask & 1) flags |= FlagCOColumn; |
63 | if (cmask & 2) flags |= FlagCORow; |
64 | |
65 | compute::kernel_arg_list_t arg_list; |
66 | int argn = 0; |
67 | |
68 | arg_list.set(argn++, a); |
69 | arg_list.set(argn++, b); |
70 | arg_list.set(argn++, c); |
71 | arg_list.set(argn++, offset_a); |
72 | arg_list.set(argn++, offset_b); |
73 | arg_list.set(argn++, offset_c); |
74 | arg_list.set(argn++, lda); |
75 | arg_list.set(argn++, ldb); |
76 | arg_list.set(argn++, ldc); |
77 | arg_list.set(argn++, m); |
78 | arg_list.set(argn++, n); |
79 | arg_list.set(argn++, k); |
80 | |
81 | set_scalar_arg_cvt(arg_list, argn++, alpha, scalar_type_); |
82 | set_scalar_arg_cvt(arg_list, argn++, beta, scalar_type_); |
83 | |
84 | if (pd()->with_a_zero_points()) arg_list.set(argn++, *ao); |
85 | if (pd()->with_b_zero_points()) arg_list.set(argn++, *bo); |
86 | if (pd()->with_c_zero_points() || pd()->with_bias() |
87 | || pd()->with_sum_ab()) { |
88 | arg_list.set(argn++, co); |
89 | arg_list.set(argn++, offset_co); |
90 | if (pd()->with_bias()) { |
91 | int32_t ldco = pd()->desc()->ld_bias(); |
92 | arg_list.set(argn++, ldco); |
93 | } |
94 | } |
95 | for (int i = 0; i < po_count; i++) { |
96 | if (!po_srcs[i]) continue; |
97 | arg_list.set(argn++, *po_srcs[i]); |
98 | arg_list.set(argn++, offset_po_src[i]); |
99 | |
100 | if (problem->binaryRow[i] && problem->binaryCol[i]) |
101 | arg_list.set(argn++, int32_t(pd()->ld_binary(i))); |
102 | } |
103 | arg_list.set(argn++, flags); |
104 | if (k_parallel) arg_list.set(argn++, k0); |
105 | |
106 | if (pd()->batch_dims() >= 1) { |
107 | arg_list.set(argn++, stride_a0); |
108 | arg_list.set(argn++, stride_b0); |
109 | arg_list.set(argn++, stride_c0); |
110 | for (int i = 0; i < po_count; i++) |
111 | if (problem->binaryBatch[i]) |
112 | arg_list.set(argn++, int32_t(pd()->stride_binary(i, 0))); |
113 | } |
114 | |
115 | if (pd()->batch_dims() >= 2) { |
116 | auto batchSize1 = uint32_t(pd()->desc()->c_desc.dims[1]); |
117 | uint32_t recipBatchSize1 = (uint32_t)utils::div_up( |
118 | uint64_t(0x100000000) << math::ilog2q(batchSize1), batchSize1); |
119 | arg_list.set(argn++, stride_a1); |
120 | arg_list.set(argn++, stride_b1); |
121 | arg_list.set(argn++, stride_c1); |
122 | for (int i = 0; i < po_count; i++) |
123 | if (problem->binaryBatch[i]) |
124 | arg_list.set(argn++, int32_t(pd()->stride_binary(i, 1))); |
125 | arg_list.set(argn++, batchSize1); |
126 | arg_list.set(argn++, recipBatchSize1); |
127 | } |
128 | |
129 | size_t gws[3] = {0, 0, 1}; |
130 | |
131 | gws[0] = utils::div_up(m, nocopy_info()->unroll[LoopM]); |
132 | gws[1] = utils::div_up(n, nocopy_info()->unroll[LoopN]); |
133 | gws[2] = k_parallel ? nstl::max(1, utils::div_up(k, k0)) : 1; |
134 | |
135 | size_t lws[3] = {size_t(nocopy_info()->wg[0]), size_t(nocopy_info()->wg[1]), |
136 | size_t(nocopy_info()->wg[2])}; |
137 | |
138 | if (nocopy_info()->isNMK()) { |
139 | std::swap(lws[0], lws[1]); |
140 | std::swap(gws[0], gws[1]); |
141 | } |
142 | |
143 | if (nocopy_info()->fusedEUs() && (lws[0] > 1)) |
144 | gws[0] = utils::rnd_up(gws[0], 2); |
145 | |
146 | lws[2] = nstl::min(lws[2], gws[2]); |
147 | |
148 | int last_non_1 = 2; |
149 | for (; last_non_1 >= 0 && (gws[last_non_1] == 1 || lws[last_non_1] == 1); |
150 | last_non_1--) |
151 | ; |
152 | |
153 | for (int d = 0; d < 3; d++) { |
154 | if (nocopy_info()->fixedWG() || (gws[d] > lws[d])) |
155 | gws[d] = utils::rnd_up(gws[d], lws[d]); |
156 | else { |
157 | // Workaround to avoid local ID reordering until reqd_walk_group_order implemented in UMD. |
158 | if (pd()->arch_ >= compute::gpu_arch_t::xe_hp && d < last_non_1) |
159 | gws[d] = utils::rnd_up_pow2(gws[d]); |
160 | lws[d] = gws[d]; |
161 | } |
162 | } |
163 | |
164 | lws[1] *= nocopy_info()->wgExpand; |
165 | gws[1] *= nocopy_info()->wgExpand; |
166 | |
167 | gws[2] *= pd()->desc()->batch(); |
168 | |
169 | gemm_linear_order_args(arg_list, argn, lws, gws, m, n, disable_hilbert, |
170 | *nocopy_info(), pd()->dev_info_); |
171 | |
172 | if (nocopy_info()->perKSLM > 0) { |
173 | size_t slm = nocopy_info()->slm; |
174 | if (lws[2] > 1) slm = nstl::max(slm, nocopy_info()->perKSLM * lws[2]); |
175 | arg_list.set(argn++, slm, nullptr); |
176 | } |
177 | |
178 | lws[0] *= nocopy_info()->subgroupSize; |
179 | gws[0] *= nocopy_info()->subgroupSize; |
180 | |
181 | auto nd_range = compute::nd_range_t(gws, lws); |
182 | return parallel_for(ctx, nd_range, nocopy_kernel_, arg_list); |
183 | } |
184 | |
185 | status_t gen_gemm_t::execute(const gemm_exec_ctx_t &ctx) const { |
186 | auto *compute_stream |
187 | = utils::downcast<compute::compute_stream_t *>(ctx.stream()); |
188 | |
189 | const auto d = pd()->desc(); |
190 | const auto &problem = *pd()->kernel_desc()->problem(); |
191 | |
192 | const bool swapab = pd()->swap_ab(); |
193 | |
194 | auto a_type = pd()->eff_a_type(); |
195 | auto b_type = pd()->eff_b_type(); |
196 | auto c_type = d->c_type(); |
197 | |
198 | const auto m = pd()->eff_m(); |
199 | const auto n = pd()->eff_n(); |
200 | auto k = d->k(); |
201 | |
202 | const bool transa = pd()->eff_transa(); |
203 | const bool transb = pd()->eff_transb(); |
204 | |
205 | const auto lda = pd()->eff_lda(); |
206 | const auto ldb = pd()->eff_ldb(); |
207 | auto ldc = d->ldc(); |
208 | auto ldco = pd()->with_bias() ? d->ld_bias() : 0; |
209 | |
210 | auto alpha = pd()->alpha(); |
211 | auto beta = pd()->beta(); |
212 | |
213 | bool k_parallel_global = nocopy_info()->kParallel; |
214 | bool k_parallel_local = nocopy_info()->kParallelLocal; |
215 | |
216 | auto &a = swapab ? GEMM_CTX_ARG_STORAGE(a) : GEMM_CTX_ARG_STORAGE(b); |
217 | auto &b = swapab ? GEMM_CTX_ARG_STORAGE(b) : GEMM_CTX_ARG_STORAGE(a); |
218 | auto &c = GEMM_CTX_ARG_STORAGE(c); |
219 | auto &c_zp = GEMM_CTX_ARG_STORAGE(c_zero_point); |
220 | auto &bias = GEMM_CTX_ARG_STORAGE(bias); |
221 | auto &sum_ab = GEMM_CTX_ARG_STORAGE(sum_ab); |
222 | auto *co = &c_zp; |
223 | memory_storage_t *ao = nullptr, *bo = nullptr; |
224 | |
225 | const memory_storage_t *po_srcs[GEMM_MAX_PO]; |
226 | |
227 | int po_count = pd()->post_ops()->len(); |
228 | assert(po_count <= GEMM_MAX_PO); |
229 | |
230 | for (int i = 0; i < po_count; i++) { |
231 | auto &src = pd()->binary_srcs()[i]; |
232 | switch (src.type) { |
233 | case pd_t::binary_src_t::binary: |
234 | po_srcs[i] |
235 | = ctx.args() |
236 | .exec_args |
237 | .at(DNNL_ARG_ATTR_MULTIPLE_POST_OP(src.index) |
238 | | DNNL_ARG_SRC_1) |
239 | .mem->memory_storage(); |
240 | break; |
241 | case pd_t::binary_src_t::bias: po_srcs[i] = &bias; break; |
242 | case pd_t::binary_src_t::scales: |
243 | switch (src.index) { |
244 | case DNNL_ARG_WEIGHTS: |
245 | po_srcs[i] = &GEMM_CTX_ARG_STORAGE(a_scales); |
246 | break; |
247 | case DNNL_ARG_SRC: |
248 | po_srcs[i] = &GEMM_CTX_ARG_STORAGE(b_scales); |
249 | break; |
250 | case DNNL_ARG_DST: |
251 | po_srcs[i] = &GEMM_CTX_ARG_STORAGE(c_scales); |
252 | break; |
253 | default: |
254 | po_srcs[i] = nullptr; |
255 | assert(!"invalid scale type" ); |
256 | break; |
257 | } |
258 | break; |
259 | default: po_srcs[i] = nullptr; break; |
260 | } |
261 | } |
262 | |
263 | size_t off_a0 |
264 | = a.offset() / types::data_type_size(a_type) + pd()->dyn_offset_a; |
265 | size_t off_b0 |
266 | = b.offset() / types::data_type_size(b_type) + pd()->dyn_offset_b; |
267 | size_t off_c0 |
268 | = c.offset() / types::data_type_size(c_type) + pd()->dyn_offset_c; |
269 | size_t off_co0 = 0; |
270 | |
271 | int32_t po_offsets0[GEMM_MAX_PO] = {0}, po_offsets[GEMM_MAX_PO] = {0}; |
272 | for (int i = 0; i < po_count; i++) |
273 | if (po_srcs[i]) |
274 | po_offsets0[i] = po_srcs[i]->offset() / problem.Tbinary[i]; |
275 | |
276 | int cmask = 0; |
277 | |
278 | if (pd()->with_c_zero_points()) { |
279 | off_co0 = co->offset() / types::data_type_size(c_type) |
280 | + pd()->dyn_offset_co; |
281 | pd()->attr()->zero_points_.get(DNNL_ARG_DST, &cmask); |
282 | } else if (pd()->with_bias()) { |
283 | off_co0 = bias.offset() / types::data_type_size(c_type); |
284 | co = &bias; |
285 | cmask = pd()->bias_cmask(); |
286 | } else if (pd()->with_sum_ab()) { |
287 | off_co0 = sum_ab.offset() / types::data_type_size(c_type); |
288 | co = &sum_ab; |
289 | cmask = pd()->sum_ab_cmask(); |
290 | } |
291 | |
292 | if (pd()->with_a_zero_points() || pd()->with_b_zero_points()) { |
293 | ao = &GEMM_CTX_ARG_STORAGE(a_zero_point); |
294 | bo = &GEMM_CTX_ARG_STORAGE(b_zero_point); |
295 | if (swapab) std::swap(ao, bo); |
296 | } |
297 | |
298 | if (swapab) { |
299 | uint8_t swap_table[4] = {0, 2, 1, 3}; |
300 | cmask = (cmask & ~3) | swap_table[cmask & 3]; |
301 | } |
302 | |
303 | status_t status; |
304 | |
305 | auto block_m = nocopy_info()->blocking[0]; |
306 | auto block_n = nocopy_info()->blocking[1]; |
307 | auto block_k = nocopy_info()->blocking[2]; |
308 | |
309 | bool disable_hilbert = (k <= 64) && nocopy_info()->isHilbert(); |
310 | if (disable_hilbert) { |
311 | block_m = nocopy_info()->blockingAlt[0]; |
312 | block_n = nocopy_info()->blockingAlt[1]; |
313 | } |
314 | |
315 | if (!utils::one_of(pd()->desc()->c_type(), data_type::f32, data_type::f16)) |
316 | block_k = k; |
317 | if (pd()->post_ops()->len() > 0 |
318 | && pd()->post_ops()->entry_[0].kind != primitive_kind::sum) |
319 | block_k = k; |
320 | |
321 | if (k_parallel_global) |
322 | block_k = pd()->kernel_desc()->aux_params()->k0; |
323 | else if (k_parallel_local) |
324 | block_k = utils::div_up(k, nocopy_info()->wg[2]); |
325 | |
326 | block_m = utils::rnd_up( |
327 | block_m, nocopy_info()->wg[0] * nocopy_info()->unroll[0]); |
328 | block_n = utils::rnd_up( |
329 | block_n, nocopy_info()->wg[1] * nocopy_info()->unroll[1]); |
330 | block_k = utils::rnd_up(block_k, nocopy_info()->unroll[2]); |
331 | block_k = nstl::max(block_k, 2 * nocopy_info()->unroll[2]); |
332 | |
333 | int32_t k0 = 1; |
334 | if (k_parallel_local || k_parallel_global) { |
335 | k0 = block_k; |
336 | block_k = nstl::max<dim_t>(k, 1); |
337 | |
338 | if (k_parallel_global && beta != 1.0f |
339 | && (k > k0 * nocopy_info()->wg[2])) { |
340 | status = launch_nocopy(ctx, compute_stream, a, b, c, ao, bo, *co, |
341 | po_count, po_srcs, off_a0, off_b0, off_c0, int32_t(off_co0), |
342 | po_offsets0, lda, ldb, ldc, m, n, 0, 1, 1.0f, beta, 0, |
343 | false, swapab, true); |
344 | beta = 1.0f; |
345 | } |
346 | } |
347 | |
348 | for (int64_t Bk = 0; Bk < nstl::max<dim_t>(k, 1); Bk += block_k) { |
349 | int64_t size_k = k - Bk; |
350 | bool last_k_block = (size_k <= block_k); |
351 | if (!last_k_block) size_k = block_k; |
352 | |
353 | for (int64_t Bm = 0; Bm < m; Bm += block_m) { |
354 | int64_t size_m = m - Bm; |
355 | if (size_m > block_m) size_m = block_m; |
356 | |
357 | auto off_a_src |
358 | = off_a0 + (!transa ? (Bm + Bk * lda) : (Bk + Bm * lda)); |
359 | |
360 | for (int64_t Bn = 0; Bn < n; Bn += block_n) { |
361 | int64_t size_n = n - Bn; |
362 | if (size_n > block_n) size_n = block_n; |
363 | |
364 | auto off_b_src = off_b0 |
365 | + (!transb ? (Bk + Bn * ldb) : (Bn + Bk * ldb)); |
366 | |
367 | auto off_c = off_c0 + Bm + Bn * ldc; |
368 | auto off_co = int32_t(off_co0); |
369 | switch (cmask & 3) { |
370 | case 1: off_co += Bn; break; |
371 | case 2: off_co += Bm; break; |
372 | case 3: |
373 | off_co += isColMajor(problem.CO.layout) |
374 | ? (Bn * ldco + Bm) |
375 | : (Bm * ldco + Bn); |
376 | break; |
377 | } |
378 | |
379 | for (int i = 0; i < po_count; i++) { |
380 | po_offsets[i] = po_offsets0[i]; |
381 | bool row = problem.binaryRow[i], col = problem.binaryCol[i]; |
382 | if (row && col) { |
383 | auto ld = pd()->ld_binary(i); |
384 | po_offsets[i] += isColMajor(problem.binary[i].layout) |
385 | ? (Bn * ld + Bm) |
386 | : (Bm * ld + Bn); |
387 | } else if (row) |
388 | po_offsets[i] += Bm; |
389 | else if (col) |
390 | po_offsets[i] += Bn; |
391 | } |
392 | |
393 | float eff_beta = (Bk == 0) ? beta : 1.0f; |
394 | status = launch_nocopy(ctx, compute_stream, a, b, c, ao, bo, |
395 | *co, po_count, po_srcs, off_a_src, off_b_src, off_c, |
396 | off_co, po_offsets, lda, ldb, ldc, size_m, size_n, |
397 | size_k, k0, alpha, eff_beta, cmask, last_k_block, |
398 | swapab, disable_hilbert); |
399 | |
400 | if (status) return status; |
401 | } |
402 | } |
403 | } |
404 | |
405 | return status::success; |
406 | } |
407 | |
408 | } // namespace jit |
409 | } // namespace gpu |
410 | } // namespace impl |
411 | } // namespace dnnl |
412 | |
413 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
414 | |