1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef GPU_JIT_GEMM_GEMM_WALK_ORDERS_HPP |
18 | #define GPU_JIT_GEMM_GEMM_WALK_ORDERS_HPP |
19 | |
20 | #include "common/utils.hpp" |
21 | #include "gpu/jit/gemm/gen_gemm_kernel.hpp" |
22 | |
23 | namespace dnnl { |
24 | namespace impl { |
25 | namespace gpu { |
26 | namespace jit { |
27 | |
28 | inline void gemm_linear_order_args(compute::kernel_arg_list_t &arg_list, |
29 | int &argn, const size_t (&lws)[3], size_t (&gws)[3], int32_t m, |
30 | int32_t n, bool disable_hilbert, const CommonDriverInfo &info, |
31 | const compute::device_info_t *dev_info) { |
32 | if (!info.isLinearOrder()) return; |
33 | |
34 | int m_index = info.isNMK() ? 1 : 0; |
35 | int n_index = info.isNMK() ? 0 : 1; |
36 | auto groups_m = uint32_t(gws[m_index] / lws[m_index]); |
37 | auto groups_n = uint32_t(gws[n_index] / lws[n_index]); |
38 | auto group_count = groups_m * groups_n; |
39 | |
40 | uint32_t ss_count = dev_info->eu_count() / dev_info->max_eus_per_wg(); |
41 | bool large_grf_mode = (info.grfCount > 128); |
42 | uint32_t thread_per_ss = dev_info->hw_threads(large_grf_mode) / ss_count; |
43 | uint32_t thread_per_tg = uint32_t(lws[0] * lws[1] * lws[2]); |
44 | uint32_t tg_per_ss = thread_per_ss / thread_per_tg; |
45 | uint32_t concurrent_tg = tg_per_ss * ss_count; |
46 | |
47 | arg_list.set(argn++, groups_m); |
48 | arg_list.set(argn++, groups_n); |
49 | |
50 | if (info.isHilbert()) { |
51 | uint32_t vd = 0, uvd = 0; |
52 | double ratio = double(groups_n) / double(groups_m); |
53 | if (ratio >= 1) { |
54 | vd = std::ceil(groups_n / std::round(2 * ratio)); |
55 | uvd = groups_m * vd; |
56 | } else { |
57 | vd = std::ceil(groups_m / std::round(2 / ratio)); |
58 | uvd = groups_n * vd; |
59 | vd |= 0xFFFF0000u; |
60 | } |
61 | |
62 | int shift = std::floor(std::log2(uvd)); |
63 | uint32_t uvd_recip |
64 | = uint32_t(utils::div_up(0x100000000ull << shift, uvd)); |
65 | uint32_t bail = disable_hilbert ? 512 : 1; |
66 | |
67 | arg_list.set(argn++, vd); |
68 | arg_list.set(argn++, uvd_recip); |
69 | arg_list.set(argn++, bail); |
70 | } else if (info.isBoustrophedon()) { |
71 | double bias = double(info.wg[0] * info.unroll[0]) |
72 | / double(info.wg[1] * info.unroll[1]); |
73 | double sm = std::sqrt(concurrent_tg / bias); |
74 | double sn = std::sqrt(concurrent_tg * bias); |
75 | |
76 | int32_t slice = 0, thresh = 0; |
77 | bool ok = false; |
78 | |
79 | for (bool nslice : {groups_m > groups_n, groups_m <= groups_n}) { |
80 | double s = nslice ? sn : sm; |
81 | auto sf = int(std::floor(s)); |
82 | auto sc = int(std::ceil(s)); |
83 | if (concurrent_tg % sc == 0) s = sf = sc; |
84 | if (concurrent_tg % (sc + 1) == 0) s = sf = sc = sc + 1; |
85 | |
86 | int gc = nslice ? groups_n : groups_m; |
87 | int gco = nslice ? groups_m : groups_n; |
88 | |
89 | for (int srange = 0; srange <= 2 && !ok; srange++) { |
90 | int s0 = (srange < 2) ? sc : sf; |
91 | bool up = (srange == 1); |
92 | int s1 = s0 + (up ? 1 : -1); |
93 | if (s1 <= 0) continue; |
94 | |
95 | auto rem = gc % s0; |
96 | if (!rem || up) |
97 | thresh = gc / s0 - rem; |
98 | else |
99 | thresh = utils::div_up(gc, s0) - (s0 - rem); |
100 | |
101 | ok = (thresh >= 0) && (gco >= 2 * s0); |
102 | slice = s0; |
103 | if (!up) { |
104 | if (thresh > 0) |
105 | thresh = -thresh; |
106 | else { |
107 | slice--; |
108 | thresh = gc; |
109 | } |
110 | } |
111 | if (nslice) slice *= -1; |
112 | } |
113 | |
114 | if (ok) break; |
115 | } |
116 | |
117 | if (!ok) { |
118 | // Fallback slicing. |
119 | bool nslice = (groups_m > groups_n); |
120 | double s = nslice ? sn : sm; |
121 | int gc = nslice ? groups_n : groups_m; |
122 | |
123 | if (gc < s * 1.5) |
124 | slice = gc; |
125 | else |
126 | slice = gc / utils::div_up(gc, int(std::round(s))); |
127 | |
128 | thresh = nstl::max(0, (gc / slice) - (gc % slice)); |
129 | if (nslice) slice *= -1; |
130 | } |
131 | |
132 | if (slice == 0) { |
133 | slice = 1; |
134 | thresh = groups_m; |
135 | } |
136 | |
137 | arg_list.set(argn++, slice); |
138 | arg_list.set(argn++, thresh); |
139 | } |
140 | |
141 | if (info.isPersistent()) { |
142 | group_count = nstl::min(group_count, concurrent_tg); |
143 | arg_list.set(argn++, concurrent_tg); |
144 | } |
145 | |
146 | gws[0] = lws[0] * group_count; |
147 | gws[1] = lws[1]; |
148 | } |
149 | |
150 | } // namespace jit |
151 | } // namespace gpu |
152 | } // namespace impl |
153 | } // namespace dnnl |
154 | #endif |
155 | |