1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <float.h> |
18 | #include <math.h> |
19 | #include <stdio.h> |
20 | #include <stdlib.h> |
21 | #include <string.h> |
22 | |
23 | #include "oneapi/dnnl/dnnl.h" |
24 | |
25 | #include "dnnl_common.hpp" |
26 | |
27 | #include "deconv/deconv.hpp" |
28 | |
29 | #define HALF_MAX 65504 |
30 | #define HALF_MIN (-65504) |
31 | |
32 | namespace deconv { |
33 | |
34 | /* cfgs definition |
35 | * arrays: SRC, WEI, BIA, DST, ACC |
36 | * params: {data_type, min, max, f_min, f_max, f_base, f_step, f_sparsity, eps} |
37 | */ |
38 | |
39 | const int int_max_exact_half = 1 << 11; |
40 | const _dt_conf_t conf_f16 = { |
41 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, 1, .25, |
42 | 0.}, |
43 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -2, 2, -2, 1, 1.0, |
44 | 0.}, |
45 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -6, 6, 0, 1, 1.0, |
46 | 0.}, |
47 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, 1, .25, |
48 | 0.}, |
49 | {dnnl_f32}, |
50 | }; |
51 | |
52 | const int int_max_exact = 1 << 24; |
53 | const _dt_conf_t conf_f16f16f32 = { |
54 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, 1, .25, |
55 | 0.}, |
56 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -2, 2, -2, 1, 1.0, |
57 | 0.}, |
58 | {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
59 | {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
60 | {dnnl_f32}, |
61 | }; |
62 | |
63 | const _dt_conf_t conf_f32f16f16 = { |
64 | {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
65 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -2, 2, -2, 1, 1.0, |
66 | 0.}, |
67 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -6, 6, 0, 1, 1.0, |
68 | 0.}, |
69 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, 1, .25, |
70 | 0.}, |
71 | {dnnl_f32}, |
72 | }; |
73 | |
74 | const _dt_conf_t conf_f16f32f16 = { |
75 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, 1, .25, |
76 | 0.}, |
77 | {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
78 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -6, 6, 0, 1, 1.0, |
79 | 0.}, |
80 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, 1, .25, |
81 | 0.}, |
82 | {dnnl_f32}, |
83 | }; |
84 | |
85 | const _dt_conf_t conf_f32 = { |
86 | {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .25, 0.}, |
87 | {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, 1.0, 0.}, |
88 | {dnnl_f32, -int_max_exact, int_max_exact, -512, 512, 0, 1, 1.0, 0.}, |
89 | {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .25, 0.}, |
90 | {dnnl_f32}, |
91 | }; |
92 | |
93 | const _dt_conf_t conf_f32_wino = { |
94 | {dnnl_f32, -FLT_MAX, FLT_MAX, -16, 128, 3, 1, .25, 1e-5}, |
95 | {dnnl_f32, -FLT_MAX, FLT_MAX, 2, 64, 2, 1, .75, 6e-6}, |
96 | {dnnl_f32, -FLT_MAX, FLT_MAX, 1, 128, 1, 1, .25, 2e-7}, |
97 | {dnnl_f32, -FLT_MAX, FLT_MAX, -16, 128, 3, 1, .25, 2e-5}, |
98 | {dnnl_f32}, |
99 | }; |
100 | |
101 | const _dt_conf_t conf_f64 = { |
102 | {dnnl_f64, -int_max_exact, int_max_exact, -32, 32, 0, 1, .25, 0.}, |
103 | {dnnl_f64, -int_max_exact, int_max_exact, -32, 32, 0, 1, 1.0, 0.}, |
104 | {dnnl_f64, -int_max_exact, int_max_exact, -512, 512, 0, 1, 1.0, 0.}, |
105 | {dnnl_f64, -int_max_exact, int_max_exact, -32, 32, 0, 1, .25, 0.}, |
106 | {dnnl_f64}, |
107 | }; |
108 | |
109 | const _dt_conf_t conf_f32_with_bf16_fpmath = { |
110 | {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
111 | {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
112 | {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
113 | {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
114 | {dnnl_f32}, |
115 | }; |
116 | |
117 | const _dt_conf_t conf_f32_with_tf32_fpmath = { |
118 | {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
119 | {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
120 | {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
121 | {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
122 | {dnnl_f32}, |
123 | }; |
124 | |
125 | const _dt_conf_t conf_f16_wino = { |
126 | {dnnl_f16, HALF_MIN, HALF_MAX, -2, 16, 0, 1, .25, 5e-3}, |
127 | {dnnl_f16, HALF_MIN, HALF_MAX, 1, 6, -2, 1, .5, 6e-3}, |
128 | {dnnl_f16, HALF_MIN, HALF_MAX, 1, 2048, 0, 1, .25, 2e-3}, |
129 | {dnnl_f16, HALF_MIN, HALF_MAX, -2, 8, 0, 1, .25, 7e-3}, |
130 | {dnnl_f32}, |
131 | }; |
132 | |
133 | const _dt_conf_t conf_bf16bf16f32 = { |
134 | {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
135 | {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
136 | {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
137 | {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
138 | {dnnl_f32}, |
139 | }; |
140 | |
141 | const _dt_conf_t conf_bf16bf16f16 = { |
142 | {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
143 | {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
144 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, 1, .25, |
145 | 0.}, |
146 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, 1, .25, |
147 | 0.}, |
148 | {dnnl_f16}, |
149 | }; |
150 | |
151 | const _dt_conf_t conf_bf16bf16s8 = { |
152 | {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
153 | {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
154 | {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
155 | {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.}, |
156 | {dnnl_f32}, |
157 | }; |
158 | |
159 | const _dt_conf_t conf_bf16bf16u8 = { |
160 | {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
161 | {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
162 | {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
163 | {dnnl_u8, 0, UINT8_MAX, 0, UINT8_MAX, 0, 1, .25, 0.}, |
164 | {dnnl_f32}, |
165 | }; |
166 | |
167 | const _dt_conf_t conf_bf16bf16bf16 = { |
168 | /* eps is 1e-2 because of loss in precision of |
169 | * output when converted from fp32 to bf16. |
170 | * oneDNN output is compared against reference computed in fp32.*/ |
171 | {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 1e-2}, |
172 | {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 1e-2}, |
173 | {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 1e-2}, |
174 | {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 1e-2}, |
175 | {dnnl_f32}, |
176 | }; |
177 | |
178 | const _dt_conf_t conf_f32bf16bf16 = { |
179 | {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
180 | {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
181 | {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
182 | {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
183 | {dnnl_f32}, |
184 | }; |
185 | |
186 | const _dt_conf_t conf_f32f32s8 = { |
187 | {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
188 | {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
189 | {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
190 | {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.}, |
191 | {dnnl_f32}, |
192 | }; |
193 | |
194 | const _dt_conf_t conf_f32f32u8 = { |
195 | {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
196 | {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
197 | {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
198 | {dnnl_s8, 0, UINT8_MAX, 0, UINT8_MAX, 0, 1, .25, 0.}, |
199 | {dnnl_f32}, |
200 | }; |
201 | |
202 | const _dt_conf_t conf_f16f16s8 = { |
203 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, 1, .25, |
204 | 0.}, |
205 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -2, 2, -2, 1, 1.0, |
206 | 0.}, |
207 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -8, 8, 0, 1, 1.0, |
208 | 0.}, |
209 | {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.}, |
210 | {dnnl_f32}, |
211 | }; |
212 | |
213 | const _dt_conf_t conf_f16f16u8 = { |
214 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, 1, .25, |
215 | 0.}, |
216 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -2, 2, -2, 1, 1.0, |
217 | 0.}, |
218 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -8, 8, 0, 1, 1.0, |
219 | 0.}, |
220 | {dnnl_u8, 0, UINT8_MAX, 0, UINT8_MAX, 0, 1, .25, 0.}, |
221 | {dnnl_f32}, |
222 | }; |
223 | |
224 | const _dt_conf_t conf_bf16f32bf16 = { |
225 | {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
226 | {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
227 | {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
228 | {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
229 | {dnnl_f32}, |
230 | }; |
231 | |
232 | const _dt_conf_t conf_u8s8f32 = { |
233 | {dnnl_u8, 0, UINT8_MAX, 0, UINT8_MAX, 0, 1, .25, 0.}, |
234 | {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.}, |
235 | {dnnl_f32, INT32_MIN, INT32_MAX, -8, 32, 0, 1, .25, 0.}, |
236 | {dnnl_f32, INT32_MIN, INT32_MAX, -255, 255, 0, 1, .25, 0.}, |
237 | {dnnl_s32}, |
238 | }; |
239 | |
240 | const _dt_conf_t conf_u8s8f16 = { |
241 | {dnnl_u8, 0, UINT8_MAX, 0, 8, 0, 1, .25, 0.}, |
242 | {dnnl_s8, INT8_MIN, INT8_MAX, -2, 2, -2, 1, 1.0, 0.}, |
243 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -8, 8, 0, 1, 1.0, |
244 | 0.}, |
245 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, 1, .25, |
246 | 0.}, |
247 | {dnnl_s32}, |
248 | }; |
249 | |
250 | const _dt_conf_t conf_u8s8bf16 = { |
251 | {dnnl_u8, 0, UINT8_MAX, 0, UINT8_MAX, 0, 1, .25, 0.}, |
252 | {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.}, |
253 | {dnnl_f32, INT32_MIN, INT32_MAX, -8, 32, 0, 1, .25, 0.}, |
254 | {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 1e-2}, |
255 | {dnnl_s32}, |
256 | }; |
257 | |
258 | const _dt_conf_t conf_u8s8s32 = { |
259 | {dnnl_u8, 0, UINT8_MAX, 0, UINT8_MAX, 0, 1, .25, 0.}, |
260 | {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.}, |
261 | {dnnl_f32, INT32_MIN, INT32_MAX, -8, 32, 0, 1, .25, 0.}, |
262 | {dnnl_s32, INT32_MIN, INT32_MAX, -255, 255, 0, 1, .25, 0.}, |
263 | {dnnl_s32}, |
264 | }; |
265 | |
266 | const _dt_conf_t conf_u8s8s8 = { |
267 | {dnnl_u8, 0, UINT8_MAX, 0, 8, 0, 1, .25, 0.}, |
268 | {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.}, |
269 | {dnnl_f32, INT32_MIN, INT32_MAX, -8, 32, 0, 1, .25, 0.}, |
270 | {dnnl_s8, INT8_MIN, INT8_MAX, -127, 127, 0, 1, .25, 0.}, |
271 | {dnnl_s32}, |
272 | }; |
273 | |
274 | const _dt_conf_t conf_u8s8u8 = { |
275 | {dnnl_u8, 0, UINT8_MAX, 0, 8, 0, 1, .25, 0.}, |
276 | {dnnl_s8, INT8_MIN, INT8_MAX, -3, 5, 0, 1, .25, 0.}, |
277 | {dnnl_f32, INT32_MIN, INT32_MAX, -8, 32, 0, 1, .25, 0.}, |
278 | {dnnl_u8, 0, UINT8_MAX, 0, 255, 0, 1, .25, 0.}, |
279 | {dnnl_s32}, |
280 | }; |
281 | |
282 | const _dt_conf_t conf_s8s8f32 = { |
283 | {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.}, |
284 | {dnnl_s8, INT8_MIN, INT8_MAX, -8, 3, 0, 4, .25, 0.}, |
285 | {dnnl_f32, INT32_MIN, INT32_MAX, -8, 32, 0, 1, .25, 0.}, |
286 | {dnnl_f32, INT32_MIN, INT32_MAX, -255, 255, 0, 1, .25, 0.}, |
287 | {dnnl_s32}, |
288 | }; |
289 | |
290 | const _dt_conf_t conf_s8s8f16 = { |
291 | {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.}, |
292 | {dnnl_s8, INT8_MIN, INT8_MAX, -8, 3, 0, 4, .25, 0.}, |
293 | {dnnl_f32, INT32_MIN, INT32_MAX, -8, 32, 0, 1, .25, 0.}, |
294 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, 1, .25, |
295 | 0.}, |
296 | {dnnl_s32}, |
297 | }; |
298 | |
299 | const _dt_conf_t conf_s8s8bf16 = { |
300 | {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.}, |
301 | {dnnl_s8, INT8_MIN, INT8_MAX, -8, 3, 0, 4, .25, 0.}, |
302 | {dnnl_f32, INT32_MIN, INT32_MAX, -8, 32, 0, 1, .25, 0.}, |
303 | {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 1e-2}, |
304 | {dnnl_s32}, |
305 | }; |
306 | |
307 | const _dt_conf_t conf_s8s8s32 = { |
308 | {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.}, |
309 | {dnnl_s8, INT8_MIN, INT8_MAX, -8, 3, 0, 4, .25, 0.}, |
310 | {dnnl_f32, INT32_MIN, INT32_MAX, -8, 32, 0, 1, .25, 0.}, |
311 | {dnnl_s32, INT32_MIN, INT32_MAX, -255, 255, 0, 1, .25, 0.}, |
312 | {dnnl_s32}, |
313 | }; |
314 | |
315 | const _dt_conf_t conf_s8s8s8 = { |
316 | {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.}, |
317 | {dnnl_s8, INT8_MIN, INT8_MAX, -8, 3, 0, 4, .25, 0.}, |
318 | {dnnl_f32, INT32_MIN, INT32_MAX, -8, 32, 0, 1, .25, 0.}, |
319 | {dnnl_s8, INT8_MIN, INT8_MAX, -127, 127, 0, 1, .25, 0.}, |
320 | {dnnl_s32}, |
321 | }; |
322 | |
323 | const _dt_conf_t conf_s8s8u8 = { |
324 | {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.}, |
325 | {dnnl_s8, INT8_MIN, INT8_MAX, -4, 7, 0, 4, .25, 0.}, |
326 | {dnnl_f32, INT32_MIN, INT32_MAX, -8, 32, 0, 1, .25, 0.}, |
327 | {dnnl_u8, 0, UINT8_MAX, 0, 255, 0, 1, .25, 0.}, |
328 | {dnnl_s32}, |
329 | }; |
330 | |
331 | const dt_conf_t *str2cfg(const char *str) { |
332 | #define CASE(cfg) \ |
333 | if (!strcasecmp(STRINGIFY(cfg), str)) return CONCAT2(conf_, cfg) |
334 | CASE(f16); |
335 | CASE(f32); |
336 | CASE(f32_wino); |
337 | CASE(f64); |
338 | CASE(u8s8f32); |
339 | CASE(u8s8f16); |
340 | CASE(u8s8bf16); |
341 | CASE(u8s8s32); |
342 | CASE(u8s8s8); |
343 | CASE(u8s8u8); |
344 | CASE(s8s8f32); |
345 | CASE(s8s8f16); |
346 | CASE(s8s8bf16); |
347 | CASE(s8s8s32); |
348 | CASE(s8s8s8); |
349 | CASE(s8s8u8); |
350 | CASE(bf16bf16f32); |
351 | CASE(bf16bf16f16); |
352 | CASE(bf16bf16s8); |
353 | CASE(bf16bf16u8); |
354 | CASE(bf16bf16bf16); |
355 | CASE(f32bf16bf16); |
356 | CASE(bf16f32bf16); |
357 | CASE(f32f32s8); |
358 | CASE(f32f32u8); |
359 | CASE(f16f16f32); |
360 | CASE(f32f16f16); |
361 | CASE(f16f32f16); |
362 | CASE(f16f16s8); |
363 | CASE(f16f16u8); |
364 | #undef CASE |
365 | []() { |
366 | SAFE(FAIL, CRIT); |
367 | return 0; |
368 | }(); |
369 | return (const dt_conf_t *)1; |
370 | } |
371 | |
372 | std::ostream &operator<<(std::ostream &s, const dt_conf_t *cfg) { |
373 | #define CASE(_cfg) \ |
374 | if (cfg == CONCAT2(conf_, _cfg)) return s << STRINGIFY(_cfg) |
375 | CASE(f16); |
376 | CASE(f32); |
377 | CASE(f32_wino); |
378 | CASE(f64); |
379 | CASE(u8s8f32); |
380 | CASE(u8s8f16); |
381 | CASE(u8s8bf16); |
382 | CASE(u8s8s32); |
383 | CASE(u8s8s8); |
384 | CASE(u8s8u8); |
385 | CASE(s8s8f32); |
386 | CASE(s8s8f16); |
387 | CASE(s8s8bf16); |
388 | CASE(s8s8s32); |
389 | CASE(s8s8s8); |
390 | CASE(s8s8u8); |
391 | CASE(f16f16f32); |
392 | CASE(f32f16f16); |
393 | CASE(f16f32f16); |
394 | CASE(f16f16s8); |
395 | CASE(f16f16u8); |
396 | CASE(bf16bf16f32); |
397 | CASE(bf16bf16f16); |
398 | CASE(bf16bf16s8); |
399 | CASE(bf16bf16u8); |
400 | CASE(bf16bf16bf16); |
401 | CASE(f32bf16bf16); |
402 | CASE(f32f32s8); |
403 | CASE(f32f32u8); |
404 | CASE(bf16f32bf16); |
405 | #undef CASE |
406 | SAFE_V(FAIL); |
407 | return s; |
408 | } |
409 | |
410 | const dt_conf_t *auto_cfg(const alg_t alg, const dt_conf_t *cfg) { |
411 | if (alg != WINO) return cfg; |
412 | |
413 | std::stringstream ss; |
414 | ss << cfg << "_wino" ; |
415 | const std::string cpp_pstr = ss.str(); |
416 | const char *cfg_s = cpp_pstr.c_str(); |
417 | #define CASE(_cfg_) \ |
418 | if (!strcmp(cfg_s, STRINGIFY(_cfg_))) return CONCAT2(conf_, _cfg_) |
419 | CASE(f32_wino); |
420 | CASE(f16_wino); |
421 | #undef CASE |
422 | return cfg; |
423 | } |
424 | |
425 | } // namespace deconv |
426 | |