1 | /******************************************************************************* |
2 | * Copyright 2018-2021 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef TEST_GEMM_COMMON_H |
18 | #define TEST_GEMM_COMMON_H |
19 | |
20 | #include <cstdint> |
21 | #include <utility> |
22 | #include <vector> |
23 | #include <type_traits> |
24 | |
25 | #include "test_gemm_data_preparation.hpp" |
26 | #include "test_gemm_params.hpp" |
27 | #include "test_gemm_validation.hpp" |
28 | |
29 | #include "dnnl_test_common.hpp" |
30 | #include "dnnl_thread.hpp" |
31 | #include "gtest/gtest.h" |
32 | |
33 | #include "oneapi/dnnl/dnnl.h" |
34 | #include "oneapi/dnnl/dnnl_types.h" |
35 | |
36 | #if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL |
37 | #include "oneapi/dnnl/dnnl_ocl.hpp" |
38 | #endif |
39 | |
40 | #if DNNL_GPU_RUNTIME == DNNL_RUNTIME_SYCL |
41 | #include "oneapi/dnnl/dnnl_sycl.hpp" |
42 | #endif |
43 | |
44 | #if DNNL_CPU_RUNTIME == DNNL_RUNTIME_THREADPOOL |
45 | #include "oneapi/dnnl/dnnl_threadpool.hpp" |
46 | #include "tests/test_thread.hpp" |
47 | #endif |
48 | |
49 | #include "tests/test_isa_common.hpp" |
50 | |
51 | #define CONCAT_WITH_UNDERSCORE_(a, b) a##_##b |
52 | #define CONCAT_WITH_UNDERSCORE(a, b) CONCAT_WITH_UNDERSCORE_(a, b) |
53 | |
54 | #define INST_TEST_CASE_(str, ...) \ |
55 | INSTANTIATE_TEST_SUITE_P(str, gemm_test, ::testing::Values(__VA_ARGS__)) |
56 | #define INST_TEST_CASE(str, ...) \ |
57 | INST_TEST_CASE_( \ |
58 | CONCAT_WITH_UNDERSCORE(str, TEST_CASE_NAME_PREFIX), __VA_ARGS__) |
59 | |
60 | #define CPU_INST_TEST_CASE_(str, ...) \ |
61 | CPU_INSTANTIATE_TEST_SUITE_P(str, gemm_test, ::testing::Values(__VA_ARGS__)) |
62 | #define CPU_INST_TEST_CASE(str, ...) \ |
63 | CPU_INST_TEST_CASE_( \ |
64 | CONCAT_WITH_UNDERSCORE(str, TEST_CASE_NAME_PREFIX), __VA_ARGS__) |
65 | |
66 | // Declare bfloat16 GEMM interfaces for testing |
67 | extern "C" { |
68 | dnnl_status_t dnnl_gemm_bf16bf16f32(char transa, char transb, dnnl_dim_t M, |
69 | dnnl_dim_t N, dnnl_dim_t K, float alpha, const bfloat16_t *A, |
70 | dnnl_dim_t lda, const bfloat16_t *B, dnnl_dim_t ldb, float beta, |
71 | float *C, dnnl_dim_t ldc); |
72 | } |
73 | |
74 | // Declare packed GEMM interfaces for testing |
75 | #include "src/cpu/gemm/gemm_pack.hpp" |
76 | |
77 | namespace dnnl { |
78 | |
79 | #if defined(DNNL_WTIH_SYCL) |
80 | bool is_memory_kind_buffer(const test_memory &mem) { |
81 | return sycl_interop::get_memory_kind(mem.get()) |
82 | == sycl_interop::memory_kind::buffer; |
83 | } |
84 | #endif |
85 | |
86 | /* Test implementation description. |
87 | * The testing steps looks as follows: |
88 | * 0. Prepare mapper_m and mapper_n <- details in test_gemm_data_preparation.hpp |
89 | * 1.a Generate random matrices A', B', C' |
90 | * 1.b Prepare matrices A, B, C based on A', B', and C' respectively |
91 | * 2. Compute C_calc := Op(M, N, K, A, B, C) |
92 | * 3. Compute C'_ref := Op_REF(M_test, N_test, K, A', B', C') |
93 | * 4. Expand C'_ref to C_ref, by applying mapper_m and mapper_n |
94 | * 5. Compare C_calc and C_ref |
95 | */ |
96 | |
97 | template <typename a_dt, typename b_dt, typename c_dt> |
98 | struct dnnl_gemm { |
99 | static dnnl_status_t call(test_params &p, const test_memory &a_mem, |
100 | const test_memory &b_mem, const test_memory &c_mem) { |
101 | throw error(dnnl_runtime_error, "unknown gemm" ); |
102 | } |
103 | }; |
104 | |
105 | template <> |
106 | struct dnnl_gemm<float16_t, float16_t, float16_t> { |
107 | static dnnl_status_t call(const test_params &p, const test_memory &a_mem, |
108 | const test_memory &b_mem, const test_memory &c_mem, |
109 | const test_memory &) { |
110 | throw error(dnnl_runtime_error, "unknown gemm" ); |
111 | } |
112 | }; |
113 | |
114 | template <> |
115 | struct dnnl_gemm<float, float, float> { |
116 | static dnnl_status_t call_packed(const test_params &p, |
117 | const test_memory &a_mem, const test_memory &b_mem, |
118 | const test_memory &c_mem) { |
119 | /* Alas, the internal API still uses Fortran notation. |
120 | * So in addition to the changes for pack API, we also need to take |
121 | * care of conversions and layouts */ |
122 | |
123 | using namespace dnnl::impl::cpu; |
124 | |
125 | assert(p.alpha == 1.f); |
126 | |
127 | /* Prepare for Fortran style, hence A <-> B */ |
128 | char trans_a = p.transB, trans_b = p.transA; |
129 | |
130 | int64_t m = p.N, n = p.M, k = p.K; |
131 | int64_t lda = p.ldb, ldb = p.lda, ldc = p.ldc; |
132 | |
133 | std::vector<float> a_pack_buf, b_pack_buf; |
134 | float *A = map_memory<float>(b_mem), *a_eff = A; |
135 | float *B = map_memory<float>(a_mem), *b_eff = B; |
136 | float *C = map_memory<float>(c_mem); |
137 | |
138 | bool pack_a = p.pack_params.pack_b; |
139 | bool pack_b = p.pack_params.pack_a; |
140 | |
141 | dnnl_status_t status = dnnl_success; |
142 | |
143 | if (pack_a) { |
144 | size_t a_sz; |
145 | status = sgemm_pack_get_size("A" , &trans_a, &trans_b, &m, &n, &k, |
146 | &lda, &ldb, &a_sz, &pack_a); |
147 | if (status != dnnl_success) return status; |
148 | |
149 | if (pack_a) { |
150 | a_pack_buf.resize(a_sz / sizeof(float)); |
151 | a_eff = a_pack_buf.data(); |
152 | |
153 | status = sgemm_pack("A" , &trans_a, &trans_b, &m, &n, &k, &lda, |
154 | &ldb, A, a_eff); |
155 | if (status != dnnl_success) return status; |
156 | } |
157 | } |
158 | |
159 | if (pack_b) { |
160 | size_t b_sz; |
161 | status = sgemm_pack_get_size("B" , &trans_a, &trans_b, &m, &n, &k, |
162 | &lda, &ldb, &b_sz, &pack_b); |
163 | if (status != dnnl_success) return status; |
164 | |
165 | if (pack_b) { |
166 | b_pack_buf.resize(b_sz / sizeof(float)); |
167 | b_eff = b_pack_buf.data(); |
168 | |
169 | status = sgemm_pack("B" , &trans_a, &trans_b, &m, &n, &k, &lda, |
170 | &ldb, B, b_eff); |
171 | if (status != dnnl_success) return status; |
172 | } |
173 | } |
174 | |
175 | if (pack_a) trans_a = 'P'; |
176 | if (pack_b) trans_b = 'P'; |
177 | |
178 | status = sgemm_compute(&trans_a, &trans_b, &m, &n, &k, a_eff, &lda, |
179 | b_eff, &ldb, &p.beta, C, &ldc); |
180 | |
181 | return status; |
182 | } |
183 | |
184 | static dnnl_status_t call(const test_params &p, const test_memory &a_mem, |
185 | const test_memory &b_mem, const test_memory &c_mem, |
186 | const test_memory &) { |
187 | |
188 | if (p.pack_params.pack_a || p.pack_params.pack_b) |
189 | return call_packed(p, a_mem, b_mem, c_mem); |
190 | |
191 | auto A = map_memory<float>(a_mem); |
192 | auto B = map_memory<float>(b_mem); |
193 | auto C = map_memory<float>(c_mem); |
194 | |
195 | #if DNNL_CPU_RUNTIME == DNNL_RUNTIME_THREADPOOL |
196 | static auto *st |
197 | = impl::testing_threadpool_utils::get_active_threadpool(); |
198 | testing::scoped_tp_deactivation_t std; |
199 | return static_cast<dnnl_status_t>(dnnl::threadpool_interop::sgemm( |
200 | p.transA, p.transB, p.M, p.N, p.K, p.alpha, A, p.lda, B, p.ldb, |
201 | p.beta, C, p.ldc, st)); |
202 | #else |
203 | return dnnl_sgemm(p.transA, p.transB, p.M, p.N, p.K, p.alpha, A, p.lda, |
204 | B, p.ldb, p.beta, C, p.ldc); |
205 | #endif |
206 | } |
207 | }; |
208 | |
209 | template <> |
210 | struct dnnl_gemm<int8_t, int8_t, int32_t> { |
211 | static dnnl_status_t call_packed(const test_params &p, |
212 | const test_memory &a_mem, const test_memory &b_mem, |
213 | const test_memory &c_mem, const test_memory &oc_mem) { |
214 | /* Alas, the internal API still uses Fortran notation. |
215 | * So in addition to the changes for pack API, we also need to take |
216 | * care of conversions and layouts */ |
217 | |
218 | using namespace dnnl::impl::cpu; |
219 | |
220 | assert(p.alpha == 1.f); |
221 | assert(p.igemm_params.oa() == 0); |
222 | assert(p.igemm_params.ob() == 0); |
223 | |
224 | /* Prepare for Fortran style, hence A <-> B */ |
225 | char trans_a = p.transB, trans_b = p.transA; |
226 | |
227 | int64_t m = p.N, n = p.M, k = p.K; |
228 | int64_t lda = p.ldb, ldb = p.lda, ldc = p.ldc; |
229 | |
230 | int8_t *A = map_memory<int8_t>(b_mem), *a_eff = A; |
231 | int8_t *B = map_memory<int8_t>(a_mem), *b_eff = B; |
232 | |
233 | auto C = map_memory<int32_t>(c_mem); |
234 | auto oc = map_memory<int32_t>(oc_mem); |
235 | |
236 | char offset_c = '\0'; |
237 | switch (p.igemm_params.offsetc) { |
238 | case 'R': offset_c = 'C'; break; |
239 | case 'r': offset_c = 'c'; break; |
240 | case 'C': offset_c = 'R'; break; |
241 | case 'c': offset_c = 'r'; break; |
242 | default: offset_c = p.igemm_params.offsetc; |
243 | } |
244 | |
245 | std::vector<int8_t> a_pack_buf; |
246 | std::vector<int8_t> b_pack_buf; |
247 | bool pack_a = p.pack_params.pack_b; |
248 | bool pack_b = p.pack_params.pack_a; |
249 | |
250 | dnnl_status_t status = dnnl_success; |
251 | |
252 | if (pack_a) { |
253 | size_t a_sz; |
254 | status = gemm_s8s8s32_pack_get_size( |
255 | "A" , &trans_a, &trans_b, &m, &n, &k, &lda, &ldb, &a_sz); |
256 | if (status != dnnl_success) return status; |
257 | |
258 | if (pack_a) { |
259 | a_pack_buf.resize(a_sz); |
260 | a_eff = a_pack_buf.data(); |
261 | |
262 | status = gemm_s8s8s32_pack("A" , &trans_a, &trans_b, &m, &n, &k, |
263 | &lda, &ldb, A, a_eff); |
264 | if (status != dnnl_success) return status; |
265 | } |
266 | } |
267 | |
268 | if (pack_b) { |
269 | size_t b_sz; |
270 | |
271 | status = gemm_s8s8s32_pack_get_size( |
272 | "B" , &trans_a, &trans_b, &m, &n, &k, &lda, &ldb, &b_sz); |
273 | if (status != dnnl_success) return status; |
274 | |
275 | if (pack_b) { |
276 | b_pack_buf.resize(b_sz); |
277 | b_eff = b_pack_buf.data(); |
278 | |
279 | status = gemm_s8s8s32_pack("B" , &trans_a, &trans_b, &m, &n, &k, |
280 | &lda, &ldb, B, b_eff); |
281 | if (status != dnnl_success) return status; |
282 | } |
283 | } |
284 | |
285 | if (pack_a) trans_a = 'P'; |
286 | if (pack_b) trans_b = 'P'; |
287 | |
288 | status = gemm_s8s8s32_compute(&trans_a, &trans_b, &offset_c, &m, &n, &k, |
289 | a_eff, &lda, b_eff, &ldb, &p.beta, C, &ldc, oc); |
290 | |
291 | return status; |
292 | } |
293 | |
294 | static dnnl_status_t call(const test_params &p, const test_memory &a_mem, |
295 | const test_memory &b_mem, const test_memory &c_mem, |
296 | const test_memory &oc_mem) { |
297 | |
298 | if (p.pack_params.pack_a || p.pack_params.pack_b) |
299 | return call_packed(p, a_mem, b_mem, c_mem, oc_mem); |
300 | |
301 | auto A = map_memory<int8_t>(a_mem); |
302 | auto B = map_memory<int8_t>(b_mem); |
303 | auto C = map_memory<int32_t>(c_mem); |
304 | auto oc = map_memory<int32_t>(oc_mem); |
305 | int8_t oa = p.igemm_params.oa(); |
306 | int8_t ob = p.igemm_params.ob(); |
307 | |
308 | #if DNNL_CPU_RUNTIME == DNNL_RUNTIME_THREADPOOL |
309 | static auto *st |
310 | = impl::testing_threadpool_utils::get_active_threadpool(); |
311 | testing::scoped_tp_deactivation_t std; |
312 | return static_cast<dnnl_status_t>( |
313 | dnnl::threadpool_interop::gemm_s8s8s32(p.transA, p.transB, |
314 | p.igemm_params.offsetc, p.M, p.N, p.K, p.alpha, A, |
315 | p.lda, oa, B, p.ldb, ob, p.beta, C, p.ldc, oc, st)); |
316 | #else |
317 | return dnnl_gemm_s8s8s32(p.transA, p.transB, p.igemm_params.offsetc, |
318 | p.M, p.N, p.K, p.alpha, A, p.lda, oa, B, p.ldb, ob, p.beta, C, |
319 | p.ldc, oc); |
320 | #endif |
321 | } |
322 | }; |
323 | |
324 | template <> |
325 | struct dnnl_gemm<int8_t, uint8_t, int32_t> { |
326 | static dnnl_status_t call(const test_params &p, const test_memory &a_mem, |
327 | const test_memory &b_mem, const test_memory &c_mem, |
328 | const test_memory &oc_mem) { |
329 | throw error(dnnl_runtime_error, "unknown gemm" ); |
330 | } |
331 | }; |
332 | |
333 | template <> |
334 | struct dnnl_gemm<uint8_t, uint8_t, int32_t> { |
335 | static dnnl_status_t call(const test_params &p, const test_memory &a_mem, |
336 | const test_memory &b_mem, const test_memory &c_mem, |
337 | const test_memory &oc_mem) { |
338 | |
339 | throw error(dnnl_runtime_error, "unknown gemm" ); |
340 | } |
341 | }; |
342 | |
343 | template <> |
344 | struct dnnl_gemm<uint8_t, int8_t, int32_t> { |
345 | static dnnl_status_t call_packed(const test_params &p, |
346 | const test_memory &a_mem, const test_memory &b_mem, |
347 | const test_memory &c_mem, const test_memory &oc_mem) { |
348 | /* Alas, the internal API still uses Fortran notation. |
349 | * So in addition to the changes for pack API, we also need to take |
350 | * care of conversions and layouts */ |
351 | |
352 | using namespace dnnl::impl::cpu; |
353 | |
354 | assert(p.alpha == 1.f); |
355 | assert(p.igemm_params.oa() == 0); |
356 | assert(p.igemm_params.ob() == 0); |
357 | |
358 | /* Prepare for Fortran style, hence A <-> B */ |
359 | char trans_a = p.transB, trans_b = p.transA; |
360 | |
361 | int64_t m = p.N, n = p.M, k = p.K; |
362 | int64_t lda = p.ldb, ldb = p.lda, ldc = p.ldc; |
363 | |
364 | int8_t *A = map_memory<int8_t>(b_mem), *a_eff = A; |
365 | uint8_t *B = map_memory<uint8_t>(a_mem), *b_eff = B; |
366 | |
367 | auto C = map_memory<int32_t>(c_mem); |
368 | auto oc = map_memory<int32_t>(oc_mem); |
369 | |
370 | char offset_c = '\0'; |
371 | switch (p.igemm_params.offsetc) { |
372 | case 'R': offset_c = 'C'; break; |
373 | case 'r': offset_c = 'c'; break; |
374 | case 'C': offset_c = 'R'; break; |
375 | case 'c': offset_c = 'r'; break; |
376 | default: offset_c = p.igemm_params.offsetc; |
377 | } |
378 | |
379 | std::vector<int8_t> a_pack_buf; |
380 | std::vector<uint8_t> b_pack_buf; |
381 | bool pack_a = p.pack_params.pack_b; |
382 | bool pack_b = p.pack_params.pack_a; |
383 | |
384 | dnnl_status_t status = dnnl_success; |
385 | |
386 | if (pack_a) { |
387 | size_t a_sz; |
388 | status = gemm_s8u8s32_pack_get_size( |
389 | "A" , &trans_a, &trans_b, &m, &n, &k, &lda, &ldb, &a_sz); |
390 | if (status != dnnl_success) return status; |
391 | |
392 | if (pack_a) { |
393 | a_pack_buf.resize(a_sz); |
394 | a_eff = a_pack_buf.data(); |
395 | |
396 | status = gemm_s8u8s32_pack("A" , &trans_a, &trans_b, &m, &n, &k, |
397 | &lda, &ldb, A, a_eff); |
398 | if (status != dnnl_success) return status; |
399 | } |
400 | } |
401 | |
402 | if (pack_b) { |
403 | size_t b_sz; |
404 | |
405 | status = gemm_s8u8s32_pack_get_size( |
406 | "B" , &trans_a, &trans_b, &m, &n, &k, &lda, &ldb, &b_sz); |
407 | if (status != dnnl_success) return status; |
408 | |
409 | if (pack_b) { |
410 | b_pack_buf.resize(b_sz); |
411 | b_eff = b_pack_buf.data(); |
412 | |
413 | status = gemm_s8u8s32_pack("B" , &trans_a, &trans_b, &m, &n, &k, |
414 | &lda, &ldb, B, b_eff); |
415 | if (status != dnnl_success) return status; |
416 | } |
417 | } |
418 | |
419 | if (pack_a) trans_a = 'P'; |
420 | if (pack_b) trans_b = 'P'; |
421 | |
422 | status = gemm_s8u8s32_compute(&trans_a, &trans_b, &offset_c, &m, &n, &k, |
423 | a_eff, &lda, b_eff, &ldb, &p.beta, C, &ldc, oc); |
424 | |
425 | return status; |
426 | } |
427 | |
428 | static dnnl_status_t call(const test_params &p, const test_memory &a_mem, |
429 | const test_memory &b_mem, const test_memory &c_mem, |
430 | const test_memory &oc_mem) { |
431 | assert(p.igemm_params.oa() >= 0); |
432 | |
433 | if (p.pack_params.pack_a || p.pack_params.pack_b) |
434 | return call_packed(p, a_mem, b_mem, c_mem, oc_mem); |
435 | |
436 | auto A = map_memory<uint8_t>(a_mem); |
437 | auto B = map_memory<int8_t>(b_mem); |
438 | auto C = map_memory<int32_t>(c_mem); |
439 | auto oc = map_memory<int32_t>(oc_mem); |
440 | uint8_t oa = (uint8_t)p.igemm_params.oa(); |
441 | int8_t ob = p.igemm_params.ob(); |
442 | |
443 | #if DNNL_CPU_RUNTIME == DNNL_RUNTIME_THREADPOOL |
444 | static auto *st |
445 | = impl::testing_threadpool_utils::get_active_threadpool(); |
446 | testing::scoped_tp_deactivation_t std; |
447 | return static_cast<dnnl_status_t>( |
448 | dnnl::threadpool_interop::gemm_u8s8s32(p.transA, p.transB, |
449 | p.igemm_params.offsetc, p.M, p.N, p.K, p.alpha, A, |
450 | p.lda, oa, B, p.ldb, ob, p.beta, C, p.ldc, oc, st)); |
451 | #else |
452 | return dnnl_gemm_u8s8s32(p.transA, p.transB, p.igemm_params.offsetc, |
453 | p.M, p.N, p.K, p.alpha, A, p.lda, oa, B, p.ldb, ob, p.beta, C, |
454 | p.ldc, oc); |
455 | #endif |
456 | } |
457 | }; |
458 | |
459 | template <> |
460 | struct dnnl_gemm<float16_t, float16_t, float> { |
461 | static dnnl_status_t call(const test_params &p, const test_memory &a_mem, |
462 | const test_memory &b_mem, const test_memory &c_mem, |
463 | const test_memory &) { |
464 | return dnnl_unimplemented; |
465 | } |
466 | }; |
467 | |
468 | template <> |
469 | struct dnnl_gemm<bfloat16_t, bfloat16_t, float> { |
470 | static dnnl_status_t call_packed(const test_params &p, |
471 | const test_memory &a_mem, const test_memory &b_mem, |
472 | const test_memory &c_mem) { |
473 | /* Alas, the internal API still uses Fortran notation. |
474 | * So in addition to the changes for pack API, we also need to take |
475 | * care of conversions and layouts */ |
476 | |
477 | using namespace dnnl::impl::cpu; |
478 | |
479 | assert(p.alpha == 1.f); |
480 | |
481 | /* Prepare for Fortran style, hence A <-> B */ |
482 | char trans_a = p.transB, trans_b = p.transA; |
483 | |
484 | int64_t m = p.N, n = p.M, k = p.K; |
485 | int64_t lda = p.ldb, ldb = p.lda, ldc = p.ldc; |
486 | |
487 | std::vector<bfloat16_t> a_pack_buf, b_pack_buf; |
488 | bfloat16_t *A = map_memory<bfloat16_t>(b_mem), *a_eff = A; |
489 | bfloat16_t *B = map_memory<bfloat16_t>(a_mem), *b_eff = B; |
490 | float *C = map_memory<float>(c_mem); |
491 | |
492 | bool pack_a = p.pack_params.pack_b; |
493 | bool pack_b = p.pack_params.pack_a; |
494 | |
495 | dnnl_status_t status = dnnl_success; |
496 | |
497 | if (pack_a) { |
498 | size_t a_sz; |
499 | status = gemm_bf16bf16f32_pack_get_size("A" , &trans_a, &trans_b, &m, |
500 | &n, &k, &lda, &ldb, &a_sz, &pack_a); |
501 | if (status != dnnl_success) return status; |
502 | |
503 | if (pack_a) { |
504 | a_pack_buf.resize(a_sz / sizeof(*a_eff)); |
505 | a_eff = a_pack_buf.data(); |
506 | |
507 | status = gemm_bf16bf16f32_pack("A" , &trans_a, &trans_b, &m, &n, |
508 | &k, &lda, &ldb, A, a_eff); |
509 | if (status != dnnl_success) return status; |
510 | } |
511 | } |
512 | |
513 | if (pack_b) { |
514 | size_t b_sz; |
515 | status = gemm_bf16bf16f32_pack_get_size("B" , &trans_a, &trans_b, &m, |
516 | &n, &k, &lda, &ldb, &b_sz, &pack_b); |
517 | if (status != dnnl_success) return status; |
518 | |
519 | if (pack_b) { |
520 | b_pack_buf.resize(b_sz / sizeof(*b_eff)); |
521 | b_eff = b_pack_buf.data(); |
522 | |
523 | status = gemm_bf16bf16f32_pack("B" , &trans_a, &trans_b, &m, &n, |
524 | &k, &lda, &ldb, B, b_eff); |
525 | if (status != dnnl_success) return status; |
526 | } |
527 | } |
528 | |
529 | if (pack_a) trans_a = 'P'; |
530 | if (pack_b) trans_b = 'P'; |
531 | |
532 | status = gemm_bf16bf16f32_compute(&trans_a, &trans_b, &m, &n, &k, a_eff, |
533 | &lda, b_eff, &ldb, &p.beta, C, &ldc); |
534 | |
535 | return status; |
536 | } |
537 | |
538 | static dnnl_status_t call(const test_params &p, const test_memory &a_mem, |
539 | const test_memory &b_mem, const test_memory &c_mem, |
540 | const test_memory &) { |
541 | if (p.pack_params.pack_a || p.pack_params.pack_b) |
542 | return call_packed(p, a_mem, b_mem, c_mem); |
543 | |
544 | auto A = map_memory<bfloat16_t>(a_mem); |
545 | auto B = map_memory<bfloat16_t>(b_mem); |
546 | auto C = map_memory<float>(c_mem); |
547 | return dnnl_gemm_bf16bf16f32(p.transA, p.transB, p.M, p.N, p.K, p.alpha, |
548 | A, p.lda, B, p.ldb, p.beta, C, p.ldc); |
549 | } |
550 | }; |
551 | |
552 | template <> |
553 | struct dnnl_gemm<bfloat16_t, bfloat16_t, bfloat16_t> { |
554 | static dnnl_status_t call(const test_params &p, const test_memory &a_mem, |
555 | const test_memory &b_mem, const test_memory &c_mem, |
556 | const test_memory &) { |
557 | return dnnl_unimplemented; |
558 | } |
559 | }; |
560 | |
561 | template <typename a_dt, typename b_dt, typename c_dt> |
562 | struct run_test_gemm { |
563 | static void call(const test_params &p) { |
564 | if (p.expect_to_fail) { |
565 | engine eng = get_test_engine(); |
566 | test_memory zero_mem({}, eng); |
567 | auto status = dnnl_gemm<a_dt, b_dt, c_dt>::call( |
568 | p, zero_mem, zero_mem, zero_mem, zero_mem); |
569 | if (status != dnnl_success) |
570 | throw error(status, "oneDNN gemm returned error" ); |
571 | return; |
572 | } |
573 | |
574 | engine eng = get_test_engine(); |
575 | test_gemm_data gemm_data; |
576 | prepare_data_for_gemm_testing<a_dt, b_dt, c_dt>(p, gemm_data, eng); |
577 | |
578 | auto status = dnnl_gemm<a_dt, b_dt, c_dt>::call(p, *gemm_data.a_mem, |
579 | *gemm_data.b_mem, *gemm_data.c_mem, *gemm_data.oc_mem); |
580 | |
581 | if (status == dnnl_success) { |
582 | validate<a_dt, b_dt, c_dt>(p, gemm_data); |
583 | } |
584 | |
585 | if (status != dnnl_success) |
586 | throw error(status, "oneDNN gemm returned error" ); |
587 | } |
588 | }; |
589 | |
590 | template <typename a_dt, typename b_dt, typename c_dt> |
591 | class gemm_test_common : public ::testing::TestWithParam<test_params> { |
592 | protected: |
593 | virtual void SetUp() { |
594 | const auto &p = ::testing::TestWithParam<test_params>::GetParam(); |
595 | |
596 | SKIP_IF(get_test_engine_kind() == engine::kind::gpu, |
597 | "GPU GEMM not implemented." ); |
598 | |
599 | #if DNNL_CPU_RUNTIME == DNNL_RUNTIME_SYCL |
600 | SKIP_IF(get_test_engine_kind() == engine::kind::cpu, |
601 | "SYCL CPU GEMM not implemented." ); |
602 | #endif |
603 | |
604 | bool zero_off = (p.off.a == 0 && p.off.b == 0 && p.off.c == 0); |
605 | SKIP_IF(!zero_off && get_test_engine_kind() == engine::kind::cpu, |
606 | "CPU does not support non-zero offsets." ); |
607 | |
608 | SKIP_IF(unsupported_data_type(data_traits<a_dt>::data_type), |
609 | "Engine does not support this data type." ); |
610 | |
611 | bool is_f16 = (data_traits<a_dt>::data_type == memory::data_type::f16); |
612 | SKIP_IF(is_f16 && get_test_engine_kind() == engine::kind::cpu, |
613 | "CPU does not support f16 data type." ); |
614 | |
615 | #if DNNL_CPU_RUNTIME == DNNL_RUNTIME_SYCL |
616 | SKIP_IF(get_test_engine_kind() == engine::kind::cpu, |
617 | "SYCL CPU GEMM not implemented." ); |
618 | #endif |
619 | #if DNNL_GPU_RUNTIME == DNNL_RUNTIME_SYCL |
620 | SKIP_IF(get_test_engine_kind() == engine::kind::gpu |
621 | && (data_traits<a_dt>::data_type |
622 | == memory::data_type::u8 |
623 | || data_traits<a_dt>::data_type |
624 | == memory::data_type::s8), |
625 | "SYCL GPU int GEMM not implemented." ); |
626 | SKIP_IF_CUDA(true, "Test not supported in CUDA backend" ); |
627 | #endif |
628 | |
629 | #if DNNL_X64 |
630 | bool is_bf16bf16f32 = true |
631 | && data_traits<a_dt>::data_type == memory::data_type::bf16 |
632 | && data_traits<b_dt>::data_type == memory::data_type::bf16 |
633 | && data_traits<c_dt>::data_type == memory::data_type::f32; |
634 | |
635 | SKIP_IF(is_bf16bf16f32 && get_test_engine_kind() == engine::kind::cpu |
636 | && !dnnl::mayiuse(cpu_isa::avx512_core), |
637 | "Skip test for systems that do not support avx512_core." ); |
638 | #endif |
639 | |
640 | bool pack = (p.pack_params.pack_a || p.pack_params.pack_b); |
641 | SKIP_IF(!DNNL_X64 && pack, |
642 | "Packed GEMM does not support non-x64 CPUs." ); |
643 | SKIP_IF((p.alpha != 1.f || p.igemm_params.oa() != 0 |
644 | || p.igemm_params.ob() != 0) |
645 | && pack, |
646 | "Packed GEMM doesn't support alpha or non-zero offset{A,B}." ); |
647 | SKIP_IF(data_traits<b_dt>::data_type == memory::data_type::u8 |
648 | && get_test_engine_kind() == engine::kind::cpu, |
649 | "CPU does not support s8u8s32 and u8u8s32 GEMM." ); |
650 | SKIP_IF(data_traits<c_dt>::data_type == memory::data_type::bf16 |
651 | && get_test_engine_kind() == engine::kind::cpu, |
652 | "CPU does not support bf16bf16bf16 GEMM." ); |
653 | |
654 | catch_expected_failures( |
655 | [=]() { Test(); }, p.expect_to_fail, p.expected_status, false); |
656 | } |
657 | void Test() { |
658 | #if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL |
659 | testing::scoped_tp_activation_t sta; |
660 | #endif |
661 | #if DNNL_GPU_RUNTIME == DNNL_RUNTIME_SYCL |
662 | if (get_test_engine_kind() == engine::kind::gpu) { |
663 | const auto &p = ::testing::TestWithParam<test_params>::GetParam(); |
664 | |
665 | #if defined(TEST_DNNL_DPCPP_BUFFER) |
666 | // Test SYCL buffer interfaces |
667 | run_test_gemm<a_dt, b_dt, c_dt>::call(p); |
668 | #else |
669 | // Test SYCL USM interfaces |
670 | bool zero_off = (p.off.a == 0 && p.off.b == 0 && p.off.c == 0); |
671 | SKIP_IF(!zero_off, "USM interfaces do not support offsets." ); |
672 | |
673 | run_test_gemm<a_dt, b_dt, c_dt>::call(p); |
674 | #endif |
675 | |
676 | return; |
677 | } |
678 | #endif |
679 | const auto &p = ::testing::TestWithParam<test_params>::GetParam(); |
680 | run_test_gemm<a_dt, b_dt, c_dt>::call(p); |
681 | } |
682 | }; |
683 | } // namespace dnnl |
684 | #endif |
685 | |