1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#define FBGEMM_EXPORTS
8#include "fbgemm/Fbgemm.h"
9
10#include <algorithm>
11#include <memory>
12
13namespace fbgemm {
14
15template <int SPATIAL_DIM, typename T, typename accT>
16PackWeightsForConv<SPATIAL_DIM, T, accT>::PackWeightsForConv(
17 const conv_param_t<SPATIAL_DIM>& conv_p,
18 const T* sdata,
19 const BlockingFactors* blocking_params)
20 : conv_param_(conv_p) {
21 // Note: The following logic should *exactly* match with what we have in
22 // FbgemmConv.cc
23 switch (ConvFastPath<SPATIAL_DIM, accT>(conv_p)) {
24 case optimized_conv_t::depthwise: {
25 const int kernel_d = SPATIAL_DIM <= 2 ? 1 : conv_p.K[0];
26 const int kernel_h = SPATIAL_DIM == 1 ? 1 : conv_p.K[SPATIAL_DIM - 2];
27 const int kernel_w = conv_p.K[SPATIAL_DIM - 1];
28 W_dw_packed_ = std::make_shared<PackedDepthWiseConvMatrix>(
29 conv_p.OC, kernel_d * kernel_h * kernel_w, sdata);
30 break;
31 }
32 case optimized_conv_t::groupwise: {
33 W_gconv_packed_ =
34 std::make_shared<PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>>(
35 matrix_op_t::Transpose, conv_p, sdata, nullptr);
36 break;
37 }
38 case optimized_conv_t::pointwise: {
39 const int N = conv_p.OC / conv_p.G;
40 const int kernel_d = SPATIAL_DIM <= 2 ? 1 : conv_p.K[0];
41 const int kernel_h = SPATIAL_DIM == 1 ? 1 : conv_p.K[SPATIAL_DIM - 2];
42 const int kernel_w = conv_p.K[SPATIAL_DIM - 1];
43 const int K = kernel_d * kernel_h * kernel_w * conv_p.IC;
44 W_pointwise_packed_ = std::make_shared<PackBMatrix<T, accT>>(
45 matrix_op_t::Transpose,
46 K,
47 N,
48 sdata,
49 K / conv_p.G,
50 nullptr,
51 conv_p.G,
52 blocking_params);
53 break;
54 }
55 case optimized_conv_t::directconv: {
56 const int kernel_h = SPATIAL_DIM == 1 ? 1 : conv_p.K[SPATIAL_DIM - 2];
57 const int kernel_w = conv_p.K[SPATIAL_DIM - 1];
58 const int K = kernel_h * kernel_w;
59 W_dc_packed_ = std::make_shared<PackedDirectConvMatrix>(
60 conv_p.IC, conv_p.OC, K, sdata);
61 break;
62 }
63 case optimized_conv_t::fastpath1d: {
64 break;
65 }
66 case optimized_conv_t::im2col: {
67 const int N = conv_p.OC / conv_p.G;
68 const int kernel_d = SPATIAL_DIM <= 2 ? 1 : conv_p.K[0];
69 const int kernel_h = SPATIAL_DIM == 1 ? 1 : conv_p.K[SPATIAL_DIM - 2];
70 const int kernel_w = conv_p.K[SPATIAL_DIM - 1];
71 const int K = kernel_d * kernel_h * kernel_w * conv_p.IC;
72 W_im2col_packed_ = std::make_shared<PackBMatrix<T, accT>>(
73 matrix_op_t::Transpose,
74 K,
75 N,
76 sdata,
77 K / conv_p.G,
78 nullptr,
79 conv_p.G,
80 blocking_params);
81 break;
82 }
83 } // switch
84}
85
86template <int SPATIAL_DIM, typename T, typename accT>
87void PackWeightsForConv<SPATIAL_DIM, T, accT>::unpack(T* origin_buf) {
88 if (W_dw_packed_) {
89 W_dw_packed_->unpack(origin_buf);
90 } else if (W_gconv_packed_) {
91 W_gconv_packed_->unpack(origin_buf);
92 } else if (W_im2col_packed_) {
93 W_im2col_packed_->unpack(origin_buf);
94 } else if (W_pointwise_packed_) {
95 W_pointwise_packed_->unpack(origin_buf);
96 } else {
97 assert(false && "At least one packed weights object should exist");
98 }
99}
100
101template <int SPATIAL_DIM, typename T, typename accT>
102bool PackWeightsForConv<SPATIAL_DIM, T, accT>::isPackingCompliant(
103 const conv_param_t<SPATIAL_DIM>& test_conv_p) {
104 return conv_param_.IC == test_conv_p.IC && conv_param_.OC == test_conv_p.OC &&
105 conv_param_.G == test_conv_p.G &&
106 std::equal(
107 conv_param_.K.begin(),
108 conv_param_.K.end(),
109 test_conv_p.K.begin()) &&
110 std::equal(
111 conv_param_.stride.begin(),
112 conv_param_.stride.end(),
113 test_conv_p.stride.begin()) &&
114 std::equal(
115 conv_param_.pad.begin(),
116 conv_param_.pad.end(),
117 test_conv_p.pad.begin()) &&
118 std::equal(
119 conv_param_.dilation.begin(),
120 conv_param_.dilation.end(),
121 test_conv_p.dilation.begin());
122}
123
124template <int SPATIAL_DIM, typename T, typename accT>
125std::string PackWeightsForConv<SPATIAL_DIM, T, accT>::mismatchingParams(
126 const conv_param_t<SPATIAL_DIM>& test_conv_p) {
127 std::string msg = "";
128
129 auto combineStr = [](std::string id, std::string str1, std::string str2) {
130 std::string out = id + std::string(" ");
131 out += str1;
132 out += std::string(" vs ") + str2;
133 out += std::string(";");
134 return out;
135 };
136
137 auto combineInt = [&combineStr](std::string id, int int1, int int2) {
138 return combineStr(id, std::to_string(int1), std::to_string(int2));
139 };
140
141 if (conv_param_.IC != test_conv_p.IC) {
142 msg += combineInt("input_channels", conv_param_.IC, test_conv_p.IC);
143 }
144 if (conv_param_.OC != test_conv_p.OC) {
145 msg += combineInt("output_channels", conv_param_.IC, test_conv_p.IC);
146 }
147 if (conv_param_.G != test_conv_p.G) {
148 msg += combineInt("groups", conv_param_.G, test_conv_p.G);
149 }
150
151 if (!std::equal(
152 conv_param_.K.begin(), conv_param_.K.end(), test_conv_p.K.begin())) {
153 msg += combineStr(
154 "kernel",
155 arrayToString<SPATIAL_DIM>(conv_param_.K),
156 arrayToString<SPATIAL_DIM>(test_conv_p.K));
157 }
158
159 if (!std::equal(
160 conv_param_.stride.begin(),
161 conv_param_.stride.end(),
162 test_conv_p.stride.begin())) {
163 msg += combineStr(
164 "stride",
165 arrayToString<SPATIAL_DIM>(conv_param_.stride),
166 arrayToString<SPATIAL_DIM>(test_conv_p.stride));
167 }
168
169 if (!std::equal(
170 conv_param_.pad.begin(),
171 conv_param_.pad.end(),
172 test_conv_p.pad.begin())) {
173 msg += combineStr(
174 "pad",
175 arrayToString<2 * SPATIAL_DIM>(conv_param_.pad),
176 arrayToString<2 * SPATIAL_DIM>(test_conv_p.pad));
177 }
178
179 if (!std::equal(
180 conv_param_.dilation.begin(),
181 conv_param_.dilation.end(),
182 test_conv_p.dilation.begin())) {
183 msg += combineStr(
184 "dilation",
185 arrayToString<SPATIAL_DIM>(conv_param_.dilation),
186 arrayToString<SPATIAL_DIM>(test_conv_p.dilation));
187 }
188
189 return msg;
190}
191
192template class PackWeightsForConv<1, int8_t, int32_t>;
193template class PackWeightsForConv<2, int8_t, int32_t>;
194template class PackWeightsForConv<3, int8_t, int32_t>;
195
196} // namespace fbgemm
197