1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_JIT_CODEGEN_REGISTER_SCOPE_HPP
18#define GPU_JIT_CODEGEN_REGISTER_SCOPE_HPP
19
20#include "gpu/jit/codegen/ngen_helpers.hpp"
21#include "gpu/jit/codegen/reg_buf.hpp"
22#include "gpu/jit/ngen/ngen.hpp"
23#include "gpu/jit/ngen/ngen_register_allocator.hpp"
24
25namespace dnnl {
26namespace impl {
27namespace gpu {
28namespace jit {
29
30// Maintains scoped allocations which are automatically released when the scope
31// is destructed.
32class ngen_register_scope_t {
33public:
34 ngen_register_scope_t(reg_allocator_t &ra) : ra_(ra) {}
35
36 ngen_register_scope_t(const ngen_register_scope_t &) = delete;
37
38 ngen_register_scope_t(ngen_register_scope_t &&other)
39 : ra_(other.ra_)
40 , grf_ranges_(std::move(other.grf_ranges_))
41 , subregisters_(std::move(other.subregisters_)) {}
42
43 reg_allocator_t &register_allocator() { return ra_; }
44
45 ngen::HW hw() const { return ra_.hardware(); }
46
47 ~ngen_register_scope_t() { clear(); }
48
49 void clear() {
50 for (auto &r : grf_ranges_)
51 ra_.safeRelease(r);
52 for (auto &s : subregisters_)
53 ra_.safeRelease(s);
54 for (auto &f : flags_)
55 ra_.safeRelease(f);
56 grf_ranges_.clear();
57 subregisters_.clear();
58 flags_.clear();
59 }
60
61 ngen::GRFRange find_grf_range(int base, int byte_offset) const {
62 if (byte_offset != 0) return ngen::GRFRange();
63 for (auto &r : grf_ranges_)
64 if (r.getBase() == base) return r;
65 return ngen::GRFRange();
66 }
67
68 ngen::Subregister find_sub(int base, int byte_offset) const {
69 for (auto &s : subregisters_)
70 if (s.getBase() == base && s.getByteOffset() == byte_offset)
71 return s;
72 return ngen::Subregister();
73 }
74
75 ngen::GRFRange try_alloc_range(
76 int regs, ngen::Bundle base_bundle = ngen::Bundle()) {
77 auto ret = ra_.try_alloc_range(regs, base_bundle);
78 if (!ret.isInvalid()) grf_ranges_.push_back(ret);
79 return ret;
80 }
81
82 ngen::GRFRange alloc_range(
83 int regs, ngen::Bundle base_bundle = ngen::Bundle()) {
84 auto ret = ra_.alloc_range(regs, base_bundle);
85 grf_ranges_.push_back(ret);
86 return ret;
87 }
88
89 reg_buf_t alloc_reg_buf(
90 int regs, ngen::Bundle base_bundle = ngen::Bundle()) {
91 auto range = ra_.alloc_range(regs, base_bundle);
92 grf_ranges_.push_back(range);
93 return reg_buf_t(ra_.hardware(), range);
94 }
95
96 reg_buf_data_t alloc_reg_buf_data(
97 int regs, ngen::Bundle base_bundle = ngen::Bundle()) {
98 return alloc_reg_buf(regs, base_bundle);
99 }
100
101 reg_buf_data_t alloc_reg_data(const type_t &type, int stride_bytes = -1,
102 ngen::Bundle bundle = ngen::Bundle()) {
103 if (type.is_scalar()) {
104 auto sub = alloc_sub(to_ngen(type), bundle);
105 return reg_buf_data_t(hw(), sub);
106 }
107
108 int type_size = type.scalar().size();
109 if (stride_bytes == -1) stride_bytes = type_size;
110 int grf_size = ngen::GRF::bytes(hw());
111 int regs = utils::div_up(type.elems() * stride_bytes, grf_size);
112 auto buf = alloc_reg_buf(regs, bundle);
113 reg_buf_data_t rbd(buf);
114 return rbd.format(0, to_ngen(type.scalar()), type.elems(),
115 stride_bytes / type_size);
116 }
117
118 ngen::GRF alloc(ngen::Bundle bundle = ngen::Bundle()) {
119 auto range = ra_.alloc_range(1, bundle);
120 grf_ranges_.push_back(range);
121 return range[0];
122 }
123
124 ngen::Subregister alloc_sub(
125 ngen::DataType type, ngen::Bundle bundle = ngen::Bundle()) {
126 auto ret = ra_.alloc_sub(type, bundle);
127 subregisters_.push_back(ret);
128 return ret;
129 }
130
131 ngen::FlagRegister alloc_flag() {
132 auto ret = ra_.alloc_flag();
133 flags_.push_back(ret);
134 return ret;
135 }
136
137 void claim(const ngen::GRFRange &range) {
138 ra_.claim(range);
139 grf_ranges_.push_back(range);
140 }
141
142 void claim(const ngen::Subregister &sub) {
143 ra_.claim(sub);
144 subregisters_.push_back(sub);
145 }
146
147 template <typename T>
148 void safeRelease(T &t) {
149 ra_.safeRelease(t);
150 }
151
152private:
153 reg_allocator_t &ra_;
154
155 std::vector<ngen::GRFRange> grf_ranges_;
156 std::vector<ngen::Subregister> subregisters_;
157 std::vector<ngen::FlagRegister> flags_;
158};
159
160} // namespace jit
161} // namespace gpu
162} // namespace impl
163} // namespace dnnl
164
165#endif
166