1/*******************************************************************************
2* Copyright 2018-2020 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16#include <cmath>
17
18#include "common/dnnl_thread.hpp"
19#include "common/utils.hpp"
20
21#include "cpu/gemm/f32/gemm_utils_f32.hpp"
22
23namespace dnnl {
24namespace impl {
25namespace cpu {
26namespace gemm_utils {
27#define BM_NOCOPY_AVX 64
28#define BN_NOCOPY_AVX 48
29#define BK_NOCOPY_AVX 384
30#define BN_LARGE_NOCOPY_AVX 192
31#define BM_SMALL_NOCOPY_AVX 16
32#define BN_SMALL_NOCOPY_AVX 1
33#define BK_SMALL_NOCOPY_AVX 4
34// Determine number of threads for each dimension of a 3-D partitioning
35// algorithm based on input parameters
36// m/n/k - First/second/third parameter for GEMM
37// nthrs - total available number of threads
38// nthrs_m/nthrs_n/nthrs_k - number of threads to use in each dimension
39// BM/BN/BK - blocking values
40void calc_nthr_nocopy_avx(dim_t m, dim_t n, dim_t k, int nthrs, int *nthrs_m,
41 int *nthrs_n, int *nthrs_k, dim_t *BM, dim_t *BN, dim_t *BK) {
42
43 // Quick exit for single thread.
44 if (nthrs == 1) {
45 *nthrs_m = 1;
46 *nthrs_n = 1;
47 *nthrs_k = 1;
48
49 *BM = m;
50 *BN = n;
51 *BK = k;
52 return;
53 }
54
55 int nthr, nthr_m, nthr_n, nthr_k;
56 dim_t MB, NB, KB;
57
58 nthr = nthrs;
59 nthr_m = (m + BM_NOCOPY_AVX - 1) / BM_NOCOPY_AVX;
60 nthr_n = (n + BN_NOCOPY_AVX - 1) / BN_NOCOPY_AVX;
61 nthr_k = 1;
62
63 // Partition along K dimension
64 // - if threading allows having barriers (e.g. OMP)
65 // - if there is not enough parallelism along M or N
66 if (dnnl_thr_syncable()) {
67 int nthr_other = nthr_k = 1;
68 while ((nthr_m * nthr_n * nthr_other < nthr)
69 && (k / (nthr_other + 1) > BK_NOCOPY_AVX)) {
70 nthr_other++;
71 if ((nthr / nthr_other) * nthr_other > 0.9 * nthr)
72 nthr_k = nthr_other;
73 }
74 }
75 nthr /= nthr_k;
76
77 if (nthr_m == 1) nthr_n = nthr;
78 if (nthr_n == 1) nthr_m = nthr;
79
80 // Simple partition reduction
81 while (nthr_m * nthr_n > nthr)
82 if (nthr_m > nthr_n)
83 nthr_m--;
84 else
85 nthr_n--;
86 while (nthr_m * nthr_n < nthr)
87 if (nthr_m < nthr_n)
88 nthr_m++;
89 else
90 nthr_n++;
91
92 if ((nthr_m * nthr_n > nthr) && (nthr_m > 1) && (nthr_n > 1)) {
93
94 if (nthr_m <= nthr_n) {
95 nthr_m = (int)sqrt((double)nthr);
96 if (nthr_m > (m + BM_SMALL_NOCOPY_AVX - 1) / BM_SMALL_NOCOPY_AVX)
97 nthr_m = (m + BM_SMALL_NOCOPY_AVX - 1) / BM_SMALL_NOCOPY_AVX;
98 nthr_n = nthr / nthr_m;
99
100 while ((nthr_m > 1) && (nthr_m * nthr_n != nthr)) {
101 nthr_m--;
102 nthr_n = nthr / nthr_m;
103 }
104 } else {
105 nthr_n = (int)sqrt((double)nthr);
106 if (nthr_n > (n + BN_SMALL_NOCOPY_AVX - 1) / BN_SMALL_NOCOPY_AVX)
107 nthr_n = (n + BN_SMALL_NOCOPY_AVX - 1) / BN_SMALL_NOCOPY_AVX;
108 nthr_m = nthr / nthr_n;
109
110 while ((nthr_n > 1) && (nthr_m * nthr_n != nthr)) {
111 nthr_n--;
112 nthr_m = nthr / nthr_n;
113 }
114 }
115 }
116
117 MB = (m + nthr_m - 1) / nthr_m + BM_SMALL_NOCOPY_AVX - 1;
118 MB -= MB % BM_SMALL_NOCOPY_AVX;
119 NB = (n + nthr_n - 1) / nthr_n + BN_SMALL_NOCOPY_AVX - 1;
120 NB -= NB % BN_SMALL_NOCOPY_AVX;
121 KB = (k + nthr_k - 1) / nthr_k + BK_SMALL_NOCOPY_AVX - 1;
122 KB -= KB % BK_SMALL_NOCOPY_AVX;
123
124 if (MB * nthr_m > m) nthr_m = (m + MB - 1) / MB;
125 if (NB * nthr_n > n) nthr_n = (n + NB - 1) / NB;
126 if (KB * nthr_k > k) nthr_k = (k + KB - 1) / KB;
127
128 *nthrs_m = nthr_m;
129 *nthrs_n = nthr_n;
130 *nthrs_k = nthr_k;
131
132 *BM = MB;
133 *BN = NB;
134 *BK = KB;
135}
136#undef BM_NOCOPY_AVX
137#undef BN_NOCOPY_AVX
138#undef BK_NOCOPY_AVX
139#undef BN_LARGE_NOCOPY_AVX
140#undef BM_SMALL_NOCOPY_AVX
141#undef BN_SMALL_NOCOPY_AVX
142#undef BK_SMALL_NOCOPY_AVX
143
144#define BM_NOCOPY_AVX512_COMMON 32
145#define BN_NOCOPY_AVX512_COMMON 64
146#define BK_NOCOPY_AVX512_COMMON 192
147#define BN_LARGE_NOCOPY_AVX512_COMMON 192
148#define BM_SMALL_NOCOPY_AVX512_COMMON 16
149#define BN_SMALL_NOCOPY_AVX512_COMMON 1
150#define BK_SMALL_NOCOPY_AVX512_COMMON 4
151// Determine number of threads for each dimension of a 3-D partitioning
152// algorithm based on input parameters
153// m/n/k - First/second/third parameter for GEMM
154// nthrs - total available number of threads
155// nthrs_m/nthrs_n/nthrs_k - number of threads to use in each dimension
156// BM/BN/BK - blocking values
157void calc_nthr_nocopy_avx512_common(dim_t m, dim_t n, dim_t k, int nthrs,
158 int *nthrs_m, int *nthrs_n, int *nthrs_k, dim_t *BM, dim_t *BN,
159 dim_t *BK) {
160
161 // Quick exit for single thread.
162 if (nthrs == 1) {
163 *nthrs_m = 1;
164 *nthrs_n = 1;
165 *nthrs_k = 1;
166
167 *BM = m;
168 *BN = n;
169 *BK = k;
170 return;
171 }
172
173 int nthr, nthr_m, nthr_n, nthr_k = 1;
174 dim_t MB, NB, KB;
175 nthr = nthrs;
176
177 int counter = 0;
178 float ratio_float = 1.;
179 int ratio = 1;
180 nthr = nthrs;
181 int nthr_m_gt_n;
182
183 // Partition along K dimension
184 // - if threading allows having barriers (e.g. OMP)
185 // - if there is not enough parallelism along M or N
186 if (dnnl_thr_syncable()) {
187 if (n <= 2 * BN_NOCOPY_AVX512_COMMON
188 && m <= 2 * BM_NOCOPY_AVX512_COMMON * nthr && k > m && k > n) {
189 nthr_k = k / BK_NOCOPY_AVX512_COMMON;
190 if (nthr_k > nthr / 4) nthr_k = nthr / 4;
191 if (nthr_k < 1) nthr_k = 1;
192
193 while ((nthr_k > 1) && (nthr % nthr_k)) {
194 nthr_k--;
195 }
196 nthr /= nthr_k;
197 } else {
198 nthr_k = 1;
199 }
200 }
201 nthr_m = (m + BM_NOCOPY_AVX512_COMMON - 1) / BM_NOCOPY_AVX512_COMMON;
202 nthr_n = (n + BN_NOCOPY_AVX512_COMMON - 1) / BN_NOCOPY_AVX512_COMMON;
203
204 if (nthr_m < 1) nthr_m = 1;
205 if (nthr_n < 1) nthr_n = 1;
206
207 nthr_m_gt_n = nthr_m > nthr_n ? 1 : 0;
208 ratio_float = (float)nthr_m / nthr_n;
209
210 if (nthr_m_gt_n)
211 ratio = (int)ratio_float;
212 else
213 ratio = (int)(1. / ratio_float);
214
215 // scale down nthr_m and nthr_n if they are too large
216 while (nthr_m * nthr_n > 4 * nthr) {
217 nthr_m /= 2;
218 nthr_n /= 2;
219 }
220
221 if (nthr_m < 1) nthr_m = 1;
222 if (nthr_n < 1) nthr_n = 1;
223
224 // Simple partition reduction
225 counter = 0;
226 while (nthr_m * nthr_n > nthr) {
227 if (nthr_m > nthr_n) {
228 if (counter < ratio)
229 nthr_m--;
230 else {
231 nthr_n--;
232 counter = -1;
233 }
234 } else {
235 if (counter < ratio)
236 nthr_n--;
237 else {
238 nthr_m--;
239 counter = -1;
240 }
241 }
242 counter++;
243 }
244
245 // Simple partition increment
246 counter = 0;
247 while (nthr_m * nthr_n < 0.95 * nthr) {
248 if (nthr_m > nthr_n) {
249 if (counter < ratio)
250 nthr_m++;
251 else {
252 nthr_n++;
253 counter = -1;
254 }
255 } else {
256 if (counter < ratio)
257 nthr_n++;
258 else {
259 nthr_m++;
260 counter = -1;
261 }
262 }
263 counter++;
264 }
265
266 // if nothing works out, then this should work
267 if ((nthr_m * nthr_n > nthr)) {
268
269 if (nthr_m <= nthr_n) {
270 nthr_m = (int)sqrt((double)nthr);
271 if (nthr_m > (m + BM_SMALL_NOCOPY_AVX512_COMMON - 1)
272 / BM_SMALL_NOCOPY_AVX512_COMMON)
273 nthr_m = (m + BM_SMALL_NOCOPY_AVX512_COMMON - 1)
274 / BM_SMALL_NOCOPY_AVX512_COMMON;
275 nthr_n = nthr / nthr_m;
276
277 while ((nthr_m > 1) && (nthr_m * nthr_n != nthr)) {
278 nthr_m--;
279 nthr_n = nthr / nthr_m;
280 }
281 } else {
282 nthr_n = (int)sqrt((double)nthr);
283 if (nthr_n > (n + BN_SMALL_NOCOPY_AVX512_COMMON - 1)
284 / BN_SMALL_NOCOPY_AVX512_COMMON)
285 nthr_n = (n + BN_SMALL_NOCOPY_AVX512_COMMON - 1)
286 / BN_SMALL_NOCOPY_AVX512_COMMON;
287 nthr_m = nthr / nthr_n;
288
289 while ((nthr_n > 1) && (nthr_m * nthr_n != nthr)) {
290 nthr_n--;
291 nthr_m = nthr / nthr_n;
292 }
293 }
294 }
295
296 MB = (m + nthr_m - 1) / nthr_m + BM_SMALL_NOCOPY_AVX512_COMMON - 1;
297 MB -= MB % BM_SMALL_NOCOPY_AVX512_COMMON;
298 NB = (n + nthr_n - 1) / nthr_n + BN_SMALL_NOCOPY_AVX512_COMMON - 1;
299 NB -= NB % BN_SMALL_NOCOPY_AVX512_COMMON;
300 KB = (k + nthr_k - 1) / nthr_k + BK_SMALL_NOCOPY_AVX512_COMMON - 1;
301 KB -= KB % BK_SMALL_NOCOPY_AVX512_COMMON;
302
303 if (MB * nthr_m > m) nthr_m = (m + MB - 1) / MB;
304 if (NB * nthr_n > n) nthr_n = (n + NB - 1) / NB;
305 if (KB * nthr_k > k) nthr_k = (k + KB - 1) / KB;
306
307 *nthrs_m = nthr_m;
308 *nthrs_n = nthr_n;
309 *nthrs_k = nthr_k;
310
311 *BM = MB;
312 *BN = NB;
313 *BK = KB;
314}
315#undef BM_NOCOPY_AVX512_COMMON
316#undef BN_NOCOPY_AVX512_COMMON
317#undef BK_NOCOPY_AVX512_COMMON
318#undef BN_LARGE_NOCOPY_AVX512_COMMON
319#undef BM_SMALL_NOCOPY_AVX512_COMMON
320#undef BN_SMALL_NOCOPY_AVX512_COMMON
321#undef BK_SMALL_NOCOPY_AVX512_COMMON
322
323// Partition n values as equally as possible among nthr threads
324// and set the offset (t_offset) and number of values (t_block) for ithr
325// Assumption: 0 <= ithr < nthr
326void partition_unit_diff(
327 int ithr, int nthr, dim_t n, dim_t *t_offset, dim_t *t_block) {
328
329 dim_t band = n / nthr;
330 if (band == 0) band = 1;
331 dim_t tail = n - band * nthr;
332 if (tail < 0) tail = 0;
333
334 if (ithr < tail) {
335 band++;
336 *t_offset = band * ithr;
337 *t_block = band;
338 } else {
339 *t_offset = band * ithr + tail;
340 *t_block = band;
341 }
342
343 if (*t_offset >= n) {
344 *t_offset = 0;
345 *t_block = 0;
346 }
347
348 if (*t_offset + *t_block > n) { *t_block = n - *t_offset; }
349}
350
351// Sum the m*n values from p_src into p_dst, assuming the two-dimensional
352// arrays have leading dimensions ld_src and ld_dst, respectively
353template <typename data_t>
354void sum_two_matrices(dim_t m, dim_t n, data_t *__restrict p_src, dim_t ld_src,
355 data_t *__restrict p_dst, dim_t ld_dst) {
356
357 for (dim_t j = 0; j < n; j++) {
358 for (dim_t i = 0; i < m; i++) {
359 p_dst[i + j * ld_dst] += p_src[i + j * ld_src];
360 }
361 }
362}
363
364template void sum_two_matrices<float>(dim_t m, dim_t n, float *__restrict p_src,
365 dim_t ld_src, float *__restrict p_dst, dim_t ld_dst);
366
367template void sum_two_matrices<double>(dim_t m, dim_t n,
368 double *__restrict p_src, dim_t ld_src, double *__restrict p_dst,
369 dim_t ld_dst);
370} // namespace gemm_utils
371} // namespace cpu
372} // namespace impl
373} // namespace dnnl
374