1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#define FBGEMM_EXPORTS
8
9#include <cpuinfo.h>
10#include <iomanip>
11#include <stdexcept>
12#include <type_traits>
13#include "fbgemm/Fbgemm.h"
14
15namespace fbgemm {
16
17template <typename PT, typename inpType, typename accType>
18PackMatrix<PT, inpType, accType>::PackMatrix(
19 int32_t rows,
20 int32_t cols,
21 inpType* buf,
22 int groups,
23 const BlockingFactors* params)
24 : buf_(buf), nrows_(rows), ncols_(cols), G_(groups) {
25 bufAllocatedHere_ = false;
26 blocking_params = params;
27 if (!cpuinfo_initialize()) {
28 throw std::runtime_error("Failed to initialize cpuinfo!");
29 }
30}
31
32template <typename PT, typename inpType, typename accType>
33int PackMatrix<PT, inpType, accType>::packedBufferSize(
34 int rows,
35 int cols,
36 const BlockingFactors* params) {
37 if (!cpuinfo_initialize()) {
38 throw std::runtime_error("Failed to initialize cpuinfo!");
39 }
40 if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() &&
41 !fbgemmHasAvx2Support())) {
42 assert(0 && "unknown architecure");
43 }
44
45 int MCB, KCB, NCB;
46 if (params) {
47 MCB = params->MCB;
48 NCB = params->NCB;
49 KCB = params->KCB;
50 } else {
51 const inst_set_t isa = fbgemmInstructionSet();
52 switch (isa) {
53 case inst_set_t::avx512_vnni:
54 MCB = PackingTraits<inpType, accType, inst_set_t::avx512_vnni>::MCB;
55 NCB = PackingTraits<inpType, accType, inst_set_t::avx512_vnni>::NCB;
56 KCB = PackingTraits<inpType, accType, inst_set_t::avx512_vnni>::KCB;
57 break;
58
59 case inst_set_t::avx512_vnni_ymm:
60 MCB = PackingTraits<inpType, accType, inst_set_t::avx512_vnni_ymm>::MCB;
61 NCB = PackingTraits<inpType, accType, inst_set_t::avx512_vnni_ymm>::NCB;
62 KCB = PackingTraits<inpType, accType, inst_set_t::avx512_vnni_ymm>::KCB;
63 break;
64
65 case inst_set_t::avx512:
66 MCB = PackingTraits<inpType, accType, inst_set_t::avx512>::MCB;
67 NCB = PackingTraits<inpType, accType, inst_set_t::avx512>::NCB;
68 KCB = PackingTraits<inpType, accType, inst_set_t::avx512>::KCB;
69 break;
70
71 case inst_set_t::avx512_ymm:
72 MCB = PackingTraits<inpType, accType, inst_set_t::avx512_ymm>::MCB;
73 NCB = PackingTraits<inpType, accType, inst_set_t::avx512_ymm>::NCB;
74 KCB = PackingTraits<inpType, accType, inst_set_t::avx512_ymm>::KCB;
75 break;
76
77 case inst_set_t::avx2:
78 MCB = PackingTraits<inpType, accType, inst_set_t::avx2>::MCB;
79 NCB = PackingTraits<inpType, accType, inst_set_t::avx2>::NCB;
80 KCB = PackingTraits<inpType, accType, inst_set_t::avx2>::KCB;
81 break;
82
83 default:
84 assert(0 && "unknown architecure");
85 throw std::runtime_error("unknown architecure");
86 }
87 }
88
89 if (isA()) {
90 return MCB * KCB;
91 } else {
92 int rowBlock = KCB;
93 int colBlock = NCB;
94 return (((rows + rowBlock - 1) / rowBlock) * rowBlock) *
95 (((cols + colBlock - 1) / colBlock) * colBlock);
96 }
97
98 return -1;
99}
100
101// int32 accumulation
102template class PackMatrix<PackAMatrix<uint8_t, int32_t>, uint8_t, int32_t>;
103
104template class PackMatrix<
105 PackAWithRowOffset<uint8_t, int32_t>,
106 uint8_t,
107 int32_t>;
108
109template class PackMatrix<
110 PackAWithIm2Col<uint8_t, int32_t, 1>,
111 uint8_t,
112 int32_t>;
113template class PackMatrix<PackAWithIm2Col<uint8_t, int32_t>, uint8_t, int32_t>;
114template class PackMatrix<
115 PackAWithIm2Col<uint8_t, int32_t, 3>,
116 uint8_t,
117 int32_t>;
118
119template class PackMatrix<
120 PackAWithQuantRowOffset<uint8_t, int32_t>,
121 uint8_t,
122 int32_t>;
123
124template class PackMatrix<PackBMatrix<int8_t, int32_t>, int8_t, int32_t>;
125
126// int16 accumulation
127template class PackMatrix<
128 PackAWithIm2Col<uint8_t, int16_t, 1>,
129 uint8_t,
130 int16_t>;
131template class PackMatrix<PackAWithIm2Col<uint8_t, int16_t>, uint8_t, int16_t>;
132template class PackMatrix<
133 PackAWithIm2Col<uint8_t, int16_t, 3>,
134 uint8_t,
135 int16_t>;
136
137template class PackMatrix<
138 PackAWithRowOffset<uint8_t, int16_t>,
139 uint8_t,
140 int16_t>;
141
142template class PackMatrix<PackAMatrix<uint8_t, int16_t>, uint8_t, int16_t>;
143
144template class PackMatrix<PackBMatrix<int8_t, int16_t>, int8_t, int16_t>;
145} // namespace fbgemm
146