1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "brgemm.hpp" |
18 | |
19 | namespace brgemm { |
20 | |
21 | // Adjust density based on accumulation chain. |
22 | float cfg_t::get_density(const cfg_t::density_args_t &density_args) const { |
23 | float density = 1.f; |
24 | if (!is_bench_mode(CORR) || density_args.data_kind != SRC) return density; |
25 | |
26 | // Find the number of accumulators safe to use with the following equations: |
27 | // Integer value can be expressed exactly with floating-point is |
28 | // `PREC = (1 << std::numeric_limit::digits(dst_dt))`. |
29 | // SUM_1_N(VALUES) <= PREC. This should hold to get precise answer. |
30 | // SUM_1_N(VALUES) <= N_ACC * MAX_VALUE <= PREC. It's a top estimate, where |
31 | // MAX_VALUE = MAX_VAL_SRC * MAX_VAL_WEI. |
32 | // SAFE_N_ACC <= PREC / MAX_VALUE. |
33 | |
34 | const auto &cfg_e_src = cfg_entry_[SRC]; |
35 | const auto &cfg_e_wei = cfg_entry_[WEI]; |
36 | const auto &cfg_e_dst = cfg_entry_[DST]; |
37 | |
38 | const int64_t max_value |
39 | = cfg_e_src.get_range_abs_max() * cfg_e_wei.get_range_abs_max(); |
40 | const int64_t safe_n_acc |
41 | = (1LL << digits_dt(cfg_e_dst.get_dt())) / max_value; |
42 | assert(safe_n_acc > 0); |
43 | density /= div_up(density_args.n_acc, safe_n_acc); |
44 | return density; |
45 | } |
46 | |
47 | // Using pow2 values allows to avoid catastrophic cancellation. |
48 | const cfg_t::cfg_entry_t::cfg_map_t &cfg_t::get_cfg_map( |
49 | data_kind_t kind) const { |
50 | static const cfg_t::cfg_entry_t::cfg_map_t src_cfg_map = { |
51 | {{dnnl_f32}, {-64, 64}}, |
52 | {{dnnl_bf16}, {-4, 4}}, |
53 | {{dnnl_f16}, {-4, 4}}, |
54 | {{dnnl_s8}, {-4, 4}}, |
55 | {{dnnl_u8}, {0, 8}}, |
56 | }; |
57 | |
58 | static const cfg_t::cfg_entry_t::cfg_map_t wei_cfg_map = { |
59 | {{dnnl_f32}, {-128, 128}}, |
60 | {{dnnl_bf16}, {-8, 8}}, |
61 | {{dnnl_f16}, {-2, 2}}, |
62 | {{dnnl_s8}, {-4, 4}}, |
63 | }; |
64 | |
65 | static const cfg_t::cfg_entry_t::cfg_map_t bia_cfg_map = { |
66 | {{dnnl_f32}, {-8, 8}}, |
67 | {{dnnl_bf16}, {-8, 8}}, |
68 | {{dnnl_f16}, {-8, 8}}, |
69 | {{dnnl_s8}, {-8, 8}}, |
70 | {{dnnl_u8}, {0, 8}}, |
71 | {{dnnl_s32}, {-8, 8}}, |
72 | }; |
73 | |
74 | static const cfg_t::cfg_entry_t::cfg_map_t dst_cfg_map = { |
75 | {{dnnl_f32}, {-8, 8}}, |
76 | {{dnnl_bf16}, {-8, 8}}, |
77 | {{dnnl_f16}, {-4, 4}}, |
78 | {{dnnl_s8}, {-4, 4}}, |
79 | {{dnnl_u8}, {0, 8}}, |
80 | {{dnnl_s32}, {-128, 128}}, |
81 | }; |
82 | |
83 | switch (kind) { |
84 | case SRC: return src_cfg_map; |
85 | case WEI: return wei_cfg_map; |
86 | case BIA: return bia_cfg_map; |
87 | case DST: return dst_cfg_map; |
88 | default: assert(!"unsupported data kind" ); break; |
89 | } |
90 | static cfg_t::cfg_entry_t::cfg_map_t dummy; |
91 | return dummy; |
92 | } |
93 | |
94 | } // namespace brgemm |
95 | |