1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#define FBGEMM_EXPORTS
8#include <cpuinfo.h>
9#include <cassert>
10#include <iomanip>
11#include <iostream>
12#include "fbgemm/Fbgemm.h"
13
14/*
15 * We pass in weights for Fully-connected and Convolution layers as B matrix.
16 * Since weights are constant during inference, B matrix is constant
17 * during inference so it's packed once and used multiple times. The code in
18 * this file takes care of fully packing B matrix. Fully packing means dividing
19 * the whole B matrix into blocks and storing all the blocks in the packed
20 * buffer instead of just 1 or some blocks.
21 *
22 * Packing refers to the rearranging of B elements to make it suitable to the
23 * way we access B in the inner compute kernel.
24 *
25 * Packing of B is dependent on three parameters: KCB, NCB and ROW_INTERLEAVE.
26 *
27 * Note 1: B is assumed to be in row-major format with K
28 * rows and N columns, i.e., the following B matrix with 3 rows and 5 columns
29 *
30 * B Matrix:
31 * b00 b01 b02 b03 b04
32 * b10 b11 b12 b13 b14
33 * b20 b21 b22 b23 b24
34 *
35 * is layed out in the memory as follows:
36 *
37 * B layout in memory (row major):
38 * b00 b01 b02 b03 b04 b10 b11 b12 b13 b14 b20 b21 b22 b23 b24
39 *
40 * Note 2: KCB is always restricted/expected to be a multiple of ROW_INTERLEAVE
41 * and thus it's minimum value is equal to ROW_INTERLEAVE.
42 *
43 * Note 3: ROW_INTERLEAVE is 2 for when we accumulate into 16-bits and 4 for
44 * when we accumulate into 32-bits.
45 *
46 * Note 4: Minimum value of NCB is such that the number of bits in
47 * NCB*ROW_INTERLEAVE elements at the very minimum is equal to the vector length
48 * (i.e., 256 for avx2 and 512 for avx512).
49 *
50 * Minimum NCB value for int8 data type:
51 * avx2 avx512
52 * acc16 16 32
53 * acc32 8 16
54 *
55 * Packing examples:
56 * Let us assume KCB=4, NCB=6 and ROW_INTERLEAVE=4 for the following examples.
57 * To keep things manageable in the examples, NCB is 6 which is less than the
58 * minimum value allowed for NCB as per the table above.
59 *
60 * * * * * * * * * * * * * * * * * * * *
61 *
62 * Example 1:
63 * Original B is an 8x4 matrix as follows:
64 * b00 b01 b02 b03
65 * b10 b11 b12 b13
66 * b20 b21 b22 b23
67 * b30 b31 b32 b33
68 * b40 b41 b42 b43
69 * b50 b51 b52 b53
70 * b60 b61 b62 b63
71 * b70 b71 b72 b73
72 *
73 * Packed matrix has 2 tiles along rows and 1 tile along columns. So
74 * allocated/needed memory for B buffer is (2*4)*(1*6) elements.
75 *
76 * Packed B matrix looks like as follows:
77 *
78 * b00 b10 b20 b30 b01 b11 b21 b31 b02 b12 b22 b32 b03 b13 b23 b33 x x x x x \
79 * x x x | b40 b50 b60 b70 b41 b51 b61 b71 b42 b52 b62 b72 b43 b53 b63 b73 x x \
80 * x x x x x x
81 *
82 * ROW_INTERLEAVE rows are mixed with columns and layed out sequentially.
83 *
84 * ("x" indicates uninitialized locations)
85 * ("|" indicates start of the next block; A block here refers to KCB*NCB
86 * elements.)
87 * ("\" indicates that the elements continue on the next line)
88 * (block 1 of size KCB*NCB directly follows block 0 of the same size)
89 *
90 * * * * * * * * * * * * * * * * * * * *
91 *
92 * Example 2:
93 * Original B is a 3x4 matrix as follows:
94 * b00 b01 b02 b03
95 * b10 b11 b12 b13
96 * b20 b21 b22 b23
97 *
98 * Packed matrix has 1 tile along rows and 1 tile along columns. So
99 * allocated/needed memory for B buffer is (1*4)*(1*6) elements.
100 *
101 * Packed B matrix looks like as follows:
102 *
103 * b00 b10 b20 0 b01 b11 b21 0 b02 b12 b22 0 b03 b13 b23 0 x x x x x x x x
104 *
105 * If a tile along rows has less than ROW_INTERLEAVE rows, interleaved elements
106 * are zero initialized.
107 *
108 * * * * * * * * * * * * * * * * * * * *
109 *
110 * Example 3:
111 * Original B is a 5x4 matrix as follows:
112 * b00 b01 b02 b03
113 * b10 b11 b12 b13
114 * b20 b21 b22 b23
115 * b30 b31 b32 b33
116 * b40 b41 b42 b43
117 *
118 * Packed matrix has 2 tiles along rows and 1 tile along columns. So
119 * allocated/needed memory for B buffer is (2*4)*(1*6) elements.
120 *
121 * Packed B matrix looks like as follows:
122 *
123 * b00 b10 b20 b30 b01 b11 b21 b31 b02 b12 b22 b32 b03 b13 b23 b33 x x x x x \
124 * x x x b40 0 0 0 b41 0 0 0 b42 0 0 0 b43 0 0 0 x x x x x x x x
125 *
126 * * * * * * * * * * * * * * * * * * * *
127 *
128 * Example 4:
129 * Original B is a 4x7 matrix as follows:
130 * b00 b01 b02 b03 b04 b05 b06
131 * b10 b11 b12 b13 b14 b15 b16
132 * b20 b21 b22 b23 b24 b25 b26
133 * b30 b31 b32 b33 b34 b35 b36
134 *
135 * Packed matrix has 1 tile along rows and 2 tiles along columns. So
136 * allocated/needed memory for B buffer is (1*4)*(2*6) elements.
137 *
138 * Packed B matrix looks like as follows:
139 *
140 * b00 b10 b20 b30 b01 b11 b21 b31 b02 b12 b22 b32 b03 b13 b23 b33 b04 b14
141 * b24 b34 b05 b15 b25 b35 | b06 b16 b26 b36 x x x x x x x x x x x x x x x x x \
142 * x x x
143 *
144 * * * * * * * * * * * * * * * * * * * *
145 *
146 * Example 5:
147 * Original B is a 5x7 matrix as follows:
148 * b00 b01 b02 b03 b04 b05 b06
149 * b10 b11 b12 b13 b14 b15 b16
150 * b20 b21 b22 b23 b24 b25 b26
151 * b30 b31 b32 b33 b34 b35 b36
152 * b40 b41 b42 b43 b44 b45 b46
153 *
154 * Packed matrix has 2 tiles along rows and 2 tiles along columns. So
155 * allocated/needed memory for B buffer is (2*4)*(2*6) elements.
156 *
157 * Packed B matrix looks like as follows:
158 *
159 * b00 b10 b20 b30 b01 b11 b21 b31 b02 b12 b22 b32 b03 b13 b23 b33 b04 b14 \
160 * b24 b34 b05 b15 b25 b35 | b06 b16 b26 b36 x x x x x x x x x x x x x x x x x \
161 * x x x | b40 0 0 0 b41 0 0 0 b42 0 0 0 b43 0 0 0 b44 0 0 0 b45 0 0 0 | b46 0 \
162 * 0 0 x x x x x x x x x x x x
163 *
164 * The kernel expects the B matrix to be packed in the way mentioned above for
165 * correct operation.
166 */
167
168namespace fbgemm {
169
170template <typename T, typename accT>
171PackBMatrix<T, accT>::PackBMatrix(
172 matrix_op_t trans,
173 int32_t nRow,
174 int32_t nCol,
175 const T* smat,
176 int32_t ld,
177 inpType* pmat,
178 int groups,
179 const BlockingFactors* params)
180 : PackMatrix<PackBMatrix<T, accT>, T, accT>(
181 nRow,
182 nCol,
183 pmat,
184 groups,
185 params),
186 trans_(trans),
187 smat_(smat),
188 ld_(ld) {
189 if (!cpuinfo_initialize()) {
190 throw std::runtime_error("Failed to initialize cpuinfo!");
191 }
192 if (params) {
193 BaseType::brow_ = params->KCB;
194 BaseType::bcol_ = params->NCB;
195 row_interleave_ = params->ROW_INTERLEAVE;
196 } else {
197 const inst_set_t isa = fbgemmInstructionSet();
198 switch (isa) {
199 case inst_set_t::avx512_vnni:
200 std::tie(BaseType::brow_, BaseType::bcol_, row_interleave_) =
201 PackingTraits<T, accT, inst_set_t::avx512_vnni>::
202 getMatrixPackBParams();
203 break;
204
205 case inst_set_t::avx512_vnni_ymm:
206 std::tie(BaseType::brow_, BaseType::bcol_, row_interleave_) =
207 PackingTraits<T, accT, inst_set_t::avx512_vnni_ymm>::
208 getMatrixPackBParams();
209 break;
210
211 case inst_set_t::avx512:
212 std::tie(BaseType::brow_, BaseType::bcol_, row_interleave_) =
213 PackingTraits<T, accT, inst_set_t::avx512>::getMatrixPackBParams();
214 break;
215
216 case inst_set_t::avx512_ymm:
217 std::tie(BaseType::brow_, BaseType::bcol_, row_interleave_) =
218 PackingTraits<T, accT, inst_set_t::avx512_ymm>::
219 getMatrixPackBParams();
220 break;
221
222 case inst_set_t::avx2:
223 std::tie(BaseType::brow_, BaseType::bcol_, row_interleave_) =
224 PackingTraits<T, accT, inst_set_t::avx2>::getMatrixPackBParams();
225 break;
226
227 default:
228 assert(0 && "unknown architecure");
229 throw std::runtime_error("unknown architecure");
230 }
231 }
232
233 if (BaseType::numRows() % groups != 0) {
234 throw std::runtime_error(
235 "groups = " + std::to_string(groups) +
236 " does not divide numRows = " + std::to_string(BaseType::numRows()));
237 }
238
239 // blocking for one group
240 block_type_t block{
241 0, BaseType::numRows() / BaseType::numGroups(), 0, BaseType::numCols()};
242 BaseType::packedBlock(block);
243 if (!pmat) {
244 BaseType::bufAllocatedHere_ = true;
245 BaseType::buf_ = static_cast<T*>(fbgemmAlignedAlloc(
246 64,
247 BaseType::numGroups() * BaseType::blockRows() * BaseType::brow_ *
248 BaseType::blockCols() * BaseType::bcol_ * sizeof(T)));
249 }
250 pack(block, params);
251}
252
253template <typename T, typename accT>
254void PackBMatrix<T, accT>::pack_unpack_(
255 const block_type_t& block,
256 T* unpack_buf,
257 T* pack_buf,
258 bool ispack,
259 const BlockingFactors* params) {
260 assert((BaseType::blockRowSize() % row_interleave_) == 0);
261 assert((block.row_start % BaseType::blockRowSize()) == 0);
262 assert((block.col_start % BaseType::blockColSize()) == 0);
263
264 // When T is char *, type-based alias analysis (TBAA) cannot prove
265 // that `unpack_buf` and `pack_buf` do not alias `block` (because
266 // char * is the one exception to the C++ strict aliasing rule), so the
267 // compiler would have to re-load these attributes from `block` on
268 // every loop iteration for correctness. We know better, so let's
269 // help the compiler out by doing the loads ourselves into
270 // constants.
271 const auto blockRowStart = block.row_start;
272 const auto blockRowSize = block.row_size;
273 const auto blockColStart = block.col_start;
274 const auto blockColSize = block.col_size;
275
276 BaseType::packedBlock(block);
277 bool tr = (trans_ == matrix_op_t::Transpose);
278 for (int g = 0; g < BaseType::numGroups(); ++g) {
279 T* pack_buf_cur = pack_buf +
280 g * BaseType::packedBufferSize(blockRowSize, blockColSize, params);
281 for (int i = blockRowStart; i < blockRowStart + blockRowSize; ++i) {
282 int r_offset = ((i / BaseType::blockRowSize()) * BaseType::blockCols()) *
283 (BaseType::blockRowSize() * BaseType::blockColSize()) +
284 (i % BaseType::blockRowSize() / row_interleave_) *
285 BaseType::blockColSize() * row_interleave_ +
286 i % row_interleave_;
287
288 int c_start_offset = (blockColStart / BaseType::blockColSize()) *
289 BaseType::blockRowSize() * BaseType::blockColSize() +
290 (blockColStart % BaseType::blockColSize()) * row_interleave_;
291
292 int c_idx_offset = 0;
293 int c_blk_offset = 0;
294 for (int j = blockColStart; j < blockColStart + blockColSize; ++j) {
295 // int c_offset = (j / BaseType::blockColSize()) *
296 // BaseType::blockRowSize() * BaseType::blockColSize() +
297 // (j % BaseType::blockColSize()) * row_interleave_;
298 // 1. Loop invariant hoisting (move block offset calculation out of
299 // inner loop); 2. Strength reduction (change modulus in inner loop to
300 // an increment + rollover).
301 int c_offset = c_start_offset +
302 c_blk_offset * BaseType::blockRowSize() * BaseType::blockColSize() +
303 c_idx_offset * row_interleave_;
304
305 if (ispack) {
306 pack_buf_cur[r_offset + c_offset] = tr
307 ? unpack_buf[i + (g * blockColSize + j) * ld_]
308 : unpack_buf[(g * blockRowSize + i) * ld_ + j];
309 } else {
310 T* unpack_buf_cur = tr
311 ? &(unpack_buf[i + (g * blockColSize + j) * ld_])
312 : &(unpack_buf[(g * blockRowSize + i) * ld_ + j]);
313 *unpack_buf_cur = pack_buf_cur[r_offset + c_offset];
314 }
315
316 c_idx_offset++;
317 if (c_idx_offset == BaseType::blockColSize()) {
318 c_idx_offset = 0;
319 c_blk_offset++;
320 }
321 }
322 }
323 if (ispack) {
324 // fill the remaining with zero.
325 // Please see the comment in PackAMatrix.cc on zero vs zero_pt fill.
326 for (int i = blockRowStart + blockRowSize;
327 i < (blockRowStart + blockRowSize + row_interleave_ - 1) /
328 row_interleave_ * row_interleave_;
329 ++i) {
330 int r_offset =
331 ((i / BaseType::blockRowSize()) * BaseType::blockCols()) *
332 (BaseType::blockRowSize() * BaseType::blockColSize()) +
333 (i % BaseType::blockRowSize() / row_interleave_) *
334 BaseType::blockColSize() * row_interleave_ +
335 i % row_interleave_;
336 for (int j = blockColStart; j < blockColStart + blockColSize; j++) {
337 int c_offset = (j / BaseType::blockColSize()) *
338 BaseType::blockRowSize() * BaseType::blockColSize() +
339 (j % BaseType::blockColSize()) * row_interleave_;
340
341 int out_idx = r_offset + c_offset;
342 pack_buf_cur[out_idx] = 0;
343 }
344 }
345 }
346 } // for each group
347}
348
349template <typename T, typename accT>
350void PackBMatrix<T, accT>::pack(
351 const block_type_t& block,
352 const BlockingFactors* params) {
353 pack_unpack_(block, const_cast<T*>(smat_), BaseType::getBuf(), true, params);
354}
355
356template <typename T, typename accT>
357void PackBMatrix<T, accT>::unpack(
358 T* origin_buf,
359 const BlockingFactors* params) {
360 block_type_t blockB{
361 BaseType::packedRowStart(),
362 BaseType::numPackedRows(),
363 BaseType::packedColStart(),
364 BaseType::numPackedCols()};
365 pack_unpack_(blockB, origin_buf, BaseType::getBuf(), false, params);
366}
367
368template <typename T, typename accT>
369int32_t PackBMatrix<T, accT>::addr(int32_t r, int32_t c) const {
370 int32_t block_row_id = r / BaseType::blockRowSize();
371 int32_t brow_offset = (block_row_id * BaseType::blockCols()) *
372 (BaseType::blockRowSize() * BaseType::blockColSize());
373
374 int32_t block_col_id = c / BaseType::blockColSize();
375 int32_t bcol_offset =
376 block_col_id * BaseType::blockRowSize() * BaseType::blockColSize();
377 int32_t block_offset = brow_offset + bcol_offset;
378 int32_t inblock_offset = (r % BaseType::blockRowSize() / row_interleave_) *
379 BaseType::blockColSize() * row_interleave_ +
380 (c % BaseType::blockColSize()) * row_interleave_ + r % row_interleave_;
381
382 int32_t index = block_offset + inblock_offset;
383
384 return index;
385}
386
387template <typename T, typename accT>
388void PackBMatrix<T, accT>::printPackedMatrix(
389 std::string name,
390 const BlockingFactors* params) {
391 std::cout << name << ":"
392 << "[" << BaseType::numPackedRows() << ", "
393 << BaseType::numPackedCols() << "]" << std::endl;
394 std::cout << "block size:"
395 << "[" << BaseType::blockRowSize() << ", "
396 << BaseType::blockColSize() << "]" << std::endl;
397
398 for (int g = 0; g < BaseType::numGroups(); ++g) {
399 T* out = BaseType::getBuf() +
400 g *
401 BaseType::packedBufferSize(
402 BaseType::numPackedRows(), BaseType::numPackedCols(), params);
403 std::cout << "group: " << g << std::endl;
404 for (auto nr = 0; nr < BaseType::blockRows(); ++nr) {
405 auto rows = (nr == BaseType::blockRows() - 1) ? BaseType::lastBrow()
406 : BaseType::blockRowSize();
407 for (auto nc = 0; nc < BaseType::blockCols(); ++nc) {
408 std::cout << "block:" << nr << ", " << nc << std::endl;
409 auto cols = (nc == BaseType::blockCols() - 1)
410 ? BaseType::lastBcol()
411 : BaseType::blockColSize();
412 for (auto r = 0; r < (rows + row_interleave_ - 1) / row_interleave_;
413 ++r) {
414 for (auto c = 0; c < cols * row_interleave_; ++c) {
415 T val =
416 out[nr * BaseType::blockCols() * BaseType::blockRowSize() *
417 BaseType::blockColSize() +
418 nc * BaseType::blockRowSize() * BaseType::blockColSize() +
419 r * BaseType::blockColSize() * row_interleave_ + c];
420 if (std::is_integral<T>::value) {
421 // cast to int64 because cout doesn't print int8_t type directly
422 std::cout << std::setw(5) << static_cast<int64_t>(val) << " ";
423 } else {
424 std::cout << std::setw(5) << val << " ";
425 }
426 }
427 std::cout << std::endl;
428 }
429 std::cout << std::endl;
430 }
431 }
432 }
433}
434
435template <typename T, typename accT>
436bool PackBMatrix<T, accT>::metaEquals(const PackBMatrix<T, accT>& that) const {
437 if (BaseType::numRows() != that.numRows() ||
438 BaseType::numCols() != that.numCols() ||
439 BaseType::blockRowSize() != that.blockRowSize() ||
440 BaseType::blockColSize() != that.blockColSize() ||
441 BaseType::blockRows() != that.blockRows() ||
442 BaseType::blockCols() != that.blockCols() ||
443 BaseType::numPackedRows() != that.numPackedRows() ||
444 BaseType::numPackedCols() != that.numPackedCols() ||
445 trans_ != that.trans_ || BaseType::numGroups() != that.numGroups() ||
446 row_interleave_ != that.row_interleave_) {
447 return false;
448 }
449
450 return true;
451}
452
453template <typename T, typename accT>
454bool PackBMatrix<T, accT>::equals(const PackBMatrix<T, accT>& that) const {
455 if (!metaEquals(that)) {
456 return false;
457 }
458
459 for (int i = 0; i < this->numRows(); ++i) {
460 for (int j = 0; j < this->numCols(); ++j) {
461 if (this->buf_[addr(i, j)] != that.buf_[that.addr(i, j)]) {
462 return false;
463 }
464 }
465 }
466
467 return true;
468}
469
470template class PackBMatrix<int8_t, int32_t>;
471template class PackBMatrix<int8_t, int16_t>;
472} // namespace fbgemm
473