1/*******************************************************************************
2* Copyright 2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef TEST_GEMM_DATA_PREPARATION_H
18#define TEST_GEMM_DATA_PREPARATION_H
19
20#include "test_gemm_params.hpp"
21
22#include "dnnl_test_common.hpp"
23
24namespace dnnl {
25/*
26 * To reduce the time spent in GEMM validation the test matrices A, B, and C
27 * are generated from sub-matrices (A', B', and C') of smaller size:
28 * - A(M, K) <-> A'(M_test, K)
29 * - B(K, N) <-> B'(K, N_test)
30 * - C(M, N) <-> C'(M_test, N_test)
31 *
32 * The matrices A', B', and C' are generated randomly. Then:
33 * - A(m, k) := A'(mapper_m[m], k),
34 * - B(k, n) := B'(k, mapper_n[n]),
35 * - C(m, n) := C'(mapper_m[m], mapper_n[n]);
36 *
37 * Here `mapper_x[]` is surjection of {0, ..., X-1} onto {0, ..., X_test-1}.
38 * For simplicity mapper_x[x] = x, for x in {0, ..., X_test-1}.
39 *
40 * This technique allows reducing the complexity of the validation code from
41 * O(M*N*K) to O(M_test * N_test * K).
42 *
43 * X_test := min(X, X_test_max), where X_test_max is prime number around 50.
44 *
45 * To make the test robust the surjective functions mapper_m and mapper_n
46 * should randomly map the elements {X_test, ..., X-1} onto {0, ..., X_test-1}.
47 */
48
49static constexpr int M_test_max = 47;
50static constexpr int N_test_max = 53;
51
52/** Mapper:
53 * a surjective function from {0, ..., dim-1} onto {0, ..., dim_test-1}.
54 */
55struct mapper_t {
56 mapper_t(const mapper_t &other)
57 : dim_(other.dim_)
58 , dim_test_(other.dim_test_)
59 , gen_(other.gen_)
60 , gen_start_(other.gen_start_)
61 , mapper_(other.mapper_) {}
62
63 mapper_t(mapper_t &&other) noexcept
64 : dim_(other.dim_)
65 , dim_test_(other.dim_test_)
66 , gen_(other.gen_)
67 , gen_start_(other.gen_start_)
68 , mapper_(std::move(other.mapper_)) {}
69
70 mapper_t(int64_t dim, int64_t dim_test_max, int64_t gen = 7,
71 int64_t gen_start = 13)
72 : dim_(dim)
73 , dim_test_((std::min)(dim, dim_test_max))
74 , gen_(gen)
75 , gen_start_(gen_start)
76 , mapper_(dim) {
77 for (int64_t d = 0; d < dim_test_; ++d)
78 mapper_[d] = d;
79 for (int64_t g = gen_start_ % dim_test_, d = dim_test_; d < dim_; ++d) {
80 mapper_[d] = mapper_[g];
81 g = g * gen_ % dim_test_;
82 }
83 }
84
85 int64_t dim() const { return dim_; }
86 int64_t dim_test() const { return dim_test_; }
87 int64_t operator[](int64_t d) const { return mapper_[d]; }
88
89private:
90 const int64_t dim_;
91 const int64_t dim_test_;
92 const int64_t gen_, gen_start_;
93 std::vector<int64_t> mapper_;
94};
95
96struct test_gemm_data {
97 std::shared_ptr<test_memory> a_mem;
98 std::shared_ptr<test_memory> b_mem;
99 std::shared_ptr<test_memory> c_mem;
100 std::shared_ptr<test_memory> c_ref_mem;
101 std::shared_ptr<test_memory> oc_mem;
102 std::shared_ptr<mapper_t> mapper_m;
103 std::shared_ptr<mapper_t> mapper_n;
104};
105
106/** Prepares matrix A or B according to the dimension mapper.
107 * The K dimension is always assumed to be columns, hence:
108 * - A layout = A_is_transposed ? ROW_MAJOR : COL_MAJOR
109 * - B layout = B_is_transposed ? COL_MAJOR : ROW_MAJOR
110 */
111template <typename data_t>
112void prepare_matrix(const test_memory &M_mem, int64_t off_beg, layout_t layout,
113 int64_t R, int64_t C, int64_t LD, const mapper_t &mapper) {
114 auto M = map_memory<data_t>(M_mem);
115 auto dt = data_traits<data_t>::data_type;
116 bool is_fp = (false || dt == memory::data_type::f16
117 || dt == memory::data_type::bf16 || dt == memory::data_type::f32);
118 const data_t mean = (data_t)(is_fp ? 1.f : 4);
119 const data_t var = (data_t)(is_fp ? 2e-1f : 3);
120
121 ASSERT_EQ(R, mapper.dim());
122 const int R_test = mapper.dim_test();
123
124 if (layout == layout_t::COL_MAJOR) {
125 dnnl::impl::parallel_nd(C, R_test, [&](int64_t c, int64_t r) {
126 const int64_t off = c * LD + r;
127 M[off_beg + off] = set_value<data_t>(off, mean, var, 1.);
128 });
129 if (R > R_test) {
130 const int64_t R_rest = R - R_test;
131 dnnl::impl::parallel_nd(C, R_rest, [&](int64_t c, int64_t r_) {
132 const int64_t r = R_test + r_;
133 const int64_t off = c * LD + r;
134 const int64_t off0 = c * LD + mapper[r];
135 M[off_beg + off] = M[off_beg + off0];
136 });
137 }
138 } else {
139 dnnl::impl::parallel_nd(R_test, C, [&](int64_t r, int64_t c) {
140 const int64_t off = r * LD + c;
141 M[off_beg + off] = set_value<data_t>(off, mean, var, 1.);
142 });
143 if (R > R_test) {
144 const int64_t R_rest = R - R_test;
145 dnnl::impl::parallel_nd(R_rest, C, [&](int64_t r_, int64_t c) {
146 const int64_t r = R_test + r_;
147 const int64_t off = r * LD + c;
148 const int64_t off0 = mapper[r] * LD + c;
149 M[off_beg + off] = M[off_beg + off0];
150 });
151 }
152 }
153
154 // To test if igemm row/col sum are correct when performing sign/zero
155 // extensions.
156 if (dt == memory::data_type::u8)
157 M[off_beg] = data_t(UINT8_MAX);
158 else if (dt == memory::data_type::s8)
159 M[off_beg] = data_t(-64);
160}
161
162/** Extends columns of the matrix M according to the mapper_c */
163template <typename data_t>
164void extend_matrix_cols(const test_memory &M_mem, int64_t off, int64_t R,
165 int64_t C, int64_t LD, const mapper_t &mapper_c) {
166 auto M = map_memory<data_t>(M_mem);
167 ASSERT_EQ(C, mapper_c.dim());
168 const int64_t C_test = mapper_c.dim_test();
169 if (C_test == C) return;
170
171 dnnl::impl::parallel_nd(R, C - C_test, [&](int64_t r, int64_t c_) {
172 const int64_t c = C_test + c_;
173 const int64_t c0 = mapper_c[c];
174 M[off + r * LD + c] = M[off + r * LD + c0];
175 });
176}
177
178/** Extends rows of the matrix M according to the mapper_r */
179template <typename data_t>
180void extend_matrix_rows(const test_memory &M_mem, int64_t off, int64_t R,
181 int64_t C, int64_t LD, const mapper_t &mapper_r) {
182 auto M = map_memory<data_t>(M_mem);
183 ASSERT_EQ(R, mapper_r.dim());
184 const int64_t R_test = mapper_r.dim_test();
185 if (R_test == R) return;
186
187 dnnl::impl::parallel_nd(R - R_test, [&](int64_t r_) {
188 const int64_t r = R_test + r_;
189 const int64_t r0 = mapper_r[r];
190 for (int64_t c = 0; c < C; ++c)
191 M[off + r * LD + c] = M[off + r0 * LD + c];
192 });
193}
194
195/** Extends matrix M according to the mapper_r and mapper_c */
196template <typename data_t>
197void extend_matrix(const test_memory &M_mem, int64_t off, int64_t R, int64_t C,
198 int64_t LD, const mapper_t &mapper_r, const mapper_t &mapper_c) {
199 ASSERT_EQ(R, mapper_r.dim());
200 ASSERT_EQ(C, mapper_c.dim());
201 extend_matrix_rows<data_t>(M_mem, off, R, C, LD, mapper_r);
202 extend_matrix_cols<data_t>(M_mem, off, R, C, LD, mapper_c);
203}
204
205inline void get_matrix_size(
206 const test_params &p, size_t &sizeA, size_t &sizeB, size_t &sizeC) {
207 const bool tr_a = (p.transA == 'T' || p.transA == 't');
208 const bool tr_b = (p.transB == 'T' || p.transB == 't');
209 sizeA = tr_a ? p.lda * p.K : p.lda * p.M;
210 sizeB = tr_b ? p.ldb * p.N : p.ldb * p.K;
211 sizeC = p.ldc * p.M;
212}
213
214template <typename T>
215inline memory::desc get_matrix_md(memory::dim n, memory::dim off) {
216 return create_md(
217 {n + off}, data_traits<T>::data_type, memory::format_tag::x);
218}
219
220template <typename a_dt, typename b_dt, typename c_dt>
221void fill_matrices(const test_params &p, const mapper_t &mapper_m,
222 const mapper_t &mapper_n, const test_memory &a_mem,
223 const test_memory &b_mem, const test_memory &c_mem,
224 const test_memory &c_ref_mem, const test_memory &oc_mem) {
225 prepare_matrix<a_dt>(a_mem, p.off.a,
226 p.tr_a() ? layout_t::COL_MAJOR : layout_t::ROW_MAJOR, p.M, p.K,
227 p.lda, mapper_m);
228 prepare_matrix<b_dt>(b_mem, p.off.b,
229 p.tr_b() ? layout_t::ROW_MAJOR : layout_t::COL_MAJOR, p.N, p.K,
230 p.ldb, mapper_n);
231
232 fill_data<c_dt>(p.off.c + p.sizeC(), c_mem.get());
233 extend_matrix<c_dt>(c_mem, p.off.c, p.M, p.N, p.ldc, mapper_m, mapper_n);
234 {
235 auto C = map_memory<c_dt>(c_mem);
236 auto C_ref = map_memory<c_dt>(c_ref_mem);
237 dnnl::impl::parallel_nd(p.sizeC(),
238 [&](int64_t i) { C_ref[p.off.c + i] = C[p.off.c + i]; });
239 }
240
241 if (oc_mem.get_size() == 0) return;
242
243 if (p.igemm_params.nonzero_oc) {
244 fill_data<c_dt>(p.size_oc(), oc_mem.get(), (c_dt)1, (c_dt)0);
245 if (p.oc_is_R()) {
246 extend_matrix_cols<c_dt>(oc_mem, 0, 1, p.N, p.N, mapper_n);
247 } else if (p.oc_is_C()) {
248 extend_matrix_rows<c_dt>(oc_mem, 0, p.M, 1, 1, mapper_m);
249 }
250 } else {
251 auto oc = map_memory<c_dt>(oc_mem);
252 for (int64_t i = 0; i < p.size_oc(); i++)
253 oc[i] = 0;
254 }
255}
256
257template <typename a_dt, typename b_dt, typename c_dt>
258void prepare_data_for_gemm_testing(
259 const test_params &p, test_gemm_data &gemm_data, engine &eng) {
260 size_t sizeA, sizeB, sizeC;
261 get_matrix_size(p, sizeA, sizeB, sizeC);
262
263 gemm_data.a_mem.reset(
264 new test_memory(get_matrix_md<a_dt>(sizeA, p.off.a), eng));
265 gemm_data.b_mem.reset(
266 new test_memory(get_matrix_md<b_dt>(sizeB, p.off.b), eng));
267 gemm_data.c_mem.reset(
268 new test_memory(get_matrix_md<c_dt>(sizeC, p.off.c), eng));
269 gemm_data.c_ref_mem.reset(
270 new test_memory(get_matrix_md<c_dt>(sizeC, p.off.c), eng));
271 gemm_data.oc_mem.reset(
272 new test_memory(get_matrix_md<c_dt>(p.size_oc(), p.off.co), eng));
273
274 gemm_data.mapper_m.reset(new mapper_t(p.M, M_test_max));
275 gemm_data.mapper_n.reset(new mapper_t(p.N, N_test_max));
276
277 fill_matrices<a_dt, b_dt, c_dt>(p, *gemm_data.mapper_m, *gemm_data.mapper_n,
278 *gemm_data.a_mem, *gemm_data.b_mem, *gemm_data.c_mem,
279 *gemm_data.c_ref_mem, *gemm_data.oc_mem);
280}
281
282} // namespace dnnl
283
284#endif
285