1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <algorithm>
18
19#include "gpu/ocl/custom_reorder.hpp"
20
21#include "common/utils.hpp"
22#include "gpu/ocl/ocl_stream.hpp"
23#include "gpu/ocl/ocl_utils.hpp"
24namespace dnnl {
25namespace impl {
26namespace gpu {
27namespace ocl {
28
29using namespace dnnl::impl::memory_tracking::names;
30
31using dimension = struct {
32 dim_t size;
33 int idx;
34};
35
36using stride_t = struct {
37 dim_t stride;
38 dim_t size;
39 int idx;
40};
41
42// Stride sorter. Smaller stride = inner dim, bigger stride = outer dim.
43// Dimensions of size 1 are considered outermost regardless of strides and
44// they are sorted by index.
45bool stride_cmp(const stride_t &a, const stride_t &b) {
46 if (a.size == 1 && b.size == 1) {
47 return a.idx > b.idx;
48 } else if (a.size != 1 && b.size == 1) {
49 return true;
50 } else if (a.size == 1 && b.size != 1) {
51 return false;
52 } else {
53 return a.stride < b.stride;
54 }
55}
56
57// Returns size and index of dimension or block that's last or at given
58// distance from end. Blocks, if exist, take precedence before dimensions.
59// Order of dimensions is determined by sorting strides; smallest stride is
60// last dimension. Dimensions of size 1 are 'ignored' (treated as first).
61// Due to ignoring dims of size 1 and treating them as outermost regardless of
62// their original position in tensor tag, now it is possible to arrive at
63// illegal tensor tag where last dim is the same as first block:
64// 8x8x1x1 aBcd8b becomes cdaB8b (notice the "B8b").
65// In such cases this function combines last dim with first block (xB8b := xb)
66dimension get_Nth_last_dim_or_block(
67 const memory_desc_wrapper &md, int distance = 0) {
68 int nblks = md.blocking_desc().inner_nblks;
69 dimension ret;
70 int ndims = md.ndims();
71
72 std::vector<stride_t> strides(ndims);
73 for (int d = 0; d < ndims; ++d) {
74 strides[d].idx = d;
75 strides[d].stride = md.blocking_desc().strides[d];
76 strides[d].size = md.padded_dims()[d];
77 }
78 std::sort(strides.begin(), strides.end(), stride_cmp);
79 for (int i = 0; i < nblks; i++) {
80 stride_t blk;
81 blk.idx = md.blocking_desc().inner_idxs[i];
82 blk.size = md.blocking_desc().inner_blks[i];
83 if (i == 0 && blk.idx == strides[0].idx) { continue; }
84 strides.insert(strides.begin(), blk);
85 }
86 ret.idx = strides[distance].idx;
87 ret.size = strides[distance].size;
88 return ret;
89}
90
91int innermost_block(const blocking_desc_t &blk) {
92 int last = blk.inner_nblks - 1;
93 return blk.inner_blks[last];
94}
95
96bool is_alt_faster_than_ref(const memory_desc_wrapper &src_mdw,
97 const memory_desc_wrapper &dst_mdw,
98 const compute::device_info_t *dev_info) {
99 using namespace format_tag;
100 int ndims = src_mdw.ndims();
101 int last = ndims - 1;
102 if (!src_mdw.matches_one_of_tag(abcd, abc, ab)) { return false; }
103 // on GPUs newer than gen9 reference implementation is usually faster
104 if (dev_info->gpu_arch() != compute::gpu_arch_t::gen9) { return false; }
105 // ensure reasonable work group size
106 if (src_mdw.dims()[last] < 8) { return false; }
107 // abcd->???b reorders are faster with reference implementation
108 if (ndims == 4 && dst_mdw.blocking_desc().strides[1] == 1) { return false; }
109 if (ndims == 3 && dst_mdw.blocking_desc().strides[0] == 1) { return false; }
110 return true;
111}
112
113bool matches_one_NxN_layout(const memory_desc_wrapper &src,
114 const memory_desc_wrapper &dst, int n, int scale_mask) {
115 if (dst.ndims() < 2) { return false; }
116 if (!src.is_blocking_desc() || !dst.is_blocking_desc()) { return false; }
117 auto dst_last = get_Nth_last_dim_or_block(dst, 0);
118 auto src_last = get_Nth_last_dim_or_block(src, 0);
119
120 if (dst_last.size % n != 0) { return false; }
121 if (src_last.size % n != 0) { return false; }
122 if (dst_last.idx == src_last.idx) { return false; }
123 // no padding allowed on dimensions that are last in src or last in dst
124 if (src.padded_dims()[src_last.idx] != src.dims()[src_last.idx]) {
125 return false;
126 }
127 if (src.padded_dims()[dst_last.idx] != src.dims()[dst_last.idx]) {
128 return false;
129 }
130 if (dst.padded_dims()[src_last.idx] != dst.dims()[src_last.idx]) {
131 return false;
132 }
133 if (dst.padded_dims()[dst_last.idx] != dst.dims()[dst_last.idx]) {
134 return false;
135 }
136 // no scale mask allowed on src's or dst's innermost dimension
137 if (scale_mask & (1 << src_last.idx)) { return false; }
138 if (scale_mask & (1 << dst_last.idx)) { return false; }
139
140 return true;
141}
142
143// For cases like ab -> Ba4b, BA4b8a8b2a etc. Takes dst's last two dimensions
144// and groups them together into packets of size 16 so that it can be written
145// to in bursts.
146bool fill_conf_xab_xba(const memory_desc_wrapper &src,
147 const memory_desc_wrapper &dst, int scale_mask, xb_to_xab_xba_t &cfg,
148 int &vect_dim, int &vect_size, dim_t *blocks) {
149
150 vect_size = 16;
151 if (dst.ndims() < 2) { return false; }
152 dimension src_last = get_Nth_last_dim_or_block(src);
153 dimension dst_last = get_Nth_last_dim_or_block(dst);
154 dimension dst_next_last = get_Nth_last_dim_or_block(dst, 1);
155
156 if (src_last.idx != dst_last.idx && src_last.idx != dst_next_last.idx) {
157 return false;
158 }
159 bool xb_to_xab = (src_last.idx == dst_last.idx);
160
161 // no padding on dims that are src's or dst's innermost
162 if (src.padded_dims()[src_last.idx] != src.dims()[src_last.idx]) {
163 return false;
164 }
165 if (src.padded_dims()[dst_last.idx] != src.dims()[dst_last.idx]) {
166 return false;
167 }
168 if (dst.padded_dims()[src_last.idx] != dst.dims()[src_last.idx]) {
169 return false;
170 }
171 if (dst.padded_dims()[dst_last.idx] != dst.dims()[dst_last.idx]) {
172 return false;
173 }
174 if (src.offset0() != 0) { return false; }
175 if (dst.offset0() != 0) { return false; }
176
177 // no scale mask allowed on src's or dst's innermost dimension
178 if (scale_mask & (1 << src_last.idx)) { return false; }
179 if (scale_mask & (1 << dst_last.idx)) { return false; }
180
181 if (src_last.size < 16) { return false; }
182 if (src_last.size % 16 != 0) { return false; }
183 if (vect_size % dst_last.size != 0) { return false; }
184 if (dst_next_last.size % (vect_size / dst_last.size) != 0) { return false; }
185
186 dimension src_burst;
187 dimension dst_burst[2];
188 dimension dst_loop;
189 dimension src_loop;
190
191 src_burst.idx = src_last.idx;
192 src_burst.size = vect_size;
193 dst_burst[0] = dst_last;
194 dst_burst[1].idx = dst_next_last.idx;
195 dst_burst[1].size = 16 / dst_last.size;
196
197 int dst_src_idx = xb_to_xab ? 0 : 1;
198 dst_loop.size = src_burst.size / dst_burst[dst_src_idx].size;
199 dst_loop.idx = src_burst.idx;
200
201 src_loop.size = dst_burst[1 - dst_src_idx].size;
202 src_loop.idx = dst_burst[1 - dst_src_idx].idx;
203
204 vect_dim = src_last.idx;
205
206 if ((dst_last.size * dst_next_last.size) % vect_size != 0) { return false; }
207 blocks[src_loop.idx] = src_loop.size;
208 cfg.blk_size = src_loop.size;
209 cfg.src_blk_dim = src_loop.idx;
210 cfg.src_blk_coeff = 1;
211 cfg.dst_blk_dim = dst_loop.idx;
212
213 if (dst_loop.idx == dst_burst[0].idx) {
214 cfg.dst_blk_coeff = dst_burst[0].size;
215 } else if (dst_loop.idx == dst_burst[1].idx) {
216 cfg.dst_blk_coeff = dst_burst[1].size;
217 } else {
218 cfg.dst_blk_coeff = 1;
219 }
220 cfg.vd = xb_to_xab;
221 return true;
222}
223
224bool fits_xab_xba(const memory_desc_wrapper &src,
225 const memory_desc_wrapper &dst, int scale_mask) {
226 xb_to_xab_xba_t cfg;
227 int vect_dim;
228 int vect_size;
229 dim_t blocks[6];
230
231 return fill_conf_xab_xba(
232 src, dst, scale_mask, cfg, vect_dim, vect_size, &blocks[0]);
233}
234
235bool matches_ABxxxx8ayb_layout(const blocking_desc_t &blk, int ndims) {
236 if (ndims > 2) { return false; }
237 int last = blk.inner_nblks - 1;
238 // Don't allow this kernel when two adjacent blocks by b create
239 // total block size smaller than 16 - in that situation macros
240 // used for calculation of dst address return wrong values.
241 for (int d = last - 2; d >= 0; d--) {
242 if (blk.inner_idxs[d] == ndims - 1) {
243 int double_block = blk.inner_blks[last] * blk.inner_blks[d];
244 if (double_block < 16) {
245 return false;
246 } else {
247 break;
248 }
249 }
250 }
251 return ((blk.inner_blks[last] == 4 || blk.inner_blks[last] == 2)
252 && blk.inner_idxs[last] == ndims - 1
253 && blk.inner_blks[last - 1] == 8
254 && blk.inner_idxs[last - 1] == ndims - 2);
255}
256
257bool dim_is_div_by_16_or_less_than_16(
258 const memory_desc_wrapper &src, int dim_index) {
259 const auto &padded_dims = src.padded_dims();
260 assert(dim_index < src.ndims());
261 return (padded_dims[dim_index] % 16 == 0 || padded_dims[dim_index] < 16);
262}
263
264bool is_broadcast_by_strides(const memory_desc_wrapper &mdw) {
265 if (mdw.is_blocking_desc()) {
266 for (int i = 0; i < mdw.ndims(); i++) {
267 if (mdw.blocking_desc().strides[i] == 0) { return true; }
268 }
269 }
270 return false;
271}
272
273bool is_padded(const memory_desc_wrapper &mdw, int dim) {
274 return (mdw.dims()[dim] != mdw.padded_dims()[dim]);
275}
276
277bool fits_3ch(const memory_desc_wrapper &src_mdw,
278 const memory_desc_wrapper &dst_mdw, int scale_mask) {
279 // TODO: make it more generic, now it works only for dense->padded case
280 if (src_mdw.ndims() < 2 || dst_mdw.ndims() < 2) { return false; }
281
282 auto last_dim_src = get_Nth_last_dim_or_block(src_mdw);
283 auto nextlast_dim_src = get_Nth_last_dim_or_block(src_mdw, 1);
284 auto last_dim_dst = get_Nth_last_dim_or_block(dst_mdw);
285 auto nextlast_dim_dst = get_Nth_last_dim_or_block(dst_mdw, 1);
286
287 bool same_nextlast_dims = (nextlast_dim_src.idx == nextlast_dim_dst.idx);
288
289 // src's innermost dim is assumed to be contiguous, it'll be read from
290 // adjacent mem addresses
291 auto src_innermost_stride
292 = src_mdw.blocking_desc().strides[last_dim_src.idx];
293 if (last_dim_src.size != src_mdw.padded_dims()[last_dim_src.idx]
294 || (src_innermost_stride > 1 && src_mdw.is_plain())) {
295 return false;
296 }
297 if (last_dim_src.idx != last_dim_dst.idx) { return false; }
298 if (last_dim_src.size > last_dim_dst.size) { return false; }
299 if (last_dim_dst.size > 16 || last_dim_dst.size % 8 != 0) { return false; }
300 if (last_dim_src.idx == nextlast_dim_src.idx) { return false; }
301 if (last_dim_src.idx == nextlast_dim_dst.idx) { return false; }
302 if (nextlast_dim_src.size % 2 != 0) { return false; }
303 if (nextlast_dim_dst.size % 2 != 0) { return false; }
304 if (is_padded(src_mdw, last_dim_src.idx)) { return false; }
305 // If src's and dst's nextlast dims are not the same, CL code will use
306 // nested loops. In inner loop, data offset is incremented and can't be
307 // converted back into tensor coordinates. This means there can't be
308 // operations that depend on nextlast dim's coordinates in inner loop -
309 // so no padding and no scale quant.
310 if (!same_nextlast_dims) {
311 if (is_padded(src_mdw, nextlast_dim_src.idx)) { return false; }
312 if (is_padded(src_mdw, nextlast_dim_dst.idx)) { return false; }
313 if (is_padded(dst_mdw, nextlast_dim_src.idx)) { return false; }
314 if (is_padded(dst_mdw, nextlast_dim_dst.idx)) { return false; }
315 if (scale_mask & (1 << nextlast_dim_src.idx)) { return false; }
316 if (scale_mask & (1 << nextlast_dim_dst.idx)) { return false; }
317 }
318 if (scale_mask & (1 << last_dim_src.idx)) { return false; }
319 // no 2nd layer of block on innermost dim in dst
320 if (dst_mdw.padded_dims()[last_dim_src.idx] != last_dim_dst.size) {
321 return false;
322 }
323 return true;
324}
325
326reorder_kernel_t select_kernel(const reorder_conf_t &conf,
327 const memory_desc_wrapper &src_mdw, const memory_desc_wrapper &dst_mdw,
328 const compute::device_info_t *dev_info) {
329 using namespace format_tag;
330
331 const auto &padded_dims = dst_mdw.padded_dims();
332
333 int last = conf.ndims - 1;
334 size_t last_dim = padded_dims[last];
335
336 const bool multi_scale_quant
337 = (conf.src_quant.with_scale() && conf.src_quant.num_scales() > 1)
338 || (conf.dst_quant.with_scale() && conf.dst_quant.num_scales() > 1);
339 const bool has_padding_or_multi_scale_quant
340 = conf.has_padding || multi_scale_quant;
341
342 const bool type_s8_u8 = utils::one_of(src_mdw.data_type(), dnnl_s8, dnnl_u8)
343 || utils::one_of(dst_mdw.data_type(), dnnl_s8, dnnl_u8);
344
345 const bool allow_unroll
346 = !conf.has_padding && !multi_scale_quant && !type_s8_u8;
347
348 if (is_broadcast_by_strides(src_mdw) || is_broadcast_by_strides(dst_mdw)) {
349 return reorder_kernel_t::none;
350 }
351
352 int mask = conf.src_quant.scale_mask() | conf.src_quant.zp_mask()
353 | conf.dst_quant.scale_mask() | conf.dst_quant.zp_mask();
354
355 if (matches_one_NxN_layout(src_mdw, dst_mdw, 16, mask)) {
356 // W/A for compiler bug: avoid using intel_sub_group_shuffle with
357 // SIMD16 on gen12lp
358 if (dev_info->gpu_arch() == compute::gpu_arch_t::xe_lp) {
359 return reorder_kernel_t::transpose8x8;
360 }
361 if (dev_info->gpu_arch() == compute::gpu_arch_t::gen9) {
362 return reorder_kernel_t::local16x16;
363 }
364 // W/A for assumed compiler bug: avoid using intel_sub_group_shuffle
365 // with SIMD16 on Gen11.
366 if (dev_info->gpu_arch() == compute::gpu_arch_t::gen11) {
367 return reorder_kernel_t::transpose8x8;
368 }
369 return reorder_kernel_t::transpose16x16;
370 }
371 if (matches_one_NxN_layout(src_mdw, dst_mdw, 8, mask)) {
372 if (dev_info->gpu_arch() == compute::gpu_arch_t::gen9) {
373 return reorder_kernel_t::local8x8;
374 }
375 return reorder_kernel_t::transpose8x8;
376 }
377 if (src_mdw.matches_one_of_tag(nhwc) && dst_mdw.matches_one_of_tag(nchw)
378 && padded_dims[last] % 16 == 0
379 && dim_is_div_by_16_or_less_than_16(dst_mdw, 1)) {
380 return reorder_kernel_t::reorder_nchw;
381 }
382 if (src_mdw.matches_one_of_tag(nhwc) && dst_mdw.matches_one_of_tag(nchw)
383 && dim_is_div_by_16_or_less_than_16(dst_mdw, 1)) {
384 return reorder_kernel_t::unaligned_sizes;
385 }
386
387 if (!has_padding_or_multi_scale_quant && (conf.nelems % 256 == 0)
388 && src_mdw.similar_to(dst_mdw, true, false, 0)) {
389 return reorder_kernel_t::dense_vector;
390 }
391
392 if (fits_3ch(src_mdw, dst_mdw, mask)) {
393 return reorder_kernel_t::pad_innermost;
394 }
395 if (fits_xab_xba(src_mdw, dst_mdw, mask)) {
396 return reorder_kernel_t::xb_to_xab_xba;
397 }
398 // This kernel works on tensors that have common innermost dim. Tries to
399 // access mem using large enough chunks to utilize whole cache lines.
400 auto src_last = get_Nth_last_dim_or_block(src_mdw);
401 auto dst_last = get_Nth_last_dim_or_block(dst_mdw);
402 auto inner_dim = dst_last.idx;
403 if (src_last.idx == dst_last.idx
404 && (src_last.size <= 16 || dst_last.size <= 16)
405 && src_last.size % 8 == 0 && dst_last.size % 8 == 0
406 && conf.ndims <= MAX_NDIMS
407 && (src_last.size % (2 * dst_last.size) == 0
408 || dst_last.size % (2 * src_last.size) == 0)
409 && src_mdw.offset0() == 0 && dst_mdw.offset0() == 0
410 && !(mask & (1 << inner_dim))
411 && dst_mdw.dims()[inner_dim] == dst_mdw.padded_dims()[inner_dim]
412 && src_mdw.dims()[inner_dim] == src_mdw.padded_dims()[inner_dim]) {
413 return reorder_kernel_t::vectorize_groups;
414 }
415
416 if (allow_unroll) {
417 if (src_mdw.matches_one_of_tag(ABc16a16b, ABc16b16a, ABcd16a16b,
418 ABcd16b16a, ABcde16a16b, ABcde16b16a, BAc16a16b, BAc16b16a,
419 BAcd16a16b, BAcd16b16a, BAcde16b16a)
420 || dst_mdw.matches_one_of_tag(ABc16a16b, ABc16b16a, ABcd16a16b,
421 ABcd16b16a, ABcde16a16b, ABcde16b16a, BAc16a16b,
422 BAc16b16a, BAcd16a16b, BAcd16b16a, BAcde16b16a)) {
423 return reorder_kernel_t::unroll_16a16b;
424 }
425 if (src_mdw.matches_one_of_tag(aBc16b, aBcd16b, aBcde16b)
426 || dst_mdw.matches_one_of_tag(aBc16b, aBcd16b, aBcde16b)) {
427 return reorder_kernel_t::unroll_16b;
428 }
429 if (src_mdw.matches_one_of_tag(aBCd16b16c, aBCd16c16b, aBCde16b16c,
430 aBCde16c16b, aBCdef16b16c, aBCdef16c16b, aCBd16b16c,
431 aCBd16c16b, aCBde16b16c, aCBde16c16b, aCBdef16c16b)
432 || dst_mdw.matches_one_of_tag(aBCd16b16c, aBCd16c16b,
433 aBCde16b16c, aBCde16c16b, aBCdef16b16c, aBCdef16c16b,
434 aCBd16b16c, aCBd16c16b, aCBde16b16c, aCBde16c16b,
435 aCBdef16c16b)) {
436 return reorder_kernel_t::unroll_16b16c;
437 }
438 }
439
440 if (src_mdw.matches_one_of_tag(abdfce) && dst_mdw.matches_one_of_tag(abcdef)
441 && ((padded_dims[conf.ndims - 2] % 16) == 0)
442 && dim_is_div_by_16_or_less_than_16(dst_mdw, last)) {
443 return reorder_kernel_t::plain_xFxE_to_abcdef;
444 }
445
446 if ((src_mdw.matches_one_of_tag(abcd, acdb))
447 && dst_mdw.matches_one_of_tag(
448 ABcd4a2b, ABcd4a4b, ABcd8a2b, ABcd8a4b)
449 && src_mdw.is_dense() && dst_mdw.is_dense(true)) {
450 return reorder_kernel_t::plain_to_ABcd84a42b;
451 }
452
453 // This kernel will be used where last dimension is not reordered.
454 // It will vectorize that dimension.
455 if (!has_padding_or_multi_scale_quant && src_mdw.is_dense()
456 && dst_mdw.is_dense() && last_dim % 8 == 0
457 && dst_mdw.md_->format_desc.blocking.strides[last] == 1
458 && src_mdw.md_->format_desc.blocking.strides[last] == 1
459 && conf.ndims <= MAX_NDIMS) {
460 return reorder_kernel_t::vectorize_last_dim;
461 }
462
463 // This kernel supports 2D reorders into blocked formats that
464 // end in 8a4b or 8a2b, no matter how many block layers, but no padding.
465 if (!has_padding_or_multi_scale_quant && src_mdw.matches_one_of_tag(ab)
466 && matches_ABxxxx8ayb_layout(
467 dst_mdw.md_->format_desc.blocking, conf.ndims)
468 && padded_dims[last] % 16 == 0) {
469 return reorder_kernel_t::plain_to_ABxx8ayb;
470 }
471
472 if (conf.ndims >= 2 && conf.ndims <= 4
473 && src_mdw.md_->format_desc.blocking.inner_nblks == 0
474 && dst_mdw.md_->format_desc.blocking.inner_nblks == 0
475 && src_mdw.offset0() == 0 && dst_mdw.offset0() == 0
476 && is_alt_faster_than_ref(src_mdw, dst_mdw, dev_info)
477 && !has_padding_or_multi_scale_quant) {
478 return reorder_kernel_t::reorder_alt;
479 }
480
481 return reorder_kernel_t::none;
482}
483
484void custom_reorder_t::pd_t::alt_defines(
485 compute::kernel_ctx_t &kernel_ctx) const {
486 const memory_desc_wrapper src_mdw(src_md());
487 const memory_desc_wrapper dst_mdw(dst_md());
488 size_t ndims = src_mdw.ndims();
489 size_t last = ndims - 1;
490
491 auto sdim = src_mdw.dims();
492 auto sstr = src_mdw.blocking_desc().strides;
493 auto dstr = dst_mdw.blocking_desc().strides;
494 kernel_ctx.define_int("ALT_OFFSETS", 1);
495 // LIMIT_MAX_D0 is necessary to avoid buffer overwritte.
496 if (conf.dispatch.nd_range().global_range()[0] != (size_t)sdim[last]) {
497 kernel_ctx.define_int("LIMIT_MAX_D0", sdim[last]);
498 }
499 kernel_ctx.define_int("S0", sstr[last]);
500 kernel_ctx.define_int("S1", sstr[last - 1]);
501 kernel_ctx.define_int("S2", ndims > 2 ? sstr[last - 2] : 1);
502 kernel_ctx.define_int("SB", ndims > 3 ? sstr[last - 3] : 1);
503 kernel_ctx.define_int("D0", dstr[last]);
504 kernel_ctx.define_int("D1", dstr[last - 1]);
505 kernel_ctx.define_int("D2", ndims > 2 ? dstr[last - 2] : 1);
506 kernel_ctx.define_int("DB", ndims > 3 ? dstr[last - 3] : 1);
507 kernel_ctx.define_int("BLK", ndims > 3 ? sdim[last - 3] : 1);
508}
509
510void custom_reorder_t::pd_t::alt_gen() {
511 const memory_desc_wrapper src_mdw(src_md());
512 const memory_desc_wrapper dst_mdw(dst_md());
513 auto sdim = src_mdw.dims();
514
515 size_t last = src_mdw.ndims() - 1;
516 size_t gws3 = src_mdw.ndims() > 2 ? sdim[last - 2] : 1;
517 size_t gws[3] = {(size_t)sdim[last], (size_t)sdim[last - 1], gws3};
518 size_t work_group_size = 32;
519 if (sdim[last] <= 16) { work_group_size = 16; }
520 if (sdim[last] <= 8) { work_group_size = 8; }
521 const size_t lws[3] = {work_group_size, 1, 1};
522 // Don't use nonuniform work groups, round up number work items if needed.
523 size_t mod = gws[0] % lws[0];
524 if (mod != 0) { gws[0] += lws[0] - mod; }
525 conf.dispatch.generate_override(gws, lws);
526}
527
528status_t custom_reorder_t::pd_t::init_conf(engine_t *engine) {
529 using namespace format_tag;
530
531 const memory_desc_wrapper src_mdw(src_md());
532 const memory_desc_wrapper dst_mdw(dst_md());
533
534 conf.src_md_info = memory_desc_info_t::create(src_mdw);
535 conf.dst_md_info = memory_desc_info_t::create(dst_mdw);
536
537 status_t status = status::success;
538
539 const auto &padded_dims = dst_mdw.padded_dims();
540 conf.src_quant = {attr(), src_mdw, DNNL_ARG_SRC};
541 conf.dst_quant = {attr(), dst_mdw, DNNL_ARG_DST};
542 conf.sum_quant = {attr()};
543 conf.has_padding = !src_mdw.is_dense() || !dst_mdw.is_dense();
544 conf.ndims = src_mdw.ndims();
545 conf.nelems = utils::array_product(padded_dims, conf.ndims);
546
547 conf.sub_group_size = 1;
548
549 if (conf.nelems == 0) return status::success;
550
551 int last = conf.ndims - 1;
552 size_t last_dim = padded_dims[last];
553
554 auto *compute_engine = utils::downcast<compute::compute_engine_t *>(engine);
555
556 conf.implementation = select_kernel(
557 conf, src_mdw, dst_mdw, compute_engine->device_info());
558
559 dim_t blocks[MAX_NDIMS] = {1, 1, 1, 1, 1, 1};
560 int vect_size = 1;
561 int vect_dim = 0;
562
563 conf.dispatch = compute_engine->create_dispatch(dst_mdw.md_);
564 int temp_block = 1;
565
566 const bool may_use_sg8 = compute_engine->mayiuse_sub_group(8);
567 int mask = conf.src_quant.scale_mask() | conf.src_quant.zp_mask()
568 | conf.dst_quant.scale_mask() | conf.dst_quant.zp_mask();
569
570 switch (conf.implementation) {
571 case none: return status_t::dnnl_unimplemented;
572 case reorder_alt:
573 // special handling with dispatcher override
574 conf.sub_group_size = 16;
575 break;
576 case dense_vector:
577 // see special handling below
578 conf.sub_group_size = 16;
579 break;
580 case unroll_16b:
581 conf.sub_group_size = 16;
582 vect_dim = 1;
583 vect_size = 16;
584 break;
585 case unroll_16b16c:
586 conf.sub_group_size = 16;
587 blocks[2] = 16;
588 vect_dim = 1;
589 vect_size = 16;
590 break;
591 case unroll_16a16b:
592 conf.sub_group_size = 16;
593 blocks[0] = 16;
594 vect_dim = 1;
595 vect_size = 16;
596 break;
597 case plain_to_ABcd84a42b: {
598 auto &blk = dst_mdw.blocking_desc();
599 int inner_block = blk.inner_blks[blk.inner_nblks - 1];
600 int outer_block = blk.inner_blks[blk.inner_nblks - 2];
601 conf.sub_group_size = inner_block * outer_block;
602 blocks[0] = outer_block;
603 blocks[1] = inner_block;
604 vect_dim = 3;
605 vect_size = conf.sub_group_size;
606 } break;
607 case xb_to_xab_xba:
608 fill_conf_xab_xba(src_mdw, dst_mdw, mask, conf.aux_data.ab,
609 vect_dim, vect_size, &blocks[0]);
610 conf.sub_group_size = vect_size;
611 break;
612 case vectorize_last_dim:
613 vect_dim = last;
614 vect_size = (last_dim % 16 == 0) ? 16 : 8;
615 if (!may_use_sg8 && vect_size == 8) {
616 return status_t::dnnl_unimplemented;
617 }
618 for (int dim = last - 1;
619 dim >= 0 && dim < MAX_NDIMS && temp_block == 1; dim--) {
620 if (padded_dims[dim] % 4 == 0) { temp_block = 4; }
621 if (padded_dims[dim] % 8 == 0) { temp_block = 8; }
622 if (padded_dims[dim] % 16 == 0) { temp_block = 16; }
623 blocks[dim] = temp_block;
624 }
625 break;
626 case pad_innermost: {
627 auto last_dim_src = get_Nth_last_dim_or_block(src_mdw);
628 auto nextlast_dim_src = get_Nth_last_dim_or_block(src_mdw, 1);
629 auto last_dim_dst = get_Nth_last_dim_or_block(dst_mdw);
630 auto nextlast_dim_dst = get_Nth_last_dim_or_block(dst_mdw, 1);
631
632 int min_common_size
633 = std::min(last_dim_src.size, last_dim_dst.size);
634 int max_common_size
635 = std::max(last_dim_src.size, last_dim_dst.size);
636 conf.sub_group_size = max_common_size;
637 if (!may_use_sg8 && conf.sub_group_size == 8) {
638 return status_t::dnnl_unimplemented;
639 }
640
641 // Group size bigger than 4 would need too much private mem;
642 // group size 1 will give worse perf than reference kernel.
643 int max_group_size = 4;
644 while (nextlast_dim_src.size % max_group_size != 0
645 || nextlast_dim_dst.size % max_group_size != 0) {
646 max_group_size--;
647 }
648
649 conf.aux_data.vg.vector_dim = last_dim_src.idx;
650 conf.aux_data.vg.src_loop_dim = nextlast_dim_dst.idx;
651 conf.aux_data.vg.dst_loop_dim = nextlast_dim_src.idx;
652 conf.aux_data.vg.innermost_size = min_common_size;
653
654 blocks[conf.aux_data.vg.src_loop_dim] = max_group_size;
655 blocks[conf.aux_data.vg.dst_loop_dim] = max_group_size;
656 // if src loop and dst loop dims are the same, CL code would iterate
657 // over the same dimension in inner and outer loop, effectively doing
658 // redundant operations; set inner loop counter to 1 in that case.
659 conf.aux_data.vg.group_size
660 = (nextlast_dim_dst.idx == nextlast_dim_src.idx)
661 ? 1
662 : max_group_size;
663
664 vect_dim = conf.aux_data.vg.vector_dim;
665 vect_size = conf.sub_group_size;
666 } break;
667
668 case vectorize_groups: {
669 auto last_dim_src = get_Nth_last_dim_or_block(src_mdw);
670 auto nextlast_dim_src = get_Nth_last_dim_or_block(src_mdw, 1);
671 auto last_dim_dst = get_Nth_last_dim_or_block(dst_mdw);
672 auto nextlast_dim_dst = get_Nth_last_dim_or_block(dst_mdw, 1);
673 int min_common_size
674 = std::min(last_dim_src.size, last_dim_dst.size);
675 vect_size = (min_common_size % 16 == 0) ? 16 : 8;
676 vect_dim = last_dim_src.idx;
677 if (!may_use_sg8 && vect_size == 8) {
678 return status_t::dnnl_unimplemented;
679 }
680
681 assert(last_dim_src.size % vect_size == 0
682 && last_dim_dst.size % vect_size == 0);
683 assert(last_dim_src.idx == last_dim_dst.idx);
684 int src_chunks;
685 if (last_dim_src.size / vect_size > 1) {
686 src_chunks = last_dim_src.size / vect_size;
687 conf.aux_data.vg.dst_loop_dim = last_dim_src.idx;
688 } else {
689 src_chunks = nextlast_dim_src.size;
690 conf.aux_data.vg.dst_loop_dim = nextlast_dim_src.idx;
691 }
692 int dst_chunks;
693 if (last_dim_dst.size / vect_size > 1) {
694 dst_chunks = last_dim_dst.size / vect_size;
695 conf.aux_data.vg.src_loop_dim = last_dim_dst.idx;
696 } else {
697 dst_chunks = nextlast_dim_dst.size;
698 conf.aux_data.vg.src_loop_dim = nextlast_dim_dst.idx;
699 }
700 // TODO:
701 // Final algorithm for selecting group size should consider:
702 // 1. Group size must be small enough to guarantee no spill.
703 // 2. Group size should be large enough to fill whole cache lines
704 // on both reads and writes, with line size determined by HW
705 // 3. If there's not enough data to feed all EUs, ignore (2) and
706 // decrease group size.
707 int max_data_size = (int)std::max(
708 src_mdw.data_type_size(), dst_mdw.data_type_size());
709
710 int group = 16 / max_data_size;
711 while (group > 1) {
712 if (src_chunks % group == 0 && dst_chunks % group == 0) {
713 break;
714 }
715 group--;
716 }
717 assert(group >= 1);
718
719 conf.aux_data.vg.vector_dim = last_dim_src.idx;
720 conf.aux_data.vg.group_size = group;
721 conf.sub_group_size = vect_size;
722
723 blocks[conf.aux_data.vg.src_loop_dim] = group;
724 blocks[conf.aux_data.vg.dst_loop_dim] = group;
725 } break;
726 case plain_to_ABxx8ayb:
727 conf.sub_group_size = 16;
728 blocks[0] = 8;
729 vect_dim = last;
730 vect_size = 16;
731 break;
732 case plain_xFxE_to_abcdef:
733 conf.sub_group_size = 16;
734 blocks[5] = nstl::min(padded_dims[conf.ndims - 1], dnnl_dim_t(16));
735 vect_dim = 4;
736 vect_size = 16;
737 break;
738 case transpose8x8:
739 case local8x8:
740 if (!may_use_sg8) { return status_t::dnnl_unimplemented; }
741 conf.sub_group_size = 8;
742 blocks[get_Nth_last_dim_or_block(dst_mdw).idx] = 8;
743 vect_dim = get_Nth_last_dim_or_block(src_mdw).idx;
744 vect_size = 8;
745 break;
746 case transpose16x16:
747 case local16x16:
748 conf.sub_group_size = 16;
749 blocks[get_Nth_last_dim_or_block(dst_mdw).idx] = 16;
750 vect_dim = get_Nth_last_dim_or_block(src_mdw).idx;
751 vect_size = 16;
752 break;
753 case reorder_nchw:
754 conf.sub_group_size = 16;
755 blocks[1] = nstl::min(padded_dims[1], dnnl_dim_t(16));
756 vect_dim = 3;
757 vect_size = 16;
758 break;
759 case unaligned_sizes: blocks[1] = padded_dims[1]; break;
760 }
761
762 // special case for dense_vector kernel - treat tensors as flat 1D vectors
763 if (conf.implementation == dense_vector) {
764 conf.dispatch.define_dim("D0", 0, conf.nelems, 16);
765 CHECK(conf.dispatch.vectorize_dim("D0", 16));
766 } else {
767 for (int i = 0; i < MAX_NDIMS; ++i) {
768 auto dim_str = utils::format("D%d", i);
769 if (i < dst_mdw.ndims()) {
770 int dim = padded_dims[i];
771 // if needed to align vectorized dim with vector size, pad that dim again
772 if (i == vect_dim) { dim = utils::rnd_up(dim, vect_size); }
773 conf.dispatch.define_dim(dim_str, i, dim, blocks[i]);
774 } else {
775 conf.dispatch.define_dim(dim_str, 1);
776 }
777 }
778 if (vect_size != 1) {
779 auto dim_str = utils::format("D%d", vect_dim);
780 CHECK(conf.dispatch.vectorize_dim(dim_str, vect_size));
781 }
782 }
783
784 if (conf.implementation == reorder_alt) {
785 alt_gen();
786 } else {
787 conf.dispatch.generate();
788 }
789 return status;
790}
791
792status_t custom_reorder_t::pd_t::init_kernel_ctx(
793 compute::kernel_ctx_t &kernel_ctx) const {
794 using namespace format_tag;
795
796 const memory_desc_wrapper src_mdw(src_md());
797 const memory_desc_wrapper dst_mdw(dst_md());
798
799 if (conf.nelems == 0) return status::success;
800
801 kernel_ctx.define_int("NDIMS", conf.ndims);
802 kernel_ctx.add_option("-cl-std=CL2.0");
803
804 conf.src_quant.define_macros(kernel_ctx, "SRC");
805 conf.dst_quant.define_macros(kernel_ctx, "DST");
806 conf.sum_quant.define_macros(kernel_ctx, "SUM");
807
808 def_dispatch(kernel_ctx, conf.dispatch);
809
810 // the 'unaligned_sizes' kernel uses the same implementation in .cl
811 // the difference is in sizes of blocks[]
812 if (conf.implementation == unaligned_sizes) {
813 kernel_ctx.define_int("UNALIGNED", 1);
814 }
815 kernel_ctx.define_int("SUB_GROUP_SIZE", conf.sub_group_size);
816
817 kernel_ctx.define_int("PAD_FILL_ZERO", conf.has_padding);
818 if (conf.implementation == dense_vector) {
819 kernel_ctx.add_option("-Dcl_intel_subgroups_char");
820 kernel_ctx.define_int("USE_DENSE_VECT", 1);
821 }
822
823 def_memory_desc_info(kernel_ctx, conf.src_md_info, "SRC");
824 def_memory_desc_info(kernel_ctx, conf.dst_md_info, "DST");
825
826 // distinguish between various flavors of unroll kernel
827 if (src_mdw.matches_one_of_tag(
828 ABc16a16b, ABcd16a16b, ABcde16a16b, BAc16a16b, BAcd16a16b)) {
829 kernel_ctx.define_int("SRC_16A16B", 1);
830 } else if (src_mdw.matches_one_of_tag(ABc16b16a, ABcd16b16a, ABcde16b16a,
831 BAc16b16a, BAcd16b16a, BAcde16b16a)) {
832 kernel_ctx.define_int("SRC_16B16A", 1);
833 } else if (src_mdw.matches_one_of_tag(aBc16b, aBcd16b, aBcde16b)) {
834 kernel_ctx.define_int("SRC_16B", 1);
835 } else if (src_mdw.matches_one_of_tag(aBCd16b16c, aBCde16b16c, aBCdef16b16c,
836 aCBd16b16c, aCBde16b16c)) {
837 kernel_ctx.define_int("SRC_16B16C", 1);
838 } else if (src_mdw.matches_one_of_tag(aBCd16c16b, aBCde16c16b, aBCdef16c16b,
839 aCBd16c16b, aCBde16c16b, aCBdef16c16b)) {
840 kernel_ctx.define_int("SRC_16C16B", 1);
841 }
842 if (dst_mdw.matches_one_of_tag(
843 ABc16a16b, ABcd16a16b, ABcde16a16b, BAc16a16b, BAcd16a16b)) {
844 kernel_ctx.define_int("DST_16A16B", 1);
845 } else if (dst_mdw.matches_one_of_tag(ABc16b16a, ABcd16b16a, ABcde16b16a,
846 BAc16b16a, BAcd16b16a, BAcde16b16a)) {
847 kernel_ctx.define_int("DST_16B16A", 1);
848 } else if (dst_mdw.matches_one_of_tag(aBc16b, aBcd16b, aBcde16b)) {
849 kernel_ctx.define_int("DST_16B", 1);
850 } else if (dst_mdw.matches_one_of_tag(aBCd16b16c, aBCde16b16c, aBCdef16b16c,
851 aCBd16b16c, aCBde16b16c)) {
852 kernel_ctx.define_int("DST_16B16C", 1);
853 } else if (dst_mdw.matches_one_of_tag(aBCd16c16b, aBCde16c16b, aBCdef16c16b,
854 aCBd16c16b, aCBde16c16b, aCBdef16c16b)) {
855 kernel_ctx.define_int("DST_16C16B", 1);
856 }
857
858 if (conf.implementation == reorder_alt) { alt_defines(kernel_ctx); }
859 if (conf.implementation == plain_xFxE_to_abcdef)
860 kernel_ctx.define_int("PLAIN_xFxE_TO_ABCDEF", 1);
861
862 if (conf.implementation == plain_to_ABcd84a42b) {
863 kernel_ctx.define_int("PLAIN_TO_ABCD84A42B", 1);
864 auto r = conf.dispatch.nd_range();
865 auto *lr = r.local_range();
866 kernel_ctx.define_int(
867 "SG_PER_WG", (lr[0] * lr[1] * lr[2]) / conf.sub_group_size);
868 }
869 if (conf.implementation == xb_to_xab_xba) {
870 kernel_ctx.define_int("XAB_XBA", 1);
871 auto r = conf.dispatch.nd_range();
872 auto *lr = r.local_range();
873 kernel_ctx.define_int(
874 "SG_PER_WG", (lr[0] * lr[1] * lr[2]) / conf.sub_group_size);
875 kernel_ctx.define_int("BLOCK_SIZE", conf.aux_data.ab.blk_size);
876 kernel_ctx.define_int("SRC_BLK_DIM", conf.aux_data.ab.src_blk_dim);
877 kernel_ctx.define_int("SRC_OFF_COEFF", conf.aux_data.ab.src_blk_coeff);
878 kernel_ctx.define_int("DST_BLK_DIM", conf.aux_data.ab.dst_blk_dim);
879 kernel_ctx.define_int("DST_OFF_COEFF", conf.aux_data.ab.dst_blk_coeff);
880 kernel_ctx.define_int("XB_TO_XAB", conf.aux_data.ab.vd);
881 }
882
883 if (conf.implementation == vectorize_last_dim) {
884 kernel_ctx.define_int("VECTORIZE_LAST_DIM", 1);
885 }
886
887 if (conf.implementation == pad_innermost) {
888 kernel_ctx.define_int("PAD_INNERMOST", 1);
889 kernel_ctx.define_int(
890 "VECT_DIM", conf.aux_data.vg.vector_dim); //useless
891 kernel_ctx.define_int("SRC_LOOP_DIM", conf.aux_data.vg.src_loop_dim);
892 kernel_ctx.define_int("DST_LOOP_DIM", conf.aux_data.vg.dst_loop_dim);
893 kernel_ctx.define_int("GROUP", conf.aux_data.vg.group_size);
894 auto r = conf.dispatch.nd_range();
895 auto *lr = r.local_range();
896 if (!lr) return status::runtime_error;
897 kernel_ctx.define_int(
898 "SG_PER_WG", (lr[0] * lr[1] * lr[2]) / conf.sub_group_size);
899 kernel_ctx.define_int(
900 "INNERMOST_SIZE", conf.aux_data.vg.innermost_size);
901 kernel_ctx.define_int("VECT_SIZE", conf.sub_group_size);
902 bool has_non_innermost_padding = false;
903 for (int i = 0; i < MAX_NDIMS; i++) {
904 if (i == conf.aux_data.vg.vector_dim) { continue; }
905 has_non_innermost_padding
906 |= (dst_mdw.dims()[i] != dst_mdw.padded_dims()[i]);
907 }
908 kernel_ctx.define_int(
909 "NON_INNERMOST_PADDING", has_non_innermost_padding);
910 auto last_dim_dst = get_Nth_last_dim_or_block(dst_mdw);
911 kernel_ctx.define_int("DST_INNERMOST_STRIDE",
912 dst_mdw.is_plain()
913 ? dst_mdw.blocking_desc().strides[last_dim_dst.idx]
914 : 1);
915 }
916 if (conf.implementation == vectorize_groups) {
917 kernel_ctx.define_int("VECTORIZE_GROUPS", 1);
918 kernel_ctx.define_int("VECT_DIM", conf.aux_data.vg.vector_dim);
919 kernel_ctx.define_int("SRC_LOOP_DIM", conf.aux_data.vg.src_loop_dim);
920 kernel_ctx.define_int("DST_LOOP_DIM", conf.aux_data.vg.dst_loop_dim);
921 kernel_ctx.define_int("GROUP", conf.aux_data.vg.group_size);
922 }
923 if (conf.implementation == plain_to_ABxx8ayb) {
924 kernel_ctx.define_int("PLAIN_TO_AB_XX_8AYB", 1);
925 kernel_ctx.define_int(
926 "BLK_L", innermost_block(dst_mdw.md_->format_desc.blocking));
927 }
928
929 if (conf.implementation == transpose8x8
930 || conf.implementation == transpose16x16) {
931 kernel_ctx.define_int("TRANSPOSE_NXN", 1);
932 kernel_ctx.define_int(
933 "DST_BLOCK_DIM", get_Nth_last_dim_or_block(src_mdw).idx);
934 }
935
936 if (conf.implementation == local8x8 || conf.implementation == local16x16) {
937 kernel_ctx.define_int("LOCAL_NXN", 1);
938 auto r = conf.dispatch.nd_range();
939 auto *lr = r.local_range();
940 if (!lr) return status::runtime_error;
941 kernel_ctx.define_int("SG_PER_WG", lr[0] * lr[1] * lr[2]);
942 kernel_ctx.define_int(
943 "DST_BLOCK_DIM", get_Nth_last_dim_or_block(src_mdw).idx);
944 }
945
946 if (conf.implementation == reorder_nchw) {
947 kernel_ctx.define_int("REORDER_NCHW", 1);
948 }
949
950 kernel_ctx.print_options();
951 return status::success;
952}
953
954void custom_reorder_t::pd_t::init_scratchpad() {
955 if (conf.src_quant.with_scale()) {
956 auto scratchpad = scratchpad_registry().registrar();
957 scratchpad.book(memory_tracking::names::key_reorder_src_scales,
958 conf.src_quant.num_scales(), sizeof(float),
959 OCL_BUFFER_ALIGNMENT);
960 }
961 if (conf.dst_quant.with_scale()) {
962 auto scratchpad = scratchpad_registry().registrar();
963 scratchpad.book(memory_tracking::names::key_reorder_dst_scales,
964 conf.dst_quant.num_scales(), sizeof(float),
965 OCL_BUFFER_ALIGNMENT);
966 }
967}
968
969status_t custom_reorder_t::execute(const exec_ctx_t &ctx) const {
970
971 status_t status = status::success;
972
973 auto &src = CTX_IN_STORAGE(DNNL_ARG_FROM);
974 auto &dst = CTX_OUT_STORAGE(DNNL_ARG_TO);
975 CHECK(status);
976
977 const auto &conf = pd()->conf;
978 if (conf.nelems == 0) return status::success;
979
980 compute::kernel_arg_list_t arg_list;
981 arg_list.set(0, src);
982 arg_list.set(1, dst);
983
984 arg_list.set(2, conf.src_quant.scales(ctx));
985 arg_list.set(3, conf.src_quant.zero_points(ctx));
986 arg_list.set(4, conf.dst_quant.scales(ctx));
987 arg_list.set(5, conf.dst_quant.zero_points(ctx));
988
989 arg_list.set(6, conf.sum_quant.scales());
990 arg_list.set(7, conf.sum_quant.zero_points());
991
992 auto nd_range = conf.dispatch.nd_range();
993
994 status = parallel_for(ctx, nd_range, kernel_, arg_list);
995
996 return status;
997}
998
999} // namespace ocl
1000} // namespace gpu
1001} // namespace impl
1002} // namespace dnnl
1003