1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_JIT_CODEGEN_REG_BUF_HPP
18#define GPU_JIT_CODEGEN_REG_BUF_HPP
19
20#include <vector>
21#include <unordered_set>
22
23#include "gpu/jit/codegen/register_allocator.hpp"
24#include "gpu/jit/ir/grf_permutation.hpp"
25#include "gpu/jit/ngen/ngen.hpp"
26#include "gpu/jit/utils/utils.hpp"
27
28namespace dnnl {
29namespace impl {
30namespace gpu {
31namespace jit {
32
33// Represents a register buffer allocated in blocks.
34class reg_buf_t {
35public:
36 reg_buf_t() = default;
37
38 reg_buf_t(ngen::HW hw, const ngen::GRFRange &range)
39 : hw_(hw)
40 , block_regs_(range.getLen())
41 , block_bases_({range.getBase()}) {}
42
43 reg_buf_t(ngen::HW hw, int block_regs, const std::vector<int> &block_bases)
44 : hw_(hw), block_regs_(block_regs), block_bases_(block_bases) {}
45
46 bool is_empty() const { return block_bases_.empty(); }
47
48 ngen::HW hw() const { return hw_; }
49
50 bool with_permute() const { return !grf_perm_.is_empty(); }
51
52 int base(int reg_idx, bool apply_permute = true) const {
53 if (apply_permute && !grf_perm_.is_empty())
54 reg_idx = grf_perm_.map(reg_idx);
55 ir_assert(reg_idx >= 0 && reg_idx < regs())
56 << "Invalid index: " << reg_idx;
57 int block_idx = reg_idx / block_regs_;
58 return block_bases_[block_idx] + (reg_idx % block_regs_);
59 }
60
61 int blocks() const { return int(block_bases_.size()); }
62
63 int block_regs() const { return block_regs_; }
64
65 int regs() const { return blocks() * block_regs(); }
66
67 void set_grf_permutation(const grf_permutation_t &grf_perm) {
68#if !defined(NDEBUG) || defined(GEN_CONV_DEBUG)
69 // Check that it's a valid permutation.
70 std::unordered_set<int> seen;
71 for (int i = 0; i < regs(); i++) {
72 int i_mapped = grf_perm.map(i);
73 ir_assert(i_mapped >= 0 && i_mapped < regs());
74 seen.insert(i_mapped);
75 }
76 ir_assert(int(seen.size()) == regs()) << "Invalid permutation.";
77#endif
78 grf_perm_ = grf_perm;
79 }
80
81 bool operator==(const reg_buf_t &other) const {
82 if (hw() != other.hw()) return false;
83 if (block_regs() != other.block_regs()) return false;
84 if (blocks() != other.blocks()) return false;
85 for (int i = 0; i < blocks(); i++) {
86 if (block_bases_[i] != other.block_bases_[i]) return false;
87 }
88 if (grf_perm_ != other.grf_perm_) return false;
89 return true;
90 }
91
92 void claim(reg_allocator_t &ra) const {
93 for (int i = 0; i < blocks(); i++) {
94 ngen::GRFRange range(block_bases_[i], block_regs_);
95 ra.claim(range);
96 }
97 }
98
99 void release(reg_allocator_t &ra) const {
100 for (int i = 0; i < blocks(); i++) {
101 ngen::GRFRange range(block_bases_[i], block_regs_);
102 ra.safeRelease(range);
103 }
104 }
105
106private:
107 ngen::HW hw_;
108 int block_regs_;
109 std::vector<int> block_bases_;
110 grf_permutation_t grf_perm_;
111};
112
113// ngen::RegData wrapper attached to a register buffer.
114class reg_buf_data_t {
115public:
116 reg_buf_data_t() = default;
117
118 reg_buf_data_t(const reg_buf_t &reg_buf)
119 : reg_buf_(std::make_shared<reg_buf_t>(reg_buf))
120 , rd_(ngen::GRF(reg_buf_->base(0))) {}
121
122 reg_buf_data_t(const reg_buf_t &reg_buf, const ngen::RegData &rd)
123 : reg_buf_(std::make_shared<reg_buf_t>(reg_buf)), rd_(rd) {}
124
125 reg_buf_data_t(ngen::HW hw, const ngen::Subregister &sub)
126 : reg_buf_(std::make_shared<reg_buf_t>(
127 hw, ngen::GRFRange(sub.getBase(), 1)))
128 , rd_(sub) {}
129
130 bool is_empty() const { return !reg_buf_; }
131
132 ngen::HW hw() const { return reg_buf_->hw(); }
133
134 ngen::DataType type() const { return rd_.getType(); }
135
136 int base() const { return rd_.getBase(); }
137
138 int byte_offset() const { return rd_.getByteOffset(); }
139
140 int offset() const { return rd_.getOffset(); }
141
142 int hs() const { return rd_.getHS(); }
143
144 const ngen::RegData &reg_data() const { return rd_; }
145
146 operator ngen::RegData() const { return rd_; }
147
148 bool check_bounds(
149 int off_bytes, int len_bytes, bool is_dense = false) const {
150 ir_assert(off_bytes >= 0);
151 ir_assert(len_bytes >= 0);
152 if (len_bytes == 0) return true;
153
154 int grf_size = ngen::GRF::bytes(hw());
155 int beg_off = (byte_offset() + off_bytes) / grf_size;
156 int end_off = (byte_offset() + off_bytes + len_bytes - 1) / grf_size;
157
158 // Check for out of bound accesses.
159 if (get_grf_buf_index() + end_off >= reg_buf_->regs()) return false;
160
161 // Check if access is dense.
162 if (is_dense) {
163 int base0 = get_grf_base(beg_off);
164 for (int i = beg_off + 1; i < end_off + 1; i++) {
165 if (get_grf_base(i) != base0 + i) return false;
166 }
167 }
168 return true;
169 }
170
171 bool is_dense(int bytes) const {
172 ir_assert(check_bounds(0, bytes)) << "Invalid access.";
173 return check_bounds(0, bytes, /*is_dense=*/true);
174 }
175
176 bool operator==(const reg_buf_data_t &other) const {
177 return (*reg_buf_ == *other.reg_buf_) && (rd_ == other.rd_);
178 }
179
180 bool operator!=(const reg_buf_data_t &other) const {
181 return !operator==(other);
182 }
183
184 reg_buf_data_t reinterpret(ngen::DataType new_type) const {
185 if (ngen::getBytes(new_type) == ngen::getBytes(type())) {
186 auto ret = *this;
187 ret.rd_.setType(new_type);
188 return ret;
189 }
190 if (rd_.getHS() == 0) return format(0, new_type, 1, 0);
191 ir_error_not_expected() << "Can't reinterpret.";
192 return reg_buf_data_t();
193 }
194
195 ngen::Subregister subregister(int off_bytes,
196 ngen::DataType type = ngen::DataType::invalid) const {
197 ir_assert(check_bounds(off_bytes, 1)) << "Invalid access.";
198 if (type == ngen::DataType::invalid) type = rd_.getType();
199 auto rd = format(off_bytes, type, 1, 0).reg_data();
200 return ngen::Subregister(rd, rd.getOffset(), rd.getType());
201 }
202
203 ngen::Subregister subregister(int off, int width, int stride_bytes,
204 ngen::DataType type = ngen::DataType::invalid) const {
205 if (type == ngen::DataType::invalid) type = rd_.getType();
206 int off_bytes = off * stride_bytes;
207
208 ir_assert(check_bounds(off_bytes, stride_bytes * (width - 1)))
209 << "Invalid access.";
210
211 auto rd = format(off_bytes, type, 1, 0).reg_data();
212 return ngen::Subregister(rd, rd.getOffset(), rd.getType());
213 }
214
215 reg_buf_data_t format(int off_bytes,
216 ngen::DataType type = ngen::DataType::invalid, int width = 1,
217 int hstride = 1) const {
218 if (type == ngen::DataType::invalid) type = rd_.getType();
219 auto grf_size = ngen::GRF::bytes(hw());
220 auto new_off = rd_.getByteOffset() + off_bytes;
221 auto new_grf_off = new_off % grf_size;
222 auto type_size = ngen::getBytes(type);
223 auto grf = get_grf(new_off / grf_size).retype(type);
224
225 ir_assert(new_grf_off % type_size == 0);
226
227 if (width == 1) {
228 hstride = 0;
229 } else if (hstride == 0) {
230 ir_assert(width == 1);
231 } else {
232 int max_width = 32 / type_size;
233 width = std::min(width, max_width / hstride);
234 width = std::min(width, 16);
235 }
236 int vstride = width * hstride;
237
238 int region_bytes = ((width - 1) * hstride + 1) * type_size;
239 ir_assert(check_bounds(off_bytes, region_bytes)) << "Invalid access.";
240
241 auto ret = *this;
242 ret.rd_ = grf[new_grf_off / type_size](vstride, width, hstride);
243 return ret;
244 }
245
246 reg_buf_data_t unpermute() const {
247 int idx = get_grf_buf_index();
248 int base = reg_buf_->base(idx, /*apply_permute=*/false);
249
250 auto ret = *this;
251 ret.rd_.setBase(base);
252 return ret;
253 }
254
255private:
256 ngen::GRF get_grf(int off_regs) const {
257 return ngen::GRF(get_grf_base(off_regs));
258 }
259
260 int get_grf_base(int off_regs) const {
261 int idx = get_grf_buf_index();
262 return reg_buf_->base(idx + off_regs);
263 }
264
265 int get_grf_buf_index() const {
266 if (reg_buf_->blocks() == 1 && !reg_buf_->with_permute()) {
267 return rd_.getBase() - reg_buf_->base(0);
268 }
269 for (int i = 0; i < reg_buf_->regs(); i++) {
270 if (reg_buf_->base(i) == rd_.getBase()) return i;
271 }
272 ir_error_not_expected();
273 return -1;
274 }
275
276 std::shared_ptr<reg_buf_t> reg_buf_;
277 ngen::RegData rd_;
278};
279
280} // namespace jit
281} // namespace gpu
282} // namespace impl
283} // namespace dnnl
284
285#endif
286