1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_JIT_CONV_GRF_USAGE_HPP
18#define GPU_JIT_CONV_GRF_USAGE_HPP
19
20#include <unordered_map>
21
22#include "gpu/jit/conv/config.hpp"
23#include "gpu/jit/ir/ir.hpp"
24#include "gpu/jit/utils/utils.hpp"
25
26namespace dnnl {
27namespace impl {
28namespace gpu {
29namespace jit {
30
31enum class grf_usage_label_t {
32 unknown,
33 gmem_load,
34 out_buf,
35 reorder,
36 reserved,
37 reused_headers,
38 slm_load,
39 slm_store,
40 tmp_vars,
41 zero_points,
42 _last,
43};
44
45inline std::vector<grf_usage_label_t> all_grf_usage_labels() {
46 std::vector<grf_usage_label_t> ret;
47 for (int i = 0; i < (int)grf_usage_label_t::_last; i++) {
48 ret.push_back((grf_usage_label_t)i);
49 }
50 return ret;
51}
52
53std::string to_string(grf_usage_label_t label);
54std::ostream &operator<<(std::ostream &out, grf_usage_label_t label);
55
56class grf_buf_usage_t {
57public:
58 grf_buf_usage_t(int grf_size) : grf_size_(grf_size) {}
59
60 const object_set_t<expr_t> &bufs() const { return bufs_; }
61
62 std::vector<expr_t> sorted_bufs() const {
63 std::vector<expr_t> ret(bufs_.begin(), bufs_.end());
64 std::sort(ret.begin(), ret.end(), [](const expr_t &a, const expr_t &b) {
65 return a.as<var_t>().name < b.as<var_t>().name;
66 });
67 return ret;
68 }
69
70 bool has(const expr_t &buf) const { return bufs_.find(buf) != bufs_.end(); }
71
72 grf_usage_label_t get_label(const expr_t &buf) const {
73 auto it = buf_labels_.find(buf);
74 ir_assert(it != buf_labels_.end()) << "Buffer not found: " << buf;
75 return it->second;
76 }
77
78 int get_size(const expr_t &buf) const {
79 auto it = buf_sizes_.find(buf);
80 ir_assert(it != buf_sizes_.end()) << "Buffer not found: " << buf;
81 return it->second;
82 }
83
84 void set_label(const expr_t &buf, grf_usage_label_t label) {
85 buf_labels_[buf] = label;
86 }
87
88 void add(const expr_t &buf, int size, grf_usage_label_t label) {
89 bufs_.insert(buf);
90 buf_labels_.emplace(buf, label);
91 buf_sizes_.emplace(buf, size);
92 }
93
94 void remove(const expr_t &buf) {
95 bufs_.erase(buf);
96 buf_labels_.erase(buf);
97 buf_sizes_.erase(buf);
98 }
99
100 int total_regs(grf_usage_label_t label) const {
101 int ret = 0;
102 for (auto &kv : buf_labels_) {
103 if (kv.second != label) continue;
104 ret += utils::div_up(buf_sizes_.at(kv.first), grf_size_);
105 }
106 return ret;
107 }
108
109 std::string str() const;
110
111private:
112 int grf_size_;
113 object_set_t<expr_t> bufs_;
114 object_map_t<expr_t, int> buf_sizes_;
115 object_map_t<expr_t, grf_usage_label_t> buf_labels_;
116};
117
118inline std::ostream &operator<<(
119 std::ostream &out, const grf_buf_usage_t &usage) {
120 out << usage.str();
121 return out;
122}
123
124class grf_usage_t {
125public:
126 grf_usage_t(int grf_size = 0) : grf_size_(grf_size), buf_usage_(grf_size) {
127 for (auto label : all_grf_usage_labels()) {
128 regs_.emplace(label, 0);
129 }
130 }
131
132 bool is_empty() const {
133 for (auto &kv : regs_)
134 if (kv.second != 0) return false;
135 return true;
136 }
137
138 void add(grf_usage_label_t label, int regs) { regs_[label] += regs; }
139
140 void add(const expr_t &buf, int size, grf_usage_label_t label) {
141 add(label, utils::div_up(size, grf_size_));
142 buf_usage_.add(buf, size, label);
143 }
144
145 void add(const grf_buf_usage_t &buf_usage) {
146 for (auto &buf : buf_usage.bufs()) {
147 add(buf, buf_usage.get_size(buf), buf_usage.get_label(buf));
148 }
149 }
150
151 int get(grf_usage_label_t label) const { return regs_.at(label); }
152
153 int total() const {
154 int ret = 0;
155 for (auto &kv : regs_)
156 ret += kv.second;
157 return ret;
158 }
159
160 const grf_buf_usage_t &buf_usage() const { return buf_usage_; }
161
162 std::string str() const;
163
164private:
165 int grf_size_;
166
167 using label_hash_t = ir_utils::enum_hash_t<grf_usage_label_t>;
168 std::unordered_map<grf_usage_label_t, int, label_hash_t> regs_;
169 grf_buf_usage_t buf_usage_;
170};
171
172inline std::ostream &operator<<(std::ostream &out, const grf_usage_t &usage) {
173 out << usage.str();
174 return out;
175}
176
177grf_usage_t estimate_grf_usage(const conv_config_t &cfg);
178grf_usage_t get_grf_usage(const stmt_t &body, int grf_size);
179
180void verify_grf_usage(
181 const conv_config_t &cfg, const stmt_t &body, int external_usage);
182
183} // namespace jit
184} // namespace gpu
185} // namespace impl
186} // namespace dnnl
187
188#endif
189