1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_JIT_CODEGEN_SEND_HPP
18#define GPU_JIT_CODEGEN_SEND_HPP
19
20#include "gpu/jit/codegen/kernel.hpp"
21#include "gpu/jit/codegen/register_scope.hpp"
22#include "gpu/jit/ir/message.hpp"
23#include "gpu/jit/ir/tensor.hpp"
24#include "gpu/jit/ngen/ngen.hpp"
25
26namespace dnnl {
27namespace impl {
28namespace gpu {
29namespace jit {
30
31template <typename DataSpecT, typename = void>
32struct atomic_helper_t {
33 template <typename GeneratorT>
34 static void call(GeneratorT *, ngen::AtomicOp,
35 const ngen::InstructionModifier &, const DataSpecT &,
36 ngen::AddressBase, const ngen::RegData &, const ngen::RegData &) {
37 ir_error_not_expected()
38 << "Unknown DataSpec: atomics are not supported.";
39 }
40};
41
42template <typename DataSpecT>
43struct atomic_helper_t<DataSpecT,
44 typename std::enable_if<
45 std::is_same<DataSpecT, ngen::scattered_dword>::value>::type> {
46 template <typename GeneratorT>
47 static void call(GeneratorT *host, ngen::AtomicOp atomic_op,
48 const ngen::InstructionModifier &mod, const DataSpecT &spec,
49 ngen::AddressBase base, const ngen::RegData &addr,
50 const ngen::RegData &data) {
51 host->atomic(atomic_op, mod, spec, base, addr, data);
52 }
53};
54
55// Helper to emit send instructions.
56class send_impl_t {
57public:
58 send_impl_t(const send_t &send) : send_(send) {}
59
60 template <typename GeneratorT, typename T>
61 void emit(GeneratorT *host, ngen_register_scope_t &scope,
62 const ngen::InstructionModifier &mod,
63 const ngen::RegData &surf_base_addr, int surf_bti,
64 const ngen::RegData &header, const T &data) {
65 if (send_.is_2d()) {
66 emit_2d(host, mod, data, header);
67 return;
68 }
69
70 if (send_.is_lsc) {
71 emit_lsc(host, mod, data, surf_bti, header);
72 return;
73 }
74
75 auto address_base = to_address_base(send_.address, surf_bti);
76
77 int elems = send_.type.elems();
78 switch (send_.type.kind()) {
79 case type_kind_t::byte:
80 emit_load_or_store(host, mod, ngen::scattered_byte(elems),
81 address_base, header, data);
82 break;
83 case type_kind_t::dword:
84 emit_load_or_store(host, mod, ngen::scattered_dword(elems),
85 address_base, header, data);
86 break;
87 case type_kind_t::qword:
88 emit_load_or_store(host, mod, ngen::scattered_qword(elems),
89 address_base, header, data);
90 break;
91 case type_kind_t::oword:
92 emit_load_or_store(host, mod, ngen::block_oword(elems),
93 address_base, header, data);
94 break;
95 case type_kind_t::hword:
96 emit_load_or_store(host, mod, ngen::block_hword(elems),
97 address_base, header, data);
98 break;
99 default: ir_error_not_expected();
100 }
101 }
102
103private:
104 template <typename GeneratorT, typename DataSpecT>
105 void emit_load_or_store(GeneratorT *host,
106 const ngen::InstructionModifier &mod, const DataSpecT &spec,
107 ngen::AddressBase base, const ngen::RegData &addr,
108 const ngen::RegData &data) {
109 if (send_.is_load()) {
110 host->load(mod, data, spec, base, addr);
111 } else if (send_.is_atomic()) {
112 atomic_helper_t<DataSpecT>::call(
113 host, ngen::AtomicOp::fadd, mod, spec, base, addr, data);
114 } else if (send_.is_store()) {
115 host->store(mod, spec, base, addr, data);
116 } else {
117 ir_error_not_expected() << "Can't emit send: " << send_;
118 }
119 }
120
121 template <typename GeneratorT>
122 void emit_lsc(GeneratorT *host, const ngen::InstructionModifier &mod,
123 const ngen::RegData &data, int surf_bti,
124 const ngen::RegData &header) {
125
126 auto get_lsc_type = [&](const type_t &type, bool is_block) {
127 if (!send_.is_block()) return type;
128 for (auto &t : {type_t::qword(), type_t::dword()}) {
129 if (type.size() % t.size() == 0) {
130 int elems = type.size() / t.size();
131 ir_assert(math::is_pow2(elems));
132 ir_assert(elems >= 1 && elems <= 64);
133 return t.with_elems(elems);
134 }
135 }
136 ir_error_not_expected();
137 return type;
138 };
139
140 std::unique_ptr<ngen::DataSpecLSC> lsc_spec;
141 auto lsc_type = to_data_lsc(get_lsc_type(send_.type, send_.is_block()));
142 if (send_.is_scattered()) {
143 lsc_spec = utils::make_unique<ngen::DataSpecLSC>(
144 ngen::scattered(lsc_type.first, lsc_type.second));
145 } else if (send_.is_block()) {
146 lsc_spec = utils::make_unique<ngen::DataSpecLSC>(
147 ngen::block(lsc_type.first, lsc_type.second));
148 } else {
149 ir_error_not_expected();
150 }
151
152 if (send_.is_slm()) {
153 if (send_.is_load()) {
154 host->load.slm(mod, data, *lsc_spec, host->SLM, header);
155 } else if (send_.is_store()) {
156 host->store.slm(mod, *lsc_spec, host->SLM, header, data);
157 } else {
158 ir_error_not_expected();
159 }
160 } else if (send_.is_a64()) {
161 if (send_.is_load() || send_.is_prefetch()) {
162 *lsc_spec |= ngen::CacheSettingsLSC::L1C_L3C;
163 host->load.ugm(mod, data, *lsc_spec, host->A64, header);
164 } else if (send_.is_store()) {
165 *lsc_spec |= ngen::CacheSettingsLSC::L1WB_L3WB;
166 host->store.ugm(mod, *lsc_spec, host->A64, header, data);
167 } else if (send_.is_atomic()) {
168 *lsc_spec |= ngen::CacheSettingsLSC::L1UC_L3WB;
169 host->atomic.ugm(ngen::AtomicOp::fadd, mod, *lsc_spec,
170 to_address_base(send_.address, surf_bti), header, data);
171 }
172 } else {
173 ir_error_not_expected();
174 }
175 }
176
177 template <typename GeneratorT>
178 void emit_2d(GeneratorT *host, const ngen::InstructionModifier &mod,
179 const ngen::RegData &data, const ngen::RegData &header) {
180 auto &info = send_.block_2d_info;
181 ngen::DataSizeLSC data_size = ngen::DataSizeLSC::D8;
182 switch (send_.type.size()) {
183 case 1: data_size = ngen::DataSizeLSC::D8; break;
184 case 2: data_size = ngen::DataSizeLSC::D16; break;
185 case 4: data_size = ngen::DataSizeLSC::D32; break;
186 default: ir_error_not_expected();
187 }
188 ngen::DataSpecLSC data_spec(data_size);
189 if (info.vnni) data_spec |= host->vnni;
190 if (info.transpose) data_spec |= host->transpose;
191 ngen::block_2d spec(data_spec, info.width, info.height, info.count);
192 if (send_.is_load_2d() || send_.is_prefetch_2d()) {
193 spec |= ngen::CacheSettingsLSC::L1C_L3C;
194 host->load(mod, data, spec, host->A64, header);
195 } else if (send_.is_store_2d()) {
196 host->store(mod, spec, host->A64, header, data);
197 } else {
198 ir_error_not_expected();
199 }
200 }
201
202 static std::pair<ngen::DataSizeLSC, int> to_data_lsc(const type_t &type) {
203 switch (type.scalar().size()) {
204 case 1: {
205 if (type.elems() == 1)
206 return std::make_pair(ngen::DataSizeLSC::D8U32, 1);
207 if (type.elems() == 2)
208 return std::make_pair(ngen::DataSizeLSC::D16U32, 1);
209 if (type.elems() == 4)
210 return std::make_pair(ngen::DataSizeLSC::D32, 1);
211 if (type.elems() == 8)
212 return std::make_pair(ngen::DataSizeLSC::D64, 1);
213 break;
214 }
215 case 2: {
216 if (type.elems() == 1)
217 return std::make_pair(ngen::DataSizeLSC::D16U32, 1);
218 if (type.elems() == 2)
219 return std::make_pair(ngen::DataSizeLSC::D32, 1);
220 if (type.elems() == 4)
221 return std::make_pair(ngen::DataSizeLSC::D64, 1);
222 break;
223 }
224 case 4: return std::make_pair(ngen::DataSizeLSC::D32, type.elems());
225 case 8: return std::make_pair(ngen::DataSizeLSC::D64, type.elems());
226 default: break;
227 }
228 ir_error_not_expected();
229 return std::make_pair(ngen::DataSizeLSC::D8, 1);
230 }
231
232 static ngen::AddressBase to_address_base(
233 send_address_t address, int surf_bti) {
234 switch (address) {
235 case send_address_t::a64: return ngen::AddressBase::createA64(true);
236 case send_address_t::bts:
237 return ngen::AddressBase::createBTS(surf_bti);
238 case send_address_t::slm: return ngen::AddressBase::createSLM();
239 default: ir_error_not_expected();
240 }
241 return ngen::AddressBase();
242 }
243
244 const send_t &send_;
245};
246
247} // namespace jit
248} // namespace gpu
249} // namespace impl
250} // namespace dnnl
251
252#endif
253