1 | /******************************************************************************* |
2 | * Copyright 2019-2021 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_GEMM_GEMM_PACK_STORAGE_HPP |
18 | #define CPU_X64_GEMM_GEMM_PACK_STORAGE_HPP |
19 | |
20 | #include <cstdint> |
21 | |
22 | #include "common/dnnl_thread.hpp" |
23 | #include "common/utils.hpp" |
24 | |
25 | #include "cpu/x64/gemm/gemm_threading.hpp" |
26 | |
27 | namespace dnnl { |
28 | namespace impl { |
29 | namespace cpu { |
30 | namespace x64 { |
31 | |
32 | enum struct matrix_id { a, b }; |
33 | |
34 | struct gemm_pack_storage_t { |
35 | gemm_threading_t &threading() { return header->threading; } |
36 | matrix_id &which() { return header->which; } |
37 | bool &has_row_sums() { return header->has_row_sums; } |
38 | bool &has_col_sums() { return header->has_col_sums; } |
39 | |
40 | const gemm_threading_t &threading() const { return header->threading; } |
41 | const matrix_id &which() const { return header->which; } |
42 | const bool &has_row_sums() const { return header->has_row_sums; } |
43 | const bool &has_col_sums() const { return header->has_col_sums; } |
44 | |
45 | size_t size() const { return header->size; } |
46 | void *get() const { return static_cast<void *>(base); } |
47 | void set(void *data) { |
48 | base = static_cast<char *>(data); |
49 | header = static_cast<header_t *>(data); |
50 | } |
51 | |
52 | bool single_nocopy() const { |
53 | return (threading().copy == copy_type::no_copy); |
54 | } |
55 | |
56 | int nthr() const { return single_nocopy() ? 1 : threading().nthrs(); } |
57 | |
58 | int nslice() const { |
59 | return (which() == matrix_id::a) |
60 | ? threading().nthrs_m * threading().nthrs_k |
61 | : threading().nthrs_n * threading().nthrs_k; |
62 | } |
63 | |
64 | template <typename data_type> |
65 | gemm_pack_storage_t(data_type *data_, bool = true) |
66 | : base(nullptr), header_set(header_set_) { |
67 | reset((void *)data_); |
68 | } |
69 | |
70 | gemm_pack_storage_t() |
71 | : base(nullptr) |
72 | , header(nullptr) |
73 | , matrix_header(nullptr) |
74 | , sums_header(nullptr) |
75 | , header_set(true) {} |
76 | |
77 | std::tuple<int, int> thread_slice_info(int ithr) const { |
78 | assert(ithr < nthr()); |
79 | |
80 | bool is_a = (which() == matrix_id::a); |
81 | auto nthr_inner = is_a ? threading().nthrs_m : threading().nthrs_n; |
82 | |
83 | auto ithr_i = ithr % threading().nthrs_m; |
84 | auto ithr_jk = ithr / threading().nthrs_m; |
85 | auto ithr_j = ithr_jk % threading().nthrs_n; |
86 | auto ithr_k = ithr_jk / threading().nthrs_n; |
87 | |
88 | auto ithr_inner = is_a ? ithr_i : ithr_j; |
89 | auto ithr_outer = ithr_k; |
90 | auto ithr_slice = is_a ? ithr_j : ithr_i; |
91 | |
92 | auto id = ithr_outer * nthr_inner + ithr_inner; |
93 | |
94 | return std::make_tuple(id, ithr_slice); |
95 | } |
96 | |
97 | int thread_to_slice(int ithr) const { |
98 | return std::get<0>(thread_slice_info(ithr)); |
99 | } |
100 | |
101 | bool is_first_thread_in_slice(int ithr) const { |
102 | return (std::get<1>(thread_slice_info(ithr)) == 0); |
103 | } |
104 | |
105 | template <typename data_type> |
106 | data_type *row_sums(int ithr, dim_t r0, dim_t cblock) const { |
107 | if (!has_row_sums()) return NULL; |
108 | auto id = thread_to_slice(ithr); |
109 | return get_block<data_type>(sums_header->slice[id], r0, cblock); |
110 | } |
111 | |
112 | template <typename data_type> |
113 | data_type *col_sums(int ithr, dim_t rblock, dim_t c0) const { |
114 | if (!has_col_sums()) return NULL; |
115 | auto id = thread_to_slice(ithr); |
116 | return get_block<data_type>(sums_header->slice[id], rblock, c0); |
117 | } |
118 | |
119 | template <typename data_type> |
120 | data_type *matrix(int ithr, dim_t r0, dim_t c0) const { |
121 | auto id = thread_to_slice(ithr); |
122 | return get_block<data_type>(matrix_header->slice[id], r0, c0); |
123 | } |
124 | |
125 | template <typename data_type> |
126 | data_type *matrix(int ithr) const { |
127 | assert(!matrix_header->slice[thread_to_slice(ithr)].packed); |
128 | return matrix<data_type>(ithr, 0, 0); |
129 | } |
130 | |
131 | template <typename data_type> |
132 | data_type *matrix() const { |
133 | assert(single_nocopy()); |
134 | return matrix<data_type>(0); |
135 | } |
136 | |
137 | bool get_nocopy(int ithr, int &trans, dim_t &ld, dim_t &td) const { |
138 | auto id = thread_to_slice(ithr); |
139 | return matrix_header->slice[id].get_nocopy(trans, ld, td); |
140 | } |
141 | |
142 | bool get_nocopy(int &trans, dim_t &ld, dim_t &td) const { |
143 | if (!single_nocopy()) return false; |
144 | return get_nocopy(0, trans, ld, td); |
145 | } |
146 | |
147 | void get_blocking(int ithr, dim_t &block_r, dim_t &block_c) const { |
148 | auto id = thread_to_slice(ithr); |
149 | matrix_header->slice[id].get_blocking(block_r, block_c); |
150 | } |
151 | |
152 | void set_blocking( |
153 | int ithr, dim_t rows, dim_t cols, dim_t block_r, dim_t block_c) { |
154 | |
155 | auto id = thread_to_slice(ithr); |
156 | auto nblk_r = (block_r == 0) ? 0 : utils::div_up(rows, block_r); |
157 | auto nblk_c = (block_c == 0) ? 0 : utils::div_up(cols, block_c); |
158 | |
159 | matrix_header->slice[id].set_blocking(nblk_r, nblk_c, block_r, block_c); |
160 | |
161 | if (has_row_sums()) |
162 | sums_header->slice[id].set_blocking(nblk_r, nblk_c, block_r, 1); |
163 | else |
164 | sums_header->slice[id].set_blocking(nblk_r, nblk_c, 1, block_c); |
165 | } |
166 | |
167 | void set_nocopy(int ithr, int trans, dim_t ld, dim_t td) { |
168 | auto id = thread_to_slice(ithr); |
169 | matrix_header->slice[id].set_nocopy(trans, ld, td); |
170 | } |
171 | |
172 | void setup(int max_nthr, bool has_row_sums = false, |
173 | bool has_col_sums = false) { |
174 | |
175 | assert(!(has_row_sums && has_col_sums)); |
176 | |
177 | auto sz_mh = matrix_header_size(max_nthr); |
178 | auto sz_h = header_size(); |
179 | |
180 | header->has_row_sums = has_row_sums; |
181 | header->has_col_sums = has_col_sums; |
182 | header->off_matrix = sz_h; |
183 | header->off_sums = sz_h + sz_mh; |
184 | total_header_size = sz_h + sz_mh * 2; |
185 | |
186 | header->size = 0; |
187 | |
188 | header_set = true; |
189 | |
190 | reset(get()); |
191 | |
192 | for (int id = 0; id < max_nthr; id++) { |
193 | matrix_header->slice[id].set_blocking(0, 0, 0, 0); |
194 | sums_header->slice[id].set_blocking(0, 0, 0, 0); |
195 | } |
196 | } |
197 | |
198 | template <typename matrix_dt, typename sums_dt> |
199 | void finalize() { |
200 | assert(total_header_size > 0); |
201 | size_t cur_off = total_header_size; |
202 | |
203 | matrix_header->finalize<matrix_dt>(cur_off, nslice()); |
204 | if (has_row_sums() || has_col_sums()) |
205 | sums_header->finalize<sums_dt>(cur_off, nslice()); |
206 | |
207 | header->size = cur_off; |
208 | |
209 | /* Compute kernels overrun to preload data. */ |
210 | header->size += align_data; |
211 | } |
212 | |
213 | protected: |
214 | char *base; |
215 | |
216 | struct { |
217 | matrix_id ; |
218 | bool ; |
219 | bool ; |
220 | size_t , ; |
221 | size_t ; |
222 | gemm_threading_t ; /* if packed */ |
223 | } * ; |
224 | |
225 | struct { |
226 | bool ; |
227 | int ; |
228 | int , ; |
229 | dim_t , ; |
230 | size_t ; |
231 | |
232 | template <typename data_type> |
233 | size_t () const { |
234 | return utils::rnd_up( |
235 | block_r * block_c * sizeof(data_type), align_data); |
236 | } |
237 | |
238 | template <typename data_type> |
239 | size_t (dim_t r0, dim_t c0, bool col_major) const { |
240 | assert((r0 % block_r) == 0); |
241 | assert((c0 % block_c) == 0); |
242 | |
243 | auto rb = r0 / block_r; |
244 | auto cb = c0 / block_c; |
245 | auto mb = col_major ? rb + cb * nblk_r : cb + rb * nblk_c; |
246 | |
247 | return block_size<data_type>() * mb; |
248 | } |
249 | |
250 | template <typename data_type> |
251 | size_t () const { |
252 | return block_size<data_type>() * nblk_r * nblk_c; |
253 | } |
254 | |
255 | void ( |
256 | int nblk_r_, int nblk_c_, dim_t block_r_, dim_t block_c_) { |
257 | packed = true; |
258 | nblk_r = nblk_r_; |
259 | nblk_c = nblk_c_; |
260 | block_r = block_r_; |
261 | block_c = block_c_; |
262 | } |
263 | |
264 | void (int trans_, dim_t ld, dim_t td) { |
265 | packed = false; |
266 | trans = trans_; |
267 | block_r = ld; |
268 | block_c = td; |
269 | nblk_r = 1; |
270 | nblk_c = 1; |
271 | } |
272 | |
273 | void (dim_t &block_r_, dim_t &block_c_) const { |
274 | block_r_ = block_r; |
275 | block_c_ = block_c; |
276 | } |
277 | |
278 | bool (int &trans_, dim_t &ld, dim_t &td) const { |
279 | if (!packed) { |
280 | trans_ = trans; |
281 | ld = block_r; |
282 | td = block_c; |
283 | } |
284 | return !packed; |
285 | } |
286 | |
287 | template <typename data_type> |
288 | void (size_t &cur_off) { |
289 | cur_off = utils::rnd_up(cur_off, align_data); |
290 | off_data = cur_off; |
291 | cur_off += size<data_type>(); |
292 | } |
293 | }; |
294 | |
295 | struct { |
296 | dim_t ; /* if not packed */ |
297 | slice_header_t [1]; /* array of size nthr, if packed */ |
298 | |
299 | template <typename data_type> |
300 | void (size_t &cur_off, int nslices) { |
301 | #if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL |
302 | // This, I hope, is a temporary workaround... |
303 | // The reason for this special case is that in case of threadpool |
304 | // threading this function may be called to estimate the amount of |
305 | // memory needed when no threading information is actually |
306 | // available. Hence, it needs to provide an upper bound. |
307 | size_t max_off = cur_off; |
308 | for (int id = 0; id < nslices; id++) { |
309 | slice[id].finalize<data_type>(cur_off); |
310 | if (id == 0) { |
311 | // Assume that slice[0] is the largest one. |
312 | size_t slice0_size = cur_off - max_off; |
313 | max_off += slice0_size * dnnl_get_max_threads(); |
314 | } |
315 | } |
316 | if (!threadpool_utils::get_active_threadpool() && nslices) |
317 | // The std::max is a paranoid check for the case when slice[0] |
318 | // is not actually the largest one. Probably a crash will |
319 | // happen anyways... |
320 | cur_off = std::max(cur_off, max_off); |
321 | #else |
322 | for (int id = 0; id < nslices; id++) |
323 | slice[id].finalize<data_type>(cur_off); |
324 | #endif |
325 | } |
326 | } * , *; |
327 | |
328 | size_t = 0; |
329 | |
330 | static constexpr auto = 0x20; |
331 | static constexpr auto align_data = 0x1000; |
332 | |
333 | static size_t () { |
334 | return utils::rnd_up(sizeof(header_t), align_headers); |
335 | } |
336 | |
337 | static size_t (int max_nthr) { |
338 | auto sz = sizeof(matrix_header_t) |
339 | + sizeof(slice_header_t) * (max_nthr - 1); |
340 | |
341 | return utils::rnd_up(sz, align_headers); |
342 | } |
343 | |
344 | template <typename data_type> |
345 | data_type *( |
346 | const slice_header_t &slice, dim_t r0, dim_t c0) const { |
347 | return reinterpret_cast<data_type *>(base + slice.off_data |
348 | + slice.block_offset<data_type>(r0, c0, col_major())); |
349 | } |
350 | |
351 | bool col_major() const { return (which() == matrix_id::a); } |
352 | |
353 | void reset(void *data) { |
354 | set(data); |
355 | |
356 | if (!header_set) return; |
357 | |
358 | matrix_header = reinterpret_cast<matrix_header_t *>( |
359 | base + header->off_matrix); |
360 | sums_header |
361 | = reinterpret_cast<matrix_header_t *>(base + header->off_sums); |
362 | } |
363 | |
364 | bool = true; |
365 | }; |
366 | |
367 | struct gemm_pack_storage_shell_t : public gemm_pack_storage_t { |
368 | |
369 | gemm_pack_storage_shell_t(int max_nthr, bool has_row_sums = false, |
370 | bool has_col_sums = false) { |
371 | void *ptr = malloc(shell_size(max_nthr), 64); |
372 | if (ptr) { |
373 | set(ptr); |
374 | setup(max_nthr, has_row_sums, has_col_sums); |
375 | } |
376 | } |
377 | |
378 | ~gemm_pack_storage_shell_t() { free(get()); } |
379 | |
380 | private: |
381 | static size_t shell_size(int max_nthr) { |
382 | return header_size() + matrix_header_size(max_nthr) * 2; |
383 | } |
384 | }; |
385 | |
386 | } // namespace x64 |
387 | } // namespace cpu |
388 | } // namespace impl |
389 | } // namespace dnnl |
390 | |
391 | #endif |
392 | |