1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "kernel_selector.hpp"
18#include "common/verbose.hpp"
19#include "kernel_evaluator.hpp"
20
21#include <cassert>
22#include <cctype>
23#include <cstring>
24
25namespace dnnl {
26namespace impl {
27namespace gpu {
28namespace jit {
29
30inline bool layoutMatch(const char *lref, const char *lpattern) {
31 return (lref[0] == lpattern[0]); // This is a sufficient check for now.
32}
33
34inline bool precisionMatch(char pref, char ppattern) {
35 // Fast case-insensitive compare
36 return (pref & ~0x20) == (ppattern & ~0x20);
37}
38
39inline bool precisionMatch(const char *pref, const char *ppattern) {
40 bool ok = false;
41 ok = ok || (ppattern[0] == '?');
42 ok = ok || precisionMatch(pref[0], ppattern[0]);
43 ok = ok || (ppattern[0] == '[' && precisionMatch(pref[0], ppattern[1]));
44 if (ok && pref[0] == '[') {
45 ok = ok && precisionMatch(pref[1], ppattern[1])
46 && precisionMatch(pref[2], ppattern[2]);
47 for (int i = 3; pref[i] != '\0'; i++) {
48 if (pref[i] != ppattern[i]) {
49 ok = false;
50 break;
51 }
52 }
53 }
54 return ok;
55}
56
57inline bool precisionMinimumMatch(const char *pref, char pmin) {
58 uint8_t sizeTable[0x20]
59 = {// A B C D E F G H I J K L M N O
60 0, 0, 2, 8, 8, 0, 0, 0, 2, 4, 4, 0, 0, 0, 0, 1,
61 // P Q R S T U V W X Y Z
62 0, 0, 0, 4, 4, 0, 0, 2, 0, 0, 16, 0, 0, 0, 0, 0};
63
64 return (sizeTable[pref[0] & 0x1F] >= sizeTable[pmin & 0x1F]);
65}
66
67inline bool alignmentMatch(int aref, int apattern) {
68 if (aref == 0) aref = 1;
69 return (apattern % aref == 0);
70}
71
72inline bool tagMatch(const char *tref, const char *tpattern) {
73 for (auto c = tref; *c; c++) {
74 // Lowercase tags -> must not match pattern
75 // Uppercase tags -> must match pattern
76 int cu = *c & ~0x20; // tolower(c)
77 bool match = (std::strchr(tpattern, cu) != nullptr);
78 bool wantMatch = (*c & 0x20) == 0;
79 if (match != wantMatch) return false;
80 }
81 return true;
82}
83
84bool matches(const kcatalog::Entry &e, const MatchParams &pattern) {
85 bool ok = true;
86
87 if (e.restrictions.steppingMin >= 0)
88 ok = ok && (pattern.stepping >= e.restrictions.steppingMin);
89 if (e.restrictions.steppingMax >= 0)
90 ok = ok && (pattern.stepping < e.restrictions.steppingMax);
91 ok = ok && layoutMatch(e.selector.layouts[0], pattern.selector.layouts[0]);
92 ok = ok && layoutMatch(e.selector.layouts[1], pattern.selector.layouts[1]);
93 ok = ok && layoutMatch(e.selector.layouts[2], pattern.selector.layouts[2]);
94 ok = ok
95 && precisionMatch(
96 e.selector.precisions[2], pattern.selector.precisions[2]);
97 if (pattern.precisionCExt)
98 ok = ok
99 && precisionMinimumMatch(
100 e.selector.precisions[2], pattern.precisionCExt);
101 for (int i = 0; i < 3; i++)
102 ok = ok
103 && alignmentMatch(
104 e.restrictions.alignment[i], pattern.alignment[i]);
105 ok = ok && tagMatch(e.restrictions.tags, pattern.tags);
106
107 for (int i = 0; i < 2; i++)
108 if (pattern.unroll[i] > 0)
109 ok = ok && (pattern.unroll[i] == e.driverInfo.unroll[i]);
110
111 if (!pattern.ignoreSizes) {
112 int64_t mnk[3] = {pattern.sizes.m, pattern.sizes.n, pattern.sizes.k};
113 for (int i = 0; i < 3; i++) {
114 if (e.restrictions.allowedSizesMin[i] >= 0)
115 ok = ok && (mnk[i] >= e.restrictions.allowedSizesMin[i]);
116 if (e.restrictions.allowedSizesMax[i] >= 0)
117 ok = ok && (mnk[i] <= e.restrictions.allowedSizesMax[i]);
118 }
119 }
120
121 // Should already be matched.
122 ok = ok && (e.selector.hw == pattern.selector.hw);
123 ok = ok
124 && precisionMatch(
125 e.selector.precisions[0], pattern.selector.precisions[0]);
126 ok = ok
127 && precisionMatch(
128 e.selector.precisions[1], pattern.selector.precisions[1]);
129
130 return ok;
131}
132
133const kcatalog::Entry *select(const kcatalog::Catalog &catalog,
134 const MatchParams &pattern, const EvaluateParams &eparams,
135 EvaluateAuxOutput &aux) {
136 return select(catalog, 1, &pattern, eparams, aux);
137}
138
139const kcatalog::Entry *select(const kcatalog::Catalog &catalog, int npatterns,
140 const MatchParams *patterns, const EvaluateParams &eparams,
141 EvaluateAuxOutput &aux) {
142 double bestScore = std::numeric_limits<double>::infinity();
143 const kcatalog::Entry *bestEntry = nullptr;
144 int bestIPattern = -1;
145
146 bool verbose = (get_verbose() >= 5);
147
148 // TODO: omit evaluation if only one match, if aux output not needed.
149 for (int ipattern = 0; ipattern < npatterns; ipattern++) {
150 for (auto it = match(catalog, patterns[ipattern]); it; it++) {
151 EvaluateAuxOutput thisAux;
152 double score = evaluate(*it, eparams, thisAux);
153 if (score < bestScore) {
154 bestEntry = &*it;
155 bestScore = score;
156 bestIPattern = ipattern;
157 aux = thisAux;
158 }
159 if (verbose) {
160 const auto &info = it->driverInfo;
161 printf("onednn_verbose,info,gpu,gemm,consider:%dx%d,%dx%dx%d,"
162 "score:%f\n",
163 info.unroll[LoopM], info.unroll[LoopN], info.wg[LoopM],
164 info.wg[LoopN], info.wg[LoopK], score);
165 }
166 }
167 }
168
169 // Late tag checking. If late tags do not match, we abandon the kernel and
170 // force the calling code to take another path.
171 if (bestEntry
172 && !tagMatch(bestEntry->restrictions.tags,
173 patterns[bestIPattern].lateTags))
174 return nullptr;
175
176 return bestEntry;
177}
178
179template <bool upper>
180const kcatalog::Entry *upper_lower_bound(
181 const kcatalog::Catalog &catalog, const kcatalog::Selector &selector) {
182 int n = catalog.entryCount;
183 const kcatalog::Entry *cur = catalog.entries;
184
185 while (n > 0) {
186 auto half = n >> 1;
187 auto mid = cur + half;
188 if (upper ? (*mid <= selector) : (*mid < selector)) {
189 cur = mid + 1;
190 n = n - half - 1;
191 } else
192 n = half;
193 }
194
195 return cur;
196}
197
198const kcatalog::Entry *lower_bound(
199 const kcatalog::Catalog &catalog, const kcatalog::Selector &selector) {
200 return upper_lower_bound<false>(catalog, selector);
201}
202
203const kcatalog::Entry *upper_bound(
204 const kcatalog::Catalog &catalog, const kcatalog::Selector &selector) {
205 return upper_lower_bound<true>(catalog, selector);
206}
207
208MatchParams::MatchParams(ngen::HW hw, const GEMMProblem &problem) {
209 using namespace kcatalog;
210
211 switch (hw) {
212 default: assert(!"Unknown architecture");
213 case ngen::HW::Gen9: selector.hw = kcatalog::HWTagGen9; break;
214 case ngen::HW::Gen11: selector.hw = kcatalog::HWTagGen11; break;
215 case ngen::HW::Gen12LP: selector.hw = kcatalog::HWTagGen12LP; break;
216 case ngen::HW::XeHP: selector.hw = kcatalog::HWTagXeHP; break;
217 case ngen::HW::XeHPG: selector.hw = kcatalog::HWTagXeHPG; break;
218 case ngen::HW::XeHPC: selector.hw = kcatalog::HWTagXeHPC; break;
219 }
220
221 auto &C = problem.C;
222 auto equivCLayout = C.layout;
223 if (isPacked(equivCLayout)) {
224 bool colMajor = (C.layout == MatrixLayout::Pc)
225 ^ (C.crosspack * problem.Tc > 4);
226 equivCLayout = (colMajor ? MatrixLayout::N : MatrixLayout::T);
227 }
228
229 selector.kernelType = "gemm";
230
231 std::fill(temp.begin(), temp.end(), '\0');
232 temp[0] = precisionChar(problem.Ta);
233 temp[2] = precisionChar(problem.Tb);
234 temp[4] = precisionChar(problem.Tc);
235 temp[6] = layoutChar(problem.A.layout);
236 temp[8] = layoutChar(problem.B.layout);
237 temp[10] = layoutChar(equivCLayout);
238 selector.precisions[0] = &temp[0];
239 selector.precisions[1] = &temp[2];
240 selector.precisions[2] = &temp[4];
241 selector.layouts[0] = &temp[6];
242 selector.layouts[1] = &temp[8];
243 selector.layouts[2] = &temp[10];
244
245 precisionCExt = precisionChar(problem.Tc_ext);
246
247 alignment[0] = problem.A.alignment;
248 alignment[1] = problem.B.alignment;
249 alignment[2] = problem.C.alignment;
250
251 char *tagPtr = &temp[12];
252 lateTags = tagPtr;
253
254 // Late-only tags. Don't choose lower-performing kernels
255 // just to fuse reductions. Instead do reductions in a separate kernel.
256 if (problem.sumA) *tagPtr++ = ReqSumA;
257 if (problem.sumB) *tagPtr++ = ReqSumB;
258
259 tags = tagPtr;
260
261 if (problem.batch != BatchMode::None) {
262 *tagPtr++ = ReqBatch;
263 if (problem.batchDims > 1) *tagPtr++ = ReqBatchMultiDim;
264 }
265
266 if (problem.abOffset != ABOffset::None) *tagPtr++ = ReqABOffset;
267}
268
269} // namespace jit
270} // namespace gpu
271} // namespace impl
272} // namespace dnnl
273