1/*******************************************************************************
2* Copyright 2018-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <set>
18
19#include "dnnl_common.hpp"
20#include "rnn/rnn.hpp"
21
22namespace rnn {
23
24namespace {
25
26#define CASE(KIND, ENTRY) \
27 if (kind == (KIND)) return ENTRY
28#define DEFAULT(ENTRY) return ENTRY
29#define END_LIST \
30 SAFE_V(CRIT); \
31 return F32_ENTRY
32
33#define CFG_INTERNAL(name, alias) \
34 struct conf_##name##_t : dt_conf_t { \
35 using dt_conf_t::dt_conf_t; \
36 const entry_t &operator[](rnn_data_kind_t kind) const override; \
37 } conf_##name(STRINGIFY(alias)); \
38 const dt_conf_t::entry_t &conf_##name##_t::operator[]( \
39 rnn_data_kind_t kind) const
40
41std::set<const dt_conf_t *> cfg_list;
42#define CFG(name) \
43 struct conf_##name##_t : dt_conf_t { \
44 using dt_conf_t::dt_conf_t; \
45 const entry_t &operator[](rnn_data_kind_t kind) const override; \
46 } conf_##name(STRINGIFY(name)); \
47 static auto __reg_##name = cfg_list.insert(&conf_##name); \
48 const dt_conf_t::entry_t &conf_##name##_t::operator[]( \
49 rnn_data_kind_t kind) const
50
51// f32
52#define MIN_F32 0.0f
53#define MAX_F32 .999999f
54#define MEAN_F32 .5f
55#define STDDEV_F32 0.01f
56#define EPS_F32 epsilon_dt(dnnl_f32)
57const int f32_max_exact = 1 << 24;
58dt_conf_t::entry_t F32_ENTRY {dnnl_f32, -f32_max_exact, f32_max_exact, MIN_F32,
59 MAX_F32, MEAN_F32, STDDEV_F32, EPS_F32};
60
61#define UNUSED_REG_VAR(name) UNUSED(__reg_##name)
62
63CFG(f32) {
64 UNUSED_REG_VAR(f32);
65 return F32_ENTRY;
66}
67
68// bf16
69#define MIN_BF16 0.0f
70#define MAX_BF16 .999999f
71#define MEAN_BF16 .5f
72#define STDDEV_BF16 0.01f
73#define EPS_BF16 epsilon_dt(dnnl_bf16)
74dt_conf_t::entry_t BF16_ENTRY_BF16 {dnnl_bf16, -f32_max_exact, f32_max_exact,
75 MIN_BF16, MAX_BF16, MEAN_BF16, STDDEV_BF16, EPS_BF16};
76dt_conf_t::entry_t BF16_ENTRY_F32 {dnnl_f32, -f32_max_exact, f32_max_exact,
77 MIN_F32, MAX_F32, MEAN_F32, STDDEV_F32, EPS_BF16};
78
79CFG(bf16f32) {
80 UNUSED_REG_VAR(bf16f32);
81 CASE(SRC_LAYER, BF16_ENTRY_BF16);
82 CASE(SRC_ITER, BF16_ENTRY_BF16);
83 CASE(WEIGHTS_LAYER, BF16_ENTRY_BF16);
84 CASE(WEIGHTS_ITER, BF16_ENTRY_BF16);
85 CASE(DST_ITER, BF16_ENTRY_BF16);
86 CASE(DST_LAYER, BF16_ENTRY_BF16);
87 CASE(AUGRU_ATTENTION, BF16_ENTRY_BF16);
88 DEFAULT(BF16_ENTRY_F32);
89}
90
91CFG(bf16) {
92 UNUSED_REG_VAR(bf16);
93 CASE(SRC_LAYER, BF16_ENTRY_BF16);
94 CASE(SRC_ITER, BF16_ENTRY_BF16);
95 CASE(SRC_ITER_C, BF16_ENTRY_BF16);
96 CASE(WEIGHTS_LAYER, BF16_ENTRY_BF16);
97 CASE(WEIGHTS_ITER, BF16_ENTRY_BF16);
98 CASE(WEIGHTS_PEEPHOLE, BF16_ENTRY_F32);
99 CASE(WEIGHTS_PROJECTION, BF16_ENTRY_F32);
100 CASE(BIAS, BF16_ENTRY_BF16);
101 CASE(DST_ITER, BF16_ENTRY_BF16);
102 CASE(DST_ITER_C, BF16_ENTRY_BF16);
103 CASE(DST_LAYER, BF16_ENTRY_BF16);
104 CASE(AUGRU_ATTENTION, BF16_ENTRY_BF16);
105 DEFAULT(BF16_ENTRY_F32);
106}
107
108// bf32
109dt_conf_t::entry_t BF32_ENTRY {dnnl_f32, -f32_max_exact, f32_max_exact,
110 MIN_BF16, MAX_BF16, MEAN_BF16, STDDEV_BF16, EPS_BF16};
111CFG_INTERNAL(bf32, f32) {
112 CASE(BIAS, F32_ENTRY);
113 DEFAULT(BF32_ENTRY);
114}
115
116// f16
117const int f16_max_exact = 1 << 11;
118dt_conf_t::entry_t F16_ENTRY {dnnl_f16, -f16_max_exact, f16_max_exact, 0.0f,
119 0.999999f, 0.5f, 0.01f, epsilon_dt(dnnl_f16)};
120
121CFG(f16) {
122 UNUSED_REG_VAR(f16);
123 return F16_ENTRY;
124}
125
126// s8
127#define EPS_U8 4e-3
128#define EPS_S8 8e-3
129
130#define MIN_U8 0.0f
131#define MAX_U8 127.f
132#define MEAN_U8 28.f
133#define STDDEV_U8 16.f
134
135#define MIN_S8 (-64.f)
136#define MAX_S8 64.f
137#define MEAN_S8 8.f
138#define STDDEV_S8 32.f
139#define MEAN_WEIGHT_S8 0.f
140
141dt_conf_t::entry_t U8_ENTRY_U8_EXACT {
142 dnnl_u8, 0, UINT8_MAX, MIN_U8, MAX_U8, MEAN_U8, STDDEV_U8, 0.f};
143dt_conf_t::entry_t U8_ENTRY_U8 {
144 dnnl_u8, 0, UINT8_MAX, MIN_U8, MAX_U8, MEAN_U8, STDDEV_U8, EPS_U8};
145dt_conf_t::entry_t U8_ENTRY_S8 {dnnl_s8, INT8_MIN, INT8_MAX, MIN_S8, MAX_S8,
146 MEAN_WEIGHT_S8, STDDEV_S8, EPS_S8};
147dt_conf_t::entry_t U8_ENTRY_F32 {dnnl_f32, -f32_max_exact, f32_max_exact,
148 MIN_F32, MAX_F32, MEAN_F32, STDDEV_F32, EPS_F32};
149
150dt_conf_t::entry_t S8_ENTRY_S8_EXACT {
151 dnnl_s8, INT8_MIN, INT8_MAX, 0, MAX_S8, MEAN_S8, STDDEV_S8, 0.f};
152dt_conf_t::entry_t S8_ENTRY_S8 {
153 dnnl_s8, INT8_MIN, INT8_MAX, 0, MAX_S8, MEAN_S8, STDDEV_S8, EPS_S8};
154dt_conf_t::entry_t S8_ENTRY_WEIGHT_S8 {dnnl_s8, INT8_MIN, INT8_MAX, MIN_S8,
155 MAX_S8, MEAN_WEIGHT_S8, STDDEV_S8, EPS_S8};
156dt_conf_t::entry_t S8_ENTRY_F32 {dnnl_f32, -f32_max_exact, f32_max_exact,
157 MIN_F32, MAX_F32, MEAN_F32, STDDEV_F32, EPS_F32};
158
159CFG(u8u8u8u8) {
160 UNUSED_REG_VAR(u8u8u8u8);
161 CASE(SRC_LAYER, U8_ENTRY_U8);
162 CASE(SRC_ITER, U8_ENTRY_U8);
163 CASE(SRC_ITER_C, U8_ENTRY_F32);
164 CASE(WEIGHTS_LAYER, U8_ENTRY_S8);
165 CASE(WEIGHTS_ITER, U8_ENTRY_S8);
166 CASE(WEIGHTS_PEEPHOLE, U8_ENTRY_F32);
167 CASE(WEIGHTS_PROJECTION, U8_ENTRY_S8);
168 CASE(BIAS, U8_ENTRY_F32);
169 CASE(DST_ITER, U8_ENTRY_U8);
170 CASE(DST_ITER_C, U8_ENTRY_F32);
171 CASE(DST_LAYER, U8_ENTRY_U8_EXACT);
172 END_LIST;
173}
174
175CFG(u8u8u8f32) {
176 UNUSED_REG_VAR(u8u8u8f32);
177 CASE(SRC_LAYER, U8_ENTRY_U8);
178 CASE(SRC_ITER, U8_ENTRY_U8);
179 CASE(SRC_ITER_C, U8_ENTRY_F32);
180 CASE(WEIGHTS_LAYER, U8_ENTRY_S8);
181 CASE(WEIGHTS_ITER, U8_ENTRY_S8);
182 CASE(WEIGHTS_PEEPHOLE, U8_ENTRY_F32);
183 CASE(WEIGHTS_PROJECTION, U8_ENTRY_S8);
184 CASE(BIAS, U8_ENTRY_F32);
185 CASE(DST_ITER, U8_ENTRY_U8);
186 CASE(DST_ITER_C, U8_ENTRY_F32);
187 CASE(DST_LAYER, U8_ENTRY_F32);
188 END_LIST;
189}
190
191CFG(f32u8f32u8) {
192 UNUSED_REG_VAR(f32u8f32u8);
193 CASE(SRC_LAYER, U8_ENTRY_U8);
194 CASE(SRC_ITER, U8_ENTRY_F32);
195 CASE(SRC_ITER_C, U8_ENTRY_F32);
196 CASE(WEIGHTS_LAYER, U8_ENTRY_S8);
197 CASE(WEIGHTS_ITER, U8_ENTRY_S8);
198 CASE(WEIGHTS_PEEPHOLE, U8_ENTRY_F32);
199 CASE(WEIGHTS_PROJECTION, U8_ENTRY_S8);
200 CASE(BIAS, U8_ENTRY_F32);
201 CASE(DST_ITER, U8_ENTRY_F32);
202 CASE(DST_ITER_C, U8_ENTRY_F32);
203 CASE(DST_LAYER, U8_ENTRY_U8_EXACT);
204 END_LIST;
205}
206
207CFG(f32u8f32f32) {
208 UNUSED_REG_VAR(f32u8f32f32);
209 CASE(SRC_LAYER, U8_ENTRY_U8);
210 CASE(SRC_ITER, U8_ENTRY_F32);
211 CASE(SRC_ITER_C, U8_ENTRY_F32);
212 CASE(WEIGHTS_LAYER, U8_ENTRY_S8);
213 CASE(WEIGHTS_ITER, U8_ENTRY_S8);
214 CASE(WEIGHTS_PEEPHOLE, U8_ENTRY_F32);
215 CASE(WEIGHTS_PROJECTION, U8_ENTRY_S8);
216 CASE(BIAS, U8_ENTRY_F32);
217 CASE(DST_ITER, U8_ENTRY_F32);
218 CASE(DST_ITER_C, U8_ENTRY_F32);
219 CASE(DST_LAYER, U8_ENTRY_F32);
220 END_LIST;
221}
222
223CFG(s8s8s8s8) {
224 UNUSED_REG_VAR(s8s8s8s8);
225 CASE(SRC_LAYER, S8_ENTRY_S8);
226 CASE(SRC_ITER, S8_ENTRY_S8);
227 CASE(SRC_ITER_C, S8_ENTRY_F32);
228 CASE(WEIGHTS_LAYER, S8_ENTRY_WEIGHT_S8);
229 CASE(WEIGHTS_ITER, S8_ENTRY_WEIGHT_S8);
230 CASE(WEIGHTS_PEEPHOLE, S8_ENTRY_F32);
231 CASE(WEIGHTS_PROJECTION, S8_ENTRY_WEIGHT_S8);
232 CASE(BIAS, S8_ENTRY_F32);
233 CASE(DST_ITER, S8_ENTRY_S8);
234 CASE(DST_ITER_C, S8_ENTRY_F32);
235 CASE(DST_LAYER, S8_ENTRY_S8_EXACT);
236 END_LIST;
237}
238
239CFG(s8s8s8f32) {
240 UNUSED_REG_VAR(s8s8s8f32);
241 CASE(SRC_LAYER, S8_ENTRY_S8);
242 CASE(SRC_ITER, S8_ENTRY_S8);
243 CASE(SRC_ITER_C, S8_ENTRY_F32);
244 CASE(WEIGHTS_LAYER, S8_ENTRY_WEIGHT_S8);
245 CASE(WEIGHTS_ITER, S8_ENTRY_WEIGHT_S8);
246 CASE(WEIGHTS_PEEPHOLE, S8_ENTRY_F32);
247 CASE(WEIGHTS_PROJECTION, S8_ENTRY_WEIGHT_S8);
248 CASE(BIAS, S8_ENTRY_F32);
249 CASE(DST_ITER, S8_ENTRY_S8);
250 CASE(DST_ITER_C, S8_ENTRY_F32);
251 CASE(DST_LAYER, S8_ENTRY_F32);
252 END_LIST;
253}
254
255CFG(f32s8f32s8) {
256 UNUSED_REG_VAR(f32s8f32s8);
257 CASE(SRC_LAYER, S8_ENTRY_S8);
258 CASE(SRC_ITER, S8_ENTRY_F32);
259 CASE(SRC_ITER_C, S8_ENTRY_F32);
260 CASE(WEIGHTS_LAYER, S8_ENTRY_WEIGHT_S8);
261 CASE(WEIGHTS_ITER, S8_ENTRY_WEIGHT_S8);
262 CASE(WEIGHTS_PEEPHOLE, S8_ENTRY_F32);
263 CASE(WEIGHTS_PROJECTION, S8_ENTRY_WEIGHT_S8);
264 CASE(BIAS, S8_ENTRY_F32);
265 CASE(DST_ITER, S8_ENTRY_F32);
266 CASE(DST_ITER_C, S8_ENTRY_F32);
267 CASE(DST_LAYER, S8_ENTRY_S8_EXACT);
268 END_LIST;
269}
270
271CFG(f32s8f32f32) {
272 UNUSED_REG_VAR(f32s8f32f32);
273 CASE(SRC_LAYER, S8_ENTRY_S8);
274 CASE(SRC_ITER, S8_ENTRY_F32);
275 CASE(SRC_ITER_C, S8_ENTRY_F32);
276 CASE(WEIGHTS_LAYER, S8_ENTRY_WEIGHT_S8);
277 CASE(WEIGHTS_ITER, S8_ENTRY_WEIGHT_S8);
278 CASE(WEIGHTS_PEEPHOLE, S8_ENTRY_F32);
279 CASE(WEIGHTS_PROJECTION, S8_ENTRY_WEIGHT_S8);
280 CASE(BIAS, S8_ENTRY_F32);
281 CASE(DST_ITER, S8_ENTRY_F32);
282 CASE(DST_ITER_C, S8_ENTRY_F32);
283 CASE(DST_LAYER, S8_ENTRY_F32);
284 END_LIST;
285}
286
287} // namespace
288
289const dt_conf_t &dt_conf_t::create(const std::string &str, const attr_t &attr) {
290 if (attr.fpmath_mode == dnnl_fpmath_mode_bf16 && str == "f32")
291 return conf_bf32;
292 for (const auto cfg : cfg_list)
293 if (cfg->str() == str) return *cfg;
294 SAFE_V(CRIT);
295 return conf_f32;
296}
297
298} // namespace rnn
299