1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef GPU_JIT_CODEGEN_REG_BUF_HPP |
18 | #define GPU_JIT_CODEGEN_REG_BUF_HPP |
19 | |
20 | #include <vector> |
21 | #include <unordered_set> |
22 | |
23 | #include "gpu/jit/codegen/register_allocator.hpp" |
24 | #include "gpu/jit/ir/grf_permutation.hpp" |
25 | #include "gpu/jit/ngen/ngen.hpp" |
26 | #include "gpu/jit/utils/utils.hpp" |
27 | |
28 | namespace dnnl { |
29 | namespace impl { |
30 | namespace gpu { |
31 | namespace jit { |
32 | |
33 | // Represents a register buffer allocated in blocks. |
34 | class reg_buf_t { |
35 | public: |
36 | reg_buf_t() = default; |
37 | |
38 | reg_buf_t(ngen::HW hw, const ngen::GRFRange &range) |
39 | : hw_(hw) |
40 | , block_regs_(range.getLen()) |
41 | , block_bases_({range.getBase()}) {} |
42 | |
43 | reg_buf_t(ngen::HW hw, int block_regs, const std::vector<int> &block_bases) |
44 | : hw_(hw), block_regs_(block_regs), block_bases_(block_bases) {} |
45 | |
46 | bool is_empty() const { return block_bases_.empty(); } |
47 | |
48 | ngen::HW hw() const { return hw_; } |
49 | |
50 | bool with_permute() const { return !grf_perm_.is_empty(); } |
51 | |
52 | int base(int reg_idx, bool apply_permute = true) const { |
53 | if (apply_permute && !grf_perm_.is_empty()) |
54 | reg_idx = grf_perm_.map(reg_idx); |
55 | ir_assert(reg_idx >= 0 && reg_idx < regs()) |
56 | << "Invalid index: " << reg_idx; |
57 | int block_idx = reg_idx / block_regs_; |
58 | return block_bases_[block_idx] + (reg_idx % block_regs_); |
59 | } |
60 | |
61 | int blocks() const { return int(block_bases_.size()); } |
62 | |
63 | int block_regs() const { return block_regs_; } |
64 | |
65 | int regs() const { return blocks() * block_regs(); } |
66 | |
67 | void set_grf_permutation(const grf_permutation_t &grf_perm) { |
68 | #if !defined(NDEBUG) || defined(GEN_CONV_DEBUG) |
69 | // Check that it's a valid permutation. |
70 | std::unordered_set<int> seen; |
71 | for (int i = 0; i < regs(); i++) { |
72 | int i_mapped = grf_perm.map(i); |
73 | ir_assert(i_mapped >= 0 && i_mapped < regs()); |
74 | seen.insert(i_mapped); |
75 | } |
76 | ir_assert(int(seen.size()) == regs()) << "Invalid permutation." ; |
77 | #endif |
78 | grf_perm_ = grf_perm; |
79 | } |
80 | |
81 | bool operator==(const reg_buf_t &other) const { |
82 | if (hw() != other.hw()) return false; |
83 | if (block_regs() != other.block_regs()) return false; |
84 | if (blocks() != other.blocks()) return false; |
85 | for (int i = 0; i < blocks(); i++) { |
86 | if (block_bases_[i] != other.block_bases_[i]) return false; |
87 | } |
88 | if (grf_perm_ != other.grf_perm_) return false; |
89 | return true; |
90 | } |
91 | |
92 | void claim(reg_allocator_t &ra) const { |
93 | for (int i = 0; i < blocks(); i++) { |
94 | ngen::GRFRange range(block_bases_[i], block_regs_); |
95 | ra.claim(range); |
96 | } |
97 | } |
98 | |
99 | void release(reg_allocator_t &ra) const { |
100 | for (int i = 0; i < blocks(); i++) { |
101 | ngen::GRFRange range(block_bases_[i], block_regs_); |
102 | ra.safeRelease(range); |
103 | } |
104 | } |
105 | |
106 | private: |
107 | ngen::HW hw_; |
108 | int block_regs_; |
109 | std::vector<int> block_bases_; |
110 | grf_permutation_t grf_perm_; |
111 | }; |
112 | |
113 | // ngen::RegData wrapper attached to a register buffer. |
114 | class reg_buf_data_t { |
115 | public: |
116 | reg_buf_data_t() = default; |
117 | |
118 | reg_buf_data_t(const reg_buf_t ®_buf) |
119 | : reg_buf_(std::make_shared<reg_buf_t>(reg_buf)) |
120 | , rd_(ngen::GRF(reg_buf_->base(0))) {} |
121 | |
122 | reg_buf_data_t(const reg_buf_t ®_buf, const ngen::RegData &rd) |
123 | : reg_buf_(std::make_shared<reg_buf_t>(reg_buf)), rd_(rd) {} |
124 | |
125 | reg_buf_data_t(ngen::HW hw, const ngen::Subregister &sub) |
126 | : reg_buf_(std::make_shared<reg_buf_t>( |
127 | hw, ngen::GRFRange(sub.getBase(), 1))) |
128 | , rd_(sub) {} |
129 | |
130 | bool is_empty() const { return !reg_buf_; } |
131 | |
132 | ngen::HW hw() const { return reg_buf_->hw(); } |
133 | |
134 | ngen::DataType type() const { return rd_.getType(); } |
135 | |
136 | int base() const { return rd_.getBase(); } |
137 | |
138 | int byte_offset() const { return rd_.getByteOffset(); } |
139 | |
140 | int offset() const { return rd_.getOffset(); } |
141 | |
142 | int hs() const { return rd_.getHS(); } |
143 | |
144 | const ngen::RegData ®_data() const { return rd_; } |
145 | |
146 | operator ngen::RegData() const { return rd_; } |
147 | |
148 | bool check_bounds( |
149 | int off_bytes, int len_bytes, bool is_dense = false) const { |
150 | ir_assert(off_bytes >= 0); |
151 | ir_assert(len_bytes >= 0); |
152 | if (len_bytes == 0) return true; |
153 | |
154 | int grf_size = ngen::GRF::bytes(hw()); |
155 | int beg_off = (byte_offset() + off_bytes) / grf_size; |
156 | int end_off = (byte_offset() + off_bytes + len_bytes - 1) / grf_size; |
157 | |
158 | // Check for out of bound accesses. |
159 | if (get_grf_buf_index() + end_off >= reg_buf_->regs()) return false; |
160 | |
161 | // Check if access is dense. |
162 | if (is_dense) { |
163 | int base0 = get_grf_base(beg_off); |
164 | for (int i = beg_off + 1; i < end_off + 1; i++) { |
165 | if (get_grf_base(i) != base0 + i) return false; |
166 | } |
167 | } |
168 | return true; |
169 | } |
170 | |
171 | bool is_dense(int bytes) const { |
172 | ir_assert(check_bounds(0, bytes)) << "Invalid access." ; |
173 | return check_bounds(0, bytes, /*is_dense=*/true); |
174 | } |
175 | |
176 | bool operator==(const reg_buf_data_t &other) const { |
177 | return (*reg_buf_ == *other.reg_buf_) && (rd_ == other.rd_); |
178 | } |
179 | |
180 | bool operator!=(const reg_buf_data_t &other) const { |
181 | return !operator==(other); |
182 | } |
183 | |
184 | reg_buf_data_t reinterpret(ngen::DataType new_type) const { |
185 | if (ngen::getBytes(new_type) == ngen::getBytes(type())) { |
186 | auto ret = *this; |
187 | ret.rd_.setType(new_type); |
188 | return ret; |
189 | } |
190 | if (rd_.getHS() == 0) return format(0, new_type, 1, 0); |
191 | ir_error_not_expected() << "Can't reinterpret." ; |
192 | return reg_buf_data_t(); |
193 | } |
194 | |
195 | ngen::Subregister subregister(int off_bytes, |
196 | ngen::DataType type = ngen::DataType::invalid) const { |
197 | ir_assert(check_bounds(off_bytes, 1)) << "Invalid access." ; |
198 | if (type == ngen::DataType::invalid) type = rd_.getType(); |
199 | auto rd = format(off_bytes, type, 1, 0).reg_data(); |
200 | return ngen::Subregister(rd, rd.getOffset(), rd.getType()); |
201 | } |
202 | |
203 | ngen::Subregister subregister(int off, int width, int stride_bytes, |
204 | ngen::DataType type = ngen::DataType::invalid) const { |
205 | if (type == ngen::DataType::invalid) type = rd_.getType(); |
206 | int off_bytes = off * stride_bytes; |
207 | |
208 | ir_assert(check_bounds(off_bytes, stride_bytes * (width - 1))) |
209 | << "Invalid access." ; |
210 | |
211 | auto rd = format(off_bytes, type, 1, 0).reg_data(); |
212 | return ngen::Subregister(rd, rd.getOffset(), rd.getType()); |
213 | } |
214 | |
215 | reg_buf_data_t format(int off_bytes, |
216 | ngen::DataType type = ngen::DataType::invalid, int width = 1, |
217 | int hstride = 1) const { |
218 | if (type == ngen::DataType::invalid) type = rd_.getType(); |
219 | auto grf_size = ngen::GRF::bytes(hw()); |
220 | auto new_off = rd_.getByteOffset() + off_bytes; |
221 | auto new_grf_off = new_off % grf_size; |
222 | auto type_size = ngen::getBytes(type); |
223 | auto grf = get_grf(new_off / grf_size).retype(type); |
224 | |
225 | ir_assert(new_grf_off % type_size == 0); |
226 | |
227 | if (width == 1) { |
228 | hstride = 0; |
229 | } else if (hstride == 0) { |
230 | ir_assert(width == 1); |
231 | } else { |
232 | int max_width = 32 / type_size; |
233 | width = std::min(width, max_width / hstride); |
234 | width = std::min(width, 16); |
235 | } |
236 | int vstride = width * hstride; |
237 | |
238 | int region_bytes = ((width - 1) * hstride + 1) * type_size; |
239 | ir_assert(check_bounds(off_bytes, region_bytes)) << "Invalid access." ; |
240 | |
241 | auto ret = *this; |
242 | ret.rd_ = grf[new_grf_off / type_size](vstride, width, hstride); |
243 | return ret; |
244 | } |
245 | |
246 | reg_buf_data_t unpermute() const { |
247 | int idx = get_grf_buf_index(); |
248 | int base = reg_buf_->base(idx, /*apply_permute=*/false); |
249 | |
250 | auto ret = *this; |
251 | ret.rd_.setBase(base); |
252 | return ret; |
253 | } |
254 | |
255 | private: |
256 | ngen::GRF get_grf(int off_regs) const { |
257 | return ngen::GRF(get_grf_base(off_regs)); |
258 | } |
259 | |
260 | int get_grf_base(int off_regs) const { |
261 | int idx = get_grf_buf_index(); |
262 | return reg_buf_->base(idx + off_regs); |
263 | } |
264 | |
265 | int get_grf_buf_index() const { |
266 | if (reg_buf_->blocks() == 1 && !reg_buf_->with_permute()) { |
267 | return rd_.getBase() - reg_buf_->base(0); |
268 | } |
269 | for (int i = 0; i < reg_buf_->regs(); i++) { |
270 | if (reg_buf_->base(i) == rd_.getBase()) return i; |
271 | } |
272 | ir_error_not_expected(); |
273 | return -1; |
274 | } |
275 | |
276 | std::shared_ptr<reg_buf_t> reg_buf_; |
277 | ngen::RegData rd_; |
278 | }; |
279 | |
280 | } // namespace jit |
281 | } // namespace gpu |
282 | } // namespace impl |
283 | } // namespace dnnl |
284 | |
285 | #endif |
286 | |