1 | /******************************************************************************* |
2 | * Copyright 2021 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef TEST_GEMM_DATA_VALIDATION_H |
18 | #define TEST_GEMM_DATA_VALIDATION_H |
19 | |
20 | #include "test_gemm_params.hpp" |
21 | |
22 | #include "dnnl_test_common.hpp" |
23 | |
24 | namespace dnnl { |
25 | |
26 | template <typename a_dt, typename b_dt, typename c_dt> |
27 | struct ref_gemm { |
28 | static void call(const test_params &p, int64_t M, int64_t N, |
29 | const test_memory &a_mem, const test_memory &b_mem, |
30 | const test_memory &c_mem, const test_memory &) { |
31 | auto a = map_memory<a_dt>(a_mem); |
32 | auto b = map_memory<b_dt>(b_mem); |
33 | auto c = map_memory<c_dt>(c_mem); |
34 | |
35 | const bool tr_a = p.transA && (p.transA == 'T' || p.transA == 't'); |
36 | const bool tr_b = p.transB && (p.transB == 'T' || p.transB == 't'); |
37 | |
38 | auto pa = [&](int64_t i, int64_t j) { |
39 | return a[p.off.a + i * p.lda + j]; |
40 | }; |
41 | auto pb = [&](int64_t i, int64_t j) { |
42 | return b[p.off.b + i * p.ldb + j]; |
43 | }; |
44 | auto pc = [&](int64_t i, int64_t j) -> c_dt & { |
45 | return c[p.off.c + i * p.ldc + j]; |
46 | }; |
47 | |
48 | dnnl::impl::parallel_nd(M, N, [&](int64_t im, int64_t in) { |
49 | c_dt c_elem = (p.beta == 0.) ? 0. : pc(im, in) * p.beta; |
50 | |
51 | for (int64_t ik = 0; ik < p.K; ik++) { |
52 | const a_dt a_elem = tr_a ? pa(ik, im) : pa(im, ik); |
53 | const b_dt b_elem = tr_b ? pb(in, ik) : pb(ik, in); |
54 | c_elem += p.alpha * a_elem * b_elem; |
55 | } |
56 | pc(im, in) = c_elem; |
57 | }); |
58 | } |
59 | }; |
60 | |
61 | template <typename a_dt, typename b_dt> |
62 | struct ref_gemm<a_dt, b_dt, int32_t> { |
63 | static void call(const test_params &p, int64_t M, int64_t N, |
64 | const test_memory &a_mem, const test_memory &b_mem, |
65 | const test_memory &c_mem, const test_memory &oc_mem) { |
66 | auto A = map_memory<a_dt>(a_mem); |
67 | auto B = map_memory<b_dt>(b_mem); |
68 | auto C = map_memory<int32_t>(c_mem); |
69 | auto oc = map_memory<int32_t>(oc_mem); |
70 | |
71 | const bool tr_a = p.transA && (p.transA == 'T' || p.transA == 't'); |
72 | const bool tr_b = p.transB && (p.transB == 'T' || p.transB == 't'); |
73 | bool OCisR = (p.igemm_params.offsetc == 'R' |
74 | || p.igemm_params.offsetc == 'r'); |
75 | bool OCisC = (p.igemm_params.offsetc == 'C' |
76 | || p.igemm_params.offsetc == 'c'); |
77 | |
78 | auto pa = [&](int64_t i, int64_t j) { |
79 | return (double)A[p.off.a + i * p.lda + j]; |
80 | }; |
81 | auto pb = [&](int64_t i, int64_t j) { |
82 | return (double)B[p.off.b + i * p.ldb + j]; |
83 | }; |
84 | auto pc = [&](int64_t i, int64_t j) -> int32_t & { |
85 | return C[p.off.c + i * p.ldc + j]; |
86 | }; |
87 | |
88 | int8_t oa = p.igemm_params.oa(); |
89 | int8_t ob = p.igemm_params.ob(); |
90 | |
91 | dnnl::impl::parallel_nd(M, N, [&](int64_t m, int64_t n) { |
92 | double c_elem = 0; |
93 | for (int64_t k = 0; k < p.K; k++) { |
94 | const double a_elem = (tr_a ? pa(k, m) : pa(m, k)) - oa; |
95 | const double b_elem = (tr_b ? pb(n, k) : pb(k, n)) - ob; |
96 | c_elem += a_elem * b_elem; |
97 | } |
98 | |
99 | double coffset = OCisR ? oc[n] : OCisC ? oc[m] : oc[0]; |
100 | double val = (p.beta == 0.f ? 0. : p.beta * (double)pc(m, n)) |
101 | + p.alpha * c_elem + coffset; |
102 | pc(m, n) = static_cast<int32_t>( |
103 | nearbyint(saturate<int32_t, double>(val))); |
104 | }); |
105 | } |
106 | }; |
107 | |
108 | template <typename a_dt, typename c_dt> |
109 | void compare(const test_params &p, const test_memory &c_mem, |
110 | const test_memory &c_ref_mem) { |
111 | using data_type = memory::data_type; |
112 | auto c = map_memory<c_dt>(c_mem); |
113 | auto c_ref = map_memory<c_dt>(c_ref_mem); |
114 | dnnl::impl::parallel_nd(p.M, p.ldc, [&](int64_t i, int64_t j) { |
115 | if (is_current_test_failed()) return; |
116 | |
117 | c_dt ref = c_ref[p.off.c + i * p.ldc + j]; |
118 | c_dt got = c[p.off.c + i * p.ldc + j]; |
119 | c_dt diff = got - ref; |
120 | |
121 | if (data_traits<a_dt>::data_type == data_type::f16) { |
122 | const float eps = 1e-3 * p.K; |
123 | float e = (std::abs(ref) > eps) ? diff / ref : float(diff); |
124 | ASSERT_NEAR(e, 0.0, eps) << "Row: " << i << " Col: " << j; |
125 | } else if (data_traits<a_dt>::data_type == data_type::bf16) { |
126 | const float eps = 1e-2 * p.K; |
127 | float e = (std::abs(ref) > eps) ? diff / ref : float(diff); |
128 | ASSERT_NEAR(e, 0.0, eps) << "Row: " << i << " Col: " << j; |
129 | } else if (data_traits<a_dt>::data_type == data_type::f32) { |
130 | c_dt e = (std::abs(ref) > 1e-4) ? c_dt(diff / ref) : diff; |
131 | ASSERT_NEAR(e, 0.0, 1e-4) << "Row: " << i << " Col: " << j; |
132 | } else { |
133 | // igemm |
134 | c_dt eps = 0; |
135 | if (p.alpha == 1.0f) { |
136 | eps = 1; |
137 | } else if (data_traits<a_dt>::data_type == data_type::u8) { |
138 | eps = p.K / 700 + 1; |
139 | } else if (data_traits<a_dt>::data_type == data_type::s8) { |
140 | eps = p.K / 350 + 1; |
141 | } |
142 | ASSERT_NEAR(diff, 0, eps) << "Row: " << i << " Col: " << j; |
143 | } |
144 | }); |
145 | } |
146 | |
147 | template <typename a_dt, typename b_dt, typename c_dt> |
148 | void validate(const test_params &p, test_gemm_data &gemm_data) { |
149 | const int64_t M_test = gemm_data.mapper_m->dim_test(); |
150 | const int64_t N_test = gemm_data.mapper_n->dim_test(); |
151 | |
152 | ref_gemm<a_dt, b_dt, c_dt>::call(p, M_test, N_test, *gemm_data.a_mem, |
153 | *gemm_data.b_mem, *gemm_data.c_ref_mem, *gemm_data.oc_mem); |
154 | extend_matrix<c_dt>(*gemm_data.c_ref_mem, p.off.c, p.M, p.N, p.ldc, |
155 | *gemm_data.mapper_m, *gemm_data.mapper_n); |
156 | compare<a_dt, c_dt>(p, *gemm_data.c_mem, *gemm_data.c_ref_mem); |
157 | } |
158 | |
159 | } // namespace dnnl |
160 | |
161 | #endif |
162 | |