1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "kernel_evaluator.hpp" |
18 | |
19 | #include <algorithm> |
20 | #include <cmath> |
21 | #include <limits> |
22 | |
23 | namespace dnnl { |
24 | namespace impl { |
25 | namespace gpu { |
26 | namespace jit { |
27 | |
28 | template <typename T1, typename T2> |
29 | static inline T1 divUp(T1 x, T2 y) { |
30 | return (x + y - 1) / y; |
31 | } |
32 | |
33 | template <typename T1, typename T2> |
34 | static inline T1 alignUp(T1 x, T2 y) { |
35 | return divUp(x, y) * y; |
36 | } |
37 | |
38 | double evaluateW(const kcatalog::Entry &e, const DerivedEvaluateParams &dp, |
39 | EvaluateAuxOutput &aux) { |
40 | static constexpr double maxPriority = 10000.; |
41 | double priority = e.model.params[kcatalog::ParamWPriority]; |
42 | |
43 | if (e.driverInfo.kParallelLocal) |
44 | aux.k0 = alignUp(divUp(dp.sizes.k, e.driverInfo.wg[LoopK]), |
45 | e.driverInfo.unroll[LoopK]); |
46 | |
47 | if (priority > maxPriority) /* no op */ |
48 | ; // Don't adjust very high values -- these are last resort kernels (lowest priority) |
49 | else if (e.driverInfo.kParallel) { |
50 | int wgCountK = std::max(1, int(dp.hwThreadCapacity / dp.threadCount)); |
51 | aux.k0 = alignUp( |
52 | divUp(dp.sizes.k, wgCountK), e.driverInfo.unroll[LoopK]); |
53 | if (aux.k0 < dp.sizes.k) |
54 | priority = -priority; |
55 | else |
56 | priority = 2 * maxPriority + priority; |
57 | } else if (dp.threadCount > dp.hwThreadCapacity) |
58 | priority = 2 * maxPriority - priority; |
59 | |
60 | // Prefer pure f16 over mixed f16-f32 kernels. |
61 | if (e.selector.precisions[2][0] != e.selector.precisions[0][0]) |
62 | priority += 4 * maxPriority; |
63 | |
64 | return priority; |
65 | } |
66 | |
67 | double evaluateSCore(const kcatalog::Entry &e, const DerivedEvaluateParams &dp, |
68 | EvaluateAuxOutput &aux) { |
69 | #define PARAM(p) e.model.params[kcatalog::ParamS_##p] |
70 | |
71 | auto threads = dp.threadCount; |
72 | auto batch = dp.sizes.batch; |
73 | auto m = dp.sizes.m; |
74 | auto n = dp.sizes.n; |
75 | auto k = dp.sizes.k; |
76 | auto kthread = k; |
77 | auto capacity = dp.hwThreadCapacity; |
78 | auto capacity1 = dp.hwMinThreadsToFill; |
79 | |
80 | if (e.driverInfo.kParallel) |
81 | kthread = aux.k0; |
82 | else if (e.driverInfo.kParallelLocal) { |
83 | kthread = alignUp( |
84 | divUp(k, e.driverInfo.wg[LoopK]), e.driverInfo.unroll[LoopK]); |
85 | kthread = std::max<decltype(kthread)>( |
86 | kthread, 2 * e.driverInfo.unroll[LoopK]); |
87 | aux.k0 = kthread; |
88 | } |
89 | |
90 | double threadsFull = std::floor(threads / capacity) * capacity; |
91 | double threadsPartial = threads - threadsFull; |
92 | double partialWaves = std::ceil(threadsPartial / capacity1); |
93 | double npartial = std::ceil(threads / capacity1); |
94 | |
95 | double C0 = (dp.beta == 0.) ? PARAM(C00) : PARAM(C01); |
96 | double C1 = (dp.beta == 0.) ? PARAM(C10) : PARAM(C11); |
97 | double Cm = (dp.beta == 0.) ? PARAM(Cm0) : PARAM(Cm1); |
98 | double ctime = std::max(Cm, C0 + npartial * C1); |
99 | |
100 | double mtime = PARAM(Ma) * m + PARAM(Mb) * n; |
101 | mtime *= k; |
102 | mtime *= batch; |
103 | |
104 | double Ef = PARAM(Ef); |
105 | double etimeFull = Ef * threadsFull; |
106 | double Ep = std::max(Ef, |
107 | PARAM(Ep0) |
108 | + (PARAM(Ep1) * dp.partialWaveCount) |
109 | / std::max(partialWaves, 1.)); |
110 | double etimePartial = Ep * partialWaves * capacity1; |
111 | double etimeLB = (etimeFull + etimePartial); |
112 | double etimeNoLB = Ef * threads; |
113 | |
114 | double Em = PARAM(Em); |
115 | if (threads < capacity) Em = 1.; |
116 | double etime = (1 - Em) * etimeNoLB + Em * etimeLB; |
117 | etime *= (e.driverInfo.unroll[LoopM] * e.driverInfo.unroll[LoopN]); |
118 | etime *= kthread; |
119 | |
120 | if (!dp.effective) { |
121 | double F = PARAM(Fr0) + double(m) * double(n) * double(k) * PARAM(Fr1); |
122 | F = std::max(1.0, std::min(PARAM(Fp), F)); |
123 | etime *= F; |
124 | } |
125 | |
126 | double time = ctime + std::max(mtime, etime); |
127 | |
128 | return time; |
129 | #undef PARAM |
130 | } |
131 | |
132 | double evaluateS(const kcatalog::Entry &e, const DerivedEvaluateParams &dp, |
133 | EvaluateAuxOutput &aux) { |
134 | if (!e.driverInfo.kParallel) |
135 | return evaluateSCore(e, dp, aux); |
136 | else { |
137 | // Consider choosing k0 to get as close as possible to 1 or 2 full waves. |
138 | int wgCountK1 = std::max(1, int(dp.hwThreadCapacity / dp.threadCount)); |
139 | int wgCountK2 |
140 | = std::max(1, int(2 * dp.hwThreadCapacity / dp.threadCount)); |
141 | |
142 | int k0_1 |
143 | = alignUp(divUp(dp.sizes.k, wgCountK1 * e.driverInfo.wg[LoopK]), |
144 | e.driverInfo.unroll[LoopK]); |
145 | int k0_2 |
146 | = alignUp(divUp(dp.sizes.k, wgCountK2 * e.driverInfo.wg[LoopK]), |
147 | e.driverInfo.unroll[LoopK]); |
148 | |
149 | k0_1 = std::max(k0_1, 1); |
150 | k0_2 = std::max(k0_2, 1); |
151 | |
152 | wgCountK1 = std::max<int>( |
153 | 1, divUp(dp.sizes.k, k0_1 * e.driverInfo.wg[LoopK])); |
154 | wgCountK2 = std::max<int>( |
155 | 1, divUp(dp.sizes.k, k0_2 * e.driverInfo.wg[LoopK])); |
156 | |
157 | auto dp1 = dp; |
158 | dp1.wgCountK = wgCountK1; |
159 | dp1.threadCount *= wgCountK1; |
160 | aux.k0 = k0_1; |
161 | |
162 | double score = evaluateSCore(e, dp1, aux); |
163 | |
164 | if (k0_2 != k0_1) { |
165 | auto dp2 = dp; |
166 | dp2.wgCountK = wgCountK2; |
167 | dp2.threadCount *= wgCountK2; |
168 | aux.k0 = k0_2; |
169 | |
170 | double score2 = evaluateSCore(e, dp2, aux); |
171 | if (score2 < score) |
172 | score = score2; |
173 | else |
174 | aux.k0 = k0_1; |
175 | } |
176 | |
177 | // Add cost of initial beta scaling if not 1. |
178 | if (dp.beta != 1.) { |
179 | auto dp0 = dp; |
180 | dp0.sizes.k = 0; |
181 | score += evaluateSCore(e, dp0, aux); |
182 | } |
183 | |
184 | return score; |
185 | } |
186 | } |
187 | |
188 | bool alwaysAccept(const kcatalog::Entry &e, const EvaluateParams &p) { |
189 | int64_t mnk[3] = {p.sizes.m, p.sizes.n, p.sizes.k}; |
190 | bool accept = true, hasAccepts = false; |
191 | |
192 | for (int i = 0; i < 3; i++) { |
193 | if (e.restrictions.acceptSizesMin[i] >= 0) { |
194 | hasAccepts = true; |
195 | accept &= (mnk[i] >= e.restrictions.acceptSizesMin[i]); |
196 | } |
197 | if (e.restrictions.acceptSizesMax[i] >= 0) { |
198 | hasAccepts = true; |
199 | accept &= (mnk[i] <= e.restrictions.acceptSizesMax[i]); |
200 | } |
201 | } |
202 | |
203 | return hasAccepts && accept; |
204 | } |
205 | |
206 | DerivedEvaluateParams getDerivedParams( |
207 | const kcatalog::Entry &e, const EvaluateParams &p) { |
208 | DerivedEvaluateParams dp; |
209 | static_cast<EvaluateParams &>(dp) = p; |
210 | |
211 | auto unrollM = e.driverInfo.unroll[LoopM]; |
212 | auto unrollN = e.driverInfo.unroll[LoopN]; |
213 | |
214 | auto wgM = e.driverInfo.wg[LoopM]; |
215 | auto wgN = e.driverInfo.wg[LoopN]; |
216 | |
217 | auto wgTileM = wgM * unrollM; |
218 | auto wgTileN = wgN * unrollN; |
219 | |
220 | dp.wgCountM = divUp(p.sizes.m, wgTileM); /* may be adjusted later */ |
221 | dp.wgCountN = divUp(p.sizes.n, wgTileN); |
222 | dp.wgCountK = 1; |
223 | |
224 | if (!e.driverInfo.fixedWG()) { |
225 | if (p.sizes.m < wgTileM) { |
226 | wgM = std::max<int>(divUp(p.sizes.m, unrollM), 1); |
227 | wgTileM = wgM * unrollM; |
228 | dp.wgCountM = 1; |
229 | } |
230 | if (p.sizes.n < wgTileN) { |
231 | wgN = std::max<int>(divUp(p.sizes.n, unrollN), 1); |
232 | wgTileN = wgN * unrollN; |
233 | dp.wgCountN = 1; |
234 | } |
235 | } |
236 | |
237 | auto threadsPerWG = wgM * wgN * e.driverInfo.wg[LoopK]; |
238 | |
239 | dp.mPad = dp.wgCountM * wgTileM; |
240 | dp.nPad = dp.wgCountN * wgTileN; |
241 | dp.threadCount = double(dp.wgCountM) * double(dp.wgCountN); |
242 | |
243 | dp.threadCount *= threadsPerWG; |
244 | dp.threadCount *= (dp.wgCountK * p.sizes.batch); |
245 | |
246 | switch (e.selector.hw) { |
247 | case kcatalog::HWTagGen9: |
248 | case kcatalog::HWTagGen11: |
249 | case kcatalog::HWTagGen12LP: dp.threadsPerEU = 7; break; |
250 | default: dp.threadsPerEU = (e.driverInfo.grfCount > 128) ? 4 : 8; break; |
251 | } |
252 | |
253 | int ssCount; |
254 | switch (e.selector.hw) { |
255 | case kcatalog::HWTagGen12LP: |
256 | case kcatalog::HWTagXeHP: |
257 | case kcatalog::HWTagXeHPG: ssCount = p.euCount >> 4; break; |
258 | default: ssCount = p.euCount >> 3; break; |
259 | } |
260 | |
261 | dp.hwThreadCapacity = dp.threadsPerEU * p.euCount; |
262 | dp.hwMinThreadsToFill = threadsPerWG * ssCount; |
263 | dp.partialWaveCount = divUp(dp.hwThreadCapacity, dp.hwMinThreadsToFill); |
264 | |
265 | return dp; |
266 | } |
267 | |
268 | double evaluate(const kcatalog::Entry &e, const EvaluateParams &p, |
269 | EvaluateAuxOutput &aux) { |
270 | return evaluate(e, getDerivedParams(e, p), aux); |
271 | } |
272 | |
273 | double evaluate(const kcatalog::Entry &e, const DerivedEvaluateParams &dp, |
274 | EvaluateAuxOutput &aux) { |
275 | double score = 0.; |
276 | |
277 | switch (e.model.id) { |
278 | case 'S': score = evaluateS(e, dp, aux); break; |
279 | case 'W': score = evaluateW(e, dp, aux); break; |
280 | default: score = std::numeric_limits<double>::quiet_NaN(); break; |
281 | } |
282 | |
283 | if (alwaysAccept(e, dp)) score = -std::numeric_limits<double>::infinity(); |
284 | |
285 | return score; |
286 | } |
287 | |
288 | } // namespace jit |
289 | } // namespace gpu |
290 | } // namespace impl |
291 | } // namespace dnnl |
292 | |