1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef GPU_JIT_CONV_GRF_USAGE_HPP |
18 | #define GPU_JIT_CONV_GRF_USAGE_HPP |
19 | |
20 | #include <unordered_map> |
21 | |
22 | #include "gpu/jit/conv/config.hpp" |
23 | #include "gpu/jit/ir/ir.hpp" |
24 | #include "gpu/jit/utils/utils.hpp" |
25 | |
26 | namespace dnnl { |
27 | namespace impl { |
28 | namespace gpu { |
29 | namespace jit { |
30 | |
31 | enum class grf_usage_label_t { |
32 | unknown, |
33 | gmem_load, |
34 | out_buf, |
35 | reorder, |
36 | reserved, |
37 | , |
38 | slm_load, |
39 | slm_store, |
40 | tmp_vars, |
41 | zero_points, |
42 | _last, |
43 | }; |
44 | |
45 | inline std::vector<grf_usage_label_t> all_grf_usage_labels() { |
46 | std::vector<grf_usage_label_t> ret; |
47 | for (int i = 0; i < (int)grf_usage_label_t::_last; i++) { |
48 | ret.push_back((grf_usage_label_t)i); |
49 | } |
50 | return ret; |
51 | } |
52 | |
53 | std::string to_string(grf_usage_label_t label); |
54 | std::ostream &operator<<(std::ostream &out, grf_usage_label_t label); |
55 | |
56 | class grf_buf_usage_t { |
57 | public: |
58 | grf_buf_usage_t(int grf_size) : grf_size_(grf_size) {} |
59 | |
60 | const object_set_t<expr_t> &bufs() const { return bufs_; } |
61 | |
62 | std::vector<expr_t> sorted_bufs() const { |
63 | std::vector<expr_t> ret(bufs_.begin(), bufs_.end()); |
64 | std::sort(ret.begin(), ret.end(), [](const expr_t &a, const expr_t &b) { |
65 | return a.as<var_t>().name < b.as<var_t>().name; |
66 | }); |
67 | return ret; |
68 | } |
69 | |
70 | bool has(const expr_t &buf) const { return bufs_.find(buf) != bufs_.end(); } |
71 | |
72 | grf_usage_label_t get_label(const expr_t &buf) const { |
73 | auto it = buf_labels_.find(buf); |
74 | ir_assert(it != buf_labels_.end()) << "Buffer not found: " << buf; |
75 | return it->second; |
76 | } |
77 | |
78 | int get_size(const expr_t &buf) const { |
79 | auto it = buf_sizes_.find(buf); |
80 | ir_assert(it != buf_sizes_.end()) << "Buffer not found: " << buf; |
81 | return it->second; |
82 | } |
83 | |
84 | void set_label(const expr_t &buf, grf_usage_label_t label) { |
85 | buf_labels_[buf] = label; |
86 | } |
87 | |
88 | void add(const expr_t &buf, int size, grf_usage_label_t label) { |
89 | bufs_.insert(buf); |
90 | buf_labels_.emplace(buf, label); |
91 | buf_sizes_.emplace(buf, size); |
92 | } |
93 | |
94 | void remove(const expr_t &buf) { |
95 | bufs_.erase(buf); |
96 | buf_labels_.erase(buf); |
97 | buf_sizes_.erase(buf); |
98 | } |
99 | |
100 | int total_regs(grf_usage_label_t label) const { |
101 | int ret = 0; |
102 | for (auto &kv : buf_labels_) { |
103 | if (kv.second != label) continue; |
104 | ret += utils::div_up(buf_sizes_.at(kv.first), grf_size_); |
105 | } |
106 | return ret; |
107 | } |
108 | |
109 | std::string str() const; |
110 | |
111 | private: |
112 | int grf_size_; |
113 | object_set_t<expr_t> bufs_; |
114 | object_map_t<expr_t, int> buf_sizes_; |
115 | object_map_t<expr_t, grf_usage_label_t> buf_labels_; |
116 | }; |
117 | |
118 | inline std::ostream &operator<<( |
119 | std::ostream &out, const grf_buf_usage_t &usage) { |
120 | out << usage.str(); |
121 | return out; |
122 | } |
123 | |
124 | class grf_usage_t { |
125 | public: |
126 | grf_usage_t(int grf_size = 0) : grf_size_(grf_size), buf_usage_(grf_size) { |
127 | for (auto label : all_grf_usage_labels()) { |
128 | regs_.emplace(label, 0); |
129 | } |
130 | } |
131 | |
132 | bool is_empty() const { |
133 | for (auto &kv : regs_) |
134 | if (kv.second != 0) return false; |
135 | return true; |
136 | } |
137 | |
138 | void add(grf_usage_label_t label, int regs) { regs_[label] += regs; } |
139 | |
140 | void add(const expr_t &buf, int size, grf_usage_label_t label) { |
141 | add(label, utils::div_up(size, grf_size_)); |
142 | buf_usage_.add(buf, size, label); |
143 | } |
144 | |
145 | void add(const grf_buf_usage_t &buf_usage) { |
146 | for (auto &buf : buf_usage.bufs()) { |
147 | add(buf, buf_usage.get_size(buf), buf_usage.get_label(buf)); |
148 | } |
149 | } |
150 | |
151 | int get(grf_usage_label_t label) const { return regs_.at(label); } |
152 | |
153 | int total() const { |
154 | int ret = 0; |
155 | for (auto &kv : regs_) |
156 | ret += kv.second; |
157 | return ret; |
158 | } |
159 | |
160 | const grf_buf_usage_t &buf_usage() const { return buf_usage_; } |
161 | |
162 | std::string str() const; |
163 | |
164 | private: |
165 | int grf_size_; |
166 | |
167 | using label_hash_t = ir_utils::enum_hash_t<grf_usage_label_t>; |
168 | std::unordered_map<grf_usage_label_t, int, label_hash_t> regs_; |
169 | grf_buf_usage_t buf_usage_; |
170 | }; |
171 | |
172 | inline std::ostream &operator<<(std::ostream &out, const grf_usage_t &usage) { |
173 | out << usage.str(); |
174 | return out; |
175 | } |
176 | |
177 | grf_usage_t estimate_grf_usage(const conv_config_t &cfg); |
178 | grf_usage_t get_grf_usage(const stmt_t &body, int grf_size); |
179 | |
180 | void verify_grf_usage( |
181 | const conv_config_t &cfg, const stmt_t &body, int external_usage); |
182 | |
183 | } // namespace jit |
184 | } // namespace gpu |
185 | } // namespace impl |
186 | } // namespace dnnl |
187 | |
188 | #endif |
189 | |