1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#define FBGEMM_EXPORTS
8#include <cpuinfo.h>
9#include <cassert>
10#include <cstring>
11#include <iomanip>
12#include <iostream>
13#include <stdexcept>
14#include "./OptimizedKernelsAvx2.h"
15#include "fbgemm/Fbgemm.h"
16
17namespace fbgemm {
18
19template <typename T, typename accT>
20PackAWithRowOffset<T, accT>::PackAWithRowOffset(
21 matrix_op_t trans,
22 uint32_t nRow,
23 uint32_t nCol,
24 const T* smat,
25 uint32_t ld,
26 inpType* pmat,
27 int groups,
28 int32_t* row_offset,
29 const BlockingFactors* params)
30 : PackMatrix<PackAWithRowOffset<T, accT>, T, accT>(
31 nRow,
32 nCol,
33 pmat,
34 groups,
35 params),
36 trans_(trans),
37 smat_(smat),
38 ld_(ld),
39 row_offset_(row_offset) {
40 if (!cpuinfo_initialize()) {
41 throw std::runtime_error("Failed to initialize cpuinfo!");
42 }
43 if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() &&
44 !fbgemmHasAvx2Support())) {
45 assert(0 && "unknown architecure");
46 }
47
48 if (params) {
49 BaseType::brow_ = params->MCB;
50 BaseType::bcol_ = params->KCB;
51 row_interleave_B_ = params->ROW_INTERLEAVE;
52 } else {
53 const inst_set_t isa = fbgemmInstructionSet();
54 switch (isa) {
55 case inst_set_t::avx512_vnni:
56 std::tie(BaseType::brow_, BaseType::bcol_, row_interleave_B_) =
57 PackingTraits<T, accT, inst_set_t::avx512_vnni>::
58 getMatrixPackAParams();
59 break;
60
61 case inst_set_t::avx512_vnni_ymm:
62 std::tie(BaseType::brow_, BaseType::bcol_, row_interleave_B_) =
63 PackingTraits<T, accT, inst_set_t::avx512_vnni_ymm>::
64 getMatrixPackAParams();
65 break;
66
67 case inst_set_t::avx512:
68 std::tie(BaseType::brow_, BaseType::bcol_, row_interleave_B_) =
69 PackingTraits<T, accT, inst_set_t::avx512>::getMatrixPackAParams();
70 break;
71
72 case inst_set_t::avx512_ymm:
73 std::tie(BaseType::brow_, BaseType::bcol_, row_interleave_B_) =
74 PackingTraits<T, accT, inst_set_t::avx512_ymm>::
75 getMatrixPackAParams();
76 break;
77
78 case inst_set_t::avx2:
79 std::tie(BaseType::brow_, BaseType::bcol_, row_interleave_B_) =
80 PackingTraits<T, accT, inst_set_t::avx2>::getMatrixPackAParams();
81 break;
82
83 default:
84 assert(0 && "unknown architecure");
85 throw std::runtime_error("unknown architecure");
86 }
87 }
88
89 rowOffsetAllocatedHere = false;
90
91 if (BaseType::numCols() % groups != 0) {
92 throw std::runtime_error(
93 "groups = " + std::to_string(groups) +
94 " does not divide numCols = " + std::to_string(BaseType::numCols()));
95 }
96 if (pmat) {
97 BaseType::buf_ = pmat;
98 } else {
99 BaseType::bufAllocatedHere_ = true;
100 BaseType::buf_ = static_cast<T*>(
101 fbgemmAlignedAlloc(64, BaseType::brow_ * BaseType::bcol_ * sizeof(T)));
102 }
103 if (!row_offset_) {
104 rowOffsetAllocatedHere = true;
105 row_offset_ = static_cast<int32_t*>(
106 fbgemmAlignedAlloc(64, BaseType::brow_ * sizeof(int32_t)));
107 }
108}
109
110template <typename T, typename accT>
111void PackAWithRowOffset<T, accT>::pack(const block_type_t& block) {
112 // assert(block.row_start % BaseType::blockRowSize() == 0);
113 assert(block.row_size <= BaseType::blockRowSize());
114 assert(block.col_size <= BaseType::blockColSize());
115
116 block_type_t block_p = {
117 block.row_start,
118 block.row_size,
119 block.col_start,
120 (block.col_size + row_interleave_B_ - 1) / row_interleave_B_ *
121 row_interleave_B_};
122 assert(block_p.col_size <= BaseType::blockColSize());
123 BaseType::packedBlock(block_p);
124
125 T* out = BaseType::getBuf();
126 bool tr = (trans_ == matrix_op_t::Transpose);
127 // accumulate into row offset?
128 bool row_offset_acc =
129 (block.col_start % (this->numCols() / this->numGroups())) != 0;
130 int32_t* row_offset_buf = getRowOffsetBuffer();
131 if (tr) {
132 for (int i = block.row_start; i < block.row_start + block.row_size; ++i) {
133 int buf_idx = i - block.row_start;
134 int32_t row_sum = row_offset_acc ? row_offset_buf[buf_idx] : 0;
135 for (int j = block.col_start; j < block.col_start + block.col_size; ++j) {
136 T val = smat_[i + j * ld_];
137 row_sum += val;
138 out[buf_idx * BaseType::blockColSize() + (j - block.col_start)] = val;
139 }
140 row_offset_buf[buf_idx] = row_sum;
141 // zero fill
142 // Please see the comment in PackAMatrix.cc on zero vs zero_pt fill.
143 for (int j = block.col_size; j < block_p.col_size; ++j) {
144 out[buf_idx * BaseType::blockColSize() + j] = 0;
145 }
146 }
147 } else {
148 // reduceAvx2 only written for T == uint8_t
149 static_assert(
150 std::is_same<T, uint8_t>::value,
151 "PackAWithRowOffset<T, accT>::pack only works for T == uint8_t");
152 for (int i = block.row_start; i < block.row_start + block.row_size; ++i) {
153 int buf_idx = i - block.row_start;
154 memcpy(
155 out + buf_idx * BaseType::blockColSize(),
156 smat_ + i * ld_ + block.col_start,
157 block.col_size * sizeof(T));
158 // zero fill
159 for (int j = block.col_size; j < block_p.col_size; ++j) {
160 out[buf_idx * BaseType::blockColSize() + j] = 0;
161 }
162 int32_t row_sum = row_offset_acc ? row_offset_buf[buf_idx] : 0;
163 row_sum += reduceAvx2(smat_ + i * ld_ + block.col_start, block.col_size);
164 row_offset_buf[buf_idx] = row_sum;
165 }
166 }
167}
168
169template <typename T, typename accT>
170int32_t PackAWithRowOffset<T, accT>::addr(int32_t r, int32_t c) const {
171 int32_t block_row_id = r / BaseType::blockRowSize();
172 int32_t brow_offset = (block_row_id * BaseType::blockCols()) *
173 (BaseType::blockRowSize() * BaseType::blockColSize());
174
175 int32_t block_col_id = c / BaseType::blockColSize();
176 int32_t bcol_offset =
177 block_col_id * BaseType::blockRowSize() * BaseType::blockColSize();
178 int32_t block_offset = brow_offset + bcol_offset;
179 int32_t inblock_offset =
180 (r % BaseType::blockRowSize()) * BaseType::blockColSize() +
181 (c % BaseType::blockColSize());
182
183 int32_t index = block_offset + inblock_offset;
184
185 return index;
186}
187
188template <typename T, typename accT>
189void PackAWithRowOffset<T, accT>::printPackedMatrix(std::string name) {
190 std::cout << name << ":"
191 << "[" << BaseType::numPackedRows() << ", "
192 << BaseType::numPackedCols() << "]" << std::endl;
193
194 T* out = BaseType::getBuf();
195 for (auto r = 0; r < BaseType::numPackedRows(); ++r) {
196 for (auto c = 0; c < BaseType::numPackedCols(); ++c) {
197 T val = out[addr(r, c)];
198 if (std::is_integral<T>::value) {
199 // cast to int64 because cout doesn't print int8_t type directly
200 std::cout << std::setw(5) << static_cast<int64_t>(val) << " ";
201 } else {
202 std::cout << std::setw(5) << val << " ";
203 }
204 }
205 std::cout << std::endl;
206 }
207 std::cout << std::endl;
208}
209
210template <typename T, typename accT>
211int PackAWithRowOffset<T, accT>::rowOffsetBufferSize(
212 const BlockingFactors* params) {
213 if (cpuinfo_initialize()) {
214 if (params) {
215 return params->MCB;
216 } else {
217 if (fbgemmHasAvx512VnniSupport()) {
218 return PackingTraits<T, accT, inst_set_t::avx512_vnni>::MCB;
219 } else if (fbgemmHasAvx512Support()) {
220 return PackingTraits<T, accT, inst_set_t::avx512>::MCB;
221 } else if (fbgemmHasAvx2Support()) {
222 return PackingTraits<T, accT, inst_set_t::avx2>::MCB;
223 } else {
224 // TODO: Have default slower path
225 assert(0 && "unsupported architecture");
226 return -1;
227 }
228 }
229 } else {
230 throw std::runtime_error("Failed to initialize cpuinfo!");
231 }
232}
233
234template class PackAWithRowOffset<uint8_t, int32_t>;
235template class PackAWithRowOffset<uint8_t, int16_t>;
236
237} // namespace fbgemm
238