1 | /******************************************************************************* |
2 | * Copyright 2018-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <float.h> |
18 | #include <stdio.h> |
19 | #include <stdlib.h> |
20 | #include <string.h> |
21 | |
22 | #include "dnnl_common.hpp" |
23 | #include "oneapi/dnnl/dnnl.h" |
24 | |
25 | #include "ip.hpp" |
26 | |
27 | namespace ip { |
28 | |
29 | /* cfgs definition |
30 | * arrays: SRC, WEI, BIA, DST, ACC |
31 | * params: {data_type, min, max, f_min, f_max, f_base, f_sparsity, f_scale, eps} |
32 | */ |
33 | |
34 | const int int_max_exact = 1 << 24; |
35 | const _dt_conf_t conf_f32 = { |
36 | {dnnl_f32, -int_max_exact, int_max_exact, -64, 64, 0, .35, 1. / 128, |
37 | 1e-6}, |
38 | {dnnl_f32, -int_max_exact, int_max_exact, -128, 128, 0, 1.0, 1. / 256, |
39 | 1e-6}, |
40 | {dnnl_f32, -int_max_exact, int_max_exact, -10, 10, 0, 1.0, 1. / 64, |
41 | 1e-6}, |
42 | {dnnl_f32, -int_max_exact, int_max_exact, -10, 10, 0, .35, 1. / 64, |
43 | 1e-6}, |
44 | {dnnl_f32}, |
45 | }; |
46 | |
47 | const _dt_conf_t conf_bf16bf16f32 = { |
48 | {dnnl_bf16, -int_max_exact, int_max_exact, -64, 64, 0, .35, 1. / 128, |
49 | 0}, |
50 | {dnnl_bf16, -int_max_exact, int_max_exact, -128, 128, 0, 1.0, 1. / 256, |
51 | 0}, |
52 | {dnnl_f32, -int_max_exact, int_max_exact, -10, 10, 0, 1.0, 1. / 64, 0}, |
53 | {dnnl_f32, -int_max_exact, int_max_exact, -10, 10, 0, .35, 1. / 64, |
54 | 1e-6}, |
55 | {dnnl_f32}, |
56 | }; |
57 | |
58 | const _dt_conf_t conf_bf16bf16bf16 = { |
59 | /* eps is 1e-2 because of loss in precision of |
60 | * output when converted from fp32 to bf16. |
61 | * oneDNN output is compared against reference computed in fp32.*/ |
62 | {dnnl_bf16, -int_max_exact, int_max_exact, -64, 64, 0, .35, 1. / 128, |
63 | 1e-2}, |
64 | {dnnl_bf16, -int_max_exact, int_max_exact, -128, 128, 0, 1.0, 1. / 256, |
65 | 1e-2}, |
66 | {dnnl_bf16, -int_max_exact, int_max_exact, -10, 10, 0, 1.0, 1. / 64, |
67 | 1e-2}, |
68 | {dnnl_bf16, -int_max_exact, int_max_exact, -10, 10, 0, .35, 1. / 64, |
69 | 1e-2}, |
70 | {dnnl_f32}, |
71 | }; |
72 | |
73 | const _dt_conf_t conf_f32bf16bf16 = { |
74 | {dnnl_f32, -int_max_exact, int_max_exact, -64, 64, 0, .35, 1. / 128, |
75 | 1e-6}, |
76 | {dnnl_bf16, -int_max_exact, int_max_exact, -128, 128, 0, 1.0, 1. / 256, |
77 | 0}, |
78 | {dnnl_bf16, -int_max_exact, int_max_exact, -10, 10, 0, 1.0, 1. / 64, 0}, |
79 | {dnnl_bf16, -int_max_exact, int_max_exact, -10, 10, 0, .35, 1. / 64, 0}, |
80 | {dnnl_f32}, |
81 | }; |
82 | |
83 | const _dt_conf_t conf_bf16f32bf16 = { |
84 | {dnnl_bf16, -int_max_exact, int_max_exact, -64, 64, 0, .35, 1. / 128, |
85 | 0}, |
86 | {dnnl_f32, -int_max_exact, int_max_exact, -128, 128, 0, 1.0, 1. / 256, |
87 | 1e-6}, |
88 | {dnnl_f32, -int_max_exact, int_max_exact, -10, 10, 0, 1.0, 1. / 64, |
89 | 1e-6}, |
90 | {dnnl_bf16, -int_max_exact, int_max_exact, -10, 10, 0, .35, 1. / 64, 0}, |
91 | {dnnl_f32}, |
92 | }; |
93 | |
94 | const int int_max_exact_half = 1 << 11; |
95 | const _dt_conf_t conf_f16 = { |
96 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, .35, 1, |
97 | 1e-3}, |
98 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -2, 2, 0, .35, 1, |
99 | 1e-3}, |
100 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -8, 8, 0, .35, 1, |
101 | 1e-3}, |
102 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, .35, 1, |
103 | 1e-3}, |
104 | {dnnl_f32}, |
105 | }; |
106 | |
107 | const _dt_conf_t conf_f16f16f32 = { |
108 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, .35, 1, |
109 | 1e-3}, |
110 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -2, 2, 0, .35, 1, |
111 | 1e-3}, |
112 | {dnnl_f32, -int_max_exact, int_max_exact, -10, 10, 0, 1.0, 1. / 64, 0}, |
113 | {dnnl_f32, -int_max_exact, int_max_exact, -10, 10, 0, .35, 1. / 64, |
114 | 1e-6}, |
115 | {dnnl_f32}, |
116 | }; |
117 | |
118 | const _dt_conf_t conf_f32f16f16 = { |
119 | {dnnl_f32, -int_max_exact, int_max_exact, -64, 64, 0, .35, 1. / 128, |
120 | 1e-6}, |
121 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -2, 2, 0, .35, 1, |
122 | 1e-3}, |
123 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -8, 8, 0, .35, 1, |
124 | 1e-3}, |
125 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, .35, 1, |
126 | 1e-3}, |
127 | {dnnl_f32}, |
128 | }; |
129 | |
130 | const _dt_conf_t conf_f16f32f16 = { |
131 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, .35, 1, |
132 | 1e-3}, |
133 | {dnnl_f32, -int_max_exact, int_max_exact, -128, 128, 0, 1.0, 1. / 256, |
134 | 1e-6}, |
135 | {dnnl_f32, -int_max_exact, int_max_exact, -10, 10, 0, 1.0, 1. / 64, |
136 | 1e-6}, |
137 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, .35, 1, |
138 | 1e-3}, |
139 | {dnnl_f32}, |
140 | }; |
141 | |
142 | const _dt_conf_t conf_f16f16s8 = { |
143 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, .35, 1, |
144 | 1e-3}, |
145 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -2, 2, 0, .35, 1, |
146 | 1e-3}, |
147 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -8, 8, 0, .35, 1, |
148 | 1e-3}, |
149 | {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, .35, 1, 0.}, |
150 | {dnnl_f32}, |
151 | }; |
152 | |
153 | const _dt_conf_t conf_f16f16u8 = { |
154 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, .35, 1, |
155 | 1e-3}, |
156 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -2, 2, 0, .35, 1, |
157 | 1e-3}, |
158 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -8, 8, 0, .35, 1, |
159 | 1e-3}, |
160 | {dnnl_u8, 0, UINT8_MAX, 0, 8, 0, .35, 1, 0.}, |
161 | {dnnl_f32}, |
162 | }; |
163 | |
164 | const _dt_conf_t conf_u8s8f32 = { |
165 | {dnnl_u8, 0, UINT8_MAX, 0, 8, 0, .35, 1, 0.}, |
166 | {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, .35, 1, 0.}, |
167 | {dnnl_f32, -int_max_exact, int_max_exact, -8, 32, 0, .35, 1, 0.}, |
168 | {dnnl_f32, -int_max_exact, int_max_exact, -255, 255, 0, .35, 1, 0.}, |
169 | {dnnl_s32}, |
170 | }; |
171 | |
172 | const _dt_conf_t conf_u8s8s32 = { |
173 | {dnnl_u8, 0, UINT8_MAX, 0, 8, 0, .35, 1, 0.}, |
174 | {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, .35, 1, 0.}, |
175 | {dnnl_f32, -int_max_exact, int_max_exact, -8, 32, 0, .35, 1, 0.}, |
176 | {dnnl_s32, INT32_MIN, INT32_MAX, -255, 255, 0, .35, 1, 0.}, |
177 | {dnnl_s32}, |
178 | }; |
179 | |
180 | const _dt_conf_t conf_u8s8s8 = { |
181 | {dnnl_u8, 0, UINT8_MAX, 0, 8, 0, .35, 1, 0.}, |
182 | {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, .35, 1, 0.}, |
183 | {dnnl_f32, -int_max_exact, int_max_exact, -8, 32, 0, .35, 1, 0.}, |
184 | {dnnl_s8, INT8_MIN, INT8_MAX, -127, 127, 0, .35, 1, 0.}, |
185 | {dnnl_s32}, |
186 | }; |
187 | |
188 | const _dt_conf_t conf_u8s8u8 = { |
189 | {dnnl_u8, 0, UINT8_MAX, 0, 8, 0, .35, 1, 0.}, |
190 | {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, .35, 1, 0.}, |
191 | {dnnl_f32, -int_max_exact, int_max_exact, -8, 32, 0, .35, 1, 0.}, |
192 | {dnnl_u8, 0, UINT8_MAX, 0, 255, 0, .35, 1, 0.}, |
193 | {dnnl_s32}, |
194 | }; |
195 | |
196 | const _dt_conf_t conf_s8s8f32 = { |
197 | {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, .35, 1, 0.}, |
198 | {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, .35, 1, 0.}, |
199 | {dnnl_f32, -int_max_exact, int_max_exact, -8, 32, 0, .35, 1, 0.}, |
200 | {dnnl_f32, -int_max_exact, int_max_exact, -255, 255, 0, .35, 1, 0.}, |
201 | {dnnl_s32}, |
202 | }; |
203 | |
204 | const _dt_conf_t conf_s8s8s32 = { |
205 | {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, .35, 1, 0.}, |
206 | {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, .35, 1, 0.}, |
207 | {dnnl_f32, -int_max_exact, int_max_exact, -8, 32, 0, .35, 1, 0.}, |
208 | {dnnl_s32, INT32_MIN, INT32_MAX, -255, 255, 0, .35, 1, 0.}, |
209 | {dnnl_s32}, |
210 | }; |
211 | |
212 | const _dt_conf_t conf_s8s8s8 = { |
213 | {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, .35, 1, 0.}, |
214 | {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, .35, 1, 0.}, |
215 | {dnnl_f32, -int_max_exact, int_max_exact, -8, 32, 0, .35, 1, 0.}, |
216 | {dnnl_s8, INT8_MIN, INT8_MAX, -127, 127, 0, .35, 1, 0.}, |
217 | {dnnl_s32}, |
218 | }; |
219 | |
220 | const _dt_conf_t conf_s8s8u8 = { |
221 | {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, .35, 1, 0.}, |
222 | {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, .35, 1, 0.}, |
223 | {dnnl_f32, -int_max_exact, int_max_exact, -8, 32, 0, .35, 1, 0.}, |
224 | {dnnl_u8, 0, UINT8_MAX, 0, 255, 0, .35, 1, 0.}, |
225 | {dnnl_s32}, |
226 | }; |
227 | |
228 | const _dt_conf_t conf_u8s8bf16 = { |
229 | {dnnl_u8, 0, UINT8_MAX, 0, 8, 0, .35, 1, 0.}, |
230 | {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, .35, 1, 0.}, |
231 | {dnnl_f32, -int_max_exact, int_max_exact, -8, 32, 0, .35, 1, 0.}, |
232 | {dnnl_bf16, -int_max_exact, int_max_exact, -255, 255, 0, .35, 1, 0.}, |
233 | {dnnl_s32}, |
234 | }; |
235 | |
236 | const _dt_conf_t conf_s8s8bf16 = { |
237 | {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, .35, 1, 0.}, |
238 | {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, .35, 1, 0.}, |
239 | {dnnl_f32, -int_max_exact, int_max_exact, -8, 32, 0, .35, 1, 0.}, |
240 | {dnnl_bf16, -int_max_exact, int_max_exact, -255, 255, 0, .35, 1, 0.}, |
241 | {dnnl_s32}, |
242 | }; |
243 | |
244 | const _dt_conf_t conf_u8s8f16 = { |
245 | {dnnl_u8, 0, UINT8_MAX, 0, 8, 0, .35, 1, 0.}, |
246 | {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, .35, 1, 0.}, |
247 | {dnnl_f32, -int_max_exact, int_max_exact, -8, 32, 0, .35, 1, 0.}, |
248 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, .35, 1, |
249 | 1e-3}, |
250 | {dnnl_s32}, |
251 | }; |
252 | |
253 | const _dt_conf_t conf_s8s8f16 = { |
254 | {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, .35, 1, 0.}, |
255 | {dnnl_s8, INT8_MIN, INT8_MAX, -5, 5, 0, .35, 1, 0.}, |
256 | {dnnl_f32, -int_max_exact, int_max_exact, -8, 32, 0, .35, 1, 0.}, |
257 | {dnnl_f16, -int_max_exact_half, int_max_exact_half, -4, 4, 0, .35, 1, |
258 | 1e-3}, |
259 | {dnnl_s32}, |
260 | }; |
261 | |
262 | const dt_conf_t *str2cfg(const char *str) { |
263 | #define CASE(cfg) \ |
264 | if (!strcasecmp(STRINGIFY(cfg), str)) return CONCAT2(conf_, cfg) |
265 | CASE(f32); |
266 | CASE(f16); |
267 | CASE(f16f16f32); |
268 | CASE(f32f16f16); |
269 | CASE(f16f32f16); |
270 | CASE(f16f16s8); |
271 | CASE(f16f16u8); |
272 | CASE(u8s8f32); |
273 | CASE(u8s8s32); |
274 | CASE(u8s8s8); |
275 | CASE(u8s8u8); |
276 | CASE(s8s8f32); |
277 | CASE(s8s8s32); |
278 | CASE(s8s8s8); |
279 | CASE(s8s8u8); |
280 | CASE(s8s8bf16); |
281 | CASE(u8s8bf16); |
282 | CASE(s8s8f16); |
283 | CASE(u8s8f16); |
284 | CASE(bf16bf16f32); |
285 | CASE(bf16bf16bf16); |
286 | CASE(f32bf16bf16); |
287 | CASE(bf16f32bf16); |
288 | #undef CASE |
289 | []() { |
290 | SAFE(FAIL, CRIT); |
291 | return 0; |
292 | }(); |
293 | return (const dt_conf_t *)1; |
294 | } |
295 | |
296 | std::ostream &operator<<(std::ostream &s, const dt_conf_t *cfg) { |
297 | #define CASE(_cfg) \ |
298 | if (cfg == CONCAT2(conf_, _cfg)) return s << STRINGIFY(_cfg) |
299 | CASE(f32); |
300 | CASE(f16); |
301 | CASE(f16f16f32); |
302 | CASE(f32f16f16); |
303 | CASE(f16f32f16); |
304 | CASE(f16f16s8); |
305 | CASE(f16f16u8); |
306 | CASE(u8s8f32); |
307 | CASE(u8s8s32); |
308 | CASE(u8s8s8); |
309 | CASE(u8s8u8); |
310 | CASE(s8s8f32); |
311 | CASE(s8s8s32); |
312 | CASE(s8s8s8); |
313 | CASE(s8s8u8); |
314 | CASE(s8s8bf16); |
315 | CASE(u8s8bf16); |
316 | CASE(s8s8f16); |
317 | CASE(u8s8f16); |
318 | CASE(bf16bf16f32); |
319 | CASE(bf16bf16bf16); |
320 | CASE(f32bf16bf16); |
321 | CASE(bf16f32bf16); |
322 | #undef CASE |
323 | SAFE_V(FAIL); |
324 | return s; |
325 | } |
326 | |
327 | } // namespace ip |
328 | |