1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef GPU_JIT_CODEGEN_SEND_HPP |
18 | #define GPU_JIT_CODEGEN_SEND_HPP |
19 | |
20 | #include "gpu/jit/codegen/kernel.hpp" |
21 | #include "gpu/jit/codegen/register_scope.hpp" |
22 | #include "gpu/jit/ir/message.hpp" |
23 | #include "gpu/jit/ir/tensor.hpp" |
24 | #include "gpu/jit/ngen/ngen.hpp" |
25 | |
26 | namespace dnnl { |
27 | namespace impl { |
28 | namespace gpu { |
29 | namespace jit { |
30 | |
31 | template <typename DataSpecT, typename = void> |
32 | struct atomic_helper_t { |
33 | template <typename GeneratorT> |
34 | static void call(GeneratorT *, ngen::AtomicOp, |
35 | const ngen::InstructionModifier &, const DataSpecT &, |
36 | ngen::AddressBase, const ngen::RegData &, const ngen::RegData &) { |
37 | ir_error_not_expected() |
38 | << "Unknown DataSpec: atomics are not supported." ; |
39 | } |
40 | }; |
41 | |
42 | template <typename DataSpecT> |
43 | struct atomic_helper_t<DataSpecT, |
44 | typename std::enable_if< |
45 | std::is_same<DataSpecT, ngen::scattered_dword>::value>::type> { |
46 | template <typename GeneratorT> |
47 | static void call(GeneratorT *host, ngen::AtomicOp atomic_op, |
48 | const ngen::InstructionModifier &mod, const DataSpecT &spec, |
49 | ngen::AddressBase base, const ngen::RegData &addr, |
50 | const ngen::RegData &data) { |
51 | host->atomic(atomic_op, mod, spec, base, addr, data); |
52 | } |
53 | }; |
54 | |
55 | // Helper to emit send instructions. |
56 | class send_impl_t { |
57 | public: |
58 | send_impl_t(const send_t &send) : send_(send) {} |
59 | |
60 | template <typename GeneratorT, typename T> |
61 | void emit(GeneratorT *host, ngen_register_scope_t &scope, |
62 | const ngen::InstructionModifier &mod, |
63 | const ngen::RegData &surf_base_addr, int surf_bti, |
64 | const ngen::RegData &, const T &data) { |
65 | if (send_.is_2d()) { |
66 | emit_2d(host, mod, data, header); |
67 | return; |
68 | } |
69 | |
70 | if (send_.is_lsc) { |
71 | emit_lsc(host, mod, data, surf_bti, header); |
72 | return; |
73 | } |
74 | |
75 | auto address_base = to_address_base(send_.address, surf_bti); |
76 | |
77 | int elems = send_.type.elems(); |
78 | switch (send_.type.kind()) { |
79 | case type_kind_t::byte: |
80 | emit_load_or_store(host, mod, ngen::scattered_byte(elems), |
81 | address_base, header, data); |
82 | break; |
83 | case type_kind_t::dword: |
84 | emit_load_or_store(host, mod, ngen::scattered_dword(elems), |
85 | address_base, header, data); |
86 | break; |
87 | case type_kind_t::qword: |
88 | emit_load_or_store(host, mod, ngen::scattered_qword(elems), |
89 | address_base, header, data); |
90 | break; |
91 | case type_kind_t::oword: |
92 | emit_load_or_store(host, mod, ngen::block_oword(elems), |
93 | address_base, header, data); |
94 | break; |
95 | case type_kind_t::hword: |
96 | emit_load_or_store(host, mod, ngen::block_hword(elems), |
97 | address_base, header, data); |
98 | break; |
99 | default: ir_error_not_expected(); |
100 | } |
101 | } |
102 | |
103 | private: |
104 | template <typename GeneratorT, typename DataSpecT> |
105 | void emit_load_or_store(GeneratorT *host, |
106 | const ngen::InstructionModifier &mod, const DataSpecT &spec, |
107 | ngen::AddressBase base, const ngen::RegData &addr, |
108 | const ngen::RegData &data) { |
109 | if (send_.is_load()) { |
110 | host->load(mod, data, spec, base, addr); |
111 | } else if (send_.is_atomic()) { |
112 | atomic_helper_t<DataSpecT>::call( |
113 | host, ngen::AtomicOp::fadd, mod, spec, base, addr, data); |
114 | } else if (send_.is_store()) { |
115 | host->store(mod, spec, base, addr, data); |
116 | } else { |
117 | ir_error_not_expected() << "Can't emit send: " << send_; |
118 | } |
119 | } |
120 | |
121 | template <typename GeneratorT> |
122 | void emit_lsc(GeneratorT *host, const ngen::InstructionModifier &mod, |
123 | const ngen::RegData &data, int surf_bti, |
124 | const ngen::RegData &) { |
125 | |
126 | auto get_lsc_type = [&](const type_t &type, bool is_block) { |
127 | if (!send_.is_block()) return type; |
128 | for (auto &t : {type_t::qword(), type_t::dword()}) { |
129 | if (type.size() % t.size() == 0) { |
130 | int elems = type.size() / t.size(); |
131 | ir_assert(math::is_pow2(elems)); |
132 | ir_assert(elems >= 1 && elems <= 64); |
133 | return t.with_elems(elems); |
134 | } |
135 | } |
136 | ir_error_not_expected(); |
137 | return type; |
138 | }; |
139 | |
140 | std::unique_ptr<ngen::DataSpecLSC> lsc_spec; |
141 | auto lsc_type = to_data_lsc(get_lsc_type(send_.type, send_.is_block())); |
142 | if (send_.is_scattered()) { |
143 | lsc_spec = utils::make_unique<ngen::DataSpecLSC>( |
144 | ngen::scattered(lsc_type.first, lsc_type.second)); |
145 | } else if (send_.is_block()) { |
146 | lsc_spec = utils::make_unique<ngen::DataSpecLSC>( |
147 | ngen::block(lsc_type.first, lsc_type.second)); |
148 | } else { |
149 | ir_error_not_expected(); |
150 | } |
151 | |
152 | if (send_.is_slm()) { |
153 | if (send_.is_load()) { |
154 | host->load.slm(mod, data, *lsc_spec, host->SLM, header); |
155 | } else if (send_.is_store()) { |
156 | host->store.slm(mod, *lsc_spec, host->SLM, header, data); |
157 | } else { |
158 | ir_error_not_expected(); |
159 | } |
160 | } else if (send_.is_a64()) { |
161 | if (send_.is_load() || send_.is_prefetch()) { |
162 | *lsc_spec |= ngen::CacheSettingsLSC::L1C_L3C; |
163 | host->load.ugm(mod, data, *lsc_spec, host->A64, header); |
164 | } else if (send_.is_store()) { |
165 | *lsc_spec |= ngen::CacheSettingsLSC::L1WB_L3WB; |
166 | host->store.ugm(mod, *lsc_spec, host->A64, header, data); |
167 | } else if (send_.is_atomic()) { |
168 | *lsc_spec |= ngen::CacheSettingsLSC::L1UC_L3WB; |
169 | host->atomic.ugm(ngen::AtomicOp::fadd, mod, *lsc_spec, |
170 | to_address_base(send_.address, surf_bti), header, data); |
171 | } |
172 | } else { |
173 | ir_error_not_expected(); |
174 | } |
175 | } |
176 | |
177 | template <typename GeneratorT> |
178 | void emit_2d(GeneratorT *host, const ngen::InstructionModifier &mod, |
179 | const ngen::RegData &data, const ngen::RegData &) { |
180 | auto &info = send_.block_2d_info; |
181 | ngen::DataSizeLSC data_size = ngen::DataSizeLSC::D8; |
182 | switch (send_.type.size()) { |
183 | case 1: data_size = ngen::DataSizeLSC::D8; break; |
184 | case 2: data_size = ngen::DataSizeLSC::D16; break; |
185 | case 4: data_size = ngen::DataSizeLSC::D32; break; |
186 | default: ir_error_not_expected(); |
187 | } |
188 | ngen::DataSpecLSC data_spec(data_size); |
189 | if (info.vnni) data_spec |= host->vnni; |
190 | if (info.transpose) data_spec |= host->transpose; |
191 | ngen::block_2d spec(data_spec, info.width, info.height, info.count); |
192 | if (send_.is_load_2d() || send_.is_prefetch_2d()) { |
193 | spec |= ngen::CacheSettingsLSC::L1C_L3C; |
194 | host->load(mod, data, spec, host->A64, header); |
195 | } else if (send_.is_store_2d()) { |
196 | host->store(mod, spec, host->A64, header, data); |
197 | } else { |
198 | ir_error_not_expected(); |
199 | } |
200 | } |
201 | |
202 | static std::pair<ngen::DataSizeLSC, int> to_data_lsc(const type_t &type) { |
203 | switch (type.scalar().size()) { |
204 | case 1: { |
205 | if (type.elems() == 1) |
206 | return std::make_pair(ngen::DataSizeLSC::D8U32, 1); |
207 | if (type.elems() == 2) |
208 | return std::make_pair(ngen::DataSizeLSC::D16U32, 1); |
209 | if (type.elems() == 4) |
210 | return std::make_pair(ngen::DataSizeLSC::D32, 1); |
211 | if (type.elems() == 8) |
212 | return std::make_pair(ngen::DataSizeLSC::D64, 1); |
213 | break; |
214 | } |
215 | case 2: { |
216 | if (type.elems() == 1) |
217 | return std::make_pair(ngen::DataSizeLSC::D16U32, 1); |
218 | if (type.elems() == 2) |
219 | return std::make_pair(ngen::DataSizeLSC::D32, 1); |
220 | if (type.elems() == 4) |
221 | return std::make_pair(ngen::DataSizeLSC::D64, 1); |
222 | break; |
223 | } |
224 | case 4: return std::make_pair(ngen::DataSizeLSC::D32, type.elems()); |
225 | case 8: return std::make_pair(ngen::DataSizeLSC::D64, type.elems()); |
226 | default: break; |
227 | } |
228 | ir_error_not_expected(); |
229 | return std::make_pair(ngen::DataSizeLSC::D8, 1); |
230 | } |
231 | |
232 | static ngen::AddressBase to_address_base( |
233 | send_address_t address, int surf_bti) { |
234 | switch (address) { |
235 | case send_address_t::a64: return ngen::AddressBase::createA64(true); |
236 | case send_address_t::bts: |
237 | return ngen::AddressBase::createBTS(surf_bti); |
238 | case send_address_t::slm: return ngen::AddressBase::createSLM(); |
239 | default: ir_error_not_expected(); |
240 | } |
241 | return ngen::AddressBase(); |
242 | } |
243 | |
244 | const send_t &send_; |
245 | }; |
246 | |
247 | } // namespace jit |
248 | } // namespace gpu |
249 | } // namespace impl |
250 | } // namespace dnnl |
251 | |
252 | #endif |
253 | |