1 | namespace dnnl { |
2 | namespace impl { |
3 | namespace gpu { |
4 | namespace ocl { |
5 | const char *gen9_conv_bwd_weights_kernel = R"==(/******************************************************************************* )==" "\n" |
6 | R"==(* Copyright 2019-2021 Intel Corporation )==" "\n" |
7 | R"==(* )==" "\n" |
8 | R"==(* Licensed under the Apache License, Version 2.0 (the "License"); )==" "\n" |
9 | R"==(* you may not use this file except in compliance with the License. )==" "\n" |
10 | R"==(* You may obtain a copy of the License at )==" "\n" |
11 | R"==(* )==" "\n" |
12 | R"==(* http: )==" "\n" |
13 | R"==(* )==" "\n" |
14 | R"==(* Unless required by applicable law or agreed to in writing, software )==" "\n" |
15 | R"==(* distributed under the License is distributed on an "AS IS" BASIS, )==" "\n" |
16 | R"==(* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. )==" "\n" |
17 | R"==(* See the License for the specific language governing permissions and )==" "\n" |
18 | R"==(* limitations under the License. )==" "\n" |
19 | R"==(*******************************************************************************/ )==" "\n" |
20 | R"==(#include "gpu/ocl/ocl_types.h" )==" "\n" |
21 | R"==(#define DT_UNDEF )==" "\n" |
22 | R"==(#include "gpu/ocl/ocl_math_utils.h" )==" "\n" |
23 | R"==(#include "gpu/ocl/ocl_types.h" )==" "\n" |
24 | R"==(#if OD > 1 )==" "\n" |
25 | R"==(#define CASE_3D 1 )==" "\n" |
26 | R"==(#else )==" "\n" |
27 | R"==(#define CASE_3D 0 )==" "\n" |
28 | R"==(#endif )==" "\n" |
29 | R"==(#define HAS_PAD_D (PD != 0 || PD_R != 0) )==" "\n" |
30 | R"==(#define HAS_PAD_H (PH != 0 || PH_R != 0) )==" "\n" |
31 | R"==(#define HAS_PAD_W (PW != 0 || PW_R != 0) )==" "\n" |
32 | R"==(#if DST_DT_F32 )==" "\n" |
33 | R"==(#define BLOCK_READ_DST(ptr) \ )==" "\n" |
34 | R"==(as_float(intel_sub_group_block_read((__global uint *)ptr)) )==" "\n" |
35 | R"==(#elif DST_DT_BF16 )==" "\n" |
36 | R"==(#define BLOCK_READ_DST(ptr) \ )==" "\n" |
37 | R"==(as_ushort(intel_sub_group_block_read_us((__global ushort *)ptr)) )==" "\n" |
38 | R"==(#define BLOCK_READ_DST8(ptr) \ )==" "\n" |
39 | R"==(as_ushort8(intel_sub_group_block_read_us8((__global ushort *)ptr)) )==" "\n" |
40 | R"==(#endif )==" "\n" |
41 | R"==(#if BWD_WEIGHTS == 1 )==" "\n" |
42 | R"==(__attribute__((reqd_work_group_size(LWS_0, LWS_1, LWS_2))) )==" "\n" |
43 | R"==(#if VER_16MB16C == 1 || VER_8OW16C == 1 )==" "\n" |
44 | R"==(__attribute__((intel_reqd_sub_group_size(SUB_GROUP_SIZE))) )==" "\n" |
45 | R"==(#endif )==" "\n" |
46 | R"==(__kernel void )==" "\n" |
47 | R"==(gen9_conv_bwd_weights(__global SRC_DATA_T *src, )==" "\n" |
48 | R"==(volatile __global atomic_float *diff_wei, )==" "\n" |
49 | R"==(volatile __global atomic_float *diff_bias, )==" "\n" |
50 | R"==(__global DST_DATA_T *diff_dst) { )==" "\n" |
51 | R"==(MAYBE_SKIP_NON_UNIFORM_WG(); )==" "\n" |
52 | R"==(#if VER_16MB16C == 1 )==" "\n" |
53 | R"==(const uint ksp = get_global_id(1); )==" "\n" |
54 | R"==(#if CASE_3D )==" "\n" |
55 | R"==(const uint kd = ksp / (KW * KH); )==" "\n" |
56 | R"==(const uint khw = ksp % (KW * KH); )==" "\n" |
57 | R"==(#else )==" "\n" |
58 | R"==(const uint khw = ksp; )==" "\n" |
59 | R"==(const uint kd = 0; )==" "\n" |
60 | R"==(#endif )==" "\n" |
61 | R"==(const uint kh = khw / KW; )==" "\n" |
62 | R"==(const uint kw = khw % KW; )==" "\n" |
63 | R"==(const uint sglid = get_sub_group_local_id(); )==" "\n" |
64 | R"==(const uint chunk = get_global_id(2) / ((IC / ICB) * (OC / OCB)); )==" "\n" |
65 | R"==(const uint icb_ocb = get_global_id(2) % ((IC / ICB) * (OC / OCB)); )==" "\n" |
66 | R"==(const uint icb = icb_ocb % (IC / ICB); )==" "\n" |
67 | R"==(const uint ocb = icb_ocb / (IC / ICB); )==" "\n" |
68 | R"==(#if IS_DW )==" "\n" |
69 | R"==(const uint g = 0; )==" "\n" |
70 | R"==(const uint oc )==" "\n" |
71 | R"==(= get_group_id(0) * (LWS_0 / SUB_GROUP_SIZE) + get_sub_group_id(); )==" "\n" |
72 | R"==(const uint ic = oc; )==" "\n" |
73 | R"==(#else )==" "\n" |
74 | R"==(const uint g_ic_oc = get_global_id(0); )==" "\n" |
75 | R"==(const uint g = g_ic_oc / (OCB * (ICB / IC_BLOCK)); )==" "\n" |
76 | R"==(const uint io = g_ic_oc % (OCB * (ICB / IC_BLOCK)); )==" "\n" |
77 | R"==(const uint oc = (io % OCB) / OC_BLOCK + ocb * (OCB / OC_BLOCK); )==" "\n" |
78 | R"==(const uint ic = io / OCB + icb * (ICB / IC_BLOCK); )==" "\n" |
79 | R"==(#endif )==" "\n" |
80 | R"==(const uint sp_chunk = chunk % OSP_CHUNK; )==" "\n" |
81 | R"==(const uint mb_chunk = chunk / OSP_CHUNK; )==" "\n" |
82 | R"==(const uint oh_nb = (OH + OHB - 1) / OHB; )==" "\n" |
83 | R"==(const uint ow_nb = (OW + OWB - 1) / OWB; )==" "\n" |
84 | R"==(const uint od_beg = (sp_chunk / ow_nb) / oh_nb * ODB; )==" "\n" |
85 | R"==(const uint oh_beg = (sp_chunk / ow_nb) % oh_nb * OHB; )==" "\n" |
86 | R"==(const uint ow_beg = (sp_chunk % ow_nb) * OWB; )==" "\n" |
87 | R"==(const uint mb = mb_chunk * (MB_CHUNK_SIZE); )==" "\n" |
88 | R"==(const uint mb_end = min((mb_chunk + 1) * (MB_CHUNK_SIZE), (uint)MB); )==" "\n" |
89 | R"==(const bool do_bias = (ic == 0 || IS_DW) && kh == 0 && kw == 0 && kd == 0; )==" "\n" |
90 | R"==(src += ic * ID * IH * IW * IC_BLOCK * MB_BLOCK + mb * IC * G * ID * IH * IW )==" "\n" |
91 | R"==(+ g * IC * ID * IH * IW * MB_BLOCK; )==" "\n" |
92 | R"==(diff_dst += oc * OD * OH * OW * OC_BLOCK * MB_BLOCK )==" "\n" |
93 | R"==(+ g * OC * OD * OH * OW * MB_BLOCK; )==" "\n" |
94 | R"==(#if WITH_BIAS == 1 )==" "\n" |
95 | R"==(diff_bias += g * OC + oc * OC_BLOCK + sglid; )==" "\n" |
96 | R"==(float bias_loc = 0.0f; )==" "\n" |
97 | R"==(#endif )==" "\n" |
98 | R"==(#if IS_DW )==" "\n" |
99 | R"==(float blockC00 = 0.0f; )==" "\n" |
100 | R"==(#else )==" "\n" |
101 | R"==(float8 blockC00 = 0.0f; )==" "\n" |
102 | R"==(float8 blockC01 = 0.0f; )==" "\n" |
103 | R"==(#endif )==" "\n" |
104 | R"==(#if MB != (MB_CHUNK * MB_BLOCK) )==" "\n" |
105 | R"==(uint omb = mb; )==" "\n" |
106 | R"==(do { )==" "\n" |
107 | R"==(const __global float *diff_dst1_ )==" "\n" |
108 | R"==(= diff_dst + omb * OC * G * OD * OH * OW; )==" "\n" |
109 | R"==(#else )==" "\n" |
110 | R"==(const __global float *diff_dst1_ = diff_dst + mb * OC * G * OD * OH * OW; )==" "\n" |
111 | R"==(#endif )==" "\n" |
112 | R"==(for (uint od = od_beg; od < min(od_beg + ODB, (uint)OD); od++) { )==" "\n" |
113 | R"==(for (uint oh = oh_beg; oh < min(oh_beg + OHB, (uint)OH); oh++) { )==" "\n" |
114 | R"==(for (uint ow = ow_beg; ow < min(ow_beg + OWB, (uint)OW); ow++) { )==" "\n" |
115 | R"==(const __global float *diff_dst1 = diff_dst1_ )==" "\n" |
116 | R"==(+ od * OH * OW * OC_BLOCK * MB_BLOCK )==" "\n" |
117 | R"==(+ oh * OW * OC_BLOCK * MB_BLOCK )==" "\n" |
118 | R"==(+ ow * OC_BLOCK * MB_BLOCK; )==" "\n" |
119 | R"==(const uint ih = oh * SH - PH + kh * (1 + DH); )==" "\n" |
120 | R"==(const uint iw = ow * SW - PW + kw * (1 + DW); )==" "\n" |
121 | R"==(#if CASE_3D )==" "\n" |
122 | R"==(const uint id = od * SD - PD + kd * (1 + DD); )==" "\n" |
123 | R"==(#endif )==" "\n" |
124 | R"==(if (iw < 0 || ih < 0 || iw >= IW || ih >= IH )==" "\n" |
125 | R"==(#if CASE_3D )==" "\n" |
126 | R"==(|| id < 0 || id >= ID )==" "\n" |
127 | R"==(#endif )==" "\n" |
128 | R"==() { )==" "\n" |
129 | R"==(#if WITH_BIAS == 1 )==" "\n" |
130 | R"==(if (do_bias) { )==" "\n" |
131 | R"==(float8 blockB )==" "\n" |
132 | R"==(= as_float8(intel_sub_group_block_read8(( )==" "\n" |
133 | R"==(const __global uint *)(diff_dst1))); )==" "\n" |
134 | R"==(for (int i = 0; i < 8; i++) )==" "\n" |
135 | R"==(bias_loc += blockB[i]; )==" "\n" |
136 | R"==(blockB = as_float8(intel_sub_group_block_read8( )==" "\n" |
137 | R"==((const __global uint *)(diff_dst1 )==" "\n" |
138 | R"==(+ 8 * OC_BLOCK))); )==" "\n" |
139 | R"==(for (int i = 0; i < 8; i++) )==" "\n" |
140 | R"==(bias_loc += blockB[i]; )==" "\n" |
141 | R"==(} )==" "\n" |
142 | R"==(#endif )==" "\n" |
143 | R"==(continue; )==" "\n" |
144 | R"==(} )==" "\n" |
145 | R"==(const __global float *src1 = src )==" "\n" |
146 | R"==(+ ih * IW * IC_BLOCK * MB_BLOCK )==" "\n" |
147 | R"==(+ iw * IC_BLOCK * MB_BLOCK; )==" "\n" |
148 | R"==(#if CASE_3D )==" "\n" |
149 | R"==(src1 += id * IH * IW * IC_BLOCK * MB_BLOCK; )==" "\n" |
150 | R"==(#endif )==" "\n" |
151 | R"==(#define TRANSPOSE_8(_block, _row, _col) \ )==" "\n" |
152 | R"==((float8)(intel_sub_group_shuffle(_block[_row], 0 + _col), \ )==" "\n" |
153 | R"==(intel_sub_group_shuffle(_block[_row], 1 + _col), \ )==" "\n" |
154 | R"==(intel_sub_group_shuffle(_block[_row], 2 + _col), \ )==" "\n" |
155 | R"==(intel_sub_group_shuffle(_block[_row], 3 + _col), \ )==" "\n" |
156 | R"==(intel_sub_group_shuffle(_block[_row], 4 + _col), \ )==" "\n" |
157 | R"==(intel_sub_group_shuffle(_block[_row], 5 + _col), \ )==" "\n" |
158 | R"==(intel_sub_group_shuffle(_block[_row], 6 + _col), \ )==" "\n" |
159 | R"==(intel_sub_group_shuffle(_block[_row], 7 + _col)) )==" "\n" |
160 | R"==(#define FMA8(a, b, c) fma((float8)(a), (float8)b, (float8)c) )==" "\n" |
161 | R"==(#define MULTIPLY_BLOCKS_8x8(_result, _blockA, _blockB, col) \ )==" "\n" |
162 | R"==({ \ )==" "\n" |
163 | R"==(_result = FMA8(_blockB.s0, TRANSPOSE_8(_blockA, 0, col), _result); \ )==" "\n" |
164 | R"==(_result = FMA8(_blockB.s1, TRANSPOSE_8(_blockA, 1, col), _result); \ )==" "\n" |
165 | R"==(_result = FMA8(_blockB.s2, TRANSPOSE_8(_blockA, 2, col), _result); \ )==" "\n" |
166 | R"==(_result = FMA8(_blockB.s3, TRANSPOSE_8(_blockA, 3, col), _result); \ )==" "\n" |
167 | R"==(_result = FMA8(_blockB.s4, TRANSPOSE_8(_blockA, 4, col), _result); \ )==" "\n" |
168 | R"==(_result = FMA8(_blockB.s5, TRANSPOSE_8(_blockA, 5, col), _result); \ )==" "\n" |
169 | R"==(_result = FMA8(_blockB.s6, TRANSPOSE_8(_blockA, 6, col), _result); \ )==" "\n" |
170 | R"==(_result = FMA8(_blockB.s7, TRANSPOSE_8(_blockA, 7, col), _result); \ )==" "\n" |
171 | R"==(} )==" "\n" |
172 | R"==(#if IS_DW )==" "\n" |
173 | R"==(float8 blockA = as_float8(intel_sub_group_block_read8( )==" "\n" |
174 | R"==((const __global uint *)(src1))); )==" "\n" |
175 | R"==(float8 blockA1 = as_float8(intel_sub_group_block_read8( )==" "\n" |
176 | R"==((const __global uint *)(src1 + 8 * IC_BLOCK))); )==" "\n" |
177 | R"==(float8 blockB = as_float8(intel_sub_group_block_read8( )==" "\n" |
178 | R"==((const __global uint *)(diff_dst1))); )==" "\n" |
179 | R"==(float8 blockB1 = as_float8(intel_sub_group_block_read8( )==" "\n" |
180 | R"==((const __global uint *)(diff_dst1 + 8 * OC_BLOCK))); )==" "\n" |
181 | R"==(for (int i = 0; i < 8; i++) { )==" "\n" |
182 | R"==(blockC00 = fma(blockA[i], blockB[i], blockC00); )==" "\n" |
183 | R"==(} )==" "\n" |
184 | R"==(#if WITH_BIAS == 1 )==" "\n" |
185 | R"==(for (int i = 0; i < 8; i++) )==" "\n" |
186 | R"==(bias_loc += blockB[i]; )==" "\n" |
187 | R"==(#endif )==" "\n" |
188 | R"==(for (int i = 0; i < 8; i++) { )==" "\n" |
189 | R"==(blockC00 = fma(blockA1[i], blockB1[i], blockC00); )==" "\n" |
190 | R"==(} )==" "\n" |
191 | R"==(#if WITH_BIAS == 1 )==" "\n" |
192 | R"==(for (int i = 0; i < 8; i++) )==" "\n" |
193 | R"==(bias_loc += blockB1[i]; )==" "\n" |
194 | R"==(#endif )==" "\n" |
195 | R"==(#else )==" "\n" |
196 | R"==(float8 blockA = as_float8(intel_sub_group_block_read8( )==" "\n" |
197 | R"==((const __global uint *)(src1))); )==" "\n" |
198 | R"==(float8 blockB = as_float8(intel_sub_group_block_read8( )==" "\n" |
199 | R"==((const __global uint *)(diff_dst1))); )==" "\n" |
200 | R"==(MULTIPLY_BLOCKS_8x8(blockC00, blockA, blockB, 0); )==" "\n" |
201 | R"==(MULTIPLY_BLOCKS_8x8(blockC01, blockA, blockB, 8); )==" "\n" |
202 | R"==(#if WITH_BIAS == 1 )==" "\n" |
203 | R"==(for (int i = 0; i < 8; i++) )==" "\n" |
204 | R"==(bias_loc += blockB[i]; )==" "\n" |
205 | R"==(#endif )==" "\n" |
206 | R"==(blockA = as_float8(intel_sub_group_block_read8( )==" "\n" |
207 | R"==((const __global uint *)(src1 + 8 * IC_BLOCK))); )==" "\n" |
208 | R"==(blockB = as_float8(intel_sub_group_block_read8( )==" "\n" |
209 | R"==((const __global uint *)(diff_dst1 + 8 * OC_BLOCK))); )==" "\n" |
210 | R"==(MULTIPLY_BLOCKS_8x8(blockC00, blockA, blockB, 0); )==" "\n" |
211 | R"==(MULTIPLY_BLOCKS_8x8(blockC01, blockA, blockB, 8); )==" "\n" |
212 | R"==(#if WITH_BIAS == 1 )==" "\n" |
213 | R"==(for (int i = 0; i < 8; i++) )==" "\n" |
214 | R"==(bias_loc += blockB[i]; )==" "\n" |
215 | R"==(#endif )==" "\n" |
216 | R"==(#endif )==" "\n" |
217 | R"==(} )==" "\n" |
218 | R"==(} )==" "\n" |
219 | R"==(} )==" "\n" |
220 | R"==(#if MB != (MB_CHUNK * MB_BLOCK) )==" "\n" |
221 | R"==(omb += MB_BLOCK; )==" "\n" |
222 | R"==(src += IC * G * ID * IH * IW * MB_BLOCK; )==" "\n" |
223 | R"==(} while (omb < mb_end); )==" "\n" |
224 | R"==(#endif )==" "\n" |
225 | R"==(#if WITH_BIAS == 1 )==" "\n" |
226 | R"==(if (do_bias )==" "\n" |
227 | R"==(&& oc * OC_BLOCK + sglid < (IS_DW ? G_WO_PADDING : OC_WO_PADDING)) )==" "\n" |
228 | R"==(atomic_add_global(diff_bias, bias_loc); )==" "\n" |
229 | R"==(#endif )==" "\n" |
230 | R"==(#if IS_DW )==" "\n" |
231 | R"==(diff_wei += oc * KD * KH * KW * OC_BLOCK + kd * KH * KW * OC_BLOCK )==" "\n" |
232 | R"==(+ kh * KW * OC_BLOCK + kw * OC_BLOCK; )==" "\n" |
233 | R"==(atomic_add_global(diff_wei + sglid, blockC00); )==" "\n" |
234 | R"==(#else )==" "\n" |
235 | R"==(diff_wei += ic * OC * KD * KH * KW * IC_BLOCK )==" "\n" |
236 | R"==(+ oc * KD * KH * KW * IC_BLOCK * OC_BLOCK )==" "\n" |
237 | R"==(+ kd * KH * KW * IC_BLOCK * OC_BLOCK + kh * KW * IC_BLOCK * OC_BLOCK )==" "\n" |
238 | R"==(+ kw * IC_BLOCK * OC_BLOCK + g * OC * IC * KD * KH * KW; )==" "\n" |
239 | R"==(for (int i = 0; i < 8; i++) )==" "\n" |
240 | R"==(atomic_add_global(diff_wei + i * OC_BLOCK + sglid, blockC00[i]); )==" "\n" |
241 | R"==(for (int i = 0; i < 8; i++) )==" "\n" |
242 | R"==(atomic_add_global(diff_wei + (8 + i) * OC_BLOCK + sglid, blockC01[i]); )==" "\n" |
243 | R"==(#endif )==" "\n" |
244 | R"==(#endif )==" "\n" |
245 | R"==(#if VER_8OW16C == 1 )==" "\n" |
246 | R"==(#define HAS_PAD_W (PW > 0 || OW * SW - PW + (KW - 1) * (1 + DW) >= IW) )==" "\n" |
247 | R"==(const int sglid = get_sub_group_local_id(); )==" "\n" |
248 | R"==(#if IC == 3 )==" "\n" |
249 | R"==(const int ksp = get_global_id(1) * 16 + sglid; )==" "\n" |
250 | R"==(#else )==" "\n" |
251 | R"==(const int ksp = get_global_id(1); )==" "\n" |
252 | R"==(#endif )==" "\n" |
253 | R"==(const int ICX = IC == 3 ? 3 : 1; )==" "\n" |
254 | R"==(#if CASE_3D )==" "\n" |
255 | R"==(const int kd = ksp / (KW * KH * ICX); )==" "\n" |
256 | R"==(const int khw = ksp % (KW * KH * ICX); )==" "\n" |
257 | R"==(#else )==" "\n" |
258 | R"==(const int khw = ksp; )==" "\n" |
259 | R"==(const int kd = 0; )==" "\n" |
260 | R"==(#endif )==" "\n" |
261 | R"==(const int kh = khw / (KW * ICX); )==" "\n" |
262 | R"==(const int kw = (khw % (KW * ICX)) % KW; )==" "\n" |
263 | R"==(const int chunk = get_global_id(2) % NCHUNK; )==" "\n" |
264 | R"==(const int icb_ocb = get_global_id(2) / NCHUNK; )==" "\n" |
265 | R"==(const int icb = icb_ocb % (IC / ICB); )==" "\n" |
266 | R"==(const int ocb = icb_ocb / (IC / ICB); )==" "\n" |
267 | R"==(#if IS_DW )==" "\n" |
268 | R"==(const int g = 0; )==" "\n" |
269 | R"==(const int oc )==" "\n" |
270 | R"==(= get_group_id(0) * (LWS_0 / SUB_GROUP_SIZE) + get_sub_group_id(); )==" "\n" |
271 | R"==(const int ic = oc; )==" "\n" |
272 | R"==(#else )==" "\n" |
273 | R"==(const int g_ic_oc = get_global_id(0); )==" "\n" |
274 | R"==(const int g = g_ic_oc / (OC * (IC / IC_BLOCK)); )==" "\n" |
275 | R"==(const int io = g_ic_oc % (OC * (IC / IC_BLOCK)); )==" "\n" |
276 | R"==(const int oc = (io % OCB) / OC_BLOCK + ocb * (OCB / OC_BLOCK); )==" "\n" |
277 | R"==(const int ic = (IC == 3) ? (khw % (KW * ICX)) / KW )==" "\n" |
278 | R"==(: (io / OCB + icb * (ICB / IC_BLOCK)); )==" "\n" |
279 | R"==(#endif )==" "\n" |
280 | R"==(const int sp_chunk = chunk % OSP_CHUNK; )==" "\n" |
281 | R"==(const int mb_chunk = chunk / OSP_CHUNK; )==" "\n" |
282 | R"==(const int ow_nb = (OW + OWB - 1) / OWB; )==" "\n" |
283 | R"==(const int oh_nb = (OH + OHB - 1) / OHB; )==" "\n" |
284 | R"==(const int od_beg = ((sp_chunk / ow_nb) / oh_nb) * ODB; )==" "\n" |
285 | R"==(const int oh_beg = ((sp_chunk / ow_nb) % oh_nb) * OHB; )==" "\n" |
286 | R"==(const int ow_beg = (sp_chunk % ow_nb) * OWB; )==" "\n" |
287 | R"==(const int mb = mb_chunk * MB_CHUNK_SIZE; )==" "\n" |
288 | R"==(const int mb_end = min((mb_chunk + 1) * MB_CHUNK_SIZE, MB); )==" "\n" |
289 | R"==(#if IC == 3 )==" "\n" |
290 | R"==(const bool do_bias = get_global_id(1) == 0; )==" "\n" |
291 | R"==(#else )==" "\n" |
292 | R"==(const bool do_bias = (ic == 0 || IS_DW) && kh == 0 && kw == 0 && kd == 0; )==" "\n" |
293 | R"==(#endif )==" "\n" |
294 | R"==(const int OW_LOOP_BLOCK = 8; )==" "\n" |
295 | R"==(#if IC == 3 )==" "\n" |
296 | R"==(src += mb * IC * G * ID * IH * IW + g * IC * ID * IH * IW * MB_BLOCK; )==" "\n" |
297 | R"==(#else )==" "\n" |
298 | R"==(src += ic * ID * IH * IW * IC_BLOCK * MB_BLOCK + mb * IC * G * ID * IH * IW )==" "\n" |
299 | R"==(+ g * IC * ID * IH * IW * MB_BLOCK; )==" "\n" |
300 | R"==(#endif )==" "\n" |
301 | R"==(diff_dst += oc * OD * OH * OW * OC_BLOCK * MB_BLOCK )==" "\n" |
302 | R"==(+ g * OC * OD * OH * OW * MB_BLOCK; )==" "\n" |
303 | R"==(#if WITH_BIAS == 1 )==" "\n" |
304 | R"==(diff_bias += g * OC + oc * OC_BLOCK + sglid; )==" "\n" |
305 | R"==(float bias_loc = 0.0f; )==" "\n" |
306 | R"==(#endif )==" "\n" |
307 | R"==(#if IC == 3 )==" "\n" |
308 | R"==(float8 blockC00 = 0.0f; )==" "\n" |
309 | R"==(float8 blockC01 = 0.0f; )==" "\n" |
310 | R"==(#elif IS_DW )==" "\n" |
311 | R"==(float blockC00 = 0.0f; )==" "\n" |
312 | R"==(#else )==" "\n" |
313 | R"==(float8 blockC00 = 0.0f; )==" "\n" |
314 | R"==(float8 blockC01 = 0.0f; )==" "\n" |
315 | R"==(#endif )==" "\n" |
316 | R"==(for (int omb = mb; omb < mb_end; omb++) { )==" "\n" |
317 | R"==(const __global DST_DATA_T *diff_dst1_ )==" "\n" |
318 | R"==(= diff_dst + omb * OC * G * OD * OH * OW; )==" "\n" |
319 | R"==(for (int od = od_beg; od < min(od_beg + ODB, OD); od++) )==" "\n" |
320 | R"==(for (int oh = oh_beg; oh < min(oh_beg + OHB, OH); oh++) { )==" "\n" |
321 | R"==(const __global DST_DATA_T *diff_dst1 = diff_dst1_ )==" "\n" |
322 | R"==(+ od * OH * OW * OC_BLOCK + oh * OW * OC_BLOCK; )==" "\n" |
323 | R"==(bool skip = false; )==" "\n" |
324 | R"==(if (oh * SH + kh * (1 + DH) < PH )==" "\n" |
325 | R"==(|| oh * SH + kh * (1 + DH) >= IH + PH )==" "\n" |
326 | R"==(#if CASE_3D )==" "\n" |
327 | R"==(|| od * SD + kd * (1 + DD) < PD )==" "\n" |
328 | R"==(|| od * SD + kd * (1 + DD) >= ID + PD )==" "\n" |
329 | R"==(#endif )==" "\n" |
330 | R"==() { )==" "\n" |
331 | R"==(skip = true; )==" "\n" |
332 | R"==(} )==" "\n" |
333 | R"==(const int id = od * SD - PD + kd * (1 + DD); )==" "\n" |
334 | R"==(const int ih = oh * SH - PH + kh * (1 + DH); )==" "\n" |
335 | R"==(__global SRC_DATA_T *src1; )==" "\n" |
336 | R"==(for (int ow = ow_beg; )==" "\n" |
337 | R"==(ow < min(ow_beg + OWB, (OW / OW_BLOCK) * OW_BLOCK); )==" "\n" |
338 | R"==(ow += OW_BLOCK) { )==" "\n" |
339 | R"==(const int iw = ow * SW - PW + kw * (1 + DW); )==" "\n" |
340 | R"==(src1 = src + id * IH * IW * IC_BLOCK + ih * IW * IC_BLOCK )==" "\n" |
341 | R"==(+ iw * IC_BLOCK; )==" "\n" |
342 | R"==(#define TRANSPOSE_8(_block, _row, _col) \ )==" "\n" |
343 | R"==({ \ )==" "\n" |
344 | R"==((float8)(intel_sub_group_shuffle(_block[_row], 0 + _col), \ )==" "\n" |
345 | R"==(intel_sub_group_shuffle(_block[_row], 1 + _col), \ )==" "\n" |
346 | R"==(intel_sub_group_shuffle(_block[_row], 2 + _col), \ )==" "\n" |
347 | R"==(intel_sub_group_shuffle(_block[_row], 3 + _col), \ )==" "\n" |
348 | R"==(intel_sub_group_shuffle(_block[_row], 4 + _col), \ )==" "\n" |
349 | R"==(intel_sub_group_shuffle(_block[_row], 5 + _col), \ )==" "\n" |
350 | R"==(intel_sub_group_shuffle(_block[_row], 6 + _col), \ )==" "\n" |
351 | R"==(intel_sub_group_shuffle(_block[_row], 7 + _col)) \ )==" "\n" |
352 | R"==(} )==" "\n" |
353 | R"==(#define FMA8(a, b, c) fma((float8)(a), (float8)b, (float8)c) )==" "\n" |
354 | R"==(#define MULTIPLY_BLOCKS_8x8(_result, _blockA, _blockB, col) \ )==" "\n" |
355 | R"==({ \ )==" "\n" |
356 | R"==(_result = FMA8(_blockB.s0, TRANSPOSE_8(_blockA, 0, col), _result); \ )==" "\n" |
357 | R"==(_result = FMA8(_blockB.s1, TRANSPOSE_8(_blockA, 1, col), _result); \ )==" "\n" |
358 | R"==(_result = FMA8(_blockB.s2, TRANSPOSE_8(_blockA, 2, col), _result); \ )==" "\n" |
359 | R"==(_result = FMA8(_blockB.s3, TRANSPOSE_8(_blockA, 3, col), _result); \ )==" "\n" |
360 | R"==(_result = FMA8(_blockB.s4, TRANSPOSE_8(_blockA, 4, col), _result); \ )==" "\n" |
361 | R"==(_result = FMA8(_blockB.s5, TRANSPOSE_8(_blockA, 5, col), _result); \ )==" "\n" |
362 | R"==(_result = FMA8(_blockB.s6, TRANSPOSE_8(_blockA, 6, col), _result); \ )==" "\n" |
363 | R"==(_result = FMA8(_blockB.s7, TRANSPOSE_8(_blockA, 7, col), _result); \ )==" "\n" |
364 | R"==(} )==" "\n" |
365 | R"==(float8 blockA, blockB; )==" "\n" |
366 | R"==(#if IC == 3 )==" "\n" |
367 | R"==(if (skip) { )==" "\n" |
368 | R"==(blockA = 0.0f; )==" "\n" |
369 | R"==(} else { )==" "\n" |
370 | R"==(for (int i = 0; i < 8; i++) { )==" "\n" |
371 | R"==(if (HAS_PAD_W )==" "\n" |
372 | R"==(&& (iw + i * SW < 0 || iw + i * SW >= IW)) )==" "\n" |
373 | R"==(blockA[i] = 0; )==" "\n" |
374 | R"==(else )==" "\n" |
375 | R"==(blockA[i] = SRC_TO_REF( )==" "\n" |
376 | R"==(src1[ic * ID * IH * IW + i * SW]); )==" "\n" |
377 | R"==(} )==" "\n" |
378 | R"==(} )==" "\n" |
379 | R"==(#else )==" "\n" |
380 | R"==(if (skip) { )==" "\n" |
381 | R"==(blockA = 0.0f; )==" "\n" |
382 | R"==(} else { )==" "\n" |
383 | R"==(for (int i = 0; i < OW_BLOCK; i++) { )==" "\n" |
384 | R"==(if (HAS_PAD_W )==" "\n" |
385 | R"==(&& (iw + i * SW < 0 || iw + i * SW >= IW)) { )==" "\n" |
386 | R"==(blockA[i] = 0; )==" "\n" |
387 | R"==(} else { )==" "\n" |
388 | R"==(blockA[i] = as_float(intel_sub_group_block_read( )==" "\n" |
389 | R"==((const __global uint *)(&src1[i )==" "\n" |
390 | R"==(* IC_BLOCK * SW]))); )==" "\n" |
391 | R"==(} )==" "\n" |
392 | R"==(} )==" "\n" |
393 | R"==(} )==" "\n" |
394 | R"==(#endif )==" "\n" |
395 | R"==(blockB = DST_TO_REF8( )==" "\n" |
396 | R"==(BLOCK_READ_DST8(diff_dst1 + ow * OC_BLOCK)); )==" "\n" |
397 | R"==(#if IC == 3 )==" "\n" |
398 | R"==(MULTIPLY_BLOCKS_8x8(blockC00, blockB, blockA, 0); )==" "\n" |
399 | R"==(MULTIPLY_BLOCKS_8x8(blockC01, blockB, blockA, 8); )==" "\n" |
400 | R"==(#elif IS_DW )==" "\n" |
401 | R"==(for (int i = 0; i < OW_LOOP_BLOCK; i++) { )==" "\n" |
402 | R"==(blockC00 = fma(blockA[i], blockB[i], blockC00); )==" "\n" |
403 | R"==(} )==" "\n" |
404 | R"==(#else )==" "\n" |
405 | R"==(MULTIPLY_BLOCKS_8x8(blockC00, blockA, blockB, 0); )==" "\n" |
406 | R"==(MULTIPLY_BLOCKS_8x8(blockC01, blockA, blockB, 8); )==" "\n" |
407 | R"==(#endif )==" "\n" |
408 | R"==(#if WITH_BIAS == 1 )==" "\n" |
409 | R"==(for (int i = 0; i < OW_LOOP_BLOCK; i++) { )==" "\n" |
410 | R"==(bias_loc += blockB[i]; )==" "\n" |
411 | R"==(} )==" "\n" |
412 | R"==(#endif )==" "\n" |
413 | R"==(} )==" "\n" |
414 | R"==(for (int ow = (OW / OW_BLOCK) * OW_BLOCK; )==" "\n" |
415 | R"==(ow < min(ow_beg + OWB, OW); ow += OW_LOOP_BLOCK) { )==" "\n" |
416 | R"==(const int id = od * SD - PD + kd * (1 + DD); )==" "\n" |
417 | R"==(const int ih = oh * SH - PH + kh * (1 + DH); )==" "\n" |
418 | R"==(const int iw = ow * SW - PW + kw * (1 + DW); )==" "\n" |
419 | R"==(__global SRC_DATA_T *src1; )==" "\n" |
420 | R"==(float8 blockA, blockB; )==" "\n" |
421 | R"==(src1 = src + id * IH * IW * IC_BLOCK + ih * IW * IC_BLOCK )==" "\n" |
422 | R"==(+ iw * IC_BLOCK; )==" "\n" |
423 | R"==(#if IC == 3 )==" "\n" |
424 | R"==(if (skip) { )==" "\n" |
425 | R"==(blockA = 0.0f; )==" "\n" |
426 | R"==(} else { )==" "\n" |
427 | R"==(for (int i = 0; i < min(OW_LOOP_BLOCK, OW - ow); i++) { )==" "\n" |
428 | R"==(if (HAS_PAD_W )==" "\n" |
429 | R"==(&& (iw + i * SW < 0 || iw + i * SW >= IW)) )==" "\n" |
430 | R"==(blockA[i] = 0; )==" "\n" |
431 | R"==(else )==" "\n" |
432 | R"==(blockA[i] = SRC_TO_REF( )==" "\n" |
433 | R"==(src1[ic * ID * IH * IW + i * SW]); )==" "\n" |
434 | R"==(} )==" "\n" |
435 | R"==(} )==" "\n" |
436 | R"==(#else )==" "\n" |
437 | R"==(if (skip) { )==" "\n" |
438 | R"==(blockA = 0.0f; )==" "\n" |
439 | R"==(} else { )==" "\n" |
440 | R"==(for (int i = 0; i < min(OW_LOOP_BLOCK, OW - ow); i++) { )==" "\n" |
441 | R"==(if (HAS_PAD_W )==" "\n" |
442 | R"==(&& (iw + i * SW < 0 || iw + i * SW >= IW)) { )==" "\n" |
443 | R"==(blockA[i] = 0; )==" "\n" |
444 | R"==(} else { )==" "\n" |
445 | R"==(blockA[i] = as_float(intel_sub_group_block_read( )==" "\n" |
446 | R"==((const __global uint *)(&src1[i )==" "\n" |
447 | R"==(* IC_BLOCK * SW]))); )==" "\n" |
448 | R"==(} )==" "\n" |
449 | R"==(} )==" "\n" |
450 | R"==(} )==" "\n" |
451 | R"==(#endif )==" "\n" |
452 | R"==(for (int i = 0; i < min(OW_LOOP_BLOCK, OW - ow); i++) { )==" "\n" |
453 | R"==(blockB[i] = DST_TO_REF(BLOCK_READ_DST( )==" "\n" |
454 | R"==((&diff_dst1[(ow + i) * OC_BLOCK]))); )==" "\n" |
455 | R"==(} )==" "\n" |
456 | R"==(#if IC == 3 )==" "\n" |
457 | R"==(for (int i = 0; i < min(OW_LOOP_BLOCK, OW - ow); i++) { )==" "\n" |
458 | R"==(blockC00 = FMA8( )==" "\n" |
459 | R"==(blockA[i], TRANSPOSE_8(blockB, i, 0), blockC00); )==" "\n" |
460 | R"==(blockC01 = FMA8( )==" "\n" |
461 | R"==(blockA[i], TRANSPOSE_8(blockB, i, 8), blockC01); )==" "\n" |
462 | R"==(} )==" "\n" |
463 | R"==(#elif IS_DW )==" "\n" |
464 | R"==(for (int i = 0; i < min(OW_LOOP_BLOCK, OW - ow); i++) { )==" "\n" |
465 | R"==(blockC00 = fma(blockA[i], blockB[i], blockC00); )==" "\n" |
466 | R"==(} )==" "\n" |
467 | R"==(#else )==" "\n" |
468 | R"==(for (int i = 0; i < min(OW_LOOP_BLOCK, OW - ow); i++) { )==" "\n" |
469 | R"==(blockC00 = FMA8( )==" "\n" |
470 | R"==(blockB[i], TRANSPOSE_8(blockA, i, 0), blockC00); )==" "\n" |
471 | R"==(blockC01 = FMA8( )==" "\n" |
472 | R"==(blockB[i], TRANSPOSE_8(blockA, i, 8), blockC01); )==" "\n" |
473 | R"==(} )==" "\n" |
474 | R"==(#endif )==" "\n" |
475 | R"==(#if WITH_BIAS == 1 )==" "\n" |
476 | R"==(for (int i = 0; i < min(OW_LOOP_BLOCK, OW - ow); i++) )==" "\n" |
477 | R"==(bias_loc += blockB[i]; )==" "\n" |
478 | R"==(#endif )==" "\n" |
479 | R"==(} )==" "\n" |
480 | R"==(} )==" "\n" |
481 | R"==(src += G * IC * ID * IH * IW * MB_BLOCK; )==" "\n" |
482 | R"==(} )==" "\n" |
483 | R"==(#if WITH_BIAS == 1 )==" "\n" |
484 | R"==(if (do_bias )==" "\n" |
485 | R"==(&& oc * OC_BLOCK + sglid < (IS_DW ? G_WO_PADDING : OC_WO_PADDING)) )==" "\n" |
486 | R"==(atomic_add_global(diff_bias, bias_loc); )==" "\n" |
487 | R"==(#endif )==" "\n" |
488 | R"==(#if IC == 3 )==" "\n" |
489 | R"==(diff_wei += ic * OC_BLOCK + oc * KD * KH * KW * IC * OC_BLOCK )==" "\n" |
490 | R"==(+ g * OC * IC * KD * KH * KW + kd * KH * KW * IC * OC_BLOCK )==" "\n" |
491 | R"==(+ kh * KW * IC * OC_BLOCK + kw * IC * OC_BLOCK; )==" "\n" |
492 | R"==(if (ksp >= KH * KW * KD * IC) return; )==" "\n" |
493 | R"==(for (int i = 0; i < 8; i++) )==" "\n" |
494 | R"==(atomic_add_global(diff_wei + i, blockC00[i]); )==" "\n" |
495 | R"==(for (int i = 0; i < 8; i++) )==" "\n" |
496 | R"==(atomic_add_global(diff_wei + 8 + i, blockC01[i]); )==" "\n" |
497 | R"==(#elif IS_DW )==" "\n" |
498 | R"==(diff_wei += oc * KD * KH * KW * OC_BLOCK + kd * KH * KW * OC_BLOCK )==" "\n" |
499 | R"==(+ kh * KW * OC_BLOCK + kw * OC_BLOCK; )==" "\n" |
500 | R"==(atomic_add_global(diff_wei + sglid, blockC00); )==" "\n" |
501 | R"==(#else )==" "\n" |
502 | R"==(diff_wei += ic * OC * KD * KH * KW * IC_BLOCK )==" "\n" |
503 | R"==(+ oc * KD * KH * KW * IC_BLOCK * OC_BLOCK )==" "\n" |
504 | R"==(+ kd * KH * KW * IC_BLOCK * OC_BLOCK + kh * KW * IC_BLOCK * OC_BLOCK )==" "\n" |
505 | R"==(+ kw * IC_BLOCK * OC_BLOCK + g * OC * IC * KD * KH * KW; )==" "\n" |
506 | R"==(for (int i = 0; i < 8; i++) )==" "\n" |
507 | R"==(atomic_add_global(diff_wei + i * OC_BLOCK + sglid, blockC00[i]); )==" "\n" |
508 | R"==(for (int i = 0; i < 8; i++) )==" "\n" |
509 | R"==(atomic_add_global(diff_wei + (8 + i) * OC_BLOCK + sglid, blockC01[i]); )==" "\n" |
510 | R"==(#endif )==" "\n" |
511 | R"==(#endif )==" "\n" |
512 | R"==(} )==" "\n" |
513 | R"==(#endif )==" "\n" |
514 | R"==()==" ; |
515 | } |
516 | } |
517 | } |
518 | } |