1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#define FBGEMM_EXPORTS
8#include <cpuinfo.h>
9#include <cassert>
10#include <iomanip>
11#include <numeric>
12#include "./RefImplementations.h"
13#include "fbgemm/Fbgemm.h"
14#include "fbgemm/SimdUtils.h"
15
16namespace fbgemm {
17
18template <typename T, typename accT, int SPATIAL_DIM>
19PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::PackWeightMatrixForGConv(
20 matrix_op_t trans,
21 const conv_param_t<SPATIAL_DIM>& conv_param,
22 const T* sdata,
23 T* pdata)
24 : trans_(trans), conv_param_(conv_param), sdata_(sdata) {
25 if (!cpuinfo_initialize()) {
26 throw std::runtime_error("Failed to initialize cpuinfo!");
27 }
28 GTogether_ = numOfGroupsTogether(conv_param_);
29 assert(
30 GTogether_ <= conv_param_.G &&
31 "Number of groups together smaller than total number of groups");
32 if (!pdata) {
33 bufAllocatedHere_ = true;
34 int kernel_prod = std::accumulate(
35 conv_param.K.begin(), conv_param.K.end(), 1, std::multiplies<int>());
36 // we make it a multiple of 4
37 int paddedICPerG = ((conv_param_.IC / conv_param_.G) + 3) / 4 * 4;
38 pdata_ = static_cast<T*>(fbgemmAlignedAlloc(
39 64,
40 (conv_param_.G + GTogether_ - 1) / GTogether_ * GTogether_ *
41 kernel_prod * (conv_param_.OC / conv_param_.G) * paddedICPerG *
42 sizeof(T)));
43 } else {
44 bufAllocatedHere_ = false;
45 pdata_ = pdata;
46 }
47
48 pack();
49}
50
51template <typename T, typename accT, int SPATIAL_DIM>
52int PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::numOfGroupsTogether(
53 const conv_param_t<SPATIAL_DIM>& conv_param) {
54 int OC_per_G = conv_param.OC / conv_param.G;
55 int IC_per_G = conv_param.IC / conv_param.G;
56 if (fbgemmHasAvx512Support() || fbgemmHasAvx512VnniSupport()) {
57 // TODO: change to avx512 when avx512 support is available
58 return std::max(
59 simd_info<inst_set_t::avx512>::WIDTH_BYTES / OC_per_G /
60 std::max(IC_per_G, 4),
61 1);
62 } else {
63 // avx2
64 // e.g., IC_per_G == 4, we need to work on 2 groups at a time
65 return std::max(
66 simd_info<inst_set_t::avx2>::WIDTH_BYTES / OC_per_G /
67 std::max(IC_per_G, 4),
68 1);
69 }
70 return 1;
71}
72
73/**
74 * @brief Get the index of the unpacked data
75 * for a given <t, r, s, k, g, c, tr>
76 *
77 * Non-transposed: G (T R S C/G) K/G
78 * Transposed: G K/G (T R S C/G)
79 * Using inline as this will be called frequently
80 */
81template <typename T, typename accT, int SPATIAL_DIM>
82inline int PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::unpacked_index_(
83 int t,
84 int r,
85 int s,
86 int k,
87 int g,
88 int c,
89 bool tr) {
90 // Get the full dimensions
91 // Can't use T as varname because T is a template parameter.
92 int F = SPATIAL_DIM <= 2 ? 1 : conv_param_.K[SPATIAL_DIM - 3];
93 int R = SPATIAL_DIM == 1 ? 1 : conv_param_.K[SPATIAL_DIM - 2];
94 int S = conv_param_.K[SPATIAL_DIM - 1];
95 int G = conv_param_.G;
96 int IC_per_G = conv_param_.IC / G;
97 int OC_per_G = conv_param_.OC / G;
98
99 int idx;
100 if (tr) {
101 idx = ((((g * OC_per_G + k) * F + t) * R + r) * S + s) * IC_per_G + c;
102 } else {
103 idx = ((((g * F + t) * R + r) * S + s) * IC_per_G + c) * OC_per_G + k;
104 }
105 return idx;
106}
107
108/**
109 * @brief Get the index of the packed data for a given <t, r, s, k, g, c>
110 *
111 * The index may differ depending on IC_per_G.
112 * Using inline as this will be called frequently
113 */
114template <typename T, typename accT, int SPATIAL_DIM>
115inline int PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::packed_index_(
116 int t,
117 int r,
118 int s,
119 int k,
120 int g,
121 int c) {
122 // Get the full dimensions
123 // Can't use T as varname because T is a template parameter.
124 int F = SPATIAL_DIM <= 2 ? 1 : conv_param_.K[SPATIAL_DIM - 3];
125 int R = SPATIAL_DIM == 1 ? 1 : conv_param_.K[SPATIAL_DIM - 2];
126 int S = conv_param_.K[SPATIAL_DIM - 1];
127 int G = conv_param_.G;
128 int IC_per_G = conv_param_.IC / G;
129 int OC_per_G = conv_param_.OC / G;
130 int paddedICPerG = (IC_per_G + 3) / 4 * 4;
131
132 int idx = ((((((g / GTogether_) * F + t) * R + r) * S + s) * OC_per_G + k) *
133 GTogether_ +
134 (g % GTogether_)) *
135 paddedICPerG +
136 c;
137 return idx;
138}
139
140/**
141 * @brief Pack or unpack matrix
142 *
143 * Let IC_per_G be number of input channels per group and OC_per_G be number of
144 * output channels per group.
145 *
146 * For IC_per_G == 4 && OC_per_G == 4 optimized
147 * kernel works on 2 groups at a time hence input channels for g and g+1 group
148 * are laid out sequentially for each output channel, i.e., the layout is (G/2)
149 * R S K (2C) and K (2C) is in each 32B vector.
150 * We work on two groups at a time to fully utilize the avx2 SIMD width of
151 * 256-bits.
152 *
153 * For IC_per_G == 8, 16, 32 && OC_per_G == 8, 16, 32 there is no need to work
154 * on 2 groups at a time and full SIMD width can be efficiently utilized even
155 * while working on 1 group at a time.
156 * In this case, the layout is G R S K_per_G paddedICPerG
157 */
158
159template <typename T, typename accT, int SPATIAL_DIM>
160void PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::pack_unpack_(
161 const T* src,
162 T* dst,
163 bool ispack) {
164 // Can't use T as varname because T is a template parameter.
165 int F = SPATIAL_DIM <= 2 ? 1 : conv_param_.K[SPATIAL_DIM - 3];
166 int R = SPATIAL_DIM == 1 ? 1 : conv_param_.K[SPATIAL_DIM - 2];
167 int S = conv_param_.K[SPATIAL_DIM - 1];
168 int G = conv_param_.G;
169 int IC_per_G = conv_param_.IC / G;
170 int OC_per_G = conv_param_.OC / G;
171 int paddedICPerG = (IC_per_G + 3) / 4 * 4;
172
173 // If transpose option is set, the weight matrix is in layout G K/G (T R S
174 // C/G) instead of G (T R S C/G) K/G
175 bool tr = (trans_ == matrix_op_t::Transpose);
176 if (fbgemmOptimizedGConv(conv_param_)) {
177 // currently only this case is supported
178 for (int t = 0; t < F; ++t) {
179 for (int r = 0; r < R; ++r) {
180 for (int s = 0; s < S; ++s) {
181 for (int k = 0; k < OC_per_G; ++k) {
182 for (int g = 0; g < G; ++g) {
183 for (int c = 0; c < IC_per_G; ++c) {
184 int p_idx = packed_index_(t, r, s, k, g, c);
185 int up_idx = unpacked_index_(t, r, s, k, g, c, tr);
186 // Pack: src (unpacked) -> dst (packed)
187 if (ispack) {
188 dst[p_idx] = src[up_idx];
189 } else {
190 dst[up_idx] = src[p_idx];
191 }
192 }
193 if (ispack) {
194 for (int c = IC_per_G; c < paddedICPerG; ++c) {
195 int p_idx = packed_index_(t, r, s, k, g, c);
196 dst[p_idx] = 0;
197 }
198 }
199 }
200 }
201 }
202 }
203 }
204 } else {
205 // For pack & transposed, call transposeConvWeights()
206 // G K/G (T R S C/G) => G (T R S C/G) K/G
207 if (tr) {
208 if (ispack) {
209 transposeConvWeights(conv_param_, src, dst);
210 } else {
211 // TODO: Wrap this as a inverseTransposeConvWeights()?
212 // For unpack & transposed, call transposeConvWeights()
213 // G (T R S C/G) K/G => G K/G (T R S C/G)
214 for (int t = 0; t < F; ++t) {
215 for (int r = 0; r < R; ++r) {
216 for (int s = 0; s < S; ++s) {
217 for (int k = 0; k < OC_per_G; ++k) {
218 for (int g = 0; g < G; ++g) {
219 for (int c = 0; c < IC_per_G; ++c) {
220 dst[((((g * OC_per_G + k) * F + t) * R + r) * S + s) *
221 IC_per_G +
222 c] =
223 src[((((g * F + t) * R + r) * S + s) * IC_per_G + c) *
224 OC_per_G +
225 k];
226 }
227 }
228 }
229 }
230 }
231 }
232 } // end if(ispack)
233 } else {
234 // just copy the data for not supported cases
235 int kernel_prod = std::accumulate(
236 conv_param_.K.begin(),
237 conv_param_.K.end(),
238 1,
239 std::multiplies<int>());
240 memcpy(dst, src, G * kernel_prod * OC_per_G * IC_per_G * sizeof(inpType));
241 } // end if(tr)
242 } // end if(fbgemmOptimizedGConv(conv_param_)
243}
244
245/**
246 * @brief Pack weight tensor in a suitable format required for the optimized
247 * kernel.
248 */
249template <typename T, typename accT, int SPATIAL_DIM>
250void PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::pack() {
251 pack_unpack_(sdata_, pdata_, true);
252}
253
254/**
255 * @brief Unpack the packed weight tensor (for the optimized kernel)
256 * to the original form.
257 */
258template <typename T, typename accT, int SPATIAL_DIM>
259void PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::unpack(T* origin_buf) {
260 pack_unpack_(const_cast<const T*>(pdata_), origin_buf, false);
261}
262
263template class FBGEMM_API PackWeightMatrixForGConv<int8_t, int32_t, 1>;
264template class FBGEMM_API PackWeightMatrixForGConv<int8_t, int16_t, 1>;
265template class FBGEMM_API PackWeightMatrixForGConv<int8_t, int32_t, 2>;
266template class FBGEMM_API PackWeightMatrixForGConv<int8_t, int16_t, 2>;
267template class FBGEMM_API PackWeightMatrixForGConv<int8_t, int32_t, 3>;
268template class FBGEMM_API PackWeightMatrixForGConv<int8_t, int16_t, 3>;
269} // namespace fbgemm
270