1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#define FBGEMM_EXPORTS
8#include <algorithm>
9#include <functional>
10#include <numeric>
11#include <stdexcept> // for logic_error
12#include <vector>
13#include "fbgemm/Fbgemm.h"
14
15namespace fbgemm {
16
17template <int SPATIAL_DIM, typename ACC_T>
18bool takeDepthWiseFastPath(const conv_param_t<SPATIAL_DIM>& conv_p) {
19 // Note: Depthwise convolutions (both 2D and 3D) are optimized for the most
20 // common case.
21 // 3x3 or 5x5 2D
22 // (3 or 5)x(3x3 or 5x5) 3D
23 bool ret = std::is_same<ACC_T, std::int32_t>::value &&
24 conv_p.G == conv_p.IC &&
25 (conv_p.G == conv_p.OC || conv_p.G * 2 == conv_p.OC) &&
26 conv_p.G % 8 == 0 &&
27 std::all_of(
28 conv_p.stride.begin(),
29 conv_p.stride.end(),
30 [](int i) { return i == 1 || i == 2; }) &&
31 SPATIAL_DIM >= 2 &&
32 conv_p.K[SPATIAL_DIM - 2] == conv_p.K[SPATIAL_DIM - 1] &&
33 std::all_of(
34 conv_p.K.begin(),
35 conv_p.K.end(),
36 [](int i) { return i == 3 || i == 5 || i == 7; }) &&
37 std::all_of(
38 conv_p.dilation.begin(),
39 conv_p.dilation.end(),
40 [](int i) { return i == 1; }) &&
41 !conv_p.transposed;
42
43 // Check pads result in same input and output spatial dim
44 for (int i = 0; i < SPATIAL_DIM; ++i) {
45 if (conv_p.pad[i] != (conv_p.K[i] - 1) / 2 ||
46 conv_p.pad[i] != conv_p.pad[SPATIAL_DIM + i]) {
47 ret = false;
48 }
49 }
50
51 return ret;
52}
53
54template <int SPATIAL_DIM>
55bool takePointWiseFastPath(const conv_param_t<SPATIAL_DIM>& conv_p) {
56 return std::accumulate(conv_p.K.begin(), conv_p.K.end(), 0) == SPATIAL_DIM &&
57 std::accumulate(conv_p.stride.begin(), conv_p.stride.end(), 0) ==
58 SPATIAL_DIM &&
59 std::accumulate(conv_p.dilation.begin(), conv_p.dilation.end(), 0) ==
60 SPATIAL_DIM &&
61 std::accumulate(conv_p.pad.begin(), conv_p.pad.end(), 0) == 0 &&
62 !conv_p.transposed;
63}
64
65template <int SPATIAL_DIM>
66bool take1DFastPath(const conv_param_t<SPATIAL_DIM>& conv_p) {
67 return false && !conv_p.transposed;
68}
69
70template <int SPATIAL_DIM, typename ACC_T>
71bool takeDirectConvPath(const conv_param_t<SPATIAL_DIM>& conv_p) {
72 // Note: Direct convolutions (2D) are optimized for
73 // filter size: 2 x 1 to 2 x 6, transposed conv,
74 // in_channel % 8 == 0, out_channel % 8 == 0
75 // stride = 1 or 2
76 // padding = 0 ( non-zero padding will be supported soon)
77 bool ret = std::is_same<ACC_T, std::int32_t>::value && conv_p.transposed &&
78 conv_p.G == 1 && conv_p.IC % 8 == 0 && conv_p.OC % 8 == 0 &&
79 std::all_of(
80 conv_p.stride.begin(),
81 conv_p.stride.end(),
82 [](int i) { return i == 1 || i == 2; }) &&
83 SPATIAL_DIM == 2 && conv_p.K[SPATIAL_DIM - 2] == 2 &&
84 conv_p.K[SPATIAL_DIM - 1] <= 6 &&
85 std::all_of(conv_p.dilation.begin(), conv_p.dilation.end(), [](int i) {
86 return i == 1;
87 });
88
89 // Check pads: zero padding
90 for (int i = 0; i < SPATIAL_DIM; ++i) {
91 if (conv_p.pad[i] != 0) {
92 ret = false;
93 }
94 }
95 ret = false;
96 return ret;
97}
98
99template <int SPATIAL_DIM, typename ACC_T>
100optimized_conv_t ConvFastPath(const conv_param_t<SPATIAL_DIM>& conv_p) {
101 if (takeDepthWiseFastPath<SPATIAL_DIM, ACC_T>(conv_p)) {
102 return optimized_conv_t::depthwise;
103 } else if (fbgemmOptimizedGConv<SPATIAL_DIM>(conv_p)) {
104 return optimized_conv_t::groupwise;
105 } else if (takePointWiseFastPath<SPATIAL_DIM>(conv_p)) {
106 return optimized_conv_t::pointwise;
107 } else if (takeDirectConvPath<SPATIAL_DIM, ACC_T>(conv_p)) {
108 return optimized_conv_t::directconv;
109 } else if (take1DFastPath<SPATIAL_DIM>(conv_p)) {
110 return optimized_conv_t::fastpath1d;
111 } else {
112 return optimized_conv_t::im2col;
113 }
114}
115
116template <typename processOutputType, int SPATIAL_DIM, typename ACC_T>
117int fbgemmConv(
118 const conv_param_t<SPATIAL_DIM>& conv_p,
119 const std::uint8_t* activations,
120 PackWeightsForConv<SPATIAL_DIM, std::int8_t, ACC_T>& packed_weights,
121 typename processOutputType::outType* out,
122 std::int32_t* outBuffer,
123 processOutputType& outProcess,
124 int thread_id,
125 int num_threads,
126 const BlockingFactors* blocking_params) {
127 if (!packed_weights.isPackingCompliant(conv_p)) {
128 std::string msg =
129 "[FBGEMM_CONV_ERROR] Convolution parameters "
130 "mismatch between pre-packed weights and conv invocation! ";
131 msg += packed_weights.mismatchingParams(conv_p);
132 msg += std::string(
133 " Please pack weights using the same parameters "
134 "with which convolution operation is invoked!");
135 throw std::logic_error(msg);
136 }
137
138 switch (ConvFastPath<SPATIAL_DIM, ACC_T>(conv_p)) {
139 case optimized_conv_t::depthwise: {
140 // 2D and 3D depthwise fast path
141 // std::cout << "Depthwise fast path" << std::endl;
142 const std::int32_t* B_zero_point = outProcess.getBZeroPoint();
143 const float* C_multiplier = outProcess.getCMultiplier();
144 const float* act_times_w_scale = outProcess.getActWScale();
145 if (SPATIAL_DIM == 3) {
146 static_assert(
147 std::is_same<typename processOutputType::outType, std::uint8_t>::
148 value,
149 "For depthwise, only requantized output is supported");
150
151 if (processOutputType::QGRANType == QuantizationGranularity::TENSOR) {
152 depthwise_3d_same_pad<QuantizationGranularity::TENSOR>(
153 *reinterpret_cast<const conv_param_t<3>*>(&conv_p),
154 outProcess.getAZeroPoint(),
155 activations,
156 B_zero_point,
157 *(packed_weights.getPackedWForDepthwise()),
158 C_multiplier,
159 outProcess.getCZeroPoint(),
160 out,
161 outProcess.getColOffsets(),
162 outProcess.getBias(),
163 outProcess.RELU_FUSED, // fuse_relu
164 act_times_w_scale,
165 thread_id,
166 num_threads);
167 } else if (
168 processOutputType::QGRANType == QuantizationGranularity::GROUP) {
169 depthwise_3d_same_pad<QuantizationGranularity::GROUP>(
170 *reinterpret_cast<const conv_param_t<3>*>(&conv_p),
171 outProcess.getAZeroPoint(),
172 activations,
173 B_zero_point,
174 *(packed_weights.getPackedWForDepthwise()),
175 C_multiplier,
176 outProcess.getCZeroPoint(),
177 out,
178 outProcess.getColOffsets(),
179 outProcess.getBias(),
180 outProcess.RELU_FUSED, // fuse_relu
181 act_times_w_scale, // act_scale * weight_scale
182 thread_id,
183 num_threads);
184 } else if (
185 processOutputType::QGRANType ==
186 QuantizationGranularity::OUT_CHANNEL) {
187 depthwise_3d_same_pad<QuantizationGranularity::OUT_CHANNEL>(
188 *reinterpret_cast<const conv_param_t<3>*>(&conv_p),
189 outProcess.getAZeroPoint(),
190 activations,
191 B_zero_point,
192 *(packed_weights.getPackedWForDepthwise()),
193 C_multiplier,
194 outProcess.getCZeroPoint(),
195 out,
196 outProcess.getColOffsets(),
197 outProcess.getBias(),
198 outProcess.RELU_FUSED, // fuse_relu
199 act_times_w_scale, // act_scale * weight_scale
200 thread_id,
201 num_threads);
202 } else {
203 std::string msg =
204 "[FBGEMM_CONV_ERROR] This quantization granularity is "
205 "not supported";
206 throw std::runtime_error(msg);
207 }
208 } else if (SPATIAL_DIM == 2) {
209 if (processOutputType::QGRANType == QuantizationGranularity::TENSOR) {
210 depthwise_2d_same_pad<QuantizationGranularity::TENSOR>(
211 conv_p.MB, // mini batch
212 conv_p.IN_DIM[0], // H
213 conv_p.IN_DIM[1], // W
214 conv_p.IC, // input channels
215 conv_p.OC, // output channels
216 conv_p.stride[0], // stride_h
217 conv_p.stride[1], // stride_w
218 outProcess.getAZeroPoint(),
219 activations,
220 B_zero_point,
221 *(packed_weights.getPackedWForDepthwise()),
222 C_multiplier,
223 outProcess.getCZeroPoint(),
224 out,
225 outProcess.getColOffsets(),
226 outProcess.getBias(),
227 outProcess.RELU_FUSED, // fuse_relu
228 act_times_w_scale,
229 thread_id,
230 num_threads);
231 } else if (
232 processOutputType::QGRANType == QuantizationGranularity::GROUP) {
233 depthwise_2d_same_pad<QuantizationGranularity::GROUP>(
234 conv_p.MB, // mini batch
235 conv_p.IN_DIM[0], // H
236 conv_p.IN_DIM[1], // W
237 conv_p.IC, // input channels
238 conv_p.OC, // output channels
239 conv_p.stride[0], // stride_h
240 conv_p.stride[1], // stride_w
241 outProcess.getAZeroPoint(),
242 activations,
243 B_zero_point,
244 *(packed_weights.getPackedWForDepthwise()),
245 C_multiplier,
246 outProcess.getCZeroPoint(),
247 out,
248 outProcess.getColOffsets(),
249 outProcess.getBias(),
250 outProcess.RELU_FUSED, // fuse_relu
251 act_times_w_scale, // act_scale * weight_scale
252 thread_id,
253 num_threads);
254 } else if (
255 processOutputType::QGRANType ==
256 QuantizationGranularity::OUT_CHANNEL) {
257 // The number of input channels == groups for depthwise convolutions
258 depthwise_2d_same_pad<QuantizationGranularity::OUT_CHANNEL>(
259 conv_p.MB, // mini batch
260 conv_p.IN_DIM[0], // H
261 conv_p.IN_DIM[1], // W
262 conv_p.IC, // input channels
263 conv_p.OC, // output channels
264 conv_p.stride[0], // stride_h
265 conv_p.stride[1], // stride_w
266 outProcess.getAZeroPoint(),
267 activations,
268 B_zero_point,
269 *(packed_weights.getPackedWForDepthwise()),
270 C_multiplier,
271 outProcess.getCZeroPoint(),
272 out,
273 outProcess.getColOffsets(),
274 outProcess.getBias(),
275 outProcess.RELU_FUSED, // fuse_relu
276 act_times_w_scale, // act_scale * weight_scale
277 thread_id,
278 num_threads);
279 } else {
280 std::string msg =
281 "[FBGEMM_CONV_ERROR] This quantization granularity is "
282 "not supported";
283 throw std::runtime_error(msg);
284 }
285 } else {
286 std::string msg =
287 "[FBGEMM_CONV_ERROR] This spatial dim is not supported";
288 throw std::runtime_error(msg);
289 }
290 break;
291 }
292 case optimized_conv_t::groupwise: {
293 // optimized groupwise convolution
294 // std::cout << "Groupwise fast path" << std::endl;
295 std::vector<int32_t> row_offset_buf(
296 rowOffsetBufferSizeGConv<SPATIAL_DIM>(conv_p));
297 outProcess.setRowOffsets(row_offset_buf.data());
298 fbgemmGroupwiseConv(
299 conv_p,
300 activations,
301 outProcess.getAZeroPoint(),
302 row_offset_buf.data(),
303 *(packed_weights.getPackedWForGroupwise()),
304 out,
305 outBuffer,
306 outProcess,
307 thread_id,
308 num_threads);
309 break;
310 }
311 case optimized_conv_t::pointwise: {
312 std::vector<int32_t> row_offset_buf(
313 PackAWithRowOffset<uint8_t>::rowOffsetBufferSize(blocking_params));
314 int image_dim = std::accumulate(
315 conv_p.IN_DIM.begin(),
316 conv_p.IN_DIM.end(),
317 1,
318 std::multiplies<int>());
319 PackAWithRowOffset<uint8_t, ACC_T> packA(
320 matrix_op_t::NoTranspose,
321 conv_p.MB * image_dim,
322 conv_p.IC,
323 activations,
324 conv_p.IC,
325 nullptr,
326 conv_p.G,
327 row_offset_buf.data(),
328 blocking_params);
329
330 outProcess.setRowOffsets(row_offset_buf.data());
331 fbgemmPacked(
332 packA,
333 *(packed_weights.getPackedWForPointwise()),
334 out,
335 outBuffer,
336 conv_p.OC,
337 outProcess,
338 thread_id,
339 num_threads,
340 blocking_params);
341 break;
342 }
343 case optimized_conv_t::directconv: {
344 // specialized direct convolution path
345 // std::cout << "Directconv fast path" << std::endl;
346 fbgemmDirectConv<SPATIAL_DIM, processOutputType::QGRANType>(
347 conv_p,
348 // Aint8,
349 activations,
350 *(packed_weights.getPackedWForDirectconv()),
351 out,
352 outBuffer,
353 outProcess,
354 outProcess.getBias(),
355 thread_id,
356 num_threads);
357 break;
358 }
359 case optimized_conv_t::fastpath1d: {
360 break;
361 }
362 case optimized_conv_t::im2col: {
363 // All other convolutions go through im2col-based implementation
364 // std::cout << "Im2col path" << std::endl;
365 std::vector<int32_t> row_offset_buf(
366 PackAWithIm2Col<uint8_t, ACC_T, SPATIAL_DIM>::rowOffsetBufferSize(
367 blocking_params));
368
369 const std::int32_t* b_zero_point = outProcess.getBZeroPoint();
370 bool b_symmetric = false;
371 if (processOutputType::QGRANType == QuantizationGranularity::TENSOR) {
372 b_symmetric = b_zero_point[0] == 0;
373 } else if (
374 processOutputType::QGRANType == QuantizationGranularity::GROUP) {
375 b_symmetric =
376 std::all_of(b_zero_point, b_zero_point + conv_p.G, [](int i) {
377 return i == 0;
378 });
379 } else if (
380 processOutputType::QGRANType ==
381 QuantizationGranularity::OUT_CHANNEL) {
382 b_symmetric =
383 std::all_of(b_zero_point, b_zero_point + conv_p.OC, [](int i) {
384 return i == 0;
385 });
386 } else {
387 std::string msg =
388 "[FBGEMM_CONV_ERROR] This quantization granularity is "
389 "not supported";
390 throw std::runtime_error(msg);
391 }
392 PackAWithIm2Col<uint8_t, ACC_T, SPATIAL_DIM> packA(
393 conv_p,
394 activations,
395 nullptr, /* buffer for packed matrix */
396 outProcess.getAZeroPoint(),
397 row_offset_buf.data(),
398 b_symmetric,
399 blocking_params);
400
401 outProcess.setRowOffsets(row_offset_buf.data());
402 fbgemmPacked(
403 packA,
404 *(packed_weights.getPackedWForIm2col()),
405 out,
406 outBuffer,
407 conv_p.OC,
408 outProcess,
409 thread_id,
410 num_threads,
411 blocking_params);
412 break;
413 }
414 } // switch
415
416 return 0;
417}
418
419#define INSTANTIATE_BASE(ACC_T, Q_GRAN, RELU, SPATIAL_DIM, BIAS_TYPE) \
420 template FBGEMM_API int fbgemmConv( \
421 const conv_param_t<SPATIAL_DIM>& conv_p, \
422 const std::uint8_t* activations, \
423 PackWeightsForConv<SPATIAL_DIM, std::int8_t, ACC_T>& packed_weights, \
424 std::uint8_t* out, \
425 std::int32_t* outBuffer, \
426 ReQuantizeOutput<RELU, Q_GRAN, BIAS_TYPE>& outProcess, \
427 int thread_id, \
428 int num_threads, \
429 const BlockingFactors* blocking_params);
430
431#define INSTANTIATE_BIAS_T(ACC_T, Q_GRAN, RELU, SPATIAL_DIM) \
432 INSTANTIATE_BASE(ACC_T, Q_GRAN, RELU, SPATIAL_DIM, float) \
433 INSTANTIATE_BASE(ACC_T, Q_GRAN, RELU, SPATIAL_DIM, int32_t)
434
435#define INSTANTIATE_SPATIAL_DIM(ACC_T, Q_GRAN, RELU) \
436 INSTANTIATE_BIAS_T(ACC_T, Q_GRAN, RELU, 1) \
437 INSTANTIATE_BIAS_T(ACC_T, Q_GRAN, RELU, 2) \
438 INSTANTIATE_BIAS_T(ACC_T, Q_GRAN, RELU, 3)
439
440#define INSTANTIATE_RELU(ACC_T, Q_GRAN) \
441 INSTANTIATE_SPATIAL_DIM(ACC_T, Q_GRAN, true) \
442 INSTANTIATE_SPATIAL_DIM(ACC_T, Q_GRAN, false)
443
444#define INSTANTIATE_Q_GRANS(ACC_T) \
445 INSTANTIATE_RELU(ACC_T, QuantizationGranularity::TENSOR) \
446 INSTANTIATE_RELU(ACC_T, QuantizationGranularity::GROUP) \
447 INSTANTIATE_RELU(ACC_T, QuantizationGranularity::OUT_CHANNEL)
448
449INSTANTIATE_Q_GRANS(std::int32_t)
450
451#undef INSTANTIATE_Q_GRANS
452#undef INSTANTIATE_RELU
453#undef INSTANTIATE_SPATIAL_DIM
454#undef INSTANTIATE_BIAS_T
455#undef INSTANTIATE_BASE
456
457template bool takeDepthWiseFastPath<2, std::int32_t>(
458 const conv_param_t<2>& conv_p);
459template bool takeDepthWiseFastPath<3, std::int32_t>(
460 const conv_param_t<3>& conv_p);
461template bool takeDepthWiseFastPath<2, std::int16_t>(
462 const conv_param_t<2>& conv_p);
463template bool takeDepthWiseFastPath<3, std::int16_t>(
464 const conv_param_t<3>& conv_p);
465
466template bool takeDirectConvPath<2, std::int32_t>(
467 const conv_param_t<2>& conv_p);
468template bool takeDirectConvPath<3, std::int32_t>(
469 const conv_param_t<3>& conv_p);
470template bool takeDirectConvPath<2, std::int16_t>(
471 const conv_param_t<2>& conv_p);
472template bool takeDirectConvPath<3, std::int16_t>(
473 const conv_param_t<3>& conv_p);
474
475template FBGEMM_API optimized_conv_t
476ConvFastPath<1, std::int32_t>(const conv_param_t<1>& conv_p);
477template FBGEMM_API optimized_conv_t
478ConvFastPath<2, std::int32_t>(const conv_param_t<2>& conv_p);
479template FBGEMM_API optimized_conv_t
480ConvFastPath<3, std::int32_t>(const conv_param_t<3>& conv_p);
481
482template FBGEMM_API optimized_conv_t
483ConvFastPath<1, std::int16_t>(const conv_param_t<1>& conv_p);
484template FBGEMM_API optimized_conv_t
485ConvFastPath<2, std::int16_t>(const conv_param_t<2>& conv_p);
486template FBGEMM_API optimized_conv_t
487ConvFastPath<3, std::int16_t>(const conv_param_t<3>& conv_p);
488
489} // namespace fbgemm
490