1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain src copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "dnnl_test_common.hpp" |
18 | #include "gtest/gtest.h" |
19 | |
20 | #include "oneapi/dnnl/dnnl.hpp" |
21 | |
22 | #include <vector> |
23 | |
24 | namespace dnnl { |
25 | |
26 | namespace P { |
27 | // Common |
28 | unsigned NONE = 0u; |
29 | |
30 | unsigned RUNTIME = 1u << 31; |
31 | |
32 | unsigned SCALES = 1u << 30; |
33 | unsigned ZERO_POINTS = 1u << 29; |
34 | |
35 | unsigned LEADING_DIM = 1u << 28; |
36 | |
37 | // matrices indices: 1 .. 7 |
38 | // bits reserved: 20 .. 22 |
39 | unsigned MATRIX_MASK = 7u << 20; |
40 | unsigned SRC = 1u << 20; |
41 | unsigned WEIGHTS = 2u << 20; |
42 | unsigned DST = 3u << 20; |
43 | |
44 | // scales and zero points: 1 .. 3 |
45 | // bits reserved: 0 .. 1 |
46 | unsigned MASK_MASK = 3u << 0; |
47 | unsigned COMMON = 1u << 0; |
48 | unsigned PER_N = 1u << 1; |
49 | } // namespace P |
50 | |
51 | struct matmul_base_t { |
52 | struct md_t { |
53 | memory::dims dims; |
54 | memory::data_type dt; |
55 | memory::format_tag tag; |
56 | unsigned flags; |
57 | } src, weights, dst; |
58 | memory::data_type bia_dt; |
59 | }; |
60 | |
61 | // TODO: src way to generalize? |
62 | struct matmul_attr_t { |
63 | // ctor {P::SCALE, {P::SRC, P::WEIGHTS, P::DST}, {P::POST_OPS, ...}} |
64 | |
65 | unsigned scale_flags; |
66 | |
67 | struct zero_points_t { |
68 | unsigned src, weights, dst; |
69 | } zero_points; |
70 | |
71 | struct post_op_t { |
72 | primitive::kind kind; |
73 | algorithm alg; |
74 | }; |
75 | |
76 | std::vector<post_op_t> post_ops; |
77 | }; |
78 | |
79 | struct matmul_test_params_t { |
80 | matmul_base_t base; |
81 | matmul_attr_t attr; |
82 | |
83 | bool expect_to_fail; |
84 | dnnl_status_t expected_status; |
85 | }; |
86 | |
87 | using tag = memory::format_tag; |
88 | |
89 | class matmul_iface_test_t |
90 | : public ::testing::TestWithParam<matmul_test_params_t> { |
91 | protected: |
92 | void SetUp() override { |
93 | matmul_test_params_t p |
94 | = ::testing::TestWithParam<decltype(p)>::GetParam(); |
95 | |
96 | SKIP_IF(unsupported_data_type(p.base.src.dt), |
97 | "Engine does not support this data type." ); |
98 | SKIP_IF(unsupported_data_type(p.base.weights.dt), |
99 | "Engine does not support this data type." ); |
100 | SKIP_IF(unsupported_data_type(p.base.dst.dt), |
101 | "Engine does not support this data type." ); |
102 | SKIP_IF(unsupported_data_type(p.base.bia_dt), |
103 | "Engine does not support this data type." ); |
104 | SKIP_IF(get_test_engine_kind() == engine::kind::gpu |
105 | && ((p.attr.zero_points.src & P::PER_N) |
106 | || (p.attr.zero_points.dst & P::PER_N)), |
107 | "Per dimensional zero points are not supported on GPU" ); |
108 | SKIP_IF(get_test_engine_kind() == engine::kind::cpu |
109 | && p.base.src.tag == impl::format_tag::AB8a4b, |
110 | "Don't test blocked formats on CPU" ); |
111 | |
112 | SKIP_IF_CUDA((p.attr.zero_points.src != 0 || p.attr.zero_points.dst != 0 |
113 | || p.attr.zero_points.weights != 0), |
114 | "Zero points not supported for CUDA" ); |
115 | |
116 | SKIP_IF_CUDA((p.attr.scale_flags & P::MASK_MASK) == P::PER_N, |
117 | "Per dimensional scaling is not supported for CUDA" ); |
118 | |
119 | catch_expected_failures( |
120 | [=]() { Test(); }, p.expect_to_fail, p.expected_status, false); |
121 | } |
122 | |
123 | // use `force_no_rt = true` when create final memory |
124 | static memory::desc init_md( |
125 | const matmul_base_t::md_t &desc, bool force_no_rt = false) { |
126 | const bool runtime = force_no_rt ? false : (desc.flags & P::RUNTIME); |
127 | const bool use_ld = (desc.flags & P::LEADING_DIM); |
128 | |
129 | memory::dims dims = desc.dims; |
130 | if (runtime) |
131 | dims = memory::dims(desc.dims.size(), DNNL_RUNTIME_DIM_VAL); |
132 | |
133 | if (runtime || use_ld == false) |
134 | return memory::desc(dims, desc.dt, desc.tag); |
135 | |
136 | memory::dims strides; |
137 | switch (desc.tag) { |
138 | case tag::ab: strides = {dims[1] + 1, 1}; break; |
139 | case tag::ba: strides = {1, dims[0] + 1}; break; |
140 | case tag::abc: |
141 | strides = {dims[1] * (dims[2] + 1) + 1, dims[2] + 1, 1}; |
142 | break; |
143 | case tag::acb: |
144 | strides = {dims[1] * (dims[2] + 1) + 1, dims[2] + 1, 1}; |
145 | break; |
146 | default: |
147 | throw std::invalid_argument("tag doesn't support custom ld" ); |
148 | } |
149 | |
150 | return memory::desc(dims, desc.dt, strides); |
151 | } |
152 | |
153 | static void create_attr(const matmul_test_params_t &p, primitive_attr &attr, |
154 | memory &zero_points_src_m, memory &zero_points_weights_m, |
155 | memory &zero_points_dst_m, engine &eng) { |
156 | const int ndims = (int)p.base.dst.dims.size(); |
157 | |
158 | // zero points |
159 | auto handle_zero_points = [&](int arg, unsigned flags, |
160 | const matmul_base_t::md_t &md, |
161 | memory &zero_points_m) { |
162 | if (flags == P::NONE) return; |
163 | |
164 | ASSERT_TRUE(flags & P::ZERO_POINTS); |
165 | ASSERT_TRUE(flags & P::MATRIX_MASK); |
166 | |
167 | // sanity check |
168 | switch (arg) { |
169 | case DNNL_ARG_SRC: |
170 | ASSERT_TRUE((flags & P::MATRIX_MASK) == P::SRC); |
171 | break; |
172 | case DNNL_ARG_WEIGHTS: |
173 | ASSERT_TRUE((flags & P::MATRIX_MASK) == P::WEIGHTS); |
174 | break; |
175 | case DNNL_ARG_DST: |
176 | ASSERT_TRUE((flags & P::MATRIX_MASK) == P::DST); |
177 | break; |
178 | default: ASSERT_TRUE(!"unreachable" ); |
179 | } |
180 | |
181 | unsigned zero_points_mask = flags & P::MASK_MASK; |
182 | ASSERT_TRUE(zero_points_mask == P::COMMON |
183 | || zero_points_mask == P::PER_N); |
184 | int mask = zero_points_mask == P::PER_N ? 1 << (ndims - 1) : 0; |
185 | memory::dim zero_points_size = mask ? md.dims[ndims - 1] : 1; |
186 | |
187 | attr.set_zero_points_mask(arg, mask); |
188 | zero_points_m = test::make_memory( |
189 | {{zero_points_size}, memory::data_type::s32, {1}}, eng); |
190 | auto z = map_memory<int32_t>(zero_points_m); |
191 | GTEST_EXPECT_NE(z, nullptr); |
192 | for (memory::dim i = 0; i < zero_points_size; ++i) |
193 | z[i] = (arg % 7) - 3; |
194 | }; |
195 | |
196 | handle_zero_points(DNNL_ARG_SRC, p.attr.zero_points.src, p.base.src, |
197 | zero_points_src_m); |
198 | handle_zero_points(DNNL_ARG_WEIGHTS, p.attr.zero_points.weights, |
199 | p.base.weights, zero_points_weights_m); |
200 | handle_zero_points(DNNL_ARG_DST, p.attr.zero_points.dst, p.base.dst, |
201 | zero_points_dst_m); |
202 | |
203 | // post ops |
204 | post_ops po; |
205 | for (auto post_op : p.attr.post_ops) { |
206 | switch (post_op.kind) { |
207 | case primitive::kind::sum: po.append_sum(); break; |
208 | case primitive::kind::eltwise: |
209 | po.append_eltwise(post_op.alg, 0.f, 0.f); |
210 | break; |
211 | default: ASSERT_TRUE(!"unknown post op kind" ); |
212 | } |
213 | } |
214 | attr.set_post_ops(po); |
215 | } |
216 | |
217 | void Test() { |
218 | matmul_test_params_t p |
219 | = ::testing::TestWithParam<matmul_test_params_t>::GetParam(); |
220 | |
221 | auto eng = get_test_engine(); |
222 | auto strm = make_stream(eng); |
223 | |
224 | auto check_matrix_flags = [](unsigned flags, unsigned matrix) { |
225 | if (flags) { ASSERT_EQ(flags & P::MATRIX_MASK, matrix); } |
226 | }; |
227 | check_matrix_flags(p.base.src.flags, P::SRC); |
228 | check_matrix_flags(p.base.weights.flags, P::WEIGHTS); |
229 | check_matrix_flags(p.base.dst.flags, P::DST); |
230 | |
231 | auto src_md = init_md(p.base.src); |
232 | auto weights_md = init_md(p.base.weights); |
233 | auto dst_md = init_md(p.base.dst); |
234 | |
235 | auto bia_md = memory::desc(); |
236 | memory bia_m; |
237 | if (p.base.bia_dt != memory::data_type::undef) { |
238 | memory::dims bia_dims(p.base.dst.dims.size() - 1, 1); |
239 | bia_dims.push_back(p.base.dst.dims.back()); |
240 | tag bia_tag = bia_dims.size() == 2 ? tag::ab : tag::abc; |
241 | bia_md = init_md({bia_dims, p.base.bia_dt, bia_tag, |
242 | p.base.dst.flags & P::RUNTIME}); |
243 | bia_m = test::make_memory( |
244 | init_md({bia_dims, p.base.bia_dt, bia_tag}), eng); |
245 | } |
246 | |
247 | primitive_attr attr; |
248 | memory scales_m, zero_points_src_m, zero_points_weights_m, |
249 | zero_points_dst_m; |
250 | create_attr(p, attr, zero_points_src_m, zero_points_weights_m, |
251 | zero_points_dst_m, eng); |
252 | |
253 | auto matmul_pd = matmul::primitive_desc( |
254 | eng, src_md, weights_md, bia_md, dst_md, attr); |
255 | |
256 | ASSERT_TRUE(matmul_pd.query_md(query::exec_arg_md, DNNL_ARG_SRC) |
257 | == matmul_pd.src_desc()); |
258 | ASSERT_TRUE(matmul_pd.query_md(query::exec_arg_md, DNNL_ARG_WEIGHTS) |
259 | == matmul_pd.weights_desc()); |
260 | ASSERT_TRUE(matmul_pd.query_md(query::exec_arg_md, DNNL_ARG_BIAS) |
261 | == matmul_pd.bias_desc()); |
262 | ASSERT_TRUE(matmul_pd.query_md(query::exec_arg_md, DNNL_ARG_DST) |
263 | == matmul_pd.dst_desc()); |
264 | |
265 | EXPECT_ANY_THROW(matmul(matmul_pd, {})); |
266 | auto matmul_p = matmul(matmul_pd); |
267 | |
268 | auto src_m = test::make_memory(init_md(p.base.src, true), eng); |
269 | auto weights_m = test::make_memory(init_md(p.base.weights, true), eng); |
270 | auto dst_m = test::make_memory(init_md(p.base.dst, true), eng); |
271 | |
272 | // Initialize memory to make sanitizers happy |
273 | auto set_to_zero = [](memory &m) { |
274 | if (m) { |
275 | auto p = map_memory<char>(m); |
276 | auto size = m.get_desc().get_size(); |
277 | if (size > 0) { |
278 | GTEST_EXPECT_NE(p, nullptr); |
279 | memset(p, 0, size); |
280 | } |
281 | } |
282 | }; |
283 | set_to_zero(src_m); |
284 | set_to_zero(weights_m); |
285 | set_to_zero(dst_m); |
286 | set_to_zero(bia_m); |
287 | |
288 | matmul_p.execute(strm, |
289 | { |
290 | {DNNL_ARG_SRC, src_m}, |
291 | {DNNL_ARG_WEIGHTS, weights_m}, |
292 | {DNNL_ARG_BIAS, bia_m}, |
293 | {DNNL_ARG_DST, dst_m}, |
294 | {DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC, |
295 | zero_points_src_m}, |
296 | {DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS, |
297 | zero_points_weights_m}, |
298 | {DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST, |
299 | zero_points_dst_m}, |
300 | }); |
301 | strm.wait(); |
302 | } |
303 | }; |
304 | |
305 | struct attr_test_t |
306 | : public ::testing::TestWithParam<std::tuple<memory::dims, memory::dims, |
307 | memory::format_tag, memory::data_type, int>> {}; |
308 | |
309 | HANDLE_EXCEPTIONS_FOR_TEST_P( |
310 | attr_test_t, TestMatmulShouldCallSameImplementationWithAttributes) { |
311 | auto engine_kind = get_test_engine_kind(); |
312 | SKIP_IF(!DNNL_X64 || engine_kind != engine::kind::cpu, |
313 | "Binary impl_info_str should be same only on x64 CPU" ); |
314 | engine e {engine_kind, 0}; |
315 | |
316 | const auto &tensor_dims = std::get<0>(GetParam()); |
317 | const auto format_tag = std::get<2>(GetParam()); |
318 | const auto &binary_po_mem_dt = std::get<3>(GetParam()); |
319 | |
320 | SKIP_IF(unsupported_data_type(binary_po_mem_dt), |
321 | "Engine does not support this data type." ); |
322 | |
323 | // Currently, f16 binary post-ops are only supported for f16 primitive. |
324 | const auto src_dt = binary_po_mem_dt == memory::data_type::f16 |
325 | ? memory::data_type::f16 |
326 | : memory::data_type::u8; |
327 | const auto wei_dt = binary_po_mem_dt == memory::data_type::f16 |
328 | ? memory::data_type::f16 |
329 | : memory::data_type::s8; |
330 | const auto dst_dt = binary_po_mem_dt == memory::data_type::f16 |
331 | ? memory::data_type::f16 |
332 | : memory::data_type::s8; |
333 | |
334 | auto src_md = memory::desc(tensor_dims, src_dt, format_tag); |
335 | auto weights_md = memory::desc(tensor_dims, wei_dt, format_tag); |
336 | auto dst_md = memory::desc(tensor_dims, dst_dt, format_tag); |
337 | auto bia_md = memory::desc(); |
338 | |
339 | std::string impl_info_no_postops; |
340 | auto matmul_pd |
341 | = matmul::primitive_desc(e, src_md, weights_md, bia_md, dst_md); |
342 | ASSERT_NO_THROW(impl_info_no_postops = matmul_pd.impl_info_str();); |
343 | |
344 | dnnl::primitive_attr attr; |
345 | const float alpha = 1.f; |
346 | const float beta = 1.f; |
347 | |
348 | dnnl::post_ops ops; |
349 | ops.append_sum(1.0); |
350 | ops.append_eltwise(algorithm::eltwise_relu, alpha, beta); |
351 | |
352 | const auto &binary_po_tensor_dims = std::get<1>(GetParam()); |
353 | memory::desc src1_po_md( |
354 | binary_po_tensor_dims, binary_po_mem_dt, format_tag); |
355 | ops.append_binary(algorithm::binary_add, src1_po_md); |
356 | |
357 | attr.set_post_ops(ops); |
358 | |
359 | std::string impl_info_with_postops; |
360 | |
361 | matmul_pd = matmul::primitive_desc( |
362 | e, src_md, weights_md, bia_md, dst_md, attr); |
363 | ASSERT_NO_THROW(impl_info_with_postops = matmul_pd.impl_info_str();); |
364 | ASSERT_EQ(impl_info_no_postops, impl_info_with_postops); |
365 | } |
366 | |
367 | /********************************* TEST CASES *********************************/ |
368 | |
369 | using iface = matmul_iface_test_t; |
370 | |
371 | using data_type = memory::data_type; |
372 | |
373 | TEST_P(iface, TestsMatMul) {} |
374 | |
375 | static auto cases_ef = []() { |
376 | std::vector<matmul_test_params_t> cases; |
377 | |
378 | // inconsistent dims |
379 | cases.push_back( |
380 | {{{{10, 1}, data_type::f32, tag::ab}, |
381 | {{2, 20}, data_type::f32, tag::ab}, |
382 | {{10, 20}, data_type::f32, tag::ab}, data_type::undef}, |
383 | {}, true, dnnl_invalid_arguments}); |
384 | cases.push_back({{{{10, 1}, data_type::f32, tag::ab}, |
385 | {{1, 20}, data_type::f32, tag::ab}, |
386 | {{10, 21}, data_type::f32, tag::ab}}, |
387 | {}, true, dnnl_invalid_arguments}); |
388 | cases.push_back({{{{10, 1}, data_type::f32, tag::ab}, |
389 | {{1, 1, 20}, data_type::f32, tag::abc}, |
390 | {{10, 20}, data_type::f32, tag::ab}}, |
391 | {}, true, dnnl_invalid_arguments}); |
392 | cases.push_back({{{{1, 10, 1}, data_type::u8, tag::abc}, |
393 | {{1, 1, 2}, data_type::s8, tag::abc}, |
394 | {{1, 11, 2}, data_type::s8, tag::abc}}, |
395 | {}, true, dnnl_invalid_arguments}); |
396 | |
397 | // inconsistent wrt runtime dim vals |
398 | cases.push_back( |
399 | {{{{3, 10, 10}, data_type::f32, tag::abc}, |
400 | {{DNNL_RUNTIME_DIM_VAL, 10, 10}, data_type::f32, tag::abc}, |
401 | {{DNNL_RUNTIME_DIM_VAL, 10, 10}, data_type::f32, |
402 | tag::abc}}, |
403 | {}, true, dnnl_invalid_arguments}); |
404 | |
405 | // inconsistent wrt broadcasting |
406 | cases.push_back({{{{3, 10, 10}, data_type::f32, tag::abc}, |
407 | {{1, 10, 10}, data_type::f32, tag::abc}, |
408 | {{1, 10, 10}, data_type::f32, tag::abc}}, |
409 | {}, true, dnnl_invalid_arguments}); |
410 | |
411 | // no broadcasting on m/k/n dims |
412 | cases.push_back({{{{10, 10}, data_type::f32, tag::ab}, |
413 | {{1, 1}, data_type::f32, tag::ab}, |
414 | {{10, 10}, data_type::f32, tag::ab}}, |
415 | {}, true, dnnl_invalid_arguments}); |
416 | |
417 | // f32 data and zero-points |
418 | cases.push_back({{{{10, 1}, data_type::f32, tag::ab}, |
419 | {{1, 20}, data_type::f32, tag::ab}, |
420 | {{10, 20}, data_type::f32, tag::ab}}, |
421 | {P::NONE, {P::ZERO_POINTS | P::SRC | P::COMMON}}, true, |
422 | dnnl_unimplemented}); |
423 | |
424 | // bf16 data and zero-points |
425 | cases.push_back({{{{10, 1}, data_type::bf16, tag::ab}, |
426 | {{1, 20}, data_type::bf16, tag::ab}, |
427 | {{10, 20}, data_type::bf16, tag::ab}}, |
428 | {P::NONE, {P::ZERO_POINTS | P::SRC | P::COMMON}}, true, |
429 | dnnl_unimplemented}); |
430 | // unimplemented data types |
431 | if (get_test_engine_kind() == engine::kind::cpu) { |
432 | cases.push_back( |
433 | {{{{10, 1}, data_type::f32, tag::ab}, |
434 | {{1, 20}, data_type::f32, tag::ab}, |
435 | {{10, 20}, data_type::f32, tag::ab}, data_type::u8}, |
436 | {}, true, dnnl_unimplemented}); |
437 | } |
438 | // XXX: disable assert in type_helpers.hpp: default_accum_data_type(...) |
439 | // cases.push_back({{{{10, 1}, data_type::u8, tag::ab}, {{1, 20}, |
440 | // data_type::u8, tag::ab}, |
441 | // {{10, 20}, data_type::u8, tag::ab}}, |
442 | // {}, true, dnnl_unimplemented}); |
443 | |
444 | // unimplemented formats (GPU only) |
445 | cases.push_back({{{{16, 16}, data_type::f32, tag::AB8a4b}, |
446 | {{16, 16}, data_type::f32, tag::AB8a4b}, |
447 | {{16, 16}, data_type::f32, tag::AB8a4b}}, |
448 | {}, true, dnnl_unimplemented}); |
449 | |
450 | // broken broadcast case |
451 | cases.push_back( |
452 | {{{{1, 10, 2}, data_type::f32, tag::abc}, |
453 | {{1, 2, 20}, data_type::f32, tag::abc}, |
454 | {{0, 10, 20}, data_type::f32, tag::abc}, data_type::undef}, |
455 | {}, true, dnnl_invalid_arguments}); |
456 | |
457 | // broken broadcast case |
458 | cases.push_back( |
459 | {{{{0, 10, 2}, data_type::f32, tag::abc}, |
460 | {{2, 2, 20}, data_type::f32, tag::abc}, |
461 | {{0, 10, 20}, data_type::f32, tag::abc}, data_type::undef}, |
462 | {}, true, dnnl_invalid_arguments}); |
463 | |
464 | // broken broadcast case |
465 | cases.push_back( |
466 | {{{{1, 10, 2}, data_type::f32, tag::abc}, |
467 | {{0, 2, 20}, data_type::f32, tag::abc}, |
468 | {{1, 10, 20}, data_type::f32, tag::abc}, data_type::undef}, |
469 | {}, true, dnnl_invalid_arguments}); |
470 | |
471 | return ::testing::ValuesIn(cases); |
472 | }; |
473 | INSTANTIATE_TEST_SUITE_P(EF, iface, cases_ef()); |
474 | |
475 | static auto cases_zd = [](memory::data_type dt) { |
476 | std::vector<matmul_test_params_t> cases; |
477 | |
478 | // simple case M=0 |
479 | cases.push_back({{{{0, 2}, dt, tag::ab}, {{2, 20}, dt, tag::ab}, |
480 | {{0, 20}, dt, tag::ab}, data_type::undef}, |
481 | {}}); |
482 | // simple case K=0 |
483 | cases.push_back({{{{10, 0}, dt, tag::ab}, {{0, 20}, dt, tag::ab}, |
484 | {{10, 20}, dt, tag::ab}, data_type::undef}, |
485 | {}}); |
486 | // simple case N=0 |
487 | cases.push_back({{{{10, 2}, dt, tag::ab}, {{2, 0}, dt, tag::ab}, |
488 | {{10, 0}, dt, tag::ab}, data_type::undef}, |
489 | {}}); |
490 | // broadcast case all MB=0 |
491 | cases.push_back({{{{0, 10, 2}, dt, tag::abc}, {{0, 2, 20}, dt, tag::abc}, |
492 | {{0, 10, 20}, dt, tag::abc}, data_type::undef}, |
493 | {}}); |
494 | // broadcast case wei MB!=0 |
495 | cases.push_back({{{{0, 10, 2}, dt, tag::abc}, {{1, 2, 20}, dt, tag::abc}, |
496 | {{0, 10, 20}, dt, tag::abc}, data_type::undef}, |
497 | {}}); |
498 | // broadcast case src MB!=0 |
499 | cases.push_back({{{{1, 10, 2}, dt, tag::abc}, {{0, 2, 20}, dt, tag::abc}, |
500 | {{0, 10, 20}, dt, tag::abc}, data_type::undef}, |
501 | {}}); |
502 | |
503 | return ::testing::ValuesIn(cases); |
504 | }; |
505 | INSTANTIATE_TEST_SUITE_P(ZeroDim_f32, iface, cases_zd(data_type::f32)); |
506 | |
507 | static auto cases_f = [](memory::data_type dt) { |
508 | std::vector<matmul_test_params_t> cases; |
509 | |
510 | // simple case |
511 | cases.push_back({{{{10, 2}, dt, tag::ab}, {{2, 20}, dt, tag::ab}, |
512 | {{10, 20}, dt, tag::ab}, data_type::undef}, |
513 | {}}); |
514 | // simple case + leading dimensions |
515 | cases.push_back({{{{10, 1}, dt, tag::ab, P::SRC | P::LEADING_DIM}, |
516 | {{1, 3}, dt, tag::ba}, |
517 | {{10, 3}, dt, tag::ab, P::DST | P::LEADING_DIM}, |
518 | data_type::f32}, |
519 | {}}); |
520 | // simple case + leading dimensions + runtime dims |
521 | cases.push_back( |
522 | {{{{1, 10}, dt, tag::ab, P::SRC | P::LEADING_DIM | P::RUNTIME}, |
523 | {{10, 2}, dt, tag::ba, P::WEIGHTS | P::RUNTIME}, |
524 | {{1, 2}, dt, tag::ab, |
525 | P::DST | P::LEADING_DIM | P::RUNTIME}, |
526 | data_type::f32}, |
527 | {}}); |
528 | |
529 | // post-ops |
530 | cases.push_back({{{{10, 1}, dt, tag::ab}, {{1, 20}, dt, tag::ab}, |
531 | {{10, 20}, dt, tag::ab}}, |
532 | {P::NONE, {}, |
533 | {{primitive::kind::eltwise, algorithm::eltwise_relu}}}}); |
534 | // multiple post-ops |
535 | cases.push_back({{{{10, 2}, dt, tag::ab}, {{2, 20}, dt, tag::ab}, |
536 | {{10, 20}, dt, tag::ab}}, |
537 | {P::SCALES | P::COMMON, {}, |
538 | {{primitive::kind::sum}, |
539 | {primitive::kind::eltwise, |
540 | algorithm::eltwise_relu}}}}); |
541 | |
542 | // gemm like: output scale + post-ops(sum) |
543 | cases.push_back({{{{10, 1}, dt, tag::ab}, {{1, 20}, dt, tag::ab}, |
544 | {{10, 20}, dt, tag::ab}, data_type::f32}, |
545 | {P::SCALES | P::COMMON, {}, {{primitive::kind::sum}}}}); |
546 | // gemm like: output scale + post-ops(sum) + all runtime |
547 | cases.push_back({{{{10, 1}, dt, tag::ab, P::SRC | P::RUNTIME}, |
548 | {{1, 20}, dt, tag::ab, P::WEIGHTS | P::RUNTIME}, |
549 | {{10, 20}, dt, tag::ab, P::DST | P::RUNTIME}, |
550 | data_type::f32}, |
551 | {P::SCALES | P::COMMON, {}, {{primitive::kind::sum}}}}); |
552 | |
553 | return ::testing::ValuesIn(cases); |
554 | }; |
555 | |
556 | INSTANTIATE_TEST_SUITE_P(Generic_f16, iface, cases_f(data_type::f16)); |
557 | INSTANTIATE_TEST_SUITE_P(Generic_bf16, iface, cases_f(data_type::bf16)); |
558 | INSTANTIATE_TEST_SUITE_P(Generic_f32, iface, cases_f(data_type::f32)); |
559 | |
560 | static auto cases_x8 = [](memory::data_type src_dt, memory::data_type dst_dt) { |
561 | std::vector<matmul_test_params_t> cases; |
562 | |
563 | // simple case |
564 | cases.push_back( |
565 | {{{{10, 2}, src_dt, tag::ba}, {{2, 20}, data_type::s8, tag::ab}, |
566 | {{10, 20}, dst_dt, tag::ab}, data_type::undef}, |
567 | {}}); |
568 | // simple case + leading dimensions |
569 | cases.push_back( |
570 | {{{{10, 1}, src_dt, tag::ba, P::SRC | P::LEADING_DIM}, |
571 | {{1, 3}, data_type::s8, tag::ba}, |
572 | {{10, 3}, dst_dt, tag::ab, P::DST | P::LEADING_DIM}, |
573 | data_type::s8}, |
574 | {}}); |
575 | // simple case + leading dimensions + runtime dims |
576 | cases.push_back( |
577 | {{{{1, 10}, src_dt, tag::ba, P::SRC | P::LEADING_DIM | P::RUNTIME}, |
578 | {{10, 2}, data_type::s8, tag::ba, P::WEIGHTS | P::RUNTIME}, |
579 | {{1, 2}, dst_dt, tag::ab, |
580 | P::DST | P::LEADING_DIM | P::RUNTIME}, |
581 | data_type::u8}, |
582 | {}}); |
583 | |
584 | // zero points |
585 | cases.push_back( |
586 | {{{{10, 2}, src_dt, tag::ba}, {{2, 20}, data_type::s8, tag::ab}, |
587 | {{10, 20}, dst_dt, tag::ab}, data_type::f32}, |
588 | {P::SCALES | P::COMMON, |
589 | {P::ZERO_POINTS | P::SRC | P::COMMON, |
590 | P::ZERO_POINTS | P::WEIGHTS | P::COMMON, |
591 | P::ZERO_POINTS | P::DST | P::COMMON}}}); |
592 | |
593 | // zero points + runtime |
594 | cases.push_back( |
595 | {{{{10, 2}, src_dt, tag::ba}, {{2, 20}, data_type::s8, tag::ab}, |
596 | {{10, 20}, dst_dt, tag::ab}, data_type::f32}, |
597 | {P::SCALES | P::COMMON | P::RUNTIME, |
598 | {P::ZERO_POINTS | P::SRC | P::COMMON, P::NONE, |
599 | P::ZERO_POINTS | P::DST | P::COMMON}}}); |
600 | |
601 | // per_dim_1 zero points + runtime |
602 | cases.push_back( |
603 | {{{{10, 2}, src_dt, tag::ba}, {{2, 20}, data_type::s8, tag::ab}, |
604 | {{10, 20}, dst_dt, tag::ab}, data_type::f32}, |
605 | {P::SCALES | P::COMMON | P::RUNTIME, |
606 | {P::ZERO_POINTS | P::SRC | P::PER_N, P::NONE, |
607 | P::ZERO_POINTS | P::DST | P::PER_N}}}); |
608 | // post-ops |
609 | cases.push_back({{{{10, 1}, src_dt, tag::ab}, |
610 | {{1, 20}, data_type::s8, tag::ab}, |
611 | {{10, 20}, dst_dt, tag::ab}}, |
612 | {P::NONE, {}, |
613 | {{primitive::kind::eltwise, algorithm::eltwise_relu}}}}); |
614 | // multiple post-ops |
615 | cases.push_back( |
616 | {{{{10, 2}, src_dt, tag::ab}, {{2, 20}, data_type::s8, tag::ab}, |
617 | {{10, 20}, dst_dt, tag::ab}, data_type::f32}, |
618 | {P::SCALES | P::COMMON, {}, |
619 | {{primitive::kind::sum}, |
620 | {primitive::kind::eltwise, |
621 | algorithm::eltwise_relu}}}}); |
622 | |
623 | // igemm like: output scale + post-ops(sum) |
624 | cases.push_back( |
625 | {{{{10, 1}, src_dt, tag::ab}, {{1, 20}, data_type::s8, tag::ab}, |
626 | {{10, 20}, dst_dt, tag::ab}, data_type::s8}, |
627 | {P::SCALES | P::COMMON, |
628 | {P::ZERO_POINTS | P::SRC | P::COMMON, P::NONE, |
629 | P::ZERO_POINTS | P::DST | P::COMMON}, |
630 | {{primitive::kind::sum}}}}); |
631 | // igemm like: output scale + post-ops(sum) + all runtime |
632 | cases.push_back( |
633 | {{{{10, 2}, src_dt, tag::ba}, {{2, 20}, data_type::s8, tag::ba}, |
634 | {{10, 20}, dst_dt, tag::ab}, data_type::s8}, |
635 | {P::SCALES | P::PER_N, |
636 | {P::ZERO_POINTS | P::SRC | P::COMMON, |
637 | P::ZERO_POINTS | P::WEIGHTS | P::COMMON, |
638 | P::ZERO_POINTS | P::DST | P::COMMON}, |
639 | {{primitive::kind::sum}}}}); |
640 | |
641 | return ::testing::ValuesIn(cases); |
642 | }; |
643 | INSTANTIATE_TEST_SUITE_P( |
644 | Generic_s8s8s32, iface, cases_x8(data_type::s8, data_type::s32)); |
645 | INSTANTIATE_TEST_SUITE_P( |
646 | Generic_u8s8u8, iface, cases_x8(data_type::u8, data_type::u8)); |
647 | |
648 | INSTANTIATE_TEST_SUITE_P(TensorDims, attr_test_t, |
649 | ::testing::Values( |
650 | // {{src0, src1, dst same_dim}, { binary post-op dim }}, |
651 | // format_tag, post-op data type, ndims |
652 | std::make_tuple(memory::dims {3, 2, 16, 16}, |
653 | memory::dims {3, 1, 16, 16}, tag::abcd, |
654 | memory::data_type::f32, 4), |
655 | std::make_tuple(memory::dims {9, 9, 64, 64}, |
656 | memory::dims {9, 1, 64, 64}, tag::abcd, |
657 | memory::data_type::f32, 4), |
658 | std::make_tuple(memory::dims {3, 2, 16, 16}, |
659 | memory::dims {3, 2, 16, 16}, tag::abcd, |
660 | memory::data_type::f32, 4), |
661 | std::make_tuple(memory::dims {2, 10, 10, 10}, |
662 | memory::dims {2, 10, 10, 10}, tag::abcd, |
663 | memory::data_type::bf16, 4), |
664 | std::make_tuple(memory::dims {2, 10, 10, 10}, |
665 | memory::dims {2, 10, 10, 10}, tag::abcd, |
666 | memory::data_type::f16, 4))); |
667 | |
668 | } // namespace dnnl |
669 | |