1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <float.h>
18#include <math.h>
19#include <stdio.h>
20#include <stdlib.h>
21#include <string.h>
22
23#include "oneapi/dnnl/dnnl.h"
24
25#include "dnnl_common.hpp"
26
27#include "deconv/deconv.hpp"
28
29#define HALF_MAX 65504
30#define HALF_MIN (-65504)
31
32namespace deconv {
33
34/* cfgs definition
35 * arrays: SRC, WEI, BIA, DST, ACC
36 * params: {data_type, min, max, f_min, f_max, f_base, f_step, f_sparsity, eps}
37 */
38
39const int int_max_exact_half = 1 << 11;
40const _dt_conf_t conf_f16 = {
41 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, 1, .25,
42 0.},
43 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -2, 2, -2, 1, 1.0,
44 0.},
45 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -6, 6, 0, 1, 1.0,
46 0.},
47 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, 1, .25,
48 0.},
49 {dnnl_f32},
50};
51
52const int int_max_exact = 1 << 24;
53const _dt_conf_t conf_f16f16f32 = {
54 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, 1, .25,
55 0.},
56 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -2, 2, -2, 1, 1.0,
57 0.},
58 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
59 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
60 {dnnl_f32},
61};
62
63const _dt_conf_t conf_f32f16f16 = {
64 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
65 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -2, 2, -2, 1, 1.0,
66 0.},
67 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -6, 6, 0, 1, 1.0,
68 0.},
69 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, 1, .25,
70 0.},
71 {dnnl_f32},
72};
73
74const _dt_conf_t conf_f16f32f16 = {
75 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, 1, .25,
76 0.},
77 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
78 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -6, 6, 0, 1, 1.0,
79 0.},
80 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, 1, .25,
81 0.},
82 {dnnl_f32},
83};
84
85const _dt_conf_t conf_f32 = {
86 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .25, 0.},
87 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, 1.0, 0.},
88 {dnnl_f32, -int_max_exact, int_max_exact, -512, 512, 0, 1, 1.0, 0.},
89 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .25, 0.},
90 {dnnl_f32},
91};
92
93const _dt_conf_t conf_f32_wino = {
94 {dnnl_f32, -FLT_MAX, FLT_MAX, -16, 128, 3, 1, .25, 1e-5},
95 {dnnl_f32, -FLT_MAX, FLT_MAX, 2, 64, 2, 1, .75, 6e-6},
96 {dnnl_f32, -FLT_MAX, FLT_MAX, 1, 128, 1, 1, .25, 2e-7},
97 {dnnl_f32, -FLT_MAX, FLT_MAX, -16, 128, 3, 1, .25, 2e-5},
98 {dnnl_f32},
99};
100
101const _dt_conf_t conf_f64 = {
102 {dnnl_f64, -int_max_exact, int_max_exact, -32, 32, 0, 1, .25, 0.},
103 {dnnl_f64, -int_max_exact, int_max_exact, -32, 32, 0, 1, 1.0, 0.},
104 {dnnl_f64, -int_max_exact, int_max_exact, -512, 512, 0, 1, 1.0, 0.},
105 {dnnl_f64, -int_max_exact, int_max_exact, -32, 32, 0, 1, .25, 0.},
106 {dnnl_f64},
107};
108
109const _dt_conf_t conf_f32_with_bf16_fpmath = {
110 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
111 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
112 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
113 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
114 {dnnl_f32},
115};
116
117const _dt_conf_t conf_f32_with_tf32_fpmath = {
118 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
119 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
120 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
121 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
122 {dnnl_f32},
123};
124
125const _dt_conf_t conf_f16_wino = {
126 {dnnl_f16, HALF_MIN, HALF_MAX, -2, 16, 0, 1, .25, 5e-3},
127 {dnnl_f16, HALF_MIN, HALF_MAX, 1, 6, -2, 1, .5, 6e-3},
128 {dnnl_f16, HALF_MIN, HALF_MAX, 1, 2048, 0, 1, .25, 2e-3},
129 {dnnl_f16, HALF_MIN, HALF_MAX, -2, 8, 0, 1, .25, 7e-3},
130 {dnnl_f32},
131};
132
133const _dt_conf_t conf_bf16bf16f32 = {
134 {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
135 {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
136 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
137 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
138 {dnnl_f32},
139};
140
141const _dt_conf_t conf_bf16bf16f16 = {
142 {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
143 {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
144 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, 1, .25,
145 0.},
146 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, 1, .25,
147 0.},
148 {dnnl_f16},
149};
150
151const _dt_conf_t conf_bf16bf16s8 = {
152 {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
153 {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
154 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
155 {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.},
156 {dnnl_f32},
157};
158
159const _dt_conf_t conf_bf16bf16u8 = {
160 {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
161 {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
162 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
163 {dnnl_u8, 0, UINT8_MAX, 0, UINT8_MAX, 0, 1, .25, 0.},
164 {dnnl_f32},
165};
166
167const _dt_conf_t conf_bf16bf16bf16 = {
168 /* eps is 1e-2 because of loss in precision of
169 * output when converted from fp32 to bf16.
170 * oneDNN output is compared against reference computed in fp32.*/
171 {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 1e-2},
172 {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 1e-2},
173 {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 1e-2},
174 {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 1e-2},
175 {dnnl_f32},
176};
177
178const _dt_conf_t conf_f32bf16bf16 = {
179 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
180 {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
181 {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
182 {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
183 {dnnl_f32},
184};
185
186const _dt_conf_t conf_f32f32s8 = {
187 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
188 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
189 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
190 {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.},
191 {dnnl_f32},
192};
193
194const _dt_conf_t conf_f32f32u8 = {
195 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
196 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
197 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
198 {dnnl_s8, 0, UINT8_MAX, 0, UINT8_MAX, 0, 1, .25, 0.},
199 {dnnl_f32},
200};
201
202const _dt_conf_t conf_f16f16s8 = {
203 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, 1, .25,
204 0.},
205 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -2, 2, -2, 1, 1.0,
206 0.},
207 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -8, 8, 0, 1, 1.0,
208 0.},
209 {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.},
210 {dnnl_f32},
211};
212
213const _dt_conf_t conf_f16f16u8 = {
214 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, 1, .25,
215 0.},
216 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -2, 2, -2, 1, 1.0,
217 0.},
218 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -8, 8, 0, 1, 1.0,
219 0.},
220 {dnnl_u8, 0, UINT8_MAX, 0, UINT8_MAX, 0, 1, .25, 0.},
221 {dnnl_f32},
222};
223
224const _dt_conf_t conf_bf16f32bf16 = {
225 {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
226 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
227 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
228 {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
229 {dnnl_f32},
230};
231
232const _dt_conf_t conf_u8s8f32 = {
233 {dnnl_u8, 0, UINT8_MAX, 0, UINT8_MAX, 0, 1, .25, 0.},
234 {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.},
235 {dnnl_f32, INT32_MIN, INT32_MAX, -8, 32, 0, 1, .25, 0.},
236 {dnnl_f32, INT32_MIN, INT32_MAX, -255, 255, 0, 1, .25, 0.},
237 {dnnl_s32},
238};
239
240const _dt_conf_t conf_u8s8f16 = {
241 {dnnl_u8, 0, UINT8_MAX, 0, 8, 0, 1, .25, 0.},
242 {dnnl_s8, INT8_MIN, INT8_MAX, -2, 2, -2, 1, 1.0, 0.},
243 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -8, 8, 0, 1, 1.0,
244 0.},
245 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, 1, .25,
246 0.},
247 {dnnl_s32},
248};
249
250const _dt_conf_t conf_u8s8bf16 = {
251 {dnnl_u8, 0, UINT8_MAX, 0, UINT8_MAX, 0, 1, .25, 0.},
252 {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.},
253 {dnnl_f32, INT32_MIN, INT32_MAX, -8, 32, 0, 1, .25, 0.},
254 {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 1e-2},
255 {dnnl_s32},
256};
257
258const _dt_conf_t conf_u8s8s32 = {
259 {dnnl_u8, 0, UINT8_MAX, 0, UINT8_MAX, 0, 1, .25, 0.},
260 {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.},
261 {dnnl_f32, INT32_MIN, INT32_MAX, -8, 32, 0, 1, .25, 0.},
262 {dnnl_s32, INT32_MIN, INT32_MAX, -255, 255, 0, 1, .25, 0.},
263 {dnnl_s32},
264};
265
266const _dt_conf_t conf_u8s8s8 = {
267 {dnnl_u8, 0, UINT8_MAX, 0, 8, 0, 1, .25, 0.},
268 {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.},
269 {dnnl_f32, INT32_MIN, INT32_MAX, -8, 32, 0, 1, .25, 0.},
270 {dnnl_s8, INT8_MIN, INT8_MAX, -127, 127, 0, 1, .25, 0.},
271 {dnnl_s32},
272};
273
274const _dt_conf_t conf_u8s8u8 = {
275 {dnnl_u8, 0, UINT8_MAX, 0, 8, 0, 1, .25, 0.},
276 {dnnl_s8, INT8_MIN, INT8_MAX, -3, 5, 0, 1, .25, 0.},
277 {dnnl_f32, INT32_MIN, INT32_MAX, -8, 32, 0, 1, .25, 0.},
278 {dnnl_u8, 0, UINT8_MAX, 0, 255, 0, 1, .25, 0.},
279 {dnnl_s32},
280};
281
282const _dt_conf_t conf_s8s8f32 = {
283 {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.},
284 {dnnl_s8, INT8_MIN, INT8_MAX, -8, 3, 0, 4, .25, 0.},
285 {dnnl_f32, INT32_MIN, INT32_MAX, -8, 32, 0, 1, .25, 0.},
286 {dnnl_f32, INT32_MIN, INT32_MAX, -255, 255, 0, 1, .25, 0.},
287 {dnnl_s32},
288};
289
290const _dt_conf_t conf_s8s8f16 = {
291 {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.},
292 {dnnl_s8, INT8_MIN, INT8_MAX, -8, 3, 0, 4, .25, 0.},
293 {dnnl_f32, INT32_MIN, INT32_MAX, -8, 32, 0, 1, .25, 0.},
294 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, 1, .25,
295 0.},
296 {dnnl_s32},
297};
298
299const _dt_conf_t conf_s8s8bf16 = {
300 {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.},
301 {dnnl_s8, INT8_MIN, INT8_MAX, -8, 3, 0, 4, .25, 0.},
302 {dnnl_f32, INT32_MIN, INT32_MAX, -8, 32, 0, 1, .25, 0.},
303 {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 1e-2},
304 {dnnl_s32},
305};
306
307const _dt_conf_t conf_s8s8s32 = {
308 {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.},
309 {dnnl_s8, INT8_MIN, INT8_MAX, -8, 3, 0, 4, .25, 0.},
310 {dnnl_f32, INT32_MIN, INT32_MAX, -8, 32, 0, 1, .25, 0.},
311 {dnnl_s32, INT32_MIN, INT32_MAX, -255, 255, 0, 1, .25, 0.},
312 {dnnl_s32},
313};
314
315const _dt_conf_t conf_s8s8s8 = {
316 {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.},
317 {dnnl_s8, INT8_MIN, INT8_MAX, -8, 3, 0, 4, .25, 0.},
318 {dnnl_f32, INT32_MIN, INT32_MAX, -8, 32, 0, 1, .25, 0.},
319 {dnnl_s8, INT8_MIN, INT8_MAX, -127, 127, 0, 1, .25, 0.},
320 {dnnl_s32},
321};
322
323const _dt_conf_t conf_s8s8u8 = {
324 {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.},
325 {dnnl_s8, INT8_MIN, INT8_MAX, -4, 7, 0, 4, .25, 0.},
326 {dnnl_f32, INT32_MIN, INT32_MAX, -8, 32, 0, 1, .25, 0.},
327 {dnnl_u8, 0, UINT8_MAX, 0, 255, 0, 1, .25, 0.},
328 {dnnl_s32},
329};
330
331const dt_conf_t *str2cfg(const char *str) {
332#define CASE(cfg) \
333 if (!strcasecmp(STRINGIFY(cfg), str)) return CONCAT2(conf_, cfg)
334 CASE(f16);
335 CASE(f32);
336 CASE(f32_wino);
337 CASE(f64);
338 CASE(u8s8f32);
339 CASE(u8s8f16);
340 CASE(u8s8bf16);
341 CASE(u8s8s32);
342 CASE(u8s8s8);
343 CASE(u8s8u8);
344 CASE(s8s8f32);
345 CASE(s8s8f16);
346 CASE(s8s8bf16);
347 CASE(s8s8s32);
348 CASE(s8s8s8);
349 CASE(s8s8u8);
350 CASE(bf16bf16f32);
351 CASE(bf16bf16f16);
352 CASE(bf16bf16s8);
353 CASE(bf16bf16u8);
354 CASE(bf16bf16bf16);
355 CASE(f32bf16bf16);
356 CASE(bf16f32bf16);
357 CASE(f32f32s8);
358 CASE(f32f32u8);
359 CASE(f16f16f32);
360 CASE(f32f16f16);
361 CASE(f16f32f16);
362 CASE(f16f16s8);
363 CASE(f16f16u8);
364#undef CASE
365 []() {
366 SAFE(FAIL, CRIT);
367 return 0;
368 }();
369 return (const dt_conf_t *)1;
370}
371
372std::ostream &operator<<(std::ostream &s, const dt_conf_t *cfg) {
373#define CASE(_cfg) \
374 if (cfg == CONCAT2(conf_, _cfg)) return s << STRINGIFY(_cfg)
375 CASE(f16);
376 CASE(f32);
377 CASE(f32_wino);
378 CASE(f64);
379 CASE(u8s8f32);
380 CASE(u8s8f16);
381 CASE(u8s8bf16);
382 CASE(u8s8s32);
383 CASE(u8s8s8);
384 CASE(u8s8u8);
385 CASE(s8s8f32);
386 CASE(s8s8f16);
387 CASE(s8s8bf16);
388 CASE(s8s8s32);
389 CASE(s8s8s8);
390 CASE(s8s8u8);
391 CASE(f16f16f32);
392 CASE(f32f16f16);
393 CASE(f16f32f16);
394 CASE(f16f16s8);
395 CASE(f16f16u8);
396 CASE(bf16bf16f32);
397 CASE(bf16bf16f16);
398 CASE(bf16bf16s8);
399 CASE(bf16bf16u8);
400 CASE(bf16bf16bf16);
401 CASE(f32bf16bf16);
402 CASE(f32f32s8);
403 CASE(f32f32u8);
404 CASE(bf16f32bf16);
405#undef CASE
406 SAFE_V(FAIL);
407 return s;
408}
409
410const dt_conf_t *auto_cfg(const alg_t alg, const dt_conf_t *cfg) {
411 if (alg != WINO) return cfg;
412
413 std::stringstream ss;
414 ss << cfg << "_wino";
415 const std::string cpp_pstr = ss.str();
416 const char *cfg_s = cpp_pstr.c_str();
417#define CASE(_cfg_) \
418 if (!strcmp(cfg_s, STRINGIFY(_cfg_))) return CONCAT2(conf_, _cfg_)
419 CASE(f32_wino);
420 CASE(f16_wino);
421#undef CASE
422 return cfg;
423}
424
425} // namespace deconv
426