1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_JIT_CODEGEN_NGEN_HELPERS_HPP
18#define GPU_JIT_CODEGEN_NGEN_HELPERS_HPP
19
20#include "gpu/jit/ir/core.hpp"
21#include "gpu/jit/ngen/ngen.hpp"
22#include "gpu/jit/ngen/ngen_register_allocator.hpp"
23#include "gpu/jit/utils/ngen_proxy.hpp"
24
25namespace dnnl {
26namespace impl {
27namespace gpu {
28namespace jit {
29
30template <typename T>
31T to_cpp(const ngen::Immediate &imm) {
32 auto u64 = uint64_t(imm);
33 switch (imm.getType()) {
34 case ngen::DataType::w:
35 return (T)utils::bit_cast<std::array<int16_t, 4>>(u64)[0];
36 case ngen::DataType::uw:
37 return (T)utils::bit_cast<std::array<uint16_t, 4>>(u64)[0];
38 case ngen::DataType::d:
39 return (T)utils::bit_cast<std::array<int32_t, 2>>(u64)[0];
40 case ngen::DataType::ud:
41 return (T)utils::bit_cast<std::array<uint32_t, 2>>(u64)[0];
42 case ngen::DataType::q: return (T)utils::bit_cast<int64_t>(u64);
43 case ngen::DataType::uq: return (T)utils::bit_cast<uint64_t>(u64);
44 case ngen::DataType::f:
45 return (T)utils::bit_cast<std::array<float, 2>>(u64)[0];
46 default: ir_error_not_expected();
47 }
48 return 0;
49}
50
51// type_t to ngen::DataType convertor.
52inline ngen::DataType to_ngen(const type_t &type) {
53 ir_assert(type.is_scalar()) << "Expected scalar type.";
54
55#define CASE(_kind, ngen_enum) \
56 if (type.kind() == type_kind_t::_kind) return ngen::DataType::ngen_enum
57
58 CASE(bf16, bf);
59 CASE(f16, hf);
60 CASE(tf32, tf32);
61 CASE(f32, f);
62 CASE(f64, df);
63 CASE(s16, w);
64 CASE(s32, d);
65 CASE(s64, q);
66 CASE(s8, b);
67 CASE(u16, uw);
68 CASE(u32, ud);
69 CASE(u64, uq);
70 CASE(u8, ub);
71
72 if (type == type_t::byte_ptr()) return ngen::DataType::uq;
73
74#undef CASE
75 ir_error_not_expected();
76 return ngen::DataType::invalid;
77}
78
79// ngen::DataType to type_t convertor.
80inline type_t to_ir(ngen::DataType type) {
81#define CASE(_kind, ngen_enum) \
82 if (type == ngen::DataType::ngen_enum) return type_t::_kind();
83
84 CASE(bf16, bf);
85 CASE(f16, hf);
86 CASE(f32, f);
87 CASE(f64, df);
88 CASE(s16, w);
89 CASE(s32, d);
90 CASE(s64, q);
91 CASE(s8, b);
92 CASE(u16, uw);
93 CASE(u32, ud);
94 CASE(u64, uq);
95 CASE(u8, ub);
96
97#undef CASE
98 ir_error_not_expected();
99 return type_t::undef();
100}
101
102inline ngen::Immediate to_ngen(
103 const expr_t &expr, const type_t &type = type_t::undef()) {
104 ir_assert(expr.type().is_scalar()) << "Vector types are not supported.";
105 if (expr.is<int_imm_t>()) {
106 auto &imm = expr.as<int_imm_t>();
107 // No conversion.
108 if (utils::one_of(type, type_t::undef(), expr.type()))
109 return ngen::Immediate(imm.value);
110 // Do conversion.
111#define CASE(cpp_type) \
112 if (type.is_cpp<cpp_type>()) return ngen::Immediate(cpp_type(imm.value))
113
114 CASE(int16_t);
115 CASE(int32_t);
116 CASE(int64_t);
117 CASE(uint16_t);
118 CASE(uint32_t);
119 CASE(uint64_t);
120
121#undef CASE
122 ir_error_not_expected() << "Can't convert expression: " << expr;
123 } else if (expr.is<float_imm_t>()) {
124 ir_assert(utils::one_of(type, type_t::undef(), type_t::f32()))
125 << "Conversion is not supported.";
126 auto &imm = expr.as<float_imm_t>();
127 if (imm.type.is_f32()) { return ngen::Immediate((float)imm.value); }
128 return ngen::Immediate(imm.value);
129 }
130 ir_error_not_expected() << "Can't convert expression: " << expr;
131 return ngen::Immediate();
132}
133
134inline ngen::Bundle to_ngen(const ngen_proxy::Bundle &bundle) {
135 return ngen::Bundle(bundle.bank_id, bundle.bundle_id);
136}
137
138inline ngen::InstructionModifier to_ngen(
139 const ngen_proxy::InstructionModifier &mod_proxy) {
140 ngen::InstructionModifier mod;
141 if (mod_proxy.is_atomic) mod |= ngen::ThreadCtrl::Atomic;
142 if (!mod_proxy.sbid.is_empty()) mod |= ngen::SBID(mod_proxy.sbid.token).set;
143 return mod;
144}
145
146inline ngen::AtomicOp to_ngen(ngen_proxy::AtomicOp atomic_op) {
147 switch (atomic_op) {
148 case ngen_proxy::AtomicOp::fadd: return ngen::AtomicOp::fadd;
149 default: ir_error_not_expected();
150 }
151 return ngen::AtomicOp(std::numeric_limits<uint16_t>::max());
152}
153
154inline ngen::Immediate ngen_negate(const ngen::Immediate &imm) {
155 switch (imm.getType()) {
156 case ngen::DataType::w: return ngen::Immediate(-to_cpp<int16_t>(imm));
157 case ngen::DataType::d: return ngen::Immediate(-to_cpp<int32_t>(imm));
158 case ngen::DataType::f: return ngen::Immediate(-to_cpp<float>(imm));
159 default: ir_error_not_expected();
160 }
161 return ngen::Immediate();
162}
163
164inline bool ngen_is_qw(ngen::DataType type) {
165 return utils::one_of(type, ngen::DataType::q, ngen::DataType::uq);
166}
167
168inline bool ngen_is_dw(ngen::DataType type) {
169 return utils::one_of(type, ngen::DataType::d, ngen::DataType::ud);
170}
171
172inline bool ngen_is_w(ngen::DataType type) {
173 return utils::one_of(type, ngen::DataType::w, ngen::DataType::uw);
174}
175
176inline bool ngen_is_b(ngen::DataType type) {
177 return utils::one_of(type, ngen::DataType::b, ngen::DataType::ub);
178}
179
180inline bool ngen_is_xf(ngen::DataType type) {
181 return utils::one_of(
182 type, ngen::DataType::bf, ngen::DataType::hf, ngen::DataType::f);
183}
184
185inline ngen::Subregister get_subregister(
186 ngen::HW hw, ngen::DataType type, const ngen::GRFRange &r, int idx) {
187 int grf_size = ngen::GRF::bytes(hw);
188 int type_size = ngen::getBytes(type);
189 int off = idx * type_size;
190 return r[off / grf_size].sub((off % grf_size) / type_size, type);
191}
192
193inline ngen::Subregister get_subregister(const ngen::RegData &rd) {
194 return ngen::Subregister(rd, rd.getOffset(), rd.getType());
195}
196
197} // namespace jit
198} // namespace gpu
199} // namespace impl
200} // namespace dnnl
201
202#endif
203