1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef GPU_JIT_XE_HP_SYSTOLIC_GEMM_HPP |
18 | #define GPU_JIT_XE_HP_SYSTOLIC_GEMM_HPP |
19 | |
20 | #include <assert.h> |
21 | #include <memory> |
22 | #include <tuple> |
23 | |
24 | #include "common/c_types_map.hpp" |
25 | #include "common/gemm_utils.hpp" |
26 | #include "common/memory_storage.hpp" |
27 | #include "common/utils.hpp" |
28 | #include "gpu/compute/compute.hpp" |
29 | #include "gpu/gemm/gpu_gemm.hpp" |
30 | #include "gpu/jit/gemm/gen_gemm_kernel.hpp" |
31 | #include "gpu/jit/gemm/jit_gemm_pd.hpp" |
32 | #include "gpu/primitive_conf.hpp" |
33 | |
34 | namespace dnnl { |
35 | namespace impl { |
36 | namespace gpu { |
37 | namespace jit { |
38 | |
39 | struct xe_hp_systolic_gemm_t : public gpu_gemm_t { |
40 | struct pd_t : public jit_gemm_pd_t { |
41 | using jit_gemm_pd_t::jit_gemm_pd_t; |
42 | |
43 | DECLARE_COMMON_PD_T("jit:xe_hp:gemm:any" , xe_hp_systolic_gemm_t); |
44 | |
45 | status_t init(engine_t *engine); |
46 | |
47 | bool use_nocopy(); |
48 | bool set_default_formats(data_type_t dt); |
49 | |
50 | size_t dyn_offset_a = 0; |
51 | size_t dyn_offset_b = 0; |
52 | size_t dyn_offset_c = 0; |
53 | |
54 | data_type_t impl_co_type() const { |
55 | using namespace data_type; |
56 | return with_bias() ? desc()->bias_type() |
57 | : (utils::one_of(desc()->a_type(), s8, u8) |
58 | ? s32 |
59 | : desc()->c_type()); |
60 | } |
61 | |
62 | data_type_t impl_acc_type() const { |
63 | using namespace data_type; |
64 | return utils::one_of(desc()->c_type(), s8, u8, f16, bf16, f32) |
65 | ? (utils::one_of(desc()->a_type(), s8, u8) ? s32 : f32) |
66 | : s32; |
67 | } |
68 | |
69 | float alpha() const { return 1.0f; } |
70 | float beta() const { return beta_; } |
71 | |
72 | bool with_bias() const { |
73 | return (desc()->bias_type() != data_type::undef) |
74 | && !bias_via_binary_; |
75 | } |
76 | |
77 | int bias_cmask() const { |
78 | unsigned char to_cmask[8] = {0, 4, 2, 6, 1, 5, 3, 7}; |
79 | assert(unsigned(desc()->bias_mask()) < 8); |
80 | return with_bias() ? to_cmask[desc()->bias_mask() & 7] : -1; |
81 | } |
82 | |
83 | bool packed_a() const { return packed_a_; } |
84 | bool packed_b() const { return packed_b_; } |
85 | bool packed_c() const { return packed_c_; } |
86 | |
87 | dim_t lda_packed() const { |
88 | return packed_a() ? desc()->b_desc.format_desc.blocking |
89 | .strides[with_batch() ? 2 : 1] |
90 | / unroll_m() |
91 | : 0; |
92 | } |
93 | dim_t ldb_packed() const { |
94 | return packed_b() ? desc()->a_desc.format_desc.blocking |
95 | .strides[with_batch() ? 1 : 0] |
96 | / unroll_n() |
97 | : 0; |
98 | } |
99 | dim_t ldc_packed() const { |
100 | return packed_c() ? desc()->c_desc.format_desc.blocking |
101 | .strides[with_batch() ? 1 : 0] |
102 | / unroll_n() |
103 | : 0; |
104 | } |
105 | |
106 | int batch_dims() const { |
107 | return nstl::max(desc()->c_desc.ndims - 2, 0); |
108 | } |
109 | |
110 | bool with_batch() const { return desc()->is_batched(); } |
111 | bool with_a_zero_points() const { return a_zp_; } |
112 | bool with_b_zero_points() const { return b_zp_; } |
113 | bool with_ab_zero_points() const { return a_zp_ || b_zp_; } |
114 | bool with_c_zero_points() const { return c_zp_; } |
115 | |
116 | bool allow_k_blocking() const { |
117 | return (desc()->acc_type == desc()->c_type()) |
118 | && IMPLICATION(post_ops()->len() > 0, |
119 | post_ops()->entry_[0].kind == primitive_kind::sum); |
120 | } |
121 | |
122 | int unroll_m() const { return unroll_m_; } |
123 | int unroll_n() const { return unroll_n_; } |
124 | bool alt() const { return alt_; } |
125 | |
126 | status_t query(query_t what, int idx, void *result) const override { |
127 | switch ((int)what) { |
128 | case (int)query::preferred_gpu_threads_per_eu: { |
129 | *(int *)result = 4; |
130 | break; |
131 | } |
132 | default: return gpu_gemm_pd_t::query(what, idx, result); |
133 | } |
134 | return status::success; |
135 | } |
136 | |
137 | const compute::device_info_t *dev_info_ = nullptr; |
138 | |
139 | private: |
140 | bool any_prepacked_ = false; |
141 | bool packed_a_ = false, packed_b_ = false, packed_c_ = false; |
142 | bool a_zp_ = false, b_zp_ = false, c_zp_ = false; |
143 | int unroll_m_ = 0; |
144 | int unroll_n_ = 0; |
145 | bool alt_ = false; |
146 | }; |
147 | |
148 | status_t init(engine_t *engine) override; |
149 | status_t init_res_storage( |
150 | engine_t *engine, gpu_resource_t *r) const override; |
151 | |
152 | public: |
153 | xe_hp_systolic_gemm_t(const pd_t *apd) : gpu_gemm_t(apd) {} |
154 | |
155 | virtual status_t execute(const gemm_exec_ctx_t &ctx) const override; |
156 | |
157 | private: |
158 | status_t init_compute(engine_t *engine); |
159 | |
160 | bool enable_mn_blocking() const; |
161 | std::tuple<int64_t, int64_t, int64_t> get_blocking() const; |
162 | |
163 | status_t launch_clear_sum(const gemm_exec_ctx_t &ctx, int64_t r, int64_t c, |
164 | const memory_storage_t &dst, int32_t offset_dst, int32_t ld_dst, |
165 | bool copyb) const; |
166 | status_t launch_copy(const gemm_exec_ctx_t &ctx, int64_t r, int64_t c, |
167 | const memory_storage_t &src, int64_t offset_src, int64_t ld_src, |
168 | const memory_storage_t &dst, int32_t offset_dst, int32_t ld_dst, |
169 | bool copyb) const; |
170 | status_t launch_compute(const gemm_exec_ctx_t &ctx, int32_t m, int32_t n, |
171 | int32_t k, const memory_storage_t &ap, int64_t offset_a, |
172 | int32_t lda, const memory_storage_t &bp, int64_t offset_b, |
173 | int32_t ldb, const memory_storage_t &c, int64_t offset_c, |
174 | int32_t ldc, float alpha, float beta, const memory_storage_t *ao, |
175 | const memory_storage_t *bo, const memory_storage_t &co, |
176 | int32_t offset_co, int po_count, const memory_storage_t **po_src, |
177 | int32_t *offset_po_src, bool first_k_block, bool last_k_block, |
178 | int32_t batch, int32_t stride_a, int32_t stride_b, |
179 | int32_t stride_c) const; |
180 | |
181 | static int64_t nice_ld(int64_t ld, int sz, bool get_max = false) { |
182 | const auto align = 32; |
183 | const auto no_align = 64; |
184 | |
185 | auto new_ld = (ld * sz + align - 1) & ~(align - 1); |
186 | if (get_max || (new_ld & (no_align - 1)) == 0) new_ld += align; |
187 | |
188 | return new_ld / sz; |
189 | } |
190 | |
191 | int64_t get_ld_packed(int64_t k, bool get_max = false) const { |
192 | auto a_type = pd()->desc()->a_type(); |
193 | auto a_sz = types::data_type_size(a_type); |
194 | |
195 | int unroll_k = int(32 / a_sz); |
196 | auto ld = utils::rnd_up(k, unroll_k); |
197 | if (pd()->with_ab_zero_points()) ld += unroll_k; |
198 | |
199 | return nice_ld(ld, int(a_sz), get_max); |
200 | } |
201 | |
202 | int64_t max_ld_packed(int64_t k) const { return get_ld_packed(k, true); } |
203 | |
204 | static const int A_PACKED_ = 0; |
205 | static const int B_PACKED_ = 1; |
206 | |
207 | compute::kernel_t kernel_[2][2]; // [first_k_block][last_k_block] |
208 | compute::kernel_t copy_kernel_[2][2]; // [trans][clear_sum] |
209 | |
210 | CommonDriverInfo compute_info_; |
211 | |
212 | compute::gpu_arch_t arch_ = compute::gpu_arch_t::unknown; |
213 | int eu_count_ = 0; |
214 | |
215 | char co_kind_ = 'N'; |
216 | bool walk_n_first_ = false; |
217 | |
218 | GEMMProblem problem_; |
219 | |
220 | const pd_t *pd() const { return (const pd_t *)gpu_primitive_t::pd().get(); } |
221 | }; |
222 | |
223 | } // namespace jit |
224 | } // namespace gpu |
225 | } // namespace impl |
226 | } // namespace dnnl |
227 | |
228 | #endif |
229 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
230 | |