1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * This source code is licensed under the BSD-style license found in the |
5 | * LICENSE file in the root directory of this source tree. |
6 | */ |
7 | #define FBGEMM_EXPORTS |
8 | #include "fbgemm/FbgemmI8DepthwiseAvx2.h" |
9 | |
10 | #include <stdexcept> // for logic_error |
11 | #include <string> |
12 | |
13 | #include "./FbgemmI8DepthwiseAvx2-inl.h" |
14 | #include "./GenerateI8Depthwise.h" |
15 | #include "./MaskAvx2.h" |
16 | #include "fbgemm/Utils.h" |
17 | #include "fbgemm/UtilsAvx2.h" |
18 | |
19 | using namespace std; |
20 | |
21 | namespace fbgemm { |
22 | |
23 | template < |
24 | bool FUSE_RELU, |
25 | bool HAS_BIAS, |
26 | bool A_SYMMETRIC, |
27 | bool B_SYMMETRIC, |
28 | QuantizationGranularity Q_GRAN, |
29 | typename BIAS_TYPE> |
30 | static ALWAYS_INLINE void depthwise_3d_kernel_( |
31 | int T, |
32 | int H, |
33 | int W, |
34 | int IC, |
35 | int OC, |
36 | int t, |
37 | int h, |
38 | int w, |
39 | array<int, 3> F, |
40 | int stride_t, |
41 | int stride_h, |
42 | int stride_w, |
43 | int32_t A_zero_point, |
44 | const uint8_t* A, |
45 | const int32_t* B_zero_point, |
46 | const int8_t* Bp, |
47 | const float* C_multiplier, |
48 | int32_t C_zero_point, |
49 | int32_t* C_int32, |
50 | uint8_t* C_uint8, |
51 | int32_t* row_offsets, |
52 | const int32_t* col_offsets, |
53 | const BIAS_TYPE* bias, |
54 | const float* act_times_w_scale, |
55 | GenI8Depthwise::jit_kernel_signature* pregenerated_kernel = nullptr) { |
56 | int R = F[1], S = F[2]; |
57 | int PAD_P = (F[0] - 1) / 2, PAD_T = (F[1] - 1) / 2, PAD_B = PAD_T, |
58 | PAD_L = (F[2] - 1) / 2, PAD_R = PAD_L; |
59 | int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1; |
60 | int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; |
61 | int t_in = -PAD_P + t * stride_t; |
62 | int h_in = -PAD_T + h * stride_h; |
63 | int w_in = -PAD_L + w * stride_w; |
64 | |
65 | int remainder = OC % 32; |
66 | if (remainder == 0) { |
67 | remainder = 32; |
68 | } |
69 | |
70 | GenI8Depthwise::jit_kernel_signature kernel = pregenerated_kernel |
71 | ? *pregenerated_kernel |
72 | : GenI8Depthwise().getOrCreate( |
73 | /*D=*/3, |
74 | F, |
75 | OC / IC, |
76 | /*compute_a_sum=*/!B_SYMMETRIC, |
77 | remainder, |
78 | /*prev_skip=*/std::max(-t_in, 0), |
79 | /*next_skip=*/std::max(t_in + F[0] - T, 0), |
80 | /*top_skip=*/std::max(-h_in, 0), |
81 | /*bottom_skip=*/std::max(h_in + F[1] - H, 0), |
82 | /*left_skip=*/std::max(-w_in, 0), |
83 | /*right_skip=*/std::max(w_in + F[2] - W, 0)); |
84 | kernel( |
85 | A + ((t_in * H + h_in) * W + w_in) * IC, |
86 | Bp, |
87 | C_int32, |
88 | B_SYMMETRIC ? nullptr : row_offsets, |
89 | H, |
90 | W, |
91 | IC, |
92 | internal::avx2_ps_or_epi32_combined_mask, |
93 | A_zero_point); |
94 | |
95 | if (OC == IC) { |
96 | requantize_<FUSE_RELU, HAS_BIAS, Q_GRAN, A_SYMMETRIC, B_SYMMETRIC, 1>( |
97 | A_zero_point, |
98 | B_zero_point, |
99 | C_multiplier, |
100 | C_zero_point, |
101 | C_int32, |
102 | C_uint8 + ((t * H_OUT + h) * W_OUT + w) * OC, |
103 | OC, |
104 | row_offsets, |
105 | col_offsets, |
106 | bias, |
107 | act_times_w_scale); |
108 | } else { |
109 | assert(OC / IC == 2); |
110 | requantize_<FUSE_RELU, HAS_BIAS, Q_GRAN, A_SYMMETRIC, B_SYMMETRIC, 2>( |
111 | A_zero_point, |
112 | B_zero_point, |
113 | C_multiplier, |
114 | C_zero_point, |
115 | C_int32, |
116 | C_uint8 + ((t * H_OUT + h) * W_OUT + w) * OC, |
117 | OC, |
118 | row_offsets, |
119 | col_offsets, |
120 | bias, |
121 | act_times_w_scale); |
122 | } |
123 | } |
124 | |
125 | template < |
126 | bool FUSE_RELU, |
127 | bool HAS_BIAS, |
128 | bool A_SYMMETRIC, |
129 | bool B_SYMMETRIC, |
130 | QuantizationGranularity Q_GRAN, |
131 | typename BIAS_TYPE> |
132 | static ALWAYS_INLINE void depthwise_3d_same_pad_( |
133 | const conv_param_t<3>& conv_p, |
134 | int32_t A_zero_point, |
135 | const uint8_t* A, |
136 | const int32_t* B_zero_point, |
137 | const PackedDepthWiseConvMatrix& B, |
138 | const float* C_multiplier, |
139 | int32_t C_zero_point, |
140 | int32_t* C_int32, |
141 | uint8_t* C_uint8, |
142 | const int32_t* col_offsets, |
143 | const BIAS_TYPE* bias, |
144 | const float* act_times_w_scale, |
145 | int thread_id, |
146 | int num_threads) { |
147 | int N = conv_p.MB; |
148 | int T = conv_p.IN_DIM[0]; |
149 | int H = conv_p.IN_DIM[1]; |
150 | int W = conv_p.IN_DIM[2]; |
151 | int IC = conv_p.IC; |
152 | int OC = conv_p.OC; |
153 | array<int, 3> F = conv_p.K; |
154 | int stride_t = conv_p.stride[0]; |
155 | int stride_h = conv_p.stride[1]; |
156 | int stride_w = conv_p.stride[2]; |
157 | |
158 | assert(IC % 8 == 0); |
159 | |
160 | int K_T = F[0], K_H = F[1], K_W = F[2]; |
161 | int PAD_P = (F[0] - 1) / 2, PAD_N = PAD_P, PAD_T = (F[1] - 1) / 2, |
162 | PAD_B = PAD_T, PAD_L = (F[2] - 1) / 2, PAD_R = PAD_L; |
163 | int64_t T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1; |
164 | int64_t H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1; |
165 | int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1; |
166 | const int8_t* Bp = B.PackedMat(); |
167 | |
168 | int32_t* row_offsets = static_cast<int32_t*>( |
169 | fbgemmAlignedAlloc(64, (IC + 31) / 32 * 32 * sizeof(int32_t))); |
170 | |
171 | int64_t n_begin, n_end, t_begin, t_end, h_begin, h_end; |
172 | // Reuse the 3-dim partition scheme for parallelization in matrix |
173 | // multiplication. |
174 | thread_type_t th_info = |
175 | fbgemmGetThreadPartition(N, T_OUT, H_OUT, thread_id, num_threads); |
176 | // Calculate the begin and end index along the batch (N) dimension |
177 | fbgemmPartition1D( |
178 | th_info.g_thread_id, th_info.g_num_threads, N, n_begin, n_end); |
179 | // Calculate the begin and end index along the T dimension |
180 | fbgemmPartition1D( |
181 | th_info.m_thread_id, th_info.m_num_threads, T_OUT, t_begin, t_end); |
182 | // Calculate the begin and end index along the H dimension |
183 | fbgemmPartition1D( |
184 | th_info.n_thread_id, th_info.n_num_threads, H_OUT, h_begin, h_end); |
185 | |
186 | GenI8Depthwise::jit_kernel_signature middle_kernel; |
187 | |
188 | for (int n = n_begin; n < n_end; ++n) { |
189 | const uint8_t* A_base = A + n * T * H * W * IC; |
190 | uint8_t* C_uint8_base = C_uint8 + n * T_OUT * H_OUT * W_OUT * OC; |
191 | |
192 | int t; |
193 | for (t = t_begin; t < PAD_P; ++t) { |
194 | int h; |
195 | for (h = h_begin; h < PAD_T; ++h) { |
196 | for (int w = 0; w < W_OUT; ++w) { |
197 | depthwise_3d_kernel_< |
198 | FUSE_RELU, |
199 | HAS_BIAS, |
200 | A_SYMMETRIC, |
201 | B_SYMMETRIC, |
202 | Q_GRAN>( |
203 | T, |
204 | H, |
205 | W, |
206 | IC, |
207 | OC, |
208 | t, |
209 | h, |
210 | w, |
211 | F, |
212 | stride_t, |
213 | stride_h, |
214 | stride_w, |
215 | A_zero_point, |
216 | A_base, |
217 | B_zero_point, |
218 | Bp, |
219 | C_multiplier, |
220 | C_zero_point, |
221 | C_int32, |
222 | C_uint8_base, |
223 | row_offsets, |
224 | col_offsets, |
225 | bias, |
226 | act_times_w_scale); |
227 | } // w |
228 | } // h |
229 | |
230 | for (; h < std::min(H_OUT - PAD_B - stride_h + 1, h_end); ++h) { |
231 | int w; |
232 | for (w = 0; w < PAD_L; ++w) { |
233 | depthwise_3d_kernel_< |
234 | FUSE_RELU, |
235 | HAS_BIAS, |
236 | A_SYMMETRIC, |
237 | B_SYMMETRIC, |
238 | Q_GRAN>( |
239 | T, |
240 | H, |
241 | W, |
242 | IC, |
243 | OC, |
244 | t, |
245 | h, |
246 | w, |
247 | F, |
248 | stride_t, |
249 | stride_h, |
250 | stride_w, |
251 | A_zero_point, |
252 | A_base, |
253 | B_zero_point, |
254 | Bp, |
255 | C_multiplier, |
256 | C_zero_point, |
257 | C_int32, |
258 | C_uint8_base, |
259 | row_offsets, |
260 | col_offsets, |
261 | bias, |
262 | act_times_w_scale); |
263 | } // w |
264 | |
265 | GenI8Depthwise::jit_kernel_signature kernel; |
266 | for (; w < W_OUT - PAD_R - stride_w + 1; ++w) { |
267 | if (w == PAD_L) { |
268 | int remainder = OC % 32; |
269 | if (remainder == 0) { |
270 | remainder = 32; |
271 | } |
272 | int t_in = -PAD_P + t * stride_t; |
273 | kernel = GenI8Depthwise().getOrCreate( |
274 | /*D=*/3, |
275 | F, |
276 | OC / IC, |
277 | /*compute_a_sum=*/!B_SYMMETRIC, |
278 | remainder, |
279 | /*prev_skip=*/std::max(-t_in, 0), |
280 | /*next_skip=*/std::max(t_in + F[0] - T, 0), |
281 | 0, |
282 | 0, |
283 | 0, |
284 | 0); |
285 | } |
286 | depthwise_3d_kernel_< |
287 | FUSE_RELU, |
288 | HAS_BIAS, |
289 | A_SYMMETRIC, |
290 | B_SYMMETRIC, |
291 | Q_GRAN>( |
292 | T, |
293 | H, |
294 | W, |
295 | IC, |
296 | OC, |
297 | t, |
298 | h, |
299 | w, |
300 | F, |
301 | stride_t, |
302 | stride_h, |
303 | stride_w, |
304 | A_zero_point, |
305 | A_base, |
306 | B_zero_point, |
307 | Bp, |
308 | C_multiplier, |
309 | C_zero_point, |
310 | C_int32, |
311 | C_uint8_base, |
312 | row_offsets, |
313 | col_offsets, |
314 | bias, |
315 | act_times_w_scale, |
316 | &kernel); |
317 | } // w |
318 | |
319 | for (; w < W_OUT; ++w) { |
320 | depthwise_3d_kernel_< |
321 | FUSE_RELU, |
322 | HAS_BIAS, |
323 | A_SYMMETRIC, |
324 | B_SYMMETRIC, |
325 | Q_GRAN>( |
326 | T, |
327 | H, |
328 | W, |
329 | IC, |
330 | OC, |
331 | t, |
332 | h, |
333 | w, |
334 | F, |
335 | stride_t, |
336 | stride_h, |
337 | stride_w, |
338 | A_zero_point, |
339 | A_base, |
340 | B_zero_point, |
341 | Bp, |
342 | C_multiplier, |
343 | C_zero_point, |
344 | C_int32, |
345 | C_uint8_base, |
346 | row_offsets, |
347 | col_offsets, |
348 | bias, |
349 | act_times_w_scale); |
350 | } // w |
351 | } // h |
352 | |
353 | for (; h < h_end; ++h) { |
354 | for (int w = 0; w < W_OUT; ++w) { |
355 | depthwise_3d_kernel_< |
356 | FUSE_RELU, |
357 | HAS_BIAS, |
358 | A_SYMMETRIC, |
359 | B_SYMMETRIC, |
360 | Q_GRAN>( |
361 | T, |
362 | H, |
363 | W, |
364 | IC, |
365 | OC, |
366 | t, |
367 | h, |
368 | w, |
369 | F, |
370 | stride_t, |
371 | stride_h, |
372 | stride_w, |
373 | A_zero_point, |
374 | A_base, |
375 | B_zero_point, |
376 | Bp, |
377 | C_multiplier, |
378 | C_zero_point, |
379 | C_int32, |
380 | C_uint8_base, |
381 | row_offsets, |
382 | col_offsets, |
383 | bias, |
384 | act_times_w_scale); |
385 | } // w |
386 | } // h |
387 | } // t |
388 | |
389 | for (; t < std::min(T_OUT - PAD_N - stride_t + 1, t_end); ++t) { |
390 | int h; |
391 | for (h = h_begin; h < PAD_T; ++h) { |
392 | for (int w = 0; w < W_OUT; ++w) { |
393 | depthwise_3d_kernel_< |
394 | FUSE_RELU, |
395 | HAS_BIAS, |
396 | A_SYMMETRIC, |
397 | B_SYMMETRIC, |
398 | Q_GRAN>( |
399 | T, |
400 | H, |
401 | W, |
402 | IC, |
403 | OC, |
404 | t, |
405 | h, |
406 | w, |
407 | F, |
408 | stride_t, |
409 | stride_h, |
410 | stride_w, |
411 | A_zero_point, |
412 | A_base, |
413 | B_zero_point, |
414 | Bp, |
415 | C_multiplier, |
416 | C_zero_point, |
417 | C_int32, |
418 | C_uint8_base, |
419 | row_offsets, |
420 | col_offsets, |
421 | bias, |
422 | act_times_w_scale); |
423 | } // w |
424 | } // h |
425 | |
426 | for (; h < std::min(H_OUT - PAD_B - stride_h + 1, h_end); ++h) { |
427 | int w; |
428 | for (w = 0; w < PAD_L; ++w) { |
429 | depthwise_3d_kernel_< |
430 | FUSE_RELU, |
431 | HAS_BIAS, |
432 | A_SYMMETRIC, |
433 | B_SYMMETRIC, |
434 | Q_GRAN>( |
435 | T, |
436 | H, |
437 | W, |
438 | IC, |
439 | OC, |
440 | t, |
441 | h, |
442 | w, |
443 | F, |
444 | stride_t, |
445 | stride_h, |
446 | stride_w, |
447 | A_zero_point, |
448 | A_base, |
449 | B_zero_point, |
450 | Bp, |
451 | C_multiplier, |
452 | C_zero_point, |
453 | C_int32, |
454 | C_uint8_base, |
455 | row_offsets, |
456 | col_offsets, |
457 | bias, |
458 | act_times_w_scale); |
459 | } // w |
460 | |
461 | for (; w < W_OUT - PAD_R - stride_w + 1; ++w) { |
462 | if (n == n_begin && w == PAD_L) { |
463 | int remainder = OC % 32; |
464 | if (remainder == 0) { |
465 | remainder = 32; |
466 | } |
467 | middle_kernel = GenI8Depthwise().getOrCreate( |
468 | /*D=*/3, |
469 | F, |
470 | OC / IC, |
471 | /*compute_a_sum=*/!B_SYMMETRIC, |
472 | remainder, |
473 | 0, |
474 | 0, |
475 | 0, |
476 | 0, |
477 | 0, |
478 | 0); |
479 | } |
480 | depthwise_3d_kernel_< |
481 | FUSE_RELU, |
482 | HAS_BIAS, |
483 | A_SYMMETRIC, |
484 | B_SYMMETRIC, |
485 | Q_GRAN>( |
486 | T, |
487 | H, |
488 | W, |
489 | IC, |
490 | OC, |
491 | t, |
492 | h, |
493 | w, |
494 | F, |
495 | stride_t, |
496 | stride_h, |
497 | stride_w, |
498 | A_zero_point, |
499 | A_base, |
500 | B_zero_point, |
501 | Bp, |
502 | C_multiplier, |
503 | C_zero_point, |
504 | C_int32, |
505 | C_uint8_base, |
506 | row_offsets, |
507 | col_offsets, |
508 | bias, |
509 | act_times_w_scale, |
510 | &middle_kernel); |
511 | } |
512 | |
513 | for (; w < W_OUT; ++w) { |
514 | depthwise_3d_kernel_< |
515 | FUSE_RELU, |
516 | HAS_BIAS, |
517 | A_SYMMETRIC, |
518 | B_SYMMETRIC, |
519 | Q_GRAN>( |
520 | T, |
521 | H, |
522 | W, |
523 | IC, |
524 | OC, |
525 | t, |
526 | h, |
527 | w, |
528 | F, |
529 | stride_t, |
530 | stride_h, |
531 | stride_w, |
532 | A_zero_point, |
533 | A_base, |
534 | B_zero_point, |
535 | Bp, |
536 | C_multiplier, |
537 | C_zero_point, |
538 | C_int32, |
539 | C_uint8_base, |
540 | row_offsets, |
541 | col_offsets, |
542 | bias, |
543 | act_times_w_scale); |
544 | } |
545 | } // h |
546 | |
547 | for (; h < h_end; ++h) { |
548 | for (int w = 0; w < W_OUT; ++w) { |
549 | depthwise_3d_kernel_< |
550 | FUSE_RELU, |
551 | HAS_BIAS, |
552 | A_SYMMETRIC, |
553 | B_SYMMETRIC, |
554 | Q_GRAN>( |
555 | T, |
556 | H, |
557 | W, |
558 | IC, |
559 | OC, |
560 | t, |
561 | h, |
562 | w, |
563 | F, |
564 | stride_t, |
565 | stride_h, |
566 | stride_w, |
567 | A_zero_point, |
568 | A_base, |
569 | B_zero_point, |
570 | Bp, |
571 | C_multiplier, |
572 | C_zero_point, |
573 | C_int32, |
574 | C_uint8_base, |
575 | row_offsets, |
576 | col_offsets, |
577 | bias, |
578 | act_times_w_scale); |
579 | } // w |
580 | } // h |
581 | } // t |
582 | |
583 | for (; t < t_end; ++t) { |
584 | int h; |
585 | for (h = h_begin; h < PAD_T; ++h) { |
586 | for (int w = 0; w < W_OUT; ++w) { |
587 | depthwise_3d_kernel_< |
588 | FUSE_RELU, |
589 | HAS_BIAS, |
590 | A_SYMMETRIC, |
591 | B_SYMMETRIC, |
592 | Q_GRAN>( |
593 | T, |
594 | H, |
595 | W, |
596 | IC, |
597 | OC, |
598 | t, |
599 | h, |
600 | w, |
601 | F, |
602 | stride_t, |
603 | stride_h, |
604 | stride_w, |
605 | A_zero_point, |
606 | A_base, |
607 | B_zero_point, |
608 | Bp, |
609 | C_multiplier, |
610 | C_zero_point, |
611 | C_int32, |
612 | C_uint8_base, |
613 | row_offsets, |
614 | col_offsets, |
615 | bias, |
616 | act_times_w_scale); |
617 | } // w |
618 | } // h |
619 | |
620 | for (; h < std::min(H_OUT - PAD_B - stride_h + 1, h_end); ++h) { |
621 | int w; |
622 | for (w = 0; w < PAD_L; ++w) { |
623 | depthwise_3d_kernel_< |
624 | FUSE_RELU, |
625 | HAS_BIAS, |
626 | A_SYMMETRIC, |
627 | B_SYMMETRIC, |
628 | Q_GRAN>( |
629 | T, |
630 | H, |
631 | W, |
632 | IC, |
633 | OC, |
634 | t, |
635 | h, |
636 | w, |
637 | F, |
638 | stride_t, |
639 | stride_h, |
640 | stride_w, |
641 | A_zero_point, |
642 | A_base, |
643 | B_zero_point, |
644 | Bp, |
645 | C_multiplier, |
646 | C_zero_point, |
647 | C_int32, |
648 | C_uint8_base, |
649 | row_offsets, |
650 | col_offsets, |
651 | bias, |
652 | act_times_w_scale); |
653 | } // w |
654 | |
655 | GenI8Depthwise::jit_kernel_signature kernel; |
656 | for (; w < W_OUT - PAD_R - stride_w + 1; ++w) { |
657 | if (w == PAD_L) { |
658 | int remainder = OC % 32; |
659 | if (remainder == 0) { |
660 | remainder = 32; |
661 | } |
662 | int t_in = -PAD_P + t * stride_t; |
663 | kernel = GenI8Depthwise().getOrCreate( |
664 | /*D=*/3, |
665 | F, |
666 | OC / IC, |
667 | /*compute_a_sum=*/!B_SYMMETRIC, |
668 | remainder, |
669 | /*prev_skip=*/std::max(-t_in, 0), |
670 | /*next_skip=*/std::max(t_in + F[0] - T, 0), |
671 | 0, |
672 | 0, |
673 | 0, |
674 | 0); |
675 | } |
676 | depthwise_3d_kernel_< |
677 | FUSE_RELU, |
678 | HAS_BIAS, |
679 | A_SYMMETRIC, |
680 | B_SYMMETRIC, |
681 | Q_GRAN>( |
682 | T, |
683 | H, |
684 | W, |
685 | IC, |
686 | OC, |
687 | t, |
688 | h, |
689 | w, |
690 | F, |
691 | stride_t, |
692 | stride_h, |
693 | stride_w, |
694 | A_zero_point, |
695 | A_base, |
696 | B_zero_point, |
697 | Bp, |
698 | C_multiplier, |
699 | C_zero_point, |
700 | C_int32, |
701 | C_uint8_base, |
702 | row_offsets, |
703 | col_offsets, |
704 | bias, |
705 | act_times_w_scale, |
706 | &kernel); |
707 | } // w |
708 | |
709 | for (; w < W_OUT; ++w) { |
710 | depthwise_3d_kernel_< |
711 | FUSE_RELU, |
712 | HAS_BIAS, |
713 | A_SYMMETRIC, |
714 | B_SYMMETRIC, |
715 | Q_GRAN>( |
716 | T, |
717 | H, |
718 | W, |
719 | IC, |
720 | OC, |
721 | t, |
722 | h, |
723 | w, |
724 | F, |
725 | stride_t, |
726 | stride_h, |
727 | stride_w, |
728 | A_zero_point, |
729 | A_base, |
730 | B_zero_point, |
731 | Bp, |
732 | C_multiplier, |
733 | C_zero_point, |
734 | C_int32, |
735 | C_uint8_base, |
736 | row_offsets, |
737 | col_offsets, |
738 | bias, |
739 | act_times_w_scale); |
740 | } // w |
741 | } // h |
742 | |
743 | for (; h < h_end; ++h) { |
744 | for (int w = 0; w < W_OUT; ++w) { |
745 | depthwise_3d_kernel_< |
746 | FUSE_RELU, |
747 | HAS_BIAS, |
748 | A_SYMMETRIC, |
749 | B_SYMMETRIC, |
750 | Q_GRAN>( |
751 | T, |
752 | H, |
753 | W, |
754 | IC, |
755 | OC, |
756 | t, |
757 | h, |
758 | w, |
759 | F, |
760 | stride_t, |
761 | stride_h, |
762 | stride_w, |
763 | A_zero_point, |
764 | A_base, |
765 | B_zero_point, |
766 | Bp, |
767 | C_multiplier, |
768 | C_zero_point, |
769 | C_int32, |
770 | C_uint8_base, |
771 | row_offsets, |
772 | col_offsets, |
773 | bias, |
774 | act_times_w_scale); |
775 | } // w |
776 | } // h |
777 | } // t |
778 | } // for each n |
779 | fbgemmAlignedFree(row_offsets); |
780 | } |
781 | |
782 | // Dispatch A_SYMMETRIC and B_SYMMETRIC |
783 | template < |
784 | bool FUSE_RELU, |
785 | bool HAS_BIAS, |
786 | QuantizationGranularity Q_GRAN, |
787 | typename BIAS_TYPE> |
788 | static void depthwise_3d_same_pad_( |
789 | const conv_param_t<3>& conv_p, |
790 | int32_t A_zero_point, |
791 | const uint8_t* A, |
792 | const int32_t* B_zero_point, |
793 | const PackedDepthWiseConvMatrix& B, |
794 | const float* C_multiplier, |
795 | int32_t C_zero_point, |
796 | uint8_t* C, |
797 | const int32_t* col_offsets, |
798 | const BIAS_TYPE* bias, |
799 | const float* act_times_w_scale, |
800 | int thread_id, |
801 | int num_threads) { |
802 | int32_t* C_int32_temp = static_cast<int32_t*>( |
803 | fbgemmAlignedAlloc(64, (conv_p.OC + 31) / 32 * 32 * sizeof(int32_t))); |
804 | if (A_zero_point == 0 || col_offsets == nullptr) { |
805 | if (Q_GRAN == QuantizationGranularity::TENSOR && B_zero_point[0] == 0) { |
806 | depthwise_3d_same_pad_< |
807 | FUSE_RELU, |
808 | HAS_BIAS, |
809 | true /*A_symmetric*/, |
810 | true /*B_symmetric*/, |
811 | Q_GRAN>( |
812 | conv_p, |
813 | A_zero_point, |
814 | A, |
815 | B_zero_point, |
816 | B, |
817 | C_multiplier, |
818 | C_zero_point, |
819 | C_int32_temp, |
820 | C, |
821 | col_offsets, |
822 | bias, |
823 | act_times_w_scale, |
824 | thread_id, |
825 | num_threads); |
826 | } else { |
827 | depthwise_3d_same_pad_< |
828 | FUSE_RELU, |
829 | HAS_BIAS, |
830 | true /*A_symmetric*/, |
831 | false /*B_symmetric*/, |
832 | Q_GRAN>( |
833 | conv_p, |
834 | A_zero_point, |
835 | A, |
836 | B_zero_point, |
837 | B, |
838 | C_multiplier, |
839 | C_zero_point, |
840 | C_int32_temp, |
841 | C, |
842 | col_offsets, |
843 | bias, |
844 | act_times_w_scale, |
845 | thread_id, |
846 | num_threads); |
847 | } |
848 | } else { |
849 | if (Q_GRAN == QuantizationGranularity::TENSOR && B_zero_point[0] == 0) { |
850 | depthwise_3d_same_pad_< |
851 | FUSE_RELU, |
852 | HAS_BIAS, |
853 | false /*A_symmetric*/, |
854 | true /*B_symmetric*/, |
855 | Q_GRAN>( |
856 | conv_p, |
857 | A_zero_point, |
858 | A, |
859 | B_zero_point, |
860 | B, |
861 | C_multiplier, |
862 | C_zero_point, |
863 | C_int32_temp, |
864 | C, |
865 | col_offsets, |
866 | bias, |
867 | act_times_w_scale, |
868 | thread_id, |
869 | num_threads); |
870 | } else { |
871 | depthwise_3d_same_pad_< |
872 | FUSE_RELU, |
873 | HAS_BIAS, |
874 | false /*A_symmetric*/, |
875 | false /*B_symmetric*/, |
876 | Q_GRAN>( |
877 | conv_p, |
878 | A_zero_point, |
879 | A, |
880 | B_zero_point, |
881 | B, |
882 | C_multiplier, |
883 | C_zero_point, |
884 | C_int32_temp, |
885 | C, |
886 | col_offsets, |
887 | bias, |
888 | act_times_w_scale, |
889 | thread_id, |
890 | num_threads); |
891 | } |
892 | } |
893 | fbgemmAlignedFree(C_int32_temp); |
894 | } |
895 | |
896 | // Dispatch HAS_BIAS |
897 | template <bool FUSE_RELU, QuantizationGranularity Q_GRAN, typename BIAS_TYPE> |
898 | static void depthwise_3d_same_pad_( |
899 | const conv_param_t<3>& conv_p, |
900 | int32_t A_zero_point, |
901 | const uint8_t* A, |
902 | const int32_t* B_zero_point, |
903 | const PackedDepthWiseConvMatrix& B, |
904 | const float* C_multiplier, |
905 | int32_t C_zero_point, |
906 | uint8_t* C, |
907 | const int32_t* col_offsets, |
908 | const BIAS_TYPE* bias, |
909 | const float* act_times_w_scale, |
910 | int thread_id, |
911 | int num_threads) { |
912 | if (bias) { |
913 | depthwise_3d_same_pad_<FUSE_RELU, true /*HAS_BIAS*/, Q_GRAN>( |
914 | conv_p, |
915 | A_zero_point, |
916 | A, |
917 | B_zero_point, |
918 | B, |
919 | C_multiplier, |
920 | C_zero_point, |
921 | C, |
922 | col_offsets, |
923 | bias, |
924 | act_times_w_scale, |
925 | thread_id, |
926 | num_threads); |
927 | } else { |
928 | depthwise_3d_same_pad_<FUSE_RELU, false /*HAS_BIAS*/, Q_GRAN>( |
929 | conv_p, |
930 | A_zero_point, |
931 | A, |
932 | B_zero_point, |
933 | B, |
934 | C_multiplier, |
935 | C_zero_point, |
936 | C, |
937 | col_offsets, |
938 | bias, |
939 | act_times_w_scale, |
940 | thread_id, |
941 | num_threads); |
942 | } |
943 | } |
944 | |
945 | // Dispatch FUSE_RELU |
946 | template <QuantizationGranularity Q_GRAN, typename BIAS_TYPE> |
947 | void depthwise_3d_same_pad( |
948 | const conv_param_t<3>& conv_p, |
949 | int32_t A_zero_point, |
950 | const uint8_t* A, |
951 | const int32_t* B_zero_point, |
952 | const PackedDepthWiseConvMatrix& B, |
953 | const float* C_multiplier, |
954 | int32_t C_zero_point, |
955 | uint8_t* C, |
956 | const int32_t* col_offsets, |
957 | const BIAS_TYPE* bias, |
958 | bool fuse_relu, |
959 | const float* act_times_w_scale, |
960 | int thread_id, |
961 | int num_threads) { |
962 | if (B.GetKernelProduct() != conv_p.K[0] * conv_p.K[1] * conv_p.K[2]) { |
963 | string msg = |
964 | "[FBGEMM_CONV_ERROR] Packed weight is expected to have kernel_prod " + |
965 | to_string(conv_p.K[0] * conv_p.K[1] * conv_p.K[2]) + " but has " + |
966 | to_string(B.GetKernelProduct()); |
967 | throw logic_error(msg); |
968 | } |
969 | if (conv_p.stride[0] == 0 || conv_p.stride[1] == 0 || conv_p.stride[2] == 0 || |
970 | num_threads == 0) { |
971 | assert( |
972 | 0 && |
973 | "stride_t == 0 || stride_h == 0 || stride_w == 0 || num_threads == 0" ); |
974 | return; |
975 | } |
976 | if (conv_p.MB == 0) { |
977 | // In C2, batch size 0 is allowed, so we should just early return. |
978 | return; |
979 | } |
980 | if (fuse_relu) { |
981 | depthwise_3d_same_pad_<true /*FUSE_RELU*/, Q_GRAN>( |
982 | conv_p, |
983 | A_zero_point, |
984 | A, |
985 | B_zero_point, |
986 | B, |
987 | C_multiplier, |
988 | C_zero_point, |
989 | C, |
990 | col_offsets, |
991 | bias, |
992 | act_times_w_scale, |
993 | thread_id, |
994 | num_threads); |
995 | } else { |
996 | depthwise_3d_same_pad_<false /*FUSE_RELU*/, Q_GRAN>( |
997 | conv_p, |
998 | A_zero_point, |
999 | A, |
1000 | B_zero_point, |
1001 | B, |
1002 | C_multiplier, |
1003 | C_zero_point, |
1004 | C, |
1005 | col_offsets, |
1006 | bias, |
1007 | act_times_w_scale, |
1008 | thread_id, |
1009 | num_threads); |
1010 | } |
1011 | } |
1012 | |
1013 | #define INSTANTIATE_BASE(Q_GRAN, BIAS_TYPE) \ |
1014 | template FBGEMM_API void \ |
1015 | depthwise_3d_same_pad<QuantizationGranularity::Q_GRAN>( \ |
1016 | const conv_param_t<3>& conv_p, \ |
1017 | int32_t A_zero_point, \ |
1018 | const uint8_t* A, \ |
1019 | const int32_t* B_zero_point, \ |
1020 | const PackedDepthWiseConvMatrix& B, \ |
1021 | const float* C_multiplier, \ |
1022 | int32_t C_zero_point, \ |
1023 | uint8_t* C, \ |
1024 | const int32_t* col_offsets, \ |
1025 | const BIAS_TYPE* bias, \ |
1026 | bool fuse_relu, \ |
1027 | const float* act_times_w_scale, \ |
1028 | int thread_id, \ |
1029 | int num_threads); |
1030 | |
1031 | #define INSTANTIATE_BIAS_T(Q_GRAN) \ |
1032 | INSTANTIATE_BASE(Q_GRAN, int32_t) \ |
1033 | INSTANTIATE_BASE(Q_GRAN, float) |
1034 | |
1035 | INSTANTIATE_BIAS_T(TENSOR) |
1036 | INSTANTIATE_BIAS_T(GROUP) |
1037 | INSTANTIATE_BIAS_T(OUT_CHANNEL) |
1038 | |
1039 | #undef INSTANTIATE_BIAS_T |
1040 | #undef INSTANTIATE_BASE |
1041 | |
1042 | } // namespace fbgemm |
1043 | |