1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <float.h>
18#include <math.h>
19#include <stdio.h>
20#include <stdlib.h>
21#include <string.h>
22
23#include "oneapi/dnnl/dnnl.h"
24
25#include "dnnl_common.hpp"
26
27#include "conv/conv.hpp"
28
29#define HALF_MAX 65504
30#define HALF_MIN (-65504)
31
32namespace conv {
33
34/* cfgs definition
35 * arrays: SRC, WEI, BIA, DST, ACC
36 * params: {data_type, min, max, f_min, f_max, f_base, f_step, f_sparsity, eps}
37 */
38
39const int int_max_exact_half = 1 << 11;
40const _dt_conf_t conf_f16 = {
41 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, 1, .25,
42 0.},
43 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -2, 2, -2, 1, 1.0,
44 0.},
45 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -6, 6, 0, 1, 1.0,
46 0.},
47 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, 1, .25,
48 0.},
49 {dnnl_f32},
50};
51
52const int int_max_exact = 1 << 24;
53const _dt_conf_t conf_f16f16f32 = {
54 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, 1, .25,
55 0.},
56 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -2, 2, -2, 1, 1.0,
57 0.},
58 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
59 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
60 {dnnl_f32},
61};
62
63const _dt_conf_t conf_f32 = {
64 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .25, 0.},
65 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, 1.0, 0.},
66 {dnnl_f32, -int_max_exact, int_max_exact, -512, 512, 0, 1, 1.0, 0.},
67 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .25, 0.},
68 {dnnl_f32},
69};
70
71const _dt_conf_t conf_f32_wino = {
72 {dnnl_f32, -FLT_MAX, FLT_MAX, -16, 128, 3, 1, .25, 1e-5},
73 {dnnl_f32, -FLT_MAX, FLT_MAX, 2, 64, 2, 1, .75, 6e-6},
74 {dnnl_f32, -FLT_MAX, FLT_MAX, 1, 128, 1, 1, .25, 2e-7},
75 {dnnl_f32, -FLT_MAX, FLT_MAX, -16, 128, 3, 1, .25, 2e-5},
76 {dnnl_f32},
77};
78
79const _dt_conf_t conf_f64 = {
80 {dnnl_f64, -int_max_exact, int_max_exact, -32, 32, 0, 1, .25, 0.},
81 {dnnl_f64, -int_max_exact, int_max_exact, -32, 32, 0, 1, 1.0, 0.},
82 {dnnl_f64, -int_max_exact, int_max_exact, -512, 512, 0, 1, 1.0, 0.},
83 {dnnl_f64, -int_max_exact, int_max_exact, -32, 32, 0, 1, .25, 0.},
84 {dnnl_f64},
85};
86
87const _dt_conf_t conf_f32_with_bf16_fpmath = {
88 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
89 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
90 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
91 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
92 {dnnl_f32},
93};
94
95const _dt_conf_t conf_f32_with_tf32_fpmath = {
96 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
97 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
98 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
99 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
100 {dnnl_f32},
101};
102
103const _dt_conf_t conf_f16_wino = {
104 {dnnl_f16, HALF_MIN, HALF_MAX, -2, 16, 0, 1, .25, 5e-3},
105 {dnnl_f16, HALF_MIN, HALF_MAX, 1, 6, -2, 1, .5, 6e-3},
106 {dnnl_f16, HALF_MIN, HALF_MAX, 1, 2048, 0, 1, .25, 2e-3},
107 {dnnl_f16, HALF_MIN, HALF_MAX, -2, 8, 0, 1, .25, 7e-3},
108 {dnnl_f32},
109};
110
111const _dt_conf_t conf_bf16bf16f32 = {
112 {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
113 {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
114 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
115 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
116 {dnnl_f32},
117};
118
119const _dt_conf_t conf_bf16bf16f16 = {
120 {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
121 {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
122 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, 1, .25,
123 0.},
124 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, 1, .25,
125 0.},
126 {dnnl_f16},
127};
128
129const _dt_conf_t conf_bf16bf16s8 = {
130 {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
131 {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
132 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
133 {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.},
134 {dnnl_f32},
135};
136
137const _dt_conf_t conf_bf16bf16u8 = {
138 {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
139 {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
140 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
141 {dnnl_u8, 0, UINT8_MAX, 0, UINT8_MAX, 0, 1, .25, 0.},
142 {dnnl_f32},
143};
144
145const _dt_conf_t conf_bf16bf16bf16 = {
146 {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
147 {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
148 {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
149 {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
150 {dnnl_f32},
151};
152
153const _dt_conf_t conf_f32bf16bf16 = {
154 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
155 {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
156 {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
157 {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
158 {dnnl_f32},
159};
160
161const _dt_conf_t conf_f32f32s8 = {
162 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
163 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
164 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
165 {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.},
166 {dnnl_f32},
167};
168
169const _dt_conf_t conf_f32f32u8 = {
170 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
171 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
172 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
173 {dnnl_s8, 0, UINT8_MAX, 0, UINT8_MAX, 0, 1, .25, 0.},
174 {dnnl_f32},
175};
176
177const _dt_conf_t conf_f32f16f16 = {
178 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
179 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -2, 2, -2, 1, 1.0,
180 0.},
181 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -6, 6, 0, 1, 1.0,
182 0.},
183 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, 1, .25,
184 0.},
185 {dnnl_f32},
186};
187
188const _dt_conf_t conf_f16f32f16 = {
189 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, 1, .25,
190 0.},
191 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
192 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
193 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, 1, .25,
194 0.},
195 {dnnl_f32},
196};
197
198const _dt_conf_t conf_f16f16s8 = {
199 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, 1, .25,
200 0.},
201 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -2, 2, -2, 1, 1.0,
202 0.},
203 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -8, 8, 0, 1, 1.0,
204 0.},
205 {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.},
206 {dnnl_f32},
207};
208
209const _dt_conf_t conf_f16f16u8 = {
210 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, 1, .25,
211 0.},
212 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -2, 2, -2, 1, 1.0,
213 0.},
214 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -8, 8, 0, 1, 1.0,
215 0.},
216 {dnnl_u8, 0, UINT8_MAX, 0, UINT8_MAX, 0, 1, .25, 0.},
217 {dnnl_f32},
218};
219
220const _dt_conf_t conf_bf16f32bf16 = {
221 {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
222 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
223 {dnnl_f32, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
224 {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
225 {dnnl_f32},
226};
227
228const _dt_conf_t conf_u8s8f32 = {
229 {dnnl_u8, 0, UINT8_MAX, 0, UINT8_MAX, 0, 1, .25, 0.},
230 {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.},
231 {dnnl_f32, INT32_MIN, INT32_MAX, -8, 32, 0, 1, .25, 0.},
232 {dnnl_f32, INT32_MIN, INT32_MAX, -255, 255, 0, 1, .25, 0.},
233 {dnnl_s32},
234};
235
236const _dt_conf_t conf_u8s8f16 = {
237 {dnnl_u8, 0, UINT8_MAX, 0, 8, 0, 1, .25, 0.},
238 {dnnl_s8, INT8_MIN, INT8_MAX, -2, 2, -2, 1, 1.0, 0.},
239 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -8, 8, 0, 1, 1.0,
240 0.},
241 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, 1, .25,
242 0.},
243 {dnnl_s32},
244};
245
246const _dt_conf_t conf_u8s8bf16 = {
247 {dnnl_u8, 0, UINT8_MAX, 0, UINT8_MAX, 0, 1, .25, 0.},
248 {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.},
249 {dnnl_f32, INT32_MIN, INT32_MAX, -8, 32, 0, 1, .25, 0.},
250 {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
251 {dnnl_s32},
252};
253
254const _dt_conf_t conf_u8s8s32 = {
255 {dnnl_u8, 0, UINT8_MAX, 0, UINT8_MAX, 0, 1, .25, 0.},
256 {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.},
257 {dnnl_f32, INT32_MIN, INT32_MAX, -8, 32, 0, 1, .25, 0.},
258 {dnnl_s32, INT32_MIN, INT32_MAX, -255, 255, 0, 1, .25, 0.},
259 {dnnl_s32},
260};
261
262const _dt_conf_t conf_u8s8s8 = {
263 {dnnl_u8, 0, UINT8_MAX, 0, 8, 0, 1, .25, 0.},
264 {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.},
265 {dnnl_f32, INT32_MIN, INT32_MAX, -8, 32, 0, 1, .25, 0.},
266 {dnnl_s8, INT8_MIN, INT8_MAX, -127, 127, 0, 1, .25, 0.},
267 {dnnl_s32},
268};
269
270const _dt_conf_t conf_u8s8u8 = {
271 {dnnl_u8, 0, UINT8_MAX, 0, 8, 0, 1, .25, 0.},
272 {dnnl_s8, INT8_MIN, INT8_MAX, -3, 5, 0, 1, .25, 0.},
273 {dnnl_f32, INT32_MIN, INT32_MAX, -8, 32, 0, 1, .25, 0.},
274 {dnnl_u8, 0, UINT8_MAX, 0, 255, 0, 1, .25, 0.},
275 {dnnl_s32},
276};
277
278const _dt_conf_t conf_s8s8f32 = {
279 {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.},
280 {dnnl_s8, INT8_MIN, INT8_MAX, -8, 3, 0, 4, .25, 0.},
281 {dnnl_f32, INT32_MIN, INT32_MAX, -8, 32, 0, 1, .25, 0.},
282 {dnnl_f32, INT32_MIN, INT32_MAX, -255, 255, 0, 1, .25, 0.},
283 {dnnl_s32},
284};
285
286const _dt_conf_t conf_s8s8f16 = {
287 {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.},
288 {dnnl_s8, INT8_MIN, INT8_MAX, -8, 3, 0, 4, .25, 0.},
289 {dnnl_f32, INT32_MIN, INT32_MAX, -8, 32, 0, 1, .25, 0.},
290 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, 1, .25,
291 0.},
292 {dnnl_s32},
293};
294
295const _dt_conf_t conf_s8s8bf16 = {
296 {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.},
297 {dnnl_s8, INT8_MIN, INT8_MAX, -8, 3, 0, 4, .25, 0.},
298 {dnnl_f32, INT32_MIN, INT32_MAX, -8, 32, 0, 1, .25, 0.},
299 {dnnl_bf16, -int_max_exact, int_max_exact, -32, 32, 0, 1, .75, 0.},
300 {dnnl_s32},
301};
302
303const _dt_conf_t conf_s8s8s32 = {
304 {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.},
305 {dnnl_s8, INT8_MIN, INT8_MAX, -8, 3, 0, 4, .25, 0.},
306 {dnnl_f32, INT32_MIN, INT32_MAX, -8, 32, 0, 1, .25, 0.},
307 {dnnl_s32, INT32_MIN, INT32_MAX, -255, 255, 0, 1, .25, 0.},
308 {dnnl_s32},
309};
310
311const _dt_conf_t conf_s8s8s8 = {
312 {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.},
313 {dnnl_s8, INT8_MIN, INT8_MAX, -8, 3, 0, 4, .25, 0.},
314 {dnnl_f32, INT32_MIN, INT32_MAX, -8, 32, 0, 1, .25, 0.},
315 {dnnl_s8, INT8_MIN, INT8_MAX, -127, 127, 0, 1, .25, 0.},
316 {dnnl_s32},
317};
318
319const _dt_conf_t conf_s8s8u8 = {
320 {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, 1, .25, 0.},
321 {dnnl_s8, INT8_MIN, INT8_MAX, -4, 7, 0, 4, .25, 0.},
322 {dnnl_f32, INT32_MIN, INT32_MAX, -8, 32, 0, 1, .25, 0.},
323 {dnnl_u8, 0, UINT8_MAX, 0, 255, 0, 1, .25, 0.},
324 {dnnl_s32},
325};
326
327const dt_conf_t *str2cfg(const char *str) {
328#define CASE(cfg) \
329 if (!strcasecmp(STRINGIFY(cfg), str)) return CONCAT2(conf_, cfg)
330 CASE(f16);
331 CASE(f32);
332 CASE(f32_wino);
333 CASE(f64);
334 CASE(u8s8f32);
335 CASE(u8s8f16);
336 CASE(u8s8bf16);
337 CASE(u8s8s32);
338 CASE(u8s8s8);
339 CASE(u8s8u8);
340 CASE(s8s8f32);
341 CASE(s8s8f16);
342 CASE(s8s8bf16);
343 CASE(s8s8s32);
344 CASE(s8s8s8);
345 CASE(s8s8u8);
346 CASE(bf16bf16f32);
347 CASE(bf16bf16f16);
348 CASE(bf16bf16s8);
349 CASE(bf16bf16u8);
350 CASE(bf16bf16bf16);
351 CASE(f32bf16bf16);
352 CASE(bf16f32bf16);
353 CASE(f32f32s8);
354 CASE(f32f32u8);
355 CASE(f16f16f32);
356 CASE(f16f16s8);
357 CASE(f16f16u8);
358 CASE(f32f16f16);
359 CASE(f16f32f16);
360#undef CASE
361 []() {
362 SAFE(FAIL, CRIT);
363 return 0;
364 }();
365 return (const dt_conf_t *)1;
366}
367
368std::ostream &operator<<(std::ostream &s, const dt_conf_t *cfg) {
369#define CASE(_cfg) \
370 if (cfg == CONCAT2(conf_, _cfg)) return s << STRINGIFY(_cfg)
371 CASE(f16);
372 CASE(f32);
373 CASE(f32_wino);
374 CASE(f64);
375 CASE(u8s8f32);
376 CASE(u8s8f16);
377 CASE(u8s8bf16);
378 CASE(u8s8s32);
379 CASE(u8s8s8);
380 CASE(u8s8u8);
381 CASE(s8s8f32);
382 CASE(s8s8f16);
383 CASE(s8s8bf16);
384 CASE(s8s8s32);
385 CASE(s8s8s8);
386 CASE(s8s8u8);
387 CASE(f16f16f32);
388 CASE(f16f16s8);
389 CASE(f16f16u8);
390 CASE(bf16bf16f32);
391 CASE(bf16bf16f16);
392 CASE(bf16bf16s8);
393 CASE(bf16bf16u8);
394 CASE(bf16bf16bf16);
395 CASE(f32bf16bf16);
396 CASE(f32f32s8);
397 CASE(f32f32u8);
398 CASE(bf16f32bf16);
399 CASE(f32f16f16);
400 CASE(f16f32f16);
401#undef CASE
402 SAFE_V(FAIL);
403 return s;
404}
405
406const dt_conf_t *auto_cfg(const alg_t alg, const dt_conf_t *cfg) {
407 if (alg != WINO) return cfg;
408
409 std::stringstream ss;
410 ss << cfg << "_wino";
411 const std::string cpp_pstr = ss.str();
412 const char *cfg_s = cpp_pstr.c_str();
413#define CASE(_cfg_) \
414 if (!strcmp(cfg_s, STRINGIFY(_cfg_))) return CONCAT2(conf_, _cfg_)
415 CASE(f32_wino);
416 CASE(f16_wino);
417#undef CASE
418 return cfg;
419}
420
421} // namespace conv
422