1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef GPU_JIT_GEMM_GEN_GEMM_KERNEL_GENERATOR_HPP |
18 | #define GPU_JIT_GEMM_GEN_GEMM_KERNEL_GENERATOR_HPP |
19 | |
20 | /* Embargo support */ |
21 | |
22 | #define STANDALONE 0 |
23 | |
24 | #include "common/math_utils.hpp" |
25 | #include "common/utils.hpp" |
26 | #include "gpu/jit/gemm/gen_gemm_kernel_common.hpp" |
27 | #include "gpu/jit/gemm/utils.hpp" |
28 | #include "gpu/jit/jit_generator.hpp" |
29 | #include "gpu/jit/jit_post_op_injector.hpp" |
30 | |
31 | #if defined(ZEBIN_OUTPUT) |
32 | #include "../ngen/ngen_elf.hpp" |
33 | #else |
34 | #include "../ngen/ngen_opencl.hpp" |
35 | |
36 | #endif |
37 | #include "../ngen/ngen_register_allocator.hpp" |
38 | |
39 | #include "gpu/jit/gemm/emulation.hpp" |
40 | |
41 | #include <array> |
42 | #include <complex> |
43 | #include <cstdint> |
44 | #include <exception> |
45 | #include <iostream> |
46 | #include <sstream> |
47 | #include <vector> |
48 | |
49 | namespace dnnl { |
50 | namespace impl { |
51 | namespace gpu { |
52 | namespace jit { |
53 | |
54 | struct RegisterBlock; |
55 | |
56 | class Type { |
57 | public: |
58 | enum _Type : uint32_t { |
59 | invalid = 0, |
60 | f16 = 0x01000201, |
61 | f32 = 0x01010402, |
62 | u8 = 0x01840100, |
63 | s8 = 0x01850100, |
64 | u16 = 0x01860201, |
65 | s16 = 0x01870201, |
66 | u32 = 0x01880402, |
67 | s32 = 0x01890402, |
68 | u64 = 0x018A0803, |
69 | s64 = 0x018B0803, |
70 | bf16 = 0x010C0201, |
71 | tf32 = 0x010D0402, |
72 | }; |
73 | |
74 | private: |
75 | _Type val; |
76 | |
77 | public: |
78 | constexpr Type() : Type(f32) {} |
79 | constexpr Type(_Type val_) : val(val_) {} |
80 | constexpr operator _Type() const { return val; } |
81 | |
82 | constexpr Type real() const { return *this; } |
83 | constexpr bool isComplex() const { return false; } |
84 | constexpr int complexComponents() const { return 1; } |
85 | constexpr int components() const { return 1; } |
86 | constexpr bool isInteger() const { return uint32_t(val) & 0x800000; } |
87 | constexpr bool isFP() const { return !isInteger(); } |
88 | constexpr bool isSigned() const { |
89 | return (uint32_t(val) & 0x810000) != 0x800000; |
90 | } |
91 | constexpr int log2Size() const { return uint32_t(val) & 0xFF; } |
92 | constexpr int size() const { return (uint32_t(val) >> 8) & 0xFF; } |
93 | |
94 | constexpr Type arithmetic() const { |
95 | return (val == tf32) ? Type(f32) : real(); |
96 | } |
97 | data_type_t get_dnnl_type() const { |
98 | switch (val) { |
99 | case Type::f32: return data_type::f32; |
100 | case Type::f16: return data_type::f16; |
101 | case Type::s32: return data_type::s32; |
102 | case Type::u8: return data_type::u8; |
103 | case Type::s8: return data_type::s8; |
104 | default: assert(!"Unsupported type" ); return data_type::undef; |
105 | } |
106 | } |
107 | constexpr Type baseType() const { return *this; } |
108 | |
109 | template <typename U> |
110 | constexpr friend int operator*(U a, Type t) { |
111 | return int(a << t.log2Size()); |
112 | } |
113 | template <typename U> |
114 | constexpr friend int operator/(U a, Type t) { |
115 | return int(a >> t.log2Size()); |
116 | } |
117 | |
118 | ngen::DataType ngen() const { |
119 | using namespace ngen; |
120 | static const DataType table[16] = {DataType::hf, DataType::f, |
121 | DataType::df, DataType::invalid, DataType::ub, DataType::b, |
122 | DataType::uw, DataType::w, DataType::ud, DataType::d, |
123 | DataType::uq, DataType::q, DataType::bf, DataType::tf32, |
124 | DataType::invalid, DataType::invalid}; |
125 | return table[(uint32_t(val) >> 16) & 0xF]; |
126 | } |
127 | |
128 | bool isSubsetOf(Type T) const; |
129 | }; |
130 | |
131 | enum class MatrixLayout : uint8_t { |
132 | N = 0, |
133 | Nontranspose = 0, |
134 | T = 1, |
135 | Transpose = 1, |
136 | Pc = 2, |
137 | PackedColumns = 2, |
138 | Pr = 3, |
139 | PackedRows = 3 |
140 | }; |
141 | |
142 | static inline bool isPacked(MatrixLayout l) { |
143 | return (l == MatrixLayout::PackedRows) |
144 | || (l == MatrixLayout::PackedColumns); |
145 | } |
146 | |
147 | static inline bool isColMajor(MatrixLayout l) { |
148 | return (l == MatrixLayout::N || l == MatrixLayout::Pc); |
149 | } |
150 | |
151 | static inline bool isLargeCrosspack(size_t sizeofT, int crosspack) { |
152 | return (crosspack * sizeofT > 4) && (crosspack > 1); |
153 | } |
154 | |
155 | static inline bool isLargeCrosspack(Type T, int crosspack) { |
156 | return isLargeCrosspack(T.size(), crosspack); |
157 | } |
158 | |
159 | enum class AccessType : uint8_t { |
160 | Scattered, // Use scattered accesses |
161 | ChannelScattered, // Use untyped surface reads |
162 | Block, // Use block messages |
163 | PseudoBlock, // Use scattered accesses to emulate block accesses |
164 | Block2D, // Use 2D block messages |
165 | Block2DTranspose, // Use 2D block messages with transposition |
166 | Block2DVNNI, // Use 2D block messages with VNNI transform |
167 | }; |
168 | |
169 | static inline bool isBlock2D(AccessType t) { |
170 | return (t == AccessType::Block2D || t == AccessType::Block2DTranspose |
171 | || t == AccessType::Block2DVNNI); |
172 | } |
173 | |
174 | enum class RemainderHandling : uint8_t { |
175 | Ignore, // Assume no remainder, or handled by hardware bounds checking. |
176 | General, // Handle all remainder cases. |
177 | Split, // Generate copies of the kernel with and without remainder handling. |
178 | KnownRemainder, // Assume remainder case; don't create special code for non-remainder case. |
179 | }; |
180 | |
181 | enum class KernelScheduling : uint8_t { |
182 | Static, |
183 | EUStatic, |
184 | Dynamic, |
185 | }; |
186 | |
187 | // Preferences for using scattered accesses. |
188 | enum class ScatterSIMD { |
189 | Default, |
190 | Wide, // Prefer wider SIMD (more scattered lanes) |
191 | Narrow // Prefer narrower SIMD (more consecutive access) |
192 | }; |
193 | |
194 | struct GRFMultirange { |
195 | std::vector<ngen::GRFRange> ranges; |
196 | |
197 | GRFMultirange() {} |
198 | GRFMultirange(ngen::GRFRange range) : ranges {1, range} {} |
199 | |
200 | ngen::GRF operator[](int idx) const { |
201 | for (auto &r : ranges) { |
202 | if (idx < r.getLen()) return r[idx]; |
203 | idx -= r.getLen(); |
204 | } |
205 | throw std::runtime_error("Index out of bounds" ); |
206 | } |
207 | |
208 | GRFMultirange subrange(int start, int count) const { |
209 | GRFMultirange result; |
210 | for (auto &r : ranges) { |
211 | if (start < r.getLen()) { |
212 | auto got = std::min(count, r.getLen() - start); |
213 | result.ranges.push_back( |
214 | ngen::GRFRange {r.getBase() + start, got}); |
215 | count -= got; |
216 | start = 0; |
217 | if (count <= 0) break; |
218 | } else |
219 | start -= r.getLen(); |
220 | } |
221 | return result; |
222 | } |
223 | |
224 | GRFMultirange subrange( |
225 | ngen::HW hw, Type T, const RegisterBlock &block) const; |
226 | |
227 | bool contiguous(int start, int count) const { |
228 | for (auto &r : ranges) { |
229 | if (start < r.getLen()) return (start + count) <= r.getLen(); |
230 | start -= r.getLen(); |
231 | } |
232 | return false; |
233 | } |
234 | |
235 | void append(ngen::GRFRange r) { |
236 | if (!ranges.empty()) { |
237 | auto &rend = ranges.back(); |
238 | if (rend.getBase() + rend.getLen() == r.getBase()) { |
239 | rend = ngen::GRFRange( |
240 | rend.getBase(), rend.getLen() + r.getLen()); |
241 | return; |
242 | } |
243 | } |
244 | ranges.push_back(r); |
245 | } |
246 | |
247 | void append(const GRFMultirange &r) { |
248 | for (auto &rr : r.ranges) |
249 | append(rr); |
250 | } |
251 | |
252 | uint8_t getLen() const { |
253 | uint8_t len = 0; |
254 | for (auto &r : ranges) |
255 | len += r.getLen(); |
256 | return len; |
257 | } |
258 | |
259 | bool empty() const { |
260 | for (auto &r : ranges) |
261 | if (r.getLen() > 0) return false; |
262 | return true; |
263 | } |
264 | void clear() { ranges.clear(); } |
265 | }; |
266 | |
267 | // A pair of Subregisters in opposite banks. |
268 | class SubregisterPair { |
269 | protected: |
270 | ngen::Subregister regs[2]; |
271 | bool negative; |
272 | |
273 | public: |
274 | SubregisterPair() : SubregisterPair(ngen::Subregister()) {} |
275 | SubregisterPair(ngen::Subregister reg0, ngen::Subregister reg1) |
276 | : regs {reg0, reg1}, negative(false) {} |
277 | explicit SubregisterPair(ngen::Subregister reg) |
278 | : SubregisterPair(reg, reg) {} |
279 | |
280 | /* implicit */ operator ngen::Subregister() const { return regs[0]; } |
281 | |
282 | SubregisterPair &operator=(ngen::Subregister reg) { |
283 | regs[0] = regs[1] = reg; |
284 | negative = false; |
285 | return *this; |
286 | } |
287 | |
288 | ngen::Subregister getReg(int idx) const; |
289 | ngen::Subregister getRegAvoiding( |
290 | ngen::HW hw, const ngen::RegData &rd) const; |
291 | |
292 | bool isValid() const { return regs[0].isValid() && regs[1].isValid(); } |
293 | bool isInvalid() const { return !isValid(); } |
294 | void invalidate() { |
295 | regs[0].invalidate(); |
296 | regs[1].invalidate(); |
297 | } |
298 | |
299 | SubregisterPair operator-() const { |
300 | auto copy = *this; |
301 | copy.negative = !copy.negative; |
302 | return copy; |
303 | } |
304 | }; |
305 | |
306 | template <typename T> |
307 | class Scalar { |
308 | protected: |
309 | bool fixed_value; |
310 | union { |
311 | SubregisterPair subs; |
312 | T value; |
313 | }; |
314 | |
315 | public: |
316 | Scalar() : Scalar(ngen::Subregister()) {} |
317 | explicit Scalar(T value_) : fixed_value(true), value(value_) {} |
318 | Scalar(ngen::Subregister reg0, ngen::Subregister reg1) |
319 | : fixed_value(false), subs {reg0, reg1} {} |
320 | explicit Scalar(ngen::Subregister reg) : Scalar(reg, reg) {} |
321 | |
322 | Scalar &operator=(T value_) { |
323 | fixed_value = true; |
324 | value = value_; |
325 | return *this; |
326 | } |
327 | Scalar &operator=(ngen::Subregister reg) { |
328 | fixed_value = false; |
329 | subs = reg; |
330 | return *this; |
331 | } |
332 | |
333 | template <typename U> |
334 | friend inline bool operator==(const Scalar<T> &scalar, const U &val) { |
335 | return scalar.fixed_value && (val == scalar.value); |
336 | } |
337 | template <typename U> |
338 | friend inline bool operator==(const U &val, const Scalar<T> &scalar) { |
339 | return scalar == val; |
340 | } |
341 | |
342 | template <typename U> |
343 | friend inline bool operator!=(const Scalar<T> &scalar, const U &val) { |
344 | return !(scalar == val); |
345 | } |
346 | template <typename U> |
347 | friend inline bool operator!=(const U &val, const Scalar<T> &scalar) { |
348 | return !(scalar == val); |
349 | } |
350 | |
351 | operator T() const { |
352 | if (!fixed_value) throw std::runtime_error("Scalar is not fixed." ); |
353 | return value; |
354 | } |
355 | |
356 | operator SubregisterPair() const { |
357 | if (fixed_value) throw std::runtime_error("Scalar is fixed." ); |
358 | return subs; |
359 | } |
360 | |
361 | SubregisterPair &getPair() { |
362 | if (fixed_value) throw std::runtime_error("Scalar is fixed." ); |
363 | return subs; |
364 | } |
365 | |
366 | bool fixed() const { return fixed_value; } |
367 | |
368 | ngen::Subregister getReg(int idx) const { |
369 | return SubregisterPair(*this).getReg(idx); |
370 | } |
371 | ngen::Subregister getRegAvoiding( |
372 | ngen::HW hw, const ngen::RegData &rd) const { |
373 | return SubregisterPair(*this).getRegAvoiding(hw, rd); |
374 | }; |
375 | }; |
376 | |
377 | class MultishiftSubregister { |
378 | protected: |
379 | static constexpr int maxShift = 5; |
380 | ngen::Subregister regs[maxShift + 1] = {ngen::Subregister()}; |
381 | bool neg = false; |
382 | |
383 | public: |
384 | MultishiftSubregister operator-() const { |
385 | auto copy = *this; |
386 | copy.neg = !copy.neg; |
387 | return copy; |
388 | } |
389 | |
390 | ngen::Subregister operator>>(int shift) const { |
391 | ngen::RegData sub = ngen::Subregister {}; |
392 | if (shift >= 0 && shift <= maxShift) sub = regs[shift]; |
393 | if (neg) sub = -sub; |
394 | return *reinterpret_cast<ngen::Subregister *>(&sub); |
395 | } |
396 | |
397 | void set(int shift, ngen::Subregister reg) { regs[shift] = reg; } |
398 | }; |
399 | |
400 | struct MatrixAddressing { |
401 | MatrixLayout layout; // Layout type (N/T/Pr/Pc) |
402 | uint8_t packSize; // # of elements in a packed row/column for packed layouts. |
403 | uint8_t crosspack; // Crosspack for packed layouts. |
404 | uint8_t alignment; // Alignment for all addresses, offsets, and leading dimensions. |
405 | uint8_t tileR = 0, tileC = 0; // Tiling (0 if none) for packed layouts. |
406 | |
407 | void setAlignment(int align) { alignment = sanitizeAlign(align); } |
408 | int defaultAlignment(Type T) const { |
409 | return sanitizeAlign( |
410 | T.size() * (isPacked(layout) ? (packSize * crosspack) : 1)); |
411 | } |
412 | |
413 | private: |
414 | static int sanitizeAlign(int align) { |
415 | return std::min(128, largest_pow2_divisor(align)); |
416 | } |
417 | }; |
418 | |
419 | struct MatrixAddressingStrategy { |
420 | ngen::AddressBase base; // Base for addressing (A64/BTS/...) |
421 | AccessType accessType = AccessType::Block; // Block/scattered/etc. access |
422 | uint8_t tileR = 0, tileC = 0; // Desired tiling (0 if none) in registers. |
423 | ScatterSIMD smode |
424 | = ScatterSIMD::Default; // SIMD selection for scattered accesses. |
425 | unsigned padded : 1; // Allow read/write overruns? |
426 | unsigned atomic : 1; // Atomic access? (only relevant for C) |
427 | unsigned address2D : 1; // Use 2D addressing? (media block-style loads) |
428 | unsigned prefetch : 1; // Prefetch only? |
429 | unsigned newDP : 1; // Use new dataport messages? (XeHPG+) |
430 | unsigned dpasw : 1; // DPASW half layout? |
431 | ngen::CacheSettingsLSC cachingR // Cache policies for LSC reads. |
432 | = ngen::CacheSettingsLSC::Default; |
433 | ngen::CacheSettingsLSC cachingW // Cache policies for LSC writes. |
434 | = ngen::CacheSettingsLSC::Default; |
435 | |
436 | MatrixAddressingStrategy() |
437 | : padded(false) |
438 | , atomic(false) |
439 | , address2D(false) |
440 | , prefetch(false) |
441 | , newDP(false) |
442 | , dpasw(false) {} |
443 | |
444 | void preflight(ngen::HW hw); |
445 | void forceA64(); |
446 | |
447 | ngen::GlobalAccessType getGlobalAccessType() const { |
448 | return base.isStateless() ? ngen::GlobalAccessType::Stateless |
449 | : ngen::GlobalAccessType::Surface; |
450 | } |
451 | }; |
452 | |
453 | struct VirtualFlag { |
454 | uint8_t idx : 6; |
455 | uint8_t n : 2; |
456 | |
457 | constexpr VirtualFlag() : idx(0), n(0) {} |
458 | /* implicit */ VirtualFlag(const ngen::FlagRegister &flag) |
459 | : idx(flag.index()), n(flag.getBytes() >> 1) {} |
460 | explicit constexpr VirtualFlag(int idx_, int n_ = 1) : idx(idx_), n(n_) {} |
461 | |
462 | ngen::FlagRegister toPhysical() const; |
463 | |
464 | friend inline bool operator==(VirtualFlag vf1, VirtualFlag vf2) { |
465 | return vf1.idx == vf2.idx && vf1.n == vf2.n; |
466 | } |
467 | friend inline bool operator!=(VirtualFlag vf1, VirtualFlag vf2) { |
468 | return !(vf1 == vf2); |
469 | } |
470 | |
471 | bool operator!() const { return (idx == 0) && (n == 0); } |
472 | explicit operator bool() const { return !!*this; } |
473 | |
474 | void clear() { *this = VirtualFlag(); } |
475 | }; |
476 | |
477 | struct MaskInfo { |
478 | union { |
479 | struct { |
480 | uint8_t isFixed : 1; // = false (variable mask) |
481 | uint8_t reverse : 1; // True to reverse mask. |
482 | uint8_t : 6; |
483 | uint8_t rsize; // Maximum remainder value. (e.g. 16 if we need the last 4 bits of the index). |
484 | uint8_t maskRep; // # of repetitions of mask pattern. |
485 | uint8_t bitRep : 5; // # of times each mask bit is repeated. |
486 | uint8_t rdivide : 3; // Amount by which to divide index before forming mask. Fractions are rounded up. |
487 | // Note maskRep * bitRep * (rsize >> rshift) = # mask bits. |
488 | } variable; |
489 | struct { |
490 | uint8_t isFixed : 1; // = true (fixed mask) |
491 | uint8_t _ : 7; |
492 | uint8_t rsize; // Maximum remainder value. |
493 | uint16_t value; // Mask value. |
494 | } fixed; |
495 | uint32_t raw; |
496 | }; |
497 | |
498 | MaskInfo() : fixed {true, 0, 0, 0xFFFF} {} |
499 | |
500 | bool operator!() const { return fixed.isFixed && fixed.value == 0xFFFF; } |
501 | explicit operator bool() const { return !!*this; } |
502 | |
503 | static MaskInfo None() { return MaskInfo(); } |
504 | |
505 | friend bool operator==(const MaskInfo &i1, const MaskInfo &i2) { |
506 | return i1.raw == i2.raw; |
507 | } |
508 | friend bool operator!=(const MaskInfo &i1, const MaskInfo &i2) { |
509 | return !(i1 == i2); |
510 | } |
511 | }; |
512 | |
513 | struct MaskAssignment { |
514 | MaskInfo mask; // Associated mask |
515 | LoopType var; // Variable to base mask off of |
516 | uint8_t offset; // Amount to subtract from variable. |
517 | VirtualFlag flag; // Index of virtual flag register to use. |
518 | |
519 | bool compatible(const MaskAssignment &other) const { |
520 | return mask == other.mask && var == other.var && offset == other.offset; |
521 | } |
522 | void reverse(int width) { |
523 | offset = width - offset - mask.variable.rsize; |
524 | mask.variable.reverse = !mask.variable.reverse; |
525 | } |
526 | }; |
527 | |
528 | struct RegisterBlock { |
529 | /* Register layout information. */ |
530 | uint8_t nr, nc; // Size of this block. |
531 | uint8_t ld; // Leading dimension, in elements. |
532 | uint8_t offsetR, offsetC; // Row and column offset within matrix block. |
533 | uint8_t colMajor : 1; // Is this block column-major? (columns stored consecutively inside each register) |
534 | uint8_t splitComplex : 1; // True if complex data split into successive real and imaginary parts. |
535 | uint8_t : 6; |
536 | uint8_t crosspack; // Crosspack for this block (1 if none). |
537 | uint8_t component : 6; // Component # for this block. |
538 | int8_t cxComponent : 2; // Complex component # for this block (-1 if not complex or interleaved). |
539 | uint16_t bytes; // # of bytes in this block. |
540 | uint16_t offsetBytes; // Byte offset within register block. |
541 | |
542 | /* Load/store information. */ |
543 | uint8_t remainderR : 1; // Row remaindering enabled? |
544 | uint8_t remainderC : 1; // Column remaindering enabled? |
545 | uint8_t noRowsOK : 1; // Can handle no rows (in mask/descriptor)? |
546 | uint8_t noColsOK : 1; // Can handle no columns (in mask/descriptor)? |
547 | uint8_t descRemR : 1; // Row remainders can be handled by changing the descriptor? |
548 | uint8_t descRemC : 1; // Column remainders can be handled by changing the descriptor? |
549 | uint8_t descAssigned : 1; // True if address registers have been assigned for this block's descriptors. |
550 | uint8_t writable : 1; // True if block is set up for writing. |
551 | |
552 | uint8_t ebytes; // Size of element in bytes, e.g. 4 for scattered_dword, 16 for block_hword |
553 | uint8_t count; // Element count. |
554 | uint8_t ; // Extra info. For block accesses, 1 means aligned OWord, 0 unaligned. For scattered accesses, # of consecutive elements. |
555 | uint8_t simdSize; // SIMD size for load/stores (0 indicating no separate load/store needs to be done.) |
556 | uint8_t msgRegs; // Underlying register count for load/store operation (may be different from nregs()). |
557 | VirtualFlag flag; // Assigned flag register index and modifiers, if any. |
558 | uint8_t flagAny : 1; // Use .anyh? |
559 | uint8_t flagAll : 1; // Use .allh? |
560 | uint8_t hasNoLoad : 1; // Does this load/store cover additional (no-load) RegisterBlocks? (packed layouts) |
561 | uint8_t : 5; |
562 | uint8_t sfid; // SFID for this block. |
563 | uint8_t rowFragment; // If this block needs fragmenting to support row/column remainders, the maximum block size (power of 2) to fragment down to. |
564 | uint8_t colFragment; // Zero if no fragmenting needed. |
565 | uint8_t addrShift; // log2(address units). e.g. 0 if byte addresses should be used, 4 if oword addresses should be used. |
566 | uint8_t log2GRFBytes; // log2(bytes per GRF). |
567 | |
568 | MaskInfo rowMask; // Row mask for this block. |
569 | MaskInfo colMask; // Column mask for this block. |
570 | |
571 | static constexpr int8_t Interleaved |
572 | = -1; // Value for cxComponent indicating interleaved real/imaginary data. |
573 | |
574 | void calcBytes(Type T); // Auto-calculate # of registers. |
575 | void calcBytes(Type T, const MatrixAddressingStrategy &astrategy); |
576 | |
577 | void clearFlag() { |
578 | flag.clear(); |
579 | flagAll = flagAny = false; |
580 | } |
581 | void eraseMask() { |
582 | clearFlag(); |
583 | rowMask = MaskInfo(); |
584 | colMask = MaskInfo(); |
585 | } |
586 | |
587 | bool isLoadBlock() const { return simdSize > 0; } |
588 | |
589 | int nregs() const; |
590 | int offsetReg() const; |
591 | |
592 | void simplify(Type T); |
593 | void compact(Type T); |
594 | }; |
595 | |
596 | struct Address2DParams { |
597 | ngen::Subregister rows, cols; |
598 | ngen::Subregister offR, offC; |
599 | ngen::Subregister remR, remC; |
600 | int fixedRows = 0, fixedCols = 0; |
601 | }; |
602 | |
603 | class VirtualFlagAllocator { |
604 | public: |
605 | VirtualFlagAllocator(ngen::HW hw) |
606 | : free((1ul << (ngen::GRF::bytes(hw) >> 1)) - 1) |
607 | , nflag(ngen::FlagRegister::subcount(hw)) {} |
608 | |
609 | VirtualFlag allocVirtual(int n = 1); |
610 | ngen::FlagRegister alloc(int n = 1); |
611 | ngen::FlagRegister tryAlloc(int n = 1); |
612 | |
613 | void claim(VirtualFlag vflag) { free &= ~mask(vflag); } |
614 | void release(VirtualFlag vflag) { free |= mask(vflag); } |
615 | void release(const ngen::FlagRegister ®) { |
616 | release(VirtualFlag(reg)); |
617 | unlock(reg); |
618 | } |
619 | void safeRelease(VirtualFlag &vflag) { |
620 | if (vflag) release(vflag); |
621 | vflag.clear(); |
622 | } |
623 | void safeRelease(ngen::FlagRegister ®) { |
624 | if (reg.isValid()) release(reg); |
625 | reg.invalidate(); |
626 | } |
627 | |
628 | bool isVirtual(VirtualFlag vflag) { return (vflag.idx >= nflag); } |
629 | |
630 | bool lock(VirtualFlag vflag) { |
631 | bool wasLocked = isLocked(vflag); |
632 | locked |= mask(vflag); |
633 | return wasLocked; |
634 | } |
635 | void unlock(VirtualFlag vflag) { locked &= ~mask(vflag); } |
636 | bool isLocked(VirtualFlag vflag) const { return (locked & mask(vflag)); } |
637 | |
638 | ngen::FlagRegister assignPhysical(VirtualFlag vflag); |
639 | |
640 | static int getBase(int idx) { return idx & 0x1F; } |
641 | static int getN(int idx) { return idx >> 5; } |
642 | static int makeIndex(int base, int n) { return base | (n << 5); } |
643 | |
644 | protected: |
645 | uint32_t free; |
646 | uint8_t locked = 0; |
647 | uint8_t nextPhys = 0; |
648 | uint8_t nflag; |
649 | |
650 | static uint32_t mask(VirtualFlag vflag) { return mask(vflag.idx, vflag.n); } |
651 | static uint32_t mask(int idx, int n) { |
652 | return (1ul << (idx + n)) - (1ul << idx); |
653 | } |
654 | }; |
655 | |
656 | class TokenAllocator { |
657 | public: |
658 | TokenAllocator(ngen::HW hw); |
659 | |
660 | int8_t tryAlloc(); |
661 | void release(int8_t token) { free |= (1u << token); } |
662 | void safeRelease(int8_t &token) { |
663 | if (token >= 0) release(token); |
664 | token = -1; |
665 | } |
666 | |
667 | protected: |
668 | uint32_t free; |
669 | }; |
670 | |
671 | // State parameters shared between different kernel types. |
672 | struct CommonState { |
673 | ngen::RegisterAllocator ra; |
674 | ngen::GRF signChange, selectImag; |
675 | ngen::GRF vflagStorage; |
676 | std::array<VirtualFlag, 8> activeVFlags; |
677 | VirtualFlagAllocator raVFlag; |
678 | TokenAllocator tokenAllocator; |
679 | std::vector<std::pair<uint8_t, int8_t>> tokenMap; |
680 | ngen::Subregister readFailures; |
681 | ngen::Subregister fusedID; |
682 | ngen::Subregister lsDescConstant[4]; |
683 | ngen::FlagRegister flagSwizzle; |
684 | EmulationState emulate; |
685 | ngen::GRFRange eatomicAddRegs[2]; |
686 | ngen::GRFRange remaskRegs[2]; |
687 | VirtualFlag vflagEAtomicAdd; |
688 | ngen::Subregister all1s; |
689 | ngen::RegData r0_info; |
690 | bool movedR0 = false; |
691 | ngen::Subregister lid0; |
692 | GRFMultirange indexVec; // uw |
693 | int ivEntries = 0; |
694 | struct { |
695 | ngen::GRF zero, one; |
696 | ngen::GRFRange src1Storage; |
697 | ngen::GRF src1, srcR1, srcI1, r, d; |
698 | ngen::GRFRange mathTemp; |
699 | ngen::GRF temp; |
700 | std::array<ngen::FlagRegister, 2> tempFlags; |
701 | ngen::Subregister flagStore; // ud |
702 | ngen::Label label; |
703 | int simd; |
704 | ngen::Subregister callStorageSub, callStorage; |
705 | bool use = false; |
706 | } invertSub; |
707 | |
708 | CommonState(ngen::HW hw) : ra(hw), raVFlag(hw), tokenAllocator(hw) {} |
709 | |
710 | void wipeActiveVFlags() { |
711 | for (int i = 0; i < int(activeVFlags.size()); i++) |
712 | if (!raVFlag.isLocked(VirtualFlag(i))) activeVFlags[i].clear(); |
713 | } |
714 | |
715 | void usePhysicalFlag(ngen::FlagRegister flag) { |
716 | activeVFlags[flag.index()] = flag; |
717 | } |
718 | |
719 | void allocEmulate64Temp(const EmulationStrategy &estrategy) { |
720 | int ntemp = 0; |
721 | if (estrategy.emulate64) ntemp = std::max(ntemp, 2); |
722 | if (estrategy.emulate64_mul) ntemp = std::max(ntemp, 2); |
723 | if (estrategy.emulateDWxDW) ntemp = std::max(ntemp, 1); |
724 | |
725 | for (int q = 0; q < ntemp; q++) |
726 | emulate.temp[q] = ra.alloc(); |
727 | } |
728 | }; |
729 | |
730 | // Places to store r0 information. |
731 | enum class MoveR0 { None, Acc, Addr, GRF }; |
732 | |
733 | // Problem parameters shared between kernel types. |
734 | struct CommonProblem { |
735 | bool nonuniformWGs = false; // Support nonuniform workgroups? |
736 | bool gtpinSupport = false; // Support GT-Pin? |
737 | }; |
738 | |
739 | // Strategy parameters shared between different kernel types. |
740 | struct CommonStrategy { |
741 | int subgroupSize = 8; // Subgroup size provided to OpenCL runtime. |
742 | bool fused = false; // Fused EU handling enabled? |
743 | bool dualGRF = true; // Enable two-GRF instructions. |
744 | bool ieeeDenormals = true; // Enable IEEE-compliant denormals. |
745 | bool spf = true; // Enable Single Program Flow (SPF) mode in EUs. |
746 | MoveR0 moveR0 = MoveR0::Acc; // Where to store r0 information. |
747 | bool sipR0WA = false; // Avoid using r0 to avoid clobbering by SIP. |
748 | bool readSuppressionWA |
749 | = true; // Workaround for HW issue with read suppression after fused sends. |
750 | bool wgInSS |
751 | = false; // Pretend to use barriers so that each WG belongs to 1 SS/DSS. |
752 | int GRFs = 128; // # of GRFs to use. |
753 | bool finalFence = false; // Issue global memory fence before EOT. |
754 | int pauseCycles |
755 | = 0x0200; // Number of cycles to pause when waiting in a spin-loop. |
756 | bool simulation = false; // For use in simulator? |
757 | |
758 | EmulationStrategy emulate; |
759 | |
760 | CommonStrategy() {} |
761 | CommonStrategy(ngen::HW hw, int stepping = 0); |
762 | void preflight(ngen::HW hw, const CommonProblem &problem); |
763 | }; |
764 | |
765 | // Types of updates for GEMM kernels. |
766 | enum class UpdateType { |
767 | Full, |
768 | UpperTriangle, |
769 | UpperTriangleHermitian, |
770 | LowerTriangle, |
771 | LowerTriangleHermitian |
772 | }; |
773 | |
774 | // A/B offset mode. |
775 | enum class ABOffset { |
776 | None, // No A/B offsets. |
777 | Calc, // Calculate A/B row/column sums in kernel. |
778 | Load, // Use precalculated row/column sums. |
779 | }; |
780 | |
781 | // C offset mode. |
782 | enum class COffset { |
783 | None, // No C offsets. |
784 | Post, // C offset after all other updates. |
785 | Pre, // C offset before all other updates (bias). |
786 | }; |
787 | |
788 | // Batch mode. |
789 | enum class BatchMode { None, Strided, Nonstrided, Variable }; |
790 | |
791 | // Binary operations. |
792 | enum class BinaryOp { Add, Sub, Mul, Div, Min, Max }; |
793 | |
794 | // GEMM kernel problem description. |
795 | struct GEMMProblem : public CommonProblem { |
796 | Type Ta, Tb, Tc, Tco, Ts; // Types for A/B/C/C offsets/scalars in registers. |
797 | Type Ta_ext, Tb_ext, Tc_ext; // Types for A/B/C data in memory. |
798 | |
799 | Scalar<double> alpha_real, alpha_imag; // Alpha value, if fixed. |
800 | Scalar<double> beta_real, beta_imag; // Beta value, if fixed. |
801 | MatrixAddressing A, B, C, CO; // Addressing information for matrices. |
802 | bool checkBeta0 = true; // If true, check for beta = 0 and handle specially. |
803 | ABOffset abOffset = ABOffset::None; // A/B offset mode. |
804 | COffset cOffset = COffset::None; // C offset mode. |
805 | BatchMode batch = BatchMode::None; // Batch mode. |
806 | int batchDims = 0; // # of batch dimensions (strided batch only). |
807 | bool sumA = false, |
808 | sumB |
809 | = false; // If true, calculate A row sums/B column sums and store in CO. |
810 | post_ops_t postOps; // Fused post operations to apply |
811 | bool postOpFwd = true; // Eltwise parameters |
812 | std::vector<MatrixAddressing> binary; // Binary postop data |
813 | std::vector<Type> Tbinary; // Binary types |
814 | std::vector<bool> binaryRow; // Dimensionality of binary data |
815 | std::vector<bool> |
816 | binaryCol; // (false means broadcast in the given dimension) |
817 | std::vector<bool> binaryBatch; |
818 | |
819 | bool hasPostOp() const { return postOps.len() > 0; } |
820 | bool hasBinaryPostOp() const { |
821 | for (int idx = 0; idx < postOps.len(); idx++) |
822 | if (postOps.entry_[idx].is_binary()) return true; |
823 | return false; |
824 | } |
825 | |
826 | bool beta0() const { |
827 | return (beta_real == 0) && (!Tc.isComplex() || (beta_imag == 0)); |
828 | } |
829 | bool beta1() const { |
830 | return (beta_real == 1) && (!Tc.isComplex() || (beta_imag == 0)); |
831 | } |
832 | bool alpha1() const { |
833 | return (alpha_real == 1) && (!Tc.isComplex() || (alpha_imag == 0)); |
834 | } |
835 | bool alphaM1() const { |
836 | return (alpha_real == -1) && (!Tc.isComplex() || (alpha_imag == 0)); |
837 | } |
838 | |
839 | bool needsTsConvert() const { |
840 | if (!(alpha1() || alphaM1())) return true; |
841 | if (!(beta0() || beta1())) return true; |
842 | if (beta1() && !Tc_ext.isSubsetOf(Tc)) return true; |
843 | if (hasPostOp()) return true; |
844 | return false; |
845 | } |
846 | |
847 | bool gemmt() const { return false; } |
848 | bool backward() const { return false; } |
849 | |
850 | bool needsASums() const { return (abOffset == ABOffset::Calc) || sumA; } |
851 | bool needsBSums() const { return (abOffset == ABOffset::Calc) || sumB; } |
852 | bool usesCO() const { return (cOffset != COffset::None) || sumA || sumB; } |
853 | bool allowMatrixOffset() const { return (cOffset == COffset::Pre); } |
854 | }; |
855 | |
856 | struct GEMMState; |
857 | |
858 | // How to split A/B amongst threads in a workgroup. |
859 | enum class CoopSplit { |
860 | K, // Split in k dimension |
861 | MN, // Split in m/n dimensions |
862 | Linear, // Split in linear index order |
863 | }; |
864 | |
865 | // Strategy parameters for GEMM kernels. |
866 | struct GEMMStrategy : public CommonStrategy { |
867 | int blocking[3] = { |
868 | 0}; // Recommended block size in each dimension (m/n/k) -- for driver. |
869 | int blockingAlt[3] = { |
870 | 0}; // Alternate block size in each dimension (m/n/k) -- for driver. |
871 | // m/n alternates are for Hilbert-ordered kernels when Hilbert ordering disabled. |
872 | // k alternate is for multi-tile execution with implicit scaling. |
873 | int unroll[3]; // Unrolls in each dimension (m/n/k), indexed by LoopType. |
874 | int unrollK_masked = 0; // k unroll to use when masking. |
875 | LoopType loopOrder[3] = {LoopM, LoopN, |
876 | LoopK}; // Expected order of loops in driver code (in order from innermost to outermost). |
877 | LoopType fusedLoop = LoopM; // Direction of fusing if threads fused. |
878 | bool hilbertOrder = false; // Use Hilbert-like walk order in C? |
879 | bool boustrophedon = false; // Use panel-boustrophedon walk order in C? |
880 | bool persistent = false; // Use persistent thread model? |
881 | bool reverse[2] = {false, false}; // Reverse m/n walk order? |
882 | int fmaSIMD = 0; // Vector length for FMA (0 = default = 2 GRFs). |
883 | int kChain = 1; // # of FMAs to chain in k dimension. |
884 | int wg[3] = {0, 0, |
885 | 0}; // m/n/k workgroup sizes, 0 if unconstrained. Indexed by LoopType. |
886 | WGType forceWGUpdate = WGDynamic; // Force work group update type. |
887 | MatrixAddressingStrategy A, B, C, |
888 | CO; // Strategies for accessing A/B/C/C offsets. |
889 | int ka_load, kb_load; // How much of A/B is loaded at once, in k dimension |
890 | int ka_load_masked = 0, |
891 | kb_load_masked |
892 | = 0; // Same as above, when masking m/n (0 = default = same as ka/kb_load) |
893 | bool slmA = false, slmB = false; // Whether to copy A/B to SLM. |
894 | bool splitCopy = false; // Separate SLM copy and compute threads? |
895 | int slmBuffers = 0; // # of A/B SLM buffers, 0 for none. |
896 | int unrollKSLM |
897 | = 0; // k unroll for SLM copies (0 = auto = unroll[LoopK]/slmCopies) |
898 | int unrollKSLMMasked |
899 | = 0; // Alternate value to use with masking (0 = same as unrollKSLM) |
900 | bool slmATrans = false, |
901 | slmBTrans |
902 | = false; // Whether A/B SLM data should be completely crosspacked (transposed). |
903 | int A_copies = 1, |
904 | B_copies = 1; // # of copies of A/B matrices, for latency absorption |
905 | int slmCopies = 1; // # of copies of loaded A/B matrices for SLM copies. |
906 | bool slmRepackAhead = false; // Repack SLM data ahead of stores? |
907 | int optAlignAB |
908 | = 0; // Optional alignment for A/B. If > 0, create two versions of k loop, one for A/B aligned to this value, one not. |
909 | AccessType unalignedAccA, |
910 | unalignedAccB; // Access types to use for A/B on unaligned path. |
911 | int ka_prefetch = 0, kb_prefetch = 0; // Chunk size for prefetching A/B. |
912 | int ka_pfStride = 0, kb_pfStride = 0; // k stride between A/B prefetches. |
913 | bool cooperativePF = true; // Enable WG-cooperative A/B prefetches. |
914 | int prefetchA = 0, prefetchB = 0, |
915 | prefetchC = 0; // Prefetch distances, in units of unrollK. |
916 | int prefetchAMasked = 0, |
917 | prefetchBMasked = 0; // Same as above, when masking m/n. |
918 | MatrixAddressingStrategy A_prefetch, B_prefetch, |
919 | C_prefetch; // Strategies for prefetching A/B/C. |
920 | enum { |
921 | CSeparate, // C stored in its own bundle, A/B in the other bundle. |
922 | ACB, // A, then C, then B |
923 | BCA, // B, then C, then A |
924 | VNC, // A/B (broadcast matrix second), then C |
925 | ABInterleave, // A/B interleaved, then C |
926 | NSeparate, // Broadcast input stored in its own bundle(s) |
927 | VAvoid, // C registers allocated to avoid non-broadcast inputs |
928 | } registerScheme |
929 | = CSeparate; // Register layout scheme. |
930 | bool avoidIncConflicts |
931 | = true; // If true, duplicate some increment values across banks to avoid bundle conflicts. |
932 | bool kParallel |
933 | = false; // If true, generate k-parallelized kernel using global memory reduction. |
934 | bool kParallelLocal |
935 | = false; // If true, generate k-parallelized kernel using local memory reduction. |
936 | bool doubleWA |
937 | = false; // Use explicit double broadcast instructions? (Gen9 only) |
938 | int barrierFreq |
939 | = 0; // If > 0, set a periodic barrier every barrierFreq k loops to keep threads together. |
940 | bool splitBarrier |
941 | = false; // Use split barriers for these periodic barriers? |
942 | bool altCRemainder = false; // Use alternative double-loop C remainder code? |
943 | bool block2DCRemainder = false; // Generate block 2D C remainder path? |
944 | bool cAccumulators |
945 | = false; // Use accumulator registers for part of C (to save a few registers)? |
946 | bool cLoadAhead = false; // Load C before doing FMAs? |
947 | bool forceCopyC = false; // Force C to be copied before the update step? |
948 | bool noJumpTables = false; // Disallow jump tables? |
949 | RemainderHandling remHandling[3] = { |
950 | // m, n, k remainder handling. |
951 | RemainderHandling::Split, |
952 | RemainderHandling::Split, |
953 | RemainderHandling::General, |
954 | }; |
955 | bool jointSplit |
956 | = true; // Use remainder kernel for both m and n dimensions if both are split. |
957 | int mSplitThresh = 0, |
958 | nSplitThresh |
959 | = 0; // m/n minimum thresholds for using split remainder handling. 0 means always use split. |
960 | bool atomicFMA = false; // Use {Atomic} FMA chains. |
961 | bool extendedAtomicFMA = false; // Use longer {Atomic} FMA chains. |
962 | bool stallAfterLoad = false; // Insert stalls after load operations. |
963 | bool checkAdd32 |
964 | = false; // Check inside kernel if inner loop additions can be done in 32-bit. |
965 | bool delayABInc |
966 | = true; // Delay A/B increment a few outer products in the k loop. |
967 | CoopSplit coopA = CoopSplit:: |
968 | K; // How to split SLM copies, cooperative prefetches amongst threads in a workgroup |
969 | CoopSplit coopB = CoopSplit::K; |
970 | bool slmEarlyKMask |
971 | = false; // Prepare A/B reads to use k-masking (when applicable) in main loop, instead of waiting for remainder. |
972 | bool slmUseIncrCopy = true; // Use incremental SLM copies if needed. |
973 | bool slmAltBarriers = false; // Alternate fenceless SLM buffering algorithm. |
974 | bool strictFence |
975 | = false; // Add extra SLM fences that are not usually required on HW. |
976 | bool skipFence |
977 | = false; // Skip SLM fences that theoretically should be required but HW doesn't need. |
978 | bool slmFenceWARWA |
979 | = false; // Work around buggy SLM fence that doesn't protect against WAR hazards. |
980 | bool systolic = false; // Use systolic array if applicable. |
981 | bool dpasw = false; // Use DPASW for fused EU architectures. |
982 | bool fixedSystolic |
983 | = false; // Use hardcoded systolic inner loop for 32x32 or 32x48 unrolls. |
984 | int namedBarriers[2] = {0, |
985 | 0}; // # of named barriers in m, n dimensions (0 to use regular barriers). |
986 | bool skewLocalIDs |
987 | = false; // Remap local IDs for large workgroups so that threads on the same EU don't depend on the same data. |
988 | bool xParallel = false; // TRSM: parallelize in x dimension. |
989 | bool checkBeta1 |
990 | = false; // If true, check for beta = 1 and handle specially. |
991 | std::vector<MatrixAddressingStrategy> |
992 | binary; // Strategies for binary postop data |
993 | |
994 | bool insideSK = false; // Inside a superkernel? |
995 | |
996 | GEMMStrategy() {} |
997 | GEMMStrategy(ngen::HW hw, int stepping = 0) |
998 | : CommonStrategy(hw, stepping) {} |
999 | |
1000 | void preflight(ngen::HW hw, const GEMMProblem &problem); |
1001 | bool minimize(ngen::HW hw, const GEMMProblem &problem); |
1002 | |
1003 | bool lateExit() const { |
1004 | return (slmBuffers > 0) || barrierFreq || kParallelLocal |
1005 | || (cooperativePF && (prefetchA || prefetchB)); |
1006 | } |
1007 | |
1008 | int maxKSLM(const GEMMProblem &problem, bool isA) const; |
1009 | int slmABufBlockSize(const GEMMProblem &problem) const { |
1010 | return fixedSystolic ? 1152 |
1011 | : int(slmA) * problem.Ta * problem.Ta.components() |
1012 | * unroll[LoopM] * maxKSLM(problem, true); |
1013 | } |
1014 | int slmBBufBlockSize(const GEMMProblem &problem) const { |
1015 | return fixedSystolic ? 1536 |
1016 | : int(slmB) * problem.Tb * problem.Tb.components() |
1017 | * unroll[LoopN] * maxKSLM(problem, false); |
1018 | } |
1019 | int slmABufSize(const GEMMProblem &problem) const { |
1020 | return slmABufBlockSize(problem) * wg[LoopM] * wg[LoopK] * slmBuffers; |
1021 | } |
1022 | int slmBBufSize(const GEMMProblem &problem) const { |
1023 | return slmBBufBlockSize(problem) * wg[LoopN] * wg[LoopK] * slmBuffers; |
1024 | } |
1025 | int slmSysgemmBlockSize() const { |
1026 | return 1152 * wg[LoopM] + 1536 * wg[LoopN]; |
1027 | } |
1028 | bool variableSLM() const { return kParallelLocal; } |
1029 | |
1030 | int ka_inc() const { return slmA ? unrollKSLM : ka_load; } |
1031 | int kb_inc() const { return slmB ? unrollKSLM : kb_load; } |
1032 | |
1033 | bool needsMNLocalIDs() const { |
1034 | return xParallel || (slmBuffers > 0) || cooperativePF || kParallelLocal |
1035 | || persistent || namedBarriers[0] || (dpasw && !fixedSystolic); |
1036 | } |
1037 | bool needsKLocalIDs() const { return kParallelLocal || persistent; } |
1038 | bool needsBarrier() const { |
1039 | return (barrierFreq > 0) || (slmBuffers > 0) || xParallel |
1040 | || kParallelLocal; |
1041 | } |
1042 | |
1043 | bool fusedM() const { return fused && (fusedLoop == LoopM); } |
1044 | bool fusedN() const { return fused && (fusedLoop == LoopN); } |
1045 | |
1046 | WGType getWGType(const GEMMProblem &problem) const { |
1047 | if ((slmBuffers > 0) || (forceWGUpdate == WGFixed) |
1048 | || (barrierFreq && namedBarriers[0])) |
1049 | return WGFixed; |
1050 | else |
1051 | return WGDynamic; |
1052 | } |
1053 | |
1054 | bool fixedWG(const GEMMProblem &problem) const { |
1055 | return (getWGType(problem) == WGFixed); |
1056 | } |
1057 | bool linearOrder() const { return hilbertOrder || boustrophedon; } |
1058 | }; |
1059 | |
1060 | struct LDMultiples { |
1061 | ngen::GRFRange range; |
1062 | bool a64 = false; |
1063 | }; |
1064 | |
1065 | // State parameters for GEMM kernels. |
1066 | struct GEMMState : public CommonState { |
1067 | struct Inputs { |
1068 | ngen::Subregister A, B, C[2], CO, base; // q |
1069 | ngen::Subregister ao, bo, abo; // w/w/ud |
1070 | ngen::Subregister aoPtr, boPtr; // q |
1071 | ngen::Subregister offsetA, offsetB, offsetC[2]; // q |
1072 | ngen::Subregister offsetCO; // d |
1073 | ngen::Subregister lda, ldb, ldc[2], ldco; // d |
1074 | ngen::Subregister m, n, k, k0; // d |
1075 | ngen::Subregister alpha_real, alpha_imag; // T_real |
1076 | ngen::Subregister beta_real, beta_imag; // T_real |
1077 | ngen::Subregister groupIDM, groupIDN, groupIDK; // ud |
1078 | ngen::Subregister groupIDMN; // ud |
1079 | ngen::GRF localIDM, localIDN, localIDK; // uw |
1080 | ngen::Subregister localSizeM, localSizeN, localSizeK; // ud |
1081 | ngen::Subregister groupCountM, groupCountN; // ud |
1082 | ngen::Subregister groupCountMN; // ud |
1083 | ngen::Subregister groupStride; // ud |
1084 | ngen::Subregister hilbertVD, hilbertUVDRecip; // ud |
1085 | ngen::Subregister hilbertBail; // ud |
1086 | ngen::Subregister bslice, bthresh; // d |
1087 | ngen::Subregister flags; // ud |
1088 | ngen::Subregister diagA, diagB, diagC; // q |
1089 | uint8_t surfaceA, surfaceB; // BTS indices |
1090 | uint8_t surfaceC[2], surfaceCO; // BTS indices |
1091 | ngen::Subregister strideA[2], strideB[2], |
1092 | strideC[2]; // ud, used for strided batch. |
1093 | ngen::Subregister batchSize1, recipBatchSize1; // ud, 2D strided batch |
1094 | ngen::Subregister offsetBatch; // ud, used for non-strided batch. |
1095 | ngen::Subregister incr_a_array, |
1096 | incr_b_array; // ud, used for non-strided variable batch. |
1097 | ngen::Subregister incr_alpha, |
1098 | incr_beta; // ud, used for non-strided variable batch. |
1099 | ngen::Subregister alpha_array, |
1100 | beta_array; // q, used for non-strided variable batch. |
1101 | std::vector<ngen::Subregister> binarySrcs; // q |
1102 | std::vector<ngen::Subregister> binaryOffsets; // q/d |
1103 | std::vector<ngen::Subregister> binaryLDs; // d |
1104 | std::vector<std::array<ngen::Subregister, 2>> binaryStrides; // d |
1105 | std::vector<uint8_t> binarySurfaces; |
1106 | } inputs; |
1107 | Type Ta_load, Tb_load; // Current type to be loaded into A/B_regs. |
1108 | Type Tacc; // Current type in accumulator registers. |
1109 | ngen::Subregister persistentGroupID; // ud |
1110 | ngen::Subregister batchID[2]; // ud |
1111 | ngen::Subregister offsetA, offsetB, offsetC[2]; |
1112 | ngen::Subregister offsetAp, offsetBp, offsetCp; |
1113 | ngen::Subregister offsetCO; |
1114 | ngen::Subregister saveOffsetA, saveOffsetB, saveOffsetC[2]; |
1115 | ngen::Subregister saveOffsetCO; |
1116 | ngen::Subregister fullK; |
1117 | ngen::Subregister effA, effB, effC[2], |
1118 | effCO; // Offsets to base of A/B/C/CO chunks for loading/storing. |
1119 | ngen::Subregister effAi, effBi; |
1120 | ngen::Subregister effAo, effBo; |
1121 | ngen::Subregister effAp, effBp, effCp; |
1122 | ngen::Subregister effAs, effBs; |
1123 | std::vector<ngen::GRFRange> A_addrs, B_addrs, C_addrs[2]; |
1124 | std::vector<ngen::GRFRange> A_addrsRem, B_addrsRem; |
1125 | std::vector<ngen::GRFRange> Ai_addrs, Bi_addrs; |
1126 | std::vector<std::vector<ngen::GRFRange>> Ai_addrsK, Bi_addrsK; |
1127 | std::vector<ngen::GRFRange> Ai_addrsRem, Bi_addrsRem; |
1128 | std::vector<ngen::GRFRange> Ao_addrs, Bo_addrs; |
1129 | std::vector<ngen::GRFRange> Ap_addrs, Bp_addrs, Cp_addrs; |
1130 | std::vector<GRFMultirange> A_regs, B_regs, C_regs; |
1131 | GRFMultirange Ar_regs, Br_regs; // Repacked A/B registers. |
1132 | std::vector<GRFMultirange> Ai_regs, |
1133 | Bi_regs; // Incoming data to copy to SLM. |
1134 | std::vector<GRFMultirange> Ai_regsRem, Bi_regsRem; |
1135 | GRFMultirange Ao_regs, Bo_regs; // Outgoing data to copy to SLM. |
1136 | GRFMultirange Ao_regsRem, Bo_regsRem; |
1137 | GRFMultirange As_regs, Bs_regs; // A row sums/B column sums. |
1138 | GRFMultirange Ap_regs, Bp_regs, Cp_regs; // A/B/C prefetch registers. |
1139 | std::vector<MaskAssignment> AB_masks; |
1140 | ngen::GRFRange broadcast_regs; |
1141 | std::vector<ngen::GRFRange> tempMul_regs; |
1142 | ngen::Subregister i0, j0, h0; // d |
1143 | ngen::Subregister remainders[3]; // d (todo: w) |
1144 | ngen::Subregister remaindersFused[2]; // w |
1145 | ngen::Subregister remaindersWG[2]; // d (todo: w) |
1146 | ngen::Subregister remFusedStorage; // d |
1147 | ngen::Subregister diagC; // d |
1148 | SubregisterPair lda, ldb; |
1149 | SubregisterPair lda_ka, ldb_kb; // Cached lda * ka, ldb * kb |
1150 | SubregisterPair lda_ka_prefetch, |
1151 | ldb_kb_prefetch; // Cached lda * ka_pfStride, ldb * kb_pfStride |
1152 | LDMultiples ldaMultiples, ldbMultiples, ldcMultiples[2]; |
1153 | int ka_cached = 0, kb_cached = 0; // Multipliers for lda_ka/ldb_kb. |
1154 | ngen::Subregister k, K; // d |
1155 | ngen::FlagRegister flagAP; |
1156 | ngen::Subregister beta1; // d |
1157 | ngen::Subregister add64; // uw |
1158 | ngen::Subregister lidM, lidN, lidStorage; // uw, uw, ud |
1159 | ngen::Subregister lidK, lszK, lidszKStorage; // uw, uw, ud |
1160 | ngen::Subregister ia0_slm, jb0_slm; // uw |
1161 | ngen::Subregister postRemA, postRemB; // ud |
1162 | ngen::Subregister postRemAi, postRemBi; // ud |
1163 | ngen::Subregister postRemAo, postRemBo; // ud |
1164 | ngen::Subregister isCompute; // ud |
1165 | ngen::GRF sysSumAll1s; // Ta/Tb |
1166 | bool systolicSumA = false, systolicSumB = false; |
1167 | bool lateKLoopCheck = false; |
1168 | int ka_loadRem, kb_loadRem; |
1169 | bool Ai_hasKRem, Bi_hasKRem; |
1170 | bool Ai_lateKRem, Bi_lateKRem; |
1171 | bool Ai_incrementalRem, Bi_incrementalRem; |
1172 | bool Ai_remIncrCopy, Bi_remIncrCopy; |
1173 | int ma_slm, ka_slm, kb_slm, nb_slm; |
1174 | int ma_prefetch, ka_prefetch, kb_prefetch, nb_prefetch; |
1175 | CoopSplit effCoopA = CoopSplit::K; |
1176 | CoopSplit effCoopB = CoopSplit::K; |
1177 | std::vector<RegisterBlock> A_layout, B_layout, C_layout; |
1178 | std::vector<RegisterBlock> A_layoutRem, B_layoutRem; |
1179 | std::vector<RegisterBlock> Ar_layout, Br_layout; |
1180 | std::vector<RegisterBlock> Ai_layout, Bi_layout; |
1181 | std::vector<std::vector<RegisterBlock>> Ai_layoutK, Bi_layoutK; |
1182 | std::vector<RegisterBlock> Ai_layoutRem, Bi_layoutRem; |
1183 | std::vector<RegisterBlock> Ao_layout, Bo_layout; |
1184 | std::vector<RegisterBlock> As_layout, Bs_layout; |
1185 | std::vector<RegisterBlock> Ap_layout, Bp_layout, Cp_layout; |
1186 | std::vector<RegisterBlock> C_layoutExt, C_layoutExtUnmasked; |
1187 | Address2DParams A_params, B_params; |
1188 | Address2DParams Ai_params, Bi_params; |
1189 | Address2DParams Ap_params, Bp_params; |
1190 | int Ai_regCount = 0, Bi_regCount = 0; |
1191 | bool aioShare, bioShare; |
1192 | bool aioShareRem, bioShareRem; |
1193 | bool aoReuseA = false, boReuseB = false; |
1194 | MatrixAddressing Ai, Bi, Ao, Bo; |
1195 | MatrixAddressingStrategy Ai_strategy, Bi_strategy; |
1196 | MatrixAddressingStrategy Ao_strategy, Bo_strategy; |
1197 | MatrixAddressingStrategy Cext_strategy; |
1198 | int8_t tokenBarrierFence[2]; |
1199 | ngen::InstructionModifier modBarrierFence[2]; |
1200 | bool barrierReady = false; |
1201 | ngen::GRF ; |
1202 | ngen::GRF , ; |
1203 | ngen::FlagRegister barrierM, barrierN; |
1204 | bool firstKLoopSegment; |
1205 | bool isNested = false; |
1206 | int C_accCount; |
1207 | bool cSwapActive = false; |
1208 | int C_count = 1; |
1209 | int C_buffers = 1; |
1210 | bool allocedAo = false, allocedBo = false; |
1211 | bool allowEmptyC = false; |
1212 | bool copyC = false; |
1213 | bool broadcast; |
1214 | bool repackA = false, repackB = false; |
1215 | bool repackARem = false, repackBRem = false; |
1216 | int ka_repackRem, kb_repackRem; |
1217 | bool remActiveA, remActiveB, remActiveSLM; |
1218 | std::vector<MaskAssignment> kMasksSLM; |
1219 | bool slmRemaskA = false, slmRemaskB = false; |
1220 | bool slmASums = false, slmBSums = false; |
1221 | bool doLateExit = false; |
1222 | ngen::GRF emulate64TempSave[2]; |
1223 | |
1224 | std::vector<ngen::Subregister> effBinary; |
1225 | |
1226 | struct { |
1227 | bool active = false; |
1228 | uint8_t surfacePlan; |
1229 | ngen::Subregister plan; |
1230 | ngen::Subregister slotA, slotB; |
1231 | ngen::Subregister localIDFlat; |
1232 | ngen::FlagRegister needLateGEMMDone; |
1233 | } fusedGEMM; |
1234 | |
1235 | struct { |
1236 | ngen::InstructionModifier depAddr[4]; |
1237 | } sysgemm; |
1238 | |
1239 | GEMMState(ngen::HW hw) : CommonState(hw) {} |
1240 | }; |
1241 | |
1242 | // GEMM superkernel problem. |
1243 | struct GEMMSuperkernelProblem : public GEMMProblem {}; |
1244 | |
1245 | // GEMM superkernel strategy parameters. |
1246 | struct GEMMSuperkernelStrategy { |
1247 | std::vector<GEMMStrategy> substrategies; |
1248 | KernelScheduling schedule; |
1249 | bool multiM, multiN; |
1250 | bool persistent = false; |
1251 | |
1252 | void preflight(ngen::HW hw, const GEMMProblem &problem); |
1253 | int subgroupSize() const { return substrategies[0].subgroupSize; } |
1254 | }; |
1255 | |
1256 | // GEMM superkernel state. |
1257 | struct GEMMSuperkernelState : public GEMMState { |
1258 | struct { |
1259 | uint8_t surfacePlan; |
1260 | ngen::Subregister planCount; |
1261 | ngen::GRF localID; |
1262 | ngen::Subregister localSize; |
1263 | } inputsSK; |
1264 | ngen::Subregister last_i0, last_j0, last_h0; |
1265 | |
1266 | GEMMSuperkernelState(ngen::HW hw) : GEMMState(hw) {} |
1267 | }; |
1268 | |
1269 | // Copy kernel problem description: D <- alpha*S |
1270 | struct CopyProblem : public CommonProblem { |
1271 | Type Ts, Td, Tsum; |
1272 | Scalar<double> alpha_real, alpha_imag; |
1273 | MatrixAddressing S, D; |
1274 | bool conjugate = false; |
1275 | bool lower; |
1276 | bool unit; |
1277 | bool trsm = false; |
1278 | bool sum = false; |
1279 | int targetWG = 1; |
1280 | |
1281 | bool reflecting() const { return false; } |
1282 | }; |
1283 | |
1284 | // Strategy parameters for copy kernels. |
1285 | struct CopyStrategy : public CommonStrategy { |
1286 | MatrixAddressingStrategy S, D; |
1287 | RemainderHandling remHandlingX, |
1288 | remHandlingY; // Remainder handling for X dimension (packed dimension) and Y dimension (length of panel) |
1289 | int s_load, d_load; // # of rows/columns to load from S/store to D at once |
1290 | int s_load_masked = 0, |
1291 | d_load_masked |
1292 | = 0; // Same as s_load/d_load, for use when masking (0 = default = same as {s,d}_load) |
1293 | int wgW = 0, wgZ = 0; // Fixed workgroup sizes (0 if variable). |
1294 | |
1295 | int unrollX, unrollY; // Unrolls for each dimension. |
1296 | bool duplicateAlpha |
1297 | = true; // True to make two copies of alpha, one for each register bank |
1298 | bool xLoop |
1299 | = false; // True to loop over x, false to loop over y within a kernel |
1300 | |
1301 | bool zParallel = false; // Kernel parallelized in z dimension? |
1302 | |
1303 | int barrierFreq = 0; // If > 0, set a barrier every barrierFreq loops |
1304 | int optionalAlignS |
1305 | = 0; // If > 0, generate code to check if S is aligned to this #elements and branch to specific code for that case. |
1306 | |
1307 | CopyStrategy() {} |
1308 | CopyStrategy(ngen::HW hw, int stepping = 0) |
1309 | : CommonStrategy(hw, stepping) {} |
1310 | |
1311 | void preflight(ngen::HW hw, const CopyProblem &problem); |
1312 | |
1313 | int unrollW() const { return xLoop ? unrollY : unrollX; } |
1314 | int unrollZ() const { return xLoop ? unrollX : unrollY; } |
1315 | }; |
1316 | |
1317 | // State parameters for copy kernels. |
1318 | struct CopyState : public CommonState { |
1319 | struct { |
1320 | ngen::Subregister S, D; // q |
1321 | ngen::Subregister offsetS, offsetD; // q |
1322 | ngen::Subregister lds, ldd; // d |
1323 | ngen::Subregister m, n; // d |
1324 | ngen::Subregister alpha_real; // T_real |
1325 | ngen::Subregister alpha_imag; // T_real |
1326 | ngen::Subregister groupIDW, groupIDZ; // ud |
1327 | ngen::GRF localIDW, localIDZ; // uw |
1328 | ngen::Subregister localSizeW, localSizeZ; // ud |
1329 | ngen::Subregister diag; // d |
1330 | ngen::Subregister blockZ; // ud |
1331 | uint8_t surfaceS, surfaceD; // DTS indices |
1332 | } inputs; |
1333 | ngen::Subregister w0, z0; // ud |
1334 | ngen::Subregister effS, |
1335 | effD; // Offsets to base of S/D chunks for loading/storing. |
1336 | ngen::Subregister offsetS1, |
1337 | effS1; // Reflected variants of offsetS/effS for symmetric/Hermitian. |
1338 | std::vector<ngen::GRFRange> S_addrs, D_addrs; |
1339 | std::vector<ngen::GRFRange> S_addrSrcs[2]; |
1340 | ngen::GRFRange S_regs, D_regs; |
1341 | std::vector<ngen::GRFRange> Ds_regs; |
1342 | ngen::Subregister lds_sl; // d |
1343 | ngen::Subregister ldd_dl; // d |
1344 | ngen::Subregister Z; // d |
1345 | ngen::FlagRegister flagAP, flagTri, flagDiag; |
1346 | ngen::FlagRegister flagReflect; |
1347 | std::vector<RegisterBlock> S_layout, D_layout; |
1348 | std::vector<RegisterBlock> Ds_layout; |
1349 | ngen::Subregister remainderX, remainderY; // ud |
1350 | ngen::GRFRange complexOne; // T_real |
1351 | ngen::GRF indexVecRT; // uw |
1352 | |
1353 | bool isNested; |
1354 | |
1355 | struct { |
1356 | bool active = false; |
1357 | } fusedGEMM; |
1358 | |
1359 | CopyState(ngen::HW hw) : CommonState(hw) {} |
1360 | |
1361 | void dump(); |
1362 | }; |
1363 | |
1364 | template <ngen::HW hw> |
1365 | class gemm_kernel_generator_t : public jit_generator<hw> { |
1366 | public: |
1367 | using super = ngen::OpenCLCodeGenerator<hw>; |
1368 | gemm_kernel_generator_t() {} |
1369 | |
1370 | NGEN_FORWARD_OPENCL(hw); |
1371 | |
1372 | using Injector = jit_post_op_injector<hw>; |
1373 | std::unique_ptr<Injector> postOpInjector; |
1374 | |
1375 | static bool supportedBinaryOp(alg_kind_t alg) { |
1376 | using namespace alg_kind; |
1377 | return utils::one_of(alg, binary_add, binary_sub, binary_mul, |
1378 | binary_div, binary_min, binary_max); |
1379 | } |
1380 | |
1381 | void gemm(GEMMProblem problem, GEMMStrategy strategy, |
1382 | const ngen::InterfaceHandler &interface_); |
1383 | void gemmSuperkernel(GEMMSuperkernelProblem problem, |
1384 | GEMMSuperkernelStrategy strategy, |
1385 | const ngen::InterfaceHandler &interface_); |
1386 | void copy(CopyProblem problem, CopyStrategy strategy, |
1387 | const ngen::InterfaceHandler &interface_); |
1388 | |
1389 | static CommonDriverInfo driverInfo( |
1390 | const GEMMProblem &problem, const GEMMStrategy &strategy); |
1391 | static CommonDriverInfo driverInfo(const GEMMSuperkernelProblem &problem, |
1392 | const GEMMStrategy &strategy); |
1393 | static CommonDriverInfo driverInfo( |
1394 | const CopyProblem &problem, const CopyStrategy &strategy); |
1395 | |
1396 | protected: |
1397 | ngen::InterfaceHandler |
1398 | &interface = ngen::OpenCLCodeGenerator<hw>::interface_; |
1399 | |
1400 | std::exception_ptr lastException; |
1401 | |
1402 | std::ostream &getOutStream() const { return std::cerr; } |
1403 | |
1404 | std::ostream ¬eStream() const { return getOutStream(); } |
1405 | |
1406 | class status_stream { |
1407 | protected: |
1408 | char cc; |
1409 | std::stringstream line; |
1410 | bool lineStart = true; |
1411 | |
1412 | gemm_kernel_generator_t<hw> &parent; |
1413 | |
1414 | friend class gemm_kernel_generator_t<hw>; |
1415 | |
1416 | public: |
1417 | status_stream(gemm_kernel_generator_t<hw> &parent_, int color = 1) |
1418 | : cc(color + '0'), parent(parent_) {} |
1419 | |
1420 | static constexpr struct Endl { |
1421 | } endl {}; |
1422 | |
1423 | template <typename T> |
1424 | status_stream &operator<<(const T &obj) { |
1425 | return *this; |
1426 | } |
1427 | |
1428 | status_stream &operator<<(const Endl &e) { return *this; } |
1429 | } status {*this}; |
1430 | |
1431 | #ifdef SHOW_DISCARDS |
1432 | void discardStream() { |
1433 | InstructionStream *s = popStream(); |
1434 | auto oldCC = status.cc; |
1435 | status.cc = '4'; |
1436 | status << "------- \x1B[32mBEGIN\x1B[34m discarded stream -------" |
1437 | << status_stream::endl; |
1438 | auto &sbuffer = *reinterpret_cast<std::ostringstream *>(s->getBuffer()); |
1439 | auto str = sbuffer.str(); |
1440 | bool lastNL = false; |
1441 | for (int l = 0; l < str.length(); l++) { |
1442 | char c = str[l]; |
1443 | |
1444 | if (c == '\n') { |
1445 | if (lastNL) status << "//" ; |
1446 | status << status_stream::endl; |
1447 | lastNL = true; |
1448 | } else { |
1449 | status << c; |
1450 | lastNL = false; |
1451 | } |
1452 | } |
1453 | status << "------- \x1B[32mEND\x1B[34m discarded stream -------" |
1454 | << status_stream::endl; |
1455 | status.cc = status.cc; |
1456 | delete s; |
1457 | } |
1458 | #endif |
1459 | |
1460 | enum class HintType { |
1461 | Bank0, |
1462 | Bank1, |
1463 | TempComp0, |
1464 | TempComp1, |
1465 | LongTerm, |
1466 | LongTerm0, |
1467 | LongTerm1, |
1468 | R0Info, |
1469 | A0, |
1470 | A0Broadcast, |
1471 | A1, |
1472 | A1Broadcast, |
1473 | B0, |
1474 | B0Broadcast, |
1475 | B1, |
1476 | B1Broadcast, |
1477 | C, |
1478 | C1, |
1479 | CLoad, |
1480 | S, |
1481 | D, |
1482 | SAddr, |
1483 | DAddr |
1484 | }; |
1485 | enum class StdCRemType { Ignore, Mask, Descriptor }; |
1486 | enum class COperation { Load, Update, UpdateStore }; |
1487 | enum class KLoop { |
1488 | GEMM, |
1489 | }; |
1490 | |
1491 | friend std::ostream &operator<<(std::ostream &s, StdCRemType rt) { |
1492 | const char *names[3] = {"ignore" , "mask" , "custom descriptor" }; |
1493 | return (s << names[static_cast<int>(rt)]); |
1494 | } |
1495 | |
1496 | ngen::FlagRegister getPhysicalFlag(VirtualFlag vflag, CommonState &state); |
1497 | void allocVFlagStorage(const CommonStrategy &strategy, CommonState &state); |
1498 | |
1499 | ngen::Bundle getHint(HintType type); |
1500 | ngen::Bundle getHint(HintType type, const CommonStrategy &strategy); |
1501 | ngen::Bundle getHint(HintType type, const GEMMStrategy &strategy); |
1502 | ngen::Bundle getHint(HintType type, const CopyStrategy &strategy); |
1503 | |
1504 | void goto12(const ngen::InstructionModifier &mod, ngen::Label &jip) { |
1505 | goto12(mod, jip, jip); |
1506 | } |
1507 | void goto12(const ngen::InstructionModifier &mod, ngen::Label &jip, |
1508 | ngen::Label &uip, bool branchCtrl = false); |
1509 | |
1510 | template <typename DT = void> |
1511 | void mulConstant(const ngen::InstructionModifier &mod, |
1512 | const ngen::RegData &dst, const ngen::RegData &src0, int32_t src1); |
1513 | |
1514 | friend struct EmulationImplementation; |
1515 | template <typename DT = void> |
1516 | void emov(const ngen::InstructionModifier &mod, ngen::RegData dst, |
1517 | ngen::RegData src0, const CommonStrategy &strategy, |
1518 | CommonState &state); |
1519 | template <typename DT = void> |
1520 | void emov(const ngen::InstructionModifier &mod, ngen::RegData dst, |
1521 | ngen::Immediate src0, const CommonStrategy &strategy, |
1522 | CommonState &state) { |
1523 | EmulationImplementation::emov<DT>( |
1524 | *this, mod, dst, src0, strategy.emulate); |
1525 | } |
1526 | template <typename DT = void> |
1527 | void eadd(const ngen::InstructionModifier &mod, const ngen::RegData &dst, |
1528 | const ngen::RegData &src0, const ngen::RegData &src1, |
1529 | const CommonStrategy &strategy, CommonState &state); |
1530 | template <typename DT = void> |
1531 | void eadd(const ngen::InstructionModifier &mod, const ngen::RegData &dst, |
1532 | const ngen::RegData &src0, ngen::Immediate src1, |
1533 | const CommonStrategy &strategy, const CommonState &state) { |
1534 | EmulationImplementation::eadd<DT>( |
1535 | *this, mod, dst, src0, src1, strategy.emulate, state.emulate); |
1536 | } |
1537 | template <typename DT = void> |
1538 | void emul(const ngen::InstructionModifier &mod, const ngen::RegData &dst, |
1539 | const ngen::RegData &src0, const ngen::RegData &src1, |
1540 | const CommonStrategy &strategy, const CommonState &state) { |
1541 | EmulationImplementation::emul<DT>( |
1542 | *this, mod, dst, src0, src1, strategy.emulate, state.emulate); |
1543 | } |
1544 | template <typename DT = void> |
1545 | void emul(const ngen::InstructionModifier &mod, const ngen::RegData &dst, |
1546 | const ngen::RegData &src0, ngen::Immediate src1, |
1547 | const CommonStrategy &strategy, const CommonState &state) { |
1548 | EmulationImplementation::emul<DT>( |
1549 | *this, mod, dst, src0, src1, strategy.emulate, state.emulate); |
1550 | } |
1551 | template <typename DT = void> |
1552 | void eshl(const ngen::InstructionModifier &mod, ngen::RegData dst, |
1553 | ngen::RegData src0, uint16_t src1, const CommonStrategy &strategy, |
1554 | const CommonState &state) { |
1555 | EmulationImplementation::eshl<DT>( |
1556 | *this, mod, dst, src0, src1, strategy.emulate, state.emulate); |
1557 | } |
1558 | template <typename DT = void> |
1559 | void eshr(const ngen::InstructionModifier &mod, ngen::RegData dst, |
1560 | ngen::RegData src0, uint16_t src1, const CommonStrategy &strategy, |
1561 | const CommonState &state) { |
1562 | EmulationImplementation::eshr<DT>( |
1563 | *this, mod, dst, src0, src1, strategy.emulate, state.emulate); |
1564 | } |
1565 | template <typename DT = void> |
1566 | void emulConstant(const ngen::InstructionModifier &mod, |
1567 | const ngen::RegData &dst, const ngen::RegData &src0, int32_t src1, |
1568 | const CommonStrategy &strategy, const CommonState &state) { |
1569 | EmulationImplementation::emulConstant<DT>( |
1570 | *this, mod, dst, src0, src1, strategy.emulate, state.emulate); |
1571 | } |
1572 | template <typename S1> |
1573 | void emul32High(const ngen::InstructionModifier &mod, |
1574 | const ngen::RegData &dstHi, const ngen::RegData &src0, |
1575 | const S1 &src1) { |
1576 | EmulationImplementation::emul32High(*this, mod, dstHi, src0, src1); |
1577 | } |
1578 | |
1579 | template <typename S0, typename S2> |
1580 | void emad(const ngen::InstructionModifier &mod, const ngen::RegData &dst, |
1581 | const S0 &src0, const ngen::RegData &src1, const S2 &src2, |
1582 | const CommonStrategy &strategy, CommonState &state); |
1583 | template <typename S0> |
1584 | void emad(const ngen::InstructionModifier &mod, const ngen::RegData &dst, |
1585 | const S0 &src0, const ngen::RegData &src1, int32_t src2, |
1586 | const CommonStrategy &strategy, CommonState &state); |
1587 | template <typename DT = void, typename S0, typename S2> |
1588 | void eadd3(const ngen::InstructionModifier &mod, const ngen::RegData &dst, |
1589 | const S0 &src0, const ngen::RegData &src1, const S2 &src2); |
1590 | |
1591 | template <typename DT = void> |
1592 | void emath(const ngen::InstructionModifier &mod, ngen::MathFunction fc, |
1593 | const ngen::RegData &dst, const ngen::RegData &src0, |
1594 | const GEMMStrategy &strategy, CommonState &state); |
1595 | template <typename DT = void> |
1596 | void einv(const ngen::InstructionModifier &mod, const ngen::RegData &dst, |
1597 | const ngen::RegData &src0, const GEMMStrategy &strategy, |
1598 | CommonState &state) { |
1599 | emath<DT>(mod, ngen::MathFunction::inv, dst, src0, strategy, state); |
1600 | } |
1601 | template <typename DT = void> |
1602 | void esqt(const ngen::InstructionModifier &mod, const ngen::RegData &dst, |
1603 | const ngen::RegData &src0, const GEMMStrategy &strategy, |
1604 | CommonState &state) { |
1605 | emath<DT>(mod, ngen::MathFunction::sqt, dst, src0, strategy, state); |
1606 | } |
1607 | |
1608 | void ejmpi(ngen::InstructionModifier mod, ngen::Label &dst); |
1609 | |
1610 | void cmp0(const ngen::InstructionModifier &mod, ngen::RegData src0); |
1611 | void syncall(); |
1612 | |
1613 | void wrdepRanges(const std::vector<GRFMultirange> &rrs) { |
1614 | for (auto &rr : rrs) |
1615 | for (auto &r : rr.ranges) |
1616 | wrdep(r); |
1617 | } |
1618 | |
1619 | void addScaled(const ngen::InstructionModifier &mod, |
1620 | const ngen::RegData &dst, int src0, const ngen::RegData &src1, |
1621 | int numerator, int denominator, CommonState &state, |
1622 | bool exact = false); |
1623 | void addScaled(const ngen::InstructionModifier &mod, |
1624 | const ngen::RegData &dst, const ngen::RegData &src0, |
1625 | const ngen::RegData &src1, int numerator, int denominator, |
1626 | CommonState &state, bool exact = false); |
1627 | void addScaled(const ngen::InstructionModifier &mod, |
1628 | const ngen::RegData &dst, const ngen::RegData &src0, int src1, |
1629 | int numerator, int denominator, CommonState &state, |
1630 | bool exact = false); |
1631 | |
1632 | template <typename DT = void> |
1633 | void mod(const ngen::Subregister &dst, const ngen::Subregister &src, |
1634 | uint16_t modulus, const CommonStrategy &strategy, |
1635 | CommonState &state); |
1636 | template <typename DT = void> |
1637 | void modExt(const ngen::Subregister &dstMod, |
1638 | const ngen::Subregister &dstMultiple, const ngen::Subregister &src, |
1639 | uint16_t modulus, const CommonStrategy &strategy, |
1640 | CommonState &state); |
1641 | template <typename DT = void> |
1642 | void alignDown(const ngen::Subregister &dst, const ngen::Subregister &src, |
1643 | uint16_t align, const CommonStrategy &strategy, CommonState &state); |
1644 | template <typename DT = void> |
1645 | void alignUp(const ngen::Subregister &dst, const ngen::Subregister &src, |
1646 | uint16_t align, const CommonStrategy &strategy, CommonState &state); |
1647 | template <typename DT = void> |
1648 | void divDown(const ngen::Subregister &dst, const ngen::Subregister &src0, |
1649 | const ngen::Subregister &src1, const ngen::Subregister &src1Recip, |
1650 | const ngen::FlagRegister &flag, const CommonStrategy &strategy, |
1651 | CommonState &state); |
1652 | template <typename DT = void> |
1653 | void divDown(const ngen::Subregister &dst, const ngen::Subregister &src, |
1654 | uint16_t divisor, const CommonStrategy &strategy, |
1655 | CommonState &state); |
1656 | |
1657 | void simtDoWhileLoop( |
1658 | const ngen::InstructionModifier &mod, ngen::Label &dest); |
1659 | void slmBarrier(const ngen::GRF &temp, const ngen::GRF &r0_info = r0); |
1660 | void globalMemBarrier(const ngen::GRF &temp, const ngen::GRF &r0_info = r0); |
1661 | void pause(const CommonStrategy &strategy); |
1662 | |
1663 | void duplicateScalar(SubregisterPair &val, CommonState &state); |
1664 | void deduplicateScalar(SubregisterPair &val, CommonState &state); |
1665 | template <typename T> |
1666 | void duplicateScalar(Scalar<T> &val, CommonState &state); |
1667 | MultishiftSubregister multishift(const ngen::Subregister ®, |
1668 | unsigned shifts, const CommonStrategy &strategy, CommonState &state, |
1669 | ngen::Bundle hint = ngen::Bundle()); |
1670 | |
1671 | void getFusedID(int scale, const CommonProblem &problem, |
1672 | const CommonStrategy &strategy, CommonState &state); |
1673 | void moveR0(const CommonStrategy &strategy, CommonState &state); |
1674 | void moveR0(const GEMMStrategy &strategy, GEMMState &state); |
1675 | template <typename F> |
1676 | void useR0(CommonState &state, F f); |
1677 | void removeSG(const CommonProblem &problem, const CommonStrategy &strategy, |
1678 | const CommonState &state); |
1679 | void reorderFusedEUs(const GEMMProblem &problem, |
1680 | const GEMMStrategy &strategy, GEMMState &state); |
1681 | ngen::Subregister copySubregister(const ngen::Subregister ®, |
1682 | CommonState &state, |
1683 | ngen::Bundle hint = ngen::Bundle(ngen::Bundle::any, 0)); |
1684 | void zeroMatrix(const GRFMultirange &r, const CommonStrategy &strategy); |
1685 | void releaseFusedRemainders(GEMMState &state); |
1686 | void saveMNLocalIDs(const GEMMStrategy &strategy, GEMMState &state); |
1687 | void saveKLocalIDSize(const GEMMStrategy &strategy, GEMMState &state); |
1688 | void releaseSavedMNLocalIDs(GEMMState &state); |
1689 | |
1690 | void doReadSuppressionWA( |
1691 | const CommonStrategy &strategy, CommonState &state); |
1692 | |
1693 | bool getBlockInfo(Type T, const MatrixAddressing &atype, |
1694 | const MatrixAddressingStrategy &astrategy, int r, int c, |
1695 | bool remainderR, bool remainderC, bool writable, bool avoidFragment, |
1696 | int maxRBlock, int maxCBlock, int &rblock, int &cblock, |
1697 | RegisterBlock &layout); |
1698 | bool getSubblock(Type T, RegisterBlock &blockDst, |
1699 | const RegisterBlock &blockSrc, bool column, int x1, int x2, |
1700 | int x1Unclamped, int x2Unclamped, bool overrunOK, |
1701 | const MatrixAddressing &atype, |
1702 | const MatrixAddressingStrategy &astrategy); |
1703 | bool getSubblocks(Type T, std::vector<RegisterBlock> &sublayout, |
1704 | const std::vector<RegisterBlock> &layout, bool column, int x1, |
1705 | int x2, bool overrunOK, const MatrixAddressing &atype, |
1706 | const MatrixAddressingStrategy &astrategy); |
1707 | bool getSubblocks(Type T, std::vector<RegisterBlock> &sublayout, |
1708 | std::vector<ngen::GRFRange> *subaddrs, std::vector<int> *indices, |
1709 | const std::vector<RegisterBlock> &layout, |
1710 | const std::vector<ngen::GRFRange> *addrs, bool column, int x1, |
1711 | int x2, bool overrunOK, const MatrixAddressing &atype, |
1712 | const MatrixAddressingStrategy &astrategy); |
1713 | bool getSubblocks(Type T, std::vector<RegisterBlock> &sublayout, |
1714 | std::vector<ngen::GRFRange> &subaddrs, |
1715 | const std::vector<RegisterBlock> &layout, |
1716 | const std::vector<ngen::GRFRange> &addrs, bool column, int x1, |
1717 | int x2, bool overrunOK, const MatrixAddressing &atype, |
1718 | const MatrixAddressingStrategy &astrategy); |
1719 | bool getSubblocks(Type T, std::vector<RegisterBlock> &sublayout, |
1720 | std::vector<int> &indices, const std::vector<RegisterBlock> &layout, |
1721 | bool column, int x1, int x2, bool overrunOK, |
1722 | const MatrixAddressing &atype, |
1723 | const MatrixAddressingStrategy &astrategy); |
1724 | bool reblockLayout(Type Tdst, std::vector<int32_t> &blockMap, |
1725 | std::vector<RegisterBlock> &layoutDst, |
1726 | const std::vector<RegisterBlock> &layoutRef, |
1727 | const std::vector<RegisterBlock> &layoutSrc, |
1728 | const MatrixAddressing &atype, |
1729 | const MatrixAddressingStrategy &astrategy); |
1730 | |
1731 | bool tryAddMasking(Type T, RegisterBlock &block, bool remainderR, |
1732 | bool remainderC, const MatrixAddressing &atype, |
1733 | const MatrixAddressingStrategy &astrategy); |
1734 | bool tryAddMasking(Type T, std::vector<RegisterBlock> &layout, |
1735 | bool remainderR, bool remainderC, const MatrixAddressing &atype, |
1736 | const MatrixAddressingStrategy &astrategy); |
1737 | void addMasking(Type T, std::vector<RegisterBlock> &layout, bool remainderR, |
1738 | bool remainderC, const MatrixAddressing &atype, |
1739 | const MatrixAddressingStrategy &astrategy); |
1740 | void addMasking(Type T, std::vector<RegisterBlock> &layout, |
1741 | std::vector<ngen::GRFRange> &addrs, const ngen::Subregister &ld, |
1742 | bool remainderR, bool remainderC, const MatrixAddressing &atype, |
1743 | const MatrixAddressingStrategy &astrategy, |
1744 | const CommonStrategy &strategy, CommonState &state, |
1745 | int dataRegs = -1); |
1746 | void adjustSubblockAddrs(Type T, |
1747 | const std::vector<RegisterBlock> &sublayout, |
1748 | const std::vector<ngen::GRFRange> &subaddrs, |
1749 | const std::vector<RegisterBlock> &layout, |
1750 | const std::vector<ngen::GRFRange> &addrs, |
1751 | const MatrixAddressing &atype, |
1752 | const MatrixAddressingStrategy &astrategy, |
1753 | const CommonStrategy &strategy, const CommonState &state); |
1754 | |
1755 | bool addToRegLayout(Type T, std::vector<RegisterBlock> &layout, int r, |
1756 | int c, int roff, int coff, bool remainderR, bool remainderC, |
1757 | bool writable, bool avoidFragment, int maxRBlock, int maxCBlock, |
1758 | const MatrixAddressing &atype, |
1759 | const MatrixAddressingStrategy &astrategy); |
1760 | bool add1DBlockToRegLayout(Type T, std::vector<RegisterBlock> &layout, |
1761 | int r, int c, bool writable, const MatrixAddressing &atype, |
1762 | const MatrixAddressingStrategy &astrategy); |
1763 | bool getRegLayout(Type T, std::vector<RegisterBlock> &layout, int r, int c, |
1764 | bool remainderR, bool remainderC, bool writable, bool avoidFragment, |
1765 | int maxRBlock, int maxCBlock, const MatrixAddressing &atype, |
1766 | const MatrixAddressingStrategy &astrategy, |
1767 | bool reverseOrder = false); |
1768 | void makeUnbackedRegLayout(Type T, std::vector<RegisterBlock> &layout, |
1769 | int r, int c, bool colMajor, int crosspack = 1, int tileR = 0, |
1770 | int tileC = 0, bool allowPartialRegs = true, |
1771 | bool fullySplitCx = false); |
1772 | bool upgradeLayoutToBlock2D(Type T, |
1773 | const std::vector<RegisterBlock> &layoutSrc, |
1774 | std::vector<RegisterBlock> &layout2D, bool remainderR, |
1775 | bool remainderC, bool writable, const MatrixAddressing &atype, |
1776 | const MatrixAddressingStrategy &astrategy); |
1777 | |
1778 | void setupTeardownLoadStoreDesc(bool setup, |
1779 | const std::vector<RegisterBlock> &layout, |
1780 | const CommonStrategy &strategy, CommonState &state); |
1781 | void loadLoadStoreDescriptors(bool load, bool store, RegisterBlock &block, |
1782 | const ngen::Subregister &count, const MatrixAddressing &atype, |
1783 | const MatrixAddressingStrategy &astrategy, |
1784 | const CommonStrategy &strategy, CommonState &state); |
1785 | |
1786 | static ngen::DataSpecLSC getDataSpecLSC( |
1787 | AccessType access, const RegisterBlock &block); |
1788 | static ngen::DataSpecLSC getDataSpecLSC(const MatrixAddressing &atype, |
1789 | const MatrixAddressingStrategy &astrategy, |
1790 | const RegisterBlock &block, bool write); |
1791 | ngen::InstructionModifier getRegisterBlockMask( |
1792 | const RegisterBlock &block, CommonState &state); |
1793 | void loadMatrixBlock(const ngen::Register &dest, |
1794 | const RegisterBlock &layout, const MatrixAddressing &atype, |
1795 | const MatrixAddressingStrategy &astrategy, |
1796 | const ngen::GRFRange &addr, const CommonStrategy &strategy, |
1797 | CommonState &state, bool zeroMask = false); |
1798 | void loadMatrix(const GRFMultirange &dest, |
1799 | const std::vector<RegisterBlock> &layout, |
1800 | const MatrixAddressing &atype, |
1801 | const MatrixAddressingStrategy &astrategy, |
1802 | const std::vector<ngen::GRFRange> &addrs, |
1803 | const CommonStrategy &strategy, CommonState &state, |
1804 | bool zeroMask = false); |
1805 | void prefetchMatrix(const std::vector<RegisterBlock> &layout, |
1806 | const MatrixAddressing &atype, |
1807 | const MatrixAddressingStrategy &astrategy, |
1808 | const std::vector<ngen::GRFRange> &addrs, |
1809 | const CommonStrategy &strategy, CommonState &state); |
1810 | void storeMatrixBlock(const ngen::GRF &src, const RegisterBlock &layout, |
1811 | const MatrixAddressing &atype, |
1812 | const MatrixAddressingStrategy &astrategy, |
1813 | const ngen::GRFRange &addr, const CommonStrategy &strategy, |
1814 | CommonState &state); |
1815 | void storeMatrix(const GRFMultirange &src, |
1816 | const std::vector<RegisterBlock> &layout, |
1817 | const MatrixAddressing &atype, |
1818 | const MatrixAddressingStrategy &astrategy, |
1819 | const std::vector<ngen::GRFRange> &addrs, |
1820 | const CommonStrategy &strategy, CommonState &state); |
1821 | void atomicAddMatrixBlock(Type T, const ngen::GRF &src, |
1822 | const RegisterBlock &layout, const MatrixAddressing &atype, |
1823 | const MatrixAddressingStrategy &astrategy, |
1824 | const ngen::GRFRange &addr, const CommonProblem &problem, |
1825 | const CommonStrategy &strategy, CommonState &state); |
1826 | void atomicAddMatrix(Type T, const GRFMultirange &src, |
1827 | const std::vector<RegisterBlock> &layout, |
1828 | const MatrixAddressing &atype, |
1829 | const MatrixAddressingStrategy &astrategy, |
1830 | const std::vector<ngen::GRFRange> &addrs, |
1831 | const CommonProblem &problem, const CommonStrategy &strategy, |
1832 | CommonState &state); |
1833 | |
1834 | bool assignMasks(std::vector<RegisterBlock> &layout, LoopType rloop, |
1835 | LoopType cloop, std::vector<MaskAssignment> &assignments, |
1836 | const CommonStrategy &strategy, CommonState &state, |
1837 | bool retryVirtual = false); |
1838 | void loadMask(MaskAssignment assignment, ngen::Subregister index, |
1839 | const CommonStrategy &strategy, CommonState &state, int offset = 0); |
1840 | void loadMasks(const std::vector<MaskAssignment> &assignments, |
1841 | ngen::Subregister (&indices)[3], const CommonStrategy &strategy, |
1842 | CommonState &state, int start = 0); |
1843 | void loadMasks(const std::vector<MaskAssignment> &assignments, |
1844 | ngen::Subregister (&indices)[3], int (&offsets)[3], |
1845 | const CommonStrategy &strategy, CommonState &state, int start = 0); |
1846 | |
1847 | void setupTeardownRemask(Type T, int index, bool setup, int nq, |
1848 | const ngen::Subregister &remQ, const CommonStrategy &strategy, |
1849 | CommonState &state, int fixedOffQ = 0, |
1850 | const ngen::Subregister &variableOffQ = ngen::Subregister()); |
1851 | void remaskLayout(Type T, int index, bool column, |
1852 | const std::vector<RegisterBlock> &layout, const GRFMultirange ®s, |
1853 | const CommonStrategy &strategy, CommonState &state, int offset = 0); |
1854 | |
1855 | void setAddrRemainder(Type T, const ngen::GRFRange &addr, |
1856 | const RegisterBlock &block, const ngen::Subregister &remR, |
1857 | const ngen::Subregister &remC, const MatrixAddressing &atype, |
1858 | const MatrixAddressingStrategy &astrategy, |
1859 | const CommonStrategy &strategy, CommonState &state); |
1860 | void setAddrRemainder(Type T, const std::vector<ngen::GRFRange> &addr, |
1861 | const std::vector<RegisterBlock> &layout, |
1862 | const ngen::Subregister &remR, const ngen::Subregister &remC, |
1863 | const MatrixAddressing &atype, |
1864 | const MatrixAddressingStrategy &astrategy, |
1865 | const CommonStrategy &strategy, CommonState &state); |
1866 | |
1867 | ngen::Subregister startShift( |
1868 | const MultishiftSubregister &ptr, int shift, CommonState &state); |
1869 | SubregisterPair startShift( |
1870 | const SubregisterPair &ptr, int shift, CommonState &state); |
1871 | template <typename BO> |
1872 | typename std::enable_if<!std::is_base_of<ngen::RegData, BO>::value, |
1873 | BO>::type |
1874 | startShift(const BO &ptr, int shift, CommonState &state); |
1875 | template <typename BO> |
1876 | typename std::enable_if<std::is_base_of<ngen::RegData, BO>::value, BO>::type |
1877 | startShift(const BO &ptr, int shift, CommonState &state); |
1878 | template <typename BO, typename BI> |
1879 | typename std::enable_if<!std::is_base_of<ngen::RegData, BO>::value>::type |
1880 | doneShift( |
1881 | const BO &ptr, const BI &ptrShifted, int shift, CommonState &state); |
1882 | template <typename BO, typename BI> |
1883 | typename std::enable_if<std::is_base_of<ngen::RegData, BO>::value>::type |
1884 | doneShift( |
1885 | const BO &ptr, const BI &ptrShifted, int shift, CommonState &state); |
1886 | void doneShift(const SubregisterPair &ptr, |
1887 | const SubregisterPair &ptrShifted, int shift, CommonState &state); |
1888 | |
1889 | void offsetAddr(const ngen::GRFRange &addrDst, |
1890 | const ngen::GRFRange &addrSrc, const RegisterBlock &blockDst, |
1891 | const RegisterBlock &blockSrc, int offsetFixed, int offsetLD, |
1892 | const ngen::Subregister &ld, const MatrixAddressing &atype, |
1893 | const MatrixAddressingStrategy &astrategy, |
1894 | const CommonStrategy &strategy, CommonState &state, |
1895 | const LDMultiples &ldMultiples = {}); |
1896 | void setupAddrRel(Type T, const ngen::GRFRange &addrDst, |
1897 | const ngen::GRFRange &addrSrc, const RegisterBlock &blockDst, |
1898 | const RegisterBlock &blockSrc, |
1899 | const std::vector<RegisterBlock> &layout, |
1900 | const ngen::Subregister &ld, const MatrixAddressing &atype, |
1901 | const MatrixAddressingStrategy &astrategy, |
1902 | const CommonStrategy &strategy, CommonState &state, |
1903 | const LDMultiples &ldMultiples = {}); |
1904 | template <typename BO> |
1905 | void setupAddr(const ngen::GRFRange &addr, const BO &ptr, |
1906 | const RegisterBlock &layout, const ngen::Subregister &ld, |
1907 | size_t sizeofT, const MatrixAddressing &atype, |
1908 | const MatrixAddressingStrategy &astrategy, |
1909 | const CommonStrategy &strategy, CommonState &state, |
1910 | const Address2DParams ¶ms = {}, LDMultiples ldMultiples = {}); |
1911 | template <typename BO> |
1912 | void setupAddr(Type T, const std::vector<ngen::GRFRange> &addr, |
1913 | const BO &ptr, const std::vector<RegisterBlock> &layout, |
1914 | const ngen::Subregister &ld, const MatrixAddressing &atype, |
1915 | const MatrixAddressingStrategy &astrategy, |
1916 | const CommonStrategy &strategy, CommonState &state, |
1917 | const Address2DParams ¶ms = {}, |
1918 | const LDMultiples &ldMultiples = {}); |
1919 | template <typename I, typename Ir, typename Ic> |
1920 | void incAddrShifted(const ngen::GRFRange &addrDst, |
1921 | const ngen::GRFRange &addrSrc, I inc, Ir incR, Ic incC, |
1922 | const RegisterBlock &layoutDst, const RegisterBlock &layoutSrc, |
1923 | const MatrixAddressing &atype, |
1924 | const MatrixAddressingStrategy &astrategy, |
1925 | const CommonStrategy &strategy, CommonState &state); |
1926 | template <typename I, typename Ir, typename Ic> |
1927 | void incAddrShifted(const std::vector<ngen::GRFRange> &addr, I inc, Ir incR, |
1928 | Ic incC, const std::vector<RegisterBlock> &layout, |
1929 | const MatrixAddressing &atype, |
1930 | const MatrixAddressingStrategy &astrategy, |
1931 | const CommonStrategy &strategy, CommonState &state); |
1932 | template <typename I> |
1933 | void incAddrShifted(const std::vector<ngen::GRFRange> &addr, I inc, |
1934 | const std::vector<RegisterBlock> &layout, |
1935 | const MatrixAddressing &atype, |
1936 | const MatrixAddressingStrategy &astrategy, |
1937 | const CommonStrategy &strategy, CommonState &state); |
1938 | template <typename I, typename Ir, typename Ic> |
1939 | void incAddr(const ngen::GRFRange &addrDst, const ngen::GRFRange &addrSrc, |
1940 | I inc, Ir incR, Ic incC, const RegisterBlock &layoutDst, |
1941 | const RegisterBlock &layoutSrc, const MatrixAddressing &atype, |
1942 | const MatrixAddressingStrategy &astrategy, |
1943 | const CommonStrategy &strategy, CommonState &state); |
1944 | template <typename I, typename Ir, typename Ic> |
1945 | void incAddr(const std::vector<ngen::GRFRange> &addr, I inc, Ir incR, |
1946 | Ic incC, const std::vector<RegisterBlock> &layout, |
1947 | const MatrixAddressing &atype, |
1948 | const MatrixAddressingStrategy &astrategy, |
1949 | const CommonStrategy &strategy, CommonState &state); |
1950 | template <typename I> |
1951 | void incAddr(const ngen::GRFRange &addrDst, const ngen::GRFRange &addrSrc, |
1952 | I inc, const RegisterBlock &layoutDst, |
1953 | const RegisterBlock &layoutSrc, const MatrixAddressing &atype, |
1954 | const MatrixAddressingStrategy &astrategy, |
1955 | const CommonStrategy &strategy, CommonState &state); |
1956 | template <typename I> |
1957 | void incAddr(const std::vector<ngen::GRFRange> &addr, I inc, |
1958 | const std::vector<RegisterBlock> &layout, |
1959 | const MatrixAddressing &atype, |
1960 | const MatrixAddressingStrategy &astrategy, |
1961 | const CommonStrategy &strategy, CommonState &state); |
1962 | template <typename A, typename I, typename Ir, typename Ic> |
1963 | void incDecAddr(const A &addr, I inc, Ir incR, Ic incC, |
1964 | const std::vector<RegisterBlock> &layout, |
1965 | const MatrixAddressing &atype, |
1966 | const MatrixAddressingStrategy &astrategy, |
1967 | const CommonStrategy &strategy, CommonState &state, bool decrement); |
1968 | template <typename A, typename I> |
1969 | void incDecAddr(const A &addr, I inc, |
1970 | const std::vector<RegisterBlock> &layout, |
1971 | const MatrixAddressing &atype, |
1972 | const MatrixAddressingStrategy &astrategy, |
1973 | const CommonStrategy &strategy, CommonState &state, bool decrement); |
1974 | |
1975 | void extendIndexVec(int n, CommonState &state); |
1976 | ngen::Subregister accessIndexVec(int n, CommonState &state); |
1977 | |
1978 | LDMultiples createLDMultiples(bool a64, int nmultiples, |
1979 | const ngen::Subregister &ld, const CommonStrategy &strategy, |
1980 | CommonState &state); |
1981 | ngen::Subregister findLDMultiple(const LDMultiples &multiples, bool a64, |
1982 | int n, const CommonStrategy &strategy, CommonState &state); |
1983 | |
1984 | void setupCAddr0(ngen::GRFRange (&C_addr0)[2], |
1985 | ngen::GRFRange (&C_addr0Unmasked)[2], |
1986 | const std::vector<RegisterBlock> &C_layout, |
1987 | const std::vector<RegisterBlock> &C_layoutUnmasked, int C_count, |
1988 | GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state, |
1989 | const Address2DParams *params = nullptr); |
1990 | |
1991 | void outerProductGen9IGEMM(int ha, int hb, |
1992 | const std::vector<RegisterBlock> &A_layout, |
1993 | const std::vector<RegisterBlock> &B_layout, |
1994 | const GRFMultirange &A_regs, const GRFMultirange &B_regs, |
1995 | const GEMMProblem &problem, const GEMMStrategy &strategy, |
1996 | GEMMState &state); |
1997 | void outerProductSystolic(int h, int ha, int hb, |
1998 | const std::vector<RegisterBlock> &A_layout, |
1999 | const std::vector<RegisterBlock> &B_layout, |
2000 | const GRFMultirange &A_regs, const GRFMultirange &B_regs, |
2001 | const GEMMProblem &problem, const GEMMStrategy &strategy, |
2002 | GEMMState &state); |
2003 | void outerProduct(int h, int ha, int hb, int opCount, |
2004 | const std::vector<RegisterBlock> &A_layout, |
2005 | const std::vector<RegisterBlock> &B_layout, |
2006 | const GRFMultirange &A_regs, const GRFMultirange &B_regs, |
2007 | const GEMMProblem &problem, const GEMMStrategy &strategy, |
2008 | GEMMState &state); |
2009 | void setupTeardownAccumulateSumSystolic(bool setup, Type Tother, |
2010 | const GEMMProblem &problem, const GEMMStrategy &strategy, |
2011 | GEMMState &state); |
2012 | |
2013 | void updateC(const GRFMultirange &C_acc, const GRFMultirange &C_accSwap, |
2014 | const GRFMultirange &C_load, GEMMProblem &problem, |
2015 | GEMMStrategy &strategy, GEMMState &state); |
2016 | void updateCLayout(const std::vector<RegisterBlock> &layoutExt, |
2017 | const ngen::GRFRange (&C_addr0)[2], const RegisterBlock &C_block0, |
2018 | COperation op, GEMMProblem &problem, GEMMStrategy &strategy, |
2019 | GEMMState &state); |
2020 | bool doStdCRemainder(std::vector<RegisterBlock> &layoutExt, |
2021 | std::vector<RegisterBlock> &layoutExtUnmasked, bool inside, |
2022 | bool columns[2], StdCRemType remTypes[2], bool fragments[2], |
2023 | bool fragPositives[2], int fragSizes[2], |
2024 | const ngen::GRFRange (&C_addr0)[2], |
2025 | const ngen::GRFRange (&C_addr0Unmasked)[2], COperation op, |
2026 | std::vector<MaskAssignment> &masks, GEMMProblem &problem, |
2027 | GEMMStrategy &strategy, GEMMState state); |
2028 | void doAlternateCRemainder(COperation op, GEMMProblem &problem, |
2029 | GEMMStrategy &strategy, GEMMState &state); |
2030 | |
2031 | void accumulateSum(bool column, Type Tsrc, const GRFMultirange &srcRegs, |
2032 | const std::vector<RegisterBlock> &srcLayout, Type Tdst, |
2033 | const GRFMultirange &dstRegs, |
2034 | const std::vector<RegisterBlock> &dstLayout, |
2035 | const CommonStrategy &strategy, CommonState &state, int q0 = -1, |
2036 | int q1 = -1); |
2037 | void makeSumLayout(bool column, Type Tsrc, |
2038 | const std::vector<RegisterBlock> &srcLayout, Type Tdst, |
2039 | std::vector<RegisterBlock> &dstLayout, |
2040 | const CommonStrategy &strategy, CommonState &state); |
2041 | void horizontalAdd(bool column, Type T, const GRFMultirange ®s, |
2042 | std::vector<RegisterBlock> &layout, CommonState &state); |
2043 | bool gemmFinalizeSums(const GEMMProblem &problem, |
2044 | const GEMMStrategy &strategy, GEMMState &state); |
2045 | |
2046 | CoopSplit effCoopSplitA( |
2047 | const GEMMProblem &problem, const GEMMStrategy &strategy); |
2048 | CoopSplit effCoopSplitB( |
2049 | const GEMMProblem &problem, const GEMMStrategy &strategy); |
2050 | |
2051 | void convert(const GRFMultirange &range, Type Told, Type Tnew, |
2052 | const GEMMProblem &problem, const GEMMStrategy &strategy, |
2053 | GEMMState &state); |
2054 | bool gemmConvertC(Type Tnew, const GEMMProblem &problem, |
2055 | const GEMMStrategy &strategy, GEMMState &state); |
2056 | void gemmBetaScale( |
2057 | GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state); |
2058 | void binaryOp(BinaryOp op, int simd, const ngen::RegData &dst, |
2059 | const ngen::RegData &src0, const ngen::RegData &src1); |
2060 | void gemmScalarBinaryOpC(BinaryOp op, const ngen::Subregister &offset, |
2061 | const GEMMProblem &problem, const GEMMStrategy &strategy, |
2062 | GEMMState &state); |
2063 | void gemmVectorBinaryOpC(BinaryOp op, bool column, |
2064 | const GRFMultirange &offsets, const ngen::Subregister &scale, |
2065 | const GEMMProblem &problem, const GEMMStrategy &strategy, |
2066 | GEMMState &state, Type Tco = Type::invalid, |
2067 | std::vector<RegisterBlock> CO_layout = std::vector<RegisterBlock>(), |
2068 | int y0 = -1, int y1 = -1); |
2069 | void gemmCalcABOffsetAddrs(const GEMMProblem &problem, |
2070 | const GEMMStrategy &strategy, GEMMState &state); |
2071 | bool gemmLoadABOffset(const GEMMProblem &problem, |
2072 | const GEMMStrategy &strategy, GEMMState &state); |
2073 | void gemmApplyABOffset(const GEMMProblem &problem, |
2074 | const GEMMStrategy &strategy, GEMMState &state); |
2075 | void gemmUpdateSums(const GEMMProblem &problem, |
2076 | const GEMMStrategy &strategy, GEMMState &state); |
2077 | bool gemmBinaryOpC(BinaryOp op, bool row, bool column, Type Tco, |
2078 | MatrixAddressing CO, MatrixAddressingStrategy CO_strategy, |
2079 | ngen::Subregister base, ngen::Subregister ld, |
2080 | const GEMMProblem &problem, const GEMMStrategy &strategy, |
2081 | GEMMState &state); |
2082 | bool gemmApplyCOffsetDispatch(const GEMMProblem &problem, |
2083 | const GEMMStrategy &strategy, GEMMState &state); |
2084 | void gemmKReduce(const GEMMProblem &problem, const GEMMStrategy &strategy, |
2085 | GEMMState &state); |
2086 | void gemmPrefetchC(const GEMMProblem &problem, GEMMStrategy &strategy, |
2087 | GEMMState &state); |
2088 | |
2089 | void gemmApplyPostOps(int poMin, int poMax, const GEMMProblem &problem, |
2090 | GEMMStrategy &strategy, GEMMState &state); |
2091 | void gemmLoadBinaryOpArgs(const GEMMProblem &problem, |
2092 | const GEMMStrategy &strategy, GEMMState &state); |
2093 | |
2094 | void gemmAllocRegs( |
2095 | GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state); |
2096 | void gemmAllocAoBoRegs(const GEMMStrategy &strategy, GEMMState &state); |
2097 | void gemmAIncrementInternal(Type Ta, |
2098 | const std::vector<RegisterBlock> &layout, |
2099 | const std::vector<ngen::GRFRange> &addrs, const MatrixAddressing &A, |
2100 | const MatrixAddressingStrategy &A_strategy, int ka_inc, |
2101 | const GEMMProblem &problem, const GEMMStrategy &strategy, |
2102 | GEMMState &state, int ha = 0); |
2103 | void gemmAIncrementInternal(Type Ta, |
2104 | const std::vector<RegisterBlock> &layout, |
2105 | const std::vector<ngen::GRFRange> &addrs, const MatrixAddressing &A, |
2106 | const MatrixAddressingStrategy &A_strategy, |
2107 | const MultishiftSubregister &ka_inc, const GEMMProblem &problem, |
2108 | const GEMMStrategy &strategy, GEMMState &state, int ha = 0); |
2109 | void gemmAIncrementInternal(Type Ta, |
2110 | const std::vector<RegisterBlock> &layout, |
2111 | const std::vector<ngen::GRFRange> &addrs, const MatrixAddressing &A, |
2112 | const MatrixAddressingStrategy &A_strategy, |
2113 | const ngen::Subregister &ka_inc, const GEMMProblem &problem, |
2114 | const GEMMStrategy &strategy, GEMMState &state, int ha = 0); |
2115 | template <typename I> |
2116 | void gemmAIncrement(Type Ta, const std::vector<RegisterBlock> &layout, |
2117 | const std::vector<ngen::GRFRange> &addrs, const MatrixAddressing &A, |
2118 | const MatrixAddressingStrategy &A_strategy, I ka_inc, |
2119 | const GEMMProblem &problem, const GEMMStrategy &strategy, |
2120 | GEMMState &state, int ha = 0); |
2121 | void gemmALoad(const GRFMultirange ®s, |
2122 | const std::vector<RegisterBlock> &layout, |
2123 | const std::vector<ngen::GRFRange> &addrs, const MatrixAddressing &A, |
2124 | const MatrixAddressingStrategy &A_strategy, |
2125 | const GEMMProblem &problem, const GEMMStrategy &strategy, |
2126 | GEMMState &state); |
2127 | template <typename I> |
2128 | void gemmALoadInc(Type Ta, const GRFMultirange ®s, |
2129 | const std::vector<RegisterBlock> &layout, |
2130 | const std::vector<ngen::GRFRange> &addrs, const MatrixAddressing &A, |
2131 | const MatrixAddressingStrategy &A_strategy, I ka_inc, |
2132 | const GEMMProblem &problem, const GEMMStrategy &strategy, |
2133 | GEMMState &state); |
2134 | void gemmBIncrementInternal(Type Tb, |
2135 | const std::vector<RegisterBlock> &layout, |
2136 | const std::vector<ngen::GRFRange> &addrs, const MatrixAddressing &B, |
2137 | const MatrixAddressingStrategy &B_strategy, int kb_inc, |
2138 | const GEMMProblem &problem, const GEMMStrategy &strategy, |
2139 | GEMMState &state, int hb = 0); |
2140 | void gemmBIncrementInternal(Type Tb, |
2141 | const std::vector<RegisterBlock> &layout, |
2142 | const std::vector<ngen::GRFRange> &addrs, const MatrixAddressing &B, |
2143 | const MatrixAddressingStrategy &B_strategy, |
2144 | const MultishiftSubregister &kb_inc, const GEMMProblem &problem, |
2145 | const GEMMStrategy &strategy, GEMMState &state, int hb = 0); |
2146 | void gemmBIncrementInternal(Type Tb, |
2147 | const std::vector<RegisterBlock> &layout, |
2148 | const std::vector<ngen::GRFRange> &addrs, const MatrixAddressing &B, |
2149 | const MatrixAddressingStrategy &B_strategy, |
2150 | const ngen::Subregister &kb_inc, const GEMMProblem &problem, |
2151 | const GEMMStrategy &strategy, GEMMState &state, int hb = 0); |
2152 | template <typename I> |
2153 | void gemmBIncrement(Type Tb, const std::vector<RegisterBlock> &layout, |
2154 | const std::vector<ngen::GRFRange> &addrs, const MatrixAddressing &B, |
2155 | const MatrixAddressingStrategy &B_strategy, I kb_inc, |
2156 | const GEMMProblem &problem, const GEMMStrategy &strategy, |
2157 | GEMMState &state, int hb = 0); |
2158 | void gemmBLoad(const GRFMultirange ®s, |
2159 | const std::vector<RegisterBlock> &layout, |
2160 | const std::vector<ngen::GRFRange> &addrs, const MatrixAddressing &B, |
2161 | const MatrixAddressingStrategy &B_strategy, |
2162 | const GEMMProblem &problem, const GEMMStrategy &strategy, |
2163 | GEMMState &state); |
2164 | template <typename I> |
2165 | void gemmBLoadInc(Type Tb, const GRFMultirange ®s, |
2166 | const std::vector<RegisterBlock> &layout, |
2167 | const std::vector<ngen::GRFRange> &addrs, const MatrixAddressing &B, |
2168 | const MatrixAddressingStrategy &B_strategy, I kb_inc, |
2169 | const GEMMProblem &problem, const GEMMStrategy &strategy, |
2170 | GEMMState &state); |
2171 | template <bool doA> |
2172 | void gemmAiBiRemLoadInc(bool incremental, bool incrementalCopy, |
2173 | bool keepAddrTogether, bool willRemask, |
2174 | const ngen::Subregister &kSLMX, const GRFMultirange &Xi_regs, |
2175 | const std::vector<RegisterBlock> &Xi_layout, |
2176 | const std::vector<ngen::GRFRange> &Xi_addrs, |
2177 | const std::vector<std::vector<RegisterBlock>> &Xi_layoutK, |
2178 | const std::vector<std::vector<ngen::GRFRange>> &Xi_addrsK, |
2179 | const GRFMultirange &Xo_regs, |
2180 | const std::vector<RegisterBlock> &Xo_layout, |
2181 | const MatrixAddressing &Xi, |
2182 | const MatrixAddressingStrategy &Xi_strategy, |
2183 | const GEMMProblem &problem, const GEMMStrategy &strategy, |
2184 | GEMMState &state); |
2185 | SubregisterPair allocIncrement( |
2186 | const GEMMStrategy &strategy, CommonState &state); |
2187 | void gemmCalcIncrements(const GEMMProblem &problem, |
2188 | const GEMMStrategy &strategy, GEMMState &state, int ka_load = 0, |
2189 | int kb_load = 0, bool doA = true, bool doB = true); |
2190 | void gemmCalcWorkshareAOffset(ngen::Subregister &off, |
2191 | ngen::Subregister &offR, ngen::Subregister &offC, |
2192 | const MatrixAddressing &A, |
2193 | const MatrixAddressingStrategy &A_strategy, int ma, int ka, |
2194 | const GEMMProblem &problem, const GEMMStrategy &strategy, |
2195 | GEMMState &state); |
2196 | void gemmCalcWorkshareBOffset(ngen::Subregister &off, |
2197 | ngen::Subregister &offR, ngen::Subregister &offC, |
2198 | const MatrixAddressing &B, |
2199 | const MatrixAddressingStrategy &B_strategy, int kb, int nb, |
2200 | const GEMMProblem &problem, const GEMMStrategy &strategy, |
2201 | GEMMState &state); |
2202 | bool gemmPrepMaskedAB(const GEMMProblem &problem, GEMMStrategy &strategy, |
2203 | GEMMState &state); |
2204 | void gemmSLMRemask(bool remaskA, bool remaskB, GRFMultirange &Ao_regs, |
2205 | GRFMultirange &Bo_regs, int kOffset, const GEMMProblem &problem, |
2206 | const GEMMStrategy &strategy, GEMMState &state); |
2207 | |
2208 | void gemmCalcKLoopBarrierCount(ngen::Subregister &count, |
2209 | const ngen::Subregister &k, int cooldown, |
2210 | const GEMMProblem &problem, const GEMMStrategy &strategy, |
2211 | GEMMState &state); |
2212 | void gemmCalcKSLM(const ngen::Subregister &kSLM, |
2213 | const ngen::Subregister &lid, int kgran, int kdiv, int krep, |
2214 | const GEMMProblem &problem, const GEMMStrategy &strategy, |
2215 | GEMMState &state); |
2216 | void (GEMMState &state); |
2217 | ngen::GRF (GEMMState &state); |
2218 | void kLoop(KLoop type, const GEMMProblem &problem, GEMMStrategy &strategy, |
2219 | GEMMState &state); |
2220 | bool kLoopSetup(const GEMMProblem &problem, const GEMMStrategy &strategy, |
2221 | GEMMState &state); |
2222 | template <typename I> |
2223 | void kLoopReset(const I &kOffset, const GEMMProblem &problem, |
2224 | const GEMMStrategy &strategy, GEMMState &state); |
2225 | void kLoopTeardown(const GEMMProblem &problem, const GEMMStrategy &strategy, |
2226 | GEMMState &state); |
2227 | bool kLoopSingle(KLoop type, const GEMMProblem &problem, |
2228 | GEMMStrategy &strategy, GEMMState &state); |
2229 | bool gemmKLoop( |
2230 | GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state); |
2231 | bool gemmAccumulateC( |
2232 | GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state); |
2233 | bool gemmAccumulateCSetup( |
2234 | GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state); |
2235 | void gemmAccumulateCTeardown( |
2236 | GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state); |
2237 | bool gemmAccessC(COperation op, GEMMProblem &problem, |
2238 | GEMMStrategy &strategy, GEMMState &state); |
2239 | bool gemmUpdateC( |
2240 | GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state); |
2241 | |
2242 | bool gemmBody(GEMMProblem problem, GEMMStrategy strategy, GEMMState state); |
2243 | bool gemmBodyInternal( |
2244 | GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state); |
2245 | |
2246 | bool wgRemCheck(const GEMMProblem &problem, const GEMMStrategy &strategy); |
2247 | template <typename Problem> |
2248 | bool mnRemainderHandling(LoopType loop, Problem &problem, |
2249 | GEMMStrategy &strategy, GEMMState &state, |
2250 | bool (gemm_kernel_generator_t<hw>::*func)( |
2251 | Problem, GEMMStrategy, GEMMState)); |
2252 | template <typename Problem> |
2253 | bool mnJointSplitRemainderHandling(Problem &problem, GEMMStrategy &strategy, |
2254 | GEMMState &state, |
2255 | bool (gemm_kernel_generator_t<hw>::*func)( |
2256 | Problem, GEMMStrategy, GEMMState)); |
2257 | bool gemmMEdge( |
2258 | GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state); |
2259 | bool gemmNEdge(GEMMProblem problem, GEMMStrategy strategy, GEMMState state); |
2260 | |
2261 | void gemmHilbertlikeOrder(const GEMMProblem &problem, |
2262 | GEMMStrategy &strategy, GEMMState &state); |
2263 | void gemmBoustrophedonOrder(const GEMMProblem &problem, |
2264 | GEMMStrategy &strategy, GEMMState &state); |
2265 | void gemmReorderLocalIDs(const GEMMProblem &problem, |
2266 | const GEMMStrategy &strategy, GEMMState &state); |
2267 | |
2268 | void gemmCheck32(const GEMMProblem &problem, GEMMStrategy &strategy, |
2269 | GEMMState &state); |
2270 | void gemmGetBatchIDs(const GEMMProblem &problem, |
2271 | const GEMMStrategy &strategy, GEMMState &state); |
2272 | void gemmReleaseBatchIDs(const GEMMProblem &problem, |
2273 | const GEMMStrategy &strategy, GEMMState &state); |
2274 | void gemmOffsetAk(int h, const ngen::Subregister &effA, |
2275 | const MatrixAddressing &globalA, const GEMMProblem &problem, |
2276 | const GEMMStrategy &strategy, GEMMState &state); |
2277 | void gemmOffsetAk(const ngen::Subregister &h, const ngen::Subregister &effA, |
2278 | const MatrixAddressing &globalA, const GEMMProblem &problem, |
2279 | const GEMMStrategy &strategy, GEMMState &state); |
2280 | void gemmOffsetBk(int h, const ngen::Subregister &effB, |
2281 | const MatrixAddressing &globalB, const GEMMProblem &problem, |
2282 | const GEMMStrategy &strategy, GEMMState &state); |
2283 | void gemmOffsetBk(const ngen::Subregister &h, const ngen::Subregister &effB, |
2284 | const MatrixAddressing &globalB, const GEMMProblem &problem, |
2285 | const GEMMStrategy &strategy, GEMMState &state); |
2286 | void gemmFoldOffsets(const GEMMProblem &problem, |
2287 | const GEMMStrategy &strategy, GEMMState &state); |
2288 | void gemmRestoreOffsets(const GEMMProblem &problem, |
2289 | const GEMMStrategy &strategy, GEMMState &state); |
2290 | void gemmOffsetABC(bool initial, ngen::Subregister i0, ngen::Subregister j0, |
2291 | ngen::Subregister h0, const GEMMProblem &problem, |
2292 | const GEMMStrategy &strategy, GEMMState &state, bool doA = true, |
2293 | bool doB = true, bool doC = true, bool doBinary = false); |
2294 | void gemmOffsetBatchABC(const GEMMProblem &problem, |
2295 | const GEMMStrategy &strategy, GEMMState &state); |
2296 | void gemmSetupABC(const GEMMProblem &problem, const GEMMStrategy &strategy, |
2297 | GEMMState &state); |
2298 | void gemmCacheLDABMultiples(const GEMMProblem &problem, |
2299 | const GEMMStrategy &strategy, GEMMState &state); |
2300 | void gemmCacheLDCMultiples(const GEMMProblem &problem, |
2301 | const GEMMStrategy &strategy, GEMMState &state, |
2302 | bool prefetch = false); |
2303 | void gemmScaleInputs(const GEMMProblem &problem, |
2304 | const GEMMStrategy &strategy, GEMMState &state); |
2305 | void gemmReverseLoops(const GEMMProblem &problem, |
2306 | const GEMMStrategy &strategy, GEMMState &state); |
2307 | void gemmDowngradeAccess(const GEMMProblem &problem, GEMMStrategy &strategy, |
2308 | GEMMState &state); |
2309 | void gemmSubkernel( |
2310 | GEMMProblem &problem, GEMMStrategy &strategy, GEMMState state); |
2311 | static size_t gemmSLMSize( |
2312 | const GEMMProblem &problem, const GEMMStrategy &strategy); |
2313 | static size_t gemmPerKSLMSize( |
2314 | const GEMMProblem &problem, const GEMMStrategy &strategy); |
2315 | void gemmInitInterface(GEMMProblem &problem, GEMMStrategy &strategy, |
2316 | GEMMState &state, bool inSK = false); |
2317 | void gemmInitState(GEMMProblem &problem, GEMMStrategy &strategy, |
2318 | GEMMState &state, bool inSK = false); |
2319 | void gemm(GEMMProblem &problem, GEMMStrategy &strategy, GEMMState &state); |
2320 | |
2321 | void gemmSuperkernelInitState(GEMMSuperkernelProblem &problem, |
2322 | GEMMSuperkernelStrategy &strategy, GEMMSuperkernelState &state); |
2323 | |
2324 | bool sysgemmAccumulateC(GEMMProblem &problem, const GEMMStrategy &strategy, |
2325 | GEMMState &state); |
2326 | void sysgemmKLoop(const GEMMProblem &problem, const GEMMStrategy &strategy, |
2327 | GEMMState &state); |
2328 | void sysgemmKLoop4(const GEMMProblem &problem, const GEMMStrategy &strategy, |
2329 | GEMMState &state, bool oddB); |
2330 | void sysgemmStoreSignal(const GEMMProblem &problem, |
2331 | const GEMMStrategy &strategy, GEMMState &state, |
2332 | bool forceFence = false); |
2333 | void sysgemmCopyLoad(const GEMMProblem &problem, |
2334 | const GEMMStrategy &strategy, GEMMState &state, int storeBuffer, |
2335 | bool useC = false); |
2336 | void sysgemmCopyLoad4(const GEMMProblem &problem, |
2337 | const GEMMStrategy &strategy, GEMMState &state, int storeBuffer, |
2338 | bool loadB, int useC = 0, |
2339 | ngen::RegData flagLoadB = ngen::RegData()); |
2340 | void sysgemmCopyStore(const GEMMProblem &problem, |
2341 | const GEMMStrategy &strategy, GEMMState &state, int storeBuffer, |
2342 | bool first = false); |
2343 | void sysgemmCopyStore4(const GEMMProblem &problem, |
2344 | const GEMMStrategy &strategy, GEMMState &state, int storeBuffer, |
2345 | bool storeB, int useC = 0, int useC_B = 0); |
2346 | void sysgemmMultiply(const GEMMProblem &problem, |
2347 | const GEMMStrategy &strategy, GEMMState &state, int buffer, |
2348 | bool lastMultiply = false); |
2349 | void sysgemmMultiply4(const GEMMProblem &problem, |
2350 | const GEMMStrategy &strategy, GEMMState &state, int buffer, |
2351 | bool firstMultiply = false, |
2352 | ngen::RegData flagWaitLoad = ngen::RegData(), |
2353 | ngen::RegData flagSignal = ngen::RegData(), |
2354 | ngen::Label *labelDone = nullptr); |
2355 | void sysgemmMultiplyChunk(const GEMMProblem &problem, |
2356 | const GEMMStrategy &strategy, bool first, int ao, int i0, |
2357 | bool waitB, bool prepB, |
2358 | const ngen::InstructionModifier &swsb0 |
2359 | = ngen::InstructionModifier(), |
2360 | const ngen::InstructionModifier &swsbEnd |
2361 | = ngen::InstructionModifier()); |
2362 | void sysgemmBarrierPrep( |
2363 | const ngen::InstructionModifier &swsb, const ngen::GRF &); |
2364 | void sysgemmReorderLocalIDs(const GEMMProblem &problem, |
2365 | const GEMMStrategy &strategy, GEMMState &state); |
2366 | |
2367 | bool sysgemm2AccumulateC(GEMMProblem &problem, const GEMMStrategy &strategy, |
2368 | GEMMState &state); |
2369 | void sysgemm2KLoopCompute(const GEMMProblem &problem, |
2370 | const GEMMStrategy &strategy, GEMMState &state); |
2371 | void sysgemm2KLoopCopy(const GEMMProblem &problem, |
2372 | const GEMMStrategy &strategy, GEMMState &state); |
2373 | void sysgemm2Multiply(const GEMMProblem &problem, |
2374 | const GEMMStrategy &strategy, GEMMState &state, int buffer, |
2375 | bool cooldown = false, |
2376 | ngen::FlagRegister flagWaitLoad = ngen::FlagRegister(), |
2377 | ngen::FlagRegister flagSignal = ngen::FlagRegister()); |
2378 | void sysgemm2MultiplyX32(const GEMMProblem &problem, |
2379 | const GEMMStrategy &strategy, GEMMState &state, int buffer, |
2380 | bool cooldown = false, |
2381 | ngen::FlagRegister flagWaitLoad = ngen::FlagRegister(), |
2382 | ngen::FlagRegister flagSignal = ngen::FlagRegister()); |
2383 | void sysgemm2MultiplyX48(const GEMMProblem &problem, |
2384 | const GEMMStrategy &strategy, GEMMState &state, int buffer, |
2385 | bool cooldown = false, |
2386 | ngen::FlagRegister flagWaitLoad = ngen::FlagRegister(), |
2387 | ngen::FlagRegister flagSignal = ngen::FlagRegister()); |
2388 | void sysgemm2MultiplyChunkX32(const GEMMProblem &problem, |
2389 | const GEMMStrategy &strategy, int chunkA, bool odd); |
2390 | void sysgemm2MultiplyChunkX48(const GEMMProblem &problem, |
2391 | const GEMMStrategy &strategy, int chunkA); |
2392 | |
2393 | bool copyRegisterBlock(Type Ts, Type Td, const RegisterBlock &blockSrc, |
2394 | const RegisterBlock &blockDst, const GRFMultirange &src, |
2395 | const GRFMultirange &dst, int dOffR, int dOffC, |
2396 | const CommonStrategy &strategy, CommonState &state, |
2397 | bool preserveSrc = false); |
2398 | bool copyRegisters(Type Ts, Type Td, |
2399 | const std::vector<RegisterBlock> &layoutSrc, |
2400 | const std::vector<RegisterBlock> &layoutDst, |
2401 | const GRFMultirange &src, const GRFMultirange &dst, int dOffR, |
2402 | int dOffC, bool conjugate, const CommonStrategy &strategy, |
2403 | CommonState &state, bool preserveSrc = false); |
2404 | bool copyRegisters(Type Ts, Type Td, |
2405 | const std::vector<RegisterBlock> &layoutSrc, |
2406 | const std::vector<RegisterBlock> &layoutDst, |
2407 | const GRFMultirange &src, const GRFMultirange &dst, int dOffR, |
2408 | int dOffC, const Scalar<double> &alpha_real, |
2409 | const Scalar<double> &alpha_imag, bool conjugate, |
2410 | const CommonStrategy &strategy, CommonState &state, |
2411 | bool preserveSrc = false); |
2412 | |
2413 | bool copyBody( |
2414 | CopyProblem &problem, CopyStrategy &strategy, CopyState &state); |
2415 | bool copyBodyRemCheck( |
2416 | CopyProblem &problem, CopyStrategy &strategy, CopyState &state); |
2417 | bool copyBodyInternal( |
2418 | CopyProblem &problem, CopyStrategy &strategy, CopyState &state); |
2419 | void copySlice( |
2420 | CopyProblem &problem, CopyStrategy &strategy, CopyState &state); |
2421 | |
2422 | void copyCalcIncrements(const CopyProblem &problem, |
2423 | const CopyStrategy &strategy, CopyState &state, int s_load = 0, |
2424 | int d_load = 0); |
2425 | |
2426 | void copyInitInterface( |
2427 | CopyProblem &problem, CopyStrategy &strategy, CopyState &state); |
2428 | void copyInitState( |
2429 | CopyProblem &problem, CopyStrategy &strategy, CopyState &state); |
2430 | void copy(CopyProblem &problem, CopyStrategy &strategy, CopyState &state); |
2431 | |
2432 | void prologue(const CommonStrategy &strategy); |
2433 | void epilogue(const CommonStrategy &strategy, const CommonState &state); |
2434 | void padding(); |
2435 | void initState(const CommonProblem &problem, const CommonStrategy &strategy, |
2436 | CommonState &state); |
2437 | }; |
2438 | |
2439 | inline char precisionChar(Type T) { |
2440 | switch (T.baseType()) { |
2441 | case Type::f16: return 'H'; |
2442 | case Type::f32: return 'S'; |
2443 | case Type::u8: return 'o'; |
2444 | case Type::s8: return 'O'; |
2445 | case Type::u16: return 'w'; |
2446 | case Type::s16: return 'W'; |
2447 | case Type::u32: return 'i'; |
2448 | case Type::s32: return 'I'; |
2449 | case Type::u64: return 'l'; |
2450 | case Type::s64: return 'L'; |
2451 | case Type::bf16: return 'B'; |
2452 | case Type::tf32: return 'T'; |
2453 | default: return '?'; |
2454 | } |
2455 | } |
2456 | |
2457 | static inline Type charPrecision(char c) { |
2458 | switch (c) { |
2459 | case 'H': return Type::f16; |
2460 | case 'S': return Type::f32; |
2461 | case 'o': return Type::u8; |
2462 | case 'O': return Type::s8; |
2463 | case 'w': return Type::u16; |
2464 | case 'W': return Type::s16; |
2465 | case 'i': return Type::u32; |
2466 | case 'I': return Type::s32; |
2467 | case 'B': return Type::bf16; |
2468 | case 'T': return Type::tf32; |
2469 | default: return Type::invalid; |
2470 | } |
2471 | } |
2472 | |
2473 | inline char layoutChar(MatrixLayout layout) { |
2474 | switch (layout) { |
2475 | case MatrixLayout::N: return 'N'; |
2476 | case MatrixLayout::T: return 'T'; |
2477 | case MatrixLayout::Pc: return 'A'; |
2478 | case MatrixLayout::Pr: return 'B'; |
2479 | default: return '?'; |
2480 | } |
2481 | } |
2482 | |
2483 | } // namespace jit |
2484 | } // namespace gpu |
2485 | } // namespace impl |
2486 | } // namespace dnnl |
2487 | |
2488 | #endif /* header guard */ |
2489 | |