1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef UTILS_CFG_HPP
18#define UTILS_CFG_HPP
19
20#include <algorithm>
21#include <map>
22#include <vector>
23
24#include "oneapi/dnnl/dnnl_types.h"
25
26#include "common.hpp"
27
28// `base_cfg_t` class is a base class to define configurations across drivers.
29// Driver-level `cfg_t` object defines a constructor which takes prb_t object
30// and a list of kinds. It fills internal `cfg_entry_` vector of entries and
31// should provide:
32// * Data kind it was created for (this is used only for accessing correspondent
33// cfg_entry_t elements).
34// * Original data type, as final data type may be altered by fpmath-mode value,
35// or different from dst_dt sum_dt value.
36// * Data type adjusted for fpmath mode or sum_dt.
37// * A `cfg_map_t` map of ranges for each data type;
38//
39// Based on these inputs, a `cfg_t` public interface may provide all necessary
40// information, like a range for a given kind or density adjustment.
41struct base_cfg_t {
42 struct cfg_entry_t {
43 // Supplies min and max ranges for filling for a given data type.
44 struct cfg_range_t {
45 int range_min;
46 int range_max;
47 };
48
49 using cfg_map_t = std::map<dnnl_data_type_t, cfg_range_t>;
50
51 cfg_entry_t(data_kind_t dk, dnnl_data_type_t orig_dt,
52 dnnl_data_type_t dt, const cfg_map_t &cfg_map)
53 : data_kind_(dk)
54 , orig_data_type_(orig_dt)
55 , data_type_(dt)
56 , cfg_map_(cfg_map) {}
57
58 int get_range_min() const { return get_cfg_range().range_min; }
59 int get_range_max() const { return get_cfg_range().range_max; }
60 int get_range_abs_max() const {
61 return std::max(abs(get_range_min()), abs(get_range_max()));
62 }
63
64 dnnl_data_type_t get_orig_dt() const { return orig_data_type_; }
65 dnnl_data_type_t get_dt() const { return data_type_; }
66 data_kind_t get_dk() const { return data_kind_; }
67
68 private:
69 data_kind_t data_kind_; // For searching elements in base_cfg_t.
70 dnnl_data_type_t orig_data_type_;
71 dnnl_data_type_t data_type_;
72 const cfg_map_t &cfg_map_;
73
74 const cfg_range_t &get_cfg_range() const {
75 const auto it = cfg_map_.find(data_type_);
76 if (it != cfg_map_.end()) return (*it).second;
77 assert(!"unexpected");
78 static cfg_range_t dummy;
79 return dummy;
80 }
81 };
82
83 int get_range_min(data_kind_t dk) const {
84 return cfg_entry_[dk].get_range_min();
85 }
86 int get_range_max(data_kind_t dk) const {
87 return cfg_entry_[dk].get_range_max();
88 }
89
90 dnnl_data_type_t get_orig_dt(data_kind_t dk) const {
91 return cfg_entry_[dk].get_orig_dt();
92 }
93 dnnl_data_type_t get_dt(data_kind_t dk) const {
94 return cfg_entry_[dk].get_dt();
95 }
96
97 // This type allows to differentiate density in filling functions by certain
98 // criteria. Members used in each driver may be different.
99 struct density_args_t {
100 // Data kind like SRC, WEI, DST, etc.
101 data_kind_t data_kind;
102 // Number of accumulators in the chain. Longer chains to be more sparse.
103 int64_t n_acc;
104 };
105
106 virtual float get_density(const density_args_t &density_args) const {
107 return 1.f;
108 }
109
110protected:
111 std::vector<cfg_entry_t> cfg_entry_;
112
113 const cfg_entry_t &operator[](data_kind_t kind) const {
114 for (const auto &e : cfg_entry_) {
115 if (e.get_dk() == kind) return e;
116 }
117 assert(!"unexpected data kind");
118 return cfg_entry_[0];
119 }
120};
121
122#endif
123