1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "matmul.hpp" |
18 | |
19 | namespace matmul { |
20 | |
21 | // Adjust density based on accumulation chain. |
22 | float cfg_t::get_density(const cfg_t::density_args_t &density_args) const { |
23 | float density = 1.f; |
24 | if (!is_bench_mode(CORR) || density_args.data_kind != SRC) return density; |
25 | |
26 | // Find the number of accumulators safe to use with the following equations: |
27 | // Integer value can be expressed exactly with floating-point is |
28 | // `PREC = (1 << std::numeric_limit::digits(dst_dt))`. |
29 | // SUM_1_N(VALUES) <= PREC. This should hold to get precise answer. |
30 | // SUM_1_N(VALUES) <= N_ACC * MAX_VALUE <= PREC. It's a top estimate, where |
31 | // MAX_VALUE = MAX_VAL_SRC * MAX_VAL_WEI. |
32 | // SAFE_N_ACC <= PREC / MAX_VALUE. |
33 | |
34 | const auto &cfg_e_src = cfg_entry_[SRC]; |
35 | const auto &cfg_e_wei = cfg_entry_[WEI]; |
36 | const auto &cfg_e_dst = cfg_entry_[DST]; |
37 | |
38 | const int64_t max_value |
39 | = cfg_e_src.get_range_abs_max() * cfg_e_wei.get_range_abs_max(); |
40 | const int64_t safe_n_acc |
41 | = (1LL << digits_dt(cfg_e_dst.get_dt())) / max_value; |
42 | assert(safe_n_acc > 0); |
43 | density /= div_up(density_args.n_acc, safe_n_acc); |
44 | return density; |
45 | } |
46 | |
47 | // Using pow2 values allows to avoid catastrophic cancellation. |
48 | const cfg_t::cfg_entry_t::cfg_map_t &cfg_t::get_cfg_map( |
49 | data_kind_t kind) const { |
50 | static const cfg_t::cfg_entry_t::cfg_map_t src_cfg_map = { |
51 | {{dnnl_f32}, {-64, 64}}, |
52 | {{dnnl_bf16}, {-4, 4}}, |
53 | {{dnnl_f16}, {-4, 4}}, |
54 | {{dnnl_s8}, {-4, 4}}, |
55 | {{dnnl_u8}, {0, 8}}, |
56 | }; |
57 | |
58 | static const cfg_t::cfg_entry_t::cfg_map_t wei_cfg_map = { |
59 | {{dnnl_f32}, {-128, 128}}, |
60 | {{dnnl_bf16}, {-8, 8}}, |
61 | {{dnnl_f16}, {-2, 2}}, |
62 | {{dnnl_s8}, {-4, 4}}, |
63 | }; |
64 | |
65 | static const cfg_t::cfg_entry_t::cfg_map_t bia_cfg_map = { |
66 | {{dnnl_f32}, {-8, 8}}, |
67 | {{dnnl_bf16}, {-8, 8}}, |
68 | {{dnnl_f16}, {-8, 8}}, |
69 | {{dnnl_s8}, {-8, 8}}, |
70 | {{dnnl_u8}, {0, 8}}, |
71 | {{dnnl_s32}, {-8, 8}}, |
72 | }; |
73 | |
74 | static const cfg_t::cfg_entry_t::cfg_map_t dst_cfg_map = { |
75 | {{dnnl_f32}, {-8, 8}}, |
76 | {{dnnl_bf16}, {-8, 8}}, |
77 | {{dnnl_f16}, {-4, 4}}, |
78 | {{dnnl_s8}, {-4, 4}}, |
79 | {{dnnl_u8}, {0, 8}}, |
80 | {{dnnl_s32}, {-128, 128}}, |
81 | }; |
82 | |
83 | switch (kind) { |
84 | case SRC: return src_cfg_map; |
85 | case WEI: return wei_cfg_map; |
86 | case BIA: return bia_cfg_map; |
87 | case DST: return dst_cfg_map; |
88 | default: assert(!"unsupported data kind" ); break; |
89 | } |
90 | static cfg_t::cfg_entry_t::cfg_map_t dummy; |
91 | return dummy; |
92 | } |
93 | |
94 | std::string str2cfg(const char *str) { |
95 | std::string s; |
96 | #define CASE(cfg) \ |
97 | if (!strcasecmp(STRINGIFY(cfg), str)) return s = str, s; |
98 | CASE(f32); |
99 | CASE(f16); |
100 | CASE(f16f16f32); |
101 | CASE(f16f16s8); |
102 | CASE(f16f16u8); |
103 | CASE(u8s8f32); |
104 | CASE(u8s8s32); |
105 | CASE(u8s8s8); |
106 | CASE(u8s8u8); |
107 | CASE(s8s8f32); |
108 | CASE(s8s8s32); |
109 | CASE(s8s8s8); |
110 | CASE(s8s8u8); |
111 | CASE(s8s8bf16); |
112 | CASE(u8s8bf16); |
113 | CASE(s8s8f16); |
114 | CASE(u8s8f16); |
115 | CASE(bf16bf16f32); |
116 | CASE(bf16bf16bf16); |
117 | CASE(f32bf16bf16); |
118 | CASE(bf16f32bf16); |
119 | #undef CASE |
120 | |
121 | BENCHDNN_PRINT(0, "Config name \'%s\' is not supported.\n" , str); |
122 | SAFE_V(CRIT); |
123 | return std::string(); |
124 | } |
125 | |
126 | void handle_legacy_cfg( |
127 | std::vector<dnnl_data_type_t> &dt, const std::string &cfg) { |
128 | if (cfg == "f32" ) |
129 | dt = {dnnl_f32}; |
130 | else if (cfg == "bf16bf16bf16" ) |
131 | dt = {dnnl_bf16}; |
132 | else if (cfg == "f16" ) |
133 | dt = {dnnl_f16}; |
134 | else if (cfg == "f16f16f32" ) |
135 | dt = {dnnl_f16, dnnl_f16, dnnl_f16}; |
136 | else if (cfg == "f16f16s8" ) |
137 | dt = {dnnl_f16, dnnl_f16, dnnl_s8}; |
138 | else if (cfg == "f16f16u8" ) |
139 | dt = {dnnl_f16, dnnl_f16, dnnl_u8}; |
140 | else if (cfg == "u8s8f32" ) |
141 | dt = {dnnl_u8, dnnl_s8, dnnl_f32}; |
142 | else if (cfg == "u8s8s32" ) |
143 | dt = {dnnl_u8, dnnl_s8, dnnl_s32}; |
144 | else if (cfg == "u8s8s8" ) |
145 | dt = {dnnl_u8, dnnl_s8, dnnl_s8}; |
146 | else if (cfg == "u8s8u8" ) |
147 | dt = {dnnl_u8, dnnl_s8, dnnl_u8}; |
148 | else if (cfg == "s8s8f32" ) |
149 | dt = {dnnl_s8, dnnl_s8, dnnl_f32}; |
150 | else if (cfg == "s8s8s32" ) |
151 | dt = {dnnl_s8, dnnl_s8, dnnl_s32}; |
152 | else if (cfg == "s8s8s8" ) |
153 | dt = {dnnl_s8, dnnl_s8, dnnl_s8}; |
154 | else if (cfg == "s8s8u8" ) |
155 | dt = {dnnl_s8, dnnl_s8, dnnl_u8}; |
156 | else if (cfg == "s8s8bf16" ) |
157 | dt = {dnnl_s8, dnnl_s8, dnnl_bf16}; |
158 | else if (cfg == "u8s8bf16" ) |
159 | dt = {dnnl_u8, dnnl_s8, dnnl_bf16}; |
160 | else if (cfg == "s8s8f16" ) |
161 | dt = {dnnl_s8, dnnl_s8, dnnl_f16}; |
162 | else if (cfg == "u8s8f16" ) |
163 | dt = {dnnl_u8, dnnl_s8, dnnl_f16}; |
164 | else if (cfg == "bf16bf16f32" ) |
165 | dt = {dnnl_bf16, dnnl_bf16, dnnl_f32}; |
166 | else if (cfg == "f32bf16bf16" ) |
167 | dt = {dnnl_f32, dnnl_bf16, dnnl_bf16}; |
168 | else if (cfg == "bf16f32bf16" ) |
169 | dt = {dnnl_bf16, dnnl_f32, dnnl_bf16}; |
170 | else { |
171 | BENCHDNN_PRINT( |
172 | 0, "Config name \'%s\' is not supported.\n" , cfg.c_str()); |
173 | SAFE_V(CRIT); |
174 | } |
175 | } |
176 | |
177 | } // namespace matmul |
178 | |