1/*******************************************************************************
2* Copyright 2018-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <float.h>
18#include <stdio.h>
19#include <stdlib.h>
20#include <string.h>
21
22#include "dnnl_common.hpp"
23#include "oneapi/dnnl/dnnl.h"
24
25#include "ip.hpp"
26
27namespace ip {
28
29/* cfgs definition
30 * arrays: SRC, WEI, BIA, DST, ACC
31 * params: {data_type, min, max, f_min, f_max, f_base, f_sparsity, f_scale, eps}
32 */
33
34const int int_max_exact = 1 << 24;
35const _dt_conf_t conf_f32 = {
36 {dnnl_f32, -int_max_exact, int_max_exact, -64, 64, 0, .35, 1. / 128,
37 1e-6},
38 {dnnl_f32, -int_max_exact, int_max_exact, -128, 128, 0, 1.0, 1. / 256,
39 1e-6},
40 {dnnl_f32, -int_max_exact, int_max_exact, -10, 10, 0, 1.0, 1. / 64,
41 1e-6},
42 {dnnl_f32, -int_max_exact, int_max_exact, -10, 10, 0, .35, 1. / 64,
43 1e-6},
44 {dnnl_f32},
45};
46
47const _dt_conf_t conf_bf16bf16f32 = {
48 {dnnl_bf16, -int_max_exact, int_max_exact, -64, 64, 0, .35, 1. / 128,
49 0},
50 {dnnl_bf16, -int_max_exact, int_max_exact, -128, 128, 0, 1.0, 1. / 256,
51 0},
52 {dnnl_f32, -int_max_exact, int_max_exact, -10, 10, 0, 1.0, 1. / 64, 0},
53 {dnnl_f32, -int_max_exact, int_max_exact, -10, 10, 0, .35, 1. / 64,
54 1e-6},
55 {dnnl_f32},
56};
57
58const _dt_conf_t conf_bf16bf16bf16 = {
59 /* eps is 1e-2 because of loss in precision of
60 * output when converted from fp32 to bf16.
61 * oneDNN output is compared against reference computed in fp32.*/
62 {dnnl_bf16, -int_max_exact, int_max_exact, -64, 64, 0, .35, 1. / 128,
63 1e-2},
64 {dnnl_bf16, -int_max_exact, int_max_exact, -128, 128, 0, 1.0, 1. / 256,
65 1e-2},
66 {dnnl_bf16, -int_max_exact, int_max_exact, -10, 10, 0, 1.0, 1. / 64,
67 1e-2},
68 {dnnl_bf16, -int_max_exact, int_max_exact, -10, 10, 0, .35, 1. / 64,
69 1e-2},
70 {dnnl_f32},
71};
72
73const _dt_conf_t conf_f32bf16bf16 = {
74 {dnnl_f32, -int_max_exact, int_max_exact, -64, 64, 0, .35, 1. / 128,
75 1e-6},
76 {dnnl_bf16, -int_max_exact, int_max_exact, -128, 128, 0, 1.0, 1. / 256,
77 0},
78 {dnnl_bf16, -int_max_exact, int_max_exact, -10, 10, 0, 1.0, 1. / 64, 0},
79 {dnnl_bf16, -int_max_exact, int_max_exact, -10, 10, 0, .35, 1. / 64, 0},
80 {dnnl_f32},
81};
82
83const _dt_conf_t conf_bf16f32bf16 = {
84 {dnnl_bf16, -int_max_exact, int_max_exact, -64, 64, 0, .35, 1. / 128,
85 0},
86 {dnnl_f32, -int_max_exact, int_max_exact, -128, 128, 0, 1.0, 1. / 256,
87 1e-6},
88 {dnnl_f32, -int_max_exact, int_max_exact, -10, 10, 0, 1.0, 1. / 64,
89 1e-6},
90 {dnnl_bf16, -int_max_exact, int_max_exact, -10, 10, 0, .35, 1. / 64, 0},
91 {dnnl_f32},
92};
93
94const int int_max_exact_half = 1 << 11;
95const _dt_conf_t conf_f16 = {
96 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, .35, 1,
97 1e-3},
98 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -2, 2, 0, .35, 1,
99 1e-3},
100 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -8, 8, 0, .35, 1,
101 1e-3},
102 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, .35, 1,
103 1e-3},
104 {dnnl_f32},
105};
106
107const _dt_conf_t conf_f16f16f32 = {
108 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, .35, 1,
109 1e-3},
110 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -2, 2, 0, .35, 1,
111 1e-3},
112 {dnnl_f32, -int_max_exact, int_max_exact, -10, 10, 0, 1.0, 1. / 64, 0},
113 {dnnl_f32, -int_max_exact, int_max_exact, -10, 10, 0, .35, 1. / 64,
114 1e-6},
115 {dnnl_f32},
116};
117
118const _dt_conf_t conf_f32f16f16 = {
119 {dnnl_f32, -int_max_exact, int_max_exact, -64, 64, 0, .35, 1. / 128,
120 1e-6},
121 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -2, 2, 0, .35, 1,
122 1e-3},
123 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -8, 8, 0, .35, 1,
124 1e-3},
125 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, .35, 1,
126 1e-3},
127 {dnnl_f32},
128};
129
130const _dt_conf_t conf_f16f32f16 = {
131 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, .35, 1,
132 1e-3},
133 {dnnl_f32, -int_max_exact, int_max_exact, -128, 128, 0, 1.0, 1. / 256,
134 1e-6},
135 {dnnl_f32, -int_max_exact, int_max_exact, -10, 10, 0, 1.0, 1. / 64,
136 1e-6},
137 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, .35, 1,
138 1e-3},
139 {dnnl_f32},
140};
141
142const _dt_conf_t conf_f16f16s8 = {
143 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, .35, 1,
144 1e-3},
145 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -2, 2, 0, .35, 1,
146 1e-3},
147 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -8, 8, 0, .35, 1,
148 1e-3},
149 {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, .35, 1, 0.},
150 {dnnl_f32},
151};
152
153const _dt_conf_t conf_f16f16u8 = {
154 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, .35, 1,
155 1e-3},
156 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -2, 2, 0, .35, 1,
157 1e-3},
158 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -8, 8, 0, .35, 1,
159 1e-3},
160 {dnnl_u8, 0, UINT8_MAX, 0, 8, 0, .35, 1, 0.},
161 {dnnl_f32},
162};
163
164const _dt_conf_t conf_u8s8f32 = {
165 {dnnl_u8, 0, UINT8_MAX, 0, 8, 0, .35, 1, 0.},
166 {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, .35, 1, 0.},
167 {dnnl_f32, -int_max_exact, int_max_exact, -8, 32, 0, .35, 1, 0.},
168 {dnnl_f32, -int_max_exact, int_max_exact, -255, 255, 0, .35, 1, 0.},
169 {dnnl_s32},
170};
171
172const _dt_conf_t conf_u8s8s32 = {
173 {dnnl_u8, 0, UINT8_MAX, 0, 8, 0, .35, 1, 0.},
174 {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, .35, 1, 0.},
175 {dnnl_f32, -int_max_exact, int_max_exact, -8, 32, 0, .35, 1, 0.},
176 {dnnl_s32, INT32_MIN, INT32_MAX, -255, 255, 0, .35, 1, 0.},
177 {dnnl_s32},
178};
179
180const _dt_conf_t conf_u8s8s8 = {
181 {dnnl_u8, 0, UINT8_MAX, 0, 8, 0, .35, 1, 0.},
182 {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, .35, 1, 0.},
183 {dnnl_f32, -int_max_exact, int_max_exact, -8, 32, 0, .35, 1, 0.},
184 {dnnl_s8, INT8_MIN, INT8_MAX, -127, 127, 0, .35, 1, 0.},
185 {dnnl_s32},
186};
187
188const _dt_conf_t conf_u8s8u8 = {
189 {dnnl_u8, 0, UINT8_MAX, 0, 8, 0, .35, 1, 0.},
190 {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, .35, 1, 0.},
191 {dnnl_f32, -int_max_exact, int_max_exact, -8, 32, 0, .35, 1, 0.},
192 {dnnl_u8, 0, UINT8_MAX, 0, 255, 0, .35, 1, 0.},
193 {dnnl_s32},
194};
195
196const _dt_conf_t conf_s8s8f32 = {
197 {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, .35, 1, 0.},
198 {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, .35, 1, 0.},
199 {dnnl_f32, -int_max_exact, int_max_exact, -8, 32, 0, .35, 1, 0.},
200 {dnnl_f32, -int_max_exact, int_max_exact, -255, 255, 0, .35, 1, 0.},
201 {dnnl_s32},
202};
203
204const _dt_conf_t conf_s8s8s32 = {
205 {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, .35, 1, 0.},
206 {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, .35, 1, 0.},
207 {dnnl_f32, -int_max_exact, int_max_exact, -8, 32, 0, .35, 1, 0.},
208 {dnnl_s32, INT32_MIN, INT32_MAX, -255, 255, 0, .35, 1, 0.},
209 {dnnl_s32},
210};
211
212const _dt_conf_t conf_s8s8s8 = {
213 {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, .35, 1, 0.},
214 {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, .35, 1, 0.},
215 {dnnl_f32, -int_max_exact, int_max_exact, -8, 32, 0, .35, 1, 0.},
216 {dnnl_s8, INT8_MIN, INT8_MAX, -127, 127, 0, .35, 1, 0.},
217 {dnnl_s32},
218};
219
220const _dt_conf_t conf_s8s8u8 = {
221 {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, .35, 1, 0.},
222 {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, .35, 1, 0.},
223 {dnnl_f32, -int_max_exact, int_max_exact, -8, 32, 0, .35, 1, 0.},
224 {dnnl_u8, 0, UINT8_MAX, 0, 255, 0, .35, 1, 0.},
225 {dnnl_s32},
226};
227
228const _dt_conf_t conf_u8s8bf16 = {
229 {dnnl_u8, 0, UINT8_MAX, 0, 8, 0, .35, 1, 0.},
230 {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, .35, 1, 0.},
231 {dnnl_f32, -int_max_exact, int_max_exact, -8, 32, 0, .35, 1, 0.},
232 {dnnl_bf16, -int_max_exact, int_max_exact, -255, 255, 0, .35, 1, 0.},
233 {dnnl_s32},
234};
235
236const _dt_conf_t conf_s8s8bf16 = {
237 {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, .35, 1, 0.},
238 {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, .35, 1, 0.},
239 {dnnl_f32, -int_max_exact, int_max_exact, -8, 32, 0, .35, 1, 0.},
240 {dnnl_bf16, -int_max_exact, int_max_exact, -255, 255, 0, .35, 1, 0.},
241 {dnnl_s32},
242};
243
244const _dt_conf_t conf_u8s8f16 = {
245 {dnnl_u8, 0, UINT8_MAX, 0, 8, 0, .35, 1, 0.},
246 {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, .35, 1, 0.},
247 {dnnl_f32, -int_max_exact, int_max_exact, -8, 32, 0, .35, 1, 0.},
248 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, .35, 1,
249 1e-3},
250 {dnnl_s32},
251};
252
253const _dt_conf_t conf_s8s8f16 = {
254 {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, .35, 1, 0.},
255 {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, .35, 1, 0.},
256 {dnnl_f32, -int_max_exact, int_max_exact, -8, 32, 0, .35, 1, 0.},
257 {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, .35, 1,
258 1e-3},
259 {dnnl_s32},
260};
261
262const dt_conf_t *str2cfg(const char *str) {
263#define CASE(cfg) \
264 if (!strcasecmp(STRINGIFY(cfg), str)) return CONCAT2(conf_, cfg)
265 CASE(f32);
266 CASE(f16);
267 CASE(f16f16f32);
268 CASE(f32f16f16);
269 CASE(f16f32f16);
270 CASE(f16f16s8);
271 CASE(f16f16u8);
272 CASE(u8s8f32);
273 CASE(u8s8s32);
274 CASE(u8s8s8);
275 CASE(u8s8u8);
276 CASE(s8s8f32);
277 CASE(s8s8s32);
278 CASE(s8s8s8);
279 CASE(s8s8u8);
280 CASE(s8s8bf16);
281 CASE(u8s8bf16);
282 CASE(s8s8f16);
283 CASE(u8s8f16);
284 CASE(bf16bf16f32);
285 CASE(bf16bf16bf16);
286 CASE(f32bf16bf16);
287 CASE(bf16f32bf16);
288#undef CASE
289 []() {
290 SAFE(FAIL, CRIT);
291 return 0;
292 }();
293 return (const dt_conf_t *)1;
294}
295
296std::ostream &operator<<(std::ostream &s, const dt_conf_t *cfg) {
297#define CASE(_cfg) \
298 if (cfg == CONCAT2(conf_, _cfg)) return s << STRINGIFY(_cfg)
299 CASE(f32);
300 CASE(f16);
301 CASE(f16f16f32);
302 CASE(f32f16f16);
303 CASE(f16f32f16);
304 CASE(f16f16s8);
305 CASE(f16f16u8);
306 CASE(u8s8f32);
307 CASE(u8s8s32);
308 CASE(u8s8s8);
309 CASE(u8s8u8);
310 CASE(s8s8f32);
311 CASE(s8s8s32);
312 CASE(s8s8s8);
313 CASE(s8s8u8);
314 CASE(s8s8bf16);
315 CASE(u8s8bf16);
316 CASE(s8s8f16);
317 CASE(u8s8f16);
318 CASE(bf16bf16f32);
319 CASE(bf16bf16bf16);
320 CASE(f32bf16bf16);
321 CASE(bf16f32bf16);
322#undef CASE
323 SAFE_V(FAIL);
324 return s;
325}
326
327} // namespace ip
328