1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef KERNEL_SELECTOR_HPP
18#define KERNEL_SELECTOR_HPP
19
20#include "gen_gemm_kernel_generator.hpp"
21
22#include "kernel_catalog.hpp"
23#include "kernel_evaluator.hpp"
24
25#include <algorithm>
26
27namespace dnnl {
28namespace impl {
29namespace gpu {
30namespace jit {
31
32// Basic kernel selection API.
33struct MatchParams {
34 kcatalog::Selector selector;
35 SizeParams sizes;
36 char precisionCExt = 0;
37 bool ignoreSizes = false;
38 int stepping = 0;
39 int alignment[3] = {0, 0, 0};
40 kcatalog::string tags, lateTags;
41 int unroll[2] = {0, 0};
42
43 MatchParams() {}
44 MatchParams(ngen::HW hw, const GEMMProblem &problem);
45
46private:
47 std::array<char, 32> temp;
48};
49
50const kcatalog::Entry *select(const kcatalog::Catalog &catalog,
51 const MatchParams &pattern, const EvaluateParams &eparams,
52 EvaluateAuxOutput &aux);
53const kcatalog::Entry *select(const kcatalog::Catalog &catalog, int npatterns,
54 const MatchParams *patterns, const EvaluateParams &eparams,
55 EvaluateAuxOutput &aux);
56
57// Extended API for iterating over all matching kernels.
58bool matches(const kcatalog::Entry &e, const MatchParams &pattern);
59
60const kcatalog::Entry *lower_bound(
61 const kcatalog::Catalog &catalog, const kcatalog::Selector &selector);
62const kcatalog::Entry *upper_bound(
63 const kcatalog::Catalog &catalog, const kcatalog::Selector &selector);
64
65class EntryIterator {
66public:
67 EntryIterator(
68 const kcatalog::Catalog &catalog_, const MatchParams &pattern_)
69 : catalog(catalog_), pattern(pattern_) {
70 begin = lower_bound(catalog_, pattern_.selector);
71 end = upper_bound(catalog_, pattern_.selector);
72 current = begin;
73 findNextMatch();
74 }
75
76 operator bool() const { return current < end; }
77
78 EntryIterator &operator++() {
79 ++current;
80 findNextMatch();
81 return *this;
82 }
83
84 EntryIterator operator++(int) {
85 auto old = *this;
86 operator++();
87 return old;
88 }
89
90 const kcatalog::Entry &operator*() const { return *current; }
91 const kcatalog::Entry *operator->() const { return &*current; }
92
93 friend bool operator==(const EntryIterator &i1, const EntryIterator &i2) {
94 return (i1.current == i2.current);
95 }
96 friend bool operator!=(const EntryIterator &i1, const EntryIterator &i2) {
97 return !(i1 == i2);
98 }
99
100protected:
101 const kcatalog::Catalog &catalog;
102 MatchParams pattern;
103 const kcatalog::Entry *begin, *end, *current;
104
105 void findNextMatch() {
106 for (; current < end; current++) {
107 if (matches(*current, pattern)) break;
108 }
109 }
110};
111
112inline EntryIterator match(
113 const kcatalog::Catalog &catalog, const MatchParams &pattern) {
114 return EntryIterator(catalog, pattern);
115}
116
117} // namespace jit
118} // namespace gpu
119} // namespace impl
120} // namespace dnnl
121
122#endif /* header guard */
123