1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_JIT_XE_HP_SYSTOLIC_GEMM_HPP
18#define GPU_JIT_XE_HP_SYSTOLIC_GEMM_HPP
19
20#include <assert.h>
21#include <memory>
22#include <tuple>
23
24#include "common/c_types_map.hpp"
25#include "common/gemm_utils.hpp"
26#include "common/memory_storage.hpp"
27#include "common/utils.hpp"
28#include "gpu/compute/compute.hpp"
29#include "gpu/gemm/gpu_gemm.hpp"
30#include "gpu/jit/gemm/gen_gemm_kernel.hpp"
31#include "gpu/jit/gemm/jit_gemm_pd.hpp"
32#include "gpu/primitive_conf.hpp"
33
34namespace dnnl {
35namespace impl {
36namespace gpu {
37namespace jit {
38
39struct xe_hp_systolic_gemm_t : public gpu_gemm_t {
40 struct pd_t : public jit_gemm_pd_t {
41 using jit_gemm_pd_t::jit_gemm_pd_t;
42
43 DECLARE_COMMON_PD_T("jit:xe_hp:gemm:any", xe_hp_systolic_gemm_t);
44
45 status_t init(engine_t *engine);
46
47 bool use_nocopy();
48 bool set_default_formats(data_type_t dt);
49
50 size_t dyn_offset_a = 0;
51 size_t dyn_offset_b = 0;
52 size_t dyn_offset_c = 0;
53
54 data_type_t impl_co_type() const {
55 using namespace data_type;
56 return with_bias() ? desc()->bias_type()
57 : (utils::one_of(desc()->a_type(), s8, u8)
58 ? s32
59 : desc()->c_type());
60 }
61
62 data_type_t impl_acc_type() const {
63 using namespace data_type;
64 return utils::one_of(desc()->c_type(), s8, u8, f16, bf16, f32)
65 ? (utils::one_of(desc()->a_type(), s8, u8) ? s32 : f32)
66 : s32;
67 }
68
69 float alpha() const { return 1.0f; }
70 float beta() const { return beta_; }
71
72 bool with_bias() const {
73 return (desc()->bias_type() != data_type::undef)
74 && !bias_via_binary_;
75 }
76
77 int bias_cmask() const {
78 unsigned char to_cmask[8] = {0, 4, 2, 6, 1, 5, 3, 7};
79 assert(unsigned(desc()->bias_mask()) < 8);
80 return with_bias() ? to_cmask[desc()->bias_mask() & 7] : -1;
81 }
82
83 bool packed_a() const { return packed_a_; }
84 bool packed_b() const { return packed_b_; }
85 bool packed_c() const { return packed_c_; }
86
87 dim_t lda_packed() const {
88 return packed_a() ? desc()->b_desc.format_desc.blocking
89 .strides[with_batch() ? 2 : 1]
90 / unroll_m()
91 : 0;
92 }
93 dim_t ldb_packed() const {
94 return packed_b() ? desc()->a_desc.format_desc.blocking
95 .strides[with_batch() ? 1 : 0]
96 / unroll_n()
97 : 0;
98 }
99 dim_t ldc_packed() const {
100 return packed_c() ? desc()->c_desc.format_desc.blocking
101 .strides[with_batch() ? 1 : 0]
102 / unroll_n()
103 : 0;
104 }
105
106 int batch_dims() const {
107 return nstl::max(desc()->c_desc.ndims - 2, 0);
108 }
109
110 bool with_batch() const { return desc()->is_batched(); }
111 bool with_a_zero_points() const { return a_zp_; }
112 bool with_b_zero_points() const { return b_zp_; }
113 bool with_ab_zero_points() const { return a_zp_ || b_zp_; }
114 bool with_c_zero_points() const { return c_zp_; }
115
116 bool allow_k_blocking() const {
117 return (desc()->acc_type == desc()->c_type())
118 && IMPLICATION(post_ops()->len() > 0,
119 post_ops()->entry_[0].kind == primitive_kind::sum);
120 }
121
122 int unroll_m() const { return unroll_m_; }
123 int unroll_n() const { return unroll_n_; }
124 bool alt() const { return alt_; }
125
126 status_t query(query_t what, int idx, void *result) const override {
127 switch ((int)what) {
128 case (int)query::preferred_gpu_threads_per_eu: {
129 *(int *)result = 4;
130 break;
131 }
132 default: return gpu_gemm_pd_t::query(what, idx, result);
133 }
134 return status::success;
135 }
136
137 const compute::device_info_t *dev_info_ = nullptr;
138
139 private:
140 bool any_prepacked_ = false;
141 bool packed_a_ = false, packed_b_ = false, packed_c_ = false;
142 bool a_zp_ = false, b_zp_ = false, c_zp_ = false;
143 int unroll_m_ = 0;
144 int unroll_n_ = 0;
145 bool alt_ = false;
146 };
147
148 status_t init(engine_t *engine) override;
149 status_t init_res_storage(
150 engine_t *engine, gpu_resource_t *r) const override;
151
152public:
153 xe_hp_systolic_gemm_t(const pd_t *apd) : gpu_gemm_t(apd) {}
154
155 virtual status_t execute(const gemm_exec_ctx_t &ctx) const override;
156
157private:
158 status_t init_compute(engine_t *engine);
159
160 bool enable_mn_blocking() const;
161 std::tuple<int64_t, int64_t, int64_t> get_blocking() const;
162
163 status_t launch_clear_sum(const gemm_exec_ctx_t &ctx, int64_t r, int64_t c,
164 const memory_storage_t &dst, int32_t offset_dst, int32_t ld_dst,
165 bool copyb) const;
166 status_t launch_copy(const gemm_exec_ctx_t &ctx, int64_t r, int64_t c,
167 const memory_storage_t &src, int64_t offset_src, int64_t ld_src,
168 const memory_storage_t &dst, int32_t offset_dst, int32_t ld_dst,
169 bool copyb) const;
170 status_t launch_compute(const gemm_exec_ctx_t &ctx, int32_t m, int32_t n,
171 int32_t k, const memory_storage_t &ap, int64_t offset_a,
172 int32_t lda, const memory_storage_t &bp, int64_t offset_b,
173 int32_t ldb, const memory_storage_t &c, int64_t offset_c,
174 int32_t ldc, float alpha, float beta, const memory_storage_t *ao,
175 const memory_storage_t *bo, const memory_storage_t &co,
176 int32_t offset_co, int po_count, const memory_storage_t **po_src,
177 int32_t *offset_po_src, bool first_k_block, bool last_k_block,
178 int32_t batch, int32_t stride_a, int32_t stride_b,
179 int32_t stride_c) const;
180
181 static int64_t nice_ld(int64_t ld, int sz, bool get_max = false) {
182 const auto align = 32;
183 const auto no_align = 64;
184
185 auto new_ld = (ld * sz + align - 1) & ~(align - 1);
186 if (get_max || (new_ld & (no_align - 1)) == 0) new_ld += align;
187
188 return new_ld / sz;
189 }
190
191 int64_t get_ld_packed(int64_t k, bool get_max = false) const {
192 auto a_type = pd()->desc()->a_type();
193 auto a_sz = types::data_type_size(a_type);
194
195 int unroll_k = int(32 / a_sz);
196 auto ld = utils::rnd_up(k, unroll_k);
197 if (pd()->with_ab_zero_points()) ld += unroll_k;
198
199 return nice_ld(ld, int(a_sz), get_max);
200 }
201
202 int64_t max_ld_packed(int64_t k) const { return get_ld_packed(k, true); }
203
204 static const int A_PACKED_ = 0;
205 static const int B_PACKED_ = 1;
206
207 compute::kernel_t kernel_[2][2]; // [first_k_block][last_k_block]
208 compute::kernel_t copy_kernel_[2][2]; // [trans][clear_sum]
209
210 CommonDriverInfo compute_info_;
211
212 compute::gpu_arch_t arch_ = compute::gpu_arch_t::unknown;
213 int eu_count_ = 0;
214
215 char co_kind_ = 'N';
216 bool walk_n_first_ = false;
217
218 GEMMProblem problem_;
219
220 const pd_t *pd() const { return (const pd_t *)gpu_primitive_t::pd().get(); }
221};
222
223} // namespace jit
224} // namespace gpu
225} // namespace impl
226} // namespace dnnl
227
228#endif
229// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
230