1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * This source code is licensed under the BSD-style license found in the |
5 | * LICENSE file in the root directory of this source tree. |
6 | */ |
7 | #pragma once |
8 | |
9 | #include "./FbgemmI8DepthwiseAvx2-inl.h" |
10 | #include "./GenerateI8Depthwise.h" |
11 | #include "./MaskAvx2.h" |
12 | #include "fbgemm/Utils.h" |
13 | #include "fbgemm/UtilsAvx2.h" |
14 | |
15 | namespace fbgemm { |
16 | |
17 | template < |
18 | int S, |
19 | bool FUSE_RELU, |
20 | bool HAS_BIAS, |
21 | bool A_SYMMETRIC, |
22 | bool B_SYMMETRIC, |
23 | QuantizationGranularity Q_GRAN, |
24 | typename BIAS_TYPE> |
25 | static ALWAYS_INLINE void depthwise_2d_kernel_( |
26 | int H, |
27 | int W, |
28 | int IC, |
29 | int OC, |
30 | int h, |
31 | int w, |
32 | int stride_h, |
33 | int stride_w, |
34 | std::int32_t A_zero_point, |
35 | const std::uint8_t* A, |
36 | const std::int32_t* B_zero_point, |
37 | const std::int8_t* Bp, |
38 | const float* C_multiplier, |
39 | std::int32_t C_zero_point, |
40 | std::int32_t* C_int32, |
41 | std::uint8_t* C_uint8, |
42 | std::int32_t* row_offsets, |
43 | const std::int32_t* col_offsets, |
44 | const BIAS_TYPE* bias, |
45 | const float* act_times_w_scale, |
46 | GenI8Depthwise::jit_kernel_signature* pregenerated_kernel = nullptr) { |
47 | constexpr int PAD_T = (S - 1) / 2, PAD_L = (S - 1) / 2, PAD_R = (S - 1) / 2; |
48 | int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; |
49 | int h_in = -PAD_T + h * stride_h; |
50 | int w_in = -PAD_L + w * stride_w; |
51 | |
52 | int remainder = OC % 32; |
53 | if (remainder == 0) { |
54 | remainder = 32; |
55 | } |
56 | |
57 | GenI8Depthwise::jit_kernel_signature kernel = pregenerated_kernel |
58 | ? *pregenerated_kernel |
59 | : GenI8Depthwise().getOrCreate( |
60 | /*D=*/2, |
61 | {1, S, S}, |
62 | OC / IC, |
63 | /*compute_a_sum=*/!B_SYMMETRIC, |
64 | remainder, |
65 | 0, |
66 | 0, |
67 | /*top_skip=*/std::max(-h_in, 0), |
68 | /*bottom_skip=*/std::max(h_in + S - H, 0), |
69 | /*left_skip=*/std::max(-w_in, 0), |
70 | /*right_skip=*/std::max(w_in + S - W, 0)); |
71 | |
72 | kernel( |
73 | A + (h_in * W + w_in) * IC, |
74 | Bp, |
75 | C_int32, |
76 | B_SYMMETRIC ? nullptr : row_offsets, |
77 | H, |
78 | W, |
79 | IC, |
80 | internal::avx2_ps_or_epi32_combined_mask, |
81 | A_zero_point); |
82 | |
83 | if (OC == IC) { |
84 | requantize_<FUSE_RELU, HAS_BIAS, Q_GRAN, A_SYMMETRIC, B_SYMMETRIC, 1>( |
85 | A_zero_point, |
86 | B_zero_point, |
87 | C_multiplier, |
88 | C_zero_point, |
89 | C_int32, |
90 | C_uint8 + (h * W_OUT + w) * OC, |
91 | OC, |
92 | row_offsets, |
93 | col_offsets, |
94 | bias, |
95 | act_times_w_scale); |
96 | } else { |
97 | requantize_<FUSE_RELU, HAS_BIAS, Q_GRAN, A_SYMMETRIC, B_SYMMETRIC, 2>( |
98 | A_zero_point, |
99 | B_zero_point, |
100 | C_multiplier, |
101 | C_zero_point, |
102 | C_int32, |
103 | C_uint8 + (h * W_OUT + w) * OC, |
104 | OC, |
105 | row_offsets, |
106 | col_offsets, |
107 | bias, |
108 | act_times_w_scale); |
109 | } |
110 | } |
111 | |
112 | // TODO: short-circuit when B_zero_point is 0 or A_zero_point is 0 |
113 | // This implemntation should be general enough to handle not just 3x3 but other |
114 | // filter shapes by parameterizing with R and S but restricting it to just 3x3 |
115 | // for now. |
116 | template < |
117 | int S, |
118 | bool FUSE_RELU, |
119 | bool HAS_BIAS, |
120 | bool A_SYMMETRIC, |
121 | bool B_SYMMETRIC, |
122 | QuantizationGranularity Q_GRAN, |
123 | typename BIAS_TYPE> |
124 | static ALWAYS_INLINE void depthwise_2d_( |
125 | int N, |
126 | int H, |
127 | int W, |
128 | int IC, |
129 | int OC, |
130 | int stride_h, |
131 | int stride_w, |
132 | std::int32_t A_zero_point, |
133 | const std::uint8_t* A, |
134 | const std::int32_t* B_zero_point, |
135 | const PackedDepthWiseConvMatrix& B, |
136 | const float* C_multiplier, |
137 | std::int32_t C_zero_point, |
138 | std::int32_t* C_int32, |
139 | std::uint8_t* C_uint8, |
140 | const std::int32_t* col_offsets, |
141 | const BIAS_TYPE* bias, |
142 | const float* act_times_w_scale, |
143 | int thread_id, |
144 | int num_threads) { |
145 | assert(IC % 8 == 0); |
146 | constexpr int R = S; |
147 | constexpr int64_t PAD_T = (R - 1) / 2, PAD_B = PAD_T, PAD_L = (S - 1) / 2, |
148 | PAD_R = PAD_L; |
149 | int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1; |
150 | int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; |
151 | const std::int8_t* Bp = B.PackedMat(); |
152 | |
153 | int32_t* row_offsets = static_cast<int32_t*>( |
154 | fbgemmAlignedAlloc(64, (IC + 31) / 32 * 32 * sizeof(int32_t))); |
155 | |
156 | int64_t n_begin, n_end, h_begin, h_end, w_begin, w_end; |
157 | // Reuse the 3-dim partition scheme for parallelization in matrix |
158 | // multiplication. |
159 | thread_type_t th_info = |
160 | fbgemmGetThreadPartition(N, H_OUT, W_OUT, thread_id, num_threads); |
161 | // Calculate the begin and end index along the batch (N) dimension |
162 | fbgemmPartition1D( |
163 | th_info.g_thread_id, th_info.g_num_threads, N, n_begin, n_end); |
164 | // Calculate the begin and end index along the H dimension |
165 | fbgemmPartition1D( |
166 | th_info.m_thread_id, th_info.m_num_threads, H_OUT, h_begin, h_end); |
167 | // Calculate the begin and end index along the W dimension |
168 | fbgemmPartition1D( |
169 | th_info.n_thread_id, th_info.n_num_threads, W_OUT, w_begin, w_end); |
170 | |
171 | GenI8Depthwise::jit_kernel_signature middle_kernel; |
172 | |
173 | for (int n = n_begin; n < n_end; ++n) { |
174 | const std::uint8_t* A_base = A + n * H * W * IC; |
175 | std::uint8_t* C_uint8_base = C_uint8 + n * H_OUT * W_OUT * OC; |
176 | |
177 | int h = 0; |
178 | int w = 0; |
179 | |
180 | for (h = h_begin; h < std::min(PAD_T, h_end); ++h) { |
181 | for (w = w_begin; w < std::min(PAD_L, w_end); ++w) { |
182 | depthwise_2d_kernel_< |
183 | S, |
184 | FUSE_RELU, |
185 | HAS_BIAS, |
186 | A_SYMMETRIC, |
187 | B_SYMMETRIC, |
188 | Q_GRAN>( |
189 | H, |
190 | W, |
191 | IC, |
192 | OC, |
193 | h, |
194 | w, |
195 | stride_h, |
196 | stride_w, |
197 | A_zero_point, |
198 | A_base, |
199 | B_zero_point, |
200 | Bp, |
201 | C_multiplier, |
202 | C_zero_point, |
203 | C_int32, |
204 | C_uint8_base, |
205 | row_offsets, |
206 | col_offsets, |
207 | bias, |
208 | act_times_w_scale); |
209 | } |
210 | |
211 | for (; w < std::min(W_OUT - PAD_R - stride_w + 1, w_end); ++w) { |
212 | depthwise_2d_kernel_< |
213 | S, |
214 | FUSE_RELU, |
215 | HAS_BIAS, |
216 | A_SYMMETRIC, |
217 | B_SYMMETRIC, |
218 | Q_GRAN>( |
219 | H, |
220 | W, |
221 | IC, |
222 | OC, |
223 | h, |
224 | w, |
225 | stride_h, |
226 | stride_w, |
227 | A_zero_point, |
228 | A_base, |
229 | B_zero_point, |
230 | Bp, |
231 | C_multiplier, |
232 | C_zero_point, |
233 | C_int32, |
234 | C_uint8_base, |
235 | row_offsets, |
236 | col_offsets, |
237 | bias, |
238 | act_times_w_scale); |
239 | } |
240 | |
241 | for (; w < w_end; ++w) { |
242 | depthwise_2d_kernel_< |
243 | S, |
244 | FUSE_RELU, |
245 | HAS_BIAS, |
246 | A_SYMMETRIC, |
247 | B_SYMMETRIC, |
248 | Q_GRAN>( |
249 | H, |
250 | W, |
251 | IC, |
252 | OC, |
253 | h, |
254 | w, |
255 | stride_h, |
256 | stride_w, |
257 | A_zero_point, |
258 | A_base, |
259 | B_zero_point, |
260 | Bp, |
261 | C_multiplier, |
262 | C_zero_point, |
263 | C_int32, |
264 | C_uint8_base, |
265 | row_offsets, |
266 | col_offsets, |
267 | bias, |
268 | act_times_w_scale); |
269 | } |
270 | } |
271 | |
272 | // h <= H_OUT - PAD_B - stride_h |
273 | // h <= (H + PAD_T + PAD_B - S) / stride_h + 1 - PAD_B - stride_h |
274 | // h_in <= -PAD_T + |
275 | // ((H + PAD_T + PAD_B - S) / stride_h + 1 - PAD_B - stride_h) * stride_h |
276 | // Case 1) For stride_h == 1, |
277 | // h_in <= -PAD_T + H + PAD_T + PAD_B - S + 1 - PAD_B - 1 |
278 | // h_in + S - H <= 0 |
279 | // Case 2) For stride_h == 2, |
280 | // h_in <= -PAD_L + |
281 | // H + PAD_T + PAD_B - S + 1 + (1 - PAD_B - stride_h) * stride_h |
282 | // h_in + S - H <= PAD_B * (1 - stride_h) + 1 + (1 - stride_h) * stride_h |
283 | // <= -PAD_B + 1 - stride_h <= 0 |
284 | for (; h < std::min(H_OUT - PAD_B - stride_h + 1, h_end); ++h) { |
285 | for (w = w_begin; w < std::min(PAD_L, w_end); ++w) { |
286 | depthwise_2d_kernel_< |
287 | S, |
288 | FUSE_RELU, |
289 | HAS_BIAS, |
290 | A_SYMMETRIC, |
291 | B_SYMMETRIC, |
292 | Q_GRAN>( |
293 | H, |
294 | W, |
295 | IC, |
296 | OC, |
297 | h, |
298 | w, |
299 | stride_h, |
300 | stride_w, |
301 | A_zero_point, |
302 | A_base, |
303 | B_zero_point, |
304 | Bp, |
305 | C_multiplier, |
306 | C_zero_point, |
307 | C_int32, |
308 | C_uint8_base, |
309 | row_offsets, |
310 | col_offsets, |
311 | bias, |
312 | act_times_w_scale); |
313 | } |
314 | |
315 | for (; w < std::min(W_OUT - PAD_R - stride_w + 1, w_end); ++w) { |
316 | if (n == n_begin && w == std::max(PAD_L, w_begin)) { |
317 | int remainder = OC % 32; |
318 | if (remainder == 0) { |
319 | remainder = 32; |
320 | } |
321 | middle_kernel = GenI8Depthwise().getOrCreate( |
322 | /*D=*/2, |
323 | {1, S, S}, |
324 | OC / IC, |
325 | /*compute_a_sum=*/!B_SYMMETRIC, |
326 | remainder, |
327 | 0, |
328 | 0, |
329 | 0, |
330 | 0, |
331 | 0, |
332 | 0); |
333 | } |
334 | depthwise_2d_kernel_< |
335 | S, |
336 | FUSE_RELU, |
337 | HAS_BIAS, |
338 | A_SYMMETRIC, |
339 | B_SYMMETRIC, |
340 | Q_GRAN>( |
341 | H, |
342 | W, |
343 | IC, |
344 | OC, |
345 | h, |
346 | w, |
347 | stride_h, |
348 | stride_w, |
349 | A_zero_point, |
350 | A_base, |
351 | B_zero_point, |
352 | Bp, |
353 | C_multiplier, |
354 | C_zero_point, |
355 | C_int32, |
356 | C_uint8_base, |
357 | row_offsets, |
358 | col_offsets, |
359 | bias, |
360 | act_times_w_scale, |
361 | &middle_kernel); |
362 | } |
363 | |
364 | for (; w < w_end; ++w) { |
365 | depthwise_2d_kernel_< |
366 | S, |
367 | FUSE_RELU, |
368 | HAS_BIAS, |
369 | A_SYMMETRIC, |
370 | B_SYMMETRIC, |
371 | Q_GRAN>( |
372 | H, |
373 | W, |
374 | IC, |
375 | OC, |
376 | h, |
377 | w, |
378 | stride_h, |
379 | stride_w, |
380 | A_zero_point, |
381 | A_base, |
382 | B_zero_point, |
383 | Bp, |
384 | C_multiplier, |
385 | C_zero_point, |
386 | C_int32, |
387 | C_uint8_base, |
388 | row_offsets, |
389 | col_offsets, |
390 | bias, |
391 | act_times_w_scale); |
392 | } |
393 | } |
394 | |
395 | for (; h < h_end; ++h) { |
396 | for (w = w_begin; w < std::min(PAD_L, w_end); ++w) { |
397 | depthwise_2d_kernel_< |
398 | S, |
399 | FUSE_RELU, |
400 | HAS_BIAS, |
401 | A_SYMMETRIC, |
402 | B_SYMMETRIC, |
403 | Q_GRAN>( |
404 | H, |
405 | W, |
406 | IC, |
407 | OC, |
408 | h, |
409 | w, |
410 | stride_h, |
411 | stride_w, |
412 | A_zero_point, |
413 | A_base, |
414 | B_zero_point, |
415 | Bp, |
416 | C_multiplier, |
417 | C_zero_point, |
418 | C_int32, |
419 | C_uint8_base, |
420 | row_offsets, |
421 | col_offsets, |
422 | bias, |
423 | act_times_w_scale); |
424 | } |
425 | |
426 | for (; w < std::min(W_OUT - PAD_R - stride_w + 1, w_end); ++w) { |
427 | depthwise_2d_kernel_< |
428 | S, |
429 | FUSE_RELU, |
430 | HAS_BIAS, |
431 | A_SYMMETRIC, |
432 | B_SYMMETRIC, |
433 | Q_GRAN>( |
434 | H, |
435 | W, |
436 | IC, |
437 | OC, |
438 | h, |
439 | w, |
440 | stride_h, |
441 | stride_w, |
442 | A_zero_point, |
443 | A_base, |
444 | B_zero_point, |
445 | Bp, |
446 | C_multiplier, |
447 | C_zero_point, |
448 | C_int32, |
449 | C_uint8_base, |
450 | row_offsets, |
451 | col_offsets, |
452 | bias, |
453 | act_times_w_scale); |
454 | } |
455 | |
456 | for (; w < w_end; ++w) { |
457 | depthwise_2d_kernel_< |
458 | S, |
459 | FUSE_RELU, |
460 | HAS_BIAS, |
461 | A_SYMMETRIC, |
462 | B_SYMMETRIC, |
463 | Q_GRAN>( |
464 | H, |
465 | W, |
466 | IC, |
467 | OC, |
468 | h, |
469 | w, |
470 | stride_h, |
471 | stride_w, |
472 | A_zero_point, |
473 | A_base, |
474 | B_zero_point, |
475 | Bp, |
476 | C_multiplier, |
477 | C_zero_point, |
478 | C_int32, |
479 | C_uint8_base, |
480 | row_offsets, |
481 | col_offsets, |
482 | bias, |
483 | act_times_w_scale); |
484 | } |
485 | } |
486 | } // for each n |
487 | |
488 | fbgemmAlignedFree(row_offsets); |
489 | } |
490 | |
491 | // Dispatch A_SYMMETRIC and B_SYMMETRIC |
492 | template < |
493 | int S, |
494 | bool FUSE_RELU, |
495 | bool HAS_BIAS, |
496 | QuantizationGranularity Q_GRAN, |
497 | typename BIAS_TYPE> |
498 | static void depthwise_2d_( |
499 | int N, |
500 | int H, |
501 | int W, |
502 | int IC, |
503 | int OC, |
504 | int stride_h, |
505 | int stride_w, |
506 | std::int32_t A_zero_point, |
507 | const std::uint8_t* A, |
508 | const std::int32_t* B_zero_point, |
509 | const PackedDepthWiseConvMatrix& B, |
510 | const float* C_multiplier, |
511 | std::int32_t C_zero_point, |
512 | std::uint8_t* C, |
513 | const std::int32_t* col_offsets, |
514 | const BIAS_TYPE* bias, |
515 | const float* act_times_w_scale, |
516 | int thread_id, |
517 | int num_threads) { |
518 | int32_t* C_int32_temp = static_cast<int32_t*>( |
519 | fbgemmAlignedAlloc(64, (OC + 31) / 32 * 32 * sizeof(int32_t))); |
520 | if (A_zero_point == 0 || col_offsets == nullptr) { |
521 | if (Q_GRAN == QuantizationGranularity::TENSOR && B_zero_point[0] == 0) { |
522 | depthwise_2d_< |
523 | S, |
524 | FUSE_RELU, |
525 | HAS_BIAS, |
526 | true /*A_symmetric*/, |
527 | true /*B_symmetric*/, |
528 | Q_GRAN>( |
529 | N, |
530 | H, |
531 | W, |
532 | IC, |
533 | OC, |
534 | stride_h, |
535 | stride_w, |
536 | A_zero_point, |
537 | A, |
538 | B_zero_point, |
539 | B, |
540 | C_multiplier, |
541 | C_zero_point, |
542 | C_int32_temp, |
543 | C, |
544 | col_offsets, |
545 | bias, |
546 | act_times_w_scale, |
547 | thread_id, |
548 | num_threads); |
549 | } else { |
550 | depthwise_2d_< |
551 | S, |
552 | FUSE_RELU, |
553 | HAS_BIAS, |
554 | true /*A_symmetric*/, |
555 | false /*B_symmetric*/, |
556 | Q_GRAN>( |
557 | N, |
558 | H, |
559 | W, |
560 | IC, |
561 | OC, |
562 | stride_h, |
563 | stride_w, |
564 | A_zero_point, |
565 | A, |
566 | B_zero_point, |
567 | B, |
568 | C_multiplier, |
569 | C_zero_point, |
570 | C_int32_temp, |
571 | C, |
572 | col_offsets, |
573 | bias, |
574 | act_times_w_scale, |
575 | thread_id, |
576 | num_threads); |
577 | } |
578 | } else { |
579 | if (Q_GRAN == QuantizationGranularity::TENSOR && B_zero_point[0] == 0) { |
580 | depthwise_2d_< |
581 | S, |
582 | FUSE_RELU, |
583 | HAS_BIAS, |
584 | false /*A_symmetric*/, |
585 | true /*B_symmetric*/, |
586 | Q_GRAN>( |
587 | N, |
588 | H, |
589 | W, |
590 | IC, |
591 | OC, |
592 | stride_h, |
593 | stride_w, |
594 | A_zero_point, |
595 | A, |
596 | B_zero_point, |
597 | B, |
598 | C_multiplier, |
599 | C_zero_point, |
600 | C_int32_temp, |
601 | C, |
602 | col_offsets, |
603 | bias, |
604 | act_times_w_scale, |
605 | thread_id, |
606 | num_threads); |
607 | } else { |
608 | depthwise_2d_< |
609 | S, |
610 | FUSE_RELU, |
611 | HAS_BIAS, |
612 | false /*A_symmetric*/, |
613 | false /*B_symmetric*/, |
614 | Q_GRAN>( |
615 | N, |
616 | H, |
617 | W, |
618 | IC, |
619 | OC, |
620 | stride_h, |
621 | stride_w, |
622 | A_zero_point, |
623 | A, |
624 | B_zero_point, |
625 | B, |
626 | C_multiplier, |
627 | C_zero_point, |
628 | C_int32_temp, |
629 | C, |
630 | col_offsets, |
631 | bias, |
632 | act_times_w_scale, |
633 | thread_id, |
634 | num_threads); |
635 | } |
636 | } |
637 | fbgemmAlignedFree(C_int32_temp); |
638 | } |
639 | |
640 | // Dispatch HAS_BIAS |
641 | template < |
642 | int S, |
643 | bool FUSE_RELU, |
644 | QuantizationGranularity Q_GRAN, |
645 | typename BIAS_TYPE> |
646 | static void depthwise_2d_( |
647 | int N, |
648 | int H, |
649 | int W, |
650 | int IC, |
651 | int OC, |
652 | int stride_h, |
653 | int stride_w, |
654 | std::int32_t A_zero_point, |
655 | const std::uint8_t* A, |
656 | const std::int32_t* B_zero_point, |
657 | const PackedDepthWiseConvMatrix& B, |
658 | const float* C_multiplier, |
659 | std::int32_t C_zero_point, |
660 | std::uint8_t* C, |
661 | const std::int32_t* col_offsets, |
662 | const BIAS_TYPE* bias, |
663 | const float* act_times_w_scale, |
664 | int thread_id, |
665 | int num_threads) { |
666 | if (bias) { |
667 | depthwise_2d_<S, FUSE_RELU, true /*HAS_BIAS*/, Q_GRAN>( |
668 | N, |
669 | H, |
670 | W, |
671 | IC, |
672 | OC, |
673 | stride_h, |
674 | stride_w, |
675 | A_zero_point, |
676 | A, |
677 | B_zero_point, |
678 | B, |
679 | C_multiplier, |
680 | C_zero_point, |
681 | C, |
682 | col_offsets, |
683 | bias, |
684 | act_times_w_scale, |
685 | thread_id, |
686 | num_threads); |
687 | } else { |
688 | depthwise_2d_<S, FUSE_RELU, false /*HAS_BIAS*/, Q_GRAN>( |
689 | N, |
690 | H, |
691 | W, |
692 | IC, |
693 | OC, |
694 | stride_h, |
695 | stride_w, |
696 | A_zero_point, |
697 | A, |
698 | B_zero_point, |
699 | B, |
700 | C_multiplier, |
701 | C_zero_point, |
702 | C, |
703 | col_offsets, |
704 | bias, |
705 | act_times_w_scale, |
706 | thread_id, |
707 | num_threads); |
708 | } |
709 | } |
710 | |
711 | } // namespace fbgemm |
712 | |