1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef GPU_JIT_GEMM_GEN_GEMM_KERNEL_COMMON_HPP |
18 | #define GPU_JIT_GEMM_GEN_GEMM_KERNEL_COMMON_HPP |
19 | |
20 | #define STANDALONE 0 |
21 | |
22 | #include <string> |
23 | |
24 | namespace dnnl { |
25 | namespace impl { |
26 | namespace gpu { |
27 | namespace jit { |
28 | |
29 | // Loop identifiers. |
30 | enum LoopType : uint8_t { |
31 | LoopM = 0, |
32 | LoopN = 1, |
33 | LoopK = 2, |
34 | LoopPersistent |
35 | = 0x40, // Flag OR'ed with other loop types, indicating persistent threads. |
36 | LoopMNBoustrophedonMNK |
37 | = 0x80, // Fused m/n indices (boustrophedon ordering), with MNK nested inside |
38 | LoopMNBoustrophedonNMK |
39 | = 0x81, // Fused n/m indices (boustrophedon ordering), with NMK nested inside |
40 | LoopMNHilbertMNK |
41 | = 0x90, // Fused m/n indices (Hilbert ordering), with MNK nested inside |
42 | LoopMNHilbertNMK |
43 | = 0x91, // Fused n/m indices (Hilbert ordering), with NMK nested inside |
44 | LoopAny = 0xFF, |
45 | LoopNone = 0xFF |
46 | }; |
47 | |
48 | // WG identifiers. |
49 | enum WGType : uint8_t { |
50 | WGDynamic = 0, // Dynamic work group size (can shrink or expand) |
51 | WGFixed = 1, // Fixed work group size |
52 | WGShrinkable = 2 // Work group size can shrink but not expand |
53 | }; |
54 | |
55 | // Driver information, shared by all kernel types. |
56 | struct CommonDriverInfo { |
57 | int subgroupSize; // Declared subgroup size (unrelated to actual SIMD lengths in kernel) |
58 | LoopType fusedLoop; // Loop dimension in which EUs are fused (if any). |
59 | int grfCount; // # of GRFs used by kernel. |
60 | LoopType loopOrder |
61 | [3]; // Loops corresponding to x/y/z dimensions of kernel dispatch. |
62 | int blocking[3]; // Standard blocking sizes in m/n/k dimensions. |
63 | int blockingAlt[3]; // Alternative blocking sizes in m/n/k dimensions. |
64 | int unroll[3]; // m/n/k unrolls. |
65 | int wg[3]; // HW threads per workgroup in m/n/k dimensions. |
66 | int wgExpand; // If > 1, workgroup size needs to be scaled by this factor. |
67 | WGType wgUpdate; // Work group type showing how/if work group sizes can be updated. |
68 | bool kRemainderHandling; // True if kernel performs k remainder handling (gemm). |
69 | bool kParallel; // True if gemm kernel can be parallelized in the k dimension. |
70 | bool kParallelLocal; // True if gemm kernel can be parallelized in the k dimension inside a workgroup. |
71 | int slm; // Minimum SLM allocation. |
72 | int perKSLM; // If > 0, dynamically allocate at least perKSLM * wg[LoopK] bytes of SLM. |
73 | int alignment |
74 | [3]; // Address alignment requirements for A,B,C (gemm) or S,D (copy). |
75 | bool support4GB |
76 | [3]; // True if >4GB buffers allowed for A,B,C (gemm) or S,D (copy). |
77 | |
78 | bool fusedEUs() const { return (fusedLoop != LoopNone); } |
79 | bool isMNK() const { |
80 | auto l = loopOrder[0] & ~LoopPersistent; |
81 | return l == LoopM || l == LoopMNHilbertMNK |
82 | || l == LoopMNBoustrophedonMNK; |
83 | } |
84 | bool isNMK() const { |
85 | auto l = loopOrder[0] & ~LoopPersistent; |
86 | return l == LoopN || l == LoopMNHilbertNMK |
87 | || l == LoopMNBoustrophedonNMK; |
88 | } |
89 | bool isHilbert() const { |
90 | auto l = loopOrder[0] & ~LoopPersistent; |
91 | return l == LoopMNHilbertMNK || l == LoopMNHilbertNMK; |
92 | } |
93 | bool isBoustrophedon() const { |
94 | auto l = loopOrder[0] & ~LoopPersistent; |
95 | return l == LoopMNBoustrophedonMNK || l == LoopMNBoustrophedonNMK; |
96 | } |
97 | bool isLinearOrder() const { return isHilbert() || isBoustrophedon(); } |
98 | bool isPersistent() const { |
99 | return (loopOrder[0] != LoopNone) && (loopOrder[0] & LoopPersistent); |
100 | } |
101 | bool fixedWG() const { return wgUpdate == WGFixed; } |
102 | |
103 | int wgTile(LoopType l) const { return unroll[l] * wg[l]; } |
104 | }; |
105 | |
106 | // Definitions for flag arguments to kernels. |
107 | enum { |
108 | FlagCOColumn = 0x4, |
109 | FlagCORow = 0x8, |
110 | FlagNonfinalKBlock = 0x10, |
111 | FlagNoninitialKBlock = 0x80, |
112 | FlagLateFusedGEMMDone = 0x100, |
113 | FlagEarlyFusedGEMMDone = 0x200, |
114 | FlagStoreSums = 0x400, |
115 | }; |
116 | |
117 | } // namespace jit |
118 | } // namespace gpu |
119 | } // namespace impl |
120 | } // namespace dnnl |
121 | |
122 | #endif /* header guard */ |
123 | |