1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <float.h> |
18 | #include <math.h> |
19 | #include <stdio.h> |
20 | #include <stdlib.h> |
21 | #include <string.h> |
22 | |
23 | #include "oneapi/dnnl/dnnl.h" |
24 | |
25 | #include "dnnl_common.hpp" |
26 | |
27 | #include "conv/conv.hpp" |
28 | |
29 | #define HALF_MAX 65504 |
30 | #define HALF_MIN (-65504) |
31 | |
32 | namespace conv { |
33 | |
34 | /* cfgs definition |
35 | * arrays: SRC, WEI, BIA, DST, ACC |
36 | * params: {data_type, min, max, f_min, f_max, f_base, f_step, f_sparsity, eps} |
37 | */ |
38 | |
39 | const int int_max_exact_half = 1 << 11; |
40 | const _dt_conf_t conf_f16 = { |
41 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, 1, .25, |
42 | 0.}, |
43 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -2, 2, -2, 1, 1.0, |
44 | 0.}, |
45 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -6, 6, 0, 1, 1.0, |
46 | 0.}, |
47 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, 1, .25, |
48 | 0.}, |
49 | {dnnl_f32}, |
50 | }; |
51 | |
52 | const int int_max_exact = 1 << 24; |
53 | const _dt_conf_t conf_f16f16f32 = { |
54 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, 1, .25, |
55 | 0.}, |
56 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -2, 2, -2, 1, 1.0, |
57 | 0.}, |
58 | {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
59 | {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
60 | {dnnl_f32}, |
61 | }; |
62 | |
63 | const _dt_conf_t conf_f32 = { |
64 | {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .25, 0.}, |
65 | {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, 1.0, 0.}, |
66 | {dnnl_f32, -int_max_exact, int_max_exact, -512, 512, 0, 1, 1.0, 0.}, |
67 | {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .25, 0.}, |
68 | {dnnl_f32}, |
69 | }; |
70 | |
71 | const _dt_conf_t conf_f32_wino = { |
72 | {dnnl_f32, -FLT_MAX, FLT_MAX, -16, 128, 3, 1, .25, 1e-5}, |
73 | {dnnl_f32, -FLT_MAX, FLT_MAX, 2, 64, 2, 1, .75, 6e-6}, |
74 | {dnnl_f32, -FLT_MAX, FLT_MAX, 1, 128, 1, 1, .25, 2e-7}, |
75 | {dnnl_f32, -FLT_MAX, FLT_MAX, -16, 128, 3, 1, .25, 2e-5}, |
76 | {dnnl_f32}, |
77 | }; |
78 | |
79 | const _dt_conf_t conf_f64 = { |
80 | {dnnl_f64, -int_max_exact, int_max_exact, -32, 32, 0, 1, .25, 0.}, |
81 | {dnnl_f64, -int_max_exact, int_max_exact, -32, 32, 0, 1, 1.0, 0.}, |
82 | {dnnl_f64, -int_max_exact, int_max_exact, -512, 512, 0, 1, 1.0, 0.}, |
83 | {dnnl_f64, -int_max_exact, int_max_exact, -32, 32, 0, 1, .25, 0.}, |
84 | {dnnl_f64}, |
85 | }; |
86 | |
87 | const _dt_conf_t conf_f32_with_bf16_fpmath = { |
88 | {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
89 | {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
90 | {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
91 | {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
92 | {dnnl_f32}, |
93 | }; |
94 | |
95 | const _dt_conf_t conf_f32_with_tf32_fpmath = { |
96 | {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
97 | {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
98 | {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
99 | {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
100 | {dnnl_f32}, |
101 | }; |
102 | |
103 | const _dt_conf_t conf_f16_wino = { |
104 | {dnnl_f16, HALF_MIN, HALF_MAX, -2, 16, 0, 1, .25, 5e-3}, |
105 | {dnnl_f16, HALF_MIN, HALF_MAX, 1, 6, -2, 1, .5, 6e-3}, |
106 | {dnnl_f16, HALF_MIN, HALF_MAX, 1, 2048, 0, 1, .25, 2e-3}, |
107 | {dnnl_f16, HALF_MIN, HALF_MAX, -2, 8, 0, 1, .25, 7e-3}, |
108 | {dnnl_f32}, |
109 | }; |
110 | |
111 | const _dt_conf_t conf_bf16bf16f32 = { |
112 | {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
113 | {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
114 | {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
115 | {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
116 | {dnnl_f32}, |
117 | }; |
118 | |
119 | const _dt_conf_t conf_bf16bf16f16 = { |
120 | {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
121 | {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
122 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, 1, .25, |
123 | 0.}, |
124 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, 1, .25, |
125 | 0.}, |
126 | {dnnl_f16}, |
127 | }; |
128 | |
129 | const _dt_conf_t conf_bf16bf16s8 = { |
130 | {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
131 | {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
132 | {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
133 | {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.}, |
134 | {dnnl_f32}, |
135 | }; |
136 | |
137 | const _dt_conf_t conf_bf16bf16u8 = { |
138 | {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
139 | {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
140 | {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
141 | {dnnl_u8, 0, UINT8_MAX, 0, UINT8_MAX, 0, 1, .25, 0.}, |
142 | {dnnl_f32}, |
143 | }; |
144 | |
145 | const _dt_conf_t conf_bf16bf16bf16 = { |
146 | {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
147 | {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
148 | {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
149 | {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
150 | {dnnl_f32}, |
151 | }; |
152 | |
153 | const _dt_conf_t conf_f32bf16bf16 = { |
154 | {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
155 | {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
156 | {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
157 | {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
158 | {dnnl_f32}, |
159 | }; |
160 | |
161 | const _dt_conf_t conf_f32f32s8 = { |
162 | {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
163 | {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
164 | {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
165 | {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.}, |
166 | {dnnl_f32}, |
167 | }; |
168 | |
169 | const _dt_conf_t conf_f32f32u8 = { |
170 | {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
171 | {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
172 | {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
173 | {dnnl_s8, 0, UINT8_MAX, 0, UINT8_MAX, 0, 1, .25, 0.}, |
174 | {dnnl_f32}, |
175 | }; |
176 | |
177 | const _dt_conf_t conf_f32f16f16 = { |
178 | {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
179 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -2, 2, -2, 1, 1.0, |
180 | 0.}, |
181 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -6, 6, 0, 1, 1.0, |
182 | 0.}, |
183 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, 1, .25, |
184 | 0.}, |
185 | {dnnl_f32}, |
186 | }; |
187 | |
188 | const _dt_conf_t conf_f16f32f16 = { |
189 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, 1, .25, |
190 | 0.}, |
191 | {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
192 | {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
193 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, 1, .25, |
194 | 0.}, |
195 | {dnnl_f32}, |
196 | }; |
197 | |
198 | const _dt_conf_t conf_f16f16s8 = { |
199 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, 1, .25, |
200 | 0.}, |
201 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -2, 2, -2, 1, 1.0, |
202 | 0.}, |
203 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -8, 8, 0, 1, 1.0, |
204 | 0.}, |
205 | {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.}, |
206 | {dnnl_f32}, |
207 | }; |
208 | |
209 | const _dt_conf_t conf_f16f16u8 = { |
210 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, 1, .25, |
211 | 0.}, |
212 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -2, 2, -2, 1, 1.0, |
213 | 0.}, |
214 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -8, 8, 0, 1, 1.0, |
215 | 0.}, |
216 | {dnnl_u8, 0, UINT8_MAX, 0, UINT8_MAX, 0, 1, .25, 0.}, |
217 | {dnnl_f32}, |
218 | }; |
219 | |
220 | const _dt_conf_t conf_bf16f32bf16 = { |
221 | {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
222 | {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
223 | {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
224 | {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
225 | {dnnl_f32}, |
226 | }; |
227 | |
228 | const _dt_conf_t conf_u8s8f32 = { |
229 | {dnnl_u8, 0, UINT8_MAX, 0, UINT8_MAX, 0, 1, .25, 0.}, |
230 | {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.}, |
231 | {dnnl_f32, INT32_MIN, INT32_MAX, -8, 32, 0, 1, .25, 0.}, |
232 | {dnnl_f32, INT32_MIN, INT32_MAX, -255, 255, 0, 1, .25, 0.}, |
233 | {dnnl_s32}, |
234 | }; |
235 | |
236 | const _dt_conf_t conf_u8s8f16 = { |
237 | {dnnl_u8, 0, UINT8_MAX, 0, 8, 0, 1, .25, 0.}, |
238 | {dnnl_s8, INT8_MIN, INT8_MAX, -2, 2, -2, 1, 1.0, 0.}, |
239 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -8, 8, 0, 1, 1.0, |
240 | 0.}, |
241 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, 1, .25, |
242 | 0.}, |
243 | {dnnl_s32}, |
244 | }; |
245 | |
246 | const _dt_conf_t conf_u8s8bf16 = { |
247 | {dnnl_u8, 0, UINT8_MAX, 0, UINT8_MAX, 0, 1, .25, 0.}, |
248 | {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.}, |
249 | {dnnl_f32, INT32_MIN, INT32_MAX, -8, 32, 0, 1, .25, 0.}, |
250 | {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
251 | {dnnl_s32}, |
252 | }; |
253 | |
254 | const _dt_conf_t conf_u8s8s32 = { |
255 | {dnnl_u8, 0, UINT8_MAX, 0, UINT8_MAX, 0, 1, .25, 0.}, |
256 | {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.}, |
257 | {dnnl_f32, INT32_MIN, INT32_MAX, -8, 32, 0, 1, .25, 0.}, |
258 | {dnnl_s32, INT32_MIN, INT32_MAX, -255, 255, 0, 1, .25, 0.}, |
259 | {dnnl_s32}, |
260 | }; |
261 | |
262 | const _dt_conf_t conf_u8s8s8 = { |
263 | {dnnl_u8, 0, UINT8_MAX, 0, 8, 0, 1, .25, 0.}, |
264 | {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.}, |
265 | {dnnl_f32, INT32_MIN, INT32_MAX, -8, 32, 0, 1, .25, 0.}, |
266 | {dnnl_s8, INT8_MIN, INT8_MAX, -127, 127, 0, 1, .25, 0.}, |
267 | {dnnl_s32}, |
268 | }; |
269 | |
270 | const _dt_conf_t conf_u8s8u8 = { |
271 | {dnnl_u8, 0, UINT8_MAX, 0, 8, 0, 1, .25, 0.}, |
272 | {dnnl_s8, INT8_MIN, INT8_MAX, -3, 5, 0, 1, .25, 0.}, |
273 | {dnnl_f32, INT32_MIN, INT32_MAX, -8, 32, 0, 1, .25, 0.}, |
274 | {dnnl_u8, 0, UINT8_MAX, 0, 255, 0, 1, .25, 0.}, |
275 | {dnnl_s32}, |
276 | }; |
277 | |
278 | const _dt_conf_t conf_s8s8f32 = { |
279 | {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.}, |
280 | {dnnl_s8, INT8_MIN, INT8_MAX, -8, 3, 0, 4, .25, 0.}, |
281 | {dnnl_f32, INT32_MIN, INT32_MAX, -8, 32, 0, 1, .25, 0.}, |
282 | {dnnl_f32, INT32_MIN, INT32_MAX, -255, 255, 0, 1, .25, 0.}, |
283 | {dnnl_s32}, |
284 | }; |
285 | |
286 | const _dt_conf_t conf_s8s8f16 = { |
287 | {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.}, |
288 | {dnnl_s8, INT8_MIN, INT8_MAX, -8, 3, 0, 4, .25, 0.}, |
289 | {dnnl_f32, INT32_MIN, INT32_MAX, -8, 32, 0, 1, .25, 0.}, |
290 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, 1, .25, |
291 | 0.}, |
292 | {dnnl_s32}, |
293 | }; |
294 | |
295 | const _dt_conf_t conf_s8s8bf16 = { |
296 | {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.}, |
297 | {dnnl_s8, INT8_MIN, INT8_MAX, -8, 3, 0, 4, .25, 0.}, |
298 | {dnnl_f32, INT32_MIN, INT32_MAX, -8, 32, 0, 1, .25, 0.}, |
299 | {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.}, |
300 | {dnnl_s32}, |
301 | }; |
302 | |
303 | const _dt_conf_t conf_s8s8s32 = { |
304 | {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.}, |
305 | {dnnl_s8, INT8_MIN, INT8_MAX, -8, 3, 0, 4, .25, 0.}, |
306 | {dnnl_f32, INT32_MIN, INT32_MAX, -8, 32, 0, 1, .25, 0.}, |
307 | {dnnl_s32, INT32_MIN, INT32_MAX, -255, 255, 0, 1, .25, 0.}, |
308 | {dnnl_s32}, |
309 | }; |
310 | |
311 | const _dt_conf_t conf_s8s8s8 = { |
312 | {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.}, |
313 | {dnnl_s8, INT8_MIN, INT8_MAX, -8, 3, 0, 4, .25, 0.}, |
314 | {dnnl_f32, INT32_MIN, INT32_MAX, -8, 32, 0, 1, .25, 0.}, |
315 | {dnnl_s8, INT8_MIN, INT8_MAX, -127, 127, 0, 1, .25, 0.}, |
316 | {dnnl_s32}, |
317 | }; |
318 | |
319 | const _dt_conf_t conf_s8s8u8 = { |
320 | {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.}, |
321 | {dnnl_s8, INT8_MIN, INT8_MAX, -4, 7, 0, 4, .25, 0.}, |
322 | {dnnl_f32, INT32_MIN, INT32_MAX, -8, 32, 0, 1, .25, 0.}, |
323 | {dnnl_u8, 0, UINT8_MAX, 0, 255, 0, 1, .25, 0.}, |
324 | {dnnl_s32}, |
325 | }; |
326 | |
327 | const dt_conf_t *str2cfg(const char *str) { |
328 | #define CASE(cfg) \ |
329 | if (!strcasecmp(STRINGIFY(cfg), str)) return CONCAT2(conf_, cfg) |
330 | CASE(f16); |
331 | CASE(f32); |
332 | CASE(f32_wino); |
333 | CASE(f64); |
334 | CASE(u8s8f32); |
335 | CASE(u8s8f16); |
336 | CASE(u8s8bf16); |
337 | CASE(u8s8s32); |
338 | CASE(u8s8s8); |
339 | CASE(u8s8u8); |
340 | CASE(s8s8f32); |
341 | CASE(s8s8f16); |
342 | CASE(s8s8bf16); |
343 | CASE(s8s8s32); |
344 | CASE(s8s8s8); |
345 | CASE(s8s8u8); |
346 | CASE(bf16bf16f32); |
347 | CASE(bf16bf16f16); |
348 | CASE(bf16bf16s8); |
349 | CASE(bf16bf16u8); |
350 | CASE(bf16bf16bf16); |
351 | CASE(f32bf16bf16); |
352 | CASE(bf16f32bf16); |
353 | CASE(f32f32s8); |
354 | CASE(f32f32u8); |
355 | CASE(f16f16f32); |
356 | CASE(f16f16s8); |
357 | CASE(f16f16u8); |
358 | CASE(f32f16f16); |
359 | CASE(f16f32f16); |
360 | #undef CASE |
361 | []() { |
362 | SAFE(FAIL, CRIT); |
363 | return 0; |
364 | }(); |
365 | return (const dt_conf_t *)1; |
366 | } |
367 | |
368 | std::ostream &operator<<(std::ostream &s, const dt_conf_t *cfg) { |
369 | #define CASE(_cfg) \ |
370 | if (cfg == CONCAT2(conf_, _cfg)) return s << STRINGIFY(_cfg) |
371 | CASE(f16); |
372 | CASE(f32); |
373 | CASE(f32_wino); |
374 | CASE(f64); |
375 | CASE(u8s8f32); |
376 | CASE(u8s8f16); |
377 | CASE(u8s8bf16); |
378 | CASE(u8s8s32); |
379 | CASE(u8s8s8); |
380 | CASE(u8s8u8); |
381 | CASE(s8s8f32); |
382 | CASE(s8s8f16); |
383 | CASE(s8s8bf16); |
384 | CASE(s8s8s32); |
385 | CASE(s8s8s8); |
386 | CASE(s8s8u8); |
387 | CASE(f16f16f32); |
388 | CASE(f16f16s8); |
389 | CASE(f16f16u8); |
390 | CASE(bf16bf16f32); |
391 | CASE(bf16bf16f16); |
392 | CASE(bf16bf16s8); |
393 | CASE(bf16bf16u8); |
394 | CASE(bf16bf16bf16); |
395 | CASE(f32bf16bf16); |
396 | CASE(f32f32s8); |
397 | CASE(f32f32u8); |
398 | CASE(bf16f32bf16); |
399 | CASE(f32f16f16); |
400 | CASE(f16f32f16); |
401 | #undef CASE |
402 | SAFE_V(FAIL); |
403 | return s; |
404 | } |
405 | |
406 | const dt_conf_t *auto_cfg(const alg_t alg, const dt_conf_t *cfg) { |
407 | if (alg != WINO) return cfg; |
408 | |
409 | std::stringstream ss; |
410 | ss << cfg << "_wino" ; |
411 | const std::string cpp_pstr = ss.str(); |
412 | const char *cfg_s = cpp_pstr.c_str(); |
413 | #define CASE(_cfg_) \ |
414 | if (!strcmp(cfg_s, STRINGIFY(_cfg_))) return CONCAT2(conf_, _cfg_) |
415 | CASE(f32_wino); |
416 | CASE(f16_wino); |
417 | #undef CASE |
418 | return cfg; |
419 | } |
420 | |
421 | } // namespace conv |
422 | |