1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_JIT_GEMM_GEMM_WALK_ORDERS_HPP
18#define GPU_JIT_GEMM_GEMM_WALK_ORDERS_HPP
19
20#include "common/utils.hpp"
21#include "gpu/jit/gemm/gen_gemm_kernel.hpp"
22
23namespace dnnl {
24namespace impl {
25namespace gpu {
26namespace jit {
27
28inline void gemm_linear_order_args(compute::kernel_arg_list_t &arg_list,
29 int &argn, const size_t (&lws)[3], size_t (&gws)[3], int32_t m,
30 int32_t n, bool disable_hilbert, const CommonDriverInfo &info,
31 const compute::device_info_t *dev_info) {
32 if (!info.isLinearOrder()) return;
33
34 int m_index = info.isNMK() ? 1 : 0;
35 int n_index = info.isNMK() ? 0 : 1;
36 auto groups_m = uint32_t(gws[m_index] / lws[m_index]);
37 auto groups_n = uint32_t(gws[n_index] / lws[n_index]);
38 auto group_count = groups_m * groups_n;
39
40 uint32_t ss_count = dev_info->eu_count() / dev_info->max_eus_per_wg();
41 bool large_grf_mode = (info.grfCount > 128);
42 uint32_t thread_per_ss = dev_info->hw_threads(large_grf_mode) / ss_count;
43 uint32_t thread_per_tg = uint32_t(lws[0] * lws[1] * lws[2]);
44 uint32_t tg_per_ss = thread_per_ss / thread_per_tg;
45 uint32_t concurrent_tg = tg_per_ss * ss_count;
46
47 arg_list.set(argn++, groups_m);
48 arg_list.set(argn++, groups_n);
49
50 if (info.isHilbert()) {
51 uint32_t vd = 0, uvd = 0;
52 double ratio = double(groups_n) / double(groups_m);
53 if (ratio >= 1) {
54 vd = std::ceil(groups_n / std::round(2 * ratio));
55 uvd = groups_m * vd;
56 } else {
57 vd = std::ceil(groups_m / std::round(2 / ratio));
58 uvd = groups_n * vd;
59 vd |= 0xFFFF0000u;
60 }
61
62 int shift = std::floor(std::log2(uvd));
63 uint32_t uvd_recip
64 = uint32_t(utils::div_up(0x100000000ull << shift, uvd));
65 uint32_t bail = disable_hilbert ? 512 : 1;
66
67 arg_list.set(argn++, vd);
68 arg_list.set(argn++, uvd_recip);
69 arg_list.set(argn++, bail);
70 } else if (info.isBoustrophedon()) {
71 double bias = double(info.wg[0] * info.unroll[0])
72 / double(info.wg[1] * info.unroll[1]);
73 double sm = std::sqrt(concurrent_tg / bias);
74 double sn = std::sqrt(concurrent_tg * bias);
75
76 int32_t slice = 0, thresh = 0;
77 bool ok = false;
78
79 for (bool nslice : {groups_m > groups_n, groups_m <= groups_n}) {
80 double s = nslice ? sn : sm;
81 auto sf = int(std::floor(s));
82 auto sc = int(std::ceil(s));
83 if (concurrent_tg % sc == 0) s = sf = sc;
84 if (concurrent_tg % (sc + 1) == 0) s = sf = sc = sc + 1;
85
86 int gc = nslice ? groups_n : groups_m;
87 int gco = nslice ? groups_m : groups_n;
88
89 for (int srange = 0; srange <= 2 && !ok; srange++) {
90 int s0 = (srange < 2) ? sc : sf;
91 bool up = (srange == 1);
92 int s1 = s0 + (up ? 1 : -1);
93 if (s1 <= 0) continue;
94
95 auto rem = gc % s0;
96 if (!rem || up)
97 thresh = gc / s0 - rem;
98 else
99 thresh = utils::div_up(gc, s0) - (s0 - rem);
100
101 ok = (thresh >= 0) && (gco >= 2 * s0);
102 slice = s0;
103 if (!up) {
104 if (thresh > 0)
105 thresh = -thresh;
106 else {
107 slice--;
108 thresh = gc;
109 }
110 }
111 if (nslice) slice *= -1;
112 }
113
114 if (ok) break;
115 }
116
117 if (!ok) {
118 // Fallback slicing.
119 bool nslice = (groups_m > groups_n);
120 double s = nslice ? sn : sm;
121 int gc = nslice ? groups_n : groups_m;
122
123 if (gc < s * 1.5)
124 slice = gc;
125 else
126 slice = gc / utils::div_up(gc, int(std::round(s)));
127
128 thresh = nstl::max(0, (gc / slice) - (gc % slice));
129 if (nslice) slice *= -1;
130 }
131
132 if (slice == 0) {
133 slice = 1;
134 thresh = groups_m;
135 }
136
137 arg_list.set(argn++, slice);
138 arg_list.set(argn++, thresh);
139 }
140
141 if (info.isPersistent()) {
142 group_count = nstl::min(group_count, concurrent_tg);
143 arg_list.set(argn++, concurrent_tg);
144 }
145
146 gws[0] = lws[0] * group_count;
147 gws[1] = lws[1];
148}
149
150} // namespace jit
151} // namespace gpu
152} // namespace impl
153} // namespace dnnl
154#endif
155