1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * This source code is licensed under the BSD-style license found in the |
5 | * LICENSE file in the root directory of this source tree. |
6 | */ |
7 | #define FBGEMM_EXPORTS |
8 | #include <algorithm> |
9 | #include <functional> |
10 | #include <numeric> |
11 | #include <stdexcept> // for logic_error |
12 | #include <vector> |
13 | #include "fbgemm/Fbgemm.h" |
14 | |
15 | namespace fbgemm { |
16 | |
17 | template <int SPATIAL_DIM, typename ACC_T> |
18 | bool takeDepthWiseFastPath(const conv_param_t<SPATIAL_DIM>& conv_p) { |
19 | // Note: Depthwise convolutions (both 2D and 3D) are optimized for the most |
20 | // common case. |
21 | // 3x3 or 5x5 2D |
22 | // (3 or 5)x(3x3 or 5x5) 3D |
23 | bool ret = std::is_same<ACC_T, std::int32_t>::value && |
24 | conv_p.G == conv_p.IC && |
25 | (conv_p.G == conv_p.OC || conv_p.G * 2 == conv_p.OC) && |
26 | conv_p.G % 8 == 0 && |
27 | std::all_of( |
28 | conv_p.stride.begin(), |
29 | conv_p.stride.end(), |
30 | [](int i) { return i == 1 || i == 2; }) && |
31 | SPATIAL_DIM >= 2 && |
32 | conv_p.K[SPATIAL_DIM - 2] == conv_p.K[SPATIAL_DIM - 1] && |
33 | std::all_of( |
34 | conv_p.K.begin(), |
35 | conv_p.K.end(), |
36 | [](int i) { return i == 3 || i == 5 || i == 7; }) && |
37 | std::all_of( |
38 | conv_p.dilation.begin(), |
39 | conv_p.dilation.end(), |
40 | [](int i) { return i == 1; }) && |
41 | !conv_p.transposed; |
42 | |
43 | // Check pads result in same input and output spatial dim |
44 | for (int i = 0; i < SPATIAL_DIM; ++i) { |
45 | if (conv_p.pad[i] != (conv_p.K[i] - 1) / 2 || |
46 | conv_p.pad[i] != conv_p.pad[SPATIAL_DIM + i]) { |
47 | ret = false; |
48 | } |
49 | } |
50 | |
51 | return ret; |
52 | } |
53 | |
54 | template <int SPATIAL_DIM> |
55 | bool takePointWiseFastPath(const conv_param_t<SPATIAL_DIM>& conv_p) { |
56 | return std::accumulate(conv_p.K.begin(), conv_p.K.end(), 0) == SPATIAL_DIM && |
57 | std::accumulate(conv_p.stride.begin(), conv_p.stride.end(), 0) == |
58 | SPATIAL_DIM && |
59 | std::accumulate(conv_p.dilation.begin(), conv_p.dilation.end(), 0) == |
60 | SPATIAL_DIM && |
61 | std::accumulate(conv_p.pad.begin(), conv_p.pad.end(), 0) == 0 && |
62 | !conv_p.transposed; |
63 | } |
64 | |
65 | template <int SPATIAL_DIM> |
66 | bool take1DFastPath(const conv_param_t<SPATIAL_DIM>& conv_p) { |
67 | return false && !conv_p.transposed; |
68 | } |
69 | |
70 | template <int SPATIAL_DIM, typename ACC_T> |
71 | bool takeDirectConvPath(const conv_param_t<SPATIAL_DIM>& conv_p) { |
72 | // Note: Direct convolutions (2D) are optimized for |
73 | // filter size: 2 x 1 to 2 x 6, transposed conv, |
74 | // in_channel % 8 == 0, out_channel % 8 == 0 |
75 | // stride = 1 or 2 |
76 | // padding = 0 ( non-zero padding will be supported soon) |
77 | bool ret = std::is_same<ACC_T, std::int32_t>::value && conv_p.transposed && |
78 | conv_p.G == 1 && conv_p.IC % 8 == 0 && conv_p.OC % 8 == 0 && |
79 | std::all_of( |
80 | conv_p.stride.begin(), |
81 | conv_p.stride.end(), |
82 | [](int i) { return i == 1 || i == 2; }) && |
83 | SPATIAL_DIM == 2 && conv_p.K[SPATIAL_DIM - 2] == 2 && |
84 | conv_p.K[SPATIAL_DIM - 1] <= 6 && |
85 | std::all_of(conv_p.dilation.begin(), conv_p.dilation.end(), [](int i) { |
86 | return i == 1; |
87 | }); |
88 | |
89 | // Check pads: zero padding |
90 | for (int i = 0; i < SPATIAL_DIM; ++i) { |
91 | if (conv_p.pad[i] != 0) { |
92 | ret = false; |
93 | } |
94 | } |
95 | ret = false; |
96 | return ret; |
97 | } |
98 | |
99 | template <int SPATIAL_DIM, typename ACC_T> |
100 | optimized_conv_t ConvFastPath(const conv_param_t<SPATIAL_DIM>& conv_p) { |
101 | if (takeDepthWiseFastPath<SPATIAL_DIM, ACC_T>(conv_p)) { |
102 | return optimized_conv_t::depthwise; |
103 | } else if (fbgemmOptimizedGConv<SPATIAL_DIM>(conv_p)) { |
104 | return optimized_conv_t::groupwise; |
105 | } else if (takePointWiseFastPath<SPATIAL_DIM>(conv_p)) { |
106 | return optimized_conv_t::pointwise; |
107 | } else if (takeDirectConvPath<SPATIAL_DIM, ACC_T>(conv_p)) { |
108 | return optimized_conv_t::directconv; |
109 | } else if (take1DFastPath<SPATIAL_DIM>(conv_p)) { |
110 | return optimized_conv_t::fastpath1d; |
111 | } else { |
112 | return optimized_conv_t::im2col; |
113 | } |
114 | } |
115 | |
116 | template <typename processOutputType, int SPATIAL_DIM, typename ACC_T> |
117 | int fbgemmConv( |
118 | const conv_param_t<SPATIAL_DIM>& conv_p, |
119 | const std::uint8_t* activations, |
120 | PackWeightsForConv<SPATIAL_DIM, std::int8_t, ACC_T>& packed_weights, |
121 | typename processOutputType::outType* out, |
122 | std::int32_t* outBuffer, |
123 | processOutputType& outProcess, |
124 | int thread_id, |
125 | int num_threads, |
126 | const BlockingFactors* blocking_params) { |
127 | if (!packed_weights.isPackingCompliant(conv_p)) { |
128 | std::string msg = |
129 | "[FBGEMM_CONV_ERROR] Convolution parameters " |
130 | "mismatch between pre-packed weights and conv invocation! " ; |
131 | msg += packed_weights.mismatchingParams(conv_p); |
132 | msg += std::string( |
133 | " Please pack weights using the same parameters " |
134 | "with which convolution operation is invoked!" ); |
135 | throw std::logic_error(msg); |
136 | } |
137 | |
138 | switch (ConvFastPath<SPATIAL_DIM, ACC_T>(conv_p)) { |
139 | case optimized_conv_t::depthwise: { |
140 | // 2D and 3D depthwise fast path |
141 | // std::cout << "Depthwise fast path" << std::endl; |
142 | const std::int32_t* B_zero_point = outProcess.getBZeroPoint(); |
143 | const float* C_multiplier = outProcess.getCMultiplier(); |
144 | const float* act_times_w_scale = outProcess.getActWScale(); |
145 | if (SPATIAL_DIM == 3) { |
146 | static_assert( |
147 | std::is_same<typename processOutputType::outType, std::uint8_t>:: |
148 | value, |
149 | "For depthwise, only requantized output is supported" ); |
150 | |
151 | if (processOutputType::QGRANType == QuantizationGranularity::TENSOR) { |
152 | depthwise_3d_same_pad<QuantizationGranularity::TENSOR>( |
153 | *reinterpret_cast<const conv_param_t<3>*>(&conv_p), |
154 | outProcess.getAZeroPoint(), |
155 | activations, |
156 | B_zero_point, |
157 | *(packed_weights.getPackedWForDepthwise()), |
158 | C_multiplier, |
159 | outProcess.getCZeroPoint(), |
160 | out, |
161 | outProcess.getColOffsets(), |
162 | outProcess.getBias(), |
163 | outProcess.RELU_FUSED, // fuse_relu |
164 | act_times_w_scale, |
165 | thread_id, |
166 | num_threads); |
167 | } else if ( |
168 | processOutputType::QGRANType == QuantizationGranularity::GROUP) { |
169 | depthwise_3d_same_pad<QuantizationGranularity::GROUP>( |
170 | *reinterpret_cast<const conv_param_t<3>*>(&conv_p), |
171 | outProcess.getAZeroPoint(), |
172 | activations, |
173 | B_zero_point, |
174 | *(packed_weights.getPackedWForDepthwise()), |
175 | C_multiplier, |
176 | outProcess.getCZeroPoint(), |
177 | out, |
178 | outProcess.getColOffsets(), |
179 | outProcess.getBias(), |
180 | outProcess.RELU_FUSED, // fuse_relu |
181 | act_times_w_scale, // act_scale * weight_scale |
182 | thread_id, |
183 | num_threads); |
184 | } else if ( |
185 | processOutputType::QGRANType == |
186 | QuantizationGranularity::OUT_CHANNEL) { |
187 | depthwise_3d_same_pad<QuantizationGranularity::OUT_CHANNEL>( |
188 | *reinterpret_cast<const conv_param_t<3>*>(&conv_p), |
189 | outProcess.getAZeroPoint(), |
190 | activations, |
191 | B_zero_point, |
192 | *(packed_weights.getPackedWForDepthwise()), |
193 | C_multiplier, |
194 | outProcess.getCZeroPoint(), |
195 | out, |
196 | outProcess.getColOffsets(), |
197 | outProcess.getBias(), |
198 | outProcess.RELU_FUSED, // fuse_relu |
199 | act_times_w_scale, // act_scale * weight_scale |
200 | thread_id, |
201 | num_threads); |
202 | } else { |
203 | std::string msg = |
204 | "[FBGEMM_CONV_ERROR] This quantization granularity is " |
205 | "not supported" ; |
206 | throw std::runtime_error(msg); |
207 | } |
208 | } else if (SPATIAL_DIM == 2) { |
209 | if (processOutputType::QGRANType == QuantizationGranularity::TENSOR) { |
210 | depthwise_2d_same_pad<QuantizationGranularity::TENSOR>( |
211 | conv_p.MB, // mini batch |
212 | conv_p.IN_DIM[0], // H |
213 | conv_p.IN_DIM[1], // W |
214 | conv_p.IC, // input channels |
215 | conv_p.OC, // output channels |
216 | conv_p.stride[0], // stride_h |
217 | conv_p.stride[1], // stride_w |
218 | outProcess.getAZeroPoint(), |
219 | activations, |
220 | B_zero_point, |
221 | *(packed_weights.getPackedWForDepthwise()), |
222 | C_multiplier, |
223 | outProcess.getCZeroPoint(), |
224 | out, |
225 | outProcess.getColOffsets(), |
226 | outProcess.getBias(), |
227 | outProcess.RELU_FUSED, // fuse_relu |
228 | act_times_w_scale, |
229 | thread_id, |
230 | num_threads); |
231 | } else if ( |
232 | processOutputType::QGRANType == QuantizationGranularity::GROUP) { |
233 | depthwise_2d_same_pad<QuantizationGranularity::GROUP>( |
234 | conv_p.MB, // mini batch |
235 | conv_p.IN_DIM[0], // H |
236 | conv_p.IN_DIM[1], // W |
237 | conv_p.IC, // input channels |
238 | conv_p.OC, // output channels |
239 | conv_p.stride[0], // stride_h |
240 | conv_p.stride[1], // stride_w |
241 | outProcess.getAZeroPoint(), |
242 | activations, |
243 | B_zero_point, |
244 | *(packed_weights.getPackedWForDepthwise()), |
245 | C_multiplier, |
246 | outProcess.getCZeroPoint(), |
247 | out, |
248 | outProcess.getColOffsets(), |
249 | outProcess.getBias(), |
250 | outProcess.RELU_FUSED, // fuse_relu |
251 | act_times_w_scale, // act_scale * weight_scale |
252 | thread_id, |
253 | num_threads); |
254 | } else if ( |
255 | processOutputType::QGRANType == |
256 | QuantizationGranularity::OUT_CHANNEL) { |
257 | // The number of input channels == groups for depthwise convolutions |
258 | depthwise_2d_same_pad<QuantizationGranularity::OUT_CHANNEL>( |
259 | conv_p.MB, // mini batch |
260 | conv_p.IN_DIM[0], // H |
261 | conv_p.IN_DIM[1], // W |
262 | conv_p.IC, // input channels |
263 | conv_p.OC, // output channels |
264 | conv_p.stride[0], // stride_h |
265 | conv_p.stride[1], // stride_w |
266 | outProcess.getAZeroPoint(), |
267 | activations, |
268 | B_zero_point, |
269 | *(packed_weights.getPackedWForDepthwise()), |
270 | C_multiplier, |
271 | outProcess.getCZeroPoint(), |
272 | out, |
273 | outProcess.getColOffsets(), |
274 | outProcess.getBias(), |
275 | outProcess.RELU_FUSED, // fuse_relu |
276 | act_times_w_scale, // act_scale * weight_scale |
277 | thread_id, |
278 | num_threads); |
279 | } else { |
280 | std::string msg = |
281 | "[FBGEMM_CONV_ERROR] This quantization granularity is " |
282 | "not supported" ; |
283 | throw std::runtime_error(msg); |
284 | } |
285 | } else { |
286 | std::string msg = |
287 | "[FBGEMM_CONV_ERROR] This spatial dim is not supported" ; |
288 | throw std::runtime_error(msg); |
289 | } |
290 | break; |
291 | } |
292 | case optimized_conv_t::groupwise: { |
293 | // optimized groupwise convolution |
294 | // std::cout << "Groupwise fast path" << std::endl; |
295 | std::vector<int32_t> row_offset_buf( |
296 | rowOffsetBufferSizeGConv<SPATIAL_DIM>(conv_p)); |
297 | outProcess.setRowOffsets(row_offset_buf.data()); |
298 | fbgemmGroupwiseConv( |
299 | conv_p, |
300 | activations, |
301 | outProcess.getAZeroPoint(), |
302 | row_offset_buf.data(), |
303 | *(packed_weights.getPackedWForGroupwise()), |
304 | out, |
305 | outBuffer, |
306 | outProcess, |
307 | thread_id, |
308 | num_threads); |
309 | break; |
310 | } |
311 | case optimized_conv_t::pointwise: { |
312 | std::vector<int32_t> row_offset_buf( |
313 | PackAWithRowOffset<uint8_t>::rowOffsetBufferSize(blocking_params)); |
314 | int image_dim = std::accumulate( |
315 | conv_p.IN_DIM.begin(), |
316 | conv_p.IN_DIM.end(), |
317 | 1, |
318 | std::multiplies<int>()); |
319 | PackAWithRowOffset<uint8_t, ACC_T> packA( |
320 | matrix_op_t::NoTranspose, |
321 | conv_p.MB * image_dim, |
322 | conv_p.IC, |
323 | activations, |
324 | conv_p.IC, |
325 | nullptr, |
326 | conv_p.G, |
327 | row_offset_buf.data(), |
328 | blocking_params); |
329 | |
330 | outProcess.setRowOffsets(row_offset_buf.data()); |
331 | fbgemmPacked( |
332 | packA, |
333 | *(packed_weights.getPackedWForPointwise()), |
334 | out, |
335 | outBuffer, |
336 | conv_p.OC, |
337 | outProcess, |
338 | thread_id, |
339 | num_threads, |
340 | blocking_params); |
341 | break; |
342 | } |
343 | case optimized_conv_t::directconv: { |
344 | // specialized direct convolution path |
345 | // std::cout << "Directconv fast path" << std::endl; |
346 | fbgemmDirectConv<SPATIAL_DIM, processOutputType::QGRANType>( |
347 | conv_p, |
348 | // Aint8, |
349 | activations, |
350 | *(packed_weights.getPackedWForDirectconv()), |
351 | out, |
352 | outBuffer, |
353 | outProcess, |
354 | outProcess.getBias(), |
355 | thread_id, |
356 | num_threads); |
357 | break; |
358 | } |
359 | case optimized_conv_t::fastpath1d: { |
360 | break; |
361 | } |
362 | case optimized_conv_t::im2col: { |
363 | // All other convolutions go through im2col-based implementation |
364 | // std::cout << "Im2col path" << std::endl; |
365 | std::vector<int32_t> row_offset_buf( |
366 | PackAWithIm2Col<uint8_t, ACC_T, SPATIAL_DIM>::rowOffsetBufferSize( |
367 | blocking_params)); |
368 | |
369 | const std::int32_t* b_zero_point = outProcess.getBZeroPoint(); |
370 | bool b_symmetric = false; |
371 | if (processOutputType::QGRANType == QuantizationGranularity::TENSOR) { |
372 | b_symmetric = b_zero_point[0] == 0; |
373 | } else if ( |
374 | processOutputType::QGRANType == QuantizationGranularity::GROUP) { |
375 | b_symmetric = |
376 | std::all_of(b_zero_point, b_zero_point + conv_p.G, [](int i) { |
377 | return i == 0; |
378 | }); |
379 | } else if ( |
380 | processOutputType::QGRANType == |
381 | QuantizationGranularity::OUT_CHANNEL) { |
382 | b_symmetric = |
383 | std::all_of(b_zero_point, b_zero_point + conv_p.OC, [](int i) { |
384 | return i == 0; |
385 | }); |
386 | } else { |
387 | std::string msg = |
388 | "[FBGEMM_CONV_ERROR] This quantization granularity is " |
389 | "not supported" ; |
390 | throw std::runtime_error(msg); |
391 | } |
392 | PackAWithIm2Col<uint8_t, ACC_T, SPATIAL_DIM> packA( |
393 | conv_p, |
394 | activations, |
395 | nullptr, /* buffer for packed matrix */ |
396 | outProcess.getAZeroPoint(), |
397 | row_offset_buf.data(), |
398 | b_symmetric, |
399 | blocking_params); |
400 | |
401 | outProcess.setRowOffsets(row_offset_buf.data()); |
402 | fbgemmPacked( |
403 | packA, |
404 | *(packed_weights.getPackedWForIm2col()), |
405 | out, |
406 | outBuffer, |
407 | conv_p.OC, |
408 | outProcess, |
409 | thread_id, |
410 | num_threads, |
411 | blocking_params); |
412 | break; |
413 | } |
414 | } // switch |
415 | |
416 | return 0; |
417 | } |
418 | |
419 | #define INSTANTIATE_BASE(ACC_T, Q_GRAN, RELU, SPATIAL_DIM, BIAS_TYPE) \ |
420 | template FBGEMM_API int fbgemmConv( \ |
421 | const conv_param_t<SPATIAL_DIM>& conv_p, \ |
422 | const std::uint8_t* activations, \ |
423 | PackWeightsForConv<SPATIAL_DIM, std::int8_t, ACC_T>& packed_weights, \ |
424 | std::uint8_t* out, \ |
425 | std::int32_t* outBuffer, \ |
426 | ReQuantizeOutput<RELU, Q_GRAN, BIAS_TYPE>& outProcess, \ |
427 | int thread_id, \ |
428 | int num_threads, \ |
429 | const BlockingFactors* blocking_params); |
430 | |
431 | #define INSTANTIATE_BIAS_T(ACC_T, Q_GRAN, RELU, SPATIAL_DIM) \ |
432 | INSTANTIATE_BASE(ACC_T, Q_GRAN, RELU, SPATIAL_DIM, float) \ |
433 | INSTANTIATE_BASE(ACC_T, Q_GRAN, RELU, SPATIAL_DIM, int32_t) |
434 | |
435 | #define INSTANTIATE_SPATIAL_DIM(ACC_T, Q_GRAN, RELU) \ |
436 | INSTANTIATE_BIAS_T(ACC_T, Q_GRAN, RELU, 1) \ |
437 | INSTANTIATE_BIAS_T(ACC_T, Q_GRAN, RELU, 2) \ |
438 | INSTANTIATE_BIAS_T(ACC_T, Q_GRAN, RELU, 3) |
439 | |
440 | #define INSTANTIATE_RELU(ACC_T, Q_GRAN) \ |
441 | INSTANTIATE_SPATIAL_DIM(ACC_T, Q_GRAN, true) \ |
442 | INSTANTIATE_SPATIAL_DIM(ACC_T, Q_GRAN, false) |
443 | |
444 | #define INSTANTIATE_Q_GRANS(ACC_T) \ |
445 | INSTANTIATE_RELU(ACC_T, QuantizationGranularity::TENSOR) \ |
446 | INSTANTIATE_RELU(ACC_T, QuantizationGranularity::GROUP) \ |
447 | INSTANTIATE_RELU(ACC_T, QuantizationGranularity::OUT_CHANNEL) |
448 | |
449 | INSTANTIATE_Q_GRANS(std::int32_t) |
450 | |
451 | #undef INSTANTIATE_Q_GRANS |
452 | #undef INSTANTIATE_RELU |
453 | #undef INSTANTIATE_SPATIAL_DIM |
454 | #undef INSTANTIATE_BIAS_T |
455 | #undef INSTANTIATE_BASE |
456 | |
457 | template bool takeDepthWiseFastPath<2, std::int32_t>( |
458 | const conv_param_t<2>& conv_p); |
459 | template bool takeDepthWiseFastPath<3, std::int32_t>( |
460 | const conv_param_t<3>& conv_p); |
461 | template bool takeDepthWiseFastPath<2, std::int16_t>( |
462 | const conv_param_t<2>& conv_p); |
463 | template bool takeDepthWiseFastPath<3, std::int16_t>( |
464 | const conv_param_t<3>& conv_p); |
465 | |
466 | template bool takeDirectConvPath<2, std::int32_t>( |
467 | const conv_param_t<2>& conv_p); |
468 | template bool takeDirectConvPath<3, std::int32_t>( |
469 | const conv_param_t<3>& conv_p); |
470 | template bool takeDirectConvPath<2, std::int16_t>( |
471 | const conv_param_t<2>& conv_p); |
472 | template bool takeDirectConvPath<3, std::int16_t>( |
473 | const conv_param_t<3>& conv_p); |
474 | |
475 | template FBGEMM_API optimized_conv_t |
476 | ConvFastPath<1, std::int32_t>(const conv_param_t<1>& conv_p); |
477 | template FBGEMM_API optimized_conv_t |
478 | ConvFastPath<2, std::int32_t>(const conv_param_t<2>& conv_p); |
479 | template FBGEMM_API optimized_conv_t |
480 | ConvFastPath<3, std::int32_t>(const conv_param_t<3>& conv_p); |
481 | |
482 | template FBGEMM_API optimized_conv_t |
483 | ConvFastPath<1, std::int16_t>(const conv_param_t<1>& conv_p); |
484 | template FBGEMM_API optimized_conv_t |
485 | ConvFastPath<2, std::int16_t>(const conv_param_t<2>& conv_p); |
486 | template FBGEMM_API optimized_conv_t |
487 | ConvFastPath<3, std::int16_t>(const conv_param_t<3>& conv_p); |
488 | |
489 | } // namespace fbgemm |
490 | |