1 | namespace dnnl { |
2 | namespace impl { |
3 | namespace gpu { |
4 | namespace ocl { |
5 | const char *gen9_conv_nhwc_bwd_weights_f32_kernel = R"==(/******************************************************************************* )==" "\n" |
6 | R"==(* Copyright 2020-2021 Intel Corporation )==" "\n" |
7 | R"==(* )==" "\n" |
8 | R"==(* Licensed under the Apache License, Version 2.0 (the "License"); )==" "\n" |
9 | R"==(* you may not use this file except in compliance with the License. )==" "\n" |
10 | R"==(* You may obtain a copy of the License at )==" "\n" |
11 | R"==(* )==" "\n" |
12 | R"==(* http: )==" "\n" |
13 | R"==(* )==" "\n" |
14 | R"==(* Unless required by applicable law or agreed to in writing, software )==" "\n" |
15 | R"==(* distributed under the License is distributed on an "AS IS" BASIS, )==" "\n" |
16 | R"==(* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. )==" "\n" |
17 | R"==(* See the License for the specific language governing permissions and )==" "\n" |
18 | R"==(* limitations under the License. )==" "\n" |
19 | R"==(*******************************************************************************/ )==" "\n" |
20 | R"==(#include "gpu/ocl/ocl_math_utils.h" )==" "\n" |
21 | R"==(#include "gpu/ocl/ocl_types.h" )==" "\n" |
22 | R"==(#if ID > 1 )==" "\n" |
23 | R"==(#define CASE_3D 1 )==" "\n" |
24 | R"==(#else )==" "\n" |
25 | R"==(#define CASE_3D 0 )==" "\n" |
26 | R"==(#endif )==" "\n" |
27 | R"==(#define DIV_UP(a, b) (((a) + (b)-1) / (b)) )==" "\n" |
28 | R"==(#define RND_UP(a, b) (DIV_UP(a, b) * (b)) )==" "\n" |
29 | R"==(#if BWD_WEIGHTS == 1 )==" "\n" |
30 | R"==(inline float read_ic_block(const __global float *ptr, int off) { )==" "\n" |
31 | R"==(#if (IS_DW ? G : IC) % IC_BLOCK != 0 )==" "\n" |
32 | R"==(int tail = (IS_DW ? G : IC) - off; )==" "\n" |
33 | R"==(if (tail < IC_BLOCK) { )==" "\n" |
34 | R"==(const int sglid = get_sub_group_local_id(); )==" "\n" |
35 | R"==(return (sglid < tail) ? ptr[sglid] : 0.0f; )==" "\n" |
36 | R"==(} )==" "\n" |
37 | R"==(#endif )==" "\n" |
38 | R"==(return as_float(intel_sub_group_block_read((const __global uint *)ptr)); )==" "\n" |
39 | R"==(} )==" "\n" |
40 | R"==(inline float read_oc_block(const __global float *ptr, int off) { )==" "\n" |
41 | R"==(#if (IS_DW ? G : OC_WO_PADDING) % OC_BLOCK != 0 )==" "\n" |
42 | R"==(int tail = (IS_DW ? G : OC_WO_PADDING) - off; )==" "\n" |
43 | R"==(if (tail < OC_BLOCK) { )==" "\n" |
44 | R"==(const int sglid = get_sub_group_local_id(); )==" "\n" |
45 | R"==(return (sglid < tail) ? ptr[sglid] : 0.0f; )==" "\n" |
46 | R"==(} )==" "\n" |
47 | R"==(#endif )==" "\n" |
48 | R"==(return as_float(intel_sub_group_block_read((const __global uint *)ptr)); )==" "\n" |
49 | R"==(} )==" "\n" |
50 | R"==(__attribute__((reqd_work_group_size(LWS_0, LWS_1, LWS_2))) )==" "\n" |
51 | R"==(__attribute__((intel_reqd_sub_group_size(SUB_GROUP_SIZE))) )==" "\n" |
52 | R"==(__kernel void )==" "\n" |
53 | R"==(gen9_conv_nhwc_bwd_weights(__global float *src, )==" "\n" |
54 | R"==(volatile __global atomic_float *diff_wei, )==" "\n" |
55 | R"==(volatile __global atomic_float *diff_bias, __global float *diff_dst) { )==" "\n" |
56 | R"==(MAYBE_SKIP_NON_UNIFORM_WG(); )==" "\n" |
57 | R"==(const int ksp = get_global_id(1); )==" "\n" |
58 | R"==(#if CASE_3D )==" "\n" |
59 | R"==(const int kd = ksp / (KW * KH); )==" "\n" |
60 | R"==(const int khw = ksp % (KW * KH); )==" "\n" |
61 | R"==(#else )==" "\n" |
62 | R"==(const int khw = ksp; )==" "\n" |
63 | R"==(const int kd = 0; )==" "\n" |
64 | R"==(#endif )==" "\n" |
65 | R"==(const int kh = khw / KW; )==" "\n" |
66 | R"==(const int kw = khw % KW; )==" "\n" |
67 | R"==(const int sglid = get_sub_group_local_id(); )==" "\n" |
68 | R"==(const int chunk = get_global_id(2) % NCHUNK; )==" "\n" |
69 | R"==(const int icb_ocb = get_global_id(2) / NCHUNK; )==" "\n" |
70 | R"==(const int icb = icb_ocb % DIV_UP(IC, ICB); )==" "\n" |
71 | R"==(const int ocb = icb_ocb / DIV_UP(IC, ICB); )==" "\n" |
72 | R"==(const int ic_padded = RND_UP(IC, IC_BLOCK); )==" "\n" |
73 | R"==(const int oc_padded = RND_UP(OC, OC_BLOCK); )==" "\n" |
74 | R"==(#if IS_DW )==" "\n" |
75 | R"==(const int g = 0; )==" "\n" |
76 | R"==(const int oc )==" "\n" |
77 | R"==(= get_group_id(0) * (LWS_0 / SUB_GROUP_SIZE) + get_sub_group_id(); )==" "\n" |
78 | R"==(const int ic = oc; )==" "\n" |
79 | R"==(#else )==" "\n" |
80 | R"==(const int g_ic_oc = get_global_id(0); )==" "\n" |
81 | R"==(const int g = g_ic_oc / (oc_padded * DIV_UP(IC, IC_BLOCK)); )==" "\n" |
82 | R"==(const int io = g_ic_oc % (oc_padded * DIV_UP(IC, IC_BLOCK)); )==" "\n" |
83 | R"==(const int oc = (io % OCB) / OC_BLOCK + ocb * (OCB / OC_BLOCK); )==" "\n" |
84 | R"==(const int ic = (IC == 3) ? 0 : (io / OCB + icb * (ICB / IC_BLOCK)); )==" "\n" |
85 | R"==(#endif )==" "\n" |
86 | R"==(const int sp_chunk = chunk % OSP_CHUNK; )==" "\n" |
87 | R"==(const int mb_chunk = chunk / OSP_CHUNK; )==" "\n" |
88 | R"==(const int ow_nb = (OW + OWB - 1) / OWB; )==" "\n" |
89 | R"==(const int oh_nb = (OH + OHB - 1) / OHB; )==" "\n" |
90 | R"==(const int od_beg = ((sp_chunk / ow_nb) / oh_nb) * ODB; )==" "\n" |
91 | R"==(const int oh_beg = ((sp_chunk / ow_nb) % oh_nb) * OHB; )==" "\n" |
92 | R"==(const int ow_beg = (sp_chunk % ow_nb) * OWB; )==" "\n" |
93 | R"==(const int mb = mb_chunk * MB_CHUNK_SIZE; )==" "\n" |
94 | R"==(const int mb_end = min((mb_chunk + 1) * MB_CHUNK_SIZE, MB); )==" "\n" |
95 | R"==(const bool do_bias = (ic == 0 || IS_DW) && kh == 0 && kw == 0 && kd == 0; )==" "\n" |
96 | R"==(src += mb * ID * IH * IW * G * IC; )==" "\n" |
97 | R"==(src += g * IC + ic * IC_BLOCK; )==" "\n" |
98 | R"==(diff_dst += g * OC_WO_PADDING + oc * OC_BLOCK; )==" "\n" |
99 | R"==(#if WITH_BIAS == 1 )==" "\n" |
100 | R"==(diff_bias += g * OC_WO_PADDING + oc * OC_BLOCK + sglid; )==" "\n" |
101 | R"==(float bias_loc = 0.0f; )==" "\n" |
102 | R"==(#endif )==" "\n" |
103 | R"==(#if IC == 3 )==" "\n" |
104 | R"==(float8 blockC00 = 0.0f; )==" "\n" |
105 | R"==(#elif IS_DW )==" "\n" |
106 | R"==(float blockC00 = 0.0f; )==" "\n" |
107 | R"==(#else )==" "\n" |
108 | R"==(float8 blockC00 = 0.0f; )==" "\n" |
109 | R"==(float8 blockC01 = 0.0f; )==" "\n" |
110 | R"==(#endif )==" "\n" |
111 | R"==(for (int omb = mb; omb < mb_end; omb++) { )==" "\n" |
112 | R"==(const __global float *diff_dst1_ )==" "\n" |
113 | R"==(= diff_dst + omb * OD * OH * OW * G * OC_WO_PADDING; )==" "\n" |
114 | R"==(for (int od = od_beg; od < min(od_beg + ODB, OD); od++) )==" "\n" |
115 | R"==(for (int oh = oh_beg; oh < min(oh_beg + OHB, OH); oh++) { )==" "\n" |
116 | R"==(const __global float *diff_dst1 = diff_dst1_ )==" "\n" |
117 | R"==(+ (od * OH * OW + oh * OW) * G * OC_WO_PADDING; )==" "\n" |
118 | R"==(if (oh * SH + kh * (1 + DH) < PH )==" "\n" |
119 | R"==(|| oh * SH + kh * (1 + DH) >= IH + PH )==" "\n" |
120 | R"==(#if CASE_3D )==" "\n" |
121 | R"==(|| od * SD + kd * (1 + DD) < PD )==" "\n" |
122 | R"==(|| od * SD + kd * (1 + DD) >= ID + PD )==" "\n" |
123 | R"==(#endif )==" "\n" |
124 | R"==() { )==" "\n" |
125 | R"==(#if WITH_BIAS == 1 )==" "\n" |
126 | R"==(if (do_bias) { )==" "\n" |
127 | R"==(for (int ow = ow_beg; ow < ow_beg + OWB; )==" "\n" |
128 | R"==(ow += OW_BLOCK) { )==" "\n" |
129 | R"==(float8 blockB; )==" "\n" |
130 | R"==(for (int i = 0; i < OW_BLOCK; i++) { )==" "\n" |
131 | R"==(if (ow + i >= OW) { )==" "\n" |
132 | R"==(blockB[i] = 0.0; )==" "\n" |
133 | R"==(} else { )==" "\n" |
134 | R"==(blockB[i] = read_oc_block( )==" "\n" |
135 | R"==(&diff_dst1[(ow + i) * G )==" "\n" |
136 | R"==(* OC_WO_PADDING], )==" "\n" |
137 | R"==(oc * OC_BLOCK); )==" "\n" |
138 | R"==(} )==" "\n" |
139 | R"==(} )==" "\n" |
140 | R"==(for (int i = 0; i < OW_BLOCK; i++) )==" "\n" |
141 | R"==(bias_loc += blockB[i]; )==" "\n" |
142 | R"==(} )==" "\n" |
143 | R"==(} )==" "\n" |
144 | R"==(#endif )==" "\n" |
145 | R"==(continue; )==" "\n" |
146 | R"==(} )==" "\n" |
147 | R"==(for (int ow = ow_beg; ow < ow_beg + OWB; ow += OW_BLOCK) { )==" "\n" |
148 | R"==(const int id = od * SD - PD + kd * (1 + DD); )==" "\n" |
149 | R"==(const int ih = oh * SH - PH + kh * (1 + DH); )==" "\n" |
150 | R"==(const int iw = ow * SW - PW + kw * (1 + DW); )==" "\n" |
151 | R"==(__global float *src1 )==" "\n" |
152 | R"==(= src + (id * IH * IW + ih * IW + iw) * G * IC; )==" "\n" |
153 | R"==(#define TRANSPOSE_8(_block, _row, _col) \ )==" "\n" |
154 | R"==({ \ )==" "\n" |
155 | R"==((float8)(intel_sub_group_shuffle(_block[_row], 0 + _col), \ )==" "\n" |
156 | R"==(intel_sub_group_shuffle(_block[_row], 1 + _col), \ )==" "\n" |
157 | R"==(intel_sub_group_shuffle(_block[_row], 2 + _col), \ )==" "\n" |
158 | R"==(intel_sub_group_shuffle(_block[_row], 3 + _col), \ )==" "\n" |
159 | R"==(intel_sub_group_shuffle(_block[_row], 4 + _col), \ )==" "\n" |
160 | R"==(intel_sub_group_shuffle(_block[_row], 5 + _col), \ )==" "\n" |
161 | R"==(intel_sub_group_shuffle(_block[_row], 6 + _col), \ )==" "\n" |
162 | R"==(intel_sub_group_shuffle(_block[_row], 7 + _col)) \ )==" "\n" |
163 | R"==(} )==" "\n" |
164 | R"==(#define FMA8(a, b, c) fma((float8)(a), (float8)b, (float8)c) )==" "\n" |
165 | R"==(#define MULTIPLY_BLOCKS_8x8(_result, _blockA, _blockB, col) \ )==" "\n" |
166 | R"==({ \ )==" "\n" |
167 | R"==(_result = FMA8(_blockB.s0, TRANSPOSE_8(_blockA, 0, col), _result); \ )==" "\n" |
168 | R"==(_result = FMA8(_blockB.s1, TRANSPOSE_8(_blockA, 1, col), _result); \ )==" "\n" |
169 | R"==(_result = FMA8(_blockB.s2, TRANSPOSE_8(_blockA, 2, col), _result); \ )==" "\n" |
170 | R"==(_result = FMA8(_blockB.s3, TRANSPOSE_8(_blockA, 3, col), _result); \ )==" "\n" |
171 | R"==(_result = FMA8(_blockB.s4, TRANSPOSE_8(_blockA, 4, col), _result); \ )==" "\n" |
172 | R"==(_result = FMA8(_blockB.s5, TRANSPOSE_8(_blockA, 5, col), _result); \ )==" "\n" |
173 | R"==(_result = FMA8(_blockB.s6, TRANSPOSE_8(_blockA, 6, col), _result); \ )==" "\n" |
174 | R"==(_result = FMA8(_blockB.s7, TRANSPOSE_8(_blockA, 7, col), _result); \ )==" "\n" |
175 | R"==(} )==" "\n" |
176 | R"==(float8 blockA, blockB; )==" "\n" |
177 | R"==(#if IC == 3 )==" "\n" |
178 | R"==(if (sglid < IC) { )==" "\n" |
179 | R"==(for (int i = 0; i < OW_BLOCK; i++) { )==" "\n" |
180 | R"==(if (iw + i * SW < 0 || iw + i * SW >= IW) { )==" "\n" |
181 | R"==(blockA[i] = 0; )==" "\n" |
182 | R"==(} else { )==" "\n" |
183 | R"==(blockA[i] = src1[i * SW * G * IC + sglid]; )==" "\n" |
184 | R"==(} )==" "\n" |
185 | R"==(} )==" "\n" |
186 | R"==(} else { )==" "\n" |
187 | R"==(blockA = 0.0f; )==" "\n" |
188 | R"==(} )==" "\n" |
189 | R"==(#else )==" "\n" |
190 | R"==(__attribute__((opencl_unroll_hint(8))) )==" "\n" |
191 | R"==(for (int i = 0; i < OW_BLOCK; i++) { )==" "\n" |
192 | R"==(if (iw + i * SW < 0 || iw + i * SW >= IW) { )==" "\n" |
193 | R"==(blockA[i] = 0; )==" "\n" |
194 | R"==(} else { )==" "\n" |
195 | R"==(blockA[i] = read_ic_block( )==" "\n" |
196 | R"==(&src1[i * SW * G * IC], ic * IC_BLOCK); )==" "\n" |
197 | R"==(} )==" "\n" |
198 | R"==(} )==" "\n" |
199 | R"==(#endif )==" "\n" |
200 | R"==(__attribute__((opencl_unroll_hint(8))) )==" "\n" |
201 | R"==(for (int i = 0; i < OW_BLOCK; i++) { )==" "\n" |
202 | R"==(if (ow + i >= OW) { )==" "\n" |
203 | R"==(blockB[i] = 0.0; )==" "\n" |
204 | R"==(} else { )==" "\n" |
205 | R"==(blockB[i] = read_oc_block( )==" "\n" |
206 | R"==(&diff_dst1[(ow + i) * G * OC_WO_PADDING], )==" "\n" |
207 | R"==(oc * OC_BLOCK); )==" "\n" |
208 | R"==(} )==" "\n" |
209 | R"==(} )==" "\n" |
210 | R"==(#if IC == 3 )==" "\n" |
211 | R"==(MULTIPLY_BLOCKS_8x8(blockC00, blockA, blockB, 0); )==" "\n" |
212 | R"==(#elif IS_DW )==" "\n" |
213 | R"==(for (int i = 0; i < OW_BLOCK; i++) { )==" "\n" |
214 | R"==(blockC00 = fma(blockA[i], blockB[i], blockC00); )==" "\n" |
215 | R"==(} )==" "\n" |
216 | R"==(#else )==" "\n" |
217 | R"==(MULTIPLY_BLOCKS_8x8(blockC00, blockA, blockB, 0); )==" "\n" |
218 | R"==(MULTIPLY_BLOCKS_8x8(blockC01, blockA, blockB, 8); )==" "\n" |
219 | R"==(#endif )==" "\n" |
220 | R"==(#if WITH_BIAS == 1 )==" "\n" |
221 | R"==(for (int i = 0; i < 8; i++) )==" "\n" |
222 | R"==(bias_loc += blockB[i]; )==" "\n" |
223 | R"==(#endif )==" "\n" |
224 | R"==(} )==" "\n" |
225 | R"==(} )==" "\n" |
226 | R"==(src += ID * IH * IW * G * IC; )==" "\n" |
227 | R"==(} )==" "\n" |
228 | R"==(#if WITH_BIAS == 1 )==" "\n" |
229 | R"==(if (do_bias && oc * OC_BLOCK + sglid < (IS_DW ? G : OC_WO_PADDING)) )==" "\n" |
230 | R"==(atomic_add_global(diff_bias, bias_loc); )==" "\n" |
231 | R"==(#endif )==" "\n" |
232 | R"==(#if IC == 3 )==" "\n" |
233 | R"==(diff_wei += g * oc_padded * ic_padded * KD * KH * KW; )==" "\n" |
234 | R"==(diff_wei += oc * KD * KH * KW * ic_padded * OC_BLOCK; )==" "\n" |
235 | R"==(diff_wei += (kd * KH * KW + kh * KW + kw) * ic_padded * OC_BLOCK; )==" "\n" |
236 | R"==(for (int i = 0; i < 3; i++) )==" "\n" |
237 | R"==(atomic_add_global(diff_wei + i * OC_BLOCK + sglid, blockC00[i]); )==" "\n" |
238 | R"==(#elif IS_DW )==" "\n" |
239 | R"==(diff_wei += oc * KD * KH * KW * OC_BLOCK; )==" "\n" |
240 | R"==(diff_wei += (kd * KH * KW + kh * KW + kw) * OC_BLOCK; )==" "\n" |
241 | R"==(atomic_add_global(diff_wei + sglid, blockC00); )==" "\n" |
242 | R"==(#else )==" "\n" |
243 | R"==(diff_wei += g * ic_padded * oc_padded * KD * KH * KW; )==" "\n" |
244 | R"==(diff_wei += ic * oc_padded * KD * KH * KW * IC_BLOCK; )==" "\n" |
245 | R"==(diff_wei += oc * KD * KH * KW * IC_BLOCK * OC_BLOCK; )==" "\n" |
246 | R"==(diff_wei += (kd * KH * KW + kh * KW + kw) * IC_BLOCK * OC_BLOCK; )==" "\n" |
247 | R"==(for (int i = 0; i < 8; i++) )==" "\n" |
248 | R"==(atomic_add_global(diff_wei + i * OC_BLOCK + sglid, blockC00[i]); )==" "\n" |
249 | R"==(for (int i = 0; i < 8; i++) )==" "\n" |
250 | R"==(atomic_add_global(diff_wei + (8 + i) * OC_BLOCK + sglid, blockC01[i]); )==" "\n" |
251 | R"==(#endif )==" "\n" |
252 | R"==(} )==" "\n" |
253 | R"==(#endif )==" "\n" |
254 | R"==()==" ; |
255 | } |
256 | } |
257 | } |
258 | } |