1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * This source code is licensed under the BSD-style license found in the |
5 | * LICENSE file in the root directory of this source tree. |
6 | */ |
7 | #define FBGEMM_EXPORTS |
8 | #include "fbgemm/Fbgemm.h" |
9 | |
10 | #include <algorithm> |
11 | #include <memory> |
12 | |
13 | namespace fbgemm { |
14 | |
15 | template <int SPATIAL_DIM, typename T, typename accT> |
16 | PackWeightsForConv<SPATIAL_DIM, T, accT>::PackWeightsForConv( |
17 | const conv_param_t<SPATIAL_DIM>& conv_p, |
18 | const T* sdata, |
19 | const BlockingFactors* blocking_params) |
20 | : conv_param_(conv_p) { |
21 | // Note: The following logic should *exactly* match with what we have in |
22 | // FbgemmConv.cc |
23 | switch (ConvFastPath<SPATIAL_DIM, accT>(conv_p)) { |
24 | case optimized_conv_t::depthwise: { |
25 | const int kernel_d = SPATIAL_DIM <= 2 ? 1 : conv_p.K[0]; |
26 | const int kernel_h = SPATIAL_DIM == 1 ? 1 : conv_p.K[SPATIAL_DIM - 2]; |
27 | const int kernel_w = conv_p.K[SPATIAL_DIM - 1]; |
28 | W_dw_packed_ = std::make_shared<PackedDepthWiseConvMatrix>( |
29 | conv_p.OC, kernel_d * kernel_h * kernel_w, sdata); |
30 | break; |
31 | } |
32 | case optimized_conv_t::groupwise: { |
33 | W_gconv_packed_ = |
34 | std::make_shared<PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>>( |
35 | matrix_op_t::Transpose, conv_p, sdata, nullptr); |
36 | break; |
37 | } |
38 | case optimized_conv_t::pointwise: { |
39 | const int N = conv_p.OC / conv_p.G; |
40 | const int kernel_d = SPATIAL_DIM <= 2 ? 1 : conv_p.K[0]; |
41 | const int kernel_h = SPATIAL_DIM == 1 ? 1 : conv_p.K[SPATIAL_DIM - 2]; |
42 | const int kernel_w = conv_p.K[SPATIAL_DIM - 1]; |
43 | const int K = kernel_d * kernel_h * kernel_w * conv_p.IC; |
44 | W_pointwise_packed_ = std::make_shared<PackBMatrix<T, accT>>( |
45 | matrix_op_t::Transpose, |
46 | K, |
47 | N, |
48 | sdata, |
49 | K / conv_p.G, |
50 | nullptr, |
51 | conv_p.G, |
52 | blocking_params); |
53 | break; |
54 | } |
55 | case optimized_conv_t::directconv: { |
56 | const int kernel_h = SPATIAL_DIM == 1 ? 1 : conv_p.K[SPATIAL_DIM - 2]; |
57 | const int kernel_w = conv_p.K[SPATIAL_DIM - 1]; |
58 | const int K = kernel_h * kernel_w; |
59 | W_dc_packed_ = std::make_shared<PackedDirectConvMatrix>( |
60 | conv_p.IC, conv_p.OC, K, sdata); |
61 | break; |
62 | } |
63 | case optimized_conv_t::fastpath1d: { |
64 | break; |
65 | } |
66 | case optimized_conv_t::im2col: { |
67 | const int N = conv_p.OC / conv_p.G; |
68 | const int kernel_d = SPATIAL_DIM <= 2 ? 1 : conv_p.K[0]; |
69 | const int kernel_h = SPATIAL_DIM == 1 ? 1 : conv_p.K[SPATIAL_DIM - 2]; |
70 | const int kernel_w = conv_p.K[SPATIAL_DIM - 1]; |
71 | const int K = kernel_d * kernel_h * kernel_w * conv_p.IC; |
72 | W_im2col_packed_ = std::make_shared<PackBMatrix<T, accT>>( |
73 | matrix_op_t::Transpose, |
74 | K, |
75 | N, |
76 | sdata, |
77 | K / conv_p.G, |
78 | nullptr, |
79 | conv_p.G, |
80 | blocking_params); |
81 | break; |
82 | } |
83 | } // switch |
84 | } |
85 | |
86 | template <int SPATIAL_DIM, typename T, typename accT> |
87 | void PackWeightsForConv<SPATIAL_DIM, T, accT>::unpack(T* origin_buf) { |
88 | if (W_dw_packed_) { |
89 | W_dw_packed_->unpack(origin_buf); |
90 | } else if (W_gconv_packed_) { |
91 | W_gconv_packed_->unpack(origin_buf); |
92 | } else if (W_im2col_packed_) { |
93 | W_im2col_packed_->unpack(origin_buf); |
94 | } else if (W_pointwise_packed_) { |
95 | W_pointwise_packed_->unpack(origin_buf); |
96 | } else { |
97 | assert(false && "At least one packed weights object should exist" ); |
98 | } |
99 | } |
100 | |
101 | template <int SPATIAL_DIM, typename T, typename accT> |
102 | bool PackWeightsForConv<SPATIAL_DIM, T, accT>::isPackingCompliant( |
103 | const conv_param_t<SPATIAL_DIM>& test_conv_p) { |
104 | return conv_param_.IC == test_conv_p.IC && conv_param_.OC == test_conv_p.OC && |
105 | conv_param_.G == test_conv_p.G && |
106 | std::equal( |
107 | conv_param_.K.begin(), |
108 | conv_param_.K.end(), |
109 | test_conv_p.K.begin()) && |
110 | std::equal( |
111 | conv_param_.stride.begin(), |
112 | conv_param_.stride.end(), |
113 | test_conv_p.stride.begin()) && |
114 | std::equal( |
115 | conv_param_.pad.begin(), |
116 | conv_param_.pad.end(), |
117 | test_conv_p.pad.begin()) && |
118 | std::equal( |
119 | conv_param_.dilation.begin(), |
120 | conv_param_.dilation.end(), |
121 | test_conv_p.dilation.begin()); |
122 | } |
123 | |
124 | template <int SPATIAL_DIM, typename T, typename accT> |
125 | std::string PackWeightsForConv<SPATIAL_DIM, T, accT>::mismatchingParams( |
126 | const conv_param_t<SPATIAL_DIM>& test_conv_p) { |
127 | std::string msg = "" ; |
128 | |
129 | auto combineStr = [](std::string id, std::string str1, std::string str2) { |
130 | std::string out = id + std::string(" " ); |
131 | out += str1; |
132 | out += std::string(" vs " ) + str2; |
133 | out += std::string(";" ); |
134 | return out; |
135 | }; |
136 | |
137 | auto combineInt = [&combineStr](std::string id, int int1, int int2) { |
138 | return combineStr(id, std::to_string(int1), std::to_string(int2)); |
139 | }; |
140 | |
141 | if (conv_param_.IC != test_conv_p.IC) { |
142 | msg += combineInt("input_channels" , conv_param_.IC, test_conv_p.IC); |
143 | } |
144 | if (conv_param_.OC != test_conv_p.OC) { |
145 | msg += combineInt("output_channels" , conv_param_.IC, test_conv_p.IC); |
146 | } |
147 | if (conv_param_.G != test_conv_p.G) { |
148 | msg += combineInt("groups" , conv_param_.G, test_conv_p.G); |
149 | } |
150 | |
151 | if (!std::equal( |
152 | conv_param_.K.begin(), conv_param_.K.end(), test_conv_p.K.begin())) { |
153 | msg += combineStr( |
154 | "kernel" , |
155 | arrayToString<SPATIAL_DIM>(conv_param_.K), |
156 | arrayToString<SPATIAL_DIM>(test_conv_p.K)); |
157 | } |
158 | |
159 | if (!std::equal( |
160 | conv_param_.stride.begin(), |
161 | conv_param_.stride.end(), |
162 | test_conv_p.stride.begin())) { |
163 | msg += combineStr( |
164 | "stride" , |
165 | arrayToString<SPATIAL_DIM>(conv_param_.stride), |
166 | arrayToString<SPATIAL_DIM>(test_conv_p.stride)); |
167 | } |
168 | |
169 | if (!std::equal( |
170 | conv_param_.pad.begin(), |
171 | conv_param_.pad.end(), |
172 | test_conv_p.pad.begin())) { |
173 | msg += combineStr( |
174 | "pad" , |
175 | arrayToString<2 * SPATIAL_DIM>(conv_param_.pad), |
176 | arrayToString<2 * SPATIAL_DIM>(test_conv_p.pad)); |
177 | } |
178 | |
179 | if (!std::equal( |
180 | conv_param_.dilation.begin(), |
181 | conv_param_.dilation.end(), |
182 | test_conv_p.dilation.begin())) { |
183 | msg += combineStr( |
184 | "dilation" , |
185 | arrayToString<SPATIAL_DIM>(conv_param_.dilation), |
186 | arrayToString<SPATIAL_DIM>(test_conv_p.dilation)); |
187 | } |
188 | |
189 | return msg; |
190 | } |
191 | |
192 | template class PackWeightsForConv<1, int8_t, int32_t>; |
193 | template class PackWeightsForConv<2, int8_t, int32_t>; |
194 | template class PackWeightsForConv<3, int8_t, int32_t>; |
195 | |
196 | } // namespace fbgemm |
197 | |