1/*******************************************************************************
2* Copyright 2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef TEST_GEMM_DATA_H
18#define TEST_GEMM_DATA_H
19
20#include <cstdint>
21#include <utility>
22#include <type_traits>
23
24#include "oneapi/dnnl/dnnl_types.h"
25
26namespace dnnl {
27
28enum class layout_t { ROW_MAJOR, COL_MAJOR };
29
30struct test_igemm_params {
31 char offsetc;
32 bool nonzero_oa;
33 bool nonzero_ob;
34 bool nonzero_oc;
35
36 int8_t oa() const { return (int8_t)(nonzero_oa ? 4 : 0); }
37 int8_t ob() const { return (int8_t)(nonzero_ob ? 3 : 0); }
38};
39
40struct test_pack_params {
41 bool pack_a;
42 bool pack_b;
43};
44
45struct gemm_offset {
46 int64_t a;
47 int64_t b;
48 int64_t c;
49 int64_t co;
50};
51
52struct test_params {
53 char transA;
54 char transB;
55 int64_t M;
56 int64_t N;
57 int64_t K;
58 float alpha;
59 float beta;
60 int64_t lda;
61 int64_t ldb;
62 int64_t ldc;
63
64 test_igemm_params igemm_params;
65 test_pack_params pack_params;
66 bool expect_to_fail;
67 dnnl_status_t expected_status;
68
69 gemm_offset off;
70
71 bool tr_a() const { return transA == 'T' || transA == 't'; }
72 bool tr_b() const { return transB == 'T' || transB == 't'; }
73 int64_t sizeC() const { return M * ldc; }
74
75 bool oc_is_R() const {
76 auto c = igemm_params.offsetc;
77 return c == 'R' || c == 'r';
78 }
79 bool oc_is_C() const {
80 auto c = igemm_params.offsetc;
81 return c == 'C' || c == 'c';
82 }
83 int64_t size_oc() const { return oc_is_R() ? N : oc_is_C() ? M : 1; }
84};
85
86template <typename... TArgs>
87inline test_params make_test_params_with_offset(
88 const gemm_offset &off, TArgs &&... args) {
89 test_params params {std::forward<TArgs>(args)...};
90 params.off = off;
91 return params;
92}
93
94template <typename... TArgs>
95inline test_params make_test_params_pack(
96 const test_pack_params &pack_params, TArgs &&... args) {
97 test_params params {std::forward<TArgs>(args)...};
98 params.pack_params = pack_params;
99 return params;
100}
101
102} // namespace dnnl
103
104#endif
105