1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_JIT_CONV_CONFIG_LOOKUP_TABLE_HPP
18#define GPU_JIT_CONV_CONFIG_LOOKUP_TABLE_HPP
19
20#include <string>
21#include <vector>
22#include <unordered_map>
23
24#include "common/c_types_map.hpp"
25#include "gpu/jit/ir/core.hpp"
26#include "gpu/jit/ir/hw_config.hpp"
27#include "gpu/jit/ngen/ngen_core.hpp"
28
29namespace dnnl {
30namespace impl {
31namespace gpu {
32namespace jit {
33
34class int_filter_t {
35public:
36 int_filter_t() = default;
37 int_filter_t(const std::string &s);
38
39 bool matches(int value) const;
40
41private:
42 int value_;
43 op_kind_t cmp_op_;
44};
45
46class type_filter_t {
47public:
48 type_filter_t() = default;
49
50 type_filter_t(const std::string &s);
51
52 bool matches(const std::vector<data_type_t> &values) const;
53
54private:
55 bool try_parse(
56 const std::string &s, size_t &pos, const std::string &pattern);
57
58 static std::vector<std::string> &all_patterns();
59
60 std::vector<std::string> patterns_;
61};
62
63class fpmath_filter_t {
64public:
65 fpmath_filter_t() = default;
66
67 bool matches(fpmath_mode_t mode) const { return mode == filter_; }
68
69private:
70 fpmath_mode_t filter_ = fpmath_mode::strict;
71};
72
73class conv_problem_t;
74
75class conv_problem_filter_t {
76public:
77 using key_t = std::string;
78
79 conv_problem_filter_t(const std::string &s);
80
81 key_t key() const { return desc_; }
82
83 bool matches(const conv_problem_t &prb, const hw_config_t &hw_cfg) const;
84
85private:
86 bool matches_dir(const conv_problem_t &prb) const;
87
88 bool matches_desc(const conv_problem_t &prb) const;
89
90 bool matches_post_ops(const conv_problem_t &prb) const;
91
92 std::string dir_;
93 type_filter_t type_filter_;
94 fpmath_filter_t fpmath_filter_;
95 int_filter_t mb_filter_;
96 std::string desc_;
97 std::string post_ops_;
98 ngen::HW hw_;
99};
100
101class conv_config_t;
102
103class conv_config_lookup_table_t {
104public:
105 conv_config_lookup_table_t();
106
107 const char *find(const conv_config_t &cfg) const;
108
109private:
110 struct entry_t {
111 conv_problem_filter_t filter;
112 const char *s_params;
113 };
114
115 void add(const char *s_prb, const char *s_params);
116
117 using key_t = conv_problem_filter_t::key_t;
118 std::unordered_map<key_t, std::vector<entry_t>> map_;
119};
120
121} // namespace jit
122} // namespace gpu
123} // namespace impl
124} // namespace dnnl
125
126#endif
127