1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "test_gemm_data_preparation.hpp"
18#include "test_gemm_params.hpp"
19#include "test_gemm_validation.hpp"
20#include "gtest/gtest.h"
21
22#include "dnnl_test_common.hpp"
23
24#include "oneapi/dnnl/dnnl.hpp"
25#include "tests/test_isa_common.hpp"
26
27#include "cpu/x64/amx_tile_configure.hpp"
28#include "cpu/x64/brgemm/brgemm.hpp"
29
30namespace dnnl {
31
32struct brgemm_params_t : test_params {
33 impl::data_type_t dt_a;
34 impl::data_type_t dt_b;
35 impl::cpu::x64::brgemm_batch_kind_t batch_kind;
36 impl::cpu::x64::brgemm_layout_t layout;
37
38 impl::cpu::x64::brgemm_attr_t attrs;
39
40 int bs;
41};
42
43class params_creator_t {
44public:
45 std::vector<brgemm_params_t> create_simple_brgemm_params() {
46 params = {};
47
48 transpose_ = {'n'};
49 sizes_and_leading_dims_[0]
50 = {{1, 4}, {3, 3}, {3, 8}, {30, 30}, {64, 64}, {31, 61}};
51 sizes_and_leading_dims_[1]
52 = {{1, 4}, {2, 6}, {2, 5}, {20, 20}, {64, 64}, {21, 51}};
53 sizes_and_leading_dims_[2]
54 = {{1, 4}, {1, 3}, {1, 2}, {10, 20}, {64, 64}, {11, 81}};
55
56 alpha_values_ = {1.0f, 2.0f, 2.5f};
57 beta_values_ = {0.0f, 1.0f, 2.0f};
58
59 amx_dts_ = {{dnnl_f32, dnnl_f32}};
60 dts_ = {{dnnl_f32, dnnl_f32}};
61
62 put_params();
63
64 sizes_and_leading_dims_[0] = {{4, 4}, {8, 12}, {64, 1024}};
65 sizes_and_leading_dims_[1] = {{4, 4}, {16, 32}, {128, 512}};
66 sizes_and_leading_dims_[2] = {{4, 4}, {12, 56}, {16, 256}};
67
68 amx_dts_ = {
69 {dnnl_bf16, dnnl_bf16}, {dnnl_u8, dnnl_u8}, {dnnl_s8, dnnl_s8}};
70 dts_ = {{dnnl_bf16, dnnl_bf16}, {dnnl_u8, dnnl_s8}};
71
72 put_params();
73
74 return params;
75 }
76
77private:
78 void put_params() {
79 for_(auto tr : transpose_)
80 for_(size_t i = 0; i < sizes_and_leading_dims_[0].size(); i++)
81 for_(auto alpha : alpha_values_)
82 for_(auto beta : beta_values_)
83 for (auto dt : is_amx ? amx_dts_ : dts_) {
84 brgemm_params_t param = {};
85 param.transA = tr;
86 param.transB = 'n';
87 param.M = sizes_and_leading_dims_[0][i].first;
88 param.lda = sizes_and_leading_dims_[0][i].second;
89 param.N = sizes_and_leading_dims_[1][i].first;
90 param.ldb = sizes_and_leading_dims_[1][i].second;
91 param.K = sizes_and_leading_dims_[2][i].first;
92 param.ldc = sizes_and_leading_dims_[2][i].second;
93 param.alpha = alpha;
94 param.beta = beta;
95 param.dt_a = dt.first;
96 param.dt_b = dt.second;
97 param.batch_kind = impl::cpu::x64::brgemm_addr;
98 param.layout = impl::cpu::x64::brgemm_row_major;
99 param.bs = 1;
100 param.attrs.max_bs = 1;
101 param.attrs.max_top_vpad = 0;
102 param.attrs.max_bottom_vpad = 0;
103 param.expect_to_fail = false;
104 param.expected_status = dnnl_success;
105
106 params.emplace_back(param);
107 }
108 }
109
110 const bool is_amx = dnnl::mayiuse(cpu_isa::avx512_core_amx);
111
112 std::vector<char> transpose_;
113 std::vector<std::pair<int64_t, int64_t>> sizes_and_leading_dims_[3];
114 std::vector<float> alpha_values_;
115 std::vector<float> beta_values_;
116 std::vector<std::pair<impl::data_type_t, impl::data_type_t>> amx_dts_;
117 std::vector<std::pair<impl::data_type_t, impl::data_type_t>> dts_;
118
119 std::vector<brgemm_params_t> params;
120};
121
122class brgemm_test_t : public ::testing::TestWithParam<brgemm_params_t> {
123protected:
124 void SetUp() override {
125 const auto &p = GetParam();
126
127 SKIP_IF(engine::get_count(engine::kind::cpu) == 0,
128 "Brgemm requires cpu.");
129 eng_ = std::make_shared<engine>(engine::kind::cpu, 0);
130
131 SKIP_IF(!impl::cpu::platform::has_data_type_support(p.dt_a),
132 "Engine does not support this data type.");
133
134 catch_expected_failures(
135 [=]() { Test(); }, p.expect_to_fail, p.expected_status, true);
136 }
137
138 void Test() {
139 const auto &p = ::testing::TestWithParam<brgemm_params_t>::GetParam();
140 run_proper_test(p);
141 }
142
143private:
144 template <typename b_dt>
145 void reorder_B(const brgemm_params_t &p, const mapped_ptr_t<b_dt> &b_mem,
146 mapped_ptr_t<b_dt> &b_mem_reordered) const {
147 static constexpr int k_pack = 4 / sizeof(b_dt);
148
149 dnnl::impl::parallel_nd(p.K, p.N, [&](int64_t k, int64_t n) {
150 size_t b_off = k * p.ldb + n;
151 size_t b_reordered_off
152 = (k / k_pack) * p.ldb * k_pack + n * k_pack + k % k_pack;
153 b_mem_reordered[b_reordered_off] = b_mem[b_off];
154 });
155 }
156
157 template <typename b_dt>
158 mapped_ptr_t<b_dt> get_B_mem(const brgemm_params_t &p) {
159 mapped_ptr_t<b_dt> B = map_memory<b_dt>(*gemm_data_.b_mem);
160
161 static constexpr int k_pack = 4 / sizeof(b_dt);
162 if (k_pack > 1) {
163 size_t sizeA, sizeB, sizeC;
164 get_matrix_size(p, sizeA, sizeB, sizeC);
165
166 b_mem_reordered_ = std::make_shared<test_memory>(
167 get_matrix_md<b_dt>(sizeB, p.off.b), *eng_);
168 auto B_reordered = map_memory<b_dt>(*b_mem_reordered_);
169
170 reorder_B(p, B, B_reordered);
171
172 return B_reordered;
173 }
174
175 return B;
176 }
177
178 template <typename a_dt, typename b_dt, typename c_dt>
179 dnnl_status_t run_brgemm(const brgemm_params_t &p) {
180 using namespace dnnl::impl::cpu;
181 using namespace dnnl::impl::cpu::x64;
182
183 mapped_ptr_t<a_dt> A = map_memory<a_dt>(*gemm_data_.a_mem);
184 mapped_ptr_t<b_dt> B = get_B_mem<b_dt>(p);
185 mapped_ptr_t<c_dt> C = map_memory<c_dt>(*gemm_data_.c_mem);
186
187 //initialize brgemm kernel
188 char palette[64];
189 char tile_buffer[1024];
190 x64::brgemm_t desc;
191 auto res = brgemm_desc_init(&desc, x64::cpu_isa_t::isa_undef,
192 p.batch_kind, p.dt_a, p.dt_b, p.tr_a(), p.tr_b(), p.layout,
193 p.alpha, p.beta, p.lda, p.ldb, p.ldc, p.M, p.N, p.K);
194 if (res != dnnl_success) return res;
195
196 if (desc.is_tmm) res = brgemm_init_tiles(desc, palette);
197 if (!desc.is_tmm) brgemm_desc_set_attr(&desc, p.attrs);
198
199 if (res != dnnl_success) return res;
200
201 x64::brgemm_kernel_t *_t_ptr;
202 res = brgemm_kernel_create(&_t_ptr, desc);
203
204 x64::brgemm_batch_element_t batch_element;
205 batch_element.ptr.A = A;
206 batch_element.ptr.B = B;
207 batch_element.vvpad.top = 0;
208 batch_element.vvpad.bottom = 0;
209 if (desc.is_tmm) amx_tile_configure(palette);
210 brgemm_kernel_execute(_t_ptr, p.bs, &batch_element, C,
211 desc.is_tmm ? tile_buffer : nullptr);
212
213 brgemm_kernel_destroy(_t_ptr);
214 if (desc.is_tmm) amx_tile_release();
215
216 return res;
217 }
218
219 template <typename a_dt, typename b_dt, typename c_dt>
220 void test_brgemm(const brgemm_params_t &p) {
221 gemm_data_ = {};
222 prepare_data_for_gemm_testing<a_dt, b_dt, c_dt>(p, gemm_data_, *eng_);
223
224 dnnl_status_t status = run_brgemm<a_dt, b_dt, c_dt>(p);
225
226 if (status == dnnl_success) {
227 validate<a_dt, b_dt, c_dt>(p, gemm_data_);
228 }
229
230 if (status != dnnl_success)
231 throw error(status, "oneDNN brgemm returned error");
232 }
233
234 void run_proper_test(const brgemm_params_t &p) {
235 using namespace impl::cpu::x64;
236
237 if (dnnl::mayiuse(cpu_isa::avx512_core_amx)) {
238 if (p.dt_a == dnnl_f32 && p.dt_b == dnnl_f32)
239 test_brgemm<float, float, float>(p);
240 else if (p.dt_a == dnnl_bf16 && p.dt_b == dnnl_bf16)
241 test_brgemm<bfloat16_t, bfloat16_t, float>(p);
242 else if (p.dt_a == dnnl_s8 && p.dt_b == dnnl_s8)
243 test_brgemm<int8_t, int8_t, int32_t>(p);
244 else if (p.dt_a == dnnl_u8 && p.dt_b == dnnl_u8)
245 test_brgemm<uint8_t, uint8_t, int32_t>(p);
246 else
247 throw error(dnnl_unimplemented, "Brgemm unimplemented.");
248 } else {
249 if (p.dt_a == dnnl_f32 && p.dt_b == dnnl_f32)
250 test_brgemm<float, float, float>(p);
251 else if (p.dt_a == dnnl_bf16 && p.dt_b == dnnl_bf16)
252 test_brgemm<bfloat16_t, bfloat16_t, float>(p);
253 else if (p.dt_a == dnnl_u8 && p.dt_b == dnnl_s8) {
254 assert(p.layout == brgemm_layout_t::brgemm_row_major);
255 test_brgemm<uint8_t, int8_t, int32_t>(p);
256 } else if (p.dt_a == dnnl_s8 && p.dt_b == dnnl_u8) {
257 assert(p.layout == brgemm_layout_t::brgemm_col_major);
258 test_brgemm<int8_t, uint8_t, int32_t>(p);
259 } else
260 throw error(dnnl_unimplemented, "Brgemm unimplemented.");
261 }
262 }
263
264 std::shared_ptr<engine> eng_;
265 test_gemm_data gemm_data_;
266 std::shared_ptr<test_memory> b_mem_reordered_;
267};
268
269TEST_P(brgemm_test_t, TestsBRGEMM) {}
270INSTANTIATE_TEST_SUITE_P(TestBRGEMMSimple, brgemm_test_t,
271 ::testing::ValuesIn(params_creator_t().create_simple_brgemm_params()));
272
273} // namespace dnnl
274