1namespace dnnl {
2namespace impl {
3namespace gpu {
4namespace ocl {
5const char *generic_reorder_kernel = R"==(/******************************************************************************* )==""\n"
6R"==(* Copyright 2021-2022 Intel Corporation )==""\n"
7R"==(* )==""\n"
8R"==(* Licensed under the Apache License, Version 2.0 (the "License"); )==""\n"
9R"==(* you may not use this file except in compliance with the License. )==""\n"
10R"==(* You may obtain a copy of the License at )==""\n"
11R"==(* )==""\n"
12R"==(* http: )==""\n"
13R"==(* )==""\n"
14R"==(* Unless required by applicable law or agreed to in writing, software )==""\n"
15R"==(* distributed under the License is distributed on an "AS IS" BASIS, )==""\n"
16R"==(* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. )==""\n"
17R"==(* See the License for the specific language governing permissions and )==""\n"
18R"==(* limitations under the License. )==""\n"
19R"==(*******************************************************************************/ )==""\n"
20R"==(#include "gpu/ocl/reorder_common.h" )==""\n"
21R"==(KERNEL_ATTR )==""\n"
22R"==(__kernel void generic_reorder(__global SRC_DATA_T *restrict src, )==""\n"
23R"==(__global DST_DATA_T *restrict dst, __global float *restrict src_scales, )==""\n"
24R"==(__global int *restrict src_zps, __global float *restrict dst_scales, )==""\n"
25R"==(__global int *restrict dst_zps, float sum_scale, int sum_zp) { )==""\n"
26R"==(const int src_zp = GET_SRC_ZP(src_zps); )==""\n"
27R"==(const int dst_zp = GET_DST_ZP(dst_zps); )==""\n"
28R"==(float src_scale = 1.0f; )==""\n"
29R"==(float dst_scale = 1.0f; )==""\n"
30R"==(src += SRC_OFFSET0; )==""\n"
31R"==(dst += DST_OFFSET0; )==""\n"
32R"==(#define LOOP_NEST_LEVEL 4 )==""\n"
33R"==(const uint sgId = get_sub_group_local_id(); )==""\n"
34R"==(uint d[6]; )==""\n"
35R"==(uint b[6] = {0, 0, 0, 0, 0, 0}; )==""\n"
36R"==(d[0] = GWS_GET_D0(); )==""\n"
37R"==(d[1] = GWS_GET_D1(); )==""\n"
38R"==(d[2] = GWS_GET_D2(); )==""\n"
39R"==(d[3] = GWS_GET_D3(); )==""\n"
40R"==(d[4] = GWS_GET_D4(); )==""\n"
41R"==(d[5] = GWS_GET_D5(); )==""\n"
42R"==(d[VECT_DIM] /= RESCALE_COEFF; )==""\n"
43R"==(const uint cache_size_per_sg = D_BLK_SIZE_0 * D_BLK_SIZE_1 * D_BLK_SIZE_2 )==""\n"
44R"==(* D_BLK_SIZE_3 * VECT_SIZE; )==""\n"
45R"==(const uint sg_off = get_sub_group_id() * cache_size_per_sg; )==""\n"
46R"==(__local SRC_DATA_T cache[SG_PER_WG * cache_size_per_sg]; )==""\n"
47R"==(uint iter[LOOP_NEST_LEVEL] = {0, 0, 0, 0}; )==""\n"
48R"==(#if S_BLK_SIZE_3 > 1 )==""\n"
49R"==(for_(iter[3] = 0; iter[3] < S_BLK_SIZE_3; iter[3]++) )==""\n"
50R"==(#endif )==""\n"
51R"==(#if S_BLK_SIZE_2 > 1 )==""\n"
52R"==(for_(iter[2] = 0; iter[2] < S_BLK_SIZE_2; iter[2]++) )==""\n"
53R"==(#endif )==""\n"
54R"==(#if S_BLK_SIZE_1 > 1 )==""\n"
55R"==(for_(iter[1] = 0; iter[1] < S_BLK_SIZE_1; iter[1]++) )==""\n"
56R"==(#endif )==""\n"
57R"==(#if S_BLK_SIZE_0 > 1 )==""\n"
58R"==(for_(iter[0] = 0; iter[0] < S_BLK_SIZE_0; iter[0]++) )==""\n"
59R"==(#endif )==""\n"
60R"==({ )==""\n"
61R"==(b[0] = 0; )==""\n"
62R"==(b[1] = 0; )==""\n"
63R"==(b[2] = 0; )==""\n"
64R"==(b[3] = 0; )==""\n"
65R"==(b[4] = 0; )==""\n"
66R"==(b[5] = 0; )==""\n"
67R"==(b[S_BLK_IDX_0] += iter[0] * S_BLK_STEP_0; )==""\n"
68R"==(b[S_BLK_IDX_1] += iter[1] * S_BLK_STEP_1; )==""\n"
69R"==(b[S_BLK_IDX_2] += iter[2] * S_BLK_STEP_2; )==""\n"
70R"==(b[S_BLK_IDX_3] += iter[3] * S_BLK_STEP_3; )==""\n"
71R"==(#if S_MOD_3 > 1 )==""\n"
72R"==(b[S_IDX_3] += S_MUL_3 * ((sgId / S_DIV_3) % S_MOD_3); )==""\n"
73R"==(#endif )==""\n"
74R"==(#if S_MOD_2 > 1 )==""\n"
75R"==(b[S_IDX_2] += S_MUL_2 * ((sgId / S_DIV_2) % S_MOD_2); )==""\n"
76R"==(#endif )==""\n"
77R"==(#if S_MOD_1 > 1 )==""\n"
78R"==(b[S_IDX_1] += S_MUL_1 * ((sgId / S_DIV_1) % S_MOD_1); )==""\n"
79R"==(#endif )==""\n"
80R"==(#if S_MOD_0 > 1 )==""\n"
81R"==(b[S_IDX_0] += S_MUL_0 * ((sgId / S_DIV_0) % S_MOD_0); )==""\n"
82R"==(#endif )==""\n"
83R"==(const uint src_off = SRC_OFF(d[0] + b[0], d[1] + b[1], d[2] + b[2], )==""\n"
84R"==(d[3] + b[3], d[4] + b[4], d[5] + b[5]); )==""\n"
85R"==(uint cache_idx = sg_off + b[5] * CACHE_STRIDE_5 + b[4] * CACHE_STRIDE_4 )==""\n"
86R"==(+ b[3] * CACHE_STRIDE_3 + b[2] * CACHE_STRIDE_2 )==""\n"
87R"==(+ b[1] * CACHE_STRIDE_1 + b[0] * CACHE_STRIDE_0; )==""\n"
88R"==(const int pad_d0 = d[0] + b[0] >= SRC_D0; )==""\n"
89R"==(const int pad_d1 = NDIMS > 1 && d[1] + b[1] >= SRC_D1; )==""\n"
90R"==(const int pad_d2 = NDIMS > 2 && d[2] + b[2] >= SRC_D2; )==""\n"
91R"==(const int pad_d3 = NDIMS > 3 && d[3] + b[3] >= SRC_D3; )==""\n"
92R"==(const int pad_d4 = NDIMS > 4 && d[4] + b[4] >= SRC_D4; )==""\n"
93R"==(const int pad_d5 = NDIMS > 5 && d[5] + b[5] >= SRC_D5; )==""\n"
94R"==(const bool pad_sgid = sgId >= LIMIT_SSGID; )==""\n"
95R"==(const int pad )==""\n"
96R"==(= pad_d0 || pad_d1 || pad_d2 || pad_d3 || pad_d4 || pad_d5; )==""\n"
97R"==(if (!pad_sgid) { )==""\n"
98R"==(SRC_DATA_T src_tmp = pad ? 0 : src[src_off]; )==""\n"
99R"==(cache[cache_idx] = src_tmp; )==""\n"
100R"==(} )==""\n"
101R"==(} )==""\n"
102R"==(for (uint i = 0; i < LOOP_NEST_LEVEL; i++) { )==""\n"
103R"==(iter[i] = 0; )==""\n"
104R"==(} )==""\n"
105R"==(#if D_BLK_SIZE_3 > 1 )==""\n"
106R"==(for_(iter[3] = 0; iter[3] < D_BLK_SIZE_3; iter[3]++) )==""\n"
107R"==(#endif )==""\n"
108R"==(#if D_BLK_SIZE_2 > 1 )==""\n"
109R"==(for_(iter[2] = 0; iter[2] < D_BLK_SIZE_2; iter[2]++) )==""\n"
110R"==(#endif )==""\n"
111R"==(#if D_BLK_SIZE_1 > 1 )==""\n"
112R"==(for_(iter[1] = 0; iter[1] < D_BLK_SIZE_1; iter[1]++) )==""\n"
113R"==(#endif )==""\n"
114R"==(#if D_BLK_SIZE_0 > 1 )==""\n"
115R"==(for_(iter[0] = 0; iter[0] < D_BLK_SIZE_0; iter[0]++) )==""\n"
116R"==(#endif )==""\n"
117R"==({ )==""\n"
118R"==(b[0] = 0; )==""\n"
119R"==(b[1] = 0; )==""\n"
120R"==(b[2] = 0; )==""\n"
121R"==(b[3] = 0; )==""\n"
122R"==(b[4] = 0; )==""\n"
123R"==(b[5] = 0; )==""\n"
124R"==(b[D_BLK_IDX_0] += iter[0] * D_BLK_STEP_0; )==""\n"
125R"==(b[D_BLK_IDX_1] += iter[1] * D_BLK_STEP_1; )==""\n"
126R"==(b[D_BLK_IDX_2] += iter[2] * D_BLK_STEP_2; )==""\n"
127R"==(b[D_BLK_IDX_3] += iter[3] * D_BLK_STEP_3; )==""\n"
128R"==(#if D_MOD_3 > 1 )==""\n"
129R"==(b[D_IDX_3] += D_MUL_3 * ((sgId / D_DIV_3) % D_MOD_3); )==""\n"
130R"==(#endif )==""\n"
131R"==(#if D_MOD_2 > 1 )==""\n"
132R"==(b[D_IDX_2] += D_MUL_2 * ((sgId / D_DIV_2) % D_MOD_2); )==""\n"
133R"==(#endif )==""\n"
134R"==(#if D_MOD_1 > 1 )==""\n"
135R"==(b[D_IDX_1] += D_MUL_1 * ((sgId / D_DIV_1) % D_MOD_1); )==""\n"
136R"==(#endif )==""\n"
137R"==(#if D_MOD_0 > 1 )==""\n"
138R"==(b[D_IDX_0] += D_MUL_0 * ((sgId / D_DIV_0) % D_MOD_0); )==""\n"
139R"==(#endif )==""\n"
140R"==(const uint dst_off = DST_OFF(d[0] + b[0], d[1] + b[1], d[2] + b[2], )==""\n"
141R"==(d[3] + b[3], d[4] + b[4], d[5] + b[5]); )==""\n"
142R"==(DST_DATA_T dst_tmp; )==""\n"
143R"==(uint cache_idx = sg_off + b[5] * CACHE_STRIDE_5 + b[4] * CACHE_STRIDE_4 )==""\n"
144R"==(+ b[3] * CACHE_STRIDE_3 + b[2] * CACHE_STRIDE_2 )==""\n"
145R"==(+ b[1] * CACHE_STRIDE_1 + b[0] * CACHE_STRIDE_0; )==""\n"
146R"==(const int pad_d0 = d[0] + b[0] >= DST_PD0; )==""\n"
147R"==(const int pad_d1 = NDIMS > 1 && d[1] + b[1] >= DST_PD1; )==""\n"
148R"==(const int pad_d2 = NDIMS > 2 && d[2] + b[2] >= DST_PD2; )==""\n"
149R"==(const int pad_d3 = NDIMS > 3 && d[3] + b[3] >= DST_PD3; )==""\n"
150R"==(const int pad_d4 = NDIMS > 4 && d[4] + b[4] >= DST_PD4; )==""\n"
151R"==(const int pad_d5 = NDIMS > 5 && d[5] + b[5] >= DST_PD5; )==""\n"
152R"==(const bool pad_sgid = sgId >= LIMIT_DSGID; )==""\n"
153R"==(const int pad = pad_d0 || pad_d1 || pad_d2 || pad_d3 || pad_d4 || pad_d5 )==""\n"
154R"==(|| pad_sgid; )==""\n"
155R"==(if (!pad) { )==""\n"
156R"==(SRC_DATA_T from_cache = cache[cache_idx]; )==""\n"
157R"==(#if WITH_SUM_SCALE || WITH_SUM_ZPOINT )==""\n"
158R"==(dst_tmp = dst[dst_off]; )==""\n"
159R"==(#endif )==""\n"
160R"==(#if WITH_SRC_SCALE )==""\n"
161R"==(uint src_scale_idx = SCALE_OFF(SRC, d[0] + b[0], d[1] + b[1], )==""\n"
162R"==(d[2] + b[2], d[3] + b[3], d[4] + b[4], d[5] + b[5]); )==""\n"
163R"==(src_scale = src_scale_idx < SRC_NUM_SCALES )==""\n"
164R"==(? src_scales[src_scale_idx] )==""\n"
165R"==(: 0.0; )==""\n"
166R"==(#endif )==""\n"
167R"==(#if WITH_DST_SCALE )==""\n"
168R"==(uint dst_scale_idx = SCALE_OFF(DST, d[0] + b[0], d[1] + b[1], )==""\n"
169R"==(d[2] + b[2], d[3] + b[3], d[4] + b[4], d[5] + b[5]); )==""\n"
170R"==(dst_scale = dst_scale_idx < DST_NUM_SCALES )==""\n"
171R"==(? dst_scales[dst_scale_idx] )==""\n"
172R"==(: 0.0; )==""\n"
173R"==(#endif )==""\n"
174R"==(REORDER(dst_tmp, from_cache, src_scale, dst_scale, sum_scale, )==""\n"
175R"==(src_zp, dst_zp, sum_zp); )==""\n"
176R"==(dst[dst_off] = dst_tmp; )==""\n"
177R"==(} )==""\n"
178R"==(} )==""\n"
179R"==(} )==""\n"
180R"==()==";
181}
182}
183}
184}