1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#define FBGEMM_EXPORTS
8#include <cpuinfo.h>
9#include <cassert>
10#include <iomanip>
11#include <iostream>
12#include "fbgemm/Fbgemm.h"
13
14namespace fbgemm {
15
16template <typename T, typename accT>
17PackAMatrix<T, accT>::PackAMatrix(
18 matrix_op_t trans,
19 int32_t nRow,
20 int32_t nCol,
21 const T* smat,
22 int32_t ld,
23 inpType* pmat,
24 int groups,
25 const BlockingFactors* params)
26 : PackMatrix<PackAMatrix<T, accT>, T, accT>(
27 nRow,
28 nCol,
29 pmat,
30 groups,
31 params),
32 trans_(trans),
33 smat_(smat),
34 ld_(ld) {
35 if (!cpuinfo_initialize()) {
36 throw std::runtime_error("Failed to initialize cpuinfo!");
37 }
38 if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() &&
39 !fbgemmHasAvx2Support())) {
40 assert(0 && "unknown architecure");
41 }
42
43 if (params) {
44 BaseType::brow_ = params->MCB;
45 BaseType::bcol_ = params->KCB;
46 row_interleave_B_ = params->ROW_INTERLEAVE;
47 } else {
48 const inst_set_t isa = fbgemmInstructionSet();
49 switch (isa) {
50 case inst_set_t::avx512_vnni:
51 std::tie(BaseType::brow_, BaseType::bcol_, row_interleave_B_) =
52 PackingTraits<T, accT, inst_set_t::avx512_vnni>::
53 getMatrixPackAParams();
54 break;
55
56 case inst_set_t::avx512_vnni_ymm:
57 std::tie(BaseType::brow_, BaseType::bcol_, row_interleave_B_) =
58 PackingTraits<T, accT, inst_set_t::avx512_vnni_ymm>::
59 getMatrixPackAParams();
60 break;
61
62 case inst_set_t::avx512:
63 std::tie(BaseType::brow_, BaseType::bcol_, row_interleave_B_) =
64 PackingTraits<T, accT, inst_set_t::avx512>::getMatrixPackAParams();
65 break;
66
67 case inst_set_t::avx512_ymm:
68 std::tie(BaseType::brow_, BaseType::bcol_, row_interleave_B_) =
69 PackingTraits<T, accT, inst_set_t::avx512_ymm>::
70 getMatrixPackAParams();
71 break;
72
73 case inst_set_t::avx2:
74 std::tie(BaseType::brow_, BaseType::bcol_, row_interleave_B_) =
75 PackingTraits<T, accT, inst_set_t::avx2>::getMatrixPackAParams();
76 break;
77
78 default:
79 assert(0 && "unknown architecure");
80 throw std::runtime_error("unknown architecure");
81 }
82 }
83
84 if (BaseType::numCols() % groups != 0) {
85 throw std::runtime_error(
86 "groups = " + std::to_string(groups) +
87 " does not divide numCols = " + std::to_string(BaseType::numCols()));
88 }
89 if (pmat) {
90 BaseType::buf_ = pmat;
91 } else {
92 BaseType::bufAllocatedHere_ = true;
93 BaseType::buf_ = static_cast<T*>(
94 fbgemmAlignedAlloc(64, BaseType::brow_ * BaseType::bcol_ * sizeof(T)));
95 }
96}
97
98template <typename T, typename accT>
99void PackAMatrix<T, accT>::pack(const block_type_t& block) {
100 block_type_t block_p = {
101 block.row_start,
102 block.row_size,
103 block.col_start,
104 (block.col_size + row_interleave_B_ - 1) / row_interleave_B_ *
105 row_interleave_B_};
106
107 BaseType::packedBlock(block_p);
108 bool tr = (trans_ == matrix_op_t::Transpose);
109 T* out = BaseType::getBuf();
110 if (tr) {
111 // TODO: should print warning because this path is not optimized yet
112 for (int i = block.row_start; i < block.row_start + block.row_size; ++i) {
113 int buf_idx = i - block.row_start;
114 for (int j = block.col_start; j < block.col_start + block.col_size; ++j) {
115 T val = smat_[i + j * ld_];
116 out[buf_idx * BaseType::blockColSize() + (j - block.col_start)] = val;
117 }
118 // zero fill
119 // Please note that we zero fill, not zero_pt fill, because for
120 // requantization original, i.e., not padded, dimensions are used. If we
121 // were to use padded dimensions for requantization, we would zero_pt
122 // fill.
123 // For example, consider the following dot product:
124 // A = .3(5-15), .3(20-15) //.3 is scale and 15 is zero_pt
125 // B = .4(1+10), .4(4+10) // .4 is scale and -10 is zero_pt
126 //
127 // numElements(A) = 2 and numElements(B) = 2
128 //
129 // Dot product is (real): -3*4.4+1.5*5.6 = -4.8
130 // Dot product is (quantized): 5*1+20*4 = 85
131 //
132 // requantization: .3*.4(85 - (5+20)*(-10) - (1+4)*(15) +
133 // numElements(A)*(15)(-10)) = -4.8
134 //
135 // In the above adding one more element zero in the quantized domain,
136 // i.e., the quantized vectors become:
137 // A_q = 5, 20, 0
138 // B_q = 1, 4, 0
139 //
140 // and requantization with numElements(A) = 2 will produce the same
141 // answer (-4.8).
142 //
143 // Also in the above adding one more element zero_pt in the quantized
144 // domain, i.e., the quantized vectors become:
145 // A_q = 5, 20, 15
146 // B_q = 1, 4, -10
147 //
148 // and requantization with numElements(A) = 3 will produce the same
149 // answer (-4.8).
150 for (int j = block.col_size; j < block_p.col_size; ++j) {
151 out[buf_idx * BaseType::blockColSize() + j] = 0;
152 }
153 }
154 } else {
155 for (int i = block.row_start; i < block.row_start + block.row_size; ++i) {
156 int buf_idx = i - block.row_start;
157 memcpy(
158 out + buf_idx * BaseType::blockColSize(),
159 smat_ + i * ld_ + block.col_start,
160 block.col_size * sizeof(T));
161 // zero fill
162 for (int j = block.col_size; j < block_p.col_size; ++j) {
163 out[buf_idx * BaseType::blockColSize() + j] = 0;
164 }
165 }
166 }
167}
168
169template <typename T, typename accT>
170int32_t PackAMatrix<T, accT>::addr(int32_t r, int32_t c) const {
171 int32_t block_row_id = r / BaseType::blockRowSize();
172 int32_t brow_offset = (block_row_id * BaseType::blockCols()) *
173 (BaseType::blockRowSize() * BaseType::blockColSize());
174
175 int32_t block_col_id = c / BaseType::blockColSize();
176 int32_t bcol_offset =
177 block_col_id * BaseType::blockRowSize() * BaseType::blockColSize();
178 int32_t block_offset = brow_offset + bcol_offset;
179 int32_t inblock_offset =
180 (r % BaseType::blockRowSize()) * BaseType::blockColSize() +
181 (c % BaseType::blockColSize());
182
183 int32_t index = block_offset + inblock_offset;
184
185 return index;
186}
187
188template <typename T, typename accT>
189void PackAMatrix<T, accT>::printPackedMatrix(std::string name) {
190 std::cout << name << ":"
191 << "[" << BaseType::numPackedRows() << ", "
192 << BaseType::numPackedCols() << "]" << std::endl;
193
194 T* out = BaseType::getBuf();
195 for (auto r = 0; r < BaseType::numPackedRows(); ++r) {
196 for (auto c = 0; c < BaseType::numPackedCols(); ++c) {
197 T val = out[addr(r, c)];
198 if (std::is_integral<T>::value) {
199 // cast to int64 because cout doesn't print int8_t type directly
200 std::cout << std::setw(5) << static_cast<int64_t>(val) << " ";
201 } else {
202 std::cout << std::setw(5) << val << " ";
203 }
204 }
205 std::cout << std::endl;
206 }
207 std::cout << std::endl;
208}
209
210template class PackAMatrix<uint8_t, int32_t>;
211template class PackAMatrix<uint8_t, int16_t>;
212} // namespace fbgemm
213