1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * This source code is licensed under the BSD-style license found in the |
5 | * LICENSE file in the root directory of this source tree. |
6 | */ |
7 | #define FBGEMM_EXPORTS |
8 | |
9 | #include <cpuinfo.h> |
10 | #include <iomanip> |
11 | #include <stdexcept> |
12 | #include <type_traits> |
13 | #include "fbgemm/Fbgemm.h" |
14 | |
15 | namespace fbgemm { |
16 | |
17 | template <typename PT, typename inpType, typename accType> |
18 | PackMatrix<PT, inpType, accType>::PackMatrix( |
19 | int32_t rows, |
20 | int32_t cols, |
21 | inpType* buf, |
22 | int groups, |
23 | const BlockingFactors* params) |
24 | : buf_(buf), nrows_(rows), ncols_(cols), G_(groups) { |
25 | bufAllocatedHere_ = false; |
26 | blocking_params = params; |
27 | if (!cpuinfo_initialize()) { |
28 | throw std::runtime_error("Failed to initialize cpuinfo!" ); |
29 | } |
30 | } |
31 | |
32 | template <typename PT, typename inpType, typename accType> |
33 | int PackMatrix<PT, inpType, accType>::packedBufferSize( |
34 | int rows, |
35 | int cols, |
36 | const BlockingFactors* params) { |
37 | if (!cpuinfo_initialize()) { |
38 | throw std::runtime_error("Failed to initialize cpuinfo!" ); |
39 | } |
40 | if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() && |
41 | !fbgemmHasAvx2Support())) { |
42 | assert(0 && "unknown architecure" ); |
43 | } |
44 | |
45 | int MCB, KCB, NCB; |
46 | if (params) { |
47 | MCB = params->MCB; |
48 | NCB = params->NCB; |
49 | KCB = params->KCB; |
50 | } else { |
51 | const inst_set_t isa = fbgemmInstructionSet(); |
52 | switch (isa) { |
53 | case inst_set_t::avx512_vnni: |
54 | MCB = PackingTraits<inpType, accType, inst_set_t::avx512_vnni>::MCB; |
55 | NCB = PackingTraits<inpType, accType, inst_set_t::avx512_vnni>::NCB; |
56 | KCB = PackingTraits<inpType, accType, inst_set_t::avx512_vnni>::KCB; |
57 | break; |
58 | |
59 | case inst_set_t::avx512_vnni_ymm: |
60 | MCB = PackingTraits<inpType, accType, inst_set_t::avx512_vnni_ymm>::MCB; |
61 | NCB = PackingTraits<inpType, accType, inst_set_t::avx512_vnni_ymm>::NCB; |
62 | KCB = PackingTraits<inpType, accType, inst_set_t::avx512_vnni_ymm>::KCB; |
63 | break; |
64 | |
65 | case inst_set_t::avx512: |
66 | MCB = PackingTraits<inpType, accType, inst_set_t::avx512>::MCB; |
67 | NCB = PackingTraits<inpType, accType, inst_set_t::avx512>::NCB; |
68 | KCB = PackingTraits<inpType, accType, inst_set_t::avx512>::KCB; |
69 | break; |
70 | |
71 | case inst_set_t::avx512_ymm: |
72 | MCB = PackingTraits<inpType, accType, inst_set_t::avx512_ymm>::MCB; |
73 | NCB = PackingTraits<inpType, accType, inst_set_t::avx512_ymm>::NCB; |
74 | KCB = PackingTraits<inpType, accType, inst_set_t::avx512_ymm>::KCB; |
75 | break; |
76 | |
77 | case inst_set_t::avx2: |
78 | MCB = PackingTraits<inpType, accType, inst_set_t::avx2>::MCB; |
79 | NCB = PackingTraits<inpType, accType, inst_set_t::avx2>::NCB; |
80 | KCB = PackingTraits<inpType, accType, inst_set_t::avx2>::KCB; |
81 | break; |
82 | |
83 | default: |
84 | assert(0 && "unknown architecure" ); |
85 | throw std::runtime_error("unknown architecure" ); |
86 | } |
87 | } |
88 | |
89 | if (isA()) { |
90 | return MCB * KCB; |
91 | } else { |
92 | int rowBlock = KCB; |
93 | int colBlock = NCB; |
94 | return (((rows + rowBlock - 1) / rowBlock) * rowBlock) * |
95 | (((cols + colBlock - 1) / colBlock) * colBlock); |
96 | } |
97 | |
98 | return -1; |
99 | } |
100 | |
101 | // int32 accumulation |
102 | template class PackMatrix<PackAMatrix<uint8_t, int32_t>, uint8_t, int32_t>; |
103 | |
104 | template class PackMatrix< |
105 | PackAWithRowOffset<uint8_t, int32_t>, |
106 | uint8_t, |
107 | int32_t>; |
108 | |
109 | template class PackMatrix< |
110 | PackAWithIm2Col<uint8_t, int32_t, 1>, |
111 | uint8_t, |
112 | int32_t>; |
113 | template class PackMatrix<PackAWithIm2Col<uint8_t, int32_t>, uint8_t, int32_t>; |
114 | template class PackMatrix< |
115 | PackAWithIm2Col<uint8_t, int32_t, 3>, |
116 | uint8_t, |
117 | int32_t>; |
118 | |
119 | template class PackMatrix< |
120 | PackAWithQuantRowOffset<uint8_t, int32_t>, |
121 | uint8_t, |
122 | int32_t>; |
123 | |
124 | template class PackMatrix<PackBMatrix<int8_t, int32_t>, int8_t, int32_t>; |
125 | |
126 | // int16 accumulation |
127 | template class PackMatrix< |
128 | PackAWithIm2Col<uint8_t, int16_t, 1>, |
129 | uint8_t, |
130 | int16_t>; |
131 | template class PackMatrix<PackAWithIm2Col<uint8_t, int16_t>, uint8_t, int16_t>; |
132 | template class PackMatrix< |
133 | PackAWithIm2Col<uint8_t, int16_t, 3>, |
134 | uint8_t, |
135 | int16_t>; |
136 | |
137 | template class PackMatrix< |
138 | PackAWithRowOffset<uint8_t, int16_t>, |
139 | uint8_t, |
140 | int16_t>; |
141 | |
142 | template class PackMatrix<PackAMatrix<uint8_t, int16_t>, uint8_t, int16_t>; |
143 | |
144 | template class PackMatrix<PackBMatrix<int8_t, int16_t>, int8_t, int16_t>; |
145 | } // namespace fbgemm |
146 | |