1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#pragma once
8
9#include "./FbgemmI8DepthwiseAvx2-inl.h"
10#include "./GenerateI8Depthwise.h"
11#include "./MaskAvx2.h"
12#include "fbgemm/Utils.h"
13#include "fbgemm/UtilsAvx2.h"
14
15namespace fbgemm {
16
17template <
18 int S,
19 bool FUSE_RELU,
20 bool HAS_BIAS,
21 bool A_SYMMETRIC,
22 bool B_SYMMETRIC,
23 QuantizationGranularity Q_GRAN,
24 typename BIAS_TYPE>
25static ALWAYS_INLINE void depthwise_2d_kernel_(
26 int H,
27 int W,
28 int IC,
29 int OC,
30 int h,
31 int w,
32 int stride_h,
33 int stride_w,
34 std::int32_t A_zero_point,
35 const std::uint8_t* A,
36 const std::int32_t* B_zero_point,
37 const std::int8_t* Bp,
38 const float* C_multiplier,
39 std::int32_t C_zero_point,
40 std::int32_t* C_int32,
41 std::uint8_t* C_uint8,
42 std::int32_t* row_offsets,
43 const std::int32_t* col_offsets,
44 const BIAS_TYPE* bias,
45 const float* act_times_w_scale,
46 GenI8Depthwise::jit_kernel_signature* pregenerated_kernel = nullptr) {
47 constexpr int PAD_T = (S - 1) / 2, PAD_L = (S - 1) / 2, PAD_R = (S - 1) / 2;
48 int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
49 int h_in = -PAD_T + h * stride_h;
50 int w_in = -PAD_L + w * stride_w;
51
52 int remainder = OC % 32;
53 if (remainder == 0) {
54 remainder = 32;
55 }
56
57 GenI8Depthwise::jit_kernel_signature kernel = pregenerated_kernel
58 ? *pregenerated_kernel
59 : GenI8Depthwise().getOrCreate(
60 /*D=*/2,
61 {1, S, S},
62 OC / IC,
63 /*compute_a_sum=*/!B_SYMMETRIC,
64 remainder,
65 0,
66 0,
67 /*top_skip=*/std::max(-h_in, 0),
68 /*bottom_skip=*/std::max(h_in + S - H, 0),
69 /*left_skip=*/std::max(-w_in, 0),
70 /*right_skip=*/std::max(w_in + S - W, 0));
71
72 kernel(
73 A + (h_in * W + w_in) * IC,
74 Bp,
75 C_int32,
76 B_SYMMETRIC ? nullptr : row_offsets,
77 H,
78 W,
79 IC,
80 internal::avx2_ps_or_epi32_combined_mask,
81 A_zero_point);
82
83 if (OC == IC) {
84 requantize_<FUSE_RELU, HAS_BIAS, Q_GRAN, A_SYMMETRIC, B_SYMMETRIC, 1>(
85 A_zero_point,
86 B_zero_point,
87 C_multiplier,
88 C_zero_point,
89 C_int32,
90 C_uint8 + (h * W_OUT + w) * OC,
91 OC,
92 row_offsets,
93 col_offsets,
94 bias,
95 act_times_w_scale);
96 } else {
97 requantize_<FUSE_RELU, HAS_BIAS, Q_GRAN, A_SYMMETRIC, B_SYMMETRIC, 2>(
98 A_zero_point,
99 B_zero_point,
100 C_multiplier,
101 C_zero_point,
102 C_int32,
103 C_uint8 + (h * W_OUT + w) * OC,
104 OC,
105 row_offsets,
106 col_offsets,
107 bias,
108 act_times_w_scale);
109 }
110}
111
112// TODO: short-circuit when B_zero_point is 0 or A_zero_point is 0
113// This implemntation should be general enough to handle not just 3x3 but other
114// filter shapes by parameterizing with R and S but restricting it to just 3x3
115// for now.
116template <
117 int S,
118 bool FUSE_RELU,
119 bool HAS_BIAS,
120 bool A_SYMMETRIC,
121 bool B_SYMMETRIC,
122 QuantizationGranularity Q_GRAN,
123 typename BIAS_TYPE>
124static ALWAYS_INLINE void depthwise_2d_(
125 int N,
126 int H,
127 int W,
128 int IC,
129 int OC,
130 int stride_h,
131 int stride_w,
132 std::int32_t A_zero_point,
133 const std::uint8_t* A,
134 const std::int32_t* B_zero_point,
135 const PackedDepthWiseConvMatrix& B,
136 const float* C_multiplier,
137 std::int32_t C_zero_point,
138 std::int32_t* C_int32,
139 std::uint8_t* C_uint8,
140 const std::int32_t* col_offsets,
141 const BIAS_TYPE* bias,
142 const float* act_times_w_scale,
143 int thread_id,
144 int num_threads) {
145 assert(IC % 8 == 0);
146 constexpr int R = S;
147 constexpr int64_t PAD_T = (R - 1) / 2, PAD_B = PAD_T, PAD_L = (S - 1) / 2,
148 PAD_R = PAD_L;
149 int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1;
150 int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
151 const std::int8_t* Bp = B.PackedMat();
152
153 int32_t* row_offsets = static_cast<int32_t*>(
154 fbgemmAlignedAlloc(64, (IC + 31) / 32 * 32 * sizeof(int32_t)));
155
156 int64_t n_begin, n_end, h_begin, h_end, w_begin, w_end;
157 // Reuse the 3-dim partition scheme for parallelization in matrix
158 // multiplication.
159 thread_type_t th_info =
160 fbgemmGetThreadPartition(N, H_OUT, W_OUT, thread_id, num_threads);
161 // Calculate the begin and end index along the batch (N) dimension
162 fbgemmPartition1D(
163 th_info.g_thread_id, th_info.g_num_threads, N, n_begin, n_end);
164 // Calculate the begin and end index along the H dimension
165 fbgemmPartition1D(
166 th_info.m_thread_id, th_info.m_num_threads, H_OUT, h_begin, h_end);
167 // Calculate the begin and end index along the W dimension
168 fbgemmPartition1D(
169 th_info.n_thread_id, th_info.n_num_threads, W_OUT, w_begin, w_end);
170
171 GenI8Depthwise::jit_kernel_signature middle_kernel;
172
173 for (int n = n_begin; n < n_end; ++n) {
174 const std::uint8_t* A_base = A + n * H * W * IC;
175 std::uint8_t* C_uint8_base = C_uint8 + n * H_OUT * W_OUT * OC;
176
177 int h = 0;
178 int w = 0;
179
180 for (h = h_begin; h < std::min(PAD_T, h_end); ++h) {
181 for (w = w_begin; w < std::min(PAD_L, w_end); ++w) {
182 depthwise_2d_kernel_<
183 S,
184 FUSE_RELU,
185 HAS_BIAS,
186 A_SYMMETRIC,
187 B_SYMMETRIC,
188 Q_GRAN>(
189 H,
190 W,
191 IC,
192 OC,
193 h,
194 w,
195 stride_h,
196 stride_w,
197 A_zero_point,
198 A_base,
199 B_zero_point,
200 Bp,
201 C_multiplier,
202 C_zero_point,
203 C_int32,
204 C_uint8_base,
205 row_offsets,
206 col_offsets,
207 bias,
208 act_times_w_scale);
209 }
210
211 for (; w < std::min(W_OUT - PAD_R - stride_w + 1, w_end); ++w) {
212 depthwise_2d_kernel_<
213 S,
214 FUSE_RELU,
215 HAS_BIAS,
216 A_SYMMETRIC,
217 B_SYMMETRIC,
218 Q_GRAN>(
219 H,
220 W,
221 IC,
222 OC,
223 h,
224 w,
225 stride_h,
226 stride_w,
227 A_zero_point,
228 A_base,
229 B_zero_point,
230 Bp,
231 C_multiplier,
232 C_zero_point,
233 C_int32,
234 C_uint8_base,
235 row_offsets,
236 col_offsets,
237 bias,
238 act_times_w_scale);
239 }
240
241 for (; w < w_end; ++w) {
242 depthwise_2d_kernel_<
243 S,
244 FUSE_RELU,
245 HAS_BIAS,
246 A_SYMMETRIC,
247 B_SYMMETRIC,
248 Q_GRAN>(
249 H,
250 W,
251 IC,
252 OC,
253 h,
254 w,
255 stride_h,
256 stride_w,
257 A_zero_point,
258 A_base,
259 B_zero_point,
260 Bp,
261 C_multiplier,
262 C_zero_point,
263 C_int32,
264 C_uint8_base,
265 row_offsets,
266 col_offsets,
267 bias,
268 act_times_w_scale);
269 }
270 }
271
272 // h <= H_OUT - PAD_B - stride_h
273 // h <= (H + PAD_T + PAD_B - S) / stride_h + 1 - PAD_B - stride_h
274 // h_in <= -PAD_T +
275 // ((H + PAD_T + PAD_B - S) / stride_h + 1 - PAD_B - stride_h) * stride_h
276 // Case 1) For stride_h == 1,
277 // h_in <= -PAD_T + H + PAD_T + PAD_B - S + 1 - PAD_B - 1
278 // h_in + S - H <= 0
279 // Case 2) For stride_h == 2,
280 // h_in <= -PAD_L +
281 // H + PAD_T + PAD_B - S + 1 + (1 - PAD_B - stride_h) * stride_h
282 // h_in + S - H <= PAD_B * (1 - stride_h) + 1 + (1 - stride_h) * stride_h
283 // <= -PAD_B + 1 - stride_h <= 0
284 for (; h < std::min(H_OUT - PAD_B - stride_h + 1, h_end); ++h) {
285 for (w = w_begin; w < std::min(PAD_L, w_end); ++w) {
286 depthwise_2d_kernel_<
287 S,
288 FUSE_RELU,
289 HAS_BIAS,
290 A_SYMMETRIC,
291 B_SYMMETRIC,
292 Q_GRAN>(
293 H,
294 W,
295 IC,
296 OC,
297 h,
298 w,
299 stride_h,
300 stride_w,
301 A_zero_point,
302 A_base,
303 B_zero_point,
304 Bp,
305 C_multiplier,
306 C_zero_point,
307 C_int32,
308 C_uint8_base,
309 row_offsets,
310 col_offsets,
311 bias,
312 act_times_w_scale);
313 }
314
315 for (; w < std::min(W_OUT - PAD_R - stride_w + 1, w_end); ++w) {
316 if (n == n_begin && w == std::max(PAD_L, w_begin)) {
317 int remainder = OC % 32;
318 if (remainder == 0) {
319 remainder = 32;
320 }
321 middle_kernel = GenI8Depthwise().getOrCreate(
322 /*D=*/2,
323 {1, S, S},
324 OC / IC,
325 /*compute_a_sum=*/!B_SYMMETRIC,
326 remainder,
327 0,
328 0,
329 0,
330 0,
331 0,
332 0);
333 }
334 depthwise_2d_kernel_<
335 S,
336 FUSE_RELU,
337 HAS_BIAS,
338 A_SYMMETRIC,
339 B_SYMMETRIC,
340 Q_GRAN>(
341 H,
342 W,
343 IC,
344 OC,
345 h,
346 w,
347 stride_h,
348 stride_w,
349 A_zero_point,
350 A_base,
351 B_zero_point,
352 Bp,
353 C_multiplier,
354 C_zero_point,
355 C_int32,
356 C_uint8_base,
357 row_offsets,
358 col_offsets,
359 bias,
360 act_times_w_scale,
361 &middle_kernel);
362 }
363
364 for (; w < w_end; ++w) {
365 depthwise_2d_kernel_<
366 S,
367 FUSE_RELU,
368 HAS_BIAS,
369 A_SYMMETRIC,
370 B_SYMMETRIC,
371 Q_GRAN>(
372 H,
373 W,
374 IC,
375 OC,
376 h,
377 w,
378 stride_h,
379 stride_w,
380 A_zero_point,
381 A_base,
382 B_zero_point,
383 Bp,
384 C_multiplier,
385 C_zero_point,
386 C_int32,
387 C_uint8_base,
388 row_offsets,
389 col_offsets,
390 bias,
391 act_times_w_scale);
392 }
393 }
394
395 for (; h < h_end; ++h) {
396 for (w = w_begin; w < std::min(PAD_L, w_end); ++w) {
397 depthwise_2d_kernel_<
398 S,
399 FUSE_RELU,
400 HAS_BIAS,
401 A_SYMMETRIC,
402 B_SYMMETRIC,
403 Q_GRAN>(
404 H,
405 W,
406 IC,
407 OC,
408 h,
409 w,
410 stride_h,
411 stride_w,
412 A_zero_point,
413 A_base,
414 B_zero_point,
415 Bp,
416 C_multiplier,
417 C_zero_point,
418 C_int32,
419 C_uint8_base,
420 row_offsets,
421 col_offsets,
422 bias,
423 act_times_w_scale);
424 }
425
426 for (; w < std::min(W_OUT - PAD_R - stride_w + 1, w_end); ++w) {
427 depthwise_2d_kernel_<
428 S,
429 FUSE_RELU,
430 HAS_BIAS,
431 A_SYMMETRIC,
432 B_SYMMETRIC,
433 Q_GRAN>(
434 H,
435 W,
436 IC,
437 OC,
438 h,
439 w,
440 stride_h,
441 stride_w,
442 A_zero_point,
443 A_base,
444 B_zero_point,
445 Bp,
446 C_multiplier,
447 C_zero_point,
448 C_int32,
449 C_uint8_base,
450 row_offsets,
451 col_offsets,
452 bias,
453 act_times_w_scale);
454 }
455
456 for (; w < w_end; ++w) {
457 depthwise_2d_kernel_<
458 S,
459 FUSE_RELU,
460 HAS_BIAS,
461 A_SYMMETRIC,
462 B_SYMMETRIC,
463 Q_GRAN>(
464 H,
465 W,
466 IC,
467 OC,
468 h,
469 w,
470 stride_h,
471 stride_w,
472 A_zero_point,
473 A_base,
474 B_zero_point,
475 Bp,
476 C_multiplier,
477 C_zero_point,
478 C_int32,
479 C_uint8_base,
480 row_offsets,
481 col_offsets,
482 bias,
483 act_times_w_scale);
484 }
485 }
486 } // for each n
487
488 fbgemmAlignedFree(row_offsets);
489}
490
491// Dispatch A_SYMMETRIC and B_SYMMETRIC
492template <
493 int S,
494 bool FUSE_RELU,
495 bool HAS_BIAS,
496 QuantizationGranularity Q_GRAN,
497 typename BIAS_TYPE>
498static void depthwise_2d_(
499 int N,
500 int H,
501 int W,
502 int IC,
503 int OC,
504 int stride_h,
505 int stride_w,
506 std::int32_t A_zero_point,
507 const std::uint8_t* A,
508 const std::int32_t* B_zero_point,
509 const PackedDepthWiseConvMatrix& B,
510 const float* C_multiplier,
511 std::int32_t C_zero_point,
512 std::uint8_t* C,
513 const std::int32_t* col_offsets,
514 const BIAS_TYPE* bias,
515 const float* act_times_w_scale,
516 int thread_id,
517 int num_threads) {
518 int32_t* C_int32_temp = static_cast<int32_t*>(
519 fbgemmAlignedAlloc(64, (OC + 31) / 32 * 32 * sizeof(int32_t)));
520 if (A_zero_point == 0 || col_offsets == nullptr) {
521 if (Q_GRAN == QuantizationGranularity::TENSOR && B_zero_point[0] == 0) {
522 depthwise_2d_<
523 S,
524 FUSE_RELU,
525 HAS_BIAS,
526 true /*A_symmetric*/,
527 true /*B_symmetric*/,
528 Q_GRAN>(
529 N,
530 H,
531 W,
532 IC,
533 OC,
534 stride_h,
535 stride_w,
536 A_zero_point,
537 A,
538 B_zero_point,
539 B,
540 C_multiplier,
541 C_zero_point,
542 C_int32_temp,
543 C,
544 col_offsets,
545 bias,
546 act_times_w_scale,
547 thread_id,
548 num_threads);
549 } else {
550 depthwise_2d_<
551 S,
552 FUSE_RELU,
553 HAS_BIAS,
554 true /*A_symmetric*/,
555 false /*B_symmetric*/,
556 Q_GRAN>(
557 N,
558 H,
559 W,
560 IC,
561 OC,
562 stride_h,
563 stride_w,
564 A_zero_point,
565 A,
566 B_zero_point,
567 B,
568 C_multiplier,
569 C_zero_point,
570 C_int32_temp,
571 C,
572 col_offsets,
573 bias,
574 act_times_w_scale,
575 thread_id,
576 num_threads);
577 }
578 } else {
579 if (Q_GRAN == QuantizationGranularity::TENSOR && B_zero_point[0] == 0) {
580 depthwise_2d_<
581 S,
582 FUSE_RELU,
583 HAS_BIAS,
584 false /*A_symmetric*/,
585 true /*B_symmetric*/,
586 Q_GRAN>(
587 N,
588 H,
589 W,
590 IC,
591 OC,
592 stride_h,
593 stride_w,
594 A_zero_point,
595 A,
596 B_zero_point,
597 B,
598 C_multiplier,
599 C_zero_point,
600 C_int32_temp,
601 C,
602 col_offsets,
603 bias,
604 act_times_w_scale,
605 thread_id,
606 num_threads);
607 } else {
608 depthwise_2d_<
609 S,
610 FUSE_RELU,
611 HAS_BIAS,
612 false /*A_symmetric*/,
613 false /*B_symmetric*/,
614 Q_GRAN>(
615 N,
616 H,
617 W,
618 IC,
619 OC,
620 stride_h,
621 stride_w,
622 A_zero_point,
623 A,
624 B_zero_point,
625 B,
626 C_multiplier,
627 C_zero_point,
628 C_int32_temp,
629 C,
630 col_offsets,
631 bias,
632 act_times_w_scale,
633 thread_id,
634 num_threads);
635 }
636 }
637 fbgemmAlignedFree(C_int32_temp);
638}
639
640// Dispatch HAS_BIAS
641template <
642 int S,
643 bool FUSE_RELU,
644 QuantizationGranularity Q_GRAN,
645 typename BIAS_TYPE>
646static void depthwise_2d_(
647 int N,
648 int H,
649 int W,
650 int IC,
651 int OC,
652 int stride_h,
653 int stride_w,
654 std::int32_t A_zero_point,
655 const std::uint8_t* A,
656 const std::int32_t* B_zero_point,
657 const PackedDepthWiseConvMatrix& B,
658 const float* C_multiplier,
659 std::int32_t C_zero_point,
660 std::uint8_t* C,
661 const std::int32_t* col_offsets,
662 const BIAS_TYPE* bias,
663 const float* act_times_w_scale,
664 int thread_id,
665 int num_threads) {
666 if (bias) {
667 depthwise_2d_<S, FUSE_RELU, true /*HAS_BIAS*/, Q_GRAN>(
668 N,
669 H,
670 W,
671 IC,
672 OC,
673 stride_h,
674 stride_w,
675 A_zero_point,
676 A,
677 B_zero_point,
678 B,
679 C_multiplier,
680 C_zero_point,
681 C,
682 col_offsets,
683 bias,
684 act_times_w_scale,
685 thread_id,
686 num_threads);
687 } else {
688 depthwise_2d_<S, FUSE_RELU, false /*HAS_BIAS*/, Q_GRAN>(
689 N,
690 H,
691 W,
692 IC,
693 OC,
694 stride_h,
695 stride_w,
696 A_zero_point,
697 A,
698 B_zero_point,
699 B,
700 C_multiplier,
701 C_zero_point,
702 C,
703 col_offsets,
704 bias,
705 act_times_w_scale,
706 thread_id,
707 num_threads);
708 }
709}
710
711} // namespace fbgemm
712