1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef KERNEL_CATALOG_HPP
18#define KERNEL_CATALOG_HPP
19
20#include <string>
21#include <tuple>
22#include <vector>
23
24#include "gen_gemm_kernel_common.hpp"
25
26namespace dnnl {
27namespace impl {
28namespace gpu {
29namespace jit {
30
31namespace kcatalog {
32
33// There are two versions of the kernel catalog structures:
34// - a mutable version, used inside of ktool for easier modification
35// - an immutable vesion, which is an aggregate type and used when the catalog
36// is loaded.
37using string = const char *;
38#define DEFAULT(val)
39#define DEFAULT3(v1, v2, v3)
40
41struct Restrictions {
42 int steppingMin DEFAULT(
43 -1); // If >= 0, minimum supported stepping (inclusive)
44 int steppingMax DEFAULT(
45 -1); // If >= 0, maximum supported stepping (exclusive)
46 int acceptSizesMin[3] DEFAULT3(
47 -1, -1, -1); // m/n/k ranges where kernel always accepted.
48 int acceptSizesMax[3] DEFAULT3(
49 -1, -1, -1); // (see kernel_evaluator.cpp for more details)
50 int allowedSizesMin[3] DEFAULT3(-1, -1,
51 -1); // m/n/k ranges outside of which kernel always rejected.
52 int allowedSizesMax[3] DEFAULT3(-1, -1, -1);
53 int alignment[3] DEFAULT3(1, 1, 1); // A/B/C alignment requirements.
54 string tags; // see RestrictionTags for entries.
55};
56
57enum RestrictionTags : char {
58 ReqBlock2DA = 'A',
59 ReqNoBlock2DA = 'a',
60 ReqBlock2DB = 'B',
61 ReqNoBlock2DB = 'b',
62 ReqBlock2DC = 'C',
63 ReqNoBlock2DC = 'c',
64 Req64BitA = 'X',
65 ReqNo64BitA = 'x',
66 Req64BitB = 'Y',
67 ReqNo64BitB = 'y',
68 Req64BitC = 'Z',
69 ReqNo64BitC = 'z',
70 ReqBatch = 'V',
71 ReqNoBatch = 'v',
72 ReqBatchMultiDim = 'W',
73 ReqNoBatchMultiDim = 'w',
74 ReqABOffset = 'O',
75 ReqNoABOffset = 'o',
76 ReqSumA = 'Q',
77 ReqNoSumA = 'q',
78 ReqSumB = 'P',
79 ReqNoSumB = 'p',
80 ReqCustom1 = 'D',
81 ReqNoCustom1 = 'd',
82};
83
84enum HWTags : char {
85 HWTagGen9 = '9',
86 HWTagGen11 = 'B',
87 HWTagGen12LP = 'C',
88 HWTagXeHP = 'D',
89 HWTagXeHPG = 'E',
90 HWTagXeHPC = 'F',
91};
92
93struct Selector {
94 char hw; // see HWTags for entries
95 string kernelType;
96 string precisions[3];
97 string layouts[3];
98
99 friend bool operator<(const Selector &sel1, const Selector &sel2) {
100 auto tupleize = [](const Selector &sel) {
101 return std::make_tuple(sel.hw, sel.precisions[0][0] & 0x1F,
102 sel.layouts[0][0], sel.layouts[1][0]);
103 };
104 return tupleize(sel1) < tupleize(sel2);
105 };
106 friend bool operator>(const Selector &sel1, const Selector &sel2) {
107 return sel2 < sel1;
108 }
109 friend bool operator<=(const Selector &sel1, const Selector &sel2) {
110 return !(sel2 < sel1);
111 }
112 friend bool operator>=(const Selector &sel1, const Selector &sel2) {
113 return !(sel1 < sel2);
114 }
115};
116
117enum : int {
118 // Model 'W' parameters
119 ParamWPriority = 0,
120 ParamWCount,
121
122 // Model 'S' parameters
123 ParamS_Cm0 = 0, // Minimum constant overhead, beta = 0
124 ParamS_Cm1, // Minimum constant overhead, beta = 1
125 ParamS_C00, // Overhead per partial wave, constant coefficient, beta = 0
126 ParamS_C01, // Overhead per partial wave, constant coefficient, beta = 1
127 ParamS_C10, // Overhead per partial wave, linear coefficient, beta = 0
128 ParamS_C11, // Overhead per partial wave, linear coefficient, beta = 1
129 ParamS_Ma, // A per-element load cost
130 ParamS_Mb, // B per-element load cost
131 ParamS_Ef, // Peak efficiency, full waves
132 ParamS_Ep0, // Peak efficiency, partial wave, constant coefficient
133 ParamS_Ep1, // Peak efficiency, partial wave, linear coefficient
134 ParamS_Em, // Load balancing weight factor (0 = ignore load balancing, 1 = full weight for load balancing term)
135 ParamS_Fp, // Max sustained frequency ratio nominal freq/actual freq.
136 ParamS_Fr0, // FMA count at which frequency starts dropping.
137 ParamS_Fr1, // FMA count at which frequency stops dropping.
138 ParamSCount,
139
140 // Maximum possible parameter count
141 MaxParamCount = ParamSCount,
142};
143
144struct Model {
145 char id;
146 int paramCount;
147 double params[MaxParamCount];
148};
149
150struct Entry {
151 Selector selector;
152 Restrictions restrictions;
153 string strategy;
154 CommonDriverInfo driverInfo;
155 Model model;
156
157 friend bool operator<(const Entry &e1, const Entry &e2) {
158 return e1.selector < e2.selector;
159 }
160 friend bool operator>(const Entry &e1, const Entry &e2) {
161 return e1.selector > e2.selector;
162 }
163 friend bool operator<=(const Entry &e1, const Entry &e2) {
164 return e1.selector <= e2.selector;
165 }
166 friend bool operator>=(const Entry &e1, const Entry &e2) {
167 return e1.selector >= e2.selector;
168 }
169 friend bool operator<(const Entry &e, const Selector &s) {
170 return e.selector < s;
171 }
172 friend bool operator>(const Entry &e, const Selector &s) {
173 return e.selector > s;
174 }
175 friend bool operator<=(const Entry &e, const Selector &s) {
176 return e.selector <= s;
177 }
178 friend bool operator>=(const Entry &e, const Selector &s) {
179 return e.selector >= s;
180 }
181 friend bool operator<(const Selector &s, const Entry &e) {
182 return s < e.selector;
183 }
184 friend bool operator>(const Selector &s, const Entry &e) {
185 return s > e.selector;
186 }
187 friend bool operator<=(const Selector &s, const Entry &e) {
188 return s <= e.selector;
189 }
190 friend bool operator>=(const Selector &s, const Entry &e) {
191 return s >= e.selector;
192 }
193};
194
195struct Catalog {
196 static constexpr int currentVersion() { return 1; }
197
198 int version DEFAULT(currentVersion());
199 uint64_t revision DEFAULT(0);
200 int entryCount DEFAULT(0);
201
202 const Entry *entries;
203};
204
205template <size_t n>
206struct FlatCatalog {
207 int version;
208 uint64_t revision;
209 int entryCount;
210 Entry entries[n];
211
212 /* implicit */ operator Catalog() const {
213 Catalog catalog = {version, revision, entryCount, &entries[0]};
214 return catalog;
215 }
216};
217
218} /* namespace kcatalog */
219
220} // namespace jit
221} // namespace gpu
222} // namespace impl
223} // namespace dnnl
224
225#endif /* header guard */
226