1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "strategy_parser.hpp"
18#include "utils.hpp"
19
20#include <cctype>
21
22namespace dnnl {
23namespace impl {
24namespace gpu {
25namespace jit {
26
27using namespace ngen;
28
29bool native64Bit(ngen::HW hw) {
30 EmulationStrategy emulate(hw);
31 return !emulate.emulate64;
32}
33
34AccessType getAccessType(char c) {
35 switch (std::tolower(c)) {
36 case 'b': return AccessType::Block;
37 case 'p': return AccessType::PseudoBlock;
38 case 's': return AccessType::Scattered;
39 case 'u': return AccessType::ChannelScattered;
40 case 'm': return AccessType::Block2D;
41 case 't': return AccessType::Block2DTranspose;
42 case 'v': return AccessType::Block2DVNNI;
43 default: throw std::runtime_error("Unknown access type.");
44 }
45}
46
47char downgradeBlock2D(char c) {
48 switch (std::tolower(c)) {
49 case 'm':
50 case 'v': return 'b';
51 case 't': return 's';
52 default: return c;
53 }
54}
55
56AddressBase getAddressBase(char c) {
57 switch (c) {
58 case 'a': return AddressBase::createA64(true);
59 case 'c': return AddressBase::createCC(0);
60 case 'm': return AddressBase::createSC(0);
61 case 's': return AddressBase::createBTS(0);
62 default: throw std::runtime_error("Unknown address space.");
63 }
64}
65
66CacheSettingsLSC getCaching(char l1, char l3) {
67 if (l1 == 'd' && l3 == 'd') return CacheSettingsLSC::Default;
68
69 bool l3cached = (l3 == 'c');
70 switch (l1) {
71 case 'u':
72 return l3cached ? CacheSettingsLSC::L1UC_L3C
73 : CacheSettingsLSC::L1UC_L3UC;
74 case 't':
75 case 'c':
76 return l3cached ? CacheSettingsLSC::L1C_L3C
77 : CacheSettingsLSC::L1C_L3UC;
78 case 's':
79 return l3cached ? CacheSettingsLSC::L1S_L3C
80 : CacheSettingsLSC::L1S_L3UC;
81 case 'b':
82 case 'i': return CacheSettingsLSC::L1IAR_L3C; break;
83 default: throw std::runtime_error("Unknown cache setting");
84 }
85}
86
87void getCaching(std::stringstream &s, MatrixAddressingStrategy &astrategy) {
88 auto &cachingR = astrategy.cachingR;
89 auto &cachingW = astrategy.cachingW;
90
91 cachingR = CacheSettingsLSC::L1C_L3C;
92 cachingW = CacheSettingsLSC::L1WB_L3WB;
93
94 if (s.peek() == '{') {
95 char eat, l1, l3;
96 s >> eat >> l1 >> l3 >> eat;
97 if (eat != '}' && eat != '/')
98 throw std::runtime_error("Invalid caching syntax");
99 cachingR = getCaching(l1, l3);
100 if (eat == '/') {
101 s >> l1 >> l3 >> eat;
102 if (eat != '}') throw std::runtime_error("Invalid caching syntax");
103 cachingW = getCaching(l1, l3);
104 }
105 }
106}
107
108void parseStrategy(const char *str, HW hw, const GEMMProblem &problem,
109 GEMMStrategy &strategy) {
110 std::stringstream s(str);
111 bool overrideFusedLoop = false;
112 bool gotSR = false;
113
114 char eat, asA, asB, asC, accessA, accessB, accessC;
115 char accessAUnaligned = '\0', accessBUnaligned = '\0';
116 char accessAPrefetch = 's', accessBPrefetch = 's', accessCPrefetch = 's';
117
118 s >> std::ws >> asA >> accessA;
119 if (s.peek() == '/') s >> eat >> accessAUnaligned;
120 s >> strategy.ka_load;
121 if (s.peek() == '/') s >> eat >> strategy.ka_load_masked;
122 if (s.peek() == 'x') s >> eat >> strategy.A_copies;
123 getCaching(s, strategy.A);
124 if (s.peek() == '+') {
125 strategy.prefetchA = 1;
126 s >> eat >> accessAPrefetch >> strategy.ka_prefetch;
127 if (s.peek() == ',') s >> eat >> strategy.ka_pfStride;
128 if (s.peek() == '@') s >> eat >> strategy.prefetchA;
129 if (s.peek() == '/')
130 s >> eat >> strategy.prefetchAMasked;
131 else
132 strategy.prefetchAMasked = strategy.prefetchA;
133 getCaching(s, strategy.A_prefetch);
134 }
135 s >> std::ws >> asB >> accessB;
136 if (s.peek() == '/') s >> eat >> accessBUnaligned;
137 s >> strategy.kb_load;
138 if (s.peek() == '/') s >> eat >> strategy.kb_load_masked;
139 if (s.peek() == 'x') s >> eat >> strategy.B_copies;
140 getCaching(s, strategy.B);
141 if (s.peek() == '+') {
142 strategy.prefetchB = 1;
143 s >> eat >> accessBPrefetch >> strategy.kb_prefetch;
144 if (s.peek() == ',') s >> eat >> strategy.kb_pfStride;
145 if (s.peek() == '@') s >> eat >> strategy.prefetchB;
146 if (s.peek() == '/')
147 s >> eat >> strategy.prefetchBMasked;
148 else
149 strategy.prefetchBMasked = strategy.prefetchB;
150 getCaching(s, strategy.B_prefetch);
151 }
152 s >> std::ws >> asC >> accessC;
153 getCaching(s, strategy.C);
154 if (s.peek() == '+') {
155 strategy.prefetchC = 1;
156 s >> eat >> accessCPrefetch;
157 if (s.peek() == '@') s >> eat >> strategy.prefetchC;
158 getCaching(s, strategy.C_prefetch);
159 }
160
161 if (!accessAUnaligned) accessAUnaligned = downgradeBlock2D(accessA);
162 if (!accessBUnaligned) accessBUnaligned = downgradeBlock2D(accessB);
163
164 strategy.A.base = strategy.A_prefetch.base = getAddressBase(asA);
165 strategy.B.base = strategy.B_prefetch.base = getAddressBase(asB);
166 strategy.C.base = strategy.C_prefetch.base = getAddressBase(asC);
167 strategy.CO.base = (hw >= HW::XeHPC) ? AddressBase::createA64(true)
168 : AddressBase::createBTS(0);
169 strategy.A.newDP = bool(std::isupper(accessA));
170 strategy.B.newDP = bool(std::isupper(accessB));
171 strategy.C.newDP = bool(std::isupper(accessC));
172 strategy.CO.newDP = strategy.C.newDP;
173 strategy.A.accessType = getAccessType(accessA);
174 strategy.B.accessType = getAccessType(accessB);
175 strategy.C.accessType = getAccessType(accessC);
176 strategy.unalignedAccA = getAccessType(accessAUnaligned);
177 strategy.unalignedAccB = getAccessType(accessBUnaligned);
178 strategy.A.cachingW = CacheSettingsLSC::Default;
179 strategy.B.cachingW = CacheSettingsLSC::Default;
180 strategy.A_prefetch.prefetch = true;
181 strategy.B_prefetch.prefetch = true;
182 strategy.C_prefetch.prefetch = true;
183 strategy.A_prefetch.newDP = bool(std::isupper(accessAPrefetch));
184 strategy.B_prefetch.newDP = bool(std::isupper(accessBPrefetch));
185 strategy.C_prefetch.newDP = bool(std::isupper(accessCPrefetch));
186 strategy.A_prefetch.accessType = getAccessType(accessAPrefetch);
187 strategy.B_prefetch.accessType = getAccessType(accessBPrefetch);
188 strategy.C_prefetch.accessType = getAccessType(accessCPrefetch);
189 strategy.A_prefetch.cachingW = CacheSettingsLSC::Default;
190 strategy.B_prefetch.cachingW = CacheSettingsLSC::Default;
191 strategy.C_prefetch.cachingW = CacheSettingsLSC::Default;
192
193 strategy.A.padded |= isPacked(problem.A.layout);
194 strategy.B.padded |= isPacked(problem.B.layout);
195 strategy.A_prefetch.padded |= isPacked(problem.A.layout);
196 strategy.B_prefetch.padded |= isPacked(problem.B.layout);
197
198 strategy.unroll[LoopK] = 1;
199 strategy.checkAdd32 = !native64Bit(hw) || (hw >= HW::XeHPC);
200 strategy.altCRemainder |= (strategy.C.accessType == AccessType::Block)
201 || strategy.kParallel;
202
203 while (!s.eof()) {
204 std::string mod;
205 s >> mod;
206 if (mod == "cs")
207 strategy.registerScheme = GEMMStrategy::CSeparate;
208 else if (mod == "acb")
209 strategy.registerScheme = GEMMStrategy::ACB;
210 else if (mod == "bca")
211 strategy.registerScheme = GEMMStrategy::BCA;
212 else if (mod == "vnc")
213 strategy.registerScheme = GEMMStrategy::VNC;
214 else if (mod == "int")
215 strategy.registerScheme = GEMMStrategy::ABInterleave;
216 else if (mod == "nse")
217 strategy.registerScheme = GEMMStrategy::NSeparate;
218 else if (mod == "vav")
219 strategy.registerScheme = GEMMStrategy::VAvoid;
220 else if (mod.substr(0, 3) == "grf") {
221 mod.erase(0, 3);
222 strategy.GRFs = std::stoi(mod);
223 } else if (mod == "sys")
224 strategy.systolic = true;
225 else if (mod == "dw")
226 strategy.dpasw = true;
227 else if (mod == "fs")
228 strategy.fixedSystolic = strategy.systolic = true;
229 else if (mod == "ar")
230 strategy.altCRemainder = true;
231 else if (mod == "sr") {
232 strategy.altCRemainder = false;
233 gotSR = true;
234 } else if (mod == "br")
235 strategy.block2DCRemainder = true;
236 else if (mod == "ac")
237 strategy.cAccumulators = true;
238 else if (mod == "el")
239 strategy.cLoadAhead = true;
240 else if (mod == "di")
241 strategy.delayABInc = true;
242 else if (mod == "sc")
243 strategy.splitCopy = true;
244 else if (mod == "sm")
245 strategy.coopA = CoopSplit::MN;
246 else if (mod == "sn")
247 strategy.coopB = CoopSplit::MN;
248 else if (mod == "ni")
249 strategy.slmUseIncrCopy = false;
250 else if (mod == "ek")
251 strategy.slmEarlyKMask = true;
252 else if (mod == "sf")
253 strategy.strictFence = true;
254 else if (mod == "ta")
255 strategy.slmATrans = true;
256 else if (mod == "tb")
257 strategy.slmBTrans = true;
258 else if (mod == "af")
259 strategy.atomicFMA = true;
260 else if (mod == "xaf")
261 strategy.atomicFMA = strategy.extendedAtomicFMA = true;
262 else if (mod == "st")
263 strategy.stallAfterLoad = true;
264 else if (mod == "ch")
265 strategy.checkAdd32 = true;
266 else if (mod == "ws")
267 strategy.wgInSS = true;
268 else if (mod == "wc")
269 strategy.C.smode = ScatterSIMD::Wide;
270 else if (mod == "cc")
271 strategy.forceCopyC = true;
272 else if (mod == "njs")
273 strategy.jointSplit = false;
274 else if (mod == "np") {
275 strategy.A.padded = strategy.A_prefetch.padded = false;
276 strategy.B.padded = strategy.B_prefetch.padded = false;
277 } else if (mod == "pab") {
278 strategy.A.padded = strategy.A_prefetch.padded = true;
279 strategy.B.padded = strategy.B_prefetch.padded = true;
280 } else if (mod == "pc")
281 strategy.C.padded = strategy.C_prefetch.padded = true;
282 else if (mod == "mnk") {
283 strategy.loopOrder[0] = LoopM;
284 strategy.loopOrder[1] = LoopN;
285 strategy.loopOrder[2] = LoopK;
286 } else if (mod == "nmk") {
287 strategy.loopOrder[0] = LoopN;
288 strategy.loopOrder[1] = LoopM;
289 strategy.loopOrder[2] = LoopK;
290 } else if (mod == "fm") {
291 strategy.fusedLoop = LoopM;
292 overrideFusedLoop = true;
293 } else if (mod == "fn") {
294 strategy.fusedLoop = LoopN;
295 overrideFusedLoop = true;
296 } else if (mod == "rm")
297 strategy.reverse[LoopM] = true;
298 else if (mod == "rn")
299 strategy.reverse[LoopN] = true;
300 else if (mod == "ql")
301 strategy.skewLocalIDs = true;
302 else if (mod == "kb") {
303 strategy.kParallel = true;
304 strategy.C.atomic = true;
305 strategy.CO.atomic = problem.sumA || problem.sumB;
306 if (strategy.CO.atomic)
307 strategy.CO.base = AddressBase::createA64(true);
308 } else if (mod == "kr")
309 strategy.kParallelLocal = true;
310 else if (mod == "au")
311 strategy.C.atomic = true;
312 else if (mod == "xp")
313 strategy.xParallel = true;
314 else if (mod == "ff")
315 strategy.forceWGUpdate = WGFixed;
316 else if (mod == "wg") {
317 char x;
318 s >> strategy.wg[LoopM];
319 s >> x;
320 s >> strategy.wg[LoopN];
321 strategy.wg[LoopK] = 0;
322 if (s.peek() == 'x') s >> x >> strategy.wg[LoopK];
323 } else if (mod == "nb") {
324 char x;
325 s >> strategy.namedBarriers[LoopM];
326 s >> std::ws >> x;
327 s >> strategy.namedBarriers[LoopN];
328 } else if (mod == "bo")
329 strategy.boustrophedon = true;
330 else if (mod == "hi")
331 strategy.hilbertOrder = true;
332 else if (mod == "pt")
333 strategy.persistent = true;
334 else if (mod == "pl") {
335 strategy.A_prefetch.prefetch = false;
336 strategy.B_prefetch.prefetch = false;
337 strategy.C_prefetch.prefetch = false;
338 } else if (mod.length() >= 2) {
339 if (mod.substr(0, 2) == "ms")
340 strategy.mSplitThresh = stoi(mod.substr(2));
341 else if (mod.substr(0, 2) == "ns")
342 strategy.nSplitThresh = stoi(mod.substr(2));
343 else if (mod.substr(0, 2) == "kc")
344 strategy.kChain = stoi(mod.substr(2));
345 else if (mod.substr(0, 2) == "ks") {
346 char eat;
347 std::stringstream ms(mod);
348 ms >> eat >> eat >> strategy.unrollKSLM;
349 if (!ms.eof() && (ms.peek() == '/'))
350 ms >> eat >> strategy.unrollKSLMMasked;
351 } else if (mod.substr(0, 2) == "sb") {
352 strategy.barrierFreq = stoi(mod.substr(2));
353 strategy.splitBarrier = true;
354 } else
355 switch (mod[0]) {
356 case 'b':
357 if (isdigit(mod[1]))
358 strategy.barrierFreq = stoi(mod.substr(1));
359 else {
360 LoopType loop;
361 switch (mod[1]) {
362 case 'm': loop = LoopM; break;
363 case 'n': loop = LoopN; break;
364 case 'k': loop = LoopK; break;
365 default:
366 throw std::runtime_error(
367 "Unknown strategy modifier.");
368 }
369 size_t alt;
370 strategy.blocking[loop] = stoi(mod.substr(2), &alt);
371 if (strategy.blocking[loop] == 0)
372 strategy.blocking[loop] = 16777216;
373 alt += 3;
374 if (mod.length() > alt)
375 strategy.blockingAlt[loop]
376 = stoi(mod.substr(alt));
377 }
378 break;
379 case 'c': {
380 mod.erase(0, 1);
381 if (mod[0] == 'a') {
382 mod.erase(0, 1);
383 strategy.slmA = true;
384 }
385 if (mod[0] == 'b') {
386 mod.erase(0, 1);
387 strategy.slmB = true;
388 }
389 std::stringstream ms(mod);
390 ms >> strategy.slmBuffers;
391 ms >> eat;
392 if (!ms.eof()) ms >> strategy.slmCopies;
393 break;
394 }
395 case 'k': {
396 char eat;
397 std::stringstream ms(mod);
398 ms >> eat >> strategy.unroll[LoopK];
399 if (!ms.eof() && (ms.peek() == '/'))
400 ms >> eat >> strategy.unrollK_masked;
401 break;
402 }
403 case 'l': strategy.optAlignAB = stoi(mod.substr(1)); break;
404 default:
405 throw std::runtime_error("Unknown strategy modifier.");
406 }
407 } else if (!mod.empty())
408 throw std::runtime_error("Unknown strategy modifier.");
409 }
410
411 if (!overrideFusedLoop) {
412 if (strategy.fused) {
413 if (strategy.wg[LoopM] == 1)
414 strategy.fusedLoop = LoopN;
415 else if (strategy.wg[LoopN] == 1)
416 strategy.fusedLoop = LoopM;
417 else
418 strategy.fusedLoop = strategy.loopOrder[0];
419 } else
420 strategy.fusedLoop = strategy.loopOrder[0];
421 }
422
423 if (strategy.ka_pfStride == 0) strategy.ka_pfStride = strategy.ka_prefetch;
424 if (strategy.kb_pfStride == 0) strategy.kb_pfStride = strategy.kb_prefetch;
425
426 if (strategy.block2DCRemainder && !gotSR) strategy.altCRemainder = true;
427
428 int poCount = problem.postOps.len();
429 strategy.binary.resize(poCount);
430 for (auto &astrategy : strategy.binary) {
431 astrategy.base = (hw >= HW::XeHPC) ? AddressBase::createA64(true)
432 : AddressBase::createBTS(0);
433 astrategy.newDP = strategy.C.newDP;
434 }
435}
436
437void adjustStrategy(HW hw, const GEMMProblem &problem, GEMMStrategy &strategy) {
438 auto *gemmAStrategy = &strategy.A, *gemmBStrategy = &strategy.B;
439
440 // 2D block accesses use 2D addressing where supported.
441 strategy.A.address2D
442 |= isBlock2D(strategy.A.accessType) && !isPacked(problem.A.layout);
443 strategy.B.address2D
444 |= isBlock2D(strategy.B.accessType) && !isPacked(problem.B.layout);
445 strategy.C.address2D
446 |= isBlock2D(strategy.C.accessType) && !isPacked(problem.C.layout);
447 strategy.A_prefetch.address2D |= isBlock2D(strategy.A_prefetch.accessType)
448 && !isPacked(problem.A.layout);
449 strategy.B_prefetch.address2D |= isBlock2D(strategy.B_prefetch.accessType)
450 && !isPacked(problem.B.layout);
451 strategy.C_prefetch.address2D |= isBlock2D(strategy.C_prefetch.accessType)
452 && !isPacked(problem.C.layout);
453
454 // No need to use split remainder handling for 2D block accesses as there's no penalty for masking.
455 if (isBlock2D(strategy.A.accessType)
456 && (!strategy.prefetchA
457 || isBlock2D(strategy.A_prefetch.accessType)))
458 strategy.remHandling[LoopM] = RemainderHandling::General;
459 if (isBlock2D(strategy.B.accessType)
460 && (!strategy.prefetchB
461 || isBlock2D(strategy.B_prefetch.accessType)))
462 strategy.remHandling[LoopN] = RemainderHandling::General;
463
464 // Also don't split remainder handling if padded.
465 if (gemmAStrategy->padded)
466 strategy.remHandling[LoopM] = RemainderHandling::General;
467 if (gemmBStrategy->padded)
468 strategy.remHandling[LoopN] = RemainderHandling::General;
469
470 // But always use split remainder handling when prefetching C if it _isn't_ block 2D
471 // ... in that case there are no C prefetches on the remainder path.
472 if (strategy.prefetchC && !isBlock2D(strategy.C_prefetch.accessType))
473 strategy.remHandling[LoopM] = strategy.remHandling[LoopN]
474 = RemainderHandling::Split;
475}
476
477} // namespace jit
478} // namespace gpu
479} // namespace impl
480} // namespace dnnl
481