1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "kernel_evaluator.hpp"
18
19#include <algorithm>
20#include <cmath>
21#include <limits>
22
23namespace dnnl {
24namespace impl {
25namespace gpu {
26namespace jit {
27
28template <typename T1, typename T2>
29static inline T1 divUp(T1 x, T2 y) {
30 return (x + y - 1) / y;
31}
32
33template <typename T1, typename T2>
34static inline T1 alignUp(T1 x, T2 y) {
35 return divUp(x, y) * y;
36}
37
38double evaluateW(const kcatalog::Entry &e, const DerivedEvaluateParams &dp,
39 EvaluateAuxOutput &aux) {
40 static constexpr double maxPriority = 10000.;
41 double priority = e.model.params[kcatalog::ParamWPriority];
42
43 if (e.driverInfo.kParallelLocal)
44 aux.k0 = alignUp(divUp(dp.sizes.k, e.driverInfo.wg[LoopK]),
45 e.driverInfo.unroll[LoopK]);
46
47 if (priority > maxPriority) /* no op */
48 ; // Don't adjust very high values -- these are last resort kernels (lowest priority)
49 else if (e.driverInfo.kParallel) {
50 int wgCountK = std::max(1, int(dp.hwThreadCapacity / dp.threadCount));
51 aux.k0 = alignUp(
52 divUp(dp.sizes.k, wgCountK), e.driverInfo.unroll[LoopK]);
53 if (aux.k0 < dp.sizes.k)
54 priority = -priority;
55 else
56 priority = 2 * maxPriority + priority;
57 } else if (dp.threadCount > dp.hwThreadCapacity)
58 priority = 2 * maxPriority - priority;
59
60 // Prefer pure f16 over mixed f16-f32 kernels.
61 if (e.selector.precisions[2][0] != e.selector.precisions[0][0])
62 priority += 4 * maxPriority;
63
64 return priority;
65}
66
67double evaluateSCore(const kcatalog::Entry &e, const DerivedEvaluateParams &dp,
68 EvaluateAuxOutput &aux) {
69#define PARAM(p) e.model.params[kcatalog::ParamS_##p]
70
71 auto threads = dp.threadCount;
72 auto batch = dp.sizes.batch;
73 auto m = dp.sizes.m;
74 auto n = dp.sizes.n;
75 auto k = dp.sizes.k;
76 auto kthread = k;
77 auto capacity = dp.hwThreadCapacity;
78 auto capacity1 = dp.hwMinThreadsToFill;
79
80 if (e.driverInfo.kParallel)
81 kthread = aux.k0;
82 else if (e.driverInfo.kParallelLocal) {
83 kthread = alignUp(
84 divUp(k, e.driverInfo.wg[LoopK]), e.driverInfo.unroll[LoopK]);
85 kthread = std::max<decltype(kthread)>(
86 kthread, 2 * e.driverInfo.unroll[LoopK]);
87 aux.k0 = kthread;
88 }
89
90 double threadsFull = std::floor(threads / capacity) * capacity;
91 double threadsPartial = threads - threadsFull;
92 double partialWaves = std::ceil(threadsPartial / capacity1);
93 double npartial = std::ceil(threads / capacity1);
94
95 double C0 = (dp.beta == 0.) ? PARAM(C00) : PARAM(C01);
96 double C1 = (dp.beta == 0.) ? PARAM(C10) : PARAM(C11);
97 double Cm = (dp.beta == 0.) ? PARAM(Cm0) : PARAM(Cm1);
98 double ctime = std::max(Cm, C0 + npartial * C1);
99
100 double mtime = PARAM(Ma) * m + PARAM(Mb) * n;
101 mtime *= k;
102 mtime *= batch;
103
104 double Ef = PARAM(Ef);
105 double etimeFull = Ef * threadsFull;
106 double Ep = std::max(Ef,
107 PARAM(Ep0)
108 + (PARAM(Ep1) * dp.partialWaveCount)
109 / std::max(partialWaves, 1.));
110 double etimePartial = Ep * partialWaves * capacity1;
111 double etimeLB = (etimeFull + etimePartial);
112 double etimeNoLB = Ef * threads;
113
114 double Em = PARAM(Em);
115 if (threads < capacity) Em = 1.;
116 double etime = (1 - Em) * etimeNoLB + Em * etimeLB;
117 etime *= (e.driverInfo.unroll[LoopM] * e.driverInfo.unroll[LoopN]);
118 etime *= kthread;
119
120 if (!dp.effective) {
121 double F = PARAM(Fr0) + double(m) * double(n) * double(k) * PARAM(Fr1);
122 F = std::max(1.0, std::min(PARAM(Fp), F));
123 etime *= F;
124 }
125
126 double time = ctime + std::max(mtime, etime);
127
128 return time;
129#undef PARAM
130}
131
132double evaluateS(const kcatalog::Entry &e, const DerivedEvaluateParams &dp,
133 EvaluateAuxOutput &aux) {
134 if (!e.driverInfo.kParallel)
135 return evaluateSCore(e, dp, aux);
136 else {
137 // Consider choosing k0 to get as close as possible to 1 or 2 full waves.
138 int wgCountK1 = std::max(1, int(dp.hwThreadCapacity / dp.threadCount));
139 int wgCountK2
140 = std::max(1, int(2 * dp.hwThreadCapacity / dp.threadCount));
141
142 int k0_1
143 = alignUp(divUp(dp.sizes.k, wgCountK1 * e.driverInfo.wg[LoopK]),
144 e.driverInfo.unroll[LoopK]);
145 int k0_2
146 = alignUp(divUp(dp.sizes.k, wgCountK2 * e.driverInfo.wg[LoopK]),
147 e.driverInfo.unroll[LoopK]);
148
149 k0_1 = std::max(k0_1, 1);
150 k0_2 = std::max(k0_2, 1);
151
152 wgCountK1 = std::max<int>(
153 1, divUp(dp.sizes.k, k0_1 * e.driverInfo.wg[LoopK]));
154 wgCountK2 = std::max<int>(
155 1, divUp(dp.sizes.k, k0_2 * e.driverInfo.wg[LoopK]));
156
157 auto dp1 = dp;
158 dp1.wgCountK = wgCountK1;
159 dp1.threadCount *= wgCountK1;
160 aux.k0 = k0_1;
161
162 double score = evaluateSCore(e, dp1, aux);
163
164 if (k0_2 != k0_1) {
165 auto dp2 = dp;
166 dp2.wgCountK = wgCountK2;
167 dp2.threadCount *= wgCountK2;
168 aux.k0 = k0_2;
169
170 double score2 = evaluateSCore(e, dp2, aux);
171 if (score2 < score)
172 score = score2;
173 else
174 aux.k0 = k0_1;
175 }
176
177 // Add cost of initial beta scaling if not 1.
178 if (dp.beta != 1.) {
179 auto dp0 = dp;
180 dp0.sizes.k = 0;
181 score += evaluateSCore(e, dp0, aux);
182 }
183
184 return score;
185 }
186}
187
188bool alwaysAccept(const kcatalog::Entry &e, const EvaluateParams &p) {
189 int64_t mnk[3] = {p.sizes.m, p.sizes.n, p.sizes.k};
190 bool accept = true, hasAccepts = false;
191
192 for (int i = 0; i < 3; i++) {
193 if (e.restrictions.acceptSizesMin[i] >= 0) {
194 hasAccepts = true;
195 accept &= (mnk[i] >= e.restrictions.acceptSizesMin[i]);
196 }
197 if (e.restrictions.acceptSizesMax[i] >= 0) {
198 hasAccepts = true;
199 accept &= (mnk[i] <= e.restrictions.acceptSizesMax[i]);
200 }
201 }
202
203 return hasAccepts && accept;
204}
205
206DerivedEvaluateParams getDerivedParams(
207 const kcatalog::Entry &e, const EvaluateParams &p) {
208 DerivedEvaluateParams dp;
209 static_cast<EvaluateParams &>(dp) = p;
210
211 auto unrollM = e.driverInfo.unroll[LoopM];
212 auto unrollN = e.driverInfo.unroll[LoopN];
213
214 auto wgM = e.driverInfo.wg[LoopM];
215 auto wgN = e.driverInfo.wg[LoopN];
216
217 auto wgTileM = wgM * unrollM;
218 auto wgTileN = wgN * unrollN;
219
220 dp.wgCountM = divUp(p.sizes.m, wgTileM); /* may be adjusted later */
221 dp.wgCountN = divUp(p.sizes.n, wgTileN);
222 dp.wgCountK = 1;
223
224 if (!e.driverInfo.fixedWG()) {
225 if (p.sizes.m < wgTileM) {
226 wgM = std::max<int>(divUp(p.sizes.m, unrollM), 1);
227 wgTileM = wgM * unrollM;
228 dp.wgCountM = 1;
229 }
230 if (p.sizes.n < wgTileN) {
231 wgN = std::max<int>(divUp(p.sizes.n, unrollN), 1);
232 wgTileN = wgN * unrollN;
233 dp.wgCountN = 1;
234 }
235 }
236
237 auto threadsPerWG = wgM * wgN * e.driverInfo.wg[LoopK];
238
239 dp.mPad = dp.wgCountM * wgTileM;
240 dp.nPad = dp.wgCountN * wgTileN;
241 dp.threadCount = double(dp.wgCountM) * double(dp.wgCountN);
242
243 dp.threadCount *= threadsPerWG;
244 dp.threadCount *= (dp.wgCountK * p.sizes.batch);
245
246 switch (e.selector.hw) {
247 case kcatalog::HWTagGen9:
248 case kcatalog::HWTagGen11:
249 case kcatalog::HWTagGen12LP: dp.threadsPerEU = 7; break;
250 default: dp.threadsPerEU = (e.driverInfo.grfCount > 128) ? 4 : 8; break;
251 }
252
253 int ssCount;
254 switch (e.selector.hw) {
255 case kcatalog::HWTagGen12LP:
256 case kcatalog::HWTagXeHP:
257 case kcatalog::HWTagXeHPG: ssCount = p.euCount >> 4; break;
258 default: ssCount = p.euCount >> 3; break;
259 }
260
261 dp.hwThreadCapacity = dp.threadsPerEU * p.euCount;
262 dp.hwMinThreadsToFill = threadsPerWG * ssCount;
263 dp.partialWaveCount = divUp(dp.hwThreadCapacity, dp.hwMinThreadsToFill);
264
265 return dp;
266}
267
268double evaluate(const kcatalog::Entry &e, const EvaluateParams &p,
269 EvaluateAuxOutput &aux) {
270 return evaluate(e, getDerivedParams(e, p), aux);
271}
272
273double evaluate(const kcatalog::Entry &e, const DerivedEvaluateParams &dp,
274 EvaluateAuxOutput &aux) {
275 double score = 0.;
276
277 switch (e.model.id) {
278 case 'S': score = evaluateS(e, dp, aux); break;
279 case 'W': score = evaluateW(e, dp, aux); break;
280 default: score = std::numeric_limits<double>::quiet_NaN(); break;
281 }
282
283 if (alwaysAccept(e, dp)) score = -std::numeric_limits<double>::infinity();
284
285 return score;
286}
287
288} // namespace jit
289} // namespace gpu
290} // namespace impl
291} // namespace dnnl
292