1 | namespace dnnl { |
2 | namespace impl { |
3 | namespace gpu { |
4 | namespace ocl { |
5 | const char *ref_reorder_kernel = R"==(/******************************************************************************* )==" "\n" |
6 | R"==(* Copyright 2019-2022 Intel Corporation )==" "\n" |
7 | R"==(* )==" "\n" |
8 | R"==(* Licensed under the Apache License, Version 2.0 (the "License"); )==" "\n" |
9 | R"==(* you may not use this file except in compliance with the License. )==" "\n" |
10 | R"==(* You may obtain a copy of the License at )==" "\n" |
11 | R"==(* )==" "\n" |
12 | R"==(* http: )==" "\n" |
13 | R"==(* )==" "\n" |
14 | R"==(* Unless required by applicable law or agreed to in writing, software )==" "\n" |
15 | R"==(* distributed under the License is distributed on an "AS IS" BASIS, )==" "\n" |
16 | R"==(* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. )==" "\n" |
17 | R"==(* See the License for the specific language governing permissions and )==" "\n" |
18 | R"==(* limitations under the License. )==" "\n" |
19 | R"==(*******************************************************************************/ )==" "\n" |
20 | R"==(#include "gpu/ocl/reorder_common.h" )==" "\n" |
21 | R"==(KERNEL_ATTR )==" "\n" |
22 | R"==(__kernel void ref_reorder(__global SRC_DATA_T *restrict src, )==" "\n" |
23 | R"==(__global DST_DATA_T *restrict dst, __global float *restrict src_scales, )==" "\n" |
24 | R"==(__global int *restrict src_zps, __global float *restrict dst_scales, )==" "\n" |
25 | R"==(__global int *dst_zps, float sum_scale, int sum_zp) { )==" "\n" |
26 | R"==(const int src_zp = GET_SRC_ZP(src_zps); )==" "\n" |
27 | R"==(const int dst_zp = GET_DST_ZP(dst_zps); )==" "\n" |
28 | R"==(float src_scale = 1.0f; )==" "\n" |
29 | R"==(float dst_scale = 1.0f; )==" "\n" |
30 | R"==(src += SRC_OFFSET0; )==" "\n" |
31 | R"==(dst += DST_OFFSET0; )==" "\n" |
32 | R"==(const int d0_blk_start = GWS_GET_D0(); )==" "\n" |
33 | R"==(const int d1_blk_start = GWS_GET_D1(); )==" "\n" |
34 | R"==(const int d2_blk_start = GWS_GET_D2(); )==" "\n" |
35 | R"==(const int d3_blk_start = GWS_GET_D3(); )==" "\n" |
36 | R"==(const int d4_blk_start = GWS_GET_D4(); )==" "\n" |
37 | R"==(const int d5_blk_start = GWS_GET_D5(); )==" "\n" |
38 | R"==(const int d0_blk_end = d0_blk_start + GWS_GET_D0_BLOCK(); )==" "\n" |
39 | R"==(const int d1_blk_end = d1_blk_start + GWS_GET_D1_BLOCK(); )==" "\n" |
40 | R"==(const int d2_blk_end = d2_blk_start + GWS_GET_D2_BLOCK(); )==" "\n" |
41 | R"==(const int d3_blk_end = d3_blk_start + GWS_GET_D3_BLOCK(); )==" "\n" |
42 | R"==(const int d4_blk_end = d4_blk_start + GWS_GET_D4_BLOCK(); )==" "\n" |
43 | R"==(const int d5_blk_end = d5_blk_start + GWS_GET_D5_BLOCK(); )==" "\n" |
44 | R"==(for_(int d0 = d0_blk_start; d0 < d0_blk_end; ++d0) )==" "\n" |
45 | R"==(for_(int d1 = d1_blk_start; d1 < d1_blk_end; ++d1) )==" "\n" |
46 | R"==(for_(int d2 = d2_blk_start; d2 < d2_blk_end; ++d2) )==" "\n" |
47 | R"==(for_(int d3 = d3_blk_start; d3 < d3_blk_end; ++d3) )==" "\n" |
48 | R"==(for_(int d4 = d4_blk_start; d4 < d4_blk_end; ++d4) )==" "\n" |
49 | R"==(for (int d5 = d5_blk_start; d5 < d5_blk_end; ++d5) { )==" "\n" |
50 | R"==(const int src_off = SRC_OFF(d0, d1, d2, d3, d4, d5); )==" "\n" |
51 | R"==(const int dst_off = DST_OFF(d0, d1, d2, d3, d4, d5); )==" "\n" |
52 | R"==(#if PAD_FILL_ZERO == 1 )==" "\n" |
53 | R"==(int pad_d0 = d0 >= SRC_D0; )==" "\n" |
54 | R"==(int pad_d1 = NDIMS > 1 && d1 >= SRC_D1; )==" "\n" |
55 | R"==(int pad_d2 = NDIMS > 2 && d2 >= SRC_D2; )==" "\n" |
56 | R"==(int pad_d3 = NDIMS > 3 && d3 >= SRC_D3; )==" "\n" |
57 | R"==(int pad_d4 = NDIMS > 4 && d4 >= SRC_D4; )==" "\n" |
58 | R"==(int pad_d5 = NDIMS > 5 && d5 >= SRC_D5; )==" "\n" |
59 | R"==(if (pad_d0 || pad_d1 || pad_d2 || pad_d3 || pad_d4 || pad_d5) { )==" "\n" |
60 | R"==(dst[dst_off] = 0; )==" "\n" |
61 | R"==(continue; )==" "\n" |
62 | R"==(} )==" "\n" |
63 | R"==(#endif )==" "\n" |
64 | R"==(#if WITH_SRC_SCALE )==" "\n" |
65 | R"==(src_scale = src_scales[SCALE_OFF(SRC, d0, d1, d2, d3, d4, d5)]; )==" "\n" |
66 | R"==(#endif )==" "\n" |
67 | R"==(#if WITH_DST_SCALE )==" "\n" |
68 | R"==(dst_scale = dst_scales[SCALE_OFF(DST, d0, d1, d2, d3, d4, d5)]; )==" "\n" |
69 | R"==(#endif )==" "\n" |
70 | R"==(REORDER(dst[dst_off], src[src_off], src_scale, dst_scale, sum_scale, )==" "\n" |
71 | R"==(src_zp, dst_zp, sum_zp); )==" "\n" |
72 | R"==(} )==" "\n" |
73 | R"==(} )==" "\n" |
74 | R"==()==" ; |
75 | } |
76 | } |
77 | } |
78 | } |