1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_JIT_GEMM_GEN_GEMM_KERNEL_COMMON_HPP
18#define GPU_JIT_GEMM_GEN_GEMM_KERNEL_COMMON_HPP
19
20#define STANDALONE 0
21
22#include <string>
23
24namespace dnnl {
25namespace impl {
26namespace gpu {
27namespace jit {
28
29// Loop identifiers.
30enum LoopType : uint8_t {
31 LoopM = 0,
32 LoopN = 1,
33 LoopK = 2,
34 LoopPersistent
35 = 0x40, // Flag OR'ed with other loop types, indicating persistent threads.
36 LoopMNBoustrophedonMNK
37 = 0x80, // Fused m/n indices (boustrophedon ordering), with MNK nested inside
38 LoopMNBoustrophedonNMK
39 = 0x81, // Fused n/m indices (boustrophedon ordering), with NMK nested inside
40 LoopMNHilbertMNK
41 = 0x90, // Fused m/n indices (Hilbert ordering), with MNK nested inside
42 LoopMNHilbertNMK
43 = 0x91, // Fused n/m indices (Hilbert ordering), with NMK nested inside
44 LoopAny = 0xFF,
45 LoopNone = 0xFF
46};
47
48// WG identifiers.
49enum WGType : uint8_t {
50 WGDynamic = 0, // Dynamic work group size (can shrink or expand)
51 WGFixed = 1, // Fixed work group size
52 WGShrinkable = 2 // Work group size can shrink but not expand
53};
54
55// Driver information, shared by all kernel types.
56struct CommonDriverInfo {
57 int subgroupSize; // Declared subgroup size (unrelated to actual SIMD lengths in kernel)
58 LoopType fusedLoop; // Loop dimension in which EUs are fused (if any).
59 int grfCount; // # of GRFs used by kernel.
60 LoopType loopOrder
61 [3]; // Loops corresponding to x/y/z dimensions of kernel dispatch.
62 int blocking[3]; // Standard blocking sizes in m/n/k dimensions.
63 int blockingAlt[3]; // Alternative blocking sizes in m/n/k dimensions.
64 int unroll[3]; // m/n/k unrolls.
65 int wg[3]; // HW threads per workgroup in m/n/k dimensions.
66 int wgExpand; // If > 1, workgroup size needs to be scaled by this factor.
67 WGType wgUpdate; // Work group type showing how/if work group sizes can be updated.
68 bool kRemainderHandling; // True if kernel performs k remainder handling (gemm).
69 bool kParallel; // True if gemm kernel can be parallelized in the k dimension.
70 bool kParallelLocal; // True if gemm kernel can be parallelized in the k dimension inside a workgroup.
71 int slm; // Minimum SLM allocation.
72 int perKSLM; // If > 0, dynamically allocate at least perKSLM * wg[LoopK] bytes of SLM.
73 int alignment
74 [3]; // Address alignment requirements for A,B,C (gemm) or S,D (copy).
75 bool support4GB
76 [3]; // True if >4GB buffers allowed for A,B,C (gemm) or S,D (copy).
77
78 bool fusedEUs() const { return (fusedLoop != LoopNone); }
79 bool isMNK() const {
80 auto l = loopOrder[0] & ~LoopPersistent;
81 return l == LoopM || l == LoopMNHilbertMNK
82 || l == LoopMNBoustrophedonMNK;
83 }
84 bool isNMK() const {
85 auto l = loopOrder[0] & ~LoopPersistent;
86 return l == LoopN || l == LoopMNHilbertNMK
87 || l == LoopMNBoustrophedonNMK;
88 }
89 bool isHilbert() const {
90 auto l = loopOrder[0] & ~LoopPersistent;
91 return l == LoopMNHilbertMNK || l == LoopMNHilbertNMK;
92 }
93 bool isBoustrophedon() const {
94 auto l = loopOrder[0] & ~LoopPersistent;
95 return l == LoopMNBoustrophedonMNK || l == LoopMNBoustrophedonNMK;
96 }
97 bool isLinearOrder() const { return isHilbert() || isBoustrophedon(); }
98 bool isPersistent() const {
99 return (loopOrder[0] != LoopNone) && (loopOrder[0] & LoopPersistent);
100 }
101 bool fixedWG() const { return wgUpdate == WGFixed; }
102
103 int wgTile(LoopType l) const { return unroll[l] * wg[l]; }
104};
105
106// Definitions for flag arguments to kernels.
107enum {
108 FlagCOColumn = 0x4,
109 FlagCORow = 0x8,
110 FlagNonfinalKBlock = 0x10,
111 FlagNoninitialKBlock = 0x80,
112 FlagLateFusedGEMMDone = 0x100,
113 FlagEarlyFusedGEMMDone = 0x200,
114 FlagStoreSums = 0x400,
115};
116
117} // namespace jit
118} // namespace gpu
119} // namespace impl
120} // namespace dnnl
121
122#endif /* header guard */
123