1 | /******************************************************************************* |
2 | * Copyright 2018-2020 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | #include <cmath> |
17 | |
18 | #include "common/dnnl_thread.hpp" |
19 | #include "common/utils.hpp" |
20 | |
21 | #include "cpu/gemm/f32/gemm_utils_f32.hpp" |
22 | |
23 | namespace dnnl { |
24 | namespace impl { |
25 | namespace cpu { |
26 | namespace gemm_utils { |
27 | #define BM_NOCOPY_AVX 64 |
28 | #define BN_NOCOPY_AVX 48 |
29 | #define BK_NOCOPY_AVX 384 |
30 | #define BN_LARGE_NOCOPY_AVX 192 |
31 | #define BM_SMALL_NOCOPY_AVX 16 |
32 | #define BN_SMALL_NOCOPY_AVX 1 |
33 | #define BK_SMALL_NOCOPY_AVX 4 |
34 | // Determine number of threads for each dimension of a 3-D partitioning |
35 | // algorithm based on input parameters |
36 | // m/n/k - First/second/third parameter for GEMM |
37 | // nthrs - total available number of threads |
38 | // nthrs_m/nthrs_n/nthrs_k - number of threads to use in each dimension |
39 | // BM/BN/BK - blocking values |
40 | void calc_nthr_nocopy_avx(dim_t m, dim_t n, dim_t k, int nthrs, int *nthrs_m, |
41 | int *nthrs_n, int *nthrs_k, dim_t *BM, dim_t *BN, dim_t *BK) { |
42 | |
43 | // Quick exit for single thread. |
44 | if (nthrs == 1) { |
45 | *nthrs_m = 1; |
46 | *nthrs_n = 1; |
47 | *nthrs_k = 1; |
48 | |
49 | *BM = m; |
50 | *BN = n; |
51 | *BK = k; |
52 | return; |
53 | } |
54 | |
55 | int nthr, nthr_m, nthr_n, nthr_k; |
56 | dim_t MB, NB, KB; |
57 | |
58 | nthr = nthrs; |
59 | nthr_m = (m + BM_NOCOPY_AVX - 1) / BM_NOCOPY_AVX; |
60 | nthr_n = (n + BN_NOCOPY_AVX - 1) / BN_NOCOPY_AVX; |
61 | nthr_k = 1; |
62 | |
63 | // Partition along K dimension |
64 | // - if threading allows having barriers (e.g. OMP) |
65 | // - if there is not enough parallelism along M or N |
66 | if (dnnl_thr_syncable()) { |
67 | int nthr_other = nthr_k = 1; |
68 | while ((nthr_m * nthr_n * nthr_other < nthr) |
69 | && (k / (nthr_other + 1) > BK_NOCOPY_AVX)) { |
70 | nthr_other++; |
71 | if ((nthr / nthr_other) * nthr_other > 0.9 * nthr) |
72 | nthr_k = nthr_other; |
73 | } |
74 | } |
75 | nthr /= nthr_k; |
76 | |
77 | if (nthr_m == 1) nthr_n = nthr; |
78 | if (nthr_n == 1) nthr_m = nthr; |
79 | |
80 | // Simple partition reduction |
81 | while (nthr_m * nthr_n > nthr) |
82 | if (nthr_m > nthr_n) |
83 | nthr_m--; |
84 | else |
85 | nthr_n--; |
86 | while (nthr_m * nthr_n < nthr) |
87 | if (nthr_m < nthr_n) |
88 | nthr_m++; |
89 | else |
90 | nthr_n++; |
91 | |
92 | if ((nthr_m * nthr_n > nthr) && (nthr_m > 1) && (nthr_n > 1)) { |
93 | |
94 | if (nthr_m <= nthr_n) { |
95 | nthr_m = (int)sqrt((double)nthr); |
96 | if (nthr_m > (m + BM_SMALL_NOCOPY_AVX - 1) / BM_SMALL_NOCOPY_AVX) |
97 | nthr_m = (m + BM_SMALL_NOCOPY_AVX - 1) / BM_SMALL_NOCOPY_AVX; |
98 | nthr_n = nthr / nthr_m; |
99 | |
100 | while ((nthr_m > 1) && (nthr_m * nthr_n != nthr)) { |
101 | nthr_m--; |
102 | nthr_n = nthr / nthr_m; |
103 | } |
104 | } else { |
105 | nthr_n = (int)sqrt((double)nthr); |
106 | if (nthr_n > (n + BN_SMALL_NOCOPY_AVX - 1) / BN_SMALL_NOCOPY_AVX) |
107 | nthr_n = (n + BN_SMALL_NOCOPY_AVX - 1) / BN_SMALL_NOCOPY_AVX; |
108 | nthr_m = nthr / nthr_n; |
109 | |
110 | while ((nthr_n > 1) && (nthr_m * nthr_n != nthr)) { |
111 | nthr_n--; |
112 | nthr_m = nthr / nthr_n; |
113 | } |
114 | } |
115 | } |
116 | |
117 | MB = (m + nthr_m - 1) / nthr_m + BM_SMALL_NOCOPY_AVX - 1; |
118 | MB -= MB % BM_SMALL_NOCOPY_AVX; |
119 | NB = (n + nthr_n - 1) / nthr_n + BN_SMALL_NOCOPY_AVX - 1; |
120 | NB -= NB % BN_SMALL_NOCOPY_AVX; |
121 | KB = (k + nthr_k - 1) / nthr_k + BK_SMALL_NOCOPY_AVX - 1; |
122 | KB -= KB % BK_SMALL_NOCOPY_AVX; |
123 | |
124 | if (MB * nthr_m > m) nthr_m = (m + MB - 1) / MB; |
125 | if (NB * nthr_n > n) nthr_n = (n + NB - 1) / NB; |
126 | if (KB * nthr_k > k) nthr_k = (k + KB - 1) / KB; |
127 | |
128 | *nthrs_m = nthr_m; |
129 | *nthrs_n = nthr_n; |
130 | *nthrs_k = nthr_k; |
131 | |
132 | *BM = MB; |
133 | *BN = NB; |
134 | *BK = KB; |
135 | } |
136 | #undef BM_NOCOPY_AVX |
137 | #undef BN_NOCOPY_AVX |
138 | #undef BK_NOCOPY_AVX |
139 | #undef BN_LARGE_NOCOPY_AVX |
140 | #undef BM_SMALL_NOCOPY_AVX |
141 | #undef BN_SMALL_NOCOPY_AVX |
142 | #undef BK_SMALL_NOCOPY_AVX |
143 | |
144 | #define BM_NOCOPY_AVX512_COMMON 32 |
145 | #define BN_NOCOPY_AVX512_COMMON 64 |
146 | #define BK_NOCOPY_AVX512_COMMON 192 |
147 | #define BN_LARGE_NOCOPY_AVX512_COMMON 192 |
148 | #define BM_SMALL_NOCOPY_AVX512_COMMON 16 |
149 | #define BN_SMALL_NOCOPY_AVX512_COMMON 1 |
150 | #define BK_SMALL_NOCOPY_AVX512_COMMON 4 |
151 | // Determine number of threads for each dimension of a 3-D partitioning |
152 | // algorithm based on input parameters |
153 | // m/n/k - First/second/third parameter for GEMM |
154 | // nthrs - total available number of threads |
155 | // nthrs_m/nthrs_n/nthrs_k - number of threads to use in each dimension |
156 | // BM/BN/BK - blocking values |
157 | void calc_nthr_nocopy_avx512_common(dim_t m, dim_t n, dim_t k, int nthrs, |
158 | int *nthrs_m, int *nthrs_n, int *nthrs_k, dim_t *BM, dim_t *BN, |
159 | dim_t *BK) { |
160 | |
161 | // Quick exit for single thread. |
162 | if (nthrs == 1) { |
163 | *nthrs_m = 1; |
164 | *nthrs_n = 1; |
165 | *nthrs_k = 1; |
166 | |
167 | *BM = m; |
168 | *BN = n; |
169 | *BK = k; |
170 | return; |
171 | } |
172 | |
173 | int nthr, nthr_m, nthr_n, nthr_k = 1; |
174 | dim_t MB, NB, KB; |
175 | nthr = nthrs; |
176 | |
177 | int counter = 0; |
178 | float ratio_float = 1.; |
179 | int ratio = 1; |
180 | nthr = nthrs; |
181 | int nthr_m_gt_n; |
182 | |
183 | // Partition along K dimension |
184 | // - if threading allows having barriers (e.g. OMP) |
185 | // - if there is not enough parallelism along M or N |
186 | if (dnnl_thr_syncable()) { |
187 | if (n <= 2 * BN_NOCOPY_AVX512_COMMON |
188 | && m <= 2 * BM_NOCOPY_AVX512_COMMON * nthr && k > m && k > n) { |
189 | nthr_k = k / BK_NOCOPY_AVX512_COMMON; |
190 | if (nthr_k > nthr / 4) nthr_k = nthr / 4; |
191 | if (nthr_k < 1) nthr_k = 1; |
192 | |
193 | while ((nthr_k > 1) && (nthr % nthr_k)) { |
194 | nthr_k--; |
195 | } |
196 | nthr /= nthr_k; |
197 | } else { |
198 | nthr_k = 1; |
199 | } |
200 | } |
201 | nthr_m = (m + BM_NOCOPY_AVX512_COMMON - 1) / BM_NOCOPY_AVX512_COMMON; |
202 | nthr_n = (n + BN_NOCOPY_AVX512_COMMON - 1) / BN_NOCOPY_AVX512_COMMON; |
203 | |
204 | if (nthr_m < 1) nthr_m = 1; |
205 | if (nthr_n < 1) nthr_n = 1; |
206 | |
207 | nthr_m_gt_n = nthr_m > nthr_n ? 1 : 0; |
208 | ratio_float = (float)nthr_m / nthr_n; |
209 | |
210 | if (nthr_m_gt_n) |
211 | ratio = (int)ratio_float; |
212 | else |
213 | ratio = (int)(1. / ratio_float); |
214 | |
215 | // scale down nthr_m and nthr_n if they are too large |
216 | while (nthr_m * nthr_n > 4 * nthr) { |
217 | nthr_m /= 2; |
218 | nthr_n /= 2; |
219 | } |
220 | |
221 | if (nthr_m < 1) nthr_m = 1; |
222 | if (nthr_n < 1) nthr_n = 1; |
223 | |
224 | // Simple partition reduction |
225 | counter = 0; |
226 | while (nthr_m * nthr_n > nthr) { |
227 | if (nthr_m > nthr_n) { |
228 | if (counter < ratio) |
229 | nthr_m--; |
230 | else { |
231 | nthr_n--; |
232 | counter = -1; |
233 | } |
234 | } else { |
235 | if (counter < ratio) |
236 | nthr_n--; |
237 | else { |
238 | nthr_m--; |
239 | counter = -1; |
240 | } |
241 | } |
242 | counter++; |
243 | } |
244 | |
245 | // Simple partition increment |
246 | counter = 0; |
247 | while (nthr_m * nthr_n < 0.95 * nthr) { |
248 | if (nthr_m > nthr_n) { |
249 | if (counter < ratio) |
250 | nthr_m++; |
251 | else { |
252 | nthr_n++; |
253 | counter = -1; |
254 | } |
255 | } else { |
256 | if (counter < ratio) |
257 | nthr_n++; |
258 | else { |
259 | nthr_m++; |
260 | counter = -1; |
261 | } |
262 | } |
263 | counter++; |
264 | } |
265 | |
266 | // if nothing works out, then this should work |
267 | if ((nthr_m * nthr_n > nthr)) { |
268 | |
269 | if (nthr_m <= nthr_n) { |
270 | nthr_m = (int)sqrt((double)nthr); |
271 | if (nthr_m > (m + BM_SMALL_NOCOPY_AVX512_COMMON - 1) |
272 | / BM_SMALL_NOCOPY_AVX512_COMMON) |
273 | nthr_m = (m + BM_SMALL_NOCOPY_AVX512_COMMON - 1) |
274 | / BM_SMALL_NOCOPY_AVX512_COMMON; |
275 | nthr_n = nthr / nthr_m; |
276 | |
277 | while ((nthr_m > 1) && (nthr_m * nthr_n != nthr)) { |
278 | nthr_m--; |
279 | nthr_n = nthr / nthr_m; |
280 | } |
281 | } else { |
282 | nthr_n = (int)sqrt((double)nthr); |
283 | if (nthr_n > (n + BN_SMALL_NOCOPY_AVX512_COMMON - 1) |
284 | / BN_SMALL_NOCOPY_AVX512_COMMON) |
285 | nthr_n = (n + BN_SMALL_NOCOPY_AVX512_COMMON - 1) |
286 | / BN_SMALL_NOCOPY_AVX512_COMMON; |
287 | nthr_m = nthr / nthr_n; |
288 | |
289 | while ((nthr_n > 1) && (nthr_m * nthr_n != nthr)) { |
290 | nthr_n--; |
291 | nthr_m = nthr / nthr_n; |
292 | } |
293 | } |
294 | } |
295 | |
296 | MB = (m + nthr_m - 1) / nthr_m + BM_SMALL_NOCOPY_AVX512_COMMON - 1; |
297 | MB -= MB % BM_SMALL_NOCOPY_AVX512_COMMON; |
298 | NB = (n + nthr_n - 1) / nthr_n + BN_SMALL_NOCOPY_AVX512_COMMON - 1; |
299 | NB -= NB % BN_SMALL_NOCOPY_AVX512_COMMON; |
300 | KB = (k + nthr_k - 1) / nthr_k + BK_SMALL_NOCOPY_AVX512_COMMON - 1; |
301 | KB -= KB % BK_SMALL_NOCOPY_AVX512_COMMON; |
302 | |
303 | if (MB * nthr_m > m) nthr_m = (m + MB - 1) / MB; |
304 | if (NB * nthr_n > n) nthr_n = (n + NB - 1) / NB; |
305 | if (KB * nthr_k > k) nthr_k = (k + KB - 1) / KB; |
306 | |
307 | *nthrs_m = nthr_m; |
308 | *nthrs_n = nthr_n; |
309 | *nthrs_k = nthr_k; |
310 | |
311 | *BM = MB; |
312 | *BN = NB; |
313 | *BK = KB; |
314 | } |
315 | #undef BM_NOCOPY_AVX512_COMMON |
316 | #undef BN_NOCOPY_AVX512_COMMON |
317 | #undef BK_NOCOPY_AVX512_COMMON |
318 | #undef BN_LARGE_NOCOPY_AVX512_COMMON |
319 | #undef BM_SMALL_NOCOPY_AVX512_COMMON |
320 | #undef BN_SMALL_NOCOPY_AVX512_COMMON |
321 | #undef BK_SMALL_NOCOPY_AVX512_COMMON |
322 | |
323 | // Partition n values as equally as possible among nthr threads |
324 | // and set the offset (t_offset) and number of values (t_block) for ithr |
325 | // Assumption: 0 <= ithr < nthr |
326 | void partition_unit_diff( |
327 | int ithr, int nthr, dim_t n, dim_t *t_offset, dim_t *t_block) { |
328 | |
329 | dim_t band = n / nthr; |
330 | if (band == 0) band = 1; |
331 | dim_t tail = n - band * nthr; |
332 | if (tail < 0) tail = 0; |
333 | |
334 | if (ithr < tail) { |
335 | band++; |
336 | *t_offset = band * ithr; |
337 | *t_block = band; |
338 | } else { |
339 | *t_offset = band * ithr + tail; |
340 | *t_block = band; |
341 | } |
342 | |
343 | if (*t_offset >= n) { |
344 | *t_offset = 0; |
345 | *t_block = 0; |
346 | } |
347 | |
348 | if (*t_offset + *t_block > n) { *t_block = n - *t_offset; } |
349 | } |
350 | |
351 | // Sum the m*n values from p_src into p_dst, assuming the two-dimensional |
352 | // arrays have leading dimensions ld_src and ld_dst, respectively |
353 | template <typename data_t> |
354 | void sum_two_matrices(dim_t m, dim_t n, data_t *__restrict p_src, dim_t ld_src, |
355 | data_t *__restrict p_dst, dim_t ld_dst) { |
356 | |
357 | for (dim_t j = 0; j < n; j++) { |
358 | for (dim_t i = 0; i < m; i++) { |
359 | p_dst[i + j * ld_dst] += p_src[i + j * ld_src]; |
360 | } |
361 | } |
362 | } |
363 | |
364 | template void sum_two_matrices<float>(dim_t m, dim_t n, float *__restrict p_src, |
365 | dim_t ld_src, float *__restrict p_dst, dim_t ld_dst); |
366 | |
367 | template void sum_two_matrices<double>(dim_t m, dim_t n, |
368 | double *__restrict p_src, dim_t ld_src, double *__restrict p_dst, |
369 | dim_t ld_dst); |
370 | } // namespace gemm_utils |
371 | } // namespace cpu |
372 | } // namespace impl |
373 | } // namespace dnnl |
374 | |