1 | /******************************************************************************* |
2 | * Copyright 2018-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <set> |
18 | |
19 | #include "dnnl_common.hpp" |
20 | #include "rnn/rnn.hpp" |
21 | |
22 | namespace rnn { |
23 | |
24 | namespace { |
25 | |
26 | #define CASE(KIND, ENTRY) \ |
27 | if (kind == (KIND)) return ENTRY |
28 | #define DEFAULT(ENTRY) return ENTRY |
29 | #define END_LIST \ |
30 | SAFE_V(CRIT); \ |
31 | return F32_ENTRY |
32 | |
33 | #define CFG_INTERNAL(name, alias) \ |
34 | struct conf_##name##_t : dt_conf_t { \ |
35 | using dt_conf_t::dt_conf_t; \ |
36 | const entry_t &operator[](rnn_data_kind_t kind) const override; \ |
37 | } conf_##name(STRINGIFY(alias)); \ |
38 | const dt_conf_t::entry_t &conf_##name##_t::operator[]( \ |
39 | rnn_data_kind_t kind) const |
40 | |
41 | std::set<const dt_conf_t *> cfg_list; |
42 | #define CFG(name) \ |
43 | struct conf_##name##_t : dt_conf_t { \ |
44 | using dt_conf_t::dt_conf_t; \ |
45 | const entry_t &operator[](rnn_data_kind_t kind) const override; \ |
46 | } conf_##name(STRINGIFY(name)); \ |
47 | static auto __reg_##name = cfg_list.insert(&conf_##name); \ |
48 | const dt_conf_t::entry_t &conf_##name##_t::operator[]( \ |
49 | rnn_data_kind_t kind) const |
50 | |
51 | // f32 |
52 | #define MIN_F32 0.0f |
53 | #define MAX_F32 .999999f |
54 | #define MEAN_F32 .5f |
55 | #define STDDEV_F32 0.01f |
56 | #define EPS_F32 epsilon_dt(dnnl_f32) |
57 | const int f32_max_exact = 1 << 24; |
58 | dt_conf_t::entry_t F32_ENTRY {dnnl_f32, -f32_max_exact, f32_max_exact, MIN_F32, |
59 | MAX_F32, MEAN_F32, STDDEV_F32, EPS_F32}; |
60 | |
61 | #define UNUSED_REG_VAR(name) UNUSED(__reg_##name) |
62 | |
63 | CFG(f32) { |
64 | UNUSED_REG_VAR(f32); |
65 | return F32_ENTRY; |
66 | } |
67 | |
68 | // bf16 |
69 | #define MIN_BF16 0.0f |
70 | #define MAX_BF16 .999999f |
71 | #define MEAN_BF16 .5f |
72 | #define STDDEV_BF16 0.01f |
73 | #define EPS_BF16 epsilon_dt(dnnl_bf16) |
74 | dt_conf_t::entry_t BF16_ENTRY_BF16 {dnnl_bf16, -f32_max_exact, f32_max_exact, |
75 | MIN_BF16, MAX_BF16, MEAN_BF16, STDDEV_BF16, EPS_BF16}; |
76 | dt_conf_t::entry_t BF16_ENTRY_F32 {dnnl_f32, -f32_max_exact, f32_max_exact, |
77 | MIN_F32, MAX_F32, MEAN_F32, STDDEV_F32, EPS_BF16}; |
78 | |
79 | CFG(bf16f32) { |
80 | UNUSED_REG_VAR(bf16f32); |
81 | CASE(SRC_LAYER, BF16_ENTRY_BF16); |
82 | CASE(SRC_ITER, BF16_ENTRY_BF16); |
83 | CASE(WEIGHTS_LAYER, BF16_ENTRY_BF16); |
84 | CASE(WEIGHTS_ITER, BF16_ENTRY_BF16); |
85 | CASE(DST_ITER, BF16_ENTRY_BF16); |
86 | CASE(DST_LAYER, BF16_ENTRY_BF16); |
87 | CASE(AUGRU_ATTENTION, BF16_ENTRY_BF16); |
88 | DEFAULT(BF16_ENTRY_F32); |
89 | } |
90 | |
91 | CFG(bf16) { |
92 | UNUSED_REG_VAR(bf16); |
93 | CASE(SRC_LAYER, BF16_ENTRY_BF16); |
94 | CASE(SRC_ITER, BF16_ENTRY_BF16); |
95 | CASE(SRC_ITER_C, BF16_ENTRY_BF16); |
96 | CASE(WEIGHTS_LAYER, BF16_ENTRY_BF16); |
97 | CASE(WEIGHTS_ITER, BF16_ENTRY_BF16); |
98 | CASE(WEIGHTS_PEEPHOLE, BF16_ENTRY_F32); |
99 | CASE(WEIGHTS_PROJECTION, BF16_ENTRY_F32); |
100 | CASE(BIAS, BF16_ENTRY_BF16); |
101 | CASE(DST_ITER, BF16_ENTRY_BF16); |
102 | CASE(DST_ITER_C, BF16_ENTRY_BF16); |
103 | CASE(DST_LAYER, BF16_ENTRY_BF16); |
104 | CASE(AUGRU_ATTENTION, BF16_ENTRY_BF16); |
105 | DEFAULT(BF16_ENTRY_F32); |
106 | } |
107 | |
108 | // bf32 |
109 | dt_conf_t::entry_t BF32_ENTRY {dnnl_f32, -f32_max_exact, f32_max_exact, |
110 | MIN_BF16, MAX_BF16, MEAN_BF16, STDDEV_BF16, EPS_BF16}; |
111 | CFG_INTERNAL(bf32, f32) { |
112 | CASE(BIAS, F32_ENTRY); |
113 | DEFAULT(BF32_ENTRY); |
114 | } |
115 | |
116 | // f16 |
117 | const int f16_max_exact = 1 << 11; |
118 | dt_conf_t::entry_t F16_ENTRY {dnnl_f16, -f16_max_exact, f16_max_exact, 0.0f, |
119 | 0.999999f, 0.5f, 0.01f, epsilon_dt(dnnl_f16)}; |
120 | |
121 | CFG(f16) { |
122 | UNUSED_REG_VAR(f16); |
123 | return F16_ENTRY; |
124 | } |
125 | |
126 | // s8 |
127 | #define EPS_U8 4e-3 |
128 | #define EPS_S8 8e-3 |
129 | |
130 | #define MIN_U8 0.0f |
131 | #define MAX_U8 127.f |
132 | #define MEAN_U8 28.f |
133 | #define STDDEV_U8 16.f |
134 | |
135 | #define MIN_S8 (-64.f) |
136 | #define MAX_S8 64.f |
137 | #define MEAN_S8 8.f |
138 | #define STDDEV_S8 32.f |
139 | #define MEAN_WEIGHT_S8 0.f |
140 | |
141 | dt_conf_t::entry_t U8_ENTRY_U8_EXACT { |
142 | dnnl_u8, 0, UINT8_MAX, MIN_U8, MAX_U8, MEAN_U8, STDDEV_U8, 0.f}; |
143 | dt_conf_t::entry_t U8_ENTRY_U8 { |
144 | dnnl_u8, 0, UINT8_MAX, MIN_U8, MAX_U8, MEAN_U8, STDDEV_U8, EPS_U8}; |
145 | dt_conf_t::entry_t U8_ENTRY_S8 {dnnl_s8, INT8_MIN, INT8_MAX, MIN_S8, MAX_S8, |
146 | MEAN_WEIGHT_S8, STDDEV_S8, EPS_S8}; |
147 | dt_conf_t::entry_t U8_ENTRY_F32 {dnnl_f32, -f32_max_exact, f32_max_exact, |
148 | MIN_F32, MAX_F32, MEAN_F32, STDDEV_F32, EPS_F32}; |
149 | |
150 | dt_conf_t::entry_t S8_ENTRY_S8_EXACT { |
151 | dnnl_s8, INT8_MIN, INT8_MAX, 0, MAX_S8, MEAN_S8, STDDEV_S8, 0.f}; |
152 | dt_conf_t::entry_t S8_ENTRY_S8 { |
153 | dnnl_s8, INT8_MIN, INT8_MAX, 0, MAX_S8, MEAN_S8, STDDEV_S8, EPS_S8}; |
154 | dt_conf_t::entry_t S8_ENTRY_WEIGHT_S8 {dnnl_s8, INT8_MIN, INT8_MAX, MIN_S8, |
155 | MAX_S8, MEAN_WEIGHT_S8, STDDEV_S8, EPS_S8}; |
156 | dt_conf_t::entry_t S8_ENTRY_F32 {dnnl_f32, -f32_max_exact, f32_max_exact, |
157 | MIN_F32, MAX_F32, MEAN_F32, STDDEV_F32, EPS_F32}; |
158 | |
159 | CFG(u8u8u8u8) { |
160 | UNUSED_REG_VAR(u8u8u8u8); |
161 | CASE(SRC_LAYER, U8_ENTRY_U8); |
162 | CASE(SRC_ITER, U8_ENTRY_U8); |
163 | CASE(SRC_ITER_C, U8_ENTRY_F32); |
164 | CASE(WEIGHTS_LAYER, U8_ENTRY_S8); |
165 | CASE(WEIGHTS_ITER, U8_ENTRY_S8); |
166 | CASE(WEIGHTS_PEEPHOLE, U8_ENTRY_F32); |
167 | CASE(WEIGHTS_PROJECTION, U8_ENTRY_S8); |
168 | CASE(BIAS, U8_ENTRY_F32); |
169 | CASE(DST_ITER, U8_ENTRY_U8); |
170 | CASE(DST_ITER_C, U8_ENTRY_F32); |
171 | CASE(DST_LAYER, U8_ENTRY_U8_EXACT); |
172 | END_LIST; |
173 | } |
174 | |
175 | CFG(u8u8u8f32) { |
176 | UNUSED_REG_VAR(u8u8u8f32); |
177 | CASE(SRC_LAYER, U8_ENTRY_U8); |
178 | CASE(SRC_ITER, U8_ENTRY_U8); |
179 | CASE(SRC_ITER_C, U8_ENTRY_F32); |
180 | CASE(WEIGHTS_LAYER, U8_ENTRY_S8); |
181 | CASE(WEIGHTS_ITER, U8_ENTRY_S8); |
182 | CASE(WEIGHTS_PEEPHOLE, U8_ENTRY_F32); |
183 | CASE(WEIGHTS_PROJECTION, U8_ENTRY_S8); |
184 | CASE(BIAS, U8_ENTRY_F32); |
185 | CASE(DST_ITER, U8_ENTRY_U8); |
186 | CASE(DST_ITER_C, U8_ENTRY_F32); |
187 | CASE(DST_LAYER, U8_ENTRY_F32); |
188 | END_LIST; |
189 | } |
190 | |
191 | CFG(f32u8f32u8) { |
192 | UNUSED_REG_VAR(f32u8f32u8); |
193 | CASE(SRC_LAYER, U8_ENTRY_U8); |
194 | CASE(SRC_ITER, U8_ENTRY_F32); |
195 | CASE(SRC_ITER_C, U8_ENTRY_F32); |
196 | CASE(WEIGHTS_LAYER, U8_ENTRY_S8); |
197 | CASE(WEIGHTS_ITER, U8_ENTRY_S8); |
198 | CASE(WEIGHTS_PEEPHOLE, U8_ENTRY_F32); |
199 | CASE(WEIGHTS_PROJECTION, U8_ENTRY_S8); |
200 | CASE(BIAS, U8_ENTRY_F32); |
201 | CASE(DST_ITER, U8_ENTRY_F32); |
202 | CASE(DST_ITER_C, U8_ENTRY_F32); |
203 | CASE(DST_LAYER, U8_ENTRY_U8_EXACT); |
204 | END_LIST; |
205 | } |
206 | |
207 | CFG(f32u8f32f32) { |
208 | UNUSED_REG_VAR(f32u8f32f32); |
209 | CASE(SRC_LAYER, U8_ENTRY_U8); |
210 | CASE(SRC_ITER, U8_ENTRY_F32); |
211 | CASE(SRC_ITER_C, U8_ENTRY_F32); |
212 | CASE(WEIGHTS_LAYER, U8_ENTRY_S8); |
213 | CASE(WEIGHTS_ITER, U8_ENTRY_S8); |
214 | CASE(WEIGHTS_PEEPHOLE, U8_ENTRY_F32); |
215 | CASE(WEIGHTS_PROJECTION, U8_ENTRY_S8); |
216 | CASE(BIAS, U8_ENTRY_F32); |
217 | CASE(DST_ITER, U8_ENTRY_F32); |
218 | CASE(DST_ITER_C, U8_ENTRY_F32); |
219 | CASE(DST_LAYER, U8_ENTRY_F32); |
220 | END_LIST; |
221 | } |
222 | |
223 | CFG(s8s8s8s8) { |
224 | UNUSED_REG_VAR(s8s8s8s8); |
225 | CASE(SRC_LAYER, S8_ENTRY_S8); |
226 | CASE(SRC_ITER, S8_ENTRY_S8); |
227 | CASE(SRC_ITER_C, S8_ENTRY_F32); |
228 | CASE(WEIGHTS_LAYER, S8_ENTRY_WEIGHT_S8); |
229 | CASE(WEIGHTS_ITER, S8_ENTRY_WEIGHT_S8); |
230 | CASE(WEIGHTS_PEEPHOLE, S8_ENTRY_F32); |
231 | CASE(WEIGHTS_PROJECTION, S8_ENTRY_WEIGHT_S8); |
232 | CASE(BIAS, S8_ENTRY_F32); |
233 | CASE(DST_ITER, S8_ENTRY_S8); |
234 | CASE(DST_ITER_C, S8_ENTRY_F32); |
235 | CASE(DST_LAYER, S8_ENTRY_S8_EXACT); |
236 | END_LIST; |
237 | } |
238 | |
239 | CFG(s8s8s8f32) { |
240 | UNUSED_REG_VAR(s8s8s8f32); |
241 | CASE(SRC_LAYER, S8_ENTRY_S8); |
242 | CASE(SRC_ITER, S8_ENTRY_S8); |
243 | CASE(SRC_ITER_C, S8_ENTRY_F32); |
244 | CASE(WEIGHTS_LAYER, S8_ENTRY_WEIGHT_S8); |
245 | CASE(WEIGHTS_ITER, S8_ENTRY_WEIGHT_S8); |
246 | CASE(WEIGHTS_PEEPHOLE, S8_ENTRY_F32); |
247 | CASE(WEIGHTS_PROJECTION, S8_ENTRY_WEIGHT_S8); |
248 | CASE(BIAS, S8_ENTRY_F32); |
249 | CASE(DST_ITER, S8_ENTRY_S8); |
250 | CASE(DST_ITER_C, S8_ENTRY_F32); |
251 | CASE(DST_LAYER, S8_ENTRY_F32); |
252 | END_LIST; |
253 | } |
254 | |
255 | CFG(f32s8f32s8) { |
256 | UNUSED_REG_VAR(f32s8f32s8); |
257 | CASE(SRC_LAYER, S8_ENTRY_S8); |
258 | CASE(SRC_ITER, S8_ENTRY_F32); |
259 | CASE(SRC_ITER_C, S8_ENTRY_F32); |
260 | CASE(WEIGHTS_LAYER, S8_ENTRY_WEIGHT_S8); |
261 | CASE(WEIGHTS_ITER, S8_ENTRY_WEIGHT_S8); |
262 | CASE(WEIGHTS_PEEPHOLE, S8_ENTRY_F32); |
263 | CASE(WEIGHTS_PROJECTION, S8_ENTRY_WEIGHT_S8); |
264 | CASE(BIAS, S8_ENTRY_F32); |
265 | CASE(DST_ITER, S8_ENTRY_F32); |
266 | CASE(DST_ITER_C, S8_ENTRY_F32); |
267 | CASE(DST_LAYER, S8_ENTRY_S8_EXACT); |
268 | END_LIST; |
269 | } |
270 | |
271 | CFG(f32s8f32f32) { |
272 | UNUSED_REG_VAR(f32s8f32f32); |
273 | CASE(SRC_LAYER, S8_ENTRY_S8); |
274 | CASE(SRC_ITER, S8_ENTRY_F32); |
275 | CASE(SRC_ITER_C, S8_ENTRY_F32); |
276 | CASE(WEIGHTS_LAYER, S8_ENTRY_WEIGHT_S8); |
277 | CASE(WEIGHTS_ITER, S8_ENTRY_WEIGHT_S8); |
278 | CASE(WEIGHTS_PEEPHOLE, S8_ENTRY_F32); |
279 | CASE(WEIGHTS_PROJECTION, S8_ENTRY_WEIGHT_S8); |
280 | CASE(BIAS, S8_ENTRY_F32); |
281 | CASE(DST_ITER, S8_ENTRY_F32); |
282 | CASE(DST_ITER_C, S8_ENTRY_F32); |
283 | CASE(DST_LAYER, S8_ENTRY_F32); |
284 | END_LIST; |
285 | } |
286 | |
287 | } // namespace |
288 | |
289 | const dt_conf_t &dt_conf_t::create(const std::string &str, const attr_t &attr) { |
290 | if (attr.fpmath_mode == dnnl_fpmath_mode_bf16 && str == "f32" ) |
291 | return conf_bf32; |
292 | for (const auto cfg : cfg_list) |
293 | if (cfg->str() == str) return *cfg; |
294 | SAFE_V(CRIT); |
295 | return conf_f32; |
296 | } |
297 | |
298 | } // namespace rnn |
299 | |