1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef KERNEL_SELECTOR_HPP |
18 | #define KERNEL_SELECTOR_HPP |
19 | |
20 | #include "gen_gemm_kernel_generator.hpp" |
21 | |
22 | #include "kernel_catalog.hpp" |
23 | #include "kernel_evaluator.hpp" |
24 | |
25 | #include <algorithm> |
26 | |
27 | namespace dnnl { |
28 | namespace impl { |
29 | namespace gpu { |
30 | namespace jit { |
31 | |
32 | // Basic kernel selection API. |
33 | struct MatchParams { |
34 | kcatalog::Selector selector; |
35 | SizeParams sizes; |
36 | char precisionCExt = 0; |
37 | bool ignoreSizes = false; |
38 | int stepping = 0; |
39 | int alignment[3] = {0, 0, 0}; |
40 | kcatalog::string tags, lateTags; |
41 | int unroll[2] = {0, 0}; |
42 | |
43 | MatchParams() {} |
44 | MatchParams(ngen::HW hw, const GEMMProblem &problem); |
45 | |
46 | private: |
47 | std::array<char, 32> temp; |
48 | }; |
49 | |
50 | const kcatalog::Entry *select(const kcatalog::Catalog &catalog, |
51 | const MatchParams &pattern, const EvaluateParams &eparams, |
52 | EvaluateAuxOutput &aux); |
53 | const kcatalog::Entry *select(const kcatalog::Catalog &catalog, int npatterns, |
54 | const MatchParams *patterns, const EvaluateParams &eparams, |
55 | EvaluateAuxOutput &aux); |
56 | |
57 | // Extended API for iterating over all matching kernels. |
58 | bool matches(const kcatalog::Entry &e, const MatchParams &pattern); |
59 | |
60 | const kcatalog::Entry *lower_bound( |
61 | const kcatalog::Catalog &catalog, const kcatalog::Selector &selector); |
62 | const kcatalog::Entry *upper_bound( |
63 | const kcatalog::Catalog &catalog, const kcatalog::Selector &selector); |
64 | |
65 | class EntryIterator { |
66 | public: |
67 | EntryIterator( |
68 | const kcatalog::Catalog &catalog_, const MatchParams &pattern_) |
69 | : catalog(catalog_), pattern(pattern_) { |
70 | begin = lower_bound(catalog_, pattern_.selector); |
71 | end = upper_bound(catalog_, pattern_.selector); |
72 | current = begin; |
73 | findNextMatch(); |
74 | } |
75 | |
76 | operator bool() const { return current < end; } |
77 | |
78 | EntryIterator &operator++() { |
79 | ++current; |
80 | findNextMatch(); |
81 | return *this; |
82 | } |
83 | |
84 | EntryIterator operator++(int) { |
85 | auto old = *this; |
86 | operator++(); |
87 | return old; |
88 | } |
89 | |
90 | const kcatalog::Entry &operator*() const { return *current; } |
91 | const kcatalog::Entry *operator->() const { return &*current; } |
92 | |
93 | friend bool operator==(const EntryIterator &i1, const EntryIterator &i2) { |
94 | return (i1.current == i2.current); |
95 | } |
96 | friend bool operator!=(const EntryIterator &i1, const EntryIterator &i2) { |
97 | return !(i1 == i2); |
98 | } |
99 | |
100 | protected: |
101 | const kcatalog::Catalog &catalog; |
102 | MatchParams pattern; |
103 | const kcatalog::Entry *begin, *end, *current; |
104 | |
105 | void findNextMatch() { |
106 | for (; current < end; current++) { |
107 | if (matches(*current, pattern)) break; |
108 | } |
109 | } |
110 | }; |
111 | |
112 | inline EntryIterator match( |
113 | const kcatalog::Catalog &catalog, const MatchParams &pattern) { |
114 | return EntryIterator(catalog, pattern); |
115 | } |
116 | |
117 | } // namespace jit |
118 | } // namespace gpu |
119 | } // namespace impl |
120 | } // namespace dnnl |
121 | |
122 | #endif /* header guard */ |
123 | |