1 | /******************************************************************************* |
2 | * Copyright 2019-2020 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_GEMM_GEMM_PARTITION_HPP |
18 | #define CPU_X64_GEMM_GEMM_PARTITION_HPP |
19 | |
20 | #include <array> |
21 | #include <cstdint> |
22 | #include <tuple> |
23 | |
24 | #include "common/nstl.hpp" |
25 | #include "common/utils.hpp" |
26 | |
27 | namespace dnnl { |
28 | namespace impl { |
29 | namespace cpu { |
30 | namespace x64 { |
31 | |
32 | static inline void partition_1d(const int ithr, const int nthrs, const dim_t n, |
33 | dim_t &t_offset, dim_t &t_block) { |
34 | |
35 | dim_t band = n / nthrs; |
36 | |
37 | dim_t tail = n - (nthrs - 1) * band; |
38 | if (tail > (band + 1)) band++; |
39 | tail = n - (nthrs - 1) * band; |
40 | |
41 | if (ithr < (nthrs - 1)) |
42 | t_block = band; |
43 | else |
44 | t_block = tail; |
45 | |
46 | t_offset = ithr * band; |
47 | |
48 | if (t_offset >= n) { |
49 | t_block = 0; |
50 | t_offset = 0; |
51 | } else if ((t_offset + t_block) > n) { |
52 | t_block = n - t_offset; |
53 | } |
54 | } |
55 | |
56 | static inline void partition_2d(const int ithr, int *nthrs, const int ithr_i, |
57 | const int ithr_j, const int nthrs_m, const int nthrs_n, const dim_t m, |
58 | const dim_t n, dim_t &out_m_disp, dim_t &out_m_band, dim_t &out_n_disp, |
59 | dim_t &out_n_band) { |
60 | |
61 | dim_t m_disp = 0, n_disp = 0; |
62 | dim_t m_band = 0, n_band = 0; |
63 | |
64 | int m_div = nthrs_m; |
65 | int n_div = nthrs_n; |
66 | |
67 | dim_t m_bandt = m / m_div; /* size per thread */ |
68 | dim_t n_bandt = n / n_div; /* size per thread */ |
69 | int first_m_group = m_div - 1; |
70 | int first_n_group = n_div - 1; |
71 | dim_t first_m_val = m_bandt; |
72 | dim_t first_n_val = n_bandt; |
73 | |
74 | int mthr_used = m_div; |
75 | if (m - (m_div - 1) * m_bandt > m_bandt + 1) { |
76 | if (m - (m_div - 1) * m_bandt > m_div) ++m_bandt; |
77 | |
78 | first_m_val = m_bandt + 1; |
79 | mthr_used = (int)(m / first_m_val); |
80 | |
81 | if (mthr_used * first_m_val < m) ++mthr_used; |
82 | |
83 | first_m_group = mthr_used - 1; |
84 | } |
85 | |
86 | int nthr_used = n_div; |
87 | if (n - (n_div - 1) * n_bandt > n_bandt + 1) { |
88 | first_n_val = n_bandt + 1; |
89 | nthr_used = (int)(n / first_n_val); |
90 | |
91 | if (nthr_used * first_n_val < n) ++nthr_used; |
92 | |
93 | first_n_group = nthr_used - 1; |
94 | } |
95 | |
96 | *nthrs = mthr_used * nthr_used; |
97 | |
98 | if (ithr < *nthrs) { |
99 | if (ithr_i < first_m_group) { |
100 | m_band = first_m_val; |
101 | m_disp = ithr_i * first_m_val; |
102 | } else if (ithr_i <= mthr_used - 2) { |
103 | m_band = m_bandt; |
104 | m_disp = first_m_group * first_m_val |
105 | + (ithr_i - first_m_group) * m_bandt; |
106 | } else { |
107 | m_disp = first_m_group * first_m_val |
108 | + (mthr_used - 1 - first_m_group) * m_bandt; |
109 | m_band = nstl::max(dim_t(0), m - m_disp); |
110 | } |
111 | |
112 | if (ithr_j < first_n_group) { |
113 | n_band = first_n_val; |
114 | n_disp = ithr_j * first_n_val; |
115 | } else if (ithr_j <= nthr_used - 2) { |
116 | n_band = n_bandt; |
117 | n_disp = first_n_group * first_n_val |
118 | + (ithr_j - first_n_group) * n_bandt; |
119 | } else { |
120 | n_disp = first_n_group * first_n_val |
121 | + (nthr_used - 1 - first_n_group) * n_bandt; |
122 | n_band = nstl::max(dim_t(0), n - n_disp); |
123 | } |
124 | m_disp = nstl::max(nstl::min(m_disp, m - 1), dim_t(0)); |
125 | n_disp = nstl::max(nstl::min(n_disp, n - 1), dim_t(0)); |
126 | } |
127 | |
128 | if (ithr < *nthrs) { |
129 | out_m_disp = m_disp; |
130 | out_n_disp = n_disp; |
131 | out_m_band = m_band; |
132 | out_n_band = n_band; |
133 | } else { |
134 | out_m_disp = 0; |
135 | out_n_disp = 0; |
136 | out_m_band = 0; |
137 | out_n_band = 0; |
138 | } |
139 | |
140 | return; |
141 | } |
142 | |
143 | static inline std::tuple<int, int> partition_2d_minblk_with_primes(dim_t m, |
144 | dim_t n, dim_t block_m, dim_t block_n, dim_t min_m, dim_t min_n, |
145 | dim_t um, dim_t un, int nthr, bool use_aspect_ratio) { |
146 | |
147 | auto part_m = nstl::max(dim_t(1), m / block_m); |
148 | auto part_n = nstl::max(dim_t(1), n / block_n); |
149 | |
150 | // Quick exit if there are enough partitions in one direction |
151 | // and there is only 1 partition in the other one |
152 | if (part_m == 1 && part_n >= nthr) |
153 | return std::make_tuple(1, nstl::min((int)part_n, nthr)); |
154 | |
155 | if (part_n == 1 && part_m >= nthr) |
156 | return std::make_tuple(nstl::min((int)part_m, nthr), 1); |
157 | |
158 | auto num_parts = part_m * part_n; |
159 | |
160 | int nthr_ite = nthr; |
161 | int nthr_m = 1, nthr_n = 1; |
162 | dim_t band_m = m, band_n = n; |
163 | |
164 | for (auto p : {2, 3, 5, 7, 11, 13, 17, 19, 23, 29}) { |
165 | bool finished = false; |
166 | |
167 | while ((nthr_ite % p) == 0 && !finished) { |
168 | nthr_ite /= p; |
169 | auto nthr_m_ite = nthr_m * p; |
170 | auto nthr_n_ite = nthr_n * p; |
171 | |
172 | auto band_m_ite = band_m / p; |
173 | auto band_n_ite = band_n / p; |
174 | |
175 | // Try partitioning with block size bm x bn |
176 | auto try_partition = [&](dim_t bm, dim_t bn, bool pick_small) { |
177 | float ratio_m = (float)band_m_ite / bm; |
178 | float ratio_n = (float)band_n_ite / bn; |
179 | bool do_m = false, do_n = false; |
180 | |
181 | if (ratio_m < 1. && ratio_n >= 1.) |
182 | do_n = true; |
183 | else if (ratio_m >= 1. && ratio_n < 1.) |
184 | do_m = true; |
185 | else if (ratio_m >= 1. && ratio_n >= 1.) { |
186 | if (use_aspect_ratio) { |
187 | float ratio_goal = (float)um / un; |
188 | float try_ratio_m = (float)band_m_ite / band_n |
189 | * (1. / ratio_goal); |
190 | float try_ratio_n = (float)band_m / band_n_ite |
191 | * (1. / ratio_goal); |
192 | if (pick_small) { |
193 | // Pick either the smaller or larger ratio as appropriate. |
194 | ((ratio_m < ratio_n) ? do_m : do_n) = true; |
195 | } else { |
196 | // Pick the dimension that will keep as close as possible |
197 | // to best ratio between m and n. |
198 | ((nstl::abs(try_ratio_m - 1.) |
199 | < nstl::abs(try_ratio_n - 1)) |
200 | ? do_m |
201 | : do_n) |
202 | = true; |
203 | } |
204 | } else { |
205 | (((ratio_m < ratio_n) == pick_small) ? do_m : do_n) |
206 | = true; |
207 | } |
208 | } |
209 | |
210 | if (do_m) { |
211 | // Partition m. |
212 | nthr_m = nthr_m_ite; |
213 | band_m = band_m_ite; |
214 | } else if (do_n) { |
215 | // Partition n. |
216 | nthr_n = nthr_n_ite; |
217 | band_n = band_n_ite; |
218 | } |
219 | |
220 | return do_m || do_n; |
221 | }; |
222 | |
223 | // If we will need min based partitioning do it now |
224 | if (num_parts < nthr) { |
225 | num_parts *= p; |
226 | if (try_partition(min_m, min_n, true)) continue; |
227 | } |
228 | |
229 | if (try_partition(block_m, block_n, false)) continue; |
230 | if (try_partition(min_m, min_n, true)) continue; |
231 | |
232 | // Both band_m/n are smaller than min_m/n |
233 | // exit the loops, nothing to partition |
234 | finished = true; |
235 | } |
236 | |
237 | if (finished) break; |
238 | } |
239 | |
240 | return std::make_tuple(nthr_m, nthr_n); |
241 | } |
242 | |
243 | static inline std::tuple<int, int> partition_2d_minblk(dim_t m, dim_t n, |
244 | dim_t block_m, dim_t block_n, dim_t min_m, dim_t min_n, dim_t um, |
245 | dim_t un, int nthr, bool use_aspect_ratio) { |
246 | |
247 | auto part_m = nstl::max(dim_t(1), m / min_m); |
248 | auto part_n = nstl::max(dim_t(1), n / min_n); |
249 | |
250 | // Quick exit if one of the dimensions is too small to partition. |
251 | if (part_m == 1) { |
252 | part_n = nstl::max(dim_t(1), utils::div_up(n, min_n)); |
253 | return std::make_tuple(1, nstl::min((int)part_n, nthr)); |
254 | } |
255 | |
256 | if (part_n == 1) { |
257 | part_m = nstl::max(dim_t(1), utils::div_up(m, min_m)); |
258 | return std::make_tuple(nstl::min((int)part_m, nthr), 1); |
259 | } |
260 | |
261 | int nthr_m = 0, nthr_n = 0; |
262 | auto nthr_thresh = nstl::min(0.95 * nthr, (double)(part_m * part_n)); |
263 | |
264 | for (int nthr_new = nthr; nthr_new > nthr / 2; nthr_new--) { |
265 | if (nthr_m * nthr_n >= nthr_thresh) break; |
266 | std::tie(nthr_m, nthr_n) |
267 | = partition_2d_minblk_with_primes(m, n, block_m, block_n, min_m, |
268 | min_n, um, un, nthr_new, use_aspect_ratio); |
269 | } |
270 | |
271 | return std::make_tuple(nthr_m, nthr_n); |
272 | } |
273 | |
274 | } // namespace x64 |
275 | } // namespace cpu |
276 | } // namespace impl |
277 | } // namespace dnnl |
278 | |
279 | #endif |
280 | |