1/*******************************************************************************
2* Copyright 2018-2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef TEST_GEMM_COMMON_H
18#define TEST_GEMM_COMMON_H
19
20#include <cstdint>
21#include <utility>
22#include <vector>
23#include <type_traits>
24
25#include "test_gemm_data_preparation.hpp"
26#include "test_gemm_params.hpp"
27#include "test_gemm_validation.hpp"
28
29#include "dnnl_test_common.hpp"
30#include "dnnl_thread.hpp"
31#include "gtest/gtest.h"
32
33#include "oneapi/dnnl/dnnl.h"
34#include "oneapi/dnnl/dnnl_types.h"
35
36#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL
37#include "oneapi/dnnl/dnnl_ocl.hpp"
38#endif
39
40#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_SYCL
41#include "oneapi/dnnl/dnnl_sycl.hpp"
42#endif
43
44#if DNNL_CPU_RUNTIME == DNNL_RUNTIME_THREADPOOL
45#include "oneapi/dnnl/dnnl_threadpool.hpp"
46#include "tests/test_thread.hpp"
47#endif
48
49#include "tests/test_isa_common.hpp"
50
51#define CONCAT_WITH_UNDERSCORE_(a, b) a##_##b
52#define CONCAT_WITH_UNDERSCORE(a, b) CONCAT_WITH_UNDERSCORE_(a, b)
53
54#define INST_TEST_CASE_(str, ...) \
55 INSTANTIATE_TEST_SUITE_P(str, gemm_test, ::testing::Values(__VA_ARGS__))
56#define INST_TEST_CASE(str, ...) \
57 INST_TEST_CASE_( \
58 CONCAT_WITH_UNDERSCORE(str, TEST_CASE_NAME_PREFIX), __VA_ARGS__)
59
60#define CPU_INST_TEST_CASE_(str, ...) \
61 CPU_INSTANTIATE_TEST_SUITE_P(str, gemm_test, ::testing::Values(__VA_ARGS__))
62#define CPU_INST_TEST_CASE(str, ...) \
63 CPU_INST_TEST_CASE_( \
64 CONCAT_WITH_UNDERSCORE(str, TEST_CASE_NAME_PREFIX), __VA_ARGS__)
65
66// Declare bfloat16 GEMM interfaces for testing
67extern "C" {
68dnnl_status_t dnnl_gemm_bf16bf16f32(char transa, char transb, dnnl_dim_t M,
69 dnnl_dim_t N, dnnl_dim_t K, float alpha, const bfloat16_t *A,
70 dnnl_dim_t lda, const bfloat16_t *B, dnnl_dim_t ldb, float beta,
71 float *C, dnnl_dim_t ldc);
72}
73
74// Declare packed GEMM interfaces for testing
75#include "src/cpu/gemm/gemm_pack.hpp"
76
77namespace dnnl {
78
79#if defined(DNNL_WTIH_SYCL)
80bool is_memory_kind_buffer(const test_memory &mem) {
81 return sycl_interop::get_memory_kind(mem.get())
82 == sycl_interop::memory_kind::buffer;
83}
84#endif
85
86/* Test implementation description.
87 * The testing steps looks as follows:
88 * 0. Prepare mapper_m and mapper_n <- details in test_gemm_data_preparation.hpp
89 * 1.a Generate random matrices A', B', C'
90 * 1.b Prepare matrices A, B, C based on A', B', and C' respectively
91 * 2. Compute C_calc := Op(M, N, K, A, B, C)
92 * 3. Compute C'_ref := Op_REF(M_test, N_test, K, A', B', C')
93 * 4. Expand C'_ref to C_ref, by applying mapper_m and mapper_n
94 * 5. Compare C_calc and C_ref
95 */
96
97template <typename a_dt, typename b_dt, typename c_dt>
98struct dnnl_gemm {
99 static dnnl_status_t call(test_params &p, const test_memory &a_mem,
100 const test_memory &b_mem, const test_memory &c_mem) {
101 throw error(dnnl_runtime_error, "unknown gemm");
102 }
103};
104
105template <>
106struct dnnl_gemm<float16_t, float16_t, float16_t> {
107 static dnnl_status_t call(const test_params &p, const test_memory &a_mem,
108 const test_memory &b_mem, const test_memory &c_mem,
109 const test_memory &) {
110 throw error(dnnl_runtime_error, "unknown gemm");
111 }
112};
113
114template <>
115struct dnnl_gemm<float, float, float> {
116 static dnnl_status_t call_packed(const test_params &p,
117 const test_memory &a_mem, const test_memory &b_mem,
118 const test_memory &c_mem) {
119 /* Alas, the internal API still uses Fortran notation.
120 * So in addition to the changes for pack API, we also need to take
121 * care of conversions and layouts */
122
123 using namespace dnnl::impl::cpu;
124
125 assert(p.alpha == 1.f);
126
127 /* Prepare for Fortran style, hence A <-> B */
128 char trans_a = p.transB, trans_b = p.transA;
129
130 int64_t m = p.N, n = p.M, k = p.K;
131 int64_t lda = p.ldb, ldb = p.lda, ldc = p.ldc;
132
133 std::vector<float> a_pack_buf, b_pack_buf;
134 float *A = map_memory<float>(b_mem), *a_eff = A;
135 float *B = map_memory<float>(a_mem), *b_eff = B;
136 float *C = map_memory<float>(c_mem);
137
138 bool pack_a = p.pack_params.pack_b;
139 bool pack_b = p.pack_params.pack_a;
140
141 dnnl_status_t status = dnnl_success;
142
143 if (pack_a) {
144 size_t a_sz;
145 status = sgemm_pack_get_size("A", &trans_a, &trans_b, &m, &n, &k,
146 &lda, &ldb, &a_sz, &pack_a);
147 if (status != dnnl_success) return status;
148
149 if (pack_a) {
150 a_pack_buf.resize(a_sz / sizeof(float));
151 a_eff = a_pack_buf.data();
152
153 status = sgemm_pack("A", &trans_a, &trans_b, &m, &n, &k, &lda,
154 &ldb, A, a_eff);
155 if (status != dnnl_success) return status;
156 }
157 }
158
159 if (pack_b) {
160 size_t b_sz;
161 status = sgemm_pack_get_size("B", &trans_a, &trans_b, &m, &n, &k,
162 &lda, &ldb, &b_sz, &pack_b);
163 if (status != dnnl_success) return status;
164
165 if (pack_b) {
166 b_pack_buf.resize(b_sz / sizeof(float));
167 b_eff = b_pack_buf.data();
168
169 status = sgemm_pack("B", &trans_a, &trans_b, &m, &n, &k, &lda,
170 &ldb, B, b_eff);
171 if (status != dnnl_success) return status;
172 }
173 }
174
175 if (pack_a) trans_a = 'P';
176 if (pack_b) trans_b = 'P';
177
178 status = sgemm_compute(&trans_a, &trans_b, &m, &n, &k, a_eff, &lda,
179 b_eff, &ldb, &p.beta, C, &ldc);
180
181 return status;
182 }
183
184 static dnnl_status_t call(const test_params &p, const test_memory &a_mem,
185 const test_memory &b_mem, const test_memory &c_mem,
186 const test_memory &) {
187
188 if (p.pack_params.pack_a || p.pack_params.pack_b)
189 return call_packed(p, a_mem, b_mem, c_mem);
190
191 auto A = map_memory<float>(a_mem);
192 auto B = map_memory<float>(b_mem);
193 auto C = map_memory<float>(c_mem);
194
195#if DNNL_CPU_RUNTIME == DNNL_RUNTIME_THREADPOOL
196 static auto *st
197 = impl::testing_threadpool_utils::get_active_threadpool();
198 testing::scoped_tp_deactivation_t std;
199 return static_cast<dnnl_status_t>(dnnl::threadpool_interop::sgemm(
200 p.transA, p.transB, p.M, p.N, p.K, p.alpha, A, p.lda, B, p.ldb,
201 p.beta, C, p.ldc, st));
202#else
203 return dnnl_sgemm(p.transA, p.transB, p.M, p.N, p.K, p.alpha, A, p.lda,
204 B, p.ldb, p.beta, C, p.ldc);
205#endif
206 }
207};
208
209template <>
210struct dnnl_gemm<int8_t, int8_t, int32_t> {
211 static dnnl_status_t call_packed(const test_params &p,
212 const test_memory &a_mem, const test_memory &b_mem,
213 const test_memory &c_mem, const test_memory &oc_mem) {
214 /* Alas, the internal API still uses Fortran notation.
215 * So in addition to the changes for pack API, we also need to take
216 * care of conversions and layouts */
217
218 using namespace dnnl::impl::cpu;
219
220 assert(p.alpha == 1.f);
221 assert(p.igemm_params.oa() == 0);
222 assert(p.igemm_params.ob() == 0);
223
224 /* Prepare for Fortran style, hence A <-> B */
225 char trans_a = p.transB, trans_b = p.transA;
226
227 int64_t m = p.N, n = p.M, k = p.K;
228 int64_t lda = p.ldb, ldb = p.lda, ldc = p.ldc;
229
230 int8_t *A = map_memory<int8_t>(b_mem), *a_eff = A;
231 int8_t *B = map_memory<int8_t>(a_mem), *b_eff = B;
232
233 auto C = map_memory<int32_t>(c_mem);
234 auto oc = map_memory<int32_t>(oc_mem);
235
236 char offset_c = '\0';
237 switch (p.igemm_params.offsetc) {
238 case 'R': offset_c = 'C'; break;
239 case 'r': offset_c = 'c'; break;
240 case 'C': offset_c = 'R'; break;
241 case 'c': offset_c = 'r'; break;
242 default: offset_c = p.igemm_params.offsetc;
243 }
244
245 std::vector<int8_t> a_pack_buf;
246 std::vector<int8_t> b_pack_buf;
247 bool pack_a = p.pack_params.pack_b;
248 bool pack_b = p.pack_params.pack_a;
249
250 dnnl_status_t status = dnnl_success;
251
252 if (pack_a) {
253 size_t a_sz;
254 status = gemm_s8s8s32_pack_get_size(
255 "A", &trans_a, &trans_b, &m, &n, &k, &lda, &ldb, &a_sz);
256 if (status != dnnl_success) return status;
257
258 if (pack_a) {
259 a_pack_buf.resize(a_sz);
260 a_eff = a_pack_buf.data();
261
262 status = gemm_s8s8s32_pack("A", &trans_a, &trans_b, &m, &n, &k,
263 &lda, &ldb, A, a_eff);
264 if (status != dnnl_success) return status;
265 }
266 }
267
268 if (pack_b) {
269 size_t b_sz;
270
271 status = gemm_s8s8s32_pack_get_size(
272 "B", &trans_a, &trans_b, &m, &n, &k, &lda, &ldb, &b_sz);
273 if (status != dnnl_success) return status;
274
275 if (pack_b) {
276 b_pack_buf.resize(b_sz);
277 b_eff = b_pack_buf.data();
278
279 status = gemm_s8s8s32_pack("B", &trans_a, &trans_b, &m, &n, &k,
280 &lda, &ldb, B, b_eff);
281 if (status != dnnl_success) return status;
282 }
283 }
284
285 if (pack_a) trans_a = 'P';
286 if (pack_b) trans_b = 'P';
287
288 status = gemm_s8s8s32_compute(&trans_a, &trans_b, &offset_c, &m, &n, &k,
289 a_eff, &lda, b_eff, &ldb, &p.beta, C, &ldc, oc);
290
291 return status;
292 }
293
294 static dnnl_status_t call(const test_params &p, const test_memory &a_mem,
295 const test_memory &b_mem, const test_memory &c_mem,
296 const test_memory &oc_mem) {
297
298 if (p.pack_params.pack_a || p.pack_params.pack_b)
299 return call_packed(p, a_mem, b_mem, c_mem, oc_mem);
300
301 auto A = map_memory<int8_t>(a_mem);
302 auto B = map_memory<int8_t>(b_mem);
303 auto C = map_memory<int32_t>(c_mem);
304 auto oc = map_memory<int32_t>(oc_mem);
305 int8_t oa = p.igemm_params.oa();
306 int8_t ob = p.igemm_params.ob();
307
308#if DNNL_CPU_RUNTIME == DNNL_RUNTIME_THREADPOOL
309 static auto *st
310 = impl::testing_threadpool_utils::get_active_threadpool();
311 testing::scoped_tp_deactivation_t std;
312 return static_cast<dnnl_status_t>(
313 dnnl::threadpool_interop::gemm_s8s8s32(p.transA, p.transB,
314 p.igemm_params.offsetc, p.M, p.N, p.K, p.alpha, A,
315 p.lda, oa, B, p.ldb, ob, p.beta, C, p.ldc, oc, st));
316#else
317 return dnnl_gemm_s8s8s32(p.transA, p.transB, p.igemm_params.offsetc,
318 p.M, p.N, p.K, p.alpha, A, p.lda, oa, B, p.ldb, ob, p.beta, C,
319 p.ldc, oc);
320#endif
321 }
322};
323
324template <>
325struct dnnl_gemm<int8_t, uint8_t, int32_t> {
326 static dnnl_status_t call(const test_params &p, const test_memory &a_mem,
327 const test_memory &b_mem, const test_memory &c_mem,
328 const test_memory &oc_mem) {
329 throw error(dnnl_runtime_error, "unknown gemm");
330 }
331};
332
333template <>
334struct dnnl_gemm<uint8_t, uint8_t, int32_t> {
335 static dnnl_status_t call(const test_params &p, const test_memory &a_mem,
336 const test_memory &b_mem, const test_memory &c_mem,
337 const test_memory &oc_mem) {
338
339 throw error(dnnl_runtime_error, "unknown gemm");
340 }
341};
342
343template <>
344struct dnnl_gemm<uint8_t, int8_t, int32_t> {
345 static dnnl_status_t call_packed(const test_params &p,
346 const test_memory &a_mem, const test_memory &b_mem,
347 const test_memory &c_mem, const test_memory &oc_mem) {
348 /* Alas, the internal API still uses Fortran notation.
349 * So in addition to the changes for pack API, we also need to take
350 * care of conversions and layouts */
351
352 using namespace dnnl::impl::cpu;
353
354 assert(p.alpha == 1.f);
355 assert(p.igemm_params.oa() == 0);
356 assert(p.igemm_params.ob() == 0);
357
358 /* Prepare for Fortran style, hence A <-> B */
359 char trans_a = p.transB, trans_b = p.transA;
360
361 int64_t m = p.N, n = p.M, k = p.K;
362 int64_t lda = p.ldb, ldb = p.lda, ldc = p.ldc;
363
364 int8_t *A = map_memory<int8_t>(b_mem), *a_eff = A;
365 uint8_t *B = map_memory<uint8_t>(a_mem), *b_eff = B;
366
367 auto C = map_memory<int32_t>(c_mem);
368 auto oc = map_memory<int32_t>(oc_mem);
369
370 char offset_c = '\0';
371 switch (p.igemm_params.offsetc) {
372 case 'R': offset_c = 'C'; break;
373 case 'r': offset_c = 'c'; break;
374 case 'C': offset_c = 'R'; break;
375 case 'c': offset_c = 'r'; break;
376 default: offset_c = p.igemm_params.offsetc;
377 }
378
379 std::vector<int8_t> a_pack_buf;
380 std::vector<uint8_t> b_pack_buf;
381 bool pack_a = p.pack_params.pack_b;
382 bool pack_b = p.pack_params.pack_a;
383
384 dnnl_status_t status = dnnl_success;
385
386 if (pack_a) {
387 size_t a_sz;
388 status = gemm_s8u8s32_pack_get_size(
389 "A", &trans_a, &trans_b, &m, &n, &k, &lda, &ldb, &a_sz);
390 if (status != dnnl_success) return status;
391
392 if (pack_a) {
393 a_pack_buf.resize(a_sz);
394 a_eff = a_pack_buf.data();
395
396 status = gemm_s8u8s32_pack("A", &trans_a, &trans_b, &m, &n, &k,
397 &lda, &ldb, A, a_eff);
398 if (status != dnnl_success) return status;
399 }
400 }
401
402 if (pack_b) {
403 size_t b_sz;
404
405 status = gemm_s8u8s32_pack_get_size(
406 "B", &trans_a, &trans_b, &m, &n, &k, &lda, &ldb, &b_sz);
407 if (status != dnnl_success) return status;
408
409 if (pack_b) {
410 b_pack_buf.resize(b_sz);
411 b_eff = b_pack_buf.data();
412
413 status = gemm_s8u8s32_pack("B", &trans_a, &trans_b, &m, &n, &k,
414 &lda, &ldb, B, b_eff);
415 if (status != dnnl_success) return status;
416 }
417 }
418
419 if (pack_a) trans_a = 'P';
420 if (pack_b) trans_b = 'P';
421
422 status = gemm_s8u8s32_compute(&trans_a, &trans_b, &offset_c, &m, &n, &k,
423 a_eff, &lda, b_eff, &ldb, &p.beta, C, &ldc, oc);
424
425 return status;
426 }
427
428 static dnnl_status_t call(const test_params &p, const test_memory &a_mem,
429 const test_memory &b_mem, const test_memory &c_mem,
430 const test_memory &oc_mem) {
431 assert(p.igemm_params.oa() >= 0);
432
433 if (p.pack_params.pack_a || p.pack_params.pack_b)
434 return call_packed(p, a_mem, b_mem, c_mem, oc_mem);
435
436 auto A = map_memory<uint8_t>(a_mem);
437 auto B = map_memory<int8_t>(b_mem);
438 auto C = map_memory<int32_t>(c_mem);
439 auto oc = map_memory<int32_t>(oc_mem);
440 uint8_t oa = (uint8_t)p.igemm_params.oa();
441 int8_t ob = p.igemm_params.ob();
442
443#if DNNL_CPU_RUNTIME == DNNL_RUNTIME_THREADPOOL
444 static auto *st
445 = impl::testing_threadpool_utils::get_active_threadpool();
446 testing::scoped_tp_deactivation_t std;
447 return static_cast<dnnl_status_t>(
448 dnnl::threadpool_interop::gemm_u8s8s32(p.transA, p.transB,
449 p.igemm_params.offsetc, p.M, p.N, p.K, p.alpha, A,
450 p.lda, oa, B, p.ldb, ob, p.beta, C, p.ldc, oc, st));
451#else
452 return dnnl_gemm_u8s8s32(p.transA, p.transB, p.igemm_params.offsetc,
453 p.M, p.N, p.K, p.alpha, A, p.lda, oa, B, p.ldb, ob, p.beta, C,
454 p.ldc, oc);
455#endif
456 }
457};
458
459template <>
460struct dnnl_gemm<float16_t, float16_t, float> {
461 static dnnl_status_t call(const test_params &p, const test_memory &a_mem,
462 const test_memory &b_mem, const test_memory &c_mem,
463 const test_memory &) {
464 return dnnl_unimplemented;
465 }
466};
467
468template <>
469struct dnnl_gemm<bfloat16_t, bfloat16_t, float> {
470 static dnnl_status_t call_packed(const test_params &p,
471 const test_memory &a_mem, const test_memory &b_mem,
472 const test_memory &c_mem) {
473 /* Alas, the internal API still uses Fortran notation.
474 * So in addition to the changes for pack API, we also need to take
475 * care of conversions and layouts */
476
477 using namespace dnnl::impl::cpu;
478
479 assert(p.alpha == 1.f);
480
481 /* Prepare for Fortran style, hence A <-> B */
482 char trans_a = p.transB, trans_b = p.transA;
483
484 int64_t m = p.N, n = p.M, k = p.K;
485 int64_t lda = p.ldb, ldb = p.lda, ldc = p.ldc;
486
487 std::vector<bfloat16_t> a_pack_buf, b_pack_buf;
488 bfloat16_t *A = map_memory<bfloat16_t>(b_mem), *a_eff = A;
489 bfloat16_t *B = map_memory<bfloat16_t>(a_mem), *b_eff = B;
490 float *C = map_memory<float>(c_mem);
491
492 bool pack_a = p.pack_params.pack_b;
493 bool pack_b = p.pack_params.pack_a;
494
495 dnnl_status_t status = dnnl_success;
496
497 if (pack_a) {
498 size_t a_sz;
499 status = gemm_bf16bf16f32_pack_get_size("A", &trans_a, &trans_b, &m,
500 &n, &k, &lda, &ldb, &a_sz, &pack_a);
501 if (status != dnnl_success) return status;
502
503 if (pack_a) {
504 a_pack_buf.resize(a_sz / sizeof(*a_eff));
505 a_eff = a_pack_buf.data();
506
507 status = gemm_bf16bf16f32_pack("A", &trans_a, &trans_b, &m, &n,
508 &k, &lda, &ldb, A, a_eff);
509 if (status != dnnl_success) return status;
510 }
511 }
512
513 if (pack_b) {
514 size_t b_sz;
515 status = gemm_bf16bf16f32_pack_get_size("B", &trans_a, &trans_b, &m,
516 &n, &k, &lda, &ldb, &b_sz, &pack_b);
517 if (status != dnnl_success) return status;
518
519 if (pack_b) {
520 b_pack_buf.resize(b_sz / sizeof(*b_eff));
521 b_eff = b_pack_buf.data();
522
523 status = gemm_bf16bf16f32_pack("B", &trans_a, &trans_b, &m, &n,
524 &k, &lda, &ldb, B, b_eff);
525 if (status != dnnl_success) return status;
526 }
527 }
528
529 if (pack_a) trans_a = 'P';
530 if (pack_b) trans_b = 'P';
531
532 status = gemm_bf16bf16f32_compute(&trans_a, &trans_b, &m, &n, &k, a_eff,
533 &lda, b_eff, &ldb, &p.beta, C, &ldc);
534
535 return status;
536 }
537
538 static dnnl_status_t call(const test_params &p, const test_memory &a_mem,
539 const test_memory &b_mem, const test_memory &c_mem,
540 const test_memory &) {
541 if (p.pack_params.pack_a || p.pack_params.pack_b)
542 return call_packed(p, a_mem, b_mem, c_mem);
543
544 auto A = map_memory<bfloat16_t>(a_mem);
545 auto B = map_memory<bfloat16_t>(b_mem);
546 auto C = map_memory<float>(c_mem);
547 return dnnl_gemm_bf16bf16f32(p.transA, p.transB, p.M, p.N, p.K, p.alpha,
548 A, p.lda, B, p.ldb, p.beta, C, p.ldc);
549 }
550};
551
552template <>
553struct dnnl_gemm<bfloat16_t, bfloat16_t, bfloat16_t> {
554 static dnnl_status_t call(const test_params &p, const test_memory &a_mem,
555 const test_memory &b_mem, const test_memory &c_mem,
556 const test_memory &) {
557 return dnnl_unimplemented;
558 }
559};
560
561template <typename a_dt, typename b_dt, typename c_dt>
562struct run_test_gemm {
563 static void call(const test_params &p) {
564 if (p.expect_to_fail) {
565 engine eng = get_test_engine();
566 test_memory zero_mem({}, eng);
567 auto status = dnnl_gemm<a_dt, b_dt, c_dt>::call(
568 p, zero_mem, zero_mem, zero_mem, zero_mem);
569 if (status != dnnl_success)
570 throw error(status, "oneDNN gemm returned error");
571 return;
572 }
573
574 engine eng = get_test_engine();
575 test_gemm_data gemm_data;
576 prepare_data_for_gemm_testing<a_dt, b_dt, c_dt>(p, gemm_data, eng);
577
578 auto status = dnnl_gemm<a_dt, b_dt, c_dt>::call(p, *gemm_data.a_mem,
579 *gemm_data.b_mem, *gemm_data.c_mem, *gemm_data.oc_mem);
580
581 if (status == dnnl_success) {
582 validate<a_dt, b_dt, c_dt>(p, gemm_data);
583 }
584
585 if (status != dnnl_success)
586 throw error(status, "oneDNN gemm returned error");
587 }
588};
589
590template <typename a_dt, typename b_dt, typename c_dt>
591class gemm_test_common : public ::testing::TestWithParam<test_params> {
592protected:
593 virtual void SetUp() {
594 const auto &p = ::testing::TestWithParam<test_params>::GetParam();
595
596 SKIP_IF(get_test_engine_kind() == engine::kind::gpu,
597 "GPU GEMM not implemented.");
598
599#if DNNL_CPU_RUNTIME == DNNL_RUNTIME_SYCL
600 SKIP_IF(get_test_engine_kind() == engine::kind::cpu,
601 "SYCL CPU GEMM not implemented.");
602#endif
603
604 bool zero_off = (p.off.a == 0 && p.off.b == 0 && p.off.c == 0);
605 SKIP_IF(!zero_off && get_test_engine_kind() == engine::kind::cpu,
606 "CPU does not support non-zero offsets.");
607
608 SKIP_IF(unsupported_data_type(data_traits<a_dt>::data_type),
609 "Engine does not support this data type.");
610
611 bool is_f16 = (data_traits<a_dt>::data_type == memory::data_type::f16);
612 SKIP_IF(is_f16 && get_test_engine_kind() == engine::kind::cpu,
613 "CPU does not support f16 data type.");
614
615#if DNNL_CPU_RUNTIME == DNNL_RUNTIME_SYCL
616 SKIP_IF(get_test_engine_kind() == engine::kind::cpu,
617 "SYCL CPU GEMM not implemented.");
618#endif
619#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_SYCL
620 SKIP_IF(get_test_engine_kind() == engine::kind::gpu
621 && (data_traits<a_dt>::data_type
622 == memory::data_type::u8
623 || data_traits<a_dt>::data_type
624 == memory::data_type::s8),
625 "SYCL GPU int GEMM not implemented.");
626 SKIP_IF_CUDA(true, "Test not supported in CUDA backend");
627#endif
628
629#if DNNL_X64
630 bool is_bf16bf16f32 = true
631 && data_traits<a_dt>::data_type == memory::data_type::bf16
632 && data_traits<b_dt>::data_type == memory::data_type::bf16
633 && data_traits<c_dt>::data_type == memory::data_type::f32;
634
635 SKIP_IF(is_bf16bf16f32 && get_test_engine_kind() == engine::kind::cpu
636 && !dnnl::mayiuse(cpu_isa::avx512_core),
637 "Skip test for systems that do not support avx512_core.");
638#endif
639
640 bool pack = (p.pack_params.pack_a || p.pack_params.pack_b);
641 SKIP_IF(!DNNL_X64 && pack,
642 "Packed GEMM does not support non-x64 CPUs.");
643 SKIP_IF((p.alpha != 1.f || p.igemm_params.oa() != 0
644 || p.igemm_params.ob() != 0)
645 && pack,
646 "Packed GEMM doesn't support alpha or non-zero offset{A,B}.");
647 SKIP_IF(data_traits<b_dt>::data_type == memory::data_type::u8
648 && get_test_engine_kind() == engine::kind::cpu,
649 "CPU does not support s8u8s32 and u8u8s32 GEMM.");
650 SKIP_IF(data_traits<c_dt>::data_type == memory::data_type::bf16
651 && get_test_engine_kind() == engine::kind::cpu,
652 "CPU does not support bf16bf16bf16 GEMM.");
653
654 catch_expected_failures(
655 [=]() { Test(); }, p.expect_to_fail, p.expected_status, false);
656 }
657 void Test() {
658#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL
659 testing::scoped_tp_activation_t sta;
660#endif
661#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_SYCL
662 if (get_test_engine_kind() == engine::kind::gpu) {
663 const auto &p = ::testing::TestWithParam<test_params>::GetParam();
664
665#if defined(TEST_DNNL_DPCPP_BUFFER)
666 // Test SYCL buffer interfaces
667 run_test_gemm<a_dt, b_dt, c_dt>::call(p);
668#else
669 // Test SYCL USM interfaces
670 bool zero_off = (p.off.a == 0 && p.off.b == 0 && p.off.c == 0);
671 SKIP_IF(!zero_off, "USM interfaces do not support offsets.");
672
673 run_test_gemm<a_dt, b_dt, c_dt>::call(p);
674#endif
675
676 return;
677 }
678#endif
679 const auto &p = ::testing::TestWithParam<test_params>::GetParam();
680 run_test_gemm<a_dt, b_dt, c_dt>::call(p);
681 }
682};
683} // namespace dnnl
684#endif
685