1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef GPU_JIT_CODEGEN_CODEGEN_HPP |
18 | #define GPU_JIT_CODEGEN_CODEGEN_HPP |
19 | |
20 | #include "gpu/jit/codegen/bank_conflict_allocation.hpp" |
21 | #include "gpu/jit/codegen/kernel.hpp" |
22 | #include "gpu/jit/codegen/reduce.hpp" |
23 | #include "gpu/jit/codegen/register_scope.hpp" |
24 | #include "gpu/jit/codegen/reorder.hpp" |
25 | #include "gpu/jit/codegen/send.hpp" |
26 | #include "gpu/jit/ir/eltwise.hpp" |
27 | #include "gpu/jit/ir/fma.hpp" |
28 | #include "gpu/jit/jit_eltwise_injector.hpp" |
29 | #include "gpu/jit/ngen/ngen.hpp" |
30 | |
31 | namespace dnnl { |
32 | namespace impl { |
33 | namespace gpu { |
34 | namespace jit { |
35 | |
36 | inline ngen::ConditionModifier cmp_op_to_ngen(op_kind_t op_kind) { |
37 | ir_assert(is_cmp_op(op_kind)); |
38 | switch (op_kind) { |
39 | case op_kind_t::_eq: return ngen::ConditionModifier::eq; |
40 | case op_kind_t::_ne: return ngen::ConditionModifier::ne; |
41 | case op_kind_t::_ge: return ngen::ConditionModifier::ge; |
42 | case op_kind_t::_gt: return ngen::ConditionModifier::gt; |
43 | case op_kind_t::_le: return ngen::ConditionModifier::le; |
44 | case op_kind_t::_lt: return ngen::ConditionModifier::lt; |
45 | default: ir_error_not_expected(); |
46 | } |
47 | return ngen::ConditionModifier::none; |
48 | } |
49 | |
50 | // Lowers IR to nGEN. |
51 | template <ngen::HW hw> |
52 | class ir_to_ngen_t : public ir_visitor_t { |
53 | public: |
54 | ir_to_ngen_t(ir_kernel_t<hw> *host, const expr_binding_t &expr_binding) |
55 | : host_(host) |
56 | , expr_binding_(expr_binding) |
57 | , simd_size_(host->getSIMD()) |
58 | , eu_count_(host->exec_cfg_.hw_cfg().eu_count()) {} |
59 | |
60 | ~ir_to_ngen_t() { |
61 | #ifdef GEN_CONV_DEBUG |
62 | if (bank_conflicts_ > 0) |
63 | ir_warning() << "Found bank conflicts: " << bank_conflicts_ |
64 | << std::endl; |
65 | if (bundle_conflicts_ > 0) |
66 | ir_warning() << "Found bundle conflicts: " << bundle_conflicts_ |
67 | << std::endl; |
68 | #endif |
69 | } |
70 | |
71 | void _visit(const alloc_t &obj) override { |
72 | auto scope = register_scope(); |
73 | bool do_alloc = (obj.kind == alloc_kind_t::grf); |
74 | bool use_bc_alloc = false; |
75 | if (do_alloc) { |
76 | reg_buf_t rb; |
77 | if (obj.has_attr<bank_conflict_attr_t>()) { |
78 | rb = create_bank_conflict_allocation(obj); |
79 | use_bc_alloc = true; |
80 | } else { |
81 | int grf_size = ngen::GRF::bytes(hw); |
82 | int regs = utils::div_up(obj.size, grf_size); |
83 | rb = scope.alloc_reg_buf(regs); |
84 | } |
85 | if (obj.has_attr<grf_permute_attr_t>()) { |
86 | auto &attr = obj.get_attr<grf_permute_attr_t>(); |
87 | rb.set_grf_permutation(*attr.grf_perm); |
88 | } |
89 | expr_binding_.bind(obj.buf, reg_buf_data_t(rb)); |
90 | } |
91 | visit(obj.body); |
92 | if (do_alloc) expr_binding_.unbind(obj.buf); |
93 | if (use_bc_alloc) release_bank_conflict_allocation(obj); |
94 | } |
95 | |
96 | void _visit(const for_t &obj) override { |
97 | auto scope = register_scope(); |
98 | auto var_op = scope.alloc_reg_data(obj.var.type()); |
99 | auto init_op = eval(obj.init, scope); |
100 | auto bound_op = eval(obj.bound, scope); |
101 | |
102 | ngen::Label loop_label; |
103 | loop_end_labels_.emplace_back(); |
104 | |
105 | host_->emov(1, var_op, init_op); |
106 | expr_binding_.bind(obj.var, var_op); |
107 | host_->mark(loop_label); |
108 | visit(obj.body); |
109 | |
110 | host_->mark(loop_end_labels_.back()); |
111 | loop_end_labels_.pop_back(); |
112 | |
113 | host_->eadd(1, var_op, var_op, ngen::Immediate(1)); |
114 | host_->ecmp(1 | host_->lt | host_->f0[0], var_op, bound_op); |
115 | host_->jmpi(1 | host_->f0[0], loop_label); |
116 | expr_binding_.unbind(obj.var); |
117 | } |
118 | |
119 | void _visit(const func_call_t &obj) override { |
120 | auto scope = register_scope(); |
121 | |
122 | auto &func = obj.func; |
123 | if (func.is<dpas_t>()) { |
124 | auto arg_ops = eval(obj.args, scope); |
125 | dpas(func.as<dpas_t>(), arg_ops, obj.attr); |
126 | } else if (func.is<mad_t>()) { |
127 | auto arg_ops = eval(obj.args, scope); |
128 | mad(scope, func.as<mad_t>(), arg_ops, obj.attr); |
129 | } else if (func.is<send_t>()) { |
130 | auto &send_func = func.as<send_t>(); |
131 | auto args = obj.args; |
132 | auto &mem_buf = send_t::arg_mem_buf(args); |
133 | auto &mask = send_t::arg_mask(args); |
134 | // If all channels are disabled for writing, quick return. |
135 | if (all_of(mask, expr_t(false))) { |
136 | if (send_func.is_load() || send_func.is_load_2d()) { |
137 | auto reg_buf_op = ir_to_ngen_t<hw>::eval( |
138 | send_t::arg_reg_buf(args), scope); |
139 | zero_out_data_payload(send_func, reg_buf_op.reg_buf_data()); |
140 | } |
141 | return; |
142 | } |
143 | // If all channels are enabled, do not use mask. |
144 | if (all_of(mask, expr_t(true))) mask = expr_t(); |
145 | auto arg_ops = ir_to_ngen_t<hw>::eval(args, scope); |
146 | send(scope, func.as<send_t>(), mem_buf, arg_ops, obj.attr); |
147 | } else if (func.is<reorder_t>()) { |
148 | auto arg_ops = eval(obj.args, scope); |
149 | ir_assert(obj.attr.is_empty()) << "Unexpected attribute." ; |
150 | reorder(scope, func.as<reorder_t>(), reorder_t::arg_src_buf(obj), |
151 | arg_ops); |
152 | } else if (func.is<reduce_t>()) { |
153 | auto arg_ops = eval(obj.args, scope); |
154 | ir_assert(obj.attr.is_empty()) << "Unexpected attribute." ; |
155 | reduce(scope, func.as<reduce_t>(), arg_ops); |
156 | } else if (func.is<eltwise_t>()) { |
157 | auto &eltwise_func = func.as<eltwise_t>(); |
158 | auto arg_ops = eval(obj.args, scope); |
159 | eltwise(scope, eltwise_func, arg_ops); |
160 | } else if (func.is_equal(funcs::barrier_func())) { |
161 | barrier(obj.attr); |
162 | } else if (func.is_equal(funcs::barrier_wait_func())) { |
163 | barrier_wait(); |
164 | } else if (func.is_equal(funcs::signal_func())) { |
165 | signal(obj.attr); |
166 | } else if (func.is_equal(funcs::slm_fence_func())) { |
167 | slm_fence(obj.attr); |
168 | } else { |
169 | ir_error_not_expected() << object_t(obj); |
170 | } |
171 | } |
172 | |
173 | void _visit(const if_t &obj) override { |
174 | ir_assert(obj.cond.is<shuffle_t>()); |
175 | ir_assert(obj.cond.as<shuffle_t>().elems() == simd_size_); |
176 | |
177 | bool has_else = !obj.else_body.is_empty(); |
178 | auto scope = register_scope(); |
179 | auto cond_op = eval(obj.cond, scope); |
180 | |
181 | if (try_emit_if_continue(obj, cond_op)) return; |
182 | |
183 | ngen::Label l_else; |
184 | ngen::Label l_endif; |
185 | host_->if_(simd_size_ | cond_op.flag_register(), |
186 | has_else ? l_else : l_endif, l_endif); |
187 | visit(obj.body); |
188 | if (has_else) { |
189 | host_->else_(simd_size_, l_endif, l_endif); |
190 | host_->mark(l_else); |
191 | visit(obj.else_body); |
192 | } |
193 | host_->mark(l_endif); |
194 | host_->endif(simd_size_); |
195 | } |
196 | |
197 | void _visit(const let_t &obj) override { |
198 | if (obj.value.is_empty()) { |
199 | // External variable, must be already bound. |
200 | ir_assert(expr_binding_.is_bound(obj.var)) |
201 | << "Variable is not defined: " << obj.var; |
202 | visit(obj.body); |
203 | return; |
204 | } |
205 | |
206 | auto scope = register_scope(); |
207 | if (is_const(obj.value) || is_shuffle_const(obj.value) |
208 | || obj.var.type() != obj.value.type()) { |
209 | auto &var_type = obj.var.type(); |
210 | auto var_op = (var_type.is_bool()) |
211 | ? ngen_operand_t(scope.alloc_flag()) |
212 | : ngen_operand_t(scope.alloc_reg_data(var_type)); |
213 | eval(obj.value, scope, ngen_operand_t(var_op, var_type.elems())); |
214 | expr_binding_.bind(obj.var, var_op); |
215 | } else { |
216 | auto value_op = eval(obj.value, scope); |
217 | expr_binding_.bind(obj.var, value_op); |
218 | } |
219 | |
220 | auto var_op = expr_binding_.get(obj.var); |
221 | |
222 | // At this point the scope contains allocations for temporary |
223 | // expressions. We need to 1) query and later re-claim the allocation |
224 | // for the let variable in a new scope and 2) release the current scope |
225 | // allocations to reduce GRF consumption. |
226 | ngen::GRFRange var_grf_range; |
227 | ngen::Subregister var_sub; |
228 | |
229 | if (var_op.is_reg_data()) { |
230 | auto var_rd = var_op.reg_data(); |
231 | var_grf_range = scope.find_grf_range( |
232 | var_rd.getBase(), var_rd.getByteOffset()); |
233 | var_sub = scope.find_sub(var_rd.getBase(), var_rd.getByteOffset()); |
234 | } |
235 | |
236 | // Release the current scope allocations. |
237 | scope.clear(); |
238 | |
239 | // Claim the let variable allocation. |
240 | auto var_scope = register_scope(); |
241 | if (!var_grf_range.isInvalid()) { |
242 | var_scope.claim(var_grf_range); |
243 | } else if (!var_sub.isInvalid()) { |
244 | var_scope.claim(var_sub); |
245 | } |
246 | |
247 | visit(obj.body); |
248 | expr_binding_.unbind(obj.var); |
249 | } |
250 | |
251 | void _visit(const store_t &obj) override { |
252 | auto scope = register_scope(); |
253 | auto buf_op = eval(obj.buf, scope); |
254 | auto off = to_cpp<int>(obj.off); |
255 | auto mask_op = eval(obj.mask, scope); |
256 | |
257 | auto &type = obj.value.type(); |
258 | |
259 | int stride; |
260 | if (obj.has_default_stride()) { |
261 | stride = 1; |
262 | } else { |
263 | ir_assert(obj.stride % type.scalar().size() == 0); |
264 | stride = obj.stride / type.scalar().size(); |
265 | } |
266 | |
267 | ngen::InstructionModifier mod = type.elems(); |
268 | if (!mask_op.is_invalid()) mod |= mask_op.flag_register_mod(); |
269 | auto dst_rbd = buf_op.reg_buf_data().format( |
270 | off, to_ngen(type.scalar()), type.elems(), stride); |
271 | ngen_operand_t dst(dst_rbd, mod); |
272 | eval(obj.value, scope, dst, obj.fill_mask0 && !mask_op.is_invalid()); |
273 | } |
274 | |
275 | private: |
276 | ngen_register_scope_t register_scope() { |
277 | return ngen_register_scope_t(host_->ra_); |
278 | } |
279 | |
280 | #ifdef GEN_CONV_DEBUG |
281 | void check_bank_conflicts(const ngen::InstructionModifier &mod, |
282 | const ngen::RegData &_src0, const ngen::RegData &_src1, |
283 | const ngen::RegData &_src2, bool is_dpas = false) { |
284 | int esize = mod.getExecSize(); |
285 | int hw_simd = (hw >= ngen::HW::XeHPC ? 16 : 8); |
286 | auto shift = [](const ngen::RegData &rd, int exec_off) { |
287 | if (exec_off == 0 || rd.isNull()) return rd; |
288 | int type_size = ngen::getBytes(rd.getType()); |
289 | int w = (exec_off % rd.getWidth()); |
290 | int h = (exec_off / rd.getWidth()); |
291 | int off = rd.getByteOffset() |
292 | + (w * rd.getHS() + h * rd.getVS()) * type_size; |
293 | int grf_size = ngen::GRF::bytes(hw); |
294 | int shifted_base = rd.getBase() + off / grf_size; |
295 | int shifted_off = off % grf_size; |
296 | auto ret = rd; |
297 | ret.setBase(shifted_base); |
298 | ret.setOffset(ir_utils::safe_divide(shifted_off, type_size)); |
299 | return ret; |
300 | }; |
301 | for (int i = 0; i < esize; i += hw_simd) { |
302 | auto src0 = shift(_src0, i); |
303 | auto src1 = shift(_src1, i); |
304 | auto src2 = shift(_src2, i); |
305 | bool same_bank01 = ngen::Bundle::same_bank(hw, src0, src1); |
306 | bool same_bank02 = ngen::Bundle::same_bank(hw, src0, src2); |
307 | if (is_dpas) { |
308 | if (same_bank02) bank_conflicts_++; |
309 | } else { |
310 | if (same_bank01 && same_bank02) bank_conflicts_++; |
311 | if (ngen::Bundle::conflicts(hw, src0, src1) |
312 | || ngen::Bundle::conflicts(hw, src0, src2) |
313 | || ngen::Bundle::conflicts(hw, src1, src2)) { |
314 | bundle_conflicts_++; |
315 | } |
316 | } |
317 | } |
318 | } |
319 | #else |
320 | template <typename... ArgsT> |
321 | void check_bank_conflicts(const ArgsT &...) {} |
322 | #endif |
323 | |
324 | reg_buf_t create_bank_conflict_allocation(const alloc_t &alloc) { |
325 | auto &bc_attr = alloc.get_attr<bank_conflict_attr_t>(); |
326 | auto it = bc_allocations_.find(bc_attr); |
327 | if (it != bc_allocations_.end()) { |
328 | it->second.retain(); |
329 | return it->second.get_reg_buf(alloc.buf); |
330 | } |
331 | auto bca = bank_conflict_allocation_t::create( |
332 | host_->ra_, host_->regs_, bc_attr); |
333 | if (bca.is_empty()) return {}; |
334 | |
335 | auto ret = bc_allocations_.emplace(bc_attr, std::move(bca)); |
336 | return ret.first->second.get_reg_buf(alloc.buf); |
337 | } |
338 | |
339 | void release_bank_conflict_allocation(const alloc_t &alloc) { |
340 | auto &bc_attr = alloc.get_attr<bank_conflict_attr_t>(); |
341 | auto it = bc_allocations_.find(bc_attr); |
342 | ir_assert(it != bc_allocations_.end()); |
343 | it->second.release(alloc.buf); |
344 | if (it->second.refs() == 0) bc_allocations_.erase(bc_attr); |
345 | } |
346 | |
347 | void signal(const func_call_attr_t &attr) { |
348 | ngen::InstructionModifier mod; |
349 | if (!attr.is_empty()) |
350 | mod = mod | to_ngen(attr.as<instruction_modifier_attr_t>().mod); |
351 | host_->barriermsg(mod, host_->signal_header_); |
352 | } |
353 | |
354 | void barrier_wait() { host_->barrierwait(); } |
355 | |
356 | void slm_fence(const func_call_attr_t &attr) { |
357 | auto scope = register_scope(); |
358 | auto tmp = scope.alloc(); |
359 | ngen::InstructionModifier mod; |
360 | if (!attr.is_empty()) |
361 | mod = mod | to_ngen(attr.as<instruction_modifier_attr_t>().mod); |
362 | |
363 | const int dwords = ngen::GRF::bytes(hw) / sizeof(int32_t); |
364 | host_->slmfence(mod, tmp, host_->r0); |
365 | host_->template mov<int32_t>(dwords, host_->null, tmp); |
366 | } |
367 | |
368 | void barrier(const func_call_attr_t &attr) { |
369 | auto scope = register_scope(); |
370 | auto tmp = scope.alloc(); |
371 | ngen::InstructionModifier mod; |
372 | if (!attr.is_empty()) |
373 | mod = mod | to_ngen(attr.as<instruction_modifier_attr_t>().mod); |
374 | |
375 | const int dwords = ngen::GRF::bytes(hw) / sizeof(int32_t); |
376 | host_->slmfence(mod, tmp, host_->r0); |
377 | host_->template mov<int32_t>(dwords, host_->null, tmp); |
378 | host_->barriermsg(mod, host_->signal_header_); |
379 | host_->barrierwait(); |
380 | } |
381 | |
382 | void dpas(const dpas_t &dpas_func, const std::vector<ngen_operand_t> &args, |
383 | const func_call_attr_t &attr) { |
384 | auto dst = dpas_t::arg_dst(args).reg_buf_data(); |
385 | auto src1 = dpas_t::arg_src1(args).reg_buf_data(); |
386 | auto src2 = dpas_t::arg_src2(args).reg_buf_data(); |
387 | |
388 | if (dpas_func.is_dpasw) dst = dst.unpermute(); |
389 | |
390 | int esize = dpas_func.exec_size; |
391 | |
392 | ngen::RegData src0; |
393 | auto &src0_op = dpas_t::arg_src0(args); |
394 | if (!src0_op.is_immediate()) { |
395 | auto src0_rbd = src0_op.reg_buf_data().format( |
396 | 0, to_ngen(dpas_func.dst_type), esize, 1); |
397 | if (dpas_func.is_dpasw) src0_rbd = src0_rbd.unpermute(); |
398 | src0 = src0_rbd; |
399 | } else { |
400 | ir_assert(src0_op.is_immediate()); |
401 | ir_assert(to_cpp<int32_t>(src0_op.immediate()) == 0); |
402 | src0 = host_->null.retype(to_ngen(dpas_func.dst_type)); |
403 | } |
404 | |
405 | dst = dst.format(0, to_ngen(dpas_func.dst_type), esize, 1); |
406 | src1 = src1.format(0, to_ngen(dpas_func.src1_type), esize, 1); |
407 | int src2_width = (dpas_func.is_dp4a() ? 1 : esize); |
408 | int src2_stride = (dpas_func.is_dp4a() ? 0 : 1); |
409 | src2 = src2.format( |
410 | 0, to_ngen(dpas_func.src2_type), src2_width, src2_stride); |
411 | |
412 | ngen::InstructionModifier mod = esize; |
413 | if (!attr.is_empty()) |
414 | mod = mod | to_ngen(attr.as<instruction_modifier_attr_t>().mod); |
415 | check_bank_conflicts(mod, src0, src1, src2, /*is_dpas=*/true); |
416 | if (dpas_func.is_dpasw) { |
417 | host_->dpasw(mod, dpas_func.sdepth, dpas_func.rcount, dst, src0, |
418 | src1, src2); |
419 | } else if (dpas_func.is_dp4a()) { |
420 | if (src0.isNull()) { |
421 | host_->dp4a(mod, dst, 0, src1, src2); |
422 | } else { |
423 | host_->dp4a(mod, dst, src0, src1, src2); |
424 | } |
425 | } else { |
426 | host_->dpas(mod, dpas_func.sdepth, dpas_func.rcount, dst, src0, |
427 | src1, src2); |
428 | } |
429 | } |
430 | |
431 | void mad(ngen_register_scope_t &scope, const mad_t &mad_func, |
432 | const std::vector<ngen_operand_t> &args, |
433 | const func_call_attr_t &attr) { |
434 | auto dst = mad_t::arg_dst(args).reg_buf_data(); |
435 | auto src1 = mad_t::arg_src1(args).reg_buf_data(); |
436 | auto src2 = mad_t::arg_src2(args).reg_buf_data(); |
437 | |
438 | ngen::RegData src0; |
439 | auto &src0_op = mad_t::arg_src0(args); |
440 | if (!src0_op.is_immediate()) { |
441 | src0 = src0_op.reg_buf_data() |
442 | .format(0, to_ngen(mad_func.dst_type), |
443 | mad_func.exec_size) |
444 | .reg_data(); |
445 | } else { |
446 | ir_assert(src0_op.is_immediate()); |
447 | ir_assert(to_cpp<int32_t>(src0_op.immediate()) == 0); |
448 | src0 = host_->null; |
449 | src0.setType(to_ngen(mad_func.dst_type)); |
450 | } |
451 | |
452 | dst = dst.format(0, to_ngen(mad_func.dst_type), mad_func.exec_size); |
453 | |
454 | int src1_width = (mad_func.src1_stride == 0 ? 1 : mad_func.exec_size); |
455 | int src2_width = (mad_func.src2_stride == 0 ? 1 : mad_func.exec_size); |
456 | src1 = src1.format(0, to_ngen(mad_func.src1_type), src1_width, |
457 | mad_func.src1_stride); |
458 | src2 = src2.format(0, to_ngen(mad_func.src2_type), src2_width, |
459 | mad_func.src2_stride); |
460 | |
461 | ngen::InstructionModifier mod = mad_func.exec_size; |
462 | if (!attr.is_empty()) |
463 | mod = mod | to_ngen(attr.as<instruction_modifier_attr_t>().mod); |
464 | |
465 | check_bank_conflicts(mod, src0, src1, src2, /*is_dpas=*/false); |
466 | if (src0.isNull()) { |
467 | host_->mul(mod, dst, src1, src2); |
468 | } else { |
469 | ir_assert(dst.byte_offset() == src0.getByteOffset()) |
470 | << "dst/src0 must be aligned to the same GRF offset." ; |
471 | auto _src1 = src1; |
472 | auto _src2 = src2; |
473 | align_src_dst_offset(host_, scope, mod, dst, _src1, _src2); |
474 | // Workaround for sporadic f64 mad errors with broadcast src1 on XeHPC. |
475 | if (mad_func.dst_type == type_t::f64() |
476 | && _src1.reg_data().getHS() == 0 |
477 | && _src1.reg_data().getVS() == 0) { |
478 | host_->mad(mod, dst, src0, _src2, _src1); |
479 | } else { |
480 | host_->mad(mod, dst, src0, _src1, _src2); |
481 | } |
482 | } |
483 | } |
484 | |
485 | void zero_out_data_payload( |
486 | const send_t &send_func, const reg_buf_data_t &rd) const { |
487 | type_t type = type_t::f32(); |
488 | int send_size = send_func.payload_size(); |
489 | int grf_size = ngen::GRF::bytes(hw); |
490 | int step = 2 * grf_size; |
491 | for (int i = 0; i < send_size; i += step) { |
492 | int exec_size = std::min(step, send_size - i) / type.size(); |
493 | auto sub_rd_mov = rd.format(i, to_ngen(type), exec_size).reg_data(); |
494 | ir_assert(math::is_pow2(exec_size)); |
495 | host_->emov(exec_size, sub_rd_mov, ngen::Immediate(0.0f)); |
496 | } |
497 | } |
498 | |
499 | ngen::RegData send_maybe_make_dense_payload(ngen_register_scope_t &scope, |
500 | const send_t &send_func, const ngen_operand_t &op_buf) const { |
501 | if (send_func.is_prefetch() || send_func.is_prefetch_2d()) |
502 | return ngen::RegData(host_->null); |
503 | |
504 | auto &buf = op_buf.reg_buf_data(); |
505 | int size = send_func.payload_size(); |
506 | bool is_dense = buf.is_dense(size); |
507 | if (is_dense) return buf.reg_data(); |
508 | |
509 | if (send_func.is_load() || send_func.is_load_2d()) { |
510 | ir_error_not_expected() |
511 | << "Expected dense GRF region for load message." ; |
512 | return ngen::RegData(); |
513 | } |
514 | |
515 | ir_assert(send_func.is_store() || send_func.is_store_2d()); |
516 | |
517 | // Reorder buffer to a dense buffer for store. |
518 | int grf_size = ngen::GRF::bytes(hw); |
519 | int regs = utils::div_up(size, grf_size); |
520 | |
521 | auto tmp = scope.alloc_range(regs); |
522 | |
523 | int dwords = ngen::GRF::bytes(hw) / sizeof(int32_t); |
524 | int max_step = 2; |
525 | for (int i = 0; i < regs;) { |
526 | auto sub_buf = buf.format(i * grf_size); |
527 | int step = std::min(max_step, regs - i); |
528 | if (step > 1 && !sub_buf.is_dense(step * grf_size)) step = 1; |
529 | int esize = step * dwords; |
530 | auto src = sub_buf.subregister(0, ngen::DataType::ud)(1); |
531 | auto dst = tmp[i].ud(0)(1); |
532 | host_->emov(esize, dst, src); |
533 | i += step; |
534 | } |
535 | return tmp[0]; |
536 | } |
537 | |
538 | void send(ngen_register_scope_t &scope, const send_t &send_func, |
539 | const expr_t &mem_buf, const std::vector<ngen_operand_t> &args, |
540 | const func_call_attr_t &attr) const { |
541 | send_impl_t spec_impl(send_func); |
542 | auto &mem_off_op = send_t::arg_mem_off(args); |
543 | auto ®_buf_op = send_t::arg_reg_buf(args); |
544 | auto &mask_op = send_t::arg_mask(args); |
545 | |
546 | ngen::RegData mem_buf_rd; |
547 | int surf_bti = -1; |
548 | switch (send_func.address) { |
549 | case send_address_t::slm: break; |
550 | case send_address_t::bts: { |
551 | auto &buf_name = mem_buf.as<var_t>().name; |
552 | surf_bti = host_->getArgumentSurface(buf_name); |
553 | break; |
554 | } |
555 | case send_address_t::a64: { |
556 | auto &mem_buf_op = send_t::arg_mem_buf(args); |
557 | mem_buf_rd = mem_buf_op.reg_data(); |
558 | break; |
559 | } |
560 | default: ir_error_not_expected(); |
561 | } |
562 | ngen::InstructionModifier mod = send_func.nmasks(); |
563 | ir_assert(math::is_pow2(mod.getExecSize())); |
564 | if (!attr.is_empty()) |
565 | mod |= to_ngen(attr.as<instruction_modifier_attr_t>().mod); |
566 | if (!mask_op.is_invalid()) mod |= mask_op.flag_register_mod(); |
567 | |
568 | // Zero-out inactive channels. |
569 | if ((send_func.is_load() || send_func.is_load_2d()) |
570 | && mod.getPredCtrl() != ngen::PredCtrl::None) { |
571 | zero_out_data_payload(send_func, reg_buf_op.reg_buf_data()); |
572 | } |
573 | |
574 | // Emit send instruction. |
575 | auto rd = send_maybe_make_dense_payload(scope, send_func, reg_buf_op); |
576 | spec_impl.emit(host_, scope, mod, mem_buf_rd, surf_bti, |
577 | mem_off_op.reg_data(), rd); |
578 | } |
579 | |
580 | void reorder(ngen_register_scope_t &scope, const reorder_t &reorder_func, |
581 | const expr_t &src_buf, |
582 | const std::vector<ngen_operand_t> &args) const { |
583 | auto &src_op = reorder_t::arg_src_buf(args); |
584 | auto &dst_op = reorder_t::arg_dst_buf(args); |
585 | |
586 | reorder_impl_t reorder_impl(hw, reorder_func); |
587 | reorder_impl.emit( |
588 | host_, scope, src_op.reg_buf_data(), dst_op.reg_buf_data()); |
589 | } |
590 | |
591 | void reduce(ngen_register_scope_t &scope, const reduce_t &reduce_func, |
592 | const std::vector<ngen_operand_t> &args) const { |
593 | auto &src_op = reduce_t::arg_src_buf(args); |
594 | auto &dst_op = reduce_t::arg_dst_buf(args); |
595 | |
596 | reduce_impl_t reduce_impl(hw, reduce_func, simd_size_); |
597 | reduce_impl.emit( |
598 | host_, scope, src_op.reg_buf_data(), dst_op.reg_buf_data()); |
599 | } |
600 | |
601 | void eltwise(ngen_register_scope_t &scope, const eltwise_t &func, |
602 | const std::vector<ngen_operand_t> &args) { |
603 | int elems = to_cpp<int>(hw, eltwise_t::arg_elems(args)); |
604 | auto &data_op = eltwise_t::arg_data(args); |
605 | const auto &data_rd = data_op.reg_buf_data(); |
606 | |
607 | jit_eltwise_injector_f32<hw> inj(host_, func.alg_kind, func.alpha, |
608 | func.beta, func.scale, eu_count_); |
609 | auto scratch = scope.alloc_range(inj.preferred_scratch_regs()); |
610 | inj.set_scratch(scratch); |
611 | inj.prepare(); |
612 | |
613 | int grf_size = ngen::GRF::bytes(hw); |
614 | int f_size = sizeof(float); |
615 | int step = 2 * grf_size / f_size; |
616 | for (int i = 0; i < elems; i += step) { |
617 | ngen_register_scope_t i_scope(scope.register_allocator()); |
618 | step = std::min(step, elems - i); |
619 | step = utils::rnd_down_pow2(step); |
620 | int cur_elems = step; |
621 | auto rd = data_rd.format(i * f_size, ngen::DataType::f); |
622 | // Use temporary storage when needed to ensure: |
623 | // - Eltwise is applied to full register |
624 | // - Data is aligned to GRF boundary |
625 | if ((cur_elems * f_size) % grf_size != 0 || rd.byte_offset() != 0) { |
626 | int full_elems |
627 | = utils::rnd_up(cur_elems * f_size, grf_size) / f_size; |
628 | auto tmp = i_scope.alloc_reg_data(type_t::f32(full_elems)); |
629 | emit_reorder_1d_tile( |
630 | hw, host_, i_scope, cur_elems, rd, 1, tmp, 1); |
631 | inj.compute(ngen::GRFRange( |
632 | tmp.base(), full_elems * f_size / grf_size)); |
633 | emit_reorder_1d_tile( |
634 | hw, host_, i_scope, cur_elems, tmp, 1, rd, 1); |
635 | } else { |
636 | inj.compute(ngen::GRFRange( |
637 | rd.base(), cur_elems * f_size / grf_size)); |
638 | } |
639 | } |
640 | } |
641 | |
642 | bool try_emit_if_continue(const if_t &obj, const ngen_operand_t &cond_op) { |
643 | if (!obj.else_body.is_empty()) return false; |
644 | auto *call = obj.body.as_ptr<func_call_t>(); |
645 | if (!call) return false; |
646 | if (!call->func.is_equal(funcs::continue_func())) return false; |
647 | |
648 | ir_assert(!loop_end_labels_.empty()) |
649 | << "Can't emit continue: no label found." ; |
650 | host_->jmpi(1 | cond_op.flag_register(), loop_end_labels_.back()); |
651 | return true; |
652 | } |
653 | |
654 | protected: |
655 | ngen_operand_t eval(const expr_t &e, ngen_register_scope_t &scope, |
656 | const ngen_operand_t &dst_operand = ngen_operand_t(), |
657 | bool fill_mask0 = false) const { |
658 | expr_evaluator_t<hw> expr_evaluator(host_, expr_binding_, scope); |
659 | return expr_evaluator.eval(e, dst_operand, fill_mask0); |
660 | } |
661 | |
662 | std::vector<ngen_operand_t> eval(const std::vector<expr_t> &exprs, |
663 | ngen_register_scope_t &scope) const { |
664 | expr_evaluator_t<hw> expr_evaluator(host_, expr_binding_, scope); |
665 | return expr_evaluator.eval(exprs); |
666 | } |
667 | |
668 | private: |
669 | ir_kernel_t<hw> *host_; |
670 | expr_binding_t expr_binding_; |
671 | int simd_size_; |
672 | int eu_count_; |
673 | |
674 | std::vector<ngen::Label> loop_end_labels_; |
675 | |
676 | #ifdef GEN_CONV_DEBUG |
677 | int bank_conflicts_ = 0; |
678 | int bundle_conflicts_ = 0; |
679 | #endif |
680 | |
681 | object_map_t<alloc_attr_t, bank_conflict_allocation_t> bc_allocations_; |
682 | }; |
683 | |
684 | // Evaluates expression by emitting instructions with nGEN. |
685 | template <ngen::HW hw> |
686 | class expr_evaluator_t : public ir_visitor_t { |
687 | public: |
688 | expr_evaluator_t(ir_kernel_t<hw> *host, const expr_binding_t &expr_binding, |
689 | ngen_register_scope_t &scope) |
690 | : host_(host), expr_binding_(expr_binding), scope_(scope) {} |
691 | |
692 | bool is_int_up_convert(const expr_t &e, type_t &type) const { |
693 | auto it = int_up_converts_.find(e); |
694 | if (it == int_up_converts_.end()) return false; |
695 | type = it->second; |
696 | return true; |
697 | } |
698 | |
699 | // If `dst_operand` is not empty, use its pre-allocated location for the |
700 | // result. |
701 | ngen_operand_t eval(const expr_t &e, |
702 | const ngen_operand_t &dst_operand = ngen_operand_t(), |
703 | bool fill_mask0 = false) { |
704 | if (!dst_operand.is_invalid()) { |
705 | ir_assert(dst_operand.mod().getExecSize() != 0); |
706 | } |
707 | if (expr_binding_.is_bound(e)) { |
708 | if (!dst_operand.is_invalid()) { |
709 | auto bind = expr_binding_.get(e); |
710 | if (fill_mask0) { |
711 | ir_assert(!bind.is_immediate()); |
712 | host_->sel(dst_operand.mod(), dst_operand.reg_data(), |
713 | bind.reg_data(), 0); |
714 | } else { |
715 | host_->emov(dst_operand.mod(), dst_operand, bind); |
716 | } |
717 | return dst_operand; |
718 | } |
719 | } else { |
720 | if (dst_operand.is_invalid()) { |
721 | visit(e); |
722 | } else if (!fill_mask0) { |
723 | expr_binding_.bind_dst(e, dst_operand); |
724 | visit(e); |
725 | } else { |
726 | auto op = eval(e); |
727 | ir_assert(!op.is_immediate()); |
728 | host_->sel(dst_operand.mod(), dst_operand.reg_data(), |
729 | op.reg_data(), 0); |
730 | } |
731 | } |
732 | |
733 | return expr_binding_.get(e, /*allow_empty=*/true); |
734 | } |
735 | |
736 | std::vector<ngen_operand_t> eval(const std::vector<expr_t> &exprs) { |
737 | std::vector<ngen_operand_t> ret; |
738 | for (auto &e : exprs) { |
739 | if (!expr_binding_.is_bound(e)) visit(e); |
740 | ret.push_back(expr_binding_.get(e)); |
741 | } |
742 | return ret; |
743 | } |
744 | |
745 | void _visit(const binary_op_t &obj) override { |
746 | auto dst_op = alloc_dst_op(obj); |
747 | auto mod = dst_op.mod(); |
748 | |
749 | switch (obj.op_kind) { |
750 | case op_kind_t::_and: { |
751 | if (obj.type.is_bool()) { |
752 | eval(obj.a, dst_op); |
753 | eval(obj.b, |
754 | ngen_operand_t( |
755 | dst_op, mod | dst_op.flag_register_mod())); |
756 | break; |
757 | } |
758 | // else fall through to the default label. |
759 | } |
760 | default: { |
761 | // Some cases require pre-allocated register regions with |
762 | // special strides for a/b. |
763 | auto a_out_op = maybe_alloc_strided_op(obj.type, obj.a); |
764 | auto b_out_op = maybe_alloc_strided_op(obj.type, obj.b); |
765 | auto src0_op = eval(obj.a, a_out_op); |
766 | auto src1_op = eval(obj.b, b_out_op); |
767 | |
768 | // XXX: (q x d) case is not supported. Try to downgrade it to |
769 | // (d x d) based on the previous assignments. |
770 | if (obj.op_kind == op_kind_t::_mul && obj.a.type().is_x64()) { |
771 | type_t orig_type; |
772 | if (is_int_up_convert(obj.a, orig_type)) { |
773 | src0_op = src0_op.reinterpret(orig_type); |
774 | // XXX: sync workaround is to fix an issue with |
775 | // mul(q, d, d) instruction on XeHP. For some reason |
776 | // the result is incorrect when dst and src0 are |
777 | // accessed from the same register. |
778 | host_->sync(ngen::SyncFunction::nop, |
779 | ngen::SWSB<uint64_t>(1)); |
780 | } else { |
781 | ir_error_not_expected(); |
782 | } |
783 | } |
784 | ebinary(obj, mod, dst_op, src0_op, src1_op); |
785 | break; |
786 | } |
787 | } |
788 | |
789 | bind(obj, dst_op); |
790 | } |
791 | |
792 | void _visit(const bool_imm_t &obj) override { |
793 | // Scalar booleans must never be directly lowered: |
794 | // - Booleans are mapped to flag registers |
795 | // - Flag register stores vector of boolean vectors |
796 | // - All boolean values in IR must be expressed by shuffle_t objects |
797 | // - _visit(shuffle_t *) must properly handle vector of booleans -> flag |
798 | // register lowering |
799 | ir_error_not_expected(); |
800 | } |
801 | |
802 | void _visit(const cast_t &obj) override { |
803 | auto &from_type = obj.expr.type(); |
804 | auto &to_type = obj.type; |
805 | |
806 | ir_assert(from_type != to_type) << "Equal types are not expected." ; |
807 | |
808 | if (is_const(obj.expr) && !to_type.is_bool()) { |
809 | bind(obj, to_ngen(obj.expr, to_type)); |
810 | return; |
811 | } |
812 | |
813 | auto dst_op = alloc_dst_op(obj); |
814 | |
815 | // Handle ptr -> u64 and u64 -> ptr casts. |
816 | if (utils::one_of(obj.type, type_t::u64(), type_t::byte_ptr()) |
817 | && utils::one_of( |
818 | obj.expr.type(), type_t::u64(), type_t::byte_ptr())) { |
819 | eval(obj.expr, dst_op); |
820 | bind(obj, dst_op); |
821 | return; |
822 | } |
823 | |
824 | // Handle integer (down-)conversion, assume bitwise equality in this |
825 | // case. Examples: d <-> ud, d -> w, q -> d. |
826 | bool is_int_convert = from_type.is_scalar() && to_type.is_scalar() |
827 | && from_type.is_int() && to_type.is_int(); |
828 | bool is_int_down_convert |
829 | = is_int_convert && from_type.size() >= to_type.size(); |
830 | bool is_int_up_convert |
831 | = is_int_convert && from_type.size() < to_type.size(); |
832 | if (is_int_down_convert) { |
833 | eval(obj.expr, dst_op.reinterpret(from_type)); |
834 | bind(obj, dst_op); |
835 | return; |
836 | } |
837 | |
838 | auto expr_op = eval(obj.expr); |
839 | auto mod = dst_op.mod(); |
840 | if (obj.saturate) mod |= host_->sat; |
841 | host_->emov(mod, dst_op, expr_op); |
842 | if (is_int_up_convert) int_up_converts_.emplace(obj, from_type); |
843 | bind(obj, dst_op); |
844 | } |
845 | |
846 | void _visit(const float_imm_t &obj) override { bind(obj, to_ngen(obj)); } |
847 | |
848 | void _visit(const int_imm_t &obj) override { bind(obj, to_ngen(obj)); } |
849 | |
850 | void _visit(const load_t &obj) override { |
851 | auto &type = obj.type; |
852 | auto buf_op = eval(obj.buf); |
853 | auto off_op = eval(obj.off); |
854 | int stride; |
855 | if (obj.has_default_stride()) { |
856 | stride = 1; |
857 | } else { |
858 | ir_assert(obj.stride % type.scalar().size() == 0); |
859 | stride = obj.stride / type.scalar().size(); |
860 | } |
861 | auto load_rbd |
862 | = buf_op.reg_buf_data().format(to_cpp<int>(off_op.immediate()), |
863 | to_ngen(type.scalar()), type.elems(), stride); |
864 | bind(obj, load_rbd); |
865 | } |
866 | |
867 | void _visit(const ptr_t &obj) override { |
868 | auto base_op = eval(obj.base); |
869 | |
870 | if (is_zero(obj.off)) { |
871 | bind(obj, base_op); |
872 | return; |
873 | } |
874 | |
875 | ir_assert(base_op.is_reg_buf_data()); |
876 | |
877 | int off = to_cpp<int>(obj.off); |
878 | bind(obj, base_op.reg_buf_data().format(off, ngen::DataType::ub)); |
879 | } |
880 | |
881 | void _visit(const shuffle_t &obj) override { |
882 | int elems = obj.elems(); |
883 | if (obj.type.is_bool() && is_shuffle_const(obj)) { |
884 | auto dst_op = alloc_dst_op(obj); |
885 | auto e_shuffle = expr_t(obj); |
886 | ir_assert(dst_op.is_flag_register()) << e_shuffle; |
887 | ir_assert(!dst_op.is_negated()) << e_shuffle; |
888 | uint16_t flag_mask = 0; |
889 | for (int i = elems - 1; i >= 0; i--) { |
890 | flag_mask <<= 1; |
891 | flag_mask |= (to_cpp<bool>(e_shuffle[i]) ? 1 : 0); |
892 | } |
893 | if (dst_op.mod().getPredCtrl() == ngen::PredCtrl::None) { |
894 | host_->emov(1, dst_op, ngen::Immediate(flag_mask)); |
895 | } else { |
896 | ir_assert(dst_op.mod().getFlagReg() == dst_op.flag_register()); |
897 | host_->and_(1, dst_op.flag_register(), dst_op.flag_register(), |
898 | ngen::Immediate(flag_mask)); |
899 | } |
900 | bind(obj, dst_op); |
901 | return; |
902 | } |
903 | |
904 | if (obj.is_broadcast()) { |
905 | if (obj.type.is_bool()) { |
906 | auto dst_op = alloc_dst_op(obj); |
907 | eval(obj.vec[0], dst_op); |
908 | bind(obj, dst_op); |
909 | } else { |
910 | auto scalar_op = eval(obj.vec[0]); |
911 | bind(obj, scalar_op); |
912 | } |
913 | return; |
914 | } |
915 | |
916 | if (try_region_peephole(obj)) return; |
917 | if (try_packed_int_peephole(obj)) return; |
918 | |
919 | // tuples: <offset, length, idx> |
920 | std::vector<std::tuple<int, int, int>> chunks; |
921 | for (int i = 0; i < elems; i++) { |
922 | int idx = obj.idx[i]; |
923 | if (chunks.empty() || std::get<2>(chunks.back()) != idx) { |
924 | chunks.emplace_back(i, 1, idx); |
925 | } else { |
926 | std::get<1>(chunks.back())++; |
927 | } |
928 | } |
929 | |
930 | auto dst_op = alloc_dst_op(obj); |
931 | for (auto &chunk : chunks) { |
932 | int off = std::get<0>(chunk); |
933 | int length = std::get<1>(chunk); |
934 | int idx = std::get<2>(chunk); |
935 | // Split length into powers of two. |
936 | while (length > 0) { |
937 | int exec_size = (1 << math::ilog2q(length)); |
938 | auto chunk_op = dst_op.sub_reg_data(off, exec_size); |
939 | eval(obj.vec[idx], ngen_operand_t(chunk_op, exec_size)); |
940 | length -= exec_size; |
941 | off += exec_size; |
942 | } |
943 | } |
944 | bind(obj, dst_op); |
945 | } |
946 | |
947 | void _visit(const ternary_op_t &obj) override { |
948 | switch (obj.op_kind) { |
949 | case op_kind_t::_dp4a: |
950 | case op_kind_t::_add3: |
951 | case op_kind_t::_mad: { |
952 | auto dst_op = alloc_dst_op(obj); |
953 | auto mod = dst_op.mod(); |
954 | auto src0_op = eval(obj.a); |
955 | auto src1_op = eval(obj.b); |
956 | auto src2_op = eval(obj.c); |
957 | if (obj.op_kind == op_kind_t::_dp4a) { |
958 | host_->edp4a(mod, dst_op, src0_op, src1_op, src2_op); |
959 | } else if (obj.op_kind == op_kind_t::_add3) { |
960 | host_->eadd3(mod, dst_op, src0_op, src1_op, src2_op); |
961 | } else { |
962 | host_->emad(mod, dst_op, src0_op, src1_op, src2_op); |
963 | } |
964 | bind(obj, dst_op); |
965 | break; |
966 | } |
967 | default: ir_error_not_expected(); |
968 | } |
969 | } |
970 | |
971 | void _visit(const unary_op_t &obj) override { |
972 | ir_assert(obj.op_kind == op_kind_t::_minus); |
973 | ngen_operand_t a_op; |
974 | a_op = try_process_negated_flags(obj); |
975 | if (a_op.is_invalid()) a_op = eval(obj.a); |
976 | bind(obj, -a_op); |
977 | } |
978 | |
979 | void _visit(const var_t &obj) override { |
980 | ir_assert(expr_binding_.is_bound(obj)) |
981 | << "Variable is not defined: " << expr_t(obj); |
982 | } |
983 | |
984 | private: |
985 | ngen_operand_t alloc_dst_op(const expr_t &e) { |
986 | ir_assert(!expr_binding_.is_bound(e)) << "Already evaluated: " << e; |
987 | if (expr_binding_.is_dst_bound(e)) return expr_binding_.get_dst(e); |
988 | |
989 | // Expression is not bound yet, allocate new storage and bind. |
990 | ngen_operand_t op; |
991 | if (e.type().is_bool()) { |
992 | op = ngen_operand_t(scope_.alloc_flag(), e.type().elems()); |
993 | } else { |
994 | op = ngen_operand_t( |
995 | scope_.alloc_reg_data(e.type()), e.type().elems()); |
996 | } |
997 | expr_binding_.bind_dst(e, op); |
998 | return op; |
999 | } |
1000 | |
1001 | // Pre-allocates a strided register region for expression `e` if needed. |
1002 | ngen_operand_t maybe_alloc_strided_op( |
1003 | const type_t &res_type, const expr_t &e) { |
1004 | // Need q-strided region for `e` if res_type is q/uq and `e` is of a |
1005 | // sub-q data type and not a scalar. |
1006 | if (e.type().is_scalar()) return ngen_operand_t(); |
1007 | if (!utils::one_of(res_type.scalar(), type_t::s64(), type_t::u64())) |
1008 | return ngen_operand_t(); |
1009 | if (utils::one_of(e.type().scalar(), type_t::s64(), type_t::u64())) |
1010 | return ngen_operand_t(); |
1011 | |
1012 | auto *shuffle = e.as_ptr<shuffle_t>(); |
1013 | if (shuffle && shuffle->is_broadcast()) return ngen_operand_t(); |
1014 | |
1015 | return ngen_operand_t( |
1016 | scope_.alloc_reg_data(e.type(), res_type.scalar().size()), |
1017 | e.type().elems()); |
1018 | } |
1019 | |
1020 | void bind(const expr_t &e, const ngen_operand_t &op) { |
1021 | if (!expr_binding_.is_dst_bound(e)) { |
1022 | expr_binding_.bind(e, op); |
1023 | return; |
1024 | } |
1025 | auto dst_op = expr_binding_.get_dst(e); |
1026 | if (dst_op == op) { |
1027 | expr_binding_.bind(e, op); |
1028 | return; |
1029 | } |
1030 | // Expression is already bound, move to the location it was bound to. |
1031 | // This is required for immediate values - they are bound as is but |
1032 | // sometimes we need them to be moved to registers. |
1033 | host_->emov(dst_op.mod(), dst_op, op); |
1034 | expr_binding_.bind(e, dst_op); |
1035 | } |
1036 | |
1037 | void ebinary(const binary_op_t &obj, const ngen::InstructionModifier &mod, |
1038 | const ngen_operand_t &_dst, const ngen_operand_t &_src0, |
1039 | const ngen_operand_t &_src1) { |
1040 | auto dst = _dst; |
1041 | auto src0 = _src0; |
1042 | auto src1 = _src1; |
1043 | align_src_dst_offset(host_, scope_, mod, dst, src0, src1); |
1044 | switch (obj.op_kind) { |
1045 | case op_kind_t::_add: host_->eadd(mod, dst, src0, src1); break; |
1046 | case op_kind_t::_sub: host_->eadd(mod, dst, src0, -src1); break; |
1047 | case op_kind_t::_mul: host_->emul(mod, dst, src0, src1); break; |
1048 | case op_kind_t::_div: host_->ediv(mod, dst, src0, src1); break; |
1049 | case op_kind_t::_mod: host_->emod(mod, dst, src0, src1); break; |
1050 | case op_kind_t::_shl: host_->eshl(mod, dst, src0, src1); break; |
1051 | case op_kind_t::_shr: host_->eshr(mod, dst, src0, src1); break; |
1052 | case op_kind_t::_min: host_->emin(mod, dst, src0, src1); break; |
1053 | case op_kind_t::_max: host_->emax(mod, dst, src0, src1); break; |
1054 | case op_kind_t::_ge: |
1055 | case op_kind_t::_gt: |
1056 | case op_kind_t::_le: |
1057 | case op_kind_t::_lt: |
1058 | case op_kind_t::_eq: |
1059 | case op_kind_t::_ne: { |
1060 | ir_assert(!dst.is_negated()) << "Destination can't be negated." ; |
1061 | ngen::InstructionModifier cmp_mod = mod; |
1062 | cmp_mod |= cmp_op_to_ngen(obj.op_kind); |
1063 | cmp_mod |= dst.flag_register(); |
1064 | host_->ecmp(cmp_mod, src0, src1); |
1065 | break; |
1066 | } |
1067 | case op_kind_t::_and: host_->eand(mod, dst, src0, src1); break; |
1068 | case op_kind_t::_prelu: { |
1069 | int grf_size = ngen::GRF::bytes(hw); |
1070 | int esize = mod.getExecSize(); |
1071 | int regs = utils::div_up(esize * int(sizeof(float)), grf_size); |
1072 | auto temp = scope_.alloc_reg_buf_data(regs).format( |
1073 | 0, ngen::DataType::f, esize); |
1074 | host_->emul(mod, temp, dst, src1); |
1075 | host_->csel(mod | host_->le, dst.reg_data(), temp, |
1076 | dst.reg_data(), dst.reg_data()); |
1077 | break; |
1078 | } |
1079 | default: |
1080 | ir_error_not_expected() |
1081 | << "Unknown kind: " << to_string(obj.op_kind); |
1082 | } |
1083 | } |
1084 | |
1085 | struct conjunct_t { |
1086 | conjunct_t(op_kind_t op, ngen_operand_t a, ngen_operand_t b) |
1087 | : op_(op), a_(std::move(a)), b_(std::move(b)) {} |
1088 | op_kind_t op_; |
1089 | ngen_operand_t a_, b_; |
1090 | }; |
1091 | |
1092 | void split_by_and(const expr_t &e, std::vector<conjunct_t> &cv, type_t ty) { |
1093 | if (auto bin = e.as_ptr<binary_op_t>()) { |
1094 | if (bin->op_kind == op_kind_t::_and) { |
1095 | split_by_and(bin->a, cv, ty); |
1096 | split_by_and(bin->b, cv, ty); |
1097 | } else |
1098 | cv.emplace_back(bin->op_kind, eval(bin->a), eval(bin->b)); |
1099 | } else { |
1100 | auto cast = cast_t::make(ty, e); |
1101 | cv.emplace_back(op_kind_t::undef, eval(cast), ngen_operand_t()); |
1102 | } |
1103 | } |
1104 | |
1105 | ngen_operand_t try_process_negated_flags(const expr_t &e) { |
1106 | ngen_operand_t retn; |
1107 | auto cast = e.as<unary_op_t>().a.as_ptr<cast_t>(); |
1108 | if (cast && cast->expr.type().is_bool()) { |
1109 | ngen_operand_t flags(scope_.alloc_flag(), e.type().elems()); |
1110 | retn = alloc_dst_op(e); |
1111 | auto mod = retn.mod(); |
1112 | auto ar_op = [&](ngen::InstructionModifier m, const conjunct_t &c) { |
1113 | if (c.op_ != op_kind_t::undef) |
1114 | host_->ecmp(m | cmp_op_to_ngen(c.op_), retn, c.a_, c.b_); |
1115 | else |
1116 | host_->emov(m, retn, -c.a_); |
1117 | }; |
1118 | std::vector<conjunct_t> cv; |
1119 | split_by_and(cast->expr, cv, cast->type); |
1120 | ar_op(mod, cv[0]); |
1121 | mod |= flags.flag_register(); |
1122 | for (int i = 1; i < int(cv.size()); i++) |
1123 | ar_op(mod, cv[i]); |
1124 | retn = -retn; |
1125 | } |
1126 | return retn; |
1127 | } |
1128 | |
1129 | bool try_region_peephole(const shuffle_t &obj) { |
1130 | int elems = obj.elems(); |
1131 | if (elems % 2 != 0) return false; |
1132 | |
1133 | std::vector<ngen_operand_t> vec(obj.vec.size()); |
1134 | ngen::DataType data_type = ngen::DataType::invalid; |
1135 | for (size_t i = 0; i < vec.size(); i++) { |
1136 | if (!obj.vec[i].is<load_t>()) return false; |
1137 | vec[i] = eval(obj.vec[i]); |
1138 | ir_assert(vec[i].is_reg_buf_data()) << obj.vec[i]; |
1139 | auto &rbd = vec[i].reg_buf_data(); |
1140 | if (data_type == ngen::DataType::invalid) { |
1141 | data_type = rbd.type(); |
1142 | continue; |
1143 | } |
1144 | if (data_type != rbd.type()) return false; |
1145 | } |
1146 | |
1147 | int grf_size = ngen::GRF::bytes(hw); |
1148 | auto diff_bytes = [&](const ngen_operand_t &a, |
1149 | const ngen_operand_t &b) { |
1150 | auto a_rd = a.reg_data(); |
1151 | auto b_rd = b.reg_data(); |
1152 | int a_off = a_rd.getBase() * grf_size + a_rd.getByteOffset(); |
1153 | int b_off = b_rd.getBase() * grf_size + b_rd.getByteOffset(); |
1154 | return b_off - a_off; |
1155 | }; |
1156 | |
1157 | int type_size = ngen::getBytes(data_type); |
1158 | int stride_bytes = diff_bytes(vec[0], vec[1]); |
1159 | if (stride_bytes < 0 || stride_bytes % type_size != 0) return false; |
1160 | |
1161 | // Pattern 1: [xxyy] |
1162 | auto is_xxyy = [&]() { |
1163 | for (int i = 0; i < elems / 2; i++) { |
1164 | if (obj.idx[i] != 0) return false; |
1165 | if (obj.idx[i + elems / 2] != 1) return false; |
1166 | } |
1167 | return true; |
1168 | }; |
1169 | if (is_xxyy()) { |
1170 | auto &rbd = vec[0].reg_buf_data(); |
1171 | auto rd = rbd.reg_data(); |
1172 | int regs = utils::div_up(stride_bytes * 2, grf_size); |
1173 | if (regs > 2) return false; |
1174 | rd.setRegion(stride_bytes / type_size, elems / 2, 0); |
1175 | reg_buf_t rb(hw, ngen::GRFRange(rd.getBase(), regs)); |
1176 | bind(obj, reg_buf_data_t(rb, rd)); |
1177 | return true; |
1178 | } |
1179 | |
1180 | // Pattern 2: [xyxy] |
1181 | auto is_xyxy = [&]() { |
1182 | for (int i = 0; i < elems / 2; i++) { |
1183 | if (obj.idx[i] != i) return false; |
1184 | if (obj.idx[i] != obj.idx[i + elems / 2]) return false; |
1185 | if (i > 0 && diff_bytes(vec[i - 1], vec[i]) != stride_bytes) |
1186 | return false; |
1187 | } |
1188 | return true; |
1189 | }; |
1190 | if (is_xyxy()) { |
1191 | auto &rbd = vec[0].reg_buf_data(); |
1192 | auto rd = rbd.reg_data(); |
1193 | int regs = utils::div_up(stride_bytes * elems / 2, grf_size); |
1194 | if (regs > 2) return false; |
1195 | rd.setRegion(0, elems / 2, stride_bytes / type_size); |
1196 | reg_buf_t rb(hw, ngen::GRFRange(rd.getBase(), regs)); |
1197 | bind(obj, reg_buf_data_t(rb, rd)); |
1198 | return true; |
1199 | } |
1200 | |
1201 | return false; |
1202 | } |
1203 | |
1204 | bool try_packed_int_peephole(const shuffle_t &obj) { |
1205 | if (!obj.type.is_x32()) return false; |
1206 | if (!utils::one_of(obj.elems(), 8, 16)) return false; |
1207 | |
1208 | int64_t int_min = std::numeric_limits<int>::min(); |
1209 | int64_t int_max = std::numeric_limits<int>::max(); |
1210 | int vec_size = (int)obj.vec.size(); |
1211 | std::vector<int> vec(vec_size); |
1212 | for (int i = 0; i < vec_size; i++) { |
1213 | if (!is_const(obj.vec[i])) return false; |
1214 | int value = to_cpp<int64_t>(obj.vec[i]); |
1215 | if (value < int_min || value > int_max) return false; |
1216 | vec[i] = (int)value; |
1217 | } |
1218 | |
1219 | const int esize = 8; |
1220 | |
1221 | auto half_same = [&](int off) { |
1222 | return std::equal(obj.idx.begin() + off + 1, |
1223 | obj.idx.begin() + off + esize, obj.idx.begin() + off); |
1224 | }; |
1225 | // If true, the case is too trivial for :v/:uv to justify the overhead |
1226 | if (half_same(0) && half_same(esize % obj.elems())) return false; |
1227 | |
1228 | int vec_min = *std::min_element(vec.begin(), vec.end()); |
1229 | int vec_max = *std::max_element(vec.begin(), vec.end()); |
1230 | |
1231 | int factor = vec_max - vec_min; |
1232 | for (int i = 0; i < vec_size; i++) |
1233 | factor = math::gcd(vec[i] - vec_min, factor); |
1234 | |
1235 | // XXX: Disabled due to an emulation limitation: vector multiplication |
1236 | // by dword constant is not implemented yet. |
1237 | int64_t s16_min = std::numeric_limits<int16_t>::min(); |
1238 | int64_t s16_max = std::numeric_limits<int16_t>::max(); |
1239 | if (factor < s16_min || factor > s16_max) return false; |
1240 | |
1241 | auto check_range = [&](int f, int m, int a, int b) { |
1242 | for (int i = 0; i < vec_size; i++) { |
1243 | int d = (vec[i] - m) / f; |
1244 | if (d < a || d > b) return false; |
1245 | } |
1246 | return true; |
1247 | }; |
1248 | |
1249 | bool use_uv = false, use_v = false; |
1250 | for (int f : {1, factor, -factor}) { |
1251 | use_uv = check_range(f, vec_min, 0, 15); |
1252 | use_v = check_range(f, vec_min, -8, 7); |
1253 | if (use_uv || use_v) { |
1254 | factor = f; |
1255 | break; |
1256 | } |
1257 | } |
1258 | if (!use_uv && !use_v) return false; |
1259 | if (vec_min % factor == 0) { |
1260 | bool new_use_uv = check_range(factor, 0, 0, 15); |
1261 | bool new_use_v = check_range(factor, 0, -8, 7); |
1262 | if (new_use_uv || new_use_v) { |
1263 | vec_min = 0; |
1264 | use_uv = new_use_uv; |
1265 | use_v = new_use_v; |
1266 | } |
1267 | } |
1268 | |
1269 | auto set_packed = [](uint32_t &packed, int8_t value, int idx) { |
1270 | uint32_t v = (value >= 0 ? value : ((value & 0x7) | 0x8)); |
1271 | packed = packed | (v << idx * 4); |
1272 | }; |
1273 | |
1274 | auto dst = alloc_dst_op(obj); |
1275 | auto &dst_rbd = dst.reg_buf_data(); |
1276 | int dst_stride = dst_rbd.hs(); |
1277 | // no more than 1 temporary register is going to be required |
1278 | auto storage = scope_.alloc_reg_buf_data(1); |
1279 | |
1280 | auto w_type = (use_uv) ? ngen::DataType::uw : ngen::DataType::w; |
1281 | for (int i = 0; i < obj.elems(); i += esize) { |
1282 | uint32_t packed = 0; |
1283 | for (int j = 0; j < esize; j++) |
1284 | set_packed(packed, (vec[obj.idx[i + j]] - vec_min) / factor, j); |
1285 | auto tmp = storage.format(i * sizeof(uint16_t), w_type, esize); |
1286 | host_->emov(esize, tmp, |
1287 | (use_uv) ? ngen::Immediate::uv(packed) |
1288 | : ngen::Immediate::v(packed)); |
1289 | } |
1290 | auto d = dst_rbd.format( |
1291 | 0, ngen::DataType::invalid, obj.elems(), dst_stride); |
1292 | auto tmp = storage.format(0, w_type, obj.elems()); |
1293 | if (factor != 1) { |
1294 | host_->emul(obj.elems(), d, tmp, ngen::Immediate(factor)); |
1295 | } |
1296 | if (factor == 1 || vec_min != 0) { |
1297 | host_->eadd(obj.elems(), d, (factor == 1) ? tmp : d, |
1298 | ngen::Immediate(vec_min)); |
1299 | } |
1300 | bind(obj, dst); |
1301 | return true; |
1302 | } |
1303 | |
1304 | ir_kernel_t<hw> *host_; |
1305 | expr_binding_t expr_binding_; |
1306 | ngen_register_scope_t &scope_; |
1307 | |
1308 | object_eq_map_t<expr_t, type_t> int_up_converts_; |
1309 | }; |
1310 | |
1311 | } // namespace jit |
1312 | } // namespace gpu |
1313 | } // namespace impl |
1314 | } // namespace dnnl |
1315 | |
1316 | #endif |
1317 | |