1 | /******************************************************************************* |
2 | * Copyright 2021 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef TEST_GEMM_DATA_H |
18 | #define TEST_GEMM_DATA_H |
19 | |
20 | #include <cstdint> |
21 | #include <utility> |
22 | #include <type_traits> |
23 | |
24 | #include "oneapi/dnnl/dnnl_types.h" |
25 | |
26 | namespace dnnl { |
27 | |
28 | enum class layout_t { ROW_MAJOR, COL_MAJOR }; |
29 | |
30 | struct test_igemm_params { |
31 | char offsetc; |
32 | bool nonzero_oa; |
33 | bool nonzero_ob; |
34 | bool nonzero_oc; |
35 | |
36 | int8_t oa() const { return (int8_t)(nonzero_oa ? 4 : 0); } |
37 | int8_t ob() const { return (int8_t)(nonzero_ob ? 3 : 0); } |
38 | }; |
39 | |
40 | struct test_pack_params { |
41 | bool pack_a; |
42 | bool pack_b; |
43 | }; |
44 | |
45 | struct gemm_offset { |
46 | int64_t a; |
47 | int64_t b; |
48 | int64_t c; |
49 | int64_t co; |
50 | }; |
51 | |
52 | struct test_params { |
53 | char transA; |
54 | char transB; |
55 | int64_t M; |
56 | int64_t N; |
57 | int64_t K; |
58 | float alpha; |
59 | float beta; |
60 | int64_t lda; |
61 | int64_t ldb; |
62 | int64_t ldc; |
63 | |
64 | test_igemm_params igemm_params; |
65 | test_pack_params pack_params; |
66 | bool expect_to_fail; |
67 | dnnl_status_t expected_status; |
68 | |
69 | gemm_offset off; |
70 | |
71 | bool tr_a() const { return transA == 'T' || transA == 't'; } |
72 | bool tr_b() const { return transB == 'T' || transB == 't'; } |
73 | int64_t sizeC() const { return M * ldc; } |
74 | |
75 | bool oc_is_R() const { |
76 | auto c = igemm_params.offsetc; |
77 | return c == 'R' || c == 'r'; |
78 | } |
79 | bool oc_is_C() const { |
80 | auto c = igemm_params.offsetc; |
81 | return c == 'C' || c == 'c'; |
82 | } |
83 | int64_t size_oc() const { return oc_is_R() ? N : oc_is_C() ? M : 1; } |
84 | }; |
85 | |
86 | template <typename... TArgs> |
87 | inline test_params make_test_params_with_offset( |
88 | const gemm_offset &off, TArgs &&... args) { |
89 | test_params params {std::forward<TArgs>(args)...}; |
90 | params.off = off; |
91 | return params; |
92 | } |
93 | |
94 | template <typename... TArgs> |
95 | inline test_params make_test_params_pack( |
96 | const test_pack_params &pack_params, TArgs &&... args) { |
97 | test_params params {std::forward<TArgs>(args)...}; |
98 | params.pack_params = pack_params; |
99 | return params; |
100 | } |
101 | |
102 | } // namespace dnnl |
103 | |
104 | #endif |
105 | |