1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#define FBGEMM_EXPORTS
8#include "fbgemm/FbgemmI8DepthwiseAvx2.h"
9
10#include <stdexcept> // for logic_error
11#include <string>
12
13#include "./FbgemmI8Depthwise2DAvx2-inl.h"
14
15using namespace std;
16
17namespace fbgemm {
18
19// Dispatch input shape and FUSE_RELU
20template <QuantizationGranularity Q_GRAN, typename BIAS_TYPE /*=std::int32_t*/>
21void depthwise_2d_same_pad(
22 int N,
23 int H,
24 int W,
25 int IC,
26 int OC,
27 int stride_h,
28 int stride_w,
29 int32_t A_zero_point,
30 const uint8_t* A,
31 const int32_t* B_zero_point,
32 const PackedDepthWiseConvMatrix& B,
33 const float* C_multiplier,
34 int32_t C_zero_point,
35 uint8_t* C,
36 const int32_t* col_offsets,
37 const BIAS_TYPE* bias,
38 bool fuse_relu,
39 const float* act_times_w_scale,
40 int thread_id,
41 int num_threads) {
42 if (B.GetKernelProduct() == 3 * 3) {
43 if (fuse_relu) {
44 depthwise_2d_<3, true /* FUSE_RELU */, Q_GRAN>(
45 N,
46 H,
47 W,
48 IC,
49 OC,
50 stride_h,
51 stride_w,
52 A_zero_point,
53 A,
54 B_zero_point,
55 B,
56 C_multiplier,
57 C_zero_point,
58 C,
59 col_offsets,
60 bias,
61 act_times_w_scale,
62 thread_id,
63 num_threads);
64 } else {
65 depthwise_2d_<3, false /* FUSE_RELU */, Q_GRAN>(
66 N,
67 H,
68 W,
69 IC,
70 OC,
71 stride_h,
72 stride_w,
73 A_zero_point,
74 A,
75 B_zero_point,
76 B,
77 C_multiplier,
78 C_zero_point,
79 C,
80 col_offsets,
81 bias,
82 act_times_w_scale,
83 thread_id,
84 num_threads);
85 }
86 return;
87 }
88
89 if (B.GetKernelProduct() == 5 * 5) {
90 if (fuse_relu) {
91 depthwise_2d_<5, true /* FUSE_RELU */, Q_GRAN>(
92 N,
93 H,
94 W,
95 IC,
96 OC,
97 stride_h,
98 stride_w,
99 A_zero_point,
100 A,
101 B_zero_point,
102 B,
103 C_multiplier,
104 C_zero_point,
105 C,
106 col_offsets,
107 bias,
108 act_times_w_scale,
109 thread_id,
110 num_threads);
111 } else {
112 depthwise_2d_<5, false /* FUSE_RELU */, Q_GRAN>(
113 N,
114 H,
115 W,
116 IC,
117 OC,
118 stride_h,
119 stride_w,
120 A_zero_point,
121 A,
122 B_zero_point,
123 B,
124 C_multiplier,
125 C_zero_point,
126 C,
127 col_offsets,
128 bias,
129 act_times_w_scale,
130 thread_id,
131 num_threads);
132 }
133 return;
134 }
135
136 if (B.GetKernelProduct() != 7 * 7) {
137 string msg =
138 "[FBGEMM_CONV_ERROR] Packed weight is expected to have kernel_prod " +
139 to_string(7 * 7) + " but has " + to_string(B.GetKernelProduct());
140 throw logic_error(msg);
141 }
142
143 if (fuse_relu) {
144 depthwise_2d_<7, true /* FUSE_RELU */, Q_GRAN>(
145 N,
146 H,
147 W,
148 IC,
149 OC,
150 stride_h,
151 stride_w,
152 A_zero_point,
153 A,
154 B_zero_point,
155 B,
156 C_multiplier,
157 C_zero_point,
158 C,
159 col_offsets,
160 bias,
161 act_times_w_scale,
162 thread_id,
163 num_threads);
164 } else {
165 depthwise_2d_<7, false /* FUSE_RELU */, Q_GRAN>(
166 N,
167 H,
168 W,
169 IC,
170 OC,
171 stride_h,
172 stride_w,
173 A_zero_point,
174 A,
175 B_zero_point,
176 B,
177 C_multiplier,
178 C_zero_point,
179 C,
180 col_offsets,
181 bias,
182 act_times_w_scale,
183 thread_id,
184 num_threads);
185 }
186}
187
188#define INSTANTIATE_BASE(Q_GRAN, BIAS_TYPE) \
189 template FBGEMM_API void \
190 depthwise_2d_same_pad<QuantizationGranularity::Q_GRAN>( \
191 int N, \
192 int H, \
193 int W, \
194 int IC, \
195 int OC, \
196 int stride_h, \
197 int stride_w, \
198 int32_t A_zero_point, \
199 const uint8_t* A, \
200 const int32_t* B_zero_point, \
201 const PackedDepthWiseConvMatrix& B, \
202 const float* C_multiplier, \
203 int32_t C_zero_point, \
204 uint8_t* C, \
205 const int32_t* col_offsets, \
206 const BIAS_TYPE* bias, \
207 bool fuse_relu, \
208 const float* act_times_w_scale, \
209 int thread_id, \
210 int num_threads);
211
212#define INSTANTIATE_BIAS_T(Q_GRAN) \
213 INSTANTIATE_BASE(Q_GRAN, int32_t) \
214 INSTANTIATE_BASE(Q_GRAN, float)
215
216INSTANTIATE_BIAS_T(TENSOR)
217INSTANTIATE_BIAS_T(GROUP)
218INSTANTIATE_BIAS_T(OUT_CHANNEL)
219
220#undef INSTANTIATE_BIAS_T
221#undef INSTANTIATE_CT
222#undef INSTANTIATE_BASE
223
224} // namespace fbgemm
225