1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <array>
18#include <cstddef>
19#include <functional>
20#include <numeric>
21#include <stdexcept>
22#include <vector>
23
24#include "common/impl_registration.hpp"
25#include "gpu/jit/gemm/gen_gemm_kernel_generator.hpp"
26#include "gpu/jit/gemm/loop_sequencer.hpp"
27#include "gpu/jit/gemm/utils.hpp"
28
29namespace dnnl {
30namespace impl {
31namespace gpu {
32namespace jit {
33
34using namespace ngen;
35using namespace ngen::utils;
36using dnnl::impl::utils::one_of;
37using ngen::utils::log2;
38
39using std::complex;
40using std::vector;
41
42#define MOCK_BARRIERS
43
44class need_vflag : public std::runtime_error {
45public:
46 need_vflag() : std::runtime_error("Need virtual flag registers") {}
47};
48
49class stub_exception : public std::runtime_error {
50public:
51 stub_exception()
52 : std::runtime_error("Functionality not yet implemented") {}
53};
54
55class hw_unsupported_exception : public std::runtime_error {
56public:
57 hw_unsupported_exception()
58 : std::runtime_error("Unsupported in hardware") {}
59};
60
61[[noreturn]] static void hw_unsupported() {
62 throw hw_unsupported_exception();
63}
64
65[[noreturn]] static void stub() {
66 throw stub_exception();
67}
68
69// Helpers
70template <typename U>
71static inline Immediate cast(Type T, U val) {
72 switch (T) {
73 case Type::f16: return half(val);
74 case Type::f32: return float(val);
75 case Type::u8: return uint8_t(val);
76 case Type::s8: return int8_t(val);
77 case Type::u16: return uint16_t(val);
78 case Type::s16: return int16_t(val);
79 case Type::u32: return uint32_t(val);
80 case Type::s32: return int32_t(val);
81 case Type::u64: return uint64_t(val);
82 case Type::s64: return int64_t(val);
83 case Type::bf16:
84 case Type::tf32:
85 default: stub();
86 }
87}
88
89static inline Immediate cast(Type T, Scalar<double> val) {
90 return cast(T, double(val));
91}
92
93bool Type::isSubsetOf(Type T) const {
94 if (*this == T) return true;
95
96 if (isInteger() && T == bf16) return false;
97
98 return (size() < T.size());
99}
100
101constexpr bool operator==(const RegData &rd, int i) {
102 return false;
103}
104constexpr bool operator==(const RegData &rd, const Immediate &i) {
105 return false;
106}
107constexpr bool operator!=(const RegData &rd, int i) {
108 return true;
109}
110constexpr bool operator!=(const RegData &rd, const Immediate &i) {
111 return true;
112}
113
114void noop() {}
115
116static inline constexpr bool isGen9IGEMM(HW hw, Type Ta, Type Tb, Type Tc) {
117 return (hw < HW::Gen12LP && Ta.size() == 1 && Tb.size() == 1
118 && Tc.size() == 4);
119}
120
121template <typename T>
122static inline constexpr int elementsPerGRF(HW hw) {
123 return GRF::bytes(hw) / sizeof(T);
124}
125
126static inline constexpr int elementsPerGRF(HW hw, Type T) {
127 return GRF::bytes(hw) / T;
128}
129
130static inline constexpr int elementsPerGRF(HW hw, DataType dt) {
131 return GRF::bytes(hw) / getBytes(dt);
132}
133
134static inline bool canSwizzle(HW hw, DataType dt) {
135 if (hw < HW::XeHP) return true;
136
137 switch (dt) {
138 case DataType::b:
139 case DataType::ub:
140 case DataType::w:
141 case DataType::uw:
142 case DataType::d:
143 case DataType::ud: return true;
144 case DataType::q:
145 case DataType::uq: return (hw >= HW::XeHPC);
146 default: return false;
147 }
148}
149
150static inline bool canSwizzle(HW hw, Type T) {
151 return canSwizzle(hw, T.ngen());
152}
153
154static inline bool hasNativeAtomicAdd(HW hw, Type T,
155 const MatrixAddressing &atype,
156 const MatrixAddressingStrategy &astrategy) {
157 bool floatAtomics = (astrategy.base.getModel() == ModelA64);
158 if (astrategy.newDP)
159 floatAtomics |= (astrategy.base.getModel() != ModelSLM);
160
161 if (T.isInteger())
162 return true;
163 else if (T == Type::f32)
164 return floatAtomics && (hw >= HW::XeHP);
165 else
166 return false;
167}
168
169static inline int slmCapacity(HW hw) {
170 switch (hw) {
171 case HW::Gen9:
172 case HW::Gen11: return 65536;
173 case HW::Gen12LP:
174 case HW::XeHP:
175 case HW::XeHPG:
176 case HW::XeHPC: return 131072;
177 default: return 0;
178 }
179}
180
181static inline int threadsPerEU(HW hw, const CommonStrategy &strategy) {
182 if (hw >= HW::XeHP)
183 return (strategy.GRFs > 128) ? 4 : 8;
184 else
185 return 7;
186}
187
188static inline int eusPerSubslice(HW hw) {
189 switch (hw) {
190 case HW::Gen9:
191 case HW::Gen11:
192 case HW::XeHPC: return 8;
193 case HW::Gen12LP:
194 case HW::XeHP:
195 case HW::XeHPG: return 16;
196 default: return 0;
197 }
198}
199
200static inline bool canDualGRF(
201 HW hw, DataType dt, const CommonStrategy &strategy) {
202 return (strategy.dualGRF && (elementsPerGRF(hw, dt) < 32));
203}
204
205// Perform a binary register-wise operation.
206template <typename F>
207static inline void map(HW hw, DataType dt, const GRFMultirange &r1,
208 const GRFMultirange &r2, const CommonStrategy &strategy, F f) {
209 int ne = elementsPerGRF(hw, dt);
210 int rstride = canDualGRF(hw, dt, strategy) ? 2 : 1;
211 int len = r1.getLen();
212
213 for (int rr = 0; rr < len;) {
214 int nr = std::min<int>(len - rr, rstride);
215 if (!r1.contiguous(rr, nr) || !r2.contiguous(rr, nr)) nr = 1;
216 f(nr * ne, r1[rr].retype(dt), r2[rr].retype(dt));
217 rr += nr;
218 }
219}
220
221// Perform a ternary register-wise operation.
222template <typename F>
223static inline void map(HW hw, DataType dt, const GRFMultirange &r1,
224 const GRFMultirange &r2, const GRFMultirange &r3,
225 const CommonStrategy &strategy, F f) {
226 int ne = elementsPerGRF(hw, dt);
227 int rstride = canDualGRF(hw, dt, strategy) ? 2 : 1;
228 int len = r1.getLen();
229
230 for (int rr = 0; rr < len;) {
231 int nr = std::min<int>(len - rr, rstride);
232 if (!r1.contiguous(rr, nr) || !r2.contiguous(rr, nr)
233 || !r3.contiguous(rr, nr))
234 nr = 1;
235 f(nr * ne, r1[rr].retype(dt), r2[rr].retype(dt), r3[rr].retype(dt));
236 rr += nr;
237 }
238}
239
240// Perform a quaternary register-wise operation.
241template <typename F>
242static inline void map(HW hw, DataType dt, const GRFMultirange &r1,
243 const GRFMultirange &r2, const GRFMultirange &r3,
244 const GRFMultirange &r4, const CommonStrategy &strategy, F f) {
245 int ne = elementsPerGRF(hw, dt);
246 int rstride = canDualGRF(hw, dt, strategy) ? 2 : 1;
247 int len = r1.getLen();
248
249 for (int rr = 0; rr < len;) {
250 int nr = std::min<int>(len - rr, rstride);
251 if (!r1.contiguous(rr, nr) || !r2.contiguous(rr, nr)
252 || !r3.contiguous(rr, nr) || !r4.contiguous(rr, nr))
253 nr = 1;
254 f(nr * ne, r1[rr].retype(dt), r2[rr].retype(dt), r3[rr].retype(dt),
255 r4[rr].retype(dt));
256 rr += nr;
257 }
258}
259
260// Perform a unary register-wise operation on a register block.
261template <typename F>
262static inline void map(HW hw, DataType dt, const GRFMultirange &regs,
263 const vector<RegisterBlock> &layout, const CommonStrategy &strategy,
264 F f, int cxComponent = -1) {
265 int curReg = 0, curOff = 0, curBytes = 0;
266 auto ebytes = getBytes(dt);
267
268 auto map1 = [&]() {
269 curOff &= -ebytes;
270 curBytes &= -ebytes;
271 while (curBytes) {
272 int maxBytes;
273 if (curOff & (GRF::bytes(hw) - 1))
274 maxBytes = GRF::bytes(hw) - curOff;
275 else
276 maxBytes = (canDualGRF(hw, dt, strategy) ? 2 : 1)
277 * GRF::bytes(hw);
278
279 auto nbytes = rounddown_pow2(std::min(maxBytes, curBytes));
280 auto ne = std::min<int>(32, nbytes / ebytes);
281 nbytes = ne * ebytes;
282
283 auto reg = regs[curOff >> GRF::log2Bytes(hw)].sub(
284 (curOff & (GRF::bytes(hw) - 1)) / ebytes, dt)(1);
285
286 f(ne, reg);
287
288 curBytes -= nbytes;
289 curOff += nbytes;
290 }
291 };
292
293 for (auto &block : layout) {
294 int endReg
295 = (curOff + curBytes + block.bytes - 1) >> GRF::log2Bytes(hw);
296 if ((block.offsetBytes == curOff + curBytes)
297 && regs.contiguous(curReg, endReg - curReg + 1))
298 curBytes += block.bytes;
299 else {
300 map1();
301 curOff = block.offsetBytes;
302 curReg = curOff >> GRF::log2Bytes(hw);
303 curBytes = block.bytes;
304 }
305 }
306
307 map1();
308}
309
310template <typename T, typename F>
311static inline void map(HW hw, const GRFMultirange &r1, const GRFMultirange &r2,
312 const CommonStrategy &strategy, F f) {
313 map(hw, getDataType<T>(), r1, r2, strategy, f);
314}
315
316template <typename T, typename F>
317static inline void map(HW hw, const GRFMultirange &r1, const GRFMultirange &r2,
318 const GRFMultirange &r3, const CommonStrategy &strategy, F f) {
319 map(hw, getDataType<T>(), r1, r2, r3, strategy, f);
320}
321
322template <typename T, typename F>
323static inline void map(HW hw, const GRFMultirange &regs,
324 const vector<RegisterBlock> &layout, const CommonStrategy &strategy,
325 F f) {
326 map(hw, getDataType<T>(), regs, layout, strategy, f);
327}
328
329template <typename... Targs>
330static inline void map(HW hw, Type T, Targs... args) {
331 map(hw, T.ngen(), args...);
332}
333
334// Move subregister to another pipe.
335static inline void movePipes(Subregister &s, bool sizeCanChange = true) {
336 DataType type = s.getType();
337
338 switch (type) {
339 case DataType::bf:
340 case DataType::hf: type = DataType::uw; break;
341 case DataType::tf32:
342 case DataType::f: type = DataType::ud; break;
343 case DataType::df:
344 if (sizeCanChange) type = DataType::ud;
345 break;
346 case DataType::w:
347 case DataType::uw: type = DataType::hf; break;
348 case DataType::d:
349 case DataType::ud: type = DataType::f; break;
350 case DataType::q:
351 case DataType::uq:
352 if (sizeCanChange) type = DataType::f;
353 break;
354 default: break;
355 }
356
357 s = s.reinterpret(0, type);
358}
359
360// Move subregister to integer pipe.
361static inline void moveToIntPipe(Subregister &s) {
362 DataType type = s.getType();
363
364 switch (type) {
365 case DataType::bf:
366 case DataType::hf: type = DataType::uw; break;
367 case DataType::q:
368 case DataType::uq:
369 case DataType::f:
370 case DataType::tf32:
371 case DataType::df: type = DataType::ud; break;
372 default: break;
373 }
374
375 s = s.reinterpret(0, type);
376}
377
378static inline Type sintType(Type T) {
379 switch (T) {
380 case Type::s8:
381 case Type::u8: return Type::s8;
382 case Type::bf16:
383 case Type::f16:
384 case Type::u16:
385 case Type::s16: return Type::s16;
386 case Type::f32:
387 case Type::s32:
388 case Type::u32: return Type::s32;
389 default: return Type::invalid;
390 }
391}
392
393// Move register region to integer pipe.
394static inline void moveToIntPipe(int esize, RegData &s) {
395 switch (s.getType()) {
396 case DataType::bf:
397 case DataType::hf: s.setType(DataType::uw); break;
398 case DataType::q:
399 case DataType::uq:
400 case DataType::tf32:
401 case DataType::f: s.setType(DataType::ud); break;
402 case DataType::df:
403 s.setType(DataType::uq);
404 EmulationImplementation::makeDWPair(s, esize);
405 break;
406 default: break;
407 }
408}
409
410void RegisterBlock::calcBytes(
411 Type T, const MatrixAddressingStrategy &astrategy) {
412 if (astrategy.newDP && astrategy.prefetch)
413 bytes = 0;
414 else
415 calcBytes(T);
416}
417
418void RegisterBlock::calcBytes(Type T) {
419 if (cxComponent != Interleaved) T = T.real();
420 bytes = align_up(colMajor ? nc : nr, crosspack) * ld * T;
421 if (isLoadBlock() && msgRegs == 0)
422 msgRegs = (bytes + (1 << log2GRFBytes) - 1) >> log2GRFBytes;
423}
424
425int RegisterBlock::nregs() const {
426 auto grfBytes = (1 << log2GRFBytes);
427 if (offsetBytes & (grfBytes - 1)) stub();
428 return (bytes + grfBytes - 1) >> log2GRFBytes;
429}
430
431int RegisterBlock::offsetReg() const {
432 auto grfBytes = (1 << log2GRFBytes);
433 if (offsetBytes & (grfBytes - 1)) stub();
434 return offsetBytes >> log2GRFBytes;
435}
436
437void RegisterBlock::simplify(Type T) {
438 // If block is completely crosspacked, convert to equivalent layout without crosspack.
439 if (crosspack == (colMajor ? nc : nr) && isLargeCrosspack(T, crosspack)) {
440 auto od = colMajor ? nr : nc;
441 if (ld == od) {
442 colMajor = !colMajor;
443 ld = crosspack;
444 crosspack = 1;
445 }
446 }
447}
448
449GRFMultirange GRFMultirange::subrange(
450 HW hw, Type T, const RegisterBlock &block) const {
451 int ne = elementsPerGRF(hw, T);
452 int ldGRFs = div_up(block.ld, ne);
453 int ldUsedGRFs = div_up(block.colMajor ? block.nr : block.nc, ne);
454 int td = block.colMajor ? block.nc : block.nr;
455
456 if (ldUsedGRFs >= ldGRFs)
457 return subrange(block.offsetReg(), block.nregs());
458 else {
459 int offReg = block.offsetReg();
460 GRFMultirange result = subrange(offReg, ldUsedGRFs);
461 for (int y = 1; y < td; y++) {
462 offReg += ldGRFs;
463 result.append(subrange(offReg, ldUsedGRFs));
464 }
465 return result;
466 }
467}
468
469// Make a RegisterBlock smaller by contracting the leading dimension, if possible.
470void RegisterBlock::compact(Type T) {
471 auto newLD = std::max<int>(
472 roundup_pow2(colMajor ? nr : nc), (1 << log2GRFBytes) / T);
473 if (newLD < ld) {
474 ld = newLD;
475 calcBytes(T);
476 }
477}
478
479static inline bool isTransposing(AccessType atype) {
480 if (atype == AccessType::Scattered) return true;
481 if (atype == AccessType::ChannelScattered) return true;
482 if (atype == AccessType::Block2DTranspose) return true;
483 return false;
484}
485
486Subregister SubregisterPair::getReg(int idx) const {
487 auto r = regs[idx & 1];
488 if (negative) r = -r;
489 return r;
490}
491
492Subregister SubregisterPair::getRegAvoiding(HW hw, const RegData &rd) const {
493 if (Bundle::same_bank(hw, rd, regs[0]))
494 return getReg(1);
495 else
496 return getReg(0);
497}
498
499inline namespace {
500template <typename T>
501struct ACHelper {
502 static T avoidConflict(HW hw, const T &x, const RegData &other) {
503 return x;
504 }
505};
506template <>
507struct ACHelper<SubregisterPair> {
508 static Subregister avoidConflict(
509 HW hw, const SubregisterPair &x, const RegData &other) {
510 return x.getRegAvoiding(hw, other);
511 }
512};
513template <typename T>
514struct ACHelper<Scalar<T>> {
515 static Subregister avoidConflict(
516 HW hw, const Scalar<T> &x, const RegData &other) {
517 return x.getRegAvoiding(hw, other);
518 }
519};
520} // namespace
521template <typename T>
522decltype(ACHelper<T>::avoidConflict(HW::Unknown, std::declval<T>(), RegData()))
523avoidConflict(HW hw, const T &x, const RegData &other) {
524 return ACHelper<T>::avoidConflict(hw, x, other);
525}
526
527FlagRegister VirtualFlag::toPhysical() const {
528 if (n == 2)
529 return FlagRegister(idx >> 1);
530 else
531 return FlagRegister::createFromIndex(idx);
532}
533
534VirtualFlag VirtualFlagAllocator::allocVirtual(int n) {
535 if (!free) throw out_of_registers_exception();
536 if (n > 2) stub();
537
538 uint32_t bmask = free;
539 if (n == 2) bmask = (bmask & (bmask >> 1)) & 0x55555555;
540 int base = bsf(bmask);
541
542 VirtualFlag vflag {base, n};
543 claim(vflag);
544
545 return vflag;
546}
547
548FlagRegister VirtualFlagAllocator::tryAlloc(int n) {
549 auto vflag = allocVirtual(n);
550 if (isVirtual(vflag)) {
551 release(vflag);
552 return FlagRegister {};
553 }
554
555 lock(vflag);
556
557 return vflag.toPhysical();
558}
559
560FlagRegister VirtualFlagAllocator::alloc(int n) {
561 auto flag = tryAlloc(n);
562 if (flag.isInvalid()) throw out_of_registers_exception();
563
564 return flag;
565}
566
567FlagRegister VirtualFlagAllocator::assignPhysical(VirtualFlag vflag) {
568 VirtualFlag pflag;
569
570 // Is it already a physical flag register?
571 if (!isVirtual(vflag)) {
572 pflag = vflag;
573 } else {
574 // It's virtual. Starting at nextPhys, find an unlocked flag register.
575 for (int i = nextPhys; i < nextPhys + nflag; i++) {
576 if (i & (vflag.n - 1)) continue;
577 auto idx = i & (nflag - 1);
578 if (!(locked & mask(idx, vflag.n))) {
579 nextPhys = (idx + vflag.n) & (nflag - 1);
580 pflag = VirtualFlag {idx, vflag.n};
581 break;
582 }
583 }
584 }
585
586 if (!pflag) throw out_of_registers_exception();
587
588 return pflag.toPhysical();
589}
590
591static inline RegData getMaskFlag(VirtualFlag vflag, CommonState &state) {
592 if (state.vflagStorage.isValid())
593 return state.vflagStorage[vflag.idx].reinterpret(
594 0, vflag.n == 2 ? DataType::ud : DataType::uw);
595 else if (!state.raVFlag.isVirtual(vflag)) {
596 auto pflag = vflag.toPhysical();
597 state.usePhysicalFlag(pflag);
598 return pflag;
599 } else
600 throw need_vflag();
601}
602
603template <HW hw>
604FlagRegister gemm_kernel_generator_t<hw>::getPhysicalFlag(
605 VirtualFlag vflag, CommonState &state) {
606 VirtualFlag pflag;
607
608 if (state.vflagStorage.isValid()) {
609 // Check if virtual flag is currently active.
610 int pidx = -1;
611 for (int i = 0; i < FlagRegister::subcount(hw); i++)
612 if (state.activeVFlags[i] == vflag) pidx = i;
613
614 // If flag is not currently active, load it into a physical flag.
615 if (pidx == -1) {
616 auto freg = state.raVFlag.assignPhysical(vflag);
617 pidx = freg.index();
618 mov(1, freg, getMaskFlag(vflag, state));
619 for (int i = 0; i < int(vflag.n); i++)
620 state.activeVFlags[pidx + i] = vflag;
621 }
622
623 pflag = VirtualFlag {pidx, vflag.n};
624 } else {
625 if (state.raVFlag.isVirtual(vflag)) throw need_vflag();
626
627 pflag = vflag;
628 }
629
630 return pflag.toPhysical();
631}
632
633template <HW hw>
634void gemm_kernel_generator_t<hw>::allocVFlagStorage(
635 const CommonStrategy &strategy, CommonState &state) {
636 state.vflagStorage
637 = state.ra.alloc(getHint(HintType::LongTerm, strategy)).uw();
638}
639
640TokenAllocator::TokenAllocator(HW hw) {
641 free = (1ull << tokenCount(hw)) - 1;
642}
643
644int8_t TokenAllocator::tryAlloc() {
645 if (free) {
646 int8_t token = bsf(free);
647 free &= ~(1 << token);
648 return token;
649 } else
650 return -1;
651}
652
653/************************/
654/* Pseudo-instructions. */
655/************************/
656
657// goto instruction with Gen12 semantics.
658template <HW hw>
659void gemm_kernel_generator_t<hw>::goto12(const InstructionModifier &mod,
660 Label &jip, Label &uip, bool branchCtrl) {
661 InstructionModifier mmod = mod;
662 if (hw == HW::Gen9 && !branchCtrl) {
663 if (mmod.getPredCtrl() == PredCtrl::None) stub();
664 mmod.setPredInv(!mmod.isPredInv());
665 }
666 goto_(mmod, jip, uip, branchCtrl);
667}
668
669// Compare to zero.
670template <HW hw>
671void gemm_kernel_generator_t<hw>::cmp0(
672 const InstructionModifier &mod, RegData src0) {
673 mov(mod, null.retype(src0.getType()), abs(src0));
674}
675
676// Scale then add: dst <- src0 + src1 * (numerator / denominator), rounding up.
677// If exact = true, ensure src1 * num / denom is integral if src1 immediate.
678template <HW hw>
679void gemm_kernel_generator_t<hw>::addScaled(const InstructionModifier &mod,
680 const RegData &dst, int src0, const RegData &src1, int numerator,
681 int denominator, CommonState &state, bool exact) {
682 if (!is_zero_or_pow2(numerator)) stub();
683 if (!is_zero_or_pow2(denominator)) stub();
684
685 if (numerator == denominator) {
686 (src0 != 0) ? add(mod, dst, src1, src0)
687 : (src1 != dst) ? mov(mod, dst, src1) : noop();
688 } else if (numerator > denominator) {
689 (src0 == 0) ? mulConstant(mod, dst, src1, numerator / denominator)
690 : mad(mod, dst, src0, src1, numerator / denominator);
691 } else if ((numerator * 2) == denominator)
692 avg(mod, dst, src1, src0 * 2);
693 else {
694 add(mod, dst, src1, ((src0 + 1) * denominator / numerator) - 1);
695 asr(mod, dst, dst, log2(denominator) - log2(numerator));
696 }
697}
698
699template <HW hw>
700void gemm_kernel_generator_t<hw>::addScaled(const InstructionModifier &mod,
701 const RegData &dst, const RegData &src0, const RegData &src1,
702 int numerator, int denominator, CommonState &state, bool exact) {
703 if (!is_zero_or_pow2(numerator)) stub();
704 if (!is_zero_or_pow2(denominator)) stub();
705
706 if (numerator == denominator)
707 add(mod, dst, src1, src0);
708 else if (numerator > denominator)
709 mad(mod, dst, src0, src1, numerator / denominator);
710 else {
711 auto temp = state.ra.alloc_sub(src1.getType());
712 if (exact)
713 asr(mod, temp, src1, log2(denominator) - log2(numerator));
714 else {
715 add(mod, temp, src1, (denominator / numerator) - 1);
716 asr(mod, temp, temp, log2(denominator) - log2(numerator));
717 }
718 add(mod, dst, temp, src0);
719 state.ra.safeRelease(temp);
720 }
721}
722
723template <HW hw>
724void gemm_kernel_generator_t<hw>::addScaled(const InstructionModifier &mod,
725 const RegData &dst, const RegData &src0, int src1, int numerator,
726 int denominator, CommonState &state, bool exact) {
727 if (!is_zero_or_pow2(numerator)) stub();
728 if (!is_zero_or_pow2(denominator)) stub();
729 if (exact && ((numerator * src1) % denominator))
730 throw std::runtime_error("Misaligned immediate value.");
731 add(mod, dst, src0, (numerator * src1) / denominator);
732}
733
734// Synchronize on all pipes and OOO operations.
735template <HW hw>
736void gemm_kernel_generator_t<hw>::syncall() {
737 if (hw == HW::Gen12LP)
738 sync.allwr(SWSB(1));
739 else if (hw >= HW::XeHP)
740 sync.allwr(SWSB<AllPipes>(1));
741}
742
743// Multiply by a constant, optimizing for power-of-2 constants.
744template <HW hw>
745template <typename DT>
746void gemm_kernel_generator_t<hw>::mulConstant(const InstructionModifier &mod,
747 const RegData &dst, const RegData &src0, int32_t src1) {
748 if (src1 == 0)
749 mov<DT>(mod, dst, uint16_t(0));
750 else if (src1 == 1) {
751 if (dst != src0) mov<DT>(mod, dst, src0);
752 } else if (src1 == -1)
753 mov<DT>(mod, dst, -src0);
754 else if (is_zero_or_pow2(src1))
755 shl<DT>(mod, dst, src0, uint16_t(log2(src1)));
756 else if (src1 >= 0x10000)
757 mul<DT>(mod, dst, src0, uint32_t(src1));
758 else if (src1 < -0x8000)
759 mul<DT>(mod, dst, src0, int32_t(src1));
760 else if (src1 > 0)
761 mul<DT>(mod, dst, src0, uint16_t(src1));
762 else
763 mul<DT>(mod, dst, src0, int16_t(src1));
764}
765
766// Three-argument add.
767template <HW hw>
768template <typename DT, typename S0, typename S2>
769void gemm_kernel_generator_t<hw>::eadd3(const InstructionModifier &mod,
770 const RegData &dst, const S0 &src0, const RegData &src1,
771 const S2 &src2) {
772 if ((hw >= HW::XeHP) && !(dst.getOffset() & 1))
773 add3<DT>(mod, dst, src0, src1, src2);
774 else {
775 add<DT>(mod, dst, src1, src0);
776 add<DT>(mod, dst, dst, src2);
777 }
778}
779
780template <HW hw>
781template <typename DT>
782void gemm_kernel_generator_t<hw>::emov(const ngen::InstructionModifier &mod,
783 ngen::RegData dst, ngen::RegData src0, const CommonStrategy &strategy,
784 CommonState &state) {
785 EmulationImplementation::applyDefaultType<DT>(dst);
786 EmulationImplementation::applyDefaultType<DT>(src0);
787
788 if (dst.getType() == DataType::tf32 && src0.getType() == DataType::tf32) {
789 dst.setType(DataType::f);
790 src0.setType(DataType::f);
791 }
792
793 if (hw >= HW::XeHP
794 && one_of(src0.getType(), DataType::hf, DataType::f, DataType::bf)
795 && src0.getType() == dst.getType()
796 && ((src0.getHS() != dst.getHS())
797 || (src0.getOffset() != dst.getOffset()))) {
798 moveToIntPipe(mod.getExecSize(), dst);
799 moveToIntPipe(mod.getExecSize(), src0);
800 }
801
802 if (hw < HW::XeHP && dst.getType() == DataType::f
803 && src0.getType() == DataType::bf) {
804 dst.setType(DataType::ud);
805 src0.setType(DataType::uw);
806 shl(mod, dst, src0, 16);
807 } else
808 EmulationImplementation::emov(*this, mod, dst, src0, strategy.emulate);
809}
810
811template <HW hw>
812template <typename DT>
813void gemm_kernel_generator_t<hw>::eadd(const InstructionModifier &mod,
814 const RegData &dst, const RegData &src0, const RegData &src1,
815 const CommonStrategy &strategy, CommonState &state) {
816 if (dst.getType() == DataType::f && src0.getType() == DataType::f
817 && src1.getType() == DataType::bf && src1.getHS() != 1) {
818 GRF alloced, temp = state.emulate.temp[0];
819 if (temp.isInvalid()) temp = alloced = state.ra.alloc();
820
821 auto src1UW = src1;
822 src1UW.setType(DataType::uw);
823 mov(mod, temp.uw(0)(1), src1UW);
824 add(mod, dst, src0, temp.bf(0)(1));
825
826 state.ra.safeRelease(alloced);
827 } else
828 EmulationImplementation::eadd<DT>(
829 *this, mod, dst, src0, src1, strategy.emulate, state.emulate);
830}
831
832template <HW hw>
833template <typename S0, typename S2>
834void gemm_kernel_generator_t<hw>::emad(const InstructionModifier &mod,
835 const RegData &dst, const S0 &src0, const RegData &src1, const S2 &src2,
836 const CommonStrategy &strategy, CommonState &state) {
837 auto dstType = dst.getType();
838 if ((hw >= HW::Gen10 && !(dst.getByteOffset() & 7)
839 && !one_of(dstType, DataType::q, DataType::uq)
840 && !one_of(src2.getType(), DataType::d, DataType::ud))
841 || one_of(dstType, DataType::hf, DataType::f, DataType::df)) {
842 mad(mod, dst, src0, src1, src2);
843 } else {
844 auto ttype = (isSigned(src1.getType()) || isSigned(src2.getType()))
845 ? DataType::d
846 : DataType::ud;
847 auto temp = state.ra.alloc_sub(ttype);
848 emul(mod, temp, src1, src2, strategy, state);
849 eadd(mod, dst, temp, src0, strategy, state);
850 state.ra.safeRelease(temp);
851 }
852}
853
854template <HW hw>
855template <typename S0>
856void gemm_kernel_generator_t<hw>::emad(const InstructionModifier &mod,
857 const RegData &dst, const S0 &src0, const RegData &src1, int32_t src2,
858 const CommonStrategy &strategy, CommonState &state) {
859 auto dstType = dst.getType();
860 if (src2 == 0)
861 emov(mod, dst, src0, strategy, state);
862 else if (src2 == 1)
863 eadd(mod, dst, src1, src0, strategy, state);
864 else if (hw >= HW::Gen10 && !(dst.getByteOffset() & 7)
865 && (src2 >= -0x8000 && src2 < 0x10000)
866 && !one_of(dstType, DataType::q, DataType::uq)) {
867 mad(mod, dst, src0, src1, src2);
868 } else {
869 auto ttype = isSigned(src1.getType()) ? DataType::d : DataType::ud;
870 Subregister tempScalar;
871 GRFRange tempGRFs;
872 RegData temp;
873 if (mod.getExecSize() == 1)
874 temp = tempScalar = state.ra.alloc_sub(ttype);
875 else {
876 tempGRFs = state.ra.alloc_range(2);
877 temp = tempGRFs[0].retype(ttype);
878 }
879 emulConstant(mod, temp, src1, src2, strategy, state);
880 eadd(mod, dst, temp, src0, strategy, state);
881 state.ra.safeRelease(tempScalar);
882 state.ra.safeRelease(tempGRFs);
883 }
884}
885
886template <HW hw>
887template <typename DT>
888void gemm_kernel_generator_t<hw>::emath(const InstructionModifier &mod,
889 MathFunction fc, const RegData &dst, const RegData &src0,
890 const GEMMStrategy &strategy, CommonState &state) {
891 if (hw == HW::XeHP && strategy.systolic && mod.getExecSize() <= 8) {
892 // Workaround for DPAS + SIMD8 EM hang: use SIMD16 arithmetic.
893 auto mod16 = mod;
894 mod16.setExecSize(16);
895
896 auto temp = state.ra.alloc_range(2);
897 auto tt = temp[0].retype(src0.getType());
898
899 mov(mod.getExecSize(), tt, src0);
900 math(mod16, fc, tt, tt);
901 mov(mod.getExecSize(), dst, tt);
902
903 state.ra.safeRelease(temp);
904 } else
905 math(mod, fc, dst, src0);
906}
907
908template <HW hw>
909void gemm_kernel_generator_t<hw>::ejmpi(InstructionModifier mod, Label &dst) {
910 if (hw == HW::XeHPC && mod.getPredCtrl() == PredCtrl::anyv
911 && !mod.isPredInv()) {
912 mod.setPredCtrl(PredCtrl::Normal);
913 jmpi(mod, dst);
914 auto flag = mod.getFlagReg();
915 flag.setBase(flag.getBase() ^ 1);
916 mod.setFlagReg(flag);
917 jmpi(mod, dst);
918 } else
919 jmpi(mod, dst);
920}
921
922/********************/
923/* Utility routines */
924/********************/
925
926// Modulo by constant value.
927template <HW hw>
928template <typename DT>
929void gemm_kernel_generator_t<hw>::mod(const Subregister &dst,
930 const Subregister &src, uint16_t modulus,
931 const CommonStrategy &strategy, CommonState &state) {
932 if (is_zero_or_pow2(modulus))
933 and_<DT>(1, dst, src, modulus - 1);
934 else if (strategy.emulate.emulate64 && (hw <= HW::Gen12LP))
935 math<DT>(1, MathFunction::irem, dst, src, modulus);
936 else {
937 alignDown<DT>(dst, src, modulus, strategy, state);
938 add<DT>(1, dst, src, -dst);
939 }
940}
941
942// Return both (a % b) and a - (a % b).
943template <HW hw>
944template <typename DT>
945void gemm_kernel_generator_t<hw>::modExt(const Subregister &dstMod,
946 const Subregister &dstMultiple, const Subregister &src,
947 uint16_t modulus, const CommonStrategy &strategy, CommonState &state) {
948 if (is_zero_or_pow2(modulus)) {
949 and_<DT>(1, dstMultiple, src, ~uint32_t(modulus - 1));
950 and_<DT>(1, dstMod, src, modulus - 1);
951 } else if (strategy.emulate.emulate64 && (hw <= HW::Gen12LP)) {
952 math<DT>(1, MathFunction::irem, dstMod, src, modulus);
953 add<DT>(1, dstMultiple, src, -dstMod);
954 } else {
955 alignDown<DT>(dstMultiple, src, modulus, strategy, state);
956 add<DT>(1, dstMod, src, -dstMultiple);
957 }
958}
959
960// Divide an unsigned value by a constant, rounding down.
961template <HW hw>
962template <typename DT>
963void gemm_kernel_generator_t<hw>::divDown(const ngen::Subregister &dst,
964 const ngen::Subregister &src, uint16_t divisor,
965 const CommonStrategy &strategy, CommonState &state) {
966 if (is_zero_or_pow2(divisor))
967 shr<DT>(1, dst, src, log2(divisor));
968 else if (strategy.emulate.emulate64 && (hw <= HW::Gen12LP))
969 math<DT>(1, MathFunction::iqot, dst, src, uint32_t(divisor));
970 else {
971 // Replace integer division with multiplication by reciprocal + shift.
972 // Valid for numerators <= 2^31.
973 int shift = ngen::utils::bsr(divisor);
974 uint32_t recip32
975 = ((uint64_t(0x100000000) << shift) + divisor - 1) / divisor;
976 emul32High(1, dst, src, recip32);
977 shr(1, dst, dst, shift);
978 }
979}
980
981// Align an unsigned value down to a multiple of align.
982template <HW hw>
983template <typename DT>
984void gemm_kernel_generator_t<hw>::alignDown(const Subregister &dst,
985 const Subregister &src, uint16_t align, const CommonStrategy &strategy,
986 CommonState &state) {
987 if (is_zero_or_pow2(align))
988 and_<DT>(1, dst, src, uint32_t(-align));
989 else {
990 divDown(dst, src, align, strategy, state);
991 mul(1, dst, dst, align);
992 }
993}
994
995// Align an unsigned value up to a multiple of align.
996template <HW hw>
997template <typename DT>
998void gemm_kernel_generator_t<hw>::alignUp(const Subregister &dst,
999 const Subregister &src, uint16_t align, const CommonStrategy &strategy,
1000 CommonState &state) {
1001 add<DT>(1, dst, src, uint16_t(align - 1));
1002 alignDown<DT>(dst, dst, align, strategy, state);
1003}
1004
1005// Non-constant integer division.
1006// Requires an auxiliary constant: ceiling(2^(32 + s) / denom), where s = floor(log2(denom)).
1007template <HW hw>
1008template <typename DT>
1009void gemm_kernel_generator_t<hw>::divDown(const Subregister &dst,
1010 const Subregister &src0, const Subregister &src1,
1011 const Subregister &src1Recip, const FlagRegister &flag,
1012 const CommonStrategy &strategy, CommonState &state) {
1013 auto shift = state.ra.alloc_sub<uint32_t>();
1014 auto pop = state.ra.alloc_sub<uint16_t>();
1015 cbit(1, pop, src1);
1016 fbh(1, shift, src1);
1017 cmp(1 | gt | flag, pop, 1);
1018 add(1, shift, -shift, 31);
1019 emul32High(1 | flag, dst, src0, src1Recip);
1020 shr(1 | ~flag, dst, src0, shift);
1021 shr(1 | flag, dst, dst, shift);
1022 state.ra.safeRelease(shift);
1023 state.ra.safeRelease(pop);
1024}
1025
1026// Simple do-while loop macro for the backward conditional branch at end of loop.
1027template <HW hw>
1028void gemm_kernel_generator_t<hw>::simtDoWhileLoop(
1029 const InstructionModifier &mod, Label &dest) {
1030 Label next;
1031
1032 goto12(mod, next, dest, true);
1033 mark(next);
1034 join(mod.getExecSize());
1035}
1036
1037// Barrier with SLM fence.
1038template <HW hw>
1039void gemm_kernel_generator_t<hw>::slmBarrier(
1040 const GRF &temp, const GRF &r0_info) {
1041 if (hw >= HW::Gen11) {
1042 slmfence(temp, r0_info);
1043 if (hw < HW::Gen12LP) mov<uint32_t>(8, null, temp);
1044 }
1045 barrier(temp, r0_info);
1046}
1047
1048// Barrier with global memory fence.
1049template <HW hw>
1050void gemm_kernel_generator_t<hw>::globalMemBarrier(
1051 const GRF &temp, const GRF &r0_info) {
1052 memfence(temp, r0_info);
1053 if (hw < HW::Gen12LP) mov<uint32_t>(8, null, temp);
1054 barrier(temp, r0_info);
1055}
1056
1057// Pause for a short period of time.
1058template <HW hw>
1059void gemm_kernel_generator_t<hw>::pause(const CommonStrategy &strategy) {
1060 if (hw >= HW::Gen11)
1061 mov(1 | Switch, tm0[4], strategy.pauseCycles);
1062 else
1063 for (int i = 0; i < 8; i++)
1064 mov<uint32_t>(8 | Switch, null, acc0);
1065}
1066
1067// Create a copy of a SubregisterPair in the other bank.
1068template <HW hw>
1069void gemm_kernel_generator_t<hw>::duplicateScalar(
1070 SubregisterPair &val, CommonState &state) {
1071 auto reg0 = val.getReg(0);
1072
1073 if (reg0 != val.getReg(1)) return;
1074
1075 auto bundle = Bundle::locate(hw, reg0);
1076 auto reg1 = state.ra.alloc_sub(
1077 reg0.getType(), Bundle(bundle.bank_id ^ 1, Bundle::any));
1078
1079 mov(1, reg1, reg0);
1080 val = SubregisterPair(reg0, reg1);
1081}
1082
1083template <HW hw>
1084void gemm_kernel_generator_t<hw>::deduplicateScalar(
1085 SubregisterPair &val, CommonState &state) {
1086 auto reg0 = val.getReg(0), reg1 = val.getReg(1);
1087 if (reg0 != reg1) {
1088 state.ra.release(reg1);
1089 val = SubregisterPair(reg0);
1090 }
1091}
1092
1093// Create a copy of a scalar subregister in the other bank.
1094template <HW hw>
1095template <typename T>
1096void gemm_kernel_generator_t<hw>::duplicateScalar(
1097 Scalar<T> &val, CommonState &state) {
1098 if (!val.fixed()) duplicateScalar(val.getPair(), state);
1099}
1100
1101// Create multiple versions of the input subregister reg, shifted by amounts specified by the shifts bitmask.
1102// The input subregister is used for one of the versions.
1103template <HW hw>
1104MultishiftSubregister gemm_kernel_generator_t<hw>::multishift(
1105 const Subregister &reg, unsigned int shifts,
1106 const CommonStrategy &strategy, CommonState &state, Bundle hint) {
1107 MultishiftSubregister ms;
1108
1109 while (shifts != 0) {
1110 int shift = bsr(shifts);
1111 shifts &= ~(1 << shift);
1112
1113 if (shifts != 0) {
1114 Subregister s = state.ra.alloc_sub(reg.getType(), hint);
1115 ms.set(shift, s);
1116 eshr(1, s, reg, shift, strategy, state);
1117 } else {
1118 ms.set(shift, reg);
1119 if (shift > 0) eshr(1, reg, reg, shift, strategy, state);
1120 }
1121 }
1122
1123 return ms;
1124}
1125
1126// Get ID of fused thread (0/1), multiplied by a scaling factor. Assumes r1 has not been overwritten,
1127// or state.lid0 is set to a subregister containing local ID 0 (divided by the subgroup size).
1128template <HW hw>
1129void gemm_kernel_generator_t<hw>::getFusedID(int scale,
1130 const CommonProblem &problem, const CommonStrategy &strategy,
1131 CommonState &state) {
1132 if (strategy.fused) {
1133 state.fusedID = state.ra.alloc_sub<uint16_t>(
1134 getHint(HintType::LongTerm, strategy));
1135 if (state.lid0.isValid()) {
1136 if (is_zero_or_pow2(scale) && scale > 1
1137 && (state.fusedID.getOffset() & 3) == 0)
1138 bfi2(1, state.fusedID, scale, state.lid0, 0);
1139 else {
1140 and_(1, state.fusedID, state.lid0, 1);
1141 mulConstant(1, state.fusedID, state.fusedID, scale);
1142 }
1143 } else if (is_zero_or_pow2(scale)) {
1144 int shift = log2(scale) - log2(strategy.subgroupSize);
1145 Subregister lid0 = r1.uw(0);
1146
1147 if (shift > 0)
1148 shl(1, state.fusedID, lid0, uint16_t(shift));
1149 else if (shift < 0)
1150 shr(1, state.fusedID, lid0, uint16_t(-shift));
1151
1152 and_(1, state.fusedID, (shift == 0) ? lid0 : state.fusedID,
1153 uint16_t(scale));
1154 } else {
1155 shr(1, state.fusedID, r1.uw(0),
1156 uint16_t(log2(strategy.subgroupSize)));
1157 and_(1, state.fusedID, state.fusedID, uint16_t(1));
1158 mulConstant(1, state.fusedID, state.fusedID, uint16_t(scale));
1159 }
1160 }
1161}
1162
1163// Move r0 information to another register if configured.
1164template <HW hw>
1165void gemm_kernel_generator_t<hw>::moveR0(
1166 const CommonStrategy &strategy, CommonState &state) {
1167 if (state.movedR0) return;
1168 if (state.r0_info.isInvalid()) {
1169 switch (strategy.moveR0) {
1170 case MoveR0::None:
1171 state.r0_info = r0.ud();
1172 state.movedR0 = true;
1173 return;
1174 case MoveR0::Acc: state.r0_info = acc0.ud(); break;
1175 case MoveR0::Addr: state.r0_info = a0.ud(); break;
1176 case MoveR0::GRF:
1177 state.r0_info
1178 = state.ra.alloc(getHint(HintType::R0Info, strategy));
1179 break;
1180 }
1181 }
1182
1183 mov<uint32_t>(8, state.r0_info, r0);
1184
1185 if (!strategy.sipR0WA) state.ra.release(r0);
1186
1187 state.movedR0 = true;
1188}
1189
1190template <HW hw>
1191void gemm_kernel_generator_t<hw>::moveR0(
1192 const GEMMStrategy &strategy, GEMMState &state) {
1193 if (state.movedR0) return;
1194 if (strategy.moveR0 == MoveR0::GRF) {
1195 if (strategy.registerScheme == GEMMStrategy::ACB
1196 || strategy.registerScheme == GEMMStrategy::BCA) {
1197 state.r0_info = r127;
1198 state.ra.claim(r127);
1199 }
1200 }
1201 moveR0(static_cast<CommonStrategy>(strategy), state);
1202}
1203
1204// Call a functor needing the r0 header in a GRF.
1205template <HW hw>
1206template <typename F>
1207void gemm_kernel_generator_t<hw>::useR0(CommonState &state, F f) {
1208 if (state.r0_info.isARF()) {
1209 auto r0_info = state.ra.alloc();
1210 mov<uint32_t>(8, r0_info, state.r0_info);
1211 f(r0_info);
1212 state.ra.safeRelease(r0_info);
1213 } else
1214 f(GRF {state.r0_info.getBase()});
1215}
1216
1217// Divide out subgroup size from x local size and local ID.
1218template <HW hw>
1219void gemm_kernel_generator_t<hw>::removeSG(const CommonProblem &problem,
1220 const CommonStrategy &strategy, const CommonState &state) {
1221 uint16_t sss = log2(strategy.subgroupSize);
1222
1223 auto localSize0 = interface.getLocalSize(0);
1224 auto localID0 = interface.getLocalID(0);
1225
1226 shr(1, localSize0, localSize0, sss);
1227 shr(1, localID0.uw(0), localID0.uw(0), sss);
1228}
1229
1230// Swap bit 0 of local ID x and y if needed so that threads are ordered according to specified EU fusion.
1231template <HW hw>
1232void gemm_kernel_generator_t<hw>::reorderFusedEUs(const GEMMProblem &problem,
1233 const GEMMStrategy &strategy, GEMMState &state) {
1234 if (!strategy.fused) return;
1235
1236 if (strategy.loopOrder[0] != strategy.fusedLoop) {
1237 auto temp = state.ra.alloc_sub<uint32_t>();
1238 and_(1, temp, state.inputs.localIDN.ud(), uint16_t(1));
1239 bfi2(1, state.inputs.localIDN.ud(), uint16_t(1),
1240 state.inputs.localIDM.ud(), state.inputs.localIDN.ud());
1241 bfi2(1, state.inputs.localIDM.ud(), uint16_t(1), temp,
1242 state.inputs.localIDM.ud());
1243 state.ra.safeRelease(temp);
1244 }
1245}
1246
1247template <HW hw>
1248Subregister gemm_kernel_generator_t<hw>::copySubregister(
1249 const Subregister &reg, CommonState &state, Bundle hint) {
1250 auto copy = state.ra.alloc_sub(reg.getType(), hint);
1251 mov(1, copy, reg);
1252 return copy;
1253}
1254
1255// Set a matrix to zero.
1256template <HW hw>
1257void gemm_kernel_generator_t<hw>::zeroMatrix(
1258 const GRFMultirange &r, const CommonStrategy &strategy) {
1259 map<uint32_t>(hw, r, r, strategy,
1260 [&](int esize, GRF reg, GRF _) { mov(esize, reg, uint16_t(0)); });
1261}
1262
1263// Release fused remainder-related state variables.
1264template <HW hw>
1265void gemm_kernel_generator_t<hw>::releaseFusedRemainders(GEMMState &state) {
1266 state.ra.safeRelease(state.remFusedStorage);
1267 state.remaindersFused[LoopM] = Subregister {};
1268 state.remaindersFused[LoopN] = Subregister {};
1269}
1270
1271template <HW hw>
1272void gemm_kernel_generator_t<hw>::saveMNLocalIDs(
1273 const GEMMStrategy &strategy, GEMMState &state) {
1274 state.lidStorage = state.ra.alloc_sub<uint32_t>(
1275 getHint(HintType::LongTerm, strategy));
1276 state.lidM = state.lidStorage.uw(0);
1277 state.lidN = state.lidStorage.uw(1);
1278 mov(1, state.lidM, state.inputs.localIDM);
1279 mov(1, state.lidN, state.inputs.localIDN);
1280}
1281
1282template <HW hw>
1283void gemm_kernel_generator_t<hw>::saveKLocalIDSize(
1284 const GEMMStrategy &strategy, GEMMState &state) {
1285 state.lidszKStorage = state.ra.alloc_sub<uint32_t>(
1286 getHint(HintType::LongTerm, strategy));
1287 state.lidK = state.lidszKStorage.uw(0);
1288 state.lszK = state.lidszKStorage.uw(1);
1289 mov(1, state.lidK, state.inputs.localIDK);
1290 mov(1, state.lszK, state.inputs.localSizeK);
1291}
1292
1293template <HW hw>
1294void gemm_kernel_generator_t<hw>::releaseSavedMNLocalIDs(GEMMState &state) {
1295 state.ra.safeRelease(state.lidStorage);
1296 state.lidStorage = invalid;
1297 state.lidM = invalid;
1298 state.lidN = invalid;
1299}
1300
1301// Clear read suppresion data on ALU pipes.
1302template <HW hw>
1303void gemm_kernel_generator_t<hw>::doReadSuppressionWA(
1304 const CommonStrategy &strategy, CommonState &state) {
1305 GRF temp;
1306 bool freeTemp = false;
1307
1308 if (!strategy.readSuppressionWA) return;
1309
1310 if (state.r0_info.isValid() && !state.r0_info.isARF())
1311 temp = GRF(state.r0_info.getBase());
1312 else {
1313 temp = state.ra.try_alloc();
1314 if (temp.isValid())
1315 freeTemp = true;
1316 else
1317 temp = r0;
1318 }
1319
1320 csel<int16_t>(8, temp, temp, temp, temp);
1321 csel<float>(8, temp, temp, temp, temp);
1322
1323 if (freeTemp) state.ra.safeRelease(temp);
1324}
1325
1326// Get minimum row/column granularity for a matrix in memory.
1327static void getGranularities(
1328 const MatrixAddressing &atype, int &rgran, int &cgran) {
1329 auto &xgran = isColMajor(atype.layout) ? cgran : rgran;
1330 auto &ygran = isColMajor(atype.layout) ? rgran : cgran;
1331 rgran = std::max<int>(atype.tileR, 1);
1332 cgran = std::max<int>(atype.tileC, 1);
1333 xgran = std::max<int>(xgran, atype.crosspack);
1334 if (isPacked(atype.layout)) ygran = std::max<int>(ygran, atype.packSize);
1335}
1336
1337// Common register allocator hints.
1338template <HW hw>
1339Bundle gemm_kernel_generator_t<hw>::getHint(HintType type) {
1340 switch (type) {
1341 case HintType::Bank0: return Bundle(0, Bundle::any);
1342 case HintType::Bank1: return Bundle(1, Bundle::any);
1343 default: break;
1344 }
1345
1346 switch (hw) {
1347 case HW::Gen9:
1348 case HW::Gen10:
1349 case HW::Gen11:
1350 switch (type) {
1351 case HintType::TempComp0: return Bundle(0, 1);
1352 case HintType::TempComp1: return Bundle(1, 1);
1353 case HintType::LongTerm: return Bundle(Bundle::any, 0);
1354 case HintType::LongTerm0: return Bundle(0, 0);
1355 case HintType::LongTerm1: return Bundle(1, 0);
1356 default: break;
1357 }
1358 break;
1359 case HW::Gen12LP:
1360 case HW::XeHP:
1361 case HW::XeHPG:
1362 case HW::XeHPC:
1363 switch (type) {
1364 case HintType::LongTerm0: return Bundle(0, Bundle::any);
1365 case HintType::LongTerm1: return Bundle(1, Bundle::any);
1366 default: break;
1367 }
1368 default: break;
1369 }
1370
1371 return Bundle();
1372}
1373
1374template <HW hw>
1375Bundle gemm_kernel_generator_t<hw>::getHint(
1376 HintType type, const CommonStrategy &strategy) {
1377 return getHint(type);
1378}
1379
1380// GEMM register allocation hints.
1381template <HW hw>
1382Bundle gemm_kernel_generator_t<hw>::getHint(
1383 HintType type, const GEMMStrategy &strategy) {
1384 switch (hw) {
1385 case HW::Gen9:
1386 case HW::Gen10:
1387 case HW::Gen11:
1388 switch (strategy.registerScheme) {
1389 case GEMMStrategy::CSeparate:
1390 switch (type) {
1391 case HintType::A0Broadcast:
1392 case HintType::A0: return Bundle(1, 0);
1393 case HintType::A1Broadcast:
1394 case HintType::A1: return Bundle(0, 0);
1395 case HintType::B0Broadcast:
1396 case HintType::B0: return Bundle(0, 0);
1397 case HintType::B1Broadcast:
1398 case HintType::B1: return Bundle(1, 0);
1399 case HintType::C: return Bundle(0, 1);
1400 case HintType::CLoad: return Bundle(1, 0);
1401 default: break;
1402 }
1403 break;
1404 case GEMMStrategy::ACB:
1405 switch (type) {
1406 case HintType::A0Broadcast:
1407 case HintType::A0: return Bundle(1, 0);
1408 case HintType::A1Broadcast:
1409 case HintType::A1: return Bundle(0, 0);
1410 case HintType::B0Broadcast:
1411 case HintType::B0: return Bundle(0, 1);
1412 case HintType::B1Broadcast:
1413 case HintType::B1: return Bundle(1, 1);
1414 case HintType::C: return Bundle(0, 0);
1415 case HintType::CLoad: return Bundle();
1416 default: break;
1417 }
1418 break;
1419 case GEMMStrategy::BCA:
1420 switch (type) {
1421 case HintType::A0Broadcast:
1422 case HintType::A0: return Bundle(0, 1);
1423 case HintType::A1Broadcast:
1424 case HintType::A1: return Bundle(1, 1);
1425 case HintType::B0Broadcast:
1426 case HintType::B0: return Bundle(1, 0);
1427 case HintType::B1Broadcast:
1428 case HintType::B1: return Bundle(0, 0);
1429 case HintType::C: return Bundle(0, 0);
1430 case HintType::CLoad: return Bundle();
1431 default: break;
1432 }
1433 break;
1434 default: break;
1435 }
1436 break;
1437 case HW::Gen12LP:
1438 case HW::XeHP:
1439 case HW::XeHPG:
1440 case HW::XeHPC:
1441 switch (strategy.registerScheme) {
1442 case GEMMStrategy::CSeparate:
1443 switch (type) {
1444 case HintType::A0Broadcast:
1445 case HintType::A0: return Bundle(1, Bundle::any);
1446 case HintType::A1Broadcast:
1447 case HintType::A1: return Bundle(0, Bundle::any);
1448 case HintType::B0Broadcast:
1449 case HintType::B0: return Bundle(0, Bundle::any);
1450 case HintType::B1Broadcast:
1451 case HintType::B1: return Bundle(1, Bundle::any);
1452 case HintType::C: return Bundle(0, 0);
1453 case HintType::CLoad: return Bundle(1, Bundle::any);
1454 default: break;
1455 }
1456 break;
1457 case GEMMStrategy::ACB:
1458 case GEMMStrategy::BCA:
1459 if (strategy.systolic) switch (type) {
1460 case HintType::A0:
1461 case HintType::B0: return Bundle(0, Bundle::any);
1462 case HintType::A1:
1463 case HintType::B1: return Bundle(1, Bundle::any);
1464 case HintType::A0Broadcast:
1465 case HintType::B0Broadcast:
1466 return Bundle(1, Bundle::any);
1467 case HintType::A1Broadcast:
1468 case HintType::B1Broadcast:
1469 return Bundle(0, Bundle::any);
1470 case HintType::C: return Bundle(0, Bundle::any);
1471 default: break;
1472 }
1473 /* else fall through */
1474 case GEMMStrategy::VNC:
1475 switch (type) {
1476 case HintType::A0:
1477 case HintType::B0: return Bundle(1, Bundle::any);
1478 case HintType::A1:
1479 case HintType::B1: return Bundle(0, Bundle::any);
1480 case HintType::A0Broadcast:
1481 case HintType::B0Broadcast:
1482 return Bundle(0, Bundle::any);
1483 case HintType::A1Broadcast:
1484 case HintType::B1Broadcast:
1485 return Bundle(1, Bundle::any);
1486 case HintType::C: return Bundle(0, Bundle::any);
1487 default: break;
1488 }
1489 break;
1490 case GEMMStrategy::ABInterleave:
1491 switch (type) {
1492 case HintType::A0:
1493 case HintType::A1:
1494 case HintType::A0Broadcast:
1495 case HintType::A1Broadcast: return Bundle(1, 0);
1496 case HintType::B0:
1497 case HintType::B1:
1498 case HintType::B0Broadcast:
1499 case HintType::B1Broadcast: return Bundle(1, 4);
1500 case HintType::C: return Bundle(0, Bundle::any);
1501 default: break;
1502 }
1503 break;
1504 case GEMMStrategy::NSeparate:
1505 switch (type) {
1506 case HintType::A0:
1507 case HintType::B0: return Bundle(1, Bundle::any);
1508 case HintType::A1:
1509 case HintType::B1: return Bundle(0, Bundle::any);
1510 case HintType::A0Broadcast:
1511 case HintType::B0Broadcast:
1512 case HintType::A1Broadcast:
1513 case HintType::B1Broadcast: return Bundle();
1514 case HintType::C: return Bundle(0, Bundle::any);
1515 case HintType::C1: return Bundle(1, Bundle::any);
1516 default: break;
1517 }
1518 break;
1519 case GEMMStrategy::VAvoid:
1520 switch (type) {
1521 case HintType::A0:
1522 case HintType::B0: return Bundle(0, Bundle::any);
1523 case HintType::A1:
1524 case HintType::B1: return Bundle(1, Bundle::any);
1525 case HintType::A0Broadcast:
1526 case HintType::B0Broadcast:
1527 case HintType::A1Broadcast:
1528 case HintType::B1Broadcast:
1529 return Bundle(1, Bundle::any);
1530 case HintType::C: return Bundle(0, Bundle::any);
1531 case HintType::C1: return Bundle(1, Bundle::any);
1532 default: break;
1533 }
1534 break;
1535 }
1536 break;
1537 default: break;
1538 }
1539
1540 return getHint(type);
1541}
1542
1543// Copy kernel register allocation hints.
1544template <HW hw>
1545Bundle gemm_kernel_generator_t<hw>::getHint(
1546 HintType type, const CopyStrategy &strategy) {
1547 switch (hw) {
1548 case HW::Gen9:
1549 case HW::Gen10:
1550 case HW::Gen11:
1551 case HW::Gen12LP:
1552 case HW::XeHP:
1553 case HW::XeHPG:
1554 case HW::XeHPC:
1555 switch (type) {
1556 case HintType::S: return Bundle();
1557 case HintType::D: return Bundle();
1558 case HintType::SAddr: return Bundle();
1559 case HintType::DAddr: return Bundle();
1560 default: break;
1561 }
1562 break;
1563 default: break;
1564 }
1565
1566 return getHint(type);
1567}
1568
1569static inline void safeReleaseRanges(
1570 vector<GRFRange> &ranges, CommonState &state) {
1571 for (auto &a : ranges)
1572 state.ra.safeRelease(a);
1573 ranges.clear();
1574}
1575
1576static inline void releaseRanges(
1577 const vector<GRFRange> &ranges, CommonState &state) {
1578 for (auto &a : ranges)
1579 state.ra.release(a);
1580}
1581
1582static inline void reclaimRanges(
1583 const vector<GRFRange> &ranges, CommonState &state) {
1584 for (auto &a : ranges)
1585 state.ra.claim(a);
1586}
1587
1588static inline void safeRelease(SubregisterPair &pair, CommonState &state) {
1589 state.ra.release(pair.getReg(0));
1590 state.ra.release(pair.getReg(1));
1591 pair.invalidate();
1592}
1593
1594static inline void safeReleaseRanges(
1595 GRFMultirange &ranges, CommonState &state) {
1596 safeReleaseRanges(ranges.ranges, state);
1597 ranges.ranges.clear();
1598}
1599
1600static inline void safeReleaseRanges(
1601 vector<GRFMultirange> &ranges, CommonState &state) {
1602 for (auto &a : ranges)
1603 safeReleaseRanges(a, state);
1604 ranges.clear();
1605}
1606
1607static inline void releaseRanges(
1608 const GRFMultirange &ranges, CommonState &state) {
1609 releaseRanges(ranges.ranges, state);
1610}
1611
1612static inline void releaseRanges(
1613 const vector<GRFMultirange> &ranges, CommonState &state) {
1614 for (auto &a : ranges)
1615 releaseRanges(a, state);
1616}
1617
1618static inline void reclaimRanges(
1619 const GRFMultirange &ranges, CommonState &state) {
1620 reclaimRanges(ranges.ranges, state);
1621}
1622
1623// Reclaim a list of GRF multiranges.
1624static inline void reclaimRanges(
1625 const vector<GRFMultirange> &ranges, CommonState &state) {
1626 for (auto &a : ranges)
1627 reclaimRanges(a, state);
1628}
1629
1630/***********************\
1631|* Load/store support. *|
1632\***********************/
1633
1634static int consecutiveElements(
1635 Type T, int r, int c, const MatrixAddressing &atype) {
1636 int x = isColMajor(atype.layout) ? r : c;
1637 int y = isColMajor(atype.layout) ? c : r;
1638
1639 if (isPacked(atype.layout)) {
1640 int effTileX = (atype.layout == MatrixLayout::Pc) ? atype.tileR
1641 : atype.tileC;
1642 int effTileY = (atype.layout == MatrixLayout::Pc) ? atype.tileC
1643 : atype.tileR;
1644 if (!effTileX) effTileX = atype.packSize;
1645 if (!effTileY) effTileY = atype.crosspack;
1646
1647 if (y % effTileY == 0) {
1648 if (x == atype.packSize)
1649 return x * y;
1650 else if (x % effTileX == 0)
1651 return x * effTileY;
1652 }
1653 if (y % atype.crosspack == 0)
1654 return std::min(x, effTileX) * atype.crosspack;
1655 }
1656
1657 return x;
1658}
1659
1660static bool needsPseudoblock(HW hw, Type T, int r, int c,
1661 const MatrixAddressing &atype,
1662 const MatrixAddressingStrategy &astrategy, bool writable, bool masked) {
1663 auto consecutive = consecutiveElements(T, r, c, atype);
1664 bool dwAligned = (atype.alignment & 0x3) == 0;
1665 bool owAligned = (atype.alignment & 0xF) == 0;
1666 bool pseudo = !dwAligned || ((consecutive * T) & 0x3)
1667 || (writable && ((consecutive * T) & 0xF) && !astrategy.newDP)
1668 || (writable && !owAligned)
1669 || (writable && masked && (T.size() & 3))
1670 || (masked && !owAligned
1671 && (hw >= HW::XeHP
1672 || astrategy.base.getModel() != ModelA64))
1673 || (hw >= HW::XeHPC && masked)
1674 || (hw >= HW::XeHPC && !astrategy.padded && !astrategy.newDP
1675 && ((r * c * T) & 0xF))
1676 || astrategy.atomic
1677 || (isColMajor(atype.layout) ? c : r) % atype.crosspack
1678 || ((astrategy.base.getModel() == ModelSLM)
1679 && (hw < HW::Gen11 || !owAligned));
1680
1681 return pseudo;
1682}
1683
1684static bool pseudoblockUseSurface(const MatrixAddressing &atype,
1685 const MatrixAddressingStrategy &astrategy, const RegisterBlock &block) {
1686 return (astrategy.base.getModel() == ModelSLM) && (block.ebytes == 4)
1687 && !astrategy.atomic;
1688}
1689
1690// Get effective access type to use when setting up addresses.
1691static AccessType effectiveAccessType(const MatrixAddressing &atype,
1692 const MatrixAddressingStrategy &astrategy, const RegisterBlock &block) {
1693 auto type = astrategy.accessType;
1694 if (!block.isLoadBlock()) return type;
1695 if (type == AccessType::Block && block.ebytes < 16 && block.extra)
1696 type = AccessType::PseudoBlock;
1697 else if (type == AccessType::Scattered
1698 && astrategy.base.getModel() == ModelSLM && block.ebytes == 4
1699 && !astrategy.newDP)
1700 type = AccessType::ChannelScattered;
1701 else if (type == AccessType::ChannelScattered && block.ebytes != 4)
1702 type = AccessType::Scattered;
1703 return type;
1704}
1705
1706// Get effective access type to use when performing loads/stores.
1707static AccessType implAccessType(const MatrixAddressing &atype,
1708 const MatrixAddressingStrategy &astrategy, const RegisterBlock &block) {
1709 auto type = effectiveAccessType(atype, astrategy, block);
1710 if (type == AccessType::PseudoBlock)
1711 type = pseudoblockUseSurface(atype, astrategy, block)
1712 ? AccessType::ChannelScattered
1713 : AccessType::Scattered;
1714 return type;
1715}
1716
1717// Count the number of address/header GRFs required by a RegisterBlock.
1718static inline int addrGRFCount(const MatrixAddressing &atype,
1719 const MatrixAddressingStrategy &astrategy, const RegisterBlock &block) {
1720 // Non-load blocks don't get address registers.
1721 if (!block.isLoadBlock()) return 0;
1722
1723 switch (effectiveAccessType(atype, astrategy, block)) {
1724 case AccessType::Scattered:
1725 case AccessType::ChannelScattered:
1726 case AccessType::PseudoBlock: {
1727 auto bytesPerAddr = (astrategy.base.getModel() == ModelA64) ? 8 : 4;
1728 auto baseSIMD = std::max<int>(block.simdSize, 8);
1729 auto log2Bytes = block.log2GRFBytes;
1730 return (bytesPerAddr * baseSIMD + (1 << log2Bytes) - 1)
1731 >> log2Bytes;
1732 }
1733 case AccessType::Block:
1734 case AccessType::Block2D:
1735 case AccessType::Block2DTranspose:
1736 case AccessType::Block2DVNNI: return 1;
1737 }
1738 throw std::runtime_error("Invalid addressing.");
1739}
1740
1741// Attempt to allocate address registers for a layout. Returns true if successful.
1742static bool tryAllocAddrRegs(vector<GRFRange> &addrRegs,
1743 const vector<RegisterBlock> &layout, const MatrixAddressing &atype,
1744 const MatrixAddressingStrategy &astrategy, CommonState &state,
1745 Bundle hint = Bundle()) {
1746 auto nblocks = int(layout.size());
1747 bool ok = true;
1748
1749 addrRegs.resize(nblocks);
1750
1751 for (int l = 0; l < nblocks && ok; l++) {
1752 addrRegs[l] = state.ra.try_alloc_range(
1753 addrGRFCount(atype, astrategy, layout[l]), hint);
1754 ok &= addrRegs[l].isValid();
1755 }
1756
1757 if (!ok) {
1758 for (auto &regs : addrRegs)
1759 state.ra.safeRelease(regs);
1760 addrRegs.clear();
1761 }
1762
1763 return ok;
1764}
1765
1766// Allocate address registers for a layout.
1767static void allocAddrRegs(vector<GRFRange> &addrRegs,
1768 const vector<RegisterBlock> &layout, const MatrixAddressing &atype,
1769 const MatrixAddressingStrategy &astrategy, CommonState &state,
1770 Bundle hint = Bundle()) {
1771 if (!tryAllocAddrRegs(addrRegs, layout, atype, astrategy, state, hint))
1772 throw out_of_registers_exception();
1773}
1774
1775// Check if a layout is completely column-major.
1776static inline bool isLayoutColMajor(const vector<RegisterBlock> &layout) {
1777 if (layout.size() == 0) throw std::runtime_error("Empty layout.");
1778 return layout[0]
1779 .colMajor; // All layouts we create are homogeneous currently.
1780}
1781
1782// Get the matrix size represented by a layout.
1783static inline void getLayoutDims(
1784 const vector<RegisterBlock> &layout, int &m, int &n) {
1785 // For now all layouts are sorted so last block is in lower-right corner.
1786 if (layout.size() == 0) throw std::runtime_error("Empty layout.");
1787 auto &last = layout[layout.size() - 1];
1788 m = last.offsetR + last.nr;
1789 n = last.offsetC + last.nc;
1790}
1791
1792// Check if every block in a layout has the given crosspack, with no padding.
1793static inline bool hasFullCrosspack(
1794 const vector<RegisterBlock> &layout, int crosspack) {
1795 if (layout.size() == 0) return true;
1796 if (layout[0].crosspack
1797 != crosspack) // Only need to check first block of layout currently.
1798 return false;
1799 for (const auto &block : layout)
1800 if ((block.colMajor ? block.nc : block.nr) % crosspack) return false;
1801 return true;
1802}
1803
1804// Check if the layout is tiled with the given tiling.
1805static inline bool hasTiling(
1806 const vector<RegisterBlock> &layout, int tileR, int tileC) {
1807 for (auto &block : layout) {
1808 if (tileR > 0)
1809 if (block.offsetR / tileR != (block.offsetR + block.nr - 1) / tileR)
1810 return false;
1811 if (tileC > 0)
1812 if (block.offsetC / tileC != (block.offsetC + block.nc - 1) / tileC)
1813 return false;
1814 }
1815 return true;
1816}
1817
1818// Check if a layout has row fragmenting.
1819static bool hasRowFragmenting(const vector<RegisterBlock> &layout) {
1820 for (auto &block : layout)
1821 if (block.rowFragment) return true;
1822 return false;
1823}
1824
1825// Check if a layout has column fragmenting.
1826static bool hasColumnFragmenting(const vector<RegisterBlock> &layout) {
1827 for (auto &block : layout)
1828 if (block.colFragment) return true;
1829 return false;
1830}
1831
1832// Check if a layout has remainders enabled.
1833static bool hasRemainders(const vector<RegisterBlock> &layout,
1834 bool remainderR = true, bool remainderC = true) {
1835 for (auto &block : layout)
1836 if ((remainderR && block.remainderR)
1837 || (remainderC && block.remainderC))
1838 return true;
1839 return false;
1840}
1841
1842// Check if a layout has any kind of fragmenting.
1843static bool hasFragmenting(const vector<RegisterBlock> &layout) {
1844 for (auto &block : layout)
1845 if (block.rowFragment || block.colFragment) return true;
1846 return false;
1847}
1848
1849// Check if a layout has any masking.
1850static bool hasMasking(const vector<RegisterBlock> &layout) {
1851 for (auto &block : layout)
1852 if (block.rowMask || block.colMask || block.flag) return true;
1853 return false;
1854}
1855
1856// Check if a layout has any flag registers assigned.
1857static bool hasFlags(const vector<RegisterBlock> &layout) {
1858 for (auto &block : layout)
1859 if (block.flag) return true;
1860 return false;
1861}
1862
1863// Find the maximum block size in a layout, in registers.
1864static inline int getMaxLoadBlock(const vector<RegisterBlock> &layout) {
1865 int result = 0;
1866 for (auto &l : layout)
1867 result = std::max<int>(result, l.msgRegs);
1868 return result;
1869}
1870
1871// Count the number of registers needed by a register layout.
1872static inline int getRegCount(const vector<RegisterBlock> &layout) {
1873 if (layout.empty()) return 0;
1874
1875 int lastByte = 0;
1876 for (auto &l : layout)
1877 lastByte = std::max(lastByte, l.offsetBytes + l.bytes);
1878
1879 int log2Bytes = layout[0].log2GRFBytes;
1880 return (lastByte + (1 << log2Bytes) - 1) >> log2Bytes;
1881}
1882
1883static int getAddr0Offset(const RegisterBlock &block,
1884 const MatrixAddressing &atype,
1885 const MatrixAddressingStrategy &astrategy) {
1886 if (astrategy.newDP) return 0;
1887 if (astrategy.base.getModel() == ModelA64) return 0;
1888 if (effectiveAccessType(atype, astrategy, block) == AccessType::Block)
1889 return 2;
1890 return 0;
1891}
1892
1893// Get a subregister containing the (shifted) address of the (0,0) entry of a layout.
1894static Subregister getOriginAddr(const vector<RegisterBlock> &layout,
1895 const vector<GRFRange> &addrRegs, const MatrixAddressing &atype,
1896 const MatrixAddressingStrategy &astrategy, int *shiftOut = nullptr) {
1897 bool a64 = (astrategy.base.getModel() == ModelA64);
1898
1899 for (size_t b = 0; b < layout.size(); b++) {
1900 const auto &block = layout[b];
1901 if ((block.offsetR != 0) || (block.offsetC != 0)) continue;
1902
1903 int off = getAddr0Offset(block, atype, astrategy);
1904
1905 if (shiftOut) *shiftOut = block.addrShift;
1906 return addrRegs[b][0].sub(off, a64 ? DataType::uq : DataType::ud);
1907 }
1908
1909 if (shiftOut) *shiftOut = 0;
1910 return Subregister();
1911}
1912
1913static inline int maxScatteredSIMD(
1914 HW hw, const MatrixAddressingStrategy &astrategy) {
1915 if (astrategy.newDP) return GRF::bytes(hw) >> 1;
1916 return 16;
1917}
1918
1919static inline int minScatteredSIMD(
1920 HW hw, const MatrixAddressingStrategy &astrategy) {
1921 if (hw == HW::XeHPC) return 16;
1922 return maxScatteredSIMD(hw, astrategy) >> 1;
1923}
1924
1925// Get width and height parameters for underlying 2D block load message.
1926static void getBlock2DWH(int &w, int &h, const MatrixAddressing &atype,
1927 const RegisterBlock &block, int *outMultiX = nullptr) {
1928 int multiX = 1;
1929 w = isColMajor(atype.layout) ? block.nr : block.nc;
1930 h = isColMajor(atype.layout) ? block.nc : block.nr;
1931 w = (w * block.extra) / block.ebytes;
1932 if (isPacked(atype.layout)) {
1933 int maxW = 64 / block.ebytes;
1934 multiX = div_up(w, maxW);
1935 w /= multiX;
1936 h *= multiX;
1937 }
1938 if (outMultiX) *outMultiX = multiX;
1939}
1940
1941static bool isRegisterColMajor(Type T, const MatrixAddressing &atype,
1942 const MatrixAddressingStrategy &astrategy) {
1943 return isColMajor(atype.layout) ^ isTransposing(astrategy.accessType)
1944 ^ isLargeCrosspack(T, atype.crosspack);
1945}
1946
1947// Set up a RegisterBlock structure.
1948template <HW hw>
1949bool gemm_kernel_generator_t<hw>::getBlockInfo(Type T,
1950 const MatrixAddressing &atype,
1951 const MatrixAddressingStrategy &astrategy, int r, int c,
1952 bool remainderR, bool remainderC, bool writable, bool avoidFragment,
1953 int maxRBlock, int maxCBlock, int &rblock, int &cblock,
1954 RegisterBlock &block) {
1955 bool prefetch = astrategy.prefetch;
1956 int R = rounddown_pow2(r);
1957 int C = rounddown_pow2(c);
1958
1959 if (maxRBlock == 0) maxRBlock = r;
1960 if (maxCBlock == 0) maxCBlock = c;
1961
1962 if (isPacked(atype.layout)) {
1963 // Don't cross nonconsecutive tiles in a packed layout.
1964 bool cm = isColMajor(atype.layout)
1965 ^ isTransposing(astrategy.accessType);
1966 if (cm) {
1967 if (maxRBlock < atype.packSize && atype.tileC > 0)
1968 maxCBlock = std::min<int>(maxCBlock, atype.tileC);
1969 } else {
1970 if (maxCBlock < atype.packSize && atype.tileR > 0)
1971 maxRBlock = std::min<int>(maxRBlock, atype.tileR);
1972 }
1973 }
1974
1975 // Set default parameters.
1976 block.colMajor = isColMajor(atype.layout);
1977 block.splitComplex = false;
1978 block.cxComponent = RegisterBlock::Interleaved;
1979 block.crosspack = 1;
1980 block.rowMask = MaskInfo::None();
1981 block.colMask = MaskInfo::None();
1982 block.rowFragment = 0;
1983 block.colFragment = 0;
1984 block.remainderR = remainderR;
1985 block.remainderC = remainderC;
1986 block.noRowsOK = false;
1987 block.noColsOK = false;
1988 block.descRemR = false;
1989 block.descRemC = false;
1990 block.descAssigned = false;
1991 block.addrShift = 0;
1992 block.writable = writable;
1993 block.clearFlag();
1994 block.log2GRFBytes = GRF::log2Bytes(hw);
1995 block.msgRegs = 0;
1996 block.bytes = 0;
1997 block.hasNoLoad = false;
1998
1999 auto &vrmask = block.rowMask.variable;
2000 auto &vcmask = block.colMask.variable;
2001
2002 vrmask.rsize = 0;
2003 vcmask.rsize = 0;
2004
2005 auto accessType = astrategy.accessType;
2006
2007 switch (accessType) {
2008 case AccessType::ChannelScattered:
2009 case AccessType::Scattered: {
2010 bool channelScattered
2011 = (accessType == AccessType::ChannelScattered);
2012
2013 // Detect large crosspack case.
2014 bool largeCP = isLargeCrosspack(T, atype.crosspack);
2015 int effCP = largeCP ? 1 : atype.crosspack;
2016
2017 // Scattered read/write messages effectively transpose DW/QW matrices.
2018 block.colMajor = !block.colMajor ^ largeCP;
2019
2020 // Let X be the contiguous dimension, Y the scattered dimension (in memory).
2021 int *xblock, *yblock;
2022 int maxXBlock, maxYBlock;
2023 int X, Y;
2024 bool remainderX, remainderY;
2025 int tileX, tileY;
2026 auto &vxmask = block.colMajor ? vcmask : vrmask;
2027 auto &vymask = block.colMajor ? vrmask : vcmask;
2028 auto &fragment
2029 = block.colMajor ? block.colFragment : block.rowFragment;
2030 auto smode = astrategy.smode;
2031
2032 if (block.colMajor) {
2033 Y = R;
2034 X = C;
2035 yblock = &rblock;
2036 xblock = &cblock;
2037 maxYBlock = maxRBlock;
2038 maxXBlock = maxCBlock;
2039 remainderY = remainderR;
2040 remainderX = remainderC;
2041 tileY = atype.tileR;
2042 tileX = atype.tileC;
2043 } else {
2044 X = R;
2045 Y = C;
2046 xblock = &rblock;
2047 yblock = &cblock;
2048 maxXBlock = maxRBlock;
2049 maxYBlock = maxCBlock;
2050 remainderX = remainderR;
2051 remainderY = remainderC;
2052 tileX = atype.tileR;
2053 tileY = atype.tileC;
2054 }
2055
2056 // Allowed accesses:
2057 // A64 Essentially max 256 bytes.
2058 // 8 slots x (1,2,4,8) dwords [Gen12/surface: 1,2,4]
2059 // 8 slots x (1,2,4) qwords
2060 // 16 slots x (1,2,4) dwords
2061 // 16 slots x (1,2) qwords
2062 // Others 8 slots x 1 dword
2063 // 16 slots x 1 dword
2064 // Slot counts doubled for 64-byte GRFs.
2065
2066 // Native (col major in memory) matrix block sizes, as a result:
2067 // SIMD8: 1x8 2x4 4x2 8x1 (count 1) 2x8 4x8 8x8 [others]
2068 // SIMD16: 1x16 2x8 4x4 8x2 16x1 (count 1) 2x16 4x16
2069 // Other layouts are possible too but require dummy (non-load) blocks.
2070 // Only kx8 and kx16 are supported for now for {4,8}-byte types.
2071 // For 16-byte types, only 1x4 and 1x8 are supported.
2072
2073 auto maxSIMD = maxScatteredSIMD(hw, astrategy);
2074 auto minSIMD = minScatteredSIMD(hw, astrategy);
2075
2076 auto Xc = (avoidFragment && remainderX) ? 1 : X;
2077 bool byte = (atype.alignment < 4) || (Xc * T * effCP < 4);
2078 bool a64 = (astrategy.base.getModel() == ModelA64);
2079
2080 channelScattered |= byte;
2081
2082 bool qword = (T.size() >= 8 && !channelScattered && !prefetch
2083 && (a64 || astrategy.newDP));
2084 if (astrategy.atomic
2085 && hasNativeAtomicAdd(hw, T.real(), atype, astrategy))
2086 qword &= (T.real().size() >= 8);
2087 int width = qword ? 8 : 4;
2088 block.ebytes = byte ? 1 : width;
2089 block.crosspack = std::max<int>(1, width / T);
2090 int consecutive = std::max<int>(1, T.size() / width);
2091
2092 if (prefetch) consecutive = 1;
2093
2094 if (block.ebytes == 4 && astrategy.base.getModel() == ModelSLM
2095 && !astrategy.newDP)
2096 channelScattered = true;
2097
2098 bool simd1 = !a64 && !channelScattered;
2099 simd1 &= !astrategy.newDP;
2100
2101 // Handle source crosspack.
2102 int uncrosspack = 1;
2103 if (effCP > 1) {
2104 if (effCP == block.crosspack) {
2105 block.crosspack = 1;
2106 uncrosspack = effCP;
2107 } else
2108 stub();
2109 }
2110
2111 // Try to fit a native matrix block size to X and Y.
2112 auto slots = std::min(Y, maxYBlock) * consecutive / uncrosspack;
2113 if (prefetch) {
2114 // Prefetch only: maximize Y usage.
2115 block.simdSize = maxSIMD;
2116 } else if (smode == ScatterSIMD::Narrow
2117 || (smode == ScatterSIMD::Default
2118 && block.ebytes * minSIMD > GRF::bytes(hw))) {
2119 // Maximize X usage because we always have at least 2 consecutive GRFs.
2120 block.simdSize
2121 = (slots >= maxSIMD && X <= 2) ? maxSIMD : minSIMD;
2122 } else {
2123 // Otherwise, try to maximize Y usage (larger SIMD, worse memory access).
2124 block.simdSize = maxSIMD;
2125 }
2126 block.simdSize
2127 = std::min<int>(block.simdSize, rounddown_pow2(slots));
2128
2129 bool no8x8DW = isGen12;
2130 bool no16x4QW = false;
2131
2132 no8x8DW &= !astrategy.newDP;
2133 if (hw == HW::XeHPG && astrategy.newDP)
2134 no8x8DW = no16x4QW
2135 = true; // Not supported on 512 EU A0. OK on later steppings.
2136 no16x4QW |= (!astrategy.newDP && GRF::bytes(hw) == 64);
2137
2138 int hwMaxXBlock;
2139
2140 if (prefetch)
2141 hwMaxXBlock = 64 / T;
2142 else if (consecutive > 1)
2143 hwMaxXBlock = 1;
2144 else if (byte)
2145 hwMaxXBlock = remainderX ? 1 : block.crosspack;
2146 else if (simd1)
2147 hwMaxXBlock = block.crosspack;
2148 else if (a64 && astrategy.atomic)
2149 hwMaxXBlock = block.crosspack;
2150 else if (channelScattered || (block.ebytes == 4 && no8x8DW)
2151 || (block.ebytes == 8 && no16x4QW)
2152 || (block.simdSize == maxSIMD))
2153 hwMaxXBlock = 16 / T;
2154 else
2155 hwMaxXBlock = 32 / T;
2156
2157 maxXBlock = std::min(maxXBlock, hwMaxXBlock);
2158
2159 if (tileX > 0) maxXBlock = std::min(maxXBlock, tileX);
2160
2161 *xblock = std::min<int>(X, maxXBlock);
2162 block.count = *xblock;
2163
2164 *yblock = block.simdSize * uncrosspack / consecutive;
2165 if (tileY > 0 && tileY < *yblock) stub();
2166
2167 if (prefetch)
2168 block.count = 1;
2169 else if (byte)
2170 block.count *= T.size();
2171 else
2172 block.count = std::max<int>(1, block.count / block.crosspack);
2173
2174 // LD is determined by actual # of SIMD slots in HW. But for X = 1 we may
2175 // shrink the LD to avoid allocating unnecessary registers.
2176 auto ldSIMD = block.simdSize;
2177 if (*xblock > 1 || (minSIMD * block.ebytes <= GRF::bytes(hw)))
2178 ldSIMD = std::max<int>(ldSIMD, minSIMD);
2179 block.ld = ldSIMD * uncrosspack / consecutive;
2180
2181 // Handle remainder. Masking handles Y remainders.
2182 if (remainderY) {
2183 vymask.isFixed = false;
2184 vymask.bitRep = consecutive;
2185 vymask.maskRep = 1;
2186 vymask.rsize = *yblock;
2187 vymask.rdivide = 1;
2188 }
2189
2190 // X remainders require fragmenting. Channel scattered float doesn't need complete fragmenting.
2191 // (ditto for regular scattered float with new dataport messages.)
2192 // Otherwise, fragment 2 is possible for DWord+ types but not implemented.
2193 if (remainderX && !prefetch) {
2194 if (avoidFragment && !remainderY && *xblock == 1) {
2195 vxmask.isFixed = false;
2196 vxmask.bitRep = 16;
2197 vxmask.maskRep = 1;
2198 vxmask.rsize = 1;
2199 vxmask.rdivide = 1;
2200 } else if ((channelScattered || astrategy.newDP)
2201 && block.crosspack == 1 && block.ebytes == T.size()) {
2202 fragment = std::min(*xblock, 4);
2203 if (block.colMajor) // Clang can't handle the ternary operator equivalent of this.
2204 block.descRemC = true;
2205 else
2206 block.descRemR = true;
2207 } else
2208 fragment = 1;
2209 }
2210
2211 block.extra = consecutive;
2212
2213 // BTS scattered accesses are addressed by elements.
2214 if (!astrategy.newDP && !channelScattered
2215 && !astrategy.base.isStateless())
2216 block.addrShift = log2(block.ebytes);
2217
2218 break;
2219 }
2220 case AccessType::Block:
2221 case AccessType::PseudoBlock: {
2222 // Three types of block messages:
2223 // block_oword: 16 byte align, BLK masking (= dw except ow channel on R Gen9 only -- silently ignore, can't fault)
2224 // aligned_oword: 4 byte align, no masking, read only
2225 // block_hword: [Gen9-12LP] A64; 4 byte align R, BLKCM masking (= dw but can do ow channel on Gen9 only)
2226 // A64; 16 byte align W
2227 // [XeHP] A64/BTS; 32 byte align R/W
2228 // New dataport messages support {DW, QW}x{1...64} with DW/QW alignment, no masking.
2229 //
2230 // Prefer block_hword in all cases. When block_hword can't be used:
2231 // Use oword if alignment can be assured (i.e. packed row/column layout, or oword-sized scalar)
2232 // Otherwise, use aligned oword. load/storeMatrixBlock will emit an error if masking/stores attempted.
2233 //
2234 // Pseudoblock messages have similar layouts, but are limited to
2235 // {8,16}x{dw,qw} sizes, so lengths 8,16 allowed for float, 4,8,16 for double.
2236
2237 bool colMajor = block.colMajor;
2238 bool effCM = colMajor ^ isLargeCrosspack(T, atype.crosspack);
2239 auto consecutive = consecutiveElements(T, r, c, atype);
2240 bool masking = (effCM ? remainderR : remainderC);
2241 bool bytePartialCP
2242 = (T.size() & 3) && ((colMajor ? C : R) % atype.crosspack);
2243 bool byte = (atype.alignment & 3) || (consecutive * T & 3)
2244 || bytePartialCP || ((T.size() & 3) && writable && masking);
2245 bool byte1PerSlot = byte && (bytePartialCP || masking);
2246 bool pseudo = (accessType == AccessType::PseudoBlock)
2247 | needsPseudoblock(
2248 hw, T, R, C, atype, astrategy, writable, masking);
2249 int maxElements = 0;
2250 int maskGranularity = 1;
2251 int maxSIMD = maxScatteredSIMD(hw, astrategy);
2252 bool oword = false, aoword = false;
2253 int npack = 0;
2254 bool canQW = false, mustQW = false;
2255
2256 bool a32 = (astrategy.base.getModel() == ModelA32);
2257 bool a64 = (astrategy.base.getModel() == ModelA64);
2258 bool sc = (astrategy.base.getModel() == ModelSC);
2259 bool slm = (astrategy.base.getModel() == ModelSLM);
2260
2261 if (!pseudo && byte) return false;
2262
2263 if (astrategy.newDP && !pseudo) {
2264 bool qword = ((atype.alignment | (consecutive * T)) % 8 == 0);
2265 block.ebytes = qword ? 8 : 4;
2266 maxElements = (64 * block.ebytes) / T;
2267 } else if (!pseudo) {
2268 int maxCount = 8;
2269 oword = !a64;
2270 aoword = ((atype.alignment & 0xF) != 0) || sc;
2271 if (hw > HW::Gen12LP) {
2272 oword |= (atype.alignment & 0x1F) != 0;
2273 if (slm) maxCount = 16;
2274 }
2275 block.ebytes = oword ? 16 : 32;
2276 maxElements = maxCount * block.ebytes / T;
2277 maskGranularity = 4; // Block accesses mask by dwords
2278 } else {
2279 bool nativeAtomic = astrategy.atomic
2280 && hasNativeAtomicAdd(hw, T.real(), atype, astrategy);
2281 canQW = ((atype.alignment | (consecutive * T)) % 8 == 0);
2282 if (!astrategy.newDP) canQW &= !byte && a64;
2283 if (slm
2284 && astrategy
2285 .atomic) // QW SLM atomics are implemented in XeHPC, but seeing functionality issues.
2286 canQW = false;
2287 if (remainderR || remainderC) canQW &= (T.size() % 8 == 0);
2288 if (nativeAtomic) canQW = mustQW = (T.real().size() >= 8);
2289 auto stride = canQW ? 8 : 4;
2290 auto maxNPack = byte1PerSlot ? 1 : std::max<int>(1, stride / T);
2291 int simdCap = maxSIMD;
2292 if (astrategy.atomic && !nativeAtomic) simdCap = 16;
2293 maxElements = simdCap * maxNPack;
2294 if (T.size() > stride) maxElements = maxElements * stride / T;
2295 }
2296
2297 auto maxABlock = maxElements / (byte1PerSlot ? 1 : atype.crosspack);
2298
2299 auto choosePackedRCBlock = [=, &block](int &xblock, int &yblock,
2300 int tileX, int tileY, int X,
2301 int Y) {
2302 xblock = std::min<int>(maxABlock, X);
2303
2304 if (tileX) {
2305 int ntileY = tileY ? (maxElements / (xblock * tileY)) : 0;
2306 if (xblock < atype.packSize || Y < tileY || ntileY == 0)
2307 xblock = std::min<int>(xblock, tileX);
2308 }
2309 if ((tileX ? tileX : atype.packSize) <= xblock) {
2310 yblock = std::min<int>(maxElements / xblock, Y);
2311 if (yblock < atype.crosspack
2312 && isLargeCrosspack(T, atype.crosspack)) {
2313 yblock = atype.crosspack;
2314 xblock = std::min<int>(xblock, maxElements / yblock);
2315 }
2316 if (tileY > 0 && yblock > tileY)
2317 yblock = align_down(yblock, tileY);
2318 } else
2319 yblock = atype.crosspack; // Remainder loop: no longer packed in memory
2320
2321 block.crosspack = atype.crosspack;
2322 Y = div_up(Y, atype.crosspack);
2323 };
2324
2325 switch (atype.layout) {
2326 case MatrixLayout::Pc:
2327 choosePackedRCBlock(
2328 rblock, cblock, atype.tileR, atype.tileC, R, C);
2329 break;
2330 case MatrixLayout::N:
2331 if (atype.crosspack > 1) stub();
2332 if (atype.tileR == R && R <= maxElements) {
2333 cblock = std::min<int>(maxElements / R, C);
2334 rblock = R;
2335 } else {
2336 cblock = 1;
2337 rblock = std::min<int>(maxElements, R);
2338 }
2339 break;
2340 case MatrixLayout::Pr:
2341 choosePackedRCBlock(
2342 cblock, rblock, atype.tileC, atype.tileR, C, R);
2343 break;
2344 case MatrixLayout::T:
2345 if (atype.crosspack > 1) stub();
2346 if (atype.tileC == C && C <= maxElements) {
2347 rblock = std::min<int>(maxElements / cblock, R);
2348 cblock = C;
2349 } else {
2350 rblock = 1;
2351 cblock = std::min<int>(maxElements, C);
2352 }
2353 break;
2354 }
2355
2356 rblock = std::min(rblock, maxRBlock);
2357 cblock = std::min(cblock, maxCBlock);
2358
2359 if (pseudo) {
2360 bool qword = mustQW
2361 || (canQW && (rblock * cblock * T >= 4 * maxSIMD));
2362 npack = std::max<int>(1, (qword ? 8 : 4) / T);
2363 if (byte1PerSlot) {
2364 if (isLargeCrosspack(T, block.crosspack)) {
2365 if (block.crosspack == (colMajor ? cblock : rblock))
2366 block.colMajor = colMajor = effCM;
2367 else
2368 stub();
2369 }
2370 block.crosspack = npack;
2371 npack = 1;
2372 (effCM ? cblock : rblock) = 1;
2373 }
2374 maskGranularity = qword ? 8 : byte1PerSlot ? T.size() : 4;
2375 }
2376
2377 if (remainderR) {
2378 if (effCM) {
2379 // rblock cannot be more than 16 dwords = 64 bytes for masking
2380 // except for pseudo-block
2381 int rblockLimit = pseudo ? rblock : 64 / T;
2382
2383 if (avoidFragment)
2384 rblock = std::min<int>(rblock, rblockLimit);
2385 if (rblock > rblockLimit)
2386 block.rowFragment = rblockLimit;
2387 else {
2388 // For sizeof(T) < maskGranularity, this is a bit of a cheat.
2389 //
2390 // As long as we do not need to write to this matrix, we can read
2391 // in maskGranularity-sized chunks knowing we will never cross a page boundary.
2392
2393 if (writable && (T.size() & (maskGranularity - 1)))
2394 return false;
2395 if (!pseudo && oword && aoword) hw_unsupported();
2396
2397 if (!pseudo
2398 && !(isPacked(atype.layout)
2399 && (atype.packSize == rblock)))
2400 cblock = 1;
2401
2402 vrmask.isFixed = false;
2403 vrmask.rsize = rblock;
2404 vrmask.bitRep
2405 = std::max<int>(T.size() / maskGranularity, 1);
2406 vrmask.maskRep = cblock;
2407 vrmask.rdivide = std::max<int>(maskGranularity / T, 1);
2408 }
2409 } else {
2410 if (avoidFragment && !remainderC) {
2411 // No native masking in this dimension. One mask/row.
2412 rblock = 1;
2413 vrmask.isFixed = false;
2414 vrmask.bitRep = 16;
2415 vrmask.maskRep = 1;
2416 vrmask.rdivide = 1;
2417 vrmask.rsize = 1;
2418 } else {
2419 // Fragment it. Could actually handle rowFragment = 2 by changing descriptor.
2420 block.rowFragment = 1;
2421 }
2422 }
2423 }
2424
2425 if (remainderC) {
2426 if (!effCM) {
2427 // cblock cannot be more than 16 dwords = 64 bytes except for pseudo-block
2428 int cblockLimit = pseudo ? cblock : 64 / T;
2429
2430 if (avoidFragment)
2431 cblock = std::min<int>(cblock, cblockLimit);
2432 if (cblock > cblockLimit)
2433 block.colFragment = cblockLimit;
2434 else {
2435 if (writable && (T.size() & (maskGranularity - 1)))
2436 return false;
2437 if (!pseudo && oword && aoword) hw_unsupported();
2438
2439 if (!pseudo
2440 && !(isPacked(atype.layout)
2441 && (atype.packSize == cblock)))
2442 rblock = 1;
2443
2444 vcmask.isFixed = false;
2445 vcmask.rsize = cblock;
2446 vcmask.bitRep
2447 = std::max<int>(T.size() / maskGranularity, 1);
2448 vcmask.maskRep = rblock;
2449 vcmask.rdivide = std::max<int>(maskGranularity / T, 1);
2450 }
2451 } else {
2452 if (avoidFragment && !remainderR) {
2453 // No native masking in this dimension. One mask/column.
2454 cblock = 1;
2455 vcmask.isFixed = false;
2456 vcmask.bitRep = 16;
2457 vcmask.maskRep = 1;
2458 vcmask.rdivide = 1;
2459 vcmask.rsize = 1;
2460 } else {
2461 // Fragment it. Could actually handle colFragment = 2 by changing descriptor.
2462 block.colFragment = 1;
2463 }
2464 }
2465 }
2466
2467 int nbytes = (rblock * cblock) * T;
2468 block.simdSize
2469 = clamp(roundup_pow2(nbytes) / maskGranularity, 1, maxSIMD);
2470 block.ld = colMajor ? rblock : cblock;
2471 if (!pseudo) {
2472 if (astrategy.newDP) block.simdSize = 1;
2473 block.count = div_up(nbytes, block.ebytes);
2474 block.extra = aoword;
2475 if (block.ebytes == 16 && !(a32 || a64)
2476 && !aoword) // BTS/SLM oword loads are oword-addressed.
2477 block.addrShift = 4;
2478 } else {
2479 block.count = byte ? std::min(nbytes, npack * T) : 1;
2480 block.ebytes = byte ? 1 : maskGranularity;
2481 block.extra = 1;
2482 if (!(a32 || a64
2483 || pseudoblockUseSurface(atype, astrategy, block)
2484 || astrategy.atomic))
2485 block.addrShift = log2(block.ebytes);
2486 }
2487 if (astrategy.newDP) block.addrShift = 0;
2488 break;
2489 }
2490 case AccessType::Block2D:
2491 case AccessType::Block2DTranspose:
2492 case AccessType::Block2DVNNI: {
2493 // bytes * array length <= 8
2494 // width * array length <= 64 bytes
2495 // => width <= 1 GRF
2496 // height <= 32 (load) 8 (store)
2497 // array length = 1 for store, transpose
2498 //
2499 // normal: width >= 4 bytes
2500 // transpose: d32 only
2501 // vnni: d8/d16 only, height >= 4 bytes
2502 bool transpose = (accessType == AccessType::Block2DTranspose);
2503 bool vnni = (accessType == AccessType::Block2DVNNI);
2504
2505 bool memCM = block.colMajor;
2506 block.colMajor ^= transpose;
2507 auto X = memCM ? R : C;
2508 auto Y = memCM ? C : R;
2509 auto &xblock = memCM ? rblock : cblock;
2510 auto &yblock = memCM ? cblock : rblock;
2511 auto maxXBlock = memCM ? maxRBlock : maxCBlock;
2512 auto maxYBlock = memCM ? maxCBlock : maxRBlock;
2513
2514 if (hw != HW::XeHPC || !astrategy.newDP) hw_unsupported();
2515
2516 // Choose underlying type.
2517 auto Tblock = T;
2518 if (transpose) {
2519 if (atype.alignment % 4) hw_unsupported();
2520 if (Tblock.size() > 8) hw_unsupported();
2521 if (Tblock.size() > 4) {
2522 if (hw == HW::XeHPC && getStepping() < SteppingPVCXTB0)
2523 hw_unsupported();
2524 Tblock = Type::u64;
2525 maxXBlock = std::min(maxXBlock, 4);
2526 maxYBlock = 8;
2527 } else {
2528 Tblock = Type::u32;
2529 maxXBlock = std::min(maxXBlock, (8 * Tblock) / T);
2530 }
2531 } else if (vnni) {
2532 if (atype.alignment % 8) hw_unsupported();
2533 if (Tblock.size() >= 4) hw_unsupported();
2534 if ((Y * Tblock) % 4) hw_unsupported();
2535 maxXBlock = std::min(maxXBlock, 16);
2536 } else {
2537 if (atype.alignment % 8) hw_unsupported();
2538 if (Tblock.size() > 8) Tblock = Type::u64;
2539 block.crosspack = atype.crosspack;
2540 }
2541 if ((X * T) % 4) hw_unsupported();
2542
2543 // Reinterpret X/maxXBlock to underlying type.
2544 maxXBlock = (maxXBlock * T) / Tblock;
2545 auto X_logical = X;
2546 X = (X * T) / Tblock;
2547
2548 // Carve out a maximal allowed block size.
2549 xblock = std::min(X, 64 / Tblock);
2550 xblock = std::max(xblock, 4 / Tblock);
2551 int yblockLimit = writable ? 8 : 32;
2552
2553 if (isPacked(atype.layout) && 2 * xblock <= X
2554 && X_logical == atype.packSize) {
2555 // Split logical x dimension into multiple spans to accomodate width restriction.
2556 if (astrategy.address2D) stub();
2557 int multiX = X / xblock;
2558 xblock *= multiX;
2559 yblockLimit /= multiX;
2560 }
2561
2562 yblock = std::min({maxYBlock, Y, yblockLimit});
2563
2564 if (transpose && Tblock.size() == 8 && yblock != 8)
2565 hw_unsupported();
2566
2567 // Choose # of blocks. In postprocessLayout, this RegisterBlock will be
2568 // split into one RegisterBlock for each block in the array.
2569 int count = 1;
2570 if (!(writable || transpose)) {
2571 count = rounddown_pow2(xblock / maxXBlock);
2572 count = std::min({count, 8 / Tblock, 64 / xblock});
2573 count = std::max(count, 1);
2574 }
2575 xblock = std::min(xblock, maxXBlock * count);
2576
2577 // Crosspack calculation.
2578 int crosspack = (transpose || vnni) ? std::max(1, 4 / T) : 1;
2579 if (atype.crosspack == 1)
2580 block.crosspack = crosspack;
2581 else if (atype.crosspack == crosspack)
2582 block.crosspack = 1;
2583 else
2584 return false;
2585
2586 // Convert size from underlying type to our actual type.
2587 xblock = (xblock * Tblock) / T;
2588
2589 block.simdSize = 1;
2590 block.ld = roundup_pow2(transpose ? yblock : xblock);
2591 block.ebytes = Tblock.size();
2592 block.count = count;
2593 block.extra = T.size();
2594 auto bytes = align_up((block.colMajor ? cblock : rblock) / count,
2595 block.crosspack)
2596 * block.ld * count * T;
2597 block.msgRegs = GRF::bytesToGRFs(hw, bytes);
2598 break;
2599 }
2600 }
2601
2602 // The mask moduli are almost always rblock/cblock.
2603 // Also, clamp mask reps to ensure mask length does not exceed SIMD size.
2604 if (block.rowMask && !block.rowMask.fixed.isFixed) {
2605 if (vrmask.rsize == 0) vrmask.rsize = rblock;
2606 vrmask.maskRep = std::min<int>(vrmask.maskRep,
2607 std::max<int>(1,
2608 vrmask.rdivide * block.simdSize
2609 / (vrmask.bitRep * vrmask.rsize)));
2610 block.noRowsOK = true; // All-zero masks are always OK.
2611 }
2612 if (block.colMask && !block.colMask.fixed.isFixed) {
2613 if (vcmask.rsize == 0) vcmask.rsize = cblock;
2614 vcmask.maskRep = std::min<int>(vcmask.maskRep,
2615 std::max<int>(1,
2616 vcmask.rdivide * block.simdSize
2617 / (vcmask.bitRep * vcmask.rsize)));
2618 block.noColsOK = true;
2619 }
2620
2621 return true;
2622}
2623
2624template <HW hw>
2625bool gemm_kernel_generator_t<hw>::tryAddMasking(Type T, RegisterBlock &block,
2626 bool remainderR, bool remainderC, const MatrixAddressing &atype,
2627 const MatrixAddressingStrategy &astrategy) {
2628 auto blockNew = block;
2629 blockNew.remainderR |= remainderR;
2630 blockNew.remainderC |= remainderC;
2631
2632 auto curAccessType = implAccessType(atype, astrategy, block);
2633
2634 if (curAccessType == AccessType::Block) {
2635 if (astrategy.newDP) return false;
2636 if (hw >= HW::XeHPC) return false;
2637 }
2638
2639 bool remChanged = (block.colMajor ? (remainderR && !block.remainderR)
2640 : (remainderC && !block.remainderC));
2641
2642 if (remChanged && !isBlock2D(curAccessType)) {
2643 int rblock, cblock;
2644 if (!getBlockInfo(T, atype, astrategy, block.nr, block.nc,
2645 blockNew.remainderR, blockNew.remainderC, block.writable,
2646 true, 0, 0, rblock, cblock, blockNew))
2647 return false;
2648 if (rblock != block.nr || cblock != block.nc) return false;
2649 if (implAccessType(atype, astrategy, blockNew) != curAccessType)
2650 return false;
2651 if (curAccessType != AccessType::Block) {
2652 if (blockNew.ebytes != block.ebytes) return false;
2653 if (blockNew.ebytes == 1 && blockNew.count != block.count)
2654 return false;
2655 }
2656 }
2657
2658 block = blockNew;
2659 return true;
2660}
2661
2662template <HW hw>
2663bool gemm_kernel_generator_t<hw>::tryAddMasking(Type T,
2664 vector<RegisterBlock> &layout, bool remainderR, bool remainderC,
2665 const MatrixAddressing &atype,
2666 const MatrixAddressingStrategy &astrategy) {
2667 auto layoutNew = layout;
2668 for (auto &block : layoutNew) {
2669 if (!tryAddMasking(T, block, remainderR, remainderC, atype, astrategy))
2670 return false;
2671 }
2672 std::swap(layout, layoutNew);
2673 return true;
2674}
2675
2676template <HW hw>
2677void gemm_kernel_generator_t<hw>::addMasking(Type T,
2678 vector<RegisterBlock> &layout, bool remainderR, bool remainderC,
2679 const MatrixAddressing &atype,
2680 const MatrixAddressingStrategy &astrategy) {
2681 for (auto &block : layout)
2682 if (!tryAddMasking(T, block, remainderR, remainderC, atype, astrategy))
2683 stub();
2684}
2685
2686template <HW hw>
2687void gemm_kernel_generator_t<hw>::addMasking(Type T,
2688 vector<RegisterBlock> &layout, vector<GRFRange> &addrs,
2689 const Subregister &ld, bool remainderR, bool remainderC,
2690 const MatrixAddressing &atype,
2691 const MatrixAddressingStrategy &astrategy,
2692 const CommonStrategy &strategy, CommonState &state, int dataRegs) {
2693 // Check if masking can be trivially enabled without changing the layout.
2694 if (tryAddMasking(T, layout, remainderR, remainderC, atype, astrategy))
2695 return;
2696
2697 // If not, tear down the old layout and create a new one in its place, recalculating address registers.
2698 vector<RegisterBlock> layoutNew;
2699 int r, c;
2700 bool remR = remainderR || hasRemainders(layout, true, false);
2701 bool remC = remainderC || hasRemainders(layout, false, true);
2702 getLayoutDims(layout, r, c);
2703 if (!getRegLayout(T, layoutNew, r, c, remR, remC, false, true, 0, 0, atype,
2704 astrategy))
2705 stub();
2706 if (dataRegs < 0) dataRegs = getRegCount(layout);
2707 if (getRegCount(layoutNew) > dataRegs) stub();
2708 if (isLayoutColMajor(layoutNew) != isLayoutColMajor(layout)) stub();
2709
2710 int shift = 0;
2711 auto addr0 = getOriginAddr(layout, addrs, atype, astrategy, &shift);
2712 std::swap(layout, layoutNew);
2713 if (shift > 0) shl(1, addr0, addr0, shift);
2714 safeReleaseRanges(addrs, state);
2715 state.ra.claim(addr0);
2716
2717 Address2DParams params2D {};
2718 if (astrategy.address2D) stub();
2719 allocAddrRegs(addrs, layout, atype, astrategy, state);
2720 setupAddr(T, addrs, addr0, layout, ld, atype, astrategy, strategy, state,
2721 params2D);
2722
2723 state.ra.safeRelease(addr0);
2724}
2725
2726template <HW hw>
2727bool gemm_kernel_generator_t<hw>::getSubblock(Type T, RegisterBlock &blockDst,
2728 const RegisterBlock &blockSrc, bool column, int x1, int x2,
2729 int x1Unclamped, int x2Unclamped, bool overrunOK,
2730 const MatrixAddressing &atype,
2731 const MatrixAddressingStrategy &astrategy) {
2732 auto effAccessType = effectiveAccessType(atype, astrategy, blockSrc);
2733 blockDst = blockSrc;
2734
2735 auto &ns = (column ? blockDst.nc : blockDst.nr);
2736 auto &nt = (column ? blockDst.nr : blockDst.nc);
2737 int oldNS = ns;
2738
2739 (column ? blockDst.offsetC : blockDst.offsetR) += x1;
2740 ns = x2 - x1;
2741
2742 if ((ns == oldNS) && (overrunOK || !blockSrc.hasNoLoad)) return true;
2743
2744 if (blockSrc.colMajor == column) {
2745 if (x1 % blockSrc.crosspack) return false;
2746
2747 blockDst.offsetBytes += (x1 * blockSrc.bytes) / oldNS;
2748
2749 if (blockSrc.isLoadBlock()) switch (effAccessType) {
2750 case AccessType::Scattered:
2751 case AccessType::ChannelScattered:
2752 blockDst.count = x2 - x1;
2753 if (blockDst.ebytes == 1)
2754 blockDst.count *= T.size();
2755 else if (blockDst.splitComplex)
2756 blockDst.count *= 2;
2757 else if (T.size() < blockDst.ebytes) {
2758 // Extra alignment path with small types.
2759 // Check to see if we can still use this element size,
2760 // if not downgrade to scattered byte.
2761 // Note for surface accesses this requires shifting the addresses back.
2762 auto bcount = blockDst.count * T;
2763 if (bcount % 4) {
2764 blockDst.ebytes = 1;
2765 blockDst.addrShift = 0;
2766 blockDst.count = bcount;
2767 if (blockDst.count > 4) stub();
2768 } else
2769 blockDst.count = bcount >> 2;
2770 }
2771 break;
2772 case AccessType::Block:
2773 case AccessType::PseudoBlock: {
2774 auto offBytes = x1 * nt * T;
2775 if (offBytes % blockDst.ebytes) return false;
2776 auto reqBytes = (x2 - x1) * nt * T;
2777 auto align = std::min<int>(
2778 blockDst.ebytes, blockDst.simdSize * 4);
2779 if (!overrunOK && (reqBytes & (align - 1))) return false;
2780 auto ncount = div_up(reqBytes, blockDst.ebytes);
2781 auto count = roundup_pow2(ncount);
2782 if (!overrunOK && (count != ncount)) return false;
2783 if (effAccessType == AccessType::Block)
2784 blockDst.count = count;
2785 else
2786 blockDst.simdSize = std::max(1, count / blockDst.count);
2787 break;
2788 }
2789 case AccessType::Block2D: break;
2790 case AccessType::Block2DTranspose:
2791 case AccessType::Block2DVNNI:
2792 int crosspack = std::max(1, 4 / blockDst.ebytes);
2793 if (x1 % crosspack || x2 % crosspack) return false;
2794 break;
2795 }
2796
2797 blockDst.calcBytes(T, astrategy);
2798 } else {
2799 blockDst.offsetBytes += x1 * T * blockSrc.crosspack;
2800
2801 if (blockSrc.isLoadBlock()) switch (effAccessType) {
2802 case AccessType::Block:
2803 case AccessType::PseudoBlock: {
2804 // Update count and mask information.
2805 // Beware, cheat: with DW-aligned sub-DW types, true block may be downgraded to byte PseudoBlock,
2806 // which requires 2 address registers, though only 1 is used, and only 1 may be allocated.
2807 int rblock, cblock;
2808 (void)getBlockInfo(T, atype, astrategy, blockDst.nr,
2809 blockDst.nc, blockDst.remainderR,
2810 blockDst.remainderC, blockDst.writable, false, 0, 0,
2811 rblock, cblock, blockDst);
2812 blockDst.simplify(T);
2813 break;
2814 }
2815 case AccessType::Scattered:
2816 case AccessType::ChannelScattered: {
2817 if (T.size() > blockDst.ebytes) return false;
2818 if (x1 != 0) return false;
2819 if (!is_zero_or_pow2(x2)) return false;
2820
2821 blockDst.simdSize = div_up(ns * T, blockDst.ebytes);
2822
2823 auto minSIMD = minScatteredSIMD(hw, astrategy);
2824 if (blockDst.simdSize <= minSIMD
2825 && blockSrc.simdSize > minSIMD) {
2826 if (blockDst.count > 1 && blockDst.ebytes > 1)
2827 return false;
2828 blockDst.ld >>= 1;
2829 }
2830 break;
2831 }
2832 case AccessType::Block2D:
2833 case AccessType::Block2DTranspose:
2834 case AccessType::Block2DVNNI:
2835 if (ns != oldNS)
2836 stub(); // Can do this, but not implemented.
2837 if (blockDst.simdSize != 0) // Recompute block array length.
2838 blockDst.count = div_up(x2Unclamped,
2839 isColMajor(atype.layout) ? blockDst.nr
2840 : blockDst.nc);
2841 // TODO: need to recompute ld
2842 break;
2843 }
2844
2845 blockDst.calcBytes(T, astrategy);
2846 }
2847
2848 return true;
2849}
2850
2851// Get list of subblocks intersecting rows/columns [x1, x2).
2852template <HW hw>
2853bool gemm_kernel_generator_t<hw>::getSubblocks(Type T,
2854 vector<RegisterBlock> &sublayout, const vector<RegisterBlock> &layout,
2855 bool column, int x1, int x2, bool overrunOK,
2856 const MatrixAddressing &atype,
2857 const MatrixAddressingStrategy &astrategy) {
2858 auto RegisterBlock::*nq = column ? &RegisterBlock::nc : &RegisterBlock::nr;
2859 auto RegisterBlock::*offsetQ
2860 = column ? &RegisterBlock::offsetC : &RegisterBlock::offsetR;
2861
2862 sublayout.clear();
2863
2864 for (auto &block : layout) {
2865 int qq1Unclamped = x1 - block.*offsetQ;
2866 int qq2Unclamped = x2 - block.*offsetQ;
2867 int qq1 = clamp<int>(qq1Unclamped, 0, block.*nq);
2868 int qq2 = clamp<int>(qq2Unclamped, 0, block.*nq);
2869 if (qq2 > qq1) {
2870 RegisterBlock subblock;
2871 if (!getSubblock(T, subblock, block, column, qq1, qq2, qq1Unclamped,
2872 qq2Unclamped, overrunOK, atype, astrategy)) {
2873 status << "Could not make subblock." << status_stream::endl;
2874 return false;
2875 }
2876 sublayout.push_back(subblock);
2877 }
2878 }
2879 return true;
2880}
2881
2882// Get list of subblocks intersecting rows/columns [x1, x2), and associated address registers and/or indices.
2883// Returns false if fragmenting failed, or an address register doesn't match a previous one.
2884template <HW hw>
2885bool gemm_kernel_generator_t<hw>::getSubblocks(Type T,
2886 vector<RegisterBlock> &sublayout, vector<GRFRange> *subaddrs,
2887 vector<int> *indices, const vector<RegisterBlock> &layout,
2888 const vector<GRFRange> *addrs, bool column, int x1, int x2,
2889 bool overrunOK, const MatrixAddressing &atype,
2890 const MatrixAddressingStrategy &astrategy) {
2891 auto RegisterBlock::*nq = column ? &RegisterBlock::nc : &RegisterBlock::nr;
2892 auto RegisterBlock::*offsetQ
2893 = column ? &RegisterBlock::offsetC : &RegisterBlock::offsetR;
2894
2895 if (subaddrs) subaddrs->clear();
2896 if (indices) indices->clear();
2897 sublayout.clear();
2898
2899 for (int b = 0; b < int(layout.size()); b++) {
2900 auto &block = layout[b];
2901 int qq1Unclamped = x1 - block.*offsetQ;
2902 int qq2Unclamped = x2 - block.*offsetQ;
2903 int qq1 = clamp<int>(qq1Unclamped, 0, block.*nq);
2904 int qq2 = clamp<int>(qq2Unclamped, 0, block.*nq);
2905 if (qq2 > qq1) {
2906 RegisterBlock subblock;
2907 if (!getSubblock(T, subblock, block, column, qq1, qq2, qq1Unclamped,
2908 qq2Unclamped, overrunOK, atype, astrategy)) {
2909 status << "Could not make subblock." << status_stream::endl;
2910 return false;
2911 }
2912 if (subblock.offsetR != block.offsetR
2913 || subblock.offsetC != block.offsetC) {
2914 status << "Subblock is not aligned to parent block."
2915 << status_stream::endl;
2916 return false;
2917 }
2918 if (subaddrs) subaddrs->push_back((*addrs)[b]);
2919 if (indices) indices->push_back(int(b));
2920 sublayout.push_back(subblock);
2921 }
2922 }
2923 return true;
2924}
2925
2926// Get list of subblocks intersecting rows/columns [x1, x2), and associated address registers.
2927// Returns false if fragmenting failed, or an address register doesn't match a previous one.
2928template <HW hw>
2929bool gemm_kernel_generator_t<hw>::getSubblocks(Type T,
2930 vector<RegisterBlock> &sublayout, vector<GRFRange> &subaddrs,
2931 const vector<RegisterBlock> &layout, const vector<GRFRange> &addrs,
2932 bool column, int x1, int x2, bool overrunOK,
2933 const MatrixAddressing &atype,
2934 const MatrixAddressingStrategy &astrategy) {
2935 return getSubblocks(T, sublayout, &subaddrs, nullptr, layout, &addrs,
2936 column, x1, x2, overrunOK, atype, astrategy);
2937}
2938
2939// Get list of subblocks intersecting rows/columns [x1, x2), and indices of associated address registers.
2940// Returns false if fragmenting failed, or an address register doesn't match a previous one.
2941template <HW hw>
2942bool gemm_kernel_generator_t<hw>::getSubblocks(Type T,
2943 vector<RegisterBlock> &sublayout, vector<int> &indices,
2944 const vector<RegisterBlock> &layout, bool column, int x1, int x2,
2945 bool overrunOK, const MatrixAddressing &atype,
2946 const MatrixAddressingStrategy &astrategy) {
2947 return getSubblocks(T, sublayout, nullptr, &indices, layout, nullptr,
2948 column, x1, x2, overrunOK, atype, astrategy);
2949}
2950
2951// Adjust address registers as needed for a newly-created subblock.
2952template <HW hw>
2953void gemm_kernel_generator_t<hw>::adjustSubblockAddrs(Type T,
2954 const vector<RegisterBlock> &sublayout,
2955 const vector<GRFRange> &subaddrs, const vector<RegisterBlock> &layout,
2956 const vector<GRFRange> &addrs, const MatrixAddressing &atype,
2957 const MatrixAddressingStrategy &astrategy,
2958 const CommonStrategy &strategy, const CommonState &state) {
2959 bool a64 = (astrategy.base.getModel() == ModelA64);
2960
2961 auto nsubs = int(sublayout.size());
2962 auto nblocks = int(layout.size());
2963
2964 for (int isub = 0; isub < nsubs; isub++) {
2965 // Find parent block by comparing address registers.
2966 auto &subaddr = subaddrs[isub];
2967 const RegisterBlock *pptr = nullptr;
2968 for (int i = 0; i < nblocks; i++) {
2969 if (addrs[i].getBase() == subaddr.getBase()) {
2970 pptr = &layout[i];
2971 break;
2972 }
2973 }
2974 if (!pptr) stub();
2975
2976 auto &block = *pptr;
2977 auto &subblock = sublayout[isub];
2978
2979 auto off = getAddr0Offset(block, atype, astrategy);
2980 auto suboff = getAddr0Offset(subblock, atype, astrategy);
2981
2982 // Perform any necessary shifts/moves. Moves are only for non-A64 block->pseudoblock settings.
2983 if (suboff != off) {
2984 if (subblock.simdSize != 1)
2985 stub(); // Need to prepare more pseudoblock addresses.
2986 mov<uint32_t>(1, subaddr[0][suboff], subaddr[0][off]);
2987 }
2988 if (subblock.addrShift != block.addrShift) {
2989 map(hw, a64 ? Type::u64 : Type::u32, subaddr, subaddr, strategy,
2990 [&](int simd, GRF r, GRF _) {
2991 auto shift = block.addrShift - subblock.addrShift;
2992 (shift > 0) ? eshl(simd, r, r, +shift, strategy, state)
2993 : eshr(simd, r, r, -shift, strategy, state);
2994 });
2995 }
2996
2997 if (isBlock2D(astrategy.accessType)) {
2998 // Adjust 2D block header as needed.
2999 int bw, bh;
3000 bool memCM = isColMajor(atype.layout);
3001 auto RegisterBlock::*nw
3002 = memCM ? &RegisterBlock::nr : &RegisterBlock::nc;
3003 auto RegisterBlock::*nh
3004 = memCM ? &RegisterBlock::nc : &RegisterBlock::nr;
3005 bool remW = memCM ? subblock.remainderR : subblock.remainderC;
3006 bool remH = memCM ? subblock.remainderC : subblock.remainderR;
3007 getBlock2DWH(bw, bh, atype, subblock);
3008
3009 if (!astrategy.address2D) {
3010 if (subblock.*nw != block.*nw
3011 || subblock.count != block.count) {
3012 int newW = bw * subblock.count * subblock.ebytes - 1;
3013 remW ? min_(1, subaddr[0].ud(2), subaddr[0].ud(2), newW)
3014 : mov(1, subaddr[0].ud(2), newW);
3015 }
3016 if (subblock.*nh != block.*nh) {
3017 int newH = bh * subblock.ebytes - 1;
3018 remH ? min_(1, subaddr[0].ud(3), subaddr[0].ud(3), newH)
3019 : mov(1, subaddr[0].ud(3), newH);
3020 }
3021 }
3022 if (subblock.nr != block.nr || subblock.nc != block.nc
3023 || subblock.count != block.count)
3024 mov(1, subaddr[0].ud(7),
3025 (bw - 1) | ((bh - 1) << 8)
3026 | ((subblock.count - 1) << 16));
3027 }
3028 }
3029}
3030
3031// Split 2D block array loads into multiple blocks.
3032static inline void postprocessLayout2D(vector<RegisterBlock> &layout,
3033 const MatrixAddressing &atype,
3034 const MatrixAddressingStrategy &astrategy) {
3035 if (!isBlock2D(astrategy.accessType)) return;
3036
3037 int maxCount = 1;
3038 for (auto &block : layout)
3039 maxCount = std::max(maxCount, int(block.count));
3040 if (maxCount == 1) return;
3041
3042 vector<RegisterBlock> xlayout;
3043 xlayout.reserve(layout.size() * maxCount);
3044
3045 bool memCM = isColMajor(atype.layout);
3046 auto RegisterBlock::*nx = memCM ? &RegisterBlock::nr : &RegisterBlock::nc;
3047 auto RegisterBlock::*offsetX
3048 = memCM ? &RegisterBlock::offsetR : &RegisterBlock::offsetC;
3049
3050 for (auto &block : layout) {
3051 auto nblock = block;
3052 nblock.*nx /= block.count;
3053 if (!isTransposing(astrategy.accessType)) nblock.ld /= block.count;
3054
3055 for (int i = 0; i < block.count; i++) {
3056 xlayout.push_back(nblock);
3057 nblock.*offsetX += nblock.*nx;
3058 nblock.simdSize = 0; // Blocks > 0 do not need loads.
3059 }
3060 }
3061
3062 std::swap(layout, xlayout);
3063}
3064
3065// Split blocks that span multiple tiles. Requires each tile to be contained within a single block.
3066static inline void postprocessLayoutMultitile(Type T,
3067 vector<RegisterBlock> &layout, const MatrixAddressing &atype,
3068 const MatrixAddressingStrategy &astrategy) {
3069 if (!atype.tileR || !atype.tileC) return;
3070 if (isLargeCrosspack(T, atype.crosspack)) return;
3071
3072 bool needToSplit = false;
3073 for (const auto &block : layout)
3074 needToSplit |= (block.colMajor ? (block.nr > atype.tileR)
3075 : (block.nc > atype.tileC));
3076
3077 if (!needToSplit) return;
3078
3079 vector<RegisterBlock> xlayout;
3080 xlayout.reserve(layout.size());
3081
3082 for (const auto &block : layout) {
3083 auto nx = block.colMajor ? &RegisterBlock::nr : &RegisterBlock::nc;
3084 auto ny = block.colMajor ? &RegisterBlock::nc : &RegisterBlock::nr;
3085 auto offsetX = block.colMajor ? &RegisterBlock::offsetR
3086 : &RegisterBlock::offsetC;
3087 auto offsetY = block.colMajor ? &RegisterBlock::offsetC
3088 : &RegisterBlock::offsetR;
3089 auto tileX = block.colMajor ? atype.tileR : atype.tileC;
3090 auto tileY = block.colMajor ? atype.tileC : atype.tileR;
3091
3092 if (block.*nx == tileX) {
3093 xlayout.push_back(block);
3094 continue;
3095 }
3096
3097 if (block.*nx % tileX || block.*offsetX % tileX || block.*ny % tileY
3098 || block.*offsetY % tileY)
3099 stub();
3100 if (isTransposing(astrategy.accessType)) stub();
3101
3102 auto nblock = block;
3103 nblock.*nx = tileX;
3104 nblock.*ny = tileY;
3105 nblock.ld = tileX;
3106
3107 for (int j = 0; j < block.*ny / tileY; j++) {
3108 for (int i = 0; i < block.*nx / tileX; i++) {
3109 nblock.*offsetX = block.*offsetX + i * tileX;
3110 nblock.*offsetY = block.*offsetY + j * tileY;
3111 xlayout.push_back(nblock);
3112 nblock.simdSize = 0;
3113 }
3114 }
3115 }
3116
3117 std::swap(layout, xlayout);
3118}
3119
3120// Split large crosspack blocks into smaller pieces so that they can be transposed.
3121static inline void postprocessLayoutLargeCP(Type T,
3122 vector<RegisterBlock> &layout, const MatrixAddressing &atype,
3123 const MatrixAddressingStrategy &astrategy) {
3124 if (!isLargeCrosspack(T, atype.crosspack)) return;
3125
3126 bool haveLargeCP = false;
3127 for (const auto &block : layout) {
3128 haveLargeCP |= isLargeCrosspack(T, block.crosspack);
3129 if (haveLargeCP) break;
3130 }
3131
3132 if (!haveLargeCP) return;
3133
3134 vector<RegisterBlock> xlayout;
3135 xlayout.reserve(layout.size());
3136
3137 for (const auto &block : layout) {
3138 if (!isLargeCrosspack(T, block.crosspack))
3139 xlayout.push_back(block);
3140 else {
3141 auto ny = block.colMajor ? &RegisterBlock::nc : &RegisterBlock::nr;
3142 auto offsetY = block.colMajor ? &RegisterBlock::offsetC
3143 : &RegisterBlock::offsetR;
3144
3145 if (block.*ny % block.crosspack) return;
3146 int blocks = (block.*ny / block.crosspack);
3147 auto nblock = block;
3148 nblock.*ny = block.crosspack;
3149 nblock.simplify(T);
3150 for (int i = 0; i < blocks; i++) {
3151 xlayout.push_back(nblock);
3152 nblock.simdSize = 0;
3153 nblock.*offsetY += nblock.*ny;
3154 }
3155 }
3156 }
3157
3158 std::swap(layout, xlayout);
3159}
3160
3161// Remove unneeded blocks from a dpasw src2 layout.
3162static inline void postprocessLayoutDPASW(vector<RegisterBlock> &layout,
3163 const MatrixAddressing &atype,
3164 const MatrixAddressingStrategy &astrategy) {
3165 if (!astrategy.dpasw) return;
3166
3167 vector<RegisterBlock> nlayout;
3168 nlayout.reserve(layout.size() / 2);
3169
3170 bool cm = isLayoutColMajor(layout);
3171 auto tile = cm ? astrategy.tileC : astrategy.tileR;
3172 auto offsetX = cm ? &RegisterBlock::offsetC : &RegisterBlock::offsetR;
3173
3174 for (const auto &block : layout)
3175 if ((block.*offsetX % (2 * tile)) < tile) nlayout.push_back(block);
3176
3177 layout = std::move(nlayout);
3178}
3179
3180static inline void postprocessLayout(Type T, vector<RegisterBlock> &layout,
3181 const MatrixAddressing &atype,
3182 const MatrixAddressingStrategy &astrategy) {
3183 postprocessLayout2D(layout, atype, astrategy);
3184 postprocessLayoutMultitile(T, layout, atype, astrategy);
3185 postprocessLayoutLargeCP(T, layout, atype, astrategy);
3186 postprocessLayoutDPASW(layout, atype, astrategy);
3187}
3188
3189// Add a submatrix to a register layout.
3190template <HW hw>
3191bool gemm_kernel_generator_t<hw>::addToRegLayout(Type T,
3192 std::vector<RegisterBlock> &layout, int nr, int nc, int roff, int coff,
3193 bool remainderR, bool remainderC, bool writable, bool avoidFragment,
3194 int maxRBlock, int maxCBlock, const MatrixAddressing &atype,
3195 const MatrixAddressingStrategy &astrategy) {
3196 int rblock, cblock;
3197 RegisterBlock blockTemplate;
3198 if (!getBlockInfo(T, atype, astrategy, nr, nc, remainderR, remainderC,
3199 writable, avoidFragment, maxRBlock, maxCBlock, rblock, cblock,
3200 blockTemplate))
3201 return false; /* Cannot handle requested block and remainder. */
3202
3203 if (rblock == 0 || cblock == 0) return false;
3204
3205 blockTemplate.nr = rblock;
3206 blockTemplate.nc = cblock;
3207
3208 for (int q = 0; q < T.components(); q++) {
3209 blockTemplate.component = q;
3210 if (isColMajor(atype.layout)) {
3211 // Order blocks in column-major fashion.
3212 for (int c = 0; c + cblock <= nc; c += cblock) {
3213 for (int r = 0; r + rblock <= nr; r += rblock) {
3214 auto thisBlock = blockTemplate;
3215
3216 thisBlock.offsetR = r + roff;
3217 thisBlock.offsetC = c + coff;
3218
3219 layout.push_back(thisBlock);
3220 }
3221 }
3222 } else {
3223 // Order blocks in row-major fashion.
3224 for (int r = 0; r + rblock <= nr; r += rblock) {
3225 for (int c = 0; c + cblock <= nc; c += cblock) {
3226 auto thisBlock = blockTemplate;
3227
3228 thisBlock.offsetR = r + roff;
3229 thisBlock.offsetC = c + coff;
3230
3231 layout.push_back(thisBlock);
3232 }
3233 }
3234 }
3235 }
3236
3237 // Handle remainder recursively, checking for infinite recursion.
3238 int rrem = nr % rblock;
3239 int crem = nc % cblock;
3240
3241 status << "Register layout: " << nr << 'x' << nc << " -> blocks " << rblock
3242 << 'x' << cblock << " remainder " << rrem << 'x' << crem
3243 << status_stream::endl;
3244
3245 bool success = true;
3246 if (rrem || crem) {
3247 if ((nr == rrem || rrem == 0) && (nc == crem || crem == 0)) {
3248 status << "Cannot load/store requested matrix block size."
3249 << status_stream::endl;
3250 success = false;
3251 } else {
3252 if (rrem)
3253 success &= addToRegLayout(T, layout, rrem, nc - crem, nr - rrem,
3254 0, remainderR, remainderC, writable, avoidFragment,
3255 maxRBlock, maxCBlock, atype, astrategy);
3256 if (crem)
3257 success &= addToRegLayout(T, layout, nr, crem, 0, nc - crem,
3258 remainderR, remainderC, writable, avoidFragment,
3259 maxRBlock, maxCBlock, atype, astrategy);
3260 }
3261 }
3262 return success;
3263}
3264
3265// Add a submatrix (contiguous in memory) to a block-accessed register layout.
3266template <HW hw>
3267bool gemm_kernel_generator_t<hw>::add1DBlockToRegLayout(Type T,
3268 vector<RegisterBlock> &layout, int r, int c, bool writable,
3269 const MatrixAddressing &atype,
3270 const MatrixAddressingStrategy &astrategy) {
3271 // Skip pseudoblock cases (possible to support though)
3272 if (needsPseudoblock(hw, T, r, c, atype, astrategy, writable, false))
3273 return false;
3274
3275 // Get total number of bytes to load. No masking supported, so stub if
3276 // number of bytes not divisible by 16 (1 oword).
3277 int nbytes = r * c * T * T.components();
3278 int align = 16;
3279 if (astrategy.newDP) align = 4;
3280
3281 if (nbytes & (align - 1)) return false;
3282
3283 // Get block info.
3284 int maxBBytes = 0;
3285 int ebytes = 0;
3286 int extra = 0;
3287 int addrShift = 0;
3288 int maxSIMD = 1;
3289
3290 if (astrategy.newDP) {
3291 bool qword = (nbytes | atype.alignment) % 8 == 0;
3292 ebytes = qword ? 8 : 4;
3293 maxBBytes = ebytes * 64;
3294 } else {
3295 bool a64 = (astrategy.base.getModel() == ModelA64);
3296 bool oword = !a64;
3297 bool aoword = (astrategy.base.getModel()
3298 == ModelSC); // SC only does aligned oword
3299 if (hw >= HW::XeHP) oword |= ((atype.alignment & 0x1F) != 0);
3300
3301 extra = aoword;
3302 ebytes = oword ? 16 : 32;
3303 maxBBytes = oword ? 128 : 256;
3304 if (astrategy.base.getModel() == ModelSLM && hw >= HW::XeHP)
3305 maxBBytes = 256;
3306 addrShift = (!a64 && oword && !aoword) ? 4 : 0;
3307 maxSIMD = 16;
3308 }
3309
3310 // Get normalized dimensions.
3311 bool colMajor = isColMajor(atype.layout);
3312 int x = colMajor ? r : c;
3313 auto crosspack = atype.crosspack;
3314
3315 // Counters for current x and y positions.
3316 int cx = 0, cy = 0;
3317
3318 while (nbytes > 0) {
3319 // Carve out the largest chunk possible.
3320 int bbytes = std::min<int>(maxBBytes, rounddown_pow2(nbytes));
3321 int belems = bbytes / T;
3322
3323 // Create a true load block for first (possibly partial) row/column.
3324 // Then, create additional no-load blocks for any further (possible partial)
3325 // rows/columns until block is exhausted.
3326 bool first = true;
3327 while (belems > 0) {
3328 int nxRem = belems / crosspack;
3329 int nx = std::min<int>(nxRem, x - cx);
3330 if (nx <= 0) stub();
3331 if (cy % crosspack) return false;
3332
3333 RegisterBlock block;
3334
3335 block.ld = nx;
3336 (colMajor ? block.nr : block.nc) = nx;
3337 (colMajor ? block.nc : block.nr) = crosspack;
3338 (colMajor ? block.offsetR : block.offsetC) = cx;
3339 (colMajor ? block.offsetC : block.offsetR) = cy;
3340 block.component = 0;
3341 block.colMajor = colMajor;
3342 block.splitComplex = false;
3343 block.cxComponent = RegisterBlock::Interleaved;
3344
3345 if (first) {
3346 block.ebytes = ebytes;
3347 block.count = div_up(bbytes, ebytes);
3348 block.simdSize = std::min(maxSIMD, roundup_pow2(bbytes) >> 2);
3349 } else
3350 block.ebytes = block.count = block.simdSize = 0;
3351
3352 block.extra = extra;
3353 block.clearFlag();
3354 block.colMask = MaskInfo::None();
3355 block.rowMask = MaskInfo::None();
3356 block.colFragment = 0;
3357 block.rowFragment = 0;
3358 block.log2GRFBytes = GRF::log2Bytes(hw);
3359
3360 block.crosspack = crosspack;
3361 block.remainderR = false;
3362 block.remainderC = false;
3363 block.noRowsOK = false;
3364 block.noColsOK = false;
3365 block.descRemR = false;
3366 block.descRemC = false;
3367 block.descAssigned = false;
3368 block.addrShift = addrShift;
3369 block.hasNoLoad = false;
3370 block.msgRegs = std::max(1, bbytes >> GRF::log2Bytes(hw));
3371
3372 if (first && cx == 0 && (nxRem % x) == 0) {
3373 // Shortcut: one register block can represent this block access.
3374 int ny = belems / x;
3375 (colMajor ? block.nc : block.nr) = ny;
3376 cy += ny;
3377 belems = 0;
3378 } else {
3379 cx += nx;
3380 belems -= nx * crosspack;
3381 if (cx == x) {
3382 cy += crosspack;
3383 cx = 0;
3384 }
3385 block.hasNoLoad = first && (belems > 0);
3386 first = false;
3387 }
3388
3389 layout.push_back(block);
3390 }
3391
3392 nbytes -= bbytes;
3393 }
3394
3395 return true;
3396}
3397
3398static inline int getPartialCrosspack(size_t sizeofT,
3399 const MatrixAddressing &atype, const RegisterBlock &block) {
3400 if (block.ebytes == 1 && !isLargeCrosspack(sizeofT, atype.crosspack))
3401 return div_up(atype.crosspack, block.colMajor ? block.nc : block.nr);
3402 else
3403 return 1;
3404}
3405
3406// Get linear element offset in tiled layout (both register and memory)
3407static int untile(Type T, const MatrixAddressing &atype, int component, int i,
3408 int j, int r, int c, int tileR, int tileC, bool reverse = false) {
3409 bool cm = isColMajor(atype.layout) ^ reverse;
3410
3411 if (isPacked(atype.layout)) (cm ? r : c) = atype.packSize;
3412
3413 int cpR = cm ? 1 : atype.crosspack;
3414 int cpC = cm ? atype.crosspack : 1;
3415
3416 if (tileR == 0) tileR = r;
3417 if (tileC == 0) tileC = c;
3418
3419 int rstride = cm ? tileC : c;
3420 int cstride = cm ? r : tileR;
3421 int rtstride = cm ? cpC : tileC;
3422 int ctstride = cm ? tileR : cpR;
3423
3424 rstride *= T.components();
3425 cstride *= T.components();
3426
3427 int iTile = i % tileR;
3428 int jTile = j % tileC;
3429 i -= iTile;
3430 j -= jTile;
3431 int iCP = iTile % cpR;
3432 int jCP = jTile % cpC;
3433 iTile -= iCP;
3434 jTile -= jCP;
3435 int idx = i * rstride + j * cstride + tileR * tileC * component
3436 + iTile * rtstride + jTile * ctstride + iCP + jCP;
3437 return idx;
3438}
3439
3440static int untile(Type T, const MatrixAddressing &atype,
3441 const RegisterBlock &block, int r, int c, int tileR, int tileC,
3442 bool reverse = false) {
3443 return untile(T, atype, block.component, block.offsetR, block.offsetC, r, c,
3444 tileR, tileC, reverse);
3445}
3446
3447static int untile(Type T, const MatrixAddressing &atype,
3448 const RegisterBlock &block, int r, int c, bool reverse = false) {
3449 return untile(T, atype, block, r, c, atype.tileR, atype.tileC, reverse);
3450}
3451
3452static int untile(Type T, const MatrixAddressing &atype, int component, int i,
3453 int j, int r, int c, bool reverse = false) {
3454 return untile(
3455 T, atype, component, i, j, r, c, atype.tileR, atype.tileC, reverse);
3456}
3457
3458// Split A/B matrix between threads.
3459static inline void coopSplit(bool isA, int &splitR, int &splitC, int r, int c,
3460 CoopSplit stype, int threads, const MatrixAddressing &atype) {
3461 auto &mn = isA ? r : c;
3462 auto &k = isA ? c : r;
3463 auto &splitMN = isA ? splitR : splitC;
3464 auto &splitK = isA ? splitC : splitR;
3465 auto tileMN = isA ? atype.tileR : atype.tileC;
3466 auto tileK = isA ? atype.tileC : atype.tileR;
3467
3468 bool ok = false;
3469
3470 switch (stype) {
3471 case CoopSplit::K:
3472 ok = (k % threads == 0);
3473 splitMN = mn;
3474 splitK = k / threads;
3475 break;
3476 case CoopSplit::MN:
3477 ok = (mn % threads == 0);
3478 splitMN = mn / threads;
3479 splitK = k;
3480 break;
3481 case CoopSplit::Linear: {
3482 int elems = r * c;
3483 ok = (elems % threads == 0);
3484 int selems = elems / threads;
3485 int cp = atype.crosspack;
3486
3487 if (!tileK) tileK = k;
3488 if (!tileMN) tileMN = mn;
3489
3490 // First try splitting into tiles in k dimension.
3491 if (selems >= (tileK * mn)) {
3492 ok &= (selems % (tileK * mn) == 0);
3493 splitMN = mn;
3494 splitK = k / threads;
3495 break;
3496 }
3497
3498 ok &= (threads % (k / tileK) == 0);
3499 if (!ok) break;
3500 threads /= (k / tileK);
3501
3502 // Then try splitting into tiles in m/n dimensions as well.
3503 if (selems >= (tileK * tileMN)) {
3504 ok &= (selems % (tileK * tileMN) == 0);
3505 splitMN = mn / threads;
3506 splitK = tileK;
3507 break;
3508 }
3509
3510 ok &= (threads % (mn / tileMN) == 0);
3511 if (!ok) break;
3512 threads /= (mn / tileMN);
3513
3514 // Then try splitting each tile in the k dimension.
3515 if (selems >= (cp * tileMN)) {
3516 ok &= (selems % (cp * tileMN) == 0);
3517 splitMN = tileMN;
3518 splitK = tileK / threads;
3519 break;
3520 }
3521
3522 ok &= (threads % (tileK / cp) == 0);
3523 if (!ok) break;
3524 threads /= (tileK / cp);
3525
3526 // Finally try splitting in the m/n dimensions.
3527 ok &= (selems % cp == 0);
3528 splitMN = tileMN / threads;
3529 splitK = cp;
3530 break;
3531 }
3532 }
3533
3534 if (!ok)
3535 throw std::runtime_error(
3536 "Cooperative operation cannot be split evenly between "
3537 "threads.");
3538}
3539
3540// Re-order a layout so that registers appear in appropriate order
3541// (row or column major)
3542static void sortRegLayout(Type T, vector<RegisterBlock> &layout, int r, int c,
3543 const MatrixAddressing &atype,
3544 const MatrixAddressingStrategy &astrategy, bool reverse = false) {
3545 auto order = [=](const RegisterBlock &block) {
3546 return untile(T, atype, block, r, c, astrategy.tileR, astrategy.tileC,
3547 reverse);
3548 };
3549
3550 std::sort(layout.begin(), layout.end(),
3551 [&](const RegisterBlock &b1, const RegisterBlock &b2) {
3552 return (order(b1) < order(b2));
3553 });
3554}
3555
3556static void finalizeLayout(HW hw, Type T, vector<RegisterBlock> &layout,
3557 const MatrixAddressing &atype,
3558 const MatrixAddressingStrategy &astrategy) {
3559 int offsetBytes = 0;
3560 for (auto &block : layout) {
3561 if (block.isLoadBlock() || isBlock2D(astrategy.accessType))
3562 offsetBytes = alignup_pow2(offsetBytes, GRF::bytes(hw));
3563 block.calcBytes(T, astrategy);
3564 block.offsetBytes = offsetBytes;
3565 offsetBytes += block.bytes;
3566 block.simplify(T);
3567 }
3568}
3569
3570// Create a register layout for a matrix.
3571template <HW hw>
3572bool gemm_kernel_generator_t<hw>::getRegLayout(Type T,
3573 vector<RegisterBlock> &layout, int r, int c, bool remainderR,
3574 bool remainderC, bool writable, bool avoidFragment, int maxRBlock,
3575 int maxCBlock, const MatrixAddressing &atype,
3576 const MatrixAddressingStrategy &astrategy, bool reverseOrder) {
3577 bool success = false;
3578
3579 layout.clear();
3580
3581 // Tiling handling.
3582 if (astrategy.tileR > 0)
3583 maxRBlock = (maxRBlock == 0) ? astrategy.tileR
3584 : gcd(int(astrategy.tileR), maxRBlock);
3585 if (astrategy.tileC > 0)
3586 maxCBlock = (maxCBlock == 0) ? astrategy.tileC
3587 : gcd(int(astrategy.tileC), maxRBlock);
3588
3589 // Two separate strategies for creating register layout:
3590 // - standard 2D partitioning
3591 // - special 1D partitioning for block access to packed inputs.
3592 if (((atype.layout == MatrixLayout::Pc && atype.packSize == r)
3593 || (atype.layout == MatrixLayout::Pr && atype.packSize == c))
3594 && (astrategy.accessType == AccessType::Block) && !remainderR
3595 && !remainderC && !atype.tileR && !atype.tileC
3596 && (maxRBlock >= r || maxRBlock == 0)
3597 && (maxCBlock >= c || maxCBlock == 0)) {
3598 success = add1DBlockToRegLayout(
3599 T, layout, r, c, writable, atype, astrategy);
3600 }
3601 if (!success) {
3602 success = addToRegLayout(T, layout, r, c, 0, 0, remainderR, remainderC,
3603 writable, avoidFragment, maxRBlock, maxCBlock, atype,
3604 astrategy);
3605 sortRegLayout(T, layout, r, c, atype, astrategy, reverseOrder);
3606 postprocessLayout(T, layout, atype, astrategy);
3607 }
3608 if (!success) return false;
3609
3610 finalizeLayout(hw, T, layout, atype, astrategy);
3611
3612 return true;
3613}
3614
3615// Create a register layout for a uniform matrix not backed by memory.
3616template <HW hw>
3617void gemm_kernel_generator_t<hw>::makeUnbackedRegLayout(Type T,
3618 vector<RegisterBlock> &layout, int r, int c, bool colMajor,
3619 int crosspack, int tileR, int tileC, bool allowPartialRegs,
3620 bool fullySplitCx) {
3621 auto block = RegisterBlock();
3622
3623 if ((colMajor ? c : r) % crosspack) stub();
3624 layout.clear();
3625
3626 if (tileR <= 0) tileR = r;
3627 if (tileC <= 0) tileC = c;
3628
3629 int offsetBytes = 0;
3630 int qCXMin = -1, qCXMax = -1;
3631
3632 for (int qCX = qCXMin; qCX <= qCXMax; qCX++) {
3633 for (int q = 0; q < T.components(); q++) {
3634 for (int i = 0; i < r; i += tileR) {
3635 for (int j = 0; j < c; j += tileC) {
3636 block.log2GRFBytes = GRF::log2Bytes(hw);
3637 block.nr = std::min(r - i, tileR);
3638 block.nc = std::min(c - j, tileC);
3639 block.ld = colMajor ? tileR : tileC;
3640 if (!allowPartialRegs)
3641 block.ld = align_up(block.ld, elementsPerGRF(hw, T));
3642 block.offsetR = i;
3643 block.offsetC = j;
3644 block.colMajor = colMajor;
3645 block.crosspack = crosspack;
3646 block.offsetBytes = offsetBytes;
3647 block.splitComplex = false;
3648 block.cxComponent = qCX;
3649 block.component = q;
3650
3651 block.calcBytes(T);
3652 offsetBytes += block.bytes;
3653
3654 block.remainderR = false;
3655 block.remainderC = false;
3656 block.simdSize = 0; // Not backed by memory.
3657
3658 layout.push_back(block);
3659 }
3660 }
3661 }
3662 }
3663}
3664
3665// Attempt to create a 2D block layout that matches an existing layout.
3666// Currently only generates regular/transpose 2D block (no VNNI support).
3667template <HW hw>
3668bool gemm_kernel_generator_t<hw>::upgradeLayoutToBlock2D(Type T,
3669 const vector<RegisterBlock> &layoutSrc, vector<RegisterBlock> &layout2D,
3670 bool remainderR, bool remainderC, bool writable,
3671 const MatrixAddressing &atype,
3672 const MatrixAddressingStrategy &astrategy) {
3673 layout2D.clear();
3674 layout2D.reserve(layoutSrc.size());
3675
3676 if (layoutSrc.empty()) return true;
3677 if (isPacked(atype.layout)) return false;
3678
3679 bool transpose = isTransposing(astrategy.accessType);
3680 bool regCM = isLayoutColMajor(layoutSrc);
3681
3682 if (transpose) {
3683 if (sizeof(T) == 8) {
3684 if (getStepping() < SteppingPVCXTB0) return false;
3685 } else if (sizeof(T) != 4)
3686 return false;
3687 }
3688
3689 int r0 = -1, c0 = -1, b0 = -1;
3690 int nr = 0, nc = 0;
3691 bool ok = true;
3692
3693 auto make2DBlock = [&] {
3694 if (r0 < 0 || c0 < 0) return;
3695 ok = ok
3696 && addToRegLayout(T, layout2D, nr, nc, r0, c0, remainderR,
3697 remainderC, writable, false, 0, 0, atype, astrategy);
3698 };
3699
3700 for (size_t i = 0; i < layoutSrc.size(); i++) {
3701 auto &block = layoutSrc[i];
3702 unsigned omask = GRF::bytes(hw) - 1;
3703
3704 if ((block.offsetBytes & omask) || (block.bytes & omask)) return false;
3705 if (block.nregs() > 1)
3706 return false; /* don't split block into multiple 2D block loads */
3707
3708 bool consecutive = (block.offsetBytes == (b0 + GRF::bytes(hw)));
3709 if (regCM && block.offsetC == c0 + nc && consecutive && nr == block.nr)
3710 nc++;
3711 else if (!regCM && block.offsetR == r0 + nr && consecutive
3712 && nc == block.nc)
3713 nr++;
3714 else {
3715 make2DBlock();
3716 r0 = block.offsetR;
3717 c0 = block.offsetC;
3718 nr = block.nr;
3719 nc = block.nc;
3720 }
3721 b0 = block.offsetBytes;
3722 }
3723
3724 make2DBlock();
3725
3726 int r = 0, c = 0;
3727 getLayoutDims(layoutSrc, r, c);
3728 sortRegLayout(T, layout2D, r, c, atype, astrategy);
3729 postprocessLayout(T, layout2D, atype, astrategy);
3730 finalizeLayout(hw, T, layout2D, atype, astrategy);
3731
3732 return ok;
3733}
3734
3735// Find the subregister in a RegisterBlock corresponding to element at offset (rr,cc),
3736// as well as the contiguous elements following it (nelems).
3737static Subregister findBlockReg(Type T, const RegisterBlock &block, int rr,
3738 int cc, const GRFMultirange &regs, int &nelems, int cxComponent = -1,
3739 int component = 0) {
3740 auto Te = T;
3741 const int ne = (1 << block.log2GRFBytes) / Te;
3742
3743 if (rr < 0 || rr >= block.nr || cc < 0 || cc >= block.nc
3744 || component != block.component
3745 || !one_of(block.cxComponent, -1, cxComponent))
3746 throw std::runtime_error("Requested out-of-bounds element.");
3747
3748 int crosspack = block.crosspack;
3749 int elFixed, elLD;
3750 if (block.colMajor) {
3751 int ccx = cc % crosspack;
3752 elFixed = ccx + (rr * crosspack);
3753 elLD = cc - ccx;
3754 nelems = block.nr - rr;
3755 } else {
3756 int rrx = rr % crosspack;
3757 elFixed = rrx + (cc * crosspack);
3758 elLD = (rr - rrx);
3759 nelems = block.nc - cc;
3760 }
3761
3762 int el = elFixed + elLD * block.ld;
3763 el += block.offsetBytes / Te;
3764 int reg = el / ne;
3765 int subreg = el % ne;
3766
3767 return regs[reg].sub(subreg, Te.ngen());
3768}
3769
3770// Find the subregister in a layout corresponding to element (r,c), as well as the
3771// associated block, and the number of contiguous elements following it (nelems).
3772static Subregister findBlockReg(Type T, const vector<RegisterBlock> &layout,
3773 int r, int c, const GRFMultirange &regs, int &nelems,
3774 const RegisterBlock *&block, int cxComponent = -1, int component = 0) {
3775 int ecomponent = component;
3776 for (auto &l : layout) {
3777 int rr = r - l.offsetR;
3778 int cc = c - l.offsetC;
3779 if (rr >= 0 && rr < l.nr && cc >= 0 && cc < l.nc
3780 && ecomponent == l.component
3781 && one_of(l.cxComponent, cxComponent,
3782 RegisterBlock::Interleaved)) {
3783 block = &l;
3784 return findBlockReg(
3785 T, l, rr, cc, regs, nelems, cxComponent, component);
3786 }
3787 }
3788
3789 throw std::runtime_error(
3790 "Could not find requested matrix element in layout.");
3791}
3792
3793// Match the register offsets in one register layout to another, reference layout.
3794// Returns true if successful. If not successful, the layout is unchanged.
3795static bool matchLayouts(Type T, vector<RegisterBlock> &layout,
3796 const vector<RegisterBlock> &layoutRef) {
3797 vector<RegisterBlock> nlayout = layout;
3798
3799 if (getRegCount(layoutRef) >= 256) return false;
3800
3801 for (auto &nblock : nlayout) {
3802 int nelems;
3803 const RegisterBlock *blockRef;
3804 auto sr = findBlockReg(T, layoutRef, nblock.offsetR, nblock.offsetC,
3805 GRFRange(0, 254), nelems, blockRef);
3806
3807 // Check:
3808 // 1. Does this register block's offset match the reference block's offset?
3809 if (sr.getByteOffset()
3810 != (nblock.offsetBytes & ((1 << nblock.log2GRFBytes) - 1)))
3811 return false;
3812
3813 // 2. Is there any free space in the register block?
3814 if (nblock.nr * nblock.nc * T != nblock.bytes) return false;
3815
3816 // 3. Does this register block's data layout match the reference block's layout?
3817 if (blockRef->colMajor != nblock.colMajor) return false;
3818 if (blockRef->crosspack != nblock.crosspack) return false;
3819
3820 // 4. Does this register block fit inside the reference block?
3821 auto RegisterBlock::*nx
3822 = nblock.colMajor ? &RegisterBlock::nr : &RegisterBlock::nc;
3823 auto RegisterBlock::*ny
3824 = nblock.colMajor ? &RegisterBlock::nc : &RegisterBlock::nr;
3825
3826 if (nblock.*nx < blockRef->*nx) {
3827 if (nblock.*ny > 1) return false;
3828 } else if (nblock.*nx == blockRef->*nx) {
3829 if (nblock.*ny > blockRef->*ny) return false;
3830 } else
3831 return false;
3832
3833 if (nblock.*ny > 1 && (nblock.ld != blockRef->ld)) return false;
3834
3835 // It's compatible. Point this register block where it belongs.
3836 nblock.offsetBytes
3837 = (sr.getBase() << nblock.log2GRFBytes) + sr.getByteOffset();
3838 }
3839
3840 std::swap(nlayout, layout);
3841 return true;
3842}
3843
3844// Like matchLayouts but allows either layout to change to match the other.
3845static bool matchLayoutsBidirectional(Type T, vector<RegisterBlock> &layout1,
3846 vector<RegisterBlock> &layout2) {
3847 return matchLayouts(T, layout1, layout2)
3848 || matchLayouts(T, layout2, layout1);
3849}
3850
3851static bool allocateTokens(const vector<RegisterBlock> &layout,
3852 const GRFMultirange &regs, CommonState &state,
3853 const vector<GRFRange> &addrs = vector<GRFRange>()) {
3854 bool success = true;
3855 size_t origSize = state.tokenMap.size();
3856 auto saveTA = state.tokenAllocator;
3857
3858 for (size_t l = 0; l < layout.size(); l++) {
3859 auto token = state.tokenAllocator.tryAlloc();
3860 if (token < 0)
3861 success = false;
3862 else {
3863 auto regKey = !regs.empty() ? regs[layout[l].offsetReg()].getBase()
3864 : addrs[l].getBase();
3865 state.tokenMap.push_back(std::make_pair(regKey, token));
3866 }
3867 }
3868
3869 if (!success) {
3870 state.tokenAllocator = saveTA;
3871 state.tokenMap.resize(origSize);
3872 }
3873
3874 return success;
3875}
3876
3877static void clearTokenAllocations(HW hw, CommonState &state) {
3878 state.tokenMap.clear();
3879 state.tokenAllocator = TokenAllocator(hw);
3880}
3881
3882template <HW hw>
3883void gemm_kernel_generator_t<hw>::setupTeardownLoadStoreDesc(bool setup,
3884 const vector<RegisterBlock> &layout, const CommonStrategy &strategy,
3885 CommonState &state) {
3886 if (strategy.emulate.emulateDWxDW) {
3887 auto nconstants = (hw >= HW::XeHPG) ? 3 : 2;
3888
3889 if (setup)
3890 for (int s = 0; s < nconstants; s++) {
3891 state.lsDescConstant[s] = state.ra.alloc_sub<uint32_t>();
3892 mov(1, state.lsDescConstant[s], uint32_t(0x00100040 << s));
3893 }
3894 else
3895 for (int s = 0; s < nconstants; s++)
3896 state.ra.safeRelease(state.lsDescConstant[s]);
3897 }
3898}
3899
3900// Output code for loading address register(s) with load/store message descriptors for remainders.
3901template <HW hw>
3902void gemm_kernel_generator_t<hw>::loadLoadStoreDescriptors(bool load,
3903 bool store, RegisterBlock &block, const Subregister &count,
3904 const MatrixAddressing &atype,
3905 const MatrixAddressingStrategy &astrategy,
3906 const CommonStrategy &strategy, CommonState &state) {
3907 MessageDescriptor descLoad; // a0.0:ud
3908 MessageDescriptor descStore; // a0.2 (a0.0 if no loads)
3909 ExtendedMessageDescriptor exdescLoad;
3910 ExtendedMessageDescriptor exdescStore; // a0.1
3911
3912 Subregister t1 = state.ra.alloc_sub<uint32_t>();
3913 Subregister t2 = state.ra.alloc_sub<uint32_t>();
3914
3915 if (astrategy.newDP) switch (astrategy.accessType) {
3916 case AccessType::ChannelScattered:
3917 case AccessType::Scattered: {
3918 bool channel = (astrategy.accessType
3919 == AccessType::ChannelScattered);
3920
3921 encodeLoadDescriptors(hw, descLoad, exdescLoad, block.simdSize,
3922 r0, getDataSpecLSC(atype, astrategy, block, false),
3923 astrategy.base, null);
3924 encodeStoreDescriptors(hw, descStore, exdescStore,
3925 block.simdSize,
3926 getDataSpecLSC(atype, astrategy, block, true),
3927 astrategy.base, null);
3928 descLoad.cmask.cmask = 0; // also vectSize
3929 descStore.cmask.cmask = 0;
3930 exdescStore.parts.extMessageLen = 0;
3931 descLoad.parts.responseLen = 0;
3932
3933 int underlyingSIMD = std::max<int>(
3934 block.simdSize, maxScatteredSIMD(hw, astrategy) >> 1);
3935 int log2GRFs = log2(underlyingSIMD * block.ebytes)
3936 - GRF::log2Bytes(hw);
3937 int log2Components = int(block.splitComplex);
3938
3939 if (channel) mov(1, t2, 0x1000 << log2Components);
3940 mul(1, t1, state.lsDescConstant[log2GRFs + log2Components],
3941 count.uw());
3942 channel ? shl(1, t2, t2, count)
3943 : shl(1, t2, count, 12 + log2Components);
3944 if (store) or_(1, a0.ud(1), t1.uw(0), exdescStore.all);
3945 add(1, t1.uw(0), t2, -0x1000);
3946 if (load) or_(1, a0.ud(0), t1, descLoad.all);
3947 if (store) or_(1, a0.ud(load ? 2 : 0), t1.uw(0), descStore.all);
3948 break;
3949 }
3950 default: hw_unsupported();
3951 }
3952 else
3953 switch (astrategy.accessType) {
3954 case AccessType::ChannelScattered: {
3955 encodeLoadDescriptors(hw, descLoad, exdescLoad, block.simdSize,
3956 r0, surface_dword(ChannelMask::rgba), astrategy.base,
3957 null);
3958 encodeStoreDescriptors(hw, descStore, exdescStore,
3959 block.simdSize, surface_dword(ChannelMask::rgba),
3960 astrategy.base, null);
3961 descLoad.surface.cmask = 0; //
3962 descStore.surface.cmask = 0; // Fields to fill in.
3963 exdescStore.parts.extMessageLen = 0; //
3964 descLoad.parts.responseLen = 0;
3965
3966 int log2Components = int(block.splitComplex);
3967 int shift = int(block.simdSize == 16) + log2Components;
3968 auto bitmask = uint16_t(0x0F00 << log2Components);
3969
3970 if (strategy.emulate.emulateDWxDW)
3971 mul(1, t1, state.lsDescConstant[shift], count.uw());
3972 else
3973 mul(1, t1, count, uint32_t(0x00100040) << shift);
3974 mov(1, t2, bitmask);
3975 if (store) or_(1, a0.ud(1), t1.uw(0), exdescStore.all);
3976 shl(1, t2, t2, count);
3977 and_(1, t1.uw(0), t2, bitmask);
3978 if (load) or_(1, a0.ud(0), t1, descLoad.all);
3979 if (store) or_(1, a0.ud(load ? 2 : 0), t1.uw(0), descStore.all);
3980 break;
3981 }
3982 default: hw_unsupported();
3983 }
3984
3985 state.ra.safeRelease(t1);
3986 state.ra.safeRelease(t2);
3987 block.sfid = exdescLoad.all;
3988}
3989
3990template <HW hw>
3991InstructionModifier gemm_kernel_generator_t<hw>::getRegisterBlockMask(
3992 const RegisterBlock &block, CommonState &state) {
3993 InstructionModifier result;
3994
3995 if (block.flag) {
3996 result |= getPhysicalFlag(block.flag, state);
3997 if (hw == HW::XeHPC) {
3998 if (block.flagAll) result |= all;
3999 if (block.flagAny) result |= any;
4000 } else if (block.flagAll)
4001 result |= (block.simdSize > 8) ? all16h : all8h;
4002 else if (block.flagAny)
4003 result |= (block.simdSize > 8) ? any16h : any8h;
4004 }
4005
4006 return result;
4007}
4008
4009// Check if a block occupies a contiguous portion of registers in the given GRFMultirange.
4010// If so, return index of the block's first register in the range.
4011static inline int contiguityCheck(
4012 HW hw, const RegisterBlock &block, const GRFMultirange &range) {
4013 auto offsetBytes = block.offsetBytes;
4014 if (offsetBytes & (GRF::bytes(hw) - 1))
4015 if (block.isLoadBlock()) stub();
4016 auto offsetReg = offsetBytes >> GRF::log2Bytes(hw);
4017 auto lastReg = GRF::bytesToGRFs(hw, offsetBytes + block.bytes);
4018 if (!range.contiguous(offsetReg, lastReg - offsetReg)) stub();
4019
4020 return offsetReg;
4021}
4022
4023static DataSizeLSC getDataSizeLSC(int ebytes, bool pad32) {
4024 switch (ebytes) {
4025 case 8: return DataSizeLSC::D64;
4026 case 4: return DataSizeLSC::D32;
4027 case 2: return pad32 ? DataSizeLSC::D16U32 : DataSizeLSC::D16;
4028 case 1: return pad32 ? DataSizeLSC::D8U32 : DataSizeLSC::D8;
4029 }
4030 throw std::runtime_error("Invalid data size");
4031}
4032
4033template <HW hw>
4034DataSpecLSC gemm_kernel_generator_t<hw>::getDataSpecLSC(
4035 AccessType access, const RegisterBlock &block) {
4036 switch (access) {
4037 case AccessType::ChannelScattered: {
4038 static const ChannelMask cmasks[4] = {ChannelMask::r,
4039 ChannelMask::rg, ChannelMask::rgb, ChannelMask::rgba};
4040 if (block.ebytes != 4) hw_unsupported();
4041 return D32 | cmasks[block.count - 1];
4042 }
4043 case AccessType::Scattered:
4044 if (block.ebytes == 8) return D64(block.count);
4045 if (block.ebytes == 4) return D32(block.count);
4046 if (block.ebytes == 1) return getDataSizeLSC(block.count, true);
4047 hw_unsupported();
4048 case AccessType::Block:
4049 if (block.ebytes == 8) return D64T(block.count);
4050 if (block.ebytes == 4) return D32T(block.count);
4051 hw_unsupported();
4052 default: stub();
4053 }
4054}
4055
4056template <HW hw>
4057DataSpecLSC gemm_kernel_generator_t<hw>::getDataSpecLSC(
4058 const MatrixAddressing &atype,
4059 const MatrixAddressingStrategy &astrategy, const RegisterBlock &block,
4060 bool write) {
4061 return getDataSpecLSC(implAccessType(atype, astrategy, block), block)
4062 | (write ? astrategy.cachingW : astrategy.cachingR);
4063}
4064
4065// Output code for prefetching a matrix chunk (XeHPG+).
4066template <HW hw>
4067void gemm_kernel_generator_t<hw>::prefetchMatrix(
4068 const vector<RegisterBlock> &layout, const MatrixAddressing &atype,
4069 const MatrixAddressingStrategy &astrategy,
4070 const vector<GRFRange> &addrs, const CommonStrategy &strategy,
4071 CommonState &state) {
4072 auto nblocks = int(layout.size());
4073
4074 for (int l = 0; l < nblocks; l++)
4075 loadMatrixBlock(null, layout[l], atype, astrategy, addrs[l], strategy,
4076 state, false);
4077}
4078
4079// Output code for loading a matrix chunk into registers.
4080template <HW hw>
4081void gemm_kernel_generator_t<hw>::loadMatrix(const GRFMultirange &dest,
4082 const vector<RegisterBlock> &layout, const MatrixAddressing &atype,
4083 const MatrixAddressingStrategy &astrategy,
4084 const vector<GRFRange> &addrs, const CommonStrategy &strategy,
4085 CommonState &state, bool zeroMask) {
4086 auto nblocks = int(layout.size());
4087
4088 if (astrategy.prefetch && astrategy.newDP) {
4089 prefetchMatrix(layout, atype, astrategy, addrs, strategy, state);
4090 return;
4091 }
4092
4093 if (strategy.readSuppressionWA && (hasFlags(layout) || !getDefaultNoMask()))
4094 doReadSuppressionWA(strategy, state);
4095
4096 for (int l = 0; l < nblocks; l++) {
4097 auto offsetReg = contiguityCheck(hw, layout[l], dest);
4098 loadMatrixBlock(dest[offsetReg], layout[l], atype, astrategy, addrs[l],
4099 strategy, state, zeroMask);
4100 }
4101}
4102
4103// Output code for loading a single matrix block into registers.
4104template <HW hw>
4105void gemm_kernel_generator_t<hw>::loadMatrixBlock(const Register &dest,
4106 const RegisterBlock &block, const MatrixAddressing &atype,
4107 const MatrixAddressingStrategy &astrategy, const GRFRange &addr,
4108 const CommonStrategy &strategy, CommonState &state, bool zeroMask) {
4109 InstructionModifier maskMod;
4110 InstructionModifier mod = block.simdSize;
4111
4112 // Zero SIMD size blocks are filled as part of another load. Skip them.
4113 if (!block.isLoadBlock()) return;
4114
4115 // Get mask to apply, if any.
4116 auto mask = getRegisterBlockMask(block, state);
4117 maskMod |= mask;
4118 mod |= mask;
4119
4120 // Look up preassigned token.
4121 for (auto &entry : state.tokenMap) {
4122 if (entry.first == dest.getBase() || entry.first == addr.getBase()) {
4123 mod |= SBID(entry.second);
4124 break;
4125 }
4126 }
4127
4128 if (astrategy.newDP) switch (implAccessType(atype, astrategy, block)) {
4129 case AccessType::Block:
4130 case AccessType::Scattered:
4131 case AccessType::ChannelScattered: {
4132 auto spec = getDataSpecLSC(atype, astrategy, block, false);
4133 if (block.descAssigned) {
4134 MessageDescriptor desc;
4135 ExtendedMessageDescriptor exdesc;
4136 encodeLoadDescriptors(hw, desc, exdesc, block.simdSize, r0,
4137 spec, astrategy.base, null);
4138 send(mod, static_cast<SharedFunction>(block.sfid), dest,
4139 addr, null, exdesc.all, a0[0]);
4140 } else {
4141 load(mod, dest, spec, astrategy.base, addr[0]);
4142 }
4143 break;
4144 }
4145 case AccessType::Block2D:
4146 case AccessType::Block2DTranspose:
4147 case AccessType::Block2DVNNI: {
4148 int w = 0, h = 0;
4149 getBlock2DWH(w, h, atype, block);
4150 auto spec = block_2d(getDataSizeLSC(block.ebytes, false), w, h,
4151 block.count)
4152 | astrategy.cachingR;
4153 if (astrategy.accessType == AccessType::Block2DTranspose)
4154 spec |= transpose;
4155 if (astrategy.accessType == AccessType::Block2DVNNI)
4156 spec |= vnni;
4157 load(mod, dest, spec, astrategy.base, addr);
4158 break;
4159 }
4160 default: stub();
4161 }
4162 else if (block.descAssigned)
4163 send(mod, static_cast<SharedFunction>(block.sfid), dest, addr, null,
4164 block.sfid, a0[0]);
4165 else
4166 switch (implAccessType(atype, astrategy, block)) {
4167 case AccessType::ChannelScattered: {
4168 static const ChannelMask cmasks[4] = {ChannelMask::r,
4169 ChannelMask::rg, ChannelMask::rgb, ChannelMask::rgba};
4170 if (block.ebytes != 4) stub();
4171 load(mod, dest, surface_dword(cmasks[block.count - 1]),
4172 astrategy.base, addr);
4173 break;
4174 }
4175 case AccessType::Scattered:
4176 if (block.ebytes == 8)
4177 load(mod, dest, scattered_qword(block.count),
4178 astrategy.base, addr);
4179 else if (block.ebytes == 4)
4180 load(mod, dest, scattered_dword(block.count),
4181 astrategy.base, addr);
4182 else if (block.ebytes == 1)
4183 load(mod, dest, scattered_byte(block.count), astrategy.base,
4184 addr);
4185 else
4186 hw_unsupported();
4187 break;
4188 case AccessType::Block:
4189 if (block.ebytes == 32)
4190 load(mod, dest, block_hword(block.count), astrategy.base,
4191 addr);
4192 else if (block.ebytes == 16 && !block.extra)
4193 load(mod, dest, block_oword(block.count), astrategy.base,
4194 addr);
4195 else if (block.ebytes == 16)
4196 load(mod, dest, aligned_block_oword(block.count),
4197 astrategy.base, addr);
4198 else
4199 hw_unsupported();
4200 if (zeroMask && (astrategy.base.getModel() == ModelBTS)) {
4201 if (block.flag)
4202 mov<uint32_t>(block.simdSize | ~maskMod, dest, 0);
4203 if (block.simdSize <= 2) mov<uint32_t>(2, dest[2](1), 0);
4204 if (block.simdSize <= 1) mov<uint32_t>(1, dest[1], 0);
4205 }
4206 break;
4207 default: stub();
4208 }
4209}
4210
4211// Output code for storing a matrix chunk from registers.
4212template <HW hw>
4213void gemm_kernel_generator_t<hw>::storeMatrix(const GRFMultirange &src,
4214 const vector<RegisterBlock> &layout, const MatrixAddressing &atype,
4215 const MatrixAddressingStrategy &astrategy,
4216 const vector<GRFRange> &addrs, const CommonStrategy &strategy,
4217 CommonState &state) {
4218 auto nblocks = int(layout.size());
4219
4220 for (int l = 0; l < nblocks; l++) {
4221 auto offsetReg = contiguityCheck(hw, layout[l], src);
4222 storeMatrixBlock(src[offsetReg], layout[l], atype, astrategy, addrs[l],
4223 strategy, state);
4224 }
4225}
4226
4227// Output code for storing a matrix block from registers.
4228template <HW hw>
4229void gemm_kernel_generator_t<hw>::storeMatrixBlock(const GRF &src,
4230 const RegisterBlock &block, const MatrixAddressing &atype,
4231 const MatrixAddressingStrategy &astrategy, const GRFRange &addr,
4232 const CommonStrategy &strategy, CommonState &state) {
4233 InstructionModifier mod = block.simdSize;
4234 ;
4235
4236 // Zero SIMD size blocks are filled as part of another store. Skip them.
4237 if (!block.isLoadBlock()) return;
4238
4239 // Get mask to apply, if any.
4240 mod |= getRegisterBlockMask(block, state);
4241
4242 // Look up preassigned token.
4243 for (auto &entry : state.tokenMap) {
4244 if (entry.first == src.getBase()) {
4245 mod |= SBID(entry.second);
4246 break;
4247 }
4248 }
4249
4250 if (block.descAssigned)
4251 send(mod, static_cast<SharedFunction>(block.sfid), null, addr, src,
4252 a0.ud(1), a0.ud(0));
4253 else if (astrategy.newDP)
4254 switch (implAccessType(atype, astrategy, block)) {
4255 case AccessType::Block:
4256 case AccessType::Scattered:
4257 case AccessType::ChannelScattered: {
4258 auto spec = getDataSpecLSC(atype, astrategy, block, true);
4259 store(mod, spec, astrategy.base, addr[0], src);
4260 break;
4261 }
4262 case AccessType::Block2D:
4263 case AccessType::Block2DTranspose:
4264 case AccessType::Block2DVNNI: {
4265 int w = 0, h = 0;
4266 getBlock2DWH(w, h, atype, block);
4267 auto spec = block_2d(getDataSizeLSC(block.ebytes, false), w, h,
4268 block.count)
4269 | astrategy.cachingW;
4270 if (astrategy.accessType == AccessType::Block2DTranspose)
4271 spec |= transpose;
4272 if (astrategy.accessType == AccessType::Block2DVNNI)
4273 spec |= vnni;
4274 store(mod, spec, astrategy.base, addr, src);
4275 break;
4276 }
4277 default: stub();
4278 }
4279 else
4280 switch (implAccessType(atype, astrategy, block)) {
4281 case AccessType::ChannelScattered: {
4282 static const ChannelMask cmasks[4] = {ChannelMask::r,
4283 ChannelMask::rg, ChannelMask::rgb, ChannelMask::rgba};
4284 if (block.ebytes != 4) stub();
4285 store(mod, surface_dword(cmasks[block.count - 1]),
4286 astrategy.base, addr, src);
4287 break;
4288 }
4289 case AccessType::Scattered:
4290 if (block.ebytes == 8)
4291 store(mod, scattered_qword(block.count), astrategy.base,
4292 addr, src);
4293 else if (block.ebytes == 4)
4294 store(mod, scattered_dword(block.count), astrategy.base,
4295 addr, src);
4296 else if (block.ebytes == 1)
4297 store(mod, scattered_byte(block.count), astrategy.base,
4298 addr, src);
4299 else
4300 hw_unsupported();
4301 break;
4302 case AccessType::Block:
4303 if (block.ebytes == 32)
4304 store(mod, block_hword(block.count), astrategy.base, addr,
4305 src);
4306 else if (block.ebytes == 16 && !block.extra)
4307 store(mod, block_oword(block.count), astrategy.base, addr,
4308 src);
4309 else
4310 hw_unsupported();
4311 break;
4312 default: stub();
4313 }
4314}
4315
4316// Atomic addition of a matrix in registers.
4317template <HW hw>
4318void gemm_kernel_generator_t<hw>::atomicAddMatrix(Type T,
4319 const GRFMultirange &src, const vector<RegisterBlock> &layout,
4320 const MatrixAddressing &atype,
4321 const MatrixAddressingStrategy &astrategy,
4322 const vector<GRFRange> &addrs, const CommonProblem &problem,
4323 const CommonStrategy &strategy, CommonState &state) {
4324 auto nblocks = int(layout.size());
4325
4326 if (strategy.readSuppressionWA && (hasFlags(layout) || !getDefaultNoMask()))
4327 doReadSuppressionWA(strategy, state);
4328
4329 for (int l = 0; l < nblocks; l++) {
4330 auto offsetReg = contiguityCheck(hw, layout[l], src);
4331 atomicAddMatrixBlock(T, src[offsetReg], layout[l], atype, astrategy,
4332 addrs[l], problem, strategy, state);
4333 }
4334}
4335
4336template <HW hw>
4337void gemm_kernel_generator_t<hw>::atomicAddMatrixBlock(Type T, const GRF &src,
4338 const RegisterBlock &block, const MatrixAddressing &atype,
4339 const MatrixAddressingStrategy &astrategy, const GRFRange &addr,
4340 const CommonProblem &problem, const CommonStrategy &strategy,
4341 CommonState &state) {
4342 InstructionModifier maskMod;
4343
4344 if (!block.isLoadBlock()) return;
4345 if (block.descAssigned) stub();
4346
4347 maskMod |= getRegisterBlockMask(block, state);
4348
4349 // SIMD16 A64 atomics are emulated with 2x SIMD8.
4350 bool a64 = (astrategy.base.getModel() == ModelA64);
4351 int hsize = a64 ? 2 : 1;
4352 int simd = block.simdSize;
4353 if (!astrategy.newDP && a64) simd = std::min(simd, 8);
4354 if (hw >= HW::XeHPC && block.ebytes < 8 && block.simdSize == 16
4355 && simd == 8)
4356 stub(); // Can't split data GRFs.
4357 auto nreg = block.nregs();
4358 auto nregReal = (nreg * simd) / block.simdSize;
4359
4360 auto specLSC = D32;
4361 if (astrategy.newDP)
4362 specLSC = getDataSpecLSC(atype, astrategy, block, true);
4363
4364 switch (implAccessType(atype, astrategy, block)) {
4365 case AccessType::Scattered:
4366 case AccessType::ChannelScattered:
4367 if (hasNativeAtomicAdd(hw, T.real(), atype, astrategy)) {
4368 auto curSrc = src;
4369 for (int eoff = 0, hoff = 0; eoff < block.simdSize;
4370 eoff += simd, hoff += hsize, curSrc += nregReal) {
4371 auto mod = simd | maskMod | ExecutionOffset(eoff);
4372 if (block.ebytes != T.real().size()) stub();
4373 if (astrategy.newDP)
4374 atomic(T.isFP() ? AtomicOp::fadd : AtomicOp::add, mod,
4375 specLSC, astrategy.base, addr[hoff], curSrc);
4376 else
4377 switch (T.real()) {
4378 case Type::f32:
4379 atomic(AtomicOp::fadd, mod, scattered_dword(),
4380 astrategy.base, addr[hoff], curSrc);
4381 break;
4382 case Type::u64:
4383 case Type::s64:
4384 atomic(AtomicOp::add, mod, scattered_qword(),
4385 astrategy.base, addr[hoff], curSrc);
4386 break;
4387 case Type::u32:
4388 case Type::s32:
4389 atomic(AtomicOp::add, mod, scattered_dword(),
4390 astrategy.base, addr[hoff], curSrc);
4391 break;
4392 case Type::u16:
4393 case Type::s16:
4394 if (hw < HW::Gen12LP) hw_unsupported();
4395 atomic(AtomicOp::add, mod, scattered_word(),
4396 astrategy.base, addr[hoff], curSrc);
4397 break;
4398 default: stub();
4399 }
4400 }
4401 } else {
4402 // Emulated atomic addition with a compare-and-swap loop.
4403 auto rOldNew = state.eatomicAddRegs[0];
4404 auto rSave = state.eatomicAddRegs[1];
4405 auto rOld = rOldNew[0];
4406 auto rNew = rOldNew[nregReal];
4407 auto flagToDo = getPhysicalFlag(state.vflagEAtomicAdd, state);
4408
4409 if (block.simdSize > 16) stub(); // Need 32 channels.
4410 if (astrategy.newDP)
4411 load(block.simdSize | maskMod, rOld, specLSC,
4412 astrategy.base, addr[0]);
4413 else if (astrategy.base.getModel() == ModelA64) {
4414 if (block.ebytes == 2)
4415 load(block.simdSize | maskMod, rOld, scattered_byte(2),
4416 astrategy.base, addr);
4417 else if (block.ebytes == 4)
4418 load(block.simdSize | maskMod, rOld, scattered_dword(),
4419 astrategy.base, addr);
4420 else if (block.ebytes == 8)
4421 load(block.simdSize | maskMod, rOld, scattered_qword(),
4422 astrategy.base, addr);
4423 } else {
4424 if (block.ebytes == 2)
4425 load(block.simdSize | maskMod, rOld, scattered_byte(2),
4426 astrategy.base, addr);
4427 else if (block.ebytes == 4)
4428 load(block.simdSize | maskMod, rOld,
4429 surface_dword(ChannelMask::r), astrategy.base,
4430 addr);
4431 else if (block.ebytes == 8)
4432 stub(); // needs cmpwr2
4433 }
4434 Label labelMask;
4435
4436 // Save off high half of data when emulating SIMD16.
4437 if (block.simdSize > simd)
4438 mov<uint32_t>(nregReal * 8, rOld.advance(nreg),
4439 rOld.advance(nregReal));
4440
4441 if (block.flag) {
4442 if_(16 | getPhysicalFlag(block.flag, state), labelMask);
4443 setDefaultNoMask(false);
4444 }
4445
4446 and_(1 | NoMask, flagToDo, ce0,
4447 uint16_t((1 << block.simdSize) - 1));
4448
4449 auto curSrc = src;
4450
4451 for (int eoff = 0, hoff = 0; eoff < block.simdSize;
4452 eoff += simd, hoff += hsize) {
4453 auto eoMod = ExecutionOffset(eoff);
4454
4455 Label labelCmpXchgLoop;
4456 mark(labelCmpXchgLoop);
4457
4458 auto dt = T.ngen();
4459 add(int(simd * block.ebytes / T.real()) | eoMod | NoMask,
4460 rNew.retype(dt), rOld.retype(dt),
4461 curSrc.retype(dt));
4462 mov<uint32_t>((simd * block.ebytes / 4) | eoMod | NoMask,
4463 rSave, rOld);
4464
4465 auto atomicMod = simd | flagToDo | eoMod;
4466 auto cmpMod = simd | flagToDo | ne | flagToDo | eoMod;
4467
4468 if (astrategy.newDP)
4469 atomic(AtomicOp::cmpwr, atomicMod, rOld, specLSC,
4470 astrategy.base, addr[hoff], rOld);
4471 else
4472 switch (block.ebytes) {
4473 case 2:
4474 if (hw < HW::Gen12LP) hw_unsupported();
4475 atomic(AtomicOp::cmpwr, atomicMod, rOld,
4476 scattered_word(), astrategy.base,
4477 addr[hoff], rOld);
4478 break;
4479 case 4:
4480 atomic(AtomicOp::cmpwr, atomicMod, rOld,
4481 scattered_dword(), astrategy.base,
4482 addr[hoff], rOld);
4483 break;
4484 case 8:
4485 atomic(AtomicOp::cmpwr, atomicMod, rOld,
4486 scattered_qword(), astrategy.base,
4487 addr[hoff], rOld);
4488 break;
4489 default: stub();
4490 }
4491
4492 if (block.ebytes == 2)
4493 cmp<uint16_t>(cmpMod, rSave[0][0](2), rOld[0](2));
4494 else if (block.ebytes == 4)
4495 cmp<uint32_t>(cmpMod, rSave, rOld);
4496 else if (block.ebytes == 8) {
4497 if (strategy.emulate.emulate64) {
4498 cmp<uint32_t>(simd | ne | flagToDo | eoMod,
4499 rSave[0][0](2), rOld[0](2));
4500 cmp<uint32_t>(
4501 simd | ~flagToDo | ne | flagToDo | eoMod,
4502 rSave[0][1](2), rOld[1](2));
4503 } else
4504 cmp<uint64_t>(cmpMod, rSave, rOld);
4505 } else
4506 stub();
4507
4508 (hw == HW::XeHPC)
4509 ? simtDoWhileLoop(
4510 16 | flagToDo | any, labelCmpXchgLoop)
4511 : strategy.fused ? simtDoWhileLoop(
4512 16 | flagToDo | any16h, labelCmpXchgLoop)
4513 : (eoff == 0 && simd == 8)
4514 ? jmpi(1 | flagToDo | any8h,
4515 labelCmpXchgLoop)
4516 : jmpi(1 | flagToDo | any16h,
4517 labelCmpXchgLoop);
4518
4519 rOld += 2 * nregReal;
4520 rNew += 2 * nregReal;
4521 curSrc += nregReal;
4522 }
4523
4524 if (block.flag) {
4525 mark(labelMask);
4526 setDefaultNoMask(true);
4527 endif(16);
4528 }
4529 }
4530 break;
4531 default: hw_unsupported();
4532 }
4533}
4534
4535// Allocate temporary registers for emulating atomic addition.
4536static inline void allocEAtomicAddRegs(HW hw, Type T,
4537 const vector<RegisterBlock> &layout, const MatrixAddressing &atype,
4538 const MatrixAddressingStrategy &astrategy, CommonState &state,
4539 const FlagRegister &flag = FlagRegister()) {
4540 if (hasNativeAtomicAdd(hw, T.real(), atype, astrategy)) return;
4541
4542 int maxNReg = 0;
4543 for (const auto &block : layout)
4544 maxNReg = std::max(maxNReg, block.nregs());
4545
4546 if (maxNReg == 0) return;
4547
4548 state.eatomicAddRegs[0] = state.ra.alloc_range(maxNReg * 2);
4549 state.eatomicAddRegs[1] = state.ra.alloc_range(maxNReg);
4550 state.vflagEAtomicAdd
4551 = flag.isValid() ? flag : state.raVFlag.allocVirtual();
4552}
4553
4554// Free temporary registers for emulating atomic addition.
4555static inline void freeEAtomicAddRegs(
4556 CommonState &state, const FlagRegister &flag = FlagRegister()) {
4557 state.ra.safeRelease(state.eatomicAddRegs[0]);
4558 state.ra.safeRelease(state.eatomicAddRegs[1]);
4559 if (flag.isInvalid()) state.raVFlag.release(state.vflagEAtomicAdd);
4560}
4561
4562static inline void releaseMaskAssignments(vector<MaskAssignment> &assignments,
4563 CommonState &state, int start = 0) {
4564 for (size_t an = start; an < assignments.size(); an++)
4565 state.raVFlag.release(assignments[an].flag);
4566
4567 state.wipeActiveVFlags();
4568}
4569
4570static inline void reclaimMaskAssignments(vector<MaskAssignment> &assignments,
4571 CommonState &state, int start = 0) {
4572 for (size_t an = start; an < assignments.size(); an++)
4573 state.raVFlag.claim(assignments[an].flag);
4574}
4575
4576// Release all masks in a mask assignment. If 'start' is specified, only the masks
4577// at index 'start' and above will be released.
4578static inline void safeReleaseMaskAssignments(
4579 vector<MaskAssignment> &assignments, CommonState &state,
4580 int start = 0) {
4581 releaseMaskAssignments(assignments, state, start);
4582 assignments.resize(start);
4583}
4584
4585// Assign mask registers to a register layout.
4586// The assignments parameter is both input and output:
4587// existing assignments will be reused if compatible, and new assignments
4588// created as necessary.
4589template <HW hw>
4590bool gemm_kernel_generator_t<hw>::assignMasks(
4591 std::vector<RegisterBlock> &layout, LoopType rloop, LoopType cloop,
4592 vector<MaskAssignment> &assignments, const CommonStrategy &strategy,
4593 CommonState &state, bool retryVirtual) {
4594 // Loop through layout, collecting masks.
4595 // - For each unique mask+loop+offset, allocate an index (flag reg)
4596 // - Store new assignment if unique and update flag reg in layout.
4597 // - For now, simultaneous row and column masks are not supported.
4598 bool retry = false;
4599 do {
4600 auto nassignOriginal = int(assignments.size());
4601 bool outOfRegs = retry = false;
4602
4603 for (RegisterBlock &l : layout) {
4604 MaskAssignment thisAssignment;
4605
4606 if (l.rowMask) {
4607 if (l.colMask) stub();
4608
4609 thisAssignment.mask = l.rowMask;
4610 thisAssignment.offset = l.offsetR;
4611 thisAssignment.var = rloop;
4612 } else if (l.colMask) {
4613 thisAssignment.mask = l.colMask;
4614 thisAssignment.offset = l.offsetC;
4615 thisAssignment.var = cloop;
4616 } else {
4617 l.clearFlag();
4618 continue;
4619 }
4620
4621 // Look for compatible mask.
4622 bool gotMask = false;
4623 for (auto &a : assignments) {
4624 if (a.compatible(thisAssignment)) {
4625 l.flag = a.flag;
4626 gotMask = true;
4627 break;
4628 }
4629 }
4630
4631 if (!gotMask) {
4632 // No compatible mask, so make a new assignment.
4633 thisAssignment.flag
4634 = state.raVFlag.allocVirtual((l.simdSize + 0xF) >> 4);
4635 assignments.push_back(thisAssignment);
4636 if (state.raVFlag.isVirtual(thisAssignment.flag)
4637 && state.vflagStorage.isInvalid()) {
4638 outOfRegs = true;
4639 break;
4640 }
4641 l.flag = thisAssignment.flag;
4642 }
4643 }
4644
4645 if (outOfRegs) {
4646 // Not enough (virtual) flag registers! Free any masks we added to the list.
4647 safeReleaseMaskAssignments(assignments, state, nassignOriginal);
4648 if (retryVirtual && state.vflagStorage.isInvalid()) {
4649 status << "Not enough flag registers available. Retrying with "
4650 "virtual flags."
4651 << status_stream::endl;
4652 allocVFlagStorage(strategy, state);
4653 retry = true;
4654 } else {
4655 status << "Not enough flag registers available."
4656 << status_stream::endl;
4657 return false;
4658 }
4659 }
4660 } while (retry);
4661
4662 return true;
4663}
4664
4665// Output code for loading a mask into a flag register.
4666template <HW hw>
4667void gemm_kernel_generator_t<hw>::loadMask(MaskAssignment assignment,
4668 Subregister index, const CommonStrategy &strategy, CommonState &state,
4669 int offset) {
4670 auto flagIdx = assignment.flag;
4671 RegData flag = getMaskFlag(flagIdx, state);
4672
4673 if (assignment.mask.fixed.isFixed) {
4674 // Load fixed mask. Easy.
4675 mov(1, flag, uint16_t(assignment.mask.fixed.value));
4676 } else {
4677 // Load a variable mask, which requires some minor bit-twiddling.
4678 auto &vmask = assignment.mask.variable;
4679
4680 uint32_t rsizeScaled = vmask.rsize / vmask.rdivide;
4681 uint32_t fullMask
4682 = (1ul << (vmask.bitRep * vmask.maskRep * rsizeScaled)) - 1;
4683 uint32_t rep1Mask = (1ul << (vmask.bitRep * rsizeScaled)) - 1;
4684 uint32_t repMultiplier = fullMask / rep1Mask;
4685
4686 auto flagType = flag.getType();
4687 auto mask0Type = getBytes(flagType) >= 4 ? DataType::uq : flagType;
4688
4689 auto temp = state.ra.alloc_sub(flagType, getHint(HintType::Bank0));
4690 auto mask0 = state.ra.alloc_sub(mask0Type, getHint(HintType::Bank1));
4691 auto mask = mask0.reinterpret(0, flagType);
4692 auto mindex = index;
4693
4694 if (vmask.rdivide > 1) {
4695 if (!is_zero_or_pow2(vmask.rdivide)) stub();
4696 add(1, temp, mindex, -offset + vmask.rdivide - 1);
4697 shr(1, temp, temp, uint16_t(log2(vmask.rdivide)));
4698 mindex = temp;
4699 offset = 0;
4700 }
4701 if (vmask.bitRep > 1) {
4702 if (offset > 0) {
4703 add(1, temp, mindex, -offset);
4704 mindex = temp;
4705 offset = 0;
4706 }
4707 mulConstant(1, temp, mindex, vmask.bitRep);
4708 mindex = temp;
4709 }
4710 uint16_t tshift = vmask.bitRep
4711 * (rsizeScaled
4712 + div_up(assignment.offset + offset, vmask.rdivide));
4713 add(1 | sat, temp, -mindex, tshift);
4714 if (tshift >= 32)
4715 min_(1, temp, temp,
4716 vmask.bitRep
4717 * rsizeScaled); // Ensure shift count doesn't overflow.
4718 emov(1, mask0, rep1Mask, strategy, state);
4719 if (vmask.maskRep == 1) {
4720 bool twoStage = (!flag.isARF() && getBytes(mask0Type) > 4);
4721 auto flag1 = twoStage ? mask0 : flag;
4722 vmask.reverse ? shl(1, flag1, mask0, temp)
4723 : shr(1, flag1, mask0, temp);
4724 if (twoStage) mov(1, flag, mask);
4725 } else {
4726 vmask.reverse ? stub() // need shl + and
4727 : shr(1, mask0, mask0, temp);
4728 if (repMultiplier & 0x10000) mov(1, mask.uw(1), mask.uw(0));
4729 mul(1, flag, mask, uint16_t(repMultiplier));
4730 }
4731
4732 state.ra.safeRelease(temp);
4733 state.ra.safeRelease(mask0);
4734 }
4735}
4736
4737// Output code for loading all masks in a mask assignment to flag registers.
4738template <HW hw>
4739void gemm_kernel_generator_t<hw>::loadMasks(
4740 const vector<MaskAssignment> &assignments, Subregister (&indices)[3],
4741 const CommonStrategy &strategy, CommonState &state, int start) {
4742 for (size_t an = start; an < assignments.size(); an++) {
4743 auto &a = assignments[an];
4744 auto av = static_cast<int>(a.var);
4745 loadMask(a, indices[av], strategy, state);
4746 }
4747}
4748
4749template <HW hw>
4750void gemm_kernel_generator_t<hw>::loadMasks(
4751 const vector<MaskAssignment> &assignments, Subregister (&indices)[3],
4752 int (&offsets)[3], const CommonStrategy &strategy, CommonState &state,
4753 int start) {
4754 for (size_t an = start; an < assignments.size(); an++) {
4755 auto &a = assignments[an];
4756 auto av = static_cast<int>(a.var);
4757 loadMask(a, indices[av], strategy, state, offsets[av]);
4758 }
4759}
4760
4761template <HW hw>
4762void gemm_kernel_generator_t<hw>::extendIndexVec(int n, CommonState &state) {
4763 auto &indexVec = state.indexVec;
4764 auto &ivEntries = state.ivEntries;
4765
4766 if (n > ivEntries) {
4767 int simd = GRF::bytes(hw) >> 1;
4768 int nregs = div_up(n, simd);
4769 int cregs = indexVec.getLen();
4770 if (nregs > cregs)
4771 indexVec.ranges.push_back(state.ra.alloc_range(nregs - cregs));
4772 if (ivEntries == 0) {
4773 mov<uint16_t>(8, indexVec[0][0](1),
4774 Immediate::uv(0, 1, 2, 3, 4, 5, 6, 7));
4775 ivEntries = 8;
4776 }
4777 if (n > 8 && ivEntries < 16) {
4778 mov<uint16_t>(8, indexVec[0][8](1),
4779 Immediate::uv(8, 9, 10, 11, 12, 13, 14, 15));
4780 ivEntries = 16;
4781 }
4782 if (GRF::bytes(hw) > 32 && n > 16 && ivEntries < 32) {
4783 add<uint16_t>(16, indexVec[0][16](1), indexVec[0].uw(0)(1), 16);
4784 ivEntries = 32;
4785 }
4786 if (n > ivEntries) {
4787 for (int e = std::max(cregs, 1); e < nregs; e++)
4788 add<uint16_t>(simd, indexVec[e], indexVec[0], simd * e);
4789 ivEntries = nregs * simd;
4790 }
4791 }
4792}
4793
4794template <HW hw>
4795Subregister gemm_kernel_generator_t<hw>::accessIndexVec(
4796 int n, CommonState &state) {
4797 if (n >= state.ivEntries) extendIndexVec(n, state);
4798
4799 int simd = GRF::bytes(hw) >> 1;
4800 return state.indexVec[n / simd].uw(n % simd);
4801}
4802
4803static inline void releaseIndexVec(CommonState &state) {
4804 safeReleaseRanges(state.indexVec, state);
4805 state.ivEntries = 0;
4806}
4807
4808template <HW hw>
4809LDMultiples gemm_kernel_generator_t<hw>::createLDMultiples(bool a64,
4810 int nmultiples, const Subregister &ld, const CommonStrategy &strategy,
4811 CommonState &state) {
4812 int simd = GRF::bytes(hw) >> (a64 ? 3 : 2);
4813 int nregs = div_up(nmultiples, simd);
4814 auto r = state.ra.try_alloc_range(nregs);
4815
4816 GRF tempHi = state.emulate.temp[0], tempLo = state.emulate.temp[1];
4817 bool freeTempHi = false, freeTempLo = false;
4818 if (a64) {
4819 if (tempHi.isInvalid()) {
4820 tempHi = state.ra.alloc();
4821 freeTempHi = true;
4822 }
4823 if (tempLo.isInvalid()) {
4824 tempLo = state.ra.alloc();
4825 freeTempLo = true;
4826 }
4827 }
4828
4829 if (r.isValid()) {
4830 extendIndexVec(nmultiples, state);
4831 for (int i = 0; i < nregs; i += 2) {
4832 auto thisSIMD = simd * std::min(nregs - i, 2);
4833 auto iv = accessIndexVec(simd * i, state)(1);
4834 if (a64) {
4835 mul<uint32_t>(thisSIMD, acc0, ld, iv);
4836 mach<uint32_t>(thisSIMD, tempHi, ld, Immediate::ud(0));
4837 mov<uint32_t>(thisSIMD, tempLo, acc0);
4838 mov<uint32_t>(thisSIMD, r[i][1](2), tempHi);
4839 mov<uint32_t>(thisSIMD, r[i][0](2), tempLo);
4840 } else
4841 mul<uint32_t>(thisSIMD, r[i], ld, iv);
4842 }
4843 }
4844
4845 if (freeTempHi) state.ra.safeRelease(tempHi);
4846 if (freeTempLo) state.ra.safeRelease(tempLo);
4847
4848 LDMultiples result;
4849 result.range = r;
4850 result.a64 = a64;
4851
4852 return result;
4853}
4854
4855template <HW hw>
4856Subregister gemm_kernel_generator_t<hw>::findLDMultiple(
4857 const LDMultiples &multiples, bool a64, int n,
4858 const CommonStrategy &strategy, CommonState &state) {
4859 int simd = GRF::bytes(hw) >> (multiples.a64 ? 3 : 2);
4860 int off = (n / simd), sub = (n % simd);
4861
4862 if (multiples.range.isInvalid()) return Subregister();
4863 if (off < 0 || off >= multiples.range.getLen()) return Subregister();
4864 if (a64 && !multiples.a64) return Subregister();
4865
4866 return !multiples.a64 ? multiples.range[off].ud(sub)
4867 : a64 ? multiples.range[off].uq(sub)
4868 : multiples.range[off].ud(2 * sub);
4869}
4870
4871static inline void releaseLDMultiples(
4872 LDMultiples &multiples, CommonState &state) {
4873 state.ra.safeRelease(multiples.range);
4874 multiples.a64 = false;
4875}
4876
4877// Ugly helpers handling address shifts. constexpr if would clean this all up.
4878template <HW hw>
4879template <typename BO>
4880typename std::enable_if<!std::is_base_of<RegData, BO>::value, BO>::type
4881gemm_kernel_generator_t<hw>::startShift(
4882 const BO &ptr, int shift, CommonState &state) {
4883 return ptr >> shift;
4884}
4885
4886template <HW hw>
4887Subregister gemm_kernel_generator_t<hw>::startShift(
4888 const MultishiftSubregister &ptr, int shift, CommonState &state) {
4889 return ptr >> shift;
4890}
4891
4892template <HW hw>
4893SubregisterPair gemm_kernel_generator_t<hw>::startShift(
4894 const SubregisterPair &ptr, int shift, CommonState &state) {
4895 if (shift == 0)
4896 return ptr;
4897 else
4898 return SubregisterPair(startShift(ptr.getReg(0), shift, state));
4899}
4900
4901template <HW hw>
4902template <typename BO>
4903typename std::enable_if<std::is_base_of<RegData, BO>::value, BO>::type
4904gemm_kernel_generator_t<hw>::startShift(
4905 const BO &ptr, int shift, CommonState &state) {
4906 BO ptrShifted = ptr;
4907
4908 // Shift pointer as necessary.
4909 if (shift > 0) {
4910 ptrShifted = state.ra.alloc_sub(ptr.getType());
4911 shr(1, ptrShifted, ptr, shift);
4912 }
4913
4914 return ptrShifted;
4915}
4916
4917template <HW hw>
4918template <typename BO, typename BI>
4919typename std::enable_if<!std::is_base_of<RegData, BO>::value>::type
4920gemm_kernel_generator_t<hw>::doneShift(
4921 const BO &ptr, const BI &ptrShifted, int shift, CommonState &state) {}
4922
4923template <HW hw>
4924template <typename BO, typename BI>
4925typename std::enable_if<std::is_base_of<RegData, BO>::value>::type
4926gemm_kernel_generator_t<hw>::doneShift(
4927 const BO &ptr, const BI &ptrShifted, int shift, CommonState &state) {
4928 if (shift > 0) state.ra.release(ptrShifted);
4929}
4930
4931template <HW hw>
4932void gemm_kernel_generator_t<hw>::doneShift(const SubregisterPair &ptr,
4933 const SubregisterPair &ptrShifted, int shift, CommonState &state) {
4934 if (shift > 0) doneShift(ptr.getReg(0), ptrShifted.getReg(0), shift, state);
4935}
4936
4937static inline bool canIncAddr(const RegisterBlock &blockSrc,
4938 const RegisterBlock &blockDst, const MatrixAddressing &atype,
4939 const MatrixAddressingStrategy &astrategy) {
4940 if (!blockSrc.isLoadBlock() || !blockDst.isLoadBlock()) return false;
4941 if (effectiveAccessType(atype, astrategy, blockDst) == AccessType::Block
4942 && effectiveAccessType(atype, astrategy, blockSrc)
4943 == AccessType::Block)
4944 return true;
4945 if (isBlock2D(astrategy.accessType))
4946 return (blockSrc.nr == blockDst.nr && blockSrc.nc == blockDst.nc);
4947 return (blockSrc.simdSize >= blockDst.simdSize);
4948}
4949
4950// Output code for setting up address/header GRFs for a single block, given
4951// the base pointer (a Subregister, MultishiftSubregister or integer) and leading dimension.
4952template <HW hw>
4953template <typename BO>
4954void gemm_kernel_generator_t<hw>::setupAddr(const GRFRange &addr, const BO &ptr,
4955 const RegisterBlock &block, const Subregister &bld, size_t sizeofT,
4956 const MatrixAddressing &atype,
4957 const MatrixAddressingStrategy &astrategy,
4958 const CommonStrategy &strategy, CommonState &state,
4959 const Address2DParams &params, LDMultiples ldMultiples) {
4960 bool a64 = astrategy.base.getModel() == ModelA64;
4961
4962 // Nothing to do for non-load blocks.
4963 if (!block.isLoadBlock()) return;
4964
4965 auto effAccessType = effectiveAccessType(atype, astrategy, block);
4966 switch (effAccessType) {
4967 case AccessType::Scattered:
4968 case AccessType::ChannelScattered:
4969 case AccessType::PseudoBlock: {
4970 int simdSize = block.simdSize;
4971 auto consecutive = block.extra;
4972 bool pseudo = (effAccessType == AccessType::PseudoBlock);
4973 auto Tptr = a64 ? DataType::uq : DataType::ud;
4974 int ne = elementsPerGRF(hw, Tptr);
4975 int preshift = 0;
4976
4977 auto oldIndexVec = state.indexVec;
4978 auto oldIVEntries = state.ivEntries;
4979
4980 if (!pseudo && !isPacked(atype.layout)) {
4981 // Get pointers to successive rows/columns, strided by ld.
4982 bool allocLDMultiples = false;
4983 if (ldMultiples.range.isInvalid()) {
4984 ldMultiples = createLDMultiples(
4985 a64, simdSize, bld, strategy, state);
4986 allocLDMultiples = true;
4987 } else
4988 (void)findLDMultiple(
4989 ldMultiples, a64, simdSize - 1, strategy, state);
4990
4991 for (int r = 0; r < addr.getLen(); r += 2) {
4992 int nr = std::min(2, addr.getLen() - r);
4993 int simd = nr * ne;
4994 auto ld0 = findLDMultiple(ldMultiples, a64,
4995 r * ne / consecutive, strategy, state);
4996 auto ldStride = (ldMultiples.a64 && !a64) ? 2 : 1;
4997 auto ldR = ld0(ldStride, consecutive, 0);
4998 auto addrR = addr[r].retype(Tptr);
4999 if (a64 && consecutive > 1 && hw >= HW::XeHP
5000 && !strategy.emulate
5001 .emulate64) { /* no swizzle in L pipe */
5002 mov(simd, addr[r].ud(0)(2),
5003 ld0.ud(0)(ldStride * 2, consecutive, 0));
5004 mov(simd, addr[r].ud(1)(2),
5005 ld0.ud(1)(ldStride * 2, consecutive, 0));
5006 if (ptr != 0) add(simd, addrR, addrR, ptr);
5007 } else if (ptr != 0)
5008 eadd(simd, addrR, ptr, ldR, strategy, state);
5009 else
5010 emov(simd, addrR, ldR, strategy, state);
5011 }
5012 if (allocLDMultiples) releaseLDMultiples(ldMultiples, state);
5013 } else {
5014 // Get pointers to successive elements, with constant stride.
5015 extendIndexVec(simdSize, state);
5016 auto iv = accessIndexVec(0, state)(1, consecutive, 0);
5017 uint16_t stride;
5018 preshift = block.addrShift;
5019 auto ptrShifted = startShift(ptr, block.addrShift, state);
5020
5021 if (pseudo)
5022 stride = (block.ebytes * block.count
5023 * getPartialCrosspack(
5024 sizeofT, atype, block))
5025 >> preshift;
5026 else {
5027 int tile = isColMajor(atype.layout) ? atype.tileR
5028 : atype.tileC;
5029 if (tile == 0) tile = atype.packSize;
5030 int psElems = (isLargeCrosspack(sizeofT, atype.crosspack)
5031 ? 1
5032 : tile)
5033 * atype.crosspack;
5034 stride = uint16_t(psElems * sizeofT) >> preshift;
5035 }
5036
5037 if (a64) {
5038 int udStride = (hw >= HW::XeHP) ? 2 : 1;
5039 int simd1 = std::min(2 * ne, simdSize);
5040 int simd2 = simdSize - simd1;
5041 if (udStride == 2 && simd2) {
5042 auto iv2 = accessIndexVec(simd1 / consecutive, state)(
5043 1, consecutive, 0);
5044 mulConstant(
5045 simd2, addr[2].ud(0)(udStride), iv2, stride);
5046 mulConstant(simd1, addr[0].ud(0)(udStride), iv, stride);
5047 } else
5048 mulConstant(
5049 simdSize, addr[0].ud(0)(udStride), iv, stride);
5050 if (simd2)
5051 eadd(simd2, addr[2].uq(), ptrShifted,
5052 addr[udStride].ud(0)(udStride), strategy,
5053 state);
5054 eadd(simd1, addr[0].uq(), ptrShifted,
5055 addr[0].ud(0)(udStride), strategy, state);
5056 } else if (ptrShifted != 0) {
5057 if (consecutive > 1) {
5058 mulConstant<uint32_t>(simdSize, addr, iv, stride);
5059 add<uint32_t>(simdSize, addr, addr, ptrShifted);
5060 } else
5061 emad(simdSize, addr[0].ud(), ptrShifted, iv,
5062 int32_t(stride), strategy, state);
5063 } else
5064 mulConstant<uint32_t>(simdSize, addr, iv, stride);
5065
5066 doneShift(ptr, ptrShifted, block.addrShift, state);
5067 }
5068
5069 // Add offsets for consecutive elements in scattered accesses.
5070 if (consecutive > 1) {
5071 if ((consecutive - 1) * block.ebytes >= 0x10) stub();
5072 if (consecutive > 4) stub();
5073 uint8_t incs[4];
5074 for (int idx = 0; idx < 4; idx++)
5075 incs[idx]
5076 = (block.ebytes * (idx % consecutive)) >> preshift;
5077
5078 if (!a64) {
5079 auto incImm = Immediate::uv(
5080 incs[0], 0, incs[1], 0, incs[2], 0, incs[3], 0);
5081 add<uint32_t>(simdSize, addr, addr, incImm);
5082 } else {
5083 if (consecutive > 2) stub();
5084 auto incImm
5085 = Immediate::uv(incs[0], 0, 0, 0, incs[1], 0, 0, 0);
5086 auto temp = state.ra.alloc_range(2);
5087 mov<uint32_t>(
5088 2 * elementsPerGRF<uint32_t>(hw), temp, incImm);
5089 map(hw, Tptr, addr, addr, strategy,
5090 [&](int simd, GRF r1, GRF _) {
5091 eadd<uint64_t>(simd, r1, r1, temp[0].ud(0)(2),
5092 strategy, state);
5093 });
5094 state.ra.safeRelease(temp);
5095 }
5096 }
5097
5098 // Scale if needed.
5099 if (block.addrShift > preshift)
5100 shr<uint32_t>(simdSize, addr, addr, block.addrShift - preshift);
5101
5102 // Restore original cached index vector in case we extended it.
5103 releaseRanges(state.indexVec, state);
5104 state.indexVec = oldIndexVec;
5105 state.ivEntries = oldIVEntries;
5106 reclaimRanges(state.indexVec, state);
5107 break;
5108 }
5109 case AccessType::Block:
5110 if (astrategy.base.getModel() == ModelA64) {
5111 emov(1, addr[0].uq(0), ptr, strategy, state);
5112 // Disable OWord channel mode on SKL.
5113 if (block.ebytes == 32 && hw < HW::Gen10)
5114 mov(1, addr[0].ud(5), uint32_t(0x80000000));
5115 } else if (astrategy.newDP) {
5116 mov(1, addr[0].ud(0), ptr);
5117 } else if (block.addrShift > 0)
5118 shr(1, addr[0].ud(2), ptr, block.addrShift);
5119 else
5120 mov(1, addr[0].ud(2), ptr);
5121 break;
5122 case AccessType::Block2D:
5123 case AccessType::Block2DTranspose:
5124 case AccessType::Block2DVNNI:
5125 if (astrategy.base.getModel() != ModelA64) hw_unsupported();
5126
5127 // Assemble some information.
5128 bool memCM = isColMajor(atype.layout);
5129 int bw, bh, multiX;
5130 getBlock2DWH(bw, bh, atype, block, &multiX);
5131
5132 auto iremR = params.remR, iremC = params.remC;
5133 if (!block.remainderR) iremR.invalidate();
5134 if (!block.remainderC) iremC.invalidate();
5135
5136 auto remW = memCM ? iremR : iremC;
5137 auto remH = memCM ? iremC : iremR;
5138 auto &nx = memCM ? params.rows : params.cols;
5139 auto &ny = memCM ? params.cols : params.rows;
5140 auto fixedX = memCM ? params.fixedRows : params.fixedCols;
5141 auto fixedY = memCM ? params.fixedCols : params.fixedRows;
5142 auto &offX = memCM ? params.offR : params.offC;
5143 auto &offY = memCM ? params.offC : params.offR;
5144 auto boffX = memCM ? block.offsetR : block.offsetC;
5145 auto boffY = memCM ? block.offsetC : block.offsetR;
5146
5147 boffX *= uint8_t(sizeofT);
5148 if (boffX % block.ebytes) stub();
5149 boffX /= block.ebytes;
5150
5151 // If the base address may not be 64b-aligned (128b pre-B4),
5152 // we need to emit code to align it down and offset x/width appropriately.
5153 int baseAlign = (getStepping() >= SteppingPVCXTB4 ? 64 : 128);
5154 bool doBaseAdjust = (atype.alignment & (baseAlign - 1)) != 0;
5155 if (doBaseAdjust && !astrategy.address2D) stub();
5156 Subregister baStorage, baseAdjust, baseAdjustElems;
5157
5158 if (!astrategy.address2D) mov(4, addr[0].ud(4)(1), 0u);
5159
5160 if (doBaseAdjust) {
5161 baStorage = state.ra.alloc_sub<uint32_t>();
5162 baseAdjust = baStorage.uw(0);
5163 baseAdjustElems = baStorage.uw(1);
5164 if (!offX.isValid()) baseAdjustElems = addr[0].ud(5);
5165
5166 and_(1, baseAdjust, ptr.ud(0), baseAlign - 1);
5167 and_(1, addr[0].ud(0), ptr.ud(0), ~uint32_t(baseAlign - 1));
5168 mov(1, addr[0].ud(1), ptr.ud(1));
5169 if (block.ebytes > 1)
5170 shr(1, baseAdjustElems, baseAdjust, log2(block.ebytes));
5171 else
5172 baseAdjustElems = baseAdjust;
5173 } else
5174 emov(1, addr[0].uq(0), ptr, strategy, state);
5175
5176 if (astrategy.address2D) {
5177 if (params.rows.isInvalid() && params.fixedRows == 0)
5178 throw std::runtime_error("Unknown matrix size.");
5179
5180 nx.isValid() ? mad(1, addr[0].ud(2), -1, nx, sizeofT)
5181 : mov(1, addr[0].ud(2), fixedX * sizeofT - 1);
5182 ny.isValid() ? add(1, addr[0].ud(3), ny, -1)
5183 : mov(1, addr[0].ud(3), fixedY - 1);
5184 offX.isValid() ? addScaled(1, addr[0].ud(5), boffX, offX,
5185 int(sizeofT), block.ebytes, state)
5186 : doBaseAdjust
5187 ? add(1, addr[0].ud(5), baseAdjustElems, boffX)
5188 : mov(1, addr[0].ud(5), boffX);
5189 offY.isValid() ? add(1, addr[0].ud(6), offY, boffY)
5190 : mov(1, addr[0].ud(6), boffY);
5191 if (doBaseAdjust) {
5192 add(1, addr[0].ud(2), addr[0].ud(2), baseAdjust);
5193 if (offX.isValid())
5194 add(1, addr[0].ud(5), addr[0].ud(5), baseAdjustElems);
5195 }
5196 if (sizeofT < 4)
5197 or_(1, addr[0].ud(2), addr[0].ud(2),
5198 3); // Width must be 4-byte-aligned.
5199 } else if (remW.isInvalid() && remH.isInvalid())
5200 emov(1, addr[0].uq(1),
5201 uint64_t(bw * block.count * block.ebytes - 1)
5202 | (uint64_t(bh * block.ebytes - 1) << 32),
5203 strategy, state);
5204 else {
5205 if (remW.isValid() && multiX > 1) stub();
5206 remW.isValid() ? mad(1, addr[0].ud(2), -1, remW.uw(), sizeofT)
5207 : mov(1, addr[0].ud(2),
5208 bw * block.count * block.ebytes - 1);
5209 remH.isValid() ? mad(1, addr[0].ud(3), -1, remH.uw(), multiX)
5210 : mov(1, addr[0].ud(3), bh - 1);
5211 if (remW.isValid() && sizeofT < 4)
5212 or_(1, addr[0].ud(2), addr[0].ud(2), 3);
5213 }
5214
5215 if (isPacked(atype.layout)) {
5216 auto pitch = bw * block.count * block.ebytes;
5217 if (pitch < 64 || pitch & 0xF) hw_unsupported();
5218 mov(1, addr[0].ud(4), pitch - 1);
5219 } else
5220 add(1, addr[0].ud(4), bld, -1);
5221
5222 mov(1, addr[0].ud(7),
5223 (bw - 1) | ((bh - 1) << 8) | ((block.count - 1) << 16));
5224
5225 state.ra.safeRelease(baStorage);
5226 break;
5227 }
5228}
5229
5230// Shift an address block by a combination of a fixed and LD offset.
5231template <HW hw>
5232void gemm_kernel_generator_t<hw>::offsetAddr(const GRFRange &addrDst,
5233 const GRFRange &addrSrc, const RegisterBlock &blockDst,
5234 const RegisterBlock &blockSrc, int offsetFixed, int offsetLD,
5235 const Subregister &ld, const MatrixAddressing &atype,
5236 const MatrixAddressingStrategy &astrategy,
5237 const CommonStrategy &strategy, CommonState &state,
5238 const LDMultiples &ldMultiples) {
5239 bool a64 = (astrategy.base.getModel() == ModelA64);
5240
5241 if (astrategy.address2D) stub();
5242
5243 if (offsetLD == 0) {
5244 if (offsetFixed != 0)
5245 incAddr(addrDst, addrSrc, offsetFixed, blockDst, blockSrc, atype,
5246 astrategy, strategy, state);
5247 } else {
5248 // Reuse ld * offsetLD calculation if available.
5249 auto ldInc
5250 = findLDMultiple(ldMultiples, a64, offsetLD, strategy, state);
5251
5252 if (ldInc.isValid() && offsetFixed == 0)
5253 incAddr(addrDst, addrSrc, (offsetLD == 1) ? ld : ldInc, blockDst,
5254 blockSrc, atype, astrategy, strategy, state);
5255 else {
5256 Subregister incAlloc
5257 = state.ra.alloc_sub(a64 ? DataType::uq : DataType::ud);
5258 auto inc = incAlloc;
5259
5260 if (ldInc.isInvalid()) {
5261 if (offsetLD == 1)
5262 ldInc = ld;
5263 else {
5264 emulConstant(1, inc, ld, offsetLD, strategy, state);
5265 ldInc = inc;
5266 }
5267 }
5268 if (offsetFixed != 0)
5269 eadd(1, inc, ldInc, offsetFixed, strategy, state);
5270 else
5271 inc = ldInc;
5272 incAddr(addrDst, addrSrc, inc, blockDst, blockSrc, atype, astrategy,
5273 strategy, state);
5274
5275 state.ra.safeRelease(incAlloc);
5276 }
5277 }
5278}
5279
5280// Output code for initializing address/header GRFs for one block based on another block's headers.
5281template <HW hw>
5282void gemm_kernel_generator_t<hw>::setupAddrRel(Type T, const GRFRange &addrDst,
5283 const GRFRange &addrSrc, const RegisterBlock &blockDst,
5284 const RegisterBlock &blockSrc, const vector<RegisterBlock> &layout,
5285 const Subregister &ld, const MatrixAddressing &atype,
5286 const MatrixAddressingStrategy &astrategy,
5287 const CommonStrategy &strategy, CommonState &state,
5288 const LDMultiples &ldMultiples) {
5289 int deltaR = blockDst.offsetR - blockSrc.offsetR;
5290 int deltaC = blockDst.offsetC - blockSrc.offsetC;
5291
5292 if (astrategy.address2D)
5293 incAddr(addrDst, addrSrc, Subregister(), deltaR, deltaC, blockDst,
5294 blockSrc, atype, astrategy, strategy, state);
5295 else {
5296 int offsetFixed = 0, offsetLD = 0, r = 0, c = 0;
5297
5298 if (isPacked(atype.layout)) getLayoutDims(layout, r, c);
5299
5300 switch (atype.layout) {
5301 case MatrixLayout::N:
5302 offsetFixed = deltaR;
5303 offsetLD = deltaC;
5304 break;
5305 case MatrixLayout::T:
5306 offsetFixed = deltaC;
5307 offsetLD = deltaR;
5308 break;
5309 case MatrixLayout::Pc:
5310 case MatrixLayout::Pr:
5311 offsetFixed = untile(T, atype, blockDst, r, c)
5312 - untile(T, atype, blockSrc, r, c);
5313 break;
5314 }
5315
5316 offsetFixed *= T.size();
5317
5318 offsetAddr(addrDst, addrSrc, blockDst, blockSrc, offsetFixed, offsetLD,
5319 ld, atype, astrategy, strategy, state, ldMultiples);
5320 }
5321}
5322
5323static inline int findBaseBlock(const RegisterBlock &block,
5324 const vector<RegisterBlock> &layout, int start, int end,
5325 const MatrixAddressing &atype,
5326 const MatrixAddressingStrategy &astrategy) {
5327 int bbase = -1;
5328 for (int bb = start; bb < end; bb++) {
5329 auto &other = layout[bb];
5330 if (canIncAddr(other, block, atype, astrategy)) {
5331 if (bbase < 0) bbase = bb;
5332 if (other.offsetR == block.offsetR
5333 || other.offsetC == block.offsetC)
5334 return bb; // "Best fit"
5335 }
5336 }
5337 return bbase;
5338}
5339
5340// Output code for initializing address/header GRFs for an entire register layout.
5341// ptr is an integer, Subregister, or MultishiftSubregister holding the base pointer/offset.
5342template <HW hw>
5343template <typename BO>
5344void gemm_kernel_generator_t<hw>::setupAddr(Type T,
5345 const vector<GRFRange> &addr, const BO &ptr,
5346 const vector<RegisterBlock> &layout, const Subregister &ld,
5347 const MatrixAddressing &atype,
5348 const MatrixAddressingStrategy &astrategy,
5349 const CommonStrategy &strategy, CommonState &state,
5350 const Address2DParams &params, const LDMultiples &ldMultiples) {
5351 auto nblocks = int(layout.size());
5352
5353 for (int b = 0; b < nblocks; b++) {
5354 auto &block = layout[b];
5355
5356 // Skip non-load blocks.
5357 if (!block.isLoadBlock()) continue;
5358
5359 auto bparams = params;
5360 Subregister tempRem;
5361 if (isBlock2D(astrategy.accessType) && !astrategy.address2D) {
5362 tempRem = state.ra.alloc_sub<uint32_t>();
5363 if (bparams.remR.isValid()) bparams.remR = tempRem.uw(0);
5364 if (bparams.remC.isValid()) bparams.remC = tempRem.uw(1);
5365 if (bparams.remR.isValid() && block.offsetR)
5366 add(1 | sat, bparams.remR, params.remR, -block.offsetR);
5367 if (bparams.remC.isValid() && block.offsetC)
5368 add(1 | sat, bparams.remC, params.remC, -block.offsetC);
5369 if (bparams.remR.isValid())
5370 min_(1, bparams.remR,
5371 block.offsetR ? bparams.remR : params.remR, block.nr);
5372 if (bparams.remC.isValid())
5373 min_(1, bparams.remC,
5374 block.offsetC ? bparams.remC : params.remC, block.nc);
5375 }
5376 // Look for a block to base this one off of.
5377 int bbase = findBaseBlock(block, layout, 0, b, atype, astrategy);
5378
5379 if (bbase < 0) {
5380 // No base address, set up a new base address.
5381 setupAddr(addr[b], ptr, block, ld, T.size(), atype, astrategy,
5382 strategy, state, bparams, ldMultiples);
5383 state.ra.safeRelease(tempRem);
5384 }
5385
5386 // Increment as appropriate.
5387 if (bbase >= 0) {
5388 setupAddrRel(T, addr[b], addr[bbase], block, layout[bbase], layout,
5389 ld, atype, astrategy, strategy, state, ldMultiples);
5390 } else if (!astrategy.address2D) {
5391 int offsetFixed = 0, offsetLD = 0, r = 0, c = 0;
5392 if (isPacked(atype.layout)) getLayoutDims(layout, r, c);
5393 switch (atype.layout) {
5394 case MatrixLayout::N:
5395 offsetFixed = block.offsetR;
5396 offsetLD = block.offsetC;
5397 break;
5398 case MatrixLayout::T:
5399 offsetFixed = block.offsetC;
5400 offsetLD = block.offsetR;
5401 break;
5402 case MatrixLayout::Pc:
5403 case MatrixLayout::Pr:
5404 offsetFixed = untile(T, atype, block, r, c);
5405 break;
5406 }
5407
5408 offsetFixed *= T.size();
5409
5410 offsetAddr(addr[b], addr[b], block, block, offsetFixed, offsetLD,
5411 ld, atype, astrategy, strategy, state, ldMultiples);
5412 }
5413 }
5414}
5415
5416// Output code for incrementing the pointers for a given block by a specified # of bytes.
5417// The amount may be an immediate, Subregister, or MultishiftSubregister.
5418template <HW hw>
5419template <typename I, typename Ir, typename Ic>
5420void gemm_kernel_generator_t<hw>::incAddr(const GRFRange &addrDst,
5421 const GRFRange &addrSrc, I inc, Ir incR, Ic incC,
5422 const RegisterBlock &layoutDst, const RegisterBlock &layoutSrc,
5423 const MatrixAddressing &atype,
5424 const MatrixAddressingStrategy &astrategy,
5425 const CommonStrategy &strategy, CommonState &state) {
5426 auto incShifted = startShift(inc, layoutDst.addrShift, state);
5427
5428 incAddrShifted(addrDst, addrSrc, incShifted, incR, incC, layoutDst,
5429 layoutSrc, atype, astrategy, strategy, state);
5430
5431 doneShift(inc, incShifted, layoutDst.addrShift, state);
5432}
5433
5434template <HW hw>
5435template <typename I>
5436void gemm_kernel_generator_t<hw>::incAddr(const GRFRange &addrDst,
5437 const GRFRange &addrSrc, I inc, const RegisterBlock &layoutDst,
5438 const RegisterBlock &layoutSrc, const MatrixAddressing &atype,
5439 const MatrixAddressingStrategy &astrategy,
5440 const CommonStrategy &strategy, CommonState &state) {
5441 if (astrategy.address2D) stub();
5442 incAddr(addrDst, addrSrc, inc, Subregister(), Subregister(), layoutDst,
5443 layoutSrc, atype, astrategy, strategy, state);
5444}
5445
5446template <HW hw>
5447template <typename I, typename Ir, typename Ic>
5448void gemm_kernel_generator_t<hw>::incAddrShifted(const GRFRange &addrDst,
5449 const GRFRange &addrSrc, I inc, Ir incR, Ic incC,
5450 const RegisterBlock &layoutDst, const RegisterBlock &layoutSrc,
5451 const MatrixAddressing &atype,
5452 const MatrixAddressingStrategy &astrategy,
5453 const CommonStrategy &strategy, CommonState &state) {
5454 // Handle non-load blocks.
5455 if (!layoutDst.isLoadBlock()) return;
5456 if (!layoutSrc.isLoadBlock()) stub();
5457
5458 if (layoutDst.addrShift != layoutSrc.addrShift) stub();
5459
5460 auto cinc = avoidConflict(hw, inc, addrSrc[0]);
5461 auto cincR = avoidConflict(hw, incR, addrSrc[0]);
5462 auto cincC = avoidConflict(hw, incC, addrSrc[0]);
5463
5464 switch (effectiveAccessType(atype, astrategy, layoutSrc)) {
5465 case AccessType::PseudoBlock:
5466 if (layoutSrc.ebytes != layoutDst.ebytes) stub();
5467 // fall through
5468 case AccessType::ChannelScattered:
5469 case AccessType::Scattered: {
5470 int naddrDst = layoutDst.simdSize;
5471 int naddrSrc = layoutSrc.simdSize;
5472 if (naddrDst > naddrSrc) stub();
5473 if (astrategy.base.getModel() == ModelA64) {
5474 auto simd = 2 * elementsPerGRF(hw, Type::u64);
5475 for (int ar = 0; naddrDst > 0; ar += 2, naddrDst -= simd)
5476 eadd<uint64_t>(std::min(naddrDst, simd), addrDst[ar],
5477 addrSrc[ar], avoidConflict(hw, inc, addrSrc[ar]),
5478 strategy, state);
5479 } else
5480 add<uint32_t>(naddrDst, addrDst[0], addrSrc[0], cinc);
5481 break;
5482 }
5483 case AccessType::Block:
5484 if (astrategy.base.getModel() == ModelA64) {
5485 eadd(1, addrDst[0].uq(0), addrSrc[0].uq(0), cinc, strategy,
5486 state);
5487 if (addrDst != addrSrc && layoutDst.ebytes == 32
5488 && hw < HW::Gen10)
5489 mov(1, addrDst[0].ud(5),
5490 uint32_t(
5491 0x80000000)); // Disable OWord channel mode on SKL.
5492 } else if (astrategy.newDP) {
5493 add(1, addrDst[0].ud(0), addrSrc[0].ud(0), cinc);
5494 } else
5495 add(1, addrDst[0].ud(2), addrSrc[0].ud(2), cinc);
5496 break;
5497 case AccessType::Block2D:
5498 case AccessType::Block2DTranspose:
5499 case AccessType::Block2DVNNI:
5500 if (addrDst != addrSrc) mov<uint32_t>(8, addrDst[0], addrSrc[0]);
5501 if (astrategy.address2D) {
5502 if (isColMajor(atype.layout)) {
5503 if (cincR != 0)
5504 addScaled(1, addrDst[0].d(5), addrDst[0].d(5), cincR,
5505 layoutDst.extra, layoutDst.ebytes, state, true);
5506 if (cincC != 0)
5507 add(1, addrDst[0].d(6), addrDst[0].d(6), cincC);
5508 } else {
5509 if (cincC != 0)
5510 addScaled(1, addrDst[0].d(5), addrDst[0].d(5), cincC,
5511 layoutDst.extra, layoutDst.ebytes, state, true);
5512 if (cincR != 0)
5513 add(1, addrDst[0].d(6), addrDst[0].d(6), cincR);
5514 }
5515 } else
5516 eadd(1, addrDst[0].uq(0), addrSrc[0].uq(0), cinc, strategy,
5517 state);
5518 break;
5519 }
5520}
5521
5522// Output code for incrementing all pointers for a register layout by a specified # of bytes.
5523// The amount may be an immediate or a subregister.
5524template <HW hw>
5525template <typename I, typename Ir, typename Ic>
5526void gemm_kernel_generator_t<hw>::incAddr(const vector<GRFRange> &addr, I inc,
5527 Ir incR, Ic incC, const vector<RegisterBlock> &layout,
5528 const MatrixAddressing &atype,
5529 const MatrixAddressingStrategy &astrategy,
5530 const CommonStrategy &strategy, CommonState &state) {
5531 auto nblocks = int(layout.size());
5532
5533 for (int b = 0; b < nblocks; b++)
5534 incAddr(addr[b], addr[b], inc, incR, incC, layout[b], layout[b], atype,
5535 astrategy, strategy, state);
5536}
5537
5538template <HW hw>
5539template <typename I>
5540void gemm_kernel_generator_t<hw>::incAddr(const vector<GRFRange> &addr, I inc,
5541 const vector<RegisterBlock> &layout, const MatrixAddressing &atype,
5542 const MatrixAddressingStrategy &astrategy,
5543 const CommonStrategy &strategy, CommonState &state) {
5544 if (astrategy.address2D) stub();
5545 incAddr(addr, inc, Subregister(), Subregister(), layout, atype, astrategy,
5546 strategy, state);
5547}
5548
5549template <HW hw>
5550template <typename I, typename Ir, typename Ic>
5551void gemm_kernel_generator_t<hw>::incAddrShifted(const vector<GRFRange> &addr,
5552 I inc, Ir incR, Ic incC, const vector<RegisterBlock> &layout,
5553 const MatrixAddressing &atype,
5554 const MatrixAddressingStrategy &astrategy,
5555 const CommonStrategy &strategy, CommonState &state) {
5556 auto nblocks = int(layout.size());
5557
5558 for (int b = 0; b < nblocks; b++)
5559 incAddrShifted(addr[b], addr[b], inc, incR, incC, layout[b], layout[b],
5560 atype, astrategy, strategy, state);
5561}
5562
5563template <HW hw>
5564template <typename I>
5565void gemm_kernel_generator_t<hw>::incAddrShifted(const vector<GRFRange> &addr,
5566 I inc, const vector<RegisterBlock> &layout,
5567 const MatrixAddressing &atype,
5568 const MatrixAddressingStrategy &astrategy,
5569 const CommonStrategy &strategy, CommonState &state) {
5570 if (astrategy.address2D) stub();
5571 incAddrShifted(addr, inc, Subregister(), Subregister(), layout, atype,
5572 astrategy, strategy, state);
5573}
5574
5575template <typename T>
5576struct NegativeType {
5577 typedef T type;
5578};
5579template <>
5580struct NegativeType<uint8_t> {
5581 typedef int8_t type;
5582};
5583template <>
5584struct NegativeType<uint16_t> {
5585 typedef int16_t type;
5586};
5587template <>
5588struct NegativeType<uint32_t> {
5589 typedef int32_t type;
5590};
5591template <>
5592struct NegativeType<int> {
5593 typedef int32_t type;
5594};
5595template <>
5596struct NegativeType<int64_t> {
5597 typedef int32_t type;
5598};
5599
5600// Output code for incrementing or decrementing all pointers for a register layout by a specified # of bytes.
5601// The amount may be an immediate or a MultishiftSubregister.
5602template <HW hw>
5603template <typename A, typename I, typename Ir, typename Ic>
5604void gemm_kernel_generator_t<hw>::incDecAddr(const A &addr, I inc, Ir incR,
5605 Ic incC, const vector<RegisterBlock> &layout,
5606 const MatrixAddressing &atype,
5607 const MatrixAddressingStrategy &astrategy,
5608 const CommonStrategy &strategy, CommonState &state, bool decrement) {
5609 typename NegativeType<I>::type signedInc = decrement ? -inc : inc;
5610 typename NegativeType<Ir>::type signedIncR = decrement ? -incR : incR;
5611 typename NegativeType<Ic>::type signedIncC = decrement ? -incC : incC;
5612
5613 incAddr(addr, signedInc, signedIncR, signedIncC, layout, atype, astrategy,
5614 strategy, state);
5615}
5616
5617template <HW hw>
5618template <typename A, typename I>
5619void gemm_kernel_generator_t<hw>::incDecAddr(const A &addr, I inc,
5620 const vector<RegisterBlock> &layout, const MatrixAddressing &atype,
5621 const MatrixAddressingStrategy &astrategy,
5622 const CommonStrategy &strategy, CommonState &state, bool decrement) {
5623 if (astrategy.address2D) stub();
5624 incDecAddr(addr, inc, Subregister(), Subregister(), layout, atype,
5625 astrategy, strategy, state, decrement);
5626}
5627
5628template <HW hw>
5629void gemm_kernel_generator_t<hw>::setAddrRemainder(Type T, const GRFRange &addr,
5630 const RegisterBlock &block, const Subregister &remR,
5631 const Subregister &remC, const MatrixAddressing &atype,
5632 const MatrixAddressingStrategy &astrategy,
5633 const CommonStrategy &strategy, CommonState &state) {
5634 if (!isBlock2D(astrategy.accessType) || astrategy.address2D) return;
5635
5636 auto tempRem = state.ra.alloc_sub<uint32_t>();
5637 Subregister thisRemR = remR, thisRemC = remC;
5638
5639 auto memCM = isColMajor(atype.layout);
5640 auto &remW = memCM ? thisRemR : thisRemC;
5641 auto &remH = memCM ? thisRemC : thisRemR;
5642 int bw, bh, multiX;
5643 getBlock2DWH(bw, bh, atype, block, &multiX);
5644
5645 if (!block.remainderR) thisRemR.invalidate();
5646 if (!block.remainderC) thisRemC.invalidate();
5647 if (thisRemR.isValid()) thisRemR = tempRem.uw(0);
5648 if (thisRemC.isValid()) thisRemC = tempRem.uw(1);
5649 if (thisRemR.isValid() && block.offsetR)
5650 add(1 | sat, thisRemR, remR, -block.offsetR);
5651 if (thisRemC.isValid() && block.offsetC)
5652 add(1 | sat, thisRemC, remC, -block.offsetC);
5653 if (thisRemR.isValid())
5654 min_(1, thisRemR, block.offsetR ? thisRemR : remR, block.nr);
5655 if (thisRemC.isValid())
5656 min_(1, thisRemC, block.offsetC ? thisRemC : remC, block.nc);
5657
5658 if (remW.isValid()) {
5659 if (block.count > 1 || multiX > 1) stub();
5660 mad(1, addr[0].ud(2), -1, remW.uw(), T.size());
5661 }
5662 if (remH.isValid()) mad(1, addr[0].ud(3), -1, remH.uw(), T.size() * multiX);
5663 if (remW.isValid() && T.size() < 4) or_(1, addr[0].ud(2), addr[0].ud(2), 3);
5664
5665 state.ra.safeRelease(tempRem);
5666}
5667
5668template <HW hw>
5669void gemm_kernel_generator_t<hw>::setAddrRemainder(Type T,
5670 const vector<GRFRange> &addr, const vector<RegisterBlock> &layout,
5671 const Subregister &remR, const Subregister &remC,
5672 const MatrixAddressing &atype,
5673 const MatrixAddressingStrategy &astrategy,
5674 const CommonStrategy &strategy, CommonState &state) {
5675 auto nblocks = int(layout.size());
5676
5677 for (int b = 0; b < nblocks; b++)
5678 setAddrRemainder(T, addr[b], layout[b], remR, remC, atype, astrategy,
5679 strategy, state);
5680}
5681
5682template <HW hw>
5683void gemm_kernel_generator_t<hw>::setupTeardownRemask(Type T, int index,
5684 bool setup, int nq, const Subregister &remQ,
5685 const CommonStrategy &strategy, CommonState &state, int fixedOffQ,
5686 const Subregister &variableOffQ) {
5687 if (setup) {
5688 auto masks = state.remaskRegs[index] = state.ra.alloc_range(
5689 div_up(T.size(), 2) * div_up(nq * 2, GRF::bytes(hw)));
5690 int ne16 = elementsPerGRF(hw, Type::u16);
5691 int n16 = std::min(nq, ne16);
5692 int ne = elementsPerGRF(hw, T);
5693 auto flag = state.raVFlag.tryAlloc((n16 > 16) ? 2 : 1);
5694 bool useCMP = flag.isValid()
5695 && (T.size() < 4); // apparent issues with 4b sequence
5696
5697 auto effRemQ = remQ;
5698 bool freeEffRemQ = false;
5699 bool haveVariableOff = variableOffQ.isValid();
5700 bool haveFixedOff = (fixedOffQ != 0);
5701
5702 if (haveVariableOff || haveFixedOff) {
5703 freeEffRemQ = true;
5704 effRemQ = state.ra.alloc_sub<uint32_t>();
5705
5706 if (haveVariableOff && haveFixedOff)
5707 eadd3(1, effRemQ, remQ, -variableOffQ, -fixedOffQ);
5708 else if (haveVariableOff)
5709 add(1, effRemQ, remQ, -variableOffQ);
5710 else
5711 add(1, effRemQ, remQ, -fixedOffQ);
5712 }
5713
5714 mov<uint16_t>(8, masks[0][0](1), Immediate::uv(0, 1, 2, 3, 4, 5, 6, 7));
5715 if (nq > 8)
5716 mov<uint16_t>(8, masks[0][8](1),
5717 Immediate::uv(8, 9, 10, 11, 12, 13, 14, 15));
5718 if (GRF::bytes(hw) > 32 && nq > 16)
5719 add<uint16_t>(16, masks[0][16](1), masks[0][0](1), 16);
5720 add<uint16_t>(n16, masks[0], masks[0], -effRemQ.w());
5721 if (!useCMP)
5722 for (int q0 = n16; q0 < nq; q0 += n16)
5723 add<uint16_t>(n16, masks[q0 / n16], masks[0], q0);
5724
5725 switch (T.size()) {
5726 case 1:
5727 case 2:
5728 if (useCMP) {
5729 for (int q0 = n16; q0 < nq; q0 += n16)
5730 cmp<int16_t>(n16 | lt | flag, masks[q0 / n16], masks[0],
5731 -q0);
5732 asr<int16_t>(n16, masks[0], masks[0], 15);
5733 } else {
5734 map(hw, Type::s16, masks, masks, strategy,
5735 [=](int simd, const RegData &r1, const RegData &) {
5736 asr(simd, r1, r1, 15);
5737 });
5738 }
5739 if (T.size() == 1)
5740 for (int q0 = 0; q0 < nq; q0 += n16)
5741 mov(n16, masks[q0 / ne].ub(q0 % ne)(1),
5742 masks[q0 / n16].ub(1)(2));
5743 break;
5744 case 4:
5745 for (int qq0 = div_up(nq, ne16) - 1; qq0 >= 1; qq0--) {
5746 useCMP ? cmp(ne16 | lt | flag, masks[qq0 * 2].d(),
5747 masks[qq0].w(), -qq0 * ne16)
5748 : asr(ne16, masks[qq0 * 2].d(), masks[qq0].w(), 15);
5749 }
5750 if (nq > (ne16 / 2))
5751 asr(ne16 / 2, masks[1].d(), masks[0].w(ne16 / 2)(1), 15);
5752 asr(ne16 / 2, masks[0].d(), masks[0].w(), 15);
5753 break;
5754 default: stub();
5755 }
5756
5757 if (freeEffRemQ) state.ra.safeRelease(effRemQ);
5758 state.raVFlag.safeRelease(flag);
5759 } else
5760 state.ra.safeRelease(state.remaskRegs[index]);
5761}
5762
5763template <HW hw>
5764void gemm_kernel_generator_t<hw>::remaskLayout(Type T, int index, bool column,
5765 const std::vector<RegisterBlock> &layout, const GRFMultirange &regs,
5766 const CommonStrategy &strategy, CommonState &state, int offset) {
5767 for (auto &block : layout) {
5768 auto crosspack = block.crosspack;
5769 bool colMajor = block.colMajor;
5770 int nx = colMajor ? block.nr : block.nc;
5771 int ny = colMajor ? block.nc : block.nr;
5772
5773 for (int y0 = 0; y0 < ny; y0 += crosspack) {
5774 for (int x0 = 0; x0 < nx;) {
5775 auto ii0 = colMajor ? x0 : y0;
5776 auto jj0 = colMajor ? y0 : x0;
5777 auto i0 = ii0 + block.offsetR;
5778 auto j0 = jj0 + block.offsetC;
5779
5780 int ne;
5781 auto sub = findBlockReg(
5782 T, block, ii0, jj0, regs, ne, -1, block.component);
5783
5784 auto necp = ne * crosspack;
5785 necp = std::min(necp, 2 * elementsPerGRF(hw, T));
5786 if ((necp * T) & 3) stub();
5787
5788 int mstride;
5789 Type mtype = Type::u32;
5790
5791 if (colMajor != column && crosspack == 1)
5792 mstride = 1;
5793 else if (colMajor != column && crosspack == 4 / T)
5794 mstride = 1, mtype = sintType(T);
5795 else if (colMajor == column && crosspack == 4 / T)
5796 mstride = 0;
5797 else
5798 stub();
5799
5800 int moff = (offset + (column ? j0 : i0)) * T / mtype;
5801 int mreg = moff / elementsPerGRF(hw, mtype);
5802 int msub = moff % elementsPerGRF(hw, mtype);
5803
5804 and_<uint32_t>((necp * T) / 4, sub.ud()(1), sub.ud()(1),
5805 state.remaskRegs[index][mreg].sub(msub, mtype.ngen())(
5806 mstride));
5807 x0 += necp / crosspack;
5808 }
5809 }
5810 }
5811}
5812
5813static bool needsRemask(Type T, bool column, const RegisterBlock &block,
5814 const MatrixAddressingStrategy &astrategy, bool ignoreMasks = false) {
5815 if (!ignoreMasks)
5816 if (column ? !block.remainderC : !block.remainderR) return false;
5817
5818 bool block2D = isBlock2D(astrategy.accessType);
5819
5820 int maskGranularity = block.ebytes;
5821 if (block.ebytes >= 16) maskGranularity = 4;
5822 if (block2D) maskGranularity = std::max(maskGranularity, 4);
5823 if (ignoreMasks && !(block2D && astrategy.address2D)) maskGranularity = 256;
5824
5825 return (T.size() < maskGranularity);
5826}
5827
5828static bool needsRemask(Type T, bool column,
5829 const vector<RegisterBlock> &layout,
5830 const MatrixAddressingStrategy &astrategy, bool ignoreMasks = false) {
5831 for (auto &block : layout)
5832 if (needsRemask(T, column, block, astrategy, ignoreMasks)) return true;
5833 return false;
5834}
5835
5836// The systolic array performs a series of GEMVs with a single fixed-size matrix.
5837// The size of the matrix is osys x ksys with vectors of size ksys x 1.
5838// The number of GEMVs (with same matrix) is given by the (variable) repeat count.
5839struct SystolicParams {
5840 int opsPerChan; // # of FMAs/stage
5841 int sdepth; // Number of stages (systolic depth)
5842 int rcountMax; // Maximum repeat count (# of RHS)
5843 int ksys; // Total number of FMAs
5844 int osys; // Output vector length
5845};
5846
5847static inline SystolicParams systolicParams(
5848 HW hw, const GEMMProblem &problem, const GEMMStrategy &strategy) {
5849 SystolicParams params;
5850 params.opsPerChan = std::max(
5851 1, std::min(4 / problem.Ta.real(), 4 / problem.Tb.real()));
5852 params.sdepth = 8;
5853 params.ksys = params.sdepth * params.opsPerChan;
5854 params.osys = GRF::bytes(hw) / std::max(problem.Tc.real().size(), 4);
5855 params.rcountMax = 8;
5856
5857 if (hw == HW::XeHPC) {
5858 // Workaround for src2 read suppression bug (TODO PVC-B: remove WA)
5859 bool cColMajor = isRegisterColMajor(problem.Tc, problem.C, strategy.C);
5860 if (strategy.unroll[cColMajor ? LoopN : LoopM] == 8)
5861 params.rcountMax = 4;
5862 }
5863
5864 return params;
5865}
5866
5867// Return # of outer products performed at once.
5868static inline int minOuterProductCount(
5869 HW hw, const GEMMProblem &problem, const GEMMStrategy &strategy) {
5870 auto Ta = problem.Ta, Tb = problem.Tb, Tc = problem.Tc;
5871 if (strategy.systolic) {
5872 auto params = systolicParams(hw, problem, strategy);
5873 return params.ksys;
5874 }
5875 if (Ta.real().size() == 1 && Tb.real().size() == 1 && Tc.real().size() == 4
5876 && (hw >= HW::Gen12LP))
5877 return 4;
5878 return 1;
5879}
5880
5881// Return # of outer products performed at once.
5882static inline int outerProductCount(
5883 HW hw, const GEMMProblem &problem, const GEMMStrategy &strategy) {
5884 return minOuterProductCount(hw, problem, strategy) * strategy.kChain;
5885}
5886
5887// Get the A and B crosspacks needed by the kernel. 0 indicates any crosspack is OK.
5888static std::tuple<int, int> targetKernelCrosspack(
5889 HW hw, const GEMMProblem &problem, const GEMMStrategy &strategy) {
5890 int opBatch = minOuterProductCount(hw, problem, strategy);
5891 bool aColMajor = isRegisterColMajor(problem.Ta, problem.A, strategy.A);
5892 bool bColMajor = isRegisterColMajor(problem.Tb, problem.B, strategy.B);
5893 bool cColMajor = isRegisterColMajor(problem.Tc, problem.C, strategy.C);
5894
5895 if (strategy.systolic) {
5896 return cColMajor
5897 ? std::make_tuple(std::max(1, 4 / problem.Ta.real()), 1)
5898 : std::make_tuple(1, std::max(1, 4 / problem.Tb.real()));
5899 }
5900 if (opBatch == 1) {
5901 return cColMajor ? std::make_tuple(1, 0) : std::make_tuple(0, 1);
5902 } else {
5903 bool bcastOK = cColMajor ? bColMajor : !aColMajor;
5904
5905 return cColMajor ? std::make_tuple(opBatch, bcastOK ? 1 : opBatch)
5906 : std::make_tuple(bcastOK ? 1 : opBatch, opBatch);
5907 }
5908}
5909
5910// Get the A and B crosspacks to use for SLM data.
5911static std::tuple<int, int> targetSLMCrosspack(
5912 HW hw, const GEMMProblem &problem, const GEMMStrategy &strategy) {
5913 int opBatch = minOuterProductCount(hw, problem, strategy);
5914
5915 if (strategy.systolic) {
5916 bool cColMajor = isRegisterColMajor(problem.Tc, problem.C, strategy.C);
5917 return cColMajor
5918 ? std::make_tuple(std::max(1, 4 / problem.Ta.size()), opBatch)
5919 : std::make_tuple(opBatch, std::max(1, 4 / problem.Tb.size()));
5920 }
5921 return std::make_tuple(opBatch, opBatch);
5922}
5923
5924// Get the A and B tiling needed by the kernel.
5925// Return value is in the format {A_tileR, A_tileC, B_tileR, B_tileC}.
5926static std::tuple<int, int, int, int> targetKernelTiling(
5927 HW hw, const GEMMProblem &problem, const GEMMStrategy &strategy) {
5928 if (strategy.systolic) {
5929 auto params = systolicParams(hw, problem, strategy);
5930 bool cColMajor = isRegisterColMajor(problem.Tc, problem.C, strategy.C);
5931 auto tileO_V = params.osys;
5932 auto tileI_N = params.ksys;
5933 if (strategy.unroll[cColMajor ? LoopN : LoopM] == 1) tileI_N = 0;
5934 return cColMajor ? std::make_tuple(tileO_V, 0, tileI_N, 0)
5935 : std::make_tuple(0, tileI_N, 0, tileO_V);
5936 }
5937 return std::make_tuple(0, 0, 0, 0);
5938}
5939
5940// Do one outer product (k = 1 slice) of A*B, updating C. ha and hb are the
5941// k indices within the A and B chunks, respectively. A_copy, B_copy are the
5942// indices of the A, B copies to use.
5943template <HW hw>
5944void gemm_kernel_generator_t<hw>::outerProduct(int h, int ha, int hb,
5945 int opCount, const vector<RegisterBlock> &A_layout,
5946 const vector<RegisterBlock> &B_layout, const GRFMultirange &A_regs,
5947 const GRFMultirange &B_regs, const GEMMProblem &problem,
5948 const GEMMStrategy &strategy, GEMMState &state) {
5949 auto Ta = problem.Ta, Tb = problem.Tb, Tc = problem.Tc;
5950
5951 if (strategy.systolic) {
5952 outerProductSystolic(h, ha, hb, A_layout, B_layout, A_regs, B_regs,
5953 problem, strategy, state);
5954 return;
5955 }
5956 if (isGen9IGEMM(hw, Ta, Tb, Tc)) {
5957 outerProductGen9IGEMM(ha, hb, A_layout, B_layout, A_regs, B_regs,
5958 problem, strategy, state);
5959 return;
5960 }
5961
5962 bool mixedMode = ((Tc.real() == Type::f32)
5963 && (Ta.real() != Type::f32 || Tb.real() != Type::f32));
5964 bool useDP4A = (Ta.size() == 1 && Tb.size() == 1 && Tc.size() == 4
5965 && hw >= HW::Gen12LP);
5966
5967 int minOPCount = minOuterProductCount(hw, problem, strategy);
5968 int kChain = std::min(strategy.kChain, opCount);
5969 int aCP, bCP;
5970 std::tie(aCP, bCP) = targetKernelCrosspack(hw, problem, strategy);
5971
5972 int accNum = 0;
5973 Subregister Clast;
5974 int nec = elementsPerGRF(hw, Tc);
5975 bool globalCM = isLayoutColMajor(state.C_layout);
5976 int fmaSIMD = strategy.fmaSIMD;
5977
5978 bool csplit = false, mixedRC = false;
5979 int icompCount = 1, ocompCount = 1, ivcompCount = 1, ovcompCount = 1;
5980
5981 bool bfloat16WA = (Tc.real() == Type::f32)
5982 && ((globalCM ? Tb : Ta).real() == Type::bf16);
5983
5984 // Emit an FMA instruction.
5985 auto outputFMA = [&](const InstructionModifier &mod, const Subregister &A,
5986 const Subregister &B, const Subregister &C,
5987 const RegData &bcastSrc, bool colMajor, int hh,
5988 bool ivfirst, bool ivlast) {
5989 auto Cacc = AccumulatorRegister(accNum).sub(0, Tc.real().ngen());
5990 auto Csrc = (hh == 0 && ivfirst) ? C : Cacc;
5991 auto Cdst = (hh == opCount - minOPCount && ivlast) ? C : Cacc;
5992 if (useDP4A) {
5993 auto Ar = A.reinterpret(
5994 0, isSigned(A.getType()) ? DataType::d : DataType::ud);
5995 auto Br = B.reinterpret(
5996 0, isSigned(B.getType()) ? DataType::d : DataType::ud);
5997
5998 colMajor ? dp4a(mod, Cdst(1), Csrc(1), Ar(1), Br(0))
5999 : dp4a(mod, Cdst(1), Csrc(1), Br(1), Ar(0));
6000 } else if (C.isARF() && hw < HW::XeHP) {
6001 colMajor ? mac(mod, C(1), A(1), bcastSrc)
6002 : mac(mod, C(1), bcastSrc, B(1));
6003 } else {
6004 // On Gen12, always put broadcast in src2 for better bank conflict avoidance.
6005 colMajor ? mad(mod, Cdst(1), Csrc(1), A(1), bcastSrc)
6006 : (hw < HW::Gen12LP)
6007 ? mad(mod, Cdst(1), Csrc(1), bcastSrc, B(1))
6008 : mad(mod, Cdst(1), Csrc(1), B(1), bcastSrc);
6009 }
6010 };
6011
6012 ha = align_down(ha, opCount);
6013 hb = align_down(hb, opCount);
6014
6015 // Decide whether to loop in column or row major order.
6016 // x = vectorized dimension
6017 // y = non-vectorized dimension
6018 int nx = globalCM ? strategy.unroll[LoopM] : strategy.unroll[LoopN];
6019 int ny = globalCM ? strategy.unroll[LoopN] : strategy.unroll[LoopM];
6020 int nx1 = (mixedMode || state.broadcast) ? nx : fmaSIMD;
6021
6022 // Prepare for chaining FMAs through accumulator registers.
6023 int necAcc = nec * (csplit ? 2 : 1);
6024 int accCount = AccumulatorRegister::count(hw, strategy.GRFs, Tc.ngen());
6025 int accPerFMA = div_up(std::min(nx, fmaSIMD), necAcc);
6026 int minAccPerFMA = Tc.isFP() ? 1 : 2;
6027 accPerFMA = std::max(accPerFMA, minAccPerFMA);
6028 int independentAccs = div_up(accCount, accPerFMA);
6029
6030 int nx1i = 1, ny1 = 1;
6031 if (kChain > 1) {
6032 if (independentAccs < icompCount) hw_unsupported();
6033 int indepAccComp = div_up(independentAccs, icompCount);
6034
6035 nx1i = std::min(nx1, indepAccComp * fmaSIMD);
6036 ny1 = div_up(indepAccComp, div_up(nx1i, fmaSIMD));
6037 }
6038
6039 GRFRange broadcastRegs = state.broadcast_regs;
6040 Subregister lastBcastBase;
6041
6042 // Last A/B blocks found;
6043 const RegisterBlock *A_blockLast = nullptr, *B_blockLast = nullptr;
6044
6045 for (int x0 = 0; x0 < nx; x0 += nx1) {
6046 for (int ovcomp = 0; ovcomp < ovcompCount; ovcomp++) {
6047 for (int ocomp = 0; ocomp < ocompCount; ocomp++) {
6048 for (int y0 = 0; y0 < ny; y0 += ny1) {
6049 for (int x1 = 0; x1 < nx1 && (x0 + x1) < nx;) {
6050 int x1New = x1;
6051 for (int ivcomp = 0; ivcomp < ivcompCount; ivcomp++) {
6052 for (int hh = 0; hh < opCount; hh += minOPCount) {
6053 accNum = 0;
6054 for (int y1 = 0; y1 < ny1 && y0 + y1 < ny;
6055 y1++) {
6056 for (int x1i = x1; (x1i < x1 + nx1i)
6057 && (x0 + x1i < nx);) {
6058 auto x = x0 + x1i;
6059 auto y = y0 + y1;
6060 auto i = globalCM ? x : y;
6061 auto j = globalCM ? y : x;
6062 auto hha = ha + hh;
6063 auto hhb = hb + hh;
6064
6065 int fmaCount = 1;
6066
6067 for (int icomp = 0; icomp < icompCount;
6068 icomp++) {
6069 // Find the appropriate A and B registers.
6070 int na, nb;
6071 int vcomp = ivcomp + ovcomp;
6072 int ncomp = (vcomp ^ ocomp) + icomp;
6073 int compA
6074 = globalCM ? vcomp : ncomp;
6075 int compB
6076 = globalCM ? ncomp : vcomp;
6077
6078 const RegisterBlock *A_block,
6079 *B_block;
6080 Subregister A = findBlockReg(Ta,
6081 A_layout, i, hha, A_regs,
6082 na, A_block, compA);
6083 Subregister B = findBlockReg(Tb,
6084 B_layout, hhb, j, B_regs,
6085 nb, B_block, compB);
6086
6087 // Check for expected crosspack.
6088 if (globalCM ? (aCP
6089 && A_block->crosspack
6090 != aCP)
6091 : (bCP
6092 && B_block->crosspack
6093 != bCP))
6094 stub();
6095
6096 // Check if we should specify {Atomic}.
6097 bool atomic = (strategy.atomicFMA
6098 && (A_block == A_blockLast)
6099 && (B_block
6100 == B_blockLast));
6101 A_blockLast = A_block;
6102 B_blockLast = B_block;
6103
6104 // Find the appropriate C register.
6105 int C_buffer = csplit
6106 ? 0
6107 : (icomp + ocomp);
6108 int compC = csplit ? ocomp : 0;
6109 int nc;
6110 const RegisterBlock *C_block;
6111 Subregister C = findBlockReg(Tc,
6112 state.C_layout, i, j,
6113 state.C_regs[C_buffer], nc,
6114 C_block, compC);
6115 if (C_block->crosspack > 1) stub();
6116
6117 // Swap out C register for an accumulator, if necessary.
6118 if (strategy.cAccumulators) {
6119 auto C_roff = C.getBase()
6120 - state.C_regs[0]
6121 .ranges[0]
6122 .getBase();
6123 if (C_roff < state.C_accCount)
6124 C = AccumulatorRegister(
6125 C_roff)
6126 .sub(C.getOffset(),
6127 Tc.ngen());
6128 }
6129
6130 InstructionModifier mod;
6131
6132 // Use requested execution size if possible, but limited to available elements.
6133 // Decide on kernel type based on register block layouts.
6134 bool canColMajor
6135 = (A_block->colMajor
6136 && globalCM);
6137 bool canRowMajor
6138 = (!B_block->colMajor
6139 && !globalCM);
6140 bool colMajor = globalCM;
6141
6142 if (!canColMajor && !canRowMajor)
6143 fmaCount = 1;
6144 else if (canColMajor)
6145 fmaCount = rounddown_pow2(
6146 std::min({fmaSIMD, na,
6147 nc}));
6148 else
6149 fmaCount = rounddown_pow2(
6150 std::min({fmaSIMD, nb,
6151 nc}));
6152
6153 int simdSize = fmaCount;
6154
6155 // Crosspacked kernels: ensure broadcast matrix is contiguous in k.
6156 if (minOPCount > 1) {
6157 bool nativeDir = (globalCM
6158 ? B_block->colMajor
6159 : !A_block->colMajor);
6160 auto bcastCrosspack
6161 = (globalCM ? B_block
6162 : A_block)
6163 ->crosspack;
6164 if (nativeDir) {
6165 if ((globalCM ? nb : na)
6166 < minOPCount)
6167 stub();
6168 if (bcastCrosspack > 1)
6169 stub();
6170 } else {
6171 if (bcastCrosspack
6172 % minOPCount)
6173 stub();
6174 }
6175 }
6176
6177 // Add Atomic if appropriate.
6178 if (atomic) mod |= Atomic;
6179
6180 // Handle broadcast duties.
6181 Subregister bcastSrcSub
6182 = colMajor ? B : A;
6183 RegData bcastSrc = bcastSrcSub;
6184
6185 if (state.broadcast) {
6186
6187 // Broadcast if necessary: pair of doubles (doubleWA) or single elements.
6188 int nbcast = strategy.doubleWA
6189 ? 2
6190 : 1;
6191 int hs = strategy.doubleWA
6192 ? 0
6193 : nbcast;
6194
6195 auto bcastType
6196 = bcastSrc.getType();
6197 Subregister bcastBase
6198 = bcastSrcSub;
6199 bcastBase.setOffset(
6200 bcastBase.getOffset()
6201 & ~(nbcast - 1));
6202
6203 if (bcastBase
6204 != lastBcastBase) {
6205 auto bcastRegion = bcastBase(
6206 0, nbcast,
6207 (nbcast > 1) ? 1
6208 : 0);
6209 if (bfloat16WA) {
6210 // Upconvert to f32 during broadcast.
6211 bcastRegion.setType(
6212 DataType::uw);
6213 shl(simdSize,
6214 broadcastRegs[0]
6215 .ud(),
6216 bcastRegion,
6217 16);
6218 } else {
6219 moveToIntPipe(simdSize,
6220 bcastRegion);
6221 mov(simdSize * bcastSrc.getBytes()
6222 / bcastRegion
6223 .getBytes(),
6224 broadcastRegs[0].retype(
6225 bcastRegion
6226 .getType()),
6227 bcastRegion);
6228 }
6229 }
6230 if (bfloat16WA)
6231 bcastType = DataType::f;
6232 bcastSrc = broadcastRegs[0].sub(
6233 bcastSrc.getOffset()
6234 & (nbcast - 1),
6235 bcastType)(hs);
6236 lastBcastBase = bcastBase;
6237 }
6238
6239 bool ivfirst
6240 = mixedRC || (ivcomp == 0);
6241 bool ivlast = mixedRC
6242 || (ivcomp
6243 == ivcompCount - 1);
6244
6245 // Finally, perform the long-awaited FMA.
6246 outputFMA(simdSize | mod, A, B, C,
6247 bcastSrc, colMajor, hh,
6248 ivfirst, ivlast);
6249 Clast = C;
6250
6251 if (kChain > 1
6252 && accNum >= accCount)
6253 stub();
6254 accNum += std::max(minAccPerFMA,
6255 div_up(fmaCount, necAcc));
6256 } /* icomp */
6257
6258 x1i += fmaCount;
6259 x1New = x1i;
6260 } /* x1i */
6261 } /* y1 */
6262 } /* hh */
6263 } /* ivcomp */
6264 x1 = x1New;
6265 } /* x1 */
6266 } /* y0 */
6267 } /* ocomp */
6268 } /* ovcomp */
6269 } /* x0 */
6270}
6271
6272template <HW hw>
6273void gemm_kernel_generator_t<hw>::outerProductGen9IGEMM(int ha, int hb,
6274 const vector<RegisterBlock> &A_layout,
6275 const vector<RegisterBlock> &B_layout, const GRFMultirange &A_regs,
6276 const GRFMultirange &B_regs, const GEMMProblem &problem,
6277 const GEMMStrategy &strategy, GEMMState &state) {
6278 auto Ta = problem.Ta, Tb = problem.Tb, Tc = problem.Tc;
6279 DataType tempType
6280 = (Ta.isSigned() || Tb.isSigned()) ? DataType::w : DataType::uw;
6281
6282 struct AddItem {
6283 int simd;
6284 RegData dest, src0, src1;
6285 };
6286 std::vector<AddItem> adds;
6287
6288 auto replayAdds = [&]() {
6289 for (auto &item : adds)
6290 add(item.simd, item.dest, item.src0, item.src1);
6291 adds.clear();
6292 };
6293
6294 bool globalCM = isLayoutColMajor(state.C_layout);
6295
6296 // Decide whether to loop in column or row major order.
6297 int nx = globalCM ? strategy.unroll[LoopM] : strategy.unroll[LoopN];
6298 int ny = globalCM ? strategy.unroll[LoopN] : strategy.unroll[LoopM];
6299
6300 int tidx = 0;
6301 for (int y = 0; y < ny; y++) {
6302 for (int x = 0; x < nx;) {
6303 auto i = globalCM ? x : y;
6304 auto j = globalCM ? y : x;
6305
6306 int fmaCount;
6307
6308 // Find the appropriate A and B registers.
6309 int na, nb;
6310 const RegisterBlock *A_block, *B_block;
6311 Subregister A
6312 = findBlockReg(Ta, A_layout, i, ha, A_regs, na, A_block);
6313 Subregister B
6314 = findBlockReg(Tb, B_layout, hb, j, B_regs, nb, B_block);
6315
6316 // Find the appropriate C register. Todo: remainders.
6317 int nc;
6318 const RegisterBlock *C_block;
6319 Subregister C = findBlockReg(
6320 Tc, state.C_layout, i, j, state.C_regs[0], nc, C_block);
6321
6322 // No C crosspack support.
6323 auto cpA = A_block->crosspack, cpB = B_block->crosspack;
6324 if (C_block->crosspack > 1) stub();
6325
6326 // Swap out C register for an accumulator, if necessary.
6327 auto C_roff = C.getBase() - state.C_regs[0].ranges[0].getBase();
6328 if (C_roff < state.C_accCount)
6329 C = AccumulatorRegister(C_roff).sub(C.getOffset(), Tc.ngen());
6330
6331 // Use requested execution size if possible, but limited to available elements.
6332 // Decide the kernel type based on register block layouts.
6333 bool canColMajor = (A_block->colMajor && C_block->colMajor);
6334 bool canRowMajor = (!B_block->colMajor && !C_block->colMajor);
6335 bool colMajor;
6336
6337 if (!canColMajor && !canRowMajor) {
6338 colMajor = true;
6339 fmaCount = 1;
6340 } else if (canColMajor) {
6341 colMajor = true;
6342 fmaCount = na;
6343 } else {
6344 colMajor = false;
6345 fmaCount = nb;
6346 }
6347 fmaCount = rounddown_pow2(std::min(
6348 {strategy.fmaSIMD, nc, elementsPerGRF<int16_t>(hw)}));
6349
6350 auto temp = state.tempMul_regs[tidx++];
6351
6352 if (C.isARF()) {
6353 if (colMajor)
6354 mac(fmaCount, C(1), A(cpA), B(0));
6355 else
6356 mac(fmaCount, C(1), A(0), B(cpB));
6357 } else {
6358 if (colMajor)
6359 mul(fmaCount, temp[0].sub(0, tempType)(2), A(cpA), B(0));
6360 else
6361 mul(fmaCount, temp[0].sub(0, tempType)(2), A(0), B(cpB));
6362
6363 adds.push_back(
6364 {fmaCount, C(1), C(1), temp[0].sub(0, tempType)(2)});
6365 }
6366
6367 if (tidx >= int(state.tempMul_regs.size())) {
6368 tidx = 0;
6369 replayAdds();
6370 }
6371
6372 x += fmaCount;
6373 }
6374 }
6375
6376 replayAdds();
6377
6378 // A4B4 outer product (4 temporary GRFs per 2 C registers) - 2/3 SP
6379 //
6380 // mul (32) temp0.0:w<1> A.0:b<32;16,2> B.0:b<32;16,2> - EM
6381 // mul (32) temp2.0:w<1> A.1:b<32;16,2> B.1:b<32;16,2> - FPU
6382 // add (16) C.0:d<1> C.0:d<8;8,1> temp0.0:w<16;8,2> - EM
6383 // add (16) C.0:d<1> C.0:d<8;8,1> temp0.1:w<16;8,2> - FPU
6384 // add (16) C.0:d<1> C.0:d<8;8,1> temp2.0:w<16;8,2> - EM
6385 // add (16) C.0:d<1> C.0:d<8;8,1> temp2.1:w<16;8,2> - FPU
6386
6387 // Faster A4B4 outer product a la non-VNNI (4 temporary GRFs per 2 C registers) - 4/5 SP
6388 //
6389 // mul (32) temp0.0:w<1> A.0:b<32;16,2> B.0:b<32;16,2> - EM
6390 // mul (32) temp2.0:w<1> A.1:b<32;16,2> B.1:b<32;16,2> - FPU
6391 // add (32) (sat) temp0.0:w<1> temp0.0:w<1> temp2.0:w<1> - EM/FPU
6392 // add (16) C.0:d<1> C.0:d<8;8,1> temp0.0:w<16;8,2> - EM
6393 // add (16) C.0:d<1> C.0:d<8;8,1> temp0.1:w<16;8,2> - FPU
6394}
6395
6396static int elementDiff(HW hw, const RegData &r1, const RegData &r2) {
6397 return elementsPerGRF(hw, r1.getType()) * (r1.getBase() - r2.getBase())
6398 + (r1.getOffset() - r2.getOffset());
6399}
6400
6401// Accumulate multiple outer products using the systolic array.
6402template <HW hw>
6403void gemm_kernel_generator_t<hw>::outerProductSystolic(int h, int ha, int hb,
6404 const vector<RegisterBlock> &A_layout,
6405 const vector<RegisterBlock> &B_layout, const GRFMultirange &A_regs,
6406 const GRFMultirange &B_regs, const GEMMProblem &problem,
6407 const GEMMStrategy &strategy, GEMMState &state) {
6408 auto Ta = problem.Ta, Tb = problem.Tb, Tc = problem.Tc;
6409 bool globalCM = isLayoutColMajor(state.C_layout);
6410 auto params = systolicParams(hw, problem, strategy);
6411 auto ksys = params.ksys;
6412 auto osys = params.osys;
6413 auto sdepth = params.sdepth;
6414 auto rcountMax = params.rcountMax;
6415 int dpaswTile = strategy.dpasw
6416 ? (globalCM ? strategy.B.tileC : strategy.A.tileR)
6417 : 0;
6418 bool rsFix = (strategy.readSuppressionWA
6419 && hasMasking(globalCM ? A_layout : B_layout));
6420 bool canAtomicNon8x8
6421 = (hw >= HW::XeHPC) && (getStepping() >= SteppingPVCXTB0);
6422
6423 bool sum = globalCM ? state.systolicSumA : state.systolicSumB;
6424
6425 RegisterBlock sumBlock;
6426 sumBlock.colMajor = globalCM;
6427 sumBlock.crosspack = 1;
6428
6429 // dpas processes ksys outer products at once.
6430 ha = align_down(ha, ksys);
6431 hb = align_down(hb, ksys);
6432
6433 // Decide whether to loop in column or row major order, to facilitate macro sequences.
6434 // x is the non-accumulating dimension of dpas src2 (N matrix)
6435 // y is the non-accumulating dimension of dpas src1 (V matrix)
6436 int nx = strategy.unroll[globalCM ? LoopN : LoopM];
6437 int ny = strategy.unroll[globalCM ? LoopM : LoopN];
6438
6439 int yinc = osys;
6440
6441 const int compA = 0, compB = 0, compC = 0;
6442 const int incompCount = 1, oncompCount = 1;
6443
6444 for (int y = 0; y < ny; y += yinc) {
6445 Subregister A0, B0, C0;
6446 int rcount = 0, x0 = 0;
6447
6448 auto issueDPAS = [&](bool last) {
6449 while (rcount > 0) {
6450 InstructionModifier mod = osys;
6451
6452 bool useDPASW = strategy.dpasw && x0 < nx;
6453 auto rc2 = rounddown_pow2(rcount);
6454 auto rc = rc2 * (useDPASW ? 2 : 1);
6455 auto &V0 = globalCM ? A0 : B0;
6456 auto &N0 = globalCM ? B0 : A0;
6457
6458 if (rsFix) {
6459 GRF v0GRF {V0.getBase()};
6460 mov<uint32_t>(8, v0GRF, v0GRF);
6461 rsFix = false;
6462 }
6463
6464 if (strategy.atomicFMA)
6465 if (!(last && (rc2 == rcount)))
6466 if (rc == 8 || canAtomicNon8x8) mod |= Atomic;
6467
6468 useDPASW ? dpasw(mod, sdepth, rc, C0, C0, V0, N0)
6469 : dpas(mod, sdepth, rc, C0, C0, V0, N0);
6470
6471 rcount -= rc2;
6472 x0 += rc;
6473 N0.setBase(N0.getBase() + rc2);
6474 C0.setBase(C0.getBase() + rc2);
6475 }
6476 };
6477
6478 for (int oncomp = 0; oncomp < oncompCount; oncomp++) {
6479 for (int x = 0; x < nx + sum; x++) {
6480 for (int incomp = 0; incomp < incompCount; incomp++) {
6481 // Find the appropriate A and B registers.
6482 int na, nb, nc;
6483 const RegisterBlock *A_block, *B_block, *C_block;
6484 Subregister A, B, C;
6485
6486 const int cxCompA = -1, cxCompB = -1, cxCompC = -1,
6487 cBuffer = 0;
6488
6489 if (x < nx) {
6490 if (strategy.dpasw
6491 && (x % (2 * dpaswTile) >= dpaswTile))
6492 continue;
6493
6494 int i = globalCM ? y : x;
6495 int j = globalCM ? x : y;
6496
6497 A = findBlockReg(Ta, A_layout, i, ha, A_regs, na,
6498 A_block, cxCompA, compA);
6499 B = findBlockReg(Tb, B_layout, hb, j, B_regs, nb,
6500 B_block, cxCompB, compB);
6501 C = findBlockReg(Tc, state.C_layout, i, j,
6502 state.C_regs[cBuffer], nc, C_block, cxCompC,
6503 compC);
6504 } else if (state.systolicSumA) {
6505 A = findBlockReg(
6506 Ta, A_layout, y, ha, A_regs, na, A_block);
6507 B = state.sysSumAll1s[0];
6508 nb = elementsPerGRF(hw, Tb);
6509 B_block = &sumBlock;
6510 C = findBlockReg(Tc, state.As_layout, y, 0,
6511 state.As_regs, nc, C_block);
6512 } else {
6513 A = state.sysSumAll1s[0];
6514 na = elementsPerGRF(hw, Ta);
6515 A_block = &sumBlock;
6516 B = findBlockReg(
6517 Tb, B_layout, hb, y, B_regs, nb, B_block);
6518 C = findBlockReg(Tc, state.Bs_layout, 0, y,
6519 state.Bs_regs, nc, C_block);
6520 }
6521
6522 int nv = globalCM ? na : nb;
6523 int nn = globalCM ? nb : na;
6524
6525 // Verify DPAS requirements.
6526 if (globalCM) {
6527 if (A_block->crosspack * Ta.real().size()
6528 != std::max(4, Ta.real().size()))
6529 stub();
6530 if (B_block->crosspack > 1) stub();
6531 } else {
6532 if (B_block->crosspack * Tb.real().size()
6533 != std::max(4, Tb.real().size()))
6534 stub();
6535 if (A_block->crosspack > 1) stub();
6536 }
6537 if (A_block->colMajor != globalCM
6538 || B_block->colMajor != globalCM)
6539 stub();
6540 if (C_block->crosspack > 1) stub();
6541
6542 if (nv != osys) stub();
6543 if (nn < ksys) stub();
6544
6545 // Check if current DPAS can be fused with the previous one.
6546 bool chain = false;
6547 if (A0.isValid()) {
6548 chain = globalCM
6549 ? (elementDiff(hw, B, B0) == (x - x0) * ksys)
6550 : (elementDiff(hw, A, A0) == (x - x0) * ksys);
6551 chain = chain
6552 && (elementDiff(hw, C, C0) == (x - x0) * osys);
6553 chain = chain && (rcount < rcountMax);
6554 if (strategy.dpasw)
6555 chain = chain && x < nx
6556 && (x % (2 * dpaswTile) > 0);
6557 }
6558
6559 if (chain)
6560 rcount++;
6561 else {
6562 if (strategy.dpasw && x < nx && rcount > 0
6563 && rcount != dpaswTile)
6564 stub();
6565 if (A0.isValid()) issueDPAS(false);
6566 A0 = A;
6567 B0 = B;
6568 C0 = C;
6569 rcount = 1;
6570 A0.setType(Ta.ngen());
6571 B0.setType(Tb.ngen());
6572 C0.setType(Tc.ngen());
6573 x0 = x;
6574 }
6575 } /* incomp loop */
6576 } /* x loop */
6577 } /* oncomp loop */
6578
6579 bool finishChain = !strategy.extendedAtomicFMA || (y + osys >= ny);
6580 issueDPAS(finishChain);
6581 } /* y loop */
6582}
6583
6584// Decide whether to use the legacy post-op injector inside C update.
6585// Needed if we can't convert C to f32 in-place, but doesn't support binary post-ops.
6586static inline bool useEltwiseInjector(const GEMMProblem &problem) {
6587 return problem.hasPostOp() && (problem.Tc.size() < 4);
6588}
6589
6590// Perform C update operation on C_acc, given original C data in C_load.
6591// All inputs and outputs are assumed to be of type problem.Ts.
6592template <HW hw>
6593void gemm_kernel_generator_t<hw>::updateC(const GRFMultirange &C_acc,
6594 const GRFMultirange &C_accSwap, const GRFMultirange &C_load,
6595 GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state) {
6596 auto &alphar = problem.alpha_real;
6597 auto &betar = problem.beta_real;
6598 bool alpha1 = (alphar == 1);
6599 bool alphaM1 = (alphar == -1);
6600 bool beta1 = (betar == 1);
6601 bool beta0 = (betar == 0);
6602 bool betaM1 = (betar == -1);
6603
6604#define FOR_EACH_C(f) \
6605 do { \
6606 map(hw, state.Tacc.real(), C_load, C_acc, strategy, \
6607 [&](int esize, GRF loaded, GRF acc) { f; }); \
6608 } while (false)
6609
6610#define FOR_EACH_C_CX(f) \
6611 do { \
6612 map(hw, state.Tacc.real(), C_load, C_acc, C_accSwap, strategy, \
6613 [&](int esize, GRF loaded, GRF acc, GRF accswap) { f; }); \
6614 } while (false)
6615
6616 if (!beta0) {
6617 if (alpha1 || alphaM1) {
6618 if (beta1)
6619 FOR_EACH_C(add(esize, acc, loaded, alpha1 ? acc : -acc));
6620 else if (betaM1)
6621 FOR_EACH_C(add(esize, acc, -loaded, alpha1 ? acc : -acc));
6622 else if (betar.fixed())
6623 stub(); // beta should be put in a register first.
6624 else {
6625 if (!strategy.doubleWA)
6626 FOR_EACH_C(mad(esize, acc, alpha1 ? acc : -acc, loaded,
6627 betar.getRegAvoiding(hw, loaded)));
6628 else {
6629 FOR_EACH_C(mul(esize, loaded, loaded,
6630 betar.getRegAvoiding(hw, loaded)));
6631 FOR_EACH_C(add(esize, acc, loaded, alpha1 ? acc : -acc));
6632 }
6633 }
6634 } else {
6635 bool neg = false;
6636 if (!beta1) {
6637 if (betaM1)
6638 neg = true;
6639 else if (!betar.fixed())
6640 FOR_EACH_C(mul(esize, loaded, loaded,
6641 betar.getRegAvoiding(hw, acc)));
6642 else
6643 stub();
6644 }
6645 if (alphar.fixed())
6646 stub(); // alpha should be put in a register first.
6647 else {
6648 if (!strategy.doubleWA)
6649 FOR_EACH_C(mad(esize, acc, neg ? -loaded : loaded, acc,
6650 alphar.getRegAvoiding(hw, acc)));
6651 else {
6652 FOR_EACH_C(mul(
6653 esize, acc, acc, alphar.getRegAvoiding(hw, acc)));
6654 FOR_EACH_C(add(esize, acc, neg ? -loaded : loaded, acc));
6655 }
6656 }
6657 }
6658 } else if (alphaM1)
6659 FOR_EACH_C(mov(esize, acc, -acc));
6660 else if (alpha1)
6661 /* no op */;
6662 else if (alphar.fixed())
6663 stub(); // alpha should be put in a register first.
6664 else {
6665 FOR_EACH_C(mul(esize, acc, acc, alphar.getRegAvoiding(hw, acc)));
6666 }
6667
6668 if (useEltwiseInjector(problem)) {
6669 Label labelPostOpDone;
6670 bool allocFlag = state.flagAP.isInvalid();
6671 auto flagNonfinal = allocFlag ? state.raVFlag.alloc() : state.flagAP;
6672 and_(1 | nz | flagNonfinal, null.ud(), state.inputs.flags,
6673 FlagNonfinalKBlock);
6674 jmpi(1 | flagNonfinal, labelPostOpDone);
6675 if (allocFlag) state.raVFlag.safeRelease(flagNonfinal);
6676 if (state.Tacc != Type::f32 || !postOpInjector) stub();
6677 for (const auto &range : C_acc.ranges)
6678 postOpInjector->compute(range);
6679 mark(labelPostOpDone);
6680 }
6681
6682#undef FOR_EACH_C
6683#undef FOR_EACH_C_CX
6684}
6685
6686template <HW hw>
6687bool gemm_kernel_generator_t<hw>::reblockLayout(Type Tdst,
6688 vector<int32_t> &blockMap, vector<RegisterBlock> &layoutDst,
6689 const vector<RegisterBlock> &layoutRef,
6690 const vector<RegisterBlock> &layoutSrc, const MatrixAddressing &atype,
6691 const MatrixAddressingStrategy &astrategy) {
6692 auto nblockRef = layoutRef.size();
6693 layoutDst.clear();
6694 layoutDst.reserve(nblockRef);
6695 blockMap.clear();
6696 blockMap.reserve(nblockRef + 1);
6697 blockMap.push_back(0);
6698 for (auto &blockRef : layoutRef) {
6699 RegisterBlock blockDst, blockMid;
6700 for (auto &blockSrc : layoutSrc) {
6701 int rr1 = blockRef.offsetR - blockSrc.offsetR,
6702 rr2 = rr1 + blockRef.nr;
6703 int cc1 = blockRef.offsetC - blockSrc.offsetC,
6704 cc2 = cc1 + blockRef.nc;
6705 if (rr1 >= blockSrc.nr || rr2 <= 0) continue;
6706 if (cc1 >= blockSrc.nc || cc2 <= 0) continue;
6707 rr1 = std::max(rr1, 0);
6708 cc1 = std::max(cc1, 0);
6709 rr2 = std::min(rr2, int(blockSrc.nr));
6710 cc2 = std::min(cc2, int(blockSrc.nc));
6711 if (!getSubblock(Tdst, blockMid, blockSrc, false, rr1, rr2, rr1,
6712 rr2, true, atype, astrategy))
6713 return false;
6714 if (!getSubblock(Tdst, blockDst, blockMid, true, cc1, cc2, cc1, cc2,
6715 true, atype, astrategy))
6716 return false;
6717 layoutDst.push_back(blockDst);
6718 }
6719 blockMap.push_back(int32_t(layoutDst.size()));
6720 }
6721 return true;
6722}
6723
6724// Update an entire C layout.
6725template <HW hw>
6726void gemm_kernel_generator_t<hw>::updateCLayout(
6727 const vector<RegisterBlock> &layoutExt, const GRFRange (&C_addr0)[2],
6728 const RegisterBlock &C_block0, COperation op, GEMMProblem &problem,
6729 GEMMStrategy &strategy, GEMMState &state) {
6730#define FOR_EACH_C for (int q = 0; q < C_count; q++)
6731 auto Tc = problem.Tc, Tc_ext = problem.Tc_ext, Ts = problem.Ts;
6732 bool loadOnly = (op == COperation::Load);
6733 bool beta0 = problem.beta0();
6734 bool needLoad = (!beta0 && !loadOnly);
6735 bool copyC = state.copyC;
6736 int C_count = (op == COperation::UpdateStore) ? state.C_count : 1;
6737
6738 auto nblocks = int(layoutExt.size());
6739 bool haveDescs = layoutExt[0].descAssigned;
6740
6741 vector<GRFRange>(&C_addrs)[2] = state.C_addrs;
6742 GRFMultirange C_extRange, C_copyRange;
6743 GRFMultirange &C_accRange = state.C_regs[0];
6744 auto &C_extRegs = C_extRange.ranges;
6745 auto &C_copyRegs = C_copyRange.ranges;
6746 vector<GRFRange> C_convertRegs;
6747
6748 for (int q = 0; q < C_count; q++)
6749 C_addrs[0].clear();
6750
6751 // Map layout to blocks in internal C layout as needed.
6752 vector<RegisterBlock> layout;
6753 vector<int> blockMap;
6754 if (copyC) {
6755 if (!reblockLayout(Tc, blockMap, layout, layoutExt, state.C_layout,
6756 problem.C, strategy.C))
6757 stub();
6758 } else {
6759 layout = layoutExt;
6760 blockMap.resize(nblocks + 1);
6761 for (int i = 0; i <= nblocks; i++)
6762 blockMap[i] = i;
6763 }
6764
6765 // Prepare for late C conversion.
6766 bool lateCConvert = (!loadOnly && !strategy.C.atomic
6767 && problem.needsTsConvert() && state.Tacc != Ts);
6768 bool copyCLoad = needLoad && (copyC || lateCConvert);
6769 if (lateCConvert && Tc.isComplex()) stub();
6770
6771 // Load as much of C as is possible at a time, given register space.
6772 for (int lstart = 0; lstart < nblocks;) {
6773 int lend;
6774
6775 // Allocate address and data registers for C updating. If allocator chokes,
6776 // proceed with the registers we were able to allocate.
6777 //
6778 // At the same time, build up three layouts for this chunk of C:
6779 // sublayoutExt: C data to be loaded/stored
6780 // sublayoutCopy: copied C data
6781 // sublayoutAcc: C data in accumulators
6782 bool allocOK = true;
6783 auto tryAlloc = [&](int regs, Bundle hint = Bundle()) {
6784 auto range = state.ra.try_alloc_range(regs, hint);
6785 allocOK &= range.isValid();
6786 return range;
6787 };
6788
6789 vector<RegisterBlock> sublayoutExt, sublayoutCopy, sublayoutAcc;
6790 size_t sublayoutCopySize = 0;
6791 int bytes = 0, bytesConvert = 0;
6792 int tokens = 0, maxTokens = 256;
6793 if (needLoad && hw >= HW::Gen12LP) maxTokens = tokenCount(hw);
6794
6795 for (lend = lstart; (lend < nblocks) && (tokens < maxTokens);
6796 lend++, tokens++) {
6797 auto li0 = blockMap[lend], li1 = blockMap[lend + 1];
6798 int expand
6799 = lateCConvert ? div_up(Ts.size(), state.Tacc.size()) : 1;
6800
6801 if (copyCLoad)
6802 for (int li = li0; li < li1; li++) {
6803 auto block = layout[li];
6804 block.compact(state.Tacc);
6805 block.offsetBytes = bytesConvert;
6806 bytesConvert += block.nregs() * expand * GRF::bytes(hw);
6807 sublayoutCopy.push_back(block);
6808 }
6809
6810 auto blockExt = layoutExt[lend];
6811 auto naddr = addrGRFCount(problem.C, strategy.C, blockExt);
6812 FOR_EACH_C C_addrs[q].push_back(
6813 (blockExt.offsetR == 0 && blockExt.offsetC == 0)
6814 ? C_addr0[q]
6815 : tryAlloc(naddr));
6816 if (needLoad || copyC)
6817 C_extRegs.push_back(tryAlloc(
6818 blockExt.nregs(), getHint(HintType::CLoad, strategy)));
6819 if (copyCLoad)
6820 for (int li = li0; li < li1; li++)
6821 C_copyRegs.push_back(tryAlloc(
6822 sublayoutCopy[li - li0 + sublayoutCopySize].nregs()
6823 * expand,
6824 getHint(HintType::CLoad, strategy)));
6825 if (lateCConvert)
6826 for (int li = li0; li < li1; li++)
6827 C_convertRegs.push_back(
6828 tryAlloc(layout[li].nregs() * expand));
6829 if (!allocOK) break;
6830
6831 blockExt.offsetBytes = bytes;
6832 bytes += blockExt.nregs() * GRF::bytes(hw);
6833 sublayoutExt.push_back(blockExt);
6834
6835 sublayoutCopySize = sublayoutCopy.size();
6836 }
6837
6838 sublayoutCopy.resize(sublayoutCopySize);
6839
6840 int listart = blockMap[lstart];
6841 int liend = blockMap[lend];
6842
6843 sublayoutAcc.reserve(liend - listart);
6844 for (int l = listart; l < liend; l++)
6845 sublayoutAcc.push_back(layout[l]);
6846
6847 // Set up C addresses relative to prior blocks.
6848 for (int l = lstart; l < lend; l++) {
6849 auto &block = sublayoutExt[l - lstart];
6850 int bbase = findBaseBlock(
6851 block, sublayoutExt, 0, l - lstart, problem.C, strategy.C);
6852 FOR_EACH_C {
6853 auto &blockSrc = (bbase >= 0) ? sublayoutExt[bbase] : C_block0;
6854 auto &addrSrc = (bbase >= 0) ? C_addrs[q][bbase] : C_addr0[q];
6855 setupAddrRel(Tc_ext, C_addrs[q][l - lstart], addrSrc, block,
6856 blockSrc, state.C_layout, state.inputs.ldc[q],
6857 problem.C, strategy.C, strategy, state,
6858 state.ldcMultiples[q]);
6859 }
6860 }
6861
6862 if (strategy.C.atomic) {
6863 // Atomic update.
6864 // Alpha scaling is done earlier; beta scaling isn't supported.
6865 if (!problem.alpha1() || !problem.beta1()) stub();
6866 if (copyC)
6867 if (!copyRegisters(state.Tacc, Tc_ext, sublayoutAcc,
6868 sublayoutExt, C_accRange, C_extRange, 0, 0, false,
6869 strategy, state))
6870 stub();
6871
6872 auto &sublayoutSrc = copyC ? sublayoutExt : sublayoutAcc;
6873 auto &C_srcRange = copyC ? C_extRange : C_accRange;
6874 FOR_EACH_C atomicAddMatrix(Tc_ext, C_srcRange, sublayoutSrc,
6875 problem.C, strategy.C, C_addrs[q], problem, strategy,
6876 state);
6877 } else {
6878 // Data types before and after scaling phase.
6879 auto Tacc_final = Tc;
6880 if (op == COperation::Update
6881 || (op == COperation::UpdateStore && copyC))
6882 Tacc_final = state.Tacc;
6883
6884 // Regular update.
6885 auto Tload = Tc_ext;
6886 if (!beta0 || loadOnly) {
6887 // Set up a0.0 descriptor for loads if needed.
6888 if (lstart > 0 && haveDescs) mov(1, a0.ud(0), a0.ud(3));
6889
6890 // Load C data.
6891 auto &sublayoutLoad
6892 = (loadOnly && !copyC) ? sublayoutAcc : sublayoutExt;
6893 auto &C_loadRange
6894 = (loadOnly && !copyC) ? C_accRange : C_extRange;
6895 loadMatrix(C_loadRange, sublayoutLoad, problem.C, strategy.C,
6896 C_addrs[0], strategy, state);
6897
6898 // Set up a0.0 descriptor for stores (and save load descriptors) if needed.
6899 if (haveDescs && !loadOnly) {
6900 if (lend < nblocks) mov(1, a0.ud(3), a0.ud(0));
6901 mov(1, a0.ud(0), a0.ud(2));
6902 }
6903
6904 // Copy loaded data as needed.
6905 if (copyCLoad) {
6906 auto &sublayoutDst
6907 = loadOnly ? sublayoutAcc : sublayoutCopy;
6908 auto &C_dstRange = loadOnly ? C_accRange : C_copyRange;
6909 Tload = lateCConvert ? Ts : state.Tacc;
6910 if (!copyRegisters(Tc_ext, Tload, sublayoutExt,
6911 sublayoutDst, C_extRange, C_dstRange, 0, 0,
6912 false, strategy, state))
6913 stub();
6914 }
6915 }
6916
6917 // Late C conversion.
6918 auto originalTacc = state.Tacc;
6919 if (lateCConvert) {
6920 for (int li = listart; li < liend; li++) {
6921 auto C_acc = state.C_regs[0].subrange(
6922 hw, state.Tacc, layout[li]);
6923 copyRegisterBlock(state.Tacc, Ts, layout[li], layout[li],
6924 C_acc, C_convertRegs[li - listart], 0, 0, strategy,
6925 state);
6926 }
6927 state.Tacc = Ts;
6928 }
6929
6930 // Alpha/beta scaling and optional fp32<->int32 conversion.
6931 if (!loadOnly)
6932 for (int phase = 0; phase < 3; phase++) {
6933 vector<GRFMultirange> C_accs, C_accSwaps, C_loads;
6934 C_accs.reserve(liend - listart);
6935 C_accSwaps.reserve(liend - listart);
6936 C_loads.reserve(liend - listart);
6937
6938 for (int li = listart; li < liend; li++) {
6939 GRFMultirange C_acc0 = state.C_regs[0].subrange(
6940 hw, state.Tacc, layout[li]);
6941 GRFMultirange C_acc = lateCConvert
6942 ? C_convertRegs[li - listart]
6943 : C_acc0;
6944 GRFMultirange C_accSwap;
6945 GRFMultirange C_load = beta0
6946 ? C_acc
6947 : copyCLoad ? C_copyRegs[li - listart]
6948 : C_extRegs[li - listart];
6949 switch (phase) {
6950 case 0:
6951 if (!beta0)
6952 convert(C_load, Tload, state.Tacc, problem,
6953 strategy, state);
6954 break;
6955 case 1: {
6956 C_accs.push_back(C_acc);
6957 C_accSwaps.push_back(C_accSwap);
6958 C_loads.push_back(C_load);
6959 } break;
6960 case 2:
6961 if (lateCConvert)
6962 copyRegisterBlock(state.Tacc, Tacc_final,
6963 layout[li], layout[li], C_acc,
6964 C_acc0, 0, 0, strategy, state);
6965 else
6966 convert(C_acc, state.Tacc, Tacc_final,
6967 problem, strategy, state);
6968 break;
6969 }
6970 }
6971
6972 if (phase == 1) {
6973 std::vector<int> order(liend - listart);
6974 std::iota(order.begin(), order.end(), 0);
6975 std::sort(
6976 order.begin(), order.end(), [&](int a, int b) {
6977 auto *rangeA = &C_accs[a],
6978 *rangeB = &C_accs[b];
6979 return (*rangeA)[0].getBase()
6980 < (*rangeB)[0].getBase();
6981 });
6982 GRFMultirange C_accsSorted, C_accSwapsSorted,
6983 C_loadsSorted;
6984 std::vector<RegisterBlock> C_accSortedLayout;
6985
6986 bool remaskC_M = isPacked(problem.C.layout)
6987 && (strategy.remHandling[LoopM]
6988 != RemainderHandling::Ignore);
6989 bool remaskC_N = isPacked(problem.C.layout)
6990 && (strategy.remHandling[LoopN]
6991 != RemainderHandling::Ignore);
6992
6993 for (int i = 0; i < (liend - listart); i++) {
6994 if (remaskC_M || remaskC_N) {
6995 auto block = layout[listart + order[i]];
6996 block.offsetBytes = C_accsSorted.getLen()
6997 << GRF::log2Bytes(hw);
6998 C_accSortedLayout.push_back(block);
6999 }
7000
7001 C_accsSorted.append(C_accs[order[i]]);
7002 C_accSwapsSorted.append(C_accSwaps[order[i]]);
7003 C_loadsSorted.append(C_loads[order[i]]);
7004 }
7005
7006 updateC(C_accsSorted, C_accSwapsSorted, C_loadsSorted,
7007 problem, strategy, state);
7008
7009 if (remaskC_M)
7010 remaskLayout(state.Tacc, 0, false,
7011 C_accSortedLayout, C_accsSorted, strategy,
7012 state);
7013 if (remaskC_N)
7014 remaskLayout(state.Tacc, 1, true, C_accSortedLayout,
7015 C_accsSorted, strategy, state);
7016 }
7017 }
7018
7019 state.Tacc = Tacc_final;
7020
7021 // Store updated data.
7022 if (op == COperation::UpdateStore) {
7023 if (copyC)
7024 if (!copyRegisters(state.Tacc, Tc_ext, sublayoutAcc,
7025 sublayoutExt, C_accRange, C_extRange, 0, 0,
7026 false, strategy, state))
7027 stub();
7028
7029 auto &sublayoutSrc = copyC ? sublayoutExt : sublayoutAcc;
7030 auto &C_srcRange = copyC ? C_extRange : C_accRange;
7031 FOR_EACH_C storeMatrix(C_srcRange, sublayoutSrc, problem.C,
7032 strategy.C, C_addrs[q], strategy, state);
7033 }
7034
7035 state.Tacc = originalTacc;
7036 }
7037
7038 // Free address and data registers, including C accumulators that are no longer used...
7039 // ... except C_addr0. I need that!
7040 FOR_EACH_C safeReleaseRanges(C_addrs[q], state);
7041 safeReleaseRanges(C_extRange, state);
7042 safeReleaseRanges(C_copyRange, state);
7043 safeReleaseRanges(C_convertRegs, state);
7044 if (op == COperation::UpdateStore)
7045 for (int li = listart; li < liend; li++)
7046 for (int b = 0; b < state.C_buffers; b++)
7047 releaseRanges(state.C_regs[b].subrange(
7048 hw, state.Tacc, layout[li]),
7049 state);
7050 FOR_EACH_C state.ra.claim(C_addr0[q]);
7051
7052 // Check for forward progress.
7053 if (lend == lstart) throw out_of_registers_exception();
7054 lstart = lend;
7055 }
7056
7057 // Re-claim all the C registers we freed, so as not to disturb the caller's RegisterAllocator.
7058 reclaimRanges(state.C_regs[0], state);
7059#undef FOR_EACH_C
7060}
7061
7062// Assign runtime-computed descriptor information to all blocks in this layout.
7063// Returns true if successful; false if not all blocks in layout are compatible.
7064static inline bool assignAllDescs(vector<RegisterBlock> &layout) {
7065 for (auto &block : layout) {
7066 if (block.simdSize != layout[0].simdSize) return false;
7067 block.descAssigned = true;
7068 block.sfid = layout[0].sfid;
7069 }
7070
7071 return true;
7072}
7073
7074// Output code for standard C remainder handling.
7075template <HW hw>
7076bool gemm_kernel_generator_t<hw>::doStdCRemainder(
7077 vector<RegisterBlock> &layoutExt,
7078 vector<RegisterBlock> &layoutExtUnmasked, bool inside, bool columns[2],
7079 StdCRemType remTypes[2], bool fragments[2], bool fragPositives[2],
7080 int fragSizes[2], const GRFRange (&C_addr0)[2],
7081 const GRFRange (&C_addr0Unmasked)[2], COperation op,
7082 vector<MaskAssignment> &masks, GEMMProblem &problem,
7083 GEMMStrategy &strategy, GEMMState state) {
7084 auto Tc_ext = problem.Tc_ext;
7085 auto column = columns[inside];
7086 LoopType loop = column ? LoopN : LoopM;
7087 auto remType = remTypes[loop];
7088 auto fragment = fragments[loop];
7089 auto fragPositive = fragPositives[loop];
7090 auto fragSize = fragSizes[loop];
7091 auto unroll = strategy.unroll[loop];
7092 auto remainder = state.remainders[loop];
7093
7094 bool canEOT = !state.isNested && (op == COperation::UpdateStore);
7095
7096 Label lEnd;
7097
7098 // The "q" dimension is the one whose remainder we are currently handling.
7099 auto RegisterBlock::*nq = column ? &RegisterBlock::nc : &RegisterBlock::nr;
7100 auto RegisterBlock::*offsetQ
7101 = column ? &RegisterBlock::offsetC : &RegisterBlock::offsetR;
7102
7103 // Status message.
7104 status << "C remainder handling (" << char('m' + column) << ") " << remType;
7105 if (fragment) status << ", fragment";
7106 if (fragPositive) status << ", no empty accesses";
7107 status << status_stream::endl;
7108
7109 // Allocate temporaries for emulated atomic addition if needed.
7110 if (!inside && strategy.C.atomic)
7111 allocEAtomicAddRegs(
7112 hw, Tc_ext, layoutExt, problem.C, strategy.C, state);
7113
7114 // Handle a subproblem. Return true if successful.
7115 auto descend = [&](vector<RegisterBlock> &sublayoutExt,
7116 vector<RegisterBlock> &sublayoutExtUnmasked,
7117 bool full = false) -> bool {
7118 bool success = true;
7119 auto nMasksOriginal = int(masks.size());
7120
7121 if (remType == StdCRemType::Mask) {
7122 if (!full) {
7123 // Assign and load any extra masks needed.
7124 if (!assignMasks(
7125 sublayoutExt, LoopM, LoopN, masks, strategy, state))
7126 return false;
7127 loadMasks(masks, state.remainders, strategy, state,
7128 nMasksOriginal);
7129 sublayoutExtUnmasked.clear();
7130 } else {
7131 // Clear out mask assignments in this dimension.
7132 for (auto &block : layoutExt)
7133 block.clearFlag();
7134 }
7135 }
7136
7137 // Recursively handle subproblem.
7138 if (!inside)
7139 success = doStdCRemainder(sublayoutExt, sublayoutExtUnmasked, true,
7140 columns, remTypes, fragments, fragPositives, fragSizes,
7141 C_addr0, C_addr0Unmasked, op, masks, problem, strategy,
7142 state);
7143 else if (sublayoutExtUnmasked.empty())
7144 updateCLayout(sublayoutExt, C_addr0, state.C_layoutExt[0], op,
7145 problem, strategy, state);
7146 else
7147 updateCLayout(sublayoutExtUnmasked, C_addr0Unmasked,
7148 state.C_layoutExtUnmasked[0], op, problem, strategy, state);
7149
7150 // Free any new masks.
7151 if (remType == StdCRemType::Mask)
7152 safeReleaseMaskAssignments(masks, state, nMasksOriginal);
7153 return success;
7154 };
7155
7156 // Exit remainder handling.
7157 auto done = [&]() {
7158 if (!canEOT)
7159 jmpi(1, lEnd);
7160 else
7161 epilogue(strategy, state);
7162 };
7163
7164 // Main code.
7165 bool success = false;
7166 pushStream();
7167
7168 if (!fragment) {
7169 // If descriptor-based remainders requested, all blocks should be smaller than fragSize.
7170 // Load descriptors based on total remainder in this (rare) case.
7171 if (remType == StdCRemType::Descriptor) {
7172 loadLoadStoreDescriptors(!problem.beta0(), true, layoutExt[0],
7173 remainder, problem.C, strategy.C, strategy, state);
7174 if (!assignAllDescs(layoutExt)
7175 || !assignAllDescs(layoutExtUnmasked))
7176 goto failed;
7177 }
7178 if (inside && !layoutExtUnmasked.empty()
7179 && layoutExt.size() == state.C_layoutExt.size()) {
7180 // If unmasked layout is available, implement full remainder case specially.
7181 const bool useSIMTFlow = strategy.fused
7182 && (strategy.fusedLoop == loop
7183 || strategy.fusedLoop == LoopAny);
7184 Label labelRem, labelDone;
7185
7186 if (useSIMTFlow) {
7187 cmp(16 | ge | state.flagAP, remainder, unroll);
7188 if_(16 | state.flagAP, labelRem, labelDone);
7189 } else if (strategy.fused) {
7190 cmp(1 | ge | state.flagAP, remainder, unroll);
7191 jmpi(1 | ~state.flagAP, labelRem);
7192 } else {
7193 // No flag registers guaranteed -- use a jump table.
7194 auto tempQ = state.ra.alloc_sub<uint64_t>();
7195 auto temp = tempQ.ud(0);
7196
7197 add(1 | sat, temp, remainder, -unroll + 1);
7198 isGen12 ? mad(1, temp, 16, temp, 16) : shl(1, temp, temp, 4);
7199 jmpi(1, temp.d());
7200 jmpi(1, labelRem);
7201
7202 state.ra.safeRelease(tempQ);
7203 }
7204
7205 status << "Code for full " << char('m' + column) << " remainder"
7206 << status_stream::endl;
7207 if (!descend(layoutExt, layoutExtUnmasked, true)) goto failed;
7208
7209 useSIMTFlow ? else_(16, labelDone) : jmpi(1, labelDone);
7210 mark(labelRem);
7211
7212 status << "Code for generic " << char('m' + column) << " remainder"
7213 << status_stream::endl;
7214 if (!descend(layoutExt, layoutExtUnmasked)) goto failed;
7215
7216 mark(labelDone);
7217 if (useSIMTFlow) endif(16);
7218 } else {
7219 // Otherwise, nothing else to do: go down a level.
7220 if (!descend(layoutExt, layoutExtUnmasked)) goto failed;
7221 }
7222 } else {
7223 // Use SIMT control flow if remainders could be different between fused threads or if jump tables disabled.
7224 const bool useSIMTFlow = strategy.noJumpTables
7225 || (strategy.fused
7226 && (strategy.fusedLoop == loop
7227 || strategy.fusedLoop == LoopAny));
7228
7229 // Fix up fragment size (fragSize).
7230 // - Check that every block starts at a multiple of fragSize; if not fall back on fragSize 1.
7231 // - Max fragment size is 16.
7232 // - Should check unmasked layout, but it will have the same kind of fragmenting as the masked layout.
7233 fragSize = std::min<int>(fragSize, 16);
7234 for (auto &block : layoutExt) {
7235 if (block.*offsetQ % fragSize) {
7236 fragSize = 1;
7237 break;
7238 }
7239 }
7240
7241 // There are two strategies for fragmenting for remainder handling:
7242 // fragSize = 1: Try to get the largest blocks as possible. These are always fragPositive.
7243 // fragSize > 1: Always use blocks of size fragSize in the q dimension.
7244 if (fragSize == 1) {
7245 if (!useSIMTFlow) {
7246 // SIMD control flow, using a jump table.
7247 Subregister temp = state.ra.alloc_sub<uint32_t>();
7248 vector<Label> rlabels(unroll);
7249
7250 // Generate jump table.
7251 shl(1, temp, remainder,
7252 uint16_t(4)); // Multiply by instruction length.
7253 if (isGen12) // Gen12+ jmpi is relative to current IP.
7254 add(1, temp, temp, uint16_t(16));
7255 jmpi(1, temp.d()); // Indexed jump into jump table.
7256 for (int r = 0; r < unroll; r++)
7257 jmpi(1, rlabels[r]);
7258
7259 // Full remainder case: continue downward.
7260 status << "Code for full " << char('m' + column) << " remainder"
7261 << status_stream::endl;
7262 if (!descend(layoutExt, layoutExtUnmasked, true)) goto failed;
7263 inside ? jmpi(1, rlabels[0]) : done();
7264
7265 // Remainder handling.
7266 vector<bool> qdone(unroll, false);
7267 qdone[0] = true;
7268 int qnext = 0;
7269 for (int nqtodo = unroll - 2; nqtodo >= 0; nqtodo--) {
7270 // Decide which q to do.
7271 int q;
7272 if (qnext > 0)
7273 q = qnext;
7274 else {
7275 for (q = unroll - 1; q >= 0; q--)
7276 if (!qdone[q]) break;
7277 }
7278
7279 status << "Code for " << char('m' + column) << " remainder "
7280 << q << status_stream::endl;
7281
7282 mark(rlabels[q]);
7283
7284 // Figure out how many rows/columns to take.
7285 int chunkSize = q & ~(q - 1); // = 1 << lowest set bit
7286
7287 // Look through all blocks in this row/column, and reduce chunk size if appropriate.
7288 for (auto &block : layoutExt) {
7289 if (!block.isLoadBlock())
7290 stub(); // Dummy blocks should be replaced by real ones...
7291 int qq = q
7292 - block.*offsetQ; // Note q = 1 + last row/column.
7293 if (qq > 0 && qq <= block.*nq)
7294 chunkSize = std::min<int>(chunkSize, qq);
7295 }
7296
7297 // With chunk size chosen, get rows/columns [q - chunkSize, q) of intersecting blocks.
7298 vector<RegisterBlock> C_subblocksExt,
7299 C_subblocksExtUnmasked;
7300 if (!getSubblocks(Tc_ext, C_subblocksExt, layoutExt, column,
7301 q - chunkSize, q, false, problem.C, strategy.C))
7302 goto failed;
7303 if (!layoutExtUnmasked.empty())
7304 if (!getSubblocks(Tc_ext, C_subblocksExtUnmasked,
7305 layoutExtUnmasked, column, q - chunkSize, q,
7306 false, problem.C, strategy.C))
7307 goto failed;
7308
7309 // Perform the requested update.
7310 if (!descend(C_subblocksExt, C_subblocksExtUnmasked))
7311 goto failed;
7312
7313 // Go to next remainder handler, or return.
7314 qdone[q] = true;
7315 qnext = q - chunkSize;
7316 if (nqtodo > 0) {
7317 if (qnext == 0 && canEOT)
7318 epilogue(strategy, state);
7319 else if (qdone[qnext]) {
7320 jmpi(1, rlabels[qnext]);
7321 qnext = 0;
7322 }
7323 }
7324 }
7325 mark(rlabels[0]);
7326
7327 state.ra.safeRelease(temp);
7328 } else {
7329 // SIMT control flow: massively nested if-else.
7330
7331 // Handle remainder in the range [q0, q1).
7332 std::function<bool(int, int)> handleRemainder
7333 = [&](int q0, int q1) -> bool {
7334 Label labelElse, labelEndif;
7335
7336 int qChunk = rounddown_pow2(q1 - q0 - 1);
7337
7338 if (qChunk == 0) qChunk = 1;
7339
7340 status << "Code for " << char('m' + column)
7341 << " remainders " << q0 << " - " << (q1 - 1)
7342 << status_stream::endl;
7343
7344 if (q1 - q0 > 1) {
7345 cmp(16 | ge | state.flagAP, remainder,
7346 uint16_t(q0 + qChunk));
7347 if_(16 | state.flagAP,
7348 (qChunk > 1) ? labelElse : labelEndif,
7349 labelEndif);
7350 }
7351
7352 vector<RegisterBlock> C_subblocksExt,
7353 C_subblocksExtUnmasked;
7354 if (!getSubblocks(Tc_ext, C_subblocksExt, layoutExt, column,
7355 q0, q0 + qChunk, false, problem.C, strategy.C))
7356 return false;
7357 if (!layoutExtUnmasked.empty())
7358 if (!getSubblocks(Tc_ext, C_subblocksExtUnmasked,
7359 layoutExtUnmasked, column, q0, q0 + qChunk,
7360 false, problem.C, strategy.C))
7361 return false;
7362
7363 if (!descend(C_subblocksExt, C_subblocksExtUnmasked))
7364 return false;
7365
7366 if (q1 - q0 > 1) {
7367 if (qChunk > 1) {
7368 if (!handleRemainder(q0 + qChunk, q1)) return false;
7369
7370 else_(16, labelEndif);
7371 mark(labelElse);
7372
7373 if (!handleRemainder(q0, q0 + qChunk)) return false;
7374 }
7375
7376 mark(labelEndif);
7377 endif(16);
7378 }
7379
7380 return true;
7381 };
7382
7383 Label labelRem, labelRemDone, labelDone;
7384
7385 cmp(16 | ge | state.flagAP, remainder, uint16_t(unroll));
7386 if_(16 | state.flagAP, labelRem, labelDone);
7387
7388 status << "Code for " << char('m' + column) << " full remainder"
7389 << status_stream::endl;
7390 if (!descend(layoutExt, layoutExtUnmasked, true)) goto failed;
7391
7392 else_(16, labelDone);
7393 mark(labelRem);
7394
7395 if (!handleRemainder(0, unroll)) goto failed;
7396
7397 mark(labelDone);
7398 endif(16);
7399 setDefaultNoMask(true);
7400 }
7401 } else {
7402 auto handleRemainderFP = [&](int q0, int q1) -> bool {
7403 // Get rows/columns [q0, q1) of intersecting blocks.
7404 vector<RegisterBlock> C_subblocksExt, C_subblocksExtUnmasked;
7405 if (!getSubblocks(Tc_ext, C_subblocksExt, layoutExt, column, q0,
7406 q1, false, problem.C, strategy.C))
7407 return false;
7408 if (!layoutExtUnmasked.empty())
7409 if (!getSubblocks(Tc_ext, C_subblocksExtUnmasked,
7410 layoutExtUnmasked, column, q0, q1, false,
7411 problem.C, strategy.C))
7412 return false;
7413
7414 if (remType == StdCRemType::Descriptor) {
7415 // Load address registers for subsequent loads and stores.
7416 Subregister rcount = state.ra.alloc_sub<uint32_t>();
7417 Subregister mremainder = remainder;
7418
7419 if (q0 != 0) {
7420 add(1 | sat, rcount, mremainder, int16_t(-q0));
7421 mremainder = rcount;
7422 }
7423 if (q1 < unroll) {
7424 min_(1, rcount, mremainder, uint16_t(fragSize));
7425 mremainder = rcount;
7426 }
7427
7428 loadLoadStoreDescriptors(!problem.beta0(), true,
7429 C_subblocksExt[0], mremainder, problem.C,
7430 strategy.C, strategy, state);
7431 if (!assignAllDescs(C_subblocksExt)
7432 || !assignAllDescs(C_subblocksExtUnmasked))
7433 return false;
7434
7435 state.ra.safeRelease(rcount);
7436 }
7437
7438 // Perform the requested update.
7439 return descend(C_subblocksExt, C_subblocksExtUnmasked);
7440 };
7441
7442 if (!useSIMTFlow) {
7443 // SIMD control flow, possibly using a jump table.
7444 int N = div_up(unroll, fragSize);
7445 vector<Label> rlabels(N); // Targets for jump table.
7446 Label rdone;
7447
7448 // Create a jump table, if needed.
7449 if (fragPositive) {
7450 Subregister t1 = state.ra.alloc_sub<uint32_t>();
7451 Subregister t2 = state.ra.alloc_sub<uint32_t>();
7452
7453 add(1 | sat, t2, remainder, int16_t(-unroll + 1));
7454 add(1, t1, remainder,
7455 int16_t(-1 + (isGen12 ? fragSize : 0)));
7456 add(1, t1, t1,
7457 t2); // Increment index if remainder == unroll.
7458 if (fragSize < 16) // Precondition: fragSize <= 16.
7459 mulConstant(1, t1, t1,
7460 16 / fragSize); // Multiply by instruction length (16b/uncompacted instruction)
7461 and_(1, t1, t1,
7462 uint16_t(0xFFF0)); // Mask off unwanted bits.
7463 jmpi(1, t1.d()); // Indexed jump into jump table.
7464 for (int r = 0; r < N; r++)
7465 jmpi(1, rlabels[r]);
7466
7467 state.ra.safeRelease(t2);
7468 state.ra.safeRelease(t1);
7469 }
7470
7471 // Full loop.
7472 status << "Code for " << char('m' + column) << " full remainder"
7473 << status_stream::endl;
7474 if (!descend(layoutExt, layoutExtUnmasked, true)) goto failed;
7475 inside ? jmpi(1, rdone) : done();
7476
7477 // Remainder handling.
7478 for (int r = N - 1; r >= 0; r--) {
7479 int q0 = r * fragSize;
7480 int q1 = std::min<int>(q0 + fragSize, unroll);
7481
7482 status << "Code for " << char('m' + column)
7483 << " remainders " << q0 + 1 << " - " << q1
7484 << status_stream::endl;
7485
7486 mark(rlabels[r]);
7487
7488 if (!handleRemainderFP(q0, q1)) goto failed;
7489 }
7490
7491 if (inside) mark(rdone);
7492 } else {
7493 // SIMT control flow version.
7494 Label labelRem, labelRemDone, labelDone;
7495
7496 cmp(16 | ge | state.flagAP, remainder, uint16_t(unroll));
7497 if_(16 | state.flagAP, labelRem, labelDone);
7498
7499 status << "Code for " << char('m' + column) << " full remainder"
7500 << status_stream::endl;
7501 if (!descend(layoutExt, layoutExtUnmasked, true)) goto failed;
7502
7503 else_(16, labelDone);
7504 mark(labelRem);
7505
7506 for (int q0 = 0; q0 < unroll; q0 += fragSize) {
7507 int q1 = std::min<int>(q0 + fragSize, unroll);
7508
7509 cmp(16 | le | state.flagAP, remainder, uint16_t(q0));
7510 goto12(16 | state.flagAP, labelRemDone);
7511 status << "Code for " << char('m' + column)
7512 << " remainders " << q0 + 1 << " - " << q1
7513 << status_stream::endl;
7514
7515 if (!handleRemainderFP(q0, q1)) goto failed;
7516 }
7517
7518 mark(labelRemDone);
7519 join(16);
7520
7521 mark(labelDone);
7522 endif(16);
7523 }
7524 }
7525 }
7526
7527 // Success!
7528 success = true;
7529failed:
7530
7531 mark(lEnd);
7532 success ? appendCurrentStream() : discardStream();
7533
7534 if (!inside && strategy.C.atomic) freeEAtomicAddRegs(state);
7535
7536 return success;
7537}
7538
7539// Alternate code path for C remainder handling, based on a simple double loop
7540// and indirect addressing.
7541template <HW hw>
7542void gemm_kernel_generator_t<hw>::doAlternateCRemainder(COperation op,
7543 GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state) {
7544 auto Tc = problem.Tc, Tc_ext = problem.Tc_ext;
7545 int C_count = (op == COperation::UpdateStore) ? state.C_count : 1;
7546#define FOR_EACH_C for (int q = 0; q < C_count; q++)
7547#define FOR_EACH_C_REV for (int q = C_count - 1; q >= 0; q--)
7548
7549 bool lateYLoopCheck = false;
7550
7551 bool surface = !strategy.C.base.isStateless();
7552 bool loadOnly = (op == COperation::Load);
7553
7554 // Vector length in inner loop.
7555 const auto nbytes = 64;
7556 auto nec = nbytes / Tc;
7557
7558 // 1- and 2-byte types must be padded to 4 bytes.
7559 bool byte_access = (Tc_ext.size() < 4);
7560 if (byte_access) nec = nbytes >> 2;
7561
7562 // 8-byte+ types can use scattered qword. Only atomic for now.
7563 bool nativeAtomic = strategy.C.atomic
7564 && hasNativeAtomicAdd(hw, Tc_ext.real(), problem.C, strategy.C);
7565 bool qword = !((nativeAtomic ? Tc_ext.real() : Tc_ext).size() & 7)
7566 && strategy.C.atomic;
7567 int rshift = qword ? 3 : 2; // log2(data stride in regs)
7568 int rsimd = 64 >> rshift;
7569
7570 auto &block0 = state.C_layout[0];
7571 bool cColMajorMem = isColMajor(problem.C.layout);
7572 bool cColMajorReg = block0.colMajor;
7573 bool transpose = (cColMajorReg != cColMajorMem);
7574 if (isPacked(problem.C.layout)) stub();
7575
7576 // x is the contiguous dimension (in registers), y is the other dimension.
7577 auto LoopX = cColMajorReg ? LoopM : LoopN;
7578 auto LoopY = cColMajorReg ? LoopN : LoopM;
7579 int unrollX = strategy.unroll[LoopX];
7580 int unrollY = strategy.unroll[LoopY];
7581
7582 // Check the layout:
7583 // - C is a contiguous block of registers.
7584 // - nx must be divisible by 2 (unpacked) GRFs, unless x unroll is < 2 GRFs,
7585 // or there's an extra GRF at the end of C.
7586 // - register offsets must be in a uniform 2D grid
7587 // - all blocks must share same ordering (row/column major).
7588 // Otherwise use non-uniform path, and indirectly load GRFs.
7589
7590 auto Tcx = Tc;
7591 bool uniform = true;
7592 int16_t xByteInc = 0, yByteInc = 0;
7593 bool cAtEnd = (state.C_regs[0][state.C_regs[0].getLen() - 1].getBase() + 1)
7594 >= strategy.GRFs;
7595
7596 if (state.C_regs[0].ranges.size() != 1) uniform = false;
7597
7598 for (auto &block : state.C_layout) {
7599 if (block.colMajor != block0.colMajor) stub();
7600
7601 int nx = cColMajorReg ? block.nr : block.nc;
7602 int ny = cColMajorReg ? block.nc : block.nr;
7603 int ox = cColMajorReg ? block.offsetR : block.offsetC;
7604 int oy = cColMajorReg ? block.offsetC : block.offsetR;
7605
7606 ox /= nec;
7607
7608 if ((nx & (nec - 1)) && cAtEnd) uniform = false;
7609
7610 if (xByteInc == 0 && nx > nec) xByteInc = nec * Tcx;
7611 if (yByteInc == 0 && ny > 1) yByteInc = block.ld * Tc;
7612
7613 if (block.offsetBytes != ox * xByteInc + oy * yByteInc) {
7614 if (xByteInc == 0 && ox > 0)
7615 xByteInc = (block.offsetBytes - oy * yByteInc) / ox;
7616 else if (yByteInc == 0 && oy > 0)
7617 yByteInc = (block.offsetBytes - ox * xByteInc) / oy;
7618 else
7619 uniform = false;
7620 }
7621 }
7622
7623 GRFRange bases;
7624 bool nonuniformSubs = false;
7625
7626 if (!uniform) {
7627 uint8_t baseIndices[256] = {0};
7628 uint16_t offIndices[256] = {0};
7629
7630 if (state.Tacc.size() == 1) stub();
7631
7632 xByteInc = div_up(nec * Tcx, GRF::bytes(hw));
7633 int nec1 = nec / xByteInc;
7634 yByteInc = div_up(unrollX, nec1);
7635
7636 for (int y = 0; y < unrollY; y++) {
7637 for (int xx = 0; xx < yByteInc; xx++) {
7638 auto x = xx * nec1;
7639 auto i = cColMajorReg ? x : y;
7640 auto j = cColMajorReg ? y : x;
7641 const RegisterBlock *blockPtr;
7642 int ne;
7643 auto sub = findBlockReg(Tc, state.C_layout, i, j,
7644 state.C_regs[0], ne, blockPtr, 0);
7645 nonuniformSubs |= (sub.getOffset() != 0);
7646 if (ne < std::min(nec1, unrollX - x)) stub();
7647 baseIndices[y * yByteInc + xx] = sub.getBase();
7648 offIndices[y * yByteInc + xx]
7649 = sub.getByteOffset() + sub.getBase() * GRF::bytes(hw);
7650 }
7651 }
7652
7653 if (nonuniformSubs) {
7654 xByteInc *= 2;
7655 yByteInc *= 2;
7656 }
7657
7658 bases = state.ra.alloc_range(
7659 div_up(unrollY * yByteInc, GRF::bytes(hw)));
7660 bool haveDF = !strategy.emulate.emulate64;
7661 haveDF |= (hw == HW::XeHPC);
7662 if (haveDF) {
7663 for (int i = 0; i < unrollY * yByteInc; i += 8) {
7664 auto sub = bases[i / GRF::bytes(hw)].df(
7665 (i % GRF::bytes(hw)) / 8);
7666 auto data = nonuniformSubs
7667 ? reinterpret_cast<double *>(&offIndices[i / 2])
7668 : reinterpret_cast<double *>(&baseIndices[i]);
7669 mov(1, sub, *data);
7670 }
7671 } else {
7672 for (int i = 0; i < unrollY * yByteInc; i += 4) {
7673 auto sub = bases[i / GRF::bytes(hw)].ud(
7674 (i % GRF::bytes(hw)) / 4);
7675 auto data = nonuniformSubs
7676 ? reinterpret_cast<uint32_t *>(&offIndices[i / 2])
7677 : reinterpret_cast<uint32_t *>(&baseIndices[i]);
7678 mov(1, sub, *data);
7679 }
7680 }
7681 }
7682
7683 // Claim flags.
7684 auto saveFlagAP = state.flagAP;
7685 state.raVFlag.safeRelease(state.flagAP);
7686 state.raVFlag.claim(f0[0]);
7687 state.raVFlag.claim(f0[1]);
7688 state.raVFlag.claim(f1[0]);
7689
7690 // Clear f0[1] for any16h trick.
7691 if (strategy.fused && !lateYLoopCheck) mov(1, f0[1], uint16_t(0));
7692
7693 // Update C with scattered accesses.
7694 // Get mask and set up header.
7695 GRFRange header[2];
7696 auto hregs = (surface ? 1 : 2) * (qword ? 1 : 2);
7697 FOR_EACH_C header[q] = state.ra.alloc_range(hregs);
7698 Subregister temp = state.ra.alloc_sub<uint32_t>();
7699 Subregister mask = state.ra.alloc_sub<uint32_t>();
7700 Subregister xIndex = state.remainders[LoopX];
7701
7702 GRF indexVec, ivContig, ivScatter;
7703
7704 indexVec = state.ra.alloc();
7705 indexVec.setType(DataType::w);
7706 mov(8, indexVec[0](1), Immediate::uv(0, 1, 2, 3, 4, 5, 6, 7));
7707 if (rsimd > 8)
7708 mov(8, indexVec[8](1), Immediate::uv(8, 9, 10, 11, 12, 13, 14, 15));
7709
7710 auto oshift = std::min<int>(rshift, Tc_ext.log2Size());
7711
7712 // Prepare x mask in f1.0 and prepare header for loads/stores.
7713 if (Tc_ext.size() > 4) {
7714 mulConstant(1, temp, xIndex, uint16_t(Tc_ext.size() >> rshift));
7715 xIndex = temp;
7716 }
7717
7718 ivScatter = indexVec;
7719 bool splitScatter = transpose && (Tc_ext.log2Size() > rshift);
7720 if (splitScatter) {
7721 ivContig = state.ra.alloc();
7722 ivContig.setType(DataType::w);
7723 auto shift = Tc_ext.log2Size() - rshift;
7724 auto m = (1 << shift) - 1;
7725
7726 asr(16, ivScatter, indexVec, uint16_t(shift));
7727 mov(16, ivContig,
7728 Immediate::uv((0 & m) << rshift, (1 & m) << rshift,
7729 (2 & m) << rshift, (3 & m) << rshift, (4 & m) << rshift,
7730 (5 & m) << rshift, (6 & m) << rshift,
7731 (7 & m) << rshift));
7732 }
7733
7734 add(1, temp, xIndex, int16_t(-1));
7735 FOR_EACH_C transpose
7736 ? mul(rsimd, header[q][0].d(), state.inputs.ldc[q], ivScatter)
7737 : shl(rsimd, header[q][0].d(), indexVec, uint16_t(oshift));
7738 FOR_EACH_C if (splitScatter)
7739 add(rsimd, header[q][0].d(), header[q][0].d(), ivContig);
7740
7741 int hs = 1;
7742 bool header4 = !qword && !surface;
7743 int neq = elementsPerGRF(hw, DataType::uq);
7744
7745 header4 &= (GRF::bytes(hw) < 64);
7746 if (hw >= HW::XeHP && !surface) {
7747 if (header4)
7748 FOR_EACH_C mov<uint32_t>(2 * neq, header[q][2][0](2), header[q][1]);
7749 FOR_EACH_C mov<uint32_t>(neq, header[q][1][0](2), header[q][0][neq](1));
7750 FOR_EACH_C mov<uint32_t>(neq, header[q][0][0](2), header[q][0][0](1));
7751 hs = 2;
7752 }
7753
7754 and_(1, temp, ~temp, uint16_t(rsimd - 1));
7755 FOR_EACH_C surface
7756 ? add(rsimd, header[q][0].d(), header[q][0].d(), state.effC[q])
7757 : header4 ? eadd(8, header[q][2].uq(), header[q][hs].d(0)(hs),
7758 state.effC[q], strategy, state)
7759 : noop();
7760 mov(1, mask, uint16_t((1 << rsimd) - 1));
7761 FOR_EACH_C if (!surface) eadd(2 * neq, header[q][0].uq(),
7762 header[q][0].d(0)(hs), state.effC[q], strategy, state);
7763 shr(1, f1[0], mask, temp);
7764
7765 state.ra.safeRelease(mask);
7766 state.ra.safeRelease(temp);
7767 state.ra.safeRelease(ivContig);
7768
7769 // Synthesize double loop updating 2 GRFs (indirectly addressed) at a time.
7770 GRF ix = state.ra.alloc();
7771 Subregister ix_init = state.ra.alloc_sub<uint16_t>();
7772 Subregister iy = state.ra.alloc_sub<int16_t>();
7773 Subregister cXInc[2], cYInc[2];
7774 FOR_EACH_C cYInc[q] = state.ra.alloc_sub<int32_t>();
7775 Label yLoop, xLoop;
7776 GRFRange Cacc = state.ra.alloc_range(2);
7777 GRFRange CaccSwap {};
7778 GRFRange Cload
7779 = state.ra.alloc_range(2, getHint(HintType::CLoad, strategy));
7780
7781 if (transpose) FOR_EACH_C {
7782 cXInc[q] = state.ra.alloc_sub<int32_t>();
7783 mulConstant(1, cXInc[q], state.inputs.ldc[q], nec);
7784 }
7785
7786 add(1, ix_init, state.remainders[LoopX], int16_t(-1));
7787 mov(1, iy, state.remainders[LoopY]);
7788 shr(1, ix_init, ix_init, uint16_t(log2(nec)));
7789
7790 if (uniform)
7791 mov(1, a0[0], state.C_regs[0][0].getBase() * GRF::bytes(hw));
7792 else
7793 mov(1, a0[0], bases.getBase() * GRF::bytes(hw));
7794
7795 add(1, cYInc[0], ix_init, uint16_t(1));
7796 mulConstant(1, cYInc[0], cYInc[0],
7797 uint16_t(nec * (!transpose ? Tc_ext.size() : 1)));
7798 if (!transpose)
7799 FOR_EACH_C_REV add(1, cYInc[q], -cYInc[0], state.inputs.ldc[q]);
7800 else {
7801 FOR_EACH_C_REV mul(1, cYInc[q], state.inputs.ldc[q], cYInc[0].w());
7802 FOR_EACH_C_REV add(1, cYInc[q], -cYInc[q], uint16_t(Tc_ext.size()));
7803 }
7804
7805 mark(yLoop);
7806 mov<uint16_t>(16, ix, ix_init);
7807 if (!lateYLoopCheck) add(1 | gt | f0[1], iy, iy, int16_t(-1));
7808 mov(1, a0[1], a0[0]);
7809
7810 mark(xLoop);
7811 add<int16_t>(16 | ge | f0[0], ix, ix, int16_t(-1));
7812
7813 // Update. The anyv is a trick to use the generated m mask (f1.0) on the last
7814 // iteration of the loop, and no mask (0xFFFF) on the other iterations.
7815 InstructionModifier mod;
7816 mod = mod | f0[0] | anyv;
7817
7818 // Alas, no anyv on PVC.
7819 if (hw == HW::XeHPC) {
7820 mov(1 | ~f0[0], f0[0], f1[0]);
7821 mod = InstructionModifier() | f0[0];
7822 }
7823
7824 if (!uniform) {
7825 nonuniformSubs ? mov(xByteInc, a0[2](1), indirect[a0[1]].uw())
7826 : shl(xByteInc, a0[2](1), indirect[a0[1]].ub(),
7827 GRF::log2Bytes(hw));
7828 }
7829
7830 if (!loadOnly) {
7831 if (uniform) switch (state.Tacc.size()) {
7832 case 1: mov<uint32_t>(16, Cacc, indirect[a0[1]].ub()); break;
7833 case 2: mov<uint32_t>(16, Cacc, indirect[a0[1]].uw()); break;
7834 default: mov<uint32_t>(16, Cacc, indirect[a0[1]]); break;
7835 }
7836 else if (xByteInc == 1)
7837 switch (state.Tacc.size()) {
7838 case 2: mov<uint32_t>(16, Cacc, indirect[a0[2]].uw()); break;
7839 default: mov<uint32_t>(16, Cacc, indirect[a0[2]]); break;
7840 }
7841 else
7842 switch (state.Tacc.size()) {
7843 case 2:
7844 mov<uint32_t>(
7845 16, Cacc, indirect[a0[2]].uw(0)(16 / xByteInc, 1));
7846 break;
7847 default:
7848 mov<uint32_t>(
7849 16, Cacc, indirect[a0[2]].ud(0)(16 / xByteInc, 1));
7850 break;
7851 }
7852 }
7853
7854 if (strategy.C.atomic) {
7855 // Atomic update. Requires beta = 1, alpha prescaled.
7856 if (!problem.alpha1() && !problem.beta1()) stub();
7857 if (C_count > 1) stub();
7858 if (op != COperation::UpdateStore) stub();
7859
7860 std::vector<RegisterBlock> layout {1};
7861 auto &block = layout[0];
7862 block.ebytes = qword ? 8 : Tc_ext.real().size();
7863 block.simdSize = rsimd;
7864 block.clearFlag();
7865 block.bytes = 64;
7866 block.extra = 1;
7867 block.count = 1;
7868 block.log2GRFBytes = GRF::log2Bytes(hw);
7869
7870 allocEAtomicAddRegs(
7871 hw, Tc_ext, layout, problem.C, strategy.C, state, f1[1]);
7872
7873 Label labelEndAtomic;
7874 if_(16 | mod, labelEndAtomic);
7875 setDefaultNoMask(false);
7876 atomicAddMatrixBlock(Tc_ext, Cacc, block, problem.C, strategy.C,
7877 header[0], problem, strategy, state);
7878 setDefaultNoMask(true);
7879 mark(labelEndAtomic);
7880 endif(16);
7881
7882 freeEAtomicAddRegs(state, f1[1]);
7883 } else {
7884 // Late C conversion, if needed.
7885 auto originalTacc = state.Tacc;
7886 if (problem.needsTsConvert() && state.Tacc != problem.Ts) {
7887 convert(Cacc, state.Tacc, problem.Ts, problem, strategy, state);
7888 state.Tacc = problem.Ts;
7889 }
7890
7891 // Regular update.
7892 if (loadOnly || !problem.beta0()) {
7893 doReadSuppressionWA(strategy, state);
7894 if (strategy.C.newDP) {
7895 !byte_access ? load(16 | mod, Cload, D32 | strategy.C.cachingR,
7896 strategy.C.base, header[0])
7897 : (Tc_ext.size() == 2)
7898 ? load(16 | mod, Cload,
7899 D16U32 | strategy.C.cachingR,
7900 strategy.C.base, header[0])
7901 : load(16 | mod, Cload,
7902 D8U32 | strategy.C.cachingR,
7903 strategy.C.base, header[0]);
7904 } else {
7905 byte_access
7906 ? load(16 | mod, Cload, scattered_byte(Tc_ext.size()),
7907 strategy.C.base, header[0])
7908 : !surface ? load(16 | mod, Cload, scattered_dword(),
7909 strategy.C.base, header[0])
7910 : load(16 | mod, Cload,
7911 surface_dword(ChannelMask::r),
7912 strategy.C.base, header[0]);
7913 }
7914 }
7915
7916 if (!loadOnly) {
7917 auto Tc_out = (op == COperation::UpdateStore) ? problem.Tc_ext
7918 : state.Tacc;
7919 if (!problem.beta0())
7920 convert(Cload, problem.Tc_ext, state.Tacc, problem, strategy,
7921 state);
7922 updateC(Cacc, CaccSwap, Cload, problem, strategy, state);
7923 convert(Cacc, state.Tacc, Tc_out, problem, strategy, state);
7924 }
7925
7926 if (op != COperation::UpdateStore) {
7927 auto src = (op == COperation::Load) ? Cload : Cacc;
7928 if (uniform) switch (Tc.size()) {
7929 case 1:
7930 mov<uint32_t>(16 | mod, indirect[a0[1]].ub(), src);
7931 break;
7932 case 2:
7933 mov<uint32_t>(16 | mod, indirect[a0[1]].uw(), src);
7934 break;
7935 default:
7936 mov<uint32_t>(16 | mod, indirect[a0[1]], src);
7937 break;
7938 }
7939 else if (xByteInc == 1)
7940 switch (state.Tacc.size()) {
7941 case 2:
7942 mov<uint32_t>(16 | mod, indirect[a0[2]].uw(), src);
7943 break;
7944 default:
7945 mov<uint32_t>(16 | mod, indirect[a0[2]], src);
7946 break;
7947 }
7948 else if (xByteInc == 2)
7949 switch (state.Tacc.size()) {
7950 case 2:
7951 mov<uint32_t>(8 | mod, indirect[a0[2]].uw(), src);
7952 mov<uint32_t>(8 | mod | M8, indirect[a0[3]].uw(),
7953 src.sub(hw, 8, DataType::ud)(1));
7954 break;
7955 default:
7956 mov<uint32_t>(8 | mod, indirect[a0[2]].ud(), src);
7957 mov<uint32_t>(8 | mod | M8, indirect[a0[3]].ud(),
7958 src.sub(hw, 8, DataType::ud)(1));
7959 break;
7960 }
7961 else
7962 stub();
7963 } else
7964 FOR_EACH_C {
7965 if (strategy.C.newDP) {
7966 !byte_access ? store(16 | mod, D32 | strategy.C.cachingW,
7967 strategy.C.base, header[q], Cacc)
7968 : (Tc_ext.size() == 2)
7969 ? store(16 | mod,
7970 D16U32 | strategy.C.cachingW,
7971 strategy.C.base, header[q], Cacc)
7972 : store(16 | mod,
7973 D8U32 | strategy.C.cachingW,
7974 strategy.C.base, header[q], Cacc);
7975 } else {
7976 byte_access ? store(16 | mod, scattered_byte(Tc_ext.size()),
7977 strategy.C.base, header[q], Cacc)
7978 : !surface
7979 ? store(16 | mod, scattered_dword(),
7980 strategy.C.base, header[q], Cacc)
7981 : store(16 | mod,
7982 surface_dword(ChannelMask::r),
7983 strategy.C.base, header[q], Cacc);
7984 }
7985 }
7986
7987 state.Tacc = originalTacc;
7988 }
7989
7990 if (hw == HW::XeHPC) cmp<int16_t>(1 | ge | f0[0], ix, 0);
7991
7992 add(1, a0[1], a0[1], xByteInc);
7993 if (!transpose) {
7994 uint16_t inc = nec * Tc_ext;
7995 if (!surface) {
7996 FOR_EACH_C eadd<uint64_t>(std::min(2 * neq, rsimd), header[q][0],
7997 header[q][0], inc, strategy, state);
7998 if (header4)
7999 FOR_EACH_C eadd<uint64_t>(
8000 8, header[q][2], header[q][2], inc, strategy, state);
8001 } else
8002 FOR_EACH_C add<uint32_t>(rsimd, header[q][0], header[q][0], inc);
8003 } else {
8004 if (!surface) {
8005 FOR_EACH_C eadd<uint64_t>(std::min(2 * neq, rsimd), header[q][0],
8006 header[q][0], cXInc[q], strategy, state);
8007 if (header4)
8008 FOR_EACH_C eadd<uint64_t>(8, header[q][2], header[q][2],
8009 cXInc[q], strategy, state);
8010 } else
8011 FOR_EACH_C add<uint32_t>(
8012 rsimd, header[q][0], header[q][0], cXInc[q]);
8013 }
8014
8015 // Bottom of x loop.
8016 // Fused threads must use SIMT control flow instructions.
8017 strategy.fused ? simtDoWhileLoop(16 | f0[0], xLoop)
8018 : jmpi(1 | f0[0], xLoop);
8019
8020 if (lateYLoopCheck) add(1 | gt | f0[1], iy, iy, int16_t(-1));
8021 add(1, a0[0], a0[0], yByteInc);
8022 if (!surface) {
8023 FOR_EACH_C eadd<uint64_t>(std::min(2 * neq, rsimd), header[q][0],
8024 header[q][0], cYInc[q], strategy, state);
8025 if (header4)
8026 FOR_EACH_C eadd<uint64_t>(
8027 8, header[q][2], header[q][2], cYInc[q], strategy, state);
8028 } else
8029 FOR_EACH_C add<uint32_t>(rsimd, header[q][0], header[q][0], cYInc[q]);
8030
8031 // Bottom of y loop.
8032 // The any16h is a trick: only the lowest bit of f0[1] is updated when decrementing iy,
8033 // but we want to apply it to all channels.
8034 strategy.fused ? simtDoWhileLoop(16 | f0[1] | any16h, yLoop)
8035 : jmpi(1 | f0[1], yLoop);
8036
8037 // Cleanup.
8038 state.raVFlag.release(f0[0]);
8039 state.raVFlag.release(f0[1]);
8040 state.raVFlag.release(f1[0]);
8041 state.ra.safeRelease(bases);
8042
8043 state.ra.safeRelease(indexVec);
8044 state.ra.safeRelease(Cload);
8045 state.ra.safeRelease(CaccSwap);
8046 state.ra.safeRelease(Cacc);
8047 FOR_EACH_C state.ra.safeRelease(cXInc[q]);
8048 FOR_EACH_C state.ra.safeRelease(cYInc[q]);
8049 state.ra.safeRelease(iy);
8050 state.ra.safeRelease(ix);
8051 state.ra.safeRelease(ix_init);
8052 FOR_EACH_C state.ra.safeRelease(header[q]);
8053
8054 state.flagAP = saveFlagAP;
8055 if (state.flagAP.isValid()) state.raVFlag.claim(state.flagAP);
8056
8057#undef FOR_EACH_C
8058}
8059
8060// Prepare for GEMM k loop with m/n masked A/B accesses. Returns true if ka_lda/kb_ldb need recalculating.
8061template <HW hw>
8062bool gemm_kernel_generator_t<hw>::gemmPrepMaskedAB(
8063 const GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state) {
8064 bool recalc = false;
8065 bool shrinkUK = false;
8066 if (!strategy.A.padded
8067 && (strategy.remHandling[LoopM] != RemainderHandling::Ignore)) {
8068 shrinkUK = true;
8069 if (strategy.ka_load > strategy.ka_load_masked) {
8070 status << "Downgrading ka_load: " << strategy.ka_load << " -> "
8071 << strategy.ka_load_masked << status_stream::endl;
8072 strategy.ka_load = strategy.ka_load_masked;
8073 strategy.kChain = gcd(strategy.kChain, strategy.ka_load);
8074 recalc = true;
8075 }
8076 // Avoid access patterns that can't be handled by masking.
8077 if (isBlock2D(strategy.A.accessType) || strategy.unroll[LoopM] == 1)
8078 noop();
8079 else if (problem.A.layout == MatrixLayout::T
8080 && !isTransposing(strategy.A.accessType))
8081 strategy.A.accessType = strategy.A.base.isStateless()
8082 ? AccessType::Scattered
8083 : AccessType::ChannelScattered;
8084 else if (problem.A.layout != MatrixLayout::T
8085 && isTransposing(strategy.A.accessType))
8086 strategy.A.accessType = AccessType::Block;
8087 strategy.slmATrans = false;
8088 strategy.prefetchA = strategy.prefetchAMasked;
8089 }
8090 if (!strategy.B.padded
8091 && (strategy.remHandling[LoopN] != RemainderHandling::Ignore)) {
8092 shrinkUK = true;
8093 if (strategy.kb_load > strategy.kb_load_masked) {
8094 status << "Downgrading kb_load: " << strategy.kb_load << " -> "
8095 << strategy.kb_load_masked << status_stream::endl;
8096 strategy.kb_load = strategy.kb_load_masked;
8097 strategy.kChain = gcd(strategy.kChain, strategy.kb_load);
8098 recalc = true;
8099 }
8100 // Avoid access patterns that can't be handled by masking.
8101 if (isBlock2D(strategy.B.accessType) || strategy.unroll[LoopN] == 1)
8102 noop();
8103 else if (problem.B.layout == MatrixLayout::N
8104 && !isTransposing(strategy.B.accessType))
8105 strategy.B.accessType = strategy.B.base.isStateless()
8106 ? AccessType::Scattered
8107 : AccessType::ChannelScattered;
8108 else if (problem.B.layout != MatrixLayout::N
8109 && isTransposing(strategy.B.accessType))
8110 strategy.B.accessType = AccessType::Block;
8111 strategy.slmBTrans = false;
8112 strategy.prefetchB = strategy.prefetchBMasked;
8113 }
8114 if (shrinkUK && (strategy.unrollK_masked > 0)
8115 && (strategy.unroll[LoopK] > strategy.unrollK_masked)) {
8116 status << "Downgrading k unroll: " << strategy.unroll[LoopK] << " -> "
8117 << strategy.unrollK_masked << status_stream::endl;
8118 strategy.unroll[LoopK] = strategy.unrollK_masked;
8119 }
8120 if (shrinkUK && (strategy.unrollKSLMMasked > 0)
8121 && (strategy.unrollKSLM > strategy.unrollKSLMMasked)) {
8122 status << "Downgrading SLM k chunk size: " << strategy.unrollKSLM
8123 << " -> " << strategy.unrollKSLMMasked << status_stream::endl;
8124 strategy.unrollKSLM = strategy.unrollKSLMMasked;
8125 }
8126 return recalc;
8127}
8128
8129// Generate the GEMM kernel body. If it fails (due to excessive masking, say), return false.
8130template <HW hw>
8131bool gemm_kernel_generator_t<hw>::gemmBody(
8132 GEMMProblem problem, GEMMStrategy strategy, GEMMState state) {
8133 bool a2D = strategy.A.address2D;
8134 bool b2D = strategy.B.address2D;
8135 bool c2D = strategy.C.address2D;
8136
8137 // Record whether we are in the first row/column for fused sum kernels.
8138 if (problem.sumA || problem.sumB) {
8139 if (problem.sumA && problem.sumB) stub();
8140 auto &flags = state.inputs.flags;
8141 auto &y0 = problem.sumA ? state.j0 : state.i0;
8142
8143 if (flags.isInvalid()) {
8144 flags = state.ra.alloc_sub<uint32_t>(
8145 getHint(HintType::LongTerm, strategy));
8146 cmp(1 | eq | state.flagAP, flags, y0, 0);
8147 and_(1, flags, flags, FlagStoreSums);
8148 } else {
8149 cmp(1 | eq | state.flagAP, y0, 0);
8150 or_(1 | state.flagAP, flags, flags, FlagStoreSums);
8151 }
8152 }
8153
8154 // Release variables that are no longer needed.
8155 bool saveIJ0 = false;
8156 saveIJ0 |= problem.hasBinaryPostOp();
8157 if (!a2D && !c2D && !saveIJ0) state.ra.safeRelease(state.i0);
8158 if (!b2D && !c2D && !saveIJ0) state.ra.safeRelease(state.j0);
8159 if (!a2D && !b2D) state.ra.safeRelease(state.h0);
8160 if (!strategy.altCRemainder) releaseFusedRemainders(state);
8161 state.ra.safeRelease(state.remaindersWG[LoopM]);
8162 state.ra.safeRelease(state.remaindersWG[LoopN]);
8163
8164 // If A/B are masked, check if we need to change ka_load/kb_load. If so, recalculate lda_ka/ldb_kb.
8165 if (gemmPrepMaskedAB(problem, strategy, state))
8166 gemmCalcIncrements(problem, strategy, state);
8167
8168 // Disable C prefetch in remainder handling if it needs masks/fragmenting.
8169 if (strategy.remHandling[LoopM] != RemainderHandling::Ignore
8170 || strategy.remHandling[LoopN] != RemainderHandling::Ignore) {
8171 if (strategy.C.base.isStateless() && !strategy.C.padded
8172 && strategy.prefetchC
8173 && !isBlock2D(strategy.C_prefetch.accessType)) {
8174 status << "Auto-disabling C prefetch in masked region"
8175 << status_stream::endl;
8176 strategy.prefetchC = 0;
8177 if (state.effCp != state.effC[0]) state.ra.safeRelease(state.effCp);
8178 }
8179 }
8180
8181 // Try generating kernel body with current strategy.
8182 bool success = false;
8183 pushStream();
8184 try {
8185 success = gemmBodyInternal(problem, strategy, state);
8186 } catch (...) { lastException = std::current_exception(); }
8187 success ? appendCurrentStream() : discardStream();
8188
8189 return success;
8190}
8191
8192// Allocate nreg registers in chunks of a given size.
8193static inline GRFMultirange chunkAlloc(int nreg, int chunk, Bundle hint,
8194 BundleGroup mask, CommonState &state) {
8195 GRFMultirange r;
8196 for (; nreg > 0; nreg -= chunk) {
8197 auto nr = std::min(nreg, chunk);
8198 r.ranges.push_back(state.ra.alloc_range(nr, hint, mask));
8199 }
8200 return r;
8201}
8202
8203static inline GRFMultirange chunkAlloc(
8204 int nreg, int chunk, Bundle hint, CommonState &state) {
8205 return chunkAlloc(nreg, chunk, hint, BundleGroup::AllBundles(), state);
8206}
8207
8208// Allocate register layout in individual chunks.
8209static inline GRFMultirange trySplitAlloc(HW hw, Type T,
8210 const vector<RegisterBlock> &layout, std::array<Bundle, 2> hints,
8211 BundleGroup mask, CommonState &state, int copies = 1) {
8212 auto oddHint = Bundle(0, 0).group_size(hw) * elementsPerGRF(hw, T);
8213
8214 GRFMultirange r;
8215 struct Request {
8216 int length, offset, index, hint;
8217 };
8218 vector<Request> requests;
8219 requests.reserve(layout.size());
8220
8221 for (auto &block : layout) {
8222 if (block.isLoadBlock()) {
8223 int hint = ((block.colMajor ? block.offsetR : block.offsetC)
8224 & oddHint)
8225 != 0;
8226 requests.push_back({block.msgRegs, block.offsetReg(), 0, hint});
8227 }
8228 }
8229
8230 if (requests.empty() && !layout.empty())
8231 for (auto &block : layout) {
8232 // No memory backing for layout. Split by rows/columns if possible.
8233 int hint = ((block.colMajor ? block.offsetR : block.offsetC)
8234 & oddHint)
8235 != 0;
8236 auto &ny = block.colMajor ? block.nc : block.nr;
8237 int xElems = (block.ld * block.crosspack * T);
8238 int xGRFs = xElems / elementsPerGRF(hw, T);
8239 if (xElems % elementsPerGRF(hw, T))
8240 requests.push_back({block.nregs(), block.offsetReg(), 0,
8241 hint}); /* can't split */
8242 else
8243 for (int y = 0, off = block.offsetReg(); y < ny;
8244 y += block.crosspack, off += xGRFs)
8245 requests.push_back({xGRFs, off, 0, hint});
8246 }
8247
8248 // Figure out which order the ranges belong in.
8249 std::sort(requests.begin(), requests.end(),
8250 [](const Request &r1, const Request &r2) {
8251 return (r1.offset < r2.offset);
8252 });
8253 for (size_t i = 0; i < requests.size(); i++)
8254 requests[i].index = int(i);
8255
8256 // Sort again and allocate largest to smallest.
8257 std::sort(requests.begin(), requests.end(),
8258 [](const Request &r1, const Request &r2) {
8259 return (r1.length > r2.length)
8260 || (r1.length == r2.length && r1.offset < r2.offset);
8261 });
8262 r.ranges.resize(requests.size() * copies);
8263
8264 bool ok = true;
8265 for (size_t i = 0; i < requests.size(); i++) {
8266 for (int c = 0; c < copies; c++) {
8267 auto newRange = state.ra.try_alloc_range(
8268 requests[i].length, hints[requests[i].hint], mask);
8269 r.ranges[requests[i].index + c * requests.size()] = newRange;
8270 ok &= newRange.isValid();
8271 }
8272 }
8273
8274 if (!ok) {
8275 for (auto &rr : r.ranges)
8276 state.ra.release(rr);
8277 r.ranges.clear();
8278 }
8279
8280 return r;
8281}
8282
8283static inline GRFMultirange splitAlloc(HW hw, Type T,
8284 const vector<RegisterBlock> &layout, std::array<Bundle, 2> hints,
8285 BundleGroup mask, CommonState &state, int copies = 1) {
8286 auto r = trySplitAlloc(hw, T, layout, hints, mask, state, copies);
8287 if (r.empty() && !layout.empty()) throw out_of_registers_exception();
8288 return r;
8289}
8290
8291// Allocate register ranges for A/B/C.
8292template <HW hw>
8293void gemm_kernel_generator_t<hw>::gemmAllocRegs(
8294 GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state) {
8295 // Summary: order of allocations is important.
8296 auto Ta = problem.Ta, Tb = problem.Tb, Tc = problem.Tc;
8297
8298 auto A_copies = strategy.A_copies;
8299 auto B_copies = strategy.B_copies;
8300 int A_regCount = getRegCount(state.A_layout);
8301 int Ar_regCount = getRegCount(state.Ar_layout);
8302 int B_regCount = getRegCount(state.B_layout);
8303 int Br_regCount = getRegCount(state.Br_layout);
8304 int C_regCountPerBuffer = getRegCount(state.C_layout);
8305 int C_regCount = state.C_buffers * C_regCountPerBuffer;
8306 GRFMultirange C_regs;
8307
8308 bool globalCM = isLayoutColMajor(state.C_layout);
8309
8310 auto hintA0 = globalCM ? HintType::A0 : HintType::A0Broadcast;
8311 auto hintB0 = !globalCM ? HintType::B0 : HintType::B0Broadcast;
8312
8313 auto Tv = globalCM ? Ta : Tb;
8314 auto Tn = !globalCM ? Ta : Tb;
8315
8316 auto &V_layout = globalCM ? state.A_layout : state.B_layout;
8317 auto &Vr_layout = globalCM ? state.Ar_layout : state.Br_layout;
8318 auto &V_regs = globalCM ? state.A_regs : state.B_regs;
8319 auto &Vr_regs = globalCM ? state.Ar_regs : state.Br_regs;
8320 auto V_copies = globalCM ? A_copies : B_copies;
8321 auto V_regCount = globalCM ? A_regCount : B_regCount;
8322 auto Vr_regCount = globalCM ? Ar_regCount : Br_regCount;
8323 auto &N_layout = !globalCM ? state.A_layout : state.B_layout;
8324 auto &Nr_layout = !globalCM ? state.Ar_layout : state.Br_layout;
8325 auto &N_regs = !globalCM ? state.A_regs : state.B_regs;
8326 auto &Nr_regs = !globalCM ? state.Ar_regs : state.Br_regs;
8327 auto N_copies = !globalCM ? A_copies : B_copies;
8328 auto N_regCount = !globalCM ? A_regCount : B_regCount;
8329 auto Nr_regCount = !globalCM ? Ar_regCount : Br_regCount;
8330
8331 const auto &C_layout = state.C_layout;
8332 const auto &C_layoutExt = state.C_layoutExtUnmasked.empty()
8333 ? state.C_layoutExt
8334 : state.C_layoutExtUnmasked;
8335
8336 int C_chunk = state.copyC ? 1 : getMaxLoadBlock(C_layoutExt);
8337 C_chunk = alignup_pow2(C_chunk, Bundle(0, 0).group_size(hw) * 2);
8338 if (strategy.systolic) C_chunk = std::max(C_chunk, 8);
8339
8340 state.C_accCount = strategy.cAccumulators
8341 ? AccumulatorRegister::count(hw, strategy.GRFs, Tc.ngen())
8342 : 0;
8343
8344 state.A_regs.resize(A_copies);
8345 state.B_regs.resize(B_copies);
8346
8347 switch (strategy.registerScheme) {
8348 case GEMMStrategy::CSeparate: {
8349 // Standard allocation (Gen9-11). A and B allocated together in lower half of registers.
8350 // Interleave allocation of A and B to minimize wasted registers. Test the waters to find out
8351 // whether to try bank 0 or 1 first.
8352 int bases[2];
8353 for (int bank = 0; bank < 2; bank++) {
8354 auto r = state.ra.alloc_range(4, Bundle(bank, Bundle::any));
8355 bases[bank] = r.getBase();
8356 state.ra.safeRelease(r);
8357 }
8358
8359 // Order of the banks.
8360 int banks[2];
8361 banks[0] = (bases[1] < bases[0]) ? 1 : 0;
8362 banks[1] = 1 - banks[0];
8363
8364 // Allocate all the registers needed from bank 0, then all the registers needed from bank 1.
8365 for (int bank : banks) {
8366 if (getHint(hintA0, strategy).bank_id == bank) {
8367 for (int copy = 0; copy < A_copies; copy++)
8368 state.A_regs[copy] = state.ra.alloc_range(
8369 A_regCount, getHint(hintA0, strategy));
8370 if (state.broadcast && !globalCM)
8371 state.broadcast_regs = state.ra.alloc_range(
8372 2, getHint(hintA0, strategy));
8373 if (Ar_regCount > 0)
8374 state.Ar_regs = state.ra.alloc_range(
8375 Ar_regCount, getHint(hintA0, strategy));
8376 }
8377
8378 if (getHint(hintB0, strategy).bank_id == bank) {
8379 for (int copy = 0; copy < B_copies; copy++)
8380 state.B_regs[copy] = state.ra.alloc_range(
8381 B_regCount, getHint(hintB0, strategy));
8382 if (state.broadcast && globalCM)
8383 state.broadcast_regs = state.ra.alloc_range(
8384 2, getHint(hintB0, strategy));
8385 if (Br_regCount > 0)
8386 state.Br_regs = state.ra.alloc_range(
8387 Br_regCount, getHint(hintB0, strategy));
8388 }
8389 }
8390
8391 C_regs = state.ra.alloc_range(C_regCount - state.C_accCount,
8392 getHint(HintType::C, strategy));
8393 break;
8394 }
8395 case GEMMStrategy::ACB:
8396 if (state.broadcast && !globalCM)
8397 state.broadcast_regs
8398 = state.ra.alloc_range(2, getHint(hintA0, strategy));
8399
8400 for (int copy = 0; copy < A_copies; copy++)
8401 state.A_regs[copy] = state.ra.alloc_range(
8402 A_regCount, getHint(hintA0, strategy));
8403 if (Ar_regCount > 0)
8404 state.Ar_regs = state.ra.alloc_range(
8405 Ar_regCount, getHint(hintA0, strategy));
8406
8407 C_regs = state.ra.alloc_range(C_regCount - state.C_accCount,
8408 getHint(HintType::C, strategy));
8409
8410 for (int copy = 0; copy < B_copies; copy++)
8411 state.B_regs[copy] = state.ra.alloc_range(
8412 B_regCount, getHint(hintB0, strategy));
8413 if (Br_regCount > 0)
8414 state.Br_regs = state.ra.alloc_range(
8415 Br_regCount, getHint(hintB0, strategy));
8416
8417 if (state.broadcast && globalCM)
8418 state.broadcast_regs
8419 = state.ra.alloc_range(2, getHint(hintB0, strategy));
8420 break;
8421 case GEMMStrategy::BCA:
8422 if (state.broadcast && !globalCM)
8423 state.broadcast_regs
8424 = state.ra.alloc_range(2, getHint(hintA0, strategy));
8425
8426 for (int copy = 0; copy < B_copies; copy++)
8427 state.B_regs[copy] = state.ra.alloc_range(
8428 B_regCount, getHint(hintB0, strategy));
8429 if (Br_regCount > 0)
8430 state.Br_regs = state.ra.alloc_range(
8431 Br_regCount, getHint(hintB0, strategy));
8432
8433 C_regs = state.ra.alloc_range(C_regCount - state.C_accCount,
8434 getHint(HintType::C, strategy));
8435
8436 for (int copy = 0; copy < A_copies; copy++)
8437 state.A_regs[copy] = state.ra.alloc_range(
8438 A_regCount, getHint(hintA0, strategy));
8439 if (Ar_regCount > 0)
8440 state.Ar_regs = state.ra.alloc_range(
8441 Ar_regCount, getHint(hintA0, strategy));
8442
8443 if (state.broadcast && globalCM)
8444 state.broadcast_regs
8445 = state.ra.alloc_range(2, getHint(hintB0, strategy));
8446 break;
8447 case GEMMStrategy::VNC: {
8448 if (hw < HW::Gen12LP) stub();
8449
8450 // Gen12+. Assign non-broadcast input matrix (V), then broadcast input matrix (N), then C.
8451 auto unrollVBytes
8452 = strategy.unroll[globalCM ? LoopM : LoopN] * Tv.size();
8453 auto unrollNBytes
8454 = strategy.unroll[globalCM ? LoopN : LoopM] * Tn.size();
8455 auto regUnrollV = div_up(unrollVBytes, GRF::bytes(hw));
8456 auto regUnrollN = div_up(unrollNBytes, GRF::bytes(hw));
8457 auto hintV = getHint(HintType::A0, strategy);
8458 auto hintN = getHint(
8459 (regUnrollN == 1) ? HintType::A0 : HintType::A0Broadcast,
8460 strategy); // Put V and N in same bundle if we can avoid N<->C conflicts.
8461 auto hintC = getHint(HintType::C, strategy);
8462 GRFRange tempPadding;
8463
8464 for (int copy = 0; copy < V_copies; copy++)
8465 V_regs[copy] = state.ra.alloc_range(V_regCount, hintV);
8466 if (Vr_regCount > 0)
8467 Vr_regs = state.ra.alloc_range(Vr_regCount, hintV);
8468
8469 N_regs[0] = state.ra.alloc_range(N_regCount, hintN);
8470
8471 // Check if A * B outer product 0 has a bank conflict. If so, move N to avoid this.
8472 auto stride = Bundle(0, 0).stride(hw);
8473 auto offN = (N_regs[0][0].getBase() - V_regs[0][0].getBase())
8474 & (stride - 1);
8475 auto offNMin = offN - ((regUnrollV - 1) & ~1);
8476 auto offNMax = offN + regUnrollN - 1;
8477 if (offNMax >= stride) offNMax -= stride, offNMin -= stride;
8478 if (offNMin <= 0) {
8479 unsigned obAlign = Bundle(0, 0).group_size(hw);
8480 if (hintN.bank_id != Bundle::any) obAlign *= 2;
8481 offNMax = alignup_pow2(offNMax, obAlign);
8482 safeReleaseRanges(N_regs[0], state);
8483 tempPadding = state.ra.alloc_range(offNMax, hintN);
8484 N_regs[0] = state.ra.alloc_range(N_regCount, hintN);
8485 }
8486
8487 for (int copy = 1; copy < N_copies; copy++)
8488 N_regs[copy] = state.ra.alloc_range(N_regCount, hintN);
8489 if (Nr_regCount > 0)
8490 Nr_regs = state.ra.alloc_range(Nr_regCount, hintN);
8491
8492 state.ra.safeRelease(tempPadding);
8493
8494 C_regs = state.ra.alloc_range(C_regCount - state.C_accCount, hintC);
8495 break;
8496 }
8497 case GEMMStrategy::ABInterleave: {
8498 // Gen12+. Interleave A and B, place C afterward.
8499 if (hw < HW::Gen12LP) stub();
8500 auto chunk = Bundle(0, 0).stride(hw) >> 1;
8501
8502 // Test allocation. Put A earlier if it has more registers.
8503 int A_regTotal = A_regCount * A_copies + Ar_regCount;
8504 int B_regTotal = B_regCount * B_copies + Br_regCount;
8505 auto hintA = getHint(HintType::A0, strategy);
8506 auto hintB = getHint(HintType::B0, strategy);
8507 auto hintC = getHint(HintType::C, strategy);
8508 auto testA = state.ra.alloc_range(8, hintA);
8509 auto testB = state.ra.alloc_range(8, hintB);
8510 if ((testA.getBase() < testB.getBase())
8511 == (A_regTotal < B_regTotal))
8512 std::swap(hintA, hintB);
8513 state.ra.safeRelease(testA);
8514 state.ra.safeRelease(testB);
8515
8516 for (int copy = 0; copy < A_copies; copy++)
8517 state.A_regs[copy]
8518 = chunkAlloc(A_regCount, chunk, hintA, state);
8519 if (Ar_regCount > 0)
8520 state.Ar_regs = chunkAlloc(Ar_regCount, chunk, hintA, state);
8521 for (int copy = 0; copy < B_copies; copy++)
8522 state.B_regs[copy]
8523 = chunkAlloc(B_regCount, chunk, hintB, state);
8524 if (Br_regCount > 0)
8525 state.Br_regs = chunkAlloc(Br_regCount, chunk, hintB, state);
8526 C_regs = state.ra.alloc_range(C_regCount - state.C_accCount, hintC);
8527 break;
8528 }
8529 case GEMMStrategy::NSeparate: {
8530 // Broadcast matrix (N) has dedicated bundle(s) (both banks)
8531 // V and C start in opposite banks in other bundles.
8532 if (hw < HW::Gen12LP) stub();
8533 if (state.C_accCount > 0) stub();
8534
8535 int bundles = Bundle::bundle_count(hw) * Bundle::bank_count(hw);
8536 int bregsConsecutive = Bundle(0, 0).group_size(hw);
8537 int bregs = strategy.GRFs / bundles;
8538 int N_chunk = getMaxLoadBlock(N_layout);
8539 int N_nregs = Nr_regCount + N_regCount * N_copies;
8540 int N_nbundles = std::max(
8541 div_up(N_chunk, bregsConsecutive), div_up(N_nregs, bregs));
8542 BundleGroup N_bundles(hw), VC_bundles(hw);
8543
8544 auto hintV0 = getHint(HintType::A0, strategy);
8545 auto hintV1 = getHint(HintType::A1, strategy);
8546 auto hintN = getHint(HintType::A0Broadcast, strategy);
8547 auto hintC0 = getHint(HintType::C, strategy);
8548 auto hintC1 = getHint(HintType::C1, strategy);
8549
8550 // Give bundles starting at the end to broadcast matrix.
8551 for (int bundle = Bundle::bundle_count(hw) - 1; bundle >= 0;
8552 bundle--) {
8553 for (int bank = Bundle::bank_count(hw) - 1; bank >= 0; bank--) {
8554 if (N_nbundles-- > 0)
8555 N_bundles |= Bundle(bank, bundle);
8556 else
8557 VC_bundles |= Bundle(bank, bundle);
8558 }
8559 }
8560
8561 for (int copy = 0; copy < V_copies; copy++)
8562 V_regs[copy] = splitAlloc(
8563 hw, Tv, V_layout, {hintV0, hintV1}, VC_bundles, state);
8564 if (Vr_regCount > 0)
8565 Vr_regs = splitAlloc(
8566 hw, Tv, Vr_layout, {hintV0, hintV1}, VC_bundles, state);
8567 if (!strategy.systolic)
8568 C_regs = trySplitAlloc(hw, Tc, C_layout, {hintC0, hintC1},
8569 VC_bundles, state, state.C_buffers);
8570 if (C_regs.empty())
8571 C_regs = chunkAlloc(
8572 C_regCount, C_chunk, hintC0, VC_bundles, state);
8573 for (int copy = 0; copy < N_copies; copy++)
8574 N_regs[copy] = splitAlloc(
8575 hw, Tn, N_layout, {hintN, hintN}, N_bundles, state);
8576 if (Nr_regCount > 0)
8577 Nr_regs = splitAlloc(
8578 hw, Tn, Nr_layout, {hintN, hintN}, N_bundles, state);
8579 break;
8580 }
8581 case GEMMStrategy::VAvoid: {
8582 // Broadcast matrix (N) has dedicated starting bank.
8583 // V and C share starting banks, but C allocations chosen to miss matching V allocations.
8584 auto hintV = getHint(HintType::A0, strategy);
8585 auto hintN = getHint(HintType::A0Broadcast, strategy);
8586 auto hintC = getHint(HintType::C, strategy);
8587
8588 for (int copy = 0; copy < N_copies; copy++)
8589 N_regs[copy] = state.ra.alloc_range(N_regCount, hintN);
8590 if (Nr_regCount > 0)
8591 Nr_regs = state.ra.alloc_range(Nr_regCount, hintN);
8592
8593 for (int copy = 0; copy < V_copies; copy++)
8594 V_regs[copy] = state.ra.alloc_range(V_regCount, hintV);
8595 if (Vr_regCount > 0)
8596 Vr_regs = state.ra.alloc_range(Vr_regCount, hintV);
8597
8598 int nv;
8599 const RegisterBlock *V_block;
8600 int V_rows, V_cols;
8601 getLayoutDims(
8602 Vr_regCount > 0 ? Vr_layout : V_layout, V_rows, V_cols);
8603 int kv = globalCM ? V_cols : V_rows;
8604
8605 int minOPCount = minOuterProductCount(hw, problem, strategy);
8606 int lastMN0 = -1;
8607 int sliceRegs = 0;
8608 BundleGroup V_bundles(hw);
8609
8610 vector<GRFMultirange> C_extra(state.C_buffers - 1);
8611 auto allocSlice = [&]() {
8612 if (sliceRegs <= 0) return;
8613 auto C_bundles = ~V_bundles;
8614
8615 C_regs.append(chunkAlloc(
8616 sliceRegs, C_chunk, hintC, C_bundles, state));
8617 for (int copy = 1; copy < state.C_buffers; copy++)
8618 C_extra[copy - 1].append(chunkAlloc(
8619 sliceRegs, C_chunk, hintC, C_bundles, state));
8620
8621 sliceRegs = 0;
8622 };
8623
8624 for (const auto &block : C_layout) {
8625 int mn0 = globalCM ? block.offsetR : block.offsetC;
8626 if (mn0 == lastMN0) {
8627 sliceRegs += block.nregs();
8628 continue;
8629 }
8630
8631 allocSlice();
8632
8633 V_bundles = BundleGroup(hw);
8634 for (int h0 = 0; h0 < kv; h0 += minOPCount) {
8635 int r = globalCM ? mn0 : h0;
8636 int c = globalCM ? h0 : mn0;
8637 int comp = 0;
8638 if (Vr_regCount == 0)
8639 for (int copy = 0; copy < V_copies; copy++) {
8640 auto V0 = findBlockReg(Tv, V_layout, r, c,
8641 V_regs[copy], nv, V_block, 0, comp);
8642 V_bundles |= Bundle::locate(hw, V0);
8643 }
8644 else {
8645 auto V0 = findBlockReg(Tv, Vr_layout, r, c, Vr_regs, nv,
8646 V_block, 0, comp);
8647 V_bundles |= Bundle::locate(hw, V0);
8648 }
8649 }
8650
8651 lastMN0 = mn0;
8652 sliceRegs = block.nregs();
8653 }
8654
8655 allocSlice();
8656
8657 for (int copy = 1; copy < state.C_buffers; copy++)
8658 C_regs.append(C_extra[copy - 1]);
8659 }
8660 }
8661
8662 // Assign C_regs, adding in GRFs (in place of accumulators) to use later.
8663 state.C_regs.resize(state.C_buffers);
8664
8665 auto it = C_regs.ranges.begin();
8666 int off = -state.C_accCount;
8667 for (int buf = 0; buf < state.C_buffers; buf++) {
8668 for (int todo = C_regCountPerBuffer; todo > 0;) {
8669 if (it == C_regs.ranges.end())
8670 throw std::runtime_error("Not enough C registers allocated.");
8671 int left = it->getLen() - off;
8672 int take = std::min(left, todo);
8673 state.C_regs[buf].ranges.push_back(
8674 GRFRange(it->getBase() + off, take));
8675 todo -= take;
8676 off += take;
8677 if (off >= it->getLen()) off = 0, it++;
8678 }
8679 }
8680
8681 // Allocate registers for SLM copies.
8682 state.Ai_regs.resize(strategy.slmCopies);
8683 state.Bi_regs.resize(strategy.slmCopies);
8684 if (strategy.slmA)
8685 for (int q = 0; q < strategy.slmCopies; q++)
8686 state.Ai_regs[q] = state.ra.alloc_range(state.Ai_regCount);
8687 if (strategy.slmB)
8688 for (int q = 0; q < strategy.slmCopies; q++)
8689 state.Bi_regs[q] = state.ra.alloc_range(state.Bi_regCount);
8690
8691 // Allocate registers for A/B sums.
8692 state.As_regs = state.ra.alloc_range(getRegCount(state.As_layout));
8693 state.Bs_regs = state.ra.alloc_range(getRegCount(state.Bs_layout));
8694
8695 // Allocate registers for A/B prefetch.
8696 state.Ap_regs = state.ra.alloc_range(getRegCount(state.Ap_layout));
8697 state.Bp_regs = state.ra.alloc_range(getRegCount(state.Bp_layout));
8698
8699 // Allocate multiplication temporaries for Gen9 IGEMM, in pairs.
8700 if (isGen9IGEMM(hw, Ta, Tb, Tc)) {
8701 auto &temps = state.tempMul_regs;
8702 for (int ntemp = 0; ntemp < 2; ntemp++) {
8703 auto range = state.ra.try_alloc_range(2);
8704 if (range.isValid())
8705 temps.push_back(range);
8706 else if (temps.empty())
8707 throw out_of_registers_exception();
8708 else
8709 break;
8710 }
8711 }
8712}
8713
8714template <HW hw>
8715void gemm_kernel_generator_t<hw>::gemmAllocAoBoRegs(
8716 const GEMMStrategy &strategy, GEMMState &state) {
8717 bool allocAo = false, allocBo = false;
8718
8719 if (strategy.slmA && state.Ao_regs.empty() && !state.aioShare) {
8720 allocAo = true;
8721 if (strategy.slmRepackAhead == 0 && strategy.A_copies == 1) {
8722 auto nreg = getRegCount(state.Ao_layout);
8723 auto &defaultRegs = state.A_regs[0];
8724 allocAo = (defaultRegs.getLen() < nreg);
8725
8726 if (!allocAo) {
8727 state.Ao_regs = defaultRegs;
8728 state.aoReuseA = true;
8729 }
8730 }
8731 }
8732
8733 if (strategy.slmB && state.Bo_regs.empty() && !state.bioShare) {
8734 allocBo = true;
8735 if (strategy.slmRepackAhead == 0 && strategy.B_copies == 1) {
8736 auto nreg = getRegCount(state.Bo_layout);
8737 auto &defaultRegs = state.B_regs[0];
8738 allocBo = (defaultRegs.getLen() < nreg);
8739
8740 if (!allocBo) {
8741 state.Bo_regs = defaultRegs;
8742 state.boReuseB = true;
8743 }
8744 }
8745 }
8746
8747 if (allocAo && !state.allocedAo) {
8748 state.allocedAo = true;
8749 state.Ao_regs = state.ra.alloc_range(getRegCount(state.Ao_layout));
8750 }
8751
8752 if (allocBo && !state.allocedBo) {
8753 state.allocedBo = true;
8754 state.Bo_regs = state.ra.alloc_range(getRegCount(state.Bo_layout));
8755 }
8756}
8757
8758// Prepare layout for row/column sum matrices, and any needed auxiliary registers.
8759template <HW hw>
8760void gemm_kernel_generator_t<hw>::makeSumLayout(bool column, Type Tsrc,
8761 const vector<RegisterBlock> &srcLayout, Type Tdst,
8762 vector<RegisterBlock> &dstLayout, const CommonStrategy &strategy,
8763 CommonState &state) {
8764 bool canDP4A = (hw >= HW::Gen12LP) && one_of(Tsrc, Type::s8, Type::u8)
8765 && one_of(Tdst, Type::s32, Type::u32);
8766 bool cm = isLayoutColMajor(srcLayout);
8767 bool hReduce = (column == cm);
8768 bool needAll1s = false;
8769 int m, n, cp = 1;
8770
8771 getLayoutDims(srcLayout, m, n);
8772 auto &rdim = column ? m : n;
8773
8774 if (Tsrc.size() == Tdst.size()) cp = srcLayout[0].crosspack;
8775
8776 if (hReduce) {
8777 if (canDP4A && hasFullCrosspack(srcLayout, 1)) {
8778 rdim /= 4;
8779 needAll1s = true;
8780 if (rdim & 1) rdim <<= 1; // Ensure dp4a dest offset is even.
8781 }
8782 } else {
8783 if (canDP4A && hasFullCrosspack(srcLayout, 4)) needAll1s |= (rdim >= 4);
8784 rdim = 1;
8785 cp = 1;
8786 }
8787
8788 bool partials = canSwizzle(hw, Tdst);
8789 makeUnbackedRegLayout(Tdst, dstLayout, m, n, cm, cp, 0, 0, partials);
8790
8791 // Prepare all-1s immediate for dp4a.
8792 if (needAll1s && state.all1s.isInvalid()) {
8793 state.all1s = state.ra.alloc_sub(
8794 Tdst.ngen(), getHint(HintType::LongTerm, strategy));
8795 mov(1, state.all1s, 0x01010101);
8796 }
8797}
8798
8799// Accumulate row/column sums.
8800template <HW hw>
8801void gemm_kernel_generator_t<hw>::accumulateSum(bool column, Type Tsrc,
8802 const GRFMultirange &srcRegs, const vector<RegisterBlock> &srcLayout,
8803 Type Tdst, const GRFMultirange &dstRegs,
8804 const vector<RegisterBlock> &dstLayout, const CommonStrategy &strategy,
8805 CommonState &state, int q0, int q1) {
8806 bool canDP4A = (hw >= HW::Gen12LP) && one_of(Tsrc, Type::s8, Type::u8)
8807 && one_of(Tdst, Type::s32, Type::u32);
8808
8809 bool cm = isLayoutColMajor(srcLayout);
8810 if (cm != isLayoutColMajor(dstLayout)) stub();
8811
8812 int m, n;
8813 getLayoutDims(srcLayout, m, n);
8814
8815 // x: consecutive dimension in src; y: strided dimension in src
8816 auto nx = cm ? m : n;
8817 auto ny = cm ? n : m;
8818
8819 int x0 = 0, y0 = 0;
8820 int x1 = nx, y1 = ny;
8821
8822 if (q1 >= 0) ((column == cm) ? x1 : y1) = q1;
8823 if (q0 >= 0) ((column == cm) ? x0 : y0) = q0;
8824
8825 // Two cases to handle:
8826 // hReduce = false: Good case; no reduction. Sum is vector of size mx1 or 1xn.
8827 // hReduce = true: Bad case; needs reduction later, although with dp4a some reduction can be done now.
8828 bool hReduce = (column == cm);
8829
8830 int yinc = 1;
8831 int reduce = (canDP4A && hReduce) ? 4 : 1;
8832 if (x0 % reduce || x1 % reduce) stub();
8833
8834 GRFRange temp;
8835
8836 for (int y = y0; y < y1; y += yinc) {
8837 for (int x = x0; x < x1;) {
8838 int isrc, jsrc, idst, jdst, nsrc, ndst;
8839 const RegisterBlock *blockSrc, *blockDst;
8840
8841 isrc = cm ? x : y;
8842 jsrc = cm ? y : x;
8843 if (!hReduce) {
8844 idst = cm ? x : 0;
8845 jdst = cm ? 0 : x;
8846 } else {
8847 idst = cm ? x / reduce : y;
8848 jdst = cm ? y : x / reduce;
8849 }
8850
8851 Subregister srcBase = findBlockReg(
8852 Tsrc, srcLayout, isrc, jsrc, srcRegs, nsrc, blockSrc);
8853 Subregister dstBase = findBlockReg(
8854 Tdst, dstLayout, idst, jdst, dstRegs, ndst, blockDst);
8855 nsrc = std::min(nsrc, x1 - x);
8856 int neMax = elementsPerGRF(hw, Tdst) * 2;
8857 if (Tdst == Type::f32 && Tsrc.size() < 4) neMax /= 2;
8858 auto ne = std::min({nsrc / reduce, ndst, neMax});
8859
8860 auto src = srcBase(blockSrc->crosspack);
8861 auto dst = dstBase(blockDst->crosspack);
8862
8863 bool hsMatch = (src.getHS() * Tsrc == dst.getHS() * Tdst);
8864 if (Tsrc == Type::bf16 && Tdst == Type::f32)
8865 hsMatch = (src.getHS() == 1) && (dst.getHS() == 1);
8866
8867 if (!canSwizzle(hw, Tsrc) && ne > 1
8868 && (srcBase.getOffset() != dstBase.getOffset()
8869 || !hsMatch)) {
8870 if (temp.isInvalid()) temp = state.ra.alloc_range(2);
8871 auto srcI = src;
8872 int tmpHS
8873 = std::max<int>(1, (blockDst->crosspack * Tdst) / Tsrc);
8874 if (Tsrc == Type::bf16 && Tdst == Type::f32)
8875 tmpHS = blockDst->crosspack;
8876 auto tmpBase = temp[0].sub(
8877 dst.getByteOffset() / Tsrc.real(), src.getType());
8878 auto tmp = tmpBase(tmpHS);
8879 auto tmpI = tmp;
8880 moveToIntPipe(ne, srcI);
8881 moveToIntPipe(ne, tmpI);
8882 mov(ne, tmpI, srcI);
8883 src = tmp;
8884 srcBase = tmpBase;
8885 }
8886
8887 if (Tsrc == Type::f16 && Tdst == Type::f32 && hw >= HW::Gen12LP) {
8888 if (temp.isInvalid()) temp = state.ra.alloc_range(2);
8889 if (src.getHS() < 2) stub();
8890 auto tmpF = temp[0].sub(src.getByteOffset() / Type::f32,
8891 DataType::f)(src.getHS() / 2);
8892 mov(ne, tmpF, src);
8893 src = tmpF;
8894 }
8895
8896 if (canDP4A) {
8897 auto srcDP4A
8898 = Tsrc.isSigned() ? srcBase.d()(1) : srcBase.ud()(1);
8899 if (!hReduce && blockSrc->crosspack == 4) {
8900 yinc = std::min(y1 - y, 4);
8901 if (yinc == 4)
8902 dp4a(ne, dst, dst, srcDP4A, state.all1s);
8903 else if (yinc == 1)
8904 add(ne, dst, srcBase(4), dst);
8905 else
8906 dp4a(ne, dst, dst, srcDP4A,
8907 0x01010101 & ((1 << (yinc * 8)) - 1));
8908 } else if (hReduce && blockSrc->crosspack == 1) {
8909 if (Tsrc.isSigned())
8910 dp4a(ne, dst, dst, srcDP4A, state.all1s);
8911 else {
8912 // Workaround for suspected HW issue.
8913 dst.setType(DataType::ud);
8914 dp4a(ne, dst, dst, srcDP4A, state.all1s.ud());
8915 }
8916 }
8917 } else
8918 eadd(ne, dst, dst, src, strategy, state);
8919
8920 x += ne * reduce;
8921 }
8922 }
8923
8924 state.ra.safeRelease(temp);
8925}
8926
8927template <HW hw>
8928void gemm_kernel_generator_t<hw>::setupTeardownAccumulateSumSystolic(bool setup,
8929 Type T, const GEMMProblem &problem, const GEMMStrategy &strategy,
8930 GEMMState &state) {
8931 auto &sysSumAll1s = state.sysSumAll1s;
8932
8933 if (setup) {
8934 if (sysSumAll1s.isInvalid()) {
8935 sysSumAll1s = state.ra.alloc();
8936 sysSumAll1s.setType(T.ngen());
8937
8938 int ne = elementsPerGRF(hw, T);
8939 if (T == Type::s8 || T == Type::u8)
8940 mov(ne / 4, sysSumAll1s.ud(), uint32_t(0x01010101));
8941 else if (T == Type::bf16)
8942 mov(ne, sysSumAll1s.uw(), uint16_t(0x3F80));
8943 else
8944 mov(ne, sysSumAll1s.retype(T.arithmetic().ngen()),
8945 cast(T.arithmetic(), 1.0));
8946 }
8947 } else
8948 state.ra.safeRelease(sysSumAll1s);
8949}
8950
8951// Horizontally add intermediate sums if needed.
8952template <HW hw>
8953void gemm_kernel_generator_t<hw>::horizontalAdd(bool column, Type T,
8954 const GRFMultirange &regs, vector<RegisterBlock> &layout,
8955 CommonState &state) {
8956 bool cm = isLayoutColMajor(layout);
8957 if (cm != column) return; // Nothing to do.
8958
8959 int m, n, cp;
8960 getLayoutDims(layout, m, n);
8961 cp = layout[0].crosspack;
8962
8963 int nx = cm ? m : n;
8964 int ny = cm ? n : m;
8965 int ne = elementsPerGRF(hw, T);
8966 bool swizzleOK = canSwizzle(hw, T);
8967
8968 GRF tempGRF;
8969 if (!swizzleOK && nx > 1) tempGRF = state.ra.alloc();
8970
8971 int nsLimit = (2 * elementsPerGRF(hw, T)) / cp;
8972
8973 for (int chunk = roundup_pow2(nx) >> 1; chunk > 0; chunk >>= 1) {
8974 for (int y = 0; y < ny; y += cp) {
8975 for (int x = chunk; x < (chunk * 2) && x < nx;) {
8976 int i = cm ? x : y;
8977 int j = cm ? y : x;
8978 int ns, nb;
8979 const RegisterBlock *block;
8980 Subregister shifted
8981 = findBlockReg(T, layout, i, j, regs, ns, block);
8982
8983 ns = std::min({ns, chunk, nsLimit});
8984 (cm ? i : j) -= chunk;
8985 Subregister base
8986 = findBlockReg(T, layout, i, j, regs, nb, block);
8987
8988 auto dest = base;
8989 if (chunk == 1) dest = regs[y / ne].sub(y % ne, T.ngen());
8990
8991 int ne = ns * cp;
8992
8993 if (!swizzleOK && chunk * cp > 1
8994 && shifted.getOffset() != base.getOffset()) {
8995 auto temp = tempGRF.sub(base.getOffset(), T.ngen());
8996 auto tempI = temp;
8997 auto shiftedI = shifted;
8998 moveToIntPipe(tempI);
8999 moveToIntPipe(shiftedI);
9000 mov(ne, tempI(1), shiftedI(1));
9001 if (base == dest)
9002 add(ne, base(1), base(1), temp(1));
9003 else
9004 for (int q = 0; q < ne; q++) {
9005 add(1, dest, base, temp);
9006 dest.setOffset(dest.getOffset() + 1);
9007 base.setOffset(base.getOffset() + 1);
9008 temp.setOffset(temp.getOffset() + 1);
9009 }
9010 } else
9011 add(ne, dest(1), base(1), shifted(1));
9012
9013 x += ns;
9014 }
9015 }
9016 }
9017
9018 state.ra.safeRelease(tempGRF);
9019
9020 (cm ? m : n) = 1;
9021 makeUnbackedRegLayout(T, layout, m, n, !cm, 1);
9022}
9023
9024// Get final A/B sums. For SLM copy kernels, this requires accumulating each thread's contributions.
9025template <HW hw>
9026bool gemm_kernel_generator_t<hw>::gemmFinalizeSums(const GEMMProblem &problem,
9027 const GEMMStrategy &strategy, GEMMState &state) {
9028 bool doA = problem.needsASums();
9029 bool doB = problem.needsBSums();
9030 bool doASLM = state.slmASums && (strategy.wg[LoopN] > 1);
9031 bool doBSLM = state.slmBSums && (strategy.wg[LoopM] > 1);
9032
9033 if (!doA && !doB) return true;
9034
9035 auto Tc = problem.Tc;
9036 auto unrollM = strategy.unroll[LoopM];
9037 auto unrollN = strategy.unroll[LoopN];
9038 bool ok = true;
9039
9040 int ms = 0, ns = 0;
9041 if (doA) getLayoutDims(state.As_layout, ms, ns);
9042 bool reduceAs = (ns > 1);
9043 if (doB) getLayoutDims(state.Bs_layout, ms, ns);
9044 bool reduceBs = (ms > 1);
9045
9046 if (reduceAs && doA && !doASLM)
9047 horizontalAdd(false, Tc, state.As_regs, state.As_layout, state);
9048 if (reduceBs && doB && !doBSLM)
9049 horizontalAdd(true, Tc, state.Bs_regs, state.Bs_layout, state);
9050
9051 if (!doASLM && !doBSLM) return true;
9052
9053 if (state.effCoopA == CoopSplit::Linear
9054 || state.effCoopB == CoopSplit::Linear)
9055 stub();
9056 bool A_coopSplitM = (state.effCoopA == CoopSplit::MN);
9057 bool B_coopSplitN = (state.effCoopB == CoopSplit::MN);
9058
9059 GRFMultirange *ABs_regs[2] = {&state.As_regs, &state.Bs_regs};
9060 bool AB_coopSplitMN[2] = {A_coopSplitM, B_coopSplitN};
9061 vector<RegisterBlock> *ABs_layout[2] = {&state.As_layout, &state.Bs_layout};
9062
9063 vector<RegisterBlock> ABs_layoutSLM[2];
9064 MatrixAddressing ABs_SLM[2];
9065 MatrixAddressingStrategy ABs_strategySLM[2];
9066 MatrixAddressingStrategy ABs_strategySLMAtomic[2];
9067 vector<GRFRange> ABs_addrs[2];
9068 GRF temp = state.ra.alloc();
9069 FlagRegister leader[2];
9070 Subregister ABs_base[2];
9071
9072 if (state.r0_info.isARF()) stub();
9073 GRF r0_info {state.r0_info.getBase()};
9074
9075 // Plan:
9076 // 1) First thread of each m/n-block (leader) stores its sums in SLM; barrier
9077 // 2) Remaining threads atomically add their sums to the first; barrier
9078 // 3) All threads read final sums
9079 // For scattered SLM write kernels, threads have accumulated disjoint parts
9080 // of the sums, so the second step isn't needed. However, each thread needs
9081 // to do a horizontal reduction first.
9082
9083 // Wait for previous SLM reads to complete.
9084 // In the meantime, finish sum reduction if necessary.
9085 status << "Finalize A/B sums" << status_stream::endl;
9086
9087 if (hw >= HW::Gen11) slmfence(temp, r0_info);
9088 MOCK_BARRIERS barriersignal(temp, r0_info);
9089
9090 if (doASLM && A_coopSplitM)
9091 horizontalAdd(false, Tc, state.As_regs, state.As_layout, state);
9092 if (doBSLM && B_coopSplitN)
9093 horizontalAdd(true, Tc, state.Bs_regs, state.Bs_layout, state);
9094
9095 MOCK_BARRIERS barrierwait();
9096
9097 auto step1 = [&](bool isB, int r, int c) {
9098 ABs_SLM[isB].setAlignment(r * c * Tc);
9099 ABs_SLM[isB].crosspack = 1;
9100 ABs_SLM[isB].layout = !isB ? MatrixLayout::Pc : MatrixLayout::Pr;
9101 ABs_SLM[isB].packSize = r * c;
9102 // Use pseudoblock to share address registers between regular and atomic accesses.
9103 ABs_strategySLMAtomic[isB].base = AddressBase::createSLM();
9104 ABs_strategySLMAtomic[isB].padded = true;
9105 ABs_strategySLMAtomic[isB].accessType = AB_coopSplitMN[isB]
9106 ? AccessType::Block
9107 : AccessType::PseudoBlock;
9108 ABs_strategySLMAtomic[isB].atomic = !AB_coopSplitMN[isB];
9109 ABs_strategySLMAtomic[isB].newDP = (hw >= HW::XeHPG);
9110 ABs_strategySLM[isB] = ABs_strategySLMAtomic[isB];
9111 ABs_strategySLM[isB].atomic = false;
9112
9113 ok = ok
9114 && getRegLayout(Tc, ABs_layoutSLM[isB], r, c, false, false,
9115 true, true, 0, 0, ABs_SLM[isB],
9116 ABs_strategySLMAtomic[isB])
9117 && matchLayouts(Tc, ABs_layoutSLM[isB], *ABs_layout[isB]);
9118
9119 Subregister adjBase = ABs_base[isB] = state.ra.alloc_sub<uint32_t>();
9120 uint32_t slmOffset
9121 = (isB && doASLM) ? (unrollM * strategy.wg[LoopM] * Tc) : 0;
9122
9123 !isB ? mulConstant(1, ABs_base[isB], state.lidM, unrollM * Tc)
9124 : mulConstant(1, ABs_base[isB], state.lidN, unrollN * Tc);
9125
9126 if (strategy.kParallelLocal) {
9127 slmOffset *= strategy.wg[LoopK];
9128 int perK = !isB ? strategy.wg[LoopM] * unrollM * Tc
9129 : strategy.wg[LoopN] * unrollN * Tc;
9130 emad(1, ABs_base[isB], ABs_base[isB], state.lidK, perK, strategy,
9131 state);
9132 }
9133
9134 if (slmOffset != 0) add(1, ABs_base[isB], ABs_base[isB], slmOffset);
9135
9136 if (AB_coopSplitMN[isB]) {
9137 adjBase = state.ra.alloc_sub<uint32_t>();
9138 !isB ? mulConstant(1, adjBase, state.lidN, state.ma_slm * Tc)
9139 : mulConstant(1, adjBase, state.lidM, state.nb_slm * Tc);
9140 add(1, adjBase, adjBase, ABs_base[isB]);
9141 }
9142
9143 allocAddrRegs(ABs_addrs[isB], ABs_layoutSLM[isB], ABs_SLM[isB],
9144 ABs_strategySLMAtomic[isB], state);
9145 setupAddr(Tc, ABs_addrs[isB], adjBase, ABs_layoutSLM[isB],
9146 Subregister(), ABs_SLM[isB], ABs_strategySLMAtomic[isB],
9147 strategy, state);
9148
9149 if (AB_coopSplitMN[isB]) state.ra.safeRelease(adjBase);
9150
9151 Label labelNoStore;
9152 if (!AB_coopSplitMN[isB]) {
9153 leader[isB] = state.raVFlag.alloc();
9154 cmp(16 | eq | leader[isB], !isB ? state.lidN : state.lidM, 0);
9155 if_(16 | leader[isB], labelNoStore);
9156 }
9157 storeMatrix(*ABs_regs[isB], ABs_layoutSLM[isB], ABs_SLM[isB],
9158 ABs_strategySLM[isB], ABs_addrs[isB], strategy, state);
9159 if (!AB_coopSplitMN[isB]) {
9160 mark(labelNoStore);
9161 endif(16);
9162 }
9163 };
9164
9165 bool barrier2 = false;
9166 auto step2 = [&](bool isB) {
9167 allocEAtomicAddRegs(hw, Tc, ABs_layoutSLM[isB], ABs_SLM[isB],
9168 ABs_strategySLMAtomic[isB], state, state.flagAP);
9169
9170 Label labelNoAdd;
9171 if_(16 | ~leader[isB], labelNoAdd);
9172 atomicAddMatrix(Tc, *ABs_regs[isB], ABs_layoutSLM[isB], ABs_SLM[isB],
9173 ABs_strategySLMAtomic[isB], ABs_addrs[isB], problem, strategy,
9174 state);
9175 mark(labelNoAdd);
9176 endif(16);
9177 barrier2 = true;
9178
9179 freeEAtomicAddRegs(state, state.flagAP);
9180 };
9181
9182 auto step3 = [&](bool isB, int r, int c) {
9183 if (AB_coopSplitMN[isB]) {
9184 safeReleaseRanges(ABs_addrs[isB], state);
9185 ABs_SLM[isB].packSize = r * c;
9186 ABs_SLM[isB].setAlignment(r * c * Tc);
9187 ABs_strategySLM[isB].accessType = AccessType::Block;
9188 ok = ok
9189 && getRegLayout(Tc, ABs_layoutSLM[isB], r, c, false, false,
9190 false, true, 0, 0, ABs_SLM[isB],
9191 ABs_strategySLM[isB]);
9192
9193 auto nregs = getRegCount(ABs_layoutSLM[isB]);
9194 if (nregs > ABs_regs[isB]->getLen()) {
9195 safeReleaseRanges(*ABs_regs[isB], state);
9196 *ABs_regs[isB] = state.ra.alloc_range(nregs);
9197 }
9198
9199 allocAddrRegs(ABs_addrs[isB], ABs_layoutSLM[isB], ABs_SLM[isB],
9200 ABs_strategySLM[isB], state);
9201 setupAddr(Tc, ABs_addrs[isB], ABs_base[isB], ABs_layoutSLM[isB],
9202 Subregister(), ABs_SLM[isB], ABs_strategySLM[isB], strategy,
9203 state);
9204 }
9205 loadMatrix(*ABs_regs[isB], ABs_layoutSLM[isB], ABs_SLM[isB],
9206 ABs_strategySLM[isB], ABs_addrs[isB], strategy, state);
9207 *ABs_layout[isB] = std::move(ABs_layoutSLM[isB]);
9208 };
9209
9210 if (doASLM) step1(false, state.ma_slm, 1);
9211 if (doBSLM) step1(true, 1, state.nb_slm);
9212
9213 MOCK_BARRIERS slmBarrier(temp, r0_info);
9214
9215 if (doASLM && !A_coopSplitM) step2(false);
9216 if (doBSLM && !B_coopSplitN) step2(true);
9217
9218 MOCK_BARRIERS if (barrier2) slmBarrier(temp, r0_info);
9219
9220 if (doASLM) step3(false, unrollM, 1);
9221 if (doBSLM) step3(true, 1, unrollN);
9222
9223 state.ra.safeRelease(temp);
9224 state.ra.safeRelease(ABs_base[0]);
9225 state.ra.safeRelease(ABs_base[1]);
9226 state.raVFlag.safeRelease(leader[0]);
9227 state.raVFlag.safeRelease(leader[1]);
9228 safeReleaseRanges(ABs_addrs[0], state);
9229 safeReleaseRanges(ABs_addrs[1], state);
9230
9231 return ok;
9232}
9233
9234// Convert register range to a new type.
9235// If types are different sizes, we assume that the smaller type's stride is the width
9236// of the larger type.
9237template <HW hw>
9238void gemm_kernel_generator_t<hw>::convert(const GRFMultirange &range, Type Told,
9239 Type Tnew, const GEMMProblem &problem, const GEMMStrategy &strategy,
9240 GEMMState &state) {
9241 if (Told == Tnew) return;
9242
9243 if (hw == HW::Gen9 && Told == Type::f32 && !Tnew.isFP()) {
9244 // Gen9: round to nearest before downconvert (not done by mov).
9245 map(hw, Told, range, range, strategy,
9246 [&](int esize, GRF r, GRF _) { rnde(esize, r.f(), r.f()); });
9247 }
9248
9249 int maxLS = std::max(Told.log2Size(), Tnew.log2Size());
9250 int hsOld = 1 << (maxLS - Told.log2Size());
9251 int hsNew = 1 << (maxLS - Tnew.log2Size());
9252 auto Tmax = (Told.size() < Tnew.size()) ? Tnew : Told;
9253
9254 InstructionModifier mod;
9255 if (Told != Tnew && Tnew.isInteger() && Tnew.size() <= Told.size())
9256 mod = mod | sat;
9257
9258 map(hw, Tmax, range, range, strategy, [&](int esize, GRF r, GRF _) {
9259 emov(esize | mod, r.sub(0, Tnew.ngen())(hsNew),
9260 r.sub(0, Told.ngen())(hsOld), strategy, state);
9261 });
9262}
9263
9264// Convert C accumulator registers to a new type. Returns true if successful, or false if old and new type are different sizes.
9265template <HW hw>
9266bool gemm_kernel_generator_t<hw>::gemmConvertC(Type Tnew,
9267 const GEMMProblem &problem, const GEMMStrategy &strategy,
9268 GEMMState &state) {
9269 auto Told = state.Tacc;
9270 int ncomp = (problem.Tc.isComplex() && state.C_buffers == 2
9271 && state.cSwapActive)
9272 ? 2
9273 : 1;
9274
9275 if (Tnew.size() != Told.size()) return false;
9276
9277 for (int comp = 0; comp < ncomp; comp++)
9278 convert(state.C_regs[comp], Told, Tnew, problem, strategy, state);
9279
9280 state.Tacc = Tnew;
9281
9282 return true;
9283}
9284
9285// Perform beta scaling.
9286template <HW hw>
9287void gemm_kernel_generator_t<hw>::gemmBetaScale(
9288 GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state) {
9289 Label labelBetaDone;
9290
9291 auto Ts = problem.Ts;
9292 auto &betar = problem.beta_real;
9293
9294 if (state.beta1.isValid()) {
9295 if (strategy.fused) {
9296 cmp(16 | lt | state.flagAP, null.d(), state.beta1, int16_t(0));
9297 goto12(16 | state.flagAP, labelBetaDone);
9298 } else {
9299 cmp(1 | lt | state.flagAP, null.d(), state.beta1, int16_t(0));
9300 jmpi(1 | state.flagAP, labelBetaDone);
9301 }
9302 }
9303
9304 gemmConvertC(problem.Ts, problem, strategy, state);
9305
9306 if (betar != 1) {
9307 map(hw, Ts.real(), state.C_regs[0], state.C_regs[0], strategy,
9308 [&](int esize, GRF acc, GRF _) {
9309 betar.fixed() ? mul(esize, acc, acc, cast(Ts.real(), betar))
9310 : mul(esize, acc, acc,
9311 betar.getRegAvoiding(hw, acc));
9312 });
9313 }
9314
9315 gemmConvertC(problem.Tc, problem, strategy, state);
9316
9317 mark(labelBetaDone);
9318
9319 if (state.beta1.isValid() && strategy.fused) join(16);
9320}
9321
9322template <HW hw>
9323void gemm_kernel_generator_t<hw>::binaryOp(BinaryOp op, int simd,
9324 const ngen::RegData &dst, const ngen::RegData &src0,
9325 const ngen::RegData &src1) {
9326 switch (op) {
9327 case BinaryOp::Add: add(simd, dst, src0, src1); break;
9328 case BinaryOp::Sub: add(simd, dst, src0, -src1); break;
9329 case BinaryOp::Mul: mul(simd, dst, src0, src1); break;
9330 case BinaryOp::Div: stub();
9331 case BinaryOp::Min: min_(simd, dst, src0, src1); break;
9332 case BinaryOp::Max: max_(simd, dst, src0, src1); break;
9333 }
9334}
9335
9336// Apply binary operation to C with a scalar operand.
9337template <HW hw>
9338void gemm_kernel_generator_t<hw>::gemmScalarBinaryOpC(BinaryOp op,
9339 const Subregister &offset, const GEMMProblem &problem,
9340 const GEMMStrategy &strategy, GEMMState &state) {
9341 auto offsetTc = offset.reinterpret(0, state.Tacc.ngen());
9342 if (offset != offsetTc) emov(1, offsetTc, offset, strategy, state);
9343
9344 map(hw, state.Tacc, state.C_regs[0], state.C_layout, strategy,
9345 [&](int simd, const RegData &r) {
9346 binaryOp(op, simd, r, r, offsetTc);
9347 });
9348}
9349
9350// Apply binary operation to C with a vector operand, optionally multiplied by a scalar.
9351template <HW hw>
9352void gemm_kernel_generator_t<hw>::gemmVectorBinaryOpC(BinaryOp op, bool column,
9353 const GRFMultirange &offsets, const Subregister &scale,
9354 const GEMMProblem &problem, const GEMMStrategy &strategy,
9355 GEMMState &state, Type Tco, vector<RegisterBlock> CO_layout, int y0,
9356 int y1) {
9357 auto Tacc = state.Tacc;
9358 auto ne = elementsPerGRF(hw, Tacc);
9359 auto globalCM = isLayoutColMajor(state.C_layout);
9360 auto unrollX = strategy.unroll[globalCM ? LoopM : LoopN];
9361 auto unrollY = strategy.unroll[globalCM ? LoopN : LoopM];
9362 auto crosspack = CO_layout.empty() ? 1 : CO_layout[0].crosspack;
9363 auto stride = [&]() { return (column == globalCM) ? 0 : crosspack; };
9364 const GRFMultirange *offsetsPtr = &offsets;
9365
9366 if (Tco == Type::invalid) Tco = Tacc;
9367
9368 bool needRepack = (Tacc != Tco);
9369 needRepack |= (stride() > 1 && hw >= HW::XeHP && Tacc.isFP());
9370
9371 GRFMultirange repackOffsets;
9372 if (needRepack) {
9373 // Repack data to unit stride as float pipe can't swizzle.
9374 vector<RegisterBlock> repackLayout;
9375 int r = column ? 1 : strategy.unroll[LoopM];
9376 int c = !column ? 1 : strategy.unroll[LoopN];
9377 makeUnbackedRegLayout(Tacc, repackLayout, r, c, !column);
9378 repackOffsets = state.ra.alloc_range(getRegCount(repackLayout));
9379 copyRegisters(Tco, Tacc, CO_layout, repackLayout, offsets,
9380 repackOffsets, 0, 0, false, strategy, state);
9381 crosspack = 1;
9382 offsetsPtr = &repackOffsets;
9383 }
9384
9385 if (y0 < 0) y0 = 0;
9386 if (y1 < 0) y1 = unrollY;
9387
9388 for (int y = y0; y < y1; y++) {
9389 for (int x = 0; x < unrollX;) {
9390 auto i = globalCM ? x : y;
9391 auto j = globalCM ? y : x;
9392 int nc;
9393 const RegisterBlock *C_block;
9394 Subregister C = findBlockReg(
9395 Tacc, state.C_layout, i, j, state.C_regs[0], nc, C_block);
9396
9397 nc = std::min({nc, strategy.fmaSIMD / crosspack, 2 * ne});
9398 auto nco = (column ? j : i) * crosspack;
9399 auto offBase = (*offsetsPtr)[nco / ne].sub(nco % ne, Tacc.ngen());
9400 if (scale.isValid()) {
9401 if (op != BinaryOp::Add) stub();
9402 mad(nc, C(1), C(1), offBase(stride()), scale);
9403 } else
9404 binaryOp(op, nc, C(1), C(1), offBase(stride()));
9405
9406 x += nc;
9407 }
9408 }
9409
9410 safeReleaseRanges(repackOffsets, state);
9411}
9412
9413// Apply binary operation to C.
9414template <HW hw>
9415bool gemm_kernel_generator_t<hw>::gemmBinaryOpC(BinaryOp op, bool row,
9416 bool column, Type Tco, MatrixAddressing CO,
9417 MatrixAddressingStrategy CO_strategy, Subregister base, Subregister ld,
9418 const GEMMProblem &problem, const GEMMStrategy &strategy,
9419 GEMMState &state) {
9420 std::vector<GRFRange> CO_addrs;
9421 std::vector<RegisterBlock> CO_layout;
9422 std::vector<MaskAssignment> masks;
9423 auto globalCM = isLayoutColMajor(state.C_layout);
9424
9425 bool recip = false;
9426 if (op == BinaryOp::Div) {
9427 // Implement div as inv+mul for speed, especially when broadcasting.
9428 recip = true;
9429 op = BinaryOp::Mul;
9430 if (!one_of(Tco, Type::f32, Type::f16)) stub();
9431 }
9432
9433 bool matrix = row && column;
9434 if (matrix) {
9435 // Matrix case implemented as loop over rows/columns, depending on C's layout.
9436 row &= globalCM;
9437 column &= !globalCM;
9438 CO_strategy.accessType = (isColMajor(CO.layout) == row)
9439 ? AccessType::Block
9440 : CO_strategy.base.isStateless() ? AccessType::Scattered
9441 : AccessType::ChannelScattered;
9442 } else {
9443 CO.layout = column ? MatrixLayout::T : MatrixLayout::N;
9444 CO_strategy.accessType = AccessType::Block;
9445 }
9446
9447 bool coColMajor = isColMajor(CO.layout);
9448
9449 auto cor = row ? strategy.unroll[LoopM] : 1;
9450 auto coc = column ? strategy.unroll[LoopN] : 1;
9451 auto remR = row && !CO_strategy.padded;
9452 auto remC = column && !CO_strategy.padded;
9453
9454 if (!getRegLayout(Tco, CO_layout, cor, coc, remR, remC, false, true, 0, 0,
9455 CO, CO_strategy))
9456 return false;
9457
9458 auto CO_regs = state.ra.alloc_range(getRegCount(CO_layout));
9459
9460 allocAddrRegs(CO_addrs, CO_layout, CO, CO_strategy, state);
9461 setupAddr(Tco, CO_addrs, base, CO_layout, ld, CO, CO_strategy, strategy,
9462 state);
9463
9464 if (!assignMasks(CO_layout, LoopM, LoopN, masks, strategy, state, true))
9465 return false;
9466
9467 loadMasks(masks, state.remainders, strategy, state);
9468
9469 if (matrix) {
9470 auto LoopY = globalCM ? LoopN : LoopM;
9471 auto unrollY = strategy.unroll[LoopY];
9472 auto remY = state.remainders[LoopY];
9473 Label lDone;
9474 bool simtCF = strategy.fused && (strategy.fusedLoop == LoopY);
9475 int simt = simtCF ? 16 : 1;
9476
9477 if (!CO_strategy.padded) cmp(simt | gt | state.flagAP, remY, 0);
9478
9479 for (int y = 0; y < unrollY; y++) {
9480 if (!CO_strategy.padded) {
9481 simtCF ? goto12(16 | ~state.flagAP, lDone)
9482 : jmpi(1 | ~state.flagAP, lDone);
9483 }
9484 loadMatrix(CO_regs, CO_layout, CO, CO_strategy, CO_addrs, strategy,
9485 state);
9486 if (recip)
9487 map(hw, Tco, CO_regs, CO_regs, strategy,
9488 [&](int simd, GRF r, GRF) { inv(simd, r, r); });
9489 if (!CO_strategy.padded && (y + 1 < unrollY))
9490 cmp(simt | gt | state.flagAP, remY, y + 1);
9491 if (coColMajor == globalCM)
9492 incAddr(CO_addrs, ld, int(row), int(column), CO_layout, CO,
9493 CO_strategy, strategy, state);
9494 else
9495 incAddr(CO_addrs, Tco.size(), int(row), int(column), CO_layout,
9496 CO, CO_strategy, strategy, state);
9497
9498 gemmVectorBinaryOpC(op, column, CO_regs, Subregister(), problem,
9499 strategy, state, Tco, CO_layout, y, y + 1);
9500 }
9501
9502 mark(lDone);
9503 if (simtCF) join(16);
9504 } else {
9505 loadMatrix(
9506 CO_regs, CO_layout, CO, CO_strategy, CO_addrs, strategy, state);
9507 if (recip)
9508 map(hw, Tco, CO_regs, CO_regs, strategy,
9509 [&](int simd, GRF r, GRF) { inv(simd, r, r); });
9510
9511 if (!row && !column)
9512 gemmScalarBinaryOpC(op, CO_regs[0].sub(0, Tco.ngen()), problem,
9513 strategy, state);
9514 else
9515 gemmVectorBinaryOpC(op, column, CO_regs, Subregister(), problem,
9516 strategy, state, Tco, CO_layout);
9517 }
9518
9519 safeReleaseMaskAssignments(masks, state);
9520 state.ra.safeRelease(CO_regs);
9521 safeReleaseRanges(CO_addrs, state);
9522
9523 return true;
9524}
9525
9526// Check kernel input for desired C offset and apply it.
9527template <HW hw>
9528bool gemm_kernel_generator_t<hw>::gemmApplyCOffsetDispatch(
9529 const GEMMProblem &problem, const GEMMStrategy &strategy,
9530 GEMMState &state) {
9531 Label labelCOColumn, labelCORow, labelCOMatrix, labelCODone;
9532 bool doMatrix = problem.allowMatrixOffset();
9533 auto Tco = problem.Tco;
9534 auto &CO = problem.CO;
9535 auto &CO_strategy = strategy.CO;
9536 auto &effCO = state.effCO;
9537 auto &ldco = state.inputs.ldco;
9538
9539 bool ok = true;
9540
9541 if (state.flagSwizzle.isValid()) state.raVFlag.release(state.flagSwizzle);
9542
9543 auto flagNonfinal = state.raVFlag.alloc();
9544 auto flagCOC = state.raVFlag.alloc();
9545 auto flagCOR = state.raVFlag.alloc();
9546
9547 and_(1 | nz | flagNonfinal, null.ud(), state.inputs.flags,
9548 FlagNonfinalKBlock);
9549 and_(1 | nz | flagCOC, null.ud(), state.inputs.flags, FlagCOColumn);
9550 and_(1 | nz | flagCOR, null.ud(), state.inputs.flags, FlagCORow);
9551 jmpi(1 | flagNonfinal, labelCODone);
9552 jmpi(1 | flagCOC, labelCOColumn);
9553 jmpi(1 | flagCOR, labelCORow);
9554
9555 state.raVFlag.safeRelease(flagNonfinal);
9556 state.raVFlag.safeRelease(flagCOC);
9557 state.raVFlag.safeRelease(flagCOR);
9558
9559 if (state.flagSwizzle.isValid()) state.raVFlag.claim(state.flagSwizzle);
9560
9561 status << "Applying fixed C offset" << status_stream::endl;
9562 ok = ok
9563 && gemmBinaryOpC(BinaryOp::Add, false, false, Tco, CO, CO_strategy,
9564 effCO, ldco, problem, strategy, state);
9565 jmpi(1, labelCODone);
9566
9567 mark(labelCOColumn);
9568 if (doMatrix) jmpi(1 | flagCOR, labelCOMatrix);
9569 status << "Applying column-wise C offset" << status_stream::endl;
9570 ok = ok
9571 && gemmBinaryOpC(BinaryOp::Add, false, true, Tco, CO, CO_strategy,
9572 effCO, ldco, problem, strategy, state);
9573 jmpi(1, labelCODone);
9574
9575 mark(labelCORow);
9576 status << "Applying row-wise C offset" << status_stream::endl;
9577 ok = ok
9578 && gemmBinaryOpC(BinaryOp::Add, true, false, Tco, CO, CO_strategy,
9579 effCO, ldco, problem, strategy, state);
9580
9581 if (doMatrix) {
9582 jmpi(1, labelCODone);
9583
9584 mark(labelCOMatrix);
9585 status << "Applying matrix C offset" << status_stream::endl;
9586 ok = ok
9587 && gemmBinaryOpC(BinaryOp::Add, true, true, Tco, CO,
9588 CO_strategy, effCO, ldco, problem, strategy, state);
9589 }
9590
9591 mark(labelCODone);
9592
9593 if (!strategy.persistent) {
9594 state.ra.safeRelease(ldco);
9595 state.ra.safeRelease(effCO);
9596 }
9597
9598 return ok;
9599}
9600
9601static inline BinaryOp dnnlToBinaryOp(alg_kind_t kind) {
9602 switch (kind) {
9603 case alg_kind::binary_add: return BinaryOp::Add;
9604 case alg_kind::binary_sub: return BinaryOp::Sub;
9605 case alg_kind::binary_mul: return BinaryOp::Mul;
9606 case alg_kind::binary_div: return BinaryOp::Div;
9607 case alg_kind::binary_min: return BinaryOp::Min;
9608 case alg_kind::binary_max: return BinaryOp::Max;
9609 default: stub();
9610 }
9611}
9612
9613template <HW hw>
9614void gemm_kernel_generator_t<hw>::gemmLoadBinaryOpArgs(
9615 const GEMMProblem &problem, const GEMMStrategy &strategy,
9616 GEMMState &state) {
9617 if (hw < HW::XeHP) stub();
9618
9619 std::vector<ngen::Subregister *> argList;
9620 argList.reserve(state.inputs.binaryOffsets.size() * 5);
9621
9622 for (auto &r : state.inputs.binarySrcs)
9623 if (r.isValid()) argList.push_back(&r);
9624 for (auto &r : state.inputs.binaryOffsets)
9625 if (r.isValid()) argList.push_back(&r);
9626 for (auto &r : state.inputs.binaryLDs)
9627 if (r.isValid()) argList.push_back(&r);
9628
9629 if (problem.batch == BatchMode::Strided) {
9630 for (auto &rs : state.inputs.binaryStrides)
9631 for (auto &r : rs)
9632 if (r.isValid()) argList.push_back(&r);
9633 }
9634
9635 int loadBase = interface.getArgLoadBase().getBase();
9636 int nGRFs = 0;
9637 for (auto arg : argList) {
9638 int base = arg->getBase();
9639 if (base < loadBase) stub();
9640 nGRFs = std::max(nGRFs, base - loadBase + 1);
9641 }
9642
9643 auto temp = state.ra.alloc();
9644 auto args = state.ra.alloc_range(nGRFs);
9645
9646 if (state.r0_info.isARF() || state.r0_info.getBase() != 0) stub();
9647 loadargs(args, nGRFs, temp);
9648
9649 int grfOffset = args.getBase() - loadBase;
9650
9651 state.ra.release(args);
9652
9653 for (auto arg : argList) {
9654 arg->setBase(arg->getBase() + grfOffset);
9655 state.ra.claim(*arg);
9656 }
9657
9658 state.ra.safeRelease(temp);
9659}
9660
9661template <HW hw>
9662void gemm_kernel_generator_t<hw>::gemmApplyPostOps(int poMin, int poMax,
9663 const GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state) {
9664 if (poMin >= poMax) return;
9665
9666 Label lSkip;
9667 and_(1 | nz | state.flagAP, null.ud(), state.inputs.flags,
9668 FlagNonfinalKBlock);
9669
9670 (void)gemmConvertC(problem.Ts, problem, strategy, state);
9671
9672 jmpi(1 | state.flagAP, lSkip);
9673
9674 // Binary preparations: load binary-related args + calculate starting addresses
9675 if (problem.hasBinaryPostOp() && state.effBinary.empty()) {
9676 int poCount = problem.postOps.len();
9677 auto &entries = problem.postOps.entry_;
9678
9679 gemmLoadBinaryOpArgs(problem, strategy, state);
9680
9681#define FOR_EACH_BINARY \
9682 for (int i = 0; i < poCount; i++) \
9683 if (entries[i].is_binary())
9684
9685 FOR_EACH_BINARY {
9686 const auto &ld = state.inputs.binaryLDs[i];
9687 auto T = problem.Tbinary[i];
9688 if (ld.isValid())
9689 emulConstant(1, ld, ld, T.size(), strategy, state);
9690 emulConstant(1, state.inputs.binaryOffsets[i],
9691 state.inputs.binaryOffsets[i], T.size(), strategy, state);
9692 if (problem.batch == BatchMode::Strided)
9693 for (int b = 0; b < problem.batchDims; b++) {
9694 const auto &stride = state.inputs.binaryStrides[i][b];
9695 if (stride.isValid())
9696 emulConstant(
9697 1, stride, stride, T.size(), strategy, state);
9698 }
9699 }
9700
9701 for (int b = 0; b < 2; b++) {
9702 FOR_EACH_BINARY {
9703 const auto &stride = state.inputs.binaryStrides[i][b];
9704 if (stride.isValid())
9705 emul(1, stride, stride, state.batchID[b], strategy, state);
9706 }
9707 }
9708
9709 for (int b = 0; b < 2; b++) {
9710 FOR_EACH_BINARY {
9711 auto &offsetStride = state.inputs.binaryStrides[i][b];
9712 if (offsetStride.isValid())
9713 eadd(1, state.inputs.binaryOffsets[i],
9714 state.inputs.binaryOffsets[i], offsetStride,
9715 strategy, state);
9716 state.ra.safeRelease(offsetStride);
9717 }
9718 }
9719
9720 gemmOffsetABC(true, state.i0, state.j0, state.h0, problem, strategy,
9721 state, false, false, false, true);
9722
9723 state.effBinary.resize(poCount);
9724
9725 FOR_EACH_BINARY {
9726 if (strategy.binary[i].base.isStateless()) {
9727 state.effBinary[i] = state.inputs.binarySrcs[i];
9728 eadd(1, state.effBinary[i], state.inputs.binarySrcs[i],
9729 state.inputs.binaryOffsets[i], strategy, state);
9730 state.ra.safeRelease(state.inputs.binaryOffsets[i]);
9731 } else
9732 state.effBinary[i] = state.inputs.binaryOffsets[i];
9733 }
9734#undef FOR_EACH_BINARY
9735 }
9736
9737 // Apply post-ops to all of C.
9738 for (int i = poMin; i < poMax; i++) {
9739 auto &entry = problem.postOps.entry_[i];
9740 switch (entry.kind) {
9741 case primitive_kind::eltwise: {
9742 using Injector = jit_eltwise_injector_f32<hw>;
9743 if (state.Tacc != Type::f32) stub();
9744
9745 int euCount = 0; /* only used for a DG2 W/A for conv */
9746 auto &ee = entry.eltwise;
9747 Injector injector {this, ee.alg, ee.alpha, ee.beta, ee.scale,
9748 euCount, GRFRange(), problem.postOpFwd};
9749
9750 auto scratch = state.ra.try_alloc_range(
9751 injector.preferred_scratch_regs());
9752 if (scratch.isInvalid())
9753 scratch = state.ra.alloc_range(injector.min_scratch_regs());
9754
9755 injector.set_scratch(scratch);
9756 injector.prepare();
9757 for (auto &rr : state.C_regs[0].ranges)
9758 injector.compute(rr);
9759 break;
9760 }
9761 case primitive_kind::binary: {
9762 auto &ld = state.inputs.binaryLDs[i];
9763 auto &eff = state.effBinary[i];
9764 auto op = dnnlToBinaryOp(entry.binary.alg);
9765
9766 gemmBinaryOpC(op, problem.binaryRow[i], problem.binaryCol[i],
9767 problem.Tbinary[i], problem.binary[i],
9768 strategy.binary[i], eff, ld, problem, strategy, state);
9769
9770 state.ra.safeRelease(ld);
9771 state.ra.safeRelease(eff);
9772 break;
9773 }
9774 default: stub();
9775 }
9776 }
9777
9778 mark(lSkip);
9779}
9780
9781// Calculate addresses of A/B sums in packed input data. Sums are stored at the end of each panel.
9782template <HW hw>
9783void gemm_kernel_generator_t<hw>::gemmCalcABOffsetAddrs(
9784 const GEMMProblem &problem, const GEMMStrategy &strategy,
9785 GEMMState &state) {
9786 auto &effAs = state.effAs;
9787 auto &effBs = state.effBs;
9788
9789 auto Tc = problem.Tc;
9790 auto unrollM = strategy.unroll[LoopM];
9791 auto unrollN = strategy.unroll[LoopN];
9792
9793 if (effAs.isInvalid()) effAs = state.ra.alloc_sub(state.effA.getType());
9794 if (effBs.isInvalid()) effBs = state.ra.alloc_sub(state.effB.getType());
9795
9796 mulConstant(1, effAs.ud(), state.inputs.lda, unrollM);
9797 mulConstant(1, effBs.ud(), state.inputs.ldb, unrollN);
9798 add(1, effAs.ud(), effAs.ud(), -unrollM * Tc);
9799 add(1, effBs.ud(), effBs.ud(), -unrollN * Tc);
9800 eadd(1, effAs, effAs.ud(), state.effA, strategy, state);
9801 eadd(1, effBs, effBs.ud(), state.effB, strategy, state);
9802}
9803
9804// Load A/B sums from packed input data.
9805template <HW hw>
9806bool gemm_kernel_generator_t<hw>::gemmLoadABOffset(const GEMMProblem &problem,
9807 const GEMMStrategy &strategy, GEMMState &state) {
9808 if (problem.abOffset != ABOffset::Load) return true;
9809
9810 auto Tc = problem.Tc;
9811 auto unrollM = strategy.unroll[LoopM];
9812 auto unrollN = strategy.unroll[LoopN];
9813
9814 MatrixAddressing As = problem.A, Bs = problem.B;
9815 As.crosspack = 1;
9816 Bs.crosspack = 1;
9817 As.tileR = As.tileC = 0;
9818 Bs.tileR = Bs.tileC = 0;
9819
9820 MatrixAddressingStrategy As_strategy = strategy.A, Bs_strategy = strategy.B;
9821 As_strategy.accessType = AccessType::Block;
9822 Bs_strategy.accessType = AccessType::Block;
9823 As_strategy.tileR = As_strategy.tileC = 0;
9824 Bs_strategy.tileR = Bs_strategy.tileC = 0;
9825 As_strategy.dpasw = Bs_strategy.dpasw = false;
9826
9827 bool ok = true;
9828 ok = ok
9829 && getRegLayout(Tc, state.As_layout, unrollM, 1, false, false,
9830 false, true, 0, 0, As, As_strategy);
9831 ok = ok
9832 && getRegLayout(Tc, state.Bs_layout, 1, unrollN, false, false,
9833 false, true, 0, 0, Bs, Bs_strategy);
9834 if (!ok) return false;
9835
9836 state.As_regs = state.ra.alloc_range(getRegCount(state.As_layout));
9837 state.Bs_regs = state.ra.alloc_range(getRegCount(state.Bs_layout));
9838
9839 vector<GRFRange> As_addrs, Bs_addrs;
9840 allocAddrRegs(As_addrs, state.As_layout, As, As_strategy, state);
9841 allocAddrRegs(Bs_addrs, state.Bs_layout, Bs, Bs_strategy, state);
9842
9843 if (state.effAs.isInvalid())
9844 gemmCalcABOffsetAddrs(problem, strategy, state);
9845
9846 setupAddr(Tc, As_addrs, state.effAs, state.As_layout, Subregister(), As,
9847 As_strategy, strategy, state);
9848 setupAddr(Tc, Bs_addrs, state.effBs, state.Bs_layout, Subregister(), Bs,
9849 Bs_strategy, strategy, state);
9850
9851 loadMatrix(state.As_regs, state.As_layout, As, As_strategy, As_addrs,
9852 strategy, state);
9853 loadMatrix(state.Bs_regs, state.Bs_layout, Bs, Bs_strategy, Bs_addrs,
9854 strategy, state);
9855
9856 state.ra.safeRelease(state.effAs);
9857 state.ra.safeRelease(state.effBs);
9858 safeReleaseRanges(As_addrs, state);
9859 safeReleaseRanges(Bs_addrs, state);
9860
9861 return true;
9862}
9863
9864// Apply contributions from A/B offsets to C matrix, using previously loaded/computed
9865// A row sums and B column sums.
9866template <HW hw>
9867void gemm_kernel_generator_t<hw>::gemmApplyABOffset(const GEMMProblem &problem,
9868 const GEMMStrategy &strategy, GEMMState &state) {
9869 if (problem.abOffset == ABOffset::None) return;
9870
9871 // Two steps: (O = all-1s matrix)
9872 // 1) C += A * O * bo
9873 // 2) C += (O * B + bo * k) * ao
9874 // TODO: combine C adds into add3 on XeHP+.
9875 auto temp = state.ra.alloc_sub(problem.Tc.ngen());
9876 mul(1, temp, state.k, state.inputs.bo);
9877
9878 bool noFMA = (hw == HW::Gen9);
9879 if (noFMA) {
9880 map(hw, problem.Tc, state.Bs_regs, state.Bs_layout, strategy,
9881 [&](int ne, RegData r) { add(ne, r, r, temp); });
9882 map(hw, problem.Tc, state.As_regs, state.As_layout, strategy,
9883 [&](int ne, RegData r) { mul(ne, r, r, state.inputs.bo); });
9884 map(hw, problem.Tc, state.Bs_regs, state.Bs_layout, strategy,
9885 [&](int ne, RegData r) { mul(ne, r, r, state.inputs.ao); });
9886 } else {
9887 mul(1, temp, temp, state.inputs.ao);
9888 map(hw, problem.Tc, state.Bs_regs, state.Bs_layout, strategy,
9889 [&](int ne, RegData r) {
9890 mad(ne, r, temp, r, state.inputs.ao);
9891 });
9892 }
9893 state.ra.safeRelease(temp);
9894
9895 gemmVectorBinaryOpC(BinaryOp::Add, false, state.As_regs,
9896 noFMA ? Subregister() : state.inputs.bo, problem, strategy, state,
9897 problem.Tc, state.As_layout);
9898 gemmVectorBinaryOpC(BinaryOp::Add, true, state.Bs_regs, Subregister(),
9899 problem, strategy, state, problem.Tc, state.Bs_layout);
9900
9901 safeReleaseRanges(state.As_regs, state);
9902 safeReleaseRanges(state.Bs_regs, state);
9903 if (!strategy.persistent) {
9904 state.ra.safeRelease(state.inputs.ao);
9905 state.ra.safeRelease(state.inputs.bo);
9906 }
9907 state.As_layout.clear();
9908 state.Bs_layout.clear();
9909}
9910
9911// Store A/B sum data into CO.
9912template <HW hw>
9913void gemm_kernel_generator_t<hw>::gemmUpdateSums(const GEMMProblem &problem,
9914 const GEMMStrategy &strategy, GEMMState &state) {
9915 bool sumA = problem.sumA, sumB = problem.sumB;
9916
9917 if (!sumA && !sumB) return;
9918 if (sumA && sumB) stub(); // can only store one of the two in CO.
9919
9920 auto Tc = problem.Tc;
9921 auto Tco = problem.Tco;
9922 auto cor = sumA ? strategy.unroll[LoopM] : 1;
9923 auto coc = sumB ? strategy.unroll[LoopN] : 1;
9924 bool atomic = strategy.CO.atomic;
9925 bool checkBeta0 = problem.checkBeta0 && !problem.beta_real.fixed();
9926 bool checkBeta1 = atomic && !problem.beta1();
9927
9928 auto CO = problem.CO;
9929 auto CO_strategy = strategy.CO;
9930 std::vector<GRFRange> CO_addrs;
9931 std::vector<RegisterBlock> CO_layout;
9932 std::vector<MaskAssignment> masks;
9933 GRFMultirange CO_regs;
9934 CO_strategy.accessType = AccessType::Block;
9935 FlagRegister flagBeta0;
9936
9937 auto &Xs_regs = sumA ? state.As_regs : state.Bs_regs;
9938 auto &Xs_layout = sumA ? state.As_layout : state.Bs_layout;
9939
9940 int Xs_nregs = getRegCount(Xs_layout);
9941 auto Xs_usedRegs = Xs_regs.subrange(0, Xs_nregs);
9942
9943 CO.layout = sumA ? MatrixLayout::N : MatrixLayout::T;
9944
9945 auto remR = sumA && !strategy.CO.padded;
9946 auto remC = sumB && !strategy.CO.padded;
9947
9948 if (!getRegLayout(Tco, CO_layout, cor, coc, remR, remC, true, true, 0, 0,
9949 CO, CO_strategy))
9950 stub();
9951
9952 bool share = (Tc == Tco) && matchLayouts(Tc, CO_layout, Xs_layout);
9953
9954 Label skipStore;
9955 and_(16 | ne | state.flagAP, null.ud(), state.inputs.flags, FlagStoreSums);
9956 if_(16 | state.flagAP, skipStore);
9957
9958 if (checkBeta0) {
9959 flagBeta0 = state.raVFlag.alloc();
9960 cmp0(1 | eq | flagBeta0, problem.beta_real.getReg(0));
9961 }
9962 if (checkBeta1)
9963 cmp(1 | eq | state.flagAP, problem.beta_real.getReg(0),
9964 cast(problem.Ts, 1.0));
9965
9966 allocAddrRegs(CO_addrs, CO_layout, CO, CO_strategy, state);
9967 setupAddr(Tco, CO_addrs, state.effCO, CO_layout, Subregister(), CO,
9968 CO_strategy, strategy, state);
9969
9970 if (!assignMasks(CO_layout, LoopM, LoopN, masks, strategy, state, true))
9971 stub();
9972
9973 loadMasks(masks, state.remainders, strategy, state);
9974
9975 if (!problem.alpha1())
9976 map(hw, Tc, Xs_usedRegs, Xs_usedRegs, strategy,
9977 [&](int esize, GRF acc, GRF _) {
9978 auto &alphar = problem.alpha_real;
9979 alphar.fixed()
9980 ? mul(esize, acc, acc, cast(Tc.real(), alphar))
9981 : mul(esize, acc, acc,
9982 alphar.getRegAvoiding(hw, acc));
9983 });
9984
9985 if (!problem.beta0() && !(problem.beta1() && atomic)) {
9986 Label lSkipUpdate;
9987 auto CO_regsLoad = state.ra.alloc_range(getRegCount(CO_layout));
9988 auto CO_regsLoadConv = CO_regsLoad;
9989
9990 if (checkBeta0) jmpi(1 | flagBeta0, lSkipUpdate);
9991 if (checkBeta1) jmpi(1 | state.flagAP, lSkipUpdate);
9992
9993 loadMatrix(CO_regsLoad, CO_layout, CO, CO_strategy, CO_addrs, strategy,
9994 state);
9995
9996 if (!share) {
9997 CO_regsLoadConv = state.ra.alloc_range(Xs_nregs);
9998 copyRegisters(Tco, Tc, CO_layout, Xs_layout, CO_regsLoad,
9999 CO_regsLoadConv, 0, 0, false, strategy, state);
10000 }
10001
10002 auto &betar = problem.beta_real;
10003
10004 map(hw, Tc, Xs_usedRegs, CO_regsLoadConv, strategy,
10005 [&](int esize, GRF acc, GRF loaded) {
10006 if (betar == 1)
10007 add(esize, acc, acc, loaded);
10008 else if (betar.fixed())
10009 mad(esize, acc, acc, loaded, cast(Tc.real(), betar));
10010 else
10011 mad(esize, acc, acc, loaded,
10012 betar.getRegAvoiding(hw, acc));
10013 });
10014
10015 state.ra.safeRelease(CO_regsLoad);
10016 if (!share) state.ra.safeRelease(CO_regsLoadConv);
10017
10018 if (checkBeta0 || checkBeta1) {
10019 state.wipeActiveVFlags();
10020 mark(lSkipUpdate);
10021 }
10022 }
10023
10024 if (!share) {
10025 CO_regs = state.ra.alloc_range(getRegCount(CO_layout));
10026 copyRegisters(Tc, Tco, Xs_layout, CO_layout, Xs_regs, CO_regs, 0, 0,
10027 false, strategy, state);
10028 safeReleaseRanges(Xs_regs, state);
10029 }
10030
10031 auto &effCO_regs = share ? Xs_regs : CO_regs;
10032 if (atomic) {
10033 Label lStore, lDone;
10034 if (checkBeta1) {
10035 if (!CO_strategy.base.isStateless() && !CO_strategy.newDP)
10036 stub(); /* need to shift addresses */
10037 jmpi(1 | ~state.flagAP, lStore);
10038 }
10039 allocEAtomicAddRegs(
10040 hw, Tco, CO_layout, CO, CO_strategy, state, state.flagAP);
10041 atomicAddMatrix(Tco, effCO_regs, CO_layout, CO, CO_strategy, CO_addrs,
10042 problem, strategy, state);
10043 freeEAtomicAddRegs(state, state.flagAP);
10044 if (checkBeta1) {
10045 state.wipeActiveVFlags();
10046 jmpi(1, lDone);
10047 mark(lStore);
10048 storeMatrix(effCO_regs, CO_layout, CO, CO_strategy, CO_addrs,
10049 strategy, state);
10050 mark(lDone);
10051 }
10052 } else
10053 storeMatrix(effCO_regs, CO_layout, CO, CO_strategy, CO_addrs, strategy,
10054 state);
10055
10056 mark(skipStore);
10057 endif(16);
10058
10059 safeReleaseMaskAssignments(masks, state);
10060
10061 if (!share) safeReleaseRanges(CO_regs, state);
10062 safeReleaseRanges(state.As_regs, state);
10063 safeReleaseRanges(state.Bs_regs, state);
10064 state.raVFlag.safeRelease(flagBeta0);
10065 state.As_layout.clear();
10066 state.Bs_layout.clear();
10067}
10068
10069// Generate code for summing C across k dimension through SLM.
10070template <HW hw>
10071void gemm_kernel_generator_t<hw>::gemmKReduce(const GEMMProblem &problem,
10072 const GEMMStrategy &strategy, GEMMState &state) {
10073 auto Tc = problem.Tc;
10074 Label lDone;
10075
10076 // Early exit if nothing to do. All branching scalar since no fusing in k dimension.
10077 cmp(1 | le | state.flagAP, state.lszK, 1);
10078 jmpi(1 | state.flagAP, lDone);
10079
10080 status << "k reduction through SLM" << status_stream::endl;
10081 cmp(1 | eq | state.flagAP, state.lidK, 0);
10082
10083 auto C_regs = state.C_regs[0];
10084
10085 // Reduce A/B sums at the same time.
10086 if (problem.sumA) C_regs.append(state.As_regs);
10087 if (problem.sumB) C_regs.append(state.Bs_regs);
10088
10089 // In general SLM isn't large enough to do the reduction in one step.
10090 // Slice C into pieces that will fit.
10091 int maxMNThreads = strategy.wg[LoopM] * strategy.wg[LoopN];
10092 if (maxMNThreads <= 0)
10093 throw std::runtime_error("Max workgroup size not specified");
10094
10095 int regs = C_regs.getLen();
10096 int sliceRegs = int(gemmPerKSLMSize(problem, strategy)
10097 / (maxMNThreads * GRF::bytes(hw)));
10098 if (sliceRegs <= 0)
10099 throw std::runtime_error("Not enough SLM for k reduction");
10100 sliceRegs = std::min<int>(sliceRegs, C_regs.getLen());
10101
10102 // Temporaries.
10103 auto kt = state.ra.alloc_sub<int32_t>();
10104 auto flagKTLoop = state.raVFlag.alloc();
10105 auto barrierTemp = state.ra.alloc();
10106
10107 if (state.r0_info.isARF()) stub();
10108 GRF r0_info {state.r0_info.getBase()};
10109
10110 bool initialBarrier = (strategy.slmBuffers > 0 || strategy.persistent);
10111 MOCK_BARRIERS if (initialBarrier) barriersignal(barrierTemp, r0_info);
10112
10113 // Set up addressing.
10114 auto addr0 = state.ra.alloc_sub<uint32_t>();
10115 emad(1, addr0, state.lidM, state.lidN, strategy.wg[LoopM], strategy, state);
10116 emad(1, addr0, addr0, state.lidK, strategy.wg[LoopM] * strategy.wg[LoopN],
10117 strategy, state);
10118 mulConstant(1, addr0, addr0, sliceRegs * GRF::bytes(hw));
10119
10120 int unrollKSLMStride = strategy.wg[LoopM] * strategy.wg[LoopN] * sliceRegs
10121 * GRF::bytes(hw);
10122 Subregister unrollKSLMReturn = state.ra.alloc_sub<int32_t>();
10123
10124 mulConstant(1, unrollKSLMReturn, -state.lszK, unrollKSLMStride);
10125
10126 MatrixAddressing C_slm;
10127 MatrixAddressingStrategy C_slmStrategy;
10128
10129 C_slm.layout = MatrixLayout::Pc;
10130 C_slm.packSize = elementsPerGRF(hw, Tc);
10131 C_slm.crosspack = 1;
10132 C_slm.setAlignment(GRF::bytes(hw));
10133
10134 C_slmStrategy.base = SLM;
10135 C_slmStrategy.accessType = AccessType::Block;
10136 C_slmStrategy.padded = true;
10137 if (hw >= HW::XeHPG) C_slmStrategy.newDP = true;
10138
10139 GRFRange C_load;
10140 vector<RegisterBlock> C_slmLayout;
10141 vector<GRFRange> C_slmAddrs;
10142
10143 // Find maximum # registers of C we can transfer to/from SLM at once.
10144 int maxContig = rounddown_pow2(regs);
10145 for (; maxContig > 1; maxContig >>= 1) {
10146 bool ok = true;
10147 for (int offsetReg = 0; offsetReg < regs; offsetReg += maxContig) {
10148 int nr = std::min(regs - offsetReg, maxContig);
10149 if (!C_regs.contiguous(offsetReg, nr)) {
10150 ok = false;
10151 break;
10152 }
10153 }
10154 if (ok) break;
10155 }
10156
10157 // Allocate address and data registers, automatically shrinking sliceRegs if
10158 // there are not enough registers.
10159 for (; sliceRegs > 0; sliceRegs = rounddown_pow2(sliceRegs - 1)) {
10160 bool ok = true;
10161
10162 C_load = state.ra.try_alloc_range(sliceRegs);
10163 ok = ok && C_load.isValid();
10164
10165 if (!getRegLayout(Tc, C_slmLayout, elementsPerGRF(hw, Tc), sliceRegs,
10166 false, false, true, true, 0, maxContig, C_slm,
10167 C_slmStrategy))
10168 stub();
10169 ok = ok
10170 && tryAllocAddrRegs(
10171 C_slmAddrs, C_slmLayout, C_slm, C_slmStrategy, state);
10172
10173 if (ok) break;
10174
10175 state.ra.safeRelease(C_load);
10176 }
10177
10178 if (sliceRegs <= 0) throw out_of_registers_exception();
10179
10180 setupAddr(Tc, C_slmAddrs, addr0, C_slmLayout, Subregister(), C_slm,
10181 C_slmStrategy, strategy, state);
10182
10183 MOCK_BARRIERS if (initialBarrier) barrierwait();
10184
10185 // Loop over slices.
10186 for (int rr = 0; rr < regs; rr += sliceRegs) {
10187 Label lSkipWrite, lSkipReduce, lTop;
10188
10189 int nreg = std::min(sliceRegs, regs - rr);
10190 auto C_range = C_regs.subrange(rr, nreg);
10191
10192 MOCK_BARRIERS if (rr > 0) slmBarrier(barrierTemp, r0_info);
10193
10194 // Trim down SLM layout for final loop.
10195 if (nreg < sliceRegs) {
10196 vector<RegisterBlock> sublayout;
10197 vector<GRFRange> subaddrs;
10198 if (!getSubblocks(Tc, sublayout, subaddrs, C_slmLayout, C_slmAddrs,
10199 true, 0, nreg, true, C_slm, C_slmStrategy))
10200 stub();
10201 std::swap(sublayout, C_slmLayout);
10202 std::swap(subaddrs, C_slmAddrs);
10203 }
10204
10205 // Non-leaders write to SLM.
10206 jmpi(1 | state.flagAP, lSkipWrite);
10207 storeMatrix(C_range, C_slmLayout, C_slm, C_slmStrategy, C_slmAddrs,
10208 strategy, state);
10209 mark(lSkipWrite);
10210
10211 MOCK_BARRIERS slmBarrier(barrierTemp, r0_info);
10212
10213 // Leader reads SLM data and accumulates C.
10214 jmpi(1 | ~state.flagAP, lSkipReduce);
10215 add(1, kt, state.lszK, -1);
10216 incAddr(C_slmAddrs, unrollKSLMStride, C_slmLayout, C_slm, C_slmStrategy,
10217 strategy, state);
10218
10219 mark(lTop);
10220 add(1 | gt | flagKTLoop, kt, kt, -1);
10221 loadMatrix(C_load, C_slmLayout, C_slm, C_slmStrategy, C_slmAddrs,
10222 strategy, state);
10223 incAddr(C_slmAddrs, unrollKSLMStride, C_slmLayout, C_slm, C_slmStrategy,
10224 strategy, state);
10225 map(hw, Tc.real(), C_range, C_load, strategy,
10226 [&](int simd, GRF r1, GRF r2) { add(simd, r1, r1, r2); });
10227 jmpi(1 | flagKTLoop, lTop);
10228
10229 if (rr + nreg < regs)
10230 incAddr(C_slmAddrs, unrollKSLMReturn, C_slmLayout, C_slm,
10231 C_slmStrategy, strategy, state);
10232
10233 mark(lSkipReduce);
10234 }
10235
10236 // Followers will not update C.
10237 mov(1 | ~state.flagAP, state.remainders[LoopM], 0);
10238 mov(1 | ~state.flagAP, state.remainders[LoopN], 0);
10239
10240 state.raVFlag.safeRelease(flagKTLoop);
10241 state.ra.safeRelease(C_load);
10242 state.ra.safeRelease(kt);
10243 state.ra.safeRelease(unrollKSLMReturn);
10244 state.ra.safeRelease(addr0);
10245 state.ra.safeRelease(barrierTemp);
10246 safeReleaseRanges(C_slmAddrs, state);
10247
10248 mark(lDone);
10249}
10250
10251template <HW hw>
10252void gemm_kernel_generator_t<hw>::gemmPrefetchC(
10253 const GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state) {
10254 auto Tc_ext = problem.Tc_ext;
10255 bool checkBeta0 = problem.checkBeta0 && !problem.beta_real.fixed();
10256 bool checkIDK = strategy.kParallelLocal;
10257
10258 releaseRanges(state.Ap_regs, state);
10259 releaseRanges(state.Bp_regs, state);
10260
10261 status << "Prefetch C" << status_stream::endl;
10262
10263 if (checkBeta0) {
10264 cmp0(1 | eq | state.flagAP, problem.beta_real.getReg(0));
10265 }
10266
10267 Address2DParams Cp_params;
10268 if (strategy.C.address2D) {
10269 Cp_params.rows = state.inputs.m;
10270 Cp_params.cols = state.inputs.n;
10271 Cp_params.offR = state.i0;
10272 Cp_params.offC = state.j0;
10273 } else {
10274 Cp_params.rows = state.remainders[LoopM];
10275 Cp_params.cols = state.remainders[LoopN];
10276 }
10277 Cp_params.remR = state.remainders[LoopM];
10278 Cp_params.remC = state.remainders[LoopN];
10279
10280 bool oldAdd32 = strategy.emulate.emulate64_add32;
10281 strategy.emulate.emulate64_add32 = false;
10282
10283 gemmCacheLDCMultiples(problem, strategy, state, 1);
10284
10285 if (checkIDK) {
10286 if (checkBeta0)
10287 cmp(1 | ~state.flagAP | gt | state.flagAP, state.lidK, 0);
10288 else
10289 cmp(1 | gt | state.flagAP, state.lidK, 0);
10290 }
10291
10292 allocAddrRegs(state.Cp_addrs, state.Cp_layout, problem.C,
10293 strategy.C_prefetch, state);
10294 setupAddr(Tc_ext, state.Cp_addrs, state.effCp, state.Cp_layout,
10295 state.inputs.ldc[0], problem.C, strategy.C_prefetch, strategy,
10296 state, Cp_params, state.ldcMultiples[0]);
10297
10298 Label lSkipPrefetchC;
10299 if (checkBeta0 || checkIDK) jmpi(1 | state.flagAP, lSkipPrefetchC);
10300
10301 state.Cp_regs = state.ra.alloc_range(getRegCount(state.Cp_layout));
10302
10303 loadMatrix(state.Cp_regs, state.Cp_layout, problem.C, strategy.C_prefetch,
10304 state.Cp_addrs, strategy, state);
10305
10306 safeReleaseRanges(state.Cp_regs, state);
10307 safeReleaseRanges(state.Cp_addrs, state);
10308 if (state.effCp != state.effC[0]) state.ra.safeRelease(state.effCp);
10309
10310 releaseLDMultiples(state.ldcMultiples[0], state);
10311 releaseIndexVec(state);
10312
10313 if (checkBeta0 || checkIDK) mark(lSkipPrefetchC);
10314
10315 strategy.emulate.emulate64_add32 = oldAdd32;
10316
10317 reclaimRanges(state.Ap_regs, state);
10318 reclaimRanges(state.Bp_regs, state);
10319}
10320
10321// Generate code for checking whether 32-bit address arithmetic can be used inside k loop.
10322// Assumes leading dimensions have not been shifted yet.
10323template <HW hw>
10324void gemm_kernel_generator_t<hw>::gemmCheck32(
10325 const GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state) {
10326 if (!strategy.checkAdd32) return;
10327
10328 bool checkA = (strategy.A.base.getModel() == ModelA64);
10329 bool checkB = (strategy.B.base.getModel() == ModelA64);
10330 if (!checkA && !checkB) return;
10331
10332 auto &m = state.inputs.m;
10333 auto &n = state.inputs.n;
10334 auto &k = state.fullK.isValid() ? state.fullK : state.inputs.k;
10335 auto &lda = state.inputs.lda;
10336 auto &ldb = state.inputs.ldb;
10337 auto temp1GRF = state.ra.alloc();
10338 auto temp2GRF = state.ra.alloc();
10339 auto temp1 = temp1GRF.ud(
10340 0); // Only need one :ud subregister. But GRF-align it for mach.
10341 auto temp2 = temp2GRF.ud(0);
10342 auto temp3 = temp2GRF.ud(4);
10343 auto flag = state.raVFlag.alloc();
10344
10345 if (checkA) {
10346 add(1, temp2, state.effA.ud(), state.offsetA.ud());
10347 switch (problem.A
10348 .layout) { // Conservatively estimate upper bound for size of A.
10349 case MatrixLayout::N: emul32High(1, temp1, lda, k); break;
10350 case MatrixLayout::T: emul32High(1, temp1, lda, m); break;
10351 case MatrixLayout::Pc: {
10352 if (strategy.fixedWG(problem))
10353 add(1, temp3, m,
10354 uint16_t(strategy.wg[LoopM] * strategy.unroll[LoopM]
10355 - 1));
10356 else
10357 emad(1, temp3, m, state.inputs.localSizeM,
10358 strategy.unroll[LoopM], strategy, state);
10359 emul32High(1, temp1, lda, temp3);
10360 break;
10361 }
10362 default: stub();
10363 }
10364 add(1 | ov | flag, temp2, acc0.ud(0), temp2);
10365 cmp(1 | ~flag | ne | flag, temp1, uint16_t(0));
10366 }
10367
10368 if (checkB) {
10369 add(1, temp2, state.effB.ud(), state.offsetB.ud());
10370 switch (problem.B.layout) {
10371 case MatrixLayout::T: emul32High(1, temp1, ldb, k); break;
10372 case MatrixLayout::N: emul32High(1, temp1, ldb, n); break;
10373 case MatrixLayout::Pr: {
10374 if (strategy.fixedWG(problem))
10375 add(1, temp3, n,
10376 uint16_t(strategy.wg[LoopN] * strategy.unroll[LoopN]
10377 - 1));
10378 else
10379 emad(1, temp3, n, state.inputs.localSizeN,
10380 strategy.unroll[LoopN], strategy, state);
10381 emul32High(1, temp1, ldb, temp3);
10382 break;
10383 }
10384 default: stub();
10385 }
10386 InstructionModifier mod = 1;
10387 if (checkA) mod |= ~flag;
10388 add(mod | ov | flag, temp2, acc0.ud(0), temp2);
10389 cmp(1 | ~flag | ne | flag, temp1, uint16_t(0));
10390 }
10391
10392 state.add64 = state.ra.alloc_sub<uint16_t>();
10393 and_(1, state.add64, flag, 1u);
10394 state.raVFlag.safeRelease(flag);
10395
10396 state.ra.safeRelease(temp1GRF);
10397 temp1 = invalid;
10398 state.ra.safeRelease(temp2GRF);
10399 temp2 = invalid;
10400 temp3 = invalid;
10401}
10402
10403// Increment A pointer after load, inside GEMM k loop.
10404template <HW hw>
10405void gemm_kernel_generator_t<hw>::gemmAIncrementInternal(Type Ta,
10406 const std::vector<RegisterBlock> &layout,
10407 const std::vector<GRFRange> &addrs, const MatrixAddressing &A,
10408 const MatrixAddressingStrategy &A_strategy, int ka_inc,
10409 const GEMMProblem &problem, const GEMMStrategy &strategy,
10410 GEMMState &state, int ha) {
10411 if (ka_inc == 0)
10412 /* no-op */;
10413 else if (A_strategy.address2D)
10414 incDecAddr(addrs, Subregister(), 0, ka_inc, layout, A, A_strategy,
10415 strategy, state, problem.backward());
10416 else if (A.layout == MatrixLayout::N) {
10417 SubregisterPair lda_ka;
10418 bool release = false;
10419 // Use cached lda * ka_inc if available, otherwise calculate on the fly.
10420 if (ka_inc == 1)
10421 lda_ka = state.lda;
10422 else if (state.lda_ka.isValid() && ka_inc == state.ka_cached)
10423 lda_ka = state.lda_ka;
10424 else if (state.lda_ka_prefetch.isValid()
10425 && ka_inc == strategy.ka_pfStride)
10426 lda_ka = state.lda_ka_prefetch;
10427 else {
10428 lda_ka = state.ra.alloc_sub<int32_t>();
10429 emulConstant(1, lda_ka, state.inputs.lda, ka_inc, strategy, state);
10430 release = true;
10431 }
10432 incDecAddr(addrs, lda_ka, layout, A, A_strategy, strategy, state,
10433 problem.backward());
10434 if (release) state.ra.safeRelease(lda_ka);
10435 } else {
10436 int incA;
10437 switch (A.layout) {
10438 case MatrixLayout::Pc:
10439 incA = untile(Ta, A, 0, 0, ha + ka_inc, A.packSize,
10440 strategy.unrollKSLM)
10441 - untile(Ta, A, 0, 0, ha, A.packSize,
10442 strategy.unrollKSLM);
10443 break;
10444 case MatrixLayout::T: incA = ka_inc; break;
10445 default: stub();
10446 }
10447 incDecAddr(addrs, uint16_t(incA * Ta), layout, A, A_strategy, strategy,
10448 state, problem.backward());
10449 }
10450}
10451
10452template <HW hw>
10453void gemm_kernel_generator_t<hw>::gemmAIncrementInternal(Type Ta,
10454 const std::vector<RegisterBlock> &layout,
10455 const std::vector<GRFRange> &addrs, const MatrixAddressing &A,
10456 const MatrixAddressingStrategy &A_strategy,
10457 const MultishiftSubregister &ka_inc, const GEMMProblem &problem,
10458 const GEMMStrategy &strategy, GEMMState &state, int ha) {
10459 incDecAddr(addrs, ka_inc, layout, A, A_strategy, strategy, state,
10460 problem.backward());
10461}
10462
10463template <HW hw>
10464void gemm_kernel_generator_t<hw>::gemmAIncrementInternal(Type Ta,
10465 const std::vector<RegisterBlock> &layout,
10466 const std::vector<GRFRange> &addrs, const MatrixAddressing &A,
10467 const MatrixAddressingStrategy &A_strategy, const Subregister &ka_inc,
10468 const GEMMProblem &problem, const GEMMStrategy &strategy,
10469 GEMMState &state, int ha) {
10470 incDecAddr(addrs, ka_inc, 0, ka_inc, layout, A, A_strategy, strategy, state,
10471 problem.backward());
10472}
10473
10474template <HW hw>
10475template <typename I>
10476void gemm_kernel_generator_t<hw>::gemmAIncrement(Type Ta,
10477 const std::vector<RegisterBlock> &layout,
10478 const std::vector<GRFRange> &addrs, const MatrixAddressing &A,
10479 const MatrixAddressingStrategy &A_strategy, I ka_inc,
10480 const GEMMProblem &problem, const GEMMStrategy &strategy,
10481 GEMMState &state, int ha) {
10482 gemmAIncrementInternal(Ta, layout, addrs, A, A_strategy, ka_inc, problem,
10483 strategy, state, ha);
10484}
10485
10486// A load for GEMM k loop.
10487template <HW hw>
10488void gemm_kernel_generator_t<hw>::gemmALoad(const GRFMultirange &regs,
10489 const std::vector<RegisterBlock> &layout,
10490 const std::vector<GRFRange> &addrs, const MatrixAddressing &A,
10491 const MatrixAddressingStrategy &A_strategy, const GEMMProblem &problem,
10492 const GEMMStrategy &strategy, GEMMState &state) {
10493 loadMatrix(regs, layout, A, A_strategy, addrs, strategy, state);
10494}
10495
10496template <HW hw>
10497template <typename I>
10498void gemm_kernel_generator_t<hw>::gemmALoadInc(Type Ta,
10499 const GRFMultirange &regs, const std::vector<RegisterBlock> &layout,
10500 const std::vector<GRFRange> &addrs, const MatrixAddressing &A,
10501 const MatrixAddressingStrategy &A_strategy, I ka_inc,
10502 const GEMMProblem &problem, const GEMMStrategy &strategy,
10503 GEMMState &state) {
10504 gemmALoad(regs, layout, addrs, A, A_strategy, problem, strategy, state);
10505 gemmAIncrement(
10506 Ta, layout, addrs, A, A_strategy, ka_inc, problem, strategy, state);
10507}
10508
10509template <HW hw>
10510void gemm_kernel_generator_t<hw>::gemmBIncrementInternal(Type Tb,
10511 const std::vector<RegisterBlock> &layout,
10512 const std::vector<GRFRange> &addrs, const MatrixAddressing &B,
10513 const MatrixAddressingStrategy &B_strategy, int kb_inc,
10514 const GEMMProblem &problem, const GEMMStrategy &strategy,
10515 GEMMState &state, int hb) {
10516 if (kb_inc == 0)
10517 /* no-op */;
10518 else if (B_strategy.address2D)
10519 incDecAddr(addrs, Subregister(), kb_inc, 0, layout, B, B_strategy,
10520 strategy, state, problem.backward());
10521 else if (B.layout == MatrixLayout::T) {
10522 SubregisterPair ldb_kb;
10523 bool release = false;
10524 if (kb_inc == 1)
10525 ldb_kb = state.ldb;
10526 else if (state.ldb_kb.isValid() && kb_inc == state.kb_cached)
10527 ldb_kb = state.ldb_kb;
10528 else if (state.ldb_kb_prefetch.isValid()
10529 && kb_inc == strategy.kb_pfStride)
10530 ldb_kb = state.ldb_kb_prefetch;
10531 else {
10532 ldb_kb = state.ra.alloc_sub<int32_t>();
10533 emulConstant(1, ldb_kb, state.inputs.ldb, kb_inc, strategy, state);
10534 release = true;
10535 }
10536 incDecAddr(addrs, ldb_kb, layout, B, B_strategy, strategy, state,
10537 problem.backward());
10538 if (release) state.ra.safeRelease(ldb_kb);
10539 } else {
10540 int incB;
10541 switch (B.layout) {
10542 case MatrixLayout::Pr:
10543 incB = untile(Tb, B, 0, hb + kb_inc, 0, strategy.unrollKSLM,
10544 B.packSize)
10545 - untile(Tb, B, 0, hb, 0, strategy.unrollKSLM,
10546 B.packSize);
10547 break;
10548 case MatrixLayout::N: incB = kb_inc; break;
10549 default: stub();
10550 }
10551 incDecAddr(addrs, uint16_t(incB * Tb), layout, B, B_strategy, strategy,
10552 state, problem.backward());
10553 }
10554}
10555
10556template <HW hw>
10557void gemm_kernel_generator_t<hw>::gemmBIncrementInternal(Type Tb,
10558 const std::vector<RegisterBlock> &layout,
10559 const std::vector<GRFRange> &addrs, const MatrixAddressing &B,
10560 const MatrixAddressingStrategy &B_strategy,
10561 const MultishiftSubregister &kb_inc, const GEMMProblem &problem,
10562 const GEMMStrategy &strategy, GEMMState &state, int hb) {
10563 incDecAddr(addrs, kb_inc, layout, B, B_strategy, strategy, state,
10564 problem.backward());
10565}
10566
10567template <HW hw>
10568void gemm_kernel_generator_t<hw>::gemmBIncrementInternal(Type Tb,
10569 const std::vector<RegisterBlock> &layout,
10570 const std::vector<GRFRange> &addrs, const MatrixAddressing &B,
10571 const MatrixAddressingStrategy &B_strategy, const Subregister &kb_inc,
10572 const GEMMProblem &problem, const GEMMStrategy &strategy,
10573 GEMMState &state, int hb) {
10574 incDecAddr(addrs, kb_inc, kb_inc, 0, layout, B, B_strategy, strategy, state,
10575 problem.backward());
10576}
10577
10578template <HW hw>
10579template <typename I>
10580void gemm_kernel_generator_t<hw>::gemmBIncrement(Type Tb,
10581 const std::vector<RegisterBlock> &layout,
10582 const std::vector<GRFRange> &addrs, const MatrixAddressing &B,
10583 const MatrixAddressingStrategy &B_strategy, I kb_inc,
10584 const GEMMProblem &problem, const GEMMStrategy &strategy,
10585 GEMMState &state, int hb) {
10586 gemmBIncrementInternal(Tb, layout, addrs, B, B_strategy, kb_inc, problem,
10587 strategy, state, hb);
10588}
10589
10590// B load for GEMM k loop.
10591template <HW hw>
10592void gemm_kernel_generator_t<hw>::gemmBLoad(const GRFMultirange &regs,
10593 const std::vector<RegisterBlock> &layout,
10594 const std::vector<GRFRange> &addrs, const MatrixAddressing &B,
10595 const MatrixAddressingStrategy &B_strategy, const GEMMProblem &problem,
10596 const GEMMStrategy &strategy, GEMMState &state) {
10597 loadMatrix(regs, layout, B, B_strategy, addrs, strategy, state);
10598}
10599
10600template <HW hw>
10601template <typename I>
10602void gemm_kernel_generator_t<hw>::gemmBLoadInc(Type Tb,
10603 const GRFMultirange &regs, const std::vector<RegisterBlock> &layout,
10604 const std::vector<GRFRange> &addrs, const MatrixAddressing &B,
10605 const MatrixAddressingStrategy &B_strategy, I kb_inc,
10606 const GEMMProblem &problem, const GEMMStrategy &strategy,
10607 GEMMState &state) {
10608 gemmBLoad(regs, layout, addrs, B, B_strategy, problem, strategy, state);
10609 gemmBIncrement(
10610 Tb, layout, addrs, B, B_strategy, kb_inc, problem, strategy, state);
10611}
10612
10613template <HW hw>
10614template <bool doA>
10615void gemm_kernel_generator_t<hw>::gemmAiBiRemLoadInc(bool incremental,
10616 bool incrementalCopy, bool keepAddrTogether, bool willRemask,
10617 const Subregister &kSLMX, const GRFMultirange &Xi_regs,
10618 const vector<RegisterBlock> &Xi_layout,
10619 const vector<GRFRange> &Xi_addrs,
10620 const vector<vector<RegisterBlock>> &Xi_layoutK,
10621 const vector<vector<GRFRange>> &Xi_addrsK, const GRFMultirange &Xo_regs,
10622 const vector<RegisterBlock> &Xo_layout, const MatrixAddressing &Xi,
10623 const MatrixAddressingStrategy &Xi_strategy, const GEMMProblem &problem,
10624 const GEMMStrategy &strategy, GEMMState &state) {
10625 auto T = doA ? problem.Ta : problem.Tb;
10626 auto T_ext = doA ? problem.Ta_ext : problem.Tb_ext;
10627 auto kx_slm = doA ? state.ka_slm : state.kb_slm;
10628
10629 auto unrollKSLM = strategy.unrollKSLM;
10630
10631 bool prezero = !willRemask
10632 && ((doA ? state.slmASums : state.slmBSums)
10633 || (minOuterProductCount(hw, problem, strategy) > 1));
10634
10635 if (!incremental) {
10636 if (prezero) zeroMatrix(Xi_regs, strategy);
10637 doA ? gemmALoad(Xi_regs, Xi_layout, Xi_addrs, Xi, Xi_strategy, problem,
10638 strategy, state)
10639 : gemmBLoad(Xi_regs, Xi_layout, Xi_addrs, Xi, Xi_strategy, problem,
10640 strategy, state);
10641 } else {
10642 bool simtCF = strategy.fused
10643 && (strategy.fusedLoop == (doA ? LoopN : LoopM));
10644 int simt = simtCF ? 16 : 1;
10645 Label done;
10646
10647 keepAddrTogether &= (Xi_addrsK.size() > 1);
10648
10649 cmp(simt | gt | state.flagAP, kSLMX, 0);
10650 add(1, kSLMX, kSLMX, (kx_slm > 1) ? -1 : -unrollKSLM);
10651
10652 if (prezero) zeroMatrix(incrementalCopy ? Xo_regs : Xi_regs, strategy);
10653
10654 for (int hh = 0; hh < kx_slm; hh++) {
10655 int hhRem = kx_slm - hh - 1;
10656
10657 simtCF ? goto12(16 | ~state.flagAP, done)
10658 : jmpi(1 | ~state.flagAP, done);
10659
10660 if (hhRem > 0) {
10661 cmp(simt | gt | state.flagAP, kSLMX, 0);
10662 add(1, kSLMX, kSLMX,
10663 (hhRem == 1) ? -(unrollKSLM - kx_slm + 1) : -1);
10664 }
10665
10666 int hh_eff = problem.backward() ? (kx_slm - 1 - hh) : hh;
10667 int hh_layout = hh_eff;
10668 int hh_addr = hh_eff;
10669
10670 if (Xi_layoutK.size() == 1) hh_layout = 0;
10671 if (Xi_addrsK.size() == 1) hh_addr = 0;
10672
10673 // OPTIMIZEME: delay inc if kx_slm = 1
10674 auto kx_inc = (Xi_addrsK.size() > 1)
10675 ? unrollKSLM
10676 : ((hh + 1) != kx_slm) ? 1 : (unrollKSLM - kx_slm + 1);
10677
10678 if (keepAddrTogether) kx_inc = 0;
10679
10680 doA ? gemmALoadInc(T_ext, Xi_regs, Xi_layoutK[hh_layout],
10681 Xi_addrsK[hh_addr], Xi, Xi_strategy, kx_inc, problem,
10682 strategy, state)
10683 : gemmBLoadInc(T_ext, Xi_regs, Xi_layoutK[hh_layout],
10684 Xi_addrsK[hh_addr], Xi, Xi_strategy, kx_inc, problem,
10685 strategy, state);
10686
10687 if (incrementalCopy) {
10688 int rr_eff = doA ? 0 : hh_eff;
10689 int cc_eff = doA ? hh_eff : 0;
10690 copyRegisters(T_ext, T, Xi_layoutK[hh_layout], Xo_layout,
10691 Xi_regs, Xo_regs, rr_eff, cc_eff, false, strategy,
10692 state);
10693 }
10694 }
10695
10696 mark(done);
10697 if (simtCF) join(16);
10698
10699 if (keepAddrTogether) {
10700 doA ? gemmAIncrement(T_ext, Xi_layout, Xi_addrs, Xi, Xi_strategy,
10701 unrollKSLM, problem, strategy, state)
10702 : gemmBIncrement(T_ext, Xi_layout, Xi_addrs, Xi, Xi_strategy,
10703 unrollKSLM, problem, strategy, state);
10704 }
10705 }
10706}
10707
10708// Calculate A offset for SLM copies or cooperative prefetches for this local ID.
10709template <HW hw>
10710void gemm_kernel_generator_t<hw>::gemmCalcWorkshareAOffset(Subregister &off,
10711 Subregister &offR, Subregister &offC, const MatrixAddressing &A,
10712 const MatrixAddressingStrategy &A_strategy, int ma, int ka,
10713 const GEMMProblem &problem, const GEMMStrategy &strategy,
10714 GEMMState &state) {
10715 bool splitM = (state.effCoopA == CoopSplit::MN);
10716 bool splitLinear = (state.effCoopA == CoopSplit::Linear);
10717
10718 if (A_strategy.address2D) {
10719 if (splitLinear) stub();
10720 if (splitM) {
10721 offR = state.ra.alloc_sub<uint32_t>(
10722 getHint(HintType::TempComp0, strategy));
10723 mulConstant(1, offR, state.lidN, ma);
10724 } else {
10725 offC = state.ra.alloc_sub<uint32_t>(
10726 getHint(HintType::TempComp0, strategy));
10727 mulConstant(1, offC, state.lidN, ka);
10728 }
10729 } else {
10730 auto Ta_ext = problem.Ta_ext;
10731 off = state.ra.alloc_sub<uint32_t>(
10732 getHint(HintType::TempComp0, strategy));
10733
10734 switch (A.layout) {
10735 case MatrixLayout::Pc:
10736 mulConstant(1, off, state.lidN, ma * ka * Ta_ext);
10737 break;
10738 case MatrixLayout::T:
10739 if (splitLinear) stub();
10740 if (splitM) {
10741 mul(1, off, state.inputs.lda, state.lidN);
10742 mulConstant(1, off, off, ma);
10743 } else
10744 mulConstant(1, off, state.lidN, ka * Ta_ext);
10745 break;
10746 case MatrixLayout::N:
10747 if (splitLinear) stub();
10748 if (splitM)
10749 mulConstant(1, off, state.lidN, ma * Ta_ext);
10750 else {
10751 mul(1, off, state.inputs.lda, state.lidN);
10752 mulConstant(1, off, off, ka);
10753 }
10754 break;
10755 default: stub();
10756 }
10757 }
10758}
10759
10760// Calculate B offset for SLM copies or cooperative prefetches for this local ID.
10761template <HW hw>
10762void gemm_kernel_generator_t<hw>::gemmCalcWorkshareBOffset(Subregister &off,
10763 Subregister &offR, Subregister &offC, const MatrixAddressing &B,
10764 const MatrixAddressingStrategy &B_strategy, int kb, int nb,
10765 const GEMMProblem &problem, const GEMMStrategy &strategy,
10766 GEMMState &state) {
10767 bool splitN = (state.effCoopB == CoopSplit::MN);
10768 bool splitLinear = (state.effCoopB == CoopSplit::Linear);
10769
10770 if (B_strategy.address2D) {
10771 if (splitLinear) stub();
10772 if (splitN) {
10773 offC = state.ra.alloc_sub<uint32_t>(
10774 getHint(HintType::TempComp0, strategy));
10775 mulConstant(1, offC, state.lidM, nb);
10776 } else {
10777 offR = state.ra.alloc_sub<uint32_t>(
10778 getHint(HintType::TempComp0, strategy));
10779 mulConstant(1, offR, state.lidM, kb);
10780 }
10781 } else {
10782 auto Tb_ext = problem.Tb_ext;
10783 off = state.ra.alloc_sub<uint32_t>(
10784 getHint(HintType::TempComp0, strategy));
10785
10786 switch (B.layout) {
10787 case MatrixLayout::Pr:
10788 mulConstant(1, off, state.lidM, nb * kb * Tb_ext);
10789 break;
10790 case MatrixLayout::N:
10791 if (splitLinear) stub();
10792 if (splitN) {
10793 mul(1, off, state.inputs.ldb, state.lidM);
10794 mulConstant(1, off, off, nb);
10795 } else
10796 mulConstant(1, off, state.lidM, kb * Tb_ext);
10797 break;
10798 case MatrixLayout::T:
10799 if (splitLinear) stub();
10800 if (splitN)
10801 mulConstant(1, off, state.lidM, nb * Tb_ext);
10802 else {
10803 mul(1, off, state.inputs.ldb, state.lidM);
10804 mulConstant(1, off, off, kb);
10805 }
10806 break;
10807 default: stub();
10808 }
10809 }
10810}
10811
10812// Remask incoming global data for SLM copies.
10813template <HW hw>
10814void gemm_kernel_generator_t<hw>::gemmSLMRemask(bool remaskA, bool remaskB,
10815 GRFMultirange &Ao_regs, GRFMultirange &Bo_regs, int kOffset,
10816 const GEMMProblem &problem, const GEMMStrategy &strategy,
10817 GEMMState &state) {
10818 if (problem.backward()) stub();
10819
10820 auto Ta = problem.Ta, Tb = problem.Tb;
10821
10822 bool oremaskA = remaskA && (state.effCoopA == CoopSplit::K);
10823 bool oremaskB = remaskB && (state.effCoopB == CoopSplit::K);
10824 bool noshareRemask = (oremaskA || oremaskB)
10825 || (remaskA && remaskB && Ta.size() != Tb.size());
10826 int aRemaskLen = state.ka_slm;
10827 int bRemaskLen = state.kb_slm;
10828
10829 Subregister offK_A, offK_B;
10830 if (oremaskA) {
10831 offK_A = state.ra.alloc_sub<uint32_t>();
10832 mulConstant(1, offK_A, state.lidN, state.ka_slm);
10833 }
10834
10835 if (oremaskB) {
10836 offK_B = state.ra.alloc_sub<uint32_t>();
10837 mulConstant(1, offK_B, state.lidM, state.kb_slm);
10838 }
10839
10840 if (!noshareRemask && remaskA && remaskB)
10841 aRemaskLen = bRemaskLen = std::max(aRemaskLen, bRemaskLen);
10842
10843 if (remaskA) {
10844 setupTeardownRemask(Ta, 1, true, aRemaskLen, state.K, strategy, state,
10845 kOffset, offK_A);
10846 remaskLayout(Ta, 1, true, state.Ao_layout, Ao_regs, strategy, state);
10847 if (noshareRemask || !remaskB)
10848 setupTeardownRemask(Ta, 1, false, aRemaskLen, state.K, strategy,
10849 state, kOffset, offK_A);
10850 }
10851
10852 if (remaskB) {
10853 if (noshareRemask || !remaskA)
10854 setupTeardownRemask(Tb, 1, true, bRemaskLen, state.K, strategy,
10855 state, kOffset, offK_B);
10856 remaskLayout(Tb, 1, false, state.Bo_layout, Bo_regs, strategy, state);
10857 setupTeardownRemask(Tb, 1, false, bRemaskLen, state.K, strategy, state,
10858 kOffset, offK_B);
10859 }
10860}
10861
10862// Calculate kSLMA/kSLMB -- countdown variables for SLM copies.
10863template <HW hw>
10864void gemm_kernel_generator_t<hw>::gemmCalcKSLM(const Subregister &kSLM,
10865 const Subregister &lid, int kgran, int kdiv, int krep,
10866 const GEMMProblem &problem, const GEMMStrategy &strategy,
10867 GEMMState &state) {
10868 if (kdiv == 1)
10869 mov(1, kSLM, state.K);
10870 else {
10871 auto modLID = lid;
10872 if (krep > 1) {
10873 if (!is_zero_or_pow2(krep)) stub();
10874 modLID = state.ra.alloc_sub<uint16_t>();
10875 shr(1, modLID, lid, log2(krep));
10876 }
10877 if (!problem.backward())
10878 emad(1, kSLM, state.K.w(), -modLID, kgran, strategy, state);
10879 else {
10880 emad(1, kSLM, strategy.unrollKSLM - kgran, -modLID, kgran, strategy,
10881 state);
10882 add(1, kSLM, state.K, -kSLM);
10883 }
10884 if (krep > 1) state.ra.safeRelease(modLID);
10885 }
10886}
10887
10888// Calculate barrier count for a k loop.
10889template <HW hw>
10890void gemm_kernel_generator_t<hw>::gemmCalcKLoopBarrierCount(Subregister &count,
10891 const Subregister &k, int cooldown, const GEMMProblem &problem,
10892 const GEMMStrategy &strategy, GEMMState &state) {
10893 int barrierFreq = strategy.barrierFreq;
10894 int unrollK = strategy.unroll[LoopK];
10895 int unrollKSLM = strategy.unrollKSLM;
10896
10897 if (count.isInvalid()) count = state.ra.alloc_sub<uint32_t>();
10898
10899 if (barrierFreq > 0) {
10900 if (!is_zero_or_pow2(barrierFreq)) stub();
10901
10902 if (strategy.splitBarrier && cooldown > 0)
10903 cmp(1 | ge | state.flagAP, k, cooldown);
10904 add(1 | sat, count, k, barrierFreq - cooldown - unrollK);
10905 shr(1, count, count, uint16_t(log2(barrierFreq)));
10906 if (strategy.splitBarrier) {
10907 (cooldown > 0) ? add(1 | state.flagAP, count, count, 1)
10908 : add(1, count, count, 1);
10909 }
10910 } else if (strategy.slmBuffers > 0) {
10911 if (!is_zero_or_pow2(unrollKSLM)) stub();
10912
10913 if (strategy.slmBuffers == 1) {
10914 add(1 | sat, count, k, unrollKSLM - 1);
10915 if (unrollKSLM == 2)
10916 and_(1, count, count, ~uint32_t(1));
10917 else {
10918 shr(1, count, count, uint16_t(log2(unrollKSLM)));
10919 shl(1, count, count, 1);
10920 }
10921 } else {
10922 add(1 | sat, count, k, unrollKSLM - 1);
10923 shr(1, count, count, uint16_t(log2(unrollKSLM)));
10924 }
10925 } else
10926 mov(1, count, 0);
10927}
10928
10929int maxExtraKLoopRemBarriers(const GEMMStrategy &strategy) {
10930 if (strategy.slmBuffers == 2)
10931 return div_up(strategy.unroll[LoopK], strategy.unrollKSLM);
10932 return 0;
10933}
10934
10935static void makeAiBiKCloneLayout(HW hw, bool isA,
10936 vector<RegisterBlock> &Xi_layout,
10937 vector<vector<RegisterBlock>> &Xi_layoutK,
10938 vector<GRFMultirange> &Xi_regsRem, int kx_slm,
10939 const GEMMStrategy &strategy, GEMMState &state) {
10940 auto regCountK = getRegCount(Xi_layoutK[0]);
10941 auto regCount = regCountK * kx_slm;
10942 auto offsetK = isA ? &RegisterBlock::offsetC : &RegisterBlock::offsetR;
10943
10944 Xi_layout = Xi_layoutK[0];
10945
10946 for (int h1 = 1; h1 < kx_slm; h1++) {
10947 Xi_layoutK[h1] = Xi_layoutK[h1 - 1];
10948 for (auto &block : Xi_layoutK[h1]) {
10949 block.offsetBytes += regCountK * GRF::bytes(hw);
10950
10951 auto oblock = block;
10952 oblock.*offsetK += h1;
10953 Xi_layout.push_back(std::move(oblock));
10954 }
10955 }
10956
10957 int extraRegs = regCount - Xi_regsRem[0].getLen();
10958 if (extraRegs > 0) {
10959 for (int q = 0; q < strategy.slmCopies; q++)
10960 Xi_regsRem[q].append(state.ra.alloc_range(extraRegs));
10961 }
10962}
10963
10964// Prepare for inner loop. Returns true on success.
10965template <HW hw>
10966bool gemm_kernel_generator_t<hw>::kLoopSetup(const GEMMProblem &problem,
10967 const GEMMStrategy &strategy, GEMMState &state) {
10968 auto Ta = problem.Ta, Tb = problem.Tb;
10969 auto Ta_ext = problem.Ta_ext, Tb_ext = problem.Tb_ext;
10970 auto Ta_load = state.Ta_load, Tb_load = state.Tb_load;
10971
10972 auto minOPCount = minOuterProductCount(hw, problem, strategy);
10973 auto unrollM = strategy.unroll[LoopM];
10974 auto unrollN = strategy.unroll[LoopN];
10975
10976 state.barrierReady = false;
10977
10978 // Get A/B named barrier IDs.
10979 auto &barrierHeaderM = state.barrierHeaderM;
10980 auto &barrierHeaderN = state.barrierHeaderN;
10981 auto &barrierM = state.barrierM;
10982 auto &barrierN = state.barrierN;
10983 bool nbM = (strategy.slmA || strategy.barrierFreq)
10984 && strategy.namedBarriers[LoopM];
10985 bool nbN = (strategy.slmB || strategy.barrierFreq)
10986 && strategy.namedBarriers[LoopN];
10987
10988 if (nbM) {
10989 barrierHeaderM = state.ra.alloc();
10990
10991 // Get fN.0 subregister for use with sync.bar.
10992 barrierM = state.raVFlag.alloc(2); // TODO: unlock these flag registers.
10993 state.raVFlag.release(FlagRegister {barrierM.getARFBase(), 1});
10994 barrierM = FlagRegister {barrierM.getARFBase(), 0};
10995
10996 if (!is_zero_or_pow2(strategy.wg[LoopM])
10997 || !is_zero_or_pow2(strategy.namedBarriers[LoopM]))
10998 stub();
10999 shr(1, barrierHeaderM.uw(4), state.lidM,
11000 log2(strategy.wg[LoopM]) - log2(strategy.namedBarriers[LoopM]));
11001 }
11002 if (nbN) {
11003 barrierHeaderN = state.ra.alloc();
11004 barrierN = state.raVFlag.alloc(2);
11005 state.raVFlag.release(FlagRegister {barrierN.getARFBase(), 1});
11006 barrierN = FlagRegister {barrierN.getARFBase(), 0};
11007
11008 if (!is_zero_or_pow2(strategy.wg[LoopN])
11009 || !is_zero_or_pow2(strategy.namedBarriers[LoopN]))
11010 stub();
11011 shr(1, barrierHeaderN.uw(4), state.lidN,
11012 log2(strategy.wg[LoopN]) - log2(strategy.namedBarriers[LoopN]));
11013 }
11014 if (nbM) {
11015 int threadsPerMBar = strategy.wg[LoopM] * strategy.wg[LoopN]
11016 / strategy.namedBarriers[LoopM];
11017 mov(1, barrierHeaderM.uw(5), threadsPerMBar | (threadsPerMBar << 8));
11018 }
11019 if (nbN) {
11020 int threadsPerNBar = strategy.wg[LoopM] * strategy.wg[LoopN]
11021 / strategy.namedBarriers[LoopN];
11022 mov(1, barrierHeaderN.uw(5), threadsPerNBar | (threadsPerNBar << 8));
11023 }
11024 if (nbM && nbN)
11025 add(1, barrierHeaderN.uw(4), barrierHeaderN.uw(4),
11026 strategy.namedBarriers[LoopM]);
11027 if (nbM) mov(1, barrierM, barrierHeaderM.uw(4));
11028 if (nbN) mov(1, barrierN, barrierHeaderN.uw(4));
11029
11030 // Get tokens for barriers/fences.
11031 for (int q = 0; q < 2; q++) {
11032 state.tokenBarrierFence[q] = -1;
11033 state.modBarrierFence[q] = InstructionModifier {};
11034 }
11035
11036 if (hw >= HW::Gen12LP) {
11037 if (strategy.needsBarrier())
11038 state.tokenBarrierFence[0] = state.tokenAllocator.tryAlloc();
11039 if (nbM && nbN)
11040 state.tokenBarrierFence[1] = state.tokenAllocator.tryAlloc();
11041 for (int q = 0; q < 2; q++)
11042 if (state.tokenBarrierFence[q] >= 0)
11043 state.modBarrierFence[q] = SBID(state.tokenBarrierFence[q]);
11044 }
11045
11046 // Remainder load preparations.
11047 auto &ka_loadRem = state.ka_loadRem, &kb_loadRem = state.kb_loadRem;
11048 ka_loadRem = 1, kb_loadRem = 1;
11049
11050 // For packed layouts, extend remainder loads to encompass a full logical block.
11051 int ignore;
11052 getGranularities(problem.A, ignore, ka_loadRem);
11053 getGranularities(problem.B, kb_loadRem, ignore);
11054
11055 // With 2D block loads, extend k unroll to at least a full block (array).
11056 bool a2D = isBlock2D(strategy.A.accessType);
11057 bool b2D = isBlock2D(strategy.B.accessType);
11058 bool ai2D = strategy.slmA && isBlock2D(state.Ai_strategy.accessType);
11059 bool bi2D = strategy.slmB && isBlock2D(state.Bi_strategy.accessType);
11060 if (a2D || ai2D) {
11061 ka_loadRem = state.A_layout[0].nc;
11062 if (!isColMajor(problem.A.layout))
11063 ka_loadRem *= state.A_layout[0].count;
11064 }
11065 if (b2D || bi2D) {
11066 kb_loadRem = state.B_layout[0].nr;
11067 if (isColMajor(problem.B.layout)) kb_loadRem *= state.B_layout[0].count;
11068 }
11069
11070 // Fragment the A, B layouts into smaller blocks (usually 1 row/column) for remainder loads.
11071 if (!getSubblocks(Ta_load, state.A_layoutRem, state.A_addrsRem,
11072 state.A_layout, state.A_addrs, true, 0, ka_loadRem,
11073 strategy.A.padded, problem.A, strategy.A))
11074 return false;
11075 if (!getSubblocks(Tb_load, state.B_layoutRem, state.B_addrsRem,
11076 state.B_layout, state.B_addrs, false, 0, kb_loadRem,
11077 strategy.B.padded, problem.B, strategy.B))
11078 return false;
11079
11080 // Add k masking.
11081 if (a2D && (ka_loadRem > 1))
11082 addMasking(
11083 Ta_load, state.A_layoutRem, false, true, problem.A, strategy.A);
11084 if (b2D && (kb_loadRem > 1))
11085 addMasking(
11086 Tb_load, state.B_layoutRem, true, false, problem.B, strategy.B);
11087
11088 // Ai/Bi remainders.
11089 auto &Ai_layoutRem = state.Ai_layoutRem, &Bi_layoutRem = state.Bi_layoutRem;
11090 auto &Ai_layoutK = state.Ai_layoutK, &Bi_layoutK = state.Bi_layoutK;
11091 auto &Ai_addrsRem = state.Ai_addrsRem, &Bi_addrsRem = state.Bi_addrsRem;
11092 auto &Ai_addrsK = state.Ai_addrsK, &Bi_addrsK = state.Bi_addrsK;
11093 auto &Ai_regsRem = state.Ai_regsRem, &Bi_regsRem = state.Bi_regsRem;
11094 auto &Ao_regsRem = state.Ao_regsRem, &Bo_regsRem = state.Bo_regsRem;
11095 auto &Ai_hasKRem = state.Ai_hasKRem, &Bi_hasKRem = state.Bi_hasKRem;
11096 auto &Ai_lateKRem = state.Ai_lateKRem, &Bi_lateKRem = state.Bi_lateKRem;
11097 auto &Ai_remIncrCopy = state.Ai_remIncrCopy,
11098 &Bi_remIncrCopy = state.Bi_remIncrCopy;
11099 auto &Ai_incrementalRem = state.Ai_incrementalRem,
11100 &Bi_incrementalRem = state.Bi_incrementalRem;
11101 auto &aioShareRem = state.aioShareRem, &bioShareRem = state.bioShareRem;
11102 int ka_slm = state.ka_slm, kb_slm = state.kb_slm;
11103
11104 Ai_layoutRem = state.Ai_layout;
11105 Bi_layoutRem = state.Bi_layout;
11106 Ai_addrsRem = state.Ai_addrs;
11107 Bi_addrsRem = state.Bi_addrs;
11108 Ai_regsRem = state.Ai_regs;
11109 Bi_regsRem = state.Bi_regs;
11110 Ao_regsRem = state.Ao_regs;
11111 Bo_regsRem = state.Bo_regs;
11112
11113 Ai_hasKRem = Ai_lateKRem = false;
11114 Bi_hasKRem = Bi_lateKRem = false;
11115 Ai_remIncrCopy = Bi_remIncrCopy = false;
11116
11117 if (ai2D && (ka_loadRem > 1) && state.Ai_strategy.address2D) {
11118 Ai_hasKRem = true;
11119 addMasking(Ta_ext, state.Ai_layoutRem, false, true, state.Ai,
11120 state.Ai_strategy);
11121 }
11122
11123 if (bi2D && (kb_loadRem > 1) && state.Bi_strategy.address2D) {
11124 Bi_hasKRem = true;
11125 addMasking(Tb_ext, state.Bi_layoutRem, true, false, state.Bi,
11126 state.Bi_strategy);
11127 }
11128
11129 if (strategy.slmA && !Ai_hasKRem)
11130 Ai_lateKRem |= !isRegisterColMajor(Ta_ext, state.Ai, state.Ai_strategy);
11131 if (strategy.slmB && !Bi_hasKRem)
11132 Bi_lateKRem |= isRegisterColMajor(Tb_ext, state.Bi, state.Bi_strategy);
11133
11134 Ai_incrementalRem
11135 = strategy.slmA && !state.Ai_hasKRem && !state.Ai_lateKRem;
11136 Bi_incrementalRem
11137 = strategy.slmB && !state.Bi_hasKRem && !state.Bi_lateKRem;
11138 aioShareRem = state.aioShare;
11139 bioShareRem = state.bioShare;
11140
11141 if (Ai_incrementalRem) {
11142 // Prepare to split Ai layout in k dimension. If it's not possible to do in-place, then
11143 // either redo the layout or copy Ai->Ao incrementally.
11144 Ai_layoutK.resize(ka_slm);
11145 Ai_addrsK.resize(ka_slm);
11146 for (int h = 0; h < ka_slm; h++) {
11147 bool success = false;
11148
11149 if (h < int(Ai_addrsK.size())) {
11150 success = getSubblocks(Ta_ext, Ai_layoutK[h], Ai_addrsK[h],
11151 Ai_layoutRem, state.Ai_addrs, true, h, h + 1,
11152 state.Ai_strategy.padded, state.Ai, state.Ai_strategy);
11153 }
11154
11155 if (!success && h == 0) stub();
11156
11157 if (!success) {
11158 // Maybe the subblock is OK, but we didn't get an address register. Try again without
11159 // asking for address registers.
11160 Ai_addrsK.resize(1);
11161 success = getSubblocks(Ta_ext, Ai_layoutK[h], Ai_layoutRem,
11162 true, h, h + 1, state.Ai_strategy.padded, state.Ai,
11163 state.Ai_strategy);
11164 }
11165
11166 if (!success) {
11167 // Can't make a subblock. Will need a new layout or an incremental copy.
11168 if (strategy.slmUseIncrCopy) {
11169 Ai_remIncrCopy = true;
11170 Ai_layoutK.resize(1);
11171 } else
11172 makeAiBiKCloneLayout(hw, true, Ai_layoutRem, Ai_layoutK,
11173 Ai_regsRem, ka_slm, strategy, state);
11174
11175 aioShareRem = false;
11176 if (state.aioShare || state.aoReuseA)
11177 Ao_regsRem = state.ra.alloc_range(
11178 getRegCount(state.Ao_layout));
11179 break;
11180 }
11181 }
11182 }
11183
11184 if (Bi_incrementalRem) {
11185 Bi_layoutK.resize(kb_slm);
11186 Bi_addrsK.resize(kb_slm);
11187 for (int h = 0; h < kb_slm; h++) {
11188 bool success = false;
11189
11190 if (h < int(Bi_addrsK.size())) {
11191 success = getSubblocks(Tb_ext, Bi_layoutK[h], Bi_addrsK[h],
11192 Bi_layoutRem, state.Bi_addrs, false, h, h + 1,
11193 state.Bi_strategy.padded, state.Bi, state.Bi_strategy);
11194 }
11195
11196 if (!success && h == 0) stub();
11197
11198 if (!success) {
11199 Bi_addrsK.resize(1);
11200 success = getSubblocks(Tb_ext, Bi_layoutK[h], Bi_layoutRem,
11201 false, h, h + 1, state.Bi_strategy.padded, state.Bi,
11202 state.Bi_strategy);
11203 }
11204
11205 if (!success) {
11206 if (strategy.slmUseIncrCopy) {
11207 Bi_remIncrCopy = true;
11208 Bi_layoutK.resize(1);
11209 } else
11210 makeAiBiKCloneLayout(hw, false, Bi_layoutRem, Bi_layoutK,
11211 Bi_regsRem, kb_slm, strategy, state);
11212
11213 bioShareRem = false;
11214 if (state.bioShare || state.boReuseB)
11215 Bo_regsRem = state.ra.alloc_range(
11216 getRegCount(state.Bo_layout));
11217 break;
11218 }
11219 }
11220 }
11221
11222 // Allocate repack registers if we need to assemble multiple loads for
11223 // each outer product calculation.
11224 // TODO: allow allocation to overlap unneeded A/B registers.
11225 auto &repackARem = state.repackARem, &repackBRem = state.repackBRem;
11226 auto &ka_repackRem = state.ka_repackRem, &kb_repackRem = state.kb_repackRem;
11227
11228 repackARem = state.repackA;
11229 repackBRem = state.repackB;
11230 ka_repackRem = state.repackA ? ka_loadRem : 0;
11231 kb_repackRem = state.repackB ? kb_loadRem : 0;
11232 if (minOPCount > 1) {
11233 int crosspackA, crosspackB, tileM_A, tileK_A, tileK_B, tileN_B;
11234 std::tie(crosspackA, crosspackB)
11235 = targetKernelCrosspack(hw, problem, strategy);
11236 std::tie(tileM_A, tileK_A, tileK_B, tileN_B)
11237 = targetKernelTiling(hw, problem, strategy);
11238
11239 if (ka_loadRem < minOPCount) {
11240 ka_repackRem = minOPCount;
11241 if (!repackARem) {
11242 makeUnbackedRegLayout(Ta, state.Ar_layout, unrollM,
11243 ka_repackRem, isLayoutColMajor(state.A_layout),
11244 crosspackA, tileM_A, tileK_A);
11245 state.Ar_regs
11246 = state.ra.alloc_range(getRegCount(state.Ar_layout),
11247 getHint(HintType::A0, strategy));
11248 repackARem = true;
11249 }
11250 }
11251 if (kb_loadRem < minOPCount) {
11252 kb_repackRem = minOPCount;
11253 if (!repackBRem) {
11254 makeUnbackedRegLayout(Tb, state.Br_layout, kb_repackRem,
11255 unrollN, isLayoutColMajor(state.B_layout), crosspackB,
11256 tileK_B, tileN_B);
11257 state.Br_regs
11258 = state.ra.alloc_range(getRegCount(state.Br_layout),
11259 getHint(HintType::B0, strategy));
11260 repackBRem = true;
11261 }
11262 }
11263 }
11264
11265 state.remActiveA = state.remActiveB = state.remActiveSLM = false;
11266 state.slmRemaskA = state.slmRemaskB = false;
11267 state.firstKLoopSegment = true;
11268
11269 return true;
11270}
11271
11272// Set up A/B addresses again in order to prepare for a non-remainder k loop.
11273// Optionally offset addresses in the k dimension.
11274template <HW hw>
11275template <typename I>
11276void gemm_kernel_generator_t<hw>::kLoopReset(const I &kOffset,
11277 const GEMMProblem &problem, const GEMMStrategy &strategy,
11278 GEMMState &state) {
11279 auto Ta = problem.Ta, Ta_ext = problem.Ta_ext, Ta_load = state.Ta_load;
11280 auto Tb = problem.Tb, Tb_ext = problem.Tb_ext, Tb_load = state.Tb_load;
11281 auto globalA = strategy.slmA ? state.Ai : problem.A;
11282 auto globalB = strategy.slmB ? state.Bi : problem.B;
11283 auto globalAStrategy = strategy.slmA ? state.Ai_strategy : strategy.A;
11284 auto globalBStrategy = strategy.slmB ? state.Bi_strategy : strategy.B;
11285 auto globalAParams = strategy.slmA ? state.Ai_params : state.A_params;
11286 auto globalBParams = strategy.slmB ? state.Bi_params : state.B_params;
11287
11288 if (kOffset.isValid()) {
11289 auto &A_offC = globalAParams.offC;
11290 auto A_offCOrig = A_offC;
11291 if (globalAStrategy.address2D) {
11292 if (A_offC == state.h0)
11293 A_offC = state.ra.alloc_sub<int32_t>(
11294 getHint(HintType::LongTerm, strategy));
11295 A_offCOrig.isValid() ? add(1, A_offC, A_offCOrig, kOffset)
11296 : mov(1, A_offC, kOffset);
11297 } else
11298 gemmOffsetAk(kOffset, strategy.slmA ? state.effAi : state.effA,
11299 globalA, problem, strategy, state);
11300
11301 if (strategy.prefetchA) {
11302 if (strategy.A_prefetch.address2D) {
11303 auto &Ap_offC = state.Ap_params.offC;
11304 auto Ap_offCOrig = Ap_offC;
11305 if (Ap_offC == A_offCOrig)
11306 Ap_offC = A_offC;
11307 else {
11308 if (Ap_offC == state.h0)
11309 Ap_offC = state.ra.alloc_sub<int32_t>(
11310 getHint(HintType::LongTerm, strategy));
11311 Ap_offCOrig.isValid()
11312 ? add(1, Ap_offC, Ap_offCOrig, kOffset)
11313 : mov(1, Ap_offC, kOffset);
11314 }
11315 } else if (state.effAp != state.effA)
11316 gemmOffsetAk(kOffset, state.effAp, globalA, problem, strategy,
11317 state);
11318 }
11319
11320 auto &B_offR = globalBParams.offR;
11321 auto B_offROrig = B_offR;
11322 if (globalBStrategy.address2D) {
11323 if (B_offR == A_offCOrig)
11324 B_offR = A_offC;
11325 else {
11326 if (B_offR == state.h0)
11327 B_offR = state.ra.alloc_sub<int32_t>(
11328 getHint(HintType::LongTerm, strategy));
11329 B_offROrig.isValid() ? add(1, B_offR, B_offROrig, kOffset)
11330 : mov(1, B_offR, kOffset);
11331 }
11332 } else
11333 gemmOffsetBk(kOffset, strategy.slmB ? state.effBi : state.effB,
11334 globalA, problem, strategy, state);
11335
11336 if (strategy.prefetchB) {
11337 if (strategy.B_prefetch.address2D) {
11338 auto &Bp_offR = state.Bp_params.offR;
11339 auto Bp_offROrig = Bp_offR;
11340 if (Bp_offR == B_offROrig)
11341 Bp_offR = B_offR;
11342 else {
11343 if (Bp_offR == state.h0)
11344 Bp_offR = state.ra.alloc_sub<int32_t>(
11345 getHint(HintType::LongTerm, strategy));
11346 Bp_offROrig.isValid()
11347 ? add(1, Bp_offR, Bp_offROrig, kOffset)
11348 : mov(1, Bp_offR, kOffset);
11349 }
11350 } else if (state.effBp != state.effB)
11351 gemmOffsetBk(kOffset, state.effBp, globalB, problem, strategy,
11352 state);
11353 }
11354 }
11355
11356 gemmCacheLDABMultiples(problem, strategy, state);
11357 setupAddr(Ta_ext, state.Ap_addrs, state.effAp, state.Ap_layout,
11358 state.inputs.lda, globalA, strategy.A_prefetch, strategy, state,
11359 state.Ap_params, state.ldaMultiples);
11360 setupAddr(Tb_ext, state.Bp_addrs, state.effBp, state.Bp_layout,
11361 state.inputs.ldb, globalB, strategy.B_prefetch, strategy, state,
11362 state.Bp_params, state.ldbMultiples);
11363 setupAddr(Ta_ext, state.Ai_addrs, state.effAi, state.Ai_layout,
11364 state.inputs.lda, state.Ai, state.Ai_strategy, strategy, state,
11365 state.Ai_params, state.ldaMultiples);
11366 setupAddr(Tb_ext, state.Bi_addrs, state.effBi, state.Bi_layout,
11367 state.inputs.ldb, state.Bi, state.Bi_strategy, strategy, state,
11368 state.Bi_params, state.ldbMultiples);
11369 setupAddr(Ta, state.Ao_addrs, state.effAo, state.Ao_layout, Subregister(),
11370 state.Ao, state.Ao_strategy, strategy, state);
11371 setupAddr(Tb, state.Bo_addrs, state.effBo, state.Bo_layout, Subregister(),
11372 state.Bo, state.Bo_strategy, strategy, state);
11373 setupAddr(Ta_load, state.A_addrs, state.effA, state.A_layout,
11374 state.inputs.lda, problem.A, strategy.A, strategy, state,
11375 state.A_params, state.ldaMultiples);
11376 setupAddr(Tb_load, state.B_addrs, state.effB, state.B_layout,
11377 state.inputs.ldb, problem.B, strategy.B, strategy, state,
11378 state.B_params, state.ldbMultiples);
11379 releaseLDMultiples(state.ldaMultiples, state);
11380 releaseLDMultiples(state.ldbMultiples, state);
11381 releaseIndexVec(state);
11382
11383 gemmCalcIncrements(problem, strategy, state);
11384}
11385
11386template <HW hw>
11387void gemm_kernel_generator_t<hw>::kLoopAllocBarrierHeader(GEMMState &state) {
11388 if (state.barrierHeader.isInvalid()) {
11389 state.barrierHeader = state.ra.alloc();
11390 state.barrierReady = false;
11391 }
11392}
11393
11394template <HW hw>
11395GRF gemm_kernel_generator_t<hw>::kLoopGetBarrierHeader(GEMMState &state) {
11396 kLoopAllocBarrierHeader(state);
11397 if (!state.barrierReady) {
11398 if (state.r0_info.isARF()) stub();
11399 barrierheader(state.barrierHeader, GRF {state.r0_info.getBase()});
11400 state.barrierReady = true;
11401 }
11402
11403 return state.barrierHeader;
11404}
11405
11406// Create one step of a sequence of inner loops for a GEMM-like kernel.
11407template <HW hw>
11408void gemm_kernel_generator_t<hw>::kLoop(KLoop type, const GEMMProblem &problem,
11409 GEMMStrategy &strategy, GEMMState &state) {
11410 auto Ta = problem.Ta, Tb = problem.Tb, Tc = problem.Tc;
11411 auto Ta_ext = problem.Ta_ext, Tb_ext = problem.Tb_ext;
11412 auto Ta_load = state.Ta_load, Tb_load = state.Tb_load;
11413
11414 bool cLoadAhead = strategy.cLoadAhead;
11415 auto opCountMain = outerProductCount(hw, problem, strategy);
11416 auto minOPCount = minOuterProductCount(hw, problem, strategy);
11417 auto opCountRem = minOPCount;
11418
11419 auto A_copies = strategy.A_copies;
11420 auto B_copies = strategy.B_copies;
11421 auto slmCopies = strategy.slmCopies;
11422 auto slmBuffers = strategy.slmBuffers;
11423 auto ka_loadMain = strategy.ka_load, ka_loadRem = state.ka_loadRem;
11424 auto kb_loadMain = strategy.kb_load, kb_loadRem = state.kb_loadRem;
11425 auto ka_pfStride = strategy.ka_pfStride;
11426 auto kb_pfStride = strategy.kb_pfStride;
11427 bool slmA = strategy.slmA;
11428 bool slmB = strategy.slmB;
11429 bool slmASums = state.slmASums;
11430 bool slmBSums = state.slmBSums;
11431 bool a2D = isBlock2D(strategy.A.accessType);
11432 bool b2D = isBlock2D(strategy.B.accessType);
11433 bool ai2D = strategy.slmA && isBlock2D(state.Ai_strategy.accessType);
11434 bool bi2D = strategy.slmB && isBlock2D(state.Bi_strategy.accessType);
11435 auto unrollM = strategy.unroll[LoopM];
11436 auto unrollN = strategy.unroll[LoopN];
11437 auto unrollK = strategy.unroll[LoopK];
11438 auto unrollKSLM = strategy.unrollKSLM;
11439 auto ka_slm = state.ka_slm;
11440 auto kb_slm = state.kb_slm;
11441 bool calcASums = problem.needsASums();
11442 bool calcBSums = problem.needsBSums();
11443 bool readA = true, readB = true;
11444
11445 bool Ai_incrementalRem = state.Ai_incrementalRem;
11446 bool Bi_incrementalRem = state.Bi_incrementalRem;
11447 bool Ai_remIncrCopy = state.Ai_remIncrCopy;
11448 bool Bi_remIncrCopy = state.Bi_remIncrCopy;
11449 bool Ai_lateKRem = state.Ai_lateKRem;
11450 bool Bi_lateKRem = state.Bi_lateKRem;
11451
11452 bool &remActiveA = state.remActiveA, &remActiveB = state.remActiveB;
11453 bool &remActiveSLM = state.remActiveSLM;
11454 auto &kMasksSLM = state.kMasksSLM;
11455 bool &slmRemaskA = state.slmRemaskA, &slmRemaskB = state.slmRemaskB;
11456 bool lateKLoopCheck = state.lateKLoopCheck;
11457
11458 bool needBarrier = (slmA || slmB || strategy.barrierFreq > 0);
11459 bool nbM = (strategy.slmA || strategy.barrierFreq)
11460 && strategy.namedBarriers[LoopM];
11461 bool nbN = (strategy.slmB || strategy.barrierFreq)
11462 && strategy.namedBarriers[LoopN];
11463
11464 bool needSLMReset = false;
11465
11466 int curPhase;
11467 int lastThresh = 0;
11468
11469 // Get r0 information where needed.
11470 GRF r0_info;
11471 if (needBarrier) {
11472 if (state.r0_info.isARF()) stub();
11473 r0_info = GRF {state.r0_info.getBase()};
11474 }
11475
11476 // Unified barrier and SLM fence handling for k loop.
11477 auto &modBarrierFence = state.modBarrierFence;
11478 auto &barrierHeader = state.barrierHeader;
11479 auto &barrierReady = state.barrierReady;
11480
11481 auto getFenceTemp = [&]() {
11482 auto temp = state.ra.try_alloc();
11483 if (temp.isValid()) return temp;
11484 if (barrierHeader.isValid()) {
11485 barrierReady = false;
11486 return barrierHeader;
11487 }
11488 throw ngen::out_of_registers_exception();
11489 };
11490
11491 auto releaseFenceTemp = [&](GRF temp) {
11492 if (temp != barrierHeader) state.ra.release(temp);
11493 };
11494
11495 GRF slmFenceTemp;
11496 auto slmFenceIssue = [&]() {
11497 if (hw >= HW::Gen11) {
11498 slmFenceTemp = getFenceTemp();
11499 slmfence(modBarrierFence[0], slmFenceTemp, r0_info);
11500 releaseFenceTemp(slmFenceTemp);
11501 }
11502 };
11503
11504 auto slmFenceWait = [&]() {
11505 if (hw >= HW::Gen12LP)
11506 wrdep(slmFenceTemp);
11507 else if (hw >= HW::Gen11)
11508 mov<uint32_t>(8, null, slmFenceTemp);
11509 };
11510
11511 enum class KBarrierType { Normal, Signal, Wait };
11512 auto kLoopBarrier = [&](bool withSLMFence,
11513 KBarrierType type = KBarrierType::Normal) {
11514 withSLMFence &= (hw >= HW::Gen11); // No SLM fences needed on Gen9.
11515
11516 if (withSLMFence && type == KBarrierType::Wait) {
11517 auto temp = getFenceTemp();
11518 slmfence(modBarrierFence[0], temp, r0_info);
11519 (hw >= HW::Gen12LP) ? wrdep(temp) : mov<uint32_t>(8, null, temp);
11520 releaseFenceTemp(temp);
11521 }
11522
11523 if (!nbM && !nbN) {
11524 if (type != KBarrierType::Wait) {
11525 kLoopAllocBarrierHeader(state);
11526 auto temp = getFenceTemp();
11527 if (withSLMFence) {
11528 slmfence(modBarrierFence[0], temp, r0_info);
11529 (hw >= HW::Gen12LP) ? wrdep(temp)
11530 : mov<uint32_t>(8, null, temp);
11531 }
11532 auto header = kLoopGetBarrierHeader(state);
11533 barriermsg(modBarrierFence[0], header);
11534 releaseFenceTemp(temp);
11535 }
11536 if (type != KBarrierType::Signal) barrierwait();
11537 } else {
11538 if (type != KBarrierType::Wait) {
11539 if (withSLMFence) {
11540 auto temp = getFenceTemp();
11541 slmfence(temp, r0_info);
11542 wrdep(temp);
11543 releaseFenceTemp(temp);
11544 }
11545 if (nbM) barriermsg(modBarrierFence[0], state.barrierHeaderM);
11546 if (nbN)
11547 barriermsg(
11548 modBarrierFence[nbM ? 1 : 0], state.barrierHeaderN);
11549 }
11550 if (type != KBarrierType::Signal) {
11551 if (nbM) sync.bar(state.barrierM);
11552 if (nbN) sync.bar(state.barrierN);
11553 }
11554 }
11555 };
11556
11557 bool mustActivateRemainder = false;
11558
11559 auto activateABRemainder = [&](bool active, bool doA, bool doB) {
11560 if (remActiveA == active) doA = false;
11561 if (remActiveB == active) doB = false;
11562 if (!active && ((doA && remActiveA) || (doB && remActiveB))) stub();
11563 if (!doA && !doB) return;
11564
11565 if (doA) remActiveA = active;
11566 if (doB) remActiveB = active;
11567
11568 // Adjust A/B/Ai/Bi addresses if needed.
11569 if (doA)
11570 adjustSubblockAddrs(Ta_load, state.A_layoutRem, state.A_addrsRem,
11571 state.A_layout, state.A_addrs, problem.A, strategy.A,
11572 strategy, state);
11573 if (doB)
11574 adjustSubblockAddrs(Tb_load, state.B_layoutRem, state.B_addrsRem,
11575 state.B_layout, state.B_addrs, problem.B, strategy.B,
11576 strategy, state);
11577
11578 if (doA && strategy.slmA && (state.effCoopA == CoopSplit::K) && !ai2D) {
11579 vector<RegisterBlock> tempLayout;
11580 vector<GRFRange> tempAddrs;
11581 if (!getSubblocks(Ta_ext, tempLayout, tempAddrs, state.Ai_layout,
11582 state.Ai_addrs, true, 0, 1, state.Ai_strategy.padded,
11583 state.Ai, state.Ai_strategy))
11584 stub();
11585 adjustSubblockAddrs(Ta_ext, tempLayout, tempAddrs, state.Ai_layout,
11586 state.Ai_addrs, state.Ai, state.Ai_strategy, strategy,
11587 state);
11588 }
11589 if (doB && strategy.slmB && (state.effCoopB == CoopSplit::K) && !bi2D) {
11590 vector<RegisterBlock> tempLayout;
11591 vector<GRFRange> tempAddrs;
11592 if (!getSubblocks(Tb_ext, tempLayout, tempAddrs, state.Bi_layout,
11593 state.Bi_addrs, false, 0, 1, state.Bi_strategy.padded,
11594 state.Bi, state.Bi_strategy))
11595 stub();
11596 adjustSubblockAddrs(Tb_ext, tempLayout, tempAddrs, state.Bi_layout,
11597 state.Bi_addrs, state.Bi, state.Bi_strategy, strategy,
11598 state);
11599 }
11600
11601 if (doA && a2D && (ka_loadRem > 1))
11602 setAddrRemainder(Ta_load, state.A_addrsRem, state.A_layoutRem,
11603 Subregister(), state.K, problem.A, strategy.A, strategy,
11604 state);
11605 if (doB && b2D && (kb_loadRem > 1))
11606 setAddrRemainder(Tb_load, state.B_addrsRem, state.B_layoutRem,
11607 state.K, Subregister(), problem.B, strategy.B, strategy,
11608 state);
11609
11610 // Recalculate lda_ka/ldb_kb if needed.
11611 gemmCalcIncrements(
11612 problem, strategy, state, ka_loadRem, kb_loadRem, doA, doB);
11613 };
11614
11615 Subregister kSLMStorage;
11616 Subregister kSLMA, kSLMB; // k remainders for k-split SLM loads
11617
11618 auto resetKSLM = [&]() {
11619 state.ra.safeRelease(kSLMStorage);
11620 kSLMA = kSLMB = invalid;
11621 };
11622
11623 auto activateSLMRemainder = [&](bool active, int kOffset = 0) {
11624 // Calculate or recalculate SLM k remainders as needed.
11625 if (active && kSLMStorage.isInvalid()) {
11626 if (Ai_incrementalRem || Bi_incrementalRem)
11627 kSLMStorage = state.ra.alloc_sub<uint32_t>();
11628
11629 if (Ai_incrementalRem) {
11630 kSLMA = kSLMStorage.w(0);
11631 int kgran, kdiv, krep;
11632 switch (state.effCoopA) {
11633 case CoopSplit::MN:
11634 kgran = unrollKSLM;
11635 kdiv = 1;
11636 krep = strategy.wg[LoopN];
11637 break;
11638 case CoopSplit::K:
11639 kgran = ka_slm;
11640 kdiv = strategy.wg[LoopN];
11641 krep = 1;
11642 break;
11643 case CoopSplit::Linear:
11644 kgran = std::max(state.Ai.crosspack, state.Ai.tileC);
11645 kdiv = unrollKSLM / kgran;
11646 krep = strategy.wg[LoopN] / kdiv;
11647 break;
11648 default: stub();
11649 }
11650 gemmCalcKSLM(kSLMA, state.lidN, kgran, kdiv, krep, problem,
11651 strategy, state);
11652 }
11653
11654 if (Bi_incrementalRem) {
11655 kSLMB = kSLMStorage.w(1);
11656 int kgran, kdiv, krep;
11657 switch (state.effCoopB) {
11658 case CoopSplit::MN:
11659 kgran = unrollKSLM;
11660 kdiv = 1;
11661 krep = strategy.wg[LoopM];
11662 break;
11663 case CoopSplit::K:
11664 kgran = kb_slm;
11665 kdiv = strategy.wg[LoopM];
11666 krep = 1;
11667 break;
11668 case CoopSplit::Linear:
11669 kgran = std::max(state.Bi.crosspack, state.Bi.tileR);
11670 kdiv = unrollKSLM / kgran;
11671 krep = strategy.wg[LoopM] / kdiv;
11672 break;
11673 default: stub();
11674 }
11675 gemmCalcKSLM(kSLMB, state.lidM, kgran, kdiv, krep, problem,
11676 strategy, state);
11677 }
11678
11679 if ((Ai_incrementalRem || Bi_incrementalRem) && kOffset != 0)
11680 add(2, kSLMStorage.w()(1), kSLMStorage.w()(1), kOffset);
11681 }
11682
11683 // k mask information.
11684 Subregister rems[3]
11685 = {state.remainders[LoopM], state.remainders[LoopN], state.K};
11686 int offsets[3] = {0, 0, -kOffset};
11687
11688 // If not changing between main loop and remainder, update k masks as needed and return.
11689 if (remActiveSLM == active) {
11690 if (active) {
11691 state.wipeActiveVFlags();
11692 loadMasks(kMasksSLM, rems, offsets, strategy, state);
11693 }
11694 return;
11695 }
11696
11697 // Not possible to deactivate remainder path with late k remainder.
11698 if (!active && remActiveSLM && (Ai_lateKRem || Bi_lateKRem)) stub();
11699 remActiveSLM = active;
11700
11701 // Start using k masks if needed.
11702 if (Ai_lateKRem && !state.Ai_strategy.padded) {
11703 state.Ai_layoutRem = state.Ai_layout;
11704 state.Ai_addrsRem = state.Ai_addrs;
11705 addMasking(Ta_ext, state.Ai_layoutRem, state.Ai_addrsRem,
11706 state.inputs.lda, false, true, state.Ai, state.Ai_strategy,
11707 strategy, state, state.Ai_regCount);
11708 if (!assignMasks(state.Ai_layoutRem, LoopM, LoopK, kMasksSLM,
11709 strategy, state, true))
11710 stub();
11711 if (state.aioShare && state.Ao_regsRem.empty()
11712 && state.Ai_layoutRem[0].crosspack
11713 != state.Ai_layout[0].crosspack) {
11714 state.aioShareRem = false;
11715 state.Ao_regsRem
11716 = state.ra.alloc_range(getRegCount(state.Ao_layout));
11717 }
11718 }
11719 if (Bi_lateKRem && !state.Bi_strategy.padded) {
11720 state.Bi_layoutRem = state.Bi_layout;
11721 state.Bi_addrsRem = state.Bi_addrs;
11722 addMasking(Tb_ext, state.Bi_layoutRem, state.Bi_addrsRem,
11723 state.inputs.ldb, true, false, state.Bi, state.Bi_strategy,
11724 strategy, state, state.Bi_regCount);
11725 if (!assignMasks(state.Bi_layoutRem, LoopK, LoopN, kMasksSLM,
11726 strategy, state, true))
11727 stub();
11728 if (state.bioShare && state.Bo_regsRem.empty()
11729 && state.Bi_layoutRem[0].crosspack
11730 != state.Bi_layout[0].crosspack) {
11731 state.bioShareRem = false;
11732 state.Bo_regsRem
11733 = state.ra.alloc_range(getRegCount(state.Bo_layout));
11734 }
11735 }
11736
11737 if (problem.backward())
11738 for (auto &mask : kMasksSLM)
11739 mask.reverse(unrollKSLM);
11740
11741 if (!state.vflagStorage.isValid()) {
11742 bool needVFlags = false;
11743 for (const auto &mask : kMasksSLM)
11744 needVFlags |= state.raVFlag.isVirtual(mask.flag);
11745 if (needVFlags) allocVFlagStorage(strategy, state);
11746 }
11747 loadMasks(kMasksSLM, rems, offsets, strategy, state);
11748
11749 bool mayAccessAllK = (minOPCount > 1) || problem.sumA || problem.sumB;
11750 bool asIfMaskedAi = Ai_lateKRem && state.Ai_strategy.padded;
11751 bool asIfMaskedBi = Bi_lateKRem && state.Bi_strategy.padded;
11752 slmRemaskA = slmA && mayAccessAllK && !Ai_remIncrCopy
11753 && needsRemask(Ta_ext, true, state.Ai_layoutRem,
11754 state.Ai_strategy, asIfMaskedAi);
11755 slmRemaskB = slmB && mayAccessAllK && !Bi_remIncrCopy
11756 && needsRemask(Tb_ext, false, state.Bi_layoutRem,
11757 state.Bi_strategy, asIfMaskedBi);
11758 };
11759
11760 // Get state.K, the loop counter.
11761 // The caller may initialize state.K, in case its value on entry is the loop count.
11762 // Otherwise, it is initialized from state.k.
11763 auto kInput = state.k;
11764 bool matchBarriers = (strategy.kParallelLocal && needBarrier);
11765 bool saveK = state.isNested || (problem.abOffset != ABOffset::None)
11766 || matchBarriers;
11767 bool incomingK = state.K.isValid();
11768
11769 if (!incomingK) state.K = saveK ? state.ra.alloc_sub<int32_t>() : kInput;
11770
11771 if (saveK && !incomingK) mov(1, state.K, kInput);
11772
11773 if (state.firstKLoopSegment) {
11774 // Zero out A/B sums if needed.
11775 if (calcASums) zeroMatrix(state.As_regs, strategy);
11776 if (calcBSums) zeroMatrix(state.Bs_regs, strategy);
11777
11778 // Zero out C, if not loading ahead of time.
11779 if (!cLoadAhead) {
11780 for (int i = 0; i < state.C_accCount; i += 2)
11781 mov<uint32_t>(2 * elementsPerGRF<uint32_t>(hw),
11782 AccumulatorRegister(i), uint16_t(0));
11783
11784 for (int buf = 0; buf < state.C_buffers; buf++)
11785 zeroMatrix(state.C_regs[buf], strategy);
11786 }
11787 }
11788
11789 LoopSequencer ls;
11790 using namespace loop_sequencer;
11791
11792 int slmBufferLA = 0;
11793 switch (slmBuffers) {
11794 case 0:
11795 case 1: slmBufferLA = 0; break;
11796 case 2:
11797 case 3: slmBufferLA = 1; break;
11798 case 4: slmBufferLA = 2; break;
11799 default: stub();
11800 }
11801
11802 int lookaheadALoad = ka_loadMain * (A_copies - 1);
11803 int lookaheadBLoad = kb_loadMain * (B_copies - 1);
11804 int lookaheadALoadRem = ka_loadRem * (A_copies - 1);
11805 int lookaheadBLoadRem = kb_loadRem * (B_copies - 1);
11806 int lookaheadSLMLoad = unrollKSLM * (slmCopies - 1) + unrollKSLM - 1;
11807 int lookaheadSLMStore = unrollKSLM * slmBufferLA + 1;
11808
11809 if (slmA && slmB) {
11810 if (lookaheadALoad != lookaheadBLoad) stub();
11811 if (lookaheadALoadRem != lookaheadBLoadRem) stub();
11812 if (ka_loadMain != kb_loadMain && lookaheadALoad != lookaheadALoadRem)
11813 stub();
11814 }
11815
11816 int lookaheadSLMReload = slmA ? lookaheadALoad : lookaheadBLoad;
11817 int lookaheadSLMReloadRem = slmA ? lookaheadALoadRem : lookaheadBLoadRem;
11818 int durationSLMMainLoad = std::max(slmA * ka_loadMain, slmB * kb_loadMain);
11819
11820 auto A_remActive = [&](Iteration h) {
11821 return (h.remaining() < ka_loadMain - (h % ka_loadMain));
11822 };
11823 auto B_remActive = [&](Iteration h) {
11824 return (h.remaining() < kb_loadMain - (h % kb_loadMain));
11825 };
11826 auto slmRemActive = [&](Iteration h) {
11827 return (h.remaining() < unrollKSLM - (h % unrollKSLM));
11828 };
11829 auto opRemActive = [&](Iteration h) {
11830 return (h.remaining() < opCountMain - (h % opCountMain));
11831 };
11832 auto repackA = [&](Iteration h) {
11833 return A_remActive(h) ? state.repackARem : state.repackA;
11834 };
11835 auto repackB = [&](Iteration h) {
11836 return B_remActive(h) ? state.repackBRem : state.repackB;
11837 };
11838 auto ka_load = [&](Iteration h) {
11839 return A_remActive(h) ? ka_loadRem : ka_loadMain;
11840 };
11841 auto kb_load = [&](Iteration h) {
11842 return B_remActive(h) ? kb_loadRem : kb_loadMain;
11843 };
11844 auto A_copy = [&](Iteration h) { return (h / ka_load(h)) % A_copies; };
11845 auto B_copy = [&](Iteration h) { return (h / kb_load(h)) % B_copies; };
11846 auto A_regs = [&](Iteration h) -> GRFMultirange & {
11847 return state.A_regs[A_copy(h)];
11848 };
11849 auto B_regs = [&](Iteration h) -> GRFMultirange & {
11850 return state.B_regs[B_copy(h)];
11851 };
11852 auto A_layout = [&](Iteration h) -> vector<RegisterBlock> & {
11853 return A_remActive(h) ? state.A_layoutRem : state.A_layout;
11854 };
11855 auto B_layout = [&](Iteration h) -> vector<RegisterBlock> & {
11856 return B_remActive(h) ? state.B_layoutRem : state.B_layout;
11857 };
11858 auto Ar_regs = [&](Iteration h) -> GRFMultirange & {
11859 return repackA(h) ? state.Ar_regs : A_regs(h);
11860 };
11861 auto Br_regs = [&](Iteration h) -> GRFMultirange & {
11862 return repackB(h) ? state.Br_regs : B_regs(h);
11863 };
11864 auto Ar_layout = [&](Iteration h) -> vector<RegisterBlock> & {
11865 return repackA(h) ? state.Ar_layout : A_layout(h);
11866 };
11867 auto Br_layout = [&](Iteration h) -> vector<RegisterBlock> & {
11868 return repackB(h) ? state.Br_layout : B_layout(h);
11869 };
11870 auto slmCopy = [&](Iteration h) { return (h / unrollKSLM) % slmCopies; };
11871 auto slmBuffer = [&](Iteration h) { return (h / unrollKSLM) % slmBuffers; };
11872 auto Ai_layout = [&](Iteration h) -> vector<RegisterBlock> & {
11873 return slmRemActive(h) ? state.Ai_layoutRem : state.Ai_layout;
11874 };
11875 auto Bi_layout = [&](Iteration h) -> vector<RegisterBlock> & {
11876 return slmRemActive(h) ? state.Bi_layoutRem : state.Bi_layout;
11877 };
11878 auto Ai_addrs = [&](Iteration h) -> vector<GRFRange> & {
11879 return slmRemActive(h) ? state.Ai_addrsRem : state.Ai_addrs;
11880 };
11881 auto Bi_addrs = [&](Iteration h) -> vector<GRFRange> & {
11882 return slmRemActive(h) ? state.Bi_addrsRem : state.Bi_addrs;
11883 };
11884 auto Ai_allRegs = [&](Iteration h) -> vector<GRFMultirange> & {
11885 return slmRemActive(h) ? state.Ai_regsRem : state.Ai_regs;
11886 };
11887 auto Bi_allRegs = [&](Iteration h) -> vector<GRFMultirange> & {
11888 return slmRemActive(h) ? state.Bi_regsRem : state.Bi_regs;
11889 };
11890 auto Ai_regs = [&](Iteration h) -> GRFMultirange & {
11891 return Ai_allRegs(h)[slmCopy(h)];
11892 };
11893 auto Bi_regs = [&](Iteration h) -> GRFMultirange & {
11894 return Bi_allRegs(h)[slmCopy(h)];
11895 };
11896 auto Ao_regs = [&](Iteration h) -> GRFMultirange & {
11897 return slmRemActive(h) ? state.Ao_regsRem : state.Ao_regs;
11898 };
11899 auto Bo_regs = [&](Iteration h) -> GRFMultirange & {
11900 return slmRemActive(h) ? state.Bo_regsRem : state.Bo_regs;
11901 };
11902 auto effAo_regs = [&](Iteration h) -> GRFMultirange & {
11903 return Ao_regs(h).empty() ? Ai_regs(h) : Ao_regs(h);
11904 };
11905 auto effBo_regs = [&](Iteration h) -> GRFMultirange & {
11906 return Bo_regs(h).empty() ? Bi_regs(h) : Bo_regs(h);
11907 };
11908 auto aioShare = [&](Iteration h) {
11909 return slmRemActive(h) ? state.aioShareRem : state.aioShare;
11910 };
11911 auto bioShare = [&](Iteration h) {
11912 return slmRemActive(h) ? state.bioShareRem : state.bioShare;
11913 };
11914 auto opCount = [&](Iteration h) {
11915 return opRemActive(h) ? opCountRem : opCountMain;
11916 };
11917 auto nothing = [&](Iteration h) {};
11918
11919 // Dummy task to extend k unroll if needed.
11920 ls.schedule(every(unrollK) | checkOptional(), nothing);
11921
11922 // A prefetch.
11923 auto reqPFA = every(ka_pfStride)
11924 | duration(
11925 strategy.cooperativePF ? ka_pfStride : strategy.ka_prefetch)
11926 | lookahead(strategy.prefetchA);
11927
11928 if (strategy.prefetchA && !strategy.slmA && readA) {
11929 ls.schedule(reqPFA, [&](Iteration h) {
11930 auto &A_global = strategy.slmA ? state.Ai : problem.A;
11931 gemmALoad(state.Ap_regs, state.Ap_layout, state.Ap_addrs, A_global,
11932 strategy.A_prefetch, problem, strategy, state);
11933 });
11934 }
11935
11936 // B prefetch.
11937 auto reqPFB = every(kb_pfStride)
11938 | duration(
11939 strategy.cooperativePF ? kb_pfStride : strategy.kb_prefetch)
11940 | lookahead(strategy.prefetchB);
11941
11942 if (strategy.prefetchB && !strategy.slmB && readB) {
11943 ls.schedule(reqPFB, [&](Iteration h) {
11944 auto &B_global = strategy.slmB ? state.Bi : problem.B;
11945 gemmBLoad(state.Bp_regs, state.Bp_layout, state.Bp_addrs, B_global,
11946 strategy.B_prefetch, problem, strategy, state);
11947 });
11948 }
11949
11950 // SLM loads.
11951 auto reqSLMLoad = every(unrollKSLM) | variants(slmCopies)
11952 | lookahead(
11953 lookaheadSLMLoad + lookaheadSLMStore + lookaheadSLMReload);
11954 auto reqSLMLoadABRem = every(unrollKSLM) | variants(slmCopies)
11955 | lookahead(lookaheadSLMLoad + lookaheadSLMStore
11956 + lookaheadSLMReloadRem);
11957 auto reqSLMStore = every(unrollKSLM) | variants(slmCopies)
11958 | lookahead(lookaheadSLMStore + lookaheadSLMReload)
11959 | duration(durationSLMMainLoad);
11960 auto reqSLMStoreABRem = every(unrollKSLM) | variants(slmCopies)
11961 | lookahead(lookaheadSLMStore + lookaheadSLMReloadRem);
11962
11963 if ((slmA || slmB) && mustActivateRemainder) {
11964 ls.schedule({{reqSLMLoad | duration(unrollKSLM), nothing},
11965 {reqSLMLoad | unconditional(), [&](Iteration h) {
11966 activateSLMRemainder(true, h.counterOffset());
11967 }}});
11968 }
11969
11970 auto doSLMRemLoad = [&](Iteration h) {
11971 activateSLMRemainder(true, h.counterOffset());
11972 if (slmA)
11973 gemmAiBiRemLoadInc<true>(Ai_incrementalRem, Ai_remIncrCopy,
11974 needSLMReset, slmRemaskA, kSLMA, Ai_regs(h),
11975 state.Ai_layoutRem, state.Ai_addrsRem, state.Ai_layoutK,
11976 state.Ai_addrsK, state.Ao_regsRem, state.Ao_layout,
11977 state.Ai, state.Ai_strategy, problem, strategy, state);
11978 if (slmB)
11979 gemmAiBiRemLoadInc<false>(Bi_incrementalRem, Bi_remIncrCopy,
11980 needSLMReset, slmRemaskB, kSLMB, Bi_regs(h),
11981 state.Bi_layoutRem, state.Bi_addrsRem, state.Bi_layoutK,
11982 state.Bi_addrsK, state.Bo_regsRem, state.Bo_layout,
11983 state.Bi, state.Bi_strategy, problem, strategy, state);
11984 if (Ai_incrementalRem || Bi_incrementalRem) lastThresh = 0;
11985 };
11986
11987 if (slmA || slmB) {
11988 ls.schedule({{reqSLMLoad | duration(unrollKSLM),
11989 [&](Iteration h) {
11990 activateSLMRemainder(false);
11991 if (slmA)
11992 gemmALoad(Ai_regs(h), state.Ai_layout,
11993 state.Ai_addrs, state.Ai,
11994 state.Ai_strategy, problem,
11995 strategy, state);
11996 if (slmB)
11997 gemmBLoad(Bi_regs(h), state.Bi_layout,
11998 state.Bi_addrs, state.Bi,
11999 state.Bi_strategy, problem,
12000 strategy, state);
12001 }},
12002 {reqSLMLoad | duration(durationSLMMainLoad), doSLMRemLoad},
12003 {reqSLMLoadABRem, doSLMRemLoad}});
12004 }
12005
12006 // Read suppression W/A for fused EU architectures.
12007 bool rswaA = strategy.readSuppressionWA && (A_copies == 1)
12008 && ((ka_loadMain <= opCountMain) || state.repackA)
12009 && hasMasking(state.A_layout);
12010 bool rswaB = strategy.readSuppressionWA && (B_copies == 1)
12011 && ((kb_loadMain <= opCountMain) || state.repackB)
12012 && hasMasking(state.B_layout);
12013 bool rswaARem = strategy.readSuppressionWA && (A_copies == 1)
12014 && ((ka_loadRem <= opCountRem) || state.repackARem)
12015 && hasMasking(state.A_layoutRem);
12016 bool rswaBRem = strategy.readSuppressionWA && (B_copies == 1)
12017 && ((kb_loadRem <= opCountRem) || state.repackBRem)
12018 && hasMasking(state.B_layoutRem);
12019
12020 Iteration A_lastRSWA;
12021 bool haveA_lastRSWA = false;
12022
12023 bool saveRSWA;
12024 auto disableRSWA = [&]() {
12025 saveRSWA = strategy.readSuppressionWA;
12026 strategy.readSuppressionWA = false;
12027 };
12028 auto restoreRSWA = [&]() { strategy.readSuppressionWA = saveRSWA; };
12029
12030 auto doRSWA_A = [&](Iteration h) {
12031 A_lastRSWA = h;
12032 haveA_lastRSWA = true;
12033 doReadSuppressionWA(strategy, state);
12034 };
12035
12036 auto doRSWA_B = [&](Iteration h) {
12037 if (!(haveA_lastRSWA && A_lastRSWA == h))
12038 doReadSuppressionWA(strategy, state);
12039 haveA_lastRSWA = false;
12040 };
12041
12042 // A/B load scheduling.
12043 auto reqLoadA = every(ka_loadMain) | duration(ka_loadMain)
12044 | variants(A_copies) | lookahead(lookaheadALoad);
12045 auto reqLoadARem = every(ka_loadRem) | variants(A_copies)
12046 | lookahead(lookaheadALoadRem);
12047 auto reqLoadAPrezero = every(minOPCount) | variants(A_copies)
12048 | lookahead(state.repackARem ? 0 : lookaheadALoadRem);
12049
12050 auto reqLoadB = every(kb_loadMain) | duration(kb_loadMain)
12051 | variants(B_copies) | lookahead(lookaheadBLoad);
12052 auto reqLoadBRem = every(kb_loadRem) | variants(B_copies)
12053 | lookahead(lookaheadBLoadRem);
12054 auto reqLoadBPrezero = every(minOPCount) | variants(B_copies)
12055 | lookahead(state.repackBRem ? 0 : lookaheadBLoadRem);
12056
12057 // A/B prezeroing for partial remainder loads with multi-k outer products.
12058 bool prezeroARem = !slmA && (ka_loadRem < minOPCount) && readA;
12059 bool prezeroBRem = !slmB && (kb_loadRem < minOPCount) && readB;
12060
12061 if (prezeroARem && prezeroBRem && Ta.isInteger() && Tb.isInteger()
12062 && !calcASums && !calcBSums) {
12063 // Only need to pre-zero one operand for integer A/B. Choose the smaller one.
12064 if (unrollM >= unrollN)
12065 prezeroARem = false;
12066 else
12067 prezeroBRem = false;
12068 }
12069
12070 if (prezeroARem)
12071 ls.schedule({{reqLoadA, nothing},
12072 {reqLoadAPrezero, [&](Iteration h) {
12073 zeroMatrix(state.repackARem ? state.Ar_regs : A_regs(h),
12074 strategy);
12075 }}});
12076
12077 if (prezeroBRem)
12078 ls.schedule({{reqLoadB, nothing},
12079 {reqLoadBPrezero, [&](Iteration h) {
12080 zeroMatrix(state.repackBRem ? state.Br_regs : B_regs(h),
12081 strategy);
12082 }}});
12083
12084 // A/B enforced remainder preparations.
12085 if (mustActivateRemainder) {
12086 ls.schedule({{reqLoadA, nothing},
12087 {reqLoadARem | unconditional(), [&](Iteration h) {
12088 activateABRemainder(true, true, false);
12089 }}});
12090 ls.schedule({{reqLoadB, nothing},
12091 {reqLoadBRem | unconditional(), [&](Iteration h) {
12092 activateABRemainder(true, false, true);
12093 }}});
12094 }
12095
12096 // A loads.
12097 if (readA)
12098 ls.schedule({{reqLoadA,
12099 [&](Iteration h) {
12100 if (rswaA) doRSWA_A(h);
12101 disableRSWA();
12102 activateABRemainder(false, true, false);
12103 gemmALoad(A_regs(h), state.A_layout,
12104 state.A_addrs, problem.A, strategy.A,
12105 problem, strategy, state);
12106 restoreRSWA();
12107 }},
12108 {reqLoadARem, [&](Iteration h) {
12109 if (rswaARem) doRSWA_A(h);
12110 disableRSWA();
12111 activateABRemainder(true, true, false);
12112 gemmALoad(A_regs(h), state.A_layoutRem, state.A_addrsRem,
12113 problem.A, strategy.A, problem, strategy, state);
12114 restoreRSWA();
12115 }}});
12116
12117 // B loads.
12118 if (readB)
12119 ls.schedule({{reqLoadB,
12120 [&](Iteration h) {
12121 if (rswaB) doRSWA_B(h);
12122 disableRSWA();
12123 activateABRemainder(false, false, true);
12124 gemmBLoad(B_regs(h), state.B_layout,
12125 state.B_addrs, problem.B, strategy.B,
12126 problem, strategy, state);
12127 restoreRSWA();
12128 }},
12129 {reqLoadBRem, [&](Iteration h) {
12130 if (rswaBRem) doRSWA_B(h);
12131 disableRSWA();
12132 activateABRemainder(true, false, true);
12133 gemmBLoad(B_regs(h), state.B_layoutRem, state.B_addrsRem,
12134 problem.B, strategy.B, problem, strategy, state);
12135 restoreRSWA();
12136 }}});
12137
12138 // Stalls to promote thread switches.
12139 auto reqStall = every(lcm(ka_loadMain, kb_loadMain)) | checkOptional();
12140
12141 if (strategy.stallAfterLoad)
12142 ls.schedule(reqStall, [&](Iteration h) {
12143 if (hw < HW::Gen12LP)
12144 mov<uint32_t>(1 | Switch, null, 0);
12145 else if (Tc.isInteger()) {
12146 mov<float>(1, null, 0.0f);
12147 sync.nop(SWSB<float>(1));
12148 } else {
12149 mov<uint32_t>(1, null, 0);
12150 sync.nop(SWSB<uint32_t>(1));
12151 }
12152 });
12153
12154 // k decrement and loop check.
12155 auto reqLoopCheck = every(unrollK) | duration(unrollK);
12156
12157 if (lateKLoopCheck)
12158 reqLoopCheck = reqLoopCheck.delay(
12159 unrollK - std::min(ka_loadMain, kb_loadMain));
12160
12161 ls.schedule_if(
12162 reqLoopCheck,
12163 [&](Iteration h) {
12164 add(1 | gt | f0[0], state.K, state.K, -unrollK);
12165 },
12166 [&](Iteration h) {
12167 return (curPhase == LoopSequencer::PhaseMainLoop);
12168 });
12169
12170 // SLM store address increments.
12171 auto doSLMStoreInc = [&](Iteration h) {
12172 int kIncSLMStore
12173 = (slmBuffer(h) == slmBuffers - 1) ? -(slmBuffers - 1) : +1;
12174 kIncSLMStore *= unrollKSLM;
12175 if (slmA)
12176 gemmAIncrement(Ta, state.Ao_layout, state.Ao_addrs, state.Ao,
12177 state.Ao_strategy, kIncSLMStore, problem, strategy, state);
12178 if (slmB)
12179 gemmBIncrement(Tb, state.Bo_layout, state.Bo_addrs, state.Bo,
12180 state.Bo_strategy, kIncSLMStore, problem, strategy, state);
12181 };
12182
12183 if (strategy.slmBuffers >= 2) {
12184 ls.schedule({{(reqSLMStore | duration(durationSLMMainLoad)).delay(1),
12185 doSLMStoreInc},
12186 {reqSLMStoreABRem.delay(1), doSLMStoreInc}});
12187 }
12188
12189 // SLM load address increments.
12190 int delaySLMInc = strategy.delayABInc ? (unrollKSLM >> 1) : 0;
12191
12192 auto doSLMLoadInc = [&](Iteration h) {
12193 bool fullLoad = (h.remaining() >= (unrollKSLM - delaySLMInc));
12194 if (slmA && (fullLoad || !Ai_incrementalRem))
12195 gemmAIncrement(Ta_ext, Ai_layout(h), Ai_addrs(h), state.Ai,
12196 state.Ai_strategy, unrollKSLM, problem, strategy, state);
12197 if (slmB && (fullLoad || !Bi_incrementalRem))
12198 gemmBIncrement(Tb_ext, Bi_layout(h), Bi_addrs(h), state.Bi,
12199 state.Bi_strategy, unrollKSLM, problem, strategy, state);
12200 };
12201
12202 auto checkSLMLoadInc = [&](Iteration h) {
12203 bool fullLoad = (h.remaining() >= (unrollKSLM - delaySLMInc));
12204 return (slmA && (fullLoad || !Ai_incrementalRem))
12205 || (slmB && (fullLoad || !Bi_incrementalRem));
12206 };
12207
12208 if (slmA || slmB) {
12209 ls.schedule_if({{(reqSLMLoad | duration(durationSLMMainLoad))
12210 .delay(delaySLMInc),
12211 doSLMLoadInc, checkSLMLoadInc},
12212 {reqSLMLoadABRem.delay(delaySLMInc), doSLMLoadInc,
12213 checkSLMLoadInc}});
12214 }
12215
12216 // A prefetch address increment.
12217 int delayAPFInc = strategy.delayABInc ? (ka_pfStride >> 1) : 0;
12218
12219 if (strategy.prefetchA && !slmA && readA) {
12220 ls.schedule(reqPFA.delay(delayAPFInc), [&](Iteration h) {
12221 gemmAIncrement(Ta_ext, state.Ap_layout, state.Ap_addrs, problem.A,
12222 strategy.A_prefetch, ka_pfStride, problem, strategy, state);
12223 });
12224 }
12225
12226 // B prefetch address increment.
12227 int delayBPFInc = strategy.delayABInc ? (kb_pfStride >> 1) : 0;
12228
12229 if (strategy.prefetchB && !slmB && readB) {
12230 ls.schedule(reqPFB.delay(delayBPFInc), [&](Iteration h) {
12231 gemmBIncrement(Tb_ext, state.Bp_layout, state.Bp_addrs, problem.B,
12232 strategy.B_prefetch, kb_pfStride, problem, strategy, state);
12233 });
12234 }
12235
12236 // A address increment.
12237 int delayAInc
12238 = (strategy.delayABInc && A_copies > 1) ? (ka_loadMain >> 1) : 0;
12239
12240 auto ka_inc = [&](Iteration h) {
12241 auto inc = ka_load(h);
12242 if (slmA) {
12243 int kWraparound = unrollKSLM * slmBuffers;
12244 if ((h + inc) % kWraparound < inc) inc -= kWraparound;
12245 }
12246 return inc;
12247 };
12248
12249 if (readA)
12250 ls.schedule({{reqLoadA.delay(delayAInc),
12251 [&](Iteration h) {
12252 gemmAIncrement(Ta_load, state.A_layout,
12253 state.A_addrs, problem.A, strategy.A,
12254 ka_inc(h), problem, strategy, state);
12255 }},
12256 {reqLoadARem, [&](Iteration h) {
12257 gemmAIncrement(Ta_load, state.A_layoutRem,
12258 state.A_addrsRem, problem.A, strategy.A, ka_inc(h),
12259 problem, strategy, state, h % unrollKSLM);
12260 }}});
12261
12262 // B address increment.
12263 int delayBInc
12264 = (strategy.delayABInc && B_copies > 1) ? (kb_loadMain >> 1) : 0;
12265
12266 auto kb_inc = [&](Iteration h) {
12267 auto inc = kb_load(h);
12268 if (slmB) {
12269 int kWraparound = unrollKSLM * slmBuffers;
12270 if ((h + inc) % kWraparound < inc) inc -= kWraparound;
12271 }
12272 return inc;
12273 };
12274
12275 if (readB)
12276 ls.schedule({{reqLoadB.delay(delayBInc),
12277 [&](Iteration h) {
12278 gemmBIncrement(Tb_load, state.B_layout,
12279 state.B_addrs, problem.B, strategy.B,
12280 kb_inc(h), problem, strategy, state);
12281 }},
12282 {reqLoadBRem, [&](Iteration h) {
12283 gemmBIncrement(Tb_load, state.B_layoutRem,
12284 state.B_addrsRem, problem.B, strategy.B, kb_inc(h),
12285 problem, strategy, state, h % unrollKSLM);
12286 }}});
12287
12288 // A/B remasking in k dimension, during remainder handling.
12289 bool remaskA = !slmA && readA && (minOPCount > 1)
12290 && needsRemask(Ta_load, true, state.A_layoutRem, strategy.A);
12291 bool remaskB = !slmB && readB && (minOPCount > 1)
12292 && needsRemask(Tb_load, false, state.B_layoutRem, strategy.B);
12293
12294 if (remaskA && remaskB && Ta.isInteger() && Tb.isInteger() && !calcASums
12295 && !calcBSums) {
12296 // Only need to remask one operand for integer A/B. Choose the smaller one.
12297 if (unrollM >= unrollN)
12298 remaskA = false;
12299 else
12300 remaskB = false;
12301 }
12302
12303 auto Tremask = remaskA ? Ta_load : Tb_load;
12304 if (remaskA && remaskB && Ta_load.size() != Tb_load.size()) stub();
12305 if ((remaskA || remaskB) && problem.backward()) stub();
12306
12307 int remaskPeriod = lcm(ka_loadRem, kb_loadRem);
12308 auto reqRemaskSetup = every(remaskPeriod);
12309 auto reqRemaskA = every(ka_loadRem) | variants(A_copies);
12310 auto reqRemaskB = every(kb_loadRem) | variants(B_copies);
12311
12312 if (remaskA || remaskB)
12313 ls.schedule({{reqRemaskSetup | duration(remaskPeriod), nothing},
12314 {reqRemaskSetup, [&](Iteration h) {
12315 setupTeardownRemask(Tremask, 0, false, remaskPeriod,
12316 state.K, strategy, state);
12317 setupTeardownRemask(Tremask, 0, true, remaskPeriod,
12318 state.K, strategy, state, -h.counterOffset());
12319 }}});
12320
12321 if (remaskA)
12322 ls.schedule({{reqLoadA, nothing},
12323 {reqRemaskA, [&](Iteration h) {
12324 remaskLayout(Ta_load, 0, true, state.A_layoutRem,
12325 A_regs(h), strategy, state, h % remaskPeriod);
12326 }}});
12327
12328 if (remaskB)
12329 ls.schedule({{reqLoadB, nothing},
12330 {reqRemaskB, [&](Iteration h) {
12331 remaskLayout(Tb_load, 0, false, state.B_layoutRem,
12332 B_regs(h), strategy, state, h % remaskPeriod);
12333 }}});
12334
12335 // A/B repacking.
12336 auto reqRepackA = every(ka_loadMain) | variants(A_copies);
12337 auto reqRepackARem = every(ka_loadRem) | variants(A_copies);
12338 bool convertA = (Ta != Ta_load) && (Ta.size() == Ta_load.size());
12339
12340 if (state.repackA || state.repackARem || convertA)
12341 if (readA)
12342 ls.schedule({{reqRepackA,
12343 [&](Iteration h) {
12344 if (state.repackA)
12345 copyRegisters(Ta_load, Ta,
12346 state.A_layout,
12347 state.Ar_layout, A_regs(h),
12348 state.Ar_regs, 0, 0, false,
12349 strategy, state);
12350 else if (convertA)
12351 convert(A_regs(h), Ta_load, Ta,
12352 problem, strategy, state);
12353 }},
12354 {reqRepackARem, [&](Iteration h) {
12355 if (state.repackARem)
12356 copyRegisters(Ta_load, Ta, state.A_layoutRem,
12357 state.Ar_layout, A_regs(h), state.Ar_regs,
12358 0, h % state.ka_repackRem, false, strategy,
12359 state);
12360 else if (convertA)
12361 convert(A_regs(h), Ta_load, Ta, problem, strategy,
12362 state);
12363 }}});
12364
12365 auto reqRepackB = every(kb_loadMain) | variants(B_copies);
12366 auto reqRepackBRem = every(kb_loadRem) | variants(B_copies);
12367 bool convertB = (Tb != Tb_load) && (Tb.size() == Tb_load.size());
12368
12369 if (state.repackB || state.repackBRem || convertB)
12370 if (readB)
12371 ls.schedule({{reqRepackB,
12372 [&](Iteration h) {
12373 if (state.repackB)
12374 copyRegisters(Tb_load, Tb,
12375 state.B_layout,
12376 state.Br_layout, B_regs(h),
12377 state.Br_regs, 0, 0, false,
12378 strategy, state);
12379 else if (convertB)
12380 convert(B_regs(h), Tb_load, Tb,
12381 problem, strategy, state);
12382 }},
12383 {reqRepackBRem, [&](Iteration h) {
12384 if (state.repackBRem)
12385 copyRegisters(Tb_load, Tb, state.B_layoutRem,
12386 state.Br_layout, B_regs(h), state.Br_regs,
12387 h % state.kb_repackRem, 0, false, strategy,
12388 state);
12389 else if (convertB)
12390 convert(B_regs(h), Tb_load, Tb, problem, strategy,
12391 state);
12392 }}});
12393
12394 // Outer product(s).
12395 // If outer products batched across k (dp4a/dpas/k-chaining), trigger every opCount loops.
12396 auto reqOP = every(minOPCount) | lookahead(-(minOPCount - 1));
12397
12398 int ka_sumMain
12399 = !isLayoutColMajor(state.A_layout) ? ka_loadMain : opCountMain;
12400 int kb_sumMain
12401 = isLayoutColMajor(state.B_layout) ? kb_loadMain : opCountMain;
12402
12403 ls.schedule(reqOP, [&](Iteration h) {
12404 auto oc = opCount(h);
12405 auto hNext = h + minOPCount;
12406 if (hNext % oc != 0) return;
12407
12408 int ka = ka_load(h), kb = kb_load(h);
12409 int ha = h % ka;
12410 int hb = h % kb;
12411 if (problem.backward()) {
12412 ha = ka - 1 - ha;
12413 hb = kb - 1 - hb;
12414 }
12415
12416 auto &layoutA = Ar_layout(h);
12417 auto &layoutB = Br_layout(h);
12418 auto &regsA = Ar_regs(h);
12419 auto &regsB = Br_regs(h);
12420
12421 outerProduct(h, ha, hb, oc, layoutA, layoutB, regsA, regsB, problem,
12422 strategy, state);
12423
12424 if (calcASums && !slmASums && !state.systolicSumA) {
12425 int ka_sum = (curPhase == LoopSequencer::PhaseMainLoop) ? ka_sumMain
12426 : oc;
12427 int ha0 = ha - oc + minOPCount;
12428 if (ha0 % ka_sum == 0)
12429 accumulateSum(false, Ta, regsA, layoutA, Tc, state.As_regs,
12430 state.As_layout, strategy, state, ha0, ha0 + ka_sum);
12431 }
12432
12433 if (calcBSums && !slmBSums && !state.systolicSumB) {
12434 int kb_sum = (curPhase == LoopSequencer::PhaseMainLoop) ? kb_sumMain
12435 : oc;
12436 int hb0 = hb - oc + minOPCount;
12437 if (hb0 % kb_sum == 0)
12438 accumulateSum(true, Tb, regsB, layoutB, Tc, state.Bs_regs,
12439 state.Bs_layout, strategy, state, hb0, hb0 + kb_sum);
12440 }
12441 });
12442
12443 // SLM data repacking and remasking.
12444 auto reqSLMRepack = every(unrollKSLM) | variants(slmCopies)
12445 | lookahead(lookaheadSLMStore + lookaheadSLMReload
12446 + strategy.slmRepackAhead)
12447 | duration(durationSLMMainLoad);
12448 auto reqSLMRepackABRem = every(unrollKSLM) | variants(slmCopies)
12449 | lookahead(lookaheadSLMStore + lookaheadSLMReloadRem
12450 + strategy.slmRepackAhead);
12451
12452 auto slmConvertA = [&](Iteration h) {
12453 return slmA && aioShare(h) && (Ta != Ta_ext)
12454 && (Ta.size() == Ta_ext.size());
12455 };
12456 auto slmConvertB = [&](Iteration h) {
12457 return slmB && bioShare(h) && (Tb != Tb_ext)
12458 && (Tb.size() == Tb_ext.size());
12459 };
12460
12461 auto doSLMRepack = [&](Iteration h) {
12462 if (slmA && !aioShare(h) && !(slmRemActive(h) && Ai_remIncrCopy))
12463 copyRegisters(Ta_ext, Ta, Ai_layout(h), state.Ao_layout, Ai_regs(h),
12464 Ao_regs(h), 0, 0, false, strategy, state);
12465 else if (slmConvertA(h))
12466 convert(Ai_regs(h), Ta_ext, Ta, problem, strategy, state);
12467
12468 if (slmB && !bioShare(h) && !(slmRemActive(h) && Bi_remIncrCopy))
12469 copyRegisters(Tb_ext, Tb, Bi_layout(h), state.Bo_layout, Bi_regs(h),
12470 Bo_regs(h), 0, 0, false, strategy, state);
12471 else if (slmConvertB(h))
12472 convert(Bi_regs(h), Tb_ext, Tb, problem, strategy, state);
12473
12474 if (slmRemActive(h) && (slmRemaskA || slmRemaskB)) {
12475 releaseMaskAssignments(kMasksSLM,
12476 state); // Not in use -- can temporarily free these.
12477 gemmSLMRemask(slmRemaskA, slmRemaskB, effAo_regs(h), effBo_regs(h),
12478 -h.counterOffset(), problem, strategy, state);
12479 reclaimMaskAssignments(kMasksSLM, state);
12480 }
12481 };
12482
12483 auto checkSLMRepack = [&](Iteration h) {
12484 return (slmA && !aioShare(h) && !(slmRemActive(h) && Ai_remIncrCopy))
12485 || (slmB && !bioShare(h)
12486 && !(slmRemActive(h) && Bi_remIncrCopy))
12487 || (slmRemActive(h) && (slmRemaskA || slmRemaskB))
12488 || slmConvertA(h) || slmConvertB(h);
12489 };
12490
12491 if (slmA || slmB) {
12492 ls.schedule_if({{reqSLMRepack, doSLMRepack, checkSLMRepack},
12493 {reqSLMRepackABRem, doSLMRepack, checkSLMRepack}});
12494 }
12495
12496 // SLM stores and synchronization.
12497 auto reqSLMAfterStore = every(unrollKSLM) | variants(slmCopies)
12498 | lookahead(lookaheadSLMStore + lookaheadSLMReload - unrollKSLM)
12499 | duration(durationSLMMainLoad);
12500 auto reqSLMAfterStore2 = every(unrollKSLM) | variants(slmCopies)
12501 | lookahead(lookaheadSLMStore + lookaheadSLMReload - 2 * unrollKSLM)
12502 | duration(durationSLMMainLoad);
12503 auto reqSLMAfterStoreABRem = every(unrollKSLM) | variants(slmCopies)
12504 | lookahead(lookaheadSLMStore + lookaheadSLMReloadRem - unrollKSLM);
12505 auto reqSLMAfterStoreABRem2 = every(unrollKSLM) | variants(slmCopies)
12506 | lookahead(
12507 lookaheadSLMStore + lookaheadSLMReloadRem - 2 * unrollKSLM);
12508
12509 auto slm1x2xFencedBarrier = [&]() {
12510 // For DG2+, before 1x/2x buffered stores, we must ensure prior SLM reads are complete.
12511 // Use a fence for >2x global buffering.
12512 // For 2x global buffering, use SWSB since loaded data will be used shortly.
12513 // For 1x global buffering, loaded data has already been consumed.
12514 if (hw < HW::XeHPG && !strategy.strictFence)
12515 kLoopBarrier(false);
12516 else if ((A_copies > 2 || B_copies > 2) && !strategy.slmFenceWARWA)
12517 kLoopBarrier(true);
12518 else {
12519 if (slmA && A_copies > 1) wrdepRanges(state.A_regs);
12520 if (slmB && B_copies > 1) wrdepRanges(state.B_regs);
12521 kLoopBarrier(false);
12522 }
12523 };
12524
12525 auto doSLMAfterStore2 = [&](Iteration h) {
12526 switch (slmBuffers) {
12527 case 1:
12528 case 2:
12529 case 3: break;
12530 case 4: kLoopBarrier(false, KBarrierType::Wait); break;
12531 default: stub();
12532 }
12533 };
12534
12535 auto doSLMAfterStore = [&](Iteration h) {
12536 switch (slmBuffers) {
12537 case 1: break;
12538 case 2: slm1x2xFencedBarrier(); break;
12539 case 3: kLoopBarrier(false, KBarrierType::Wait); break;
12540 case 4:
12541 // TEMP: move me earlier.
12542 slmFenceIssue();
12543 //
12544 slmFenceWait();
12545 if (strategy.slmFenceWARWA) {
12546 // Work around buggy SLM fence by ensuring SLM reads complete.
12547 if (slmA && A_copies > 1) wrdepRanges(state.A_regs);
12548 if (slmB && B_copies > 1) wrdepRanges(state.B_regs);
12549 }
12550 kLoopBarrier(false, KBarrierType::Signal);
12551 break;
12552 }
12553 };
12554
12555 auto doSLMStore = [&](Iteration h) {
12556 if (!slmA && !slmB) return;
12557
12558 switch (slmBuffers) {
12559 case 1: slm1x2xFencedBarrier(); break;
12560 case 2:
12561 case 3:
12562 case 4: break;
12563 default: stub();
12564 }
12565
12566 if (slmA)
12567 storeMatrix(effAo_regs(h), state.Ao_layout, state.Ao,
12568 state.Ao_strategy, state.Ao_addrs, strategy, state);
12569 if (slmB)
12570 storeMatrix(effBo_regs(h), state.Bo_layout, state.Bo,
12571 state.Bo_strategy, state.Bo_addrs, strategy, state);
12572
12573 if (slmASums)
12574 accumulateSum(false, Ta, effAo_regs(h), state.Ao_layout, Tc,
12575 state.As_regs, state.As_layout, strategy, state);
12576 if (slmBSums)
12577 accumulateSum(true, Tb, effBo_regs(h), state.Bo_layout, Tc,
12578 state.Bs_regs, state.Bs_layout, strategy, state);
12579
12580 switch (slmBuffers) {
12581 case 1: kLoopBarrier(true); break;
12582 case 2:
12583 slmFenceIssue();
12584 slmFenceWait();
12585 break;
12586 case 3:
12587 if (strategy.slmFenceWARWA) {
12588 // Work around buggy SLM fence by ensuring SLM reads complete.
12589 // Should be moved later, just before the barrier.
12590 if (slmA && A_copies > 1) wrdepRanges(state.A_regs);
12591 if (slmB && B_copies > 1) wrdepRanges(state.B_regs);
12592 }
12593 kLoopBarrier(true, KBarrierType::Signal);
12594 break;
12595 case 4: break;
12596 default: stub();
12597 }
12598 };
12599
12600 if (slmBuffers > 0) {
12601 if (slmBuffers >= 4)
12602 ls.schedule({{reqSLMAfterStore2, doSLMAfterStore2},
12603 {reqSLMAfterStoreABRem2, doSLMAfterStore2}});
12604
12605 if (slmBuffers >= 2)
12606 ls.schedule({{reqSLMAfterStore, doSLMAfterStore},
12607 {reqSLMAfterStoreABRem, doSLMAfterStore}});
12608
12609 ls.schedule(
12610 {{reqSLMStore, doSLMStore}, {reqSLMStoreABRem, doSLMStore}});
12611 }
12612
12613 // Save pre-loop state.
12614 auto statePreLoop = state;
12615
12616 using CT = LoopSequencer::CallbackType;
12617
12618 Label lTop, lBottom;
12619 std::vector<Label> labels;
12620
12621 ls.analyze();
12622
12623 if (ls.getUnroll() != unrollK)
12624 stub(); // Auto-calculated unroll should match unrollK from strategy.
12625
12626 // Prepare to save off loops for periodic barriers, if needed.
12627 Subregister outerK;
12628 if (strategy.barrierFreq > 0) outerK = state.ra.alloc_sub<uint32_t>();
12629
12630 // Prepare to peel loops for C prefetch, if needed.
12631 int prefetchCPeelLoops = -1;
12632 Subregister pfCPeelK;
12633 if (strategy.prefetchC > 0) {
12634 prefetchCPeelLoops = div_up(
12635 std::max(0, strategy.prefetchC - ls.getCooldown()), unrollK);
12636 if (prefetchCPeelLoops > 0) pfCPeelK = state.ra.alloc_sub<uint32_t>();
12637 }
12638
12639 auto resetForNewLoop = [&]() {
12640 resetKSLM();
12641 lastThresh = 0;
12642 haveA_lastRSWA = false;
12643 state.ra.safeRelease(barrierHeader);
12644 setupTeardownRemask(
12645 Tremask, 0, false, remaskPeriod, state.K, strategy, state);
12646 };
12647
12648 // Main events in lifetime of loop.
12649 ls.setCallback(CT::OffsetCounter,
12650 [&](int offset, int) { add(1, state.K, state.K, offset); });
12651 ls.setCallback(CT::LoopStart, [&](int unroll, int) {
12652 cmp(1 | le | state.flagAP, state.K, 0);
12653 if (prefetchCPeelLoops > 0) {
12654 min_(1, pfCPeelK, state.K, prefetchCPeelLoops * unrollK);
12655 add(1, state.K, state.K, -pfCPeelK);
12656 }
12657 if (strategy.barrierFreq > 0) {
12658 add(1 | sat, outerK, state.K, -strategy.barrierFreq);
12659 min_(1, state.K, state.K, strategy.barrierFreq);
12660 if (strategy.splitBarrier)
12661 kLoopBarrier(false, KBarrierType::Signal);
12662 }
12663 if (hw >= HW::Gen12LP) sync.nop(SWSB(Pipe::A, 1));
12664 jmpi(1 | state.flagAP, lBottom);
12665 mark(lTop);
12666 state.wipeActiveVFlags();
12667 });
12668 ls.setCallback(CT::LoopEnd, [&](int, int) {
12669 jmpi(1 | state.flagAP, lTop);
12670 if (strategy.barrierFreq > 0) {
12671 add(1, state.K, state.K, outerK);
12672 add(1 | sat, outerK, outerK, int16_t(-strategy.barrierFreq));
12673 add(1 | gt | state.flagAP, state.K, state.K, -outerK);
12674 if (strategy.splitBarrier) {
12675 kLoopBarrier(false, KBarrierType::Wait);
12676 kLoopBarrier(false, KBarrierType::Signal);
12677 } else
12678 kLoopBarrier(false);
12679 jmpi(1 | state.flagAP, lTop);
12680 }
12681 if (prefetchCPeelLoops > 0) {
12682 add(1 | gt | state.flagAP, state.K, state.K, pfCPeelK);
12683 mov(1, pfCPeelK, 0);
12684 gemmPrefetchC(problem, strategy, state);
12685 jmpi(1 | state.flagAP, lTop);
12686 }
12687 mark(lBottom);
12688 state.wipeActiveVFlags();
12689 });
12690 ls.setCallback(CT::JumpIfLT, [&](int thresh, int label) {
12691 if (size_t(label) >= labels.size()) labels.resize(label + 1);
12692 if (thresh != lastThresh) cmp(1 | lt | state.flagAP, state.K, thresh);
12693 jmpi(1 | state.flagAP, labels[label]);
12694 lastThresh = thresh;
12695 });
12696 ls.setCallback(CT::JumpTarget, [&](int label, int) {
12697 mark(labels[label]);
12698 state.wipeActiveVFlags();
12699 });
12700 ls.setCallback(CT::Jump, [&](int label, int) {
12701 if (size_t(label) >= labels.size()) labels.resize(label + 1);
12702 jmpi(1, labels[label]);
12703 });
12704 ls.setCallback(CT::NotifyPhase, [&](int phase, int) {
12705 curPhase = phase;
12706 switch (phase) {
12707 case LoopSequencer::PhaseWarmup:
12708 status << "k loop warmup" << status_stream::endl;
12709 break;
12710 case LoopSequencer::PhaseMainLoop:
12711 status << "Main k loop" << status_stream::endl;
12712 break;
12713 case LoopSequencer::PhaseMainPathEnd:
12714 if (strategy.barrierFreq > 0 && strategy.splitBarrier)
12715 kLoopBarrier(false, KBarrierType::Wait);
12716 break;
12717 case LoopSequencer::PhaseCooldown:
12718 if (prefetchCPeelLoops == 0)
12719 gemmPrefetchC(problem, strategy, state);
12720 if (lateKLoopCheck) state.raVFlag.lock(state.flagAP);
12721 haveA_lastRSWA = false;
12722 status << "k loop cooldown" << status_stream::endl;
12723 break;
12724 case LoopSequencer::PhaseShortLoop:
12725 if (strategy.prefetchC > 0)
12726 gemmPrefetchC(problem, strategy, state);
12727 status << "Short k loop" << status_stream::endl;
12728 remActiveA = remActiveB = remActiveSLM = false;
12729 resetForNewLoop();
12730 state = statePreLoop;
12731 break;
12732 case LoopSequencer::PhaseRemainder:
12733 status << "k loop remainder" << status_stream::endl;
12734 break;
12735 default: break;
12736 }
12737 });
12738
12739 // Early C prefetch.
12740 if (strategy.prefetchC < 0) gemmPrefetchC(problem, strategy, state);
12741
12742 // Generate k loop.
12743 if (lateKLoopCheck) state.raVFlag.unlock(state.flagAP);
12744 syncall(); /* Avoid unnecessary SWSB dependencies entering loop. */
12745 ls.materialize();
12746
12747 // Release barrier header from short k loop.
12748 state.ra.safeRelease(barrierHeader);
12749
12750 // Additional barriers to match other threads' barrier count, if other threads might have different k.
12751 if (matchBarriers) {
12752 status << "Match barrier counts between threads" << status_stream::endl;
12753 Subregister myBarriers, k0Barriers;
12754 Label lSkipExtraBarriers, lExtraBarrierLoop;
12755 int maxExtraBarriers = maxExtraKLoopRemBarriers(strategy);
12756
12757 if (strategy.barrierFreq > 0 && prefetchCPeelLoops > 0) stub();
12758
12759 gemmCalcKLoopBarrierCount(k0Barriers, state.inputs.k0, ls.getCooldown(),
12760 problem, strategy, state);
12761 gemmCalcKLoopBarrierCount(myBarriers, state.k, ls.getCooldown(),
12762 problem, strategy, state);
12763 if (maxExtraBarriers > 0)
12764 add(1, k0Barriers, k0Barriers, maxExtraBarriers);
12765 add(1 | sat | le | state.flagAP, myBarriers.ud(), k0Barriers,
12766 -myBarriers);
12767 (void)kLoopGetBarrierHeader(state);
12768 jmpi(1 | state.flagAP, lSkipExtraBarriers);
12769
12770 mark(lExtraBarrierLoop);
12771 {
12772 add(1 | gt | state.flagAP, myBarriers, myBarriers, -1);
12773 kLoopBarrier(false);
12774 jmpi(1 | state.flagAP, lExtraBarrierLoop);
12775 }
12776 mark(lSkipExtraBarriers);
12777
12778 state.ra.safeRelease(myBarriers);
12779 state.ra.safeRelease(k0Barriers);
12780 if (!strategy.persistent) state.ra.safeRelease(state.inputs.k0);
12781 }
12782
12783 // Free resources that are no longer needed.
12784 state.ra.safeRelease(outerK);
12785 state.ra.safeRelease(pfCPeelK);
12786 setupTeardownRemask(
12787 Tremask, 0, false, remaskPeriod, state.K, strategy, state);
12788
12789 state.firstKLoopSegment = false;
12790}
12791
12792template <HW hw>
12793void gemm_kernel_generator_t<hw>::kLoopTeardown(const GEMMProblem &problem,
12794 const GEMMStrategy &strategy, GEMMState &state) {
12795 if (state.K != state.k) state.ra.safeRelease(state.K);
12796 state.barrierReady = false;
12797 state.ra.safeRelease(state.barrierHeader);
12798 state.ra.safeRelease(state.barrierHeaderM);
12799 state.ra.safeRelease(state.barrierHeaderN);
12800 state.raVFlag.safeRelease(state.barrierM);
12801 state.raVFlag.safeRelease(state.barrierN);
12802 safeReleaseMaskAssignments(state.kMasksSLM, state);
12803 safeReleaseRanges(state.Ao_regsRem, state);
12804 safeReleaseRanges(state.Bo_regsRem, state);
12805 state.tokenAllocator.safeRelease(state.tokenBarrierFence[0]);
12806 state.tokenAllocator.safeRelease(state.tokenBarrierFence[1]);
12807}
12808
12809// Create 1-segment inner loop for a GEMM-like kernel.
12810template <HW hw>
12811bool gemm_kernel_generator_t<hw>::kLoopSingle(KLoop type,
12812 const GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state) {
12813 bool ok = kLoopSetup(problem, strategy, state);
12814 if (ok) {
12815 kLoop(type, problem, strategy, state);
12816 kLoopTeardown(problem, strategy, state);
12817 }
12818 return ok;
12819}
12820
12821template <HW hw>
12822bool gemm_kernel_generator_t<hw>::gemmKLoop(
12823 GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state) {
12824 return kLoopSingle(KLoop::GEMM, problem, strategy, state);
12825}
12826
12827// Decide whether C layout needs m/n remainder handling.
12828static inline void getCRemainders(const GEMMProblem &problem,
12829 const GEMMStrategy &strategy, bool &remM_C, bool &remN_C) {
12830 bool remainderM
12831 = (strategy.remHandling[LoopM] != RemainderHandling::Ignore);
12832 bool remainderN
12833 = (strategy.remHandling[LoopN] != RemainderHandling::Ignore);
12834
12835 int C_mgran, C_ngran;
12836 getGranularities(problem.C, C_mgran, C_ngran);
12837
12838 remM_C = remainderM && !strategy.C.padded && !strategy.altCRemainder
12839 && (C_mgran < strategy.unroll[LoopM]);
12840 remN_C = remainderN && !strategy.C.padded && !strategy.altCRemainder
12841 && (C_ngran < strategy.unroll[LoopN]);
12842}
12843
12844static inline bool needsKLoopReset(const GEMMProblem &problem) {
12845 return false;
12846}
12847
12848// Setup for C accumulation.
12849// NOTE: modifies problem/strategy/state.
12850template <HW hw>
12851bool gemm_kernel_generator_t<hw>::gemmAccumulateCSetup(
12852 GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state) {
12853 auto &Ta = problem.Ta, &Tb = problem.Tb, Tc = problem.Tc;
12854 auto Ta_ext = problem.Ta_ext, Tb_ext = problem.Tb_ext,
12855 Tc_ext = problem.Tc_ext;
12856 auto &Ta_load = state.Ta_load, &Tb_load = state.Tb_load;
12857
12858 bool cLoadAhead = strategy.cLoadAhead;
12859 auto unrollM = strategy.unroll[LoopM];
12860 auto unrollN = strategy.unroll[LoopN];
12861 auto unrollK = strategy.unroll[LoopK];
12862
12863 // Decide what remainder handling needs to be done.
12864 bool remainderM = strategy.remHandling[LoopM] != RemainderHandling::Ignore;
12865 bool remainderN = strategy.remHandling[LoopN] != RemainderHandling::Ignore;
12866 bool remainderK = strategy.remHandling[LoopK] != RemainderHandling::Ignore;
12867 bool remM_A = remainderM && !strategy.A.padded;
12868 bool remK_A = false;
12869 bool remK_B = false;
12870 bool remN_B = remainderN && !strategy.B.padded;
12871 bool remM_C, remN_C;
12872 getCRemainders(problem, strategy, remM_C, remN_C);
12873 bool remM_Ce = remM_C;
12874 bool remN_Ce = remN_C;
12875
12876 if (state.copyC) remM_C = remN_C = false;
12877
12878 auto globalA = problem.A;
12879 auto globalB = problem.B;
12880
12881 // 2D addressing parameters.
12882 auto &A_params = state.A_params, &B_params = state.B_params;
12883 auto &Ai_params = state.Ai_params, &Bi_params = state.Bi_params;
12884 auto &Ap_params = state.Ap_params, &Bp_params = state.Bp_params;
12885 A_params.rows = state.inputs.m;
12886 A_params.cols = state.fullK;
12887 A_params.offR = state.i0;
12888 A_params.offC = state.h0;
12889 A_params.remR = state.remainders[LoopM];
12890 B_params.rows = state.fullK;
12891 B_params.cols = state.inputs.n;
12892 B_params.offR = state.h0;
12893 B_params.offC = state.j0;
12894 B_params.remC = state.remainders[LoopN];
12895 Ai_params = A_params, Bi_params = B_params;
12896 Ap_params = A_params, Bp_params = B_params;
12897
12898 // Decide which dimensions to split for WG-cooperative operations (SLM copy, cooperative PF).
12899 state.effCoopA = effCoopSplitA(problem, strategy);
12900 state.effCoopB = effCoopSplitB(problem, strategy);
12901
12902 if (strategy.slmA && (state.effCoopA != CoopSplit::K) && remM_A
12903 && !isBlock2D(strategy.A.accessType)) {
12904 strategy.A.accessType = isColMajor(problem.A.layout)
12905 ? AccessType::Block
12906 : AccessType::Scattered;
12907 state.effCoopA = CoopSplit::K;
12908 }
12909
12910 if (strategy.slmB && (state.effCoopB != CoopSplit::K) && remN_B
12911 && !isBlock2D(strategy.B.accessType)) {
12912 strategy.B.accessType = !isColMajor(problem.B.layout)
12913 ? AccessType::Block
12914 : AccessType::Scattered;
12915 state.effCoopB = CoopSplit::K;
12916 }
12917
12918 // Prepare layouts for prefetch.
12919 bool remM_Cp = remM_C && strategy.C.base.isStateless();
12920 bool remN_Cp = remN_C && strategy.C.base.isStateless();
12921
12922 state.ma_prefetch = state.ka_prefetch = state.kb_prefetch
12923 = state.nb_prefetch = 0;
12924 if (strategy.prefetchA)
12925 coopSplit(true, state.ma_prefetch, state.ka_prefetch, unrollM,
12926 strategy.ka_prefetch, state.effCoopA, strategy.wg[LoopN],
12927 problem.A);
12928 if (strategy.prefetchB)
12929 coopSplit(false, state.kb_prefetch, state.nb_prefetch,
12930 strategy.kb_prefetch, unrollN, state.effCoopB,
12931 strategy.wg[LoopM], problem.B);
12932
12933 if (strategy.prefetchA
12934 && !getRegLayout(Ta_ext, state.Ap_layout, state.ma_prefetch,
12935 state.ka_prefetch, remM_A, remK_A, false, true, 0, 0,
12936 problem.A, strategy.A_prefetch))
12937 return false;
12938 if (strategy.prefetchB
12939 && !getRegLayout(Tb_ext, state.Bp_layout, state.kb_prefetch,
12940 state.nb_prefetch, remK_B, remN_B, false, true, 0, 0,
12941 problem.B, strategy.B_prefetch))
12942 return false;
12943 if (strategy.prefetchC
12944 && !getRegLayout(Tc_ext, state.Cp_layout, unrollM, unrollN, remM_Cp,
12945 remN_Cp, false, true, 0, 0, problem.C, strategy.C_prefetch))
12946 return false;
12947
12948 if (hasMasking(state.Cp_layout) || hasFragmenting(state.Cp_layout)) stub();
12949
12950 // Prepare addresses for prefetch.
12951 if (strategy.cooperativePF && strategy.prefetchA) {
12952 Subregister offAp;
12953 gemmCalcWorkshareAOffset(offAp, Ap_params.offR, Ap_params.offC,
12954 problem.A, strategy.A_prefetch, state.ma_prefetch,
12955 state.ka_prefetch, problem, strategy, state);
12956 if (strategy.A_prefetch.address2D) {
12957 if (A_params.offR.isValid())
12958 add(1, Ap_params.offR, Ap_params.offR, A_params.offR);
12959 if (A_params.offC.isValid())
12960 add(1, Ap_params.offC, Ap_params.offC, A_params.offC);
12961 } else {
12962 auto inEffAp = state.effAp;
12963 if (state.effA == state.effAp)
12964 state.effAp = state.ra.alloc_sub(state.effA.getType());
12965 eadd(1, state.effAp, inEffAp, offAp, strategy, state);
12966 }
12967 state.ra.safeRelease(offAp);
12968 }
12969 if (strategy.cooperativePF && strategy.prefetchB) {
12970 Subregister offBp;
12971 gemmCalcWorkshareBOffset(offBp, Bp_params.offR, Bp_params.offC,
12972 problem.B, strategy.B_prefetch, state.kb_prefetch,
12973 state.nb_prefetch, problem, strategy, state);
12974 if (strategy.B_prefetch.address2D) {
12975 if (B_params.offR.isValid())
12976 add(1, Bp_params.offR, Bp_params.offR, B_params.offR);
12977 if (B_params.offC.isValid())
12978 add(1, Bp_params.offC, Bp_params.offC, B_params.offC);
12979 } else {
12980 auto inEffBp = state.effBp;
12981 if (state.effB == state.effBp)
12982 state.effBp = state.ra.alloc_sub(state.effB.getType());
12983 eadd(1, state.effBp, inEffBp, offBp, strategy, state);
12984 }
12985 state.ra.safeRelease(offBp);
12986 }
12987
12988 // Prepare layouts and starting addresses for SLM copies and adjust problem.
12989 if (strategy.slmBuffers > 0) {
12990 int A_slmCP, B_slmCP;
12991 int A_tileR, A_tileC, B_tileR, B_tileC;
12992 std::tie(A_slmCP, B_slmCP) = targetSLMCrosspack(hw, problem, strategy);
12993 std::tie(A_tileR, A_tileC, B_tileR, B_tileC)
12994 = targetKernelTiling(hw, problem, strategy);
12995 auto opCount = outerProductCount(hw, problem, strategy);
12996
12997 if (strategy.slmA) {
12998 coopSplit(true, state.ma_slm, state.ka_slm, unrollM,
12999 strategy.unrollKSLM, state.effCoopA, strategy.wg[LoopN],
13000 problem.A);
13001
13002 if (state.ma_slm < unrollM) {
13003 remM_A = false;
13004 remK_A = remainderK && strategy.slmEarlyKMask;
13005 }
13006 if (strategy.slmATrans) {
13007 A_slmCP = state.ka_slm;
13008 if (strategy.ka_load % A_slmCP)
13009 throw std::runtime_error(
13010 "ka_load must be a multiple of ka_slm");
13011 }
13012 if ((state.ka_slm < A_slmCP) && (strategy.unrollKSLM != A_slmCP)
13013 && (A_tileC != A_slmCP))
13014 throw std::runtime_error(
13015 "ka_slm must be a multiple of crosspack, or unrollKSLM "
13016 "= crosspack.");
13017
13018 // Layout in from memory...
13019 state.Ai = problem.A;
13020 state.Ai_strategy = strategy.A;
13021 if (state.Ai_strategy.dpasw) {
13022 state.Ai_strategy.dpasw = false;
13023 state.Ai_strategy.tileR = 0;
13024 }
13025
13026 // ... layout out to SLM.
13027 state.Ao.layout = MatrixLayout::Pc;
13028 state.Ao.packSize = unrollM;
13029 state.Ao.crosspack = A_slmCP;
13030 state.Ao.setAlignment(state.Ao.packSize * Ta);
13031 state.Ao.tileR = A_tileR;
13032 state.Ao.tileC = (A_tileC || !A_tileR)
13033 ? A_tileC
13034 : std::max(opCount, strategy.ka_load);
13035
13036 bool colMajorIn
13037 = isRegisterColMajor(Ta_ext, state.Ai, state.Ai_strategy);
13038 bool colMajorSLM = !isLargeCrosspack(Ta, A_slmCP);
13039 state.Ao_strategy.base = SLM;
13040 state.Ao_strategy.accessType = (colMajorIn == colMajorSLM)
13041 ? AccessType::Block
13042 : AccessType::Scattered;
13043 state.Ao_strategy.smode = ScatterSIMD::Default;
13044
13045 if (state.Ai.layout == MatrixLayout::N
13046 && state.Ai_strategy.accessType == AccessType::Block2DVNNI
13047 && isLargeCrosspack(Ta, A_slmCP)) {
13048 state.Ao_strategy.accessType = AccessType::ChannelScattered;
13049 state.Ao_strategy.smode = ScatterSIMD::Narrow;
13050 }
13051 state.Ao_strategy.padded = true;
13052 state.Ao_strategy.atomic = false;
13053 state.Ao_strategy.address2D = false;
13054 state.Ao_strategy.newDP = (hw >= HW::XeHPG);
13055 state.Ao_strategy.cachingW = CacheSettingsLSC::Default;
13056
13057 // Layout in from memory...
13058 if (!getRegLayout(Ta_ext, state.Ai_layout, state.ma_slm,
13059 state.ka_slm, remM_A, remK_A, false, true, 0, 0,
13060 state.Ai, state.Ai_strategy))
13061 return false;
13062
13063 // ... layout out to SLM...
13064 remM_A = remK_A = false;
13065 if (!getRegLayout(Ta, state.Ao_layout, state.ma_slm, state.ka_slm,
13066 remM_A, remK_A, true, true, 0, 0, state.Ao,
13067 state.Ao_strategy))
13068 return false;
13069
13070 // ... and layout back from SLM.
13071 problem.A = state.Ao;
13072 strategy.A.base = SLM;
13073 strategy.A.accessType = AccessType::Block;
13074 strategy.A.address2D = false;
13075 strategy.A.newDP = (hw >= HW::XeHPG);
13076 strategy.A.cachingR = CacheSettingsLSC::Default;
13077 Ta_load = Ta;
13078 state.aioShare = Ta.size() == Ta_ext.size()
13079 && Ta.components() == Ta_ext.components()
13080 && matchLayoutsBidirectional(
13081 Ta, state.Ai_layout, state.Ao_layout);
13082
13083 // If we will add k-masking later, check if extra registers are needed.
13084 state.Ai_regCount = getRegCount(state.Ai_layout);
13085 if (!remK_A && remainderK && !state.Ai_strategy.address2D
13086 && !isRegisterColMajor(
13087 Ta_ext, state.Ai, state.Ai_strategy)) {
13088 std::vector<RegisterBlock> Ai_layoutKMasked;
13089 if (getRegLayout(Ta_ext, Ai_layoutKMasked, state.ma_slm,
13090 state.ka_slm, remM_A, true, false, true, 0, 0,
13091 state.Ai, state.Ai_strategy))
13092 state.Ai_regCount = std::max(
13093 state.Ai_regCount, getRegCount(Ai_layoutKMasked));
13094 }
13095
13096 // Offset A addresses in and out.
13097 state.effAi = state.effA;
13098 state.effA = state.ra.alloc_sub<uint32_t>(
13099 getHint(HintType::LongTerm, strategy));
13100 state.effAo = state.ra.alloc_sub<uint32_t>(
13101 getHint(HintType::LongTerm, strategy));
13102
13103 auto temp = state.ra.alloc_sub<uint32_t>(
13104 getHint(HintType::TempComp0, strategy));
13105 Subregister temp2;
13106
13107 uint32_t noff, noffTile, tileSplit = 1;
13108
13109 switch (state.effCoopA) {
13110 case CoopSplit::Linear:
13111 // FIXME: assumes compatible tiling between global and SLM layouts.
13112 noff = state.ma_slm * state.ka_slm;
13113 break;
13114 case CoopSplit::MN:
13115 noff = untile(Ta, state.Ao, 0, state.ma_slm, 0,
13116 state.Ao.packSize, strategy.unrollKSLM);
13117 if (state.ma_slm < state.Ao.tileR
13118 && state.Ao.tileR < state.Ao.packSize) {
13119 // m division splits tiles -- starting offsets no longer a linear sequence.
13120 if (state.Ao.tileR % state.ma_slm) stub();
13121 tileSplit = state.Ao.tileR / state.ma_slm;
13122 noffTile = untile(Ta, state.Ao, 0, state.Ao.tileR, 0,
13123 state.Ao.packSize, strategy.unrollKSLM);
13124 }
13125 break;
13126 case CoopSplit::K:
13127 noff = untile(Ta, state.Ao, 0, 0, state.ka_slm,
13128 state.Ao.packSize, strategy.unrollKSLM);
13129 if (state.ka_slm < state.Ao.tileC
13130 && state.Ao.tileC < strategy.unrollKSLM) {
13131 // k division splits tiles -- starting offsets no longer a linear sequence.
13132 if (state.Ao.tileC % state.ka_slm) stub();
13133 tileSplit = state.Ao.tileC / state.ka_slm;
13134 noffTile = untile(Ta, state.Ao, 0, 0, state.Ao.tileC,
13135 state.Ao.packSize, strategy.unrollKSLM);
13136 }
13137 break;
13138 default: stub();
13139 }
13140
13141 int32_t A_slmStride
13142 = strategy.slmABufBlockSize(problem) * strategy.slmBuffers;
13143
13144 if (tileSplit > 1) {
13145 if (!is_zero_or_pow2(tileSplit)) stub();
13146 shr(1, temp, state.lidN, log2(tileSplit));
13147 }
13148 gemmCalcWorkshareAOffset(temp2, Ai_params.offR, Ai_params.offC,
13149 state.Ai, state.Ai_strategy, state.ma_slm, state.ka_slm,
13150 problem, strategy, state);
13151 if (tileSplit > 1) {
13152 mulConstant(1, temp, temp, (noffTile - noff * tileSplit) * Ta);
13153 emad(1, temp, temp, state.lidN, noff * Ta, strategy, state);
13154 } else
13155 mulConstant(1, temp, state.lidN, noff * Ta);
13156 mulConstant(1, state.effA, state.lidM, A_slmStride);
13157 if (strategy.wg[LoopK] > 1)
13158 emad(1, state.effA, state.effA, state.lidK,
13159 A_slmStride * strategy.wg[LoopM], strategy, state);
13160 if (state.Ai_strategy.address2D) {
13161 if (Ai_params.offR != A_params.offR && A_params.offR.isValid())
13162 add(1, Ai_params.offR, Ai_params.offR, A_params.offR);
13163 if (Ai_params.offC != A_params.offC && A_params.offC.isValid())
13164 add(1, Ai_params.offC, Ai_params.offC, A_params.offC);
13165 } else
13166 eadd(1, state.effAi, state.effAi, temp2, strategy, state);
13167 add(1, state.effAo, state.effA, temp);
13168 if (problem.backward())
13169 add(1, state.effA, state.effA,
13170 (strategy.unrollKSLM - strategy.ka_load) * unrollM
13171 * Ta);
13172
13173 state.ra.safeRelease(temp2);
13174 state.ra.safeRelease(temp);
13175 }
13176 if (strategy.slmB) {
13177 coopSplit(false, state.kb_slm, state.nb_slm, strategy.unrollKSLM,
13178 unrollN, state.effCoopB, strategy.wg[LoopM], problem.B);
13179
13180 if (state.nb_slm < unrollN) {
13181 remN_B = false;
13182 remK_B = remainderK && strategy.slmEarlyKMask;
13183 }
13184 if (strategy.slmBTrans) {
13185 B_slmCP = state.kb_slm;
13186 if (strategy.kb_load % B_slmCP)
13187 throw std::runtime_error(
13188 "kb_load must be a multiple of kb_slm");
13189 }
13190 if ((state.kb_slm < B_slmCP) && (strategy.unrollKSLM != B_slmCP)
13191 && (B_tileR != B_slmCP))
13192 throw std::runtime_error(
13193 "kb_slm must be a multiple of crosspack, or unrollKSLM "
13194 "= crosspack.");
13195
13196 // Layout in from memory...
13197 state.Bi = problem.B;
13198 state.Bi_strategy = strategy.B;
13199 if (state.Bi_strategy.dpasw) {
13200 state.Bi_strategy.dpasw = false;
13201 state.Bi_strategy.tileC = 0;
13202 }
13203
13204 // ... layout out to SLM.
13205 state.Bo.layout = MatrixLayout::Pr;
13206 state.Bo.packSize = unrollN;
13207 state.Bo.crosspack = B_slmCP;
13208 state.Bo.setAlignment(state.Bo.packSize * Tb);
13209 state.Bo.tileR = (B_tileR || !B_tileC)
13210 ? B_tileR
13211 : std::max(opCount, strategy.kb_load);
13212 state.Bo.tileC = B_tileC;
13213
13214 bool colMajorIn
13215 = isRegisterColMajor(Tb_ext, state.Bi, state.Bi_strategy);
13216 bool colMajorSLM = isLargeCrosspack(Tb, B_slmCP);
13217 state.Bo_strategy.base = SLM;
13218 state.Bo_strategy.accessType = (colMajorIn == colMajorSLM)
13219 ? AccessType::Block
13220 : AccessType::Scattered;
13221 state.Bo_strategy.smode = ScatterSIMD::Default;
13222
13223 if (state.Bi.layout == MatrixLayout::T
13224 && state.Bi_strategy.accessType == AccessType::Block2DVNNI
13225 && isLargeCrosspack(Tb, B_slmCP)) {
13226 state.Bo_strategy.accessType = AccessType::ChannelScattered;
13227 state.Bo_strategy.smode = ScatterSIMD::Narrow;
13228 }
13229 state.Bo_strategy.padded = true;
13230 state.Bo_strategy.atomic = false;
13231 state.Bo_strategy.address2D = false;
13232 state.Bo_strategy.newDP = (hw >= HW::XeHPG);
13233 state.Bo_strategy.cachingW = CacheSettingsLSC::Default;
13234
13235 // Layout in from memory...
13236 if (!getRegLayout(Tb_ext, state.Bi_layout, state.kb_slm,
13237 state.nb_slm, remK_B, remN_B, false, true, 0, 0,
13238 state.Bi, state.Bi_strategy))
13239 return false;
13240
13241 // ... layout out to SLM...
13242 remK_B = remN_B = false;
13243 if (!getRegLayout(Tb, state.Bo_layout, state.kb_slm, state.nb_slm,
13244 remK_B, remN_B, true, true, 0, 0, state.Bo,
13245 state.Bo_strategy))
13246 return false;
13247
13248 // ... and layout back from SLM.
13249 problem.B = state.Bo;
13250 strategy.B.base = SLM;
13251 strategy.B.accessType = AccessType::Block;
13252 strategy.B.address2D = false;
13253 strategy.B.newDP = (hw >= HW::XeHPG);
13254 strategy.B.cachingR = CacheSettingsLSC::Default;
13255 Tb_load = Tb;
13256 state.bioShare = Tb.size() == Tb_ext.size()
13257 && Tb.components() == Tb_ext.components()
13258 && matchLayoutsBidirectional(
13259 Tb, state.Bi_layout, state.Bo_layout);
13260
13261 // If we will add k-masking later, check if extra registers are needed.
13262 state.Bi_regCount = getRegCount(state.Bi_layout);
13263 if (!remK_B && remainderK && !state.Bi_strategy.address2D
13264 && isRegisterColMajor(
13265 Tb_ext, state.Bi, state.Bi_strategy)) {
13266 std::vector<RegisterBlock> Bi_layoutKMasked;
13267 if (getRegLayout(Tb_ext, Bi_layoutKMasked, state.kb_slm,
13268 state.nb_slm, true, remN_B, false, true, 0, 0,
13269 state.Bi, state.Bi_strategy))
13270 state.Bi_regCount = std::max(
13271 state.Bi_regCount, getRegCount(Bi_layoutKMasked));
13272 }
13273
13274 // Offset B addresses in and out.
13275 state.effBi = state.effB;
13276 state.effB = state.ra.alloc_sub<uint32_t>(
13277 getHint(HintType::LongTerm, strategy));
13278 state.effBo = state.ra.alloc_sub<uint32_t>(
13279 getHint(HintType::LongTerm, strategy));
13280
13281 auto temp = state.ra.alloc_sub<uint32_t>(
13282 getHint(HintType::TempComp0, strategy));
13283 Subregister temp2;
13284
13285 uint32_t moff, moffTile, tileSplit = 1;
13286
13287 switch (state.effCoopB) {
13288 case CoopSplit::Linear:
13289 moff = state.kb_slm * state.nb_slm;
13290 break;
13291 case CoopSplit::MN:
13292 moff = untile(Tb, state.Bo, 0, 0, state.nb_slm,
13293 strategy.unrollKSLM, state.Bo.packSize);
13294 if (state.nb_slm < state.Bo.tileC
13295 && state.Bo.tileC < state.Bo.packSize) {
13296 if (state.Bo.tileC % state.nb_slm) stub();
13297 tileSplit = state.Bo.tileC / state.nb_slm;
13298 moffTile = untile(Tb, state.Bo, 0, 0, state.Bo.tileC,
13299 strategy.unrollKSLM, state.Bo.packSize);
13300 }
13301 break;
13302 case CoopSplit::K:
13303 moff = untile(Tb, state.Bo, 0, state.kb_slm, 0,
13304 strategy.unrollKSLM, state.Bo.packSize);
13305 if (state.kb_slm < state.Bo.tileR) {
13306 if (state.Bo.tileR % state.kb_slm) stub();
13307 tileSplit = state.Bo.tileR / state.kb_slm;
13308 moffTile = untile(Tb, state.Bo, 0, state.Bo.tileR, 0,
13309 strategy.unrollKSLM, state.Bo.packSize);
13310 }
13311 break;
13312 default: stub();
13313 }
13314
13315 int32_t B_slmStride
13316 = strategy.slmBBufBlockSize(problem) * strategy.slmBuffers;
13317
13318 if (tileSplit > 1) {
13319 if (!is_zero_or_pow2(tileSplit)) stub();
13320 shr(1, temp, state.lidM, log2(tileSplit));
13321 }
13322 gemmCalcWorkshareBOffset(temp2, Bi_params.offR, Bi_params.offC,
13323 state.Bi, state.Bi_strategy, state.kb_slm, state.nb_slm,
13324 problem, strategy, state);
13325 if (tileSplit > 1) {
13326 mulConstant(1, temp, temp, (moffTile - moff * tileSplit) * Tb);
13327 emad(1, temp, temp, state.lidM, moff * Tb, strategy, state);
13328 } else
13329 mulConstant(1, temp, state.lidM, moff * Tb);
13330 mulConstant(1, state.effB, state.lidN, B_slmStride);
13331 if (strategy.wg[LoopK] > 1)
13332 emad(1, state.effB, state.effB, state.lidK,
13333 B_slmStride * strategy.wg[LoopN], strategy, state);
13334 if (state.Bi_strategy.address2D) {
13335 if (Bi_params.offR != B_params.offR && B_params.offR.isValid())
13336 add(1, Bi_params.offR, Bi_params.offR, B_params.offR);
13337 if (Bi_params.offC != B_params.offC && B_params.offC.isValid())
13338 add(1, Bi_params.offC, Bi_params.offC, B_params.offC);
13339 } else
13340 eadd(1, state.effBi, state.effBi, temp2, strategy, state);
13341 if (strategy.slmABufSize(problem) > 0)
13342 add(1, state.effB, state.effB, strategy.slmABufSize(problem));
13343 add(1, state.effBo, state.effB, temp);
13344 if (problem.backward())
13345 add(1, state.effB, state.effB,
13346 (strategy.unrollKSLM - strategy.kb_load) * unrollN
13347 * Tb);
13348
13349 state.ra.safeRelease(temp2);
13350 state.ra.safeRelease(temp);
13351 }
13352 }
13353
13354 // Starting address adjustments for DPASW.
13355 bool cColMajor = isRegisterColMajor(Tc_ext, problem.C, strategy.C);
13356 if (strategy.dpasw) {
13357 if (cColMajor) {
13358 int t = strategy.B.tileC;
13359 and_(1 | nz | state.flagAP, null.uw(), state.lidM, 1);
13360 switch (problem.B.layout) {
13361 case MatrixLayout::N:
13362 emad(1 | state.flagAP, state.effB, state.effB,
13363 state.inputs.ldb, Immediate::w(t), strategy, state);
13364 break;
13365 case MatrixLayout::T:
13366 eadd(1 | state.flagAP, state.effB, state.effB, t * Tb_load,
13367 strategy, state);
13368 break;
13369 case MatrixLayout::Pr:
13370 eadd(1 | state.flagAP, state.effB, state.effB,
13371 untile(Tb_load, problem.B, 0, 0, t, unrollK,
13372 unrollN)
13373 * Tb_load,
13374 strategy, state);
13375 break;
13376 default: stub();
13377 }
13378 } else {
13379 int t = strategy.A.tileR;
13380 and_(1 | nz | state.flagAP, null.uw(), state.lidN, 1);
13381 switch (problem.A.layout) {
13382 case MatrixLayout::T:
13383 emad(1 | state.flagAP, state.effA, state.effA,
13384 state.inputs.lda, Immediate::w(t), strategy, state);
13385 break;
13386 case MatrixLayout::N:
13387 eadd(1 | state.flagAP, state.effA, state.effA, t * Ta_load,
13388 strategy, state);
13389 break;
13390 case MatrixLayout::Pc:
13391 eadd(1 | state.flagAP, state.effA, state.effA,
13392 untile(Ta_load, problem.A, 0, t, 0, unrollM,
13393 unrollK)
13394 * Ta_load,
13395 strategy, state);
13396 break;
13397 default: stub();
13398 }
13399 }
13400 }
13401
13402 // Get register layouts for A/B/C.
13403 if (!getRegLayout(Ta_load, state.A_layout, unrollM, strategy.ka_load,
13404 remM_A, remK_A, false, true, 0, 0, problem.A, strategy.A))
13405 return false;
13406 if (!getRegLayout(Tb_load, state.B_layout, strategy.kb_load, unrollN,
13407 remK_B, remN_B, false, true, 0, 0, problem.B, strategy.B))
13408 return false;
13409
13410 if (state.copyC) {
13411 makeUnbackedRegLayout(Tc, state.C_layout, unrollM, unrollN, cColMajor,
13412 1, strategy.C.tileR, strategy.C.tileC, true);
13413 if (!getRegLayout(Tc_ext, state.C_layoutExt, unrollM, unrollN, remM_Ce,
13414 remN_Ce, true, false, 0, 0, problem.C, state.Cext_strategy))
13415 return false;
13416 } else {
13417 if (!getRegLayout(Tc, state.C_layout, unrollM, unrollN, remM_C, remN_C,
13418 true, false, 0, 0, problem.C, strategy.C))
13419 return false;
13420 }
13421
13422 if (!strategy.altCRemainder && (remM_Ce || remN_Ce)) {
13423 // Try preparing C layout without masking (may reduce memory accesses).
13424 // Only use it if compatible with the masked layout, and saves on send instructions.
13425 auto &layoutExt = state.copyC ? state.C_layoutExt : state.C_layout;
13426 (void)getRegLayout(Tc_ext, state.C_layoutExtUnmasked, unrollM, unrollN,
13427 false, false, true, false, 0, 0, problem.C,
13428 state.Cext_strategy);
13429 if (state.C_layoutExtUnmasked.size() == layoutExt.size()
13430 || (!state.copyC
13431 && !matchLayouts(
13432 Tc, layoutExt, state.C_layoutExtUnmasked)))
13433 state.C_layoutExtUnmasked.clear();
13434 }
13435
13436 if (!state.copyC) state.C_layoutExt = state.C_layout;
13437
13438 if (hasRowFragmenting(state.A_layout)
13439 || hasColumnFragmenting(state.B_layout)) {
13440 status << "Can't fragment A or B.\n";
13441 return false;
13442 }
13443
13444 // Prepare to repack A/B if needed.
13445 int crosspackA, crosspackB, tileM_A, tileK_A, tileK_B, tileN_B;
13446 std::tie(crosspackA, crosspackB)
13447 = targetKernelCrosspack(hw, problem, strategy);
13448 std::tie(tileM_A, tileK_A, tileK_B, tileN_B)
13449 = targetKernelTiling(hw, problem, strategy);
13450
13451 state.repackA
13452 |= (crosspackA && !hasFullCrosspack(state.A_layout, crosspackA))
13453 || !hasTiling(state.A_layout, tileM_A, tileK_A);
13454 state.repackB
13455 |= (crosspackB && !hasFullCrosspack(state.B_layout, crosspackB))
13456 || !hasTiling(state.B_layout, tileK_B, tileN_B);
13457
13458 state.repackA |= (Ta.size() != Ta_ext.size()
13459 || Ta.components() != Ta_ext.components())
13460 && !strategy.slmA;
13461 state.repackB |= (Tb.size() != Tb_ext.size()
13462 || Tb.components() != Tb_ext.components())
13463 && !strategy.slmB;
13464
13465 if (crosspackA == 0) crosspackA = 1;
13466 if (crosspackB == 0) crosspackB = 1;
13467
13468 bool splitA = false, splitB = false;
13469
13470 if (state.repackA)
13471 makeUnbackedRegLayout(Ta, state.Ar_layout, unrollM, strategy.ka_load,
13472 isLayoutColMajor(state.A_layout), crosspackA, tileM_A, tileK_A,
13473 true, splitA);
13474 if (state.repackB)
13475 makeUnbackedRegLayout(Tb, state.Br_layout, strategy.kb_load, unrollN,
13476 isLayoutColMajor(state.B_layout), crosspackB, tileK_B, tileN_B,
13477 true, splitB);
13478
13479 // Prepare layouts for row/column sum calculation.
13480 bool globalCM = isLayoutColMajor(state.C_layout);
13481 if (problem.needsASums()) {
13482 state.systolicSumA = strategy.systolic && globalCM;
13483 state.slmASums = strategy.slmA && !state.systolicSumA;
13484
13485 auto As_srcLayout = state.slmASums
13486 ? state.Ao_layout
13487 : state.repackA ? state.Ar_layout : state.A_layout;
13488 makeSumLayout(
13489 false, Ta, As_srcLayout, Tc, state.As_layout, strategy, state);
13490 if (state.systolicSumA)
13491 setupTeardownAccumulateSumSystolic(
13492 true, Tb, problem, strategy, state);
13493 }
13494 if (problem.needsBSums()) {
13495 state.systolicSumB = strategy.systolic && !globalCM;
13496 state.slmBSums = strategy.slmB && !state.systolicSumB;
13497
13498 auto Bs_srcLayout = state.slmBSums
13499 ? state.Bo_layout
13500 : state.repackB ? state.Br_layout : state.B_layout;
13501 makeSumLayout(
13502 true, Tb, Bs_srcLayout, Tc, state.Bs_layout, strategy, state);
13503 if (state.systolicSumB)
13504 setupTeardownAccumulateSumSystolic(
13505 true, Ta, problem, strategy, state);
13506 }
13507
13508 // Round up needed A/B flag registers; hold off on C.
13509 // Try first without virtual flags and retry if needed.
13510 // m/n cooperative SLM copies use k masking, so skip those masks for now.
13511 auto &masks = state.AB_masks;
13512
13513 auto assignAllMasks = [&]() {
13514 return assignMasks(state.A_layout, LoopM, LoopK, masks, strategy, state)
13515 && assignMasks(
13516 state.Ap_layout, LoopM, LoopK, masks, strategy, state)
13517 && assignMasks(
13518 state.B_layout, LoopK, LoopN, masks, strategy, state)
13519 && assignMasks(
13520 state.Bp_layout, LoopK, LoopN, masks, strategy, state)
13521 && ((state.effCoopA != CoopSplit::K)
13522 || assignMasks(state.Ai_layout, LoopM, LoopK, masks,
13523 strategy, state))
13524 && ((state.effCoopB != CoopSplit::K)
13525 || assignMasks(state.Bi_layout, LoopK, LoopN, masks,
13526 strategy, state));
13527 };
13528
13529 state.lateKLoopCheck = false;
13530 bool success = assignAllMasks();
13531 if (!success && state.vflagStorage.isInvalid()) {
13532 status << "Retrying with virtual flags." << status_stream::endl;
13533 allocVFlagStorage(strategy, state);
13534 success = assignAllMasks();
13535 state.lateKLoopCheck = true;
13536 }
13537
13538 if (!success) return false;
13539
13540 loadMasks(masks, state.remainders, strategy, state);
13541
13542 // Temporary: move add64 out of the way (later: general cramming).
13543 if (state.add64.isValid()) {
13544 auto oldAdd64 = state.add64;
13545 state.ra.safeRelease(state.add64);
13546 state.add64 = state.ra.alloc_sub<uint32_t>();
13547 if (oldAdd64 != state.add64) mov(1, state.add64, oldAdd64);
13548 }
13549
13550 // Allocate data registers.
13551 gemmAllocRegs(problem, strategy, state);
13552 gemmAllocAoBoRegs(strategy, state);
13553
13554 // Allocate address registers for A/B loads. We don't need C addresses yet.
13555 allocAddrRegs(state.A_addrs, state.A_layout, problem.A, strategy.A, state);
13556 allocAddrRegs(state.B_addrs, state.B_layout, problem.B, strategy.B, state);
13557 allocAddrRegs(state.Ap_addrs, state.Ap_layout, globalA, strategy.A_prefetch,
13558 state);
13559 allocAddrRegs(state.Bp_addrs, state.Bp_layout, globalB, strategy.B_prefetch,
13560 state);
13561 allocAddrRegs(state.Ai_addrs, state.Ai_layout, state.Ai, state.Ai_strategy,
13562 state);
13563 allocAddrRegs(state.Bi_addrs, state.Bi_layout, state.Bi, state.Bi_strategy,
13564 state);
13565 allocAddrRegs(state.Ao_addrs, state.Ao_layout, state.Ao, state.Ao_strategy,
13566 state);
13567 allocAddrRegs(state.Bo_addrs, state.Bo_layout, state.Bo, state.Bo_strategy,
13568 state);
13569
13570 // Free up some C registers temporarily for use in address calculations.
13571 releaseRanges(state.C_regs, state);
13572
13573 // Set up address registers.
13574 gemmCacheLDABMultiples(problem, strategy, state);
13575 setupAddr(Ta_ext, state.Ap_addrs, state.effAp, state.Ap_layout,
13576 state.inputs.lda, globalA, strategy.A_prefetch, strategy, state,
13577 Ap_params, state.ldaMultiples);
13578 setupAddr(Tb_ext, state.Bp_addrs, state.effBp, state.Bp_layout,
13579 state.inputs.ldb, globalB, strategy.B_prefetch, strategy, state,
13580 Bp_params, state.ldbMultiples);
13581 setupAddr(Ta_ext, state.Ai_addrs, state.effAi, state.Ai_layout,
13582 state.inputs.lda, state.Ai, state.Ai_strategy, strategy, state,
13583 Ai_params, state.ldaMultiples);
13584 setupAddr(Tb_ext, state.Bi_addrs, state.effBi, state.Bi_layout,
13585 state.inputs.ldb, state.Bi, state.Bi_strategy, strategy, state,
13586 Bi_params, state.ldbMultiples);
13587 setupAddr(Ta, state.Ao_addrs, state.effAo, state.Ao_layout, Subregister(),
13588 state.Ao, state.Ao_strategy, strategy, state);
13589 setupAddr(Tb, state.Bo_addrs, state.effBo, state.Bo_layout, Subregister(),
13590 state.Bo, state.Bo_strategy, strategy, state);
13591 setupAddr(Ta_load, state.A_addrs, state.effA, state.A_layout,
13592 state.inputs.lda, problem.A, strategy.A, strategy, state, A_params,
13593 state.ldaMultiples);
13594 setupAddr(Tb_load, state.B_addrs, state.effB, state.B_layout,
13595 state.inputs.ldb, problem.B, strategy.B, strategy, state, B_params,
13596 state.ldbMultiples);
13597
13598 // Free unneeded registers after address setup.
13599 if (!needsKLoopReset(problem)) {
13600 if (!state.isNested) {
13601 state.ra.safeRelease(state.h0);
13602 if (strategy.A.address2D
13603 && (!strategy.prefetchA || strategy.A_prefetch.address2D))
13604 state.ra.safeRelease(state.inputs.lda);
13605 if (strategy.B.address2D
13606 && (!strategy.prefetchB || strategy.B_prefetch.address2D))
13607 state.ra.safeRelease(state.inputs.ldb);
13608 if (!problem.hasBinaryPostOp())
13609 if (!strategy.C.address2D
13610 && (!strategy.prefetchC
13611 || !strategy.C_prefetch.address2D)) {
13612 state.ra.safeRelease(state.i0);
13613 state.ra.safeRelease(state.j0);
13614 }
13615 }
13616 if (state.Ai_strategy.address2D) {
13617 if (Ai_params.offR != A_params.offR)
13618 state.ra.safeRelease(Ai_params.offR);
13619 if (Ai_params.offC != A_params.offC)
13620 state.ra.safeRelease(Ai_params.offC);
13621 }
13622 if (state.Bi_strategy.address2D) {
13623 if (Bi_params.offR != B_params.offR)
13624 state.ra.safeRelease(Bi_params.offR);
13625 if (Bi_params.offC != B_params.offC)
13626 state.ra.safeRelease(Bi_params.offC);
13627 }
13628 if (strategy.A_prefetch.address2D) {
13629 if (Ap_params.offR != A_params.offR)
13630 state.ra.safeRelease(Ap_params.offR);
13631 if (Ap_params.offC != A_params.offC)
13632 state.ra.safeRelease(Ap_params.offC);
13633 }
13634 if (strategy.B_prefetch.address2D) {
13635 if (Bp_params.offR != B_params.offR)
13636 state.ra.safeRelease(Bp_params.offR);
13637 if (Bp_params.offC != B_params.offC)
13638 state.ra.safeRelease(Bp_params.offC);
13639 }
13640
13641 if (!one_of(state.effAp, state.effA, state.effAi))
13642 state.ra.safeRelease(state.effAp);
13643 if (!one_of(state.effBp, state.effB, state.effBi))
13644 state.ra.safeRelease(state.effBp);
13645 }
13646
13647 releaseLDMultiples(state.ldaMultiples, state);
13648 releaseLDMultiples(state.ldbMultiples, state);
13649 releaseIndexVec(state);
13650
13651 reclaimRanges(state.C_regs, state);
13652
13653 // Allocate tokens.
13654 success = true;
13655 for (int q = 0; q < strategy.A_copies; q++)
13656 success &= allocateTokens(state.A_layout, state.A_regs[q], state);
13657 for (int q = 0; q < strategy.B_copies; q++)
13658 success &= allocateTokens(state.B_layout, state.B_regs[q], state);
13659 for (int q = 0; q < strategy.slmCopies; q++) {
13660 if (strategy.slmA)
13661 success &= allocateTokens(state.Ai_layout, state.Ai_regs[q], state);
13662 if (strategy.slmB)
13663 success &= allocateTokens(state.Bi_layout, state.Bi_regs[q], state);
13664 }
13665 if (strategy.slmA && !state.aioShare)
13666 success &= allocateTokens(state.Ao_layout, state.Ao_regs, state);
13667 if (strategy.slmB && !state.bioShare)
13668 success &= allocateTokens(state.Bo_layout, state.Bo_regs, state);
13669 success &= allocateTokens(
13670 state.Ap_layout, state.Ap_regs, state, state.Ap_addrs);
13671 success &= allocateTokens(
13672 state.Bp_layout, state.Bp_regs, state, state.Bp_addrs);
13673 if (!success) {
13674 if (hw >= HW::Gen12LP)
13675 status << "Not enough tokens for k loop." << status_stream::endl;
13676 clearTokenAllocations(hw, state);
13677 }
13678
13679 // Load C now if configured.
13680 // - temporarily free A/B data regs to use as C headers
13681 // - do beta scaling
13682 if (cLoadAhead) {
13683 if (problem.checkBeta0 && !problem.beta_real.fixed()) stub();
13684 if (state.C_accCount > 0) stub();
13685 if (strategy.kParallelLocal) stub();
13686
13687 releaseRanges(state.A_regs, state);
13688 releaseRanges(state.B_regs, state);
13689 if (!state.Ar_regs.empty()) releaseRanges(state.Ar_regs, state);
13690 if (!state.Br_regs.empty()) releaseRanges(state.Br_regs, state);
13691
13692 status << "Loading C" << status_stream::endl;
13693 gemmAccessC(COperation::Load, problem, strategy, state);
13694
13695 gemmBetaScale(problem, strategy, state);
13696 if (!state.Br_regs.empty()) reclaimRanges(state.Br_regs, state);
13697 if (!state.Ar_regs.empty()) reclaimRanges(state.Ar_regs, state);
13698 reclaimRanges(state.B_regs, state);
13699 reclaimRanges(state.A_regs, state);
13700 }
13701
13702 for (int q = 0; q < state.C_count; q++)
13703 releaseLDMultiples(state.ldcMultiples[q], state);
13704 releaseIndexVec(state);
13705
13706 // Release 64-bit emulation registers as they aren't needed in the inner loop.
13707 // Could also move r0 to acc here.
13708 if (state.emulate.temp[0].isValid()) {
13709 for (int q = 0; q < 2; q++) {
13710 state.emulate64TempSave[q] = state.emulate.temp[q];
13711 state.ra.safeRelease(state.emulate.temp[q]);
13712 }
13713 if (GRF::bytes(hw) == 64) {
13714 // Need a whole flag register to do emulated SIMD16 arithmetic.
13715 state.emulate.flag = state.raVFlag.alloc();
13716 state.emulate.flagOffset = 0;
13717 } else {
13718 state.emulate.flag = state.flagAP;
13719 state.emulate.flagOffset = 8;
13720 state.lateKLoopCheck = false;
13721 }
13722 }
13723
13724 return true;
13725}
13726
13727template <HW hw>
13728void gemm_kernel_generator_t<hw>::gemmAccumulateCTeardown(
13729 GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state) {
13730 // We're done with A and B. Free their address, data, and flag registers.
13731 // Also done with loop counter.
13732 safeReleaseMaskAssignments(state.AB_masks, state);
13733 safeReleaseRanges(state.A_addrs, state);
13734 safeReleaseRanges(state.B_addrs, state);
13735 safeReleaseRanges(state.Ai_addrs, state);
13736 safeReleaseRanges(state.Bi_addrs, state);
13737 safeReleaseRanges(state.Ao_addrs, state);
13738 safeReleaseRanges(state.Bo_addrs, state);
13739 safeReleaseRanges(state.Ap_addrs, state);
13740 safeReleaseRanges(state.Bp_addrs, state);
13741
13742 safeReleaseRanges(state.A_regs, state);
13743 safeReleaseRanges(state.Ar_regs, state);
13744 safeReleaseRanges(state.Ai_regs, state);
13745 safeReleaseRanges(state.Ao_regs, state);
13746 safeReleaseRanges(state.Ap_regs, state);
13747 safeReleaseRanges(state.B_regs, state);
13748 safeReleaseRanges(state.Br_regs, state);
13749 safeReleaseRanges(state.Bi_regs, state);
13750 safeReleaseRanges(state.Bo_regs, state);
13751 safeReleaseRanges(state.Bp_regs, state);
13752 state.ra.safeRelease(state.broadcast_regs);
13753 safeReleaseRanges(state.tempMul_regs, state);
13754 clearTokenAllocations(hw, state);
13755
13756 state.A_layout.clear();
13757 state.B_layout.clear();
13758 state.Ai_layout.clear();
13759 state.Bi_layout.clear();
13760 state.Ao_layout.clear();
13761 state.Bo_layout.clear();
13762 state.Ar_layout.clear();
13763 state.Br_layout.clear();
13764 state.Ap_layout.clear();
13765 state.Bp_layout.clear();
13766 state.Cp_layout.clear();
13767
13768 if (state.systolicSumA || state.systolicSumB)
13769 setupTeardownAccumulateSumSystolic(
13770 false, Type::invalid, problem, strategy, state);
13771
13772 // Restore effA/B if needed.
13773 bool restoreEffAB = false;
13774
13775 restoreEffAB |= (problem.abOffset == ABOffset::Load);
13776
13777 if (restoreEffAB) {
13778 Subregister aiOff, biOff;
13779 Subregister aiOffR, biOffR;
13780 Subregister aiOffC, biOffC;
13781 if (strategy.slmA)
13782 gemmCalcWorkshareAOffset(aiOff, aiOffR, aiOffC, state.Ai,
13783 state.Ai_strategy, state.ma_slm, state.ka_slm, problem,
13784 strategy, state);
13785 if (strategy.slmB)
13786 gemmCalcWorkshareBOffset(biOff, biOffR, biOffC, state.Bi,
13787 state.Bi_strategy, state.kb_slm, state.nb_slm, problem,
13788 strategy, state);
13789 if (strategy.slmA)
13790 eadd(1, state.effAi, state.effAi, -aiOff, strategy, state);
13791 if (strategy.slmB)
13792 eadd(1, state.effBi, state.effBi, -biOff, strategy, state);
13793
13794 state.ra.safeRelease(aiOff);
13795 state.ra.safeRelease(biOff);
13796 state.ra.safeRelease(aiOffR);
13797 state.ra.safeRelease(biOffR);
13798 state.ra.safeRelease(aiOffC);
13799 state.ra.safeRelease(biOffC);
13800 }
13801
13802 // Restore A/B addresses and strategies that were modified by SLM copies.
13803 if (strategy.slmA) {
13804 state.ra.safeRelease(state.effA);
13805 state.ra.safeRelease(state.effAo);
13806 state.effA = state.effAi;
13807 state.effAi = invalid;
13808 state.Ta_load = problem.Ta_ext;
13809 problem.A = state.Ai;
13810 strategy.A = state.Ai_strategy;
13811 }
13812 if (strategy.slmB) {
13813 state.ra.safeRelease(state.effB);
13814 state.ra.safeRelease(state.effBo);
13815 state.effB = state.effBi;
13816 state.effBi = invalid;
13817 state.Tb_load = problem.Tb_ext;
13818 problem.B = state.Bi;
13819 strategy.B = state.Bi_strategy;
13820 }
13821
13822 // Put accumulators with the rest of C.
13823 if (state.C_accCount > 0) {
13824 // Reclaim the bottom registers of C.
13825 reclaimRanges(state.C_regs[0], state);
13826
13827 auto e = elementsPerGRF<uint32_t>(hw);
13828 for (int i = 0; i < state.C_accCount; i += 2)
13829 mov<uint32_t>(2 * e, state.C_regs[0][i], AccumulatorRegister(i));
13830 }
13831
13832 // Restore emulation registers.
13833 if (state.emulate64TempSave[0].isValid()) {
13834 for (int q = 0; q < 2; q++) {
13835 state.emulate.temp[q] = state.emulate64TempSave[q];
13836 if (state.emulate64TempSave[q].isValid())
13837 state.ra.claim(state.emulate64TempSave[q]);
13838 }
13839 if (GRF::bytes(hw) == 64) state.raVFlag.release(state.emulate.flag);
13840 state.emulate.flag = invalid;
13841 state.emulate.flagOffset = 0;
13842 }
13843}
13844
13845// Perform the body of the GEMM computation, updating a block of C.
13846template <HW hw>
13847bool gemm_kernel_generator_t<hw>::gemmAccumulateC(
13848 GEMMProblem &problem_, GEMMStrategy &strategy_, GEMMState &state) {
13849 if (strategy_.fixedSystolic) {
13850 if (problem_.sumA || problem_.sumB
13851 || problem_.abOffset == ABOffset::Calc)
13852 stub();
13853 return strategy_.splitCopy
13854 ? sysgemm2AccumulateC(problem_, strategy_, state)
13855 : sysgemmAccumulateC(problem_, strategy_, state);
13856 }
13857
13858 auto problem = problem_;
13859 auto strategy = strategy_;
13860
13861 if (!gemmAccumulateCSetup(problem, strategy, state)) return false;
13862
13863 // Synthesize k loop. If configured, choose between 32-bit adds and 64-bit adds.
13864 if (strategy.checkAdd32 && state.add64.isValid()) {
13865 Label loop64, done;
13866 bool success = true;
13867
13868 cmp(1 | ne | state.flagAP, state.add64, uint16_t(0));
13869 jmpi(1 | state.flagAP, loop64);
13870 state.ra.safeRelease(state.add64);
13871
13872 status << "k loop: 32-bit address update" << status_stream::endl;
13873 strategy.emulate.emulate64_add32 = true;
13874 auto substate32 = state;
13875 success &= gemmKLoop(problem, strategy, substate32);
13876 jmpi(1, done);
13877
13878 mark(loop64);
13879 status << "k loop: 64-bit address update" << status_stream::endl;
13880 strategy.emulate.emulate64_add32 = false;
13881 success &= gemmKLoop(problem, strategy, state);
13882
13883 mark(done);
13884 if (!success) return false;
13885 } else {
13886 state.ra.safeRelease(state.add64);
13887 if (!gemmKLoop(problem, strategy, state)) return false;
13888 }
13889
13890 gemmAccumulateCTeardown(problem, strategy, state);
13891
13892 return true;
13893}
13894
13895template <HW hw>
13896void gemm_kernel_generator_t<hw>::setupCAddr0(GRFRange (&C_addr0)[2],
13897 GRFRange (&C_addr0Unmasked)[2], const vector<RegisterBlock> &C_layout,
13898 const vector<RegisterBlock> &C_layoutUnmasked, int C_count,
13899 GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state,
13900 const Address2DParams *params) {
13901 Address2DParams defaultParams;
13902 if (!params) {
13903 defaultParams.rows = state.inputs.m;
13904 defaultParams.cols = state.inputs.n;
13905 defaultParams.offR = state.i0;
13906 defaultParams.offC = state.j0;
13907 defaultParams.remR = state.remainders[LoopM];
13908 defaultParams.remC = state.remainders[LoopN];
13909 params = &defaultParams;
13910 }
13911 for (int q = 0; q < C_count; q++) {
13912 C_addr0[q] = state.ra.alloc_range(
13913 addrGRFCount(problem.C, strategy.C, C_layout[0]));
13914 setupAddr(C_addr0[q], state.effC[q], C_layout[0], state.inputs.ldc[q],
13915 problem.Tc.size(), problem.C, strategy.C, strategy, state,
13916 *params, state.ldcMultiples[q]);
13917 }
13918 if (!C_layoutUnmasked.empty())
13919 for (int q = 0; q < C_count; q++) {
13920 C_addr0Unmasked[q] = state.ra.alloc_range(
13921 addrGRFCount(problem.C, strategy.C, C_layoutUnmasked[0]));
13922 setupAddr(C_addr0Unmasked[q], state.effC[q], C_layoutUnmasked[0],
13923 state.inputs.ldc[q], problem.Tc.size(), problem.C,
13924 strategy.C, strategy, state, *params,
13925 state.ldcMultiples[q]);
13926 }
13927}
13928
13929template <HW hw>
13930bool gemm_kernel_generator_t<hw>::gemmUpdateC(
13931 GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state) {
13932
13933 auto Ts = problem.Ts;
13934
13935 status << "C update" << status_stream::endl;
13936
13937 auto &alphar = problem.alpha_real;
13938 auto &betar = problem.beta_real;
13939
13940 if (strategy.cLoadAhead) {
13941 betar = 0;
13942 if (!problem.alpha1()) stub();
13943 }
13944
13945 // C early offset.
13946 if (problem.cOffset == COffset::Pre)
13947 if (!gemmApplyCOffsetDispatch(problem, strategy, state)) return false;
13948
13949 // Prepare legacy eltwise postop injector if configured.
13950 GRFRange postOpScratch;
13951 if (useEltwiseInjector(problem)) {
13952 if (problem.hasBinaryPostOp()) stub();
13953
13954 const int eu_count = 0;
13955 postOpInjector.reset(new Injector(this, problem.Ts.get_dnnl_type(),
13956 problem.postOps, eu_count, GRFRange(), problem.postOpFwd));
13957 if (!postOpInjector) stub();
13958
13959 postOpScratch = state.ra.try_alloc_range(
13960 postOpInjector->preferred_scratch_regs());
13961 if (postOpScratch.isInvalid())
13962 postOpScratch
13963 = state.ra.alloc_range(postOpInjector->min_scratch_regs());
13964 postOpInjector->set_scratch(postOpScratch);
13965 }
13966
13967 // Convert C to the type of alpha/beta if needed and if possible (no data size change).
13968 // If not possible, must be done at a lower level during C update.
13969 bool successfulConvert = true;
13970
13971 if (problem.needsTsConvert())
13972 successfulConvert = gemmConvertC(Ts, problem, strategy, state);
13973
13974 // Scale by alpha now if alpha and beta are both nontrivial. Todo: move above beta = 0 check,
13975 // handle double precision correctly (load alpha to register first).
13976 // Also scale if atomically updating C or for split-complex.
13977 bool nontrivialAlpha = !problem.alpha1() && !problem.alphaM1();
13978 bool forceScale = !problem.alpha1() && strategy.C.atomic;
13979
13980 if (!problem.alpha1() && problem.hasBinaryPostOp()) {
13981 forceScale = true;
13982 if (!successfulConvert) stub();
13983 }
13984
13985 if (successfulConvert
13986 && ((nontrivialAlpha && (!problem.beta1() || strategy.doubleWA))
13987 || forceScale)) {
13988
13989 if (alphar == -1) {
13990 map(hw, Ts.real(), state.C_regs[0], state.C_regs[0], strategy,
13991 [&](int esize, GRF acc, GRF _) { mov(esize, acc, -acc); });
13992 } else if (alphar != 1) {
13993 map(hw, Ts.real(), state.C_regs[0], state.C_regs[0], strategy,
13994 [&](int esize, GRF acc, GRF _) {
13995 alphar.fixed()
13996 ? mul(esize, acc, acc, cast(Ts.real(), alphar))
13997 : mul(esize, acc, acc,
13998 alphar.getRegAvoiding(hw, acc));
13999 });
14000 }
14001
14002 alphar = 1;
14003 }
14004
14005 // Do the actual updating.
14006 if (!gemmAccessC(COperation::UpdateStore, problem, strategy, state))
14007 return false;
14008
14009 // Postop cleanup.
14010 if (useEltwiseInjector(problem)) {
14011 postOpInjector.reset();
14012 state.ra.safeRelease(postOpScratch);
14013 }
14014
14015 // Free C data and layout.
14016 safeReleaseRanges(state.C_regs, state);
14017 state.C_layout.clear();
14018 state.C_layoutExt.clear();
14019
14020 state.raVFlag.safeRelease(state.flagSwizzle);
14021
14022 // Success!
14023 return true;
14024}
14025
14026// Load from, update, and/or store to C, with complete remainder handling.
14027// If op == COperation::Load, only load C.
14028// If op == COperation::Update, load and update C.
14029// If op == COperation::UpdateStore, perform full C update with alpha/beta scaling. Unless state.isNested == true, assumed
14030// to be the conclusion of the kernel.
14031template <HW hw>
14032bool gemm_kernel_generator_t<hw>::gemmAccessC(COperation op,
14033 GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state) {
14034 Label labelStdCRemainder, labelAltCRemainder, labelBlock2DCRemainder,
14035 labelCRemDone, labelSkip;
14036
14037 int C_count = (op == COperation::UpdateStore) ? state.C_count : 1;
14038 bool remainderM
14039 = (strategy.remHandling[LoopM] != RemainderHandling::Ignore);
14040 bool remainderN
14041 = (strategy.remHandling[LoopN] != RemainderHandling::Ignore);
14042 bool remM_C, remN_C;
14043 getCRemainders(problem, strategy, remM_C, remN_C);
14044 bool altCRemainder = strategy.altCRemainder && !strategy.C.padded
14045 && (remainderM || remainderN || problem.gemmt());
14046 bool block2DCRemainder = strategy.block2DCRemainder && !strategy.C.padded
14047 && (remainderM || remainderN);
14048 bool stdCRemainder = !(altCRemainder
14049 && (strategy.remHandling[LoopM]
14050 == RemainderHandling::KnownRemainder)
14051 && (strategy.remHandling[LoopN]
14052 == RemainderHandling::KnownRemainder));
14053
14054 if ((op != COperation::UpdateStore) && strategy.C.atomic) stub();
14055
14056 if (state.allowEmptyC && (remainderM || remainderN)) {
14057 if (!state.isNested) stub();
14058 int simt = strategy.fused ? 16 : 1;
14059 cmp(simt | le | f0[0], null.ud(), state.remainders[LoopM], 0);
14060 cmp(simt | le | f1[0], null.ud(), state.remainders[LoopN], 0);
14061 strategy.fused ? goto12(16 | f0[0] | anyv, labelSkip)
14062 : ejmpi(1 | f0[0] | anyv, labelSkip);
14063 }
14064
14065 bool splitUpdateStore = (problem.cOffset == COffset::Post);
14066
14067 // New post-op path: do all post-ops up to sum, if any.
14068 int poSum = 0;
14069 bool newPostOps = !useEltwiseInjector(problem);
14070 if (op == COperation::UpdateStore && newPostOps) {
14071 for (poSum = 0; poSum < problem.postOps.len(); poSum++)
14072 if (problem.postOps.entry_[poSum].kind == primitive_kind::sum)
14073 break;
14074 gemmApplyPostOps(0, poSum, problem, strategy, state);
14075 splitUpdateStore |= (poSum + 1 < problem.postOps.len());
14076 }
14077
14078 if (op == COperation::UpdateStore && splitUpdateStore) {
14079 // C postoffset is implemented by splitting the update and store steps.
14080 bool ok = true;
14081 bool oldAllowEmptyC = state.allowEmptyC;
14082 state.allowEmptyC = false;
14083
14084 if (!(problem.alpha1() && problem.beta0()))
14085 ok = ok
14086 && gemmAccessC(
14087 COperation::Update, problem, strategy, state);
14088
14089 auto storeProblem = problem;
14090 storeProblem.cOffset = COffset::None;
14091 storeProblem.alpha_real = 1;
14092 storeProblem.alpha_imag = 0;
14093 storeProblem.beta_real = 0;
14094 storeProblem.beta_imag = 0;
14095
14096 // Do any post-sum post-ops
14097 if (newPostOps)
14098 gemmApplyPostOps(
14099 poSum + 1, problem.postOps.len(), problem, strategy, state);
14100 storeProblem.postOps = post_ops_t {};
14101
14102 if (problem.cOffset == COffset::Post) {
14103 gemmConvertC(problem.Tc, problem, strategy, state);
14104 ok = ok && gemmApplyCOffsetDispatch(problem, strategy, state);
14105 }
14106
14107 ok = ok
14108 && gemmAccessC(
14109 COperation::UpdateStore, storeProblem, strategy, state);
14110
14111 state.allowEmptyC = oldAllowEmptyC;
14112 if (ok && state.allowEmptyC && (remainderM || remainderN)) {
14113 mark(labelSkip);
14114 if (strategy.fused) join(16);
14115 }
14116 return ok;
14117 }
14118
14119 auto leave = [&] {
14120 if (state.isNested || (op != COperation::UpdateStore))
14121 jmpi(1, labelCRemDone);
14122 else
14123 epilogue(strategy, state);
14124 };
14125
14126 if (stdCRemainder) {
14127 // Check to see if we should jump to alternate C remainder handling path, when enabled:
14128 // - if this a remainder kernel
14129 // - for triangular updates, if the diagonal crosses this block.
14130 // When fusing, check diagonal for thread 0 for (fused in n) upper/m lower, thread 1 for n lower/m upper.
14131 if (altCRemainder || block2DCRemainder) {
14132 if (remainderM || remainderN) {
14133 cmp(1 | lt | f0[0], null.ud(), state.remaindersFused[LoopM],
14134 strategy.unroll[LoopM]);
14135 cmp(1 | lt | f1[0], null.ud(), state.remaindersFused[LoopN],
14136 strategy.unroll[LoopN]);
14137 }
14138
14139 auto &remLabel = block2DCRemainder ? labelBlock2DCRemainder
14140 : labelAltCRemainder;
14141 if (remainderM || remainderN) ejmpi(1 | f0[0] | anyv, remLabel);
14142 }
14143
14144 if (block2DCRemainder && !altCRemainder) mark(labelStdCRemainder);
14145
14146 // Release the all-purpose flag temporarily to free up flag registers if it won't be needed.
14147 auto saveFlagAP = state.flagAP;
14148 if (!problem.hasPostOp())
14149 if (!strategy.fused && !strategy.noJumpTables
14150 && state.emulate.flag != state.flagAP)
14151 state.raVFlag.safeRelease(state.flagAP);
14152
14153 // Decide on the C remainder handling strategy.
14154 bool fragments[2] = {false, false};
14155 bool fragPositives[2] = {true, true};
14156 int fragSizes[2] = {1 << 16, 1 << 16};
14157
14158 // Check for fragmenting.
14159 auto &C_layoutExt = state.C_layoutExt;
14160 auto &C_layoutExtUnmasked = state.C_layoutExtUnmasked;
14161 bool remDescs[2] = {false, false};
14162 bool remMasks[2] = {false, false};
14163
14164 // Loop over rows (rc = 0) and columns (rc = 1).
14165 for (int rc = 0; rc < 2; rc++) {
14166 if (!(rc ? remN_C : remM_C))
14167 continue; // Skip if not doing remainder handling in this dimension.
14168
14169 for (auto &l : C_layoutExt) {
14170 auto qFragment = rc ? l.colFragment : l.rowFragment;
14171 bool qZeroOK = rc ? l.noColsOK : l.noRowsOK;
14172 bool qMasked = rc ? (bool)l.colMask : (bool)l.rowMask;
14173 bool qDescRem = rc ? l.descRemC : l.descRemR;
14174
14175 if (qFragment > 0) {
14176 fragments[rc] = true;
14177 fragSizes[rc] = std::min<int>(fragSizes[rc], qFragment);
14178 if (qZeroOK) fragPositives[rc] = false;
14179
14180 if (qFragment > 1) {
14181 remDescs[rc] |= qDescRem;
14182 remMasks[rc] |= !qDescRem;
14183 }
14184 } else
14185 remMasks[rc] |= qMasked;
14186 }
14187 }
14188
14189 // Disable fragmentation if fragment size is bigger than unroll.
14190 fragments[0] &= fragSizes[0] < strategy.unroll[LoopM];
14191 fragments[1] &= fragSizes[1] < strategy.unroll[LoopN];
14192
14193 // Sanity check the requirements.
14194 if ((remDescs[0] && remMasks[0]) || (remDescs[1] && remMasks[1])) {
14195 status << "Different remainder types mixed in C layout."
14196 << status_stream::endl;
14197 return false;
14198 }
14199 if (remMasks[0] && remMasks[1]) {
14200 status << "Both dimensions are masked (not supported)."
14201 << status_stream::endl;
14202 return false;
14203 }
14204 if (remDescs[0] && remDescs[1]) {
14205 status << "Both dimensions use descriptors (not supported)."
14206 << status_stream::endl;
14207 return false;
14208 }
14209
14210 // Set remainder handling types.
14211 StdCRemType remTypes[2] = {StdCRemType::Ignore, StdCRemType::Ignore};
14212 for (int rc = 0; rc < 2; rc++) {
14213 if (remDescs[rc])
14214 remTypes[rc] = StdCRemType::Descriptor;
14215 else if (remMasks[rc])
14216 remTypes[rc] = StdCRemType::Mask;
14217 }
14218
14219 // Decide whether to do m or n first. Criteria, in order of priority:
14220 // - Do an ignored dimension first.
14221 // - Do a fragmented dimension first.
14222 // - Do descriptors first.
14223 // - Do whichever dimension of C is strided first.
14224 bool nFirst;
14225 if (remTypes[0] == StdCRemType::Ignore
14226 || remTypes[1] == StdCRemType::Ignore)
14227 nFirst = (remTypes[1] == StdCRemType::Ignore);
14228 else if (fragments[0] != fragments[1])
14229 nFirst = fragments[1];
14230 else if (remDescs[0] || remDescs[1])
14231 nFirst = remDescs[1];
14232 else
14233 nFirst = (problem.C.layout == MatrixLayout::N);
14234
14235 // Cache ldc multiples.
14236 gemmCacheLDCMultiples(problem, strategy, state);
14237
14238 // Prepare for load/store descriptor generation.
14239 if (remDescs[0] || remDescs[1])
14240 setupTeardownLoadStoreDesc(true, C_layoutExt, strategy, state);
14241
14242 // Set up address for the beginning of C.
14243 GRFRange C_addr0[2], C_addr0Unmasked[2];
14244 setupCAddr0(C_addr0, C_addr0Unmasked, C_layoutExt, C_layoutExtUnmasked,
14245 C_count, problem, strategy, state);
14246
14247 // Try to load C masks. If that fails, fragment the masked dimension down to the size of current blocks.
14248 vector<MaskAssignment> masks;
14249 if (!assignMasks(C_layoutExt, LoopM, LoopN, masks, strategy, state)) {
14250 for (int rc = 0; rc < 2; rc++) {
14251 if (remMasks[rc]) {
14252 fragments[rc] = true;
14253 fragSizes[rc] = rc ? C_layoutExt[0].nc : C_layoutExt[0].nr;
14254 }
14255 }
14256 } else
14257 loadMasks(masks, state.remainders, strategy, state);
14258
14259 // Call the remainder handling routine. If it fails, try again, switching M and N.
14260 // If that still fails, then try again with complete fragmentation if partial
14261 // fragmentation attempted the first time.
14262 bool columns[2] = {nFirst, !nFirst};
14263 bool switchedColumns[2] = {!nFirst, nFirst};
14264 do {
14265 if (doStdCRemainder(C_layoutExt, C_layoutExtUnmasked, false,
14266 columns, remTypes, fragments, fragPositives, fragSizes,
14267 C_addr0, C_addr0Unmasked, op, masks, problem, strategy,
14268 state))
14269 break;
14270 if (doStdCRemainder(C_layoutExt, C_layoutExtUnmasked, false,
14271 switchedColumns, remTypes, fragments, fragPositives,
14272 fragSizes, C_addr0, C_addr0Unmasked, op, masks, problem,
14273 strategy, state))
14274 break;
14275
14276 if ((fragments[0] && (fragSizes[0] > 1))
14277 || (fragments[1] && (fragSizes[1] > 1))) {
14278 fragSizes[0] = fragSizes[1] = 1;
14279
14280 if (doStdCRemainder(C_layoutExt, C_layoutExtUnmasked, false,
14281 columns, remTypes, fragments, fragPositives,
14282 fragSizes, C_addr0, C_addr0Unmasked, op, masks,
14283 problem, strategy, state))
14284 break;
14285 if (doStdCRemainder(C_layoutExt, C_layoutExtUnmasked, false,
14286 switchedColumns, remTypes, fragments, fragPositives,
14287 fragSizes, C_addr0, C_addr0Unmasked, op, masks,
14288 problem, strategy, state))
14289 break;
14290 }
14291 return false;
14292 } while (false);
14293
14294 // Free cached ldc multiples.
14295 for (int q = 0; q < state.C_count; q++)
14296 releaseLDMultiples(state.ldcMultiples[q], state);
14297 releaseIndexVec(state);
14298
14299 // Free address header for block 0.
14300 for (int q = 0; q < C_count; q++)
14301 state.ra.safeRelease(C_addr0[q]);
14302
14303 // Free C mask registers.
14304 safeReleaseMaskAssignments(masks, state);
14305
14306 // Clean up after load/store descriptor generation.
14307 if (remDescs[0] || remDescs[1])
14308 setupTeardownLoadStoreDesc(false, C_layoutExt, strategy, state);
14309
14310 // Restore all-purpose flag.
14311 state.flagAP = saveFlagAP;
14312 state.raVFlag.claim(saveFlagAP);
14313
14314 // Leave.
14315 if (block2DCRemainder || altCRemainder) leave();
14316 }
14317
14318 // Do block 2D C remainder handling if enabled.
14319 if (block2DCRemainder) {
14320 mark(labelBlock2DCRemainder);
14321
14322 // Check for transposition.
14323 bool memCM = isColMajor(problem.C.layout);
14324 bool regCM = isLayoutColMajor(state.C_layout);
14325 bool doTranspose = (memCM != regCM);
14326
14327 // Check if alignment requirements are met.
14328 auto Tc = problem.Tc;
14329 uint16_t align = doTranspose ? 4 : 8;
14330 for (int q = 0; q < state.C_count; q++) {
14331 bool checkAlign = (problem.C.alignment % align) != 0;
14332 bool checkWidth
14333 = (q == 0 && Tc.size() < 4 && op != COperation::Load);
14334 auto &labelNonBlock2DRem
14335 = altCRemainder ? labelAltCRemainder : labelStdCRemainder;
14336
14337 if (checkAlign) {
14338 and_(1 | nz | f0[0], null.uw(), state.effC[q].uw(), align - 1);
14339 and_(1 | nz | f1[0], null.uw(), state.inputs.ldc[q].uw(),
14340 align - 1);
14341 }
14342 if (checkWidth)
14343 and_(1 | nz | f0[1], null.uw(),
14344 state.remainders[isColMajor(problem.C.layout) ? LoopM
14345 : LoopN],
14346 (4 / Tc) - 1);
14347 if (checkAlign) ejmpi(1 | f0[0] | anyv, labelNonBlock2DRem);
14348 if (checkWidth) jmpi(1 | f0[1], labelNonBlock2DRem);
14349 }
14350
14351 // Rustle up a new layout.
14352 // Match existing C layout if possible and deemed worthwhile.
14353 vector<RegisterBlock> C_layout2D;
14354 auto modProblem = problem;
14355 auto modStrategy = strategy;
14356 auto modState = state;
14357
14358 modProblem.C.setAlignment(align);
14359 modStrategy.C = state.Cext_strategy;
14360 modStrategy.C.newDP = true;
14361 modStrategy.C.address2D = true;
14362 modStrategy.C.accessType = doTranspose ? AccessType::Block2DTranspose
14363 : AccessType::Block2D;
14364
14365 auto &C_layout = modState.C_layout;
14366 auto &C_layoutExt = modState.C_layoutExt;
14367 auto &C_layoutExtUnmasked = modState.C_layoutExtUnmasked;
14368
14369 bool inplace = !state.copyC
14370 && upgradeLayoutToBlock2D(Tc, C_layoutExt, C_layout2D, remM_C,
14371 remN_C, op != COperation::Load, modProblem.C,
14372 modStrategy.C);
14373
14374 for (auto &block : C_layout2D)
14375 inplace = inplace
14376 && state.C_regs[0].contiguous(
14377 block.offsetReg(), block.nregs());
14378
14379 if (inplace)
14380 C_layoutExt = std::move(C_layout2D);
14381 else {
14382 modState.copyC = true;
14383 if (!getRegLayout(problem.Tc_ext, C_layoutExt,
14384 strategy.unroll[LoopM], strategy.unroll[LoopN], remM_C,
14385 remN_C, true, false, 0, 0, modProblem.C, modStrategy.C))
14386 stub();
14387 }
14388
14389 C_layoutExtUnmasked.clear();
14390
14391 for (auto &block : C_layout)
14392 block.simdSize = 0; /* unlink C layout from memory */
14393
14394 // Do the update.
14395 Address2DParams params;
14396 params.rows = state.remainders[LoopM];
14397 params.cols = state.remainders[LoopN];
14398
14399 GRFRange C_addr0[2], C_addr0Unmasked[2];
14400 setupCAddr0(C_addr0, C_addr0Unmasked, C_layoutExt, C_layoutExtUnmasked,
14401 C_count, modProblem, modStrategy, modState, &params);
14402
14403 bool columns[2] = {false, true};
14404 StdCRemType remTypes[2] = {StdCRemType::Ignore, StdCRemType::Ignore};
14405 bool fragments[2] = {false, false};
14406 bool fragPositives[2] = {false, false};
14407 int fragSizes[2] = {1 << 16, 1 << 16};
14408 vector<MaskAssignment> masks;
14409
14410 if (!doStdCRemainder(C_layoutExt, C_layoutExtUnmasked, false, columns,
14411 remTypes, fragments, fragPositives, fragSizes, C_addr0,
14412 C_addr0Unmasked, op, masks, modProblem, modStrategy,
14413 modState))
14414 stub();
14415
14416 if (altCRemainder) leave();
14417 }
14418
14419 // Do alternate C remainder handling if enabled.
14420 if (altCRemainder) {
14421 mark(labelAltCRemainder);
14422 doAlternateCRemainder(op, problem, strategy, state);
14423 }
14424
14425 if (altCRemainder || block2DCRemainder) mark(labelCRemDone);
14426
14427 if (state.allowEmptyC && (remainderM || remainderN)) {
14428 mark(labelSkip);
14429 if (strategy.fused) join(16);
14430 }
14431
14432 return true; /* Successful! */
14433}
14434
14435template <HW hw>
14436bool gemm_kernel_generator_t<hw>::gemmBodyInternal(
14437 GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state) {
14438 auto Tc = problem.Tc;
14439
14440 auto unrollM = strategy.unroll[LoopM];
14441 auto unrollN = strategy.unroll[LoopN];
14442 auto &remM = state.remainders[LoopM];
14443 auto &remN = state.remainders[LoopN];
14444
14445 // Accumulate C with panel*panel multiply.
14446 if (!gemmAccumulateC(problem, strategy, state)) return false;
14447
14448 // Add A/B offsets.
14449 gemmLoadABOffset(problem, strategy, state);
14450 if (!gemmFinalizeSums(problem, strategy, state)) return false;
14451 gemmApplyABOffset(problem, strategy, state);
14452
14453 // If C is packed, update remainders and prepare to mask out border regions.
14454 bool remaskC_M = isPacked(problem.C.layout)
14455 && (strategy.remHandling[LoopM] != RemainderHandling::Ignore);
14456 bool remaskC_N = isPacked(problem.C.layout)
14457 && (strategy.remHandling[LoopN] != RemainderHandling::Ignore);
14458
14459 if (remaskC_M || remaskC_N) {
14460 if (remaskC_M)
14461 setupTeardownRemask(Tc, 0, true, unrollM, remM, strategy, state);
14462 if (remaskC_N)
14463 setupTeardownRemask(Tc, 1, true, unrollN, remN, strategy, state);
14464
14465 int C_mgran, C_ngran;
14466 getGranularities(problem.C, C_mgran, C_ngran);
14467 if (!remaskC_M || C_mgran == unrollM) C_mgran = 1;
14468 if (!remaskC_N || C_ngran == unrollN) C_ngran = 1;
14469 if (!is_zero_or_pow2(C_mgran)) stub();
14470 if (!is_zero_or_pow2(C_ngran)) stub();
14471
14472 if (C_mgran > 1) add(1, remM, remM, C_mgran - 1);
14473 if (C_ngran > 1) add(1, remN, remN, C_ngran - 1);
14474 if (C_mgran > 1) and_(1, remM, remM, uint32_t(~(C_mgran - 1)));
14475 if (C_ngran > 1) and_(1, remN, remN, uint32_t(~(C_ngran - 1)));
14476 }
14477
14478 // Local k reduction.
14479 if (strategy.kParallelLocal) gemmKReduce(problem, strategy, state);
14480
14481 // Late exit.
14482 bool lateExit = state.doLateExit;
14483 Label labelLateExit;
14484
14485 if (lateExit) {
14486 int simt = strategy.fused ? 16 : 1;
14487
14488 cmp(simt | le | f0[0], state.remainders[LoopM], uint16_t(0));
14489 cmp(simt | le | f1[0], state.remainders[LoopN], uint16_t(0));
14490
14491 InstructionModifier cond = simt | f0[0] | anyv;
14492
14493 strategy.fused ? goto12(cond, labelLateExit)
14494 : ejmpi(cond, labelLateExit);
14495 }
14496
14497 gemmUpdateSums(problem, strategy, state);
14498
14499 // Update C. If configured, choose between regular beta and beta = 0 or beta = 1 updates now.
14500 bool checkBeta0 = problem.checkBeta0 && !problem.beta_real.fixed();
14501 bool checkBeta1 = strategy.checkBeta1 && !problem.beta_real.fixed();
14502 bool checkTRMMBeta1 = state.beta1.isValid();
14503
14504 if (checkTRMMBeta1 && (checkBeta0 || checkBeta1)) stub();
14505
14506 if (!checkBeta0 && !checkBeta1 && !checkTRMMBeta1) {
14507 if (!gemmUpdateC(problem, strategy, state)) return false;
14508 } else {
14509 Label labelBeta0, labelBeta1, labelBetaDone;
14510 InstructionModifier mod0 = 1 | f0[0];
14511 InstructionModifier mod1 = 1 | f0[1];
14512 bool simtCF1 = false;
14513
14514 if (checkBeta0) { cmp0(1 | eq | f0[0], problem.beta_real.getReg(0)); }
14515
14516 if (checkBeta1) {
14517 cmp(1 | eq | f0[1], problem.beta_real.getReg(0),
14518 cast(problem.Ts, 1.0));
14519 }
14520
14521 if (checkBeta0) jmpi(mod0, labelBeta0);
14522
14523 if (checkBeta1 || checkTRMMBeta1) {
14524 simtCF1 ? if_(mod1, labelBeta1, labelBetaDone)
14525 : jmpi(mod1, labelBeta1);
14526 }
14527
14528 // Regular update.
14529 {
14530 auto subproblem = problem;
14531 auto substrategy = strategy;
14532 auto substate = state;
14533
14534 if (strategy.C.atomic && !strategy.C.base.isStateless()
14535 && !strategy.C.newDP)
14536 stub(); /* need to shift addresses */
14537 substrategy.C.atomic = false;
14538
14539 if (!gemmUpdateC(subproblem, substrategy, substate)) return false;
14540 }
14541
14542 simtCF1 ? else_(16, labelBetaDone)
14543 : state.isNested ? jmpi(1, labelBetaDone)
14544 : epilogue(strategy, state);
14545
14546 // beta = 1 update.
14547 if (checkBeta1 || checkTRMMBeta1) {
14548 status << "Special path: beta = 1" << status_stream::endl;
14549 mark(labelBeta1);
14550
14551 auto subproblem = problem;
14552 auto substate = state;
14553
14554 subproblem.beta_real = 1;
14555 subproblem.beta_imag = 0;
14556
14557 if (!gemmUpdateC(subproblem, strategy, substate)) return false;
14558
14559 if (checkBeta0) {
14560 (simtCF1 || state.isNested) ? jmpi(1, labelBetaDone)
14561 : epilogue(strategy, state);
14562 }
14563 }
14564
14565 // beta = 0 update.
14566 if (checkBeta0) {
14567 status << "Special path: beta = 0" << status_stream::endl;
14568 mark(labelBeta0);
14569
14570 auto subproblem = problem;
14571 auto substrategy = strategy;
14572 auto substate = state;
14573
14574 subproblem.beta_real = 0;
14575 subproblem.beta_imag = 0;
14576
14577 substrategy.C.atomic = false;
14578
14579 if (!gemmUpdateC(subproblem, substrategy, substate)) return false;
14580 }
14581
14582 mark(labelBetaDone);
14583 if (simtCF1) endif(16);
14584 }
14585
14586 // Cleanup.
14587 if (remaskC_M)
14588 setupTeardownRemask(Tc, 0, false, unrollM, remM, strategy, state);
14589 if (remaskC_N)
14590 setupTeardownRemask(Tc, 1, false, unrollN, remN, strategy, state);
14591
14592 if (lateExit) {
14593 mark(labelLateExit);
14594 if (strategy.fused) join(16);
14595 }
14596
14597 return true;
14598}
14599
14600template <HW hw>
14601CoopSplit gemm_kernel_generator_t<hw>::effCoopSplitA(
14602 const GEMMProblem &problem, const GEMMStrategy &strategy) {
14603 if (isPacked(problem.A.layout))
14604 return CoopSplit::Linear;
14605 else if (!isRegisterColMajor(problem.Ta_ext, problem.A, strategy.A)
14606 && (strategy.unroll[LoopM] % strategy.wg[LoopN] == 0)
14607 && !isBlock2D(strategy.A.accessType))
14608 return CoopSplit::MN;
14609 else
14610 return strategy.coopA;
14611}
14612
14613template <HW hw>
14614CoopSplit gemm_kernel_generator_t<hw>::effCoopSplitB(
14615 const GEMMProblem &problem, const GEMMStrategy &strategy) {
14616 if (isPacked(problem.B.layout))
14617 return CoopSplit::Linear;
14618 else if (isRegisterColMajor(problem.Tb_ext, problem.B, strategy.B)
14619 && (strategy.unroll[LoopN] % strategy.wg[LoopM] == 0)
14620 && !isBlock2D(strategy.B.accessType))
14621 return CoopSplit::MN;
14622 else
14623 return strategy.coopB;
14624}
14625
14626// Check whether all threads in a thread group should stay together in m/n remainder handling.
14627template <HW hw>
14628bool gemm_kernel_generator_t<hw>::wgRemCheck(
14629 const GEMMProblem &problem, const GEMMStrategy &strategy) {
14630 return (strategy.slmA && (effCoopSplitA(problem, strategy) != CoopSplit::K)
14631 && (strategy.remHandling[LoopM] != RemainderHandling::Ignore)
14632 && !strategy.A.padded)
14633 || (strategy.slmB
14634 && (effCoopSplitB(problem, strategy) != CoopSplit::K)
14635 && (strategy.remHandling[LoopN]
14636 != RemainderHandling::Ignore)
14637 && !strategy.B.padded)
14638 || strategy.kParallelLocal
14639 || ((strategy.barrierFreq > 0 || strategy.cooperativePF)
14640 && (strategy.prefetchA || strategy.prefetchB
14641 || strategy.prefetchC));
14642}
14643
14644// Do outer-level m/n remainder handling.
14645template <HW hw>
14646template <typename Problem>
14647bool gemm_kernel_generator_t<hw>::mnRemainderHandling(LoopType loop,
14648 Problem &problem, GEMMStrategy &strategy, GEMMState &state,
14649 bool (gemm_kernel_generator_t<hw>::*func)(
14650 Problem, GEMMStrategy, GEMMState)) {
14651 auto method = strategy.remHandling[loop];
14652 auto &unroll = strategy.unroll[loop];
14653 auto mn = (loop == LoopM) ? state.inputs.m : state.inputs.n;
14654 auto splitThresh
14655 = (loop == LoopM) ? strategy.mSplitThresh : strategy.nSplitThresh;
14656
14657 Label label_done;
14658
14659 auto originalCheckAdd32 = strategy.checkAdd32;
14660
14661 if (method == RemainderHandling::Split) {
14662 Label label_remainder;
14663
14664 // Jump to remainder loop if needed.
14665 // If threads fused in this direction, factor fused ID into calculation.
14666 if (wgRemCheck(problem, strategy))
14667 cmp(1 | lt | f0[0], null.d(), state.remaindersWG[loop],
14668 uint16_t(unroll * strategy.wg[loop]));
14669 else
14670 cmp(1 | lt | f0[0], null.d(), state.remaindersFused[loop],
14671 uint16_t(unroll));
14672
14673 if (splitThresh) {
14674 cmp(1 | lt | f1[0], null.d(), mn, int32_t(splitThresh));
14675 ejmpi(1 | f0[0] | anyv, label_remainder);
14676 } else
14677 jmpi(1 | f0[0], label_remainder);
14678
14679 // First generate code that ignores remainder handling.
14680 GEMMStrategy substrategy = strategy;
14681 substrategy.remHandling[loop] = RemainderHandling::Ignore;
14682
14683 status << "Generating "
14684 << "MNK"[static_cast<int>(loop)]
14685 << " non-remainder kernel for unroll " << unroll << '.'
14686 << status_stream::endl;
14687 if (!(this->*func)(problem, substrategy, state)) {
14688 status << "Non-remainder kernel failed, aborting."
14689 << status_stream::endl;
14690 return false;
14691 }
14692
14693 // Return, unless this is part of a larger computation, in which case jump to end.
14694 if (state.isNested)
14695 jmpi(1, label_done);
14696 else
14697 epilogue(strategy, state);
14698
14699 mark(label_remainder);
14700
14701 strategy.checkAdd32 = false;
14702 }
14703
14704 // OK, great! Now try to create remainder-handling code.
14705 status << "Attempting to generate "
14706 << "MNK"[static_cast<int>(loop)] << " general kernel for unroll "
14707 << unroll << '.' << status_stream::endl;
14708 bool success = (this->*func)(problem, strategy, state);
14709
14710 strategy.checkAdd32 = originalCheckAdd32;
14711 if (success) {
14712 mark(label_done);
14713 return true;
14714 }
14715
14716#ifndef ALLOW_REMAINDERS
14717 // Disable remainder code for now.
14718 return false;
14719#else
14720 auto &bound = (loop == LoopN) ? state.inputs.n : state.inputs.m;
14721 auto &index = (loop == LoopN) ? state.j0 : state.i0;
14722 auto &remainders = state.remainders[loop];
14723
14724 if (method == RemainderHandling::Ignore)
14725 throw std::runtime_error("Could not generate kernel.");
14726
14727 // It failed, so break up the loop into the next smaller power of 2 along this dimension,
14728 // plus the remainder (recursively).
14729 Label label_next_rem;
14730
14731 if (unroll == 1) {
14732 // No more splitting to do.
14733 // We don't know if this was originally split, so just output a warning.
14734 status << "NOTE: Split remainder handling is required for loop "
14735 << "MNK"[static_cast<int>(loop)] << '.' << status_stream::endl;
14736 return true;
14737 }
14738 int chunkSize = rounddown_pow2(unroll - 1);
14739
14740 // Jump to next remainder loop if needed.
14741 pushStream();
14742 {
14743 cmp(1 | lt | state.flagAP, null.d(), remainders, chunkSize);
14744 jmpi(1 | state.flagAP, label_next_rem);
14745
14746 {
14747 GEMMStrategy substrategy = strategy;
14748 GEMMState substate = state;
14749 substrategy.remHandling[loop] = RemainderHandling::Ignore;
14750 substrategy.unroll[loop] = chunkSize;
14751 substate.isNested = true;
14752 status << "Generating "
14753 << "MNK"[static_cast<int>(loop)]
14754 << " remainder kernel with unroll " << chunkSize << '.'
14755 << status_stream::endl;
14756 if (!(this->*func)(problem, substrategy, substate)) {
14757 discardStream();
14758 return false;
14759 }
14760 }
14761
14762 // Adjust remainder.
14763 add(1, remainders, remainders, -chunkSize);
14764
14765 // Adjust pointers as needed.
14766 // A += i0 (N) i0 * lda (T, Pc)
14767 // B += j0 * ldb (N, Pr) j0 (T)
14768 // C += i0 + j0 * ldc (N, Pr) j0 + i0 * ldc (T, Pc)
14769 switch (loop) {
14770 case LoopM:
14771 if (problem.A.layout == MatrixLayout::N)
14772 eadd(1, state.effA, state.effA, chunkSize * Ta, strategy,
14773 state);
14774 else {
14775 Subregister temp = state.ra.alloc_sub<uint32_t>();
14776 mulConstant(1, temp, state.inputs.lda, chunkSize);
14777 eadd(1, state.effA, state.effA, temp, strategy, state);
14778 state.ra.safeRelease(temp);
14779 }
14780 if (problem.C.layout == MatrixLayout::N
14781 || problem.C.layout == MatrixLayout::Pr)
14782 eadd(1, state.effC, state.effC,
14783 chunkSize * transaction_safe, strategy, state);
14784 else {
14785 Subregister temp = state.ra.alloc_sub<uint32_t>();
14786 mulConstant(1, temp, state.inputs.lda, chunkSize);
14787 eadd(1, state.effA, state.effA, temp, strategy, state);
14788 state.ra.safeRelease(temp);
14789 }
14790 break;
14791 case LoopN:
14792 if (problem.B.layout == MatrixLayout::T)
14793 eadd(1, state.effB, state.effB, chunkSize * Tb, strategy,
14794 state);
14795 else {
14796 Subregister temp = state.ra.alloc_sub<uint32_t>();
14797 mulConstant(1, temp, state.inputs.ldb, chunkSize);
14798 eadd(1, state.effB, state.effB, temp, strategy, state);
14799 state.ra.safeRelease(temp);
14800 }
14801 if (problem.C.layout == MatrixLayout::T
14802 || problem.C.layout == MatrixLayout::Pc)
14803 eadd(1, state.effC, state.effC, chunkSize * Tc, strategy,
14804 state);
14805 else {
14806 Subregister temp = state.ra.alloc_sub<uint32_t>();
14807 mulConstant(1, temp, state.inputs.ldb, chunkSize);
14808 eadd(1, state.effB, state.effB, temp, strategy, state);
14809 state.ra.safeRelease(temp);
14810 }
14811 break;
14812 }
14813
14814 mark(label_next_rem);
14815
14816 // Handle the remainder recursively.
14817 {
14818 GEMMStrategy substrategy = strategy;
14819 substrategy.remHandling[loop] = RemainderHandling::General;
14820 substrategy.unroll[loop] -= chunkSize;
14821 if (!mnRemainderHandling(loop, problem, substrategy, state, func)) {
14822 discardStream();
14823 return false;
14824 }
14825 }
14826 } /* end stream */
14827
14828 appendCurrentStream();
14829
14830 return true; /* success */
14831#endif
14832}
14833
14834template <HW hw>
14835template <typename Problem>
14836bool gemm_kernel_generator_t<hw>::mnJointSplitRemainderHandling(
14837 Problem &problem, GEMMStrategy &strategy, GEMMState &state,
14838 bool (gemm_kernel_generator_t<hw>::*func)(
14839 Problem, GEMMStrategy, GEMMState)) {
14840 Label label_done, label_remainder;
14841 bool success = false;
14842
14843 auto unrollM = strategy.unroll[LoopM];
14844 auto unrollN = strategy.unroll[LoopN];
14845
14846 pushStream();
14847 do {
14848 // Jump to remainder loop if needed:
14849 // - if m/n below split thresholds (when enabled)
14850 // - if in a remainder kernel.
14851 bool wgCheck = wgRemCheck(problem, strategy);
14852
14853 if (strategy.mSplitThresh && strategy.nSplitThresh) {
14854 cmp(1 | lt | f0[0], null.d(), state.inputs.m,
14855 int32_t(strategy.mSplitThresh));
14856 cmp(1 | lt | f1[0], null.d(), state.inputs.n,
14857 int32_t(strategy.nSplitThresh));
14858 ejmpi(1 | f0[0] | anyv, label_remainder);
14859 } else if (strategy.mSplitThresh) {
14860 cmp(1 | lt | f0[0], null.d(), state.inputs.m,
14861 int32_t(strategy.mSplitThresh));
14862 jmpi(1 | f0[0], label_remainder);
14863 } else if (strategy.nSplitThresh) {
14864 cmp(1 | lt | f0[0], null.d(), state.inputs.n,
14865 int32_t(strategy.nSplitThresh));
14866 jmpi(1 | f0[0], label_remainder);
14867 }
14868 if (wgCheck) {
14869 cmp(1 | lt | f0[0], null.d(), state.remaindersWG[LoopM],
14870 uint16_t(unrollM * strategy.wg[LoopM]));
14871 cmp(1 | lt | f1[0], null.d(), state.remaindersWG[LoopN],
14872 uint16_t(unrollN * strategy.wg[LoopN]));
14873 } else {
14874 cmp(1 | lt | f0[0], null.d(), state.remaindersFused[LoopM],
14875 uint16_t(unrollM));
14876 cmp(1 | lt | f1[0], null.d(), state.remaindersFused[LoopN],
14877 uint16_t(unrollN));
14878 }
14879 ejmpi(1 | f0[0] | anyv, label_remainder);
14880
14881 // First generate code that ignores remainder handling.
14882 GEMMStrategy substrategy = strategy;
14883 substrategy.remHandling[LoopM] = RemainderHandling::Ignore;
14884 substrategy.remHandling[LoopN] = RemainderHandling::Ignore;
14885
14886 status << "Generating MN non-remainder kernel." << status_stream::endl;
14887 if (!(this->*func)(problem, substrategy, state)) {
14888 status << "Non-remainder kernel failed, aborting."
14889 << status_stream::endl;
14890 break;
14891 }
14892
14893 // Return, unless this is part of a larger computation, in which case jump to end.
14894 if (state.isNested)
14895 jmpi(1, label_done);
14896 else
14897 epilogue(strategy, state);
14898
14899 mark(label_remainder);
14900
14901 // Finally, generate remainder handling kernel.
14902 substrategy = strategy;
14903 substrategy.remHandling[LoopM] = substrategy.remHandling[LoopN]
14904 = (wgCheck ? RemainderHandling::General
14905 : RemainderHandling::KnownRemainder);
14906 substrategy.checkAdd32 = false;
14907 status << "Generating MN general kernel." << status_stream::endl;
14908 success = (this->*func)(problem, substrategy, state);
14909
14910 mark(label_done);
14911 } while (false);
14912
14913 success ? appendCurrentStream() : discardStream();
14914
14915 return success;
14916}
14917
14918// Handle outer-level m edge cases.
14919template <HW hw>
14920bool gemm_kernel_generator_t<hw>::gemmMEdge(
14921 GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state) {
14922 if (strategy.jointSplit
14923 && strategy.remHandling[LoopM] == RemainderHandling::Split
14924 && strategy.remHandling[LoopN] == RemainderHandling::Split)
14925 return mnJointSplitRemainderHandling(problem, strategy, state,
14926 &gemm_kernel_generator_t<hw>::gemmBody);
14927 else
14928 return mnRemainderHandling(LoopM, problem, strategy, state,
14929 &gemm_kernel_generator_t<hw>::gemmNEdge);
14930}
14931
14932// Handle outer-level n edge cases.
14933template <HW hw>
14934bool gemm_kernel_generator_t<hw>::gemmNEdge(
14935 GEMMProblem problem, GEMMStrategy strategy, GEMMState state) {
14936 return mnRemainderHandling(LoopN, problem, strategy, state,
14937 &gemm_kernel_generator_t<hw>::gemmBody);
14938}
14939
14940// Initialize the interface.
14941template <HW hw>
14942void gemm_kernel_generator_t<hw>::gemmInitInterface(GEMMProblem &problem,
14943 GEMMStrategy &strategy, GEMMState &state, bool inSK) {
14944 Subregister localSize[3];
14945 GRF localID[3];
14946 Subregister tgids[3]
14947 = {r0.ud(1), r0.ud(6), r0.ud(7)}; // X, Y, Z threadgroup IDs
14948
14949 if (strategy.systolic) interface.requireDPAS();
14950 if (strategy.C.atomic) interface.requireGlobalAtomics();
14951 if (strategy.barrierFreq > 0) interface.requireBarrier();
14952
14953 auto slmSize = gemmSLMSize(problem, strategy);
14954 auto slmPerK = gemmPerKSLMSize(problem, strategy);
14955 if (slmSize > 0 || slmPerK > 0) {
14956 status << "SLM usage: " << slmSize / 1024. << 'k';
14957 if (slmPerK) status << " (" << slmPerK / 1024. << "k per-k)";
14958 status << status_stream::endl;
14959 if (!slmPerK) interface.requireSLM(slmSize);
14960 interface.requireBarrier();
14961 }
14962
14963 if (strategy.fixedWG(problem)) {
14964 auto wgX = strategy.wg[strategy.loopOrder[0]];
14965 auto wgY = strategy.wg[strategy.loopOrder[1]];
14966 auto wgZ = strategy.wg[strategy.loopOrder[2]];
14967 if (strategy.splitCopy) wgY *= 2;
14968 if (wgZ <= 1)
14969 interface.requireWorkgroup(strategy.subgroupSize * wgX, wgY, wgZ);
14970 }
14971 interface.requireWalkOrder(0, 1, 2);
14972
14973 bool needStatelessWrites = strategy.C.base.isStateless();
14974 if (problem.sumA || problem.sumB)
14975 needStatelessWrites |= strategy.CO.base.isStateless();
14976
14977 interface.requireStatelessWrites(needStatelessWrites);
14978
14979 int nb = int(strategy.slmA || strategy.barrierFreq)
14980 * strategy.namedBarriers[LoopM]
14981 + int(strategy.slmB || strategy.barrierFreq)
14982 * strategy.namedBarriers[LoopN];
14983 if (nb) interface.requireBarriers(nb);
14984
14985 interface.finalize();
14986
14987 for (int dim = 0; dim < 3; dim++) {
14988 localID[dim] = interface.getLocalID(dim);
14989 localSize[dim] = interface.getLocalSize(dim);
14990 }
14991
14992 // Get input arguments.
14993 state.inputs.base = interface.getArgumentIfExists("base");
14994 auto baseSurface = interface.getArgumentSurfaceIfExists("base");
14995 if (state.inputs.base.isInvalid()
14996 && baseSurface == InterfaceHandler::noSurface) {
14997 state.inputs.A = interface.getArgumentIfExists("A");
14998 state.inputs.B = interface.getArgumentIfExists("B");
14999 state.inputs.C[1] = interface.getArgumentIfExists("P");
15000 state.inputs.surfaceA = interface.getArgumentSurfaceIfExists("A");
15001 state.inputs.surfaceB = interface.getArgumentSurfaceIfExists("B");
15002 state.inputs.surfaceC[1] = interface.getArgumentSurfaceIfExists("P");
15003 } else {
15004 state.inputs.A = state.inputs.B = state.inputs.base;
15005 state.inputs.surfaceA = state.inputs.surfaceB = baseSurface;
15006 if (interface.getArgumentIfExists("offset_P").isValid()) {
15007 state.inputs.C[1] = state.inputs.base;
15008 state.inputs.surfaceC[1] = state.inputs.surfaceA;
15009 }
15010 }
15011
15012 state.inputs.C[0] = interface.getArgumentIfExists("C");
15013 state.inputs.surfaceC[0] = interface.getArgumentSurfaceIfExists("C");
15014 state.C_count = state.inputs.C[1].isValid() ? 2 : 1;
15015 if (problem.usesCO()) {
15016 state.inputs.CO = interface.getArgumentIfExists("CO");
15017 state.inputs.surfaceCO = interface.getArgumentSurfaceIfExists("CO");
15018 }
15019
15020 if (problem.abOffset != ABOffset::None) {
15021 state.inputs.abo = interface.getArgumentIfExists("abo");
15022 if (state.inputs.abo.isValid()) {
15023 // A/B offset are two words packed into a single dword argument.
15024 state.inputs.ao = state.inputs.abo.w(0);
15025 state.inputs.bo = state.inputs.abo.w(1);
15026 } else {
15027 state.inputs.ao = interface.getArgumentIfExists("ao");
15028 state.inputs.bo = interface.getArgumentIfExists("bo");
15029 }
15030 state.inputs.aoPtr = interface.getArgumentIfExists("ao_ptr");
15031 state.inputs.boPtr = interface.getArgumentIfExists("bo_ptr");
15032 }
15033 state.inputs.offsetA = interface.getArgumentIfExists("offset_A");
15034 state.inputs.offsetB = interface.getArgumentIfExists("offset_B");
15035 state.inputs.offsetC[0] = interface.getArgumentIfExists("offset_C");
15036 state.inputs.offsetC[1] = interface.getArgumentIfExists("offset_P");
15037 state.inputs.offsetCO = interface.getArgumentIfExists("offset_CO");
15038 if (problem.batch == BatchMode::Strided) {
15039 state.inputs.strideA[0] = interface.getArgumentIfExists("stride_A");
15040 state.inputs.strideB[0] = interface.getArgumentIfExists("stride_B");
15041 state.inputs.strideC[0] = interface.getArgumentIfExists("stride_C");
15042 if (problem.batchDims > 1) {
15043 state.inputs.strideA[1]
15044 = interface.getArgumentIfExists("stride_A1");
15045 state.inputs.strideB[1]
15046 = interface.getArgumentIfExists("stride_B1");
15047 state.inputs.strideC[1]
15048 = interface.getArgumentIfExists("stride_C1");
15049 state.inputs.batchSize1
15050 = interface.getArgumentIfExists("batch_size1");
15051 state.inputs.recipBatchSize1
15052 = interface.getArgumentIfExists("recip_batch_size1");
15053 }
15054 } else if (problem.batch == BatchMode::Nonstrided)
15055 state.inputs.offsetBatch
15056 = interface.getArgumentIfExists("offset_batch");
15057 else if (problem.batch == BatchMode::Variable) {
15058 state.inputs.incr_a_array
15059 = interface.getArgumentIfExists("incr_a_array");
15060 state.inputs.incr_b_array
15061 = interface.getArgumentIfExists("incr_b_array");
15062 }
15063 state.inputs.lda = interface.getArgumentIfExists("lda");
15064 state.inputs.ldb = interface.getArgumentIfExists("ldb");
15065 state.inputs.ldc[0] = interface.getArgumentIfExists("ldc");
15066 state.inputs.ldc[1] = interface.getArgumentIfExists("ldp");
15067 state.inputs.ldco = interface.getArgumentIfExists("ldco");
15068 state.inputs.m = interface.getArgumentIfExists("m");
15069 state.inputs.n = interface.getArgumentIfExists("n");
15070 state.inputs.k = interface.getArgumentIfExists("k");
15071 state.inputs.k0 = interface.getArgumentIfExists("k0");
15072 state.inputs.alpha_real = interface.getArgumentIfExists("alpha_real");
15073 state.inputs.alpha_imag = interface.getArgumentIfExists("alpha_imag");
15074 state.inputs.beta_real = interface.getArgumentIfExists("beta_real");
15075 state.inputs.beta_imag = interface.getArgumentIfExists("beta_imag");
15076 if (problem.batch == BatchMode::Variable) {
15077 state.inputs.alpha_array = interface.getArgumentIfExists("alpha_array");
15078 state.inputs.beta_array = interface.getArgumentIfExists("beta_array");
15079 state.inputs.incr_alpha = interface.getArgumentIfExists("incr_alpha");
15080 state.inputs.incr_beta = interface.getArgumentIfExists("incr_beta");
15081 }
15082 state.inputs.diagA = interface.getArgumentIfExists("diag_A");
15083 state.inputs.diagB = interface.getArgumentIfExists("diag_B");
15084 state.inputs.diagC = interface.getArgumentIfExists("diag_C");
15085 state.inputs.flags = interface.getArgumentIfExists("flags");
15086
15087 if (strategy.linearOrder()) {
15088 state.inputs.groupCountM = interface.getArgument("group_count_m");
15089 state.inputs.groupCountN = interface.getArgument("group_count_n");
15090 }
15091 if (strategy.hilbertOrder) {
15092 state.inputs.hilbertVD = interface.getArgumentIfExists("hilbert_vd");
15093 state.inputs.hilbertUVDRecip
15094 = interface.getArgumentIfExists("hilbert_uvd_recip");
15095 state.inputs.hilbertBail
15096 = interface.getArgumentIfExists("hilbert_bail");
15097 } else if (strategy.boustrophedon) {
15098 state.inputs.bslice = interface.getArgument("bslice");
15099 state.inputs.bthresh = interface.getArgument("bthresh");
15100 }
15101 if (strategy.persistent) {
15102 state.inputs.groupCountMN
15103 = interface.getArgumentIfExists("group_count");
15104 state.inputs.groupStride = interface.getArgument("group_stride");
15105 }
15106
15107 int po_count = problem.postOps.len();
15108 state.inputs.binarySrcs.resize(po_count);
15109 state.inputs.binaryOffsets.resize(po_count);
15110 state.inputs.binaryLDs.resize(po_count);
15111 state.inputs.binaryStrides.resize(po_count);
15112 state.inputs.binarySurfaces.resize(po_count);
15113 for (int i = 0; i < po_count; i++) {
15114 std::string srcName = "binary" + std::to_string(i);
15115 state.inputs.binarySrcs[i] = interface.getArgumentIfExists(srcName);
15116 state.inputs.binarySurfaces[i]
15117 = interface.getArgumentSurfaceIfExists(srcName);
15118 state.inputs.binaryOffsets[i]
15119 = interface.getArgumentIfExists("offset_" + srcName);
15120 state.inputs.binaryLDs[i]
15121 = interface.getArgumentIfExists("ld" + srcName);
15122 if (problem.batch == BatchMode::Strided) {
15123 state.inputs.binaryStrides[i][0]
15124 = interface.getArgumentIfExists("stride_" + srcName);
15125 if (problem.batchDims > 1)
15126 state.inputs.binaryStrides[i][1]
15127 = interface.getArgumentIfExists("stride1_" + srcName);
15128 }
15129 }
15130
15131 Subregister tgids_reordered[3];
15132 GRF lids_reordered[3];
15133 Subregister lszs_reordered[3];
15134
15135 for (int l = 0; l < 3; l++) {
15136 int i = static_cast<int>(strategy.loopOrder[l]);
15137 tgids_reordered[i] = tgids[l];
15138 lids_reordered[i] = localID[l];
15139 lszs_reordered[i] = localSize[l];
15140 }
15141 state.inputs.groupIDM = tgids_reordered[0];
15142 state.inputs.groupIDN = tgids_reordered[1];
15143 state.inputs.groupIDK = tgids_reordered[2];
15144 state.inputs.localIDM = lids_reordered[0];
15145 state.inputs.localIDN = lids_reordered[1];
15146 state.inputs.localIDK = lids_reordered[2];
15147 state.inputs.localSizeM = lszs_reordered[0];
15148 state.inputs.localSizeN = lszs_reordered[1];
15149 state.inputs.localSizeK = lszs_reordered[2];
15150
15151 if (strategy.linearOrder()) {
15152 state.inputs.groupIDMN = tgids[0];
15153 state.inputs.groupIDM = invalid;
15154 state.inputs.groupIDN = invalid;
15155 }
15156
15157 // Downgrade offsets to 32 bits for non-A64 accesses.
15158 if (strategy.A.base.getModel() != ModelA64)
15159 state.inputs.offsetA = state.inputs.offsetA.d();
15160 if (strategy.B.base.getModel() != ModelA64)
15161 state.inputs.offsetB = state.inputs.offsetB.d();
15162 if (strategy.C.base.getModel() != ModelA64)
15163 for (int q = 0; q < state.C_count; q++)
15164 state.inputs.offsetC[q] = state.inputs.offsetC[q].d();
15165 if (problem.usesCO() && strategy.CO.base.getModel() != ModelA64)
15166 state.inputs.offsetCO = state.inputs.offsetCO.d();
15167 for (auto &off : state.inputs.binaryOffsets)
15168 off = off.d();
15169
15170 // For now, reinterpret m/n/k/ld/diag variables to 32-bit if they are 64-bit.
15171 state.inputs.m = state.inputs.m.d();
15172 state.inputs.n = state.inputs.n.d();
15173 state.inputs.k = state.inputs.k.d();
15174 state.inputs.lda = state.inputs.lda.ud();
15175 state.inputs.ldb = state.inputs.ldb.ud();
15176 for (int q = 0; q < state.C_count; q++)
15177 state.inputs.ldc[q] = state.inputs.ldc[q].ud();
15178 state.inputs.ldco = state.inputs.ldco.ud();
15179 state.inputs.diagA = state.inputs.diagA.d();
15180 state.inputs.diagB = state.inputs.diagB.d();
15181 state.inputs.diagC = state.inputs.diagC.d();
15182
15183 // Claim registers.
15184 for (int i = 0; i < 4; i++)
15185 state.ra.claim(r0.uq(i));
15186
15187 if (strategy.A.base.isStateless()) state.ra.claim(state.inputs.A);
15188 if (strategy.B.base.isStateless()) state.ra.claim(state.inputs.B);
15189 if (strategy.C.base.isStateless())
15190 for (int q = 0; q < state.C_count; q++)
15191 state.ra.claim(state.inputs.C[q]);
15192
15193 if (problem.abOffset != ABOffset::None) {
15194 if (state.inputs.ao.isValid()) state.ra.claim(state.inputs.ao);
15195 if (state.inputs.bo.isValid()) state.ra.claim(state.inputs.bo);
15196 if (state.inputs.aoPtr.isValid()) state.ra.claim(state.inputs.aoPtr);
15197 if (state.inputs.boPtr.isValid()) state.ra.claim(state.inputs.boPtr);
15198 }
15199
15200 if (problem.usesCO()) {
15201 if (strategy.CO.base.isStateless()) state.ra.claim(state.inputs.CO);
15202 state.ra.claim(state.inputs.offsetCO);
15203 }
15204
15205 state.ra.claim(state.inputs.offsetA);
15206 state.ra.claim(state.inputs.offsetB);
15207 for (int q = 0; q < state.C_count; q++)
15208 state.ra.claim(state.inputs.offsetC[q]);
15209 state.ra.claim(state.inputs.lda);
15210 state.ra.claim(state.inputs.ldb);
15211 for (int q = 0; q < state.C_count; q++)
15212 state.ra.claim(state.inputs.ldc[q]);
15213 if (problem.allowMatrixOffset()) state.ra.claim(state.inputs.ldco);
15214 state.ra.claim(state.inputs.m);
15215 state.ra.claim(state.inputs.n);
15216 state.ra.claim(state.inputs.k);
15217 if (strategy.kParallel || strategy.kParallelLocal)
15218 state.ra.claim(state.inputs.k0);
15219
15220 if (!problem.alpha_real.fixed()) state.ra.claim(state.inputs.alpha_real);
15221 if (!problem.beta_real.fixed()) state.ra.claim(state.inputs.beta_real);
15222
15223 if (!inSK) {
15224 state.ra.claim(state.inputs.localIDM);
15225 state.ra.claim(state.inputs.localIDN);
15226 if (!strategy.fixedWG(problem)) {
15227 state.ra.claim(state.inputs.localSizeM);
15228 state.ra.claim(state.inputs.localSizeN);
15229 } else
15230 state.inputs.localSizeM = state.inputs.localSizeN = invalid;
15231 if (strategy.kParallel || strategy.kParallelLocal) {
15232 state.ra.claim(state.inputs.localIDK);
15233 state.ra.claim(state.inputs.localSizeK);
15234 }
15235 }
15236
15237 if (state.inputs.flags.isValid()) state.ra.claim(state.inputs.flags);
15238
15239 if (problem.batch == BatchMode::Strided) {
15240 for (int i = 0; i < problem.batchDims; i++) {
15241 state.ra.claim(state.inputs.strideA[i]);
15242 state.ra.claim(state.inputs.strideB[i]);
15243 state.ra.claim(state.inputs.strideC[i]);
15244 }
15245 if (problem.batchDims > 1) {
15246 state.ra.claim(state.inputs.batchSize1);
15247 state.ra.claim(state.inputs.recipBatchSize1);
15248 }
15249 state.ra.claim(state.inputs.groupIDK);
15250 } else if (problem.batch == BatchMode::Nonstrided) {
15251 state.ra.claim(state.inputs.offsetBatch);
15252 state.ra.claim(state.inputs.groupIDK);
15253 } else if (problem.batch == BatchMode::Variable) {
15254 state.ra.claim(state.inputs.incr_a_array);
15255 state.ra.claim(state.inputs.incr_b_array);
15256 state.ra.claim(state.inputs.alpha_array);
15257 state.ra.claim(state.inputs.beta_array);
15258 state.ra.claim(state.inputs.incr_alpha);
15259 state.ra.claim(state.inputs.incr_beta);
15260 state.ra.claim(state.inputs.groupIDK);
15261 }
15262
15263 if (strategy.linearOrder()) {
15264 state.ra.claim(state.inputs.groupCountM);
15265 state.ra.claim(state.inputs.groupCountN);
15266 }
15267
15268 if (strategy.hilbertOrder) {
15269 {
15270 state.ra.claim(state.inputs.hilbertVD);
15271 state.ra.claim(state.inputs.hilbertUVDRecip);
15272 }
15273 state.ra.claim(state.inputs.hilbertBail);
15274 }
15275
15276 if (strategy.boustrophedon) {
15277 state.ra.claim(state.inputs.bslice);
15278 state.ra.claim(state.inputs.bthresh);
15279 }
15280
15281 if (strategy.persistent) {
15282 state.ra.claim(state.inputs.groupStride);
15283 if (state.inputs.groupCountMN.isValid())
15284 state.ra.claim(state.inputs.groupCountMN);
15285 }
15286
15287 // Binary-related arguments are not claimed here, but instead
15288 // are reloaded later in the kernel when needed.
15289}
15290
15291// Return amount of SLM needed by a GEMM kernel.
15292template <HW hw>
15293size_t gemm_kernel_generator_t<hw>::gemmSLMSize(
15294 const GEMMProblem &problem, const GEMMStrategy &strategy) {
15295 // Space needed by SLM copies.
15296 size_t slmSize
15297 = strategy.slmABufSize(problem) + strategy.slmBBufSize(problem);
15298
15299 // Space needed for row/column sum reduction/sharing.
15300 if ((problem.needsASums() && strategy.slmA)
15301 || (problem.needsBSums() && strategy.slmB)) {
15302 slmSize = std::max<size_t>(slmSize,
15303 (strategy.unroll[LoopM] * strategy.wg[LoopM]
15304 + strategy.unroll[LoopN] * strategy.wg[LoopN])
15305 * problem.Tc);
15306 }
15307
15308 return slmSize;
15309}
15310
15311// Return amount of per-k SLM needed by a GEMM kernel.
15312template <HW hw>
15313size_t gemm_kernel_generator_t<hw>::gemmPerKSLMSize(
15314 const GEMMProblem &problem, const GEMMStrategy &strategy) {
15315 size_t slmSize = 0;
15316
15317 // Space needed for local k reduction (as much as possible).
15318 if (strategy.kParallelLocal) {
15319 // Calculate max SLM usage that doesn't reduce thread count.
15320 int mnThreads = strategy.wg[LoopM] * strategy.wg[LoopN];
15321 if (mnThreads <= 0) stub();
15322 int concurrentK = std::max(
15323 1, threadsPerEU(hw, strategy) * eusPerSubslice(hw) / mnThreads);
15324 slmSize = rounddown_pow2(slmCapacity(hw) / concurrentK);
15325 }
15326
15327 return slmSize;
15328}
15329
15330// Initialize the state structure.
15331template <HW hw>
15332void gemm_kernel_generator_t<hw>::gemmInitState(GEMMProblem &problem,
15333 GEMMStrategy &strategy, GEMMState &state, bool inSK) {
15334 auto Ta = problem.Ta, Tb = problem.Tb, Tc = problem.Tc;
15335
15336 if (!state.fusedGEMM.active) {
15337 initState(problem, strategy, state);
15338 gemmInitInterface(problem, strategy, state, inSK);
15339 state.isNested |= strategy.fused;
15340 state.isNested |= strategy.persistent;
15341 }
15342
15343 state.effA = strategy.A.base.isStateless() ? state.inputs.A
15344 : state.inputs.offsetA.d();
15345 state.effB = strategy.B.base.isStateless() ? state.inputs.B
15346 : state.inputs.offsetB.d();
15347 for (int q = 0; q < state.C_count; q++) {
15348 state.effC[q] = strategy.C.base.isStateless()
15349 ? state.inputs.C[q]
15350 : state.inputs.offsetC[q].d();
15351 }
15352 if (problem.usesCO()) {
15353 state.effCO = strategy.CO.base.isStateless()
15354 ? state.inputs.CO
15355 : state.inputs.offsetCO.d();
15356 }
15357
15358 if (!problem.alpha_real.fixed())
15359 problem.alpha_real = state.inputs.alpha_real;
15360 if (!problem.beta_real.fixed()) problem.beta_real = state.inputs.beta_real;
15361
15362 state.offsetA = state.inputs.offsetA;
15363 state.offsetB = state.inputs.offsetB;
15364 for (int q = 0; q < state.C_count; q++)
15365 state.offsetC[q] = state.inputs.offsetC[q];
15366 state.offsetCO = state.inputs.offsetCO;
15367
15368 state.flagAP = state.raVFlag.alloc();
15369
15370 state.allocEmulate64Temp(strategy.emulate);
15371
15372 state.Ta_load = problem.Ta_ext;
15373 state.Tb_load = problem.Tb_ext;
15374
15375 state.Tacc = problem.Tc;
15376 state.copyC = (problem.Tc != problem.Tc_ext)
15377 || (!strategy.altCRemainder && (Tc.size() < 4))
15378 || strategy.forceCopyC;
15379
15380 state.broadcast = strategy.doubleWA;
15381
15382 bool cColMajor = isRegisterColMajor(problem.Tc, problem.C, strategy.C);
15383 state.broadcast |= (Tc == Type::f32 && (cColMajor ? Tb : Ta) == Type::bf16);
15384
15385 state.Cext_strategy = strategy.C;
15386 state.Cext_strategy.tileR = state.Cext_strategy.tileC = 0;
15387
15388 state.lidM = state.inputs.localIDM[0];
15389 state.lidN = state.inputs.localIDN[0];
15390 if (strategy.kParallel || strategy.kParallelLocal)
15391 state.lidK = state.inputs.localIDK[0];
15392
15393 state.diagC = state.inputs.diagC;
15394 state.k = state.inputs.k;
15395
15396 state.lda = state.inputs.lda;
15397 state.ldb = state.inputs.ldb;
15398}
15399
15400// Offset A pointer in k dimension by a constant value.
15401template <HW hw>
15402void gemm_kernel_generator_t<hw>::gemmOffsetAk(int h, const Subregister &effA,
15403 const MatrixAddressing &globalA, const GEMMProblem &problem,
15404 const GEMMStrategy &strategy, GEMMState &state) {
15405 auto Ta_ext = problem.Ta_ext;
15406 if (h) switch (globalA.layout) {
15407 case MatrixLayout::N:
15408 emad(1, effA, effA, state.inputs.lda, Immediate::w(h), strategy,
15409 state);
15410 break;
15411 case MatrixLayout::T:
15412 eadd(1, effA, effA, h * Ta_ext, strategy, state);
15413 break;
15414 case MatrixLayout::Pc:
15415 eadd(1, effA, effA, h * globalA.packSize * Ta_ext, strategy,
15416 state);
15417 break;
15418 default: stub();
15419 }
15420}
15421
15422// Offset A pointer in k dimension by a variable value.
15423template <HW hw>
15424void gemm_kernel_generator_t<hw>::gemmOffsetAk(const Subregister &h,
15425 const Subregister &effA, const MatrixAddressing &globalA,
15426 const GEMMProblem &problem, const GEMMStrategy &strategy,
15427 GEMMState &state) {
15428 auto Ta_ext = problem.Ta_ext;
15429 switch (globalA.layout) {
15430 case MatrixLayout::N:
15431 emad(1, effA, effA, state.inputs.lda, h, strategy, state);
15432 break;
15433 case MatrixLayout::T:
15434 emad(1, effA, effA, h, Ta_ext.size(), strategy, state);
15435 break;
15436 case MatrixLayout::Pc:
15437 emad(1, effA, effA, h, globalA.packSize * Ta_ext, strategy, state);
15438 break;
15439 default: stub();
15440 }
15441}
15442
15443// Offset B pointer in k dimension by a constant value.
15444template <HW hw>
15445void gemm_kernel_generator_t<hw>::gemmOffsetBk(int h, const Subregister &effB,
15446 const MatrixAddressing &globalB, const GEMMProblem &problem,
15447 const GEMMStrategy &strategy, GEMMState &state) {
15448 auto Tb_ext = problem.Tb_ext;
15449 if (h) switch (globalB.layout) {
15450 case MatrixLayout::N:
15451 eadd(1, effB, effB, h * Tb_ext, strategy, state);
15452 break;
15453 case MatrixLayout::Pr:
15454 eadd(1, effB, effB, h * globalB.packSize * Tb_ext, strategy,
15455 state);
15456 break;
15457 case MatrixLayout::T:
15458 emad(1, effB, effB, state.inputs.ldb, Immediate::w(h), strategy,
15459 state);
15460 break;
15461 default: stub();
15462 }
15463}
15464
15465// Offset B pointer in k dimension by a variable value.
15466template <HW hw>
15467void gemm_kernel_generator_t<hw>::gemmOffsetBk(const Subregister &h,
15468 const Subregister &effB, const MatrixAddressing &globalB,
15469 const GEMMProblem &problem, const GEMMStrategy &strategy,
15470 GEMMState &state) {
15471 auto Tb_ext = problem.Tb_ext;
15472 switch (globalB.layout) {
15473 case MatrixLayout::T:
15474 emad(1, effB, effB, state.inputs.ldb, h, strategy, state);
15475 break;
15476 case MatrixLayout::N:
15477 emad(1, effB, effB, h, Tb_ext.size(), strategy, state);
15478 break;
15479 case MatrixLayout::Pr:
15480 emad(1, effB, effB, h, globalB.packSize * Tb_ext, strategy, state);
15481 break;
15482 default: stub();
15483 }
15484}
15485
15486// Adjust A, B, C to start at (i0, j0).
15487// initial is true to adjust offset_{A,B,C}, false to adjust A,B,C pointers.
15488template <HW hw>
15489void gemm_kernel_generator_t<hw>::gemmOffsetABC(bool initial, Subregister i0,
15490 Subregister j0, Subregister h0, const GEMMProblem &problem,
15491 const GEMMStrategy &strategy, GEMMState &state, bool doA, bool doB,
15492 bool doC, bool doBinary) {
15493 auto Ta_ext = problem.Ta_ext, Tb_ext = problem.Tb_ext,
15494 Tc_ext = problem.Tc_ext, Tco = problem.Tco;
15495 auto &offsetA = initial ? state.offsetA : state.effA;
15496 auto &offsetB = initial ? state.offsetB : state.effB;
15497 auto &offsetC0 = initial ? state.offsetC[0] : state.effC[0];
15498 auto &offsetAp = initial ? state.offsetAp : state.effAp;
15499 auto &offsetBp = initial ? state.offsetBp : state.effBp;
15500 auto &offsetCp = initial ? state.offsetCp : state.effCp;
15501 auto offsetCO = initial ? state.offsetCO : state.effCO;
15502 bool doCO = doC && (problem.cOffset != COffset::None);
15503
15504 Subregister tempQ0 = state.ra.alloc_sub<int64_t>(
15505 getHint(HintType::TempComp0, strategy));
15506 Subregister tempQ1 = state.ra.alloc_sub<int64_t>(
15507 getHint(HintType::TempComp1, strategy));
15508
15509 bool a2D = strategy.A.address2D;
15510 bool b2D = strategy.B.address2D;
15511 bool c2D = strategy.C.address2D;
15512 bool ap2D = strategy.prefetchA ? strategy.A_prefetch.address2D : a2D;
15513 bool bp2D = strategy.prefetchB ? strategy.B_prefetch.address2D : b2D;
15514 bool cp2D = strategy.prefetchC ? strategy.C_prefetch.address2D : c2D;
15515
15516 if (a2D && ap2D) doA = false;
15517 if (b2D && bp2D) doB = false;
15518 if (c2D && cp2D) doC = false;
15519
15520 if (doA && (a2D != ap2D)) {
15521 if (!initial) stub();
15522 if (offsetAp.isInvalid()) {
15523 offsetAp = state.ra.alloc_sub(
15524 offsetA.getType(), getHint(HintType::LongTerm, strategy));
15525 emov(1, state.offsetAp, offsetA, strategy, state);
15526 } else if (a2D && !ap2D)
15527 std::swap(offsetA, offsetAp);
15528 }
15529 if (doB && (b2D != bp2D)) {
15530 if (!initial) stub();
15531 if (offsetBp.isInvalid()) {
15532 offsetBp = state.ra.alloc_sub(
15533 offsetB.getType(), getHint(HintType::LongTerm, strategy));
15534 emov(1, offsetBp, offsetB, strategy, state);
15535 } else if (b2D && !bp2D)
15536 std::swap(offsetB, offsetBp);
15537 }
15538 if (doC && (c2D != cp2D)) {
15539 if (!initial) stub();
15540 if (offsetCp.isInvalid()) {
15541 offsetCp = state.ra.alloc_sub(
15542 offsetC0.getType(), getHint(HintType::LongTerm, strategy));
15543 emov(1, offsetCp, offsetC0, strategy, state);
15544 } else if (c2D && !cp2D)
15545 std::swap(offsetC0, offsetCp);
15546 }
15547
15548 // To do: interleave code.
15549 // A += i0 (N) i0 * lda (T, Pc)
15550 // B += j0 * ldb (N, Pr) j0 (T)
15551 // C += i0 + j0 * ldc (N, Pr) j0 + i0 * ldc (T, Pc)
15552 // CO += i0 (row offsets) j0 (col offsets)
15553 if (doA && i0.isValid()) {
15554 if (problem.A.layout == MatrixLayout::Nontranspose)
15555 emad(1, offsetA, offsetA, i0, Ta_ext.size(), strategy, state);
15556 else {
15557 emul(1, tempQ1, i0, state.inputs.lda, strategy, state);
15558 eadd(1, offsetA, offsetA, tempQ1.reinterpret(0, offsetA.getType()),
15559 strategy, state);
15560 }
15561 }
15562
15563 if (doB && j0.isValid()) {
15564 if (problem.B.layout == MatrixLayout::Transpose)
15565 emad(1, offsetB, offsetB, j0, Tb_ext.size(), strategy, state);
15566 else {
15567 emul(1, tempQ0, j0, state.inputs.ldb, strategy, state);
15568 eadd(1, offsetB, offsetB, tempQ0.reinterpret(0, offsetB.getType()),
15569 strategy, state);
15570 }
15571 }
15572
15573 FlagRegister flagCOR, flagCOC;
15574 if (doCO) {
15575 flagCOR = state.raVFlag.alloc();
15576 flagCOC = state.raVFlag.alloc();
15577 and_(1 | nz | flagCOC, null.ud(), state.inputs.flags, FlagCOColumn);
15578 and_(1 | nz | flagCOR, null.ud(), state.inputs.flags, FlagCORow);
15579 }
15580 if (doC) {
15581 for (int q = 0; q < state.C_count; q++) {
15582 auto offsetC = initial ? state.offsetC[q] : state.effC[q];
15583
15584 Subregister x, y;
15585 int xstride = Tc_ext.size();
15586 switch (problem.C.layout) {
15587 case MatrixLayout::Pr:
15588 xstride *= strategy.unroll[LoopN]; /* fall through */
15589 case MatrixLayout::N:
15590 x = i0;
15591 y = j0;
15592 break;
15593 case MatrixLayout::Pc:
15594 xstride *= strategy.unroll[LoopM]; /* fall through */
15595 case MatrixLayout::T:
15596 x = j0;
15597 y = i0;
15598 break;
15599 }
15600 emad(1, offsetC, offsetC, x, xstride, strategy, state);
15601 emul(1, tempQ0, y, state.inputs.ldc[q], strategy, state);
15602 eadd(1, offsetC, offsetC, tempQ0.reinterpret(0, offsetC.getType()),
15603 strategy, state); // Gen12: Use add3.
15604 }
15605 }
15606 if (doCO) {
15607 Label lNoMatrixOffset, lDone;
15608 if (problem.allowMatrixOffset()) {
15609 auto x0 = isColMajor(problem.CO.layout) ? i0 : j0;
15610 auto y0 = isColMajor(problem.CO.layout) ? j0 : i0;
15611 jmpi(1 | ~flagCOC, lNoMatrixOffset);
15612 jmpi(1 | ~flagCOR, lNoMatrixOffset);
15613 emad(1, offsetCO, offsetCO, x0, Tco.size(), strategy, state);
15614 emul(1, tempQ0, y0, state.inputs.ldco, strategy, state);
15615 eadd(1, offsetCO, offsetCO,
15616 tempQ0.reinterpret(0, offsetCO.getType()), strategy, state);
15617 jmpi(1, lDone);
15618 mark(lNoMatrixOffset);
15619 }
15620 emad(1 | flagCOC, offsetCO, offsetCO, j0, Tco.size(), strategy, state);
15621 emad(1 | flagCOR, offsetCO, offsetCO, i0, Tco.size(), strategy, state);
15622 state.raVFlag.safeRelease(flagCOR);
15623 state.raVFlag.safeRelease(flagCOC);
15624 if (problem.allowMatrixOffset()) mark(lDone);
15625 }
15626 if (doBinary)
15627 for (int i = 0; i < problem.postOps.len(); i++) {
15628 if (!problem.postOps.entry_[i].is_binary()) continue;
15629 bool row = problem.binaryRow[i], col = problem.binaryCol[i];
15630 auto T = problem.Tbinary[i];
15631 auto &ld = state.inputs.binaryLDs[i];
15632 auto offset = initial ? state.inputs.binaryOffsets[i]
15633 : state.effBinary[i];
15634 if (row && col) {
15635 auto x0 = isColMajor(problem.binary[i].layout) ? i0 : j0;
15636 auto y0 = isColMajor(problem.binary[i].layout) ? j0 : i0;
15637 emad(1, offset, offset, x0, T.size(), strategy, state);
15638 emul(1, tempQ0, y0, ld, strategy, state);
15639 eadd(1, offset, offset, tempQ0.reinterpret(0, offset.getType()),
15640 strategy, state);
15641 } else if (row)
15642 emad(1, offset, offset, i0, T.size(), strategy, state);
15643 else if (col)
15644 emad(1, offset, offset, j0, T.size(), strategy, state);
15645 }
15646 if (doC && problem.sumA)
15647 emad(1, offsetCO, offsetCO, i0, Tco.size(), strategy, state);
15648 if (doC && problem.sumB)
15649 emad(1, offsetCO, offsetCO, j0, Tco.size(), strategy, state);
15650
15651 // When k blocking (or certain triangular source kernels)
15652 // A += h0 * lda (N) h0 (T) h0 * mb (Pc)
15653 // B += h0 (N) h0 * ldb (T) h0 * nb (Pr)
15654 if (!h0.isInvalid()) {
15655 if (doA) switch (problem.A.layout) {
15656 case MatrixLayout::Nontranspose:
15657 emul(1, tempQ1, h0, state.inputs.lda, strategy, state);
15658 eadd(1, offsetA, offsetA,
15659 tempQ1.reinterpret(0, offsetA.getType()), strategy,
15660 state);
15661 break;
15662 case MatrixLayout::Transpose:
15663 emad(1, offsetA, offsetA, h0, Ta_ext.size(), strategy,
15664 state);
15665 break;
15666 case MatrixLayout::PackedColumns:
15667 emad(1, offsetA, offsetA, h0,
15668 strategy.unroll[LoopM] * Ta_ext, strategy, state);
15669 break;
15670 default: stub();
15671 }
15672 if (doB) switch (problem.B.layout) {
15673 case MatrixLayout::Nontranspose:
15674 emad(1, offsetB, offsetB, h0, Tb_ext.size(), strategy,
15675 state);
15676 break;
15677 case MatrixLayout::Transpose:
15678 emul(1, tempQ0, h0, state.inputs.ldb, strategy, state);
15679 eadd(1, offsetB, offsetB,
15680 tempQ0.reinterpret(0, offsetB.getType()), strategy,
15681 state);
15682 break;
15683 case MatrixLayout::PackedRows:
15684 emad(1, offsetB, offsetB, h0,
15685 strategy.unroll[LoopN] * Tb_ext, strategy, state);
15686 break;
15687 default: stub();
15688 }
15689 }
15690
15691 state.ra.safeRelease(tempQ0);
15692 state.ra.safeRelease(tempQ1);
15693
15694 if (doA && a2D && !ap2D) std::swap(offsetA, offsetAp);
15695 if (doB && b2D && !bp2D) std::swap(offsetB, offsetBp);
15696 if (doC && c2D && !cp2D) std::swap(offsetC0, offsetCp);
15697}
15698
15699template <ngen::HW hw>
15700void gemm_kernel_generator_t<hw>::gemmOffsetBatchABC(const GEMMProblem &problem,
15701 const GEMMStrategy &strategy, GEMMState &state) {
15702 auto Ts = problem.Ts;
15703
15704 // Non-strided batch support.
15705 if (problem.batch == BatchMode::Nonstrided) {
15706 auto temp = state.ra.alloc().uq();
15707 auto boffset = state.ra.alloc_sub<uint32_t>();
15708
15709 add(1, boffset, state.inputs.offsetBatch, state.inputs.groupIDK);
15710 mov(1, state.flagAP, 0x7);
15711 shl(1, boffset, boffset, uint16_t(3));
15712
15713 eadd(1, temp[0], state.inputs.A, boffset, strategy, state);
15714 eadd(1, temp[1], state.inputs.B, boffset, strategy, state);
15715 eadd(1, temp[2], state.inputs.C[0], boffset, strategy, state);
15716
15717 load(4 | state.flagAP, temp, scattered_qword(1), strategy.A.base, temp);
15718
15719 emov(1, state.effA, temp[0], strategy, state);
15720 emov(1, state.effB, temp[1], strategy, state);
15721 emov(1, state.effC[0], temp[2], strategy, state);
15722
15723 state.ra.safeRelease(temp);
15724 state.ra.safeRelease(boffset);
15725 if (!strategy.persistent)
15726 state.ra.safeRelease(state.inputs.offsetBatch);
15727 }
15728
15729 // Strided batch support.
15730 if (problem.batch == BatchMode::Strided) {
15731 for (int b = 0; b < problem.batchDims; b++) {
15732 emul(1, state.inputs.strideA[b], state.inputs.strideA[b],
15733 state.batchID[b], strategy, state);
15734 emul(1, state.inputs.strideB[b], state.inputs.strideB[b],
15735 state.batchID[b], strategy, state);
15736 emul(1, state.inputs.strideC[b], state.inputs.strideC[b],
15737 state.batchID[b], strategy, state);
15738 }
15739
15740 for (int b = 0; b < problem.batchDims; b++) {
15741 eadd(1, state.offsetA, state.offsetA, state.inputs.strideA[b],
15742 strategy, state);
15743 eadd(1, state.offsetB, state.offsetB, state.inputs.strideB[b],
15744 strategy, state);
15745 for (int q = 0; q < state.C_count; q++) {
15746 auto offsetC = state.offsetC[q];
15747 eadd(1, offsetC, offsetC, state.inputs.strideC[b], strategy,
15748 state);
15749 }
15750 if (!strategy.persistent) {
15751 state.ra.safeRelease(state.inputs.strideA[b]);
15752 state.ra.safeRelease(state.inputs.strideB[b]);
15753 state.ra.safeRelease(state.inputs.strideC[b]);
15754 }
15755 }
15756 }
15757
15758 // Non-strided variable batch support.
15759 if (problem.batch == BatchMode::Variable) {
15760 auto tempA = state.ra.alloc().uq();
15761 auto tempB = state.ra.alloc().uq();
15762 auto tempC = state.ra.alloc().uq();
15763 auto tempIDK = state.ra.alloc().ud();
15764 auto offset_scalar = state.ra.alloc().uq();
15765 auto offset_pointer = state.ra.alloc().uq();
15766 auto tempAlphaReal = state.ra.alloc().uq();
15767 auto tempAlphaImag = state.ra.alloc().uq();
15768 auto tempBetaReal = state.ra.alloc().uq();
15769 auto tempBetaImag = state.ra.alloc().uq();
15770
15771 eshl(1, tempIDK, state.inputs.groupIDK, uint16_t(log2(Ts.size())),
15772 strategy, state);
15773
15774 // load and set alpha
15775 emul(1, offset_scalar, tempIDK, state.inputs.incr_alpha.uw(), strategy,
15776 state);
15777 eadd(1, tempAlphaReal[0], state.inputs.alpha_array, offset_scalar,
15778 strategy, state);
15779 if (Ts.isComplex()) {
15780 eadd(1, tempAlphaImag[0], tempAlphaReal[0], Ts.real().size(),
15781 strategy, state);
15782 }
15783 if (Ts.real().size() == 4) {
15784 load(1, tempAlphaReal, scattered_dword(1), A64, tempAlphaReal);
15785 if (Ts.isComplex()) {
15786 load(1, tempAlphaImag, scattered_dword(1), A64, tempAlphaImag);
15787 }
15788 } else if (Ts.real().size() == 2) {
15789 load(1, tempAlphaReal, scattered_byte(2), A64, tempAlphaReal);
15790 if (Ts.isComplex()) {
15791 load(1, tempAlphaImag, scattered_byte(2), A64, tempAlphaImag);
15792 }
15793 } else {
15794 load(1, tempAlphaReal, scattered_qword(1), A64, tempAlphaReal);
15795 if (Ts.isComplex()) {
15796 load(1, tempAlphaImag, scattered_qword(1), A64, tempAlphaImag);
15797 }
15798 }
15799 mov(1, state.inputs.alpha_real, tempAlphaReal.sub(0, Ts.real().ngen()));
15800 if (Ts.isComplex())
15801 mov(1, state.inputs.alpha_imag,
15802 tempAlphaImag.sub(0, Ts.real().ngen()));
15803 // end load and set alpha
15804
15805 // load and set beta
15806 emul(1, offset_scalar, tempIDK, state.inputs.incr_beta.uw(), strategy,
15807 state);
15808 eadd(1, tempBetaReal[0], state.inputs.beta_array, offset_scalar,
15809 strategy, state);
15810 if (Ts.isComplex()) {
15811 eadd(1, tempBetaImag[0], tempBetaReal[0], Ts.real().size(),
15812 strategy, state);
15813 }
15814 if (Ts.real().size() == 4) {
15815 load(1, tempBetaReal, scattered_dword(1), A64, tempBetaReal);
15816 if (Ts.isComplex()) {
15817 load(1, tempBetaImag, scattered_dword(1), A64, tempBetaImag);
15818 }
15819 } else if (Ts.real().size() == 2) {
15820 load(1, tempBetaReal, scattered_byte(2), A64, tempBetaReal);
15821 if (Ts.isComplex()) {
15822 load(1, tempBetaImag, scattered_byte(2), A64, tempBetaImag);
15823 }
15824 } else {
15825 load(1, tempBetaReal, scattered_qword(1), A64, tempBetaReal);
15826 if (Ts.isComplex()) {
15827 load(1, tempBetaImag, scattered_qword(1), A64, tempBetaImag);
15828 }
15829 }
15830 mov(1, state.inputs.beta_real, tempBetaReal.sub(0, Ts.real().ngen()));
15831 if (Ts.isComplex())
15832 mov(1, state.inputs.beta_imag,
15833 tempBetaImag.sub(0, Ts.real().ngen()));
15834 // end load and set beta
15835
15836 eshl(1, tempIDK, state.inputs.groupIDK, uint16_t(3), strategy, state);
15837 emul(1, offset_pointer, tempIDK, state.inputs.incr_a_array.uw(),
15838 strategy, state);
15839 eadd(1, tempA[0], state.inputs.A, offset_pointer, strategy, state);
15840 load(1, tempA, scattered_qword(1), strategy.A.base, tempA);
15841
15842 emul(1, offset_pointer, tempIDK, state.inputs.incr_b_array.uw(),
15843 strategy, state);
15844 eadd(1, tempB[0], state.inputs.B, offset_pointer, strategy, state);
15845 load(1, tempB, scattered_qword(1), strategy.B.base, tempB);
15846
15847 eadd(1, tempC[0], state.inputs.C[0], tempIDK, strategy, state);
15848 load(1, tempC, scattered_qword(1), strategy.C.base, tempC);
15849
15850 emov(1, state.effA, tempA, strategy, state);
15851 emov(1, state.effB, tempB, strategy, state);
15852 emov(1, state.effC[0], tempC, strategy, state);
15853
15854 state.ra.safeRelease(tempA);
15855 state.ra.safeRelease(tempB);
15856 state.ra.safeRelease(tempC);
15857 state.ra.safeRelease(tempIDK);
15858 state.ra.safeRelease(offset_scalar);
15859 state.ra.safeRelease(offset_pointer);
15860 state.ra.safeRelease(tempAlphaReal);
15861 state.ra.safeRelease(tempAlphaImag);
15862 state.ra.safeRelease(tempBetaReal);
15863 state.ra.safeRelease(tempBetaImag);
15864 if (!strategy.persistent) {
15865 state.ra.safeRelease(state.inputs.incr_a_array);
15866 state.ra.safeRelease(state.inputs.incr_b_array);
15867 state.ra.safeRelease(state.inputs.incr_alpha);
15868 state.ra.safeRelease(state.inputs.incr_beta);
15869 state.ra.safeRelease(state.inputs.alpha_array);
15870 state.ra.safeRelease(state.inputs.beta_array);
15871 }
15872 }
15873}
15874
15875// Prepare for persistent GEMM by folding offsets into A/B/C pointers (if stateless),
15876// or saving offsets (if stateful)
15877template <HW hw>
15878void gemm_kernel_generator_t<hw>::gemmFoldOffsets(const GEMMProblem &problem,
15879 const GEMMStrategy &strategy, GEMMState &state) {
15880 auto foldOrSave
15881 = [&](const MatrixAddressingStrategy &sX, Subregister &inputX,
15882 Subregister &offsetX, const Subregister &inputOffsetX,
15883 Subregister &saveOffsetX, bool newInput = false) {
15884 if (sX.base.isStateless()) {
15885 auto oldInputX = inputX;
15886 if (newInput)
15887 inputX = state.ra.alloc_sub(DataType::uq,
15888 getHint(HintType::LongTerm, strategy));
15889 eadd(1, inputX, oldInputX, offsetX, strategy, state);
15890 if (getBytes(offsetX.getType()) < 8) {
15891 state.ra.safeRelease(offsetX);
15892 offsetX = state.ra.alloc_sub(DataType::uq,
15893 getHint(HintType::LongTerm, strategy));
15894 }
15895 emov(1, offsetX, 0, strategy, state);
15896 } else {
15897 offsetX = state.ra.alloc_sub(offsetX.getType(),
15898 getHint(HintType::LongTerm, strategy));
15899 mov(1, offsetX, inputOffsetX);
15900 }
15901 saveOffsetX = offsetX;
15902 };
15903
15904 bool deduplicateAB = (state.inputs.A == state.inputs.B);
15905
15906 foldOrSave(strategy.A, state.inputs.A, state.offsetA, state.inputs.offsetA,
15907 state.saveOffsetA, deduplicateAB);
15908 foldOrSave(strategy.B, state.inputs.B, state.offsetB, state.inputs.offsetB,
15909 state.saveOffsetB);
15910 for (int q = 0; q < state.C_count; q++)
15911 foldOrSave(strategy.C, state.inputs.C[q], state.offsetC[q],
15912 state.inputs.offsetC[q],
15913 state.saveOffsetC[q]); // todo init for hpl
15914 if (problem.usesCO())
15915 foldOrSave(strategy.CO, state.inputs.CO, state.offsetCO,
15916 state.inputs.offsetCO, state.saveOffsetCO);
15917
15918 if (deduplicateAB) state.effA = state.inputs.A;
15919}
15920
15921// Restore input offsets from saved copies, for persistent GEMM.
15922template <HW hw>
15923void gemm_kernel_generator_t<hw>::gemmRestoreOffsets(const GEMMProblem &problem,
15924 const GEMMStrategy &strategy, GEMMState &state) {
15925 auto zeroOrRestore = [&](const MatrixAddressingStrategy &sX,
15926 const Subregister &offsetX,
15927 const Subregister &inputOffsetX) {
15928 if (sX.base.isStateless())
15929 emov(1, offsetX, 0, strategy, state);
15930 else
15931 mov(1, offsetX, inputOffsetX);
15932 };
15933
15934 zeroOrRestore(strategy.A, state.saveOffsetA, state.inputs.offsetA);
15935 zeroOrRestore(strategy.B, state.saveOffsetB, state.inputs.offsetB);
15936 for (int q = 0; q < state.C_count; q++)
15937 zeroOrRestore(
15938 strategy.C, state.saveOffsetC[q], state.inputs.offsetC[q]);
15939 if (problem.usesCO())
15940 zeroOrRestore(strategy.CO, state.saveOffsetCO, state.inputs.offsetCO);
15941}
15942
15943template <HW hw>
15944void gemm_kernel_generator_t<hw>::gemmSetupABC(const GEMMProblem &problem,
15945 const GEMMStrategy &strategy, GEMMState &state) {
15946 if (strategy.persistent) {
15947 state.effA = state.offsetA;
15948 state.effB = state.offsetB;
15949 for (int q = 0; q < state.C_count; q++)
15950 state.effC[q] = state.offsetC[q];
15951 state.effCO = state.offsetCO;
15952 }
15953
15954 // Add offsets to A, B, C base pointers for stateless accesses.
15955 if (strategy.C.base.isStateless()) {
15956 for (int q = 0; q < state.C_count; q++) {
15957 auto Csrc = state.inputs.C[q];
15958 if ((q > 0) && strategy.C.base.isStateless()
15959 && state.inputs.base.isValid())
15960 state.effC[q] = state.inputs.C[q]
15961 = state.ra.alloc_sub<uint64_t>(
15962 getHint(HintType::LongTerm, strategy));
15963
15964 eadd(1, state.effC[q], Csrc, state.offsetC[q], strategy, state);
15965 if (strategy.persistent)
15966 state.offsetC[q] = invalid;
15967 else
15968 state.ra.safeRelease(state.offsetC[q]);
15969 }
15970 }
15971
15972 if (problem.usesCO() && strategy.CO.base.isStateless()) {
15973 eadd(1, state.effCO, state.inputs.CO, state.offsetCO, strategy, state);
15974 if (strategy.persistent)
15975 state.offsetCO = invalid;
15976 else
15977 state.ra.safeRelease(state.offsetCO);
15978 }
15979
15980 if (state.offsetAp.isValid()) {
15981 if (strategy.A.base.isStateless()) {
15982 state.effAp = state.ra.alloc_sub<uint64_t>(
15983 getHint(HintType::LongTerm, strategy));
15984 eadd(1, state.effAp, state.inputs.A, state.offsetAp, strategy,
15985 state);
15986 state.ra.safeRelease(state.offsetAp);
15987 } else
15988 state.effAp = state.offsetAp;
15989 }
15990
15991 if (state.offsetBp.isValid()) {
15992 if (strategy.B.base.isStateless()) {
15993 state.effBp = state.ra.alloc_sub<uint64_t>(
15994 getHint(HintType::LongTerm, strategy));
15995 eadd(1, state.effBp, state.inputs.B, state.offsetBp, strategy,
15996 state);
15997 state.ra.safeRelease(state.offsetBp);
15998 } else
15999 state.effBp = state.offsetBp;
16000 }
16001
16002 if (state.offsetCp.isValid()) {
16003 if (strategy.C.base.isStateless()) {
16004 state.effCp = state.ra.alloc_sub<uint64_t>(
16005 getHint(HintType::LongTerm, strategy));
16006 eadd(1, state.effCp, state.inputs.C[0], state.offsetCp, strategy,
16007 state);
16008 state.ra.safeRelease(state.offsetCp);
16009 } else
16010 state.effCp = state.offsetCp;
16011 }
16012
16013 if (strategy.A.base.isStateless()) {
16014 auto Asrc = state.inputs.A;
16015 if (strategy.B.base.isStateless() && (state.effA == state.effB))
16016 state.effA = state.inputs.A = state.ra.alloc_sub<uint64_t>(
16017 getHint(HintType::LongTerm, strategy));
16018
16019 eadd(1, state.effA, Asrc, state.offsetA, strategy, state);
16020 if (strategy.persistent)
16021 state.offsetA = invalid;
16022 else
16023 state.ra.safeRelease(state.offsetA);
16024 }
16025
16026 if (strategy.B.base.isStateless()) {
16027 eadd(1, state.effB, state.inputs.B, state.offsetB, strategy, state);
16028 if (strategy.persistent)
16029 state.offsetB = invalid;
16030 else
16031 state.ra.safeRelease(state.offsetB);
16032 }
16033
16034 if (strategy.prefetchA && state.effAp.isInvalid()) state.effAp = state.effA;
16035 if (strategy.prefetchB && state.effBp.isInvalid()) state.effBp = state.effB;
16036 if (strategy.prefetchC && state.effCp.isInvalid())
16037 state.effCp = state.effC[0];
16038}
16039
16040// Get (possibly multidimensional) batch IDs.
16041template <HW hw>
16042void gemm_kernel_generator_t<hw>::gemmGetBatchIDs(const GEMMProblem &problem,
16043 const GEMMStrategy &strategy, GEMMState &state) {
16044 switch (problem.batchDims) {
16045 case 0: break;
16046 case 1: state.batchID[0] = state.inputs.groupIDK; break;
16047 case 2: {
16048 state.batchID[0] = state.ra.alloc_sub<uint32_t>();
16049 state.batchID[1] = state.ra.alloc_sub<uint32_t>();
16050 divDown(state.batchID[1], state.inputs.groupIDK,
16051 state.inputs.batchSize1, state.inputs.recipBatchSize1,
16052 state.flagAP, strategy, state);
16053 emul(1, state.batchID[0], state.batchID[1], state.inputs.batchSize1,
16054 strategy, state);
16055 add(1, state.batchID[0], -state.batchID[0], state.inputs.groupIDK);
16056 if (!strategy.persistent) {
16057 state.ra.safeRelease(state.inputs.batchSize1);
16058 state.ra.safeRelease(state.inputs.recipBatchSize1);
16059 }
16060 break;
16061 }
16062 default: stub();
16063 }
16064}
16065
16066template <HW hw>
16067void gemm_kernel_generator_t<hw>::gemmReleaseBatchIDs(
16068 const GEMMProblem &problem, const GEMMStrategy &strategy,
16069 GEMMState &state) {
16070 if (problem.batch != BatchMode::Strided) return;
16071 if (problem.batchDims == 1 && state.r0_info == r0) return;
16072 if (problem.hasBinaryPostOp()) return;
16073 for (int b = 0; b < problem.batchDims; b++)
16074 state.ra.safeRelease(state.batchID[b]);
16075}
16076
16077// Convert linear index to 2D index in a Hilbert curve-like fashion.
16078template <HW hw>
16079void gemm_kernel_generator_t<hw>::gemmHilbertlikeOrder(
16080 const GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state) {
16081 bool triangular = false;
16082 bool rectangular = !triangular && state.inputs.hilbertVD.isValid();
16083
16084 auto storage = state.ra.alloc();
16085 auto u = storage.ud(0);
16086 auto v = storage.ud(1);
16087 auto uh = storage.ud(2);
16088 auto vh = storage.ud(3);
16089 auto a = storage.ud(4);
16090 auto b = storage.ud(5);
16091 /* auto isNormal = storage.ud(6); */ // not used directly
16092 auto isReversed = storage.ud(7);
16093 int soff = storage.getBase() * GRF::bytes(hw);
16094
16095 auto storage2 = state.ra.alloc_range(2);
16096 auto nbu = storage2[0].ud(0);
16097 auto nbv = storage2[0].ud(1);
16098 auto np1 = storage2[0].ud(2);
16099 auto bv1 = storage2[0].ud(3);
16100 auto uv1 = storage2[0].ud(4);
16101 auto temp3 = storage2[0].ud(5);
16102 auto uo = storage2[0].ud(6);
16103 /* auto vo = storage2[0].ud(7); */ // not used directly
16104 auto temp = storage2[1].ud(0);
16105 auto temp2 = storage2[1].ud(1);
16106 auto qrem = storage2[1].ud(2);
16107 auto qqot = storage2[1].ud(4);
16108 auto q = storage2[1].ud(6);
16109 auto ud = storage2[1].ud(7);
16110
16111 auto bu = f1[0], bv = f1[1];
16112
16113 auto vd = state.inputs.hilbertVD;
16114 auto uvdRecip = state.inputs.hilbertUVDRecip;
16115 auto hilbertBail = state.inputs.hilbertBail;
16116
16117 auto any8 = (hw == HW::XeHPC) ? any : any8h;
16118 auto any16 = (hw == HW::XeHPC) ? any : any16h;
16119 bool avoidAny2 = (hw == HW::XeHPC);
16120
16121 auto jumpAny2 = [&](InstructionModifier mod, Label &l) {
16122 if (avoidAny2) {
16123 mod.setExecSize(16);
16124 goto12(mod | any16, l);
16125 } else
16126 jmpi(mod | any2h, l);
16127 };
16128
16129 Label lTriangularTop, lTriangularExit, lTriangularBypass;
16130 Label lRecursiveTop, lRecursiveEnd;
16131
16132 // NB: Sequence assumes group counts fit in 16 bits.
16133 status << "Hilbert-like ordering" << status_stream::endl;
16134 if (avoidAny2) mov(1, f0[0], 0);
16135 if (rectangular)
16136 mov(1, f0[1],
16137 vd.uw(1)); // High word of vd = 0xFFFF -> start splitting in x
16138 else if (triangular)
16139 cmp(1 | ne | f0[0], state.inputs.diagC, 0);
16140 mov(1, u, state.inputs.groupCountM);
16141 mov(1, v, state.inputs.groupCountN);
16142 mov(4, a0, Immediate::uv(4, 0, 12, 8, 0, 0, 0, 0));
16143 mov(1, f1.ud(0), 0); // bu = bv = false
16144 mov(1, np1, triangular ? 0xFFFFFFFF : 0);
16145 if (triangular)
16146 cmp(1 | ~f0[0] | ne | f0[0], state.inputs.m, state.inputs.n);
16147 else
16148 cmp(2 | le | f0[0], u(1), hilbertBail);
16149 mov(1, q, state.inputs.groupIDMN);
16150 add(4, a0[4](1), a0[0](1), 16);
16151 if (!rectangular && !triangular)
16152 emad(1, uv1, -1, u.uw(), v.uw(), strategy, state);
16153 mov(8, a.uw()(1),
16154 Immediate::uv(0x00010000)); // a = b = 0, normal = 1, reversed = 0;
16155 if (soff >= 512) {
16156 add(4, a0, a0, soff);
16157 soff = 0;
16158 }
16159 if (triangular)
16160 jmpi(1 | f0[0], lTriangularBypass);
16161 else
16162 jumpAny2(1 | f0[0], lRecursiveEnd);
16163
16164 // Rectangular partitioning step. Break dispatch into blocks of roughly desired aspect ratio.
16165 if (rectangular) {
16166 auto uvd = uv1;
16167 movi(8 | f0[1], storage.ud(), indirect[a0].ud(soff)(1));
16168 mul(1, uvd, u, vd.uw());
16169 divDown(nbv, q, uvd, uvdRecip, f0[0], strategy, state);
16170 and_(1 | ne | bv, bv1, nbv, 1);
16171 mul(1, temp, uvd, nbv.uw());
16172 mul(1, b, vd.uw(), nbv.uw());
16173 add(1, q, q, -temp); // no DWxW with source modifiers
16174 add(1, v, v, -b);
16175 avg(1, ud, u, -bv1);
16176 min_(1, v, v, vd.uw());
16177 avg(1, uh, u, 0);
16178 mul(1, temp, v.uw(), ud.uw());
16179 cmp(1 | ge | bu, nbu, q, temp);
16180 add(1 | bu, q, q, -temp);
16181 cmp(1 | ne | bu, nbu, nbu.d(), -bv1.d()); // {bu,nbu} ^= bv1
16182 sel(1 | bu, a, uh, 0);
16183 avg(1, u, u, nbu.d());
16184 movi(8 | ~bu | any8, storage.ud(), indirect[a0].ud(soff)(1));
16185 cmp(2 | le | f0[0], u(1), hilbertBail);
16186 sel(1 | ~bu, np1, -bv1, 0);
16187 emad(1, uv1, -1, u.uw(), v.uw(), strategy, state);
16188 mov(1, f1.ud(0), 0); // bu = bv = false
16189 jumpAny2(1 | f0[0], lRecursiveEnd);
16190 }
16191
16192 // Recursive partitioning. Each step breaks the current block
16193 // into 2x2 subblocks and follows the block we are currently in.
16194 // Exit when one dimension is less than hilbertBail.
16195 mark(lRecursiveTop);
16196 {
16197 avg(2, uh(1), u(1), 0);
16198 add(1 | bv, q, uv1, -q);
16199
16200 mul(1, temp, u.uw(), vh.uw());
16201 cmp(1 | ge | bv, nbv, q, temp);
16202 mov(2, uo(1), u(1));
16203 add(1 | bv, q, uv1, -q);
16204 avg(1, v, v, nbv.d());
16205 mul(1, temp, uh.uw(), v.uw());
16206 cmp(1 | ge | bu, nbu, q, temp);
16207 add(1 | bu, q, q, -temp);
16208 avg(1, u, u, nbu.d());
16209
16210 xor_(2, temp(1), nbu(1), np1);
16211 avg(2, uo(1), uo(1), np1.d());
16212 xor_(1 | bv, np1, np1, ~nbu);
16213 and_(2, uo(1), uo(1), temp(1));
16214 emad(1, uv1, -1, u.uw(), v.uw(), strategy, state);
16215 add(2, a(1), a(1), uo(1));
16216
16217 cmp(2 | le | f0[0], u(1), hilbertBail);
16218 movi(8 | ~bu | any8, storage.ud(), indirect[a0].ud(soff)(1));
16219
16220 if (avoidAny2)
16221 goto12(16 | ~f0[0] | any16, lRecursiveEnd, lRecursiveTop, true);
16222 else
16223 jmpi(1 | ~f0[0] | any2h, lRecursiveTop);
16224 }
16225 mark(lRecursiveEnd);
16226 if (avoidAny2) join(16);
16227
16228 cmp(8 | ne | f0[0], isReversed, 0);
16229 movi(8 | f0[0], storage.ud(), indirect[a0].ud(soff)(1));
16230
16231 // Regular 2D traversal over final block.
16232 bool nmk = (strategy.loopOrder[0] == LoopN);
16233 auto divisor = nmk ? v : u;
16234
16235 if (hw < HW::Gen12LP) {
16236 irem(1, qrem, q, divisor);
16237 iqot(1, qqot, q, divisor);
16238 } else {
16239 auto bias = temp.f();
16240 auto divisorFP = temp2.f();
16241 auto qFP = temp3.f();
16242 mov(1, divisorFP, divisor);
16243 mov(1, qFP, q);
16244 mov(1, bias, -0.499996185302734375f); // -1/2 + 2^(-18)
16245 einv(1, divisorFP, divisorFP, strategy, state);
16246 add(1, divisorFP.ud(), divisorFP.ud(), 2);
16247 mad(1, qqot.f(), bias, qFP, divisorFP);
16248 mov(1, qqot, qqot.f());
16249 mad(1, qrem, q, -qqot.uw(), divisor.uw());
16250 }
16251
16252 // Reassign m/n group IDs.
16253 if (!strategy.persistent) {
16254 state.inputs.groupIDM = state.inputs.groupCountM;
16255 state.inputs.groupIDN = state.inputs.groupCountN;
16256 state.inputs.groupCountM = invalid;
16257 state.inputs.groupCountN = invalid;
16258 } else {
16259 state.inputs.groupIDM = state.ra.alloc_sub<uint32_t>();
16260 state.inputs.groupIDN = state.ra.alloc_sub<uint32_t>();
16261 }
16262
16263 add(1, state.inputs.groupIDM, a, nmk ? qqot : qrem);
16264 add(1, state.inputs.groupIDN, b, nmk ? qrem : qqot);
16265
16266 state.ra.safeRelease(storage);
16267 state.ra.safeRelease(storage2);
16268 if (!strategy.persistent) {
16269 state.ra.safeRelease(state.inputs.hilbertVD);
16270 state.ra.safeRelease(state.inputs.hilbertUVDRecip);
16271 state.ra.safeRelease(state.inputs.hilbertBail);
16272 }
16273}
16274
16275// Convert linear index to 2D index in a boustrophedon pattern.
16276template <HW hw>
16277void gemm_kernel_generator_t<hw>::gemmBoustrophedonOrder(
16278 const GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state) {
16279 auto storage = state.ra.alloc_range(4);
16280 auto u = storage[0].ud(0);
16281 auto s = storage[0].ud(1);
16282 auto v = storage[0].ud(2);
16283 auto s1 = storage[0].ud(3);
16284 auto i = storage[0].ud(4);
16285 auto j = storage[0].ud(5);
16286 auto i0 = storage[0].ud(6);
16287 auto two = storage[0].f(7);
16288 auto numFP = storage[1].f(0);
16289 auto islice = storage[1].ud(1);
16290 auto qot = storage[1].ud(2);
16291 auto rem = storage[1].ud(4);
16292 auto ithresh = storage[1].ud(6);
16293 auto temp0 = storage[2].ud(0);
16294 auto temp1 = storage[2].ud(2);
16295 auto temp2 = storage[2].ud(4);
16296 auto bias = storage[3].f(0);
16297 auto denomFP = storage[3].f(2);
16298 auto q = storage[3].ud(4);
16299 auto qFP = storage[3].f(6);
16300
16301 auto s0 = state.inputs
16302 .bslice; // Slice width/height in WGs. Sign interpretation:
16303 // + means slice in m dimension, - means n dimension
16304 auto thresh = state.inputs.bthresh; // Slice size adjustment threshold
16305 // + means increase slice size by 1 starting with this row/column
16306 // - means decrease slice size by 1 starting with this row/column
16307
16308 auto &groupCountM = state.inputs.groupCountM;
16309 auto &groupCountN = state.inputs.groupCountN;
16310 auto idMN = state.inputs.groupIDMN;
16311
16312 Label lBegin, lEnd, lDone, lBeginTri2, lEndTri2, lTricalc1, lTricalc2,
16313 lTricalcOut;
16314
16315 auto divqot = [&](const Subregister &num, const Subregister &denom,
16316 bool large) {
16317 if (hw < HW::Gen12LP) {
16318 irem(1, rem, num, denom);
16319 iqot(1, qot, num, denom);
16320 } else if (large) {
16321 // denom <= 0x400000, qot < 2^16
16322 or_(1, cr0[0], cr0[0], 0x20); // round toward -inf
16323 mov(1, denomFP, denom);
16324 mov(1, numFP, -num);
16325 einv(1, denomFP, denomFP, strategy, state);
16326 add(1, denomFP.ud(), denomFP.ud(), 2);
16327 mul(1, qot.f(), -numFP, denomFP);
16328 mov(1, qot, qot.f());
16329 mad(1 | lt | f1[1], rem.d(), num, denom, -qot.uw());
16330 add(1 | f1[1], rem, rem, denom);
16331 add(1 | f1[1], qot, qot, -1);
16332 and_(1, cr0[0], cr0[0], ~0x30);
16333 } else {
16334 // denom <= 0x40, qot < 2^16
16335 mov(1, denomFP, denom);
16336 mov(1, numFP, num);
16337 mov(1, bias, -0.499996185302734375f); // -1/2 + 2^(-18)
16338 einv(1, denomFP, denomFP, strategy, state);
16339 add(1, denomFP.ud(), denomFP.ud(), 2);
16340 mad(1, qot.f(), bias, numFP, denomFP);
16341 mov(1, qot, qot.f());
16342 mad(1, rem, num, -qot.uw(), denom.uw());
16343 }
16344 };
16345
16346 auto ecsel
16347 = [&](const InstructionModifier &mod,
16348 const InstructionModifier &cmod, const FlagRegister &flag,
16349 const RegData &dst, const RegData &src0,
16350 const RegData &src1, const RegData &src2) {
16351 if (hw == HW::Gen9 || dst.getByteOffset() & 7) {
16352 cmp(mod | cmod | flag, src2, 0);
16353 sel(mod | flag, dst, src0, src1);
16354 } else
16355 csel(mod | cmod | flag, dst, src0, src1, src2);
16356 };
16357
16358 // NB: Sequence assumes group counts fit in 16 bits.
16359 status << "Boustrophedon ordering" << status_stream::endl;
16360
16361 mul(1, ithresh, abs(thresh.w()), abs(s0.w()));
16362 cmp(1 | ge | f1[0], thresh, 0);
16363 ecsel(1, lt, f0[0], v, groupCountM, groupCountN, s0);
16364 ecsel(1, ge, f0[0], u, groupCountM, groupCountN, s0);
16365
16366 emad(1, temp0, idMN, -v.uw(), ithresh.uw(), strategy, state);
16367 cmp(1 | ge | f0[0], temp2.d(), temp0.d(), 0);
16368 ecsel(1, ge, f0[1], q, temp0, idMN, temp0.d());
16369
16370 if (hw == HW::XeHPC) {
16371 add(1, s1, abs(s0), 1);
16372 add(1 | ~f0[0], s1, abs(s0), temp2.d());
16373 add(1 | ~f1[0], s1, abs(s0), temp2.d());
16374 } else {
16375 add(1, s1, abs(s0), temp2.d());
16376 add(1 | f0[0] | allv, s1, abs(s0), 1);
16377 }
16378
16379 mul(1, temp1, s1.uw(), v.uw());
16380
16381 divqot(q, temp1, true);
16382
16383 mul(1, i0, qot.uw(), s1.uw());
16384 mov(1, islice, qot);
16385 add(1 | f0[0], i0, i0, ithresh);
16386 mov(1, q, rem);
16387 add(1 | sat, temp0, u, -i0);
16388 min_(1, s, s1, temp0);
16389 add(1 | f0[0], islice, islice, abs(thresh));
16390
16391 mul(1, temp2, s.uw(), s.uw());
16392 emad(1, temp1, temp1, -s.uw(), s.uw(), strategy, state);
16393
16394 cmp(1 | gt | f0[0], i0, 0); // not first row?
16395 cmp(1 | lt | f0[1], s1, temp0); // not last row?
16396
16397 if (hw == HW::XeHPC) {
16398 cmp(1 | f0[0] | lt | f0[0], q, temp2); // beginning of row?
16399 cmp(1 | f0[1] | ge | f0[1], q, temp1); // end of row?
16400 } else {
16401 cmp(1 | lt | f1[0], q, temp2); // beginning of row?
16402 cmp(1 | ge | f1[1], q, temp1); // end of row?
16403 }
16404
16405 mov(1, two, 2.0f);
16406 mov(1, bias, 1.25f);
16407
16408 if (hw == HW::XeHPC) {
16409 jmpi(1 | f0[0], lBegin);
16410 jmpi(1 | f0[1], lEnd);
16411 } else {
16412 jmpi(1 | f0[0] | allv, lBegin);
16413 jmpi(1 | f0[1] | allv, lEnd);
16414 }
16415
16416 {
16417 divqot(q, s, false);
16418
16419 add(1, i, i0, rem);
16420 mov(1, j, qot);
16421 }
16422
16423 jmpi(1, lDone);
16424
16425 mark(lBegin);
16426 {
16427 avg(1, temp0, temp2, -s); // s(s-1)/2
16428 mov(1, f1.ud(0), 0xFFFF);
16429 cmp(1 | lt | f0[0], q, temp0);
16430 jmpi(1 | ~f0[0], lBeginTri2);
16431
16432 eadd3(1, q, temp0, -q, -1);
16433 jmpi(1, lTricalc1);
16434
16435 mark(lBeginTri2);
16436 add(1, q, q, -temp0);
16437 jmpi(1, lTricalc2);
16438 }
16439
16440 mark(lEnd);
16441 {
16442 add(1, q, q, -temp1);
16443 avg(1, temp0, temp2, s); // s(s+1)/2
16444 mov(1, f1.ud(0), 0);
16445 cmp(1 | lt | f0[0], q, temp0);
16446 jmpi(1 | ~f0[0], lEndTri2);
16447
16448 eadd3(1, q, temp0, -q, -1);
16449 mark(lTricalc2);
16450 {
16451 mov(1, qFP, q);
16452 mad(1, qFP, bias, qFP, two);
16453 esqt(1, qFP, qFP, strategy, state);
16454 if (hw == HW::Gen9) rnde(1, qFP, qFP);
16455 mov(1, j, qFP);
16456 mul(1, temp0, j.uw(), j.uw());
16457 avg(1, temp0, temp0, -j);
16458 add(1, j, j, -1);
16459 add(1, i, q, -temp0);
16460 }
16461 jmpi(1, lTricalcOut);
16462
16463 mark(lEndTri2);
16464 add(1, q, q, -temp0);
16465 mark(lTricalc1);
16466 {
16467 mov(1, qFP, q);
16468 mad(1, qFP, bias, qFP, two);
16469 esqt(1, qFP, qFP, strategy, state);
16470 if (hw == HW::Gen9) rnde(1, qFP, qFP);
16471 mov(1, i, qFP);
16472 mul(1, temp0, i.uw(), i.uw());
16473 avg(1, temp0, temp0, -i);
16474 add(1, j, q, -temp0);
16475 }
16476
16477 mark(lTricalcOut);
16478 eadd3(1 | f1[0], i, s, -i, -1);
16479 eadd3(1 | ~f1[0], j, v, -j, -1);
16480 add(1, i, i, i0);
16481 }
16482
16483 // Reassign m/n group IDs.
16484 mark(lDone);
16485
16486 if (!strategy.persistent) {
16487 state.inputs.groupIDM = state.inputs.groupCountM;
16488 state.inputs.groupIDN = state.inputs.groupCountN;
16489 state.inputs.groupCountM = invalid;
16490 state.inputs.groupCountN = invalid;
16491 } else {
16492 state.inputs.groupIDM = state.ra.alloc_sub<uint32_t>();
16493 state.inputs.groupIDN = state.ra.alloc_sub<uint32_t>();
16494 }
16495
16496 and_(1 | ne | f1[1], null.ud(), islice, 1);
16497 eadd3(1 | f1[1], j, v, -j, -1);
16498 ecsel(1, ge, f0[0], state.inputs.groupIDM, i, j, s0);
16499 ecsel(1, lt, f0[0], state.inputs.groupIDN, i, j, s0);
16500
16501 state.ra.safeRelease(storage);
16502 if (!strategy.persistent) {
16503 state.ra.safeRelease(state.inputs.bslice);
16504 state.ra.safeRelease(state.inputs.bthresh);
16505 }
16506}
16507
16508// Reverse m/n loops if requested.
16509template <HW hw>
16510void gemm_kernel_generator_t<hw>::gemmReverseLoops(const GEMMProblem &problem,
16511 const GEMMStrategy &strategy, GEMMState &state) {
16512 for (LoopType l : {LoopM, LoopN})
16513 if (strategy.reverse[l]) {
16514 bool fusedL = strategy.fused && (l == strategy.fusedLoop);
16515 auto q = (l == LoopM) ? state.inputs.m : state.inputs.n;
16516 auto q0 = (l == LoopM) ? state.i0 : state.j0;
16517 auto q0Align = state.ra.alloc_sub<uint32_t>();
16518 auto temp = state.ra.alloc_sub<uint32_t>();
16519
16520 add(1, q0Align, q, -1);
16521 if (strategy.fixedWG(problem)) {
16522 mod(temp, q0, strategy.wg[l] * strategy.unroll[l], strategy,
16523 state);
16524 alignDown(q0Align, q0Align, strategy.wg[l] * strategy.unroll[l],
16525 strategy, state);
16526 shl(1, temp, temp, 1);
16527 eadd3(1 | ge | f0[0], q0Align.d(), q0Align, -q0, temp);
16528 mov(1 | f0[0], q0, q0Align);
16529 } else if (fusedL) {
16530 shl(1, temp, state.fusedID, 1);
16531 alignDown(q0Align, q0Align, 2 * strategy.unroll[l], strategy,
16532 state);
16533 eadd3(1 | ge | f0[0], q0Align.d(), q0Align, -q0, temp);
16534 mov(1 | f0[0], q0, q0Align);
16535 } else {
16536 alignDown(
16537 q0Align, q0Align, strategy.unroll[l], strategy, state);
16538 cmp(1 | le | f0[0], q0, q0Align);
16539 add(1 | f0[0], q0, q0Align, -q0);
16540 }
16541 state.ra.safeRelease(temp);
16542 state.ra.safeRelease(q0Align);
16543 }
16544}
16545
16546// Reorder local IDs as needed.
16547template <HW hw>
16548void gemm_kernel_generator_t<hw>::gemmReorderLocalIDs(
16549 const GEMMProblem &problem, const GEMMStrategy &strategy,
16550 GEMMState &state) {
16551 if (strategy.fixedSystolic)
16552 sysgemmReorderLocalIDs(problem, strategy, state);
16553
16554 if (strategy.skewLocalIDs) {
16555 if (!strategy.fixedWG(problem)) stub();
16556 auto wgI = strategy.wg[strategy.loopOrder[0]];
16557 auto adjustEvery = div_up(eusPerSubslice(hw), wgI);
16558 bool innerM = strategy.loopOrder[0] == LoopM;
16559 auto lidI = innerM ? state.lidM : state.lidN;
16560 auto lidO = innerM ? state.lidN : state.lidM;
16561 auto temp = state.ra.alloc_sub<uint16_t>();
16562 auto slidO = lidO;
16563
16564 if (adjustEvery > 1) {
16565 shr(1, temp, lidO, log2(adjustEvery));
16566 slidO = temp;
16567 }
16568
16569 if (strategy.fused)
16570 emad(1, lidI, lidI, slidO, 2, strategy, state);
16571 else
16572 add(1, lidI, lidI, slidO);
16573
16574 if (!is_zero_or_pow2(wgI)) stub();
16575
16576 and_(1, lidI, lidI, wgI - 1);
16577
16578 state.ra.safeRelease(temp);
16579 }
16580}
16581
16582// Convert leading dimension and offset inputs to bytes.
16583template <ngen::HW hw>
16584void gemm_kernel_generator_t<hw>::gemmScaleInputs(const GEMMProblem &problem,
16585 const GEMMStrategy &strategy, GEMMState &state) {
16586 auto Ta_ext = problem.Ta_ext, Tb_ext = problem.Tb_ext,
16587 Tc_ext = problem.Tc_ext, Tco = problem.Tco;
16588
16589 emulConstant(1, state.inputs.lda, state.inputs.lda, Ta_ext.size(), strategy,
16590 state);
16591 if (state.inputs.ldb != state.inputs.lda)
16592 emulConstant(1, state.inputs.ldb, state.inputs.ldb, Tb_ext.size(),
16593 strategy, state);
16594 for (int q = 0; q < state.C_count; q++)
16595 emulConstant(1, state.inputs.ldc[q], state.inputs.ldc[q], Tc_ext.size(),
16596 strategy, state);
16597 if (state.inputs.ldco.isValid())
16598 emulConstant(1, state.inputs.ldco, state.inputs.ldco, Tco.size(),
16599 strategy, state);
16600
16601 {
16602 emulConstant(1, state.inputs.offsetA, state.inputs.offsetA,
16603 Ta_ext.size(), strategy, state);
16604 emulConstant(1, state.inputs.offsetB, state.inputs.offsetB,
16605 Tb_ext.size(), strategy, state);
16606 for (int q = 0; q < state.C_count; q++)
16607 emulConstant(1, state.inputs.offsetC[q], state.inputs.offsetC[q],
16608 Tc_ext.size(), strategy, state);
16609 if (problem.usesCO())
16610 emulConstant(1, state.inputs.offsetCO, state.inputs.offsetCO,
16611 Tco.size(), strategy, state);
16612 }
16613
16614 if (problem.batch == BatchMode::Strided)
16615 for (int b = 0; b < problem.batchDims; b++) {
16616 emulConstant(1, state.inputs.strideA[b], state.inputs.strideA[b],
16617 Ta_ext.size(), strategy, state);
16618 emulConstant(1, state.inputs.strideB[b], state.inputs.strideB[b],
16619 Tb_ext.size(), strategy, state);
16620 emulConstant(1, state.inputs.strideC[b], state.inputs.strideC[b],
16621 Tc_ext.size(), strategy, state);
16622 }
16623}
16624
16625// Cache multiples of lda/ldb for later address calculations.
16626template <HW hw>
16627void gemm_kernel_generator_t<hw>::gemmCacheLDABMultiples(
16628 const GEMMProblem &problem, const GEMMStrategy &strategy,
16629 GEMMState &state) {
16630 int na = 0, nb = 0;
16631
16632 if (!strategy.A.address2D) switch (problem.A.layout) {
16633 case MatrixLayout::N:
16634 na = std::max(strategy.ka_load, strategy.ka_prefetch);
16635 break;
16636 case MatrixLayout::T:
16637 na = strategy.unroll[LoopM];
16638 if (isTransposing(strategy.A.accessType))
16639 na = std::min(na, maxScatteredSIMD(hw, strategy.A));
16640 break;
16641 default: break;
16642 }
16643
16644 if (!strategy.B.address2D) switch (problem.B.layout) {
16645 case MatrixLayout::T:
16646 nb = std::max(strategy.kb_load, strategy.kb_prefetch);
16647 break;
16648 case MatrixLayout::N:
16649 nb = strategy.unroll[LoopN];
16650 if (isTransposing(strategy.B.accessType))
16651 nb = std::min(nb, maxScatteredSIMD(hw, strategy.B));
16652 break;
16653 default: break;
16654 }
16655
16656 if (na <= 2) na = 0;
16657 if (nb <= 2) nb = 0;
16658
16659 if (na || nb) extendIndexVec(std::max(na, nb), state);
16660
16661 if (na) {
16662 bool a64 = (strategy.A.base.getModel() == ModelA64);
16663 state.ldaMultiples
16664 = createLDMultiples(a64, na, state.lda, strategy, state);
16665 }
16666
16667 if (nb) {
16668 bool a64 = (strategy.B.base.getModel() == ModelA64);
16669 state.ldbMultiples
16670 = createLDMultiples(a64, nb, state.ldb, strategy, state);
16671 }
16672}
16673
16674// Cache multiples of ldc for later address calculations.
16675template <HW hw>
16676void gemm_kernel_generator_t<hw>::gemmCacheLDCMultiples(
16677 const GEMMProblem &problem, const GEMMStrategy &strategy,
16678 GEMMState &state, bool prefetch) {
16679 if ((prefetch ? strategy.C_prefetch : strategy.C).address2D) return;
16680
16681 int nc = 0;
16682 switch (problem.C.layout) {
16683 case MatrixLayout::N: nc = strategy.unroll[LoopN]; break;
16684 case MatrixLayout::T: nc = strategy.unroll[LoopM]; break;
16685 default: break;
16686 }
16687
16688 if (nc <= 2) return;
16689
16690 bool a64 = (strategy.C.base.getModel() == ModelA64);
16691 int C_count = prefetch ? 1 : state.C_count;
16692 for (int q = 0; q < C_count; q++)
16693 state.ldcMultiples[q] = createLDMultiples(
16694 a64, nc, state.inputs.ldc[q], strategy, state);
16695}
16696
16697// GEMM kernel generation interface.
16698template <HW hw>
16699void gemm_kernel_generator_t<hw>::gemm(GEMMProblem problem,
16700 GEMMStrategy strategy, const InterfaceHandler &interface_) {
16701 GEMMState state(hw);
16702 interface = interface_;
16703 gemm(problem, strategy, state);
16704}
16705
16706template <HW hw>
16707void gemm_kernel_generator_t<hw>::gemm(
16708 GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state) {
16709 bool inFusedGEMM = state.fusedGEMM.active;
16710 bool anyKParallel = strategy.kParallelLocal || strategy.kParallel;
16711
16712 Label labelKernelDone, labelReentry;
16713
16714 // By default, don't use dispatch mask.
16715 setDefaultNoMask();
16716 setDefaultAutoSWSB();
16717
16718 // Set up.
16719 gemmInitState(problem, strategy, state);
16720
16721 // Transfer surface indices to strategy AddressBases.
16722 if (!strategy.A.base.isStateless())
16723 strategy.A.base.setIndex(state.inputs.surfaceA);
16724 if (!strategy.B.base.isStateless())
16725 strategy.B.base.setIndex(state.inputs.surfaceB);
16726 if (!strategy.C.base.isStateless()) {
16727 strategy.C.base.setIndex(state.inputs.surfaceC[0]);
16728 if (state.C_count > 1) stub();
16729 }
16730 if (problem.usesCO() && !strategy.CO.base.isStateless())
16731 strategy.CO.base.setIndex(state.inputs.surfaceCO);
16732
16733 for (size_t i = 0; i < strategy.binary.size(); i++)
16734 if (!strategy.binary[i].base.isStateless())
16735 strategy.binary[i].base.setIndex(state.inputs.binarySurfaces[i]);
16736
16737 // Prologue.
16738 if (!inFusedGEMM) prologue(strategy);
16739
16740 // Grab fused ID if needed, and multiply by unroll.
16741 getFusedID(strategy.unroll[strategy.fusedLoop], problem, strategy, state);
16742
16743 if (!inFusedGEMM) {
16744 // Divide out subgroup size from local size 0 and local ID 0, and reorder threads for fusing if needed.
16745 removeSG(problem, strategy, state);
16746 reorderFusedEUs(problem, strategy, state);
16747
16748 } /* !inFusedGEMM */
16749
16750 // Check for copy or compute kernel.
16751 if (strategy.splitCopy) {
16752 state.isCompute = state.ra.alloc_sub<uint32_t>(
16753 getHint(HintType::LongTerm, strategy));
16754 auto localIDY = (strategy.loopOrder[1] == LoopN)
16755 ? state.inputs.localIDN
16756 : state.inputs.localIDM;
16757 auto wgY = strategy.wg[strategy.loopOrder[1]];
16758 cmp(1 | ge | f1[1], state.isCompute, localIDY, wgY);
16759 if (is_zero_or_pow2(wgY))
16760 and_(1, localIDY, localIDY, wgY - 1);
16761 else
16762 add(1 | f1[1], localIDY, localIDY, -wgY);
16763 }
16764
16765 // Scale LDs/offsets.
16766 gemmScaleInputs(problem, strategy, state);
16767
16768 // Local ID handling and saving.
16769 gemmReorderLocalIDs(problem, strategy, state);
16770
16771 if (strategy.needsMNLocalIDs()) saveMNLocalIDs(strategy, state);
16772
16773 if (strategy.needsKLocalIDs()) saveKLocalIDSize(strategy, state);
16774
16775 // Save full k size if needed.
16776 bool anyAB2D = strategy.A.address2D || strategy.B.address2D
16777 || (strategy.prefetchA && strategy.A_prefetch.address2D)
16778 || (strategy.prefetchB && strategy.B_prefetch.address2D);
16779 if (anyKParallel) {
16780 if (strategy.persistent || anyAB2D) {
16781 state.fullK = state.ra.alloc_sub<uint32_t>(
16782 getHint(HintType::LongTerm, strategy));
16783 mov(1, state.fullK, state.inputs.k);
16784 }
16785 } else
16786 state.fullK = state.inputs.k;
16787
16788 // Load ao/bo from memory if needed.
16789 if (problem.abOffset != ABOffset::None && state.inputs.abo.isInvalid()) {
16790 state.inputs.abo = state.ra.alloc_sub<uint32_t>(
16791 getHint(HintType::LongTerm, strategy));
16792 state.inputs.ao = state.inputs.abo.w(0);
16793 state.inputs.bo = state.inputs.abo.w(1);
16794
16795 auto loadABO = [&](const ngen::Subregister &xo,
16796 ngen::Subregister &xoPtr) {
16797 if (xoPtr.isInvalid())
16798 mov(1, xo, 0);
16799 else {
16800 auto header = state.ra.alloc_range(2);
16801 auto data = state.ra.alloc();
16802 emov<uint64_t>(1, header, xoPtr, strategy, state);
16803 (hw >= HW::XeHPG)
16804 ? load(1, data, D32 | CacheSettingsLSC::L1C_L3C, A64,
16805 header)
16806 : load(1, data, scattered_dword(1), A64, header);
16807 mov(1, xo, -data.w(0));
16808 state.ra.safeRelease(xoPtr);
16809 state.ra.safeRelease(header);
16810 state.ra.safeRelease(data);
16811 }
16812 };
16813
16814 loadABO(state.inputs.ao, state.inputs.aoPtr);
16815 loadABO(state.inputs.bo, state.inputs.boPtr);
16816 }
16817
16818 // Persistent thread preparation and re-entry.
16819 if (strategy.persistent) {
16820 if (!strategy.linearOrder()) stub();
16821 if (problem.batch != BatchMode::None)
16822 stub(); // need to wrangle groupIDK also
16823
16824 auto newGroupIDMN = state.ra.alloc_sub<uint32_t>(
16825 getHint(HintType::LongTerm, strategy));
16826 mov(1, newGroupIDMN, state.inputs.groupIDMN);
16827 state.inputs.groupIDMN = newGroupIDMN;
16828
16829 gemmFoldOffsets(problem, strategy, state);
16830
16831 mark(labelReentry);
16832 }
16833
16834 // Group ID remapping.
16835 if (strategy.hilbertOrder)
16836 gemmHilbertlikeOrder(problem, strategy, state);
16837 else if (strategy.boustrophedon)
16838 gemmBoustrophedonOrder(problem, strategy, state);
16839
16840 // Batch handling.
16841 gemmGetBatchIDs(problem, strategy, state);
16842
16843 // Compute offset for A, B, C for non-strided and strided batch.
16844 gemmOffsetBatchABC(problem, strategy, state);
16845
16846 // 32-bit add check. TODO: move out of persistent loop for non-batch.
16847 gemmCheck32(problem, strategy, state);
16848
16849 // Calculate i0, j0, h0 -- the initial i/j/h indices for this thread.
16850 bool needH0 = anyKParallel;
16851
16852 state.i0 = state.ra.alloc_sub<uint32_t>(
16853 getHint(HintType::TempComp0, strategy));
16854 state.j0 = state.ra.alloc_sub<uint32_t>(
16855 getHint(HintType::TempComp1, strategy));
16856 if (needH0)
16857 state.h0 = state.ra.alloc_sub<uint32_t>(
16858 getHint(HintType::TempComp0, strategy));
16859
16860 bool wgCheck = wgRemCheck(problem, strategy);
16861 bool gemmtBarriers = problem.gemmt() && strategy.needsBarrier();
16862
16863 Subregister idM, idN, idK;
16864 Subregister wgI0, wgJ0;
16865
16866 idM = state.ra.alloc_sub<uint32_t>(getHint(HintType::TempComp1, strategy));
16867 idN = state.ra.alloc_sub<uint32_t>(getHint(HintType::TempComp0, strategy));
16868 if (strategy.kParallel)
16869 idK = state.ra.alloc_sub<uint32_t>(
16870 getHint(HintType::TempComp0, strategy));
16871
16872 if (strategy.fixedWG(problem)) {
16873 mulConstant(1, idM, state.inputs.groupIDM, strategy.wg[LoopM]);
16874 mulConstant(1, idN, state.inputs.groupIDN, strategy.wg[LoopN]);
16875 if (strategy.kParallel)
16876 mulConstant(1, idK, state.inputs.groupIDK, strategy.wg[LoopK]);
16877 } else {
16878 mul(1, idM, state.inputs.groupIDM, state.inputs.localSizeM.uw());
16879 mul(1, idN, state.inputs.groupIDN, state.inputs.localSizeN.uw());
16880 if (strategy.kParallel)
16881 mul(1, idK, state.inputs.groupIDK, state.inputs.localSizeK.uw());
16882 }
16883
16884 if (wgCheck || gemmtBarriers) {
16885 wgI0 = state.ra.alloc_sub<uint32_t>(
16886 getHint(HintType::TempComp0, strategy));
16887 wgJ0 = state.ra.alloc_sub<uint32_t>(
16888 getHint(HintType::TempComp1, strategy));
16889 mulConstant(1, wgI0, idM, strategy.unroll[LoopM]);
16890 mulConstant(1, wgJ0, idN, strategy.unroll[LoopN]);
16891 }
16892
16893 add(1, idM, idM, state.lidM);
16894 add(1, idN, idN, state.lidN);
16895 if (strategy.kParallel) add(1, idK, idK, state.lidK);
16896
16897 mulConstant(1, state.i0, idM, strategy.unroll[LoopM]);
16898 mulConstant(1, state.j0, idN, strategy.unroll[LoopN]);
16899
16900 if (strategy.kParallel)
16901 emul(1, state.h0, idK, state.inputs.k0, strategy, state);
16902 else if (strategy.kParallelLocal)
16903 mul(1, state.h0, state.inputs.k0, state.lidK);
16904
16905 gemmReverseLoops(problem, strategy, state);
16906
16907 state.ra.safeRelease(idM);
16908 state.ra.safeRelease(idN);
16909 state.ra.safeRelease(idK);
16910 if (!strategy.persistent) {
16911 state.ra.safeRelease(state.inputs.localSizeM);
16912 state.ra.safeRelease(state.inputs.localSizeN);
16913 }
16914 if (anyKParallel) {
16915 state.ra.safeRelease(state.inputs.localIDK);
16916 if (!strategy.persistent) state.ra.safeRelease(state.inputs.localSizeK);
16917 }
16918 if (strategy.linearOrder() || strategy.persistent) {
16919 state.ra.safeRelease(state.inputs.groupIDM);
16920 state.ra.safeRelease(state.inputs.groupIDN);
16921 }
16922
16923 moveR0(strategy, state);
16924
16925 // Adjust k range as needed.
16926 if (anyKParallel) {
16927 add(1, state.inputs.k,
16928 strategy.persistent ? state.fullK : state.inputs.k, -state.h0);
16929 min_(1, state.inputs.k, state.inputs.k, state.inputs.k0);
16930
16931 bool keepK0 = false;
16932 keepK0 |= strategy.kParallelLocal
16933 && (strategy.barrierFreq > 0 || strategy.slmBuffers > 0);
16934 keepK0 |= strategy.persistent;
16935
16936 if (!keepK0) state.ra.safeRelease(state.inputs.k0);
16937 }
16938
16939 state.ra.safeRelease(state.inputs.localIDM);
16940 state.ra.safeRelease(state.inputs.localIDN);
16941 if (!strategy.needsMNLocalIDs()) state.lidM = state.lidN = invalid;
16942
16943 // Compute workgroup remainders if needed.
16944 if (wgCheck) {
16945 state.remaindersWG[LoopM] = state.ra.alloc_sub<uint32_t>(
16946 getHint(HintType::TempComp1, strategy));
16947 state.remaindersWG[LoopN] = state.ra.alloc_sub<uint32_t>(
16948 getHint(HintType::TempComp0, strategy));
16949 add(1 | sat, state.remaindersWG[LoopM], -wgI0, state.inputs.m);
16950 add(1 | sat, state.remaindersWG[LoopN], -wgJ0, state.inputs.n);
16951 }
16952 state.ra.safeRelease(wgI0);
16953 state.ra.safeRelease(wgJ0);
16954
16955 // Compute base addresses for A, B, C.
16956 gemmOffsetABC(true, state.i0, state.j0, state.h0, problem, strategy, state);
16957
16958 gemmSetupABC(problem, strategy, state);
16959 gemmSubkernel(problem, strategy, state);
16960
16961 mark(labelKernelDone);
16962
16963 // Persistent thread loop. Advance group ID and re-enter kernel if there's more work to do.
16964 if (strategy.persistent) {
16965 status << "Persistent loop" << status_stream::endl;
16966 if (state.inputs.groupCountMN.isInvalid()) {
16967 state.inputs.groupCountMN = state.ra.alloc_sub<uint32_t>(
16968 getHint(HintType::LongTerm, strategy));
16969 emul(1, state.inputs.groupCountMN, state.inputs.groupCountM,
16970 state.inputs.groupCountN, strategy, state);
16971 }
16972
16973 add(1, state.inputs.groupIDMN, state.inputs.groupIDMN,
16974 state.inputs.groupStride);
16975 cmp(1 | lt | state.flagAP, state.inputs.groupIDMN,
16976 state.inputs.groupCountMN);
16977
16978 state.ra.safeRelease(state.inputs.groupCountMN);
16979 gemmRestoreOffsets(problem, strategy, state);
16980
16981 if (strategy.slmBuffers > 0) {
16982 auto temp = state.ra.alloc();
16983 useR0(state, [&](const GRF &r0_info) {
16984 MOCK_BARRIERS barrier(temp, r0_info);
16985 });
16986 state.ra.safeRelease(temp);
16987 }
16988
16989 jmpi(1 | state.flagAP, labelReentry);
16990 }
16991
16992 if (!inFusedGEMM) {
16993 epilogue(strategy, state);
16994 padding();
16995 }
16996}
16997
16998template <HW hw>
16999SubregisterPair gemm_kernel_generator_t<hw>::allocIncrement(
17000 const GEMMStrategy &strategy, CommonState &state) {
17001 if (strategy.avoidIncConflicts)
17002 return SubregisterPair(state.ra.alloc_sub<uint32_t>(
17003 getHint(HintType::LongTerm0, strategy)),
17004 state.ra.alloc_sub<uint32_t>(
17005 getHint(HintType::LongTerm1, strategy)));
17006 else
17007 return SubregisterPair(state.ra.alloc_sub<uint32_t>(
17008 getHint(HintType::LongTerm, strategy)));
17009}
17010
17011// Calculate and cache lda_ka (= lda * ka) and ldb_kb (= ldb * kb) as necessary.
17012template <HW hw>
17013void gemm_kernel_generator_t<hw>::gemmCalcIncrements(const GEMMProblem &problem,
17014 const GEMMStrategy &strategy, GEMMState &state, int ka_load,
17015 int kb_load, bool doA, bool doB) {
17016 int nr = strategy.avoidIncConflicts ? 2 : 1;
17017
17018 if (ka_load == 0) ka_load = strategy.ka_inc();
17019 if (kb_load == 0) kb_load = strategy.kb_inc();
17020
17021 // If A is nontranspose, we need lda * ka_load * elementSize.
17022 if (doA && (problem.A.layout == MatrixLayout::N)) {
17023 if (!strategy.A.address2D) {
17024 if (ka_load > 1) {
17025 if (state.lda_ka.isInvalid())
17026 state.lda_ka = allocIncrement(strategy, state);
17027 for (int i = 0; i < nr; i++)
17028 emulConstant(1, state.lda_ka.getReg(i), state.inputs.lda,
17029 ka_load, strategy, state);
17030 state.ka_cached = ka_load;
17031 } else if (strategy.avoidIncConflicts)
17032 duplicateScalar(state.lda, state);
17033 }
17034 if (strategy.prefetchA && !strategy.A_prefetch.address2D
17035 && (strategy.ka_pfStride != ka_load || strategy.A.address2D)) {
17036 if (strategy.ka_pfStride > 1) {
17037 if (state.lda_ka_prefetch.isInvalid())
17038 state.lda_ka_prefetch = allocIncrement(strategy, state);
17039 for (int i = 0; i < nr; i++)
17040 emulConstant(1, state.lda_ka_prefetch.getReg(i),
17041 state.inputs.lda, strategy.ka_pfStride, strategy,
17042 state);
17043 } else if (strategy.avoidIncConflicts)
17044 duplicateScalar(state.lda, state);
17045 }
17046 }
17047
17048 // Similarly for B if it's transpose.
17049 if (doB && (problem.B.layout == MatrixLayout::T)) {
17050 if (!strategy.B.address2D) {
17051 if (kb_load > 1) {
17052 if (state.ldb_kb.isInvalid())
17053 state.ldb_kb = allocIncrement(strategy, state);
17054 for (int i = 0; i < nr; i++)
17055 emulConstant(1, state.ldb_kb.getReg(i), state.inputs.ldb,
17056 kb_load, strategy, state);
17057 state.kb_cached = kb_load;
17058 } else if (strategy.avoidIncConflicts)
17059 duplicateScalar(state.ldb, state);
17060 }
17061 if (strategy.prefetchB && !strategy.B_prefetch.address2D
17062 && (strategy.kb_pfStride != kb_load || strategy.B.address2D)) {
17063 if (strategy.kb_pfStride > 1) {
17064 if (state.ldb_kb_prefetch.isInvalid())
17065 state.ldb_kb_prefetch = allocIncrement(strategy, state);
17066 for (int i = 0; i < nr; i++)
17067 emulConstant(1, state.ldb_kb_prefetch.getReg(i),
17068 state.inputs.ldb, strategy.kb_pfStride, strategy,
17069 state);
17070 } else if (strategy.avoidIncConflicts)
17071 duplicateScalar(state.ldb, state);
17072 }
17073 }
17074}
17075
17076template <HW hw>
17077void gemm_kernel_generator_t<hw>::gemmDowngradeAccess(
17078 const GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state) {
17079 bool applyOffsetA = false, applyOffsetB = false;
17080
17081 strategy.A.accessType = strategy.unalignedAccA;
17082 strategy.B.accessType = strategy.unalignedAccB;
17083
17084 if (strategy.A.address2D && !isBlock2D(strategy.A.accessType)) {
17085 applyOffsetA = true;
17086 strategy.A.address2D = false;
17087 }
17088
17089 if (strategy.B.address2D && !isBlock2D(strategy.B.accessType)) {
17090 applyOffsetB = true;
17091 strategy.B.address2D = false;
17092 }
17093
17094 if (applyOffsetA || applyOffsetB)
17095 gemmOffsetABC(false, state.i0, state.j0, state.h0, problem, strategy,
17096 state, applyOffsetA, applyOffsetB, false);
17097}
17098
17099template <HW hw>
17100void gemm_kernel_generator_t<hw>::gemmSubkernel(
17101 GEMMProblem &problem, GEMMStrategy &strategy, GEMMState state) {
17102 Label labelSubkernelDone, labelSubkernelEarlyExit;
17103
17104 status << "Begin subkernel: unroll " << strategy.unroll[LoopM] << 'x'
17105 << strategy.unroll[LoopN] << status_stream::endl;
17106
17107 // Calculate remainders for m/n loops: clamp(m - i0, 0, unrollM), clamp(n - j0, 0, unrollN).
17108 // Careful with this clamping, because unroll may change in remainder handling.
17109 bool remM = (strategy.remHandling[LoopM] != RemainderHandling::Ignore);
17110 bool remN = (strategy.remHandling[LoopN] != RemainderHandling::Ignore);
17111 bool fusedremM = remM && strategy.fused && (strategy.fusedLoop == LoopM);
17112 bool fusedremN = remN && strategy.fused && (strategy.fusedLoop == LoopN);
17113
17114 state.doLateExit = strategy.lateExit();
17115 bool earlyExit = !state.doLateExit;
17116
17117 if (fusedremM || fusedremN) {
17118 state.remFusedStorage = state.ra.alloc_sub<uint32_t>();
17119 add(1, state.remFusedStorage, -state.fusedID,
17120 uint16_t(strategy.unroll[strategy.fusedLoop]));
17121 }
17122 if (remM || !earlyExit) {
17123 state.remaindersFused[LoopM] = state.remainders[LoopM]
17124 = state.ra.alloc_sub<uint32_t>(
17125 getHint(HintType::LongTerm, strategy));
17126 InstructionModifier mod = 1 | sat;
17127 if (!fusedremM && earlyExit) mod = mod | le | f0[1];
17128 add(mod, state.remainders[LoopM], -state.i0, state.inputs.m);
17129 }
17130 if (remN || !earlyExit) {
17131 state.remaindersFused[LoopN] = state.remainders[LoopN]
17132 = state.ra.alloc_sub<uint32_t>(
17133 getHint(HintType::LongTerm, strategy));
17134 InstructionModifier mod = 1 | sat;
17135 if (!fusedremN && earlyExit) mod = mod | le | f1[1];
17136 add(mod, state.remainders[LoopN], -state.j0, state.inputs.n);
17137 }
17138 if (fusedremM || fusedremN) {
17139 state.remaindersFused[strategy.fusedLoop] = state.remFusedStorage;
17140 add(1 | sat, state.remFusedStorage, -state.remFusedStorage,
17141 state.remainders[strategy.fusedLoop]);
17142 if (earlyExit) {
17143 cmp(1 | le | (fusedremM ? f0[1] : f1[1]), null.d(),
17144 state.remainders[strategy.fusedLoop].d(), -state.fusedID);
17145 state.allowEmptyC = true;
17146 }
17147 }
17148 if (remM)
17149 min_(1, state.remainders[LoopM], state.remainders[LoopM],
17150 uint16_t(strategy.unroll[LoopM]));
17151 if (remN)
17152 min_(1, state.remainders[LoopN], state.remainders[LoopN],
17153 uint16_t(strategy.unroll[LoopN]));
17154
17155 gemmCalcIncrements(problem, strategy, state);
17156
17157 // Early exit if nothing to do. Keep fused threads together.
17158 if (earlyExit && (remM || remN)) {
17159 InstructionModifier cond;
17160 if (remM && remN)
17161 cond = 1 | f0[1] | anyv;
17162 else if (remM)
17163 cond = 1 | f0[1];
17164 else
17165 cond = 1 | f1[1];
17166
17167 if (state.fusedGEMM.active)
17168 and_(16 | nz | state.fusedGEMM.needLateGEMMDone, null.uw(),
17169 state.inputs.flags.uw(), FlagEarlyFusedGEMMDone);
17170
17171 auto &label = state.fusedGEMM.active ? labelSubkernelEarlyExit
17172 : labelSubkernelDone;
17173
17174 ejmpi(cond, label);
17175 }
17176
17177 // Create the kernel body. If enabled, create two versions, one with A/B more aligned.
17178 bool success;
17179 if (!strategy.optAlignAB)
17180 success = gemmMEdge(problem, strategy, state);
17181 else {
17182 // Check alignment of effA, effB, lda, and ldb.
17183 Label labelUnaligned;
17184 uint16_t mask = (strategy.optAlignAB - 1);
17185 bool check_lda = !isPacked(problem.A.layout);
17186 bool check_ldb = !isPacked(problem.B.layout);
17187 if (problem.A.alignment & mask) {
17188 and_(1 | nz | f0[0], null.uw(), state.effA.uw(), mask);
17189 if (check_lda)
17190 and_(1 | nz | f1[0], null.uw(), state.inputs.lda.uw(), mask);
17191 }
17192 if (problem.B.alignment & mask) {
17193 and_(1 | nz | f0[1], null.uw(), state.effB.uw(), mask);
17194 if (check_ldb)
17195 and_(1 | nz | f1[1], null.uw(), state.inputs.ldb.uw(), mask);
17196 }
17197 if (problem.A.alignment & mask) {
17198 InstructionModifier amod = check_lda ? 1 | f0[0] | anyv : 1 | f0[0];
17199 ejmpi(amod, labelUnaligned);
17200 }
17201 if (problem.B.alignment & mask) {
17202 InstructionModifier bmod = check_ldb ? 1 | f0[1] | anyv : 1 | f0[1];
17203 ejmpi(bmod, labelUnaligned);
17204 }
17205
17206 auto alignedProblem = problem;
17207 alignedProblem.A.setAlignment(
17208 std::max<int>(problem.A.alignment, strategy.optAlignAB));
17209 alignedProblem.B.setAlignment(
17210 std::max<int>(problem.B.alignment, strategy.optAlignAB));
17211
17212 status << "Aligned A/B" << status_stream::endl;
17213 success = gemmMEdge(alignedProblem, strategy, state);
17214
17215 if (!success && lastException) std::rethrow_exception(lastException);
17216
17217 state.isNested ? jmpi(1, labelSubkernelDone)
17218 : epilogue(strategy, state);
17219
17220 mark(labelUnaligned);
17221
17222 auto modStrategy = strategy;
17223
17224 gemmDowngradeAccess(problem, modStrategy, state);
17225
17226 status << "Unaligned A/B" << status_stream::endl;
17227 if (!gemmMEdge(problem, modStrategy, state)) {
17228 modStrategy.checkAdd32
17229 = false; // Don't optimize additions on this (slow) path to reduce code size.
17230 status << "Reducing register usage" << status_stream::endl;
17231 success = success && modStrategy.minimize(hw, problem);
17232
17233 gemmCalcIncrements(problem, modStrategy,
17234 state); // Recalculate lda_ka/ldb_kb as they have changed.
17235
17236 success = success && gemmMEdge(problem, modStrategy, state);
17237 }
17238 }
17239
17240 if (!success)
17241 lastException ? std::rethrow_exception(lastException)
17242 : throw std::runtime_error("Could not generate kernel.");
17243
17244 mark(labelSubkernelDone);
17245
17246 if (state.fusedGEMM.active) {
17247 mov(1, state.fusedGEMM.needLateGEMMDone, 0);
17248 mark(labelSubkernelEarlyExit);
17249 }
17250
17251 safeRelease(state.lda_ka, state);
17252 safeRelease(state.ldb_kb, state);
17253 safeRelease(state.lda_ka_prefetch, state);
17254 safeRelease(state.ldb_kb_prefetch, state);
17255}
17256
17257template <HW hw>
17258void gemm_kernel_generator_t<hw>::gemmSuperkernelInitState(
17259 GEMMSuperkernelProblem &problem, GEMMSuperkernelStrategy &strategy,
17260 GEMMSuperkernelState &state) {
17261 if (strategy.persistent) interface.requireGlobalAtomics();
17262
17263 gemmInitState(problem, strategy.substrategies[0], state, true);
17264
17265 state.isNested |= strategy.persistent;
17266
17267 state.inputsSK.surfacePlan = interface.getArgumentSurface("plan");
17268 state.inputsSK.planCount = interface.getArgument("plan_count");
17269 state.inputsSK.localID = interface.getLocalID(0);
17270 state.inputsSK.localSize = interface.getLocalSize(0);
17271
17272 state.ra.claim(state.inputsSK.localID);
17273 state.ra.claim(state.inputsSK.localSize);
17274 state.ra.claim(state.inputsSK.planCount);
17275}
17276
17277// Create a GEMM superkernel.
17278template <HW hw>
17279void gemm_kernel_generator_t<hw>::gemmSuperkernel(
17280 GEMMSuperkernelProblem problem, GEMMSuperkernelStrategy strategy,
17281 const InterfaceHandler &interface_) {
17282 auto &strategy0 = strategy.substrategies[0];
17283 bool persistent = strategy.persistent;
17284
17285 GEMMSuperkernelState state(hw);
17286
17287 // Set up.
17288 setDefaultNoMask();
17289 setDefaultAutoSWSB();
17290 interface = interface_;
17291 gemmSuperkernelInitState(problem, strategy, state);
17292 state.ra.safeRelease(state.inputs.localIDN);
17293 state.ra.safeRelease(state.inputs.localSizeN);
17294
17295 for (auto &ss : strategy.substrategies) {
17296 if (!ss.A.base.isStateless()) ss.A.base.setIndex(state.inputs.surfaceA);
17297 if (!ss.B.base.isStateless()) ss.B.base.setIndex(state.inputs.surfaceB);
17298 if (!ss.C.base.isStateless())
17299 ss.C.base.setIndex(state.inputs.surfaceC[0]);
17300 }
17301
17302 // Prevent unhelpful layouts.
17303 if (problem.A.layout == MatrixLayout::PackedRows) stub();
17304 if (problem.B.layout == MatrixLayout::PackedColumns) stub();
17305
17306 Label loopSK, loopSKEnd;
17307
17308 // Prologue.
17309 prologue(strategy0);
17310
17311 // Grab fused ID if needed.
17312 getFusedID(1, problem, strategy0, state);
17313
17314 // Get my plan ID and convert to offset in plan.
17315 auto idX = r0.ud(1);
17316 auto header = state.ra.alloc();
17317 auto poff = header.ud(2);
17318 constexpr uint16_t eltSz = 8;
17319
17320 auto temp = state.ra.alloc_sub<uint32_t>();
17321
17322 mulConstant(1, temp, state.inputsSK.planCount, strategy.subgroupSize());
17323 mul(1, poff, idX, state.inputsSK.localSize);
17324 add(1, poff, poff, state.inputsSK.localID.uw(0));
17325 cmp<uint32_t>(1 | ge | f0[0], poff, temp);
17326 if (eltSz < strategy.subgroupSize())
17327 shr(1, poff, poff, log2(strategy.subgroupSize() / eltSz));
17328 else if (eltSz > strategy.subgroupSize())
17329 mulConstant(1, poff, poff, eltSz / strategy.subgroupSize());
17330
17331 state.ra.safeRelease(temp);
17332 state.ra.safeRelease(state.inputsSK.localID);
17333 state.ra.safeRelease(state.inputsSK.localSize);
17334
17335 if (persistent) add(1, poff, poff, eltSz);
17336
17337 // Move r0 to acc0 if configured.
17338 moveR0(strategy0, state);
17339
17340 // Quick exit for extra threads (uniform WG).
17341 jmpi(1 | f0[0], loopSKEnd);
17342
17343 // Retrieve plan element.
17344 auto pdata = state.ra.alloc(getHint(HintType::TempComp0, strategy0));
17345 load(8, pdata, aligned_block_oword(1), Surface(state.inputsSK.surfacePlan),
17346 header);
17347 state.ra.safeRelease(header);
17348
17349 gemmScaleInputs(problem, strategy0, state); // Scale inputs while waiting.
17350
17351 state.i0 = pdata.d(0);
17352 state.j0 = pdata.d(1);
17353
17354 state.ra.safeRelease(pdata);
17355 state.ra.claim(state.i0);
17356 state.ra.claim(state.j0);
17357
17358 auto flagKID0 = f1[0];
17359 auto flagKID1 = f1[1];
17360
17361 if (strategy.multiM) cmp(1 | lt | flagKID0, null.d(), state.i0, 0);
17362 if (strategy.multiN) cmp(1 | lt | flagKID1, null.d(), state.j0, 0);
17363 and_(2, state.i0.ud()(1), state.i0.ud()(1), uint32_t(0x7FFFFFFF));
17364
17365 // Initial offset of A/B/C.
17366 gemmOffsetABC(
17367 true, state.i0, state.j0, Subregister(), problem, strategy0, state);
17368 gemmSetupABC(problem, strategy0, state);
17369
17370 // Save i0, j0 for later.
17371 state.last_i0 = state.ra.alloc_sub<int32_t>(
17372 getHint(HintType::LongTerm, strategy0));
17373 state.last_j0 = state.ra.alloc_sub<int32_t>(
17374 getHint(HintType::LongTerm, strategy0));
17375 mov(1, state.last_i0, state.i0);
17376 mov(1, state.last_j0, state.j0);
17377
17378 // Top of superkernel loop.
17379 status << "Begin superkernel loop" << status_stream::endl;
17380 mark(loopSK);
17381 {
17382 // Dispatch appropriate kernel, supporting up to 4 subkernels.
17383 int kidx = 0;
17384 Label labelM1, labelM0N1, labelM1N1, labelKernelDone;
17385 if (strategy.multiM) jmpi(1 | flagKID0, labelM1);
17386 if (strategy.multiN) jmpi(1 | flagKID1, labelM0N1);
17387
17388 gemmSubkernel(problem, strategy.substrategies[kidx++], state);
17389
17390 if (strategy.multiN) {
17391 jmpi(1, labelKernelDone);
17392 mark(labelM0N1);
17393 gemmSubkernel(problem, strategy.substrategies[kidx++], state);
17394 }
17395
17396 if (strategy.multiM) {
17397 jmpi(1, labelKernelDone);
17398
17399 mark(labelM1);
17400 if (strategy.multiN) jmpi(1 | flagKID1, labelM1N1);
17401
17402 gemmSubkernel(problem, strategy.substrategies[kidx++], state);
17403
17404 if (strategy.multiN) {
17405 jmpi(1, labelKernelDone);
17406 mark(labelM1N1);
17407 gemmSubkernel(problem, strategy.substrategies[kidx++], state);
17408 }
17409 }
17410
17411 mark(labelKernelDone);
17412
17413 if (persistent) {
17414 // Get next plan element via atomic increment of plan ID counter.
17415 auto header = state.ra.alloc();
17416 auto nextID
17417 = state.ra.alloc(getHint(HintType::TempComp1, strategy0));
17418 auto pdata
17419 = state.ra.alloc(getHint(HintType::TempComp0, strategy0));
17420
17421 mov<uint32_t>(8, header, uint16_t(0));
17422 atomic(AtomicOp::inc, 1, nextID, scattered_dword(),
17423 Surface(state.inputsSK.surfacePlan), header);
17424
17425 // Load next plan element, or exit if no more work.
17426 mulConstant<uint32_t>(1, header[2], nextID[0], eltSz);
17427 cmp<uint32_t>(
17428 1 | ge | f0[0], null, nextID[0], state.inputsSK.planCount);
17429 add<uint32_t>(1, header[2], header[2], eltSz);
17430
17431 jmpi(1 | f0[0], loopSKEnd);
17432
17433 load(8, pdata, aligned_block_oword(1),
17434 Surface(state.inputsSK.surfacePlan), header);
17435 state.ra.safeRelease(header);
17436 state.ra.safeRelease(nextID);
17437
17438 // Load next (i0, j0) and kernel IDs.
17439 auto in_i0 = pdata.d(0);
17440 auto in_j0 = pdata.d(1);
17441
17442 if (strategy.multiM) cmp(1 | lt | flagKID0, null.d(), in_i0, 0);
17443 if (strategy.multiN) cmp(1 | lt | flagKID1, null.d(), in_j0, 0);
17444 and_(1, state.i0.ud(), in_i0.ud(), uint32_t(0x7FFFFFFF));
17445 and_(1, state.j0.ud(), in_j0.ud(), uint32_t(0x7FFFFFFF));
17446
17447 // Get difference in i0 and j0...
17448 add(1, in_i0, state.i0, -state.last_i0);
17449 add(1, in_j0, state.j0, -state.last_j0);
17450
17451 // ... save current (i0, j0) for later...
17452 mov(1, state.last_i0, state.i0);
17453 mov(1, state.last_j0, state.j0);
17454
17455 // ...and offset A, B, C appropriately.
17456 gemmOffsetABC(false, in_i0, in_j0, Subregister(), problem,
17457 strategy0, state);
17458
17459 state.ra.safeRelease(pdata);
17460
17461 state.ra.safeRelease(state.i0);
17462 state.ra.safeRelease(state.j0);
17463
17464 // Ready for the next kernel.
17465 jmpi(1, loopSK);
17466 }
17467 }
17468 mark(loopSKEnd);
17469
17470 epilogue(strategy.substrategies[0], state);
17471 padding();
17472}
17473
17474// Get driver information from this strategy.
17475template <HW hw>
17476CommonDriverInfo gemm_kernel_generator_t<hw>::driverInfo(
17477 const GEMMProblem &problem, const GEMMStrategy &strategy) {
17478 CommonDriverInfo info;
17479
17480 info.subgroupSize = strategy.subgroupSize;
17481 info.fusedLoop = strategy.fused ? strategy.fusedLoop : LoopNone;
17482 info.grfCount = strategy.GRFs;
17483 for (int d = 0; d < 3; d++) {
17484 info.loopOrder[d] = strategy.loopOrder[d];
17485 info.blocking[d] = strategy.blocking[d];
17486 info.blockingAlt[d] = strategy.blockingAlt[d];
17487 info.unroll[d] = strategy.unroll[d];
17488 info.wg[d] = strategy.wg[d];
17489 }
17490 info.wgExpand = strategy.splitCopy ? 2 : 1;
17491 if (strategy.hilbertOrder) {
17492 info.loopOrder[0] = (info.loopOrder[0] == LoopN) ? LoopMNHilbertNMK
17493 : LoopMNHilbertMNK;
17494 info.loopOrder[1] = LoopNone;
17495 } else if (strategy.boustrophedon) {
17496 info.loopOrder[0] = (info.loopOrder[0] == LoopN)
17497 ? LoopMNBoustrophedonNMK
17498 : LoopMNBoustrophedonMNK;
17499 info.loopOrder[1] = LoopNone;
17500 }
17501 if (strategy.persistent)
17502 info.loopOrder[0]
17503 = static_cast<LoopType>(info.loopOrder[0] | LoopPersistent);
17504 if (problem.batch == BatchMode::None && !strategy.kParallelLocal)
17505 info.loopOrder[2] = LoopNone;
17506 info.wgUpdate = strategy.getWGType(problem);
17507 info.kRemainderHandling
17508 = (strategy.remHandling[LoopK] != RemainderHandling::Ignore);
17509 info.kParallel = strategy.kParallel;
17510 info.kParallelLocal = strategy.kParallelLocal;
17511 info.slm = int(gemmSLMSize(problem, strategy));
17512 info.perKSLM = int(gemmPerKSLMSize(problem, strategy));
17513 info.alignment[0] = problem.A.alignment;
17514 info.alignment[1] = problem.B.alignment;
17515 info.alignment[2] = problem.C.alignment;
17516 info.support4GB[0] = (strategy.A.base.getModel() == ModelA64);
17517 info.support4GB[1] = (strategy.B.base.getModel() == ModelA64);
17518 info.support4GB[2] = (strategy.C.base.getModel() == ModelA64);
17519
17520 return info;
17521}
17522
17523template <HW hw>
17524CommonDriverInfo gemm_kernel_generator_t<hw>::driverInfo(
17525 const GEMMSuperkernelProblem &problem, const GEMMStrategy &strategy) {
17526 auto info = driverInfo(static_cast<GEMMProblem>(problem), strategy);
17527 return info;
17528}
17529
17530// Return the maximum possible k size for copied SLM data.
17531int GEMMStrategy::maxKSLM(const GEMMProblem &problem, bool isA) const {
17532 return unrollKSLM;
17533}
17534
17535// Validate a GEMM strategy, correcting settings as necessary.
17536void GEMMStrategy::preflight(HW hw, const GEMMProblem &problem) {
17537 auto Ta = problem.Ta, Tb = problem.Tb, Tc = problem.Tc;
17538 auto Ta_real = Ta.real();
17539 auto Tb_real = Tb.real();
17540 auto Tc_real = Tc.real();
17541
17542 // Addressing preflight.
17543
17544 if (C.atomic && !C.base.isStateless() && !C.newDP) C.forceA64();
17545
17546 slmA &= (slmBuffers > 0);
17547 slmB &= (slmBuffers > 0);
17548
17549 A.preflight(hw);
17550 B.preflight(hw);
17551 C.preflight(hw);
17552 A_prefetch.preflight(hw);
17553 B_prefetch.preflight(hw);
17554 C_prefetch.preflight(hw);
17555
17556 bool globalCM = isRegisterColMajor(problem.Tc, problem.C, C);
17557
17558 // Default SIMD setting.
17559 if (fmaSIMD == 0) {
17560 fmaSIMD = std::min(32,
17561 2 * GRF::bytes(hw)
17562 / std::max<int>({Ta.size(), Tb.size(), Tc.size()}));
17563 if (hw == HW::Gen9 && Ta_real.size() == 1 && Tb_real.size() == 1
17564 && Tc_real.size() == 4)
17565 fmaSIMD = 32;
17566 }
17567
17568 slmFenceWARWA |= (hw == HW::XeHPG);
17569
17570 if (problem.batch != BatchMode::None) {
17571 persistent = false;
17572 kParallel = false;
17573 }
17574
17575 if (coopA == CoopSplit::K && slmATrans) coopA = CoopSplit::MN;
17576 if (coopB == CoopSplit::K && slmBTrans) coopB = CoopSplit::MN;
17577
17578 checkBeta1 |= C.atomic && !problem.beta1();
17579
17580 // Fixed systolic kernel handling.
17581 if (fixedSystolic) {
17582 if (wg[LoopM] == 0) wg[LoopM] = 4;
17583 if (wg[LoopN] == 0) wg[LoopN] = 4;
17584 bool doubleM = (wg[LoopM] == 8);
17585
17586 slmCopies = (slmCopies == 3) ? 3 : 1;
17587 slmBuffers = (splitCopy || doubleM) ? 4 : 3;
17588 slmA = slmB = true;
17589 GRFs = 256;
17590 altCRemainder = false;
17591 loopOrder[0] = LoopM;
17592 loopOrder[1] = LoopN;
17593 loopOrder[2] = LoopK;
17594 A.accessType = B.accessType = AccessType::Block;
17595 ka_load = kb_load = 32 / Ta_real;
17596 dpasw = true;
17597 }
17598
17599 dpasw &= fused;
17600
17601 // Accumulator usage: 64-bit emulation, or k chaining, or extra C registers, or storage for r0 header.
17602 // Priority: k chaining > extra C registers > r0 header storage.
17603 // 64-bit emulation > r0 header storage.
17604 if (hw <= HW::Gen9) kChain = 1;
17605 cAccumulators &= (kChain == 1);
17606
17607 bool emulateNeedsAcc = emulate.emulate64 || emulate.emulateDWxDW;
17608 if (moveR0 == MoveR0::Acc)
17609 if (cAccumulators || emulateNeedsAcc || xParallel || (kChain > 1)
17610 || barrierFreq)
17611 moveR0 = MoveR0::None;
17612
17613 // Mixed mode restrictions:
17614 // - mixed hf/f is max SIMD 8 on Gen9
17615 // - mixed hf/f is not allowed on Gen12
17616 // - mixed bf/f is max SIMD 8 on ATS+
17617 if ((Tc_real == Type::f32)
17618 && (Ta_real != Type::f32 || Tb_real != Type::f32))
17619 fmaSIMD = std::min(fmaSIMD, GRF::bytes(hw) >> 2);
17620
17621 // No jump table paths use SIMT control flow. Also atomic reductions.
17622 spf &= !noJumpTables;
17623 spf &= !C.atomic;
17624
17625 checkAdd32 &= !emulate.emulate64_add32;
17626 checkAdd32 &= (A.base.isStateless() || B.base.isStateless());
17627 checkAdd32 &= !(A.address2D && B.address2D
17628 && (!prefetchA || A_prefetch.address2D)
17629 && (!prefetchB || B_prefetch.address2D));
17630
17631 int opCount = outerProductCount(hw, problem, *this);
17632 int minOPCount = minOuterProductCount(hw, problem, *this);
17633 int ukAlign = opCount;
17634
17635 if (kParallelLocal) moveR0 = MoveR0::None;
17636
17637 // SLM copy logic.
17638 int slmVersions = std::max(1, lcm(slmCopies, slmBuffers));
17639 if (slmBuffers > 0) {
17640 moveR0 = MoveR0::None;
17641 barrierFreq = 0;
17642 if (wg[LoopM] <= 0 || wg[LoopN] <= 0)
17643 throw std::runtime_error("Workgroup sizes required.");
17644 if (slmA) ukAlign = lcm(ukAlign, wg[LoopN] * slmVersions);
17645 if (slmB) ukAlign = lcm(ukAlign, wg[LoopM] * slmVersions);
17646 slmUseIncrCopy &= (slmCopies == 1);
17647 }
17648
17649 // ka/kb_load wranging.
17650 if (ka_load_masked == 0) ka_load_masked = ka_load;
17651 if (kb_load_masked == 0) kb_load_masked = kb_load;
17652
17653 if (!slmA) {
17654 ka_load = align_up(ka_load, opCount);
17655 ka_load_masked = align_up(ka_load_masked, minOPCount);
17656 }
17657 if (!slmB) {
17658 kb_load = align_up(kb_load, opCount);
17659 kb_load_masked = align_up(kb_load_masked, minOPCount);
17660 }
17661
17662 // Systolic handling.
17663 if (systolic) {
17664 auto params = systolicParams(hw, problem, *this);
17665
17666 ukAlign = lcm(ukAlign, params.ksys);
17667 auto tileX = params.osys;
17668 (globalCM ? C.tileR : C.tileC) = tileX;
17669 if (unroll[globalCM ? LoopM : LoopN] > tileX) forceCopyC = true;
17670 }
17671
17672 // Prefetch handling.
17673 cooperativePF &= (prefetchA || prefetchB);
17674
17675 if (problem.beta0()) prefetchC = 0;
17676
17677 // Propagate tiling requests to strategy.
17678 int tileM_A, tileK_A, tileK_B, tileN_B;
17679 std::tie(tileM_A, tileK_A, tileK_B, tileN_B)
17680 = targetKernelTiling(hw, problem, *this);
17681 if (A.accessType != AccessType::Block) {
17682 if (tileM_A && !A.tileR) A.tileR = tileM_A;
17683 if (tileK_A && !A.tileC) A.tileC = tileK_A;
17684 }
17685 if (B.accessType != AccessType::Block) {
17686 if (tileK_B && !B.tileR) B.tileR = tileK_B;
17687 if (tileN_B && !B.tileC) B.tileC = tileN_B;
17688 }
17689
17690 if (dpasw) {
17691 auto params = systolicParams(hw, problem, *this);
17692 if (globalCM) {
17693 if (!fusedM()) stub();
17694 B.dpasw = true;
17695 B.tileC = std::max(
17696 1, std::min(unroll[LoopN], params.rcountMax) / 2);
17697 } else {
17698 if (!fusedN()) stub();
17699 A.dpasw = true;
17700 A.tileR = std::max(
17701 1, std::min(unroll[LoopM], params.rcountMax) / 2);
17702 }
17703 }
17704
17705 // Always use 1D addressing for packed inputs.
17706 A.address2D &= !isPacked(problem.A.layout);
17707 B.address2D &= !isPacked(problem.B.layout);
17708
17709 // k unroll wrangling.
17710 ukAlign = lcm(ukAlign, A_copies * ka_load);
17711 ukAlign = lcm(ukAlign, B_copies * kb_load);
17712 if (slmCopies > 1) {
17713 ukAlign = lcm(ukAlign, slmCopies * ka_load);
17714 ukAlign = lcm(ukAlign, slmCopies * kb_load);
17715 }
17716 if (ka_pfStride) ukAlign = lcm(ukAlign, ka_pfStride);
17717 if (kb_pfStride) ukAlign = lcm(ukAlign, kb_pfStride);
17718
17719 int minUnrollKSLM = 1;
17720 if (unrollKSLM > 0)
17721 minUnrollKSLM = unrollKSLM;
17722 else {
17723 if (slmA) minUnrollKSLM = lcm(minUnrollKSLM, ka_load);
17724 if (slmB) minUnrollKSLM = lcm(minUnrollKSLM, kb_load);
17725 }
17726
17727 ukAlign = align_up(ukAlign, minUnrollKSLM * slmVersions);
17728
17729 unroll[LoopK] = align_up(unroll[LoopK], ukAlign);
17730 barrierFreq = align_up(barrierFreq, unroll[LoopK]);
17731
17732 if (unrollKSLM == 0) unrollKSLM = unroll[LoopK] / slmVersions;
17733
17734 if (fixedSystolic) unroll[LoopK] = unrollKSLM = 32 / Ta_real;
17735
17736 barrierFreq = align_up(barrierFreq, unroll[LoopK]);
17737
17738 int kChunkA = (problem.A.tileC ? problem.A.tileC : problem.A.crosspack);
17739 int kChunkB = (problem.B.tileR ? problem.B.tileR : problem.B.crosspack);
17740 if (unroll[LoopK] <= std::min(kChunkA, kChunkB))
17741 remHandling[LoopK] = RemainderHandling::Ignore;
17742
17743 // Default blocking.
17744 bool isZ = problem.Tc.size() >= 16;
17745 auto defaultMBlock = isZ ? 2048 : 4096;
17746 if (hw >= HW::XeHP) defaultMBlock *= 2;
17747 auto defaultNBlock = defaultMBlock;
17748 auto defaultMNBlockNonHilbert = defaultMBlock;
17749
17750 /* No more than (2^16 - 1) workgroups in m/n dimensions for linear orders, plus a huge safety margin. */
17751 if (linearOrder()) {
17752 defaultMBlock = 16384 * unroll[LoopM];
17753 defaultNBlock = 16384 * unroll[LoopN];
17754 }
17755
17756 if (blocking[LoopM] <= 0) blocking[LoopM] = defaultMBlock;
17757 if (blocking[LoopN] <= 0) blocking[LoopN] = defaultNBlock;
17758 if (blocking[LoopK] <= 0) {
17759 int points = 1;
17760 if (slmA || (problem.A.layout != MatrixLayout::T)) points++;
17761 if (slmB || (problem.B.layout != MatrixLayout::N)) points++;
17762 blocking[LoopK] = std::min(2048, (2048 * points) / problem.Ta);
17763 }
17764
17765 auto defaultBlockAltK = blocking[LoopK];
17766 if (hw < HW::XeHPC)
17767 if (hw >= HW::XeHP) defaultBlockAltK = std::min(defaultBlockAltK, 1024);
17768
17769 if (blockingAlt[LoopM] <= 0) blockingAlt[LoopM] = defaultMNBlockNonHilbert;
17770 if (blockingAlt[LoopN] <= 0) blockingAlt[LoopN] = defaultMNBlockNonHilbert;
17771 if (blockingAlt[LoopK] <= 0) blockingAlt[LoopK] = defaultBlockAltK;
17772
17773 // Default workgroups.
17774 auto defaultWGX = 2, defaultWGY = 8;
17775
17776 if (wg[loopOrder[0]] <= 0) wg[loopOrder[0]] = defaultWGX;
17777 if (wg[loopOrder[1]] <= 0) wg[loopOrder[1]] = defaultWGY;
17778 if (wg[LoopK] <= 0) {
17779 if (kParallelLocal)
17780 wg[LoopK] = (threadsPerEU(hw, *this) * eusPerSubslice(hw))
17781 / (wg[LoopM] * wg[LoopN]);
17782 else
17783 wg[LoopK] = 1;
17784 }
17785
17786 kParallelLocal &= (wg[LoopK] > 1);
17787
17788 skewLocalIDs &= (wg[LoopM] * wg[LoopN] > eusPerSubslice(hw));
17789
17790 if (skewLocalIDs) forceWGUpdate = WGFixed;
17791
17792 avoidIncConflicts &= (hw >= HW::XeHP);
17793
17794 CommonStrategy::preflight(hw, problem);
17795}
17796
17797// Reduce register pressure. Returns true if successful.
17798bool GEMMStrategy::minimize(HW hw, const GEMMProblem &problem) {
17799 bool better = false;
17800 auto minOPCount = minOuterProductCount(hw, problem, *this);
17801 auto ka_load_best_min = std::max<int>({1, 4 / problem.Ta, minOPCount});
17802 auto kb_load_best_min = std::max<int>({1, 4 / problem.Tb, minOPCount});
17803
17804 // Reduce ka/b_load down to suggested minimums (not requiring crosspack)
17805 if (ka_load > ka_load_best_min) {
17806 ka_load = ka_load_best_min;
17807 better = true;
17808 }
17809 if (kb_load > kb_load_best_min) {
17810 kb_load = kb_load_best_min;
17811 better = true;
17812 }
17813
17814 // Reduce A/B copies.
17815 A_copies = B_copies = 1;
17816
17817 // Remove k chaining.
17818 kChain = 1;
17819
17820 // Reduce k unroll for SLM copies.
17821 if (slmA || slmB) {
17822 auto oldUK = unroll[LoopK];
17823 unroll[LoopK] = 1;
17824 unrollKSLM = 0;
17825 preflight(hw, problem);
17826 better |= (unroll[LoopK] < oldUK);
17827 }
17828
17829 if (better) return better;
17830
17831 // Reduce ka/b_load to absolute minimum if that failed.
17832 if (ka_load > minOPCount) {
17833 ka_load = minOPCount;
17834 better = true;
17835 }
17836 if (kb_load > minOPCount) {
17837 kb_load = minOPCount;
17838 better = true;
17839 }
17840
17841 return better;
17842}
17843
17844// Validate a GEMM superkernel strategy, correcting settings as necessary.
17845void GEMMSuperkernelStrategy::preflight(HW hw, const GEMMProblem &problem) {
17846 if (substrategies.size() <= 0)
17847 throw std::runtime_error("No substrategies for superkernel.");
17848 auto subgroupSize = substrategies[0].subgroupSize;
17849 for (auto &ss : substrategies) {
17850 ss.insideSK = true;
17851 ss.preflight(hw, problem);
17852 if (ss.subgroupSize != subgroupSize)
17853 throw std::runtime_error("Incompatible subgroup sizes.");
17854 }
17855}
17856
17857void MatrixAddressingStrategy::preflight(HW hw) {
17858 newDP |= isBlock2D(accessType);
17859 if (prefetch && newDP && cachingR == CacheSettingsLSC::Default)
17860 cachingR = CacheSettingsLSC::L1C_L3C;
17861
17862 if (accessType == AccessType::ChannelScattered && base.isStateless()
17863 && !newDP)
17864 base = AddressBase::createBTS(0);
17865}
17866
17867void MatrixAddressingStrategy::forceA64() {
17868 base = AddressBase::createA64(true);
17869 if (accessType == AccessType::ChannelScattered && !newDP)
17870 accessType = AccessType::Scattered;
17871}
17872
17873/**********************************************************************/
17874/* Fixed Systolic GEMM (XeHP/XeHPG) */
17875/**********************************************************************/
17876namespace sysgemm {
17877static GRFRange A_copy0 = GRF(40) - GRF(47);
17878static GRFRange B_copy0 = GRF(2) - GRF(13);
17879static GRFRange A_regs = GRF(48) - GRF(63);
17880static GRFRange B_regs = GRF(14) - GRF(37);
17881static GRFRange C_regs = GRF(64) - GRF(255);
17882static GRFRange A_copy1 = GRF(96) - GRF(103);
17883static GRFRange B_copy1 = GRF(104) - GRF(111);
17884static GRFRange A_copy2 = GRF(144) - GRF(151);
17885static GRFRange B_copy2 = GRF(152) - GRF(159);
17886static GRFRange A_copy[3] = {A_copy0, A_copy1, A_copy2};
17887static GRFRange B_copy[3] = {B_copy0, B_copy1, B_copy2};
17888static GRF addr0 = GRF(1);
17889static GRF addr1 = GRF(38);
17890static GRF addr2 = GRF(39);
17891static GRF addr3 = GRF(0);
17892static Subregister A_ptr64 = addr1.uq(3);
17893static Subregister B_ptr64 = addr2.uq(3);
17894static Subregister C_ptr64 = addr2.uq(2);
17895static Subregister slmAOffsetLoad = addr1.uw(8); // offsets in OWords
17896static Subregister slmBOffsetLoad = addr1.uw(9);
17897static Subregister slmAOffsetStore = addr1.uw(10);
17898static Subregister slmBOffsetStore = addr1.uw(11);
17899static Subregister slmAOffsetLoadInit = addr1.uw(6);
17900static Subregister slmBOffsetLoadInit = addr1.uw(7);
17901static Subregister slmAOffsetStoreInit = addr2.uw(6);
17902static Subregister slmBOffsetStoreInit = addr2.uw(7);
17903static Subregister kCounter = AccumulatorRegister(2).d(0);
17904static Subregister barrierVal = AddressRegister(0).ud(0);
17905static constexpr int accStride = 48;
17906} // namespace sysgemm
17907
17908template <HW hw>
17909bool gemm_kernel_generator_t<hw>::sysgemmAccumulateC(
17910 GEMMProblem &problem, const GEMMStrategy &strategy, GEMMState &state) {
17911 using namespace sysgemm;
17912 auto params = systolicParams(hw, problem, strategy);
17913 auto unrollM = strategy.unroll[LoopM];
17914 auto unrollN = strategy.unroll[LoopN];
17915 auto wgM = strategy.wg[LoopM];
17916 auto wgN = strategy.wg[LoopN];
17917 auto localIDM = state.lidM;
17918 auto localIDN = state.lidN;
17919 bool doubleM = (wgM == 8);
17920 bool surfaceAB = !strategy.A.base.isStateless();
17921 bool surfaceC = !strategy.C.base.isStateless();
17922
17923 if (unrollM != 32) stub();
17924 if (unrollN != 32 && unrollN != 48) stub();
17925 if (wgM != 4 && wgM != 8) stub();
17926 if (wgN != 4) stub();
17927 if (strategy.A.base.getModel() != strategy.B.base.getModel()) stub();
17928 if (problem.A.layout != MatrixLayout::Pc) stub();
17929 if (problem.A.crosspack != params.opsPerChan) stub();
17930 if (problem.A.tileR != params.osys) stub();
17931 if (problem.A.tileC != params.ksys) stub();
17932 if (problem.B.layout != MatrixLayout::Pr) stub();
17933 if (problem.B.crosspack != params.ksys) stub();
17934 if (problem.B.tileR != 0 || problem.B.tileC != 0) stub();
17935
17936 state.ra.claim(C_regs);
17937
17938 // Adjust A/B addresses and SLM offsets.
17939 auto tempStorage = C_regs[0];
17940 auto suboffsetA = tempStorage.ud(0);
17941 auto suboffsetB = tempStorage.ud(1);
17942 auto tempA = tempStorage.ud(2);
17943 auto wlidM = tempStorage.uw(6);
17944 auto tempB = tempStorage.ud(4);
17945 auto suboffsetBl = tempStorage.ud(5);
17946
17947 if (doubleM) {
17948 and_(1, wlidM, localIDM, 3);
17949 and_(1 | ne | f1[1], null.uw(), localIDM, 4);
17950 }
17951 and_(1 | ne | state.flagAP, null.uw(), localIDM, 1);
17952 mulConstant(1, suboffsetA, localIDN, unrollM * (32 / 4));
17953 if (doubleM) {
17954 mulConstant(1, suboffsetB, wlidM, unrollN * (32 / 4));
17955 mulConstant(1, suboffsetBl, localIDM, unrollN * (32 / 4));
17956 } else {
17957 mulConstant(1, suboffsetB, localIDM, unrollN * (32 / 4));
17958 suboffsetBl = suboffsetB;
17959 }
17960
17961 auto A_ptr = A_ptr64, B_ptr = B_ptr64, C_ptr = C_ptr64;
17962 if (surfaceAB) {
17963 if (!strategy.A.newDP || !strategy.B.newDP) stub();
17964 A_ptr = A_ptr.ud();
17965 B_ptr = B_ptr.ud();
17966 }
17967 if (surfaceC) C_ptr = C_ptr.ud();
17968
17969 eadd(1, A_ptr, state.effA, suboffsetA, strategy, state);
17970 eadd(1, B_ptr, state.effB, suboffsetBl, strategy, state);
17971 emov(1, C_ptr, state.effC[0], strategy, state);
17972
17973 shr(2, suboffsetA(1), suboffsetA(1), 4);
17974
17975 mul(1, tempA, localIDM, (unrollM * 36) / 16);
17976 mad(1, tempB, (wgM * unrollM * 36) / 16, localIDN, (unrollN * 32) / 16);
17977
17978 mov(1, slmAOffsetLoadInit.uw(), tempA.uw());
17979 add(1 | state.flagAP, slmBOffsetLoadInit.uw(), tempB.uw(),
17980 (unrollN / 2) * (32 / 16));
17981 mov(1 | ~state.flagAP, slmBOffsetLoadInit.uw(), tempB.uw());
17982 add(1, slmAOffsetStoreInit.uw(), tempA.uw(), suboffsetA.uw());
17983 add(1, slmBOffsetStoreInit.uw(), tempB.uw(), suboffsetB.uw());
17984 mov(2, slmAOffsetLoad(1), slmAOffsetLoadInit(1));
17985
17986 // Marshal data needed later into acc2 for safekeeping.
17987 auto saveData = state.ra.alloc_range(2);
17988 auto kLoops = saveData[0].d(0);
17989 auto ldc = saveData[0].ud(1);
17990 auto flags = saveData[0].ud(2);
17991 auto k = saveData[0].ud(3);
17992 auto remM = saveData[0].uw(8);
17993 auto remN = saveData[0].uw(9);
17994 auto abo = saveData[0].ud(5);
17995 auto ao = saveData[0].w(10);
17996 auto bo = saveData[0].w(11);
17997 auto alpha = saveData[0].ud(6).reinterpret(0, problem.Ts.ngen());
17998 auto beta = saveData[0].ud(7).reinterpret(0, problem.Ts.ngen());
17999 auto remFusedStorage = saveData[1].ud(0);
18000 auto diagC = saveData[1].ud(1);
18001 auto saveI0 = saveData[1].ud(1);
18002 auto effCO = saveData[1].uq(1);
18003 auto saveJ0 = saveData[1].ud(3);
18004 auto slotAB = saveData[1].ud(4);
18005 auto effAs = saveData[1].uq(2).reinterpret(0, state.effA.getType());
18006 auto effBs = saveData[1].uq(3).reinterpret(0, state.effB.getType());
18007
18008 if (state.r0_info != acc0.ud()) mov<uint32_t>(8, acc0, state.r0_info);
18009
18010 add(1, kLoops, state.k, params.ksys - 1);
18011 mov(1, ldc, state.inputs.ldc[0]);
18012 if (state.inputs.flags.isValid()) mov(1, flags, state.inputs.flags);
18013 mov(1, k, state.k);
18014 if (state.remainders[LoopM].isValid())
18015 mov(1, remM, state.remainders[LoopM]);
18016 if (state.remainders[LoopN].isValid())
18017 mov(1, remN, state.remainders[LoopN]);
18018 if (state.inputs.abo.isValid())
18019 mov(1, abo, state.inputs.abo);
18020 else {
18021 if (state.inputs.ao.isValid()) mov(1, ao, state.inputs.ao);
18022 if (state.inputs.bo.isValid()) mov(1, bo, state.inputs.bo);
18023 }
18024 if (state.inputs.alpha_real.isValid())
18025 mov(1, alpha, state.inputs.alpha_real);
18026 if (state.inputs.beta_real.isValid()) mov(1, beta, state.inputs.beta_real);
18027 shr(1, kLoops, kLoops, log2(params.ksys));
18028 if (state.remFusedStorage.isValid())
18029 mov(1, remFusedStorage, state.remFusedStorage);
18030 if (state.diagC.isValid()) mov(1, diagC, state.diagC);
18031 if (state.effCO.isValid()) {
18032 effCO = effCO.reinterpret(0, state.effCO.getType());
18033 emov(1, effCO, state.effCO, strategy, state);
18034 }
18035 if (problem.hasBinaryPostOp()) {
18036 if (state.diagC.isValid()) stub();
18037 if (state.effCO.isValid() && effCO.getBytes() > 4) stub();
18038 mov(1, saveI0, state.i0);
18039 mov(1, saveJ0, state.j0);
18040 }
18041 if (problem.abOffset != ABOffset::None) {
18042 state.effAs = effAs;
18043 state.effBs = effBs;
18044 gemmCalcABOffsetAddrs(problem, strategy, state);
18045 }
18046 if (state.fusedGEMM.slotA.isValid()) {
18047 if (problem.abOffset != ABOffset::None)
18048 stub(); // Not enough room in acc2.
18049 mov(1, slotAB, state.fusedGEMM.slotA.ud());
18050 }
18051
18052 releaseSavedMNLocalIDs(state);
18053 state.ra.safeRelease(state.effA);
18054 state.ra.safeRelease(state.effB);
18055 state.ra.safeRelease(state.effC[0]);
18056 state.ra.safeRelease(state.inputs.lda);
18057 state.ra.safeRelease(state.inputs.ldb);
18058
18059 state.ra.release(state.inputs.ldc[0]);
18060 state.ra.release(state.k);
18061 state.ra.release(state.remainders[LoopM]);
18062 state.ra.release(state.remainders[LoopN]);
18063 state.ra.release(state.inputs.abo);
18064 state.ra.release(state.inputs.ao);
18065 state.ra.release(state.inputs.bo);
18066 state.ra.release(state.inputs.alpha_real);
18067 state.ra.release(state.inputs.beta_real);
18068 state.ra.release(state.remFusedStorage);
18069 state.ra.release(state.diagC);
18070 state.ra.release(state.effCO);
18071 state.ra.release(state.fusedGEMM.slotA);
18072 state.ra.release(state.fusedGEMM.slotB);
18073
18074 if (state.r0_info.isARF()) stub();
18075 GRF r0_info {state.r0_info.getBase()};
18076 if (hw >= HW::XeHPG) {
18077 mov(1, barrierVal.uw(0), Immediate::uw(0));
18078 mov(2, barrierVal.ub(2)(1), r0_info.ub(11)(0));
18079 } else
18080 and_(1, barrierVal, r0_info.ud(2), 0x7F000000);
18081
18082 mov<float>(16, acc2, saveData[0]);
18083
18084 sync.nop(SWSB<AllPipes>(1));
18085
18086 if (!doubleM)
18087 sysgemmKLoop(problem, strategy, state);
18088 else {
18089 Label oddB, done;
18090 jmpi(1 | f1[1], oddB);
18091 sysgemmKLoop4(problem, strategy, state, false);
18092 jmpi(1, done);
18093 mark(oddB);
18094 sysgemmKLoop4(problem, strategy, state, true);
18095 mark(done);
18096 }
18097
18098 mov<float>(16, saveData[0], acc2);
18099
18100 state.effC[0] = C_ptr;
18101 state.inputs.ldc[0] = ldc;
18102 if (state.inputs.flags.isValid()) state.inputs.flags = flags;
18103 state.k = k;
18104 if (state.remainders[LoopM].isValid()) state.remainders[LoopM] = remM;
18105 if (state.remainders[LoopN].isValid()) state.remainders[LoopN] = remN;
18106 if (state.inputs.abo.isValid()) state.inputs.abo = abo;
18107 if (state.inputs.ao.isValid()) state.inputs.ao = ao;
18108 if (state.inputs.bo.isValid()) state.inputs.bo = bo;
18109 if (state.inputs.alpha_real.isValid()) {
18110 state.inputs.alpha_real = alpha;
18111 if (!problem.alpha_real.fixed()) problem.alpha_real = alpha;
18112 }
18113 if (state.inputs.beta_real.isValid()) {
18114 state.inputs.beta_real = beta;
18115 if (!problem.beta_real.fixed()) problem.beta_real = beta;
18116 }
18117 if (state.remFusedStorage.isValid()) {
18118 state.remFusedStorage = remFusedStorage;
18119 state.remaindersFused[LoopM] = state.remainders[LoopM];
18120 state.remaindersFused[LoopN] = state.remainders[LoopN];
18121 state.remaindersFused[strategy.fusedLoop] = remFusedStorage;
18122 }
18123 if (state.diagC.isValid()) state.diagC = diagC;
18124 if (state.effCO.isValid()) state.effCO = effCO;
18125 if (state.fusedGEMM.slotA.isValid()) {
18126 state.fusedGEMM.slotA = slotAB.uw(0);
18127 state.fusedGEMM.slotB = slotAB.uw(1);
18128 }
18129 if (problem.hasBinaryPostOp()) {
18130 state.i0 = saveI0;
18131 state.j0 = saveJ0;
18132 }
18133
18134 state.ra.claim(C_ptr);
18135
18136 // Set up C internal layout and registers.
18137 state.C_regs.resize(1);
18138 state.C_regs[0] = C_regs;
18139 state.C_layout.clear();
18140 state.C_layout.reserve((unrollM / 8) * (unrollN / 4));
18141 for (int j0 = 0; j0 < unrollN; j0 += 4) {
18142 for (int i0 = 0; i0 < unrollM; i0 += 8) {
18143 RegisterBlock block;
18144 block.log2GRFBytes = GRF::log2Bytes(hw);
18145 block.colMajor = true;
18146 block.splitComplex = false;
18147 block.cxComponent = RegisterBlock::Interleaved;
18148 block.nr = block.ld = 8;
18149 block.nc = 4;
18150 block.component = 0;
18151 block.offsetR = i0;
18152 block.offsetC = j0;
18153 block.crosspack = 1;
18154 block.bytes = 8 * 4 * problem.Tc.size();
18155 block.simdSize = 0;
18156
18157 int j0Interleaved = j0 << 1;
18158 if (j0Interleaved >= unrollN) j0Interleaved += 4 - unrollN;
18159
18160 block.offsetBytes
18161 = (accStride * i0 / 8 + j0Interleaved) * GRF::bytes(hw);
18162 state.C_layout.push_back(block);
18163 }
18164 }
18165
18166 // Set up C external layout.
18167 state.copyC = true;
18168 bool remM_Ce, remN_Ce;
18169 getCRemainders(problem, strategy, remM_Ce, remN_Ce);
18170
18171 if (!getRegLayout(problem.Tc_ext, state.C_layoutExt, unrollM, unrollN,
18172 remM_Ce, remN_Ce, true, false, 0, 0, problem.C,
18173 state.Cext_strategy))
18174 return false;
18175 if (remM_Ce || remN_Ce)
18176 (void)getRegLayout(problem.Tc_ext, state.C_layoutExtUnmasked, unrollM,
18177 unrollN, false, false, true, false, 0, 0, problem.C,
18178 state.Cext_strategy);
18179
18180 if (state.r0_info != acc0.ud()) mov<uint32_t>(8, state.r0_info, acc0);
18181
18182 return true; // Success!
18183}
18184
18185template <HW hw>
18186void gemm_kernel_generator_t<hw>::sysgemmKLoop(const GEMMProblem &problem,
18187 const GEMMStrategy &strategy, GEMMState &state) {
18188 using namespace sysgemm;
18189 Label top, bottom, skipMain, remTop, remBottom;
18190
18191 auto nbBarrierWait = [&]() {
18192 if (!strategy.slmAltBarriers) barrierwait();
18193 };
18194 auto nbStoreSignal = [&](bool forceFence = false) {
18195 if (!strategy.slmAltBarriers)
18196 sysgemmStoreSignal(problem, strategy, state, forceFence);
18197 };
18198 auto storeSignal = [&](bool forceFence = false) {
18199 sysgemmStoreSignal(problem, strategy, state, forceFence);
18200 };
18201 auto copyLoad = [&](int storeBuffer, bool useC = false) {
18202 sysgemmCopyLoad(problem, strategy, state, storeBuffer, useC);
18203 };
18204 auto copyStore = [&](int storeBuffer, bool first = false) {
18205 sysgemmCopyStore(problem, strategy, state, storeBuffer, first);
18206 };
18207 auto multiply = [&](int buffer, bool lastMultiply = false) {
18208 sysgemmMultiply(problem, strategy, state, buffer, lastMultiply);
18209 };
18210
18211 bool oldDefaultAutoSWSB = getDefaultAutoSWSB();
18212 setDefaultAutoSWSB(false);
18213
18214 if (strategy.slmCopies == 1) {
18215 cmp(1 | lt | f1[1], kCounter, 3);
18216 add(1 | le | f0[1], kCounter, kCounter, -5);
18217
18218 jmpi(1 | f1[1], skipMain);
18219
18220 copyLoad(0, true); // L0 -> C
18221 copyLoad(1); // L1
18222 copyStore(0, true); // S0 <- C
18223 storeSignal(true); // Signal 0 ready
18224 zeroMatrix(C_regs, strategy);
18225 sync.nop(SWSB<AllPipes>(1));
18226 copyStore(1); // S1
18227
18228 nbBarrierWait(); // Wait 0 ready
18229 nbStoreSignal(); // Signal 1 ready
18230
18231 jmpi(1 | f0[1], bottom); // Zero-trip loop check
18232
18233 mark(top);
18234 add(1 | gt | f0[1], kCounter, kCounter, -3);
18235
18236 copyLoad(2); // L2
18237 multiply(0); // M0
18238 nbBarrierWait(); // Wait 0 ready
18239 copyStore(2); // S2
18240 nbStoreSignal(); // Signal 2 ready
18241
18242 copyLoad(0); // L0
18243 multiply(1); // M1
18244 nbBarrierWait(); // Wait 2 ready
18245 copyStore(0); // S0
18246 nbStoreSignal(); // Signal 0 ready
18247
18248 copyLoad(1); // L1
18249 multiply(2); // M2
18250 nbBarrierWait(); // Wait 0 ready
18251 copyStore(1); // S1
18252 nbStoreSignal(); // Signal 1 ready
18253
18254 jmpi(1 | f0[1], top);
18255 mark(bottom);
18256
18257 copyLoad(2); // L2
18258 multiply(0); // M0
18259 nbBarrierWait(); // Wait 1 ready
18260 copyStore(2); // S2
18261 nbStoreSignal(); // Signal 2 ready
18262
18263 multiply(1); // M1
18264
18265 nbBarrierWait(); // Wait 2 ready
18266
18267 multiply(2, true); // M2
18268
18269 add(1 | le | f0[1], kCounter, kCounter, 2);
18270 jmpi(1 | f0[1], remBottom);
18271 jmpi(1, remTop);
18272
18273 mark(skipMain);
18274
18275 zeroMatrix(C_regs, strategy);
18276 add(1, kCounter, kCounter, 5);
18277
18278 mov(2, slmAOffsetStore(1), slmAOffsetStoreInit(1));
18279 sync.nop(SWSB<AllPipes>(1));
18280
18281 mark(remTop);
18282
18283 cmp(1 | lt | f0[1], kCounter, 2);
18284 copyLoad(0);
18285 copyStore(0);
18286 storeSignal(true);
18287 nbBarrierWait();
18288 multiply(0, true);
18289
18290 jmpi(1 | f0[1], remBottom);
18291 copyLoad(1);
18292 copyStore(1);
18293 storeSignal(true);
18294 nbBarrierWait();
18295 multiply(1, true);
18296
18297 mark(remBottom);
18298 } else if (strategy.slmCopies == 3) {
18299 // Triple-buffered global memory load + SLM pipeline.
18300 cmp(1 | lt | f1[1], kCounter, 4);
18301 add(1 | le | f0[1], kCounter, kCounter, -6);
18302
18303 jmpi(1 | f1[1], skipMain);
18304
18305 copyLoad(0); // L0
18306 copyLoad(1); // L1
18307 copyLoad(2); // L2
18308 copyStore(0, true); // S0
18309 storeSignal(true); // Signal 0 ready
18310 zeroMatrix(C_regs, strategy);
18311 copyLoad(0); // L0
18312 sync.nop(SWSB<uint32_t>(1));
18313 copyStore(1); // S1
18314
18315 nbBarrierWait(); // Wait 0 ready
18316 nbStoreSignal(); // Signal 1 ready
18317
18318 jmpi(1 | f0[1], bottom); // Zero-trip loop check
18319
18320 mark(top);
18321 add(1 | gt | f0[1], kCounter, kCounter, -3);
18322
18323 copyLoad(1); // L1
18324 multiply(0); // M0
18325 nbBarrierWait(); // Wait 0 ready
18326 copyStore(2); // S2
18327 nbStoreSignal(); // Signal 2 ready
18328
18329 copyLoad(2); // L2
18330 multiply(1); // M1
18331 nbBarrierWait(); // Wait 2 ready
18332 copyStore(0); // S0
18333 nbStoreSignal(); // Signal 0 ready
18334
18335 copyLoad(0); // L0
18336 multiply(2); // M2
18337 nbBarrierWait(); // Wait 0 ready
18338 copyStore(1); // S1
18339 nbStoreSignal(); // Signal 1 ready
18340
18341 jmpi(1 | f0[1], top);
18342 mark(bottom);
18343
18344 multiply(0); // M0
18345 nbBarrierWait(); // Wait 1 ready
18346 copyStore(2); // S2
18347 nbStoreSignal(); // Signal 2 ready
18348
18349 multiply(1); // M1
18350 nbBarrierWait(); // Wait 2 ready
18351 copyStore(0); // S0
18352 nbStoreSignal(); // Signal 0 ready
18353
18354 multiply(2); // M2
18355
18356 nbBarrierWait(); // Wait 0 ready
18357
18358 multiply(0, true); // M0
18359
18360 add(1 | le | f0[1], kCounter, kCounter, 2);
18361 jmpi(1 | f0[1], remBottom);
18362 jmpi(1, remTop);
18363
18364 mark(skipMain);
18365
18366 zeroMatrix(C_regs, strategy);
18367 add(1 | le | f0[1], kCounter, kCounter, 5);
18368
18369 mov(2, slmAOffsetStore(1), slmAOffsetStoreInit(1));
18370 sync.nop(SWSB<uint32_t>(1));
18371
18372 copyLoad(0);
18373 copyStore(0);
18374 storeSignal(true);
18375 nbBarrierWait();
18376 multiply(0, true);
18377
18378 jmpi(1 | f0[1], remBottom);
18379
18380 mark(remTop);
18381
18382 cmp(1 | lt | f0[1], kCounter, 2);
18383
18384 copyLoad(1);
18385 copyStore(1);
18386 storeSignal(true);
18387 nbBarrierWait();
18388 multiply(1, true);
18389
18390 jmpi(1 | f0[1], remBottom);
18391
18392 copyLoad(2);
18393 copyStore(2);
18394 storeSignal(true);
18395 nbBarrierWait();
18396 multiply(2, true);
18397
18398 mark(remBottom);
18399 } else
18400 stub();
18401
18402 sync.allwr();
18403 setDefaultAutoSWSB(oldDefaultAutoSWSB);
18404}
18405
18406template <HW hw>
18407void gemm_kernel_generator_t<hw>::sysgemmKLoop4(const GEMMProblem &problem,
18408 const GEMMStrategy &strategy, GEMMState &state, bool oddB) {
18409 using namespace sysgemm;
18410 auto &depAddr = state.sysgemm.depAddr;
18411
18412 Label top, bottom, skipMain, done;
18413 Label skipLoad0, skipLoad1, skipLoad2;
18414 Label skipStore0, skipStore1, skipStore2;
18415 Label sskipLoad12, sskipStore1, sskipStore2Load3, sskipStore3;
18416
18417 auto clearDepAddr = [&]() {
18418 for (int i = 0; i < 4; i++)
18419 depAddr[i] = InstructionModifier();
18420 };
18421 auto storeSignal
18422 = [&]() { sysgemmStoreSignal(problem, strategy, state, true); };
18423 auto copyLoad = [&](int storeBuffer, int useC = 0, bool forceLoadB = false,
18424 RegData flagLoadB = RegData()) {
18425 sysgemmCopyLoad4(problem, strategy, state, storeBuffer,
18426 ((storeBuffer & 1) != oddB) || forceLoadB, useC, flagLoadB);
18427 };
18428 auto copyLoadRem = [&](int storeBuffer, RegData flagLoadB) {
18429 sysgemmCopyLoad4(problem, strategy, state, storeBuffer,
18430 (storeBuffer & 1) != oddB, 0, flagLoadB);
18431 };
18432 auto copyStore = [&](int storeBuffer, int useC = 0, int useC_B = 0) {
18433 sysgemmCopyStore4(problem, strategy, state, storeBuffer,
18434 (storeBuffer & 1) == oddB, useC, useC_B);
18435 };
18436 auto multiply = [&](int buffer, bool firstMultiply = false) {
18437 sysgemmMultiply4(problem, strategy, state, buffer, firstMultiply);
18438 };
18439 auto multiplyRem = [&](int buffer, RegData flagWaitLoad, RegData flagSignal,
18440 bool firstMultiply = false) {
18441 sysgemmMultiply4(problem, strategy, state, buffer, firstMultiply,
18442 flagWaitLoad, flagSignal, &done);
18443 };
18444 auto slmRead = [&]() {
18445 mov(1 | depAddr[0], addr0.ud(2), slmAOffsetLoad);
18446 mov(1 | depAddr[1], addr1.ud(2), slmBOffsetLoad);
18447 add(1 | depAddr[2], addr2.ud(2), slmBOffsetLoad, 8 * 32 / 16);
18448 add(1 | depAddr[3], addr3.ud(2), slmBOffsetLoad, 16 * 32 / 16);
18449
18450 load(16 | SWSB<AllPipes>(sb3, 4), A_regs[0], block_oword(16), SLM,
18451 addr0);
18452 load(16 | SWSB<AllPipes>(sb0, 3), B_regs[0], block_oword(16), SLM,
18453 addr1);
18454 load(16 | SWSB<AllPipes>(sb1, 2), B_regs[8], block_oword(16), SLM,
18455 addr2);
18456 load(16 | SWSB<AllPipes>(sb2, 1), B_regs[16], block_oword(16), SLM,
18457 addr3);
18458 depAddr[0] = sb3.src;
18459 depAddr[1] = sb0.src;
18460 depAddr[2] = sb1.src;
18461 depAddr[3] = sb2.src;
18462
18463 add(1 | depAddr[0], addr0.ud(2), slmAOffsetLoad, 8 * 32 / 16);
18464 load(16 | SWSB<AllPipes>(sb4, 1), A_regs[8], block_oword(16), SLM,
18465 addr0);
18466 depAddr[0] = sb4.src;
18467 };
18468
18469 bool oldDefaultAutoSWSB = getDefaultAutoSWSB();
18470 setDefaultAutoSWSB(false);
18471
18472 clearDepAddr();
18473 mov(1, f1.ud(), 0);
18474 mov(1, f0.ud(), 0);
18475 cmp(1 | lt | f1[1], kCounter, 4);
18476 add(1 | le | f0[1], kCounter, kCounter, -7);
18477
18478 jmpi(1 | f1[1], skipMain);
18479
18480 status << "Main path, " << (oddB ? "odd B" : "even B")
18481 << status_stream::endl;
18482
18483 copyLoad(0, 1, true); // L0 -> C1
18484 copyLoad(1, 2); // L1 -> C2
18485 copyLoad(2); // L2
18486 copyStore(0, 1, 1); // S0 <- C1
18487 storeSignal();
18488 copyStore(1, 2, 1); // S1 <- C2
18489 barrierwait();
18490 slmRead();
18491 storeSignal();
18492 copyStore(2, 0, 2); // S2
18493 if (!oddB) sync.allrd(0x3000);
18494 zeroMatrix(C_regs, strategy);
18495 sync.allrd(SWSB<AllPipes>(1));
18496
18497 jmpi(1 | f0[1], bottom); // Zero-trip loop check
18498
18499 depAddr[0] = sb8.src;
18500 depAddr[1] = !oddB ? sb9.src : sb0.src;
18501 depAddr[2] = !oddB ? sb10.src : sb4.src;
18502 depAddr[3] = sb3.src;
18503
18504 mark(top);
18505 add(1 | gt | f0[1], kCounter, kCounter, -4);
18506
18507 copyLoad(3);
18508 multiply(0);
18509 copyStore(3);
18510
18511 copyLoad(0);
18512 multiply(1);
18513 copyStore(0);
18514
18515 copyLoad(1);
18516 multiply(2);
18517 copyStore(1);
18518
18519 copyLoad(2);
18520 multiply(3);
18521 copyStore(2);
18522
18523 jmpi(1 | f0[1], top);
18524 mark(bottom);
18525
18526 cmp(1 | gt | f0[0], kCounter, -4 + 1);
18527 cmp(1 | gt | f0[1], kCounter, -4 + 2);
18528 cmp(1 | gt | f1[0], kCounter, -4 + 3);
18529
18530 copyLoadRem(3, f0[0]);
18531 multiply(0);
18532 copyStore(3);
18533
18534 sync.allrd();
18535 jmpi(1 | ~f0[0], skipLoad0);
18536 copyLoadRem(0, f0[1]);
18537 mark(skipLoad0);
18538 multiply(1);
18539 sync.allrd();
18540 jmpi(1 | ~f0[0], skipStore0);
18541 copyStore(0);
18542 mark(skipStore0);
18543
18544 sync.allrd();
18545 jmpi(1 | ~f0[1], skipLoad1);
18546 copyLoadRem(1, f1[0]);
18547 mark(skipLoad1);
18548 multiplyRem(2, FlagRegister(), f0[0]);
18549 sync.allrd();
18550 jmpi(1 | ~f0[1], skipStore1);
18551 copyStore(1);
18552 mark(skipStore1);
18553
18554 sync.allrd();
18555 jmpi(1 | ~f1[0], skipLoad2);
18556 copyLoadRem(2, null);
18557 mark(skipLoad2);
18558 multiplyRem(3, f0[0], f0[1]);
18559 sync.allrd();
18560 jmpi(1 | ~f1[0], skipStore2);
18561 copyStore(2);
18562 mark(skipStore2);
18563
18564 multiplyRem(0, f0[1], f1[0]);
18565 multiplyRem(1, f1[0], null);
18566 multiplyRem(2, null, null);
18567
18568 jmpi(1, done);
18569
18570 status << "Small-k path, " << (oddB ? "odd B" : "even B")
18571 << status_stream::endl;
18572
18573 clearDepAddr();
18574 mark(skipMain);
18575
18576 // Short loops: special case for 1-4 unrolls
18577 cmp(1 | gt | f0[0], kCounter, -7 + 1);
18578 cmp(1 | gt | f0[1], kCounter, -7 + 2);
18579 cmp(1 | gt | f1[0], kCounter, -7 + 3);
18580
18581 auto flagLoadB0 = oddB ? f0[0] : FlagRegister();
18582 copyLoad(0, 1, true, flagLoadB0);
18583 sync.allrd();
18584 jmpi(1 | ~f0[0], sskipLoad12);
18585 copyLoad(1, 2, false, f0[1]);
18586 sync.allrd();
18587 jmpi(1 | ~f0[1], sskipLoad12);
18588 copyLoadRem(2, f1[0]);
18589 mark(sskipLoad12);
18590 copyStore(0, 1, 1);
18591 storeSignal();
18592 sync.allrd();
18593 jmpi(1 | ~f0[0], sskipStore1);
18594 copyStore(1, 2, 1);
18595 mark(sskipStore1);
18596 barrierwait();
18597 slmRead();
18598 sync.allrd();
18599 jmpi(1 | ~f0[0], sskipStore2Load3);
18600 storeSignal();
18601 sync.allrd();
18602 jmpi(1 | ~f0[1], sskipStore2Load3);
18603 copyStore(2, 0, 2);
18604 sync.allrd();
18605 jmpi(1 | ~f1[0], sskipStore2Load3);
18606 copyLoadRem(3, null);
18607 mark(sskipStore2Load3);
18608 multiplyRem(0, f0[0], f0[1], true);
18609 jmpi(1 | ~f0[0], done);
18610 sync.allrd();
18611 jmpi(1 | ~f1[0], sskipStore3);
18612 copyStore(3);
18613 mark(sskipStore3);
18614 multiplyRem(1, f0[1], f1[0]);
18615 multiplyRem(2, f1[0], null);
18616 multiplyRem(3, null, null);
18617
18618 mark(done);
18619
18620 sync.allwr();
18621 setDefaultAutoSWSB(oldDefaultAutoSWSB);
18622}
18623
18624template <HW hw>
18625void gemm_kernel_generator_t<hw>::sysgemmStoreSignal(const GEMMProblem &problem,
18626 const GEMMStrategy &strategy, GEMMState &state, bool forceFence) {
18627 using namespace sysgemm;
18628 auto &depAddr = state.sysgemm.depAddr;
18629
18630 if (!strategy.slmAltBarriers || forceFence) {
18631 // Signal SLM data ready once memory fence returns, asynchronously
18632 sync.nop(depAddr[0]);
18633 sysgemmBarrierPrep(depAddr[3], addr3);
18634
18635 slmfence(SWSB<AllPipes>(sb15, 1), addr0);
18636 barriermsg(sb15, addr3);
18637 depAddr[0] = InstructionModifier();
18638 depAddr[3] = sb15.src;
18639 } else {
18640 sysgemmBarrierPrep(depAddr[3], addr3);
18641 barriermsg(SWSB<AllPipes>(sb15, 1), addr3);
18642 depAddr[3] = sb15.src;
18643 }
18644}
18645
18646template <HW hw>
18647void gemm_kernel_generator_t<hw>::sysgemmCopyLoad(const GEMMProblem &problem,
18648 const GEMMStrategy &strategy, GEMMState &state, int storeBuffer,
18649 bool useC) {
18650 using namespace sysgemm;
18651 auto &depAddr = state.sysgemm.depAddr;
18652
18653 bool surface = !strategy.A.base.isStateless();
18654 bool emulate64 = strategy.emulate.emulate64;
18655 int unrollM = strategy.unroll[LoopM];
18656 int unrollN = strategy.unroll[LoopN];
18657
18658 auto A_ptr = A_ptr64, B_ptr = B_ptr64;
18659 if (surface) {
18660 A_ptr = A_ptr.ud();
18661 B_ptr = B_ptr.ud();
18662 emulate64 = false;
18663 }
18664
18665 // Load new A and B and increment load pointers.
18666 if (surface) {
18667 sync(SyncFunction::nop, SWSB<uint32_t>(1));
18668 mov(1 | depAddr[0], addr0.ud(0), A_ptr);
18669 mov(1 | depAddr[1], addr1.ud(0), B_ptr);
18670 add(1 | depAddr[2], addr2.ud(0), B_ptr, 8 * 32);
18671 } else if (!emulate64) {
18672 sync(SyncFunction::nop, SWSB<uint64_t>(1));
18673 mov(1 | depAddr[0], addr0.uq(0), A_ptr);
18674 mov(1 | depAddr[1], addr1.uq(0), B_ptr);
18675 add(1 | depAddr[2], addr2.uq(0), B_ptr, 8 * 32);
18676 } else {
18677 sync(SyncFunction::nop, SWSB<uint32_t>(1));
18678 mov(1 | depAddr[2], addr2.ud(1), B_ptr.ud(1));
18679 add(1 | ov | f1[1], addr2.ud(0), B_ptr.ud(0), 8 * 32);
18680 mov(2 | depAddr[0], addr0.ud(0)(1), A_ptr.ud()(1));
18681 mov(2 | depAddr[1], addr1.ud(0)(1), B_ptr.ud()(1));
18682 add(1 | f1[1] | SWSB(4), addr2.ud(1), addr2.ud(1), 1);
18683 }
18684
18685 if (useC) {
18686 if (surface) {
18687 load(1 | SWSB<AllPipes>(sb11, 3), C_regs[0],
18688 D64T(32) | strategy.A.cachingR, strategy.A.base, addr0);
18689 load(1 | SWSB<AllPipes>(sb12, 2), C_regs[8],
18690 D64T(32) | strategy.B.cachingR, strategy.B.base, addr1);
18691 if (strategy.unroll[LoopN] > 32)
18692 load(1 | SWSB<AllPipes>(sb13, 1), C_regs[16],
18693 D64T(16) | strategy.B.cachingR, strategy.B.base, addr2);
18694 } else {
18695 load(16 | SWSB<AllPipes>(sb11, 3), C_regs[0], block_hword(8), A64,
18696 addr0);
18697 load(16 | SWSB<AllPipes>(sb12, 2), C_regs[8], block_hword(8), A64,
18698 addr1);
18699 if (strategy.unroll[LoopN] > 32)
18700 load(16 | SWSB<AllPipes>(sb13, 1), C_regs[16], block_hword(4),
18701 A64, addr2);
18702 }
18703 depAddr[0] = sb11.src;
18704 depAddr[1] = sb12.src;
18705 if (strategy.unroll[LoopN] > 32) depAddr[2] = sb13.src;
18706 if (strategy.simulation) sync.allrd(0x3000);
18707 } else {
18708 // Stronger than necessary dependencies... can load as soon as prev. store inputs are read.
18709 int loadBuffer = (strategy.slmCopies == 3) ? storeBuffer : 0;
18710 int t0 = 8 + loadBuffer * 2;
18711 SBID token0 {t0}, token1 {t0 + 1}, token2 {t0 + 2};
18712
18713 if (surface) {
18714 load(1 | SWSB<AllPipes>(token0, 3), A_copy[loadBuffer][0],
18715 D64T(32) | strategy.A.cachingR, strategy.A.base, addr0);
18716 load(1 | SWSB<AllPipes>(token1, 2), B_copy[loadBuffer][0],
18717 D64T(32) | strategy.B.cachingR, strategy.B.base, addr1);
18718 if (strategy.unroll[LoopN] > 32)
18719 load(1 | SWSB<AllPipes>(token2, 1), B_copy[loadBuffer][8],
18720 D64T(16) | strategy.B.cachingR, strategy.B.base, addr2);
18721 } else {
18722 load(16 | SWSB<AllPipes>(token0, 3), A_copy[loadBuffer][0],
18723 block_hword(8), A64, addr0);
18724 load(16 | SWSB<AllPipes>(token1, 2), B_copy[loadBuffer][0],
18725 block_hword(8), A64, addr1);
18726 if (strategy.unroll[LoopN] > 32)
18727 load(16 | SWSB<AllPipes>(token2, 1), B_copy[loadBuffer][8],
18728 block_hword(4), A64, addr2);
18729 }
18730 depAddr[0] = token0.src;
18731 depAddr[1] = token1.src;
18732 if (strategy.unroll[LoopN] > 32) depAddr[2] = token2.src;
18733 if (strategy.simulation) sync.allrd(0x6 << t0);
18734 }
18735
18736 if (!emulate64) {
18737 add(1 | SWSB(3), A_ptr, A_ptr, unrollM * 32);
18738 add(1 | SWSB(3), B_ptr, B_ptr, unrollN * 32);
18739 } else {
18740 add(1 | ov | f1[0] | SWSB(3), A_ptr.ud(0), A_ptr.ud(0), unrollM * 32);
18741 add(1 | ov | f1[1] | SWSB(3), B_ptr.ud(0), B_ptr.ud(0), unrollN * 32);
18742 add(1 | f1[0], A_ptr.ud(1), A_ptr.ud(1), 1);
18743 add(1 | f1[1], B_ptr.ud(1), B_ptr.ud(1), 1);
18744 }
18745}
18746
18747template <HW hw>
18748void gemm_kernel_generator_t<hw>::sysgemmCopyLoad4(const GEMMProblem &problem,
18749 const GEMMStrategy &strategy, GEMMState &state, int storeBuffer,
18750 bool loadB, int useC, RegData flagLoadB) {
18751 using namespace sysgemm;
18752 auto &depAddr = state.sysgemm.depAddr;
18753
18754 bool emulate64 = strategy.emulate.emulate64;
18755 bool surface = !strategy.A.base.isStateless();
18756 int unrollM = strategy.unroll[LoopM];
18757 int unrollN = strategy.unroll[LoopN];
18758 int x48 = (unrollN > 32);
18759 InstructionModifier loadBMod;
18760 loadB &= !(flagLoadB.isValid() && flagLoadB.isNull());
18761 if (flagLoadB.isValid())
18762 loadBMod = loadBMod | static_cast<FlagRegister &>(flagLoadB) | any16h;
18763
18764 // Load new A and B and increment load pointers.
18765 auto A_ptr = A_ptr64, B_ptr = B_ptr64;
18766 if (surface) {
18767 A_ptr = A_ptr.ud();
18768 B_ptr = B_ptr.ud();
18769 emulate64 = false;
18770 }
18771
18772 if (surface) {
18773 sync.nop(SWSB(Pipe::I, 1));
18774 mov(1 | depAddr[0], addr0.ud(0), A_ptr);
18775 if (loadB) {
18776 mov(1 | depAddr[1], addr1.ud(0), B_ptr);
18777 add(1 | depAddr[2], addr2.ud(0), B_ptr, 8 * 32);
18778 }
18779 } else if (!emulate64) {
18780 sync.nop(SWSB(Pipe::L, 1));
18781 mov(1 | depAddr[0], addr0.uq(0), A_ptr);
18782 if (loadB) {
18783 mov(1 | depAddr[1], addr1.uq(0), B_ptr);
18784 add(1 | depAddr[2], addr2.uq(0), B_ptr, 8 * 32);
18785 }
18786 } else {
18787 sync.nop(SWSB(Pipe::I, 1));
18788 if (loadB) {
18789 mov(1 | depAddr[2], addr2.ud(1), B_ptr.ud(1));
18790 add(1 | ov | f1[1], addr2.ud(0), B_ptr.ud(0), 8 * 32);
18791 }
18792 mov(2 | depAddr[0], addr0.ud(0)(1), A_ptr.ud()(1));
18793 if (loadB) {
18794 mov(2 | depAddr[1], addr1.ud(0)(1), B_ptr.ud()(1));
18795 add(1 | f1[1], addr2.ud(1), addr2.ud(1), 1);
18796 }
18797 }
18798
18799 SBID tokenA(0), tokenB0(0), tokenB1(0);
18800 GRF dstA, dstB0, dstB1;
18801
18802 if (useC) {
18803 tokenA = SBID((useC == 1) ? 5 : 11);
18804 tokenB0 = SBID((useC == 1) ? 6 : 12);
18805 tokenB1 = SBID((useC == 1) ? 7 : 13);
18806 int C_off = (useC == 1) ? 0 : 20;
18807 dstA = C_regs[C_off + 0];
18808 dstB0 = C_regs[C_off + 8];
18809 dstB1 = C_regs[C_off + 16];
18810 } else {
18811 // Stronger than necessary dependencies... can load as soon as prev. store inputs are read.
18812 int loadBuffer = (strategy.slmCopies == 3) ? storeBuffer : 0;
18813 int t0 = 8 + loadBuffer * 2;
18814 tokenA = SBID(t0 + 0);
18815 tokenB0 = SBID(t0 + 1);
18816 tokenB1 = SBID(t0 + 2);
18817 dstA = A_copy[loadBuffer][0];
18818 dstB0 = B_copy[loadBuffer][0];
18819 dstB1 = B_copy[loadBuffer][8];
18820 }
18821
18822 if (surface) {
18823 load(1 | tokenA | SWSB<AllPipes>(1 + loadB * (1 + x48)), dstA,
18824 D64T(32) | strategy.A.cachingR, strategy.A.base, addr0);
18825 if (loadB) {
18826 load(1 | tokenB0 | loadBMod | SWSB<AllPipes>(1 + x48), dstB0,
18827 D64T(32) | strategy.B.cachingR, strategy.B.base, addr1);
18828 if (x48)
18829 load(1 | tokenB1 | loadBMod | SWSB<AllPipes>(1), dstB1,
18830 D64T(16) | strategy.B.cachingR, strategy.B.base, addr2);
18831 }
18832 } else {
18833 load(16 | tokenA | SWSB<AllPipes>(1 + loadB * (1 + x48)), dstA,
18834 block_hword(8), A64, addr0);
18835 if (loadB) {
18836 load(16 | tokenB0 | loadBMod | SWSB<AllPipes>(1 + x48), dstB0,
18837 block_hword(8), A64, addr1);
18838 if (x48)
18839 load(16 | tokenB1 | loadBMod | SWSB<AllPipes>(1), dstB1,
18840 block_hword(4), A64, addr2);
18841 }
18842 }
18843 depAddr[0] = tokenA.src;
18844 if (loadB) {
18845 depAddr[1] = tokenB0.src;
18846 if (x48) depAddr[2] = tokenB1.src;
18847 }
18848 if (strategy.simulation) {
18849 uint16_t tmask = (1 << tokenA.getID());
18850 if (loadB) tmask |= (1 << tokenB0.getID()) | (1 << tokenB1.getID());
18851 sync.allrd(tmask);
18852 }
18853
18854 if (!emulate64) {
18855 add(1, A_ptr, A_ptr, unrollM * 32);
18856 if (loadB) add(1, B_ptr, B_ptr, 2 * unrollN * 32);
18857 } else {
18858 add(1 | ov | f1[1], A_ptr.ud(0), A_ptr.ud(0), unrollM * 32);
18859 if (loadB)
18860 add(1 | ov | f1[1] | M8, B_ptr.ud(0), B_ptr.ud(0),
18861 2 * unrollN * 32);
18862 add(1 | f1[1], A_ptr.ud(1), A_ptr.ud(1), 1);
18863 if (loadB) add(1 | f1[1] | M8, B_ptr.ud(1), B_ptr.ud(1), 1);
18864 }
18865}
18866
18867template <HW hw>
18868void gemm_kernel_generator_t<hw>::sysgemmCopyStore(const GEMMProblem &problem,
18869 const GEMMStrategy &strategy, GEMMState &state, int storeBuffer,
18870 bool first) {
18871 using namespace sysgemm;
18872 auto &depAddr = state.sysgemm.depAddr;
18873
18874 auto aoffset = first ? slmAOffsetStoreInit : slmAOffsetStore;
18875 auto boffset = first ? slmBOffsetStoreInit : slmBOffsetStore;
18876
18877 // Store A and B and advance store pointers to next buffer.
18878 mov(1 | depAddr[0], addr0.ud(2), aoffset);
18879 mov(1 | depAddr[1], addr1.ud(2), boffset);
18880 add(1 | depAddr[2], addr2.ud(2), boffset, 8 * 32 / 16);
18881
18882 if (first && strategy.slmCopies == 1) {
18883 store(16 | SWSB<AllPipes>(sb11, 3), block_oword(16), SLM, addr0,
18884 C_regs[0]);
18885 store(16 | SWSB<AllPipes>(sb12, 2), block_oword(16), SLM, addr1,
18886 C_regs[8]);
18887 if (strategy.unroll[LoopN] > 32)
18888 store(16 | SWSB<AllPipes>(sb13, 1), block_oword(8), SLM, addr2,
18889 C_regs[16]);
18890 depAddr[0] = sb11.src;
18891 depAddr[1] = sb12.src;
18892 if (strategy.unroll[LoopN] > 32) depAddr[2] = sb13.src;
18893 if (strategy.simulation) sync.allrd(0x3000);
18894 } else {
18895 int loadBuffer = (strategy.slmCopies == 3) ? storeBuffer : 0;
18896 int t0 = 8 + loadBuffer * 2;
18897 SBID token0 {t0}, token1 {t0 + 1}, token2 {t0 + 2};
18898
18899 store(16 | SWSB<AllPipes>(token0, 3), block_oword(16), SLM, addr0,
18900 A_copy[loadBuffer][0]);
18901 store(16 | SWSB<AllPipes>(token1, 2), block_oword(16), SLM, addr1,
18902 B_copy[loadBuffer][0]);
18903 if (strategy.unroll[LoopN] > 32)
18904 store(16 | SWSB<AllPipes>(token2, 1), block_oword(8), SLM, addr2,
18905 B_copy[loadBuffer][8]);
18906 depAddr[0] = token0.src;
18907 depAddr[1] = token1.src;
18908 if (strategy.unroll[LoopN] > 32) depAddr[2] = token2.src;
18909 if (strategy.simulation) sync.allrd(0x6 << t0);
18910 }
18911
18912 if (storeBuffer == 2)
18913 mov(2, slmAOffsetStore(1), slmAOffsetStoreInit(1));
18914 else
18915 add(2, slmAOffsetStore(1), aoffset(1),
18916 strategy.slmSysgemmBlockSize() / 16);
18917}
18918
18919template <HW hw>
18920void gemm_kernel_generator_t<hw>::sysgemmCopyStore4(const GEMMProblem &problem,
18921 const GEMMStrategy &strategy, GEMMState &state, int storeBuffer,
18922 bool storeB, int useC, int useC_B) {
18923 using namespace sysgemm;
18924 auto &depAddr = state.sysgemm.depAddr;
18925 bool first = (useC == 1);
18926 bool x48 = (strategy.unroll[LoopN] > 32);
18927
18928 auto aoffset = first ? slmAOffsetStoreInit : slmAOffsetStore;
18929 auto boffset = first ? slmBOffsetStoreInit : slmBOffsetStore;
18930
18931 // Store A and B and advance store pointers to next buffer.
18932 mov(1 | depAddr[0], addr0.ud(2), aoffset);
18933 if (storeB) {
18934 mov(1 | depAddr[1], addr1.ud(2), boffset);
18935 if (x48) add(1 | depAddr[2], addr2.ud(2), boffset, 8 * 32 / 16);
18936 }
18937
18938 int loadBuffer = (strategy.slmCopies == 3) ? storeBuffer : 0;
18939 int t0 = 8 + loadBuffer * 2;
18940 auto tokenA = SBID(t0 + 0);
18941 auto tokenB0 = SBID(t0 + 1);
18942 auto tokenB1 = SBID(t0 + 2);
18943 auto srcA = A_copy[loadBuffer][0];
18944 auto srcB0 = B_copy[loadBuffer][0];
18945 auto srcB1 = B_copy[loadBuffer][8];
18946
18947 if (useC) {
18948 tokenA = SBID((useC == 1) ? 5 : 11);
18949 int C_off = (useC == 1) ? 0 : 20;
18950 srcA = C_regs[C_off + 0];
18951 }
18952
18953 if (useC_B) {
18954 tokenB0 = SBID((useC_B == 1) ? 6 : 12);
18955 tokenB1 = SBID((useC_B == 1) ? 7 : 13);
18956 int C_off = (useC_B == 1) ? 0 : 20;
18957 srcB0 = C_regs[C_off + 8];
18958 srcB1 = C_regs[C_off + 16];
18959 }
18960
18961 store(16 | tokenA | SWSB<AllPipes>(1 + storeB * (1 + x48)), block_oword(16),
18962 SLM, addr0, srcA);
18963 if (storeB) {
18964 store(16 | tokenB0 | SWSB<AllPipes>(1 + x48), block_oword(16), SLM,
18965 addr1, srcB0);
18966 if (x48)
18967 store(16 | tokenB1 | SWSB<AllPipes>(1), block_oword(8), SLM, addr2,
18968 srcB1);
18969 }
18970
18971 depAddr[0] = tokenA.src;
18972 if (storeB) {
18973 depAddr[1] = tokenB0.src;
18974 if (x48) depAddr[2] = tokenB1.src;
18975 }
18976 if (strategy.simulation) {
18977 uint16_t tmask = (1 << tokenA.getID());
18978 if (storeB) tmask |= (1 << tokenB0.getID()) | (1 << tokenB1.getID());
18979 sync.allrd(tmask);
18980 }
18981
18982 if (storeBuffer == 3)
18983 mov(2, slmAOffsetStore(1), slmAOffsetStoreInit(1));
18984 else
18985 add(2, slmAOffsetStore(1), aoffset(1),
18986 strategy.slmSysgemmBlockSize() / 16);
18987}
18988
18989template <HW hw>
18990void gemm_kernel_generator_t<hw>::sysgemmMultiply(const GEMMProblem &problem,
18991 const GEMMStrategy &strategy, GEMMState &state, int buffer,
18992 bool lastMultiply) {
18993 using namespace sysgemm;
18994 auto &depAddr = state.sysgemm.depAddr;
18995
18996 // Load half of A (16x32) -- hopefully broadcast from SLM to this row -- and half of B, interleaved.
18997 InstructionModifier swsb = lastMultiply ? SWSB(1) : depAddr[0];
18998
18999 mov(1 | swsb, addr0.ud(2), slmAOffsetLoad);
19000 mov(1 | depAddr[1], addr1.ud(2), slmBOffsetLoad);
19001 add(1 | depAddr[2], addr2.ud(2), slmBOffsetLoad, 8 * 32 / 16);
19002 add(1 | depAddr[3], addr3.ud(2), slmBOffsetLoad, 16 * 32 / 16);
19003
19004 if (strategy.slmAltBarriers) barrierwait();
19005
19006 if (strategy.simulation) sync(SyncFunction::nop, SWSB<int64_t>(1));
19007 sync.nop(sb5.src);
19008 load(16 | SWSB<AllPipes>(sb3, 4), A_regs[0], block_oword(16), SLM, addr0);
19009 load(16 | SWSB<AllPipes>(sb0, 3), B_regs[0], block_oword(16), SLM, addr1);
19010 load(16 | SWSB<AllPipes>(sb1, 2), B_regs[8], block_oword(16), SLM, addr2);
19011 if (strategy.unroll[LoopN] > 32)
19012 load(16 | SWSB<AllPipes>(sb2, 1), B_regs[16], block_oword(16), SLM,
19013 addr3);
19014
19015 add(1 | sb3.src, addr0.ud(2), slmAOffsetLoad, 8 * 32 / 16);
19016 add(1 | sb0.src, addr1.ud(2), slmAOffsetLoad, 16 * 32 / 16);
19017 add(1 | sb1.src, addr2.ud(2), slmAOffsetLoad, 24 * 32 / 16);
19018 load(16 | SWSB<AllPipes>(sb4, 3), A_regs[8], block_oword(16), SLM, addr0);
19019
19020 // Wait for A data to load.
19021 sync.allwr(0x18);
19022
19023 if (strategy.slmAltBarriers && !lastMultiply) {
19024 sysgemmBarrierPrep(sb2.src, addr3);
19025 barriermsg(SWSB<AllPipes>(sb15, 1), addr3);
19026 }
19027
19028 // Rows 0-7
19029 sysgemmMultiplyChunk(
19030 problem, strategy, false, 0, 0, true, false, sb0.dst, sb3);
19031
19032 // Rows 8-15
19033 sysgemmMultiplyChunk(problem, strategy, false, 8, 8, false, false,
19034 InstructionModifier(), sb4);
19035
19036 // Load third quarter of A (8x32)
19037 load(16 | SWSB<AllPipes>(sb3, 2), A_regs[0], block_oword(16), SLM, addr1);
19038
19039 // Rows 16-23
19040 sysgemmMultiplyChunk(
19041 problem, strategy, false, 0, 16, false, false, sb3.dst);
19042
19043 // Load last quarter of A (8x32)
19044 load(16 | SWSB<AllPipes>(sb4, 1), A_regs[8], block_oword(16), SLM, addr2);
19045
19046 // Increment A and B to next buffer.
19047 swsb = strategy.simulation ? InstructionModifier(sb3.src)
19048 : InstructionModifier();
19049 if (buffer == 2)
19050 mov(2 | swsb, slmAOffsetLoad(1), slmAOffsetLoadInit(1));
19051 else
19052 add(2 | swsb, slmAOffsetLoad(1), slmAOffsetLoad(1),
19053 strategy.slmSysgemmBlockSize() / 16);
19054
19055 // Rows 24-31
19056 sysgemmMultiplyChunk(
19057 problem, strategy, false, 8, 24, false, false, sb4.dst, sb5);
19058
19059 // Remember dependencies for address registers.
19060 depAddr[0] = InstructionModifier {};
19061 depAddr[1] = sb3.src;
19062 depAddr[2] = sb4.src;
19063 depAddr[3] = strategy.slmAltBarriers ? sb15.src : sb2.src;
19064}
19065
19066template <HW hw>
19067void gemm_kernel_generator_t<hw>::sysgemmMultiply4(const GEMMProblem &problem,
19068 const GEMMStrategy &strategy, GEMMState &state, int buffer,
19069 bool firstMultiply, RegData flagWaitLoad, RegData flagSignal,
19070 Label *labelDone) {
19071 using namespace sysgemm;
19072 auto &depAddr = state.sysgemm.depAddr;
19073 bool x48 = (strategy.unroll[LoopN] > 32);
19074 uint16_t slmStride = strategy.slmSysgemmBlockSize() / 16;
19075 int16_t slmAdvance = ((buffer == 3) ? -3 : 1) * slmStride;
19076
19077 InstructionModifier loadMod {}, signalMod {};
19078 bool cooldownWaitLoad = flagWaitLoad.isValid();
19079 bool cooldownSignal = flagSignal.isValid();
19080 bool doWaitLoad = !cooldownWaitLoad || !flagWaitLoad.isNull();
19081 bool doSignal = !cooldownSignal || !flagSignal.isNull();
19082 auto fWaitLoad = static_cast<FlagRegister &>(flagWaitLoad);
19083 auto fSignal = static_cast<FlagRegister &>(flagSignal);
19084 if (doWaitLoad && cooldownWaitLoad) loadMod = loadMod | fWaitLoad | any16h;
19085 if (doSignal && cooldownSignal) signalMod = signalMod | fSignal | any16h;
19086
19087 // Fence.
19088 if (doSignal) {
19089 sync.nop(depAddr[0]);
19090 slmfence(sb15 | signalMod, addr0);
19091 depAddr[0] = sb15.dst;
19092 }
19093
19094 // Rows 0-7. Upper half of A (16x32) is already loaded.
19095 sync.nop(sb3.dst);
19096 depAddr[3] = InstructionModifier();
19097 sysgemmMultiplyChunk(
19098 problem, strategy, firstMultiply, 0, 0, true, false, sb0.dst, sb3);
19099
19100 // Prepare addresses for loading lower half of A, and part of B
19101 add(1 | depAddr[1], addr1.ud(2), slmAOffsetLoad, 16 * 32 / 16);
19102 add(1 | depAddr[2], addr2.ud(2), slmAOffsetLoad, 24 * 32 / 16);
19103 sysgemmBarrierPrep(depAddr[3], addr3);
19104
19105 // Rows 8-15.
19106 sysgemmMultiplyChunk(
19107 problem, strategy, firstMultiply, 8, 8, false, false, sb4.dst, sb4);
19108
19109 // Load lower half of A (16x32) -- hopefully broadcast from SLM to this row.
19110 load(16 | SWSB<AllPipes>(sb3, 3), A_regs[0], block_oword(16), SLM, addr1);
19111 load(16 | SWSB<AllPipes>(sb4, 2), A_regs[8], block_oword(16), SLM, addr2);
19112 depAddr[1] = sb3.src;
19113 depAddr[2] = sb4.src;
19114
19115 // Rows 16-23.
19116 sysgemmMultiplyChunk(problem, strategy, firstMultiply, 0, 16, false, false,
19117 sb3.dst, sb3);
19118 depAddr[1] = InstructionModifier();
19119
19120 // Address prep, part 2.
19121 add(1 | depAddr[1], addr1.ud(2), slmBOffsetLoad, slmAdvance + 0 * 32 / 16);
19122 add(1 | depAddr[2], addr2.ud(2), slmBOffsetLoad, slmAdvance + 8 * 32 / 16);
19123 if (x48)
19124 add(1 | depAddr[0], addr0.ud(2), slmBOffsetLoad,
19125 slmAdvance
19126 + 16 * 32
19127 / 16); // consider moving after next dpasw block
19128
19129 // Rows 24-31.
19130 sysgemmMultiplyChunk(
19131 problem, strategy, firstMultiply, 8, 24, false, true, sb4.dst, sb2);
19132
19133 if (doWaitLoad) {
19134 if (cooldownWaitLoad) jmpi(1 | ~fWaitLoad, *labelDone);
19135
19136 // Split barrier.
19137 barrierwait();
19138 if (doSignal) {
19139 barriermsg(SWSB<AllPipes>(sb15, x48 ? 4 : 3) | signalMod, addr3);
19140 depAddr[3] = sb15.src;
19141 }
19142
19143 // Load next B data and upper 16x32 of A.
19144 load(16 | SWSB<AllPipes>(sb0, x48 ? 3 : 2), B_regs[0], block_oword(16),
19145 SLM, addr1);
19146 load(16 | SWSB<AllPipes>(sb1, x48 ? 2 : 1), B_regs[8], block_oword(16),
19147 SLM, addr2);
19148 if (x48)
19149 load(16 | SWSB<AllPipes>(sb2, 1), B_regs[16], block_oword(16), SLM,
19150 addr0);
19151 depAddr[1] = sb0.src;
19152 depAddr[2] = sb1.src;
19153 if (x48) depAddr[0] = sb2.src;
19154
19155 add(1 | depAddr[3], addr3.ud(2), slmAOffsetLoad,
19156 slmAdvance + 0 * 32 / 16);
19157 add(1 | depAddr[2], addr2.ud(2), slmAOffsetLoad,
19158 slmAdvance + 8 * 32 / 16);
19159 InstructionModifier swsb;
19160 if (strategy.simulation) swsb = sb2.src;
19161 if (buffer == 3)
19162 mov(2 | swsb, slmAOffsetLoad(1), slmAOffsetLoadInit(1));
19163 else
19164 add(2 | swsb, slmAOffsetLoad(1), slmAOffsetLoad(1), slmAdvance);
19165
19166 load(16 | SWSB<AllPipes>(sb3, 3), A_regs[0], block_oword(16), SLM,
19167 addr3);
19168 load(16 | SWSB<AllPipes>(sb4, 2), A_regs[8], block_oword(16), SLM,
19169 addr2);
19170 depAddr[3] = sb3.src;
19171 depAddr[2] = sb4.src;
19172 }
19173}
19174
19175template <HW hw>
19176void gemm_kernel_generator_t<hw>::sysgemmMultiplyChunk(
19177 const GEMMProblem &problem, const GEMMStrategy &strategy, bool first,
19178 int ao, int i0, bool waitB, bool prepB,
19179 const InstructionModifier &swsb0, const InstructionModifier &swsbEnd) {
19180 using namespace sysgemm;
19181 int co = i0 * 6;
19182
19183 auto dpaswTyped
19184 = [&](InstructionModifier mod, uint8_t sdepth, uint8_t rcount,
19185 const GRF &cReg, const GRF &aReg, const GRF &bReg) {
19186 auto A = aReg.retype(problem.Ta.ngen());
19187 auto B = bReg.retype(problem.Tb.ngen());
19188 auto C = cReg.retype(problem.Tc.ngen());
19189 first ? dpasw(mod, sdepth, rcount, C,
19190 null.retype(problem.Tc.ngen()), A, B)
19191 : dpasw(mod, sdepth, rcount, C, C, A, B);
19192 };
19193
19194 if (strategy.unroll[LoopN] > 32) {
19195 if (waitB) {
19196 dpaswTyped(8 | swsb0 | Atomic, 8, 8, C_regs[co], A_regs[ao],
19197 B_regs[0]);
19198 dpaswTyped(8, 8, 8, C_regs[co + 8], A_regs[ao], B_regs[4]);
19199 dpaswTyped(8 | sb1.dst | Atomic, 8, 8, C_regs[co + 16], A_regs[ao],
19200 B_regs[8]);
19201 dpaswTyped(8, 8, 8, C_regs[co + 24], A_regs[ao], B_regs[12]);
19202 dpaswTyped(8 | sb2.dst | Atomic, 8, 8, C_regs[co + 32], A_regs[ao],
19203 B_regs[16]);
19204 dpaswTyped(
19205 8 | swsbEnd, 8, 8, C_regs[co + 40], A_regs[ao], B_regs[20]);
19206 } else if (prepB) {
19207 dpaswTyped(8 | swsb0 | Atomic, 8, 8, C_regs[co], A_regs[ao],
19208 B_regs[0]);
19209 dpaswTyped(8 | sb0, 8, 8, C_regs[co + 8], A_regs[ao], B_regs[4]);
19210 dpaswTyped(
19211 8 | Atomic, 8, 8, C_regs[co + 16], A_regs[ao], B_regs[8]);
19212 dpaswTyped(8 | sb1, 8, 8, C_regs[co + 24], A_regs[ao], B_regs[12]);
19213 dpaswTyped(
19214 8 | Atomic, 8, 8, C_regs[co + 32], A_regs[ao], B_regs[16]);
19215 dpaswTyped(
19216 8 | swsbEnd, 8, 8, C_regs[co + 40], A_regs[ao], B_regs[20]);
19217 } else {
19218 dpaswTyped(8 | swsb0 | Atomic, 8, 8, C_regs[co], A_regs[ao],
19219 B_regs[0]);
19220 dpaswTyped(8 | Atomic, 8, 8, C_regs[co + 8], A_regs[ao], B_regs[4]);
19221 dpaswTyped(
19222 8 | Atomic, 8, 8, C_regs[co + 16], A_regs[ao], B_regs[8]);
19223 dpaswTyped(
19224 8 | Atomic, 8, 8, C_regs[co + 24], A_regs[ao], B_regs[12]);
19225 dpaswTyped(
19226 8 | Atomic, 8, 8, C_regs[co + 32], A_regs[ao], B_regs[16]);
19227 dpaswTyped(
19228 8 | swsbEnd, 8, 8, C_regs[co + 40], A_regs[ao], B_regs[20]);
19229 }
19230 } else {
19231 if (waitB) {
19232 dpaswTyped(8 | swsb0 | Atomic, 8, 8, C_regs[co], A_regs[ao],
19233 B_regs[0]);
19234 dpaswTyped(8, 8, 8, C_regs[co + 8], A_regs[ao], B_regs[4]);
19235 dpaswTyped(8 | sb1.dst | Atomic, 8, 8, C_regs[co + 16], A_regs[ao],
19236 B_regs[8]);
19237 dpaswTyped(
19238 8 | swsbEnd, 8, 8, C_regs[co + 24], A_regs[ao], B_regs[12]);
19239 } else if (prepB) {
19240 dpaswTyped(8 | swsb0 | Atomic, 8, 8, C_regs[co], A_regs[ao],
19241 B_regs[0]);
19242 dpaswTyped(8 | sb0, 8, 8, C_regs[co + 8], A_regs[ao], B_regs[4]);
19243 dpaswTyped(
19244 8 | Atomic, 8, 8, C_regs[co + 16], A_regs[ao], B_regs[8]);
19245 dpaswTyped(
19246 8 | swsbEnd, 8, 8, C_regs[co + 24], A_regs[ao], B_regs[12]);
19247 } else {
19248 dpaswTyped(8 | swsb0 | Atomic, 8, 8, C_regs[co], A_regs[ao],
19249 B_regs[0]);
19250 dpaswTyped(8 | Atomic, 8, 8, C_regs[co + 8], A_regs[ao], B_regs[4]);
19251 dpaswTyped(
19252 8 | Atomic, 8, 8, C_regs[co + 16], A_regs[ao], B_regs[8]);
19253 dpaswTyped(
19254 8 | swsbEnd, 8, 8, C_regs[co + 24], A_regs[ao], B_regs[12]);
19255 }
19256 }
19257}
19258
19259template <HW hw>
19260void gemm_kernel_generator_t<hw>::sysgemmBarrierPrep(
19261 const InstructionModifier &swsb, const GRF &header) {
19262 using namespace sysgemm;
19263 mov<uint32_t>(1 | swsb, header[2], barrierVal);
19264}
19265
19266template <HW hw>
19267void gemm_kernel_generator_t<hw>::sysgemmReorderLocalIDs(
19268 const GEMMProblem &problem, const GEMMStrategy &strategy,
19269 GEMMState &state) {
19270 if (strategy.splitCopy) return;
19271 if (strategy.wg[LoopM] != 8) return;
19272
19273 auto storage = state.ra.alloc_sub<uint64_t>();
19274 auto temp = storage.uw(0);
19275 auto temp2 = storage.uw(2);
19276 auto lidM = state.inputs.localIDM;
19277 auto lidN = state.inputs.localIDN;
19278
19279 // Remap local IDs so that upper 4x4 threads come first, then lower 4x4 threads.
19280 bfi2(1, temp, 0x08, lidN, lidM);
19281 shr(1, temp2, lidN, 1);
19282 shr(1, lidN, temp, 2);
19283 bfi2(1, lidM, 0x04, temp2, lidM);
19284
19285 state.ra.safeRelease(storage);
19286}
19287
19288namespace sysgemm2 {
19289namespace x48 {
19290static GRFRange A_regs = GRF(32) - GRF(63);
19291static GRFRange B_regs = GRF(2) - GRF(25);
19292static GRFRange C_regs = GRF(64) - GRF(255);
19293static Subregister B_addr[3] = {GRF(26).ud(2), GRF(27).ud(2), GRF(1).ud(2)};
19294static Subregister A_addr[4]
19295 = {GRF(28).ud(2), GRF(29).ud(2), GRF(30).ud(2), GRF(31).ud(2)};
19296static GRF headerTemp = GRF(0);
19297} // namespace x48
19298
19299namespace x32 {
19300static GRFRange A_regs[2] = {GRF(32) - GRF(63), GRF(96) - GRF(127)};
19301static GRFRange B_regs[2] = {GRF(2) - GRF(17), GRF(66) - GRF(81)};
19302static GRFRange C_regs = GRF(128) - GRF(255);
19303static Subregister B_addr[2][2]
19304 = {{GRF(26).ud(2), GRF(27).ud(2)}, {GRF(90).ud(2), GRF(91).ud(2)}};
19305static Subregister A_addr[2][4]
19306 = {{GRF(28).ud(2), GRF(29).ud(2), GRF(30).ud(2), GRF(31).ud(2)},
19307 {GRF(92).ud(2), GRF(93).ud(2), GRF(94).ud(2), GRF(95).ud(2)}};
19308static GRF barrierHeader = GRF(0);
19309static GRF fenceHeader = GRF(64);
19310} // namespace x32
19311
19312static GRFRange copyInputs = GRF(254) - GRF(255);
19313static Subregister A_copyLoadAddr0 = GRF(254).uq(0);
19314static Subregister A_copyLoadAddrSurf0 = GRF(254).ud(2);
19315static Subregister slmAOff = GRF(254).d(4);
19316static Subregister lda = GRF(254).ud(6);
19317static Subregister B_copyLoadAddr0 = GRF(255).uq(0);
19318static Subregister B_copyLoadAddrSurf0 = GRF(255).ud(2);
19319static Subregister slmBOff[2] = {GRF(255).d(4), GRF(255).d(5)};
19320static Subregister ldb = GRF(255).ud(6);
19321
19322static Subregister kCounter = AccumulatorRegister(2).d(0);
19323static Subregister barrierVal = AddressRegister(0).ud(0);
19324} // namespace sysgemm2
19325
19326template <HW hw>
19327bool gemm_kernel_generator_t<hw>::sysgemm2AccumulateC(
19328 GEMMProblem &problem, const GEMMStrategy &strategy, GEMMState &state) {
19329 using namespace sysgemm2;
19330 auto params = systolicParams(hw, problem, strategy);
19331 auto unrollM = strategy.unroll[LoopM];
19332 auto unrollN = strategy.unroll[LoopN];
19333 auto localIDM = state.lidM;
19334 auto localIDN = state.lidN;
19335 auto C_regs = (unrollN == 48) ? x48::C_regs : x32::C_regs;
19336
19337 if (unrollM != 16 && unrollM != 32) stub();
19338 if (unrollN != 32 && unrollN != 48) stub();
19339 if (isPacked(problem.A.layout)) {
19340 if (problem.A.crosspack != params.opsPerChan) stub();
19341 if (problem.A.tileR != params.osys) stub();
19342 if (problem.A.tileC != params.ksys) stub();
19343 }
19344 if (isPacked(problem.B.layout)) {
19345 if (problem.B.crosspack != params.ksys) stub();
19346 if (problem.B.tileR != 0 || problem.B.tileC != 0) stub();
19347 }
19348
19349 state.ra.claim(C_regs);
19350
19351 // Check whether this thread will do copy or compute. Copy threads get priority (lower thread #).
19352 auto flagCompute = f1[1];
19353 mov(1, flagCompute, state.isCompute.uw());
19354 state.ra.safeRelease(state.isCompute);
19355
19356 // Calculate A/B addresses and SLM offsets.
19357 auto tempStorage = C_regs[0];
19358 auto suboffsetA = tempStorage.ud(0);
19359 auto suboffsetB = tempStorage.ud(1);
19360 auto tempB = tempStorage.ud(2);
19361 auto ldaUnrollM4 = tempStorage.ud(3);
19362 auto aInc = tempStorage.ud(4);
19363 auto ldbUnrollN4 = tempStorage.ud(5);
19364 auto bInc = tempStorage.ud(6);
19365
19366 if (problem.A.layout == MatrixLayout::T)
19367 mulConstant(1, ldaUnrollM4, state.inputs.lda, unrollM / 4);
19368 if (problem.B.layout == MatrixLayout::N)
19369 mulConstant(1, ldbUnrollN4, state.inputs.ldb, unrollN / 4);
19370 if (!isPacked(problem.A.layout)) mov(1, lda, state.inputs.lda);
19371 if (!isPacked(problem.B.layout)) mov(1, ldb, state.inputs.ldb);
19372
19373 and_(1 | ne | state.flagAP, null.uw(), localIDM, 1);
19374
19375 switch (problem.A.layout) {
19376 case MatrixLayout::Pc:
19377 mulConstant(1, aInc, localIDN, unrollM * (32 / 4));
19378 break;
19379 case MatrixLayout::N:
19380 mulConstant(1, aInc, localIDN, unrollM * problem.Ta / 4);
19381 break;
19382 case MatrixLayout::T: mul(1, aInc, ldaUnrollM4, localIDN.uw()); break;
19383 default: stub();
19384 }
19385 switch (problem.B.layout) {
19386 case MatrixLayout::Pr:
19387 mulConstant(1, bInc, localIDM, unrollN * (32 / 4));
19388 break;
19389 case MatrixLayout::N: mul(1, bInc, ldbUnrollN4, localIDM.uw()); break;
19390 case MatrixLayout::T:
19391 mulConstant(1, bInc, localIDM, unrollN * problem.Tb / 4);
19392 break;
19393 default: stub();
19394 }
19395
19396 mulConstant(1, suboffsetA, localIDN, unrollM * (32 / 4) / 16);
19397 mulConstant(1, suboffsetB, localIDM, unrollN * (32 / 4) / 16);
19398
19399 if (strategy.A.base.isStateless())
19400 eadd(1, A_copyLoadAddr0, state.effA, aInc, strategy, state);
19401 else
19402 add(1, A_copyLoadAddrSurf0, state.effA, aInc);
19403
19404 if (strategy.B.base.isStateless())
19405 eadd(1, B_copyLoadAddr0, state.effB, bInc, strategy, state);
19406 else
19407 add(1, B_copyLoadAddrSurf0, state.effB, bInc);
19408
19409 mad(1, tempB, (4 * unrollM * 36) / 16, localIDN, (unrollN * 32) / 16);
19410
19411 mul(1, x48::A_addr[0], localIDM, (unrollM * 36) / 16);
19412 add(1 | state.flagAP, x48::B_addr[0], tempB, (unrollN / 2) * (32 / 16));
19413 mov(1 | ~state.flagAP, x48::B_addr[0], tempB);
19414
19415 add(1, slmAOff, x48::A_addr[0], suboffsetA);
19416 add(1, slmBOff[0], tempB, suboffsetB);
19417 add3(1, slmBOff[1], tempB, suboffsetB, 8 * 32 / 16);
19418
19419 // Marshal data needed later into acc2 for safekeeping.
19420 auto saveData = state.ra.alloc_range(2);
19421 auto kLoops = saveData[0].d(0);
19422 auto ldc = saveData[0].ud(1);
19423 auto flags = saveData[0].ud(2);
19424 auto k = saveData[0].ud(3);
19425 auto remM = saveData[0].uw(8);
19426 auto remN = saveData[0].uw(9);
19427 auto abo = saveData[0].ud(5);
19428 auto ao = saveData[0].w(10);
19429 auto bo = saveData[0].w(11);
19430 auto alpha = saveData[0].ud(6).reinterpret(0, problem.Ts.ngen());
19431 auto beta = saveData[0].ud(7).reinterpret(0, problem.Ts.ngen());
19432 auto remFusedStorage = saveData[1].ud(0);
19433 auto diagC = saveData[1].ud(1);
19434 auto saveI0 = saveData[1].ud(1);
19435 auto effCO = saveData[1].uq(1);
19436 auto saveJ0 = saveData[1].ud(3);
19437 auto C_ptr = saveData[1].uq(2);
19438 auto slotAB = saveData[1].ud(6);
19439 auto effAs = a0.ud(4); // dwords 4-5
19440 auto effBs = a0.ud(6); // dwords 6-7
19441
19442 if (state.r0_info != acc0.ud()) mov<uint32_t>(8, acc0, state.r0_info);
19443
19444 add(1, kLoops, state.k, params.ksys - 1);
19445 mov(1, ldc, state.inputs.ldc[0]);
19446 emov(1, C_ptr, state.effC[0], strategy, state);
19447 if (state.inputs.flags.isValid()) mov(1, flags, state.inputs.flags);
19448 mov(1, k, state.k);
19449 if (state.remainders[LoopM].isValid())
19450 mov(1, remM, state.remainders[LoopM]);
19451 if (state.remainders[LoopN].isValid())
19452 mov(1, remN, state.remainders[LoopN]);
19453 if (state.inputs.abo.isValid())
19454 mov(1, abo, state.inputs.abo);
19455 else {
19456 if (state.inputs.ao.isValid()) mov(1, ao, state.inputs.ao);
19457 if (state.inputs.bo.isValid()) mov(1, bo, state.inputs.bo);
19458 }
19459 if (state.inputs.alpha_real.isValid())
19460 mov(1, alpha, state.inputs.alpha_real);
19461 if (state.inputs.beta_real.isValid()) mov(1, beta, state.inputs.beta_real);
19462 shr(1, kLoops, kLoops, log2(params.ksys));
19463 if (state.remFusedStorage.isValid())
19464 mov(1, remFusedStorage, state.remFusedStorage);
19465 if (state.diagC.isValid()) mov(1, diagC, state.diagC);
19466 if (state.effCO.isValid()) {
19467 effCO = effCO.reinterpret(0, state.effCO.getType());
19468 emov(1, effCO, state.effCO, strategy, state);
19469 }
19470 if (problem.hasBinaryPostOp()) {
19471 if (state.diagC.isValid()) stub();
19472 if (state.effCO.isValid() && effCO.getBytes() > 4) stub();
19473 mov(1, saveI0, state.i0);
19474 mov(1, saveJ0, state.j0);
19475 }
19476 if (problem.abOffset != ABOffset::None) {
19477 GRF temp = state.ra.alloc();
19478 state.effAs = temp.uq(0).reinterpret(0, state.effA.getType());
19479 state.effBs = temp.uq(1).reinterpret(0, state.effB.getType());
19480 gemmCalcABOffsetAddrs(problem, strategy, state);
19481 mov<uint32_t>(4, effAs(1), temp);
19482 state.ra.safeRelease(temp);
19483 }
19484 if (state.fusedGEMM.slotA.isValid())
19485 mov(1, slotAB, state.fusedGEMM.slotA.ud());
19486
19487 if (state.isNested) {
19488 // To do: replace with sel
19489 mov(2 | ~flagCompute, remM(1), 0);
19490 mov(1 | ~flagCompute, remFusedStorage, 0);
19491 }
19492
19493 releaseSavedMNLocalIDs(state);
19494 state.ra.safeRelease(state.effA);
19495 state.ra.safeRelease(state.effB);
19496 state.ra.safeRelease(state.effC[0]);
19497 state.ra.safeRelease(state.inputs.lda);
19498 state.ra.safeRelease(state.inputs.ldb);
19499
19500 state.ra.release(state.inputs.ldc[0]);
19501 state.ra.release(state.k);
19502 state.ra.release(state.remainders[LoopM]);
19503 state.ra.release(state.remainders[LoopN]);
19504 state.ra.release(state.inputs.abo);
19505 state.ra.release(state.inputs.ao);
19506 state.ra.release(state.inputs.bo);
19507 state.ra.release(state.inputs.alpha_real);
19508 state.ra.release(state.inputs.beta_real);
19509 state.ra.release(state.remFusedStorage);
19510 state.ra.release(state.diagC);
19511 state.ra.release(state.effCO);
19512 state.ra.release(state.fusedGEMM.slotA);
19513 state.ra.release(state.fusedGEMM.slotB);
19514
19515 if (state.r0_info.isARF()) stub();
19516 GRF r0_info {state.r0_info.getBase()};
19517 if (hw >= HW::XeHPG) {
19518 mov(1, barrierVal.uw(0), Immediate::uw(0));
19519 mov(2, barrierVal.ub(2)(1), r0_info.ub(11)(0));
19520 } else
19521 and_(1, barrierVal, r0_info.ud(2), 0x7F000000);
19522
19523 mov<float>(16, acc2, saveData[0]);
19524
19525 Label labelCompute, labelDone;
19526
19527 jmpi(1 | f1[1], labelCompute);
19528 sysgemm2KLoopCopy(problem, strategy, state);
19529 if (state.isNested) {
19530 jmpi(1, labelDone);
19531 } else
19532 epilogue(strategy, state);
19533 mark(labelCompute);
19534 sysgemm2KLoopCompute(problem, strategy, state);
19535 mark(labelDone);
19536
19537 mov<float>(16, saveData[0], acc2);
19538
19539 state.effC[0] = C_ptr;
19540 state.inputs.ldc[0] = ldc;
19541 if (state.inputs.flags.isValid()) state.inputs.flags = flags;
19542 state.k = k;
19543 if (state.remainders[LoopM].isValid()) state.remainders[LoopM] = remM;
19544 if (state.remainders[LoopN].isValid()) state.remainders[LoopN] = remN;
19545 if (state.inputs.abo.isValid()) state.inputs.abo = abo;
19546 if (state.inputs.ao.isValid()) state.inputs.ao = ao;
19547 if (state.inputs.bo.isValid()) state.inputs.bo = bo;
19548 if (state.inputs.alpha_real.isValid()) {
19549 state.inputs.alpha_real = alpha;
19550 if (!problem.alpha_real.fixed()) problem.alpha_real = alpha;
19551 }
19552 if (state.inputs.beta_real.isValid()) {
19553 state.inputs.beta_real = beta;
19554 if (!problem.beta_real.fixed()) problem.beta_real = beta;
19555 }
19556 if (state.remFusedStorage.isValid()) {
19557 state.remFusedStorage = remFusedStorage;
19558 state.remaindersFused[LoopM] = state.remainders[LoopM];
19559 state.remaindersFused[LoopN] = state.remainders[LoopN];
19560 state.remaindersFused[strategy.fusedLoop] = remFusedStorage;
19561 }
19562 if (state.diagC.isValid()) state.diagC = diagC;
19563 if (state.effCO.isValid()) state.effCO = effCO;
19564 if (state.fusedGEMM.slotA.isValid()) {
19565 state.fusedGEMM.slotA = slotAB.uw(0);
19566 state.fusedGEMM.slotB = slotAB.uw(1);
19567 }
19568 if (problem.abOffset != ABOffset::None) {
19569 auto tas = state.effAs.getType();
19570 auto tbs = state.effBs.getType();
19571 state.effAs = state.ra.alloc_sub(tas);
19572 state.effBs = state.ra.alloc_sub(tbs);
19573 mov<uint32_t>(getDwords(tas), state.effAs.ud()(1), effAs(1));
19574 mov<uint32_t>(getDwords(tbs), state.effBs.ud()(1), effBs(1));
19575 }
19576 if (problem.hasBinaryPostOp()) {
19577 state.i0 = saveI0;
19578 state.j0 = saveJ0;
19579 }
19580
19581 // Set up C internal layout and registers.
19582 state.C_regs.resize(1);
19583 state.C_regs[0] = C_regs;
19584 state.C_layout.clear();
19585 state.C_layout.reserve((unrollM / 8) * (unrollN / 4));
19586 for (int j0 = 0; j0 < unrollN; j0 += 4) {
19587 for (int i0 = 0; i0 < unrollM; i0 += 8) {
19588 RegisterBlock block;
19589 block.log2GRFBytes = GRF::log2Bytes(hw);
19590 block.colMajor = true;
19591 block.splitComplex = false;
19592 block.cxComponent = RegisterBlock::Interleaved;
19593 block.nr = block.ld = 8;
19594 block.nc = 4;
19595 block.component = 0;
19596 block.offsetR = i0;
19597 block.offsetC = j0;
19598 block.crosspack = 1;
19599 block.bytes = 8 * 4 * problem.Tc.size();
19600 block.simdSize = 0;
19601
19602 int j0Interleaved = j0 << 1;
19603 if (j0Interleaved >= unrollN) j0Interleaved += 4 - unrollN;
19604
19605 block.offsetBytes
19606 = (unrollN * i0 / 8 + j0Interleaved) * GRF::bytes(hw);
19607 state.C_layout.push_back(block);
19608 }
19609 }
19610
19611 // Set up C external layout.
19612 state.copyC = true;
19613 bool remM_Ce, remN_Ce;
19614 getCRemainders(problem, strategy, remM_Ce, remN_Ce);
19615
19616 if (!getRegLayout(problem.Tc_ext, state.C_layoutExt, unrollM, unrollN,
19617 remM_Ce, remN_Ce, true, false, 0, 0, problem.C,
19618 state.Cext_strategy))
19619 return false;
19620 if (remM_Ce || remN_Ce)
19621 (void)getRegLayout(problem.Tc_ext, state.C_layoutExtUnmasked, unrollM,
19622 unrollN, false, false, true, false, 0, 0, problem.C,
19623 state.Cext_strategy);
19624
19625 if (state.r0_info != acc0.ud()) mov<uint32_t>(8, state.r0_info, acc0);
19626
19627 return true; // Success!
19628}
19629
19630template <HW hw>
19631void gemm_kernel_generator_t<hw>::sysgemm2KLoopCopy(const GEMMProblem &problem,
19632 const GEMMStrategy &strategy, GEMMState &state) {
19633 using namespace sysgemm2;
19634
19635 Label top, bottom, smallK, skipSmallK, reenterSmallK, done;
19636
19637 bool surfaceA = !strategy.A.base.isStateless();
19638 bool surfaceB = !strategy.B.base.isStateless();
19639 auto unrollM = strategy.unroll[LoopM];
19640 auto unrollN = strategy.unroll[LoopN];
19641 bool _32x = (unrollM == 32);
19642 bool x48 = (unrollN == 48);
19643 bool int8 = problem.Ta.size() == 1;
19644
19645 int globalBuffers = strategy.A_copies;
19646 auto slmStride = strategy.slmSysgemmBlockSize() / 16;
19647
19648 if (globalBuffers < 2) stub();
19649 if (globalBuffers > 5) stub();
19650
19651 auto saveRA = state.ra;
19652 state.ra.release(r0 - r127);
19653 state.ra.release(r128 - r255);
19654 state.ra.claim(copyInputs);
19655
19656 // Register and token allocation.
19657 int aTokens = 0, aLoadAddrs = 0, aStoreAddrs = 1;
19658 int bTokens = 0, bLoadAddrs = 0, bStoreAddrs = 1 + x48;
19659 int aiRegCount = unrollM / 4, biRegCount = unrollN / 4;
19660 bool aRepack = false, bRepack = false;
19661
19662 if (problem.A.alignment & 3 || problem.B.alignment & 3) stub();
19663
19664 switch (problem.A.layout) {
19665 case MatrixLayout::Pc:
19666 aTokens = aLoadAddrs = (surfaceA && _32x) ? 2 : 1;
19667 break;
19668 case MatrixLayout::N:
19669 if (!surfaceA) stub();
19670 aTokens = int8 ? 2 : 1;
19671 aLoadAddrs = aTokens * 2;
19672 aRepack = true;
19673 break;
19674 case MatrixLayout::T:
19675 if (!surfaceA) stub();
19676 aTokens = aLoadAddrs = 2;
19677 break;
19678 default: stub();
19679 }
19680
19681 switch (problem.B.layout) {
19682 case MatrixLayout::Pr:
19683 bTokens = bLoadAddrs = (surfaceB ? 2 : 1) + x48;
19684 break;
19685 case MatrixLayout::N:
19686 if (!surfaceB) stub();
19687 bTokens = 2;
19688 bLoadAddrs = x48 ? 4 : 2;
19689 if (x48) biRegCount = 16;
19690 bRepack = true;
19691 break;
19692 case MatrixLayout::T:
19693 if (!surfaceB) stub();
19694 bTokens = (int8 || x48) ? 2 : 1;
19695 bLoadAddrs = bTokens * 2;
19696 bRepack = true;
19697 break;
19698 default: stub();
19699 }
19700
19701 int tokenStride = aTokens + bTokens;
19702 if (tokenStride * globalBuffers > 15)
19703 throw std::runtime_error("Not enough tokens available.");
19704
19705 auto &Ai_regs = state.Ai_regs;
19706 auto &Bi_regs = state.Bi_regs;
19707 auto &Ao_regs = state.Ao_regs;
19708 auto &Bo_regs = state.Bo_regs;
19709 auto &Ai_addrs = state.Ai_addrs;
19710 auto &Bi_addrs = state.Bi_addrs;
19711 auto &Ao_addrs = state.Ao_addrs;
19712 auto &Bo_addrs = state.Bo_addrs;
19713 GRFRange ldaMultiples, ldbMultiples;
19714 FlagRegister flag12;
19715 GRF A_swizzle, B_swizzle;
19716 Subregister lda16, ldb16, ldaK, ldbK;
19717 Subregister Ai_advance, Bi_advance;
19718 GRF copyBarrierHeader = state.ra.alloc();
19719 GRF copyFenceHeader = state.ra.alloc();
19720 GRF slmBase = state.ra.alloc().d();
19721 GRF temp = state.ra.alloc().d();
19722
19723 Ai_regs.reserve(globalBuffers);
19724 Bi_regs.reserve(globalBuffers);
19725 Ai_addrs.reserve(globalBuffers);
19726 Bi_addrs.reserve(globalBuffers);
19727 for (int i = 0; i < globalBuffers; i++) {
19728 Ai_regs.push_back(state.ra.alloc_range(aiRegCount));
19729 Bi_regs.push_back(state.ra.alloc_range(biRegCount));
19730 Ai_addrs.push_back(state.ra.alloc_range(aLoadAddrs));
19731 Bi_addrs.push_back(state.ra.alloc_range(bLoadAddrs));
19732 }
19733
19734 if (aRepack) Ao_regs = state.ra.alloc_range(8);
19735 if (bRepack) Bo_regs = state.ra.alloc_range(unrollN / 4);
19736 Ao_addrs.push_back(state.ra.alloc_range(aStoreAddrs));
19737 Bo_addrs.push_back(state.ra.alloc_range(bStoreAddrs));
19738
19739 if (state.emulate.temp[0].isValid()) {
19740 for (int q = 0; q < 2; q++)
19741 state.ra.safeRelease(state.emulate.temp[q]);
19742 state.emulate.flag = f1[1];
19743 state.emulate.flagOffset = 8;
19744 }
19745
19746 // Address initialization.
19747 if (surfaceA && isPacked(problem.A.layout))
19748 shr(1, A_copyLoadAddrSurf0, A_copyLoadAddrSurf0, 4);
19749 if (surfaceB && isPacked(problem.B.layout))
19750 shr(1, B_copyLoadAddrSurf0, B_copyLoadAddrSurf0, 4);
19751
19752 mov(1, slmBase[0], 0);
19753 mov(1, slmBase[1], -4 * slmStride);
19754
19755 auto makeLDMultiples = [&](GRFRange &multiples, const Subregister &ld,
19756 int n) {
19757 multiples = state.ra.alloc_range(n / 8);
19758 mov<uint16_t>(8, multiples[0], Immediate::uv(0, 1, 2, 3, 4, 5, 6, 7));
19759 if (n > 8)
19760 mov<uint16_t>(8, multiples[1],
19761 Immediate::uv(8, 9, 10, 11, 12, 13, 14, 15));
19762 mul<uint32_t>(8, multiples[0], ld, multiples[0].uw());
19763 if (n > 8) mul<uint32_t>(8, multiples[1], ld, multiples[1].uw());
19764 };
19765
19766 switch (problem.A.layout) {
19767 case MatrixLayout::N:
19768 lda16 = state.ra.alloc_sub<uint32_t>();
19769 Ai_advance = state.ra.alloc_sub<uint32_t>();
19770 if (!int8) {
19771 A_swizzle = state.ra.alloc().uw();
19772 mov(4, A_swizzle[0](1), Immediate::uv(0, 4, 2, 6, 0, 0, 0, 0));
19773 add(4, A_swizzle[4](1), A_swizzle[0](1), 64);
19774 add(8, A_swizzle[8](1), A_swizzle[0](1), 128);
19775 }
19776 makeLDMultiples(ldaMultiples, lda, 16);
19777 mulConstant(1, lda16, lda, 16);
19778 if (int8) {
19779 ldaK = state.ra.alloc_sub<uint32_t>();
19780 mulConstant(1, ldaK, lda, 32);
19781 } else
19782 ldaK = lda16;
19783 mulConstant(1, Ai_advance, lda, (int8 ? 32 : 16) * globalBuffers);
19784 break;
19785 case MatrixLayout::T: makeLDMultiples(ldaMultiples, lda, 8); break;
19786 default: break;
19787 }
19788
19789 switch (problem.B.layout) {
19790 case MatrixLayout::N:
19791 B_swizzle = state.ra.alloc().uw();
19792 mov(8, B_swizzle[0](1), Immediate::uv(0, 1, 2, 3, 4, 5, 6, 7));
19793 if (x48) {
19794 flag12 = state.raVFlag.alloc();
19795 mov(1, flag12, 0x0FFF);
19796 }
19797 mad(8, B_swizzle[8](1), 4, B_swizzle[0](1), x48 ? 64 : 32);
19798 mulConstant(8, B_swizzle[0](1), B_swizzle[0](1), x48 ? 64 : 32);
19799 makeLDMultiples(ldbMultiples, ldb, x48 ? 16 : 8);
19800 break;
19801 case MatrixLayout::T:
19802 makeLDMultiples(ldbMultiples, ldb, 16);
19803 if (int8) {
19804 ldb16 = state.ra.alloc_sub<uint32_t>();
19805 mulConstant(1, ldb16, ldb, 16);
19806 }
19807 Bi_advance = state.ra.alloc_sub<uint32_t>();
19808 ldbK = state.ra.alloc_sub<uint32_t>();
19809 mulConstant(1, Bi_advance, ldb, (int8 ? 32 : 16) * globalBuffers);
19810 mulConstant(1, ldbK, ldb, int8 ? 32 : 16);
19811 default: break;
19812 }
19813
19814 for (int i = 0; i < globalBuffers; i++)
19815 switch (problem.A.layout) {
19816 case MatrixLayout::Pc:
19817 if (surfaceA) {
19818 add(1, Ai_addrs[i][0].ud(2), A_copyLoadAddrSurf0,
19819 i * unrollM * 32 / 16);
19820 if (_32x)
19821 add(1, Ai_addrs[i][1].ud(2), A_copyLoadAddrSurf0,
19822 (i * unrollM * 32 + 4 * 32) / 16);
19823 } else
19824 eadd(1, Ai_addrs[i][0].uq(0), A_copyLoadAddr0,
19825 i * unrollM * 32, strategy, state);
19826 break;
19827 case MatrixLayout::N:
19828 if (i == 0) {
19829 add<uint32_t>(16, Ai_addrs[i][0], A_copyLoadAddrSurf0,
19830 ldaMultiples);
19831 if (int8)
19832 add3<uint32_t>(16, Ai_addrs[i][2], A_copyLoadAddrSurf0,
19833 ldaMultiples, lda16);
19834 } else {
19835 add<uint32_t>(16, Ai_addrs[i][0], Ai_addrs[i - 1][0], ldaK);
19836 if (int8)
19837 add<uint32_t>(
19838 16, Ai_addrs[i][2], Ai_addrs[i - 1][2], ldaK);
19839 }
19840 break;
19841 case MatrixLayout::T:
19842 add3<uint32_t>(8, Ai_addrs[i][0], A_copyLoadAddrSurf0,
19843 ldaMultiples, i * 32 + 0);
19844 add3<uint32_t>(8, Ai_addrs[i][1], A_copyLoadAddrSurf0,
19845 ldaMultiples, i * 32 + 16);
19846 break;
19847 default: stub();
19848 }
19849
19850 for (int i = 0; i < globalBuffers; i++)
19851 switch (problem.B.layout) {
19852 case MatrixLayout::Pr:
19853 if (surfaceB) {
19854 add(1, Bi_addrs[i][0].ud(2), B_copyLoadAddrSurf0,
19855 i * unrollN * 32 / 16);
19856 add(1, Bi_addrs[i][1].ud(2), B_copyLoadAddrSurf0,
19857 (i * unrollN * 32 + 4 * 32) / 16);
19858 if (x48)
19859 add(1, Bi_addrs[i][2].ud(2), B_copyLoadAddrSurf0,
19860 (i * unrollN * 32 + 8 * 32) / 16);
19861 } else {
19862 eadd(1, Bi_addrs[i][0].uq(0), B_copyLoadAddr0,
19863 i * unrollN * 32, strategy, state);
19864 if (x48)
19865 eadd(1, Bi_addrs[i][1].uq(0), B_copyLoadAddr0,
19866 i * unrollN * 32 + 8 * 32, strategy, state);
19867 }
19868 break;
19869 case MatrixLayout::N:
19870 add3<uint32_t>(x48 ? 16 : 8, Bi_addrs[i][0],
19871 B_copyLoadAddrSurf0, ldbMultiples, i * 32 + 0);
19872 add3<uint32_t>(x48 ? 16 : 8, Bi_addrs[i][x48 ? 2 : 1],
19873 B_copyLoadAddrSurf0, ldbMultiples, i * 32 + 16);
19874 break;
19875 case MatrixLayout::T:
19876 if (i == 0) {
19877 add<uint32_t>(16, Bi_addrs[i][0], B_copyLoadAddrSurf0,
19878 ldbMultiples);
19879 if (int8)
19880 add3<uint32_t>(16, Bi_addrs[i][2], B_copyLoadAddrSurf0,
19881 ldbMultiples, ldb16);
19882 else if (x48)
19883 add3<uint32_t>(16, Bi_addrs[i][2], B_copyLoadAddrSurf0,
19884 ldbMultiples, 16);
19885 } else {
19886 add<uint32_t>(16, Bi_addrs[i][0], Bi_addrs[i - 1][0], ldbK);
19887 if (int8 || x48)
19888 add<uint32_t>(
19889 16, Bi_addrs[i][2], Bi_addrs[i - 1][2], ldbK);
19890 }
19891 break;
19892 default: stub();
19893 }
19894
19895 sysgemmBarrierPrep(InstructionModifier(), copyBarrierHeader);
19896
19897 mov(2, slmBase[4](1), slmBase[0](1));
19898
19899 // Main logic.
19900 auto copyLoad = [&](int buffer) {
19901 int atbase = tokenStride * buffer;
19902 int btbase = atbase + aTokens;
19903 switch (problem.A.layout) {
19904 case MatrixLayout::Pc:
19905 if (surfaceA) {
19906 load(16 | SBID(atbase + 0), Ai_regs[buffer][0],
19907 block_oword(8), strategy.A.base,
19908 Ai_addrs[buffer][0]);
19909 if (_32x)
19910 load(16 | SBID(atbase + 1), Ai_regs[buffer][4],
19911 block_oword(8), strategy.A.base,
19912 Ai_addrs[buffer][1]);
19913 } else
19914 load(16 | SBID(atbase + 0), Ai_regs[buffer][0],
19915 block_hword(_32x ? 8 : 4), strategy.A.base,
19916 Ai_addrs[buffer]);
19917 break;
19918 case MatrixLayout::N:
19919 if (int8) {
19920 load(16 | SBID(atbase + 0), Ai_regs[buffer][0],
19921 surface_dword(ChannelMask::rg), strategy.A.base,
19922 Ai_addrs[buffer][0]);
19923 load(16 | SBID(atbase + 1), Ai_regs[buffer][4],
19924 surface_dword(ChannelMask::rg), strategy.A.base,
19925 Ai_addrs[buffer][2]);
19926 } else
19927 load(16 | SBID(atbase + 0), Ai_regs[buffer][0],
19928 surface_dword(ChannelMask::rgba), strategy.A.base,
19929 Ai_addrs[buffer][0]);
19930 break;
19931 case MatrixLayout::T:
19932 load(8 | SBID(atbase + 0), Ai_regs[buffer][0],
19933 surface_dword(ChannelMask::rgba), strategy.A.base,
19934 Ai_addrs[buffer][0]);
19935 load(8 | SBID(atbase + 1), Ai_regs[buffer][4],
19936 surface_dword(ChannelMask::rgba), strategy.A.base,
19937 Ai_addrs[buffer][1]);
19938 break;
19939 default: stub();
19940 }
19941
19942 switch (problem.B.layout) {
19943 case MatrixLayout::Pr:
19944 if (surfaceB) {
19945 load(16 | SBID(btbase + 0), Bi_regs[buffer][0],
19946 block_oword(8), strategy.B.base,
19947 Bi_addrs[buffer][0]);
19948 load(16 | SBID(btbase + 1), Bi_regs[buffer][4],
19949 block_oword(8), strategy.B.base,
19950 Bi_addrs[buffer][1]);
19951 if (x48)
19952 load(16 | SBID(btbase + 2), Bi_regs[buffer][8],
19953 block_oword(8), strategy.B.base,
19954 Bi_addrs[buffer][2]);
19955 } else {
19956 load(16 | SBID(btbase + 0), Bi_regs[buffer][0],
19957 block_hword(8), strategy.B.base,
19958 Bi_addrs[buffer][0]);
19959 if (x48)
19960 load(16 | SBID(btbase + 1), Bi_regs[buffer][8],
19961 block_hword(4), strategy.B.base,
19962 Bi_addrs[buffer][1]);
19963 }
19964 break;
19965 case MatrixLayout::N:
19966 if (x48) {
19967 load(16 | SBID(btbase + 0) | flag12 | any4h,
19968 Bi_regs[buffer][0],
19969 surface_dword(ChannelMask::rgba), strategy.B.base,
19970 Bi_addrs[buffer][0]);
19971 load(16 | SBID(btbase + 1) | flag12 | any4h,
19972 Bi_regs[buffer][8],
19973 surface_dword(ChannelMask::rgba), strategy.B.base,
19974 Bi_addrs[buffer][2]);
19975 } else {
19976 load(8 | SBID(btbase + 0), Bi_regs[buffer][0],
19977 surface_dword(ChannelMask::rgba), strategy.B.base,
19978 Bi_addrs[buffer][0]);
19979 load(8 | SBID(btbase + 1), Bi_regs[buffer][4],
19980 surface_dword(ChannelMask::rgba), strategy.B.base,
19981 Bi_addrs[buffer][1]);
19982 }
19983 break;
19984 case MatrixLayout::T:
19985 if (int8) {
19986 auto cmask = x48 ? ChannelMask::rgb : ChannelMask::rg;
19987 load(16 | SBID(btbase + 0), Bi_regs[buffer][0],
19988 surface_dword(cmask), strategy.B.base,
19989 Bi_addrs[buffer][0]);
19990 load(16 | SBID(btbase + 1), Bi_regs[buffer][x48 ? 6 : 4],
19991 surface_dword(cmask), strategy.B.base,
19992 Bi_addrs[buffer][2]);
19993 } else {
19994 load(16 | SBID(btbase + 0), Bi_regs[buffer][0],
19995 surface_dword(ChannelMask::rgba), strategy.B.base,
19996 Bi_addrs[buffer][0]);
19997 if (x48)
19998 load(16 | SBID(btbase + 1), Bi_regs[buffer][8],
19999 surface_dword(ChannelMask::rg), strategy.B.base,
20000 Bi_addrs[buffer][2]);
20001 }
20002 break;
20003 default: stub();
20004 }
20005 };
20006
20007 auto copyRepack = [&](int buffer) {
20008 int atbase = tokenStride * buffer;
20009 int btbase = atbase + aTokens;
20010
20011 switch (problem.A.layout) {
20012 case MatrixLayout::N:
20013 if (int8) {
20014 for (int j = 0; j < 4; j++) {
20015 int reg = (j >> 1);
20016 int sub = (j & 1) << 4;
20017 mov<uint8_t>(16, Ao_regs[j + 0][0](1),
20018 Ai_regs[buffer][reg + 0][sub](1, 4, 4));
20019 mov<uint8_t>(16, Ao_regs[j + 4][0](1),
20020 Ai_regs[buffer][reg + 4][sub](1, 4, 4));
20021 mov<uint8_t>(16, Ao_regs[j + 0][16](1),
20022 Ai_regs[buffer][reg + 2][sub](1, 4, 4));
20023 mov<uint8_t>(16, Ao_regs[j + 4][16](1),
20024 Ai_regs[buffer][reg + 6][sub](1, 4, 4));
20025 }
20026 } else {
20027 // a0: 0 4 2 6 64 68 66 70...
20028 add(16, a0, A_swizzle, Ai_regs[buffer][0].getBase() * 32);
20029 setDefaultAutoSWSB(false);
20030 sync.allwr(0b1 << atbase);
20031 for (int j = 0; j < 8; j++)
20032 mov<uint16_t>(
20033 16, Ao_regs[j], indirect[a0].uw(j * 8)(1, 0));
20034 setDefaultAutoSWSB(true);
20035 }
20036 break;
20037 default: break;
20038 }
20039
20040 switch (problem.B.layout) {
20041 case MatrixLayout::N:
20042 // a0 (x32): 0 32 64 96 128 160 192 228 4 36 68...
20043 // a0 (x48): 0 64 128 192 256 320 384 448 4 68 132 196...
20044 add(16, a0, B_swizzle, Bi_regs[buffer][0].getBase() * 32);
20045 setDefaultAutoSWSB(false);
20046 sync.allwr(0b11 << btbase);
20047 for (int j = 0; j < unrollN / 4; j += 2) // 2 cols at a time
20048 mov<uint32_t>(16, Bo_regs[j], indirect[a0].ud(j * 4)(1, 0));
20049 setDefaultAutoSWSB(true);
20050 break;
20051 case MatrixLayout::T:
20052 if (int8) {
20053 for (int j = 0; j < unrollN / 4; j++)
20054 mov<uint8_t>(16, Bo_regs[j][0](1),
20055 Bi_regs[buffer][(j & ~3) >> 1][j & 3](4));
20056 for (int j = 0; j < unrollN / 4; j++)
20057 mov<uint8_t>(16, Bo_regs[j][16](1),
20058 Bi_regs[buffer][(x48 ? 6 : 4) + ((j & ~3) >> 1)]
20059 [j & 3](4));
20060 } else {
20061 for (int j = 0; j < unrollN / 4; j++)
20062 mov<uint16_t>(16, Bo_regs[j],
20063 Bi_regs[buffer][j & ~1][j & 1](2));
20064 }
20065 break;
20066 default: break;
20067 }
20068 };
20069
20070 auto copyStore = [&](int buffer) {
20071 int atbase = tokenStride * buffer;
20072 int btbase = atbase + aTokens;
20073
20074 auto A_regs = aRepack ? Ao_regs : Ai_regs[buffer];
20075 auto B_regs = bRepack ? Bo_regs : Bi_regs[buffer];
20076
20077 copyRepack(buffer);
20078
20079 auto b1 = (surfaceB && isPacked(problem.B.layout)) ? 2 : 1;
20080
20081 store(16 | SBID(atbase + 0), block_oword(_32x ? 16 : 8), SLM,
20082 Ao_addrs[0][0], A_regs[0]);
20083 store(16 | SBID(btbase + 0), block_oword(16), SLM, Bo_addrs[0][0],
20084 B_regs[0]);
20085 if (x48)
20086 store(16 | SBID(btbase + b1), block_oword(8), SLM, Bo_addrs[0][1],
20087 B_regs[8]);
20088 };
20089
20090 auto advanceLoad = [&](int buffer) {
20091 switch (problem.A.layout) {
20092 case MatrixLayout::Pc:
20093 if (surfaceA) {
20094 add(1, Ai_addrs[buffer][0].ud(2), Ai_addrs[buffer][0].ud(2),
20095 globalBuffers * 32 * unrollM / 16);
20096 if (_32x)
20097 add(1, Ai_addrs[buffer][1].ud(2),
20098 Ai_addrs[buffer][1].ud(2),
20099 globalBuffers * 32 * unrollM / 16);
20100 } else
20101 eadd(1, Ai_addrs[buffer][0].uq(0),
20102 Ai_addrs[buffer][0].uq(0),
20103 globalBuffers * 32 * unrollM, strategy, state);
20104 break;
20105 case MatrixLayout::N:
20106 add<uint32_t>(16, Ai_addrs[buffer][0], Ai_addrs[buffer][0],
20107 Ai_advance);
20108 if (int8)
20109 add<uint32_t>(16, Ai_addrs[buffer][2], Ai_addrs[buffer][2],
20110 Ai_advance);
20111 break;
20112 case MatrixLayout::T:
20113 add<uint32_t>(8, Ai_addrs[buffer][0], Ai_addrs[buffer][0],
20114 32 * globalBuffers);
20115 add<uint32_t>(8, Ai_addrs[buffer][1], Ai_addrs[buffer][1],
20116 32 * globalBuffers);
20117 break;
20118 default: stub();
20119 }
20120
20121 switch (problem.B.layout) {
20122 case MatrixLayout::Pr:
20123 if (surfaceB) {
20124 add(1, Bi_addrs[buffer][0].ud(2), Bi_addrs[buffer][0].ud(2),
20125 globalBuffers * 32 * unrollN / 16);
20126 add(1, Bi_addrs[buffer][1].ud(2), Bi_addrs[buffer][1].ud(2),
20127 globalBuffers * 32 * unrollN / 16);
20128 if (x48)
20129 add(1, Bi_addrs[buffer][2].ud(2),
20130 Bi_addrs[buffer][2].ud(2),
20131 globalBuffers * 32 * unrollN / 16);
20132 } else {
20133 eadd(1, Bi_addrs[buffer][0].uq(0),
20134 Bi_addrs[buffer][0].uq(0),
20135 globalBuffers * 32 * unrollN, strategy, state);
20136 if (x48)
20137 eadd(1, Bi_addrs[buffer][1].uq(0),
20138 Bi_addrs[buffer][1].uq(0),
20139 globalBuffers * 32 * unrollN, strategy, state);
20140 }
20141 break;
20142 case MatrixLayout::N:
20143 add<uint32_t>(16, Bi_addrs[buffer][0], Bi_addrs[buffer][0],
20144 32 * globalBuffers);
20145 if (x48)
20146 add<uint32_t>(16, Bi_addrs[buffer][2], Bi_addrs[buffer][2],
20147 32 * globalBuffers);
20148 break;
20149 case MatrixLayout::T:
20150 add<uint32_t>(16, Bi_addrs[buffer][0], Bi_addrs[buffer][0],
20151 Bi_advance);
20152 if (int8 || x48)
20153 add<uint32_t>(16, Bi_addrs[buffer][2], Bi_addrs[buffer][2],
20154 Bi_advance);
20155 break;
20156 default: stub();
20157 }
20158 };
20159
20160 auto advanceStore = [&](int buffer = -1) {
20161 add(2, temp, slmBase, slmStride);
20162 add(1, Ao_addrs[0][0].ud(2), slmBase, slmAOff);
20163 add(1, Bo_addrs[0][0].ud(2), slmBase, slmBOff[0]);
20164 if (x48) add(1, Bo_addrs[0][1].ud(2), slmBase, slmBOff[1]);
20165
20166 csel(2 | ge | f0[0], slmBase, slmBase[4](1, 1, 0), temp[0](1, 1, 0),
20167 temp[1]);
20168 };
20169
20170 auto fence = [&]() { slmfence(sb15, copyFenceHeader, copyFenceHeader); };
20171
20172 auto splitBarrier = [&]() {
20173 barrierwait();
20174 barriermsg(sb15, copyBarrierHeader);
20175 };
20176
20177 // Warmup.
20178 if (globalBuffers > 1) cmp(1 | gt | f0[0], kCounter, 1);
20179 if (globalBuffers > 2) cmp(1 | gt | f0[1], kCounter, 2);
20180 if (globalBuffers > 3) cmp(1 | gt | f1[0], kCounter, 3);
20181 if (globalBuffers > 4) cmp(1 | gt | f1[1], kCounter, 4);
20182 if (globalBuffers > 1) {
20183 copyLoad(0);
20184 jmpi(1 | ~f0[0], smallK);
20185 }
20186 if (globalBuffers > 2) {
20187 copyLoad(1);
20188 jmpi(1 | ~f0[1], smallK);
20189 }
20190 if (globalBuffers > 3) {
20191 copyLoad(2);
20192 jmpi(1 | ~f1[0], smallK);
20193 }
20194 if (globalBuffers > 4) {
20195 copyLoad(3);
20196 jmpi(1 | ~f1[1], smallK);
20197 }
20198
20199 auto flagLast = FlagRegister::createFromIndex(globalBuffers - 2);
20200 cmp(1 | le | flagLast, kCounter, globalBuffers);
20201 copyLoad(globalBuffers - 1);
20202 jmpi(1 | flagLast, smallK);
20203
20204 add(1 | gt | f0[0], kCounter, kCounter, -2 * globalBuffers);
20205
20206 advanceStore();
20207 advanceLoad(0);
20208 if (globalBuffers > 1) advanceLoad(1);
20209 if (globalBuffers > 2) advanceLoad(2);
20210 if (globalBuffers > 3) advanceLoad(3);
20211
20212 copyStore(0);
20213 if (globalBuffers > 4) advanceLoad(4);
20214
20215 fence();
20216 advanceStore(0);
20217 copyLoad(0);
20218 barriermsg(sb15, copyBarrierHeader);
20219 copyStore(1);
20220
20221 advanceLoad(0);
20222
20223 jmpi(1 | ~f0[0], bottom);
20224 sync.nop(SWSB<AllPipes>(1));
20225 mark(top);
20226 {
20227 add(1 | gt | f0[0], kCounter, kCounter, -globalBuffers);
20228 for (int i = 0; i < globalBuffers; i++) {
20229 fence();
20230 advanceStore((i + 1) % globalBuffers);
20231 copyLoad((i + 1) % globalBuffers); // move after barrier?
20232 splitBarrier();
20233 copyStore((i + 2) % globalBuffers);
20234 advanceLoad((i + 1) % globalBuffers);
20235 }
20236 }
20237 jmpi(1 | f0[0], top);
20238 mark(bottom);
20239
20240 if (globalBuffers > 1) cmp(1 | gt | f0[0], kCounter, -globalBuffers + 1);
20241 if (globalBuffers > 2) cmp(1 | gt | f0[1], kCounter, -globalBuffers + 2);
20242 if (globalBuffers > 3) cmp(1 | gt | f1[0], kCounter, -globalBuffers + 3);
20243 if (globalBuffers > 4) cmp(1 | gt | f1[1], kCounter, -globalBuffers + 4);
20244
20245 // Cooldown loop. All buffers but #1 loaded.
20246 for (int i = 1; i < globalBuffers; i++) {
20247 Label skipLoad;
20248 fence();
20249 advanceStore(i);
20250 jmpi(1 | ~FlagRegister::createFromIndex(i - 1), skipLoad);
20251 copyLoad(i);
20252 mark(skipLoad);
20253 splitBarrier();
20254 copyStore((i + 1) % globalBuffers);
20255 advanceLoad(i);
20256 }
20257
20258 jmpi(1, skipSmallK);
20259 mark(smallK);
20260
20261 advanceStore();
20262 copyStore(0);
20263
20264 fence();
20265 advanceStore(0);
20266 barriermsg(sb15, copyBarrierHeader);
20267 jmpi(1, reenterSmallK);
20268
20269 mark(skipSmallK);
20270
20271 for (int i = 0; i < globalBuffers - 1; i++) {
20272 fence();
20273 advanceStore(i);
20274 splitBarrier();
20275 if (i == 0) mark(reenterSmallK);
20276 jmpi(1 | ~FlagRegister::createFromIndex(i), done);
20277 copyStore(i + 1);
20278 }
20279
20280 fence();
20281 splitBarrier();
20282
20283 mark(done);
20284 barrierwait();
20285
20286 state.ra = saveRA;
20287}
20288
20289template <HW hw>
20290void gemm_kernel_generator_t<hw>::sysgemm2KLoopCompute(
20291 const GEMMProblem &problem, const GEMMStrategy &strategy,
20292 GEMMState &state) {
20293 using namespace sysgemm2;
20294
20295 Label top, remainder, done;
20296 bool _32x = (strategy.unroll[LoopM] == 32);
20297 bool x48 = (strategy.unroll[LoopN] == 48);
20298 bool keepBarHdr = strategy.skipFence || !x48;
20299 auto slmStride = strategy.slmSysgemmBlockSize() / 16;
20300 auto barrierHeader = x32::barrierHeader;
20301
20302 mov(1, f0.ud(0), 0);
20303 mov(1, f1.ud(0), 0);
20304 if (x48) {
20305 using namespace x48;
20306 add(1, A_addr[1], A_addr[0], 8 * 32 / 16);
20307 if (_32x) {
20308 add(1, A_addr[2], A_addr[0], 16 * 32 / 16);
20309 add(1, A_addr[3], A_addr[0], 24 * 32 / 16);
20310 }
20311 add(1, B_addr[1], B_addr[0], 8 * 32 / 16);
20312 add(1, B_addr[2], B_addr[0], 16 * 32 / 16);
20313 } else {
20314 using namespace x32;
20315 add(1, A_addr[0][1], A_addr[0][0], 8 * 32 / 16);
20316 if (_32x) {
20317 add(1, A_addr[0][2], A_addr[0][0], 16 * 32 / 16);
20318 add(1, A_addr[0][3], A_addr[0][0], 24 * 32 / 16);
20319 }
20320 add(1, A_addr[1][0], A_addr[0][0], 0 * 32 / 16 + slmStride);
20321 add(1, A_addr[1][1], A_addr[0][0], 8 * 32 / 16 + slmStride);
20322 if (_32x) {
20323 add(1, A_addr[1][2], A_addr[0][0], 16 * 32 / 16 + slmStride);
20324 add(1, A_addr[1][3], A_addr[0][0], 24 * 32 / 16 + slmStride);
20325 }
20326 add(1, B_addr[0][1], B_addr[0][0], 8 * 32 / 16);
20327 add(1, B_addr[1][0], B_addr[0][0], 0 * 32 / 16 + slmStride);
20328 add(1, B_addr[1][1], B_addr[0][0], 8 * 32 / 16 + slmStride);
20329 }
20330
20331 if (keepBarHdr) sysgemmBarrierPrep(InstructionModifier(), barrierHeader);
20332
20333 // Warmup: signal, split barrier, load
20334 cmp(1 | gt | f1[1], kCounter, 1);
20335 add(1 | gt | f0[0], kCounter, kCounter, -5);
20336
20337 if (!keepBarHdr) sysgemmBarrierPrep(InstructionModifier(), barrierHeader);
20338 barriermsg(sb15, barrierHeader);
20339 barrierwait();
20340 barriermsg(sb15 | f1[1], barrierHeader);
20341
20342 bool oldDefaultAutoSWSB = getDefaultAutoSWSB();
20343 setDefaultAutoSWSB(false);
20344 sync.nop(SWSB<AllPipes>(1));
20345
20346 load(16 | sb0, x48::A_regs[0], block_oword(16), SLM, x48::A_addr[0]);
20347 load(16 | sb1, x48::A_regs[8], block_oword(16), SLM, x48::A_addr[1]);
20348 if (_32x) {
20349 load(16 | sb2, x48::A_regs[16], block_oword(16), SLM, x48::A_addr[2]);
20350 load(16 | sb3, x48::A_regs[24], block_oword(16), SLM, x48::A_addr[3]);
20351 }
20352 load(16 | sb4, x48::B_regs[0], block_oword(16), SLM, x48::B_addr[0]);
20353 load(16 | sb5, x48::B_regs[8], block_oword(16), SLM, x48::B_addr[1]);
20354 if (x48)
20355 load(16 | sb6, x48::B_regs[16], block_oword(16), SLM, x48::B_addr[2]);
20356
20357 zeroMatrix(x48 ? x48::C_regs : x32::C_regs, strategy);
20358
20359 jmpi(1 | ~f0[0], remainder);
20360
20361 mark(top);
20362 {
20363 add(1 | gt | f0[0], kCounter, kCounter, -4);
20364 sysgemm2Multiply(problem, strategy, state, 0);
20365 sysgemm2Multiply(problem, strategy, state, 1);
20366 sysgemm2Multiply(problem, strategy, state, 2);
20367 sysgemm2Multiply(problem, strategy, state, 3);
20368 }
20369 jmpi(1 | f0[0], top);
20370
20371 mark(remainder);
20372
20373 cmp(1 | gt | f0[0], kCounter, 1 - 5);
20374 cmp(1 | gt | f0[1], kCounter, 2 - 5);
20375 cmp(1 | gt | f1[0], kCounter, 3 - 5);
20376 cmp(1 | gt | f1[1], kCounter, 4 - 5);
20377 sysgemm2Multiply(problem, strategy, state, 0, true, f0[0], f0[1]);
20378 jmpi(1 | ~f0[0], done);
20379
20380 sysgemm2Multiply(problem, strategy, state, 1, true, f0[1], f1[0]);
20381 jmpi(1 | ~f0[1], done);
20382
20383 sysgemm2Multiply(problem, strategy, state, 2, true, f1[0], f1[1]);
20384 jmpi(1 | ~f1[0], done);
20385
20386 sysgemm2Multiply(problem, strategy, state, 3, true, f1[1]);
20387 jmpi(1 | ~f1[1], done);
20388
20389 sysgemm2Multiply(problem, strategy, state, 0, true);
20390
20391 mark(done);
20392
20393 setDefaultAutoSWSB(oldDefaultAutoSWSB);
20394}
20395
20396template <HW hw>
20397void gemm_kernel_generator_t<hw>::sysgemm2Multiply(const GEMMProblem &problem,
20398 const GEMMStrategy &strategy, GEMMState &state, int slmBuffer,
20399 bool cooldown, FlagRegister flagWaitLoad, FlagRegister flagSignal) {
20400 if (strategy.unroll[LoopN] == 48)
20401 sysgemm2MultiplyX48(problem, strategy, state, slmBuffer, cooldown,
20402 flagWaitLoad, flagSignal);
20403 else
20404 sysgemm2MultiplyX32(problem, strategy, state, slmBuffer, cooldown,
20405 flagWaitLoad, flagSignal);
20406}
20407
20408template <HW hw>
20409void gemm_kernel_generator_t<hw>::sysgemm2MultiplyX48(
20410 const GEMMProblem &problem, const GEMMStrategy &strategy,
20411 GEMMState &state, int slmBuffer, bool cooldown,
20412 FlagRegister flagWaitLoad, FlagRegister flagSignal) {
20413 using namespace sysgemm2;
20414 using namespace sysgemm2::x48;
20415
20416 auto slmStride = strategy.slmSysgemmBlockSize() / 16;
20417 int16_t advance = ((slmBuffer == 3) ? -3 : 1) * slmStride;
20418 InstructionModifier loadMod {}, signalMod {};
20419 bool doWaitLoad = !cooldown || flagWaitLoad.isValid();
20420 bool doSignal = !cooldown || flagSignal.isValid();
20421 if (cooldown) {
20422 if (doWaitLoad) loadMod = loadMod | flagWaitLoad | any16h;
20423 if (doSignal) signalMod = signalMod | flagSignal | any16h;
20424 }
20425
20426 if (strategy.unroll[LoopM] != 32) stub();
20427
20428 if (doWaitLoad) {
20429 Label skipWait;
20430 if (cooldown) jmpi(1 | ~flagWaitLoad, skipWait);
20431
20432 // SLM fence
20433 if (!strategy.skipFence && doSignal)
20434 slmfence(sb15 | signalMod, headerTemp, headerTemp);
20435
20436 // Barrier wait
20437 barrierwait();
20438
20439 // Barrier signal
20440 if (doSignal) {
20441 if (!strategy.skipFence) {
20442 sysgemmBarrierPrep(sb15.dst | signalMod, headerTemp);
20443 barriermsg(sb15 | SWSB<AllPipes>(1) | signalMod, headerTemp);
20444 } else
20445 barriermsg(sb15 | signalMod, headerTemp);
20446 }
20447
20448 if (cooldown) mark(skipWait);
20449 }
20450
20451 // Advance A0 address (note dst)
20452 if (doWaitLoad) add(1 | sb0.dst, A_addr[0], A_addr[0], advance);
20453
20454 // Rows 0-7
20455 sysgemm2MultiplyChunkX48(problem, strategy, 0);
20456
20457 // Advance A1 address
20458 if (doWaitLoad) add(1 | sb1.src, A_addr[1], A_addr[1], advance);
20459
20460 // Rows 8-15
20461 sysgemm2MultiplyChunkX48(problem, strategy, 1);
20462
20463 if (doWaitLoad) {
20464 // Load new A0
20465 load(16 | sb0 | loadMod, A_regs[0], block_oword(16), SLM, A_addr[0]);
20466
20467 // Advance B, A2, A3 addresses
20468 add(1 | sb4.src, B_addr[0], B_addr[0], advance);
20469 add(1 | sb5.src, B_addr[1], B_addr[1], advance);
20470 add(1 | sb6.src, B_addr[2], B_addr[2], advance);
20471
20472 add(1 | sb2.src, A_addr[2], A_addr[2], advance);
20473 add(1 | sb3.src, A_addr[3], A_addr[3], advance);
20474 }
20475
20476 // Rows 16-23
20477 sysgemm2MultiplyChunkX48(problem, strategy, 2);
20478
20479 // Load new A1
20480 if (doWaitLoad)
20481 load(16 | sb1 | loadMod, A_regs[8], block_oword(16), SLM, A_addr[1]);
20482
20483 // Rows 24-31
20484 sysgemm2MultiplyChunkX48(problem, strategy, 3);
20485
20486 if (doWaitLoad) {
20487 // Load new B data
20488 load(16 | sb4 | loadMod, B_regs[0], block_oword(16), SLM, B_addr[0]);
20489 load(16 | sb5 | loadMod, B_regs[8], block_oword(16), SLM, B_addr[1]);
20490 load(16 | sb6 | loadMod, B_regs[16], block_oword(16), SLM, B_addr[2]);
20491
20492 // Load new A2,A3
20493 load(16 | sb2 | loadMod, A_regs[16], block_oword(16), SLM, A_addr[2]);
20494 load(16 | sb3 | loadMod, A_regs[24], block_oword(16), SLM, A_addr[3]);
20495 }
20496}
20497
20498template <HW hw>
20499void gemm_kernel_generator_t<hw>::sysgemm2MultiplyX32(
20500 const GEMMProblem &problem, const GEMMStrategy &strategy,
20501 GEMMState &state, int slmBuffer, bool cooldown,
20502 FlagRegister flagWaitLoad, FlagRegister flagSignal) {
20503 using namespace sysgemm2;
20504 using namespace sysgemm2::x32;
20505
20506 auto slmStride = strategy.slmSysgemmBlockSize() / 16;
20507 int16_t advance = ((slmBuffer >= 2) ? -2 : 2) * slmStride;
20508 InstructionModifier loadMod {}, signalMod {};
20509 bool doWaitLoad = !cooldown || flagWaitLoad.isValid();
20510 bool doSignal = !cooldown || flagSignal.isValid();
20511 if (cooldown) {
20512 if (doWaitLoad) loadMod = loadMod | flagWaitLoad | any16h;
20513 if (doSignal) signalMod = signalMod | flagSignal | any16h;
20514 }
20515 bool _32x = (strategy.unroll[LoopM] == 32);
20516 bool odd = (slmBuffer & 1);
20517 int tokenBase = odd ? 8 : 0;
20518 int otokenBase = odd ? 0 : 8;
20519
20520 if (doWaitLoad) {
20521 Label skipWait;
20522 if (cooldown) jmpi(1 | ~flagWaitLoad, skipWait);
20523
20524 // SLM fence
20525 if (!strategy.skipFence && doSignal)
20526 slmfence(sb15 | signalMod, fenceHeader, fenceHeader);
20527
20528 add(1 | SBID(tokenBase + 0).src, A_addr[odd][0], A_addr[odd][0],
20529 advance); // TODO: reuse src0
20530 add(1 | SBID(tokenBase + 4).src, B_addr[odd][0], B_addr[odd][0],
20531 advance);
20532 add(1 | SBID(tokenBase + 5).src, B_addr[odd][1], B_addr[odd][1],
20533 advance);
20534 add(1 | SBID(tokenBase + 1).src, A_addr[odd][1], A_addr[odd][1],
20535 advance);
20536 if (_32x) {
20537 add(1 | SBID(tokenBase + 2).src, A_addr[odd][2], A_addr[odd][2],
20538 advance);
20539 add(1 | SBID(tokenBase + 3).src, A_addr[odd][3], A_addr[odd][3],
20540 advance);
20541 }
20542
20543 // Barrier wait
20544 barrierwait();
20545
20546 if (hw >= HW::XeHPG) {
20547 // Wait for SLM loads to return before signaling.
20548 sync.allwr(0x3F << tokenBase);
20549 }
20550
20551 // Barrier signal
20552 if (doSignal) barriermsg(sb15 | signalMod, barrierHeader);
20553
20554 if (cooldown) mark(skipWait);
20555 }
20556
20557 if (doWaitLoad) {
20558 load(16 | SBID(otokenBase + 0) | loadMod, A_regs[!odd][0],
20559 block_oword(16), SLM, A_addr[!odd][0]);
20560 load(16 | SBID(otokenBase + 4) | loadMod, B_regs[!odd][0],
20561 block_oword(16), SLM, B_addr[!odd][0]);
20562 load(16 | SBID(otokenBase + 5) | loadMod, B_regs[!odd][8],
20563 block_oword(16), SLM, B_addr[!odd][1]);
20564 load(16 | SBID(otokenBase + 1) | loadMod, A_regs[!odd][8],
20565 block_oword(16), SLM, A_addr[!odd][1]);
20566 if (_32x) {
20567 load(16 | SBID(otokenBase + 2) | loadMod, A_regs[!odd][16],
20568 block_oword(16), SLM, A_addr[!odd][2]);
20569 load(16 | SBID(otokenBase + 3) | loadMod, A_regs[!odd][24],
20570 block_oword(16), SLM, A_addr[!odd][3]);
20571 }
20572 }
20573
20574 sysgemm2MultiplyChunkX32(problem, strategy, 0, odd);
20575 sysgemm2MultiplyChunkX32(problem, strategy, 1, odd);
20576 if (_32x) {
20577 sysgemm2MultiplyChunkX32(problem, strategy, 2, odd);
20578 sysgemm2MultiplyChunkX32(problem, strategy, 3, odd);
20579 }
20580}
20581
20582template <HW hw>
20583void gemm_kernel_generator_t<hw>::sysgemm2MultiplyChunkX48(
20584 const GEMMProblem &problem, const GEMMStrategy &strategy, int chunkA) {
20585 using namespace sysgemm2;
20586 using namespace sysgemm2::x48;
20587 int ao = chunkA * 8;
20588 int co = ao * 6;
20589 bool waitB = (chunkA == 0);
20590 bool prepB = (chunkA == 3);
20591 SBID sbA(chunkA);
20592
20593 auto dpaswTyped = [&](InstructionModifier mod, uint8_t sdepth,
20594 uint8_t rcount, const GRF &cReg, const GRF &aReg,
20595 const GRF &bReg) {
20596 dpasw(mod, sdepth, rcount, cReg.retype(problem.Tc.ngen()),
20597 cReg.retype(problem.Tc.ngen()), aReg.retype(problem.Ta.ngen()),
20598 bReg.retype(problem.Tb.ngen()));
20599 };
20600
20601 if (waitB) {
20602 /* sync.nop(sbA.dst); */ // arranged by caller
20603 dpaswTyped(
20604 8 | sb4.dst | Atomic, 8, 8, C_regs[co], A_regs[ao], B_regs[0]);
20605 dpaswTyped(8, 8, 8, C_regs[co + 8], A_regs[ao], B_regs[4]);
20606 dpaswTyped(8 | sb5.dst | Atomic, 8, 8, C_regs[co + 16], A_regs[ao],
20607 B_regs[8]);
20608 dpaswTyped(8, 8, 8, C_regs[co + 24], A_regs[ao], B_regs[12]);
20609 dpaswTyped(8 | sb6.dst | Atomic, 8, 8, C_regs[co + 32], A_regs[ao],
20610 B_regs[16]);
20611 dpaswTyped(8 | sbA, 8, 8, C_regs[co + 40], A_regs[ao], B_regs[20]);
20612 } else if (prepB) {
20613 dpaswTyped(
20614 8 | sbA.dst | Atomic, 8, 8, C_regs[co], A_regs[ao], B_regs[0]);
20615 dpaswTyped(8 | sb4, 8, 8, C_regs[co + 8], A_regs[ao], B_regs[4]);
20616 dpaswTyped(8 | Atomic, 8, 8, C_regs[co + 16], A_regs[ao], B_regs[8]);
20617 dpaswTyped(8 | sb5, 8, 8, C_regs[co + 24], A_regs[ao], B_regs[12]);
20618 dpaswTyped(8 | Atomic, 8, 8, C_regs[co + 32], A_regs[ao], B_regs[16]);
20619 dpaswTyped(8 | sb6, 8, 8, C_regs[co + 40], A_regs[ao], B_regs[20]);
20620 } else {
20621 dpaswTyped(
20622 8 | sbA.dst | Atomic, 8, 8, C_regs[co], A_regs[ao], B_regs[0]);
20623 dpaswTyped(8 | Atomic, 8, 8, C_regs[co + 8], A_regs[ao], B_regs[4]);
20624 dpaswTyped(8 | Atomic, 8, 8, C_regs[co + 16], A_regs[ao], B_regs[8]);
20625 dpaswTyped(8 | Atomic, 8, 8, C_regs[co + 24], A_regs[ao], B_regs[12]);
20626 dpaswTyped(8 | Atomic, 8, 8, C_regs[co + 32], A_regs[ao], B_regs[16]);
20627 dpaswTyped(8 | sbA, 8, 8, C_regs[co + 40], A_regs[ao], B_regs[20]);
20628 }
20629}
20630
20631template <HW hw>
20632void gemm_kernel_generator_t<hw>::sysgemm2MultiplyChunkX32(
20633 const GEMMProblem &problem, const GEMMStrategy &strategy, int chunkA,
20634 bool odd) {
20635 using namespace sysgemm2;
20636 using namespace sysgemm2::x32;
20637 int ao = chunkA * 8;
20638 int co = ao * 4;
20639 int nchunks = strategy.unroll[LoopM] / 8;
20640 bool waitB = (chunkA == 0);
20641 bool prepB = (chunkA == nchunks - 1);
20642 int tokenBase = odd ? 8 : 0;
20643 SBID sbA(tokenBase + chunkA);
20644 SBID sbB0(tokenBase + 4);
20645 SBID sbB1(tokenBase + 5);
20646
20647 auto dpaswTyped = [&](InstructionModifier mod, uint8_t sdepth,
20648 uint8_t rcount, const GRF &cReg, const GRF &aReg,
20649 const GRF &bReg) {
20650 dpasw(mod, sdepth, rcount, cReg.retype(problem.Tc.ngen()),
20651 cReg.retype(problem.Tc.ngen()), aReg.retype(problem.Ta.ngen()),
20652 bReg.retype(problem.Tb.ngen()));
20653 };
20654
20655 if (waitB) {
20656 sync.nop(sbA.dst);
20657 dpaswTyped(8 | sbB0.dst | Atomic, 8, 8, C_regs[co], A_regs[odd][ao],
20658 B_regs[odd][0]);
20659 dpaswTyped(8, 8, 8, C_regs[co + 8], A_regs[odd][ao], B_regs[odd][4]);
20660 dpaswTyped(8 | sbB1.dst | Atomic, 8, 8, C_regs[co + 16],
20661 A_regs[odd][ao], B_regs[odd][8]);
20662 dpaswTyped(8 | sbA, 8, 8, C_regs[co + 24], A_regs[odd][ao],
20663 B_regs[odd][12]);
20664 } else if (prepB) {
20665 dpaswTyped(8 | sbA.dst | Atomic, 8, 8, C_regs[co], A_regs[odd][ao],
20666 B_regs[odd][0]);
20667 dpaswTyped(8 | sbB0, 8, 8, C_regs[co + 8], A_regs[odd][ao],
20668 B_regs[odd][4]);
20669 dpaswTyped(8 | Atomic, 8, 8, C_regs[co + 16], A_regs[odd][ao],
20670 B_regs[odd][8]);
20671 dpaswTyped(8 | sbB1, 8, 8, C_regs[co + 24], A_regs[odd][ao],
20672 B_regs[odd][12]);
20673 } else {
20674 dpaswTyped(8 | sbA.dst | Atomic, 8, 8, C_regs[co], A_regs[odd][ao],
20675 B_regs[odd][0]);
20676 dpaswTyped(8 | Atomic, 8, 8, C_regs[co + 8], A_regs[odd][ao],
20677 B_regs[odd][4]);
20678 dpaswTyped(8 | Atomic, 8, 8, C_regs[co + 16], A_regs[odd][ao],
20679 B_regs[odd][8]);
20680 dpaswTyped(8 | sbA, 8, 8, C_regs[co + 24], A_regs[odd][ao],
20681 B_regs[odd][12]);
20682 }
20683}
20684
20685/**********************************************************************/
20686/* Copy Kernels */
20687/**********************************************************************/
20688
20689// Initialize the interface and claim arguments.
20690template <HW hw>
20691void gemm_kernel_generator_t<hw>::copyInitInterface(
20692 CopyProblem &problem, CopyStrategy &strategy, CopyState &state) {
20693 if (strategy.barrierFreq > 0) interface.requireBarrier();
20694
20695 interface.finalize();
20696
20697 // Get input register assignments.
20698 state.inputs.S = interface.getArgumentIfExists("S");
20699 state.inputs.D = interface.getArgumentIfExists("D");
20700 state.inputs.surfaceS = interface.getArgumentSurfaceIfExists("S");
20701 state.inputs.surfaceD = interface.getArgumentSurfaceIfExists("D");
20702 state.inputs.offsetS = interface.getArgument("offset_S");
20703 state.inputs.offsetD = interface.getArgument("offset_D");
20704 state.inputs.lds = interface.getArgument("lds");
20705 state.inputs.ldd = interface.getArgumentIfExists("ldd");
20706 state.inputs.m = interface.getArgument("m");
20707 state.inputs.n = interface.getArgument("n");
20708 state.inputs.alpha_real = interface.getArgumentIfExists("alpha_real");
20709 state.inputs.alpha_imag = interface.getArgumentIfExists("alpha_imag");
20710 state.inputs.diag = interface.getArgumentIfExists("diag");
20711 state.inputs.blockZ = interface.getArgumentIfExists("block_z");
20712
20713 state.inputs.localIDW = interface.getLocalID(0);
20714 state.inputs.localSizeW = interface.getLocalSize(0);
20715 if (strategy.zParallel) {
20716 state.inputs.localIDZ = interface.getLocalID(1);
20717 state.inputs.localSizeZ = interface.getLocalSize(1);
20718 }
20719
20720 state.inputs.groupIDW = r0.ud(1);
20721 if (strategy.zParallel) state.inputs.groupIDZ = r0.ud(6);
20722
20723 // Downgrade offset variables to 32-bit for non-A64 accesses.
20724 if (strategy.S.base.getModel() != ModelA64)
20725 state.inputs.offsetS = state.inputs.offsetS.d();
20726 if (strategy.D.base.getModel() != ModelA64)
20727 state.inputs.offsetD = state.inputs.offsetD.d();
20728
20729 // For now, reinterpret m/n/ld/diag variables to 32-bit if they are 64-bit.
20730 state.inputs.m = state.inputs.m.d();
20731 state.inputs.n = state.inputs.n.d();
20732 state.inputs.lds = state.inputs.lds.ud();
20733 if (state.inputs.ldd.isValid()) state.inputs.ldd = state.inputs.ldd.ud();
20734 if (state.inputs.diag.isValid()) state.inputs.diag = state.inputs.diag.d();
20735
20736 // Claim inputs.
20737 for (int i = 0; i < 4; i++)
20738 state.ra.claim(r0.uq(i));
20739
20740 if (strategy.S.base.isStateless()) state.ra.claim(state.inputs.S);
20741 if (strategy.D.base.isStateless()) state.ra.claim(state.inputs.D);
20742
20743 state.ra.claim(state.inputs.offsetS);
20744 state.ra.claim(state.inputs.offsetD);
20745 state.ra.claim(state.inputs.lds);
20746 if (state.inputs.ldd.isValid()) state.ra.claim(state.inputs.ldd);
20747 state.ra.claim(state.inputs.m);
20748 state.ra.claim(state.inputs.n);
20749 if (state.inputs.diag.isValid()) state.ra.claim(state.inputs.diag);
20750 if (!problem.alpha_real.fixed()) state.ra.claim(state.inputs.alpha_real);
20751 if (problem.Td.isComplex() && !problem.alpha_imag.fixed())
20752 state.ra.claim(state.inputs.alpha_imag);
20753
20754 state.ra.claim(state.inputs.localIDW);
20755 state.ra.claim(state.inputs.localSizeW);
20756 if (strategy.zParallel) {
20757 state.ra.claim(state.inputs.localIDZ);
20758 state.ra.claim(state.inputs.localSizeZ);
20759 }
20760
20761 if (strategy.zParallel) state.ra.claim(state.inputs.blockZ);
20762}
20763
20764// Initialize the state structure.
20765template <HW hw>
20766void gemm_kernel_generator_t<hw>::copyInitState(
20767 CopyProblem &problem, CopyStrategy &strategy, CopyState &state) {
20768 if (!state.fusedGEMM.active) {
20769 initState(problem, strategy, state);
20770 copyInitInterface(problem, strategy, state);
20771 state.isNested = false;
20772 }
20773
20774 state.effS = strategy.S.base.isStateless() ? state.inputs.S
20775 : state.inputs.offsetS.d();
20776 state.effD = strategy.D.base.isStateless() ? state.inputs.D
20777 : state.inputs.offsetD.d();
20778
20779 if (!problem.alpha_real.fixed())
20780 problem.alpha_real = state.inputs.alpha_real;
20781 if (problem.Td.isComplex() && !problem.alpha_imag.fixed())
20782 problem.alpha_imag = state.inputs.alpha_imag;
20783
20784 state.flagAP = state.raVFlag.alloc();
20785
20786 state.allocEmulate64Temp(strategy.emulate);
20787}
20788
20789// Copy kernel generation interface.
20790template <HW hw>
20791void gemm_kernel_generator_t<hw>::copy(CopyProblem problem,
20792 CopyStrategy strategy, const InterfaceHandler &interface_) {
20793 interface = interface_;
20794 CopyState state(hw);
20795 copy(problem, strategy, state);
20796}
20797
20798template <HW hw>
20799void gemm_kernel_generator_t<hw>::copy(
20800 CopyProblem &problem, CopyStrategy &strategy, CopyState &state) {
20801 bool inFused = state.fusedGEMM.active;
20802 auto unrollW = strategy.unrollW();
20803
20804 // Check layouts.
20805 if (!isPacked(problem.D.layout)) stub();
20806
20807 if (strategy.zParallel && problem.sum) stub();
20808
20809 // By default, don't use dispatch mask.
20810 setDefaultNoMask();
20811 setDefaultAutoSWSB();
20812
20813 // Set up.
20814 copyInitState(problem, strategy, state);
20815
20816 if (!strategy.S.base.isStateless())
20817 strategy.S.base.setIndex(state.inputs.surfaceS);
20818 if (!strategy.D.base.isStateless())
20819 strategy.D.base.setIndex(state.inputs.surfaceD);
20820
20821 // Prologue.
20822 if (!inFused) prologue(strategy);
20823
20824 // Grab fused ID if needed.
20825 getFusedID(unrollW, problem, strategy, state);
20826
20827 // Calculate w0, the starting row/column for this thread.
20828 // This is the first x (if xloop = false) or y (xloop = true) value.
20829 state.w0 = state.ra.alloc_sub<uint32_t>(
20830 getHint(HintType::TempComp0, strategy));
20831 if (strategy.zParallel)
20832 state.z0 = state.ra.alloc_sub<uint32_t>(
20833 getHint(HintType::TempComp1, strategy));
20834
20835 auto globalIDW = state.ra.alloc_sub<uint32_t>(
20836 getHint(HintType::TempComp1, strategy));
20837 auto globalIDZ = state.ra.alloc_sub<uint32_t>(
20838 getHint(HintType::TempComp0, strategy));
20839
20840 int idWScale = inFused ? 1 : strategy.subgroupSize;
20841 bool multiple = (unrollW % idWScale) == 0;
20842
20843 if (strategy.wgW > 0)
20844 mulConstant(
20845 1, globalIDW, state.inputs.groupIDW, strategy.wgW * idWScale);
20846 else
20847 mul(1, globalIDW, state.inputs.groupIDW, state.inputs.localSizeW.uw());
20848 if (strategy.zParallel) {
20849 if (strategy.wgZ > 0)
20850 mulConstant(1, globalIDZ, state.inputs.groupIDZ, strategy.wgZ);
20851 else
20852 mul(1, globalIDZ, state.inputs.groupIDZ,
20853 state.inputs.localSizeZ.uw());
20854 }
20855 add(1, globalIDW, globalIDW, state.inputs.localIDW.uw(0));
20856 if (strategy.zParallel && (strategy.wgZ != 1))
20857 add(1, globalIDZ, globalIDZ, state.inputs.localIDZ.uw(0));
20858 if (multiple)
20859 mulConstant(1, state.w0, globalIDW, unrollW / idWScale);
20860 else {
20861 mulConstant(1, state.w0, globalIDW, unrollW);
20862 shr(1, state.w0, state.w0, log2(idWScale));
20863 }
20864 if (strategy.zParallel)
20865 emul(1, state.z0, globalIDZ, state.inputs.blockZ, strategy, state);
20866
20867 state.ra.safeRelease(globalIDW);
20868 state.ra.safeRelease(globalIDZ);
20869 state.ra.safeRelease(state.inputs.localIDW);
20870 state.ra.safeRelease(state.inputs.localIDZ);
20871 state.ra.safeRelease(state.inputs.localSizeW);
20872 state.ra.safeRelease(state.inputs.localSizeZ);
20873
20874 // Move r0 to acc0 if configured.
20875 moveR0(strategy, state);
20876
20877 // Copy our slice.
20878 copySlice(problem, strategy, state);
20879
20880 if (!inFused) {
20881 epilogue(strategy, state);
20882
20883 padding();
20884 }
20885}
20886
20887// Calculate or recalculate lds_sl/ldd_dl as needed.
20888template <HW hw>
20889void gemm_kernel_generator_t<hw>::copyCalcIncrements(const CopyProblem &problem,
20890 const CopyStrategy &strategy, CopyState &state, int s_load,
20891 int d_load) {
20892 // S: w0 * s_load is needed for N->Pc, T->Pr [!xLoop] N->Pr, T->Pc [xLoop]
20893 // D: no increment needed (always packed) [!xLoop] ldd * d_load [xLoop]
20894 bool sStrided
20895 = (isColMajor(problem.S.layout) == isColMajor(problem.D.layout))
20896 ^ strategy.xLoop;
20897
20898 if (sStrided || problem.reflecting()) {
20899 if (s_load == 0) s_load = strategy.s_load;
20900 if (s_load > 1) {
20901 if (state.lds_sl.isInvalid()) {
20902 state.lds_sl = state.ra.alloc_sub<uint32_t>();
20903 s_load *= problem.Ts.size();
20904 }
20905 emulConstant(
20906 1, state.lds_sl, state.inputs.lds, s_load, strategy, state);
20907 }
20908 }
20909
20910 if (strategy.xLoop) {
20911 if (d_load == 0) d_load = strategy.d_load;
20912 if (d_load > 1) {
20913 if (state.ldd_dl.isInvalid()) {
20914 state.ldd_dl = state.ra.alloc_sub<uint32_t>();
20915 d_load *= problem.Td.size();
20916 }
20917 emulConstant(
20918 1, state.ldd_dl, state.inputs.ldd, d_load, strategy, state);
20919 }
20920 }
20921}
20922
20923// Copy kernel generation interface.
20924template <HW hw>
20925void gemm_kernel_generator_t<hw>::copySlice(
20926 CopyProblem &problem, CopyStrategy &strategy, CopyState &state) {
20927 auto Ts = problem.Ts, Td = problem.Td;
20928 Label labelExit;
20929 Subregister lddSrc;
20930
20931 // If ldd not specified, use y.
20932 if (state.inputs.ldd.isInvalid()) {
20933 state.inputs.ldd = lddSrc = (problem.D.layout == MatrixLayout::Pc)
20934 ? state.inputs.n
20935 : state.inputs.m;
20936 if (problem.D.crosspack > 1 || problem.sum || strategy.zParallel) {
20937 state.inputs.ldd = state.ra.alloc_sub<uint32_t>(
20938 getHint(HintType::LongTerm, strategy));
20939 mov(1, state.inputs.ldd, lddSrc);
20940 lddSrc = invalid;
20941 }
20942 if (problem.D.crosspack > 1) {
20943 add(1, state.inputs.ldd, state.inputs.ldd, problem.D.crosspack - 1);
20944 and_(1, state.inputs.ldd, state.inputs.ldd,
20945 ~uint32_t(problem.D.crosspack - 1));
20946 }
20947 if (problem.sum)
20948 add(1, state.inputs.ldd, state.inputs.ldd,
20949 problem.Tsum.size() / problem.Td.size());
20950 }
20951
20952 // Duplicate alpha if configured.
20953 if (strategy.duplicateAlpha) { duplicateScalar(problem.alpha_real, state); }
20954
20955 // For fused kernels, compute 2 * unrollW - fusedID for use in several places.
20956 Subregister unrollWRem;
20957 if (strategy.fused) {
20958 unrollWRem = state.ra.alloc_sub<uint32_t>(
20959 getHint(HintType::TempComp0, strategy));
20960 add(1, unrollWRem, -state.fusedID, uint16_t(2 * strategy.unrollW()));
20961 }
20962
20963 // Align code paths.
20964 bool mLoop = isColMajor(problem.D.layout) == strategy.xLoop;
20965 auto z = mLoop ? state.inputs.m : state.inputs.n;
20966 Subregister z0;
20967
20968 // Handle z blocking.
20969 if (strategy.zParallel) {
20970 z0 = state.z0;
20971 add(1 | le | f0[1], z, z, -z0);
20972 min_(1, z, z, state.inputs.blockZ);
20973 state.ra.safeRelease(state.inputs.blockZ);
20974 }
20975
20976 // Compute base addresses for S, D.
20977 // S += w0 + z0 * lds (N->Pc, T->Pr) z0 + w0 * lds (N->Pr, T->Pc) [swapped if xLoop = true]
20978 bool sStrided
20979 = (isColMajor(problem.S.layout) == isColMajor(problem.D.layout))
20980 ^ strategy.xLoop;
20981 auto incC = sStrided ? state.w0 : z0;
20982 auto incS = sStrided ? z0 : state.w0;
20983
20984 if (incC.isValid())
20985 eadd(1, state.inputs.offsetS, state.inputs.offsetS, incC, strategy,
20986 state);
20987 if (incS.isValid()) {
20988 Subregister temp = state.ra.alloc_sub(state.inputs.offsetS.getType(),
20989 getHint(HintType::TempComp1, strategy));
20990 emul(1, temp, incS, state.inputs.lds, strategy, state);
20991 eadd(1, state.inputs.offsetS, state.inputs.offsetS, temp, strategy,
20992 state);
20993 state.ra.safeRelease(temp);
20994 }
20995
20996 // Quick exit if no work to do.
20997 if (strategy.zParallel) jmpi(1 | f0[1], labelExit);
20998
20999 // D += align_up(x0, unroll) * ldd + y0 * unroll + (x0 % unroll) * crosspack
21000 {
21001 Subregister temp0 = state.ra.alloc_sub(state.inputs.offsetD.getType(),
21002 getHint(HintType::TempComp0, strategy));
21003 Subregister temp1 = state.ra.alloc_sub(state.inputs.offsetD.getType(),
21004 getHint(HintType::TempComp1, strategy));
21005 Subregister temp2 = state.ra.alloc_sub<uint32_t>(
21006 getHint(HintType::TempComp0, strategy));
21007 auto x0 = strategy.xLoop ? z0 : state.w0;
21008 auto y0 = strategy.xLoop ? state.w0 : z0;
21009 bool splitX = strategy.unrollX < problem.D.packSize;
21010
21011 if (x0.isValid()) {
21012 if (splitX) {
21013 modExt(temp2, temp1.ud(), x0, problem.D.packSize, strategy,
21014 state);
21015 emul(1, temp0, temp1.ud(), state.inputs.ldd, strategy, state);
21016 mulConstant(1, temp2, temp2, problem.D.crosspack);
21017 } else
21018 emul(1, temp0, x0, state.inputs.ldd, strategy, state);
21019 }
21020 if (y0.isValid())
21021 emulConstant(1, temp1, y0, problem.D.packSize, strategy, state);
21022 if (x0.isValid())
21023 eadd(1, state.inputs.offsetD, state.inputs.offsetD, temp0, strategy,
21024 state);
21025 if (y0.isValid())
21026 eadd(1, state.inputs.offsetD, state.inputs.offsetD, temp1, strategy,
21027 state);
21028 if (x0.isValid() && splitX)
21029 eadd(1, state.inputs.offsetD, state.inputs.offsetD, temp2, strategy,
21030 state);
21031
21032 state.ra.safeRelease(temp0);
21033 state.ra.safeRelease(temp1);
21034 state.ra.safeRelease(temp2);
21035 }
21036
21037 state.ra.safeRelease(z0);
21038 state.z0 = invalid;
21039
21040 // Calculate increments.
21041 copyCalcIncrements(problem, strategy, state);
21042
21043 // Calculate remainders for w loop as needed.
21044 if (!strategy.xLoop
21045 && (strategy.remHandlingX != RemainderHandling::Ignore)) {
21046 auto x = (problem.D.layout == MatrixLayout::Pc) ? state.inputs.m
21047 : state.inputs.n;
21048 state.remainderX = state.ra.alloc_sub<uint32_t>();
21049 add(1 | sat, state.remainderX, -state.w0, x);
21050 if (strategy.remHandlingX == RemainderHandling::Split) {
21051 if (strategy.fused)
21052 cmp(1 | lt | state.flagAP, null.ud(), state.remainderX,
21053 unrollWRem);
21054 else
21055 cmp(1 | lt | state.flagAP, null.ud(), state.remainderX,
21056 strategy.unrollX);
21057 mov(1 | ~state.flagAP, state.remainderX, strategy.unrollX);
21058 } else
21059 min_(1, state.remainderX, state.remainderX, strategy.unrollX);
21060 }
21061 if (strategy.xLoop
21062 && (strategy.remHandlingY != RemainderHandling::Ignore)) {
21063 auto y = (problem.D.layout == MatrixLayout::Pc) ? state.inputs.n
21064 : state.inputs.m;
21065 state.remainderY = state.ra.alloc_sub<uint32_t>();
21066 add(1 | sat, state.remainderY, -state.w0, y);
21067 if (strategy.remHandlingY == RemainderHandling::Split) {
21068 if (strategy.fused)
21069 cmp(1 | lt | state.flagAP, null.ud(), state.remainderY,
21070 unrollWRem);
21071 else
21072 cmp(1 | lt | state.flagAP, null.ud(), state.remainderY,
21073 strategy.unrollY);
21074 mov(1 | ~state.flagAP, state.remainderY, strategy.unrollY);
21075 } else
21076 min_(1, state.remainderY, state.remainderY, strategy.unrollY);
21077 }
21078
21079 // Convert lds to bytes.
21080 emulConstant(
21081 1, state.inputs.lds, state.inputs.lds, Ts.size(), strategy, state);
21082
21083 // Add offsets to base pointers for stateless accesses.
21084 emulConstant(1, state.inputs.offsetS, state.inputs.offsetS, Ts.size(),
21085 strategy, state);
21086 emulConstant(1, state.inputs.offsetD, state.inputs.offsetD, Td.size(),
21087 strategy, state);
21088
21089 if (strategy.S.base.isStateless()) {
21090 eadd(1, state.inputs.S, state.inputs.S, state.inputs.offsetS, strategy,
21091 state);
21092
21093 state.ra.safeRelease(state.inputs.offsetS);
21094 } else
21095 state.effS1 = state.offsetS1;
21096
21097 if (strategy.D.base.isStateless()) {
21098 eadd(1, state.inputs.D, state.inputs.D, state.inputs.offsetD, strategy,
21099 state);
21100 state.ra.safeRelease(state.inputs.offsetD);
21101 }
21102
21103 state.ra.safeRelease(unrollWRem);
21104
21105 if (!copyBody(problem, strategy, state)) {
21106 lastException ? std::rethrow_exception(lastException)
21107 : throw std::runtime_error("Could not generate kernel.");
21108 }
21109
21110 mark(labelExit);
21111}
21112
21113// Wrapper around copyBodyRemCheck, checking for optimally-aligned S.
21114template <HW hw>
21115bool gemm_kernel_generator_t<hw>::copyBody(
21116 CopyProblem &problem, CopyStrategy &strategy, CopyState &state) {
21117 if (!is_zero_or_pow2(strategy.optionalAlignS)) stub();
21118
21119 bool success;
21120
21121 if (strategy.optionalAlignS == 0)
21122 success = copyBodyRemCheck(problem, strategy, state);
21123 else {
21124 Label labelUnaligned, labelEnd;
21125
21126 status << "S alignment check" << status_stream::endl;
21127 and_(1 | nz | f0[1], null.uw(), state.effS.uw(),
21128 uint16_t(strategy.optionalAlignS - 1));
21129 and_(1 | nz | f1[1], null.uw(), state.inputs.lds.uw(),
21130 uint16_t(strategy.optionalAlignS - 1));
21131 ejmpi(1 | f0[1] | anyv, labelUnaligned);
21132
21133 auto modProblem = problem;
21134 modProblem.S.setAlignment(strategy.optionalAlignS);
21135
21136 status << "S aligned to " << strategy.optionalAlignS << ':'
21137 << status_stream::endl;
21138 success = copyBodyRemCheck(modProblem, strategy, state);
21139
21140 if (state.isNested)
21141 jmpi(1, labelEnd);
21142 else
21143 epilogue(strategy, state);
21144
21145 mark(labelUnaligned);
21146
21147 status << "S unaligned" << status_stream::endl;
21148 success = success && copyBodyRemCheck(problem, strategy, state);
21149
21150 mark(labelEnd);
21151 }
21152
21153 return success;
21154}
21155
21156// Wrapper around copyBodyInternal, handling split remainders.
21157template <HW hw>
21158bool gemm_kernel_generator_t<hw>::copyBodyRemCheck(
21159 CopyProblem &problem, CopyStrategy &strategy, CopyState &state) {
21160 auto CopyStrategy::*remHandlingW
21161 = (strategy.xLoop ? &CopyStrategy::remHandlingY
21162 : &CopyStrategy::remHandlingX);
21163 bool wSplit = strategy.*remHandlingW == RemainderHandling::Split;
21164 bool success;
21165
21166 if (!wSplit)
21167 success = copyBodyInternal(problem, strategy, state);
21168 else {
21169 CopyStrategy modStrategy = strategy;
21170 Label wRemBegin, wRemEnd;
21171 jmpi(1 | state.flagAP, wRemBegin);
21172
21173 status << "Generating "
21174 << "xy"[strategy.xLoop] << " non-remainder kernel"
21175 << status_stream::endl;
21176 modStrategy.*remHandlingW = RemainderHandling::Ignore;
21177 success = copyBodyInternal(problem, modStrategy, state);
21178
21179 if (state.isNested)
21180 jmpi(1, wRemEnd);
21181 else
21182 epilogue(strategy, state);
21183
21184 modStrategy.*remHandlingW = RemainderHandling::KnownRemainder;
21185
21186 bool recalc = false;
21187
21188 if (strategy.xLoop && !isTransposing(modStrategy.D.accessType)
21189 && !isLargeCrosspack(problem.Td, problem.D.crosspack)) {
21190 // Change D access to use scattered stores so masking is possible.
21191 modStrategy.D.accessType = AccessType::Scattered;
21192 modStrategy.S.accessType = isTransposing(modStrategy.S.accessType)
21193 ? AccessType::Block
21194 : AccessType::Scattered;
21195 }
21196 if (!strategy.xLoop && !strategy.S.padded) {
21197 // Check if we need to change s_load/d_load.
21198 if (strategy.s_load > strategy.s_load_masked) {
21199 status << "Downgrading s_load: " << strategy.s_load << " -> "
21200 << strategy.s_load_masked << status_stream::endl;
21201 modStrategy.s_load = strategy.s_load_masked;
21202 recalc = true;
21203 }
21204 if (strategy.d_load > strategy.d_load_masked) {
21205 status << "Downgrading d_load: " << strategy.d_load << " -> "
21206 << strategy.d_load_masked << status_stream::endl;
21207 modStrategy.d_load = strategy.d_load_masked;
21208 recalc = true;
21209 }
21210 }
21211
21212 status << "Generating "
21213 << "xy"[strategy.xLoop] << " remainder kernel"
21214 << status_stream::endl;
21215 mark(wRemBegin);
21216 if (recalc) copyCalcIncrements(problem, modStrategy, state);
21217 success = success && copyBodyInternal(problem, modStrategy, state);
21218 mark(wRemEnd);
21219 }
21220
21221 return success;
21222}
21223
21224// Body of copy kernel.
21225template <HW hw>
21226bool gemm_kernel_generator_t<hw>::copyBodyInternal(
21227 CopyProblem &problem, CopyStrategy &strategy, CopyState &state) {
21228 Label lZLoopBegin, lZLoopEnd;
21229 constexpr auto SD_copies = 1;
21230 vector<MaskAssignment> masks;
21231 bool share;
21232
21233 auto Ts = problem.Ts, Td = problem.Td, Tsum = problem.Tsum;
21234 const bool byColumn = isColMajor(problem.D.layout);
21235 const bool sStrided
21236 = (isColMajor(problem.S.layout) == isColMajor(problem.D.layout))
21237 ^ strategy.xLoop;
21238 const bool mLoop = isColMajor(problem.D.layout) == strategy.xLoop;
21239
21240 const bool reflecting = false;
21241 const bool triRemOnly = false;
21242
21243 auto crosspack = problem.D.crosspack;
21244
21245 // Release w0 -- no longer needed.
21246 state.ra.safeRelease(state.w0);
21247
21248 // Get flag register for complex swizzles for XeHP+.
21249 if (hw >= HW::XeHP && Ts.isComplex()) {
21250 state.flagSwizzle = state.raVFlag.alloc();
21251 state.raVFlag.unlock(state.flagSwizzle);
21252 }
21253
21254 MatrixAddressingStrategy S_strategyReflected = strategy.S;
21255 vector<RegisterBlock> S_layoutReflected;
21256
21257 // Decide what remainder handling needs to be done.
21258 bool remainderX = (strategy.remHandlingX != RemainderHandling::Ignore);
21259 bool remainderY = (strategy.remHandlingY != RemainderHandling::Ignore);
21260 bool remainderZ = strategy.xLoop ? remainderX : remainderY;
21261
21262 bool checkYRem1 = strategy.xLoop && remainderY && strategy.unrollY == 1;
21263 VirtualFlag flagYRem1;
21264
21265 remainderY &= !checkYRem1;
21266
21267 // Get register layouts for S and D.
21268 int nms, nmd, nns, nnd;
21269 auto setup = [&](int s_load, int d_load, Subregister S_addr0,
21270 Subregister S1_addr0, Subregister D_addr0,
21271 bool handleRemZ) -> bool {
21272 bool remM = remainderX && (!strategy.xLoop || handleRemZ);
21273 bool remN = remainderY && (strategy.xLoop || handleRemZ);
21274 Subregister remainders[3]
21275 = {state.remainderX, state.remainderY, Subregister {}};
21276
21277 if (!strategy.xLoop) {
21278 nmd = nms = strategy.unrollX;
21279 nnd = d_load;
21280 nns = s_load;
21281 } else {
21282 nnd = nns = strategy.unrollY;
21283 nmd = d_load;
21284 nms = s_load;
21285 }
21286
21287 if (!byColumn) {
21288 std::swap(nms, nns);
21289 std::swap(nmd, nnd);
21290 std::swap(remM, remN);
21291 std::swap(remainders[0], remainders[1]);
21292 }
21293
21294 auto remM_S = remM && !strategy.S.padded;
21295 auto remN_S = remN && !strategy.S.padded;
21296 auto remM_D = remM && !strategy.D.padded && !byColumn;
21297 auto remN_D = remN && !strategy.D.padded && byColumn;
21298
21299 auto sMaxRBlock = 0;
21300 auto sMaxCBlock = 0;
21301
21302 if (!getRegLayout(Ts, state.S_layout, nms, nns, remM_S, remN_S, false,
21303 true, sMaxRBlock, sMaxCBlock, problem.S, strategy.S))
21304 return false;
21305 if (!getRegLayout(Td, state.D_layout, nmd, nnd, remM_D, remN_D, true,
21306 true, 0, 0, problem.D, strategy.D))
21307 return false;
21308
21309 if (hasFragmenting(state.S_layout) || hasFragmenting(state.D_layout)) {
21310 status << "Fragmenting not supported." << status_stream::endl;
21311 return false;
21312 }
21313
21314 bool success = true;
21315 if (checkYRem1) {
21316 flagYRem1 = state.raVFlag.allocVirtual();
21317 success &= !(state.raVFlag.isVirtual(flagYRem1)
21318 && state.vflagStorage.isInvalid());
21319 }
21320
21321 // Find and load any needed mask registers.
21322 success = success
21323 && assignMasks(
21324 state.S_layout, LoopM, LoopN, masks, strategy, state)
21325 && assignMasks(
21326 state.D_layout, LoopM, LoopN, masks, strategy, state);
21327
21328 if (!success && state.vflagStorage.isInvalid()) {
21329 status << "Retrying with virtual flags." << status_stream::endl;
21330 allocVFlagStorage(strategy, state);
21331 success = assignMasks(state.S_layout, LoopM, LoopN, masks, strategy,
21332 state)
21333 && assignMasks(state.D_layout, LoopM, LoopN, masks,
21334 strategy, state);
21335 }
21336
21337 if (!success) return false;
21338
21339 loadMasks(masks, remainders, strategy, state);
21340
21341 if (!strategy.xLoop && !remM_D && !remN_D
21342 && strategy.remHandlingX != RemainderHandling::Ignore) {
21343 // Find a mask to use for destination layout for y loop remainders.
21344 VirtualFlag flag;
21345 bool found = false;
21346 for (auto &mask : masks)
21347 if (mask.var == (byColumn ? LoopM : LoopN) && mask.offset == 0)
21348 flag = mask.flag, found = true;
21349 if (!found) stub();
21350 for (auto &block : state.D_layout) {
21351 if (block.simdSize > 16) stub();
21352 block.flag = flag;
21353 block.flagAny = true;
21354 }
21355 } else if (checkYRem1) {
21356 // Create mask for y remainder for x-loop kernels with unrollY == 1, and
21357 // apply it by hand to both source and destination.
21358 RegData regYRem1 = getMaskFlag(flagYRem1, state);
21359 FlagRegister testFlag;
21360
21361 testFlag = regYRem1.isARF()
21362 ? reinterpret_cast<FlagRegister &>(regYRem1)
21363 : f0[1];
21364
21365 cmp(16 | gt | testFlag, state.remainderY, 0);
21366
21367 for (auto &mask : masks)
21368 mov(1 | ~testFlag, getMaskFlag(mask.flag, state), 0);
21369 if (!regYRem1.isARF()) mov(1, regYRem1, testFlag);
21370
21371 for (auto &block : state.S_layout)
21372 if (!block.flag) block.flag = flagYRem1;
21373 for (auto &block : state.D_layout) {
21374 if (block.simdSize > 16) stub();
21375 block.flag = flagYRem1;
21376 }
21377 }
21378
21379 // Match source layout to destination layout if possible, so that they can share registers.
21380 share = (Ts == Td) && (s_load == d_load)
21381 && matchLayoutsBidirectional(
21382 Ts, state.S_layout, state.D_layout);
21383
21384 // Allocate address registers.
21385 allocAddrRegs(state.S_addrs, state.S_layout, problem.S, strategy.S,
21386 state,
21387 getHint(share ? HintType::DAddr : HintType::SAddr, strategy));
21388 allocAddrRegs(state.D_addrs, state.D_layout, problem.D, strategy.D,
21389 state, getHint(HintType::DAddr, strategy));
21390
21391 // Set up address registers.
21392 setupAddr(Ts, state.S_addrs, S_addr0, state.S_layout, state.inputs.lds,
21393 problem.S, strategy.S, strategy, state);
21394 setupAddr(Td, state.D_addrs, D_addr0, state.D_layout, state.inputs.ldd,
21395 problem.D, strategy.D, strategy, state);
21396
21397 // Allocate data registers.
21398 int S_regCount = getRegCount(state.S_layout);
21399 int D_regCount = getRegCount(state.D_layout);
21400
21401 state.D_regs = state.ra.alloc_range(
21402 D_regCount, getHint(HintType::D, strategy));
21403 state.S_regs = share ? state.D_regs
21404 : state.ra.alloc_range(S_regCount,
21405 getHint(HintType::S, strategy));
21406
21407 // Prepare for summation.
21408 // Clean up previous sums if any, and try to reuse their registers.
21409 // Allocate and zero new sum registers as needed.
21410 if (problem.sum) {
21411 if (strategy.xLoop) stub();
21412
21413 vector<RegisterBlock> Ds_layout;
21414 makeSumLayout(!byColumn, Td, state.D_layout, Tsum, Ds_layout,
21415 strategy, state);
21416
21417 bool alloc = state.Ds_layout.empty()
21418 || !matchLayouts(Tsum, Ds_layout, state.Ds_layout);
21419 if (!state.Ds_layout.empty() && alloc) {
21420 horizontalAdd(!byColumn, Tsum, state.Ds_regs.back(),
21421 state.Ds_layout, state);
21422 alloc = !matchLayouts(Tsum, Ds_layout, state.Ds_layout);
21423 }
21424 if (alloc) {
21425 state.Ds_layout = std::move(Ds_layout);
21426 auto Ds_regs
21427 = state.ra.alloc_range(getRegCount(state.Ds_layout));
21428 zeroMatrix(Ds_regs, strategy);
21429 state.Ds_regs.push_back(Ds_regs);
21430 }
21431 }
21432
21433 return true;
21434 };
21435
21436 auto cleanup = [&]() {
21437 state.raVFlag.safeRelease(flagYRem1);
21438 safeReleaseMaskAssignments(masks, state);
21439 safeReleaseRanges(state.S_addrs, state);
21440 safeReleaseRanges(state.D_addrs, state);
21441
21442 state.ra.safeRelease(state.S_regs);
21443 state.ra.safeRelease(state.D_regs);
21444 // Sum registers not freed here.
21445
21446 state.S_layout.clear();
21447 state.D_layout.clear();
21448 };
21449
21450 auto doSLoad = [&](const vector<RegisterBlock> &layout,
21451 const vector<RegisterBlock> &layoutReflect,
21452 const vector<GRFRange> &addrs,
21453 const vector<GRFRange>(&addrSrcs)[2], int z0,
21454 int s_load, int S_copy, bool checkRem) {
21455 bool unlockAP = false;
21456 Label skipLoad;
21457 checkRem &= (z0 > 0);
21458
21459 if (checkRem) {
21460 zeroMatrix(state.S_regs, strategy);
21461 unlockAP = !state.raVFlag.lock(state.flagAP);
21462 state.usePhysicalFlag(state.flagAP);
21463 cmp(1 | le | state.flagAP, state.Z, uint16_t(z0));
21464 jmpi(1 | state.flagAP, skipLoad);
21465 }
21466
21467 {
21468 loadMatrix(state.S_regs, layout, problem.S, strategy.S, addrs,
21469 strategy, state);
21470 }
21471
21472 auto addrsFixed = reflecting ? &addrSrcs[0] : &addrs;
21473 auto addrsStrided = reflecting ? &addrSrcs[1] : nullptr;
21474 auto layoutFixed = &layout;
21475 auto layoutStrided = &layoutReflect;
21476
21477 if (sStrided) {
21478 std::swap(addrsFixed, addrsStrided);
21479 std::swap(layoutFixed, layoutStrided);
21480 }
21481
21482 {
21483 if (addrsStrided)
21484 incAddr(*addrsStrided,
21485 (s_load == 1) ? state.inputs.lds : state.lds_sl,
21486 *layoutStrided, problem.S, strategy.S, strategy, state);
21487 if (addrsFixed)
21488 incAddr(*addrsFixed, uint16_t(s_load * Ts), *layoutFixed,
21489 problem.S, strategy.S, strategy, state);
21490 }
21491 if (checkRem) {
21492 if (unlockAP) state.raVFlag.unlock(state.flagAP);
21493 mark(skipLoad);
21494 }
21495 };
21496
21497 auto doDStore = [&](const vector<RegisterBlock> &layout,
21498 const vector<GRFRange> &addrs, int d_load,
21499 int D_copy) {
21500 storeMatrix(state.D_regs, layout, problem.D, strategy.D, addrs,
21501 strategy, state);
21502 if (problem.sum)
21503 accumulateSum(!byColumn, Td, state.D_regs, layout, Tsum,
21504 state.Ds_regs.back(), state.Ds_layout, strategy, state);
21505 if (strategy.xLoop) {
21506 if (d_load >= strategy.unrollX)
21507 incAddr(addrs, state.ldd_dl, layout, problem.D, strategy.D,
21508 strategy, state);
21509 else
21510 incAddr(addrs, uint16_t(d_load * Td), layout, problem.D,
21511 strategy.D, strategy, state);
21512 } else {
21513 auto D_tileX = byColumn ? problem.D.tileR : problem.D.tileC;
21514 auto D_tileY = byColumn ? problem.D.tileC : problem.D.tileR;
21515 auto effPS = (d_load < D_tileY) ? D_tileX : problem.D.packSize;
21516 incAddr(addrs, uint16_t(d_load * effPS * Td), layout, problem.D,
21517 strategy.D, strategy, state);
21518 }
21519 };
21520
21521 // Start generating code.
21522
21523 // Reuse z for the loop counter.
21524 // If z unroll > 1, the loop counter will be offset by (unrollZ - 1) during the main loop,
21525 // unless there's no z remainder.
21526 // For triangular-ended copies, offset by an additional unrollW [2x unrollX if fused] to push triangular handling to remainder loop.
21527 state.Z = mLoop ? state.inputs.m : state.inputs.n;
21528
21529 auto unrollZ = strategy.unrollZ();
21530 auto offsetZ = (remainderZ || triRemOnly) ? (unrollZ - 1) : 0;
21531
21532 if (offsetZ == 0)
21533 cmp(1 | le | state.flagAP, null.d(), state.Z, int16_t(0));
21534 else
21535 add(1 | le | state.flagAP, state.Z, state.Z, int16_t(-offsetZ));
21536
21537 // Get flag register and loop counter for barrier check if needed.
21538 FlagRegister flagBarrier;
21539 Subregister bcount;
21540 if (strategy.barrierFreq > 0) {
21541 flagBarrier = state.raVFlag.alloc();
21542
21543 // Can use main loop counter if barrierFreq and unrollZ both powers of 2.
21544 if (!is_zero_or_pow2(strategy.barrierFreq * unrollZ)) {
21545 bcount = state.ra.alloc_sub<uint32_t>();
21546 mov(1, bcount, uint16_t(strategy.barrierFreq));
21547 }
21548 }
21549
21550 // Setup for main loop.
21551 if (!setup(strategy.s_load, strategy.d_load, state.effS, state.effS1,
21552 state.effD, false))
21553 return false;
21554
21555 bool lateZLoopCheck = state.vflagStorage.isValid();
21556 if (lateZLoopCheck) {
21557 // Release flags for use by vflags. Note flagReflect is not released.
21558 state.raVFlag.unlock(state.flagAP);
21559 if (flagBarrier.isValid()) state.raVFlag.unlock(flagBarrier);
21560 }
21561
21562 // Bail to remainder loop if no main loops.
21563 jmpi(1 | state.flagAP, lZLoopEnd);
21564
21565 // Loop check code.
21566 auto zLoopCheck = [&](int unrollZ, bool enableBarriers) {
21567 // Use the all-purpose flag for z loop query.
21568 add(1 | gt | state.flagAP, state.Z, state.Z, int16_t(-unrollZ));
21569
21570 // Check for barrier if requested.
21571 if (enableBarriers) {
21572 if (bcount.isInvalid())
21573 and_(1 | ze | flagBarrier, null.ud(), state.Z,
21574 uint16_t(unrollZ * strategy.barrierFreq - unrollZ));
21575 else
21576 add(1 | ze | flagBarrier, bcount, bcount, int16_t(-1));
21577 }
21578 };
21579
21580 // Lambdas used in zLoopBody (moved outside to w/a GCC bug)
21581 auto mulAlphaFixed = [&](int esize, RegData r) {
21582 mul(esize, r, r, problem.alpha_real.getRegAvoiding(hw, r));
21583 };
21584
21585 auto mulAlpha = [&](int esize, RegData r) {
21586 mul(esize, r, r, cast(Ts.real(), problem.alpha_real));
21587 };
21588
21589 auto signChange = [&](int esize, RegData r) {
21590 auto ne = elementsPerGRF<uint32_t>(hw);
21591 xor_<uint32_t>(esize, r, r,
21592 (ne < esize) ? state.signChange[0](0, ne, 1)
21593 : state.signChange[0](1));
21594 };
21595
21596 // z loop: returns true on success.
21597 int S_copy = 0, D_copy = 0;
21598 auto zLoopBody = [&](const vector<RegisterBlock> &S_layout,
21599 const vector<RegisterBlock> &S_layoutReflected,
21600 const vector<RegisterBlock> &D_layout,
21601 const vector<GRFRange> &S_addrs,
21602 const vector<GRFRange>(&S_addrSrcs)[2],
21603 const vector<GRFRange> &D_addrs, int unrollZ,
21604 int s_load, int d_load, bool enableBarriers,
21605 bool enableTri, bool needSRem = false,
21606 bool noLoop = false) {
21607 int us = s_load, ud = 0;
21608 int uZLoopCheck = noLoop ? -1 : lateZLoopCheck ? (unrollZ - 1) : 0;
21609 bool dMasked = hasMasking(D_layout);
21610
21611 for (int u = 0; u < unrollZ; u++, us++, ud++) {
21612 // Maintain us (= u % s_load) and ud (= u % d_load) counters.
21613 bool loadS = false;
21614 if (us == s_load) {
21615 us = 0;
21616 loadS = true;
21617 }
21618
21619 if (ud == d_load) ud = 0;
21620 bool storeD = ((ud + 1) == d_load);
21621
21622 // Test loop counter on first iteration (lateZLoopCheck == false)
21623 if ((u == uZLoopCheck) && !lateZLoopCheck)
21624 zLoopCheck(unrollZ, enableBarriers);
21625
21626 // Load S every s_load loops, and copy as necessary.
21627 if (loadS) {
21628 doSLoad(S_layout, S_layoutReflected, S_addrs, S_addrSrcs, u,
21629 s_load, S_copy, needSRem);
21630
21631 // Copy S registers to D registers, or perform in-place scaling/transposition.
21632 if (!share) {
21633 int dOffR = 0, dOffC = 0;
21634 (byColumn ? dOffC : dOffR) = ud;
21635
21636 if (!copyRegisters(Ts, Td, S_layout, D_layout, state.S_regs,
21637 state.D_regs, dOffR, dOffC, problem.alpha_real,
21638 problem.alpha_imag, problem.conjugate, strategy,
21639 state))
21640 return false;
21641 } else {
21642 if (!problem.alpha_real.fixed())
21643 map(hw, Ts.real(), state.S_regs, S_layout, strategy,
21644 mulAlphaFixed);
21645 else if ((problem.alpha_real != 1)
21646 && (problem.alpha_real != -1))
21647 map(hw, Ts.real(), state.S_regs, S_layout, strategy,
21648 mulAlpha);
21649 if (problem.conjugate || (problem.alpha_real == -1))
21650 map<uint32_t>(hw, state.S_regs, S_layout, strategy,
21651 signChange);
21652 }
21653
21654 // Advance S copy counter.
21655 if (++S_copy == SD_copies) S_copy = 0;
21656 }
21657
21658 // Test loop counter on last iteration (lateZLoopCheck == true) if D unmasked.
21659 if ((u == uZLoopCheck) && lateZLoopCheck && !dMasked)
21660 zLoopCheck(unrollZ, enableBarriers);
21661
21662 // Store D every d_load loops.
21663 if (storeD) {
21664 doDStore(D_layout, D_addrs, d_load, D_copy);
21665 if (++D_copy == SD_copies) D_copy = 0;
21666 }
21667
21668 // Test loop counter at very end (lateZLoopCheck == true) if D masked.
21669 if ((u == uZLoopCheck) && lateZLoopCheck && dMasked)
21670 zLoopCheck(unrollZ, enableBarriers);
21671 }
21672
21673 // Forget about active vflags.
21674 state.wipeActiveVFlags();
21675
21676 return true;
21677 };
21678
21679 syncall();
21680
21681 mark(lZLoopBegin);
21682 {
21683 if (!zLoopBody(state.S_layout, S_layoutReflected, state.D_layout,
21684 state.S_addrs, state.S_addrSrcs, state.D_addrs, unrollZ,
21685 strategy.s_load, strategy.d_load, strategy.barrierFreq > 0,
21686 !triRemOnly))
21687 return false;
21688
21689 if (strategy.barrierFreq == 0)
21690 jmpi(1 | state.flagAP, lZLoopBegin);
21691 else {
21692 jmpi(1 | ~state.flagAP, lZLoopEnd);
21693 jmpi(1 | ~flagBarrier, lZLoopBegin);
21694
21695 auto temp = state.ra.alloc();
21696 if (!bcount.isInvalid())
21697 mov(1, bcount, uint16_t(strategy.barrierFreq));
21698
21699 GRF r0_info;
21700 bool freeR0Info = false;
21701
21702 if (state.r0_info.isARF()) {
21703 r0_info = state.ra.alloc();
21704 mov<uint32_t>(8, r0_info, state.r0_info);
21705 freeR0Info = true;
21706 } else
21707 r0_info = GRF {state.r0_info.getBase()};
21708
21709 barrier(temp, r0_info);
21710 state.ra.safeRelease(temp);
21711 if (freeR0Info) state.ra.safeRelease(r0_info);
21712
21713 jmpi(1, lZLoopBegin);
21714 }
21715 }
21716 mark(lZLoopEnd);
21717
21718 state.raVFlag.safeRelease(flagBarrier);
21719 state.ra.safeRelease(bcount);
21720
21721 // z remainder loop.
21722 if (offsetZ) {
21723 // Undo offseting on the z loop counter and check for zero remainder loops.
21724 add(1 | le | state.flagAP, state.Z, state.Z, uint16_t(offsetZ));
21725
21726 // Get the current S, D addresses.
21727 Subregister S_addr0, S1_addr0, D_addr0;
21728 int S_shift, D_shift;
21729 S_addr0 = getOriginAddr(
21730 state.S_layout, state.S_addrs, problem.S, strategy.S, &S_shift);
21731
21732 D_addr0 = getOriginAddr(
21733 state.D_layout, state.D_addrs, problem.D, strategy.D, &D_shift);
21734
21735 auto unshiftAddr0 = [&]() {
21736 if (S_shift) shl(1, S_addr0, S_addr0, S_shift);
21737 if (D_shift) shl(1, D_addr0, D_addr0, D_shift);
21738 };
21739
21740 // Prepare for potential new layout.
21741 vector<RegisterBlock> S_layout1, S_layout1Reflect, D_layout1;
21742 vector<GRFRange> S_addrs1, S_addrSrcs1[2], D_addrs1;
21743
21744 // First, try handling the whole remainder, all at once.
21745 bool wholeRem = false, fragmented = false;
21746 auto newSLoad = strategy.s_load, newDLoad = strategy.d_load;
21747 auto saveSStrategy = strategy.S;
21748 bool largeDCrosspack = isLargeCrosspack(Td, problem.D.crosspack);
21749
21750 if (S_addr0.isValid() && D_addr0.isValid() && !largeDCrosspack) {
21751 auto saveState = state;
21752 auto saveMasks = masks;
21753 (strategy.xLoop ? state.remainderX : state.remainderY) = state.Z;
21754 pushStream();
21755 try {
21756 cleanup();
21757 state.ra.claim(S_addr0);
21758 state.ra.claim(D_addr0);
21759 unshiftAddr0();
21760
21761 wholeRem = setup(strategy.s_load, strategy.d_load, S_addr0,
21762 S1_addr0, D_addr0, true);
21763
21764 state.ra.release(S_addr0);
21765 state.ra.release(D_addr0);
21766 } catch (...) {}
21767 if (!wholeRem) {
21768 masks = saveMasks;
21769 state = saveState;
21770 }
21771 wholeRem ? appendCurrentStream() : discardStream();
21772 }
21773
21774 // If that doesn't work, retry with minimal unroll.
21775 if (!wholeRem) {
21776 newSLoad = 1;
21777 newDLoad = crosspack;
21778 bool unshare = share && (newSLoad != newDLoad);
21779
21780 // Fragment the S, D layouts, taking the first row/column of each.
21781 vector<int> indices;
21782 fragmented = (!unshare && !largeDCrosspack
21783 && getSubblocks(Ts, S_layout1, indices, state.S_layout,
21784 !mLoop, 0, newSLoad, strategy.S.padded, problem.S,
21785 strategy.S)
21786 && getSubblocks(Ts, S_layout1Reflect, S_layoutReflected,
21787 mLoop, 0, newSLoad, strategy.S.padded, problem.S,
21788 S_strategyReflected)
21789 && getSubblocks(Td, D_layout1, D_addrs1, state.D_layout,
21790 state.D_addrs, !mLoop, 0, newDLoad, false,
21791 problem.D, strategy.D));
21792
21793 if (fragmented) {
21794 // Select source address registers from the fragments.
21795 for (auto b : indices)
21796 S_addrs1.push_back(state.S_addrs[b]);
21797 // Update sizes.
21798 (mLoop ? nms : nns) = newSLoad;
21799 (mLoop ? nmd : nnd) = newDLoad;
21800 } else {
21801 // Fragmentation failed. Start fresh.
21802 if (S_addr0.isInvalid() || D_addr0.isInvalid()) return false;
21803
21804 cleanup();
21805 state.ra.claim(S_addr0);
21806 state.ra.claim(D_addr0);
21807 unshiftAddr0();
21808
21809 if (largeDCrosspack) {
21810 strategy.S.accessType = isTransposing(strategy.S.accessType)
21811 ? AccessType::Block
21812 : strategy.S.base.isStateless()
21813 ? AccessType::Scattered
21814 : AccessType::ChannelScattered;
21815 }
21816
21817 if (!setup(newSLoad, newDLoad, S_addr0, S1_addr0, D_addr0,
21818 false))
21819 return false;
21820
21821 state.ra.release(S_addr0);
21822 state.ra.release(D_addr0);
21823 }
21824
21825 if (crosspack > 1) {
21826 lateZLoopCheck = true;
21827 copyCalcIncrements(
21828 problem, strategy, state, newSLoad, newDLoad);
21829 }
21830 }
21831
21832 // Emit z remainder loop.
21833 Label lZRemLoopBegin, lZRemLoopEnd;
21834 jmpi(1 | state.flagAP, lZRemLoopEnd);
21835 mark(lZRemLoopBegin);
21836 wholeRem ? zLoopBody(state.S_layout, S_layoutReflected, state.D_layout,
21837 state.S_addrs, state.S_addrSrcs, state.D_addrs, unrollZ,
21838 newSLoad, newDLoad, false, true, false, !triRemOnly)
21839 : fragmented
21840 ? zLoopBody(S_layout1, S_layout1Reflect, D_layout1,
21841 S_addrs1, S_addrSrcs1, D_addrs1, crosspack,
21842 newSLoad, newDLoad, false, true, crosspack > 1)
21843 : zLoopBody(state.S_layout, S_layoutReflected,
21844 state.D_layout, state.S_addrs, state.S_addrSrcs,
21845 state.D_addrs, crosspack, newSLoad, newDLoad,
21846 false, true, crosspack > 1);
21847 if (!wholeRem || triRemOnly) jmpi(1 | state.flagAP, lZRemLoopBegin);
21848 mark(lZRemLoopEnd);
21849
21850 strategy.S = saveSStrategy;
21851 }
21852
21853 // Finalize and store sums.
21854 if (problem.sum) {
21855 Label skipSumStore;
21856 bool simtCF = strategy.fused;
21857
21858 if (remainderX) {
21859 cmp((simtCF ? 16 : 1) | le | state.flagAP, state.remainderX, 0);
21860 simtCF ? goto12(16 | state.flagAP, skipSumStore)
21861 : jmpi(1 | state.flagAP, skipSumStore);
21862 }
21863
21864 horizontalAdd(
21865 !byColumn, Tsum, state.Ds_regs.back(), state.Ds_layout, state);
21866
21867 // Accumulate sums from main and remainder loops.
21868 for (int l = 1; l < int(state.Ds_regs.size()); l++) {
21869 map(hw, Tsum, state.Ds_regs[0], state.Ds_regs[l], strategy,
21870 [&](int ne, GRF r1, GRF r2) { add(ne, r1, r1, r2); });
21871 state.ra.safeRelease(state.Ds_regs[l]);
21872 }
21873 state.Ds_regs.resize(1);
21874
21875 MatrixAddressing Ds = problem.D;
21876 Ds.crosspack = 1;
21877
21878 MatrixAddressingStrategy Ds_strategy = strategy.D;
21879 Ds_strategy.accessType = AccessType::Block;
21880
21881 int sr = 1, sc = 1;
21882 (byColumn ? sr : sc) = problem.D.packSize;
21883
21884 vector<RegisterBlock> Ds_layoutOut;
21885 bool ok = getRegLayout(Tsum, Ds_layoutOut, sr, sc, false, false, true,
21886 true, 0, 0, Ds, Ds_strategy)
21887 && matchLayouts(Tsum, Ds_layoutOut, state.Ds_layout);
21888 if (!ok) return false;
21889
21890 vector<GRFRange> Ds_addrs;
21891 allocAddrRegs(Ds_addrs, Ds_layoutOut, Ds, Ds_strategy, state);
21892
21893 Subregister Ds_base;
21894 Ds_base = state.ra.alloc_sub(state.effD.getType());
21895
21896 mulConstant(1, Ds_base.ud(), state.inputs.ldd, problem.D.packSize * Td);
21897 add(1, Ds_base.ud(), Ds_base.ud(), -problem.D.packSize * Tsum);
21898 eadd(1, Ds_base, Ds_base.ud(), state.effD, strategy, state);
21899
21900 setupAddr(Tsum, Ds_addrs, Ds_base, Ds_layoutOut, Subregister(), Ds,
21901 Ds_strategy, strategy, state);
21902 storeMatrix(state.Ds_regs[0], Ds_layoutOut, Ds, Ds_strategy, Ds_addrs,
21903 strategy, state);
21904
21905 state.ra.safeRelease(Ds_base);
21906 safeReleaseRanges(Ds_addrs, state);
21907 safeReleaseRanges(state.Ds_regs, state);
21908 state.Ds_layout.clear();
21909 state.ra.safeRelease(state.all1s);
21910
21911 if (remainderX) {
21912 mark(skipSumStore);
21913 if (simtCF) join(16);
21914 }
21915 }
21916
21917 // Done. Free address, data, and flag registers.
21918 cleanup();
21919 state.ra.safeRelease(state.signChange);
21920 if (lateZLoopCheck) state.raVFlag.lock(state.flagAP);
21921 state.raVFlag.safeRelease(state.flagReflect);
21922 state.raVFlag.safeRelease(state.flagSwizzle);
21923
21924 return true; /* Success! */
21925}
21926
21927// Register-to-register copy of a single block, ignoring register offsets in the block.
21928template <HW hw>
21929bool gemm_kernel_generator_t<hw>::copyRegisterBlock(Type Ts, Type Td,
21930 const RegisterBlock &blockSrc, const RegisterBlock &blockDst,
21931 const GRFMultirange &src, const GRFMultirange &dst, int dOffR,
21932 int dOffC, const CommonStrategy &strategy, CommonState &state,
21933 bool preserveSrc) {
21934 std::vector<RegisterBlock> modSrc {1, blockSrc}, modDst {1, blockDst};
21935 modSrc[0].offsetBytes %= GRF::bytes(hw);
21936 modDst[0].offsetBytes %= GRF::bytes(hw);
21937 return copyRegisters(Ts, Td, modSrc, modDst, src, dst, dOffR, dOffC, false,
21938 strategy, state, preserveSrc);
21939}
21940
21941// Register-to-register copy, with no scaling.
21942template <HW hw>
21943bool gemm_kernel_generator_t<hw>::copyRegisters(Type Ts, Type Td,
21944 const vector<RegisterBlock> &layoutSrc,
21945 const vector<RegisterBlock> &layoutDst, const GRFMultirange &src,
21946 const GRFMultirange &dst, int dOffR, int dOffC, bool conjugate,
21947 const CommonStrategy &strategy, CommonState &state, bool preserveSrc) {
21948 return copyRegisters(Ts, Td, layoutSrc, layoutDst, src, dst, dOffR, dOffC,
21949 Scalar<double>(1.), Scalar<double>(0.), conjugate, strategy, state,
21950 preserveSrc);
21951}
21952
21953// Register-to-register copy, with scaling.
21954template <HW hw>
21955bool gemm_kernel_generator_t<hw>::copyRegisters(Type Ts, Type Td,
21956 const vector<RegisterBlock> &layoutSrc,
21957 const vector<RegisterBlock> &layoutDst, const GRFMultirange &src,
21958 const GRFMultirange &dst, int dOffR, int dOffC,
21959 const Scalar<double> &alpha_real, const Scalar<double> &alpha_imag,
21960 bool conjugate, const CommonStrategy &strategy, CommonState &state,
21961 bool preserveSrc) {
21962 const int nphases = 2, qCXMin = -1, qCXMax = -1;
21963
21964 bool preswizzle = (hw >= HW::XeHP);
21965 GRFRange copyTemp;
21966
21967 auto allocTemp = [&]() {
21968 if (preswizzle && copyTemp.isInvalid())
21969 copyTemp = state.ra.alloc_range(2);
21970 };
21971
21972 int srcM, srcN;
21973 getLayoutDims(layoutSrc, srcM, srcN);
21974 bool vectorCopy = (srcM == 1 || srcN == 1);
21975 int periodY = 1;
21976
21977 if (GRF::bytes(hw) == 64 && Td.size() == 1 && layoutDst[0].crosspack > 1)
21978 periodY = 2;
21979
21980 for (int phase = -1; phase < nphases; phase++) {
21981 for (int phaseY = 0; phaseY < periodY; phaseY++) {
21982 for (auto &sblock : layoutSrc) {
21983 auto RegisterBlock::*nx = sblock.colMajor ? &RegisterBlock::nr
21984 : &RegisterBlock::nc;
21985 auto RegisterBlock::*ny = sblock.colMajor ? &RegisterBlock::nc
21986 : &RegisterBlock::nr;
21987 auto RegisterBlock::*offsetY = sblock.colMajor
21988 ? &RegisterBlock::offsetC
21989 : &RegisterBlock::offsetR;
21990
21991 for (int eoffY = 0; eoffY < sblock.*ny; eoffY++) {
21992 if (((eoffY + sblock.*offsetY) & (periodY - 1)) != phaseY)
21993 continue;
21994 for (int qCX = qCXMin; qCX <= qCXMax; qCX++) {
21995 for (int eoffX = 0; eoffX < sblock.*nx;) {
21996 auto eoffR = sblock.colMajor ? eoffX : eoffY;
21997 auto eoffC = sblock.colMajor ? eoffY : eoffX;
21998
21999 int selems, delems;
22000 const RegisterBlock *sblockPtr, *dblockPtr;
22001
22002 // Locate source and destination register.
22003 auto sreg = findBlockReg(Ts, layoutSrc,
22004 sblock.offsetR + eoffR,
22005 sblock.offsetC + eoffC, src, selems,
22006 sblockPtr, qCX);
22007 auto dreg = findBlockReg(Td, layoutDst,
22008 sblock.offsetR + eoffR + dOffR,
22009 sblock.offsetC + eoffC + dOffC, dst, delems,
22010 dblockPtr, qCX);
22011
22012 auto scrosspack = sblock.crosspack;
22013 auto dcrosspack = dblockPtr->crosspack;
22014
22015 if (sblock.colMajor != dblockPtr->colMajor) {
22016 bool sLargeCP
22017 = isLargeCrosspack(Ts, scrosspack);
22018 bool dLargeCP
22019 = isLargeCrosspack(Td, dcrosspack);
22020 bool sEffCM = sblock.colMajor ^ sLargeCP;
22021 bool dEffCM = dblockPtr->colMajor ^ dLargeCP;
22022 if (sEffCM == dEffCM) {
22023 if (sLargeCP)
22024 selems = std::min<int>(
22025 selems, scrosspack);
22026 if (dLargeCP)
22027 delems = std::min<int>(
22028 delems, dcrosspack);
22029 } else {
22030 if (!vectorCopy)
22031 stub(); // No in-register matrix transposes.
22032 selems = delems = 1;
22033 }
22034 }
22035
22036 // Find out how many consecutive elements we can copy.
22037 auto nGRFs = (strategy.dualGRF ? 2 : 1);
22038 auto nGRFs_d = (dreg.getOffset() >= dcrosspack)
22039 ? 1
22040 : nGRFs; // Don't cross destination GRF boundaries for efficiency.
22041 auto selems_real = selems * Ts.complexComponents();
22042 auto delems_real = delems * Td.complexComponents();
22043 auto selems_limit = div_up(
22044 nGRFs * elementsPerGRF(hw, Ts.real())
22045 - sreg.getOffset(),
22046 scrosspack);
22047 auto delems_limit = div_up(
22048 nGRFs_d * elementsPerGRF(hw, Td.real())
22049 - dreg.getOffset(),
22050 dcrosspack);
22051 selems_real = std::min({selems_real, selems_limit});
22052 delems_real = std::min({delems_real, delems_limit});
22053 auto nelems_real
22054 = std::min(selems_real, delems_real);
22055 if (phase != 0)
22056 nelems_real = std::min(
22057 rounddown_pow2(nelems_real), 32);
22058
22059 if (Ts == Type::f32 && Td != Type::f32
22060 && dcrosspack == 1)
22061 nelems_real = std::min(nelems_real,
22062 elementsPerGRF(hw,
22063 Ts)); // Special case: mixed mode packed downconversion limited to SIMD8.
22064
22065 // Check if separate conversions are needed due to size changes.
22066 auto sconvertCP = (Ts.size() / Td.size());
22067 bool sconvert = (Td.size() == 1 && Ts.size() > 1
22068 && dcrosspack != sconvertCP)
22069 || (Td.size() == 2 && Td.isFP()
22070 && !Ts.isFP()
22071 && dcrosspack != sconvertCP
22072 && hw > HW::Gen9);
22073 if (sconvert && preserveSrc) stub();
22074 auto sregConverted = sconvert
22075 ? sreg.reinterpret(0, Td.real().ngen())(
22076 sconvertCP)
22077 : sreg(scrosspack);
22078
22079 auto dconvertCP = (Td.size() / Ts.size());
22080 bool dconvert = (Ts.size() == 1 && Td.size() > 1
22081 && scrosspack != dconvertCP);
22082 auto dregConverted = dconvert
22083 ? dreg.reinterpret(0, Ts.real().ngen())(
22084 dconvertCP)
22085 : dreg(dcrosspack);
22086
22087 InstructionModifier modMov, mmodMov;
22088 if (Ts != Td && Td.isInteger()
22089 && Td.size() <= Ts.size()) {
22090 modMov = modMov | sat;
22091 if (!sconvert && !dconvert)
22092 mmodMov = mmodMov | sat;
22093 }
22094
22095 // Finally, copy, with any necessary conjugation and scaling. If doing a raw copy, use another pipe.
22096 switch (phase) {
22097 case -1:
22098 if (hw == HW::Gen9 && Ts == Type::f32
22099 && !Td.isFP()) {
22100 // Gen9: round to nearest before downconvert (not done by mov).
22101 rnde(nelems_real, sreg(scrosspack),
22102 sreg(scrosspack));
22103 }
22104 if (sconvert)
22105 mov(nelems_real | modMov, sregConverted,
22106 sreg(scrosspack));
22107 break;
22108 case 0:
22109 if (alpha_real == 1 || alpha_real == -1) {
22110 if (Ts.real() == Td.real()) {
22111 movePipes(sreg, scrosspack == 1);
22112 movePipes(dreg, scrosspack == 1);
22113 if (!sconvert)
22114 sregConverted
22115 = sreg(scrosspack);
22116 if (!dconvert)
22117 dregConverted
22118 = dreg(dcrosspack);
22119 if (hw >= HW::XeHP
22120 && scrosspack
22121 != dcrosspack) {
22122 moveToIntPipe(nelems_real,
22123 sregConverted);
22124 moveToIntPipe(nelems_real,
22125 dregConverted);
22126 sreg = sreg.reinterpret(0,
22127 sregConverted
22128 .getType());
22129 }
22130 }
22131 int telems = nelems_real * Ts.real()
22132 / sreg.getBytes();
22133 if (telems > 32) {
22134 nelems_real = (nelems_real * 32)
22135 / telems;
22136 telems = 32;
22137 }
22138 if (alpha_real == -1) {
22139 auto wd = elementsPerGRF(
22140 hw, sreg.getType());
22141 auto base = state.signChange.sub(
22142 0, dreg.getType());
22143 xor_(telems, dreg(1), sreg(1),
22144 (wd >= telems)
22145 ? base(1)
22146 : base(0, wd, 1));
22147 } else
22148 emov(telems | mmodMov,
22149 dregConverted,
22150 sregConverted, strategy,
22151 state);
22152 } else {
22153 auto realDst = dreg(dcrosspack);
22154 auto effDst = realDst;
22155 if (preswizzle
22156 && (Ts.isFP() || Td.isFP())) {
22157 allocTemp();
22158 if ((sreg.getOffset()
22159 != dreg.getOffset())
22160 || (scrosspack
22161 != dcrosspack))
22162 effDst = copyTemp[0].sub(
22163 sreg.getOffset(),
22164 sreg.getType())(
22165 scrosspack);
22166 }
22167
22168 if (alpha_real.fixed())
22169 mul(nelems_real, effDst,
22170 sregConverted,
22171 cast(Ts.real(),
22172 alpha_real));
22173 else
22174 mul(nelems_real, effDst,
22175 sregConverted,
22176 alpha_real.getRegAvoiding(
22177 hw, sreg));
22178
22179 if (effDst != realDst) {
22180 moveToIntPipe(nelems_real, realDst);
22181 moveToIntPipe(nelems_real, effDst);
22182 int nelems_real_int = nelems_real
22183 * Td
22184 / getBytes(
22185 effDst.getType());
22186 emov(nelems_real_int, realDst,
22187 effDst, strategy, state);
22188 dconvert = false;
22189 }
22190 }
22191 break;
22192 case 1:
22193 if (dconvert)
22194 mov(nelems_real | modMov,
22195 dreg(dcrosspack),
22196 dregConverted);
22197 break;
22198 } /* switch phase */
22199
22200 int nelems = nelems_real;
22201 eoffX += nelems;
22202 } /* eoffX loop */
22203 } /* qCX loop */
22204 } /* eoffY loop */
22205 } /* sblock loop */
22206 } /* phaseY loop */
22207
22208 } /* phase loop */
22209
22210 state.ra.safeRelease(copyTemp);
22211 return true; // Success
22212}
22213
22214// Get driver information from this strategy.
22215template <HW hw>
22216CommonDriverInfo gemm_kernel_generator_t<hw>::driverInfo(
22217 const CopyProblem &problem, const CopyStrategy &strategy) {
22218 CommonDriverInfo info;
22219 bool isA = (problem.D.layout == MatrixLayout::Pc);
22220
22221 for (int d = 0; d < 3; d++) {
22222 info.blocking[d] = info.blockingAlt[d] = info.unroll[d] = 0;
22223 info.wg[d] = 1;
22224 info.loopOrder[d] = LoopNone;
22225 }
22226
22227 info.subgroupSize = strategy.subgroupSize;
22228 info.grfCount = strategy.GRFs;
22229 info.unroll[0] = isA ? strategy.unrollX : strategy.unrollY;
22230 info.unroll[1] = isA ? strategy.unrollY : strategy.unrollX;
22231 info.kRemainderHandling
22232 = (strategy.remHandlingY != RemainderHandling::Ignore);
22233 info.loopOrder[0] = (isA ^ strategy.xLoop) ? LoopM : LoopN;
22234 if (strategy.zParallel)
22235 info.loopOrder[1] = (isA ^ strategy.xLoop) ? LoopN : LoopM;
22236 info.fusedLoop = strategy.fused ? info.loopOrder[0] : LoopNone;
22237 info.wg[0] = 16;
22238 info.wgExpand = 1;
22239 info.wgUpdate = WGDynamic;
22240 info.kRemainderHandling = true;
22241 info.kParallel = strategy.zParallel;
22242 info.kParallelLocal = false;
22243 info.slm = info.perKSLM = 0;
22244 info.alignment[0] = problem.S.alignment;
22245 info.alignment[1] = problem.D.alignment;
22246 info.alignment[2] = 0;
22247 info.support4GB[0] = (strategy.S.base.getModel() == ModelA64);
22248 info.support4GB[1] = (strategy.D.base.getModel() == ModelA64);
22249 info.support4GB[2] = false;
22250
22251 return info;
22252}
22253
22254// Validate a copy strategy, correcting settings as necessary.
22255void CopyStrategy::preflight(HW hw, const CopyProblem &problem) {
22256 bool cm = isColMajor(problem.D.layout);
22257
22258 S.preflight(hw);
22259 D.preflight(hw);
22260
22261 s_load = std::max(s_load, 1);
22262 d_load = std::max(d_load, 1);
22263 if (s_load_masked == 0) s_load_masked = s_load;
22264 if (d_load_masked == 0) d_load_masked = d_load;
22265 unrollX = std::max(unrollX, 1);
22266 unrollY = std::max(unrollY, 1);
22267 unrollY = align_up(unrollY, problem.D.crosspack);
22268
22269 // Ensure d_load is a multiple of s_load and crosspack, and unrollZ a multiple of both.
22270 // For x loop kernels, ensure s_load is a multiple of the packing size.
22271 // For y loop kernels, ensure all d_loads are multiples of y tile size if any.
22272 if (xLoop) {
22273 s_load = align_up(s_load, problem.D.packSize);
22274 s_load_masked = align_up(s_load_masked, problem.D.packSize);
22275 } else {
22276 auto D_tileY = cm ? problem.D.tileC : problem.D.tileR;
22277 if (D_tileY > 0) d_load_masked = align_up(d_load_masked, D_tileY);
22278 d_load_masked = align_up(d_load_masked, problem.D.crosspack);
22279 }
22280 d_load = align_up(d_load, s_load);
22281 d_load_masked = align_up(d_load_masked, s_load_masked);
22282 d_load = align_up(d_load, d_load_masked);
22283
22284 if (xLoop)
22285 unrollX = align_up(unrollX, d_load);
22286 else
22287 unrollY = align_up(unrollY, d_load);
22288
22289 if (unrollY == 1 && remHandlingY == RemainderHandling::Split)
22290 remHandlingY = RemainderHandling::General;
22291
22292 spf &= !problem.trsm; // TRSM copies use SIMT control flow.
22293
22294 CommonStrategy::preflight(hw, problem);
22295}
22296
22297/**********************************************************************/
22298/* Common Kernel Functions */
22299/**********************************************************************/
22300
22301// Generate the kernel prologue.
22302template <HW hw>
22303void gemm_kernel_generator_t<hw>::prologue(const CommonStrategy &strategy) {
22304 uint16_t cr0Enable;
22305
22306 interface.generatePrologue(*this);
22307
22308 cr0Enable = 0x1000; // IEEE float->int rounding.
22309 if (strategy.ieeeDenormals) cr0Enable |= 0x4C0; // Enable hf|f|df denormals.
22310 if (strategy.spf) cr0Enable |= 0x4; // Enable single program flow.
22311
22312 or_(1, cr0, cr0, cr0Enable);
22313
22314 InstructionModifier imod = 1;
22315 if (hw < HW::Gen12LP) imod |= Switch;
22316
22317 if (interface.getSIMD() < 16) mov(imod, sr0[2], uint16_t(0xFFFF));
22318}
22319
22320// Generate the kernel epilogue.
22321template <HW hw>
22322void gemm_kernel_generator_t<hw>::epilogue(
22323 const CommonStrategy &strategy, const CommonState &state) {
22324 auto r0_info = state.r0_info;
22325
22326 if (r0_info.getBase() < 112) {
22327 mov<uint32_t>(8, r127, r0_info);
22328 r0_info = r127;
22329 }
22330
22331 if (strategy.finalFence) {
22332 memfence(r124, r0_info);
22333 mov<uint32_t>(8, null, r124);
22334 }
22335
22336 threadend(r0_info);
22337}
22338
22339// Pad the end of the kernel to accommodate instruction prefetching.
22340template <HW hw>
22341void gemm_kernel_generator_t<hw>::padding() {
22342 for (int q = 0; q < 8; q++)
22343 nop();
22344}
22345
22346// Common state initialization code.
22347template <HW hw>
22348void gemm_kernel_generator_t<hw>::initState(const CommonProblem &problem,
22349 const CommonStrategy &strategy, CommonState &state) {
22350 interface.requireLocalID(3);
22351 interface.requireLocalSize();
22352 if (problem.nonuniformWGs) interface.requireNonuniformWGs();
22353
22354 if (strategy.wgInSS) interface.requireBarrier();
22355
22356 interface.requireSIMD(strategy.subgroupSize);
22357
22358 if (!strategy.sipR0WA) interface.requireNoPreemption();
22359
22360 interface.requireGRF(strategy.GRFs);
22361 state.ra.setRegisterCount(strategy.GRFs);
22362
22363 if (problem.gtpinSupport) interface.requireScratch(128);
22364
22365 for (int i = 0; i < FlagRegister::subcount(hw); i++)
22366 state.activeVFlags[i].clear();
22367}
22368
22369CommonStrategy::CommonStrategy(HW hw, int stepping) : emulate(hw, stepping) {
22370 fused = one_of(hw, HW::Gen12LP, HW::XeHP, HW::XeHPG);
22371}
22372
22373void CommonStrategy::preflight(HW hw, const CommonProblem &problem) {
22374 subgroupSize = std::max(subgroupSize, GRF::bytes(hw) >> 2);
22375 sipR0WA &= (hw == HW::Gen9);
22376 if (sipR0WA && (moveR0 == MoveR0::None)) moveR0 = MoveR0::GRF;
22377 readSuppressionWA &= fused;
22378
22379 bool emulateNeedsAcc = emulate.emulate64 || emulate.emulateDWxDW
22380 || emulate.emulate64_mul;
22381 if (moveR0 == MoveR0::Acc && emulateNeedsAcc) moveR0 = MoveR0::None;
22382
22383 spf &= !fused;
22384}
22385
22386template <HW hw>
22387constexpr typename gemm_kernel_generator_t<hw>::status_stream::Endl
22388 gemm_kernel_generator_t<hw>::status_stream::endl;
22389
22390REG_GEN9_ISA(template class gemm_kernel_generator_t<HW::Gen9>);
22391REG_XELP_ISA(template class gemm_kernel_generator_t<HW::Gen12LP>);
22392REG_XEHP_ISA(template class gemm_kernel_generator_t<HW::XeHP>);
22393REG_XEHPG_ISA(template class gemm_kernel_generator_t<HW::XeHPG>);
22394REG_XEHPC_ISA(template class gemm_kernel_generator_t<HW::XeHPC>);
22395
22396} // namespace jit
22397} // namespace gpu
22398} // namespace impl
22399} // namespace dnnl
22400