1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_JIT_GEMM_GEN_GEMM_KERNEL_GENERATOR_HPP
18#define GPU_JIT_GEMM_GEN_GEMM_KERNEL_GENERATOR_HPP
19
20/* Embargo support */
21
22#define STANDALONE 0
23
24#include "common/math_utils.hpp"
25#include "common/utils.hpp"
26#include "gpu/jit/gemm/gen_gemm_kernel_common.hpp"
27#include "gpu/jit/gemm/utils.hpp"
28#include "gpu/jit/jit_generator.hpp"
29#include "gpu/jit/jit_post_op_injector.hpp"
30
31#if defined(ZEBIN_OUTPUT)
32#include "../ngen/ngen_elf.hpp"
33#else
34#include "../ngen/ngen_opencl.hpp"
35
36#endif
37#include "../ngen/ngen_register_allocator.hpp"
38
39#include "gpu/jit/gemm/emulation.hpp"
40
41#include <array>
42#include <complex>
43#include <cstdint>
44#include <exception>
45#include <iostream>
46#include <sstream>
47#include <vector>
48
49namespace dnnl {
50namespace impl {
51namespace gpu {
52namespace jit {
53
54struct RegisterBlock;
55
56class Type {
57public:
58 enum _Type : uint32_t {
59 invalid = 0,
60 f16 = 0x01000201,
61 f32 = 0x01010402,
62 u8 = 0x01840100,
63 s8 = 0x01850100,
64 u16 = 0x01860201,
65 s16 = 0x01870201,
66 u32 = 0x01880402,
67 s32 = 0x01890402,
68 u64 = 0x018A0803,
69 s64 = 0x018B0803,
70 bf16 = 0x010C0201,
71 tf32 = 0x010D0402,
72 };
73
74private:
75 _Type val;
76
77public:
78 constexpr Type() : Type(f32) {}
79 constexpr Type(_Type val_) : val(val_) {}
80 constexpr operator _Type() const { return val; }
81
82 constexpr Type real() const { return *this; }
83 constexpr bool isComplex() const { return false; }
84 constexpr int complexComponents() const { return 1; }
85 constexpr int components() const { return 1; }
86 constexpr bool isInteger() const { return uint32_t(val) & 0x800000; }
87 constexpr bool isFP() const { return !isInteger(); }
88 constexpr bool isSigned() const {
89 return (uint32_t(val) & 0x810000) != 0x800000;
90 }
91 constexpr int log2Size() const { return uint32_t(val) & 0xFF; }
92 constexpr int size() const { return (uint32_t(val) >> 8) & 0xFF; }
93
94 constexpr Type arithmetic() const {
95 return (val == tf32) ? Type(f32) : real();
96 }
97 data_type_t get_dnnl_type() const {
98 switch (val) {
99 case Type::f32: return data_type::f32;
100 case Type::f16: return data_type::f16;
101 case Type::s32: return data_type::s32;
102 case Type::u8: return data_type::u8;
103 case Type::s8: return data_type::s8;
104 default: assert(!"Unsupported type"); return data_type::undef;
105 }
106 }
107 constexpr Type baseType() const { return *this; }
108
109 template <typename U>
110 constexpr friend int operator*(U a, Type t) {
111 return int(a << t.log2Size());
112 }
113 template <typename U>
114 constexpr friend int operator/(U a, Type t) {
115 return int(a >> t.log2Size());
116 }
117
118 ngen::DataType ngen() const {
119 using namespace ngen;
120 static const DataType table[16] = {DataType::hf, DataType::f,
121 DataType::df, DataType::invalid, DataType::ub, DataType::b,
122 DataType::uw, DataType::w, DataType::ud, DataType::d,
123 DataType::uq, DataType::q, DataType::bf, DataType::tf32,
124 DataType::invalid, DataType::invalid};
125 return table[(uint32_t(val) >> 16) & 0xF];
126 }
127
128 bool isSubsetOf(Type T) const;
129};
130
131enum class MatrixLayout : uint8_t {
132 N = 0,
133 Nontranspose = 0,
134 T = 1,
135 Transpose = 1,
136 Pc = 2,
137 PackedColumns = 2,
138 Pr = 3,
139 PackedRows = 3
140};
141
142static inline bool isPacked(MatrixLayout l) {
143 return (l == MatrixLayout::PackedRows)
144 || (l == MatrixLayout::PackedColumns);
145}
146
147static inline bool isColMajor(MatrixLayout l) {
148 return (l == MatrixLayout::N || l == MatrixLayout::Pc);
149}
150
151static inline bool isLargeCrosspack(size_t sizeofT, int crosspack) {
152 return (crosspack * sizeofT > 4) && (crosspack > 1);
153}
154
155static inline bool isLargeCrosspack(Type T, int crosspack) {
156 return isLargeCrosspack(T.size(), crosspack);
157}
158
159enum class AccessType : uint8_t {
160 Scattered, // Use scattered accesses
161 ChannelScattered, // Use untyped surface reads
162 Block, // Use block messages
163 PseudoBlock, // Use scattered accesses to emulate block accesses
164 Block2D, // Use 2D block messages
165 Block2DTranspose, // Use 2D block messages with transposition
166 Block2DVNNI, // Use 2D block messages with VNNI transform
167};
168
169static inline bool isBlock2D(AccessType t) {
170 return (t == AccessType::Block2D || t == AccessType::Block2DTranspose
171 || t == AccessType::Block2DVNNI);
172}
173
174enum class RemainderHandling : uint8_t {
175 Ignore, // Assume no remainder, or handled by hardware bounds checking.
176 General, // Handle all remainder cases.
177 Split, // Generate copies of the kernel with and without remainder handling.
178 KnownRemainder, // Assume remainder case; don't create special code for non-remainder case.
179};
180
181enum class KernelScheduling : uint8_t {
182 Static,
183 EUStatic,
184 Dynamic,
185};
186
187// Preferences for using scattered accesses.
188enum class ScatterSIMD {
189 Default,
190 Wide, // Prefer wider SIMD (more scattered lanes)
191 Narrow // Prefer narrower SIMD (more consecutive access)
192};
193
194struct GRFMultirange {
195 std::vector<ngen::GRFRange> ranges;
196
197 GRFMultirange() {}
198 GRFMultirange(ngen::GRFRange range) : ranges {1, range} {}
199
200 ngen::GRF operator[](int idx) const {
201 for (auto &r : ranges) {
202 if (idx < r.getLen()) return r[idx];
203 idx -= r.getLen();
204 }
205 throw std::runtime_error("Index out of bounds");
206 }
207
208 GRFMultirange subrange(int start, int count) const {
209 GRFMultirange result;
210 for (auto &r : ranges) {
211 if (start < r.getLen()) {
212 auto got = std::min(count, r.getLen() - start);
213 result.ranges.push_back(
214 ngen::GRFRange {r.getBase() + start, got});
215 count -= got;
216 start = 0;
217 if (count <= 0) break;
218 } else
219 start -= r.getLen();
220 }
221 return result;
222 }
223
224 GRFMultirange subrange(
225 ngen::HW hw, Type T, const RegisterBlock &block) const;
226
227 bool contiguous(int start, int count) const {
228 for (auto &r : ranges) {
229 if (start < r.getLen()) return (start + count) <= r.getLen();
230 start -= r.getLen();
231 }
232 return false;
233 }
234
235 void append(ngen::GRFRange r) {
236 if (!ranges.empty()) {
237 auto &rend = ranges.back();
238 if (rend.getBase() + rend.getLen() == r.getBase()) {
239 rend = ngen::GRFRange(
240 rend.getBase(), rend.getLen() + r.getLen());
241 return;
242 }
243 }
244 ranges.push_back(r);
245 }
246
247 void append(const GRFMultirange &r) {
248 for (auto &rr : r.ranges)
249 append(rr);
250 }
251
252 uint8_t getLen() const {
253 uint8_t len = 0;
254 for (auto &r : ranges)
255 len += r.getLen();
256 return len;
257 }
258
259 bool empty() const {
260 for (auto &r : ranges)
261 if (r.getLen() > 0) return false;
262 return true;
263 }
264 void clear() { ranges.clear(); }
265};
266
267// A pair of Subregisters in opposite banks.
268class SubregisterPair {
269protected:
270 ngen::Subregister regs[2];
271 bool negative;
272
273public:
274 SubregisterPair() : SubregisterPair(ngen::Subregister()) {}
275 SubregisterPair(ngen::Subregister reg0, ngen::Subregister reg1)
276 : regs {reg0, reg1}, negative(false) {}
277 explicit SubregisterPair(ngen::Subregister reg)
278 : SubregisterPair(reg, reg) {}
279
280 /* implicit */ operator ngen::Subregister() const { return regs[0]; }
281
282 SubregisterPair &operator=(ngen::Subregister reg) {
283 regs[0] = regs[1] = reg;
284 negative = false;
285 return *this;
286 }
287
288 ngen::Subregister getReg(int idx) const;
289 ngen::Subregister getRegAvoiding(
290 ngen::HW hw, const ngen::RegData &rd) const;
291
292 bool isValid() const { return regs[0].isValid() && regs[1].isValid(); }
293 bool isInvalid() const { return !isValid(); }
294 void invalidate() {
295 regs[0].invalidate();
296 regs[1].invalidate();
297 }
298
299 SubregisterPair operator-() const {
300 auto copy = *this;
301 copy.negative = !copy.negative;
302 return copy;
303 }
304};
305
306template <typename T>
307class Scalar {
308protected:
309 bool fixed_value;
310 union {
311 SubregisterPair subs;
312 T value;
313 };
314
315public:
316 Scalar() : Scalar(ngen::Subregister()) {}
317 explicit Scalar(T value_) : fixed_value(true), value(value_) {}
318 Scalar(ngen::Subregister reg0, ngen::Subregister reg1)
319 : fixed_value(false), subs {reg0, reg1} {}
320 explicit Scalar(ngen::Subregister reg) : Scalar(reg, reg) {}
321
322 Scalar &operator=(T value_) {
323 fixed_value = true;
324 value = value_;
325 return *this;
326 }
327 Scalar &operator=(ngen::Subregister reg) {
328 fixed_value = false;
329 subs = reg;
330 return *this;
331 }
332
333 template <typename U>
334 friend inline bool operator==(const Scalar<T> &scalar, const U &val) {
335 return scalar.fixed_value && (val == scalar.value);
336 }
337 template <typename U>
338 friend inline bool operator==(const U &val, const Scalar<T> &scalar) {
339 return scalar == val;
340 }
341
342 template <typename U>
343 friend inline bool operator!=(const Scalar<T> &scalar, const U &val) {
344 return !(scalar == val);
345 }
346 template <typename U>
347 friend inline bool operator!=(const U &val, const Scalar<T> &scalar) {
348 return !(scalar == val);
349 }
350
351 operator T() const {
352 if (!fixed_value) throw std::runtime_error("Scalar is not fixed.");
353 return value;
354 }
355
356 operator SubregisterPair() const {
357 if (fixed_value) throw std::runtime_error("Scalar is fixed.");
358 return subs;
359 }
360
361 SubregisterPair &getPair() {
362 if (fixed_value) throw std::runtime_error("Scalar is fixed.");
363 return subs;
364 }
365
366 bool fixed() const { return fixed_value; }
367
368 ngen::Subregister getReg(int idx) const {
369 return SubregisterPair(*this).getReg(idx);
370 }
371 ngen::Subregister getRegAvoiding(
372 ngen::HW hw, const ngen::RegData &rd) const {
373 return SubregisterPair(*this).getRegAvoiding(hw, rd);
374 };
375};
376
377class MultishiftSubregister {
378protected:
379 static constexpr int maxShift = 5;
380 ngen::Subregister regs[maxShift + 1] = {ngen::Subregister()};
381 bool neg = false;
382
383public:
384 MultishiftSubregister operator-() const {
385 auto copy = *this;
386 copy.neg = !copy.neg;
387 return copy;
388 }
389
390 ngen::Subregister operator>>(int shift) const {
391 ngen::RegData sub = ngen::Subregister {};
392 if (shift >= 0 && shift <= maxShift) sub = regs[shift];
393 if (neg) sub = -sub;
394 return *reinterpret_cast<ngen::Subregister *>(&sub);
395 }
396
397 void set(int shift, ngen::Subregister reg) { regs[shift] = reg; }
398};
399
400struct MatrixAddressing {
401 MatrixLayout layout; // Layout type (N/T/Pr/Pc)
402 uint8_t packSize; // # of elements in a packed row/column for packed layouts.
403 uint8_t crosspack; // Crosspack for packed layouts.
404 uint8_t alignment; // Alignment for all addresses, offsets, and leading dimensions.
405 uint8_t tileR = 0, tileC = 0; // Tiling (0 if none) for packed layouts.
406
407 void setAlignment(int align) { alignment = sanitizeAlign(align); }
408 int defaultAlignment(Type T) const {
409 return sanitizeAlign(
410 T.size() * (isPacked(layout) ? (packSize * crosspack) : 1));
411 }
412
413private:
414 static int sanitizeAlign(int align) {
415 return std::min(128, largest_pow2_divisor(align));
416 }
417};
418
419struct MatrixAddressingStrategy {
420 ngen::AddressBase base; // Base for addressing (A64/BTS/...)
421 AccessType accessType = AccessType::Block; // Block/scattered/etc. access
422 uint8_t tileR = 0, tileC = 0; // Desired tiling (0 if none) in registers.
423 ScatterSIMD smode
424 = ScatterSIMD::Default; // SIMD selection for scattered accesses.
425 unsigned padded : 1; // Allow read/write overruns?
426 unsigned atomic : 1; // Atomic access? (only relevant for C)
427 unsigned address2D : 1; // Use 2D addressing? (media block-style loads)
428 unsigned prefetch : 1; // Prefetch only?
429 unsigned newDP : 1; // Use new dataport messages? (XeHPG+)
430 unsigned dpasw : 1; // DPASW half layout?
431 ngen::CacheSettingsLSC cachingR // Cache policies for LSC reads.
432 = ngen::CacheSettingsLSC::Default;
433 ngen::CacheSettingsLSC cachingW // Cache policies for LSC writes.
434 = ngen::CacheSettingsLSC::Default;
435
436 MatrixAddressingStrategy()
437 : padded(false)
438 , atomic(false)
439 , address2D(false)
440 , prefetch(false)
441 , newDP(false)
442 , dpasw(false) {}
443
444 void preflight(ngen::HW hw);
445 void forceA64();
446
447 ngen::GlobalAccessType getGlobalAccessType() const {
448 return base.isStateless() ? ngen::GlobalAccessType::Stateless
449 : ngen::GlobalAccessType::Surface;
450 }
451};
452
453struct VirtualFlag {
454 uint8_t idx : 6;
455 uint8_t n : 2;
456
457 constexpr VirtualFlag() : idx(0), n(0) {}
458 /* implicit */ VirtualFlag(const ngen::FlagRegister &flag)
459 : idx(flag.index()), n(flag.getBytes() >> 1) {}
460 explicit constexpr VirtualFlag(int idx_, int n_ = 1) : idx(idx_), n(n_) {}
461
462 ngen::FlagRegister toPhysical() const;
463
464 friend inline bool operator==(VirtualFlag vf1, VirtualFlag vf2) {
465 return vf1.idx == vf2.idx && vf1.n == vf2.n;
466 }
467 friend inline bool operator!=(VirtualFlag vf1, VirtualFlag vf2) {
468 return !(vf1 == vf2);
469 }
470
471 bool operator!() const { return (idx == 0) && (n == 0); }
472 explicit operator bool() const { return !!*this; }
473
474 void clear() { *this = VirtualFlag(); }
475};
476
477struct MaskInfo {
478 union {
479 struct {
480 uint8_t isFixed : 1; // = false (variable mask)
481 uint8_t reverse : 1; // True to reverse mask.
482 uint8_t : 6;
483 uint8_t rsize; // Maximum remainder value. (e.g. 16 if we need the last 4 bits of the index).
484 uint8_t maskRep; // # of repetitions of mask pattern.
485 uint8_t bitRep : 5; // # of times each mask bit is repeated.
486 uint8_t rdivide : 3; // Amount by which to divide index before forming mask. Fractions are rounded up.
487 // Note maskRep * bitRep * (rsize >> rshift) = # mask bits.
488 } variable;
489 struct {
490 uint8_t isFixed : 1; // = true (fixed mask)
491 uint8_t _ : 7;
492 uint8_t rsize; // Maximum remainder value.
493 uint16_t value; // Mask value.
494 } fixed;
495 uint32_t raw;
496 };
497
498 MaskInfo() : fixed {true, 0, 0, 0xFFFF} {}
499
500 bool operator!() const { return fixed.isFixed && fixed.value == 0xFFFF; }
501 explicit operator bool() const { return !!*this; }
502
503 static MaskInfo None() { return MaskInfo(); }
504
505 friend bool operator==(const MaskInfo &i1, const MaskInfo &i2) {
506 return i1.raw == i2.raw;
507 }
508 friend bool operator!=(const MaskInfo &i1, const MaskInfo &i2) {
509 return !(i1 == i2);
510 }
511};
512
513struct MaskAssignment {
514 MaskInfo mask; // Associated mask
515 LoopType var; // Variable to base mask off of
516 uint8_t offset; // Amount to subtract from variable.
517 VirtualFlag flag; // Index of virtual flag register to use.
518
519 bool compatible(const MaskAssignment &other) const {
520 return mask == other.mask && var == other.var && offset == other.offset;
521 }
522 void reverse(int width) {
523 offset = width - offset - mask.variable.rsize;
524 mask.variable.reverse = !mask.variable.reverse;
525 }
526};
527
528struct RegisterBlock {
529 /* Register layout information. */
530 uint8_t nr, nc; // Size of this block.
531 uint8_t ld; // Leading dimension, in elements.
532 uint8_t offsetR, offsetC; // Row and column offset within matrix block.
533 uint8_t colMajor : 1; // Is this block column-major? (columns stored consecutively inside each register)
534 uint8_t splitComplex : 1; // True if complex data split into successive real and imaginary parts.
535 uint8_t : 6;
536 uint8_t crosspack; // Crosspack for this block (1 if none).
537 uint8_t component : 6; // Component # for this block.
538 int8_t cxComponent : 2; // Complex component # for this block (-1 if not complex or interleaved).
539 uint16_t bytes; // # of bytes in this block.
540 uint16_t offsetBytes; // Byte offset within register block.
541
542 /* Load/store information. */
543 uint8_t remainderR : 1; // Row remaindering enabled?
544 uint8_t remainderC : 1; // Column remaindering enabled?
545 uint8_t noRowsOK : 1; // Can handle no rows (in mask/descriptor)?
546 uint8_t noColsOK : 1; // Can handle no columns (in mask/descriptor)?
547 uint8_t descRemR : 1; // Row remainders can be handled by changing the descriptor?
548 uint8_t descRemC : 1; // Column remainders can be handled by changing the descriptor?
549 uint8_t descAssigned : 1; // True if address registers have been assigned for this block's descriptors.
550 uint8_t writable : 1; // True if block is set up for writing.
551
552 uint8_t ebytes; // Size of element in bytes, e.g. 4 for scattered_dword, 16 for block_hword
553 uint8_t count; // Element count.
554 uint8_t extra; // Extra info. For block accesses, 1 means aligned OWord, 0 unaligned. For scattered accesses, # of consecutive elements.
555 uint8_t simdSize; // SIMD size for load/stores (0 indicating no separate load/store needs to be done.)
556 uint8_t msgRegs; // Underlying register count for load/store operation (may be different from nregs()).
557 VirtualFlag flag; // Assigned flag register index and modifiers, if any.
558 uint8_t flagAny : 1; // Use .anyh?
559 uint8_t flagAll : 1; // Use .allh?
560 uint8_t hasNoLoad : 1; // Does this load/store cover additional (no-load) RegisterBlocks? (packed layouts)
561 uint8_t : 5;
562 uint8_t sfid; // SFID for this block.
563 uint8_t rowFragment; // If this block needs fragmenting to support row/column remainders, the maximum block size (power of 2) to fragment down to.
564 uint8_t colFragment; // Zero if no fragmenting needed.
565 uint8_t addrShift; // log2(address units). e.g. 0 if byte addresses should be used, 4 if oword addresses should be used.
566 uint8_t log2GRFBytes; // log2(bytes per GRF).
567
568 MaskInfo rowMask; // Row mask for this block.
569 MaskInfo colMask; // Column mask for this block.
570
571 static constexpr int8_t Interleaved
572 = -1; // Value for cxComponent indicating interleaved real/imaginary data.
573
574 void calcBytes(Type T); // Auto-calculate # of registers.
575 void calcBytes(Type T, const MatrixAddressingStrategy &astrategy);
576
577 void clearFlag() {
578 flag.clear();
579 flagAll = flagAny = false;
580 }
581 void eraseMask() {
582 clearFlag();
583 rowMask = MaskInfo();
584 colMask = MaskInfo();
585 }
586
587 bool isLoadBlock() const { return simdSize > 0; }
588
589 int nregs() const;
590 int offsetReg() const;
591
592 void simplify(Type T);
593 void compact(Type T);
594};
595
596struct Address2DParams {
597 ngen::Subregister rows, cols;
598 ngen::Subregister offR, offC;
599 ngen::Subregister remR, remC;
600 int fixedRows = 0, fixedCols = 0;
601};
602
603class VirtualFlagAllocator {
604public:
605 VirtualFlagAllocator(ngen::HW hw)
606 : free((1ul << (ngen::GRF::bytes(hw) >> 1)) - 1)
607 , nflag(ngen::FlagRegister::subcount(hw)) {}
608
609 VirtualFlag allocVirtual(int n = 1);
610 ngen::FlagRegister alloc(int n = 1);
611 ngen::FlagRegister tryAlloc(int n = 1);
612
613 void claim(VirtualFlag vflag) { free &= ~mask(vflag); }
614 void release(VirtualFlag vflag) { free |= mask(vflag); }
615 void release(const ngen::FlagRegister &reg) {
616 release(VirtualFlag(reg));
617 unlock(reg);
618 }
619 void safeRelease(VirtualFlag &vflag) {
620 if (vflag) release(vflag);
621 vflag.clear();
622 }
623 void safeRelease(ngen::FlagRegister &reg) {
624 if (reg.isValid()) release(reg);
625 reg.invalidate();
626 }
627
628 bool isVirtual(VirtualFlag vflag) { return (vflag.idx >= nflag); }
629
630 bool lock(VirtualFlag vflag) {
631 bool wasLocked = isLocked(vflag);
632 locked |= mask(vflag);
633 return wasLocked;
634 }
635 void unlock(VirtualFlag vflag) { locked &= ~mask(vflag); }
636 bool isLocked(VirtualFlag vflag) const { return (locked & mask(vflag)); }
637
638 ngen::FlagRegister assignPhysical(VirtualFlag vflag);
639
640 static int getBase(int idx) { return idx & 0x1F; }
641 static int getN(int idx) { return idx >> 5; }
642 static int makeIndex(int base, int n) { return base | (n << 5); }
643
644protected:
645 uint32_t free;
646 uint8_t locked = 0;
647 uint8_t nextPhys = 0;
648 uint8_t nflag;
649
650 static uint32_t mask(VirtualFlag vflag) { return mask(vflag.idx, vflag.n); }
651 static uint32_t mask(int idx, int n) {
652 return (1ul << (idx + n)) - (1ul << idx);
653 }
654};
655
656class TokenAllocator {
657public:
658 TokenAllocator(ngen::HW hw);
659
660 int8_t tryAlloc();
661 void release(int8_t token) { free |= (1u << token); }
662 void safeRelease(int8_t &token) {
663 if (token >= 0) release(token);
664 token = -1;
665 }
666
667protected:
668 uint32_t free;
669};
670
671// State parameters shared between different kernel types.
672struct CommonState {
673 ngen::RegisterAllocator ra;
674 ngen::GRF signChange, selectImag;
675 ngen::GRF vflagStorage;
676 std::array<VirtualFlag, 8> activeVFlags;
677 VirtualFlagAllocator raVFlag;
678 TokenAllocator tokenAllocator;
679 std::vector<std::pair<uint8_t, int8_t>> tokenMap;
680 ngen::Subregister readFailures;
681 ngen::Subregister fusedID;
682 ngen::Subregister lsDescConstant[4];
683 ngen::FlagRegister flagSwizzle;
684 EmulationState emulate;
685 ngen::GRFRange eatomicAddRegs[2];
686 ngen::GRFRange remaskRegs[2];
687 VirtualFlag vflagEAtomicAdd;
688 ngen::Subregister all1s;
689 ngen::RegData r0_info;
690 bool movedR0 = false;
691 ngen::Subregister lid0;
692 GRFMultirange indexVec; // uw
693 int ivEntries = 0;
694 struct {
695 ngen::GRF zero, one;
696 ngen::GRFRange src1Storage;
697 ngen::GRF src1, srcR1, srcI1, r, d;
698 ngen::GRFRange mathTemp;
699 ngen::GRF temp;
700 std::array<ngen::FlagRegister, 2> tempFlags;
701 ngen::Subregister flagStore; // ud
702 ngen::Label label;
703 int simd;
704 ngen::Subregister callStorageSub, callStorage;
705 bool use = false;
706 } invertSub;
707
708 CommonState(ngen::HW hw) : ra(hw), raVFlag(hw), tokenAllocator(hw) {}
709
710 void wipeActiveVFlags() {
711 for (int i = 0; i < int(activeVFlags.size()); i++)
712 if (!raVFlag.isLocked(VirtualFlag(i))) activeVFlags[i].clear();
713 }
714
715 void usePhysicalFlag(ngen::FlagRegister flag) {
716 activeVFlags[flag.index()] = flag;
717 }
718
719 void allocEmulate64Temp(const EmulationStrategy &estrategy) {
720 int ntemp = 0;
721 if (estrategy.emulate64) ntemp = std::max(ntemp, 2);
722 if (estrategy.emulate64_mul) ntemp = std::max(ntemp, 2);
723 if (estrategy.emulateDWxDW) ntemp = std::max(ntemp, 1);
724
725 for (int q = 0; q < ntemp; q++)
726 emulate.temp[q] = ra.alloc();
727 }
728};
729
730// Places to store r0 information.
731enum class MoveR0 { None, Acc, Addr, GRF };
732
733// Problem parameters shared between kernel types.
734struct CommonProblem {
735 bool nonuniformWGs = false; // Support nonuniform workgroups?
736 bool gtpinSupport = false; // Support GT-Pin?
737};
738
739// Strategy parameters shared between different kernel types.
740struct CommonStrategy {
741 int subgroupSize = 8; // Subgroup size provided to OpenCL runtime.
742 bool fused = false; // Fused EU handling enabled?
743 bool dualGRF = true; // Enable two-GRF instructions.
744 bool ieeeDenormals = true; // Enable IEEE-compliant denormals.
745 bool spf = true; // Enable Single Program Flow (SPF) mode in EUs.
746 MoveR0 moveR0 = MoveR0::Acc; // Where to store r0 information.
747 bool sipR0WA = false; // Avoid using r0 to avoid clobbering by SIP.
748 bool readSuppressionWA
749 = true; // Workaround for HW issue with read suppression after fused sends.
750 bool wgInSS
751 = false; // Pretend to use barriers so that each WG belongs to 1 SS/DSS.
752 int GRFs = 128; // # of GRFs to use.
753 bool finalFence = false; // Issue global memory fence before EOT.
754 int pauseCycles
755 = 0x0200; // Number of cycles to pause when waiting in a spin-loop.
756 bool simulation = false; // For use in simulator?
757
758 EmulationStrategy emulate;
759
760 CommonStrategy() {}
761 CommonStrategy(ngen::HW hw, int stepping = 0);
762 void preflight(ngen::HW hw, const CommonProblem &problem);
763};
764
765// Types of updates for GEMM kernels.
766enum class UpdateType {
767 Full,
768 UpperTriangle,
769 UpperTriangleHermitian,
770 LowerTriangle,
771 LowerTriangleHermitian
772};
773
774// A/B offset mode.
775enum class ABOffset {
776 None, // No A/B offsets.
777 Calc, // Calculate A/B row/column sums in kernel.
778 Load, // Use precalculated row/column sums.
779};
780
781// C offset mode.
782enum class COffset {
783 None, // No C offsets.
784 Post, // C offset after all other updates.
785 Pre, // C offset before all other updates (bias).
786};
787
788// Batch mode.
789enum class BatchMode { None, Strided, Nonstrided, Variable };
790
791// Binary operations.
792enum class BinaryOp { Add, Sub, Mul, Div, Min, Max };
793
794// GEMM kernel problem description.
795struct GEMMProblem : public CommonProblem {
796 Type Ta, Tb, Tc, Tco, Ts; // Types for A/B/C/C offsets/scalars in registers.
797 Type Ta_ext, Tb_ext, Tc_ext; // Types for A/B/C data in memory.
798
799 Scalar<double> alpha_real, alpha_imag; // Alpha value, if fixed.
800 Scalar<double> beta_real, beta_imag; // Beta value, if fixed.
801 MatrixAddressing A, B, C, CO; // Addressing information for matrices.
802 bool checkBeta0 = true; // If true, check for beta = 0 and handle specially.
803 ABOffset abOffset = ABOffset::None; // A/B offset mode.
804 COffset cOffset = COffset::None; // C offset mode.
805 BatchMode batch = BatchMode::None; // Batch mode.
806 int batchDims = 0; // # of batch dimensions (strided batch only).
807 bool sumA = false,
808 sumB
809 = false; // If true, calculate A row sums/B column sums and store in CO.
810 post_ops_t postOps; // Fused post operations to apply
811 bool postOpFwd = true; // Eltwise parameters
812 std::vector<MatrixAddressing> binary; // Binary postop data
813 std::vector<Type> Tbinary; // Binary types
814 std::vector<bool> binaryRow; // Dimensionality of binary data
815 std::vector<bool>
816 binaryCol; // (false means broadcast in the given dimension)
817 std::vector<bool> binaryBatch;
818
819 bool hasPostOp() const { return postOps.len() > 0; }
820 bool hasBinaryPostOp() const {
821 for (int idx = 0; idx < postOps.len(); idx++)
822 if (postOps.entry_[idx].is_binary()) return true;
823 return false;
824 }
825
826 bool beta0() const {
827 return (beta_real == 0) && (!Tc.isComplex() || (beta_imag == 0));
828 }
829 bool beta1() const {
830 return (beta_real == 1) && (!Tc.isComplex() || (beta_imag == 0));
831 }
832 bool alpha1() const {
833 return (alpha_real == 1) && (!Tc.isComplex() || (alpha_imag == 0));
834 }
835 bool alphaM1() const {
836 return (alpha_real == -1) && (!Tc.isComplex() || (alpha_imag == 0));
837 }
838
839 bool needsTsConvert() const {
840 if (!(alpha1() || alphaM1())) return true;
841 if (!(beta0() || beta1())) return true;
842 if (beta1() && !Tc_ext.isSubsetOf(Tc)) return true;
843 if (hasPostOp()) return true;
844 return false;
845 }
846
847 bool gemmt() const { return false; }
848 bool backward() const { return false; }
849
850 bool needsASums() const { return (abOffset == ABOffset::Calc) || sumA; }
851 bool needsBSums() const { return (abOffset == ABOffset::Calc) || sumB; }
852 bool usesCO() const { return (cOffset != COffset::None) || sumA || sumB; }
853 bool allowMatrixOffset() const { return (cOffset == COffset::Pre); }
854};
855
856struct GEMMState;
857
858// How to split A/B amongst threads in a workgroup.
859enum class CoopSplit {
860 K, // Split in k dimension
861 MN, // Split in m/n dimensions
862 Linear, // Split in linear index order
863};
864
865// Strategy parameters for GEMM kernels.
866struct GEMMStrategy : public CommonStrategy {
867 int blocking[3] = {
868 0}; // Recommended block size in each dimension (m/n/k) -- for driver.
869 int blockingAlt[3] = {
870 0}; // Alternate block size in each dimension (m/n/k) -- for driver.
871 // m/n alternates are for Hilbert-ordered kernels when Hilbert ordering disabled.
872 // k alternate is for multi-tile execution with implicit scaling.
873 int unroll[3]; // Unrolls in each dimension (m/n/k), indexed by LoopType.
874 int unrollK_masked = 0; // k unroll to use when masking.
875 LoopType loopOrder[3] = {LoopM, LoopN,
876 LoopK}; // Expected order of loops in driver code (in order from innermost to outermost).
877 LoopType fusedLoop = LoopM; // Direction of fusing if threads fused.
878 bool hilbertOrder = false; // Use Hilbert-like walk order in C?
879 bool boustrophedon = false; // Use panel-boustrophedon walk order in C?
880 bool persistent = false; // Use persistent thread model?
881 bool reverse[2] = {false, false}; // Reverse m/n walk order?
882 int fmaSIMD = 0; // Vector length for FMA (0 = default = 2 GRFs).
883 int kChain = 1; // # of FMAs to chain in k dimension.
884 int wg[3] = {0, 0,
885 0}; // m/n/k workgroup sizes, 0 if unconstrained. Indexed by LoopType.
886 WGType forceWGUpdate = WGDynamic; // Force work group update type.
887 MatrixAddressingStrategy A, B, C,
888 CO; // Strategies for accessing A/B/C/C offsets.
889 int ka_load, kb_load; // How much of A/B is loaded at once, in k dimension
890 int ka_load_masked = 0,
891 kb_load_masked
892 = 0; // Same as above, when masking m/n (0 = default = same as ka/kb_load)
893 bool slmA = false, slmB = false; // Whether to copy A/B to SLM.
894 bool splitCopy = false; // Separate SLM copy and compute threads?
895 int slmBuffers = 0; // # of A/B SLM buffers, 0 for none.
896 int unrollKSLM
897 = 0; // k unroll for SLM copies (0 = auto = unroll[LoopK]/slmCopies)
898 int unrollKSLMMasked
899 = 0; // Alternate value to use with masking (0 = same as unrollKSLM)
900 bool slmATrans = false,
901 slmBTrans
902 = false; // Whether A/B SLM data should be completely crosspacked (transposed).
903 int A_copies = 1,
904 B_copies = 1; // # of copies of A/B matrices, for latency absorption
905 int slmCopies = 1; // # of copies of loaded A/B matrices for SLM copies.
906 bool slmRepackAhead = false; // Repack SLM data ahead of stores?
907 int optAlignAB
908 = 0; // Optional alignment for A/B. If > 0, create two versions of k loop, one for A/B aligned to this value, one not.
909 AccessType unalignedAccA,
910 unalignedAccB; // Access types to use for A/B on unaligned path.
911 int ka_prefetch = 0, kb_prefetch = 0; // Chunk size for prefetching A/B.
912 int ka_pfStride = 0, kb_pfStride = 0; // k stride between A/B prefetches.
913 bool cooperativePF = true; // Enable WG-cooperative A/B prefetches.
914 int prefetchA = 0, prefetchB = 0,
915 prefetchC = 0; // Prefetch distances, in units of unrollK.
916 int prefetchAMasked = 0,
917 prefetchBMasked = 0; // Same as above, when masking m/n.
918 MatrixAddressingStrategy A_prefetch, B_prefetch,
919 C_prefetch; // Strategies for prefetching A/B/C.
920 enum {
921 CSeparate, // C stored in its own bundle, A/B in the other bundle.
922 ACB, // A, then C, then B
923 BCA, // B, then C, then A
924 VNC, // A/B (broadcast matrix second), then C
925 ABInterleave, // A/B interleaved, then C
926 NSeparate, // Broadcast input stored in its own bundle(s)
927 VAvoid, // C registers allocated to avoid non-broadcast inputs
928 } registerScheme
929 = CSeparate; // Register layout scheme.
930 bool avoidIncConflicts
931 = true; // If true, duplicate some increment values across banks to avoid bundle conflicts.
932 bool kParallel
933 = false; // If true, generate k-parallelized kernel using global memory reduction.
934 bool kParallelLocal
935 = false; // If true, generate k-parallelized kernel using local memory reduction.
936 bool doubleWA
937 = false; // Use explicit double broadcast instructions? (Gen9 only)
938 int barrierFreq
939 = 0; // If > 0, set a periodic barrier every barrierFreq k loops to keep threads together.
940 bool splitBarrier
941 = false; // Use split barriers for these periodic barriers?
942 bool altCRemainder = false; // Use alternative double-loop C remainder code?
943 bool block2DCRemainder = false; // Generate block 2D C remainder path?
944 bool cAccumulators
945 = false; // Use accumulator registers for part of C (to save a few registers)?
946 bool cLoadAhead = false; // Load C before doing FMAs?
947 bool forceCopyC = false; // Force C to be copied before the update step?
948 bool noJumpTables = false; // Disallow jump tables?
949 RemainderHandling remHandling[3] = {
950 // m, n, k remainder handling.
951 RemainderHandling::Split,
952 RemainderHandling::Split,
953 RemainderHandling::General,
954 };
955 bool jointSplit
956 = true; // Use remainder kernel for both m and n dimensions if both are split.
957 int mSplitThresh = 0,
958 nSplitThresh
959 = 0; // m/n minimum thresholds for using split remainder handling. 0 means always use split.
960 bool atomicFMA = false; // Use {Atomic} FMA chains.
961 bool extendedAtomicFMA = false; // Use longer {Atomic} FMA chains.
962 bool stallAfterLoad = false; // Insert stalls after load operations.
963 bool checkAdd32
964 = false; // Check inside kernel if inner loop additions can be done in 32-bit.
965 bool delayABInc
966 = true; // Delay A/B increment a few outer products in the k loop.
967 CoopSplit coopA = CoopSplit::
968 K; // How to split SLM copies, cooperative prefetches amongst threads in a workgroup
969 CoopSplit coopB = CoopSplit::K;
970 bool slmEarlyKMask
971 = false; // Prepare A/B reads to use k-masking (when applicable) in main loop, instead of waiting for remainder.
972 bool slmUseIncrCopy = true; // Use incremental SLM copies if needed.
973 bool slmAltBarriers = false; // Alternate fenceless SLM buffering algorithm.
974 bool strictFence
975 = false; // Add extra SLM fences that are not usually required on HW.
976 bool skipFence
977 = false; // Skip SLM fences that theoretically should be required but HW doesn't need.
978 bool slmFenceWARWA
979 = false; // Work around buggy SLM fence that doesn't protect against WAR hazards.
980 bool systolic = false; // Use systolic array if applicable.
981 bool dpasw = false; // Use DPASW for fused EU architectures.
982 bool fixedSystolic
983 = false; // Use hardcoded systolic inner loop for 32x32 or 32x48 unrolls.
984 int namedBarriers[2] = {0,
985 0}; // # of named barriers in m, n dimensions (0 to use regular barriers).
986 bool skewLocalIDs
987 = false; // Remap local IDs for large workgroups so that threads on the same EU don't depend on the same data.
988 bool xParallel = false; // TRSM: parallelize in x dimension.
989 bool checkBeta1
990 = false; // If true, check for beta = 1 and handle specially.
991 std::vector<MatrixAddressingStrategy>
992 binary; // Strategies for binary postop data
993
994 bool insideSK = false; // Inside a superkernel?
995
996 GEMMStrategy() {}
997 GEMMStrategy(ngen::HW hw, int stepping = 0)
998 : CommonStrategy(hw, stepping) {}
999
1000 void preflight(ngen::HW hw, const GEMMProblem &problem);
1001 bool minimize(ngen::HW hw, const GEMMProblem &problem);
1002
1003 bool lateExit() const {
1004 return (slmBuffers > 0) || barrierFreq || kParallelLocal
1005 || (cooperativePF && (prefetchA || prefetchB));
1006 }
1007
1008 int maxKSLM(const GEMMProblem &problem, bool isA) const;
1009 int slmABufBlockSize(const GEMMProblem &problem) const {
1010 return fixedSystolic ? 1152
1011 : int(slmA) * problem.Ta * problem.Ta.components()
1012 * unroll[LoopM] * maxKSLM(problem, true);
1013 }
1014 int slmBBufBlockSize(const GEMMProblem &problem) const {
1015 return fixedSystolic ? 1536
1016 : int(slmB) * problem.Tb * problem.Tb.components()
1017 * unroll[LoopN] * maxKSLM(problem, false);
1018 }
1019 int slmABufSize(const GEMMProblem &problem) const {
1020 return slmABufBlockSize(problem) * wg[LoopM] * wg[LoopK] * slmBuffers;
1021 }
1022 int slmBBufSize(const GEMMProblem &problem) const {
1023 return slmBBufBlockSize(problem) * wg[LoopN] * wg[LoopK] * slmBuffers;
1024 }
1025 int slmSysgemmBlockSize() const {
1026 return 1152 * wg[LoopM] + 1536 * wg[LoopN];
1027 }
1028 bool variableSLM() const { return kParallelLocal; }
1029
1030 int ka_inc() const { return slmA ? unrollKSLM : ka_load; }
1031 int kb_inc() const { return slmB ? unrollKSLM : kb_load; }
1032
1033 bool needsMNLocalIDs() const {
1034 return xParallel || (slmBuffers > 0) || cooperativePF || kParallelLocal
1035 || persistent || namedBarriers[0] || (dpasw && !fixedSystolic);
1036 }
1037 bool needsKLocalIDs() const { return kParallelLocal || persistent; }
1038 bool needsBarrier() const {
1039 return (barrierFreq > 0) || (slmBuffers > 0) || xParallel
1040 || kParallelLocal;
1041 }
1042
1043 bool fusedM() const { return fused && (fusedLoop == LoopM); }
1044 bool fusedN() const { return fused && (fusedLoop == LoopN); }
1045
1046 WGType getWGType(const GEMMProblem &problem) const {
1047 if ((slmBuffers > 0) || (forceWGUpdate == WGFixed)
1048 || (barrierFreq && namedBarriers[0]))
1049 return WGFixed;
1050 else
1051 return WGDynamic;
1052 }
1053
1054 bool fixedWG(const GEMMProblem &problem) const {
1055 return (getWGType(problem) == WGFixed);
1056 }
1057 bool linearOrder() const { return hilbertOrder || boustrophedon; }
1058};
1059
1060struct LDMultiples {
1061 ngen::GRFRange range;
1062 bool a64 = false;
1063};
1064
1065// State parameters for GEMM kernels.
1066struct GEMMState : public CommonState {
1067 struct Inputs {
1068 ngen::Subregister A, B, C[2], CO, base; // q
1069 ngen::Subregister ao, bo, abo; // w/w/ud
1070 ngen::Subregister aoPtr, boPtr; // q
1071 ngen::Subregister offsetA, offsetB, offsetC[2]; // q
1072 ngen::Subregister offsetCO; // d
1073 ngen::Subregister lda, ldb, ldc[2], ldco; // d
1074 ngen::Subregister m, n, k, k0; // d
1075 ngen::Subregister alpha_real, alpha_imag; // T_real
1076 ngen::Subregister beta_real, beta_imag; // T_real
1077 ngen::Subregister groupIDM, groupIDN, groupIDK; // ud
1078 ngen::Subregister groupIDMN; // ud
1079 ngen::GRF localIDM, localIDN, localIDK; // uw
1080 ngen::Subregister localSizeM, localSizeN, localSizeK; // ud
1081 ngen::Subregister groupCountM, groupCountN; // ud
1082 ngen::Subregister groupCountMN; // ud
1083 ngen::Subregister groupStride; // ud
1084 ngen::Subregister hilbertVD, hilbertUVDRecip; // ud
1085 ngen::Subregister hilbertBail; // ud
1086 ngen::Subregister bslice, bthresh; // d
1087 ngen::Subregister flags; // ud
1088 ngen::Subregister diagA, diagB, diagC; // q
1089 uint8_t surfaceA, surfaceB; // BTS indices
1090 uint8_t surfaceC[2], surfaceCO; // BTS indices
1091 ngen::Subregister strideA[2], strideB[2],
1092 strideC[2]; // ud, used for strided batch.
1093 ngen::Subregister batchSize1, recipBatchSize1; // ud, 2D strided batch
1094 ngen::Subregister offsetBatch; // ud, used for non-strided batch.
1095 ngen::Subregister incr_a_array,
1096 incr_b_array; // ud, used for non-strided variable batch.
1097 ngen::Subregister incr_alpha,
1098 incr_beta; // ud, used for non-strided variable batch.
1099 ngen::Subregister alpha_array,
1100 beta_array; // q, used for non-strided variable batch.
1101 std::vector<ngen::Subregister> binarySrcs; // q
1102 std::vector<ngen::Subregister> binaryOffsets; // q/d
1103 std::vector<ngen::Subregister> binaryLDs; // d
1104 std::vector<std::array<ngen::Subregister, 2>> binaryStrides; // d
1105 std::vector<uint8_t> binarySurfaces;
1106 } inputs;
1107 Type Ta_load, Tb_load; // Current type to be loaded into A/B_regs.
1108 Type Tacc; // Current type in accumulator registers.
1109 ngen::Subregister persistentGroupID; // ud
1110 ngen::Subregister batchID[2]; // ud
1111 ngen::Subregister offsetA, offsetB, offsetC[2];
1112 ngen::Subregister offsetAp, offsetBp, offsetCp;
1113 ngen::Subregister offsetCO;
1114 ngen::Subregister saveOffsetA, saveOffsetB, saveOffsetC[2];
1115 ngen::Subregister saveOffsetCO;
1116 ngen::Subregister fullK;
1117 ngen::Subregister effA, effB, effC[2],
1118 effCO; // Offsets to base of A/B/C/CO chunks for loading/storing.
1119 ngen::Subregister effAi, effBi;
1120 ngen::Subregister effAo, effBo;
1121 ngen::Subregister effAp, effBp, effCp;
1122 ngen::Subregister effAs, effBs;
1123 std::vector<ngen::GRFRange> A_addrs, B_addrs, C_addrs[2];
1124 std::vector<ngen::GRFRange> A_addrsRem, B_addrsRem;
1125 std::vector<ngen::GRFRange> Ai_addrs, Bi_addrs;
1126 std::vector<std::vector<ngen::GRFRange>> Ai_addrsK, Bi_addrsK;
1127 std::vector<ngen::GRFRange> Ai_addrsRem, Bi_addrsRem;
1128 std::vector<ngen::GRFRange> Ao_addrs, Bo_addrs;
1129 std::vector<ngen::GRFRange> Ap_addrs, Bp_addrs, Cp_addrs;
1130 std::vector<GRFMultirange> A_regs, B_regs, C_regs;
1131 GRFMultirange Ar_regs, Br_regs; // Repacked A/B registers.
1132 std::vector<GRFMultirange> Ai_regs,
1133 Bi_regs; // Incoming data to copy to SLM.
1134 std::vector<GRFMultirange> Ai_regsRem, Bi_regsRem;
1135 GRFMultirange Ao_regs, Bo_regs; // Outgoing data to copy to SLM.
1136 GRFMultirange Ao_regsRem, Bo_regsRem;
1137 GRFMultirange As_regs, Bs_regs; // A row sums/B column sums.
1138 GRFMultirange Ap_regs, Bp_regs, Cp_regs; // A/B/C prefetch registers.
1139 std::vector<MaskAssignment> AB_masks;
1140 ngen::GRFRange broadcast_regs;
1141 std::vector<ngen::GRFRange> tempMul_regs;
1142 ngen::Subregister i0, j0, h0; // d
1143 ngen::Subregister remainders[3]; // d (todo: w)
1144 ngen::Subregister remaindersFused[2]; // w
1145 ngen::Subregister remaindersWG[2]; // d (todo: w)
1146 ngen::Subregister remFusedStorage; // d
1147 ngen::Subregister diagC; // d
1148 SubregisterPair lda, ldb;
1149 SubregisterPair lda_ka, ldb_kb; // Cached lda * ka, ldb * kb
1150 SubregisterPair lda_ka_prefetch,
1151 ldb_kb_prefetch; // Cached lda * ka_pfStride, ldb * kb_pfStride
1152 LDMultiples ldaMultiples, ldbMultiples, ldcMultiples[2];
1153 int ka_cached = 0, kb_cached = 0; // Multipliers for lda_ka/ldb_kb.
1154 ngen::Subregister k, K; // d
1155 ngen::FlagRegister flagAP;
1156 ngen::Subregister beta1; // d
1157 ngen::Subregister add64; // uw
1158 ngen::Subregister lidM, lidN, lidStorage; // uw, uw, ud
1159 ngen::Subregister lidK, lszK, lidszKStorage; // uw, uw, ud
1160 ngen::Subregister ia0_slm, jb0_slm; // uw
1161 ngen::Subregister postRemA, postRemB; // ud
1162 ngen::Subregister postRemAi, postRemBi; // ud
1163 ngen::Subregister postRemAo, postRemBo; // ud
1164 ngen::Subregister isCompute; // ud
1165 ngen::GRF sysSumAll1s; // Ta/Tb
1166 bool systolicSumA = false, systolicSumB = false;
1167 bool lateKLoopCheck = false;
1168 int ka_loadRem, kb_loadRem;
1169 bool Ai_hasKRem, Bi_hasKRem;
1170 bool Ai_lateKRem, Bi_lateKRem;
1171 bool Ai_incrementalRem, Bi_incrementalRem;
1172 bool Ai_remIncrCopy, Bi_remIncrCopy;
1173 int ma_slm, ka_slm, kb_slm, nb_slm;
1174 int ma_prefetch, ka_prefetch, kb_prefetch, nb_prefetch;
1175 CoopSplit effCoopA = CoopSplit::K;
1176 CoopSplit effCoopB = CoopSplit::K;
1177 std::vector<RegisterBlock> A_layout, B_layout, C_layout;
1178 std::vector<RegisterBlock> A_layoutRem, B_layoutRem;
1179 std::vector<RegisterBlock> Ar_layout, Br_layout;
1180 std::vector<RegisterBlock> Ai_layout, Bi_layout;
1181 std::vector<std::vector<RegisterBlock>> Ai_layoutK, Bi_layoutK;
1182 std::vector<RegisterBlock> Ai_layoutRem, Bi_layoutRem;
1183 std::vector<RegisterBlock> Ao_layout, Bo_layout;
1184 std::vector<RegisterBlock> As_layout, Bs_layout;
1185 std::vector<RegisterBlock> Ap_layout, Bp_layout, Cp_layout;
1186 std::vector<RegisterBlock> C_layoutExt, C_layoutExtUnmasked;
1187 Address2DParams A_params, B_params;
1188 Address2DParams Ai_params, Bi_params;
1189 Address2DParams Ap_params, Bp_params;
1190 int Ai_regCount = 0, Bi_regCount = 0;
1191 bool aioShare, bioShare;
1192 bool aioShareRem, bioShareRem;
1193 bool aoReuseA = false, boReuseB = false;
1194 MatrixAddressing Ai, Bi, Ao, Bo;
1195 MatrixAddressingStrategy Ai_strategy, Bi_strategy;
1196 MatrixAddressingStrategy Ao_strategy, Bo_strategy;
1197 MatrixAddressingStrategy Cext_strategy;
1198 int8_t tokenBarrierFence[2];
1199 ngen::InstructionModifier modBarrierFence[2];
1200 bool barrierReady = false;
1201 ngen::GRF barrierHeader;
1202 ngen::GRF barrierHeaderM, barrierHeaderN;
1203 ngen::FlagRegister barrierM, barrierN;
1204 bool firstKLoopSegment;
1205 bool isNested = false;
1206 int C_accCount;
1207 bool cSwapActive = false;
1208 int C_count = 1;
1209 int C_buffers = 1;
1210 bool allocedAo = false, allocedBo = false;
1211 bool allowEmptyC = false;
1212 bool copyC = false;
1213 bool broadcast;
1214 bool repackA = false, repackB = false;
1215 bool repackARem = false, repackBRem = false;
1216 int ka_repackRem, kb_repackRem;
1217 bool remActiveA, remActiveB, remActiveSLM;
1218 std::vector<MaskAssignment> kMasksSLM;
1219 bool slmRemaskA = false, slmRemaskB = false;
1220 bool slmASums = false, slmBSums = false;
1221 bool doLateExit = false;
1222 ngen::GRF emulate64TempSave[2];
1223
1224 std::vector<ngen::Subregister> effBinary;
1225
1226 struct {
1227 bool active = false;
1228 uint8_t surfacePlan;
1229 ngen::Subregister plan;
1230 ngen::Subregister slotA, slotB;
1231 ngen::Subregister localIDFlat;
1232 ngen::FlagRegister needLateGEMMDone;
1233 } fusedGEMM;
1234
1235 struct {
1236 ngen::InstructionModifier depAddr[4];
1237 } sysgemm;
1238
1239 GEMMState(ngen::HW hw) : CommonState(hw) {}
1240};
1241
1242// GEMM superkernel problem.
1243struct GEMMSuperkernelProblem : public GEMMProblem {};
1244
1245// GEMM superkernel strategy parameters.
1246struct GEMMSuperkernelStrategy {
1247 std::vector<GEMMStrategy> substrategies;
1248 KernelScheduling schedule;
1249 bool multiM, multiN;
1250 bool persistent = false;
1251
1252 void preflight(ngen::HW hw, const GEMMProblem &problem);
1253 int subgroupSize() const { return substrategies[0].subgroupSize; }
1254};
1255
1256// GEMM superkernel state.
1257struct GEMMSuperkernelState : public GEMMState {
1258 struct {
1259 uint8_t surfacePlan;
1260 ngen::Subregister planCount;
1261 ngen::GRF localID;
1262 ngen::Subregister localSize;
1263 } inputsSK;
1264 ngen::Subregister last_i0, last_j0, last_h0;
1265
1266 GEMMSuperkernelState(ngen::HW hw) : GEMMState(hw) {}
1267};
1268
1269// Copy kernel problem description: D <- alpha*S
1270struct CopyProblem : public CommonProblem {
1271 Type Ts, Td, Tsum;
1272 Scalar<double> alpha_real, alpha_imag;
1273 MatrixAddressing S, D;
1274 bool conjugate = false;
1275 bool lower;
1276 bool unit;
1277 bool trsm = false;
1278 bool sum = false;
1279 int targetWG = 1;
1280
1281 bool reflecting() const { return false; }
1282};
1283
1284// Strategy parameters for copy kernels.
1285struct CopyStrategy : public CommonStrategy {
1286 MatrixAddressingStrategy S, D;
1287 RemainderHandling remHandlingX,
1288 remHandlingY; // Remainder handling for X dimension (packed dimension) and Y dimension (length of panel)
1289 int s_load, d_load; // # of rows/columns to load from S/store to D at once
1290 int s_load_masked = 0,
1291 d_load_masked
1292 = 0; // Same as s_load/d_load, for use when masking (0 = default = same as {s,d}_load)
1293 int wgW = 0, wgZ = 0; // Fixed workgroup sizes (0 if variable).
1294
1295 int unrollX, unrollY; // Unrolls for each dimension.
1296 bool duplicateAlpha
1297 = true; // True to make two copies of alpha, one for each register bank
1298 bool xLoop
1299 = false; // True to loop over x, false to loop over y within a kernel
1300
1301 bool zParallel = false; // Kernel parallelized in z dimension?
1302
1303 int barrierFreq = 0; // If > 0, set a barrier every barrierFreq loops
1304 int optionalAlignS
1305 = 0; // If > 0, generate code to check if S is aligned to this #elements and branch to specific code for that case.
1306
1307 CopyStrategy() {}
1308 CopyStrategy(ngen::HW hw, int stepping = 0)
1309 : CommonStrategy(hw, stepping) {}
1310
1311 void preflight(ngen::HW hw, const CopyProblem &problem);
1312
1313 int unrollW() const { return xLoop ? unrollY : unrollX; }
1314 int unrollZ() const { return xLoop ? unrollX : unrollY; }
1315};
1316
1317// State parameters for copy kernels.
1318struct CopyState : public CommonState {
1319 struct {
1320 ngen::Subregister S, D; // q
1321 ngen::Subregister offsetS, offsetD; // q
1322 ngen::Subregister lds, ldd; // d
1323 ngen::Subregister m, n; // d
1324 ngen::Subregister alpha_real; // T_real
1325 ngen::Subregister alpha_imag; // T_real
1326 ngen::Subregister groupIDW, groupIDZ; // ud
1327 ngen::GRF localIDW, localIDZ; // uw
1328 ngen::Subregister localSizeW, localSizeZ; // ud
1329 ngen::Subregister diag; // d
1330 ngen::Subregister blockZ; // ud
1331 uint8_t surfaceS, surfaceD; // DTS indices
1332 } inputs;
1333 ngen::Subregister w0, z0; // ud
1334 ngen::Subregister effS,
1335 effD; // Offsets to base of S/D chunks for loading/storing.
1336 ngen::Subregister offsetS1,
1337 effS1; // Reflected variants of offsetS/effS for symmetric/Hermitian.
1338 std::vector<ngen::GRFRange> S_addrs, D_addrs;
1339 std::vector<ngen::GRFRange> S_addrSrcs[2];
1340 ngen::GRFRange S_regs, D_regs;
1341 std::vector<ngen::GRFRange> Ds_regs;
1342 ngen::Subregister lds_sl; // d
1343 ngen::Subregister ldd_dl; // d
1344 ngen::Subregister Z; // d
1345 ngen::FlagRegister flagAP, flagTri, flagDiag;
1346 ngen::FlagRegister flagReflect;
1347 std::vector<RegisterBlock> S_layout, D_layout;
1348 std::vector<RegisterBlock> Ds_layout;
1349 ngen::Subregister remainderX, remainderY; // ud
1350 ngen::GRFRange complexOne; // T_real
1351 ngen::GRF indexVecRT; // uw
1352
1353 bool isNested;
1354
1355 struct {
1356 bool active = false;
1357 } fusedGEMM;
1358
1359 CopyState(ngen::HW hw) : CommonState(hw) {}
1360
1361 void dump();
1362};
1363
1364template <ngen::HW hw>
1365class gemm_kernel_generator_t : public jit_generator<hw> {
1366public:
1367 using super = ngen::OpenCLCodeGenerator<hw>;
1368 gemm_kernel_generator_t() {}
1369
1370 NGEN_FORWARD_OPENCL(hw);
1371
1372 using Injector = jit_post_op_injector<hw>;
1373 std::unique_ptr<Injector> postOpInjector;
1374
1375 static bool supportedBinaryOp(alg_kind_t alg) {
1376 using namespace alg_kind;
1377 return utils::one_of(alg, binary_add, binary_sub, binary_mul,
1378 binary_div, binary_min, binary_max);
1379 }
1380
1381 void gemm(GEMMProblem problem, GEMMStrategy strategy,
1382 const ngen::InterfaceHandler &interface_);
1383 void gemmSuperkernel(GEMMSuperkernelProblem problem,
1384 GEMMSuperkernelStrategy strategy,
1385 const ngen::InterfaceHandler &interface_);
1386 void copy(CopyProblem problem, CopyStrategy strategy,
1387 const ngen::InterfaceHandler &interface_);
1388
1389 static CommonDriverInfo driverInfo(
1390 const GEMMProblem &problem, const GEMMStrategy &strategy);
1391 static CommonDriverInfo driverInfo(const GEMMSuperkernelProblem &problem,
1392 const GEMMStrategy &strategy);
1393 static CommonDriverInfo driverInfo(
1394 const CopyProblem &problem, const CopyStrategy &strategy);
1395
1396protected:
1397 ngen::InterfaceHandler
1398 &interface = ngen::OpenCLCodeGenerator<hw>::interface_;
1399
1400 std::exception_ptr lastException;
1401
1402 std::ostream &getOutStream() const { return std::cerr; }
1403
1404 std::ostream &noteStream() const { return getOutStream(); }
1405
1406 class status_stream {
1407 protected:
1408 char cc;
1409 std::stringstream line;
1410 bool lineStart = true;
1411
1412 gemm_kernel_generator_t<hw> &parent;
1413
1414 friend class gemm_kernel_generator_t<hw>;
1415
1416 public:
1417 status_stream(gemm_kernel_generator_t<hw> &parent_, int color = 1)
1418 : cc(color + '0'), parent(parent_) {}
1419
1420 static constexpr struct Endl {
1421 } endl {};
1422
1423 template <typename T>
1424 status_stream &operator<<(const T &obj) {
1425 return *this;
1426 }
1427
1428 status_stream &operator<<(const Endl &e) { return *this; }
1429 } status {*this};
1430
1431#ifdef SHOW_DISCARDS
1432 void discardStream() {
1433 InstructionStream *s = popStream();
1434 auto oldCC = status.cc;
1435 status.cc = '4';
1436 status << "------- \x1B[32mBEGIN\x1B[34m discarded stream -------"
1437 << status_stream::endl;
1438 auto &sbuffer = *reinterpret_cast<std::ostringstream *>(s->getBuffer());
1439 auto str = sbuffer.str();
1440 bool lastNL = false;
1441 for (int l = 0; l < str.length(); l++) {
1442 char c = str[l];
1443
1444 if (c == '\n') {
1445 if (lastNL) status << "//";
1446 status << status_stream::endl;
1447 lastNL = true;
1448 } else {
1449 status << c;
1450 lastNL = false;
1451 }
1452 }
1453 status << "------- \x1B[32mEND\x1B[34m discarded stream -------"
1454 << status_stream::endl;
1455 status.cc = status.cc;
1456 delete s;
1457 }
1458#endif
1459
1460 enum class HintType {
1461 Bank0,
1462 Bank1,
1463 TempComp0,
1464 TempComp1,
1465 LongTerm,
1466 LongTerm0,
1467 LongTerm1,
1468 R0Info,
1469 A0,
1470 A0Broadcast,
1471 A1,
1472 A1Broadcast,
1473 B0,
1474 B0Broadcast,
1475 B1,
1476 B1Broadcast,
1477 C,
1478 C1,
1479 CLoad,
1480 S,
1481 D,
1482 SAddr,
1483 DAddr
1484 };
1485 enum class StdCRemType { Ignore, Mask, Descriptor };
1486 enum class COperation { Load, Update, UpdateStore };
1487 enum class KLoop {
1488 GEMM,
1489 };
1490
1491 friend std::ostream &operator<<(std::ostream &s, StdCRemType rt) {
1492 const char *names[3] = {"ignore", "mask", "custom descriptor"};
1493 return (s << names[static_cast<int>(rt)]);
1494 }
1495
1496 ngen::FlagRegister getPhysicalFlag(VirtualFlag vflag, CommonState &state);
1497 void allocVFlagStorage(const CommonStrategy &strategy, CommonState &state);
1498
1499 ngen::Bundle getHint(HintType type);
1500 ngen::Bundle getHint(HintType type, const CommonStrategy &strategy);
1501 ngen::Bundle getHint(HintType type, const GEMMStrategy &strategy);
1502 ngen::Bundle getHint(HintType type, const CopyStrategy &strategy);
1503
1504 void goto12(const ngen::InstructionModifier &mod, ngen::Label &jip) {
1505 goto12(mod, jip, jip);
1506 }
1507 void goto12(const ngen::InstructionModifier &mod, ngen::Label &jip,
1508 ngen::Label &uip, bool branchCtrl = false);
1509
1510 template <typename DT = void>
1511 void mulConstant(const ngen::InstructionModifier &mod,
1512 const ngen::RegData &dst, const ngen::RegData &src0, int32_t src1);
1513
1514 friend struct EmulationImplementation;
1515 template <typename DT = void>
1516 void emov(const ngen::InstructionModifier &mod, ngen::RegData dst,
1517 ngen::RegData src0, const CommonStrategy &strategy,
1518 CommonState &state);
1519 template <typename DT = void>
1520 void emov(const ngen::InstructionModifier &mod, ngen::RegData dst,
1521 ngen::Immediate src0, const CommonStrategy &strategy,
1522 CommonState &state) {
1523 EmulationImplementation::emov<DT>(
1524 *this, mod, dst, src0, strategy.emulate);
1525 }
1526 template <typename DT = void>
1527 void eadd(const ngen::InstructionModifier &mod, const ngen::RegData &dst,
1528 const ngen::RegData &src0, const ngen::RegData &src1,
1529 const CommonStrategy &strategy, CommonState &state);
1530 template <typename DT = void>
1531 void eadd(const ngen::InstructionModifier &mod, const ngen::RegData &dst,
1532 const ngen::RegData &src0, ngen::Immediate src1,
1533 const CommonStrategy &strategy, const CommonState &state) {
1534 EmulationImplementation::eadd<DT>(
1535 *this, mod, dst, src0, src1, strategy.emulate, state.emulate);
1536 }
1537 template <typename DT = void>
1538 void emul(const ngen::InstructionModifier &mod, const ngen::RegData &dst,
1539 const ngen::RegData &src0, const ngen::RegData &src1,
1540 const CommonStrategy &strategy, const CommonState &state) {
1541 EmulationImplementation::emul<DT>(
1542 *this, mod, dst, src0, src1, strategy.emulate, state.emulate);
1543 }
1544 template <typename DT = void>
1545 void emul(const ngen::InstructionModifier &mod, const ngen::RegData &dst,
1546 const ngen::RegData &src0, ngen::Immediate src1,
1547 const CommonStrategy &strategy, const CommonState &state) {
1548 EmulationImplementation::emul<DT>(
1549 *this, mod, dst, src0, src1, strategy.emulate, state.emulate);
1550 }
1551 template <typename DT = void>
1552 void eshl(const ngen::InstructionModifier &mod, ngen::RegData dst,
1553 ngen::RegData src0, uint16_t src1, const CommonStrategy &strategy,
1554 const CommonState &state) {
1555 EmulationImplementation::eshl<DT>(
1556 *this, mod, dst, src0, src1, strategy.emulate, state.emulate);
1557 }
1558 template <typename DT = void>
1559 void eshr(const ngen::InstructionModifier &mod, ngen::RegData dst,
1560 ngen::RegData src0, uint16_t src1, const CommonStrategy &strategy,
1561 const CommonState &state) {
1562 EmulationImplementation::eshr<DT>(
1563 *this, mod, dst, src0, src1, strategy.emulate, state.emulate);
1564 }
1565 template <typename DT = void>
1566 void emulConstant(const ngen::InstructionModifier &mod,
1567 const ngen::RegData &dst, const ngen::RegData &src0, int32_t src1,
1568 const CommonStrategy &strategy, const CommonState &state) {
1569 EmulationImplementation::emulConstant<DT>(
1570 *this, mod, dst, src0, src1, strategy.emulate, state.emulate);
1571 }
1572 template <typename S1>
1573 void emul32High(const ngen::InstructionModifier &mod,
1574 const ngen::RegData &dstHi, const ngen::RegData &src0,
1575 const S1 &src1) {
1576 EmulationImplementation::emul32High(*this, mod, dstHi, src0, src1);
1577 }
1578
1579 template <typename S0, typename S2>
1580 void emad(const ngen::InstructionModifier &mod, const ngen::RegData &dst,
1581 const S0 &src0, const ngen::RegData &src1, const S2 &src2,
1582 const CommonStrategy &strategy, CommonState &state);
1583 template <typename S0>
1584 void emad(const ngen::InstructionModifier &mod, const ngen::RegData &dst,
1585 const S0 &src0, const ngen::RegData &src1, int32_t src2,
1586 const CommonStrategy &strategy, CommonState &state);
1587 template <typename DT = void, typename S0, typename S2>
1588 void eadd3(const ngen::InstructionModifier &mod, const ngen::RegData &dst,
1589 const S0 &src0, const ngen::RegData &src1, const S2 &src2);
1590
1591 template <typename DT = void>
1592 void emath(const ngen::InstructionModifier &mod, ngen::MathFunction fc,
1593 const ngen::RegData &dst, const ngen::RegData &src0,
1594 const GEMMStrategy &strategy, CommonState &state);
1595 template <typename DT = void>
1596 void einv(const ngen::InstructionModifier &mod, const ngen::RegData &dst,
1597 const ngen::RegData &src0, const GEMMStrategy &strategy,
1598 CommonState &state) {
1599 emath<DT>(mod, ngen::MathFunction::inv, dst, src0, strategy, state);
1600 }
1601 template <typename DT = void>
1602 void esqt(const ngen::InstructionModifier &mod, const ngen::RegData &dst,
1603 const ngen::RegData &src0, const GEMMStrategy &strategy,
1604 CommonState &state) {
1605 emath<DT>(mod, ngen::MathFunction::sqt, dst, src0, strategy, state);
1606 }
1607
1608 void ejmpi(ngen::InstructionModifier mod, ngen::Label &dst);
1609
1610 void cmp0(const ngen::InstructionModifier &mod, ngen::RegData src0);
1611 void syncall();
1612
1613 void wrdepRanges(const std::vector<GRFMultirange> &rrs) {
1614 for (auto &rr : rrs)
1615 for (auto &r : rr.ranges)
1616 wrdep(r);
1617 }
1618
1619 void addScaled(const ngen::InstructionModifier &mod,
1620 const ngen::RegData &dst, int src0, const ngen::RegData &src1,
1621 int numerator, int denominator, CommonState &state,
1622 bool exact = false);
1623 void addScaled(const ngen::InstructionModifier &mod,
1624 const ngen::RegData &dst, const ngen::RegData &src0,
1625 const ngen::RegData &src1, int numerator, int denominator,
1626 CommonState &state, bool exact = false);
1627 void addScaled(const ngen::InstructionModifier &mod,
1628 const ngen::RegData &dst, const ngen::RegData &src0, int src1,
1629 int numerator, int denominator, CommonState &state,
1630 bool exact = false);
1631
1632 template <typename DT = void>
1633 void mod(const ngen::Subregister &dst, const ngen::Subregister &src,
1634 uint16_t modulus, const CommonStrategy &strategy,
1635 CommonState &state);
1636 template <typename DT = void>
1637 void modExt(const ngen::Subregister &dstMod,
1638 const ngen::Subregister &dstMultiple, const ngen::Subregister &src,
1639 uint16_t modulus, const CommonStrategy &strategy,
1640 CommonState &state);
1641 template <typename DT = void>
1642 void alignDown(const ngen::Subregister &dst, const ngen::Subregister &src,
1643 uint16_t align, const CommonStrategy &strategy, CommonState &state);
1644 template <typename DT = void>
1645 void alignUp(const ngen::Subregister &dst, const ngen::Subregister &src,
1646 uint16_t align, const CommonStrategy &strategy, CommonState &state);
1647 template <typename DT = void>
1648 void divDown(const ngen::Subregister &dst, const ngen::Subregister &src0,
1649 const ngen::Subregister &src1, const ngen::Subregister &src1Recip,
1650 const ngen::FlagRegister &flag, const CommonStrategy &strategy,
1651 CommonState &state);
1652 template <typename DT = void>
1653 void divDown(const ngen::Subregister &dst, const ngen::Subregister &src,
1654 uint16_t divisor, const CommonStrategy &strategy,
1655 CommonState &state);
1656
1657 void simtDoWhileLoop(
1658 const ngen::InstructionModifier &mod, ngen::Label &dest);
1659 void slmBarrier(const ngen::GRF &temp, const ngen::GRF &r0_info = r0);
1660 void globalMemBarrier(const ngen::GRF &temp, const ngen::GRF &r0_info = r0);
1661 void pause(const CommonStrategy &strategy);
1662
1663 void duplicateScalar(SubregisterPair &val, CommonState &state);
1664 void deduplicateScalar(SubregisterPair &val, CommonState &state);
1665 template <typename T>
1666 void duplicateScalar(Scalar<T> &val, CommonState &state);
1667 MultishiftSubregister multishift(const ngen::Subregister &reg,
1668 unsigned shifts, const CommonStrategy &strategy, CommonState &state,
1669 ngen::Bundle hint = ngen::Bundle());
1670
1671 void getFusedID(int scale, const CommonProblem &problem,
1672 const CommonStrategy &strategy, CommonState &state);
1673 void moveR0(const CommonStrategy &strategy, CommonState &state);
1674 void moveR0(const GEMMStrategy &strategy, GEMMState &state);
1675 template <typename F>
1676 void useR0(CommonState &state, F f);
1677 void removeSG(const CommonProblem &problem, const CommonStrategy &strategy,
1678 const CommonState &state);
1679 void reorderFusedEUs(const GEMMProblem &problem,
1680 const GEMMStrategy &strategy, GEMMState &state);
1681 ngen::Subregister copySubregister(const ngen::Subregister &reg,
1682 CommonState &state,
1683 ngen::Bundle hint = ngen::Bundle(ngen::Bundle::any, 0));
1684 void zeroMatrix(const GRFMultirange &r, const CommonStrategy &strategy);
1685 void releaseFusedRemainders(GEMMState &state);
1686 void saveMNLocalIDs(const GEMMStrategy &strategy, GEMMState &state);
1687 void saveKLocalIDSize(const GEMMStrategy &strategy, GEMMState &state);
1688 void releaseSavedMNLocalIDs(GEMMState &state);
1689
1690 void doReadSuppressionWA(
1691 const CommonStrategy &strategy, CommonState &state);
1692
1693 bool getBlockInfo(Type T, const MatrixAddressing &atype,
1694 const MatrixAddressingStrategy &astrategy, int r, int c,
1695 bool remainderR, bool remainderC, bool writable, bool avoidFragment,
1696 int maxRBlock, int maxCBlock, int &rblock, int &cblock,
1697 RegisterBlock &layout);
1698 bool getSubblock(Type T, RegisterBlock &blockDst,
1699 const RegisterBlock &blockSrc, bool column, int x1, int x2,
1700 int x1Unclamped, int x2Unclamped, bool overrunOK,
1701 const MatrixAddressing &atype,
1702 const MatrixAddressingStrategy &astrategy);
1703 bool getSubblocks(Type T, std::vector<RegisterBlock> &sublayout,
1704 const std::vector<RegisterBlock> &layout, bool column, int x1,
1705 int x2, bool overrunOK, const MatrixAddressing &atype,
1706 const MatrixAddressingStrategy &astrategy);
1707 bool getSubblocks(Type T, std::vector<RegisterBlock> &sublayout,
1708 std::vector<ngen::GRFRange> *subaddrs, std::vector<int> *indices,
1709 const std::vector<RegisterBlock> &layout,
1710 const std::vector<ngen::GRFRange> *addrs, bool column, int x1,
1711 int x2, bool overrunOK, const MatrixAddressing &atype,
1712 const MatrixAddressingStrategy &astrategy);
1713 bool getSubblocks(Type T, std::vector<RegisterBlock> &sublayout,
1714 std::vector<ngen::GRFRange> &subaddrs,
1715 const std::vector<RegisterBlock> &layout,
1716 const std::vector<ngen::GRFRange> &addrs, bool column, int x1,
1717 int x2, bool overrunOK, const MatrixAddressing &atype,
1718 const MatrixAddressingStrategy &astrategy);
1719 bool getSubblocks(Type T, std::vector<RegisterBlock> &sublayout,
1720 std::vector<int> &indices, const std::vector<RegisterBlock> &layout,
1721 bool column, int x1, int x2, bool overrunOK,
1722 const MatrixAddressing &atype,
1723 const MatrixAddressingStrategy &astrategy);
1724 bool reblockLayout(Type Tdst, std::vector<int32_t> &blockMap,
1725 std::vector<RegisterBlock> &layoutDst,
1726 const std::vector<RegisterBlock> &layoutRef,
1727 const std::vector<RegisterBlock> &layoutSrc,
1728 const MatrixAddressing &atype,
1729 const MatrixAddressingStrategy &astrategy);
1730
1731 bool tryAddMasking(Type T, RegisterBlock &block, bool remainderR,
1732 bool remainderC, const MatrixAddressing &atype,
1733 const MatrixAddressingStrategy &astrategy);
1734 bool tryAddMasking(Type T, std::vector<RegisterBlock> &layout,
1735 bool remainderR, bool remainderC, const MatrixAddressing &atype,
1736 const MatrixAddressingStrategy &astrategy);
1737 void addMasking(Type T, std::vector<RegisterBlock> &layout, bool remainderR,
1738 bool remainderC, const MatrixAddressing &atype,
1739 const MatrixAddressingStrategy &astrategy);
1740 void addMasking(Type T, std::vector<RegisterBlock> &layout,
1741 std::vector<ngen::GRFRange> &addrs, const ngen::Subregister &ld,
1742 bool remainderR, bool remainderC, const MatrixAddressing &atype,
1743 const MatrixAddressingStrategy &astrategy,
1744 const CommonStrategy &strategy, CommonState &state,
1745 int dataRegs = -1);
1746 void adjustSubblockAddrs(Type T,
1747 const std::vector<RegisterBlock> &sublayout,
1748 const std::vector<ngen::GRFRange> &subaddrs,
1749 const std::vector<RegisterBlock> &layout,
1750 const std::vector<ngen::GRFRange> &addrs,
1751 const MatrixAddressing &atype,
1752 const MatrixAddressingStrategy &astrategy,
1753 const CommonStrategy &strategy, const CommonState &state);
1754
1755 bool addToRegLayout(Type T, std::vector<RegisterBlock> &layout, int r,
1756 int c, int roff, int coff, bool remainderR, bool remainderC,
1757 bool writable, bool avoidFragment, int maxRBlock, int maxCBlock,
1758 const MatrixAddressing &atype,
1759 const MatrixAddressingStrategy &astrategy);
1760 bool add1DBlockToRegLayout(Type T, std::vector<RegisterBlock> &layout,
1761 int r, int c, bool writable, const MatrixAddressing &atype,
1762 const MatrixAddressingStrategy &astrategy);
1763 bool getRegLayout(Type T, std::vector<RegisterBlock> &layout, int r, int c,
1764 bool remainderR, bool remainderC, bool writable, bool avoidFragment,
1765 int maxRBlock, int maxCBlock, const MatrixAddressing &atype,
1766 const MatrixAddressingStrategy &astrategy,
1767 bool reverseOrder = false);
1768 void makeUnbackedRegLayout(Type T, std::vector<RegisterBlock> &layout,
1769 int r, int c, bool colMajor, int crosspack = 1, int tileR = 0,
1770 int tileC = 0, bool allowPartialRegs = true,
1771 bool fullySplitCx = false);
1772 bool upgradeLayoutToBlock2D(Type T,
1773 const std::vector<RegisterBlock> &layoutSrc,
1774 std::vector<RegisterBlock> &layout2D, bool remainderR,
1775 bool remainderC, bool writable, const MatrixAddressing &atype,
1776 const MatrixAddressingStrategy &astrategy);
1777
1778 void setupTeardownLoadStoreDesc(bool setup,
1779 const std::vector<RegisterBlock> &layout,
1780 const CommonStrategy &strategy, CommonState &state);
1781 void loadLoadStoreDescriptors(bool load, bool store, RegisterBlock &block,
1782 const ngen::Subregister &count, const MatrixAddressing &atype,
1783 const MatrixAddressingStrategy &astrategy,
1784 const CommonStrategy &strategy, CommonState &state);
1785
1786 static ngen::DataSpecLSC getDataSpecLSC(
1787 AccessType access, const RegisterBlock &block);
1788 static ngen::DataSpecLSC getDataSpecLSC(const MatrixAddressing &atype,
1789 const MatrixAddressingStrategy &astrategy,
1790 const RegisterBlock &block, bool write);
1791 ngen::InstructionModifier getRegisterBlockMask(
1792 const RegisterBlock &block, CommonState &state);
1793 void loadMatrixBlock(const ngen::Register &dest,
1794 const RegisterBlock &layout, const MatrixAddressing &atype,
1795 const MatrixAddressingStrategy &astrategy,
1796 const ngen::GRFRange &addr, const CommonStrategy &strategy,
1797 CommonState &state, bool zeroMask = false);
1798 void loadMatrix(const GRFMultirange &dest,
1799 const std::vector<RegisterBlock> &layout,
1800 const MatrixAddressing &atype,
1801 const MatrixAddressingStrategy &astrategy,
1802 const std::vector<ngen::GRFRange> &addrs,
1803 const CommonStrategy &strategy, CommonState &state,
1804 bool zeroMask = false);
1805 void prefetchMatrix(const std::vector<RegisterBlock> &layout,
1806 const MatrixAddressing &atype,
1807 const MatrixAddressingStrategy &astrategy,
1808 const std::vector<ngen::GRFRange> &addrs,
1809 const CommonStrategy &strategy, CommonState &state);
1810 void storeMatrixBlock(const ngen::GRF &src, const RegisterBlock &layout,
1811 const MatrixAddressing &atype,
1812 const MatrixAddressingStrategy &astrategy,
1813 const ngen::GRFRange &addr, const CommonStrategy &strategy,
1814 CommonState &state);
1815 void storeMatrix(const GRFMultirange &src,
1816 const std::vector<RegisterBlock> &layout,
1817 const MatrixAddressing &atype,
1818 const MatrixAddressingStrategy &astrategy,
1819 const std::vector<ngen::GRFRange> &addrs,
1820 const CommonStrategy &strategy, CommonState &state);
1821 void atomicAddMatrixBlock(Type T, const ngen::GRF &src,
1822 const RegisterBlock &layout, const MatrixAddressing &atype,
1823 const MatrixAddressingStrategy &astrategy,
1824 const ngen::GRFRange &addr, const CommonProblem &problem,
1825 const CommonStrategy &strategy, CommonState &state);
1826 void atomicAddMatrix(Type T, const GRFMultirange &src,
1827 const std::vector<RegisterBlock> &layout,
1828 const MatrixAddressing &atype,
1829 const MatrixAddressingStrategy &astrategy,
1830 const std::vector<ngen::GRFRange> &addrs,
1831 const CommonProblem &problem, const CommonStrategy &strategy,
1832 CommonState &state);
1833
1834 bool assignMasks(std::vector<RegisterBlock> &layout, LoopType rloop,
1835 LoopType cloop, std::vector<MaskAssignment> &assignments,
1836 const CommonStrategy &strategy, CommonState &state,
1837 bool retryVirtual = false);
1838 void loadMask(MaskAssignment assignment, ngen::Subregister index,
1839 const CommonStrategy &strategy, CommonState &state, int offset = 0);
1840 void loadMasks(const std::vector<MaskAssignment> &assignments,
1841 ngen::Subregister (&indices)[3], const CommonStrategy &strategy,
1842 CommonState &state, int start = 0);
1843 void loadMasks(const std::vector<MaskAssignment> &assignments,
1844 ngen::Subregister (&indices)[3], int (&offsets)[3],
1845 const CommonStrategy &strategy, CommonState &state, int start = 0);
1846
1847 void setupTeardownRemask(Type T, int index, bool setup, int nq,
1848 const ngen::Subregister &remQ, const CommonStrategy &strategy,
1849 CommonState &state, int fixedOffQ = 0,
1850 const ngen::Subregister &variableOffQ = ngen::Subregister());
1851 void remaskLayout(Type T, int index, bool column,
1852 const std::vector<RegisterBlock> &layout, const GRFMultirange &regs,
1853 const CommonStrategy &strategy, CommonState &state, int offset = 0);
1854
1855 void setAddrRemainder(Type T, const ngen::GRFRange &addr,
1856 const RegisterBlock &block, const ngen::Subregister &remR,
1857 const ngen::Subregister &remC, const MatrixAddressing &atype,
1858 const MatrixAddressingStrategy &astrategy,
1859 const CommonStrategy &strategy, CommonState &state);
1860 void setAddrRemainder(Type T, const std::vector<ngen::GRFRange> &addr,
1861 const std::vector<RegisterBlock> &layout,
1862 const ngen::Subregister &remR, const ngen::Subregister &remC,
1863 const MatrixAddressing &atype,
1864 const MatrixAddressingStrategy &astrategy,
1865 const CommonStrategy &strategy, CommonState &state);
1866
1867 ngen::Subregister startShift(
1868 const MultishiftSubregister &ptr, int shift, CommonState &state);
1869 SubregisterPair startShift(
1870 const SubregisterPair &ptr, int shift, CommonState &state);
1871 template <typename BO>
1872 typename std::enable_if<!std::is_base_of<ngen::RegData, BO>::value,
1873 BO>::type
1874 startShift(const BO &ptr, int shift, CommonState &state);
1875 template <typename BO>
1876 typename std::enable_if<std::is_base_of<ngen::RegData, BO>::value, BO>::type
1877 startShift(const BO &ptr, int shift, CommonState &state);
1878 template <typename BO, typename BI>
1879 typename std::enable_if<!std::is_base_of<ngen::RegData, BO>::value>::type
1880 doneShift(
1881 const BO &ptr, const BI &ptrShifted, int shift, CommonState &state);
1882 template <typename BO, typename BI>
1883 typename std::enable_if<std::is_base_of<ngen::RegData, BO>::value>::type
1884 doneShift(
1885 const BO &ptr, const BI &ptrShifted, int shift, CommonState &state);
1886 void doneShift(const SubregisterPair &ptr,
1887 const SubregisterPair &ptrShifted, int shift, CommonState &state);
1888
1889 void offsetAddr(const ngen::GRFRange &addrDst,
1890 const ngen::GRFRange &addrSrc, const RegisterBlock &blockDst,
1891 const RegisterBlock &blockSrc, int offsetFixed, int offsetLD,
1892 const ngen::Subregister &ld, const MatrixAddressing &atype,
1893 const MatrixAddressingStrategy &astrategy,
1894 const CommonStrategy &strategy, CommonState &state,
1895 const LDMultiples &ldMultiples = {});
1896 void setupAddrRel(Type T, const ngen::GRFRange &addrDst,
1897 const ngen::GRFRange &addrSrc, const RegisterBlock &blockDst,
1898 const RegisterBlock &blockSrc,
1899 const std::vector<RegisterBlock> &layout,
1900 const ngen::Subregister &ld, const MatrixAddressing &atype,
1901 const MatrixAddressingStrategy &astrategy,
1902 const CommonStrategy &strategy, CommonState &state,
1903 const LDMultiples &ldMultiples = {});
1904 template <typename BO>
1905 void setupAddr(const ngen::GRFRange &addr, const BO &ptr,
1906 const RegisterBlock &layout, const ngen::Subregister &ld,
1907 size_t sizeofT, const MatrixAddressing &atype,
1908 const MatrixAddressingStrategy &astrategy,
1909 const CommonStrategy &strategy, CommonState &state,
1910 const Address2DParams &params = {}, LDMultiples ldMultiples = {});
1911 template <typename BO>
1912 void setupAddr(Type T, const std::vector<ngen::GRFRange> &addr,
1913 const BO &ptr, const std::vector<RegisterBlock> &layout,
1914 const ngen::Subregister &ld, const MatrixAddressing &atype,
1915 const MatrixAddressingStrategy &astrategy,
1916 const CommonStrategy &strategy, CommonState &state,
1917 const Address2DParams &params = {},
1918 const LDMultiples &ldMultiples = {});
1919 template <typename I, typename Ir, typename Ic>
1920 void incAddrShifted(const ngen::GRFRange &addrDst,
1921 const ngen::GRFRange &addrSrc, I inc, Ir incR, Ic incC,
1922 const RegisterBlock &layoutDst, const RegisterBlock &layoutSrc,
1923 const MatrixAddressing &atype,
1924 const MatrixAddressingStrategy &astrategy,
1925 const CommonStrategy &strategy, CommonState &state);
1926 template <typename I, typename Ir, typename Ic>
1927 void incAddrShifted(const std::vector<ngen::GRFRange> &addr, I inc, Ir incR,
1928 Ic incC, const std::vector<RegisterBlock> &layout,
1929 const MatrixAddressing &atype,
1930 const MatrixAddressingStrategy &astrategy,
1931 const CommonStrategy &strategy, CommonState &state);
1932 template <typename I>
1933 void incAddrShifted(const std::vector<ngen::GRFRange> &addr, I inc,
1934 const std::vector<RegisterBlock> &layout,
1935 const MatrixAddressing &atype,
1936 const MatrixAddressingStrategy &astrategy,
1937 const CommonStrategy &strategy, CommonState &state);
1938 template <typename I, typename Ir, typename Ic>
1939 void incAddr(const ngen::GRFRange &addrDst, const ngen::GRFRange &addrSrc,
1940 I inc, Ir incR, Ic incC, const RegisterBlock &layoutDst,
1941 const RegisterBlock &layoutSrc, const MatrixAddressing &atype,
1942 const MatrixAddressingStrategy &astrategy,
1943 const CommonStrategy &strategy, CommonState &state);
1944 template <typename I, typename Ir, typename Ic>
1945 void incAddr(const std::vector<ngen::GRFRange> &addr, I inc, Ir incR,
1946 Ic incC, const std::vector<RegisterBlock> &layout,
1947 const MatrixAddressing &atype,
1948 const MatrixAddressingStrategy &astrategy,
1949 const CommonStrategy &strategy, CommonState &state);
1950 template <typename I>
1951 void incAddr(const ngen::GRFRange &addrDst, const ngen::GRFRange &addrSrc,
1952 I inc, const RegisterBlock &layoutDst,
1953 const RegisterBlock &layoutSrc, const MatrixAddressing &atype,
1954 const MatrixAddressingStrategy &astrategy,
1955 const CommonStrategy &strategy, CommonState &state);
1956 template <typename I>
1957 void incAddr(const std::vector<ngen::GRFRange> &addr, I inc,
1958 const std::vector<RegisterBlock> &layout,
1959 const MatrixAddressing &atype,
1960 const MatrixAddressingStrategy &astrategy,
1961 const CommonStrategy &strategy, CommonState &state);
1962 template <typename A, typename I, typename Ir, typename Ic>
1963 void incDecAddr(const A &addr, I inc, Ir incR, Ic incC,
1964 const std::vector<RegisterBlock> &layout,
1965 const MatrixAddressing &atype,
1966 const MatrixAddressingStrategy &astrategy,
1967 const CommonStrategy &strategy, CommonState &state, bool decrement);
1968 template <typename A, typename I>
1969 void incDecAddr(const A &addr, I inc,
1970 const std::vector<RegisterBlock> &layout,
1971 const MatrixAddressing &atype,
1972 const MatrixAddressingStrategy &astrategy,
1973 const CommonStrategy &strategy, CommonState &state, bool decrement);
1974
1975 void extendIndexVec(int n, CommonState &state);
1976 ngen::Subregister accessIndexVec(int n, CommonState &state);
1977
1978 LDMultiples createLDMultiples(bool a64, int nmultiples,
1979 const ngen::Subregister &ld, const CommonStrategy &strategy,
1980 CommonState &state);
1981 ngen::Subregister findLDMultiple(const LDMultiples &multiples, bool a64,
1982 int n, const CommonStrategy &strategy, CommonState &state);
1983
1984 void setupCAddr0(ngen::GRFRange (&C_addr0)[2],
1985 ngen::GRFRange (&C_addr0Unmasked)[2],
1986 const std::vector<RegisterBlock> &C_layout,
1987 const std::vector<RegisterBlock> &C_layoutUnmasked, int C_count,
1988 GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state,
1989 const Address2DParams *params = nullptr);
1990
1991 void outerProductGen9IGEMM(int ha, int hb,
1992 const std::vector<RegisterBlock> &A_layout,
1993 const std::vector<RegisterBlock> &B_layout,
1994 const GRFMultirange &A_regs, const GRFMultirange &B_regs,
1995 const GEMMProblem &problem, const GEMMStrategy &strategy,
1996 GEMMState &state);
1997 void outerProductSystolic(int h, int ha, int hb,
1998 const std::vector<RegisterBlock> &A_layout,
1999 const std::vector<RegisterBlock> &B_layout,
2000 const GRFMultirange &A_regs, const GRFMultirange &B_regs,
2001 const GEMMProblem &problem, const GEMMStrategy &strategy,
2002 GEMMState &state);
2003 void outerProduct(int h, int ha, int hb, int opCount,
2004 const std::vector<RegisterBlock> &A_layout,
2005 const std::vector<RegisterBlock> &B_layout,
2006 const GRFMultirange &A_regs, const GRFMultirange &B_regs,
2007 const GEMMProblem &problem, const GEMMStrategy &strategy,
2008 GEMMState &state);
2009 void setupTeardownAccumulateSumSystolic(bool setup, Type Tother,
2010 const GEMMProblem &problem, const GEMMStrategy &strategy,
2011 GEMMState &state);
2012
2013 void updateC(const GRFMultirange &C_acc, const GRFMultirange &C_accSwap,
2014 const GRFMultirange &C_load, GEMMProblem &problem,
2015 GEMMStrategy &strategy, GEMMState &state);
2016 void updateCLayout(const std::vector<RegisterBlock> &layoutExt,
2017 const ngen::GRFRange (&C_addr0)[2], const RegisterBlock &C_block0,
2018 COperation op, GEMMProblem &problem, GEMMStrategy &strategy,
2019 GEMMState &state);
2020 bool doStdCRemainder(std::vector<RegisterBlock> &layoutExt,
2021 std::vector<RegisterBlock> &layoutExtUnmasked, bool inside,
2022 bool columns[2], StdCRemType remTypes[2], bool fragments[2],
2023 bool fragPositives[2], int fragSizes[2],
2024 const ngen::GRFRange (&C_addr0)[2],
2025 const ngen::GRFRange (&C_addr0Unmasked)[2], COperation op,
2026 std::vector<MaskAssignment> &masks, GEMMProblem &problem,
2027 GEMMStrategy &strategy, GEMMState state);
2028 void doAlternateCRemainder(COperation op, GEMMProblem &problem,
2029 GEMMStrategy &strategy, GEMMState &state);
2030
2031 void accumulateSum(bool column, Type Tsrc, const GRFMultirange &srcRegs,
2032 const std::vector<RegisterBlock> &srcLayout, Type Tdst,
2033 const GRFMultirange &dstRegs,
2034 const std::vector<RegisterBlock> &dstLayout,
2035 const CommonStrategy &strategy, CommonState &state, int q0 = -1,
2036 int q1 = -1);
2037 void makeSumLayout(bool column, Type Tsrc,
2038 const std::vector<RegisterBlock> &srcLayout, Type Tdst,
2039 std::vector<RegisterBlock> &dstLayout,
2040 const CommonStrategy &strategy, CommonState &state);
2041 void horizontalAdd(bool column, Type T, const GRFMultirange &regs,
2042 std::vector<RegisterBlock> &layout, CommonState &state);
2043 bool gemmFinalizeSums(const GEMMProblem &problem,
2044 const GEMMStrategy &strategy, GEMMState &state);
2045
2046 CoopSplit effCoopSplitA(
2047 const GEMMProblem &problem, const GEMMStrategy &strategy);
2048 CoopSplit effCoopSplitB(
2049 const GEMMProblem &problem, const GEMMStrategy &strategy);
2050
2051 void convert(const GRFMultirange &range, Type Told, Type Tnew,
2052 const GEMMProblem &problem, const GEMMStrategy &strategy,
2053 GEMMState &state);
2054 bool gemmConvertC(Type Tnew, const GEMMProblem &problem,
2055 const GEMMStrategy &strategy, GEMMState &state);
2056 void gemmBetaScale(
2057 GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state);
2058 void binaryOp(BinaryOp op, int simd, const ngen::RegData &dst,
2059 const ngen::RegData &src0, const ngen::RegData &src1);
2060 void gemmScalarBinaryOpC(BinaryOp op, const ngen::Subregister &offset,
2061 const GEMMProblem &problem, const GEMMStrategy &strategy,
2062 GEMMState &state);
2063 void gemmVectorBinaryOpC(BinaryOp op, bool column,
2064 const GRFMultirange &offsets, const ngen::Subregister &scale,
2065 const GEMMProblem &problem, const GEMMStrategy &strategy,
2066 GEMMState &state, Type Tco = Type::invalid,
2067 std::vector<RegisterBlock> CO_layout = std::vector<RegisterBlock>(),
2068 int y0 = -1, int y1 = -1);
2069 void gemmCalcABOffsetAddrs(const GEMMProblem &problem,
2070 const GEMMStrategy &strategy, GEMMState &state);
2071 bool gemmLoadABOffset(const GEMMProblem &problem,
2072 const GEMMStrategy &strategy, GEMMState &state);
2073 void gemmApplyABOffset(const GEMMProblem &problem,
2074 const GEMMStrategy &strategy, GEMMState &state);
2075 void gemmUpdateSums(const GEMMProblem &problem,
2076 const GEMMStrategy &strategy, GEMMState &state);
2077 bool gemmBinaryOpC(BinaryOp op, bool row, bool column, Type Tco,
2078 MatrixAddressing CO, MatrixAddressingStrategy CO_strategy,
2079 ngen::Subregister base, ngen::Subregister ld,
2080 const GEMMProblem &problem, const GEMMStrategy &strategy,
2081 GEMMState &state);
2082 bool gemmApplyCOffsetDispatch(const GEMMProblem &problem,
2083 const GEMMStrategy &strategy, GEMMState &state);
2084 void gemmKReduce(const GEMMProblem &problem, const GEMMStrategy &strategy,
2085 GEMMState &state);
2086 void gemmPrefetchC(const GEMMProblem &problem, GEMMStrategy &strategy,
2087 GEMMState &state);
2088
2089 void gemmApplyPostOps(int poMin, int poMax, const GEMMProblem &problem,
2090 GEMMStrategy &strategy, GEMMState &state);
2091 void gemmLoadBinaryOpArgs(const GEMMProblem &problem,
2092 const GEMMStrategy &strategy, GEMMState &state);
2093
2094 void gemmAllocRegs(
2095 GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state);
2096 void gemmAllocAoBoRegs(const GEMMStrategy &strategy, GEMMState &state);
2097 void gemmAIncrementInternal(Type Ta,
2098 const std::vector<RegisterBlock> &layout,
2099 const std::vector<ngen::GRFRange> &addrs, const MatrixAddressing &A,
2100 const MatrixAddressingStrategy &A_strategy, int ka_inc,
2101 const GEMMProblem &problem, const GEMMStrategy &strategy,
2102 GEMMState &state, int ha = 0);
2103 void gemmAIncrementInternal(Type Ta,
2104 const std::vector<RegisterBlock> &layout,
2105 const std::vector<ngen::GRFRange> &addrs, const MatrixAddressing &A,
2106 const MatrixAddressingStrategy &A_strategy,
2107 const MultishiftSubregister &ka_inc, const GEMMProblem &problem,
2108 const GEMMStrategy &strategy, GEMMState &state, int ha = 0);
2109 void gemmAIncrementInternal(Type Ta,
2110 const std::vector<RegisterBlock> &layout,
2111 const std::vector<ngen::GRFRange> &addrs, const MatrixAddressing &A,
2112 const MatrixAddressingStrategy &A_strategy,
2113 const ngen::Subregister &ka_inc, const GEMMProblem &problem,
2114 const GEMMStrategy &strategy, GEMMState &state, int ha = 0);
2115 template <typename I>
2116 void gemmAIncrement(Type Ta, const std::vector<RegisterBlock> &layout,
2117 const std::vector<ngen::GRFRange> &addrs, const MatrixAddressing &A,
2118 const MatrixAddressingStrategy &A_strategy, I ka_inc,
2119 const GEMMProblem &problem, const GEMMStrategy &strategy,
2120 GEMMState &state, int ha = 0);
2121 void gemmALoad(const GRFMultirange &regs,
2122 const std::vector<RegisterBlock> &layout,
2123 const std::vector<ngen::GRFRange> &addrs, const MatrixAddressing &A,
2124 const MatrixAddressingStrategy &A_strategy,
2125 const GEMMProblem &problem, const GEMMStrategy &strategy,
2126 GEMMState &state);
2127 template <typename I>
2128 void gemmALoadInc(Type Ta, const GRFMultirange &regs,
2129 const std::vector<RegisterBlock> &layout,
2130 const std::vector<ngen::GRFRange> &addrs, const MatrixAddressing &A,
2131 const MatrixAddressingStrategy &A_strategy, I ka_inc,
2132 const GEMMProblem &problem, const GEMMStrategy &strategy,
2133 GEMMState &state);
2134 void gemmBIncrementInternal(Type Tb,
2135 const std::vector<RegisterBlock> &layout,
2136 const std::vector<ngen::GRFRange> &addrs, const MatrixAddressing &B,
2137 const MatrixAddressingStrategy &B_strategy, int kb_inc,
2138 const GEMMProblem &problem, const GEMMStrategy &strategy,
2139 GEMMState &state, int hb = 0);
2140 void gemmBIncrementInternal(Type Tb,
2141 const std::vector<RegisterBlock> &layout,
2142 const std::vector<ngen::GRFRange> &addrs, const MatrixAddressing &B,
2143 const MatrixAddressingStrategy &B_strategy,
2144 const MultishiftSubregister &kb_inc, const GEMMProblem &problem,
2145 const GEMMStrategy &strategy, GEMMState &state, int hb = 0);
2146 void gemmBIncrementInternal(Type Tb,
2147 const std::vector<RegisterBlock> &layout,
2148 const std::vector<ngen::GRFRange> &addrs, const MatrixAddressing &B,
2149 const MatrixAddressingStrategy &B_strategy,
2150 const ngen::Subregister &kb_inc, const GEMMProblem &problem,
2151 const GEMMStrategy &strategy, GEMMState &state, int hb = 0);
2152 template <typename I>
2153 void gemmBIncrement(Type Tb, const std::vector<RegisterBlock> &layout,
2154 const std::vector<ngen::GRFRange> &addrs, const MatrixAddressing &B,
2155 const MatrixAddressingStrategy &B_strategy, I kb_inc,
2156 const GEMMProblem &problem, const GEMMStrategy &strategy,
2157 GEMMState &state, int hb = 0);
2158 void gemmBLoad(const GRFMultirange &regs,
2159 const std::vector<RegisterBlock> &layout,
2160 const std::vector<ngen::GRFRange> &addrs, const MatrixAddressing &B,
2161 const MatrixAddressingStrategy &B_strategy,
2162 const GEMMProblem &problem, const GEMMStrategy &strategy,
2163 GEMMState &state);
2164 template <typename I>
2165 void gemmBLoadInc(Type Tb, const GRFMultirange &regs,
2166 const std::vector<RegisterBlock> &layout,
2167 const std::vector<ngen::GRFRange> &addrs, const MatrixAddressing &B,
2168 const MatrixAddressingStrategy &B_strategy, I kb_inc,
2169 const GEMMProblem &problem, const GEMMStrategy &strategy,
2170 GEMMState &state);
2171 template <bool doA>
2172 void gemmAiBiRemLoadInc(bool incremental, bool incrementalCopy,
2173 bool keepAddrTogether, bool willRemask,
2174 const ngen::Subregister &kSLMX, const GRFMultirange &Xi_regs,
2175 const std::vector<RegisterBlock> &Xi_layout,
2176 const std::vector<ngen::GRFRange> &Xi_addrs,
2177 const std::vector<std::vector<RegisterBlock>> &Xi_layoutK,
2178 const std::vector<std::vector<ngen::GRFRange>> &Xi_addrsK,
2179 const GRFMultirange &Xo_regs,
2180 const std::vector<RegisterBlock> &Xo_layout,
2181 const MatrixAddressing &Xi,
2182 const MatrixAddressingStrategy &Xi_strategy,
2183 const GEMMProblem &problem, const GEMMStrategy &strategy,
2184 GEMMState &state);
2185 SubregisterPair allocIncrement(
2186 const GEMMStrategy &strategy, CommonState &state);
2187 void gemmCalcIncrements(const GEMMProblem &problem,
2188 const GEMMStrategy &strategy, GEMMState &state, int ka_load = 0,
2189 int kb_load = 0, bool doA = true, bool doB = true);
2190 void gemmCalcWorkshareAOffset(ngen::Subregister &off,
2191 ngen::Subregister &offR, ngen::Subregister &offC,
2192 const MatrixAddressing &A,
2193 const MatrixAddressingStrategy &A_strategy, int ma, int ka,
2194 const GEMMProblem &problem, const GEMMStrategy &strategy,
2195 GEMMState &state);
2196 void gemmCalcWorkshareBOffset(ngen::Subregister &off,
2197 ngen::Subregister &offR, ngen::Subregister &offC,
2198 const MatrixAddressing &B,
2199 const MatrixAddressingStrategy &B_strategy, int kb, int nb,
2200 const GEMMProblem &problem, const GEMMStrategy &strategy,
2201 GEMMState &state);
2202 bool gemmPrepMaskedAB(const GEMMProblem &problem, GEMMStrategy &strategy,
2203 GEMMState &state);
2204 void gemmSLMRemask(bool remaskA, bool remaskB, GRFMultirange &Ao_regs,
2205 GRFMultirange &Bo_regs, int kOffset, const GEMMProblem &problem,
2206 const GEMMStrategy &strategy, GEMMState &state);
2207
2208 void gemmCalcKLoopBarrierCount(ngen::Subregister &count,
2209 const ngen::Subregister &k, int cooldown,
2210 const GEMMProblem &problem, const GEMMStrategy &strategy,
2211 GEMMState &state);
2212 void gemmCalcKSLM(const ngen::Subregister &kSLM,
2213 const ngen::Subregister &lid, int kgran, int kdiv, int krep,
2214 const GEMMProblem &problem, const GEMMStrategy &strategy,
2215 GEMMState &state);
2216 void kLoopAllocBarrierHeader(GEMMState &state);
2217 ngen::GRF kLoopGetBarrierHeader(GEMMState &state);
2218 void kLoop(KLoop type, const GEMMProblem &problem, GEMMStrategy &strategy,
2219 GEMMState &state);
2220 bool kLoopSetup(const GEMMProblem &problem, const GEMMStrategy &strategy,
2221 GEMMState &state);
2222 template <typename I>
2223 void kLoopReset(const I &kOffset, const GEMMProblem &problem,
2224 const GEMMStrategy &strategy, GEMMState &state);
2225 void kLoopTeardown(const GEMMProblem &problem, const GEMMStrategy &strategy,
2226 GEMMState &state);
2227 bool kLoopSingle(KLoop type, const GEMMProblem &problem,
2228 GEMMStrategy &strategy, GEMMState &state);
2229 bool gemmKLoop(
2230 GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state);
2231 bool gemmAccumulateC(
2232 GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state);
2233 bool gemmAccumulateCSetup(
2234 GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state);
2235 void gemmAccumulateCTeardown(
2236 GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state);
2237 bool gemmAccessC(COperation op, GEMMProblem &problem,
2238 GEMMStrategy &strategy, GEMMState &state);
2239 bool gemmUpdateC(
2240 GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state);
2241
2242 bool gemmBody(GEMMProblem problem, GEMMStrategy strategy, GEMMState state);
2243 bool gemmBodyInternal(
2244 GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state);
2245
2246 bool wgRemCheck(const GEMMProblem &problem, const GEMMStrategy &strategy);
2247 template <typename Problem>
2248 bool mnRemainderHandling(LoopType loop, Problem &problem,
2249 GEMMStrategy &strategy, GEMMState &state,
2250 bool (gemm_kernel_generator_t<hw>::*func)(
2251 Problem, GEMMStrategy, GEMMState));
2252 template <typename Problem>
2253 bool mnJointSplitRemainderHandling(Problem &problem, GEMMStrategy &strategy,
2254 GEMMState &state,
2255 bool (gemm_kernel_generator_t<hw>::*func)(
2256 Problem, GEMMStrategy, GEMMState));
2257 bool gemmMEdge(
2258 GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state);
2259 bool gemmNEdge(GEMMProblem problem, GEMMStrategy strategy, GEMMState state);
2260
2261 void gemmHilbertlikeOrder(const GEMMProblem &problem,
2262 GEMMStrategy &strategy, GEMMState &state);
2263 void gemmBoustrophedonOrder(const GEMMProblem &problem,
2264 GEMMStrategy &strategy, GEMMState &state);
2265 void gemmReorderLocalIDs(const GEMMProblem &problem,
2266 const GEMMStrategy &strategy, GEMMState &state);
2267
2268 void gemmCheck32(const GEMMProblem &problem, GEMMStrategy &strategy,
2269 GEMMState &state);
2270 void gemmGetBatchIDs(const GEMMProblem &problem,
2271 const GEMMStrategy &strategy, GEMMState &state);
2272 void gemmReleaseBatchIDs(const GEMMProblem &problem,
2273 const GEMMStrategy &strategy, GEMMState &state);
2274 void gemmOffsetAk(int h, const ngen::Subregister &effA,
2275 const MatrixAddressing &globalA, const GEMMProblem &problem,
2276 const GEMMStrategy &strategy, GEMMState &state);
2277 void gemmOffsetAk(const ngen::Subregister &h, const ngen::Subregister &effA,
2278 const MatrixAddressing &globalA, const GEMMProblem &problem,
2279 const GEMMStrategy &strategy, GEMMState &state);
2280 void gemmOffsetBk(int h, const ngen::Subregister &effB,
2281 const MatrixAddressing &globalB, const GEMMProblem &problem,
2282 const GEMMStrategy &strategy, GEMMState &state);
2283 void gemmOffsetBk(const ngen::Subregister &h, const ngen::Subregister &effB,
2284 const MatrixAddressing &globalB, const GEMMProblem &problem,
2285 const GEMMStrategy &strategy, GEMMState &state);
2286 void gemmFoldOffsets(const GEMMProblem &problem,
2287 const GEMMStrategy &strategy, GEMMState &state);
2288 void gemmRestoreOffsets(const GEMMProblem &problem,
2289 const GEMMStrategy &strategy, GEMMState &state);
2290 void gemmOffsetABC(bool initial, ngen::Subregister i0, ngen::Subregister j0,
2291 ngen::Subregister h0, const GEMMProblem &problem,
2292 const GEMMStrategy &strategy, GEMMState &state, bool doA = true,
2293 bool doB = true, bool doC = true, bool doBinary = false);
2294 void gemmOffsetBatchABC(const GEMMProblem &problem,
2295 const GEMMStrategy &strategy, GEMMState &state);
2296 void gemmSetupABC(const GEMMProblem &problem, const GEMMStrategy &strategy,
2297 GEMMState &state);
2298 void gemmCacheLDABMultiples(const GEMMProblem &problem,
2299 const GEMMStrategy &strategy, GEMMState &state);
2300 void gemmCacheLDCMultiples(const GEMMProblem &problem,
2301 const GEMMStrategy &strategy, GEMMState &state,
2302 bool prefetch = false);
2303 void gemmScaleInputs(const GEMMProblem &problem,
2304 const GEMMStrategy &strategy, GEMMState &state);
2305 void gemmReverseLoops(const GEMMProblem &problem,
2306 const GEMMStrategy &strategy, GEMMState &state);
2307 void gemmDowngradeAccess(const GEMMProblem &problem, GEMMStrategy &strategy,
2308 GEMMState &state);
2309 void gemmSubkernel(
2310 GEMMProblem &problem, GEMMStrategy &strategy, GEMMState state);
2311 static size_t gemmSLMSize(
2312 const GEMMProblem &problem, const GEMMStrategy &strategy);
2313 static size_t gemmPerKSLMSize(
2314 const GEMMProblem &problem, const GEMMStrategy &strategy);
2315 void gemmInitInterface(GEMMProblem &problem, GEMMStrategy &strategy,
2316 GEMMState &state, bool inSK = false);
2317 void gemmInitState(GEMMProblem &problem, GEMMStrategy &strategy,
2318 GEMMState &state, bool inSK = false);
2319 void gemm(GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state);
2320
2321 void gemmSuperkernelInitState(GEMMSuperkernelProblem &problem,
2322 GEMMSuperkernelStrategy &strategy, GEMMSuperkernelState &state);
2323
2324 bool sysgemmAccumulateC(GEMMProblem &problem, const GEMMStrategy &strategy,
2325 GEMMState &state);
2326 void sysgemmKLoop(const GEMMProblem &problem, const GEMMStrategy &strategy,
2327 GEMMState &state);
2328 void sysgemmKLoop4(const GEMMProblem &problem, const GEMMStrategy &strategy,
2329 GEMMState &state, bool oddB);
2330 void sysgemmStoreSignal(const GEMMProblem &problem,
2331 const GEMMStrategy &strategy, GEMMState &state,
2332 bool forceFence = false);
2333 void sysgemmCopyLoad(const GEMMProblem &problem,
2334 const GEMMStrategy &strategy, GEMMState &state, int storeBuffer,
2335 bool useC = false);
2336 void sysgemmCopyLoad4(const GEMMProblem &problem,
2337 const GEMMStrategy &strategy, GEMMState &state, int storeBuffer,
2338 bool loadB, int useC = 0,
2339 ngen::RegData flagLoadB = ngen::RegData());
2340 void sysgemmCopyStore(const GEMMProblem &problem,
2341 const GEMMStrategy &strategy, GEMMState &state, int storeBuffer,
2342 bool first = false);
2343 void sysgemmCopyStore4(const GEMMProblem &problem,
2344 const GEMMStrategy &strategy, GEMMState &state, int storeBuffer,
2345 bool storeB, int useC = 0, int useC_B = 0);
2346 void sysgemmMultiply(const GEMMProblem &problem,
2347 const GEMMStrategy &strategy, GEMMState &state, int buffer,
2348 bool lastMultiply = false);
2349 void sysgemmMultiply4(const GEMMProblem &problem,
2350 const GEMMStrategy &strategy, GEMMState &state, int buffer,
2351 bool firstMultiply = false,
2352 ngen::RegData flagWaitLoad = ngen::RegData(),
2353 ngen::RegData flagSignal = ngen::RegData(),
2354 ngen::Label *labelDone = nullptr);
2355 void sysgemmMultiplyChunk(const GEMMProblem &problem,
2356 const GEMMStrategy &strategy, bool first, int ao, int i0,
2357 bool waitB, bool prepB,
2358 const ngen::InstructionModifier &swsb0
2359 = ngen::InstructionModifier(),
2360 const ngen::InstructionModifier &swsbEnd
2361 = ngen::InstructionModifier());
2362 void sysgemmBarrierPrep(
2363 const ngen::InstructionModifier &swsb, const ngen::GRF &header);
2364 void sysgemmReorderLocalIDs(const GEMMProblem &problem,
2365 const GEMMStrategy &strategy, GEMMState &state);
2366
2367 bool sysgemm2AccumulateC(GEMMProblem &problem, const GEMMStrategy &strategy,
2368 GEMMState &state);
2369 void sysgemm2KLoopCompute(const GEMMProblem &problem,
2370 const GEMMStrategy &strategy, GEMMState &state);
2371 void sysgemm2KLoopCopy(const GEMMProblem &problem,
2372 const GEMMStrategy &strategy, GEMMState &state);
2373 void sysgemm2Multiply(const GEMMProblem &problem,
2374 const GEMMStrategy &strategy, GEMMState &state, int buffer,
2375 bool cooldown = false,
2376 ngen::FlagRegister flagWaitLoad = ngen::FlagRegister(),
2377 ngen::FlagRegister flagSignal = ngen::FlagRegister());
2378 void sysgemm2MultiplyX32(const GEMMProblem &problem,
2379 const GEMMStrategy &strategy, GEMMState &state, int buffer,
2380 bool cooldown = false,
2381 ngen::FlagRegister flagWaitLoad = ngen::FlagRegister(),
2382 ngen::FlagRegister flagSignal = ngen::FlagRegister());
2383 void sysgemm2MultiplyX48(const GEMMProblem &problem,
2384 const GEMMStrategy &strategy, GEMMState &state, int buffer,
2385 bool cooldown = false,
2386 ngen::FlagRegister flagWaitLoad = ngen::FlagRegister(),
2387 ngen::FlagRegister flagSignal = ngen::FlagRegister());
2388 void sysgemm2MultiplyChunkX32(const GEMMProblem &problem,
2389 const GEMMStrategy &strategy, int chunkA, bool odd);
2390 void sysgemm2MultiplyChunkX48(const GEMMProblem &problem,
2391 const GEMMStrategy &strategy, int chunkA);
2392
2393 bool copyRegisterBlock(Type Ts, Type Td, const RegisterBlock &blockSrc,
2394 const RegisterBlock &blockDst, const GRFMultirange &src,
2395 const GRFMultirange &dst, int dOffR, int dOffC,
2396 const CommonStrategy &strategy, CommonState &state,
2397 bool preserveSrc = false);
2398 bool copyRegisters(Type Ts, Type Td,
2399 const std::vector<RegisterBlock> &layoutSrc,
2400 const std::vector<RegisterBlock> &layoutDst,
2401 const GRFMultirange &src, const GRFMultirange &dst, int dOffR,
2402 int dOffC, bool conjugate, const CommonStrategy &strategy,
2403 CommonState &state, bool preserveSrc = false);
2404 bool copyRegisters(Type Ts, Type Td,
2405 const std::vector<RegisterBlock> &layoutSrc,
2406 const std::vector<RegisterBlock> &layoutDst,
2407 const GRFMultirange &src, const GRFMultirange &dst, int dOffR,
2408 int dOffC, const Scalar<double> &alpha_real,
2409 const Scalar<double> &alpha_imag, bool conjugate,
2410 const CommonStrategy &strategy, CommonState &state,
2411 bool preserveSrc = false);
2412
2413 bool copyBody(
2414 CopyProblem &problem, CopyStrategy &strategy, CopyState &state);
2415 bool copyBodyRemCheck(
2416 CopyProblem &problem, CopyStrategy &strategy, CopyState &state);
2417 bool copyBodyInternal(
2418 CopyProblem &problem, CopyStrategy &strategy, CopyState &state);
2419 void copySlice(
2420 CopyProblem &problem, CopyStrategy &strategy, CopyState &state);
2421
2422 void copyCalcIncrements(const CopyProblem &problem,
2423 const CopyStrategy &strategy, CopyState &state, int s_load = 0,
2424 int d_load = 0);
2425
2426 void copyInitInterface(
2427 CopyProblem &problem, CopyStrategy &strategy, CopyState &state);
2428 void copyInitState(
2429 CopyProblem &problem, CopyStrategy &strategy, CopyState &state);
2430 void copy(CopyProblem &problem, CopyStrategy &strategy, CopyState &state);
2431
2432 void prologue(const CommonStrategy &strategy);
2433 void epilogue(const CommonStrategy &strategy, const CommonState &state);
2434 void padding();
2435 void initState(const CommonProblem &problem, const CommonStrategy &strategy,
2436 CommonState &state);
2437};
2438
2439inline char precisionChar(Type T) {
2440 switch (T.baseType()) {
2441 case Type::f16: return 'H';
2442 case Type::f32: return 'S';
2443 case Type::u8: return 'o';
2444 case Type::s8: return 'O';
2445 case Type::u16: return 'w';
2446 case Type::s16: return 'W';
2447 case Type::u32: return 'i';
2448 case Type::s32: return 'I';
2449 case Type::u64: return 'l';
2450 case Type::s64: return 'L';
2451 case Type::bf16: return 'B';
2452 case Type::tf32: return 'T';
2453 default: return '?';
2454 }
2455}
2456
2457static inline Type charPrecision(char c) {
2458 switch (c) {
2459 case 'H': return Type::f16;
2460 case 'S': return Type::f32;
2461 case 'o': return Type::u8;
2462 case 'O': return Type::s8;
2463 case 'w': return Type::u16;
2464 case 'W': return Type::s16;
2465 case 'i': return Type::u32;
2466 case 'I': return Type::s32;
2467 case 'B': return Type::bf16;
2468 case 'T': return Type::tf32;
2469 default: return Type::invalid;
2470 }
2471}
2472
2473inline char layoutChar(MatrixLayout layout) {
2474 switch (layout) {
2475 case MatrixLayout::N: return 'N';
2476 case MatrixLayout::T: return 'T';
2477 case MatrixLayout::Pc: return 'A';
2478 case MatrixLayout::Pr: return 'B';
2479 default: return '?';
2480 }
2481}
2482
2483} // namespace jit
2484} // namespace gpu
2485} // namespace impl
2486} // namespace dnnl
2487
2488#endif /* header guard */
2489