1 | namespace dnnl { |
2 | namespace impl { |
3 | namespace gpu { |
4 | namespace ocl { |
5 | const char *ref_rnn_kernel = R"==(/******************************************************************************* )==" "\n" |
6 | R"==(* Copyright 2019-2022 Intel Corporation )==" "\n" |
7 | R"==(* )==" "\n" |
8 | R"==(* Licensed under the Apache License, Version 2.0 (the "License"); )==" "\n" |
9 | R"==(* you may not use this file except in compliance with the License. )==" "\n" |
10 | R"==(* You may obtain a copy of the License at )==" "\n" |
11 | R"==(* )==" "\n" |
12 | R"==(* http: )==" "\n" |
13 | R"==(* )==" "\n" |
14 | R"==(* Unless required by applicable law or agreed to in writing, software )==" "\n" |
15 | R"==(* distributed under the License is distributed on an "AS IS" BASIS, )==" "\n" |
16 | R"==(* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. )==" "\n" |
17 | R"==(* See the License for the specific language governing permissions and )==" "\n" |
18 | R"==(* limitations under the License. )==" "\n" |
19 | R"==(*******************************************************************************/ )==" "\n" |
20 | R"==(#include "gpu/ocl/rnn/rnn_types.h" )==" "\n" |
21 | R"==(float one_m_square(float a) { )==" "\n" |
22 | R"==(return 1.0f - a * a; )==" "\n" |
23 | R"==(} )==" "\n" |
24 | R"==(float x_m_square(float a) { )==" "\n" |
25 | R"==(return (1.0f - a) * a; )==" "\n" |
26 | R"==(} )==" "\n" |
27 | R"==(float relu_fwd(float s, float alpha) { )==" "\n" |
28 | R"==(return s > 0 ? s : s * alpha; )==" "\n" |
29 | R"==(} )==" "\n" |
30 | R"==(float tanh_fwd(float s) { )==" "\n" |
31 | R"==(return tanh(s); )==" "\n" |
32 | R"==(} )==" "\n" |
33 | R"==(float logistic_fwd(float s) { )==" "\n" |
34 | R"==(return 1 / (1 + exp((float)-s)); )==" "\n" |
35 | R"==(} )==" "\n" |
36 | R"==(float logistic_bwd(float s) { )==" "\n" |
37 | R"==(return x_m_square(s); )==" "\n" |
38 | R"==(} )==" "\n" |
39 | R"==(float relu_bwd(float s, float alpha) { )==" "\n" |
40 | R"==(return s > 0 ? 1.f : alpha; )==" "\n" |
41 | R"==(} )==" "\n" |
42 | R"==(float tanh_bwd(float s) { )==" "\n" |
43 | R"==(return (1 - s) * (1 + s); )==" "\n" |
44 | R"==(} )==" "\n" |
45 | R"==(float linear(float s, float alpha) { )==" "\n" |
46 | R"==(return alpha * s; )==" "\n" |
47 | R"==(} )==" "\n" |
48 | R"==(float relu_fwd_tm(float s, float alpha) { )==" "\n" |
49 | R"==(#if !IS_TESTMODE )==" "\n" |
50 | R"==(return relu_fwd(s, alpha); )==" "\n" |
51 | R"==(#else )==" "\n" |
52 | R"==(return linear(s, alpha); )==" "\n" |
53 | R"==(#endif )==" "\n" |
54 | R"==(} )==" "\n" |
55 | R"==(float tanh_fwd_tm(float s, float alpha) { )==" "\n" |
56 | R"==(#if !IS_TESTMODE )==" "\n" |
57 | R"==(return tanh(s); )==" "\n" |
58 | R"==(#else )==" "\n" |
59 | R"==(return linear(s, alpha); )==" "\n" |
60 | R"==(#endif )==" "\n" |
61 | R"==(} )==" "\n" |
62 | R"==(float logistic_fwd_tm(float s, float alpha) { )==" "\n" |
63 | R"==(#if !IS_TESTMODE )==" "\n" |
64 | R"==(return logistic_fwd(s); )==" "\n" |
65 | R"==(#else )==" "\n" |
66 | R"==(return linear(s, alpha); )==" "\n" |
67 | R"==(#endif )==" "\n" |
68 | R"==(} )==" "\n" |
69 | R"==(float relu_bwd_tm(float s, float alpha) { )==" "\n" |
70 | R"==(#if !IS_TESTMODE )==" "\n" |
71 | R"==(return relu_bwd(s, alpha); )==" "\n" |
72 | R"==(#else )==" "\n" |
73 | R"==(return linear(s, alpha); )==" "\n" |
74 | R"==(#endif )==" "\n" |
75 | R"==(} )==" "\n" |
76 | R"==(float tanh_bwd_tm(float s, float alpha) { )==" "\n" |
77 | R"==(#if !IS_TESTMODE )==" "\n" |
78 | R"==(return tanh_bwd(s); )==" "\n" |
79 | R"==(#else )==" "\n" |
80 | R"==(return linear(s, alpha); )==" "\n" |
81 | R"==(#endif )==" "\n" |
82 | R"==(} )==" "\n" |
83 | R"==(float logistic_bwd_tm(float s, float alpha) { )==" "\n" |
84 | R"==(#if !IS_TESTMODE )==" "\n" |
85 | R"==(return logistic_bwd(s); )==" "\n" |
86 | R"==(#else )==" "\n" |
87 | R"==(return linear(s, alpha); )==" "\n" |
88 | R"==(#endif )==" "\n" |
89 | R"==(} )==" "\n" |
90 | R"==(float activation_fwd(float s, float alpha, float cliping) { )==" "\n" |
91 | R"==(#if CELL_KIND == VANILLA_RNN )==" "\n" |
92 | R"==(#if ACTIVATION_KIND == ELTWISE_RELU )==" "\n" |
93 | R"==(return relu_fwd_tm(s, alpha); )==" "\n" |
94 | R"==(#elif ACTIVATION_KIND == ELTWISE_TANH )==" "\n" |
95 | R"==(return tanh_fwd_tm(s, alpha); )==" "\n" |
96 | R"==(#elif ACTIVATION_KIND == ELTWISE_LOGISTIC )==" "\n" |
97 | R"==(return logistic_fwd_tm(s, alpha); )==" "\n" |
98 | R"==(#else )==" "\n" |
99 | R"==(#error "Unsupported activation_kind" )==" "\n" |
100 | R"==(#endif )==" "\n" |
101 | R"==(#else )==" "\n" |
102 | R"==(return 0.0f; )==" "\n" |
103 | R"==(#endif )==" "\n" |
104 | R"==(} )==" "\n" |
105 | R"==(float activation_bwd(float s, float alpha, float cliping) { )==" "\n" |
106 | R"==(#if CELL_KIND == VANILLA_RNN )==" "\n" |
107 | R"==(#if ACTIVATION_KIND == ELTWISE_RELU )==" "\n" |
108 | R"==(return relu_bwd_tm(s, alpha); )==" "\n" |
109 | R"==(#elif ACTIVATION_KIND == ELTWISE_TANH )==" "\n" |
110 | R"==(return tanh_bwd_tm(s, alpha); )==" "\n" |
111 | R"==(#elif ACTIVATION_KIND == ELTWISE_LOGISTIC )==" "\n" |
112 | R"==(return logistic_bwd_tm(s, alpha); )==" "\n" |
113 | R"==(#else )==" "\n" |
114 | R"==(#error "Unsupported activation_kind" )==" "\n" |
115 | R"==(#endif )==" "\n" |
116 | R"==(#else )==" "\n" |
117 | R"==(return 0.0f; )==" "\n" |
118 | R"==(#endif )==" "\n" |
119 | R"==(} )==" "\n" |
120 | R"==(__attribute__((intel_reqd_sub_group_size(SUBGROUP_SIZE))) __kernel void )==" "\n" |
121 | R"==(ref_rnn_copy_init_layer(__global char *ws, __global char *src_base, )==" "\n" |
122 | R"==(__global char *scratch_diff_states, int lr, int rl) { )==" "\n" |
123 | R"==(#if IS_FWD )==" "\n" |
124 | R"==(const int it = get_global_id(2); )==" "\n" |
125 | R"==(const int b = get_global_id(1); )==" "\n" |
126 | R"==(const int c = get_global_id(0); )==" "\n" |
127 | R"==(if (c >= SLC || b >= BATCH || it >= N_ITER) return; )==" "\n" |
128 | R"==(__global WS_STATE_DATA_T *dst; )==" "\n" |
129 | R"==(__global WS_STATE_DATA_T *dst_base )==" "\n" |
130 | R"==(= (__global WS_STATE_DATA_T *)(ws + WS_STATES_OFFSET); )==" "\n" |
131 | R"==(__global WS_STATE_DATA_T *src = (__global WS_STATE_DATA_T *)src_base )==" "\n" |
132 | R"==(+ SRC_L_OFF(it, 0, 0) + b * SLC + c; )==" "\n" |
133 | R"==(if (lr) { )==" "\n" |
134 | R"==(dst = dst_base + OFF_WS_STATE(0, 0, it + 1, b, c); )==" "\n" |
135 | R"==(dst[0] = src[0]; )==" "\n" |
136 | R"==(} )==" "\n" |
137 | R"==(if (rl) { )==" "\n" |
138 | R"==(dst = dst_base + OFF_WS_STATE(0, N_DIR - 1, N_ITER - it, b, c); )==" "\n" |
139 | R"==(dst[0] = src[0]; )==" "\n" |
140 | R"==(} )==" "\n" |
141 | R"==(#else )==" "\n" |
142 | R"==(const int it = get_global_id(1); )==" "\n" |
143 | R"==(const int b = get_global_id(0); )==" "\n" |
144 | R"==(if (b >= BATCH || it >= N_ITER) return; )==" "\n" |
145 | R"==(__global DIFF_DATA_T *dst = (__global DIFF_DATA_T *)scratch_diff_states; )==" "\n" |
146 | R"==(#if DIRECTION_KIND == CONCAT )==" "\n" |
147 | R"==(__global DIFF_DATA_T *src )==" "\n" |
148 | R"==(= (__global DIFF_DATA_T *)src_base + DIFF_DST_L_OFF(it, b, 0); )==" "\n" |
149 | R"==(for (int s = 0; s < DHC; s++) { )==" "\n" |
150 | R"==(dst[OFF_SCRATCH_DIFF_STATES(N_LAYER, 0, N_STATES, it, b, s)] = src[s]; )==" "\n" |
151 | R"==(dst[OFF_SCRATCH_DIFF_STATES( )==" "\n" |
152 | R"==(N_LAYER, 1, N_STATES, N_ITER - it - 1, b, s)] )==" "\n" |
153 | R"==(= src[DHC + s]; )==" "\n" |
154 | R"==(} )==" "\n" |
155 | R"==(#elif DIRECTION_KIND == SUM )==" "\n" |
156 | R"==(__global DIFF_DATA_T *src )==" "\n" |
157 | R"==(= (__global DIFF_DATA_T *)src_base + DIFF_DST_L_OFF(it, b, 0); )==" "\n" |
158 | R"==(for (int s = 0; s < DHC; s++) { )==" "\n" |
159 | R"==(dst[OFF_SCRATCH_DIFF_STATES(N_LAYER, 0, N_STATES, it, b, s)] = src[s]; )==" "\n" |
160 | R"==(dst[OFF_SCRATCH_DIFF_STATES( )==" "\n" |
161 | R"==(N_LAYER, 1, N_STATES, N_ITER - it - 1, b, s)] )==" "\n" |
162 | R"==(= src[s]; )==" "\n" |
163 | R"==(} )==" "\n" |
164 | R"==(#elif DIRECTION_KIND == L2R )==" "\n" |
165 | R"==(__global DIFF_DATA_T *src )==" "\n" |
166 | R"==(= (__global DIFF_DATA_T *)src_base + DIFF_DST_L_OFF(it, b, 0); )==" "\n" |
167 | R"==(for (int s = 0; s < DHC; s++) { )==" "\n" |
168 | R"==(dst[OFF_SCRATCH_DIFF_STATES(N_LAYER, 0, N_STATES, it, b, s)] = src[s]; )==" "\n" |
169 | R"==(} )==" "\n" |
170 | R"==(#elif DIRECTION_KIND == R2L )==" "\n" |
171 | R"==(__global DIFF_DATA_T *src = (__global DIFF_DATA_T *)src_base )==" "\n" |
172 | R"==(+ DIFF_DST_L_OFF(N_ITER - it - 1, b, 0); )==" "\n" |
173 | R"==(for (int s = 0; s < DHC; s++) { )==" "\n" |
174 | R"==(dst[OFF_SCRATCH_DIFF_STATES(N_LAYER, 0, N_STATES, it, b, s)] = src[s]; )==" "\n" |
175 | R"==(} )==" "\n" |
176 | R"==(#else )==" "\n" |
177 | R"==(#error "Unsupported direction_kind" )==" "\n" |
178 | R"==(#endif )==" "\n" |
179 | R"==(#endif )==" "\n" |
180 | R"==(} )==" "\n" |
181 | R"==(__kernel void ref_rnn_copy_init_iter(__global char *ws, __global char *src_base, )==" "\n" |
182 | R"==(__global char *src_c_base, __global char *scratch_diff_states )==" "\n" |
183 | R"==(#if IS_FWD )==" "\n" |
184 | R"==(, )==" "\n" |
185 | R"==(const float shift, const float scale, const int quantize )==" "\n" |
186 | R"==(#endif )==" "\n" |
187 | R"==() { )==" "\n" |
188 | R"==(const int s = get_global_id(0); )==" "\n" |
189 | R"==(const int b = get_global_id(1); )==" "\n" |
190 | R"==(const int lay = get_global_id(2) / N_DIR; )==" "\n" |
191 | R"==(const int dir = get_global_id(2) % N_DIR; )==" "\n" |
192 | R"==(#if IS_FWD )==" "\n" |
193 | R"==(__global INPUT_DATA_T *src = (__global INPUT_DATA_T *)(src_base); )==" "\n" |
194 | R"==(__global WS_STATE_DATA_T *dst )==" "\n" |
195 | R"==(= (__global WS_STATE_DATA_T *)(ws + WS_STATES_OFFSET); )==" "\n" |
196 | R"==(if (s < SIC) )==" "\n" |
197 | R"==(dst[OFF_WS_STATE(lay + 1, dir, 0, b, s)] = src_base )==" "\n" |
198 | R"==(? (quantize ? TO_WS_STATE( )==" "\n" |
199 | R"==(src[SRC_I_OFF(lay, dir, b, s)] * scale + shift) )==" "\n" |
200 | R"==(: src[SRC_I_OFF(lay, dir, b, s)]) )==" "\n" |
201 | R"==(: TO_WS_STATE(0.0f); )==" "\n" |
202 | R"==(#if WITH_SRC_ITER_C )==" "\n" |
203 | R"==(__global AUX_DATA_T *src_c = (__global AUX_DATA_T *)(src_c_base); )==" "\n" |
204 | R"==(__global AUX_DATA_T *dst_c )==" "\n" |
205 | R"==(= (__global AUX_DATA_T *)(ws + WS_C_STATE_OFFSET); )==" "\n" |
206 | R"==(if (s < DHC) )==" "\n" |
207 | R"==(dst_c[OFF_WS_STATE(lay + 1, dir, 0, b, s)] = src_c_base )==" "\n" |
208 | R"==(? src_c[SRC_I_C_OFF(lay, dir, b, s)] )==" "\n" |
209 | R"==(: TO_WS_STATE(0.0f); )==" "\n" |
210 | R"==(#endif )==" "\n" |
211 | R"==(#else )==" "\n" |
212 | R"==(__global DIFF_DATA_T *src = (__global DIFF_DATA_T *)(src_base); )==" "\n" |
213 | R"==(__global DIFF_DATA_T *dst = (__global DIFF_DATA_T *)scratch_diff_states; )==" "\n" |
214 | R"==(if (s < DHC) )==" "\n" |
215 | R"==(dst[OFF_SCRATCH_DIFF_STATES(lay, dir, 0, N_ITER, b, s)] )==" "\n" |
216 | R"==(= src_base ? src[DIFF_DST_I_OFF(lay, dir, b, s)] : 0.0f; )==" "\n" |
217 | R"==(#if WITH_DST_ITER_C )==" "\n" |
218 | R"==(__global DIFF_DATA_T *src_c = (__global DIFF_DATA_T *)(src_c_base); )==" "\n" |
219 | R"==(if (s < DHC) )==" "\n" |
220 | R"==(dst[OFF_SCRATCH_DIFF_STATES(lay, dir, 1, N_ITER, b, s)] )==" "\n" |
221 | R"==(= src_c_base ? src_c[DIFF_DST_I_C_OFF(lay, dir, b, s)] : 0.0f; )==" "\n" |
222 | R"==(#endif )==" "\n" |
223 | R"==(#endif )==" "\n" |
224 | R"==(} )==" "\n" |
225 | R"==(__attribute__((intel_reqd_sub_group_size(SUBGROUP_SIZE))) __kernel void )==" "\n" |
226 | R"==(ref_rnn_copy_res_layer(__global char *ws, __global char *dst_base, )==" "\n" |
227 | R"==(__global char *scratch_diff_states, int lr, int rl )==" "\n" |
228 | R"==(#if IS_FWD )==" "\n" |
229 | R"==(, )==" "\n" |
230 | R"==(const float shift, const float scale, const int dequantize )==" "\n" |
231 | R"==(#endif )==" "\n" |
232 | R"==() { )==" "\n" |
233 | R"==(const int it = get_global_id(2); )==" "\n" |
234 | R"==(const int b = get_global_id(1); )==" "\n" |
235 | R"==(const int s = get_global_id(0); )==" "\n" |
236 | R"==(#if IS_FWD )==" "\n" |
237 | R"==(if (s >= DHC || b >= BATCH || it >= N_ITER) return; )==" "\n" |
238 | R"==(__global WS_STATE_DATA_T *src )==" "\n" |
239 | R"==(= (__global WS_STATE_DATA_T *)(ws + WS_STATES_OFFSET); )==" "\n" |
240 | R"==(__global DST_DATA_T *dst = (__global DST_DATA_T *)(dst_base); )==" "\n" |
241 | R"==(int dir = 0; )==" "\n" |
242 | R"==(if (lr) { )==" "\n" |
243 | R"==(bool dequantize_at_copy = dequantize && DIRECTION_KIND != SUM; )==" "\n" |
244 | R"==(dst[DST_L_OFF(it, b, dir * DHC + s)] = dequantize_at_copy )==" "\n" |
245 | R"==(? TO_DST(((float)src[OFF_WS_STATE(N_LAYER, dir, it + 1, b, s)] )==" "\n" |
246 | R"==(- shift) )==" "\n" |
247 | R"==(/ scale) )==" "\n" |
248 | R"==(: src[OFF_WS_STATE(N_LAYER, dir, it + 1, b, s)]; )==" "\n" |
249 | R"==(dir = 1; )==" "\n" |
250 | R"==(} )==" "\n" |
251 | R"==(if (rl) { )==" "\n" |
252 | R"==(#if DIRECTION_KIND == SUM )==" "\n" |
253 | R"==(if (dequantize) { )==" "\n" |
254 | R"==(float val )==" "\n" |
255 | R"==(= (float)src[OFF_WS_STATE(N_LAYER, dir, N_ITER - it, b, s)] )==" "\n" |
256 | R"==(+ dst[DST_L_OFF(it, b, s)]; )==" "\n" |
257 | R"==(val = min(max(val, 0.f), 255.f); )==" "\n" |
258 | R"==(dst[DST_L_OFF(it, b, s)] = TO_DST((val - 2 * shift) / scale); )==" "\n" |
259 | R"==(} else { )==" "\n" |
260 | R"==(#if defined(SRC_DT_U8) && defined(DST_DT_U8) )==" "\n" |
261 | R"==(dst[DST_L_OFF(it, b, s)] = convert_uchar_sat( )==" "\n" |
262 | R"==(convert_short( )==" "\n" |
263 | R"==(src[OFF_WS_STATE(N_LAYER, dir, N_ITER - it, b, s)]) )==" "\n" |
264 | R"==(+ convert_short(dst[DST_L_OFF(it, b, s)])); )==" "\n" |
265 | R"==(#else )==" "\n" |
266 | R"==(ACC_DATA_T temp_src = DST_TO_REF(dst[DST_L_OFF(it, b, s)]); )==" "\n" |
267 | R"==(temp_src += DST_TO_REF( )==" "\n" |
268 | R"==(src[OFF_WS_STATE(N_LAYER, dir, N_ITER - it, b, s)]); )==" "\n" |
269 | R"==(dst[DST_L_OFF(it, b, s)] = REF_TO_DST(temp_src); )==" "\n" |
270 | R"==(#endif )==" "\n" |
271 | R"==(} )==" "\n" |
272 | R"==(#else )==" "\n" |
273 | R"==(dst[DST_L_OFF(it, b, dir * DHC + s)] = dequantize )==" "\n" |
274 | R"==(? TO_DST(((float)src[OFF_WS_STATE( )==" "\n" |
275 | R"==(N_LAYER, dir, N_ITER - it, b, s)] )==" "\n" |
276 | R"==(- shift) )==" "\n" |
277 | R"==(/ scale) )==" "\n" |
278 | R"==(: src[OFF_WS_STATE(N_LAYER, dir, N_ITER - it, b, s)]; )==" "\n" |
279 | R"==(#endif )==" "\n" |
280 | R"==(} )==" "\n" |
281 | R"==(#else )==" "\n" |
282 | R"==(if (s >= SLC || b >= BATCH || it >= N_ITER) return; )==" "\n" |
283 | R"==(__global DIFF_DATA_T *src = (__global DIFF_DATA_T *)(scratch_diff_states); )==" "\n" |
284 | R"==(__global DIFF_DATA_T *dst = (__global DIFF_DATA_T *)(dst_base); )==" "\n" |
285 | R"==(int dir = 0; )==" "\n" |
286 | R"==(#if DIRECTION_KIND == R2L )==" "\n" |
287 | R"==(const int iter = N_ITER - 1 - it; )==" "\n" |
288 | R"==(#else )==" "\n" |
289 | R"==(const int iter = it; )==" "\n" |
290 | R"==(#endif )==" "\n" |
291 | R"==(DIFF_DATA_T res = src[OFF_SCRATCH_DIFF_STATES(0, 0, N_STATES, it, b, s)]; )==" "\n" |
292 | R"==(#if N_DIR > 1 )==" "\n" |
293 | R"==(res += src[OFF_SCRATCH_DIFF_STATES(0, 1, N_STATES, N_ITER - 1 - it, b, s)]; )==" "\n" |
294 | R"==(#endif )==" "\n" |
295 | R"==(dst[DIFF_SRC_L_OFF(iter, b, dir * SLC + s)] = res; )==" "\n" |
296 | R"==(#endif )==" "\n" |
297 | R"==(} )==" "\n" |
298 | R"==(__kernel void ref_rnn_copy_res_iter(__global char *ws, __global char *dst_base, )==" "\n" |
299 | R"==(__global char *dst_c_base, __global char *scratch_diff_states )==" "\n" |
300 | R"==(#if IS_FWD )==" "\n" |
301 | R"==(, )==" "\n" |
302 | R"==(const float shift, const float scale, const int dequantize )==" "\n" |
303 | R"==(#endif )==" "\n" |
304 | R"==() { )==" "\n" |
305 | R"==(const int s = get_global_id(0); )==" "\n" |
306 | R"==(const int b = get_global_id(1); )==" "\n" |
307 | R"==(const int lay = get_global_id(2) / N_DIR; )==" "\n" |
308 | R"==(const int dir = get_global_id(2) % N_DIR; )==" "\n" |
309 | R"==(#if IS_FWD )==" "\n" |
310 | R"==(__global WS_STATE_DATA_T *src )==" "\n" |
311 | R"==(= (__global WS_STATE_DATA_T *)(ws + WS_STATES_OFFSET); )==" "\n" |
312 | R"==(__global OUTPUT_DATA_T *dst = (__global OUTPUT_DATA_T *)(dst_base); )==" "\n" |
313 | R"==(if (dst_base && s < DHC) { )==" "\n" |
314 | R"==(dst[DST_I_OFF(lay, dir, b, s)] = dequantize )==" "\n" |
315 | R"==(? TO_OUTPUT( )==" "\n" |
316 | R"==(((float)src[OFF_WS_STATE(lay + 1, dir, N_ITER, b, s)] )==" "\n" |
317 | R"==(- shift) )==" "\n" |
318 | R"==(/ scale) )==" "\n" |
319 | R"==(: TO_OUTPUT(src[OFF_WS_STATE(lay + 1, dir, N_ITER, b, s)]); )==" "\n" |
320 | R"==(} )==" "\n" |
321 | R"==(#if WITH_DST_ITER_C )==" "\n" |
322 | R"==(__global AUX_DATA_T *src_c )==" "\n" |
323 | R"==(= (__global AUX_DATA_T *)(ws + WS_C_STATE_OFFSET); )==" "\n" |
324 | R"==(__global AUX_DATA_T *dst_c = (__global AUX_DATA_T *)(dst_c_base); )==" "\n" |
325 | R"==(if (dst_c_base && s < DHC) { )==" "\n" |
326 | R"==(dst_c[DST_I_C_OFF(lay, dir, b, s)] )==" "\n" |
327 | R"==(= src_c[OFF_WS_STATE(lay + 1, dir, N_ITER, b, s)]; )==" "\n" |
328 | R"==(} )==" "\n" |
329 | R"==(#endif )==" "\n" |
330 | R"==(#else )==" "\n" |
331 | R"==(__global DIFF_DATA_T *src = (__global DIFF_DATA_T *)(scratch_diff_states); )==" "\n" |
332 | R"==(__global DIFF_DATA_T *dst = (__global DIFF_DATA_T *)(dst_base); )==" "\n" |
333 | R"==(__global DIFF_DATA_T *dst_c = (__global DIFF_DATA_T *)(dst_c_base); )==" "\n" |
334 | R"==(if (dst_base && s < SIC) { )==" "\n" |
335 | R"==(dst[DIFF_SRC_I_OFF(lay, dir, b, s)] )==" "\n" |
336 | R"==(= src[OFF_SCRATCH_DIFF_STATES(lay, dir, 0, 0, b, s)]; )==" "\n" |
337 | R"==(} )==" "\n" |
338 | R"==(#if WITH_SRC_ITER_C )==" "\n" |
339 | R"==(if (dst_base && s < DHC) { )==" "\n" |
340 | R"==(dst_c[DIFF_SRC_I_C_OFF(lay, dir, b, s)] )==" "\n" |
341 | R"==(= src[OFF_SCRATCH_DIFF_STATES(lay, dir, 1, 0, b, s)]; )==" "\n" |
342 | R"==(} )==" "\n" |
343 | R"==(#endif )==" "\n" |
344 | R"==(#endif )==" "\n" |
345 | R"==(} )==" "\n" |
346 | R"==(__kernel void ref_rnn_ws_set( )==" "\n" |
347 | R"==(__global char *ws, OFFTYPE ws_offset, float val, int ws_part) { )==" "\n" |
348 | R"==(if (ws_part == WS_C_STATES || ws_part == WS_BIAS) { )==" "\n" |
349 | R"==(__global DIFF_DATA_T *dst = (__global DIFF_DATA_T *)(ws + ws_offset); )==" "\n" |
350 | R"==(dst[get_global_id(0)] = CONVERT_DATA_T(val); )==" "\n" |
351 | R"==(} else if (ws_part == WS_GATES) { )==" "\n" |
352 | R"==(__global ACC_DATA_T *dst = (__global ACC_DATA_T *)(ws + ws_offset); )==" "\n" |
353 | R"==(dst[get_global_id(0)] = TO_ACC(val); )==" "\n" |
354 | R"==(} else { )==" "\n" |
355 | R"==(__global WS_STATE_DATA_T *dst )==" "\n" |
356 | R"==(= (__global WS_STATE_DATA_T *)(ws + ws_offset); )==" "\n" |
357 | R"==(dst[get_global_id(0)] = TO_WS_STATE(val); )==" "\n" |
358 | R"==(} )==" "\n" |
359 | R"==(} )==" "\n" |
360 | R"==(#if DEBUGPRINT )==" "\n" |
361 | R"==(__kernel void ref_rnn_ws_print(const __global char *ws) { )==" "\n" |
362 | R"==({ )==" "\n" |
363 | R"==(__global ACC_DATA_T *wt = (__global ACC_DATA_T *)(ws + WS_GATES_OFFSET); )==" "\n" |
364 | R"==(printf("ws_gates: off %d\n", WS_GATES_OFFSET); )==" "\n" |
365 | R"==(printf("[lay,dir,iter,batch]\n"); )==" "\n" |
366 | R"==(for_(int j = 0; j < N_LAYER; j++) )==" "\n" |
367 | R"==(for_(int dir = 0; dir < N_DIR; dir++) )==" "\n" |
368 | R"==(for_(int i = 0; i < N_ITER; i++) )==" "\n" |
369 | R"==(for (int b = 0; b < BATCH; b++) { )==" "\n" |
370 | R"==(printf("[%d,%d,%d,%d]: ", j, dir, i, b); )==" "\n" |
371 | R"==(for_(int g = 0; g < N_GATES; g++) )==" "\n" |
372 | R"==(for (int s = 0; s < DHC; s++) { )==" "\n" |
373 | R"==(printf(" %f", )==" "\n" |
374 | R"==(SRC_TO_REF(*(wt + OFF_WS_GATES(j, dir, i, b, g, s)))); )==" "\n" |
375 | R"==(} )==" "\n" |
376 | R"==(printf("\n"); )==" "\n" |
377 | R"==(} )==" "\n" |
378 | R"==(} )==" "\n" |
379 | R"==({ )==" "\n" |
380 | R"==(__global WS_STATE_DATA_T *wt )==" "\n" |
381 | R"==(= (__global WS_STATE_DATA_T *)(ws + WS_STATES_OFFSET); )==" "\n" |
382 | R"==(printf("ws_states (H): off %d\n", WS_STATES_OFFSET); )==" "\n" |
383 | R"==(printf("[lay,dir,iter]\n"); )==" "\n" |
384 | R"==(for_(int j = 0; j < N_LAYER + 1; j++) )==" "\n" |
385 | R"==(for_(int dir = 0; dir < N_DIR; dir++) )==" "\n" |
386 | R"==(for (int i = 1; i < N_ITER + 1; i++) { )==" "\n" |
387 | R"==(printf("[%d,%d,%d] : ", j, dir, i); )==" "\n" |
388 | R"==(for_(int b = 0; b < BATCH; b++) )==" "\n" |
389 | R"==(for (int s = 0; s < WIC; s++) { )==" "\n" |
390 | R"==(printf(" %f", )==" "\n" |
391 | R"==(SRC_TO_REF(*(wt + OFF_WS_STATE(j, dir, i, b, s)))); )==" "\n" |
392 | R"==(} )==" "\n" |
393 | R"==(printf("\n"); )==" "\n" |
394 | R"==(} )==" "\n" |
395 | R"==(} )==" "\n" |
396 | R"==(#if IS_TRAINING && CELL_KIND == LBR_GRU )==" "\n" |
397 | R"==({ )==" "\n" |
398 | R"==(__global ACC_DATA_T *wt )==" "\n" |
399 | R"==(= (__global ACC_DATA_T *)(ws + WS_GRID_COMP_OFFSET); )==" "\n" |
400 | R"==(printf("ws_grid: off %d\n", WS_GRID_COMP_OFFSET); )==" "\n" |
401 | R"==(printf("[lay,dir,iter,batch]\n"); )==" "\n" |
402 | R"==(for_(int j = 0; j < N_LAYER; j++) )==" "\n" |
403 | R"==(for_(int dir = 0; dir < N_DIR; dir++) )==" "\n" |
404 | R"==(for_(int i = 0; i < N_ITER; i++) )==" "\n" |
405 | R"==(for (int b = 0; b < BATCH; b++) { )==" "\n" |
406 | R"==(printf("[%d,%d,%d,%d]: ", j, dir, i, b); )==" "\n" |
407 | R"==(for (int s = 0; s < DHC; s++) { )==" "\n" |
408 | R"==(printf(" %f", *(wt + OFF_WS_GRID_OFFSET(j, dir, i, b, s))); )==" "\n" |
409 | R"==(} )==" "\n" |
410 | R"==(printf("\n"); )==" "\n" |
411 | R"==(} )==" "\n" |
412 | R"==(} )==" "\n" |
413 | R"==(#endif )==" "\n" |
414 | R"==(#if IS_FWD && CELL_KIND == VANILLA_LSTM )==" "\n" |
415 | R"==({ )==" "\n" |
416 | R"==(__global AUX_DATA_T *wt )==" "\n" |
417 | R"==(= (__global AUX_DATA_T *)(ws + WS_C_STATE_OFFSET); )==" "\n" |
418 | R"==(printf("ws_states (C): off %d\n", WS_C_STATE_OFFSET); )==" "\n" |
419 | R"==(printf("[lay,dir,iter]\n"); )==" "\n" |
420 | R"==(for_(int j = 0; j < N_LAYER; j++) )==" "\n" |
421 | R"==(for_(int dir = 0; dir < N_DIR; dir++) )==" "\n" |
422 | R"==(for (int i = 0; i < N_ITER + 1; i++) { )==" "\n" |
423 | R"==(printf("[%d,%d,%d] : ", j, dir, i); )==" "\n" |
424 | R"==(for_(int b = 0; b < BATCH; b++) )==" "\n" |
425 | R"==(for (int s = 0; s < WIC; s++) { )==" "\n" |
426 | R"==(printf(" %f", *(wt + OFF_WS_STATE(j, dir, i, b, s))); )==" "\n" |
427 | R"==(} )==" "\n" |
428 | R"==(printf("\n"); )==" "\n" |
429 | R"==(} )==" "\n" |
430 | R"==(} )==" "\n" |
431 | R"==(#endif )==" "\n" |
432 | R"==(#if COPY_BIAS )==" "\n" |
433 | R"==({ )==" "\n" |
434 | R"==(__global AUX_DATA_T *wt = (__global AUX_DATA_T *)(ws + WS_BIAS_OFFSET); )==" "\n" |
435 | R"==(printf("ws_bias: off %d\n", WS_BIAS_OFFSET); )==" "\n" |
436 | R"==(printf("[lay,dir]\n"); )==" "\n" |
437 | R"==(for_(int j = 0; j < N_LAYER; j++) )==" "\n" |
438 | R"==(for_(int dir = 0; dir < N_DIR; dir++) )==" "\n" |
439 | R"==({ )==" "\n" |
440 | R"==(printf("[%d,%d] : ", j, dir); )==" "\n" |
441 | R"==(for_(int nb = 0; nb < N_BIAS; nb++) )==" "\n" |
442 | R"==(for (int dhc = 0; dhc < DHC; dhc++) { )==" "\n" |
443 | R"==(printf(" %f", *(wt + OFF_WS_BIAS(j, dir, nb, dhc))); )==" "\n" |
444 | R"==(} )==" "\n" |
445 | R"==(printf("\n"); )==" "\n" |
446 | R"==(} )==" "\n" |
447 | R"==(} )==" "\n" |
448 | R"==(#endif )==" "\n" |
449 | R"==(} )==" "\n" |
450 | R"==(#endif )==" "\n" |
451 | R"==(__kernel void ref_rnn_bias_prepare(__global char *ws, __global float *scales, )==" "\n" |
452 | R"==(__global char *wei_layer, __global char *wei_iter, __global float *bias, )==" "\n" |
453 | R"==(float data_shift, float data_scale) { )==" "\n" |
454 | R"==(#if COPY_BIAS )==" "\n" |
455 | R"==(const int dhc = get_global_id(0); )==" "\n" |
456 | R"==(const int nbias = get_global_id(1); )==" "\n" |
457 | R"==(const int layer = get_global_id(2) / N_DIR; )==" "\n" |
458 | R"==(const int dir = get_global_id(2) % N_DIR; )==" "\n" |
459 | R"==(__global float *ws_bias = (__global float *)(ws + WS_BIAS_OFFSET); )==" "\n" |
460 | R"==(const float wei_scale )==" "\n" |
461 | R"==(#if WEI_QPARAM_MASK )==" "\n" |
462 | R"==(= scales[nbias * DHC + dhc]; )==" "\n" |
463 | R"==(#else )==" "\n" |
464 | R"==(= scales[0]; )==" "\n" |
465 | R"==(#endif )==" "\n" |
466 | R"==(#define COMP_OFF(i0, i1, i2, i3) \ )==" "\n" |
467 | R"==(((((i0) * (N_DIR) + (i1)) * (N_BIAS) + (i2)) * (DHC) + (i3)) )==" "\n" |
468 | R"==(#define COMP_WEI_LAYER_OFF (WEI_L_D0 * WEI_L_S0) )==" "\n" |
469 | R"==(#define COMP_WEI_ITER_OFF (WEI_I_D0 * WEI_I_S0) )==" "\n" |
470 | R"==(__global char *temp = (__global char *)(wei_iter + COMP_WEI_ITER_OFF); )==" "\n" |
471 | R"==(__global float *wei_iter_comp )==" "\n" |
472 | R"==(= (__global float *)(((unsigned long)temp + (sizeof(float) - 1)) )==" "\n" |
473 | R"==(& -sizeof(float)); )==" "\n" |
474 | R"==(temp = (__global char *)(wei_layer + COMP_WEI_LAYER_OFF); )==" "\n" |
475 | R"==(__global float *wei_layer_comp )==" "\n" |
476 | R"==(= (__global float *)(((unsigned long)temp + (sizeof(float) - 1)) )==" "\n" |
477 | R"==(& -sizeof(float)); )==" "\n" |
478 | R"==(const int off = COMP_OFF(layer, dir, nbias, dhc); )==" "\n" |
479 | R"==(const float comp = wei_layer_comp[off] + wei_iter_comp[off]; )==" "\n" |
480 | R"==(ws_bias[OFF_WS_BIAS(layer, dir, nbias, dhc)] )==" "\n" |
481 | R"==(= bias[BIAS_OFF(layer, dir, nbias, dhc)] )==" "\n" |
482 | R"==(- comp * data_shift / (wei_scale * data_scale); )==" "\n" |
483 | R"==(#endif )==" "\n" |
484 | R"==(} )==" "\n" |
485 | R"==(#if IS_INT8 && CELL_KIND == VANILLA_LSTM )==" "\n" |
486 | R"==(WS_STATE_DATA_T q_d(float f, float data_scale, float data_shift) { )==" "\n" |
487 | R"==(float qf = f * data_scale + data_shift; )==" "\n" |
488 | R"==(return TO_WS_STATE(qf); )==" "\n" |
489 | R"==(} )==" "\n" |
490 | R"==(float deq_w(ACC_DATA_T s, int gate, int j, __global float *scales, )==" "\n" |
491 | R"==(float data_scale) { )==" "\n" |
492 | R"==(#if WEI_QPARAM_MASK )==" "\n" |
493 | R"==(float wei_scale = scales[gate * DHC + j]; )==" "\n" |
494 | R"==(#else )==" "\n" |
495 | R"==(float wei_scale = scales[0]; )==" "\n" |
496 | R"==(#endif )==" "\n" |
497 | R"==(return (float)(s) / (wei_scale * data_scale); )==" "\n" |
498 | R"==(} )==" "\n" |
499 | R"==(__attribute__((intel_reqd_sub_group_size(SUBGROUP_SIZE))) __kernel void )==" "\n" |
500 | R"==(ref_rnn_elemwise_fwd(int dir, int lay, int iter, __global char *ws, )==" "\n" |
501 | R"==(__global char *scr_gates, __global float *scales, )==" "\n" |
502 | R"==(__global float *bias_base, float alpha, float data_shift, )==" "\n" |
503 | R"==(float data_scale, __global float *tm_scales, float tm_cscale) { )==" "\n" |
504 | R"==(const int i = get_global_id(1); )==" "\n" |
505 | R"==(const int j = get_global_id(0); )==" "\n" |
506 | R"==(if (j >= DHC || i >= BATCH) return; )==" "\n" |
507 | R"==(const __global float *c_states_tm1_l )==" "\n" |
508 | R"==(= (__global float *)(ws + WS_C_STATE_OFFSET) )==" "\n" |
509 | R"==(+ OFF_WS_STATE(lay + 1, dir, iter, 0, 0); )==" "\n" |
510 | R"==(__global float *ws_bias = (__global float *)(ws + WS_BIAS_OFFSET); )==" "\n" |
511 | R"==(__global ACC_DATA_T *ws_gates )==" "\n" |
512 | R"==(= (__global ACC_DATA_T *)(ws + WS_GATES_OFFSET) )==" "\n" |
513 | R"==(+ OFF_WS_GATES(lay, dir, iter, 0, 0, 0); )==" "\n" |
514 | R"==(__global ACC_DATA_T *scratch_gates = (__global ACC_DATA_T *)(scr_gates) )==" "\n" |
515 | R"==(+ OFF_SCRATCH_MEM(iter, 0, 0, 0); )==" "\n" |
516 | R"==(__global WS_STATE_DATA_T *h_states_t_l )==" "\n" |
517 | R"==(= (__global WS_STATE_DATA_T *)(ws + WS_STATES_OFFSET) )==" "\n" |
518 | R"==(+ OFF_WS_STATE(lay + 1, dir, iter + 1, 0, 0); )==" "\n" |
519 | R"==(__global float *c_states_t_l = (__global float *)(ws + WS_C_STATE_OFFSET) )==" "\n" |
520 | R"==(+ OFF_WS_STATE(lay + 1, dir, iter + 1, 0, 0); )==" "\n" |
521 | R"==(float G0 = logistic_fwd_tm(deq_w(scratch_gates[CELL_SCRATCH_MEM(i, 0, j)], )==" "\n" |
522 | R"==(0, j, scales, data_scale) )==" "\n" |
523 | R"==(+ ws_bias[OFF_WS_BIAS(lay, dir, 0, j)], )==" "\n" |
524 | R"==(tm_scales[0]); )==" "\n" |
525 | R"==(float G1 = logistic_fwd_tm(deq_w(scratch_gates[CELL_SCRATCH_MEM(i, 1, j)], )==" "\n" |
526 | R"==(1, j, scales, data_scale) )==" "\n" |
527 | R"==(+ ws_bias[OFF_WS_BIAS(lay, dir, 1, j)], )==" "\n" |
528 | R"==(tm_scales[1]); )==" "\n" |
529 | R"==(float G2 = tanh_fwd_tm(deq_w(scratch_gates[CELL_SCRATCH_MEM(i, 2, j)], 2, j, )==" "\n" |
530 | R"==(scales, data_scale) )==" "\n" |
531 | R"==(+ ws_bias[OFF_WS_BIAS(lay, dir, 2, j)], )==" "\n" |
532 | R"==(tm_scales[2]); )==" "\n" |
533 | R"==(float G3 = logistic_fwd_tm(deq_w(scratch_gates[CELL_SCRATCH_MEM(i, 3, j)], )==" "\n" |
534 | R"==(3, j, scales, data_scale) )==" "\n" |
535 | R"==(+ ws_bias[OFF_WS_BIAS(lay, dir, 3, j)], )==" "\n" |
536 | R"==(tm_scales[3]); )==" "\n" |
537 | R"==(float tmp = G1 * c_states_tm1_l[CELL_WS_STATE(i, j)] + G0 * G2; )==" "\n" |
538 | R"==(h_states_t_l[CELL_WS_STATE(i, j)] )==" "\n" |
539 | R"==(= q_d(G3 * tanh_fwd_tm(tmp, tm_cscale), data_scale, data_shift); )==" "\n" |
540 | R"==(c_states_t_l[CELL_WS_STATE(i, j)] = tmp; )==" "\n" |
541 | R"==(} )==" "\n" |
542 | R"==(#else )==" "\n" |
543 | R"==(__attribute__((intel_reqd_sub_group_size(SUBGROUP_SIZE))) __kernel void )==" "\n" |
544 | R"==(ref_rnn_elemwise_fwd( )==" "\n" |
545 | R"==(int dir, int lay, int iter, __global char *ws, __global char *scr_gates, )==" "\n" |
546 | R"==(__global AUX_DATA_T *bias_base, float alpha, __global float *tm_scales, )==" "\n" |
547 | R"==(#if CELL_KIND == VANILLA_LSTM || CELL_KIND == VANILLA_RNN )==" "\n" |
548 | R"==(float tm_cscale )==" "\n" |
549 | R"==(#elif CELL_KIND == LBR_GRU )==" "\n" |
550 | R"==(__global char *scr_cell )==" "\n" |
551 | R"==(#elif CELL_KIND == VANILLA_GRU )==" "\n" |
552 | R"==(int n_part )==" "\n" |
553 | R"==(#endif )==" "\n" |
554 | R"==() { )==" "\n" |
555 | R"==(const int i = get_global_id(1); )==" "\n" |
556 | R"==(const int j = get_global_id(0); )==" "\n" |
557 | R"==(if (j >= DHC || i >= BATCH) return; )==" "\n" |
558 | R"==(const __global AUX_DATA_T *c_states_tm1_l )==" "\n" |
559 | R"==(= (__global AUX_DATA_T *)(ws + WS_C_STATE_OFFSET) )==" "\n" |
560 | R"==(+ OFF_WS_STATE(lay + 1, dir, iter, 0, 0); )==" "\n" |
561 | R"==(const __global AUX_DATA_T *bias = bias_base + BIAS_OFF(lay, dir, 0, 0); )==" "\n" |
562 | R"==(__global AUX_DATA_T *ws_gates )==" "\n" |
563 | R"==(= (__global AUX_DATA_T *)(ws + WS_GATES_OFFSET) )==" "\n" |
564 | R"==(+ OFF_WS_GATES(lay, dir, iter, 0, 0, 0); )==" "\n" |
565 | R"==(__global AUX_DATA_T *scratch_gates = (__global AUX_DATA_T *)(scr_gates) )==" "\n" |
566 | R"==(+ OFF_SCRATCH_MEM(iter, 0, 0, 0); )==" "\n" |
567 | R"==(__global WS_STATE_DATA_T *h_states_t_l )==" "\n" |
568 | R"==(= (__global WS_STATE_DATA_T *)(ws + WS_STATES_OFFSET) )==" "\n" |
569 | R"==(+ OFF_WS_STATE(lay + 1, dir, iter + 1, 0, 0); )==" "\n" |
570 | R"==(#if CELL_KIND == VANILLA_LSTM )==" "\n" |
571 | R"==(__global AUX_DATA_T *c_states_t_l )==" "\n" |
572 | R"==(= (__global AUX_DATA_T *)(ws + WS_C_STATE_OFFSET) )==" "\n" |
573 | R"==(+ OFF_WS_STATE(lay + 1, dir, iter + 1, 0, 0); )==" "\n" |
574 | R"==(float g_i = logistic_fwd_tm((float)scratch_gates[CELL_SCRATCH_MEM(i, 0, j)] )==" "\n" |
575 | R"==(+ bias[OFF_KER_BIAS(0, j)], )==" "\n" |
576 | R"==(tm_scales[0]); )==" "\n" |
577 | R"==(float g_f = logistic_fwd_tm((float)scratch_gates[CELL_SCRATCH_MEM(i, 1, j)] )==" "\n" |
578 | R"==(+ bias[OFF_KER_BIAS(1, j)], )==" "\n" |
579 | R"==(tm_scales[1]); )==" "\n" |
580 | R"==(float g_z = tanh_fwd_tm((float)scratch_gates[CELL_SCRATCH_MEM(i, 2, j)] )==" "\n" |
581 | R"==(+ bias[OFF_KER_BIAS(2, j)], )==" "\n" |
582 | R"==(tm_scales[2]); )==" "\n" |
583 | R"==(float g_o = logistic_fwd_tm((float)scratch_gates[CELL_SCRATCH_MEM(i, 3, j)] )==" "\n" |
584 | R"==(+ bias[OFF_KER_BIAS(3, j)], )==" "\n" |
585 | R"==(tm_scales[3]); )==" "\n" |
586 | R"==(#if IS_TRAINING )==" "\n" |
587 | R"==(ws_gates[CELL_WS_GATES(i, 0, j)] = g_i; )==" "\n" |
588 | R"==(ws_gates[CELL_WS_GATES(i, 1, j)] = g_f; )==" "\n" |
589 | R"==(ws_gates[CELL_WS_GATES(i, 2, j)] = g_z; )==" "\n" |
590 | R"==(ws_gates[CELL_WS_GATES(i, 3, j)] = g_o; )==" "\n" |
591 | R"==(#endif )==" "\n" |
592 | R"==(float Ct = g_f * c_states_tm1_l[CELL_WS_STATE(i, j)] + g_i * g_z; )==" "\n" |
593 | R"==(float Ht = g_o * tanh_fwd_tm(Ct, tm_cscale); )==" "\n" |
594 | R"==(h_states_t_l[CELL_WS_STATE(i, j)] = TO_INPUT(Ht); )==" "\n" |
595 | R"==(c_states_t_l[CELL_WS_STATE(i, j)] = Ct; )==" "\n" |
596 | R"==(#elif CELL_KIND == VANILLA_RNN )==" "\n" |
597 | R"==(float g = activation_fwd((float)scratch_gates[CELL_SCRATCH_MEM(i, 0, j)] )==" "\n" |
598 | R"==(+ bias[OFF_KER_BIAS(0, j)], )==" "\n" |
599 | R"==(#if IS_TESTMODE )==" "\n" |
600 | R"==(tm_scales[0], 0); )==" "\n" |
601 | R"==(#else )==" "\n" |
602 | R"==(alpha, 0); )==" "\n" |
603 | R"==(#endif )==" "\n" |
604 | R"==(#if IS_TRAINING )==" "\n" |
605 | R"==(ws_gates[CELL_WS_GATES(i, 0, j)] = g; )==" "\n" |
606 | R"==(#endif )==" "\n" |
607 | R"==(h_states_t_l[CELL_WS_STATE(i, j)] = TO_INPUT(g); )==" "\n" |
608 | R"==(#elif CELL_KIND == LBR_GRU )==" "\n" |
609 | R"==(__global AUX_DATA_T *scratch_cell = (__global AUX_DATA_T *)(scr_cell); )==" "\n" |
610 | R"==(__global WS_STATE_DATA_T *src_iter )==" "\n" |
611 | R"==(= (__global WS_STATE_DATA_T *)(ws + WS_STATES_OFFSET) )==" "\n" |
612 | R"==(+ OFF_WS_STATE(lay + 1, dir, iter, 0, 0); )==" "\n" |
613 | R"==(__global AUX_DATA_T *ws_grid )==" "\n" |
614 | R"==(= (__global AUX_DATA_T *)(ws + WS_GRID_COMP_OFFSET) )==" "\n" |
615 | R"==(+ OFF_WS_GRID_OFFSET(lay, dir, iter, 0, 0); )==" "\n" |
616 | R"==(float Wh_b = (float)scratch_cell[CELL_SCRATCH_MEM(i, 2, j)] )==" "\n" |
617 | R"==(+ bias[OFF_KER_BIAS(3, j)]; )==" "\n" |
618 | R"==(float G0 = logistic_fwd_tm((float)scratch_gates[CELL_SCRATCH_MEM(i, 0, j)] )==" "\n" |
619 | R"==(+ (float)scratch_cell[CELL_SCRATCH_MEM(i, 0, j)] )==" "\n" |
620 | R"==(+ bias[OFF_KER_BIAS(0, j)], )==" "\n" |
621 | R"==(tm_scales[0]); )==" "\n" |
622 | R"==(float G1 = logistic_fwd_tm((float)scratch_gates[CELL_SCRATCH_MEM(i, 1, j)] )==" "\n" |
623 | R"==(+ (float)scratch_cell[CELL_SCRATCH_MEM(i, 1, j)] )==" "\n" |
624 | R"==(+ bias[OFF_KER_BIAS(1, j)], )==" "\n" |
625 | R"==(tm_scales[1]); )==" "\n" |
626 | R"==(float G2 = tanh_fwd_tm((float)scratch_gates[CELL_SCRATCH_MEM(i, 2, j)] )==" "\n" |
627 | R"==(+ G1 * Wh_b + bias[OFF_KER_BIAS(2, j)], )==" "\n" |
628 | R"==(tm_scales[2]); )==" "\n" |
629 | R"==(float Ht = G0 * TO_REF(src_iter[CELL_WS_STATE(i, j)]) + (1 - G0) * G2; )==" "\n" |
630 | R"==(h_states_t_l[CELL_WS_STATE(i, j)] = TO_INPUT(Ht); )==" "\n" |
631 | R"==(#if IS_TRAINING )==" "\n" |
632 | R"==(ws_gates[CELL_WS_GATES(i, 0, j)] = G0; )==" "\n" |
633 | R"==(ws_gates[CELL_WS_GATES(i, 1, j)] = G1; )==" "\n" |
634 | R"==(ws_gates[CELL_WS_GATES(i, 2, j)] = G2; )==" "\n" |
635 | R"==(ws_grid[CELL_WS_GRID_COMP(i, j)] = Wh_b; )==" "\n" |
636 | R"==(#endif )==" "\n" |
637 | R"==(#elif CELL_KIND == VANILLA_GRU )==" "\n" |
638 | R"==(__global WS_STATE_DATA_T *src_iter )==" "\n" |
639 | R"==(= (__global WS_STATE_DATA_T *)(ws + WS_STATES_OFFSET) )==" "\n" |
640 | R"==(+ OFF_WS_STATE(lay + 1, dir, iter, 0, 0); )==" "\n" |
641 | R"==(if (n_part == 1) { )==" "\n" |
642 | R"==(float G0 = logistic_fwd_tm(scratch_gates[CELL_SCRATCH_MEM(i, 0, j)] )==" "\n" |
643 | R"==(+ bias[OFF_KER_BIAS(0, j)], )==" "\n" |
644 | R"==(tm_scales[0]); )==" "\n" |
645 | R"==(float G1 = logistic_fwd_tm(scratch_gates[CELL_SCRATCH_MEM(i, 1, j)] )==" "\n" |
646 | R"==(+ bias[OFF_KER_BIAS(1, j)], )==" "\n" |
647 | R"==(tm_scales[1]); )==" "\n" |
648 | R"==(/* TODO from CPU: Can be optimized for fwd_training by using )==" "\n" |
649 | R"==(ws_gates instead of scratch_gates in p2 */ )==" "\n" |
650 | R"==(scratch_gates[CELL_SCRATCH_MEM(i, 0, j)] = TO_INPUT(G0); )==" "\n" |
651 | R"==(scratch_gates[CELL_SCRATCH_MEM(i, 1, j)] = TO_INPUT(G1); )==" "\n" |
652 | R"==(float tmp = TO_REF(src_iter[CELL_WS_STATE(i, j)]); )==" "\n" |
653 | R"==(h_states_t_l[CELL_WS_STATE(i, j)] = TO_INPUT(tmp * G1); )==" "\n" |
654 | R"==(#if IS_TRAINING )==" "\n" |
655 | R"==(ws_gates[CELL_WS_GATES(i, 0, j)] = G0; )==" "\n" |
656 | R"==(ws_gates[CELL_WS_GATES(i, 1, j)] = G1; )==" "\n" |
657 | R"==(#endif )==" "\n" |
658 | R"==(} else if (n_part == 2) { )==" "\n" |
659 | R"==(float G0 = TO_REF(scratch_gates[CELL_SCRATCH_MEM(i, 0, j)]); )==" "\n" |
660 | R"==(float G2 = tanh_fwd_tm(scratch_gates[CELL_SCRATCH_MEM(i, 2, j)] )==" "\n" |
661 | R"==(+ bias[OFF_KER_BIAS(2, j)], )==" "\n" |
662 | R"==(tm_scales[2]); )==" "\n" |
663 | R"==(float tmp = TO_REF(src_iter[CELL_WS_STATE(i, j)]); )==" "\n" |
664 | R"==(h_states_t_l[CELL_WS_STATE(i, j)] )==" "\n" |
665 | R"==(= TO_INPUT(tmp * G0 + (1.0f - G0) * G2); )==" "\n" |
666 | R"==(#if IS_TRAINING )==" "\n" |
667 | R"==(ws_gates[CELL_WS_GATES(i, 2, j)] = G2; )==" "\n" |
668 | R"==(#endif )==" "\n" |
669 | R"==(} )==" "\n" |
670 | R"==(#else )==" "\n" |
671 | R"==(#error "Wrong Cell Kind" )==" "\n" |
672 | R"==(#endif )==" "\n" |
673 | R"==(} )==" "\n" |
674 | R"==(#endif )==" "\n" |
675 | R"==(__attribute__((intel_reqd_sub_group_size(SUBGROUP_SIZE))) __kernel void )==" "\n" |
676 | R"==(ref_rnn_elemwise_bwd(int dir, int lay, int iter, __global char *ws, )==" "\n" |
677 | R"==(__global char *scr_gates, __global AUX_DATA_T *bias_base, float alpha, )==" "\n" |
678 | R"==(__global float *tm_scales, )==" "\n" |
679 | R"==(#if CELL_KIND == VANILLA_LSTM || CELL_KIND == VANILLA_RNN )==" "\n" |
680 | R"==(float tm_cscale, )==" "\n" |
681 | R"==(#elif CELL_KIND == LBR_GRU )==" "\n" |
682 | R"==(__global char *scr_gate_r, )==" "\n" |
683 | R"==(#elif CELL_KIND == VANILLA_GRU )==" "\n" |
684 | R"==(int n_part, __global char *scr_cell, __global char *scratch_dhG1, )==" "\n" |
685 | R"==(#endif )==" "\n" |
686 | R"==(__global char *diff_states) { )==" "\n" |
687 | R"==(const int i = get_global_id(1); )==" "\n" |
688 | R"==(const int j = get_global_id(0); )==" "\n" |
689 | R"==(if (j >= DHC || i >= BATCH) return; )==" "\n" |
690 | R"==(#if CELL_KIND == VANILLA_LSTM )==" "\n" |
691 | R"==(__global AUX_DATA_T *ws_gates )==" "\n" |
692 | R"==(= (__global AUX_DATA_T *)(ws + WS_GATES_OFFSET) )==" "\n" |
693 | R"==(+ OFF_WS_GATES(lay, dir, iter, 0, 0, 0); )==" "\n" |
694 | R"==(__global SRC_DATA_T *scratch_gates = (__global SRC_DATA_T *)(scr_gates) )==" "\n" |
695 | R"==(+ OFF_SCRATCH_MEM(iter, 0, 0, 0); )==" "\n" |
696 | R"==(__global AUX_DATA_T *c_states_t_l )==" "\n" |
697 | R"==(= (__global AUX_DATA_T *)(ws + WS_C_STATE_OFFSET) )==" "\n" |
698 | R"==(+ OFF_WS_STATE(lay + 1, dir, iter + 1, 0, 0); )==" "\n" |
699 | R"==(__global AUX_DATA_T *c_states_tm1_l )==" "\n" |
700 | R"==(= (__global AUX_DATA_T *)(ws + WS_C_STATE_OFFSET) )==" "\n" |
701 | R"==(+ OFF_WS_STATE(lay + 1, dir, iter, 0, 0); )==" "\n" |
702 | R"==(__global DIFF_DATA_T *diff_states_t_l = (__global DIFF_DATA_T *)diff_states )==" "\n" |
703 | R"==(+ OFF_SCRATCH_DIFF_STATES(lay, dir, 0, iter, 0, 0); )==" "\n" |
704 | R"==(__global DIFF_DATA_T *diff_states_tp1_l )==" "\n" |
705 | R"==(= (__global DIFF_DATA_T *)diff_states )==" "\n" |
706 | R"==(+ OFF_SCRATCH_DIFF_STATES(lay, dir, 0, iter + 1, 0, 0); )==" "\n" |
707 | R"==(__global DIFF_DATA_T *diff_states_t_lp1 )==" "\n" |
708 | R"==(= (__global DIFF_DATA_T *)diff_states )==" "\n" |
709 | R"==(+ OFF_SCRATCH_DIFF_STATES(lay + 1, dir, 0, iter, 0, 0); )==" "\n" |
710 | R"==(float Ct = c_states_t_l[CELL_WS_STATE(i, j)]; )==" "\n" |
711 | R"==(float tanhCt = tanh_fwd_tm(Ct, tm_cscale); )==" "\n" |
712 | R"==(float dHt = (float)diff_states_tp1_l[CELL_SCRATCH_DIFF_STATES(0, i, j)] )==" "\n" |
713 | R"==(+ diff_states_t_lp1[CELL_SCRATCH_DIFF_STATES(N_STATES, i, j)]; )==" "\n" |
714 | R"==(float dCt = (float)diff_states_tp1_l[CELL_SCRATCH_DIFF_STATES(1, i, j)] )==" "\n" |
715 | R"==(+ one_m_square(tanhCt) * ws_gates[CELL_WS_GATES(i, 3, j)] * dHt; )==" "\n" |
716 | R"==(float dG1 = (float)c_states_tm1_l[CELL_WS_STATE(i, j)] * dCt )==" "\n" |
717 | R"==(* x_m_square(ws_gates[CELL_WS_GATES(i, 1, j)]); )==" "\n" |
718 | R"==(float dG0 = ws_gates[CELL_WS_GATES(i, 2, j)] * dCt )==" "\n" |
719 | R"==(* x_m_square(ws_gates[CELL_WS_GATES(i, 0, j)]); )==" "\n" |
720 | R"==(float dG3 = tanhCt * dHt * x_m_square(ws_gates[CELL_WS_GATES(i, 3, j)]); )==" "\n" |
721 | R"==(float dG2 = ws_gates[CELL_WS_GATES(i, 0, j)] * dCt )==" "\n" |
722 | R"==(* one_m_square(ws_gates[CELL_WS_GATES(i, 2, j)]); )==" "\n" |
723 | R"==(diff_states_t_l[CELL_SCRATCH_DIFF_STATES(1, i, j)] )==" "\n" |
724 | R"==(= dCt * ws_gates[CELL_WS_GATES(i, 1, j)]; )==" "\n" |
725 | R"==(scratch_gates[CELL_SCRATCH_MEM(i, 0, j)] = TO_INPUT(dG0); )==" "\n" |
726 | R"==(scratch_gates[CELL_SCRATCH_MEM(i, 1, j)] = TO_INPUT(dG1); )==" "\n" |
727 | R"==(scratch_gates[CELL_SCRATCH_MEM(i, 2, j)] = TO_INPUT(dG2); )==" "\n" |
728 | R"==(scratch_gates[CELL_SCRATCH_MEM(i, 3, j)] = TO_INPUT(dG3); )==" "\n" |
729 | R"==(#elif CELL_KIND == LBR_GRU )==" "\n" |
730 | R"==(__global SRC_DATA_T *scratch_gates = (__global SRC_DATA_T *)(scr_gates) )==" "\n" |
731 | R"==(+ OFF_SCRATCH_MEM(iter, 0, 0, 0); )==" "\n" |
732 | R"==(__global SRC_DATA_T *scratch_gate_r = (__global SRC_DATA_T *)(scr_gate_r); )==" "\n" |
733 | R"==(__global AUX_DATA_T *ws_gates )==" "\n" |
734 | R"==(= (__global AUX_DATA_T *)(ws + WS_GATES_OFFSET) )==" "\n" |
735 | R"==(+ OFF_WS_GATES(lay, dir, iter, 0, 0, 0); )==" "\n" |
736 | R"==(__global AUX_DATA_T *ws_grid )==" "\n" |
737 | R"==(= (__global AUX_DATA_T *)(ws + WS_GRID_COMP_OFFSET) )==" "\n" |
738 | R"==(+ OFF_WS_GRID_OFFSET(lay, dir, iter, 0, 0); )==" "\n" |
739 | R"==(__global DIFF_DATA_T *diff_src_iter = (__global DIFF_DATA_T *)diff_states )==" "\n" |
740 | R"==(+ OFF_SCRATCH_DIFF_STATES(lay, dir, 0, iter, 0, 0); )==" "\n" |
741 | R"==(__global DIFF_DATA_T *diff_dst_iter = (__global DIFF_DATA_T *)diff_states )==" "\n" |
742 | R"==(+ OFF_SCRATCH_DIFF_STATES(lay, dir, 0, iter + 1, 0, 0); )==" "\n" |
743 | R"==(__global DIFF_DATA_T *diff_dst_layer = (__global DIFF_DATA_T *)diff_states )==" "\n" |
744 | R"==(+ OFF_SCRATCH_DIFF_STATES(lay + 1, dir, 0, iter, 0, 0); )==" "\n" |
745 | R"==(__global WS_STATE_DATA_T *src_iter )==" "\n" |
746 | R"==(= (__global WS_STATE_DATA_T *)(ws + WS_STATES_OFFSET) )==" "\n" |
747 | R"==(+ OFF_WS_STATE(lay + 1, dir, iter, 0, 0); )==" "\n" |
748 | R"==(float h = TO_REF(src_iter[CELL_WS_STATE(i, j)]); )==" "\n" |
749 | R"==(float Wh_b = ws_grid[CELL_WS_GRID_COMP(i, j)]; )==" "\n" |
750 | R"==(float dHt = diff_dst_iter[CELL_SCRATCH_DIFF_STATES(0, i, j)] )==" "\n" |
751 | R"==(+ diff_dst_layer[CELL_SCRATCH_DIFF_STATES(N_STATES, i, j)]; )==" "\n" |
752 | R"==(float dG0 = (h - ws_gates[CELL_WS_GATES(i, 2, j)]) * dHt )==" "\n" |
753 | R"==(* x_m_square(ws_gates[CELL_WS_GATES(i, 0, j)]); )==" "\n" |
754 | R"==(float dG2 = (1.0f - ws_gates[CELL_WS_GATES(i, 0, j)]) )==" "\n" |
755 | R"==(* one_m_square(ws_gates[CELL_WS_GATES(i, 2, j)]) * dHt; )==" "\n" |
756 | R"==(float dG1 = Wh_b * dG2 * x_m_square(ws_gates[CELL_WS_GATES(i, 1, j)]); )==" "\n" |
757 | R"==(diff_src_iter[CELL_SCRATCH_DIFF_STATES(0, i, j)] )==" "\n" |
758 | R"==(= dHt * ws_gates[CELL_WS_GATES(i, 0, j)]; )==" "\n" |
759 | R"==(scratch_gates[CELL_SCRATCH_MEM(i, 0, j)] = TO_INPUT(dG0); )==" "\n" |
760 | R"==(scratch_gates[CELL_SCRATCH_MEM(i, 1, j)] = TO_INPUT(dG1); )==" "\n" |
761 | R"==(scratch_gates[CELL_SCRATCH_MEM(i, 2, j)] = TO_INPUT(dG2); )==" "\n" |
762 | R"==(scratch_gate_r[CELL_SCRATCH_MEM(i, 0, j)] = TO_INPUT(dG0); )==" "\n" |
763 | R"==(scratch_gate_r[CELL_SCRATCH_MEM(i, 1, j)] = TO_INPUT(dG1); )==" "\n" |
764 | R"==(scratch_gate_r[CELL_SCRATCH_MEM(i, 2, j)] )==" "\n" |
765 | R"==(= TO_INPUT(dG2 * ws_gates[CELL_WS_GATES(i, 1, j)]); )==" "\n" |
766 | R"==(#elif CELL_KIND == VANILLA_RNN )==" "\n" |
767 | R"==(__global AUX_DATA_T *ws_gates )==" "\n" |
768 | R"==(= (__global AUX_DATA_T *)(ws + WS_GATES_OFFSET) )==" "\n" |
769 | R"==(+ OFF_WS_GATES(lay, dir, iter, i, 0, j); )==" "\n" |
770 | R"==(__global SRC_DATA_T *scratch_gates = (__global SRC_DATA_T *)(scr_gates) )==" "\n" |
771 | R"==(+ OFF_SCRATCH_MEM(iter, i, 0, j); )==" "\n" |
772 | R"==(__global DIFF_DATA_T *diff_states_t_lp1 )==" "\n" |
773 | R"==(= (__global DIFF_DATA_T *)diff_states )==" "\n" |
774 | R"==(+ OFF_SCRATCH_DIFF_STATES(lay + 1, dir, N_STATES, iter, i, j); )==" "\n" |
775 | R"==(__global DIFF_DATA_T *diff_states_tp1_l )==" "\n" |
776 | R"==(= (__global DIFF_DATA_T *)diff_states )==" "\n" |
777 | R"==(+ OFF_SCRATCH_DIFF_STATES(lay, dir, 0, iter + 1, i, j); )==" "\n" |
778 | R"==(const float dH = (float)diff_states_t_lp1[0] + diff_states_tp1_l[0]; )==" "\n" |
779 | R"==(float g = ws_gates[0]; )==" "\n" |
780 | R"==(#if IS_TESTMODE )==" "\n" |
781 | R"==(scratch_gates[0] = TO_INPUT(dH * activation_bwd(g, tm_scales[0], 0.)); )==" "\n" |
782 | R"==(#else )==" "\n" |
783 | R"==(scratch_gates[0] = TO_INPUT(dH * activation_bwd(g, alpha, 0.)); )==" "\n" |
784 | R"==(#endif )==" "\n" |
785 | R"==(#elif CELL_KIND == VANILLA_GRU )==" "\n" |
786 | R"==(__global SRC_DATA_T *scratch_gates = (__global SRC_DATA_T *)(scr_gates) )==" "\n" |
787 | R"==(+ OFF_SCRATCH_MEM(iter, 0, 0, 0); )==" "\n" |
788 | R"==(__global AUX_DATA_T *ws_gates )==" "\n" |
789 | R"==(= (__global AUX_DATA_T *)(ws + WS_GATES_OFFSET) )==" "\n" |
790 | R"==(+ OFF_WS_GATES(lay, dir, iter, 0, 0, 0); )==" "\n" |
791 | R"==(__global DIFF_DATA_T *diff_src_iter = (__global DIFF_DATA_T *)diff_states )==" "\n" |
792 | R"==(+ OFF_SCRATCH_DIFF_STATES(lay, dir, 0, iter, 0, 0); )==" "\n" |
793 | R"==(__global DIFF_DATA_T *diff_dst_iter = (__global DIFF_DATA_T *)diff_states )==" "\n" |
794 | R"==(+ OFF_SCRATCH_DIFF_STATES(lay, dir, 0, iter + 1, 0, 0); )==" "\n" |
795 | R"==(__global DIFF_DATA_T *diff_dst_layer = (__global DIFF_DATA_T *)diff_states )==" "\n" |
796 | R"==(+ OFF_SCRATCH_DIFF_STATES(lay + 1, dir, 0, iter, 0, 0); )==" "\n" |
797 | R"==(__global WS_STATE_DATA_T *src_iter )==" "\n" |
798 | R"==(= (__global WS_STATE_DATA_T *)(ws + WS_STATES_OFFSET) )==" "\n" |
799 | R"==(+ OFF_WS_STATE(lay + 1, dir, iter, 0, 0); )==" "\n" |
800 | R"==(float h = TO_REF(src_iter[CELL_WS_STATE(i, j)]); )==" "\n" |
801 | R"==(if (n_part == 1) { )==" "\n" |
802 | R"==(float dHt = diff_dst_iter[CELL_SCRATCH_DIFF_STATES(0, i, j)] )==" "\n" |
803 | R"==(+ diff_dst_layer[CELL_SCRATCH_DIFF_STATES(N_STATES, i, j)]; )==" "\n" |
804 | R"==(float dG2 = (1.0f - ws_gates[CELL_WS_GATES(i, 0, j)]) * dHt )==" "\n" |
805 | R"==(* one_m_square(ws_gates[CELL_WS_GATES(i, 2, j)]); )==" "\n" |
806 | R"==(float dG0 = (h - ws_gates[CELL_WS_GATES(i, 2, j)]) * dHt )==" "\n" |
807 | R"==(* x_m_square(ws_gates[CELL_WS_GATES(i, 0, j)]); )==" "\n" |
808 | R"==(diff_src_iter[CELL_SCRATCH_DIFF_STATES(0, i, j)] )==" "\n" |
809 | R"==(= dHt * ws_gates[CELL_WS_GATES(i, 0, j)]; )==" "\n" |
810 | R"==(scratch_gates[CELL_SCRATCH_MEM(i, 0, j)] = TO_INPUT(dG0); )==" "\n" |
811 | R"==(scratch_gates[CELL_SCRATCH_MEM(i, 2, j)] = TO_INPUT(dG2); )==" "\n" |
812 | R"==(} else if (n_part == 2) { )==" "\n" |
813 | R"==(__global SRC_DATA_T *scratch_cell = (__global SRC_DATA_T *)(scr_cell); )==" "\n" |
814 | R"==(__global DIFF_DATA_T *dhG1 = (__global DIFF_DATA_T *)scratch_dhG1; )==" "\n" |
815 | R"==(float dG1 = ws_gates[CELL_WS_GATES(i, 1, j)]; )==" "\n" |
816 | R"==(diff_src_iter[CELL_SCRATCH_DIFF_STATES(0, i, j)] )==" "\n" |
817 | R"==(+= dhG1[OFF_SCRATCH_DHG1(i, j)] * dG1; )==" "\n" |
818 | R"==(scratch_gates[CELL_SCRATCH_MEM(i, 1, j)] )==" "\n" |
819 | R"==(= TO_INPUT(dhG1[OFF_SCRATCH_DHG1(i, j)] * h * x_m_square(dG1)); )==" "\n" |
820 | R"==(scratch_cell[OFF_SCRATCH_CELL(i, j)] = TO_INPUT(dG1 * h); )==" "\n" |
821 | R"==(} )==" "\n" |
822 | R"==(#else )==" "\n" |
823 | R"==(#error "Wrong Cell Kind" )==" "\n" |
824 | R"==(#endif )==" "\n" |
825 | R"==(} )==" "\n" |
826 | R"==(__attribute__((intel_reqd_sub_group_size(SUBGROUP_SIZE))) __kernel void )==" "\n" |
827 | R"==(ref_rnn_gates_reduction(int dir, int lay, int iter, )==" "\n" |
828 | R"==(__global DIFF_DATA_T *diff_bias_base, __global char *scratch_gates, )==" "\n" |
829 | R"==(__global char *scratch_cell) { )==" "\n" |
830 | R"==(#if !IS_FWD )==" "\n" |
831 | R"==(#if USE_SUBGROUP_REDUCTION )==" "\n" |
832 | R"==(const int k = get_global_id(1); )==" "\n" |
833 | R"==(const int i = get_global_id(2); )==" "\n" |
834 | R"==(#else )==" "\n" |
835 | R"==(const int k = get_global_id(0); )==" "\n" |
836 | R"==(const int i = get_global_id(1); )==" "\n" |
837 | R"==(#endif )==" "\n" |
838 | R"==(const int n_bias_max = (CELL_KIND == LBR_GRU) ? 4 : N_GATES; )==" "\n" |
839 | R"==(if (k >= DHC || i >= n_bias_max) return; )==" "\n" |
840 | R"==(__global DIFF_DATA_T *diff_bias )==" "\n" |
841 | R"==(= diff_bias_base + DIFF_BIAS_OFF(lay, dir, 0, 0); )==" "\n" |
842 | R"==(__global SRC_DATA_T *gates; )==" "\n" |
843 | R"==(int i_ = i; )==" "\n" |
844 | R"==(if (CELL_KIND == LBR_GRU && i == 3) { )==" "\n" |
845 | R"==(gates = (__global SRC_DATA_T *)(scratch_cell); )==" "\n" |
846 | R"==(i_ = 2; )==" "\n" |
847 | R"==(} else )==" "\n" |
848 | R"==(gates = (__global SRC_DATA_T *)(scratch_gates) )==" "\n" |
849 | R"==(+ OFF_SCRATCH_MEM(iter, 0, 0, 0); )==" "\n" |
850 | R"==(#if USE_SUBGROUP_REDUCTION )==" "\n" |
851 | R"==(DIFF_DATA_T result = 0; )==" "\n" |
852 | R"==(for (int j = get_local_id(0); j < BATCH; j += SUBGROUP_SIZE) { )==" "\n" |
853 | R"==(result += SRC_TO_REF(gates[CELL_SCRATCH_MEM(j, i_, k)]); )==" "\n" |
854 | R"==(} )==" "\n" |
855 | R"==(diff_bias[i * DHC + k] += sub_group_reduce_add(result); )==" "\n" |
856 | R"==(#else )==" "\n" |
857 | R"==(for (int j = 0; j < BATCH; j++) { )==" "\n" |
858 | R"==(diff_bias[i * DHC + k] += SRC_TO_REF(gates[CELL_SCRATCH_MEM(j, i_, k)]); )==" "\n" |
859 | R"==(} )==" "\n" |
860 | R"==(#endif )==" "\n" |
861 | R"==(#endif )==" "\n" |
862 | R"==(} )==" "\n" |
863 | R"==()==" ; |
864 | } |
865 | } |
866 | } |
867 | } |