1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef GPU_JIT_CODEGEN_NGEN_HELPERS_HPP |
18 | #define GPU_JIT_CODEGEN_NGEN_HELPERS_HPP |
19 | |
20 | #include "gpu/jit/ir/core.hpp" |
21 | #include "gpu/jit/ngen/ngen.hpp" |
22 | #include "gpu/jit/ngen/ngen_register_allocator.hpp" |
23 | #include "gpu/jit/utils/ngen_proxy.hpp" |
24 | |
25 | namespace dnnl { |
26 | namespace impl { |
27 | namespace gpu { |
28 | namespace jit { |
29 | |
30 | template <typename T> |
31 | T to_cpp(const ngen::Immediate &imm) { |
32 | auto u64 = uint64_t(imm); |
33 | switch (imm.getType()) { |
34 | case ngen::DataType::w: |
35 | return (T)utils::bit_cast<std::array<int16_t, 4>>(u64)[0]; |
36 | case ngen::DataType::uw: |
37 | return (T)utils::bit_cast<std::array<uint16_t, 4>>(u64)[0]; |
38 | case ngen::DataType::d: |
39 | return (T)utils::bit_cast<std::array<int32_t, 2>>(u64)[0]; |
40 | case ngen::DataType::ud: |
41 | return (T)utils::bit_cast<std::array<uint32_t, 2>>(u64)[0]; |
42 | case ngen::DataType::q: return (T)utils::bit_cast<int64_t>(u64); |
43 | case ngen::DataType::uq: return (T)utils::bit_cast<uint64_t>(u64); |
44 | case ngen::DataType::f: |
45 | return (T)utils::bit_cast<std::array<float, 2>>(u64)[0]; |
46 | default: ir_error_not_expected(); |
47 | } |
48 | return 0; |
49 | } |
50 | |
51 | // type_t to ngen::DataType convertor. |
52 | inline ngen::DataType to_ngen(const type_t &type) { |
53 | ir_assert(type.is_scalar()) << "Expected scalar type." ; |
54 | |
55 | #define CASE(_kind, ngen_enum) \ |
56 | if (type.kind() == type_kind_t::_kind) return ngen::DataType::ngen_enum |
57 | |
58 | CASE(bf16, bf); |
59 | CASE(f16, hf); |
60 | CASE(tf32, tf32); |
61 | CASE(f32, f); |
62 | CASE(f64, df); |
63 | CASE(s16, w); |
64 | CASE(s32, d); |
65 | CASE(s64, q); |
66 | CASE(s8, b); |
67 | CASE(u16, uw); |
68 | CASE(u32, ud); |
69 | CASE(u64, uq); |
70 | CASE(u8, ub); |
71 | |
72 | if (type == type_t::byte_ptr()) return ngen::DataType::uq; |
73 | |
74 | #undef CASE |
75 | ir_error_not_expected(); |
76 | return ngen::DataType::invalid; |
77 | } |
78 | |
79 | // ngen::DataType to type_t convertor. |
80 | inline type_t to_ir(ngen::DataType type) { |
81 | #define CASE(_kind, ngen_enum) \ |
82 | if (type == ngen::DataType::ngen_enum) return type_t::_kind(); |
83 | |
84 | CASE(bf16, bf); |
85 | CASE(f16, hf); |
86 | CASE(f32, f); |
87 | CASE(f64, df); |
88 | CASE(s16, w); |
89 | CASE(s32, d); |
90 | CASE(s64, q); |
91 | CASE(s8, b); |
92 | CASE(u16, uw); |
93 | CASE(u32, ud); |
94 | CASE(u64, uq); |
95 | CASE(u8, ub); |
96 | |
97 | #undef CASE |
98 | ir_error_not_expected(); |
99 | return type_t::undef(); |
100 | } |
101 | |
102 | inline ngen::Immediate to_ngen( |
103 | const expr_t &expr, const type_t &type = type_t::undef()) { |
104 | ir_assert(expr.type().is_scalar()) << "Vector types are not supported." ; |
105 | if (expr.is<int_imm_t>()) { |
106 | auto &imm = expr.as<int_imm_t>(); |
107 | // No conversion. |
108 | if (utils::one_of(type, type_t::undef(), expr.type())) |
109 | return ngen::Immediate(imm.value); |
110 | // Do conversion. |
111 | #define CASE(cpp_type) \ |
112 | if (type.is_cpp<cpp_type>()) return ngen::Immediate(cpp_type(imm.value)) |
113 | |
114 | CASE(int16_t); |
115 | CASE(int32_t); |
116 | CASE(int64_t); |
117 | CASE(uint16_t); |
118 | CASE(uint32_t); |
119 | CASE(uint64_t); |
120 | |
121 | #undef CASE |
122 | ir_error_not_expected() << "Can't convert expression: " << expr; |
123 | } else if (expr.is<float_imm_t>()) { |
124 | ir_assert(utils::one_of(type, type_t::undef(), type_t::f32())) |
125 | << "Conversion is not supported." ; |
126 | auto &imm = expr.as<float_imm_t>(); |
127 | if (imm.type.is_f32()) { return ngen::Immediate((float)imm.value); } |
128 | return ngen::Immediate(imm.value); |
129 | } |
130 | ir_error_not_expected() << "Can't convert expression: " << expr; |
131 | return ngen::Immediate(); |
132 | } |
133 | |
134 | inline ngen::Bundle to_ngen(const ngen_proxy::Bundle &bundle) { |
135 | return ngen::Bundle(bundle.bank_id, bundle.bundle_id); |
136 | } |
137 | |
138 | inline ngen::InstructionModifier to_ngen( |
139 | const ngen_proxy::InstructionModifier &mod_proxy) { |
140 | ngen::InstructionModifier mod; |
141 | if (mod_proxy.is_atomic) mod |= ngen::ThreadCtrl::Atomic; |
142 | if (!mod_proxy.sbid.is_empty()) mod |= ngen::SBID(mod_proxy.sbid.token).set; |
143 | return mod; |
144 | } |
145 | |
146 | inline ngen::AtomicOp to_ngen(ngen_proxy::AtomicOp atomic_op) { |
147 | switch (atomic_op) { |
148 | case ngen_proxy::AtomicOp::fadd: return ngen::AtomicOp::fadd; |
149 | default: ir_error_not_expected(); |
150 | } |
151 | return ngen::AtomicOp(std::numeric_limits<uint16_t>::max()); |
152 | } |
153 | |
154 | inline ngen::Immediate ngen_negate(const ngen::Immediate &imm) { |
155 | switch (imm.getType()) { |
156 | case ngen::DataType::w: return ngen::Immediate(-to_cpp<int16_t>(imm)); |
157 | case ngen::DataType::d: return ngen::Immediate(-to_cpp<int32_t>(imm)); |
158 | case ngen::DataType::f: return ngen::Immediate(-to_cpp<float>(imm)); |
159 | default: ir_error_not_expected(); |
160 | } |
161 | return ngen::Immediate(); |
162 | } |
163 | |
164 | inline bool ngen_is_qw(ngen::DataType type) { |
165 | return utils::one_of(type, ngen::DataType::q, ngen::DataType::uq); |
166 | } |
167 | |
168 | inline bool ngen_is_dw(ngen::DataType type) { |
169 | return utils::one_of(type, ngen::DataType::d, ngen::DataType::ud); |
170 | } |
171 | |
172 | inline bool ngen_is_w(ngen::DataType type) { |
173 | return utils::one_of(type, ngen::DataType::w, ngen::DataType::uw); |
174 | } |
175 | |
176 | inline bool ngen_is_b(ngen::DataType type) { |
177 | return utils::one_of(type, ngen::DataType::b, ngen::DataType::ub); |
178 | } |
179 | |
180 | inline bool ngen_is_xf(ngen::DataType type) { |
181 | return utils::one_of( |
182 | type, ngen::DataType::bf, ngen::DataType::hf, ngen::DataType::f); |
183 | } |
184 | |
185 | inline ngen::Subregister get_subregister( |
186 | ngen::HW hw, ngen::DataType type, const ngen::GRFRange &r, int idx) { |
187 | int grf_size = ngen::GRF::bytes(hw); |
188 | int type_size = ngen::getBytes(type); |
189 | int off = idx * type_size; |
190 | return r[off / grf_size].sub((off % grf_size) / type_size, type); |
191 | } |
192 | |
193 | inline ngen::Subregister get_subregister(const ngen::RegData &rd) { |
194 | return ngen::Subregister(rd, rd.getOffset(), rd.getType()); |
195 | } |
196 | |
197 | } // namespace jit |
198 | } // namespace gpu |
199 | } // namespace impl |
200 | } // namespace dnnl |
201 | |
202 | #endif |
203 | |