1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "matmul.hpp"
18
19namespace matmul {
20
21// Adjust density based on accumulation chain.
22float cfg_t::get_density(const cfg_t::density_args_t &density_args) const {
23 float density = 1.f;
24 if (!is_bench_mode(CORR) || density_args.data_kind != SRC) return density;
25
26 // Find the number of accumulators safe to use with the following equations:
27 // Integer value can be expressed exactly with floating-point is
28 // `PREC = (1 << std::numeric_limit::digits(dst_dt))`.
29 // SUM_1_N(VALUES) <= PREC. This should hold to get precise answer.
30 // SUM_1_N(VALUES) <= N_ACC * MAX_VALUE <= PREC. It's a top estimate, where
31 // MAX_VALUE = MAX_VAL_SRC * MAX_VAL_WEI.
32 // SAFE_N_ACC <= PREC / MAX_VALUE.
33
34 const auto &cfg_e_src = cfg_entry_[SRC];
35 const auto &cfg_e_wei = cfg_entry_[WEI];
36 const auto &cfg_e_dst = cfg_entry_[DST];
37
38 const int64_t max_value
39 = cfg_e_src.get_range_abs_max() * cfg_e_wei.get_range_abs_max();
40 const int64_t safe_n_acc
41 = (1LL << digits_dt(cfg_e_dst.get_dt())) / max_value;
42 assert(safe_n_acc > 0);
43 density /= div_up(density_args.n_acc, safe_n_acc);
44 return density;
45}
46
47// Using pow2 values allows to avoid catastrophic cancellation.
48const cfg_t::cfg_entry_t::cfg_map_t &cfg_t::get_cfg_map(
49 data_kind_t kind) const {
50 static const cfg_t::cfg_entry_t::cfg_map_t src_cfg_map = {
51 {{dnnl_f32}, {-64, 64}},
52 {{dnnl_bf16}, {-4, 4}},
53 {{dnnl_f16}, {-4, 4}},
54 {{dnnl_s8}, {-4, 4}},
55 {{dnnl_u8}, {0, 8}},
56 };
57
58 static const cfg_t::cfg_entry_t::cfg_map_t wei_cfg_map = {
59 {{dnnl_f32}, {-128, 128}},
60 {{dnnl_bf16}, {-8, 8}},
61 {{dnnl_f16}, {-2, 2}},
62 {{dnnl_s8}, {-4, 4}},
63 };
64
65 static const cfg_t::cfg_entry_t::cfg_map_t bia_cfg_map = {
66 {{dnnl_f32}, {-8, 8}},
67 {{dnnl_bf16}, {-8, 8}},
68 {{dnnl_f16}, {-8, 8}},
69 {{dnnl_s8}, {-8, 8}},
70 {{dnnl_u8}, {0, 8}},
71 {{dnnl_s32}, {-8, 8}},
72 };
73
74 static const cfg_t::cfg_entry_t::cfg_map_t dst_cfg_map = {
75 {{dnnl_f32}, {-8, 8}},
76 {{dnnl_bf16}, {-8, 8}},
77 {{dnnl_f16}, {-4, 4}},
78 {{dnnl_s8}, {-4, 4}},
79 {{dnnl_u8}, {0, 8}},
80 {{dnnl_s32}, {-128, 128}},
81 };
82
83 switch (kind) {
84 case SRC: return src_cfg_map;
85 case WEI: return wei_cfg_map;
86 case BIA: return bia_cfg_map;
87 case DST: return dst_cfg_map;
88 default: assert(!"unsupported data kind"); break;
89 }
90 static cfg_t::cfg_entry_t::cfg_map_t dummy;
91 return dummy;
92}
93
94std::string str2cfg(const char *str) {
95 std::string s;
96#define CASE(cfg) \
97 if (!strcasecmp(STRINGIFY(cfg), str)) return s = str, s;
98 CASE(f32);
99 CASE(f16);
100 CASE(f16f16f32);
101 CASE(f16f16s8);
102 CASE(f16f16u8);
103 CASE(u8s8f32);
104 CASE(u8s8s32);
105 CASE(u8s8s8);
106 CASE(u8s8u8);
107 CASE(s8s8f32);
108 CASE(s8s8s32);
109 CASE(s8s8s8);
110 CASE(s8s8u8);
111 CASE(s8s8bf16);
112 CASE(u8s8bf16);
113 CASE(s8s8f16);
114 CASE(u8s8f16);
115 CASE(bf16bf16f32);
116 CASE(bf16bf16bf16);
117 CASE(f32bf16bf16);
118 CASE(bf16f32bf16);
119#undef CASE
120
121 BENCHDNN_PRINT(0, "Config name \'%s\' is not supported.\n", str);
122 SAFE_V(CRIT);
123 return std::string();
124}
125
126void handle_legacy_cfg(
127 std::vector<dnnl_data_type_t> &dt, const std::string &cfg) {
128 if (cfg == "f32")
129 dt = {dnnl_f32};
130 else if (cfg == "bf16bf16bf16")
131 dt = {dnnl_bf16};
132 else if (cfg == "f16")
133 dt = {dnnl_f16};
134 else if (cfg == "f16f16f32")
135 dt = {dnnl_f16, dnnl_f16, dnnl_f16};
136 else if (cfg == "f16f16s8")
137 dt = {dnnl_f16, dnnl_f16, dnnl_s8};
138 else if (cfg == "f16f16u8")
139 dt = {dnnl_f16, dnnl_f16, dnnl_u8};
140 else if (cfg == "u8s8f32")
141 dt = {dnnl_u8, dnnl_s8, dnnl_f32};
142 else if (cfg == "u8s8s32")
143 dt = {dnnl_u8, dnnl_s8, dnnl_s32};
144 else if (cfg == "u8s8s8")
145 dt = {dnnl_u8, dnnl_s8, dnnl_s8};
146 else if (cfg == "u8s8u8")
147 dt = {dnnl_u8, dnnl_s8, dnnl_u8};
148 else if (cfg == "s8s8f32")
149 dt = {dnnl_s8, dnnl_s8, dnnl_f32};
150 else if (cfg == "s8s8s32")
151 dt = {dnnl_s8, dnnl_s8, dnnl_s32};
152 else if (cfg == "s8s8s8")
153 dt = {dnnl_s8, dnnl_s8, dnnl_s8};
154 else if (cfg == "s8s8u8")
155 dt = {dnnl_s8, dnnl_s8, dnnl_u8};
156 else if (cfg == "s8s8bf16")
157 dt = {dnnl_s8, dnnl_s8, dnnl_bf16};
158 else if (cfg == "u8s8bf16")
159 dt = {dnnl_u8, dnnl_s8, dnnl_bf16};
160 else if (cfg == "s8s8f16")
161 dt = {dnnl_s8, dnnl_s8, dnnl_f16};
162 else if (cfg == "u8s8f16")
163 dt = {dnnl_u8, dnnl_s8, dnnl_f16};
164 else if (cfg == "bf16bf16f32")
165 dt = {dnnl_bf16, dnnl_bf16, dnnl_f32};
166 else if (cfg == "f32bf16bf16")
167 dt = {dnnl_f32, dnnl_bf16, dnnl_bf16};
168 else if (cfg == "bf16f32bf16")
169 dt = {dnnl_bf16, dnnl_f32, dnnl_bf16};
170 else {
171 BENCHDNN_PRINT(
172 0, "Config name \'%s\' is not supported.\n", cfg.c_str());
173 SAFE_V(CRIT);
174 }
175}
176
177} // namespace matmul
178