1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain src copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "dnnl_test_common.hpp"
18#include "gtest/gtest.h"
19
20#include "oneapi/dnnl/dnnl.hpp"
21
22#include <vector>
23
24namespace dnnl {
25
26namespace P {
27// Common
28unsigned NONE = 0u;
29
30unsigned RUNTIME = 1u << 31;
31
32unsigned SCALES = 1u << 30;
33unsigned ZERO_POINTS = 1u << 29;
34
35unsigned LEADING_DIM = 1u << 28;
36
37// matrices indices: 1 .. 7
38// bits reserved: 20 .. 22
39unsigned MATRIX_MASK = 7u << 20;
40unsigned SRC = 1u << 20;
41unsigned WEIGHTS = 2u << 20;
42unsigned DST = 3u << 20;
43
44// scales and zero points: 1 .. 3
45// bits reserved: 0 .. 1
46unsigned MASK_MASK = 3u << 0;
47unsigned COMMON = 1u << 0;
48unsigned PER_N = 1u << 1;
49} // namespace P
50
51struct matmul_base_t {
52 struct md_t {
53 memory::dims dims;
54 memory::data_type dt;
55 memory::format_tag tag;
56 unsigned flags;
57 } src, weights, dst;
58 memory::data_type bia_dt;
59};
60
61// TODO: src way to generalize?
62struct matmul_attr_t {
63 // ctor {P::SCALE, {P::SRC, P::WEIGHTS, P::DST}, {P::POST_OPS, ...}}
64
65 unsigned scale_flags;
66
67 struct zero_points_t {
68 unsigned src, weights, dst;
69 } zero_points;
70
71 struct post_op_t {
72 primitive::kind kind;
73 algorithm alg;
74 };
75
76 std::vector<post_op_t> post_ops;
77};
78
79struct matmul_test_params_t {
80 matmul_base_t base;
81 matmul_attr_t attr;
82
83 bool expect_to_fail;
84 dnnl_status_t expected_status;
85};
86
87using tag = memory::format_tag;
88
89class matmul_iface_test_t
90 : public ::testing::TestWithParam<matmul_test_params_t> {
91protected:
92 void SetUp() override {
93 matmul_test_params_t p
94 = ::testing::TestWithParam<decltype(p)>::GetParam();
95
96 SKIP_IF(unsupported_data_type(p.base.src.dt),
97 "Engine does not support this data type.");
98 SKIP_IF(unsupported_data_type(p.base.weights.dt),
99 "Engine does not support this data type.");
100 SKIP_IF(unsupported_data_type(p.base.dst.dt),
101 "Engine does not support this data type.");
102 SKIP_IF(unsupported_data_type(p.base.bia_dt),
103 "Engine does not support this data type.");
104 SKIP_IF(get_test_engine_kind() == engine::kind::gpu
105 && ((p.attr.zero_points.src & P::PER_N)
106 || (p.attr.zero_points.dst & P::PER_N)),
107 "Per dimensional zero points are not supported on GPU");
108 SKIP_IF(get_test_engine_kind() == engine::kind::cpu
109 && p.base.src.tag == impl::format_tag::AB8a4b,
110 "Don't test blocked formats on CPU");
111
112 SKIP_IF_CUDA((p.attr.zero_points.src != 0 || p.attr.zero_points.dst != 0
113 || p.attr.zero_points.weights != 0),
114 "Zero points not supported for CUDA");
115
116 SKIP_IF_CUDA((p.attr.scale_flags & P::MASK_MASK) == P::PER_N,
117 "Per dimensional scaling is not supported for CUDA");
118
119 catch_expected_failures(
120 [=]() { Test(); }, p.expect_to_fail, p.expected_status, false);
121 }
122
123 // use `force_no_rt = true` when create final memory
124 static memory::desc init_md(
125 const matmul_base_t::md_t &desc, bool force_no_rt = false) {
126 const bool runtime = force_no_rt ? false : (desc.flags & P::RUNTIME);
127 const bool use_ld = (desc.flags & P::LEADING_DIM);
128
129 memory::dims dims = desc.dims;
130 if (runtime)
131 dims = memory::dims(desc.dims.size(), DNNL_RUNTIME_DIM_VAL);
132
133 if (runtime || use_ld == false)
134 return memory::desc(dims, desc.dt, desc.tag);
135
136 memory::dims strides;
137 switch (desc.tag) {
138 case tag::ab: strides = {dims[1] + 1, 1}; break;
139 case tag::ba: strides = {1, dims[0] + 1}; break;
140 case tag::abc:
141 strides = {dims[1] * (dims[2] + 1) + 1, dims[2] + 1, 1};
142 break;
143 case tag::acb:
144 strides = {dims[1] * (dims[2] + 1) + 1, dims[2] + 1, 1};
145 break;
146 default:
147 throw std::invalid_argument("tag doesn't support custom ld");
148 }
149
150 return memory::desc(dims, desc.dt, strides);
151 }
152
153 static void create_attr(const matmul_test_params_t &p, primitive_attr &attr,
154 memory &zero_points_src_m, memory &zero_points_weights_m,
155 memory &zero_points_dst_m, engine &eng) {
156 const int ndims = (int)p.base.dst.dims.size();
157
158 // zero points
159 auto handle_zero_points = [&](int arg, unsigned flags,
160 const matmul_base_t::md_t &md,
161 memory &zero_points_m) {
162 if (flags == P::NONE) return;
163
164 ASSERT_TRUE(flags & P::ZERO_POINTS);
165 ASSERT_TRUE(flags & P::MATRIX_MASK);
166
167 // sanity check
168 switch (arg) {
169 case DNNL_ARG_SRC:
170 ASSERT_TRUE((flags & P::MATRIX_MASK) == P::SRC);
171 break;
172 case DNNL_ARG_WEIGHTS:
173 ASSERT_TRUE((flags & P::MATRIX_MASK) == P::WEIGHTS);
174 break;
175 case DNNL_ARG_DST:
176 ASSERT_TRUE((flags & P::MATRIX_MASK) == P::DST);
177 break;
178 default: ASSERT_TRUE(!"unreachable");
179 }
180
181 unsigned zero_points_mask = flags & P::MASK_MASK;
182 ASSERT_TRUE(zero_points_mask == P::COMMON
183 || zero_points_mask == P::PER_N);
184 int mask = zero_points_mask == P::PER_N ? 1 << (ndims - 1) : 0;
185 memory::dim zero_points_size = mask ? md.dims[ndims - 1] : 1;
186
187 attr.set_zero_points_mask(arg, mask);
188 zero_points_m = test::make_memory(
189 {{zero_points_size}, memory::data_type::s32, {1}}, eng);
190 auto z = map_memory<int32_t>(zero_points_m);
191 GTEST_EXPECT_NE(z, nullptr);
192 for (memory::dim i = 0; i < zero_points_size; ++i)
193 z[i] = (arg % 7) - 3;
194 };
195
196 handle_zero_points(DNNL_ARG_SRC, p.attr.zero_points.src, p.base.src,
197 zero_points_src_m);
198 handle_zero_points(DNNL_ARG_WEIGHTS, p.attr.zero_points.weights,
199 p.base.weights, zero_points_weights_m);
200 handle_zero_points(DNNL_ARG_DST, p.attr.zero_points.dst, p.base.dst,
201 zero_points_dst_m);
202
203 // post ops
204 post_ops po;
205 for (auto post_op : p.attr.post_ops) {
206 switch (post_op.kind) {
207 case primitive::kind::sum: po.append_sum(); break;
208 case primitive::kind::eltwise:
209 po.append_eltwise(post_op.alg, 0.f, 0.f);
210 break;
211 default: ASSERT_TRUE(!"unknown post op kind");
212 }
213 }
214 attr.set_post_ops(po);
215 }
216
217 void Test() {
218 matmul_test_params_t p
219 = ::testing::TestWithParam<matmul_test_params_t>::GetParam();
220
221 auto eng = get_test_engine();
222 auto strm = make_stream(eng);
223
224 auto check_matrix_flags = [](unsigned flags, unsigned matrix) {
225 if (flags) { ASSERT_EQ(flags & P::MATRIX_MASK, matrix); }
226 };
227 check_matrix_flags(p.base.src.flags, P::SRC);
228 check_matrix_flags(p.base.weights.flags, P::WEIGHTS);
229 check_matrix_flags(p.base.dst.flags, P::DST);
230
231 auto src_md = init_md(p.base.src);
232 auto weights_md = init_md(p.base.weights);
233 auto dst_md = init_md(p.base.dst);
234
235 auto bia_md = memory::desc();
236 memory bia_m;
237 if (p.base.bia_dt != memory::data_type::undef) {
238 memory::dims bia_dims(p.base.dst.dims.size() - 1, 1);
239 bia_dims.push_back(p.base.dst.dims.back());
240 tag bia_tag = bia_dims.size() == 2 ? tag::ab : tag::abc;
241 bia_md = init_md({bia_dims, p.base.bia_dt, bia_tag,
242 p.base.dst.flags & P::RUNTIME});
243 bia_m = test::make_memory(
244 init_md({bia_dims, p.base.bia_dt, bia_tag}), eng);
245 }
246
247 primitive_attr attr;
248 memory scales_m, zero_points_src_m, zero_points_weights_m,
249 zero_points_dst_m;
250 create_attr(p, attr, zero_points_src_m, zero_points_weights_m,
251 zero_points_dst_m, eng);
252
253 auto matmul_pd = matmul::primitive_desc(
254 eng, src_md, weights_md, bia_md, dst_md, attr);
255
256 ASSERT_TRUE(matmul_pd.query_md(query::exec_arg_md, DNNL_ARG_SRC)
257 == matmul_pd.src_desc());
258 ASSERT_TRUE(matmul_pd.query_md(query::exec_arg_md, DNNL_ARG_WEIGHTS)
259 == matmul_pd.weights_desc());
260 ASSERT_TRUE(matmul_pd.query_md(query::exec_arg_md, DNNL_ARG_BIAS)
261 == matmul_pd.bias_desc());
262 ASSERT_TRUE(matmul_pd.query_md(query::exec_arg_md, DNNL_ARG_DST)
263 == matmul_pd.dst_desc());
264
265 EXPECT_ANY_THROW(matmul(matmul_pd, {}));
266 auto matmul_p = matmul(matmul_pd);
267
268 auto src_m = test::make_memory(init_md(p.base.src, true), eng);
269 auto weights_m = test::make_memory(init_md(p.base.weights, true), eng);
270 auto dst_m = test::make_memory(init_md(p.base.dst, true), eng);
271
272 // Initialize memory to make sanitizers happy
273 auto set_to_zero = [](memory &m) {
274 if (m) {
275 auto p = map_memory<char>(m);
276 auto size = m.get_desc().get_size();
277 if (size > 0) {
278 GTEST_EXPECT_NE(p, nullptr);
279 memset(p, 0, size);
280 }
281 }
282 };
283 set_to_zero(src_m);
284 set_to_zero(weights_m);
285 set_to_zero(dst_m);
286 set_to_zero(bia_m);
287
288 matmul_p.execute(strm,
289 {
290 {DNNL_ARG_SRC, src_m},
291 {DNNL_ARG_WEIGHTS, weights_m},
292 {DNNL_ARG_BIAS, bia_m},
293 {DNNL_ARG_DST, dst_m},
294 {DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC,
295 zero_points_src_m},
296 {DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS,
297 zero_points_weights_m},
298 {DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST,
299 zero_points_dst_m},
300 });
301 strm.wait();
302 }
303};
304
305struct attr_test_t
306 : public ::testing::TestWithParam<std::tuple<memory::dims, memory::dims,
307 memory::format_tag, memory::data_type, int>> {};
308
309HANDLE_EXCEPTIONS_FOR_TEST_P(
310 attr_test_t, TestMatmulShouldCallSameImplementationWithAttributes) {
311 auto engine_kind = get_test_engine_kind();
312 SKIP_IF(!DNNL_X64 || engine_kind != engine::kind::cpu,
313 "Binary impl_info_str should be same only on x64 CPU");
314 engine e {engine_kind, 0};
315
316 const auto &tensor_dims = std::get<0>(GetParam());
317 const auto format_tag = std::get<2>(GetParam());
318 const auto &binary_po_mem_dt = std::get<3>(GetParam());
319
320 SKIP_IF(unsupported_data_type(binary_po_mem_dt),
321 "Engine does not support this data type.");
322
323 // Currently, f16 binary post-ops are only supported for f16 primitive.
324 const auto src_dt = binary_po_mem_dt == memory::data_type::f16
325 ? memory::data_type::f16
326 : memory::data_type::u8;
327 const auto wei_dt = binary_po_mem_dt == memory::data_type::f16
328 ? memory::data_type::f16
329 : memory::data_type::s8;
330 const auto dst_dt = binary_po_mem_dt == memory::data_type::f16
331 ? memory::data_type::f16
332 : memory::data_type::s8;
333
334 auto src_md = memory::desc(tensor_dims, src_dt, format_tag);
335 auto weights_md = memory::desc(tensor_dims, wei_dt, format_tag);
336 auto dst_md = memory::desc(tensor_dims, dst_dt, format_tag);
337 auto bia_md = memory::desc();
338
339 std::string impl_info_no_postops;
340 auto matmul_pd
341 = matmul::primitive_desc(e, src_md, weights_md, bia_md, dst_md);
342 ASSERT_NO_THROW(impl_info_no_postops = matmul_pd.impl_info_str(););
343
344 dnnl::primitive_attr attr;
345 const float alpha = 1.f;
346 const float beta = 1.f;
347
348 dnnl::post_ops ops;
349 ops.append_sum(1.0);
350 ops.append_eltwise(algorithm::eltwise_relu, alpha, beta);
351
352 const auto &binary_po_tensor_dims = std::get<1>(GetParam());
353 memory::desc src1_po_md(
354 binary_po_tensor_dims, binary_po_mem_dt, format_tag);
355 ops.append_binary(algorithm::binary_add, src1_po_md);
356
357 attr.set_post_ops(ops);
358
359 std::string impl_info_with_postops;
360
361 matmul_pd = matmul::primitive_desc(
362 e, src_md, weights_md, bia_md, dst_md, attr);
363 ASSERT_NO_THROW(impl_info_with_postops = matmul_pd.impl_info_str(););
364 ASSERT_EQ(impl_info_no_postops, impl_info_with_postops);
365}
366
367/********************************* TEST CASES *********************************/
368
369using iface = matmul_iface_test_t;
370
371using data_type = memory::data_type;
372
373TEST_P(iface, TestsMatMul) {}
374
375static auto cases_ef = []() {
376 std::vector<matmul_test_params_t> cases;
377
378 // inconsistent dims
379 cases.push_back(
380 {{{{10, 1}, data_type::f32, tag::ab},
381 {{2, 20}, data_type::f32, tag::ab},
382 {{10, 20}, data_type::f32, tag::ab}, data_type::undef},
383 {}, true, dnnl_invalid_arguments});
384 cases.push_back({{{{10, 1}, data_type::f32, tag::ab},
385 {{1, 20}, data_type::f32, tag::ab},
386 {{10, 21}, data_type::f32, tag::ab}},
387 {}, true, dnnl_invalid_arguments});
388 cases.push_back({{{{10, 1}, data_type::f32, tag::ab},
389 {{1, 1, 20}, data_type::f32, tag::abc},
390 {{10, 20}, data_type::f32, tag::ab}},
391 {}, true, dnnl_invalid_arguments});
392 cases.push_back({{{{1, 10, 1}, data_type::u8, tag::abc},
393 {{1, 1, 2}, data_type::s8, tag::abc},
394 {{1, 11, 2}, data_type::s8, tag::abc}},
395 {}, true, dnnl_invalid_arguments});
396
397 // inconsistent wrt runtime dim vals
398 cases.push_back(
399 {{{{3, 10, 10}, data_type::f32, tag::abc},
400 {{DNNL_RUNTIME_DIM_VAL, 10, 10}, data_type::f32, tag::abc},
401 {{DNNL_RUNTIME_DIM_VAL, 10, 10}, data_type::f32,
402 tag::abc}},
403 {}, true, dnnl_invalid_arguments});
404
405 // inconsistent wrt broadcasting
406 cases.push_back({{{{3, 10, 10}, data_type::f32, tag::abc},
407 {{1, 10, 10}, data_type::f32, tag::abc},
408 {{1, 10, 10}, data_type::f32, tag::abc}},
409 {}, true, dnnl_invalid_arguments});
410
411 // no broadcasting on m/k/n dims
412 cases.push_back({{{{10, 10}, data_type::f32, tag::ab},
413 {{1, 1}, data_type::f32, tag::ab},
414 {{10, 10}, data_type::f32, tag::ab}},
415 {}, true, dnnl_invalid_arguments});
416
417 // f32 data and zero-points
418 cases.push_back({{{{10, 1}, data_type::f32, tag::ab},
419 {{1, 20}, data_type::f32, tag::ab},
420 {{10, 20}, data_type::f32, tag::ab}},
421 {P::NONE, {P::ZERO_POINTS | P::SRC | P::COMMON}}, true,
422 dnnl_unimplemented});
423
424 // bf16 data and zero-points
425 cases.push_back({{{{10, 1}, data_type::bf16, tag::ab},
426 {{1, 20}, data_type::bf16, tag::ab},
427 {{10, 20}, data_type::bf16, tag::ab}},
428 {P::NONE, {P::ZERO_POINTS | P::SRC | P::COMMON}}, true,
429 dnnl_unimplemented});
430 // unimplemented data types
431 if (get_test_engine_kind() == engine::kind::cpu) {
432 cases.push_back(
433 {{{{10, 1}, data_type::f32, tag::ab},
434 {{1, 20}, data_type::f32, tag::ab},
435 {{10, 20}, data_type::f32, tag::ab}, data_type::u8},
436 {}, true, dnnl_unimplemented});
437 }
438 // XXX: disable assert in type_helpers.hpp: default_accum_data_type(...)
439 // cases.push_back({{{{10, 1}, data_type::u8, tag::ab}, {{1, 20},
440 // data_type::u8, tag::ab},
441 // {{10, 20}, data_type::u8, tag::ab}},
442 // {}, true, dnnl_unimplemented});
443
444 // unimplemented formats (GPU only)
445 cases.push_back({{{{16, 16}, data_type::f32, tag::AB8a4b},
446 {{16, 16}, data_type::f32, tag::AB8a4b},
447 {{16, 16}, data_type::f32, tag::AB8a4b}},
448 {}, true, dnnl_unimplemented});
449
450 // broken broadcast case
451 cases.push_back(
452 {{{{1, 10, 2}, data_type::f32, tag::abc},
453 {{1, 2, 20}, data_type::f32, tag::abc},
454 {{0, 10, 20}, data_type::f32, tag::abc}, data_type::undef},
455 {}, true, dnnl_invalid_arguments});
456
457 // broken broadcast case
458 cases.push_back(
459 {{{{0, 10, 2}, data_type::f32, tag::abc},
460 {{2, 2, 20}, data_type::f32, tag::abc},
461 {{0, 10, 20}, data_type::f32, tag::abc}, data_type::undef},
462 {}, true, dnnl_invalid_arguments});
463
464 // broken broadcast case
465 cases.push_back(
466 {{{{1, 10, 2}, data_type::f32, tag::abc},
467 {{0, 2, 20}, data_type::f32, tag::abc},
468 {{1, 10, 20}, data_type::f32, tag::abc}, data_type::undef},
469 {}, true, dnnl_invalid_arguments});
470
471 return ::testing::ValuesIn(cases);
472};
473INSTANTIATE_TEST_SUITE_P(EF, iface, cases_ef());
474
475static auto cases_zd = [](memory::data_type dt) {
476 std::vector<matmul_test_params_t> cases;
477
478 // simple case M=0
479 cases.push_back({{{{0, 2}, dt, tag::ab}, {{2, 20}, dt, tag::ab},
480 {{0, 20}, dt, tag::ab}, data_type::undef},
481 {}});
482 // simple case K=0
483 cases.push_back({{{{10, 0}, dt, tag::ab}, {{0, 20}, dt, tag::ab},
484 {{10, 20}, dt, tag::ab}, data_type::undef},
485 {}});
486 // simple case N=0
487 cases.push_back({{{{10, 2}, dt, tag::ab}, {{2, 0}, dt, tag::ab},
488 {{10, 0}, dt, tag::ab}, data_type::undef},
489 {}});
490 // broadcast case all MB=0
491 cases.push_back({{{{0, 10, 2}, dt, tag::abc}, {{0, 2, 20}, dt, tag::abc},
492 {{0, 10, 20}, dt, tag::abc}, data_type::undef},
493 {}});
494 // broadcast case wei MB!=0
495 cases.push_back({{{{0, 10, 2}, dt, tag::abc}, {{1, 2, 20}, dt, tag::abc},
496 {{0, 10, 20}, dt, tag::abc}, data_type::undef},
497 {}});
498 // broadcast case src MB!=0
499 cases.push_back({{{{1, 10, 2}, dt, tag::abc}, {{0, 2, 20}, dt, tag::abc},
500 {{0, 10, 20}, dt, tag::abc}, data_type::undef},
501 {}});
502
503 return ::testing::ValuesIn(cases);
504};
505INSTANTIATE_TEST_SUITE_P(ZeroDim_f32, iface, cases_zd(data_type::f32));
506
507static auto cases_f = [](memory::data_type dt) {
508 std::vector<matmul_test_params_t> cases;
509
510 // simple case
511 cases.push_back({{{{10, 2}, dt, tag::ab}, {{2, 20}, dt, tag::ab},
512 {{10, 20}, dt, tag::ab}, data_type::undef},
513 {}});
514 // simple case + leading dimensions
515 cases.push_back({{{{10, 1}, dt, tag::ab, P::SRC | P::LEADING_DIM},
516 {{1, 3}, dt, tag::ba},
517 {{10, 3}, dt, tag::ab, P::DST | P::LEADING_DIM},
518 data_type::f32},
519 {}});
520 // simple case + leading dimensions + runtime dims
521 cases.push_back(
522 {{{{1, 10}, dt, tag::ab, P::SRC | P::LEADING_DIM | P::RUNTIME},
523 {{10, 2}, dt, tag::ba, P::WEIGHTS | P::RUNTIME},
524 {{1, 2}, dt, tag::ab,
525 P::DST | P::LEADING_DIM | P::RUNTIME},
526 data_type::f32},
527 {}});
528
529 // post-ops
530 cases.push_back({{{{10, 1}, dt, tag::ab}, {{1, 20}, dt, tag::ab},
531 {{10, 20}, dt, tag::ab}},
532 {P::NONE, {},
533 {{primitive::kind::eltwise, algorithm::eltwise_relu}}}});
534 // multiple post-ops
535 cases.push_back({{{{10, 2}, dt, tag::ab}, {{2, 20}, dt, tag::ab},
536 {{10, 20}, dt, tag::ab}},
537 {P::SCALES | P::COMMON, {},
538 {{primitive::kind::sum},
539 {primitive::kind::eltwise,
540 algorithm::eltwise_relu}}}});
541
542 // gemm like: output scale + post-ops(sum)
543 cases.push_back({{{{10, 1}, dt, tag::ab}, {{1, 20}, dt, tag::ab},
544 {{10, 20}, dt, tag::ab}, data_type::f32},
545 {P::SCALES | P::COMMON, {}, {{primitive::kind::sum}}}});
546 // gemm like: output scale + post-ops(sum) + all runtime
547 cases.push_back({{{{10, 1}, dt, tag::ab, P::SRC | P::RUNTIME},
548 {{1, 20}, dt, tag::ab, P::WEIGHTS | P::RUNTIME},
549 {{10, 20}, dt, tag::ab, P::DST | P::RUNTIME},
550 data_type::f32},
551 {P::SCALES | P::COMMON, {}, {{primitive::kind::sum}}}});
552
553 return ::testing::ValuesIn(cases);
554};
555
556INSTANTIATE_TEST_SUITE_P(Generic_f16, iface, cases_f(data_type::f16));
557INSTANTIATE_TEST_SUITE_P(Generic_bf16, iface, cases_f(data_type::bf16));
558INSTANTIATE_TEST_SUITE_P(Generic_f32, iface, cases_f(data_type::f32));
559
560static auto cases_x8 = [](memory::data_type src_dt, memory::data_type dst_dt) {
561 std::vector<matmul_test_params_t> cases;
562
563 // simple case
564 cases.push_back(
565 {{{{10, 2}, src_dt, tag::ba}, {{2, 20}, data_type::s8, tag::ab},
566 {{10, 20}, dst_dt, tag::ab}, data_type::undef},
567 {}});
568 // simple case + leading dimensions
569 cases.push_back(
570 {{{{10, 1}, src_dt, tag::ba, P::SRC | P::LEADING_DIM},
571 {{1, 3}, data_type::s8, tag::ba},
572 {{10, 3}, dst_dt, tag::ab, P::DST | P::LEADING_DIM},
573 data_type::s8},
574 {}});
575 // simple case + leading dimensions + runtime dims
576 cases.push_back(
577 {{{{1, 10}, src_dt, tag::ba, P::SRC | P::LEADING_DIM | P::RUNTIME},
578 {{10, 2}, data_type::s8, tag::ba, P::WEIGHTS | P::RUNTIME},
579 {{1, 2}, dst_dt, tag::ab,
580 P::DST | P::LEADING_DIM | P::RUNTIME},
581 data_type::u8},
582 {}});
583
584 // zero points
585 cases.push_back(
586 {{{{10, 2}, src_dt, tag::ba}, {{2, 20}, data_type::s8, tag::ab},
587 {{10, 20}, dst_dt, tag::ab}, data_type::f32},
588 {P::SCALES | P::COMMON,
589 {P::ZERO_POINTS | P::SRC | P::COMMON,
590 P::ZERO_POINTS | P::WEIGHTS | P::COMMON,
591 P::ZERO_POINTS | P::DST | P::COMMON}}});
592
593 // zero points + runtime
594 cases.push_back(
595 {{{{10, 2}, src_dt, tag::ba}, {{2, 20}, data_type::s8, tag::ab},
596 {{10, 20}, dst_dt, tag::ab}, data_type::f32},
597 {P::SCALES | P::COMMON | P::RUNTIME,
598 {P::ZERO_POINTS | P::SRC | P::COMMON, P::NONE,
599 P::ZERO_POINTS | P::DST | P::COMMON}}});
600
601 // per_dim_1 zero points + runtime
602 cases.push_back(
603 {{{{10, 2}, src_dt, tag::ba}, {{2, 20}, data_type::s8, tag::ab},
604 {{10, 20}, dst_dt, tag::ab}, data_type::f32},
605 {P::SCALES | P::COMMON | P::RUNTIME,
606 {P::ZERO_POINTS | P::SRC | P::PER_N, P::NONE,
607 P::ZERO_POINTS | P::DST | P::PER_N}}});
608 // post-ops
609 cases.push_back({{{{10, 1}, src_dt, tag::ab},
610 {{1, 20}, data_type::s8, tag::ab},
611 {{10, 20}, dst_dt, tag::ab}},
612 {P::NONE, {},
613 {{primitive::kind::eltwise, algorithm::eltwise_relu}}}});
614 // multiple post-ops
615 cases.push_back(
616 {{{{10, 2}, src_dt, tag::ab}, {{2, 20}, data_type::s8, tag::ab},
617 {{10, 20}, dst_dt, tag::ab}, data_type::f32},
618 {P::SCALES | P::COMMON, {},
619 {{primitive::kind::sum},
620 {primitive::kind::eltwise,
621 algorithm::eltwise_relu}}}});
622
623 // igemm like: output scale + post-ops(sum)
624 cases.push_back(
625 {{{{10, 1}, src_dt, tag::ab}, {{1, 20}, data_type::s8, tag::ab},
626 {{10, 20}, dst_dt, tag::ab}, data_type::s8},
627 {P::SCALES | P::COMMON,
628 {P::ZERO_POINTS | P::SRC | P::COMMON, P::NONE,
629 P::ZERO_POINTS | P::DST | P::COMMON},
630 {{primitive::kind::sum}}}});
631 // igemm like: output scale + post-ops(sum) + all runtime
632 cases.push_back(
633 {{{{10, 2}, src_dt, tag::ba}, {{2, 20}, data_type::s8, tag::ba},
634 {{10, 20}, dst_dt, tag::ab}, data_type::s8},
635 {P::SCALES | P::PER_N,
636 {P::ZERO_POINTS | P::SRC | P::COMMON,
637 P::ZERO_POINTS | P::WEIGHTS | P::COMMON,
638 P::ZERO_POINTS | P::DST | P::COMMON},
639 {{primitive::kind::sum}}}});
640
641 return ::testing::ValuesIn(cases);
642};
643INSTANTIATE_TEST_SUITE_P(
644 Generic_s8s8s32, iface, cases_x8(data_type::s8, data_type::s32));
645INSTANTIATE_TEST_SUITE_P(
646 Generic_u8s8u8, iface, cases_x8(data_type::u8, data_type::u8));
647
648INSTANTIATE_TEST_SUITE_P(TensorDims, attr_test_t,
649 ::testing::Values(
650 // {{src0, src1, dst same_dim}, { binary post-op dim }},
651 // format_tag, post-op data type, ndims
652 std::make_tuple(memory::dims {3, 2, 16, 16},
653 memory::dims {3, 1, 16, 16}, tag::abcd,
654 memory::data_type::f32, 4),
655 std::make_tuple(memory::dims {9, 9, 64, 64},
656 memory::dims {9, 1, 64, 64}, tag::abcd,
657 memory::data_type::f32, 4),
658 std::make_tuple(memory::dims {3, 2, 16, 16},
659 memory::dims {3, 2, 16, 16}, tag::abcd,
660 memory::data_type::f32, 4),
661 std::make_tuple(memory::dims {2, 10, 10, 10},
662 memory::dims {2, 10, 10, 10}, tag::abcd,
663 memory::data_type::bf16, 4),
664 std::make_tuple(memory::dims {2, 10, 10, 10},
665 memory::dims {2, 10, 10, 10}, tag::abcd,
666 memory::data_type::f16, 4)));
667
668} // namespace dnnl
669