1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_JIT_CODEGEN_REORDER_HPP
18#define GPU_JIT_CODEGEN_REORDER_HPP
19
20#include "gpu/jit/codegen/operand.hpp"
21#include "gpu/jit/codegen/register_scope.hpp"
22#include "gpu/jit/ir/reorder.hpp"
23#include "gpu/jit/ir/tensor.hpp"
24#include "gpu/jit/ngen/ngen.hpp"
25
26namespace dnnl {
27namespace impl {
28namespace gpu {
29namespace jit {
30
31// Rewrites a single-src instruction to avoid GRF boundary issues.
32struct op_plan_t {
33private:
34 template <typename... ArgT>
35 using op_t = std::function<void(ArgT...)>;
36
37 using inst_mod_t = ngen::InstructionModifier;
38 using reg_data_t = ngen::RegData;
39 using single_src_op_t = op_t<inst_mod_t, reg_data_t, reg_data_t>;
40
41public:
42 op_plan_t(int grf_size) : grf_size_(grf_size) {}
43
44 void operator()(const single_src_op_t &op, inst_mod_t mod, reg_data_t dst,
45 reg_data_t src) const {
46 // Rewrite a single-src instruction that spans more than 2 GRFs into
47 // multiple ops
48 auto dst_esize = max_esize(dst, /*is_dst=*/true);
49 auto src_esize = max_esize(src, /*is_dst=*/false);
50 auto original_esize = mod.getExecSize();
51 auto original_width = src.getWidth();
52 auto esize = std::min(std::min(dst_esize, src_esize), original_esize);
53
54 mod.setExecSize(esize);
55 if (esize < original_width)
56 // Width must be at most esize
57 set_contiguous_region(src, esize, src.getHS());
58
59 for (int i = 0; i < original_esize; i += esize) {
60 fixup(op, mod, dst, src);
61 shift_offset(dst, esize);
62 shift_offset(src, esize);
63 }
64 }
65
66private:
67 int max_esize(const reg_data_t &reg, bool is_dst) const {
68 auto size = reg.getBytes();
69 auto width = reg.getWidth();
70 auto hs = reg.getHS();
71 auto vs = reg.getVS();
72 auto remaining_bytes = 2 * grf_size_ - reg.getByteOffset();
73 auto stride = hs;
74 if (!is_dst && width == 1) stride = vs;
75 if (is_dst && stride == 0) stride = 1;
76 if (stride == 0) return 16; // Broadcast can have max step
77 auto max_step = (remaining_bytes - 1) / (stride * size) + 1;
78 return utils::rnd_down_pow2(max_step);
79 }
80
81 void fixup(const single_src_op_t &op, inst_mod_t mod, reg_data_t dst,
82 reg_data_t src) const {
83 // Rewrite src0 to cross GRF boundaries using vertical striding
84 auto exec_size = mod.getExecSize();
85 auto offset = src.getOffset();
86 auto width = src.getWidth();
87 auto hs = src.getHS();
88 auto vs = src.getVS();
89 auto size = src.getBytes();
90
91 if (!width) width = exec_size;
92 auto height = exec_size / width;
93 auto grf_elems = grf_size_ / size;
94
95 bool crosses_grf_boundary = false;
96 auto begin = offset;
97 for (int i = 0; i < height; ++i) {
98 auto reg_off = begin % grf_elems;
99 crosses_grf_boundary
100 |= (reg_off + (width - 1) * hs + 1 > grf_elems);
101 begin += vs;
102 }
103
104 if (!crosses_grf_boundary) {
105 // op is valid
106 op(mod, dst, src);
107 } else if (vs == width * hs) {
108 // rewrite src as a valid access with shorter width and vs
109 auto elems_to_grf_boundary = (grf_elems - offset - 1) / hs + 1;
110 auto tentative_width = utils::rnd_down_pow2(elems_to_grf_boundary);
111 while (tentative_width > 1) {
112 if (elems_to_grf_boundary % tentative_width == 0) break;
113 tentative_width /= 2;
114 }
115
116 set_contiguous_region(src, tentative_width, hs);
117 op(mod, dst, src);
118 } else {
119 // break op into multiple row-wise ops
120 mod.setExecSize(width);
121 set_contiguous_region(src, width, hs);
122 for (int i = 0; i < height; ++i) {
123 fixup(op, mod, dst, src);
124 shift_offset(dst, width * dst.getHS());
125 shift_offset(src, vs);
126 }
127 }
128 }
129
130 void set_contiguous_region(reg_data_t &rr, int width, int hs) const {
131 if (width > 1)
132 rr.setRegion(width * hs, width, hs);
133 else
134 // Each element occupies its own row. width = 1 requires hs = 0
135 rr.setRegion(hs, 1, 0);
136 }
137
138 void shift_offset(reg_data_t &rr, int offset) const {
139 auto new_offset = rr.getOffset() + offset;
140 auto type_size = rr.getBytes();
141 auto grf_elems = grf_size_ / type_size;
142 rr.setBase(rr.getBase() + new_offset / grf_elems);
143 rr.setOffset(new_offset % grf_elems);
144 };
145
146 int grf_size_;
147};
148
149// Aligns src offset with dst offset when src is not broadcasted.
150template <typename GeneratorT>
151void align_src_dst_offset(GeneratorT *host, ngen_register_scope_t &scope,
152 const ngen::InstructionModifier &mod, const reg_buf_data_t &dst,
153 reg_buf_data_t &src);
154
155template <typename GeneratorT>
156bool try_emit_batched_reorder_1d_tile(ngen::HW hw, GeneratorT *host,
157 ngen_register_scope_t &scope, int width, const reg_buf_data_t &src,
158 int src_stride, const reg_buf_data_t &dst, int dst_stride) {
159 ngen::DataType src_type = src.type();
160 ngen::DataType dst_type = dst.type();
161 int src_type_size = ngen::getBytes(src_type);
162 int dst_type_size = ngen::getBytes(dst_type);
163 auto large_type = (src_type_size > dst_type_size) ? src_type : dst_type;
164 auto small_type = (src_type_size < dst_type_size) ? src_type : dst_type;
165 ngen_register_scope_t lex_scope {scope.register_allocator()};
166
167 if (!utils::one_of(large_type, ngen::DataType::f, ngen::DataType::d))
168 return false;
169 if (!utils::one_of(small_type, ngen::DataType::b, ngen::DataType::ub))
170 return false;
171 if (src_stride != 1) return false;
172 if (dst_stride != 1) return false;
173 // XeHPG seems to have a problem with scalar byte writes with saturation;
174 // defer to the non-batched implementation's workaround
175 if (width == 1) return false;
176
177 int batch = 128;
178 int max_step = 8;
179
180 const int grf_size = ngen::GRF::bytes(hw);
181 op_plan_t plan = grf_size;
182 auto tmp = lex_scope.alloc_reg_buf_data(
183 utils::div_up(int(batch * sizeof(uint32_t)), grf_size));
184 using inst_mod_t = ngen::InstructionModifier;
185 using reg_data_t = ngen::RegData;
186 auto mov = [&](inst_mod_t mod, reg_data_t dst, reg_data_t src) {
187 host->emov(mod, dst, src);
188 };
189
190 for (int i = 0; i < width; i += batch) {
191 int i_beg = i;
192 int i_end = std::min(width, i + batch);
193
194 for (int ii = i_beg; ii < i_end;) {
195 int esize = std::min(max_step, i_end - ii);
196 esize = utils::rnd_down_pow2(esize);
197
198 auto s = src.subregister(ii, esize, src_type_size);
199 auto t = tmp.subregister((ii - i_beg) * 4, small_type)(4);
200 ngen::InstructionModifier mod = esize;
201 if (dst_type == small_type) mod |= host->sat;
202 plan(mov, mod, t, s(1));
203 ii += esize;
204 }
205 for (int ii = i_beg; ii < i_end;) {
206 int esize = std::min(max_step, i_end - ii);
207 esize = utils::rnd_down_pow2(esize);
208
209 auto d = dst.subregister(ii, esize, dst_type_size);
210 auto t = tmp.subregister((ii - i_beg) * 4, small_type)(4);
211 plan(mov, esize, d(1), t);
212 ii += esize;
213 }
214 }
215 return true;
216}
217
218// Performs 1D reorder, possibly with strides and type conversion.
219template <typename GeneratorT>
220void emit_reorder_1d_tile(ngen::HW hw, GeneratorT *host,
221 ngen_register_scope_t &scope, int width, const reg_buf_data_t &_src,
222 int src_stride, const reg_buf_data_t &_dst, int dst_stride) {
223
224 if (try_emit_batched_reorder_1d_tile(
225 hw, host, scope, width, _src, src_stride, _dst, dst_stride))
226 return;
227
228 auto src = _src;
229 auto dst = _dst;
230 ngen::DataType src_type = src.type();
231 ngen::DataType dst_type = dst.type();
232 // Replace (float -> float) by (int -> int) as word/dword moves have less
233 // restrictions.
234 if (src_type == dst_type
235 && utils::one_of(src_type, ngen::DataType::bf, ngen::DataType::hf,
236 ngen::DataType::f, ngen::DataType::df)) {
237 int factor = (src_type == ngen::DataType::df ? 2 : 1);
238 if (factor == 1 || (src_stride == 1 && dst_stride == 1)) {
239 src_type
240 = to_ngen(type_t::u(ngen::getBytes(src_type) / factor * 8));
241 dst_type = src_type;
242 width *= factor;
243 src = src.reinterpret(src_type);
244 dst = dst.reinterpret(dst_type);
245 }
246 }
247
248 const int grf_size = ngen::GRF::bytes(hw);
249 int src_type_size = ngen::getBytes(src_type);
250 int dst_type_size = ngen::getBytes(dst_type);
251 int src_stride_bytes = src_stride * src_type_size;
252 int dst_stride_bytes = dst_stride * dst_type_size;
253 bool dst_b = ngen_is_b(dst_type);
254 bool dst_d = ngen_is_dw(dst_type);
255 bool dst_q = ngen_is_qw(dst_type);
256 bool dst_f = (dst_type == ngen::DataType::f);
257 bool dst_hf = (dst_type == ngen::DataType::hf);
258 bool dst_bf = (dst_type == ngen::DataType::bf);
259 bool dst_df = (dst_type == ngen::DataType::df);
260 bool dst_xf = dst_bf || dst_f || dst_hf || dst_df;
261 bool src_b = ngen_is_b(src_type);
262 bool src_d = ngen_is_dw(src_type);
263 bool src_q = ngen_is_qw(src_type);
264 bool src_f = (src_type == ngen::DataType::f);
265 bool src_hf = (src_type == ngen::DataType::hf);
266 bool src_bf = (src_type == ngen::DataType::bf);
267 bool src_df = (src_type == ngen::DataType::df);
268 bool src_xf = src_bf || src_f || src_hf || src_df;
269 bool f_to_xf = (src_f && (dst_bf || dst_hf));
270 op_plan_t plan = grf_size;
271 ngen_register_scope_t lex_scope {scope.register_allocator()};
272
273 auto get_step = [&]() {
274 int step = (width < 16 ? 8 : 16);
275
276 // f32 -> bf16 or f32 -> f16: SIMD16 does not support mixed mode move.
277 if (hw < ngen::HW::XeHPC)
278 if (f_to_xf) step = 8;
279
280 if (src_df || dst_df) step = 8;
281
282 // Max supported stride is 4.
283 if (src_stride > 4 || dst_stride > 4) step = 1;
284
285 // Qword does not appear to support swizzling.
286 if (src_q && dst_q && src_stride != dst_stride) step = 1;
287
288 return step;
289 };
290
291 using inst_mod_t = ngen::InstructionModifier;
292 using reg_data_t = ngen::RegData;
293 auto shl16 = [&](inst_mod_t mod, reg_data_t dst, reg_data_t src) {
294 host->eshl(mod, dst, src, 16);
295 };
296 auto mov = [&](inst_mod_t mod, reg_data_t dst, reg_data_t src) {
297 host->emov(mod, dst, src);
298 };
299
300 // bf16 -> f32:
301 // - bf16 must be packed: use left shift instead.
302 if (src_bf && dst_f) {
303 int step = get_step();
304 for (int i = 0; i < width; i += step) {
305 step = std::min(step, width - i);
306 step = utils::rnd_down_pow2(step);
307 int esize = step;
308 auto s = src.subregister(
309 i, esize, src_stride_bytes, ngen::DataType::uw);
310 auto d = dst.subregister(
311 i, esize, dst_stride_bytes, ngen::DataType::ud);
312 plan(shl16, esize, d(dst_stride), s(src_stride));
313 }
314 return;
315 }
316
317 // d -> bf/hf:
318 // - Use d -> f -> bf/hf conversion with temporary
319 if (src_d && (dst_bf || dst_hf)) {
320 const int nregs = utils::div_up(width * (int)sizeof(float), grf_size);
321 auto tmp = lex_scope.alloc_reg_buf_data(nregs).format(
322 0, ngen::DataType::f);
323 emit_reorder_1d_tile(hw, host, scope, width, src, src_stride, tmp, 1);
324 emit_reorder_1d_tile(hw, host, scope, width, tmp, 1, dst, dst_stride);
325 return;
326 }
327
328 // b -> hf
329 // - Direct b -> float conversion not supported: use s16 temporary
330 // - int -> hf must be DW-aligned & strided: use f temporary
331 // - Use b -> w -> f -> hf
332 if (src_b && dst_hf) {
333 ir_assert(utils::one_of(dst_stride_bytes, 2, 4));
334 ir_assert(utils::one_of(src_stride_bytes, 1, 4));
335 int step = get_step();
336 const int align_boundary = grf_size / 2;
337 const int step_size = step * (int)sizeof(uint32_t);
338 const int nregs = utils::div_up(step_size, grf_size);
339 auto tmp1 = lex_scope.alloc_reg_buf_data(nregs);
340 auto tmp2 = lex_scope.alloc_reg_buf_data(nregs);
341 for (int i = 0; i < width; i += step) {
342 step = std::min(step, width - i);
343 step = utils::rnd_down_pow2(step);
344 int esize = step;
345
346 auto s = src.subregister(i, esize, src_stride_bytes);
347 auto d = dst.subregister(i, esize, dst_stride_bytes);
348 auto byte_offset = 2 * (d.getByteOffset() % align_boundary);
349 auto t1 = tmp1.subregister(byte_offset, ngen::DataType::w);
350 auto t2 = tmp2.subregister(byte_offset, ngen::DataType::f);
351 auto t1_as_hf = t1.reinterpret(0, ngen::DataType::hf);
352 auto d_as_w = d.reinterpret(0, ngen::DataType::w);
353
354 plan(mov, esize, t1(2), s(src_stride));
355 plan(mov, esize, t2(1), t1(2));
356 plan(mov, esize, t1_as_hf(2), t2(1));
357 plan(mov, esize, d_as_w(dst_stride), t1(2));
358 }
359 return;
360 }
361
362 // hf -> b
363 if (src_hf && dst_b) {
364 ir_assert(utils::one_of(src_stride_bytes, 2, 4));
365 ir_assert(utils::one_of(dst_stride_bytes, 1, 4));
366 int step = get_step();
367 const int tmp_stride = 4;
368 const int tmp_stride_bytes = tmp_stride * dst_type_size;
369 const int step_size = step * tmp_stride_bytes;
370 const int nregs = 1 + utils::div_up(step_size, grf_size);
371 auto tmp1 = lex_scope.alloc_reg_buf_data(nregs);
372 auto tmp2 = lex_scope.alloc_reg_buf_data(nregs);
373 for (int i = 0; i < width; i += step) {
374 step = std::min(step, width - i);
375 step = utils::rnd_down_pow2(step);
376 int esize = step;
377
378 auto s = src.subregister(i, esize, src_stride_bytes);
379 auto d = dst.subregister(i, esize, dst_stride_bytes);
380 const int t1_offset = s.getByteOffset();
381 const int t2_offset = (d.getOffset() % 16) * tmp_stride_bytes;
382 auto t1 = tmp1.subregister(t1_offset, dst_type);
383 auto t2 = tmp2.subregister(t2_offset, dst_type);
384
385 if (dst_stride_bytes >= tmp_stride_bytes && esize > 1) {
386 plan(mov, esize | host->sat, d(dst_stride), s(src_stride));
387 continue;
388 }
389 auto wa_esize = std::max(2, esize);
390 plan(mov, wa_esize | host->sat, t1(tmp_stride), s(src_stride));
391 if (t1_offset != t2_offset)
392 plan(mov, esize, t2(tmp_stride), t1(tmp_stride));
393 else
394 std::swap(t1, t2);
395 plan(mov, esize, d(dst_stride), t2(tmp_stride));
396 }
397 return;
398 }
399
400 // f -> df
401 // - f/df mixed operands must be qword aligned
402 // - f -> f striding: use s32
403 if (src_f && dst_df) {
404 int step = get_step();
405 const auto tmp_type = src_type;
406 const int tmp_stride = 2;
407 const int tmp_stride_bytes = tmp_stride * src_type_size;
408 const int reg_size = dst.byte_offset() + width * tmp_stride_bytes;
409 const int nregs = utils::div_up(reg_size, grf_size);
410 auto tmp = lex_scope.alloc_reg_buf_data(nregs);
411 for (int i = 0; i < width; i += step) {
412 step = std::min(step, width - i);
413 step = utils::rnd_down_pow2(step);
414 int esize = step;
415
416 auto s = src.subregister(i, esize, src_stride_bytes);
417 auto d = dst.subregister(i, esize, dst_stride_bytes);
418 auto t = tmp.subregister(d.getByteOffset(), tmp_type);
419 plan(mov, esize, t.d()(tmp_stride), s.d()(src_stride));
420 plan(mov, esize, d(dst_stride), t(tmp_stride));
421 }
422 return;
423 }
424
425 // df -> f
426 // - f/df mixed operands must be qword aligned
427 // - f -> f packing: use s32
428 if (dst_f && src_df) {
429 int step = get_step();
430 const auto tmp_type = dst_type;
431 const int tmp_stride = 2;
432 const int tmp_stride_bytes = tmp_stride * src_type_size;
433 const int reg_size = dst.byte_offset() + width * tmp_stride_bytes;
434 const int nregs = utils::div_up(reg_size, grf_size);
435 auto tmp = lex_scope.alloc_reg_buf_data(nregs);
436 for (int i = 0; i < width; i += step) {
437 step = std::min(step, width - i);
438 step = utils::rnd_down_pow2(step);
439 int esize = step;
440
441 auto s = src.subregister(i, esize, src_stride_bytes);
442 auto d = dst.subregister(i, esize, dst_stride_bytes);
443 auto t = tmp.subregister(s.getByteOffset(), tmp_type);
444 plan(mov, esize, t(tmp_stride), s(src_stride));
445 // df -> f uses the float pipe. Override ngen setting the long pipe.
446 auto mod_with_pipe = esize | ngen::SWSB<float>(1);
447 plan(mov, mod_with_pipe, d.d()(dst_stride), t.d()(tmp_stride));
448 }
449 return;
450 }
451
452 // f -> hf
453 if (src_f && dst_hf) {
454 int step = get_step();
455 const auto tmp_type = dst_type;
456 const int tmp_stride = 2;
457 const int reg_size = step * tmp_stride * dst_type_size;
458 const int nregs = utils::div_up(reg_size, grf_size);
459 auto tmp = lex_scope.alloc_reg_buf_data(nregs);
460 for (int i = 0; i < width; i += step) {
461 step = std::min(step, width - i);
462 step = utils::rnd_down_pow2(step);
463 int esize = step;
464
465 auto s = src.subregister(i, esize, src_stride_bytes);
466 auto d = dst.subregister(i, esize, dst_stride_bytes);
467
468 const auto align_boundary = grf_size / 2;
469 bool aligned = d.getByteOffset() % align_boundary == 0
470 && s.getByteOffset() == 0;
471 if (esize > 1 && dst_stride == 1 && !aligned) {
472 auto t = tmp.subregister(s.getByteOffset(), tmp_type);
473 plan(mov, esize, t(tmp_stride), s(src_stride));
474 plan(mov, esize, d.w()(dst_stride), t.w()(tmp_stride));
475 continue;
476 }
477 plan(mov, esize, d(dst_stride), s(src_stride));
478 }
479 return;
480 }
481
482 // hf -> f
483 if (dst_f && src_hf) {
484 int step = get_step();
485 const auto tmp_type = src_type;
486 const int tmp_stride = 2;
487 const int reg_size = step * tmp_stride * src_type_size;
488 const int nregs = utils::div_up(reg_size, grf_size);
489 auto tmp = lex_scope.alloc_reg_buf_data(nregs);
490 for (int i = 0; i < width; i += step) {
491 step = std::min(step, width - i);
492 step = utils::rnd_down_pow2(step);
493 int esize = step;
494
495 auto s = src.subregister(i, esize, src_stride_bytes);
496 auto d = dst.subregister(i, esize, dst_stride_bytes);
497
498 const auto align_boundary = grf_size / 2;
499 bool aligned = s.getByteOffset() % align_boundary == 0
500 && d.getByteOffset() == 0;
501 if (esize > 1 && src_stride == 1 && !aligned) {
502 auto t = tmp.subregister(d.getByteOffset(), tmp_type);
503 plan(mov, esize, t.w()(tmp_stride), s.w()(src_stride));
504 plan(mov, esize, d(dst_stride), t(tmp_stride));
505 continue;
506 }
507 plan(mov, esize, d(dst_stride), s(src_stride));
508 }
509 return;
510 }
511
512 // f32/f16/s32 -> s8/u8 and s8/u8 -> f32/s32
513 // - Use saturation
514 // - s8/u8 must be DW-strided: use temporary
515 bool d_or_f_to_b = (src_d || src_f) && dst_b;
516 bool b_to_d_or_f = (dst_d || dst_f) && src_b;
517 bool hf_to_b = src_hf && dst_b;
518 if (d_or_f_to_b || b_to_d_or_f || hf_to_b) {
519 if (dst_d || dst_f) ir_assert(dst_stride_bytes == 4);
520 if (src_d || src_f) ir_assert(src_stride_bytes == 4);
521 if (src_hf) ir_assert(utils::one_of(src_stride_bytes, 2, 4));
522 if (dst_b) ir_assert(utils::one_of(dst_stride_bytes, 1, 4));
523 if (src_b) ir_assert(utils::one_of(src_stride_bytes, 1, 4));
524 int step = get_step();
525 const int step_size = step * (int)sizeof(uint32_t);
526 const int nregs = 1 + utils::div_up(step_size, grf_size);
527 auto tmp = lex_scope.alloc_reg_buf_data(nregs);
528 for (int i = 0; i < width; i += step) {
529 step = std::min(step, width - i);
530 step = utils::rnd_down_pow2(step);
531 int esize = step;
532
533 auto s = src.subregister(i, esize, src_stride_bytes);
534 auto d = dst.subregister(i, esize, dst_stride_bytes);
535 if (src_d || src_f || src_hf) {
536 // d -> b.
537 if (dst_stride_bytes == 1 || esize == 1) {
538 auto offset_bytes = src_f ? s.getByteOffset()
539 : 4 * (d.getByteOffset() % 16);
540 auto t = tmp.subregister(offset_bytes, dst_type)(2);
541 plan(mov, std::max(2, esize) | host->sat, t, s);
542 plan(mov, esize, d(dst_stride), t);
543 } else {
544 plan(mov, esize | host->sat, d(dst_stride), s(src_stride));
545 }
546 } else {
547 if (esize == 1) {
548 // Direct x8 -> x32 scalar cast is not always
549 // supported. Use intermediate cast to s16.
550 auto t = tmp.subregister(0, ngen::DataType::w)(1);
551 plan(mov, esize, t, s(src_stride));
552 plan(mov, esize, d(dst_stride), t);
553 } else if (src_b) {
554 auto offset_bytes = dst_f ? d.getByteOffset() : 0;
555 auto t = tmp.subregister(offset_bytes, src_type)(4);
556 plan(mov, esize, t, s(src_stride));
557 plan(mov, esize, d(dst_stride), t);
558 } else {
559 plan(mov, esize, d(dst_stride), s(src_stride));
560 }
561 }
562 }
563 return;
564 }
565
566 // Handle mov(src.uw(x)(1), dst.uw(y)(2)).
567 if (src_type_size == 2 && dst_type_size == 2 && src_stride == 2
568 && dst_stride == 1 && width > 1) {
569 int step = get_step();
570 auto step_size = step * src_type_size * src_stride;
571 auto nregs = utils::div_up(step_size, grf_size);
572 auto tmp = lex_scope.alloc_reg_buf_data(nregs);
573 for (int i = 0; i < width; i += step) {
574 step = std::min(step, width - i);
575 step = utils::rnd_down_pow2(step);
576 int esize = step;
577 auto s = src.format(i * src_stride_bytes, ngen::DataType::invalid,
578 esize, src_stride);
579 auto d = dst.format(i * dst_stride_bytes, ngen::DataType::invalid,
580 esize, dst_stride);
581 auto d_old = d;
582 bool d_half_grf_aligned
583 = utils::one_of(d.byte_offset(), 0, grf_size / 2);
584 if (!d_half_grf_aligned) {
585 d = scope.alloc_reg_data(to_ir(dst_type).with_elems(esize));
586 }
587 if (s.offset() != 0) {
588 auto t = tmp.format(0, src_type, esize, src_stride);
589 plan(mov, esize, t, s);
590 s = t;
591 }
592 plan(mov, esize, d, s);
593 if (!d_half_grf_aligned) plan(mov, esize, d_old, d);
594 }
595 return;
596 }
597
598 // Perform FP to FP move.
599 // Float pipe has some register regioning limitations. If mov is not
600 // allowed then fix regioning by switching to integer pipe which has
601 // less limitations.
602 if (src_xf || dst_xf) {
603 int step = get_step();
604 for (int i = 0; i < width; i += step) {
605 step = std::min(step, width - i);
606 step = utils::rnd_down_pow2(step);
607 int esize = step;
608 ir_assert(math::is_pow2(esize));
609 auto s = src.format(i * src_stride_bytes, ngen::DataType::invalid,
610 esize, src_stride);
611 auto d = dst.format(i * dst_stride_bytes, ngen::DataType::invalid,
612 esize, dst_stride);
613 auto d_old = d;
614
615 bool do_d0_align = false;
616 if (esize > 1 && dst_bf) {
617 bool d_0_aligned = (d.byte_offset() == 0);
618 bool d_half_grf_aligned = (d.byte_offset() == grf_size / 2);
619 if (!d_0_aligned && (!d_half_grf_aligned || dst_stride != 1)) {
620 do_d0_align = true;
621 }
622 }
623 if (do_d0_align) {
624 d = lex_scope.alloc_reg_data(to_ir(dst_type).with_elems(esize));
625 }
626
627 bool do_align = false;
628 if (esize > 1 && s.hs() != 0) {
629 if ((src_f && dst_hf) || (dst_f && src_hf)) {
630 if (s.byte_offset() != d.byte_offset()) do_align = true;
631 } else if (s.offset() != d.offset())
632 do_align = true;
633 }
634 if (do_align) {
635 bool s_half_grf_aligned = (src_hf || src_bf)
636 && utils::one_of(s.byte_offset(), 0, grf_size / 2);
637 bool d_half_grf_aligned = (dst_hf || dst_bf)
638 && utils::one_of(d.byte_offset(), 0, grf_size / 2);
639 if (dst_f && d.offset() == 0 && s_half_grf_aligned)
640 do_align = false;
641 if (src_f && s.offset() == 0 && d_half_grf_aligned)
642 do_align = false;
643 }
644
645 if (do_align) {
646 auto i_type = to_ngen(type_t::u(ngen::getBytes(src_type) * 8));
647 s = s.reinterpret(i_type);
648 align_src_dst_offset(host, scope, esize, d, s);
649 s = s.reinterpret(src_type);
650 }
651 plan(mov, esize, d, s);
652
653 if (do_d0_align) {
654 auto i_type = to_ngen(type_t::u(ngen::getBytes(dst_type) * 8));
655 auto d_int = d_old.reinterpret(i_type);
656 auto s_int = d.reinterpret(i_type);
657 plan(mov, esize, d_int, s_int);
658 }
659 }
660 return;
661 }
662
663 if (src_b && dst_b) {
664 const int tmp_stride = 4;
665 const int tmp_stride_bytes = tmp_stride * dst_type_size;
666 // Any byte conversion requires saturation:
667 // - ub -> b loses 1 bit of precision
668 // - b -> ub loses sign bit
669 const bool needs_saturation = src_type != dst_type;
670
671 int step = get_step();
672 const int nregs = 1 + utils::div_up(step * tmp_stride_bytes, grf_size);
673 auto tmp = lex_scope.alloc_reg_buf_data(nregs);
674 for (int i = 0; i < width; i += step) {
675 step = std::min(step, width - i);
676 step = utils::rnd_down_pow2(step);
677 int esize = step;
678 auto s = src.format(i * src_stride_bytes, ngen::DataType::invalid,
679 esize, src_stride);
680 auto d = dst.format(i * dst_stride_bytes, ngen::DataType::invalid,
681 esize, dst_stride);
682 ngen::InstructionModifier mod = esize;
683 if (needs_saturation) mod |= host->sat;
684
685 bool aligned = true;
686 const bool needs_raw_mov = dst_stride == 1 || esize == 1;
687 if ((src_stride == 1 || esize == 1) && needs_raw_mov) {
688 // Note: This case does not appear to be documented. Experiments
689 // seem to indicate that packed byte-to-packed byte move must be
690 // word-aligned on the destination.
691 aligned = d.offset() % 2 == 0;
692 } else if (dst_stride <= 2 && src_stride >= 2 * dst_stride) {
693 const int rel_stride = src_stride / dst_stride;
694 const int alignment_bdy = grf_size / rel_stride;
695 const int dst_aligned_offset = d.offset() % alignment_bdy;
696 const int src_aligned_offset = s.offset() / rel_stride;
697 aligned = dst_aligned_offset == src_aligned_offset;
698 }
699
700 // Workaround for scalar byte conversion:
701 // - Broadcast to two locations with stride and conversion
702 // - Move one copy to the destination
703 if (needs_saturation && esize == 1) mod.setExecSize(2);
704
705 if (!aligned || (needs_saturation && needs_raw_mov)) {
706 const int tmp_rel_stride = tmp_stride / dst_stride;
707 const int tmp_alignment_bdy = grf_size / tmp_rel_stride;
708 const int tmp_aligned_offset = d.offset() % tmp_alignment_bdy;
709 const int tmp_offset = tmp_rel_stride * tmp_aligned_offset;
710 const int allowed_bytes = 2 * grf_size - tmp_offset;
711
712 if ((mod.getExecSize() - 1) * tmp_stride + 1 > allowed_bytes) {
713 // Workaround for cases where temporary is not grf aligned
714 // and esize == 16 on XeHPG and below
715 auto max_width = (allowed_bytes - 1) / tmp_stride + 1;
716 auto tmp_esize = utils::rnd_down_pow2(max_width);
717 mod.setExecSize(tmp_esize);
718 esize = tmp_esize;
719 step = tmp_esize;
720 }
721
722 auto t = tmp.format(
723 tmp_offset, dst_type, mod.getExecSize(), tmp_stride);
724 plan(mov, mod, t, s);
725 mod = esize;
726 s = tmp.format(tmp_offset, dst_type, esize, tmp_stride);
727 }
728 plan(mov, mod, d, s);
729 }
730 return;
731 }
732
733 // Perform regular move.
734 int step = get_step();
735 for (int i = 0; i < width; i += step) {
736 step = std::min(step, width - i);
737 step = utils::rnd_down_pow2(step);
738 int esize = step;
739 ir_assert(math::is_pow2(esize));
740 auto s = src.format(i * src_stride_bytes, ngen::DataType::invalid,
741 esize, src_stride);
742 auto d = dst.format(i * dst_stride_bytes, ngen::DataType::invalid,
743 esize, dst_stride);
744 plan(mov, esize, d, s);
745 }
746}
747
748template <typename GeneratorT>
749void align_src_dst_offset(GeneratorT *host, ngen_register_scope_t &scope,
750 const ngen::InstructionModifier &mod, const reg_buf_data_t &dst,
751 reg_buf_data_t &src) {
752 int src_stride = src.hs();
753 // src is broadcasted, no need to align, return.
754 if (src_stride == 0) return;
755
756 bool is_xf = ngen_is_xf(src.type()) || ngen_is_xf(dst.type());
757 bool is_bf_to_f = (src.type() == ngen::DataType::bf)
758 && (dst.type() == ngen::DataType::f);
759 int src_type_size = ngen::getBytes(src.type());
760 int src_off = src.offset();
761 int dst_off = dst.offset();
762 int src_byte_off = src.byte_offset();
763 int dst_byte_off = dst.byte_offset();
764
765 // If src is aligned with dst, return.
766 if ((is_xf || is_bf_to_f) && src_off == dst_off) return;
767 if (!is_xf && src_byte_off == dst_byte_off) return;
768
769 int new_src_byte_off = (is_xf ? dst_off * src_type_size : dst_byte_off);
770
771 int esize = mod.getExecSize();
772 int grf_size = ngen::GRF::bytes(scope.hw());
773 int src_size = std::max(src_type_size * esize * src_stride, src_type_size);
774
775 auto new_src = scope.alloc_reg_buf_data(
776 utils::div_up(src_size + new_src_byte_off, grf_size));
777 new_src = new_src.format(new_src_byte_off, src.type(), esize, src_stride);
778 emit_reorder_1d_tile(scope.hw(), host, scope, esize, src, src_stride,
779 new_src, src_stride);
780 src = new_src;
781}
782
783template <typename GeneratorT>
784void align_src_dst_offset(GeneratorT *host, ngen_register_scope_t &scope,
785 const ngen::InstructionModifier &mod, const reg_buf_data_t &dst,
786 reg_buf_data_t &src0, reg_buf_data_t &src1) {
787 align_src_dst_offset(host, scope, mod, dst, src0);
788 align_src_dst_offset(host, scope, mod, dst, src1);
789}
790
791template <typename GeneratorT>
792void align_src_dst_offset(GeneratorT *host, ngen_register_scope_t &scope,
793 const ngen::InstructionModifier &mod, const ngen_operand_t &dst,
794 ngen_operand_t &src) {
795 if (!dst.is_reg_data()) return;
796 if (!src.is_reg_data()) return;
797
798 auto rd = src.reg_buf_data();
799 align_src_dst_offset(host, scope, mod, dst.reg_buf_data(), rd);
800 if (rd == src.reg_buf_data()) return;
801
802 bool is_negated = src.is_negated();
803 src = ngen_operand_t(rd, src.mod());
804 if (is_negated) src = -src;
805}
806
807template <typename GeneratorT>
808void align_src_dst_offset(GeneratorT *host, ngen_register_scope_t &scope,
809 const ngen::InstructionModifier &mod, const ngen_operand_t &dst,
810 ngen_operand_t &src0, ngen_operand_t &src1) {
811 align_src_dst_offset(host, scope, mod, dst, src0);
812 align_src_dst_offset(host, scope, mod, dst, src1);
813}
814
815// Implementation of GRF reorder between 2D dense layouts.
816// Requirements for A -> B reorder:
817// - A and B must have the same data type
818// - Layouts must be 2D and dense
819// Reorder may require several steps, in this case a temporary buffer T is
820// allocated. For example: A -> T -> B or A -> B -> T -> B
821class reorder_2d_impl_t {
822public:
823 reorder_2d_impl_t(
824 ngen::HW hw, const layout_t &src_layout, const layout_t &dst_layout)
825 : hw_(hw), src_(src_layout), dst_(dst_layout) {
826 ir_assert(src_.type() == dst_.type());
827 tile_ = find_2d_tile(src_, dst_);
828 }
829
830 const tensor_t &tile() const { return tile_; }
831
832 template <typename GeneratorT>
833 void emit(GeneratorT *host, ngen_register_scope_t &scope,
834 const reg_buf_data_t &src_rd, const reg_buf_data_t &dst_rd) {
835 int a_idx, b_idx;
836 int tile_a, tile_b;
837 tile_to_2d_dims(tile_, a_idx, b_idx, tile_a, tile_b);
838
839 // Convert src/dst to 2D layouts.
840 dim_assignment_t to_ab(src_.ndims(), 2);
841 to_ab.assign(a_idx, 0);
842 to_ab.assign(b_idx, 1);
843 auto src_ab = to_ab.map(src_);
844 auto dst_ab = to_ab.map(dst_);
845
846 // Find minimal cost reorder path between layouts.
847 auto path = find_min_cost_path(hw_, src_ab, dst_ab, tile_a, tile_b);
848
849 // Allocate a temporary GRF buffer if needed.
850 reg_buf_data_t tmp;
851 if (path.size() > 1) {
852 const int grf_size = ngen::GRF::bytes(hw_);
853 tmp = scope.alloc_reg_buf_data(
854 utils::div_up(dst_ab.size(), grf_size));
855 }
856
857 // Iterate through found reorders.
858 auto *prev_layout = &src_ab;
859 auto prev_rd = src_rd;
860 int path_len = int(path.size());
861 auto &orig_type = src_ab.type();
862 for (int i = 0; i < path_len; i++) {
863 auto &step = path[i];
864 auto &tile = step.tile;
865 auto &type = step.type;
866 auto *next_layout = &step.layout;
867
868 // x -> y reorder.
869 auto x = prev_layout->map(tile).reinterpret(type);
870 auto y = next_layout->map(tile).reinterpret(type);
871
872 bool use_dst = ((path_len - i) % 2 == 1);
873 auto next_rd = (use_dst ? dst_rd : tmp);
874 auto &x_blocks = x.blocks();
875 auto &y_blocks = y.blocks();
876 ir_assert(x_blocks.size() <= 1);
877 ir_assert(y_blocks.size() <= 1);
878 int x_stride = (x_blocks.empty() ? 1 : int(x_blocks[0].stride));
879 int y_stride = (y_blocks.empty() ? 1 : int(y_blocks[0].stride));
880 int width = int(tile.elems()) * orig_type.size() / type.size();
881 next_layout->for_each_tile(
882 tile, [&](const std::vector<dim_t> &start) {
883 int prev_off = int(prev_layout->offset_in_bytes(start));
884 int next_off = int(next_layout->offset_in_bytes(start));
885 auto x_sub = prev_rd.format(prev_off, to_ngen(type), 1);
886 auto y_sub = next_rd.format(next_off, to_ngen(type), 1);
887 emit_reorder_1d_tile(hw_, host, scope, width, x_sub,
888 x_stride, y_sub, y_stride);
889 });
890 prev_layout = next_layout;
891 prev_rd = next_rd;
892 }
893 }
894
895 // Returns the biggest common 2D tile that is innermost for both layouts.
896 // The returned tile contains at most max_elems elements. If match_outer is
897 // true, then outer parts of both layouts are required to be equal.
898 // Returns an empty tensor if the requested tile is not found.
899 static tensor_t find_2d_tile(const layout_t &a, const layout_t &b,
900 int max_elems = std::numeric_limits<int>::max(),
901 bool match_outer = false) {
902 std::vector<dim_t> tile_dims(a.ndims(), 1);
903 if (a.blocks().empty() || b.blocks().empty())
904 return tensor_t(tile_dims);
905
906 auto non_one_ndims = [](const tensor_t &t) {
907 int ret = 0;
908 for (dim_t d : t.dims())
909 ret += (d != 1 ? 1 : 0);
910 return ret;
911 };
912
913 layout_iterator_t a_it(a);
914 layout_iterator_t b_it(b);
915
916 tensor_t max_tile;
917 for (;;) {
918 auto a_tile = a_it.tile();
919 auto b_tile = b_it.tile();
920 if (non_one_ndims(a_tile) > 2 || non_one_ndims(b_tile) > 2) break;
921 if (!a.map(a_tile).is_dense() || !b.map(b_tile).is_dense()) break;
922 dim_t a_elems = a_tile.elems();
923 dim_t b_elems = b_tile.elems();
924
925 bool tile_ok = true;
926 if (!a_tile.is_equal(b_tile)) tile_ok = false;
927 if (match_outer) {
928 auto a_outer = a_it.outer_layout();
929 auto b_outer = b_it.outer_layout();
930 if (!a_outer.is_equal(b_outer)) tile_ok = false;
931 }
932 if (tile_ok) {
933 if (a_it.nblocks() > max_tile_blocks) break;
934 if (b_it.nblocks() > max_tile_blocks) break;
935 if (a_tile.elems() > max_elems) break;
936 max_tile = a_tile;
937 if (!a_it.has_next() || !b_it.has_next()) break;
938 ++a_it;
939 ++b_it;
940 } else if (a_elems <= b_elems) {
941 if (!a_it.has_next()) break;
942 ++a_it;
943 } else {
944 if (!b_it.has_next()) break;
945 ++b_it;
946 }
947 }
948 return max_tile;
949 }
950
951 static const int max_tile_blocks = 4;
952
953private:
954 // Represents 2D reorder corresponding to (a x b) tile.
955 struct edge_t {
956 edge_t() = default;
957 edge_t(int idx, int a, int b) : idx(idx), a(a), b(b) {}
958
959 tensor_t tile() const { return tensor_t({a, b}); }
960
961 std::string str() const {
962 std::ostringstream oss;
963 oss << "edge(idx = " << idx << ", a = " << a << ", b = " << b
964 << ")";
965 return oss.str();
966 }
967
968 int idx; // Identifier of the edge.
969 int a = 0, b = 0; // Specify tile (a x b).
970 };
971
972 // Represents GRF layout between edges-reorders.
973 struct vertex_t {
974 vertex_t(ngen::HW hw, int idx, const layout_t &layout)
975 : hw(hw), idx(idx), layout(layout) {}
976
977 std::string str() const {
978 std::ostringstream oss;
979 oss << "vertex(idx = " << idx << ", layout = " << layout << ")";
980 return oss.str();
981 }
982
983 void set_edges(const std::vector<edge_t> &edges) {
984 adj_edge_type_masks.resize(edges.size());
985 int type_size = layout.type().size();
986 for (int i = 0; i < int(edges.size()); i++) {
987 auto &e = edges[i];
988 auto tile = e.tile();
989 int max_type_size;
990 bool ok = layout_t::try_reinterpret_to_wider_type(
991 layout, layout, tile, false, &max_type_size);
992 if (!ok) max_type_size = type_size;
993 int from = math::ilog2q(type_size);
994 int to = math::ilog2q(max_type_size);
995 for (int j = from; j <= to; j++) {
996 type_t type = type_t::u(8 << j);
997 if (can_reorder(tile, type))
998 adj_edge_type_masks[i] |= (1 << j);
999 }
1000 }
1001 }
1002
1003 void add_neighbor(const vertex_t *v) { adj_vertices.push_back(v); }
1004
1005 bool is_neighbor(const vertex_t &v) const {
1006 for (auto *n : adj_vertices)
1007 if (n == &v) return true;
1008 return false;
1009 }
1010
1011 // Check the following limitations:
1012 // - Assume at most one block (maybe with non-dense stride)
1013 // - Horizontal stride must be <= 4 for GRF region
1014 // - GRF region can't span more than 2 registers
1015 bool can_reorder(const tensor_t &tile, const type_t &type) const {
1016 auto ab_layout = layout.map(tile).reinterpret(type);
1017 int nblocks = int(ab_layout.blocks().size());
1018 if (nblocks == 0) return true;
1019 if (nblocks > 1) return false;
1020 auto &last = ab_layout.blocks().back();
1021 int max_stride = int(last.stride * last.block);
1022 if (last.stride > 4) return false;
1023 if ((int)last.stride == 4 && type.size() <= 2) return false;
1024 if (!math::is_pow2(last.stride)) return false;
1025 int max_stride_bytes = max_stride * type.size();
1026 int grf_size = ngen::GRF::bytes(hw);
1027 if (max_stride_bytes > 2 * grf_size) return false;
1028 return true;
1029 }
1030
1031 // Finds the minimal cost of reordering from this vertex to vertex v.
1032 int cost(const vertex_t &v, const std::vector<edge_t> &edges,
1033 edge_t &min_edge, type_t &min_type) const {
1034 int min_cost = std::numeric_limits<int>::max();
1035 for (int i = 0; i < int(edges.size()); i++) {
1036 type_t i_min_type;
1037 int new_cost = cost(edges[i], v, i_min_type);
1038 if (new_cost < min_cost) {
1039 min_cost = new_cost;
1040 min_edge = edges[i];
1041 min_type = i_min_type;
1042 }
1043 }
1044 return min_cost;
1045 }
1046
1047 // Finds the minimal cost of reordering from this vertex to vertex `v`
1048 // through edge `e`. If the reorder is possible, `type` contains the
1049 // reorder type with the minimal cost.
1050 int cost(const edge_t &e, const vertex_t &v, type_t &type) const {
1051 uint32_t mask = (adj_edge_type_masks[e.idx]
1052 & v.adj_edge_type_masks[e.idx]);
1053 if (mask == 0) return std::numeric_limits<int>::max();
1054 int cur_size = layout.type().size();
1055 int cur_cost = layout.elems() / (e.a * e.b);
1056 int min_log_bytes = math::ilog2q(cur_size);
1057 int max_log_bytes = 3;
1058 int min_cost = std::numeric_limits<int>::max();
1059 for (int i = min_log_bytes; i <= max_log_bytes; i++) {
1060 if ((mask & (1 << i)) == 0) continue;
1061 if (i > min_log_bytes) {
1062 ir_assert(!layout.blocks().empty());
1063 ir_assert(!v.layout.blocks().empty());
1064 int dim_idx0 = layout.blocks()[0].dim_idx;
1065 int dim_idx1 = v.layout.blocks()[0].dim_idx;
1066 if (dim_idx0 != dim_idx1) continue;
1067 }
1068 min_cost = cur_cost;
1069 type = type_t::u(8 << i);
1070 break;
1071 }
1072 return min_cost;
1073 }
1074
1075 ngen::HW hw;
1076 int idx; // Identifier of the vertex.
1077 layout_t layout; // Layout of the vertex.
1078 // Specifies a bitmask for every edge: if adj_edge_type_masks[E_idx]
1079 // has b-th bit set then this vertex can be reordered through E edge
1080 // using the data type with size 2^b bytes.
1081 std::vector<uint32_t> adj_edge_type_masks;
1082 std::vector<const vertex_t *> adj_vertices; // Adjacent vertices.
1083 };
1084
1085 // Represents a reorder step.
1086 struct reorder_step_t {
1087 reorder_step_t() = default;
1088 reorder_step_t(const layout_t &layout, const tensor_t &tile,
1089 const type_t &type)
1090 : layout(layout), tile(tile), type(type) {}
1091
1092 layout_t layout; // Destination layout.
1093 tensor_t tile; // Tile corresponding to one instruction.
1094 type_t type; // Registers should be reinterpreted to `type` for reorder.
1095 };
1096
1097 // Extracts dimension sizes and their indices from a multidimensional
1098 // tensor.
1099 static void tile_to_2d_dims(
1100 const tensor_t &tile, int &a_idx, int &b_idx, int &a, int &b) {
1101 a_idx = -1;
1102 b_idx = -1;
1103 for (int i = 0; i < tile.ndims(); i++) {
1104 if (tile.dims()[i] == 1) continue;
1105 if (a_idx == -1) {
1106 a_idx = i;
1107 continue;
1108 }
1109 if (b_idx == -1) {
1110 b_idx = i;
1111 continue;
1112 }
1113 ir_error_not_expected();
1114 }
1115
1116 for (int i = 0; i < tile.ndims(); i++) {
1117 if (utils::one_of(i, a_idx, b_idx)) continue;
1118 if (a_idx == -1) {
1119 a_idx = i;
1120 continue;
1121 }
1122 if (b_idx == -1) {
1123 b_idx = i;
1124 continue;
1125 }
1126 }
1127
1128 if (a_idx > b_idx) std::swap(a_idx, b_idx);
1129
1130 a = tile.dims()[a_idx];
1131 b = tile.dims()[b_idx];
1132 }
1133
1134 // Finds the optimal sequence of reorders between src and dst layouts.
1135 static std::vector<reorder_step_t> find_min_cost_path(ngen::HW hw,
1136 const layout_t &src, const layout_t &dst, int tile_a, int tile_b) {
1137 // Create all possible edges - 2D reorders.
1138 std::vector<edge_t> edges;
1139 for (int a = 1; a <= tile_a; a *= 2) {
1140 for (int b = 1; b <= tile_b; b *= 2) {
1141 if (src.dim(0) % a != 0) continue;
1142 if (src.dim(1) % b != 0) continue;
1143 int idx = int(edges.size());
1144 edges.emplace_back(idx, a, b);
1145 }
1146 }
1147
1148 int nedges = int(edges.size());
1149
1150 // Create all possible layouts for tile_a x tile_b tensor.
1151 std::vector<vertex_t> vertices;
1152 std::vector<std::vector<std::pair<int, uint32_t>>> edge_vertices(
1153 nedges);
1154 auto all_layouts = generate_all_layouts(src.type(), tile_a, tile_b);
1155 for (auto &l : all_layouts) {
1156 // Skip if too many blocks.
1157 if (int(l.blocks().size()) > max_tile_blocks) continue;
1158 int v_idx = int(vertices.size());
1159 vertices.emplace_back(hw, v_idx, l);
1160 auto &v = vertices.back();
1161 // Pass all known reorders, the vertex/layout will filter out
1162 // incompatible reorders.
1163 v.set_edges(edges);
1164 // Store all vertices adjacent to a specific edge.
1165 for (int i = 0; i < nedges; i++) {
1166 uint32_t mask = v.adj_edge_type_masks[i];
1167 if (mask != 0) edge_vertices[i].emplace_back(v_idx, mask);
1168 }
1169 }
1170
1171 // Find neighbors between all vertices.
1172 int nvertices = int(vertices.size());
1173 for (int i = 0; i < nvertices; i++) {
1174 auto &v = vertices[i];
1175 for (int j = 0; j < nedges; j++) {
1176 uint32_t mask = v.adj_edge_type_masks[j];
1177 if (mask != 0) {
1178 for (auto &idx_mask : edge_vertices[j]) {
1179 int v_idx = idx_mask.first;
1180 if (v_idx == i) continue;
1181 uint32_t common_mask = (mask
1182 & vertices[v_idx].adj_edge_type_masks[j]);
1183 if (common_mask != 0) v.add_neighbor(&vertices[v_idx]);
1184 }
1185 }
1186 }
1187 }
1188
1189 // Identify source and destination vertices.
1190 int src_idx = -1;
1191 int dst_idx = -1;
1192 for (int i = 0; i < nvertices; i++) {
1193 auto &v = vertices[i];
1194 if (src_idx == -1
1195 && v.layout.is_strictly_equal(
1196 src, /*compare_offset=*/false))
1197 src_idx = i;
1198 if (dst_idx == -1
1199 && v.layout.is_strictly_equal(
1200 dst, /*compare_offset=*/false))
1201 dst_idx = i;
1202 }
1203
1204 ir_assert(src_idx != -1);
1205 ir_assert(dst_idx != -1);
1206
1207 // Layouts are the same, just copy.
1208 if (src_idx == dst_idx) {
1209 auto &v = vertices[src_idx];
1210 edge_t min_edge;
1211 type_t min_type;
1212 v.cost(v, edges, min_edge, min_type);
1213 reorder_step_t step(v.layout, min_edge.tile(), min_type);
1214 return {step};
1215 }
1216
1217 // Dijkstra's algorithm, find the minimal cost path between src and
1218 // dst. Use the number of instructions to estimate the cost.
1219 int inf_cost = std::numeric_limits<int>::max();
1220 std::vector<int> cost(nvertices, inf_cost);
1221 std::vector<int> prev(nvertices);
1222 std::vector<reorder_step_t> reorder_steps(nvertices);
1223 std::vector<bool> seen(nvertices, false);
1224 cost[src_idx] = 0;
1225 for (int i = 0; i < nvertices; i++) {
1226 int min_idx = -1;
1227 int min_cost = inf_cost;
1228 for (int j = 0; j < nvertices; j++) {
1229 if (seen[j]) continue;
1230 if (cost[j] < min_cost) {
1231 min_idx = j;
1232 min_cost = cost[j];
1233 }
1234 }
1235 seen[min_idx] = true;
1236 auto &v_min = vertices[min_idx];
1237 for (auto *v : v_min.adj_vertices) {
1238 edge_t min_edge;
1239 type_t min_type;
1240 int new_cost = cost[min_idx]
1241 + v_min.cost(*v, edges, min_edge, min_type);
1242 if (new_cost < cost[v->idx]) {
1243 cost[v->idx] = new_cost;
1244 prev[v->idx] = min_idx;
1245 reorder_steps[v->idx] = reorder_step_t(
1246 v->layout, min_edge.tile(), min_type);
1247 }
1248 }
1249 }
1250
1251 // Sanity check, ensure the reorder sequence is not too long.
1252 int max_cost = 256;
1253 ir_assert(cost[dst_idx] <= max_cost);
1254 MAYBE_UNUSED(max_cost);
1255
1256 // Restore the shortest reorder path.
1257 std::vector<reorder_step_t> ret;
1258 int idx = dst_idx;
1259 while (idx != src_idx) {
1260 ret.push_back(reorder_steps[idx]);
1261 idx = prev[idx];
1262 }
1263 std::reverse(ret.begin(), ret.end());
1264 return ret;
1265 }
1266
1267 // Returns all possible layouts for (a x b) tensor.
1268 static std::vector<layout_t> generate_all_layouts(
1269 const type_t &type, int a, int b) {
1270 std::vector<layout_t> ret;
1271 std::vector<block_t> blocks;
1272 generate_all_layouts_impl(ret, blocks, type, a, b, 1);
1273 return ret;
1274 }
1275
1276 static void generate_all_layouts_impl(std::vector<layout_t> &layouts,
1277 std::vector<block_t> &blocks, const type_t &type, int a, int b,
1278 int stride) {
1279 if (a == 1 && b == 1) {
1280 layouts.emplace_back(type, 2, 0, blocks);
1281 return;
1282 }
1283 bool iterate_a = true;
1284 bool iterate_b = true;
1285
1286 // Avoid repeating indices to keep only unique layouts.
1287 if (!blocks.empty()) {
1288 auto &last = blocks.back();
1289 iterate_a &= (last.dim_idx != 0);
1290 iterate_b &= (last.dim_idx != 1);
1291 }
1292
1293 if (iterate_a) {
1294 for (int a_blk = 2; a_blk <= a; a_blk++) {
1295 if (a % a_blk != 0) continue;
1296 blocks.emplace_back(0, a_blk, stride);
1297 generate_all_layouts_impl(
1298 layouts, blocks, type, a / a_blk, b, stride * a_blk);
1299 blocks.pop_back();
1300 }
1301 }
1302 if (iterate_b) {
1303 for (int b_blk = 2; b_blk <= b; b_blk++) {
1304 if (b % b_blk != 0) continue;
1305 blocks.emplace_back(1, b_blk, stride);
1306 generate_all_layouts_impl(
1307 layouts, blocks, type, a, b / b_blk, stride * b_blk);
1308 blocks.pop_back();
1309 }
1310 }
1311 }
1312
1313 ngen::HW hw_;
1314
1315 layout_t src_;
1316 layout_t dst_;
1317
1318 tensor_t tile_;
1319};
1320
1321class dense_2d_block_filter_t {
1322public:
1323 bool operator()(const block_t &b) {
1324 if (b.block == 1) return false;
1325 if ((dim_t)b.stride != stride_) return false;
1326 stride_ *= b.block;
1327 if (!have_seen(b.dim_idx)) non_one_ndims_++;
1328 return non_one_ndims_ <= 2;
1329 }
1330
1331 dense_2d_block_filter_t() = default;
1332
1333private:
1334 bool have_seen(int idx) {
1335 auto ret = seen_.insert(idx);
1336 return !ret.second;
1337 }
1338
1339 dim_t stride_ = 1;
1340 int non_one_ndims_ = 0;
1341 std::unordered_set<int> seen_;
1342};
1343
1344template <typename IterT>
1345class filter_t {
1346 using inner_iter_t = decltype(std::declval<const IterT>().begin());
1347 using iter_value_t = decltype(*std::declval<inner_iter_t>());
1348 using predicate_t = std::function<bool(const iter_value_t &)>;
1349
1350public:
1351 class iterator_t {
1352 public:
1353 bool operator==(const iterator_t &it) const { return it_ == it.it_; }
1354 bool operator!=(const iterator_t &it) const { return !operator==(it); }
1355 const iter_value_t &operator*() const { return *it_; }
1356 iterator_t &operator++() {
1357 if (it_ == end_) return *this;
1358 while (++it_ != end_ && !predicate_(*it_))
1359 ;
1360 return *this;
1361 }
1362
1363 iterator_t(inner_iter_t it, inner_iter_t end, predicate_t predicate)
1364 : it_(it), end_(end), predicate_(predicate) {
1365 if (it_ != end_ && !predicate_(*it_)) operator++();
1366 }
1367
1368 private:
1369 inner_iter_t it_, end_;
1370 std::function<bool(const iter_value_t &)> predicate_;
1371 };
1372
1373 iterator_t begin() const { return {begin_, end_, predicate_}; }
1374 iterator_t end() const { return {end_, end_, predicate_}; }
1375
1376 filter_t(const IterT &it, predicate_t predicate)
1377 : begin_(it.begin()), end_(it.end()), predicate_(predicate) {}
1378
1379private:
1380 inner_iter_t begin_, end_;
1381 predicate_t predicate_;
1382};
1383
1384class shared_inner_tiles_t {
1385 using blocks_t = std::vector<block_t>;
1386 using block_iterator_t = filter_t<blocks_t>::iterator_t;
1387
1388 class inner_tile_iterator_t {
1389 public:
1390 bool operator==(const inner_tile_iterator_t &it) const {
1391 return curr_ == it.curr_;
1392 }
1393 bool operator!=(const inner_tile_iterator_t &it) const {
1394 return !operator==(it);
1395 }
1396
1397 inner_tile_iterator_t &operator++() {
1398 if (curr_ == end_) return *this;
1399
1400 auto size = (*curr_).block;
1401 while (++factor_ <= size) {
1402 if (size % factor_ == 0) return *this;
1403 }
1404
1405 dims_[(*curr_).dim_idx] *= size;
1406 ++curr_;
1407 factor_ = 1;
1408 return operator++();
1409 }
1410
1411 tensor_t operator*() const {
1412 auto dims = dims_;
1413 dims[(*curr_).dim_idx] *= factor_;
1414 return tensor_t(dims);
1415 }
1416
1417 inner_tile_iterator_t(block_iterator_t curr, block_iterator_t end,
1418 const tensor_t &tile)
1419 : curr_(std::move(curr))
1420 , end_(std::move(end))
1421 , dims_(tile.dims())
1422 , factor_(1) {}
1423
1424 private:
1425 block_iterator_t curr_, end_;
1426 std::vector<dim_t> dims_;
1427 dim_t factor_;
1428 };
1429
1430 class iterator_t {
1431 public:
1432 bool operator==(const iterator_t &it) const {
1433 return a_it_ == it.a_it_ && b_it_ == it.b_it_;
1434 }
1435 bool operator!=(const iterator_t &it) const { return !operator==(it); }
1436
1437 iterator_t &operator++() {
1438 bool adv_a = true, adv_b = true;
1439 while (adv_a || adv_b) {
1440 adv_a = (a_it_ != a_end_)
1441 && (b_it_ == b_end_
1442 || (*a_it_).elems() <= (*b_it_).elems());
1443 adv_b = (b_it_ != b_end_)
1444 && (a_it_ == a_end_
1445 || (*b_it_).elems() <= (*a_it_).elems());
1446 if (adv_a) ++a_it_;
1447 if (adv_b) ++b_it_;
1448 if (a_it_ != a_end_) a_tile_ = *a_it_;
1449 if (b_it_ != b_end_) b_tile_ = *b_it_;
1450 if (a_tile_.is_equal(b_tile_)) break;
1451 }
1452 return *this;
1453 }
1454
1455 const tensor_t &operator*() const { return a_tile_; }
1456
1457 iterator_t(const block_iterator_t &a_it, const block_iterator_t &a_end,
1458 const block_iterator_t &b_it, const block_iterator_t &b_end,
1459 const tensor_t &tile)
1460 : a_it_(a_it, a_end, tile)
1461 , a_end_(a_end, a_end, tile)
1462 , b_it_(b_it, b_end, tile)
1463 , b_end_(b_end, b_end, tile) {}
1464
1465 private:
1466 inner_tile_iterator_t a_it_, a_end_;
1467 inner_tile_iterator_t b_it_, b_end_;
1468 tensor_t a_tile_, b_tile_;
1469 };
1470
1471public:
1472 iterator_t begin() const {
1473 return {a_.begin(), a_.end(), b_.begin(), b_.end(), default_tile_};
1474 }
1475 iterator_t end() const {
1476 return {a_.end(), a_.end(), b_.end(), b_.end(), default_tile_};
1477 }
1478
1479 shared_inner_tiles_t(const layout_t &a, const layout_t &b)
1480 : a_(a.blocks(), dense_2d_block_filter_t())
1481 , b_(b.blocks(), dense_2d_block_filter_t())
1482 , default_tile_(std::vector<dim_t>(a.ndims(), 1)) {}
1483
1484private:
1485 filter_t<blocks_t> a_;
1486 filter_t<blocks_t> b_;
1487 tensor_t default_tile_;
1488};
1489
1490class reorder_impl_t {
1491public:
1492 reorder_impl_t(ngen::HW hw, const reorder_t &reorder)
1493 : hw_(hw)
1494 , src_layout_(reorder.src_layout)
1495 , dst_layout_(reorder.dst_layout) {
1496 layout_t::try_reinterpret_to_wider_type(src_layout_, dst_layout_);
1497
1498 // Pure bf moves are not supported.
1499 if (utils::everyone_is(
1500 type_t::bf16(), src_layout_.type(), dst_layout_.type())) {
1501 src_layout_ = src_layout_.retype(type_t::u16());
1502 dst_layout_ = dst_layout_.retype(type_t::u16());
1503 }
1504 }
1505
1506 template <typename GeneratorT>
1507 void emit(GeneratorT *host, ngen_register_scope_t &scope,
1508 const reg_buf_data_t &src, const reg_buf_data_t &dst) {
1509 if (try_emit_2d(host, scope, src, dst)) return;
1510 emit_1d(host, scope, src, dst);
1511 }
1512
1513private:
1514 template <typename GeneratorT>
1515 void emit_1d(GeneratorT *host, ngen_register_scope_t &scope,
1516 const reg_buf_data_t &src_rd, const reg_buf_data_t &dst_rd) {
1517 int src_stride;
1518 int dst_stride;
1519 auto tile = find_max_tile_with_fixed_stride(
1520 src_layout_, dst_layout_, src_stride, dst_stride);
1521
1522 int tile_elems = int(tile.elems());
1523 auto &src_type = src_layout_.type();
1524 auto &dst_type = dst_layout_.type();
1525 dst_layout_.for_each_tile(tile, [&](const std::vector<dim_t> &start) {
1526 int src_off = int(src_layout_(start) * src_type.size());
1527 int dst_off = int(dst_layout_(start) * dst_type.size());
1528 auto sub_src = src_rd.format(src_off, to_ngen(src_type), 1);
1529 auto sub_dst = dst_rd.format(dst_off, to_ngen(dst_type), 1);
1530
1531 ngen_register_scope_t tile_scope(scope.register_allocator());
1532 emit_reorder_1d_tile(hw_, host, tile_scope, tile_elems, sub_src,
1533 src_stride, sub_dst, dst_stride);
1534 });
1535 }
1536
1537 static tensor_t find_max_2d_dense_tile(const layout_t &a_layout,
1538 const layout_t &b_layout, dim_t max_elems) {
1539 shared_inner_tiles_t tiles {a_layout, b_layout};
1540 tensor_t max_tile;
1541 auto all_pow2 = [](const tensor_t &tile) {
1542 for (auto d : tile.dims())
1543 if (!math::is_pow2(d)) return false;
1544 return true;
1545 };
1546
1547 for (auto &tile : tiles) {
1548 if (tile.elems() > max_elems) break;
1549 if (all_pow2(tile)) max_tile = tile;
1550 }
1551 // No point in tiling with a 1x1 tile
1552 return max_tile.elems() > 1 ? max_tile : tensor_t();
1553 }
1554
1555 template <typename GeneratorT>
1556 bool try_emit_2d(GeneratorT *host, ngen_register_scope_t &scope,
1557 const reg_buf_data_t &src_rd, const reg_buf_data_t &dst_rd) {
1558 if (src_layout_.type() != dst_layout_.type()) return false;
1559 if (!src_layout_.is_dense()) return false;
1560 if (!dst_layout_.is_dense()) return false;
1561
1562 int max_tile_size = 512;
1563 int max_tile_elems = max_tile_size / src_layout_.type().size();
1564 auto tile = find_max_2d_dense_tile(
1565 src_layout_, dst_layout_, max_tile_elems);
1566
1567 // Couldn't find tile, 2D reorder is not supported.
1568 if (tile.is_empty()) return false;
1569
1570 auto src_tile_layout = src_layout_.map(tile);
1571 auto dst_tile_layout = dst_layout_.map(tile);
1572 if (!dst_tile_layout.is_dense()) return false;
1573 auto layout_ok = [](const layout_t &l) {
1574 if (l.blocks().size() < 2) return false;
1575 for (auto &b : l.blocks()) {
1576 if (math::is_pow2(b.block)) continue;
1577 for (int i = 2; i < (int)b.block / 2; i++)
1578 if (b.block % i != 0) return false;
1579 }
1580 return true;
1581 };
1582
1583 if (!layout_ok(src_tile_layout)) return false;
1584 if (!layout_ok(dst_tile_layout)) return false;
1585
1586 // Set layout offset to 0 since the offset is handled by fixing up the
1587 // register input to try_emit_2d_impl
1588 src_tile_layout.set_offset(0);
1589 dst_tile_layout.set_offset(0);
1590
1591 bool ok = true;
1592 auto type = to_ngen(src_layout_.type());
1593 src_layout_.for_each_tile(tile, [&](const std::vector<dim_t> &start) {
1594 auto src_off = src_layout_.offset_in_bytes<dim_t>(start);
1595 auto dst_off = dst_layout_.offset_in_bytes<dim_t>(start);
1596 auto src_tile_rd = src_rd.format(int(src_off), type);
1597 auto dst_tile_rd = dst_rd.format(int(dst_off), type);
1598
1599 ngen_register_scope_t tile_scope(scope.register_allocator());
1600 ok &= try_emit_2d_impl(host, tile_scope, src_tile_layout,
1601 dst_tile_layout, src_tile_rd, dst_tile_rd);
1602 });
1603 return ok;
1604 }
1605
1606 template <typename GeneratorT>
1607 bool try_emit_2d_impl(GeneratorT *host, ngen_register_scope_t &scope,
1608 const layout_t &src_layout, const layout_t &dst_layout,
1609 const reg_buf_data_t &src_rd, const reg_buf_data_t &dst_rd) {
1610 // Try to allocate/release a temporary buffer to avoid out_of_registers
1611 // exception.
1612 const int grf_size = ngen::GRF::bytes(hw_);
1613 auto dummy = scope.try_alloc_range(
1614 utils::div_up(dst_layout.size(), grf_size));
1615 if (dummy.isInvalid()) {
1616 ir_warning() << "Can't allocate buffer for 2D reorder. Reorder "
1617 "performance may be suboptimal.\n";
1618 return false;
1619 }
1620
1621 // Allocation succeeded, can proceed further.
1622 scope.safeRelease(dummy);
1623
1624 reorder_2d_impl_t r(hw_, src_layout, dst_layout);
1625 int tile_elems = int(r.tile().elems());
1626 if (tile_elems < 16 || tile_elems > 512) return false;
1627
1628 r.emit(host, scope, src_rd, dst_rd);
1629 return true;
1630 }
1631
1632 static tensor_t find_max_tile_with_fixed_stride(const layout_t &src,
1633 const layout_t &dst, int &src_stride, int &dst_stride) {
1634 // 1. Split layouts to have aligned blocks.
1635 auto a = src;
1636 auto b = dst;
1637 layout_t::align_layouts(a, b);
1638
1639 // 2. Find the max innermost tile.
1640 auto a_blocks = a.blocks();
1641 auto b_blocks = b.blocks();
1642
1643 std::vector<dim_t> tile_dims(a.ndims(), 1);
1644 src_stride = (a_blocks.empty() ? 1 : int(a_blocks[0].stride));
1645 dst_stride = (b_blocks.empty() ? 1 : int(b_blocks[0].stride));
1646 int src_cur_stride = src_stride;
1647 int dst_cur_stride = dst_stride;
1648
1649 int min_blocks = int(std::min(a_blocks.size(), b_blocks.size()));
1650 for (int i = 0; i < min_blocks; i++) {
1651 auto &ab = a_blocks[i];
1652 auto &bb = b_blocks[i];
1653 if (ab.dim_idx != bb.dim_idx || ab.block != bb.block) break;
1654
1655 // Strides are supported for the innermost block only.
1656 if (src_cur_stride != int(ab.stride)) break;
1657 if (dst_cur_stride != int(bb.stride)) break;
1658
1659 src_cur_stride = int(ab.block * ab.stride);
1660 dst_cur_stride = int(bb.block * bb.stride);
1661 tile_dims[ab.dim_idx] *= ab.block;
1662 }
1663 return tensor_t(tile_dims);
1664 }
1665
1666 ngen::HW hw_;
1667 layout_t src_layout_;
1668 layout_t dst_layout_;
1669};
1670
1671} // namespace jit
1672} // namespace gpu
1673} // namespace impl
1674} // namespace dnnl
1675
1676#endif
1677