1 | namespace dnnl { |
2 | namespace impl { |
3 | namespace gpu { |
4 | namespace ocl { |
5 | const char *gen9_concat_kernel = R"==(/******************************************************************************* )==" "\n" |
6 | R"==(* Copyright 2021-2022 Intel Corporation )==" "\n" |
7 | R"==(* )==" "\n" |
8 | R"==(* Licensed under the Apache License, Version 2.0 (the "License"); )==" "\n" |
9 | R"==(* you may not use this file except in compliance with the License. )==" "\n" |
10 | R"==(* You may obtain a copy of the License at )==" "\n" |
11 | R"==(* )==" "\n" |
12 | R"==(* http: )==" "\n" |
13 | R"==(* )==" "\n" |
14 | R"==(* Unless required by applicable law or agreed to in writing, software )==" "\n" |
15 | R"==(* distributed under the License is distributed on an "AS IS" BASIS, )==" "\n" |
16 | R"==(* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. )==" "\n" |
17 | R"==(* See the License for the specific language governing permissions and )==" "\n" |
18 | R"==(* limitations under the License. )==" "\n" |
19 | R"==(*******************************************************************************/ )==" "\n" |
20 | R"==(#include "gpu/ocl/ocl_types.h" )==" "\n" |
21 | R"==(#define IS_IN_PART(x) (dst_dims[CONCAT_AXIS] < CONCAT3(SRC, x, _END)) )==" "\n" |
22 | R"==(#define SET_DIMS(x, y) \ )==" "\n" |
23 | R"==({ \ )==" "\n" |
24 | R"==(part = y; \ )==" "\n" |
25 | R"==(if (y > 0) { \ )==" "\n" |
26 | R"==(src_dims[CONCAT_AXIS] \ )==" "\n" |
27 | R"==(= dst_dims[CONCAT_AXIS] - CONCAT3(SRC, x, _END); \ )==" "\n" |
28 | R"==(} \ )==" "\n" |
29 | R"==(src_off = OFF_MD(CONCAT2(SRC, y), src_dims[0], src_dims[1], \ )==" "\n" |
30 | R"==(src_dims[2], src_dims[3], src_dims[4], src_dims[5]); \ )==" "\n" |
31 | R"==(src = CONCAT2(src, y); \ )==" "\n" |
32 | R"==(} )==" "\n" |
33 | R"==(#define SRC_DATA_T SRC0_DATA_T )==" "\n" |
34 | R"==(#define DD(i) CONCAt2(DST_D, i) )==" "\n" |
35 | R"==(#define NEEDS_PADDING(dim0, dim1, dim2, dim3, dim4, dim5) \ )==" "\n" |
36 | R"==(dim0 >= DD(0) || dim1 >= DD(1) || dim2 >= DD(2) || dim3 >= DD(3) \ )==" "\n" |
37 | R"==(|| dim4 >= DD(4) || dim5 >= DD(5) )==" "\n" |
38 | R"==(KERNEL_ATTR )==" "\n" |
39 | R"==(__kernel void gen9_concat(__global DST_DATA_T *dst, long dst_offset0, )==" "\n" |
40 | R"==(__global const SRC_DATA_T *src0, __global const SRC_DATA_T *src1, )==" "\n" |
41 | R"==(__global const SRC_DATA_T *src2, __global const SRC_DATA_T *src3, )==" "\n" |
42 | R"==(__global const SRC_DATA_T *src4, __global const SRC_DATA_T *src5, )==" "\n" |
43 | R"==(__global const SRC_DATA_T *src6, __global const SRC_DATA_T *src7, )==" "\n" |
44 | R"==(__global const SRC_DATA_T *src8, __global const SRC_DATA_T *src9, )==" "\n" |
45 | R"==(__global const SRC_DATA_T *src10, __global const SRC_DATA_T *src11, )==" "\n" |
46 | R"==(__global const SRC_DATA_T *src12, __global const SRC_DATA_T *src13, )==" "\n" |
47 | R"==(__global const SRC_DATA_T *src14, __global const SRC_DATA_T *src15) { )==" "\n" |
48 | R"==(dst += dst_offset0; )==" "\n" |
49 | R"==(int dst_dims[6], src_dims[6]; )==" "\n" |
50 | R"==(src_dims[0] = dst_dims[0] = GWS_GET_D0(); )==" "\n" |
51 | R"==(src_dims[1] = dst_dims[1] = GWS_GET_D1(); )==" "\n" |
52 | R"==(src_dims[2] = dst_dims[2] = GWS_GET_D2(); )==" "\n" |
53 | R"==(src_dims[3] = dst_dims[3] = GWS_GET_D3(); )==" "\n" |
54 | R"==(src_dims[4] = dst_dims[4] = GWS_GET_D4(); )==" "\n" |
55 | R"==(src_dims[5] = dst_dims[5] = GWS_GET_D5(); )==" "\n" |
56 | R"==(const int iter_dim_end = min( )==" "\n" |
57 | R"==(dst_dims[ITER_DIM_IDX] + ITER_DIM_CHUNK, ITER_DIM_PADDED_SIZE); )==" "\n" |
58 | R"==(if (NEEDS_PADDING(dst_dims[0], dst_dims[1], dst_dims[2], dst_dims[3], )==" "\n" |
59 | R"==(dst_dims[4], dst_dims[5])) { )==" "\n" |
60 | R"==(for (; dst_dims[ITER_DIM_IDX] < iter_dim_end; )==" "\n" |
61 | R"==(dst_dims[ITER_DIM_IDX]++) { )==" "\n" |
62 | R"==(const int dst_off = OFF_MD(DST, dst_dims[0], dst_dims[1], )==" "\n" |
63 | R"==(dst_dims[2], dst_dims[3], dst_dims[4], dst_dims[5]); )==" "\n" |
64 | R"==(#if SUB_GROUP_SIZE > 1 )==" "\n" |
65 | R"==(BLOCK_WRITE_DST(&dst[dst_off], TO_DST(DATA_ZERO)); )==" "\n" |
66 | R"==(#else )==" "\n" |
67 | R"==(dst[dst_off] = TO_DST(DATA_ZERO); )==" "\n" |
68 | R"==(#endif )==" "\n" |
69 | R"==(} )==" "\n" |
70 | R"==(return; )==" "\n" |
71 | R"==(} )==" "\n" |
72 | R"==(for (; dst_dims[ITER_DIM_IDX] < min(DD(ITER_DIM_IDX), iter_dim_end); )==" "\n" |
73 | R"==(dst_dims[ITER_DIM_IDX]++, src_dims[ITER_DIM_IDX]++) { )==" "\n" |
74 | R"==(int part; )==" "\n" |
75 | R"==(int src_off; )==" "\n" |
76 | R"==(__global SRC_DATA_T *src; )==" "\n" |
77 | R"==(if (IS_IN_PART(0)) SET_DIMS(0, 0) )==" "\n" |
78 | R"==(#if NUM_INPUTS >= 2 )==" "\n" |
79 | R"==(else if (IS_IN_PART(1)) )==" "\n" |
80 | R"==(SET_DIMS(0, 1) )==" "\n" |
81 | R"==(#endif )==" "\n" |
82 | R"==(#if NUM_INPUTS >= 3 )==" "\n" |
83 | R"==(else if (IS_IN_PART(2)) )==" "\n" |
84 | R"==(SET_DIMS(1, 2) )==" "\n" |
85 | R"==(#endif )==" "\n" |
86 | R"==(#if NUM_INPUTS >= 4 )==" "\n" |
87 | R"==(else if (IS_IN_PART(3)) )==" "\n" |
88 | R"==(SET_DIMS(2, 3) )==" "\n" |
89 | R"==(#endif )==" "\n" |
90 | R"==(#if NUM_INPUTS >= 5 )==" "\n" |
91 | R"==(else if (IS_IN_PART(4)) )==" "\n" |
92 | R"==(SET_DIMS(3, 4) )==" "\n" |
93 | R"==(#endif )==" "\n" |
94 | R"==(#if NUM_INPUTS >= 6 )==" "\n" |
95 | R"==(else if (IS_IN_PART(5)) )==" "\n" |
96 | R"==(SET_DIMS(4, 5) )==" "\n" |
97 | R"==(#endif )==" "\n" |
98 | R"==(#if NUM_INPUTS >= 7 )==" "\n" |
99 | R"==(else if (IS_IN_PART(6)) )==" "\n" |
100 | R"==(SET_DIMS(5, 6) )==" "\n" |
101 | R"==(#endif )==" "\n" |
102 | R"==(#if NUM_INPUTS >= 8 )==" "\n" |
103 | R"==(else if (IS_IN_PART(7)) )==" "\n" |
104 | R"==(SET_DIMS(6, 7) )==" "\n" |
105 | R"==(#endif )==" "\n" |
106 | R"==(#if NUM_INPUTS >= 9 )==" "\n" |
107 | R"==(else if (IS_IN_PART(8)) )==" "\n" |
108 | R"==(SET_DIMS(7, 8) )==" "\n" |
109 | R"==(#endif )==" "\n" |
110 | R"==(#if NUM_INPUTS >= 10 )==" "\n" |
111 | R"==(else if (IS_IN_PART(9)) )==" "\n" |
112 | R"==(SET_DIMS(8, 9) )==" "\n" |
113 | R"==(#endif )==" "\n" |
114 | R"==(#if NUM_INPUTS >= 11 )==" "\n" |
115 | R"==(else if (IS_IN_PART(10)) )==" "\n" |
116 | R"==(SET_DIMS(9, 10) )==" "\n" |
117 | R"==(#endif )==" "\n" |
118 | R"==(#if NUM_INPUTS >= 12 )==" "\n" |
119 | R"==(else if (IS_IN_PART(11)) )==" "\n" |
120 | R"==(SET_DIMS(10, 11) )==" "\n" |
121 | R"==(#endif )==" "\n" |
122 | R"==(#if NUM_INPUTS >= 13 )==" "\n" |
123 | R"==(else if (IS_IN_PART(12)) )==" "\n" |
124 | R"==(SET_DIMS(11, 12) )==" "\n" |
125 | R"==(#endif )==" "\n" |
126 | R"==(#if NUM_INPUTS >= 14 )==" "\n" |
127 | R"==(else if (IS_IN_PART(13)) )==" "\n" |
128 | R"==(SET_DIMS(12, 13) )==" "\n" |
129 | R"==(#endif )==" "\n" |
130 | R"==(#if NUM_INPUTS >= 15 )==" "\n" |
131 | R"==(else if (IS_IN_PART(14)) )==" "\n" |
132 | R"==(SET_DIMS(13, 14) )==" "\n" |
133 | R"==(#endif )==" "\n" |
134 | R"==(#if NUM_INPUTS >= 16 )==" "\n" |
135 | R"==(else if (IS_IN_PART(15)) )==" "\n" |
136 | R"==(SET_DIMS(14, 15) )==" "\n" |
137 | R"==(#endif )==" "\n" |
138 | R"==(const int dst_off = OFF_MD(DST, dst_dims[0], dst_dims[1], dst_dims[2], )==" "\n" |
139 | R"==(dst_dims[3], dst_dims[4], dst_dims[5]); )==" "\n" |
140 | R"==(#if SUB_GROUP_SIZE > 1 )==" "\n" |
141 | R"==(#if DT_BF16 == 1 )==" "\n" |
142 | R"==(float src_val = DATA_TO_REF(AS_DATA_T( )==" "\n" |
143 | R"==(BLOCK_READ((const __global BLOCK_DATA_T *)&src[src_off]))); )==" "\n" |
144 | R"==(#else )==" "\n" |
145 | R"==(SRC_DATA_T src_val = AS_DATA_T( )==" "\n" |
146 | R"==(BLOCK_READ((const __global BLOCK_DATA_T *)&src[src_off])); )==" "\n" |
147 | R"==(#endif )==" "\n" |
148 | R"==(BLOCK_WRITE_DST(&dst[dst_off], TO_DST(src_val)); )==" "\n" |
149 | R"==(#else )==" "\n" |
150 | R"==(#if DT_BF16 == 1 )==" "\n" |
151 | R"==(float src_val = DATA_TO_REF(src[src_off]); )==" "\n" |
152 | R"==(#else )==" "\n" |
153 | R"==(SRC_DATA_T src_val = src[src_off]; )==" "\n" |
154 | R"==(#endif )==" "\n" |
155 | R"==(dst[dst_off] = TO_DST(src_val); )==" "\n" |
156 | R"==(#endif )==" "\n" |
157 | R"==(} )==" "\n" |
158 | R"==(for (; dst_dims[ITER_DIM_IDX] < iter_dim_end; dst_dims[ITER_DIM_IDX]++) { )==" "\n" |
159 | R"==(const int dst_off = OFF_MD(DST, dst_dims[0], dst_dims[1], dst_dims[2], )==" "\n" |
160 | R"==(dst_dims[3], dst_dims[4], dst_dims[5]); )==" "\n" |
161 | R"==(#if SUB_GROUP_SIZE > 1 )==" "\n" |
162 | R"==(BLOCK_WRITE_DST(&dst[dst_off], TO_DST(DATA_ZERO)); )==" "\n" |
163 | R"==(#else )==" "\n" |
164 | R"==(dst[dst_off] = TO_DST(DATA_ZERO); )==" "\n" |
165 | R"==(#endif )==" "\n" |
166 | R"==(} )==" "\n" |
167 | R"==(} )==" "\n" |
168 | R"==()==" ; |
169 | } |
170 | } |
171 | } |
172 | } |