1/*******************************************************************************
2* Copyright 2019-2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_GEMM_GEMM_PACK_STORAGE_HPP
18#define CPU_X64_GEMM_GEMM_PACK_STORAGE_HPP
19
20#include <cstdint>
21
22#include "common/dnnl_thread.hpp"
23#include "common/utils.hpp"
24
25#include "cpu/x64/gemm/gemm_threading.hpp"
26
27namespace dnnl {
28namespace impl {
29namespace cpu {
30namespace x64 {
31
32enum struct matrix_id { a, b };
33
34struct gemm_pack_storage_t {
35 gemm_threading_t &threading() { return header->threading; }
36 matrix_id &which() { return header->which; }
37 bool &has_row_sums() { return header->has_row_sums; }
38 bool &has_col_sums() { return header->has_col_sums; }
39
40 const gemm_threading_t &threading() const { return header->threading; }
41 const matrix_id &which() const { return header->which; }
42 const bool &has_row_sums() const { return header->has_row_sums; }
43 const bool &has_col_sums() const { return header->has_col_sums; }
44
45 size_t size() const { return header->size; }
46 void *get() const { return static_cast<void *>(base); }
47 void set(void *data) {
48 base = static_cast<char *>(data);
49 header = static_cast<header_t *>(data);
50 }
51
52 bool single_nocopy() const {
53 return (threading().copy == copy_type::no_copy);
54 }
55
56 int nthr() const { return single_nocopy() ? 1 : threading().nthrs(); }
57
58 int nslice() const {
59 return (which() == matrix_id::a)
60 ? threading().nthrs_m * threading().nthrs_k
61 : threading().nthrs_n * threading().nthrs_k;
62 }
63
64 template <typename data_type>
65 gemm_pack_storage_t(data_type *data_, bool header_set_ = true)
66 : base(nullptr), header_set(header_set_) {
67 reset((void *)data_);
68 }
69
70 gemm_pack_storage_t()
71 : base(nullptr)
72 , header(nullptr)
73 , matrix_header(nullptr)
74 , sums_header(nullptr)
75 , header_set(true) {}
76
77 std::tuple<int, int> thread_slice_info(int ithr) const {
78 assert(ithr < nthr());
79
80 bool is_a = (which() == matrix_id::a);
81 auto nthr_inner = is_a ? threading().nthrs_m : threading().nthrs_n;
82
83 auto ithr_i = ithr % threading().nthrs_m;
84 auto ithr_jk = ithr / threading().nthrs_m;
85 auto ithr_j = ithr_jk % threading().nthrs_n;
86 auto ithr_k = ithr_jk / threading().nthrs_n;
87
88 auto ithr_inner = is_a ? ithr_i : ithr_j;
89 auto ithr_outer = ithr_k;
90 auto ithr_slice = is_a ? ithr_j : ithr_i;
91
92 auto id = ithr_outer * nthr_inner + ithr_inner;
93
94 return std::make_tuple(id, ithr_slice);
95 }
96
97 int thread_to_slice(int ithr) const {
98 return std::get<0>(thread_slice_info(ithr));
99 }
100
101 bool is_first_thread_in_slice(int ithr) const {
102 return (std::get<1>(thread_slice_info(ithr)) == 0);
103 }
104
105 template <typename data_type>
106 data_type *row_sums(int ithr, dim_t r0, dim_t cblock) const {
107 if (!has_row_sums()) return NULL;
108 auto id = thread_to_slice(ithr);
109 return get_block<data_type>(sums_header->slice[id], r0, cblock);
110 }
111
112 template <typename data_type>
113 data_type *col_sums(int ithr, dim_t rblock, dim_t c0) const {
114 if (!has_col_sums()) return NULL;
115 auto id = thread_to_slice(ithr);
116 return get_block<data_type>(sums_header->slice[id], rblock, c0);
117 }
118
119 template <typename data_type>
120 data_type *matrix(int ithr, dim_t r0, dim_t c0) const {
121 auto id = thread_to_slice(ithr);
122 return get_block<data_type>(matrix_header->slice[id], r0, c0);
123 }
124
125 template <typename data_type>
126 data_type *matrix(int ithr) const {
127 assert(!matrix_header->slice[thread_to_slice(ithr)].packed);
128 return matrix<data_type>(ithr, 0, 0);
129 }
130
131 template <typename data_type>
132 data_type *matrix() const {
133 assert(single_nocopy());
134 return matrix<data_type>(0);
135 }
136
137 bool get_nocopy(int ithr, int &trans, dim_t &ld, dim_t &td) const {
138 auto id = thread_to_slice(ithr);
139 return matrix_header->slice[id].get_nocopy(trans, ld, td);
140 }
141
142 bool get_nocopy(int &trans, dim_t &ld, dim_t &td) const {
143 if (!single_nocopy()) return false;
144 return get_nocopy(0, trans, ld, td);
145 }
146
147 void get_blocking(int ithr, dim_t &block_r, dim_t &block_c) const {
148 auto id = thread_to_slice(ithr);
149 matrix_header->slice[id].get_blocking(block_r, block_c);
150 }
151
152 void set_blocking(
153 int ithr, dim_t rows, dim_t cols, dim_t block_r, dim_t block_c) {
154
155 auto id = thread_to_slice(ithr);
156 auto nblk_r = (block_r == 0) ? 0 : utils::div_up(rows, block_r);
157 auto nblk_c = (block_c == 0) ? 0 : utils::div_up(cols, block_c);
158
159 matrix_header->slice[id].set_blocking(nblk_r, nblk_c, block_r, block_c);
160
161 if (has_row_sums())
162 sums_header->slice[id].set_blocking(nblk_r, nblk_c, block_r, 1);
163 else
164 sums_header->slice[id].set_blocking(nblk_r, nblk_c, 1, block_c);
165 }
166
167 void set_nocopy(int ithr, int trans, dim_t ld, dim_t td) {
168 auto id = thread_to_slice(ithr);
169 matrix_header->slice[id].set_nocopy(trans, ld, td);
170 }
171
172 void setup(int max_nthr, bool has_row_sums = false,
173 bool has_col_sums = false) {
174
175 assert(!(has_row_sums && has_col_sums));
176
177 auto sz_mh = matrix_header_size(max_nthr);
178 auto sz_h = header_size();
179
180 header->has_row_sums = has_row_sums;
181 header->has_col_sums = has_col_sums;
182 header->off_matrix = sz_h;
183 header->off_sums = sz_h + sz_mh;
184 total_header_size = sz_h + sz_mh * 2;
185
186 header->size = 0;
187
188 header_set = true;
189
190 reset(get());
191
192 for (int id = 0; id < max_nthr; id++) {
193 matrix_header->slice[id].set_blocking(0, 0, 0, 0);
194 sums_header->slice[id].set_blocking(0, 0, 0, 0);
195 }
196 }
197
198 template <typename matrix_dt, typename sums_dt>
199 void finalize() {
200 assert(total_header_size > 0);
201 size_t cur_off = total_header_size;
202
203 matrix_header->finalize<matrix_dt>(cur_off, nslice());
204 if (has_row_sums() || has_col_sums())
205 sums_header->finalize<sums_dt>(cur_off, nslice());
206
207 header->size = cur_off;
208
209 /* Compute kernels overrun to preload data. */
210 header->size += align_data;
211 }
212
213protected:
214 char *base;
215
216 struct header_t {
217 matrix_id which;
218 bool has_row_sums;
219 bool has_col_sums;
220 size_t off_matrix, off_sums;
221 size_t size;
222 gemm_threading_t threading; /* if packed */
223 } * header;
224
225 struct slice_header_t {
226 bool packed;
227 int trans;
228 int nblk_r, nblk_c;
229 dim_t block_r, block_c;
230 size_t off_data;
231
232 template <typename data_type>
233 size_t block_size() const {
234 return utils::rnd_up(
235 block_r * block_c * sizeof(data_type), align_data);
236 }
237
238 template <typename data_type>
239 size_t block_offset(dim_t r0, dim_t c0, bool col_major) const {
240 assert((r0 % block_r) == 0);
241 assert((c0 % block_c) == 0);
242
243 auto rb = r0 / block_r;
244 auto cb = c0 / block_c;
245 auto mb = col_major ? rb + cb * nblk_r : cb + rb * nblk_c;
246
247 return block_size<data_type>() * mb;
248 }
249
250 template <typename data_type>
251 size_t size() const {
252 return block_size<data_type>() * nblk_r * nblk_c;
253 }
254
255 void set_blocking(
256 int nblk_r_, int nblk_c_, dim_t block_r_, dim_t block_c_) {
257 packed = true;
258 nblk_r = nblk_r_;
259 nblk_c = nblk_c_;
260 block_r = block_r_;
261 block_c = block_c_;
262 }
263
264 void set_nocopy(int trans_, dim_t ld, dim_t td) {
265 packed = false;
266 trans = trans_;
267 block_r = ld;
268 block_c = td;
269 nblk_r = 1;
270 nblk_c = 1;
271 }
272
273 void get_blocking(dim_t &block_r_, dim_t &block_c_) const {
274 block_r_ = block_r;
275 block_c_ = block_c;
276 }
277
278 bool get_nocopy(int &trans_, dim_t &ld, dim_t &td) const {
279 if (!packed) {
280 trans_ = trans;
281 ld = block_r;
282 td = block_c;
283 }
284 return !packed;
285 }
286
287 template <typename data_type>
288 void finalize(size_t &cur_off) {
289 cur_off = utils::rnd_up(cur_off, align_data);
290 off_data = cur_off;
291 cur_off += size<data_type>();
292 }
293 };
294
295 struct matrix_header_t {
296 dim_t ld; /* if not packed */
297 slice_header_t slice[1]; /* array of size nthr, if packed */
298
299 template <typename data_type>
300 void finalize(size_t &cur_off, int nslices) {
301#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL
302 // This, I hope, is a temporary workaround...
303 // The reason for this special case is that in case of threadpool
304 // threading this function may be called to estimate the amount of
305 // memory needed when no threading information is actually
306 // available. Hence, it needs to provide an upper bound.
307 size_t max_off = cur_off;
308 for (int id = 0; id < nslices; id++) {
309 slice[id].finalize<data_type>(cur_off);
310 if (id == 0) {
311 // Assume that slice[0] is the largest one.
312 size_t slice0_size = cur_off - max_off;
313 max_off += slice0_size * dnnl_get_max_threads();
314 }
315 }
316 if (!threadpool_utils::get_active_threadpool() && nslices)
317 // The std::max is a paranoid check for the case when slice[0]
318 // is not actually the largest one. Probably a crash will
319 // happen anyways...
320 cur_off = std::max(cur_off, max_off);
321#else
322 for (int id = 0; id < nslices; id++)
323 slice[id].finalize<data_type>(cur_off);
324#endif
325 }
326 } * matrix_header, *sums_header;
327
328 size_t total_header_size = 0;
329
330 static constexpr auto align_headers = 0x20;
331 static constexpr auto align_data = 0x1000;
332
333 static size_t header_size() {
334 return utils::rnd_up(sizeof(header_t), align_headers);
335 }
336
337 static size_t matrix_header_size(int max_nthr) {
338 auto sz = sizeof(matrix_header_t)
339 + sizeof(slice_header_t) * (max_nthr - 1);
340
341 return utils::rnd_up(sz, align_headers);
342 }
343
344 template <typename data_type>
345 data_type *get_block(
346 const slice_header_t &slice, dim_t r0, dim_t c0) const {
347 return reinterpret_cast<data_type *>(base + slice.off_data
348 + slice.block_offset<data_type>(r0, c0, col_major()));
349 }
350
351 bool col_major() const { return (which() == matrix_id::a); }
352
353 void reset(void *data) {
354 set(data);
355
356 if (!header_set) return;
357
358 matrix_header = reinterpret_cast<matrix_header_t *>(
359 base + header->off_matrix);
360 sums_header
361 = reinterpret_cast<matrix_header_t *>(base + header->off_sums);
362 }
363
364 bool header_set = true;
365};
366
367struct gemm_pack_storage_shell_t : public gemm_pack_storage_t {
368
369 gemm_pack_storage_shell_t(int max_nthr, bool has_row_sums = false,
370 bool has_col_sums = false) {
371 void *ptr = malloc(shell_size(max_nthr), 64);
372 if (ptr) {
373 set(ptr);
374 setup(max_nthr, has_row_sums, has_col_sums);
375 }
376 }
377
378 ~gemm_pack_storage_shell_t() { free(get()); }
379
380private:
381 static size_t shell_size(int max_nthr) {
382 return header_size() + matrix_header_size(max_nthr) * 2;
383 }
384};
385
386} // namespace x64
387} // namespace cpu
388} // namespace impl
389} // namespace dnnl
390
391#endif
392