1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#define FBGEMM_EXPORTS
8#include "fbgemm/FbgemmI8DepthwiseAvx2.h"
9
10#include <stdexcept> // for logic_error
11#include <string>
12
13#include "./FbgemmI8DepthwiseAvx2-inl.h"
14#include "./GenerateI8Depthwise.h"
15#include "./MaskAvx2.h"
16#include "fbgemm/Utils.h"
17#include "fbgemm/UtilsAvx2.h"
18
19using namespace std;
20
21namespace fbgemm {
22
23template <
24 bool FUSE_RELU,
25 bool HAS_BIAS,
26 bool A_SYMMETRIC,
27 bool B_SYMMETRIC,
28 QuantizationGranularity Q_GRAN,
29 typename BIAS_TYPE>
30static ALWAYS_INLINE void depthwise_3d_kernel_(
31 int T,
32 int H,
33 int W,
34 int IC,
35 int OC,
36 int t,
37 int h,
38 int w,
39 array<int, 3> F,
40 int stride_t,
41 int stride_h,
42 int stride_w,
43 int32_t A_zero_point,
44 const uint8_t* A,
45 const int32_t* B_zero_point,
46 const int8_t* Bp,
47 const float* C_multiplier,
48 int32_t C_zero_point,
49 int32_t* C_int32,
50 uint8_t* C_uint8,
51 int32_t* row_offsets,
52 const int32_t* col_offsets,
53 const BIAS_TYPE* bias,
54 const float* act_times_w_scale,
55 GenI8Depthwise::jit_kernel_signature* pregenerated_kernel = nullptr) {
56 int R = F[1], S = F[2];
57 int PAD_P = (F[0] - 1) / 2, PAD_T = (F[1] - 1) / 2, PAD_B = PAD_T,
58 PAD_L = (F[2] - 1) / 2, PAD_R = PAD_L;
59 int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1;
60 int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
61 int t_in = -PAD_P + t * stride_t;
62 int h_in = -PAD_T + h * stride_h;
63 int w_in = -PAD_L + w * stride_w;
64
65 int remainder = OC % 32;
66 if (remainder == 0) {
67 remainder = 32;
68 }
69
70 GenI8Depthwise::jit_kernel_signature kernel = pregenerated_kernel
71 ? *pregenerated_kernel
72 : GenI8Depthwise().getOrCreate(
73 /*D=*/3,
74 F,
75 OC / IC,
76 /*compute_a_sum=*/!B_SYMMETRIC,
77 remainder,
78 /*prev_skip=*/std::max(-t_in, 0),
79 /*next_skip=*/std::max(t_in + F[0] - T, 0),
80 /*top_skip=*/std::max(-h_in, 0),
81 /*bottom_skip=*/std::max(h_in + F[1] - H, 0),
82 /*left_skip=*/std::max(-w_in, 0),
83 /*right_skip=*/std::max(w_in + F[2] - W, 0));
84 kernel(
85 A + ((t_in * H + h_in) * W + w_in) * IC,
86 Bp,
87 C_int32,
88 B_SYMMETRIC ? nullptr : row_offsets,
89 H,
90 W,
91 IC,
92 internal::avx2_ps_or_epi32_combined_mask,
93 A_zero_point);
94
95 if (OC == IC) {
96 requantize_<FUSE_RELU, HAS_BIAS, Q_GRAN, A_SYMMETRIC, B_SYMMETRIC, 1>(
97 A_zero_point,
98 B_zero_point,
99 C_multiplier,
100 C_zero_point,
101 C_int32,
102 C_uint8 + ((t * H_OUT + h) * W_OUT + w) * OC,
103 OC,
104 row_offsets,
105 col_offsets,
106 bias,
107 act_times_w_scale);
108 } else {
109 assert(OC / IC == 2);
110 requantize_<FUSE_RELU, HAS_BIAS, Q_GRAN, A_SYMMETRIC, B_SYMMETRIC, 2>(
111 A_zero_point,
112 B_zero_point,
113 C_multiplier,
114 C_zero_point,
115 C_int32,
116 C_uint8 + ((t * H_OUT + h) * W_OUT + w) * OC,
117 OC,
118 row_offsets,
119 col_offsets,
120 bias,
121 act_times_w_scale);
122 }
123}
124
125template <
126 bool FUSE_RELU,
127 bool HAS_BIAS,
128 bool A_SYMMETRIC,
129 bool B_SYMMETRIC,
130 QuantizationGranularity Q_GRAN,
131 typename BIAS_TYPE>
132static ALWAYS_INLINE void depthwise_3d_same_pad_(
133 const conv_param_t<3>& conv_p,
134 int32_t A_zero_point,
135 const uint8_t* A,
136 const int32_t* B_zero_point,
137 const PackedDepthWiseConvMatrix& B,
138 const float* C_multiplier,
139 int32_t C_zero_point,
140 int32_t* C_int32,
141 uint8_t* C_uint8,
142 const int32_t* col_offsets,
143 const BIAS_TYPE* bias,
144 const float* act_times_w_scale,
145 int thread_id,
146 int num_threads) {
147 int N = conv_p.MB;
148 int T = conv_p.IN_DIM[0];
149 int H = conv_p.IN_DIM[1];
150 int W = conv_p.IN_DIM[2];
151 int IC = conv_p.IC;
152 int OC = conv_p.OC;
153 array<int, 3> F = conv_p.K;
154 int stride_t = conv_p.stride[0];
155 int stride_h = conv_p.stride[1];
156 int stride_w = conv_p.stride[2];
157
158 assert(IC % 8 == 0);
159
160 int K_T = F[0], K_H = F[1], K_W = F[2];
161 int PAD_P = (F[0] - 1) / 2, PAD_N = PAD_P, PAD_T = (F[1] - 1) / 2,
162 PAD_B = PAD_T, PAD_L = (F[2] - 1) / 2, PAD_R = PAD_L;
163 int64_t T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1;
164 int64_t H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1;
165 int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1;
166 const int8_t* Bp = B.PackedMat();
167
168 int32_t* row_offsets = static_cast<int32_t*>(
169 fbgemmAlignedAlloc(64, (IC + 31) / 32 * 32 * sizeof(int32_t)));
170
171 int64_t n_begin, n_end, t_begin, t_end, h_begin, h_end;
172 // Reuse the 3-dim partition scheme for parallelization in matrix
173 // multiplication.
174 thread_type_t th_info =
175 fbgemmGetThreadPartition(N, T_OUT, H_OUT, thread_id, num_threads);
176 // Calculate the begin and end index along the batch (N) dimension
177 fbgemmPartition1D(
178 th_info.g_thread_id, th_info.g_num_threads, N, n_begin, n_end);
179 // Calculate the begin and end index along the T dimension
180 fbgemmPartition1D(
181 th_info.m_thread_id, th_info.m_num_threads, T_OUT, t_begin, t_end);
182 // Calculate the begin and end index along the H dimension
183 fbgemmPartition1D(
184 th_info.n_thread_id, th_info.n_num_threads, H_OUT, h_begin, h_end);
185
186 GenI8Depthwise::jit_kernel_signature middle_kernel;
187
188 for (int n = n_begin; n < n_end; ++n) {
189 const uint8_t* A_base = A + n * T * H * W * IC;
190 uint8_t* C_uint8_base = C_uint8 + n * T_OUT * H_OUT * W_OUT * OC;
191
192 int t;
193 for (t = t_begin; t < PAD_P; ++t) {
194 int h;
195 for (h = h_begin; h < PAD_T; ++h) {
196 for (int w = 0; w < W_OUT; ++w) {
197 depthwise_3d_kernel_<
198 FUSE_RELU,
199 HAS_BIAS,
200 A_SYMMETRIC,
201 B_SYMMETRIC,
202 Q_GRAN>(
203 T,
204 H,
205 W,
206 IC,
207 OC,
208 t,
209 h,
210 w,
211 F,
212 stride_t,
213 stride_h,
214 stride_w,
215 A_zero_point,
216 A_base,
217 B_zero_point,
218 Bp,
219 C_multiplier,
220 C_zero_point,
221 C_int32,
222 C_uint8_base,
223 row_offsets,
224 col_offsets,
225 bias,
226 act_times_w_scale);
227 } // w
228 } // h
229
230 for (; h < std::min(H_OUT - PAD_B - stride_h + 1, h_end); ++h) {
231 int w;
232 for (w = 0; w < PAD_L; ++w) {
233 depthwise_3d_kernel_<
234 FUSE_RELU,
235 HAS_BIAS,
236 A_SYMMETRIC,
237 B_SYMMETRIC,
238 Q_GRAN>(
239 T,
240 H,
241 W,
242 IC,
243 OC,
244 t,
245 h,
246 w,
247 F,
248 stride_t,
249 stride_h,
250 stride_w,
251 A_zero_point,
252 A_base,
253 B_zero_point,
254 Bp,
255 C_multiplier,
256 C_zero_point,
257 C_int32,
258 C_uint8_base,
259 row_offsets,
260 col_offsets,
261 bias,
262 act_times_w_scale);
263 } // w
264
265 GenI8Depthwise::jit_kernel_signature kernel;
266 for (; w < W_OUT - PAD_R - stride_w + 1; ++w) {
267 if (w == PAD_L) {
268 int remainder = OC % 32;
269 if (remainder == 0) {
270 remainder = 32;
271 }
272 int t_in = -PAD_P + t * stride_t;
273 kernel = GenI8Depthwise().getOrCreate(
274 /*D=*/3,
275 F,
276 OC / IC,
277 /*compute_a_sum=*/!B_SYMMETRIC,
278 remainder,
279 /*prev_skip=*/std::max(-t_in, 0),
280 /*next_skip=*/std::max(t_in + F[0] - T, 0),
281 0,
282 0,
283 0,
284 0);
285 }
286 depthwise_3d_kernel_<
287 FUSE_RELU,
288 HAS_BIAS,
289 A_SYMMETRIC,
290 B_SYMMETRIC,
291 Q_GRAN>(
292 T,
293 H,
294 W,
295 IC,
296 OC,
297 t,
298 h,
299 w,
300 F,
301 stride_t,
302 stride_h,
303 stride_w,
304 A_zero_point,
305 A_base,
306 B_zero_point,
307 Bp,
308 C_multiplier,
309 C_zero_point,
310 C_int32,
311 C_uint8_base,
312 row_offsets,
313 col_offsets,
314 bias,
315 act_times_w_scale,
316 &kernel);
317 } // w
318
319 for (; w < W_OUT; ++w) {
320 depthwise_3d_kernel_<
321 FUSE_RELU,
322 HAS_BIAS,
323 A_SYMMETRIC,
324 B_SYMMETRIC,
325 Q_GRAN>(
326 T,
327 H,
328 W,
329 IC,
330 OC,
331 t,
332 h,
333 w,
334 F,
335 stride_t,
336 stride_h,
337 stride_w,
338 A_zero_point,
339 A_base,
340 B_zero_point,
341 Bp,
342 C_multiplier,
343 C_zero_point,
344 C_int32,
345 C_uint8_base,
346 row_offsets,
347 col_offsets,
348 bias,
349 act_times_w_scale);
350 } // w
351 } // h
352
353 for (; h < h_end; ++h) {
354 for (int w = 0; w < W_OUT; ++w) {
355 depthwise_3d_kernel_<
356 FUSE_RELU,
357 HAS_BIAS,
358 A_SYMMETRIC,
359 B_SYMMETRIC,
360 Q_GRAN>(
361 T,
362 H,
363 W,
364 IC,
365 OC,
366 t,
367 h,
368 w,
369 F,
370 stride_t,
371 stride_h,
372 stride_w,
373 A_zero_point,
374 A_base,
375 B_zero_point,
376 Bp,
377 C_multiplier,
378 C_zero_point,
379 C_int32,
380 C_uint8_base,
381 row_offsets,
382 col_offsets,
383 bias,
384 act_times_w_scale);
385 } // w
386 } // h
387 } // t
388
389 for (; t < std::min(T_OUT - PAD_N - stride_t + 1, t_end); ++t) {
390 int h;
391 for (h = h_begin; h < PAD_T; ++h) {
392 for (int w = 0; w < W_OUT; ++w) {
393 depthwise_3d_kernel_<
394 FUSE_RELU,
395 HAS_BIAS,
396 A_SYMMETRIC,
397 B_SYMMETRIC,
398 Q_GRAN>(
399 T,
400 H,
401 W,
402 IC,
403 OC,
404 t,
405 h,
406 w,
407 F,
408 stride_t,
409 stride_h,
410 stride_w,
411 A_zero_point,
412 A_base,
413 B_zero_point,
414 Bp,
415 C_multiplier,
416 C_zero_point,
417 C_int32,
418 C_uint8_base,
419 row_offsets,
420 col_offsets,
421 bias,
422 act_times_w_scale);
423 } // w
424 } // h
425
426 for (; h < std::min(H_OUT - PAD_B - stride_h + 1, h_end); ++h) {
427 int w;
428 for (w = 0; w < PAD_L; ++w) {
429 depthwise_3d_kernel_<
430 FUSE_RELU,
431 HAS_BIAS,
432 A_SYMMETRIC,
433 B_SYMMETRIC,
434 Q_GRAN>(
435 T,
436 H,
437 W,
438 IC,
439 OC,
440 t,
441 h,
442 w,
443 F,
444 stride_t,
445 stride_h,
446 stride_w,
447 A_zero_point,
448 A_base,
449 B_zero_point,
450 Bp,
451 C_multiplier,
452 C_zero_point,
453 C_int32,
454 C_uint8_base,
455 row_offsets,
456 col_offsets,
457 bias,
458 act_times_w_scale);
459 } // w
460
461 for (; w < W_OUT - PAD_R - stride_w + 1; ++w) {
462 if (n == n_begin && w == PAD_L) {
463 int remainder = OC % 32;
464 if (remainder == 0) {
465 remainder = 32;
466 }
467 middle_kernel = GenI8Depthwise().getOrCreate(
468 /*D=*/3,
469 F,
470 OC / IC,
471 /*compute_a_sum=*/!B_SYMMETRIC,
472 remainder,
473 0,
474 0,
475 0,
476 0,
477 0,
478 0);
479 }
480 depthwise_3d_kernel_<
481 FUSE_RELU,
482 HAS_BIAS,
483 A_SYMMETRIC,
484 B_SYMMETRIC,
485 Q_GRAN>(
486 T,
487 H,
488 W,
489 IC,
490 OC,
491 t,
492 h,
493 w,
494 F,
495 stride_t,
496 stride_h,
497 stride_w,
498 A_zero_point,
499 A_base,
500 B_zero_point,
501 Bp,
502 C_multiplier,
503 C_zero_point,
504 C_int32,
505 C_uint8_base,
506 row_offsets,
507 col_offsets,
508 bias,
509 act_times_w_scale,
510 &middle_kernel);
511 }
512
513 for (; w < W_OUT; ++w) {
514 depthwise_3d_kernel_<
515 FUSE_RELU,
516 HAS_BIAS,
517 A_SYMMETRIC,
518 B_SYMMETRIC,
519 Q_GRAN>(
520 T,
521 H,
522 W,
523 IC,
524 OC,
525 t,
526 h,
527 w,
528 F,
529 stride_t,
530 stride_h,
531 stride_w,
532 A_zero_point,
533 A_base,
534 B_zero_point,
535 Bp,
536 C_multiplier,
537 C_zero_point,
538 C_int32,
539 C_uint8_base,
540 row_offsets,
541 col_offsets,
542 bias,
543 act_times_w_scale);
544 }
545 } // h
546
547 for (; h < h_end; ++h) {
548 for (int w = 0; w < W_OUT; ++w) {
549 depthwise_3d_kernel_<
550 FUSE_RELU,
551 HAS_BIAS,
552 A_SYMMETRIC,
553 B_SYMMETRIC,
554 Q_GRAN>(
555 T,
556 H,
557 W,
558 IC,
559 OC,
560 t,
561 h,
562 w,
563 F,
564 stride_t,
565 stride_h,
566 stride_w,
567 A_zero_point,
568 A_base,
569 B_zero_point,
570 Bp,
571 C_multiplier,
572 C_zero_point,
573 C_int32,
574 C_uint8_base,
575 row_offsets,
576 col_offsets,
577 bias,
578 act_times_w_scale);
579 } // w
580 } // h
581 } // t
582
583 for (; t < t_end; ++t) {
584 int h;
585 for (h = h_begin; h < PAD_T; ++h) {
586 for (int w = 0; w < W_OUT; ++w) {
587 depthwise_3d_kernel_<
588 FUSE_RELU,
589 HAS_BIAS,
590 A_SYMMETRIC,
591 B_SYMMETRIC,
592 Q_GRAN>(
593 T,
594 H,
595 W,
596 IC,
597 OC,
598 t,
599 h,
600 w,
601 F,
602 stride_t,
603 stride_h,
604 stride_w,
605 A_zero_point,
606 A_base,
607 B_zero_point,
608 Bp,
609 C_multiplier,
610 C_zero_point,
611 C_int32,
612 C_uint8_base,
613 row_offsets,
614 col_offsets,
615 bias,
616 act_times_w_scale);
617 } // w
618 } // h
619
620 for (; h < std::min(H_OUT - PAD_B - stride_h + 1, h_end); ++h) {
621 int w;
622 for (w = 0; w < PAD_L; ++w) {
623 depthwise_3d_kernel_<
624 FUSE_RELU,
625 HAS_BIAS,
626 A_SYMMETRIC,
627 B_SYMMETRIC,
628 Q_GRAN>(
629 T,
630 H,
631 W,
632 IC,
633 OC,
634 t,
635 h,
636 w,
637 F,
638 stride_t,
639 stride_h,
640 stride_w,
641 A_zero_point,
642 A_base,
643 B_zero_point,
644 Bp,
645 C_multiplier,
646 C_zero_point,
647 C_int32,
648 C_uint8_base,
649 row_offsets,
650 col_offsets,
651 bias,
652 act_times_w_scale);
653 } // w
654
655 GenI8Depthwise::jit_kernel_signature kernel;
656 for (; w < W_OUT - PAD_R - stride_w + 1; ++w) {
657 if (w == PAD_L) {
658 int remainder = OC % 32;
659 if (remainder == 0) {
660 remainder = 32;
661 }
662 int t_in = -PAD_P + t * stride_t;
663 kernel = GenI8Depthwise().getOrCreate(
664 /*D=*/3,
665 F,
666 OC / IC,
667 /*compute_a_sum=*/!B_SYMMETRIC,
668 remainder,
669 /*prev_skip=*/std::max(-t_in, 0),
670 /*next_skip=*/std::max(t_in + F[0] - T, 0),
671 0,
672 0,
673 0,
674 0);
675 }
676 depthwise_3d_kernel_<
677 FUSE_RELU,
678 HAS_BIAS,
679 A_SYMMETRIC,
680 B_SYMMETRIC,
681 Q_GRAN>(
682 T,
683 H,
684 W,
685 IC,
686 OC,
687 t,
688 h,
689 w,
690 F,
691 stride_t,
692 stride_h,
693 stride_w,
694 A_zero_point,
695 A_base,
696 B_zero_point,
697 Bp,
698 C_multiplier,
699 C_zero_point,
700 C_int32,
701 C_uint8_base,
702 row_offsets,
703 col_offsets,
704 bias,
705 act_times_w_scale,
706 &kernel);
707 } // w
708
709 for (; w < W_OUT; ++w) {
710 depthwise_3d_kernel_<
711 FUSE_RELU,
712 HAS_BIAS,
713 A_SYMMETRIC,
714 B_SYMMETRIC,
715 Q_GRAN>(
716 T,
717 H,
718 W,
719 IC,
720 OC,
721 t,
722 h,
723 w,
724 F,
725 stride_t,
726 stride_h,
727 stride_w,
728 A_zero_point,
729 A_base,
730 B_zero_point,
731 Bp,
732 C_multiplier,
733 C_zero_point,
734 C_int32,
735 C_uint8_base,
736 row_offsets,
737 col_offsets,
738 bias,
739 act_times_w_scale);
740 } // w
741 } // h
742
743 for (; h < h_end; ++h) {
744 for (int w = 0; w < W_OUT; ++w) {
745 depthwise_3d_kernel_<
746 FUSE_RELU,
747 HAS_BIAS,
748 A_SYMMETRIC,
749 B_SYMMETRIC,
750 Q_GRAN>(
751 T,
752 H,
753 W,
754 IC,
755 OC,
756 t,
757 h,
758 w,
759 F,
760 stride_t,
761 stride_h,
762 stride_w,
763 A_zero_point,
764 A_base,
765 B_zero_point,
766 Bp,
767 C_multiplier,
768 C_zero_point,
769 C_int32,
770 C_uint8_base,
771 row_offsets,
772 col_offsets,
773 bias,
774 act_times_w_scale);
775 } // w
776 } // h
777 } // t
778 } // for each n
779 fbgemmAlignedFree(row_offsets);
780}
781
782// Dispatch A_SYMMETRIC and B_SYMMETRIC
783template <
784 bool FUSE_RELU,
785 bool HAS_BIAS,
786 QuantizationGranularity Q_GRAN,
787 typename BIAS_TYPE>
788static void depthwise_3d_same_pad_(
789 const conv_param_t<3>& conv_p,
790 int32_t A_zero_point,
791 const uint8_t* A,
792 const int32_t* B_zero_point,
793 const PackedDepthWiseConvMatrix& B,
794 const float* C_multiplier,
795 int32_t C_zero_point,
796 uint8_t* C,
797 const int32_t* col_offsets,
798 const BIAS_TYPE* bias,
799 const float* act_times_w_scale,
800 int thread_id,
801 int num_threads) {
802 int32_t* C_int32_temp = static_cast<int32_t*>(
803 fbgemmAlignedAlloc(64, (conv_p.OC + 31) / 32 * 32 * sizeof(int32_t)));
804 if (A_zero_point == 0 || col_offsets == nullptr) {
805 if (Q_GRAN == QuantizationGranularity::TENSOR && B_zero_point[0] == 0) {
806 depthwise_3d_same_pad_<
807 FUSE_RELU,
808 HAS_BIAS,
809 true /*A_symmetric*/,
810 true /*B_symmetric*/,
811 Q_GRAN>(
812 conv_p,
813 A_zero_point,
814 A,
815 B_zero_point,
816 B,
817 C_multiplier,
818 C_zero_point,
819 C_int32_temp,
820 C,
821 col_offsets,
822 bias,
823 act_times_w_scale,
824 thread_id,
825 num_threads);
826 } else {
827 depthwise_3d_same_pad_<
828 FUSE_RELU,
829 HAS_BIAS,
830 true /*A_symmetric*/,
831 false /*B_symmetric*/,
832 Q_GRAN>(
833 conv_p,
834 A_zero_point,
835 A,
836 B_zero_point,
837 B,
838 C_multiplier,
839 C_zero_point,
840 C_int32_temp,
841 C,
842 col_offsets,
843 bias,
844 act_times_w_scale,
845 thread_id,
846 num_threads);
847 }
848 } else {
849 if (Q_GRAN == QuantizationGranularity::TENSOR && B_zero_point[0] == 0) {
850 depthwise_3d_same_pad_<
851 FUSE_RELU,
852 HAS_BIAS,
853 false /*A_symmetric*/,
854 true /*B_symmetric*/,
855 Q_GRAN>(
856 conv_p,
857 A_zero_point,
858 A,
859 B_zero_point,
860 B,
861 C_multiplier,
862 C_zero_point,
863 C_int32_temp,
864 C,
865 col_offsets,
866 bias,
867 act_times_w_scale,
868 thread_id,
869 num_threads);
870 } else {
871 depthwise_3d_same_pad_<
872 FUSE_RELU,
873 HAS_BIAS,
874 false /*A_symmetric*/,
875 false /*B_symmetric*/,
876 Q_GRAN>(
877 conv_p,
878 A_zero_point,
879 A,
880 B_zero_point,
881 B,
882 C_multiplier,
883 C_zero_point,
884 C_int32_temp,
885 C,
886 col_offsets,
887 bias,
888 act_times_w_scale,
889 thread_id,
890 num_threads);
891 }
892 }
893 fbgemmAlignedFree(C_int32_temp);
894}
895
896// Dispatch HAS_BIAS
897template <bool FUSE_RELU, QuantizationGranularity Q_GRAN, typename BIAS_TYPE>
898static void depthwise_3d_same_pad_(
899 const conv_param_t<3>& conv_p,
900 int32_t A_zero_point,
901 const uint8_t* A,
902 const int32_t* B_zero_point,
903 const PackedDepthWiseConvMatrix& B,
904 const float* C_multiplier,
905 int32_t C_zero_point,
906 uint8_t* C,
907 const int32_t* col_offsets,
908 const BIAS_TYPE* bias,
909 const float* act_times_w_scale,
910 int thread_id,
911 int num_threads) {
912 if (bias) {
913 depthwise_3d_same_pad_<FUSE_RELU, true /*HAS_BIAS*/, Q_GRAN>(
914 conv_p,
915 A_zero_point,
916 A,
917 B_zero_point,
918 B,
919 C_multiplier,
920 C_zero_point,
921 C,
922 col_offsets,
923 bias,
924 act_times_w_scale,
925 thread_id,
926 num_threads);
927 } else {
928 depthwise_3d_same_pad_<FUSE_RELU, false /*HAS_BIAS*/, Q_GRAN>(
929 conv_p,
930 A_zero_point,
931 A,
932 B_zero_point,
933 B,
934 C_multiplier,
935 C_zero_point,
936 C,
937 col_offsets,
938 bias,
939 act_times_w_scale,
940 thread_id,
941 num_threads);
942 }
943}
944
945// Dispatch FUSE_RELU
946template <QuantizationGranularity Q_GRAN, typename BIAS_TYPE>
947void depthwise_3d_same_pad(
948 const conv_param_t<3>& conv_p,
949 int32_t A_zero_point,
950 const uint8_t* A,
951 const int32_t* B_zero_point,
952 const PackedDepthWiseConvMatrix& B,
953 const float* C_multiplier,
954 int32_t C_zero_point,
955 uint8_t* C,
956 const int32_t* col_offsets,
957 const BIAS_TYPE* bias,
958 bool fuse_relu,
959 const float* act_times_w_scale,
960 int thread_id,
961 int num_threads) {
962 if (B.GetKernelProduct() != conv_p.K[0] * conv_p.K[1] * conv_p.K[2]) {
963 string msg =
964 "[FBGEMM_CONV_ERROR] Packed weight is expected to have kernel_prod " +
965 to_string(conv_p.K[0] * conv_p.K[1] * conv_p.K[2]) + " but has " +
966 to_string(B.GetKernelProduct());
967 throw logic_error(msg);
968 }
969 if (conv_p.stride[0] == 0 || conv_p.stride[1] == 0 || conv_p.stride[2] == 0 ||
970 num_threads == 0) {
971 assert(
972 0 &&
973 "stride_t == 0 || stride_h == 0 || stride_w == 0 || num_threads == 0");
974 return;
975 }
976 if (conv_p.MB == 0) {
977 // In C2, batch size 0 is allowed, so we should just early return.
978 return;
979 }
980 if (fuse_relu) {
981 depthwise_3d_same_pad_<true /*FUSE_RELU*/, Q_GRAN>(
982 conv_p,
983 A_zero_point,
984 A,
985 B_zero_point,
986 B,
987 C_multiplier,
988 C_zero_point,
989 C,
990 col_offsets,
991 bias,
992 act_times_w_scale,
993 thread_id,
994 num_threads);
995 } else {
996 depthwise_3d_same_pad_<false /*FUSE_RELU*/, Q_GRAN>(
997 conv_p,
998 A_zero_point,
999 A,
1000 B_zero_point,
1001 B,
1002 C_multiplier,
1003 C_zero_point,
1004 C,
1005 col_offsets,
1006 bias,
1007 act_times_w_scale,
1008 thread_id,
1009 num_threads);
1010 }
1011}
1012
1013#define INSTANTIATE_BASE(Q_GRAN, BIAS_TYPE) \
1014 template FBGEMM_API void \
1015 depthwise_3d_same_pad<QuantizationGranularity::Q_GRAN>( \
1016 const conv_param_t<3>& conv_p, \
1017 int32_t A_zero_point, \
1018 const uint8_t* A, \
1019 const int32_t* B_zero_point, \
1020 const PackedDepthWiseConvMatrix& B, \
1021 const float* C_multiplier, \
1022 int32_t C_zero_point, \
1023 uint8_t* C, \
1024 const int32_t* col_offsets, \
1025 const BIAS_TYPE* bias, \
1026 bool fuse_relu, \
1027 const float* act_times_w_scale, \
1028 int thread_id, \
1029 int num_threads);
1030
1031#define INSTANTIATE_BIAS_T(Q_GRAN) \
1032 INSTANTIATE_BASE(Q_GRAN, int32_t) \
1033 INSTANTIATE_BASE(Q_GRAN, float)
1034
1035INSTANTIATE_BIAS_T(TENSOR)
1036INSTANTIATE_BIAS_T(GROUP)
1037INSTANTIATE_BIAS_T(OUT_CHANNEL)
1038
1039#undef INSTANTIATE_BIAS_T
1040#undef INSTANTIATE_BASE
1041
1042} // namespace fbgemm
1043