1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * This source code is licensed under the BSD-style license found in the |
5 | * LICENSE file in the root directory of this source tree. |
6 | */ |
7 | #define FBGEMM_EXPORTS |
8 | #include "fbgemm/FbgemmI8DepthwiseAvx2.h" |
9 | |
10 | #include <stdexcept> // for logic_error |
11 | #include <string> |
12 | |
13 | #include "./FbgemmI8Depthwise2DAvx2-inl.h" |
14 | |
15 | using namespace std; |
16 | |
17 | namespace fbgemm { |
18 | |
19 | // Dispatch input shape and FUSE_RELU |
20 | template <QuantizationGranularity Q_GRAN, typename BIAS_TYPE /*=std::int32_t*/> |
21 | void depthwise_2d_same_pad( |
22 | int N, |
23 | int H, |
24 | int W, |
25 | int IC, |
26 | int OC, |
27 | int stride_h, |
28 | int stride_w, |
29 | int32_t A_zero_point, |
30 | const uint8_t* A, |
31 | const int32_t* B_zero_point, |
32 | const PackedDepthWiseConvMatrix& B, |
33 | const float* C_multiplier, |
34 | int32_t C_zero_point, |
35 | uint8_t* C, |
36 | const int32_t* col_offsets, |
37 | const BIAS_TYPE* bias, |
38 | bool fuse_relu, |
39 | const float* act_times_w_scale, |
40 | int thread_id, |
41 | int num_threads) { |
42 | if (B.GetKernelProduct() == 3 * 3) { |
43 | if (fuse_relu) { |
44 | depthwise_2d_<3, true /* FUSE_RELU */, Q_GRAN>( |
45 | N, |
46 | H, |
47 | W, |
48 | IC, |
49 | OC, |
50 | stride_h, |
51 | stride_w, |
52 | A_zero_point, |
53 | A, |
54 | B_zero_point, |
55 | B, |
56 | C_multiplier, |
57 | C_zero_point, |
58 | C, |
59 | col_offsets, |
60 | bias, |
61 | act_times_w_scale, |
62 | thread_id, |
63 | num_threads); |
64 | } else { |
65 | depthwise_2d_<3, false /* FUSE_RELU */, Q_GRAN>( |
66 | N, |
67 | H, |
68 | W, |
69 | IC, |
70 | OC, |
71 | stride_h, |
72 | stride_w, |
73 | A_zero_point, |
74 | A, |
75 | B_zero_point, |
76 | B, |
77 | C_multiplier, |
78 | C_zero_point, |
79 | C, |
80 | col_offsets, |
81 | bias, |
82 | act_times_w_scale, |
83 | thread_id, |
84 | num_threads); |
85 | } |
86 | return; |
87 | } |
88 | |
89 | if (B.GetKernelProduct() == 5 * 5) { |
90 | if (fuse_relu) { |
91 | depthwise_2d_<5, true /* FUSE_RELU */, Q_GRAN>( |
92 | N, |
93 | H, |
94 | W, |
95 | IC, |
96 | OC, |
97 | stride_h, |
98 | stride_w, |
99 | A_zero_point, |
100 | A, |
101 | B_zero_point, |
102 | B, |
103 | C_multiplier, |
104 | C_zero_point, |
105 | C, |
106 | col_offsets, |
107 | bias, |
108 | act_times_w_scale, |
109 | thread_id, |
110 | num_threads); |
111 | } else { |
112 | depthwise_2d_<5, false /* FUSE_RELU */, Q_GRAN>( |
113 | N, |
114 | H, |
115 | W, |
116 | IC, |
117 | OC, |
118 | stride_h, |
119 | stride_w, |
120 | A_zero_point, |
121 | A, |
122 | B_zero_point, |
123 | B, |
124 | C_multiplier, |
125 | C_zero_point, |
126 | C, |
127 | col_offsets, |
128 | bias, |
129 | act_times_w_scale, |
130 | thread_id, |
131 | num_threads); |
132 | } |
133 | return; |
134 | } |
135 | |
136 | if (B.GetKernelProduct() != 7 * 7) { |
137 | string msg = |
138 | "[FBGEMM_CONV_ERROR] Packed weight is expected to have kernel_prod " + |
139 | to_string(7 * 7) + " but has " + to_string(B.GetKernelProduct()); |
140 | throw logic_error(msg); |
141 | } |
142 | |
143 | if (fuse_relu) { |
144 | depthwise_2d_<7, true /* FUSE_RELU */, Q_GRAN>( |
145 | N, |
146 | H, |
147 | W, |
148 | IC, |
149 | OC, |
150 | stride_h, |
151 | stride_w, |
152 | A_zero_point, |
153 | A, |
154 | B_zero_point, |
155 | B, |
156 | C_multiplier, |
157 | C_zero_point, |
158 | C, |
159 | col_offsets, |
160 | bias, |
161 | act_times_w_scale, |
162 | thread_id, |
163 | num_threads); |
164 | } else { |
165 | depthwise_2d_<7, false /* FUSE_RELU */, Q_GRAN>( |
166 | N, |
167 | H, |
168 | W, |
169 | IC, |
170 | OC, |
171 | stride_h, |
172 | stride_w, |
173 | A_zero_point, |
174 | A, |
175 | B_zero_point, |
176 | B, |
177 | C_multiplier, |
178 | C_zero_point, |
179 | C, |
180 | col_offsets, |
181 | bias, |
182 | act_times_w_scale, |
183 | thread_id, |
184 | num_threads); |
185 | } |
186 | } |
187 | |
188 | #define INSTANTIATE_BASE(Q_GRAN, BIAS_TYPE) \ |
189 | template FBGEMM_API void \ |
190 | depthwise_2d_same_pad<QuantizationGranularity::Q_GRAN>( \ |
191 | int N, \ |
192 | int H, \ |
193 | int W, \ |
194 | int IC, \ |
195 | int OC, \ |
196 | int stride_h, \ |
197 | int stride_w, \ |
198 | int32_t A_zero_point, \ |
199 | const uint8_t* A, \ |
200 | const int32_t* B_zero_point, \ |
201 | const PackedDepthWiseConvMatrix& B, \ |
202 | const float* C_multiplier, \ |
203 | int32_t C_zero_point, \ |
204 | uint8_t* C, \ |
205 | const int32_t* col_offsets, \ |
206 | const BIAS_TYPE* bias, \ |
207 | bool fuse_relu, \ |
208 | const float* act_times_w_scale, \ |
209 | int thread_id, \ |
210 | int num_threads); |
211 | |
212 | #define INSTANTIATE_BIAS_T(Q_GRAN) \ |
213 | INSTANTIATE_BASE(Q_GRAN, int32_t) \ |
214 | INSTANTIATE_BASE(Q_GRAN, float) |
215 | |
216 | INSTANTIATE_BIAS_T(TENSOR) |
217 | INSTANTIATE_BIAS_T(GROUP) |
218 | INSTANTIATE_BIAS_T(OUT_CHANNEL) |
219 | |
220 | #undef INSTANTIATE_BIAS_T |
221 | #undef INSTANTIATE_CT |
222 | #undef INSTANTIATE_BASE |
223 | |
224 | } // namespace fbgemm |
225 | |