1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "strategy_parser.hpp" |
18 | #include "utils.hpp" |
19 | |
20 | #include <cctype> |
21 | |
22 | namespace dnnl { |
23 | namespace impl { |
24 | namespace gpu { |
25 | namespace jit { |
26 | |
27 | using namespace ngen; |
28 | |
29 | bool native64Bit(ngen::HW hw) { |
30 | EmulationStrategy emulate(hw); |
31 | return !emulate.emulate64; |
32 | } |
33 | |
34 | AccessType getAccessType(char c) { |
35 | switch (std::tolower(c)) { |
36 | case 'b': return AccessType::Block; |
37 | case 'p': return AccessType::PseudoBlock; |
38 | case 's': return AccessType::Scattered; |
39 | case 'u': return AccessType::ChannelScattered; |
40 | case 'm': return AccessType::Block2D; |
41 | case 't': return AccessType::Block2DTranspose; |
42 | case 'v': return AccessType::Block2DVNNI; |
43 | default: throw std::runtime_error("Unknown access type." ); |
44 | } |
45 | } |
46 | |
47 | char downgradeBlock2D(char c) { |
48 | switch (std::tolower(c)) { |
49 | case 'm': |
50 | case 'v': return 'b'; |
51 | case 't': return 's'; |
52 | default: return c; |
53 | } |
54 | } |
55 | |
56 | AddressBase getAddressBase(char c) { |
57 | switch (c) { |
58 | case 'a': return AddressBase::createA64(true); |
59 | case 'c': return AddressBase::createCC(0); |
60 | case 'm': return AddressBase::createSC(0); |
61 | case 's': return AddressBase::createBTS(0); |
62 | default: throw std::runtime_error("Unknown address space." ); |
63 | } |
64 | } |
65 | |
66 | CacheSettingsLSC getCaching(char l1, char l3) { |
67 | if (l1 == 'd' && l3 == 'd') return CacheSettingsLSC::Default; |
68 | |
69 | bool l3cached = (l3 == 'c'); |
70 | switch (l1) { |
71 | case 'u': |
72 | return l3cached ? CacheSettingsLSC::L1UC_L3C |
73 | : CacheSettingsLSC::L1UC_L3UC; |
74 | case 't': |
75 | case 'c': |
76 | return l3cached ? CacheSettingsLSC::L1C_L3C |
77 | : CacheSettingsLSC::L1C_L3UC; |
78 | case 's': |
79 | return l3cached ? CacheSettingsLSC::L1S_L3C |
80 | : CacheSettingsLSC::L1S_L3UC; |
81 | case 'b': |
82 | case 'i': return CacheSettingsLSC::L1IAR_L3C; break; |
83 | default: throw std::runtime_error("Unknown cache setting" ); |
84 | } |
85 | } |
86 | |
87 | void getCaching(std::stringstream &s, MatrixAddressingStrategy &astrategy) { |
88 | auto &cachingR = astrategy.cachingR; |
89 | auto &cachingW = astrategy.cachingW; |
90 | |
91 | cachingR = CacheSettingsLSC::L1C_L3C; |
92 | cachingW = CacheSettingsLSC::L1WB_L3WB; |
93 | |
94 | if (s.peek() == '{') { |
95 | char eat, l1, l3; |
96 | s >> eat >> l1 >> l3 >> eat; |
97 | if (eat != '}' && eat != '/') |
98 | throw std::runtime_error("Invalid caching syntax" ); |
99 | cachingR = getCaching(l1, l3); |
100 | if (eat == '/') { |
101 | s >> l1 >> l3 >> eat; |
102 | if (eat != '}') throw std::runtime_error("Invalid caching syntax" ); |
103 | cachingW = getCaching(l1, l3); |
104 | } |
105 | } |
106 | } |
107 | |
108 | void parseStrategy(const char *str, HW hw, const GEMMProblem &problem, |
109 | GEMMStrategy &strategy) { |
110 | std::stringstream s(str); |
111 | bool overrideFusedLoop = false; |
112 | bool gotSR = false; |
113 | |
114 | char eat, asA, asB, asC, accessA, accessB, accessC; |
115 | char accessAUnaligned = '\0', accessBUnaligned = '\0'; |
116 | char accessAPrefetch = 's', accessBPrefetch = 's', accessCPrefetch = 's'; |
117 | |
118 | s >> std::ws >> asA >> accessA; |
119 | if (s.peek() == '/') s >> eat >> accessAUnaligned; |
120 | s >> strategy.ka_load; |
121 | if (s.peek() == '/') s >> eat >> strategy.ka_load_masked; |
122 | if (s.peek() == 'x') s >> eat >> strategy.A_copies; |
123 | getCaching(s, strategy.A); |
124 | if (s.peek() == '+') { |
125 | strategy.prefetchA = 1; |
126 | s >> eat >> accessAPrefetch >> strategy.ka_prefetch; |
127 | if (s.peek() == ',') s >> eat >> strategy.ka_pfStride; |
128 | if (s.peek() == '@') s >> eat >> strategy.prefetchA; |
129 | if (s.peek() == '/') |
130 | s >> eat >> strategy.prefetchAMasked; |
131 | else |
132 | strategy.prefetchAMasked = strategy.prefetchA; |
133 | getCaching(s, strategy.A_prefetch); |
134 | } |
135 | s >> std::ws >> asB >> accessB; |
136 | if (s.peek() == '/') s >> eat >> accessBUnaligned; |
137 | s >> strategy.kb_load; |
138 | if (s.peek() == '/') s >> eat >> strategy.kb_load_masked; |
139 | if (s.peek() == 'x') s >> eat >> strategy.B_copies; |
140 | getCaching(s, strategy.B); |
141 | if (s.peek() == '+') { |
142 | strategy.prefetchB = 1; |
143 | s >> eat >> accessBPrefetch >> strategy.kb_prefetch; |
144 | if (s.peek() == ',') s >> eat >> strategy.kb_pfStride; |
145 | if (s.peek() == '@') s >> eat >> strategy.prefetchB; |
146 | if (s.peek() == '/') |
147 | s >> eat >> strategy.prefetchBMasked; |
148 | else |
149 | strategy.prefetchBMasked = strategy.prefetchB; |
150 | getCaching(s, strategy.B_prefetch); |
151 | } |
152 | s >> std::ws >> asC >> accessC; |
153 | getCaching(s, strategy.C); |
154 | if (s.peek() == '+') { |
155 | strategy.prefetchC = 1; |
156 | s >> eat >> accessCPrefetch; |
157 | if (s.peek() == '@') s >> eat >> strategy.prefetchC; |
158 | getCaching(s, strategy.C_prefetch); |
159 | } |
160 | |
161 | if (!accessAUnaligned) accessAUnaligned = downgradeBlock2D(accessA); |
162 | if (!accessBUnaligned) accessBUnaligned = downgradeBlock2D(accessB); |
163 | |
164 | strategy.A.base = strategy.A_prefetch.base = getAddressBase(asA); |
165 | strategy.B.base = strategy.B_prefetch.base = getAddressBase(asB); |
166 | strategy.C.base = strategy.C_prefetch.base = getAddressBase(asC); |
167 | strategy.CO.base = (hw >= HW::XeHPC) ? AddressBase::createA64(true) |
168 | : AddressBase::createBTS(0); |
169 | strategy.A.newDP = bool(std::isupper(accessA)); |
170 | strategy.B.newDP = bool(std::isupper(accessB)); |
171 | strategy.C.newDP = bool(std::isupper(accessC)); |
172 | strategy.CO.newDP = strategy.C.newDP; |
173 | strategy.A.accessType = getAccessType(accessA); |
174 | strategy.B.accessType = getAccessType(accessB); |
175 | strategy.C.accessType = getAccessType(accessC); |
176 | strategy.unalignedAccA = getAccessType(accessAUnaligned); |
177 | strategy.unalignedAccB = getAccessType(accessBUnaligned); |
178 | strategy.A.cachingW = CacheSettingsLSC::Default; |
179 | strategy.B.cachingW = CacheSettingsLSC::Default; |
180 | strategy.A_prefetch.prefetch = true; |
181 | strategy.B_prefetch.prefetch = true; |
182 | strategy.C_prefetch.prefetch = true; |
183 | strategy.A_prefetch.newDP = bool(std::isupper(accessAPrefetch)); |
184 | strategy.B_prefetch.newDP = bool(std::isupper(accessBPrefetch)); |
185 | strategy.C_prefetch.newDP = bool(std::isupper(accessCPrefetch)); |
186 | strategy.A_prefetch.accessType = getAccessType(accessAPrefetch); |
187 | strategy.B_prefetch.accessType = getAccessType(accessBPrefetch); |
188 | strategy.C_prefetch.accessType = getAccessType(accessCPrefetch); |
189 | strategy.A_prefetch.cachingW = CacheSettingsLSC::Default; |
190 | strategy.B_prefetch.cachingW = CacheSettingsLSC::Default; |
191 | strategy.C_prefetch.cachingW = CacheSettingsLSC::Default; |
192 | |
193 | strategy.A.padded |= isPacked(problem.A.layout); |
194 | strategy.B.padded |= isPacked(problem.B.layout); |
195 | strategy.A_prefetch.padded |= isPacked(problem.A.layout); |
196 | strategy.B_prefetch.padded |= isPacked(problem.B.layout); |
197 | |
198 | strategy.unroll[LoopK] = 1; |
199 | strategy.checkAdd32 = !native64Bit(hw) || (hw >= HW::XeHPC); |
200 | strategy.altCRemainder |= (strategy.C.accessType == AccessType::Block) |
201 | || strategy.kParallel; |
202 | |
203 | while (!s.eof()) { |
204 | std::string mod; |
205 | s >> mod; |
206 | if (mod == "cs" ) |
207 | strategy.registerScheme = GEMMStrategy::CSeparate; |
208 | else if (mod == "acb" ) |
209 | strategy.registerScheme = GEMMStrategy::ACB; |
210 | else if (mod == "bca" ) |
211 | strategy.registerScheme = GEMMStrategy::BCA; |
212 | else if (mod == "vnc" ) |
213 | strategy.registerScheme = GEMMStrategy::VNC; |
214 | else if (mod == "int" ) |
215 | strategy.registerScheme = GEMMStrategy::ABInterleave; |
216 | else if (mod == "nse" ) |
217 | strategy.registerScheme = GEMMStrategy::NSeparate; |
218 | else if (mod == "vav" ) |
219 | strategy.registerScheme = GEMMStrategy::VAvoid; |
220 | else if (mod.substr(0, 3) == "grf" ) { |
221 | mod.erase(0, 3); |
222 | strategy.GRFs = std::stoi(mod); |
223 | } else if (mod == "sys" ) |
224 | strategy.systolic = true; |
225 | else if (mod == "dw" ) |
226 | strategy.dpasw = true; |
227 | else if (mod == "fs" ) |
228 | strategy.fixedSystolic = strategy.systolic = true; |
229 | else if (mod == "ar" ) |
230 | strategy.altCRemainder = true; |
231 | else if (mod == "sr" ) { |
232 | strategy.altCRemainder = false; |
233 | gotSR = true; |
234 | } else if (mod == "br" ) |
235 | strategy.block2DCRemainder = true; |
236 | else if (mod == "ac" ) |
237 | strategy.cAccumulators = true; |
238 | else if (mod == "el" ) |
239 | strategy.cLoadAhead = true; |
240 | else if (mod == "di" ) |
241 | strategy.delayABInc = true; |
242 | else if (mod == "sc" ) |
243 | strategy.splitCopy = true; |
244 | else if (mod == "sm" ) |
245 | strategy.coopA = CoopSplit::MN; |
246 | else if (mod == "sn" ) |
247 | strategy.coopB = CoopSplit::MN; |
248 | else if (mod == "ni" ) |
249 | strategy.slmUseIncrCopy = false; |
250 | else if (mod == "ek" ) |
251 | strategy.slmEarlyKMask = true; |
252 | else if (mod == "sf" ) |
253 | strategy.strictFence = true; |
254 | else if (mod == "ta" ) |
255 | strategy.slmATrans = true; |
256 | else if (mod == "tb" ) |
257 | strategy.slmBTrans = true; |
258 | else if (mod == "af" ) |
259 | strategy.atomicFMA = true; |
260 | else if (mod == "xaf" ) |
261 | strategy.atomicFMA = strategy.extendedAtomicFMA = true; |
262 | else if (mod == "st" ) |
263 | strategy.stallAfterLoad = true; |
264 | else if (mod == "ch" ) |
265 | strategy.checkAdd32 = true; |
266 | else if (mod == "ws" ) |
267 | strategy.wgInSS = true; |
268 | else if (mod == "wc" ) |
269 | strategy.C.smode = ScatterSIMD::Wide; |
270 | else if (mod == "cc" ) |
271 | strategy.forceCopyC = true; |
272 | else if (mod == "njs" ) |
273 | strategy.jointSplit = false; |
274 | else if (mod == "np" ) { |
275 | strategy.A.padded = strategy.A_prefetch.padded = false; |
276 | strategy.B.padded = strategy.B_prefetch.padded = false; |
277 | } else if (mod == "pab" ) { |
278 | strategy.A.padded = strategy.A_prefetch.padded = true; |
279 | strategy.B.padded = strategy.B_prefetch.padded = true; |
280 | } else if (mod == "pc" ) |
281 | strategy.C.padded = strategy.C_prefetch.padded = true; |
282 | else if (mod == "mnk" ) { |
283 | strategy.loopOrder[0] = LoopM; |
284 | strategy.loopOrder[1] = LoopN; |
285 | strategy.loopOrder[2] = LoopK; |
286 | } else if (mod == "nmk" ) { |
287 | strategy.loopOrder[0] = LoopN; |
288 | strategy.loopOrder[1] = LoopM; |
289 | strategy.loopOrder[2] = LoopK; |
290 | } else if (mod == "fm" ) { |
291 | strategy.fusedLoop = LoopM; |
292 | overrideFusedLoop = true; |
293 | } else if (mod == "fn" ) { |
294 | strategy.fusedLoop = LoopN; |
295 | overrideFusedLoop = true; |
296 | } else if (mod == "rm" ) |
297 | strategy.reverse[LoopM] = true; |
298 | else if (mod == "rn" ) |
299 | strategy.reverse[LoopN] = true; |
300 | else if (mod == "ql" ) |
301 | strategy.skewLocalIDs = true; |
302 | else if (mod == "kb" ) { |
303 | strategy.kParallel = true; |
304 | strategy.C.atomic = true; |
305 | strategy.CO.atomic = problem.sumA || problem.sumB; |
306 | if (strategy.CO.atomic) |
307 | strategy.CO.base = AddressBase::createA64(true); |
308 | } else if (mod == "kr" ) |
309 | strategy.kParallelLocal = true; |
310 | else if (mod == "au" ) |
311 | strategy.C.atomic = true; |
312 | else if (mod == "xp" ) |
313 | strategy.xParallel = true; |
314 | else if (mod == "ff" ) |
315 | strategy.forceWGUpdate = WGFixed; |
316 | else if (mod == "wg" ) { |
317 | char x; |
318 | s >> strategy.wg[LoopM]; |
319 | s >> x; |
320 | s >> strategy.wg[LoopN]; |
321 | strategy.wg[LoopK] = 0; |
322 | if (s.peek() == 'x') s >> x >> strategy.wg[LoopK]; |
323 | } else if (mod == "nb" ) { |
324 | char x; |
325 | s >> strategy.namedBarriers[LoopM]; |
326 | s >> std::ws >> x; |
327 | s >> strategy.namedBarriers[LoopN]; |
328 | } else if (mod == "bo" ) |
329 | strategy.boustrophedon = true; |
330 | else if (mod == "hi" ) |
331 | strategy.hilbertOrder = true; |
332 | else if (mod == "pt" ) |
333 | strategy.persistent = true; |
334 | else if (mod == "pl" ) { |
335 | strategy.A_prefetch.prefetch = false; |
336 | strategy.B_prefetch.prefetch = false; |
337 | strategy.C_prefetch.prefetch = false; |
338 | } else if (mod.length() >= 2) { |
339 | if (mod.substr(0, 2) == "ms" ) |
340 | strategy.mSplitThresh = stoi(mod.substr(2)); |
341 | else if (mod.substr(0, 2) == "ns" ) |
342 | strategy.nSplitThresh = stoi(mod.substr(2)); |
343 | else if (mod.substr(0, 2) == "kc" ) |
344 | strategy.kChain = stoi(mod.substr(2)); |
345 | else if (mod.substr(0, 2) == "ks" ) { |
346 | char eat; |
347 | std::stringstream ms(mod); |
348 | ms >> eat >> eat >> strategy.unrollKSLM; |
349 | if (!ms.eof() && (ms.peek() == '/')) |
350 | ms >> eat >> strategy.unrollKSLMMasked; |
351 | } else if (mod.substr(0, 2) == "sb" ) { |
352 | strategy.barrierFreq = stoi(mod.substr(2)); |
353 | strategy.splitBarrier = true; |
354 | } else |
355 | switch (mod[0]) { |
356 | case 'b': |
357 | if (isdigit(mod[1])) |
358 | strategy.barrierFreq = stoi(mod.substr(1)); |
359 | else { |
360 | LoopType loop; |
361 | switch (mod[1]) { |
362 | case 'm': loop = LoopM; break; |
363 | case 'n': loop = LoopN; break; |
364 | case 'k': loop = LoopK; break; |
365 | default: |
366 | throw std::runtime_error( |
367 | "Unknown strategy modifier." ); |
368 | } |
369 | size_t alt; |
370 | strategy.blocking[loop] = stoi(mod.substr(2), &alt); |
371 | if (strategy.blocking[loop] == 0) |
372 | strategy.blocking[loop] = 16777216; |
373 | alt += 3; |
374 | if (mod.length() > alt) |
375 | strategy.blockingAlt[loop] |
376 | = stoi(mod.substr(alt)); |
377 | } |
378 | break; |
379 | case 'c': { |
380 | mod.erase(0, 1); |
381 | if (mod[0] == 'a') { |
382 | mod.erase(0, 1); |
383 | strategy.slmA = true; |
384 | } |
385 | if (mod[0] == 'b') { |
386 | mod.erase(0, 1); |
387 | strategy.slmB = true; |
388 | } |
389 | std::stringstream ms(mod); |
390 | ms >> strategy.slmBuffers; |
391 | ms >> eat; |
392 | if (!ms.eof()) ms >> strategy.slmCopies; |
393 | break; |
394 | } |
395 | case 'k': { |
396 | char eat; |
397 | std::stringstream ms(mod); |
398 | ms >> eat >> strategy.unroll[LoopK]; |
399 | if (!ms.eof() && (ms.peek() == '/')) |
400 | ms >> eat >> strategy.unrollK_masked; |
401 | break; |
402 | } |
403 | case 'l': strategy.optAlignAB = stoi(mod.substr(1)); break; |
404 | default: |
405 | throw std::runtime_error("Unknown strategy modifier." ); |
406 | } |
407 | } else if (!mod.empty()) |
408 | throw std::runtime_error("Unknown strategy modifier." ); |
409 | } |
410 | |
411 | if (!overrideFusedLoop) { |
412 | if (strategy.fused) { |
413 | if (strategy.wg[LoopM] == 1) |
414 | strategy.fusedLoop = LoopN; |
415 | else if (strategy.wg[LoopN] == 1) |
416 | strategy.fusedLoop = LoopM; |
417 | else |
418 | strategy.fusedLoop = strategy.loopOrder[0]; |
419 | } else |
420 | strategy.fusedLoop = strategy.loopOrder[0]; |
421 | } |
422 | |
423 | if (strategy.ka_pfStride == 0) strategy.ka_pfStride = strategy.ka_prefetch; |
424 | if (strategy.kb_pfStride == 0) strategy.kb_pfStride = strategy.kb_prefetch; |
425 | |
426 | if (strategy.block2DCRemainder && !gotSR) strategy.altCRemainder = true; |
427 | |
428 | int poCount = problem.postOps.len(); |
429 | strategy.binary.resize(poCount); |
430 | for (auto &astrategy : strategy.binary) { |
431 | astrategy.base = (hw >= HW::XeHPC) ? AddressBase::createA64(true) |
432 | : AddressBase::createBTS(0); |
433 | astrategy.newDP = strategy.C.newDP; |
434 | } |
435 | } |
436 | |
437 | void adjustStrategy(HW hw, const GEMMProblem &problem, GEMMStrategy &strategy) { |
438 | auto *gemmAStrategy = &strategy.A, *gemmBStrategy = &strategy.B; |
439 | |
440 | // 2D block accesses use 2D addressing where supported. |
441 | strategy.A.address2D |
442 | |= isBlock2D(strategy.A.accessType) && !isPacked(problem.A.layout); |
443 | strategy.B.address2D |
444 | |= isBlock2D(strategy.B.accessType) && !isPacked(problem.B.layout); |
445 | strategy.C.address2D |
446 | |= isBlock2D(strategy.C.accessType) && !isPacked(problem.C.layout); |
447 | strategy.A_prefetch.address2D |= isBlock2D(strategy.A_prefetch.accessType) |
448 | && !isPacked(problem.A.layout); |
449 | strategy.B_prefetch.address2D |= isBlock2D(strategy.B_prefetch.accessType) |
450 | && !isPacked(problem.B.layout); |
451 | strategy.C_prefetch.address2D |= isBlock2D(strategy.C_prefetch.accessType) |
452 | && !isPacked(problem.C.layout); |
453 | |
454 | // No need to use split remainder handling for 2D block accesses as there's no penalty for masking. |
455 | if (isBlock2D(strategy.A.accessType) |
456 | && (!strategy.prefetchA |
457 | || isBlock2D(strategy.A_prefetch.accessType))) |
458 | strategy.remHandling[LoopM] = RemainderHandling::General; |
459 | if (isBlock2D(strategy.B.accessType) |
460 | && (!strategy.prefetchB |
461 | || isBlock2D(strategy.B_prefetch.accessType))) |
462 | strategy.remHandling[LoopN] = RemainderHandling::General; |
463 | |
464 | // Also don't split remainder handling if padded. |
465 | if (gemmAStrategy->padded) |
466 | strategy.remHandling[LoopM] = RemainderHandling::General; |
467 | if (gemmBStrategy->padded) |
468 | strategy.remHandling[LoopN] = RemainderHandling::General; |
469 | |
470 | // But always use split remainder handling when prefetching C if it _isn't_ block 2D |
471 | // ... in that case there are no C prefetches on the remainder path. |
472 | if (strategy.prefetchC && !isBlock2D(strategy.C_prefetch.accessType)) |
473 | strategy.remHandling[LoopM] = strategy.remHandling[LoopN] |
474 | = RemainderHandling::Split; |
475 | } |
476 | |
477 | } // namespace jit |
478 | } // namespace gpu |
479 | } // namespace impl |
480 | } // namespace dnnl |
481 | |