1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "kernel_selector.hpp" |
18 | #include "common/verbose.hpp" |
19 | #include "kernel_evaluator.hpp" |
20 | |
21 | #include <cassert> |
22 | #include <cctype> |
23 | #include <cstring> |
24 | |
25 | namespace dnnl { |
26 | namespace impl { |
27 | namespace gpu { |
28 | namespace jit { |
29 | |
30 | inline bool layoutMatch(const char *lref, const char *lpattern) { |
31 | return (lref[0] == lpattern[0]); // This is a sufficient check for now. |
32 | } |
33 | |
34 | inline bool precisionMatch(char pref, char ppattern) { |
35 | // Fast case-insensitive compare |
36 | return (pref & ~0x20) == (ppattern & ~0x20); |
37 | } |
38 | |
39 | inline bool precisionMatch(const char *pref, const char *ppattern) { |
40 | bool ok = false; |
41 | ok = ok || (ppattern[0] == '?'); |
42 | ok = ok || precisionMatch(pref[0], ppattern[0]); |
43 | ok = ok || (ppattern[0] == '[' && precisionMatch(pref[0], ppattern[1])); |
44 | if (ok && pref[0] == '[') { |
45 | ok = ok && precisionMatch(pref[1], ppattern[1]) |
46 | && precisionMatch(pref[2], ppattern[2]); |
47 | for (int i = 3; pref[i] != '\0'; i++) { |
48 | if (pref[i] != ppattern[i]) { |
49 | ok = false; |
50 | break; |
51 | } |
52 | } |
53 | } |
54 | return ok; |
55 | } |
56 | |
57 | inline bool precisionMinimumMatch(const char *pref, char pmin) { |
58 | uint8_t sizeTable[0x20] |
59 | = {// A B C D E F G H I J K L M N O |
60 | 0, 0, 2, 8, 8, 0, 0, 0, 2, 4, 4, 0, 0, 0, 0, 1, |
61 | // P Q R S T U V W X Y Z |
62 | 0, 0, 0, 4, 4, 0, 0, 2, 0, 0, 16, 0, 0, 0, 0, 0}; |
63 | |
64 | return (sizeTable[pref[0] & 0x1F] >= sizeTable[pmin & 0x1F]); |
65 | } |
66 | |
67 | inline bool alignmentMatch(int aref, int apattern) { |
68 | if (aref == 0) aref = 1; |
69 | return (apattern % aref == 0); |
70 | } |
71 | |
72 | inline bool tagMatch(const char *tref, const char *tpattern) { |
73 | for (auto c = tref; *c; c++) { |
74 | // Lowercase tags -> must not match pattern |
75 | // Uppercase tags -> must match pattern |
76 | int cu = *c & ~0x20; // tolower(c) |
77 | bool match = (std::strchr(tpattern, cu) != nullptr); |
78 | bool wantMatch = (*c & 0x20) == 0; |
79 | if (match != wantMatch) return false; |
80 | } |
81 | return true; |
82 | } |
83 | |
84 | bool matches(const kcatalog::Entry &e, const MatchParams &pattern) { |
85 | bool ok = true; |
86 | |
87 | if (e.restrictions.steppingMin >= 0) |
88 | ok = ok && (pattern.stepping >= e.restrictions.steppingMin); |
89 | if (e.restrictions.steppingMax >= 0) |
90 | ok = ok && (pattern.stepping < e.restrictions.steppingMax); |
91 | ok = ok && layoutMatch(e.selector.layouts[0], pattern.selector.layouts[0]); |
92 | ok = ok && layoutMatch(e.selector.layouts[1], pattern.selector.layouts[1]); |
93 | ok = ok && layoutMatch(e.selector.layouts[2], pattern.selector.layouts[2]); |
94 | ok = ok |
95 | && precisionMatch( |
96 | e.selector.precisions[2], pattern.selector.precisions[2]); |
97 | if (pattern.precisionCExt) |
98 | ok = ok |
99 | && precisionMinimumMatch( |
100 | e.selector.precisions[2], pattern.precisionCExt); |
101 | for (int i = 0; i < 3; i++) |
102 | ok = ok |
103 | && alignmentMatch( |
104 | e.restrictions.alignment[i], pattern.alignment[i]); |
105 | ok = ok && tagMatch(e.restrictions.tags, pattern.tags); |
106 | |
107 | for (int i = 0; i < 2; i++) |
108 | if (pattern.unroll[i] > 0) |
109 | ok = ok && (pattern.unroll[i] == e.driverInfo.unroll[i]); |
110 | |
111 | if (!pattern.ignoreSizes) { |
112 | int64_t mnk[3] = {pattern.sizes.m, pattern.sizes.n, pattern.sizes.k}; |
113 | for (int i = 0; i < 3; i++) { |
114 | if (e.restrictions.allowedSizesMin[i] >= 0) |
115 | ok = ok && (mnk[i] >= e.restrictions.allowedSizesMin[i]); |
116 | if (e.restrictions.allowedSizesMax[i] >= 0) |
117 | ok = ok && (mnk[i] <= e.restrictions.allowedSizesMax[i]); |
118 | } |
119 | } |
120 | |
121 | // Should already be matched. |
122 | ok = ok && (e.selector.hw == pattern.selector.hw); |
123 | ok = ok |
124 | && precisionMatch( |
125 | e.selector.precisions[0], pattern.selector.precisions[0]); |
126 | ok = ok |
127 | && precisionMatch( |
128 | e.selector.precisions[1], pattern.selector.precisions[1]); |
129 | |
130 | return ok; |
131 | } |
132 | |
133 | const kcatalog::Entry *select(const kcatalog::Catalog &catalog, |
134 | const MatchParams &pattern, const EvaluateParams &eparams, |
135 | EvaluateAuxOutput &aux) { |
136 | return select(catalog, 1, &pattern, eparams, aux); |
137 | } |
138 | |
139 | const kcatalog::Entry *select(const kcatalog::Catalog &catalog, int npatterns, |
140 | const MatchParams *patterns, const EvaluateParams &eparams, |
141 | EvaluateAuxOutput &aux) { |
142 | double bestScore = std::numeric_limits<double>::infinity(); |
143 | const kcatalog::Entry *bestEntry = nullptr; |
144 | int bestIPattern = -1; |
145 | |
146 | bool verbose = (get_verbose() >= 5); |
147 | |
148 | // TODO: omit evaluation if only one match, if aux output not needed. |
149 | for (int ipattern = 0; ipattern < npatterns; ipattern++) { |
150 | for (auto it = match(catalog, patterns[ipattern]); it; it++) { |
151 | EvaluateAuxOutput thisAux; |
152 | double score = evaluate(*it, eparams, thisAux); |
153 | if (score < bestScore) { |
154 | bestEntry = &*it; |
155 | bestScore = score; |
156 | bestIPattern = ipattern; |
157 | aux = thisAux; |
158 | } |
159 | if (verbose) { |
160 | const auto &info = it->driverInfo; |
161 | printf("onednn_verbose,info,gpu,gemm,consider:%dx%d,%dx%dx%d," |
162 | "score:%f\n" , |
163 | info.unroll[LoopM], info.unroll[LoopN], info.wg[LoopM], |
164 | info.wg[LoopN], info.wg[LoopK], score); |
165 | } |
166 | } |
167 | } |
168 | |
169 | // Late tag checking. If late tags do not match, we abandon the kernel and |
170 | // force the calling code to take another path. |
171 | if (bestEntry |
172 | && !tagMatch(bestEntry->restrictions.tags, |
173 | patterns[bestIPattern].lateTags)) |
174 | return nullptr; |
175 | |
176 | return bestEntry; |
177 | } |
178 | |
179 | template <bool upper> |
180 | const kcatalog::Entry *upper_lower_bound( |
181 | const kcatalog::Catalog &catalog, const kcatalog::Selector &selector) { |
182 | int n = catalog.entryCount; |
183 | const kcatalog::Entry *cur = catalog.entries; |
184 | |
185 | while (n > 0) { |
186 | auto half = n >> 1; |
187 | auto mid = cur + half; |
188 | if (upper ? (*mid <= selector) : (*mid < selector)) { |
189 | cur = mid + 1; |
190 | n = n - half - 1; |
191 | } else |
192 | n = half; |
193 | } |
194 | |
195 | return cur; |
196 | } |
197 | |
198 | const kcatalog::Entry *lower_bound( |
199 | const kcatalog::Catalog &catalog, const kcatalog::Selector &selector) { |
200 | return upper_lower_bound<false>(catalog, selector); |
201 | } |
202 | |
203 | const kcatalog::Entry *upper_bound( |
204 | const kcatalog::Catalog &catalog, const kcatalog::Selector &selector) { |
205 | return upper_lower_bound<true>(catalog, selector); |
206 | } |
207 | |
208 | MatchParams::MatchParams(ngen::HW hw, const GEMMProblem &problem) { |
209 | using namespace kcatalog; |
210 | |
211 | switch (hw) { |
212 | default: assert(!"Unknown architecture" ); |
213 | case ngen::HW::Gen9: selector.hw = kcatalog::HWTagGen9; break; |
214 | case ngen::HW::Gen11: selector.hw = kcatalog::HWTagGen11; break; |
215 | case ngen::HW::Gen12LP: selector.hw = kcatalog::HWTagGen12LP; break; |
216 | case ngen::HW::XeHP: selector.hw = kcatalog::HWTagXeHP; break; |
217 | case ngen::HW::XeHPG: selector.hw = kcatalog::HWTagXeHPG; break; |
218 | case ngen::HW::XeHPC: selector.hw = kcatalog::HWTagXeHPC; break; |
219 | } |
220 | |
221 | auto &C = problem.C; |
222 | auto equivCLayout = C.layout; |
223 | if (isPacked(equivCLayout)) { |
224 | bool colMajor = (C.layout == MatrixLayout::Pc) |
225 | ^ (C.crosspack * problem.Tc > 4); |
226 | equivCLayout = (colMajor ? MatrixLayout::N : MatrixLayout::T); |
227 | } |
228 | |
229 | selector.kernelType = "gemm" ; |
230 | |
231 | std::fill(temp.begin(), temp.end(), '\0'); |
232 | temp[0] = precisionChar(problem.Ta); |
233 | temp[2] = precisionChar(problem.Tb); |
234 | temp[4] = precisionChar(problem.Tc); |
235 | temp[6] = layoutChar(problem.A.layout); |
236 | temp[8] = layoutChar(problem.B.layout); |
237 | temp[10] = layoutChar(equivCLayout); |
238 | selector.precisions[0] = &temp[0]; |
239 | selector.precisions[1] = &temp[2]; |
240 | selector.precisions[2] = &temp[4]; |
241 | selector.layouts[0] = &temp[6]; |
242 | selector.layouts[1] = &temp[8]; |
243 | selector.layouts[2] = &temp[10]; |
244 | |
245 | precisionCExt = precisionChar(problem.Tc_ext); |
246 | |
247 | alignment[0] = problem.A.alignment; |
248 | alignment[1] = problem.B.alignment; |
249 | alignment[2] = problem.C.alignment; |
250 | |
251 | char *tagPtr = &temp[12]; |
252 | lateTags = tagPtr; |
253 | |
254 | // Late-only tags. Don't choose lower-performing kernels |
255 | // just to fuse reductions. Instead do reductions in a separate kernel. |
256 | if (problem.sumA) *tagPtr++ = ReqSumA; |
257 | if (problem.sumB) *tagPtr++ = ReqSumB; |
258 | |
259 | tags = tagPtr; |
260 | |
261 | if (problem.batch != BatchMode::None) { |
262 | *tagPtr++ = ReqBatch; |
263 | if (problem.batchDims > 1) *tagPtr++ = ReqBatchMultiDim; |
264 | } |
265 | |
266 | if (problem.abOffset != ABOffset::None) *tagPtr++ = ReqABOffset; |
267 | } |
268 | |
269 | } // namespace jit |
270 | } // namespace gpu |
271 | } // namespace impl |
272 | } // namespace dnnl |
273 | |