1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef UTILS_CFG_HPP |
18 | #define UTILS_CFG_HPP |
19 | |
20 | #include <algorithm> |
21 | #include <map> |
22 | #include <vector> |
23 | |
24 | #include "oneapi/dnnl/dnnl_types.h" |
25 | |
26 | #include "common.hpp" |
27 | |
28 | // `base_cfg_t` class is a base class to define configurations across drivers. |
29 | // Driver-level `cfg_t` object defines a constructor which takes prb_t object |
30 | // and a list of kinds. It fills internal `cfg_entry_` vector of entries and |
31 | // should provide: |
32 | // * Data kind it was created for (this is used only for accessing correspondent |
33 | // cfg_entry_t elements). |
34 | // * Original data type, as final data type may be altered by fpmath-mode value, |
35 | // or different from dst_dt sum_dt value. |
36 | // * Data type adjusted for fpmath mode or sum_dt. |
37 | // * A `cfg_map_t` map of ranges for each data type; |
38 | // |
39 | // Based on these inputs, a `cfg_t` public interface may provide all necessary |
40 | // information, like a range for a given kind or density adjustment. |
41 | struct base_cfg_t { |
42 | struct cfg_entry_t { |
43 | // Supplies min and max ranges for filling for a given data type. |
44 | struct cfg_range_t { |
45 | int range_min; |
46 | int range_max; |
47 | }; |
48 | |
49 | using cfg_map_t = std::map<dnnl_data_type_t, cfg_range_t>; |
50 | |
51 | cfg_entry_t(data_kind_t dk, dnnl_data_type_t orig_dt, |
52 | dnnl_data_type_t dt, const cfg_map_t &cfg_map) |
53 | : data_kind_(dk) |
54 | , orig_data_type_(orig_dt) |
55 | , data_type_(dt) |
56 | , cfg_map_(cfg_map) {} |
57 | |
58 | int get_range_min() const { return get_cfg_range().range_min; } |
59 | int get_range_max() const { return get_cfg_range().range_max; } |
60 | int get_range_abs_max() const { |
61 | return std::max(abs(get_range_min()), abs(get_range_max())); |
62 | } |
63 | |
64 | dnnl_data_type_t get_orig_dt() const { return orig_data_type_; } |
65 | dnnl_data_type_t get_dt() const { return data_type_; } |
66 | data_kind_t get_dk() const { return data_kind_; } |
67 | |
68 | private: |
69 | data_kind_t data_kind_; // For searching elements in base_cfg_t. |
70 | dnnl_data_type_t orig_data_type_; |
71 | dnnl_data_type_t data_type_; |
72 | const cfg_map_t &cfg_map_; |
73 | |
74 | const cfg_range_t &get_cfg_range() const { |
75 | const auto it = cfg_map_.find(data_type_); |
76 | if (it != cfg_map_.end()) return (*it).second; |
77 | assert(!"unexpected" ); |
78 | static cfg_range_t dummy; |
79 | return dummy; |
80 | } |
81 | }; |
82 | |
83 | int get_range_min(data_kind_t dk) const { |
84 | return cfg_entry_[dk].get_range_min(); |
85 | } |
86 | int get_range_max(data_kind_t dk) const { |
87 | return cfg_entry_[dk].get_range_max(); |
88 | } |
89 | |
90 | dnnl_data_type_t get_orig_dt(data_kind_t dk) const { |
91 | return cfg_entry_[dk].get_orig_dt(); |
92 | } |
93 | dnnl_data_type_t get_dt(data_kind_t dk) const { |
94 | return cfg_entry_[dk].get_dt(); |
95 | } |
96 | |
97 | // This type allows to differentiate density in filling functions by certain |
98 | // criteria. Members used in each driver may be different. |
99 | struct density_args_t { |
100 | // Data kind like SRC, WEI, DST, etc. |
101 | data_kind_t data_kind; |
102 | // Number of accumulators in the chain. Longer chains to be more sparse. |
103 | int64_t n_acc; |
104 | }; |
105 | |
106 | virtual float get_density(const density_args_t &density_args) const { |
107 | return 1.f; |
108 | } |
109 | |
110 | protected: |
111 | std::vector<cfg_entry_t> cfg_entry_; |
112 | |
113 | const cfg_entry_t &operator[](data_kind_t kind) const { |
114 | for (const auto &e : cfg_entry_) { |
115 | if (e.get_dk() == kind) return e; |
116 | } |
117 | assert(!"unexpected data kind" ); |
118 | return cfg_entry_[0]; |
119 | } |
120 | }; |
121 | |
122 | #endif |
123 | |