1/*******************************************************************************
2* Copyright 2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef TEST_GEMM_DATA_VALIDATION_H
18#define TEST_GEMM_DATA_VALIDATION_H
19
20#include "test_gemm_params.hpp"
21
22#include "dnnl_test_common.hpp"
23
24namespace dnnl {
25
26template <typename a_dt, typename b_dt, typename c_dt>
27struct ref_gemm {
28 static void call(const test_params &p, int64_t M, int64_t N,
29 const test_memory &a_mem, const test_memory &b_mem,
30 const test_memory &c_mem, const test_memory &) {
31 auto a = map_memory<a_dt>(a_mem);
32 auto b = map_memory<b_dt>(b_mem);
33 auto c = map_memory<c_dt>(c_mem);
34
35 const bool tr_a = p.transA && (p.transA == 'T' || p.transA == 't');
36 const bool tr_b = p.transB && (p.transB == 'T' || p.transB == 't');
37
38 auto pa = [&](int64_t i, int64_t j) {
39 return a[p.off.a + i * p.lda + j];
40 };
41 auto pb = [&](int64_t i, int64_t j) {
42 return b[p.off.b + i * p.ldb + j];
43 };
44 auto pc = [&](int64_t i, int64_t j) -> c_dt & {
45 return c[p.off.c + i * p.ldc + j];
46 };
47
48 dnnl::impl::parallel_nd(M, N, [&](int64_t im, int64_t in) {
49 c_dt c_elem = (p.beta == 0.) ? 0. : pc(im, in) * p.beta;
50
51 for (int64_t ik = 0; ik < p.K; ik++) {
52 const a_dt a_elem = tr_a ? pa(ik, im) : pa(im, ik);
53 const b_dt b_elem = tr_b ? pb(in, ik) : pb(ik, in);
54 c_elem += p.alpha * a_elem * b_elem;
55 }
56 pc(im, in) = c_elem;
57 });
58 }
59};
60
61template <typename a_dt, typename b_dt>
62struct ref_gemm<a_dt, b_dt, int32_t> {
63 static void call(const test_params &p, int64_t M, int64_t N,
64 const test_memory &a_mem, const test_memory &b_mem,
65 const test_memory &c_mem, const test_memory &oc_mem) {
66 auto A = map_memory<a_dt>(a_mem);
67 auto B = map_memory<b_dt>(b_mem);
68 auto C = map_memory<int32_t>(c_mem);
69 auto oc = map_memory<int32_t>(oc_mem);
70
71 const bool tr_a = p.transA && (p.transA == 'T' || p.transA == 't');
72 const bool tr_b = p.transB && (p.transB == 'T' || p.transB == 't');
73 bool OCisR = (p.igemm_params.offsetc == 'R'
74 || p.igemm_params.offsetc == 'r');
75 bool OCisC = (p.igemm_params.offsetc == 'C'
76 || p.igemm_params.offsetc == 'c');
77
78 auto pa = [&](int64_t i, int64_t j) {
79 return (double)A[p.off.a + i * p.lda + j];
80 };
81 auto pb = [&](int64_t i, int64_t j) {
82 return (double)B[p.off.b + i * p.ldb + j];
83 };
84 auto pc = [&](int64_t i, int64_t j) -> int32_t & {
85 return C[p.off.c + i * p.ldc + j];
86 };
87
88 int8_t oa = p.igemm_params.oa();
89 int8_t ob = p.igemm_params.ob();
90
91 dnnl::impl::parallel_nd(M, N, [&](int64_t m, int64_t n) {
92 double c_elem = 0;
93 for (int64_t k = 0; k < p.K; k++) {
94 const double a_elem = (tr_a ? pa(k, m) : pa(m, k)) - oa;
95 const double b_elem = (tr_b ? pb(n, k) : pb(k, n)) - ob;
96 c_elem += a_elem * b_elem;
97 }
98
99 double coffset = OCisR ? oc[n] : OCisC ? oc[m] : oc[0];
100 double val = (p.beta == 0.f ? 0. : p.beta * (double)pc(m, n))
101 + p.alpha * c_elem + coffset;
102 pc(m, n) = static_cast<int32_t>(
103 nearbyint(saturate<int32_t, double>(val)));
104 });
105 }
106};
107
108template <typename a_dt, typename c_dt>
109void compare(const test_params &p, const test_memory &c_mem,
110 const test_memory &c_ref_mem) {
111 using data_type = memory::data_type;
112 auto c = map_memory<c_dt>(c_mem);
113 auto c_ref = map_memory<c_dt>(c_ref_mem);
114 dnnl::impl::parallel_nd(p.M, p.ldc, [&](int64_t i, int64_t j) {
115 if (is_current_test_failed()) return;
116
117 c_dt ref = c_ref[p.off.c + i * p.ldc + j];
118 c_dt got = c[p.off.c + i * p.ldc + j];
119 c_dt diff = got - ref;
120
121 if (data_traits<a_dt>::data_type == data_type::f16) {
122 const float eps = 1e-3 * p.K;
123 float e = (std::abs(ref) > eps) ? diff / ref : float(diff);
124 ASSERT_NEAR(e, 0.0, eps) << "Row: " << i << " Col: " << j;
125 } else if (data_traits<a_dt>::data_type == data_type::bf16) {
126 const float eps = 1e-2 * p.K;
127 float e = (std::abs(ref) > eps) ? diff / ref : float(diff);
128 ASSERT_NEAR(e, 0.0, eps) << "Row: " << i << " Col: " << j;
129 } else if (data_traits<a_dt>::data_type == data_type::f32) {
130 c_dt e = (std::abs(ref) > 1e-4) ? c_dt(diff / ref) : diff;
131 ASSERT_NEAR(e, 0.0, 1e-4) << "Row: " << i << " Col: " << j;
132 } else {
133 // igemm
134 c_dt eps = 0;
135 if (p.alpha == 1.0f) {
136 eps = 1;
137 } else if (data_traits<a_dt>::data_type == data_type::u8) {
138 eps = p.K / 700 + 1;
139 } else if (data_traits<a_dt>::data_type == data_type::s8) {
140 eps = p.K / 350 + 1;
141 }
142 ASSERT_NEAR(diff, 0, eps) << "Row: " << i << " Col: " << j;
143 }
144 });
145}
146
147template <typename a_dt, typename b_dt, typename c_dt>
148void validate(const test_params &p, test_gemm_data &gemm_data) {
149 const int64_t M_test = gemm_data.mapper_m->dim_test();
150 const int64_t N_test = gemm_data.mapper_n->dim_test();
151
152 ref_gemm<a_dt, b_dt, c_dt>::call(p, M_test, N_test, *gemm_data.a_mem,
153 *gemm_data.b_mem, *gemm_data.c_ref_mem, *gemm_data.oc_mem);
154 extend_matrix<c_dt>(*gemm_data.c_ref_mem, p.off.c, p.M, p.N, p.ldc,
155 *gemm_data.mapper_m, *gemm_data.mapper_n);
156 compare<a_dt, c_dt>(p, *gemm_data.c_mem, *gemm_data.c_ref_mem);
157}
158
159} // namespace dnnl
160
161#endif
162