1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "test_gemm_data_preparation.hpp" |
18 | #include "test_gemm_params.hpp" |
19 | #include "test_gemm_validation.hpp" |
20 | #include "gtest/gtest.h" |
21 | |
22 | #include "dnnl_test_common.hpp" |
23 | |
24 | #include "oneapi/dnnl/dnnl.hpp" |
25 | #include "tests/test_isa_common.hpp" |
26 | |
27 | #include "cpu/x64/amx_tile_configure.hpp" |
28 | #include "cpu/x64/brgemm/brgemm.hpp" |
29 | |
30 | namespace dnnl { |
31 | |
32 | struct brgemm_params_t : test_params { |
33 | impl::data_type_t dt_a; |
34 | impl::data_type_t dt_b; |
35 | impl::cpu::x64::brgemm_batch_kind_t batch_kind; |
36 | impl::cpu::x64::brgemm_layout_t layout; |
37 | |
38 | impl::cpu::x64::brgemm_attr_t attrs; |
39 | |
40 | int bs; |
41 | }; |
42 | |
43 | class params_creator_t { |
44 | public: |
45 | std::vector<brgemm_params_t> create_simple_brgemm_params() { |
46 | params = {}; |
47 | |
48 | transpose_ = {'n'}; |
49 | sizes_and_leading_dims_[0] |
50 | = {{1, 4}, {3, 3}, {3, 8}, {30, 30}, {64, 64}, {31, 61}}; |
51 | sizes_and_leading_dims_[1] |
52 | = {{1, 4}, {2, 6}, {2, 5}, {20, 20}, {64, 64}, {21, 51}}; |
53 | sizes_and_leading_dims_[2] |
54 | = {{1, 4}, {1, 3}, {1, 2}, {10, 20}, {64, 64}, {11, 81}}; |
55 | |
56 | alpha_values_ = {1.0f, 2.0f, 2.5f}; |
57 | beta_values_ = {0.0f, 1.0f, 2.0f}; |
58 | |
59 | amx_dts_ = {{dnnl_f32, dnnl_f32}}; |
60 | dts_ = {{dnnl_f32, dnnl_f32}}; |
61 | |
62 | put_params(); |
63 | |
64 | sizes_and_leading_dims_[0] = {{4, 4}, {8, 12}, {64, 1024}}; |
65 | sizes_and_leading_dims_[1] = {{4, 4}, {16, 32}, {128, 512}}; |
66 | sizes_and_leading_dims_[2] = {{4, 4}, {12, 56}, {16, 256}}; |
67 | |
68 | amx_dts_ = { |
69 | {dnnl_bf16, dnnl_bf16}, {dnnl_u8, dnnl_u8}, {dnnl_s8, dnnl_s8}}; |
70 | dts_ = {{dnnl_bf16, dnnl_bf16}, {dnnl_u8, dnnl_s8}}; |
71 | |
72 | put_params(); |
73 | |
74 | return params; |
75 | } |
76 | |
77 | private: |
78 | void put_params() { |
79 | for_(auto tr : transpose_) |
80 | for_(size_t i = 0; i < sizes_and_leading_dims_[0].size(); i++) |
81 | for_(auto alpha : alpha_values_) |
82 | for_(auto beta : beta_values_) |
83 | for (auto dt : is_amx ? amx_dts_ : dts_) { |
84 | brgemm_params_t param = {}; |
85 | param.transA = tr; |
86 | param.transB = 'n'; |
87 | param.M = sizes_and_leading_dims_[0][i].first; |
88 | param.lda = sizes_and_leading_dims_[0][i].second; |
89 | param.N = sizes_and_leading_dims_[1][i].first; |
90 | param.ldb = sizes_and_leading_dims_[1][i].second; |
91 | param.K = sizes_and_leading_dims_[2][i].first; |
92 | param.ldc = sizes_and_leading_dims_[2][i].second; |
93 | param.alpha = alpha; |
94 | param.beta = beta; |
95 | param.dt_a = dt.first; |
96 | param.dt_b = dt.second; |
97 | param.batch_kind = impl::cpu::x64::brgemm_addr; |
98 | param.layout = impl::cpu::x64::brgemm_row_major; |
99 | param.bs = 1; |
100 | param.attrs.max_bs = 1; |
101 | param.attrs.max_top_vpad = 0; |
102 | param.attrs.max_bottom_vpad = 0; |
103 | param.expect_to_fail = false; |
104 | param.expected_status = dnnl_success; |
105 | |
106 | params.emplace_back(param); |
107 | } |
108 | } |
109 | |
110 | const bool is_amx = dnnl::mayiuse(cpu_isa::avx512_core_amx); |
111 | |
112 | std::vector<char> transpose_; |
113 | std::vector<std::pair<int64_t, int64_t>> sizes_and_leading_dims_[3]; |
114 | std::vector<float> alpha_values_; |
115 | std::vector<float> beta_values_; |
116 | std::vector<std::pair<impl::data_type_t, impl::data_type_t>> amx_dts_; |
117 | std::vector<std::pair<impl::data_type_t, impl::data_type_t>> dts_; |
118 | |
119 | std::vector<brgemm_params_t> params; |
120 | }; |
121 | |
122 | class brgemm_test_t : public ::testing::TestWithParam<brgemm_params_t> { |
123 | protected: |
124 | void SetUp() override { |
125 | const auto &p = GetParam(); |
126 | |
127 | SKIP_IF(engine::get_count(engine::kind::cpu) == 0, |
128 | "Brgemm requires cpu." ); |
129 | eng_ = std::make_shared<engine>(engine::kind::cpu, 0); |
130 | |
131 | SKIP_IF(!impl::cpu::platform::has_data_type_support(p.dt_a), |
132 | "Engine does not support this data type." ); |
133 | |
134 | catch_expected_failures( |
135 | [=]() { Test(); }, p.expect_to_fail, p.expected_status, true); |
136 | } |
137 | |
138 | void Test() { |
139 | const auto &p = ::testing::TestWithParam<brgemm_params_t>::GetParam(); |
140 | run_proper_test(p); |
141 | } |
142 | |
143 | private: |
144 | template <typename b_dt> |
145 | void reorder_B(const brgemm_params_t &p, const mapped_ptr_t<b_dt> &b_mem, |
146 | mapped_ptr_t<b_dt> &b_mem_reordered) const { |
147 | static constexpr int k_pack = 4 / sizeof(b_dt); |
148 | |
149 | dnnl::impl::parallel_nd(p.K, p.N, [&](int64_t k, int64_t n) { |
150 | size_t b_off = k * p.ldb + n; |
151 | size_t b_reordered_off |
152 | = (k / k_pack) * p.ldb * k_pack + n * k_pack + k % k_pack; |
153 | b_mem_reordered[b_reordered_off] = b_mem[b_off]; |
154 | }); |
155 | } |
156 | |
157 | template <typename b_dt> |
158 | mapped_ptr_t<b_dt> get_B_mem(const brgemm_params_t &p) { |
159 | mapped_ptr_t<b_dt> B = map_memory<b_dt>(*gemm_data_.b_mem); |
160 | |
161 | static constexpr int k_pack = 4 / sizeof(b_dt); |
162 | if (k_pack > 1) { |
163 | size_t sizeA, sizeB, sizeC; |
164 | get_matrix_size(p, sizeA, sizeB, sizeC); |
165 | |
166 | b_mem_reordered_ = std::make_shared<test_memory>( |
167 | get_matrix_md<b_dt>(sizeB, p.off.b), *eng_); |
168 | auto B_reordered = map_memory<b_dt>(*b_mem_reordered_); |
169 | |
170 | reorder_B(p, B, B_reordered); |
171 | |
172 | return B_reordered; |
173 | } |
174 | |
175 | return B; |
176 | } |
177 | |
178 | template <typename a_dt, typename b_dt, typename c_dt> |
179 | dnnl_status_t run_brgemm(const brgemm_params_t &p) { |
180 | using namespace dnnl::impl::cpu; |
181 | using namespace dnnl::impl::cpu::x64; |
182 | |
183 | mapped_ptr_t<a_dt> A = map_memory<a_dt>(*gemm_data_.a_mem); |
184 | mapped_ptr_t<b_dt> B = get_B_mem<b_dt>(p); |
185 | mapped_ptr_t<c_dt> C = map_memory<c_dt>(*gemm_data_.c_mem); |
186 | |
187 | //initialize brgemm kernel |
188 | char palette[64]; |
189 | char tile_buffer[1024]; |
190 | x64::brgemm_t desc; |
191 | auto res = brgemm_desc_init(&desc, x64::cpu_isa_t::isa_undef, |
192 | p.batch_kind, p.dt_a, p.dt_b, p.tr_a(), p.tr_b(), p.layout, |
193 | p.alpha, p.beta, p.lda, p.ldb, p.ldc, p.M, p.N, p.K); |
194 | if (res != dnnl_success) return res; |
195 | |
196 | if (desc.is_tmm) res = brgemm_init_tiles(desc, palette); |
197 | if (!desc.is_tmm) brgemm_desc_set_attr(&desc, p.attrs); |
198 | |
199 | if (res != dnnl_success) return res; |
200 | |
201 | x64::brgemm_kernel_t *_t_ptr; |
202 | res = brgemm_kernel_create(&_t_ptr, desc); |
203 | |
204 | x64::brgemm_batch_element_t batch_element; |
205 | batch_element.ptr.A = A; |
206 | batch_element.ptr.B = B; |
207 | batch_element.vvpad.top = 0; |
208 | batch_element.vvpad.bottom = 0; |
209 | if (desc.is_tmm) amx_tile_configure(palette); |
210 | brgemm_kernel_execute(_t_ptr, p.bs, &batch_element, C, |
211 | desc.is_tmm ? tile_buffer : nullptr); |
212 | |
213 | brgemm_kernel_destroy(_t_ptr); |
214 | if (desc.is_tmm) amx_tile_release(); |
215 | |
216 | return res; |
217 | } |
218 | |
219 | template <typename a_dt, typename b_dt, typename c_dt> |
220 | void test_brgemm(const brgemm_params_t &p) { |
221 | gemm_data_ = {}; |
222 | prepare_data_for_gemm_testing<a_dt, b_dt, c_dt>(p, gemm_data_, *eng_); |
223 | |
224 | dnnl_status_t status = run_brgemm<a_dt, b_dt, c_dt>(p); |
225 | |
226 | if (status == dnnl_success) { |
227 | validate<a_dt, b_dt, c_dt>(p, gemm_data_); |
228 | } |
229 | |
230 | if (status != dnnl_success) |
231 | throw error(status, "oneDNN brgemm returned error" ); |
232 | } |
233 | |
234 | void run_proper_test(const brgemm_params_t &p) { |
235 | using namespace impl::cpu::x64; |
236 | |
237 | if (dnnl::mayiuse(cpu_isa::avx512_core_amx)) { |
238 | if (p.dt_a == dnnl_f32 && p.dt_b == dnnl_f32) |
239 | test_brgemm<float, float, float>(p); |
240 | else if (p.dt_a == dnnl_bf16 && p.dt_b == dnnl_bf16) |
241 | test_brgemm<bfloat16_t, bfloat16_t, float>(p); |
242 | else if (p.dt_a == dnnl_s8 && p.dt_b == dnnl_s8) |
243 | test_brgemm<int8_t, int8_t, int32_t>(p); |
244 | else if (p.dt_a == dnnl_u8 && p.dt_b == dnnl_u8) |
245 | test_brgemm<uint8_t, uint8_t, int32_t>(p); |
246 | else |
247 | throw error(dnnl_unimplemented, "Brgemm unimplemented." ); |
248 | } else { |
249 | if (p.dt_a == dnnl_f32 && p.dt_b == dnnl_f32) |
250 | test_brgemm<float, float, float>(p); |
251 | else if (p.dt_a == dnnl_bf16 && p.dt_b == dnnl_bf16) |
252 | test_brgemm<bfloat16_t, bfloat16_t, float>(p); |
253 | else if (p.dt_a == dnnl_u8 && p.dt_b == dnnl_s8) { |
254 | assert(p.layout == brgemm_layout_t::brgemm_row_major); |
255 | test_brgemm<uint8_t, int8_t, int32_t>(p); |
256 | } else if (p.dt_a == dnnl_s8 && p.dt_b == dnnl_u8) { |
257 | assert(p.layout == brgemm_layout_t::brgemm_col_major); |
258 | test_brgemm<int8_t, uint8_t, int32_t>(p); |
259 | } else |
260 | throw error(dnnl_unimplemented, "Brgemm unimplemented." ); |
261 | } |
262 | } |
263 | |
264 | std::shared_ptr<engine> eng_; |
265 | test_gemm_data gemm_data_; |
266 | std::shared_ptr<test_memory> b_mem_reordered_; |
267 | }; |
268 | |
269 | TEST_P(brgemm_test_t, TestsBRGEMM) {} |
270 | INSTANTIATE_TEST_SUITE_P(TestBRGEMMSimple, brgemm_test_t, |
271 | ::testing::ValuesIn(params_creator_t().create_simple_brgemm_params())); |
272 | |
273 | } // namespace dnnl |
274 | |