1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <array> |
18 | #include <cstddef> |
19 | #include <functional> |
20 | #include <numeric> |
21 | #include <stdexcept> |
22 | #include <vector> |
23 | |
24 | #include "common/impl_registration.hpp" |
25 | #include "gpu/jit/gemm/gen_gemm_kernel_generator.hpp" |
26 | #include "gpu/jit/gemm/loop_sequencer.hpp" |
27 | #include "gpu/jit/gemm/utils.hpp" |
28 | |
29 | namespace dnnl { |
30 | namespace impl { |
31 | namespace gpu { |
32 | namespace jit { |
33 | |
34 | using namespace ngen; |
35 | using namespace ngen::utils; |
36 | using dnnl::impl::utils::one_of; |
37 | using ngen::utils::log2; |
38 | |
39 | using std::complex; |
40 | using std::vector; |
41 | |
42 | #define MOCK_BARRIERS |
43 | |
44 | class need_vflag : public std::runtime_error { |
45 | public: |
46 | need_vflag() : std::runtime_error("Need virtual flag registers" ) {} |
47 | }; |
48 | |
49 | class stub_exception : public std::runtime_error { |
50 | public: |
51 | stub_exception() |
52 | : std::runtime_error("Functionality not yet implemented" ) {} |
53 | }; |
54 | |
55 | class hw_unsupported_exception : public std::runtime_error { |
56 | public: |
57 | hw_unsupported_exception() |
58 | : std::runtime_error("Unsupported in hardware" ) {} |
59 | }; |
60 | |
61 | [[noreturn]] static void hw_unsupported() { |
62 | throw hw_unsupported_exception(); |
63 | } |
64 | |
65 | [[noreturn]] static void stub() { |
66 | throw stub_exception(); |
67 | } |
68 | |
69 | // Helpers |
70 | template <typename U> |
71 | static inline Immediate cast(Type T, U val) { |
72 | switch (T) { |
73 | case Type::f16: return half(val); |
74 | case Type::f32: return float(val); |
75 | case Type::u8: return uint8_t(val); |
76 | case Type::s8: return int8_t(val); |
77 | case Type::u16: return uint16_t(val); |
78 | case Type::s16: return int16_t(val); |
79 | case Type::u32: return uint32_t(val); |
80 | case Type::s32: return int32_t(val); |
81 | case Type::u64: return uint64_t(val); |
82 | case Type::s64: return int64_t(val); |
83 | case Type::bf16: |
84 | case Type::tf32: |
85 | default: stub(); |
86 | } |
87 | } |
88 | |
89 | static inline Immediate cast(Type T, Scalar<double> val) { |
90 | return cast(T, double(val)); |
91 | } |
92 | |
93 | bool Type::isSubsetOf(Type T) const { |
94 | if (*this == T) return true; |
95 | |
96 | if (isInteger() && T == bf16) return false; |
97 | |
98 | return (size() < T.size()); |
99 | } |
100 | |
101 | constexpr bool operator==(const RegData &rd, int i) { |
102 | return false; |
103 | } |
104 | constexpr bool operator==(const RegData &rd, const Immediate &i) { |
105 | return false; |
106 | } |
107 | constexpr bool operator!=(const RegData &rd, int i) { |
108 | return true; |
109 | } |
110 | constexpr bool operator!=(const RegData &rd, const Immediate &i) { |
111 | return true; |
112 | } |
113 | |
114 | void noop() {} |
115 | |
116 | static inline constexpr bool isGen9IGEMM(HW hw, Type Ta, Type Tb, Type Tc) { |
117 | return (hw < HW::Gen12LP && Ta.size() == 1 && Tb.size() == 1 |
118 | && Tc.size() == 4); |
119 | } |
120 | |
121 | template <typename T> |
122 | static inline constexpr int elementsPerGRF(HW hw) { |
123 | return GRF::bytes(hw) / sizeof(T); |
124 | } |
125 | |
126 | static inline constexpr int elementsPerGRF(HW hw, Type T) { |
127 | return GRF::bytes(hw) / T; |
128 | } |
129 | |
130 | static inline constexpr int elementsPerGRF(HW hw, DataType dt) { |
131 | return GRF::bytes(hw) / getBytes(dt); |
132 | } |
133 | |
134 | static inline bool canSwizzle(HW hw, DataType dt) { |
135 | if (hw < HW::XeHP) return true; |
136 | |
137 | switch (dt) { |
138 | case DataType::b: |
139 | case DataType::ub: |
140 | case DataType::w: |
141 | case DataType::uw: |
142 | case DataType::d: |
143 | case DataType::ud: return true; |
144 | case DataType::q: |
145 | case DataType::uq: return (hw >= HW::XeHPC); |
146 | default: return false; |
147 | } |
148 | } |
149 | |
150 | static inline bool canSwizzle(HW hw, Type T) { |
151 | return canSwizzle(hw, T.ngen()); |
152 | } |
153 | |
154 | static inline bool hasNativeAtomicAdd(HW hw, Type T, |
155 | const MatrixAddressing &atype, |
156 | const MatrixAddressingStrategy &astrategy) { |
157 | bool floatAtomics = (astrategy.base.getModel() == ModelA64); |
158 | if (astrategy.newDP) |
159 | floatAtomics |= (astrategy.base.getModel() != ModelSLM); |
160 | |
161 | if (T.isInteger()) |
162 | return true; |
163 | else if (T == Type::f32) |
164 | return floatAtomics && (hw >= HW::XeHP); |
165 | else |
166 | return false; |
167 | } |
168 | |
169 | static inline int slmCapacity(HW hw) { |
170 | switch (hw) { |
171 | case HW::Gen9: |
172 | case HW::Gen11: return 65536; |
173 | case HW::Gen12LP: |
174 | case HW::XeHP: |
175 | case HW::XeHPG: |
176 | case HW::XeHPC: return 131072; |
177 | default: return 0; |
178 | } |
179 | } |
180 | |
181 | static inline int threadsPerEU(HW hw, const CommonStrategy &strategy) { |
182 | if (hw >= HW::XeHP) |
183 | return (strategy.GRFs > 128) ? 4 : 8; |
184 | else |
185 | return 7; |
186 | } |
187 | |
188 | static inline int eusPerSubslice(HW hw) { |
189 | switch (hw) { |
190 | case HW::Gen9: |
191 | case HW::Gen11: |
192 | case HW::XeHPC: return 8; |
193 | case HW::Gen12LP: |
194 | case HW::XeHP: |
195 | case HW::XeHPG: return 16; |
196 | default: return 0; |
197 | } |
198 | } |
199 | |
200 | static inline bool canDualGRF( |
201 | HW hw, DataType dt, const CommonStrategy &strategy) { |
202 | return (strategy.dualGRF && (elementsPerGRF(hw, dt) < 32)); |
203 | } |
204 | |
205 | // Perform a binary register-wise operation. |
206 | template <typename F> |
207 | static inline void map(HW hw, DataType dt, const GRFMultirange &r1, |
208 | const GRFMultirange &r2, const CommonStrategy &strategy, F f) { |
209 | int ne = elementsPerGRF(hw, dt); |
210 | int rstride = canDualGRF(hw, dt, strategy) ? 2 : 1; |
211 | int len = r1.getLen(); |
212 | |
213 | for (int rr = 0; rr < len;) { |
214 | int nr = std::min<int>(len - rr, rstride); |
215 | if (!r1.contiguous(rr, nr) || !r2.contiguous(rr, nr)) nr = 1; |
216 | f(nr * ne, r1[rr].retype(dt), r2[rr].retype(dt)); |
217 | rr += nr; |
218 | } |
219 | } |
220 | |
221 | // Perform a ternary register-wise operation. |
222 | template <typename F> |
223 | static inline void map(HW hw, DataType dt, const GRFMultirange &r1, |
224 | const GRFMultirange &r2, const GRFMultirange &r3, |
225 | const CommonStrategy &strategy, F f) { |
226 | int ne = elementsPerGRF(hw, dt); |
227 | int rstride = canDualGRF(hw, dt, strategy) ? 2 : 1; |
228 | int len = r1.getLen(); |
229 | |
230 | for (int rr = 0; rr < len;) { |
231 | int nr = std::min<int>(len - rr, rstride); |
232 | if (!r1.contiguous(rr, nr) || !r2.contiguous(rr, nr) |
233 | || !r3.contiguous(rr, nr)) |
234 | nr = 1; |
235 | f(nr * ne, r1[rr].retype(dt), r2[rr].retype(dt), r3[rr].retype(dt)); |
236 | rr += nr; |
237 | } |
238 | } |
239 | |
240 | // Perform a quaternary register-wise operation. |
241 | template <typename F> |
242 | static inline void map(HW hw, DataType dt, const GRFMultirange &r1, |
243 | const GRFMultirange &r2, const GRFMultirange &r3, |
244 | const GRFMultirange &r4, const CommonStrategy &strategy, F f) { |
245 | int ne = elementsPerGRF(hw, dt); |
246 | int rstride = canDualGRF(hw, dt, strategy) ? 2 : 1; |
247 | int len = r1.getLen(); |
248 | |
249 | for (int rr = 0; rr < len;) { |
250 | int nr = std::min<int>(len - rr, rstride); |
251 | if (!r1.contiguous(rr, nr) || !r2.contiguous(rr, nr) |
252 | || !r3.contiguous(rr, nr) || !r4.contiguous(rr, nr)) |
253 | nr = 1; |
254 | f(nr * ne, r1[rr].retype(dt), r2[rr].retype(dt), r3[rr].retype(dt), |
255 | r4[rr].retype(dt)); |
256 | rr += nr; |
257 | } |
258 | } |
259 | |
260 | // Perform a unary register-wise operation on a register block. |
261 | template <typename F> |
262 | static inline void map(HW hw, DataType dt, const GRFMultirange ®s, |
263 | const vector<RegisterBlock> &layout, const CommonStrategy &strategy, |
264 | F f, int cxComponent = -1) { |
265 | int curReg = 0, curOff = 0, curBytes = 0; |
266 | auto ebytes = getBytes(dt); |
267 | |
268 | auto map1 = [&]() { |
269 | curOff &= -ebytes; |
270 | curBytes &= -ebytes; |
271 | while (curBytes) { |
272 | int maxBytes; |
273 | if (curOff & (GRF::bytes(hw) - 1)) |
274 | maxBytes = GRF::bytes(hw) - curOff; |
275 | else |
276 | maxBytes = (canDualGRF(hw, dt, strategy) ? 2 : 1) |
277 | * GRF::bytes(hw); |
278 | |
279 | auto nbytes = rounddown_pow2(std::min(maxBytes, curBytes)); |
280 | auto ne = std::min<int>(32, nbytes / ebytes); |
281 | nbytes = ne * ebytes; |
282 | |
283 | auto reg = regs[curOff >> GRF::log2Bytes(hw)].sub( |
284 | (curOff & (GRF::bytes(hw) - 1)) / ebytes, dt)(1); |
285 | |
286 | f(ne, reg); |
287 | |
288 | curBytes -= nbytes; |
289 | curOff += nbytes; |
290 | } |
291 | }; |
292 | |
293 | for (auto &block : layout) { |
294 | int endReg |
295 | = (curOff + curBytes + block.bytes - 1) >> GRF::log2Bytes(hw); |
296 | if ((block.offsetBytes == curOff + curBytes) |
297 | && regs.contiguous(curReg, endReg - curReg + 1)) |
298 | curBytes += block.bytes; |
299 | else { |
300 | map1(); |
301 | curOff = block.offsetBytes; |
302 | curReg = curOff >> GRF::log2Bytes(hw); |
303 | curBytes = block.bytes; |
304 | } |
305 | } |
306 | |
307 | map1(); |
308 | } |
309 | |
310 | template <typename T, typename F> |
311 | static inline void map(HW hw, const GRFMultirange &r1, const GRFMultirange &r2, |
312 | const CommonStrategy &strategy, F f) { |
313 | map(hw, getDataType<T>(), r1, r2, strategy, f); |
314 | } |
315 | |
316 | template <typename T, typename F> |
317 | static inline void map(HW hw, const GRFMultirange &r1, const GRFMultirange &r2, |
318 | const GRFMultirange &r3, const CommonStrategy &strategy, F f) { |
319 | map(hw, getDataType<T>(), r1, r2, r3, strategy, f); |
320 | } |
321 | |
322 | template <typename T, typename F> |
323 | static inline void map(HW hw, const GRFMultirange ®s, |
324 | const vector<RegisterBlock> &layout, const CommonStrategy &strategy, |
325 | F f) { |
326 | map(hw, getDataType<T>(), regs, layout, strategy, f); |
327 | } |
328 | |
329 | template <typename... Targs> |
330 | static inline void map(HW hw, Type T, Targs... args) { |
331 | map(hw, T.ngen(), args...); |
332 | } |
333 | |
334 | // Move subregister to another pipe. |
335 | static inline void movePipes(Subregister &s, bool sizeCanChange = true) { |
336 | DataType type = s.getType(); |
337 | |
338 | switch (type) { |
339 | case DataType::bf: |
340 | case DataType::hf: type = DataType::uw; break; |
341 | case DataType::tf32: |
342 | case DataType::f: type = DataType::ud; break; |
343 | case DataType::df: |
344 | if (sizeCanChange) type = DataType::ud; |
345 | break; |
346 | case DataType::w: |
347 | case DataType::uw: type = DataType::hf; break; |
348 | case DataType::d: |
349 | case DataType::ud: type = DataType::f; break; |
350 | case DataType::q: |
351 | case DataType::uq: |
352 | if (sizeCanChange) type = DataType::f; |
353 | break; |
354 | default: break; |
355 | } |
356 | |
357 | s = s.reinterpret(0, type); |
358 | } |
359 | |
360 | // Move subregister to integer pipe. |
361 | static inline void moveToIntPipe(Subregister &s) { |
362 | DataType type = s.getType(); |
363 | |
364 | switch (type) { |
365 | case DataType::bf: |
366 | case DataType::hf: type = DataType::uw; break; |
367 | case DataType::q: |
368 | case DataType::uq: |
369 | case DataType::f: |
370 | case DataType::tf32: |
371 | case DataType::df: type = DataType::ud; break; |
372 | default: break; |
373 | } |
374 | |
375 | s = s.reinterpret(0, type); |
376 | } |
377 | |
378 | static inline Type sintType(Type T) { |
379 | switch (T) { |
380 | case Type::s8: |
381 | case Type::u8: return Type::s8; |
382 | case Type::bf16: |
383 | case Type::f16: |
384 | case Type::u16: |
385 | case Type::s16: return Type::s16; |
386 | case Type::f32: |
387 | case Type::s32: |
388 | case Type::u32: return Type::s32; |
389 | default: return Type::invalid; |
390 | } |
391 | } |
392 | |
393 | // Move register region to integer pipe. |
394 | static inline void moveToIntPipe(int esize, RegData &s) { |
395 | switch (s.getType()) { |
396 | case DataType::bf: |
397 | case DataType::hf: s.setType(DataType::uw); break; |
398 | case DataType::q: |
399 | case DataType::uq: |
400 | case DataType::tf32: |
401 | case DataType::f: s.setType(DataType::ud); break; |
402 | case DataType::df: |
403 | s.setType(DataType::uq); |
404 | EmulationImplementation::makeDWPair(s, esize); |
405 | break; |
406 | default: break; |
407 | } |
408 | } |
409 | |
410 | void RegisterBlock::calcBytes( |
411 | Type T, const MatrixAddressingStrategy &astrategy) { |
412 | if (astrategy.newDP && astrategy.prefetch) |
413 | bytes = 0; |
414 | else |
415 | calcBytes(T); |
416 | } |
417 | |
418 | void RegisterBlock::calcBytes(Type T) { |
419 | if (cxComponent != Interleaved) T = T.real(); |
420 | bytes = align_up(colMajor ? nc : nr, crosspack) * ld * T; |
421 | if (isLoadBlock() && msgRegs == 0) |
422 | msgRegs = (bytes + (1 << log2GRFBytes) - 1) >> log2GRFBytes; |
423 | } |
424 | |
425 | int RegisterBlock::nregs() const { |
426 | auto grfBytes = (1 << log2GRFBytes); |
427 | if (offsetBytes & (grfBytes - 1)) stub(); |
428 | return (bytes + grfBytes - 1) >> log2GRFBytes; |
429 | } |
430 | |
431 | int RegisterBlock::offsetReg() const { |
432 | auto grfBytes = (1 << log2GRFBytes); |
433 | if (offsetBytes & (grfBytes - 1)) stub(); |
434 | return offsetBytes >> log2GRFBytes; |
435 | } |
436 | |
437 | void RegisterBlock::simplify(Type T) { |
438 | // If block is completely crosspacked, convert to equivalent layout without crosspack. |
439 | if (crosspack == (colMajor ? nc : nr) && isLargeCrosspack(T, crosspack)) { |
440 | auto od = colMajor ? nr : nc; |
441 | if (ld == od) { |
442 | colMajor = !colMajor; |
443 | ld = crosspack; |
444 | crosspack = 1; |
445 | } |
446 | } |
447 | } |
448 | |
449 | GRFMultirange GRFMultirange::subrange( |
450 | HW hw, Type T, const RegisterBlock &block) const { |
451 | int ne = elementsPerGRF(hw, T); |
452 | int ldGRFs = div_up(block.ld, ne); |
453 | int ldUsedGRFs = div_up(block.colMajor ? block.nr : block.nc, ne); |
454 | int td = block.colMajor ? block.nc : block.nr; |
455 | |
456 | if (ldUsedGRFs >= ldGRFs) |
457 | return subrange(block.offsetReg(), block.nregs()); |
458 | else { |
459 | int offReg = block.offsetReg(); |
460 | GRFMultirange result = subrange(offReg, ldUsedGRFs); |
461 | for (int y = 1; y < td; y++) { |
462 | offReg += ldGRFs; |
463 | result.append(subrange(offReg, ldUsedGRFs)); |
464 | } |
465 | return result; |
466 | } |
467 | } |
468 | |
469 | // Make a RegisterBlock smaller by contracting the leading dimension, if possible. |
470 | void RegisterBlock::compact(Type T) { |
471 | auto newLD = std::max<int>( |
472 | roundup_pow2(colMajor ? nr : nc), (1 << log2GRFBytes) / T); |
473 | if (newLD < ld) { |
474 | ld = newLD; |
475 | calcBytes(T); |
476 | } |
477 | } |
478 | |
479 | static inline bool isTransposing(AccessType atype) { |
480 | if (atype == AccessType::Scattered) return true; |
481 | if (atype == AccessType::ChannelScattered) return true; |
482 | if (atype == AccessType::Block2DTranspose) return true; |
483 | return false; |
484 | } |
485 | |
486 | Subregister SubregisterPair::getReg(int idx) const { |
487 | auto r = regs[idx & 1]; |
488 | if (negative) r = -r; |
489 | return r; |
490 | } |
491 | |
492 | Subregister SubregisterPair::getRegAvoiding(HW hw, const RegData &rd) const { |
493 | if (Bundle::same_bank(hw, rd, regs[0])) |
494 | return getReg(1); |
495 | else |
496 | return getReg(0); |
497 | } |
498 | |
499 | inline namespace { |
500 | template <typename T> |
501 | struct ACHelper { |
502 | static T avoidConflict(HW hw, const T &x, const RegData &other) { |
503 | return x; |
504 | } |
505 | }; |
506 | template <> |
507 | struct ACHelper<SubregisterPair> { |
508 | static Subregister avoidConflict( |
509 | HW hw, const SubregisterPair &x, const RegData &other) { |
510 | return x.getRegAvoiding(hw, other); |
511 | } |
512 | }; |
513 | template <typename T> |
514 | struct ACHelper<Scalar<T>> { |
515 | static Subregister avoidConflict( |
516 | HW hw, const Scalar<T> &x, const RegData &other) { |
517 | return x.getRegAvoiding(hw, other); |
518 | } |
519 | }; |
520 | } // namespace |
521 | template <typename T> |
522 | decltype(ACHelper<T>::avoidConflict(HW::Unknown, std::declval<T>(), RegData())) |
523 | avoidConflict(HW hw, const T &x, const RegData &other) { |
524 | return ACHelper<T>::avoidConflict(hw, x, other); |
525 | } |
526 | |
527 | FlagRegister VirtualFlag::toPhysical() const { |
528 | if (n == 2) |
529 | return FlagRegister(idx >> 1); |
530 | else |
531 | return FlagRegister::createFromIndex(idx); |
532 | } |
533 | |
534 | VirtualFlag VirtualFlagAllocator::allocVirtual(int n) { |
535 | if (!free) throw out_of_registers_exception(); |
536 | if (n > 2) stub(); |
537 | |
538 | uint32_t bmask = free; |
539 | if (n == 2) bmask = (bmask & (bmask >> 1)) & 0x55555555; |
540 | int base = bsf(bmask); |
541 | |
542 | VirtualFlag vflag {base, n}; |
543 | claim(vflag); |
544 | |
545 | return vflag; |
546 | } |
547 | |
548 | FlagRegister VirtualFlagAllocator::tryAlloc(int n) { |
549 | auto vflag = allocVirtual(n); |
550 | if (isVirtual(vflag)) { |
551 | release(vflag); |
552 | return FlagRegister {}; |
553 | } |
554 | |
555 | lock(vflag); |
556 | |
557 | return vflag.toPhysical(); |
558 | } |
559 | |
560 | FlagRegister VirtualFlagAllocator::alloc(int n) { |
561 | auto flag = tryAlloc(n); |
562 | if (flag.isInvalid()) throw out_of_registers_exception(); |
563 | |
564 | return flag; |
565 | } |
566 | |
567 | FlagRegister VirtualFlagAllocator::assignPhysical(VirtualFlag vflag) { |
568 | VirtualFlag pflag; |
569 | |
570 | // Is it already a physical flag register? |
571 | if (!isVirtual(vflag)) { |
572 | pflag = vflag; |
573 | } else { |
574 | // It's virtual. Starting at nextPhys, find an unlocked flag register. |
575 | for (int i = nextPhys; i < nextPhys + nflag; i++) { |
576 | if (i & (vflag.n - 1)) continue; |
577 | auto idx = i & (nflag - 1); |
578 | if (!(locked & mask(idx, vflag.n))) { |
579 | nextPhys = (idx + vflag.n) & (nflag - 1); |
580 | pflag = VirtualFlag {idx, vflag.n}; |
581 | break; |
582 | } |
583 | } |
584 | } |
585 | |
586 | if (!pflag) throw out_of_registers_exception(); |
587 | |
588 | return pflag.toPhysical(); |
589 | } |
590 | |
591 | static inline RegData getMaskFlag(VirtualFlag vflag, CommonState &state) { |
592 | if (state.vflagStorage.isValid()) |
593 | return state.vflagStorage[vflag.idx].reinterpret( |
594 | 0, vflag.n == 2 ? DataType::ud : DataType::uw); |
595 | else if (!state.raVFlag.isVirtual(vflag)) { |
596 | auto pflag = vflag.toPhysical(); |
597 | state.usePhysicalFlag(pflag); |
598 | return pflag; |
599 | } else |
600 | throw need_vflag(); |
601 | } |
602 | |
603 | template <HW hw> |
604 | FlagRegister gemm_kernel_generator_t<hw>::getPhysicalFlag( |
605 | VirtualFlag vflag, CommonState &state) { |
606 | VirtualFlag pflag; |
607 | |
608 | if (state.vflagStorage.isValid()) { |
609 | // Check if virtual flag is currently active. |
610 | int pidx = -1; |
611 | for (int i = 0; i < FlagRegister::subcount(hw); i++) |
612 | if (state.activeVFlags[i] == vflag) pidx = i; |
613 | |
614 | // If flag is not currently active, load it into a physical flag. |
615 | if (pidx == -1) { |
616 | auto freg = state.raVFlag.assignPhysical(vflag); |
617 | pidx = freg.index(); |
618 | mov(1, freg, getMaskFlag(vflag, state)); |
619 | for (int i = 0; i < int(vflag.n); i++) |
620 | state.activeVFlags[pidx + i] = vflag; |
621 | } |
622 | |
623 | pflag = VirtualFlag {pidx, vflag.n}; |
624 | } else { |
625 | if (state.raVFlag.isVirtual(vflag)) throw need_vflag(); |
626 | |
627 | pflag = vflag; |
628 | } |
629 | |
630 | return pflag.toPhysical(); |
631 | } |
632 | |
633 | template <HW hw> |
634 | void gemm_kernel_generator_t<hw>::allocVFlagStorage( |
635 | const CommonStrategy &strategy, CommonState &state) { |
636 | state.vflagStorage |
637 | = state.ra.alloc(getHint(HintType::LongTerm, strategy)).uw(); |
638 | } |
639 | |
640 | TokenAllocator::TokenAllocator(HW hw) { |
641 | free = (1ull << tokenCount(hw)) - 1; |
642 | } |
643 | |
644 | int8_t TokenAllocator::tryAlloc() { |
645 | if (free) { |
646 | int8_t token = bsf(free); |
647 | free &= ~(1 << token); |
648 | return token; |
649 | } else |
650 | return -1; |
651 | } |
652 | |
653 | /************************/ |
654 | /* Pseudo-instructions. */ |
655 | /************************/ |
656 | |
657 | // goto instruction with Gen12 semantics. |
658 | template <HW hw> |
659 | void gemm_kernel_generator_t<hw>::goto12(const InstructionModifier &mod, |
660 | Label &jip, Label &uip, bool branchCtrl) { |
661 | InstructionModifier mmod = mod; |
662 | if (hw == HW::Gen9 && !branchCtrl) { |
663 | if (mmod.getPredCtrl() == PredCtrl::None) stub(); |
664 | mmod.setPredInv(!mmod.isPredInv()); |
665 | } |
666 | goto_(mmod, jip, uip, branchCtrl); |
667 | } |
668 | |
669 | // Compare to zero. |
670 | template <HW hw> |
671 | void gemm_kernel_generator_t<hw>::cmp0( |
672 | const InstructionModifier &mod, RegData src0) { |
673 | mov(mod, null.retype(src0.getType()), abs(src0)); |
674 | } |
675 | |
676 | // Scale then add: dst <- src0 + src1 * (numerator / denominator), rounding up. |
677 | // If exact = true, ensure src1 * num / denom is integral if src1 immediate. |
678 | template <HW hw> |
679 | void gemm_kernel_generator_t<hw>::addScaled(const InstructionModifier &mod, |
680 | const RegData &dst, int src0, const RegData &src1, int numerator, |
681 | int denominator, CommonState &state, bool exact) { |
682 | if (!is_zero_or_pow2(numerator)) stub(); |
683 | if (!is_zero_or_pow2(denominator)) stub(); |
684 | |
685 | if (numerator == denominator) { |
686 | (src0 != 0) ? add(mod, dst, src1, src0) |
687 | : (src1 != dst) ? mov(mod, dst, src1) : noop(); |
688 | } else if (numerator > denominator) { |
689 | (src0 == 0) ? mulConstant(mod, dst, src1, numerator / denominator) |
690 | : mad(mod, dst, src0, src1, numerator / denominator); |
691 | } else if ((numerator * 2) == denominator) |
692 | avg(mod, dst, src1, src0 * 2); |
693 | else { |
694 | add(mod, dst, src1, ((src0 + 1) * denominator / numerator) - 1); |
695 | asr(mod, dst, dst, log2(denominator) - log2(numerator)); |
696 | } |
697 | } |
698 | |
699 | template <HW hw> |
700 | void gemm_kernel_generator_t<hw>::addScaled(const InstructionModifier &mod, |
701 | const RegData &dst, const RegData &src0, const RegData &src1, |
702 | int numerator, int denominator, CommonState &state, bool exact) { |
703 | if (!is_zero_or_pow2(numerator)) stub(); |
704 | if (!is_zero_or_pow2(denominator)) stub(); |
705 | |
706 | if (numerator == denominator) |
707 | add(mod, dst, src1, src0); |
708 | else if (numerator > denominator) |
709 | mad(mod, dst, src0, src1, numerator / denominator); |
710 | else { |
711 | auto temp = state.ra.alloc_sub(src1.getType()); |
712 | if (exact) |
713 | asr(mod, temp, src1, log2(denominator) - log2(numerator)); |
714 | else { |
715 | add(mod, temp, src1, (denominator / numerator) - 1); |
716 | asr(mod, temp, temp, log2(denominator) - log2(numerator)); |
717 | } |
718 | add(mod, dst, temp, src0); |
719 | state.ra.safeRelease(temp); |
720 | } |
721 | } |
722 | |
723 | template <HW hw> |
724 | void gemm_kernel_generator_t<hw>::addScaled(const InstructionModifier &mod, |
725 | const RegData &dst, const RegData &src0, int src1, int numerator, |
726 | int denominator, CommonState &state, bool exact) { |
727 | if (!is_zero_or_pow2(numerator)) stub(); |
728 | if (!is_zero_or_pow2(denominator)) stub(); |
729 | if (exact && ((numerator * src1) % denominator)) |
730 | throw std::runtime_error("Misaligned immediate value." ); |
731 | add(mod, dst, src0, (numerator * src1) / denominator); |
732 | } |
733 | |
734 | // Synchronize on all pipes and OOO operations. |
735 | template <HW hw> |
736 | void gemm_kernel_generator_t<hw>::syncall() { |
737 | if (hw == HW::Gen12LP) |
738 | sync.allwr(SWSB(1)); |
739 | else if (hw >= HW::XeHP) |
740 | sync.allwr(SWSB<AllPipes>(1)); |
741 | } |
742 | |
743 | // Multiply by a constant, optimizing for power-of-2 constants. |
744 | template <HW hw> |
745 | template <typename DT> |
746 | void gemm_kernel_generator_t<hw>::mulConstant(const InstructionModifier &mod, |
747 | const RegData &dst, const RegData &src0, int32_t src1) { |
748 | if (src1 == 0) |
749 | mov<DT>(mod, dst, uint16_t(0)); |
750 | else if (src1 == 1) { |
751 | if (dst != src0) mov<DT>(mod, dst, src0); |
752 | } else if (src1 == -1) |
753 | mov<DT>(mod, dst, -src0); |
754 | else if (is_zero_or_pow2(src1)) |
755 | shl<DT>(mod, dst, src0, uint16_t(log2(src1))); |
756 | else if (src1 >= 0x10000) |
757 | mul<DT>(mod, dst, src0, uint32_t(src1)); |
758 | else if (src1 < -0x8000) |
759 | mul<DT>(mod, dst, src0, int32_t(src1)); |
760 | else if (src1 > 0) |
761 | mul<DT>(mod, dst, src0, uint16_t(src1)); |
762 | else |
763 | mul<DT>(mod, dst, src0, int16_t(src1)); |
764 | } |
765 | |
766 | // Three-argument add. |
767 | template <HW hw> |
768 | template <typename DT, typename S0, typename S2> |
769 | void gemm_kernel_generator_t<hw>::eadd3(const InstructionModifier &mod, |
770 | const RegData &dst, const S0 &src0, const RegData &src1, |
771 | const S2 &src2) { |
772 | if ((hw >= HW::XeHP) && !(dst.getOffset() & 1)) |
773 | add3<DT>(mod, dst, src0, src1, src2); |
774 | else { |
775 | add<DT>(mod, dst, src1, src0); |
776 | add<DT>(mod, dst, dst, src2); |
777 | } |
778 | } |
779 | |
780 | template <HW hw> |
781 | template <typename DT> |
782 | void gemm_kernel_generator_t<hw>::emov(const ngen::InstructionModifier &mod, |
783 | ngen::RegData dst, ngen::RegData src0, const CommonStrategy &strategy, |
784 | CommonState &state) { |
785 | EmulationImplementation::applyDefaultType<DT>(dst); |
786 | EmulationImplementation::applyDefaultType<DT>(src0); |
787 | |
788 | if (dst.getType() == DataType::tf32 && src0.getType() == DataType::tf32) { |
789 | dst.setType(DataType::f); |
790 | src0.setType(DataType::f); |
791 | } |
792 | |
793 | if (hw >= HW::XeHP |
794 | && one_of(src0.getType(), DataType::hf, DataType::f, DataType::bf) |
795 | && src0.getType() == dst.getType() |
796 | && ((src0.getHS() != dst.getHS()) |
797 | || (src0.getOffset() != dst.getOffset()))) { |
798 | moveToIntPipe(mod.getExecSize(), dst); |
799 | moveToIntPipe(mod.getExecSize(), src0); |
800 | } |
801 | |
802 | if (hw < HW::XeHP && dst.getType() == DataType::f |
803 | && src0.getType() == DataType::bf) { |
804 | dst.setType(DataType::ud); |
805 | src0.setType(DataType::uw); |
806 | shl(mod, dst, src0, 16); |
807 | } else |
808 | EmulationImplementation::emov(*this, mod, dst, src0, strategy.emulate); |
809 | } |
810 | |
811 | template <HW hw> |
812 | template <typename DT> |
813 | void gemm_kernel_generator_t<hw>::eadd(const InstructionModifier &mod, |
814 | const RegData &dst, const RegData &src0, const RegData &src1, |
815 | const CommonStrategy &strategy, CommonState &state) { |
816 | if (dst.getType() == DataType::f && src0.getType() == DataType::f |
817 | && src1.getType() == DataType::bf && src1.getHS() != 1) { |
818 | GRF alloced, temp = state.emulate.temp[0]; |
819 | if (temp.isInvalid()) temp = alloced = state.ra.alloc(); |
820 | |
821 | auto src1UW = src1; |
822 | src1UW.setType(DataType::uw); |
823 | mov(mod, temp.uw(0)(1), src1UW); |
824 | add(mod, dst, src0, temp.bf(0)(1)); |
825 | |
826 | state.ra.safeRelease(alloced); |
827 | } else |
828 | EmulationImplementation::eadd<DT>( |
829 | *this, mod, dst, src0, src1, strategy.emulate, state.emulate); |
830 | } |
831 | |
832 | template <HW hw> |
833 | template <typename S0, typename S2> |
834 | void gemm_kernel_generator_t<hw>::emad(const InstructionModifier &mod, |
835 | const RegData &dst, const S0 &src0, const RegData &src1, const S2 &src2, |
836 | const CommonStrategy &strategy, CommonState &state) { |
837 | auto dstType = dst.getType(); |
838 | if ((hw >= HW::Gen10 && !(dst.getByteOffset() & 7) |
839 | && !one_of(dstType, DataType::q, DataType::uq) |
840 | && !one_of(src2.getType(), DataType::d, DataType::ud)) |
841 | || one_of(dstType, DataType::hf, DataType::f, DataType::df)) { |
842 | mad(mod, dst, src0, src1, src2); |
843 | } else { |
844 | auto ttype = (isSigned(src1.getType()) || isSigned(src2.getType())) |
845 | ? DataType::d |
846 | : DataType::ud; |
847 | auto temp = state.ra.alloc_sub(ttype); |
848 | emul(mod, temp, src1, src2, strategy, state); |
849 | eadd(mod, dst, temp, src0, strategy, state); |
850 | state.ra.safeRelease(temp); |
851 | } |
852 | } |
853 | |
854 | template <HW hw> |
855 | template <typename S0> |
856 | void gemm_kernel_generator_t<hw>::emad(const InstructionModifier &mod, |
857 | const RegData &dst, const S0 &src0, const RegData &src1, int32_t src2, |
858 | const CommonStrategy &strategy, CommonState &state) { |
859 | auto dstType = dst.getType(); |
860 | if (src2 == 0) |
861 | emov(mod, dst, src0, strategy, state); |
862 | else if (src2 == 1) |
863 | eadd(mod, dst, src1, src0, strategy, state); |
864 | else if (hw >= HW::Gen10 && !(dst.getByteOffset() & 7) |
865 | && (src2 >= -0x8000 && src2 < 0x10000) |
866 | && !one_of(dstType, DataType::q, DataType::uq)) { |
867 | mad(mod, dst, src0, src1, src2); |
868 | } else { |
869 | auto ttype = isSigned(src1.getType()) ? DataType::d : DataType::ud; |
870 | Subregister tempScalar; |
871 | GRFRange tempGRFs; |
872 | RegData temp; |
873 | if (mod.getExecSize() == 1) |
874 | temp = tempScalar = state.ra.alloc_sub(ttype); |
875 | else { |
876 | tempGRFs = state.ra.alloc_range(2); |
877 | temp = tempGRFs[0].retype(ttype); |
878 | } |
879 | emulConstant(mod, temp, src1, src2, strategy, state); |
880 | eadd(mod, dst, temp, src0, strategy, state); |
881 | state.ra.safeRelease(tempScalar); |
882 | state.ra.safeRelease(tempGRFs); |
883 | } |
884 | } |
885 | |
886 | template <HW hw> |
887 | template <typename DT> |
888 | void gemm_kernel_generator_t<hw>::emath(const InstructionModifier &mod, |
889 | MathFunction fc, const RegData &dst, const RegData &src0, |
890 | const GEMMStrategy &strategy, CommonState &state) { |
891 | if (hw == HW::XeHP && strategy.systolic && mod.getExecSize() <= 8) { |
892 | // Workaround for DPAS + SIMD8 EM hang: use SIMD16 arithmetic. |
893 | auto mod16 = mod; |
894 | mod16.setExecSize(16); |
895 | |
896 | auto temp = state.ra.alloc_range(2); |
897 | auto tt = temp[0].retype(src0.getType()); |
898 | |
899 | mov(mod.getExecSize(), tt, src0); |
900 | math(mod16, fc, tt, tt); |
901 | mov(mod.getExecSize(), dst, tt); |
902 | |
903 | state.ra.safeRelease(temp); |
904 | } else |
905 | math(mod, fc, dst, src0); |
906 | } |
907 | |
908 | template <HW hw> |
909 | void gemm_kernel_generator_t<hw>::ejmpi(InstructionModifier mod, Label &dst) { |
910 | if (hw == HW::XeHPC && mod.getPredCtrl() == PredCtrl::anyv |
911 | && !mod.isPredInv()) { |
912 | mod.setPredCtrl(PredCtrl::Normal); |
913 | jmpi(mod, dst); |
914 | auto flag = mod.getFlagReg(); |
915 | flag.setBase(flag.getBase() ^ 1); |
916 | mod.setFlagReg(flag); |
917 | jmpi(mod, dst); |
918 | } else |
919 | jmpi(mod, dst); |
920 | } |
921 | |
922 | /********************/ |
923 | /* Utility routines */ |
924 | /********************/ |
925 | |
926 | // Modulo by constant value. |
927 | template <HW hw> |
928 | template <typename DT> |
929 | void gemm_kernel_generator_t<hw>::mod(const Subregister &dst, |
930 | const Subregister &src, uint16_t modulus, |
931 | const CommonStrategy &strategy, CommonState &state) { |
932 | if (is_zero_or_pow2(modulus)) |
933 | and_<DT>(1, dst, src, modulus - 1); |
934 | else if (strategy.emulate.emulate64 && (hw <= HW::Gen12LP)) |
935 | math<DT>(1, MathFunction::irem, dst, src, modulus); |
936 | else { |
937 | alignDown<DT>(dst, src, modulus, strategy, state); |
938 | add<DT>(1, dst, src, -dst); |
939 | } |
940 | } |
941 | |
942 | // Return both (a % b) and a - (a % b). |
943 | template <HW hw> |
944 | template <typename DT> |
945 | void gemm_kernel_generator_t<hw>::modExt(const Subregister &dstMod, |
946 | const Subregister &dstMultiple, const Subregister &src, |
947 | uint16_t modulus, const CommonStrategy &strategy, CommonState &state) { |
948 | if (is_zero_or_pow2(modulus)) { |
949 | and_<DT>(1, dstMultiple, src, ~uint32_t(modulus - 1)); |
950 | and_<DT>(1, dstMod, src, modulus - 1); |
951 | } else if (strategy.emulate.emulate64 && (hw <= HW::Gen12LP)) { |
952 | math<DT>(1, MathFunction::irem, dstMod, src, modulus); |
953 | add<DT>(1, dstMultiple, src, -dstMod); |
954 | } else { |
955 | alignDown<DT>(dstMultiple, src, modulus, strategy, state); |
956 | add<DT>(1, dstMod, src, -dstMultiple); |
957 | } |
958 | } |
959 | |
960 | // Divide an unsigned value by a constant, rounding down. |
961 | template <HW hw> |
962 | template <typename DT> |
963 | void gemm_kernel_generator_t<hw>::divDown(const ngen::Subregister &dst, |
964 | const ngen::Subregister &src, uint16_t divisor, |
965 | const CommonStrategy &strategy, CommonState &state) { |
966 | if (is_zero_or_pow2(divisor)) |
967 | shr<DT>(1, dst, src, log2(divisor)); |
968 | else if (strategy.emulate.emulate64 && (hw <= HW::Gen12LP)) |
969 | math<DT>(1, MathFunction::iqot, dst, src, uint32_t(divisor)); |
970 | else { |
971 | // Replace integer division with multiplication by reciprocal + shift. |
972 | // Valid for numerators <= 2^31. |
973 | int shift = ngen::utils::bsr(divisor); |
974 | uint32_t recip32 |
975 | = ((uint64_t(0x100000000) << shift) + divisor - 1) / divisor; |
976 | emul32High(1, dst, src, recip32); |
977 | shr(1, dst, dst, shift); |
978 | } |
979 | } |
980 | |
981 | // Align an unsigned value down to a multiple of align. |
982 | template <HW hw> |
983 | template <typename DT> |
984 | void gemm_kernel_generator_t<hw>::alignDown(const Subregister &dst, |
985 | const Subregister &src, uint16_t align, const CommonStrategy &strategy, |
986 | CommonState &state) { |
987 | if (is_zero_or_pow2(align)) |
988 | and_<DT>(1, dst, src, uint32_t(-align)); |
989 | else { |
990 | divDown(dst, src, align, strategy, state); |
991 | mul(1, dst, dst, align); |
992 | } |
993 | } |
994 | |
995 | // Align an unsigned value up to a multiple of align. |
996 | template <HW hw> |
997 | template <typename DT> |
998 | void gemm_kernel_generator_t<hw>::alignUp(const Subregister &dst, |
999 | const Subregister &src, uint16_t align, const CommonStrategy &strategy, |
1000 | CommonState &state) { |
1001 | add<DT>(1, dst, src, uint16_t(align - 1)); |
1002 | alignDown<DT>(dst, dst, align, strategy, state); |
1003 | } |
1004 | |
1005 | // Non-constant integer division. |
1006 | // Requires an auxiliary constant: ceiling(2^(32 + s) / denom), where s = floor(log2(denom)). |
1007 | template <HW hw> |
1008 | template <typename DT> |
1009 | void gemm_kernel_generator_t<hw>::divDown(const Subregister &dst, |
1010 | const Subregister &src0, const Subregister &src1, |
1011 | const Subregister &src1Recip, const FlagRegister &flag, |
1012 | const CommonStrategy &strategy, CommonState &state) { |
1013 | auto shift = state.ra.alloc_sub<uint32_t>(); |
1014 | auto pop = state.ra.alloc_sub<uint16_t>(); |
1015 | cbit(1, pop, src1); |
1016 | fbh(1, shift, src1); |
1017 | cmp(1 | gt | flag, pop, 1); |
1018 | add(1, shift, -shift, 31); |
1019 | emul32High(1 | flag, dst, src0, src1Recip); |
1020 | shr(1 | ~flag, dst, src0, shift); |
1021 | shr(1 | flag, dst, dst, shift); |
1022 | state.ra.safeRelease(shift); |
1023 | state.ra.safeRelease(pop); |
1024 | } |
1025 | |
1026 | // Simple do-while loop macro for the backward conditional branch at end of loop. |
1027 | template <HW hw> |
1028 | void gemm_kernel_generator_t<hw>::simtDoWhileLoop( |
1029 | const InstructionModifier &mod, Label &dest) { |
1030 | Label next; |
1031 | |
1032 | goto12(mod, next, dest, true); |
1033 | mark(next); |
1034 | join(mod.getExecSize()); |
1035 | } |
1036 | |
1037 | // Barrier with SLM fence. |
1038 | template <HW hw> |
1039 | void gemm_kernel_generator_t<hw>::slmBarrier( |
1040 | const GRF &temp, const GRF &r0_info) { |
1041 | if (hw >= HW::Gen11) { |
1042 | slmfence(temp, r0_info); |
1043 | if (hw < HW::Gen12LP) mov<uint32_t>(8, null, temp); |
1044 | } |
1045 | barrier(temp, r0_info); |
1046 | } |
1047 | |
1048 | // Barrier with global memory fence. |
1049 | template <HW hw> |
1050 | void gemm_kernel_generator_t<hw>::globalMemBarrier( |
1051 | const GRF &temp, const GRF &r0_info) { |
1052 | memfence(temp, r0_info); |
1053 | if (hw < HW::Gen12LP) mov<uint32_t>(8, null, temp); |
1054 | barrier(temp, r0_info); |
1055 | } |
1056 | |
1057 | // Pause for a short period of time. |
1058 | template <HW hw> |
1059 | void gemm_kernel_generator_t<hw>::pause(const CommonStrategy &strategy) { |
1060 | if (hw >= HW::Gen11) |
1061 | mov(1 | Switch, tm0[4], strategy.pauseCycles); |
1062 | else |
1063 | for (int i = 0; i < 8; i++) |
1064 | mov<uint32_t>(8 | Switch, null, acc0); |
1065 | } |
1066 | |
1067 | // Create a copy of a SubregisterPair in the other bank. |
1068 | template <HW hw> |
1069 | void gemm_kernel_generator_t<hw>::duplicateScalar( |
1070 | SubregisterPair &val, CommonState &state) { |
1071 | auto reg0 = val.getReg(0); |
1072 | |
1073 | if (reg0 != val.getReg(1)) return; |
1074 | |
1075 | auto bundle = Bundle::locate(hw, reg0); |
1076 | auto reg1 = state.ra.alloc_sub( |
1077 | reg0.getType(), Bundle(bundle.bank_id ^ 1, Bundle::any)); |
1078 | |
1079 | mov(1, reg1, reg0); |
1080 | val = SubregisterPair(reg0, reg1); |
1081 | } |
1082 | |
1083 | template <HW hw> |
1084 | void gemm_kernel_generator_t<hw>::deduplicateScalar( |
1085 | SubregisterPair &val, CommonState &state) { |
1086 | auto reg0 = val.getReg(0), reg1 = val.getReg(1); |
1087 | if (reg0 != reg1) { |
1088 | state.ra.release(reg1); |
1089 | val = SubregisterPair(reg0); |
1090 | } |
1091 | } |
1092 | |
1093 | // Create a copy of a scalar subregister in the other bank. |
1094 | template <HW hw> |
1095 | template <typename T> |
1096 | void gemm_kernel_generator_t<hw>::duplicateScalar( |
1097 | Scalar<T> &val, CommonState &state) { |
1098 | if (!val.fixed()) duplicateScalar(val.getPair(), state); |
1099 | } |
1100 | |
1101 | // Create multiple versions of the input subregister reg, shifted by amounts specified by the shifts bitmask. |
1102 | // The input subregister is used for one of the versions. |
1103 | template <HW hw> |
1104 | MultishiftSubregister gemm_kernel_generator_t<hw>::multishift( |
1105 | const Subregister ®, unsigned int shifts, |
1106 | const CommonStrategy &strategy, CommonState &state, Bundle hint) { |
1107 | MultishiftSubregister ms; |
1108 | |
1109 | while (shifts != 0) { |
1110 | int shift = bsr(shifts); |
1111 | shifts &= ~(1 << shift); |
1112 | |
1113 | if (shifts != 0) { |
1114 | Subregister s = state.ra.alloc_sub(reg.getType(), hint); |
1115 | ms.set(shift, s); |
1116 | eshr(1, s, reg, shift, strategy, state); |
1117 | } else { |
1118 | ms.set(shift, reg); |
1119 | if (shift > 0) eshr(1, reg, reg, shift, strategy, state); |
1120 | } |
1121 | } |
1122 | |
1123 | return ms; |
1124 | } |
1125 | |
1126 | // Get ID of fused thread (0/1), multiplied by a scaling factor. Assumes r1 has not been overwritten, |
1127 | // or state.lid0 is set to a subregister containing local ID 0 (divided by the subgroup size). |
1128 | template <HW hw> |
1129 | void gemm_kernel_generator_t<hw>::getFusedID(int scale, |
1130 | const CommonProblem &problem, const CommonStrategy &strategy, |
1131 | CommonState &state) { |
1132 | if (strategy.fused) { |
1133 | state.fusedID = state.ra.alloc_sub<uint16_t>( |
1134 | getHint(HintType::LongTerm, strategy)); |
1135 | if (state.lid0.isValid()) { |
1136 | if (is_zero_or_pow2(scale) && scale > 1 |
1137 | && (state.fusedID.getOffset() & 3) == 0) |
1138 | bfi2(1, state.fusedID, scale, state.lid0, 0); |
1139 | else { |
1140 | and_(1, state.fusedID, state.lid0, 1); |
1141 | mulConstant(1, state.fusedID, state.fusedID, scale); |
1142 | } |
1143 | } else if (is_zero_or_pow2(scale)) { |
1144 | int shift = log2(scale) - log2(strategy.subgroupSize); |
1145 | Subregister lid0 = r1.uw(0); |
1146 | |
1147 | if (shift > 0) |
1148 | shl(1, state.fusedID, lid0, uint16_t(shift)); |
1149 | else if (shift < 0) |
1150 | shr(1, state.fusedID, lid0, uint16_t(-shift)); |
1151 | |
1152 | and_(1, state.fusedID, (shift == 0) ? lid0 : state.fusedID, |
1153 | uint16_t(scale)); |
1154 | } else { |
1155 | shr(1, state.fusedID, r1.uw(0), |
1156 | uint16_t(log2(strategy.subgroupSize))); |
1157 | and_(1, state.fusedID, state.fusedID, uint16_t(1)); |
1158 | mulConstant(1, state.fusedID, state.fusedID, uint16_t(scale)); |
1159 | } |
1160 | } |
1161 | } |
1162 | |
1163 | // Move r0 information to another register if configured. |
1164 | template <HW hw> |
1165 | void gemm_kernel_generator_t<hw>::moveR0( |
1166 | const CommonStrategy &strategy, CommonState &state) { |
1167 | if (state.movedR0) return; |
1168 | if (state.r0_info.isInvalid()) { |
1169 | switch (strategy.moveR0) { |
1170 | case MoveR0::None: |
1171 | state.r0_info = r0.ud(); |
1172 | state.movedR0 = true; |
1173 | return; |
1174 | case MoveR0::Acc: state.r0_info = acc0.ud(); break; |
1175 | case MoveR0::Addr: state.r0_info = a0.ud(); break; |
1176 | case MoveR0::GRF: |
1177 | state.r0_info |
1178 | = state.ra.alloc(getHint(HintType::R0Info, strategy)); |
1179 | break; |
1180 | } |
1181 | } |
1182 | |
1183 | mov<uint32_t>(8, state.r0_info, r0); |
1184 | |
1185 | if (!strategy.sipR0WA) state.ra.release(r0); |
1186 | |
1187 | state.movedR0 = true; |
1188 | } |
1189 | |
1190 | template <HW hw> |
1191 | void gemm_kernel_generator_t<hw>::moveR0( |
1192 | const GEMMStrategy &strategy, GEMMState &state) { |
1193 | if (state.movedR0) return; |
1194 | if (strategy.moveR0 == MoveR0::GRF) { |
1195 | if (strategy.registerScheme == GEMMStrategy::ACB |
1196 | || strategy.registerScheme == GEMMStrategy::BCA) { |
1197 | state.r0_info = r127; |
1198 | state.ra.claim(r127); |
1199 | } |
1200 | } |
1201 | moveR0(static_cast<CommonStrategy>(strategy), state); |
1202 | } |
1203 | |
1204 | // Call a functor needing the r0 header in a GRF. |
1205 | template <HW hw> |
1206 | template <typename F> |
1207 | void gemm_kernel_generator_t<hw>::useR0(CommonState &state, F f) { |
1208 | if (state.r0_info.isARF()) { |
1209 | auto r0_info = state.ra.alloc(); |
1210 | mov<uint32_t>(8, r0_info, state.r0_info); |
1211 | f(r0_info); |
1212 | state.ra.safeRelease(r0_info); |
1213 | } else |
1214 | f(GRF {state.r0_info.getBase()}); |
1215 | } |
1216 | |
1217 | // Divide out subgroup size from x local size and local ID. |
1218 | template <HW hw> |
1219 | void gemm_kernel_generator_t<hw>::removeSG(const CommonProblem &problem, |
1220 | const CommonStrategy &strategy, const CommonState &state) { |
1221 | uint16_t sss = log2(strategy.subgroupSize); |
1222 | |
1223 | auto localSize0 = interface.getLocalSize(0); |
1224 | auto localID0 = interface.getLocalID(0); |
1225 | |
1226 | shr(1, localSize0, localSize0, sss); |
1227 | shr(1, localID0.uw(0), localID0.uw(0), sss); |
1228 | } |
1229 | |
1230 | // Swap bit 0 of local ID x and y if needed so that threads are ordered according to specified EU fusion. |
1231 | template <HW hw> |
1232 | void gemm_kernel_generator_t<hw>::reorderFusedEUs(const GEMMProblem &problem, |
1233 | const GEMMStrategy &strategy, GEMMState &state) { |
1234 | if (!strategy.fused) return; |
1235 | |
1236 | if (strategy.loopOrder[0] != strategy.fusedLoop) { |
1237 | auto temp = state.ra.alloc_sub<uint32_t>(); |
1238 | and_(1, temp, state.inputs.localIDN.ud(), uint16_t(1)); |
1239 | bfi2(1, state.inputs.localIDN.ud(), uint16_t(1), |
1240 | state.inputs.localIDM.ud(), state.inputs.localIDN.ud()); |
1241 | bfi2(1, state.inputs.localIDM.ud(), uint16_t(1), temp, |
1242 | state.inputs.localIDM.ud()); |
1243 | state.ra.safeRelease(temp); |
1244 | } |
1245 | } |
1246 | |
1247 | template <HW hw> |
1248 | Subregister gemm_kernel_generator_t<hw>::copySubregister( |
1249 | const Subregister ®, CommonState &state, Bundle hint) { |
1250 | auto copy = state.ra.alloc_sub(reg.getType(), hint); |
1251 | mov(1, copy, reg); |
1252 | return copy; |
1253 | } |
1254 | |
1255 | // Set a matrix to zero. |
1256 | template <HW hw> |
1257 | void gemm_kernel_generator_t<hw>::zeroMatrix( |
1258 | const GRFMultirange &r, const CommonStrategy &strategy) { |
1259 | map<uint32_t>(hw, r, r, strategy, |
1260 | [&](int esize, GRF reg, GRF _) { mov(esize, reg, uint16_t(0)); }); |
1261 | } |
1262 | |
1263 | // Release fused remainder-related state variables. |
1264 | template <HW hw> |
1265 | void gemm_kernel_generator_t<hw>::releaseFusedRemainders(GEMMState &state) { |
1266 | state.ra.safeRelease(state.remFusedStorage); |
1267 | state.remaindersFused[LoopM] = Subregister {}; |
1268 | state.remaindersFused[LoopN] = Subregister {}; |
1269 | } |
1270 | |
1271 | template <HW hw> |
1272 | void gemm_kernel_generator_t<hw>::saveMNLocalIDs( |
1273 | const GEMMStrategy &strategy, GEMMState &state) { |
1274 | state.lidStorage = state.ra.alloc_sub<uint32_t>( |
1275 | getHint(HintType::LongTerm, strategy)); |
1276 | state.lidM = state.lidStorage.uw(0); |
1277 | state.lidN = state.lidStorage.uw(1); |
1278 | mov(1, state.lidM, state.inputs.localIDM); |
1279 | mov(1, state.lidN, state.inputs.localIDN); |
1280 | } |
1281 | |
1282 | template <HW hw> |
1283 | void gemm_kernel_generator_t<hw>::saveKLocalIDSize( |
1284 | const GEMMStrategy &strategy, GEMMState &state) { |
1285 | state.lidszKStorage = state.ra.alloc_sub<uint32_t>( |
1286 | getHint(HintType::LongTerm, strategy)); |
1287 | state.lidK = state.lidszKStorage.uw(0); |
1288 | state.lszK = state.lidszKStorage.uw(1); |
1289 | mov(1, state.lidK, state.inputs.localIDK); |
1290 | mov(1, state.lszK, state.inputs.localSizeK); |
1291 | } |
1292 | |
1293 | template <HW hw> |
1294 | void gemm_kernel_generator_t<hw>::releaseSavedMNLocalIDs(GEMMState &state) { |
1295 | state.ra.safeRelease(state.lidStorage); |
1296 | state.lidStorage = invalid; |
1297 | state.lidM = invalid; |
1298 | state.lidN = invalid; |
1299 | } |
1300 | |
1301 | // Clear read suppresion data on ALU pipes. |
1302 | template <HW hw> |
1303 | void gemm_kernel_generator_t<hw>::doReadSuppressionWA( |
1304 | const CommonStrategy &strategy, CommonState &state) { |
1305 | GRF temp; |
1306 | bool freeTemp = false; |
1307 | |
1308 | if (!strategy.readSuppressionWA) return; |
1309 | |
1310 | if (state.r0_info.isValid() && !state.r0_info.isARF()) |
1311 | temp = GRF(state.r0_info.getBase()); |
1312 | else { |
1313 | temp = state.ra.try_alloc(); |
1314 | if (temp.isValid()) |
1315 | freeTemp = true; |
1316 | else |
1317 | temp = r0; |
1318 | } |
1319 | |
1320 | csel<int16_t>(8, temp, temp, temp, temp); |
1321 | csel<float>(8, temp, temp, temp, temp); |
1322 | |
1323 | if (freeTemp) state.ra.safeRelease(temp); |
1324 | } |
1325 | |
1326 | // Get minimum row/column granularity for a matrix in memory. |
1327 | static void getGranularities( |
1328 | const MatrixAddressing &atype, int &rgran, int &cgran) { |
1329 | auto &xgran = isColMajor(atype.layout) ? cgran : rgran; |
1330 | auto &ygran = isColMajor(atype.layout) ? rgran : cgran; |
1331 | rgran = std::max<int>(atype.tileR, 1); |
1332 | cgran = std::max<int>(atype.tileC, 1); |
1333 | xgran = std::max<int>(xgran, atype.crosspack); |
1334 | if (isPacked(atype.layout)) ygran = std::max<int>(ygran, atype.packSize); |
1335 | } |
1336 | |
1337 | // Common register allocator hints. |
1338 | template <HW hw> |
1339 | Bundle gemm_kernel_generator_t<hw>::getHint(HintType type) { |
1340 | switch (type) { |
1341 | case HintType::Bank0: return Bundle(0, Bundle::any); |
1342 | case HintType::Bank1: return Bundle(1, Bundle::any); |
1343 | default: break; |
1344 | } |
1345 | |
1346 | switch (hw) { |
1347 | case HW::Gen9: |
1348 | case HW::Gen10: |
1349 | case HW::Gen11: |
1350 | switch (type) { |
1351 | case HintType::TempComp0: return Bundle(0, 1); |
1352 | case HintType::TempComp1: return Bundle(1, 1); |
1353 | case HintType::LongTerm: return Bundle(Bundle::any, 0); |
1354 | case HintType::LongTerm0: return Bundle(0, 0); |
1355 | case HintType::LongTerm1: return Bundle(1, 0); |
1356 | default: break; |
1357 | } |
1358 | break; |
1359 | case HW::Gen12LP: |
1360 | case HW::XeHP: |
1361 | case HW::XeHPG: |
1362 | case HW::XeHPC: |
1363 | switch (type) { |
1364 | case HintType::LongTerm0: return Bundle(0, Bundle::any); |
1365 | case HintType::LongTerm1: return Bundle(1, Bundle::any); |
1366 | default: break; |
1367 | } |
1368 | default: break; |
1369 | } |
1370 | |
1371 | return Bundle(); |
1372 | } |
1373 | |
1374 | template <HW hw> |
1375 | Bundle gemm_kernel_generator_t<hw>::getHint( |
1376 | HintType type, const CommonStrategy &strategy) { |
1377 | return getHint(type); |
1378 | } |
1379 | |
1380 | // GEMM register allocation hints. |
1381 | template <HW hw> |
1382 | Bundle gemm_kernel_generator_t<hw>::getHint( |
1383 | HintType type, const GEMMStrategy &strategy) { |
1384 | switch (hw) { |
1385 | case HW::Gen9: |
1386 | case HW::Gen10: |
1387 | case HW::Gen11: |
1388 | switch (strategy.registerScheme) { |
1389 | case GEMMStrategy::CSeparate: |
1390 | switch (type) { |
1391 | case HintType::A0Broadcast: |
1392 | case HintType::A0: return Bundle(1, 0); |
1393 | case HintType::A1Broadcast: |
1394 | case HintType::A1: return Bundle(0, 0); |
1395 | case HintType::B0Broadcast: |
1396 | case HintType::B0: return Bundle(0, 0); |
1397 | case HintType::B1Broadcast: |
1398 | case HintType::B1: return Bundle(1, 0); |
1399 | case HintType::C: return Bundle(0, 1); |
1400 | case HintType::CLoad: return Bundle(1, 0); |
1401 | default: break; |
1402 | } |
1403 | break; |
1404 | case GEMMStrategy::ACB: |
1405 | switch (type) { |
1406 | case HintType::A0Broadcast: |
1407 | case HintType::A0: return Bundle(1, 0); |
1408 | case HintType::A1Broadcast: |
1409 | case HintType::A1: return Bundle(0, 0); |
1410 | case HintType::B0Broadcast: |
1411 | case HintType::B0: return Bundle(0, 1); |
1412 | case HintType::B1Broadcast: |
1413 | case HintType::B1: return Bundle(1, 1); |
1414 | case HintType::C: return Bundle(0, 0); |
1415 | case HintType::CLoad: return Bundle(); |
1416 | default: break; |
1417 | } |
1418 | break; |
1419 | case GEMMStrategy::BCA: |
1420 | switch (type) { |
1421 | case HintType::A0Broadcast: |
1422 | case HintType::A0: return Bundle(0, 1); |
1423 | case HintType::A1Broadcast: |
1424 | case HintType::A1: return Bundle(1, 1); |
1425 | case HintType::B0Broadcast: |
1426 | case HintType::B0: return Bundle(1, 0); |
1427 | case HintType::B1Broadcast: |
1428 | case HintType::B1: return Bundle(0, 0); |
1429 | case HintType::C: return Bundle(0, 0); |
1430 | case HintType::CLoad: return Bundle(); |
1431 | default: break; |
1432 | } |
1433 | break; |
1434 | default: break; |
1435 | } |
1436 | break; |
1437 | case HW::Gen12LP: |
1438 | case HW::XeHP: |
1439 | case HW::XeHPG: |
1440 | case HW::XeHPC: |
1441 | switch (strategy.registerScheme) { |
1442 | case GEMMStrategy::CSeparate: |
1443 | switch (type) { |
1444 | case HintType::A0Broadcast: |
1445 | case HintType::A0: return Bundle(1, Bundle::any); |
1446 | case HintType::A1Broadcast: |
1447 | case HintType::A1: return Bundle(0, Bundle::any); |
1448 | case HintType::B0Broadcast: |
1449 | case HintType::B0: return Bundle(0, Bundle::any); |
1450 | case HintType::B1Broadcast: |
1451 | case HintType::B1: return Bundle(1, Bundle::any); |
1452 | case HintType::C: return Bundle(0, 0); |
1453 | case HintType::CLoad: return Bundle(1, Bundle::any); |
1454 | default: break; |
1455 | } |
1456 | break; |
1457 | case GEMMStrategy::ACB: |
1458 | case GEMMStrategy::BCA: |
1459 | if (strategy.systolic) switch (type) { |
1460 | case HintType::A0: |
1461 | case HintType::B0: return Bundle(0, Bundle::any); |
1462 | case HintType::A1: |
1463 | case HintType::B1: return Bundle(1, Bundle::any); |
1464 | case HintType::A0Broadcast: |
1465 | case HintType::B0Broadcast: |
1466 | return Bundle(1, Bundle::any); |
1467 | case HintType::A1Broadcast: |
1468 | case HintType::B1Broadcast: |
1469 | return Bundle(0, Bundle::any); |
1470 | case HintType::C: return Bundle(0, Bundle::any); |
1471 | default: break; |
1472 | } |
1473 | /* else fall through */ |
1474 | case GEMMStrategy::VNC: |
1475 | switch (type) { |
1476 | case HintType::A0: |
1477 | case HintType::B0: return Bundle(1, Bundle::any); |
1478 | case HintType::A1: |
1479 | case HintType::B1: return Bundle(0, Bundle::any); |
1480 | case HintType::A0Broadcast: |
1481 | case HintType::B0Broadcast: |
1482 | return Bundle(0, Bundle::any); |
1483 | case HintType::A1Broadcast: |
1484 | case HintType::B1Broadcast: |
1485 | return Bundle(1, Bundle::any); |
1486 | case HintType::C: return Bundle(0, Bundle::any); |
1487 | default: break; |
1488 | } |
1489 | break; |
1490 | case GEMMStrategy::ABInterleave: |
1491 | switch (type) { |
1492 | case HintType::A0: |
1493 | case HintType::A1: |
1494 | case HintType::A0Broadcast: |
1495 | case HintType::A1Broadcast: return Bundle(1, 0); |
1496 | case HintType::B0: |
1497 | case HintType::B1: |
1498 | case HintType::B0Broadcast: |
1499 | case HintType::B1Broadcast: return Bundle(1, 4); |
1500 | case HintType::C: return Bundle(0, Bundle::any); |
1501 | default: break; |
1502 | } |
1503 | break; |
1504 | case GEMMStrategy::NSeparate: |
1505 | switch (type) { |
1506 | case HintType::A0: |
1507 | case HintType::B0: return Bundle(1, Bundle::any); |
1508 | case HintType::A1: |
1509 | case HintType::B1: return Bundle(0, Bundle::any); |
1510 | case HintType::A0Broadcast: |
1511 | case HintType::B0Broadcast: |
1512 | case HintType::A1Broadcast: |
1513 | case HintType::B1Broadcast: return Bundle(); |
1514 | case HintType::C: return Bundle(0, Bundle::any); |
1515 | case HintType::C1: return Bundle(1, Bundle::any); |
1516 | default: break; |
1517 | } |
1518 | break; |
1519 | case GEMMStrategy::VAvoid: |
1520 | switch (type) { |
1521 | case HintType::A0: |
1522 | case HintType::B0: return Bundle(0, Bundle::any); |
1523 | case HintType::A1: |
1524 | case HintType::B1: return Bundle(1, Bundle::any); |
1525 | case HintType::A0Broadcast: |
1526 | case HintType::B0Broadcast: |
1527 | case HintType::A1Broadcast: |
1528 | case HintType::B1Broadcast: |
1529 | return Bundle(1, Bundle::any); |
1530 | case HintType::C: return Bundle(0, Bundle::any); |
1531 | case HintType::C1: return Bundle(1, Bundle::any); |
1532 | default: break; |
1533 | } |
1534 | break; |
1535 | } |
1536 | break; |
1537 | default: break; |
1538 | } |
1539 | |
1540 | return getHint(type); |
1541 | } |
1542 | |
1543 | // Copy kernel register allocation hints. |
1544 | template <HW hw> |
1545 | Bundle gemm_kernel_generator_t<hw>::getHint( |
1546 | HintType type, const CopyStrategy &strategy) { |
1547 | switch (hw) { |
1548 | case HW::Gen9: |
1549 | case HW::Gen10: |
1550 | case HW::Gen11: |
1551 | case HW::Gen12LP: |
1552 | case HW::XeHP: |
1553 | case HW::XeHPG: |
1554 | case HW::XeHPC: |
1555 | switch (type) { |
1556 | case HintType::S: return Bundle(); |
1557 | case HintType::D: return Bundle(); |
1558 | case HintType::SAddr: return Bundle(); |
1559 | case HintType::DAddr: return Bundle(); |
1560 | default: break; |
1561 | } |
1562 | break; |
1563 | default: break; |
1564 | } |
1565 | |
1566 | return getHint(type); |
1567 | } |
1568 | |
1569 | static inline void safeReleaseRanges( |
1570 | vector<GRFRange> &ranges, CommonState &state) { |
1571 | for (auto &a : ranges) |
1572 | state.ra.safeRelease(a); |
1573 | ranges.clear(); |
1574 | } |
1575 | |
1576 | static inline void releaseRanges( |
1577 | const vector<GRFRange> &ranges, CommonState &state) { |
1578 | for (auto &a : ranges) |
1579 | state.ra.release(a); |
1580 | } |
1581 | |
1582 | static inline void reclaimRanges( |
1583 | const vector<GRFRange> &ranges, CommonState &state) { |
1584 | for (auto &a : ranges) |
1585 | state.ra.claim(a); |
1586 | } |
1587 | |
1588 | static inline void safeRelease(SubregisterPair &pair, CommonState &state) { |
1589 | state.ra.release(pair.getReg(0)); |
1590 | state.ra.release(pair.getReg(1)); |
1591 | pair.invalidate(); |
1592 | } |
1593 | |
1594 | static inline void safeReleaseRanges( |
1595 | GRFMultirange &ranges, CommonState &state) { |
1596 | safeReleaseRanges(ranges.ranges, state); |
1597 | ranges.ranges.clear(); |
1598 | } |
1599 | |
1600 | static inline void safeReleaseRanges( |
1601 | vector<GRFMultirange> &ranges, CommonState &state) { |
1602 | for (auto &a : ranges) |
1603 | safeReleaseRanges(a, state); |
1604 | ranges.clear(); |
1605 | } |
1606 | |
1607 | static inline void releaseRanges( |
1608 | const GRFMultirange &ranges, CommonState &state) { |
1609 | releaseRanges(ranges.ranges, state); |
1610 | } |
1611 | |
1612 | static inline void releaseRanges( |
1613 | const vector<GRFMultirange> &ranges, CommonState &state) { |
1614 | for (auto &a : ranges) |
1615 | releaseRanges(a, state); |
1616 | } |
1617 | |
1618 | static inline void reclaimRanges( |
1619 | const GRFMultirange &ranges, CommonState &state) { |
1620 | reclaimRanges(ranges.ranges, state); |
1621 | } |
1622 | |
1623 | // Reclaim a list of GRF multiranges. |
1624 | static inline void reclaimRanges( |
1625 | const vector<GRFMultirange> &ranges, CommonState &state) { |
1626 | for (auto &a : ranges) |
1627 | reclaimRanges(a, state); |
1628 | } |
1629 | |
1630 | /***********************\ |
1631 | |* Load/store support. *| |
1632 | \***********************/ |
1633 | |
1634 | static int consecutiveElements( |
1635 | Type T, int r, int c, const MatrixAddressing &atype) { |
1636 | int x = isColMajor(atype.layout) ? r : c; |
1637 | int y = isColMajor(atype.layout) ? c : r; |
1638 | |
1639 | if (isPacked(atype.layout)) { |
1640 | int effTileX = (atype.layout == MatrixLayout::Pc) ? atype.tileR |
1641 | : atype.tileC; |
1642 | int effTileY = (atype.layout == MatrixLayout::Pc) ? atype.tileC |
1643 | : atype.tileR; |
1644 | if (!effTileX) effTileX = atype.packSize; |
1645 | if (!effTileY) effTileY = atype.crosspack; |
1646 | |
1647 | if (y % effTileY == 0) { |
1648 | if (x == atype.packSize) |
1649 | return x * y; |
1650 | else if (x % effTileX == 0) |
1651 | return x * effTileY; |
1652 | } |
1653 | if (y % atype.crosspack == 0) |
1654 | return std::min(x, effTileX) * atype.crosspack; |
1655 | } |
1656 | |
1657 | return x; |
1658 | } |
1659 | |
1660 | static bool needsPseudoblock(HW hw, Type T, int r, int c, |
1661 | const MatrixAddressing &atype, |
1662 | const MatrixAddressingStrategy &astrategy, bool writable, bool masked) { |
1663 | auto consecutive = consecutiveElements(T, r, c, atype); |
1664 | bool dwAligned = (atype.alignment & 0x3) == 0; |
1665 | bool owAligned = (atype.alignment & 0xF) == 0; |
1666 | bool pseudo = !dwAligned || ((consecutive * T) & 0x3) |
1667 | || (writable && ((consecutive * T) & 0xF) && !astrategy.newDP) |
1668 | || (writable && !owAligned) |
1669 | || (writable && masked && (T.size() & 3)) |
1670 | || (masked && !owAligned |
1671 | && (hw >= HW::XeHP |
1672 | || astrategy.base.getModel() != ModelA64)) |
1673 | || (hw >= HW::XeHPC && masked) |
1674 | || (hw >= HW::XeHPC && !astrategy.padded && !astrategy.newDP |
1675 | && ((r * c * T) & 0xF)) |
1676 | || astrategy.atomic |
1677 | || (isColMajor(atype.layout) ? c : r) % atype.crosspack |
1678 | || ((astrategy.base.getModel() == ModelSLM) |
1679 | && (hw < HW::Gen11 || !owAligned)); |
1680 | |
1681 | return pseudo; |
1682 | } |
1683 | |
1684 | static bool pseudoblockUseSurface(const MatrixAddressing &atype, |
1685 | const MatrixAddressingStrategy &astrategy, const RegisterBlock &block) { |
1686 | return (astrategy.base.getModel() == ModelSLM) && (block.ebytes == 4) |
1687 | && !astrategy.atomic; |
1688 | } |
1689 | |
1690 | // Get effective access type to use when setting up addresses. |
1691 | static AccessType effectiveAccessType(const MatrixAddressing &atype, |
1692 | const MatrixAddressingStrategy &astrategy, const RegisterBlock &block) { |
1693 | auto type = astrategy.accessType; |
1694 | if (!block.isLoadBlock()) return type; |
1695 | if (type == AccessType::Block && block.ebytes < 16 && block.extra) |
1696 | type = AccessType::PseudoBlock; |
1697 | else if (type == AccessType::Scattered |
1698 | && astrategy.base.getModel() == ModelSLM && block.ebytes == 4 |
1699 | && !astrategy.newDP) |
1700 | type = AccessType::ChannelScattered; |
1701 | else if (type == AccessType::ChannelScattered && block.ebytes != 4) |
1702 | type = AccessType::Scattered; |
1703 | return type; |
1704 | } |
1705 | |
1706 | // Get effective access type to use when performing loads/stores. |
1707 | static AccessType implAccessType(const MatrixAddressing &atype, |
1708 | const MatrixAddressingStrategy &astrategy, const RegisterBlock &block) { |
1709 | auto type = effectiveAccessType(atype, astrategy, block); |
1710 | if (type == AccessType::PseudoBlock) |
1711 | type = pseudoblockUseSurface(atype, astrategy, block) |
1712 | ? AccessType::ChannelScattered |
1713 | : AccessType::Scattered; |
1714 | return type; |
1715 | } |
1716 | |
1717 | // Count the number of address/header GRFs required by a RegisterBlock. |
1718 | static inline int addrGRFCount(const MatrixAddressing &atype, |
1719 | const MatrixAddressingStrategy &astrategy, const RegisterBlock &block) { |
1720 | // Non-load blocks don't get address registers. |
1721 | if (!block.isLoadBlock()) return 0; |
1722 | |
1723 | switch (effectiveAccessType(atype, astrategy, block)) { |
1724 | case AccessType::Scattered: |
1725 | case AccessType::ChannelScattered: |
1726 | case AccessType::PseudoBlock: { |
1727 | auto bytesPerAddr = (astrategy.base.getModel() == ModelA64) ? 8 : 4; |
1728 | auto baseSIMD = std::max<int>(block.simdSize, 8); |
1729 | auto log2Bytes = block.log2GRFBytes; |
1730 | return (bytesPerAddr * baseSIMD + (1 << log2Bytes) - 1) |
1731 | >> log2Bytes; |
1732 | } |
1733 | case AccessType::Block: |
1734 | case AccessType::Block2D: |
1735 | case AccessType::Block2DTranspose: |
1736 | case AccessType::Block2DVNNI: return 1; |
1737 | } |
1738 | throw std::runtime_error("Invalid addressing." ); |
1739 | } |
1740 | |
1741 | // Attempt to allocate address registers for a layout. Returns true if successful. |
1742 | static bool tryAllocAddrRegs(vector<GRFRange> &addrRegs, |
1743 | const vector<RegisterBlock> &layout, const MatrixAddressing &atype, |
1744 | const MatrixAddressingStrategy &astrategy, CommonState &state, |
1745 | Bundle hint = Bundle()) { |
1746 | auto nblocks = int(layout.size()); |
1747 | bool ok = true; |
1748 | |
1749 | addrRegs.resize(nblocks); |
1750 | |
1751 | for (int l = 0; l < nblocks && ok; l++) { |
1752 | addrRegs[l] = state.ra.try_alloc_range( |
1753 | addrGRFCount(atype, astrategy, layout[l]), hint); |
1754 | ok &= addrRegs[l].isValid(); |
1755 | } |
1756 | |
1757 | if (!ok) { |
1758 | for (auto ®s : addrRegs) |
1759 | state.ra.safeRelease(regs); |
1760 | addrRegs.clear(); |
1761 | } |
1762 | |
1763 | return ok; |
1764 | } |
1765 | |
1766 | // Allocate address registers for a layout. |
1767 | static void allocAddrRegs(vector<GRFRange> &addrRegs, |
1768 | const vector<RegisterBlock> &layout, const MatrixAddressing &atype, |
1769 | const MatrixAddressingStrategy &astrategy, CommonState &state, |
1770 | Bundle hint = Bundle()) { |
1771 | if (!tryAllocAddrRegs(addrRegs, layout, atype, astrategy, state, hint)) |
1772 | throw out_of_registers_exception(); |
1773 | } |
1774 | |
1775 | // Check if a layout is completely column-major. |
1776 | static inline bool isLayoutColMajor(const vector<RegisterBlock> &layout) { |
1777 | if (layout.size() == 0) throw std::runtime_error("Empty layout." ); |
1778 | return layout[0] |
1779 | .colMajor; // All layouts we create are homogeneous currently. |
1780 | } |
1781 | |
1782 | // Get the matrix size represented by a layout. |
1783 | static inline void getLayoutDims( |
1784 | const vector<RegisterBlock> &layout, int &m, int &n) { |
1785 | // For now all layouts are sorted so last block is in lower-right corner. |
1786 | if (layout.size() == 0) throw std::runtime_error("Empty layout." ); |
1787 | auto &last = layout[layout.size() - 1]; |
1788 | m = last.offsetR + last.nr; |
1789 | n = last.offsetC + last.nc; |
1790 | } |
1791 | |
1792 | // Check if every block in a layout has the given crosspack, with no padding. |
1793 | static inline bool hasFullCrosspack( |
1794 | const vector<RegisterBlock> &layout, int crosspack) { |
1795 | if (layout.size() == 0) return true; |
1796 | if (layout[0].crosspack |
1797 | != crosspack) // Only need to check first block of layout currently. |
1798 | return false; |
1799 | for (const auto &block : layout) |
1800 | if ((block.colMajor ? block.nc : block.nr) % crosspack) return false; |
1801 | return true; |
1802 | } |
1803 | |
1804 | // Check if the layout is tiled with the given tiling. |
1805 | static inline bool hasTiling( |
1806 | const vector<RegisterBlock> &layout, int tileR, int tileC) { |
1807 | for (auto &block : layout) { |
1808 | if (tileR > 0) |
1809 | if (block.offsetR / tileR != (block.offsetR + block.nr - 1) / tileR) |
1810 | return false; |
1811 | if (tileC > 0) |
1812 | if (block.offsetC / tileC != (block.offsetC + block.nc - 1) / tileC) |
1813 | return false; |
1814 | } |
1815 | return true; |
1816 | } |
1817 | |
1818 | // Check if a layout has row fragmenting. |
1819 | static bool hasRowFragmenting(const vector<RegisterBlock> &layout) { |
1820 | for (auto &block : layout) |
1821 | if (block.rowFragment) return true; |
1822 | return false; |
1823 | } |
1824 | |
1825 | // Check if a layout has column fragmenting. |
1826 | static bool hasColumnFragmenting(const vector<RegisterBlock> &layout) { |
1827 | for (auto &block : layout) |
1828 | if (block.colFragment) return true; |
1829 | return false; |
1830 | } |
1831 | |
1832 | // Check if a layout has remainders enabled. |
1833 | static bool hasRemainders(const vector<RegisterBlock> &layout, |
1834 | bool remainderR = true, bool remainderC = true) { |
1835 | for (auto &block : layout) |
1836 | if ((remainderR && block.remainderR) |
1837 | || (remainderC && block.remainderC)) |
1838 | return true; |
1839 | return false; |
1840 | } |
1841 | |
1842 | // Check if a layout has any kind of fragmenting. |
1843 | static bool hasFragmenting(const vector<RegisterBlock> &layout) { |
1844 | for (auto &block : layout) |
1845 | if (block.rowFragment || block.colFragment) return true; |
1846 | return false; |
1847 | } |
1848 | |
1849 | // Check if a layout has any masking. |
1850 | static bool hasMasking(const vector<RegisterBlock> &layout) { |
1851 | for (auto &block : layout) |
1852 | if (block.rowMask || block.colMask || block.flag) return true; |
1853 | return false; |
1854 | } |
1855 | |
1856 | // Check if a layout has any flag registers assigned. |
1857 | static bool hasFlags(const vector<RegisterBlock> &layout) { |
1858 | for (auto &block : layout) |
1859 | if (block.flag) return true; |
1860 | return false; |
1861 | } |
1862 | |
1863 | // Find the maximum block size in a layout, in registers. |
1864 | static inline int getMaxLoadBlock(const vector<RegisterBlock> &layout) { |
1865 | int result = 0; |
1866 | for (auto &l : layout) |
1867 | result = std::max<int>(result, l.msgRegs); |
1868 | return result; |
1869 | } |
1870 | |
1871 | // Count the number of registers needed by a register layout. |
1872 | static inline int getRegCount(const vector<RegisterBlock> &layout) { |
1873 | if (layout.empty()) return 0; |
1874 | |
1875 | int lastByte = 0; |
1876 | for (auto &l : layout) |
1877 | lastByte = std::max(lastByte, l.offsetBytes + l.bytes); |
1878 | |
1879 | int log2Bytes = layout[0].log2GRFBytes; |
1880 | return (lastByte + (1 << log2Bytes) - 1) >> log2Bytes; |
1881 | } |
1882 | |
1883 | static int getAddr0Offset(const RegisterBlock &block, |
1884 | const MatrixAddressing &atype, |
1885 | const MatrixAddressingStrategy &astrategy) { |
1886 | if (astrategy.newDP) return 0; |
1887 | if (astrategy.base.getModel() == ModelA64) return 0; |
1888 | if (effectiveAccessType(atype, astrategy, block) == AccessType::Block) |
1889 | return 2; |
1890 | return 0; |
1891 | } |
1892 | |
1893 | // Get a subregister containing the (shifted) address of the (0,0) entry of a layout. |
1894 | static Subregister getOriginAddr(const vector<RegisterBlock> &layout, |
1895 | const vector<GRFRange> &addrRegs, const MatrixAddressing &atype, |
1896 | const MatrixAddressingStrategy &astrategy, int *shiftOut = nullptr) { |
1897 | bool a64 = (astrategy.base.getModel() == ModelA64); |
1898 | |
1899 | for (size_t b = 0; b < layout.size(); b++) { |
1900 | const auto &block = layout[b]; |
1901 | if ((block.offsetR != 0) || (block.offsetC != 0)) continue; |
1902 | |
1903 | int off = getAddr0Offset(block, atype, astrategy); |
1904 | |
1905 | if (shiftOut) *shiftOut = block.addrShift; |
1906 | return addrRegs[b][0].sub(off, a64 ? DataType::uq : DataType::ud); |
1907 | } |
1908 | |
1909 | if (shiftOut) *shiftOut = 0; |
1910 | return Subregister(); |
1911 | } |
1912 | |
1913 | static inline int maxScatteredSIMD( |
1914 | HW hw, const MatrixAddressingStrategy &astrategy) { |
1915 | if (astrategy.newDP) return GRF::bytes(hw) >> 1; |
1916 | return 16; |
1917 | } |
1918 | |
1919 | static inline int minScatteredSIMD( |
1920 | HW hw, const MatrixAddressingStrategy &astrategy) { |
1921 | if (hw == HW::XeHPC) return 16; |
1922 | return maxScatteredSIMD(hw, astrategy) >> 1; |
1923 | } |
1924 | |
1925 | // Get width and height parameters for underlying 2D block load message. |
1926 | static void getBlock2DWH(int &w, int &h, const MatrixAddressing &atype, |
1927 | const RegisterBlock &block, int *outMultiX = nullptr) { |
1928 | int multiX = 1; |
1929 | w = isColMajor(atype.layout) ? block.nr : block.nc; |
1930 | h = isColMajor(atype.layout) ? block.nc : block.nr; |
1931 | w = (w * block.extra) / block.ebytes; |
1932 | if (isPacked(atype.layout)) { |
1933 | int maxW = 64 / block.ebytes; |
1934 | multiX = div_up(w, maxW); |
1935 | w /= multiX; |
1936 | h *= multiX; |
1937 | } |
1938 | if (outMultiX) *outMultiX = multiX; |
1939 | } |
1940 | |
1941 | static bool isRegisterColMajor(Type T, const MatrixAddressing &atype, |
1942 | const MatrixAddressingStrategy &astrategy) { |
1943 | return isColMajor(atype.layout) ^ isTransposing(astrategy.accessType) |
1944 | ^ isLargeCrosspack(T, atype.crosspack); |
1945 | } |
1946 | |
1947 | // Set up a RegisterBlock structure. |
1948 | template <HW hw> |
1949 | bool gemm_kernel_generator_t<hw>::getBlockInfo(Type T, |
1950 | const MatrixAddressing &atype, |
1951 | const MatrixAddressingStrategy &astrategy, int r, int c, |
1952 | bool remainderR, bool remainderC, bool writable, bool avoidFragment, |
1953 | int maxRBlock, int maxCBlock, int &rblock, int &cblock, |
1954 | RegisterBlock &block) { |
1955 | bool prefetch = astrategy.prefetch; |
1956 | int R = rounddown_pow2(r); |
1957 | int C = rounddown_pow2(c); |
1958 | |
1959 | if (maxRBlock == 0) maxRBlock = r; |
1960 | if (maxCBlock == 0) maxCBlock = c; |
1961 | |
1962 | if (isPacked(atype.layout)) { |
1963 | // Don't cross nonconsecutive tiles in a packed layout. |
1964 | bool cm = isColMajor(atype.layout) |
1965 | ^ isTransposing(astrategy.accessType); |
1966 | if (cm) { |
1967 | if (maxRBlock < atype.packSize && atype.tileC > 0) |
1968 | maxCBlock = std::min<int>(maxCBlock, atype.tileC); |
1969 | } else { |
1970 | if (maxCBlock < atype.packSize && atype.tileR > 0) |
1971 | maxRBlock = std::min<int>(maxRBlock, atype.tileR); |
1972 | } |
1973 | } |
1974 | |
1975 | // Set default parameters. |
1976 | block.colMajor = isColMajor(atype.layout); |
1977 | block.splitComplex = false; |
1978 | block.cxComponent = RegisterBlock::Interleaved; |
1979 | block.crosspack = 1; |
1980 | block.rowMask = MaskInfo::None(); |
1981 | block.colMask = MaskInfo::None(); |
1982 | block.rowFragment = 0; |
1983 | block.colFragment = 0; |
1984 | block.remainderR = remainderR; |
1985 | block.remainderC = remainderC; |
1986 | block.noRowsOK = false; |
1987 | block.noColsOK = false; |
1988 | block.descRemR = false; |
1989 | block.descRemC = false; |
1990 | block.descAssigned = false; |
1991 | block.addrShift = 0; |
1992 | block.writable = writable; |
1993 | block.clearFlag(); |
1994 | block.log2GRFBytes = GRF::log2Bytes(hw); |
1995 | block.msgRegs = 0; |
1996 | block.bytes = 0; |
1997 | block.hasNoLoad = false; |
1998 | |
1999 | auto &vrmask = block.rowMask.variable; |
2000 | auto &vcmask = block.colMask.variable; |
2001 | |
2002 | vrmask.rsize = 0; |
2003 | vcmask.rsize = 0; |
2004 | |
2005 | auto accessType = astrategy.accessType; |
2006 | |
2007 | switch (accessType) { |
2008 | case AccessType::ChannelScattered: |
2009 | case AccessType::Scattered: { |
2010 | bool channelScattered |
2011 | = (accessType == AccessType::ChannelScattered); |
2012 | |
2013 | // Detect large crosspack case. |
2014 | bool largeCP = isLargeCrosspack(T, atype.crosspack); |
2015 | int effCP = largeCP ? 1 : atype.crosspack; |
2016 | |
2017 | // Scattered read/write messages effectively transpose DW/QW matrices. |
2018 | block.colMajor = !block.colMajor ^ largeCP; |
2019 | |
2020 | // Let X be the contiguous dimension, Y the scattered dimension (in memory). |
2021 | int *xblock, *yblock; |
2022 | int maxXBlock, maxYBlock; |
2023 | int X, Y; |
2024 | bool remainderX, remainderY; |
2025 | int tileX, tileY; |
2026 | auto &vxmask = block.colMajor ? vcmask : vrmask; |
2027 | auto &vymask = block.colMajor ? vrmask : vcmask; |
2028 | auto &fragment |
2029 | = block.colMajor ? block.colFragment : block.rowFragment; |
2030 | auto smode = astrategy.smode; |
2031 | |
2032 | if (block.colMajor) { |
2033 | Y = R; |
2034 | X = C; |
2035 | yblock = &rblock; |
2036 | xblock = &cblock; |
2037 | maxYBlock = maxRBlock; |
2038 | maxXBlock = maxCBlock; |
2039 | remainderY = remainderR; |
2040 | remainderX = remainderC; |
2041 | tileY = atype.tileR; |
2042 | tileX = atype.tileC; |
2043 | } else { |
2044 | X = R; |
2045 | Y = C; |
2046 | xblock = &rblock; |
2047 | yblock = &cblock; |
2048 | maxXBlock = maxRBlock; |
2049 | maxYBlock = maxCBlock; |
2050 | remainderX = remainderR; |
2051 | remainderY = remainderC; |
2052 | tileX = atype.tileR; |
2053 | tileY = atype.tileC; |
2054 | } |
2055 | |
2056 | // Allowed accesses: |
2057 | // A64 Essentially max 256 bytes. |
2058 | // 8 slots x (1,2,4,8) dwords [Gen12/surface: 1,2,4] |
2059 | // 8 slots x (1,2,4) qwords |
2060 | // 16 slots x (1,2,4) dwords |
2061 | // 16 slots x (1,2) qwords |
2062 | // Others 8 slots x 1 dword |
2063 | // 16 slots x 1 dword |
2064 | // Slot counts doubled for 64-byte GRFs. |
2065 | |
2066 | // Native (col major in memory) matrix block sizes, as a result: |
2067 | // SIMD8: 1x8 2x4 4x2 8x1 (count 1) 2x8 4x8 8x8 [others] |
2068 | // SIMD16: 1x16 2x8 4x4 8x2 16x1 (count 1) 2x16 4x16 |
2069 | // Other layouts are possible too but require dummy (non-load) blocks. |
2070 | // Only kx8 and kx16 are supported for now for {4,8}-byte types. |
2071 | // For 16-byte types, only 1x4 and 1x8 are supported. |
2072 | |
2073 | auto maxSIMD = maxScatteredSIMD(hw, astrategy); |
2074 | auto minSIMD = minScatteredSIMD(hw, astrategy); |
2075 | |
2076 | auto Xc = (avoidFragment && remainderX) ? 1 : X; |
2077 | bool byte = (atype.alignment < 4) || (Xc * T * effCP < 4); |
2078 | bool a64 = (astrategy.base.getModel() == ModelA64); |
2079 | |
2080 | channelScattered |= byte; |
2081 | |
2082 | bool qword = (T.size() >= 8 && !channelScattered && !prefetch |
2083 | && (a64 || astrategy.newDP)); |
2084 | if (astrategy.atomic |
2085 | && hasNativeAtomicAdd(hw, T.real(), atype, astrategy)) |
2086 | qword &= (T.real().size() >= 8); |
2087 | int width = qword ? 8 : 4; |
2088 | block.ebytes = byte ? 1 : width; |
2089 | block.crosspack = std::max<int>(1, width / T); |
2090 | int consecutive = std::max<int>(1, T.size() / width); |
2091 | |
2092 | if (prefetch) consecutive = 1; |
2093 | |
2094 | if (block.ebytes == 4 && astrategy.base.getModel() == ModelSLM |
2095 | && !astrategy.newDP) |
2096 | channelScattered = true; |
2097 | |
2098 | bool simd1 = !a64 && !channelScattered; |
2099 | simd1 &= !astrategy.newDP; |
2100 | |
2101 | // Handle source crosspack. |
2102 | int uncrosspack = 1; |
2103 | if (effCP > 1) { |
2104 | if (effCP == block.crosspack) { |
2105 | block.crosspack = 1; |
2106 | uncrosspack = effCP; |
2107 | } else |
2108 | stub(); |
2109 | } |
2110 | |
2111 | // Try to fit a native matrix block size to X and Y. |
2112 | auto slots = std::min(Y, maxYBlock) * consecutive / uncrosspack; |
2113 | if (prefetch) { |
2114 | // Prefetch only: maximize Y usage. |
2115 | block.simdSize = maxSIMD; |
2116 | } else if (smode == ScatterSIMD::Narrow |
2117 | || (smode == ScatterSIMD::Default |
2118 | && block.ebytes * minSIMD > GRF::bytes(hw))) { |
2119 | // Maximize X usage because we always have at least 2 consecutive GRFs. |
2120 | block.simdSize |
2121 | = (slots >= maxSIMD && X <= 2) ? maxSIMD : minSIMD; |
2122 | } else { |
2123 | // Otherwise, try to maximize Y usage (larger SIMD, worse memory access). |
2124 | block.simdSize = maxSIMD; |
2125 | } |
2126 | block.simdSize |
2127 | = std::min<int>(block.simdSize, rounddown_pow2(slots)); |
2128 | |
2129 | bool no8x8DW = isGen12; |
2130 | bool no16x4QW = false; |
2131 | |
2132 | no8x8DW &= !astrategy.newDP; |
2133 | if (hw == HW::XeHPG && astrategy.newDP) |
2134 | no8x8DW = no16x4QW |
2135 | = true; // Not supported on 512 EU A0. OK on later steppings. |
2136 | no16x4QW |= (!astrategy.newDP && GRF::bytes(hw) == 64); |
2137 | |
2138 | int hwMaxXBlock; |
2139 | |
2140 | if (prefetch) |
2141 | hwMaxXBlock = 64 / T; |
2142 | else if (consecutive > 1) |
2143 | hwMaxXBlock = 1; |
2144 | else if (byte) |
2145 | hwMaxXBlock = remainderX ? 1 : block.crosspack; |
2146 | else if (simd1) |
2147 | hwMaxXBlock = block.crosspack; |
2148 | else if (a64 && astrategy.atomic) |
2149 | hwMaxXBlock = block.crosspack; |
2150 | else if (channelScattered || (block.ebytes == 4 && no8x8DW) |
2151 | || (block.ebytes == 8 && no16x4QW) |
2152 | || (block.simdSize == maxSIMD)) |
2153 | hwMaxXBlock = 16 / T; |
2154 | else |
2155 | hwMaxXBlock = 32 / T; |
2156 | |
2157 | maxXBlock = std::min(maxXBlock, hwMaxXBlock); |
2158 | |
2159 | if (tileX > 0) maxXBlock = std::min(maxXBlock, tileX); |
2160 | |
2161 | *xblock = std::min<int>(X, maxXBlock); |
2162 | block.count = *xblock; |
2163 | |
2164 | *yblock = block.simdSize * uncrosspack / consecutive; |
2165 | if (tileY > 0 && tileY < *yblock) stub(); |
2166 | |
2167 | if (prefetch) |
2168 | block.count = 1; |
2169 | else if (byte) |
2170 | block.count *= T.size(); |
2171 | else |
2172 | block.count = std::max<int>(1, block.count / block.crosspack); |
2173 | |
2174 | // LD is determined by actual # of SIMD slots in HW. But for X = 1 we may |
2175 | // shrink the LD to avoid allocating unnecessary registers. |
2176 | auto ldSIMD = block.simdSize; |
2177 | if (*xblock > 1 || (minSIMD * block.ebytes <= GRF::bytes(hw))) |
2178 | ldSIMD = std::max<int>(ldSIMD, minSIMD); |
2179 | block.ld = ldSIMD * uncrosspack / consecutive; |
2180 | |
2181 | // Handle remainder. Masking handles Y remainders. |
2182 | if (remainderY) { |
2183 | vymask.isFixed = false; |
2184 | vymask.bitRep = consecutive; |
2185 | vymask.maskRep = 1; |
2186 | vymask.rsize = *yblock; |
2187 | vymask.rdivide = 1; |
2188 | } |
2189 | |
2190 | // X remainders require fragmenting. Channel scattered float doesn't need complete fragmenting. |
2191 | // (ditto for regular scattered float with new dataport messages.) |
2192 | // Otherwise, fragment 2 is possible for DWord+ types but not implemented. |
2193 | if (remainderX && !prefetch) { |
2194 | if (avoidFragment && !remainderY && *xblock == 1) { |
2195 | vxmask.isFixed = false; |
2196 | vxmask.bitRep = 16; |
2197 | vxmask.maskRep = 1; |
2198 | vxmask.rsize = 1; |
2199 | vxmask.rdivide = 1; |
2200 | } else if ((channelScattered || astrategy.newDP) |
2201 | && block.crosspack == 1 && block.ebytes == T.size()) { |
2202 | fragment = std::min(*xblock, 4); |
2203 | if (block.colMajor) // Clang can't handle the ternary operator equivalent of this. |
2204 | block.descRemC = true; |
2205 | else |
2206 | block.descRemR = true; |
2207 | } else |
2208 | fragment = 1; |
2209 | } |
2210 | |
2211 | block.extra = consecutive; |
2212 | |
2213 | // BTS scattered accesses are addressed by elements. |
2214 | if (!astrategy.newDP && !channelScattered |
2215 | && !astrategy.base.isStateless()) |
2216 | block.addrShift = log2(block.ebytes); |
2217 | |
2218 | break; |
2219 | } |
2220 | case AccessType::Block: |
2221 | case AccessType::PseudoBlock: { |
2222 | // Three types of block messages: |
2223 | // block_oword: 16 byte align, BLK masking (= dw except ow channel on R Gen9 only -- silently ignore, can't fault) |
2224 | // aligned_oword: 4 byte align, no masking, read only |
2225 | // block_hword: [Gen9-12LP] A64; 4 byte align R, BLKCM masking (= dw but can do ow channel on Gen9 only) |
2226 | // A64; 16 byte align W |
2227 | // [XeHP] A64/BTS; 32 byte align R/W |
2228 | // New dataport messages support {DW, QW}x{1...64} with DW/QW alignment, no masking. |
2229 | // |
2230 | // Prefer block_hword in all cases. When block_hword can't be used: |
2231 | // Use oword if alignment can be assured (i.e. packed row/column layout, or oword-sized scalar) |
2232 | // Otherwise, use aligned oword. load/storeMatrixBlock will emit an error if masking/stores attempted. |
2233 | // |
2234 | // Pseudoblock messages have similar layouts, but are limited to |
2235 | // {8,16}x{dw,qw} sizes, so lengths 8,16 allowed for float, 4,8,16 for double. |
2236 | |
2237 | bool colMajor = block.colMajor; |
2238 | bool effCM = colMajor ^ isLargeCrosspack(T, atype.crosspack); |
2239 | auto consecutive = consecutiveElements(T, r, c, atype); |
2240 | bool masking = (effCM ? remainderR : remainderC); |
2241 | bool bytePartialCP |
2242 | = (T.size() & 3) && ((colMajor ? C : R) % atype.crosspack); |
2243 | bool byte = (atype.alignment & 3) || (consecutive * T & 3) |
2244 | || bytePartialCP || ((T.size() & 3) && writable && masking); |
2245 | bool byte1PerSlot = byte && (bytePartialCP || masking); |
2246 | bool pseudo = (accessType == AccessType::PseudoBlock) |
2247 | | needsPseudoblock( |
2248 | hw, T, R, C, atype, astrategy, writable, masking); |
2249 | int maxElements = 0; |
2250 | int maskGranularity = 1; |
2251 | int maxSIMD = maxScatteredSIMD(hw, astrategy); |
2252 | bool oword = false, aoword = false; |
2253 | int npack = 0; |
2254 | bool canQW = false, mustQW = false; |
2255 | |
2256 | bool a32 = (astrategy.base.getModel() == ModelA32); |
2257 | bool a64 = (astrategy.base.getModel() == ModelA64); |
2258 | bool sc = (astrategy.base.getModel() == ModelSC); |
2259 | bool slm = (astrategy.base.getModel() == ModelSLM); |
2260 | |
2261 | if (!pseudo && byte) return false; |
2262 | |
2263 | if (astrategy.newDP && !pseudo) { |
2264 | bool qword = ((atype.alignment | (consecutive * T)) % 8 == 0); |
2265 | block.ebytes = qword ? 8 : 4; |
2266 | maxElements = (64 * block.ebytes) / T; |
2267 | } else if (!pseudo) { |
2268 | int maxCount = 8; |
2269 | oword = !a64; |
2270 | aoword = ((atype.alignment & 0xF) != 0) || sc; |
2271 | if (hw > HW::Gen12LP) { |
2272 | oword |= (atype.alignment & 0x1F) != 0; |
2273 | if (slm) maxCount = 16; |
2274 | } |
2275 | block.ebytes = oword ? 16 : 32; |
2276 | maxElements = maxCount * block.ebytes / T; |
2277 | maskGranularity = 4; // Block accesses mask by dwords |
2278 | } else { |
2279 | bool nativeAtomic = astrategy.atomic |
2280 | && hasNativeAtomicAdd(hw, T.real(), atype, astrategy); |
2281 | canQW = ((atype.alignment | (consecutive * T)) % 8 == 0); |
2282 | if (!astrategy.newDP) canQW &= !byte && a64; |
2283 | if (slm |
2284 | && astrategy |
2285 | .atomic) // QW SLM atomics are implemented in XeHPC, but seeing functionality issues. |
2286 | canQW = false; |
2287 | if (remainderR || remainderC) canQW &= (T.size() % 8 == 0); |
2288 | if (nativeAtomic) canQW = mustQW = (T.real().size() >= 8); |
2289 | auto stride = canQW ? 8 : 4; |
2290 | auto maxNPack = byte1PerSlot ? 1 : std::max<int>(1, stride / T); |
2291 | int simdCap = maxSIMD; |
2292 | if (astrategy.atomic && !nativeAtomic) simdCap = 16; |
2293 | maxElements = simdCap * maxNPack; |
2294 | if (T.size() > stride) maxElements = maxElements * stride / T; |
2295 | } |
2296 | |
2297 | auto maxABlock = maxElements / (byte1PerSlot ? 1 : atype.crosspack); |
2298 | |
2299 | auto choosePackedRCBlock = [=, &block](int &xblock, int &yblock, |
2300 | int tileX, int tileY, int X, |
2301 | int Y) { |
2302 | xblock = std::min<int>(maxABlock, X); |
2303 | |
2304 | if (tileX) { |
2305 | int ntileY = tileY ? (maxElements / (xblock * tileY)) : 0; |
2306 | if (xblock < atype.packSize || Y < tileY || ntileY == 0) |
2307 | xblock = std::min<int>(xblock, tileX); |
2308 | } |
2309 | if ((tileX ? tileX : atype.packSize) <= xblock) { |
2310 | yblock = std::min<int>(maxElements / xblock, Y); |
2311 | if (yblock < atype.crosspack |
2312 | && isLargeCrosspack(T, atype.crosspack)) { |
2313 | yblock = atype.crosspack; |
2314 | xblock = std::min<int>(xblock, maxElements / yblock); |
2315 | } |
2316 | if (tileY > 0 && yblock > tileY) |
2317 | yblock = align_down(yblock, tileY); |
2318 | } else |
2319 | yblock = atype.crosspack; // Remainder loop: no longer packed in memory |
2320 | |
2321 | block.crosspack = atype.crosspack; |
2322 | Y = div_up(Y, atype.crosspack); |
2323 | }; |
2324 | |
2325 | switch (atype.layout) { |
2326 | case MatrixLayout::Pc: |
2327 | choosePackedRCBlock( |
2328 | rblock, cblock, atype.tileR, atype.tileC, R, C); |
2329 | break; |
2330 | case MatrixLayout::N: |
2331 | if (atype.crosspack > 1) stub(); |
2332 | if (atype.tileR == R && R <= maxElements) { |
2333 | cblock = std::min<int>(maxElements / R, C); |
2334 | rblock = R; |
2335 | } else { |
2336 | cblock = 1; |
2337 | rblock = std::min<int>(maxElements, R); |
2338 | } |
2339 | break; |
2340 | case MatrixLayout::Pr: |
2341 | choosePackedRCBlock( |
2342 | cblock, rblock, atype.tileC, atype.tileR, C, R); |
2343 | break; |
2344 | case MatrixLayout::T: |
2345 | if (atype.crosspack > 1) stub(); |
2346 | if (atype.tileC == C && C <= maxElements) { |
2347 | rblock = std::min<int>(maxElements / cblock, R); |
2348 | cblock = C; |
2349 | } else { |
2350 | rblock = 1; |
2351 | cblock = std::min<int>(maxElements, C); |
2352 | } |
2353 | break; |
2354 | } |
2355 | |
2356 | rblock = std::min(rblock, maxRBlock); |
2357 | cblock = std::min(cblock, maxCBlock); |
2358 | |
2359 | if (pseudo) { |
2360 | bool qword = mustQW |
2361 | || (canQW && (rblock * cblock * T >= 4 * maxSIMD)); |
2362 | npack = std::max<int>(1, (qword ? 8 : 4) / T); |
2363 | if (byte1PerSlot) { |
2364 | if (isLargeCrosspack(T, block.crosspack)) { |
2365 | if (block.crosspack == (colMajor ? cblock : rblock)) |
2366 | block.colMajor = colMajor = effCM; |
2367 | else |
2368 | stub(); |
2369 | } |
2370 | block.crosspack = npack; |
2371 | npack = 1; |
2372 | (effCM ? cblock : rblock) = 1; |
2373 | } |
2374 | maskGranularity = qword ? 8 : byte1PerSlot ? T.size() : 4; |
2375 | } |
2376 | |
2377 | if (remainderR) { |
2378 | if (effCM) { |
2379 | // rblock cannot be more than 16 dwords = 64 bytes for masking |
2380 | // except for pseudo-block |
2381 | int rblockLimit = pseudo ? rblock : 64 / T; |
2382 | |
2383 | if (avoidFragment) |
2384 | rblock = std::min<int>(rblock, rblockLimit); |
2385 | if (rblock > rblockLimit) |
2386 | block.rowFragment = rblockLimit; |
2387 | else { |
2388 | // For sizeof(T) < maskGranularity, this is a bit of a cheat. |
2389 | // |
2390 | // As long as we do not need to write to this matrix, we can read |
2391 | // in maskGranularity-sized chunks knowing we will never cross a page boundary. |
2392 | |
2393 | if (writable && (T.size() & (maskGranularity - 1))) |
2394 | return false; |
2395 | if (!pseudo && oword && aoword) hw_unsupported(); |
2396 | |
2397 | if (!pseudo |
2398 | && !(isPacked(atype.layout) |
2399 | && (atype.packSize == rblock))) |
2400 | cblock = 1; |
2401 | |
2402 | vrmask.isFixed = false; |
2403 | vrmask.rsize = rblock; |
2404 | vrmask.bitRep |
2405 | = std::max<int>(T.size() / maskGranularity, 1); |
2406 | vrmask.maskRep = cblock; |
2407 | vrmask.rdivide = std::max<int>(maskGranularity / T, 1); |
2408 | } |
2409 | } else { |
2410 | if (avoidFragment && !remainderC) { |
2411 | // No native masking in this dimension. One mask/row. |
2412 | rblock = 1; |
2413 | vrmask.isFixed = false; |
2414 | vrmask.bitRep = 16; |
2415 | vrmask.maskRep = 1; |
2416 | vrmask.rdivide = 1; |
2417 | vrmask.rsize = 1; |
2418 | } else { |
2419 | // Fragment it. Could actually handle rowFragment = 2 by changing descriptor. |
2420 | block.rowFragment = 1; |
2421 | } |
2422 | } |
2423 | } |
2424 | |
2425 | if (remainderC) { |
2426 | if (!effCM) { |
2427 | // cblock cannot be more than 16 dwords = 64 bytes except for pseudo-block |
2428 | int cblockLimit = pseudo ? cblock : 64 / T; |
2429 | |
2430 | if (avoidFragment) |
2431 | cblock = std::min<int>(cblock, cblockLimit); |
2432 | if (cblock > cblockLimit) |
2433 | block.colFragment = cblockLimit; |
2434 | else { |
2435 | if (writable && (T.size() & (maskGranularity - 1))) |
2436 | return false; |
2437 | if (!pseudo && oword && aoword) hw_unsupported(); |
2438 | |
2439 | if (!pseudo |
2440 | && !(isPacked(atype.layout) |
2441 | && (atype.packSize == cblock))) |
2442 | rblock = 1; |
2443 | |
2444 | vcmask.isFixed = false; |
2445 | vcmask.rsize = cblock; |
2446 | vcmask.bitRep |
2447 | = std::max<int>(T.size() / maskGranularity, 1); |
2448 | vcmask.maskRep = rblock; |
2449 | vcmask.rdivide = std::max<int>(maskGranularity / T, 1); |
2450 | } |
2451 | } else { |
2452 | if (avoidFragment && !remainderR) { |
2453 | // No native masking in this dimension. One mask/column. |
2454 | cblock = 1; |
2455 | vcmask.isFixed = false; |
2456 | vcmask.bitRep = 16; |
2457 | vcmask.maskRep = 1; |
2458 | vcmask.rdivide = 1; |
2459 | vcmask.rsize = 1; |
2460 | } else { |
2461 | // Fragment it. Could actually handle colFragment = 2 by changing descriptor. |
2462 | block.colFragment = 1; |
2463 | } |
2464 | } |
2465 | } |
2466 | |
2467 | int nbytes = (rblock * cblock) * T; |
2468 | block.simdSize |
2469 | = clamp(roundup_pow2(nbytes) / maskGranularity, 1, maxSIMD); |
2470 | block.ld = colMajor ? rblock : cblock; |
2471 | if (!pseudo) { |
2472 | if (astrategy.newDP) block.simdSize = 1; |
2473 | block.count = div_up(nbytes, block.ebytes); |
2474 | block.extra = aoword; |
2475 | if (block.ebytes == 16 && !(a32 || a64) |
2476 | && !aoword) // BTS/SLM oword loads are oword-addressed. |
2477 | block.addrShift = 4; |
2478 | } else { |
2479 | block.count = byte ? std::min(nbytes, npack * T) : 1; |
2480 | block.ebytes = byte ? 1 : maskGranularity; |
2481 | block.extra = 1; |
2482 | if (!(a32 || a64 |
2483 | || pseudoblockUseSurface(atype, astrategy, block) |
2484 | || astrategy.atomic)) |
2485 | block.addrShift = log2(block.ebytes); |
2486 | } |
2487 | if (astrategy.newDP) block.addrShift = 0; |
2488 | break; |
2489 | } |
2490 | case AccessType::Block2D: |
2491 | case AccessType::Block2DTranspose: |
2492 | case AccessType::Block2DVNNI: { |
2493 | // bytes * array length <= 8 |
2494 | // width * array length <= 64 bytes |
2495 | // => width <= 1 GRF |
2496 | // height <= 32 (load) 8 (store) |
2497 | // array length = 1 for store, transpose |
2498 | // |
2499 | // normal: width >= 4 bytes |
2500 | // transpose: d32 only |
2501 | // vnni: d8/d16 only, height >= 4 bytes |
2502 | bool transpose = (accessType == AccessType::Block2DTranspose); |
2503 | bool vnni = (accessType == AccessType::Block2DVNNI); |
2504 | |
2505 | bool memCM = block.colMajor; |
2506 | block.colMajor ^= transpose; |
2507 | auto X = memCM ? R : C; |
2508 | auto Y = memCM ? C : R; |
2509 | auto &xblock = memCM ? rblock : cblock; |
2510 | auto &yblock = memCM ? cblock : rblock; |
2511 | auto maxXBlock = memCM ? maxRBlock : maxCBlock; |
2512 | auto maxYBlock = memCM ? maxCBlock : maxRBlock; |
2513 | |
2514 | if (hw != HW::XeHPC || !astrategy.newDP) hw_unsupported(); |
2515 | |
2516 | // Choose underlying type. |
2517 | auto Tblock = T; |
2518 | if (transpose) { |
2519 | if (atype.alignment % 4) hw_unsupported(); |
2520 | if (Tblock.size() > 8) hw_unsupported(); |
2521 | if (Tblock.size() > 4) { |
2522 | if (hw == HW::XeHPC && getStepping() < SteppingPVCXTB0) |
2523 | hw_unsupported(); |
2524 | Tblock = Type::u64; |
2525 | maxXBlock = std::min(maxXBlock, 4); |
2526 | maxYBlock = 8; |
2527 | } else { |
2528 | Tblock = Type::u32; |
2529 | maxXBlock = std::min(maxXBlock, (8 * Tblock) / T); |
2530 | } |
2531 | } else if (vnni) { |
2532 | if (atype.alignment % 8) hw_unsupported(); |
2533 | if (Tblock.size() >= 4) hw_unsupported(); |
2534 | if ((Y * Tblock) % 4) hw_unsupported(); |
2535 | maxXBlock = std::min(maxXBlock, 16); |
2536 | } else { |
2537 | if (atype.alignment % 8) hw_unsupported(); |
2538 | if (Tblock.size() > 8) Tblock = Type::u64; |
2539 | block.crosspack = atype.crosspack; |
2540 | } |
2541 | if ((X * T) % 4) hw_unsupported(); |
2542 | |
2543 | // Reinterpret X/maxXBlock to underlying type. |
2544 | maxXBlock = (maxXBlock * T) / Tblock; |
2545 | auto X_logical = X; |
2546 | X = (X * T) / Tblock; |
2547 | |
2548 | // Carve out a maximal allowed block size. |
2549 | xblock = std::min(X, 64 / Tblock); |
2550 | xblock = std::max(xblock, 4 / Tblock); |
2551 | int yblockLimit = writable ? 8 : 32; |
2552 | |
2553 | if (isPacked(atype.layout) && 2 * xblock <= X |
2554 | && X_logical == atype.packSize) { |
2555 | // Split logical x dimension into multiple spans to accomodate width restriction. |
2556 | if (astrategy.address2D) stub(); |
2557 | int multiX = X / xblock; |
2558 | xblock *= multiX; |
2559 | yblockLimit /= multiX; |
2560 | } |
2561 | |
2562 | yblock = std::min({maxYBlock, Y, yblockLimit}); |
2563 | |
2564 | if (transpose && Tblock.size() == 8 && yblock != 8) |
2565 | hw_unsupported(); |
2566 | |
2567 | // Choose # of blocks. In postprocessLayout, this RegisterBlock will be |
2568 | // split into one RegisterBlock for each block in the array. |
2569 | int count = 1; |
2570 | if (!(writable || transpose)) { |
2571 | count = rounddown_pow2(xblock / maxXBlock); |
2572 | count = std::min({count, 8 / Tblock, 64 / xblock}); |
2573 | count = std::max(count, 1); |
2574 | } |
2575 | xblock = std::min(xblock, maxXBlock * count); |
2576 | |
2577 | // Crosspack calculation. |
2578 | int crosspack = (transpose || vnni) ? std::max(1, 4 / T) : 1; |
2579 | if (atype.crosspack == 1) |
2580 | block.crosspack = crosspack; |
2581 | else if (atype.crosspack == crosspack) |
2582 | block.crosspack = 1; |
2583 | else |
2584 | return false; |
2585 | |
2586 | // Convert size from underlying type to our actual type. |
2587 | xblock = (xblock * Tblock) / T; |
2588 | |
2589 | block.simdSize = 1; |
2590 | block.ld = roundup_pow2(transpose ? yblock : xblock); |
2591 | block.ebytes = Tblock.size(); |
2592 | block.count = count; |
2593 | block.extra = T.size(); |
2594 | auto bytes = align_up((block.colMajor ? cblock : rblock) / count, |
2595 | block.crosspack) |
2596 | * block.ld * count * T; |
2597 | block.msgRegs = GRF::bytesToGRFs(hw, bytes); |
2598 | break; |
2599 | } |
2600 | } |
2601 | |
2602 | // The mask moduli are almost always rblock/cblock. |
2603 | // Also, clamp mask reps to ensure mask length does not exceed SIMD size. |
2604 | if (block.rowMask && !block.rowMask.fixed.isFixed) { |
2605 | if (vrmask.rsize == 0) vrmask.rsize = rblock; |
2606 | vrmask.maskRep = std::min<int>(vrmask.maskRep, |
2607 | std::max<int>(1, |
2608 | vrmask.rdivide * block.simdSize |
2609 | / (vrmask.bitRep * vrmask.rsize))); |
2610 | block.noRowsOK = true; // All-zero masks are always OK. |
2611 | } |
2612 | if (block.colMask && !block.colMask.fixed.isFixed) { |
2613 | if (vcmask.rsize == 0) vcmask.rsize = cblock; |
2614 | vcmask.maskRep = std::min<int>(vcmask.maskRep, |
2615 | std::max<int>(1, |
2616 | vcmask.rdivide * block.simdSize |
2617 | / (vcmask.bitRep * vcmask.rsize))); |
2618 | block.noColsOK = true; |
2619 | } |
2620 | |
2621 | return true; |
2622 | } |
2623 | |
2624 | template <HW hw> |
2625 | bool gemm_kernel_generator_t<hw>::tryAddMasking(Type T, RegisterBlock &block, |
2626 | bool remainderR, bool remainderC, const MatrixAddressing &atype, |
2627 | const MatrixAddressingStrategy &astrategy) { |
2628 | auto blockNew = block; |
2629 | blockNew.remainderR |= remainderR; |
2630 | blockNew.remainderC |= remainderC; |
2631 | |
2632 | auto curAccessType = implAccessType(atype, astrategy, block); |
2633 | |
2634 | if (curAccessType == AccessType::Block) { |
2635 | if (astrategy.newDP) return false; |
2636 | if (hw >= HW::XeHPC) return false; |
2637 | } |
2638 | |
2639 | bool remChanged = (block.colMajor ? (remainderR && !block.remainderR) |
2640 | : (remainderC && !block.remainderC)); |
2641 | |
2642 | if (remChanged && !isBlock2D(curAccessType)) { |
2643 | int rblock, cblock; |
2644 | if (!getBlockInfo(T, atype, astrategy, block.nr, block.nc, |
2645 | blockNew.remainderR, blockNew.remainderC, block.writable, |
2646 | true, 0, 0, rblock, cblock, blockNew)) |
2647 | return false; |
2648 | if (rblock != block.nr || cblock != block.nc) return false; |
2649 | if (implAccessType(atype, astrategy, blockNew) != curAccessType) |
2650 | return false; |
2651 | if (curAccessType != AccessType::Block) { |
2652 | if (blockNew.ebytes != block.ebytes) return false; |
2653 | if (blockNew.ebytes == 1 && blockNew.count != block.count) |
2654 | return false; |
2655 | } |
2656 | } |
2657 | |
2658 | block = blockNew; |
2659 | return true; |
2660 | } |
2661 | |
2662 | template <HW hw> |
2663 | bool gemm_kernel_generator_t<hw>::tryAddMasking(Type T, |
2664 | vector<RegisterBlock> &layout, bool remainderR, bool remainderC, |
2665 | const MatrixAddressing &atype, |
2666 | const MatrixAddressingStrategy &astrategy) { |
2667 | auto layoutNew = layout; |
2668 | for (auto &block : layoutNew) { |
2669 | if (!tryAddMasking(T, block, remainderR, remainderC, atype, astrategy)) |
2670 | return false; |
2671 | } |
2672 | std::swap(layout, layoutNew); |
2673 | return true; |
2674 | } |
2675 | |
2676 | template <HW hw> |
2677 | void gemm_kernel_generator_t<hw>::addMasking(Type T, |
2678 | vector<RegisterBlock> &layout, bool remainderR, bool remainderC, |
2679 | const MatrixAddressing &atype, |
2680 | const MatrixAddressingStrategy &astrategy) { |
2681 | for (auto &block : layout) |
2682 | if (!tryAddMasking(T, block, remainderR, remainderC, atype, astrategy)) |
2683 | stub(); |
2684 | } |
2685 | |
2686 | template <HW hw> |
2687 | void gemm_kernel_generator_t<hw>::addMasking(Type T, |
2688 | vector<RegisterBlock> &layout, vector<GRFRange> &addrs, |
2689 | const Subregister &ld, bool remainderR, bool remainderC, |
2690 | const MatrixAddressing &atype, |
2691 | const MatrixAddressingStrategy &astrategy, |
2692 | const CommonStrategy &strategy, CommonState &state, int dataRegs) { |
2693 | // Check if masking can be trivially enabled without changing the layout. |
2694 | if (tryAddMasking(T, layout, remainderR, remainderC, atype, astrategy)) |
2695 | return; |
2696 | |
2697 | // If not, tear down the old layout and create a new one in its place, recalculating address registers. |
2698 | vector<RegisterBlock> layoutNew; |
2699 | int r, c; |
2700 | bool remR = remainderR || hasRemainders(layout, true, false); |
2701 | bool remC = remainderC || hasRemainders(layout, false, true); |
2702 | getLayoutDims(layout, r, c); |
2703 | if (!getRegLayout(T, layoutNew, r, c, remR, remC, false, true, 0, 0, atype, |
2704 | astrategy)) |
2705 | stub(); |
2706 | if (dataRegs < 0) dataRegs = getRegCount(layout); |
2707 | if (getRegCount(layoutNew) > dataRegs) stub(); |
2708 | if (isLayoutColMajor(layoutNew) != isLayoutColMajor(layout)) stub(); |
2709 | |
2710 | int shift = 0; |
2711 | auto addr0 = getOriginAddr(layout, addrs, atype, astrategy, &shift); |
2712 | std::swap(layout, layoutNew); |
2713 | if (shift > 0) shl(1, addr0, addr0, shift); |
2714 | safeReleaseRanges(addrs, state); |
2715 | state.ra.claim(addr0); |
2716 | |
2717 | Address2DParams params2D {}; |
2718 | if (astrategy.address2D) stub(); |
2719 | allocAddrRegs(addrs, layout, atype, astrategy, state); |
2720 | setupAddr(T, addrs, addr0, layout, ld, atype, astrategy, strategy, state, |
2721 | params2D); |
2722 | |
2723 | state.ra.safeRelease(addr0); |
2724 | } |
2725 | |
2726 | template <HW hw> |
2727 | bool gemm_kernel_generator_t<hw>::getSubblock(Type T, RegisterBlock &blockDst, |
2728 | const RegisterBlock &blockSrc, bool column, int x1, int x2, |
2729 | int x1Unclamped, int x2Unclamped, bool overrunOK, |
2730 | const MatrixAddressing &atype, |
2731 | const MatrixAddressingStrategy &astrategy) { |
2732 | auto effAccessType = effectiveAccessType(atype, astrategy, blockSrc); |
2733 | blockDst = blockSrc; |
2734 | |
2735 | auto &ns = (column ? blockDst.nc : blockDst.nr); |
2736 | auto &nt = (column ? blockDst.nr : blockDst.nc); |
2737 | int oldNS = ns; |
2738 | |
2739 | (column ? blockDst.offsetC : blockDst.offsetR) += x1; |
2740 | ns = x2 - x1; |
2741 | |
2742 | if ((ns == oldNS) && (overrunOK || !blockSrc.hasNoLoad)) return true; |
2743 | |
2744 | if (blockSrc.colMajor == column) { |
2745 | if (x1 % blockSrc.crosspack) return false; |
2746 | |
2747 | blockDst.offsetBytes += (x1 * blockSrc.bytes) / oldNS; |
2748 | |
2749 | if (blockSrc.isLoadBlock()) switch (effAccessType) { |
2750 | case AccessType::Scattered: |
2751 | case AccessType::ChannelScattered: |
2752 | blockDst.count = x2 - x1; |
2753 | if (blockDst.ebytes == 1) |
2754 | blockDst.count *= T.size(); |
2755 | else if (blockDst.splitComplex) |
2756 | blockDst.count *= 2; |
2757 | else if (T.size() < blockDst.ebytes) { |
2758 | // Extra alignment path with small types. |
2759 | // Check to see if we can still use this element size, |
2760 | // if not downgrade to scattered byte. |
2761 | // Note for surface accesses this requires shifting the addresses back. |
2762 | auto bcount = blockDst.count * T; |
2763 | if (bcount % 4) { |
2764 | blockDst.ebytes = 1; |
2765 | blockDst.addrShift = 0; |
2766 | blockDst.count = bcount; |
2767 | if (blockDst.count > 4) stub(); |
2768 | } else |
2769 | blockDst.count = bcount >> 2; |
2770 | } |
2771 | break; |
2772 | case AccessType::Block: |
2773 | case AccessType::PseudoBlock: { |
2774 | auto offBytes = x1 * nt * T; |
2775 | if (offBytes % blockDst.ebytes) return false; |
2776 | auto reqBytes = (x2 - x1) * nt * T; |
2777 | auto align = std::min<int>( |
2778 | blockDst.ebytes, blockDst.simdSize * 4); |
2779 | if (!overrunOK && (reqBytes & (align - 1))) return false; |
2780 | auto ncount = div_up(reqBytes, blockDst.ebytes); |
2781 | auto count = roundup_pow2(ncount); |
2782 | if (!overrunOK && (count != ncount)) return false; |
2783 | if (effAccessType == AccessType::Block) |
2784 | blockDst.count = count; |
2785 | else |
2786 | blockDst.simdSize = std::max(1, count / blockDst.count); |
2787 | break; |
2788 | } |
2789 | case AccessType::Block2D: break; |
2790 | case AccessType::Block2DTranspose: |
2791 | case AccessType::Block2DVNNI: |
2792 | int crosspack = std::max(1, 4 / blockDst.ebytes); |
2793 | if (x1 % crosspack || x2 % crosspack) return false; |
2794 | break; |
2795 | } |
2796 | |
2797 | blockDst.calcBytes(T, astrategy); |
2798 | } else { |
2799 | blockDst.offsetBytes += x1 * T * blockSrc.crosspack; |
2800 | |
2801 | if (blockSrc.isLoadBlock()) switch (effAccessType) { |
2802 | case AccessType::Block: |
2803 | case AccessType::PseudoBlock: { |
2804 | // Update count and mask information. |
2805 | // Beware, cheat: with DW-aligned sub-DW types, true block may be downgraded to byte PseudoBlock, |
2806 | // which requires 2 address registers, though only 1 is used, and only 1 may be allocated. |
2807 | int rblock, cblock; |
2808 | (void)getBlockInfo(T, atype, astrategy, blockDst.nr, |
2809 | blockDst.nc, blockDst.remainderR, |
2810 | blockDst.remainderC, blockDst.writable, false, 0, 0, |
2811 | rblock, cblock, blockDst); |
2812 | blockDst.simplify(T); |
2813 | break; |
2814 | } |
2815 | case AccessType::Scattered: |
2816 | case AccessType::ChannelScattered: { |
2817 | if (T.size() > blockDst.ebytes) return false; |
2818 | if (x1 != 0) return false; |
2819 | if (!is_zero_or_pow2(x2)) return false; |
2820 | |
2821 | blockDst.simdSize = div_up(ns * T, blockDst.ebytes); |
2822 | |
2823 | auto minSIMD = minScatteredSIMD(hw, astrategy); |
2824 | if (blockDst.simdSize <= minSIMD |
2825 | && blockSrc.simdSize > minSIMD) { |
2826 | if (blockDst.count > 1 && blockDst.ebytes > 1) |
2827 | return false; |
2828 | blockDst.ld >>= 1; |
2829 | } |
2830 | break; |
2831 | } |
2832 | case AccessType::Block2D: |
2833 | case AccessType::Block2DTranspose: |
2834 | case AccessType::Block2DVNNI: |
2835 | if (ns != oldNS) |
2836 | stub(); // Can do this, but not implemented. |
2837 | if (blockDst.simdSize != 0) // Recompute block array length. |
2838 | blockDst.count = div_up(x2Unclamped, |
2839 | isColMajor(atype.layout) ? blockDst.nr |
2840 | : blockDst.nc); |
2841 | // TODO: need to recompute ld |
2842 | break; |
2843 | } |
2844 | |
2845 | blockDst.calcBytes(T, astrategy); |
2846 | } |
2847 | |
2848 | return true; |
2849 | } |
2850 | |
2851 | // Get list of subblocks intersecting rows/columns [x1, x2). |
2852 | template <HW hw> |
2853 | bool gemm_kernel_generator_t<hw>::getSubblocks(Type T, |
2854 | vector<RegisterBlock> &sublayout, const vector<RegisterBlock> &layout, |
2855 | bool column, int x1, int x2, bool overrunOK, |
2856 | const MatrixAddressing &atype, |
2857 | const MatrixAddressingStrategy &astrategy) { |
2858 | auto RegisterBlock::*nq = column ? &RegisterBlock::nc : &RegisterBlock::nr; |
2859 | auto RegisterBlock::*offsetQ |
2860 | = column ? &RegisterBlock::offsetC : &RegisterBlock::offsetR; |
2861 | |
2862 | sublayout.clear(); |
2863 | |
2864 | for (auto &block : layout) { |
2865 | int qq1Unclamped = x1 - block.*offsetQ; |
2866 | int qq2Unclamped = x2 - block.*offsetQ; |
2867 | int qq1 = clamp<int>(qq1Unclamped, 0, block.*nq); |
2868 | int qq2 = clamp<int>(qq2Unclamped, 0, block.*nq); |
2869 | if (qq2 > qq1) { |
2870 | RegisterBlock subblock; |
2871 | if (!getSubblock(T, subblock, block, column, qq1, qq2, qq1Unclamped, |
2872 | qq2Unclamped, overrunOK, atype, astrategy)) { |
2873 | status << "Could not make subblock." << status_stream::endl; |
2874 | return false; |
2875 | } |
2876 | sublayout.push_back(subblock); |
2877 | } |
2878 | } |
2879 | return true; |
2880 | } |
2881 | |
2882 | // Get list of subblocks intersecting rows/columns [x1, x2), and associated address registers and/or indices. |
2883 | // Returns false if fragmenting failed, or an address register doesn't match a previous one. |
2884 | template <HW hw> |
2885 | bool gemm_kernel_generator_t<hw>::getSubblocks(Type T, |
2886 | vector<RegisterBlock> &sublayout, vector<GRFRange> *subaddrs, |
2887 | vector<int> *indices, const vector<RegisterBlock> &layout, |
2888 | const vector<GRFRange> *addrs, bool column, int x1, int x2, |
2889 | bool overrunOK, const MatrixAddressing &atype, |
2890 | const MatrixAddressingStrategy &astrategy) { |
2891 | auto RegisterBlock::*nq = column ? &RegisterBlock::nc : &RegisterBlock::nr; |
2892 | auto RegisterBlock::*offsetQ |
2893 | = column ? &RegisterBlock::offsetC : &RegisterBlock::offsetR; |
2894 | |
2895 | if (subaddrs) subaddrs->clear(); |
2896 | if (indices) indices->clear(); |
2897 | sublayout.clear(); |
2898 | |
2899 | for (int b = 0; b < int(layout.size()); b++) { |
2900 | auto &block = layout[b]; |
2901 | int qq1Unclamped = x1 - block.*offsetQ; |
2902 | int qq2Unclamped = x2 - block.*offsetQ; |
2903 | int qq1 = clamp<int>(qq1Unclamped, 0, block.*nq); |
2904 | int qq2 = clamp<int>(qq2Unclamped, 0, block.*nq); |
2905 | if (qq2 > qq1) { |
2906 | RegisterBlock subblock; |
2907 | if (!getSubblock(T, subblock, block, column, qq1, qq2, qq1Unclamped, |
2908 | qq2Unclamped, overrunOK, atype, astrategy)) { |
2909 | status << "Could not make subblock." << status_stream::endl; |
2910 | return false; |
2911 | } |
2912 | if (subblock.offsetR != block.offsetR |
2913 | || subblock.offsetC != block.offsetC) { |
2914 | status << "Subblock is not aligned to parent block." |
2915 | << status_stream::endl; |
2916 | return false; |
2917 | } |
2918 | if (subaddrs) subaddrs->push_back((*addrs)[b]); |
2919 | if (indices) indices->push_back(int(b)); |
2920 | sublayout.push_back(subblock); |
2921 | } |
2922 | } |
2923 | return true; |
2924 | } |
2925 | |
2926 | // Get list of subblocks intersecting rows/columns [x1, x2), and associated address registers. |
2927 | // Returns false if fragmenting failed, or an address register doesn't match a previous one. |
2928 | template <HW hw> |
2929 | bool gemm_kernel_generator_t<hw>::getSubblocks(Type T, |
2930 | vector<RegisterBlock> &sublayout, vector<GRFRange> &subaddrs, |
2931 | const vector<RegisterBlock> &layout, const vector<GRFRange> &addrs, |
2932 | bool column, int x1, int x2, bool overrunOK, |
2933 | const MatrixAddressing &atype, |
2934 | const MatrixAddressingStrategy &astrategy) { |
2935 | return getSubblocks(T, sublayout, &subaddrs, nullptr, layout, &addrs, |
2936 | column, x1, x2, overrunOK, atype, astrategy); |
2937 | } |
2938 | |
2939 | // Get list of subblocks intersecting rows/columns [x1, x2), and indices of associated address registers. |
2940 | // Returns false if fragmenting failed, or an address register doesn't match a previous one. |
2941 | template <HW hw> |
2942 | bool gemm_kernel_generator_t<hw>::getSubblocks(Type T, |
2943 | vector<RegisterBlock> &sublayout, vector<int> &indices, |
2944 | const vector<RegisterBlock> &layout, bool column, int x1, int x2, |
2945 | bool overrunOK, const MatrixAddressing &atype, |
2946 | const MatrixAddressingStrategy &astrategy) { |
2947 | return getSubblocks(T, sublayout, nullptr, &indices, layout, nullptr, |
2948 | column, x1, x2, overrunOK, atype, astrategy); |
2949 | } |
2950 | |
2951 | // Adjust address registers as needed for a newly-created subblock. |
2952 | template <HW hw> |
2953 | void gemm_kernel_generator_t<hw>::adjustSubblockAddrs(Type T, |
2954 | const vector<RegisterBlock> &sublayout, |
2955 | const vector<GRFRange> &subaddrs, const vector<RegisterBlock> &layout, |
2956 | const vector<GRFRange> &addrs, const MatrixAddressing &atype, |
2957 | const MatrixAddressingStrategy &astrategy, |
2958 | const CommonStrategy &strategy, const CommonState &state) { |
2959 | bool a64 = (astrategy.base.getModel() == ModelA64); |
2960 | |
2961 | auto nsubs = int(sublayout.size()); |
2962 | auto nblocks = int(layout.size()); |
2963 | |
2964 | for (int isub = 0; isub < nsubs; isub++) { |
2965 | // Find parent block by comparing address registers. |
2966 | auto &subaddr = subaddrs[isub]; |
2967 | const RegisterBlock *pptr = nullptr; |
2968 | for (int i = 0; i < nblocks; i++) { |
2969 | if (addrs[i].getBase() == subaddr.getBase()) { |
2970 | pptr = &layout[i]; |
2971 | break; |
2972 | } |
2973 | } |
2974 | if (!pptr) stub(); |
2975 | |
2976 | auto &block = *pptr; |
2977 | auto &subblock = sublayout[isub]; |
2978 | |
2979 | auto off = getAddr0Offset(block, atype, astrategy); |
2980 | auto suboff = getAddr0Offset(subblock, atype, astrategy); |
2981 | |
2982 | // Perform any necessary shifts/moves. Moves are only for non-A64 block->pseudoblock settings. |
2983 | if (suboff != off) { |
2984 | if (subblock.simdSize != 1) |
2985 | stub(); // Need to prepare more pseudoblock addresses. |
2986 | mov<uint32_t>(1, subaddr[0][suboff], subaddr[0][off]); |
2987 | } |
2988 | if (subblock.addrShift != block.addrShift) { |
2989 | map(hw, a64 ? Type::u64 : Type::u32, subaddr, subaddr, strategy, |
2990 | [&](int simd, GRF r, GRF _) { |
2991 | auto shift = block.addrShift - subblock.addrShift; |
2992 | (shift > 0) ? eshl(simd, r, r, +shift, strategy, state) |
2993 | : eshr(simd, r, r, -shift, strategy, state); |
2994 | }); |
2995 | } |
2996 | |
2997 | if (isBlock2D(astrategy.accessType)) { |
2998 | // Adjust 2D block header as needed. |
2999 | int bw, bh; |
3000 | bool memCM = isColMajor(atype.layout); |
3001 | auto RegisterBlock::*nw |
3002 | = memCM ? &RegisterBlock::nr : &RegisterBlock::nc; |
3003 | auto RegisterBlock::*nh |
3004 | = memCM ? &RegisterBlock::nc : &RegisterBlock::nr; |
3005 | bool remW = memCM ? subblock.remainderR : subblock.remainderC; |
3006 | bool remH = memCM ? subblock.remainderC : subblock.remainderR; |
3007 | getBlock2DWH(bw, bh, atype, subblock); |
3008 | |
3009 | if (!astrategy.address2D) { |
3010 | if (subblock.*nw != block.*nw |
3011 | || subblock.count != block.count) { |
3012 | int newW = bw * subblock.count * subblock.ebytes - 1; |
3013 | remW ? min_(1, subaddr[0].ud(2), subaddr[0].ud(2), newW) |
3014 | : mov(1, subaddr[0].ud(2), newW); |
3015 | } |
3016 | if (subblock.*nh != block.*nh) { |
3017 | int newH = bh * subblock.ebytes - 1; |
3018 | remH ? min_(1, subaddr[0].ud(3), subaddr[0].ud(3), newH) |
3019 | : mov(1, subaddr[0].ud(3), newH); |
3020 | } |
3021 | } |
3022 | if (subblock.nr != block.nr || subblock.nc != block.nc |
3023 | || subblock.count != block.count) |
3024 | mov(1, subaddr[0].ud(7), |
3025 | (bw - 1) | ((bh - 1) << 8) |
3026 | | ((subblock.count - 1) << 16)); |
3027 | } |
3028 | } |
3029 | } |
3030 | |
3031 | // Split 2D block array loads into multiple blocks. |
3032 | static inline void postprocessLayout2D(vector<RegisterBlock> &layout, |
3033 | const MatrixAddressing &atype, |
3034 | const MatrixAddressingStrategy &astrategy) { |
3035 | if (!isBlock2D(astrategy.accessType)) return; |
3036 | |
3037 | int maxCount = 1; |
3038 | for (auto &block : layout) |
3039 | maxCount = std::max(maxCount, int(block.count)); |
3040 | if (maxCount == 1) return; |
3041 | |
3042 | vector<RegisterBlock> xlayout; |
3043 | xlayout.reserve(layout.size() * maxCount); |
3044 | |
3045 | bool memCM = isColMajor(atype.layout); |
3046 | auto RegisterBlock::*nx = memCM ? &RegisterBlock::nr : &RegisterBlock::nc; |
3047 | auto RegisterBlock::*offsetX |
3048 | = memCM ? &RegisterBlock::offsetR : &RegisterBlock::offsetC; |
3049 | |
3050 | for (auto &block : layout) { |
3051 | auto nblock = block; |
3052 | nblock.*nx /= block.count; |
3053 | if (!isTransposing(astrategy.accessType)) nblock.ld /= block.count; |
3054 | |
3055 | for (int i = 0; i < block.count; i++) { |
3056 | xlayout.push_back(nblock); |
3057 | nblock.*offsetX += nblock.*nx; |
3058 | nblock.simdSize = 0; // Blocks > 0 do not need loads. |
3059 | } |
3060 | } |
3061 | |
3062 | std::swap(layout, xlayout); |
3063 | } |
3064 | |
3065 | // Split blocks that span multiple tiles. Requires each tile to be contained within a single block. |
3066 | static inline void postprocessLayoutMultitile(Type T, |
3067 | vector<RegisterBlock> &layout, const MatrixAddressing &atype, |
3068 | const MatrixAddressingStrategy &astrategy) { |
3069 | if (!atype.tileR || !atype.tileC) return; |
3070 | if (isLargeCrosspack(T, atype.crosspack)) return; |
3071 | |
3072 | bool needToSplit = false; |
3073 | for (const auto &block : layout) |
3074 | needToSplit |= (block.colMajor ? (block.nr > atype.tileR) |
3075 | : (block.nc > atype.tileC)); |
3076 | |
3077 | if (!needToSplit) return; |
3078 | |
3079 | vector<RegisterBlock> xlayout; |
3080 | xlayout.reserve(layout.size()); |
3081 | |
3082 | for (const auto &block : layout) { |
3083 | auto nx = block.colMajor ? &RegisterBlock::nr : &RegisterBlock::nc; |
3084 | auto ny = block.colMajor ? &RegisterBlock::nc : &RegisterBlock::nr; |
3085 | auto offsetX = block.colMajor ? &RegisterBlock::offsetR |
3086 | : &RegisterBlock::offsetC; |
3087 | auto offsetY = block.colMajor ? &RegisterBlock::offsetC |
3088 | : &RegisterBlock::offsetR; |
3089 | auto tileX = block.colMajor ? atype.tileR : atype.tileC; |
3090 | auto tileY = block.colMajor ? atype.tileC : atype.tileR; |
3091 | |
3092 | if (block.*nx == tileX) { |
3093 | xlayout.push_back(block); |
3094 | continue; |
3095 | } |
3096 | |
3097 | if (block.*nx % tileX || block.*offsetX % tileX || block.*ny % tileY |
3098 | || block.*offsetY % tileY) |
3099 | stub(); |
3100 | if (isTransposing(astrategy.accessType)) stub(); |
3101 | |
3102 | auto nblock = block; |
3103 | nblock.*nx = tileX; |
3104 | nblock.*ny = tileY; |
3105 | nblock.ld = tileX; |
3106 | |
3107 | for (int j = 0; j < block.*ny / tileY; j++) { |
3108 | for (int i = 0; i < block.*nx / tileX; i++) { |
3109 | nblock.*offsetX = block.*offsetX + i * tileX; |
3110 | nblock.*offsetY = block.*offsetY + j * tileY; |
3111 | xlayout.push_back(nblock); |
3112 | nblock.simdSize = 0; |
3113 | } |
3114 | } |
3115 | } |
3116 | |
3117 | std::swap(layout, xlayout); |
3118 | } |
3119 | |
3120 | // Split large crosspack blocks into smaller pieces so that they can be transposed. |
3121 | static inline void postprocessLayoutLargeCP(Type T, |
3122 | vector<RegisterBlock> &layout, const MatrixAddressing &atype, |
3123 | const MatrixAddressingStrategy &astrategy) { |
3124 | if (!isLargeCrosspack(T, atype.crosspack)) return; |
3125 | |
3126 | bool haveLargeCP = false; |
3127 | for (const auto &block : layout) { |
3128 | haveLargeCP |= isLargeCrosspack(T, block.crosspack); |
3129 | if (haveLargeCP) break; |
3130 | } |
3131 | |
3132 | if (!haveLargeCP) return; |
3133 | |
3134 | vector<RegisterBlock> xlayout; |
3135 | xlayout.reserve(layout.size()); |
3136 | |
3137 | for (const auto &block : layout) { |
3138 | if (!isLargeCrosspack(T, block.crosspack)) |
3139 | xlayout.push_back(block); |
3140 | else { |
3141 | auto ny = block.colMajor ? &RegisterBlock::nc : &RegisterBlock::nr; |
3142 | auto offsetY = block.colMajor ? &RegisterBlock::offsetC |
3143 | : &RegisterBlock::offsetR; |
3144 | |
3145 | if (block.*ny % block.crosspack) return; |
3146 | int blocks = (block.*ny / block.crosspack); |
3147 | auto nblock = block; |
3148 | nblock.*ny = block.crosspack; |
3149 | nblock.simplify(T); |
3150 | for (int i = 0; i < blocks; i++) { |
3151 | xlayout.push_back(nblock); |
3152 | nblock.simdSize = 0; |
3153 | nblock.*offsetY += nblock.*ny; |
3154 | } |
3155 | } |
3156 | } |
3157 | |
3158 | std::swap(layout, xlayout); |
3159 | } |
3160 | |
3161 | // Remove unneeded blocks from a dpasw src2 layout. |
3162 | static inline void postprocessLayoutDPASW(vector<RegisterBlock> &layout, |
3163 | const MatrixAddressing &atype, |
3164 | const MatrixAddressingStrategy &astrategy) { |
3165 | if (!astrategy.dpasw) return; |
3166 | |
3167 | vector<RegisterBlock> nlayout; |
3168 | nlayout.reserve(layout.size() / 2); |
3169 | |
3170 | bool cm = isLayoutColMajor(layout); |
3171 | auto tile = cm ? astrategy.tileC : astrategy.tileR; |
3172 | auto offsetX = cm ? &RegisterBlock::offsetC : &RegisterBlock::offsetR; |
3173 | |
3174 | for (const auto &block : layout) |
3175 | if ((block.*offsetX % (2 * tile)) < tile) nlayout.push_back(block); |
3176 | |
3177 | layout = std::move(nlayout); |
3178 | } |
3179 | |
3180 | static inline void postprocessLayout(Type T, vector<RegisterBlock> &layout, |
3181 | const MatrixAddressing &atype, |
3182 | const MatrixAddressingStrategy &astrategy) { |
3183 | postprocessLayout2D(layout, atype, astrategy); |
3184 | postprocessLayoutMultitile(T, layout, atype, astrategy); |
3185 | postprocessLayoutLargeCP(T, layout, atype, astrategy); |
3186 | postprocessLayoutDPASW(layout, atype, astrategy); |
3187 | } |
3188 | |
3189 | // Add a submatrix to a register layout. |
3190 | template <HW hw> |
3191 | bool gemm_kernel_generator_t<hw>::addToRegLayout(Type T, |
3192 | std::vector<RegisterBlock> &layout, int nr, int nc, int roff, int coff, |
3193 | bool remainderR, bool remainderC, bool writable, bool avoidFragment, |
3194 | int maxRBlock, int maxCBlock, const MatrixAddressing &atype, |
3195 | const MatrixAddressingStrategy &astrategy) { |
3196 | int rblock, cblock; |
3197 | RegisterBlock blockTemplate; |
3198 | if (!getBlockInfo(T, atype, astrategy, nr, nc, remainderR, remainderC, |
3199 | writable, avoidFragment, maxRBlock, maxCBlock, rblock, cblock, |
3200 | blockTemplate)) |
3201 | return false; /* Cannot handle requested block and remainder. */ |
3202 | |
3203 | if (rblock == 0 || cblock == 0) return false; |
3204 | |
3205 | blockTemplate.nr = rblock; |
3206 | blockTemplate.nc = cblock; |
3207 | |
3208 | for (int q = 0; q < T.components(); q++) { |
3209 | blockTemplate.component = q; |
3210 | if (isColMajor(atype.layout)) { |
3211 | // Order blocks in column-major fashion. |
3212 | for (int c = 0; c + cblock <= nc; c += cblock) { |
3213 | for (int r = 0; r + rblock <= nr; r += rblock) { |
3214 | auto thisBlock = blockTemplate; |
3215 | |
3216 | thisBlock.offsetR = r + roff; |
3217 | thisBlock.offsetC = c + coff; |
3218 | |
3219 | layout.push_back(thisBlock); |
3220 | } |
3221 | } |
3222 | } else { |
3223 | // Order blocks in row-major fashion. |
3224 | for (int r = 0; r + rblock <= nr; r += rblock) { |
3225 | for (int c = 0; c + cblock <= nc; c += cblock) { |
3226 | auto thisBlock = blockTemplate; |
3227 | |
3228 | thisBlock.offsetR = r + roff; |
3229 | thisBlock.offsetC = c + coff; |
3230 | |
3231 | layout.push_back(thisBlock); |
3232 | } |
3233 | } |
3234 | } |
3235 | } |
3236 | |
3237 | // Handle remainder recursively, checking for infinite recursion. |
3238 | int rrem = nr % rblock; |
3239 | int crem = nc % cblock; |
3240 | |
3241 | status << "Register layout: " << nr << 'x' << nc << " -> blocks " << rblock |
3242 | << 'x' << cblock << " remainder " << rrem << 'x' << crem |
3243 | << status_stream::endl; |
3244 | |
3245 | bool success = true; |
3246 | if (rrem || crem) { |
3247 | if ((nr == rrem || rrem == 0) && (nc == crem || crem == 0)) { |
3248 | status << "Cannot load/store requested matrix block size." |
3249 | << status_stream::endl; |
3250 | success = false; |
3251 | } else { |
3252 | if (rrem) |
3253 | success &= addToRegLayout(T, layout, rrem, nc - crem, nr - rrem, |
3254 | 0, remainderR, remainderC, writable, avoidFragment, |
3255 | maxRBlock, maxCBlock, atype, astrategy); |
3256 | if (crem) |
3257 | success &= addToRegLayout(T, layout, nr, crem, 0, nc - crem, |
3258 | remainderR, remainderC, writable, avoidFragment, |
3259 | maxRBlock, maxCBlock, atype, astrategy); |
3260 | } |
3261 | } |
3262 | return success; |
3263 | } |
3264 | |
3265 | // Add a submatrix (contiguous in memory) to a block-accessed register layout. |
3266 | template <HW hw> |
3267 | bool gemm_kernel_generator_t<hw>::add1DBlockToRegLayout(Type T, |
3268 | vector<RegisterBlock> &layout, int r, int c, bool writable, |
3269 | const MatrixAddressing &atype, |
3270 | const MatrixAddressingStrategy &astrategy) { |
3271 | // Skip pseudoblock cases (possible to support though) |
3272 | if (needsPseudoblock(hw, T, r, c, atype, astrategy, writable, false)) |
3273 | return false; |
3274 | |
3275 | // Get total number of bytes to load. No masking supported, so stub if |
3276 | // number of bytes not divisible by 16 (1 oword). |
3277 | int nbytes = r * c * T * T.components(); |
3278 | int align = 16; |
3279 | if (astrategy.newDP) align = 4; |
3280 | |
3281 | if (nbytes & (align - 1)) return false; |
3282 | |
3283 | // Get block info. |
3284 | int maxBBytes = 0; |
3285 | int ebytes = 0; |
3286 | int = 0; |
3287 | int addrShift = 0; |
3288 | int maxSIMD = 1; |
3289 | |
3290 | if (astrategy.newDP) { |
3291 | bool qword = (nbytes | atype.alignment) % 8 == 0; |
3292 | ebytes = qword ? 8 : 4; |
3293 | maxBBytes = ebytes * 64; |
3294 | } else { |
3295 | bool a64 = (astrategy.base.getModel() == ModelA64); |
3296 | bool oword = !a64; |
3297 | bool aoword = (astrategy.base.getModel() |
3298 | == ModelSC); // SC only does aligned oword |
3299 | if (hw >= HW::XeHP) oword |= ((atype.alignment & 0x1F) != 0); |
3300 | |
3301 | extra = aoword; |
3302 | ebytes = oword ? 16 : 32; |
3303 | maxBBytes = oword ? 128 : 256; |
3304 | if (astrategy.base.getModel() == ModelSLM && hw >= HW::XeHP) |
3305 | maxBBytes = 256; |
3306 | addrShift = (!a64 && oword && !aoword) ? 4 : 0; |
3307 | maxSIMD = 16; |
3308 | } |
3309 | |
3310 | // Get normalized dimensions. |
3311 | bool colMajor = isColMajor(atype.layout); |
3312 | int x = colMajor ? r : c; |
3313 | auto crosspack = atype.crosspack; |
3314 | |
3315 | // Counters for current x and y positions. |
3316 | int cx = 0, cy = 0; |
3317 | |
3318 | while (nbytes > 0) { |
3319 | // Carve out the largest chunk possible. |
3320 | int bbytes = std::min<int>(maxBBytes, rounddown_pow2(nbytes)); |
3321 | int belems = bbytes / T; |
3322 | |
3323 | // Create a true load block for first (possibly partial) row/column. |
3324 | // Then, create additional no-load blocks for any further (possible partial) |
3325 | // rows/columns until block is exhausted. |
3326 | bool first = true; |
3327 | while (belems > 0) { |
3328 | int nxRem = belems / crosspack; |
3329 | int nx = std::min<int>(nxRem, x - cx); |
3330 | if (nx <= 0) stub(); |
3331 | if (cy % crosspack) return false; |
3332 | |
3333 | RegisterBlock block; |
3334 | |
3335 | block.ld = nx; |
3336 | (colMajor ? block.nr : block.nc) = nx; |
3337 | (colMajor ? block.nc : block.nr) = crosspack; |
3338 | (colMajor ? block.offsetR : block.offsetC) = cx; |
3339 | (colMajor ? block.offsetC : block.offsetR) = cy; |
3340 | block.component = 0; |
3341 | block.colMajor = colMajor; |
3342 | block.splitComplex = false; |
3343 | block.cxComponent = RegisterBlock::Interleaved; |
3344 | |
3345 | if (first) { |
3346 | block.ebytes = ebytes; |
3347 | block.count = div_up(bbytes, ebytes); |
3348 | block.simdSize = std::min(maxSIMD, roundup_pow2(bbytes) >> 2); |
3349 | } else |
3350 | block.ebytes = block.count = block.simdSize = 0; |
3351 | |
3352 | block.extra = extra; |
3353 | block.clearFlag(); |
3354 | block.colMask = MaskInfo::None(); |
3355 | block.rowMask = MaskInfo::None(); |
3356 | block.colFragment = 0; |
3357 | block.rowFragment = 0; |
3358 | block.log2GRFBytes = GRF::log2Bytes(hw); |
3359 | |
3360 | block.crosspack = crosspack; |
3361 | block.remainderR = false; |
3362 | block.remainderC = false; |
3363 | block.noRowsOK = false; |
3364 | block.noColsOK = false; |
3365 | block.descRemR = false; |
3366 | block.descRemC = false; |
3367 | block.descAssigned = false; |
3368 | block.addrShift = addrShift; |
3369 | block.hasNoLoad = false; |
3370 | block.msgRegs = std::max(1, bbytes >> GRF::log2Bytes(hw)); |
3371 | |
3372 | if (first && cx == 0 && (nxRem % x) == 0) { |
3373 | // Shortcut: one register block can represent this block access. |
3374 | int ny = belems / x; |
3375 | (colMajor ? block.nc : block.nr) = ny; |
3376 | cy += ny; |
3377 | belems = 0; |
3378 | } else { |
3379 | cx += nx; |
3380 | belems -= nx * crosspack; |
3381 | if (cx == x) { |
3382 | cy += crosspack; |
3383 | cx = 0; |
3384 | } |
3385 | block.hasNoLoad = first && (belems > 0); |
3386 | first = false; |
3387 | } |
3388 | |
3389 | layout.push_back(block); |
3390 | } |
3391 | |
3392 | nbytes -= bbytes; |
3393 | } |
3394 | |
3395 | return true; |
3396 | } |
3397 | |
3398 | static inline int getPartialCrosspack(size_t sizeofT, |
3399 | const MatrixAddressing &atype, const RegisterBlock &block) { |
3400 | if (block.ebytes == 1 && !isLargeCrosspack(sizeofT, atype.crosspack)) |
3401 | return div_up(atype.crosspack, block.colMajor ? block.nc : block.nr); |
3402 | else |
3403 | return 1; |
3404 | } |
3405 | |
3406 | // Get linear element offset in tiled layout (both register and memory) |
3407 | static int untile(Type T, const MatrixAddressing &atype, int component, int i, |
3408 | int j, int r, int c, int tileR, int tileC, bool reverse = false) { |
3409 | bool cm = isColMajor(atype.layout) ^ reverse; |
3410 | |
3411 | if (isPacked(atype.layout)) (cm ? r : c) = atype.packSize; |
3412 | |
3413 | int cpR = cm ? 1 : atype.crosspack; |
3414 | int cpC = cm ? atype.crosspack : 1; |
3415 | |
3416 | if (tileR == 0) tileR = r; |
3417 | if (tileC == 0) tileC = c; |
3418 | |
3419 | int rstride = cm ? tileC : c; |
3420 | int cstride = cm ? r : tileR; |
3421 | int rtstride = cm ? cpC : tileC; |
3422 | int ctstride = cm ? tileR : cpR; |
3423 | |
3424 | rstride *= T.components(); |
3425 | cstride *= T.components(); |
3426 | |
3427 | int iTile = i % tileR; |
3428 | int jTile = j % tileC; |
3429 | i -= iTile; |
3430 | j -= jTile; |
3431 | int iCP = iTile % cpR; |
3432 | int jCP = jTile % cpC; |
3433 | iTile -= iCP; |
3434 | jTile -= jCP; |
3435 | int idx = i * rstride + j * cstride + tileR * tileC * component |
3436 | + iTile * rtstride + jTile * ctstride + iCP + jCP; |
3437 | return idx; |
3438 | } |
3439 | |
3440 | static int untile(Type T, const MatrixAddressing &atype, |
3441 | const RegisterBlock &block, int r, int c, int tileR, int tileC, |
3442 | bool reverse = false) { |
3443 | return untile(T, atype, block.component, block.offsetR, block.offsetC, r, c, |
3444 | tileR, tileC, reverse); |
3445 | } |
3446 | |
3447 | static int untile(Type T, const MatrixAddressing &atype, |
3448 | const RegisterBlock &block, int r, int c, bool reverse = false) { |
3449 | return untile(T, atype, block, r, c, atype.tileR, atype.tileC, reverse); |
3450 | } |
3451 | |
3452 | static int untile(Type T, const MatrixAddressing &atype, int component, int i, |
3453 | int j, int r, int c, bool reverse = false) { |
3454 | return untile( |
3455 | T, atype, component, i, j, r, c, atype.tileR, atype.tileC, reverse); |
3456 | } |
3457 | |
3458 | // Split A/B matrix between threads. |
3459 | static inline void coopSplit(bool isA, int &splitR, int &splitC, int r, int c, |
3460 | CoopSplit stype, int threads, const MatrixAddressing &atype) { |
3461 | auto &mn = isA ? r : c; |
3462 | auto &k = isA ? c : r; |
3463 | auto &splitMN = isA ? splitR : splitC; |
3464 | auto &splitK = isA ? splitC : splitR; |
3465 | auto tileMN = isA ? atype.tileR : atype.tileC; |
3466 | auto tileK = isA ? atype.tileC : atype.tileR; |
3467 | |
3468 | bool ok = false; |
3469 | |
3470 | switch (stype) { |
3471 | case CoopSplit::K: |
3472 | ok = (k % threads == 0); |
3473 | splitMN = mn; |
3474 | splitK = k / threads; |
3475 | break; |
3476 | case CoopSplit::MN: |
3477 | ok = (mn % threads == 0); |
3478 | splitMN = mn / threads; |
3479 | splitK = k; |
3480 | break; |
3481 | case CoopSplit::Linear: { |
3482 | int elems = r * c; |
3483 | ok = (elems % threads == 0); |
3484 | int selems = elems / threads; |
3485 | int cp = atype.crosspack; |
3486 | |
3487 | if (!tileK) tileK = k; |
3488 | if (!tileMN) tileMN = mn; |
3489 | |
3490 | // First try splitting into tiles in k dimension. |
3491 | if (selems >= (tileK * mn)) { |
3492 | ok &= (selems % (tileK * mn) == 0); |
3493 | splitMN = mn; |
3494 | splitK = k / threads; |
3495 | break; |
3496 | } |
3497 | |
3498 | ok &= (threads % (k / tileK) == 0); |
3499 | if (!ok) break; |
3500 | threads /= (k / tileK); |
3501 | |
3502 | // Then try splitting into tiles in m/n dimensions as well. |
3503 | if (selems >= (tileK * tileMN)) { |
3504 | ok &= (selems % (tileK * tileMN) == 0); |
3505 | splitMN = mn / threads; |
3506 | splitK = tileK; |
3507 | break; |
3508 | } |
3509 | |
3510 | ok &= (threads % (mn / tileMN) == 0); |
3511 | if (!ok) break; |
3512 | threads /= (mn / tileMN); |
3513 | |
3514 | // Then try splitting each tile in the k dimension. |
3515 | if (selems >= (cp * tileMN)) { |
3516 | ok &= (selems % (cp * tileMN) == 0); |
3517 | splitMN = tileMN; |
3518 | splitK = tileK / threads; |
3519 | break; |
3520 | } |
3521 | |
3522 | ok &= (threads % (tileK / cp) == 0); |
3523 | if (!ok) break; |
3524 | threads /= (tileK / cp); |
3525 | |
3526 | // Finally try splitting in the m/n dimensions. |
3527 | ok &= (selems % cp == 0); |
3528 | splitMN = tileMN / threads; |
3529 | splitK = cp; |
3530 | break; |
3531 | } |
3532 | } |
3533 | |
3534 | if (!ok) |
3535 | throw std::runtime_error( |
3536 | "Cooperative operation cannot be split evenly between " |
3537 | "threads." ); |
3538 | } |
3539 | |
3540 | // Re-order a layout so that registers appear in appropriate order |
3541 | // (row or column major) |
3542 | static void sortRegLayout(Type T, vector<RegisterBlock> &layout, int r, int c, |
3543 | const MatrixAddressing &atype, |
3544 | const MatrixAddressingStrategy &astrategy, bool reverse = false) { |
3545 | auto order = [=](const RegisterBlock &block) { |
3546 | return untile(T, atype, block, r, c, astrategy.tileR, astrategy.tileC, |
3547 | reverse); |
3548 | }; |
3549 | |
3550 | std::sort(layout.begin(), layout.end(), |
3551 | [&](const RegisterBlock &b1, const RegisterBlock &b2) { |
3552 | return (order(b1) < order(b2)); |
3553 | }); |
3554 | } |
3555 | |
3556 | static void finalizeLayout(HW hw, Type T, vector<RegisterBlock> &layout, |
3557 | const MatrixAddressing &atype, |
3558 | const MatrixAddressingStrategy &astrategy) { |
3559 | int offsetBytes = 0; |
3560 | for (auto &block : layout) { |
3561 | if (block.isLoadBlock() || isBlock2D(astrategy.accessType)) |
3562 | offsetBytes = alignup_pow2(offsetBytes, GRF::bytes(hw)); |
3563 | block.calcBytes(T, astrategy); |
3564 | block.offsetBytes = offsetBytes; |
3565 | offsetBytes += block.bytes; |
3566 | block.simplify(T); |
3567 | } |
3568 | } |
3569 | |
3570 | // Create a register layout for a matrix. |
3571 | template <HW hw> |
3572 | bool gemm_kernel_generator_t<hw>::getRegLayout(Type T, |
3573 | vector<RegisterBlock> &layout, int r, int c, bool remainderR, |
3574 | bool remainderC, bool writable, bool avoidFragment, int maxRBlock, |
3575 | int maxCBlock, const MatrixAddressing &atype, |
3576 | const MatrixAddressingStrategy &astrategy, bool reverseOrder) { |
3577 | bool success = false; |
3578 | |
3579 | layout.clear(); |
3580 | |
3581 | // Tiling handling. |
3582 | if (astrategy.tileR > 0) |
3583 | maxRBlock = (maxRBlock == 0) ? astrategy.tileR |
3584 | : gcd(int(astrategy.tileR), maxRBlock); |
3585 | if (astrategy.tileC > 0) |
3586 | maxCBlock = (maxCBlock == 0) ? astrategy.tileC |
3587 | : gcd(int(astrategy.tileC), maxRBlock); |
3588 | |
3589 | // Two separate strategies for creating register layout: |
3590 | // - standard 2D partitioning |
3591 | // - special 1D partitioning for block access to packed inputs. |
3592 | if (((atype.layout == MatrixLayout::Pc && atype.packSize == r) |
3593 | || (atype.layout == MatrixLayout::Pr && atype.packSize == c)) |
3594 | && (astrategy.accessType == AccessType::Block) && !remainderR |
3595 | && !remainderC && !atype.tileR && !atype.tileC |
3596 | && (maxRBlock >= r || maxRBlock == 0) |
3597 | && (maxCBlock >= c || maxCBlock == 0)) { |
3598 | success = add1DBlockToRegLayout( |
3599 | T, layout, r, c, writable, atype, astrategy); |
3600 | } |
3601 | if (!success) { |
3602 | success = addToRegLayout(T, layout, r, c, 0, 0, remainderR, remainderC, |
3603 | writable, avoidFragment, maxRBlock, maxCBlock, atype, |
3604 | astrategy); |
3605 | sortRegLayout(T, layout, r, c, atype, astrategy, reverseOrder); |
3606 | postprocessLayout(T, layout, atype, astrategy); |
3607 | } |
3608 | if (!success) return false; |
3609 | |
3610 | finalizeLayout(hw, T, layout, atype, astrategy); |
3611 | |
3612 | return true; |
3613 | } |
3614 | |
3615 | // Create a register layout for a uniform matrix not backed by memory. |
3616 | template <HW hw> |
3617 | void gemm_kernel_generator_t<hw>::makeUnbackedRegLayout(Type T, |
3618 | vector<RegisterBlock> &layout, int r, int c, bool colMajor, |
3619 | int crosspack, int tileR, int tileC, bool allowPartialRegs, |
3620 | bool fullySplitCx) { |
3621 | auto block = RegisterBlock(); |
3622 | |
3623 | if ((colMajor ? c : r) % crosspack) stub(); |
3624 | layout.clear(); |
3625 | |
3626 | if (tileR <= 0) tileR = r; |
3627 | if (tileC <= 0) tileC = c; |
3628 | |
3629 | int offsetBytes = 0; |
3630 | int qCXMin = -1, qCXMax = -1; |
3631 | |
3632 | for (int qCX = qCXMin; qCX <= qCXMax; qCX++) { |
3633 | for (int q = 0; q < T.components(); q++) { |
3634 | for (int i = 0; i < r; i += tileR) { |
3635 | for (int j = 0; j < c; j += tileC) { |
3636 | block.log2GRFBytes = GRF::log2Bytes(hw); |
3637 | block.nr = std::min(r - i, tileR); |
3638 | block.nc = std::min(c - j, tileC); |
3639 | block.ld = colMajor ? tileR : tileC; |
3640 | if (!allowPartialRegs) |
3641 | block.ld = align_up(block.ld, elementsPerGRF(hw, T)); |
3642 | block.offsetR = i; |
3643 | block.offsetC = j; |
3644 | block.colMajor = colMajor; |
3645 | block.crosspack = crosspack; |
3646 | block.offsetBytes = offsetBytes; |
3647 | block.splitComplex = false; |
3648 | block.cxComponent = qCX; |
3649 | block.component = q; |
3650 | |
3651 | block.calcBytes(T); |
3652 | offsetBytes += block.bytes; |
3653 | |
3654 | block.remainderR = false; |
3655 | block.remainderC = false; |
3656 | block.simdSize = 0; // Not backed by memory. |
3657 | |
3658 | layout.push_back(block); |
3659 | } |
3660 | } |
3661 | } |
3662 | } |
3663 | } |
3664 | |
3665 | // Attempt to create a 2D block layout that matches an existing layout. |
3666 | // Currently only generates regular/transpose 2D block (no VNNI support). |
3667 | template <HW hw> |
3668 | bool gemm_kernel_generator_t<hw>::upgradeLayoutToBlock2D(Type T, |
3669 | const vector<RegisterBlock> &layoutSrc, vector<RegisterBlock> &layout2D, |
3670 | bool remainderR, bool remainderC, bool writable, |
3671 | const MatrixAddressing &atype, |
3672 | const MatrixAddressingStrategy &astrategy) { |
3673 | layout2D.clear(); |
3674 | layout2D.reserve(layoutSrc.size()); |
3675 | |
3676 | if (layoutSrc.empty()) return true; |
3677 | if (isPacked(atype.layout)) return false; |
3678 | |
3679 | bool transpose = isTransposing(astrategy.accessType); |
3680 | bool regCM = isLayoutColMajor(layoutSrc); |
3681 | |
3682 | if (transpose) { |
3683 | if (sizeof(T) == 8) { |
3684 | if (getStepping() < SteppingPVCXTB0) return false; |
3685 | } else if (sizeof(T) != 4) |
3686 | return false; |
3687 | } |
3688 | |
3689 | int r0 = -1, c0 = -1, b0 = -1; |
3690 | int nr = 0, nc = 0; |
3691 | bool ok = true; |
3692 | |
3693 | auto make2DBlock = [&] { |
3694 | if (r0 < 0 || c0 < 0) return; |
3695 | ok = ok |
3696 | && addToRegLayout(T, layout2D, nr, nc, r0, c0, remainderR, |
3697 | remainderC, writable, false, 0, 0, atype, astrategy); |
3698 | }; |
3699 | |
3700 | for (size_t i = 0; i < layoutSrc.size(); i++) { |
3701 | auto &block = layoutSrc[i]; |
3702 | unsigned omask = GRF::bytes(hw) - 1; |
3703 | |
3704 | if ((block.offsetBytes & omask) || (block.bytes & omask)) return false; |
3705 | if (block.nregs() > 1) |
3706 | return false; /* don't split block into multiple 2D block loads */ |
3707 | |
3708 | bool consecutive = (block.offsetBytes == (b0 + GRF::bytes(hw))); |
3709 | if (regCM && block.offsetC == c0 + nc && consecutive && nr == block.nr) |
3710 | nc++; |
3711 | else if (!regCM && block.offsetR == r0 + nr && consecutive |
3712 | && nc == block.nc) |
3713 | nr++; |
3714 | else { |
3715 | make2DBlock(); |
3716 | r0 = block.offsetR; |
3717 | c0 = block.offsetC; |
3718 | nr = block.nr; |
3719 | nc = block.nc; |
3720 | } |
3721 | b0 = block.offsetBytes; |
3722 | } |
3723 | |
3724 | make2DBlock(); |
3725 | |
3726 | int r = 0, c = 0; |
3727 | getLayoutDims(layoutSrc, r, c); |
3728 | sortRegLayout(T, layout2D, r, c, atype, astrategy); |
3729 | postprocessLayout(T, layout2D, atype, astrategy); |
3730 | finalizeLayout(hw, T, layout2D, atype, astrategy); |
3731 | |
3732 | return ok; |
3733 | } |
3734 | |
3735 | // Find the subregister in a RegisterBlock corresponding to element at offset (rr,cc), |
3736 | // as well as the contiguous elements following it (nelems). |
3737 | static Subregister findBlockReg(Type T, const RegisterBlock &block, int rr, |
3738 | int cc, const GRFMultirange ®s, int &nelems, int cxComponent = -1, |
3739 | int component = 0) { |
3740 | auto Te = T; |
3741 | const int ne = (1 << block.log2GRFBytes) / Te; |
3742 | |
3743 | if (rr < 0 || rr >= block.nr || cc < 0 || cc >= block.nc |
3744 | || component != block.component |
3745 | || !one_of(block.cxComponent, -1, cxComponent)) |
3746 | throw std::runtime_error("Requested out-of-bounds element." ); |
3747 | |
3748 | int crosspack = block.crosspack; |
3749 | int elFixed, elLD; |
3750 | if (block.colMajor) { |
3751 | int ccx = cc % crosspack; |
3752 | elFixed = ccx + (rr * crosspack); |
3753 | elLD = cc - ccx; |
3754 | nelems = block.nr - rr; |
3755 | } else { |
3756 | int rrx = rr % crosspack; |
3757 | elFixed = rrx + (cc * crosspack); |
3758 | elLD = (rr - rrx); |
3759 | nelems = block.nc - cc; |
3760 | } |
3761 | |
3762 | int el = elFixed + elLD * block.ld; |
3763 | el += block.offsetBytes / Te; |
3764 | int reg = el / ne; |
3765 | int subreg = el % ne; |
3766 | |
3767 | return regs[reg].sub(subreg, Te.ngen()); |
3768 | } |
3769 | |
3770 | // Find the subregister in a layout corresponding to element (r,c), as well as the |
3771 | // associated block, and the number of contiguous elements following it (nelems). |
3772 | static Subregister findBlockReg(Type T, const vector<RegisterBlock> &layout, |
3773 | int r, int c, const GRFMultirange ®s, int &nelems, |
3774 | const RegisterBlock *&block, int cxComponent = -1, int component = 0) { |
3775 | int ecomponent = component; |
3776 | for (auto &l : layout) { |
3777 | int rr = r - l.offsetR; |
3778 | int cc = c - l.offsetC; |
3779 | if (rr >= 0 && rr < l.nr && cc >= 0 && cc < l.nc |
3780 | && ecomponent == l.component |
3781 | && one_of(l.cxComponent, cxComponent, |
3782 | RegisterBlock::Interleaved)) { |
3783 | block = &l; |
3784 | return findBlockReg( |
3785 | T, l, rr, cc, regs, nelems, cxComponent, component); |
3786 | } |
3787 | } |
3788 | |
3789 | throw std::runtime_error( |
3790 | "Could not find requested matrix element in layout." ); |
3791 | } |
3792 | |
3793 | // Match the register offsets in one register layout to another, reference layout. |
3794 | // Returns true if successful. If not successful, the layout is unchanged. |
3795 | static bool matchLayouts(Type T, vector<RegisterBlock> &layout, |
3796 | const vector<RegisterBlock> &layoutRef) { |
3797 | vector<RegisterBlock> nlayout = layout; |
3798 | |
3799 | if (getRegCount(layoutRef) >= 256) return false; |
3800 | |
3801 | for (auto &nblock : nlayout) { |
3802 | int nelems; |
3803 | const RegisterBlock *blockRef; |
3804 | auto sr = findBlockReg(T, layoutRef, nblock.offsetR, nblock.offsetC, |
3805 | GRFRange(0, 254), nelems, blockRef); |
3806 | |
3807 | // Check: |
3808 | // 1. Does this register block's offset match the reference block's offset? |
3809 | if (sr.getByteOffset() |
3810 | != (nblock.offsetBytes & ((1 << nblock.log2GRFBytes) - 1))) |
3811 | return false; |
3812 | |
3813 | // 2. Is there any free space in the register block? |
3814 | if (nblock.nr * nblock.nc * T != nblock.bytes) return false; |
3815 | |
3816 | // 3. Does this register block's data layout match the reference block's layout? |
3817 | if (blockRef->colMajor != nblock.colMajor) return false; |
3818 | if (blockRef->crosspack != nblock.crosspack) return false; |
3819 | |
3820 | // 4. Does this register block fit inside the reference block? |
3821 | auto RegisterBlock::*nx |
3822 | = nblock.colMajor ? &RegisterBlock::nr : &RegisterBlock::nc; |
3823 | auto RegisterBlock::*ny |
3824 | = nblock.colMajor ? &RegisterBlock::nc : &RegisterBlock::nr; |
3825 | |
3826 | if (nblock.*nx < blockRef->*nx) { |
3827 | if (nblock.*ny > 1) return false; |
3828 | } else if (nblock.*nx == blockRef->*nx) { |
3829 | if (nblock.*ny > blockRef->*ny) return false; |
3830 | } else |
3831 | return false; |
3832 | |
3833 | if (nblock.*ny > 1 && (nblock.ld != blockRef->ld)) return false; |
3834 | |
3835 | // It's compatible. Point this register block where it belongs. |
3836 | nblock.offsetBytes |
3837 | = (sr.getBase() << nblock.log2GRFBytes) + sr.getByteOffset(); |
3838 | } |
3839 | |
3840 | std::swap(nlayout, layout); |
3841 | return true; |
3842 | } |
3843 | |
3844 | // Like matchLayouts but allows either layout to change to match the other. |
3845 | static bool matchLayoutsBidirectional(Type T, vector<RegisterBlock> &layout1, |
3846 | vector<RegisterBlock> &layout2) { |
3847 | return matchLayouts(T, layout1, layout2) |
3848 | || matchLayouts(T, layout2, layout1); |
3849 | } |
3850 | |
3851 | static bool allocateTokens(const vector<RegisterBlock> &layout, |
3852 | const GRFMultirange ®s, CommonState &state, |
3853 | const vector<GRFRange> &addrs = vector<GRFRange>()) { |
3854 | bool success = true; |
3855 | size_t origSize = state.tokenMap.size(); |
3856 | auto saveTA = state.tokenAllocator; |
3857 | |
3858 | for (size_t l = 0; l < layout.size(); l++) { |
3859 | auto token = state.tokenAllocator.tryAlloc(); |
3860 | if (token < 0) |
3861 | success = false; |
3862 | else { |
3863 | auto regKey = !regs.empty() ? regs[layout[l].offsetReg()].getBase() |
3864 | : addrs[l].getBase(); |
3865 | state.tokenMap.push_back(std::make_pair(regKey, token)); |
3866 | } |
3867 | } |
3868 | |
3869 | if (!success) { |
3870 | state.tokenAllocator = saveTA; |
3871 | state.tokenMap.resize(origSize); |
3872 | } |
3873 | |
3874 | return success; |
3875 | } |
3876 | |
3877 | static void clearTokenAllocations(HW hw, CommonState &state) { |
3878 | state.tokenMap.clear(); |
3879 | state.tokenAllocator = TokenAllocator(hw); |
3880 | } |
3881 | |
3882 | template <HW hw> |
3883 | void gemm_kernel_generator_t<hw>::setupTeardownLoadStoreDesc(bool setup, |
3884 | const vector<RegisterBlock> &layout, const CommonStrategy &strategy, |
3885 | CommonState &state) { |
3886 | if (strategy.emulate.emulateDWxDW) { |
3887 | auto nconstants = (hw >= HW::XeHPG) ? 3 : 2; |
3888 | |
3889 | if (setup) |
3890 | for (int s = 0; s < nconstants; s++) { |
3891 | state.lsDescConstant[s] = state.ra.alloc_sub<uint32_t>(); |
3892 | mov(1, state.lsDescConstant[s], uint32_t(0x00100040 << s)); |
3893 | } |
3894 | else |
3895 | for (int s = 0; s < nconstants; s++) |
3896 | state.ra.safeRelease(state.lsDescConstant[s]); |
3897 | } |
3898 | } |
3899 | |
3900 | // Output code for loading address register(s) with load/store message descriptors for remainders. |
3901 | template <HW hw> |
3902 | void gemm_kernel_generator_t<hw>::loadLoadStoreDescriptors(bool load, |
3903 | bool store, RegisterBlock &block, const Subregister &count, |
3904 | const MatrixAddressing &atype, |
3905 | const MatrixAddressingStrategy &astrategy, |
3906 | const CommonStrategy &strategy, CommonState &state) { |
3907 | MessageDescriptor descLoad; // a0.0:ud |
3908 | MessageDescriptor descStore; // a0.2 (a0.0 if no loads) |
3909 | ExtendedMessageDescriptor exdescLoad; |
3910 | ExtendedMessageDescriptor exdescStore; // a0.1 |
3911 | |
3912 | Subregister t1 = state.ra.alloc_sub<uint32_t>(); |
3913 | Subregister t2 = state.ra.alloc_sub<uint32_t>(); |
3914 | |
3915 | if (astrategy.newDP) switch (astrategy.accessType) { |
3916 | case AccessType::ChannelScattered: |
3917 | case AccessType::Scattered: { |
3918 | bool channel = (astrategy.accessType |
3919 | == AccessType::ChannelScattered); |
3920 | |
3921 | encodeLoadDescriptors(hw, descLoad, exdescLoad, block.simdSize, |
3922 | r0, getDataSpecLSC(atype, astrategy, block, false), |
3923 | astrategy.base, null); |
3924 | encodeStoreDescriptors(hw, descStore, exdescStore, |
3925 | block.simdSize, |
3926 | getDataSpecLSC(atype, astrategy, block, true), |
3927 | astrategy.base, null); |
3928 | descLoad.cmask.cmask = 0; // also vectSize |
3929 | descStore.cmask.cmask = 0; |
3930 | exdescStore.parts.extMessageLen = 0; |
3931 | descLoad.parts.responseLen = 0; |
3932 | |
3933 | int underlyingSIMD = std::max<int>( |
3934 | block.simdSize, maxScatteredSIMD(hw, astrategy) >> 1); |
3935 | int log2GRFs = log2(underlyingSIMD * block.ebytes) |
3936 | - GRF::log2Bytes(hw); |
3937 | int log2Components = int(block.splitComplex); |
3938 | |
3939 | if (channel) mov(1, t2, 0x1000 << log2Components); |
3940 | mul(1, t1, state.lsDescConstant[log2GRFs + log2Components], |
3941 | count.uw()); |
3942 | channel ? shl(1, t2, t2, count) |
3943 | : shl(1, t2, count, 12 + log2Components); |
3944 | if (store) or_(1, a0.ud(1), t1.uw(0), exdescStore.all); |
3945 | add(1, t1.uw(0), t2, -0x1000); |
3946 | if (load) or_(1, a0.ud(0), t1, descLoad.all); |
3947 | if (store) or_(1, a0.ud(load ? 2 : 0), t1.uw(0), descStore.all); |
3948 | break; |
3949 | } |
3950 | default: hw_unsupported(); |
3951 | } |
3952 | else |
3953 | switch (astrategy.accessType) { |
3954 | case AccessType::ChannelScattered: { |
3955 | encodeLoadDescriptors(hw, descLoad, exdescLoad, block.simdSize, |
3956 | r0, surface_dword(ChannelMask::rgba), astrategy.base, |
3957 | null); |
3958 | encodeStoreDescriptors(hw, descStore, exdescStore, |
3959 | block.simdSize, surface_dword(ChannelMask::rgba), |
3960 | astrategy.base, null); |
3961 | descLoad.surface.cmask = 0; // |
3962 | descStore.surface.cmask = 0; // Fields to fill in. |
3963 | exdescStore.parts.extMessageLen = 0; // |
3964 | descLoad.parts.responseLen = 0; |
3965 | |
3966 | int log2Components = int(block.splitComplex); |
3967 | int shift = int(block.simdSize == 16) + log2Components; |
3968 | auto bitmask = uint16_t(0x0F00 << log2Components); |
3969 | |
3970 | if (strategy.emulate.emulateDWxDW) |
3971 | mul(1, t1, state.lsDescConstant[shift], count.uw()); |
3972 | else |
3973 | mul(1, t1, count, uint32_t(0x00100040) << shift); |
3974 | mov(1, t2, bitmask); |
3975 | if (store) or_(1, a0.ud(1), t1.uw(0), exdescStore.all); |
3976 | shl(1, t2, t2, count); |
3977 | and_(1, t1.uw(0), t2, bitmask); |
3978 | if (load) or_(1, a0.ud(0), t1, descLoad.all); |
3979 | if (store) or_(1, a0.ud(load ? 2 : 0), t1.uw(0), descStore.all); |
3980 | break; |
3981 | } |
3982 | default: hw_unsupported(); |
3983 | } |
3984 | |
3985 | state.ra.safeRelease(t1); |
3986 | state.ra.safeRelease(t2); |
3987 | block.sfid = exdescLoad.all; |
3988 | } |
3989 | |
3990 | template <HW hw> |
3991 | InstructionModifier gemm_kernel_generator_t<hw>::getRegisterBlockMask( |
3992 | const RegisterBlock &block, CommonState &state) { |
3993 | InstructionModifier result; |
3994 | |
3995 | if (block.flag) { |
3996 | result |= getPhysicalFlag(block.flag, state); |
3997 | if (hw == HW::XeHPC) { |
3998 | if (block.flagAll) result |= all; |
3999 | if (block.flagAny) result |= any; |
4000 | } else if (block.flagAll) |
4001 | result |= (block.simdSize > 8) ? all16h : all8h; |
4002 | else if (block.flagAny) |
4003 | result |= (block.simdSize > 8) ? any16h : any8h; |
4004 | } |
4005 | |
4006 | return result; |
4007 | } |
4008 | |
4009 | // Check if a block occupies a contiguous portion of registers in the given GRFMultirange. |
4010 | // If so, return index of the block's first register in the range. |
4011 | static inline int contiguityCheck( |
4012 | HW hw, const RegisterBlock &block, const GRFMultirange &range) { |
4013 | auto offsetBytes = block.offsetBytes; |
4014 | if (offsetBytes & (GRF::bytes(hw) - 1)) |
4015 | if (block.isLoadBlock()) stub(); |
4016 | auto offsetReg = offsetBytes >> GRF::log2Bytes(hw); |
4017 | auto lastReg = GRF::bytesToGRFs(hw, offsetBytes + block.bytes); |
4018 | if (!range.contiguous(offsetReg, lastReg - offsetReg)) stub(); |
4019 | |
4020 | return offsetReg; |
4021 | } |
4022 | |
4023 | static DataSizeLSC getDataSizeLSC(int ebytes, bool pad32) { |
4024 | switch (ebytes) { |
4025 | case 8: return DataSizeLSC::D64; |
4026 | case 4: return DataSizeLSC::D32; |
4027 | case 2: return pad32 ? DataSizeLSC::D16U32 : DataSizeLSC::D16; |
4028 | case 1: return pad32 ? DataSizeLSC::D8U32 : DataSizeLSC::D8; |
4029 | } |
4030 | throw std::runtime_error("Invalid data size" ); |
4031 | } |
4032 | |
4033 | template <HW hw> |
4034 | DataSpecLSC gemm_kernel_generator_t<hw>::getDataSpecLSC( |
4035 | AccessType access, const RegisterBlock &block) { |
4036 | switch (access) { |
4037 | case AccessType::ChannelScattered: { |
4038 | static const ChannelMask cmasks[4] = {ChannelMask::r, |
4039 | ChannelMask::rg, ChannelMask::rgb, ChannelMask::rgba}; |
4040 | if (block.ebytes != 4) hw_unsupported(); |
4041 | return D32 | cmasks[block.count - 1]; |
4042 | } |
4043 | case AccessType::Scattered: |
4044 | if (block.ebytes == 8) return D64(block.count); |
4045 | if (block.ebytes == 4) return D32(block.count); |
4046 | if (block.ebytes == 1) return getDataSizeLSC(block.count, true); |
4047 | hw_unsupported(); |
4048 | case AccessType::Block: |
4049 | if (block.ebytes == 8) return D64T(block.count); |
4050 | if (block.ebytes == 4) return D32T(block.count); |
4051 | hw_unsupported(); |
4052 | default: stub(); |
4053 | } |
4054 | } |
4055 | |
4056 | template <HW hw> |
4057 | DataSpecLSC gemm_kernel_generator_t<hw>::getDataSpecLSC( |
4058 | const MatrixAddressing &atype, |
4059 | const MatrixAddressingStrategy &astrategy, const RegisterBlock &block, |
4060 | bool write) { |
4061 | return getDataSpecLSC(implAccessType(atype, astrategy, block), block) |
4062 | | (write ? astrategy.cachingW : astrategy.cachingR); |
4063 | } |
4064 | |
4065 | // Output code for prefetching a matrix chunk (XeHPG+). |
4066 | template <HW hw> |
4067 | void gemm_kernel_generator_t<hw>::prefetchMatrix( |
4068 | const vector<RegisterBlock> &layout, const MatrixAddressing &atype, |
4069 | const MatrixAddressingStrategy &astrategy, |
4070 | const vector<GRFRange> &addrs, const CommonStrategy &strategy, |
4071 | CommonState &state) { |
4072 | auto nblocks = int(layout.size()); |
4073 | |
4074 | for (int l = 0; l < nblocks; l++) |
4075 | loadMatrixBlock(null, layout[l], atype, astrategy, addrs[l], strategy, |
4076 | state, false); |
4077 | } |
4078 | |
4079 | // Output code for loading a matrix chunk into registers. |
4080 | template <HW hw> |
4081 | void gemm_kernel_generator_t<hw>::loadMatrix(const GRFMultirange &dest, |
4082 | const vector<RegisterBlock> &layout, const MatrixAddressing &atype, |
4083 | const MatrixAddressingStrategy &astrategy, |
4084 | const vector<GRFRange> &addrs, const CommonStrategy &strategy, |
4085 | CommonState &state, bool zeroMask) { |
4086 | auto nblocks = int(layout.size()); |
4087 | |
4088 | if (astrategy.prefetch && astrategy.newDP) { |
4089 | prefetchMatrix(layout, atype, astrategy, addrs, strategy, state); |
4090 | return; |
4091 | } |
4092 | |
4093 | if (strategy.readSuppressionWA && (hasFlags(layout) || !getDefaultNoMask())) |
4094 | doReadSuppressionWA(strategy, state); |
4095 | |
4096 | for (int l = 0; l < nblocks; l++) { |
4097 | auto offsetReg = contiguityCheck(hw, layout[l], dest); |
4098 | loadMatrixBlock(dest[offsetReg], layout[l], atype, astrategy, addrs[l], |
4099 | strategy, state, zeroMask); |
4100 | } |
4101 | } |
4102 | |
4103 | // Output code for loading a single matrix block into registers. |
4104 | template <HW hw> |
4105 | void gemm_kernel_generator_t<hw>::loadMatrixBlock(const Register &dest, |
4106 | const RegisterBlock &block, const MatrixAddressing &atype, |
4107 | const MatrixAddressingStrategy &astrategy, const GRFRange &addr, |
4108 | const CommonStrategy &strategy, CommonState &state, bool zeroMask) { |
4109 | InstructionModifier maskMod; |
4110 | InstructionModifier mod = block.simdSize; |
4111 | |
4112 | // Zero SIMD size blocks are filled as part of another load. Skip them. |
4113 | if (!block.isLoadBlock()) return; |
4114 | |
4115 | // Get mask to apply, if any. |
4116 | auto mask = getRegisterBlockMask(block, state); |
4117 | maskMod |= mask; |
4118 | mod |= mask; |
4119 | |
4120 | // Look up preassigned token. |
4121 | for (auto &entry : state.tokenMap) { |
4122 | if (entry.first == dest.getBase() || entry.first == addr.getBase()) { |
4123 | mod |= SBID(entry.second); |
4124 | break; |
4125 | } |
4126 | } |
4127 | |
4128 | if (astrategy.newDP) switch (implAccessType(atype, astrategy, block)) { |
4129 | case AccessType::Block: |
4130 | case AccessType::Scattered: |
4131 | case AccessType::ChannelScattered: { |
4132 | auto spec = getDataSpecLSC(atype, astrategy, block, false); |
4133 | if (block.descAssigned) { |
4134 | MessageDescriptor desc; |
4135 | ExtendedMessageDescriptor exdesc; |
4136 | encodeLoadDescriptors(hw, desc, exdesc, block.simdSize, r0, |
4137 | spec, astrategy.base, null); |
4138 | send(mod, static_cast<SharedFunction>(block.sfid), dest, |
4139 | addr, null, exdesc.all, a0[0]); |
4140 | } else { |
4141 | load(mod, dest, spec, astrategy.base, addr[0]); |
4142 | } |
4143 | break; |
4144 | } |
4145 | case AccessType::Block2D: |
4146 | case AccessType::Block2DTranspose: |
4147 | case AccessType::Block2DVNNI: { |
4148 | int w = 0, h = 0; |
4149 | getBlock2DWH(w, h, atype, block); |
4150 | auto spec = block_2d(getDataSizeLSC(block.ebytes, false), w, h, |
4151 | block.count) |
4152 | | astrategy.cachingR; |
4153 | if (astrategy.accessType == AccessType::Block2DTranspose) |
4154 | spec |= transpose; |
4155 | if (astrategy.accessType == AccessType::Block2DVNNI) |
4156 | spec |= vnni; |
4157 | load(mod, dest, spec, astrategy.base, addr); |
4158 | break; |
4159 | } |
4160 | default: stub(); |
4161 | } |
4162 | else if (block.descAssigned) |
4163 | send(mod, static_cast<SharedFunction>(block.sfid), dest, addr, null, |
4164 | block.sfid, a0[0]); |
4165 | else |
4166 | switch (implAccessType(atype, astrategy, block)) { |
4167 | case AccessType::ChannelScattered: { |
4168 | static const ChannelMask cmasks[4] = {ChannelMask::r, |
4169 | ChannelMask::rg, ChannelMask::rgb, ChannelMask::rgba}; |
4170 | if (block.ebytes != 4) stub(); |
4171 | load(mod, dest, surface_dword(cmasks[block.count - 1]), |
4172 | astrategy.base, addr); |
4173 | break; |
4174 | } |
4175 | case AccessType::Scattered: |
4176 | if (block.ebytes == 8) |
4177 | load(mod, dest, scattered_qword(block.count), |
4178 | astrategy.base, addr); |
4179 | else if (block.ebytes == 4) |
4180 | load(mod, dest, scattered_dword(block.count), |
4181 | astrategy.base, addr); |
4182 | else if (block.ebytes == 1) |
4183 | load(mod, dest, scattered_byte(block.count), astrategy.base, |
4184 | addr); |
4185 | else |
4186 | hw_unsupported(); |
4187 | break; |
4188 | case AccessType::Block: |
4189 | if (block.ebytes == 32) |
4190 | load(mod, dest, block_hword(block.count), astrategy.base, |
4191 | addr); |
4192 | else if (block.ebytes == 16 && !block.extra) |
4193 | load(mod, dest, block_oword(block.count), astrategy.base, |
4194 | addr); |
4195 | else if (block.ebytes == 16) |
4196 | load(mod, dest, aligned_block_oword(block.count), |
4197 | astrategy.base, addr); |
4198 | else |
4199 | hw_unsupported(); |
4200 | if (zeroMask && (astrategy.base.getModel() == ModelBTS)) { |
4201 | if (block.flag) |
4202 | mov<uint32_t>(block.simdSize | ~maskMod, dest, 0); |
4203 | if (block.simdSize <= 2) mov<uint32_t>(2, dest[2](1), 0); |
4204 | if (block.simdSize <= 1) mov<uint32_t>(1, dest[1], 0); |
4205 | } |
4206 | break; |
4207 | default: stub(); |
4208 | } |
4209 | } |
4210 | |
4211 | // Output code for storing a matrix chunk from registers. |
4212 | template <HW hw> |
4213 | void gemm_kernel_generator_t<hw>::storeMatrix(const GRFMultirange &src, |
4214 | const vector<RegisterBlock> &layout, const MatrixAddressing &atype, |
4215 | const MatrixAddressingStrategy &astrategy, |
4216 | const vector<GRFRange> &addrs, const CommonStrategy &strategy, |
4217 | CommonState &state) { |
4218 | auto nblocks = int(layout.size()); |
4219 | |
4220 | for (int l = 0; l < nblocks; l++) { |
4221 | auto offsetReg = contiguityCheck(hw, layout[l], src); |
4222 | storeMatrixBlock(src[offsetReg], layout[l], atype, astrategy, addrs[l], |
4223 | strategy, state); |
4224 | } |
4225 | } |
4226 | |
4227 | // Output code for storing a matrix block from registers. |
4228 | template <HW hw> |
4229 | void gemm_kernel_generator_t<hw>::storeMatrixBlock(const GRF &src, |
4230 | const RegisterBlock &block, const MatrixAddressing &atype, |
4231 | const MatrixAddressingStrategy &astrategy, const GRFRange &addr, |
4232 | const CommonStrategy &strategy, CommonState &state) { |
4233 | InstructionModifier mod = block.simdSize; |
4234 | ; |
4235 | |
4236 | // Zero SIMD size blocks are filled as part of another store. Skip them. |
4237 | if (!block.isLoadBlock()) return; |
4238 | |
4239 | // Get mask to apply, if any. |
4240 | mod |= getRegisterBlockMask(block, state); |
4241 | |
4242 | // Look up preassigned token. |
4243 | for (auto &entry : state.tokenMap) { |
4244 | if (entry.first == src.getBase()) { |
4245 | mod |= SBID(entry.second); |
4246 | break; |
4247 | } |
4248 | } |
4249 | |
4250 | if (block.descAssigned) |
4251 | send(mod, static_cast<SharedFunction>(block.sfid), null, addr, src, |
4252 | a0.ud(1), a0.ud(0)); |
4253 | else if (astrategy.newDP) |
4254 | switch (implAccessType(atype, astrategy, block)) { |
4255 | case AccessType::Block: |
4256 | case AccessType::Scattered: |
4257 | case AccessType::ChannelScattered: { |
4258 | auto spec = getDataSpecLSC(atype, astrategy, block, true); |
4259 | store(mod, spec, astrategy.base, addr[0], src); |
4260 | break; |
4261 | } |
4262 | case AccessType::Block2D: |
4263 | case AccessType::Block2DTranspose: |
4264 | case AccessType::Block2DVNNI: { |
4265 | int w = 0, h = 0; |
4266 | getBlock2DWH(w, h, atype, block); |
4267 | auto spec = block_2d(getDataSizeLSC(block.ebytes, false), w, h, |
4268 | block.count) |
4269 | | astrategy.cachingW; |
4270 | if (astrategy.accessType == AccessType::Block2DTranspose) |
4271 | spec |= transpose; |
4272 | if (astrategy.accessType == AccessType::Block2DVNNI) |
4273 | spec |= vnni; |
4274 | store(mod, spec, astrategy.base, addr, src); |
4275 | break; |
4276 | } |
4277 | default: stub(); |
4278 | } |
4279 | else |
4280 | switch (implAccessType(atype, astrategy, block)) { |
4281 | case AccessType::ChannelScattered: { |
4282 | static const ChannelMask cmasks[4] = {ChannelMask::r, |
4283 | ChannelMask::rg, ChannelMask::rgb, ChannelMask::rgba}; |
4284 | if (block.ebytes != 4) stub(); |
4285 | store(mod, surface_dword(cmasks[block.count - 1]), |
4286 | astrategy.base, addr, src); |
4287 | break; |
4288 | } |
4289 | case AccessType::Scattered: |
4290 | if (block.ebytes == 8) |
4291 | store(mod, scattered_qword(block.count), astrategy.base, |
4292 | addr, src); |
4293 | else if (block.ebytes == 4) |
4294 | store(mod, scattered_dword(block.count), astrategy.base, |
4295 | addr, src); |
4296 | else if (block.ebytes == 1) |
4297 | store(mod, scattered_byte(block.count), astrategy.base, |
4298 | addr, src); |
4299 | else |
4300 | hw_unsupported(); |
4301 | break; |
4302 | case AccessType::Block: |
4303 | if (block.ebytes == 32) |
4304 | store(mod, block_hword(block.count), astrategy.base, addr, |
4305 | src); |
4306 | else if (block.ebytes == 16 && !block.extra) |
4307 | store(mod, block_oword(block.count), astrategy.base, addr, |
4308 | src); |
4309 | else |
4310 | hw_unsupported(); |
4311 | break; |
4312 | default: stub(); |
4313 | } |
4314 | } |
4315 | |
4316 | // Atomic addition of a matrix in registers. |
4317 | template <HW hw> |
4318 | void gemm_kernel_generator_t<hw>::atomicAddMatrix(Type T, |
4319 | const GRFMultirange &src, const vector<RegisterBlock> &layout, |
4320 | const MatrixAddressing &atype, |
4321 | const MatrixAddressingStrategy &astrategy, |
4322 | const vector<GRFRange> &addrs, const CommonProblem &problem, |
4323 | const CommonStrategy &strategy, CommonState &state) { |
4324 | auto nblocks = int(layout.size()); |
4325 | |
4326 | if (strategy.readSuppressionWA && (hasFlags(layout) || !getDefaultNoMask())) |
4327 | doReadSuppressionWA(strategy, state); |
4328 | |
4329 | for (int l = 0; l < nblocks; l++) { |
4330 | auto offsetReg = contiguityCheck(hw, layout[l], src); |
4331 | atomicAddMatrixBlock(T, src[offsetReg], layout[l], atype, astrategy, |
4332 | addrs[l], problem, strategy, state); |
4333 | } |
4334 | } |
4335 | |
4336 | template <HW hw> |
4337 | void gemm_kernel_generator_t<hw>::atomicAddMatrixBlock(Type T, const GRF &src, |
4338 | const RegisterBlock &block, const MatrixAddressing &atype, |
4339 | const MatrixAddressingStrategy &astrategy, const GRFRange &addr, |
4340 | const CommonProblem &problem, const CommonStrategy &strategy, |
4341 | CommonState &state) { |
4342 | InstructionModifier maskMod; |
4343 | |
4344 | if (!block.isLoadBlock()) return; |
4345 | if (block.descAssigned) stub(); |
4346 | |
4347 | maskMod |= getRegisterBlockMask(block, state); |
4348 | |
4349 | // SIMD16 A64 atomics are emulated with 2x SIMD8. |
4350 | bool a64 = (astrategy.base.getModel() == ModelA64); |
4351 | int hsize = a64 ? 2 : 1; |
4352 | int simd = block.simdSize; |
4353 | if (!astrategy.newDP && a64) simd = std::min(simd, 8); |
4354 | if (hw >= HW::XeHPC && block.ebytes < 8 && block.simdSize == 16 |
4355 | && simd == 8) |
4356 | stub(); // Can't split data GRFs. |
4357 | auto nreg = block.nregs(); |
4358 | auto nregReal = (nreg * simd) / block.simdSize; |
4359 | |
4360 | auto specLSC = D32; |
4361 | if (astrategy.newDP) |
4362 | specLSC = getDataSpecLSC(atype, astrategy, block, true); |
4363 | |
4364 | switch (implAccessType(atype, astrategy, block)) { |
4365 | case AccessType::Scattered: |
4366 | case AccessType::ChannelScattered: |
4367 | if (hasNativeAtomicAdd(hw, T.real(), atype, astrategy)) { |
4368 | auto curSrc = src; |
4369 | for (int eoff = 0, hoff = 0; eoff < block.simdSize; |
4370 | eoff += simd, hoff += hsize, curSrc += nregReal) { |
4371 | auto mod = simd | maskMod | ExecutionOffset(eoff); |
4372 | if (block.ebytes != T.real().size()) stub(); |
4373 | if (astrategy.newDP) |
4374 | atomic(T.isFP() ? AtomicOp::fadd : AtomicOp::add, mod, |
4375 | specLSC, astrategy.base, addr[hoff], curSrc); |
4376 | else |
4377 | switch (T.real()) { |
4378 | case Type::f32: |
4379 | atomic(AtomicOp::fadd, mod, scattered_dword(), |
4380 | astrategy.base, addr[hoff], curSrc); |
4381 | break; |
4382 | case Type::u64: |
4383 | case Type::s64: |
4384 | atomic(AtomicOp::add, mod, scattered_qword(), |
4385 | astrategy.base, addr[hoff], curSrc); |
4386 | break; |
4387 | case Type::u32: |
4388 | case Type::s32: |
4389 | atomic(AtomicOp::add, mod, scattered_dword(), |
4390 | astrategy.base, addr[hoff], curSrc); |
4391 | break; |
4392 | case Type::u16: |
4393 | case Type::s16: |
4394 | if (hw < HW::Gen12LP) hw_unsupported(); |
4395 | atomic(AtomicOp::add, mod, scattered_word(), |
4396 | astrategy.base, addr[hoff], curSrc); |
4397 | break; |
4398 | default: stub(); |
4399 | } |
4400 | } |
4401 | } else { |
4402 | // Emulated atomic addition with a compare-and-swap loop. |
4403 | auto rOldNew = state.eatomicAddRegs[0]; |
4404 | auto rSave = state.eatomicAddRegs[1]; |
4405 | auto rOld = rOldNew[0]; |
4406 | auto rNew = rOldNew[nregReal]; |
4407 | auto flagToDo = getPhysicalFlag(state.vflagEAtomicAdd, state); |
4408 | |
4409 | if (block.simdSize > 16) stub(); // Need 32 channels. |
4410 | if (astrategy.newDP) |
4411 | load(block.simdSize | maskMod, rOld, specLSC, |
4412 | astrategy.base, addr[0]); |
4413 | else if (astrategy.base.getModel() == ModelA64) { |
4414 | if (block.ebytes == 2) |
4415 | load(block.simdSize | maskMod, rOld, scattered_byte(2), |
4416 | astrategy.base, addr); |
4417 | else if (block.ebytes == 4) |
4418 | load(block.simdSize | maskMod, rOld, scattered_dword(), |
4419 | astrategy.base, addr); |
4420 | else if (block.ebytes == 8) |
4421 | load(block.simdSize | maskMod, rOld, scattered_qword(), |
4422 | astrategy.base, addr); |
4423 | } else { |
4424 | if (block.ebytes == 2) |
4425 | load(block.simdSize | maskMod, rOld, scattered_byte(2), |
4426 | astrategy.base, addr); |
4427 | else if (block.ebytes == 4) |
4428 | load(block.simdSize | maskMod, rOld, |
4429 | surface_dword(ChannelMask::r), astrategy.base, |
4430 | addr); |
4431 | else if (block.ebytes == 8) |
4432 | stub(); // needs cmpwr2 |
4433 | } |
4434 | Label labelMask; |
4435 | |
4436 | // Save off high half of data when emulating SIMD16. |
4437 | if (block.simdSize > simd) |
4438 | mov<uint32_t>(nregReal * 8, rOld.advance(nreg), |
4439 | rOld.advance(nregReal)); |
4440 | |
4441 | if (block.flag) { |
4442 | if_(16 | getPhysicalFlag(block.flag, state), labelMask); |
4443 | setDefaultNoMask(false); |
4444 | } |
4445 | |
4446 | and_(1 | NoMask, flagToDo, ce0, |
4447 | uint16_t((1 << block.simdSize) - 1)); |
4448 | |
4449 | auto curSrc = src; |
4450 | |
4451 | for (int eoff = 0, hoff = 0; eoff < block.simdSize; |
4452 | eoff += simd, hoff += hsize) { |
4453 | auto eoMod = ExecutionOffset(eoff); |
4454 | |
4455 | Label labelCmpXchgLoop; |
4456 | mark(labelCmpXchgLoop); |
4457 | |
4458 | auto dt = T.ngen(); |
4459 | add(int(simd * block.ebytes / T.real()) | eoMod | NoMask, |
4460 | rNew.retype(dt), rOld.retype(dt), |
4461 | curSrc.retype(dt)); |
4462 | mov<uint32_t>((simd * block.ebytes / 4) | eoMod | NoMask, |
4463 | rSave, rOld); |
4464 | |
4465 | auto atomicMod = simd | flagToDo | eoMod; |
4466 | auto cmpMod = simd | flagToDo | ne | flagToDo | eoMod; |
4467 | |
4468 | if (astrategy.newDP) |
4469 | atomic(AtomicOp::cmpwr, atomicMod, rOld, specLSC, |
4470 | astrategy.base, addr[hoff], rOld); |
4471 | else |
4472 | switch (block.ebytes) { |
4473 | case 2: |
4474 | if (hw < HW::Gen12LP) hw_unsupported(); |
4475 | atomic(AtomicOp::cmpwr, atomicMod, rOld, |
4476 | scattered_word(), astrategy.base, |
4477 | addr[hoff], rOld); |
4478 | break; |
4479 | case 4: |
4480 | atomic(AtomicOp::cmpwr, atomicMod, rOld, |
4481 | scattered_dword(), astrategy.base, |
4482 | addr[hoff], rOld); |
4483 | break; |
4484 | case 8: |
4485 | atomic(AtomicOp::cmpwr, atomicMod, rOld, |
4486 | scattered_qword(), astrategy.base, |
4487 | addr[hoff], rOld); |
4488 | break; |
4489 | default: stub(); |
4490 | } |
4491 | |
4492 | if (block.ebytes == 2) |
4493 | cmp<uint16_t>(cmpMod, rSave[0][0](2), rOld[0](2)); |
4494 | else if (block.ebytes == 4) |
4495 | cmp<uint32_t>(cmpMod, rSave, rOld); |
4496 | else if (block.ebytes == 8) { |
4497 | if (strategy.emulate.emulate64) { |
4498 | cmp<uint32_t>(simd | ne | flagToDo | eoMod, |
4499 | rSave[0][0](2), rOld[0](2)); |
4500 | cmp<uint32_t>( |
4501 | simd | ~flagToDo | ne | flagToDo | eoMod, |
4502 | rSave[0][1](2), rOld[1](2)); |
4503 | } else |
4504 | cmp<uint64_t>(cmpMod, rSave, rOld); |
4505 | } else |
4506 | stub(); |
4507 | |
4508 | (hw == HW::XeHPC) |
4509 | ? simtDoWhileLoop( |
4510 | 16 | flagToDo | any, labelCmpXchgLoop) |
4511 | : strategy.fused ? simtDoWhileLoop( |
4512 | 16 | flagToDo | any16h, labelCmpXchgLoop) |
4513 | : (eoff == 0 && simd == 8) |
4514 | ? jmpi(1 | flagToDo | any8h, |
4515 | labelCmpXchgLoop) |
4516 | : jmpi(1 | flagToDo | any16h, |
4517 | labelCmpXchgLoop); |
4518 | |
4519 | rOld += 2 * nregReal; |
4520 | rNew += 2 * nregReal; |
4521 | curSrc += nregReal; |
4522 | } |
4523 | |
4524 | if (block.flag) { |
4525 | mark(labelMask); |
4526 | setDefaultNoMask(true); |
4527 | endif(16); |
4528 | } |
4529 | } |
4530 | break; |
4531 | default: hw_unsupported(); |
4532 | } |
4533 | } |
4534 | |
4535 | // Allocate temporary registers for emulating atomic addition. |
4536 | static inline void allocEAtomicAddRegs(HW hw, Type T, |
4537 | const vector<RegisterBlock> &layout, const MatrixAddressing &atype, |
4538 | const MatrixAddressingStrategy &astrategy, CommonState &state, |
4539 | const FlagRegister &flag = FlagRegister()) { |
4540 | if (hasNativeAtomicAdd(hw, T.real(), atype, astrategy)) return; |
4541 | |
4542 | int maxNReg = 0; |
4543 | for (const auto &block : layout) |
4544 | maxNReg = std::max(maxNReg, block.nregs()); |
4545 | |
4546 | if (maxNReg == 0) return; |
4547 | |
4548 | state.eatomicAddRegs[0] = state.ra.alloc_range(maxNReg * 2); |
4549 | state.eatomicAddRegs[1] = state.ra.alloc_range(maxNReg); |
4550 | state.vflagEAtomicAdd |
4551 | = flag.isValid() ? flag : state.raVFlag.allocVirtual(); |
4552 | } |
4553 | |
4554 | // Free temporary registers for emulating atomic addition. |
4555 | static inline void freeEAtomicAddRegs( |
4556 | CommonState &state, const FlagRegister &flag = FlagRegister()) { |
4557 | state.ra.safeRelease(state.eatomicAddRegs[0]); |
4558 | state.ra.safeRelease(state.eatomicAddRegs[1]); |
4559 | if (flag.isInvalid()) state.raVFlag.release(state.vflagEAtomicAdd); |
4560 | } |
4561 | |
4562 | static inline void releaseMaskAssignments(vector<MaskAssignment> &assignments, |
4563 | CommonState &state, int start = 0) { |
4564 | for (size_t an = start; an < assignments.size(); an++) |
4565 | state.raVFlag.release(assignments[an].flag); |
4566 | |
4567 | state.wipeActiveVFlags(); |
4568 | } |
4569 | |
4570 | static inline void reclaimMaskAssignments(vector<MaskAssignment> &assignments, |
4571 | CommonState &state, int start = 0) { |
4572 | for (size_t an = start; an < assignments.size(); an++) |
4573 | state.raVFlag.claim(assignments[an].flag); |
4574 | } |
4575 | |
4576 | // Release all masks in a mask assignment. If 'start' is specified, only the masks |
4577 | // at index 'start' and above will be released. |
4578 | static inline void safeReleaseMaskAssignments( |
4579 | vector<MaskAssignment> &assignments, CommonState &state, |
4580 | int start = 0) { |
4581 | releaseMaskAssignments(assignments, state, start); |
4582 | assignments.resize(start); |
4583 | } |
4584 | |
4585 | // Assign mask registers to a register layout. |
4586 | // The assignments parameter is both input and output: |
4587 | // existing assignments will be reused if compatible, and new assignments |
4588 | // created as necessary. |
4589 | template <HW hw> |
4590 | bool gemm_kernel_generator_t<hw>::assignMasks( |
4591 | std::vector<RegisterBlock> &layout, LoopType rloop, LoopType cloop, |
4592 | vector<MaskAssignment> &assignments, const CommonStrategy &strategy, |
4593 | CommonState &state, bool retryVirtual) { |
4594 | // Loop through layout, collecting masks. |
4595 | // - For each unique mask+loop+offset, allocate an index (flag reg) |
4596 | // - Store new assignment if unique and update flag reg in layout. |
4597 | // - For now, simultaneous row and column masks are not supported. |
4598 | bool retry = false; |
4599 | do { |
4600 | auto nassignOriginal = int(assignments.size()); |
4601 | bool outOfRegs = retry = false; |
4602 | |
4603 | for (RegisterBlock &l : layout) { |
4604 | MaskAssignment thisAssignment; |
4605 | |
4606 | if (l.rowMask) { |
4607 | if (l.colMask) stub(); |
4608 | |
4609 | thisAssignment.mask = l.rowMask; |
4610 | thisAssignment.offset = l.offsetR; |
4611 | thisAssignment.var = rloop; |
4612 | } else if (l.colMask) { |
4613 | thisAssignment.mask = l.colMask; |
4614 | thisAssignment.offset = l.offsetC; |
4615 | thisAssignment.var = cloop; |
4616 | } else { |
4617 | l.clearFlag(); |
4618 | continue; |
4619 | } |
4620 | |
4621 | // Look for compatible mask. |
4622 | bool gotMask = false; |
4623 | for (auto &a : assignments) { |
4624 | if (a.compatible(thisAssignment)) { |
4625 | l.flag = a.flag; |
4626 | gotMask = true; |
4627 | break; |
4628 | } |
4629 | } |
4630 | |
4631 | if (!gotMask) { |
4632 | // No compatible mask, so make a new assignment. |
4633 | thisAssignment.flag |
4634 | = state.raVFlag.allocVirtual((l.simdSize + 0xF) >> 4); |
4635 | assignments.push_back(thisAssignment); |
4636 | if (state.raVFlag.isVirtual(thisAssignment.flag) |
4637 | && state.vflagStorage.isInvalid()) { |
4638 | outOfRegs = true; |
4639 | break; |
4640 | } |
4641 | l.flag = thisAssignment.flag; |
4642 | } |
4643 | } |
4644 | |
4645 | if (outOfRegs) { |
4646 | // Not enough (virtual) flag registers! Free any masks we added to the list. |
4647 | safeReleaseMaskAssignments(assignments, state, nassignOriginal); |
4648 | if (retryVirtual && state.vflagStorage.isInvalid()) { |
4649 | status << "Not enough flag registers available. Retrying with " |
4650 | "virtual flags." |
4651 | << status_stream::endl; |
4652 | allocVFlagStorage(strategy, state); |
4653 | retry = true; |
4654 | } else { |
4655 | status << "Not enough flag registers available." |
4656 | << status_stream::endl; |
4657 | return false; |
4658 | } |
4659 | } |
4660 | } while (retry); |
4661 | |
4662 | return true; |
4663 | } |
4664 | |
4665 | // Output code for loading a mask into a flag register. |
4666 | template <HW hw> |
4667 | void gemm_kernel_generator_t<hw>::loadMask(MaskAssignment assignment, |
4668 | Subregister index, const CommonStrategy &strategy, CommonState &state, |
4669 | int offset) { |
4670 | auto flagIdx = assignment.flag; |
4671 | RegData flag = getMaskFlag(flagIdx, state); |
4672 | |
4673 | if (assignment.mask.fixed.isFixed) { |
4674 | // Load fixed mask. Easy. |
4675 | mov(1, flag, uint16_t(assignment.mask.fixed.value)); |
4676 | } else { |
4677 | // Load a variable mask, which requires some minor bit-twiddling. |
4678 | auto &vmask = assignment.mask.variable; |
4679 | |
4680 | uint32_t rsizeScaled = vmask.rsize / vmask.rdivide; |
4681 | uint32_t fullMask |
4682 | = (1ul << (vmask.bitRep * vmask.maskRep * rsizeScaled)) - 1; |
4683 | uint32_t rep1Mask = (1ul << (vmask.bitRep * rsizeScaled)) - 1; |
4684 | uint32_t repMultiplier = fullMask / rep1Mask; |
4685 | |
4686 | auto flagType = flag.getType(); |
4687 | auto mask0Type = getBytes(flagType) >= 4 ? DataType::uq : flagType; |
4688 | |
4689 | auto temp = state.ra.alloc_sub(flagType, getHint(HintType::Bank0)); |
4690 | auto mask0 = state.ra.alloc_sub(mask0Type, getHint(HintType::Bank1)); |
4691 | auto mask = mask0.reinterpret(0, flagType); |
4692 | auto mindex = index; |
4693 | |
4694 | if (vmask.rdivide > 1) { |
4695 | if (!is_zero_or_pow2(vmask.rdivide)) stub(); |
4696 | add(1, temp, mindex, -offset + vmask.rdivide - 1); |
4697 | shr(1, temp, temp, uint16_t(log2(vmask.rdivide))); |
4698 | mindex = temp; |
4699 | offset = 0; |
4700 | } |
4701 | if (vmask.bitRep > 1) { |
4702 | if (offset > 0) { |
4703 | add(1, temp, mindex, -offset); |
4704 | mindex = temp; |
4705 | offset = 0; |
4706 | } |
4707 | mulConstant(1, temp, mindex, vmask.bitRep); |
4708 | mindex = temp; |
4709 | } |
4710 | uint16_t tshift = vmask.bitRep |
4711 | * (rsizeScaled |
4712 | + div_up(assignment.offset + offset, vmask.rdivide)); |
4713 | add(1 | sat, temp, -mindex, tshift); |
4714 | if (tshift >= 32) |
4715 | min_(1, temp, temp, |
4716 | vmask.bitRep |
4717 | * rsizeScaled); // Ensure shift count doesn't overflow. |
4718 | emov(1, mask0, rep1Mask, strategy, state); |
4719 | if (vmask.maskRep == 1) { |
4720 | bool twoStage = (!flag.isARF() && getBytes(mask0Type) > 4); |
4721 | auto flag1 = twoStage ? mask0 : flag; |
4722 | vmask.reverse ? shl(1, flag1, mask0, temp) |
4723 | : shr(1, flag1, mask0, temp); |
4724 | if (twoStage) mov(1, flag, mask); |
4725 | } else { |
4726 | vmask.reverse ? stub() // need shl + and |
4727 | : shr(1, mask0, mask0, temp); |
4728 | if (repMultiplier & 0x10000) mov(1, mask.uw(1), mask.uw(0)); |
4729 | mul(1, flag, mask, uint16_t(repMultiplier)); |
4730 | } |
4731 | |
4732 | state.ra.safeRelease(temp); |
4733 | state.ra.safeRelease(mask0); |
4734 | } |
4735 | } |
4736 | |
4737 | // Output code for loading all masks in a mask assignment to flag registers. |
4738 | template <HW hw> |
4739 | void gemm_kernel_generator_t<hw>::loadMasks( |
4740 | const vector<MaskAssignment> &assignments, Subregister (&indices)[3], |
4741 | const CommonStrategy &strategy, CommonState &state, int start) { |
4742 | for (size_t an = start; an < assignments.size(); an++) { |
4743 | auto &a = assignments[an]; |
4744 | auto av = static_cast<int>(a.var); |
4745 | loadMask(a, indices[av], strategy, state); |
4746 | } |
4747 | } |
4748 | |
4749 | template <HW hw> |
4750 | void gemm_kernel_generator_t<hw>::loadMasks( |
4751 | const vector<MaskAssignment> &assignments, Subregister (&indices)[3], |
4752 | int (&offsets)[3], const CommonStrategy &strategy, CommonState &state, |
4753 | int start) { |
4754 | for (size_t an = start; an < assignments.size(); an++) { |
4755 | auto &a = assignments[an]; |
4756 | auto av = static_cast<int>(a.var); |
4757 | loadMask(a, indices[av], strategy, state, offsets[av]); |
4758 | } |
4759 | } |
4760 | |
4761 | template <HW hw> |
4762 | void gemm_kernel_generator_t<hw>::extendIndexVec(int n, CommonState &state) { |
4763 | auto &indexVec = state.indexVec; |
4764 | auto &ivEntries = state.ivEntries; |
4765 | |
4766 | if (n > ivEntries) { |
4767 | int simd = GRF::bytes(hw) >> 1; |
4768 | int nregs = div_up(n, simd); |
4769 | int cregs = indexVec.getLen(); |
4770 | if (nregs > cregs) |
4771 | indexVec.ranges.push_back(state.ra.alloc_range(nregs - cregs)); |
4772 | if (ivEntries == 0) { |
4773 | mov<uint16_t>(8, indexVec[0][0](1), |
4774 | Immediate::uv(0, 1, 2, 3, 4, 5, 6, 7)); |
4775 | ivEntries = 8; |
4776 | } |
4777 | if (n > 8 && ivEntries < 16) { |
4778 | mov<uint16_t>(8, indexVec[0][8](1), |
4779 | Immediate::uv(8, 9, 10, 11, 12, 13, 14, 15)); |
4780 | ivEntries = 16; |
4781 | } |
4782 | if (GRF::bytes(hw) > 32 && n > 16 && ivEntries < 32) { |
4783 | add<uint16_t>(16, indexVec[0][16](1), indexVec[0].uw(0)(1), 16); |
4784 | ivEntries = 32; |
4785 | } |
4786 | if (n > ivEntries) { |
4787 | for (int e = std::max(cregs, 1); e < nregs; e++) |
4788 | add<uint16_t>(simd, indexVec[e], indexVec[0], simd * e); |
4789 | ivEntries = nregs * simd; |
4790 | } |
4791 | } |
4792 | } |
4793 | |
4794 | template <HW hw> |
4795 | Subregister gemm_kernel_generator_t<hw>::accessIndexVec( |
4796 | int n, CommonState &state) { |
4797 | if (n >= state.ivEntries) extendIndexVec(n, state); |
4798 | |
4799 | int simd = GRF::bytes(hw) >> 1; |
4800 | return state.indexVec[n / simd].uw(n % simd); |
4801 | } |
4802 | |
4803 | static inline void releaseIndexVec(CommonState &state) { |
4804 | safeReleaseRanges(state.indexVec, state); |
4805 | state.ivEntries = 0; |
4806 | } |
4807 | |
4808 | template <HW hw> |
4809 | LDMultiples gemm_kernel_generator_t<hw>::createLDMultiples(bool a64, |
4810 | int nmultiples, const Subregister &ld, const CommonStrategy &strategy, |
4811 | CommonState &state) { |
4812 | int simd = GRF::bytes(hw) >> (a64 ? 3 : 2); |
4813 | int nregs = div_up(nmultiples, simd); |
4814 | auto r = state.ra.try_alloc_range(nregs); |
4815 | |
4816 | GRF tempHi = state.emulate.temp[0], tempLo = state.emulate.temp[1]; |
4817 | bool freeTempHi = false, freeTempLo = false; |
4818 | if (a64) { |
4819 | if (tempHi.isInvalid()) { |
4820 | tempHi = state.ra.alloc(); |
4821 | freeTempHi = true; |
4822 | } |
4823 | if (tempLo.isInvalid()) { |
4824 | tempLo = state.ra.alloc(); |
4825 | freeTempLo = true; |
4826 | } |
4827 | } |
4828 | |
4829 | if (r.isValid()) { |
4830 | extendIndexVec(nmultiples, state); |
4831 | for (int i = 0; i < nregs; i += 2) { |
4832 | auto thisSIMD = simd * std::min(nregs - i, 2); |
4833 | auto iv = accessIndexVec(simd * i, state)(1); |
4834 | if (a64) { |
4835 | mul<uint32_t>(thisSIMD, acc0, ld, iv); |
4836 | mach<uint32_t>(thisSIMD, tempHi, ld, Immediate::ud(0)); |
4837 | mov<uint32_t>(thisSIMD, tempLo, acc0); |
4838 | mov<uint32_t>(thisSIMD, r[i][1](2), tempHi); |
4839 | mov<uint32_t>(thisSIMD, r[i][0](2), tempLo); |
4840 | } else |
4841 | mul<uint32_t>(thisSIMD, r[i], ld, iv); |
4842 | } |
4843 | } |
4844 | |
4845 | if (freeTempHi) state.ra.safeRelease(tempHi); |
4846 | if (freeTempLo) state.ra.safeRelease(tempLo); |
4847 | |
4848 | LDMultiples result; |
4849 | result.range = r; |
4850 | result.a64 = a64; |
4851 | |
4852 | return result; |
4853 | } |
4854 | |
4855 | template <HW hw> |
4856 | Subregister gemm_kernel_generator_t<hw>::findLDMultiple( |
4857 | const LDMultiples &multiples, bool a64, int n, |
4858 | const CommonStrategy &strategy, CommonState &state) { |
4859 | int simd = GRF::bytes(hw) >> (multiples.a64 ? 3 : 2); |
4860 | int off = (n / simd), sub = (n % simd); |
4861 | |
4862 | if (multiples.range.isInvalid()) return Subregister(); |
4863 | if (off < 0 || off >= multiples.range.getLen()) return Subregister(); |
4864 | if (a64 && !multiples.a64) return Subregister(); |
4865 | |
4866 | return !multiples.a64 ? multiples.range[off].ud(sub) |
4867 | : a64 ? multiples.range[off].uq(sub) |
4868 | : multiples.range[off].ud(2 * sub); |
4869 | } |
4870 | |
4871 | static inline void releaseLDMultiples( |
4872 | LDMultiples &multiples, CommonState &state) { |
4873 | state.ra.safeRelease(multiples.range); |
4874 | multiples.a64 = false; |
4875 | } |
4876 | |
4877 | // Ugly helpers handling address shifts. constexpr if would clean this all up. |
4878 | template <HW hw> |
4879 | template <typename BO> |
4880 | typename std::enable_if<!std::is_base_of<RegData, BO>::value, BO>::type |
4881 | gemm_kernel_generator_t<hw>::startShift( |
4882 | const BO &ptr, int shift, CommonState &state) { |
4883 | return ptr >> shift; |
4884 | } |
4885 | |
4886 | template <HW hw> |
4887 | Subregister gemm_kernel_generator_t<hw>::startShift( |
4888 | const MultishiftSubregister &ptr, int shift, CommonState &state) { |
4889 | return ptr >> shift; |
4890 | } |
4891 | |
4892 | template <HW hw> |
4893 | SubregisterPair gemm_kernel_generator_t<hw>::startShift( |
4894 | const SubregisterPair &ptr, int shift, CommonState &state) { |
4895 | if (shift == 0) |
4896 | return ptr; |
4897 | else |
4898 | return SubregisterPair(startShift(ptr.getReg(0), shift, state)); |
4899 | } |
4900 | |
4901 | template <HW hw> |
4902 | template <typename BO> |
4903 | typename std::enable_if<std::is_base_of<RegData, BO>::value, BO>::type |
4904 | gemm_kernel_generator_t<hw>::startShift( |
4905 | const BO &ptr, int shift, CommonState &state) { |
4906 | BO ptrShifted = ptr; |
4907 | |
4908 | // Shift pointer as necessary. |
4909 | if (shift > 0) { |
4910 | ptrShifted = state.ra.alloc_sub(ptr.getType()); |
4911 | shr(1, ptrShifted, ptr, shift); |
4912 | } |
4913 | |
4914 | return ptrShifted; |
4915 | } |
4916 | |
4917 | template <HW hw> |
4918 | template <typename BO, typename BI> |
4919 | typename std::enable_if<!std::is_base_of<RegData, BO>::value>::type |
4920 | gemm_kernel_generator_t<hw>::doneShift( |
4921 | const BO &ptr, const BI &ptrShifted, int shift, CommonState &state) {} |
4922 | |
4923 | template <HW hw> |
4924 | template <typename BO, typename BI> |
4925 | typename std::enable_if<std::is_base_of<RegData, BO>::value>::type |
4926 | gemm_kernel_generator_t<hw>::doneShift( |
4927 | const BO &ptr, const BI &ptrShifted, int shift, CommonState &state) { |
4928 | if (shift > 0) state.ra.release(ptrShifted); |
4929 | } |
4930 | |
4931 | template <HW hw> |
4932 | void gemm_kernel_generator_t<hw>::doneShift(const SubregisterPair &ptr, |
4933 | const SubregisterPair &ptrShifted, int shift, CommonState &state) { |
4934 | if (shift > 0) doneShift(ptr.getReg(0), ptrShifted.getReg(0), shift, state); |
4935 | } |
4936 | |
4937 | static inline bool canIncAddr(const RegisterBlock &blockSrc, |
4938 | const RegisterBlock &blockDst, const MatrixAddressing &atype, |
4939 | const MatrixAddressingStrategy &astrategy) { |
4940 | if (!blockSrc.isLoadBlock() || !blockDst.isLoadBlock()) return false; |
4941 | if (effectiveAccessType(atype, astrategy, blockDst) == AccessType::Block |
4942 | && effectiveAccessType(atype, astrategy, blockSrc) |
4943 | == AccessType::Block) |
4944 | return true; |
4945 | if (isBlock2D(astrategy.accessType)) |
4946 | return (blockSrc.nr == blockDst.nr && blockSrc.nc == blockDst.nc); |
4947 | return (blockSrc.simdSize >= blockDst.simdSize); |
4948 | } |
4949 | |
4950 | // Output code for setting up address/header GRFs for a single block, given |
4951 | // the base pointer (a Subregister, MultishiftSubregister or integer) and leading dimension. |
4952 | template <HW hw> |
4953 | template <typename BO> |
4954 | void gemm_kernel_generator_t<hw>::setupAddr(const GRFRange &addr, const BO &ptr, |
4955 | const RegisterBlock &block, const Subregister &bld, size_t sizeofT, |
4956 | const MatrixAddressing &atype, |
4957 | const MatrixAddressingStrategy &astrategy, |
4958 | const CommonStrategy &strategy, CommonState &state, |
4959 | const Address2DParams ¶ms, LDMultiples ldMultiples) { |
4960 | bool a64 = astrategy.base.getModel() == ModelA64; |
4961 | |
4962 | // Nothing to do for non-load blocks. |
4963 | if (!block.isLoadBlock()) return; |
4964 | |
4965 | auto effAccessType = effectiveAccessType(atype, astrategy, block); |
4966 | switch (effAccessType) { |
4967 | case AccessType::Scattered: |
4968 | case AccessType::ChannelScattered: |
4969 | case AccessType::PseudoBlock: { |
4970 | int simdSize = block.simdSize; |
4971 | auto consecutive = block.extra; |
4972 | bool pseudo = (effAccessType == AccessType::PseudoBlock); |
4973 | auto Tptr = a64 ? DataType::uq : DataType::ud; |
4974 | int ne = elementsPerGRF(hw, Tptr); |
4975 | int preshift = 0; |
4976 | |
4977 | auto oldIndexVec = state.indexVec; |
4978 | auto oldIVEntries = state.ivEntries; |
4979 | |
4980 | if (!pseudo && !isPacked(atype.layout)) { |
4981 | // Get pointers to successive rows/columns, strided by ld. |
4982 | bool allocLDMultiples = false; |
4983 | if (ldMultiples.range.isInvalid()) { |
4984 | ldMultiples = createLDMultiples( |
4985 | a64, simdSize, bld, strategy, state); |
4986 | allocLDMultiples = true; |
4987 | } else |
4988 | (void)findLDMultiple( |
4989 | ldMultiples, a64, simdSize - 1, strategy, state); |
4990 | |
4991 | for (int r = 0; r < addr.getLen(); r += 2) { |
4992 | int nr = std::min(2, addr.getLen() - r); |
4993 | int simd = nr * ne; |
4994 | auto ld0 = findLDMultiple(ldMultiples, a64, |
4995 | r * ne / consecutive, strategy, state); |
4996 | auto ldStride = (ldMultiples.a64 && !a64) ? 2 : 1; |
4997 | auto ldR = ld0(ldStride, consecutive, 0); |
4998 | auto addrR = addr[r].retype(Tptr); |
4999 | if (a64 && consecutive > 1 && hw >= HW::XeHP |
5000 | && !strategy.emulate |
5001 | .emulate64) { /* no swizzle in L pipe */ |
5002 | mov(simd, addr[r].ud(0)(2), |
5003 | ld0.ud(0)(ldStride * 2, consecutive, 0)); |
5004 | mov(simd, addr[r].ud(1)(2), |
5005 | ld0.ud(1)(ldStride * 2, consecutive, 0)); |
5006 | if (ptr != 0) add(simd, addrR, addrR, ptr); |
5007 | } else if (ptr != 0) |
5008 | eadd(simd, addrR, ptr, ldR, strategy, state); |
5009 | else |
5010 | emov(simd, addrR, ldR, strategy, state); |
5011 | } |
5012 | if (allocLDMultiples) releaseLDMultiples(ldMultiples, state); |
5013 | } else { |
5014 | // Get pointers to successive elements, with constant stride. |
5015 | extendIndexVec(simdSize, state); |
5016 | auto iv = accessIndexVec(0, state)(1, consecutive, 0); |
5017 | uint16_t stride; |
5018 | preshift = block.addrShift; |
5019 | auto ptrShifted = startShift(ptr, block.addrShift, state); |
5020 | |
5021 | if (pseudo) |
5022 | stride = (block.ebytes * block.count |
5023 | * getPartialCrosspack( |
5024 | sizeofT, atype, block)) |
5025 | >> preshift; |
5026 | else { |
5027 | int tile = isColMajor(atype.layout) ? atype.tileR |
5028 | : atype.tileC; |
5029 | if (tile == 0) tile = atype.packSize; |
5030 | int psElems = (isLargeCrosspack(sizeofT, atype.crosspack) |
5031 | ? 1 |
5032 | : tile) |
5033 | * atype.crosspack; |
5034 | stride = uint16_t(psElems * sizeofT) >> preshift; |
5035 | } |
5036 | |
5037 | if (a64) { |
5038 | int udStride = (hw >= HW::XeHP) ? 2 : 1; |
5039 | int simd1 = std::min(2 * ne, simdSize); |
5040 | int simd2 = simdSize - simd1; |
5041 | if (udStride == 2 && simd2) { |
5042 | auto iv2 = accessIndexVec(simd1 / consecutive, state)( |
5043 | 1, consecutive, 0); |
5044 | mulConstant( |
5045 | simd2, addr[2].ud(0)(udStride), iv2, stride); |
5046 | mulConstant(simd1, addr[0].ud(0)(udStride), iv, stride); |
5047 | } else |
5048 | mulConstant( |
5049 | simdSize, addr[0].ud(0)(udStride), iv, stride); |
5050 | if (simd2) |
5051 | eadd(simd2, addr[2].uq(), ptrShifted, |
5052 | addr[udStride].ud(0)(udStride), strategy, |
5053 | state); |
5054 | eadd(simd1, addr[0].uq(), ptrShifted, |
5055 | addr[0].ud(0)(udStride), strategy, state); |
5056 | } else if (ptrShifted != 0) { |
5057 | if (consecutive > 1) { |
5058 | mulConstant<uint32_t>(simdSize, addr, iv, stride); |
5059 | add<uint32_t>(simdSize, addr, addr, ptrShifted); |
5060 | } else |
5061 | emad(simdSize, addr[0].ud(), ptrShifted, iv, |
5062 | int32_t(stride), strategy, state); |
5063 | } else |
5064 | mulConstant<uint32_t>(simdSize, addr, iv, stride); |
5065 | |
5066 | doneShift(ptr, ptrShifted, block.addrShift, state); |
5067 | } |
5068 | |
5069 | // Add offsets for consecutive elements in scattered accesses. |
5070 | if (consecutive > 1) { |
5071 | if ((consecutive - 1) * block.ebytes >= 0x10) stub(); |
5072 | if (consecutive > 4) stub(); |
5073 | uint8_t incs[4]; |
5074 | for (int idx = 0; idx < 4; idx++) |
5075 | incs[idx] |
5076 | = (block.ebytes * (idx % consecutive)) >> preshift; |
5077 | |
5078 | if (!a64) { |
5079 | auto incImm = Immediate::uv( |
5080 | incs[0], 0, incs[1], 0, incs[2], 0, incs[3], 0); |
5081 | add<uint32_t>(simdSize, addr, addr, incImm); |
5082 | } else { |
5083 | if (consecutive > 2) stub(); |
5084 | auto incImm |
5085 | = Immediate::uv(incs[0], 0, 0, 0, incs[1], 0, 0, 0); |
5086 | auto temp = state.ra.alloc_range(2); |
5087 | mov<uint32_t>( |
5088 | 2 * elementsPerGRF<uint32_t>(hw), temp, incImm); |
5089 | map(hw, Tptr, addr, addr, strategy, |
5090 | [&](int simd, GRF r1, GRF _) { |
5091 | eadd<uint64_t>(simd, r1, r1, temp[0].ud(0)(2), |
5092 | strategy, state); |
5093 | }); |
5094 | state.ra.safeRelease(temp); |
5095 | } |
5096 | } |
5097 | |
5098 | // Scale if needed. |
5099 | if (block.addrShift > preshift) |
5100 | shr<uint32_t>(simdSize, addr, addr, block.addrShift - preshift); |
5101 | |
5102 | // Restore original cached index vector in case we extended it. |
5103 | releaseRanges(state.indexVec, state); |
5104 | state.indexVec = oldIndexVec; |
5105 | state.ivEntries = oldIVEntries; |
5106 | reclaimRanges(state.indexVec, state); |
5107 | break; |
5108 | } |
5109 | case AccessType::Block: |
5110 | if (astrategy.base.getModel() == ModelA64) { |
5111 | emov(1, addr[0].uq(0), ptr, strategy, state); |
5112 | // Disable OWord channel mode on SKL. |
5113 | if (block.ebytes == 32 && hw < HW::Gen10) |
5114 | mov(1, addr[0].ud(5), uint32_t(0x80000000)); |
5115 | } else if (astrategy.newDP) { |
5116 | mov(1, addr[0].ud(0), ptr); |
5117 | } else if (block.addrShift > 0) |
5118 | shr(1, addr[0].ud(2), ptr, block.addrShift); |
5119 | else |
5120 | mov(1, addr[0].ud(2), ptr); |
5121 | break; |
5122 | case AccessType::Block2D: |
5123 | case AccessType::Block2DTranspose: |
5124 | case AccessType::Block2DVNNI: |
5125 | if (astrategy.base.getModel() != ModelA64) hw_unsupported(); |
5126 | |
5127 | // Assemble some information. |
5128 | bool memCM = isColMajor(atype.layout); |
5129 | int bw, bh, multiX; |
5130 | getBlock2DWH(bw, bh, atype, block, &multiX); |
5131 | |
5132 | auto iremR = params.remR, iremC = params.remC; |
5133 | if (!block.remainderR) iremR.invalidate(); |
5134 | if (!block.remainderC) iremC.invalidate(); |
5135 | |
5136 | auto remW = memCM ? iremR : iremC; |
5137 | auto remH = memCM ? iremC : iremR; |
5138 | auto &nx = memCM ? params.rows : params.cols; |
5139 | auto &ny = memCM ? params.cols : params.rows; |
5140 | auto fixedX = memCM ? params.fixedRows : params.fixedCols; |
5141 | auto fixedY = memCM ? params.fixedCols : params.fixedRows; |
5142 | auto &offX = memCM ? params.offR : params.offC; |
5143 | auto &offY = memCM ? params.offC : params.offR; |
5144 | auto boffX = memCM ? block.offsetR : block.offsetC; |
5145 | auto boffY = memCM ? block.offsetC : block.offsetR; |
5146 | |
5147 | boffX *= uint8_t(sizeofT); |
5148 | if (boffX % block.ebytes) stub(); |
5149 | boffX /= block.ebytes; |
5150 | |
5151 | // If the base address may not be 64b-aligned (128b pre-B4), |
5152 | // we need to emit code to align it down and offset x/width appropriately. |
5153 | int baseAlign = (getStepping() >= SteppingPVCXTB4 ? 64 : 128); |
5154 | bool doBaseAdjust = (atype.alignment & (baseAlign - 1)) != 0; |
5155 | if (doBaseAdjust && !astrategy.address2D) stub(); |
5156 | Subregister baStorage, baseAdjust, baseAdjustElems; |
5157 | |
5158 | if (!astrategy.address2D) mov(4, addr[0].ud(4)(1), 0u); |
5159 | |
5160 | if (doBaseAdjust) { |
5161 | baStorage = state.ra.alloc_sub<uint32_t>(); |
5162 | baseAdjust = baStorage.uw(0); |
5163 | baseAdjustElems = baStorage.uw(1); |
5164 | if (!offX.isValid()) baseAdjustElems = addr[0].ud(5); |
5165 | |
5166 | and_(1, baseAdjust, ptr.ud(0), baseAlign - 1); |
5167 | and_(1, addr[0].ud(0), ptr.ud(0), ~uint32_t(baseAlign - 1)); |
5168 | mov(1, addr[0].ud(1), ptr.ud(1)); |
5169 | if (block.ebytes > 1) |
5170 | shr(1, baseAdjustElems, baseAdjust, log2(block.ebytes)); |
5171 | else |
5172 | baseAdjustElems = baseAdjust; |
5173 | } else |
5174 | emov(1, addr[0].uq(0), ptr, strategy, state); |
5175 | |
5176 | if (astrategy.address2D) { |
5177 | if (params.rows.isInvalid() && params.fixedRows == 0) |
5178 | throw std::runtime_error("Unknown matrix size." ); |
5179 | |
5180 | nx.isValid() ? mad(1, addr[0].ud(2), -1, nx, sizeofT) |
5181 | : mov(1, addr[0].ud(2), fixedX * sizeofT - 1); |
5182 | ny.isValid() ? add(1, addr[0].ud(3), ny, -1) |
5183 | : mov(1, addr[0].ud(3), fixedY - 1); |
5184 | offX.isValid() ? addScaled(1, addr[0].ud(5), boffX, offX, |
5185 | int(sizeofT), block.ebytes, state) |
5186 | : doBaseAdjust |
5187 | ? add(1, addr[0].ud(5), baseAdjustElems, boffX) |
5188 | : mov(1, addr[0].ud(5), boffX); |
5189 | offY.isValid() ? add(1, addr[0].ud(6), offY, boffY) |
5190 | : mov(1, addr[0].ud(6), boffY); |
5191 | if (doBaseAdjust) { |
5192 | add(1, addr[0].ud(2), addr[0].ud(2), baseAdjust); |
5193 | if (offX.isValid()) |
5194 | add(1, addr[0].ud(5), addr[0].ud(5), baseAdjustElems); |
5195 | } |
5196 | if (sizeofT < 4) |
5197 | or_(1, addr[0].ud(2), addr[0].ud(2), |
5198 | 3); // Width must be 4-byte-aligned. |
5199 | } else if (remW.isInvalid() && remH.isInvalid()) |
5200 | emov(1, addr[0].uq(1), |
5201 | uint64_t(bw * block.count * block.ebytes - 1) |
5202 | | (uint64_t(bh * block.ebytes - 1) << 32), |
5203 | strategy, state); |
5204 | else { |
5205 | if (remW.isValid() && multiX > 1) stub(); |
5206 | remW.isValid() ? mad(1, addr[0].ud(2), -1, remW.uw(), sizeofT) |
5207 | : mov(1, addr[0].ud(2), |
5208 | bw * block.count * block.ebytes - 1); |
5209 | remH.isValid() ? mad(1, addr[0].ud(3), -1, remH.uw(), multiX) |
5210 | : mov(1, addr[0].ud(3), bh - 1); |
5211 | if (remW.isValid() && sizeofT < 4) |
5212 | or_(1, addr[0].ud(2), addr[0].ud(2), 3); |
5213 | } |
5214 | |
5215 | if (isPacked(atype.layout)) { |
5216 | auto pitch = bw * block.count * block.ebytes; |
5217 | if (pitch < 64 || pitch & 0xF) hw_unsupported(); |
5218 | mov(1, addr[0].ud(4), pitch - 1); |
5219 | } else |
5220 | add(1, addr[0].ud(4), bld, -1); |
5221 | |
5222 | mov(1, addr[0].ud(7), |
5223 | (bw - 1) | ((bh - 1) << 8) | ((block.count - 1) << 16)); |
5224 | |
5225 | state.ra.safeRelease(baStorage); |
5226 | break; |
5227 | } |
5228 | } |
5229 | |
5230 | // Shift an address block by a combination of a fixed and LD offset. |
5231 | template <HW hw> |
5232 | void gemm_kernel_generator_t<hw>::offsetAddr(const GRFRange &addrDst, |
5233 | const GRFRange &addrSrc, const RegisterBlock &blockDst, |
5234 | const RegisterBlock &blockSrc, int offsetFixed, int offsetLD, |
5235 | const Subregister &ld, const MatrixAddressing &atype, |
5236 | const MatrixAddressingStrategy &astrategy, |
5237 | const CommonStrategy &strategy, CommonState &state, |
5238 | const LDMultiples &ldMultiples) { |
5239 | bool a64 = (astrategy.base.getModel() == ModelA64); |
5240 | |
5241 | if (astrategy.address2D) stub(); |
5242 | |
5243 | if (offsetLD == 0) { |
5244 | if (offsetFixed != 0) |
5245 | incAddr(addrDst, addrSrc, offsetFixed, blockDst, blockSrc, atype, |
5246 | astrategy, strategy, state); |
5247 | } else { |
5248 | // Reuse ld * offsetLD calculation if available. |
5249 | auto ldInc |
5250 | = findLDMultiple(ldMultiples, a64, offsetLD, strategy, state); |
5251 | |
5252 | if (ldInc.isValid() && offsetFixed == 0) |
5253 | incAddr(addrDst, addrSrc, (offsetLD == 1) ? ld : ldInc, blockDst, |
5254 | blockSrc, atype, astrategy, strategy, state); |
5255 | else { |
5256 | Subregister incAlloc |
5257 | = state.ra.alloc_sub(a64 ? DataType::uq : DataType::ud); |
5258 | auto inc = incAlloc; |
5259 | |
5260 | if (ldInc.isInvalid()) { |
5261 | if (offsetLD == 1) |
5262 | ldInc = ld; |
5263 | else { |
5264 | emulConstant(1, inc, ld, offsetLD, strategy, state); |
5265 | ldInc = inc; |
5266 | } |
5267 | } |
5268 | if (offsetFixed != 0) |
5269 | eadd(1, inc, ldInc, offsetFixed, strategy, state); |
5270 | else |
5271 | inc = ldInc; |
5272 | incAddr(addrDst, addrSrc, inc, blockDst, blockSrc, atype, astrategy, |
5273 | strategy, state); |
5274 | |
5275 | state.ra.safeRelease(incAlloc); |
5276 | } |
5277 | } |
5278 | } |
5279 | |
5280 | // Output code for initializing address/header GRFs for one block based on another block's headers. |
5281 | template <HW hw> |
5282 | void gemm_kernel_generator_t<hw>::setupAddrRel(Type T, const GRFRange &addrDst, |
5283 | const GRFRange &addrSrc, const RegisterBlock &blockDst, |
5284 | const RegisterBlock &blockSrc, const vector<RegisterBlock> &layout, |
5285 | const Subregister &ld, const MatrixAddressing &atype, |
5286 | const MatrixAddressingStrategy &astrategy, |
5287 | const CommonStrategy &strategy, CommonState &state, |
5288 | const LDMultiples &ldMultiples) { |
5289 | int deltaR = blockDst.offsetR - blockSrc.offsetR; |
5290 | int deltaC = blockDst.offsetC - blockSrc.offsetC; |
5291 | |
5292 | if (astrategy.address2D) |
5293 | incAddr(addrDst, addrSrc, Subregister(), deltaR, deltaC, blockDst, |
5294 | blockSrc, atype, astrategy, strategy, state); |
5295 | else { |
5296 | int offsetFixed = 0, offsetLD = 0, r = 0, c = 0; |
5297 | |
5298 | if (isPacked(atype.layout)) getLayoutDims(layout, r, c); |
5299 | |
5300 | switch (atype.layout) { |
5301 | case MatrixLayout::N: |
5302 | offsetFixed = deltaR; |
5303 | offsetLD = deltaC; |
5304 | break; |
5305 | case MatrixLayout::T: |
5306 | offsetFixed = deltaC; |
5307 | offsetLD = deltaR; |
5308 | break; |
5309 | case MatrixLayout::Pc: |
5310 | case MatrixLayout::Pr: |
5311 | offsetFixed = untile(T, atype, blockDst, r, c) |
5312 | - untile(T, atype, blockSrc, r, c); |
5313 | break; |
5314 | } |
5315 | |
5316 | offsetFixed *= T.size(); |
5317 | |
5318 | offsetAddr(addrDst, addrSrc, blockDst, blockSrc, offsetFixed, offsetLD, |
5319 | ld, atype, astrategy, strategy, state, ldMultiples); |
5320 | } |
5321 | } |
5322 | |
5323 | static inline int findBaseBlock(const RegisterBlock &block, |
5324 | const vector<RegisterBlock> &layout, int start, int end, |
5325 | const MatrixAddressing &atype, |
5326 | const MatrixAddressingStrategy &astrategy) { |
5327 | int bbase = -1; |
5328 | for (int bb = start; bb < end; bb++) { |
5329 | auto &other = layout[bb]; |
5330 | if (canIncAddr(other, block, atype, astrategy)) { |
5331 | if (bbase < 0) bbase = bb; |
5332 | if (other.offsetR == block.offsetR |
5333 | || other.offsetC == block.offsetC) |
5334 | return bb; // "Best fit" |
5335 | } |
5336 | } |
5337 | return bbase; |
5338 | } |
5339 | |
5340 | // Output code for initializing address/header GRFs for an entire register layout. |
5341 | // ptr is an integer, Subregister, or MultishiftSubregister holding the base pointer/offset. |
5342 | template <HW hw> |
5343 | template <typename BO> |
5344 | void gemm_kernel_generator_t<hw>::setupAddr(Type T, |
5345 | const vector<GRFRange> &addr, const BO &ptr, |
5346 | const vector<RegisterBlock> &layout, const Subregister &ld, |
5347 | const MatrixAddressing &atype, |
5348 | const MatrixAddressingStrategy &astrategy, |
5349 | const CommonStrategy &strategy, CommonState &state, |
5350 | const Address2DParams ¶ms, const LDMultiples &ldMultiples) { |
5351 | auto nblocks = int(layout.size()); |
5352 | |
5353 | for (int b = 0; b < nblocks; b++) { |
5354 | auto &block = layout[b]; |
5355 | |
5356 | // Skip non-load blocks. |
5357 | if (!block.isLoadBlock()) continue; |
5358 | |
5359 | auto bparams = params; |
5360 | Subregister tempRem; |
5361 | if (isBlock2D(astrategy.accessType) && !astrategy.address2D) { |
5362 | tempRem = state.ra.alloc_sub<uint32_t>(); |
5363 | if (bparams.remR.isValid()) bparams.remR = tempRem.uw(0); |
5364 | if (bparams.remC.isValid()) bparams.remC = tempRem.uw(1); |
5365 | if (bparams.remR.isValid() && block.offsetR) |
5366 | add(1 | sat, bparams.remR, params.remR, -block.offsetR); |
5367 | if (bparams.remC.isValid() && block.offsetC) |
5368 | add(1 | sat, bparams.remC, params.remC, -block.offsetC); |
5369 | if (bparams.remR.isValid()) |
5370 | min_(1, bparams.remR, |
5371 | block.offsetR ? bparams.remR : params.remR, block.nr); |
5372 | if (bparams.remC.isValid()) |
5373 | min_(1, bparams.remC, |
5374 | block.offsetC ? bparams.remC : params.remC, block.nc); |
5375 | } |
5376 | // Look for a block to base this one off of. |
5377 | int bbase = findBaseBlock(block, layout, 0, b, atype, astrategy); |
5378 | |
5379 | if (bbase < 0) { |
5380 | // No base address, set up a new base address. |
5381 | setupAddr(addr[b], ptr, block, ld, T.size(), atype, astrategy, |
5382 | strategy, state, bparams, ldMultiples); |
5383 | state.ra.safeRelease(tempRem); |
5384 | } |
5385 | |
5386 | // Increment as appropriate. |
5387 | if (bbase >= 0) { |
5388 | setupAddrRel(T, addr[b], addr[bbase], block, layout[bbase], layout, |
5389 | ld, atype, astrategy, strategy, state, ldMultiples); |
5390 | } else if (!astrategy.address2D) { |
5391 | int offsetFixed = 0, offsetLD = 0, r = 0, c = 0; |
5392 | if (isPacked(atype.layout)) getLayoutDims(layout, r, c); |
5393 | switch (atype.layout) { |
5394 | case MatrixLayout::N: |
5395 | offsetFixed = block.offsetR; |
5396 | offsetLD = block.offsetC; |
5397 | break; |
5398 | case MatrixLayout::T: |
5399 | offsetFixed = block.offsetC; |
5400 | offsetLD = block.offsetR; |
5401 | break; |
5402 | case MatrixLayout::Pc: |
5403 | case MatrixLayout::Pr: |
5404 | offsetFixed = untile(T, atype, block, r, c); |
5405 | break; |
5406 | } |
5407 | |
5408 | offsetFixed *= T.size(); |
5409 | |
5410 | offsetAddr(addr[b], addr[b], block, block, offsetFixed, offsetLD, |
5411 | ld, atype, astrategy, strategy, state, ldMultiples); |
5412 | } |
5413 | } |
5414 | } |
5415 | |
5416 | // Output code for incrementing the pointers for a given block by a specified # of bytes. |
5417 | // The amount may be an immediate, Subregister, or MultishiftSubregister. |
5418 | template <HW hw> |
5419 | template <typename I, typename Ir, typename Ic> |
5420 | void gemm_kernel_generator_t<hw>::incAddr(const GRFRange &addrDst, |
5421 | const GRFRange &addrSrc, I inc, Ir incR, Ic incC, |
5422 | const RegisterBlock &layoutDst, const RegisterBlock &layoutSrc, |
5423 | const MatrixAddressing &atype, |
5424 | const MatrixAddressingStrategy &astrategy, |
5425 | const CommonStrategy &strategy, CommonState &state) { |
5426 | auto incShifted = startShift(inc, layoutDst.addrShift, state); |
5427 | |
5428 | incAddrShifted(addrDst, addrSrc, incShifted, incR, incC, layoutDst, |
5429 | layoutSrc, atype, astrategy, strategy, state); |
5430 | |
5431 | doneShift(inc, incShifted, layoutDst.addrShift, state); |
5432 | } |
5433 | |
5434 | template <HW hw> |
5435 | template <typename I> |
5436 | void gemm_kernel_generator_t<hw>::incAddr(const GRFRange &addrDst, |
5437 | const GRFRange &addrSrc, I inc, const RegisterBlock &layoutDst, |
5438 | const RegisterBlock &layoutSrc, const MatrixAddressing &atype, |
5439 | const MatrixAddressingStrategy &astrategy, |
5440 | const CommonStrategy &strategy, CommonState &state) { |
5441 | if (astrategy.address2D) stub(); |
5442 | incAddr(addrDst, addrSrc, inc, Subregister(), Subregister(), layoutDst, |
5443 | layoutSrc, atype, astrategy, strategy, state); |
5444 | } |
5445 | |
5446 | template <HW hw> |
5447 | template <typename I, typename Ir, typename Ic> |
5448 | void gemm_kernel_generator_t<hw>::incAddrShifted(const GRFRange &addrDst, |
5449 | const GRFRange &addrSrc, I inc, Ir incR, Ic incC, |
5450 | const RegisterBlock &layoutDst, const RegisterBlock &layoutSrc, |
5451 | const MatrixAddressing &atype, |
5452 | const MatrixAddressingStrategy &astrategy, |
5453 | const CommonStrategy &strategy, CommonState &state) { |
5454 | // Handle non-load blocks. |
5455 | if (!layoutDst.isLoadBlock()) return; |
5456 | if (!layoutSrc.isLoadBlock()) stub(); |
5457 | |
5458 | if (layoutDst.addrShift != layoutSrc.addrShift) stub(); |
5459 | |
5460 | auto cinc = avoidConflict(hw, inc, addrSrc[0]); |
5461 | auto cincR = avoidConflict(hw, incR, addrSrc[0]); |
5462 | auto cincC = avoidConflict(hw, incC, addrSrc[0]); |
5463 | |
5464 | switch (effectiveAccessType(atype, astrategy, layoutSrc)) { |
5465 | case AccessType::PseudoBlock: |
5466 | if (layoutSrc.ebytes != layoutDst.ebytes) stub(); |
5467 | // fall through |
5468 | case AccessType::ChannelScattered: |
5469 | case AccessType::Scattered: { |
5470 | int naddrDst = layoutDst.simdSize; |
5471 | int naddrSrc = layoutSrc.simdSize; |
5472 | if (naddrDst > naddrSrc) stub(); |
5473 | if (astrategy.base.getModel() == ModelA64) { |
5474 | auto simd = 2 * elementsPerGRF(hw, Type::u64); |
5475 | for (int ar = 0; naddrDst > 0; ar += 2, naddrDst -= simd) |
5476 | eadd<uint64_t>(std::min(naddrDst, simd), addrDst[ar], |
5477 | addrSrc[ar], avoidConflict(hw, inc, addrSrc[ar]), |
5478 | strategy, state); |
5479 | } else |
5480 | add<uint32_t>(naddrDst, addrDst[0], addrSrc[0], cinc); |
5481 | break; |
5482 | } |
5483 | case AccessType::Block: |
5484 | if (astrategy.base.getModel() == ModelA64) { |
5485 | eadd(1, addrDst[0].uq(0), addrSrc[0].uq(0), cinc, strategy, |
5486 | state); |
5487 | if (addrDst != addrSrc && layoutDst.ebytes == 32 |
5488 | && hw < HW::Gen10) |
5489 | mov(1, addrDst[0].ud(5), |
5490 | uint32_t( |
5491 | 0x80000000)); // Disable OWord channel mode on SKL. |
5492 | } else if (astrategy.newDP) { |
5493 | add(1, addrDst[0].ud(0), addrSrc[0].ud(0), cinc); |
5494 | } else |
5495 | add(1, addrDst[0].ud(2), addrSrc[0].ud(2), cinc); |
5496 | break; |
5497 | case AccessType::Block2D: |
5498 | case AccessType::Block2DTranspose: |
5499 | case AccessType::Block2DVNNI: |
5500 | if (addrDst != addrSrc) mov<uint32_t>(8, addrDst[0], addrSrc[0]); |
5501 | if (astrategy.address2D) { |
5502 | if (isColMajor(atype.layout)) { |
5503 | if (cincR != 0) |
5504 | addScaled(1, addrDst[0].d(5), addrDst[0].d(5), cincR, |
5505 | layoutDst.extra, layoutDst.ebytes, state, true); |
5506 | if (cincC != 0) |
5507 | add(1, addrDst[0].d(6), addrDst[0].d(6), cincC); |
5508 | } else { |
5509 | if (cincC != 0) |
5510 | addScaled(1, addrDst[0].d(5), addrDst[0].d(5), cincC, |
5511 | layoutDst.extra, layoutDst.ebytes, state, true); |
5512 | if (cincR != 0) |
5513 | add(1, addrDst[0].d(6), addrDst[0].d(6), cincR); |
5514 | } |
5515 | } else |
5516 | eadd(1, addrDst[0].uq(0), addrSrc[0].uq(0), cinc, strategy, |
5517 | state); |
5518 | break; |
5519 | } |
5520 | } |
5521 | |
5522 | // Output code for incrementing all pointers for a register layout by a specified # of bytes. |
5523 | // The amount may be an immediate or a subregister. |
5524 | template <HW hw> |
5525 | template <typename I, typename Ir, typename Ic> |
5526 | void gemm_kernel_generator_t<hw>::incAddr(const vector<GRFRange> &addr, I inc, |
5527 | Ir incR, Ic incC, const vector<RegisterBlock> &layout, |
5528 | const MatrixAddressing &atype, |
5529 | const MatrixAddressingStrategy &astrategy, |
5530 | const CommonStrategy &strategy, CommonState &state) { |
5531 | auto nblocks = int(layout.size()); |
5532 | |
5533 | for (int b = 0; b < nblocks; b++) |
5534 | incAddr(addr[b], addr[b], inc, incR, incC, layout[b], layout[b], atype, |
5535 | astrategy, strategy, state); |
5536 | } |
5537 | |
5538 | template <HW hw> |
5539 | template <typename I> |
5540 | void gemm_kernel_generator_t<hw>::incAddr(const vector<GRFRange> &addr, I inc, |
5541 | const vector<RegisterBlock> &layout, const MatrixAddressing &atype, |
5542 | const MatrixAddressingStrategy &astrategy, |
5543 | const CommonStrategy &strategy, CommonState &state) { |
5544 | if (astrategy.address2D) stub(); |
5545 | incAddr(addr, inc, Subregister(), Subregister(), layout, atype, astrategy, |
5546 | strategy, state); |
5547 | } |
5548 | |
5549 | template <HW hw> |
5550 | template <typename I, typename Ir, typename Ic> |
5551 | void gemm_kernel_generator_t<hw>::incAddrShifted(const vector<GRFRange> &addr, |
5552 | I inc, Ir incR, Ic incC, const vector<RegisterBlock> &layout, |
5553 | const MatrixAddressing &atype, |
5554 | const MatrixAddressingStrategy &astrategy, |
5555 | const CommonStrategy &strategy, CommonState &state) { |
5556 | auto nblocks = int(layout.size()); |
5557 | |
5558 | for (int b = 0; b < nblocks; b++) |
5559 | incAddrShifted(addr[b], addr[b], inc, incR, incC, layout[b], layout[b], |
5560 | atype, astrategy, strategy, state); |
5561 | } |
5562 | |
5563 | template <HW hw> |
5564 | template <typename I> |
5565 | void gemm_kernel_generator_t<hw>::incAddrShifted(const vector<GRFRange> &addr, |
5566 | I inc, const vector<RegisterBlock> &layout, |
5567 | const MatrixAddressing &atype, |
5568 | const MatrixAddressingStrategy &astrategy, |
5569 | const CommonStrategy &strategy, CommonState &state) { |
5570 | if (astrategy.address2D) stub(); |
5571 | incAddrShifted(addr, inc, Subregister(), Subregister(), layout, atype, |
5572 | astrategy, strategy, state); |
5573 | } |
5574 | |
5575 | template <typename T> |
5576 | struct NegativeType { |
5577 | typedef T type; |
5578 | }; |
5579 | template <> |
5580 | struct NegativeType<uint8_t> { |
5581 | typedef int8_t type; |
5582 | }; |
5583 | template <> |
5584 | struct NegativeType<uint16_t> { |
5585 | typedef int16_t type; |
5586 | }; |
5587 | template <> |
5588 | struct NegativeType<uint32_t> { |
5589 | typedef int32_t type; |
5590 | }; |
5591 | template <> |
5592 | struct NegativeType<int> { |
5593 | typedef int32_t type; |
5594 | }; |
5595 | template <> |
5596 | struct NegativeType<int64_t> { |
5597 | typedef int32_t type; |
5598 | }; |
5599 | |
5600 | // Output code for incrementing or decrementing all pointers for a register layout by a specified # of bytes. |
5601 | // The amount may be an immediate or a MultishiftSubregister. |
5602 | template <HW hw> |
5603 | template <typename A, typename I, typename Ir, typename Ic> |
5604 | void gemm_kernel_generator_t<hw>::incDecAddr(const A &addr, I inc, Ir incR, |
5605 | Ic incC, const vector<RegisterBlock> &layout, |
5606 | const MatrixAddressing &atype, |
5607 | const MatrixAddressingStrategy &astrategy, |
5608 | const CommonStrategy &strategy, CommonState &state, bool decrement) { |
5609 | typename NegativeType<I>::type signedInc = decrement ? -inc : inc; |
5610 | typename NegativeType<Ir>::type signedIncR = decrement ? -incR : incR; |
5611 | typename NegativeType<Ic>::type signedIncC = decrement ? -incC : incC; |
5612 | |
5613 | incAddr(addr, signedInc, signedIncR, signedIncC, layout, atype, astrategy, |
5614 | strategy, state); |
5615 | } |
5616 | |
5617 | template <HW hw> |
5618 | template <typename A, typename I> |
5619 | void gemm_kernel_generator_t<hw>::incDecAddr(const A &addr, I inc, |
5620 | const vector<RegisterBlock> &layout, const MatrixAddressing &atype, |
5621 | const MatrixAddressingStrategy &astrategy, |
5622 | const CommonStrategy &strategy, CommonState &state, bool decrement) { |
5623 | if (astrategy.address2D) stub(); |
5624 | incDecAddr(addr, inc, Subregister(), Subregister(), layout, atype, |
5625 | astrategy, strategy, state, decrement); |
5626 | } |
5627 | |
5628 | template <HW hw> |
5629 | void gemm_kernel_generator_t<hw>::setAddrRemainder(Type T, const GRFRange &addr, |
5630 | const RegisterBlock &block, const Subregister &remR, |
5631 | const Subregister &remC, const MatrixAddressing &atype, |
5632 | const MatrixAddressingStrategy &astrategy, |
5633 | const CommonStrategy &strategy, CommonState &state) { |
5634 | if (!isBlock2D(astrategy.accessType) || astrategy.address2D) return; |
5635 | |
5636 | auto tempRem = state.ra.alloc_sub<uint32_t>(); |
5637 | Subregister thisRemR = remR, thisRemC = remC; |
5638 | |
5639 | auto memCM = isColMajor(atype.layout); |
5640 | auto &remW = memCM ? thisRemR : thisRemC; |
5641 | auto &remH = memCM ? thisRemC : thisRemR; |
5642 | int bw, bh, multiX; |
5643 | getBlock2DWH(bw, bh, atype, block, &multiX); |
5644 | |
5645 | if (!block.remainderR) thisRemR.invalidate(); |
5646 | if (!block.remainderC) thisRemC.invalidate(); |
5647 | if (thisRemR.isValid()) thisRemR = tempRem.uw(0); |
5648 | if (thisRemC.isValid()) thisRemC = tempRem.uw(1); |
5649 | if (thisRemR.isValid() && block.offsetR) |
5650 | add(1 | sat, thisRemR, remR, -block.offsetR); |
5651 | if (thisRemC.isValid() && block.offsetC) |
5652 | add(1 | sat, thisRemC, remC, -block.offsetC); |
5653 | if (thisRemR.isValid()) |
5654 | min_(1, thisRemR, block.offsetR ? thisRemR : remR, block.nr); |
5655 | if (thisRemC.isValid()) |
5656 | min_(1, thisRemC, block.offsetC ? thisRemC : remC, block.nc); |
5657 | |
5658 | if (remW.isValid()) { |
5659 | if (block.count > 1 || multiX > 1) stub(); |
5660 | mad(1, addr[0].ud(2), -1, remW.uw(), T.size()); |
5661 | } |
5662 | if (remH.isValid()) mad(1, addr[0].ud(3), -1, remH.uw(), T.size() * multiX); |
5663 | if (remW.isValid() && T.size() < 4) or_(1, addr[0].ud(2), addr[0].ud(2), 3); |
5664 | |
5665 | state.ra.safeRelease(tempRem); |
5666 | } |
5667 | |
5668 | template <HW hw> |
5669 | void gemm_kernel_generator_t<hw>::setAddrRemainder(Type T, |
5670 | const vector<GRFRange> &addr, const vector<RegisterBlock> &layout, |
5671 | const Subregister &remR, const Subregister &remC, |
5672 | const MatrixAddressing &atype, |
5673 | const MatrixAddressingStrategy &astrategy, |
5674 | const CommonStrategy &strategy, CommonState &state) { |
5675 | auto nblocks = int(layout.size()); |
5676 | |
5677 | for (int b = 0; b < nblocks; b++) |
5678 | setAddrRemainder(T, addr[b], layout[b], remR, remC, atype, astrategy, |
5679 | strategy, state); |
5680 | } |
5681 | |
5682 | template <HW hw> |
5683 | void gemm_kernel_generator_t<hw>::setupTeardownRemask(Type T, int index, |
5684 | bool setup, int nq, const Subregister &remQ, |
5685 | const CommonStrategy &strategy, CommonState &state, int fixedOffQ, |
5686 | const Subregister &variableOffQ) { |
5687 | if (setup) { |
5688 | auto masks = state.remaskRegs[index] = state.ra.alloc_range( |
5689 | div_up(T.size(), 2) * div_up(nq * 2, GRF::bytes(hw))); |
5690 | int ne16 = elementsPerGRF(hw, Type::u16); |
5691 | int n16 = std::min(nq, ne16); |
5692 | int ne = elementsPerGRF(hw, T); |
5693 | auto flag = state.raVFlag.tryAlloc((n16 > 16) ? 2 : 1); |
5694 | bool useCMP = flag.isValid() |
5695 | && (T.size() < 4); // apparent issues with 4b sequence |
5696 | |
5697 | auto effRemQ = remQ; |
5698 | bool freeEffRemQ = false; |
5699 | bool haveVariableOff = variableOffQ.isValid(); |
5700 | bool haveFixedOff = (fixedOffQ != 0); |
5701 | |
5702 | if (haveVariableOff || haveFixedOff) { |
5703 | freeEffRemQ = true; |
5704 | effRemQ = state.ra.alloc_sub<uint32_t>(); |
5705 | |
5706 | if (haveVariableOff && haveFixedOff) |
5707 | eadd3(1, effRemQ, remQ, -variableOffQ, -fixedOffQ); |
5708 | else if (haveVariableOff) |
5709 | add(1, effRemQ, remQ, -variableOffQ); |
5710 | else |
5711 | add(1, effRemQ, remQ, -fixedOffQ); |
5712 | } |
5713 | |
5714 | mov<uint16_t>(8, masks[0][0](1), Immediate::uv(0, 1, 2, 3, 4, 5, 6, 7)); |
5715 | if (nq > 8) |
5716 | mov<uint16_t>(8, masks[0][8](1), |
5717 | Immediate::uv(8, 9, 10, 11, 12, 13, 14, 15)); |
5718 | if (GRF::bytes(hw) > 32 && nq > 16) |
5719 | add<uint16_t>(16, masks[0][16](1), masks[0][0](1), 16); |
5720 | add<uint16_t>(n16, masks[0], masks[0], -effRemQ.w()); |
5721 | if (!useCMP) |
5722 | for (int q0 = n16; q0 < nq; q0 += n16) |
5723 | add<uint16_t>(n16, masks[q0 / n16], masks[0], q0); |
5724 | |
5725 | switch (T.size()) { |
5726 | case 1: |
5727 | case 2: |
5728 | if (useCMP) { |
5729 | for (int q0 = n16; q0 < nq; q0 += n16) |
5730 | cmp<int16_t>(n16 | lt | flag, masks[q0 / n16], masks[0], |
5731 | -q0); |
5732 | asr<int16_t>(n16, masks[0], masks[0], 15); |
5733 | } else { |
5734 | map(hw, Type::s16, masks, masks, strategy, |
5735 | [=](int simd, const RegData &r1, const RegData &) { |
5736 | asr(simd, r1, r1, 15); |
5737 | }); |
5738 | } |
5739 | if (T.size() == 1) |
5740 | for (int q0 = 0; q0 < nq; q0 += n16) |
5741 | mov(n16, masks[q0 / ne].ub(q0 % ne)(1), |
5742 | masks[q0 / n16].ub(1)(2)); |
5743 | break; |
5744 | case 4: |
5745 | for (int qq0 = div_up(nq, ne16) - 1; qq0 >= 1; qq0--) { |
5746 | useCMP ? cmp(ne16 | lt | flag, masks[qq0 * 2].d(), |
5747 | masks[qq0].w(), -qq0 * ne16) |
5748 | : asr(ne16, masks[qq0 * 2].d(), masks[qq0].w(), 15); |
5749 | } |
5750 | if (nq > (ne16 / 2)) |
5751 | asr(ne16 / 2, masks[1].d(), masks[0].w(ne16 / 2)(1), 15); |
5752 | asr(ne16 / 2, masks[0].d(), masks[0].w(), 15); |
5753 | break; |
5754 | default: stub(); |
5755 | } |
5756 | |
5757 | if (freeEffRemQ) state.ra.safeRelease(effRemQ); |
5758 | state.raVFlag.safeRelease(flag); |
5759 | } else |
5760 | state.ra.safeRelease(state.remaskRegs[index]); |
5761 | } |
5762 | |
5763 | template <HW hw> |
5764 | void gemm_kernel_generator_t<hw>::remaskLayout(Type T, int index, bool column, |
5765 | const std::vector<RegisterBlock> &layout, const GRFMultirange ®s, |
5766 | const CommonStrategy &strategy, CommonState &state, int offset) { |
5767 | for (auto &block : layout) { |
5768 | auto crosspack = block.crosspack; |
5769 | bool colMajor = block.colMajor; |
5770 | int nx = colMajor ? block.nr : block.nc; |
5771 | int ny = colMajor ? block.nc : block.nr; |
5772 | |
5773 | for (int y0 = 0; y0 < ny; y0 += crosspack) { |
5774 | for (int x0 = 0; x0 < nx;) { |
5775 | auto ii0 = colMajor ? x0 : y0; |
5776 | auto jj0 = colMajor ? y0 : x0; |
5777 | auto i0 = ii0 + block.offsetR; |
5778 | auto j0 = jj0 + block.offsetC; |
5779 | |
5780 | int ne; |
5781 | auto sub = findBlockReg( |
5782 | T, block, ii0, jj0, regs, ne, -1, block.component); |
5783 | |
5784 | auto necp = ne * crosspack; |
5785 | necp = std::min(necp, 2 * elementsPerGRF(hw, T)); |
5786 | if ((necp * T) & 3) stub(); |
5787 | |
5788 | int mstride; |
5789 | Type mtype = Type::u32; |
5790 | |
5791 | if (colMajor != column && crosspack == 1) |
5792 | mstride = 1; |
5793 | else if (colMajor != column && crosspack == 4 / T) |
5794 | mstride = 1, mtype = sintType(T); |
5795 | else if (colMajor == column && crosspack == 4 / T) |
5796 | mstride = 0; |
5797 | else |
5798 | stub(); |
5799 | |
5800 | int moff = (offset + (column ? j0 : i0)) * T / mtype; |
5801 | int mreg = moff / elementsPerGRF(hw, mtype); |
5802 | int msub = moff % elementsPerGRF(hw, mtype); |
5803 | |
5804 | and_<uint32_t>((necp * T) / 4, sub.ud()(1), sub.ud()(1), |
5805 | state.remaskRegs[index][mreg].sub(msub, mtype.ngen())( |
5806 | mstride)); |
5807 | x0 += necp / crosspack; |
5808 | } |
5809 | } |
5810 | } |
5811 | } |
5812 | |
5813 | static bool needsRemask(Type T, bool column, const RegisterBlock &block, |
5814 | const MatrixAddressingStrategy &astrategy, bool ignoreMasks = false) { |
5815 | if (!ignoreMasks) |
5816 | if (column ? !block.remainderC : !block.remainderR) return false; |
5817 | |
5818 | bool block2D = isBlock2D(astrategy.accessType); |
5819 | |
5820 | int maskGranularity = block.ebytes; |
5821 | if (block.ebytes >= 16) maskGranularity = 4; |
5822 | if (block2D) maskGranularity = std::max(maskGranularity, 4); |
5823 | if (ignoreMasks && !(block2D && astrategy.address2D)) maskGranularity = 256; |
5824 | |
5825 | return (T.size() < maskGranularity); |
5826 | } |
5827 | |
5828 | static bool needsRemask(Type T, bool column, |
5829 | const vector<RegisterBlock> &layout, |
5830 | const MatrixAddressingStrategy &astrategy, bool ignoreMasks = false) { |
5831 | for (auto &block : layout) |
5832 | if (needsRemask(T, column, block, astrategy, ignoreMasks)) return true; |
5833 | return false; |
5834 | } |
5835 | |
5836 | // The systolic array performs a series of GEMVs with a single fixed-size matrix. |
5837 | // The size of the matrix is osys x ksys with vectors of size ksys x 1. |
5838 | // The number of GEMVs (with same matrix) is given by the (variable) repeat count. |
5839 | struct SystolicParams { |
5840 | int opsPerChan; // # of FMAs/stage |
5841 | int sdepth; // Number of stages (systolic depth) |
5842 | int rcountMax; // Maximum repeat count (# of RHS) |
5843 | int ksys; // Total number of FMAs |
5844 | int osys; // Output vector length |
5845 | }; |
5846 | |
5847 | static inline SystolicParams systolicParams( |
5848 | HW hw, const GEMMProblem &problem, const GEMMStrategy &strategy) { |
5849 | SystolicParams params; |
5850 | params.opsPerChan = std::max( |
5851 | 1, std::min(4 / problem.Ta.real(), 4 / problem.Tb.real())); |
5852 | params.sdepth = 8; |
5853 | params.ksys = params.sdepth * params.opsPerChan; |
5854 | params.osys = GRF::bytes(hw) / std::max(problem.Tc.real().size(), 4); |
5855 | params.rcountMax = 8; |
5856 | |
5857 | if (hw == HW::XeHPC) { |
5858 | // Workaround for src2 read suppression bug (TODO PVC-B: remove WA) |
5859 | bool cColMajor = isRegisterColMajor(problem.Tc, problem.C, strategy.C); |
5860 | if (strategy.unroll[cColMajor ? LoopN : LoopM] == 8) |
5861 | params.rcountMax = 4; |
5862 | } |
5863 | |
5864 | return params; |
5865 | } |
5866 | |
5867 | // Return # of outer products performed at once. |
5868 | static inline int minOuterProductCount( |
5869 | HW hw, const GEMMProblem &problem, const GEMMStrategy &strategy) { |
5870 | auto Ta = problem.Ta, Tb = problem.Tb, Tc = problem.Tc; |
5871 | if (strategy.systolic) { |
5872 | auto params = systolicParams(hw, problem, strategy); |
5873 | return params.ksys; |
5874 | } |
5875 | if (Ta.real().size() == 1 && Tb.real().size() == 1 && Tc.real().size() == 4 |
5876 | && (hw >= HW::Gen12LP)) |
5877 | return 4; |
5878 | return 1; |
5879 | } |
5880 | |
5881 | // Return # of outer products performed at once. |
5882 | static inline int outerProductCount( |
5883 | HW hw, const GEMMProblem &problem, const GEMMStrategy &strategy) { |
5884 | return minOuterProductCount(hw, problem, strategy) * strategy.kChain; |
5885 | } |
5886 | |
5887 | // Get the A and B crosspacks needed by the kernel. 0 indicates any crosspack is OK. |
5888 | static std::tuple<int, int> targetKernelCrosspack( |
5889 | HW hw, const GEMMProblem &problem, const GEMMStrategy &strategy) { |
5890 | int opBatch = minOuterProductCount(hw, problem, strategy); |
5891 | bool aColMajor = isRegisterColMajor(problem.Ta, problem.A, strategy.A); |
5892 | bool bColMajor = isRegisterColMajor(problem.Tb, problem.B, strategy.B); |
5893 | bool cColMajor = isRegisterColMajor(problem.Tc, problem.C, strategy.C); |
5894 | |
5895 | if (strategy.systolic) { |
5896 | return cColMajor |
5897 | ? std::make_tuple(std::max(1, 4 / problem.Ta.real()), 1) |
5898 | : std::make_tuple(1, std::max(1, 4 / problem.Tb.real())); |
5899 | } |
5900 | if (opBatch == 1) { |
5901 | return cColMajor ? std::make_tuple(1, 0) : std::make_tuple(0, 1); |
5902 | } else { |
5903 | bool bcastOK = cColMajor ? bColMajor : !aColMajor; |
5904 | |
5905 | return cColMajor ? std::make_tuple(opBatch, bcastOK ? 1 : opBatch) |
5906 | : std::make_tuple(bcastOK ? 1 : opBatch, opBatch); |
5907 | } |
5908 | } |
5909 | |
5910 | // Get the A and B crosspacks to use for SLM data. |
5911 | static std::tuple<int, int> targetSLMCrosspack( |
5912 | HW hw, const GEMMProblem &problem, const GEMMStrategy &strategy) { |
5913 | int opBatch = minOuterProductCount(hw, problem, strategy); |
5914 | |
5915 | if (strategy.systolic) { |
5916 | bool cColMajor = isRegisterColMajor(problem.Tc, problem.C, strategy.C); |
5917 | return cColMajor |
5918 | ? std::make_tuple(std::max(1, 4 / problem.Ta.size()), opBatch) |
5919 | : std::make_tuple(opBatch, std::max(1, 4 / problem.Tb.size())); |
5920 | } |
5921 | return std::make_tuple(opBatch, opBatch); |
5922 | } |
5923 | |
5924 | // Get the A and B tiling needed by the kernel. |
5925 | // Return value is in the format {A_tileR, A_tileC, B_tileR, B_tileC}. |
5926 | static std::tuple<int, int, int, int> targetKernelTiling( |
5927 | HW hw, const GEMMProblem &problem, const GEMMStrategy &strategy) { |
5928 | if (strategy.systolic) { |
5929 | auto params = systolicParams(hw, problem, strategy); |
5930 | bool cColMajor = isRegisterColMajor(problem.Tc, problem.C, strategy.C); |
5931 | auto tileO_V = params.osys; |
5932 | auto tileI_N = params.ksys; |
5933 | if (strategy.unroll[cColMajor ? LoopN : LoopM] == 1) tileI_N = 0; |
5934 | return cColMajor ? std::make_tuple(tileO_V, 0, tileI_N, 0) |
5935 | : std::make_tuple(0, tileI_N, 0, tileO_V); |
5936 | } |
5937 | return std::make_tuple(0, 0, 0, 0); |
5938 | } |
5939 | |
5940 | // Do one outer product (k = 1 slice) of A*B, updating C. ha and hb are the |
5941 | // k indices within the A and B chunks, respectively. A_copy, B_copy are the |
5942 | // indices of the A, B copies to use. |
5943 | template <HW hw> |
5944 | void gemm_kernel_generator_t<hw>::outerProduct(int h, int ha, int hb, |
5945 | int opCount, const vector<RegisterBlock> &A_layout, |
5946 | const vector<RegisterBlock> &B_layout, const GRFMultirange &A_regs, |
5947 | const GRFMultirange &B_regs, const GEMMProblem &problem, |
5948 | const GEMMStrategy &strategy, GEMMState &state) { |
5949 | auto Ta = problem.Ta, Tb = problem.Tb, Tc = problem.Tc; |
5950 | |
5951 | if (strategy.systolic) { |
5952 | outerProductSystolic(h, ha, hb, A_layout, B_layout, A_regs, B_regs, |
5953 | problem, strategy, state); |
5954 | return; |
5955 | } |
5956 | if (isGen9IGEMM(hw, Ta, Tb, Tc)) { |
5957 | outerProductGen9IGEMM(ha, hb, A_layout, B_layout, A_regs, B_regs, |
5958 | problem, strategy, state); |
5959 | return; |
5960 | } |
5961 | |
5962 | bool mixedMode = ((Tc.real() == Type::f32) |
5963 | && (Ta.real() != Type::f32 || Tb.real() != Type::f32)); |
5964 | bool useDP4A = (Ta.size() == 1 && Tb.size() == 1 && Tc.size() == 4 |
5965 | && hw >= HW::Gen12LP); |
5966 | |
5967 | int minOPCount = minOuterProductCount(hw, problem, strategy); |
5968 | int kChain = std::min(strategy.kChain, opCount); |
5969 | int aCP, bCP; |
5970 | std::tie(aCP, bCP) = targetKernelCrosspack(hw, problem, strategy); |
5971 | |
5972 | int accNum = 0; |
5973 | Subregister Clast; |
5974 | int nec = elementsPerGRF(hw, Tc); |
5975 | bool globalCM = isLayoutColMajor(state.C_layout); |
5976 | int fmaSIMD = strategy.fmaSIMD; |
5977 | |
5978 | bool csplit = false, mixedRC = false; |
5979 | int icompCount = 1, ocompCount = 1, ivcompCount = 1, ovcompCount = 1; |
5980 | |
5981 | bool bfloat16WA = (Tc.real() == Type::f32) |
5982 | && ((globalCM ? Tb : Ta).real() == Type::bf16); |
5983 | |
5984 | // Emit an FMA instruction. |
5985 | auto outputFMA = [&](const InstructionModifier &mod, const Subregister &A, |
5986 | const Subregister &B, const Subregister &C, |
5987 | const RegData &bcastSrc, bool colMajor, int hh, |
5988 | bool ivfirst, bool ivlast) { |
5989 | auto Cacc = AccumulatorRegister(accNum).sub(0, Tc.real().ngen()); |
5990 | auto Csrc = (hh == 0 && ivfirst) ? C : Cacc; |
5991 | auto Cdst = (hh == opCount - minOPCount && ivlast) ? C : Cacc; |
5992 | if (useDP4A) { |
5993 | auto Ar = A.reinterpret( |
5994 | 0, isSigned(A.getType()) ? DataType::d : DataType::ud); |
5995 | auto Br = B.reinterpret( |
5996 | 0, isSigned(B.getType()) ? DataType::d : DataType::ud); |
5997 | |
5998 | colMajor ? dp4a(mod, Cdst(1), Csrc(1), Ar(1), Br(0)) |
5999 | : dp4a(mod, Cdst(1), Csrc(1), Br(1), Ar(0)); |
6000 | } else if (C.isARF() && hw < HW::XeHP) { |
6001 | colMajor ? mac(mod, C(1), A(1), bcastSrc) |
6002 | : mac(mod, C(1), bcastSrc, B(1)); |
6003 | } else { |
6004 | // On Gen12, always put broadcast in src2 for better bank conflict avoidance. |
6005 | colMajor ? mad(mod, Cdst(1), Csrc(1), A(1), bcastSrc) |
6006 | : (hw < HW::Gen12LP) |
6007 | ? mad(mod, Cdst(1), Csrc(1), bcastSrc, B(1)) |
6008 | : mad(mod, Cdst(1), Csrc(1), B(1), bcastSrc); |
6009 | } |
6010 | }; |
6011 | |
6012 | ha = align_down(ha, opCount); |
6013 | hb = align_down(hb, opCount); |
6014 | |
6015 | // Decide whether to loop in column or row major order. |
6016 | // x = vectorized dimension |
6017 | // y = non-vectorized dimension |
6018 | int nx = globalCM ? strategy.unroll[LoopM] : strategy.unroll[LoopN]; |
6019 | int ny = globalCM ? strategy.unroll[LoopN] : strategy.unroll[LoopM]; |
6020 | int nx1 = (mixedMode || state.broadcast) ? nx : fmaSIMD; |
6021 | |
6022 | // Prepare for chaining FMAs through accumulator registers. |
6023 | int necAcc = nec * (csplit ? 2 : 1); |
6024 | int accCount = AccumulatorRegister::count(hw, strategy.GRFs, Tc.ngen()); |
6025 | int accPerFMA = div_up(std::min(nx, fmaSIMD), necAcc); |
6026 | int minAccPerFMA = Tc.isFP() ? 1 : 2; |
6027 | accPerFMA = std::max(accPerFMA, minAccPerFMA); |
6028 | int independentAccs = div_up(accCount, accPerFMA); |
6029 | |
6030 | int nx1i = 1, ny1 = 1; |
6031 | if (kChain > 1) { |
6032 | if (independentAccs < icompCount) hw_unsupported(); |
6033 | int indepAccComp = div_up(independentAccs, icompCount); |
6034 | |
6035 | nx1i = std::min(nx1, indepAccComp * fmaSIMD); |
6036 | ny1 = div_up(indepAccComp, div_up(nx1i, fmaSIMD)); |
6037 | } |
6038 | |
6039 | GRFRange broadcastRegs = state.broadcast_regs; |
6040 | Subregister lastBcastBase; |
6041 | |
6042 | // Last A/B blocks found; |
6043 | const RegisterBlock *A_blockLast = nullptr, *B_blockLast = nullptr; |
6044 | |
6045 | for (int x0 = 0; x0 < nx; x0 += nx1) { |
6046 | for (int ovcomp = 0; ovcomp < ovcompCount; ovcomp++) { |
6047 | for (int ocomp = 0; ocomp < ocompCount; ocomp++) { |
6048 | for (int y0 = 0; y0 < ny; y0 += ny1) { |
6049 | for (int x1 = 0; x1 < nx1 && (x0 + x1) < nx;) { |
6050 | int x1New = x1; |
6051 | for (int ivcomp = 0; ivcomp < ivcompCount; ivcomp++) { |
6052 | for (int hh = 0; hh < opCount; hh += minOPCount) { |
6053 | accNum = 0; |
6054 | for (int y1 = 0; y1 < ny1 && y0 + y1 < ny; |
6055 | y1++) { |
6056 | for (int x1i = x1; (x1i < x1 + nx1i) |
6057 | && (x0 + x1i < nx);) { |
6058 | auto x = x0 + x1i; |
6059 | auto y = y0 + y1; |
6060 | auto i = globalCM ? x : y; |
6061 | auto j = globalCM ? y : x; |
6062 | auto hha = ha + hh; |
6063 | auto hhb = hb + hh; |
6064 | |
6065 | int fmaCount = 1; |
6066 | |
6067 | for (int icomp = 0; icomp < icompCount; |
6068 | icomp++) { |
6069 | // Find the appropriate A and B registers. |
6070 | int na, nb; |
6071 | int vcomp = ivcomp + ovcomp; |
6072 | int ncomp = (vcomp ^ ocomp) + icomp; |
6073 | int compA |
6074 | = globalCM ? vcomp : ncomp; |
6075 | int compB |
6076 | = globalCM ? ncomp : vcomp; |
6077 | |
6078 | const RegisterBlock *A_block, |
6079 | *B_block; |
6080 | Subregister A = findBlockReg(Ta, |
6081 | A_layout, i, hha, A_regs, |
6082 | na, A_block, compA); |
6083 | Subregister B = findBlockReg(Tb, |
6084 | B_layout, hhb, j, B_regs, |
6085 | nb, B_block, compB); |
6086 | |
6087 | // Check for expected crosspack. |
6088 | if (globalCM ? (aCP |
6089 | && A_block->crosspack |
6090 | != aCP) |
6091 | : (bCP |
6092 | && B_block->crosspack |
6093 | != bCP)) |
6094 | stub(); |
6095 | |
6096 | // Check if we should specify {Atomic}. |
6097 | bool atomic = (strategy.atomicFMA |
6098 | && (A_block == A_blockLast) |
6099 | && (B_block |
6100 | == B_blockLast)); |
6101 | A_blockLast = A_block; |
6102 | B_blockLast = B_block; |
6103 | |
6104 | // Find the appropriate C register. |
6105 | int C_buffer = csplit |
6106 | ? 0 |
6107 | : (icomp + ocomp); |
6108 | int compC = csplit ? ocomp : 0; |
6109 | int nc; |
6110 | const RegisterBlock *C_block; |
6111 | Subregister C = findBlockReg(Tc, |
6112 | state.C_layout, i, j, |
6113 | state.C_regs[C_buffer], nc, |
6114 | C_block, compC); |
6115 | if (C_block->crosspack > 1) stub(); |
6116 | |
6117 | // Swap out C register for an accumulator, if necessary. |
6118 | if (strategy.cAccumulators) { |
6119 | auto C_roff = C.getBase() |
6120 | - state.C_regs[0] |
6121 | .ranges[0] |
6122 | .getBase(); |
6123 | if (C_roff < state.C_accCount) |
6124 | C = AccumulatorRegister( |
6125 | C_roff) |
6126 | .sub(C.getOffset(), |
6127 | Tc.ngen()); |
6128 | } |
6129 | |
6130 | InstructionModifier mod; |
6131 | |
6132 | // Use requested execution size if possible, but limited to available elements. |
6133 | // Decide on kernel type based on register block layouts. |
6134 | bool canColMajor |
6135 | = (A_block->colMajor |
6136 | && globalCM); |
6137 | bool canRowMajor |
6138 | = (!B_block->colMajor |
6139 | && !globalCM); |
6140 | bool colMajor = globalCM; |
6141 | |
6142 | if (!canColMajor && !canRowMajor) |
6143 | fmaCount = 1; |
6144 | else if (canColMajor) |
6145 | fmaCount = rounddown_pow2( |
6146 | std::min({fmaSIMD, na, |
6147 | nc})); |
6148 | else |
6149 | fmaCount = rounddown_pow2( |
6150 | std::min({fmaSIMD, nb, |
6151 | nc})); |
6152 | |
6153 | int simdSize = fmaCount; |
6154 | |
6155 | // Crosspacked kernels: ensure broadcast matrix is contiguous in k. |
6156 | if (minOPCount > 1) { |
6157 | bool nativeDir = (globalCM |
6158 | ? B_block->colMajor |
6159 | : !A_block->colMajor); |
6160 | auto bcastCrosspack |
6161 | = (globalCM ? B_block |
6162 | : A_block) |
6163 | ->crosspack; |
6164 | if (nativeDir) { |
6165 | if ((globalCM ? nb : na) |
6166 | < minOPCount) |
6167 | stub(); |
6168 | if (bcastCrosspack > 1) |
6169 | stub(); |
6170 | } else { |
6171 | if (bcastCrosspack |
6172 | % minOPCount) |
6173 | stub(); |
6174 | } |
6175 | } |
6176 | |
6177 | // Add Atomic if appropriate. |
6178 | if (atomic) mod |= Atomic; |
6179 | |
6180 | // Handle broadcast duties. |
6181 | Subregister bcastSrcSub |
6182 | = colMajor ? B : A; |
6183 | RegData bcastSrc = bcastSrcSub; |
6184 | |
6185 | if (state.broadcast) { |
6186 | |
6187 | // Broadcast if necessary: pair of doubles (doubleWA) or single elements. |
6188 | int nbcast = strategy.doubleWA |
6189 | ? 2 |
6190 | : 1; |
6191 | int hs = strategy.doubleWA |
6192 | ? 0 |
6193 | : nbcast; |
6194 | |
6195 | auto bcastType |
6196 | = bcastSrc.getType(); |
6197 | Subregister bcastBase |
6198 | = bcastSrcSub; |
6199 | bcastBase.setOffset( |
6200 | bcastBase.getOffset() |
6201 | & ~(nbcast - 1)); |
6202 | |
6203 | if (bcastBase |
6204 | != lastBcastBase) { |
6205 | auto bcastRegion = bcastBase( |
6206 | 0, nbcast, |
6207 | (nbcast > 1) ? 1 |
6208 | : 0); |
6209 | if (bfloat16WA) { |
6210 | // Upconvert to f32 during broadcast. |
6211 | bcastRegion.setType( |
6212 | DataType::uw); |
6213 | shl(simdSize, |
6214 | broadcastRegs[0] |
6215 | .ud(), |
6216 | bcastRegion, |
6217 | 16); |
6218 | } else { |
6219 | moveToIntPipe(simdSize, |
6220 | bcastRegion); |
6221 | mov(simdSize * bcastSrc.getBytes() |
6222 | / bcastRegion |
6223 | .getBytes(), |
6224 | broadcastRegs[0].retype( |
6225 | bcastRegion |
6226 | .getType()), |
6227 | bcastRegion); |
6228 | } |
6229 | } |
6230 | if (bfloat16WA) |
6231 | bcastType = DataType::f; |
6232 | bcastSrc = broadcastRegs[0].sub( |
6233 | bcastSrc.getOffset() |
6234 | & (nbcast - 1), |
6235 | bcastType)(hs); |
6236 | lastBcastBase = bcastBase; |
6237 | } |
6238 | |
6239 | bool ivfirst |
6240 | = mixedRC || (ivcomp == 0); |
6241 | bool ivlast = mixedRC |
6242 | || (ivcomp |
6243 | == ivcompCount - 1); |
6244 | |
6245 | // Finally, perform the long-awaited FMA. |
6246 | outputFMA(simdSize | mod, A, B, C, |
6247 | bcastSrc, colMajor, hh, |
6248 | ivfirst, ivlast); |
6249 | Clast = C; |
6250 | |
6251 | if (kChain > 1 |
6252 | && accNum >= accCount) |
6253 | stub(); |
6254 | accNum += std::max(minAccPerFMA, |
6255 | div_up(fmaCount, necAcc)); |
6256 | } /* icomp */ |
6257 | |
6258 | x1i += fmaCount; |
6259 | x1New = x1i; |
6260 | } /* x1i */ |
6261 | } /* y1 */ |
6262 | } /* hh */ |
6263 | } /* ivcomp */ |
6264 | x1 = x1New; |
6265 | } /* x1 */ |
6266 | } /* y0 */ |
6267 | } /* ocomp */ |
6268 | } /* ovcomp */ |
6269 | } /* x0 */ |
6270 | } |
6271 | |
6272 | template <HW hw> |
6273 | void gemm_kernel_generator_t<hw>::outerProductGen9IGEMM(int ha, int hb, |
6274 | const vector<RegisterBlock> &A_layout, |
6275 | const vector<RegisterBlock> &B_layout, const GRFMultirange &A_regs, |
6276 | const GRFMultirange &B_regs, const GEMMProblem &problem, |
6277 | const GEMMStrategy &strategy, GEMMState &state) { |
6278 | auto Ta = problem.Ta, Tb = problem.Tb, Tc = problem.Tc; |
6279 | DataType tempType |
6280 | = (Ta.isSigned() || Tb.isSigned()) ? DataType::w : DataType::uw; |
6281 | |
6282 | struct AddItem { |
6283 | int simd; |
6284 | RegData dest, src0, src1; |
6285 | }; |
6286 | std::vector<AddItem> adds; |
6287 | |
6288 | auto replayAdds = [&]() { |
6289 | for (auto &item : adds) |
6290 | add(item.simd, item.dest, item.src0, item.src1); |
6291 | adds.clear(); |
6292 | }; |
6293 | |
6294 | bool globalCM = isLayoutColMajor(state.C_layout); |
6295 | |
6296 | // Decide whether to loop in column or row major order. |
6297 | int nx = globalCM ? strategy.unroll[LoopM] : strategy.unroll[LoopN]; |
6298 | int ny = globalCM ? strategy.unroll[LoopN] : strategy.unroll[LoopM]; |
6299 | |
6300 | int tidx = 0; |
6301 | for (int y = 0; y < ny; y++) { |
6302 | for (int x = 0; x < nx;) { |
6303 | auto i = globalCM ? x : y; |
6304 | auto j = globalCM ? y : x; |
6305 | |
6306 | int fmaCount; |
6307 | |
6308 | // Find the appropriate A and B registers. |
6309 | int na, nb; |
6310 | const RegisterBlock *A_block, *B_block; |
6311 | Subregister A |
6312 | = findBlockReg(Ta, A_layout, i, ha, A_regs, na, A_block); |
6313 | Subregister B |
6314 | = findBlockReg(Tb, B_layout, hb, j, B_regs, nb, B_block); |
6315 | |
6316 | // Find the appropriate C register. Todo: remainders. |
6317 | int nc; |
6318 | const RegisterBlock *C_block; |
6319 | Subregister C = findBlockReg( |
6320 | Tc, state.C_layout, i, j, state.C_regs[0], nc, C_block); |
6321 | |
6322 | // No C crosspack support. |
6323 | auto cpA = A_block->crosspack, cpB = B_block->crosspack; |
6324 | if (C_block->crosspack > 1) stub(); |
6325 | |
6326 | // Swap out C register for an accumulator, if necessary. |
6327 | auto C_roff = C.getBase() - state.C_regs[0].ranges[0].getBase(); |
6328 | if (C_roff < state.C_accCount) |
6329 | C = AccumulatorRegister(C_roff).sub(C.getOffset(), Tc.ngen()); |
6330 | |
6331 | // Use requested execution size if possible, but limited to available elements. |
6332 | // Decide the kernel type based on register block layouts. |
6333 | bool canColMajor = (A_block->colMajor && C_block->colMajor); |
6334 | bool canRowMajor = (!B_block->colMajor && !C_block->colMajor); |
6335 | bool colMajor; |
6336 | |
6337 | if (!canColMajor && !canRowMajor) { |
6338 | colMajor = true; |
6339 | fmaCount = 1; |
6340 | } else if (canColMajor) { |
6341 | colMajor = true; |
6342 | fmaCount = na; |
6343 | } else { |
6344 | colMajor = false; |
6345 | fmaCount = nb; |
6346 | } |
6347 | fmaCount = rounddown_pow2(std::min( |
6348 | {strategy.fmaSIMD, nc, elementsPerGRF<int16_t>(hw)})); |
6349 | |
6350 | auto temp = state.tempMul_regs[tidx++]; |
6351 | |
6352 | if (C.isARF()) { |
6353 | if (colMajor) |
6354 | mac(fmaCount, C(1), A(cpA), B(0)); |
6355 | else |
6356 | mac(fmaCount, C(1), A(0), B(cpB)); |
6357 | } else { |
6358 | if (colMajor) |
6359 | mul(fmaCount, temp[0].sub(0, tempType)(2), A(cpA), B(0)); |
6360 | else |
6361 | mul(fmaCount, temp[0].sub(0, tempType)(2), A(0), B(cpB)); |
6362 | |
6363 | adds.push_back( |
6364 | {fmaCount, C(1), C(1), temp[0].sub(0, tempType)(2)}); |
6365 | } |
6366 | |
6367 | if (tidx >= int(state.tempMul_regs.size())) { |
6368 | tidx = 0; |
6369 | replayAdds(); |
6370 | } |
6371 | |
6372 | x += fmaCount; |
6373 | } |
6374 | } |
6375 | |
6376 | replayAdds(); |
6377 | |
6378 | // A4B4 outer product (4 temporary GRFs per 2 C registers) - 2/3 SP |
6379 | // |
6380 | // mul (32) temp0.0:w<1> A.0:b<32;16,2> B.0:b<32;16,2> - EM |
6381 | // mul (32) temp2.0:w<1> A.1:b<32;16,2> B.1:b<32;16,2> - FPU |
6382 | // add (16) C.0:d<1> C.0:d<8;8,1> temp0.0:w<16;8,2> - EM |
6383 | // add (16) C.0:d<1> C.0:d<8;8,1> temp0.1:w<16;8,2> - FPU |
6384 | // add (16) C.0:d<1> C.0:d<8;8,1> temp2.0:w<16;8,2> - EM |
6385 | // add (16) C.0:d<1> C.0:d<8;8,1> temp2.1:w<16;8,2> - FPU |
6386 | |
6387 | // Faster A4B4 outer product a la non-VNNI (4 temporary GRFs per 2 C registers) - 4/5 SP |
6388 | // |
6389 | // mul (32) temp0.0:w<1> A.0:b<32;16,2> B.0:b<32;16,2> - EM |
6390 | // mul (32) temp2.0:w<1> A.1:b<32;16,2> B.1:b<32;16,2> - FPU |
6391 | // add (32) (sat) temp0.0:w<1> temp0.0:w<1> temp2.0:w<1> - EM/FPU |
6392 | // add (16) C.0:d<1> C.0:d<8;8,1> temp0.0:w<16;8,2> - EM |
6393 | // add (16) C.0:d<1> C.0:d<8;8,1> temp0.1:w<16;8,2> - FPU |
6394 | } |
6395 | |
6396 | static int elementDiff(HW hw, const RegData &r1, const RegData &r2) { |
6397 | return elementsPerGRF(hw, r1.getType()) * (r1.getBase() - r2.getBase()) |
6398 | + (r1.getOffset() - r2.getOffset()); |
6399 | } |
6400 | |
6401 | // Accumulate multiple outer products using the systolic array. |
6402 | template <HW hw> |
6403 | void gemm_kernel_generator_t<hw>::outerProductSystolic(int h, int ha, int hb, |
6404 | const vector<RegisterBlock> &A_layout, |
6405 | const vector<RegisterBlock> &B_layout, const GRFMultirange &A_regs, |
6406 | const GRFMultirange &B_regs, const GEMMProblem &problem, |
6407 | const GEMMStrategy &strategy, GEMMState &state) { |
6408 | auto Ta = problem.Ta, Tb = problem.Tb, Tc = problem.Tc; |
6409 | bool globalCM = isLayoutColMajor(state.C_layout); |
6410 | auto params = systolicParams(hw, problem, strategy); |
6411 | auto ksys = params.ksys; |
6412 | auto osys = params.osys; |
6413 | auto sdepth = params.sdepth; |
6414 | auto rcountMax = params.rcountMax; |
6415 | int dpaswTile = strategy.dpasw |
6416 | ? (globalCM ? strategy.B.tileC : strategy.A.tileR) |
6417 | : 0; |
6418 | bool rsFix = (strategy.readSuppressionWA |
6419 | && hasMasking(globalCM ? A_layout : B_layout)); |
6420 | bool canAtomicNon8x8 |
6421 | = (hw >= HW::XeHPC) && (getStepping() >= SteppingPVCXTB0); |
6422 | |
6423 | bool sum = globalCM ? state.systolicSumA : state.systolicSumB; |
6424 | |
6425 | RegisterBlock sumBlock; |
6426 | sumBlock.colMajor = globalCM; |
6427 | sumBlock.crosspack = 1; |
6428 | |
6429 | // dpas processes ksys outer products at once. |
6430 | ha = align_down(ha, ksys); |
6431 | hb = align_down(hb, ksys); |
6432 | |
6433 | // Decide whether to loop in column or row major order, to facilitate macro sequences. |
6434 | // x is the non-accumulating dimension of dpas src2 (N matrix) |
6435 | // y is the non-accumulating dimension of dpas src1 (V matrix) |
6436 | int nx = strategy.unroll[globalCM ? LoopN : LoopM]; |
6437 | int ny = strategy.unroll[globalCM ? LoopM : LoopN]; |
6438 | |
6439 | int yinc = osys; |
6440 | |
6441 | const int compA = 0, compB = 0, compC = 0; |
6442 | const int incompCount = 1, oncompCount = 1; |
6443 | |
6444 | for (int y = 0; y < ny; y += yinc) { |
6445 | Subregister A0, B0, C0; |
6446 | int rcount = 0, x0 = 0; |
6447 | |
6448 | auto issueDPAS = [&](bool last) { |
6449 | while (rcount > 0) { |
6450 | InstructionModifier mod = osys; |
6451 | |
6452 | bool useDPASW = strategy.dpasw && x0 < nx; |
6453 | auto rc2 = rounddown_pow2(rcount); |
6454 | auto rc = rc2 * (useDPASW ? 2 : 1); |
6455 | auto &V0 = globalCM ? A0 : B0; |
6456 | auto &N0 = globalCM ? B0 : A0; |
6457 | |
6458 | if (rsFix) { |
6459 | GRF v0GRF {V0.getBase()}; |
6460 | mov<uint32_t>(8, v0GRF, v0GRF); |
6461 | rsFix = false; |
6462 | } |
6463 | |
6464 | if (strategy.atomicFMA) |
6465 | if (!(last && (rc2 == rcount))) |
6466 | if (rc == 8 || canAtomicNon8x8) mod |= Atomic; |
6467 | |
6468 | useDPASW ? dpasw(mod, sdepth, rc, C0, C0, V0, N0) |
6469 | : dpas(mod, sdepth, rc, C0, C0, V0, N0); |
6470 | |
6471 | rcount -= rc2; |
6472 | x0 += rc; |
6473 | N0.setBase(N0.getBase() + rc2); |
6474 | C0.setBase(C0.getBase() + rc2); |
6475 | } |
6476 | }; |
6477 | |
6478 | for (int oncomp = 0; oncomp < oncompCount; oncomp++) { |
6479 | for (int x = 0; x < nx + sum; x++) { |
6480 | for (int incomp = 0; incomp < incompCount; incomp++) { |
6481 | // Find the appropriate A and B registers. |
6482 | int na, nb, nc; |
6483 | const RegisterBlock *A_block, *B_block, *C_block; |
6484 | Subregister A, B, C; |
6485 | |
6486 | const int cxCompA = -1, cxCompB = -1, cxCompC = -1, |
6487 | cBuffer = 0; |
6488 | |
6489 | if (x < nx) { |
6490 | if (strategy.dpasw |
6491 | && (x % (2 * dpaswTile) >= dpaswTile)) |
6492 | continue; |
6493 | |
6494 | int i = globalCM ? y : x; |
6495 | int j = globalCM ? x : y; |
6496 | |
6497 | A = findBlockReg(Ta, A_layout, i, ha, A_regs, na, |
6498 | A_block, cxCompA, compA); |
6499 | B = findBlockReg(Tb, B_layout, hb, j, B_regs, nb, |
6500 | B_block, cxCompB, compB); |
6501 | C = findBlockReg(Tc, state.C_layout, i, j, |
6502 | state.C_regs[cBuffer], nc, C_block, cxCompC, |
6503 | compC); |
6504 | } else if (state.systolicSumA) { |
6505 | A = findBlockReg( |
6506 | Ta, A_layout, y, ha, A_regs, na, A_block); |
6507 | B = state.sysSumAll1s[0]; |
6508 | nb = elementsPerGRF(hw, Tb); |
6509 | B_block = &sumBlock; |
6510 | C = findBlockReg(Tc, state.As_layout, y, 0, |
6511 | state.As_regs, nc, C_block); |
6512 | } else { |
6513 | A = state.sysSumAll1s[0]; |
6514 | na = elementsPerGRF(hw, Ta); |
6515 | A_block = &sumBlock; |
6516 | B = findBlockReg( |
6517 | Tb, B_layout, hb, y, B_regs, nb, B_block); |
6518 | C = findBlockReg(Tc, state.Bs_layout, 0, y, |
6519 | state.Bs_regs, nc, C_block); |
6520 | } |
6521 | |
6522 | int nv = globalCM ? na : nb; |
6523 | int nn = globalCM ? nb : na; |
6524 | |
6525 | // Verify DPAS requirements. |
6526 | if (globalCM) { |
6527 | if (A_block->crosspack * Ta.real().size() |
6528 | != std::max(4, Ta.real().size())) |
6529 | stub(); |
6530 | if (B_block->crosspack > 1) stub(); |
6531 | } else { |
6532 | if (B_block->crosspack * Tb.real().size() |
6533 | != std::max(4, Tb.real().size())) |
6534 | stub(); |
6535 | if (A_block->crosspack > 1) stub(); |
6536 | } |
6537 | if (A_block->colMajor != globalCM |
6538 | || B_block->colMajor != globalCM) |
6539 | stub(); |
6540 | if (C_block->crosspack > 1) stub(); |
6541 | |
6542 | if (nv != osys) stub(); |
6543 | if (nn < ksys) stub(); |
6544 | |
6545 | // Check if current DPAS can be fused with the previous one. |
6546 | bool chain = false; |
6547 | if (A0.isValid()) { |
6548 | chain = globalCM |
6549 | ? (elementDiff(hw, B, B0) == (x - x0) * ksys) |
6550 | : (elementDiff(hw, A, A0) == (x - x0) * ksys); |
6551 | chain = chain |
6552 | && (elementDiff(hw, C, C0) == (x - x0) * osys); |
6553 | chain = chain && (rcount < rcountMax); |
6554 | if (strategy.dpasw) |
6555 | chain = chain && x < nx |
6556 | && (x % (2 * dpaswTile) > 0); |
6557 | } |
6558 | |
6559 | if (chain) |
6560 | rcount++; |
6561 | else { |
6562 | if (strategy.dpasw && x < nx && rcount > 0 |
6563 | && rcount != dpaswTile) |
6564 | stub(); |
6565 | if (A0.isValid()) issueDPAS(false); |
6566 | A0 = A; |
6567 | B0 = B; |
6568 | C0 = C; |
6569 | rcount = 1; |
6570 | A0.setType(Ta.ngen()); |
6571 | B0.setType(Tb.ngen()); |
6572 | C0.setType(Tc.ngen()); |
6573 | x0 = x; |
6574 | } |
6575 | } /* incomp loop */ |
6576 | } /* x loop */ |
6577 | } /* oncomp loop */ |
6578 | |
6579 | bool finishChain = !strategy.extendedAtomicFMA || (y + osys >= ny); |
6580 | issueDPAS(finishChain); |
6581 | } /* y loop */ |
6582 | } |
6583 | |
6584 | // Decide whether to use the legacy post-op injector inside C update. |
6585 | // Needed if we can't convert C to f32 in-place, but doesn't support binary post-ops. |
6586 | static inline bool useEltwiseInjector(const GEMMProblem &problem) { |
6587 | return problem.hasPostOp() && (problem.Tc.size() < 4); |
6588 | } |
6589 | |
6590 | // Perform C update operation on C_acc, given original C data in C_load. |
6591 | // All inputs and outputs are assumed to be of type problem.Ts. |
6592 | template <HW hw> |
6593 | void gemm_kernel_generator_t<hw>::updateC(const GRFMultirange &C_acc, |
6594 | const GRFMultirange &C_accSwap, const GRFMultirange &C_load, |
6595 | GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state) { |
6596 | auto &alphar = problem.alpha_real; |
6597 | auto &betar = problem.beta_real; |
6598 | bool alpha1 = (alphar == 1); |
6599 | bool alphaM1 = (alphar == -1); |
6600 | bool beta1 = (betar == 1); |
6601 | bool beta0 = (betar == 0); |
6602 | bool betaM1 = (betar == -1); |
6603 | |
6604 | #define FOR_EACH_C(f) \ |
6605 | do { \ |
6606 | map(hw, state.Tacc.real(), C_load, C_acc, strategy, \ |
6607 | [&](int esize, GRF loaded, GRF acc) { f; }); \ |
6608 | } while (false) |
6609 | |
6610 | #define FOR_EACH_C_CX(f) \ |
6611 | do { \ |
6612 | map(hw, state.Tacc.real(), C_load, C_acc, C_accSwap, strategy, \ |
6613 | [&](int esize, GRF loaded, GRF acc, GRF accswap) { f; }); \ |
6614 | } while (false) |
6615 | |
6616 | if (!beta0) { |
6617 | if (alpha1 || alphaM1) { |
6618 | if (beta1) |
6619 | FOR_EACH_C(add(esize, acc, loaded, alpha1 ? acc : -acc)); |
6620 | else if (betaM1) |
6621 | FOR_EACH_C(add(esize, acc, -loaded, alpha1 ? acc : -acc)); |
6622 | else if (betar.fixed()) |
6623 | stub(); // beta should be put in a register first. |
6624 | else { |
6625 | if (!strategy.doubleWA) |
6626 | FOR_EACH_C(mad(esize, acc, alpha1 ? acc : -acc, loaded, |
6627 | betar.getRegAvoiding(hw, loaded))); |
6628 | else { |
6629 | FOR_EACH_C(mul(esize, loaded, loaded, |
6630 | betar.getRegAvoiding(hw, loaded))); |
6631 | FOR_EACH_C(add(esize, acc, loaded, alpha1 ? acc : -acc)); |
6632 | } |
6633 | } |
6634 | } else { |
6635 | bool neg = false; |
6636 | if (!beta1) { |
6637 | if (betaM1) |
6638 | neg = true; |
6639 | else if (!betar.fixed()) |
6640 | FOR_EACH_C(mul(esize, loaded, loaded, |
6641 | betar.getRegAvoiding(hw, acc))); |
6642 | else |
6643 | stub(); |
6644 | } |
6645 | if (alphar.fixed()) |
6646 | stub(); // alpha should be put in a register first. |
6647 | else { |
6648 | if (!strategy.doubleWA) |
6649 | FOR_EACH_C(mad(esize, acc, neg ? -loaded : loaded, acc, |
6650 | alphar.getRegAvoiding(hw, acc))); |
6651 | else { |
6652 | FOR_EACH_C(mul( |
6653 | esize, acc, acc, alphar.getRegAvoiding(hw, acc))); |
6654 | FOR_EACH_C(add(esize, acc, neg ? -loaded : loaded, acc)); |
6655 | } |
6656 | } |
6657 | } |
6658 | } else if (alphaM1) |
6659 | FOR_EACH_C(mov(esize, acc, -acc)); |
6660 | else if (alpha1) |
6661 | /* no op */; |
6662 | else if (alphar.fixed()) |
6663 | stub(); // alpha should be put in a register first. |
6664 | else { |
6665 | FOR_EACH_C(mul(esize, acc, acc, alphar.getRegAvoiding(hw, acc))); |
6666 | } |
6667 | |
6668 | if (useEltwiseInjector(problem)) { |
6669 | Label labelPostOpDone; |
6670 | bool allocFlag = state.flagAP.isInvalid(); |
6671 | auto flagNonfinal = allocFlag ? state.raVFlag.alloc() : state.flagAP; |
6672 | and_(1 | nz | flagNonfinal, null.ud(), state.inputs.flags, |
6673 | FlagNonfinalKBlock); |
6674 | jmpi(1 | flagNonfinal, labelPostOpDone); |
6675 | if (allocFlag) state.raVFlag.safeRelease(flagNonfinal); |
6676 | if (state.Tacc != Type::f32 || !postOpInjector) stub(); |
6677 | for (const auto &range : C_acc.ranges) |
6678 | postOpInjector->compute(range); |
6679 | mark(labelPostOpDone); |
6680 | } |
6681 | |
6682 | #undef FOR_EACH_C |
6683 | #undef FOR_EACH_C_CX |
6684 | } |
6685 | |
6686 | template <HW hw> |
6687 | bool gemm_kernel_generator_t<hw>::reblockLayout(Type Tdst, |
6688 | vector<int32_t> &blockMap, vector<RegisterBlock> &layoutDst, |
6689 | const vector<RegisterBlock> &layoutRef, |
6690 | const vector<RegisterBlock> &layoutSrc, const MatrixAddressing &atype, |
6691 | const MatrixAddressingStrategy &astrategy) { |
6692 | auto nblockRef = layoutRef.size(); |
6693 | layoutDst.clear(); |
6694 | layoutDst.reserve(nblockRef); |
6695 | blockMap.clear(); |
6696 | blockMap.reserve(nblockRef + 1); |
6697 | blockMap.push_back(0); |
6698 | for (auto &blockRef : layoutRef) { |
6699 | RegisterBlock blockDst, blockMid; |
6700 | for (auto &blockSrc : layoutSrc) { |
6701 | int rr1 = blockRef.offsetR - blockSrc.offsetR, |
6702 | rr2 = rr1 + blockRef.nr; |
6703 | int cc1 = blockRef.offsetC - blockSrc.offsetC, |
6704 | cc2 = cc1 + blockRef.nc; |
6705 | if (rr1 >= blockSrc.nr || rr2 <= 0) continue; |
6706 | if (cc1 >= blockSrc.nc || cc2 <= 0) continue; |
6707 | rr1 = std::max(rr1, 0); |
6708 | cc1 = std::max(cc1, 0); |
6709 | rr2 = std::min(rr2, int(blockSrc.nr)); |
6710 | cc2 = std::min(cc2, int(blockSrc.nc)); |
6711 | if (!getSubblock(Tdst, blockMid, blockSrc, false, rr1, rr2, rr1, |
6712 | rr2, true, atype, astrategy)) |
6713 | return false; |
6714 | if (!getSubblock(Tdst, blockDst, blockMid, true, cc1, cc2, cc1, cc2, |
6715 | true, atype, astrategy)) |
6716 | return false; |
6717 | layoutDst.push_back(blockDst); |
6718 | } |
6719 | blockMap.push_back(int32_t(layoutDst.size())); |
6720 | } |
6721 | return true; |
6722 | } |
6723 | |
6724 | // Update an entire C layout. |
6725 | template <HW hw> |
6726 | void gemm_kernel_generator_t<hw>::updateCLayout( |
6727 | const vector<RegisterBlock> &layoutExt, const GRFRange (&C_addr0)[2], |
6728 | const RegisterBlock &C_block0, COperation op, GEMMProblem &problem, |
6729 | GEMMStrategy &strategy, GEMMState &state) { |
6730 | #define FOR_EACH_C for (int q = 0; q < C_count; q++) |
6731 | auto Tc = problem.Tc, Tc_ext = problem.Tc_ext, Ts = problem.Ts; |
6732 | bool loadOnly = (op == COperation::Load); |
6733 | bool beta0 = problem.beta0(); |
6734 | bool needLoad = (!beta0 && !loadOnly); |
6735 | bool copyC = state.copyC; |
6736 | int C_count = (op == COperation::UpdateStore) ? state.C_count : 1; |
6737 | |
6738 | auto nblocks = int(layoutExt.size()); |
6739 | bool haveDescs = layoutExt[0].descAssigned; |
6740 | |
6741 | vector<GRFRange>(&C_addrs)[2] = state.C_addrs; |
6742 | GRFMultirange , C_copyRange; |
6743 | GRFMultirange &C_accRange = state.C_regs[0]; |
6744 | auto &C_extRegs = C_extRange.ranges; |
6745 | auto &C_copyRegs = C_copyRange.ranges; |
6746 | vector<GRFRange> C_convertRegs; |
6747 | |
6748 | for (int q = 0; q < C_count; q++) |
6749 | C_addrs[0].clear(); |
6750 | |
6751 | // Map layout to blocks in internal C layout as needed. |
6752 | vector<RegisterBlock> layout; |
6753 | vector<int> blockMap; |
6754 | if (copyC) { |
6755 | if (!reblockLayout(Tc, blockMap, layout, layoutExt, state.C_layout, |
6756 | problem.C, strategy.C)) |
6757 | stub(); |
6758 | } else { |
6759 | layout = layoutExt; |
6760 | blockMap.resize(nblocks + 1); |
6761 | for (int i = 0; i <= nblocks; i++) |
6762 | blockMap[i] = i; |
6763 | } |
6764 | |
6765 | // Prepare for late C conversion. |
6766 | bool lateCConvert = (!loadOnly && !strategy.C.atomic |
6767 | && problem.needsTsConvert() && state.Tacc != Ts); |
6768 | bool copyCLoad = needLoad && (copyC || lateCConvert); |
6769 | if (lateCConvert && Tc.isComplex()) stub(); |
6770 | |
6771 | // Load as much of C as is possible at a time, given register space. |
6772 | for (int lstart = 0; lstart < nblocks;) { |
6773 | int lend; |
6774 | |
6775 | // Allocate address and data registers for C updating. If allocator chokes, |
6776 | // proceed with the registers we were able to allocate. |
6777 | // |
6778 | // At the same time, build up three layouts for this chunk of C: |
6779 | // sublayoutExt: C data to be loaded/stored |
6780 | // sublayoutCopy: copied C data |
6781 | // sublayoutAcc: C data in accumulators |
6782 | bool allocOK = true; |
6783 | auto tryAlloc = [&](int regs, Bundle hint = Bundle()) { |
6784 | auto range = state.ra.try_alloc_range(regs, hint); |
6785 | allocOK &= range.isValid(); |
6786 | return range; |
6787 | }; |
6788 | |
6789 | vector<RegisterBlock> sublayoutExt, sublayoutCopy, sublayoutAcc; |
6790 | size_t sublayoutCopySize = 0; |
6791 | int bytes = 0, bytesConvert = 0; |
6792 | int tokens = 0, maxTokens = 256; |
6793 | if (needLoad && hw >= HW::Gen12LP) maxTokens = tokenCount(hw); |
6794 | |
6795 | for (lend = lstart; (lend < nblocks) && (tokens < maxTokens); |
6796 | lend++, tokens++) { |
6797 | auto li0 = blockMap[lend], li1 = blockMap[lend + 1]; |
6798 | int expand |
6799 | = lateCConvert ? div_up(Ts.size(), state.Tacc.size()) : 1; |
6800 | |
6801 | if (copyCLoad) |
6802 | for (int li = li0; li < li1; li++) { |
6803 | auto block = layout[li]; |
6804 | block.compact(state.Tacc); |
6805 | block.offsetBytes = bytesConvert; |
6806 | bytesConvert += block.nregs() * expand * GRF::bytes(hw); |
6807 | sublayoutCopy.push_back(block); |
6808 | } |
6809 | |
6810 | auto blockExt = layoutExt[lend]; |
6811 | auto naddr = addrGRFCount(problem.C, strategy.C, blockExt); |
6812 | FOR_EACH_C C_addrs[q].push_back( |
6813 | (blockExt.offsetR == 0 && blockExt.offsetC == 0) |
6814 | ? C_addr0[q] |
6815 | : tryAlloc(naddr)); |
6816 | if (needLoad || copyC) |
6817 | C_extRegs.push_back(tryAlloc( |
6818 | blockExt.nregs(), getHint(HintType::CLoad, strategy))); |
6819 | if (copyCLoad) |
6820 | for (int li = li0; li < li1; li++) |
6821 | C_copyRegs.push_back(tryAlloc( |
6822 | sublayoutCopy[li - li0 + sublayoutCopySize].nregs() |
6823 | * expand, |
6824 | getHint(HintType::CLoad, strategy))); |
6825 | if (lateCConvert) |
6826 | for (int li = li0; li < li1; li++) |
6827 | C_convertRegs.push_back( |
6828 | tryAlloc(layout[li].nregs() * expand)); |
6829 | if (!allocOK) break; |
6830 | |
6831 | blockExt.offsetBytes = bytes; |
6832 | bytes += blockExt.nregs() * GRF::bytes(hw); |
6833 | sublayoutExt.push_back(blockExt); |
6834 | |
6835 | sublayoutCopySize = sublayoutCopy.size(); |
6836 | } |
6837 | |
6838 | sublayoutCopy.resize(sublayoutCopySize); |
6839 | |
6840 | int listart = blockMap[lstart]; |
6841 | int liend = blockMap[lend]; |
6842 | |
6843 | sublayoutAcc.reserve(liend - listart); |
6844 | for (int l = listart; l < liend; l++) |
6845 | sublayoutAcc.push_back(layout[l]); |
6846 | |
6847 | // Set up C addresses relative to prior blocks. |
6848 | for (int l = lstart; l < lend; l++) { |
6849 | auto &block = sublayoutExt[l - lstart]; |
6850 | int bbase = findBaseBlock( |
6851 | block, sublayoutExt, 0, l - lstart, problem.C, strategy.C); |
6852 | FOR_EACH_C { |
6853 | auto &blockSrc = (bbase >= 0) ? sublayoutExt[bbase] : C_block0; |
6854 | auto &addrSrc = (bbase >= 0) ? C_addrs[q][bbase] : C_addr0[q]; |
6855 | setupAddrRel(Tc_ext, C_addrs[q][l - lstart], addrSrc, block, |
6856 | blockSrc, state.C_layout, state.inputs.ldc[q], |
6857 | problem.C, strategy.C, strategy, state, |
6858 | state.ldcMultiples[q]); |
6859 | } |
6860 | } |
6861 | |
6862 | if (strategy.C.atomic) { |
6863 | // Atomic update. |
6864 | // Alpha scaling is done earlier; beta scaling isn't supported. |
6865 | if (!problem.alpha1() || !problem.beta1()) stub(); |
6866 | if (copyC) |
6867 | if (!copyRegisters(state.Tacc, Tc_ext, sublayoutAcc, |
6868 | sublayoutExt, C_accRange, C_extRange, 0, 0, false, |
6869 | strategy, state)) |
6870 | stub(); |
6871 | |
6872 | auto &sublayoutSrc = copyC ? sublayoutExt : sublayoutAcc; |
6873 | auto &C_srcRange = copyC ? C_extRange : C_accRange; |
6874 | FOR_EACH_C atomicAddMatrix(Tc_ext, C_srcRange, sublayoutSrc, |
6875 | problem.C, strategy.C, C_addrs[q], problem, strategy, |
6876 | state); |
6877 | } else { |
6878 | // Data types before and after scaling phase. |
6879 | auto Tacc_final = Tc; |
6880 | if (op == COperation::Update |
6881 | || (op == COperation::UpdateStore && copyC)) |
6882 | Tacc_final = state.Tacc; |
6883 | |
6884 | // Regular update. |
6885 | auto Tload = Tc_ext; |
6886 | if (!beta0 || loadOnly) { |
6887 | // Set up a0.0 descriptor for loads if needed. |
6888 | if (lstart > 0 && haveDescs) mov(1, a0.ud(0), a0.ud(3)); |
6889 | |
6890 | // Load C data. |
6891 | auto &sublayoutLoad |
6892 | = (loadOnly && !copyC) ? sublayoutAcc : sublayoutExt; |
6893 | auto &C_loadRange |
6894 | = (loadOnly && !copyC) ? C_accRange : C_extRange; |
6895 | loadMatrix(C_loadRange, sublayoutLoad, problem.C, strategy.C, |
6896 | C_addrs[0], strategy, state); |
6897 | |
6898 | // Set up a0.0 descriptor for stores (and save load descriptors) if needed. |
6899 | if (haveDescs && !loadOnly) { |
6900 | if (lend < nblocks) mov(1, a0.ud(3), a0.ud(0)); |
6901 | mov(1, a0.ud(0), a0.ud(2)); |
6902 | } |
6903 | |
6904 | // Copy loaded data as needed. |
6905 | if (copyCLoad) { |
6906 | auto &sublayoutDst |
6907 | = loadOnly ? sublayoutAcc : sublayoutCopy; |
6908 | auto &C_dstRange = loadOnly ? C_accRange : C_copyRange; |
6909 | Tload = lateCConvert ? Ts : state.Tacc; |
6910 | if (!copyRegisters(Tc_ext, Tload, sublayoutExt, |
6911 | sublayoutDst, C_extRange, C_dstRange, 0, 0, |
6912 | false, strategy, state)) |
6913 | stub(); |
6914 | } |
6915 | } |
6916 | |
6917 | // Late C conversion. |
6918 | auto originalTacc = state.Tacc; |
6919 | if (lateCConvert) { |
6920 | for (int li = listart; li < liend; li++) { |
6921 | auto C_acc = state.C_regs[0].subrange( |
6922 | hw, state.Tacc, layout[li]); |
6923 | copyRegisterBlock(state.Tacc, Ts, layout[li], layout[li], |
6924 | C_acc, C_convertRegs[li - listart], 0, 0, strategy, |
6925 | state); |
6926 | } |
6927 | state.Tacc = Ts; |
6928 | } |
6929 | |
6930 | // Alpha/beta scaling and optional fp32<->int32 conversion. |
6931 | if (!loadOnly) |
6932 | for (int phase = 0; phase < 3; phase++) { |
6933 | vector<GRFMultirange> C_accs, C_accSwaps, C_loads; |
6934 | C_accs.reserve(liend - listart); |
6935 | C_accSwaps.reserve(liend - listart); |
6936 | C_loads.reserve(liend - listart); |
6937 | |
6938 | for (int li = listart; li < liend; li++) { |
6939 | GRFMultirange C_acc0 = state.C_regs[0].subrange( |
6940 | hw, state.Tacc, layout[li]); |
6941 | GRFMultirange C_acc = lateCConvert |
6942 | ? C_convertRegs[li - listart] |
6943 | : C_acc0; |
6944 | GRFMultirange C_accSwap; |
6945 | GRFMultirange C_load = beta0 |
6946 | ? C_acc |
6947 | : copyCLoad ? C_copyRegs[li - listart] |
6948 | : C_extRegs[li - listart]; |
6949 | switch (phase) { |
6950 | case 0: |
6951 | if (!beta0) |
6952 | convert(C_load, Tload, state.Tacc, problem, |
6953 | strategy, state); |
6954 | break; |
6955 | case 1: { |
6956 | C_accs.push_back(C_acc); |
6957 | C_accSwaps.push_back(C_accSwap); |
6958 | C_loads.push_back(C_load); |
6959 | } break; |
6960 | case 2: |
6961 | if (lateCConvert) |
6962 | copyRegisterBlock(state.Tacc, Tacc_final, |
6963 | layout[li], layout[li], C_acc, |
6964 | C_acc0, 0, 0, strategy, state); |
6965 | else |
6966 | convert(C_acc, state.Tacc, Tacc_final, |
6967 | problem, strategy, state); |
6968 | break; |
6969 | } |
6970 | } |
6971 | |
6972 | if (phase == 1) { |
6973 | std::vector<int> order(liend - listart); |
6974 | std::iota(order.begin(), order.end(), 0); |
6975 | std::sort( |
6976 | order.begin(), order.end(), [&](int a, int b) { |
6977 | auto *rangeA = &C_accs[a], |
6978 | *rangeB = &C_accs[b]; |
6979 | return (*rangeA)[0].getBase() |
6980 | < (*rangeB)[0].getBase(); |
6981 | }); |
6982 | GRFMultirange C_accsSorted, C_accSwapsSorted, |
6983 | C_loadsSorted; |
6984 | std::vector<RegisterBlock> C_accSortedLayout; |
6985 | |
6986 | bool remaskC_M = isPacked(problem.C.layout) |
6987 | && (strategy.remHandling[LoopM] |
6988 | != RemainderHandling::Ignore); |
6989 | bool remaskC_N = isPacked(problem.C.layout) |
6990 | && (strategy.remHandling[LoopN] |
6991 | != RemainderHandling::Ignore); |
6992 | |
6993 | for (int i = 0; i < (liend - listart); i++) { |
6994 | if (remaskC_M || remaskC_N) { |
6995 | auto block = layout[listart + order[i]]; |
6996 | block.offsetBytes = C_accsSorted.getLen() |
6997 | << GRF::log2Bytes(hw); |
6998 | C_accSortedLayout.push_back(block); |
6999 | } |
7000 | |
7001 | C_accsSorted.append(C_accs[order[i]]); |
7002 | C_accSwapsSorted.append(C_accSwaps[order[i]]); |
7003 | C_loadsSorted.append(C_loads[order[i]]); |
7004 | } |
7005 | |
7006 | updateC(C_accsSorted, C_accSwapsSorted, C_loadsSorted, |
7007 | problem, strategy, state); |
7008 | |
7009 | if (remaskC_M) |
7010 | remaskLayout(state.Tacc, 0, false, |
7011 | C_accSortedLayout, C_accsSorted, strategy, |
7012 | state); |
7013 | if (remaskC_N) |
7014 | remaskLayout(state.Tacc, 1, true, C_accSortedLayout, |
7015 | C_accsSorted, strategy, state); |
7016 | } |
7017 | } |
7018 | |
7019 | state.Tacc = Tacc_final; |
7020 | |
7021 | // Store updated data. |
7022 | if (op == COperation::UpdateStore) { |
7023 | if (copyC) |
7024 | if (!copyRegisters(state.Tacc, Tc_ext, sublayoutAcc, |
7025 | sublayoutExt, C_accRange, C_extRange, 0, 0, |
7026 | false, strategy, state)) |
7027 | stub(); |
7028 | |
7029 | auto &sublayoutSrc = copyC ? sublayoutExt : sublayoutAcc; |
7030 | auto &C_srcRange = copyC ? C_extRange : C_accRange; |
7031 | FOR_EACH_C storeMatrix(C_srcRange, sublayoutSrc, problem.C, |
7032 | strategy.C, C_addrs[q], strategy, state); |
7033 | } |
7034 | |
7035 | state.Tacc = originalTacc; |
7036 | } |
7037 | |
7038 | // Free address and data registers, including C accumulators that are no longer used... |
7039 | // ... except C_addr0. I need that! |
7040 | FOR_EACH_C safeReleaseRanges(C_addrs[q], state); |
7041 | safeReleaseRanges(C_extRange, state); |
7042 | safeReleaseRanges(C_copyRange, state); |
7043 | safeReleaseRanges(C_convertRegs, state); |
7044 | if (op == COperation::UpdateStore) |
7045 | for (int li = listart; li < liend; li++) |
7046 | for (int b = 0; b < state.C_buffers; b++) |
7047 | releaseRanges(state.C_regs[b].subrange( |
7048 | hw, state.Tacc, layout[li]), |
7049 | state); |
7050 | FOR_EACH_C state.ra.claim(C_addr0[q]); |
7051 | |
7052 | // Check for forward progress. |
7053 | if (lend == lstart) throw out_of_registers_exception(); |
7054 | lstart = lend; |
7055 | } |
7056 | |
7057 | // Re-claim all the C registers we freed, so as not to disturb the caller's RegisterAllocator. |
7058 | reclaimRanges(state.C_regs[0], state); |
7059 | #undef FOR_EACH_C |
7060 | } |
7061 | |
7062 | // Assign runtime-computed descriptor information to all blocks in this layout. |
7063 | // Returns true if successful; false if not all blocks in layout are compatible. |
7064 | static inline bool assignAllDescs(vector<RegisterBlock> &layout) { |
7065 | for (auto &block : layout) { |
7066 | if (block.simdSize != layout[0].simdSize) return false; |
7067 | block.descAssigned = true; |
7068 | block.sfid = layout[0].sfid; |
7069 | } |
7070 | |
7071 | return true; |
7072 | } |
7073 | |
7074 | // Output code for standard C remainder handling. |
7075 | template <HW hw> |
7076 | bool gemm_kernel_generator_t<hw>::doStdCRemainder( |
7077 | vector<RegisterBlock> &layoutExt, |
7078 | vector<RegisterBlock> &layoutExtUnmasked, bool inside, bool columns[2], |
7079 | StdCRemType remTypes[2], bool fragments[2], bool fragPositives[2], |
7080 | int fragSizes[2], const GRFRange (&C_addr0)[2], |
7081 | const GRFRange (&C_addr0Unmasked)[2], COperation op, |
7082 | vector<MaskAssignment> &masks, GEMMProblem &problem, |
7083 | GEMMStrategy &strategy, GEMMState state) { |
7084 | auto Tc_ext = problem.Tc_ext; |
7085 | auto column = columns[inside]; |
7086 | LoopType loop = column ? LoopN : LoopM; |
7087 | auto remType = remTypes[loop]; |
7088 | auto fragment = fragments[loop]; |
7089 | auto fragPositive = fragPositives[loop]; |
7090 | auto fragSize = fragSizes[loop]; |
7091 | auto unroll = strategy.unroll[loop]; |
7092 | auto remainder = state.remainders[loop]; |
7093 | |
7094 | bool canEOT = !state.isNested && (op == COperation::UpdateStore); |
7095 | |
7096 | Label lEnd; |
7097 | |
7098 | // The "q" dimension is the one whose remainder we are currently handling. |
7099 | auto RegisterBlock::*nq = column ? &RegisterBlock::nc : &RegisterBlock::nr; |
7100 | auto RegisterBlock::*offsetQ |
7101 | = column ? &RegisterBlock::offsetC : &RegisterBlock::offsetR; |
7102 | |
7103 | // Status message. |
7104 | status << "C remainder handling (" << char('m' + column) << ") " << remType; |
7105 | if (fragment) status << ", fragment" ; |
7106 | if (fragPositive) status << ", no empty accesses" ; |
7107 | status << status_stream::endl; |
7108 | |
7109 | // Allocate temporaries for emulated atomic addition if needed. |
7110 | if (!inside && strategy.C.atomic) |
7111 | allocEAtomicAddRegs( |
7112 | hw, Tc_ext, layoutExt, problem.C, strategy.C, state); |
7113 | |
7114 | // Handle a subproblem. Return true if successful. |
7115 | auto descend = [&](vector<RegisterBlock> &sublayoutExt, |
7116 | vector<RegisterBlock> &sublayoutExtUnmasked, |
7117 | bool full = false) -> bool { |
7118 | bool success = true; |
7119 | auto nMasksOriginal = int(masks.size()); |
7120 | |
7121 | if (remType == StdCRemType::Mask) { |
7122 | if (!full) { |
7123 | // Assign and load any extra masks needed. |
7124 | if (!assignMasks( |
7125 | sublayoutExt, LoopM, LoopN, masks, strategy, state)) |
7126 | return false; |
7127 | loadMasks(masks, state.remainders, strategy, state, |
7128 | nMasksOriginal); |
7129 | sublayoutExtUnmasked.clear(); |
7130 | } else { |
7131 | // Clear out mask assignments in this dimension. |
7132 | for (auto &block : layoutExt) |
7133 | block.clearFlag(); |
7134 | } |
7135 | } |
7136 | |
7137 | // Recursively handle subproblem. |
7138 | if (!inside) |
7139 | success = doStdCRemainder(sublayoutExt, sublayoutExtUnmasked, true, |
7140 | columns, remTypes, fragments, fragPositives, fragSizes, |
7141 | C_addr0, C_addr0Unmasked, op, masks, problem, strategy, |
7142 | state); |
7143 | else if (sublayoutExtUnmasked.empty()) |
7144 | updateCLayout(sublayoutExt, C_addr0, state.C_layoutExt[0], op, |
7145 | problem, strategy, state); |
7146 | else |
7147 | updateCLayout(sublayoutExtUnmasked, C_addr0Unmasked, |
7148 | state.C_layoutExtUnmasked[0], op, problem, strategy, state); |
7149 | |
7150 | // Free any new masks. |
7151 | if (remType == StdCRemType::Mask) |
7152 | safeReleaseMaskAssignments(masks, state, nMasksOriginal); |
7153 | return success; |
7154 | }; |
7155 | |
7156 | // Exit remainder handling. |
7157 | auto done = [&]() { |
7158 | if (!canEOT) |
7159 | jmpi(1, lEnd); |
7160 | else |
7161 | epilogue(strategy, state); |
7162 | }; |
7163 | |
7164 | // Main code. |
7165 | bool success = false; |
7166 | pushStream(); |
7167 | |
7168 | if (!fragment) { |
7169 | // If descriptor-based remainders requested, all blocks should be smaller than fragSize. |
7170 | // Load descriptors based on total remainder in this (rare) case. |
7171 | if (remType == StdCRemType::Descriptor) { |
7172 | loadLoadStoreDescriptors(!problem.beta0(), true, layoutExt[0], |
7173 | remainder, problem.C, strategy.C, strategy, state); |
7174 | if (!assignAllDescs(layoutExt) |
7175 | || !assignAllDescs(layoutExtUnmasked)) |
7176 | goto failed; |
7177 | } |
7178 | if (inside && !layoutExtUnmasked.empty() |
7179 | && layoutExt.size() == state.C_layoutExt.size()) { |
7180 | // If unmasked layout is available, implement full remainder case specially. |
7181 | const bool useSIMTFlow = strategy.fused |
7182 | && (strategy.fusedLoop == loop |
7183 | || strategy.fusedLoop == LoopAny); |
7184 | Label labelRem, labelDone; |
7185 | |
7186 | if (useSIMTFlow) { |
7187 | cmp(16 | ge | state.flagAP, remainder, unroll); |
7188 | if_(16 | state.flagAP, labelRem, labelDone); |
7189 | } else if (strategy.fused) { |
7190 | cmp(1 | ge | state.flagAP, remainder, unroll); |
7191 | jmpi(1 | ~state.flagAP, labelRem); |
7192 | } else { |
7193 | // No flag registers guaranteed -- use a jump table. |
7194 | auto tempQ = state.ra.alloc_sub<uint64_t>(); |
7195 | auto temp = tempQ.ud(0); |
7196 | |
7197 | add(1 | sat, temp, remainder, -unroll + 1); |
7198 | isGen12 ? mad(1, temp, 16, temp, 16) : shl(1, temp, temp, 4); |
7199 | jmpi(1, temp.d()); |
7200 | jmpi(1, labelRem); |
7201 | |
7202 | state.ra.safeRelease(tempQ); |
7203 | } |
7204 | |
7205 | status << "Code for full " << char('m' + column) << " remainder" |
7206 | << status_stream::endl; |
7207 | if (!descend(layoutExt, layoutExtUnmasked, true)) goto failed; |
7208 | |
7209 | useSIMTFlow ? else_(16, labelDone) : jmpi(1, labelDone); |
7210 | mark(labelRem); |
7211 | |
7212 | status << "Code for generic " << char('m' + column) << " remainder" |
7213 | << status_stream::endl; |
7214 | if (!descend(layoutExt, layoutExtUnmasked)) goto failed; |
7215 | |
7216 | mark(labelDone); |
7217 | if (useSIMTFlow) endif(16); |
7218 | } else { |
7219 | // Otherwise, nothing else to do: go down a level. |
7220 | if (!descend(layoutExt, layoutExtUnmasked)) goto failed; |
7221 | } |
7222 | } else { |
7223 | // Use SIMT control flow if remainders could be different between fused threads or if jump tables disabled. |
7224 | const bool useSIMTFlow = strategy.noJumpTables |
7225 | || (strategy.fused |
7226 | && (strategy.fusedLoop == loop |
7227 | || strategy.fusedLoop == LoopAny)); |
7228 | |
7229 | // Fix up fragment size (fragSize). |
7230 | // - Check that every block starts at a multiple of fragSize; if not fall back on fragSize 1. |
7231 | // - Max fragment size is 16. |
7232 | // - Should check unmasked layout, but it will have the same kind of fragmenting as the masked layout. |
7233 | fragSize = std::min<int>(fragSize, 16); |
7234 | for (auto &block : layoutExt) { |
7235 | if (block.*offsetQ % fragSize) { |
7236 | fragSize = 1; |
7237 | break; |
7238 | } |
7239 | } |
7240 | |
7241 | // There are two strategies for fragmenting for remainder handling: |
7242 | // fragSize = 1: Try to get the largest blocks as possible. These are always fragPositive. |
7243 | // fragSize > 1: Always use blocks of size fragSize in the q dimension. |
7244 | if (fragSize == 1) { |
7245 | if (!useSIMTFlow) { |
7246 | // SIMD control flow, using a jump table. |
7247 | Subregister temp = state.ra.alloc_sub<uint32_t>(); |
7248 | vector<Label> rlabels(unroll); |
7249 | |
7250 | // Generate jump table. |
7251 | shl(1, temp, remainder, |
7252 | uint16_t(4)); // Multiply by instruction length. |
7253 | if (isGen12) // Gen12+ jmpi is relative to current IP. |
7254 | add(1, temp, temp, uint16_t(16)); |
7255 | jmpi(1, temp.d()); // Indexed jump into jump table. |
7256 | for (int r = 0; r < unroll; r++) |
7257 | jmpi(1, rlabels[r]); |
7258 | |
7259 | // Full remainder case: continue downward. |
7260 | status << "Code for full " << char('m' + column) << " remainder" |
7261 | << status_stream::endl; |
7262 | if (!descend(layoutExt, layoutExtUnmasked, true)) goto failed; |
7263 | inside ? jmpi(1, rlabels[0]) : done(); |
7264 | |
7265 | // Remainder handling. |
7266 | vector<bool> qdone(unroll, false); |
7267 | qdone[0] = true; |
7268 | int qnext = 0; |
7269 | for (int nqtodo = unroll - 2; nqtodo >= 0; nqtodo--) { |
7270 | // Decide which q to do. |
7271 | int q; |
7272 | if (qnext > 0) |
7273 | q = qnext; |
7274 | else { |
7275 | for (q = unroll - 1; q >= 0; q--) |
7276 | if (!qdone[q]) break; |
7277 | } |
7278 | |
7279 | status << "Code for " << char('m' + column) << " remainder " |
7280 | << q << status_stream::endl; |
7281 | |
7282 | mark(rlabels[q]); |
7283 | |
7284 | // Figure out how many rows/columns to take. |
7285 | int chunkSize = q & ~(q - 1); // = 1 << lowest set bit |
7286 | |
7287 | // Look through all blocks in this row/column, and reduce chunk size if appropriate. |
7288 | for (auto &block : layoutExt) { |
7289 | if (!block.isLoadBlock()) |
7290 | stub(); // Dummy blocks should be replaced by real ones... |
7291 | int qq = q |
7292 | - block.*offsetQ; // Note q = 1 + last row/column. |
7293 | if (qq > 0 && qq <= block.*nq) |
7294 | chunkSize = std::min<int>(chunkSize, qq); |
7295 | } |
7296 | |
7297 | // With chunk size chosen, get rows/columns [q - chunkSize, q) of intersecting blocks. |
7298 | vector<RegisterBlock> C_subblocksExt, |
7299 | C_subblocksExtUnmasked; |
7300 | if (!getSubblocks(Tc_ext, C_subblocksExt, layoutExt, column, |
7301 | q - chunkSize, q, false, problem.C, strategy.C)) |
7302 | goto failed; |
7303 | if (!layoutExtUnmasked.empty()) |
7304 | if (!getSubblocks(Tc_ext, C_subblocksExtUnmasked, |
7305 | layoutExtUnmasked, column, q - chunkSize, q, |
7306 | false, problem.C, strategy.C)) |
7307 | goto failed; |
7308 | |
7309 | // Perform the requested update. |
7310 | if (!descend(C_subblocksExt, C_subblocksExtUnmasked)) |
7311 | goto failed; |
7312 | |
7313 | // Go to next remainder handler, or return. |
7314 | qdone[q] = true; |
7315 | qnext = q - chunkSize; |
7316 | if (nqtodo > 0) { |
7317 | if (qnext == 0 && canEOT) |
7318 | epilogue(strategy, state); |
7319 | else if (qdone[qnext]) { |
7320 | jmpi(1, rlabels[qnext]); |
7321 | qnext = 0; |
7322 | } |
7323 | } |
7324 | } |
7325 | mark(rlabels[0]); |
7326 | |
7327 | state.ra.safeRelease(temp); |
7328 | } else { |
7329 | // SIMT control flow: massively nested if-else. |
7330 | |
7331 | // Handle remainder in the range [q0, q1). |
7332 | std::function<bool(int, int)> handleRemainder |
7333 | = [&](int q0, int q1) -> bool { |
7334 | Label labelElse, labelEndif; |
7335 | |
7336 | int qChunk = rounddown_pow2(q1 - q0 - 1); |
7337 | |
7338 | if (qChunk == 0) qChunk = 1; |
7339 | |
7340 | status << "Code for " << char('m' + column) |
7341 | << " remainders " << q0 << " - " << (q1 - 1) |
7342 | << status_stream::endl; |
7343 | |
7344 | if (q1 - q0 > 1) { |
7345 | cmp(16 | ge | state.flagAP, remainder, |
7346 | uint16_t(q0 + qChunk)); |
7347 | if_(16 | state.flagAP, |
7348 | (qChunk > 1) ? labelElse : labelEndif, |
7349 | labelEndif); |
7350 | } |
7351 | |
7352 | vector<RegisterBlock> C_subblocksExt, |
7353 | C_subblocksExtUnmasked; |
7354 | if (!getSubblocks(Tc_ext, C_subblocksExt, layoutExt, column, |
7355 | q0, q0 + qChunk, false, problem.C, strategy.C)) |
7356 | return false; |
7357 | if (!layoutExtUnmasked.empty()) |
7358 | if (!getSubblocks(Tc_ext, C_subblocksExtUnmasked, |
7359 | layoutExtUnmasked, column, q0, q0 + qChunk, |
7360 | false, problem.C, strategy.C)) |
7361 | return false; |
7362 | |
7363 | if (!descend(C_subblocksExt, C_subblocksExtUnmasked)) |
7364 | return false; |
7365 | |
7366 | if (q1 - q0 > 1) { |
7367 | if (qChunk > 1) { |
7368 | if (!handleRemainder(q0 + qChunk, q1)) return false; |
7369 | |
7370 | else_(16, labelEndif); |
7371 | mark(labelElse); |
7372 | |
7373 | if (!handleRemainder(q0, q0 + qChunk)) return false; |
7374 | } |
7375 | |
7376 | mark(labelEndif); |
7377 | endif(16); |
7378 | } |
7379 | |
7380 | return true; |
7381 | }; |
7382 | |
7383 | Label labelRem, labelRemDone, labelDone; |
7384 | |
7385 | cmp(16 | ge | state.flagAP, remainder, uint16_t(unroll)); |
7386 | if_(16 | state.flagAP, labelRem, labelDone); |
7387 | |
7388 | status << "Code for " << char('m' + column) << " full remainder" |
7389 | << status_stream::endl; |
7390 | if (!descend(layoutExt, layoutExtUnmasked, true)) goto failed; |
7391 | |
7392 | else_(16, labelDone); |
7393 | mark(labelRem); |
7394 | |
7395 | if (!handleRemainder(0, unroll)) goto failed; |
7396 | |
7397 | mark(labelDone); |
7398 | endif(16); |
7399 | setDefaultNoMask(true); |
7400 | } |
7401 | } else { |
7402 | auto handleRemainderFP = [&](int q0, int q1) -> bool { |
7403 | // Get rows/columns [q0, q1) of intersecting blocks. |
7404 | vector<RegisterBlock> C_subblocksExt, C_subblocksExtUnmasked; |
7405 | if (!getSubblocks(Tc_ext, C_subblocksExt, layoutExt, column, q0, |
7406 | q1, false, problem.C, strategy.C)) |
7407 | return false; |
7408 | if (!layoutExtUnmasked.empty()) |
7409 | if (!getSubblocks(Tc_ext, C_subblocksExtUnmasked, |
7410 | layoutExtUnmasked, column, q0, q1, false, |
7411 | problem.C, strategy.C)) |
7412 | return false; |
7413 | |
7414 | if (remType == StdCRemType::Descriptor) { |
7415 | // Load address registers for subsequent loads and stores. |
7416 | Subregister rcount = state.ra.alloc_sub<uint32_t>(); |
7417 | Subregister mremainder = remainder; |
7418 | |
7419 | if (q0 != 0) { |
7420 | add(1 | sat, rcount, mremainder, int16_t(-q0)); |
7421 | mremainder = rcount; |
7422 | } |
7423 | if (q1 < unroll) { |
7424 | min_(1, rcount, mremainder, uint16_t(fragSize)); |
7425 | mremainder = rcount; |
7426 | } |
7427 | |
7428 | loadLoadStoreDescriptors(!problem.beta0(), true, |
7429 | C_subblocksExt[0], mremainder, problem.C, |
7430 | strategy.C, strategy, state); |
7431 | if (!assignAllDescs(C_subblocksExt) |
7432 | || !assignAllDescs(C_subblocksExtUnmasked)) |
7433 | return false; |
7434 | |
7435 | state.ra.safeRelease(rcount); |
7436 | } |
7437 | |
7438 | // Perform the requested update. |
7439 | return descend(C_subblocksExt, C_subblocksExtUnmasked); |
7440 | }; |
7441 | |
7442 | if (!useSIMTFlow) { |
7443 | // SIMD control flow, possibly using a jump table. |
7444 | int N = div_up(unroll, fragSize); |
7445 | vector<Label> rlabels(N); // Targets for jump table. |
7446 | Label rdone; |
7447 | |
7448 | // Create a jump table, if needed. |
7449 | if (fragPositive) { |
7450 | Subregister t1 = state.ra.alloc_sub<uint32_t>(); |
7451 | Subregister t2 = state.ra.alloc_sub<uint32_t>(); |
7452 | |
7453 | add(1 | sat, t2, remainder, int16_t(-unroll + 1)); |
7454 | add(1, t1, remainder, |
7455 | int16_t(-1 + (isGen12 ? fragSize : 0))); |
7456 | add(1, t1, t1, |
7457 | t2); // Increment index if remainder == unroll. |
7458 | if (fragSize < 16) // Precondition: fragSize <= 16. |
7459 | mulConstant(1, t1, t1, |
7460 | 16 / fragSize); // Multiply by instruction length (16b/uncompacted instruction) |
7461 | and_(1, t1, t1, |
7462 | uint16_t(0xFFF0)); // Mask off unwanted bits. |
7463 | jmpi(1, t1.d()); // Indexed jump into jump table. |
7464 | for (int r = 0; r < N; r++) |
7465 | jmpi(1, rlabels[r]); |
7466 | |
7467 | state.ra.safeRelease(t2); |
7468 | state.ra.safeRelease(t1); |
7469 | } |
7470 | |
7471 | // Full loop. |
7472 | status << "Code for " << char('m' + column) << " full remainder" |
7473 | << status_stream::endl; |
7474 | if (!descend(layoutExt, layoutExtUnmasked, true)) goto failed; |
7475 | inside ? jmpi(1, rdone) : done(); |
7476 | |
7477 | // Remainder handling. |
7478 | for (int r = N - 1; r >= 0; r--) { |
7479 | int q0 = r * fragSize; |
7480 | int q1 = std::min<int>(q0 + fragSize, unroll); |
7481 | |
7482 | status << "Code for " << char('m' + column) |
7483 | << " remainders " << q0 + 1 << " - " << q1 |
7484 | << status_stream::endl; |
7485 | |
7486 | mark(rlabels[r]); |
7487 | |
7488 | if (!handleRemainderFP(q0, q1)) goto failed; |
7489 | } |
7490 | |
7491 | if (inside) mark(rdone); |
7492 | } else { |
7493 | // SIMT control flow version. |
7494 | Label labelRem, labelRemDone, labelDone; |
7495 | |
7496 | cmp(16 | ge | state.flagAP, remainder, uint16_t(unroll)); |
7497 | if_(16 | state.flagAP, labelRem, labelDone); |
7498 | |
7499 | status << "Code for " << char('m' + column) << " full remainder" |
7500 | << status_stream::endl; |
7501 | if (!descend(layoutExt, layoutExtUnmasked, true)) goto failed; |
7502 | |
7503 | else_(16, labelDone); |
7504 | mark(labelRem); |
7505 | |
7506 | for (int q0 = 0; q0 < unroll; q0 += fragSize) { |
7507 | int q1 = std::min<int>(q0 + fragSize, unroll); |
7508 | |
7509 | cmp(16 | le | state.flagAP, remainder, uint16_t(q0)); |
7510 | goto12(16 | state.flagAP, labelRemDone); |
7511 | status << "Code for " << char('m' + column) |
7512 | << " remainders " << q0 + 1 << " - " << q1 |
7513 | << status_stream::endl; |
7514 | |
7515 | if (!handleRemainderFP(q0, q1)) goto failed; |
7516 | } |
7517 | |
7518 | mark(labelRemDone); |
7519 | join(16); |
7520 | |
7521 | mark(labelDone); |
7522 | endif(16); |
7523 | } |
7524 | } |
7525 | } |
7526 | |
7527 | // Success! |
7528 | success = true; |
7529 | failed: |
7530 | |
7531 | mark(lEnd); |
7532 | success ? appendCurrentStream() : discardStream(); |
7533 | |
7534 | if (!inside && strategy.C.atomic) freeEAtomicAddRegs(state); |
7535 | |
7536 | return success; |
7537 | } |
7538 | |
7539 | // Alternate code path for C remainder handling, based on a simple double loop |
7540 | // and indirect addressing. |
7541 | template <HW hw> |
7542 | void gemm_kernel_generator_t<hw>::doAlternateCRemainder(COperation op, |
7543 | GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state) { |
7544 | auto Tc = problem.Tc, Tc_ext = problem.Tc_ext; |
7545 | int C_count = (op == COperation::UpdateStore) ? state.C_count : 1; |
7546 | #define FOR_EACH_C for (int q = 0; q < C_count; q++) |
7547 | #define FOR_EACH_C_REV for (int q = C_count - 1; q >= 0; q--) |
7548 | |
7549 | bool lateYLoopCheck = false; |
7550 | |
7551 | bool surface = !strategy.C.base.isStateless(); |
7552 | bool loadOnly = (op == COperation::Load); |
7553 | |
7554 | // Vector length in inner loop. |
7555 | const auto nbytes = 64; |
7556 | auto nec = nbytes / Tc; |
7557 | |
7558 | // 1- and 2-byte types must be padded to 4 bytes. |
7559 | bool byte_access = (Tc_ext.size() < 4); |
7560 | if (byte_access) nec = nbytes >> 2; |
7561 | |
7562 | // 8-byte+ types can use scattered qword. Only atomic for now. |
7563 | bool nativeAtomic = strategy.C.atomic |
7564 | && hasNativeAtomicAdd(hw, Tc_ext.real(), problem.C, strategy.C); |
7565 | bool qword = !((nativeAtomic ? Tc_ext.real() : Tc_ext).size() & 7) |
7566 | && strategy.C.atomic; |
7567 | int rshift = qword ? 3 : 2; // log2(data stride in regs) |
7568 | int rsimd = 64 >> rshift; |
7569 | |
7570 | auto &block0 = state.C_layout[0]; |
7571 | bool cColMajorMem = isColMajor(problem.C.layout); |
7572 | bool cColMajorReg = block0.colMajor; |
7573 | bool transpose = (cColMajorReg != cColMajorMem); |
7574 | if (isPacked(problem.C.layout)) stub(); |
7575 | |
7576 | // x is the contiguous dimension (in registers), y is the other dimension. |
7577 | auto LoopX = cColMajorReg ? LoopM : LoopN; |
7578 | auto LoopY = cColMajorReg ? LoopN : LoopM; |
7579 | int unrollX = strategy.unroll[LoopX]; |
7580 | int unrollY = strategy.unroll[LoopY]; |
7581 | |
7582 | // Check the layout: |
7583 | // - C is a contiguous block of registers. |
7584 | // - nx must be divisible by 2 (unpacked) GRFs, unless x unroll is < 2 GRFs, |
7585 | // or there's an extra GRF at the end of C. |
7586 | // - register offsets must be in a uniform 2D grid |
7587 | // - all blocks must share same ordering (row/column major). |
7588 | // Otherwise use non-uniform path, and indirectly load GRFs. |
7589 | |
7590 | auto Tcx = Tc; |
7591 | bool uniform = true; |
7592 | int16_t xByteInc = 0, yByteInc = 0; |
7593 | bool cAtEnd = (state.C_regs[0][state.C_regs[0].getLen() - 1].getBase() + 1) |
7594 | >= strategy.GRFs; |
7595 | |
7596 | if (state.C_regs[0].ranges.size() != 1) uniform = false; |
7597 | |
7598 | for (auto &block : state.C_layout) { |
7599 | if (block.colMajor != block0.colMajor) stub(); |
7600 | |
7601 | int nx = cColMajorReg ? block.nr : block.nc; |
7602 | int ny = cColMajorReg ? block.nc : block.nr; |
7603 | int ox = cColMajorReg ? block.offsetR : block.offsetC; |
7604 | int oy = cColMajorReg ? block.offsetC : block.offsetR; |
7605 | |
7606 | ox /= nec; |
7607 | |
7608 | if ((nx & (nec - 1)) && cAtEnd) uniform = false; |
7609 | |
7610 | if (xByteInc == 0 && nx > nec) xByteInc = nec * Tcx; |
7611 | if (yByteInc == 0 && ny > 1) yByteInc = block.ld * Tc; |
7612 | |
7613 | if (block.offsetBytes != ox * xByteInc + oy * yByteInc) { |
7614 | if (xByteInc == 0 && ox > 0) |
7615 | xByteInc = (block.offsetBytes - oy * yByteInc) / ox; |
7616 | else if (yByteInc == 0 && oy > 0) |
7617 | yByteInc = (block.offsetBytes - ox * xByteInc) / oy; |
7618 | else |
7619 | uniform = false; |
7620 | } |
7621 | } |
7622 | |
7623 | GRFRange bases; |
7624 | bool nonuniformSubs = false; |
7625 | |
7626 | if (!uniform) { |
7627 | uint8_t baseIndices[256] = {0}; |
7628 | uint16_t offIndices[256] = {0}; |
7629 | |
7630 | if (state.Tacc.size() == 1) stub(); |
7631 | |
7632 | xByteInc = div_up(nec * Tcx, GRF::bytes(hw)); |
7633 | int nec1 = nec / xByteInc; |
7634 | yByteInc = div_up(unrollX, nec1); |
7635 | |
7636 | for (int y = 0; y < unrollY; y++) { |
7637 | for (int xx = 0; xx < yByteInc; xx++) { |
7638 | auto x = xx * nec1; |
7639 | auto i = cColMajorReg ? x : y; |
7640 | auto j = cColMajorReg ? y : x; |
7641 | const RegisterBlock *blockPtr; |
7642 | int ne; |
7643 | auto sub = findBlockReg(Tc, state.C_layout, i, j, |
7644 | state.C_regs[0], ne, blockPtr, 0); |
7645 | nonuniformSubs |= (sub.getOffset() != 0); |
7646 | if (ne < std::min(nec1, unrollX - x)) stub(); |
7647 | baseIndices[y * yByteInc + xx] = sub.getBase(); |
7648 | offIndices[y * yByteInc + xx] |
7649 | = sub.getByteOffset() + sub.getBase() * GRF::bytes(hw); |
7650 | } |
7651 | } |
7652 | |
7653 | if (nonuniformSubs) { |
7654 | xByteInc *= 2; |
7655 | yByteInc *= 2; |
7656 | } |
7657 | |
7658 | bases = state.ra.alloc_range( |
7659 | div_up(unrollY * yByteInc, GRF::bytes(hw))); |
7660 | bool haveDF = !strategy.emulate.emulate64; |
7661 | haveDF |= (hw == HW::XeHPC); |
7662 | if (haveDF) { |
7663 | for (int i = 0; i < unrollY * yByteInc; i += 8) { |
7664 | auto sub = bases[i / GRF::bytes(hw)].df( |
7665 | (i % GRF::bytes(hw)) / 8); |
7666 | auto data = nonuniformSubs |
7667 | ? reinterpret_cast<double *>(&offIndices[i / 2]) |
7668 | : reinterpret_cast<double *>(&baseIndices[i]); |
7669 | mov(1, sub, *data); |
7670 | } |
7671 | } else { |
7672 | for (int i = 0; i < unrollY * yByteInc; i += 4) { |
7673 | auto sub = bases[i / GRF::bytes(hw)].ud( |
7674 | (i % GRF::bytes(hw)) / 4); |
7675 | auto data = nonuniformSubs |
7676 | ? reinterpret_cast<uint32_t *>(&offIndices[i / 2]) |
7677 | : reinterpret_cast<uint32_t *>(&baseIndices[i]); |
7678 | mov(1, sub, *data); |
7679 | } |
7680 | } |
7681 | } |
7682 | |
7683 | // Claim flags. |
7684 | auto saveFlagAP = state.flagAP; |
7685 | state.raVFlag.safeRelease(state.flagAP); |
7686 | state.raVFlag.claim(f0[0]); |
7687 | state.raVFlag.claim(f0[1]); |
7688 | state.raVFlag.claim(f1[0]); |
7689 | |
7690 | // Clear f0[1] for any16h trick. |
7691 | if (strategy.fused && !lateYLoopCheck) mov(1, f0[1], uint16_t(0)); |
7692 | |
7693 | // Update C with scattered accesses. |
7694 | // Get mask and set up header. |
7695 | GRFRange [2]; |
7696 | auto hregs = (surface ? 1 : 2) * (qword ? 1 : 2); |
7697 | FOR_EACH_C header[q] = state.ra.alloc_range(hregs); |
7698 | Subregister temp = state.ra.alloc_sub<uint32_t>(); |
7699 | Subregister mask = state.ra.alloc_sub<uint32_t>(); |
7700 | Subregister xIndex = state.remainders[LoopX]; |
7701 | |
7702 | GRF indexVec, ivContig, ivScatter; |
7703 | |
7704 | indexVec = state.ra.alloc(); |
7705 | indexVec.setType(DataType::w); |
7706 | mov(8, indexVec[0](1), Immediate::uv(0, 1, 2, 3, 4, 5, 6, 7)); |
7707 | if (rsimd > 8) |
7708 | mov(8, indexVec[8](1), Immediate::uv(8, 9, 10, 11, 12, 13, 14, 15)); |
7709 | |
7710 | auto oshift = std::min<int>(rshift, Tc_ext.log2Size()); |
7711 | |
7712 | // Prepare x mask in f1.0 and prepare header for loads/stores. |
7713 | if (Tc_ext.size() > 4) { |
7714 | mulConstant(1, temp, xIndex, uint16_t(Tc_ext.size() >> rshift)); |
7715 | xIndex = temp; |
7716 | } |
7717 | |
7718 | ivScatter = indexVec; |
7719 | bool splitScatter = transpose && (Tc_ext.log2Size() > rshift); |
7720 | if (splitScatter) { |
7721 | ivContig = state.ra.alloc(); |
7722 | ivContig.setType(DataType::w); |
7723 | auto shift = Tc_ext.log2Size() - rshift; |
7724 | auto m = (1 << shift) - 1; |
7725 | |
7726 | asr(16, ivScatter, indexVec, uint16_t(shift)); |
7727 | mov(16, ivContig, |
7728 | Immediate::uv((0 & m) << rshift, (1 & m) << rshift, |
7729 | (2 & m) << rshift, (3 & m) << rshift, (4 & m) << rshift, |
7730 | (5 & m) << rshift, (6 & m) << rshift, |
7731 | (7 & m) << rshift)); |
7732 | } |
7733 | |
7734 | add(1, temp, xIndex, int16_t(-1)); |
7735 | FOR_EACH_C transpose |
7736 | ? mul(rsimd, header[q][0].d(), state.inputs.ldc[q], ivScatter) |
7737 | : shl(rsimd, header[q][0].d(), indexVec, uint16_t(oshift)); |
7738 | FOR_EACH_C if (splitScatter) |
7739 | add(rsimd, header[q][0].d(), header[q][0].d(), ivContig); |
7740 | |
7741 | int hs = 1; |
7742 | bool = !qword && !surface; |
7743 | int neq = elementsPerGRF(hw, DataType::uq); |
7744 | |
7745 | header4 &= (GRF::bytes(hw) < 64); |
7746 | if (hw >= HW::XeHP && !surface) { |
7747 | if (header4) |
7748 | FOR_EACH_C mov<uint32_t>(2 * neq, header[q][2][0](2), header[q][1]); |
7749 | FOR_EACH_C mov<uint32_t>(neq, header[q][1][0](2), header[q][0][neq](1)); |
7750 | FOR_EACH_C mov<uint32_t>(neq, header[q][0][0](2), header[q][0][0](1)); |
7751 | hs = 2; |
7752 | } |
7753 | |
7754 | and_(1, temp, ~temp, uint16_t(rsimd - 1)); |
7755 | FOR_EACH_C surface |
7756 | ? add(rsimd, header[q][0].d(), header[q][0].d(), state.effC[q]) |
7757 | : header4 ? eadd(8, header[q][2].uq(), header[q][hs].d(0)(hs), |
7758 | state.effC[q], strategy, state) |
7759 | : noop(); |
7760 | mov(1, mask, uint16_t((1 << rsimd) - 1)); |
7761 | FOR_EACH_C if (!surface) eadd(2 * neq, header[q][0].uq(), |
7762 | header[q][0].d(0)(hs), state.effC[q], strategy, state); |
7763 | shr(1, f1[0], mask, temp); |
7764 | |
7765 | state.ra.safeRelease(mask); |
7766 | state.ra.safeRelease(temp); |
7767 | state.ra.safeRelease(ivContig); |
7768 | |
7769 | // Synthesize double loop updating 2 GRFs (indirectly addressed) at a time. |
7770 | GRF ix = state.ra.alloc(); |
7771 | Subregister ix_init = state.ra.alloc_sub<uint16_t>(); |
7772 | Subregister iy = state.ra.alloc_sub<int16_t>(); |
7773 | Subregister cXInc[2], cYInc[2]; |
7774 | FOR_EACH_C cYInc[q] = state.ra.alloc_sub<int32_t>(); |
7775 | Label yLoop, xLoop; |
7776 | GRFRange Cacc = state.ra.alloc_range(2); |
7777 | GRFRange CaccSwap {}; |
7778 | GRFRange Cload |
7779 | = state.ra.alloc_range(2, getHint(HintType::CLoad, strategy)); |
7780 | |
7781 | if (transpose) FOR_EACH_C { |
7782 | cXInc[q] = state.ra.alloc_sub<int32_t>(); |
7783 | mulConstant(1, cXInc[q], state.inputs.ldc[q], nec); |
7784 | } |
7785 | |
7786 | add(1, ix_init, state.remainders[LoopX], int16_t(-1)); |
7787 | mov(1, iy, state.remainders[LoopY]); |
7788 | shr(1, ix_init, ix_init, uint16_t(log2(nec))); |
7789 | |
7790 | if (uniform) |
7791 | mov(1, a0[0], state.C_regs[0][0].getBase() * GRF::bytes(hw)); |
7792 | else |
7793 | mov(1, a0[0], bases.getBase() * GRF::bytes(hw)); |
7794 | |
7795 | add(1, cYInc[0], ix_init, uint16_t(1)); |
7796 | mulConstant(1, cYInc[0], cYInc[0], |
7797 | uint16_t(nec * (!transpose ? Tc_ext.size() : 1))); |
7798 | if (!transpose) |
7799 | FOR_EACH_C_REV add(1, cYInc[q], -cYInc[0], state.inputs.ldc[q]); |
7800 | else { |
7801 | FOR_EACH_C_REV mul(1, cYInc[q], state.inputs.ldc[q], cYInc[0].w()); |
7802 | FOR_EACH_C_REV add(1, cYInc[q], -cYInc[q], uint16_t(Tc_ext.size())); |
7803 | } |
7804 | |
7805 | mark(yLoop); |
7806 | mov<uint16_t>(16, ix, ix_init); |
7807 | if (!lateYLoopCheck) add(1 | gt | f0[1], iy, iy, int16_t(-1)); |
7808 | mov(1, a0[1], a0[0]); |
7809 | |
7810 | mark(xLoop); |
7811 | add<int16_t>(16 | ge | f0[0], ix, ix, int16_t(-1)); |
7812 | |
7813 | // Update. The anyv is a trick to use the generated m mask (f1.0) on the last |
7814 | // iteration of the loop, and no mask (0xFFFF) on the other iterations. |
7815 | InstructionModifier mod; |
7816 | mod = mod | f0[0] | anyv; |
7817 | |
7818 | // Alas, no anyv on PVC. |
7819 | if (hw == HW::XeHPC) { |
7820 | mov(1 | ~f0[0], f0[0], f1[0]); |
7821 | mod = InstructionModifier() | f0[0]; |
7822 | } |
7823 | |
7824 | if (!uniform) { |
7825 | nonuniformSubs ? mov(xByteInc, a0[2](1), indirect[a0[1]].uw()) |
7826 | : shl(xByteInc, a0[2](1), indirect[a0[1]].ub(), |
7827 | GRF::log2Bytes(hw)); |
7828 | } |
7829 | |
7830 | if (!loadOnly) { |
7831 | if (uniform) switch (state.Tacc.size()) { |
7832 | case 1: mov<uint32_t>(16, Cacc, indirect[a0[1]].ub()); break; |
7833 | case 2: mov<uint32_t>(16, Cacc, indirect[a0[1]].uw()); break; |
7834 | default: mov<uint32_t>(16, Cacc, indirect[a0[1]]); break; |
7835 | } |
7836 | else if (xByteInc == 1) |
7837 | switch (state.Tacc.size()) { |
7838 | case 2: mov<uint32_t>(16, Cacc, indirect[a0[2]].uw()); break; |
7839 | default: mov<uint32_t>(16, Cacc, indirect[a0[2]]); break; |
7840 | } |
7841 | else |
7842 | switch (state.Tacc.size()) { |
7843 | case 2: |
7844 | mov<uint32_t>( |
7845 | 16, Cacc, indirect[a0[2]].uw(0)(16 / xByteInc, 1)); |
7846 | break; |
7847 | default: |
7848 | mov<uint32_t>( |
7849 | 16, Cacc, indirect[a0[2]].ud(0)(16 / xByteInc, 1)); |
7850 | break; |
7851 | } |
7852 | } |
7853 | |
7854 | if (strategy.C.atomic) { |
7855 | // Atomic update. Requires beta = 1, alpha prescaled. |
7856 | if (!problem.alpha1() && !problem.beta1()) stub(); |
7857 | if (C_count > 1) stub(); |
7858 | if (op != COperation::UpdateStore) stub(); |
7859 | |
7860 | std::vector<RegisterBlock> layout {1}; |
7861 | auto &block = layout[0]; |
7862 | block.ebytes = qword ? 8 : Tc_ext.real().size(); |
7863 | block.simdSize = rsimd; |
7864 | block.clearFlag(); |
7865 | block.bytes = 64; |
7866 | block.extra = 1; |
7867 | block.count = 1; |
7868 | block.log2GRFBytes = GRF::log2Bytes(hw); |
7869 | |
7870 | allocEAtomicAddRegs( |
7871 | hw, Tc_ext, layout, problem.C, strategy.C, state, f1[1]); |
7872 | |
7873 | Label labelEndAtomic; |
7874 | if_(16 | mod, labelEndAtomic); |
7875 | setDefaultNoMask(false); |
7876 | atomicAddMatrixBlock(Tc_ext, Cacc, block, problem.C, strategy.C, |
7877 | header[0], problem, strategy, state); |
7878 | setDefaultNoMask(true); |
7879 | mark(labelEndAtomic); |
7880 | endif(16); |
7881 | |
7882 | freeEAtomicAddRegs(state, f1[1]); |
7883 | } else { |
7884 | // Late C conversion, if needed. |
7885 | auto originalTacc = state.Tacc; |
7886 | if (problem.needsTsConvert() && state.Tacc != problem.Ts) { |
7887 | convert(Cacc, state.Tacc, problem.Ts, problem, strategy, state); |
7888 | state.Tacc = problem.Ts; |
7889 | } |
7890 | |
7891 | // Regular update. |
7892 | if (loadOnly || !problem.beta0()) { |
7893 | doReadSuppressionWA(strategy, state); |
7894 | if (strategy.C.newDP) { |
7895 | !byte_access ? load(16 | mod, Cload, D32 | strategy.C.cachingR, |
7896 | strategy.C.base, header[0]) |
7897 | : (Tc_ext.size() == 2) |
7898 | ? load(16 | mod, Cload, |
7899 | D16U32 | strategy.C.cachingR, |
7900 | strategy.C.base, header[0]) |
7901 | : load(16 | mod, Cload, |
7902 | D8U32 | strategy.C.cachingR, |
7903 | strategy.C.base, header[0]); |
7904 | } else { |
7905 | byte_access |
7906 | ? load(16 | mod, Cload, scattered_byte(Tc_ext.size()), |
7907 | strategy.C.base, header[0]) |
7908 | : !surface ? load(16 | mod, Cload, scattered_dword(), |
7909 | strategy.C.base, header[0]) |
7910 | : load(16 | mod, Cload, |
7911 | surface_dword(ChannelMask::r), |
7912 | strategy.C.base, header[0]); |
7913 | } |
7914 | } |
7915 | |
7916 | if (!loadOnly) { |
7917 | auto Tc_out = (op == COperation::UpdateStore) ? problem.Tc_ext |
7918 | : state.Tacc; |
7919 | if (!problem.beta0()) |
7920 | convert(Cload, problem.Tc_ext, state.Tacc, problem, strategy, |
7921 | state); |
7922 | updateC(Cacc, CaccSwap, Cload, problem, strategy, state); |
7923 | convert(Cacc, state.Tacc, Tc_out, problem, strategy, state); |
7924 | } |
7925 | |
7926 | if (op != COperation::UpdateStore) { |
7927 | auto src = (op == COperation::Load) ? Cload : Cacc; |
7928 | if (uniform) switch (Tc.size()) { |
7929 | case 1: |
7930 | mov<uint32_t>(16 | mod, indirect[a0[1]].ub(), src); |
7931 | break; |
7932 | case 2: |
7933 | mov<uint32_t>(16 | mod, indirect[a0[1]].uw(), src); |
7934 | break; |
7935 | default: |
7936 | mov<uint32_t>(16 | mod, indirect[a0[1]], src); |
7937 | break; |
7938 | } |
7939 | else if (xByteInc == 1) |
7940 | switch (state.Tacc.size()) { |
7941 | case 2: |
7942 | mov<uint32_t>(16 | mod, indirect[a0[2]].uw(), src); |
7943 | break; |
7944 | default: |
7945 | mov<uint32_t>(16 | mod, indirect[a0[2]], src); |
7946 | break; |
7947 | } |
7948 | else if (xByteInc == 2) |
7949 | switch (state.Tacc.size()) { |
7950 | case 2: |
7951 | mov<uint32_t>(8 | mod, indirect[a0[2]].uw(), src); |
7952 | mov<uint32_t>(8 | mod | M8, indirect[a0[3]].uw(), |
7953 | src.sub(hw, 8, DataType::ud)(1)); |
7954 | break; |
7955 | default: |
7956 | mov<uint32_t>(8 | mod, indirect[a0[2]].ud(), src); |
7957 | mov<uint32_t>(8 | mod | M8, indirect[a0[3]].ud(), |
7958 | src.sub(hw, 8, DataType::ud)(1)); |
7959 | break; |
7960 | } |
7961 | else |
7962 | stub(); |
7963 | } else |
7964 | FOR_EACH_C { |
7965 | if (strategy.C.newDP) { |
7966 | !byte_access ? store(16 | mod, D32 | strategy.C.cachingW, |
7967 | strategy.C.base, header[q], Cacc) |
7968 | : (Tc_ext.size() == 2) |
7969 | ? store(16 | mod, |
7970 | D16U32 | strategy.C.cachingW, |
7971 | strategy.C.base, header[q], Cacc) |
7972 | : store(16 | mod, |
7973 | D8U32 | strategy.C.cachingW, |
7974 | strategy.C.base, header[q], Cacc); |
7975 | } else { |
7976 | byte_access ? store(16 | mod, scattered_byte(Tc_ext.size()), |
7977 | strategy.C.base, header[q], Cacc) |
7978 | : !surface |
7979 | ? store(16 | mod, scattered_dword(), |
7980 | strategy.C.base, header[q], Cacc) |
7981 | : store(16 | mod, |
7982 | surface_dword(ChannelMask::r), |
7983 | strategy.C.base, header[q], Cacc); |
7984 | } |
7985 | } |
7986 | |
7987 | state.Tacc = originalTacc; |
7988 | } |
7989 | |
7990 | if (hw == HW::XeHPC) cmp<int16_t>(1 | ge | f0[0], ix, 0); |
7991 | |
7992 | add(1, a0[1], a0[1], xByteInc); |
7993 | if (!transpose) { |
7994 | uint16_t inc = nec * Tc_ext; |
7995 | if (!surface) { |
7996 | FOR_EACH_C eadd<uint64_t>(std::min(2 * neq, rsimd), header[q][0], |
7997 | header[q][0], inc, strategy, state); |
7998 | if (header4) |
7999 | FOR_EACH_C eadd<uint64_t>( |
8000 | 8, header[q][2], header[q][2], inc, strategy, state); |
8001 | } else |
8002 | FOR_EACH_C add<uint32_t>(rsimd, header[q][0], header[q][0], inc); |
8003 | } else { |
8004 | if (!surface) { |
8005 | FOR_EACH_C eadd<uint64_t>(std::min(2 * neq, rsimd), header[q][0], |
8006 | header[q][0], cXInc[q], strategy, state); |
8007 | if (header4) |
8008 | FOR_EACH_C eadd<uint64_t>(8, header[q][2], header[q][2], |
8009 | cXInc[q], strategy, state); |
8010 | } else |
8011 | FOR_EACH_C add<uint32_t>( |
8012 | rsimd, header[q][0], header[q][0], cXInc[q]); |
8013 | } |
8014 | |
8015 | // Bottom of x loop. |
8016 | // Fused threads must use SIMT control flow instructions. |
8017 | strategy.fused ? simtDoWhileLoop(16 | f0[0], xLoop) |
8018 | : jmpi(1 | f0[0], xLoop); |
8019 | |
8020 | if (lateYLoopCheck) add(1 | gt | f0[1], iy, iy, int16_t(-1)); |
8021 | add(1, a0[0], a0[0], yByteInc); |
8022 | if (!surface) { |
8023 | FOR_EACH_C eadd<uint64_t>(std::min(2 * neq, rsimd), header[q][0], |
8024 | header[q][0], cYInc[q], strategy, state); |
8025 | if (header4) |
8026 | FOR_EACH_C eadd<uint64_t>( |
8027 | 8, header[q][2], header[q][2], cYInc[q], strategy, state); |
8028 | } else |
8029 | FOR_EACH_C add<uint32_t>(rsimd, header[q][0], header[q][0], cYInc[q]); |
8030 | |
8031 | // Bottom of y loop. |
8032 | // The any16h is a trick: only the lowest bit of f0[1] is updated when decrementing iy, |
8033 | // but we want to apply it to all channels. |
8034 | strategy.fused ? simtDoWhileLoop(16 | f0[1] | any16h, yLoop) |
8035 | : jmpi(1 | f0[1], yLoop); |
8036 | |
8037 | // Cleanup. |
8038 | state.raVFlag.release(f0[0]); |
8039 | state.raVFlag.release(f0[1]); |
8040 | state.raVFlag.release(f1[0]); |
8041 | state.ra.safeRelease(bases); |
8042 | |
8043 | state.ra.safeRelease(indexVec); |
8044 | state.ra.safeRelease(Cload); |
8045 | state.ra.safeRelease(CaccSwap); |
8046 | state.ra.safeRelease(Cacc); |
8047 | FOR_EACH_C state.ra.safeRelease(cXInc[q]); |
8048 | FOR_EACH_C state.ra.safeRelease(cYInc[q]); |
8049 | state.ra.safeRelease(iy); |
8050 | state.ra.safeRelease(ix); |
8051 | state.ra.safeRelease(ix_init); |
8052 | FOR_EACH_C state.ra.safeRelease(header[q]); |
8053 | |
8054 | state.flagAP = saveFlagAP; |
8055 | if (state.flagAP.isValid()) state.raVFlag.claim(state.flagAP); |
8056 | |
8057 | #undef FOR_EACH_C |
8058 | } |
8059 | |
8060 | // Prepare for GEMM k loop with m/n masked A/B accesses. Returns true if ka_lda/kb_ldb need recalculating. |
8061 | template <HW hw> |
8062 | bool gemm_kernel_generator_t<hw>::gemmPrepMaskedAB( |
8063 | const GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state) { |
8064 | bool recalc = false; |
8065 | bool shrinkUK = false; |
8066 | if (!strategy.A.padded |
8067 | && (strategy.remHandling[LoopM] != RemainderHandling::Ignore)) { |
8068 | shrinkUK = true; |
8069 | if (strategy.ka_load > strategy.ka_load_masked) { |
8070 | status << "Downgrading ka_load: " << strategy.ka_load << " -> " |
8071 | << strategy.ka_load_masked << status_stream::endl; |
8072 | strategy.ka_load = strategy.ka_load_masked; |
8073 | strategy.kChain = gcd(strategy.kChain, strategy.ka_load); |
8074 | recalc = true; |
8075 | } |
8076 | // Avoid access patterns that can't be handled by masking. |
8077 | if (isBlock2D(strategy.A.accessType) || strategy.unroll[LoopM] == 1) |
8078 | noop(); |
8079 | else if (problem.A.layout == MatrixLayout::T |
8080 | && !isTransposing(strategy.A.accessType)) |
8081 | strategy.A.accessType = strategy.A.base.isStateless() |
8082 | ? AccessType::Scattered |
8083 | : AccessType::ChannelScattered; |
8084 | else if (problem.A.layout != MatrixLayout::T |
8085 | && isTransposing(strategy.A.accessType)) |
8086 | strategy.A.accessType = AccessType::Block; |
8087 | strategy.slmATrans = false; |
8088 | strategy.prefetchA = strategy.prefetchAMasked; |
8089 | } |
8090 | if (!strategy.B.padded |
8091 | && (strategy.remHandling[LoopN] != RemainderHandling::Ignore)) { |
8092 | shrinkUK = true; |
8093 | if (strategy.kb_load > strategy.kb_load_masked) { |
8094 | status << "Downgrading kb_load: " << strategy.kb_load << " -> " |
8095 | << strategy.kb_load_masked << status_stream::endl; |
8096 | strategy.kb_load = strategy.kb_load_masked; |
8097 | strategy.kChain = gcd(strategy.kChain, strategy.kb_load); |
8098 | recalc = true; |
8099 | } |
8100 | // Avoid access patterns that can't be handled by masking. |
8101 | if (isBlock2D(strategy.B.accessType) || strategy.unroll[LoopN] == 1) |
8102 | noop(); |
8103 | else if (problem.B.layout == MatrixLayout::N |
8104 | && !isTransposing(strategy.B.accessType)) |
8105 | strategy.B.accessType = strategy.B.base.isStateless() |
8106 | ? AccessType::Scattered |
8107 | : AccessType::ChannelScattered; |
8108 | else if (problem.B.layout != MatrixLayout::N |
8109 | && isTransposing(strategy.B.accessType)) |
8110 | strategy.B.accessType = AccessType::Block; |
8111 | strategy.slmBTrans = false; |
8112 | strategy.prefetchB = strategy.prefetchBMasked; |
8113 | } |
8114 | if (shrinkUK && (strategy.unrollK_masked > 0) |
8115 | && (strategy.unroll[LoopK] > strategy.unrollK_masked)) { |
8116 | status << "Downgrading k unroll: " << strategy.unroll[LoopK] << " -> " |
8117 | << strategy.unrollK_masked << status_stream::endl; |
8118 | strategy.unroll[LoopK] = strategy.unrollK_masked; |
8119 | } |
8120 | if (shrinkUK && (strategy.unrollKSLMMasked > 0) |
8121 | && (strategy.unrollKSLM > strategy.unrollKSLMMasked)) { |
8122 | status << "Downgrading SLM k chunk size: " << strategy.unrollKSLM |
8123 | << " -> " << strategy.unrollKSLMMasked << status_stream::endl; |
8124 | strategy.unrollKSLM = strategy.unrollKSLMMasked; |
8125 | } |
8126 | return recalc; |
8127 | } |
8128 | |
8129 | // Generate the GEMM kernel body. If it fails (due to excessive masking, say), return false. |
8130 | template <HW hw> |
8131 | bool gemm_kernel_generator_t<hw>::gemmBody( |
8132 | GEMMProblem problem, GEMMStrategy strategy, GEMMState state) { |
8133 | bool a2D = strategy.A.address2D; |
8134 | bool b2D = strategy.B.address2D; |
8135 | bool c2D = strategy.C.address2D; |
8136 | |
8137 | // Record whether we are in the first row/column for fused sum kernels. |
8138 | if (problem.sumA || problem.sumB) { |
8139 | if (problem.sumA && problem.sumB) stub(); |
8140 | auto &flags = state.inputs.flags; |
8141 | auto &y0 = problem.sumA ? state.j0 : state.i0; |
8142 | |
8143 | if (flags.isInvalid()) { |
8144 | flags = state.ra.alloc_sub<uint32_t>( |
8145 | getHint(HintType::LongTerm, strategy)); |
8146 | cmp(1 | eq | state.flagAP, flags, y0, 0); |
8147 | and_(1, flags, flags, FlagStoreSums); |
8148 | } else { |
8149 | cmp(1 | eq | state.flagAP, y0, 0); |
8150 | or_(1 | state.flagAP, flags, flags, FlagStoreSums); |
8151 | } |
8152 | } |
8153 | |
8154 | // Release variables that are no longer needed. |
8155 | bool saveIJ0 = false; |
8156 | saveIJ0 |= problem.hasBinaryPostOp(); |
8157 | if (!a2D && !c2D && !saveIJ0) state.ra.safeRelease(state.i0); |
8158 | if (!b2D && !c2D && !saveIJ0) state.ra.safeRelease(state.j0); |
8159 | if (!a2D && !b2D) state.ra.safeRelease(state.h0); |
8160 | if (!strategy.altCRemainder) releaseFusedRemainders(state); |
8161 | state.ra.safeRelease(state.remaindersWG[LoopM]); |
8162 | state.ra.safeRelease(state.remaindersWG[LoopN]); |
8163 | |
8164 | // If A/B are masked, check if we need to change ka_load/kb_load. If so, recalculate lda_ka/ldb_kb. |
8165 | if (gemmPrepMaskedAB(problem, strategy, state)) |
8166 | gemmCalcIncrements(problem, strategy, state); |
8167 | |
8168 | // Disable C prefetch in remainder handling if it needs masks/fragmenting. |
8169 | if (strategy.remHandling[LoopM] != RemainderHandling::Ignore |
8170 | || strategy.remHandling[LoopN] != RemainderHandling::Ignore) { |
8171 | if (strategy.C.base.isStateless() && !strategy.C.padded |
8172 | && strategy.prefetchC |
8173 | && !isBlock2D(strategy.C_prefetch.accessType)) { |
8174 | status << "Auto-disabling C prefetch in masked region" |
8175 | << status_stream::endl; |
8176 | strategy.prefetchC = 0; |
8177 | if (state.effCp != state.effC[0]) state.ra.safeRelease(state.effCp); |
8178 | } |
8179 | } |
8180 | |
8181 | // Try generating kernel body with current strategy. |
8182 | bool success = false; |
8183 | pushStream(); |
8184 | try { |
8185 | success = gemmBodyInternal(problem, strategy, state); |
8186 | } catch (...) { lastException = std::current_exception(); } |
8187 | success ? appendCurrentStream() : discardStream(); |
8188 | |
8189 | return success; |
8190 | } |
8191 | |
8192 | // Allocate nreg registers in chunks of a given size. |
8193 | static inline GRFMultirange chunkAlloc(int nreg, int chunk, Bundle hint, |
8194 | BundleGroup mask, CommonState &state) { |
8195 | GRFMultirange r; |
8196 | for (; nreg > 0; nreg -= chunk) { |
8197 | auto nr = std::min(nreg, chunk); |
8198 | r.ranges.push_back(state.ra.alloc_range(nr, hint, mask)); |
8199 | } |
8200 | return r; |
8201 | } |
8202 | |
8203 | static inline GRFMultirange chunkAlloc( |
8204 | int nreg, int chunk, Bundle hint, CommonState &state) { |
8205 | return chunkAlloc(nreg, chunk, hint, BundleGroup::AllBundles(), state); |
8206 | } |
8207 | |
8208 | // Allocate register layout in individual chunks. |
8209 | static inline GRFMultirange trySplitAlloc(HW hw, Type T, |
8210 | const vector<RegisterBlock> &layout, std::array<Bundle, 2> hints, |
8211 | BundleGroup mask, CommonState &state, int copies = 1) { |
8212 | auto oddHint = Bundle(0, 0).group_size(hw) * elementsPerGRF(hw, T); |
8213 | |
8214 | GRFMultirange r; |
8215 | struct Request { |
8216 | int length, offset, index, hint; |
8217 | }; |
8218 | vector<Request> requests; |
8219 | requests.reserve(layout.size()); |
8220 | |
8221 | for (auto &block : layout) { |
8222 | if (block.isLoadBlock()) { |
8223 | int hint = ((block.colMajor ? block.offsetR : block.offsetC) |
8224 | & oddHint) |
8225 | != 0; |
8226 | requests.push_back({block.msgRegs, block.offsetReg(), 0, hint}); |
8227 | } |
8228 | } |
8229 | |
8230 | if (requests.empty() && !layout.empty()) |
8231 | for (auto &block : layout) { |
8232 | // No memory backing for layout. Split by rows/columns if possible. |
8233 | int hint = ((block.colMajor ? block.offsetR : block.offsetC) |
8234 | & oddHint) |
8235 | != 0; |
8236 | auto &ny = block.colMajor ? block.nc : block.nr; |
8237 | int xElems = (block.ld * block.crosspack * T); |
8238 | int xGRFs = xElems / elementsPerGRF(hw, T); |
8239 | if (xElems % elementsPerGRF(hw, T)) |
8240 | requests.push_back({block.nregs(), block.offsetReg(), 0, |
8241 | hint}); /* can't split */ |
8242 | else |
8243 | for (int y = 0, off = block.offsetReg(); y < ny; |
8244 | y += block.crosspack, off += xGRFs) |
8245 | requests.push_back({xGRFs, off, 0, hint}); |
8246 | } |
8247 | |
8248 | // Figure out which order the ranges belong in. |
8249 | std::sort(requests.begin(), requests.end(), |
8250 | [](const Request &r1, const Request &r2) { |
8251 | return (r1.offset < r2.offset); |
8252 | }); |
8253 | for (size_t i = 0; i < requests.size(); i++) |
8254 | requests[i].index = int(i); |
8255 | |
8256 | // Sort again and allocate largest to smallest. |
8257 | std::sort(requests.begin(), requests.end(), |
8258 | [](const Request &r1, const Request &r2) { |
8259 | return (r1.length > r2.length) |
8260 | || (r1.length == r2.length && r1.offset < r2.offset); |
8261 | }); |
8262 | r.ranges.resize(requests.size() * copies); |
8263 | |
8264 | bool ok = true; |
8265 | for (size_t i = 0; i < requests.size(); i++) { |
8266 | for (int c = 0; c < copies; c++) { |
8267 | auto newRange = state.ra.try_alloc_range( |
8268 | requests[i].length, hints[requests[i].hint], mask); |
8269 | r.ranges[requests[i].index + c * requests.size()] = newRange; |
8270 | ok &= newRange.isValid(); |
8271 | } |
8272 | } |
8273 | |
8274 | if (!ok) { |
8275 | for (auto &rr : r.ranges) |
8276 | state.ra.release(rr); |
8277 | r.ranges.clear(); |
8278 | } |
8279 | |
8280 | return r; |
8281 | } |
8282 | |
8283 | static inline GRFMultirange splitAlloc(HW hw, Type T, |
8284 | const vector<RegisterBlock> &layout, std::array<Bundle, 2> hints, |
8285 | BundleGroup mask, CommonState &state, int copies = 1) { |
8286 | auto r = trySplitAlloc(hw, T, layout, hints, mask, state, copies); |
8287 | if (r.empty() && !layout.empty()) throw out_of_registers_exception(); |
8288 | return r; |
8289 | } |
8290 | |
8291 | // Allocate register ranges for A/B/C. |
8292 | template <HW hw> |
8293 | void gemm_kernel_generator_t<hw>::gemmAllocRegs( |
8294 | GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state) { |
8295 | // Summary: order of allocations is important. |
8296 | auto Ta = problem.Ta, Tb = problem.Tb, Tc = problem.Tc; |
8297 | |
8298 | auto A_copies = strategy.A_copies; |
8299 | auto B_copies = strategy.B_copies; |
8300 | int A_regCount = getRegCount(state.A_layout); |
8301 | int Ar_regCount = getRegCount(state.Ar_layout); |
8302 | int B_regCount = getRegCount(state.B_layout); |
8303 | int Br_regCount = getRegCount(state.Br_layout); |
8304 | int C_regCountPerBuffer = getRegCount(state.C_layout); |
8305 | int C_regCount = state.C_buffers * C_regCountPerBuffer; |
8306 | GRFMultirange C_regs; |
8307 | |
8308 | bool globalCM = isLayoutColMajor(state.C_layout); |
8309 | |
8310 | auto hintA0 = globalCM ? HintType::A0 : HintType::A0Broadcast; |
8311 | auto hintB0 = !globalCM ? HintType::B0 : HintType::B0Broadcast; |
8312 | |
8313 | auto Tv = globalCM ? Ta : Tb; |
8314 | auto Tn = !globalCM ? Ta : Tb; |
8315 | |
8316 | auto &V_layout = globalCM ? state.A_layout : state.B_layout; |
8317 | auto &Vr_layout = globalCM ? state.Ar_layout : state.Br_layout; |
8318 | auto &V_regs = globalCM ? state.A_regs : state.B_regs; |
8319 | auto &Vr_regs = globalCM ? state.Ar_regs : state.Br_regs; |
8320 | auto V_copies = globalCM ? A_copies : B_copies; |
8321 | auto V_regCount = globalCM ? A_regCount : B_regCount; |
8322 | auto Vr_regCount = globalCM ? Ar_regCount : Br_regCount; |
8323 | auto &N_layout = !globalCM ? state.A_layout : state.B_layout; |
8324 | auto &Nr_layout = !globalCM ? state.Ar_layout : state.Br_layout; |
8325 | auto &N_regs = !globalCM ? state.A_regs : state.B_regs; |
8326 | auto &Nr_regs = !globalCM ? state.Ar_regs : state.Br_regs; |
8327 | auto N_copies = !globalCM ? A_copies : B_copies; |
8328 | auto N_regCount = !globalCM ? A_regCount : B_regCount; |
8329 | auto Nr_regCount = !globalCM ? Ar_regCount : Br_regCount; |
8330 | |
8331 | const auto &C_layout = state.C_layout; |
8332 | const auto &C_layoutExt = state.C_layoutExtUnmasked.empty() |
8333 | ? state.C_layoutExt |
8334 | : state.C_layoutExtUnmasked; |
8335 | |
8336 | int C_chunk = state.copyC ? 1 : getMaxLoadBlock(C_layoutExt); |
8337 | C_chunk = alignup_pow2(C_chunk, Bundle(0, 0).group_size(hw) * 2); |
8338 | if (strategy.systolic) C_chunk = std::max(C_chunk, 8); |
8339 | |
8340 | state.C_accCount = strategy.cAccumulators |
8341 | ? AccumulatorRegister::count(hw, strategy.GRFs, Tc.ngen()) |
8342 | : 0; |
8343 | |
8344 | state.A_regs.resize(A_copies); |
8345 | state.B_regs.resize(B_copies); |
8346 | |
8347 | switch (strategy.registerScheme) { |
8348 | case GEMMStrategy::CSeparate: { |
8349 | // Standard allocation (Gen9-11). A and B allocated together in lower half of registers. |
8350 | // Interleave allocation of A and B to minimize wasted registers. Test the waters to find out |
8351 | // whether to try bank 0 or 1 first. |
8352 | int bases[2]; |
8353 | for (int bank = 0; bank < 2; bank++) { |
8354 | auto r = state.ra.alloc_range(4, Bundle(bank, Bundle::any)); |
8355 | bases[bank] = r.getBase(); |
8356 | state.ra.safeRelease(r); |
8357 | } |
8358 | |
8359 | // Order of the banks. |
8360 | int banks[2]; |
8361 | banks[0] = (bases[1] < bases[0]) ? 1 : 0; |
8362 | banks[1] = 1 - banks[0]; |
8363 | |
8364 | // Allocate all the registers needed from bank 0, then all the registers needed from bank 1. |
8365 | for (int bank : banks) { |
8366 | if (getHint(hintA0, strategy).bank_id == bank) { |
8367 | for (int copy = 0; copy < A_copies; copy++) |
8368 | state.A_regs[copy] = state.ra.alloc_range( |
8369 | A_regCount, getHint(hintA0, strategy)); |
8370 | if (state.broadcast && !globalCM) |
8371 | state.broadcast_regs = state.ra.alloc_range( |
8372 | 2, getHint(hintA0, strategy)); |
8373 | if (Ar_regCount > 0) |
8374 | state.Ar_regs = state.ra.alloc_range( |
8375 | Ar_regCount, getHint(hintA0, strategy)); |
8376 | } |
8377 | |
8378 | if (getHint(hintB0, strategy).bank_id == bank) { |
8379 | for (int copy = 0; copy < B_copies; copy++) |
8380 | state.B_regs[copy] = state.ra.alloc_range( |
8381 | B_regCount, getHint(hintB0, strategy)); |
8382 | if (state.broadcast && globalCM) |
8383 | state.broadcast_regs = state.ra.alloc_range( |
8384 | 2, getHint(hintB0, strategy)); |
8385 | if (Br_regCount > 0) |
8386 | state.Br_regs = state.ra.alloc_range( |
8387 | Br_regCount, getHint(hintB0, strategy)); |
8388 | } |
8389 | } |
8390 | |
8391 | C_regs = state.ra.alloc_range(C_regCount - state.C_accCount, |
8392 | getHint(HintType::C, strategy)); |
8393 | break; |
8394 | } |
8395 | case GEMMStrategy::ACB: |
8396 | if (state.broadcast && !globalCM) |
8397 | state.broadcast_regs |
8398 | = state.ra.alloc_range(2, getHint(hintA0, strategy)); |
8399 | |
8400 | for (int copy = 0; copy < A_copies; copy++) |
8401 | state.A_regs[copy] = state.ra.alloc_range( |
8402 | A_regCount, getHint(hintA0, strategy)); |
8403 | if (Ar_regCount > 0) |
8404 | state.Ar_regs = state.ra.alloc_range( |
8405 | Ar_regCount, getHint(hintA0, strategy)); |
8406 | |
8407 | C_regs = state.ra.alloc_range(C_regCount - state.C_accCount, |
8408 | getHint(HintType::C, strategy)); |
8409 | |
8410 | for (int copy = 0; copy < B_copies; copy++) |
8411 | state.B_regs[copy] = state.ra.alloc_range( |
8412 | B_regCount, getHint(hintB0, strategy)); |
8413 | if (Br_regCount > 0) |
8414 | state.Br_regs = state.ra.alloc_range( |
8415 | Br_regCount, getHint(hintB0, strategy)); |
8416 | |
8417 | if (state.broadcast && globalCM) |
8418 | state.broadcast_regs |
8419 | = state.ra.alloc_range(2, getHint(hintB0, strategy)); |
8420 | break; |
8421 | case GEMMStrategy::BCA: |
8422 | if (state.broadcast && !globalCM) |
8423 | state.broadcast_regs |
8424 | = state.ra.alloc_range(2, getHint(hintA0, strategy)); |
8425 | |
8426 | for (int copy = 0; copy < B_copies; copy++) |
8427 | state.B_regs[copy] = state.ra.alloc_range( |
8428 | B_regCount, getHint(hintB0, strategy)); |
8429 | if (Br_regCount > 0) |
8430 | state.Br_regs = state.ra.alloc_range( |
8431 | Br_regCount, getHint(hintB0, strategy)); |
8432 | |
8433 | C_regs = state.ra.alloc_range(C_regCount - state.C_accCount, |
8434 | getHint(HintType::C, strategy)); |
8435 | |
8436 | for (int copy = 0; copy < A_copies; copy++) |
8437 | state.A_regs[copy] = state.ra.alloc_range( |
8438 | A_regCount, getHint(hintA0, strategy)); |
8439 | if (Ar_regCount > 0) |
8440 | state.Ar_regs = state.ra.alloc_range( |
8441 | Ar_regCount, getHint(hintA0, strategy)); |
8442 | |
8443 | if (state.broadcast && globalCM) |
8444 | state.broadcast_regs |
8445 | = state.ra.alloc_range(2, getHint(hintB0, strategy)); |
8446 | break; |
8447 | case GEMMStrategy::VNC: { |
8448 | if (hw < HW::Gen12LP) stub(); |
8449 | |
8450 | // Gen12+. Assign non-broadcast input matrix (V), then broadcast input matrix (N), then C. |
8451 | auto unrollVBytes |
8452 | = strategy.unroll[globalCM ? LoopM : LoopN] * Tv.size(); |
8453 | auto unrollNBytes |
8454 | = strategy.unroll[globalCM ? LoopN : LoopM] * Tn.size(); |
8455 | auto regUnrollV = div_up(unrollVBytes, GRF::bytes(hw)); |
8456 | auto regUnrollN = div_up(unrollNBytes, GRF::bytes(hw)); |
8457 | auto hintV = getHint(HintType::A0, strategy); |
8458 | auto hintN = getHint( |
8459 | (regUnrollN == 1) ? HintType::A0 : HintType::A0Broadcast, |
8460 | strategy); // Put V and N in same bundle if we can avoid N<->C conflicts. |
8461 | auto hintC = getHint(HintType::C, strategy); |
8462 | GRFRange tempPadding; |
8463 | |
8464 | for (int copy = 0; copy < V_copies; copy++) |
8465 | V_regs[copy] = state.ra.alloc_range(V_regCount, hintV); |
8466 | if (Vr_regCount > 0) |
8467 | Vr_regs = state.ra.alloc_range(Vr_regCount, hintV); |
8468 | |
8469 | N_regs[0] = state.ra.alloc_range(N_regCount, hintN); |
8470 | |
8471 | // Check if A * B outer product 0 has a bank conflict. If so, move N to avoid this. |
8472 | auto stride = Bundle(0, 0).stride(hw); |
8473 | auto offN = (N_regs[0][0].getBase() - V_regs[0][0].getBase()) |
8474 | & (stride - 1); |
8475 | auto offNMin = offN - ((regUnrollV - 1) & ~1); |
8476 | auto offNMax = offN + regUnrollN - 1; |
8477 | if (offNMax >= stride) offNMax -= stride, offNMin -= stride; |
8478 | if (offNMin <= 0) { |
8479 | unsigned obAlign = Bundle(0, 0).group_size(hw); |
8480 | if (hintN.bank_id != Bundle::any) obAlign *= 2; |
8481 | offNMax = alignup_pow2(offNMax, obAlign); |
8482 | safeReleaseRanges(N_regs[0], state); |
8483 | tempPadding = state.ra.alloc_range(offNMax, hintN); |
8484 | N_regs[0] = state.ra.alloc_range(N_regCount, hintN); |
8485 | } |
8486 | |
8487 | for (int copy = 1; copy < N_copies; copy++) |
8488 | N_regs[copy] = state.ra.alloc_range(N_regCount, hintN); |
8489 | if (Nr_regCount > 0) |
8490 | Nr_regs = state.ra.alloc_range(Nr_regCount, hintN); |
8491 | |
8492 | state.ra.safeRelease(tempPadding); |
8493 | |
8494 | C_regs = state.ra.alloc_range(C_regCount - state.C_accCount, hintC); |
8495 | break; |
8496 | } |
8497 | case GEMMStrategy::ABInterleave: { |
8498 | // Gen12+. Interleave A and B, place C afterward. |
8499 | if (hw < HW::Gen12LP) stub(); |
8500 | auto chunk = Bundle(0, 0).stride(hw) >> 1; |
8501 | |
8502 | // Test allocation. Put A earlier if it has more registers. |
8503 | int A_regTotal = A_regCount * A_copies + Ar_regCount; |
8504 | int B_regTotal = B_regCount * B_copies + Br_regCount; |
8505 | auto hintA = getHint(HintType::A0, strategy); |
8506 | auto hintB = getHint(HintType::B0, strategy); |
8507 | auto hintC = getHint(HintType::C, strategy); |
8508 | auto testA = state.ra.alloc_range(8, hintA); |
8509 | auto testB = state.ra.alloc_range(8, hintB); |
8510 | if ((testA.getBase() < testB.getBase()) |
8511 | == (A_regTotal < B_regTotal)) |
8512 | std::swap(hintA, hintB); |
8513 | state.ra.safeRelease(testA); |
8514 | state.ra.safeRelease(testB); |
8515 | |
8516 | for (int copy = 0; copy < A_copies; copy++) |
8517 | state.A_regs[copy] |
8518 | = chunkAlloc(A_regCount, chunk, hintA, state); |
8519 | if (Ar_regCount > 0) |
8520 | state.Ar_regs = chunkAlloc(Ar_regCount, chunk, hintA, state); |
8521 | for (int copy = 0; copy < B_copies; copy++) |
8522 | state.B_regs[copy] |
8523 | = chunkAlloc(B_regCount, chunk, hintB, state); |
8524 | if (Br_regCount > 0) |
8525 | state.Br_regs = chunkAlloc(Br_regCount, chunk, hintB, state); |
8526 | C_regs = state.ra.alloc_range(C_regCount - state.C_accCount, hintC); |
8527 | break; |
8528 | } |
8529 | case GEMMStrategy::NSeparate: { |
8530 | // Broadcast matrix (N) has dedicated bundle(s) (both banks) |
8531 | // V and C start in opposite banks in other bundles. |
8532 | if (hw < HW::Gen12LP) stub(); |
8533 | if (state.C_accCount > 0) stub(); |
8534 | |
8535 | int bundles = Bundle::bundle_count(hw) * Bundle::bank_count(hw); |
8536 | int bregsConsecutive = Bundle(0, 0).group_size(hw); |
8537 | int bregs = strategy.GRFs / bundles; |
8538 | int N_chunk = getMaxLoadBlock(N_layout); |
8539 | int N_nregs = Nr_regCount + N_regCount * N_copies; |
8540 | int N_nbundles = std::max( |
8541 | div_up(N_chunk, bregsConsecutive), div_up(N_nregs, bregs)); |
8542 | BundleGroup N_bundles(hw), VC_bundles(hw); |
8543 | |
8544 | auto hintV0 = getHint(HintType::A0, strategy); |
8545 | auto hintV1 = getHint(HintType::A1, strategy); |
8546 | auto hintN = getHint(HintType::A0Broadcast, strategy); |
8547 | auto hintC0 = getHint(HintType::C, strategy); |
8548 | auto hintC1 = getHint(HintType::C1, strategy); |
8549 | |
8550 | // Give bundles starting at the end to broadcast matrix. |
8551 | for (int bundle = Bundle::bundle_count(hw) - 1; bundle >= 0; |
8552 | bundle--) { |
8553 | for (int bank = Bundle::bank_count(hw) - 1; bank >= 0; bank--) { |
8554 | if (N_nbundles-- > 0) |
8555 | N_bundles |= Bundle(bank, bundle); |
8556 | else |
8557 | VC_bundles |= Bundle(bank, bundle); |
8558 | } |
8559 | } |
8560 | |
8561 | for (int copy = 0; copy < V_copies; copy++) |
8562 | V_regs[copy] = splitAlloc( |
8563 | hw, Tv, V_layout, {hintV0, hintV1}, VC_bundles, state); |
8564 | if (Vr_regCount > 0) |
8565 | Vr_regs = splitAlloc( |
8566 | hw, Tv, Vr_layout, {hintV0, hintV1}, VC_bundles, state); |
8567 | if (!strategy.systolic) |
8568 | C_regs = trySplitAlloc(hw, Tc, C_layout, {hintC0, hintC1}, |
8569 | VC_bundles, state, state.C_buffers); |
8570 | if (C_regs.empty()) |
8571 | C_regs = chunkAlloc( |
8572 | C_regCount, C_chunk, hintC0, VC_bundles, state); |
8573 | for (int copy = 0; copy < N_copies; copy++) |
8574 | N_regs[copy] = splitAlloc( |
8575 | hw, Tn, N_layout, {hintN, hintN}, N_bundles, state); |
8576 | if (Nr_regCount > 0) |
8577 | Nr_regs = splitAlloc( |
8578 | hw, Tn, Nr_layout, {hintN, hintN}, N_bundles, state); |
8579 | break; |
8580 | } |
8581 | case GEMMStrategy::VAvoid: { |
8582 | // Broadcast matrix (N) has dedicated starting bank. |
8583 | // V and C share starting banks, but C allocations chosen to miss matching V allocations. |
8584 | auto hintV = getHint(HintType::A0, strategy); |
8585 | auto hintN = getHint(HintType::A0Broadcast, strategy); |
8586 | auto hintC = getHint(HintType::C, strategy); |
8587 | |
8588 | for (int copy = 0; copy < N_copies; copy++) |
8589 | N_regs[copy] = state.ra.alloc_range(N_regCount, hintN); |
8590 | if (Nr_regCount > 0) |
8591 | Nr_regs = state.ra.alloc_range(Nr_regCount, hintN); |
8592 | |
8593 | for (int copy = 0; copy < V_copies; copy++) |
8594 | V_regs[copy] = state.ra.alloc_range(V_regCount, hintV); |
8595 | if (Vr_regCount > 0) |
8596 | Vr_regs = state.ra.alloc_range(Vr_regCount, hintV); |
8597 | |
8598 | int nv; |
8599 | const RegisterBlock *V_block; |
8600 | int V_rows, V_cols; |
8601 | getLayoutDims( |
8602 | Vr_regCount > 0 ? Vr_layout : V_layout, V_rows, V_cols); |
8603 | int kv = globalCM ? V_cols : V_rows; |
8604 | |
8605 | int minOPCount = minOuterProductCount(hw, problem, strategy); |
8606 | int lastMN0 = -1; |
8607 | int sliceRegs = 0; |
8608 | BundleGroup V_bundles(hw); |
8609 | |
8610 | vector<GRFMultirange> C_extra(state.C_buffers - 1); |
8611 | auto allocSlice = [&]() { |
8612 | if (sliceRegs <= 0) return; |
8613 | auto C_bundles = ~V_bundles; |
8614 | |
8615 | C_regs.append(chunkAlloc( |
8616 | sliceRegs, C_chunk, hintC, C_bundles, state)); |
8617 | for (int copy = 1; copy < state.C_buffers; copy++) |
8618 | C_extra[copy - 1].append(chunkAlloc( |
8619 | sliceRegs, C_chunk, hintC, C_bundles, state)); |
8620 | |
8621 | sliceRegs = 0; |
8622 | }; |
8623 | |
8624 | for (const auto &block : C_layout) { |
8625 | int mn0 = globalCM ? block.offsetR : block.offsetC; |
8626 | if (mn0 == lastMN0) { |
8627 | sliceRegs += block.nregs(); |
8628 | continue; |
8629 | } |
8630 | |
8631 | allocSlice(); |
8632 | |
8633 | V_bundles = BundleGroup(hw); |
8634 | for (int h0 = 0; h0 < kv; h0 += minOPCount) { |
8635 | int r = globalCM ? mn0 : h0; |
8636 | int c = globalCM ? h0 : mn0; |
8637 | int comp = 0; |
8638 | if (Vr_regCount == 0) |
8639 | for (int copy = 0; copy < V_copies; copy++) { |
8640 | auto V0 = findBlockReg(Tv, V_layout, r, c, |
8641 | V_regs[copy], nv, V_block, 0, comp); |
8642 | V_bundles |= Bundle::locate(hw, V0); |
8643 | } |
8644 | else { |
8645 | auto V0 = findBlockReg(Tv, Vr_layout, r, c, Vr_regs, nv, |
8646 | V_block, 0, comp); |
8647 | V_bundles |= Bundle::locate(hw, V0); |
8648 | } |
8649 | } |
8650 | |
8651 | lastMN0 = mn0; |
8652 | sliceRegs = block.nregs(); |
8653 | } |
8654 | |
8655 | allocSlice(); |
8656 | |
8657 | for (int copy = 1; copy < state.C_buffers; copy++) |
8658 | C_regs.append(C_extra[copy - 1]); |
8659 | } |
8660 | } |
8661 | |
8662 | // Assign C_regs, adding in GRFs (in place of accumulators) to use later. |
8663 | state.C_regs.resize(state.C_buffers); |
8664 | |
8665 | auto it = C_regs.ranges.begin(); |
8666 | int off = -state.C_accCount; |
8667 | for (int buf = 0; buf < state.C_buffers; buf++) { |
8668 | for (int todo = C_regCountPerBuffer; todo > 0;) { |
8669 | if (it == C_regs.ranges.end()) |
8670 | throw std::runtime_error("Not enough C registers allocated." ); |
8671 | int left = it->getLen() - off; |
8672 | int take = std::min(left, todo); |
8673 | state.C_regs[buf].ranges.push_back( |
8674 | GRFRange(it->getBase() + off, take)); |
8675 | todo -= take; |
8676 | off += take; |
8677 | if (off >= it->getLen()) off = 0, it++; |
8678 | } |
8679 | } |
8680 | |
8681 | // Allocate registers for SLM copies. |
8682 | state.Ai_regs.resize(strategy.slmCopies); |
8683 | state.Bi_regs.resize(strategy.slmCopies); |
8684 | if (strategy.slmA) |
8685 | for (int q = 0; q < strategy.slmCopies; q++) |
8686 | state.Ai_regs[q] = state.ra.alloc_range(state.Ai_regCount); |
8687 | if (strategy.slmB) |
8688 | for (int q = 0; q < strategy.slmCopies; q++) |
8689 | state.Bi_regs[q] = state.ra.alloc_range(state.Bi_regCount); |
8690 | |
8691 | // Allocate registers for A/B sums. |
8692 | state.As_regs = state.ra.alloc_range(getRegCount(state.As_layout)); |
8693 | state.Bs_regs = state.ra.alloc_range(getRegCount(state.Bs_layout)); |
8694 | |
8695 | // Allocate registers for A/B prefetch. |
8696 | state.Ap_regs = state.ra.alloc_range(getRegCount(state.Ap_layout)); |
8697 | state.Bp_regs = state.ra.alloc_range(getRegCount(state.Bp_layout)); |
8698 | |
8699 | // Allocate multiplication temporaries for Gen9 IGEMM, in pairs. |
8700 | if (isGen9IGEMM(hw, Ta, Tb, Tc)) { |
8701 | auto &temps = state.tempMul_regs; |
8702 | for (int ntemp = 0; ntemp < 2; ntemp++) { |
8703 | auto range = state.ra.try_alloc_range(2); |
8704 | if (range.isValid()) |
8705 | temps.push_back(range); |
8706 | else if (temps.empty()) |
8707 | throw out_of_registers_exception(); |
8708 | else |
8709 | break; |
8710 | } |
8711 | } |
8712 | } |
8713 | |
8714 | template <HW hw> |
8715 | void gemm_kernel_generator_t<hw>::gemmAllocAoBoRegs( |
8716 | const GEMMStrategy &strategy, GEMMState &state) { |
8717 | bool allocAo = false, allocBo = false; |
8718 | |
8719 | if (strategy.slmA && state.Ao_regs.empty() && !state.aioShare) { |
8720 | allocAo = true; |
8721 | if (strategy.slmRepackAhead == 0 && strategy.A_copies == 1) { |
8722 | auto nreg = getRegCount(state.Ao_layout); |
8723 | auto &defaultRegs = state.A_regs[0]; |
8724 | allocAo = (defaultRegs.getLen() < nreg); |
8725 | |
8726 | if (!allocAo) { |
8727 | state.Ao_regs = defaultRegs; |
8728 | state.aoReuseA = true; |
8729 | } |
8730 | } |
8731 | } |
8732 | |
8733 | if (strategy.slmB && state.Bo_regs.empty() && !state.bioShare) { |
8734 | allocBo = true; |
8735 | if (strategy.slmRepackAhead == 0 && strategy.B_copies == 1) { |
8736 | auto nreg = getRegCount(state.Bo_layout); |
8737 | auto &defaultRegs = state.B_regs[0]; |
8738 | allocBo = (defaultRegs.getLen() < nreg); |
8739 | |
8740 | if (!allocBo) { |
8741 | state.Bo_regs = defaultRegs; |
8742 | state.boReuseB = true; |
8743 | } |
8744 | } |
8745 | } |
8746 | |
8747 | if (allocAo && !state.allocedAo) { |
8748 | state.allocedAo = true; |
8749 | state.Ao_regs = state.ra.alloc_range(getRegCount(state.Ao_layout)); |
8750 | } |
8751 | |
8752 | if (allocBo && !state.allocedBo) { |
8753 | state.allocedBo = true; |
8754 | state.Bo_regs = state.ra.alloc_range(getRegCount(state.Bo_layout)); |
8755 | } |
8756 | } |
8757 | |
8758 | // Prepare layout for row/column sum matrices, and any needed auxiliary registers. |
8759 | template <HW hw> |
8760 | void gemm_kernel_generator_t<hw>::makeSumLayout(bool column, Type Tsrc, |
8761 | const vector<RegisterBlock> &srcLayout, Type Tdst, |
8762 | vector<RegisterBlock> &dstLayout, const CommonStrategy &strategy, |
8763 | CommonState &state) { |
8764 | bool canDP4A = (hw >= HW::Gen12LP) && one_of(Tsrc, Type::s8, Type::u8) |
8765 | && one_of(Tdst, Type::s32, Type::u32); |
8766 | bool cm = isLayoutColMajor(srcLayout); |
8767 | bool hReduce = (column == cm); |
8768 | bool needAll1s = false; |
8769 | int m, n, cp = 1; |
8770 | |
8771 | getLayoutDims(srcLayout, m, n); |
8772 | auto &rdim = column ? m : n; |
8773 | |
8774 | if (Tsrc.size() == Tdst.size()) cp = srcLayout[0].crosspack; |
8775 | |
8776 | if (hReduce) { |
8777 | if (canDP4A && hasFullCrosspack(srcLayout, 1)) { |
8778 | rdim /= 4; |
8779 | needAll1s = true; |
8780 | if (rdim & 1) rdim <<= 1; // Ensure dp4a dest offset is even. |
8781 | } |
8782 | } else { |
8783 | if (canDP4A && hasFullCrosspack(srcLayout, 4)) needAll1s |= (rdim >= 4); |
8784 | rdim = 1; |
8785 | cp = 1; |
8786 | } |
8787 | |
8788 | bool partials = canSwizzle(hw, Tdst); |
8789 | makeUnbackedRegLayout(Tdst, dstLayout, m, n, cm, cp, 0, 0, partials); |
8790 | |
8791 | // Prepare all-1s immediate for dp4a. |
8792 | if (needAll1s && state.all1s.isInvalid()) { |
8793 | state.all1s = state.ra.alloc_sub( |
8794 | Tdst.ngen(), getHint(HintType::LongTerm, strategy)); |
8795 | mov(1, state.all1s, 0x01010101); |
8796 | } |
8797 | } |
8798 | |
8799 | // Accumulate row/column sums. |
8800 | template <HW hw> |
8801 | void gemm_kernel_generator_t<hw>::accumulateSum(bool column, Type Tsrc, |
8802 | const GRFMultirange &srcRegs, const vector<RegisterBlock> &srcLayout, |
8803 | Type Tdst, const GRFMultirange &dstRegs, |
8804 | const vector<RegisterBlock> &dstLayout, const CommonStrategy &strategy, |
8805 | CommonState &state, int q0, int q1) { |
8806 | bool canDP4A = (hw >= HW::Gen12LP) && one_of(Tsrc, Type::s8, Type::u8) |
8807 | && one_of(Tdst, Type::s32, Type::u32); |
8808 | |
8809 | bool cm = isLayoutColMajor(srcLayout); |
8810 | if (cm != isLayoutColMajor(dstLayout)) stub(); |
8811 | |
8812 | int m, n; |
8813 | getLayoutDims(srcLayout, m, n); |
8814 | |
8815 | // x: consecutive dimension in src; y: strided dimension in src |
8816 | auto nx = cm ? m : n; |
8817 | auto ny = cm ? n : m; |
8818 | |
8819 | int x0 = 0, y0 = 0; |
8820 | int x1 = nx, y1 = ny; |
8821 | |
8822 | if (q1 >= 0) ((column == cm) ? x1 : y1) = q1; |
8823 | if (q0 >= 0) ((column == cm) ? x0 : y0) = q0; |
8824 | |
8825 | // Two cases to handle: |
8826 | // hReduce = false: Good case; no reduction. Sum is vector of size mx1 or 1xn. |
8827 | // hReduce = true: Bad case; needs reduction later, although with dp4a some reduction can be done now. |
8828 | bool hReduce = (column == cm); |
8829 | |
8830 | int yinc = 1; |
8831 | int reduce = (canDP4A && hReduce) ? 4 : 1; |
8832 | if (x0 % reduce || x1 % reduce) stub(); |
8833 | |
8834 | GRFRange temp; |
8835 | |
8836 | for (int y = y0; y < y1; y += yinc) { |
8837 | for (int x = x0; x < x1;) { |
8838 | int isrc, jsrc, idst, jdst, nsrc, ndst; |
8839 | const RegisterBlock *blockSrc, *blockDst; |
8840 | |
8841 | isrc = cm ? x : y; |
8842 | jsrc = cm ? y : x; |
8843 | if (!hReduce) { |
8844 | idst = cm ? x : 0; |
8845 | jdst = cm ? 0 : x; |
8846 | } else { |
8847 | idst = cm ? x / reduce : y; |
8848 | jdst = cm ? y : x / reduce; |
8849 | } |
8850 | |
8851 | Subregister srcBase = findBlockReg( |
8852 | Tsrc, srcLayout, isrc, jsrc, srcRegs, nsrc, blockSrc); |
8853 | Subregister dstBase = findBlockReg( |
8854 | Tdst, dstLayout, idst, jdst, dstRegs, ndst, blockDst); |
8855 | nsrc = std::min(nsrc, x1 - x); |
8856 | int neMax = elementsPerGRF(hw, Tdst) * 2; |
8857 | if (Tdst == Type::f32 && Tsrc.size() < 4) neMax /= 2; |
8858 | auto ne = std::min({nsrc / reduce, ndst, neMax}); |
8859 | |
8860 | auto src = srcBase(blockSrc->crosspack); |
8861 | auto dst = dstBase(blockDst->crosspack); |
8862 | |
8863 | bool hsMatch = (src.getHS() * Tsrc == dst.getHS() * Tdst); |
8864 | if (Tsrc == Type::bf16 && Tdst == Type::f32) |
8865 | hsMatch = (src.getHS() == 1) && (dst.getHS() == 1); |
8866 | |
8867 | if (!canSwizzle(hw, Tsrc) && ne > 1 |
8868 | && (srcBase.getOffset() != dstBase.getOffset() |
8869 | || !hsMatch)) { |
8870 | if (temp.isInvalid()) temp = state.ra.alloc_range(2); |
8871 | auto srcI = src; |
8872 | int tmpHS |
8873 | = std::max<int>(1, (blockDst->crosspack * Tdst) / Tsrc); |
8874 | if (Tsrc == Type::bf16 && Tdst == Type::f32) |
8875 | tmpHS = blockDst->crosspack; |
8876 | auto tmpBase = temp[0].sub( |
8877 | dst.getByteOffset() / Tsrc.real(), src.getType()); |
8878 | auto tmp = tmpBase(tmpHS); |
8879 | auto tmpI = tmp; |
8880 | moveToIntPipe(ne, srcI); |
8881 | moveToIntPipe(ne, tmpI); |
8882 | mov(ne, tmpI, srcI); |
8883 | src = tmp; |
8884 | srcBase = tmpBase; |
8885 | } |
8886 | |
8887 | if (Tsrc == Type::f16 && Tdst == Type::f32 && hw >= HW::Gen12LP) { |
8888 | if (temp.isInvalid()) temp = state.ra.alloc_range(2); |
8889 | if (src.getHS() < 2) stub(); |
8890 | auto tmpF = temp[0].sub(src.getByteOffset() / Type::f32, |
8891 | DataType::f)(src.getHS() / 2); |
8892 | mov(ne, tmpF, src); |
8893 | src = tmpF; |
8894 | } |
8895 | |
8896 | if (canDP4A) { |
8897 | auto srcDP4A |
8898 | = Tsrc.isSigned() ? srcBase.d()(1) : srcBase.ud()(1); |
8899 | if (!hReduce && blockSrc->crosspack == 4) { |
8900 | yinc = std::min(y1 - y, 4); |
8901 | if (yinc == 4) |
8902 | dp4a(ne, dst, dst, srcDP4A, state.all1s); |
8903 | else if (yinc == 1) |
8904 | add(ne, dst, srcBase(4), dst); |
8905 | else |
8906 | dp4a(ne, dst, dst, srcDP4A, |
8907 | 0x01010101 & ((1 << (yinc * 8)) - 1)); |
8908 | } else if (hReduce && blockSrc->crosspack == 1) { |
8909 | if (Tsrc.isSigned()) |
8910 | dp4a(ne, dst, dst, srcDP4A, state.all1s); |
8911 | else { |
8912 | // Workaround for suspected HW issue. |
8913 | dst.setType(DataType::ud); |
8914 | dp4a(ne, dst, dst, srcDP4A, state.all1s.ud()); |
8915 | } |
8916 | } |
8917 | } else |
8918 | eadd(ne, dst, dst, src, strategy, state); |
8919 | |
8920 | x += ne * reduce; |
8921 | } |
8922 | } |
8923 | |
8924 | state.ra.safeRelease(temp); |
8925 | } |
8926 | |
8927 | template <HW hw> |
8928 | void gemm_kernel_generator_t<hw>::setupTeardownAccumulateSumSystolic(bool setup, |
8929 | Type T, const GEMMProblem &problem, const GEMMStrategy &strategy, |
8930 | GEMMState &state) { |
8931 | auto &sysSumAll1s = state.sysSumAll1s; |
8932 | |
8933 | if (setup) { |
8934 | if (sysSumAll1s.isInvalid()) { |
8935 | sysSumAll1s = state.ra.alloc(); |
8936 | sysSumAll1s.setType(T.ngen()); |
8937 | |
8938 | int ne = elementsPerGRF(hw, T); |
8939 | if (T == Type::s8 || T == Type::u8) |
8940 | mov(ne / 4, sysSumAll1s.ud(), uint32_t(0x01010101)); |
8941 | else if (T == Type::bf16) |
8942 | mov(ne, sysSumAll1s.uw(), uint16_t(0x3F80)); |
8943 | else |
8944 | mov(ne, sysSumAll1s.retype(T.arithmetic().ngen()), |
8945 | cast(T.arithmetic(), 1.0)); |
8946 | } |
8947 | } else |
8948 | state.ra.safeRelease(sysSumAll1s); |
8949 | } |
8950 | |
8951 | // Horizontally add intermediate sums if needed. |
8952 | template <HW hw> |
8953 | void gemm_kernel_generator_t<hw>::horizontalAdd(bool column, Type T, |
8954 | const GRFMultirange ®s, vector<RegisterBlock> &layout, |
8955 | CommonState &state) { |
8956 | bool cm = isLayoutColMajor(layout); |
8957 | if (cm != column) return; // Nothing to do. |
8958 | |
8959 | int m, n, cp; |
8960 | getLayoutDims(layout, m, n); |
8961 | cp = layout[0].crosspack; |
8962 | |
8963 | int nx = cm ? m : n; |
8964 | int ny = cm ? n : m; |
8965 | int ne = elementsPerGRF(hw, T); |
8966 | bool swizzleOK = canSwizzle(hw, T); |
8967 | |
8968 | GRF tempGRF; |
8969 | if (!swizzleOK && nx > 1) tempGRF = state.ra.alloc(); |
8970 | |
8971 | int nsLimit = (2 * elementsPerGRF(hw, T)) / cp; |
8972 | |
8973 | for (int chunk = roundup_pow2(nx) >> 1; chunk > 0; chunk >>= 1) { |
8974 | for (int y = 0; y < ny; y += cp) { |
8975 | for (int x = chunk; x < (chunk * 2) && x < nx;) { |
8976 | int i = cm ? x : y; |
8977 | int j = cm ? y : x; |
8978 | int ns, nb; |
8979 | const RegisterBlock *block; |
8980 | Subregister shifted |
8981 | = findBlockReg(T, layout, i, j, regs, ns, block); |
8982 | |
8983 | ns = std::min({ns, chunk, nsLimit}); |
8984 | (cm ? i : j) -= chunk; |
8985 | Subregister base |
8986 | = findBlockReg(T, layout, i, j, regs, nb, block); |
8987 | |
8988 | auto dest = base; |
8989 | if (chunk == 1) dest = regs[y / ne].sub(y % ne, T.ngen()); |
8990 | |
8991 | int ne = ns * cp; |
8992 | |
8993 | if (!swizzleOK && chunk * cp > 1 |
8994 | && shifted.getOffset() != base.getOffset()) { |
8995 | auto temp = tempGRF.sub(base.getOffset(), T.ngen()); |
8996 | auto tempI = temp; |
8997 | auto shiftedI = shifted; |
8998 | moveToIntPipe(tempI); |
8999 | moveToIntPipe(shiftedI); |
9000 | mov(ne, tempI(1), shiftedI(1)); |
9001 | if (base == dest) |
9002 | add(ne, base(1), base(1), temp(1)); |
9003 | else |
9004 | for (int q = 0; q < ne; q++) { |
9005 | add(1, dest, base, temp); |
9006 | dest.setOffset(dest.getOffset() + 1); |
9007 | base.setOffset(base.getOffset() + 1); |
9008 | temp.setOffset(temp.getOffset() + 1); |
9009 | } |
9010 | } else |
9011 | add(ne, dest(1), base(1), shifted(1)); |
9012 | |
9013 | x += ns; |
9014 | } |
9015 | } |
9016 | } |
9017 | |
9018 | state.ra.safeRelease(tempGRF); |
9019 | |
9020 | (cm ? m : n) = 1; |
9021 | makeUnbackedRegLayout(T, layout, m, n, !cm, 1); |
9022 | } |
9023 | |
9024 | // Get final A/B sums. For SLM copy kernels, this requires accumulating each thread's contributions. |
9025 | template <HW hw> |
9026 | bool gemm_kernel_generator_t<hw>::gemmFinalizeSums(const GEMMProblem &problem, |
9027 | const GEMMStrategy &strategy, GEMMState &state) { |
9028 | bool doA = problem.needsASums(); |
9029 | bool doB = problem.needsBSums(); |
9030 | bool doASLM = state.slmASums && (strategy.wg[LoopN] > 1); |
9031 | bool doBSLM = state.slmBSums && (strategy.wg[LoopM] > 1); |
9032 | |
9033 | if (!doA && !doB) return true; |
9034 | |
9035 | auto Tc = problem.Tc; |
9036 | auto unrollM = strategy.unroll[LoopM]; |
9037 | auto unrollN = strategy.unroll[LoopN]; |
9038 | bool ok = true; |
9039 | |
9040 | int ms = 0, ns = 0; |
9041 | if (doA) getLayoutDims(state.As_layout, ms, ns); |
9042 | bool reduceAs = (ns > 1); |
9043 | if (doB) getLayoutDims(state.Bs_layout, ms, ns); |
9044 | bool reduceBs = (ms > 1); |
9045 | |
9046 | if (reduceAs && doA && !doASLM) |
9047 | horizontalAdd(false, Tc, state.As_regs, state.As_layout, state); |
9048 | if (reduceBs && doB && !doBSLM) |
9049 | horizontalAdd(true, Tc, state.Bs_regs, state.Bs_layout, state); |
9050 | |
9051 | if (!doASLM && !doBSLM) return true; |
9052 | |
9053 | if (state.effCoopA == CoopSplit::Linear |
9054 | || state.effCoopB == CoopSplit::Linear) |
9055 | stub(); |
9056 | bool A_coopSplitM = (state.effCoopA == CoopSplit::MN); |
9057 | bool B_coopSplitN = (state.effCoopB == CoopSplit::MN); |
9058 | |
9059 | GRFMultirange *ABs_regs[2] = {&state.As_regs, &state.Bs_regs}; |
9060 | bool AB_coopSplitMN[2] = {A_coopSplitM, B_coopSplitN}; |
9061 | vector<RegisterBlock> *ABs_layout[2] = {&state.As_layout, &state.Bs_layout}; |
9062 | |
9063 | vector<RegisterBlock> ABs_layoutSLM[2]; |
9064 | MatrixAddressing ABs_SLM[2]; |
9065 | MatrixAddressingStrategy ABs_strategySLM[2]; |
9066 | MatrixAddressingStrategy ABs_strategySLMAtomic[2]; |
9067 | vector<GRFRange> ABs_addrs[2]; |
9068 | GRF temp = state.ra.alloc(); |
9069 | FlagRegister leader[2]; |
9070 | Subregister ABs_base[2]; |
9071 | |
9072 | if (state.r0_info.isARF()) stub(); |
9073 | GRF r0_info {state.r0_info.getBase()}; |
9074 | |
9075 | // Plan: |
9076 | // 1) First thread of each m/n-block (leader) stores its sums in SLM; barrier |
9077 | // 2) Remaining threads atomically add their sums to the first; barrier |
9078 | // 3) All threads read final sums |
9079 | // For scattered SLM write kernels, threads have accumulated disjoint parts |
9080 | // of the sums, so the second step isn't needed. However, each thread needs |
9081 | // to do a horizontal reduction first. |
9082 | |
9083 | // Wait for previous SLM reads to complete. |
9084 | // In the meantime, finish sum reduction if necessary. |
9085 | status << "Finalize A/B sums" << status_stream::endl; |
9086 | |
9087 | if (hw >= HW::Gen11) slmfence(temp, r0_info); |
9088 | MOCK_BARRIERS barriersignal(temp, r0_info); |
9089 | |
9090 | if (doASLM && A_coopSplitM) |
9091 | horizontalAdd(false, Tc, state.As_regs, state.As_layout, state); |
9092 | if (doBSLM && B_coopSplitN) |
9093 | horizontalAdd(true, Tc, state.Bs_regs, state.Bs_layout, state); |
9094 | |
9095 | MOCK_BARRIERS barrierwait(); |
9096 | |
9097 | auto step1 = [&](bool isB, int r, int c) { |
9098 | ABs_SLM[isB].setAlignment(r * c * Tc); |
9099 | ABs_SLM[isB].crosspack = 1; |
9100 | ABs_SLM[isB].layout = !isB ? MatrixLayout::Pc : MatrixLayout::Pr; |
9101 | ABs_SLM[isB].packSize = r * c; |
9102 | // Use pseudoblock to share address registers between regular and atomic accesses. |
9103 | ABs_strategySLMAtomic[isB].base = AddressBase::createSLM(); |
9104 | ABs_strategySLMAtomic[isB].padded = true; |
9105 | ABs_strategySLMAtomic[isB].accessType = AB_coopSplitMN[isB] |
9106 | ? AccessType::Block |
9107 | : AccessType::PseudoBlock; |
9108 | ABs_strategySLMAtomic[isB].atomic = !AB_coopSplitMN[isB]; |
9109 | ABs_strategySLMAtomic[isB].newDP = (hw >= HW::XeHPG); |
9110 | ABs_strategySLM[isB] = ABs_strategySLMAtomic[isB]; |
9111 | ABs_strategySLM[isB].atomic = false; |
9112 | |
9113 | ok = ok |
9114 | && getRegLayout(Tc, ABs_layoutSLM[isB], r, c, false, false, |
9115 | true, true, 0, 0, ABs_SLM[isB], |
9116 | ABs_strategySLMAtomic[isB]) |
9117 | && matchLayouts(Tc, ABs_layoutSLM[isB], *ABs_layout[isB]); |
9118 | |
9119 | Subregister adjBase = ABs_base[isB] = state.ra.alloc_sub<uint32_t>(); |
9120 | uint32_t slmOffset |
9121 | = (isB && doASLM) ? (unrollM * strategy.wg[LoopM] * Tc) : 0; |
9122 | |
9123 | !isB ? mulConstant(1, ABs_base[isB], state.lidM, unrollM * Tc) |
9124 | : mulConstant(1, ABs_base[isB], state.lidN, unrollN * Tc); |
9125 | |
9126 | if (strategy.kParallelLocal) { |
9127 | slmOffset *= strategy.wg[LoopK]; |
9128 | int perK = !isB ? strategy.wg[LoopM] * unrollM * Tc |
9129 | : strategy.wg[LoopN] * unrollN * Tc; |
9130 | emad(1, ABs_base[isB], ABs_base[isB], state.lidK, perK, strategy, |
9131 | state); |
9132 | } |
9133 | |
9134 | if (slmOffset != 0) add(1, ABs_base[isB], ABs_base[isB], slmOffset); |
9135 | |
9136 | if (AB_coopSplitMN[isB]) { |
9137 | adjBase = state.ra.alloc_sub<uint32_t>(); |
9138 | !isB ? mulConstant(1, adjBase, state.lidN, state.ma_slm * Tc) |
9139 | : mulConstant(1, adjBase, state.lidM, state.nb_slm * Tc); |
9140 | add(1, adjBase, adjBase, ABs_base[isB]); |
9141 | } |
9142 | |
9143 | allocAddrRegs(ABs_addrs[isB], ABs_layoutSLM[isB], ABs_SLM[isB], |
9144 | ABs_strategySLMAtomic[isB], state); |
9145 | setupAddr(Tc, ABs_addrs[isB], adjBase, ABs_layoutSLM[isB], |
9146 | Subregister(), ABs_SLM[isB], ABs_strategySLMAtomic[isB], |
9147 | strategy, state); |
9148 | |
9149 | if (AB_coopSplitMN[isB]) state.ra.safeRelease(adjBase); |
9150 | |
9151 | Label labelNoStore; |
9152 | if (!AB_coopSplitMN[isB]) { |
9153 | leader[isB] = state.raVFlag.alloc(); |
9154 | cmp(16 | eq | leader[isB], !isB ? state.lidN : state.lidM, 0); |
9155 | if_(16 | leader[isB], labelNoStore); |
9156 | } |
9157 | storeMatrix(*ABs_regs[isB], ABs_layoutSLM[isB], ABs_SLM[isB], |
9158 | ABs_strategySLM[isB], ABs_addrs[isB], strategy, state); |
9159 | if (!AB_coopSplitMN[isB]) { |
9160 | mark(labelNoStore); |
9161 | endif(16); |
9162 | } |
9163 | }; |
9164 | |
9165 | bool barrier2 = false; |
9166 | auto step2 = [&](bool isB) { |
9167 | allocEAtomicAddRegs(hw, Tc, ABs_layoutSLM[isB], ABs_SLM[isB], |
9168 | ABs_strategySLMAtomic[isB], state, state.flagAP); |
9169 | |
9170 | Label labelNoAdd; |
9171 | if_(16 | ~leader[isB], labelNoAdd); |
9172 | atomicAddMatrix(Tc, *ABs_regs[isB], ABs_layoutSLM[isB], ABs_SLM[isB], |
9173 | ABs_strategySLMAtomic[isB], ABs_addrs[isB], problem, strategy, |
9174 | state); |
9175 | mark(labelNoAdd); |
9176 | endif(16); |
9177 | barrier2 = true; |
9178 | |
9179 | freeEAtomicAddRegs(state, state.flagAP); |
9180 | }; |
9181 | |
9182 | auto step3 = [&](bool isB, int r, int c) { |
9183 | if (AB_coopSplitMN[isB]) { |
9184 | safeReleaseRanges(ABs_addrs[isB], state); |
9185 | ABs_SLM[isB].packSize = r * c; |
9186 | ABs_SLM[isB].setAlignment(r * c * Tc); |
9187 | ABs_strategySLM[isB].accessType = AccessType::Block; |
9188 | ok = ok |
9189 | && getRegLayout(Tc, ABs_layoutSLM[isB], r, c, false, false, |
9190 | false, true, 0, 0, ABs_SLM[isB], |
9191 | ABs_strategySLM[isB]); |
9192 | |
9193 | auto nregs = getRegCount(ABs_layoutSLM[isB]); |
9194 | if (nregs > ABs_regs[isB]->getLen()) { |
9195 | safeReleaseRanges(*ABs_regs[isB], state); |
9196 | *ABs_regs[isB] = state.ra.alloc_range(nregs); |
9197 | } |
9198 | |
9199 | allocAddrRegs(ABs_addrs[isB], ABs_layoutSLM[isB], ABs_SLM[isB], |
9200 | ABs_strategySLM[isB], state); |
9201 | setupAddr(Tc, ABs_addrs[isB], ABs_base[isB], ABs_layoutSLM[isB], |
9202 | Subregister(), ABs_SLM[isB], ABs_strategySLM[isB], strategy, |
9203 | state); |
9204 | } |
9205 | loadMatrix(*ABs_regs[isB], ABs_layoutSLM[isB], ABs_SLM[isB], |
9206 | ABs_strategySLM[isB], ABs_addrs[isB], strategy, state); |
9207 | *ABs_layout[isB] = std::move(ABs_layoutSLM[isB]); |
9208 | }; |
9209 | |
9210 | if (doASLM) step1(false, state.ma_slm, 1); |
9211 | if (doBSLM) step1(true, 1, state.nb_slm); |
9212 | |
9213 | MOCK_BARRIERS slmBarrier(temp, r0_info); |
9214 | |
9215 | if (doASLM && !A_coopSplitM) step2(false); |
9216 | if (doBSLM && !B_coopSplitN) step2(true); |
9217 | |
9218 | MOCK_BARRIERS if (barrier2) slmBarrier(temp, r0_info); |
9219 | |
9220 | if (doASLM) step3(false, unrollM, 1); |
9221 | if (doBSLM) step3(true, 1, unrollN); |
9222 | |
9223 | state.ra.safeRelease(temp); |
9224 | state.ra.safeRelease(ABs_base[0]); |
9225 | state.ra.safeRelease(ABs_base[1]); |
9226 | state.raVFlag.safeRelease(leader[0]); |
9227 | state.raVFlag.safeRelease(leader[1]); |
9228 | safeReleaseRanges(ABs_addrs[0], state); |
9229 | safeReleaseRanges(ABs_addrs[1], state); |
9230 | |
9231 | return ok; |
9232 | } |
9233 | |
9234 | // Convert register range to a new type. |
9235 | // If types are different sizes, we assume that the smaller type's stride is the width |
9236 | // of the larger type. |
9237 | template <HW hw> |
9238 | void gemm_kernel_generator_t<hw>::convert(const GRFMultirange &range, Type Told, |
9239 | Type Tnew, const GEMMProblem &problem, const GEMMStrategy &strategy, |
9240 | GEMMState &state) { |
9241 | if (Told == Tnew) return; |
9242 | |
9243 | if (hw == HW::Gen9 && Told == Type::f32 && !Tnew.isFP()) { |
9244 | // Gen9: round to nearest before downconvert (not done by mov). |
9245 | map(hw, Told, range, range, strategy, |
9246 | [&](int esize, GRF r, GRF _) { rnde(esize, r.f(), r.f()); }); |
9247 | } |
9248 | |
9249 | int maxLS = std::max(Told.log2Size(), Tnew.log2Size()); |
9250 | int hsOld = 1 << (maxLS - Told.log2Size()); |
9251 | int hsNew = 1 << (maxLS - Tnew.log2Size()); |
9252 | auto Tmax = (Told.size() < Tnew.size()) ? Tnew : Told; |
9253 | |
9254 | InstructionModifier mod; |
9255 | if (Told != Tnew && Tnew.isInteger() && Tnew.size() <= Told.size()) |
9256 | mod = mod | sat; |
9257 | |
9258 | map(hw, Tmax, range, range, strategy, [&](int esize, GRF r, GRF _) { |
9259 | emov(esize | mod, r.sub(0, Tnew.ngen())(hsNew), |
9260 | r.sub(0, Told.ngen())(hsOld), strategy, state); |
9261 | }); |
9262 | } |
9263 | |
9264 | // Convert C accumulator registers to a new type. Returns true if successful, or false if old and new type are different sizes. |
9265 | template <HW hw> |
9266 | bool gemm_kernel_generator_t<hw>::gemmConvertC(Type Tnew, |
9267 | const GEMMProblem &problem, const GEMMStrategy &strategy, |
9268 | GEMMState &state) { |
9269 | auto Told = state.Tacc; |
9270 | int ncomp = (problem.Tc.isComplex() && state.C_buffers == 2 |
9271 | && state.cSwapActive) |
9272 | ? 2 |
9273 | : 1; |
9274 | |
9275 | if (Tnew.size() != Told.size()) return false; |
9276 | |
9277 | for (int comp = 0; comp < ncomp; comp++) |
9278 | convert(state.C_regs[comp], Told, Tnew, problem, strategy, state); |
9279 | |
9280 | state.Tacc = Tnew; |
9281 | |
9282 | return true; |
9283 | } |
9284 | |
9285 | // Perform beta scaling. |
9286 | template <HW hw> |
9287 | void gemm_kernel_generator_t<hw>::gemmBetaScale( |
9288 | GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state) { |
9289 | Label labelBetaDone; |
9290 | |
9291 | auto Ts = problem.Ts; |
9292 | auto &betar = problem.beta_real; |
9293 | |
9294 | if (state.beta1.isValid()) { |
9295 | if (strategy.fused) { |
9296 | cmp(16 | lt | state.flagAP, null.d(), state.beta1, int16_t(0)); |
9297 | goto12(16 | state.flagAP, labelBetaDone); |
9298 | } else { |
9299 | cmp(1 | lt | state.flagAP, null.d(), state.beta1, int16_t(0)); |
9300 | jmpi(1 | state.flagAP, labelBetaDone); |
9301 | } |
9302 | } |
9303 | |
9304 | gemmConvertC(problem.Ts, problem, strategy, state); |
9305 | |
9306 | if (betar != 1) { |
9307 | map(hw, Ts.real(), state.C_regs[0], state.C_regs[0], strategy, |
9308 | [&](int esize, GRF acc, GRF _) { |
9309 | betar.fixed() ? mul(esize, acc, acc, cast(Ts.real(), betar)) |
9310 | : mul(esize, acc, acc, |
9311 | betar.getRegAvoiding(hw, acc)); |
9312 | }); |
9313 | } |
9314 | |
9315 | gemmConvertC(problem.Tc, problem, strategy, state); |
9316 | |
9317 | mark(labelBetaDone); |
9318 | |
9319 | if (state.beta1.isValid() && strategy.fused) join(16); |
9320 | } |
9321 | |
9322 | template <HW hw> |
9323 | void gemm_kernel_generator_t<hw>::binaryOp(BinaryOp op, int simd, |
9324 | const ngen::RegData &dst, const ngen::RegData &src0, |
9325 | const ngen::RegData &src1) { |
9326 | switch (op) { |
9327 | case BinaryOp::Add: add(simd, dst, src0, src1); break; |
9328 | case BinaryOp::Sub: add(simd, dst, src0, -src1); break; |
9329 | case BinaryOp::Mul: mul(simd, dst, src0, src1); break; |
9330 | case BinaryOp::Div: stub(); |
9331 | case BinaryOp::Min: min_(simd, dst, src0, src1); break; |
9332 | case BinaryOp::Max: max_(simd, dst, src0, src1); break; |
9333 | } |
9334 | } |
9335 | |
9336 | // Apply binary operation to C with a scalar operand. |
9337 | template <HW hw> |
9338 | void gemm_kernel_generator_t<hw>::gemmScalarBinaryOpC(BinaryOp op, |
9339 | const Subregister &offset, const GEMMProblem &problem, |
9340 | const GEMMStrategy &strategy, GEMMState &state) { |
9341 | auto offsetTc = offset.reinterpret(0, state.Tacc.ngen()); |
9342 | if (offset != offsetTc) emov(1, offsetTc, offset, strategy, state); |
9343 | |
9344 | map(hw, state.Tacc, state.C_regs[0], state.C_layout, strategy, |
9345 | [&](int simd, const RegData &r) { |
9346 | binaryOp(op, simd, r, r, offsetTc); |
9347 | }); |
9348 | } |
9349 | |
9350 | // Apply binary operation to C with a vector operand, optionally multiplied by a scalar. |
9351 | template <HW hw> |
9352 | void gemm_kernel_generator_t<hw>::gemmVectorBinaryOpC(BinaryOp op, bool column, |
9353 | const GRFMultirange &offsets, const Subregister &scale, |
9354 | const GEMMProblem &problem, const GEMMStrategy &strategy, |
9355 | GEMMState &state, Type Tco, vector<RegisterBlock> CO_layout, int y0, |
9356 | int y1) { |
9357 | auto Tacc = state.Tacc; |
9358 | auto ne = elementsPerGRF(hw, Tacc); |
9359 | auto globalCM = isLayoutColMajor(state.C_layout); |
9360 | auto unrollX = strategy.unroll[globalCM ? LoopM : LoopN]; |
9361 | auto unrollY = strategy.unroll[globalCM ? LoopN : LoopM]; |
9362 | auto crosspack = CO_layout.empty() ? 1 : CO_layout[0].crosspack; |
9363 | auto stride = [&]() { return (column == globalCM) ? 0 : crosspack; }; |
9364 | const GRFMultirange *offsetsPtr = &offsets; |
9365 | |
9366 | if (Tco == Type::invalid) Tco = Tacc; |
9367 | |
9368 | bool needRepack = (Tacc != Tco); |
9369 | needRepack |= (stride() > 1 && hw >= HW::XeHP && Tacc.isFP()); |
9370 | |
9371 | GRFMultirange repackOffsets; |
9372 | if (needRepack) { |
9373 | // Repack data to unit stride as float pipe can't swizzle. |
9374 | vector<RegisterBlock> repackLayout; |
9375 | int r = column ? 1 : strategy.unroll[LoopM]; |
9376 | int c = !column ? 1 : strategy.unroll[LoopN]; |
9377 | makeUnbackedRegLayout(Tacc, repackLayout, r, c, !column); |
9378 | repackOffsets = state.ra.alloc_range(getRegCount(repackLayout)); |
9379 | copyRegisters(Tco, Tacc, CO_layout, repackLayout, offsets, |
9380 | repackOffsets, 0, 0, false, strategy, state); |
9381 | crosspack = 1; |
9382 | offsetsPtr = &repackOffsets; |
9383 | } |
9384 | |
9385 | if (y0 < 0) y0 = 0; |
9386 | if (y1 < 0) y1 = unrollY; |
9387 | |
9388 | for (int y = y0; y < y1; y++) { |
9389 | for (int x = 0; x < unrollX;) { |
9390 | auto i = globalCM ? x : y; |
9391 | auto j = globalCM ? y : x; |
9392 | int nc; |
9393 | const RegisterBlock *C_block; |
9394 | Subregister C = findBlockReg( |
9395 | Tacc, state.C_layout, i, j, state.C_regs[0], nc, C_block); |
9396 | |
9397 | nc = std::min({nc, strategy.fmaSIMD / crosspack, 2 * ne}); |
9398 | auto nco = (column ? j : i) * crosspack; |
9399 | auto offBase = (*offsetsPtr)[nco / ne].sub(nco % ne, Tacc.ngen()); |
9400 | if (scale.isValid()) { |
9401 | if (op != BinaryOp::Add) stub(); |
9402 | mad(nc, C(1), C(1), offBase(stride()), scale); |
9403 | } else |
9404 | binaryOp(op, nc, C(1), C(1), offBase(stride())); |
9405 | |
9406 | x += nc; |
9407 | } |
9408 | } |
9409 | |
9410 | safeReleaseRanges(repackOffsets, state); |
9411 | } |
9412 | |
9413 | // Apply binary operation to C. |
9414 | template <HW hw> |
9415 | bool gemm_kernel_generator_t<hw>::gemmBinaryOpC(BinaryOp op, bool row, |
9416 | bool column, Type Tco, MatrixAddressing CO, |
9417 | MatrixAddressingStrategy CO_strategy, Subregister base, Subregister ld, |
9418 | const GEMMProblem &problem, const GEMMStrategy &strategy, |
9419 | GEMMState &state) { |
9420 | std::vector<GRFRange> CO_addrs; |
9421 | std::vector<RegisterBlock> CO_layout; |
9422 | std::vector<MaskAssignment> masks; |
9423 | auto globalCM = isLayoutColMajor(state.C_layout); |
9424 | |
9425 | bool recip = false; |
9426 | if (op == BinaryOp::Div) { |
9427 | // Implement div as inv+mul for speed, especially when broadcasting. |
9428 | recip = true; |
9429 | op = BinaryOp::Mul; |
9430 | if (!one_of(Tco, Type::f32, Type::f16)) stub(); |
9431 | } |
9432 | |
9433 | bool matrix = row && column; |
9434 | if (matrix) { |
9435 | // Matrix case implemented as loop over rows/columns, depending on C's layout. |
9436 | row &= globalCM; |
9437 | column &= !globalCM; |
9438 | CO_strategy.accessType = (isColMajor(CO.layout) == row) |
9439 | ? AccessType::Block |
9440 | : CO_strategy.base.isStateless() ? AccessType::Scattered |
9441 | : AccessType::ChannelScattered; |
9442 | } else { |
9443 | CO.layout = column ? MatrixLayout::T : MatrixLayout::N; |
9444 | CO_strategy.accessType = AccessType::Block; |
9445 | } |
9446 | |
9447 | bool coColMajor = isColMajor(CO.layout); |
9448 | |
9449 | auto cor = row ? strategy.unroll[LoopM] : 1; |
9450 | auto coc = column ? strategy.unroll[LoopN] : 1; |
9451 | auto remR = row && !CO_strategy.padded; |
9452 | auto remC = column && !CO_strategy.padded; |
9453 | |
9454 | if (!getRegLayout(Tco, CO_layout, cor, coc, remR, remC, false, true, 0, 0, |
9455 | CO, CO_strategy)) |
9456 | return false; |
9457 | |
9458 | auto CO_regs = state.ra.alloc_range(getRegCount(CO_layout)); |
9459 | |
9460 | allocAddrRegs(CO_addrs, CO_layout, CO, CO_strategy, state); |
9461 | setupAddr(Tco, CO_addrs, base, CO_layout, ld, CO, CO_strategy, strategy, |
9462 | state); |
9463 | |
9464 | if (!assignMasks(CO_layout, LoopM, LoopN, masks, strategy, state, true)) |
9465 | return false; |
9466 | |
9467 | loadMasks(masks, state.remainders, strategy, state); |
9468 | |
9469 | if (matrix) { |
9470 | auto LoopY = globalCM ? LoopN : LoopM; |
9471 | auto unrollY = strategy.unroll[LoopY]; |
9472 | auto remY = state.remainders[LoopY]; |
9473 | Label lDone; |
9474 | bool simtCF = strategy.fused && (strategy.fusedLoop == LoopY); |
9475 | int simt = simtCF ? 16 : 1; |
9476 | |
9477 | if (!CO_strategy.padded) cmp(simt | gt | state.flagAP, remY, 0); |
9478 | |
9479 | for (int y = 0; y < unrollY; y++) { |
9480 | if (!CO_strategy.padded) { |
9481 | simtCF ? goto12(16 | ~state.flagAP, lDone) |
9482 | : jmpi(1 | ~state.flagAP, lDone); |
9483 | } |
9484 | loadMatrix(CO_regs, CO_layout, CO, CO_strategy, CO_addrs, strategy, |
9485 | state); |
9486 | if (recip) |
9487 | map(hw, Tco, CO_regs, CO_regs, strategy, |
9488 | [&](int simd, GRF r, GRF) { inv(simd, r, r); }); |
9489 | if (!CO_strategy.padded && (y + 1 < unrollY)) |
9490 | cmp(simt | gt | state.flagAP, remY, y + 1); |
9491 | if (coColMajor == globalCM) |
9492 | incAddr(CO_addrs, ld, int(row), int(column), CO_layout, CO, |
9493 | CO_strategy, strategy, state); |
9494 | else |
9495 | incAddr(CO_addrs, Tco.size(), int(row), int(column), CO_layout, |
9496 | CO, CO_strategy, strategy, state); |
9497 | |
9498 | gemmVectorBinaryOpC(op, column, CO_regs, Subregister(), problem, |
9499 | strategy, state, Tco, CO_layout, y, y + 1); |
9500 | } |
9501 | |
9502 | mark(lDone); |
9503 | if (simtCF) join(16); |
9504 | } else { |
9505 | loadMatrix( |
9506 | CO_regs, CO_layout, CO, CO_strategy, CO_addrs, strategy, state); |
9507 | if (recip) |
9508 | map(hw, Tco, CO_regs, CO_regs, strategy, |
9509 | [&](int simd, GRF r, GRF) { inv(simd, r, r); }); |
9510 | |
9511 | if (!row && !column) |
9512 | gemmScalarBinaryOpC(op, CO_regs[0].sub(0, Tco.ngen()), problem, |
9513 | strategy, state); |
9514 | else |
9515 | gemmVectorBinaryOpC(op, column, CO_regs, Subregister(), problem, |
9516 | strategy, state, Tco, CO_layout); |
9517 | } |
9518 | |
9519 | safeReleaseMaskAssignments(masks, state); |
9520 | state.ra.safeRelease(CO_regs); |
9521 | safeReleaseRanges(CO_addrs, state); |
9522 | |
9523 | return true; |
9524 | } |
9525 | |
9526 | // Check kernel input for desired C offset and apply it. |
9527 | template <HW hw> |
9528 | bool gemm_kernel_generator_t<hw>::gemmApplyCOffsetDispatch( |
9529 | const GEMMProblem &problem, const GEMMStrategy &strategy, |
9530 | GEMMState &state) { |
9531 | Label labelCOColumn, labelCORow, labelCOMatrix, labelCODone; |
9532 | bool doMatrix = problem.allowMatrixOffset(); |
9533 | auto Tco = problem.Tco; |
9534 | auto &CO = problem.CO; |
9535 | auto &CO_strategy = strategy.CO; |
9536 | auto &effCO = state.effCO; |
9537 | auto &ldco = state.inputs.ldco; |
9538 | |
9539 | bool ok = true; |
9540 | |
9541 | if (state.flagSwizzle.isValid()) state.raVFlag.release(state.flagSwizzle); |
9542 | |
9543 | auto flagNonfinal = state.raVFlag.alloc(); |
9544 | auto flagCOC = state.raVFlag.alloc(); |
9545 | auto flagCOR = state.raVFlag.alloc(); |
9546 | |
9547 | and_(1 | nz | flagNonfinal, null.ud(), state.inputs.flags, |
9548 | FlagNonfinalKBlock); |
9549 | and_(1 | nz | flagCOC, null.ud(), state.inputs.flags, FlagCOColumn); |
9550 | and_(1 | nz | flagCOR, null.ud(), state.inputs.flags, FlagCORow); |
9551 | jmpi(1 | flagNonfinal, labelCODone); |
9552 | jmpi(1 | flagCOC, labelCOColumn); |
9553 | jmpi(1 | flagCOR, labelCORow); |
9554 | |
9555 | state.raVFlag.safeRelease(flagNonfinal); |
9556 | state.raVFlag.safeRelease(flagCOC); |
9557 | state.raVFlag.safeRelease(flagCOR); |
9558 | |
9559 | if (state.flagSwizzle.isValid()) state.raVFlag.claim(state.flagSwizzle); |
9560 | |
9561 | status << "Applying fixed C offset" << status_stream::endl; |
9562 | ok = ok |
9563 | && gemmBinaryOpC(BinaryOp::Add, false, false, Tco, CO, CO_strategy, |
9564 | effCO, ldco, problem, strategy, state); |
9565 | jmpi(1, labelCODone); |
9566 | |
9567 | mark(labelCOColumn); |
9568 | if (doMatrix) jmpi(1 | flagCOR, labelCOMatrix); |
9569 | status << "Applying column-wise C offset" << status_stream::endl; |
9570 | ok = ok |
9571 | && gemmBinaryOpC(BinaryOp::Add, false, true, Tco, CO, CO_strategy, |
9572 | effCO, ldco, problem, strategy, state); |
9573 | jmpi(1, labelCODone); |
9574 | |
9575 | mark(labelCORow); |
9576 | status << "Applying row-wise C offset" << status_stream::endl; |
9577 | ok = ok |
9578 | && gemmBinaryOpC(BinaryOp::Add, true, false, Tco, CO, CO_strategy, |
9579 | effCO, ldco, problem, strategy, state); |
9580 | |
9581 | if (doMatrix) { |
9582 | jmpi(1, labelCODone); |
9583 | |
9584 | mark(labelCOMatrix); |
9585 | status << "Applying matrix C offset" << status_stream::endl; |
9586 | ok = ok |
9587 | && gemmBinaryOpC(BinaryOp::Add, true, true, Tco, CO, |
9588 | CO_strategy, effCO, ldco, problem, strategy, state); |
9589 | } |
9590 | |
9591 | mark(labelCODone); |
9592 | |
9593 | if (!strategy.persistent) { |
9594 | state.ra.safeRelease(ldco); |
9595 | state.ra.safeRelease(effCO); |
9596 | } |
9597 | |
9598 | return ok; |
9599 | } |
9600 | |
9601 | static inline BinaryOp dnnlToBinaryOp(alg_kind_t kind) { |
9602 | switch (kind) { |
9603 | case alg_kind::binary_add: return BinaryOp::Add; |
9604 | case alg_kind::binary_sub: return BinaryOp::Sub; |
9605 | case alg_kind::binary_mul: return BinaryOp::Mul; |
9606 | case alg_kind::binary_div: return BinaryOp::Div; |
9607 | case alg_kind::binary_min: return BinaryOp::Min; |
9608 | case alg_kind::binary_max: return BinaryOp::Max; |
9609 | default: stub(); |
9610 | } |
9611 | } |
9612 | |
9613 | template <HW hw> |
9614 | void gemm_kernel_generator_t<hw>::gemmLoadBinaryOpArgs( |
9615 | const GEMMProblem &problem, const GEMMStrategy &strategy, |
9616 | GEMMState &state) { |
9617 | if (hw < HW::XeHP) stub(); |
9618 | |
9619 | std::vector<ngen::Subregister *> argList; |
9620 | argList.reserve(state.inputs.binaryOffsets.size() * 5); |
9621 | |
9622 | for (auto &r : state.inputs.binarySrcs) |
9623 | if (r.isValid()) argList.push_back(&r); |
9624 | for (auto &r : state.inputs.binaryOffsets) |
9625 | if (r.isValid()) argList.push_back(&r); |
9626 | for (auto &r : state.inputs.binaryLDs) |
9627 | if (r.isValid()) argList.push_back(&r); |
9628 | |
9629 | if (problem.batch == BatchMode::Strided) { |
9630 | for (auto &rs : state.inputs.binaryStrides) |
9631 | for (auto &r : rs) |
9632 | if (r.isValid()) argList.push_back(&r); |
9633 | } |
9634 | |
9635 | int loadBase = interface.getArgLoadBase().getBase(); |
9636 | int nGRFs = 0; |
9637 | for (auto arg : argList) { |
9638 | int base = arg->getBase(); |
9639 | if (base < loadBase) stub(); |
9640 | nGRFs = std::max(nGRFs, base - loadBase + 1); |
9641 | } |
9642 | |
9643 | auto temp = state.ra.alloc(); |
9644 | auto args = state.ra.alloc_range(nGRFs); |
9645 | |
9646 | if (state.r0_info.isARF() || state.r0_info.getBase() != 0) stub(); |
9647 | loadargs(args, nGRFs, temp); |
9648 | |
9649 | int grfOffset = args.getBase() - loadBase; |
9650 | |
9651 | state.ra.release(args); |
9652 | |
9653 | for (auto arg : argList) { |
9654 | arg->setBase(arg->getBase() + grfOffset); |
9655 | state.ra.claim(*arg); |
9656 | } |
9657 | |
9658 | state.ra.safeRelease(temp); |
9659 | } |
9660 | |
9661 | template <HW hw> |
9662 | void gemm_kernel_generator_t<hw>::gemmApplyPostOps(int poMin, int poMax, |
9663 | const GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state) { |
9664 | if (poMin >= poMax) return; |
9665 | |
9666 | Label lSkip; |
9667 | and_(1 | nz | state.flagAP, null.ud(), state.inputs.flags, |
9668 | FlagNonfinalKBlock); |
9669 | |
9670 | (void)gemmConvertC(problem.Ts, problem, strategy, state); |
9671 | |
9672 | jmpi(1 | state.flagAP, lSkip); |
9673 | |
9674 | // Binary preparations: load binary-related args + calculate starting addresses |
9675 | if (problem.hasBinaryPostOp() && state.effBinary.empty()) { |
9676 | int poCount = problem.postOps.len(); |
9677 | auto &entries = problem.postOps.entry_; |
9678 | |
9679 | gemmLoadBinaryOpArgs(problem, strategy, state); |
9680 | |
9681 | #define FOR_EACH_BINARY \ |
9682 | for (int i = 0; i < poCount; i++) \ |
9683 | if (entries[i].is_binary()) |
9684 | |
9685 | FOR_EACH_BINARY { |
9686 | const auto &ld = state.inputs.binaryLDs[i]; |
9687 | auto T = problem.Tbinary[i]; |
9688 | if (ld.isValid()) |
9689 | emulConstant(1, ld, ld, T.size(), strategy, state); |
9690 | emulConstant(1, state.inputs.binaryOffsets[i], |
9691 | state.inputs.binaryOffsets[i], T.size(), strategy, state); |
9692 | if (problem.batch == BatchMode::Strided) |
9693 | for (int b = 0; b < problem.batchDims; b++) { |
9694 | const auto &stride = state.inputs.binaryStrides[i][b]; |
9695 | if (stride.isValid()) |
9696 | emulConstant( |
9697 | 1, stride, stride, T.size(), strategy, state); |
9698 | } |
9699 | } |
9700 | |
9701 | for (int b = 0; b < 2; b++) { |
9702 | FOR_EACH_BINARY { |
9703 | const auto &stride = state.inputs.binaryStrides[i][b]; |
9704 | if (stride.isValid()) |
9705 | emul(1, stride, stride, state.batchID[b], strategy, state); |
9706 | } |
9707 | } |
9708 | |
9709 | for (int b = 0; b < 2; b++) { |
9710 | FOR_EACH_BINARY { |
9711 | auto &offsetStride = state.inputs.binaryStrides[i][b]; |
9712 | if (offsetStride.isValid()) |
9713 | eadd(1, state.inputs.binaryOffsets[i], |
9714 | state.inputs.binaryOffsets[i], offsetStride, |
9715 | strategy, state); |
9716 | state.ra.safeRelease(offsetStride); |
9717 | } |
9718 | } |
9719 | |
9720 | gemmOffsetABC(true, state.i0, state.j0, state.h0, problem, strategy, |
9721 | state, false, false, false, true); |
9722 | |
9723 | state.effBinary.resize(poCount); |
9724 | |
9725 | FOR_EACH_BINARY { |
9726 | if (strategy.binary[i].base.isStateless()) { |
9727 | state.effBinary[i] = state.inputs.binarySrcs[i]; |
9728 | eadd(1, state.effBinary[i], state.inputs.binarySrcs[i], |
9729 | state.inputs.binaryOffsets[i], strategy, state); |
9730 | state.ra.safeRelease(state.inputs.binaryOffsets[i]); |
9731 | } else |
9732 | state.effBinary[i] = state.inputs.binaryOffsets[i]; |
9733 | } |
9734 | #undef FOR_EACH_BINARY |
9735 | } |
9736 | |
9737 | // Apply post-ops to all of C. |
9738 | for (int i = poMin; i < poMax; i++) { |
9739 | auto &entry = problem.postOps.entry_[i]; |
9740 | switch (entry.kind) { |
9741 | case primitive_kind::eltwise: { |
9742 | using Injector = jit_eltwise_injector_f32<hw>; |
9743 | if (state.Tacc != Type::f32) stub(); |
9744 | |
9745 | int euCount = 0; /* only used for a DG2 W/A for conv */ |
9746 | auto &ee = entry.eltwise; |
9747 | Injector injector {this, ee.alg, ee.alpha, ee.beta, ee.scale, |
9748 | euCount, GRFRange(), problem.postOpFwd}; |
9749 | |
9750 | auto scratch = state.ra.try_alloc_range( |
9751 | injector.preferred_scratch_regs()); |
9752 | if (scratch.isInvalid()) |
9753 | scratch = state.ra.alloc_range(injector.min_scratch_regs()); |
9754 | |
9755 | injector.set_scratch(scratch); |
9756 | injector.prepare(); |
9757 | for (auto &rr : state.C_regs[0].ranges) |
9758 | injector.compute(rr); |
9759 | break; |
9760 | } |
9761 | case primitive_kind::binary: { |
9762 | auto &ld = state.inputs.binaryLDs[i]; |
9763 | auto &eff = state.effBinary[i]; |
9764 | auto op = dnnlToBinaryOp(entry.binary.alg); |
9765 | |
9766 | gemmBinaryOpC(op, problem.binaryRow[i], problem.binaryCol[i], |
9767 | problem.Tbinary[i], problem.binary[i], |
9768 | strategy.binary[i], eff, ld, problem, strategy, state); |
9769 | |
9770 | state.ra.safeRelease(ld); |
9771 | state.ra.safeRelease(eff); |
9772 | break; |
9773 | } |
9774 | default: stub(); |
9775 | } |
9776 | } |
9777 | |
9778 | mark(lSkip); |
9779 | } |
9780 | |
9781 | // Calculate addresses of A/B sums in packed input data. Sums are stored at the end of each panel. |
9782 | template <HW hw> |
9783 | void gemm_kernel_generator_t<hw>::gemmCalcABOffsetAddrs( |
9784 | const GEMMProblem &problem, const GEMMStrategy &strategy, |
9785 | GEMMState &state) { |
9786 | auto &effAs = state.effAs; |
9787 | auto &effBs = state.effBs; |
9788 | |
9789 | auto Tc = problem.Tc; |
9790 | auto unrollM = strategy.unroll[LoopM]; |
9791 | auto unrollN = strategy.unroll[LoopN]; |
9792 | |
9793 | if (effAs.isInvalid()) effAs = state.ra.alloc_sub(state.effA.getType()); |
9794 | if (effBs.isInvalid()) effBs = state.ra.alloc_sub(state.effB.getType()); |
9795 | |
9796 | mulConstant(1, effAs.ud(), state.inputs.lda, unrollM); |
9797 | mulConstant(1, effBs.ud(), state.inputs.ldb, unrollN); |
9798 | add(1, effAs.ud(), effAs.ud(), -unrollM * Tc); |
9799 | add(1, effBs.ud(), effBs.ud(), -unrollN * Tc); |
9800 | eadd(1, effAs, effAs.ud(), state.effA, strategy, state); |
9801 | eadd(1, effBs, effBs.ud(), state.effB, strategy, state); |
9802 | } |
9803 | |
9804 | // Load A/B sums from packed input data. |
9805 | template <HW hw> |
9806 | bool gemm_kernel_generator_t<hw>::gemmLoadABOffset(const GEMMProblem &problem, |
9807 | const GEMMStrategy &strategy, GEMMState &state) { |
9808 | if (problem.abOffset != ABOffset::Load) return true; |
9809 | |
9810 | auto Tc = problem.Tc; |
9811 | auto unrollM = strategy.unroll[LoopM]; |
9812 | auto unrollN = strategy.unroll[LoopN]; |
9813 | |
9814 | MatrixAddressing As = problem.A, Bs = problem.B; |
9815 | As.crosspack = 1; |
9816 | Bs.crosspack = 1; |
9817 | As.tileR = As.tileC = 0; |
9818 | Bs.tileR = Bs.tileC = 0; |
9819 | |
9820 | MatrixAddressingStrategy As_strategy = strategy.A, Bs_strategy = strategy.B; |
9821 | As_strategy.accessType = AccessType::Block; |
9822 | Bs_strategy.accessType = AccessType::Block; |
9823 | As_strategy.tileR = As_strategy.tileC = 0; |
9824 | Bs_strategy.tileR = Bs_strategy.tileC = 0; |
9825 | As_strategy.dpasw = Bs_strategy.dpasw = false; |
9826 | |
9827 | bool ok = true; |
9828 | ok = ok |
9829 | && getRegLayout(Tc, state.As_layout, unrollM, 1, false, false, |
9830 | false, true, 0, 0, As, As_strategy); |
9831 | ok = ok |
9832 | && getRegLayout(Tc, state.Bs_layout, 1, unrollN, false, false, |
9833 | false, true, 0, 0, Bs, Bs_strategy); |
9834 | if (!ok) return false; |
9835 | |
9836 | state.As_regs = state.ra.alloc_range(getRegCount(state.As_layout)); |
9837 | state.Bs_regs = state.ra.alloc_range(getRegCount(state.Bs_layout)); |
9838 | |
9839 | vector<GRFRange> As_addrs, Bs_addrs; |
9840 | allocAddrRegs(As_addrs, state.As_layout, As, As_strategy, state); |
9841 | allocAddrRegs(Bs_addrs, state.Bs_layout, Bs, Bs_strategy, state); |
9842 | |
9843 | if (state.effAs.isInvalid()) |
9844 | gemmCalcABOffsetAddrs(problem, strategy, state); |
9845 | |
9846 | setupAddr(Tc, As_addrs, state.effAs, state.As_layout, Subregister(), As, |
9847 | As_strategy, strategy, state); |
9848 | setupAddr(Tc, Bs_addrs, state.effBs, state.Bs_layout, Subregister(), Bs, |
9849 | Bs_strategy, strategy, state); |
9850 | |
9851 | loadMatrix(state.As_regs, state.As_layout, As, As_strategy, As_addrs, |
9852 | strategy, state); |
9853 | loadMatrix(state.Bs_regs, state.Bs_layout, Bs, Bs_strategy, Bs_addrs, |
9854 | strategy, state); |
9855 | |
9856 | state.ra.safeRelease(state.effAs); |
9857 | state.ra.safeRelease(state.effBs); |
9858 | safeReleaseRanges(As_addrs, state); |
9859 | safeReleaseRanges(Bs_addrs, state); |
9860 | |
9861 | return true; |
9862 | } |
9863 | |
9864 | // Apply contributions from A/B offsets to C matrix, using previously loaded/computed |
9865 | // A row sums and B column sums. |
9866 | template <HW hw> |
9867 | void gemm_kernel_generator_t<hw>::gemmApplyABOffset(const GEMMProblem &problem, |
9868 | const GEMMStrategy &strategy, GEMMState &state) { |
9869 | if (problem.abOffset == ABOffset::None) return; |
9870 | |
9871 | // Two steps: (O = all-1s matrix) |
9872 | // 1) C += A * O * bo |
9873 | // 2) C += (O * B + bo * k) * ao |
9874 | // TODO: combine C adds into add3 on XeHP+. |
9875 | auto temp = state.ra.alloc_sub(problem.Tc.ngen()); |
9876 | mul(1, temp, state.k, state.inputs.bo); |
9877 | |
9878 | bool noFMA = (hw == HW::Gen9); |
9879 | if (noFMA) { |
9880 | map(hw, problem.Tc, state.Bs_regs, state.Bs_layout, strategy, |
9881 | [&](int ne, RegData r) { add(ne, r, r, temp); }); |
9882 | map(hw, problem.Tc, state.As_regs, state.As_layout, strategy, |
9883 | [&](int ne, RegData r) { mul(ne, r, r, state.inputs.bo); }); |
9884 | map(hw, problem.Tc, state.Bs_regs, state.Bs_layout, strategy, |
9885 | [&](int ne, RegData r) { mul(ne, r, r, state.inputs.ao); }); |
9886 | } else { |
9887 | mul(1, temp, temp, state.inputs.ao); |
9888 | map(hw, problem.Tc, state.Bs_regs, state.Bs_layout, strategy, |
9889 | [&](int ne, RegData r) { |
9890 | mad(ne, r, temp, r, state.inputs.ao); |
9891 | }); |
9892 | } |
9893 | state.ra.safeRelease(temp); |
9894 | |
9895 | gemmVectorBinaryOpC(BinaryOp::Add, false, state.As_regs, |
9896 | noFMA ? Subregister() : state.inputs.bo, problem, strategy, state, |
9897 | problem.Tc, state.As_layout); |
9898 | gemmVectorBinaryOpC(BinaryOp::Add, true, state.Bs_regs, Subregister(), |
9899 | problem, strategy, state, problem.Tc, state.Bs_layout); |
9900 | |
9901 | safeReleaseRanges(state.As_regs, state); |
9902 | safeReleaseRanges(state.Bs_regs, state); |
9903 | if (!strategy.persistent) { |
9904 | state.ra.safeRelease(state.inputs.ao); |
9905 | state.ra.safeRelease(state.inputs.bo); |
9906 | } |
9907 | state.As_layout.clear(); |
9908 | state.Bs_layout.clear(); |
9909 | } |
9910 | |
9911 | // Store A/B sum data into CO. |
9912 | template <HW hw> |
9913 | void gemm_kernel_generator_t<hw>::gemmUpdateSums(const GEMMProblem &problem, |
9914 | const GEMMStrategy &strategy, GEMMState &state) { |
9915 | bool sumA = problem.sumA, sumB = problem.sumB; |
9916 | |
9917 | if (!sumA && !sumB) return; |
9918 | if (sumA && sumB) stub(); // can only store one of the two in CO. |
9919 | |
9920 | auto Tc = problem.Tc; |
9921 | auto Tco = problem.Tco; |
9922 | auto cor = sumA ? strategy.unroll[LoopM] : 1; |
9923 | auto coc = sumB ? strategy.unroll[LoopN] : 1; |
9924 | bool atomic = strategy.CO.atomic; |
9925 | bool checkBeta0 = problem.checkBeta0 && !problem.beta_real.fixed(); |
9926 | bool checkBeta1 = atomic && !problem.beta1(); |
9927 | |
9928 | auto CO = problem.CO; |
9929 | auto CO_strategy = strategy.CO; |
9930 | std::vector<GRFRange> CO_addrs; |
9931 | std::vector<RegisterBlock> CO_layout; |
9932 | std::vector<MaskAssignment> masks; |
9933 | GRFMultirange CO_regs; |
9934 | CO_strategy.accessType = AccessType::Block; |
9935 | FlagRegister flagBeta0; |
9936 | |
9937 | auto &Xs_regs = sumA ? state.As_regs : state.Bs_regs; |
9938 | auto &Xs_layout = sumA ? state.As_layout : state.Bs_layout; |
9939 | |
9940 | int Xs_nregs = getRegCount(Xs_layout); |
9941 | auto Xs_usedRegs = Xs_regs.subrange(0, Xs_nregs); |
9942 | |
9943 | CO.layout = sumA ? MatrixLayout::N : MatrixLayout::T; |
9944 | |
9945 | auto remR = sumA && !strategy.CO.padded; |
9946 | auto remC = sumB && !strategy.CO.padded; |
9947 | |
9948 | if (!getRegLayout(Tco, CO_layout, cor, coc, remR, remC, true, true, 0, 0, |
9949 | CO, CO_strategy)) |
9950 | stub(); |
9951 | |
9952 | bool share = (Tc == Tco) && matchLayouts(Tc, CO_layout, Xs_layout); |
9953 | |
9954 | Label skipStore; |
9955 | and_(16 | ne | state.flagAP, null.ud(), state.inputs.flags, FlagStoreSums); |
9956 | if_(16 | state.flagAP, skipStore); |
9957 | |
9958 | if (checkBeta0) { |
9959 | flagBeta0 = state.raVFlag.alloc(); |
9960 | cmp0(1 | eq | flagBeta0, problem.beta_real.getReg(0)); |
9961 | } |
9962 | if (checkBeta1) |
9963 | cmp(1 | eq | state.flagAP, problem.beta_real.getReg(0), |
9964 | cast(problem.Ts, 1.0)); |
9965 | |
9966 | allocAddrRegs(CO_addrs, CO_layout, CO, CO_strategy, state); |
9967 | setupAddr(Tco, CO_addrs, state.effCO, CO_layout, Subregister(), CO, |
9968 | CO_strategy, strategy, state); |
9969 | |
9970 | if (!assignMasks(CO_layout, LoopM, LoopN, masks, strategy, state, true)) |
9971 | stub(); |
9972 | |
9973 | loadMasks(masks, state.remainders, strategy, state); |
9974 | |
9975 | if (!problem.alpha1()) |
9976 | map(hw, Tc, Xs_usedRegs, Xs_usedRegs, strategy, |
9977 | [&](int esize, GRF acc, GRF _) { |
9978 | auto &alphar = problem.alpha_real; |
9979 | alphar.fixed() |
9980 | ? mul(esize, acc, acc, cast(Tc.real(), alphar)) |
9981 | : mul(esize, acc, acc, |
9982 | alphar.getRegAvoiding(hw, acc)); |
9983 | }); |
9984 | |
9985 | if (!problem.beta0() && !(problem.beta1() && atomic)) { |
9986 | Label lSkipUpdate; |
9987 | auto CO_regsLoad = state.ra.alloc_range(getRegCount(CO_layout)); |
9988 | auto CO_regsLoadConv = CO_regsLoad; |
9989 | |
9990 | if (checkBeta0) jmpi(1 | flagBeta0, lSkipUpdate); |
9991 | if (checkBeta1) jmpi(1 | state.flagAP, lSkipUpdate); |
9992 | |
9993 | loadMatrix(CO_regsLoad, CO_layout, CO, CO_strategy, CO_addrs, strategy, |
9994 | state); |
9995 | |
9996 | if (!share) { |
9997 | CO_regsLoadConv = state.ra.alloc_range(Xs_nregs); |
9998 | copyRegisters(Tco, Tc, CO_layout, Xs_layout, CO_regsLoad, |
9999 | CO_regsLoadConv, 0, 0, false, strategy, state); |
10000 | } |
10001 | |
10002 | auto &betar = problem.beta_real; |
10003 | |
10004 | map(hw, Tc, Xs_usedRegs, CO_regsLoadConv, strategy, |
10005 | [&](int esize, GRF acc, GRF loaded) { |
10006 | if (betar == 1) |
10007 | add(esize, acc, acc, loaded); |
10008 | else if (betar.fixed()) |
10009 | mad(esize, acc, acc, loaded, cast(Tc.real(), betar)); |
10010 | else |
10011 | mad(esize, acc, acc, loaded, |
10012 | betar.getRegAvoiding(hw, acc)); |
10013 | }); |
10014 | |
10015 | state.ra.safeRelease(CO_regsLoad); |
10016 | if (!share) state.ra.safeRelease(CO_regsLoadConv); |
10017 | |
10018 | if (checkBeta0 || checkBeta1) { |
10019 | state.wipeActiveVFlags(); |
10020 | mark(lSkipUpdate); |
10021 | } |
10022 | } |
10023 | |
10024 | if (!share) { |
10025 | CO_regs = state.ra.alloc_range(getRegCount(CO_layout)); |
10026 | copyRegisters(Tc, Tco, Xs_layout, CO_layout, Xs_regs, CO_regs, 0, 0, |
10027 | false, strategy, state); |
10028 | safeReleaseRanges(Xs_regs, state); |
10029 | } |
10030 | |
10031 | auto &effCO_regs = share ? Xs_regs : CO_regs; |
10032 | if (atomic) { |
10033 | Label lStore, lDone; |
10034 | if (checkBeta1) { |
10035 | if (!CO_strategy.base.isStateless() && !CO_strategy.newDP) |
10036 | stub(); /* need to shift addresses */ |
10037 | jmpi(1 | ~state.flagAP, lStore); |
10038 | } |
10039 | allocEAtomicAddRegs( |
10040 | hw, Tco, CO_layout, CO, CO_strategy, state, state.flagAP); |
10041 | atomicAddMatrix(Tco, effCO_regs, CO_layout, CO, CO_strategy, CO_addrs, |
10042 | problem, strategy, state); |
10043 | freeEAtomicAddRegs(state, state.flagAP); |
10044 | if (checkBeta1) { |
10045 | state.wipeActiveVFlags(); |
10046 | jmpi(1, lDone); |
10047 | mark(lStore); |
10048 | storeMatrix(effCO_regs, CO_layout, CO, CO_strategy, CO_addrs, |
10049 | strategy, state); |
10050 | mark(lDone); |
10051 | } |
10052 | } else |
10053 | storeMatrix(effCO_regs, CO_layout, CO, CO_strategy, CO_addrs, strategy, |
10054 | state); |
10055 | |
10056 | mark(skipStore); |
10057 | endif(16); |
10058 | |
10059 | safeReleaseMaskAssignments(masks, state); |
10060 | |
10061 | if (!share) safeReleaseRanges(CO_regs, state); |
10062 | safeReleaseRanges(state.As_regs, state); |
10063 | safeReleaseRanges(state.Bs_regs, state); |
10064 | state.raVFlag.safeRelease(flagBeta0); |
10065 | state.As_layout.clear(); |
10066 | state.Bs_layout.clear(); |
10067 | } |
10068 | |
10069 | // Generate code for summing C across k dimension through SLM. |
10070 | template <HW hw> |
10071 | void gemm_kernel_generator_t<hw>::gemmKReduce(const GEMMProblem &problem, |
10072 | const GEMMStrategy &strategy, GEMMState &state) { |
10073 | auto Tc = problem.Tc; |
10074 | Label lDone; |
10075 | |
10076 | // Early exit if nothing to do. All branching scalar since no fusing in k dimension. |
10077 | cmp(1 | le | state.flagAP, state.lszK, 1); |
10078 | jmpi(1 | state.flagAP, lDone); |
10079 | |
10080 | status << "k reduction through SLM" << status_stream::endl; |
10081 | cmp(1 | eq | state.flagAP, state.lidK, 0); |
10082 | |
10083 | auto C_regs = state.C_regs[0]; |
10084 | |
10085 | // Reduce A/B sums at the same time. |
10086 | if (problem.sumA) C_regs.append(state.As_regs); |
10087 | if (problem.sumB) C_regs.append(state.Bs_regs); |
10088 | |
10089 | // In general SLM isn't large enough to do the reduction in one step. |
10090 | // Slice C into pieces that will fit. |
10091 | int maxMNThreads = strategy.wg[LoopM] * strategy.wg[LoopN]; |
10092 | if (maxMNThreads <= 0) |
10093 | throw std::runtime_error("Max workgroup size not specified" ); |
10094 | |
10095 | int regs = C_regs.getLen(); |
10096 | int sliceRegs = int(gemmPerKSLMSize(problem, strategy) |
10097 | / (maxMNThreads * GRF::bytes(hw))); |
10098 | if (sliceRegs <= 0) |
10099 | throw std::runtime_error("Not enough SLM for k reduction" ); |
10100 | sliceRegs = std::min<int>(sliceRegs, C_regs.getLen()); |
10101 | |
10102 | // Temporaries. |
10103 | auto kt = state.ra.alloc_sub<int32_t>(); |
10104 | auto flagKTLoop = state.raVFlag.alloc(); |
10105 | auto barrierTemp = state.ra.alloc(); |
10106 | |
10107 | if (state.r0_info.isARF()) stub(); |
10108 | GRF r0_info {state.r0_info.getBase()}; |
10109 | |
10110 | bool initialBarrier = (strategy.slmBuffers > 0 || strategy.persistent); |
10111 | MOCK_BARRIERS if (initialBarrier) barriersignal(barrierTemp, r0_info); |
10112 | |
10113 | // Set up addressing. |
10114 | auto addr0 = state.ra.alloc_sub<uint32_t>(); |
10115 | emad(1, addr0, state.lidM, state.lidN, strategy.wg[LoopM], strategy, state); |
10116 | emad(1, addr0, addr0, state.lidK, strategy.wg[LoopM] * strategy.wg[LoopN], |
10117 | strategy, state); |
10118 | mulConstant(1, addr0, addr0, sliceRegs * GRF::bytes(hw)); |
10119 | |
10120 | int unrollKSLMStride = strategy.wg[LoopM] * strategy.wg[LoopN] * sliceRegs |
10121 | * GRF::bytes(hw); |
10122 | Subregister unrollKSLMReturn = state.ra.alloc_sub<int32_t>(); |
10123 | |
10124 | mulConstant(1, unrollKSLMReturn, -state.lszK, unrollKSLMStride); |
10125 | |
10126 | MatrixAddressing C_slm; |
10127 | MatrixAddressingStrategy C_slmStrategy; |
10128 | |
10129 | C_slm.layout = MatrixLayout::Pc; |
10130 | C_slm.packSize = elementsPerGRF(hw, Tc); |
10131 | C_slm.crosspack = 1; |
10132 | C_slm.setAlignment(GRF::bytes(hw)); |
10133 | |
10134 | C_slmStrategy.base = SLM; |
10135 | C_slmStrategy.accessType = AccessType::Block; |
10136 | C_slmStrategy.padded = true; |
10137 | if (hw >= HW::XeHPG) C_slmStrategy.newDP = true; |
10138 | |
10139 | GRFRange C_load; |
10140 | vector<RegisterBlock> C_slmLayout; |
10141 | vector<GRFRange> C_slmAddrs; |
10142 | |
10143 | // Find maximum # registers of C we can transfer to/from SLM at once. |
10144 | int maxContig = rounddown_pow2(regs); |
10145 | for (; maxContig > 1; maxContig >>= 1) { |
10146 | bool ok = true; |
10147 | for (int offsetReg = 0; offsetReg < regs; offsetReg += maxContig) { |
10148 | int nr = std::min(regs - offsetReg, maxContig); |
10149 | if (!C_regs.contiguous(offsetReg, nr)) { |
10150 | ok = false; |
10151 | break; |
10152 | } |
10153 | } |
10154 | if (ok) break; |
10155 | } |
10156 | |
10157 | // Allocate address and data registers, automatically shrinking sliceRegs if |
10158 | // there are not enough registers. |
10159 | for (; sliceRegs > 0; sliceRegs = rounddown_pow2(sliceRegs - 1)) { |
10160 | bool ok = true; |
10161 | |
10162 | C_load = state.ra.try_alloc_range(sliceRegs); |
10163 | ok = ok && C_load.isValid(); |
10164 | |
10165 | if (!getRegLayout(Tc, C_slmLayout, elementsPerGRF(hw, Tc), sliceRegs, |
10166 | false, false, true, true, 0, maxContig, C_slm, |
10167 | C_slmStrategy)) |
10168 | stub(); |
10169 | ok = ok |
10170 | && tryAllocAddrRegs( |
10171 | C_slmAddrs, C_slmLayout, C_slm, C_slmStrategy, state); |
10172 | |
10173 | if (ok) break; |
10174 | |
10175 | state.ra.safeRelease(C_load); |
10176 | } |
10177 | |
10178 | if (sliceRegs <= 0) throw out_of_registers_exception(); |
10179 | |
10180 | setupAddr(Tc, C_slmAddrs, addr0, C_slmLayout, Subregister(), C_slm, |
10181 | C_slmStrategy, strategy, state); |
10182 | |
10183 | MOCK_BARRIERS if (initialBarrier) barrierwait(); |
10184 | |
10185 | // Loop over slices. |
10186 | for (int rr = 0; rr < regs; rr += sliceRegs) { |
10187 | Label lSkipWrite, lSkipReduce, lTop; |
10188 | |
10189 | int nreg = std::min(sliceRegs, regs - rr); |
10190 | auto C_range = C_regs.subrange(rr, nreg); |
10191 | |
10192 | MOCK_BARRIERS if (rr > 0) slmBarrier(barrierTemp, r0_info); |
10193 | |
10194 | // Trim down SLM layout for final loop. |
10195 | if (nreg < sliceRegs) { |
10196 | vector<RegisterBlock> sublayout; |
10197 | vector<GRFRange> subaddrs; |
10198 | if (!getSubblocks(Tc, sublayout, subaddrs, C_slmLayout, C_slmAddrs, |
10199 | true, 0, nreg, true, C_slm, C_slmStrategy)) |
10200 | stub(); |
10201 | std::swap(sublayout, C_slmLayout); |
10202 | std::swap(subaddrs, C_slmAddrs); |
10203 | } |
10204 | |
10205 | // Non-leaders write to SLM. |
10206 | jmpi(1 | state.flagAP, lSkipWrite); |
10207 | storeMatrix(C_range, C_slmLayout, C_slm, C_slmStrategy, C_slmAddrs, |
10208 | strategy, state); |
10209 | mark(lSkipWrite); |
10210 | |
10211 | MOCK_BARRIERS slmBarrier(barrierTemp, r0_info); |
10212 | |
10213 | // Leader reads SLM data and accumulates C. |
10214 | jmpi(1 | ~state.flagAP, lSkipReduce); |
10215 | add(1, kt, state.lszK, -1); |
10216 | incAddr(C_slmAddrs, unrollKSLMStride, C_slmLayout, C_slm, C_slmStrategy, |
10217 | strategy, state); |
10218 | |
10219 | mark(lTop); |
10220 | add(1 | gt | flagKTLoop, kt, kt, -1); |
10221 | loadMatrix(C_load, C_slmLayout, C_slm, C_slmStrategy, C_slmAddrs, |
10222 | strategy, state); |
10223 | incAddr(C_slmAddrs, unrollKSLMStride, C_slmLayout, C_slm, C_slmStrategy, |
10224 | strategy, state); |
10225 | map(hw, Tc.real(), C_range, C_load, strategy, |
10226 | [&](int simd, GRF r1, GRF r2) { add(simd, r1, r1, r2); }); |
10227 | jmpi(1 | flagKTLoop, lTop); |
10228 | |
10229 | if (rr + nreg < regs) |
10230 | incAddr(C_slmAddrs, unrollKSLMReturn, C_slmLayout, C_slm, |
10231 | C_slmStrategy, strategy, state); |
10232 | |
10233 | mark(lSkipReduce); |
10234 | } |
10235 | |
10236 | // Followers will not update C. |
10237 | mov(1 | ~state.flagAP, state.remainders[LoopM], 0); |
10238 | mov(1 | ~state.flagAP, state.remainders[LoopN], 0); |
10239 | |
10240 | state.raVFlag.safeRelease(flagKTLoop); |
10241 | state.ra.safeRelease(C_load); |
10242 | state.ra.safeRelease(kt); |
10243 | state.ra.safeRelease(unrollKSLMReturn); |
10244 | state.ra.safeRelease(addr0); |
10245 | state.ra.safeRelease(barrierTemp); |
10246 | safeReleaseRanges(C_slmAddrs, state); |
10247 | |
10248 | mark(lDone); |
10249 | } |
10250 | |
10251 | template <HW hw> |
10252 | void gemm_kernel_generator_t<hw>::gemmPrefetchC( |
10253 | const GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state) { |
10254 | auto Tc_ext = problem.Tc_ext; |
10255 | bool checkBeta0 = problem.checkBeta0 && !problem.beta_real.fixed(); |
10256 | bool checkIDK = strategy.kParallelLocal; |
10257 | |
10258 | releaseRanges(state.Ap_regs, state); |
10259 | releaseRanges(state.Bp_regs, state); |
10260 | |
10261 | status << "Prefetch C" << status_stream::endl; |
10262 | |
10263 | if (checkBeta0) { |
10264 | cmp0(1 | eq | state.flagAP, problem.beta_real.getReg(0)); |
10265 | } |
10266 | |
10267 | Address2DParams Cp_params; |
10268 | if (strategy.C.address2D) { |
10269 | Cp_params.rows = state.inputs.m; |
10270 | Cp_params.cols = state.inputs.n; |
10271 | Cp_params.offR = state.i0; |
10272 | Cp_params.offC = state.j0; |
10273 | } else { |
10274 | Cp_params.rows = state.remainders[LoopM]; |
10275 | Cp_params.cols = state.remainders[LoopN]; |
10276 | } |
10277 | Cp_params.remR = state.remainders[LoopM]; |
10278 | Cp_params.remC = state.remainders[LoopN]; |
10279 | |
10280 | bool oldAdd32 = strategy.emulate.emulate64_add32; |
10281 | strategy.emulate.emulate64_add32 = false; |
10282 | |
10283 | gemmCacheLDCMultiples(problem, strategy, state, 1); |
10284 | |
10285 | if (checkIDK) { |
10286 | if (checkBeta0) |
10287 | cmp(1 | ~state.flagAP | gt | state.flagAP, state.lidK, 0); |
10288 | else |
10289 | cmp(1 | gt | state.flagAP, state.lidK, 0); |
10290 | } |
10291 | |
10292 | allocAddrRegs(state.Cp_addrs, state.Cp_layout, problem.C, |
10293 | strategy.C_prefetch, state); |
10294 | setupAddr(Tc_ext, state.Cp_addrs, state.effCp, state.Cp_layout, |
10295 | state.inputs.ldc[0], problem.C, strategy.C_prefetch, strategy, |
10296 | state, Cp_params, state.ldcMultiples[0]); |
10297 | |
10298 | Label lSkipPrefetchC; |
10299 | if (checkBeta0 || checkIDK) jmpi(1 | state.flagAP, lSkipPrefetchC); |
10300 | |
10301 | state.Cp_regs = state.ra.alloc_range(getRegCount(state.Cp_layout)); |
10302 | |
10303 | loadMatrix(state.Cp_regs, state.Cp_layout, problem.C, strategy.C_prefetch, |
10304 | state.Cp_addrs, strategy, state); |
10305 | |
10306 | safeReleaseRanges(state.Cp_regs, state); |
10307 | safeReleaseRanges(state.Cp_addrs, state); |
10308 | if (state.effCp != state.effC[0]) state.ra.safeRelease(state.effCp); |
10309 | |
10310 | releaseLDMultiples(state.ldcMultiples[0], state); |
10311 | releaseIndexVec(state); |
10312 | |
10313 | if (checkBeta0 || checkIDK) mark(lSkipPrefetchC); |
10314 | |
10315 | strategy.emulate.emulate64_add32 = oldAdd32; |
10316 | |
10317 | reclaimRanges(state.Ap_regs, state); |
10318 | reclaimRanges(state.Bp_regs, state); |
10319 | } |
10320 | |
10321 | // Generate code for checking whether 32-bit address arithmetic can be used inside k loop. |
10322 | // Assumes leading dimensions have not been shifted yet. |
10323 | template <HW hw> |
10324 | void gemm_kernel_generator_t<hw>::gemmCheck32( |
10325 | const GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state) { |
10326 | if (!strategy.checkAdd32) return; |
10327 | |
10328 | bool checkA = (strategy.A.base.getModel() == ModelA64); |
10329 | bool checkB = (strategy.B.base.getModel() == ModelA64); |
10330 | if (!checkA && !checkB) return; |
10331 | |
10332 | auto &m = state.inputs.m; |
10333 | auto &n = state.inputs.n; |
10334 | auto &k = state.fullK.isValid() ? state.fullK : state.inputs.k; |
10335 | auto &lda = state.inputs.lda; |
10336 | auto &ldb = state.inputs.ldb; |
10337 | auto temp1GRF = state.ra.alloc(); |
10338 | auto temp2GRF = state.ra.alloc(); |
10339 | auto temp1 = temp1GRF.ud( |
10340 | 0); // Only need one :ud subregister. But GRF-align it for mach. |
10341 | auto temp2 = temp2GRF.ud(0); |
10342 | auto temp3 = temp2GRF.ud(4); |
10343 | auto flag = state.raVFlag.alloc(); |
10344 | |
10345 | if (checkA) { |
10346 | add(1, temp2, state.effA.ud(), state.offsetA.ud()); |
10347 | switch (problem.A |
10348 | .layout) { // Conservatively estimate upper bound for size of A. |
10349 | case MatrixLayout::N: emul32High(1, temp1, lda, k); break; |
10350 | case MatrixLayout::T: emul32High(1, temp1, lda, m); break; |
10351 | case MatrixLayout::Pc: { |
10352 | if (strategy.fixedWG(problem)) |
10353 | add(1, temp3, m, |
10354 | uint16_t(strategy.wg[LoopM] * strategy.unroll[LoopM] |
10355 | - 1)); |
10356 | else |
10357 | emad(1, temp3, m, state.inputs.localSizeM, |
10358 | strategy.unroll[LoopM], strategy, state); |
10359 | emul32High(1, temp1, lda, temp3); |
10360 | break; |
10361 | } |
10362 | default: stub(); |
10363 | } |
10364 | add(1 | ov | flag, temp2, acc0.ud(0), temp2); |
10365 | cmp(1 | ~flag | ne | flag, temp1, uint16_t(0)); |
10366 | } |
10367 | |
10368 | if (checkB) { |
10369 | add(1, temp2, state.effB.ud(), state.offsetB.ud()); |
10370 | switch (problem.B.layout) { |
10371 | case MatrixLayout::T: emul32High(1, temp1, ldb, k); break; |
10372 | case MatrixLayout::N: emul32High(1, temp1, ldb, n); break; |
10373 | case MatrixLayout::Pr: { |
10374 | if (strategy.fixedWG(problem)) |
10375 | add(1, temp3, n, |
10376 | uint16_t(strategy.wg[LoopN] * strategy.unroll[LoopN] |
10377 | - 1)); |
10378 | else |
10379 | emad(1, temp3, n, state.inputs.localSizeN, |
10380 | strategy.unroll[LoopN], strategy, state); |
10381 | emul32High(1, temp1, ldb, temp3); |
10382 | break; |
10383 | } |
10384 | default: stub(); |
10385 | } |
10386 | InstructionModifier mod = 1; |
10387 | if (checkA) mod |= ~flag; |
10388 | add(mod | ov | flag, temp2, acc0.ud(0), temp2); |
10389 | cmp(1 | ~flag | ne | flag, temp1, uint16_t(0)); |
10390 | } |
10391 | |
10392 | state.add64 = state.ra.alloc_sub<uint16_t>(); |
10393 | and_(1, state.add64, flag, 1u); |
10394 | state.raVFlag.safeRelease(flag); |
10395 | |
10396 | state.ra.safeRelease(temp1GRF); |
10397 | temp1 = invalid; |
10398 | state.ra.safeRelease(temp2GRF); |
10399 | temp2 = invalid; |
10400 | temp3 = invalid; |
10401 | } |
10402 | |
10403 | // Increment A pointer after load, inside GEMM k loop. |
10404 | template <HW hw> |
10405 | void gemm_kernel_generator_t<hw>::gemmAIncrementInternal(Type Ta, |
10406 | const std::vector<RegisterBlock> &layout, |
10407 | const std::vector<GRFRange> &addrs, const MatrixAddressing &A, |
10408 | const MatrixAddressingStrategy &A_strategy, int ka_inc, |
10409 | const GEMMProblem &problem, const GEMMStrategy &strategy, |
10410 | GEMMState &state, int ha) { |
10411 | if (ka_inc == 0) |
10412 | /* no-op */; |
10413 | else if (A_strategy.address2D) |
10414 | incDecAddr(addrs, Subregister(), 0, ka_inc, layout, A, A_strategy, |
10415 | strategy, state, problem.backward()); |
10416 | else if (A.layout == MatrixLayout::N) { |
10417 | SubregisterPair lda_ka; |
10418 | bool release = false; |
10419 | // Use cached lda * ka_inc if available, otherwise calculate on the fly. |
10420 | if (ka_inc == 1) |
10421 | lda_ka = state.lda; |
10422 | else if (state.lda_ka.isValid() && ka_inc == state.ka_cached) |
10423 | lda_ka = state.lda_ka; |
10424 | else if (state.lda_ka_prefetch.isValid() |
10425 | && ka_inc == strategy.ka_pfStride) |
10426 | lda_ka = state.lda_ka_prefetch; |
10427 | else { |
10428 | lda_ka = state.ra.alloc_sub<int32_t>(); |
10429 | emulConstant(1, lda_ka, state.inputs.lda, ka_inc, strategy, state); |
10430 | release = true; |
10431 | } |
10432 | incDecAddr(addrs, lda_ka, layout, A, A_strategy, strategy, state, |
10433 | problem.backward()); |
10434 | if (release) state.ra.safeRelease(lda_ka); |
10435 | } else { |
10436 | int incA; |
10437 | switch (A.layout) { |
10438 | case MatrixLayout::Pc: |
10439 | incA = untile(Ta, A, 0, 0, ha + ka_inc, A.packSize, |
10440 | strategy.unrollKSLM) |
10441 | - untile(Ta, A, 0, 0, ha, A.packSize, |
10442 | strategy.unrollKSLM); |
10443 | break; |
10444 | case MatrixLayout::T: incA = ka_inc; break; |
10445 | default: stub(); |
10446 | } |
10447 | incDecAddr(addrs, uint16_t(incA * Ta), layout, A, A_strategy, strategy, |
10448 | state, problem.backward()); |
10449 | } |
10450 | } |
10451 | |
10452 | template <HW hw> |
10453 | void gemm_kernel_generator_t<hw>::gemmAIncrementInternal(Type Ta, |
10454 | const std::vector<RegisterBlock> &layout, |
10455 | const std::vector<GRFRange> &addrs, const MatrixAddressing &A, |
10456 | const MatrixAddressingStrategy &A_strategy, |
10457 | const MultishiftSubregister &ka_inc, const GEMMProblem &problem, |
10458 | const GEMMStrategy &strategy, GEMMState &state, int ha) { |
10459 | incDecAddr(addrs, ka_inc, layout, A, A_strategy, strategy, state, |
10460 | problem.backward()); |
10461 | } |
10462 | |
10463 | template <HW hw> |
10464 | void gemm_kernel_generator_t<hw>::gemmAIncrementInternal(Type Ta, |
10465 | const std::vector<RegisterBlock> &layout, |
10466 | const std::vector<GRFRange> &addrs, const MatrixAddressing &A, |
10467 | const MatrixAddressingStrategy &A_strategy, const Subregister &ka_inc, |
10468 | const GEMMProblem &problem, const GEMMStrategy &strategy, |
10469 | GEMMState &state, int ha) { |
10470 | incDecAddr(addrs, ka_inc, 0, ka_inc, layout, A, A_strategy, strategy, state, |
10471 | problem.backward()); |
10472 | } |
10473 | |
10474 | template <HW hw> |
10475 | template <typename I> |
10476 | void gemm_kernel_generator_t<hw>::gemmAIncrement(Type Ta, |
10477 | const std::vector<RegisterBlock> &layout, |
10478 | const std::vector<GRFRange> &addrs, const MatrixAddressing &A, |
10479 | const MatrixAddressingStrategy &A_strategy, I ka_inc, |
10480 | const GEMMProblem &problem, const GEMMStrategy &strategy, |
10481 | GEMMState &state, int ha) { |
10482 | gemmAIncrementInternal(Ta, layout, addrs, A, A_strategy, ka_inc, problem, |
10483 | strategy, state, ha); |
10484 | } |
10485 | |
10486 | // A load for GEMM k loop. |
10487 | template <HW hw> |
10488 | void gemm_kernel_generator_t<hw>::gemmALoad(const GRFMultirange ®s, |
10489 | const std::vector<RegisterBlock> &layout, |
10490 | const std::vector<GRFRange> &addrs, const MatrixAddressing &A, |
10491 | const MatrixAddressingStrategy &A_strategy, const GEMMProblem &problem, |
10492 | const GEMMStrategy &strategy, GEMMState &state) { |
10493 | loadMatrix(regs, layout, A, A_strategy, addrs, strategy, state); |
10494 | } |
10495 | |
10496 | template <HW hw> |
10497 | template <typename I> |
10498 | void gemm_kernel_generator_t<hw>::gemmALoadInc(Type Ta, |
10499 | const GRFMultirange ®s, const std::vector<RegisterBlock> &layout, |
10500 | const std::vector<GRFRange> &addrs, const MatrixAddressing &A, |
10501 | const MatrixAddressingStrategy &A_strategy, I ka_inc, |
10502 | const GEMMProblem &problem, const GEMMStrategy &strategy, |
10503 | GEMMState &state) { |
10504 | gemmALoad(regs, layout, addrs, A, A_strategy, problem, strategy, state); |
10505 | gemmAIncrement( |
10506 | Ta, layout, addrs, A, A_strategy, ka_inc, problem, strategy, state); |
10507 | } |
10508 | |
10509 | template <HW hw> |
10510 | void gemm_kernel_generator_t<hw>::gemmBIncrementInternal(Type Tb, |
10511 | const std::vector<RegisterBlock> &layout, |
10512 | const std::vector<GRFRange> &addrs, const MatrixAddressing &B, |
10513 | const MatrixAddressingStrategy &B_strategy, int kb_inc, |
10514 | const GEMMProblem &problem, const GEMMStrategy &strategy, |
10515 | GEMMState &state, int hb) { |
10516 | if (kb_inc == 0) |
10517 | /* no-op */; |
10518 | else if (B_strategy.address2D) |
10519 | incDecAddr(addrs, Subregister(), kb_inc, 0, layout, B, B_strategy, |
10520 | strategy, state, problem.backward()); |
10521 | else if (B.layout == MatrixLayout::T) { |
10522 | SubregisterPair ldb_kb; |
10523 | bool release = false; |
10524 | if (kb_inc == 1) |
10525 | ldb_kb = state.ldb; |
10526 | else if (state.ldb_kb.isValid() && kb_inc == state.kb_cached) |
10527 | ldb_kb = state.ldb_kb; |
10528 | else if (state.ldb_kb_prefetch.isValid() |
10529 | && kb_inc == strategy.kb_pfStride) |
10530 | ldb_kb = state.ldb_kb_prefetch; |
10531 | else { |
10532 | ldb_kb = state.ra.alloc_sub<int32_t>(); |
10533 | emulConstant(1, ldb_kb, state.inputs.ldb, kb_inc, strategy, state); |
10534 | release = true; |
10535 | } |
10536 | incDecAddr(addrs, ldb_kb, layout, B, B_strategy, strategy, state, |
10537 | problem.backward()); |
10538 | if (release) state.ra.safeRelease(ldb_kb); |
10539 | } else { |
10540 | int incB; |
10541 | switch (B.layout) { |
10542 | case MatrixLayout::Pr: |
10543 | incB = untile(Tb, B, 0, hb + kb_inc, 0, strategy.unrollKSLM, |
10544 | B.packSize) |
10545 | - untile(Tb, B, 0, hb, 0, strategy.unrollKSLM, |
10546 | B.packSize); |
10547 | break; |
10548 | case MatrixLayout::N: incB = kb_inc; break; |
10549 | default: stub(); |
10550 | } |
10551 | incDecAddr(addrs, uint16_t(incB * Tb), layout, B, B_strategy, strategy, |
10552 | state, problem.backward()); |
10553 | } |
10554 | } |
10555 | |
10556 | template <HW hw> |
10557 | void gemm_kernel_generator_t<hw>::gemmBIncrementInternal(Type Tb, |
10558 | const std::vector<RegisterBlock> &layout, |
10559 | const std::vector<GRFRange> &addrs, const MatrixAddressing &B, |
10560 | const MatrixAddressingStrategy &B_strategy, |
10561 | const MultishiftSubregister &kb_inc, const GEMMProblem &problem, |
10562 | const GEMMStrategy &strategy, GEMMState &state, int hb) { |
10563 | incDecAddr(addrs, kb_inc, layout, B, B_strategy, strategy, state, |
10564 | problem.backward()); |
10565 | } |
10566 | |
10567 | template <HW hw> |
10568 | void gemm_kernel_generator_t<hw>::gemmBIncrementInternal(Type Tb, |
10569 | const std::vector<RegisterBlock> &layout, |
10570 | const std::vector<GRFRange> &addrs, const MatrixAddressing &B, |
10571 | const MatrixAddressingStrategy &B_strategy, const Subregister &kb_inc, |
10572 | const GEMMProblem &problem, const GEMMStrategy &strategy, |
10573 | GEMMState &state, int hb) { |
10574 | incDecAddr(addrs, kb_inc, kb_inc, 0, layout, B, B_strategy, strategy, state, |
10575 | problem.backward()); |
10576 | } |
10577 | |
10578 | template <HW hw> |
10579 | template <typename I> |
10580 | void gemm_kernel_generator_t<hw>::gemmBIncrement(Type Tb, |
10581 | const std::vector<RegisterBlock> &layout, |
10582 | const std::vector<GRFRange> &addrs, const MatrixAddressing &B, |
10583 | const MatrixAddressingStrategy &B_strategy, I kb_inc, |
10584 | const GEMMProblem &problem, const GEMMStrategy &strategy, |
10585 | GEMMState &state, int hb) { |
10586 | gemmBIncrementInternal(Tb, layout, addrs, B, B_strategy, kb_inc, problem, |
10587 | strategy, state, hb); |
10588 | } |
10589 | |
10590 | // B load for GEMM k loop. |
10591 | template <HW hw> |
10592 | void gemm_kernel_generator_t<hw>::gemmBLoad(const GRFMultirange ®s, |
10593 | const std::vector<RegisterBlock> &layout, |
10594 | const std::vector<GRFRange> &addrs, const MatrixAddressing &B, |
10595 | const MatrixAddressingStrategy &B_strategy, const GEMMProblem &problem, |
10596 | const GEMMStrategy &strategy, GEMMState &state) { |
10597 | loadMatrix(regs, layout, B, B_strategy, addrs, strategy, state); |
10598 | } |
10599 | |
10600 | template <HW hw> |
10601 | template <typename I> |
10602 | void gemm_kernel_generator_t<hw>::gemmBLoadInc(Type Tb, |
10603 | const GRFMultirange ®s, const std::vector<RegisterBlock> &layout, |
10604 | const std::vector<GRFRange> &addrs, const MatrixAddressing &B, |
10605 | const MatrixAddressingStrategy &B_strategy, I kb_inc, |
10606 | const GEMMProblem &problem, const GEMMStrategy &strategy, |
10607 | GEMMState &state) { |
10608 | gemmBLoad(regs, layout, addrs, B, B_strategy, problem, strategy, state); |
10609 | gemmBIncrement( |
10610 | Tb, layout, addrs, B, B_strategy, kb_inc, problem, strategy, state); |
10611 | } |
10612 | |
10613 | template <HW hw> |
10614 | template <bool doA> |
10615 | void gemm_kernel_generator_t<hw>::gemmAiBiRemLoadInc(bool incremental, |
10616 | bool incrementalCopy, bool keepAddrTogether, bool willRemask, |
10617 | const Subregister &kSLMX, const GRFMultirange &Xi_regs, |
10618 | const vector<RegisterBlock> &Xi_layout, |
10619 | const vector<GRFRange> &Xi_addrs, |
10620 | const vector<vector<RegisterBlock>> &Xi_layoutK, |
10621 | const vector<vector<GRFRange>> &Xi_addrsK, const GRFMultirange &Xo_regs, |
10622 | const vector<RegisterBlock> &Xo_layout, const MatrixAddressing &Xi, |
10623 | const MatrixAddressingStrategy &Xi_strategy, const GEMMProblem &problem, |
10624 | const GEMMStrategy &strategy, GEMMState &state) { |
10625 | auto T = doA ? problem.Ta : problem.Tb; |
10626 | auto T_ext = doA ? problem.Ta_ext : problem.Tb_ext; |
10627 | auto kx_slm = doA ? state.ka_slm : state.kb_slm; |
10628 | |
10629 | auto unrollKSLM = strategy.unrollKSLM; |
10630 | |
10631 | bool prezero = !willRemask |
10632 | && ((doA ? state.slmASums : state.slmBSums) |
10633 | || (minOuterProductCount(hw, problem, strategy) > 1)); |
10634 | |
10635 | if (!incremental) { |
10636 | if (prezero) zeroMatrix(Xi_regs, strategy); |
10637 | doA ? gemmALoad(Xi_regs, Xi_layout, Xi_addrs, Xi, Xi_strategy, problem, |
10638 | strategy, state) |
10639 | : gemmBLoad(Xi_regs, Xi_layout, Xi_addrs, Xi, Xi_strategy, problem, |
10640 | strategy, state); |
10641 | } else { |
10642 | bool simtCF = strategy.fused |
10643 | && (strategy.fusedLoop == (doA ? LoopN : LoopM)); |
10644 | int simt = simtCF ? 16 : 1; |
10645 | Label done; |
10646 | |
10647 | keepAddrTogether &= (Xi_addrsK.size() > 1); |
10648 | |
10649 | cmp(simt | gt | state.flagAP, kSLMX, 0); |
10650 | add(1, kSLMX, kSLMX, (kx_slm > 1) ? -1 : -unrollKSLM); |
10651 | |
10652 | if (prezero) zeroMatrix(incrementalCopy ? Xo_regs : Xi_regs, strategy); |
10653 | |
10654 | for (int hh = 0; hh < kx_slm; hh++) { |
10655 | int hhRem = kx_slm - hh - 1; |
10656 | |
10657 | simtCF ? goto12(16 | ~state.flagAP, done) |
10658 | : jmpi(1 | ~state.flagAP, done); |
10659 | |
10660 | if (hhRem > 0) { |
10661 | cmp(simt | gt | state.flagAP, kSLMX, 0); |
10662 | add(1, kSLMX, kSLMX, |
10663 | (hhRem == 1) ? -(unrollKSLM - kx_slm + 1) : -1); |
10664 | } |
10665 | |
10666 | int hh_eff = problem.backward() ? (kx_slm - 1 - hh) : hh; |
10667 | int hh_layout = hh_eff; |
10668 | int hh_addr = hh_eff; |
10669 | |
10670 | if (Xi_layoutK.size() == 1) hh_layout = 0; |
10671 | if (Xi_addrsK.size() == 1) hh_addr = 0; |
10672 | |
10673 | // OPTIMIZEME: delay inc if kx_slm = 1 |
10674 | auto kx_inc = (Xi_addrsK.size() > 1) |
10675 | ? unrollKSLM |
10676 | : ((hh + 1) != kx_slm) ? 1 : (unrollKSLM - kx_slm + 1); |
10677 | |
10678 | if (keepAddrTogether) kx_inc = 0; |
10679 | |
10680 | doA ? gemmALoadInc(T_ext, Xi_regs, Xi_layoutK[hh_layout], |
10681 | Xi_addrsK[hh_addr], Xi, Xi_strategy, kx_inc, problem, |
10682 | strategy, state) |
10683 | : gemmBLoadInc(T_ext, Xi_regs, Xi_layoutK[hh_layout], |
10684 | Xi_addrsK[hh_addr], Xi, Xi_strategy, kx_inc, problem, |
10685 | strategy, state); |
10686 | |
10687 | if (incrementalCopy) { |
10688 | int rr_eff = doA ? 0 : hh_eff; |
10689 | int cc_eff = doA ? hh_eff : 0; |
10690 | copyRegisters(T_ext, T, Xi_layoutK[hh_layout], Xo_layout, |
10691 | Xi_regs, Xo_regs, rr_eff, cc_eff, false, strategy, |
10692 | state); |
10693 | } |
10694 | } |
10695 | |
10696 | mark(done); |
10697 | if (simtCF) join(16); |
10698 | |
10699 | if (keepAddrTogether) { |
10700 | doA ? gemmAIncrement(T_ext, Xi_layout, Xi_addrs, Xi, Xi_strategy, |
10701 | unrollKSLM, problem, strategy, state) |
10702 | : gemmBIncrement(T_ext, Xi_layout, Xi_addrs, Xi, Xi_strategy, |
10703 | unrollKSLM, problem, strategy, state); |
10704 | } |
10705 | } |
10706 | } |
10707 | |
10708 | // Calculate A offset for SLM copies or cooperative prefetches for this local ID. |
10709 | template <HW hw> |
10710 | void gemm_kernel_generator_t<hw>::gemmCalcWorkshareAOffset(Subregister &off, |
10711 | Subregister &offR, Subregister &offC, const MatrixAddressing &A, |
10712 | const MatrixAddressingStrategy &A_strategy, int ma, int ka, |
10713 | const GEMMProblem &problem, const GEMMStrategy &strategy, |
10714 | GEMMState &state) { |
10715 | bool splitM = (state.effCoopA == CoopSplit::MN); |
10716 | bool splitLinear = (state.effCoopA == CoopSplit::Linear); |
10717 | |
10718 | if (A_strategy.address2D) { |
10719 | if (splitLinear) stub(); |
10720 | if (splitM) { |
10721 | offR = state.ra.alloc_sub<uint32_t>( |
10722 | getHint(HintType::TempComp0, strategy)); |
10723 | mulConstant(1, offR, state.lidN, ma); |
10724 | } else { |
10725 | offC = state.ra.alloc_sub<uint32_t>( |
10726 | getHint(HintType::TempComp0, strategy)); |
10727 | mulConstant(1, offC, state.lidN, ka); |
10728 | } |
10729 | } else { |
10730 | auto Ta_ext = problem.Ta_ext; |
10731 | off = state.ra.alloc_sub<uint32_t>( |
10732 | getHint(HintType::TempComp0, strategy)); |
10733 | |
10734 | switch (A.layout) { |
10735 | case MatrixLayout::Pc: |
10736 | mulConstant(1, off, state.lidN, ma * ka * Ta_ext); |
10737 | break; |
10738 | case MatrixLayout::T: |
10739 | if (splitLinear) stub(); |
10740 | if (splitM) { |
10741 | mul(1, off, state.inputs.lda, state.lidN); |
10742 | mulConstant(1, off, off, ma); |
10743 | } else |
10744 | mulConstant(1, off, state.lidN, ka * Ta_ext); |
10745 | break; |
10746 | case MatrixLayout::N: |
10747 | if (splitLinear) stub(); |
10748 | if (splitM) |
10749 | mulConstant(1, off, state.lidN, ma * Ta_ext); |
10750 | else { |
10751 | mul(1, off, state.inputs.lda, state.lidN); |
10752 | mulConstant(1, off, off, ka); |
10753 | } |
10754 | break; |
10755 | default: stub(); |
10756 | } |
10757 | } |
10758 | } |
10759 | |
10760 | // Calculate B offset for SLM copies or cooperative prefetches for this local ID. |
10761 | template <HW hw> |
10762 | void gemm_kernel_generator_t<hw>::gemmCalcWorkshareBOffset(Subregister &off, |
10763 | Subregister &offR, Subregister &offC, const MatrixAddressing &B, |
10764 | const MatrixAddressingStrategy &B_strategy, int kb, int nb, |
10765 | const GEMMProblem &problem, const GEMMStrategy &strategy, |
10766 | GEMMState &state) { |
10767 | bool splitN = (state.effCoopB == CoopSplit::MN); |
10768 | bool splitLinear = (state.effCoopB == CoopSplit::Linear); |
10769 | |
10770 | if (B_strategy.address2D) { |
10771 | if (splitLinear) stub(); |
10772 | if (splitN) { |
10773 | offC = state.ra.alloc_sub<uint32_t>( |
10774 | getHint(HintType::TempComp0, strategy)); |
10775 | mulConstant(1, offC, state.lidM, nb); |
10776 | } else { |
10777 | offR = state.ra.alloc_sub<uint32_t>( |
10778 | getHint(HintType::TempComp0, strategy)); |
10779 | mulConstant(1, offR, state.lidM, kb); |
10780 | } |
10781 | } else { |
10782 | auto Tb_ext = problem.Tb_ext; |
10783 | off = state.ra.alloc_sub<uint32_t>( |
10784 | getHint(HintType::TempComp0, strategy)); |
10785 | |
10786 | switch (B.layout) { |
10787 | case MatrixLayout::Pr: |
10788 | mulConstant(1, off, state.lidM, nb * kb * Tb_ext); |
10789 | break; |
10790 | case MatrixLayout::N: |
10791 | if (splitLinear) stub(); |
10792 | if (splitN) { |
10793 | mul(1, off, state.inputs.ldb, state.lidM); |
10794 | mulConstant(1, off, off, nb); |
10795 | } else |
10796 | mulConstant(1, off, state.lidM, kb * Tb_ext); |
10797 | break; |
10798 | case MatrixLayout::T: |
10799 | if (splitLinear) stub(); |
10800 | if (splitN) |
10801 | mulConstant(1, off, state.lidM, nb * Tb_ext); |
10802 | else { |
10803 | mul(1, off, state.inputs.ldb, state.lidM); |
10804 | mulConstant(1, off, off, kb); |
10805 | } |
10806 | break; |
10807 | default: stub(); |
10808 | } |
10809 | } |
10810 | } |
10811 | |
10812 | // Remask incoming global data for SLM copies. |
10813 | template <HW hw> |
10814 | void gemm_kernel_generator_t<hw>::gemmSLMRemask(bool remaskA, bool remaskB, |
10815 | GRFMultirange &Ao_regs, GRFMultirange &Bo_regs, int kOffset, |
10816 | const GEMMProblem &problem, const GEMMStrategy &strategy, |
10817 | GEMMState &state) { |
10818 | if (problem.backward()) stub(); |
10819 | |
10820 | auto Ta = problem.Ta, Tb = problem.Tb; |
10821 | |
10822 | bool oremaskA = remaskA && (state.effCoopA == CoopSplit::K); |
10823 | bool oremaskB = remaskB && (state.effCoopB == CoopSplit::K); |
10824 | bool noshareRemask = (oremaskA || oremaskB) |
10825 | || (remaskA && remaskB && Ta.size() != Tb.size()); |
10826 | int aRemaskLen = state.ka_slm; |
10827 | int bRemaskLen = state.kb_slm; |
10828 | |
10829 | Subregister offK_A, offK_B; |
10830 | if (oremaskA) { |
10831 | offK_A = state.ra.alloc_sub<uint32_t>(); |
10832 | mulConstant(1, offK_A, state.lidN, state.ka_slm); |
10833 | } |
10834 | |
10835 | if (oremaskB) { |
10836 | offK_B = state.ra.alloc_sub<uint32_t>(); |
10837 | mulConstant(1, offK_B, state.lidM, state.kb_slm); |
10838 | } |
10839 | |
10840 | if (!noshareRemask && remaskA && remaskB) |
10841 | aRemaskLen = bRemaskLen = std::max(aRemaskLen, bRemaskLen); |
10842 | |
10843 | if (remaskA) { |
10844 | setupTeardownRemask(Ta, 1, true, aRemaskLen, state.K, strategy, state, |
10845 | kOffset, offK_A); |
10846 | remaskLayout(Ta, 1, true, state.Ao_layout, Ao_regs, strategy, state); |
10847 | if (noshareRemask || !remaskB) |
10848 | setupTeardownRemask(Ta, 1, false, aRemaskLen, state.K, strategy, |
10849 | state, kOffset, offK_A); |
10850 | } |
10851 | |
10852 | if (remaskB) { |
10853 | if (noshareRemask || !remaskA) |
10854 | setupTeardownRemask(Tb, 1, true, bRemaskLen, state.K, strategy, |
10855 | state, kOffset, offK_B); |
10856 | remaskLayout(Tb, 1, false, state.Bo_layout, Bo_regs, strategy, state); |
10857 | setupTeardownRemask(Tb, 1, false, bRemaskLen, state.K, strategy, state, |
10858 | kOffset, offK_B); |
10859 | } |
10860 | } |
10861 | |
10862 | // Calculate kSLMA/kSLMB -- countdown variables for SLM copies. |
10863 | template <HW hw> |
10864 | void gemm_kernel_generator_t<hw>::gemmCalcKSLM(const Subregister &kSLM, |
10865 | const Subregister &lid, int kgran, int kdiv, int krep, |
10866 | const GEMMProblem &problem, const GEMMStrategy &strategy, |
10867 | GEMMState &state) { |
10868 | if (kdiv == 1) |
10869 | mov(1, kSLM, state.K); |
10870 | else { |
10871 | auto modLID = lid; |
10872 | if (krep > 1) { |
10873 | if (!is_zero_or_pow2(krep)) stub(); |
10874 | modLID = state.ra.alloc_sub<uint16_t>(); |
10875 | shr(1, modLID, lid, log2(krep)); |
10876 | } |
10877 | if (!problem.backward()) |
10878 | emad(1, kSLM, state.K.w(), -modLID, kgran, strategy, state); |
10879 | else { |
10880 | emad(1, kSLM, strategy.unrollKSLM - kgran, -modLID, kgran, strategy, |
10881 | state); |
10882 | add(1, kSLM, state.K, -kSLM); |
10883 | } |
10884 | if (krep > 1) state.ra.safeRelease(modLID); |
10885 | } |
10886 | } |
10887 | |
10888 | // Calculate barrier count for a k loop. |
10889 | template <HW hw> |
10890 | void gemm_kernel_generator_t<hw>::gemmCalcKLoopBarrierCount(Subregister &count, |
10891 | const Subregister &k, int cooldown, const GEMMProblem &problem, |
10892 | const GEMMStrategy &strategy, GEMMState &state) { |
10893 | int barrierFreq = strategy.barrierFreq; |
10894 | int unrollK = strategy.unroll[LoopK]; |
10895 | int unrollKSLM = strategy.unrollKSLM; |
10896 | |
10897 | if (count.isInvalid()) count = state.ra.alloc_sub<uint32_t>(); |
10898 | |
10899 | if (barrierFreq > 0) { |
10900 | if (!is_zero_or_pow2(barrierFreq)) stub(); |
10901 | |
10902 | if (strategy.splitBarrier && cooldown > 0) |
10903 | cmp(1 | ge | state.flagAP, k, cooldown); |
10904 | add(1 | sat, count, k, barrierFreq - cooldown - unrollK); |
10905 | shr(1, count, count, uint16_t(log2(barrierFreq))); |
10906 | if (strategy.splitBarrier) { |
10907 | (cooldown > 0) ? add(1 | state.flagAP, count, count, 1) |
10908 | : add(1, count, count, 1); |
10909 | } |
10910 | } else if (strategy.slmBuffers > 0) { |
10911 | if (!is_zero_or_pow2(unrollKSLM)) stub(); |
10912 | |
10913 | if (strategy.slmBuffers == 1) { |
10914 | add(1 | sat, count, k, unrollKSLM - 1); |
10915 | if (unrollKSLM == 2) |
10916 | and_(1, count, count, ~uint32_t(1)); |
10917 | else { |
10918 | shr(1, count, count, uint16_t(log2(unrollKSLM))); |
10919 | shl(1, count, count, 1); |
10920 | } |
10921 | } else { |
10922 | add(1 | sat, count, k, unrollKSLM - 1); |
10923 | shr(1, count, count, uint16_t(log2(unrollKSLM))); |
10924 | } |
10925 | } else |
10926 | mov(1, count, 0); |
10927 | } |
10928 | |
10929 | int (const GEMMStrategy &strategy) { |
10930 | if (strategy.slmBuffers == 2) |
10931 | return div_up(strategy.unroll[LoopK], strategy.unrollKSLM); |
10932 | return 0; |
10933 | } |
10934 | |
10935 | static void makeAiBiKCloneLayout(HW hw, bool isA, |
10936 | vector<RegisterBlock> &Xi_layout, |
10937 | vector<vector<RegisterBlock>> &Xi_layoutK, |
10938 | vector<GRFMultirange> &Xi_regsRem, int kx_slm, |
10939 | const GEMMStrategy &strategy, GEMMState &state) { |
10940 | auto regCountK = getRegCount(Xi_layoutK[0]); |
10941 | auto regCount = regCountK * kx_slm; |
10942 | auto offsetK = isA ? &RegisterBlock::offsetC : &RegisterBlock::offsetR; |
10943 | |
10944 | Xi_layout = Xi_layoutK[0]; |
10945 | |
10946 | for (int h1 = 1; h1 < kx_slm; h1++) { |
10947 | Xi_layoutK[h1] = Xi_layoutK[h1 - 1]; |
10948 | for (auto &block : Xi_layoutK[h1]) { |
10949 | block.offsetBytes += regCountK * GRF::bytes(hw); |
10950 | |
10951 | auto oblock = block; |
10952 | oblock.*offsetK += h1; |
10953 | Xi_layout.push_back(std::move(oblock)); |
10954 | } |
10955 | } |
10956 | |
10957 | int = regCount - Xi_regsRem[0].getLen(); |
10958 | if (extraRegs > 0) { |
10959 | for (int q = 0; q < strategy.slmCopies; q++) |
10960 | Xi_regsRem[q].append(state.ra.alloc_range(extraRegs)); |
10961 | } |
10962 | } |
10963 | |
10964 | // Prepare for inner loop. Returns true on success. |
10965 | template <HW hw> |
10966 | bool gemm_kernel_generator_t<hw>::kLoopSetup(const GEMMProblem &problem, |
10967 | const GEMMStrategy &strategy, GEMMState &state) { |
10968 | auto Ta = problem.Ta, Tb = problem.Tb; |
10969 | auto Ta_ext = problem.Ta_ext, Tb_ext = problem.Tb_ext; |
10970 | auto Ta_load = state.Ta_load, Tb_load = state.Tb_load; |
10971 | |
10972 | auto minOPCount = minOuterProductCount(hw, problem, strategy); |
10973 | auto unrollM = strategy.unroll[LoopM]; |
10974 | auto unrollN = strategy.unroll[LoopN]; |
10975 | |
10976 | state.barrierReady = false; |
10977 | |
10978 | // Get A/B named barrier IDs. |
10979 | auto & = state.barrierHeaderM; |
10980 | auto & = state.barrierHeaderN; |
10981 | auto &barrierM = state.barrierM; |
10982 | auto &barrierN = state.barrierN; |
10983 | bool nbM = (strategy.slmA || strategy.barrierFreq) |
10984 | && strategy.namedBarriers[LoopM]; |
10985 | bool nbN = (strategy.slmB || strategy.barrierFreq) |
10986 | && strategy.namedBarriers[LoopN]; |
10987 | |
10988 | if (nbM) { |
10989 | barrierHeaderM = state.ra.alloc(); |
10990 | |
10991 | // Get fN.0 subregister for use with sync.bar. |
10992 | barrierM = state.raVFlag.alloc(2); // TODO: unlock these flag registers. |
10993 | state.raVFlag.release(FlagRegister {barrierM.getARFBase(), 1}); |
10994 | barrierM = FlagRegister {barrierM.getARFBase(), 0}; |
10995 | |
10996 | if (!is_zero_or_pow2(strategy.wg[LoopM]) |
10997 | || !is_zero_or_pow2(strategy.namedBarriers[LoopM])) |
10998 | stub(); |
10999 | shr(1, barrierHeaderM.uw(4), state.lidM, |
11000 | log2(strategy.wg[LoopM]) - log2(strategy.namedBarriers[LoopM])); |
11001 | } |
11002 | if (nbN) { |
11003 | barrierHeaderN = state.ra.alloc(); |
11004 | barrierN = state.raVFlag.alloc(2); |
11005 | state.raVFlag.release(FlagRegister {barrierN.getARFBase(), 1}); |
11006 | barrierN = FlagRegister {barrierN.getARFBase(), 0}; |
11007 | |
11008 | if (!is_zero_or_pow2(strategy.wg[LoopN]) |
11009 | || !is_zero_or_pow2(strategy.namedBarriers[LoopN])) |
11010 | stub(); |
11011 | shr(1, barrierHeaderN.uw(4), state.lidN, |
11012 | log2(strategy.wg[LoopN]) - log2(strategy.namedBarriers[LoopN])); |
11013 | } |
11014 | if (nbM) { |
11015 | int threadsPerMBar = strategy.wg[LoopM] * strategy.wg[LoopN] |
11016 | / strategy.namedBarriers[LoopM]; |
11017 | mov(1, barrierHeaderM.uw(5), threadsPerMBar | (threadsPerMBar << 8)); |
11018 | } |
11019 | if (nbN) { |
11020 | int threadsPerNBar = strategy.wg[LoopM] * strategy.wg[LoopN] |
11021 | / strategy.namedBarriers[LoopN]; |
11022 | mov(1, barrierHeaderN.uw(5), threadsPerNBar | (threadsPerNBar << 8)); |
11023 | } |
11024 | if (nbM && nbN) |
11025 | add(1, barrierHeaderN.uw(4), barrierHeaderN.uw(4), |
11026 | strategy.namedBarriers[LoopM]); |
11027 | if (nbM) mov(1, barrierM, barrierHeaderM.uw(4)); |
11028 | if (nbN) mov(1, barrierN, barrierHeaderN.uw(4)); |
11029 | |
11030 | // Get tokens for barriers/fences. |
11031 | for (int q = 0; q < 2; q++) { |
11032 | state.tokenBarrierFence[q] = -1; |
11033 | state.modBarrierFence[q] = InstructionModifier {}; |
11034 | } |
11035 | |
11036 | if (hw >= HW::Gen12LP) { |
11037 | if (strategy.needsBarrier()) |
11038 | state.tokenBarrierFence[0] = state.tokenAllocator.tryAlloc(); |
11039 | if (nbM && nbN) |
11040 | state.tokenBarrierFence[1] = state.tokenAllocator.tryAlloc(); |
11041 | for (int q = 0; q < 2; q++) |
11042 | if (state.tokenBarrierFence[q] >= 0) |
11043 | state.modBarrierFence[q] = SBID(state.tokenBarrierFence[q]); |
11044 | } |
11045 | |
11046 | // Remainder load preparations. |
11047 | auto &ka_loadRem = state.ka_loadRem, &kb_loadRem = state.kb_loadRem; |
11048 | ka_loadRem = 1, kb_loadRem = 1; |
11049 | |
11050 | // For packed layouts, extend remainder loads to encompass a full logical block. |
11051 | int ignore; |
11052 | getGranularities(problem.A, ignore, ka_loadRem); |
11053 | getGranularities(problem.B, kb_loadRem, ignore); |
11054 | |
11055 | // With 2D block loads, extend k unroll to at least a full block (array). |
11056 | bool a2D = isBlock2D(strategy.A.accessType); |
11057 | bool b2D = isBlock2D(strategy.B.accessType); |
11058 | bool ai2D = strategy.slmA && isBlock2D(state.Ai_strategy.accessType); |
11059 | bool bi2D = strategy.slmB && isBlock2D(state.Bi_strategy.accessType); |
11060 | if (a2D || ai2D) { |
11061 | ka_loadRem = state.A_layout[0].nc; |
11062 | if (!isColMajor(problem.A.layout)) |
11063 | ka_loadRem *= state.A_layout[0].count; |
11064 | } |
11065 | if (b2D || bi2D) { |
11066 | kb_loadRem = state.B_layout[0].nr; |
11067 | if (isColMajor(problem.B.layout)) kb_loadRem *= state.B_layout[0].count; |
11068 | } |
11069 | |
11070 | // Fragment the A, B layouts into smaller blocks (usually 1 row/column) for remainder loads. |
11071 | if (!getSubblocks(Ta_load, state.A_layoutRem, state.A_addrsRem, |
11072 | state.A_layout, state.A_addrs, true, 0, ka_loadRem, |
11073 | strategy.A.padded, problem.A, strategy.A)) |
11074 | return false; |
11075 | if (!getSubblocks(Tb_load, state.B_layoutRem, state.B_addrsRem, |
11076 | state.B_layout, state.B_addrs, false, 0, kb_loadRem, |
11077 | strategy.B.padded, problem.B, strategy.B)) |
11078 | return false; |
11079 | |
11080 | // Add k masking. |
11081 | if (a2D && (ka_loadRem > 1)) |
11082 | addMasking( |
11083 | Ta_load, state.A_layoutRem, false, true, problem.A, strategy.A); |
11084 | if (b2D && (kb_loadRem > 1)) |
11085 | addMasking( |
11086 | Tb_load, state.B_layoutRem, true, false, problem.B, strategy.B); |
11087 | |
11088 | // Ai/Bi remainders. |
11089 | auto &Ai_layoutRem = state.Ai_layoutRem, &Bi_layoutRem = state.Bi_layoutRem; |
11090 | auto &Ai_layoutK = state.Ai_layoutK, &Bi_layoutK = state.Bi_layoutK; |
11091 | auto &Ai_addrsRem = state.Ai_addrsRem, &Bi_addrsRem = state.Bi_addrsRem; |
11092 | auto &Ai_addrsK = state.Ai_addrsK, &Bi_addrsK = state.Bi_addrsK; |
11093 | auto &Ai_regsRem = state.Ai_regsRem, &Bi_regsRem = state.Bi_regsRem; |
11094 | auto &Ao_regsRem = state.Ao_regsRem, &Bo_regsRem = state.Bo_regsRem; |
11095 | auto &Ai_hasKRem = state.Ai_hasKRem, &Bi_hasKRem = state.Bi_hasKRem; |
11096 | auto &Ai_lateKRem = state.Ai_lateKRem, &Bi_lateKRem = state.Bi_lateKRem; |
11097 | auto &Ai_remIncrCopy = state.Ai_remIncrCopy, |
11098 | &Bi_remIncrCopy = state.Bi_remIncrCopy; |
11099 | auto &Ai_incrementalRem = state.Ai_incrementalRem, |
11100 | &Bi_incrementalRem = state.Bi_incrementalRem; |
11101 | auto &aioShareRem = state.aioShareRem, &bioShareRem = state.bioShareRem; |
11102 | int ka_slm = state.ka_slm, kb_slm = state.kb_slm; |
11103 | |
11104 | Ai_layoutRem = state.Ai_layout; |
11105 | Bi_layoutRem = state.Bi_layout; |
11106 | Ai_addrsRem = state.Ai_addrs; |
11107 | Bi_addrsRem = state.Bi_addrs; |
11108 | Ai_regsRem = state.Ai_regs; |
11109 | Bi_regsRem = state.Bi_regs; |
11110 | Ao_regsRem = state.Ao_regs; |
11111 | Bo_regsRem = state.Bo_regs; |
11112 | |
11113 | Ai_hasKRem = Ai_lateKRem = false; |
11114 | Bi_hasKRem = Bi_lateKRem = false; |
11115 | Ai_remIncrCopy = Bi_remIncrCopy = false; |
11116 | |
11117 | if (ai2D && (ka_loadRem > 1) && state.Ai_strategy.address2D) { |
11118 | Ai_hasKRem = true; |
11119 | addMasking(Ta_ext, state.Ai_layoutRem, false, true, state.Ai, |
11120 | state.Ai_strategy); |
11121 | } |
11122 | |
11123 | if (bi2D && (kb_loadRem > 1) && state.Bi_strategy.address2D) { |
11124 | Bi_hasKRem = true; |
11125 | addMasking(Tb_ext, state.Bi_layoutRem, true, false, state.Bi, |
11126 | state.Bi_strategy); |
11127 | } |
11128 | |
11129 | if (strategy.slmA && !Ai_hasKRem) |
11130 | Ai_lateKRem |= !isRegisterColMajor(Ta_ext, state.Ai, state.Ai_strategy); |
11131 | if (strategy.slmB && !Bi_hasKRem) |
11132 | Bi_lateKRem |= isRegisterColMajor(Tb_ext, state.Bi, state.Bi_strategy); |
11133 | |
11134 | Ai_incrementalRem |
11135 | = strategy.slmA && !state.Ai_hasKRem && !state.Ai_lateKRem; |
11136 | Bi_incrementalRem |
11137 | = strategy.slmB && !state.Bi_hasKRem && !state.Bi_lateKRem; |
11138 | aioShareRem = state.aioShare; |
11139 | bioShareRem = state.bioShare; |
11140 | |
11141 | if (Ai_incrementalRem) { |
11142 | // Prepare to split Ai layout in k dimension. If it's not possible to do in-place, then |
11143 | // either redo the layout or copy Ai->Ao incrementally. |
11144 | Ai_layoutK.resize(ka_slm); |
11145 | Ai_addrsK.resize(ka_slm); |
11146 | for (int h = 0; h < ka_slm; h++) { |
11147 | bool success = false; |
11148 | |
11149 | if (h < int(Ai_addrsK.size())) { |
11150 | success = getSubblocks(Ta_ext, Ai_layoutK[h], Ai_addrsK[h], |
11151 | Ai_layoutRem, state.Ai_addrs, true, h, h + 1, |
11152 | state.Ai_strategy.padded, state.Ai, state.Ai_strategy); |
11153 | } |
11154 | |
11155 | if (!success && h == 0) stub(); |
11156 | |
11157 | if (!success) { |
11158 | // Maybe the subblock is OK, but we didn't get an address register. Try again without |
11159 | // asking for address registers. |
11160 | Ai_addrsK.resize(1); |
11161 | success = getSubblocks(Ta_ext, Ai_layoutK[h], Ai_layoutRem, |
11162 | true, h, h + 1, state.Ai_strategy.padded, state.Ai, |
11163 | state.Ai_strategy); |
11164 | } |
11165 | |
11166 | if (!success) { |
11167 | // Can't make a subblock. Will need a new layout or an incremental copy. |
11168 | if (strategy.slmUseIncrCopy) { |
11169 | Ai_remIncrCopy = true; |
11170 | Ai_layoutK.resize(1); |
11171 | } else |
11172 | makeAiBiKCloneLayout(hw, true, Ai_layoutRem, Ai_layoutK, |
11173 | Ai_regsRem, ka_slm, strategy, state); |
11174 | |
11175 | aioShareRem = false; |
11176 | if (state.aioShare || state.aoReuseA) |
11177 | Ao_regsRem = state.ra.alloc_range( |
11178 | getRegCount(state.Ao_layout)); |
11179 | break; |
11180 | } |
11181 | } |
11182 | } |
11183 | |
11184 | if (Bi_incrementalRem) { |
11185 | Bi_layoutK.resize(kb_slm); |
11186 | Bi_addrsK.resize(kb_slm); |
11187 | for (int h = 0; h < kb_slm; h++) { |
11188 | bool success = false; |
11189 | |
11190 | if (h < int(Bi_addrsK.size())) { |
11191 | success = getSubblocks(Tb_ext, Bi_layoutK[h], Bi_addrsK[h], |
11192 | Bi_layoutRem, state.Bi_addrs, false, h, h + 1, |
11193 | state.Bi_strategy.padded, state.Bi, state.Bi_strategy); |
11194 | } |
11195 | |
11196 | if (!success && h == 0) stub(); |
11197 | |
11198 | if (!success) { |
11199 | Bi_addrsK.resize(1); |
11200 | success = getSubblocks(Tb_ext, Bi_layoutK[h], Bi_layoutRem, |
11201 | false, h, h + 1, state.Bi_strategy.padded, state.Bi, |
11202 | state.Bi_strategy); |
11203 | } |
11204 | |
11205 | if (!success) { |
11206 | if (strategy.slmUseIncrCopy) { |
11207 | Bi_remIncrCopy = true; |
11208 | Bi_layoutK.resize(1); |
11209 | } else |
11210 | makeAiBiKCloneLayout(hw, false, Bi_layoutRem, Bi_layoutK, |
11211 | Bi_regsRem, kb_slm, strategy, state); |
11212 | |
11213 | bioShareRem = false; |
11214 | if (state.bioShare || state.boReuseB) |
11215 | Bo_regsRem = state.ra.alloc_range( |
11216 | getRegCount(state.Bo_layout)); |
11217 | break; |
11218 | } |
11219 | } |
11220 | } |
11221 | |
11222 | // Allocate repack registers if we need to assemble multiple loads for |
11223 | // each outer product calculation. |
11224 | // TODO: allow allocation to overlap unneeded A/B registers. |
11225 | auto &repackARem = state.repackARem, &repackBRem = state.repackBRem; |
11226 | auto &ka_repackRem = state.ka_repackRem, &kb_repackRem = state.kb_repackRem; |
11227 | |
11228 | repackARem = state.repackA; |
11229 | repackBRem = state.repackB; |
11230 | ka_repackRem = state.repackA ? ka_loadRem : 0; |
11231 | kb_repackRem = state.repackB ? kb_loadRem : 0; |
11232 | if (minOPCount > 1) { |
11233 | int crosspackA, crosspackB, tileM_A, tileK_A, tileK_B, tileN_B; |
11234 | std::tie(crosspackA, crosspackB) |
11235 | = targetKernelCrosspack(hw, problem, strategy); |
11236 | std::tie(tileM_A, tileK_A, tileK_B, tileN_B) |
11237 | = targetKernelTiling(hw, problem, strategy); |
11238 | |
11239 | if (ka_loadRem < minOPCount) { |
11240 | ka_repackRem = minOPCount; |
11241 | if (!repackARem) { |
11242 | makeUnbackedRegLayout(Ta, state.Ar_layout, unrollM, |
11243 | ka_repackRem, isLayoutColMajor(state.A_layout), |
11244 | crosspackA, tileM_A, tileK_A); |
11245 | state.Ar_regs |
11246 | = state.ra.alloc_range(getRegCount(state.Ar_layout), |
11247 | getHint(HintType::A0, strategy)); |
11248 | repackARem = true; |
11249 | } |
11250 | } |
11251 | if (kb_loadRem < minOPCount) { |
11252 | kb_repackRem = minOPCount; |
11253 | if (!repackBRem) { |
11254 | makeUnbackedRegLayout(Tb, state.Br_layout, kb_repackRem, |
11255 | unrollN, isLayoutColMajor(state.B_layout), crosspackB, |
11256 | tileK_B, tileN_B); |
11257 | state.Br_regs |
11258 | = state.ra.alloc_range(getRegCount(state.Br_layout), |
11259 | getHint(HintType::B0, strategy)); |
11260 | repackBRem = true; |
11261 | } |
11262 | } |
11263 | } |
11264 | |
11265 | state.remActiveA = state.remActiveB = state.remActiveSLM = false; |
11266 | state.slmRemaskA = state.slmRemaskB = false; |
11267 | state.firstKLoopSegment = true; |
11268 | |
11269 | return true; |
11270 | } |
11271 | |
11272 | // Set up A/B addresses again in order to prepare for a non-remainder k loop. |
11273 | // Optionally offset addresses in the k dimension. |
11274 | template <HW hw> |
11275 | template <typename I> |
11276 | void gemm_kernel_generator_t<hw>::kLoopReset(const I &kOffset, |
11277 | const GEMMProblem &problem, const GEMMStrategy &strategy, |
11278 | GEMMState &state) { |
11279 | auto Ta = problem.Ta, Ta_ext = problem.Ta_ext, Ta_load = state.Ta_load; |
11280 | auto Tb = problem.Tb, Tb_ext = problem.Tb_ext, Tb_load = state.Tb_load; |
11281 | auto globalA = strategy.slmA ? state.Ai : problem.A; |
11282 | auto globalB = strategy.slmB ? state.Bi : problem.B; |
11283 | auto globalAStrategy = strategy.slmA ? state.Ai_strategy : strategy.A; |
11284 | auto globalBStrategy = strategy.slmB ? state.Bi_strategy : strategy.B; |
11285 | auto globalAParams = strategy.slmA ? state.Ai_params : state.A_params; |
11286 | auto globalBParams = strategy.slmB ? state.Bi_params : state.B_params; |
11287 | |
11288 | if (kOffset.isValid()) { |
11289 | auto &A_offC = globalAParams.offC; |
11290 | auto A_offCOrig = A_offC; |
11291 | if (globalAStrategy.address2D) { |
11292 | if (A_offC == state.h0) |
11293 | A_offC = state.ra.alloc_sub<int32_t>( |
11294 | getHint(HintType::LongTerm, strategy)); |
11295 | A_offCOrig.isValid() ? add(1, A_offC, A_offCOrig, kOffset) |
11296 | : mov(1, A_offC, kOffset); |
11297 | } else |
11298 | gemmOffsetAk(kOffset, strategy.slmA ? state.effAi : state.effA, |
11299 | globalA, problem, strategy, state); |
11300 | |
11301 | if (strategy.prefetchA) { |
11302 | if (strategy.A_prefetch.address2D) { |
11303 | auto &Ap_offC = state.Ap_params.offC; |
11304 | auto Ap_offCOrig = Ap_offC; |
11305 | if (Ap_offC == A_offCOrig) |
11306 | Ap_offC = A_offC; |
11307 | else { |
11308 | if (Ap_offC == state.h0) |
11309 | Ap_offC = state.ra.alloc_sub<int32_t>( |
11310 | getHint(HintType::LongTerm, strategy)); |
11311 | Ap_offCOrig.isValid() |
11312 | ? add(1, Ap_offC, Ap_offCOrig, kOffset) |
11313 | : mov(1, Ap_offC, kOffset); |
11314 | } |
11315 | } else if (state.effAp != state.effA) |
11316 | gemmOffsetAk(kOffset, state.effAp, globalA, problem, strategy, |
11317 | state); |
11318 | } |
11319 | |
11320 | auto &B_offR = globalBParams.offR; |
11321 | auto B_offROrig = B_offR; |
11322 | if (globalBStrategy.address2D) { |
11323 | if (B_offR == A_offCOrig) |
11324 | B_offR = A_offC; |
11325 | else { |
11326 | if (B_offR == state.h0) |
11327 | B_offR = state.ra.alloc_sub<int32_t>( |
11328 | getHint(HintType::LongTerm, strategy)); |
11329 | B_offROrig.isValid() ? add(1, B_offR, B_offROrig, kOffset) |
11330 | : mov(1, B_offR, kOffset); |
11331 | } |
11332 | } else |
11333 | gemmOffsetBk(kOffset, strategy.slmB ? state.effBi : state.effB, |
11334 | globalA, problem, strategy, state); |
11335 | |
11336 | if (strategy.prefetchB) { |
11337 | if (strategy.B_prefetch.address2D) { |
11338 | auto &Bp_offR = state.Bp_params.offR; |
11339 | auto Bp_offROrig = Bp_offR; |
11340 | if (Bp_offR == B_offROrig) |
11341 | Bp_offR = B_offR; |
11342 | else { |
11343 | if (Bp_offR == state.h0) |
11344 | Bp_offR = state.ra.alloc_sub<int32_t>( |
11345 | getHint(HintType::LongTerm, strategy)); |
11346 | Bp_offROrig.isValid() |
11347 | ? add(1, Bp_offR, Bp_offROrig, kOffset) |
11348 | : mov(1, Bp_offR, kOffset); |
11349 | } |
11350 | } else if (state.effBp != state.effB) |
11351 | gemmOffsetBk(kOffset, state.effBp, globalB, problem, strategy, |
11352 | state); |
11353 | } |
11354 | } |
11355 | |
11356 | gemmCacheLDABMultiples(problem, strategy, state); |
11357 | setupAddr(Ta_ext, state.Ap_addrs, state.effAp, state.Ap_layout, |
11358 | state.inputs.lda, globalA, strategy.A_prefetch, strategy, state, |
11359 | state.Ap_params, state.ldaMultiples); |
11360 | setupAddr(Tb_ext, state.Bp_addrs, state.effBp, state.Bp_layout, |
11361 | state.inputs.ldb, globalB, strategy.B_prefetch, strategy, state, |
11362 | state.Bp_params, state.ldbMultiples); |
11363 | setupAddr(Ta_ext, state.Ai_addrs, state.effAi, state.Ai_layout, |
11364 | state.inputs.lda, state.Ai, state.Ai_strategy, strategy, state, |
11365 | state.Ai_params, state.ldaMultiples); |
11366 | setupAddr(Tb_ext, state.Bi_addrs, state.effBi, state.Bi_layout, |
11367 | state.inputs.ldb, state.Bi, state.Bi_strategy, strategy, state, |
11368 | state.Bi_params, state.ldbMultiples); |
11369 | setupAddr(Ta, state.Ao_addrs, state.effAo, state.Ao_layout, Subregister(), |
11370 | state.Ao, state.Ao_strategy, strategy, state); |
11371 | setupAddr(Tb, state.Bo_addrs, state.effBo, state.Bo_layout, Subregister(), |
11372 | state.Bo, state.Bo_strategy, strategy, state); |
11373 | setupAddr(Ta_load, state.A_addrs, state.effA, state.A_layout, |
11374 | state.inputs.lda, problem.A, strategy.A, strategy, state, |
11375 | state.A_params, state.ldaMultiples); |
11376 | setupAddr(Tb_load, state.B_addrs, state.effB, state.B_layout, |
11377 | state.inputs.ldb, problem.B, strategy.B, strategy, state, |
11378 | state.B_params, state.ldbMultiples); |
11379 | releaseLDMultiples(state.ldaMultiples, state); |
11380 | releaseLDMultiples(state.ldbMultiples, state); |
11381 | releaseIndexVec(state); |
11382 | |
11383 | gemmCalcIncrements(problem, strategy, state); |
11384 | } |
11385 | |
11386 | template <HW hw> |
11387 | void gemm_kernel_generator_t<hw>::(GEMMState &state) { |
11388 | if (state.barrierHeader.isInvalid()) { |
11389 | state.barrierHeader = state.ra.alloc(); |
11390 | state.barrierReady = false; |
11391 | } |
11392 | } |
11393 | |
11394 | template <HW hw> |
11395 | GRF gemm_kernel_generator_t<hw>::(GEMMState &state) { |
11396 | kLoopAllocBarrierHeader(state); |
11397 | if (!state.barrierReady) { |
11398 | if (state.r0_info.isARF()) stub(); |
11399 | barrierheader(state.barrierHeader, GRF {state.r0_info.getBase()}); |
11400 | state.barrierReady = true; |
11401 | } |
11402 | |
11403 | return state.barrierHeader; |
11404 | } |
11405 | |
11406 | // Create one step of a sequence of inner loops for a GEMM-like kernel. |
11407 | template <HW hw> |
11408 | void gemm_kernel_generator_t<hw>::kLoop(KLoop type, const GEMMProblem &problem, |
11409 | GEMMStrategy &strategy, GEMMState &state) { |
11410 | auto Ta = problem.Ta, Tb = problem.Tb, Tc = problem.Tc; |
11411 | auto Ta_ext = problem.Ta_ext, Tb_ext = problem.Tb_ext; |
11412 | auto Ta_load = state.Ta_load, Tb_load = state.Tb_load; |
11413 | |
11414 | bool cLoadAhead = strategy.cLoadAhead; |
11415 | auto opCountMain = outerProductCount(hw, problem, strategy); |
11416 | auto minOPCount = minOuterProductCount(hw, problem, strategy); |
11417 | auto opCountRem = minOPCount; |
11418 | |
11419 | auto A_copies = strategy.A_copies; |
11420 | auto B_copies = strategy.B_copies; |
11421 | auto slmCopies = strategy.slmCopies; |
11422 | auto slmBuffers = strategy.slmBuffers; |
11423 | auto ka_loadMain = strategy.ka_load, ka_loadRem = state.ka_loadRem; |
11424 | auto kb_loadMain = strategy.kb_load, kb_loadRem = state.kb_loadRem; |
11425 | auto ka_pfStride = strategy.ka_pfStride; |
11426 | auto kb_pfStride = strategy.kb_pfStride; |
11427 | bool slmA = strategy.slmA; |
11428 | bool slmB = strategy.slmB; |
11429 | bool slmASums = state.slmASums; |
11430 | bool slmBSums = state.slmBSums; |
11431 | bool a2D = isBlock2D(strategy.A.accessType); |
11432 | bool b2D = isBlock2D(strategy.B.accessType); |
11433 | bool ai2D = strategy.slmA && isBlock2D(state.Ai_strategy.accessType); |
11434 | bool bi2D = strategy.slmB && isBlock2D(state.Bi_strategy.accessType); |
11435 | auto unrollM = strategy.unroll[LoopM]; |
11436 | auto unrollN = strategy.unroll[LoopN]; |
11437 | auto unrollK = strategy.unroll[LoopK]; |
11438 | auto unrollKSLM = strategy.unrollKSLM; |
11439 | auto ka_slm = state.ka_slm; |
11440 | auto kb_slm = state.kb_slm; |
11441 | bool calcASums = problem.needsASums(); |
11442 | bool calcBSums = problem.needsBSums(); |
11443 | bool readA = true, readB = true; |
11444 | |
11445 | bool Ai_incrementalRem = state.Ai_incrementalRem; |
11446 | bool Bi_incrementalRem = state.Bi_incrementalRem; |
11447 | bool Ai_remIncrCopy = state.Ai_remIncrCopy; |
11448 | bool Bi_remIncrCopy = state.Bi_remIncrCopy; |
11449 | bool Ai_lateKRem = state.Ai_lateKRem; |
11450 | bool Bi_lateKRem = state.Bi_lateKRem; |
11451 | |
11452 | bool &remActiveA = state.remActiveA, &remActiveB = state.remActiveB; |
11453 | bool &remActiveSLM = state.remActiveSLM; |
11454 | auto &kMasksSLM = state.kMasksSLM; |
11455 | bool &slmRemaskA = state.slmRemaskA, &slmRemaskB = state.slmRemaskB; |
11456 | bool lateKLoopCheck = state.lateKLoopCheck; |
11457 | |
11458 | bool needBarrier = (slmA || slmB || strategy.barrierFreq > 0); |
11459 | bool nbM = (strategy.slmA || strategy.barrierFreq) |
11460 | && strategy.namedBarriers[LoopM]; |
11461 | bool nbN = (strategy.slmB || strategy.barrierFreq) |
11462 | && strategy.namedBarriers[LoopN]; |
11463 | |
11464 | bool needSLMReset = false; |
11465 | |
11466 | int curPhase; |
11467 | int lastThresh = 0; |
11468 | |
11469 | // Get r0 information where needed. |
11470 | GRF r0_info; |
11471 | if (needBarrier) { |
11472 | if (state.r0_info.isARF()) stub(); |
11473 | r0_info = GRF {state.r0_info.getBase()}; |
11474 | } |
11475 | |
11476 | // Unified barrier and SLM fence handling for k loop. |
11477 | auto &modBarrierFence = state.modBarrierFence; |
11478 | auto & = state.barrierHeader; |
11479 | auto &barrierReady = state.barrierReady; |
11480 | |
11481 | auto getFenceTemp = [&]() { |
11482 | auto temp = state.ra.try_alloc(); |
11483 | if (temp.isValid()) return temp; |
11484 | if (barrierHeader.isValid()) { |
11485 | barrierReady = false; |
11486 | return barrierHeader; |
11487 | } |
11488 | throw ngen::out_of_registers_exception(); |
11489 | }; |
11490 | |
11491 | auto releaseFenceTemp = [&](GRF temp) { |
11492 | if (temp != barrierHeader) state.ra.release(temp); |
11493 | }; |
11494 | |
11495 | GRF slmFenceTemp; |
11496 | auto slmFenceIssue = [&]() { |
11497 | if (hw >= HW::Gen11) { |
11498 | slmFenceTemp = getFenceTemp(); |
11499 | slmfence(modBarrierFence[0], slmFenceTemp, r0_info); |
11500 | releaseFenceTemp(slmFenceTemp); |
11501 | } |
11502 | }; |
11503 | |
11504 | auto slmFenceWait = [&]() { |
11505 | if (hw >= HW::Gen12LP) |
11506 | wrdep(slmFenceTemp); |
11507 | else if (hw >= HW::Gen11) |
11508 | mov<uint32_t>(8, null, slmFenceTemp); |
11509 | }; |
11510 | |
11511 | enum class KBarrierType { Normal, Signal, Wait }; |
11512 | auto kLoopBarrier = [&](bool withSLMFence, |
11513 | KBarrierType type = KBarrierType::Normal) { |
11514 | withSLMFence &= (hw >= HW::Gen11); // No SLM fences needed on Gen9. |
11515 | |
11516 | if (withSLMFence && type == KBarrierType::Wait) { |
11517 | auto temp = getFenceTemp(); |
11518 | slmfence(modBarrierFence[0], temp, r0_info); |
11519 | (hw >= HW::Gen12LP) ? wrdep(temp) : mov<uint32_t>(8, null, temp); |
11520 | releaseFenceTemp(temp); |
11521 | } |
11522 | |
11523 | if (!nbM && !nbN) { |
11524 | if (type != KBarrierType::Wait) { |
11525 | kLoopAllocBarrierHeader(state); |
11526 | auto temp = getFenceTemp(); |
11527 | if (withSLMFence) { |
11528 | slmfence(modBarrierFence[0], temp, r0_info); |
11529 | (hw >= HW::Gen12LP) ? wrdep(temp) |
11530 | : mov<uint32_t>(8, null, temp); |
11531 | } |
11532 | auto = kLoopGetBarrierHeader(state); |
11533 | barriermsg(modBarrierFence[0], header); |
11534 | releaseFenceTemp(temp); |
11535 | } |
11536 | if (type != KBarrierType::Signal) barrierwait(); |
11537 | } else { |
11538 | if (type != KBarrierType::Wait) { |
11539 | if (withSLMFence) { |
11540 | auto temp = getFenceTemp(); |
11541 | slmfence(temp, r0_info); |
11542 | wrdep(temp); |
11543 | releaseFenceTemp(temp); |
11544 | } |
11545 | if (nbM) barriermsg(modBarrierFence[0], state.barrierHeaderM); |
11546 | if (nbN) |
11547 | barriermsg( |
11548 | modBarrierFence[nbM ? 1 : 0], state.barrierHeaderN); |
11549 | } |
11550 | if (type != KBarrierType::Signal) { |
11551 | if (nbM) sync.bar(state.barrierM); |
11552 | if (nbN) sync.bar(state.barrierN); |
11553 | } |
11554 | } |
11555 | }; |
11556 | |
11557 | bool mustActivateRemainder = false; |
11558 | |
11559 | auto activateABRemainder = [&](bool active, bool doA, bool doB) { |
11560 | if (remActiveA == active) doA = false; |
11561 | if (remActiveB == active) doB = false; |
11562 | if (!active && ((doA && remActiveA) || (doB && remActiveB))) stub(); |
11563 | if (!doA && !doB) return; |
11564 | |
11565 | if (doA) remActiveA = active; |
11566 | if (doB) remActiveB = active; |
11567 | |
11568 | // Adjust A/B/Ai/Bi addresses if needed. |
11569 | if (doA) |
11570 | adjustSubblockAddrs(Ta_load, state.A_layoutRem, state.A_addrsRem, |
11571 | state.A_layout, state.A_addrs, problem.A, strategy.A, |
11572 | strategy, state); |
11573 | if (doB) |
11574 | adjustSubblockAddrs(Tb_load, state.B_layoutRem, state.B_addrsRem, |
11575 | state.B_layout, state.B_addrs, problem.B, strategy.B, |
11576 | strategy, state); |
11577 | |
11578 | if (doA && strategy.slmA && (state.effCoopA == CoopSplit::K) && !ai2D) { |
11579 | vector<RegisterBlock> tempLayout; |
11580 | vector<GRFRange> tempAddrs; |
11581 | if (!getSubblocks(Ta_ext, tempLayout, tempAddrs, state.Ai_layout, |
11582 | state.Ai_addrs, true, 0, 1, state.Ai_strategy.padded, |
11583 | state.Ai, state.Ai_strategy)) |
11584 | stub(); |
11585 | adjustSubblockAddrs(Ta_ext, tempLayout, tempAddrs, state.Ai_layout, |
11586 | state.Ai_addrs, state.Ai, state.Ai_strategy, strategy, |
11587 | state); |
11588 | } |
11589 | if (doB && strategy.slmB && (state.effCoopB == CoopSplit::K) && !bi2D) { |
11590 | vector<RegisterBlock> tempLayout; |
11591 | vector<GRFRange> tempAddrs; |
11592 | if (!getSubblocks(Tb_ext, tempLayout, tempAddrs, state.Bi_layout, |
11593 | state.Bi_addrs, false, 0, 1, state.Bi_strategy.padded, |
11594 | state.Bi, state.Bi_strategy)) |
11595 | stub(); |
11596 | adjustSubblockAddrs(Tb_ext, tempLayout, tempAddrs, state.Bi_layout, |
11597 | state.Bi_addrs, state.Bi, state.Bi_strategy, strategy, |
11598 | state); |
11599 | } |
11600 | |
11601 | if (doA && a2D && (ka_loadRem > 1)) |
11602 | setAddrRemainder(Ta_load, state.A_addrsRem, state.A_layoutRem, |
11603 | Subregister(), state.K, problem.A, strategy.A, strategy, |
11604 | state); |
11605 | if (doB && b2D && (kb_loadRem > 1)) |
11606 | setAddrRemainder(Tb_load, state.B_addrsRem, state.B_layoutRem, |
11607 | state.K, Subregister(), problem.B, strategy.B, strategy, |
11608 | state); |
11609 | |
11610 | // Recalculate lda_ka/ldb_kb if needed. |
11611 | gemmCalcIncrements( |
11612 | problem, strategy, state, ka_loadRem, kb_loadRem, doA, doB); |
11613 | }; |
11614 | |
11615 | Subregister kSLMStorage; |
11616 | Subregister kSLMA, kSLMB; // k remainders for k-split SLM loads |
11617 | |
11618 | auto resetKSLM = [&]() { |
11619 | state.ra.safeRelease(kSLMStorage); |
11620 | kSLMA = kSLMB = invalid; |
11621 | }; |
11622 | |
11623 | auto activateSLMRemainder = [&](bool active, int kOffset = 0) { |
11624 | // Calculate or recalculate SLM k remainders as needed. |
11625 | if (active && kSLMStorage.isInvalid()) { |
11626 | if (Ai_incrementalRem || Bi_incrementalRem) |
11627 | kSLMStorage = state.ra.alloc_sub<uint32_t>(); |
11628 | |
11629 | if (Ai_incrementalRem) { |
11630 | kSLMA = kSLMStorage.w(0); |
11631 | int kgran, kdiv, krep; |
11632 | switch (state.effCoopA) { |
11633 | case CoopSplit::MN: |
11634 | kgran = unrollKSLM; |
11635 | kdiv = 1; |
11636 | krep = strategy.wg[LoopN]; |
11637 | break; |
11638 | case CoopSplit::K: |
11639 | kgran = ka_slm; |
11640 | kdiv = strategy.wg[LoopN]; |
11641 | krep = 1; |
11642 | break; |
11643 | case CoopSplit::Linear: |
11644 | kgran = std::max(state.Ai.crosspack, state.Ai.tileC); |
11645 | kdiv = unrollKSLM / kgran; |
11646 | krep = strategy.wg[LoopN] / kdiv; |
11647 | break; |
11648 | default: stub(); |
11649 | } |
11650 | gemmCalcKSLM(kSLMA, state.lidN, kgran, kdiv, krep, problem, |
11651 | strategy, state); |
11652 | } |
11653 | |
11654 | if (Bi_incrementalRem) { |
11655 | kSLMB = kSLMStorage.w(1); |
11656 | int kgran, kdiv, krep; |
11657 | switch (state.effCoopB) { |
11658 | case CoopSplit::MN: |
11659 | kgran = unrollKSLM; |
11660 | kdiv = 1; |
11661 | krep = strategy.wg[LoopM]; |
11662 | break; |
11663 | case CoopSplit::K: |
11664 | kgran = kb_slm; |
11665 | kdiv = strategy.wg[LoopM]; |
11666 | krep = 1; |
11667 | break; |
11668 | case CoopSplit::Linear: |
11669 | kgran = std::max(state.Bi.crosspack, state.Bi.tileR); |
11670 | kdiv = unrollKSLM / kgran; |
11671 | krep = strategy.wg[LoopM] / kdiv; |
11672 | break; |
11673 | default: stub(); |
11674 | } |
11675 | gemmCalcKSLM(kSLMB, state.lidM, kgran, kdiv, krep, problem, |
11676 | strategy, state); |
11677 | } |
11678 | |
11679 | if ((Ai_incrementalRem || Bi_incrementalRem) && kOffset != 0) |
11680 | add(2, kSLMStorage.w()(1), kSLMStorage.w()(1), kOffset); |
11681 | } |
11682 | |
11683 | // k mask information. |
11684 | Subregister rems[3] |
11685 | = {state.remainders[LoopM], state.remainders[LoopN], state.K}; |
11686 | int offsets[3] = {0, 0, -kOffset}; |
11687 | |
11688 | // If not changing between main loop and remainder, update k masks as needed and return. |
11689 | if (remActiveSLM == active) { |
11690 | if (active) { |
11691 | state.wipeActiveVFlags(); |
11692 | loadMasks(kMasksSLM, rems, offsets, strategy, state); |
11693 | } |
11694 | return; |
11695 | } |
11696 | |
11697 | // Not possible to deactivate remainder path with late k remainder. |
11698 | if (!active && remActiveSLM && (Ai_lateKRem || Bi_lateKRem)) stub(); |
11699 | remActiveSLM = active; |
11700 | |
11701 | // Start using k masks if needed. |
11702 | if (Ai_lateKRem && !state.Ai_strategy.padded) { |
11703 | state.Ai_layoutRem = state.Ai_layout; |
11704 | state.Ai_addrsRem = state.Ai_addrs; |
11705 | addMasking(Ta_ext, state.Ai_layoutRem, state.Ai_addrsRem, |
11706 | state.inputs.lda, false, true, state.Ai, state.Ai_strategy, |
11707 | strategy, state, state.Ai_regCount); |
11708 | if (!assignMasks(state.Ai_layoutRem, LoopM, LoopK, kMasksSLM, |
11709 | strategy, state, true)) |
11710 | stub(); |
11711 | if (state.aioShare && state.Ao_regsRem.empty() |
11712 | && state.Ai_layoutRem[0].crosspack |
11713 | != state.Ai_layout[0].crosspack) { |
11714 | state.aioShareRem = false; |
11715 | state.Ao_regsRem |
11716 | = state.ra.alloc_range(getRegCount(state.Ao_layout)); |
11717 | } |
11718 | } |
11719 | if (Bi_lateKRem && !state.Bi_strategy.padded) { |
11720 | state.Bi_layoutRem = state.Bi_layout; |
11721 | state.Bi_addrsRem = state.Bi_addrs; |
11722 | addMasking(Tb_ext, state.Bi_layoutRem, state.Bi_addrsRem, |
11723 | state.inputs.ldb, true, false, state.Bi, state.Bi_strategy, |
11724 | strategy, state, state.Bi_regCount); |
11725 | if (!assignMasks(state.Bi_layoutRem, LoopK, LoopN, kMasksSLM, |
11726 | strategy, state, true)) |
11727 | stub(); |
11728 | if (state.bioShare && state.Bo_regsRem.empty() |
11729 | && state.Bi_layoutRem[0].crosspack |
11730 | != state.Bi_layout[0].crosspack) { |
11731 | state.bioShareRem = false; |
11732 | state.Bo_regsRem |
11733 | = state.ra.alloc_range(getRegCount(state.Bo_layout)); |
11734 | } |
11735 | } |
11736 | |
11737 | if (problem.backward()) |
11738 | for (auto &mask : kMasksSLM) |
11739 | mask.reverse(unrollKSLM); |
11740 | |
11741 | if (!state.vflagStorage.isValid()) { |
11742 | bool needVFlags = false; |
11743 | for (const auto &mask : kMasksSLM) |
11744 | needVFlags |= state.raVFlag.isVirtual(mask.flag); |
11745 | if (needVFlags) allocVFlagStorage(strategy, state); |
11746 | } |
11747 | loadMasks(kMasksSLM, rems, offsets, strategy, state); |
11748 | |
11749 | bool mayAccessAllK = (minOPCount > 1) || problem.sumA || problem.sumB; |
11750 | bool asIfMaskedAi = Ai_lateKRem && state.Ai_strategy.padded; |
11751 | bool asIfMaskedBi = Bi_lateKRem && state.Bi_strategy.padded; |
11752 | slmRemaskA = slmA && mayAccessAllK && !Ai_remIncrCopy |
11753 | && needsRemask(Ta_ext, true, state.Ai_layoutRem, |
11754 | state.Ai_strategy, asIfMaskedAi); |
11755 | slmRemaskB = slmB && mayAccessAllK && !Bi_remIncrCopy |
11756 | && needsRemask(Tb_ext, false, state.Bi_layoutRem, |
11757 | state.Bi_strategy, asIfMaskedBi); |
11758 | }; |
11759 | |
11760 | // Get state.K, the loop counter. |
11761 | // The caller may initialize state.K, in case its value on entry is the loop count. |
11762 | // Otherwise, it is initialized from state.k. |
11763 | auto kInput = state.k; |
11764 | bool matchBarriers = (strategy.kParallelLocal && needBarrier); |
11765 | bool saveK = state.isNested || (problem.abOffset != ABOffset::None) |
11766 | || matchBarriers; |
11767 | bool incomingK = state.K.isValid(); |
11768 | |
11769 | if (!incomingK) state.K = saveK ? state.ra.alloc_sub<int32_t>() : kInput; |
11770 | |
11771 | if (saveK && !incomingK) mov(1, state.K, kInput); |
11772 | |
11773 | if (state.firstKLoopSegment) { |
11774 | // Zero out A/B sums if needed. |
11775 | if (calcASums) zeroMatrix(state.As_regs, strategy); |
11776 | if (calcBSums) zeroMatrix(state.Bs_regs, strategy); |
11777 | |
11778 | // Zero out C, if not loading ahead of time. |
11779 | if (!cLoadAhead) { |
11780 | for (int i = 0; i < state.C_accCount; i += 2) |
11781 | mov<uint32_t>(2 * elementsPerGRF<uint32_t>(hw), |
11782 | AccumulatorRegister(i), uint16_t(0)); |
11783 | |
11784 | for (int buf = 0; buf < state.C_buffers; buf++) |
11785 | zeroMatrix(state.C_regs[buf], strategy); |
11786 | } |
11787 | } |
11788 | |
11789 | LoopSequencer ls; |
11790 | using namespace loop_sequencer; |
11791 | |
11792 | int slmBufferLA = 0; |
11793 | switch (slmBuffers) { |
11794 | case 0: |
11795 | case 1: slmBufferLA = 0; break; |
11796 | case 2: |
11797 | case 3: slmBufferLA = 1; break; |
11798 | case 4: slmBufferLA = 2; break; |
11799 | default: stub(); |
11800 | } |
11801 | |
11802 | int lookaheadALoad = ka_loadMain * (A_copies - 1); |
11803 | int lookaheadBLoad = kb_loadMain * (B_copies - 1); |
11804 | int lookaheadALoadRem = ka_loadRem * (A_copies - 1); |
11805 | int lookaheadBLoadRem = kb_loadRem * (B_copies - 1); |
11806 | int lookaheadSLMLoad = unrollKSLM * (slmCopies - 1) + unrollKSLM - 1; |
11807 | int lookaheadSLMStore = unrollKSLM * slmBufferLA + 1; |
11808 | |
11809 | if (slmA && slmB) { |
11810 | if (lookaheadALoad != lookaheadBLoad) stub(); |
11811 | if (lookaheadALoadRem != lookaheadBLoadRem) stub(); |
11812 | if (ka_loadMain != kb_loadMain && lookaheadALoad != lookaheadALoadRem) |
11813 | stub(); |
11814 | } |
11815 | |
11816 | int lookaheadSLMReload = slmA ? lookaheadALoad : lookaheadBLoad; |
11817 | int lookaheadSLMReloadRem = slmA ? lookaheadALoadRem : lookaheadBLoadRem; |
11818 | int durationSLMMainLoad = std::max(slmA * ka_loadMain, slmB * kb_loadMain); |
11819 | |
11820 | auto A_remActive = [&](Iteration h) { |
11821 | return (h.remaining() < ka_loadMain - (h % ka_loadMain)); |
11822 | }; |
11823 | auto B_remActive = [&](Iteration h) { |
11824 | return (h.remaining() < kb_loadMain - (h % kb_loadMain)); |
11825 | }; |
11826 | auto slmRemActive = [&](Iteration h) { |
11827 | return (h.remaining() < unrollKSLM - (h % unrollKSLM)); |
11828 | }; |
11829 | auto opRemActive = [&](Iteration h) { |
11830 | return (h.remaining() < opCountMain - (h % opCountMain)); |
11831 | }; |
11832 | auto repackA = [&](Iteration h) { |
11833 | return A_remActive(h) ? state.repackARem : state.repackA; |
11834 | }; |
11835 | auto repackB = [&](Iteration h) { |
11836 | return B_remActive(h) ? state.repackBRem : state.repackB; |
11837 | }; |
11838 | auto ka_load = [&](Iteration h) { |
11839 | return A_remActive(h) ? ka_loadRem : ka_loadMain; |
11840 | }; |
11841 | auto kb_load = [&](Iteration h) { |
11842 | return B_remActive(h) ? kb_loadRem : kb_loadMain; |
11843 | }; |
11844 | auto A_copy = [&](Iteration h) { return (h / ka_load(h)) % A_copies; }; |
11845 | auto B_copy = [&](Iteration h) { return (h / kb_load(h)) % B_copies; }; |
11846 | auto A_regs = [&](Iteration h) -> GRFMultirange & { |
11847 | return state.A_regs[A_copy(h)]; |
11848 | }; |
11849 | auto B_regs = [&](Iteration h) -> GRFMultirange & { |
11850 | return state.B_regs[B_copy(h)]; |
11851 | }; |
11852 | auto A_layout = [&](Iteration h) -> vector<RegisterBlock> & { |
11853 | return A_remActive(h) ? state.A_layoutRem : state.A_layout; |
11854 | }; |
11855 | auto B_layout = [&](Iteration h) -> vector<RegisterBlock> & { |
11856 | return B_remActive(h) ? state.B_layoutRem : state.B_layout; |
11857 | }; |
11858 | auto Ar_regs = [&](Iteration h) -> GRFMultirange & { |
11859 | return repackA(h) ? state.Ar_regs : A_regs(h); |
11860 | }; |
11861 | auto Br_regs = [&](Iteration h) -> GRFMultirange & { |
11862 | return repackB(h) ? state.Br_regs : B_regs(h); |
11863 | }; |
11864 | auto Ar_layout = [&](Iteration h) -> vector<RegisterBlock> & { |
11865 | return repackA(h) ? state.Ar_layout : A_layout(h); |
11866 | }; |
11867 | auto Br_layout = [&](Iteration h) -> vector<RegisterBlock> & { |
11868 | return repackB(h) ? state.Br_layout : B_layout(h); |
11869 | }; |
11870 | auto slmCopy = [&](Iteration h) { return (h / unrollKSLM) % slmCopies; }; |
11871 | auto slmBuffer = [&](Iteration h) { return (h / unrollKSLM) % slmBuffers; }; |
11872 | auto Ai_layout = [&](Iteration h) -> vector<RegisterBlock> & { |
11873 | return slmRemActive(h) ? state.Ai_layoutRem : state.Ai_layout; |
11874 | }; |
11875 | auto Bi_layout = [&](Iteration h) -> vector<RegisterBlock> & { |
11876 | return slmRemActive(h) ? state.Bi_layoutRem : state.Bi_layout; |
11877 | }; |
11878 | auto Ai_addrs = [&](Iteration h) -> vector<GRFRange> & { |
11879 | return slmRemActive(h) ? state.Ai_addrsRem : state.Ai_addrs; |
11880 | }; |
11881 | auto Bi_addrs = [&](Iteration h) -> vector<GRFRange> & { |
11882 | return slmRemActive(h) ? state.Bi_addrsRem : state.Bi_addrs; |
11883 | }; |
11884 | auto Ai_allRegs = [&](Iteration h) -> vector<GRFMultirange> & { |
11885 | return slmRemActive(h) ? state.Ai_regsRem : state.Ai_regs; |
11886 | }; |
11887 | auto Bi_allRegs = [&](Iteration h) -> vector<GRFMultirange> & { |
11888 | return slmRemActive(h) ? state.Bi_regsRem : state.Bi_regs; |
11889 | }; |
11890 | auto Ai_regs = [&](Iteration h) -> GRFMultirange & { |
11891 | return Ai_allRegs(h)[slmCopy(h)]; |
11892 | }; |
11893 | auto Bi_regs = [&](Iteration h) -> GRFMultirange & { |
11894 | return Bi_allRegs(h)[slmCopy(h)]; |
11895 | }; |
11896 | auto Ao_regs = [&](Iteration h) -> GRFMultirange & { |
11897 | return slmRemActive(h) ? state.Ao_regsRem : state.Ao_regs; |
11898 | }; |
11899 | auto Bo_regs = [&](Iteration h) -> GRFMultirange & { |
11900 | return slmRemActive(h) ? state.Bo_regsRem : state.Bo_regs; |
11901 | }; |
11902 | auto effAo_regs = [&](Iteration h) -> GRFMultirange & { |
11903 | return Ao_regs(h).empty() ? Ai_regs(h) : Ao_regs(h); |
11904 | }; |
11905 | auto effBo_regs = [&](Iteration h) -> GRFMultirange & { |
11906 | return Bo_regs(h).empty() ? Bi_regs(h) : Bo_regs(h); |
11907 | }; |
11908 | auto aioShare = [&](Iteration h) { |
11909 | return slmRemActive(h) ? state.aioShareRem : state.aioShare; |
11910 | }; |
11911 | auto bioShare = [&](Iteration h) { |
11912 | return slmRemActive(h) ? state.bioShareRem : state.bioShare; |
11913 | }; |
11914 | auto opCount = [&](Iteration h) { |
11915 | return opRemActive(h) ? opCountRem : opCountMain; |
11916 | }; |
11917 | auto nothing = [&](Iteration h) {}; |
11918 | |
11919 | // Dummy task to extend k unroll if needed. |
11920 | ls.schedule(every(unrollK) | checkOptional(), nothing); |
11921 | |
11922 | // A prefetch. |
11923 | auto reqPFA = every(ka_pfStride) |
11924 | | duration( |
11925 | strategy.cooperativePF ? ka_pfStride : strategy.ka_prefetch) |
11926 | | lookahead(strategy.prefetchA); |
11927 | |
11928 | if (strategy.prefetchA && !strategy.slmA && readA) { |
11929 | ls.schedule(reqPFA, [&](Iteration h) { |
11930 | auto &A_global = strategy.slmA ? state.Ai : problem.A; |
11931 | gemmALoad(state.Ap_regs, state.Ap_layout, state.Ap_addrs, A_global, |
11932 | strategy.A_prefetch, problem, strategy, state); |
11933 | }); |
11934 | } |
11935 | |
11936 | // B prefetch. |
11937 | auto reqPFB = every(kb_pfStride) |
11938 | | duration( |
11939 | strategy.cooperativePF ? kb_pfStride : strategy.kb_prefetch) |
11940 | | lookahead(strategy.prefetchB); |
11941 | |
11942 | if (strategy.prefetchB && !strategy.slmB && readB) { |
11943 | ls.schedule(reqPFB, [&](Iteration h) { |
11944 | auto &B_global = strategy.slmB ? state.Bi : problem.B; |
11945 | gemmBLoad(state.Bp_regs, state.Bp_layout, state.Bp_addrs, B_global, |
11946 | strategy.B_prefetch, problem, strategy, state); |
11947 | }); |
11948 | } |
11949 | |
11950 | // SLM loads. |
11951 | auto reqSLMLoad = every(unrollKSLM) | variants(slmCopies) |
11952 | | lookahead( |
11953 | lookaheadSLMLoad + lookaheadSLMStore + lookaheadSLMReload); |
11954 | auto reqSLMLoadABRem = every(unrollKSLM) | variants(slmCopies) |
11955 | | lookahead(lookaheadSLMLoad + lookaheadSLMStore |
11956 | + lookaheadSLMReloadRem); |
11957 | auto reqSLMStore = every(unrollKSLM) | variants(slmCopies) |
11958 | | lookahead(lookaheadSLMStore + lookaheadSLMReload) |
11959 | | duration(durationSLMMainLoad); |
11960 | auto reqSLMStoreABRem = every(unrollKSLM) | variants(slmCopies) |
11961 | | lookahead(lookaheadSLMStore + lookaheadSLMReloadRem); |
11962 | |
11963 | if ((slmA || slmB) && mustActivateRemainder) { |
11964 | ls.schedule({{reqSLMLoad | duration(unrollKSLM), nothing}, |
11965 | {reqSLMLoad | unconditional(), [&](Iteration h) { |
11966 | activateSLMRemainder(true, h.counterOffset()); |
11967 | }}}); |
11968 | } |
11969 | |
11970 | auto doSLMRemLoad = [&](Iteration h) { |
11971 | activateSLMRemainder(true, h.counterOffset()); |
11972 | if (slmA) |
11973 | gemmAiBiRemLoadInc<true>(Ai_incrementalRem, Ai_remIncrCopy, |
11974 | needSLMReset, slmRemaskA, kSLMA, Ai_regs(h), |
11975 | state.Ai_layoutRem, state.Ai_addrsRem, state.Ai_layoutK, |
11976 | state.Ai_addrsK, state.Ao_regsRem, state.Ao_layout, |
11977 | state.Ai, state.Ai_strategy, problem, strategy, state); |
11978 | if (slmB) |
11979 | gemmAiBiRemLoadInc<false>(Bi_incrementalRem, Bi_remIncrCopy, |
11980 | needSLMReset, slmRemaskB, kSLMB, Bi_regs(h), |
11981 | state.Bi_layoutRem, state.Bi_addrsRem, state.Bi_layoutK, |
11982 | state.Bi_addrsK, state.Bo_regsRem, state.Bo_layout, |
11983 | state.Bi, state.Bi_strategy, problem, strategy, state); |
11984 | if (Ai_incrementalRem || Bi_incrementalRem) lastThresh = 0; |
11985 | }; |
11986 | |
11987 | if (slmA || slmB) { |
11988 | ls.schedule({{reqSLMLoad | duration(unrollKSLM), |
11989 | [&](Iteration h) { |
11990 | activateSLMRemainder(false); |
11991 | if (slmA) |
11992 | gemmALoad(Ai_regs(h), state.Ai_layout, |
11993 | state.Ai_addrs, state.Ai, |
11994 | state.Ai_strategy, problem, |
11995 | strategy, state); |
11996 | if (slmB) |
11997 | gemmBLoad(Bi_regs(h), state.Bi_layout, |
11998 | state.Bi_addrs, state.Bi, |
11999 | state.Bi_strategy, problem, |
12000 | strategy, state); |
12001 | }}, |
12002 | {reqSLMLoad | duration(durationSLMMainLoad), doSLMRemLoad}, |
12003 | {reqSLMLoadABRem, doSLMRemLoad}}); |
12004 | } |
12005 | |
12006 | // Read suppression W/A for fused EU architectures. |
12007 | bool rswaA = strategy.readSuppressionWA && (A_copies == 1) |
12008 | && ((ka_loadMain <= opCountMain) || state.repackA) |
12009 | && hasMasking(state.A_layout); |
12010 | bool rswaB = strategy.readSuppressionWA && (B_copies == 1) |
12011 | && ((kb_loadMain <= opCountMain) || state.repackB) |
12012 | && hasMasking(state.B_layout); |
12013 | bool rswaARem = strategy.readSuppressionWA && (A_copies == 1) |
12014 | && ((ka_loadRem <= opCountRem) || state.repackARem) |
12015 | && hasMasking(state.A_layoutRem); |
12016 | bool rswaBRem = strategy.readSuppressionWA && (B_copies == 1) |
12017 | && ((kb_loadRem <= opCountRem) || state.repackBRem) |
12018 | && hasMasking(state.B_layoutRem); |
12019 | |
12020 | Iteration A_lastRSWA; |
12021 | bool haveA_lastRSWA = false; |
12022 | |
12023 | bool saveRSWA; |
12024 | auto disableRSWA = [&]() { |
12025 | saveRSWA = strategy.readSuppressionWA; |
12026 | strategy.readSuppressionWA = false; |
12027 | }; |
12028 | auto restoreRSWA = [&]() { strategy.readSuppressionWA = saveRSWA; }; |
12029 | |
12030 | auto doRSWA_A = [&](Iteration h) { |
12031 | A_lastRSWA = h; |
12032 | haveA_lastRSWA = true; |
12033 | doReadSuppressionWA(strategy, state); |
12034 | }; |
12035 | |
12036 | auto doRSWA_B = [&](Iteration h) { |
12037 | if (!(haveA_lastRSWA && A_lastRSWA == h)) |
12038 | doReadSuppressionWA(strategy, state); |
12039 | haveA_lastRSWA = false; |
12040 | }; |
12041 | |
12042 | // A/B load scheduling. |
12043 | auto reqLoadA = every(ka_loadMain) | duration(ka_loadMain) |
12044 | | variants(A_copies) | lookahead(lookaheadALoad); |
12045 | auto reqLoadARem = every(ka_loadRem) | variants(A_copies) |
12046 | | lookahead(lookaheadALoadRem); |
12047 | auto reqLoadAPrezero = every(minOPCount) | variants(A_copies) |
12048 | | lookahead(state.repackARem ? 0 : lookaheadALoadRem); |
12049 | |
12050 | auto reqLoadB = every(kb_loadMain) | duration(kb_loadMain) |
12051 | | variants(B_copies) | lookahead(lookaheadBLoad); |
12052 | auto reqLoadBRem = every(kb_loadRem) | variants(B_copies) |
12053 | | lookahead(lookaheadBLoadRem); |
12054 | auto reqLoadBPrezero = every(minOPCount) | variants(B_copies) |
12055 | | lookahead(state.repackBRem ? 0 : lookaheadBLoadRem); |
12056 | |
12057 | // A/B prezeroing for partial remainder loads with multi-k outer products. |
12058 | bool prezeroARem = !slmA && (ka_loadRem < minOPCount) && readA; |
12059 | bool prezeroBRem = !slmB && (kb_loadRem < minOPCount) && readB; |
12060 | |
12061 | if (prezeroARem && prezeroBRem && Ta.isInteger() && Tb.isInteger() |
12062 | && !calcASums && !calcBSums) { |
12063 | // Only need to pre-zero one operand for integer A/B. Choose the smaller one. |
12064 | if (unrollM >= unrollN) |
12065 | prezeroARem = false; |
12066 | else |
12067 | prezeroBRem = false; |
12068 | } |
12069 | |
12070 | if (prezeroARem) |
12071 | ls.schedule({{reqLoadA, nothing}, |
12072 | {reqLoadAPrezero, [&](Iteration h) { |
12073 | zeroMatrix(state.repackARem ? state.Ar_regs : A_regs(h), |
12074 | strategy); |
12075 | }}}); |
12076 | |
12077 | if (prezeroBRem) |
12078 | ls.schedule({{reqLoadB, nothing}, |
12079 | {reqLoadBPrezero, [&](Iteration h) { |
12080 | zeroMatrix(state.repackBRem ? state.Br_regs : B_regs(h), |
12081 | strategy); |
12082 | }}}); |
12083 | |
12084 | // A/B enforced remainder preparations. |
12085 | if (mustActivateRemainder) { |
12086 | ls.schedule({{reqLoadA, nothing}, |
12087 | {reqLoadARem | unconditional(), [&](Iteration h) { |
12088 | activateABRemainder(true, true, false); |
12089 | }}}); |
12090 | ls.schedule({{reqLoadB, nothing}, |
12091 | {reqLoadBRem | unconditional(), [&](Iteration h) { |
12092 | activateABRemainder(true, false, true); |
12093 | }}}); |
12094 | } |
12095 | |
12096 | // A loads. |
12097 | if (readA) |
12098 | ls.schedule({{reqLoadA, |
12099 | [&](Iteration h) { |
12100 | if (rswaA) doRSWA_A(h); |
12101 | disableRSWA(); |
12102 | activateABRemainder(false, true, false); |
12103 | gemmALoad(A_regs(h), state.A_layout, |
12104 | state.A_addrs, problem.A, strategy.A, |
12105 | problem, strategy, state); |
12106 | restoreRSWA(); |
12107 | }}, |
12108 | {reqLoadARem, [&](Iteration h) { |
12109 | if (rswaARem) doRSWA_A(h); |
12110 | disableRSWA(); |
12111 | activateABRemainder(true, true, false); |
12112 | gemmALoad(A_regs(h), state.A_layoutRem, state.A_addrsRem, |
12113 | problem.A, strategy.A, problem, strategy, state); |
12114 | restoreRSWA(); |
12115 | }}}); |
12116 | |
12117 | // B loads. |
12118 | if (readB) |
12119 | ls.schedule({{reqLoadB, |
12120 | [&](Iteration h) { |
12121 | if (rswaB) doRSWA_B(h); |
12122 | disableRSWA(); |
12123 | activateABRemainder(false, false, true); |
12124 | gemmBLoad(B_regs(h), state.B_layout, |
12125 | state.B_addrs, problem.B, strategy.B, |
12126 | problem, strategy, state); |
12127 | restoreRSWA(); |
12128 | }}, |
12129 | {reqLoadBRem, [&](Iteration h) { |
12130 | if (rswaBRem) doRSWA_B(h); |
12131 | disableRSWA(); |
12132 | activateABRemainder(true, false, true); |
12133 | gemmBLoad(B_regs(h), state.B_layoutRem, state.B_addrsRem, |
12134 | problem.B, strategy.B, problem, strategy, state); |
12135 | restoreRSWA(); |
12136 | }}}); |
12137 | |
12138 | // Stalls to promote thread switches. |
12139 | auto reqStall = every(lcm(ka_loadMain, kb_loadMain)) | checkOptional(); |
12140 | |
12141 | if (strategy.stallAfterLoad) |
12142 | ls.schedule(reqStall, [&](Iteration h) { |
12143 | if (hw < HW::Gen12LP) |
12144 | mov<uint32_t>(1 | Switch, null, 0); |
12145 | else if (Tc.isInteger()) { |
12146 | mov<float>(1, null, 0.0f); |
12147 | sync.nop(SWSB<float>(1)); |
12148 | } else { |
12149 | mov<uint32_t>(1, null, 0); |
12150 | sync.nop(SWSB<uint32_t>(1)); |
12151 | } |
12152 | }); |
12153 | |
12154 | // k decrement and loop check. |
12155 | auto reqLoopCheck = every(unrollK) | duration(unrollK); |
12156 | |
12157 | if (lateKLoopCheck) |
12158 | reqLoopCheck = reqLoopCheck.delay( |
12159 | unrollK - std::min(ka_loadMain, kb_loadMain)); |
12160 | |
12161 | ls.schedule_if( |
12162 | reqLoopCheck, |
12163 | [&](Iteration h) { |
12164 | add(1 | gt | f0[0], state.K, state.K, -unrollK); |
12165 | }, |
12166 | [&](Iteration h) { |
12167 | return (curPhase == LoopSequencer::PhaseMainLoop); |
12168 | }); |
12169 | |
12170 | // SLM store address increments. |
12171 | auto doSLMStoreInc = [&](Iteration h) { |
12172 | int kIncSLMStore |
12173 | = (slmBuffer(h) == slmBuffers - 1) ? -(slmBuffers - 1) : +1; |
12174 | kIncSLMStore *= unrollKSLM; |
12175 | if (slmA) |
12176 | gemmAIncrement(Ta, state.Ao_layout, state.Ao_addrs, state.Ao, |
12177 | state.Ao_strategy, kIncSLMStore, problem, strategy, state); |
12178 | if (slmB) |
12179 | gemmBIncrement(Tb, state.Bo_layout, state.Bo_addrs, state.Bo, |
12180 | state.Bo_strategy, kIncSLMStore, problem, strategy, state); |
12181 | }; |
12182 | |
12183 | if (strategy.slmBuffers >= 2) { |
12184 | ls.schedule({{(reqSLMStore | duration(durationSLMMainLoad)).delay(1), |
12185 | doSLMStoreInc}, |
12186 | {reqSLMStoreABRem.delay(1), doSLMStoreInc}}); |
12187 | } |
12188 | |
12189 | // SLM load address increments. |
12190 | int delaySLMInc = strategy.delayABInc ? (unrollKSLM >> 1) : 0; |
12191 | |
12192 | auto doSLMLoadInc = [&](Iteration h) { |
12193 | bool fullLoad = (h.remaining() >= (unrollKSLM - delaySLMInc)); |
12194 | if (slmA && (fullLoad || !Ai_incrementalRem)) |
12195 | gemmAIncrement(Ta_ext, Ai_layout(h), Ai_addrs(h), state.Ai, |
12196 | state.Ai_strategy, unrollKSLM, problem, strategy, state); |
12197 | if (slmB && (fullLoad || !Bi_incrementalRem)) |
12198 | gemmBIncrement(Tb_ext, Bi_layout(h), Bi_addrs(h), state.Bi, |
12199 | state.Bi_strategy, unrollKSLM, problem, strategy, state); |
12200 | }; |
12201 | |
12202 | auto checkSLMLoadInc = [&](Iteration h) { |
12203 | bool fullLoad = (h.remaining() >= (unrollKSLM - delaySLMInc)); |
12204 | return (slmA && (fullLoad || !Ai_incrementalRem)) |
12205 | || (slmB && (fullLoad || !Bi_incrementalRem)); |
12206 | }; |
12207 | |
12208 | if (slmA || slmB) { |
12209 | ls.schedule_if({{(reqSLMLoad | duration(durationSLMMainLoad)) |
12210 | .delay(delaySLMInc), |
12211 | doSLMLoadInc, checkSLMLoadInc}, |
12212 | {reqSLMLoadABRem.delay(delaySLMInc), doSLMLoadInc, |
12213 | checkSLMLoadInc}}); |
12214 | } |
12215 | |
12216 | // A prefetch address increment. |
12217 | int delayAPFInc = strategy.delayABInc ? (ka_pfStride >> 1) : 0; |
12218 | |
12219 | if (strategy.prefetchA && !slmA && readA) { |
12220 | ls.schedule(reqPFA.delay(delayAPFInc), [&](Iteration h) { |
12221 | gemmAIncrement(Ta_ext, state.Ap_layout, state.Ap_addrs, problem.A, |
12222 | strategy.A_prefetch, ka_pfStride, problem, strategy, state); |
12223 | }); |
12224 | } |
12225 | |
12226 | // B prefetch address increment. |
12227 | int delayBPFInc = strategy.delayABInc ? (kb_pfStride >> 1) : 0; |
12228 | |
12229 | if (strategy.prefetchB && !slmB && readB) { |
12230 | ls.schedule(reqPFB.delay(delayBPFInc), [&](Iteration h) { |
12231 | gemmBIncrement(Tb_ext, state.Bp_layout, state.Bp_addrs, problem.B, |
12232 | strategy.B_prefetch, kb_pfStride, problem, strategy, state); |
12233 | }); |
12234 | } |
12235 | |
12236 | // A address increment. |
12237 | int delayAInc |
12238 | = (strategy.delayABInc && A_copies > 1) ? (ka_loadMain >> 1) : 0; |
12239 | |
12240 | auto ka_inc = [&](Iteration h) { |
12241 | auto inc = ka_load(h); |
12242 | if (slmA) { |
12243 | int kWraparound = unrollKSLM * slmBuffers; |
12244 | if ((h + inc) % kWraparound < inc) inc -= kWraparound; |
12245 | } |
12246 | return inc; |
12247 | }; |
12248 | |
12249 | if (readA) |
12250 | ls.schedule({{reqLoadA.delay(delayAInc), |
12251 | [&](Iteration h) { |
12252 | gemmAIncrement(Ta_load, state.A_layout, |
12253 | state.A_addrs, problem.A, strategy.A, |
12254 | ka_inc(h), problem, strategy, state); |
12255 | }}, |
12256 | {reqLoadARem, [&](Iteration h) { |
12257 | gemmAIncrement(Ta_load, state.A_layoutRem, |
12258 | state.A_addrsRem, problem.A, strategy.A, ka_inc(h), |
12259 | problem, strategy, state, h % unrollKSLM); |
12260 | }}}); |
12261 | |
12262 | // B address increment. |
12263 | int delayBInc |
12264 | = (strategy.delayABInc && B_copies > 1) ? (kb_loadMain >> 1) : 0; |
12265 | |
12266 | auto kb_inc = [&](Iteration h) { |
12267 | auto inc = kb_load(h); |
12268 | if (slmB) { |
12269 | int kWraparound = unrollKSLM * slmBuffers; |
12270 | if ((h + inc) % kWraparound < inc) inc -= kWraparound; |
12271 | } |
12272 | return inc; |
12273 | }; |
12274 | |
12275 | if (readB) |
12276 | ls.schedule({{reqLoadB.delay(delayBInc), |
12277 | [&](Iteration h) { |
12278 | gemmBIncrement(Tb_load, state.B_layout, |
12279 | state.B_addrs, problem.B, strategy.B, |
12280 | kb_inc(h), problem, strategy, state); |
12281 | }}, |
12282 | {reqLoadBRem, [&](Iteration h) { |
12283 | gemmBIncrement(Tb_load, state.B_layoutRem, |
12284 | state.B_addrsRem, problem.B, strategy.B, kb_inc(h), |
12285 | problem, strategy, state, h % unrollKSLM); |
12286 | }}}); |
12287 | |
12288 | // A/B remasking in k dimension, during remainder handling. |
12289 | bool remaskA = !slmA && readA && (minOPCount > 1) |
12290 | && needsRemask(Ta_load, true, state.A_layoutRem, strategy.A); |
12291 | bool remaskB = !slmB && readB && (minOPCount > 1) |
12292 | && needsRemask(Tb_load, false, state.B_layoutRem, strategy.B); |
12293 | |
12294 | if (remaskA && remaskB && Ta.isInteger() && Tb.isInteger() && !calcASums |
12295 | && !calcBSums) { |
12296 | // Only need to remask one operand for integer A/B. Choose the smaller one. |
12297 | if (unrollM >= unrollN) |
12298 | remaskA = false; |
12299 | else |
12300 | remaskB = false; |
12301 | } |
12302 | |
12303 | auto Tremask = remaskA ? Ta_load : Tb_load; |
12304 | if (remaskA && remaskB && Ta_load.size() != Tb_load.size()) stub(); |
12305 | if ((remaskA || remaskB) && problem.backward()) stub(); |
12306 | |
12307 | int remaskPeriod = lcm(ka_loadRem, kb_loadRem); |
12308 | auto reqRemaskSetup = every(remaskPeriod); |
12309 | auto reqRemaskA = every(ka_loadRem) | variants(A_copies); |
12310 | auto reqRemaskB = every(kb_loadRem) | variants(B_copies); |
12311 | |
12312 | if (remaskA || remaskB) |
12313 | ls.schedule({{reqRemaskSetup | duration(remaskPeriod), nothing}, |
12314 | {reqRemaskSetup, [&](Iteration h) { |
12315 | setupTeardownRemask(Tremask, 0, false, remaskPeriod, |
12316 | state.K, strategy, state); |
12317 | setupTeardownRemask(Tremask, 0, true, remaskPeriod, |
12318 | state.K, strategy, state, -h.counterOffset()); |
12319 | }}}); |
12320 | |
12321 | if (remaskA) |
12322 | ls.schedule({{reqLoadA, nothing}, |
12323 | {reqRemaskA, [&](Iteration h) { |
12324 | remaskLayout(Ta_load, 0, true, state.A_layoutRem, |
12325 | A_regs(h), strategy, state, h % remaskPeriod); |
12326 | }}}); |
12327 | |
12328 | if (remaskB) |
12329 | ls.schedule({{reqLoadB, nothing}, |
12330 | {reqRemaskB, [&](Iteration h) { |
12331 | remaskLayout(Tb_load, 0, false, state.B_layoutRem, |
12332 | B_regs(h), strategy, state, h % remaskPeriod); |
12333 | }}}); |
12334 | |
12335 | // A/B repacking. |
12336 | auto reqRepackA = every(ka_loadMain) | variants(A_copies); |
12337 | auto reqRepackARem = every(ka_loadRem) | variants(A_copies); |
12338 | bool convertA = (Ta != Ta_load) && (Ta.size() == Ta_load.size()); |
12339 | |
12340 | if (state.repackA || state.repackARem || convertA) |
12341 | if (readA) |
12342 | ls.schedule({{reqRepackA, |
12343 | [&](Iteration h) { |
12344 | if (state.repackA) |
12345 | copyRegisters(Ta_load, Ta, |
12346 | state.A_layout, |
12347 | state.Ar_layout, A_regs(h), |
12348 | state.Ar_regs, 0, 0, false, |
12349 | strategy, state); |
12350 | else if (convertA) |
12351 | convert(A_regs(h), Ta_load, Ta, |
12352 | problem, strategy, state); |
12353 | }}, |
12354 | {reqRepackARem, [&](Iteration h) { |
12355 | if (state.repackARem) |
12356 | copyRegisters(Ta_load, Ta, state.A_layoutRem, |
12357 | state.Ar_layout, A_regs(h), state.Ar_regs, |
12358 | 0, h % state.ka_repackRem, false, strategy, |
12359 | state); |
12360 | else if (convertA) |
12361 | convert(A_regs(h), Ta_load, Ta, problem, strategy, |
12362 | state); |
12363 | }}}); |
12364 | |
12365 | auto reqRepackB = every(kb_loadMain) | variants(B_copies); |
12366 | auto reqRepackBRem = every(kb_loadRem) | variants(B_copies); |
12367 | bool convertB = (Tb != Tb_load) && (Tb.size() == Tb_load.size()); |
12368 | |
12369 | if (state.repackB || state.repackBRem || convertB) |
12370 | if (readB) |
12371 | ls.schedule({{reqRepackB, |
12372 | [&](Iteration h) { |
12373 | if (state.repackB) |
12374 | copyRegisters(Tb_load, Tb, |
12375 | state.B_layout, |
12376 | state.Br_layout, B_regs(h), |
12377 | state.Br_regs, 0, 0, false, |
12378 | strategy, state); |
12379 | else if (convertB) |
12380 | convert(B_regs(h), Tb_load, Tb, |
12381 | problem, strategy, state); |
12382 | }}, |
12383 | {reqRepackBRem, [&](Iteration h) { |
12384 | if (state.repackBRem) |
12385 | copyRegisters(Tb_load, Tb, state.B_layoutRem, |
12386 | state.Br_layout, B_regs(h), state.Br_regs, |
12387 | h % state.kb_repackRem, 0, false, strategy, |
12388 | state); |
12389 | else if (convertB) |
12390 | convert(B_regs(h), Tb_load, Tb, problem, strategy, |
12391 | state); |
12392 | }}}); |
12393 | |
12394 | // Outer product(s). |
12395 | // If outer products batched across k (dp4a/dpas/k-chaining), trigger every opCount loops. |
12396 | auto reqOP = every(minOPCount) | lookahead(-(minOPCount - 1)); |
12397 | |
12398 | int ka_sumMain |
12399 | = !isLayoutColMajor(state.A_layout) ? ka_loadMain : opCountMain; |
12400 | int kb_sumMain |
12401 | = isLayoutColMajor(state.B_layout) ? kb_loadMain : opCountMain; |
12402 | |
12403 | ls.schedule(reqOP, [&](Iteration h) { |
12404 | auto oc = opCount(h); |
12405 | auto hNext = h + minOPCount; |
12406 | if (hNext % oc != 0) return; |
12407 | |
12408 | int ka = ka_load(h), kb = kb_load(h); |
12409 | int ha = h % ka; |
12410 | int hb = h % kb; |
12411 | if (problem.backward()) { |
12412 | ha = ka - 1 - ha; |
12413 | hb = kb - 1 - hb; |
12414 | } |
12415 | |
12416 | auto &layoutA = Ar_layout(h); |
12417 | auto &layoutB = Br_layout(h); |
12418 | auto ®sA = Ar_regs(h); |
12419 | auto ®sB = Br_regs(h); |
12420 | |
12421 | outerProduct(h, ha, hb, oc, layoutA, layoutB, regsA, regsB, problem, |
12422 | strategy, state); |
12423 | |
12424 | if (calcASums && !slmASums && !state.systolicSumA) { |
12425 | int ka_sum = (curPhase == LoopSequencer::PhaseMainLoop) ? ka_sumMain |
12426 | : oc; |
12427 | int ha0 = ha - oc + minOPCount; |
12428 | if (ha0 % ka_sum == 0) |
12429 | accumulateSum(false, Ta, regsA, layoutA, Tc, state.As_regs, |
12430 | state.As_layout, strategy, state, ha0, ha0 + ka_sum); |
12431 | } |
12432 | |
12433 | if (calcBSums && !slmBSums && !state.systolicSumB) { |
12434 | int kb_sum = (curPhase == LoopSequencer::PhaseMainLoop) ? kb_sumMain |
12435 | : oc; |
12436 | int hb0 = hb - oc + minOPCount; |
12437 | if (hb0 % kb_sum == 0) |
12438 | accumulateSum(true, Tb, regsB, layoutB, Tc, state.Bs_regs, |
12439 | state.Bs_layout, strategy, state, hb0, hb0 + kb_sum); |
12440 | } |
12441 | }); |
12442 | |
12443 | // SLM data repacking and remasking. |
12444 | auto reqSLMRepack = every(unrollKSLM) | variants(slmCopies) |
12445 | | lookahead(lookaheadSLMStore + lookaheadSLMReload |
12446 | + strategy.slmRepackAhead) |
12447 | | duration(durationSLMMainLoad); |
12448 | auto reqSLMRepackABRem = every(unrollKSLM) | variants(slmCopies) |
12449 | | lookahead(lookaheadSLMStore + lookaheadSLMReloadRem |
12450 | + strategy.slmRepackAhead); |
12451 | |
12452 | auto slmConvertA = [&](Iteration h) { |
12453 | return slmA && aioShare(h) && (Ta != Ta_ext) |
12454 | && (Ta.size() == Ta_ext.size()); |
12455 | }; |
12456 | auto slmConvertB = [&](Iteration h) { |
12457 | return slmB && bioShare(h) && (Tb != Tb_ext) |
12458 | && (Tb.size() == Tb_ext.size()); |
12459 | }; |
12460 | |
12461 | auto doSLMRepack = [&](Iteration h) { |
12462 | if (slmA && !aioShare(h) && !(slmRemActive(h) && Ai_remIncrCopy)) |
12463 | copyRegisters(Ta_ext, Ta, Ai_layout(h), state.Ao_layout, Ai_regs(h), |
12464 | Ao_regs(h), 0, 0, false, strategy, state); |
12465 | else if (slmConvertA(h)) |
12466 | convert(Ai_regs(h), Ta_ext, Ta, problem, strategy, state); |
12467 | |
12468 | if (slmB && !bioShare(h) && !(slmRemActive(h) && Bi_remIncrCopy)) |
12469 | copyRegisters(Tb_ext, Tb, Bi_layout(h), state.Bo_layout, Bi_regs(h), |
12470 | Bo_regs(h), 0, 0, false, strategy, state); |
12471 | else if (slmConvertB(h)) |
12472 | convert(Bi_regs(h), Tb_ext, Tb, problem, strategy, state); |
12473 | |
12474 | if (slmRemActive(h) && (slmRemaskA || slmRemaskB)) { |
12475 | releaseMaskAssignments(kMasksSLM, |
12476 | state); // Not in use -- can temporarily free these. |
12477 | gemmSLMRemask(slmRemaskA, slmRemaskB, effAo_regs(h), effBo_regs(h), |
12478 | -h.counterOffset(), problem, strategy, state); |
12479 | reclaimMaskAssignments(kMasksSLM, state); |
12480 | } |
12481 | }; |
12482 | |
12483 | auto checkSLMRepack = [&](Iteration h) { |
12484 | return (slmA && !aioShare(h) && !(slmRemActive(h) && Ai_remIncrCopy)) |
12485 | || (slmB && !bioShare(h) |
12486 | && !(slmRemActive(h) && Bi_remIncrCopy)) |
12487 | || (slmRemActive(h) && (slmRemaskA || slmRemaskB)) |
12488 | || slmConvertA(h) || slmConvertB(h); |
12489 | }; |
12490 | |
12491 | if (slmA || slmB) { |
12492 | ls.schedule_if({{reqSLMRepack, doSLMRepack, checkSLMRepack}, |
12493 | {reqSLMRepackABRem, doSLMRepack, checkSLMRepack}}); |
12494 | } |
12495 | |
12496 | // SLM stores and synchronization. |
12497 | auto reqSLMAfterStore = every(unrollKSLM) | variants(slmCopies) |
12498 | | lookahead(lookaheadSLMStore + lookaheadSLMReload - unrollKSLM) |
12499 | | duration(durationSLMMainLoad); |
12500 | auto reqSLMAfterStore2 = every(unrollKSLM) | variants(slmCopies) |
12501 | | lookahead(lookaheadSLMStore + lookaheadSLMReload - 2 * unrollKSLM) |
12502 | | duration(durationSLMMainLoad); |
12503 | auto reqSLMAfterStoreABRem = every(unrollKSLM) | variants(slmCopies) |
12504 | | lookahead(lookaheadSLMStore + lookaheadSLMReloadRem - unrollKSLM); |
12505 | auto reqSLMAfterStoreABRem2 = every(unrollKSLM) | variants(slmCopies) |
12506 | | lookahead( |
12507 | lookaheadSLMStore + lookaheadSLMReloadRem - 2 * unrollKSLM); |
12508 | |
12509 | auto slm1x2xFencedBarrier = [&]() { |
12510 | // For DG2+, before 1x/2x buffered stores, we must ensure prior SLM reads are complete. |
12511 | // Use a fence for >2x global buffering. |
12512 | // For 2x global buffering, use SWSB since loaded data will be used shortly. |
12513 | // For 1x global buffering, loaded data has already been consumed. |
12514 | if (hw < HW::XeHPG && !strategy.strictFence) |
12515 | kLoopBarrier(false); |
12516 | else if ((A_copies > 2 || B_copies > 2) && !strategy.slmFenceWARWA) |
12517 | kLoopBarrier(true); |
12518 | else { |
12519 | if (slmA && A_copies > 1) wrdepRanges(state.A_regs); |
12520 | if (slmB && B_copies > 1) wrdepRanges(state.B_regs); |
12521 | kLoopBarrier(false); |
12522 | } |
12523 | }; |
12524 | |
12525 | auto doSLMAfterStore2 = [&](Iteration h) { |
12526 | switch (slmBuffers) { |
12527 | case 1: |
12528 | case 2: |
12529 | case 3: break; |
12530 | case 4: kLoopBarrier(false, KBarrierType::Wait); break; |
12531 | default: stub(); |
12532 | } |
12533 | }; |
12534 | |
12535 | auto doSLMAfterStore = [&](Iteration h) { |
12536 | switch (slmBuffers) { |
12537 | case 1: break; |
12538 | case 2: slm1x2xFencedBarrier(); break; |
12539 | case 3: kLoopBarrier(false, KBarrierType::Wait); break; |
12540 | case 4: |
12541 | // TEMP: move me earlier. |
12542 | slmFenceIssue(); |
12543 | // |
12544 | slmFenceWait(); |
12545 | if (strategy.slmFenceWARWA) { |
12546 | // Work around buggy SLM fence by ensuring SLM reads complete. |
12547 | if (slmA && A_copies > 1) wrdepRanges(state.A_regs); |
12548 | if (slmB && B_copies > 1) wrdepRanges(state.B_regs); |
12549 | } |
12550 | kLoopBarrier(false, KBarrierType::Signal); |
12551 | break; |
12552 | } |
12553 | }; |
12554 | |
12555 | auto doSLMStore = [&](Iteration h) { |
12556 | if (!slmA && !slmB) return; |
12557 | |
12558 | switch (slmBuffers) { |
12559 | case 1: slm1x2xFencedBarrier(); break; |
12560 | case 2: |
12561 | case 3: |
12562 | case 4: break; |
12563 | default: stub(); |
12564 | } |
12565 | |
12566 | if (slmA) |
12567 | storeMatrix(effAo_regs(h), state.Ao_layout, state.Ao, |
12568 | state.Ao_strategy, state.Ao_addrs, strategy, state); |
12569 | if (slmB) |
12570 | storeMatrix(effBo_regs(h), state.Bo_layout, state.Bo, |
12571 | state.Bo_strategy, state.Bo_addrs, strategy, state); |
12572 | |
12573 | if (slmASums) |
12574 | accumulateSum(false, Ta, effAo_regs(h), state.Ao_layout, Tc, |
12575 | state.As_regs, state.As_layout, strategy, state); |
12576 | if (slmBSums) |
12577 | accumulateSum(true, Tb, effBo_regs(h), state.Bo_layout, Tc, |
12578 | state.Bs_regs, state.Bs_layout, strategy, state); |
12579 | |
12580 | switch (slmBuffers) { |
12581 | case 1: kLoopBarrier(true); break; |
12582 | case 2: |
12583 | slmFenceIssue(); |
12584 | slmFenceWait(); |
12585 | break; |
12586 | case 3: |
12587 | if (strategy.slmFenceWARWA) { |
12588 | // Work around buggy SLM fence by ensuring SLM reads complete. |
12589 | // Should be moved later, just before the barrier. |
12590 | if (slmA && A_copies > 1) wrdepRanges(state.A_regs); |
12591 | if (slmB && B_copies > 1) wrdepRanges(state.B_regs); |
12592 | } |
12593 | kLoopBarrier(true, KBarrierType::Signal); |
12594 | break; |
12595 | case 4: break; |
12596 | default: stub(); |
12597 | } |
12598 | }; |
12599 | |
12600 | if (slmBuffers > 0) { |
12601 | if (slmBuffers >= 4) |
12602 | ls.schedule({{reqSLMAfterStore2, doSLMAfterStore2}, |
12603 | {reqSLMAfterStoreABRem2, doSLMAfterStore2}}); |
12604 | |
12605 | if (slmBuffers >= 2) |
12606 | ls.schedule({{reqSLMAfterStore, doSLMAfterStore}, |
12607 | {reqSLMAfterStoreABRem, doSLMAfterStore}}); |
12608 | |
12609 | ls.schedule( |
12610 | {{reqSLMStore, doSLMStore}, {reqSLMStoreABRem, doSLMStore}}); |
12611 | } |
12612 | |
12613 | // Save pre-loop state. |
12614 | auto statePreLoop = state; |
12615 | |
12616 | using CT = LoopSequencer::CallbackType; |
12617 | |
12618 | Label lTop, lBottom; |
12619 | std::vector<Label> labels; |
12620 | |
12621 | ls.analyze(); |
12622 | |
12623 | if (ls.getUnroll() != unrollK) |
12624 | stub(); // Auto-calculated unroll should match unrollK from strategy. |
12625 | |
12626 | // Prepare to save off loops for periodic barriers, if needed. |
12627 | Subregister outerK; |
12628 | if (strategy.barrierFreq > 0) outerK = state.ra.alloc_sub<uint32_t>(); |
12629 | |
12630 | // Prepare to peel loops for C prefetch, if needed. |
12631 | int prefetchCPeelLoops = -1; |
12632 | Subregister pfCPeelK; |
12633 | if (strategy.prefetchC > 0) { |
12634 | prefetchCPeelLoops = div_up( |
12635 | std::max(0, strategy.prefetchC - ls.getCooldown()), unrollK); |
12636 | if (prefetchCPeelLoops > 0) pfCPeelK = state.ra.alloc_sub<uint32_t>(); |
12637 | } |
12638 | |
12639 | auto resetForNewLoop = [&]() { |
12640 | resetKSLM(); |
12641 | lastThresh = 0; |
12642 | haveA_lastRSWA = false; |
12643 | state.ra.safeRelease(barrierHeader); |
12644 | setupTeardownRemask( |
12645 | Tremask, 0, false, remaskPeriod, state.K, strategy, state); |
12646 | }; |
12647 | |
12648 | // Main events in lifetime of loop. |
12649 | ls.setCallback(CT::OffsetCounter, |
12650 | [&](int offset, int) { add(1, state.K, state.K, offset); }); |
12651 | ls.setCallback(CT::LoopStart, [&](int unroll, int) { |
12652 | cmp(1 | le | state.flagAP, state.K, 0); |
12653 | if (prefetchCPeelLoops > 0) { |
12654 | min_(1, pfCPeelK, state.K, prefetchCPeelLoops * unrollK); |
12655 | add(1, state.K, state.K, -pfCPeelK); |
12656 | } |
12657 | if (strategy.barrierFreq > 0) { |
12658 | add(1 | sat, outerK, state.K, -strategy.barrierFreq); |
12659 | min_(1, state.K, state.K, strategy.barrierFreq); |
12660 | if (strategy.splitBarrier) |
12661 | kLoopBarrier(false, KBarrierType::Signal); |
12662 | } |
12663 | if (hw >= HW::Gen12LP) sync.nop(SWSB(Pipe::A, 1)); |
12664 | jmpi(1 | state.flagAP, lBottom); |
12665 | mark(lTop); |
12666 | state.wipeActiveVFlags(); |
12667 | }); |
12668 | ls.setCallback(CT::LoopEnd, [&](int, int) { |
12669 | jmpi(1 | state.flagAP, lTop); |
12670 | if (strategy.barrierFreq > 0) { |
12671 | add(1, state.K, state.K, outerK); |
12672 | add(1 | sat, outerK, outerK, int16_t(-strategy.barrierFreq)); |
12673 | add(1 | gt | state.flagAP, state.K, state.K, -outerK); |
12674 | if (strategy.splitBarrier) { |
12675 | kLoopBarrier(false, KBarrierType::Wait); |
12676 | kLoopBarrier(false, KBarrierType::Signal); |
12677 | } else |
12678 | kLoopBarrier(false); |
12679 | jmpi(1 | state.flagAP, lTop); |
12680 | } |
12681 | if (prefetchCPeelLoops > 0) { |
12682 | add(1 | gt | state.flagAP, state.K, state.K, pfCPeelK); |
12683 | mov(1, pfCPeelK, 0); |
12684 | gemmPrefetchC(problem, strategy, state); |
12685 | jmpi(1 | state.flagAP, lTop); |
12686 | } |
12687 | mark(lBottom); |
12688 | state.wipeActiveVFlags(); |
12689 | }); |
12690 | ls.setCallback(CT::JumpIfLT, [&](int thresh, int label) { |
12691 | if (size_t(label) >= labels.size()) labels.resize(label + 1); |
12692 | if (thresh != lastThresh) cmp(1 | lt | state.flagAP, state.K, thresh); |
12693 | jmpi(1 | state.flagAP, labels[label]); |
12694 | lastThresh = thresh; |
12695 | }); |
12696 | ls.setCallback(CT::JumpTarget, [&](int label, int) { |
12697 | mark(labels[label]); |
12698 | state.wipeActiveVFlags(); |
12699 | }); |
12700 | ls.setCallback(CT::Jump, [&](int label, int) { |
12701 | if (size_t(label) >= labels.size()) labels.resize(label + 1); |
12702 | jmpi(1, labels[label]); |
12703 | }); |
12704 | ls.setCallback(CT::NotifyPhase, [&](int phase, int) { |
12705 | curPhase = phase; |
12706 | switch (phase) { |
12707 | case LoopSequencer::PhaseWarmup: |
12708 | status << "k loop warmup" << status_stream::endl; |
12709 | break; |
12710 | case LoopSequencer::PhaseMainLoop: |
12711 | status << "Main k loop" << status_stream::endl; |
12712 | break; |
12713 | case LoopSequencer::PhaseMainPathEnd: |
12714 | if (strategy.barrierFreq > 0 && strategy.splitBarrier) |
12715 | kLoopBarrier(false, KBarrierType::Wait); |
12716 | break; |
12717 | case LoopSequencer::PhaseCooldown: |
12718 | if (prefetchCPeelLoops == 0) |
12719 | gemmPrefetchC(problem, strategy, state); |
12720 | if (lateKLoopCheck) state.raVFlag.lock(state.flagAP); |
12721 | haveA_lastRSWA = false; |
12722 | status << "k loop cooldown" << status_stream::endl; |
12723 | break; |
12724 | case LoopSequencer::PhaseShortLoop: |
12725 | if (strategy.prefetchC > 0) |
12726 | gemmPrefetchC(problem, strategy, state); |
12727 | status << "Short k loop" << status_stream::endl; |
12728 | remActiveA = remActiveB = remActiveSLM = false; |
12729 | resetForNewLoop(); |
12730 | state = statePreLoop; |
12731 | break; |
12732 | case LoopSequencer::PhaseRemainder: |
12733 | status << "k loop remainder" << status_stream::endl; |
12734 | break; |
12735 | default: break; |
12736 | } |
12737 | }); |
12738 | |
12739 | // Early C prefetch. |
12740 | if (strategy.prefetchC < 0) gemmPrefetchC(problem, strategy, state); |
12741 | |
12742 | // Generate k loop. |
12743 | if (lateKLoopCheck) state.raVFlag.unlock(state.flagAP); |
12744 | syncall(); /* Avoid unnecessary SWSB dependencies entering loop. */ |
12745 | ls.materialize(); |
12746 | |
12747 | // Release barrier header from short k loop. |
12748 | state.ra.safeRelease(barrierHeader); |
12749 | |
12750 | // Additional barriers to match other threads' barrier count, if other threads might have different k. |
12751 | if (matchBarriers) { |
12752 | status << "Match barrier counts between threads" << status_stream::endl; |
12753 | Subregister myBarriers, k0Barriers; |
12754 | Label , ; |
12755 | int = maxExtraKLoopRemBarriers(strategy); |
12756 | |
12757 | if (strategy.barrierFreq > 0 && prefetchCPeelLoops > 0) stub(); |
12758 | |
12759 | gemmCalcKLoopBarrierCount(k0Barriers, state.inputs.k0, ls.getCooldown(), |
12760 | problem, strategy, state); |
12761 | gemmCalcKLoopBarrierCount(myBarriers, state.k, ls.getCooldown(), |
12762 | problem, strategy, state); |
12763 | if (maxExtraBarriers > 0) |
12764 | add(1, k0Barriers, k0Barriers, maxExtraBarriers); |
12765 | add(1 | sat | le | state.flagAP, myBarriers.ud(), k0Barriers, |
12766 | -myBarriers); |
12767 | (void)kLoopGetBarrierHeader(state); |
12768 | jmpi(1 | state.flagAP, lSkipExtraBarriers); |
12769 | |
12770 | mark(lExtraBarrierLoop); |
12771 | { |
12772 | add(1 | gt | state.flagAP, myBarriers, myBarriers, -1); |
12773 | kLoopBarrier(false); |
12774 | jmpi(1 | state.flagAP, lExtraBarrierLoop); |
12775 | } |
12776 | mark(lSkipExtraBarriers); |
12777 | |
12778 | state.ra.safeRelease(myBarriers); |
12779 | state.ra.safeRelease(k0Barriers); |
12780 | if (!strategy.persistent) state.ra.safeRelease(state.inputs.k0); |
12781 | } |
12782 | |
12783 | // Free resources that are no longer needed. |
12784 | state.ra.safeRelease(outerK); |
12785 | state.ra.safeRelease(pfCPeelK); |
12786 | setupTeardownRemask( |
12787 | Tremask, 0, false, remaskPeriod, state.K, strategy, state); |
12788 | |
12789 | state.firstKLoopSegment = false; |
12790 | } |
12791 | |
12792 | template <HW hw> |
12793 | void gemm_kernel_generator_t<hw>::kLoopTeardown(const GEMMProblem &problem, |
12794 | const GEMMStrategy &strategy, GEMMState &state) { |
12795 | if (state.K != state.k) state.ra.safeRelease(state.K); |
12796 | state.barrierReady = false; |
12797 | state.ra.safeRelease(state.barrierHeader); |
12798 | state.ra.safeRelease(state.barrierHeaderM); |
12799 | state.ra.safeRelease(state.barrierHeaderN); |
12800 | state.raVFlag.safeRelease(state.barrierM); |
12801 | state.raVFlag.safeRelease(state.barrierN); |
12802 | safeReleaseMaskAssignments(state.kMasksSLM, state); |
12803 | safeReleaseRanges(state.Ao_regsRem, state); |
12804 | safeReleaseRanges(state.Bo_regsRem, state); |
12805 | state.tokenAllocator.safeRelease(state.tokenBarrierFence[0]); |
12806 | state.tokenAllocator.safeRelease(state.tokenBarrierFence[1]); |
12807 | } |
12808 | |
12809 | // Create 1-segment inner loop for a GEMM-like kernel. |
12810 | template <HW hw> |
12811 | bool gemm_kernel_generator_t<hw>::kLoopSingle(KLoop type, |
12812 | const GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state) { |
12813 | bool ok = kLoopSetup(problem, strategy, state); |
12814 | if (ok) { |
12815 | kLoop(type, problem, strategy, state); |
12816 | kLoopTeardown(problem, strategy, state); |
12817 | } |
12818 | return ok; |
12819 | } |
12820 | |
12821 | template <HW hw> |
12822 | bool gemm_kernel_generator_t<hw>::gemmKLoop( |
12823 | GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state) { |
12824 | return kLoopSingle(KLoop::GEMM, problem, strategy, state); |
12825 | } |
12826 | |
12827 | // Decide whether C layout needs m/n remainder handling. |
12828 | static inline void getCRemainders(const GEMMProblem &problem, |
12829 | const GEMMStrategy &strategy, bool &remM_C, bool &remN_C) { |
12830 | bool remainderM |
12831 | = (strategy.remHandling[LoopM] != RemainderHandling::Ignore); |
12832 | bool remainderN |
12833 | = (strategy.remHandling[LoopN] != RemainderHandling::Ignore); |
12834 | |
12835 | int C_mgran, C_ngran; |
12836 | getGranularities(problem.C, C_mgran, C_ngran); |
12837 | |
12838 | remM_C = remainderM && !strategy.C.padded && !strategy.altCRemainder |
12839 | && (C_mgran < strategy.unroll[LoopM]); |
12840 | remN_C = remainderN && !strategy.C.padded && !strategy.altCRemainder |
12841 | && (C_ngran < strategy.unroll[LoopN]); |
12842 | } |
12843 | |
12844 | static inline bool needsKLoopReset(const GEMMProblem &problem) { |
12845 | return false; |
12846 | } |
12847 | |
12848 | // Setup for C accumulation. |
12849 | // NOTE: modifies problem/strategy/state. |
12850 | template <HW hw> |
12851 | bool gemm_kernel_generator_t<hw>::gemmAccumulateCSetup( |
12852 | GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state) { |
12853 | auto &Ta = problem.Ta, &Tb = problem.Tb, Tc = problem.Tc; |
12854 | auto Ta_ext = problem.Ta_ext, Tb_ext = problem.Tb_ext, |
12855 | Tc_ext = problem.Tc_ext; |
12856 | auto &Ta_load = state.Ta_load, &Tb_load = state.Tb_load; |
12857 | |
12858 | bool cLoadAhead = strategy.cLoadAhead; |
12859 | auto unrollM = strategy.unroll[LoopM]; |
12860 | auto unrollN = strategy.unroll[LoopN]; |
12861 | auto unrollK = strategy.unroll[LoopK]; |
12862 | |
12863 | // Decide what remainder handling needs to be done. |
12864 | bool remainderM = strategy.remHandling[LoopM] != RemainderHandling::Ignore; |
12865 | bool remainderN = strategy.remHandling[LoopN] != RemainderHandling::Ignore; |
12866 | bool remainderK = strategy.remHandling[LoopK] != RemainderHandling::Ignore; |
12867 | bool remM_A = remainderM && !strategy.A.padded; |
12868 | bool remK_A = false; |
12869 | bool remK_B = false; |
12870 | bool remN_B = remainderN && !strategy.B.padded; |
12871 | bool remM_C, remN_C; |
12872 | getCRemainders(problem, strategy, remM_C, remN_C); |
12873 | bool remM_Ce = remM_C; |
12874 | bool remN_Ce = remN_C; |
12875 | |
12876 | if (state.copyC) remM_C = remN_C = false; |
12877 | |
12878 | auto globalA = problem.A; |
12879 | auto globalB = problem.B; |
12880 | |
12881 | // 2D addressing parameters. |
12882 | auto &A_params = state.A_params, &B_params = state.B_params; |
12883 | auto &Ai_params = state.Ai_params, &Bi_params = state.Bi_params; |
12884 | auto &Ap_params = state.Ap_params, &Bp_params = state.Bp_params; |
12885 | A_params.rows = state.inputs.m; |
12886 | A_params.cols = state.fullK; |
12887 | A_params.offR = state.i0; |
12888 | A_params.offC = state.h0; |
12889 | A_params.remR = state.remainders[LoopM]; |
12890 | B_params.rows = state.fullK; |
12891 | B_params.cols = state.inputs.n; |
12892 | B_params.offR = state.h0; |
12893 | B_params.offC = state.j0; |
12894 | B_params.remC = state.remainders[LoopN]; |
12895 | Ai_params = A_params, Bi_params = B_params; |
12896 | Ap_params = A_params, Bp_params = B_params; |
12897 | |
12898 | // Decide which dimensions to split for WG-cooperative operations (SLM copy, cooperative PF). |
12899 | state.effCoopA = effCoopSplitA(problem, strategy); |
12900 | state.effCoopB = effCoopSplitB(problem, strategy); |
12901 | |
12902 | if (strategy.slmA && (state.effCoopA != CoopSplit::K) && remM_A |
12903 | && !isBlock2D(strategy.A.accessType)) { |
12904 | strategy.A.accessType = isColMajor(problem.A.layout) |
12905 | ? AccessType::Block |
12906 | : AccessType::Scattered; |
12907 | state.effCoopA = CoopSplit::K; |
12908 | } |
12909 | |
12910 | if (strategy.slmB && (state.effCoopB != CoopSplit::K) && remN_B |
12911 | && !isBlock2D(strategy.B.accessType)) { |
12912 | strategy.B.accessType = !isColMajor(problem.B.layout) |
12913 | ? AccessType::Block |
12914 | : AccessType::Scattered; |
12915 | state.effCoopB = CoopSplit::K; |
12916 | } |
12917 | |
12918 | // Prepare layouts for prefetch. |
12919 | bool remM_Cp = remM_C && strategy.C.base.isStateless(); |
12920 | bool remN_Cp = remN_C && strategy.C.base.isStateless(); |
12921 | |
12922 | state.ma_prefetch = state.ka_prefetch = state.kb_prefetch |
12923 | = state.nb_prefetch = 0; |
12924 | if (strategy.prefetchA) |
12925 | coopSplit(true, state.ma_prefetch, state.ka_prefetch, unrollM, |
12926 | strategy.ka_prefetch, state.effCoopA, strategy.wg[LoopN], |
12927 | problem.A); |
12928 | if (strategy.prefetchB) |
12929 | coopSplit(false, state.kb_prefetch, state.nb_prefetch, |
12930 | strategy.kb_prefetch, unrollN, state.effCoopB, |
12931 | strategy.wg[LoopM], problem.B); |
12932 | |
12933 | if (strategy.prefetchA |
12934 | && !getRegLayout(Ta_ext, state.Ap_layout, state.ma_prefetch, |
12935 | state.ka_prefetch, remM_A, remK_A, false, true, 0, 0, |
12936 | problem.A, strategy.A_prefetch)) |
12937 | return false; |
12938 | if (strategy.prefetchB |
12939 | && !getRegLayout(Tb_ext, state.Bp_layout, state.kb_prefetch, |
12940 | state.nb_prefetch, remK_B, remN_B, false, true, 0, 0, |
12941 | problem.B, strategy.B_prefetch)) |
12942 | return false; |
12943 | if (strategy.prefetchC |
12944 | && !getRegLayout(Tc_ext, state.Cp_layout, unrollM, unrollN, remM_Cp, |
12945 | remN_Cp, false, true, 0, 0, problem.C, strategy.C_prefetch)) |
12946 | return false; |
12947 | |
12948 | if (hasMasking(state.Cp_layout) || hasFragmenting(state.Cp_layout)) stub(); |
12949 | |
12950 | // Prepare addresses for prefetch. |
12951 | if (strategy.cooperativePF && strategy.prefetchA) { |
12952 | Subregister offAp; |
12953 | gemmCalcWorkshareAOffset(offAp, Ap_params.offR, Ap_params.offC, |
12954 | problem.A, strategy.A_prefetch, state.ma_prefetch, |
12955 | state.ka_prefetch, problem, strategy, state); |
12956 | if (strategy.A_prefetch.address2D) { |
12957 | if (A_params.offR.isValid()) |
12958 | add(1, Ap_params.offR, Ap_params.offR, A_params.offR); |
12959 | if (A_params.offC.isValid()) |
12960 | add(1, Ap_params.offC, Ap_params.offC, A_params.offC); |
12961 | } else { |
12962 | auto inEffAp = state.effAp; |
12963 | if (state.effA == state.effAp) |
12964 | state.effAp = state.ra.alloc_sub(state.effA.getType()); |
12965 | eadd(1, state.effAp, inEffAp, offAp, strategy, state); |
12966 | } |
12967 | state.ra.safeRelease(offAp); |
12968 | } |
12969 | if (strategy.cooperativePF && strategy.prefetchB) { |
12970 | Subregister offBp; |
12971 | gemmCalcWorkshareBOffset(offBp, Bp_params.offR, Bp_params.offC, |
12972 | problem.B, strategy.B_prefetch, state.kb_prefetch, |
12973 | state.nb_prefetch, problem, strategy, state); |
12974 | if (strategy.B_prefetch.address2D) { |
12975 | if (B_params.offR.isValid()) |
12976 | add(1, Bp_params.offR, Bp_params.offR, B_params.offR); |
12977 | if (B_params.offC.isValid()) |
12978 | add(1, Bp_params.offC, Bp_params.offC, B_params.offC); |
12979 | } else { |
12980 | auto inEffBp = state.effBp; |
12981 | if (state.effB == state.effBp) |
12982 | state.effBp = state.ra.alloc_sub(state.effB.getType()); |
12983 | eadd(1, state.effBp, inEffBp, offBp, strategy, state); |
12984 | } |
12985 | state.ra.safeRelease(offBp); |
12986 | } |
12987 | |
12988 | // Prepare layouts and starting addresses for SLM copies and adjust problem. |
12989 | if (strategy.slmBuffers > 0) { |
12990 | int A_slmCP, B_slmCP; |
12991 | int A_tileR, A_tileC, B_tileR, B_tileC; |
12992 | std::tie(A_slmCP, B_slmCP) = targetSLMCrosspack(hw, problem, strategy); |
12993 | std::tie(A_tileR, A_tileC, B_tileR, B_tileC) |
12994 | = targetKernelTiling(hw, problem, strategy); |
12995 | auto opCount = outerProductCount(hw, problem, strategy); |
12996 | |
12997 | if (strategy.slmA) { |
12998 | coopSplit(true, state.ma_slm, state.ka_slm, unrollM, |
12999 | strategy.unrollKSLM, state.effCoopA, strategy.wg[LoopN], |
13000 | problem.A); |
13001 | |
13002 | if (state.ma_slm < unrollM) { |
13003 | remM_A = false; |
13004 | remK_A = remainderK && strategy.slmEarlyKMask; |
13005 | } |
13006 | if (strategy.slmATrans) { |
13007 | A_slmCP = state.ka_slm; |
13008 | if (strategy.ka_load % A_slmCP) |
13009 | throw std::runtime_error( |
13010 | "ka_load must be a multiple of ka_slm" ); |
13011 | } |
13012 | if ((state.ka_slm < A_slmCP) && (strategy.unrollKSLM != A_slmCP) |
13013 | && (A_tileC != A_slmCP)) |
13014 | throw std::runtime_error( |
13015 | "ka_slm must be a multiple of crosspack, or unrollKSLM " |
13016 | "= crosspack." ); |
13017 | |
13018 | // Layout in from memory... |
13019 | state.Ai = problem.A; |
13020 | state.Ai_strategy = strategy.A; |
13021 | if (state.Ai_strategy.dpasw) { |
13022 | state.Ai_strategy.dpasw = false; |
13023 | state.Ai_strategy.tileR = 0; |
13024 | } |
13025 | |
13026 | // ... layout out to SLM. |
13027 | state.Ao.layout = MatrixLayout::Pc; |
13028 | state.Ao.packSize = unrollM; |
13029 | state.Ao.crosspack = A_slmCP; |
13030 | state.Ao.setAlignment(state.Ao.packSize * Ta); |
13031 | state.Ao.tileR = A_tileR; |
13032 | state.Ao.tileC = (A_tileC || !A_tileR) |
13033 | ? A_tileC |
13034 | : std::max(opCount, strategy.ka_load); |
13035 | |
13036 | bool colMajorIn |
13037 | = isRegisterColMajor(Ta_ext, state.Ai, state.Ai_strategy); |
13038 | bool colMajorSLM = !isLargeCrosspack(Ta, A_slmCP); |
13039 | state.Ao_strategy.base = SLM; |
13040 | state.Ao_strategy.accessType = (colMajorIn == colMajorSLM) |
13041 | ? AccessType::Block |
13042 | : AccessType::Scattered; |
13043 | state.Ao_strategy.smode = ScatterSIMD::Default; |
13044 | |
13045 | if (state.Ai.layout == MatrixLayout::N |
13046 | && state.Ai_strategy.accessType == AccessType::Block2DVNNI |
13047 | && isLargeCrosspack(Ta, A_slmCP)) { |
13048 | state.Ao_strategy.accessType = AccessType::ChannelScattered; |
13049 | state.Ao_strategy.smode = ScatterSIMD::Narrow; |
13050 | } |
13051 | state.Ao_strategy.padded = true; |
13052 | state.Ao_strategy.atomic = false; |
13053 | state.Ao_strategy.address2D = false; |
13054 | state.Ao_strategy.newDP = (hw >= HW::XeHPG); |
13055 | state.Ao_strategy.cachingW = CacheSettingsLSC::Default; |
13056 | |
13057 | // Layout in from memory... |
13058 | if (!getRegLayout(Ta_ext, state.Ai_layout, state.ma_slm, |
13059 | state.ka_slm, remM_A, remK_A, false, true, 0, 0, |
13060 | state.Ai, state.Ai_strategy)) |
13061 | return false; |
13062 | |
13063 | // ... layout out to SLM... |
13064 | remM_A = remK_A = false; |
13065 | if (!getRegLayout(Ta, state.Ao_layout, state.ma_slm, state.ka_slm, |
13066 | remM_A, remK_A, true, true, 0, 0, state.Ao, |
13067 | state.Ao_strategy)) |
13068 | return false; |
13069 | |
13070 | // ... and layout back from SLM. |
13071 | problem.A = state.Ao; |
13072 | strategy.A.base = SLM; |
13073 | strategy.A.accessType = AccessType::Block; |
13074 | strategy.A.address2D = false; |
13075 | strategy.A.newDP = (hw >= HW::XeHPG); |
13076 | strategy.A.cachingR = CacheSettingsLSC::Default; |
13077 | Ta_load = Ta; |
13078 | state.aioShare = Ta.size() == Ta_ext.size() |
13079 | && Ta.components() == Ta_ext.components() |
13080 | && matchLayoutsBidirectional( |
13081 | Ta, state.Ai_layout, state.Ao_layout); |
13082 | |
13083 | // If we will add k-masking later, check if extra registers are needed. |
13084 | state.Ai_regCount = getRegCount(state.Ai_layout); |
13085 | if (!remK_A && remainderK && !state.Ai_strategy.address2D |
13086 | && !isRegisterColMajor( |
13087 | Ta_ext, state.Ai, state.Ai_strategy)) { |
13088 | std::vector<RegisterBlock> Ai_layoutKMasked; |
13089 | if (getRegLayout(Ta_ext, Ai_layoutKMasked, state.ma_slm, |
13090 | state.ka_slm, remM_A, true, false, true, 0, 0, |
13091 | state.Ai, state.Ai_strategy)) |
13092 | state.Ai_regCount = std::max( |
13093 | state.Ai_regCount, getRegCount(Ai_layoutKMasked)); |
13094 | } |
13095 | |
13096 | // Offset A addresses in and out. |
13097 | state.effAi = state.effA; |
13098 | state.effA = state.ra.alloc_sub<uint32_t>( |
13099 | getHint(HintType::LongTerm, strategy)); |
13100 | state.effAo = state.ra.alloc_sub<uint32_t>( |
13101 | getHint(HintType::LongTerm, strategy)); |
13102 | |
13103 | auto temp = state.ra.alloc_sub<uint32_t>( |
13104 | getHint(HintType::TempComp0, strategy)); |
13105 | Subregister temp2; |
13106 | |
13107 | uint32_t noff, noffTile, tileSplit = 1; |
13108 | |
13109 | switch (state.effCoopA) { |
13110 | case CoopSplit::Linear: |
13111 | // FIXME: assumes compatible tiling between global and SLM layouts. |
13112 | noff = state.ma_slm * state.ka_slm; |
13113 | break; |
13114 | case CoopSplit::MN: |
13115 | noff = untile(Ta, state.Ao, 0, state.ma_slm, 0, |
13116 | state.Ao.packSize, strategy.unrollKSLM); |
13117 | if (state.ma_slm < state.Ao.tileR |
13118 | && state.Ao.tileR < state.Ao.packSize) { |
13119 | // m division splits tiles -- starting offsets no longer a linear sequence. |
13120 | if (state.Ao.tileR % state.ma_slm) stub(); |
13121 | tileSplit = state.Ao.tileR / state.ma_slm; |
13122 | noffTile = untile(Ta, state.Ao, 0, state.Ao.tileR, 0, |
13123 | state.Ao.packSize, strategy.unrollKSLM); |
13124 | } |
13125 | break; |
13126 | case CoopSplit::K: |
13127 | noff = untile(Ta, state.Ao, 0, 0, state.ka_slm, |
13128 | state.Ao.packSize, strategy.unrollKSLM); |
13129 | if (state.ka_slm < state.Ao.tileC |
13130 | && state.Ao.tileC < strategy.unrollKSLM) { |
13131 | // k division splits tiles -- starting offsets no longer a linear sequence. |
13132 | if (state.Ao.tileC % state.ka_slm) stub(); |
13133 | tileSplit = state.Ao.tileC / state.ka_slm; |
13134 | noffTile = untile(Ta, state.Ao, 0, 0, state.Ao.tileC, |
13135 | state.Ao.packSize, strategy.unrollKSLM); |
13136 | } |
13137 | break; |
13138 | default: stub(); |
13139 | } |
13140 | |
13141 | int32_t A_slmStride |
13142 | = strategy.slmABufBlockSize(problem) * strategy.slmBuffers; |
13143 | |
13144 | if (tileSplit > 1) { |
13145 | if (!is_zero_or_pow2(tileSplit)) stub(); |
13146 | shr(1, temp, state.lidN, log2(tileSplit)); |
13147 | } |
13148 | gemmCalcWorkshareAOffset(temp2, Ai_params.offR, Ai_params.offC, |
13149 | state.Ai, state.Ai_strategy, state.ma_slm, state.ka_slm, |
13150 | problem, strategy, state); |
13151 | if (tileSplit > 1) { |
13152 | mulConstant(1, temp, temp, (noffTile - noff * tileSplit) * Ta); |
13153 | emad(1, temp, temp, state.lidN, noff * Ta, strategy, state); |
13154 | } else |
13155 | mulConstant(1, temp, state.lidN, noff * Ta); |
13156 | mulConstant(1, state.effA, state.lidM, A_slmStride); |
13157 | if (strategy.wg[LoopK] > 1) |
13158 | emad(1, state.effA, state.effA, state.lidK, |
13159 | A_slmStride * strategy.wg[LoopM], strategy, state); |
13160 | if (state.Ai_strategy.address2D) { |
13161 | if (Ai_params.offR != A_params.offR && A_params.offR.isValid()) |
13162 | add(1, Ai_params.offR, Ai_params.offR, A_params.offR); |
13163 | if (Ai_params.offC != A_params.offC && A_params.offC.isValid()) |
13164 | add(1, Ai_params.offC, Ai_params.offC, A_params.offC); |
13165 | } else |
13166 | eadd(1, state.effAi, state.effAi, temp2, strategy, state); |
13167 | add(1, state.effAo, state.effA, temp); |
13168 | if (problem.backward()) |
13169 | add(1, state.effA, state.effA, |
13170 | (strategy.unrollKSLM - strategy.ka_load) * unrollM |
13171 | * Ta); |
13172 | |
13173 | state.ra.safeRelease(temp2); |
13174 | state.ra.safeRelease(temp); |
13175 | } |
13176 | if (strategy.slmB) { |
13177 | coopSplit(false, state.kb_slm, state.nb_slm, strategy.unrollKSLM, |
13178 | unrollN, state.effCoopB, strategy.wg[LoopM], problem.B); |
13179 | |
13180 | if (state.nb_slm < unrollN) { |
13181 | remN_B = false; |
13182 | remK_B = remainderK && strategy.slmEarlyKMask; |
13183 | } |
13184 | if (strategy.slmBTrans) { |
13185 | B_slmCP = state.kb_slm; |
13186 | if (strategy.kb_load % B_slmCP) |
13187 | throw std::runtime_error( |
13188 | "kb_load must be a multiple of kb_slm" ); |
13189 | } |
13190 | if ((state.kb_slm < B_slmCP) && (strategy.unrollKSLM != B_slmCP) |
13191 | && (B_tileR != B_slmCP)) |
13192 | throw std::runtime_error( |
13193 | "kb_slm must be a multiple of crosspack, or unrollKSLM " |
13194 | "= crosspack." ); |
13195 | |
13196 | // Layout in from memory... |
13197 | state.Bi = problem.B; |
13198 | state.Bi_strategy = strategy.B; |
13199 | if (state.Bi_strategy.dpasw) { |
13200 | state.Bi_strategy.dpasw = false; |
13201 | state.Bi_strategy.tileC = 0; |
13202 | } |
13203 | |
13204 | // ... layout out to SLM. |
13205 | state.Bo.layout = MatrixLayout::Pr; |
13206 | state.Bo.packSize = unrollN; |
13207 | state.Bo.crosspack = B_slmCP; |
13208 | state.Bo.setAlignment(state.Bo.packSize * Tb); |
13209 | state.Bo.tileR = (B_tileR || !B_tileC) |
13210 | ? B_tileR |
13211 | : std::max(opCount, strategy.kb_load); |
13212 | state.Bo.tileC = B_tileC; |
13213 | |
13214 | bool colMajorIn |
13215 | = isRegisterColMajor(Tb_ext, state.Bi, state.Bi_strategy); |
13216 | bool colMajorSLM = isLargeCrosspack(Tb, B_slmCP); |
13217 | state.Bo_strategy.base = SLM; |
13218 | state.Bo_strategy.accessType = (colMajorIn == colMajorSLM) |
13219 | ? AccessType::Block |
13220 | : AccessType::Scattered; |
13221 | state.Bo_strategy.smode = ScatterSIMD::Default; |
13222 | |
13223 | if (state.Bi.layout == MatrixLayout::T |
13224 | && state.Bi_strategy.accessType == AccessType::Block2DVNNI |
13225 | && isLargeCrosspack(Tb, B_slmCP)) { |
13226 | state.Bo_strategy.accessType = AccessType::ChannelScattered; |
13227 | state.Bo_strategy.smode = ScatterSIMD::Narrow; |
13228 | } |
13229 | state.Bo_strategy.padded = true; |
13230 | state.Bo_strategy.atomic = false; |
13231 | state.Bo_strategy.address2D = false; |
13232 | state.Bo_strategy.newDP = (hw >= HW::XeHPG); |
13233 | state.Bo_strategy.cachingW = CacheSettingsLSC::Default; |
13234 | |
13235 | // Layout in from memory... |
13236 | if (!getRegLayout(Tb_ext, state.Bi_layout, state.kb_slm, |
13237 | state.nb_slm, remK_B, remN_B, false, true, 0, 0, |
13238 | state.Bi, state.Bi_strategy)) |
13239 | return false; |
13240 | |
13241 | // ... layout out to SLM... |
13242 | remK_B = remN_B = false; |
13243 | if (!getRegLayout(Tb, state.Bo_layout, state.kb_slm, state.nb_slm, |
13244 | remK_B, remN_B, true, true, 0, 0, state.Bo, |
13245 | state.Bo_strategy)) |
13246 | return false; |
13247 | |
13248 | // ... and layout back from SLM. |
13249 | problem.B = state.Bo; |
13250 | strategy.B.base = SLM; |
13251 | strategy.B.accessType = AccessType::Block; |
13252 | strategy.B.address2D = false; |
13253 | strategy.B.newDP = (hw >= HW::XeHPG); |
13254 | strategy.B.cachingR = CacheSettingsLSC::Default; |
13255 | Tb_load = Tb; |
13256 | state.bioShare = Tb.size() == Tb_ext.size() |
13257 | && Tb.components() == Tb_ext.components() |
13258 | && matchLayoutsBidirectional( |
13259 | Tb, state.Bi_layout, state.Bo_layout); |
13260 | |
13261 | // If we will add k-masking later, check if extra registers are needed. |
13262 | state.Bi_regCount = getRegCount(state.Bi_layout); |
13263 | if (!remK_B && remainderK && !state.Bi_strategy.address2D |
13264 | && isRegisterColMajor( |
13265 | Tb_ext, state.Bi, state.Bi_strategy)) { |
13266 | std::vector<RegisterBlock> Bi_layoutKMasked; |
13267 | if (getRegLayout(Tb_ext, Bi_layoutKMasked, state.kb_slm, |
13268 | state.nb_slm, true, remN_B, false, true, 0, 0, |
13269 | state.Bi, state.Bi_strategy)) |
13270 | state.Bi_regCount = std::max( |
13271 | state.Bi_regCount, getRegCount(Bi_layoutKMasked)); |
13272 | } |
13273 | |
13274 | // Offset B addresses in and out. |
13275 | state.effBi = state.effB; |
13276 | state.effB = state.ra.alloc_sub<uint32_t>( |
13277 | getHint(HintType::LongTerm, strategy)); |
13278 | state.effBo = state.ra.alloc_sub<uint32_t>( |
13279 | getHint(HintType::LongTerm, strategy)); |
13280 | |
13281 | auto temp = state.ra.alloc_sub<uint32_t>( |
13282 | getHint(HintType::TempComp0, strategy)); |
13283 | Subregister temp2; |
13284 | |
13285 | uint32_t moff, moffTile, tileSplit = 1; |
13286 | |
13287 | switch (state.effCoopB) { |
13288 | case CoopSplit::Linear: |
13289 | moff = state.kb_slm * state.nb_slm; |
13290 | break; |
13291 | case CoopSplit::MN: |
13292 | moff = untile(Tb, state.Bo, 0, 0, state.nb_slm, |
13293 | strategy.unrollKSLM, state.Bo.packSize); |
13294 | if (state.nb_slm < state.Bo.tileC |
13295 | && state.Bo.tileC < state.Bo.packSize) { |
13296 | if (state.Bo.tileC % state.nb_slm) stub(); |
13297 | tileSplit = state.Bo.tileC / state.nb_slm; |
13298 | moffTile = untile(Tb, state.Bo, 0, 0, state.Bo.tileC, |
13299 | strategy.unrollKSLM, state.Bo.packSize); |
13300 | } |
13301 | break; |
13302 | case CoopSplit::K: |
13303 | moff = untile(Tb, state.Bo, 0, state.kb_slm, 0, |
13304 | strategy.unrollKSLM, state.Bo.packSize); |
13305 | if (state.kb_slm < state.Bo.tileR) { |
13306 | if (state.Bo.tileR % state.kb_slm) stub(); |
13307 | tileSplit = state.Bo.tileR / state.kb_slm; |
13308 | moffTile = untile(Tb, state.Bo, 0, state.Bo.tileR, 0, |
13309 | strategy.unrollKSLM, state.Bo.packSize); |
13310 | } |
13311 | break; |
13312 | default: stub(); |
13313 | } |
13314 | |
13315 | int32_t B_slmStride |
13316 | = strategy.slmBBufBlockSize(problem) * strategy.slmBuffers; |
13317 | |
13318 | if (tileSplit > 1) { |
13319 | if (!is_zero_or_pow2(tileSplit)) stub(); |
13320 | shr(1, temp, state.lidM, log2(tileSplit)); |
13321 | } |
13322 | gemmCalcWorkshareBOffset(temp2, Bi_params.offR, Bi_params.offC, |
13323 | state.Bi, state.Bi_strategy, state.kb_slm, state.nb_slm, |
13324 | problem, strategy, state); |
13325 | if (tileSplit > 1) { |
13326 | mulConstant(1, temp, temp, (moffTile - moff * tileSplit) * Tb); |
13327 | emad(1, temp, temp, state.lidM, moff * Tb, strategy, state); |
13328 | } else |
13329 | mulConstant(1, temp, state.lidM, moff * Tb); |
13330 | mulConstant(1, state.effB, state.lidN, B_slmStride); |
13331 | if (strategy.wg[LoopK] > 1) |
13332 | emad(1, state.effB, state.effB, state.lidK, |
13333 | B_slmStride * strategy.wg[LoopN], strategy, state); |
13334 | if (state.Bi_strategy.address2D) { |
13335 | if (Bi_params.offR != B_params.offR && B_params.offR.isValid()) |
13336 | add(1, Bi_params.offR, Bi_params.offR, B_params.offR); |
13337 | if (Bi_params.offC != B_params.offC && B_params.offC.isValid()) |
13338 | add(1, Bi_params.offC, Bi_params.offC, B_params.offC); |
13339 | } else |
13340 | eadd(1, state.effBi, state.effBi, temp2, strategy, state); |
13341 | if (strategy.slmABufSize(problem) > 0) |
13342 | add(1, state.effB, state.effB, strategy.slmABufSize(problem)); |
13343 | add(1, state.effBo, state.effB, temp); |
13344 | if (problem.backward()) |
13345 | add(1, state.effB, state.effB, |
13346 | (strategy.unrollKSLM - strategy.kb_load) * unrollN |
13347 | * Tb); |
13348 | |
13349 | state.ra.safeRelease(temp2); |
13350 | state.ra.safeRelease(temp); |
13351 | } |
13352 | } |
13353 | |
13354 | // Starting address adjustments for DPASW. |
13355 | bool cColMajor = isRegisterColMajor(Tc_ext, problem.C, strategy.C); |
13356 | if (strategy.dpasw) { |
13357 | if (cColMajor) { |
13358 | int t = strategy.B.tileC; |
13359 | and_(1 | nz | state.flagAP, null.uw(), state.lidM, 1); |
13360 | switch (problem.B.layout) { |
13361 | case MatrixLayout::N: |
13362 | emad(1 | state.flagAP, state.effB, state.effB, |
13363 | state.inputs.ldb, Immediate::w(t), strategy, state); |
13364 | break; |
13365 | case MatrixLayout::T: |
13366 | eadd(1 | state.flagAP, state.effB, state.effB, t * Tb_load, |
13367 | strategy, state); |
13368 | break; |
13369 | case MatrixLayout::Pr: |
13370 | eadd(1 | state.flagAP, state.effB, state.effB, |
13371 | untile(Tb_load, problem.B, 0, 0, t, unrollK, |
13372 | unrollN) |
13373 | * Tb_load, |
13374 | strategy, state); |
13375 | break; |
13376 | default: stub(); |
13377 | } |
13378 | } else { |
13379 | int t = strategy.A.tileR; |
13380 | and_(1 | nz | state.flagAP, null.uw(), state.lidN, 1); |
13381 | switch (problem.A.layout) { |
13382 | case MatrixLayout::T: |
13383 | emad(1 | state.flagAP, state.effA, state.effA, |
13384 | state.inputs.lda, Immediate::w(t), strategy, state); |
13385 | break; |
13386 | case MatrixLayout::N: |
13387 | eadd(1 | state.flagAP, state.effA, state.effA, t * Ta_load, |
13388 | strategy, state); |
13389 | break; |
13390 | case MatrixLayout::Pc: |
13391 | eadd(1 | state.flagAP, state.effA, state.effA, |
13392 | untile(Ta_load, problem.A, 0, t, 0, unrollM, |
13393 | unrollK) |
13394 | * Ta_load, |
13395 | strategy, state); |
13396 | break; |
13397 | default: stub(); |
13398 | } |
13399 | } |
13400 | } |
13401 | |
13402 | // Get register layouts for A/B/C. |
13403 | if (!getRegLayout(Ta_load, state.A_layout, unrollM, strategy.ka_load, |
13404 | remM_A, remK_A, false, true, 0, 0, problem.A, strategy.A)) |
13405 | return false; |
13406 | if (!getRegLayout(Tb_load, state.B_layout, strategy.kb_load, unrollN, |
13407 | remK_B, remN_B, false, true, 0, 0, problem.B, strategy.B)) |
13408 | return false; |
13409 | |
13410 | if (state.copyC) { |
13411 | makeUnbackedRegLayout(Tc, state.C_layout, unrollM, unrollN, cColMajor, |
13412 | 1, strategy.C.tileR, strategy.C.tileC, true); |
13413 | if (!getRegLayout(Tc_ext, state.C_layoutExt, unrollM, unrollN, remM_Ce, |
13414 | remN_Ce, true, false, 0, 0, problem.C, state.Cext_strategy)) |
13415 | return false; |
13416 | } else { |
13417 | if (!getRegLayout(Tc, state.C_layout, unrollM, unrollN, remM_C, remN_C, |
13418 | true, false, 0, 0, problem.C, strategy.C)) |
13419 | return false; |
13420 | } |
13421 | |
13422 | if (!strategy.altCRemainder && (remM_Ce || remN_Ce)) { |
13423 | // Try preparing C layout without masking (may reduce memory accesses). |
13424 | // Only use it if compatible with the masked layout, and saves on send instructions. |
13425 | auto &layoutExt = state.copyC ? state.C_layoutExt : state.C_layout; |
13426 | (void)getRegLayout(Tc_ext, state.C_layoutExtUnmasked, unrollM, unrollN, |
13427 | false, false, true, false, 0, 0, problem.C, |
13428 | state.Cext_strategy); |
13429 | if (state.C_layoutExtUnmasked.size() == layoutExt.size() |
13430 | || (!state.copyC |
13431 | && !matchLayouts( |
13432 | Tc, layoutExt, state.C_layoutExtUnmasked))) |
13433 | state.C_layoutExtUnmasked.clear(); |
13434 | } |
13435 | |
13436 | if (!state.copyC) state.C_layoutExt = state.C_layout; |
13437 | |
13438 | if (hasRowFragmenting(state.A_layout) |
13439 | || hasColumnFragmenting(state.B_layout)) { |
13440 | status << "Can't fragment A or B.\n" ; |
13441 | return false; |
13442 | } |
13443 | |
13444 | // Prepare to repack A/B if needed. |
13445 | int crosspackA, crosspackB, tileM_A, tileK_A, tileK_B, tileN_B; |
13446 | std::tie(crosspackA, crosspackB) |
13447 | = targetKernelCrosspack(hw, problem, strategy); |
13448 | std::tie(tileM_A, tileK_A, tileK_B, tileN_B) |
13449 | = targetKernelTiling(hw, problem, strategy); |
13450 | |
13451 | state.repackA |
13452 | |= (crosspackA && !hasFullCrosspack(state.A_layout, crosspackA)) |
13453 | || !hasTiling(state.A_layout, tileM_A, tileK_A); |
13454 | state.repackB |
13455 | |= (crosspackB && !hasFullCrosspack(state.B_layout, crosspackB)) |
13456 | || !hasTiling(state.B_layout, tileK_B, tileN_B); |
13457 | |
13458 | state.repackA |= (Ta.size() != Ta_ext.size() |
13459 | || Ta.components() != Ta_ext.components()) |
13460 | && !strategy.slmA; |
13461 | state.repackB |= (Tb.size() != Tb_ext.size() |
13462 | || Tb.components() != Tb_ext.components()) |
13463 | && !strategy.slmB; |
13464 | |
13465 | if (crosspackA == 0) crosspackA = 1; |
13466 | if (crosspackB == 0) crosspackB = 1; |
13467 | |
13468 | bool splitA = false, splitB = false; |
13469 | |
13470 | if (state.repackA) |
13471 | makeUnbackedRegLayout(Ta, state.Ar_layout, unrollM, strategy.ka_load, |
13472 | isLayoutColMajor(state.A_layout), crosspackA, tileM_A, tileK_A, |
13473 | true, splitA); |
13474 | if (state.repackB) |
13475 | makeUnbackedRegLayout(Tb, state.Br_layout, strategy.kb_load, unrollN, |
13476 | isLayoutColMajor(state.B_layout), crosspackB, tileK_B, tileN_B, |
13477 | true, splitB); |
13478 | |
13479 | // Prepare layouts for row/column sum calculation. |
13480 | bool globalCM = isLayoutColMajor(state.C_layout); |
13481 | if (problem.needsASums()) { |
13482 | state.systolicSumA = strategy.systolic && globalCM; |
13483 | state.slmASums = strategy.slmA && !state.systolicSumA; |
13484 | |
13485 | auto As_srcLayout = state.slmASums |
13486 | ? state.Ao_layout |
13487 | : state.repackA ? state.Ar_layout : state.A_layout; |
13488 | makeSumLayout( |
13489 | false, Ta, As_srcLayout, Tc, state.As_layout, strategy, state); |
13490 | if (state.systolicSumA) |
13491 | setupTeardownAccumulateSumSystolic( |
13492 | true, Tb, problem, strategy, state); |
13493 | } |
13494 | if (problem.needsBSums()) { |
13495 | state.systolicSumB = strategy.systolic && !globalCM; |
13496 | state.slmBSums = strategy.slmB && !state.systolicSumB; |
13497 | |
13498 | auto Bs_srcLayout = state.slmBSums |
13499 | ? state.Bo_layout |
13500 | : state.repackB ? state.Br_layout : state.B_layout; |
13501 | makeSumLayout( |
13502 | true, Tb, Bs_srcLayout, Tc, state.Bs_layout, strategy, state); |
13503 | if (state.systolicSumB) |
13504 | setupTeardownAccumulateSumSystolic( |
13505 | true, Ta, problem, strategy, state); |
13506 | } |
13507 | |
13508 | // Round up needed A/B flag registers; hold off on C. |
13509 | // Try first without virtual flags and retry if needed. |
13510 | // m/n cooperative SLM copies use k masking, so skip those masks for now. |
13511 | auto &masks = state.AB_masks; |
13512 | |
13513 | auto assignAllMasks = [&]() { |
13514 | return assignMasks(state.A_layout, LoopM, LoopK, masks, strategy, state) |
13515 | && assignMasks( |
13516 | state.Ap_layout, LoopM, LoopK, masks, strategy, state) |
13517 | && assignMasks( |
13518 | state.B_layout, LoopK, LoopN, masks, strategy, state) |
13519 | && assignMasks( |
13520 | state.Bp_layout, LoopK, LoopN, masks, strategy, state) |
13521 | && ((state.effCoopA != CoopSplit::K) |
13522 | || assignMasks(state.Ai_layout, LoopM, LoopK, masks, |
13523 | strategy, state)) |
13524 | && ((state.effCoopB != CoopSplit::K) |
13525 | || assignMasks(state.Bi_layout, LoopK, LoopN, masks, |
13526 | strategy, state)); |
13527 | }; |
13528 | |
13529 | state.lateKLoopCheck = false; |
13530 | bool success = assignAllMasks(); |
13531 | if (!success && state.vflagStorage.isInvalid()) { |
13532 | status << "Retrying with virtual flags." << status_stream::endl; |
13533 | allocVFlagStorage(strategy, state); |
13534 | success = assignAllMasks(); |
13535 | state.lateKLoopCheck = true; |
13536 | } |
13537 | |
13538 | if (!success) return false; |
13539 | |
13540 | loadMasks(masks, state.remainders, strategy, state); |
13541 | |
13542 | // Temporary: move add64 out of the way (later: general cramming). |
13543 | if (state.add64.isValid()) { |
13544 | auto oldAdd64 = state.add64; |
13545 | state.ra.safeRelease(state.add64); |
13546 | state.add64 = state.ra.alloc_sub<uint32_t>(); |
13547 | if (oldAdd64 != state.add64) mov(1, state.add64, oldAdd64); |
13548 | } |
13549 | |
13550 | // Allocate data registers. |
13551 | gemmAllocRegs(problem, strategy, state); |
13552 | gemmAllocAoBoRegs(strategy, state); |
13553 | |
13554 | // Allocate address registers for A/B loads. We don't need C addresses yet. |
13555 | allocAddrRegs(state.A_addrs, state.A_layout, problem.A, strategy.A, state); |
13556 | allocAddrRegs(state.B_addrs, state.B_layout, problem.B, strategy.B, state); |
13557 | allocAddrRegs(state.Ap_addrs, state.Ap_layout, globalA, strategy.A_prefetch, |
13558 | state); |
13559 | allocAddrRegs(state.Bp_addrs, state.Bp_layout, globalB, strategy.B_prefetch, |
13560 | state); |
13561 | allocAddrRegs(state.Ai_addrs, state.Ai_layout, state.Ai, state.Ai_strategy, |
13562 | state); |
13563 | allocAddrRegs(state.Bi_addrs, state.Bi_layout, state.Bi, state.Bi_strategy, |
13564 | state); |
13565 | allocAddrRegs(state.Ao_addrs, state.Ao_layout, state.Ao, state.Ao_strategy, |
13566 | state); |
13567 | allocAddrRegs(state.Bo_addrs, state.Bo_layout, state.Bo, state.Bo_strategy, |
13568 | state); |
13569 | |
13570 | // Free up some C registers temporarily for use in address calculations. |
13571 | releaseRanges(state.C_regs, state); |
13572 | |
13573 | // Set up address registers. |
13574 | gemmCacheLDABMultiples(problem, strategy, state); |
13575 | setupAddr(Ta_ext, state.Ap_addrs, state.effAp, state.Ap_layout, |
13576 | state.inputs.lda, globalA, strategy.A_prefetch, strategy, state, |
13577 | Ap_params, state.ldaMultiples); |
13578 | setupAddr(Tb_ext, state.Bp_addrs, state.effBp, state.Bp_layout, |
13579 | state.inputs.ldb, globalB, strategy.B_prefetch, strategy, state, |
13580 | Bp_params, state.ldbMultiples); |
13581 | setupAddr(Ta_ext, state.Ai_addrs, state.effAi, state.Ai_layout, |
13582 | state.inputs.lda, state.Ai, state.Ai_strategy, strategy, state, |
13583 | Ai_params, state.ldaMultiples); |
13584 | setupAddr(Tb_ext, state.Bi_addrs, state.effBi, state.Bi_layout, |
13585 | state.inputs.ldb, state.Bi, state.Bi_strategy, strategy, state, |
13586 | Bi_params, state.ldbMultiples); |
13587 | setupAddr(Ta, state.Ao_addrs, state.effAo, state.Ao_layout, Subregister(), |
13588 | state.Ao, state.Ao_strategy, strategy, state); |
13589 | setupAddr(Tb, state.Bo_addrs, state.effBo, state.Bo_layout, Subregister(), |
13590 | state.Bo, state.Bo_strategy, strategy, state); |
13591 | setupAddr(Ta_load, state.A_addrs, state.effA, state.A_layout, |
13592 | state.inputs.lda, problem.A, strategy.A, strategy, state, A_params, |
13593 | state.ldaMultiples); |
13594 | setupAddr(Tb_load, state.B_addrs, state.effB, state.B_layout, |
13595 | state.inputs.ldb, problem.B, strategy.B, strategy, state, B_params, |
13596 | state.ldbMultiples); |
13597 | |
13598 | // Free unneeded registers after address setup. |
13599 | if (!needsKLoopReset(problem)) { |
13600 | if (!state.isNested) { |
13601 | state.ra.safeRelease(state.h0); |
13602 | if (strategy.A.address2D |
13603 | && (!strategy.prefetchA || strategy.A_prefetch.address2D)) |
13604 | state.ra.safeRelease(state.inputs.lda); |
13605 | if (strategy.B.address2D |
13606 | && (!strategy.prefetchB || strategy.B_prefetch.address2D)) |
13607 | state.ra.safeRelease(state.inputs.ldb); |
13608 | if (!problem.hasBinaryPostOp()) |
13609 | if (!strategy.C.address2D |
13610 | && (!strategy.prefetchC |
13611 | || !strategy.C_prefetch.address2D)) { |
13612 | state.ra.safeRelease(state.i0); |
13613 | state.ra.safeRelease(state.j0); |
13614 | } |
13615 | } |
13616 | if (state.Ai_strategy.address2D) { |
13617 | if (Ai_params.offR != A_params.offR) |
13618 | state.ra.safeRelease(Ai_params.offR); |
13619 | if (Ai_params.offC != A_params.offC) |
13620 | state.ra.safeRelease(Ai_params.offC); |
13621 | } |
13622 | if (state.Bi_strategy.address2D) { |
13623 | if (Bi_params.offR != B_params.offR) |
13624 | state.ra.safeRelease(Bi_params.offR); |
13625 | if (Bi_params.offC != B_params.offC) |
13626 | state.ra.safeRelease(Bi_params.offC); |
13627 | } |
13628 | if (strategy.A_prefetch.address2D) { |
13629 | if (Ap_params.offR != A_params.offR) |
13630 | state.ra.safeRelease(Ap_params.offR); |
13631 | if (Ap_params.offC != A_params.offC) |
13632 | state.ra.safeRelease(Ap_params.offC); |
13633 | } |
13634 | if (strategy.B_prefetch.address2D) { |
13635 | if (Bp_params.offR != B_params.offR) |
13636 | state.ra.safeRelease(Bp_params.offR); |
13637 | if (Bp_params.offC != B_params.offC) |
13638 | state.ra.safeRelease(Bp_params.offC); |
13639 | } |
13640 | |
13641 | if (!one_of(state.effAp, state.effA, state.effAi)) |
13642 | state.ra.safeRelease(state.effAp); |
13643 | if (!one_of(state.effBp, state.effB, state.effBi)) |
13644 | state.ra.safeRelease(state.effBp); |
13645 | } |
13646 | |
13647 | releaseLDMultiples(state.ldaMultiples, state); |
13648 | releaseLDMultiples(state.ldbMultiples, state); |
13649 | releaseIndexVec(state); |
13650 | |
13651 | reclaimRanges(state.C_regs, state); |
13652 | |
13653 | // Allocate tokens. |
13654 | success = true; |
13655 | for (int q = 0; q < strategy.A_copies; q++) |
13656 | success &= allocateTokens(state.A_layout, state.A_regs[q], state); |
13657 | for (int q = 0; q < strategy.B_copies; q++) |
13658 | success &= allocateTokens(state.B_layout, state.B_regs[q], state); |
13659 | for (int q = 0; q < strategy.slmCopies; q++) { |
13660 | if (strategy.slmA) |
13661 | success &= allocateTokens(state.Ai_layout, state.Ai_regs[q], state); |
13662 | if (strategy.slmB) |
13663 | success &= allocateTokens(state.Bi_layout, state.Bi_regs[q], state); |
13664 | } |
13665 | if (strategy.slmA && !state.aioShare) |
13666 | success &= allocateTokens(state.Ao_layout, state.Ao_regs, state); |
13667 | if (strategy.slmB && !state.bioShare) |
13668 | success &= allocateTokens(state.Bo_layout, state.Bo_regs, state); |
13669 | success &= allocateTokens( |
13670 | state.Ap_layout, state.Ap_regs, state, state.Ap_addrs); |
13671 | success &= allocateTokens( |
13672 | state.Bp_layout, state.Bp_regs, state, state.Bp_addrs); |
13673 | if (!success) { |
13674 | if (hw >= HW::Gen12LP) |
13675 | status << "Not enough tokens for k loop." << status_stream::endl; |
13676 | clearTokenAllocations(hw, state); |
13677 | } |
13678 | |
13679 | // Load C now if configured. |
13680 | // - temporarily free A/B data regs to use as C headers |
13681 | // - do beta scaling |
13682 | if (cLoadAhead) { |
13683 | if (problem.checkBeta0 && !problem.beta_real.fixed()) stub(); |
13684 | if (state.C_accCount > 0) stub(); |
13685 | if (strategy.kParallelLocal) stub(); |
13686 | |
13687 | releaseRanges(state.A_regs, state); |
13688 | releaseRanges(state.B_regs, state); |
13689 | if (!state.Ar_regs.empty()) releaseRanges(state.Ar_regs, state); |
13690 | if (!state.Br_regs.empty()) releaseRanges(state.Br_regs, state); |
13691 | |
13692 | status << "Loading C" << status_stream::endl; |
13693 | gemmAccessC(COperation::Load, problem, strategy, state); |
13694 | |
13695 | gemmBetaScale(problem, strategy, state); |
13696 | if (!state.Br_regs.empty()) reclaimRanges(state.Br_regs, state); |
13697 | if (!state.Ar_regs.empty()) reclaimRanges(state.Ar_regs, state); |
13698 | reclaimRanges(state.B_regs, state); |
13699 | reclaimRanges(state.A_regs, state); |
13700 | } |
13701 | |
13702 | for (int q = 0; q < state.C_count; q++) |
13703 | releaseLDMultiples(state.ldcMultiples[q], state); |
13704 | releaseIndexVec(state); |
13705 | |
13706 | // Release 64-bit emulation registers as they aren't needed in the inner loop. |
13707 | // Could also move r0 to acc here. |
13708 | if (state.emulate.temp[0].isValid()) { |
13709 | for (int q = 0; q < 2; q++) { |
13710 | state.emulate64TempSave[q] = state.emulate.temp[q]; |
13711 | state.ra.safeRelease(state.emulate.temp[q]); |
13712 | } |
13713 | if (GRF::bytes(hw) == 64) { |
13714 | // Need a whole flag register to do emulated SIMD16 arithmetic. |
13715 | state.emulate.flag = state.raVFlag.alloc(); |
13716 | state.emulate.flagOffset = 0; |
13717 | } else { |
13718 | state.emulate.flag = state.flagAP; |
13719 | state.emulate.flagOffset = 8; |
13720 | state.lateKLoopCheck = false; |
13721 | } |
13722 | } |
13723 | |
13724 | return true; |
13725 | } |
13726 | |
13727 | template <HW hw> |
13728 | void gemm_kernel_generator_t<hw>::gemmAccumulateCTeardown( |
13729 | GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state) { |
13730 | // We're done with A and B. Free their address, data, and flag registers. |
13731 | // Also done with loop counter. |
13732 | safeReleaseMaskAssignments(state.AB_masks, state); |
13733 | safeReleaseRanges(state.A_addrs, state); |
13734 | safeReleaseRanges(state.B_addrs, state); |
13735 | safeReleaseRanges(state.Ai_addrs, state); |
13736 | safeReleaseRanges(state.Bi_addrs, state); |
13737 | safeReleaseRanges(state.Ao_addrs, state); |
13738 | safeReleaseRanges(state.Bo_addrs, state); |
13739 | safeReleaseRanges(state.Ap_addrs, state); |
13740 | safeReleaseRanges(state.Bp_addrs, state); |
13741 | |
13742 | safeReleaseRanges(state.A_regs, state); |
13743 | safeReleaseRanges(state.Ar_regs, state); |
13744 | safeReleaseRanges(state.Ai_regs, state); |
13745 | safeReleaseRanges(state.Ao_regs, state); |
13746 | safeReleaseRanges(state.Ap_regs, state); |
13747 | safeReleaseRanges(state.B_regs, state); |
13748 | safeReleaseRanges(state.Br_regs, state); |
13749 | safeReleaseRanges(state.Bi_regs, state); |
13750 | safeReleaseRanges(state.Bo_regs, state); |
13751 | safeReleaseRanges(state.Bp_regs, state); |
13752 | state.ra.safeRelease(state.broadcast_regs); |
13753 | safeReleaseRanges(state.tempMul_regs, state); |
13754 | clearTokenAllocations(hw, state); |
13755 | |
13756 | state.A_layout.clear(); |
13757 | state.B_layout.clear(); |
13758 | state.Ai_layout.clear(); |
13759 | state.Bi_layout.clear(); |
13760 | state.Ao_layout.clear(); |
13761 | state.Bo_layout.clear(); |
13762 | state.Ar_layout.clear(); |
13763 | state.Br_layout.clear(); |
13764 | state.Ap_layout.clear(); |
13765 | state.Bp_layout.clear(); |
13766 | state.Cp_layout.clear(); |
13767 | |
13768 | if (state.systolicSumA || state.systolicSumB) |
13769 | setupTeardownAccumulateSumSystolic( |
13770 | false, Type::invalid, problem, strategy, state); |
13771 | |
13772 | // Restore effA/B if needed. |
13773 | bool restoreEffAB = false; |
13774 | |
13775 | restoreEffAB |= (problem.abOffset == ABOffset::Load); |
13776 | |
13777 | if (restoreEffAB) { |
13778 | Subregister aiOff, biOff; |
13779 | Subregister aiOffR, biOffR; |
13780 | Subregister aiOffC, biOffC; |
13781 | if (strategy.slmA) |
13782 | gemmCalcWorkshareAOffset(aiOff, aiOffR, aiOffC, state.Ai, |
13783 | state.Ai_strategy, state.ma_slm, state.ka_slm, problem, |
13784 | strategy, state); |
13785 | if (strategy.slmB) |
13786 | gemmCalcWorkshareBOffset(biOff, biOffR, biOffC, state.Bi, |
13787 | state.Bi_strategy, state.kb_slm, state.nb_slm, problem, |
13788 | strategy, state); |
13789 | if (strategy.slmA) |
13790 | eadd(1, state.effAi, state.effAi, -aiOff, strategy, state); |
13791 | if (strategy.slmB) |
13792 | eadd(1, state.effBi, state.effBi, -biOff, strategy, state); |
13793 | |
13794 | state.ra.safeRelease(aiOff); |
13795 | state.ra.safeRelease(biOff); |
13796 | state.ra.safeRelease(aiOffR); |
13797 | state.ra.safeRelease(biOffR); |
13798 | state.ra.safeRelease(aiOffC); |
13799 | state.ra.safeRelease(biOffC); |
13800 | } |
13801 | |
13802 | // Restore A/B addresses and strategies that were modified by SLM copies. |
13803 | if (strategy.slmA) { |
13804 | state.ra.safeRelease(state.effA); |
13805 | state.ra.safeRelease(state.effAo); |
13806 | state.effA = state.effAi; |
13807 | state.effAi = invalid; |
13808 | state.Ta_load = problem.Ta_ext; |
13809 | problem.A = state.Ai; |
13810 | strategy.A = state.Ai_strategy; |
13811 | } |
13812 | if (strategy.slmB) { |
13813 | state.ra.safeRelease(state.effB); |
13814 | state.ra.safeRelease(state.effBo); |
13815 | state.effB = state.effBi; |
13816 | state.effBi = invalid; |
13817 | state.Tb_load = problem.Tb_ext; |
13818 | problem.B = state.Bi; |
13819 | strategy.B = state.Bi_strategy; |
13820 | } |
13821 | |
13822 | // Put accumulators with the rest of C. |
13823 | if (state.C_accCount > 0) { |
13824 | // Reclaim the bottom registers of C. |
13825 | reclaimRanges(state.C_regs[0], state); |
13826 | |
13827 | auto e = elementsPerGRF<uint32_t>(hw); |
13828 | for (int i = 0; i < state.C_accCount; i += 2) |
13829 | mov<uint32_t>(2 * e, state.C_regs[0][i], AccumulatorRegister(i)); |
13830 | } |
13831 | |
13832 | // Restore emulation registers. |
13833 | if (state.emulate64TempSave[0].isValid()) { |
13834 | for (int q = 0; q < 2; q++) { |
13835 | state.emulate.temp[q] = state.emulate64TempSave[q]; |
13836 | if (state.emulate64TempSave[q].isValid()) |
13837 | state.ra.claim(state.emulate64TempSave[q]); |
13838 | } |
13839 | if (GRF::bytes(hw) == 64) state.raVFlag.release(state.emulate.flag); |
13840 | state.emulate.flag = invalid; |
13841 | state.emulate.flagOffset = 0; |
13842 | } |
13843 | } |
13844 | |
13845 | // Perform the body of the GEMM computation, updating a block of C. |
13846 | template <HW hw> |
13847 | bool gemm_kernel_generator_t<hw>::gemmAccumulateC( |
13848 | GEMMProblem &problem_, GEMMStrategy &strategy_, GEMMState &state) { |
13849 | if (strategy_.fixedSystolic) { |
13850 | if (problem_.sumA || problem_.sumB |
13851 | || problem_.abOffset == ABOffset::Calc) |
13852 | stub(); |
13853 | return strategy_.splitCopy |
13854 | ? sysgemm2AccumulateC(problem_, strategy_, state) |
13855 | : sysgemmAccumulateC(problem_, strategy_, state); |
13856 | } |
13857 | |
13858 | auto problem = problem_; |
13859 | auto strategy = strategy_; |
13860 | |
13861 | if (!gemmAccumulateCSetup(problem, strategy, state)) return false; |
13862 | |
13863 | // Synthesize k loop. If configured, choose between 32-bit adds and 64-bit adds. |
13864 | if (strategy.checkAdd32 && state.add64.isValid()) { |
13865 | Label loop64, done; |
13866 | bool success = true; |
13867 | |
13868 | cmp(1 | ne | state.flagAP, state.add64, uint16_t(0)); |
13869 | jmpi(1 | state.flagAP, loop64); |
13870 | state.ra.safeRelease(state.add64); |
13871 | |
13872 | status << "k loop: 32-bit address update" << status_stream::endl; |
13873 | strategy.emulate.emulate64_add32 = true; |
13874 | auto substate32 = state; |
13875 | success &= gemmKLoop(problem, strategy, substate32); |
13876 | jmpi(1, done); |
13877 | |
13878 | mark(loop64); |
13879 | status << "k loop: 64-bit address update" << status_stream::endl; |
13880 | strategy.emulate.emulate64_add32 = false; |
13881 | success &= gemmKLoop(problem, strategy, state); |
13882 | |
13883 | mark(done); |
13884 | if (!success) return false; |
13885 | } else { |
13886 | state.ra.safeRelease(state.add64); |
13887 | if (!gemmKLoop(problem, strategy, state)) return false; |
13888 | } |
13889 | |
13890 | gemmAccumulateCTeardown(problem, strategy, state); |
13891 | |
13892 | return true; |
13893 | } |
13894 | |
13895 | template <HW hw> |
13896 | void gemm_kernel_generator_t<hw>::setupCAddr0(GRFRange (&C_addr0)[2], |
13897 | GRFRange (&C_addr0Unmasked)[2], const vector<RegisterBlock> &C_layout, |
13898 | const vector<RegisterBlock> &C_layoutUnmasked, int C_count, |
13899 | GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state, |
13900 | const Address2DParams *params) { |
13901 | Address2DParams defaultParams; |
13902 | if (!params) { |
13903 | defaultParams.rows = state.inputs.m; |
13904 | defaultParams.cols = state.inputs.n; |
13905 | defaultParams.offR = state.i0; |
13906 | defaultParams.offC = state.j0; |
13907 | defaultParams.remR = state.remainders[LoopM]; |
13908 | defaultParams.remC = state.remainders[LoopN]; |
13909 | params = &defaultParams; |
13910 | } |
13911 | for (int q = 0; q < C_count; q++) { |
13912 | C_addr0[q] = state.ra.alloc_range( |
13913 | addrGRFCount(problem.C, strategy.C, C_layout[0])); |
13914 | setupAddr(C_addr0[q], state.effC[q], C_layout[0], state.inputs.ldc[q], |
13915 | problem.Tc.size(), problem.C, strategy.C, strategy, state, |
13916 | *params, state.ldcMultiples[q]); |
13917 | } |
13918 | if (!C_layoutUnmasked.empty()) |
13919 | for (int q = 0; q < C_count; q++) { |
13920 | C_addr0Unmasked[q] = state.ra.alloc_range( |
13921 | addrGRFCount(problem.C, strategy.C, C_layoutUnmasked[0])); |
13922 | setupAddr(C_addr0Unmasked[q], state.effC[q], C_layoutUnmasked[0], |
13923 | state.inputs.ldc[q], problem.Tc.size(), problem.C, |
13924 | strategy.C, strategy, state, *params, |
13925 | state.ldcMultiples[q]); |
13926 | } |
13927 | } |
13928 | |
13929 | template <HW hw> |
13930 | bool gemm_kernel_generator_t<hw>::gemmUpdateC( |
13931 | GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state) { |
13932 | |
13933 | auto Ts = problem.Ts; |
13934 | |
13935 | status << "C update" << status_stream::endl; |
13936 | |
13937 | auto &alphar = problem.alpha_real; |
13938 | auto &betar = problem.beta_real; |
13939 | |
13940 | if (strategy.cLoadAhead) { |
13941 | betar = 0; |
13942 | if (!problem.alpha1()) stub(); |
13943 | } |
13944 | |
13945 | // C early offset. |
13946 | if (problem.cOffset == COffset::Pre) |
13947 | if (!gemmApplyCOffsetDispatch(problem, strategy, state)) return false; |
13948 | |
13949 | // Prepare legacy eltwise postop injector if configured. |
13950 | GRFRange postOpScratch; |
13951 | if (useEltwiseInjector(problem)) { |
13952 | if (problem.hasBinaryPostOp()) stub(); |
13953 | |
13954 | const int eu_count = 0; |
13955 | postOpInjector.reset(new Injector(this, problem.Ts.get_dnnl_type(), |
13956 | problem.postOps, eu_count, GRFRange(), problem.postOpFwd)); |
13957 | if (!postOpInjector) stub(); |
13958 | |
13959 | postOpScratch = state.ra.try_alloc_range( |
13960 | postOpInjector->preferred_scratch_regs()); |
13961 | if (postOpScratch.isInvalid()) |
13962 | postOpScratch |
13963 | = state.ra.alloc_range(postOpInjector->min_scratch_regs()); |
13964 | postOpInjector->set_scratch(postOpScratch); |
13965 | } |
13966 | |
13967 | // Convert C to the type of alpha/beta if needed and if possible (no data size change). |
13968 | // If not possible, must be done at a lower level during C update. |
13969 | bool successfulConvert = true; |
13970 | |
13971 | if (problem.needsTsConvert()) |
13972 | successfulConvert = gemmConvertC(Ts, problem, strategy, state); |
13973 | |
13974 | // Scale by alpha now if alpha and beta are both nontrivial. Todo: move above beta = 0 check, |
13975 | // handle double precision correctly (load alpha to register first). |
13976 | // Also scale if atomically updating C or for split-complex. |
13977 | bool nontrivialAlpha = !problem.alpha1() && !problem.alphaM1(); |
13978 | bool forceScale = !problem.alpha1() && strategy.C.atomic; |
13979 | |
13980 | if (!problem.alpha1() && problem.hasBinaryPostOp()) { |
13981 | forceScale = true; |
13982 | if (!successfulConvert) stub(); |
13983 | } |
13984 | |
13985 | if (successfulConvert |
13986 | && ((nontrivialAlpha && (!problem.beta1() || strategy.doubleWA)) |
13987 | || forceScale)) { |
13988 | |
13989 | if (alphar == -1) { |
13990 | map(hw, Ts.real(), state.C_regs[0], state.C_regs[0], strategy, |
13991 | [&](int esize, GRF acc, GRF _) { mov(esize, acc, -acc); }); |
13992 | } else if (alphar != 1) { |
13993 | map(hw, Ts.real(), state.C_regs[0], state.C_regs[0], strategy, |
13994 | [&](int esize, GRF acc, GRF _) { |
13995 | alphar.fixed() |
13996 | ? mul(esize, acc, acc, cast(Ts.real(), alphar)) |
13997 | : mul(esize, acc, acc, |
13998 | alphar.getRegAvoiding(hw, acc)); |
13999 | }); |
14000 | } |
14001 | |
14002 | alphar = 1; |
14003 | } |
14004 | |
14005 | // Do the actual updating. |
14006 | if (!gemmAccessC(COperation::UpdateStore, problem, strategy, state)) |
14007 | return false; |
14008 | |
14009 | // Postop cleanup. |
14010 | if (useEltwiseInjector(problem)) { |
14011 | postOpInjector.reset(); |
14012 | state.ra.safeRelease(postOpScratch); |
14013 | } |
14014 | |
14015 | // Free C data and layout. |
14016 | safeReleaseRanges(state.C_regs, state); |
14017 | state.C_layout.clear(); |
14018 | state.C_layoutExt.clear(); |
14019 | |
14020 | state.raVFlag.safeRelease(state.flagSwizzle); |
14021 | |
14022 | // Success! |
14023 | return true; |
14024 | } |
14025 | |
14026 | // Load from, update, and/or store to C, with complete remainder handling. |
14027 | // If op == COperation::Load, only load C. |
14028 | // If op == COperation::Update, load and update C. |
14029 | // If op == COperation::UpdateStore, perform full C update with alpha/beta scaling. Unless state.isNested == true, assumed |
14030 | // to be the conclusion of the kernel. |
14031 | template <HW hw> |
14032 | bool gemm_kernel_generator_t<hw>::gemmAccessC(COperation op, |
14033 | GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state) { |
14034 | Label labelStdCRemainder, labelAltCRemainder, labelBlock2DCRemainder, |
14035 | labelCRemDone, labelSkip; |
14036 | |
14037 | int C_count = (op == COperation::UpdateStore) ? state.C_count : 1; |
14038 | bool remainderM |
14039 | = (strategy.remHandling[LoopM] != RemainderHandling::Ignore); |
14040 | bool remainderN |
14041 | = (strategy.remHandling[LoopN] != RemainderHandling::Ignore); |
14042 | bool remM_C, remN_C; |
14043 | getCRemainders(problem, strategy, remM_C, remN_C); |
14044 | bool altCRemainder = strategy.altCRemainder && !strategy.C.padded |
14045 | && (remainderM || remainderN || problem.gemmt()); |
14046 | bool block2DCRemainder = strategy.block2DCRemainder && !strategy.C.padded |
14047 | && (remainderM || remainderN); |
14048 | bool stdCRemainder = !(altCRemainder |
14049 | && (strategy.remHandling[LoopM] |
14050 | == RemainderHandling::KnownRemainder) |
14051 | && (strategy.remHandling[LoopN] |
14052 | == RemainderHandling::KnownRemainder)); |
14053 | |
14054 | if ((op != COperation::UpdateStore) && strategy.C.atomic) stub(); |
14055 | |
14056 | if (state.allowEmptyC && (remainderM || remainderN)) { |
14057 | if (!state.isNested) stub(); |
14058 | int simt = strategy.fused ? 16 : 1; |
14059 | cmp(simt | le | f0[0], null.ud(), state.remainders[LoopM], 0); |
14060 | cmp(simt | le | f1[0], null.ud(), state.remainders[LoopN], 0); |
14061 | strategy.fused ? goto12(16 | f0[0] | anyv, labelSkip) |
14062 | : ejmpi(1 | f0[0] | anyv, labelSkip); |
14063 | } |
14064 | |
14065 | bool splitUpdateStore = (problem.cOffset == COffset::Post); |
14066 | |
14067 | // New post-op path: do all post-ops up to sum, if any. |
14068 | int poSum = 0; |
14069 | bool newPostOps = !useEltwiseInjector(problem); |
14070 | if (op == COperation::UpdateStore && newPostOps) { |
14071 | for (poSum = 0; poSum < problem.postOps.len(); poSum++) |
14072 | if (problem.postOps.entry_[poSum].kind == primitive_kind::sum) |
14073 | break; |
14074 | gemmApplyPostOps(0, poSum, problem, strategy, state); |
14075 | splitUpdateStore |= (poSum + 1 < problem.postOps.len()); |
14076 | } |
14077 | |
14078 | if (op == COperation::UpdateStore && splitUpdateStore) { |
14079 | // C postoffset is implemented by splitting the update and store steps. |
14080 | bool ok = true; |
14081 | bool oldAllowEmptyC = state.allowEmptyC; |
14082 | state.allowEmptyC = false; |
14083 | |
14084 | if (!(problem.alpha1() && problem.beta0())) |
14085 | ok = ok |
14086 | && gemmAccessC( |
14087 | COperation::Update, problem, strategy, state); |
14088 | |
14089 | auto storeProblem = problem; |
14090 | storeProblem.cOffset = COffset::None; |
14091 | storeProblem.alpha_real = 1; |
14092 | storeProblem.alpha_imag = 0; |
14093 | storeProblem.beta_real = 0; |
14094 | storeProblem.beta_imag = 0; |
14095 | |
14096 | // Do any post-sum post-ops |
14097 | if (newPostOps) |
14098 | gemmApplyPostOps( |
14099 | poSum + 1, problem.postOps.len(), problem, strategy, state); |
14100 | storeProblem.postOps = post_ops_t {}; |
14101 | |
14102 | if (problem.cOffset == COffset::Post) { |
14103 | gemmConvertC(problem.Tc, problem, strategy, state); |
14104 | ok = ok && gemmApplyCOffsetDispatch(problem, strategy, state); |
14105 | } |
14106 | |
14107 | ok = ok |
14108 | && gemmAccessC( |
14109 | COperation::UpdateStore, storeProblem, strategy, state); |
14110 | |
14111 | state.allowEmptyC = oldAllowEmptyC; |
14112 | if (ok && state.allowEmptyC && (remainderM || remainderN)) { |
14113 | mark(labelSkip); |
14114 | if (strategy.fused) join(16); |
14115 | } |
14116 | return ok; |
14117 | } |
14118 | |
14119 | auto leave = [&] { |
14120 | if (state.isNested || (op != COperation::UpdateStore)) |
14121 | jmpi(1, labelCRemDone); |
14122 | else |
14123 | epilogue(strategy, state); |
14124 | }; |
14125 | |
14126 | if (stdCRemainder) { |
14127 | // Check to see if we should jump to alternate C remainder handling path, when enabled: |
14128 | // - if this a remainder kernel |
14129 | // - for triangular updates, if the diagonal crosses this block. |
14130 | // When fusing, check diagonal for thread 0 for (fused in n) upper/m lower, thread 1 for n lower/m upper. |
14131 | if (altCRemainder || block2DCRemainder) { |
14132 | if (remainderM || remainderN) { |
14133 | cmp(1 | lt | f0[0], null.ud(), state.remaindersFused[LoopM], |
14134 | strategy.unroll[LoopM]); |
14135 | cmp(1 | lt | f1[0], null.ud(), state.remaindersFused[LoopN], |
14136 | strategy.unroll[LoopN]); |
14137 | } |
14138 | |
14139 | auto &remLabel = block2DCRemainder ? labelBlock2DCRemainder |
14140 | : labelAltCRemainder; |
14141 | if (remainderM || remainderN) ejmpi(1 | f0[0] | anyv, remLabel); |
14142 | } |
14143 | |
14144 | if (block2DCRemainder && !altCRemainder) mark(labelStdCRemainder); |
14145 | |
14146 | // Release the all-purpose flag temporarily to free up flag registers if it won't be needed. |
14147 | auto saveFlagAP = state.flagAP; |
14148 | if (!problem.hasPostOp()) |
14149 | if (!strategy.fused && !strategy.noJumpTables |
14150 | && state.emulate.flag != state.flagAP) |
14151 | state.raVFlag.safeRelease(state.flagAP); |
14152 | |
14153 | // Decide on the C remainder handling strategy. |
14154 | bool fragments[2] = {false, false}; |
14155 | bool fragPositives[2] = {true, true}; |
14156 | int fragSizes[2] = {1 << 16, 1 << 16}; |
14157 | |
14158 | // Check for fragmenting. |
14159 | auto &C_layoutExt = state.C_layoutExt; |
14160 | auto &C_layoutExtUnmasked = state.C_layoutExtUnmasked; |
14161 | bool remDescs[2] = {false, false}; |
14162 | bool remMasks[2] = {false, false}; |
14163 | |
14164 | // Loop over rows (rc = 0) and columns (rc = 1). |
14165 | for (int rc = 0; rc < 2; rc++) { |
14166 | if (!(rc ? remN_C : remM_C)) |
14167 | continue; // Skip if not doing remainder handling in this dimension. |
14168 | |
14169 | for (auto &l : C_layoutExt) { |
14170 | auto qFragment = rc ? l.colFragment : l.rowFragment; |
14171 | bool qZeroOK = rc ? l.noColsOK : l.noRowsOK; |
14172 | bool qMasked = rc ? (bool)l.colMask : (bool)l.rowMask; |
14173 | bool qDescRem = rc ? l.descRemC : l.descRemR; |
14174 | |
14175 | if (qFragment > 0) { |
14176 | fragments[rc] = true; |
14177 | fragSizes[rc] = std::min<int>(fragSizes[rc], qFragment); |
14178 | if (qZeroOK) fragPositives[rc] = false; |
14179 | |
14180 | if (qFragment > 1) { |
14181 | remDescs[rc] |= qDescRem; |
14182 | remMasks[rc] |= !qDescRem; |
14183 | } |
14184 | } else |
14185 | remMasks[rc] |= qMasked; |
14186 | } |
14187 | } |
14188 | |
14189 | // Disable fragmentation if fragment size is bigger than unroll. |
14190 | fragments[0] &= fragSizes[0] < strategy.unroll[LoopM]; |
14191 | fragments[1] &= fragSizes[1] < strategy.unroll[LoopN]; |
14192 | |
14193 | // Sanity check the requirements. |
14194 | if ((remDescs[0] && remMasks[0]) || (remDescs[1] && remMasks[1])) { |
14195 | status << "Different remainder types mixed in C layout." |
14196 | << status_stream::endl; |
14197 | return false; |
14198 | } |
14199 | if (remMasks[0] && remMasks[1]) { |
14200 | status << "Both dimensions are masked (not supported)." |
14201 | << status_stream::endl; |
14202 | return false; |
14203 | } |
14204 | if (remDescs[0] && remDescs[1]) { |
14205 | status << "Both dimensions use descriptors (not supported)." |
14206 | << status_stream::endl; |
14207 | return false; |
14208 | } |
14209 | |
14210 | // Set remainder handling types. |
14211 | StdCRemType remTypes[2] = {StdCRemType::Ignore, StdCRemType::Ignore}; |
14212 | for (int rc = 0; rc < 2; rc++) { |
14213 | if (remDescs[rc]) |
14214 | remTypes[rc] = StdCRemType::Descriptor; |
14215 | else if (remMasks[rc]) |
14216 | remTypes[rc] = StdCRemType::Mask; |
14217 | } |
14218 | |
14219 | // Decide whether to do m or n first. Criteria, in order of priority: |
14220 | // - Do an ignored dimension first. |
14221 | // - Do a fragmented dimension first. |
14222 | // - Do descriptors first. |
14223 | // - Do whichever dimension of C is strided first. |
14224 | bool nFirst; |
14225 | if (remTypes[0] == StdCRemType::Ignore |
14226 | || remTypes[1] == StdCRemType::Ignore) |
14227 | nFirst = (remTypes[1] == StdCRemType::Ignore); |
14228 | else if (fragments[0] != fragments[1]) |
14229 | nFirst = fragments[1]; |
14230 | else if (remDescs[0] || remDescs[1]) |
14231 | nFirst = remDescs[1]; |
14232 | else |
14233 | nFirst = (problem.C.layout == MatrixLayout::N); |
14234 | |
14235 | // Cache ldc multiples. |
14236 | gemmCacheLDCMultiples(problem, strategy, state); |
14237 | |
14238 | // Prepare for load/store descriptor generation. |
14239 | if (remDescs[0] || remDescs[1]) |
14240 | setupTeardownLoadStoreDesc(true, C_layoutExt, strategy, state); |
14241 | |
14242 | // Set up address for the beginning of C. |
14243 | GRFRange C_addr0[2], C_addr0Unmasked[2]; |
14244 | setupCAddr0(C_addr0, C_addr0Unmasked, C_layoutExt, C_layoutExtUnmasked, |
14245 | C_count, problem, strategy, state); |
14246 | |
14247 | // Try to load C masks. If that fails, fragment the masked dimension down to the size of current blocks. |
14248 | vector<MaskAssignment> masks; |
14249 | if (!assignMasks(C_layoutExt, LoopM, LoopN, masks, strategy, state)) { |
14250 | for (int rc = 0; rc < 2; rc++) { |
14251 | if (remMasks[rc]) { |
14252 | fragments[rc] = true; |
14253 | fragSizes[rc] = rc ? C_layoutExt[0].nc : C_layoutExt[0].nr; |
14254 | } |
14255 | } |
14256 | } else |
14257 | loadMasks(masks, state.remainders, strategy, state); |
14258 | |
14259 | // Call the remainder handling routine. If it fails, try again, switching M and N. |
14260 | // If that still fails, then try again with complete fragmentation if partial |
14261 | // fragmentation attempted the first time. |
14262 | bool columns[2] = {nFirst, !nFirst}; |
14263 | bool switchedColumns[2] = {!nFirst, nFirst}; |
14264 | do { |
14265 | if (doStdCRemainder(C_layoutExt, C_layoutExtUnmasked, false, |
14266 | columns, remTypes, fragments, fragPositives, fragSizes, |
14267 | C_addr0, C_addr0Unmasked, op, masks, problem, strategy, |
14268 | state)) |
14269 | break; |
14270 | if (doStdCRemainder(C_layoutExt, C_layoutExtUnmasked, false, |
14271 | switchedColumns, remTypes, fragments, fragPositives, |
14272 | fragSizes, C_addr0, C_addr0Unmasked, op, masks, problem, |
14273 | strategy, state)) |
14274 | break; |
14275 | |
14276 | if ((fragments[0] && (fragSizes[0] > 1)) |
14277 | || (fragments[1] && (fragSizes[1] > 1))) { |
14278 | fragSizes[0] = fragSizes[1] = 1; |
14279 | |
14280 | if (doStdCRemainder(C_layoutExt, C_layoutExtUnmasked, false, |
14281 | columns, remTypes, fragments, fragPositives, |
14282 | fragSizes, C_addr0, C_addr0Unmasked, op, masks, |
14283 | problem, strategy, state)) |
14284 | break; |
14285 | if (doStdCRemainder(C_layoutExt, C_layoutExtUnmasked, false, |
14286 | switchedColumns, remTypes, fragments, fragPositives, |
14287 | fragSizes, C_addr0, C_addr0Unmasked, op, masks, |
14288 | problem, strategy, state)) |
14289 | break; |
14290 | } |
14291 | return false; |
14292 | } while (false); |
14293 | |
14294 | // Free cached ldc multiples. |
14295 | for (int q = 0; q < state.C_count; q++) |
14296 | releaseLDMultiples(state.ldcMultiples[q], state); |
14297 | releaseIndexVec(state); |
14298 | |
14299 | // Free address header for block 0. |
14300 | for (int q = 0; q < C_count; q++) |
14301 | state.ra.safeRelease(C_addr0[q]); |
14302 | |
14303 | // Free C mask registers. |
14304 | safeReleaseMaskAssignments(masks, state); |
14305 | |
14306 | // Clean up after load/store descriptor generation. |
14307 | if (remDescs[0] || remDescs[1]) |
14308 | setupTeardownLoadStoreDesc(false, C_layoutExt, strategy, state); |
14309 | |
14310 | // Restore all-purpose flag. |
14311 | state.flagAP = saveFlagAP; |
14312 | state.raVFlag.claim(saveFlagAP); |
14313 | |
14314 | // Leave. |
14315 | if (block2DCRemainder || altCRemainder) leave(); |
14316 | } |
14317 | |
14318 | // Do block 2D C remainder handling if enabled. |
14319 | if (block2DCRemainder) { |
14320 | mark(labelBlock2DCRemainder); |
14321 | |
14322 | // Check for transposition. |
14323 | bool memCM = isColMajor(problem.C.layout); |
14324 | bool regCM = isLayoutColMajor(state.C_layout); |
14325 | bool doTranspose = (memCM != regCM); |
14326 | |
14327 | // Check if alignment requirements are met. |
14328 | auto Tc = problem.Tc; |
14329 | uint16_t align = doTranspose ? 4 : 8; |
14330 | for (int q = 0; q < state.C_count; q++) { |
14331 | bool checkAlign = (problem.C.alignment % align) != 0; |
14332 | bool checkWidth |
14333 | = (q == 0 && Tc.size() < 4 && op != COperation::Load); |
14334 | auto &labelNonBlock2DRem |
14335 | = altCRemainder ? labelAltCRemainder : labelStdCRemainder; |
14336 | |
14337 | if (checkAlign) { |
14338 | and_(1 | nz | f0[0], null.uw(), state.effC[q].uw(), align - 1); |
14339 | and_(1 | nz | f1[0], null.uw(), state.inputs.ldc[q].uw(), |
14340 | align - 1); |
14341 | } |
14342 | if (checkWidth) |
14343 | and_(1 | nz | f0[1], null.uw(), |
14344 | state.remainders[isColMajor(problem.C.layout) ? LoopM |
14345 | : LoopN], |
14346 | (4 / Tc) - 1); |
14347 | if (checkAlign) ejmpi(1 | f0[0] | anyv, labelNonBlock2DRem); |
14348 | if (checkWidth) jmpi(1 | f0[1], labelNonBlock2DRem); |
14349 | } |
14350 | |
14351 | // Rustle up a new layout. |
14352 | // Match existing C layout if possible and deemed worthwhile. |
14353 | vector<RegisterBlock> C_layout2D; |
14354 | auto modProblem = problem; |
14355 | auto modStrategy = strategy; |
14356 | auto modState = state; |
14357 | |
14358 | modProblem.C.setAlignment(align); |
14359 | modStrategy.C = state.Cext_strategy; |
14360 | modStrategy.C.newDP = true; |
14361 | modStrategy.C.address2D = true; |
14362 | modStrategy.C.accessType = doTranspose ? AccessType::Block2DTranspose |
14363 | : AccessType::Block2D; |
14364 | |
14365 | auto &C_layout = modState.C_layout; |
14366 | auto &C_layoutExt = modState.C_layoutExt; |
14367 | auto &C_layoutExtUnmasked = modState.C_layoutExtUnmasked; |
14368 | |
14369 | bool inplace = !state.copyC |
14370 | && upgradeLayoutToBlock2D(Tc, C_layoutExt, C_layout2D, remM_C, |
14371 | remN_C, op != COperation::Load, modProblem.C, |
14372 | modStrategy.C); |
14373 | |
14374 | for (auto &block : C_layout2D) |
14375 | inplace = inplace |
14376 | && state.C_regs[0].contiguous( |
14377 | block.offsetReg(), block.nregs()); |
14378 | |
14379 | if (inplace) |
14380 | C_layoutExt = std::move(C_layout2D); |
14381 | else { |
14382 | modState.copyC = true; |
14383 | if (!getRegLayout(problem.Tc_ext, C_layoutExt, |
14384 | strategy.unroll[LoopM], strategy.unroll[LoopN], remM_C, |
14385 | remN_C, true, false, 0, 0, modProblem.C, modStrategy.C)) |
14386 | stub(); |
14387 | } |
14388 | |
14389 | C_layoutExtUnmasked.clear(); |
14390 | |
14391 | for (auto &block : C_layout) |
14392 | block.simdSize = 0; /* unlink C layout from memory */ |
14393 | |
14394 | // Do the update. |
14395 | Address2DParams params; |
14396 | params.rows = state.remainders[LoopM]; |
14397 | params.cols = state.remainders[LoopN]; |
14398 | |
14399 | GRFRange C_addr0[2], C_addr0Unmasked[2]; |
14400 | setupCAddr0(C_addr0, C_addr0Unmasked, C_layoutExt, C_layoutExtUnmasked, |
14401 | C_count, modProblem, modStrategy, modState, ¶ms); |
14402 | |
14403 | bool columns[2] = {false, true}; |
14404 | StdCRemType remTypes[2] = {StdCRemType::Ignore, StdCRemType::Ignore}; |
14405 | bool fragments[2] = {false, false}; |
14406 | bool fragPositives[2] = {false, false}; |
14407 | int fragSizes[2] = {1 << 16, 1 << 16}; |
14408 | vector<MaskAssignment> masks; |
14409 | |
14410 | if (!doStdCRemainder(C_layoutExt, C_layoutExtUnmasked, false, columns, |
14411 | remTypes, fragments, fragPositives, fragSizes, C_addr0, |
14412 | C_addr0Unmasked, op, masks, modProblem, modStrategy, |
14413 | modState)) |
14414 | stub(); |
14415 | |
14416 | if (altCRemainder) leave(); |
14417 | } |
14418 | |
14419 | // Do alternate C remainder handling if enabled. |
14420 | if (altCRemainder) { |
14421 | mark(labelAltCRemainder); |
14422 | doAlternateCRemainder(op, problem, strategy, state); |
14423 | } |
14424 | |
14425 | if (altCRemainder || block2DCRemainder) mark(labelCRemDone); |
14426 | |
14427 | if (state.allowEmptyC && (remainderM || remainderN)) { |
14428 | mark(labelSkip); |
14429 | if (strategy.fused) join(16); |
14430 | } |
14431 | |
14432 | return true; /* Successful! */ |
14433 | } |
14434 | |
14435 | template <HW hw> |
14436 | bool gemm_kernel_generator_t<hw>::gemmBodyInternal( |
14437 | GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state) { |
14438 | auto Tc = problem.Tc; |
14439 | |
14440 | auto unrollM = strategy.unroll[LoopM]; |
14441 | auto unrollN = strategy.unroll[LoopN]; |
14442 | auto &remM = state.remainders[LoopM]; |
14443 | auto &remN = state.remainders[LoopN]; |
14444 | |
14445 | // Accumulate C with panel*panel multiply. |
14446 | if (!gemmAccumulateC(problem, strategy, state)) return false; |
14447 | |
14448 | // Add A/B offsets. |
14449 | gemmLoadABOffset(problem, strategy, state); |
14450 | if (!gemmFinalizeSums(problem, strategy, state)) return false; |
14451 | gemmApplyABOffset(problem, strategy, state); |
14452 | |
14453 | // If C is packed, update remainders and prepare to mask out border regions. |
14454 | bool remaskC_M = isPacked(problem.C.layout) |
14455 | && (strategy.remHandling[LoopM] != RemainderHandling::Ignore); |
14456 | bool remaskC_N = isPacked(problem.C.layout) |
14457 | && (strategy.remHandling[LoopN] != RemainderHandling::Ignore); |
14458 | |
14459 | if (remaskC_M || remaskC_N) { |
14460 | if (remaskC_M) |
14461 | setupTeardownRemask(Tc, 0, true, unrollM, remM, strategy, state); |
14462 | if (remaskC_N) |
14463 | setupTeardownRemask(Tc, 1, true, unrollN, remN, strategy, state); |
14464 | |
14465 | int C_mgran, C_ngran; |
14466 | getGranularities(problem.C, C_mgran, C_ngran); |
14467 | if (!remaskC_M || C_mgran == unrollM) C_mgran = 1; |
14468 | if (!remaskC_N || C_ngran == unrollN) C_ngran = 1; |
14469 | if (!is_zero_or_pow2(C_mgran)) stub(); |
14470 | if (!is_zero_or_pow2(C_ngran)) stub(); |
14471 | |
14472 | if (C_mgran > 1) add(1, remM, remM, C_mgran - 1); |
14473 | if (C_ngran > 1) add(1, remN, remN, C_ngran - 1); |
14474 | if (C_mgran > 1) and_(1, remM, remM, uint32_t(~(C_mgran - 1))); |
14475 | if (C_ngran > 1) and_(1, remN, remN, uint32_t(~(C_ngran - 1))); |
14476 | } |
14477 | |
14478 | // Local k reduction. |
14479 | if (strategy.kParallelLocal) gemmKReduce(problem, strategy, state); |
14480 | |
14481 | // Late exit. |
14482 | bool lateExit = state.doLateExit; |
14483 | Label labelLateExit; |
14484 | |
14485 | if (lateExit) { |
14486 | int simt = strategy.fused ? 16 : 1; |
14487 | |
14488 | cmp(simt | le | f0[0], state.remainders[LoopM], uint16_t(0)); |
14489 | cmp(simt | le | f1[0], state.remainders[LoopN], uint16_t(0)); |
14490 | |
14491 | InstructionModifier cond = simt | f0[0] | anyv; |
14492 | |
14493 | strategy.fused ? goto12(cond, labelLateExit) |
14494 | : ejmpi(cond, labelLateExit); |
14495 | } |
14496 | |
14497 | gemmUpdateSums(problem, strategy, state); |
14498 | |
14499 | // Update C. If configured, choose between regular beta and beta = 0 or beta = 1 updates now. |
14500 | bool checkBeta0 = problem.checkBeta0 && !problem.beta_real.fixed(); |
14501 | bool checkBeta1 = strategy.checkBeta1 && !problem.beta_real.fixed(); |
14502 | bool checkTRMMBeta1 = state.beta1.isValid(); |
14503 | |
14504 | if (checkTRMMBeta1 && (checkBeta0 || checkBeta1)) stub(); |
14505 | |
14506 | if (!checkBeta0 && !checkBeta1 && !checkTRMMBeta1) { |
14507 | if (!gemmUpdateC(problem, strategy, state)) return false; |
14508 | } else { |
14509 | Label labelBeta0, labelBeta1, labelBetaDone; |
14510 | InstructionModifier mod0 = 1 | f0[0]; |
14511 | InstructionModifier mod1 = 1 | f0[1]; |
14512 | bool simtCF1 = false; |
14513 | |
14514 | if (checkBeta0) { cmp0(1 | eq | f0[0], problem.beta_real.getReg(0)); } |
14515 | |
14516 | if (checkBeta1) { |
14517 | cmp(1 | eq | f0[1], problem.beta_real.getReg(0), |
14518 | cast(problem.Ts, 1.0)); |
14519 | } |
14520 | |
14521 | if (checkBeta0) jmpi(mod0, labelBeta0); |
14522 | |
14523 | if (checkBeta1 || checkTRMMBeta1) { |
14524 | simtCF1 ? if_(mod1, labelBeta1, labelBetaDone) |
14525 | : jmpi(mod1, labelBeta1); |
14526 | } |
14527 | |
14528 | // Regular update. |
14529 | { |
14530 | auto subproblem = problem; |
14531 | auto substrategy = strategy; |
14532 | auto substate = state; |
14533 | |
14534 | if (strategy.C.atomic && !strategy.C.base.isStateless() |
14535 | && !strategy.C.newDP) |
14536 | stub(); /* need to shift addresses */ |
14537 | substrategy.C.atomic = false; |
14538 | |
14539 | if (!gemmUpdateC(subproblem, substrategy, substate)) return false; |
14540 | } |
14541 | |
14542 | simtCF1 ? else_(16, labelBetaDone) |
14543 | : state.isNested ? jmpi(1, labelBetaDone) |
14544 | : epilogue(strategy, state); |
14545 | |
14546 | // beta = 1 update. |
14547 | if (checkBeta1 || checkTRMMBeta1) { |
14548 | status << "Special path: beta = 1" << status_stream::endl; |
14549 | mark(labelBeta1); |
14550 | |
14551 | auto subproblem = problem; |
14552 | auto substate = state; |
14553 | |
14554 | subproblem.beta_real = 1; |
14555 | subproblem.beta_imag = 0; |
14556 | |
14557 | if (!gemmUpdateC(subproblem, strategy, substate)) return false; |
14558 | |
14559 | if (checkBeta0) { |
14560 | (simtCF1 || state.isNested) ? jmpi(1, labelBetaDone) |
14561 | : epilogue(strategy, state); |
14562 | } |
14563 | } |
14564 | |
14565 | // beta = 0 update. |
14566 | if (checkBeta0) { |
14567 | status << "Special path: beta = 0" << status_stream::endl; |
14568 | mark(labelBeta0); |
14569 | |
14570 | auto subproblem = problem; |
14571 | auto substrategy = strategy; |
14572 | auto substate = state; |
14573 | |
14574 | subproblem.beta_real = 0; |
14575 | subproblem.beta_imag = 0; |
14576 | |
14577 | substrategy.C.atomic = false; |
14578 | |
14579 | if (!gemmUpdateC(subproblem, substrategy, substate)) return false; |
14580 | } |
14581 | |
14582 | mark(labelBetaDone); |
14583 | if (simtCF1) endif(16); |
14584 | } |
14585 | |
14586 | // Cleanup. |
14587 | if (remaskC_M) |
14588 | setupTeardownRemask(Tc, 0, false, unrollM, remM, strategy, state); |
14589 | if (remaskC_N) |
14590 | setupTeardownRemask(Tc, 1, false, unrollN, remN, strategy, state); |
14591 | |
14592 | if (lateExit) { |
14593 | mark(labelLateExit); |
14594 | if (strategy.fused) join(16); |
14595 | } |
14596 | |
14597 | return true; |
14598 | } |
14599 | |
14600 | template <HW hw> |
14601 | CoopSplit gemm_kernel_generator_t<hw>::effCoopSplitA( |
14602 | const GEMMProblem &problem, const GEMMStrategy &strategy) { |
14603 | if (isPacked(problem.A.layout)) |
14604 | return CoopSplit::Linear; |
14605 | else if (!isRegisterColMajor(problem.Ta_ext, problem.A, strategy.A) |
14606 | && (strategy.unroll[LoopM] % strategy.wg[LoopN] == 0) |
14607 | && !isBlock2D(strategy.A.accessType)) |
14608 | return CoopSplit::MN; |
14609 | else |
14610 | return strategy.coopA; |
14611 | } |
14612 | |
14613 | template <HW hw> |
14614 | CoopSplit gemm_kernel_generator_t<hw>::effCoopSplitB( |
14615 | const GEMMProblem &problem, const GEMMStrategy &strategy) { |
14616 | if (isPacked(problem.B.layout)) |
14617 | return CoopSplit::Linear; |
14618 | else if (isRegisterColMajor(problem.Tb_ext, problem.B, strategy.B) |
14619 | && (strategy.unroll[LoopN] % strategy.wg[LoopM] == 0) |
14620 | && !isBlock2D(strategy.B.accessType)) |
14621 | return CoopSplit::MN; |
14622 | else |
14623 | return strategy.coopB; |
14624 | } |
14625 | |
14626 | // Check whether all threads in a thread group should stay together in m/n remainder handling. |
14627 | template <HW hw> |
14628 | bool gemm_kernel_generator_t<hw>::wgRemCheck( |
14629 | const GEMMProblem &problem, const GEMMStrategy &strategy) { |
14630 | return (strategy.slmA && (effCoopSplitA(problem, strategy) != CoopSplit::K) |
14631 | && (strategy.remHandling[LoopM] != RemainderHandling::Ignore) |
14632 | && !strategy.A.padded) |
14633 | || (strategy.slmB |
14634 | && (effCoopSplitB(problem, strategy) != CoopSplit::K) |
14635 | && (strategy.remHandling[LoopN] |
14636 | != RemainderHandling::Ignore) |
14637 | && !strategy.B.padded) |
14638 | || strategy.kParallelLocal |
14639 | || ((strategy.barrierFreq > 0 || strategy.cooperativePF) |
14640 | && (strategy.prefetchA || strategy.prefetchB |
14641 | || strategy.prefetchC)); |
14642 | } |
14643 | |
14644 | // Do outer-level m/n remainder handling. |
14645 | template <HW hw> |
14646 | template <typename Problem> |
14647 | bool gemm_kernel_generator_t<hw>::mnRemainderHandling(LoopType loop, |
14648 | Problem &problem, GEMMStrategy &strategy, GEMMState &state, |
14649 | bool (gemm_kernel_generator_t<hw>::*func)( |
14650 | Problem, GEMMStrategy, GEMMState)) { |
14651 | auto method = strategy.remHandling[loop]; |
14652 | auto &unroll = strategy.unroll[loop]; |
14653 | auto mn = (loop == LoopM) ? state.inputs.m : state.inputs.n; |
14654 | auto splitThresh |
14655 | = (loop == LoopM) ? strategy.mSplitThresh : strategy.nSplitThresh; |
14656 | |
14657 | Label label_done; |
14658 | |
14659 | auto originalCheckAdd32 = strategy.checkAdd32; |
14660 | |
14661 | if (method == RemainderHandling::Split) { |
14662 | Label label_remainder; |
14663 | |
14664 | // Jump to remainder loop if needed. |
14665 | // If threads fused in this direction, factor fused ID into calculation. |
14666 | if (wgRemCheck(problem, strategy)) |
14667 | cmp(1 | lt | f0[0], null.d(), state.remaindersWG[loop], |
14668 | uint16_t(unroll * strategy.wg[loop])); |
14669 | else |
14670 | cmp(1 | lt | f0[0], null.d(), state.remaindersFused[loop], |
14671 | uint16_t(unroll)); |
14672 | |
14673 | if (splitThresh) { |
14674 | cmp(1 | lt | f1[0], null.d(), mn, int32_t(splitThresh)); |
14675 | ejmpi(1 | f0[0] | anyv, label_remainder); |
14676 | } else |
14677 | jmpi(1 | f0[0], label_remainder); |
14678 | |
14679 | // First generate code that ignores remainder handling. |
14680 | GEMMStrategy substrategy = strategy; |
14681 | substrategy.remHandling[loop] = RemainderHandling::Ignore; |
14682 | |
14683 | status << "Generating " |
14684 | << "MNK" [static_cast<int>(loop)] |
14685 | << " non-remainder kernel for unroll " << unroll << '.' |
14686 | << status_stream::endl; |
14687 | if (!(this->*func)(problem, substrategy, state)) { |
14688 | status << "Non-remainder kernel failed, aborting." |
14689 | << status_stream::endl; |
14690 | return false; |
14691 | } |
14692 | |
14693 | // Return, unless this is part of a larger computation, in which case jump to end. |
14694 | if (state.isNested) |
14695 | jmpi(1, label_done); |
14696 | else |
14697 | epilogue(strategy, state); |
14698 | |
14699 | mark(label_remainder); |
14700 | |
14701 | strategy.checkAdd32 = false; |
14702 | } |
14703 | |
14704 | // OK, great! Now try to create remainder-handling code. |
14705 | status << "Attempting to generate " |
14706 | << "MNK" [static_cast<int>(loop)] << " general kernel for unroll " |
14707 | << unroll << '.' << status_stream::endl; |
14708 | bool success = (this->*func)(problem, strategy, state); |
14709 | |
14710 | strategy.checkAdd32 = originalCheckAdd32; |
14711 | if (success) { |
14712 | mark(label_done); |
14713 | return true; |
14714 | } |
14715 | |
14716 | #ifndef ALLOW_REMAINDERS |
14717 | // Disable remainder code for now. |
14718 | return false; |
14719 | #else |
14720 | auto &bound = (loop == LoopN) ? state.inputs.n : state.inputs.m; |
14721 | auto &index = (loop == LoopN) ? state.j0 : state.i0; |
14722 | auto &remainders = state.remainders[loop]; |
14723 | |
14724 | if (method == RemainderHandling::Ignore) |
14725 | throw std::runtime_error("Could not generate kernel." ); |
14726 | |
14727 | // It failed, so break up the loop into the next smaller power of 2 along this dimension, |
14728 | // plus the remainder (recursively). |
14729 | Label label_next_rem; |
14730 | |
14731 | if (unroll == 1) { |
14732 | // No more splitting to do. |
14733 | // We don't know if this was originally split, so just output a warning. |
14734 | status << "NOTE: Split remainder handling is required for loop " |
14735 | << "MNK" [static_cast<int>(loop)] << '.' << status_stream::endl; |
14736 | return true; |
14737 | } |
14738 | int chunkSize = rounddown_pow2(unroll - 1); |
14739 | |
14740 | // Jump to next remainder loop if needed. |
14741 | pushStream(); |
14742 | { |
14743 | cmp(1 | lt | state.flagAP, null.d(), remainders, chunkSize); |
14744 | jmpi(1 | state.flagAP, label_next_rem); |
14745 | |
14746 | { |
14747 | GEMMStrategy substrategy = strategy; |
14748 | GEMMState substate = state; |
14749 | substrategy.remHandling[loop] = RemainderHandling::Ignore; |
14750 | substrategy.unroll[loop] = chunkSize; |
14751 | substate.isNested = true; |
14752 | status << "Generating " |
14753 | << "MNK" [static_cast<int>(loop)] |
14754 | << " remainder kernel with unroll " << chunkSize << '.' |
14755 | << status_stream::endl; |
14756 | if (!(this->*func)(problem, substrategy, substate)) { |
14757 | discardStream(); |
14758 | return false; |
14759 | } |
14760 | } |
14761 | |
14762 | // Adjust remainder. |
14763 | add(1, remainders, remainders, -chunkSize); |
14764 | |
14765 | // Adjust pointers as needed. |
14766 | // A += i0 (N) i0 * lda (T, Pc) |
14767 | // B += j0 * ldb (N, Pr) j0 (T) |
14768 | // C += i0 + j0 * ldc (N, Pr) j0 + i0 * ldc (T, Pc) |
14769 | switch (loop) { |
14770 | case LoopM: |
14771 | if (problem.A.layout == MatrixLayout::N) |
14772 | eadd(1, state.effA, state.effA, chunkSize * Ta, strategy, |
14773 | state); |
14774 | else { |
14775 | Subregister temp = state.ra.alloc_sub<uint32_t>(); |
14776 | mulConstant(1, temp, state.inputs.lda, chunkSize); |
14777 | eadd(1, state.effA, state.effA, temp, strategy, state); |
14778 | state.ra.safeRelease(temp); |
14779 | } |
14780 | if (problem.C.layout == MatrixLayout::N |
14781 | || problem.C.layout == MatrixLayout::Pr) |
14782 | eadd(1, state.effC, state.effC, |
14783 | chunkSize * transaction_safe, strategy, state); |
14784 | else { |
14785 | Subregister temp = state.ra.alloc_sub<uint32_t>(); |
14786 | mulConstant(1, temp, state.inputs.lda, chunkSize); |
14787 | eadd(1, state.effA, state.effA, temp, strategy, state); |
14788 | state.ra.safeRelease(temp); |
14789 | } |
14790 | break; |
14791 | case LoopN: |
14792 | if (problem.B.layout == MatrixLayout::T) |
14793 | eadd(1, state.effB, state.effB, chunkSize * Tb, strategy, |
14794 | state); |
14795 | else { |
14796 | Subregister temp = state.ra.alloc_sub<uint32_t>(); |
14797 | mulConstant(1, temp, state.inputs.ldb, chunkSize); |
14798 | eadd(1, state.effB, state.effB, temp, strategy, state); |
14799 | state.ra.safeRelease(temp); |
14800 | } |
14801 | if (problem.C.layout == MatrixLayout::T |
14802 | || problem.C.layout == MatrixLayout::Pc) |
14803 | eadd(1, state.effC, state.effC, chunkSize * Tc, strategy, |
14804 | state); |
14805 | else { |
14806 | Subregister temp = state.ra.alloc_sub<uint32_t>(); |
14807 | mulConstant(1, temp, state.inputs.ldb, chunkSize); |
14808 | eadd(1, state.effB, state.effB, temp, strategy, state); |
14809 | state.ra.safeRelease(temp); |
14810 | } |
14811 | break; |
14812 | } |
14813 | |
14814 | mark(label_next_rem); |
14815 | |
14816 | // Handle the remainder recursively. |
14817 | { |
14818 | GEMMStrategy substrategy = strategy; |
14819 | substrategy.remHandling[loop] = RemainderHandling::General; |
14820 | substrategy.unroll[loop] -= chunkSize; |
14821 | if (!mnRemainderHandling(loop, problem, substrategy, state, func)) { |
14822 | discardStream(); |
14823 | return false; |
14824 | } |
14825 | } |
14826 | } /* end stream */ |
14827 | |
14828 | appendCurrentStream(); |
14829 | |
14830 | return true; /* success */ |
14831 | #endif |
14832 | } |
14833 | |
14834 | template <HW hw> |
14835 | template <typename Problem> |
14836 | bool gemm_kernel_generator_t<hw>::mnJointSplitRemainderHandling( |
14837 | Problem &problem, GEMMStrategy &strategy, GEMMState &state, |
14838 | bool (gemm_kernel_generator_t<hw>::*func)( |
14839 | Problem, GEMMStrategy, GEMMState)) { |
14840 | Label label_done, label_remainder; |
14841 | bool success = false; |
14842 | |
14843 | auto unrollM = strategy.unroll[LoopM]; |
14844 | auto unrollN = strategy.unroll[LoopN]; |
14845 | |
14846 | pushStream(); |
14847 | do { |
14848 | // Jump to remainder loop if needed: |
14849 | // - if m/n below split thresholds (when enabled) |
14850 | // - if in a remainder kernel. |
14851 | bool wgCheck = wgRemCheck(problem, strategy); |
14852 | |
14853 | if (strategy.mSplitThresh && strategy.nSplitThresh) { |
14854 | cmp(1 | lt | f0[0], null.d(), state.inputs.m, |
14855 | int32_t(strategy.mSplitThresh)); |
14856 | cmp(1 | lt | f1[0], null.d(), state.inputs.n, |
14857 | int32_t(strategy.nSplitThresh)); |
14858 | ejmpi(1 | f0[0] | anyv, label_remainder); |
14859 | } else if (strategy.mSplitThresh) { |
14860 | cmp(1 | lt | f0[0], null.d(), state.inputs.m, |
14861 | int32_t(strategy.mSplitThresh)); |
14862 | jmpi(1 | f0[0], label_remainder); |
14863 | } else if (strategy.nSplitThresh) { |
14864 | cmp(1 | lt | f0[0], null.d(), state.inputs.n, |
14865 | int32_t(strategy.nSplitThresh)); |
14866 | jmpi(1 | f0[0], label_remainder); |
14867 | } |
14868 | if (wgCheck) { |
14869 | cmp(1 | lt | f0[0], null.d(), state.remaindersWG[LoopM], |
14870 | uint16_t(unrollM * strategy.wg[LoopM])); |
14871 | cmp(1 | lt | f1[0], null.d(), state.remaindersWG[LoopN], |
14872 | uint16_t(unrollN * strategy.wg[LoopN])); |
14873 | } else { |
14874 | cmp(1 | lt | f0[0], null.d(), state.remaindersFused[LoopM], |
14875 | uint16_t(unrollM)); |
14876 | cmp(1 | lt | f1[0], null.d(), state.remaindersFused[LoopN], |
14877 | uint16_t(unrollN)); |
14878 | } |
14879 | ejmpi(1 | f0[0] | anyv, label_remainder); |
14880 | |
14881 | // First generate code that ignores remainder handling. |
14882 | GEMMStrategy substrategy = strategy; |
14883 | substrategy.remHandling[LoopM] = RemainderHandling::Ignore; |
14884 | substrategy.remHandling[LoopN] = RemainderHandling::Ignore; |
14885 | |
14886 | status << "Generating MN non-remainder kernel." << status_stream::endl; |
14887 | if (!(this->*func)(problem, substrategy, state)) { |
14888 | status << "Non-remainder kernel failed, aborting." |
14889 | << status_stream::endl; |
14890 | break; |
14891 | } |
14892 | |
14893 | // Return, unless this is part of a larger computation, in which case jump to end. |
14894 | if (state.isNested) |
14895 | jmpi(1, label_done); |
14896 | else |
14897 | epilogue(strategy, state); |
14898 | |
14899 | mark(label_remainder); |
14900 | |
14901 | // Finally, generate remainder handling kernel. |
14902 | substrategy = strategy; |
14903 | substrategy.remHandling[LoopM] = substrategy.remHandling[LoopN] |
14904 | = (wgCheck ? RemainderHandling::General |
14905 | : RemainderHandling::KnownRemainder); |
14906 | substrategy.checkAdd32 = false; |
14907 | status << "Generating MN general kernel." << status_stream::endl; |
14908 | success = (this->*func)(problem, substrategy, state); |
14909 | |
14910 | mark(label_done); |
14911 | } while (false); |
14912 | |
14913 | success ? appendCurrentStream() : discardStream(); |
14914 | |
14915 | return success; |
14916 | } |
14917 | |
14918 | // Handle outer-level m edge cases. |
14919 | template <HW hw> |
14920 | bool gemm_kernel_generator_t<hw>::gemmMEdge( |
14921 | GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state) { |
14922 | if (strategy.jointSplit |
14923 | && strategy.remHandling[LoopM] == RemainderHandling::Split |
14924 | && strategy.remHandling[LoopN] == RemainderHandling::Split) |
14925 | return mnJointSplitRemainderHandling(problem, strategy, state, |
14926 | &gemm_kernel_generator_t<hw>::gemmBody); |
14927 | else |
14928 | return mnRemainderHandling(LoopM, problem, strategy, state, |
14929 | &gemm_kernel_generator_t<hw>::gemmNEdge); |
14930 | } |
14931 | |
14932 | // Handle outer-level n edge cases. |
14933 | template <HW hw> |
14934 | bool gemm_kernel_generator_t<hw>::gemmNEdge( |
14935 | GEMMProblem problem, GEMMStrategy strategy, GEMMState state) { |
14936 | return mnRemainderHandling(LoopN, problem, strategy, state, |
14937 | &gemm_kernel_generator_t<hw>::gemmBody); |
14938 | } |
14939 | |
14940 | // Initialize the interface. |
14941 | template <HW hw> |
14942 | void gemm_kernel_generator_t<hw>::gemmInitInterface(GEMMProblem &problem, |
14943 | GEMMStrategy &strategy, GEMMState &state, bool inSK) { |
14944 | Subregister localSize[3]; |
14945 | GRF localID[3]; |
14946 | Subregister tgids[3] |
14947 | = {r0.ud(1), r0.ud(6), r0.ud(7)}; // X, Y, Z threadgroup IDs |
14948 | |
14949 | if (strategy.systolic) interface.requireDPAS(); |
14950 | if (strategy.C.atomic) interface.requireGlobalAtomics(); |
14951 | if (strategy.barrierFreq > 0) interface.requireBarrier(); |
14952 | |
14953 | auto slmSize = gemmSLMSize(problem, strategy); |
14954 | auto slmPerK = gemmPerKSLMSize(problem, strategy); |
14955 | if (slmSize > 0 || slmPerK > 0) { |
14956 | status << "SLM usage: " << slmSize / 1024. << 'k'; |
14957 | if (slmPerK) status << " (" << slmPerK / 1024. << "k per-k)" ; |
14958 | status << status_stream::endl; |
14959 | if (!slmPerK) interface.requireSLM(slmSize); |
14960 | interface.requireBarrier(); |
14961 | } |
14962 | |
14963 | if (strategy.fixedWG(problem)) { |
14964 | auto wgX = strategy.wg[strategy.loopOrder[0]]; |
14965 | auto wgY = strategy.wg[strategy.loopOrder[1]]; |
14966 | auto wgZ = strategy.wg[strategy.loopOrder[2]]; |
14967 | if (strategy.splitCopy) wgY *= 2; |
14968 | if (wgZ <= 1) |
14969 | interface.requireWorkgroup(strategy.subgroupSize * wgX, wgY, wgZ); |
14970 | } |
14971 | interface.requireWalkOrder(0, 1, 2); |
14972 | |
14973 | bool needStatelessWrites = strategy.C.base.isStateless(); |
14974 | if (problem.sumA || problem.sumB) |
14975 | needStatelessWrites |= strategy.CO.base.isStateless(); |
14976 | |
14977 | interface.requireStatelessWrites(needStatelessWrites); |
14978 | |
14979 | int nb = int(strategy.slmA || strategy.barrierFreq) |
14980 | * strategy.namedBarriers[LoopM] |
14981 | + int(strategy.slmB || strategy.barrierFreq) |
14982 | * strategy.namedBarriers[LoopN]; |
14983 | if (nb) interface.requireBarriers(nb); |
14984 | |
14985 | interface.finalize(); |
14986 | |
14987 | for (int dim = 0; dim < 3; dim++) { |
14988 | localID[dim] = interface.getLocalID(dim); |
14989 | localSize[dim] = interface.getLocalSize(dim); |
14990 | } |
14991 | |
14992 | // Get input arguments. |
14993 | state.inputs.base = interface.getArgumentIfExists("base" ); |
14994 | auto baseSurface = interface.getArgumentSurfaceIfExists("base" ); |
14995 | if (state.inputs.base.isInvalid() |
14996 | && baseSurface == InterfaceHandler::noSurface) { |
14997 | state.inputs.A = interface.getArgumentIfExists("A" ); |
14998 | state.inputs.B = interface.getArgumentIfExists("B" ); |
14999 | state.inputs.C[1] = interface.getArgumentIfExists("P" ); |
15000 | state.inputs.surfaceA = interface.getArgumentSurfaceIfExists("A" ); |
15001 | state.inputs.surfaceB = interface.getArgumentSurfaceIfExists("B" ); |
15002 | state.inputs.surfaceC[1] = interface.getArgumentSurfaceIfExists("P" ); |
15003 | } else { |
15004 | state.inputs.A = state.inputs.B = state.inputs.base; |
15005 | state.inputs.surfaceA = state.inputs.surfaceB = baseSurface; |
15006 | if (interface.getArgumentIfExists("offset_P" ).isValid()) { |
15007 | state.inputs.C[1] = state.inputs.base; |
15008 | state.inputs.surfaceC[1] = state.inputs.surfaceA; |
15009 | } |
15010 | } |
15011 | |
15012 | state.inputs.C[0] = interface.getArgumentIfExists("C" ); |
15013 | state.inputs.surfaceC[0] = interface.getArgumentSurfaceIfExists("C" ); |
15014 | state.C_count = state.inputs.C[1].isValid() ? 2 : 1; |
15015 | if (problem.usesCO()) { |
15016 | state.inputs.CO = interface.getArgumentIfExists("CO" ); |
15017 | state.inputs.surfaceCO = interface.getArgumentSurfaceIfExists("CO" ); |
15018 | } |
15019 | |
15020 | if (problem.abOffset != ABOffset::None) { |
15021 | state.inputs.abo = interface.getArgumentIfExists("abo" ); |
15022 | if (state.inputs.abo.isValid()) { |
15023 | // A/B offset are two words packed into a single dword argument. |
15024 | state.inputs.ao = state.inputs.abo.w(0); |
15025 | state.inputs.bo = state.inputs.abo.w(1); |
15026 | } else { |
15027 | state.inputs.ao = interface.getArgumentIfExists("ao" ); |
15028 | state.inputs.bo = interface.getArgumentIfExists("bo" ); |
15029 | } |
15030 | state.inputs.aoPtr = interface.getArgumentIfExists("ao_ptr" ); |
15031 | state.inputs.boPtr = interface.getArgumentIfExists("bo_ptr" ); |
15032 | } |
15033 | state.inputs.offsetA = interface.getArgumentIfExists("offset_A" ); |
15034 | state.inputs.offsetB = interface.getArgumentIfExists("offset_B" ); |
15035 | state.inputs.offsetC[0] = interface.getArgumentIfExists("offset_C" ); |
15036 | state.inputs.offsetC[1] = interface.getArgumentIfExists("offset_P" ); |
15037 | state.inputs.offsetCO = interface.getArgumentIfExists("offset_CO" ); |
15038 | if (problem.batch == BatchMode::Strided) { |
15039 | state.inputs.strideA[0] = interface.getArgumentIfExists("stride_A" ); |
15040 | state.inputs.strideB[0] = interface.getArgumentIfExists("stride_B" ); |
15041 | state.inputs.strideC[0] = interface.getArgumentIfExists("stride_C" ); |
15042 | if (problem.batchDims > 1) { |
15043 | state.inputs.strideA[1] |
15044 | = interface.getArgumentIfExists("stride_A1" ); |
15045 | state.inputs.strideB[1] |
15046 | = interface.getArgumentIfExists("stride_B1" ); |
15047 | state.inputs.strideC[1] |
15048 | = interface.getArgumentIfExists("stride_C1" ); |
15049 | state.inputs.batchSize1 |
15050 | = interface.getArgumentIfExists("batch_size1" ); |
15051 | state.inputs.recipBatchSize1 |
15052 | = interface.getArgumentIfExists("recip_batch_size1" ); |
15053 | } |
15054 | } else if (problem.batch == BatchMode::Nonstrided) |
15055 | state.inputs.offsetBatch |
15056 | = interface.getArgumentIfExists("offset_batch" ); |
15057 | else if (problem.batch == BatchMode::Variable) { |
15058 | state.inputs.incr_a_array |
15059 | = interface.getArgumentIfExists("incr_a_array" ); |
15060 | state.inputs.incr_b_array |
15061 | = interface.getArgumentIfExists("incr_b_array" ); |
15062 | } |
15063 | state.inputs.lda = interface.getArgumentIfExists("lda" ); |
15064 | state.inputs.ldb = interface.getArgumentIfExists("ldb" ); |
15065 | state.inputs.ldc[0] = interface.getArgumentIfExists("ldc" ); |
15066 | state.inputs.ldc[1] = interface.getArgumentIfExists("ldp" ); |
15067 | state.inputs.ldco = interface.getArgumentIfExists("ldco" ); |
15068 | state.inputs.m = interface.getArgumentIfExists("m" ); |
15069 | state.inputs.n = interface.getArgumentIfExists("n" ); |
15070 | state.inputs.k = interface.getArgumentIfExists("k" ); |
15071 | state.inputs.k0 = interface.getArgumentIfExists("k0" ); |
15072 | state.inputs.alpha_real = interface.getArgumentIfExists("alpha_real" ); |
15073 | state.inputs.alpha_imag = interface.getArgumentIfExists("alpha_imag" ); |
15074 | state.inputs.beta_real = interface.getArgumentIfExists("beta_real" ); |
15075 | state.inputs.beta_imag = interface.getArgumentIfExists("beta_imag" ); |
15076 | if (problem.batch == BatchMode::Variable) { |
15077 | state.inputs.alpha_array = interface.getArgumentIfExists("alpha_array" ); |
15078 | state.inputs.beta_array = interface.getArgumentIfExists("beta_array" ); |
15079 | state.inputs.incr_alpha = interface.getArgumentIfExists("incr_alpha" ); |
15080 | state.inputs.incr_beta = interface.getArgumentIfExists("incr_beta" ); |
15081 | } |
15082 | state.inputs.diagA = interface.getArgumentIfExists("diag_A" ); |
15083 | state.inputs.diagB = interface.getArgumentIfExists("diag_B" ); |
15084 | state.inputs.diagC = interface.getArgumentIfExists("diag_C" ); |
15085 | state.inputs.flags = interface.getArgumentIfExists("flags" ); |
15086 | |
15087 | if (strategy.linearOrder()) { |
15088 | state.inputs.groupCountM = interface.getArgument("group_count_m" ); |
15089 | state.inputs.groupCountN = interface.getArgument("group_count_n" ); |
15090 | } |
15091 | if (strategy.hilbertOrder) { |
15092 | state.inputs.hilbertVD = interface.getArgumentIfExists("hilbert_vd" ); |
15093 | state.inputs.hilbertUVDRecip |
15094 | = interface.getArgumentIfExists("hilbert_uvd_recip" ); |
15095 | state.inputs.hilbertBail |
15096 | = interface.getArgumentIfExists("hilbert_bail" ); |
15097 | } else if (strategy.boustrophedon) { |
15098 | state.inputs.bslice = interface.getArgument("bslice" ); |
15099 | state.inputs.bthresh = interface.getArgument("bthresh" ); |
15100 | } |
15101 | if (strategy.persistent) { |
15102 | state.inputs.groupCountMN |
15103 | = interface.getArgumentIfExists("group_count" ); |
15104 | state.inputs.groupStride = interface.getArgument("group_stride" ); |
15105 | } |
15106 | |
15107 | int po_count = problem.postOps.len(); |
15108 | state.inputs.binarySrcs.resize(po_count); |
15109 | state.inputs.binaryOffsets.resize(po_count); |
15110 | state.inputs.binaryLDs.resize(po_count); |
15111 | state.inputs.binaryStrides.resize(po_count); |
15112 | state.inputs.binarySurfaces.resize(po_count); |
15113 | for (int i = 0; i < po_count; i++) { |
15114 | std::string srcName = "binary" + std::to_string(i); |
15115 | state.inputs.binarySrcs[i] = interface.getArgumentIfExists(srcName); |
15116 | state.inputs.binarySurfaces[i] |
15117 | = interface.getArgumentSurfaceIfExists(srcName); |
15118 | state.inputs.binaryOffsets[i] |
15119 | = interface.getArgumentIfExists("offset_" + srcName); |
15120 | state.inputs.binaryLDs[i] |
15121 | = interface.getArgumentIfExists("ld" + srcName); |
15122 | if (problem.batch == BatchMode::Strided) { |
15123 | state.inputs.binaryStrides[i][0] |
15124 | = interface.getArgumentIfExists("stride_" + srcName); |
15125 | if (problem.batchDims > 1) |
15126 | state.inputs.binaryStrides[i][1] |
15127 | = interface.getArgumentIfExists("stride1_" + srcName); |
15128 | } |
15129 | } |
15130 | |
15131 | Subregister tgids_reordered[3]; |
15132 | GRF lids_reordered[3]; |
15133 | Subregister lszs_reordered[3]; |
15134 | |
15135 | for (int l = 0; l < 3; l++) { |
15136 | int i = static_cast<int>(strategy.loopOrder[l]); |
15137 | tgids_reordered[i] = tgids[l]; |
15138 | lids_reordered[i] = localID[l]; |
15139 | lszs_reordered[i] = localSize[l]; |
15140 | } |
15141 | state.inputs.groupIDM = tgids_reordered[0]; |
15142 | state.inputs.groupIDN = tgids_reordered[1]; |
15143 | state.inputs.groupIDK = tgids_reordered[2]; |
15144 | state.inputs.localIDM = lids_reordered[0]; |
15145 | state.inputs.localIDN = lids_reordered[1]; |
15146 | state.inputs.localIDK = lids_reordered[2]; |
15147 | state.inputs.localSizeM = lszs_reordered[0]; |
15148 | state.inputs.localSizeN = lszs_reordered[1]; |
15149 | state.inputs.localSizeK = lszs_reordered[2]; |
15150 | |
15151 | if (strategy.linearOrder()) { |
15152 | state.inputs.groupIDMN = tgids[0]; |
15153 | state.inputs.groupIDM = invalid; |
15154 | state.inputs.groupIDN = invalid; |
15155 | } |
15156 | |
15157 | // Downgrade offsets to 32 bits for non-A64 accesses. |
15158 | if (strategy.A.base.getModel() != ModelA64) |
15159 | state.inputs.offsetA = state.inputs.offsetA.d(); |
15160 | if (strategy.B.base.getModel() != ModelA64) |
15161 | state.inputs.offsetB = state.inputs.offsetB.d(); |
15162 | if (strategy.C.base.getModel() != ModelA64) |
15163 | for (int q = 0; q < state.C_count; q++) |
15164 | state.inputs.offsetC[q] = state.inputs.offsetC[q].d(); |
15165 | if (problem.usesCO() && strategy.CO.base.getModel() != ModelA64) |
15166 | state.inputs.offsetCO = state.inputs.offsetCO.d(); |
15167 | for (auto &off : state.inputs.binaryOffsets) |
15168 | off = off.d(); |
15169 | |
15170 | // For now, reinterpret m/n/k/ld/diag variables to 32-bit if they are 64-bit. |
15171 | state.inputs.m = state.inputs.m.d(); |
15172 | state.inputs.n = state.inputs.n.d(); |
15173 | state.inputs.k = state.inputs.k.d(); |
15174 | state.inputs.lda = state.inputs.lda.ud(); |
15175 | state.inputs.ldb = state.inputs.ldb.ud(); |
15176 | for (int q = 0; q < state.C_count; q++) |
15177 | state.inputs.ldc[q] = state.inputs.ldc[q].ud(); |
15178 | state.inputs.ldco = state.inputs.ldco.ud(); |
15179 | state.inputs.diagA = state.inputs.diagA.d(); |
15180 | state.inputs.diagB = state.inputs.diagB.d(); |
15181 | state.inputs.diagC = state.inputs.diagC.d(); |
15182 | |
15183 | // Claim registers. |
15184 | for (int i = 0; i < 4; i++) |
15185 | state.ra.claim(r0.uq(i)); |
15186 | |
15187 | if (strategy.A.base.isStateless()) state.ra.claim(state.inputs.A); |
15188 | if (strategy.B.base.isStateless()) state.ra.claim(state.inputs.B); |
15189 | if (strategy.C.base.isStateless()) |
15190 | for (int q = 0; q < state.C_count; q++) |
15191 | state.ra.claim(state.inputs.C[q]); |
15192 | |
15193 | if (problem.abOffset != ABOffset::None) { |
15194 | if (state.inputs.ao.isValid()) state.ra.claim(state.inputs.ao); |
15195 | if (state.inputs.bo.isValid()) state.ra.claim(state.inputs.bo); |
15196 | if (state.inputs.aoPtr.isValid()) state.ra.claim(state.inputs.aoPtr); |
15197 | if (state.inputs.boPtr.isValid()) state.ra.claim(state.inputs.boPtr); |
15198 | } |
15199 | |
15200 | if (problem.usesCO()) { |
15201 | if (strategy.CO.base.isStateless()) state.ra.claim(state.inputs.CO); |
15202 | state.ra.claim(state.inputs.offsetCO); |
15203 | } |
15204 | |
15205 | state.ra.claim(state.inputs.offsetA); |
15206 | state.ra.claim(state.inputs.offsetB); |
15207 | for (int q = 0; q < state.C_count; q++) |
15208 | state.ra.claim(state.inputs.offsetC[q]); |
15209 | state.ra.claim(state.inputs.lda); |
15210 | state.ra.claim(state.inputs.ldb); |
15211 | for (int q = 0; q < state.C_count; q++) |
15212 | state.ra.claim(state.inputs.ldc[q]); |
15213 | if (problem.allowMatrixOffset()) state.ra.claim(state.inputs.ldco); |
15214 | state.ra.claim(state.inputs.m); |
15215 | state.ra.claim(state.inputs.n); |
15216 | state.ra.claim(state.inputs.k); |
15217 | if (strategy.kParallel || strategy.kParallelLocal) |
15218 | state.ra.claim(state.inputs.k0); |
15219 | |
15220 | if (!problem.alpha_real.fixed()) state.ra.claim(state.inputs.alpha_real); |
15221 | if (!problem.beta_real.fixed()) state.ra.claim(state.inputs.beta_real); |
15222 | |
15223 | if (!inSK) { |
15224 | state.ra.claim(state.inputs.localIDM); |
15225 | state.ra.claim(state.inputs.localIDN); |
15226 | if (!strategy.fixedWG(problem)) { |
15227 | state.ra.claim(state.inputs.localSizeM); |
15228 | state.ra.claim(state.inputs.localSizeN); |
15229 | } else |
15230 | state.inputs.localSizeM = state.inputs.localSizeN = invalid; |
15231 | if (strategy.kParallel || strategy.kParallelLocal) { |
15232 | state.ra.claim(state.inputs.localIDK); |
15233 | state.ra.claim(state.inputs.localSizeK); |
15234 | } |
15235 | } |
15236 | |
15237 | if (state.inputs.flags.isValid()) state.ra.claim(state.inputs.flags); |
15238 | |
15239 | if (problem.batch == BatchMode::Strided) { |
15240 | for (int i = 0; i < problem.batchDims; i++) { |
15241 | state.ra.claim(state.inputs.strideA[i]); |
15242 | state.ra.claim(state.inputs.strideB[i]); |
15243 | state.ra.claim(state.inputs.strideC[i]); |
15244 | } |
15245 | if (problem.batchDims > 1) { |
15246 | state.ra.claim(state.inputs.batchSize1); |
15247 | state.ra.claim(state.inputs.recipBatchSize1); |
15248 | } |
15249 | state.ra.claim(state.inputs.groupIDK); |
15250 | } else if (problem.batch == BatchMode::Nonstrided) { |
15251 | state.ra.claim(state.inputs.offsetBatch); |
15252 | state.ra.claim(state.inputs.groupIDK); |
15253 | } else if (problem.batch == BatchMode::Variable) { |
15254 | state.ra.claim(state.inputs.incr_a_array); |
15255 | state.ra.claim(state.inputs.incr_b_array); |
15256 | state.ra.claim(state.inputs.alpha_array); |
15257 | state.ra.claim(state.inputs.beta_array); |
15258 | state.ra.claim(state.inputs.incr_alpha); |
15259 | state.ra.claim(state.inputs.incr_beta); |
15260 | state.ra.claim(state.inputs.groupIDK); |
15261 | } |
15262 | |
15263 | if (strategy.linearOrder()) { |
15264 | state.ra.claim(state.inputs.groupCountM); |
15265 | state.ra.claim(state.inputs.groupCountN); |
15266 | } |
15267 | |
15268 | if (strategy.hilbertOrder) { |
15269 | { |
15270 | state.ra.claim(state.inputs.hilbertVD); |
15271 | state.ra.claim(state.inputs.hilbertUVDRecip); |
15272 | } |
15273 | state.ra.claim(state.inputs.hilbertBail); |
15274 | } |
15275 | |
15276 | if (strategy.boustrophedon) { |
15277 | state.ra.claim(state.inputs.bslice); |
15278 | state.ra.claim(state.inputs.bthresh); |
15279 | } |
15280 | |
15281 | if (strategy.persistent) { |
15282 | state.ra.claim(state.inputs.groupStride); |
15283 | if (state.inputs.groupCountMN.isValid()) |
15284 | state.ra.claim(state.inputs.groupCountMN); |
15285 | } |
15286 | |
15287 | // Binary-related arguments are not claimed here, but instead |
15288 | // are reloaded later in the kernel when needed. |
15289 | } |
15290 | |
15291 | // Return amount of SLM needed by a GEMM kernel. |
15292 | template <HW hw> |
15293 | size_t gemm_kernel_generator_t<hw>::gemmSLMSize( |
15294 | const GEMMProblem &problem, const GEMMStrategy &strategy) { |
15295 | // Space needed by SLM copies. |
15296 | size_t slmSize |
15297 | = strategy.slmABufSize(problem) + strategy.slmBBufSize(problem); |
15298 | |
15299 | // Space needed for row/column sum reduction/sharing. |
15300 | if ((problem.needsASums() && strategy.slmA) |
15301 | || (problem.needsBSums() && strategy.slmB)) { |
15302 | slmSize = std::max<size_t>(slmSize, |
15303 | (strategy.unroll[LoopM] * strategy.wg[LoopM] |
15304 | + strategy.unroll[LoopN] * strategy.wg[LoopN]) |
15305 | * problem.Tc); |
15306 | } |
15307 | |
15308 | return slmSize; |
15309 | } |
15310 | |
15311 | // Return amount of per-k SLM needed by a GEMM kernel. |
15312 | template <HW hw> |
15313 | size_t gemm_kernel_generator_t<hw>::gemmPerKSLMSize( |
15314 | const GEMMProblem &problem, const GEMMStrategy &strategy) { |
15315 | size_t slmSize = 0; |
15316 | |
15317 | // Space needed for local k reduction (as much as possible). |
15318 | if (strategy.kParallelLocal) { |
15319 | // Calculate max SLM usage that doesn't reduce thread count. |
15320 | int mnThreads = strategy.wg[LoopM] * strategy.wg[LoopN]; |
15321 | if (mnThreads <= 0) stub(); |
15322 | int concurrentK = std::max( |
15323 | 1, threadsPerEU(hw, strategy) * eusPerSubslice(hw) / mnThreads); |
15324 | slmSize = rounddown_pow2(slmCapacity(hw) / concurrentK); |
15325 | } |
15326 | |
15327 | return slmSize; |
15328 | } |
15329 | |
15330 | // Initialize the state structure. |
15331 | template <HW hw> |
15332 | void gemm_kernel_generator_t<hw>::gemmInitState(GEMMProblem &problem, |
15333 | GEMMStrategy &strategy, GEMMState &state, bool inSK) { |
15334 | auto Ta = problem.Ta, Tb = problem.Tb, Tc = problem.Tc; |
15335 | |
15336 | if (!state.fusedGEMM.active) { |
15337 | initState(problem, strategy, state); |
15338 | gemmInitInterface(problem, strategy, state, inSK); |
15339 | state.isNested |= strategy.fused; |
15340 | state.isNested |= strategy.persistent; |
15341 | } |
15342 | |
15343 | state.effA = strategy.A.base.isStateless() ? state.inputs.A |
15344 | : state.inputs.offsetA.d(); |
15345 | state.effB = strategy.B.base.isStateless() ? state.inputs.B |
15346 | : state.inputs.offsetB.d(); |
15347 | for (int q = 0; q < state.C_count; q++) { |
15348 | state.effC[q] = strategy.C.base.isStateless() |
15349 | ? state.inputs.C[q] |
15350 | : state.inputs.offsetC[q].d(); |
15351 | } |
15352 | if (problem.usesCO()) { |
15353 | state.effCO = strategy.CO.base.isStateless() |
15354 | ? state.inputs.CO |
15355 | : state.inputs.offsetCO.d(); |
15356 | } |
15357 | |
15358 | if (!problem.alpha_real.fixed()) |
15359 | problem.alpha_real = state.inputs.alpha_real; |
15360 | if (!problem.beta_real.fixed()) problem.beta_real = state.inputs.beta_real; |
15361 | |
15362 | state.offsetA = state.inputs.offsetA; |
15363 | state.offsetB = state.inputs.offsetB; |
15364 | for (int q = 0; q < state.C_count; q++) |
15365 | state.offsetC[q] = state.inputs.offsetC[q]; |
15366 | state.offsetCO = state.inputs.offsetCO; |
15367 | |
15368 | state.flagAP = state.raVFlag.alloc(); |
15369 | |
15370 | state.allocEmulate64Temp(strategy.emulate); |
15371 | |
15372 | state.Ta_load = problem.Ta_ext; |
15373 | state.Tb_load = problem.Tb_ext; |
15374 | |
15375 | state.Tacc = problem.Tc; |
15376 | state.copyC = (problem.Tc != problem.Tc_ext) |
15377 | || (!strategy.altCRemainder && (Tc.size() < 4)) |
15378 | || strategy.forceCopyC; |
15379 | |
15380 | state.broadcast = strategy.doubleWA; |
15381 | |
15382 | bool cColMajor = isRegisterColMajor(problem.Tc, problem.C, strategy.C); |
15383 | state.broadcast |= (Tc == Type::f32 && (cColMajor ? Tb : Ta) == Type::bf16); |
15384 | |
15385 | state.Cext_strategy = strategy.C; |
15386 | state.Cext_strategy.tileR = state.Cext_strategy.tileC = 0; |
15387 | |
15388 | state.lidM = state.inputs.localIDM[0]; |
15389 | state.lidN = state.inputs.localIDN[0]; |
15390 | if (strategy.kParallel || strategy.kParallelLocal) |
15391 | state.lidK = state.inputs.localIDK[0]; |
15392 | |
15393 | state.diagC = state.inputs.diagC; |
15394 | state.k = state.inputs.k; |
15395 | |
15396 | state.lda = state.inputs.lda; |
15397 | state.ldb = state.inputs.ldb; |
15398 | } |
15399 | |
15400 | // Offset A pointer in k dimension by a constant value. |
15401 | template <HW hw> |
15402 | void gemm_kernel_generator_t<hw>::gemmOffsetAk(int h, const Subregister &effA, |
15403 | const MatrixAddressing &globalA, const GEMMProblem &problem, |
15404 | const GEMMStrategy &strategy, GEMMState &state) { |
15405 | auto Ta_ext = problem.Ta_ext; |
15406 | if (h) switch (globalA.layout) { |
15407 | case MatrixLayout::N: |
15408 | emad(1, effA, effA, state.inputs.lda, Immediate::w(h), strategy, |
15409 | state); |
15410 | break; |
15411 | case MatrixLayout::T: |
15412 | eadd(1, effA, effA, h * Ta_ext, strategy, state); |
15413 | break; |
15414 | case MatrixLayout::Pc: |
15415 | eadd(1, effA, effA, h * globalA.packSize * Ta_ext, strategy, |
15416 | state); |
15417 | break; |
15418 | default: stub(); |
15419 | } |
15420 | } |
15421 | |
15422 | // Offset A pointer in k dimension by a variable value. |
15423 | template <HW hw> |
15424 | void gemm_kernel_generator_t<hw>::gemmOffsetAk(const Subregister &h, |
15425 | const Subregister &effA, const MatrixAddressing &globalA, |
15426 | const GEMMProblem &problem, const GEMMStrategy &strategy, |
15427 | GEMMState &state) { |
15428 | auto Ta_ext = problem.Ta_ext; |
15429 | switch (globalA.layout) { |
15430 | case MatrixLayout::N: |
15431 | emad(1, effA, effA, state.inputs.lda, h, strategy, state); |
15432 | break; |
15433 | case MatrixLayout::T: |
15434 | emad(1, effA, effA, h, Ta_ext.size(), strategy, state); |
15435 | break; |
15436 | case MatrixLayout::Pc: |
15437 | emad(1, effA, effA, h, globalA.packSize * Ta_ext, strategy, state); |
15438 | break; |
15439 | default: stub(); |
15440 | } |
15441 | } |
15442 | |
15443 | // Offset B pointer in k dimension by a constant value. |
15444 | template <HW hw> |
15445 | void gemm_kernel_generator_t<hw>::gemmOffsetBk(int h, const Subregister &effB, |
15446 | const MatrixAddressing &globalB, const GEMMProblem &problem, |
15447 | const GEMMStrategy &strategy, GEMMState &state) { |
15448 | auto Tb_ext = problem.Tb_ext; |
15449 | if (h) switch (globalB.layout) { |
15450 | case MatrixLayout::N: |
15451 | eadd(1, effB, effB, h * Tb_ext, strategy, state); |
15452 | break; |
15453 | case MatrixLayout::Pr: |
15454 | eadd(1, effB, effB, h * globalB.packSize * Tb_ext, strategy, |
15455 | state); |
15456 | break; |
15457 | case MatrixLayout::T: |
15458 | emad(1, effB, effB, state.inputs.ldb, Immediate::w(h), strategy, |
15459 | state); |
15460 | break; |
15461 | default: stub(); |
15462 | } |
15463 | } |
15464 | |
15465 | // Offset B pointer in k dimension by a variable value. |
15466 | template <HW hw> |
15467 | void gemm_kernel_generator_t<hw>::gemmOffsetBk(const Subregister &h, |
15468 | const Subregister &effB, const MatrixAddressing &globalB, |
15469 | const GEMMProblem &problem, const GEMMStrategy &strategy, |
15470 | GEMMState &state) { |
15471 | auto Tb_ext = problem.Tb_ext; |
15472 | switch (globalB.layout) { |
15473 | case MatrixLayout::T: |
15474 | emad(1, effB, effB, state.inputs.ldb, h, strategy, state); |
15475 | break; |
15476 | case MatrixLayout::N: |
15477 | emad(1, effB, effB, h, Tb_ext.size(), strategy, state); |
15478 | break; |
15479 | case MatrixLayout::Pr: |
15480 | emad(1, effB, effB, h, globalB.packSize * Tb_ext, strategy, state); |
15481 | break; |
15482 | default: stub(); |
15483 | } |
15484 | } |
15485 | |
15486 | // Adjust A, B, C to start at (i0, j0). |
15487 | // initial is true to adjust offset_{A,B,C}, false to adjust A,B,C pointers. |
15488 | template <HW hw> |
15489 | void gemm_kernel_generator_t<hw>::gemmOffsetABC(bool initial, Subregister i0, |
15490 | Subregister j0, Subregister h0, const GEMMProblem &problem, |
15491 | const GEMMStrategy &strategy, GEMMState &state, bool doA, bool doB, |
15492 | bool doC, bool doBinary) { |
15493 | auto Ta_ext = problem.Ta_ext, Tb_ext = problem.Tb_ext, |
15494 | Tc_ext = problem.Tc_ext, Tco = problem.Tco; |
15495 | auto &offsetA = initial ? state.offsetA : state.effA; |
15496 | auto &offsetB = initial ? state.offsetB : state.effB; |
15497 | auto &offsetC0 = initial ? state.offsetC[0] : state.effC[0]; |
15498 | auto &offsetAp = initial ? state.offsetAp : state.effAp; |
15499 | auto &offsetBp = initial ? state.offsetBp : state.effBp; |
15500 | auto &offsetCp = initial ? state.offsetCp : state.effCp; |
15501 | auto offsetCO = initial ? state.offsetCO : state.effCO; |
15502 | bool doCO = doC && (problem.cOffset != COffset::None); |
15503 | |
15504 | Subregister tempQ0 = state.ra.alloc_sub<int64_t>( |
15505 | getHint(HintType::TempComp0, strategy)); |
15506 | Subregister tempQ1 = state.ra.alloc_sub<int64_t>( |
15507 | getHint(HintType::TempComp1, strategy)); |
15508 | |
15509 | bool a2D = strategy.A.address2D; |
15510 | bool b2D = strategy.B.address2D; |
15511 | bool c2D = strategy.C.address2D; |
15512 | bool ap2D = strategy.prefetchA ? strategy.A_prefetch.address2D : a2D; |
15513 | bool bp2D = strategy.prefetchB ? strategy.B_prefetch.address2D : b2D; |
15514 | bool cp2D = strategy.prefetchC ? strategy.C_prefetch.address2D : c2D; |
15515 | |
15516 | if (a2D && ap2D) doA = false; |
15517 | if (b2D && bp2D) doB = false; |
15518 | if (c2D && cp2D) doC = false; |
15519 | |
15520 | if (doA && (a2D != ap2D)) { |
15521 | if (!initial) stub(); |
15522 | if (offsetAp.isInvalid()) { |
15523 | offsetAp = state.ra.alloc_sub( |
15524 | offsetA.getType(), getHint(HintType::LongTerm, strategy)); |
15525 | emov(1, state.offsetAp, offsetA, strategy, state); |
15526 | } else if (a2D && !ap2D) |
15527 | std::swap(offsetA, offsetAp); |
15528 | } |
15529 | if (doB && (b2D != bp2D)) { |
15530 | if (!initial) stub(); |
15531 | if (offsetBp.isInvalid()) { |
15532 | offsetBp = state.ra.alloc_sub( |
15533 | offsetB.getType(), getHint(HintType::LongTerm, strategy)); |
15534 | emov(1, offsetBp, offsetB, strategy, state); |
15535 | } else if (b2D && !bp2D) |
15536 | std::swap(offsetB, offsetBp); |
15537 | } |
15538 | if (doC && (c2D != cp2D)) { |
15539 | if (!initial) stub(); |
15540 | if (offsetCp.isInvalid()) { |
15541 | offsetCp = state.ra.alloc_sub( |
15542 | offsetC0.getType(), getHint(HintType::LongTerm, strategy)); |
15543 | emov(1, offsetCp, offsetC0, strategy, state); |
15544 | } else if (c2D && !cp2D) |
15545 | std::swap(offsetC0, offsetCp); |
15546 | } |
15547 | |
15548 | // To do: interleave code. |
15549 | // A += i0 (N) i0 * lda (T, Pc) |
15550 | // B += j0 * ldb (N, Pr) j0 (T) |
15551 | // C += i0 + j0 * ldc (N, Pr) j0 + i0 * ldc (T, Pc) |
15552 | // CO += i0 (row offsets) j0 (col offsets) |
15553 | if (doA && i0.isValid()) { |
15554 | if (problem.A.layout == MatrixLayout::Nontranspose) |
15555 | emad(1, offsetA, offsetA, i0, Ta_ext.size(), strategy, state); |
15556 | else { |
15557 | emul(1, tempQ1, i0, state.inputs.lda, strategy, state); |
15558 | eadd(1, offsetA, offsetA, tempQ1.reinterpret(0, offsetA.getType()), |
15559 | strategy, state); |
15560 | } |
15561 | } |
15562 | |
15563 | if (doB && j0.isValid()) { |
15564 | if (problem.B.layout == MatrixLayout::Transpose) |
15565 | emad(1, offsetB, offsetB, j0, Tb_ext.size(), strategy, state); |
15566 | else { |
15567 | emul(1, tempQ0, j0, state.inputs.ldb, strategy, state); |
15568 | eadd(1, offsetB, offsetB, tempQ0.reinterpret(0, offsetB.getType()), |
15569 | strategy, state); |
15570 | } |
15571 | } |
15572 | |
15573 | FlagRegister flagCOR, flagCOC; |
15574 | if (doCO) { |
15575 | flagCOR = state.raVFlag.alloc(); |
15576 | flagCOC = state.raVFlag.alloc(); |
15577 | and_(1 | nz | flagCOC, null.ud(), state.inputs.flags, FlagCOColumn); |
15578 | and_(1 | nz | flagCOR, null.ud(), state.inputs.flags, FlagCORow); |
15579 | } |
15580 | if (doC) { |
15581 | for (int q = 0; q < state.C_count; q++) { |
15582 | auto offsetC = initial ? state.offsetC[q] : state.effC[q]; |
15583 | |
15584 | Subregister x, y; |
15585 | int xstride = Tc_ext.size(); |
15586 | switch (problem.C.layout) { |
15587 | case MatrixLayout::Pr: |
15588 | xstride *= strategy.unroll[LoopN]; /* fall through */ |
15589 | case MatrixLayout::N: |
15590 | x = i0; |
15591 | y = j0; |
15592 | break; |
15593 | case MatrixLayout::Pc: |
15594 | xstride *= strategy.unroll[LoopM]; /* fall through */ |
15595 | case MatrixLayout::T: |
15596 | x = j0; |
15597 | y = i0; |
15598 | break; |
15599 | } |
15600 | emad(1, offsetC, offsetC, x, xstride, strategy, state); |
15601 | emul(1, tempQ0, y, state.inputs.ldc[q], strategy, state); |
15602 | eadd(1, offsetC, offsetC, tempQ0.reinterpret(0, offsetC.getType()), |
15603 | strategy, state); // Gen12: Use add3. |
15604 | } |
15605 | } |
15606 | if (doCO) { |
15607 | Label lNoMatrixOffset, lDone; |
15608 | if (problem.allowMatrixOffset()) { |
15609 | auto x0 = isColMajor(problem.CO.layout) ? i0 : j0; |
15610 | auto y0 = isColMajor(problem.CO.layout) ? j0 : i0; |
15611 | jmpi(1 | ~flagCOC, lNoMatrixOffset); |
15612 | jmpi(1 | ~flagCOR, lNoMatrixOffset); |
15613 | emad(1, offsetCO, offsetCO, x0, Tco.size(), strategy, state); |
15614 | emul(1, tempQ0, y0, state.inputs.ldco, strategy, state); |
15615 | eadd(1, offsetCO, offsetCO, |
15616 | tempQ0.reinterpret(0, offsetCO.getType()), strategy, state); |
15617 | jmpi(1, lDone); |
15618 | mark(lNoMatrixOffset); |
15619 | } |
15620 | emad(1 | flagCOC, offsetCO, offsetCO, j0, Tco.size(), strategy, state); |
15621 | emad(1 | flagCOR, offsetCO, offsetCO, i0, Tco.size(), strategy, state); |
15622 | state.raVFlag.safeRelease(flagCOR); |
15623 | state.raVFlag.safeRelease(flagCOC); |
15624 | if (problem.allowMatrixOffset()) mark(lDone); |
15625 | } |
15626 | if (doBinary) |
15627 | for (int i = 0; i < problem.postOps.len(); i++) { |
15628 | if (!problem.postOps.entry_[i].is_binary()) continue; |
15629 | bool row = problem.binaryRow[i], col = problem.binaryCol[i]; |
15630 | auto T = problem.Tbinary[i]; |
15631 | auto &ld = state.inputs.binaryLDs[i]; |
15632 | auto offset = initial ? state.inputs.binaryOffsets[i] |
15633 | : state.effBinary[i]; |
15634 | if (row && col) { |
15635 | auto x0 = isColMajor(problem.binary[i].layout) ? i0 : j0; |
15636 | auto y0 = isColMajor(problem.binary[i].layout) ? j0 : i0; |
15637 | emad(1, offset, offset, x0, T.size(), strategy, state); |
15638 | emul(1, tempQ0, y0, ld, strategy, state); |
15639 | eadd(1, offset, offset, tempQ0.reinterpret(0, offset.getType()), |
15640 | strategy, state); |
15641 | } else if (row) |
15642 | emad(1, offset, offset, i0, T.size(), strategy, state); |
15643 | else if (col) |
15644 | emad(1, offset, offset, j0, T.size(), strategy, state); |
15645 | } |
15646 | if (doC && problem.sumA) |
15647 | emad(1, offsetCO, offsetCO, i0, Tco.size(), strategy, state); |
15648 | if (doC && problem.sumB) |
15649 | emad(1, offsetCO, offsetCO, j0, Tco.size(), strategy, state); |
15650 | |
15651 | // When k blocking (or certain triangular source kernels) |
15652 | // A += h0 * lda (N) h0 (T) h0 * mb (Pc) |
15653 | // B += h0 (N) h0 * ldb (T) h0 * nb (Pr) |
15654 | if (!h0.isInvalid()) { |
15655 | if (doA) switch (problem.A.layout) { |
15656 | case MatrixLayout::Nontranspose: |
15657 | emul(1, tempQ1, h0, state.inputs.lda, strategy, state); |
15658 | eadd(1, offsetA, offsetA, |
15659 | tempQ1.reinterpret(0, offsetA.getType()), strategy, |
15660 | state); |
15661 | break; |
15662 | case MatrixLayout::Transpose: |
15663 | emad(1, offsetA, offsetA, h0, Ta_ext.size(), strategy, |
15664 | state); |
15665 | break; |
15666 | case MatrixLayout::PackedColumns: |
15667 | emad(1, offsetA, offsetA, h0, |
15668 | strategy.unroll[LoopM] * Ta_ext, strategy, state); |
15669 | break; |
15670 | default: stub(); |
15671 | } |
15672 | if (doB) switch (problem.B.layout) { |
15673 | case MatrixLayout::Nontranspose: |
15674 | emad(1, offsetB, offsetB, h0, Tb_ext.size(), strategy, |
15675 | state); |
15676 | break; |
15677 | case MatrixLayout::Transpose: |
15678 | emul(1, tempQ0, h0, state.inputs.ldb, strategy, state); |
15679 | eadd(1, offsetB, offsetB, |
15680 | tempQ0.reinterpret(0, offsetB.getType()), strategy, |
15681 | state); |
15682 | break; |
15683 | case MatrixLayout::PackedRows: |
15684 | emad(1, offsetB, offsetB, h0, |
15685 | strategy.unroll[LoopN] * Tb_ext, strategy, state); |
15686 | break; |
15687 | default: stub(); |
15688 | } |
15689 | } |
15690 | |
15691 | state.ra.safeRelease(tempQ0); |
15692 | state.ra.safeRelease(tempQ1); |
15693 | |
15694 | if (doA && a2D && !ap2D) std::swap(offsetA, offsetAp); |
15695 | if (doB && b2D && !bp2D) std::swap(offsetB, offsetBp); |
15696 | if (doC && c2D && !cp2D) std::swap(offsetC0, offsetCp); |
15697 | } |
15698 | |
15699 | template <ngen::HW hw> |
15700 | void gemm_kernel_generator_t<hw>::gemmOffsetBatchABC(const GEMMProblem &problem, |
15701 | const GEMMStrategy &strategy, GEMMState &state) { |
15702 | auto Ts = problem.Ts; |
15703 | |
15704 | // Non-strided batch support. |
15705 | if (problem.batch == BatchMode::Nonstrided) { |
15706 | auto temp = state.ra.alloc().uq(); |
15707 | auto boffset = state.ra.alloc_sub<uint32_t>(); |
15708 | |
15709 | add(1, boffset, state.inputs.offsetBatch, state.inputs.groupIDK); |
15710 | mov(1, state.flagAP, 0x7); |
15711 | shl(1, boffset, boffset, uint16_t(3)); |
15712 | |
15713 | eadd(1, temp[0], state.inputs.A, boffset, strategy, state); |
15714 | eadd(1, temp[1], state.inputs.B, boffset, strategy, state); |
15715 | eadd(1, temp[2], state.inputs.C[0], boffset, strategy, state); |
15716 | |
15717 | load(4 | state.flagAP, temp, scattered_qword(1), strategy.A.base, temp); |
15718 | |
15719 | emov(1, state.effA, temp[0], strategy, state); |
15720 | emov(1, state.effB, temp[1], strategy, state); |
15721 | emov(1, state.effC[0], temp[2], strategy, state); |
15722 | |
15723 | state.ra.safeRelease(temp); |
15724 | state.ra.safeRelease(boffset); |
15725 | if (!strategy.persistent) |
15726 | state.ra.safeRelease(state.inputs.offsetBatch); |
15727 | } |
15728 | |
15729 | // Strided batch support. |
15730 | if (problem.batch == BatchMode::Strided) { |
15731 | for (int b = 0; b < problem.batchDims; b++) { |
15732 | emul(1, state.inputs.strideA[b], state.inputs.strideA[b], |
15733 | state.batchID[b], strategy, state); |
15734 | emul(1, state.inputs.strideB[b], state.inputs.strideB[b], |
15735 | state.batchID[b], strategy, state); |
15736 | emul(1, state.inputs.strideC[b], state.inputs.strideC[b], |
15737 | state.batchID[b], strategy, state); |
15738 | } |
15739 | |
15740 | for (int b = 0; b < problem.batchDims; b++) { |
15741 | eadd(1, state.offsetA, state.offsetA, state.inputs.strideA[b], |
15742 | strategy, state); |
15743 | eadd(1, state.offsetB, state.offsetB, state.inputs.strideB[b], |
15744 | strategy, state); |
15745 | for (int q = 0; q < state.C_count; q++) { |
15746 | auto offsetC = state.offsetC[q]; |
15747 | eadd(1, offsetC, offsetC, state.inputs.strideC[b], strategy, |
15748 | state); |
15749 | } |
15750 | if (!strategy.persistent) { |
15751 | state.ra.safeRelease(state.inputs.strideA[b]); |
15752 | state.ra.safeRelease(state.inputs.strideB[b]); |
15753 | state.ra.safeRelease(state.inputs.strideC[b]); |
15754 | } |
15755 | } |
15756 | } |
15757 | |
15758 | // Non-strided variable batch support. |
15759 | if (problem.batch == BatchMode::Variable) { |
15760 | auto tempA = state.ra.alloc().uq(); |
15761 | auto tempB = state.ra.alloc().uq(); |
15762 | auto tempC = state.ra.alloc().uq(); |
15763 | auto tempIDK = state.ra.alloc().ud(); |
15764 | auto offset_scalar = state.ra.alloc().uq(); |
15765 | auto offset_pointer = state.ra.alloc().uq(); |
15766 | auto tempAlphaReal = state.ra.alloc().uq(); |
15767 | auto tempAlphaImag = state.ra.alloc().uq(); |
15768 | auto tempBetaReal = state.ra.alloc().uq(); |
15769 | auto tempBetaImag = state.ra.alloc().uq(); |
15770 | |
15771 | eshl(1, tempIDK, state.inputs.groupIDK, uint16_t(log2(Ts.size())), |
15772 | strategy, state); |
15773 | |
15774 | // load and set alpha |
15775 | emul(1, offset_scalar, tempIDK, state.inputs.incr_alpha.uw(), strategy, |
15776 | state); |
15777 | eadd(1, tempAlphaReal[0], state.inputs.alpha_array, offset_scalar, |
15778 | strategy, state); |
15779 | if (Ts.isComplex()) { |
15780 | eadd(1, tempAlphaImag[0], tempAlphaReal[0], Ts.real().size(), |
15781 | strategy, state); |
15782 | } |
15783 | if (Ts.real().size() == 4) { |
15784 | load(1, tempAlphaReal, scattered_dword(1), A64, tempAlphaReal); |
15785 | if (Ts.isComplex()) { |
15786 | load(1, tempAlphaImag, scattered_dword(1), A64, tempAlphaImag); |
15787 | } |
15788 | } else if (Ts.real().size() == 2) { |
15789 | load(1, tempAlphaReal, scattered_byte(2), A64, tempAlphaReal); |
15790 | if (Ts.isComplex()) { |
15791 | load(1, tempAlphaImag, scattered_byte(2), A64, tempAlphaImag); |
15792 | } |
15793 | } else { |
15794 | load(1, tempAlphaReal, scattered_qword(1), A64, tempAlphaReal); |
15795 | if (Ts.isComplex()) { |
15796 | load(1, tempAlphaImag, scattered_qword(1), A64, tempAlphaImag); |
15797 | } |
15798 | } |
15799 | mov(1, state.inputs.alpha_real, tempAlphaReal.sub(0, Ts.real().ngen())); |
15800 | if (Ts.isComplex()) |
15801 | mov(1, state.inputs.alpha_imag, |
15802 | tempAlphaImag.sub(0, Ts.real().ngen())); |
15803 | // end load and set alpha |
15804 | |
15805 | // load and set beta |
15806 | emul(1, offset_scalar, tempIDK, state.inputs.incr_beta.uw(), strategy, |
15807 | state); |
15808 | eadd(1, tempBetaReal[0], state.inputs.beta_array, offset_scalar, |
15809 | strategy, state); |
15810 | if (Ts.isComplex()) { |
15811 | eadd(1, tempBetaImag[0], tempBetaReal[0], Ts.real().size(), |
15812 | strategy, state); |
15813 | } |
15814 | if (Ts.real().size() == 4) { |
15815 | load(1, tempBetaReal, scattered_dword(1), A64, tempBetaReal); |
15816 | if (Ts.isComplex()) { |
15817 | load(1, tempBetaImag, scattered_dword(1), A64, tempBetaImag); |
15818 | } |
15819 | } else if (Ts.real().size() == 2) { |
15820 | load(1, tempBetaReal, scattered_byte(2), A64, tempBetaReal); |
15821 | if (Ts.isComplex()) { |
15822 | load(1, tempBetaImag, scattered_byte(2), A64, tempBetaImag); |
15823 | } |
15824 | } else { |
15825 | load(1, tempBetaReal, scattered_qword(1), A64, tempBetaReal); |
15826 | if (Ts.isComplex()) { |
15827 | load(1, tempBetaImag, scattered_qword(1), A64, tempBetaImag); |
15828 | } |
15829 | } |
15830 | mov(1, state.inputs.beta_real, tempBetaReal.sub(0, Ts.real().ngen())); |
15831 | if (Ts.isComplex()) |
15832 | mov(1, state.inputs.beta_imag, |
15833 | tempBetaImag.sub(0, Ts.real().ngen())); |
15834 | // end load and set beta |
15835 | |
15836 | eshl(1, tempIDK, state.inputs.groupIDK, uint16_t(3), strategy, state); |
15837 | emul(1, offset_pointer, tempIDK, state.inputs.incr_a_array.uw(), |
15838 | strategy, state); |
15839 | eadd(1, tempA[0], state.inputs.A, offset_pointer, strategy, state); |
15840 | load(1, tempA, scattered_qword(1), strategy.A.base, tempA); |
15841 | |
15842 | emul(1, offset_pointer, tempIDK, state.inputs.incr_b_array.uw(), |
15843 | strategy, state); |
15844 | eadd(1, tempB[0], state.inputs.B, offset_pointer, strategy, state); |
15845 | load(1, tempB, scattered_qword(1), strategy.B.base, tempB); |
15846 | |
15847 | eadd(1, tempC[0], state.inputs.C[0], tempIDK, strategy, state); |
15848 | load(1, tempC, scattered_qword(1), strategy.C.base, tempC); |
15849 | |
15850 | emov(1, state.effA, tempA, strategy, state); |
15851 | emov(1, state.effB, tempB, strategy, state); |
15852 | emov(1, state.effC[0], tempC, strategy, state); |
15853 | |
15854 | state.ra.safeRelease(tempA); |
15855 | state.ra.safeRelease(tempB); |
15856 | state.ra.safeRelease(tempC); |
15857 | state.ra.safeRelease(tempIDK); |
15858 | state.ra.safeRelease(offset_scalar); |
15859 | state.ra.safeRelease(offset_pointer); |
15860 | state.ra.safeRelease(tempAlphaReal); |
15861 | state.ra.safeRelease(tempAlphaImag); |
15862 | state.ra.safeRelease(tempBetaReal); |
15863 | state.ra.safeRelease(tempBetaImag); |
15864 | if (!strategy.persistent) { |
15865 | state.ra.safeRelease(state.inputs.incr_a_array); |
15866 | state.ra.safeRelease(state.inputs.incr_b_array); |
15867 | state.ra.safeRelease(state.inputs.incr_alpha); |
15868 | state.ra.safeRelease(state.inputs.incr_beta); |
15869 | state.ra.safeRelease(state.inputs.alpha_array); |
15870 | state.ra.safeRelease(state.inputs.beta_array); |
15871 | } |
15872 | } |
15873 | } |
15874 | |
15875 | // Prepare for persistent GEMM by folding offsets into A/B/C pointers (if stateless), |
15876 | // or saving offsets (if stateful) |
15877 | template <HW hw> |
15878 | void gemm_kernel_generator_t<hw>::gemmFoldOffsets(const GEMMProblem &problem, |
15879 | const GEMMStrategy &strategy, GEMMState &state) { |
15880 | auto foldOrSave |
15881 | = [&](const MatrixAddressingStrategy &sX, Subregister &inputX, |
15882 | Subregister &offsetX, const Subregister &inputOffsetX, |
15883 | Subregister &saveOffsetX, bool newInput = false) { |
15884 | if (sX.base.isStateless()) { |
15885 | auto oldInputX = inputX; |
15886 | if (newInput) |
15887 | inputX = state.ra.alloc_sub(DataType::uq, |
15888 | getHint(HintType::LongTerm, strategy)); |
15889 | eadd(1, inputX, oldInputX, offsetX, strategy, state); |
15890 | if (getBytes(offsetX.getType()) < 8) { |
15891 | state.ra.safeRelease(offsetX); |
15892 | offsetX = state.ra.alloc_sub(DataType::uq, |
15893 | getHint(HintType::LongTerm, strategy)); |
15894 | } |
15895 | emov(1, offsetX, 0, strategy, state); |
15896 | } else { |
15897 | offsetX = state.ra.alloc_sub(offsetX.getType(), |
15898 | getHint(HintType::LongTerm, strategy)); |
15899 | mov(1, offsetX, inputOffsetX); |
15900 | } |
15901 | saveOffsetX = offsetX; |
15902 | }; |
15903 | |
15904 | bool deduplicateAB = (state.inputs.A == state.inputs.B); |
15905 | |
15906 | foldOrSave(strategy.A, state.inputs.A, state.offsetA, state.inputs.offsetA, |
15907 | state.saveOffsetA, deduplicateAB); |
15908 | foldOrSave(strategy.B, state.inputs.B, state.offsetB, state.inputs.offsetB, |
15909 | state.saveOffsetB); |
15910 | for (int q = 0; q < state.C_count; q++) |
15911 | foldOrSave(strategy.C, state.inputs.C[q], state.offsetC[q], |
15912 | state.inputs.offsetC[q], |
15913 | state.saveOffsetC[q]); // todo init for hpl |
15914 | if (problem.usesCO()) |
15915 | foldOrSave(strategy.CO, state.inputs.CO, state.offsetCO, |
15916 | state.inputs.offsetCO, state.saveOffsetCO); |
15917 | |
15918 | if (deduplicateAB) state.effA = state.inputs.A; |
15919 | } |
15920 | |
15921 | // Restore input offsets from saved copies, for persistent GEMM. |
15922 | template <HW hw> |
15923 | void gemm_kernel_generator_t<hw>::gemmRestoreOffsets(const GEMMProblem &problem, |
15924 | const GEMMStrategy &strategy, GEMMState &state) { |
15925 | auto zeroOrRestore = [&](const MatrixAddressingStrategy &sX, |
15926 | const Subregister &offsetX, |
15927 | const Subregister &inputOffsetX) { |
15928 | if (sX.base.isStateless()) |
15929 | emov(1, offsetX, 0, strategy, state); |
15930 | else |
15931 | mov(1, offsetX, inputOffsetX); |
15932 | }; |
15933 | |
15934 | zeroOrRestore(strategy.A, state.saveOffsetA, state.inputs.offsetA); |
15935 | zeroOrRestore(strategy.B, state.saveOffsetB, state.inputs.offsetB); |
15936 | for (int q = 0; q < state.C_count; q++) |
15937 | zeroOrRestore( |
15938 | strategy.C, state.saveOffsetC[q], state.inputs.offsetC[q]); |
15939 | if (problem.usesCO()) |
15940 | zeroOrRestore(strategy.CO, state.saveOffsetCO, state.inputs.offsetCO); |
15941 | } |
15942 | |
15943 | template <HW hw> |
15944 | void gemm_kernel_generator_t<hw>::gemmSetupABC(const GEMMProblem &problem, |
15945 | const GEMMStrategy &strategy, GEMMState &state) { |
15946 | if (strategy.persistent) { |
15947 | state.effA = state.offsetA; |
15948 | state.effB = state.offsetB; |
15949 | for (int q = 0; q < state.C_count; q++) |
15950 | state.effC[q] = state.offsetC[q]; |
15951 | state.effCO = state.offsetCO; |
15952 | } |
15953 | |
15954 | // Add offsets to A, B, C base pointers for stateless accesses. |
15955 | if (strategy.C.base.isStateless()) { |
15956 | for (int q = 0; q < state.C_count; q++) { |
15957 | auto Csrc = state.inputs.C[q]; |
15958 | if ((q > 0) && strategy.C.base.isStateless() |
15959 | && state.inputs.base.isValid()) |
15960 | state.effC[q] = state.inputs.C[q] |
15961 | = state.ra.alloc_sub<uint64_t>( |
15962 | getHint(HintType::LongTerm, strategy)); |
15963 | |
15964 | eadd(1, state.effC[q], Csrc, state.offsetC[q], strategy, state); |
15965 | if (strategy.persistent) |
15966 | state.offsetC[q] = invalid; |
15967 | else |
15968 | state.ra.safeRelease(state.offsetC[q]); |
15969 | } |
15970 | } |
15971 | |
15972 | if (problem.usesCO() && strategy.CO.base.isStateless()) { |
15973 | eadd(1, state.effCO, state.inputs.CO, state.offsetCO, strategy, state); |
15974 | if (strategy.persistent) |
15975 | state.offsetCO = invalid; |
15976 | else |
15977 | state.ra.safeRelease(state.offsetCO); |
15978 | } |
15979 | |
15980 | if (state.offsetAp.isValid()) { |
15981 | if (strategy.A.base.isStateless()) { |
15982 | state.effAp = state.ra.alloc_sub<uint64_t>( |
15983 | getHint(HintType::LongTerm, strategy)); |
15984 | eadd(1, state.effAp, state.inputs.A, state.offsetAp, strategy, |
15985 | state); |
15986 | state.ra.safeRelease(state.offsetAp); |
15987 | } else |
15988 | state.effAp = state.offsetAp; |
15989 | } |
15990 | |
15991 | if (state.offsetBp.isValid()) { |
15992 | if (strategy.B.base.isStateless()) { |
15993 | state.effBp = state.ra.alloc_sub<uint64_t>( |
15994 | getHint(HintType::LongTerm, strategy)); |
15995 | eadd(1, state.effBp, state.inputs.B, state.offsetBp, strategy, |
15996 | state); |
15997 | state.ra.safeRelease(state.offsetBp); |
15998 | } else |
15999 | state.effBp = state.offsetBp; |
16000 | } |
16001 | |
16002 | if (state.offsetCp.isValid()) { |
16003 | if (strategy.C.base.isStateless()) { |
16004 | state.effCp = state.ra.alloc_sub<uint64_t>( |
16005 | getHint(HintType::LongTerm, strategy)); |
16006 | eadd(1, state.effCp, state.inputs.C[0], state.offsetCp, strategy, |
16007 | state); |
16008 | state.ra.safeRelease(state.offsetCp); |
16009 | } else |
16010 | state.effCp = state.offsetCp; |
16011 | } |
16012 | |
16013 | if (strategy.A.base.isStateless()) { |
16014 | auto Asrc = state.inputs.A; |
16015 | if (strategy.B.base.isStateless() && (state.effA == state.effB)) |
16016 | state.effA = state.inputs.A = state.ra.alloc_sub<uint64_t>( |
16017 | getHint(HintType::LongTerm, strategy)); |
16018 | |
16019 | eadd(1, state.effA, Asrc, state.offsetA, strategy, state); |
16020 | if (strategy.persistent) |
16021 | state.offsetA = invalid; |
16022 | else |
16023 | state.ra.safeRelease(state.offsetA); |
16024 | } |
16025 | |
16026 | if (strategy.B.base.isStateless()) { |
16027 | eadd(1, state.effB, state.inputs.B, state.offsetB, strategy, state); |
16028 | if (strategy.persistent) |
16029 | state.offsetB = invalid; |
16030 | else |
16031 | state.ra.safeRelease(state.offsetB); |
16032 | } |
16033 | |
16034 | if (strategy.prefetchA && state.effAp.isInvalid()) state.effAp = state.effA; |
16035 | if (strategy.prefetchB && state.effBp.isInvalid()) state.effBp = state.effB; |
16036 | if (strategy.prefetchC && state.effCp.isInvalid()) |
16037 | state.effCp = state.effC[0]; |
16038 | } |
16039 | |
16040 | // Get (possibly multidimensional) batch IDs. |
16041 | template <HW hw> |
16042 | void gemm_kernel_generator_t<hw>::gemmGetBatchIDs(const GEMMProblem &problem, |
16043 | const GEMMStrategy &strategy, GEMMState &state) { |
16044 | switch (problem.batchDims) { |
16045 | case 0: break; |
16046 | case 1: state.batchID[0] = state.inputs.groupIDK; break; |
16047 | case 2: { |
16048 | state.batchID[0] = state.ra.alloc_sub<uint32_t>(); |
16049 | state.batchID[1] = state.ra.alloc_sub<uint32_t>(); |
16050 | divDown(state.batchID[1], state.inputs.groupIDK, |
16051 | state.inputs.batchSize1, state.inputs.recipBatchSize1, |
16052 | state.flagAP, strategy, state); |
16053 | emul(1, state.batchID[0], state.batchID[1], state.inputs.batchSize1, |
16054 | strategy, state); |
16055 | add(1, state.batchID[0], -state.batchID[0], state.inputs.groupIDK); |
16056 | if (!strategy.persistent) { |
16057 | state.ra.safeRelease(state.inputs.batchSize1); |
16058 | state.ra.safeRelease(state.inputs.recipBatchSize1); |
16059 | } |
16060 | break; |
16061 | } |
16062 | default: stub(); |
16063 | } |
16064 | } |
16065 | |
16066 | template <HW hw> |
16067 | void gemm_kernel_generator_t<hw>::gemmReleaseBatchIDs( |
16068 | const GEMMProblem &problem, const GEMMStrategy &strategy, |
16069 | GEMMState &state) { |
16070 | if (problem.batch != BatchMode::Strided) return; |
16071 | if (problem.batchDims == 1 && state.r0_info == r0) return; |
16072 | if (problem.hasBinaryPostOp()) return; |
16073 | for (int b = 0; b < problem.batchDims; b++) |
16074 | state.ra.safeRelease(state.batchID[b]); |
16075 | } |
16076 | |
16077 | // Convert linear index to 2D index in a Hilbert curve-like fashion. |
16078 | template <HW hw> |
16079 | void gemm_kernel_generator_t<hw>::gemmHilbertlikeOrder( |
16080 | const GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state) { |
16081 | bool triangular = false; |
16082 | bool rectangular = !triangular && state.inputs.hilbertVD.isValid(); |
16083 | |
16084 | auto storage = state.ra.alloc(); |
16085 | auto u = storage.ud(0); |
16086 | auto v = storage.ud(1); |
16087 | auto uh = storage.ud(2); |
16088 | auto vh = storage.ud(3); |
16089 | auto a = storage.ud(4); |
16090 | auto b = storage.ud(5); |
16091 | /* auto isNormal = storage.ud(6); */ // not used directly |
16092 | auto isReversed = storage.ud(7); |
16093 | int soff = storage.getBase() * GRF::bytes(hw); |
16094 | |
16095 | auto storage2 = state.ra.alloc_range(2); |
16096 | auto nbu = storage2[0].ud(0); |
16097 | auto nbv = storage2[0].ud(1); |
16098 | auto np1 = storage2[0].ud(2); |
16099 | auto bv1 = storage2[0].ud(3); |
16100 | auto uv1 = storage2[0].ud(4); |
16101 | auto temp3 = storage2[0].ud(5); |
16102 | auto uo = storage2[0].ud(6); |
16103 | /* auto vo = storage2[0].ud(7); */ // not used directly |
16104 | auto temp = storage2[1].ud(0); |
16105 | auto temp2 = storage2[1].ud(1); |
16106 | auto qrem = storage2[1].ud(2); |
16107 | auto qqot = storage2[1].ud(4); |
16108 | auto q = storage2[1].ud(6); |
16109 | auto ud = storage2[1].ud(7); |
16110 | |
16111 | auto bu = f1[0], bv = f1[1]; |
16112 | |
16113 | auto vd = state.inputs.hilbertVD; |
16114 | auto uvdRecip = state.inputs.hilbertUVDRecip; |
16115 | auto hilbertBail = state.inputs.hilbertBail; |
16116 | |
16117 | auto any8 = (hw == HW::XeHPC) ? any : any8h; |
16118 | auto any16 = (hw == HW::XeHPC) ? any : any16h; |
16119 | bool avoidAny2 = (hw == HW::XeHPC); |
16120 | |
16121 | auto jumpAny2 = [&](InstructionModifier mod, Label &l) { |
16122 | if (avoidAny2) { |
16123 | mod.setExecSize(16); |
16124 | goto12(mod | any16, l); |
16125 | } else |
16126 | jmpi(mod | any2h, l); |
16127 | }; |
16128 | |
16129 | Label lTriangularTop, lTriangularExit, lTriangularBypass; |
16130 | Label lRecursiveTop, lRecursiveEnd; |
16131 | |
16132 | // NB: Sequence assumes group counts fit in 16 bits. |
16133 | status << "Hilbert-like ordering" << status_stream::endl; |
16134 | if (avoidAny2) mov(1, f0[0], 0); |
16135 | if (rectangular) |
16136 | mov(1, f0[1], |
16137 | vd.uw(1)); // High word of vd = 0xFFFF -> start splitting in x |
16138 | else if (triangular) |
16139 | cmp(1 | ne | f0[0], state.inputs.diagC, 0); |
16140 | mov(1, u, state.inputs.groupCountM); |
16141 | mov(1, v, state.inputs.groupCountN); |
16142 | mov(4, a0, Immediate::uv(4, 0, 12, 8, 0, 0, 0, 0)); |
16143 | mov(1, f1.ud(0), 0); // bu = bv = false |
16144 | mov(1, np1, triangular ? 0xFFFFFFFF : 0); |
16145 | if (triangular) |
16146 | cmp(1 | ~f0[0] | ne | f0[0], state.inputs.m, state.inputs.n); |
16147 | else |
16148 | cmp(2 | le | f0[0], u(1), hilbertBail); |
16149 | mov(1, q, state.inputs.groupIDMN); |
16150 | add(4, a0[4](1), a0[0](1), 16); |
16151 | if (!rectangular && !triangular) |
16152 | emad(1, uv1, -1, u.uw(), v.uw(), strategy, state); |
16153 | mov(8, a.uw()(1), |
16154 | Immediate::uv(0x00010000)); // a = b = 0, normal = 1, reversed = 0; |
16155 | if (soff >= 512) { |
16156 | add(4, a0, a0, soff); |
16157 | soff = 0; |
16158 | } |
16159 | if (triangular) |
16160 | jmpi(1 | f0[0], lTriangularBypass); |
16161 | else |
16162 | jumpAny2(1 | f0[0], lRecursiveEnd); |
16163 | |
16164 | // Rectangular partitioning step. Break dispatch into blocks of roughly desired aspect ratio. |
16165 | if (rectangular) { |
16166 | auto uvd = uv1; |
16167 | movi(8 | f0[1], storage.ud(), indirect[a0].ud(soff)(1)); |
16168 | mul(1, uvd, u, vd.uw()); |
16169 | divDown(nbv, q, uvd, uvdRecip, f0[0], strategy, state); |
16170 | and_(1 | ne | bv, bv1, nbv, 1); |
16171 | mul(1, temp, uvd, nbv.uw()); |
16172 | mul(1, b, vd.uw(), nbv.uw()); |
16173 | add(1, q, q, -temp); // no DWxW with source modifiers |
16174 | add(1, v, v, -b); |
16175 | avg(1, ud, u, -bv1); |
16176 | min_(1, v, v, vd.uw()); |
16177 | avg(1, uh, u, 0); |
16178 | mul(1, temp, v.uw(), ud.uw()); |
16179 | cmp(1 | ge | bu, nbu, q, temp); |
16180 | add(1 | bu, q, q, -temp); |
16181 | cmp(1 | ne | bu, nbu, nbu.d(), -bv1.d()); // {bu,nbu} ^= bv1 |
16182 | sel(1 | bu, a, uh, 0); |
16183 | avg(1, u, u, nbu.d()); |
16184 | movi(8 | ~bu | any8, storage.ud(), indirect[a0].ud(soff)(1)); |
16185 | cmp(2 | le | f0[0], u(1), hilbertBail); |
16186 | sel(1 | ~bu, np1, -bv1, 0); |
16187 | emad(1, uv1, -1, u.uw(), v.uw(), strategy, state); |
16188 | mov(1, f1.ud(0), 0); // bu = bv = false |
16189 | jumpAny2(1 | f0[0], lRecursiveEnd); |
16190 | } |
16191 | |
16192 | // Recursive partitioning. Each step breaks the current block |
16193 | // into 2x2 subblocks and follows the block we are currently in. |
16194 | // Exit when one dimension is less than hilbertBail. |
16195 | mark(lRecursiveTop); |
16196 | { |
16197 | avg(2, uh(1), u(1), 0); |
16198 | add(1 | bv, q, uv1, -q); |
16199 | |
16200 | mul(1, temp, u.uw(), vh.uw()); |
16201 | cmp(1 | ge | bv, nbv, q, temp); |
16202 | mov(2, uo(1), u(1)); |
16203 | add(1 | bv, q, uv1, -q); |
16204 | avg(1, v, v, nbv.d()); |
16205 | mul(1, temp, uh.uw(), v.uw()); |
16206 | cmp(1 | ge | bu, nbu, q, temp); |
16207 | add(1 | bu, q, q, -temp); |
16208 | avg(1, u, u, nbu.d()); |
16209 | |
16210 | xor_(2, temp(1), nbu(1), np1); |
16211 | avg(2, uo(1), uo(1), np1.d()); |
16212 | xor_(1 | bv, np1, np1, ~nbu); |
16213 | and_(2, uo(1), uo(1), temp(1)); |
16214 | emad(1, uv1, -1, u.uw(), v.uw(), strategy, state); |
16215 | add(2, a(1), a(1), uo(1)); |
16216 | |
16217 | cmp(2 | le | f0[0], u(1), hilbertBail); |
16218 | movi(8 | ~bu | any8, storage.ud(), indirect[a0].ud(soff)(1)); |
16219 | |
16220 | if (avoidAny2) |
16221 | goto12(16 | ~f0[0] | any16, lRecursiveEnd, lRecursiveTop, true); |
16222 | else |
16223 | jmpi(1 | ~f0[0] | any2h, lRecursiveTop); |
16224 | } |
16225 | mark(lRecursiveEnd); |
16226 | if (avoidAny2) join(16); |
16227 | |
16228 | cmp(8 | ne | f0[0], isReversed, 0); |
16229 | movi(8 | f0[0], storage.ud(), indirect[a0].ud(soff)(1)); |
16230 | |
16231 | // Regular 2D traversal over final block. |
16232 | bool nmk = (strategy.loopOrder[0] == LoopN); |
16233 | auto divisor = nmk ? v : u; |
16234 | |
16235 | if (hw < HW::Gen12LP) { |
16236 | irem(1, qrem, q, divisor); |
16237 | iqot(1, qqot, q, divisor); |
16238 | } else { |
16239 | auto bias = temp.f(); |
16240 | auto divisorFP = temp2.f(); |
16241 | auto qFP = temp3.f(); |
16242 | mov(1, divisorFP, divisor); |
16243 | mov(1, qFP, q); |
16244 | mov(1, bias, -0.499996185302734375f); // -1/2 + 2^(-18) |
16245 | einv(1, divisorFP, divisorFP, strategy, state); |
16246 | add(1, divisorFP.ud(), divisorFP.ud(), 2); |
16247 | mad(1, qqot.f(), bias, qFP, divisorFP); |
16248 | mov(1, qqot, qqot.f()); |
16249 | mad(1, qrem, q, -qqot.uw(), divisor.uw()); |
16250 | } |
16251 | |
16252 | // Reassign m/n group IDs. |
16253 | if (!strategy.persistent) { |
16254 | state.inputs.groupIDM = state.inputs.groupCountM; |
16255 | state.inputs.groupIDN = state.inputs.groupCountN; |
16256 | state.inputs.groupCountM = invalid; |
16257 | state.inputs.groupCountN = invalid; |
16258 | } else { |
16259 | state.inputs.groupIDM = state.ra.alloc_sub<uint32_t>(); |
16260 | state.inputs.groupIDN = state.ra.alloc_sub<uint32_t>(); |
16261 | } |
16262 | |
16263 | add(1, state.inputs.groupIDM, a, nmk ? qqot : qrem); |
16264 | add(1, state.inputs.groupIDN, b, nmk ? qrem : qqot); |
16265 | |
16266 | state.ra.safeRelease(storage); |
16267 | state.ra.safeRelease(storage2); |
16268 | if (!strategy.persistent) { |
16269 | state.ra.safeRelease(state.inputs.hilbertVD); |
16270 | state.ra.safeRelease(state.inputs.hilbertUVDRecip); |
16271 | state.ra.safeRelease(state.inputs.hilbertBail); |
16272 | } |
16273 | } |
16274 | |
16275 | // Convert linear index to 2D index in a boustrophedon pattern. |
16276 | template <HW hw> |
16277 | void gemm_kernel_generator_t<hw>::gemmBoustrophedonOrder( |
16278 | const GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state) { |
16279 | auto storage = state.ra.alloc_range(4); |
16280 | auto u = storage[0].ud(0); |
16281 | auto s = storage[0].ud(1); |
16282 | auto v = storage[0].ud(2); |
16283 | auto s1 = storage[0].ud(3); |
16284 | auto i = storage[0].ud(4); |
16285 | auto j = storage[0].ud(5); |
16286 | auto i0 = storage[0].ud(6); |
16287 | auto two = storage[0].f(7); |
16288 | auto numFP = storage[1].f(0); |
16289 | auto islice = storage[1].ud(1); |
16290 | auto qot = storage[1].ud(2); |
16291 | auto rem = storage[1].ud(4); |
16292 | auto ithresh = storage[1].ud(6); |
16293 | auto temp0 = storage[2].ud(0); |
16294 | auto temp1 = storage[2].ud(2); |
16295 | auto temp2 = storage[2].ud(4); |
16296 | auto bias = storage[3].f(0); |
16297 | auto denomFP = storage[3].f(2); |
16298 | auto q = storage[3].ud(4); |
16299 | auto qFP = storage[3].f(6); |
16300 | |
16301 | auto s0 = state.inputs |
16302 | .bslice; // Slice width/height in WGs. Sign interpretation: |
16303 | // + means slice in m dimension, - means n dimension |
16304 | auto thresh = state.inputs.bthresh; // Slice size adjustment threshold |
16305 | // + means increase slice size by 1 starting with this row/column |
16306 | // - means decrease slice size by 1 starting with this row/column |
16307 | |
16308 | auto &groupCountM = state.inputs.groupCountM; |
16309 | auto &groupCountN = state.inputs.groupCountN; |
16310 | auto idMN = state.inputs.groupIDMN; |
16311 | |
16312 | Label lBegin, lEnd, lDone, lBeginTri2, lEndTri2, lTricalc1, lTricalc2, |
16313 | lTricalcOut; |
16314 | |
16315 | auto divqot = [&](const Subregister &num, const Subregister &denom, |
16316 | bool large) { |
16317 | if (hw < HW::Gen12LP) { |
16318 | irem(1, rem, num, denom); |
16319 | iqot(1, qot, num, denom); |
16320 | } else if (large) { |
16321 | // denom <= 0x400000, qot < 2^16 |
16322 | or_(1, cr0[0], cr0[0], 0x20); // round toward -inf |
16323 | mov(1, denomFP, denom); |
16324 | mov(1, numFP, -num); |
16325 | einv(1, denomFP, denomFP, strategy, state); |
16326 | add(1, denomFP.ud(), denomFP.ud(), 2); |
16327 | mul(1, qot.f(), -numFP, denomFP); |
16328 | mov(1, qot, qot.f()); |
16329 | mad(1 | lt | f1[1], rem.d(), num, denom, -qot.uw()); |
16330 | add(1 | f1[1], rem, rem, denom); |
16331 | add(1 | f1[1], qot, qot, -1); |
16332 | and_(1, cr0[0], cr0[0], ~0x30); |
16333 | } else { |
16334 | // denom <= 0x40, qot < 2^16 |
16335 | mov(1, denomFP, denom); |
16336 | mov(1, numFP, num); |
16337 | mov(1, bias, -0.499996185302734375f); // -1/2 + 2^(-18) |
16338 | einv(1, denomFP, denomFP, strategy, state); |
16339 | add(1, denomFP.ud(), denomFP.ud(), 2); |
16340 | mad(1, qot.f(), bias, numFP, denomFP); |
16341 | mov(1, qot, qot.f()); |
16342 | mad(1, rem, num, -qot.uw(), denom.uw()); |
16343 | } |
16344 | }; |
16345 | |
16346 | auto ecsel |
16347 | = [&](const InstructionModifier &mod, |
16348 | const InstructionModifier &cmod, const FlagRegister &flag, |
16349 | const RegData &dst, const RegData &src0, |
16350 | const RegData &src1, const RegData &src2) { |
16351 | if (hw == HW::Gen9 || dst.getByteOffset() & 7) { |
16352 | cmp(mod | cmod | flag, src2, 0); |
16353 | sel(mod | flag, dst, src0, src1); |
16354 | } else |
16355 | csel(mod | cmod | flag, dst, src0, src1, src2); |
16356 | }; |
16357 | |
16358 | // NB: Sequence assumes group counts fit in 16 bits. |
16359 | status << "Boustrophedon ordering" << status_stream::endl; |
16360 | |
16361 | mul(1, ithresh, abs(thresh.w()), abs(s0.w())); |
16362 | cmp(1 | ge | f1[0], thresh, 0); |
16363 | ecsel(1, lt, f0[0], v, groupCountM, groupCountN, s0); |
16364 | ecsel(1, ge, f0[0], u, groupCountM, groupCountN, s0); |
16365 | |
16366 | emad(1, temp0, idMN, -v.uw(), ithresh.uw(), strategy, state); |
16367 | cmp(1 | ge | f0[0], temp2.d(), temp0.d(), 0); |
16368 | ecsel(1, ge, f0[1], q, temp0, idMN, temp0.d()); |
16369 | |
16370 | if (hw == HW::XeHPC) { |
16371 | add(1, s1, abs(s0), 1); |
16372 | add(1 | ~f0[0], s1, abs(s0), temp2.d()); |
16373 | add(1 | ~f1[0], s1, abs(s0), temp2.d()); |
16374 | } else { |
16375 | add(1, s1, abs(s0), temp2.d()); |
16376 | add(1 | f0[0] | allv, s1, abs(s0), 1); |
16377 | } |
16378 | |
16379 | mul(1, temp1, s1.uw(), v.uw()); |
16380 | |
16381 | divqot(q, temp1, true); |
16382 | |
16383 | mul(1, i0, qot.uw(), s1.uw()); |
16384 | mov(1, islice, qot); |
16385 | add(1 | f0[0], i0, i0, ithresh); |
16386 | mov(1, q, rem); |
16387 | add(1 | sat, temp0, u, -i0); |
16388 | min_(1, s, s1, temp0); |
16389 | add(1 | f0[0], islice, islice, abs(thresh)); |
16390 | |
16391 | mul(1, temp2, s.uw(), s.uw()); |
16392 | emad(1, temp1, temp1, -s.uw(), s.uw(), strategy, state); |
16393 | |
16394 | cmp(1 | gt | f0[0], i0, 0); // not first row? |
16395 | cmp(1 | lt | f0[1], s1, temp0); // not last row? |
16396 | |
16397 | if (hw == HW::XeHPC) { |
16398 | cmp(1 | f0[0] | lt | f0[0], q, temp2); // beginning of row? |
16399 | cmp(1 | f0[1] | ge | f0[1], q, temp1); // end of row? |
16400 | } else { |
16401 | cmp(1 | lt | f1[0], q, temp2); // beginning of row? |
16402 | cmp(1 | ge | f1[1], q, temp1); // end of row? |
16403 | } |
16404 | |
16405 | mov(1, two, 2.0f); |
16406 | mov(1, bias, 1.25f); |
16407 | |
16408 | if (hw == HW::XeHPC) { |
16409 | jmpi(1 | f0[0], lBegin); |
16410 | jmpi(1 | f0[1], lEnd); |
16411 | } else { |
16412 | jmpi(1 | f0[0] | allv, lBegin); |
16413 | jmpi(1 | f0[1] | allv, lEnd); |
16414 | } |
16415 | |
16416 | { |
16417 | divqot(q, s, false); |
16418 | |
16419 | add(1, i, i0, rem); |
16420 | mov(1, j, qot); |
16421 | } |
16422 | |
16423 | jmpi(1, lDone); |
16424 | |
16425 | mark(lBegin); |
16426 | { |
16427 | avg(1, temp0, temp2, -s); // s(s-1)/2 |
16428 | mov(1, f1.ud(0), 0xFFFF); |
16429 | cmp(1 | lt | f0[0], q, temp0); |
16430 | jmpi(1 | ~f0[0], lBeginTri2); |
16431 | |
16432 | eadd3(1, q, temp0, -q, -1); |
16433 | jmpi(1, lTricalc1); |
16434 | |
16435 | mark(lBeginTri2); |
16436 | add(1, q, q, -temp0); |
16437 | jmpi(1, lTricalc2); |
16438 | } |
16439 | |
16440 | mark(lEnd); |
16441 | { |
16442 | add(1, q, q, -temp1); |
16443 | avg(1, temp0, temp2, s); // s(s+1)/2 |
16444 | mov(1, f1.ud(0), 0); |
16445 | cmp(1 | lt | f0[0], q, temp0); |
16446 | jmpi(1 | ~f0[0], lEndTri2); |
16447 | |
16448 | eadd3(1, q, temp0, -q, -1); |
16449 | mark(lTricalc2); |
16450 | { |
16451 | mov(1, qFP, q); |
16452 | mad(1, qFP, bias, qFP, two); |
16453 | esqt(1, qFP, qFP, strategy, state); |
16454 | if (hw == HW::Gen9) rnde(1, qFP, qFP); |
16455 | mov(1, j, qFP); |
16456 | mul(1, temp0, j.uw(), j.uw()); |
16457 | avg(1, temp0, temp0, -j); |
16458 | add(1, j, j, -1); |
16459 | add(1, i, q, -temp0); |
16460 | } |
16461 | jmpi(1, lTricalcOut); |
16462 | |
16463 | mark(lEndTri2); |
16464 | add(1, q, q, -temp0); |
16465 | mark(lTricalc1); |
16466 | { |
16467 | mov(1, qFP, q); |
16468 | mad(1, qFP, bias, qFP, two); |
16469 | esqt(1, qFP, qFP, strategy, state); |
16470 | if (hw == HW::Gen9) rnde(1, qFP, qFP); |
16471 | mov(1, i, qFP); |
16472 | mul(1, temp0, i.uw(), i.uw()); |
16473 | avg(1, temp0, temp0, -i); |
16474 | add(1, j, q, -temp0); |
16475 | } |
16476 | |
16477 | mark(lTricalcOut); |
16478 | eadd3(1 | f1[0], i, s, -i, -1); |
16479 | eadd3(1 | ~f1[0], j, v, -j, -1); |
16480 | add(1, i, i, i0); |
16481 | } |
16482 | |
16483 | // Reassign m/n group IDs. |
16484 | mark(lDone); |
16485 | |
16486 | if (!strategy.persistent) { |
16487 | state.inputs.groupIDM = state.inputs.groupCountM; |
16488 | state.inputs.groupIDN = state.inputs.groupCountN; |
16489 | state.inputs.groupCountM = invalid; |
16490 | state.inputs.groupCountN = invalid; |
16491 | } else { |
16492 | state.inputs.groupIDM = state.ra.alloc_sub<uint32_t>(); |
16493 | state.inputs.groupIDN = state.ra.alloc_sub<uint32_t>(); |
16494 | } |
16495 | |
16496 | and_(1 | ne | f1[1], null.ud(), islice, 1); |
16497 | eadd3(1 | f1[1], j, v, -j, -1); |
16498 | ecsel(1, ge, f0[0], state.inputs.groupIDM, i, j, s0); |
16499 | ecsel(1, lt, f0[0], state.inputs.groupIDN, i, j, s0); |
16500 | |
16501 | state.ra.safeRelease(storage); |
16502 | if (!strategy.persistent) { |
16503 | state.ra.safeRelease(state.inputs.bslice); |
16504 | state.ra.safeRelease(state.inputs.bthresh); |
16505 | } |
16506 | } |
16507 | |
16508 | // Reverse m/n loops if requested. |
16509 | template <HW hw> |
16510 | void gemm_kernel_generator_t<hw>::gemmReverseLoops(const GEMMProblem &problem, |
16511 | const GEMMStrategy &strategy, GEMMState &state) { |
16512 | for (LoopType l : {LoopM, LoopN}) |
16513 | if (strategy.reverse[l]) { |
16514 | bool fusedL = strategy.fused && (l == strategy.fusedLoop); |
16515 | auto q = (l == LoopM) ? state.inputs.m : state.inputs.n; |
16516 | auto q0 = (l == LoopM) ? state.i0 : state.j0; |
16517 | auto q0Align = state.ra.alloc_sub<uint32_t>(); |
16518 | auto temp = state.ra.alloc_sub<uint32_t>(); |
16519 | |
16520 | add(1, q0Align, q, -1); |
16521 | if (strategy.fixedWG(problem)) { |
16522 | mod(temp, q0, strategy.wg[l] * strategy.unroll[l], strategy, |
16523 | state); |
16524 | alignDown(q0Align, q0Align, strategy.wg[l] * strategy.unroll[l], |
16525 | strategy, state); |
16526 | shl(1, temp, temp, 1); |
16527 | eadd3(1 | ge | f0[0], q0Align.d(), q0Align, -q0, temp); |
16528 | mov(1 | f0[0], q0, q0Align); |
16529 | } else if (fusedL) { |
16530 | shl(1, temp, state.fusedID, 1); |
16531 | alignDown(q0Align, q0Align, 2 * strategy.unroll[l], strategy, |
16532 | state); |
16533 | eadd3(1 | ge | f0[0], q0Align.d(), q0Align, -q0, temp); |
16534 | mov(1 | f0[0], q0, q0Align); |
16535 | } else { |
16536 | alignDown( |
16537 | q0Align, q0Align, strategy.unroll[l], strategy, state); |
16538 | cmp(1 | le | f0[0], q0, q0Align); |
16539 | add(1 | f0[0], q0, q0Align, -q0); |
16540 | } |
16541 | state.ra.safeRelease(temp); |
16542 | state.ra.safeRelease(q0Align); |
16543 | } |
16544 | } |
16545 | |
16546 | // Reorder local IDs as needed. |
16547 | template <HW hw> |
16548 | void gemm_kernel_generator_t<hw>::gemmReorderLocalIDs( |
16549 | const GEMMProblem &problem, const GEMMStrategy &strategy, |
16550 | GEMMState &state) { |
16551 | if (strategy.fixedSystolic) |
16552 | sysgemmReorderLocalIDs(problem, strategy, state); |
16553 | |
16554 | if (strategy.skewLocalIDs) { |
16555 | if (!strategy.fixedWG(problem)) stub(); |
16556 | auto wgI = strategy.wg[strategy.loopOrder[0]]; |
16557 | auto adjustEvery = div_up(eusPerSubslice(hw), wgI); |
16558 | bool innerM = strategy.loopOrder[0] == LoopM; |
16559 | auto lidI = innerM ? state.lidM : state.lidN; |
16560 | auto lidO = innerM ? state.lidN : state.lidM; |
16561 | auto temp = state.ra.alloc_sub<uint16_t>(); |
16562 | auto slidO = lidO; |
16563 | |
16564 | if (adjustEvery > 1) { |
16565 | shr(1, temp, lidO, log2(adjustEvery)); |
16566 | slidO = temp; |
16567 | } |
16568 | |
16569 | if (strategy.fused) |
16570 | emad(1, lidI, lidI, slidO, 2, strategy, state); |
16571 | else |
16572 | add(1, lidI, lidI, slidO); |
16573 | |
16574 | if (!is_zero_or_pow2(wgI)) stub(); |
16575 | |
16576 | and_(1, lidI, lidI, wgI - 1); |
16577 | |
16578 | state.ra.safeRelease(temp); |
16579 | } |
16580 | } |
16581 | |
16582 | // Convert leading dimension and offset inputs to bytes. |
16583 | template <ngen::HW hw> |
16584 | void gemm_kernel_generator_t<hw>::gemmScaleInputs(const GEMMProblem &problem, |
16585 | const GEMMStrategy &strategy, GEMMState &state) { |
16586 | auto Ta_ext = problem.Ta_ext, Tb_ext = problem.Tb_ext, |
16587 | Tc_ext = problem.Tc_ext, Tco = problem.Tco; |
16588 | |
16589 | emulConstant(1, state.inputs.lda, state.inputs.lda, Ta_ext.size(), strategy, |
16590 | state); |
16591 | if (state.inputs.ldb != state.inputs.lda) |
16592 | emulConstant(1, state.inputs.ldb, state.inputs.ldb, Tb_ext.size(), |
16593 | strategy, state); |
16594 | for (int q = 0; q < state.C_count; q++) |
16595 | emulConstant(1, state.inputs.ldc[q], state.inputs.ldc[q], Tc_ext.size(), |
16596 | strategy, state); |
16597 | if (state.inputs.ldco.isValid()) |
16598 | emulConstant(1, state.inputs.ldco, state.inputs.ldco, Tco.size(), |
16599 | strategy, state); |
16600 | |
16601 | { |
16602 | emulConstant(1, state.inputs.offsetA, state.inputs.offsetA, |
16603 | Ta_ext.size(), strategy, state); |
16604 | emulConstant(1, state.inputs.offsetB, state.inputs.offsetB, |
16605 | Tb_ext.size(), strategy, state); |
16606 | for (int q = 0; q < state.C_count; q++) |
16607 | emulConstant(1, state.inputs.offsetC[q], state.inputs.offsetC[q], |
16608 | Tc_ext.size(), strategy, state); |
16609 | if (problem.usesCO()) |
16610 | emulConstant(1, state.inputs.offsetCO, state.inputs.offsetCO, |
16611 | Tco.size(), strategy, state); |
16612 | } |
16613 | |
16614 | if (problem.batch == BatchMode::Strided) |
16615 | for (int b = 0; b < problem.batchDims; b++) { |
16616 | emulConstant(1, state.inputs.strideA[b], state.inputs.strideA[b], |
16617 | Ta_ext.size(), strategy, state); |
16618 | emulConstant(1, state.inputs.strideB[b], state.inputs.strideB[b], |
16619 | Tb_ext.size(), strategy, state); |
16620 | emulConstant(1, state.inputs.strideC[b], state.inputs.strideC[b], |
16621 | Tc_ext.size(), strategy, state); |
16622 | } |
16623 | } |
16624 | |
16625 | // Cache multiples of lda/ldb for later address calculations. |
16626 | template <HW hw> |
16627 | void gemm_kernel_generator_t<hw>::gemmCacheLDABMultiples( |
16628 | const GEMMProblem &problem, const GEMMStrategy &strategy, |
16629 | GEMMState &state) { |
16630 | int na = 0, nb = 0; |
16631 | |
16632 | if (!strategy.A.address2D) switch (problem.A.layout) { |
16633 | case MatrixLayout::N: |
16634 | na = std::max(strategy.ka_load, strategy.ka_prefetch); |
16635 | break; |
16636 | case MatrixLayout::T: |
16637 | na = strategy.unroll[LoopM]; |
16638 | if (isTransposing(strategy.A.accessType)) |
16639 | na = std::min(na, maxScatteredSIMD(hw, strategy.A)); |
16640 | break; |
16641 | default: break; |
16642 | } |
16643 | |
16644 | if (!strategy.B.address2D) switch (problem.B.layout) { |
16645 | case MatrixLayout::T: |
16646 | nb = std::max(strategy.kb_load, strategy.kb_prefetch); |
16647 | break; |
16648 | case MatrixLayout::N: |
16649 | nb = strategy.unroll[LoopN]; |
16650 | if (isTransposing(strategy.B.accessType)) |
16651 | nb = std::min(nb, maxScatteredSIMD(hw, strategy.B)); |
16652 | break; |
16653 | default: break; |
16654 | } |
16655 | |
16656 | if (na <= 2) na = 0; |
16657 | if (nb <= 2) nb = 0; |
16658 | |
16659 | if (na || nb) extendIndexVec(std::max(na, nb), state); |
16660 | |
16661 | if (na) { |
16662 | bool a64 = (strategy.A.base.getModel() == ModelA64); |
16663 | state.ldaMultiples |
16664 | = createLDMultiples(a64, na, state.lda, strategy, state); |
16665 | } |
16666 | |
16667 | if (nb) { |
16668 | bool a64 = (strategy.B.base.getModel() == ModelA64); |
16669 | state.ldbMultiples |
16670 | = createLDMultiples(a64, nb, state.ldb, strategy, state); |
16671 | } |
16672 | } |
16673 | |
16674 | // Cache multiples of ldc for later address calculations. |
16675 | template <HW hw> |
16676 | void gemm_kernel_generator_t<hw>::gemmCacheLDCMultiples( |
16677 | const GEMMProblem &problem, const GEMMStrategy &strategy, |
16678 | GEMMState &state, bool prefetch) { |
16679 | if ((prefetch ? strategy.C_prefetch : strategy.C).address2D) return; |
16680 | |
16681 | int nc = 0; |
16682 | switch (problem.C.layout) { |
16683 | case MatrixLayout::N: nc = strategy.unroll[LoopN]; break; |
16684 | case MatrixLayout::T: nc = strategy.unroll[LoopM]; break; |
16685 | default: break; |
16686 | } |
16687 | |
16688 | if (nc <= 2) return; |
16689 | |
16690 | bool a64 = (strategy.C.base.getModel() == ModelA64); |
16691 | int C_count = prefetch ? 1 : state.C_count; |
16692 | for (int q = 0; q < C_count; q++) |
16693 | state.ldcMultiples[q] = createLDMultiples( |
16694 | a64, nc, state.inputs.ldc[q], strategy, state); |
16695 | } |
16696 | |
16697 | // GEMM kernel generation interface. |
16698 | template <HW hw> |
16699 | void gemm_kernel_generator_t<hw>::gemm(GEMMProblem problem, |
16700 | GEMMStrategy strategy, const InterfaceHandler &interface_) { |
16701 | GEMMState state(hw); |
16702 | interface = interface_; |
16703 | gemm(problem, strategy, state); |
16704 | } |
16705 | |
16706 | template <HW hw> |
16707 | void gemm_kernel_generator_t<hw>::gemm( |
16708 | GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state) { |
16709 | bool inFusedGEMM = state.fusedGEMM.active; |
16710 | bool anyKParallel = strategy.kParallelLocal || strategy.kParallel; |
16711 | |
16712 | Label labelKernelDone, labelReentry; |
16713 | |
16714 | // By default, don't use dispatch mask. |
16715 | setDefaultNoMask(); |
16716 | setDefaultAutoSWSB(); |
16717 | |
16718 | // Set up. |
16719 | gemmInitState(problem, strategy, state); |
16720 | |
16721 | // Transfer surface indices to strategy AddressBases. |
16722 | if (!strategy.A.base.isStateless()) |
16723 | strategy.A.base.setIndex(state.inputs.surfaceA); |
16724 | if (!strategy.B.base.isStateless()) |
16725 | strategy.B.base.setIndex(state.inputs.surfaceB); |
16726 | if (!strategy.C.base.isStateless()) { |
16727 | strategy.C.base.setIndex(state.inputs.surfaceC[0]); |
16728 | if (state.C_count > 1) stub(); |
16729 | } |
16730 | if (problem.usesCO() && !strategy.CO.base.isStateless()) |
16731 | strategy.CO.base.setIndex(state.inputs.surfaceCO); |
16732 | |
16733 | for (size_t i = 0; i < strategy.binary.size(); i++) |
16734 | if (!strategy.binary[i].base.isStateless()) |
16735 | strategy.binary[i].base.setIndex(state.inputs.binarySurfaces[i]); |
16736 | |
16737 | // Prologue. |
16738 | if (!inFusedGEMM) prologue(strategy); |
16739 | |
16740 | // Grab fused ID if needed, and multiply by unroll. |
16741 | getFusedID(strategy.unroll[strategy.fusedLoop], problem, strategy, state); |
16742 | |
16743 | if (!inFusedGEMM) { |
16744 | // Divide out subgroup size from local size 0 and local ID 0, and reorder threads for fusing if needed. |
16745 | removeSG(problem, strategy, state); |
16746 | reorderFusedEUs(problem, strategy, state); |
16747 | |
16748 | } /* !inFusedGEMM */ |
16749 | |
16750 | // Check for copy or compute kernel. |
16751 | if (strategy.splitCopy) { |
16752 | state.isCompute = state.ra.alloc_sub<uint32_t>( |
16753 | getHint(HintType::LongTerm, strategy)); |
16754 | auto localIDY = (strategy.loopOrder[1] == LoopN) |
16755 | ? state.inputs.localIDN |
16756 | : state.inputs.localIDM; |
16757 | auto wgY = strategy.wg[strategy.loopOrder[1]]; |
16758 | cmp(1 | ge | f1[1], state.isCompute, localIDY, wgY); |
16759 | if (is_zero_or_pow2(wgY)) |
16760 | and_(1, localIDY, localIDY, wgY - 1); |
16761 | else |
16762 | add(1 | f1[1], localIDY, localIDY, -wgY); |
16763 | } |
16764 | |
16765 | // Scale LDs/offsets. |
16766 | gemmScaleInputs(problem, strategy, state); |
16767 | |
16768 | // Local ID handling and saving. |
16769 | gemmReorderLocalIDs(problem, strategy, state); |
16770 | |
16771 | if (strategy.needsMNLocalIDs()) saveMNLocalIDs(strategy, state); |
16772 | |
16773 | if (strategy.needsKLocalIDs()) saveKLocalIDSize(strategy, state); |
16774 | |
16775 | // Save full k size if needed. |
16776 | bool anyAB2D = strategy.A.address2D || strategy.B.address2D |
16777 | || (strategy.prefetchA && strategy.A_prefetch.address2D) |
16778 | || (strategy.prefetchB && strategy.B_prefetch.address2D); |
16779 | if (anyKParallel) { |
16780 | if (strategy.persistent || anyAB2D) { |
16781 | state.fullK = state.ra.alloc_sub<uint32_t>( |
16782 | getHint(HintType::LongTerm, strategy)); |
16783 | mov(1, state.fullK, state.inputs.k); |
16784 | } |
16785 | } else |
16786 | state.fullK = state.inputs.k; |
16787 | |
16788 | // Load ao/bo from memory if needed. |
16789 | if (problem.abOffset != ABOffset::None && state.inputs.abo.isInvalid()) { |
16790 | state.inputs.abo = state.ra.alloc_sub<uint32_t>( |
16791 | getHint(HintType::LongTerm, strategy)); |
16792 | state.inputs.ao = state.inputs.abo.w(0); |
16793 | state.inputs.bo = state.inputs.abo.w(1); |
16794 | |
16795 | auto loadABO = [&](const ngen::Subregister &xo, |
16796 | ngen::Subregister &xoPtr) { |
16797 | if (xoPtr.isInvalid()) |
16798 | mov(1, xo, 0); |
16799 | else { |
16800 | auto = state.ra.alloc_range(2); |
16801 | auto data = state.ra.alloc(); |
16802 | emov<uint64_t>(1, header, xoPtr, strategy, state); |
16803 | (hw >= HW::XeHPG) |
16804 | ? load(1, data, D32 | CacheSettingsLSC::L1C_L3C, A64, |
16805 | header) |
16806 | : load(1, data, scattered_dword(1), A64, header); |
16807 | mov(1, xo, -data.w(0)); |
16808 | state.ra.safeRelease(xoPtr); |
16809 | state.ra.safeRelease(header); |
16810 | state.ra.safeRelease(data); |
16811 | } |
16812 | }; |
16813 | |
16814 | loadABO(state.inputs.ao, state.inputs.aoPtr); |
16815 | loadABO(state.inputs.bo, state.inputs.boPtr); |
16816 | } |
16817 | |
16818 | // Persistent thread preparation and re-entry. |
16819 | if (strategy.persistent) { |
16820 | if (!strategy.linearOrder()) stub(); |
16821 | if (problem.batch != BatchMode::None) |
16822 | stub(); // need to wrangle groupIDK also |
16823 | |
16824 | auto newGroupIDMN = state.ra.alloc_sub<uint32_t>( |
16825 | getHint(HintType::LongTerm, strategy)); |
16826 | mov(1, newGroupIDMN, state.inputs.groupIDMN); |
16827 | state.inputs.groupIDMN = newGroupIDMN; |
16828 | |
16829 | gemmFoldOffsets(problem, strategy, state); |
16830 | |
16831 | mark(labelReentry); |
16832 | } |
16833 | |
16834 | // Group ID remapping. |
16835 | if (strategy.hilbertOrder) |
16836 | gemmHilbertlikeOrder(problem, strategy, state); |
16837 | else if (strategy.boustrophedon) |
16838 | gemmBoustrophedonOrder(problem, strategy, state); |
16839 | |
16840 | // Batch handling. |
16841 | gemmGetBatchIDs(problem, strategy, state); |
16842 | |
16843 | // Compute offset for A, B, C for non-strided and strided batch. |
16844 | gemmOffsetBatchABC(problem, strategy, state); |
16845 | |
16846 | // 32-bit add check. TODO: move out of persistent loop for non-batch. |
16847 | gemmCheck32(problem, strategy, state); |
16848 | |
16849 | // Calculate i0, j0, h0 -- the initial i/j/h indices for this thread. |
16850 | bool needH0 = anyKParallel; |
16851 | |
16852 | state.i0 = state.ra.alloc_sub<uint32_t>( |
16853 | getHint(HintType::TempComp0, strategy)); |
16854 | state.j0 = state.ra.alloc_sub<uint32_t>( |
16855 | getHint(HintType::TempComp1, strategy)); |
16856 | if (needH0) |
16857 | state.h0 = state.ra.alloc_sub<uint32_t>( |
16858 | getHint(HintType::TempComp0, strategy)); |
16859 | |
16860 | bool wgCheck = wgRemCheck(problem, strategy); |
16861 | bool gemmtBarriers = problem.gemmt() && strategy.needsBarrier(); |
16862 | |
16863 | Subregister idM, idN, idK; |
16864 | Subregister wgI0, wgJ0; |
16865 | |
16866 | idM = state.ra.alloc_sub<uint32_t>(getHint(HintType::TempComp1, strategy)); |
16867 | idN = state.ra.alloc_sub<uint32_t>(getHint(HintType::TempComp0, strategy)); |
16868 | if (strategy.kParallel) |
16869 | idK = state.ra.alloc_sub<uint32_t>( |
16870 | getHint(HintType::TempComp0, strategy)); |
16871 | |
16872 | if (strategy.fixedWG(problem)) { |
16873 | mulConstant(1, idM, state.inputs.groupIDM, strategy.wg[LoopM]); |
16874 | mulConstant(1, idN, state.inputs.groupIDN, strategy.wg[LoopN]); |
16875 | if (strategy.kParallel) |
16876 | mulConstant(1, idK, state.inputs.groupIDK, strategy.wg[LoopK]); |
16877 | } else { |
16878 | mul(1, idM, state.inputs.groupIDM, state.inputs.localSizeM.uw()); |
16879 | mul(1, idN, state.inputs.groupIDN, state.inputs.localSizeN.uw()); |
16880 | if (strategy.kParallel) |
16881 | mul(1, idK, state.inputs.groupIDK, state.inputs.localSizeK.uw()); |
16882 | } |
16883 | |
16884 | if (wgCheck || gemmtBarriers) { |
16885 | wgI0 = state.ra.alloc_sub<uint32_t>( |
16886 | getHint(HintType::TempComp0, strategy)); |
16887 | wgJ0 = state.ra.alloc_sub<uint32_t>( |
16888 | getHint(HintType::TempComp1, strategy)); |
16889 | mulConstant(1, wgI0, idM, strategy.unroll[LoopM]); |
16890 | mulConstant(1, wgJ0, idN, strategy.unroll[LoopN]); |
16891 | } |
16892 | |
16893 | add(1, idM, idM, state.lidM); |
16894 | add(1, idN, idN, state.lidN); |
16895 | if (strategy.kParallel) add(1, idK, idK, state.lidK); |
16896 | |
16897 | mulConstant(1, state.i0, idM, strategy.unroll[LoopM]); |
16898 | mulConstant(1, state.j0, idN, strategy.unroll[LoopN]); |
16899 | |
16900 | if (strategy.kParallel) |
16901 | emul(1, state.h0, idK, state.inputs.k0, strategy, state); |
16902 | else if (strategy.kParallelLocal) |
16903 | mul(1, state.h0, state.inputs.k0, state.lidK); |
16904 | |
16905 | gemmReverseLoops(problem, strategy, state); |
16906 | |
16907 | state.ra.safeRelease(idM); |
16908 | state.ra.safeRelease(idN); |
16909 | state.ra.safeRelease(idK); |
16910 | if (!strategy.persistent) { |
16911 | state.ra.safeRelease(state.inputs.localSizeM); |
16912 | state.ra.safeRelease(state.inputs.localSizeN); |
16913 | } |
16914 | if (anyKParallel) { |
16915 | state.ra.safeRelease(state.inputs.localIDK); |
16916 | if (!strategy.persistent) state.ra.safeRelease(state.inputs.localSizeK); |
16917 | } |
16918 | if (strategy.linearOrder() || strategy.persistent) { |
16919 | state.ra.safeRelease(state.inputs.groupIDM); |
16920 | state.ra.safeRelease(state.inputs.groupIDN); |
16921 | } |
16922 | |
16923 | moveR0(strategy, state); |
16924 | |
16925 | // Adjust k range as needed. |
16926 | if (anyKParallel) { |
16927 | add(1, state.inputs.k, |
16928 | strategy.persistent ? state.fullK : state.inputs.k, -state.h0); |
16929 | min_(1, state.inputs.k, state.inputs.k, state.inputs.k0); |
16930 | |
16931 | bool keepK0 = false; |
16932 | keepK0 |= strategy.kParallelLocal |
16933 | && (strategy.barrierFreq > 0 || strategy.slmBuffers > 0); |
16934 | keepK0 |= strategy.persistent; |
16935 | |
16936 | if (!keepK0) state.ra.safeRelease(state.inputs.k0); |
16937 | } |
16938 | |
16939 | state.ra.safeRelease(state.inputs.localIDM); |
16940 | state.ra.safeRelease(state.inputs.localIDN); |
16941 | if (!strategy.needsMNLocalIDs()) state.lidM = state.lidN = invalid; |
16942 | |
16943 | // Compute workgroup remainders if needed. |
16944 | if (wgCheck) { |
16945 | state.remaindersWG[LoopM] = state.ra.alloc_sub<uint32_t>( |
16946 | getHint(HintType::TempComp1, strategy)); |
16947 | state.remaindersWG[LoopN] = state.ra.alloc_sub<uint32_t>( |
16948 | getHint(HintType::TempComp0, strategy)); |
16949 | add(1 | sat, state.remaindersWG[LoopM], -wgI0, state.inputs.m); |
16950 | add(1 | sat, state.remaindersWG[LoopN], -wgJ0, state.inputs.n); |
16951 | } |
16952 | state.ra.safeRelease(wgI0); |
16953 | state.ra.safeRelease(wgJ0); |
16954 | |
16955 | // Compute base addresses for A, B, C. |
16956 | gemmOffsetABC(true, state.i0, state.j0, state.h0, problem, strategy, state); |
16957 | |
16958 | gemmSetupABC(problem, strategy, state); |
16959 | gemmSubkernel(problem, strategy, state); |
16960 | |
16961 | mark(labelKernelDone); |
16962 | |
16963 | // Persistent thread loop. Advance group ID and re-enter kernel if there's more work to do. |
16964 | if (strategy.persistent) { |
16965 | status << "Persistent loop" << status_stream::endl; |
16966 | if (state.inputs.groupCountMN.isInvalid()) { |
16967 | state.inputs.groupCountMN = state.ra.alloc_sub<uint32_t>( |
16968 | getHint(HintType::LongTerm, strategy)); |
16969 | emul(1, state.inputs.groupCountMN, state.inputs.groupCountM, |
16970 | state.inputs.groupCountN, strategy, state); |
16971 | } |
16972 | |
16973 | add(1, state.inputs.groupIDMN, state.inputs.groupIDMN, |
16974 | state.inputs.groupStride); |
16975 | cmp(1 | lt | state.flagAP, state.inputs.groupIDMN, |
16976 | state.inputs.groupCountMN); |
16977 | |
16978 | state.ra.safeRelease(state.inputs.groupCountMN); |
16979 | gemmRestoreOffsets(problem, strategy, state); |
16980 | |
16981 | if (strategy.slmBuffers > 0) { |
16982 | auto temp = state.ra.alloc(); |
16983 | useR0(state, [&](const GRF &r0_info) { |
16984 | MOCK_BARRIERS barrier(temp, r0_info); |
16985 | }); |
16986 | state.ra.safeRelease(temp); |
16987 | } |
16988 | |
16989 | jmpi(1 | state.flagAP, labelReentry); |
16990 | } |
16991 | |
16992 | if (!inFusedGEMM) { |
16993 | epilogue(strategy, state); |
16994 | padding(); |
16995 | } |
16996 | } |
16997 | |
16998 | template <HW hw> |
16999 | SubregisterPair gemm_kernel_generator_t<hw>::allocIncrement( |
17000 | const GEMMStrategy &strategy, CommonState &state) { |
17001 | if (strategy.avoidIncConflicts) |
17002 | return SubregisterPair(state.ra.alloc_sub<uint32_t>( |
17003 | getHint(HintType::LongTerm0, strategy)), |
17004 | state.ra.alloc_sub<uint32_t>( |
17005 | getHint(HintType::LongTerm1, strategy))); |
17006 | else |
17007 | return SubregisterPair(state.ra.alloc_sub<uint32_t>( |
17008 | getHint(HintType::LongTerm, strategy))); |
17009 | } |
17010 | |
17011 | // Calculate and cache lda_ka (= lda * ka) and ldb_kb (= ldb * kb) as necessary. |
17012 | template <HW hw> |
17013 | void gemm_kernel_generator_t<hw>::gemmCalcIncrements(const GEMMProblem &problem, |
17014 | const GEMMStrategy &strategy, GEMMState &state, int ka_load, |
17015 | int kb_load, bool doA, bool doB) { |
17016 | int nr = strategy.avoidIncConflicts ? 2 : 1; |
17017 | |
17018 | if (ka_load == 0) ka_load = strategy.ka_inc(); |
17019 | if (kb_load == 0) kb_load = strategy.kb_inc(); |
17020 | |
17021 | // If A is nontranspose, we need lda * ka_load * elementSize. |
17022 | if (doA && (problem.A.layout == MatrixLayout::N)) { |
17023 | if (!strategy.A.address2D) { |
17024 | if (ka_load > 1) { |
17025 | if (state.lda_ka.isInvalid()) |
17026 | state.lda_ka = allocIncrement(strategy, state); |
17027 | for (int i = 0; i < nr; i++) |
17028 | emulConstant(1, state.lda_ka.getReg(i), state.inputs.lda, |
17029 | ka_load, strategy, state); |
17030 | state.ka_cached = ka_load; |
17031 | } else if (strategy.avoidIncConflicts) |
17032 | duplicateScalar(state.lda, state); |
17033 | } |
17034 | if (strategy.prefetchA && !strategy.A_prefetch.address2D |
17035 | && (strategy.ka_pfStride != ka_load || strategy.A.address2D)) { |
17036 | if (strategy.ka_pfStride > 1) { |
17037 | if (state.lda_ka_prefetch.isInvalid()) |
17038 | state.lda_ka_prefetch = allocIncrement(strategy, state); |
17039 | for (int i = 0; i < nr; i++) |
17040 | emulConstant(1, state.lda_ka_prefetch.getReg(i), |
17041 | state.inputs.lda, strategy.ka_pfStride, strategy, |
17042 | state); |
17043 | } else if (strategy.avoidIncConflicts) |
17044 | duplicateScalar(state.lda, state); |
17045 | } |
17046 | } |
17047 | |
17048 | // Similarly for B if it's transpose. |
17049 | if (doB && (problem.B.layout == MatrixLayout::T)) { |
17050 | if (!strategy.B.address2D) { |
17051 | if (kb_load > 1) { |
17052 | if (state.ldb_kb.isInvalid()) |
17053 | state.ldb_kb = allocIncrement(strategy, state); |
17054 | for (int i = 0; i < nr; i++) |
17055 | emulConstant(1, state.ldb_kb.getReg(i), state.inputs.ldb, |
17056 | kb_load, strategy, state); |
17057 | state.kb_cached = kb_load; |
17058 | } else if (strategy.avoidIncConflicts) |
17059 | duplicateScalar(state.ldb, state); |
17060 | } |
17061 | if (strategy.prefetchB && !strategy.B_prefetch.address2D |
17062 | && (strategy.kb_pfStride != kb_load || strategy.B.address2D)) { |
17063 | if (strategy.kb_pfStride > 1) { |
17064 | if (state.ldb_kb_prefetch.isInvalid()) |
17065 | state.ldb_kb_prefetch = allocIncrement(strategy, state); |
17066 | for (int i = 0; i < nr; i++) |
17067 | emulConstant(1, state.ldb_kb_prefetch.getReg(i), |
17068 | state.inputs.ldb, strategy.kb_pfStride, strategy, |
17069 | state); |
17070 | } else if (strategy.avoidIncConflicts) |
17071 | duplicateScalar(state.ldb, state); |
17072 | } |
17073 | } |
17074 | } |
17075 | |
17076 | template <HW hw> |
17077 | void gemm_kernel_generator_t<hw>::gemmDowngradeAccess( |
17078 | const GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state) { |
17079 | bool applyOffsetA = false, applyOffsetB = false; |
17080 | |
17081 | strategy.A.accessType = strategy.unalignedAccA; |
17082 | strategy.B.accessType = strategy.unalignedAccB; |
17083 | |
17084 | if (strategy.A.address2D && !isBlock2D(strategy.A.accessType)) { |
17085 | applyOffsetA = true; |
17086 | strategy.A.address2D = false; |
17087 | } |
17088 | |
17089 | if (strategy.B.address2D && !isBlock2D(strategy.B.accessType)) { |
17090 | applyOffsetB = true; |
17091 | strategy.B.address2D = false; |
17092 | } |
17093 | |
17094 | if (applyOffsetA || applyOffsetB) |
17095 | gemmOffsetABC(false, state.i0, state.j0, state.h0, problem, strategy, |
17096 | state, applyOffsetA, applyOffsetB, false); |
17097 | } |
17098 | |
17099 | template <HW hw> |
17100 | void gemm_kernel_generator_t<hw>::gemmSubkernel( |
17101 | GEMMProblem &problem, GEMMStrategy &strategy, GEMMState state) { |
17102 | Label labelSubkernelDone, labelSubkernelEarlyExit; |
17103 | |
17104 | status << "Begin subkernel: unroll " << strategy.unroll[LoopM] << 'x' |
17105 | << strategy.unroll[LoopN] << status_stream::endl; |
17106 | |
17107 | // Calculate remainders for m/n loops: clamp(m - i0, 0, unrollM), clamp(n - j0, 0, unrollN). |
17108 | // Careful with this clamping, because unroll may change in remainder handling. |
17109 | bool remM = (strategy.remHandling[LoopM] != RemainderHandling::Ignore); |
17110 | bool remN = (strategy.remHandling[LoopN] != RemainderHandling::Ignore); |
17111 | bool fusedremM = remM && strategy.fused && (strategy.fusedLoop == LoopM); |
17112 | bool fusedremN = remN && strategy.fused && (strategy.fusedLoop == LoopN); |
17113 | |
17114 | state.doLateExit = strategy.lateExit(); |
17115 | bool earlyExit = !state.doLateExit; |
17116 | |
17117 | if (fusedremM || fusedremN) { |
17118 | state.remFusedStorage = state.ra.alloc_sub<uint32_t>(); |
17119 | add(1, state.remFusedStorage, -state.fusedID, |
17120 | uint16_t(strategy.unroll[strategy.fusedLoop])); |
17121 | } |
17122 | if (remM || !earlyExit) { |
17123 | state.remaindersFused[LoopM] = state.remainders[LoopM] |
17124 | = state.ra.alloc_sub<uint32_t>( |
17125 | getHint(HintType::LongTerm, strategy)); |
17126 | InstructionModifier mod = 1 | sat; |
17127 | if (!fusedremM && earlyExit) mod = mod | le | f0[1]; |
17128 | add(mod, state.remainders[LoopM], -state.i0, state.inputs.m); |
17129 | } |
17130 | if (remN || !earlyExit) { |
17131 | state.remaindersFused[LoopN] = state.remainders[LoopN] |
17132 | = state.ra.alloc_sub<uint32_t>( |
17133 | getHint(HintType::LongTerm, strategy)); |
17134 | InstructionModifier mod = 1 | sat; |
17135 | if (!fusedremN && earlyExit) mod = mod | le | f1[1]; |
17136 | add(mod, state.remainders[LoopN], -state.j0, state.inputs.n); |
17137 | } |
17138 | if (fusedremM || fusedremN) { |
17139 | state.remaindersFused[strategy.fusedLoop] = state.remFusedStorage; |
17140 | add(1 | sat, state.remFusedStorage, -state.remFusedStorage, |
17141 | state.remainders[strategy.fusedLoop]); |
17142 | if (earlyExit) { |
17143 | cmp(1 | le | (fusedremM ? f0[1] : f1[1]), null.d(), |
17144 | state.remainders[strategy.fusedLoop].d(), -state.fusedID); |
17145 | state.allowEmptyC = true; |
17146 | } |
17147 | } |
17148 | if (remM) |
17149 | min_(1, state.remainders[LoopM], state.remainders[LoopM], |
17150 | uint16_t(strategy.unroll[LoopM])); |
17151 | if (remN) |
17152 | min_(1, state.remainders[LoopN], state.remainders[LoopN], |
17153 | uint16_t(strategy.unroll[LoopN])); |
17154 | |
17155 | gemmCalcIncrements(problem, strategy, state); |
17156 | |
17157 | // Early exit if nothing to do. Keep fused threads together. |
17158 | if (earlyExit && (remM || remN)) { |
17159 | InstructionModifier cond; |
17160 | if (remM && remN) |
17161 | cond = 1 | f0[1] | anyv; |
17162 | else if (remM) |
17163 | cond = 1 | f0[1]; |
17164 | else |
17165 | cond = 1 | f1[1]; |
17166 | |
17167 | if (state.fusedGEMM.active) |
17168 | and_(16 | nz | state.fusedGEMM.needLateGEMMDone, null.uw(), |
17169 | state.inputs.flags.uw(), FlagEarlyFusedGEMMDone); |
17170 | |
17171 | auto &label = state.fusedGEMM.active ? labelSubkernelEarlyExit |
17172 | : labelSubkernelDone; |
17173 | |
17174 | ejmpi(cond, label); |
17175 | } |
17176 | |
17177 | // Create the kernel body. If enabled, create two versions, one with A/B more aligned. |
17178 | bool success; |
17179 | if (!strategy.optAlignAB) |
17180 | success = gemmMEdge(problem, strategy, state); |
17181 | else { |
17182 | // Check alignment of effA, effB, lda, and ldb. |
17183 | Label labelUnaligned; |
17184 | uint16_t mask = (strategy.optAlignAB - 1); |
17185 | bool check_lda = !isPacked(problem.A.layout); |
17186 | bool check_ldb = !isPacked(problem.B.layout); |
17187 | if (problem.A.alignment & mask) { |
17188 | and_(1 | nz | f0[0], null.uw(), state.effA.uw(), mask); |
17189 | if (check_lda) |
17190 | and_(1 | nz | f1[0], null.uw(), state.inputs.lda.uw(), mask); |
17191 | } |
17192 | if (problem.B.alignment & mask) { |
17193 | and_(1 | nz | f0[1], null.uw(), state.effB.uw(), mask); |
17194 | if (check_ldb) |
17195 | and_(1 | nz | f1[1], null.uw(), state.inputs.ldb.uw(), mask); |
17196 | } |
17197 | if (problem.A.alignment & mask) { |
17198 | InstructionModifier amod = check_lda ? 1 | f0[0] | anyv : 1 | f0[0]; |
17199 | ejmpi(amod, labelUnaligned); |
17200 | } |
17201 | if (problem.B.alignment & mask) { |
17202 | InstructionModifier bmod = check_ldb ? 1 | f0[1] | anyv : 1 | f0[1]; |
17203 | ejmpi(bmod, labelUnaligned); |
17204 | } |
17205 | |
17206 | auto alignedProblem = problem; |
17207 | alignedProblem.A.setAlignment( |
17208 | std::max<int>(problem.A.alignment, strategy.optAlignAB)); |
17209 | alignedProblem.B.setAlignment( |
17210 | std::max<int>(problem.B.alignment, strategy.optAlignAB)); |
17211 | |
17212 | status << "Aligned A/B" << status_stream::endl; |
17213 | success = gemmMEdge(alignedProblem, strategy, state); |
17214 | |
17215 | if (!success && lastException) std::rethrow_exception(lastException); |
17216 | |
17217 | state.isNested ? jmpi(1, labelSubkernelDone) |
17218 | : epilogue(strategy, state); |
17219 | |
17220 | mark(labelUnaligned); |
17221 | |
17222 | auto modStrategy = strategy; |
17223 | |
17224 | gemmDowngradeAccess(problem, modStrategy, state); |
17225 | |
17226 | status << "Unaligned A/B" << status_stream::endl; |
17227 | if (!gemmMEdge(problem, modStrategy, state)) { |
17228 | modStrategy.checkAdd32 |
17229 | = false; // Don't optimize additions on this (slow) path to reduce code size. |
17230 | status << "Reducing register usage" << status_stream::endl; |
17231 | success = success && modStrategy.minimize(hw, problem); |
17232 | |
17233 | gemmCalcIncrements(problem, modStrategy, |
17234 | state); // Recalculate lda_ka/ldb_kb as they have changed. |
17235 | |
17236 | success = success && gemmMEdge(problem, modStrategy, state); |
17237 | } |
17238 | } |
17239 | |
17240 | if (!success) |
17241 | lastException ? std::rethrow_exception(lastException) |
17242 | : throw std::runtime_error("Could not generate kernel." ); |
17243 | |
17244 | mark(labelSubkernelDone); |
17245 | |
17246 | if (state.fusedGEMM.active) { |
17247 | mov(1, state.fusedGEMM.needLateGEMMDone, 0); |
17248 | mark(labelSubkernelEarlyExit); |
17249 | } |
17250 | |
17251 | safeRelease(state.lda_ka, state); |
17252 | safeRelease(state.ldb_kb, state); |
17253 | safeRelease(state.lda_ka_prefetch, state); |
17254 | safeRelease(state.ldb_kb_prefetch, state); |
17255 | } |
17256 | |
17257 | template <HW hw> |
17258 | void gemm_kernel_generator_t<hw>::gemmSuperkernelInitState( |
17259 | GEMMSuperkernelProblem &problem, GEMMSuperkernelStrategy &strategy, |
17260 | GEMMSuperkernelState &state) { |
17261 | if (strategy.persistent) interface.requireGlobalAtomics(); |
17262 | |
17263 | gemmInitState(problem, strategy.substrategies[0], state, true); |
17264 | |
17265 | state.isNested |= strategy.persistent; |
17266 | |
17267 | state.inputsSK.surfacePlan = interface.getArgumentSurface("plan" ); |
17268 | state.inputsSK.planCount = interface.getArgument("plan_count" ); |
17269 | state.inputsSK.localID = interface.getLocalID(0); |
17270 | state.inputsSK.localSize = interface.getLocalSize(0); |
17271 | |
17272 | state.ra.claim(state.inputsSK.localID); |
17273 | state.ra.claim(state.inputsSK.localSize); |
17274 | state.ra.claim(state.inputsSK.planCount); |
17275 | } |
17276 | |
17277 | // Create a GEMM superkernel. |
17278 | template <HW hw> |
17279 | void gemm_kernel_generator_t<hw>::gemmSuperkernel( |
17280 | GEMMSuperkernelProblem problem, GEMMSuperkernelStrategy strategy, |
17281 | const InterfaceHandler &interface_) { |
17282 | auto &strategy0 = strategy.substrategies[0]; |
17283 | bool persistent = strategy.persistent; |
17284 | |
17285 | GEMMSuperkernelState state(hw); |
17286 | |
17287 | // Set up. |
17288 | setDefaultNoMask(); |
17289 | setDefaultAutoSWSB(); |
17290 | interface = interface_; |
17291 | gemmSuperkernelInitState(problem, strategy, state); |
17292 | state.ra.safeRelease(state.inputs.localIDN); |
17293 | state.ra.safeRelease(state.inputs.localSizeN); |
17294 | |
17295 | for (auto &ss : strategy.substrategies) { |
17296 | if (!ss.A.base.isStateless()) ss.A.base.setIndex(state.inputs.surfaceA); |
17297 | if (!ss.B.base.isStateless()) ss.B.base.setIndex(state.inputs.surfaceB); |
17298 | if (!ss.C.base.isStateless()) |
17299 | ss.C.base.setIndex(state.inputs.surfaceC[0]); |
17300 | } |
17301 | |
17302 | // Prevent unhelpful layouts. |
17303 | if (problem.A.layout == MatrixLayout::PackedRows) stub(); |
17304 | if (problem.B.layout == MatrixLayout::PackedColumns) stub(); |
17305 | |
17306 | Label loopSK, loopSKEnd; |
17307 | |
17308 | // Prologue. |
17309 | prologue(strategy0); |
17310 | |
17311 | // Grab fused ID if needed. |
17312 | getFusedID(1, problem, strategy0, state); |
17313 | |
17314 | // Get my plan ID and convert to offset in plan. |
17315 | auto idX = r0.ud(1); |
17316 | auto = state.ra.alloc(); |
17317 | auto poff = header.ud(2); |
17318 | constexpr uint16_t eltSz = 8; |
17319 | |
17320 | auto temp = state.ra.alloc_sub<uint32_t>(); |
17321 | |
17322 | mulConstant(1, temp, state.inputsSK.planCount, strategy.subgroupSize()); |
17323 | mul(1, poff, idX, state.inputsSK.localSize); |
17324 | add(1, poff, poff, state.inputsSK.localID.uw(0)); |
17325 | cmp<uint32_t>(1 | ge | f0[0], poff, temp); |
17326 | if (eltSz < strategy.subgroupSize()) |
17327 | shr(1, poff, poff, log2(strategy.subgroupSize() / eltSz)); |
17328 | else if (eltSz > strategy.subgroupSize()) |
17329 | mulConstant(1, poff, poff, eltSz / strategy.subgroupSize()); |
17330 | |
17331 | state.ra.safeRelease(temp); |
17332 | state.ra.safeRelease(state.inputsSK.localID); |
17333 | state.ra.safeRelease(state.inputsSK.localSize); |
17334 | |
17335 | if (persistent) add(1, poff, poff, eltSz); |
17336 | |
17337 | // Move r0 to acc0 if configured. |
17338 | moveR0(strategy0, state); |
17339 | |
17340 | // Quick exit for extra threads (uniform WG). |
17341 | jmpi(1 | f0[0], loopSKEnd); |
17342 | |
17343 | // Retrieve plan element. |
17344 | auto pdata = state.ra.alloc(getHint(HintType::TempComp0, strategy0)); |
17345 | load(8, pdata, aligned_block_oword(1), Surface(state.inputsSK.surfacePlan), |
17346 | header); |
17347 | state.ra.safeRelease(header); |
17348 | |
17349 | gemmScaleInputs(problem, strategy0, state); // Scale inputs while waiting. |
17350 | |
17351 | state.i0 = pdata.d(0); |
17352 | state.j0 = pdata.d(1); |
17353 | |
17354 | state.ra.safeRelease(pdata); |
17355 | state.ra.claim(state.i0); |
17356 | state.ra.claim(state.j0); |
17357 | |
17358 | auto flagKID0 = f1[0]; |
17359 | auto flagKID1 = f1[1]; |
17360 | |
17361 | if (strategy.multiM) cmp(1 | lt | flagKID0, null.d(), state.i0, 0); |
17362 | if (strategy.multiN) cmp(1 | lt | flagKID1, null.d(), state.j0, 0); |
17363 | and_(2, state.i0.ud()(1), state.i0.ud()(1), uint32_t(0x7FFFFFFF)); |
17364 | |
17365 | // Initial offset of A/B/C. |
17366 | gemmOffsetABC( |
17367 | true, state.i0, state.j0, Subregister(), problem, strategy0, state); |
17368 | gemmSetupABC(problem, strategy0, state); |
17369 | |
17370 | // Save i0, j0 for later. |
17371 | state.last_i0 = state.ra.alloc_sub<int32_t>( |
17372 | getHint(HintType::LongTerm, strategy0)); |
17373 | state.last_j0 = state.ra.alloc_sub<int32_t>( |
17374 | getHint(HintType::LongTerm, strategy0)); |
17375 | mov(1, state.last_i0, state.i0); |
17376 | mov(1, state.last_j0, state.j0); |
17377 | |
17378 | // Top of superkernel loop. |
17379 | status << "Begin superkernel loop" << status_stream::endl; |
17380 | mark(loopSK); |
17381 | { |
17382 | // Dispatch appropriate kernel, supporting up to 4 subkernels. |
17383 | int kidx = 0; |
17384 | Label labelM1, labelM0N1, labelM1N1, labelKernelDone; |
17385 | if (strategy.multiM) jmpi(1 | flagKID0, labelM1); |
17386 | if (strategy.multiN) jmpi(1 | flagKID1, labelM0N1); |
17387 | |
17388 | gemmSubkernel(problem, strategy.substrategies[kidx++], state); |
17389 | |
17390 | if (strategy.multiN) { |
17391 | jmpi(1, labelKernelDone); |
17392 | mark(labelM0N1); |
17393 | gemmSubkernel(problem, strategy.substrategies[kidx++], state); |
17394 | } |
17395 | |
17396 | if (strategy.multiM) { |
17397 | jmpi(1, labelKernelDone); |
17398 | |
17399 | mark(labelM1); |
17400 | if (strategy.multiN) jmpi(1 | flagKID1, labelM1N1); |
17401 | |
17402 | gemmSubkernel(problem, strategy.substrategies[kidx++], state); |
17403 | |
17404 | if (strategy.multiN) { |
17405 | jmpi(1, labelKernelDone); |
17406 | mark(labelM1N1); |
17407 | gemmSubkernel(problem, strategy.substrategies[kidx++], state); |
17408 | } |
17409 | } |
17410 | |
17411 | mark(labelKernelDone); |
17412 | |
17413 | if (persistent) { |
17414 | // Get next plan element via atomic increment of plan ID counter. |
17415 | auto = state.ra.alloc(); |
17416 | auto nextID |
17417 | = state.ra.alloc(getHint(HintType::TempComp1, strategy0)); |
17418 | auto pdata |
17419 | = state.ra.alloc(getHint(HintType::TempComp0, strategy0)); |
17420 | |
17421 | mov<uint32_t>(8, header, uint16_t(0)); |
17422 | atomic(AtomicOp::inc, 1, nextID, scattered_dword(), |
17423 | Surface(state.inputsSK.surfacePlan), header); |
17424 | |
17425 | // Load next plan element, or exit if no more work. |
17426 | mulConstant<uint32_t>(1, header[2], nextID[0], eltSz); |
17427 | cmp<uint32_t>( |
17428 | 1 | ge | f0[0], null, nextID[0], state.inputsSK.planCount); |
17429 | add<uint32_t>(1, header[2], header[2], eltSz); |
17430 | |
17431 | jmpi(1 | f0[0], loopSKEnd); |
17432 | |
17433 | load(8, pdata, aligned_block_oword(1), |
17434 | Surface(state.inputsSK.surfacePlan), header); |
17435 | state.ra.safeRelease(header); |
17436 | state.ra.safeRelease(nextID); |
17437 | |
17438 | // Load next (i0, j0) and kernel IDs. |
17439 | auto in_i0 = pdata.d(0); |
17440 | auto in_j0 = pdata.d(1); |
17441 | |
17442 | if (strategy.multiM) cmp(1 | lt | flagKID0, null.d(), in_i0, 0); |
17443 | if (strategy.multiN) cmp(1 | lt | flagKID1, null.d(), in_j0, 0); |
17444 | and_(1, state.i0.ud(), in_i0.ud(), uint32_t(0x7FFFFFFF)); |
17445 | and_(1, state.j0.ud(), in_j0.ud(), uint32_t(0x7FFFFFFF)); |
17446 | |
17447 | // Get difference in i0 and j0... |
17448 | add(1, in_i0, state.i0, -state.last_i0); |
17449 | add(1, in_j0, state.j0, -state.last_j0); |
17450 | |
17451 | // ... save current (i0, j0) for later... |
17452 | mov(1, state.last_i0, state.i0); |
17453 | mov(1, state.last_j0, state.j0); |
17454 | |
17455 | // ...and offset A, B, C appropriately. |
17456 | gemmOffsetABC(false, in_i0, in_j0, Subregister(), problem, |
17457 | strategy0, state); |
17458 | |
17459 | state.ra.safeRelease(pdata); |
17460 | |
17461 | state.ra.safeRelease(state.i0); |
17462 | state.ra.safeRelease(state.j0); |
17463 | |
17464 | // Ready for the next kernel. |
17465 | jmpi(1, loopSK); |
17466 | } |
17467 | } |
17468 | mark(loopSKEnd); |
17469 | |
17470 | epilogue(strategy.substrategies[0], state); |
17471 | padding(); |
17472 | } |
17473 | |
17474 | // Get driver information from this strategy. |
17475 | template <HW hw> |
17476 | CommonDriverInfo gemm_kernel_generator_t<hw>::driverInfo( |
17477 | const GEMMProblem &problem, const GEMMStrategy &strategy) { |
17478 | CommonDriverInfo info; |
17479 | |
17480 | info.subgroupSize = strategy.subgroupSize; |
17481 | info.fusedLoop = strategy.fused ? strategy.fusedLoop : LoopNone; |
17482 | info.grfCount = strategy.GRFs; |
17483 | for (int d = 0; d < 3; d++) { |
17484 | info.loopOrder[d] = strategy.loopOrder[d]; |
17485 | info.blocking[d] = strategy.blocking[d]; |
17486 | info.blockingAlt[d] = strategy.blockingAlt[d]; |
17487 | info.unroll[d] = strategy.unroll[d]; |
17488 | info.wg[d] = strategy.wg[d]; |
17489 | } |
17490 | info.wgExpand = strategy.splitCopy ? 2 : 1; |
17491 | if (strategy.hilbertOrder) { |
17492 | info.loopOrder[0] = (info.loopOrder[0] == LoopN) ? LoopMNHilbertNMK |
17493 | : LoopMNHilbertMNK; |
17494 | info.loopOrder[1] = LoopNone; |
17495 | } else if (strategy.boustrophedon) { |
17496 | info.loopOrder[0] = (info.loopOrder[0] == LoopN) |
17497 | ? LoopMNBoustrophedonNMK |
17498 | : LoopMNBoustrophedonMNK; |
17499 | info.loopOrder[1] = LoopNone; |
17500 | } |
17501 | if (strategy.persistent) |
17502 | info.loopOrder[0] |
17503 | = static_cast<LoopType>(info.loopOrder[0] | LoopPersistent); |
17504 | if (problem.batch == BatchMode::None && !strategy.kParallelLocal) |
17505 | info.loopOrder[2] = LoopNone; |
17506 | info.wgUpdate = strategy.getWGType(problem); |
17507 | info.kRemainderHandling |
17508 | = (strategy.remHandling[LoopK] != RemainderHandling::Ignore); |
17509 | info.kParallel = strategy.kParallel; |
17510 | info.kParallelLocal = strategy.kParallelLocal; |
17511 | info.slm = int(gemmSLMSize(problem, strategy)); |
17512 | info.perKSLM = int(gemmPerKSLMSize(problem, strategy)); |
17513 | info.alignment[0] = problem.A.alignment; |
17514 | info.alignment[1] = problem.B.alignment; |
17515 | info.alignment[2] = problem.C.alignment; |
17516 | info.support4GB[0] = (strategy.A.base.getModel() == ModelA64); |
17517 | info.support4GB[1] = (strategy.B.base.getModel() == ModelA64); |
17518 | info.support4GB[2] = (strategy.C.base.getModel() == ModelA64); |
17519 | |
17520 | return info; |
17521 | } |
17522 | |
17523 | template <HW hw> |
17524 | CommonDriverInfo gemm_kernel_generator_t<hw>::driverInfo( |
17525 | const GEMMSuperkernelProblem &problem, const GEMMStrategy &strategy) { |
17526 | auto info = driverInfo(static_cast<GEMMProblem>(problem), strategy); |
17527 | return info; |
17528 | } |
17529 | |
17530 | // Return the maximum possible k size for copied SLM data. |
17531 | int GEMMStrategy::maxKSLM(const GEMMProblem &problem, bool isA) const { |
17532 | return unrollKSLM; |
17533 | } |
17534 | |
17535 | // Validate a GEMM strategy, correcting settings as necessary. |
17536 | void GEMMStrategy::preflight(HW hw, const GEMMProblem &problem) { |
17537 | auto Ta = problem.Ta, Tb = problem.Tb, Tc = problem.Tc; |
17538 | auto Ta_real = Ta.real(); |
17539 | auto Tb_real = Tb.real(); |
17540 | auto Tc_real = Tc.real(); |
17541 | |
17542 | // Addressing preflight. |
17543 | |
17544 | if (C.atomic && !C.base.isStateless() && !C.newDP) C.forceA64(); |
17545 | |
17546 | slmA &= (slmBuffers > 0); |
17547 | slmB &= (slmBuffers > 0); |
17548 | |
17549 | A.preflight(hw); |
17550 | B.preflight(hw); |
17551 | C.preflight(hw); |
17552 | A_prefetch.preflight(hw); |
17553 | B_prefetch.preflight(hw); |
17554 | C_prefetch.preflight(hw); |
17555 | |
17556 | bool globalCM = isRegisterColMajor(problem.Tc, problem.C, C); |
17557 | |
17558 | // Default SIMD setting. |
17559 | if (fmaSIMD == 0) { |
17560 | fmaSIMD = std::min(32, |
17561 | 2 * GRF::bytes(hw) |
17562 | / std::max<int>({Ta.size(), Tb.size(), Tc.size()})); |
17563 | if (hw == HW::Gen9 && Ta_real.size() == 1 && Tb_real.size() == 1 |
17564 | && Tc_real.size() == 4) |
17565 | fmaSIMD = 32; |
17566 | } |
17567 | |
17568 | slmFenceWARWA |= (hw == HW::XeHPG); |
17569 | |
17570 | if (problem.batch != BatchMode::None) { |
17571 | persistent = false; |
17572 | kParallel = false; |
17573 | } |
17574 | |
17575 | if (coopA == CoopSplit::K && slmATrans) coopA = CoopSplit::MN; |
17576 | if (coopB == CoopSplit::K && slmBTrans) coopB = CoopSplit::MN; |
17577 | |
17578 | checkBeta1 |= C.atomic && !problem.beta1(); |
17579 | |
17580 | // Fixed systolic kernel handling. |
17581 | if (fixedSystolic) { |
17582 | if (wg[LoopM] == 0) wg[LoopM] = 4; |
17583 | if (wg[LoopN] == 0) wg[LoopN] = 4; |
17584 | bool doubleM = (wg[LoopM] == 8); |
17585 | |
17586 | slmCopies = (slmCopies == 3) ? 3 : 1; |
17587 | slmBuffers = (splitCopy || doubleM) ? 4 : 3; |
17588 | slmA = slmB = true; |
17589 | GRFs = 256; |
17590 | altCRemainder = false; |
17591 | loopOrder[0] = LoopM; |
17592 | loopOrder[1] = LoopN; |
17593 | loopOrder[2] = LoopK; |
17594 | A.accessType = B.accessType = AccessType::Block; |
17595 | ka_load = kb_load = 32 / Ta_real; |
17596 | dpasw = true; |
17597 | } |
17598 | |
17599 | dpasw &= fused; |
17600 | |
17601 | // Accumulator usage: 64-bit emulation, or k chaining, or extra C registers, or storage for r0 header. |
17602 | // Priority: k chaining > extra C registers > r0 header storage. |
17603 | // 64-bit emulation > r0 header storage. |
17604 | if (hw <= HW::Gen9) kChain = 1; |
17605 | cAccumulators &= (kChain == 1); |
17606 | |
17607 | bool emulateNeedsAcc = emulate.emulate64 || emulate.emulateDWxDW; |
17608 | if (moveR0 == MoveR0::Acc) |
17609 | if (cAccumulators || emulateNeedsAcc || xParallel || (kChain > 1) |
17610 | || barrierFreq) |
17611 | moveR0 = MoveR0::None; |
17612 | |
17613 | // Mixed mode restrictions: |
17614 | // - mixed hf/f is max SIMD 8 on Gen9 |
17615 | // - mixed hf/f is not allowed on Gen12 |
17616 | // - mixed bf/f is max SIMD 8 on ATS+ |
17617 | if ((Tc_real == Type::f32) |
17618 | && (Ta_real != Type::f32 || Tb_real != Type::f32)) |
17619 | fmaSIMD = std::min(fmaSIMD, GRF::bytes(hw) >> 2); |
17620 | |
17621 | // No jump table paths use SIMT control flow. Also atomic reductions. |
17622 | spf &= !noJumpTables; |
17623 | spf &= !C.atomic; |
17624 | |
17625 | checkAdd32 &= !emulate.emulate64_add32; |
17626 | checkAdd32 &= (A.base.isStateless() || B.base.isStateless()); |
17627 | checkAdd32 &= !(A.address2D && B.address2D |
17628 | && (!prefetchA || A_prefetch.address2D) |
17629 | && (!prefetchB || B_prefetch.address2D)); |
17630 | |
17631 | int opCount = outerProductCount(hw, problem, *this); |
17632 | int minOPCount = minOuterProductCount(hw, problem, *this); |
17633 | int ukAlign = opCount; |
17634 | |
17635 | if (kParallelLocal) moveR0 = MoveR0::None; |
17636 | |
17637 | // SLM copy logic. |
17638 | int slmVersions = std::max(1, lcm(slmCopies, slmBuffers)); |
17639 | if (slmBuffers > 0) { |
17640 | moveR0 = MoveR0::None; |
17641 | barrierFreq = 0; |
17642 | if (wg[LoopM] <= 0 || wg[LoopN] <= 0) |
17643 | throw std::runtime_error("Workgroup sizes required." ); |
17644 | if (slmA) ukAlign = lcm(ukAlign, wg[LoopN] * slmVersions); |
17645 | if (slmB) ukAlign = lcm(ukAlign, wg[LoopM] * slmVersions); |
17646 | slmUseIncrCopy &= (slmCopies == 1); |
17647 | } |
17648 | |
17649 | // ka/kb_load wranging. |
17650 | if (ka_load_masked == 0) ka_load_masked = ka_load; |
17651 | if (kb_load_masked == 0) kb_load_masked = kb_load; |
17652 | |
17653 | if (!slmA) { |
17654 | ka_load = align_up(ka_load, opCount); |
17655 | ka_load_masked = align_up(ka_load_masked, minOPCount); |
17656 | } |
17657 | if (!slmB) { |
17658 | kb_load = align_up(kb_load, opCount); |
17659 | kb_load_masked = align_up(kb_load_masked, minOPCount); |
17660 | } |
17661 | |
17662 | // Systolic handling. |
17663 | if (systolic) { |
17664 | auto params = systolicParams(hw, problem, *this); |
17665 | |
17666 | ukAlign = lcm(ukAlign, params.ksys); |
17667 | auto tileX = params.osys; |
17668 | (globalCM ? C.tileR : C.tileC) = tileX; |
17669 | if (unroll[globalCM ? LoopM : LoopN] > tileX) forceCopyC = true; |
17670 | } |
17671 | |
17672 | // Prefetch handling. |
17673 | cooperativePF &= (prefetchA || prefetchB); |
17674 | |
17675 | if (problem.beta0()) prefetchC = 0; |
17676 | |
17677 | // Propagate tiling requests to strategy. |
17678 | int tileM_A, tileK_A, tileK_B, tileN_B; |
17679 | std::tie(tileM_A, tileK_A, tileK_B, tileN_B) |
17680 | = targetKernelTiling(hw, problem, *this); |
17681 | if (A.accessType != AccessType::Block) { |
17682 | if (tileM_A && !A.tileR) A.tileR = tileM_A; |
17683 | if (tileK_A && !A.tileC) A.tileC = tileK_A; |
17684 | } |
17685 | if (B.accessType != AccessType::Block) { |
17686 | if (tileK_B && !B.tileR) B.tileR = tileK_B; |
17687 | if (tileN_B && !B.tileC) B.tileC = tileN_B; |
17688 | } |
17689 | |
17690 | if (dpasw) { |
17691 | auto params = systolicParams(hw, problem, *this); |
17692 | if (globalCM) { |
17693 | if (!fusedM()) stub(); |
17694 | B.dpasw = true; |
17695 | B.tileC = std::max( |
17696 | 1, std::min(unroll[LoopN], params.rcountMax) / 2); |
17697 | } else { |
17698 | if (!fusedN()) stub(); |
17699 | A.dpasw = true; |
17700 | A.tileR = std::max( |
17701 | 1, std::min(unroll[LoopM], params.rcountMax) / 2); |
17702 | } |
17703 | } |
17704 | |
17705 | // Always use 1D addressing for packed inputs. |
17706 | A.address2D &= !isPacked(problem.A.layout); |
17707 | B.address2D &= !isPacked(problem.B.layout); |
17708 | |
17709 | // k unroll wrangling. |
17710 | ukAlign = lcm(ukAlign, A_copies * ka_load); |
17711 | ukAlign = lcm(ukAlign, B_copies * kb_load); |
17712 | if (slmCopies > 1) { |
17713 | ukAlign = lcm(ukAlign, slmCopies * ka_load); |
17714 | ukAlign = lcm(ukAlign, slmCopies * kb_load); |
17715 | } |
17716 | if (ka_pfStride) ukAlign = lcm(ukAlign, ka_pfStride); |
17717 | if (kb_pfStride) ukAlign = lcm(ukAlign, kb_pfStride); |
17718 | |
17719 | int minUnrollKSLM = 1; |
17720 | if (unrollKSLM > 0) |
17721 | minUnrollKSLM = unrollKSLM; |
17722 | else { |
17723 | if (slmA) minUnrollKSLM = lcm(minUnrollKSLM, ka_load); |
17724 | if (slmB) minUnrollKSLM = lcm(minUnrollKSLM, kb_load); |
17725 | } |
17726 | |
17727 | ukAlign = align_up(ukAlign, minUnrollKSLM * slmVersions); |
17728 | |
17729 | unroll[LoopK] = align_up(unroll[LoopK], ukAlign); |
17730 | barrierFreq = align_up(barrierFreq, unroll[LoopK]); |
17731 | |
17732 | if (unrollKSLM == 0) unrollKSLM = unroll[LoopK] / slmVersions; |
17733 | |
17734 | if (fixedSystolic) unroll[LoopK] = unrollKSLM = 32 / Ta_real; |
17735 | |
17736 | barrierFreq = align_up(barrierFreq, unroll[LoopK]); |
17737 | |
17738 | int kChunkA = (problem.A.tileC ? problem.A.tileC : problem.A.crosspack); |
17739 | int kChunkB = (problem.B.tileR ? problem.B.tileR : problem.B.crosspack); |
17740 | if (unroll[LoopK] <= std::min(kChunkA, kChunkB)) |
17741 | remHandling[LoopK] = RemainderHandling::Ignore; |
17742 | |
17743 | // Default blocking. |
17744 | bool isZ = problem.Tc.size() >= 16; |
17745 | auto defaultMBlock = isZ ? 2048 : 4096; |
17746 | if (hw >= HW::XeHP) defaultMBlock *= 2; |
17747 | auto defaultNBlock = defaultMBlock; |
17748 | auto defaultMNBlockNonHilbert = defaultMBlock; |
17749 | |
17750 | /* No more than (2^16 - 1) workgroups in m/n dimensions for linear orders, plus a huge safety margin. */ |
17751 | if (linearOrder()) { |
17752 | defaultMBlock = 16384 * unroll[LoopM]; |
17753 | defaultNBlock = 16384 * unroll[LoopN]; |
17754 | } |
17755 | |
17756 | if (blocking[LoopM] <= 0) blocking[LoopM] = defaultMBlock; |
17757 | if (blocking[LoopN] <= 0) blocking[LoopN] = defaultNBlock; |
17758 | if (blocking[LoopK] <= 0) { |
17759 | int points = 1; |
17760 | if (slmA || (problem.A.layout != MatrixLayout::T)) points++; |
17761 | if (slmB || (problem.B.layout != MatrixLayout::N)) points++; |
17762 | blocking[LoopK] = std::min(2048, (2048 * points) / problem.Ta); |
17763 | } |
17764 | |
17765 | auto defaultBlockAltK = blocking[LoopK]; |
17766 | if (hw < HW::XeHPC) |
17767 | if (hw >= HW::XeHP) defaultBlockAltK = std::min(defaultBlockAltK, 1024); |
17768 | |
17769 | if (blockingAlt[LoopM] <= 0) blockingAlt[LoopM] = defaultMNBlockNonHilbert; |
17770 | if (blockingAlt[LoopN] <= 0) blockingAlt[LoopN] = defaultMNBlockNonHilbert; |
17771 | if (blockingAlt[LoopK] <= 0) blockingAlt[LoopK] = defaultBlockAltK; |
17772 | |
17773 | // Default workgroups. |
17774 | auto defaultWGX = 2, defaultWGY = 8; |
17775 | |
17776 | if (wg[loopOrder[0]] <= 0) wg[loopOrder[0]] = defaultWGX; |
17777 | if (wg[loopOrder[1]] <= 0) wg[loopOrder[1]] = defaultWGY; |
17778 | if (wg[LoopK] <= 0) { |
17779 | if (kParallelLocal) |
17780 | wg[LoopK] = (threadsPerEU(hw, *this) * eusPerSubslice(hw)) |
17781 | / (wg[LoopM] * wg[LoopN]); |
17782 | else |
17783 | wg[LoopK] = 1; |
17784 | } |
17785 | |
17786 | kParallelLocal &= (wg[LoopK] > 1); |
17787 | |
17788 | skewLocalIDs &= (wg[LoopM] * wg[LoopN] > eusPerSubslice(hw)); |
17789 | |
17790 | if (skewLocalIDs) forceWGUpdate = WGFixed; |
17791 | |
17792 | avoidIncConflicts &= (hw >= HW::XeHP); |
17793 | |
17794 | CommonStrategy::preflight(hw, problem); |
17795 | } |
17796 | |
17797 | // Reduce register pressure. Returns true if successful. |
17798 | bool GEMMStrategy::minimize(HW hw, const GEMMProblem &problem) { |
17799 | bool better = false; |
17800 | auto minOPCount = minOuterProductCount(hw, problem, *this); |
17801 | auto ka_load_best_min = std::max<int>({1, 4 / problem.Ta, minOPCount}); |
17802 | auto kb_load_best_min = std::max<int>({1, 4 / problem.Tb, minOPCount}); |
17803 | |
17804 | // Reduce ka/b_load down to suggested minimums (not requiring crosspack) |
17805 | if (ka_load > ka_load_best_min) { |
17806 | ka_load = ka_load_best_min; |
17807 | better = true; |
17808 | } |
17809 | if (kb_load > kb_load_best_min) { |
17810 | kb_load = kb_load_best_min; |
17811 | better = true; |
17812 | } |
17813 | |
17814 | // Reduce A/B copies. |
17815 | A_copies = B_copies = 1; |
17816 | |
17817 | // Remove k chaining. |
17818 | kChain = 1; |
17819 | |
17820 | // Reduce k unroll for SLM copies. |
17821 | if (slmA || slmB) { |
17822 | auto oldUK = unroll[LoopK]; |
17823 | unroll[LoopK] = 1; |
17824 | unrollKSLM = 0; |
17825 | preflight(hw, problem); |
17826 | better |= (unroll[LoopK] < oldUK); |
17827 | } |
17828 | |
17829 | if (better) return better; |
17830 | |
17831 | // Reduce ka/b_load to absolute minimum if that failed. |
17832 | if (ka_load > minOPCount) { |
17833 | ka_load = minOPCount; |
17834 | better = true; |
17835 | } |
17836 | if (kb_load > minOPCount) { |
17837 | kb_load = minOPCount; |
17838 | better = true; |
17839 | } |
17840 | |
17841 | return better; |
17842 | } |
17843 | |
17844 | // Validate a GEMM superkernel strategy, correcting settings as necessary. |
17845 | void GEMMSuperkernelStrategy::preflight(HW hw, const GEMMProblem &problem) { |
17846 | if (substrategies.size() <= 0) |
17847 | throw std::runtime_error("No substrategies for superkernel." ); |
17848 | auto subgroupSize = substrategies[0].subgroupSize; |
17849 | for (auto &ss : substrategies) { |
17850 | ss.insideSK = true; |
17851 | ss.preflight(hw, problem); |
17852 | if (ss.subgroupSize != subgroupSize) |
17853 | throw std::runtime_error("Incompatible subgroup sizes." ); |
17854 | } |
17855 | } |
17856 | |
17857 | void MatrixAddressingStrategy::preflight(HW hw) { |
17858 | newDP |= isBlock2D(accessType); |
17859 | if (prefetch && newDP && cachingR == CacheSettingsLSC::Default) |
17860 | cachingR = CacheSettingsLSC::L1C_L3C; |
17861 | |
17862 | if (accessType == AccessType::ChannelScattered && base.isStateless() |
17863 | && !newDP) |
17864 | base = AddressBase::createBTS(0); |
17865 | } |
17866 | |
17867 | void MatrixAddressingStrategy::forceA64() { |
17868 | base = AddressBase::createA64(true); |
17869 | if (accessType == AccessType::ChannelScattered && !newDP) |
17870 | accessType = AccessType::Scattered; |
17871 | } |
17872 | |
17873 | /**********************************************************************/ |
17874 | /* Fixed Systolic GEMM (XeHP/XeHPG) */ |
17875 | /**********************************************************************/ |
17876 | namespace sysgemm { |
17877 | static GRFRange A_copy0 = GRF(40) - GRF(47); |
17878 | static GRFRange B_copy0 = GRF(2) - GRF(13); |
17879 | static GRFRange A_regs = GRF(48) - GRF(63); |
17880 | static GRFRange B_regs = GRF(14) - GRF(37); |
17881 | static GRFRange C_regs = GRF(64) - GRF(255); |
17882 | static GRFRange A_copy1 = GRF(96) - GRF(103); |
17883 | static GRFRange B_copy1 = GRF(104) - GRF(111); |
17884 | static GRFRange A_copy2 = GRF(144) - GRF(151); |
17885 | static GRFRange B_copy2 = GRF(152) - GRF(159); |
17886 | static GRFRange A_copy[3] = {A_copy0, A_copy1, A_copy2}; |
17887 | static GRFRange B_copy[3] = {B_copy0, B_copy1, B_copy2}; |
17888 | static GRF addr0 = GRF(1); |
17889 | static GRF addr1 = GRF(38); |
17890 | static GRF addr2 = GRF(39); |
17891 | static GRF addr3 = GRF(0); |
17892 | static Subregister A_ptr64 = addr1.uq(3); |
17893 | static Subregister B_ptr64 = addr2.uq(3); |
17894 | static Subregister C_ptr64 = addr2.uq(2); |
17895 | static Subregister slmAOffsetLoad = addr1.uw(8); // offsets in OWords |
17896 | static Subregister slmBOffsetLoad = addr1.uw(9); |
17897 | static Subregister slmAOffsetStore = addr1.uw(10); |
17898 | static Subregister slmBOffsetStore = addr1.uw(11); |
17899 | static Subregister slmAOffsetLoadInit = addr1.uw(6); |
17900 | static Subregister slmBOffsetLoadInit = addr1.uw(7); |
17901 | static Subregister slmAOffsetStoreInit = addr2.uw(6); |
17902 | static Subregister slmBOffsetStoreInit = addr2.uw(7); |
17903 | static Subregister kCounter = AccumulatorRegister(2).d(0); |
17904 | static Subregister barrierVal = AddressRegister(0).ud(0); |
17905 | static constexpr int accStride = 48; |
17906 | } // namespace sysgemm |
17907 | |
17908 | template <HW hw> |
17909 | bool gemm_kernel_generator_t<hw>::sysgemmAccumulateC( |
17910 | GEMMProblem &problem, const GEMMStrategy &strategy, GEMMState &state) { |
17911 | using namespace sysgemm; |
17912 | auto params = systolicParams(hw, problem, strategy); |
17913 | auto unrollM = strategy.unroll[LoopM]; |
17914 | auto unrollN = strategy.unroll[LoopN]; |
17915 | auto wgM = strategy.wg[LoopM]; |
17916 | auto wgN = strategy.wg[LoopN]; |
17917 | auto localIDM = state.lidM; |
17918 | auto localIDN = state.lidN; |
17919 | bool doubleM = (wgM == 8); |
17920 | bool surfaceAB = !strategy.A.base.isStateless(); |
17921 | bool surfaceC = !strategy.C.base.isStateless(); |
17922 | |
17923 | if (unrollM != 32) stub(); |
17924 | if (unrollN != 32 && unrollN != 48) stub(); |
17925 | if (wgM != 4 && wgM != 8) stub(); |
17926 | if (wgN != 4) stub(); |
17927 | if (strategy.A.base.getModel() != strategy.B.base.getModel()) stub(); |
17928 | if (problem.A.layout != MatrixLayout::Pc) stub(); |
17929 | if (problem.A.crosspack != params.opsPerChan) stub(); |
17930 | if (problem.A.tileR != params.osys) stub(); |
17931 | if (problem.A.tileC != params.ksys) stub(); |
17932 | if (problem.B.layout != MatrixLayout::Pr) stub(); |
17933 | if (problem.B.crosspack != params.ksys) stub(); |
17934 | if (problem.B.tileR != 0 || problem.B.tileC != 0) stub(); |
17935 | |
17936 | state.ra.claim(C_regs); |
17937 | |
17938 | // Adjust A/B addresses and SLM offsets. |
17939 | auto tempStorage = C_regs[0]; |
17940 | auto suboffsetA = tempStorage.ud(0); |
17941 | auto suboffsetB = tempStorage.ud(1); |
17942 | auto tempA = tempStorage.ud(2); |
17943 | auto wlidM = tempStorage.uw(6); |
17944 | auto tempB = tempStorage.ud(4); |
17945 | auto suboffsetBl = tempStorage.ud(5); |
17946 | |
17947 | if (doubleM) { |
17948 | and_(1, wlidM, localIDM, 3); |
17949 | and_(1 | ne | f1[1], null.uw(), localIDM, 4); |
17950 | } |
17951 | and_(1 | ne | state.flagAP, null.uw(), localIDM, 1); |
17952 | mulConstant(1, suboffsetA, localIDN, unrollM * (32 / 4)); |
17953 | if (doubleM) { |
17954 | mulConstant(1, suboffsetB, wlidM, unrollN * (32 / 4)); |
17955 | mulConstant(1, suboffsetBl, localIDM, unrollN * (32 / 4)); |
17956 | } else { |
17957 | mulConstant(1, suboffsetB, localIDM, unrollN * (32 / 4)); |
17958 | suboffsetBl = suboffsetB; |
17959 | } |
17960 | |
17961 | auto A_ptr = A_ptr64, B_ptr = B_ptr64, C_ptr = C_ptr64; |
17962 | if (surfaceAB) { |
17963 | if (!strategy.A.newDP || !strategy.B.newDP) stub(); |
17964 | A_ptr = A_ptr.ud(); |
17965 | B_ptr = B_ptr.ud(); |
17966 | } |
17967 | if (surfaceC) C_ptr = C_ptr.ud(); |
17968 | |
17969 | eadd(1, A_ptr, state.effA, suboffsetA, strategy, state); |
17970 | eadd(1, B_ptr, state.effB, suboffsetBl, strategy, state); |
17971 | emov(1, C_ptr, state.effC[0], strategy, state); |
17972 | |
17973 | shr(2, suboffsetA(1), suboffsetA(1), 4); |
17974 | |
17975 | mul(1, tempA, localIDM, (unrollM * 36) / 16); |
17976 | mad(1, tempB, (wgM * unrollM * 36) / 16, localIDN, (unrollN * 32) / 16); |
17977 | |
17978 | mov(1, slmAOffsetLoadInit.uw(), tempA.uw()); |
17979 | add(1 | state.flagAP, slmBOffsetLoadInit.uw(), tempB.uw(), |
17980 | (unrollN / 2) * (32 / 16)); |
17981 | mov(1 | ~state.flagAP, slmBOffsetLoadInit.uw(), tempB.uw()); |
17982 | add(1, slmAOffsetStoreInit.uw(), tempA.uw(), suboffsetA.uw()); |
17983 | add(1, slmBOffsetStoreInit.uw(), tempB.uw(), suboffsetB.uw()); |
17984 | mov(2, slmAOffsetLoad(1), slmAOffsetLoadInit(1)); |
17985 | |
17986 | // Marshal data needed later into acc2 for safekeeping. |
17987 | auto saveData = state.ra.alloc_range(2); |
17988 | auto kLoops = saveData[0].d(0); |
17989 | auto ldc = saveData[0].ud(1); |
17990 | auto flags = saveData[0].ud(2); |
17991 | auto k = saveData[0].ud(3); |
17992 | auto remM = saveData[0].uw(8); |
17993 | auto remN = saveData[0].uw(9); |
17994 | auto abo = saveData[0].ud(5); |
17995 | auto ao = saveData[0].w(10); |
17996 | auto bo = saveData[0].w(11); |
17997 | auto alpha = saveData[0].ud(6).reinterpret(0, problem.Ts.ngen()); |
17998 | auto beta = saveData[0].ud(7).reinterpret(0, problem.Ts.ngen()); |
17999 | auto remFusedStorage = saveData[1].ud(0); |
18000 | auto diagC = saveData[1].ud(1); |
18001 | auto saveI0 = saveData[1].ud(1); |
18002 | auto effCO = saveData[1].uq(1); |
18003 | auto saveJ0 = saveData[1].ud(3); |
18004 | auto slotAB = saveData[1].ud(4); |
18005 | auto effAs = saveData[1].uq(2).reinterpret(0, state.effA.getType()); |
18006 | auto effBs = saveData[1].uq(3).reinterpret(0, state.effB.getType()); |
18007 | |
18008 | if (state.r0_info != acc0.ud()) mov<uint32_t>(8, acc0, state.r0_info); |
18009 | |
18010 | add(1, kLoops, state.k, params.ksys - 1); |
18011 | mov(1, ldc, state.inputs.ldc[0]); |
18012 | if (state.inputs.flags.isValid()) mov(1, flags, state.inputs.flags); |
18013 | mov(1, k, state.k); |
18014 | if (state.remainders[LoopM].isValid()) |
18015 | mov(1, remM, state.remainders[LoopM]); |
18016 | if (state.remainders[LoopN].isValid()) |
18017 | mov(1, remN, state.remainders[LoopN]); |
18018 | if (state.inputs.abo.isValid()) |
18019 | mov(1, abo, state.inputs.abo); |
18020 | else { |
18021 | if (state.inputs.ao.isValid()) mov(1, ao, state.inputs.ao); |
18022 | if (state.inputs.bo.isValid()) mov(1, bo, state.inputs.bo); |
18023 | } |
18024 | if (state.inputs.alpha_real.isValid()) |
18025 | mov(1, alpha, state.inputs.alpha_real); |
18026 | if (state.inputs.beta_real.isValid()) mov(1, beta, state.inputs.beta_real); |
18027 | shr(1, kLoops, kLoops, log2(params.ksys)); |
18028 | if (state.remFusedStorage.isValid()) |
18029 | mov(1, remFusedStorage, state.remFusedStorage); |
18030 | if (state.diagC.isValid()) mov(1, diagC, state.diagC); |
18031 | if (state.effCO.isValid()) { |
18032 | effCO = effCO.reinterpret(0, state.effCO.getType()); |
18033 | emov(1, effCO, state.effCO, strategy, state); |
18034 | } |
18035 | if (problem.hasBinaryPostOp()) { |
18036 | if (state.diagC.isValid()) stub(); |
18037 | if (state.effCO.isValid() && effCO.getBytes() > 4) stub(); |
18038 | mov(1, saveI0, state.i0); |
18039 | mov(1, saveJ0, state.j0); |
18040 | } |
18041 | if (problem.abOffset != ABOffset::None) { |
18042 | state.effAs = effAs; |
18043 | state.effBs = effBs; |
18044 | gemmCalcABOffsetAddrs(problem, strategy, state); |
18045 | } |
18046 | if (state.fusedGEMM.slotA.isValid()) { |
18047 | if (problem.abOffset != ABOffset::None) |
18048 | stub(); // Not enough room in acc2. |
18049 | mov(1, slotAB, state.fusedGEMM.slotA.ud()); |
18050 | } |
18051 | |
18052 | releaseSavedMNLocalIDs(state); |
18053 | state.ra.safeRelease(state.effA); |
18054 | state.ra.safeRelease(state.effB); |
18055 | state.ra.safeRelease(state.effC[0]); |
18056 | state.ra.safeRelease(state.inputs.lda); |
18057 | state.ra.safeRelease(state.inputs.ldb); |
18058 | |
18059 | state.ra.release(state.inputs.ldc[0]); |
18060 | state.ra.release(state.k); |
18061 | state.ra.release(state.remainders[LoopM]); |
18062 | state.ra.release(state.remainders[LoopN]); |
18063 | state.ra.release(state.inputs.abo); |
18064 | state.ra.release(state.inputs.ao); |
18065 | state.ra.release(state.inputs.bo); |
18066 | state.ra.release(state.inputs.alpha_real); |
18067 | state.ra.release(state.inputs.beta_real); |
18068 | state.ra.release(state.remFusedStorage); |
18069 | state.ra.release(state.diagC); |
18070 | state.ra.release(state.effCO); |
18071 | state.ra.release(state.fusedGEMM.slotA); |
18072 | state.ra.release(state.fusedGEMM.slotB); |
18073 | |
18074 | if (state.r0_info.isARF()) stub(); |
18075 | GRF r0_info {state.r0_info.getBase()}; |
18076 | if (hw >= HW::XeHPG) { |
18077 | mov(1, barrierVal.uw(0), Immediate::uw(0)); |
18078 | mov(2, barrierVal.ub(2)(1), r0_info.ub(11)(0)); |
18079 | } else |
18080 | and_(1, barrierVal, r0_info.ud(2), 0x7F000000); |
18081 | |
18082 | mov<float>(16, acc2, saveData[0]); |
18083 | |
18084 | sync.nop(SWSB<AllPipes>(1)); |
18085 | |
18086 | if (!doubleM) |
18087 | sysgemmKLoop(problem, strategy, state); |
18088 | else { |
18089 | Label oddB, done; |
18090 | jmpi(1 | f1[1], oddB); |
18091 | sysgemmKLoop4(problem, strategy, state, false); |
18092 | jmpi(1, done); |
18093 | mark(oddB); |
18094 | sysgemmKLoop4(problem, strategy, state, true); |
18095 | mark(done); |
18096 | } |
18097 | |
18098 | mov<float>(16, saveData[0], acc2); |
18099 | |
18100 | state.effC[0] = C_ptr; |
18101 | state.inputs.ldc[0] = ldc; |
18102 | if (state.inputs.flags.isValid()) state.inputs.flags = flags; |
18103 | state.k = k; |
18104 | if (state.remainders[LoopM].isValid()) state.remainders[LoopM] = remM; |
18105 | if (state.remainders[LoopN].isValid()) state.remainders[LoopN] = remN; |
18106 | if (state.inputs.abo.isValid()) state.inputs.abo = abo; |
18107 | if (state.inputs.ao.isValid()) state.inputs.ao = ao; |
18108 | if (state.inputs.bo.isValid()) state.inputs.bo = bo; |
18109 | if (state.inputs.alpha_real.isValid()) { |
18110 | state.inputs.alpha_real = alpha; |
18111 | if (!problem.alpha_real.fixed()) problem.alpha_real = alpha; |
18112 | } |
18113 | if (state.inputs.beta_real.isValid()) { |
18114 | state.inputs.beta_real = beta; |
18115 | if (!problem.beta_real.fixed()) problem.beta_real = beta; |
18116 | } |
18117 | if (state.remFusedStorage.isValid()) { |
18118 | state.remFusedStorage = remFusedStorage; |
18119 | state.remaindersFused[LoopM] = state.remainders[LoopM]; |
18120 | state.remaindersFused[LoopN] = state.remainders[LoopN]; |
18121 | state.remaindersFused[strategy.fusedLoop] = remFusedStorage; |
18122 | } |
18123 | if (state.diagC.isValid()) state.diagC = diagC; |
18124 | if (state.effCO.isValid()) state.effCO = effCO; |
18125 | if (state.fusedGEMM.slotA.isValid()) { |
18126 | state.fusedGEMM.slotA = slotAB.uw(0); |
18127 | state.fusedGEMM.slotB = slotAB.uw(1); |
18128 | } |
18129 | if (problem.hasBinaryPostOp()) { |
18130 | state.i0 = saveI0; |
18131 | state.j0 = saveJ0; |
18132 | } |
18133 | |
18134 | state.ra.claim(C_ptr); |
18135 | |
18136 | // Set up C internal layout and registers. |
18137 | state.C_regs.resize(1); |
18138 | state.C_regs[0] = C_regs; |
18139 | state.C_layout.clear(); |
18140 | state.C_layout.reserve((unrollM / 8) * (unrollN / 4)); |
18141 | for (int j0 = 0; j0 < unrollN; j0 += 4) { |
18142 | for (int i0 = 0; i0 < unrollM; i0 += 8) { |
18143 | RegisterBlock block; |
18144 | block.log2GRFBytes = GRF::log2Bytes(hw); |
18145 | block.colMajor = true; |
18146 | block.splitComplex = false; |
18147 | block.cxComponent = RegisterBlock::Interleaved; |
18148 | block.nr = block.ld = 8; |
18149 | block.nc = 4; |
18150 | block.component = 0; |
18151 | block.offsetR = i0; |
18152 | block.offsetC = j0; |
18153 | block.crosspack = 1; |
18154 | block.bytes = 8 * 4 * problem.Tc.size(); |
18155 | block.simdSize = 0; |
18156 | |
18157 | int j0Interleaved = j0 << 1; |
18158 | if (j0Interleaved >= unrollN) j0Interleaved += 4 - unrollN; |
18159 | |
18160 | block.offsetBytes |
18161 | = (accStride * i0 / 8 + j0Interleaved) * GRF::bytes(hw); |
18162 | state.C_layout.push_back(block); |
18163 | } |
18164 | } |
18165 | |
18166 | // Set up C external layout. |
18167 | state.copyC = true; |
18168 | bool remM_Ce, remN_Ce; |
18169 | getCRemainders(problem, strategy, remM_Ce, remN_Ce); |
18170 | |
18171 | if (!getRegLayout(problem.Tc_ext, state.C_layoutExt, unrollM, unrollN, |
18172 | remM_Ce, remN_Ce, true, false, 0, 0, problem.C, |
18173 | state.Cext_strategy)) |
18174 | return false; |
18175 | if (remM_Ce || remN_Ce) |
18176 | (void)getRegLayout(problem.Tc_ext, state.C_layoutExtUnmasked, unrollM, |
18177 | unrollN, false, false, true, false, 0, 0, problem.C, |
18178 | state.Cext_strategy); |
18179 | |
18180 | if (state.r0_info != acc0.ud()) mov<uint32_t>(8, state.r0_info, acc0); |
18181 | |
18182 | return true; // Success! |
18183 | } |
18184 | |
18185 | template <HW hw> |
18186 | void gemm_kernel_generator_t<hw>::sysgemmKLoop(const GEMMProblem &problem, |
18187 | const GEMMStrategy &strategy, GEMMState &state) { |
18188 | using namespace sysgemm; |
18189 | Label top, bottom, skipMain, remTop, remBottom; |
18190 | |
18191 | auto nbBarrierWait = [&]() { |
18192 | if (!strategy.slmAltBarriers) barrierwait(); |
18193 | }; |
18194 | auto nbStoreSignal = [&](bool forceFence = false) { |
18195 | if (!strategy.slmAltBarriers) |
18196 | sysgemmStoreSignal(problem, strategy, state, forceFence); |
18197 | }; |
18198 | auto storeSignal = [&](bool forceFence = false) { |
18199 | sysgemmStoreSignal(problem, strategy, state, forceFence); |
18200 | }; |
18201 | auto copyLoad = [&](int storeBuffer, bool useC = false) { |
18202 | sysgemmCopyLoad(problem, strategy, state, storeBuffer, useC); |
18203 | }; |
18204 | auto copyStore = [&](int storeBuffer, bool first = false) { |
18205 | sysgemmCopyStore(problem, strategy, state, storeBuffer, first); |
18206 | }; |
18207 | auto multiply = [&](int buffer, bool lastMultiply = false) { |
18208 | sysgemmMultiply(problem, strategy, state, buffer, lastMultiply); |
18209 | }; |
18210 | |
18211 | bool oldDefaultAutoSWSB = getDefaultAutoSWSB(); |
18212 | setDefaultAutoSWSB(false); |
18213 | |
18214 | if (strategy.slmCopies == 1) { |
18215 | cmp(1 | lt | f1[1], kCounter, 3); |
18216 | add(1 | le | f0[1], kCounter, kCounter, -5); |
18217 | |
18218 | jmpi(1 | f1[1], skipMain); |
18219 | |
18220 | copyLoad(0, true); // L0 -> C |
18221 | copyLoad(1); // L1 |
18222 | copyStore(0, true); // S0 <- C |
18223 | storeSignal(true); // Signal 0 ready |
18224 | zeroMatrix(C_regs, strategy); |
18225 | sync.nop(SWSB<AllPipes>(1)); |
18226 | copyStore(1); // S1 |
18227 | |
18228 | nbBarrierWait(); // Wait 0 ready |
18229 | nbStoreSignal(); // Signal 1 ready |
18230 | |
18231 | jmpi(1 | f0[1], bottom); // Zero-trip loop check |
18232 | |
18233 | mark(top); |
18234 | add(1 | gt | f0[1], kCounter, kCounter, -3); |
18235 | |
18236 | copyLoad(2); // L2 |
18237 | multiply(0); // M0 |
18238 | nbBarrierWait(); // Wait 0 ready |
18239 | copyStore(2); // S2 |
18240 | nbStoreSignal(); // Signal 2 ready |
18241 | |
18242 | copyLoad(0); // L0 |
18243 | multiply(1); // M1 |
18244 | nbBarrierWait(); // Wait 2 ready |
18245 | copyStore(0); // S0 |
18246 | nbStoreSignal(); // Signal 0 ready |
18247 | |
18248 | copyLoad(1); // L1 |
18249 | multiply(2); // M2 |
18250 | nbBarrierWait(); // Wait 0 ready |
18251 | copyStore(1); // S1 |
18252 | nbStoreSignal(); // Signal 1 ready |
18253 | |
18254 | jmpi(1 | f0[1], top); |
18255 | mark(bottom); |
18256 | |
18257 | copyLoad(2); // L2 |
18258 | multiply(0); // M0 |
18259 | nbBarrierWait(); // Wait 1 ready |
18260 | copyStore(2); // S2 |
18261 | nbStoreSignal(); // Signal 2 ready |
18262 | |
18263 | multiply(1); // M1 |
18264 | |
18265 | nbBarrierWait(); // Wait 2 ready |
18266 | |
18267 | multiply(2, true); // M2 |
18268 | |
18269 | add(1 | le | f0[1], kCounter, kCounter, 2); |
18270 | jmpi(1 | f0[1], remBottom); |
18271 | jmpi(1, remTop); |
18272 | |
18273 | mark(skipMain); |
18274 | |
18275 | zeroMatrix(C_regs, strategy); |
18276 | add(1, kCounter, kCounter, 5); |
18277 | |
18278 | mov(2, slmAOffsetStore(1), slmAOffsetStoreInit(1)); |
18279 | sync.nop(SWSB<AllPipes>(1)); |
18280 | |
18281 | mark(remTop); |
18282 | |
18283 | cmp(1 | lt | f0[1], kCounter, 2); |
18284 | copyLoad(0); |
18285 | copyStore(0); |
18286 | storeSignal(true); |
18287 | nbBarrierWait(); |
18288 | multiply(0, true); |
18289 | |
18290 | jmpi(1 | f0[1], remBottom); |
18291 | copyLoad(1); |
18292 | copyStore(1); |
18293 | storeSignal(true); |
18294 | nbBarrierWait(); |
18295 | multiply(1, true); |
18296 | |
18297 | mark(remBottom); |
18298 | } else if (strategy.slmCopies == 3) { |
18299 | // Triple-buffered global memory load + SLM pipeline. |
18300 | cmp(1 | lt | f1[1], kCounter, 4); |
18301 | add(1 | le | f0[1], kCounter, kCounter, -6); |
18302 | |
18303 | jmpi(1 | f1[1], skipMain); |
18304 | |
18305 | copyLoad(0); // L0 |
18306 | copyLoad(1); // L1 |
18307 | copyLoad(2); // L2 |
18308 | copyStore(0, true); // S0 |
18309 | storeSignal(true); // Signal 0 ready |
18310 | zeroMatrix(C_regs, strategy); |
18311 | copyLoad(0); // L0 |
18312 | sync.nop(SWSB<uint32_t>(1)); |
18313 | copyStore(1); // S1 |
18314 | |
18315 | nbBarrierWait(); // Wait 0 ready |
18316 | nbStoreSignal(); // Signal 1 ready |
18317 | |
18318 | jmpi(1 | f0[1], bottom); // Zero-trip loop check |
18319 | |
18320 | mark(top); |
18321 | add(1 | gt | f0[1], kCounter, kCounter, -3); |
18322 | |
18323 | copyLoad(1); // L1 |
18324 | multiply(0); // M0 |
18325 | nbBarrierWait(); // Wait 0 ready |
18326 | copyStore(2); // S2 |
18327 | nbStoreSignal(); // Signal 2 ready |
18328 | |
18329 | copyLoad(2); // L2 |
18330 | multiply(1); // M1 |
18331 | nbBarrierWait(); // Wait 2 ready |
18332 | copyStore(0); // S0 |
18333 | nbStoreSignal(); // Signal 0 ready |
18334 | |
18335 | copyLoad(0); // L0 |
18336 | multiply(2); // M2 |
18337 | nbBarrierWait(); // Wait 0 ready |
18338 | copyStore(1); // S1 |
18339 | nbStoreSignal(); // Signal 1 ready |
18340 | |
18341 | jmpi(1 | f0[1], top); |
18342 | mark(bottom); |
18343 | |
18344 | multiply(0); // M0 |
18345 | nbBarrierWait(); // Wait 1 ready |
18346 | copyStore(2); // S2 |
18347 | nbStoreSignal(); // Signal 2 ready |
18348 | |
18349 | multiply(1); // M1 |
18350 | nbBarrierWait(); // Wait 2 ready |
18351 | copyStore(0); // S0 |
18352 | nbStoreSignal(); // Signal 0 ready |
18353 | |
18354 | multiply(2); // M2 |
18355 | |
18356 | nbBarrierWait(); // Wait 0 ready |
18357 | |
18358 | multiply(0, true); // M0 |
18359 | |
18360 | add(1 | le | f0[1], kCounter, kCounter, 2); |
18361 | jmpi(1 | f0[1], remBottom); |
18362 | jmpi(1, remTop); |
18363 | |
18364 | mark(skipMain); |
18365 | |
18366 | zeroMatrix(C_regs, strategy); |
18367 | add(1 | le | f0[1], kCounter, kCounter, 5); |
18368 | |
18369 | mov(2, slmAOffsetStore(1), slmAOffsetStoreInit(1)); |
18370 | sync.nop(SWSB<uint32_t>(1)); |
18371 | |
18372 | copyLoad(0); |
18373 | copyStore(0); |
18374 | storeSignal(true); |
18375 | nbBarrierWait(); |
18376 | multiply(0, true); |
18377 | |
18378 | jmpi(1 | f0[1], remBottom); |
18379 | |
18380 | mark(remTop); |
18381 | |
18382 | cmp(1 | lt | f0[1], kCounter, 2); |
18383 | |
18384 | copyLoad(1); |
18385 | copyStore(1); |
18386 | storeSignal(true); |
18387 | nbBarrierWait(); |
18388 | multiply(1, true); |
18389 | |
18390 | jmpi(1 | f0[1], remBottom); |
18391 | |
18392 | copyLoad(2); |
18393 | copyStore(2); |
18394 | storeSignal(true); |
18395 | nbBarrierWait(); |
18396 | multiply(2, true); |
18397 | |
18398 | mark(remBottom); |
18399 | } else |
18400 | stub(); |
18401 | |
18402 | sync.allwr(); |
18403 | setDefaultAutoSWSB(oldDefaultAutoSWSB); |
18404 | } |
18405 | |
18406 | template <HW hw> |
18407 | void gemm_kernel_generator_t<hw>::sysgemmKLoop4(const GEMMProblem &problem, |
18408 | const GEMMStrategy &strategy, GEMMState &state, bool oddB) { |
18409 | using namespace sysgemm; |
18410 | auto &depAddr = state.sysgemm.depAddr; |
18411 | |
18412 | Label top, bottom, skipMain, done; |
18413 | Label skipLoad0, skipLoad1, skipLoad2; |
18414 | Label skipStore0, skipStore1, skipStore2; |
18415 | Label sskipLoad12, sskipStore1, sskipStore2Load3, sskipStore3; |
18416 | |
18417 | auto clearDepAddr = [&]() { |
18418 | for (int i = 0; i < 4; i++) |
18419 | depAddr[i] = InstructionModifier(); |
18420 | }; |
18421 | auto storeSignal |
18422 | = [&]() { sysgemmStoreSignal(problem, strategy, state, true); }; |
18423 | auto copyLoad = [&](int storeBuffer, int useC = 0, bool forceLoadB = false, |
18424 | RegData flagLoadB = RegData()) { |
18425 | sysgemmCopyLoad4(problem, strategy, state, storeBuffer, |
18426 | ((storeBuffer & 1) != oddB) || forceLoadB, useC, flagLoadB); |
18427 | }; |
18428 | auto copyLoadRem = [&](int storeBuffer, RegData flagLoadB) { |
18429 | sysgemmCopyLoad4(problem, strategy, state, storeBuffer, |
18430 | (storeBuffer & 1) != oddB, 0, flagLoadB); |
18431 | }; |
18432 | auto copyStore = [&](int storeBuffer, int useC = 0, int useC_B = 0) { |
18433 | sysgemmCopyStore4(problem, strategy, state, storeBuffer, |
18434 | (storeBuffer & 1) == oddB, useC, useC_B); |
18435 | }; |
18436 | auto multiply = [&](int buffer, bool firstMultiply = false) { |
18437 | sysgemmMultiply4(problem, strategy, state, buffer, firstMultiply); |
18438 | }; |
18439 | auto multiplyRem = [&](int buffer, RegData flagWaitLoad, RegData flagSignal, |
18440 | bool firstMultiply = false) { |
18441 | sysgemmMultiply4(problem, strategy, state, buffer, firstMultiply, |
18442 | flagWaitLoad, flagSignal, &done); |
18443 | }; |
18444 | auto slmRead = [&]() { |
18445 | mov(1 | depAddr[0], addr0.ud(2), slmAOffsetLoad); |
18446 | mov(1 | depAddr[1], addr1.ud(2), slmBOffsetLoad); |
18447 | add(1 | depAddr[2], addr2.ud(2), slmBOffsetLoad, 8 * 32 / 16); |
18448 | add(1 | depAddr[3], addr3.ud(2), slmBOffsetLoad, 16 * 32 / 16); |
18449 | |
18450 | load(16 | SWSB<AllPipes>(sb3, 4), A_regs[0], block_oword(16), SLM, |
18451 | addr0); |
18452 | load(16 | SWSB<AllPipes>(sb0, 3), B_regs[0], block_oword(16), SLM, |
18453 | addr1); |
18454 | load(16 | SWSB<AllPipes>(sb1, 2), B_regs[8], block_oword(16), SLM, |
18455 | addr2); |
18456 | load(16 | SWSB<AllPipes>(sb2, 1), B_regs[16], block_oword(16), SLM, |
18457 | addr3); |
18458 | depAddr[0] = sb3.src; |
18459 | depAddr[1] = sb0.src; |
18460 | depAddr[2] = sb1.src; |
18461 | depAddr[3] = sb2.src; |
18462 | |
18463 | add(1 | depAddr[0], addr0.ud(2), slmAOffsetLoad, 8 * 32 / 16); |
18464 | load(16 | SWSB<AllPipes>(sb4, 1), A_regs[8], block_oword(16), SLM, |
18465 | addr0); |
18466 | depAddr[0] = sb4.src; |
18467 | }; |
18468 | |
18469 | bool oldDefaultAutoSWSB = getDefaultAutoSWSB(); |
18470 | setDefaultAutoSWSB(false); |
18471 | |
18472 | clearDepAddr(); |
18473 | mov(1, f1.ud(), 0); |
18474 | mov(1, f0.ud(), 0); |
18475 | cmp(1 | lt | f1[1], kCounter, 4); |
18476 | add(1 | le | f0[1], kCounter, kCounter, -7); |
18477 | |
18478 | jmpi(1 | f1[1], skipMain); |
18479 | |
18480 | status << "Main path, " << (oddB ? "odd B" : "even B" ) |
18481 | << status_stream::endl; |
18482 | |
18483 | copyLoad(0, 1, true); // L0 -> C1 |
18484 | copyLoad(1, 2); // L1 -> C2 |
18485 | copyLoad(2); // L2 |
18486 | copyStore(0, 1, 1); // S0 <- C1 |
18487 | storeSignal(); |
18488 | copyStore(1, 2, 1); // S1 <- C2 |
18489 | barrierwait(); |
18490 | slmRead(); |
18491 | storeSignal(); |
18492 | copyStore(2, 0, 2); // S2 |
18493 | if (!oddB) sync.allrd(0x3000); |
18494 | zeroMatrix(C_regs, strategy); |
18495 | sync.allrd(SWSB<AllPipes>(1)); |
18496 | |
18497 | jmpi(1 | f0[1], bottom); // Zero-trip loop check |
18498 | |
18499 | depAddr[0] = sb8.src; |
18500 | depAddr[1] = !oddB ? sb9.src : sb0.src; |
18501 | depAddr[2] = !oddB ? sb10.src : sb4.src; |
18502 | depAddr[3] = sb3.src; |
18503 | |
18504 | mark(top); |
18505 | add(1 | gt | f0[1], kCounter, kCounter, -4); |
18506 | |
18507 | copyLoad(3); |
18508 | multiply(0); |
18509 | copyStore(3); |
18510 | |
18511 | copyLoad(0); |
18512 | multiply(1); |
18513 | copyStore(0); |
18514 | |
18515 | copyLoad(1); |
18516 | multiply(2); |
18517 | copyStore(1); |
18518 | |
18519 | copyLoad(2); |
18520 | multiply(3); |
18521 | copyStore(2); |
18522 | |
18523 | jmpi(1 | f0[1], top); |
18524 | mark(bottom); |
18525 | |
18526 | cmp(1 | gt | f0[0], kCounter, -4 + 1); |
18527 | cmp(1 | gt | f0[1], kCounter, -4 + 2); |
18528 | cmp(1 | gt | f1[0], kCounter, -4 + 3); |
18529 | |
18530 | copyLoadRem(3, f0[0]); |
18531 | multiply(0); |
18532 | copyStore(3); |
18533 | |
18534 | sync.allrd(); |
18535 | jmpi(1 | ~f0[0], skipLoad0); |
18536 | copyLoadRem(0, f0[1]); |
18537 | mark(skipLoad0); |
18538 | multiply(1); |
18539 | sync.allrd(); |
18540 | jmpi(1 | ~f0[0], skipStore0); |
18541 | copyStore(0); |
18542 | mark(skipStore0); |
18543 | |
18544 | sync.allrd(); |
18545 | jmpi(1 | ~f0[1], skipLoad1); |
18546 | copyLoadRem(1, f1[0]); |
18547 | mark(skipLoad1); |
18548 | multiplyRem(2, FlagRegister(), f0[0]); |
18549 | sync.allrd(); |
18550 | jmpi(1 | ~f0[1], skipStore1); |
18551 | copyStore(1); |
18552 | mark(skipStore1); |
18553 | |
18554 | sync.allrd(); |
18555 | jmpi(1 | ~f1[0], skipLoad2); |
18556 | copyLoadRem(2, null); |
18557 | mark(skipLoad2); |
18558 | multiplyRem(3, f0[0], f0[1]); |
18559 | sync.allrd(); |
18560 | jmpi(1 | ~f1[0], skipStore2); |
18561 | copyStore(2); |
18562 | mark(skipStore2); |
18563 | |
18564 | multiplyRem(0, f0[1], f1[0]); |
18565 | multiplyRem(1, f1[0], null); |
18566 | multiplyRem(2, null, null); |
18567 | |
18568 | jmpi(1, done); |
18569 | |
18570 | status << "Small-k path, " << (oddB ? "odd B" : "even B" ) |
18571 | << status_stream::endl; |
18572 | |
18573 | clearDepAddr(); |
18574 | mark(skipMain); |
18575 | |
18576 | // Short loops: special case for 1-4 unrolls |
18577 | cmp(1 | gt | f0[0], kCounter, -7 + 1); |
18578 | cmp(1 | gt | f0[1], kCounter, -7 + 2); |
18579 | cmp(1 | gt | f1[0], kCounter, -7 + 3); |
18580 | |
18581 | auto flagLoadB0 = oddB ? f0[0] : FlagRegister(); |
18582 | copyLoad(0, 1, true, flagLoadB0); |
18583 | sync.allrd(); |
18584 | jmpi(1 | ~f0[0], sskipLoad12); |
18585 | copyLoad(1, 2, false, f0[1]); |
18586 | sync.allrd(); |
18587 | jmpi(1 | ~f0[1], sskipLoad12); |
18588 | copyLoadRem(2, f1[0]); |
18589 | mark(sskipLoad12); |
18590 | copyStore(0, 1, 1); |
18591 | storeSignal(); |
18592 | sync.allrd(); |
18593 | jmpi(1 | ~f0[0], sskipStore1); |
18594 | copyStore(1, 2, 1); |
18595 | mark(sskipStore1); |
18596 | barrierwait(); |
18597 | slmRead(); |
18598 | sync.allrd(); |
18599 | jmpi(1 | ~f0[0], sskipStore2Load3); |
18600 | storeSignal(); |
18601 | sync.allrd(); |
18602 | jmpi(1 | ~f0[1], sskipStore2Load3); |
18603 | copyStore(2, 0, 2); |
18604 | sync.allrd(); |
18605 | jmpi(1 | ~f1[0], sskipStore2Load3); |
18606 | copyLoadRem(3, null); |
18607 | mark(sskipStore2Load3); |
18608 | multiplyRem(0, f0[0], f0[1], true); |
18609 | jmpi(1 | ~f0[0], done); |
18610 | sync.allrd(); |
18611 | jmpi(1 | ~f1[0], sskipStore3); |
18612 | copyStore(3); |
18613 | mark(sskipStore3); |
18614 | multiplyRem(1, f0[1], f1[0]); |
18615 | multiplyRem(2, f1[0], null); |
18616 | multiplyRem(3, null, null); |
18617 | |
18618 | mark(done); |
18619 | |
18620 | sync.allwr(); |
18621 | setDefaultAutoSWSB(oldDefaultAutoSWSB); |
18622 | } |
18623 | |
18624 | template <HW hw> |
18625 | void gemm_kernel_generator_t<hw>::sysgemmStoreSignal(const GEMMProblem &problem, |
18626 | const GEMMStrategy &strategy, GEMMState &state, bool forceFence) { |
18627 | using namespace sysgemm; |
18628 | auto &depAddr = state.sysgemm.depAddr; |
18629 | |
18630 | if (!strategy.slmAltBarriers || forceFence) { |
18631 | // Signal SLM data ready once memory fence returns, asynchronously |
18632 | sync.nop(depAddr[0]); |
18633 | sysgemmBarrierPrep(depAddr[3], addr3); |
18634 | |
18635 | slmfence(SWSB<AllPipes>(sb15, 1), addr0); |
18636 | barriermsg(sb15, addr3); |
18637 | depAddr[0] = InstructionModifier(); |
18638 | depAddr[3] = sb15.src; |
18639 | } else { |
18640 | sysgemmBarrierPrep(depAddr[3], addr3); |
18641 | barriermsg(SWSB<AllPipes>(sb15, 1), addr3); |
18642 | depAddr[3] = sb15.src; |
18643 | } |
18644 | } |
18645 | |
18646 | template <HW hw> |
18647 | void gemm_kernel_generator_t<hw>::sysgemmCopyLoad(const GEMMProblem &problem, |
18648 | const GEMMStrategy &strategy, GEMMState &state, int storeBuffer, |
18649 | bool useC) { |
18650 | using namespace sysgemm; |
18651 | auto &depAddr = state.sysgemm.depAddr; |
18652 | |
18653 | bool surface = !strategy.A.base.isStateless(); |
18654 | bool emulate64 = strategy.emulate.emulate64; |
18655 | int unrollM = strategy.unroll[LoopM]; |
18656 | int unrollN = strategy.unroll[LoopN]; |
18657 | |
18658 | auto A_ptr = A_ptr64, B_ptr = B_ptr64; |
18659 | if (surface) { |
18660 | A_ptr = A_ptr.ud(); |
18661 | B_ptr = B_ptr.ud(); |
18662 | emulate64 = false; |
18663 | } |
18664 | |
18665 | // Load new A and B and increment load pointers. |
18666 | if (surface) { |
18667 | sync(SyncFunction::nop, SWSB<uint32_t>(1)); |
18668 | mov(1 | depAddr[0], addr0.ud(0), A_ptr); |
18669 | mov(1 | depAddr[1], addr1.ud(0), B_ptr); |
18670 | add(1 | depAddr[2], addr2.ud(0), B_ptr, 8 * 32); |
18671 | } else if (!emulate64) { |
18672 | sync(SyncFunction::nop, SWSB<uint64_t>(1)); |
18673 | mov(1 | depAddr[0], addr0.uq(0), A_ptr); |
18674 | mov(1 | depAddr[1], addr1.uq(0), B_ptr); |
18675 | add(1 | depAddr[2], addr2.uq(0), B_ptr, 8 * 32); |
18676 | } else { |
18677 | sync(SyncFunction::nop, SWSB<uint32_t>(1)); |
18678 | mov(1 | depAddr[2], addr2.ud(1), B_ptr.ud(1)); |
18679 | add(1 | ov | f1[1], addr2.ud(0), B_ptr.ud(0), 8 * 32); |
18680 | mov(2 | depAddr[0], addr0.ud(0)(1), A_ptr.ud()(1)); |
18681 | mov(2 | depAddr[1], addr1.ud(0)(1), B_ptr.ud()(1)); |
18682 | add(1 | f1[1] | SWSB(4), addr2.ud(1), addr2.ud(1), 1); |
18683 | } |
18684 | |
18685 | if (useC) { |
18686 | if (surface) { |
18687 | load(1 | SWSB<AllPipes>(sb11, 3), C_regs[0], |
18688 | D64T(32) | strategy.A.cachingR, strategy.A.base, addr0); |
18689 | load(1 | SWSB<AllPipes>(sb12, 2), C_regs[8], |
18690 | D64T(32) | strategy.B.cachingR, strategy.B.base, addr1); |
18691 | if (strategy.unroll[LoopN] > 32) |
18692 | load(1 | SWSB<AllPipes>(sb13, 1), C_regs[16], |
18693 | D64T(16) | strategy.B.cachingR, strategy.B.base, addr2); |
18694 | } else { |
18695 | load(16 | SWSB<AllPipes>(sb11, 3), C_regs[0], block_hword(8), A64, |
18696 | addr0); |
18697 | load(16 | SWSB<AllPipes>(sb12, 2), C_regs[8], block_hword(8), A64, |
18698 | addr1); |
18699 | if (strategy.unroll[LoopN] > 32) |
18700 | load(16 | SWSB<AllPipes>(sb13, 1), C_regs[16], block_hword(4), |
18701 | A64, addr2); |
18702 | } |
18703 | depAddr[0] = sb11.src; |
18704 | depAddr[1] = sb12.src; |
18705 | if (strategy.unroll[LoopN] > 32) depAddr[2] = sb13.src; |
18706 | if (strategy.simulation) sync.allrd(0x3000); |
18707 | } else { |
18708 | // Stronger than necessary dependencies... can load as soon as prev. store inputs are read. |
18709 | int loadBuffer = (strategy.slmCopies == 3) ? storeBuffer : 0; |
18710 | int t0 = 8 + loadBuffer * 2; |
18711 | SBID token0 {t0}, token1 {t0 + 1}, token2 {t0 + 2}; |
18712 | |
18713 | if (surface) { |
18714 | load(1 | SWSB<AllPipes>(token0, 3), A_copy[loadBuffer][0], |
18715 | D64T(32) | strategy.A.cachingR, strategy.A.base, addr0); |
18716 | load(1 | SWSB<AllPipes>(token1, 2), B_copy[loadBuffer][0], |
18717 | D64T(32) | strategy.B.cachingR, strategy.B.base, addr1); |
18718 | if (strategy.unroll[LoopN] > 32) |
18719 | load(1 | SWSB<AllPipes>(token2, 1), B_copy[loadBuffer][8], |
18720 | D64T(16) | strategy.B.cachingR, strategy.B.base, addr2); |
18721 | } else { |
18722 | load(16 | SWSB<AllPipes>(token0, 3), A_copy[loadBuffer][0], |
18723 | block_hword(8), A64, addr0); |
18724 | load(16 | SWSB<AllPipes>(token1, 2), B_copy[loadBuffer][0], |
18725 | block_hword(8), A64, addr1); |
18726 | if (strategy.unroll[LoopN] > 32) |
18727 | load(16 | SWSB<AllPipes>(token2, 1), B_copy[loadBuffer][8], |
18728 | block_hword(4), A64, addr2); |
18729 | } |
18730 | depAddr[0] = token0.src; |
18731 | depAddr[1] = token1.src; |
18732 | if (strategy.unroll[LoopN] > 32) depAddr[2] = token2.src; |
18733 | if (strategy.simulation) sync.allrd(0x6 << t0); |
18734 | } |
18735 | |
18736 | if (!emulate64) { |
18737 | add(1 | SWSB(3), A_ptr, A_ptr, unrollM * 32); |
18738 | add(1 | SWSB(3), B_ptr, B_ptr, unrollN * 32); |
18739 | } else { |
18740 | add(1 | ov | f1[0] | SWSB(3), A_ptr.ud(0), A_ptr.ud(0), unrollM * 32); |
18741 | add(1 | ov | f1[1] | SWSB(3), B_ptr.ud(0), B_ptr.ud(0), unrollN * 32); |
18742 | add(1 | f1[0], A_ptr.ud(1), A_ptr.ud(1), 1); |
18743 | add(1 | f1[1], B_ptr.ud(1), B_ptr.ud(1), 1); |
18744 | } |
18745 | } |
18746 | |
18747 | template <HW hw> |
18748 | void gemm_kernel_generator_t<hw>::sysgemmCopyLoad4(const GEMMProblem &problem, |
18749 | const GEMMStrategy &strategy, GEMMState &state, int storeBuffer, |
18750 | bool loadB, int useC, RegData flagLoadB) { |
18751 | using namespace sysgemm; |
18752 | auto &depAddr = state.sysgemm.depAddr; |
18753 | |
18754 | bool emulate64 = strategy.emulate.emulate64; |
18755 | bool surface = !strategy.A.base.isStateless(); |
18756 | int unrollM = strategy.unroll[LoopM]; |
18757 | int unrollN = strategy.unroll[LoopN]; |
18758 | int x48 = (unrollN > 32); |
18759 | InstructionModifier loadBMod; |
18760 | loadB &= !(flagLoadB.isValid() && flagLoadB.isNull()); |
18761 | if (flagLoadB.isValid()) |
18762 | loadBMod = loadBMod | static_cast<FlagRegister &>(flagLoadB) | any16h; |
18763 | |
18764 | // Load new A and B and increment load pointers. |
18765 | auto A_ptr = A_ptr64, B_ptr = B_ptr64; |
18766 | if (surface) { |
18767 | A_ptr = A_ptr.ud(); |
18768 | B_ptr = B_ptr.ud(); |
18769 | emulate64 = false; |
18770 | } |
18771 | |
18772 | if (surface) { |
18773 | sync.nop(SWSB(Pipe::I, 1)); |
18774 | mov(1 | depAddr[0], addr0.ud(0), A_ptr); |
18775 | if (loadB) { |
18776 | mov(1 | depAddr[1], addr1.ud(0), B_ptr); |
18777 | add(1 | depAddr[2], addr2.ud(0), B_ptr, 8 * 32); |
18778 | } |
18779 | } else if (!emulate64) { |
18780 | sync.nop(SWSB(Pipe::L, 1)); |
18781 | mov(1 | depAddr[0], addr0.uq(0), A_ptr); |
18782 | if (loadB) { |
18783 | mov(1 | depAddr[1], addr1.uq(0), B_ptr); |
18784 | add(1 | depAddr[2], addr2.uq(0), B_ptr, 8 * 32); |
18785 | } |
18786 | } else { |
18787 | sync.nop(SWSB(Pipe::I, 1)); |
18788 | if (loadB) { |
18789 | mov(1 | depAddr[2], addr2.ud(1), B_ptr.ud(1)); |
18790 | add(1 | ov | f1[1], addr2.ud(0), B_ptr.ud(0), 8 * 32); |
18791 | } |
18792 | mov(2 | depAddr[0], addr0.ud(0)(1), A_ptr.ud()(1)); |
18793 | if (loadB) { |
18794 | mov(2 | depAddr[1], addr1.ud(0)(1), B_ptr.ud()(1)); |
18795 | add(1 | f1[1], addr2.ud(1), addr2.ud(1), 1); |
18796 | } |
18797 | } |
18798 | |
18799 | SBID tokenA(0), tokenB0(0), tokenB1(0); |
18800 | GRF dstA, dstB0, dstB1; |
18801 | |
18802 | if (useC) { |
18803 | tokenA = SBID((useC == 1) ? 5 : 11); |
18804 | tokenB0 = SBID((useC == 1) ? 6 : 12); |
18805 | tokenB1 = SBID((useC == 1) ? 7 : 13); |
18806 | int C_off = (useC == 1) ? 0 : 20; |
18807 | dstA = C_regs[C_off + 0]; |
18808 | dstB0 = C_regs[C_off + 8]; |
18809 | dstB1 = C_regs[C_off + 16]; |
18810 | } else { |
18811 | // Stronger than necessary dependencies... can load as soon as prev. store inputs are read. |
18812 | int loadBuffer = (strategy.slmCopies == 3) ? storeBuffer : 0; |
18813 | int t0 = 8 + loadBuffer * 2; |
18814 | tokenA = SBID(t0 + 0); |
18815 | tokenB0 = SBID(t0 + 1); |
18816 | tokenB1 = SBID(t0 + 2); |
18817 | dstA = A_copy[loadBuffer][0]; |
18818 | dstB0 = B_copy[loadBuffer][0]; |
18819 | dstB1 = B_copy[loadBuffer][8]; |
18820 | } |
18821 | |
18822 | if (surface) { |
18823 | load(1 | tokenA | SWSB<AllPipes>(1 + loadB * (1 + x48)), dstA, |
18824 | D64T(32) | strategy.A.cachingR, strategy.A.base, addr0); |
18825 | if (loadB) { |
18826 | load(1 | tokenB0 | loadBMod | SWSB<AllPipes>(1 + x48), dstB0, |
18827 | D64T(32) | strategy.B.cachingR, strategy.B.base, addr1); |
18828 | if (x48) |
18829 | load(1 | tokenB1 | loadBMod | SWSB<AllPipes>(1), dstB1, |
18830 | D64T(16) | strategy.B.cachingR, strategy.B.base, addr2); |
18831 | } |
18832 | } else { |
18833 | load(16 | tokenA | SWSB<AllPipes>(1 + loadB * (1 + x48)), dstA, |
18834 | block_hword(8), A64, addr0); |
18835 | if (loadB) { |
18836 | load(16 | tokenB0 | loadBMod | SWSB<AllPipes>(1 + x48), dstB0, |
18837 | block_hword(8), A64, addr1); |
18838 | if (x48) |
18839 | load(16 | tokenB1 | loadBMod | SWSB<AllPipes>(1), dstB1, |
18840 | block_hword(4), A64, addr2); |
18841 | } |
18842 | } |
18843 | depAddr[0] = tokenA.src; |
18844 | if (loadB) { |
18845 | depAddr[1] = tokenB0.src; |
18846 | if (x48) depAddr[2] = tokenB1.src; |
18847 | } |
18848 | if (strategy.simulation) { |
18849 | uint16_t tmask = (1 << tokenA.getID()); |
18850 | if (loadB) tmask |= (1 << tokenB0.getID()) | (1 << tokenB1.getID()); |
18851 | sync.allrd(tmask); |
18852 | } |
18853 | |
18854 | if (!emulate64) { |
18855 | add(1, A_ptr, A_ptr, unrollM * 32); |
18856 | if (loadB) add(1, B_ptr, B_ptr, 2 * unrollN * 32); |
18857 | } else { |
18858 | add(1 | ov | f1[1], A_ptr.ud(0), A_ptr.ud(0), unrollM * 32); |
18859 | if (loadB) |
18860 | add(1 | ov | f1[1] | M8, B_ptr.ud(0), B_ptr.ud(0), |
18861 | 2 * unrollN * 32); |
18862 | add(1 | f1[1], A_ptr.ud(1), A_ptr.ud(1), 1); |
18863 | if (loadB) add(1 | f1[1] | M8, B_ptr.ud(1), B_ptr.ud(1), 1); |
18864 | } |
18865 | } |
18866 | |
18867 | template <HW hw> |
18868 | void gemm_kernel_generator_t<hw>::sysgemmCopyStore(const GEMMProblem &problem, |
18869 | const GEMMStrategy &strategy, GEMMState &state, int storeBuffer, |
18870 | bool first) { |
18871 | using namespace sysgemm; |
18872 | auto &depAddr = state.sysgemm.depAddr; |
18873 | |
18874 | auto aoffset = first ? slmAOffsetStoreInit : slmAOffsetStore; |
18875 | auto boffset = first ? slmBOffsetStoreInit : slmBOffsetStore; |
18876 | |
18877 | // Store A and B and advance store pointers to next buffer. |
18878 | mov(1 | depAddr[0], addr0.ud(2), aoffset); |
18879 | mov(1 | depAddr[1], addr1.ud(2), boffset); |
18880 | add(1 | depAddr[2], addr2.ud(2), boffset, 8 * 32 / 16); |
18881 | |
18882 | if (first && strategy.slmCopies == 1) { |
18883 | store(16 | SWSB<AllPipes>(sb11, 3), block_oword(16), SLM, addr0, |
18884 | C_regs[0]); |
18885 | store(16 | SWSB<AllPipes>(sb12, 2), block_oword(16), SLM, addr1, |
18886 | C_regs[8]); |
18887 | if (strategy.unroll[LoopN] > 32) |
18888 | store(16 | SWSB<AllPipes>(sb13, 1), block_oword(8), SLM, addr2, |
18889 | C_regs[16]); |
18890 | depAddr[0] = sb11.src; |
18891 | depAddr[1] = sb12.src; |
18892 | if (strategy.unroll[LoopN] > 32) depAddr[2] = sb13.src; |
18893 | if (strategy.simulation) sync.allrd(0x3000); |
18894 | } else { |
18895 | int loadBuffer = (strategy.slmCopies == 3) ? storeBuffer : 0; |
18896 | int t0 = 8 + loadBuffer * 2; |
18897 | SBID token0 {t0}, token1 {t0 + 1}, token2 {t0 + 2}; |
18898 | |
18899 | store(16 | SWSB<AllPipes>(token0, 3), block_oword(16), SLM, addr0, |
18900 | A_copy[loadBuffer][0]); |
18901 | store(16 | SWSB<AllPipes>(token1, 2), block_oword(16), SLM, addr1, |
18902 | B_copy[loadBuffer][0]); |
18903 | if (strategy.unroll[LoopN] > 32) |
18904 | store(16 | SWSB<AllPipes>(token2, 1), block_oword(8), SLM, addr2, |
18905 | B_copy[loadBuffer][8]); |
18906 | depAddr[0] = token0.src; |
18907 | depAddr[1] = token1.src; |
18908 | if (strategy.unroll[LoopN] > 32) depAddr[2] = token2.src; |
18909 | if (strategy.simulation) sync.allrd(0x6 << t0); |
18910 | } |
18911 | |
18912 | if (storeBuffer == 2) |
18913 | mov(2, slmAOffsetStore(1), slmAOffsetStoreInit(1)); |
18914 | else |
18915 | add(2, slmAOffsetStore(1), aoffset(1), |
18916 | strategy.slmSysgemmBlockSize() / 16); |
18917 | } |
18918 | |
18919 | template <HW hw> |
18920 | void gemm_kernel_generator_t<hw>::sysgemmCopyStore4(const GEMMProblem &problem, |
18921 | const GEMMStrategy &strategy, GEMMState &state, int storeBuffer, |
18922 | bool storeB, int useC, int useC_B) { |
18923 | using namespace sysgemm; |
18924 | auto &depAddr = state.sysgemm.depAddr; |
18925 | bool first = (useC == 1); |
18926 | bool x48 = (strategy.unroll[LoopN] > 32); |
18927 | |
18928 | auto aoffset = first ? slmAOffsetStoreInit : slmAOffsetStore; |
18929 | auto boffset = first ? slmBOffsetStoreInit : slmBOffsetStore; |
18930 | |
18931 | // Store A and B and advance store pointers to next buffer. |
18932 | mov(1 | depAddr[0], addr0.ud(2), aoffset); |
18933 | if (storeB) { |
18934 | mov(1 | depAddr[1], addr1.ud(2), boffset); |
18935 | if (x48) add(1 | depAddr[2], addr2.ud(2), boffset, 8 * 32 / 16); |
18936 | } |
18937 | |
18938 | int loadBuffer = (strategy.slmCopies == 3) ? storeBuffer : 0; |
18939 | int t0 = 8 + loadBuffer * 2; |
18940 | auto tokenA = SBID(t0 + 0); |
18941 | auto tokenB0 = SBID(t0 + 1); |
18942 | auto tokenB1 = SBID(t0 + 2); |
18943 | auto srcA = A_copy[loadBuffer][0]; |
18944 | auto srcB0 = B_copy[loadBuffer][0]; |
18945 | auto srcB1 = B_copy[loadBuffer][8]; |
18946 | |
18947 | if (useC) { |
18948 | tokenA = SBID((useC == 1) ? 5 : 11); |
18949 | int C_off = (useC == 1) ? 0 : 20; |
18950 | srcA = C_regs[C_off + 0]; |
18951 | } |
18952 | |
18953 | if (useC_B) { |
18954 | tokenB0 = SBID((useC_B == 1) ? 6 : 12); |
18955 | tokenB1 = SBID((useC_B == 1) ? 7 : 13); |
18956 | int C_off = (useC_B == 1) ? 0 : 20; |
18957 | srcB0 = C_regs[C_off + 8]; |
18958 | srcB1 = C_regs[C_off + 16]; |
18959 | } |
18960 | |
18961 | store(16 | tokenA | SWSB<AllPipes>(1 + storeB * (1 + x48)), block_oword(16), |
18962 | SLM, addr0, srcA); |
18963 | if (storeB) { |
18964 | store(16 | tokenB0 | SWSB<AllPipes>(1 + x48), block_oword(16), SLM, |
18965 | addr1, srcB0); |
18966 | if (x48) |
18967 | store(16 | tokenB1 | SWSB<AllPipes>(1), block_oword(8), SLM, addr2, |
18968 | srcB1); |
18969 | } |
18970 | |
18971 | depAddr[0] = tokenA.src; |
18972 | if (storeB) { |
18973 | depAddr[1] = tokenB0.src; |
18974 | if (x48) depAddr[2] = tokenB1.src; |
18975 | } |
18976 | if (strategy.simulation) { |
18977 | uint16_t tmask = (1 << tokenA.getID()); |
18978 | if (storeB) tmask |= (1 << tokenB0.getID()) | (1 << tokenB1.getID()); |
18979 | sync.allrd(tmask); |
18980 | } |
18981 | |
18982 | if (storeBuffer == 3) |
18983 | mov(2, slmAOffsetStore(1), slmAOffsetStoreInit(1)); |
18984 | else |
18985 | add(2, slmAOffsetStore(1), aoffset(1), |
18986 | strategy.slmSysgemmBlockSize() / 16); |
18987 | } |
18988 | |
18989 | template <HW hw> |
18990 | void gemm_kernel_generator_t<hw>::sysgemmMultiply(const GEMMProblem &problem, |
18991 | const GEMMStrategy &strategy, GEMMState &state, int buffer, |
18992 | bool lastMultiply) { |
18993 | using namespace sysgemm; |
18994 | auto &depAddr = state.sysgemm.depAddr; |
18995 | |
18996 | // Load half of A (16x32) -- hopefully broadcast from SLM to this row -- and half of B, interleaved. |
18997 | InstructionModifier swsb = lastMultiply ? SWSB(1) : depAddr[0]; |
18998 | |
18999 | mov(1 | swsb, addr0.ud(2), slmAOffsetLoad); |
19000 | mov(1 | depAddr[1], addr1.ud(2), slmBOffsetLoad); |
19001 | add(1 | depAddr[2], addr2.ud(2), slmBOffsetLoad, 8 * 32 / 16); |
19002 | add(1 | depAddr[3], addr3.ud(2), slmBOffsetLoad, 16 * 32 / 16); |
19003 | |
19004 | if (strategy.slmAltBarriers) barrierwait(); |
19005 | |
19006 | if (strategy.simulation) sync(SyncFunction::nop, SWSB<int64_t>(1)); |
19007 | sync.nop(sb5.src); |
19008 | load(16 | SWSB<AllPipes>(sb3, 4), A_regs[0], block_oword(16), SLM, addr0); |
19009 | load(16 | SWSB<AllPipes>(sb0, 3), B_regs[0], block_oword(16), SLM, addr1); |
19010 | load(16 | SWSB<AllPipes>(sb1, 2), B_regs[8], block_oword(16), SLM, addr2); |
19011 | if (strategy.unroll[LoopN] > 32) |
19012 | load(16 | SWSB<AllPipes>(sb2, 1), B_regs[16], block_oword(16), SLM, |
19013 | addr3); |
19014 | |
19015 | add(1 | sb3.src, addr0.ud(2), slmAOffsetLoad, 8 * 32 / 16); |
19016 | add(1 | sb0.src, addr1.ud(2), slmAOffsetLoad, 16 * 32 / 16); |
19017 | add(1 | sb1.src, addr2.ud(2), slmAOffsetLoad, 24 * 32 / 16); |
19018 | load(16 | SWSB<AllPipes>(sb4, 3), A_regs[8], block_oword(16), SLM, addr0); |
19019 | |
19020 | // Wait for A data to load. |
19021 | sync.allwr(0x18); |
19022 | |
19023 | if (strategy.slmAltBarriers && !lastMultiply) { |
19024 | sysgemmBarrierPrep(sb2.src, addr3); |
19025 | barriermsg(SWSB<AllPipes>(sb15, 1), addr3); |
19026 | } |
19027 | |
19028 | // Rows 0-7 |
19029 | sysgemmMultiplyChunk( |
19030 | problem, strategy, false, 0, 0, true, false, sb0.dst, sb3); |
19031 | |
19032 | // Rows 8-15 |
19033 | sysgemmMultiplyChunk(problem, strategy, false, 8, 8, false, false, |
19034 | InstructionModifier(), sb4); |
19035 | |
19036 | // Load third quarter of A (8x32) |
19037 | load(16 | SWSB<AllPipes>(sb3, 2), A_regs[0], block_oword(16), SLM, addr1); |
19038 | |
19039 | // Rows 16-23 |
19040 | sysgemmMultiplyChunk( |
19041 | problem, strategy, false, 0, 16, false, false, sb3.dst); |
19042 | |
19043 | // Load last quarter of A (8x32) |
19044 | load(16 | SWSB<AllPipes>(sb4, 1), A_regs[8], block_oword(16), SLM, addr2); |
19045 | |
19046 | // Increment A and B to next buffer. |
19047 | swsb = strategy.simulation ? InstructionModifier(sb3.src) |
19048 | : InstructionModifier(); |
19049 | if (buffer == 2) |
19050 | mov(2 | swsb, slmAOffsetLoad(1), slmAOffsetLoadInit(1)); |
19051 | else |
19052 | add(2 | swsb, slmAOffsetLoad(1), slmAOffsetLoad(1), |
19053 | strategy.slmSysgemmBlockSize() / 16); |
19054 | |
19055 | // Rows 24-31 |
19056 | sysgemmMultiplyChunk( |
19057 | problem, strategy, false, 8, 24, false, false, sb4.dst, sb5); |
19058 | |
19059 | // Remember dependencies for address registers. |
19060 | depAddr[0] = InstructionModifier {}; |
19061 | depAddr[1] = sb3.src; |
19062 | depAddr[2] = sb4.src; |
19063 | depAddr[3] = strategy.slmAltBarriers ? sb15.src : sb2.src; |
19064 | } |
19065 | |
19066 | template <HW hw> |
19067 | void gemm_kernel_generator_t<hw>::sysgemmMultiply4(const GEMMProblem &problem, |
19068 | const GEMMStrategy &strategy, GEMMState &state, int buffer, |
19069 | bool firstMultiply, RegData flagWaitLoad, RegData flagSignal, |
19070 | Label *labelDone) { |
19071 | using namespace sysgemm; |
19072 | auto &depAddr = state.sysgemm.depAddr; |
19073 | bool x48 = (strategy.unroll[LoopN] > 32); |
19074 | uint16_t slmStride = strategy.slmSysgemmBlockSize() / 16; |
19075 | int16_t slmAdvance = ((buffer == 3) ? -3 : 1) * slmStride; |
19076 | |
19077 | InstructionModifier loadMod {}, signalMod {}; |
19078 | bool cooldownWaitLoad = flagWaitLoad.isValid(); |
19079 | bool cooldownSignal = flagSignal.isValid(); |
19080 | bool doWaitLoad = !cooldownWaitLoad || !flagWaitLoad.isNull(); |
19081 | bool doSignal = !cooldownSignal || !flagSignal.isNull(); |
19082 | auto fWaitLoad = static_cast<FlagRegister &>(flagWaitLoad); |
19083 | auto fSignal = static_cast<FlagRegister &>(flagSignal); |
19084 | if (doWaitLoad && cooldownWaitLoad) loadMod = loadMod | fWaitLoad | any16h; |
19085 | if (doSignal && cooldownSignal) signalMod = signalMod | fSignal | any16h; |
19086 | |
19087 | // Fence. |
19088 | if (doSignal) { |
19089 | sync.nop(depAddr[0]); |
19090 | slmfence(sb15 | signalMod, addr0); |
19091 | depAddr[0] = sb15.dst; |
19092 | } |
19093 | |
19094 | // Rows 0-7. Upper half of A (16x32) is already loaded. |
19095 | sync.nop(sb3.dst); |
19096 | depAddr[3] = InstructionModifier(); |
19097 | sysgemmMultiplyChunk( |
19098 | problem, strategy, firstMultiply, 0, 0, true, false, sb0.dst, sb3); |
19099 | |
19100 | // Prepare addresses for loading lower half of A, and part of B |
19101 | add(1 | depAddr[1], addr1.ud(2), slmAOffsetLoad, 16 * 32 / 16); |
19102 | add(1 | depAddr[2], addr2.ud(2), slmAOffsetLoad, 24 * 32 / 16); |
19103 | sysgemmBarrierPrep(depAddr[3], addr3); |
19104 | |
19105 | // Rows 8-15. |
19106 | sysgemmMultiplyChunk( |
19107 | problem, strategy, firstMultiply, 8, 8, false, false, sb4.dst, sb4); |
19108 | |
19109 | // Load lower half of A (16x32) -- hopefully broadcast from SLM to this row. |
19110 | load(16 | SWSB<AllPipes>(sb3, 3), A_regs[0], block_oword(16), SLM, addr1); |
19111 | load(16 | SWSB<AllPipes>(sb4, 2), A_regs[8], block_oword(16), SLM, addr2); |
19112 | depAddr[1] = sb3.src; |
19113 | depAddr[2] = sb4.src; |
19114 | |
19115 | // Rows 16-23. |
19116 | sysgemmMultiplyChunk(problem, strategy, firstMultiply, 0, 16, false, false, |
19117 | sb3.dst, sb3); |
19118 | depAddr[1] = InstructionModifier(); |
19119 | |
19120 | // Address prep, part 2. |
19121 | add(1 | depAddr[1], addr1.ud(2), slmBOffsetLoad, slmAdvance + 0 * 32 / 16); |
19122 | add(1 | depAddr[2], addr2.ud(2), slmBOffsetLoad, slmAdvance + 8 * 32 / 16); |
19123 | if (x48) |
19124 | add(1 | depAddr[0], addr0.ud(2), slmBOffsetLoad, |
19125 | slmAdvance |
19126 | + 16 * 32 |
19127 | / 16); // consider moving after next dpasw block |
19128 | |
19129 | // Rows 24-31. |
19130 | sysgemmMultiplyChunk( |
19131 | problem, strategy, firstMultiply, 8, 24, false, true, sb4.dst, sb2); |
19132 | |
19133 | if (doWaitLoad) { |
19134 | if (cooldownWaitLoad) jmpi(1 | ~fWaitLoad, *labelDone); |
19135 | |
19136 | // Split barrier. |
19137 | barrierwait(); |
19138 | if (doSignal) { |
19139 | barriermsg(SWSB<AllPipes>(sb15, x48 ? 4 : 3) | signalMod, addr3); |
19140 | depAddr[3] = sb15.src; |
19141 | } |
19142 | |
19143 | // Load next B data and upper 16x32 of A. |
19144 | load(16 | SWSB<AllPipes>(sb0, x48 ? 3 : 2), B_regs[0], block_oword(16), |
19145 | SLM, addr1); |
19146 | load(16 | SWSB<AllPipes>(sb1, x48 ? 2 : 1), B_regs[8], block_oword(16), |
19147 | SLM, addr2); |
19148 | if (x48) |
19149 | load(16 | SWSB<AllPipes>(sb2, 1), B_regs[16], block_oword(16), SLM, |
19150 | addr0); |
19151 | depAddr[1] = sb0.src; |
19152 | depAddr[2] = sb1.src; |
19153 | if (x48) depAddr[0] = sb2.src; |
19154 | |
19155 | add(1 | depAddr[3], addr3.ud(2), slmAOffsetLoad, |
19156 | slmAdvance + 0 * 32 / 16); |
19157 | add(1 | depAddr[2], addr2.ud(2), slmAOffsetLoad, |
19158 | slmAdvance + 8 * 32 / 16); |
19159 | InstructionModifier swsb; |
19160 | if (strategy.simulation) swsb = sb2.src; |
19161 | if (buffer == 3) |
19162 | mov(2 | swsb, slmAOffsetLoad(1), slmAOffsetLoadInit(1)); |
19163 | else |
19164 | add(2 | swsb, slmAOffsetLoad(1), slmAOffsetLoad(1), slmAdvance); |
19165 | |
19166 | load(16 | SWSB<AllPipes>(sb3, 3), A_regs[0], block_oword(16), SLM, |
19167 | addr3); |
19168 | load(16 | SWSB<AllPipes>(sb4, 2), A_regs[8], block_oword(16), SLM, |
19169 | addr2); |
19170 | depAddr[3] = sb3.src; |
19171 | depAddr[2] = sb4.src; |
19172 | } |
19173 | } |
19174 | |
19175 | template <HW hw> |
19176 | void gemm_kernel_generator_t<hw>::sysgemmMultiplyChunk( |
19177 | const GEMMProblem &problem, const GEMMStrategy &strategy, bool first, |
19178 | int ao, int i0, bool waitB, bool prepB, |
19179 | const InstructionModifier &swsb0, const InstructionModifier &swsbEnd) { |
19180 | using namespace sysgemm; |
19181 | int co = i0 * 6; |
19182 | |
19183 | auto dpaswTyped |
19184 | = [&](InstructionModifier mod, uint8_t sdepth, uint8_t rcount, |
19185 | const GRF &cReg, const GRF &aReg, const GRF &bReg) { |
19186 | auto A = aReg.retype(problem.Ta.ngen()); |
19187 | auto B = bReg.retype(problem.Tb.ngen()); |
19188 | auto C = cReg.retype(problem.Tc.ngen()); |
19189 | first ? dpasw(mod, sdepth, rcount, C, |
19190 | null.retype(problem.Tc.ngen()), A, B) |
19191 | : dpasw(mod, sdepth, rcount, C, C, A, B); |
19192 | }; |
19193 | |
19194 | if (strategy.unroll[LoopN] > 32) { |
19195 | if (waitB) { |
19196 | dpaswTyped(8 | swsb0 | Atomic, 8, 8, C_regs[co], A_regs[ao], |
19197 | B_regs[0]); |
19198 | dpaswTyped(8, 8, 8, C_regs[co + 8], A_regs[ao], B_regs[4]); |
19199 | dpaswTyped(8 | sb1.dst | Atomic, 8, 8, C_regs[co + 16], A_regs[ao], |
19200 | B_regs[8]); |
19201 | dpaswTyped(8, 8, 8, C_regs[co + 24], A_regs[ao], B_regs[12]); |
19202 | dpaswTyped(8 | sb2.dst | Atomic, 8, 8, C_regs[co + 32], A_regs[ao], |
19203 | B_regs[16]); |
19204 | dpaswTyped( |
19205 | 8 | swsbEnd, 8, 8, C_regs[co + 40], A_regs[ao], B_regs[20]); |
19206 | } else if (prepB) { |
19207 | dpaswTyped(8 | swsb0 | Atomic, 8, 8, C_regs[co], A_regs[ao], |
19208 | B_regs[0]); |
19209 | dpaswTyped(8 | sb0, 8, 8, C_regs[co + 8], A_regs[ao], B_regs[4]); |
19210 | dpaswTyped( |
19211 | 8 | Atomic, 8, 8, C_regs[co + 16], A_regs[ao], B_regs[8]); |
19212 | dpaswTyped(8 | sb1, 8, 8, C_regs[co + 24], A_regs[ao], B_regs[12]); |
19213 | dpaswTyped( |
19214 | 8 | Atomic, 8, 8, C_regs[co + 32], A_regs[ao], B_regs[16]); |
19215 | dpaswTyped( |
19216 | 8 | swsbEnd, 8, 8, C_regs[co + 40], A_regs[ao], B_regs[20]); |
19217 | } else { |
19218 | dpaswTyped(8 | swsb0 | Atomic, 8, 8, C_regs[co], A_regs[ao], |
19219 | B_regs[0]); |
19220 | dpaswTyped(8 | Atomic, 8, 8, C_regs[co + 8], A_regs[ao], B_regs[4]); |
19221 | dpaswTyped( |
19222 | 8 | Atomic, 8, 8, C_regs[co + 16], A_regs[ao], B_regs[8]); |
19223 | dpaswTyped( |
19224 | 8 | Atomic, 8, 8, C_regs[co + 24], A_regs[ao], B_regs[12]); |
19225 | dpaswTyped( |
19226 | 8 | Atomic, 8, 8, C_regs[co + 32], A_regs[ao], B_regs[16]); |
19227 | dpaswTyped( |
19228 | 8 | swsbEnd, 8, 8, C_regs[co + 40], A_regs[ao], B_regs[20]); |
19229 | } |
19230 | } else { |
19231 | if (waitB) { |
19232 | dpaswTyped(8 | swsb0 | Atomic, 8, 8, C_regs[co], A_regs[ao], |
19233 | B_regs[0]); |
19234 | dpaswTyped(8, 8, 8, C_regs[co + 8], A_regs[ao], B_regs[4]); |
19235 | dpaswTyped(8 | sb1.dst | Atomic, 8, 8, C_regs[co + 16], A_regs[ao], |
19236 | B_regs[8]); |
19237 | dpaswTyped( |
19238 | 8 | swsbEnd, 8, 8, C_regs[co + 24], A_regs[ao], B_regs[12]); |
19239 | } else if (prepB) { |
19240 | dpaswTyped(8 | swsb0 | Atomic, 8, 8, C_regs[co], A_regs[ao], |
19241 | B_regs[0]); |
19242 | dpaswTyped(8 | sb0, 8, 8, C_regs[co + 8], A_regs[ao], B_regs[4]); |
19243 | dpaswTyped( |
19244 | 8 | Atomic, 8, 8, C_regs[co + 16], A_regs[ao], B_regs[8]); |
19245 | dpaswTyped( |
19246 | 8 | swsbEnd, 8, 8, C_regs[co + 24], A_regs[ao], B_regs[12]); |
19247 | } else { |
19248 | dpaswTyped(8 | swsb0 | Atomic, 8, 8, C_regs[co], A_regs[ao], |
19249 | B_regs[0]); |
19250 | dpaswTyped(8 | Atomic, 8, 8, C_regs[co + 8], A_regs[ao], B_regs[4]); |
19251 | dpaswTyped( |
19252 | 8 | Atomic, 8, 8, C_regs[co + 16], A_regs[ao], B_regs[8]); |
19253 | dpaswTyped( |
19254 | 8 | swsbEnd, 8, 8, C_regs[co + 24], A_regs[ao], B_regs[12]); |
19255 | } |
19256 | } |
19257 | } |
19258 | |
19259 | template <HW hw> |
19260 | void gemm_kernel_generator_t<hw>::sysgemmBarrierPrep( |
19261 | const InstructionModifier &swsb, const GRF &) { |
19262 | using namespace sysgemm; |
19263 | mov<uint32_t>(1 | swsb, header[2], barrierVal); |
19264 | } |
19265 | |
19266 | template <HW hw> |
19267 | void gemm_kernel_generator_t<hw>::sysgemmReorderLocalIDs( |
19268 | const GEMMProblem &problem, const GEMMStrategy &strategy, |
19269 | GEMMState &state) { |
19270 | if (strategy.splitCopy) return; |
19271 | if (strategy.wg[LoopM] != 8) return; |
19272 | |
19273 | auto storage = state.ra.alloc_sub<uint64_t>(); |
19274 | auto temp = storage.uw(0); |
19275 | auto temp2 = storage.uw(2); |
19276 | auto lidM = state.inputs.localIDM; |
19277 | auto lidN = state.inputs.localIDN; |
19278 | |
19279 | // Remap local IDs so that upper 4x4 threads come first, then lower 4x4 threads. |
19280 | bfi2(1, temp, 0x08, lidN, lidM); |
19281 | shr(1, temp2, lidN, 1); |
19282 | shr(1, lidN, temp, 2); |
19283 | bfi2(1, lidM, 0x04, temp2, lidM); |
19284 | |
19285 | state.ra.safeRelease(storage); |
19286 | } |
19287 | |
19288 | namespace sysgemm2 { |
19289 | namespace x48 { |
19290 | static GRFRange A_regs = GRF(32) - GRF(63); |
19291 | static GRFRange B_regs = GRF(2) - GRF(25); |
19292 | static GRFRange C_regs = GRF(64) - GRF(255); |
19293 | static Subregister B_addr[3] = {GRF(26).ud(2), GRF(27).ud(2), GRF(1).ud(2)}; |
19294 | static Subregister A_addr[4] |
19295 | = {GRF(28).ud(2), GRF(29).ud(2), GRF(30).ud(2), GRF(31).ud(2)}; |
19296 | static GRF = GRF(0); |
19297 | } // namespace x48 |
19298 | |
19299 | namespace x32 { |
19300 | static GRFRange A_regs[2] = {GRF(32) - GRF(63), GRF(96) - GRF(127)}; |
19301 | static GRFRange B_regs[2] = {GRF(2) - GRF(17), GRF(66) - GRF(81)}; |
19302 | static GRFRange C_regs = GRF(128) - GRF(255); |
19303 | static Subregister B_addr[2][2] |
19304 | = {{GRF(26).ud(2), GRF(27).ud(2)}, {GRF(90).ud(2), GRF(91).ud(2)}}; |
19305 | static Subregister A_addr[2][4] |
19306 | = {{GRF(28).ud(2), GRF(29).ud(2), GRF(30).ud(2), GRF(31).ud(2)}, |
19307 | {GRF(92).ud(2), GRF(93).ud(2), GRF(94).ud(2), GRF(95).ud(2)}}; |
19308 | static GRF = GRF(0); |
19309 | static GRF = GRF(64); |
19310 | } // namespace x32 |
19311 | |
19312 | static GRFRange copyInputs = GRF(254) - GRF(255); |
19313 | static Subregister A_copyLoadAddr0 = GRF(254).uq(0); |
19314 | static Subregister A_copyLoadAddrSurf0 = GRF(254).ud(2); |
19315 | static Subregister slmAOff = GRF(254).d(4); |
19316 | static Subregister lda = GRF(254).ud(6); |
19317 | static Subregister B_copyLoadAddr0 = GRF(255).uq(0); |
19318 | static Subregister B_copyLoadAddrSurf0 = GRF(255).ud(2); |
19319 | static Subregister slmBOff[2] = {GRF(255).d(4), GRF(255).d(5)}; |
19320 | static Subregister ldb = GRF(255).ud(6); |
19321 | |
19322 | static Subregister kCounter = AccumulatorRegister(2).d(0); |
19323 | static Subregister barrierVal = AddressRegister(0).ud(0); |
19324 | } // namespace sysgemm2 |
19325 | |
19326 | template <HW hw> |
19327 | bool gemm_kernel_generator_t<hw>::sysgemm2AccumulateC( |
19328 | GEMMProblem &problem, const GEMMStrategy &strategy, GEMMState &state) { |
19329 | using namespace sysgemm2; |
19330 | auto params = systolicParams(hw, problem, strategy); |
19331 | auto unrollM = strategy.unroll[LoopM]; |
19332 | auto unrollN = strategy.unroll[LoopN]; |
19333 | auto localIDM = state.lidM; |
19334 | auto localIDN = state.lidN; |
19335 | auto C_regs = (unrollN == 48) ? x48::C_regs : x32::C_regs; |
19336 | |
19337 | if (unrollM != 16 && unrollM != 32) stub(); |
19338 | if (unrollN != 32 && unrollN != 48) stub(); |
19339 | if (isPacked(problem.A.layout)) { |
19340 | if (problem.A.crosspack != params.opsPerChan) stub(); |
19341 | if (problem.A.tileR != params.osys) stub(); |
19342 | if (problem.A.tileC != params.ksys) stub(); |
19343 | } |
19344 | if (isPacked(problem.B.layout)) { |
19345 | if (problem.B.crosspack != params.ksys) stub(); |
19346 | if (problem.B.tileR != 0 || problem.B.tileC != 0) stub(); |
19347 | } |
19348 | |
19349 | state.ra.claim(C_regs); |
19350 | |
19351 | // Check whether this thread will do copy or compute. Copy threads get priority (lower thread #). |
19352 | auto flagCompute = f1[1]; |
19353 | mov(1, flagCompute, state.isCompute.uw()); |
19354 | state.ra.safeRelease(state.isCompute); |
19355 | |
19356 | // Calculate A/B addresses and SLM offsets. |
19357 | auto tempStorage = C_regs[0]; |
19358 | auto suboffsetA = tempStorage.ud(0); |
19359 | auto suboffsetB = tempStorage.ud(1); |
19360 | auto tempB = tempStorage.ud(2); |
19361 | auto ldaUnrollM4 = tempStorage.ud(3); |
19362 | auto aInc = tempStorage.ud(4); |
19363 | auto ldbUnrollN4 = tempStorage.ud(5); |
19364 | auto bInc = tempStorage.ud(6); |
19365 | |
19366 | if (problem.A.layout == MatrixLayout::T) |
19367 | mulConstant(1, ldaUnrollM4, state.inputs.lda, unrollM / 4); |
19368 | if (problem.B.layout == MatrixLayout::N) |
19369 | mulConstant(1, ldbUnrollN4, state.inputs.ldb, unrollN / 4); |
19370 | if (!isPacked(problem.A.layout)) mov(1, lda, state.inputs.lda); |
19371 | if (!isPacked(problem.B.layout)) mov(1, ldb, state.inputs.ldb); |
19372 | |
19373 | and_(1 | ne | state.flagAP, null.uw(), localIDM, 1); |
19374 | |
19375 | switch (problem.A.layout) { |
19376 | case MatrixLayout::Pc: |
19377 | mulConstant(1, aInc, localIDN, unrollM * (32 / 4)); |
19378 | break; |
19379 | case MatrixLayout::N: |
19380 | mulConstant(1, aInc, localIDN, unrollM * problem.Ta / 4); |
19381 | break; |
19382 | case MatrixLayout::T: mul(1, aInc, ldaUnrollM4, localIDN.uw()); break; |
19383 | default: stub(); |
19384 | } |
19385 | switch (problem.B.layout) { |
19386 | case MatrixLayout::Pr: |
19387 | mulConstant(1, bInc, localIDM, unrollN * (32 / 4)); |
19388 | break; |
19389 | case MatrixLayout::N: mul(1, bInc, ldbUnrollN4, localIDM.uw()); break; |
19390 | case MatrixLayout::T: |
19391 | mulConstant(1, bInc, localIDM, unrollN * problem.Tb / 4); |
19392 | break; |
19393 | default: stub(); |
19394 | } |
19395 | |
19396 | mulConstant(1, suboffsetA, localIDN, unrollM * (32 / 4) / 16); |
19397 | mulConstant(1, suboffsetB, localIDM, unrollN * (32 / 4) / 16); |
19398 | |
19399 | if (strategy.A.base.isStateless()) |
19400 | eadd(1, A_copyLoadAddr0, state.effA, aInc, strategy, state); |
19401 | else |
19402 | add(1, A_copyLoadAddrSurf0, state.effA, aInc); |
19403 | |
19404 | if (strategy.B.base.isStateless()) |
19405 | eadd(1, B_copyLoadAddr0, state.effB, bInc, strategy, state); |
19406 | else |
19407 | add(1, B_copyLoadAddrSurf0, state.effB, bInc); |
19408 | |
19409 | mad(1, tempB, (4 * unrollM * 36) / 16, localIDN, (unrollN * 32) / 16); |
19410 | |
19411 | mul(1, x48::A_addr[0], localIDM, (unrollM * 36) / 16); |
19412 | add(1 | state.flagAP, x48::B_addr[0], tempB, (unrollN / 2) * (32 / 16)); |
19413 | mov(1 | ~state.flagAP, x48::B_addr[0], tempB); |
19414 | |
19415 | add(1, slmAOff, x48::A_addr[0], suboffsetA); |
19416 | add(1, slmBOff[0], tempB, suboffsetB); |
19417 | add3(1, slmBOff[1], tempB, suboffsetB, 8 * 32 / 16); |
19418 | |
19419 | // Marshal data needed later into acc2 for safekeeping. |
19420 | auto saveData = state.ra.alloc_range(2); |
19421 | auto kLoops = saveData[0].d(0); |
19422 | auto ldc = saveData[0].ud(1); |
19423 | auto flags = saveData[0].ud(2); |
19424 | auto k = saveData[0].ud(3); |
19425 | auto remM = saveData[0].uw(8); |
19426 | auto remN = saveData[0].uw(9); |
19427 | auto abo = saveData[0].ud(5); |
19428 | auto ao = saveData[0].w(10); |
19429 | auto bo = saveData[0].w(11); |
19430 | auto alpha = saveData[0].ud(6).reinterpret(0, problem.Ts.ngen()); |
19431 | auto beta = saveData[0].ud(7).reinterpret(0, problem.Ts.ngen()); |
19432 | auto remFusedStorage = saveData[1].ud(0); |
19433 | auto diagC = saveData[1].ud(1); |
19434 | auto saveI0 = saveData[1].ud(1); |
19435 | auto effCO = saveData[1].uq(1); |
19436 | auto saveJ0 = saveData[1].ud(3); |
19437 | auto C_ptr = saveData[1].uq(2); |
19438 | auto slotAB = saveData[1].ud(6); |
19439 | auto effAs = a0.ud(4); // dwords 4-5 |
19440 | auto effBs = a0.ud(6); // dwords 6-7 |
19441 | |
19442 | if (state.r0_info != acc0.ud()) mov<uint32_t>(8, acc0, state.r0_info); |
19443 | |
19444 | add(1, kLoops, state.k, params.ksys - 1); |
19445 | mov(1, ldc, state.inputs.ldc[0]); |
19446 | emov(1, C_ptr, state.effC[0], strategy, state); |
19447 | if (state.inputs.flags.isValid()) mov(1, flags, state.inputs.flags); |
19448 | mov(1, k, state.k); |
19449 | if (state.remainders[LoopM].isValid()) |
19450 | mov(1, remM, state.remainders[LoopM]); |
19451 | if (state.remainders[LoopN].isValid()) |
19452 | mov(1, remN, state.remainders[LoopN]); |
19453 | if (state.inputs.abo.isValid()) |
19454 | mov(1, abo, state.inputs.abo); |
19455 | else { |
19456 | if (state.inputs.ao.isValid()) mov(1, ao, state.inputs.ao); |
19457 | if (state.inputs.bo.isValid()) mov(1, bo, state.inputs.bo); |
19458 | } |
19459 | if (state.inputs.alpha_real.isValid()) |
19460 | mov(1, alpha, state.inputs.alpha_real); |
19461 | if (state.inputs.beta_real.isValid()) mov(1, beta, state.inputs.beta_real); |
19462 | shr(1, kLoops, kLoops, log2(params.ksys)); |
19463 | if (state.remFusedStorage.isValid()) |
19464 | mov(1, remFusedStorage, state.remFusedStorage); |
19465 | if (state.diagC.isValid()) mov(1, diagC, state.diagC); |
19466 | if (state.effCO.isValid()) { |
19467 | effCO = effCO.reinterpret(0, state.effCO.getType()); |
19468 | emov(1, effCO, state.effCO, strategy, state); |
19469 | } |
19470 | if (problem.hasBinaryPostOp()) { |
19471 | if (state.diagC.isValid()) stub(); |
19472 | if (state.effCO.isValid() && effCO.getBytes() > 4) stub(); |
19473 | mov(1, saveI0, state.i0); |
19474 | mov(1, saveJ0, state.j0); |
19475 | } |
19476 | if (problem.abOffset != ABOffset::None) { |
19477 | GRF temp = state.ra.alloc(); |
19478 | state.effAs = temp.uq(0).reinterpret(0, state.effA.getType()); |
19479 | state.effBs = temp.uq(1).reinterpret(0, state.effB.getType()); |
19480 | gemmCalcABOffsetAddrs(problem, strategy, state); |
19481 | mov<uint32_t>(4, effAs(1), temp); |
19482 | state.ra.safeRelease(temp); |
19483 | } |
19484 | if (state.fusedGEMM.slotA.isValid()) |
19485 | mov(1, slotAB, state.fusedGEMM.slotA.ud()); |
19486 | |
19487 | if (state.isNested) { |
19488 | // To do: replace with sel |
19489 | mov(2 | ~flagCompute, remM(1), 0); |
19490 | mov(1 | ~flagCompute, remFusedStorage, 0); |
19491 | } |
19492 | |
19493 | releaseSavedMNLocalIDs(state); |
19494 | state.ra.safeRelease(state.effA); |
19495 | state.ra.safeRelease(state.effB); |
19496 | state.ra.safeRelease(state.effC[0]); |
19497 | state.ra.safeRelease(state.inputs.lda); |
19498 | state.ra.safeRelease(state.inputs.ldb); |
19499 | |
19500 | state.ra.release(state.inputs.ldc[0]); |
19501 | state.ra.release(state.k); |
19502 | state.ra.release(state.remainders[LoopM]); |
19503 | state.ra.release(state.remainders[LoopN]); |
19504 | state.ra.release(state.inputs.abo); |
19505 | state.ra.release(state.inputs.ao); |
19506 | state.ra.release(state.inputs.bo); |
19507 | state.ra.release(state.inputs.alpha_real); |
19508 | state.ra.release(state.inputs.beta_real); |
19509 | state.ra.release(state.remFusedStorage); |
19510 | state.ra.release(state.diagC); |
19511 | state.ra.release(state.effCO); |
19512 | state.ra.release(state.fusedGEMM.slotA); |
19513 | state.ra.release(state.fusedGEMM.slotB); |
19514 | |
19515 | if (state.r0_info.isARF()) stub(); |
19516 | GRF r0_info {state.r0_info.getBase()}; |
19517 | if (hw >= HW::XeHPG) { |
19518 | mov(1, barrierVal.uw(0), Immediate::uw(0)); |
19519 | mov(2, barrierVal.ub(2)(1), r0_info.ub(11)(0)); |
19520 | } else |
19521 | and_(1, barrierVal, r0_info.ud(2), 0x7F000000); |
19522 | |
19523 | mov<float>(16, acc2, saveData[0]); |
19524 | |
19525 | Label labelCompute, labelDone; |
19526 | |
19527 | jmpi(1 | f1[1], labelCompute); |
19528 | sysgemm2KLoopCopy(problem, strategy, state); |
19529 | if (state.isNested) { |
19530 | jmpi(1, labelDone); |
19531 | } else |
19532 | epilogue(strategy, state); |
19533 | mark(labelCompute); |
19534 | sysgemm2KLoopCompute(problem, strategy, state); |
19535 | mark(labelDone); |
19536 | |
19537 | mov<float>(16, saveData[0], acc2); |
19538 | |
19539 | state.effC[0] = C_ptr; |
19540 | state.inputs.ldc[0] = ldc; |
19541 | if (state.inputs.flags.isValid()) state.inputs.flags = flags; |
19542 | state.k = k; |
19543 | if (state.remainders[LoopM].isValid()) state.remainders[LoopM] = remM; |
19544 | if (state.remainders[LoopN].isValid()) state.remainders[LoopN] = remN; |
19545 | if (state.inputs.abo.isValid()) state.inputs.abo = abo; |
19546 | if (state.inputs.ao.isValid()) state.inputs.ao = ao; |
19547 | if (state.inputs.bo.isValid()) state.inputs.bo = bo; |
19548 | if (state.inputs.alpha_real.isValid()) { |
19549 | state.inputs.alpha_real = alpha; |
19550 | if (!problem.alpha_real.fixed()) problem.alpha_real = alpha; |
19551 | } |
19552 | if (state.inputs.beta_real.isValid()) { |
19553 | state.inputs.beta_real = beta; |
19554 | if (!problem.beta_real.fixed()) problem.beta_real = beta; |
19555 | } |
19556 | if (state.remFusedStorage.isValid()) { |
19557 | state.remFusedStorage = remFusedStorage; |
19558 | state.remaindersFused[LoopM] = state.remainders[LoopM]; |
19559 | state.remaindersFused[LoopN] = state.remainders[LoopN]; |
19560 | state.remaindersFused[strategy.fusedLoop] = remFusedStorage; |
19561 | } |
19562 | if (state.diagC.isValid()) state.diagC = diagC; |
19563 | if (state.effCO.isValid()) state.effCO = effCO; |
19564 | if (state.fusedGEMM.slotA.isValid()) { |
19565 | state.fusedGEMM.slotA = slotAB.uw(0); |
19566 | state.fusedGEMM.slotB = slotAB.uw(1); |
19567 | } |
19568 | if (problem.abOffset != ABOffset::None) { |
19569 | auto tas = state.effAs.getType(); |
19570 | auto tbs = state.effBs.getType(); |
19571 | state.effAs = state.ra.alloc_sub(tas); |
19572 | state.effBs = state.ra.alloc_sub(tbs); |
19573 | mov<uint32_t>(getDwords(tas), state.effAs.ud()(1), effAs(1)); |
19574 | mov<uint32_t>(getDwords(tbs), state.effBs.ud()(1), effBs(1)); |
19575 | } |
19576 | if (problem.hasBinaryPostOp()) { |
19577 | state.i0 = saveI0; |
19578 | state.j0 = saveJ0; |
19579 | } |
19580 | |
19581 | // Set up C internal layout and registers. |
19582 | state.C_regs.resize(1); |
19583 | state.C_regs[0] = C_regs; |
19584 | state.C_layout.clear(); |
19585 | state.C_layout.reserve((unrollM / 8) * (unrollN / 4)); |
19586 | for (int j0 = 0; j0 < unrollN; j0 += 4) { |
19587 | for (int i0 = 0; i0 < unrollM; i0 += 8) { |
19588 | RegisterBlock block; |
19589 | block.log2GRFBytes = GRF::log2Bytes(hw); |
19590 | block.colMajor = true; |
19591 | block.splitComplex = false; |
19592 | block.cxComponent = RegisterBlock::Interleaved; |
19593 | block.nr = block.ld = 8; |
19594 | block.nc = 4; |
19595 | block.component = 0; |
19596 | block.offsetR = i0; |
19597 | block.offsetC = j0; |
19598 | block.crosspack = 1; |
19599 | block.bytes = 8 * 4 * problem.Tc.size(); |
19600 | block.simdSize = 0; |
19601 | |
19602 | int j0Interleaved = j0 << 1; |
19603 | if (j0Interleaved >= unrollN) j0Interleaved += 4 - unrollN; |
19604 | |
19605 | block.offsetBytes |
19606 | = (unrollN * i0 / 8 + j0Interleaved) * GRF::bytes(hw); |
19607 | state.C_layout.push_back(block); |
19608 | } |
19609 | } |
19610 | |
19611 | // Set up C external layout. |
19612 | state.copyC = true; |
19613 | bool remM_Ce, remN_Ce; |
19614 | getCRemainders(problem, strategy, remM_Ce, remN_Ce); |
19615 | |
19616 | if (!getRegLayout(problem.Tc_ext, state.C_layoutExt, unrollM, unrollN, |
19617 | remM_Ce, remN_Ce, true, false, 0, 0, problem.C, |
19618 | state.Cext_strategy)) |
19619 | return false; |
19620 | if (remM_Ce || remN_Ce) |
19621 | (void)getRegLayout(problem.Tc_ext, state.C_layoutExtUnmasked, unrollM, |
19622 | unrollN, false, false, true, false, 0, 0, problem.C, |
19623 | state.Cext_strategy); |
19624 | |
19625 | if (state.r0_info != acc0.ud()) mov<uint32_t>(8, state.r0_info, acc0); |
19626 | |
19627 | return true; // Success! |
19628 | } |
19629 | |
19630 | template <HW hw> |
19631 | void gemm_kernel_generator_t<hw>::sysgemm2KLoopCopy(const GEMMProblem &problem, |
19632 | const GEMMStrategy &strategy, GEMMState &state) { |
19633 | using namespace sysgemm2; |
19634 | |
19635 | Label top, bottom, smallK, skipSmallK, reenterSmallK, done; |
19636 | |
19637 | bool surfaceA = !strategy.A.base.isStateless(); |
19638 | bool surfaceB = !strategy.B.base.isStateless(); |
19639 | auto unrollM = strategy.unroll[LoopM]; |
19640 | auto unrollN = strategy.unroll[LoopN]; |
19641 | bool _32x = (unrollM == 32); |
19642 | bool x48 = (unrollN == 48); |
19643 | bool int8 = problem.Ta.size() == 1; |
19644 | |
19645 | int globalBuffers = strategy.A_copies; |
19646 | auto slmStride = strategy.slmSysgemmBlockSize() / 16; |
19647 | |
19648 | if (globalBuffers < 2) stub(); |
19649 | if (globalBuffers > 5) stub(); |
19650 | |
19651 | auto saveRA = state.ra; |
19652 | state.ra.release(r0 - r127); |
19653 | state.ra.release(r128 - r255); |
19654 | state.ra.claim(copyInputs); |
19655 | |
19656 | // Register and token allocation. |
19657 | int aTokens = 0, aLoadAddrs = 0, aStoreAddrs = 1; |
19658 | int bTokens = 0, bLoadAddrs = 0, bStoreAddrs = 1 + x48; |
19659 | int aiRegCount = unrollM / 4, biRegCount = unrollN / 4; |
19660 | bool aRepack = false, bRepack = false; |
19661 | |
19662 | if (problem.A.alignment & 3 || problem.B.alignment & 3) stub(); |
19663 | |
19664 | switch (problem.A.layout) { |
19665 | case MatrixLayout::Pc: |
19666 | aTokens = aLoadAddrs = (surfaceA && _32x) ? 2 : 1; |
19667 | break; |
19668 | case MatrixLayout::N: |
19669 | if (!surfaceA) stub(); |
19670 | aTokens = int8 ? 2 : 1; |
19671 | aLoadAddrs = aTokens * 2; |
19672 | aRepack = true; |
19673 | break; |
19674 | case MatrixLayout::T: |
19675 | if (!surfaceA) stub(); |
19676 | aTokens = aLoadAddrs = 2; |
19677 | break; |
19678 | default: stub(); |
19679 | } |
19680 | |
19681 | switch (problem.B.layout) { |
19682 | case MatrixLayout::Pr: |
19683 | bTokens = bLoadAddrs = (surfaceB ? 2 : 1) + x48; |
19684 | break; |
19685 | case MatrixLayout::N: |
19686 | if (!surfaceB) stub(); |
19687 | bTokens = 2; |
19688 | bLoadAddrs = x48 ? 4 : 2; |
19689 | if (x48) biRegCount = 16; |
19690 | bRepack = true; |
19691 | break; |
19692 | case MatrixLayout::T: |
19693 | if (!surfaceB) stub(); |
19694 | bTokens = (int8 || x48) ? 2 : 1; |
19695 | bLoadAddrs = bTokens * 2; |
19696 | bRepack = true; |
19697 | break; |
19698 | default: stub(); |
19699 | } |
19700 | |
19701 | int tokenStride = aTokens + bTokens; |
19702 | if (tokenStride * globalBuffers > 15) |
19703 | throw std::runtime_error("Not enough tokens available." ); |
19704 | |
19705 | auto &Ai_regs = state.Ai_regs; |
19706 | auto &Bi_regs = state.Bi_regs; |
19707 | auto &Ao_regs = state.Ao_regs; |
19708 | auto &Bo_regs = state.Bo_regs; |
19709 | auto &Ai_addrs = state.Ai_addrs; |
19710 | auto &Bi_addrs = state.Bi_addrs; |
19711 | auto &Ao_addrs = state.Ao_addrs; |
19712 | auto &Bo_addrs = state.Bo_addrs; |
19713 | GRFRange ldaMultiples, ldbMultiples; |
19714 | FlagRegister flag12; |
19715 | GRF A_swizzle, B_swizzle; |
19716 | Subregister lda16, ldb16, ldaK, ldbK; |
19717 | Subregister Ai_advance, Bi_advance; |
19718 | GRF = state.ra.alloc(); |
19719 | GRF = state.ra.alloc(); |
19720 | GRF slmBase = state.ra.alloc().d(); |
19721 | GRF temp = state.ra.alloc().d(); |
19722 | |
19723 | Ai_regs.reserve(globalBuffers); |
19724 | Bi_regs.reserve(globalBuffers); |
19725 | Ai_addrs.reserve(globalBuffers); |
19726 | Bi_addrs.reserve(globalBuffers); |
19727 | for (int i = 0; i < globalBuffers; i++) { |
19728 | Ai_regs.push_back(state.ra.alloc_range(aiRegCount)); |
19729 | Bi_regs.push_back(state.ra.alloc_range(biRegCount)); |
19730 | Ai_addrs.push_back(state.ra.alloc_range(aLoadAddrs)); |
19731 | Bi_addrs.push_back(state.ra.alloc_range(bLoadAddrs)); |
19732 | } |
19733 | |
19734 | if (aRepack) Ao_regs = state.ra.alloc_range(8); |
19735 | if (bRepack) Bo_regs = state.ra.alloc_range(unrollN / 4); |
19736 | Ao_addrs.push_back(state.ra.alloc_range(aStoreAddrs)); |
19737 | Bo_addrs.push_back(state.ra.alloc_range(bStoreAddrs)); |
19738 | |
19739 | if (state.emulate.temp[0].isValid()) { |
19740 | for (int q = 0; q < 2; q++) |
19741 | state.ra.safeRelease(state.emulate.temp[q]); |
19742 | state.emulate.flag = f1[1]; |
19743 | state.emulate.flagOffset = 8; |
19744 | } |
19745 | |
19746 | // Address initialization. |
19747 | if (surfaceA && isPacked(problem.A.layout)) |
19748 | shr(1, A_copyLoadAddrSurf0, A_copyLoadAddrSurf0, 4); |
19749 | if (surfaceB && isPacked(problem.B.layout)) |
19750 | shr(1, B_copyLoadAddrSurf0, B_copyLoadAddrSurf0, 4); |
19751 | |
19752 | mov(1, slmBase[0], 0); |
19753 | mov(1, slmBase[1], -4 * slmStride); |
19754 | |
19755 | auto makeLDMultiples = [&](GRFRange &multiples, const Subregister &ld, |
19756 | int n) { |
19757 | multiples = state.ra.alloc_range(n / 8); |
19758 | mov<uint16_t>(8, multiples[0], Immediate::uv(0, 1, 2, 3, 4, 5, 6, 7)); |
19759 | if (n > 8) |
19760 | mov<uint16_t>(8, multiples[1], |
19761 | Immediate::uv(8, 9, 10, 11, 12, 13, 14, 15)); |
19762 | mul<uint32_t>(8, multiples[0], ld, multiples[0].uw()); |
19763 | if (n > 8) mul<uint32_t>(8, multiples[1], ld, multiples[1].uw()); |
19764 | }; |
19765 | |
19766 | switch (problem.A.layout) { |
19767 | case MatrixLayout::N: |
19768 | lda16 = state.ra.alloc_sub<uint32_t>(); |
19769 | Ai_advance = state.ra.alloc_sub<uint32_t>(); |
19770 | if (!int8) { |
19771 | A_swizzle = state.ra.alloc().uw(); |
19772 | mov(4, A_swizzle[0](1), Immediate::uv(0, 4, 2, 6, 0, 0, 0, 0)); |
19773 | add(4, A_swizzle[4](1), A_swizzle[0](1), 64); |
19774 | add(8, A_swizzle[8](1), A_swizzle[0](1), 128); |
19775 | } |
19776 | makeLDMultiples(ldaMultiples, lda, 16); |
19777 | mulConstant(1, lda16, lda, 16); |
19778 | if (int8) { |
19779 | ldaK = state.ra.alloc_sub<uint32_t>(); |
19780 | mulConstant(1, ldaK, lda, 32); |
19781 | } else |
19782 | ldaK = lda16; |
19783 | mulConstant(1, Ai_advance, lda, (int8 ? 32 : 16) * globalBuffers); |
19784 | break; |
19785 | case MatrixLayout::T: makeLDMultiples(ldaMultiples, lda, 8); break; |
19786 | default: break; |
19787 | } |
19788 | |
19789 | switch (problem.B.layout) { |
19790 | case MatrixLayout::N: |
19791 | B_swizzle = state.ra.alloc().uw(); |
19792 | mov(8, B_swizzle[0](1), Immediate::uv(0, 1, 2, 3, 4, 5, 6, 7)); |
19793 | if (x48) { |
19794 | flag12 = state.raVFlag.alloc(); |
19795 | mov(1, flag12, 0x0FFF); |
19796 | } |
19797 | mad(8, B_swizzle[8](1), 4, B_swizzle[0](1), x48 ? 64 : 32); |
19798 | mulConstant(8, B_swizzle[0](1), B_swizzle[0](1), x48 ? 64 : 32); |
19799 | makeLDMultiples(ldbMultiples, ldb, x48 ? 16 : 8); |
19800 | break; |
19801 | case MatrixLayout::T: |
19802 | makeLDMultiples(ldbMultiples, ldb, 16); |
19803 | if (int8) { |
19804 | ldb16 = state.ra.alloc_sub<uint32_t>(); |
19805 | mulConstant(1, ldb16, ldb, 16); |
19806 | } |
19807 | Bi_advance = state.ra.alloc_sub<uint32_t>(); |
19808 | ldbK = state.ra.alloc_sub<uint32_t>(); |
19809 | mulConstant(1, Bi_advance, ldb, (int8 ? 32 : 16) * globalBuffers); |
19810 | mulConstant(1, ldbK, ldb, int8 ? 32 : 16); |
19811 | default: break; |
19812 | } |
19813 | |
19814 | for (int i = 0; i < globalBuffers; i++) |
19815 | switch (problem.A.layout) { |
19816 | case MatrixLayout::Pc: |
19817 | if (surfaceA) { |
19818 | add(1, Ai_addrs[i][0].ud(2), A_copyLoadAddrSurf0, |
19819 | i * unrollM * 32 / 16); |
19820 | if (_32x) |
19821 | add(1, Ai_addrs[i][1].ud(2), A_copyLoadAddrSurf0, |
19822 | (i * unrollM * 32 + 4 * 32) / 16); |
19823 | } else |
19824 | eadd(1, Ai_addrs[i][0].uq(0), A_copyLoadAddr0, |
19825 | i * unrollM * 32, strategy, state); |
19826 | break; |
19827 | case MatrixLayout::N: |
19828 | if (i == 0) { |
19829 | add<uint32_t>(16, Ai_addrs[i][0], A_copyLoadAddrSurf0, |
19830 | ldaMultiples); |
19831 | if (int8) |
19832 | add3<uint32_t>(16, Ai_addrs[i][2], A_copyLoadAddrSurf0, |
19833 | ldaMultiples, lda16); |
19834 | } else { |
19835 | add<uint32_t>(16, Ai_addrs[i][0], Ai_addrs[i - 1][0], ldaK); |
19836 | if (int8) |
19837 | add<uint32_t>( |
19838 | 16, Ai_addrs[i][2], Ai_addrs[i - 1][2], ldaK); |
19839 | } |
19840 | break; |
19841 | case MatrixLayout::T: |
19842 | add3<uint32_t>(8, Ai_addrs[i][0], A_copyLoadAddrSurf0, |
19843 | ldaMultiples, i * 32 + 0); |
19844 | add3<uint32_t>(8, Ai_addrs[i][1], A_copyLoadAddrSurf0, |
19845 | ldaMultiples, i * 32 + 16); |
19846 | break; |
19847 | default: stub(); |
19848 | } |
19849 | |
19850 | for (int i = 0; i < globalBuffers; i++) |
19851 | switch (problem.B.layout) { |
19852 | case MatrixLayout::Pr: |
19853 | if (surfaceB) { |
19854 | add(1, Bi_addrs[i][0].ud(2), B_copyLoadAddrSurf0, |
19855 | i * unrollN * 32 / 16); |
19856 | add(1, Bi_addrs[i][1].ud(2), B_copyLoadAddrSurf0, |
19857 | (i * unrollN * 32 + 4 * 32) / 16); |
19858 | if (x48) |
19859 | add(1, Bi_addrs[i][2].ud(2), B_copyLoadAddrSurf0, |
19860 | (i * unrollN * 32 + 8 * 32) / 16); |
19861 | } else { |
19862 | eadd(1, Bi_addrs[i][0].uq(0), B_copyLoadAddr0, |
19863 | i * unrollN * 32, strategy, state); |
19864 | if (x48) |
19865 | eadd(1, Bi_addrs[i][1].uq(0), B_copyLoadAddr0, |
19866 | i * unrollN * 32 + 8 * 32, strategy, state); |
19867 | } |
19868 | break; |
19869 | case MatrixLayout::N: |
19870 | add3<uint32_t>(x48 ? 16 : 8, Bi_addrs[i][0], |
19871 | B_copyLoadAddrSurf0, ldbMultiples, i * 32 + 0); |
19872 | add3<uint32_t>(x48 ? 16 : 8, Bi_addrs[i][x48 ? 2 : 1], |
19873 | B_copyLoadAddrSurf0, ldbMultiples, i * 32 + 16); |
19874 | break; |
19875 | case MatrixLayout::T: |
19876 | if (i == 0) { |
19877 | add<uint32_t>(16, Bi_addrs[i][0], B_copyLoadAddrSurf0, |
19878 | ldbMultiples); |
19879 | if (int8) |
19880 | add3<uint32_t>(16, Bi_addrs[i][2], B_copyLoadAddrSurf0, |
19881 | ldbMultiples, ldb16); |
19882 | else if (x48) |
19883 | add3<uint32_t>(16, Bi_addrs[i][2], B_copyLoadAddrSurf0, |
19884 | ldbMultiples, 16); |
19885 | } else { |
19886 | add<uint32_t>(16, Bi_addrs[i][0], Bi_addrs[i - 1][0], ldbK); |
19887 | if (int8 || x48) |
19888 | add<uint32_t>( |
19889 | 16, Bi_addrs[i][2], Bi_addrs[i - 1][2], ldbK); |
19890 | } |
19891 | break; |
19892 | default: stub(); |
19893 | } |
19894 | |
19895 | sysgemmBarrierPrep(InstructionModifier(), copyBarrierHeader); |
19896 | |
19897 | mov(2, slmBase[4](1), slmBase[0](1)); |
19898 | |
19899 | // Main logic. |
19900 | auto copyLoad = [&](int buffer) { |
19901 | int atbase = tokenStride * buffer; |
19902 | int btbase = atbase + aTokens; |
19903 | switch (problem.A.layout) { |
19904 | case MatrixLayout::Pc: |
19905 | if (surfaceA) { |
19906 | load(16 | SBID(atbase + 0), Ai_regs[buffer][0], |
19907 | block_oword(8), strategy.A.base, |
19908 | Ai_addrs[buffer][0]); |
19909 | if (_32x) |
19910 | load(16 | SBID(atbase + 1), Ai_regs[buffer][4], |
19911 | block_oword(8), strategy.A.base, |
19912 | Ai_addrs[buffer][1]); |
19913 | } else |
19914 | load(16 | SBID(atbase + 0), Ai_regs[buffer][0], |
19915 | block_hword(_32x ? 8 : 4), strategy.A.base, |
19916 | Ai_addrs[buffer]); |
19917 | break; |
19918 | case MatrixLayout::N: |
19919 | if (int8) { |
19920 | load(16 | SBID(atbase + 0), Ai_regs[buffer][0], |
19921 | surface_dword(ChannelMask::rg), strategy.A.base, |
19922 | Ai_addrs[buffer][0]); |
19923 | load(16 | SBID(atbase + 1), Ai_regs[buffer][4], |
19924 | surface_dword(ChannelMask::rg), strategy.A.base, |
19925 | Ai_addrs[buffer][2]); |
19926 | } else |
19927 | load(16 | SBID(atbase + 0), Ai_regs[buffer][0], |
19928 | surface_dword(ChannelMask::rgba), strategy.A.base, |
19929 | Ai_addrs[buffer][0]); |
19930 | break; |
19931 | case MatrixLayout::T: |
19932 | load(8 | SBID(atbase + 0), Ai_regs[buffer][0], |
19933 | surface_dword(ChannelMask::rgba), strategy.A.base, |
19934 | Ai_addrs[buffer][0]); |
19935 | load(8 | SBID(atbase + 1), Ai_regs[buffer][4], |
19936 | surface_dword(ChannelMask::rgba), strategy.A.base, |
19937 | Ai_addrs[buffer][1]); |
19938 | break; |
19939 | default: stub(); |
19940 | } |
19941 | |
19942 | switch (problem.B.layout) { |
19943 | case MatrixLayout::Pr: |
19944 | if (surfaceB) { |
19945 | load(16 | SBID(btbase + 0), Bi_regs[buffer][0], |
19946 | block_oword(8), strategy.B.base, |
19947 | Bi_addrs[buffer][0]); |
19948 | load(16 | SBID(btbase + 1), Bi_regs[buffer][4], |
19949 | block_oword(8), strategy.B.base, |
19950 | Bi_addrs[buffer][1]); |
19951 | if (x48) |
19952 | load(16 | SBID(btbase + 2), Bi_regs[buffer][8], |
19953 | block_oword(8), strategy.B.base, |
19954 | Bi_addrs[buffer][2]); |
19955 | } else { |
19956 | load(16 | SBID(btbase + 0), Bi_regs[buffer][0], |
19957 | block_hword(8), strategy.B.base, |
19958 | Bi_addrs[buffer][0]); |
19959 | if (x48) |
19960 | load(16 | SBID(btbase + 1), Bi_regs[buffer][8], |
19961 | block_hword(4), strategy.B.base, |
19962 | Bi_addrs[buffer][1]); |
19963 | } |
19964 | break; |
19965 | case MatrixLayout::N: |
19966 | if (x48) { |
19967 | load(16 | SBID(btbase + 0) | flag12 | any4h, |
19968 | Bi_regs[buffer][0], |
19969 | surface_dword(ChannelMask::rgba), strategy.B.base, |
19970 | Bi_addrs[buffer][0]); |
19971 | load(16 | SBID(btbase + 1) | flag12 | any4h, |
19972 | Bi_regs[buffer][8], |
19973 | surface_dword(ChannelMask::rgba), strategy.B.base, |
19974 | Bi_addrs[buffer][2]); |
19975 | } else { |
19976 | load(8 | SBID(btbase + 0), Bi_regs[buffer][0], |
19977 | surface_dword(ChannelMask::rgba), strategy.B.base, |
19978 | Bi_addrs[buffer][0]); |
19979 | load(8 | SBID(btbase + 1), Bi_regs[buffer][4], |
19980 | surface_dword(ChannelMask::rgba), strategy.B.base, |
19981 | Bi_addrs[buffer][1]); |
19982 | } |
19983 | break; |
19984 | case MatrixLayout::T: |
19985 | if (int8) { |
19986 | auto cmask = x48 ? ChannelMask::rgb : ChannelMask::rg; |
19987 | load(16 | SBID(btbase + 0), Bi_regs[buffer][0], |
19988 | surface_dword(cmask), strategy.B.base, |
19989 | Bi_addrs[buffer][0]); |
19990 | load(16 | SBID(btbase + 1), Bi_regs[buffer][x48 ? 6 : 4], |
19991 | surface_dword(cmask), strategy.B.base, |
19992 | Bi_addrs[buffer][2]); |
19993 | } else { |
19994 | load(16 | SBID(btbase + 0), Bi_regs[buffer][0], |
19995 | surface_dword(ChannelMask::rgba), strategy.B.base, |
19996 | Bi_addrs[buffer][0]); |
19997 | if (x48) |
19998 | load(16 | SBID(btbase + 1), Bi_regs[buffer][8], |
19999 | surface_dword(ChannelMask::rg), strategy.B.base, |
20000 | Bi_addrs[buffer][2]); |
20001 | } |
20002 | break; |
20003 | default: stub(); |
20004 | } |
20005 | }; |
20006 | |
20007 | auto copyRepack = [&](int buffer) { |
20008 | int atbase = tokenStride * buffer; |
20009 | int btbase = atbase + aTokens; |
20010 | |
20011 | switch (problem.A.layout) { |
20012 | case MatrixLayout::N: |
20013 | if (int8) { |
20014 | for (int j = 0; j < 4; j++) { |
20015 | int reg = (j >> 1); |
20016 | int sub = (j & 1) << 4; |
20017 | mov<uint8_t>(16, Ao_regs[j + 0][0](1), |
20018 | Ai_regs[buffer][reg + 0][sub](1, 4, 4)); |
20019 | mov<uint8_t>(16, Ao_regs[j + 4][0](1), |
20020 | Ai_regs[buffer][reg + 4][sub](1, 4, 4)); |
20021 | mov<uint8_t>(16, Ao_regs[j + 0][16](1), |
20022 | Ai_regs[buffer][reg + 2][sub](1, 4, 4)); |
20023 | mov<uint8_t>(16, Ao_regs[j + 4][16](1), |
20024 | Ai_regs[buffer][reg + 6][sub](1, 4, 4)); |
20025 | } |
20026 | } else { |
20027 | // a0: 0 4 2 6 64 68 66 70... |
20028 | add(16, a0, A_swizzle, Ai_regs[buffer][0].getBase() * 32); |
20029 | setDefaultAutoSWSB(false); |
20030 | sync.allwr(0b1 << atbase); |
20031 | for (int j = 0; j < 8; j++) |
20032 | mov<uint16_t>( |
20033 | 16, Ao_regs[j], indirect[a0].uw(j * 8)(1, 0)); |
20034 | setDefaultAutoSWSB(true); |
20035 | } |
20036 | break; |
20037 | default: break; |
20038 | } |
20039 | |
20040 | switch (problem.B.layout) { |
20041 | case MatrixLayout::N: |
20042 | // a0 (x32): 0 32 64 96 128 160 192 228 4 36 68... |
20043 | // a0 (x48): 0 64 128 192 256 320 384 448 4 68 132 196... |
20044 | add(16, a0, B_swizzle, Bi_regs[buffer][0].getBase() * 32); |
20045 | setDefaultAutoSWSB(false); |
20046 | sync.allwr(0b11 << btbase); |
20047 | for (int j = 0; j < unrollN / 4; j += 2) // 2 cols at a time |
20048 | mov<uint32_t>(16, Bo_regs[j], indirect[a0].ud(j * 4)(1, 0)); |
20049 | setDefaultAutoSWSB(true); |
20050 | break; |
20051 | case MatrixLayout::T: |
20052 | if (int8) { |
20053 | for (int j = 0; j < unrollN / 4; j++) |
20054 | mov<uint8_t>(16, Bo_regs[j][0](1), |
20055 | Bi_regs[buffer][(j & ~3) >> 1][j & 3](4)); |
20056 | for (int j = 0; j < unrollN / 4; j++) |
20057 | mov<uint8_t>(16, Bo_regs[j][16](1), |
20058 | Bi_regs[buffer][(x48 ? 6 : 4) + ((j & ~3) >> 1)] |
20059 | [j & 3](4)); |
20060 | } else { |
20061 | for (int j = 0; j < unrollN / 4; j++) |
20062 | mov<uint16_t>(16, Bo_regs[j], |
20063 | Bi_regs[buffer][j & ~1][j & 1](2)); |
20064 | } |
20065 | break; |
20066 | default: break; |
20067 | } |
20068 | }; |
20069 | |
20070 | auto copyStore = [&](int buffer) { |
20071 | int atbase = tokenStride * buffer; |
20072 | int btbase = atbase + aTokens; |
20073 | |
20074 | auto A_regs = aRepack ? Ao_regs : Ai_regs[buffer]; |
20075 | auto B_regs = bRepack ? Bo_regs : Bi_regs[buffer]; |
20076 | |
20077 | copyRepack(buffer); |
20078 | |
20079 | auto b1 = (surfaceB && isPacked(problem.B.layout)) ? 2 : 1; |
20080 | |
20081 | store(16 | SBID(atbase + 0), block_oword(_32x ? 16 : 8), SLM, |
20082 | Ao_addrs[0][0], A_regs[0]); |
20083 | store(16 | SBID(btbase + 0), block_oword(16), SLM, Bo_addrs[0][0], |
20084 | B_regs[0]); |
20085 | if (x48) |
20086 | store(16 | SBID(btbase + b1), block_oword(8), SLM, Bo_addrs[0][1], |
20087 | B_regs[8]); |
20088 | }; |
20089 | |
20090 | auto advanceLoad = [&](int buffer) { |
20091 | switch (problem.A.layout) { |
20092 | case MatrixLayout::Pc: |
20093 | if (surfaceA) { |
20094 | add(1, Ai_addrs[buffer][0].ud(2), Ai_addrs[buffer][0].ud(2), |
20095 | globalBuffers * 32 * unrollM / 16); |
20096 | if (_32x) |
20097 | add(1, Ai_addrs[buffer][1].ud(2), |
20098 | Ai_addrs[buffer][1].ud(2), |
20099 | globalBuffers * 32 * unrollM / 16); |
20100 | } else |
20101 | eadd(1, Ai_addrs[buffer][0].uq(0), |
20102 | Ai_addrs[buffer][0].uq(0), |
20103 | globalBuffers * 32 * unrollM, strategy, state); |
20104 | break; |
20105 | case MatrixLayout::N: |
20106 | add<uint32_t>(16, Ai_addrs[buffer][0], Ai_addrs[buffer][0], |
20107 | Ai_advance); |
20108 | if (int8) |
20109 | add<uint32_t>(16, Ai_addrs[buffer][2], Ai_addrs[buffer][2], |
20110 | Ai_advance); |
20111 | break; |
20112 | case MatrixLayout::T: |
20113 | add<uint32_t>(8, Ai_addrs[buffer][0], Ai_addrs[buffer][0], |
20114 | 32 * globalBuffers); |
20115 | add<uint32_t>(8, Ai_addrs[buffer][1], Ai_addrs[buffer][1], |
20116 | 32 * globalBuffers); |
20117 | break; |
20118 | default: stub(); |
20119 | } |
20120 | |
20121 | switch (problem.B.layout) { |
20122 | case MatrixLayout::Pr: |
20123 | if (surfaceB) { |
20124 | add(1, Bi_addrs[buffer][0].ud(2), Bi_addrs[buffer][0].ud(2), |
20125 | globalBuffers * 32 * unrollN / 16); |
20126 | add(1, Bi_addrs[buffer][1].ud(2), Bi_addrs[buffer][1].ud(2), |
20127 | globalBuffers * 32 * unrollN / 16); |
20128 | if (x48) |
20129 | add(1, Bi_addrs[buffer][2].ud(2), |
20130 | Bi_addrs[buffer][2].ud(2), |
20131 | globalBuffers * 32 * unrollN / 16); |
20132 | } else { |
20133 | eadd(1, Bi_addrs[buffer][0].uq(0), |
20134 | Bi_addrs[buffer][0].uq(0), |
20135 | globalBuffers * 32 * unrollN, strategy, state); |
20136 | if (x48) |
20137 | eadd(1, Bi_addrs[buffer][1].uq(0), |
20138 | Bi_addrs[buffer][1].uq(0), |
20139 | globalBuffers * 32 * unrollN, strategy, state); |
20140 | } |
20141 | break; |
20142 | case MatrixLayout::N: |
20143 | add<uint32_t>(16, Bi_addrs[buffer][0], Bi_addrs[buffer][0], |
20144 | 32 * globalBuffers); |
20145 | if (x48) |
20146 | add<uint32_t>(16, Bi_addrs[buffer][2], Bi_addrs[buffer][2], |
20147 | 32 * globalBuffers); |
20148 | break; |
20149 | case MatrixLayout::T: |
20150 | add<uint32_t>(16, Bi_addrs[buffer][0], Bi_addrs[buffer][0], |
20151 | Bi_advance); |
20152 | if (int8 || x48) |
20153 | add<uint32_t>(16, Bi_addrs[buffer][2], Bi_addrs[buffer][2], |
20154 | Bi_advance); |
20155 | break; |
20156 | default: stub(); |
20157 | } |
20158 | }; |
20159 | |
20160 | auto advanceStore = [&](int buffer = -1) { |
20161 | add(2, temp, slmBase, slmStride); |
20162 | add(1, Ao_addrs[0][0].ud(2), slmBase, slmAOff); |
20163 | add(1, Bo_addrs[0][0].ud(2), slmBase, slmBOff[0]); |
20164 | if (x48) add(1, Bo_addrs[0][1].ud(2), slmBase, slmBOff[1]); |
20165 | |
20166 | csel(2 | ge | f0[0], slmBase, slmBase[4](1, 1, 0), temp[0](1, 1, 0), |
20167 | temp[1]); |
20168 | }; |
20169 | |
20170 | auto fence = [&]() { slmfence(sb15, copyFenceHeader, copyFenceHeader); }; |
20171 | |
20172 | auto splitBarrier = [&]() { |
20173 | barrierwait(); |
20174 | barriermsg(sb15, copyBarrierHeader); |
20175 | }; |
20176 | |
20177 | // Warmup. |
20178 | if (globalBuffers > 1) cmp(1 | gt | f0[0], kCounter, 1); |
20179 | if (globalBuffers > 2) cmp(1 | gt | f0[1], kCounter, 2); |
20180 | if (globalBuffers > 3) cmp(1 | gt | f1[0], kCounter, 3); |
20181 | if (globalBuffers > 4) cmp(1 | gt | f1[1], kCounter, 4); |
20182 | if (globalBuffers > 1) { |
20183 | copyLoad(0); |
20184 | jmpi(1 | ~f0[0], smallK); |
20185 | } |
20186 | if (globalBuffers > 2) { |
20187 | copyLoad(1); |
20188 | jmpi(1 | ~f0[1], smallK); |
20189 | } |
20190 | if (globalBuffers > 3) { |
20191 | copyLoad(2); |
20192 | jmpi(1 | ~f1[0], smallK); |
20193 | } |
20194 | if (globalBuffers > 4) { |
20195 | copyLoad(3); |
20196 | jmpi(1 | ~f1[1], smallK); |
20197 | } |
20198 | |
20199 | auto flagLast = FlagRegister::createFromIndex(globalBuffers - 2); |
20200 | cmp(1 | le | flagLast, kCounter, globalBuffers); |
20201 | copyLoad(globalBuffers - 1); |
20202 | jmpi(1 | flagLast, smallK); |
20203 | |
20204 | add(1 | gt | f0[0], kCounter, kCounter, -2 * globalBuffers); |
20205 | |
20206 | advanceStore(); |
20207 | advanceLoad(0); |
20208 | if (globalBuffers > 1) advanceLoad(1); |
20209 | if (globalBuffers > 2) advanceLoad(2); |
20210 | if (globalBuffers > 3) advanceLoad(3); |
20211 | |
20212 | copyStore(0); |
20213 | if (globalBuffers > 4) advanceLoad(4); |
20214 | |
20215 | fence(); |
20216 | advanceStore(0); |
20217 | copyLoad(0); |
20218 | barriermsg(sb15, copyBarrierHeader); |
20219 | copyStore(1); |
20220 | |
20221 | advanceLoad(0); |
20222 | |
20223 | jmpi(1 | ~f0[0], bottom); |
20224 | sync.nop(SWSB<AllPipes>(1)); |
20225 | mark(top); |
20226 | { |
20227 | add(1 | gt | f0[0], kCounter, kCounter, -globalBuffers); |
20228 | for (int i = 0; i < globalBuffers; i++) { |
20229 | fence(); |
20230 | advanceStore((i + 1) % globalBuffers); |
20231 | copyLoad((i + 1) % globalBuffers); // move after barrier? |
20232 | splitBarrier(); |
20233 | copyStore((i + 2) % globalBuffers); |
20234 | advanceLoad((i + 1) % globalBuffers); |
20235 | } |
20236 | } |
20237 | jmpi(1 | f0[0], top); |
20238 | mark(bottom); |
20239 | |
20240 | if (globalBuffers > 1) cmp(1 | gt | f0[0], kCounter, -globalBuffers + 1); |
20241 | if (globalBuffers > 2) cmp(1 | gt | f0[1], kCounter, -globalBuffers + 2); |
20242 | if (globalBuffers > 3) cmp(1 | gt | f1[0], kCounter, -globalBuffers + 3); |
20243 | if (globalBuffers > 4) cmp(1 | gt | f1[1], kCounter, -globalBuffers + 4); |
20244 | |
20245 | // Cooldown loop. All buffers but #1 loaded. |
20246 | for (int i = 1; i < globalBuffers; i++) { |
20247 | Label skipLoad; |
20248 | fence(); |
20249 | advanceStore(i); |
20250 | jmpi(1 | ~FlagRegister::createFromIndex(i - 1), skipLoad); |
20251 | copyLoad(i); |
20252 | mark(skipLoad); |
20253 | splitBarrier(); |
20254 | copyStore((i + 1) % globalBuffers); |
20255 | advanceLoad(i); |
20256 | } |
20257 | |
20258 | jmpi(1, skipSmallK); |
20259 | mark(smallK); |
20260 | |
20261 | advanceStore(); |
20262 | copyStore(0); |
20263 | |
20264 | fence(); |
20265 | advanceStore(0); |
20266 | barriermsg(sb15, copyBarrierHeader); |
20267 | jmpi(1, reenterSmallK); |
20268 | |
20269 | mark(skipSmallK); |
20270 | |
20271 | for (int i = 0; i < globalBuffers - 1; i++) { |
20272 | fence(); |
20273 | advanceStore(i); |
20274 | splitBarrier(); |
20275 | if (i == 0) mark(reenterSmallK); |
20276 | jmpi(1 | ~FlagRegister::createFromIndex(i), done); |
20277 | copyStore(i + 1); |
20278 | } |
20279 | |
20280 | fence(); |
20281 | splitBarrier(); |
20282 | |
20283 | mark(done); |
20284 | barrierwait(); |
20285 | |
20286 | state.ra = saveRA; |
20287 | } |
20288 | |
20289 | template <HW hw> |
20290 | void gemm_kernel_generator_t<hw>::sysgemm2KLoopCompute( |
20291 | const GEMMProblem &problem, const GEMMStrategy &strategy, |
20292 | GEMMState &state) { |
20293 | using namespace sysgemm2; |
20294 | |
20295 | Label top, remainder, done; |
20296 | bool _32x = (strategy.unroll[LoopM] == 32); |
20297 | bool x48 = (strategy.unroll[LoopN] == 48); |
20298 | bool keepBarHdr = strategy.skipFence || !x48; |
20299 | auto slmStride = strategy.slmSysgemmBlockSize() / 16; |
20300 | auto = x32::barrierHeader; |
20301 | |
20302 | mov(1, f0.ud(0), 0); |
20303 | mov(1, f1.ud(0), 0); |
20304 | if (x48) { |
20305 | using namespace x48; |
20306 | add(1, A_addr[1], A_addr[0], 8 * 32 / 16); |
20307 | if (_32x) { |
20308 | add(1, A_addr[2], A_addr[0], 16 * 32 / 16); |
20309 | add(1, A_addr[3], A_addr[0], 24 * 32 / 16); |
20310 | } |
20311 | add(1, B_addr[1], B_addr[0], 8 * 32 / 16); |
20312 | add(1, B_addr[2], B_addr[0], 16 * 32 / 16); |
20313 | } else { |
20314 | using namespace x32; |
20315 | add(1, A_addr[0][1], A_addr[0][0], 8 * 32 / 16); |
20316 | if (_32x) { |
20317 | add(1, A_addr[0][2], A_addr[0][0], 16 * 32 / 16); |
20318 | add(1, A_addr[0][3], A_addr[0][0], 24 * 32 / 16); |
20319 | } |
20320 | add(1, A_addr[1][0], A_addr[0][0], 0 * 32 / 16 + slmStride); |
20321 | add(1, A_addr[1][1], A_addr[0][0], 8 * 32 / 16 + slmStride); |
20322 | if (_32x) { |
20323 | add(1, A_addr[1][2], A_addr[0][0], 16 * 32 / 16 + slmStride); |
20324 | add(1, A_addr[1][3], A_addr[0][0], 24 * 32 / 16 + slmStride); |
20325 | } |
20326 | add(1, B_addr[0][1], B_addr[0][0], 8 * 32 / 16); |
20327 | add(1, B_addr[1][0], B_addr[0][0], 0 * 32 / 16 + slmStride); |
20328 | add(1, B_addr[1][1], B_addr[0][0], 8 * 32 / 16 + slmStride); |
20329 | } |
20330 | |
20331 | if (keepBarHdr) sysgemmBarrierPrep(InstructionModifier(), barrierHeader); |
20332 | |
20333 | // Warmup: signal, split barrier, load |
20334 | cmp(1 | gt | f1[1], kCounter, 1); |
20335 | add(1 | gt | f0[0], kCounter, kCounter, -5); |
20336 | |
20337 | if (!keepBarHdr) sysgemmBarrierPrep(InstructionModifier(), barrierHeader); |
20338 | barriermsg(sb15, barrierHeader); |
20339 | barrierwait(); |
20340 | barriermsg(sb15 | f1[1], barrierHeader); |
20341 | |
20342 | bool oldDefaultAutoSWSB = getDefaultAutoSWSB(); |
20343 | setDefaultAutoSWSB(false); |
20344 | sync.nop(SWSB<AllPipes>(1)); |
20345 | |
20346 | load(16 | sb0, x48::A_regs[0], block_oword(16), SLM, x48::A_addr[0]); |
20347 | load(16 | sb1, x48::A_regs[8], block_oword(16), SLM, x48::A_addr[1]); |
20348 | if (_32x) { |
20349 | load(16 | sb2, x48::A_regs[16], block_oword(16), SLM, x48::A_addr[2]); |
20350 | load(16 | sb3, x48::A_regs[24], block_oword(16), SLM, x48::A_addr[3]); |
20351 | } |
20352 | load(16 | sb4, x48::B_regs[0], block_oword(16), SLM, x48::B_addr[0]); |
20353 | load(16 | sb5, x48::B_regs[8], block_oword(16), SLM, x48::B_addr[1]); |
20354 | if (x48) |
20355 | load(16 | sb6, x48::B_regs[16], block_oword(16), SLM, x48::B_addr[2]); |
20356 | |
20357 | zeroMatrix(x48 ? x48::C_regs : x32::C_regs, strategy); |
20358 | |
20359 | jmpi(1 | ~f0[0], remainder); |
20360 | |
20361 | mark(top); |
20362 | { |
20363 | add(1 | gt | f0[0], kCounter, kCounter, -4); |
20364 | sysgemm2Multiply(problem, strategy, state, 0); |
20365 | sysgemm2Multiply(problem, strategy, state, 1); |
20366 | sysgemm2Multiply(problem, strategy, state, 2); |
20367 | sysgemm2Multiply(problem, strategy, state, 3); |
20368 | } |
20369 | jmpi(1 | f0[0], top); |
20370 | |
20371 | mark(remainder); |
20372 | |
20373 | cmp(1 | gt | f0[0], kCounter, 1 - 5); |
20374 | cmp(1 | gt | f0[1], kCounter, 2 - 5); |
20375 | cmp(1 | gt | f1[0], kCounter, 3 - 5); |
20376 | cmp(1 | gt | f1[1], kCounter, 4 - 5); |
20377 | sysgemm2Multiply(problem, strategy, state, 0, true, f0[0], f0[1]); |
20378 | jmpi(1 | ~f0[0], done); |
20379 | |
20380 | sysgemm2Multiply(problem, strategy, state, 1, true, f0[1], f1[0]); |
20381 | jmpi(1 | ~f0[1], done); |
20382 | |
20383 | sysgemm2Multiply(problem, strategy, state, 2, true, f1[0], f1[1]); |
20384 | jmpi(1 | ~f1[0], done); |
20385 | |
20386 | sysgemm2Multiply(problem, strategy, state, 3, true, f1[1]); |
20387 | jmpi(1 | ~f1[1], done); |
20388 | |
20389 | sysgemm2Multiply(problem, strategy, state, 0, true); |
20390 | |
20391 | mark(done); |
20392 | |
20393 | setDefaultAutoSWSB(oldDefaultAutoSWSB); |
20394 | } |
20395 | |
20396 | template <HW hw> |
20397 | void gemm_kernel_generator_t<hw>::sysgemm2Multiply(const GEMMProblem &problem, |
20398 | const GEMMStrategy &strategy, GEMMState &state, int slmBuffer, |
20399 | bool cooldown, FlagRegister flagWaitLoad, FlagRegister flagSignal) { |
20400 | if (strategy.unroll[LoopN] == 48) |
20401 | sysgemm2MultiplyX48(problem, strategy, state, slmBuffer, cooldown, |
20402 | flagWaitLoad, flagSignal); |
20403 | else |
20404 | sysgemm2MultiplyX32(problem, strategy, state, slmBuffer, cooldown, |
20405 | flagWaitLoad, flagSignal); |
20406 | } |
20407 | |
20408 | template <HW hw> |
20409 | void gemm_kernel_generator_t<hw>::sysgemm2MultiplyX48( |
20410 | const GEMMProblem &problem, const GEMMStrategy &strategy, |
20411 | GEMMState &state, int slmBuffer, bool cooldown, |
20412 | FlagRegister flagWaitLoad, FlagRegister flagSignal) { |
20413 | using namespace sysgemm2; |
20414 | using namespace sysgemm2::x48; |
20415 | |
20416 | auto slmStride = strategy.slmSysgemmBlockSize() / 16; |
20417 | int16_t advance = ((slmBuffer == 3) ? -3 : 1) * slmStride; |
20418 | InstructionModifier loadMod {}, signalMod {}; |
20419 | bool doWaitLoad = !cooldown || flagWaitLoad.isValid(); |
20420 | bool doSignal = !cooldown || flagSignal.isValid(); |
20421 | if (cooldown) { |
20422 | if (doWaitLoad) loadMod = loadMod | flagWaitLoad | any16h; |
20423 | if (doSignal) signalMod = signalMod | flagSignal | any16h; |
20424 | } |
20425 | |
20426 | if (strategy.unroll[LoopM] != 32) stub(); |
20427 | |
20428 | if (doWaitLoad) { |
20429 | Label skipWait; |
20430 | if (cooldown) jmpi(1 | ~flagWaitLoad, skipWait); |
20431 | |
20432 | // SLM fence |
20433 | if (!strategy.skipFence && doSignal) |
20434 | slmfence(sb15 | signalMod, headerTemp, headerTemp); |
20435 | |
20436 | // Barrier wait |
20437 | barrierwait(); |
20438 | |
20439 | // Barrier signal |
20440 | if (doSignal) { |
20441 | if (!strategy.skipFence) { |
20442 | sysgemmBarrierPrep(sb15.dst | signalMod, headerTemp); |
20443 | barriermsg(sb15 | SWSB<AllPipes>(1) | signalMod, headerTemp); |
20444 | } else |
20445 | barriermsg(sb15 | signalMod, headerTemp); |
20446 | } |
20447 | |
20448 | if (cooldown) mark(skipWait); |
20449 | } |
20450 | |
20451 | // Advance A0 address (note dst) |
20452 | if (doWaitLoad) add(1 | sb0.dst, A_addr[0], A_addr[0], advance); |
20453 | |
20454 | // Rows 0-7 |
20455 | sysgemm2MultiplyChunkX48(problem, strategy, 0); |
20456 | |
20457 | // Advance A1 address |
20458 | if (doWaitLoad) add(1 | sb1.src, A_addr[1], A_addr[1], advance); |
20459 | |
20460 | // Rows 8-15 |
20461 | sysgemm2MultiplyChunkX48(problem, strategy, 1); |
20462 | |
20463 | if (doWaitLoad) { |
20464 | // Load new A0 |
20465 | load(16 | sb0 | loadMod, A_regs[0], block_oword(16), SLM, A_addr[0]); |
20466 | |
20467 | // Advance B, A2, A3 addresses |
20468 | add(1 | sb4.src, B_addr[0], B_addr[0], advance); |
20469 | add(1 | sb5.src, B_addr[1], B_addr[1], advance); |
20470 | add(1 | sb6.src, B_addr[2], B_addr[2], advance); |
20471 | |
20472 | add(1 | sb2.src, A_addr[2], A_addr[2], advance); |
20473 | add(1 | sb3.src, A_addr[3], A_addr[3], advance); |
20474 | } |
20475 | |
20476 | // Rows 16-23 |
20477 | sysgemm2MultiplyChunkX48(problem, strategy, 2); |
20478 | |
20479 | // Load new A1 |
20480 | if (doWaitLoad) |
20481 | load(16 | sb1 | loadMod, A_regs[8], block_oword(16), SLM, A_addr[1]); |
20482 | |
20483 | // Rows 24-31 |
20484 | sysgemm2MultiplyChunkX48(problem, strategy, 3); |
20485 | |
20486 | if (doWaitLoad) { |
20487 | // Load new B data |
20488 | load(16 | sb4 | loadMod, B_regs[0], block_oword(16), SLM, B_addr[0]); |
20489 | load(16 | sb5 | loadMod, B_regs[8], block_oword(16), SLM, B_addr[1]); |
20490 | load(16 | sb6 | loadMod, B_regs[16], block_oword(16), SLM, B_addr[2]); |
20491 | |
20492 | // Load new A2,A3 |
20493 | load(16 | sb2 | loadMod, A_regs[16], block_oword(16), SLM, A_addr[2]); |
20494 | load(16 | sb3 | loadMod, A_regs[24], block_oword(16), SLM, A_addr[3]); |
20495 | } |
20496 | } |
20497 | |
20498 | template <HW hw> |
20499 | void gemm_kernel_generator_t<hw>::sysgemm2MultiplyX32( |
20500 | const GEMMProblem &problem, const GEMMStrategy &strategy, |
20501 | GEMMState &state, int slmBuffer, bool cooldown, |
20502 | FlagRegister flagWaitLoad, FlagRegister flagSignal) { |
20503 | using namespace sysgemm2; |
20504 | using namespace sysgemm2::x32; |
20505 | |
20506 | auto slmStride = strategy.slmSysgemmBlockSize() / 16; |
20507 | int16_t advance = ((slmBuffer >= 2) ? -2 : 2) * slmStride; |
20508 | InstructionModifier loadMod {}, signalMod {}; |
20509 | bool doWaitLoad = !cooldown || flagWaitLoad.isValid(); |
20510 | bool doSignal = !cooldown || flagSignal.isValid(); |
20511 | if (cooldown) { |
20512 | if (doWaitLoad) loadMod = loadMod | flagWaitLoad | any16h; |
20513 | if (doSignal) signalMod = signalMod | flagSignal | any16h; |
20514 | } |
20515 | bool _32x = (strategy.unroll[LoopM] == 32); |
20516 | bool odd = (slmBuffer & 1); |
20517 | int tokenBase = odd ? 8 : 0; |
20518 | int otokenBase = odd ? 0 : 8; |
20519 | |
20520 | if (doWaitLoad) { |
20521 | Label skipWait; |
20522 | if (cooldown) jmpi(1 | ~flagWaitLoad, skipWait); |
20523 | |
20524 | // SLM fence |
20525 | if (!strategy.skipFence && doSignal) |
20526 | slmfence(sb15 | signalMod, fenceHeader, fenceHeader); |
20527 | |
20528 | add(1 | SBID(tokenBase + 0).src, A_addr[odd][0], A_addr[odd][0], |
20529 | advance); // TODO: reuse src0 |
20530 | add(1 | SBID(tokenBase + 4).src, B_addr[odd][0], B_addr[odd][0], |
20531 | advance); |
20532 | add(1 | SBID(tokenBase + 5).src, B_addr[odd][1], B_addr[odd][1], |
20533 | advance); |
20534 | add(1 | SBID(tokenBase + 1).src, A_addr[odd][1], A_addr[odd][1], |
20535 | advance); |
20536 | if (_32x) { |
20537 | add(1 | SBID(tokenBase + 2).src, A_addr[odd][2], A_addr[odd][2], |
20538 | advance); |
20539 | add(1 | SBID(tokenBase + 3).src, A_addr[odd][3], A_addr[odd][3], |
20540 | advance); |
20541 | } |
20542 | |
20543 | // Barrier wait |
20544 | barrierwait(); |
20545 | |
20546 | if (hw >= HW::XeHPG) { |
20547 | // Wait for SLM loads to return before signaling. |
20548 | sync.allwr(0x3F << tokenBase); |
20549 | } |
20550 | |
20551 | // Barrier signal |
20552 | if (doSignal) barriermsg(sb15 | signalMod, barrierHeader); |
20553 | |
20554 | if (cooldown) mark(skipWait); |
20555 | } |
20556 | |
20557 | if (doWaitLoad) { |
20558 | load(16 | SBID(otokenBase + 0) | loadMod, A_regs[!odd][0], |
20559 | block_oword(16), SLM, A_addr[!odd][0]); |
20560 | load(16 | SBID(otokenBase + 4) | loadMod, B_regs[!odd][0], |
20561 | block_oword(16), SLM, B_addr[!odd][0]); |
20562 | load(16 | SBID(otokenBase + 5) | loadMod, B_regs[!odd][8], |
20563 | block_oword(16), SLM, B_addr[!odd][1]); |
20564 | load(16 | SBID(otokenBase + 1) | loadMod, A_regs[!odd][8], |
20565 | block_oword(16), SLM, A_addr[!odd][1]); |
20566 | if (_32x) { |
20567 | load(16 | SBID(otokenBase + 2) | loadMod, A_regs[!odd][16], |
20568 | block_oword(16), SLM, A_addr[!odd][2]); |
20569 | load(16 | SBID(otokenBase + 3) | loadMod, A_regs[!odd][24], |
20570 | block_oword(16), SLM, A_addr[!odd][3]); |
20571 | } |
20572 | } |
20573 | |
20574 | sysgemm2MultiplyChunkX32(problem, strategy, 0, odd); |
20575 | sysgemm2MultiplyChunkX32(problem, strategy, 1, odd); |
20576 | if (_32x) { |
20577 | sysgemm2MultiplyChunkX32(problem, strategy, 2, odd); |
20578 | sysgemm2MultiplyChunkX32(problem, strategy, 3, odd); |
20579 | } |
20580 | } |
20581 | |
20582 | template <HW hw> |
20583 | void gemm_kernel_generator_t<hw>::sysgemm2MultiplyChunkX48( |
20584 | const GEMMProblem &problem, const GEMMStrategy &strategy, int chunkA) { |
20585 | using namespace sysgemm2; |
20586 | using namespace sysgemm2::x48; |
20587 | int ao = chunkA * 8; |
20588 | int co = ao * 6; |
20589 | bool waitB = (chunkA == 0); |
20590 | bool prepB = (chunkA == 3); |
20591 | SBID sbA(chunkA); |
20592 | |
20593 | auto dpaswTyped = [&](InstructionModifier mod, uint8_t sdepth, |
20594 | uint8_t rcount, const GRF &cReg, const GRF &aReg, |
20595 | const GRF &bReg) { |
20596 | dpasw(mod, sdepth, rcount, cReg.retype(problem.Tc.ngen()), |
20597 | cReg.retype(problem.Tc.ngen()), aReg.retype(problem.Ta.ngen()), |
20598 | bReg.retype(problem.Tb.ngen())); |
20599 | }; |
20600 | |
20601 | if (waitB) { |
20602 | /* sync.nop(sbA.dst); */ // arranged by caller |
20603 | dpaswTyped( |
20604 | 8 | sb4.dst | Atomic, 8, 8, C_regs[co], A_regs[ao], B_regs[0]); |
20605 | dpaswTyped(8, 8, 8, C_regs[co + 8], A_regs[ao], B_regs[4]); |
20606 | dpaswTyped(8 | sb5.dst | Atomic, 8, 8, C_regs[co + 16], A_regs[ao], |
20607 | B_regs[8]); |
20608 | dpaswTyped(8, 8, 8, C_regs[co + 24], A_regs[ao], B_regs[12]); |
20609 | dpaswTyped(8 | sb6.dst | Atomic, 8, 8, C_regs[co + 32], A_regs[ao], |
20610 | B_regs[16]); |
20611 | dpaswTyped(8 | sbA, 8, 8, C_regs[co + 40], A_regs[ao], B_regs[20]); |
20612 | } else if (prepB) { |
20613 | dpaswTyped( |
20614 | 8 | sbA.dst | Atomic, 8, 8, C_regs[co], A_regs[ao], B_regs[0]); |
20615 | dpaswTyped(8 | sb4, 8, 8, C_regs[co + 8], A_regs[ao], B_regs[4]); |
20616 | dpaswTyped(8 | Atomic, 8, 8, C_regs[co + 16], A_regs[ao], B_regs[8]); |
20617 | dpaswTyped(8 | sb5, 8, 8, C_regs[co + 24], A_regs[ao], B_regs[12]); |
20618 | dpaswTyped(8 | Atomic, 8, 8, C_regs[co + 32], A_regs[ao], B_regs[16]); |
20619 | dpaswTyped(8 | sb6, 8, 8, C_regs[co + 40], A_regs[ao], B_regs[20]); |
20620 | } else { |
20621 | dpaswTyped( |
20622 | 8 | sbA.dst | Atomic, 8, 8, C_regs[co], A_regs[ao], B_regs[0]); |
20623 | dpaswTyped(8 | Atomic, 8, 8, C_regs[co + 8], A_regs[ao], B_regs[4]); |
20624 | dpaswTyped(8 | Atomic, 8, 8, C_regs[co + 16], A_regs[ao], B_regs[8]); |
20625 | dpaswTyped(8 | Atomic, 8, 8, C_regs[co + 24], A_regs[ao], B_regs[12]); |
20626 | dpaswTyped(8 | Atomic, 8, 8, C_regs[co + 32], A_regs[ao], B_regs[16]); |
20627 | dpaswTyped(8 | sbA, 8, 8, C_regs[co + 40], A_regs[ao], B_regs[20]); |
20628 | } |
20629 | } |
20630 | |
20631 | template <HW hw> |
20632 | void gemm_kernel_generator_t<hw>::sysgemm2MultiplyChunkX32( |
20633 | const GEMMProblem &problem, const GEMMStrategy &strategy, int chunkA, |
20634 | bool odd) { |
20635 | using namespace sysgemm2; |
20636 | using namespace sysgemm2::x32; |
20637 | int ao = chunkA * 8; |
20638 | int co = ao * 4; |
20639 | int nchunks = strategy.unroll[LoopM] / 8; |
20640 | bool waitB = (chunkA == 0); |
20641 | bool prepB = (chunkA == nchunks - 1); |
20642 | int tokenBase = odd ? 8 : 0; |
20643 | SBID sbA(tokenBase + chunkA); |
20644 | SBID sbB0(tokenBase + 4); |
20645 | SBID sbB1(tokenBase + 5); |
20646 | |
20647 | auto dpaswTyped = [&](InstructionModifier mod, uint8_t sdepth, |
20648 | uint8_t rcount, const GRF &cReg, const GRF &aReg, |
20649 | const GRF &bReg) { |
20650 | dpasw(mod, sdepth, rcount, cReg.retype(problem.Tc.ngen()), |
20651 | cReg.retype(problem.Tc.ngen()), aReg.retype(problem.Ta.ngen()), |
20652 | bReg.retype(problem.Tb.ngen())); |
20653 | }; |
20654 | |
20655 | if (waitB) { |
20656 | sync.nop(sbA.dst); |
20657 | dpaswTyped(8 | sbB0.dst | Atomic, 8, 8, C_regs[co], A_regs[odd][ao], |
20658 | B_regs[odd][0]); |
20659 | dpaswTyped(8, 8, 8, C_regs[co + 8], A_regs[odd][ao], B_regs[odd][4]); |
20660 | dpaswTyped(8 | sbB1.dst | Atomic, 8, 8, C_regs[co + 16], |
20661 | A_regs[odd][ao], B_regs[odd][8]); |
20662 | dpaswTyped(8 | sbA, 8, 8, C_regs[co + 24], A_regs[odd][ao], |
20663 | B_regs[odd][12]); |
20664 | } else if (prepB) { |
20665 | dpaswTyped(8 | sbA.dst | Atomic, 8, 8, C_regs[co], A_regs[odd][ao], |
20666 | B_regs[odd][0]); |
20667 | dpaswTyped(8 | sbB0, 8, 8, C_regs[co + 8], A_regs[odd][ao], |
20668 | B_regs[odd][4]); |
20669 | dpaswTyped(8 | Atomic, 8, 8, C_regs[co + 16], A_regs[odd][ao], |
20670 | B_regs[odd][8]); |
20671 | dpaswTyped(8 | sbB1, 8, 8, C_regs[co + 24], A_regs[odd][ao], |
20672 | B_regs[odd][12]); |
20673 | } else { |
20674 | dpaswTyped(8 | sbA.dst | Atomic, 8, 8, C_regs[co], A_regs[odd][ao], |
20675 | B_regs[odd][0]); |
20676 | dpaswTyped(8 | Atomic, 8, 8, C_regs[co + 8], A_regs[odd][ao], |
20677 | B_regs[odd][4]); |
20678 | dpaswTyped(8 | Atomic, 8, 8, C_regs[co + 16], A_regs[odd][ao], |
20679 | B_regs[odd][8]); |
20680 | dpaswTyped(8 | sbA, 8, 8, C_regs[co + 24], A_regs[odd][ao], |
20681 | B_regs[odd][12]); |
20682 | } |
20683 | } |
20684 | |
20685 | /**********************************************************************/ |
20686 | /* Copy Kernels */ |
20687 | /**********************************************************************/ |
20688 | |
20689 | // Initialize the interface and claim arguments. |
20690 | template <HW hw> |
20691 | void gemm_kernel_generator_t<hw>::copyInitInterface( |
20692 | CopyProblem &problem, CopyStrategy &strategy, CopyState &state) { |
20693 | if (strategy.barrierFreq > 0) interface.requireBarrier(); |
20694 | |
20695 | interface.finalize(); |
20696 | |
20697 | // Get input register assignments. |
20698 | state.inputs.S = interface.getArgumentIfExists("S" ); |
20699 | state.inputs.D = interface.getArgumentIfExists("D" ); |
20700 | state.inputs.surfaceS = interface.getArgumentSurfaceIfExists("S" ); |
20701 | state.inputs.surfaceD = interface.getArgumentSurfaceIfExists("D" ); |
20702 | state.inputs.offsetS = interface.getArgument("offset_S" ); |
20703 | state.inputs.offsetD = interface.getArgument("offset_D" ); |
20704 | state.inputs.lds = interface.getArgument("lds" ); |
20705 | state.inputs.ldd = interface.getArgumentIfExists("ldd" ); |
20706 | state.inputs.m = interface.getArgument("m" ); |
20707 | state.inputs.n = interface.getArgument("n" ); |
20708 | state.inputs.alpha_real = interface.getArgumentIfExists("alpha_real" ); |
20709 | state.inputs.alpha_imag = interface.getArgumentIfExists("alpha_imag" ); |
20710 | state.inputs.diag = interface.getArgumentIfExists("diag" ); |
20711 | state.inputs.blockZ = interface.getArgumentIfExists("block_z" ); |
20712 | |
20713 | state.inputs.localIDW = interface.getLocalID(0); |
20714 | state.inputs.localSizeW = interface.getLocalSize(0); |
20715 | if (strategy.zParallel) { |
20716 | state.inputs.localIDZ = interface.getLocalID(1); |
20717 | state.inputs.localSizeZ = interface.getLocalSize(1); |
20718 | } |
20719 | |
20720 | state.inputs.groupIDW = r0.ud(1); |
20721 | if (strategy.zParallel) state.inputs.groupIDZ = r0.ud(6); |
20722 | |
20723 | // Downgrade offset variables to 32-bit for non-A64 accesses. |
20724 | if (strategy.S.base.getModel() != ModelA64) |
20725 | state.inputs.offsetS = state.inputs.offsetS.d(); |
20726 | if (strategy.D.base.getModel() != ModelA64) |
20727 | state.inputs.offsetD = state.inputs.offsetD.d(); |
20728 | |
20729 | // For now, reinterpret m/n/ld/diag variables to 32-bit if they are 64-bit. |
20730 | state.inputs.m = state.inputs.m.d(); |
20731 | state.inputs.n = state.inputs.n.d(); |
20732 | state.inputs.lds = state.inputs.lds.ud(); |
20733 | if (state.inputs.ldd.isValid()) state.inputs.ldd = state.inputs.ldd.ud(); |
20734 | if (state.inputs.diag.isValid()) state.inputs.diag = state.inputs.diag.d(); |
20735 | |
20736 | // Claim inputs. |
20737 | for (int i = 0; i < 4; i++) |
20738 | state.ra.claim(r0.uq(i)); |
20739 | |
20740 | if (strategy.S.base.isStateless()) state.ra.claim(state.inputs.S); |
20741 | if (strategy.D.base.isStateless()) state.ra.claim(state.inputs.D); |
20742 | |
20743 | state.ra.claim(state.inputs.offsetS); |
20744 | state.ra.claim(state.inputs.offsetD); |
20745 | state.ra.claim(state.inputs.lds); |
20746 | if (state.inputs.ldd.isValid()) state.ra.claim(state.inputs.ldd); |
20747 | state.ra.claim(state.inputs.m); |
20748 | state.ra.claim(state.inputs.n); |
20749 | if (state.inputs.diag.isValid()) state.ra.claim(state.inputs.diag); |
20750 | if (!problem.alpha_real.fixed()) state.ra.claim(state.inputs.alpha_real); |
20751 | if (problem.Td.isComplex() && !problem.alpha_imag.fixed()) |
20752 | state.ra.claim(state.inputs.alpha_imag); |
20753 | |
20754 | state.ra.claim(state.inputs.localIDW); |
20755 | state.ra.claim(state.inputs.localSizeW); |
20756 | if (strategy.zParallel) { |
20757 | state.ra.claim(state.inputs.localIDZ); |
20758 | state.ra.claim(state.inputs.localSizeZ); |
20759 | } |
20760 | |
20761 | if (strategy.zParallel) state.ra.claim(state.inputs.blockZ); |
20762 | } |
20763 | |
20764 | // Initialize the state structure. |
20765 | template <HW hw> |
20766 | void gemm_kernel_generator_t<hw>::copyInitState( |
20767 | CopyProblem &problem, CopyStrategy &strategy, CopyState &state) { |
20768 | if (!state.fusedGEMM.active) { |
20769 | initState(problem, strategy, state); |
20770 | copyInitInterface(problem, strategy, state); |
20771 | state.isNested = false; |
20772 | } |
20773 | |
20774 | state.effS = strategy.S.base.isStateless() ? state.inputs.S |
20775 | : state.inputs.offsetS.d(); |
20776 | state.effD = strategy.D.base.isStateless() ? state.inputs.D |
20777 | : state.inputs.offsetD.d(); |
20778 | |
20779 | if (!problem.alpha_real.fixed()) |
20780 | problem.alpha_real = state.inputs.alpha_real; |
20781 | if (problem.Td.isComplex() && !problem.alpha_imag.fixed()) |
20782 | problem.alpha_imag = state.inputs.alpha_imag; |
20783 | |
20784 | state.flagAP = state.raVFlag.alloc(); |
20785 | |
20786 | state.allocEmulate64Temp(strategy.emulate); |
20787 | } |
20788 | |
20789 | // Copy kernel generation interface. |
20790 | template <HW hw> |
20791 | void gemm_kernel_generator_t<hw>::copy(CopyProblem problem, |
20792 | CopyStrategy strategy, const InterfaceHandler &interface_) { |
20793 | interface = interface_; |
20794 | CopyState state(hw); |
20795 | copy(problem, strategy, state); |
20796 | } |
20797 | |
20798 | template <HW hw> |
20799 | void gemm_kernel_generator_t<hw>::copy( |
20800 | CopyProblem &problem, CopyStrategy &strategy, CopyState &state) { |
20801 | bool inFused = state.fusedGEMM.active; |
20802 | auto unrollW = strategy.unrollW(); |
20803 | |
20804 | // Check layouts. |
20805 | if (!isPacked(problem.D.layout)) stub(); |
20806 | |
20807 | if (strategy.zParallel && problem.sum) stub(); |
20808 | |
20809 | // By default, don't use dispatch mask. |
20810 | setDefaultNoMask(); |
20811 | setDefaultAutoSWSB(); |
20812 | |
20813 | // Set up. |
20814 | copyInitState(problem, strategy, state); |
20815 | |
20816 | if (!strategy.S.base.isStateless()) |
20817 | strategy.S.base.setIndex(state.inputs.surfaceS); |
20818 | if (!strategy.D.base.isStateless()) |
20819 | strategy.D.base.setIndex(state.inputs.surfaceD); |
20820 | |
20821 | // Prologue. |
20822 | if (!inFused) prologue(strategy); |
20823 | |
20824 | // Grab fused ID if needed. |
20825 | getFusedID(unrollW, problem, strategy, state); |
20826 | |
20827 | // Calculate w0, the starting row/column for this thread. |
20828 | // This is the first x (if xloop = false) or y (xloop = true) value. |
20829 | state.w0 = state.ra.alloc_sub<uint32_t>( |
20830 | getHint(HintType::TempComp0, strategy)); |
20831 | if (strategy.zParallel) |
20832 | state.z0 = state.ra.alloc_sub<uint32_t>( |
20833 | getHint(HintType::TempComp1, strategy)); |
20834 | |
20835 | auto globalIDW = state.ra.alloc_sub<uint32_t>( |
20836 | getHint(HintType::TempComp1, strategy)); |
20837 | auto globalIDZ = state.ra.alloc_sub<uint32_t>( |
20838 | getHint(HintType::TempComp0, strategy)); |
20839 | |
20840 | int idWScale = inFused ? 1 : strategy.subgroupSize; |
20841 | bool multiple = (unrollW % idWScale) == 0; |
20842 | |
20843 | if (strategy.wgW > 0) |
20844 | mulConstant( |
20845 | 1, globalIDW, state.inputs.groupIDW, strategy.wgW * idWScale); |
20846 | else |
20847 | mul(1, globalIDW, state.inputs.groupIDW, state.inputs.localSizeW.uw()); |
20848 | if (strategy.zParallel) { |
20849 | if (strategy.wgZ > 0) |
20850 | mulConstant(1, globalIDZ, state.inputs.groupIDZ, strategy.wgZ); |
20851 | else |
20852 | mul(1, globalIDZ, state.inputs.groupIDZ, |
20853 | state.inputs.localSizeZ.uw()); |
20854 | } |
20855 | add(1, globalIDW, globalIDW, state.inputs.localIDW.uw(0)); |
20856 | if (strategy.zParallel && (strategy.wgZ != 1)) |
20857 | add(1, globalIDZ, globalIDZ, state.inputs.localIDZ.uw(0)); |
20858 | if (multiple) |
20859 | mulConstant(1, state.w0, globalIDW, unrollW / idWScale); |
20860 | else { |
20861 | mulConstant(1, state.w0, globalIDW, unrollW); |
20862 | shr(1, state.w0, state.w0, log2(idWScale)); |
20863 | } |
20864 | if (strategy.zParallel) |
20865 | emul(1, state.z0, globalIDZ, state.inputs.blockZ, strategy, state); |
20866 | |
20867 | state.ra.safeRelease(globalIDW); |
20868 | state.ra.safeRelease(globalIDZ); |
20869 | state.ra.safeRelease(state.inputs.localIDW); |
20870 | state.ra.safeRelease(state.inputs.localIDZ); |
20871 | state.ra.safeRelease(state.inputs.localSizeW); |
20872 | state.ra.safeRelease(state.inputs.localSizeZ); |
20873 | |
20874 | // Move r0 to acc0 if configured. |
20875 | moveR0(strategy, state); |
20876 | |
20877 | // Copy our slice. |
20878 | copySlice(problem, strategy, state); |
20879 | |
20880 | if (!inFused) { |
20881 | epilogue(strategy, state); |
20882 | |
20883 | padding(); |
20884 | } |
20885 | } |
20886 | |
20887 | // Calculate or recalculate lds_sl/ldd_dl as needed. |
20888 | template <HW hw> |
20889 | void gemm_kernel_generator_t<hw>::copyCalcIncrements(const CopyProblem &problem, |
20890 | const CopyStrategy &strategy, CopyState &state, int s_load, |
20891 | int d_load) { |
20892 | // S: w0 * s_load is needed for N->Pc, T->Pr [!xLoop] N->Pr, T->Pc [xLoop] |
20893 | // D: no increment needed (always packed) [!xLoop] ldd * d_load [xLoop] |
20894 | bool sStrided |
20895 | = (isColMajor(problem.S.layout) == isColMajor(problem.D.layout)) |
20896 | ^ strategy.xLoop; |
20897 | |
20898 | if (sStrided || problem.reflecting()) { |
20899 | if (s_load == 0) s_load = strategy.s_load; |
20900 | if (s_load > 1) { |
20901 | if (state.lds_sl.isInvalid()) { |
20902 | state.lds_sl = state.ra.alloc_sub<uint32_t>(); |
20903 | s_load *= problem.Ts.size(); |
20904 | } |
20905 | emulConstant( |
20906 | 1, state.lds_sl, state.inputs.lds, s_load, strategy, state); |
20907 | } |
20908 | } |
20909 | |
20910 | if (strategy.xLoop) { |
20911 | if (d_load == 0) d_load = strategy.d_load; |
20912 | if (d_load > 1) { |
20913 | if (state.ldd_dl.isInvalid()) { |
20914 | state.ldd_dl = state.ra.alloc_sub<uint32_t>(); |
20915 | d_load *= problem.Td.size(); |
20916 | } |
20917 | emulConstant( |
20918 | 1, state.ldd_dl, state.inputs.ldd, d_load, strategy, state); |
20919 | } |
20920 | } |
20921 | } |
20922 | |
20923 | // Copy kernel generation interface. |
20924 | template <HW hw> |
20925 | void gemm_kernel_generator_t<hw>::copySlice( |
20926 | CopyProblem &problem, CopyStrategy &strategy, CopyState &state) { |
20927 | auto Ts = problem.Ts, Td = problem.Td; |
20928 | Label labelExit; |
20929 | Subregister lddSrc; |
20930 | |
20931 | // If ldd not specified, use y. |
20932 | if (state.inputs.ldd.isInvalid()) { |
20933 | state.inputs.ldd = lddSrc = (problem.D.layout == MatrixLayout::Pc) |
20934 | ? state.inputs.n |
20935 | : state.inputs.m; |
20936 | if (problem.D.crosspack > 1 || problem.sum || strategy.zParallel) { |
20937 | state.inputs.ldd = state.ra.alloc_sub<uint32_t>( |
20938 | getHint(HintType::LongTerm, strategy)); |
20939 | mov(1, state.inputs.ldd, lddSrc); |
20940 | lddSrc = invalid; |
20941 | } |
20942 | if (problem.D.crosspack > 1) { |
20943 | add(1, state.inputs.ldd, state.inputs.ldd, problem.D.crosspack - 1); |
20944 | and_(1, state.inputs.ldd, state.inputs.ldd, |
20945 | ~uint32_t(problem.D.crosspack - 1)); |
20946 | } |
20947 | if (problem.sum) |
20948 | add(1, state.inputs.ldd, state.inputs.ldd, |
20949 | problem.Tsum.size() / problem.Td.size()); |
20950 | } |
20951 | |
20952 | // Duplicate alpha if configured. |
20953 | if (strategy.duplicateAlpha) { duplicateScalar(problem.alpha_real, state); } |
20954 | |
20955 | // For fused kernels, compute 2 * unrollW - fusedID for use in several places. |
20956 | Subregister unrollWRem; |
20957 | if (strategy.fused) { |
20958 | unrollWRem = state.ra.alloc_sub<uint32_t>( |
20959 | getHint(HintType::TempComp0, strategy)); |
20960 | add(1, unrollWRem, -state.fusedID, uint16_t(2 * strategy.unrollW())); |
20961 | } |
20962 | |
20963 | // Align code paths. |
20964 | bool mLoop = isColMajor(problem.D.layout) == strategy.xLoop; |
20965 | auto z = mLoop ? state.inputs.m : state.inputs.n; |
20966 | Subregister z0; |
20967 | |
20968 | // Handle z blocking. |
20969 | if (strategy.zParallel) { |
20970 | z0 = state.z0; |
20971 | add(1 | le | f0[1], z, z, -z0); |
20972 | min_(1, z, z, state.inputs.blockZ); |
20973 | state.ra.safeRelease(state.inputs.blockZ); |
20974 | } |
20975 | |
20976 | // Compute base addresses for S, D. |
20977 | // S += w0 + z0 * lds (N->Pc, T->Pr) z0 + w0 * lds (N->Pr, T->Pc) [swapped if xLoop = true] |
20978 | bool sStrided |
20979 | = (isColMajor(problem.S.layout) == isColMajor(problem.D.layout)) |
20980 | ^ strategy.xLoop; |
20981 | auto incC = sStrided ? state.w0 : z0; |
20982 | auto incS = sStrided ? z0 : state.w0; |
20983 | |
20984 | if (incC.isValid()) |
20985 | eadd(1, state.inputs.offsetS, state.inputs.offsetS, incC, strategy, |
20986 | state); |
20987 | if (incS.isValid()) { |
20988 | Subregister temp = state.ra.alloc_sub(state.inputs.offsetS.getType(), |
20989 | getHint(HintType::TempComp1, strategy)); |
20990 | emul(1, temp, incS, state.inputs.lds, strategy, state); |
20991 | eadd(1, state.inputs.offsetS, state.inputs.offsetS, temp, strategy, |
20992 | state); |
20993 | state.ra.safeRelease(temp); |
20994 | } |
20995 | |
20996 | // Quick exit if no work to do. |
20997 | if (strategy.zParallel) jmpi(1 | f0[1], labelExit); |
20998 | |
20999 | // D += align_up(x0, unroll) * ldd + y0 * unroll + (x0 % unroll) * crosspack |
21000 | { |
21001 | Subregister temp0 = state.ra.alloc_sub(state.inputs.offsetD.getType(), |
21002 | getHint(HintType::TempComp0, strategy)); |
21003 | Subregister temp1 = state.ra.alloc_sub(state.inputs.offsetD.getType(), |
21004 | getHint(HintType::TempComp1, strategy)); |
21005 | Subregister temp2 = state.ra.alloc_sub<uint32_t>( |
21006 | getHint(HintType::TempComp0, strategy)); |
21007 | auto x0 = strategy.xLoop ? z0 : state.w0; |
21008 | auto y0 = strategy.xLoop ? state.w0 : z0; |
21009 | bool splitX = strategy.unrollX < problem.D.packSize; |
21010 | |
21011 | if (x0.isValid()) { |
21012 | if (splitX) { |
21013 | modExt(temp2, temp1.ud(), x0, problem.D.packSize, strategy, |
21014 | state); |
21015 | emul(1, temp0, temp1.ud(), state.inputs.ldd, strategy, state); |
21016 | mulConstant(1, temp2, temp2, problem.D.crosspack); |
21017 | } else |
21018 | emul(1, temp0, x0, state.inputs.ldd, strategy, state); |
21019 | } |
21020 | if (y0.isValid()) |
21021 | emulConstant(1, temp1, y0, problem.D.packSize, strategy, state); |
21022 | if (x0.isValid()) |
21023 | eadd(1, state.inputs.offsetD, state.inputs.offsetD, temp0, strategy, |
21024 | state); |
21025 | if (y0.isValid()) |
21026 | eadd(1, state.inputs.offsetD, state.inputs.offsetD, temp1, strategy, |
21027 | state); |
21028 | if (x0.isValid() && splitX) |
21029 | eadd(1, state.inputs.offsetD, state.inputs.offsetD, temp2, strategy, |
21030 | state); |
21031 | |
21032 | state.ra.safeRelease(temp0); |
21033 | state.ra.safeRelease(temp1); |
21034 | state.ra.safeRelease(temp2); |
21035 | } |
21036 | |
21037 | state.ra.safeRelease(z0); |
21038 | state.z0 = invalid; |
21039 | |
21040 | // Calculate increments. |
21041 | copyCalcIncrements(problem, strategy, state); |
21042 | |
21043 | // Calculate remainders for w loop as needed. |
21044 | if (!strategy.xLoop |
21045 | && (strategy.remHandlingX != RemainderHandling::Ignore)) { |
21046 | auto x = (problem.D.layout == MatrixLayout::Pc) ? state.inputs.m |
21047 | : state.inputs.n; |
21048 | state.remainderX = state.ra.alloc_sub<uint32_t>(); |
21049 | add(1 | sat, state.remainderX, -state.w0, x); |
21050 | if (strategy.remHandlingX == RemainderHandling::Split) { |
21051 | if (strategy.fused) |
21052 | cmp(1 | lt | state.flagAP, null.ud(), state.remainderX, |
21053 | unrollWRem); |
21054 | else |
21055 | cmp(1 | lt | state.flagAP, null.ud(), state.remainderX, |
21056 | strategy.unrollX); |
21057 | mov(1 | ~state.flagAP, state.remainderX, strategy.unrollX); |
21058 | } else |
21059 | min_(1, state.remainderX, state.remainderX, strategy.unrollX); |
21060 | } |
21061 | if (strategy.xLoop |
21062 | && (strategy.remHandlingY != RemainderHandling::Ignore)) { |
21063 | auto y = (problem.D.layout == MatrixLayout::Pc) ? state.inputs.n |
21064 | : state.inputs.m; |
21065 | state.remainderY = state.ra.alloc_sub<uint32_t>(); |
21066 | add(1 | sat, state.remainderY, -state.w0, y); |
21067 | if (strategy.remHandlingY == RemainderHandling::Split) { |
21068 | if (strategy.fused) |
21069 | cmp(1 | lt | state.flagAP, null.ud(), state.remainderY, |
21070 | unrollWRem); |
21071 | else |
21072 | cmp(1 | lt | state.flagAP, null.ud(), state.remainderY, |
21073 | strategy.unrollY); |
21074 | mov(1 | ~state.flagAP, state.remainderY, strategy.unrollY); |
21075 | } else |
21076 | min_(1, state.remainderY, state.remainderY, strategy.unrollY); |
21077 | } |
21078 | |
21079 | // Convert lds to bytes. |
21080 | emulConstant( |
21081 | 1, state.inputs.lds, state.inputs.lds, Ts.size(), strategy, state); |
21082 | |
21083 | // Add offsets to base pointers for stateless accesses. |
21084 | emulConstant(1, state.inputs.offsetS, state.inputs.offsetS, Ts.size(), |
21085 | strategy, state); |
21086 | emulConstant(1, state.inputs.offsetD, state.inputs.offsetD, Td.size(), |
21087 | strategy, state); |
21088 | |
21089 | if (strategy.S.base.isStateless()) { |
21090 | eadd(1, state.inputs.S, state.inputs.S, state.inputs.offsetS, strategy, |
21091 | state); |
21092 | |
21093 | state.ra.safeRelease(state.inputs.offsetS); |
21094 | } else |
21095 | state.effS1 = state.offsetS1; |
21096 | |
21097 | if (strategy.D.base.isStateless()) { |
21098 | eadd(1, state.inputs.D, state.inputs.D, state.inputs.offsetD, strategy, |
21099 | state); |
21100 | state.ra.safeRelease(state.inputs.offsetD); |
21101 | } |
21102 | |
21103 | state.ra.safeRelease(unrollWRem); |
21104 | |
21105 | if (!copyBody(problem, strategy, state)) { |
21106 | lastException ? std::rethrow_exception(lastException) |
21107 | : throw std::runtime_error("Could not generate kernel." ); |
21108 | } |
21109 | |
21110 | mark(labelExit); |
21111 | } |
21112 | |
21113 | // Wrapper around copyBodyRemCheck, checking for optimally-aligned S. |
21114 | template <HW hw> |
21115 | bool gemm_kernel_generator_t<hw>::copyBody( |
21116 | CopyProblem &problem, CopyStrategy &strategy, CopyState &state) { |
21117 | if (!is_zero_or_pow2(strategy.optionalAlignS)) stub(); |
21118 | |
21119 | bool success; |
21120 | |
21121 | if (strategy.optionalAlignS == 0) |
21122 | success = copyBodyRemCheck(problem, strategy, state); |
21123 | else { |
21124 | Label labelUnaligned, labelEnd; |
21125 | |
21126 | status << "S alignment check" << status_stream::endl; |
21127 | and_(1 | nz | f0[1], null.uw(), state.effS.uw(), |
21128 | uint16_t(strategy.optionalAlignS - 1)); |
21129 | and_(1 | nz | f1[1], null.uw(), state.inputs.lds.uw(), |
21130 | uint16_t(strategy.optionalAlignS - 1)); |
21131 | ejmpi(1 | f0[1] | anyv, labelUnaligned); |
21132 | |
21133 | auto modProblem = problem; |
21134 | modProblem.S.setAlignment(strategy.optionalAlignS); |
21135 | |
21136 | status << "S aligned to " << strategy.optionalAlignS << ':' |
21137 | << status_stream::endl; |
21138 | success = copyBodyRemCheck(modProblem, strategy, state); |
21139 | |
21140 | if (state.isNested) |
21141 | jmpi(1, labelEnd); |
21142 | else |
21143 | epilogue(strategy, state); |
21144 | |
21145 | mark(labelUnaligned); |
21146 | |
21147 | status << "S unaligned" << status_stream::endl; |
21148 | success = success && copyBodyRemCheck(problem, strategy, state); |
21149 | |
21150 | mark(labelEnd); |
21151 | } |
21152 | |
21153 | return success; |
21154 | } |
21155 | |
21156 | // Wrapper around copyBodyInternal, handling split remainders. |
21157 | template <HW hw> |
21158 | bool gemm_kernel_generator_t<hw>::copyBodyRemCheck( |
21159 | CopyProblem &problem, CopyStrategy &strategy, CopyState &state) { |
21160 | auto CopyStrategy::*remHandlingW |
21161 | = (strategy.xLoop ? &CopyStrategy::remHandlingY |
21162 | : &CopyStrategy::remHandlingX); |
21163 | bool wSplit = strategy.*remHandlingW == RemainderHandling::Split; |
21164 | bool success; |
21165 | |
21166 | if (!wSplit) |
21167 | success = copyBodyInternal(problem, strategy, state); |
21168 | else { |
21169 | CopyStrategy modStrategy = strategy; |
21170 | Label wRemBegin, wRemEnd; |
21171 | jmpi(1 | state.flagAP, wRemBegin); |
21172 | |
21173 | status << "Generating " |
21174 | << "xy" [strategy.xLoop] << " non-remainder kernel" |
21175 | << status_stream::endl; |
21176 | modStrategy.*remHandlingW = RemainderHandling::Ignore; |
21177 | success = copyBodyInternal(problem, modStrategy, state); |
21178 | |
21179 | if (state.isNested) |
21180 | jmpi(1, wRemEnd); |
21181 | else |
21182 | epilogue(strategy, state); |
21183 | |
21184 | modStrategy.*remHandlingW = RemainderHandling::KnownRemainder; |
21185 | |
21186 | bool recalc = false; |
21187 | |
21188 | if (strategy.xLoop && !isTransposing(modStrategy.D.accessType) |
21189 | && !isLargeCrosspack(problem.Td, problem.D.crosspack)) { |
21190 | // Change D access to use scattered stores so masking is possible. |
21191 | modStrategy.D.accessType = AccessType::Scattered; |
21192 | modStrategy.S.accessType = isTransposing(modStrategy.S.accessType) |
21193 | ? AccessType::Block |
21194 | : AccessType::Scattered; |
21195 | } |
21196 | if (!strategy.xLoop && !strategy.S.padded) { |
21197 | // Check if we need to change s_load/d_load. |
21198 | if (strategy.s_load > strategy.s_load_masked) { |
21199 | status << "Downgrading s_load: " << strategy.s_load << " -> " |
21200 | << strategy.s_load_masked << status_stream::endl; |
21201 | modStrategy.s_load = strategy.s_load_masked; |
21202 | recalc = true; |
21203 | } |
21204 | if (strategy.d_load > strategy.d_load_masked) { |
21205 | status << "Downgrading d_load: " << strategy.d_load << " -> " |
21206 | << strategy.d_load_masked << status_stream::endl; |
21207 | modStrategy.d_load = strategy.d_load_masked; |
21208 | recalc = true; |
21209 | } |
21210 | } |
21211 | |
21212 | status << "Generating " |
21213 | << "xy" [strategy.xLoop] << " remainder kernel" |
21214 | << status_stream::endl; |
21215 | mark(wRemBegin); |
21216 | if (recalc) copyCalcIncrements(problem, modStrategy, state); |
21217 | success = success && copyBodyInternal(problem, modStrategy, state); |
21218 | mark(wRemEnd); |
21219 | } |
21220 | |
21221 | return success; |
21222 | } |
21223 | |
21224 | // Body of copy kernel. |
21225 | template <HW hw> |
21226 | bool gemm_kernel_generator_t<hw>::copyBodyInternal( |
21227 | CopyProblem &problem, CopyStrategy &strategy, CopyState &state) { |
21228 | Label lZLoopBegin, lZLoopEnd; |
21229 | constexpr auto SD_copies = 1; |
21230 | vector<MaskAssignment> masks; |
21231 | bool share; |
21232 | |
21233 | auto Ts = problem.Ts, Td = problem.Td, Tsum = problem.Tsum; |
21234 | const bool byColumn = isColMajor(problem.D.layout); |
21235 | const bool sStrided |
21236 | = (isColMajor(problem.S.layout) == isColMajor(problem.D.layout)) |
21237 | ^ strategy.xLoop; |
21238 | const bool mLoop = isColMajor(problem.D.layout) == strategy.xLoop; |
21239 | |
21240 | const bool reflecting = false; |
21241 | const bool triRemOnly = false; |
21242 | |
21243 | auto crosspack = problem.D.crosspack; |
21244 | |
21245 | // Release w0 -- no longer needed. |
21246 | state.ra.safeRelease(state.w0); |
21247 | |
21248 | // Get flag register for complex swizzles for XeHP+. |
21249 | if (hw >= HW::XeHP && Ts.isComplex()) { |
21250 | state.flagSwizzle = state.raVFlag.alloc(); |
21251 | state.raVFlag.unlock(state.flagSwizzle); |
21252 | } |
21253 | |
21254 | MatrixAddressingStrategy S_strategyReflected = strategy.S; |
21255 | vector<RegisterBlock> S_layoutReflected; |
21256 | |
21257 | // Decide what remainder handling needs to be done. |
21258 | bool remainderX = (strategy.remHandlingX != RemainderHandling::Ignore); |
21259 | bool remainderY = (strategy.remHandlingY != RemainderHandling::Ignore); |
21260 | bool remainderZ = strategy.xLoop ? remainderX : remainderY; |
21261 | |
21262 | bool checkYRem1 = strategy.xLoop && remainderY && strategy.unrollY == 1; |
21263 | VirtualFlag flagYRem1; |
21264 | |
21265 | remainderY &= !checkYRem1; |
21266 | |
21267 | // Get register layouts for S and D. |
21268 | int nms, nmd, nns, nnd; |
21269 | auto setup = [&](int s_load, int d_load, Subregister S_addr0, |
21270 | Subregister S1_addr0, Subregister D_addr0, |
21271 | bool handleRemZ) -> bool { |
21272 | bool remM = remainderX && (!strategy.xLoop || handleRemZ); |
21273 | bool remN = remainderY && (strategy.xLoop || handleRemZ); |
21274 | Subregister remainders[3] |
21275 | = {state.remainderX, state.remainderY, Subregister {}}; |
21276 | |
21277 | if (!strategy.xLoop) { |
21278 | nmd = nms = strategy.unrollX; |
21279 | nnd = d_load; |
21280 | nns = s_load; |
21281 | } else { |
21282 | nnd = nns = strategy.unrollY; |
21283 | nmd = d_load; |
21284 | nms = s_load; |
21285 | } |
21286 | |
21287 | if (!byColumn) { |
21288 | std::swap(nms, nns); |
21289 | std::swap(nmd, nnd); |
21290 | std::swap(remM, remN); |
21291 | std::swap(remainders[0], remainders[1]); |
21292 | } |
21293 | |
21294 | auto remM_S = remM && !strategy.S.padded; |
21295 | auto remN_S = remN && !strategy.S.padded; |
21296 | auto remM_D = remM && !strategy.D.padded && !byColumn; |
21297 | auto remN_D = remN && !strategy.D.padded && byColumn; |
21298 | |
21299 | auto sMaxRBlock = 0; |
21300 | auto sMaxCBlock = 0; |
21301 | |
21302 | if (!getRegLayout(Ts, state.S_layout, nms, nns, remM_S, remN_S, false, |
21303 | true, sMaxRBlock, sMaxCBlock, problem.S, strategy.S)) |
21304 | return false; |
21305 | if (!getRegLayout(Td, state.D_layout, nmd, nnd, remM_D, remN_D, true, |
21306 | true, 0, 0, problem.D, strategy.D)) |
21307 | return false; |
21308 | |
21309 | if (hasFragmenting(state.S_layout) || hasFragmenting(state.D_layout)) { |
21310 | status << "Fragmenting not supported." << status_stream::endl; |
21311 | return false; |
21312 | } |
21313 | |
21314 | bool success = true; |
21315 | if (checkYRem1) { |
21316 | flagYRem1 = state.raVFlag.allocVirtual(); |
21317 | success &= !(state.raVFlag.isVirtual(flagYRem1) |
21318 | && state.vflagStorage.isInvalid()); |
21319 | } |
21320 | |
21321 | // Find and load any needed mask registers. |
21322 | success = success |
21323 | && assignMasks( |
21324 | state.S_layout, LoopM, LoopN, masks, strategy, state) |
21325 | && assignMasks( |
21326 | state.D_layout, LoopM, LoopN, masks, strategy, state); |
21327 | |
21328 | if (!success && state.vflagStorage.isInvalid()) { |
21329 | status << "Retrying with virtual flags." << status_stream::endl; |
21330 | allocVFlagStorage(strategy, state); |
21331 | success = assignMasks(state.S_layout, LoopM, LoopN, masks, strategy, |
21332 | state) |
21333 | && assignMasks(state.D_layout, LoopM, LoopN, masks, |
21334 | strategy, state); |
21335 | } |
21336 | |
21337 | if (!success) return false; |
21338 | |
21339 | loadMasks(masks, remainders, strategy, state); |
21340 | |
21341 | if (!strategy.xLoop && !remM_D && !remN_D |
21342 | && strategy.remHandlingX != RemainderHandling::Ignore) { |
21343 | // Find a mask to use for destination layout for y loop remainders. |
21344 | VirtualFlag flag; |
21345 | bool found = false; |
21346 | for (auto &mask : masks) |
21347 | if (mask.var == (byColumn ? LoopM : LoopN) && mask.offset == 0) |
21348 | flag = mask.flag, found = true; |
21349 | if (!found) stub(); |
21350 | for (auto &block : state.D_layout) { |
21351 | if (block.simdSize > 16) stub(); |
21352 | block.flag = flag; |
21353 | block.flagAny = true; |
21354 | } |
21355 | } else if (checkYRem1) { |
21356 | // Create mask for y remainder for x-loop kernels with unrollY == 1, and |
21357 | // apply it by hand to both source and destination. |
21358 | RegData regYRem1 = getMaskFlag(flagYRem1, state); |
21359 | FlagRegister testFlag; |
21360 | |
21361 | testFlag = regYRem1.isARF() |
21362 | ? reinterpret_cast<FlagRegister &>(regYRem1) |
21363 | : f0[1]; |
21364 | |
21365 | cmp(16 | gt | testFlag, state.remainderY, 0); |
21366 | |
21367 | for (auto &mask : masks) |
21368 | mov(1 | ~testFlag, getMaskFlag(mask.flag, state), 0); |
21369 | if (!regYRem1.isARF()) mov(1, regYRem1, testFlag); |
21370 | |
21371 | for (auto &block : state.S_layout) |
21372 | if (!block.flag) block.flag = flagYRem1; |
21373 | for (auto &block : state.D_layout) { |
21374 | if (block.simdSize > 16) stub(); |
21375 | block.flag = flagYRem1; |
21376 | } |
21377 | } |
21378 | |
21379 | // Match source layout to destination layout if possible, so that they can share registers. |
21380 | share = (Ts == Td) && (s_load == d_load) |
21381 | && matchLayoutsBidirectional( |
21382 | Ts, state.S_layout, state.D_layout); |
21383 | |
21384 | // Allocate address registers. |
21385 | allocAddrRegs(state.S_addrs, state.S_layout, problem.S, strategy.S, |
21386 | state, |
21387 | getHint(share ? HintType::DAddr : HintType::SAddr, strategy)); |
21388 | allocAddrRegs(state.D_addrs, state.D_layout, problem.D, strategy.D, |
21389 | state, getHint(HintType::DAddr, strategy)); |
21390 | |
21391 | // Set up address registers. |
21392 | setupAddr(Ts, state.S_addrs, S_addr0, state.S_layout, state.inputs.lds, |
21393 | problem.S, strategy.S, strategy, state); |
21394 | setupAddr(Td, state.D_addrs, D_addr0, state.D_layout, state.inputs.ldd, |
21395 | problem.D, strategy.D, strategy, state); |
21396 | |
21397 | // Allocate data registers. |
21398 | int S_regCount = getRegCount(state.S_layout); |
21399 | int D_regCount = getRegCount(state.D_layout); |
21400 | |
21401 | state.D_regs = state.ra.alloc_range( |
21402 | D_regCount, getHint(HintType::D, strategy)); |
21403 | state.S_regs = share ? state.D_regs |
21404 | : state.ra.alloc_range(S_regCount, |
21405 | getHint(HintType::S, strategy)); |
21406 | |
21407 | // Prepare for summation. |
21408 | // Clean up previous sums if any, and try to reuse their registers. |
21409 | // Allocate and zero new sum registers as needed. |
21410 | if (problem.sum) { |
21411 | if (strategy.xLoop) stub(); |
21412 | |
21413 | vector<RegisterBlock> Ds_layout; |
21414 | makeSumLayout(!byColumn, Td, state.D_layout, Tsum, Ds_layout, |
21415 | strategy, state); |
21416 | |
21417 | bool alloc = state.Ds_layout.empty() |
21418 | || !matchLayouts(Tsum, Ds_layout, state.Ds_layout); |
21419 | if (!state.Ds_layout.empty() && alloc) { |
21420 | horizontalAdd(!byColumn, Tsum, state.Ds_regs.back(), |
21421 | state.Ds_layout, state); |
21422 | alloc = !matchLayouts(Tsum, Ds_layout, state.Ds_layout); |
21423 | } |
21424 | if (alloc) { |
21425 | state.Ds_layout = std::move(Ds_layout); |
21426 | auto Ds_regs |
21427 | = state.ra.alloc_range(getRegCount(state.Ds_layout)); |
21428 | zeroMatrix(Ds_regs, strategy); |
21429 | state.Ds_regs.push_back(Ds_regs); |
21430 | } |
21431 | } |
21432 | |
21433 | return true; |
21434 | }; |
21435 | |
21436 | auto cleanup = [&]() { |
21437 | state.raVFlag.safeRelease(flagYRem1); |
21438 | safeReleaseMaskAssignments(masks, state); |
21439 | safeReleaseRanges(state.S_addrs, state); |
21440 | safeReleaseRanges(state.D_addrs, state); |
21441 | |
21442 | state.ra.safeRelease(state.S_regs); |
21443 | state.ra.safeRelease(state.D_regs); |
21444 | // Sum registers not freed here. |
21445 | |
21446 | state.S_layout.clear(); |
21447 | state.D_layout.clear(); |
21448 | }; |
21449 | |
21450 | auto doSLoad = [&](const vector<RegisterBlock> &layout, |
21451 | const vector<RegisterBlock> &layoutReflect, |
21452 | const vector<GRFRange> &addrs, |
21453 | const vector<GRFRange>(&addrSrcs)[2], int z0, |
21454 | int s_load, int S_copy, bool checkRem) { |
21455 | bool unlockAP = false; |
21456 | Label skipLoad; |
21457 | checkRem &= (z0 > 0); |
21458 | |
21459 | if (checkRem) { |
21460 | zeroMatrix(state.S_regs, strategy); |
21461 | unlockAP = !state.raVFlag.lock(state.flagAP); |
21462 | state.usePhysicalFlag(state.flagAP); |
21463 | cmp(1 | le | state.flagAP, state.Z, uint16_t(z0)); |
21464 | jmpi(1 | state.flagAP, skipLoad); |
21465 | } |
21466 | |
21467 | { |
21468 | loadMatrix(state.S_regs, layout, problem.S, strategy.S, addrs, |
21469 | strategy, state); |
21470 | } |
21471 | |
21472 | auto addrsFixed = reflecting ? &addrSrcs[0] : &addrs; |
21473 | auto = reflecting ? &addrSrcs[1] : nullptr; |
21474 | auto layoutFixed = &layout; |
21475 | auto layoutStrided = &layoutReflect; |
21476 | |
21477 | if (sStrided) { |
21478 | std::swap(addrsFixed, addrsStrided); |
21479 | std::swap(layoutFixed, layoutStrided); |
21480 | } |
21481 | |
21482 | { |
21483 | if (addrsStrided) |
21484 | incAddr(*addrsStrided, |
21485 | (s_load == 1) ? state.inputs.lds : state.lds_sl, |
21486 | *layoutStrided, problem.S, strategy.S, strategy, state); |
21487 | if (addrsFixed) |
21488 | incAddr(*addrsFixed, uint16_t(s_load * Ts), *layoutFixed, |
21489 | problem.S, strategy.S, strategy, state); |
21490 | } |
21491 | if (checkRem) { |
21492 | if (unlockAP) state.raVFlag.unlock(state.flagAP); |
21493 | mark(skipLoad); |
21494 | } |
21495 | }; |
21496 | |
21497 | auto doDStore = [&](const vector<RegisterBlock> &layout, |
21498 | const vector<GRFRange> &addrs, int d_load, |
21499 | int D_copy) { |
21500 | storeMatrix(state.D_regs, layout, problem.D, strategy.D, addrs, |
21501 | strategy, state); |
21502 | if (problem.sum) |
21503 | accumulateSum(!byColumn, Td, state.D_regs, layout, Tsum, |
21504 | state.Ds_regs.back(), state.Ds_layout, strategy, state); |
21505 | if (strategy.xLoop) { |
21506 | if (d_load >= strategy.unrollX) |
21507 | incAddr(addrs, state.ldd_dl, layout, problem.D, strategy.D, |
21508 | strategy, state); |
21509 | else |
21510 | incAddr(addrs, uint16_t(d_load * Td), layout, problem.D, |
21511 | strategy.D, strategy, state); |
21512 | } else { |
21513 | auto D_tileX = byColumn ? problem.D.tileR : problem.D.tileC; |
21514 | auto D_tileY = byColumn ? problem.D.tileC : problem.D.tileR; |
21515 | auto effPS = (d_load < D_tileY) ? D_tileX : problem.D.packSize; |
21516 | incAddr(addrs, uint16_t(d_load * effPS * Td), layout, problem.D, |
21517 | strategy.D, strategy, state); |
21518 | } |
21519 | }; |
21520 | |
21521 | // Start generating code. |
21522 | |
21523 | // Reuse z for the loop counter. |
21524 | // If z unroll > 1, the loop counter will be offset by (unrollZ - 1) during the main loop, |
21525 | // unless there's no z remainder. |
21526 | // For triangular-ended copies, offset by an additional unrollW [2x unrollX if fused] to push triangular handling to remainder loop. |
21527 | state.Z = mLoop ? state.inputs.m : state.inputs.n; |
21528 | |
21529 | auto unrollZ = strategy.unrollZ(); |
21530 | auto offsetZ = (remainderZ || triRemOnly) ? (unrollZ - 1) : 0; |
21531 | |
21532 | if (offsetZ == 0) |
21533 | cmp(1 | le | state.flagAP, null.d(), state.Z, int16_t(0)); |
21534 | else |
21535 | add(1 | le | state.flagAP, state.Z, state.Z, int16_t(-offsetZ)); |
21536 | |
21537 | // Get flag register and loop counter for barrier check if needed. |
21538 | FlagRegister flagBarrier; |
21539 | Subregister bcount; |
21540 | if (strategy.barrierFreq > 0) { |
21541 | flagBarrier = state.raVFlag.alloc(); |
21542 | |
21543 | // Can use main loop counter if barrierFreq and unrollZ both powers of 2. |
21544 | if (!is_zero_or_pow2(strategy.barrierFreq * unrollZ)) { |
21545 | bcount = state.ra.alloc_sub<uint32_t>(); |
21546 | mov(1, bcount, uint16_t(strategy.barrierFreq)); |
21547 | } |
21548 | } |
21549 | |
21550 | // Setup for main loop. |
21551 | if (!setup(strategy.s_load, strategy.d_load, state.effS, state.effS1, |
21552 | state.effD, false)) |
21553 | return false; |
21554 | |
21555 | bool lateZLoopCheck = state.vflagStorage.isValid(); |
21556 | if (lateZLoopCheck) { |
21557 | // Release flags for use by vflags. Note flagReflect is not released. |
21558 | state.raVFlag.unlock(state.flagAP); |
21559 | if (flagBarrier.isValid()) state.raVFlag.unlock(flagBarrier); |
21560 | } |
21561 | |
21562 | // Bail to remainder loop if no main loops. |
21563 | jmpi(1 | state.flagAP, lZLoopEnd); |
21564 | |
21565 | // Loop check code. |
21566 | auto zLoopCheck = [&](int unrollZ, bool enableBarriers) { |
21567 | // Use the all-purpose flag for z loop query. |
21568 | add(1 | gt | state.flagAP, state.Z, state.Z, int16_t(-unrollZ)); |
21569 | |
21570 | // Check for barrier if requested. |
21571 | if (enableBarriers) { |
21572 | if (bcount.isInvalid()) |
21573 | and_(1 | ze | flagBarrier, null.ud(), state.Z, |
21574 | uint16_t(unrollZ * strategy.barrierFreq - unrollZ)); |
21575 | else |
21576 | add(1 | ze | flagBarrier, bcount, bcount, int16_t(-1)); |
21577 | } |
21578 | }; |
21579 | |
21580 | // Lambdas used in zLoopBody (moved outside to w/a GCC bug) |
21581 | auto mulAlphaFixed = [&](int esize, RegData r) { |
21582 | mul(esize, r, r, problem.alpha_real.getRegAvoiding(hw, r)); |
21583 | }; |
21584 | |
21585 | auto mulAlpha = [&](int esize, RegData r) { |
21586 | mul(esize, r, r, cast(Ts.real(), problem.alpha_real)); |
21587 | }; |
21588 | |
21589 | auto signChange = [&](int esize, RegData r) { |
21590 | auto ne = elementsPerGRF<uint32_t>(hw); |
21591 | xor_<uint32_t>(esize, r, r, |
21592 | (ne < esize) ? state.signChange[0](0, ne, 1) |
21593 | : state.signChange[0](1)); |
21594 | }; |
21595 | |
21596 | // z loop: returns true on success. |
21597 | int S_copy = 0, D_copy = 0; |
21598 | auto zLoopBody = [&](const vector<RegisterBlock> &S_layout, |
21599 | const vector<RegisterBlock> &S_layoutReflected, |
21600 | const vector<RegisterBlock> &D_layout, |
21601 | const vector<GRFRange> &S_addrs, |
21602 | const vector<GRFRange>(&S_addrSrcs)[2], |
21603 | const vector<GRFRange> &D_addrs, int unrollZ, |
21604 | int s_load, int d_load, bool enableBarriers, |
21605 | bool enableTri, bool needSRem = false, |
21606 | bool noLoop = false) { |
21607 | int us = s_load, ud = 0; |
21608 | int uZLoopCheck = noLoop ? -1 : lateZLoopCheck ? (unrollZ - 1) : 0; |
21609 | bool dMasked = hasMasking(D_layout); |
21610 | |
21611 | for (int u = 0; u < unrollZ; u++, us++, ud++) { |
21612 | // Maintain us (= u % s_load) and ud (= u % d_load) counters. |
21613 | bool loadS = false; |
21614 | if (us == s_load) { |
21615 | us = 0; |
21616 | loadS = true; |
21617 | } |
21618 | |
21619 | if (ud == d_load) ud = 0; |
21620 | bool storeD = ((ud + 1) == d_load); |
21621 | |
21622 | // Test loop counter on first iteration (lateZLoopCheck == false) |
21623 | if ((u == uZLoopCheck) && !lateZLoopCheck) |
21624 | zLoopCheck(unrollZ, enableBarriers); |
21625 | |
21626 | // Load S every s_load loops, and copy as necessary. |
21627 | if (loadS) { |
21628 | doSLoad(S_layout, S_layoutReflected, S_addrs, S_addrSrcs, u, |
21629 | s_load, S_copy, needSRem); |
21630 | |
21631 | // Copy S registers to D registers, or perform in-place scaling/transposition. |
21632 | if (!share) { |
21633 | int dOffR = 0, dOffC = 0; |
21634 | (byColumn ? dOffC : dOffR) = ud; |
21635 | |
21636 | if (!copyRegisters(Ts, Td, S_layout, D_layout, state.S_regs, |
21637 | state.D_regs, dOffR, dOffC, problem.alpha_real, |
21638 | problem.alpha_imag, problem.conjugate, strategy, |
21639 | state)) |
21640 | return false; |
21641 | } else { |
21642 | if (!problem.alpha_real.fixed()) |
21643 | map(hw, Ts.real(), state.S_regs, S_layout, strategy, |
21644 | mulAlphaFixed); |
21645 | else if ((problem.alpha_real != 1) |
21646 | && (problem.alpha_real != -1)) |
21647 | map(hw, Ts.real(), state.S_regs, S_layout, strategy, |
21648 | mulAlpha); |
21649 | if (problem.conjugate || (problem.alpha_real == -1)) |
21650 | map<uint32_t>(hw, state.S_regs, S_layout, strategy, |
21651 | signChange); |
21652 | } |
21653 | |
21654 | // Advance S copy counter. |
21655 | if (++S_copy == SD_copies) S_copy = 0; |
21656 | } |
21657 | |
21658 | // Test loop counter on last iteration (lateZLoopCheck == true) if D unmasked. |
21659 | if ((u == uZLoopCheck) && lateZLoopCheck && !dMasked) |
21660 | zLoopCheck(unrollZ, enableBarriers); |
21661 | |
21662 | // Store D every d_load loops. |
21663 | if (storeD) { |
21664 | doDStore(D_layout, D_addrs, d_load, D_copy); |
21665 | if (++D_copy == SD_copies) D_copy = 0; |
21666 | } |
21667 | |
21668 | // Test loop counter at very end (lateZLoopCheck == true) if D masked. |
21669 | if ((u == uZLoopCheck) && lateZLoopCheck && dMasked) |
21670 | zLoopCheck(unrollZ, enableBarriers); |
21671 | } |
21672 | |
21673 | // Forget about active vflags. |
21674 | state.wipeActiveVFlags(); |
21675 | |
21676 | return true; |
21677 | }; |
21678 | |
21679 | syncall(); |
21680 | |
21681 | mark(lZLoopBegin); |
21682 | { |
21683 | if (!zLoopBody(state.S_layout, S_layoutReflected, state.D_layout, |
21684 | state.S_addrs, state.S_addrSrcs, state.D_addrs, unrollZ, |
21685 | strategy.s_load, strategy.d_load, strategy.barrierFreq > 0, |
21686 | !triRemOnly)) |
21687 | return false; |
21688 | |
21689 | if (strategy.barrierFreq == 0) |
21690 | jmpi(1 | state.flagAP, lZLoopBegin); |
21691 | else { |
21692 | jmpi(1 | ~state.flagAP, lZLoopEnd); |
21693 | jmpi(1 | ~flagBarrier, lZLoopBegin); |
21694 | |
21695 | auto temp = state.ra.alloc(); |
21696 | if (!bcount.isInvalid()) |
21697 | mov(1, bcount, uint16_t(strategy.barrierFreq)); |
21698 | |
21699 | GRF r0_info; |
21700 | bool freeR0Info = false; |
21701 | |
21702 | if (state.r0_info.isARF()) { |
21703 | r0_info = state.ra.alloc(); |
21704 | mov<uint32_t>(8, r0_info, state.r0_info); |
21705 | freeR0Info = true; |
21706 | } else |
21707 | r0_info = GRF {state.r0_info.getBase()}; |
21708 | |
21709 | barrier(temp, r0_info); |
21710 | state.ra.safeRelease(temp); |
21711 | if (freeR0Info) state.ra.safeRelease(r0_info); |
21712 | |
21713 | jmpi(1, lZLoopBegin); |
21714 | } |
21715 | } |
21716 | mark(lZLoopEnd); |
21717 | |
21718 | state.raVFlag.safeRelease(flagBarrier); |
21719 | state.ra.safeRelease(bcount); |
21720 | |
21721 | // z remainder loop. |
21722 | if (offsetZ) { |
21723 | // Undo offseting on the z loop counter and check for zero remainder loops. |
21724 | add(1 | le | state.flagAP, state.Z, state.Z, uint16_t(offsetZ)); |
21725 | |
21726 | // Get the current S, D addresses. |
21727 | Subregister S_addr0, S1_addr0, D_addr0; |
21728 | int S_shift, D_shift; |
21729 | S_addr0 = getOriginAddr( |
21730 | state.S_layout, state.S_addrs, problem.S, strategy.S, &S_shift); |
21731 | |
21732 | D_addr0 = getOriginAddr( |
21733 | state.D_layout, state.D_addrs, problem.D, strategy.D, &D_shift); |
21734 | |
21735 | auto unshiftAddr0 = [&]() { |
21736 | if (S_shift) shl(1, S_addr0, S_addr0, S_shift); |
21737 | if (D_shift) shl(1, D_addr0, D_addr0, D_shift); |
21738 | }; |
21739 | |
21740 | // Prepare for potential new layout. |
21741 | vector<RegisterBlock> S_layout1, S_layout1Reflect, D_layout1; |
21742 | vector<GRFRange> S_addrs1, S_addrSrcs1[2], D_addrs1; |
21743 | |
21744 | // First, try handling the whole remainder, all at once. |
21745 | bool wholeRem = false, fragmented = false; |
21746 | auto newSLoad = strategy.s_load, newDLoad = strategy.d_load; |
21747 | auto saveSStrategy = strategy.S; |
21748 | bool largeDCrosspack = isLargeCrosspack(Td, problem.D.crosspack); |
21749 | |
21750 | if (S_addr0.isValid() && D_addr0.isValid() && !largeDCrosspack) { |
21751 | auto saveState = state; |
21752 | auto saveMasks = masks; |
21753 | (strategy.xLoop ? state.remainderX : state.remainderY) = state.Z; |
21754 | pushStream(); |
21755 | try { |
21756 | cleanup(); |
21757 | state.ra.claim(S_addr0); |
21758 | state.ra.claim(D_addr0); |
21759 | unshiftAddr0(); |
21760 | |
21761 | wholeRem = setup(strategy.s_load, strategy.d_load, S_addr0, |
21762 | S1_addr0, D_addr0, true); |
21763 | |
21764 | state.ra.release(S_addr0); |
21765 | state.ra.release(D_addr0); |
21766 | } catch (...) {} |
21767 | if (!wholeRem) { |
21768 | masks = saveMasks; |
21769 | state = saveState; |
21770 | } |
21771 | wholeRem ? appendCurrentStream() : discardStream(); |
21772 | } |
21773 | |
21774 | // If that doesn't work, retry with minimal unroll. |
21775 | if (!wholeRem) { |
21776 | newSLoad = 1; |
21777 | newDLoad = crosspack; |
21778 | bool unshare = share && (newSLoad != newDLoad); |
21779 | |
21780 | // Fragment the S, D layouts, taking the first row/column of each. |
21781 | vector<int> indices; |
21782 | fragmented = (!unshare && !largeDCrosspack |
21783 | && getSubblocks(Ts, S_layout1, indices, state.S_layout, |
21784 | !mLoop, 0, newSLoad, strategy.S.padded, problem.S, |
21785 | strategy.S) |
21786 | && getSubblocks(Ts, S_layout1Reflect, S_layoutReflected, |
21787 | mLoop, 0, newSLoad, strategy.S.padded, problem.S, |
21788 | S_strategyReflected) |
21789 | && getSubblocks(Td, D_layout1, D_addrs1, state.D_layout, |
21790 | state.D_addrs, !mLoop, 0, newDLoad, false, |
21791 | problem.D, strategy.D)); |
21792 | |
21793 | if (fragmented) { |
21794 | // Select source address registers from the fragments. |
21795 | for (auto b : indices) |
21796 | S_addrs1.push_back(state.S_addrs[b]); |
21797 | // Update sizes. |
21798 | (mLoop ? nms : nns) = newSLoad; |
21799 | (mLoop ? nmd : nnd) = newDLoad; |
21800 | } else { |
21801 | // Fragmentation failed. Start fresh. |
21802 | if (S_addr0.isInvalid() || D_addr0.isInvalid()) return false; |
21803 | |
21804 | cleanup(); |
21805 | state.ra.claim(S_addr0); |
21806 | state.ra.claim(D_addr0); |
21807 | unshiftAddr0(); |
21808 | |
21809 | if (largeDCrosspack) { |
21810 | strategy.S.accessType = isTransposing(strategy.S.accessType) |
21811 | ? AccessType::Block |
21812 | : strategy.S.base.isStateless() |
21813 | ? AccessType::Scattered |
21814 | : AccessType::ChannelScattered; |
21815 | } |
21816 | |
21817 | if (!setup(newSLoad, newDLoad, S_addr0, S1_addr0, D_addr0, |
21818 | false)) |
21819 | return false; |
21820 | |
21821 | state.ra.release(S_addr0); |
21822 | state.ra.release(D_addr0); |
21823 | } |
21824 | |
21825 | if (crosspack > 1) { |
21826 | lateZLoopCheck = true; |
21827 | copyCalcIncrements( |
21828 | problem, strategy, state, newSLoad, newDLoad); |
21829 | } |
21830 | } |
21831 | |
21832 | // Emit z remainder loop. |
21833 | Label lZRemLoopBegin, lZRemLoopEnd; |
21834 | jmpi(1 | state.flagAP, lZRemLoopEnd); |
21835 | mark(lZRemLoopBegin); |
21836 | wholeRem ? zLoopBody(state.S_layout, S_layoutReflected, state.D_layout, |
21837 | state.S_addrs, state.S_addrSrcs, state.D_addrs, unrollZ, |
21838 | newSLoad, newDLoad, false, true, false, !triRemOnly) |
21839 | : fragmented |
21840 | ? zLoopBody(S_layout1, S_layout1Reflect, D_layout1, |
21841 | S_addrs1, S_addrSrcs1, D_addrs1, crosspack, |
21842 | newSLoad, newDLoad, false, true, crosspack > 1) |
21843 | : zLoopBody(state.S_layout, S_layoutReflected, |
21844 | state.D_layout, state.S_addrs, state.S_addrSrcs, |
21845 | state.D_addrs, crosspack, newSLoad, newDLoad, |
21846 | false, true, crosspack > 1); |
21847 | if (!wholeRem || triRemOnly) jmpi(1 | state.flagAP, lZRemLoopBegin); |
21848 | mark(lZRemLoopEnd); |
21849 | |
21850 | strategy.S = saveSStrategy; |
21851 | } |
21852 | |
21853 | // Finalize and store sums. |
21854 | if (problem.sum) { |
21855 | Label skipSumStore; |
21856 | bool simtCF = strategy.fused; |
21857 | |
21858 | if (remainderX) { |
21859 | cmp((simtCF ? 16 : 1) | le | state.flagAP, state.remainderX, 0); |
21860 | simtCF ? goto12(16 | state.flagAP, skipSumStore) |
21861 | : jmpi(1 | state.flagAP, skipSumStore); |
21862 | } |
21863 | |
21864 | horizontalAdd( |
21865 | !byColumn, Tsum, state.Ds_regs.back(), state.Ds_layout, state); |
21866 | |
21867 | // Accumulate sums from main and remainder loops. |
21868 | for (int l = 1; l < int(state.Ds_regs.size()); l++) { |
21869 | map(hw, Tsum, state.Ds_regs[0], state.Ds_regs[l], strategy, |
21870 | [&](int ne, GRF r1, GRF r2) { add(ne, r1, r1, r2); }); |
21871 | state.ra.safeRelease(state.Ds_regs[l]); |
21872 | } |
21873 | state.Ds_regs.resize(1); |
21874 | |
21875 | MatrixAddressing Ds = problem.D; |
21876 | Ds.crosspack = 1; |
21877 | |
21878 | MatrixAddressingStrategy Ds_strategy = strategy.D; |
21879 | Ds_strategy.accessType = AccessType::Block; |
21880 | |
21881 | int sr = 1, sc = 1; |
21882 | (byColumn ? sr : sc) = problem.D.packSize; |
21883 | |
21884 | vector<RegisterBlock> Ds_layoutOut; |
21885 | bool ok = getRegLayout(Tsum, Ds_layoutOut, sr, sc, false, false, true, |
21886 | true, 0, 0, Ds, Ds_strategy) |
21887 | && matchLayouts(Tsum, Ds_layoutOut, state.Ds_layout); |
21888 | if (!ok) return false; |
21889 | |
21890 | vector<GRFRange> Ds_addrs; |
21891 | allocAddrRegs(Ds_addrs, Ds_layoutOut, Ds, Ds_strategy, state); |
21892 | |
21893 | Subregister Ds_base; |
21894 | Ds_base = state.ra.alloc_sub(state.effD.getType()); |
21895 | |
21896 | mulConstant(1, Ds_base.ud(), state.inputs.ldd, problem.D.packSize * Td); |
21897 | add(1, Ds_base.ud(), Ds_base.ud(), -problem.D.packSize * Tsum); |
21898 | eadd(1, Ds_base, Ds_base.ud(), state.effD, strategy, state); |
21899 | |
21900 | setupAddr(Tsum, Ds_addrs, Ds_base, Ds_layoutOut, Subregister(), Ds, |
21901 | Ds_strategy, strategy, state); |
21902 | storeMatrix(state.Ds_regs[0], Ds_layoutOut, Ds, Ds_strategy, Ds_addrs, |
21903 | strategy, state); |
21904 | |
21905 | state.ra.safeRelease(Ds_base); |
21906 | safeReleaseRanges(Ds_addrs, state); |
21907 | safeReleaseRanges(state.Ds_regs, state); |
21908 | state.Ds_layout.clear(); |
21909 | state.ra.safeRelease(state.all1s); |
21910 | |
21911 | if (remainderX) { |
21912 | mark(skipSumStore); |
21913 | if (simtCF) join(16); |
21914 | } |
21915 | } |
21916 | |
21917 | // Done. Free address, data, and flag registers. |
21918 | cleanup(); |
21919 | state.ra.safeRelease(state.signChange); |
21920 | if (lateZLoopCheck) state.raVFlag.lock(state.flagAP); |
21921 | state.raVFlag.safeRelease(state.flagReflect); |
21922 | state.raVFlag.safeRelease(state.flagSwizzle); |
21923 | |
21924 | return true; /* Success! */ |
21925 | } |
21926 | |
21927 | // Register-to-register copy of a single block, ignoring register offsets in the block. |
21928 | template <HW hw> |
21929 | bool gemm_kernel_generator_t<hw>::copyRegisterBlock(Type Ts, Type Td, |
21930 | const RegisterBlock &blockSrc, const RegisterBlock &blockDst, |
21931 | const GRFMultirange &src, const GRFMultirange &dst, int dOffR, |
21932 | int dOffC, const CommonStrategy &strategy, CommonState &state, |
21933 | bool preserveSrc) { |
21934 | std::vector<RegisterBlock> modSrc {1, blockSrc}, modDst {1, blockDst}; |
21935 | modSrc[0].offsetBytes %= GRF::bytes(hw); |
21936 | modDst[0].offsetBytes %= GRF::bytes(hw); |
21937 | return copyRegisters(Ts, Td, modSrc, modDst, src, dst, dOffR, dOffC, false, |
21938 | strategy, state, preserveSrc); |
21939 | } |
21940 | |
21941 | // Register-to-register copy, with no scaling. |
21942 | template <HW hw> |
21943 | bool gemm_kernel_generator_t<hw>::copyRegisters(Type Ts, Type Td, |
21944 | const vector<RegisterBlock> &layoutSrc, |
21945 | const vector<RegisterBlock> &layoutDst, const GRFMultirange &src, |
21946 | const GRFMultirange &dst, int dOffR, int dOffC, bool conjugate, |
21947 | const CommonStrategy &strategy, CommonState &state, bool preserveSrc) { |
21948 | return copyRegisters(Ts, Td, layoutSrc, layoutDst, src, dst, dOffR, dOffC, |
21949 | Scalar<double>(1.), Scalar<double>(0.), conjugate, strategy, state, |
21950 | preserveSrc); |
21951 | } |
21952 | |
21953 | // Register-to-register copy, with scaling. |
21954 | template <HW hw> |
21955 | bool gemm_kernel_generator_t<hw>::copyRegisters(Type Ts, Type Td, |
21956 | const vector<RegisterBlock> &layoutSrc, |
21957 | const vector<RegisterBlock> &layoutDst, const GRFMultirange &src, |
21958 | const GRFMultirange &dst, int dOffR, int dOffC, |
21959 | const Scalar<double> &alpha_real, const Scalar<double> &alpha_imag, |
21960 | bool conjugate, const CommonStrategy &strategy, CommonState &state, |
21961 | bool preserveSrc) { |
21962 | const int nphases = 2, qCXMin = -1, qCXMax = -1; |
21963 | |
21964 | bool preswizzle = (hw >= HW::XeHP); |
21965 | GRFRange copyTemp; |
21966 | |
21967 | auto allocTemp = [&]() { |
21968 | if (preswizzle && copyTemp.isInvalid()) |
21969 | copyTemp = state.ra.alloc_range(2); |
21970 | }; |
21971 | |
21972 | int srcM, srcN; |
21973 | getLayoutDims(layoutSrc, srcM, srcN); |
21974 | bool vectorCopy = (srcM == 1 || srcN == 1); |
21975 | int periodY = 1; |
21976 | |
21977 | if (GRF::bytes(hw) == 64 && Td.size() == 1 && layoutDst[0].crosspack > 1) |
21978 | periodY = 2; |
21979 | |
21980 | for (int phase = -1; phase < nphases; phase++) { |
21981 | for (int phaseY = 0; phaseY < periodY; phaseY++) { |
21982 | for (auto &sblock : layoutSrc) { |
21983 | auto RegisterBlock::*nx = sblock.colMajor ? &RegisterBlock::nr |
21984 | : &RegisterBlock::nc; |
21985 | auto RegisterBlock::*ny = sblock.colMajor ? &RegisterBlock::nc |
21986 | : &RegisterBlock::nr; |
21987 | auto RegisterBlock::*offsetY = sblock.colMajor |
21988 | ? &RegisterBlock::offsetC |
21989 | : &RegisterBlock::offsetR; |
21990 | |
21991 | for (int eoffY = 0; eoffY < sblock.*ny; eoffY++) { |
21992 | if (((eoffY + sblock.*offsetY) & (periodY - 1)) != phaseY) |
21993 | continue; |
21994 | for (int qCX = qCXMin; qCX <= qCXMax; qCX++) { |
21995 | for (int eoffX = 0; eoffX < sblock.*nx;) { |
21996 | auto eoffR = sblock.colMajor ? eoffX : eoffY; |
21997 | auto eoffC = sblock.colMajor ? eoffY : eoffX; |
21998 | |
21999 | int selems, delems; |
22000 | const RegisterBlock *sblockPtr, *dblockPtr; |
22001 | |
22002 | // Locate source and destination register. |
22003 | auto sreg = findBlockReg(Ts, layoutSrc, |
22004 | sblock.offsetR + eoffR, |
22005 | sblock.offsetC + eoffC, src, selems, |
22006 | sblockPtr, qCX); |
22007 | auto dreg = findBlockReg(Td, layoutDst, |
22008 | sblock.offsetR + eoffR + dOffR, |
22009 | sblock.offsetC + eoffC + dOffC, dst, delems, |
22010 | dblockPtr, qCX); |
22011 | |
22012 | auto scrosspack = sblock.crosspack; |
22013 | auto dcrosspack = dblockPtr->crosspack; |
22014 | |
22015 | if (sblock.colMajor != dblockPtr->colMajor) { |
22016 | bool sLargeCP |
22017 | = isLargeCrosspack(Ts, scrosspack); |
22018 | bool dLargeCP |
22019 | = isLargeCrosspack(Td, dcrosspack); |
22020 | bool sEffCM = sblock.colMajor ^ sLargeCP; |
22021 | bool dEffCM = dblockPtr->colMajor ^ dLargeCP; |
22022 | if (sEffCM == dEffCM) { |
22023 | if (sLargeCP) |
22024 | selems = std::min<int>( |
22025 | selems, scrosspack); |
22026 | if (dLargeCP) |
22027 | delems = std::min<int>( |
22028 | delems, dcrosspack); |
22029 | } else { |
22030 | if (!vectorCopy) |
22031 | stub(); // No in-register matrix transposes. |
22032 | selems = delems = 1; |
22033 | } |
22034 | } |
22035 | |
22036 | // Find out how many consecutive elements we can copy. |
22037 | auto nGRFs = (strategy.dualGRF ? 2 : 1); |
22038 | auto nGRFs_d = (dreg.getOffset() >= dcrosspack) |
22039 | ? 1 |
22040 | : nGRFs; // Don't cross destination GRF boundaries for efficiency. |
22041 | auto selems_real = selems * Ts.complexComponents(); |
22042 | auto delems_real = delems * Td.complexComponents(); |
22043 | auto selems_limit = div_up( |
22044 | nGRFs * elementsPerGRF(hw, Ts.real()) |
22045 | - sreg.getOffset(), |
22046 | scrosspack); |
22047 | auto delems_limit = div_up( |
22048 | nGRFs_d * elementsPerGRF(hw, Td.real()) |
22049 | - dreg.getOffset(), |
22050 | dcrosspack); |
22051 | selems_real = std::min({selems_real, selems_limit}); |
22052 | delems_real = std::min({delems_real, delems_limit}); |
22053 | auto nelems_real |
22054 | = std::min(selems_real, delems_real); |
22055 | if (phase != 0) |
22056 | nelems_real = std::min( |
22057 | rounddown_pow2(nelems_real), 32); |
22058 | |
22059 | if (Ts == Type::f32 && Td != Type::f32 |
22060 | && dcrosspack == 1) |
22061 | nelems_real = std::min(nelems_real, |
22062 | elementsPerGRF(hw, |
22063 | Ts)); // Special case: mixed mode packed downconversion limited to SIMD8. |
22064 | |
22065 | // Check if separate conversions are needed due to size changes. |
22066 | auto sconvertCP = (Ts.size() / Td.size()); |
22067 | bool sconvert = (Td.size() == 1 && Ts.size() > 1 |
22068 | && dcrosspack != sconvertCP) |
22069 | || (Td.size() == 2 && Td.isFP() |
22070 | && !Ts.isFP() |
22071 | && dcrosspack != sconvertCP |
22072 | && hw > HW::Gen9); |
22073 | if (sconvert && preserveSrc) stub(); |
22074 | auto sregConverted = sconvert |
22075 | ? sreg.reinterpret(0, Td.real().ngen())( |
22076 | sconvertCP) |
22077 | : sreg(scrosspack); |
22078 | |
22079 | auto dconvertCP = (Td.size() / Ts.size()); |
22080 | bool dconvert = (Ts.size() == 1 && Td.size() > 1 |
22081 | && scrosspack != dconvertCP); |
22082 | auto dregConverted = dconvert |
22083 | ? dreg.reinterpret(0, Ts.real().ngen())( |
22084 | dconvertCP) |
22085 | : dreg(dcrosspack); |
22086 | |
22087 | InstructionModifier modMov, mmodMov; |
22088 | if (Ts != Td && Td.isInteger() |
22089 | && Td.size() <= Ts.size()) { |
22090 | modMov = modMov | sat; |
22091 | if (!sconvert && !dconvert) |
22092 | mmodMov = mmodMov | sat; |
22093 | } |
22094 | |
22095 | // Finally, copy, with any necessary conjugation and scaling. If doing a raw copy, use another pipe. |
22096 | switch (phase) { |
22097 | case -1: |
22098 | if (hw == HW::Gen9 && Ts == Type::f32 |
22099 | && !Td.isFP()) { |
22100 | // Gen9: round to nearest before downconvert (not done by mov). |
22101 | rnde(nelems_real, sreg(scrosspack), |
22102 | sreg(scrosspack)); |
22103 | } |
22104 | if (sconvert) |
22105 | mov(nelems_real | modMov, sregConverted, |
22106 | sreg(scrosspack)); |
22107 | break; |
22108 | case 0: |
22109 | if (alpha_real == 1 || alpha_real == -1) { |
22110 | if (Ts.real() == Td.real()) { |
22111 | movePipes(sreg, scrosspack == 1); |
22112 | movePipes(dreg, scrosspack == 1); |
22113 | if (!sconvert) |
22114 | sregConverted |
22115 | = sreg(scrosspack); |
22116 | if (!dconvert) |
22117 | dregConverted |
22118 | = dreg(dcrosspack); |
22119 | if (hw >= HW::XeHP |
22120 | && scrosspack |
22121 | != dcrosspack) { |
22122 | moveToIntPipe(nelems_real, |
22123 | sregConverted); |
22124 | moveToIntPipe(nelems_real, |
22125 | dregConverted); |
22126 | sreg = sreg.reinterpret(0, |
22127 | sregConverted |
22128 | .getType()); |
22129 | } |
22130 | } |
22131 | int telems = nelems_real * Ts.real() |
22132 | / sreg.getBytes(); |
22133 | if (telems > 32) { |
22134 | nelems_real = (nelems_real * 32) |
22135 | / telems; |
22136 | telems = 32; |
22137 | } |
22138 | if (alpha_real == -1) { |
22139 | auto wd = elementsPerGRF( |
22140 | hw, sreg.getType()); |
22141 | auto base = state.signChange.sub( |
22142 | 0, dreg.getType()); |
22143 | xor_(telems, dreg(1), sreg(1), |
22144 | (wd >= telems) |
22145 | ? base(1) |
22146 | : base(0, wd, 1)); |
22147 | } else |
22148 | emov(telems | mmodMov, |
22149 | dregConverted, |
22150 | sregConverted, strategy, |
22151 | state); |
22152 | } else { |
22153 | auto realDst = dreg(dcrosspack); |
22154 | auto effDst = realDst; |
22155 | if (preswizzle |
22156 | && (Ts.isFP() || Td.isFP())) { |
22157 | allocTemp(); |
22158 | if ((sreg.getOffset() |
22159 | != dreg.getOffset()) |
22160 | || (scrosspack |
22161 | != dcrosspack)) |
22162 | effDst = copyTemp[0].sub( |
22163 | sreg.getOffset(), |
22164 | sreg.getType())( |
22165 | scrosspack); |
22166 | } |
22167 | |
22168 | if (alpha_real.fixed()) |
22169 | mul(nelems_real, effDst, |
22170 | sregConverted, |
22171 | cast(Ts.real(), |
22172 | alpha_real)); |
22173 | else |
22174 | mul(nelems_real, effDst, |
22175 | sregConverted, |
22176 | alpha_real.getRegAvoiding( |
22177 | hw, sreg)); |
22178 | |
22179 | if (effDst != realDst) { |
22180 | moveToIntPipe(nelems_real, realDst); |
22181 | moveToIntPipe(nelems_real, effDst); |
22182 | int nelems_real_int = nelems_real |
22183 | * Td |
22184 | / getBytes( |
22185 | effDst.getType()); |
22186 | emov(nelems_real_int, realDst, |
22187 | effDst, strategy, state); |
22188 | dconvert = false; |
22189 | } |
22190 | } |
22191 | break; |
22192 | case 1: |
22193 | if (dconvert) |
22194 | mov(nelems_real | modMov, |
22195 | dreg(dcrosspack), |
22196 | dregConverted); |
22197 | break; |
22198 | } /* switch phase */ |
22199 | |
22200 | int nelems = nelems_real; |
22201 | eoffX += nelems; |
22202 | } /* eoffX loop */ |
22203 | } /* qCX loop */ |
22204 | } /* eoffY loop */ |
22205 | } /* sblock loop */ |
22206 | } /* phaseY loop */ |
22207 | |
22208 | } /* phase loop */ |
22209 | |
22210 | state.ra.safeRelease(copyTemp); |
22211 | return true; // Success |
22212 | } |
22213 | |
22214 | // Get driver information from this strategy. |
22215 | template <HW hw> |
22216 | CommonDriverInfo gemm_kernel_generator_t<hw>::driverInfo( |
22217 | const CopyProblem &problem, const CopyStrategy &strategy) { |
22218 | CommonDriverInfo info; |
22219 | bool isA = (problem.D.layout == MatrixLayout::Pc); |
22220 | |
22221 | for (int d = 0; d < 3; d++) { |
22222 | info.blocking[d] = info.blockingAlt[d] = info.unroll[d] = 0; |
22223 | info.wg[d] = 1; |
22224 | info.loopOrder[d] = LoopNone; |
22225 | } |
22226 | |
22227 | info.subgroupSize = strategy.subgroupSize; |
22228 | info.grfCount = strategy.GRFs; |
22229 | info.unroll[0] = isA ? strategy.unrollX : strategy.unrollY; |
22230 | info.unroll[1] = isA ? strategy.unrollY : strategy.unrollX; |
22231 | info.kRemainderHandling |
22232 | = (strategy.remHandlingY != RemainderHandling::Ignore); |
22233 | info.loopOrder[0] = (isA ^ strategy.xLoop) ? LoopM : LoopN; |
22234 | if (strategy.zParallel) |
22235 | info.loopOrder[1] = (isA ^ strategy.xLoop) ? LoopN : LoopM; |
22236 | info.fusedLoop = strategy.fused ? info.loopOrder[0] : LoopNone; |
22237 | info.wg[0] = 16; |
22238 | info.wgExpand = 1; |
22239 | info.wgUpdate = WGDynamic; |
22240 | info.kRemainderHandling = true; |
22241 | info.kParallel = strategy.zParallel; |
22242 | info.kParallelLocal = false; |
22243 | info.slm = info.perKSLM = 0; |
22244 | info.alignment[0] = problem.S.alignment; |
22245 | info.alignment[1] = problem.D.alignment; |
22246 | info.alignment[2] = 0; |
22247 | info.support4GB[0] = (strategy.S.base.getModel() == ModelA64); |
22248 | info.support4GB[1] = (strategy.D.base.getModel() == ModelA64); |
22249 | info.support4GB[2] = false; |
22250 | |
22251 | return info; |
22252 | } |
22253 | |
22254 | // Validate a copy strategy, correcting settings as necessary. |
22255 | void CopyStrategy::preflight(HW hw, const CopyProblem &problem) { |
22256 | bool cm = isColMajor(problem.D.layout); |
22257 | |
22258 | S.preflight(hw); |
22259 | D.preflight(hw); |
22260 | |
22261 | s_load = std::max(s_load, 1); |
22262 | d_load = std::max(d_load, 1); |
22263 | if (s_load_masked == 0) s_load_masked = s_load; |
22264 | if (d_load_masked == 0) d_load_masked = d_load; |
22265 | unrollX = std::max(unrollX, 1); |
22266 | unrollY = std::max(unrollY, 1); |
22267 | unrollY = align_up(unrollY, problem.D.crosspack); |
22268 | |
22269 | // Ensure d_load is a multiple of s_load and crosspack, and unrollZ a multiple of both. |
22270 | // For x loop kernels, ensure s_load is a multiple of the packing size. |
22271 | // For y loop kernels, ensure all d_loads are multiples of y tile size if any. |
22272 | if (xLoop) { |
22273 | s_load = align_up(s_load, problem.D.packSize); |
22274 | s_load_masked = align_up(s_load_masked, problem.D.packSize); |
22275 | } else { |
22276 | auto D_tileY = cm ? problem.D.tileC : problem.D.tileR; |
22277 | if (D_tileY > 0) d_load_masked = align_up(d_load_masked, D_tileY); |
22278 | d_load_masked = align_up(d_load_masked, problem.D.crosspack); |
22279 | } |
22280 | d_load = align_up(d_load, s_load); |
22281 | d_load_masked = align_up(d_load_masked, s_load_masked); |
22282 | d_load = align_up(d_load, d_load_masked); |
22283 | |
22284 | if (xLoop) |
22285 | unrollX = align_up(unrollX, d_load); |
22286 | else |
22287 | unrollY = align_up(unrollY, d_load); |
22288 | |
22289 | if (unrollY == 1 && remHandlingY == RemainderHandling::Split) |
22290 | remHandlingY = RemainderHandling::General; |
22291 | |
22292 | spf &= !problem.trsm; // TRSM copies use SIMT control flow. |
22293 | |
22294 | CommonStrategy::preflight(hw, problem); |
22295 | } |
22296 | |
22297 | /**********************************************************************/ |
22298 | /* Common Kernel Functions */ |
22299 | /**********************************************************************/ |
22300 | |
22301 | // Generate the kernel prologue. |
22302 | template <HW hw> |
22303 | void gemm_kernel_generator_t<hw>::prologue(const CommonStrategy &strategy) { |
22304 | uint16_t cr0Enable; |
22305 | |
22306 | interface.generatePrologue(*this); |
22307 | |
22308 | cr0Enable = 0x1000; // IEEE float->int rounding. |
22309 | if (strategy.ieeeDenormals) cr0Enable |= 0x4C0; // Enable hf|f|df denormals. |
22310 | if (strategy.spf) cr0Enable |= 0x4; // Enable single program flow. |
22311 | |
22312 | or_(1, cr0, cr0, cr0Enable); |
22313 | |
22314 | InstructionModifier imod = 1; |
22315 | if (hw < HW::Gen12LP) imod |= Switch; |
22316 | |
22317 | if (interface.getSIMD() < 16) mov(imod, sr0[2], uint16_t(0xFFFF)); |
22318 | } |
22319 | |
22320 | // Generate the kernel epilogue. |
22321 | template <HW hw> |
22322 | void gemm_kernel_generator_t<hw>::epilogue( |
22323 | const CommonStrategy &strategy, const CommonState &state) { |
22324 | auto r0_info = state.r0_info; |
22325 | |
22326 | if (r0_info.getBase() < 112) { |
22327 | mov<uint32_t>(8, r127, r0_info); |
22328 | r0_info = r127; |
22329 | } |
22330 | |
22331 | if (strategy.finalFence) { |
22332 | memfence(r124, r0_info); |
22333 | mov<uint32_t>(8, null, r124); |
22334 | } |
22335 | |
22336 | threadend(r0_info); |
22337 | } |
22338 | |
22339 | // Pad the end of the kernel to accommodate instruction prefetching. |
22340 | template <HW hw> |
22341 | void gemm_kernel_generator_t<hw>::padding() { |
22342 | for (int q = 0; q < 8; q++) |
22343 | nop(); |
22344 | } |
22345 | |
22346 | // Common state initialization code. |
22347 | template <HW hw> |
22348 | void gemm_kernel_generator_t<hw>::initState(const CommonProblem &problem, |
22349 | const CommonStrategy &strategy, CommonState &state) { |
22350 | interface.requireLocalID(3); |
22351 | interface.requireLocalSize(); |
22352 | if (problem.nonuniformWGs) interface.requireNonuniformWGs(); |
22353 | |
22354 | if (strategy.wgInSS) interface.requireBarrier(); |
22355 | |
22356 | interface.requireSIMD(strategy.subgroupSize); |
22357 | |
22358 | if (!strategy.sipR0WA) interface.requireNoPreemption(); |
22359 | |
22360 | interface.requireGRF(strategy.GRFs); |
22361 | state.ra.setRegisterCount(strategy.GRFs); |
22362 | |
22363 | if (problem.gtpinSupport) interface.requireScratch(128); |
22364 | |
22365 | for (int i = 0; i < FlagRegister::subcount(hw); i++) |
22366 | state.activeVFlags[i].clear(); |
22367 | } |
22368 | |
22369 | CommonStrategy::CommonStrategy(HW hw, int stepping) : emulate(hw, stepping) { |
22370 | fused = one_of(hw, HW::Gen12LP, HW::XeHP, HW::XeHPG); |
22371 | } |
22372 | |
22373 | void CommonStrategy::preflight(HW hw, const CommonProblem &problem) { |
22374 | subgroupSize = std::max(subgroupSize, GRF::bytes(hw) >> 2); |
22375 | sipR0WA &= (hw == HW::Gen9); |
22376 | if (sipR0WA && (moveR0 == MoveR0::None)) moveR0 = MoveR0::GRF; |
22377 | readSuppressionWA &= fused; |
22378 | |
22379 | bool emulateNeedsAcc = emulate.emulate64 || emulate.emulateDWxDW |
22380 | || emulate.emulate64_mul; |
22381 | if (moveR0 == MoveR0::Acc && emulateNeedsAcc) moveR0 = MoveR0::None; |
22382 | |
22383 | spf &= !fused; |
22384 | } |
22385 | |
22386 | template <HW hw> |
22387 | constexpr typename gemm_kernel_generator_t<hw>::status_stream::Endl |
22388 | gemm_kernel_generator_t<hw>::status_stream::endl; |
22389 | |
22390 | REG_GEN9_ISA(template class gemm_kernel_generator_t<HW::Gen9>); |
22391 | REG_XELP_ISA(template class gemm_kernel_generator_t<HW::Gen12LP>); |
22392 | REG_XEHP_ISA(template class gemm_kernel_generator_t<HW::XeHP>); |
22393 | REG_XEHPG_ISA(template class gemm_kernel_generator_t<HW::XeHPG>); |
22394 | REG_XEHPC_ISA(template class gemm_kernel_generator_t<HW::XeHPC>); |
22395 | |
22396 | } // namespace jit |
22397 | } // namespace gpu |
22398 | } // namespace impl |
22399 | } // namespace dnnl |
22400 | |