1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <algorithm> |
18 | |
19 | #include "gpu/ocl/custom_reorder.hpp" |
20 | |
21 | #include "common/utils.hpp" |
22 | #include "gpu/ocl/ocl_stream.hpp" |
23 | #include "gpu/ocl/ocl_utils.hpp" |
24 | namespace dnnl { |
25 | namespace impl { |
26 | namespace gpu { |
27 | namespace ocl { |
28 | |
29 | using namespace dnnl::impl::memory_tracking::names; |
30 | |
31 | using dimension = struct { |
32 | dim_t size; |
33 | int idx; |
34 | }; |
35 | |
36 | using stride_t = struct { |
37 | dim_t stride; |
38 | dim_t size; |
39 | int idx; |
40 | }; |
41 | |
42 | // Stride sorter. Smaller stride = inner dim, bigger stride = outer dim. |
43 | // Dimensions of size 1 are considered outermost regardless of strides and |
44 | // they are sorted by index. |
45 | bool stride_cmp(const stride_t &a, const stride_t &b) { |
46 | if (a.size == 1 && b.size == 1) { |
47 | return a.idx > b.idx; |
48 | } else if (a.size != 1 && b.size == 1) { |
49 | return true; |
50 | } else if (a.size == 1 && b.size != 1) { |
51 | return false; |
52 | } else { |
53 | return a.stride < b.stride; |
54 | } |
55 | } |
56 | |
57 | // Returns size and index of dimension or block that's last or at given |
58 | // distance from end. Blocks, if exist, take precedence before dimensions. |
59 | // Order of dimensions is determined by sorting strides; smallest stride is |
60 | // last dimension. Dimensions of size 1 are 'ignored' (treated as first). |
61 | // Due to ignoring dims of size 1 and treating them as outermost regardless of |
62 | // their original position in tensor tag, now it is possible to arrive at |
63 | // illegal tensor tag where last dim is the same as first block: |
64 | // 8x8x1x1 aBcd8b becomes cdaB8b (notice the "B8b"). |
65 | // In such cases this function combines last dim with first block (xB8b := xb) |
66 | dimension get_Nth_last_dim_or_block( |
67 | const memory_desc_wrapper &md, int distance = 0) { |
68 | int nblks = md.blocking_desc().inner_nblks; |
69 | dimension ret; |
70 | int ndims = md.ndims(); |
71 | |
72 | std::vector<stride_t> strides(ndims); |
73 | for (int d = 0; d < ndims; ++d) { |
74 | strides[d].idx = d; |
75 | strides[d].stride = md.blocking_desc().strides[d]; |
76 | strides[d].size = md.padded_dims()[d]; |
77 | } |
78 | std::sort(strides.begin(), strides.end(), stride_cmp); |
79 | for (int i = 0; i < nblks; i++) { |
80 | stride_t blk; |
81 | blk.idx = md.blocking_desc().inner_idxs[i]; |
82 | blk.size = md.blocking_desc().inner_blks[i]; |
83 | if (i == 0 && blk.idx == strides[0].idx) { continue; } |
84 | strides.insert(strides.begin(), blk); |
85 | } |
86 | ret.idx = strides[distance].idx; |
87 | ret.size = strides[distance].size; |
88 | return ret; |
89 | } |
90 | |
91 | int innermost_block(const blocking_desc_t &blk) { |
92 | int last = blk.inner_nblks - 1; |
93 | return blk.inner_blks[last]; |
94 | } |
95 | |
96 | bool is_alt_faster_than_ref(const memory_desc_wrapper &src_mdw, |
97 | const memory_desc_wrapper &dst_mdw, |
98 | const compute::device_info_t *dev_info) { |
99 | using namespace format_tag; |
100 | int ndims = src_mdw.ndims(); |
101 | int last = ndims - 1; |
102 | if (!src_mdw.matches_one_of_tag(abcd, abc, ab)) { return false; } |
103 | // on GPUs newer than gen9 reference implementation is usually faster |
104 | if (dev_info->gpu_arch() != compute::gpu_arch_t::gen9) { return false; } |
105 | // ensure reasonable work group size |
106 | if (src_mdw.dims()[last] < 8) { return false; } |
107 | // abcd->???b reorders are faster with reference implementation |
108 | if (ndims == 4 && dst_mdw.blocking_desc().strides[1] == 1) { return false; } |
109 | if (ndims == 3 && dst_mdw.blocking_desc().strides[0] == 1) { return false; } |
110 | return true; |
111 | } |
112 | |
113 | bool matches_one_NxN_layout(const memory_desc_wrapper &src, |
114 | const memory_desc_wrapper &dst, int n, int scale_mask) { |
115 | if (dst.ndims() < 2) { return false; } |
116 | if (!src.is_blocking_desc() || !dst.is_blocking_desc()) { return false; } |
117 | auto dst_last = get_Nth_last_dim_or_block(dst, 0); |
118 | auto src_last = get_Nth_last_dim_or_block(src, 0); |
119 | |
120 | if (dst_last.size % n != 0) { return false; } |
121 | if (src_last.size % n != 0) { return false; } |
122 | if (dst_last.idx == src_last.idx) { return false; } |
123 | // no padding allowed on dimensions that are last in src or last in dst |
124 | if (src.padded_dims()[src_last.idx] != src.dims()[src_last.idx]) { |
125 | return false; |
126 | } |
127 | if (src.padded_dims()[dst_last.idx] != src.dims()[dst_last.idx]) { |
128 | return false; |
129 | } |
130 | if (dst.padded_dims()[src_last.idx] != dst.dims()[src_last.idx]) { |
131 | return false; |
132 | } |
133 | if (dst.padded_dims()[dst_last.idx] != dst.dims()[dst_last.idx]) { |
134 | return false; |
135 | } |
136 | // no scale mask allowed on src's or dst's innermost dimension |
137 | if (scale_mask & (1 << src_last.idx)) { return false; } |
138 | if (scale_mask & (1 << dst_last.idx)) { return false; } |
139 | |
140 | return true; |
141 | } |
142 | |
143 | // For cases like ab -> Ba4b, BA4b8a8b2a etc. Takes dst's last two dimensions |
144 | // and groups them together into packets of size 16 so that it can be written |
145 | // to in bursts. |
146 | bool fill_conf_xab_xba(const memory_desc_wrapper &src, |
147 | const memory_desc_wrapper &dst, int scale_mask, xb_to_xab_xba_t &cfg, |
148 | int &vect_dim, int &vect_size, dim_t *blocks) { |
149 | |
150 | vect_size = 16; |
151 | if (dst.ndims() < 2) { return false; } |
152 | dimension src_last = get_Nth_last_dim_or_block(src); |
153 | dimension dst_last = get_Nth_last_dim_or_block(dst); |
154 | dimension dst_next_last = get_Nth_last_dim_or_block(dst, 1); |
155 | |
156 | if (src_last.idx != dst_last.idx && src_last.idx != dst_next_last.idx) { |
157 | return false; |
158 | } |
159 | bool xb_to_xab = (src_last.idx == dst_last.idx); |
160 | |
161 | // no padding on dims that are src's or dst's innermost |
162 | if (src.padded_dims()[src_last.idx] != src.dims()[src_last.idx]) { |
163 | return false; |
164 | } |
165 | if (src.padded_dims()[dst_last.idx] != src.dims()[dst_last.idx]) { |
166 | return false; |
167 | } |
168 | if (dst.padded_dims()[src_last.idx] != dst.dims()[src_last.idx]) { |
169 | return false; |
170 | } |
171 | if (dst.padded_dims()[dst_last.idx] != dst.dims()[dst_last.idx]) { |
172 | return false; |
173 | } |
174 | if (src.offset0() != 0) { return false; } |
175 | if (dst.offset0() != 0) { return false; } |
176 | |
177 | // no scale mask allowed on src's or dst's innermost dimension |
178 | if (scale_mask & (1 << src_last.idx)) { return false; } |
179 | if (scale_mask & (1 << dst_last.idx)) { return false; } |
180 | |
181 | if (src_last.size < 16) { return false; } |
182 | if (src_last.size % 16 != 0) { return false; } |
183 | if (vect_size % dst_last.size != 0) { return false; } |
184 | if (dst_next_last.size % (vect_size / dst_last.size) != 0) { return false; } |
185 | |
186 | dimension src_burst; |
187 | dimension dst_burst[2]; |
188 | dimension dst_loop; |
189 | dimension src_loop; |
190 | |
191 | src_burst.idx = src_last.idx; |
192 | src_burst.size = vect_size; |
193 | dst_burst[0] = dst_last; |
194 | dst_burst[1].idx = dst_next_last.idx; |
195 | dst_burst[1].size = 16 / dst_last.size; |
196 | |
197 | int dst_src_idx = xb_to_xab ? 0 : 1; |
198 | dst_loop.size = src_burst.size / dst_burst[dst_src_idx].size; |
199 | dst_loop.idx = src_burst.idx; |
200 | |
201 | src_loop.size = dst_burst[1 - dst_src_idx].size; |
202 | src_loop.idx = dst_burst[1 - dst_src_idx].idx; |
203 | |
204 | vect_dim = src_last.idx; |
205 | |
206 | if ((dst_last.size * dst_next_last.size) % vect_size != 0) { return false; } |
207 | blocks[src_loop.idx] = src_loop.size; |
208 | cfg.blk_size = src_loop.size; |
209 | cfg.src_blk_dim = src_loop.idx; |
210 | cfg.src_blk_coeff = 1; |
211 | cfg.dst_blk_dim = dst_loop.idx; |
212 | |
213 | if (dst_loop.idx == dst_burst[0].idx) { |
214 | cfg.dst_blk_coeff = dst_burst[0].size; |
215 | } else if (dst_loop.idx == dst_burst[1].idx) { |
216 | cfg.dst_blk_coeff = dst_burst[1].size; |
217 | } else { |
218 | cfg.dst_blk_coeff = 1; |
219 | } |
220 | cfg.vd = xb_to_xab; |
221 | return true; |
222 | } |
223 | |
224 | bool fits_xab_xba(const memory_desc_wrapper &src, |
225 | const memory_desc_wrapper &dst, int scale_mask) { |
226 | xb_to_xab_xba_t cfg; |
227 | int vect_dim; |
228 | int vect_size; |
229 | dim_t blocks[6]; |
230 | |
231 | return fill_conf_xab_xba( |
232 | src, dst, scale_mask, cfg, vect_dim, vect_size, &blocks[0]); |
233 | } |
234 | |
235 | bool matches_ABxxxx8ayb_layout(const blocking_desc_t &blk, int ndims) { |
236 | if (ndims > 2) { return false; } |
237 | int last = blk.inner_nblks - 1; |
238 | // Don't allow this kernel when two adjacent blocks by b create |
239 | // total block size smaller than 16 - in that situation macros |
240 | // used for calculation of dst address return wrong values. |
241 | for (int d = last - 2; d >= 0; d--) { |
242 | if (blk.inner_idxs[d] == ndims - 1) { |
243 | int double_block = blk.inner_blks[last] * blk.inner_blks[d]; |
244 | if (double_block < 16) { |
245 | return false; |
246 | } else { |
247 | break; |
248 | } |
249 | } |
250 | } |
251 | return ((blk.inner_blks[last] == 4 || blk.inner_blks[last] == 2) |
252 | && blk.inner_idxs[last] == ndims - 1 |
253 | && blk.inner_blks[last - 1] == 8 |
254 | && blk.inner_idxs[last - 1] == ndims - 2); |
255 | } |
256 | |
257 | bool dim_is_div_by_16_or_less_than_16( |
258 | const memory_desc_wrapper &src, int dim_index) { |
259 | const auto &padded_dims = src.padded_dims(); |
260 | assert(dim_index < src.ndims()); |
261 | return (padded_dims[dim_index] % 16 == 0 || padded_dims[dim_index] < 16); |
262 | } |
263 | |
264 | bool is_broadcast_by_strides(const memory_desc_wrapper &mdw) { |
265 | if (mdw.is_blocking_desc()) { |
266 | for (int i = 0; i < mdw.ndims(); i++) { |
267 | if (mdw.blocking_desc().strides[i] == 0) { return true; } |
268 | } |
269 | } |
270 | return false; |
271 | } |
272 | |
273 | bool is_padded(const memory_desc_wrapper &mdw, int dim) { |
274 | return (mdw.dims()[dim] != mdw.padded_dims()[dim]); |
275 | } |
276 | |
277 | bool fits_3ch(const memory_desc_wrapper &src_mdw, |
278 | const memory_desc_wrapper &dst_mdw, int scale_mask) { |
279 | // TODO: make it more generic, now it works only for dense->padded case |
280 | if (src_mdw.ndims() < 2 || dst_mdw.ndims() < 2) { return false; } |
281 | |
282 | auto last_dim_src = get_Nth_last_dim_or_block(src_mdw); |
283 | auto nextlast_dim_src = get_Nth_last_dim_or_block(src_mdw, 1); |
284 | auto last_dim_dst = get_Nth_last_dim_or_block(dst_mdw); |
285 | auto nextlast_dim_dst = get_Nth_last_dim_or_block(dst_mdw, 1); |
286 | |
287 | bool same_nextlast_dims = (nextlast_dim_src.idx == nextlast_dim_dst.idx); |
288 | |
289 | // src's innermost dim is assumed to be contiguous, it'll be read from |
290 | // adjacent mem addresses |
291 | auto src_innermost_stride |
292 | = src_mdw.blocking_desc().strides[last_dim_src.idx]; |
293 | if (last_dim_src.size != src_mdw.padded_dims()[last_dim_src.idx] |
294 | || (src_innermost_stride > 1 && src_mdw.is_plain())) { |
295 | return false; |
296 | } |
297 | if (last_dim_src.idx != last_dim_dst.idx) { return false; } |
298 | if (last_dim_src.size > last_dim_dst.size) { return false; } |
299 | if (last_dim_dst.size > 16 || last_dim_dst.size % 8 != 0) { return false; } |
300 | if (last_dim_src.idx == nextlast_dim_src.idx) { return false; } |
301 | if (last_dim_src.idx == nextlast_dim_dst.idx) { return false; } |
302 | if (nextlast_dim_src.size % 2 != 0) { return false; } |
303 | if (nextlast_dim_dst.size % 2 != 0) { return false; } |
304 | if (is_padded(src_mdw, last_dim_src.idx)) { return false; } |
305 | // If src's and dst's nextlast dims are not the same, CL code will use |
306 | // nested loops. In inner loop, data offset is incremented and can't be |
307 | // converted back into tensor coordinates. This means there can't be |
308 | // operations that depend on nextlast dim's coordinates in inner loop - |
309 | // so no padding and no scale quant. |
310 | if (!same_nextlast_dims) { |
311 | if (is_padded(src_mdw, nextlast_dim_src.idx)) { return false; } |
312 | if (is_padded(src_mdw, nextlast_dim_dst.idx)) { return false; } |
313 | if (is_padded(dst_mdw, nextlast_dim_src.idx)) { return false; } |
314 | if (is_padded(dst_mdw, nextlast_dim_dst.idx)) { return false; } |
315 | if (scale_mask & (1 << nextlast_dim_src.idx)) { return false; } |
316 | if (scale_mask & (1 << nextlast_dim_dst.idx)) { return false; } |
317 | } |
318 | if (scale_mask & (1 << last_dim_src.idx)) { return false; } |
319 | // no 2nd layer of block on innermost dim in dst |
320 | if (dst_mdw.padded_dims()[last_dim_src.idx] != last_dim_dst.size) { |
321 | return false; |
322 | } |
323 | return true; |
324 | } |
325 | |
326 | reorder_kernel_t select_kernel(const reorder_conf_t &conf, |
327 | const memory_desc_wrapper &src_mdw, const memory_desc_wrapper &dst_mdw, |
328 | const compute::device_info_t *dev_info) { |
329 | using namespace format_tag; |
330 | |
331 | const auto &padded_dims = dst_mdw.padded_dims(); |
332 | |
333 | int last = conf.ndims - 1; |
334 | size_t last_dim = padded_dims[last]; |
335 | |
336 | const bool multi_scale_quant |
337 | = (conf.src_quant.with_scale() && conf.src_quant.num_scales() > 1) |
338 | || (conf.dst_quant.with_scale() && conf.dst_quant.num_scales() > 1); |
339 | const bool has_padding_or_multi_scale_quant |
340 | = conf.has_padding || multi_scale_quant; |
341 | |
342 | const bool type_s8_u8 = utils::one_of(src_mdw.data_type(), dnnl_s8, dnnl_u8) |
343 | || utils::one_of(dst_mdw.data_type(), dnnl_s8, dnnl_u8); |
344 | |
345 | const bool allow_unroll |
346 | = !conf.has_padding && !multi_scale_quant && !type_s8_u8; |
347 | |
348 | if (is_broadcast_by_strides(src_mdw) || is_broadcast_by_strides(dst_mdw)) { |
349 | return reorder_kernel_t::none; |
350 | } |
351 | |
352 | int mask = conf.src_quant.scale_mask() | conf.src_quant.zp_mask() |
353 | | conf.dst_quant.scale_mask() | conf.dst_quant.zp_mask(); |
354 | |
355 | if (matches_one_NxN_layout(src_mdw, dst_mdw, 16, mask)) { |
356 | // W/A for compiler bug: avoid using intel_sub_group_shuffle with |
357 | // SIMD16 on gen12lp |
358 | if (dev_info->gpu_arch() == compute::gpu_arch_t::xe_lp) { |
359 | return reorder_kernel_t::transpose8x8; |
360 | } |
361 | if (dev_info->gpu_arch() == compute::gpu_arch_t::gen9) { |
362 | return reorder_kernel_t::local16x16; |
363 | } |
364 | // W/A for assumed compiler bug: avoid using intel_sub_group_shuffle |
365 | // with SIMD16 on Gen11. |
366 | if (dev_info->gpu_arch() == compute::gpu_arch_t::gen11) { |
367 | return reorder_kernel_t::transpose8x8; |
368 | } |
369 | return reorder_kernel_t::transpose16x16; |
370 | } |
371 | if (matches_one_NxN_layout(src_mdw, dst_mdw, 8, mask)) { |
372 | if (dev_info->gpu_arch() == compute::gpu_arch_t::gen9) { |
373 | return reorder_kernel_t::local8x8; |
374 | } |
375 | return reorder_kernel_t::transpose8x8; |
376 | } |
377 | if (src_mdw.matches_one_of_tag(nhwc) && dst_mdw.matches_one_of_tag(nchw) |
378 | && padded_dims[last] % 16 == 0 |
379 | && dim_is_div_by_16_or_less_than_16(dst_mdw, 1)) { |
380 | return reorder_kernel_t::reorder_nchw; |
381 | } |
382 | if (src_mdw.matches_one_of_tag(nhwc) && dst_mdw.matches_one_of_tag(nchw) |
383 | && dim_is_div_by_16_or_less_than_16(dst_mdw, 1)) { |
384 | return reorder_kernel_t::unaligned_sizes; |
385 | } |
386 | |
387 | if (!has_padding_or_multi_scale_quant && (conf.nelems % 256 == 0) |
388 | && src_mdw.similar_to(dst_mdw, true, false, 0)) { |
389 | return reorder_kernel_t::dense_vector; |
390 | } |
391 | |
392 | if (fits_3ch(src_mdw, dst_mdw, mask)) { |
393 | return reorder_kernel_t::pad_innermost; |
394 | } |
395 | if (fits_xab_xba(src_mdw, dst_mdw, mask)) { |
396 | return reorder_kernel_t::xb_to_xab_xba; |
397 | } |
398 | // This kernel works on tensors that have common innermost dim. Tries to |
399 | // access mem using large enough chunks to utilize whole cache lines. |
400 | auto src_last = get_Nth_last_dim_or_block(src_mdw); |
401 | auto dst_last = get_Nth_last_dim_or_block(dst_mdw); |
402 | auto inner_dim = dst_last.idx; |
403 | if (src_last.idx == dst_last.idx |
404 | && (src_last.size <= 16 || dst_last.size <= 16) |
405 | && src_last.size % 8 == 0 && dst_last.size % 8 == 0 |
406 | && conf.ndims <= MAX_NDIMS |
407 | && (src_last.size % (2 * dst_last.size) == 0 |
408 | || dst_last.size % (2 * src_last.size) == 0) |
409 | && src_mdw.offset0() == 0 && dst_mdw.offset0() == 0 |
410 | && !(mask & (1 << inner_dim)) |
411 | && dst_mdw.dims()[inner_dim] == dst_mdw.padded_dims()[inner_dim] |
412 | && src_mdw.dims()[inner_dim] == src_mdw.padded_dims()[inner_dim]) { |
413 | return reorder_kernel_t::vectorize_groups; |
414 | } |
415 | |
416 | if (allow_unroll) { |
417 | if (src_mdw.matches_one_of_tag(ABc16a16b, ABc16b16a, ABcd16a16b, |
418 | ABcd16b16a, ABcde16a16b, ABcde16b16a, BAc16a16b, BAc16b16a, |
419 | BAcd16a16b, BAcd16b16a, BAcde16b16a) |
420 | || dst_mdw.matches_one_of_tag(ABc16a16b, ABc16b16a, ABcd16a16b, |
421 | ABcd16b16a, ABcde16a16b, ABcde16b16a, BAc16a16b, |
422 | BAc16b16a, BAcd16a16b, BAcd16b16a, BAcde16b16a)) { |
423 | return reorder_kernel_t::unroll_16a16b; |
424 | } |
425 | if (src_mdw.matches_one_of_tag(aBc16b, aBcd16b, aBcde16b) |
426 | || dst_mdw.matches_one_of_tag(aBc16b, aBcd16b, aBcde16b)) { |
427 | return reorder_kernel_t::unroll_16b; |
428 | } |
429 | if (src_mdw.matches_one_of_tag(aBCd16b16c, aBCd16c16b, aBCde16b16c, |
430 | aBCde16c16b, aBCdef16b16c, aBCdef16c16b, aCBd16b16c, |
431 | aCBd16c16b, aCBde16b16c, aCBde16c16b, aCBdef16c16b) |
432 | || dst_mdw.matches_one_of_tag(aBCd16b16c, aBCd16c16b, |
433 | aBCde16b16c, aBCde16c16b, aBCdef16b16c, aBCdef16c16b, |
434 | aCBd16b16c, aCBd16c16b, aCBde16b16c, aCBde16c16b, |
435 | aCBdef16c16b)) { |
436 | return reorder_kernel_t::unroll_16b16c; |
437 | } |
438 | } |
439 | |
440 | if (src_mdw.matches_one_of_tag(abdfce) && dst_mdw.matches_one_of_tag(abcdef) |
441 | && ((padded_dims[conf.ndims - 2] % 16) == 0) |
442 | && dim_is_div_by_16_or_less_than_16(dst_mdw, last)) { |
443 | return reorder_kernel_t::plain_xFxE_to_abcdef; |
444 | } |
445 | |
446 | if ((src_mdw.matches_one_of_tag(abcd, acdb)) |
447 | && dst_mdw.matches_one_of_tag( |
448 | ABcd4a2b, ABcd4a4b, ABcd8a2b, ABcd8a4b) |
449 | && src_mdw.is_dense() && dst_mdw.is_dense(true)) { |
450 | return reorder_kernel_t::plain_to_ABcd84a42b; |
451 | } |
452 | |
453 | // This kernel will be used where last dimension is not reordered. |
454 | // It will vectorize that dimension. |
455 | if (!has_padding_or_multi_scale_quant && src_mdw.is_dense() |
456 | && dst_mdw.is_dense() && last_dim % 8 == 0 |
457 | && dst_mdw.md_->format_desc.blocking.strides[last] == 1 |
458 | && src_mdw.md_->format_desc.blocking.strides[last] == 1 |
459 | && conf.ndims <= MAX_NDIMS) { |
460 | return reorder_kernel_t::vectorize_last_dim; |
461 | } |
462 | |
463 | // This kernel supports 2D reorders into blocked formats that |
464 | // end in 8a4b or 8a2b, no matter how many block layers, but no padding. |
465 | if (!has_padding_or_multi_scale_quant && src_mdw.matches_one_of_tag(ab) |
466 | && matches_ABxxxx8ayb_layout( |
467 | dst_mdw.md_->format_desc.blocking, conf.ndims) |
468 | && padded_dims[last] % 16 == 0) { |
469 | return reorder_kernel_t::plain_to_ABxx8ayb; |
470 | } |
471 | |
472 | if (conf.ndims >= 2 && conf.ndims <= 4 |
473 | && src_mdw.md_->format_desc.blocking.inner_nblks == 0 |
474 | && dst_mdw.md_->format_desc.blocking.inner_nblks == 0 |
475 | && src_mdw.offset0() == 0 && dst_mdw.offset0() == 0 |
476 | && is_alt_faster_than_ref(src_mdw, dst_mdw, dev_info) |
477 | && !has_padding_or_multi_scale_quant) { |
478 | return reorder_kernel_t::reorder_alt; |
479 | } |
480 | |
481 | return reorder_kernel_t::none; |
482 | } |
483 | |
484 | void custom_reorder_t::pd_t::alt_defines( |
485 | compute::kernel_ctx_t &kernel_ctx) const { |
486 | const memory_desc_wrapper src_mdw(src_md()); |
487 | const memory_desc_wrapper dst_mdw(dst_md()); |
488 | size_t ndims = src_mdw.ndims(); |
489 | size_t last = ndims - 1; |
490 | |
491 | auto sdim = src_mdw.dims(); |
492 | auto sstr = src_mdw.blocking_desc().strides; |
493 | auto dstr = dst_mdw.blocking_desc().strides; |
494 | kernel_ctx.define_int("ALT_OFFSETS" , 1); |
495 | // LIMIT_MAX_D0 is necessary to avoid buffer overwritte. |
496 | if (conf.dispatch.nd_range().global_range()[0] != (size_t)sdim[last]) { |
497 | kernel_ctx.define_int("LIMIT_MAX_D0" , sdim[last]); |
498 | } |
499 | kernel_ctx.define_int("S0" , sstr[last]); |
500 | kernel_ctx.define_int("S1" , sstr[last - 1]); |
501 | kernel_ctx.define_int("S2" , ndims > 2 ? sstr[last - 2] : 1); |
502 | kernel_ctx.define_int("SB" , ndims > 3 ? sstr[last - 3] : 1); |
503 | kernel_ctx.define_int("D0" , dstr[last]); |
504 | kernel_ctx.define_int("D1" , dstr[last - 1]); |
505 | kernel_ctx.define_int("D2" , ndims > 2 ? dstr[last - 2] : 1); |
506 | kernel_ctx.define_int("DB" , ndims > 3 ? dstr[last - 3] : 1); |
507 | kernel_ctx.define_int("BLK" , ndims > 3 ? sdim[last - 3] : 1); |
508 | } |
509 | |
510 | void custom_reorder_t::pd_t::alt_gen() { |
511 | const memory_desc_wrapper src_mdw(src_md()); |
512 | const memory_desc_wrapper dst_mdw(dst_md()); |
513 | auto sdim = src_mdw.dims(); |
514 | |
515 | size_t last = src_mdw.ndims() - 1; |
516 | size_t gws3 = src_mdw.ndims() > 2 ? sdim[last - 2] : 1; |
517 | size_t gws[3] = {(size_t)sdim[last], (size_t)sdim[last - 1], gws3}; |
518 | size_t work_group_size = 32; |
519 | if (sdim[last] <= 16) { work_group_size = 16; } |
520 | if (sdim[last] <= 8) { work_group_size = 8; } |
521 | const size_t lws[3] = {work_group_size, 1, 1}; |
522 | // Don't use nonuniform work groups, round up number work items if needed. |
523 | size_t mod = gws[0] % lws[0]; |
524 | if (mod != 0) { gws[0] += lws[0] - mod; } |
525 | conf.dispatch.generate_override(gws, lws); |
526 | } |
527 | |
528 | status_t custom_reorder_t::pd_t::init_conf(engine_t *engine) { |
529 | using namespace format_tag; |
530 | |
531 | const memory_desc_wrapper src_mdw(src_md()); |
532 | const memory_desc_wrapper dst_mdw(dst_md()); |
533 | |
534 | conf.src_md_info = memory_desc_info_t::create(src_mdw); |
535 | conf.dst_md_info = memory_desc_info_t::create(dst_mdw); |
536 | |
537 | status_t status = status::success; |
538 | |
539 | const auto &padded_dims = dst_mdw.padded_dims(); |
540 | conf.src_quant = {attr(), src_mdw, DNNL_ARG_SRC}; |
541 | conf.dst_quant = {attr(), dst_mdw, DNNL_ARG_DST}; |
542 | conf.sum_quant = {attr()}; |
543 | conf.has_padding = !src_mdw.is_dense() || !dst_mdw.is_dense(); |
544 | conf.ndims = src_mdw.ndims(); |
545 | conf.nelems = utils::array_product(padded_dims, conf.ndims); |
546 | |
547 | conf.sub_group_size = 1; |
548 | |
549 | if (conf.nelems == 0) return status::success; |
550 | |
551 | int last = conf.ndims - 1; |
552 | size_t last_dim = padded_dims[last]; |
553 | |
554 | auto *compute_engine = utils::downcast<compute::compute_engine_t *>(engine); |
555 | |
556 | conf.implementation = select_kernel( |
557 | conf, src_mdw, dst_mdw, compute_engine->device_info()); |
558 | |
559 | dim_t blocks[MAX_NDIMS] = {1, 1, 1, 1, 1, 1}; |
560 | int vect_size = 1; |
561 | int vect_dim = 0; |
562 | |
563 | conf.dispatch = compute_engine->create_dispatch(dst_mdw.md_); |
564 | int temp_block = 1; |
565 | |
566 | const bool may_use_sg8 = compute_engine->mayiuse_sub_group(8); |
567 | int mask = conf.src_quant.scale_mask() | conf.src_quant.zp_mask() |
568 | | conf.dst_quant.scale_mask() | conf.dst_quant.zp_mask(); |
569 | |
570 | switch (conf.implementation) { |
571 | case none: return status_t::dnnl_unimplemented; |
572 | case reorder_alt: |
573 | // special handling with dispatcher override |
574 | conf.sub_group_size = 16; |
575 | break; |
576 | case dense_vector: |
577 | // see special handling below |
578 | conf.sub_group_size = 16; |
579 | break; |
580 | case unroll_16b: |
581 | conf.sub_group_size = 16; |
582 | vect_dim = 1; |
583 | vect_size = 16; |
584 | break; |
585 | case unroll_16b16c: |
586 | conf.sub_group_size = 16; |
587 | blocks[2] = 16; |
588 | vect_dim = 1; |
589 | vect_size = 16; |
590 | break; |
591 | case unroll_16a16b: |
592 | conf.sub_group_size = 16; |
593 | blocks[0] = 16; |
594 | vect_dim = 1; |
595 | vect_size = 16; |
596 | break; |
597 | case plain_to_ABcd84a42b: { |
598 | auto &blk = dst_mdw.blocking_desc(); |
599 | int inner_block = blk.inner_blks[blk.inner_nblks - 1]; |
600 | int outer_block = blk.inner_blks[blk.inner_nblks - 2]; |
601 | conf.sub_group_size = inner_block * outer_block; |
602 | blocks[0] = outer_block; |
603 | blocks[1] = inner_block; |
604 | vect_dim = 3; |
605 | vect_size = conf.sub_group_size; |
606 | } break; |
607 | case xb_to_xab_xba: |
608 | fill_conf_xab_xba(src_mdw, dst_mdw, mask, conf.aux_data.ab, |
609 | vect_dim, vect_size, &blocks[0]); |
610 | conf.sub_group_size = vect_size; |
611 | break; |
612 | case vectorize_last_dim: |
613 | vect_dim = last; |
614 | vect_size = (last_dim % 16 == 0) ? 16 : 8; |
615 | if (!may_use_sg8 && vect_size == 8) { |
616 | return status_t::dnnl_unimplemented; |
617 | } |
618 | for (int dim = last - 1; |
619 | dim >= 0 && dim < MAX_NDIMS && temp_block == 1; dim--) { |
620 | if (padded_dims[dim] % 4 == 0) { temp_block = 4; } |
621 | if (padded_dims[dim] % 8 == 0) { temp_block = 8; } |
622 | if (padded_dims[dim] % 16 == 0) { temp_block = 16; } |
623 | blocks[dim] = temp_block; |
624 | } |
625 | break; |
626 | case pad_innermost: { |
627 | auto last_dim_src = get_Nth_last_dim_or_block(src_mdw); |
628 | auto nextlast_dim_src = get_Nth_last_dim_or_block(src_mdw, 1); |
629 | auto last_dim_dst = get_Nth_last_dim_or_block(dst_mdw); |
630 | auto nextlast_dim_dst = get_Nth_last_dim_or_block(dst_mdw, 1); |
631 | |
632 | int min_common_size |
633 | = std::min(last_dim_src.size, last_dim_dst.size); |
634 | int max_common_size |
635 | = std::max(last_dim_src.size, last_dim_dst.size); |
636 | conf.sub_group_size = max_common_size; |
637 | if (!may_use_sg8 && conf.sub_group_size == 8) { |
638 | return status_t::dnnl_unimplemented; |
639 | } |
640 | |
641 | // Group size bigger than 4 would need too much private mem; |
642 | // group size 1 will give worse perf than reference kernel. |
643 | int max_group_size = 4; |
644 | while (nextlast_dim_src.size % max_group_size != 0 |
645 | || nextlast_dim_dst.size % max_group_size != 0) { |
646 | max_group_size--; |
647 | } |
648 | |
649 | conf.aux_data.vg.vector_dim = last_dim_src.idx; |
650 | conf.aux_data.vg.src_loop_dim = nextlast_dim_dst.idx; |
651 | conf.aux_data.vg.dst_loop_dim = nextlast_dim_src.idx; |
652 | conf.aux_data.vg.innermost_size = min_common_size; |
653 | |
654 | blocks[conf.aux_data.vg.src_loop_dim] = max_group_size; |
655 | blocks[conf.aux_data.vg.dst_loop_dim] = max_group_size; |
656 | // if src loop and dst loop dims are the same, CL code would iterate |
657 | // over the same dimension in inner and outer loop, effectively doing |
658 | // redundant operations; set inner loop counter to 1 in that case. |
659 | conf.aux_data.vg.group_size |
660 | = (nextlast_dim_dst.idx == nextlast_dim_src.idx) |
661 | ? 1 |
662 | : max_group_size; |
663 | |
664 | vect_dim = conf.aux_data.vg.vector_dim; |
665 | vect_size = conf.sub_group_size; |
666 | } break; |
667 | |
668 | case vectorize_groups: { |
669 | auto last_dim_src = get_Nth_last_dim_or_block(src_mdw); |
670 | auto nextlast_dim_src = get_Nth_last_dim_or_block(src_mdw, 1); |
671 | auto last_dim_dst = get_Nth_last_dim_or_block(dst_mdw); |
672 | auto nextlast_dim_dst = get_Nth_last_dim_or_block(dst_mdw, 1); |
673 | int min_common_size |
674 | = std::min(last_dim_src.size, last_dim_dst.size); |
675 | vect_size = (min_common_size % 16 == 0) ? 16 : 8; |
676 | vect_dim = last_dim_src.idx; |
677 | if (!may_use_sg8 && vect_size == 8) { |
678 | return status_t::dnnl_unimplemented; |
679 | } |
680 | |
681 | assert(last_dim_src.size % vect_size == 0 |
682 | && last_dim_dst.size % vect_size == 0); |
683 | assert(last_dim_src.idx == last_dim_dst.idx); |
684 | int src_chunks; |
685 | if (last_dim_src.size / vect_size > 1) { |
686 | src_chunks = last_dim_src.size / vect_size; |
687 | conf.aux_data.vg.dst_loop_dim = last_dim_src.idx; |
688 | } else { |
689 | src_chunks = nextlast_dim_src.size; |
690 | conf.aux_data.vg.dst_loop_dim = nextlast_dim_src.idx; |
691 | } |
692 | int dst_chunks; |
693 | if (last_dim_dst.size / vect_size > 1) { |
694 | dst_chunks = last_dim_dst.size / vect_size; |
695 | conf.aux_data.vg.src_loop_dim = last_dim_dst.idx; |
696 | } else { |
697 | dst_chunks = nextlast_dim_dst.size; |
698 | conf.aux_data.vg.src_loop_dim = nextlast_dim_dst.idx; |
699 | } |
700 | // TODO: |
701 | // Final algorithm for selecting group size should consider: |
702 | // 1. Group size must be small enough to guarantee no spill. |
703 | // 2. Group size should be large enough to fill whole cache lines |
704 | // on both reads and writes, with line size determined by HW |
705 | // 3. If there's not enough data to feed all EUs, ignore (2) and |
706 | // decrease group size. |
707 | int max_data_size = (int)std::max( |
708 | src_mdw.data_type_size(), dst_mdw.data_type_size()); |
709 | |
710 | int group = 16 / max_data_size; |
711 | while (group > 1) { |
712 | if (src_chunks % group == 0 && dst_chunks % group == 0) { |
713 | break; |
714 | } |
715 | group--; |
716 | } |
717 | assert(group >= 1); |
718 | |
719 | conf.aux_data.vg.vector_dim = last_dim_src.idx; |
720 | conf.aux_data.vg.group_size = group; |
721 | conf.sub_group_size = vect_size; |
722 | |
723 | blocks[conf.aux_data.vg.src_loop_dim] = group; |
724 | blocks[conf.aux_data.vg.dst_loop_dim] = group; |
725 | } break; |
726 | case plain_to_ABxx8ayb: |
727 | conf.sub_group_size = 16; |
728 | blocks[0] = 8; |
729 | vect_dim = last; |
730 | vect_size = 16; |
731 | break; |
732 | case plain_xFxE_to_abcdef: |
733 | conf.sub_group_size = 16; |
734 | blocks[5] = nstl::min(padded_dims[conf.ndims - 1], dnnl_dim_t(16)); |
735 | vect_dim = 4; |
736 | vect_size = 16; |
737 | break; |
738 | case transpose8x8: |
739 | case local8x8: |
740 | if (!may_use_sg8) { return status_t::dnnl_unimplemented; } |
741 | conf.sub_group_size = 8; |
742 | blocks[get_Nth_last_dim_or_block(dst_mdw).idx] = 8; |
743 | vect_dim = get_Nth_last_dim_or_block(src_mdw).idx; |
744 | vect_size = 8; |
745 | break; |
746 | case transpose16x16: |
747 | case local16x16: |
748 | conf.sub_group_size = 16; |
749 | blocks[get_Nth_last_dim_or_block(dst_mdw).idx] = 16; |
750 | vect_dim = get_Nth_last_dim_or_block(src_mdw).idx; |
751 | vect_size = 16; |
752 | break; |
753 | case reorder_nchw: |
754 | conf.sub_group_size = 16; |
755 | blocks[1] = nstl::min(padded_dims[1], dnnl_dim_t(16)); |
756 | vect_dim = 3; |
757 | vect_size = 16; |
758 | break; |
759 | case unaligned_sizes: blocks[1] = padded_dims[1]; break; |
760 | } |
761 | |
762 | // special case for dense_vector kernel - treat tensors as flat 1D vectors |
763 | if (conf.implementation == dense_vector) { |
764 | conf.dispatch.define_dim("D0" , 0, conf.nelems, 16); |
765 | CHECK(conf.dispatch.vectorize_dim("D0" , 16)); |
766 | } else { |
767 | for (int i = 0; i < MAX_NDIMS; ++i) { |
768 | auto dim_str = utils::format("D%d" , i); |
769 | if (i < dst_mdw.ndims()) { |
770 | int dim = padded_dims[i]; |
771 | // if needed to align vectorized dim with vector size, pad that dim again |
772 | if (i == vect_dim) { dim = utils::rnd_up(dim, vect_size); } |
773 | conf.dispatch.define_dim(dim_str, i, dim, blocks[i]); |
774 | } else { |
775 | conf.dispatch.define_dim(dim_str, 1); |
776 | } |
777 | } |
778 | if (vect_size != 1) { |
779 | auto dim_str = utils::format("D%d" , vect_dim); |
780 | CHECK(conf.dispatch.vectorize_dim(dim_str, vect_size)); |
781 | } |
782 | } |
783 | |
784 | if (conf.implementation == reorder_alt) { |
785 | alt_gen(); |
786 | } else { |
787 | conf.dispatch.generate(); |
788 | } |
789 | return status; |
790 | } |
791 | |
792 | status_t custom_reorder_t::pd_t::init_kernel_ctx( |
793 | compute::kernel_ctx_t &kernel_ctx) const { |
794 | using namespace format_tag; |
795 | |
796 | const memory_desc_wrapper src_mdw(src_md()); |
797 | const memory_desc_wrapper dst_mdw(dst_md()); |
798 | |
799 | if (conf.nelems == 0) return status::success; |
800 | |
801 | kernel_ctx.define_int("NDIMS" , conf.ndims); |
802 | kernel_ctx.add_option("-cl-std=CL2.0" ); |
803 | |
804 | conf.src_quant.define_macros(kernel_ctx, "SRC" ); |
805 | conf.dst_quant.define_macros(kernel_ctx, "DST" ); |
806 | conf.sum_quant.define_macros(kernel_ctx, "SUM" ); |
807 | |
808 | def_dispatch(kernel_ctx, conf.dispatch); |
809 | |
810 | // the 'unaligned_sizes' kernel uses the same implementation in .cl |
811 | // the difference is in sizes of blocks[] |
812 | if (conf.implementation == unaligned_sizes) { |
813 | kernel_ctx.define_int("UNALIGNED" , 1); |
814 | } |
815 | kernel_ctx.define_int("SUB_GROUP_SIZE" , conf.sub_group_size); |
816 | |
817 | kernel_ctx.define_int("PAD_FILL_ZERO" , conf.has_padding); |
818 | if (conf.implementation == dense_vector) { |
819 | kernel_ctx.add_option("-Dcl_intel_subgroups_char" ); |
820 | kernel_ctx.define_int("USE_DENSE_VECT" , 1); |
821 | } |
822 | |
823 | def_memory_desc_info(kernel_ctx, conf.src_md_info, "SRC" ); |
824 | def_memory_desc_info(kernel_ctx, conf.dst_md_info, "DST" ); |
825 | |
826 | // distinguish between various flavors of unroll kernel |
827 | if (src_mdw.matches_one_of_tag( |
828 | ABc16a16b, ABcd16a16b, ABcde16a16b, BAc16a16b, BAcd16a16b)) { |
829 | kernel_ctx.define_int("SRC_16A16B" , 1); |
830 | } else if (src_mdw.matches_one_of_tag(ABc16b16a, ABcd16b16a, ABcde16b16a, |
831 | BAc16b16a, BAcd16b16a, BAcde16b16a)) { |
832 | kernel_ctx.define_int("SRC_16B16A" , 1); |
833 | } else if (src_mdw.matches_one_of_tag(aBc16b, aBcd16b, aBcde16b)) { |
834 | kernel_ctx.define_int("SRC_16B" , 1); |
835 | } else if (src_mdw.matches_one_of_tag(aBCd16b16c, aBCde16b16c, aBCdef16b16c, |
836 | aCBd16b16c, aCBde16b16c)) { |
837 | kernel_ctx.define_int("SRC_16B16C" , 1); |
838 | } else if (src_mdw.matches_one_of_tag(aBCd16c16b, aBCde16c16b, aBCdef16c16b, |
839 | aCBd16c16b, aCBde16c16b, aCBdef16c16b)) { |
840 | kernel_ctx.define_int("SRC_16C16B" , 1); |
841 | } |
842 | if (dst_mdw.matches_one_of_tag( |
843 | ABc16a16b, ABcd16a16b, ABcde16a16b, BAc16a16b, BAcd16a16b)) { |
844 | kernel_ctx.define_int("DST_16A16B" , 1); |
845 | } else if (dst_mdw.matches_one_of_tag(ABc16b16a, ABcd16b16a, ABcde16b16a, |
846 | BAc16b16a, BAcd16b16a, BAcde16b16a)) { |
847 | kernel_ctx.define_int("DST_16B16A" , 1); |
848 | } else if (dst_mdw.matches_one_of_tag(aBc16b, aBcd16b, aBcde16b)) { |
849 | kernel_ctx.define_int("DST_16B" , 1); |
850 | } else if (dst_mdw.matches_one_of_tag(aBCd16b16c, aBCde16b16c, aBCdef16b16c, |
851 | aCBd16b16c, aCBde16b16c)) { |
852 | kernel_ctx.define_int("DST_16B16C" , 1); |
853 | } else if (dst_mdw.matches_one_of_tag(aBCd16c16b, aBCde16c16b, aBCdef16c16b, |
854 | aCBd16c16b, aCBde16c16b, aCBdef16c16b)) { |
855 | kernel_ctx.define_int("DST_16C16B" , 1); |
856 | } |
857 | |
858 | if (conf.implementation == reorder_alt) { alt_defines(kernel_ctx); } |
859 | if (conf.implementation == plain_xFxE_to_abcdef) |
860 | kernel_ctx.define_int("PLAIN_xFxE_TO_ABCDEF" , 1); |
861 | |
862 | if (conf.implementation == plain_to_ABcd84a42b) { |
863 | kernel_ctx.define_int("PLAIN_TO_ABCD84A42B" , 1); |
864 | auto r = conf.dispatch.nd_range(); |
865 | auto *lr = r.local_range(); |
866 | kernel_ctx.define_int( |
867 | "SG_PER_WG" , (lr[0] * lr[1] * lr[2]) / conf.sub_group_size); |
868 | } |
869 | if (conf.implementation == xb_to_xab_xba) { |
870 | kernel_ctx.define_int("XAB_XBA" , 1); |
871 | auto r = conf.dispatch.nd_range(); |
872 | auto *lr = r.local_range(); |
873 | kernel_ctx.define_int( |
874 | "SG_PER_WG" , (lr[0] * lr[1] * lr[2]) / conf.sub_group_size); |
875 | kernel_ctx.define_int("BLOCK_SIZE" , conf.aux_data.ab.blk_size); |
876 | kernel_ctx.define_int("SRC_BLK_DIM" , conf.aux_data.ab.src_blk_dim); |
877 | kernel_ctx.define_int("SRC_OFF_COEFF" , conf.aux_data.ab.src_blk_coeff); |
878 | kernel_ctx.define_int("DST_BLK_DIM" , conf.aux_data.ab.dst_blk_dim); |
879 | kernel_ctx.define_int("DST_OFF_COEFF" , conf.aux_data.ab.dst_blk_coeff); |
880 | kernel_ctx.define_int("XB_TO_XAB" , conf.aux_data.ab.vd); |
881 | } |
882 | |
883 | if (conf.implementation == vectorize_last_dim) { |
884 | kernel_ctx.define_int("VECTORIZE_LAST_DIM" , 1); |
885 | } |
886 | |
887 | if (conf.implementation == pad_innermost) { |
888 | kernel_ctx.define_int("PAD_INNERMOST" , 1); |
889 | kernel_ctx.define_int( |
890 | "VECT_DIM" , conf.aux_data.vg.vector_dim); //useless |
891 | kernel_ctx.define_int("SRC_LOOP_DIM" , conf.aux_data.vg.src_loop_dim); |
892 | kernel_ctx.define_int("DST_LOOP_DIM" , conf.aux_data.vg.dst_loop_dim); |
893 | kernel_ctx.define_int("GROUP" , conf.aux_data.vg.group_size); |
894 | auto r = conf.dispatch.nd_range(); |
895 | auto *lr = r.local_range(); |
896 | if (!lr) return status::runtime_error; |
897 | kernel_ctx.define_int( |
898 | "SG_PER_WG" , (lr[0] * lr[1] * lr[2]) / conf.sub_group_size); |
899 | kernel_ctx.define_int( |
900 | "INNERMOST_SIZE" , conf.aux_data.vg.innermost_size); |
901 | kernel_ctx.define_int("VECT_SIZE" , conf.sub_group_size); |
902 | bool has_non_innermost_padding = false; |
903 | for (int i = 0; i < MAX_NDIMS; i++) { |
904 | if (i == conf.aux_data.vg.vector_dim) { continue; } |
905 | has_non_innermost_padding |
906 | |= (dst_mdw.dims()[i] != dst_mdw.padded_dims()[i]); |
907 | } |
908 | kernel_ctx.define_int( |
909 | "NON_INNERMOST_PADDING" , has_non_innermost_padding); |
910 | auto last_dim_dst = get_Nth_last_dim_or_block(dst_mdw); |
911 | kernel_ctx.define_int("DST_INNERMOST_STRIDE" , |
912 | dst_mdw.is_plain() |
913 | ? dst_mdw.blocking_desc().strides[last_dim_dst.idx] |
914 | : 1); |
915 | } |
916 | if (conf.implementation == vectorize_groups) { |
917 | kernel_ctx.define_int("VECTORIZE_GROUPS" , 1); |
918 | kernel_ctx.define_int("VECT_DIM" , conf.aux_data.vg.vector_dim); |
919 | kernel_ctx.define_int("SRC_LOOP_DIM" , conf.aux_data.vg.src_loop_dim); |
920 | kernel_ctx.define_int("DST_LOOP_DIM" , conf.aux_data.vg.dst_loop_dim); |
921 | kernel_ctx.define_int("GROUP" , conf.aux_data.vg.group_size); |
922 | } |
923 | if (conf.implementation == plain_to_ABxx8ayb) { |
924 | kernel_ctx.define_int("PLAIN_TO_AB_XX_8AYB" , 1); |
925 | kernel_ctx.define_int( |
926 | "BLK_L" , innermost_block(dst_mdw.md_->format_desc.blocking)); |
927 | } |
928 | |
929 | if (conf.implementation == transpose8x8 |
930 | || conf.implementation == transpose16x16) { |
931 | kernel_ctx.define_int("TRANSPOSE_NXN" , 1); |
932 | kernel_ctx.define_int( |
933 | "DST_BLOCK_DIM" , get_Nth_last_dim_or_block(src_mdw).idx); |
934 | } |
935 | |
936 | if (conf.implementation == local8x8 || conf.implementation == local16x16) { |
937 | kernel_ctx.define_int("LOCAL_NXN" , 1); |
938 | auto r = conf.dispatch.nd_range(); |
939 | auto *lr = r.local_range(); |
940 | if (!lr) return status::runtime_error; |
941 | kernel_ctx.define_int("SG_PER_WG" , lr[0] * lr[1] * lr[2]); |
942 | kernel_ctx.define_int( |
943 | "DST_BLOCK_DIM" , get_Nth_last_dim_or_block(src_mdw).idx); |
944 | } |
945 | |
946 | if (conf.implementation == reorder_nchw) { |
947 | kernel_ctx.define_int("REORDER_NCHW" , 1); |
948 | } |
949 | |
950 | kernel_ctx.print_options(); |
951 | return status::success; |
952 | } |
953 | |
954 | void custom_reorder_t::pd_t::init_scratchpad() { |
955 | if (conf.src_quant.with_scale()) { |
956 | auto scratchpad = scratchpad_registry().registrar(); |
957 | scratchpad.book(memory_tracking::names::key_reorder_src_scales, |
958 | conf.src_quant.num_scales(), sizeof(float), |
959 | OCL_BUFFER_ALIGNMENT); |
960 | } |
961 | if (conf.dst_quant.with_scale()) { |
962 | auto scratchpad = scratchpad_registry().registrar(); |
963 | scratchpad.book(memory_tracking::names::key_reorder_dst_scales, |
964 | conf.dst_quant.num_scales(), sizeof(float), |
965 | OCL_BUFFER_ALIGNMENT); |
966 | } |
967 | } |
968 | |
969 | status_t custom_reorder_t::execute(const exec_ctx_t &ctx) const { |
970 | |
971 | status_t status = status::success; |
972 | |
973 | auto &src = CTX_IN_STORAGE(DNNL_ARG_FROM); |
974 | auto &dst = CTX_OUT_STORAGE(DNNL_ARG_TO); |
975 | CHECK(status); |
976 | |
977 | const auto &conf = pd()->conf; |
978 | if (conf.nelems == 0) return status::success; |
979 | |
980 | compute::kernel_arg_list_t arg_list; |
981 | arg_list.set(0, src); |
982 | arg_list.set(1, dst); |
983 | |
984 | arg_list.set(2, conf.src_quant.scales(ctx)); |
985 | arg_list.set(3, conf.src_quant.zero_points(ctx)); |
986 | arg_list.set(4, conf.dst_quant.scales(ctx)); |
987 | arg_list.set(5, conf.dst_quant.zero_points(ctx)); |
988 | |
989 | arg_list.set(6, conf.sum_quant.scales()); |
990 | arg_list.set(7, conf.sum_quant.zero_points()); |
991 | |
992 | auto nd_range = conf.dispatch.nd_range(); |
993 | |
994 | status = parallel_for(ctx, nd_range, kernel_, arg_list); |
995 | |
996 | return status; |
997 | } |
998 | |
999 | } // namespace ocl |
1000 | } // namespace gpu |
1001 | } // namespace impl |
1002 | } // namespace dnnl |
1003 | |