1namespace dnnl {
2namespace impl {
3namespace gpu {
4namespace ocl {
5const char *ref_rnn_kernel = R"==(/******************************************************************************* )==""\n"
6R"==(* Copyright 2019-2022 Intel Corporation )==""\n"
7R"==(* )==""\n"
8R"==(* Licensed under the Apache License, Version 2.0 (the "License"); )==""\n"
9R"==(* you may not use this file except in compliance with the License. )==""\n"
10R"==(* You may obtain a copy of the License at )==""\n"
11R"==(* )==""\n"
12R"==(* http: )==""\n"
13R"==(* )==""\n"
14R"==(* Unless required by applicable law or agreed to in writing, software )==""\n"
15R"==(* distributed under the License is distributed on an "AS IS" BASIS, )==""\n"
16R"==(* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. )==""\n"
17R"==(* See the License for the specific language governing permissions and )==""\n"
18R"==(* limitations under the License. )==""\n"
19R"==(*******************************************************************************/ )==""\n"
20R"==(#include "gpu/ocl/rnn/rnn_types.h" )==""\n"
21R"==(float one_m_square(float a) { )==""\n"
22R"==(return 1.0f - a * a; )==""\n"
23R"==(} )==""\n"
24R"==(float x_m_square(float a) { )==""\n"
25R"==(return (1.0f - a) * a; )==""\n"
26R"==(} )==""\n"
27R"==(float relu_fwd(float s, float alpha) { )==""\n"
28R"==(return s > 0 ? s : s * alpha; )==""\n"
29R"==(} )==""\n"
30R"==(float tanh_fwd(float s) { )==""\n"
31R"==(return tanh(s); )==""\n"
32R"==(} )==""\n"
33R"==(float logistic_fwd(float s) { )==""\n"
34R"==(return 1 / (1 + exp((float)-s)); )==""\n"
35R"==(} )==""\n"
36R"==(float logistic_bwd(float s) { )==""\n"
37R"==(return x_m_square(s); )==""\n"
38R"==(} )==""\n"
39R"==(float relu_bwd(float s, float alpha) { )==""\n"
40R"==(return s > 0 ? 1.f : alpha; )==""\n"
41R"==(} )==""\n"
42R"==(float tanh_bwd(float s) { )==""\n"
43R"==(return (1 - s) * (1 + s); )==""\n"
44R"==(} )==""\n"
45R"==(float linear(float s, float alpha) { )==""\n"
46R"==(return alpha * s; )==""\n"
47R"==(} )==""\n"
48R"==(float relu_fwd_tm(float s, float alpha) { )==""\n"
49R"==(#if !IS_TESTMODE )==""\n"
50R"==(return relu_fwd(s, alpha); )==""\n"
51R"==(#else )==""\n"
52R"==(return linear(s, alpha); )==""\n"
53R"==(#endif )==""\n"
54R"==(} )==""\n"
55R"==(float tanh_fwd_tm(float s, float alpha) { )==""\n"
56R"==(#if !IS_TESTMODE )==""\n"
57R"==(return tanh(s); )==""\n"
58R"==(#else )==""\n"
59R"==(return linear(s, alpha); )==""\n"
60R"==(#endif )==""\n"
61R"==(} )==""\n"
62R"==(float logistic_fwd_tm(float s, float alpha) { )==""\n"
63R"==(#if !IS_TESTMODE )==""\n"
64R"==(return logistic_fwd(s); )==""\n"
65R"==(#else )==""\n"
66R"==(return linear(s, alpha); )==""\n"
67R"==(#endif )==""\n"
68R"==(} )==""\n"
69R"==(float relu_bwd_tm(float s, float alpha) { )==""\n"
70R"==(#if !IS_TESTMODE )==""\n"
71R"==(return relu_bwd(s, alpha); )==""\n"
72R"==(#else )==""\n"
73R"==(return linear(s, alpha); )==""\n"
74R"==(#endif )==""\n"
75R"==(} )==""\n"
76R"==(float tanh_bwd_tm(float s, float alpha) { )==""\n"
77R"==(#if !IS_TESTMODE )==""\n"
78R"==(return tanh_bwd(s); )==""\n"
79R"==(#else )==""\n"
80R"==(return linear(s, alpha); )==""\n"
81R"==(#endif )==""\n"
82R"==(} )==""\n"
83R"==(float logistic_bwd_tm(float s, float alpha) { )==""\n"
84R"==(#if !IS_TESTMODE )==""\n"
85R"==(return logistic_bwd(s); )==""\n"
86R"==(#else )==""\n"
87R"==(return linear(s, alpha); )==""\n"
88R"==(#endif )==""\n"
89R"==(} )==""\n"
90R"==(float activation_fwd(float s, float alpha, float cliping) { )==""\n"
91R"==(#if CELL_KIND == VANILLA_RNN )==""\n"
92R"==(#if ACTIVATION_KIND == ELTWISE_RELU )==""\n"
93R"==(return relu_fwd_tm(s, alpha); )==""\n"
94R"==(#elif ACTIVATION_KIND == ELTWISE_TANH )==""\n"
95R"==(return tanh_fwd_tm(s, alpha); )==""\n"
96R"==(#elif ACTIVATION_KIND == ELTWISE_LOGISTIC )==""\n"
97R"==(return logistic_fwd_tm(s, alpha); )==""\n"
98R"==(#else )==""\n"
99R"==(#error "Unsupported activation_kind" )==""\n"
100R"==(#endif )==""\n"
101R"==(#else )==""\n"
102R"==(return 0.0f; )==""\n"
103R"==(#endif )==""\n"
104R"==(} )==""\n"
105R"==(float activation_bwd(float s, float alpha, float cliping) { )==""\n"
106R"==(#if CELL_KIND == VANILLA_RNN )==""\n"
107R"==(#if ACTIVATION_KIND == ELTWISE_RELU )==""\n"
108R"==(return relu_bwd_tm(s, alpha); )==""\n"
109R"==(#elif ACTIVATION_KIND == ELTWISE_TANH )==""\n"
110R"==(return tanh_bwd_tm(s, alpha); )==""\n"
111R"==(#elif ACTIVATION_KIND == ELTWISE_LOGISTIC )==""\n"
112R"==(return logistic_bwd_tm(s, alpha); )==""\n"
113R"==(#else )==""\n"
114R"==(#error "Unsupported activation_kind" )==""\n"
115R"==(#endif )==""\n"
116R"==(#else )==""\n"
117R"==(return 0.0f; )==""\n"
118R"==(#endif )==""\n"
119R"==(} )==""\n"
120R"==(__attribute__((intel_reqd_sub_group_size(SUBGROUP_SIZE))) __kernel void )==""\n"
121R"==(ref_rnn_copy_init_layer(__global char *ws, __global char *src_base, )==""\n"
122R"==(__global char *scratch_diff_states, int lr, int rl) { )==""\n"
123R"==(#if IS_FWD )==""\n"
124R"==(const int it = get_global_id(2); )==""\n"
125R"==(const int b = get_global_id(1); )==""\n"
126R"==(const int c = get_global_id(0); )==""\n"
127R"==(if (c >= SLC || b >= BATCH || it >= N_ITER) return; )==""\n"
128R"==(__global WS_STATE_DATA_T *dst; )==""\n"
129R"==(__global WS_STATE_DATA_T *dst_base )==""\n"
130R"==(= (__global WS_STATE_DATA_T *)(ws + WS_STATES_OFFSET); )==""\n"
131R"==(__global WS_STATE_DATA_T *src = (__global WS_STATE_DATA_T *)src_base )==""\n"
132R"==(+ SRC_L_OFF(it, 0, 0) + b * SLC + c; )==""\n"
133R"==(if (lr) { )==""\n"
134R"==(dst = dst_base + OFF_WS_STATE(0, 0, it + 1, b, c); )==""\n"
135R"==(dst[0] = src[0]; )==""\n"
136R"==(} )==""\n"
137R"==(if (rl) { )==""\n"
138R"==(dst = dst_base + OFF_WS_STATE(0, N_DIR - 1, N_ITER - it, b, c); )==""\n"
139R"==(dst[0] = src[0]; )==""\n"
140R"==(} )==""\n"
141R"==(#else )==""\n"
142R"==(const int it = get_global_id(1); )==""\n"
143R"==(const int b = get_global_id(0); )==""\n"
144R"==(if (b >= BATCH || it >= N_ITER) return; )==""\n"
145R"==(__global DIFF_DATA_T *dst = (__global DIFF_DATA_T *)scratch_diff_states; )==""\n"
146R"==(#if DIRECTION_KIND == CONCAT )==""\n"
147R"==(__global DIFF_DATA_T *src )==""\n"
148R"==(= (__global DIFF_DATA_T *)src_base + DIFF_DST_L_OFF(it, b, 0); )==""\n"
149R"==(for (int s = 0; s < DHC; s++) { )==""\n"
150R"==(dst[OFF_SCRATCH_DIFF_STATES(N_LAYER, 0, N_STATES, it, b, s)] = src[s]; )==""\n"
151R"==(dst[OFF_SCRATCH_DIFF_STATES( )==""\n"
152R"==(N_LAYER, 1, N_STATES, N_ITER - it - 1, b, s)] )==""\n"
153R"==(= src[DHC + s]; )==""\n"
154R"==(} )==""\n"
155R"==(#elif DIRECTION_KIND == SUM )==""\n"
156R"==(__global DIFF_DATA_T *src )==""\n"
157R"==(= (__global DIFF_DATA_T *)src_base + DIFF_DST_L_OFF(it, b, 0); )==""\n"
158R"==(for (int s = 0; s < DHC; s++) { )==""\n"
159R"==(dst[OFF_SCRATCH_DIFF_STATES(N_LAYER, 0, N_STATES, it, b, s)] = src[s]; )==""\n"
160R"==(dst[OFF_SCRATCH_DIFF_STATES( )==""\n"
161R"==(N_LAYER, 1, N_STATES, N_ITER - it - 1, b, s)] )==""\n"
162R"==(= src[s]; )==""\n"
163R"==(} )==""\n"
164R"==(#elif DIRECTION_KIND == L2R )==""\n"
165R"==(__global DIFF_DATA_T *src )==""\n"
166R"==(= (__global DIFF_DATA_T *)src_base + DIFF_DST_L_OFF(it, b, 0); )==""\n"
167R"==(for (int s = 0; s < DHC; s++) { )==""\n"
168R"==(dst[OFF_SCRATCH_DIFF_STATES(N_LAYER, 0, N_STATES, it, b, s)] = src[s]; )==""\n"
169R"==(} )==""\n"
170R"==(#elif DIRECTION_KIND == R2L )==""\n"
171R"==(__global DIFF_DATA_T *src = (__global DIFF_DATA_T *)src_base )==""\n"
172R"==(+ DIFF_DST_L_OFF(N_ITER - it - 1, b, 0); )==""\n"
173R"==(for (int s = 0; s < DHC; s++) { )==""\n"
174R"==(dst[OFF_SCRATCH_DIFF_STATES(N_LAYER, 0, N_STATES, it, b, s)] = src[s]; )==""\n"
175R"==(} )==""\n"
176R"==(#else )==""\n"
177R"==(#error "Unsupported direction_kind" )==""\n"
178R"==(#endif )==""\n"
179R"==(#endif )==""\n"
180R"==(} )==""\n"
181R"==(__kernel void ref_rnn_copy_init_iter(__global char *ws, __global char *src_base, )==""\n"
182R"==(__global char *src_c_base, __global char *scratch_diff_states )==""\n"
183R"==(#if IS_FWD )==""\n"
184R"==(, )==""\n"
185R"==(const float shift, const float scale, const int quantize )==""\n"
186R"==(#endif )==""\n"
187R"==() { )==""\n"
188R"==(const int s = get_global_id(0); )==""\n"
189R"==(const int b = get_global_id(1); )==""\n"
190R"==(const int lay = get_global_id(2) / N_DIR; )==""\n"
191R"==(const int dir = get_global_id(2) % N_DIR; )==""\n"
192R"==(#if IS_FWD )==""\n"
193R"==(__global INPUT_DATA_T *src = (__global INPUT_DATA_T *)(src_base); )==""\n"
194R"==(__global WS_STATE_DATA_T *dst )==""\n"
195R"==(= (__global WS_STATE_DATA_T *)(ws + WS_STATES_OFFSET); )==""\n"
196R"==(if (s < SIC) )==""\n"
197R"==(dst[OFF_WS_STATE(lay + 1, dir, 0, b, s)] = src_base )==""\n"
198R"==(? (quantize ? TO_WS_STATE( )==""\n"
199R"==(src[SRC_I_OFF(lay, dir, b, s)] * scale + shift) )==""\n"
200R"==(: src[SRC_I_OFF(lay, dir, b, s)]) )==""\n"
201R"==(: TO_WS_STATE(0.0f); )==""\n"
202R"==(#if WITH_SRC_ITER_C )==""\n"
203R"==(__global AUX_DATA_T *src_c = (__global AUX_DATA_T *)(src_c_base); )==""\n"
204R"==(__global AUX_DATA_T *dst_c )==""\n"
205R"==(= (__global AUX_DATA_T *)(ws + WS_C_STATE_OFFSET); )==""\n"
206R"==(if (s < DHC) )==""\n"
207R"==(dst_c[OFF_WS_STATE(lay + 1, dir, 0, b, s)] = src_c_base )==""\n"
208R"==(? src_c[SRC_I_C_OFF(lay, dir, b, s)] )==""\n"
209R"==(: TO_WS_STATE(0.0f); )==""\n"
210R"==(#endif )==""\n"
211R"==(#else )==""\n"
212R"==(__global DIFF_DATA_T *src = (__global DIFF_DATA_T *)(src_base); )==""\n"
213R"==(__global DIFF_DATA_T *dst = (__global DIFF_DATA_T *)scratch_diff_states; )==""\n"
214R"==(if (s < DHC) )==""\n"
215R"==(dst[OFF_SCRATCH_DIFF_STATES(lay, dir, 0, N_ITER, b, s)] )==""\n"
216R"==(= src_base ? src[DIFF_DST_I_OFF(lay, dir, b, s)] : 0.0f; )==""\n"
217R"==(#if WITH_DST_ITER_C )==""\n"
218R"==(__global DIFF_DATA_T *src_c = (__global DIFF_DATA_T *)(src_c_base); )==""\n"
219R"==(if (s < DHC) )==""\n"
220R"==(dst[OFF_SCRATCH_DIFF_STATES(lay, dir, 1, N_ITER, b, s)] )==""\n"
221R"==(= src_c_base ? src_c[DIFF_DST_I_C_OFF(lay, dir, b, s)] : 0.0f; )==""\n"
222R"==(#endif )==""\n"
223R"==(#endif )==""\n"
224R"==(} )==""\n"
225R"==(__attribute__((intel_reqd_sub_group_size(SUBGROUP_SIZE))) __kernel void )==""\n"
226R"==(ref_rnn_copy_res_layer(__global char *ws, __global char *dst_base, )==""\n"
227R"==(__global char *scratch_diff_states, int lr, int rl )==""\n"
228R"==(#if IS_FWD )==""\n"
229R"==(, )==""\n"
230R"==(const float shift, const float scale, const int dequantize )==""\n"
231R"==(#endif )==""\n"
232R"==() { )==""\n"
233R"==(const int it = get_global_id(2); )==""\n"
234R"==(const int b = get_global_id(1); )==""\n"
235R"==(const int s = get_global_id(0); )==""\n"
236R"==(#if IS_FWD )==""\n"
237R"==(if (s >= DHC || b >= BATCH || it >= N_ITER) return; )==""\n"
238R"==(__global WS_STATE_DATA_T *src )==""\n"
239R"==(= (__global WS_STATE_DATA_T *)(ws + WS_STATES_OFFSET); )==""\n"
240R"==(__global DST_DATA_T *dst = (__global DST_DATA_T *)(dst_base); )==""\n"
241R"==(int dir = 0; )==""\n"
242R"==(if (lr) { )==""\n"
243R"==(bool dequantize_at_copy = dequantize && DIRECTION_KIND != SUM; )==""\n"
244R"==(dst[DST_L_OFF(it, b, dir * DHC + s)] = dequantize_at_copy )==""\n"
245R"==(? TO_DST(((float)src[OFF_WS_STATE(N_LAYER, dir, it + 1, b, s)] )==""\n"
246R"==(- shift) )==""\n"
247R"==(/ scale) )==""\n"
248R"==(: src[OFF_WS_STATE(N_LAYER, dir, it + 1, b, s)]; )==""\n"
249R"==(dir = 1; )==""\n"
250R"==(} )==""\n"
251R"==(if (rl) { )==""\n"
252R"==(#if DIRECTION_KIND == SUM )==""\n"
253R"==(if (dequantize) { )==""\n"
254R"==(float val )==""\n"
255R"==(= (float)src[OFF_WS_STATE(N_LAYER, dir, N_ITER - it, b, s)] )==""\n"
256R"==(+ dst[DST_L_OFF(it, b, s)]; )==""\n"
257R"==(val = min(max(val, 0.f), 255.f); )==""\n"
258R"==(dst[DST_L_OFF(it, b, s)] = TO_DST((val - 2 * shift) / scale); )==""\n"
259R"==(} else { )==""\n"
260R"==(#if defined(SRC_DT_U8) && defined(DST_DT_U8) )==""\n"
261R"==(dst[DST_L_OFF(it, b, s)] = convert_uchar_sat( )==""\n"
262R"==(convert_short( )==""\n"
263R"==(src[OFF_WS_STATE(N_LAYER, dir, N_ITER - it, b, s)]) )==""\n"
264R"==(+ convert_short(dst[DST_L_OFF(it, b, s)])); )==""\n"
265R"==(#else )==""\n"
266R"==(ACC_DATA_T temp_src = DST_TO_REF(dst[DST_L_OFF(it, b, s)]); )==""\n"
267R"==(temp_src += DST_TO_REF( )==""\n"
268R"==(src[OFF_WS_STATE(N_LAYER, dir, N_ITER - it, b, s)]); )==""\n"
269R"==(dst[DST_L_OFF(it, b, s)] = REF_TO_DST(temp_src); )==""\n"
270R"==(#endif )==""\n"
271R"==(} )==""\n"
272R"==(#else )==""\n"
273R"==(dst[DST_L_OFF(it, b, dir * DHC + s)] = dequantize )==""\n"
274R"==(? TO_DST(((float)src[OFF_WS_STATE( )==""\n"
275R"==(N_LAYER, dir, N_ITER - it, b, s)] )==""\n"
276R"==(- shift) )==""\n"
277R"==(/ scale) )==""\n"
278R"==(: src[OFF_WS_STATE(N_LAYER, dir, N_ITER - it, b, s)]; )==""\n"
279R"==(#endif )==""\n"
280R"==(} )==""\n"
281R"==(#else )==""\n"
282R"==(if (s >= SLC || b >= BATCH || it >= N_ITER) return; )==""\n"
283R"==(__global DIFF_DATA_T *src = (__global DIFF_DATA_T *)(scratch_diff_states); )==""\n"
284R"==(__global DIFF_DATA_T *dst = (__global DIFF_DATA_T *)(dst_base); )==""\n"
285R"==(int dir = 0; )==""\n"
286R"==(#if DIRECTION_KIND == R2L )==""\n"
287R"==(const int iter = N_ITER - 1 - it; )==""\n"
288R"==(#else )==""\n"
289R"==(const int iter = it; )==""\n"
290R"==(#endif )==""\n"
291R"==(DIFF_DATA_T res = src[OFF_SCRATCH_DIFF_STATES(0, 0, N_STATES, it, b, s)]; )==""\n"
292R"==(#if N_DIR > 1 )==""\n"
293R"==(res += src[OFF_SCRATCH_DIFF_STATES(0, 1, N_STATES, N_ITER - 1 - it, b, s)]; )==""\n"
294R"==(#endif )==""\n"
295R"==(dst[DIFF_SRC_L_OFF(iter, b, dir * SLC + s)] = res; )==""\n"
296R"==(#endif )==""\n"
297R"==(} )==""\n"
298R"==(__kernel void ref_rnn_copy_res_iter(__global char *ws, __global char *dst_base, )==""\n"
299R"==(__global char *dst_c_base, __global char *scratch_diff_states )==""\n"
300R"==(#if IS_FWD )==""\n"
301R"==(, )==""\n"
302R"==(const float shift, const float scale, const int dequantize )==""\n"
303R"==(#endif )==""\n"
304R"==() { )==""\n"
305R"==(const int s = get_global_id(0); )==""\n"
306R"==(const int b = get_global_id(1); )==""\n"
307R"==(const int lay = get_global_id(2) / N_DIR; )==""\n"
308R"==(const int dir = get_global_id(2) % N_DIR; )==""\n"
309R"==(#if IS_FWD )==""\n"
310R"==(__global WS_STATE_DATA_T *src )==""\n"
311R"==(= (__global WS_STATE_DATA_T *)(ws + WS_STATES_OFFSET); )==""\n"
312R"==(__global OUTPUT_DATA_T *dst = (__global OUTPUT_DATA_T *)(dst_base); )==""\n"
313R"==(if (dst_base && s < DHC) { )==""\n"
314R"==(dst[DST_I_OFF(lay, dir, b, s)] = dequantize )==""\n"
315R"==(? TO_OUTPUT( )==""\n"
316R"==(((float)src[OFF_WS_STATE(lay + 1, dir, N_ITER, b, s)] )==""\n"
317R"==(- shift) )==""\n"
318R"==(/ scale) )==""\n"
319R"==(: TO_OUTPUT(src[OFF_WS_STATE(lay + 1, dir, N_ITER, b, s)]); )==""\n"
320R"==(} )==""\n"
321R"==(#if WITH_DST_ITER_C )==""\n"
322R"==(__global AUX_DATA_T *src_c )==""\n"
323R"==(= (__global AUX_DATA_T *)(ws + WS_C_STATE_OFFSET); )==""\n"
324R"==(__global AUX_DATA_T *dst_c = (__global AUX_DATA_T *)(dst_c_base); )==""\n"
325R"==(if (dst_c_base && s < DHC) { )==""\n"
326R"==(dst_c[DST_I_C_OFF(lay, dir, b, s)] )==""\n"
327R"==(= src_c[OFF_WS_STATE(lay + 1, dir, N_ITER, b, s)]; )==""\n"
328R"==(} )==""\n"
329R"==(#endif )==""\n"
330R"==(#else )==""\n"
331R"==(__global DIFF_DATA_T *src = (__global DIFF_DATA_T *)(scratch_diff_states); )==""\n"
332R"==(__global DIFF_DATA_T *dst = (__global DIFF_DATA_T *)(dst_base); )==""\n"
333R"==(__global DIFF_DATA_T *dst_c = (__global DIFF_DATA_T *)(dst_c_base); )==""\n"
334R"==(if (dst_base && s < SIC) { )==""\n"
335R"==(dst[DIFF_SRC_I_OFF(lay, dir, b, s)] )==""\n"
336R"==(= src[OFF_SCRATCH_DIFF_STATES(lay, dir, 0, 0, b, s)]; )==""\n"
337R"==(} )==""\n"
338R"==(#if WITH_SRC_ITER_C )==""\n"
339R"==(if (dst_base && s < DHC) { )==""\n"
340R"==(dst_c[DIFF_SRC_I_C_OFF(lay, dir, b, s)] )==""\n"
341R"==(= src[OFF_SCRATCH_DIFF_STATES(lay, dir, 1, 0, b, s)]; )==""\n"
342R"==(} )==""\n"
343R"==(#endif )==""\n"
344R"==(#endif )==""\n"
345R"==(} )==""\n"
346R"==(__kernel void ref_rnn_ws_set( )==""\n"
347R"==(__global char *ws, OFFTYPE ws_offset, float val, int ws_part) { )==""\n"
348R"==(if (ws_part == WS_C_STATES || ws_part == WS_BIAS) { )==""\n"
349R"==(__global DIFF_DATA_T *dst = (__global DIFF_DATA_T *)(ws + ws_offset); )==""\n"
350R"==(dst[get_global_id(0)] = CONVERT_DATA_T(val); )==""\n"
351R"==(} else if (ws_part == WS_GATES) { )==""\n"
352R"==(__global ACC_DATA_T *dst = (__global ACC_DATA_T *)(ws + ws_offset); )==""\n"
353R"==(dst[get_global_id(0)] = TO_ACC(val); )==""\n"
354R"==(} else { )==""\n"
355R"==(__global WS_STATE_DATA_T *dst )==""\n"
356R"==(= (__global WS_STATE_DATA_T *)(ws + ws_offset); )==""\n"
357R"==(dst[get_global_id(0)] = TO_WS_STATE(val); )==""\n"
358R"==(} )==""\n"
359R"==(} )==""\n"
360R"==(#if DEBUGPRINT )==""\n"
361R"==(__kernel void ref_rnn_ws_print(const __global char *ws) { )==""\n"
362R"==({ )==""\n"
363R"==(__global ACC_DATA_T *wt = (__global ACC_DATA_T *)(ws + WS_GATES_OFFSET); )==""\n"
364R"==(printf("ws_gates: off %d\n", WS_GATES_OFFSET); )==""\n"
365R"==(printf("[lay,dir,iter,batch]\n"); )==""\n"
366R"==(for_(int j = 0; j < N_LAYER; j++) )==""\n"
367R"==(for_(int dir = 0; dir < N_DIR; dir++) )==""\n"
368R"==(for_(int i = 0; i < N_ITER; i++) )==""\n"
369R"==(for (int b = 0; b < BATCH; b++) { )==""\n"
370R"==(printf("[%d,%d,%d,%d]: ", j, dir, i, b); )==""\n"
371R"==(for_(int g = 0; g < N_GATES; g++) )==""\n"
372R"==(for (int s = 0; s < DHC; s++) { )==""\n"
373R"==(printf(" %f", )==""\n"
374R"==(SRC_TO_REF(*(wt + OFF_WS_GATES(j, dir, i, b, g, s)))); )==""\n"
375R"==(} )==""\n"
376R"==(printf("\n"); )==""\n"
377R"==(} )==""\n"
378R"==(} )==""\n"
379R"==({ )==""\n"
380R"==(__global WS_STATE_DATA_T *wt )==""\n"
381R"==(= (__global WS_STATE_DATA_T *)(ws + WS_STATES_OFFSET); )==""\n"
382R"==(printf("ws_states (H): off %d\n", WS_STATES_OFFSET); )==""\n"
383R"==(printf("[lay,dir,iter]\n"); )==""\n"
384R"==(for_(int j = 0; j < N_LAYER + 1; j++) )==""\n"
385R"==(for_(int dir = 0; dir < N_DIR; dir++) )==""\n"
386R"==(for (int i = 1; i < N_ITER + 1; i++) { )==""\n"
387R"==(printf("[%d,%d,%d] : ", j, dir, i); )==""\n"
388R"==(for_(int b = 0; b < BATCH; b++) )==""\n"
389R"==(for (int s = 0; s < WIC; s++) { )==""\n"
390R"==(printf(" %f", )==""\n"
391R"==(SRC_TO_REF(*(wt + OFF_WS_STATE(j, dir, i, b, s)))); )==""\n"
392R"==(} )==""\n"
393R"==(printf("\n"); )==""\n"
394R"==(} )==""\n"
395R"==(} )==""\n"
396R"==(#if IS_TRAINING && CELL_KIND == LBR_GRU )==""\n"
397R"==({ )==""\n"
398R"==(__global ACC_DATA_T *wt )==""\n"
399R"==(= (__global ACC_DATA_T *)(ws + WS_GRID_COMP_OFFSET); )==""\n"
400R"==(printf("ws_grid: off %d\n", WS_GRID_COMP_OFFSET); )==""\n"
401R"==(printf("[lay,dir,iter,batch]\n"); )==""\n"
402R"==(for_(int j = 0; j < N_LAYER; j++) )==""\n"
403R"==(for_(int dir = 0; dir < N_DIR; dir++) )==""\n"
404R"==(for_(int i = 0; i < N_ITER; i++) )==""\n"
405R"==(for (int b = 0; b < BATCH; b++) { )==""\n"
406R"==(printf("[%d,%d,%d,%d]: ", j, dir, i, b); )==""\n"
407R"==(for (int s = 0; s < DHC; s++) { )==""\n"
408R"==(printf(" %f", *(wt + OFF_WS_GRID_OFFSET(j, dir, i, b, s))); )==""\n"
409R"==(} )==""\n"
410R"==(printf("\n"); )==""\n"
411R"==(} )==""\n"
412R"==(} )==""\n"
413R"==(#endif )==""\n"
414R"==(#if IS_FWD && CELL_KIND == VANILLA_LSTM )==""\n"
415R"==({ )==""\n"
416R"==(__global AUX_DATA_T *wt )==""\n"
417R"==(= (__global AUX_DATA_T *)(ws + WS_C_STATE_OFFSET); )==""\n"
418R"==(printf("ws_states (C): off %d\n", WS_C_STATE_OFFSET); )==""\n"
419R"==(printf("[lay,dir,iter]\n"); )==""\n"
420R"==(for_(int j = 0; j < N_LAYER; j++) )==""\n"
421R"==(for_(int dir = 0; dir < N_DIR; dir++) )==""\n"
422R"==(for (int i = 0; i < N_ITER + 1; i++) { )==""\n"
423R"==(printf("[%d,%d,%d] : ", j, dir, i); )==""\n"
424R"==(for_(int b = 0; b < BATCH; b++) )==""\n"
425R"==(for (int s = 0; s < WIC; s++) { )==""\n"
426R"==(printf(" %f", *(wt + OFF_WS_STATE(j, dir, i, b, s))); )==""\n"
427R"==(} )==""\n"
428R"==(printf("\n"); )==""\n"
429R"==(} )==""\n"
430R"==(} )==""\n"
431R"==(#endif )==""\n"
432R"==(#if COPY_BIAS )==""\n"
433R"==({ )==""\n"
434R"==(__global AUX_DATA_T *wt = (__global AUX_DATA_T *)(ws + WS_BIAS_OFFSET); )==""\n"
435R"==(printf("ws_bias: off %d\n", WS_BIAS_OFFSET); )==""\n"
436R"==(printf("[lay,dir]\n"); )==""\n"
437R"==(for_(int j = 0; j < N_LAYER; j++) )==""\n"
438R"==(for_(int dir = 0; dir < N_DIR; dir++) )==""\n"
439R"==({ )==""\n"
440R"==(printf("[%d,%d] : ", j, dir); )==""\n"
441R"==(for_(int nb = 0; nb < N_BIAS; nb++) )==""\n"
442R"==(for (int dhc = 0; dhc < DHC; dhc++) { )==""\n"
443R"==(printf(" %f", *(wt + OFF_WS_BIAS(j, dir, nb, dhc))); )==""\n"
444R"==(} )==""\n"
445R"==(printf("\n"); )==""\n"
446R"==(} )==""\n"
447R"==(} )==""\n"
448R"==(#endif )==""\n"
449R"==(} )==""\n"
450R"==(#endif )==""\n"
451R"==(__kernel void ref_rnn_bias_prepare(__global char *ws, __global float *scales, )==""\n"
452R"==(__global char *wei_layer, __global char *wei_iter, __global float *bias, )==""\n"
453R"==(float data_shift, float data_scale) { )==""\n"
454R"==(#if COPY_BIAS )==""\n"
455R"==(const int dhc = get_global_id(0); )==""\n"
456R"==(const int nbias = get_global_id(1); )==""\n"
457R"==(const int layer = get_global_id(2) / N_DIR; )==""\n"
458R"==(const int dir = get_global_id(2) % N_DIR; )==""\n"
459R"==(__global float *ws_bias = (__global float *)(ws + WS_BIAS_OFFSET); )==""\n"
460R"==(const float wei_scale )==""\n"
461R"==(#if WEI_QPARAM_MASK )==""\n"
462R"==(= scales[nbias * DHC + dhc]; )==""\n"
463R"==(#else )==""\n"
464R"==(= scales[0]; )==""\n"
465R"==(#endif )==""\n"
466R"==(#define COMP_OFF(i0, i1, i2, i3) \ )==""\n"
467R"==(((((i0) * (N_DIR) + (i1)) * (N_BIAS) + (i2)) * (DHC) + (i3)) )==""\n"
468R"==(#define COMP_WEI_LAYER_OFF (WEI_L_D0 * WEI_L_S0) )==""\n"
469R"==(#define COMP_WEI_ITER_OFF (WEI_I_D0 * WEI_I_S0) )==""\n"
470R"==(__global char *temp = (__global char *)(wei_iter + COMP_WEI_ITER_OFF); )==""\n"
471R"==(__global float *wei_iter_comp )==""\n"
472R"==(= (__global float *)(((unsigned long)temp + (sizeof(float) - 1)) )==""\n"
473R"==(& -sizeof(float)); )==""\n"
474R"==(temp = (__global char *)(wei_layer + COMP_WEI_LAYER_OFF); )==""\n"
475R"==(__global float *wei_layer_comp )==""\n"
476R"==(= (__global float *)(((unsigned long)temp + (sizeof(float) - 1)) )==""\n"
477R"==(& -sizeof(float)); )==""\n"
478R"==(const int off = COMP_OFF(layer, dir, nbias, dhc); )==""\n"
479R"==(const float comp = wei_layer_comp[off] + wei_iter_comp[off]; )==""\n"
480R"==(ws_bias[OFF_WS_BIAS(layer, dir, nbias, dhc)] )==""\n"
481R"==(= bias[BIAS_OFF(layer, dir, nbias, dhc)] )==""\n"
482R"==(- comp * data_shift / (wei_scale * data_scale); )==""\n"
483R"==(#endif )==""\n"
484R"==(} )==""\n"
485R"==(#if IS_INT8 && CELL_KIND == VANILLA_LSTM )==""\n"
486R"==(WS_STATE_DATA_T q_d(float f, float data_scale, float data_shift) { )==""\n"
487R"==(float qf = f * data_scale + data_shift; )==""\n"
488R"==(return TO_WS_STATE(qf); )==""\n"
489R"==(} )==""\n"
490R"==(float deq_w(ACC_DATA_T s, int gate, int j, __global float *scales, )==""\n"
491R"==(float data_scale) { )==""\n"
492R"==(#if WEI_QPARAM_MASK )==""\n"
493R"==(float wei_scale = scales[gate * DHC + j]; )==""\n"
494R"==(#else )==""\n"
495R"==(float wei_scale = scales[0]; )==""\n"
496R"==(#endif )==""\n"
497R"==(return (float)(s) / (wei_scale * data_scale); )==""\n"
498R"==(} )==""\n"
499R"==(__attribute__((intel_reqd_sub_group_size(SUBGROUP_SIZE))) __kernel void )==""\n"
500R"==(ref_rnn_elemwise_fwd(int dir, int lay, int iter, __global char *ws, )==""\n"
501R"==(__global char *scr_gates, __global float *scales, )==""\n"
502R"==(__global float *bias_base, float alpha, float data_shift, )==""\n"
503R"==(float data_scale, __global float *tm_scales, float tm_cscale) { )==""\n"
504R"==(const int i = get_global_id(1); )==""\n"
505R"==(const int j = get_global_id(0); )==""\n"
506R"==(if (j >= DHC || i >= BATCH) return; )==""\n"
507R"==(const __global float *c_states_tm1_l )==""\n"
508R"==(= (__global float *)(ws + WS_C_STATE_OFFSET) )==""\n"
509R"==(+ OFF_WS_STATE(lay + 1, dir, iter, 0, 0); )==""\n"
510R"==(__global float *ws_bias = (__global float *)(ws + WS_BIAS_OFFSET); )==""\n"
511R"==(__global ACC_DATA_T *ws_gates )==""\n"
512R"==(= (__global ACC_DATA_T *)(ws + WS_GATES_OFFSET) )==""\n"
513R"==(+ OFF_WS_GATES(lay, dir, iter, 0, 0, 0); )==""\n"
514R"==(__global ACC_DATA_T *scratch_gates = (__global ACC_DATA_T *)(scr_gates) )==""\n"
515R"==(+ OFF_SCRATCH_MEM(iter, 0, 0, 0); )==""\n"
516R"==(__global WS_STATE_DATA_T *h_states_t_l )==""\n"
517R"==(= (__global WS_STATE_DATA_T *)(ws + WS_STATES_OFFSET) )==""\n"
518R"==(+ OFF_WS_STATE(lay + 1, dir, iter + 1, 0, 0); )==""\n"
519R"==(__global float *c_states_t_l = (__global float *)(ws + WS_C_STATE_OFFSET) )==""\n"
520R"==(+ OFF_WS_STATE(lay + 1, dir, iter + 1, 0, 0); )==""\n"
521R"==(float G0 = logistic_fwd_tm(deq_w(scratch_gates[CELL_SCRATCH_MEM(i, 0, j)], )==""\n"
522R"==(0, j, scales, data_scale) )==""\n"
523R"==(+ ws_bias[OFF_WS_BIAS(lay, dir, 0, j)], )==""\n"
524R"==(tm_scales[0]); )==""\n"
525R"==(float G1 = logistic_fwd_tm(deq_w(scratch_gates[CELL_SCRATCH_MEM(i, 1, j)], )==""\n"
526R"==(1, j, scales, data_scale) )==""\n"
527R"==(+ ws_bias[OFF_WS_BIAS(lay, dir, 1, j)], )==""\n"
528R"==(tm_scales[1]); )==""\n"
529R"==(float G2 = tanh_fwd_tm(deq_w(scratch_gates[CELL_SCRATCH_MEM(i, 2, j)], 2, j, )==""\n"
530R"==(scales, data_scale) )==""\n"
531R"==(+ ws_bias[OFF_WS_BIAS(lay, dir, 2, j)], )==""\n"
532R"==(tm_scales[2]); )==""\n"
533R"==(float G3 = logistic_fwd_tm(deq_w(scratch_gates[CELL_SCRATCH_MEM(i, 3, j)], )==""\n"
534R"==(3, j, scales, data_scale) )==""\n"
535R"==(+ ws_bias[OFF_WS_BIAS(lay, dir, 3, j)], )==""\n"
536R"==(tm_scales[3]); )==""\n"
537R"==(float tmp = G1 * c_states_tm1_l[CELL_WS_STATE(i, j)] + G0 * G2; )==""\n"
538R"==(h_states_t_l[CELL_WS_STATE(i, j)] )==""\n"
539R"==(= q_d(G3 * tanh_fwd_tm(tmp, tm_cscale), data_scale, data_shift); )==""\n"
540R"==(c_states_t_l[CELL_WS_STATE(i, j)] = tmp; )==""\n"
541R"==(} )==""\n"
542R"==(#else )==""\n"
543R"==(__attribute__((intel_reqd_sub_group_size(SUBGROUP_SIZE))) __kernel void )==""\n"
544R"==(ref_rnn_elemwise_fwd( )==""\n"
545R"==(int dir, int lay, int iter, __global char *ws, __global char *scr_gates, )==""\n"
546R"==(__global AUX_DATA_T *bias_base, float alpha, __global float *tm_scales, )==""\n"
547R"==(#if CELL_KIND == VANILLA_LSTM || CELL_KIND == VANILLA_RNN )==""\n"
548R"==(float tm_cscale )==""\n"
549R"==(#elif CELL_KIND == LBR_GRU )==""\n"
550R"==(__global char *scr_cell )==""\n"
551R"==(#elif CELL_KIND == VANILLA_GRU )==""\n"
552R"==(int n_part )==""\n"
553R"==(#endif )==""\n"
554R"==() { )==""\n"
555R"==(const int i = get_global_id(1); )==""\n"
556R"==(const int j = get_global_id(0); )==""\n"
557R"==(if (j >= DHC || i >= BATCH) return; )==""\n"
558R"==(const __global AUX_DATA_T *c_states_tm1_l )==""\n"
559R"==(= (__global AUX_DATA_T *)(ws + WS_C_STATE_OFFSET) )==""\n"
560R"==(+ OFF_WS_STATE(lay + 1, dir, iter, 0, 0); )==""\n"
561R"==(const __global AUX_DATA_T *bias = bias_base + BIAS_OFF(lay, dir, 0, 0); )==""\n"
562R"==(__global AUX_DATA_T *ws_gates )==""\n"
563R"==(= (__global AUX_DATA_T *)(ws + WS_GATES_OFFSET) )==""\n"
564R"==(+ OFF_WS_GATES(lay, dir, iter, 0, 0, 0); )==""\n"
565R"==(__global AUX_DATA_T *scratch_gates = (__global AUX_DATA_T *)(scr_gates) )==""\n"
566R"==(+ OFF_SCRATCH_MEM(iter, 0, 0, 0); )==""\n"
567R"==(__global WS_STATE_DATA_T *h_states_t_l )==""\n"
568R"==(= (__global WS_STATE_DATA_T *)(ws + WS_STATES_OFFSET) )==""\n"
569R"==(+ OFF_WS_STATE(lay + 1, dir, iter + 1, 0, 0); )==""\n"
570R"==(#if CELL_KIND == VANILLA_LSTM )==""\n"
571R"==(__global AUX_DATA_T *c_states_t_l )==""\n"
572R"==(= (__global AUX_DATA_T *)(ws + WS_C_STATE_OFFSET) )==""\n"
573R"==(+ OFF_WS_STATE(lay + 1, dir, iter + 1, 0, 0); )==""\n"
574R"==(float g_i = logistic_fwd_tm((float)scratch_gates[CELL_SCRATCH_MEM(i, 0, j)] )==""\n"
575R"==(+ bias[OFF_KER_BIAS(0, j)], )==""\n"
576R"==(tm_scales[0]); )==""\n"
577R"==(float g_f = logistic_fwd_tm((float)scratch_gates[CELL_SCRATCH_MEM(i, 1, j)] )==""\n"
578R"==(+ bias[OFF_KER_BIAS(1, j)], )==""\n"
579R"==(tm_scales[1]); )==""\n"
580R"==(float g_z = tanh_fwd_tm((float)scratch_gates[CELL_SCRATCH_MEM(i, 2, j)] )==""\n"
581R"==(+ bias[OFF_KER_BIAS(2, j)], )==""\n"
582R"==(tm_scales[2]); )==""\n"
583R"==(float g_o = logistic_fwd_tm((float)scratch_gates[CELL_SCRATCH_MEM(i, 3, j)] )==""\n"
584R"==(+ bias[OFF_KER_BIAS(3, j)], )==""\n"
585R"==(tm_scales[3]); )==""\n"
586R"==(#if IS_TRAINING )==""\n"
587R"==(ws_gates[CELL_WS_GATES(i, 0, j)] = g_i; )==""\n"
588R"==(ws_gates[CELL_WS_GATES(i, 1, j)] = g_f; )==""\n"
589R"==(ws_gates[CELL_WS_GATES(i, 2, j)] = g_z; )==""\n"
590R"==(ws_gates[CELL_WS_GATES(i, 3, j)] = g_o; )==""\n"
591R"==(#endif )==""\n"
592R"==(float Ct = g_f * c_states_tm1_l[CELL_WS_STATE(i, j)] + g_i * g_z; )==""\n"
593R"==(float Ht = g_o * tanh_fwd_tm(Ct, tm_cscale); )==""\n"
594R"==(h_states_t_l[CELL_WS_STATE(i, j)] = TO_INPUT(Ht); )==""\n"
595R"==(c_states_t_l[CELL_WS_STATE(i, j)] = Ct; )==""\n"
596R"==(#elif CELL_KIND == VANILLA_RNN )==""\n"
597R"==(float g = activation_fwd((float)scratch_gates[CELL_SCRATCH_MEM(i, 0, j)] )==""\n"
598R"==(+ bias[OFF_KER_BIAS(0, j)], )==""\n"
599R"==(#if IS_TESTMODE )==""\n"
600R"==(tm_scales[0], 0); )==""\n"
601R"==(#else )==""\n"
602R"==(alpha, 0); )==""\n"
603R"==(#endif )==""\n"
604R"==(#if IS_TRAINING )==""\n"
605R"==(ws_gates[CELL_WS_GATES(i, 0, j)] = g; )==""\n"
606R"==(#endif )==""\n"
607R"==(h_states_t_l[CELL_WS_STATE(i, j)] = TO_INPUT(g); )==""\n"
608R"==(#elif CELL_KIND == LBR_GRU )==""\n"
609R"==(__global AUX_DATA_T *scratch_cell = (__global AUX_DATA_T *)(scr_cell); )==""\n"
610R"==(__global WS_STATE_DATA_T *src_iter )==""\n"
611R"==(= (__global WS_STATE_DATA_T *)(ws + WS_STATES_OFFSET) )==""\n"
612R"==(+ OFF_WS_STATE(lay + 1, dir, iter, 0, 0); )==""\n"
613R"==(__global AUX_DATA_T *ws_grid )==""\n"
614R"==(= (__global AUX_DATA_T *)(ws + WS_GRID_COMP_OFFSET) )==""\n"
615R"==(+ OFF_WS_GRID_OFFSET(lay, dir, iter, 0, 0); )==""\n"
616R"==(float Wh_b = (float)scratch_cell[CELL_SCRATCH_MEM(i, 2, j)] )==""\n"
617R"==(+ bias[OFF_KER_BIAS(3, j)]; )==""\n"
618R"==(float G0 = logistic_fwd_tm((float)scratch_gates[CELL_SCRATCH_MEM(i, 0, j)] )==""\n"
619R"==(+ (float)scratch_cell[CELL_SCRATCH_MEM(i, 0, j)] )==""\n"
620R"==(+ bias[OFF_KER_BIAS(0, j)], )==""\n"
621R"==(tm_scales[0]); )==""\n"
622R"==(float G1 = logistic_fwd_tm((float)scratch_gates[CELL_SCRATCH_MEM(i, 1, j)] )==""\n"
623R"==(+ (float)scratch_cell[CELL_SCRATCH_MEM(i, 1, j)] )==""\n"
624R"==(+ bias[OFF_KER_BIAS(1, j)], )==""\n"
625R"==(tm_scales[1]); )==""\n"
626R"==(float G2 = tanh_fwd_tm((float)scratch_gates[CELL_SCRATCH_MEM(i, 2, j)] )==""\n"
627R"==(+ G1 * Wh_b + bias[OFF_KER_BIAS(2, j)], )==""\n"
628R"==(tm_scales[2]); )==""\n"
629R"==(float Ht = G0 * TO_REF(src_iter[CELL_WS_STATE(i, j)]) + (1 - G0) * G2; )==""\n"
630R"==(h_states_t_l[CELL_WS_STATE(i, j)] = TO_INPUT(Ht); )==""\n"
631R"==(#if IS_TRAINING )==""\n"
632R"==(ws_gates[CELL_WS_GATES(i, 0, j)] = G0; )==""\n"
633R"==(ws_gates[CELL_WS_GATES(i, 1, j)] = G1; )==""\n"
634R"==(ws_gates[CELL_WS_GATES(i, 2, j)] = G2; )==""\n"
635R"==(ws_grid[CELL_WS_GRID_COMP(i, j)] = Wh_b; )==""\n"
636R"==(#endif )==""\n"
637R"==(#elif CELL_KIND == VANILLA_GRU )==""\n"
638R"==(__global WS_STATE_DATA_T *src_iter )==""\n"
639R"==(= (__global WS_STATE_DATA_T *)(ws + WS_STATES_OFFSET) )==""\n"
640R"==(+ OFF_WS_STATE(lay + 1, dir, iter, 0, 0); )==""\n"
641R"==(if (n_part == 1) { )==""\n"
642R"==(float G0 = logistic_fwd_tm(scratch_gates[CELL_SCRATCH_MEM(i, 0, j)] )==""\n"
643R"==(+ bias[OFF_KER_BIAS(0, j)], )==""\n"
644R"==(tm_scales[0]); )==""\n"
645R"==(float G1 = logistic_fwd_tm(scratch_gates[CELL_SCRATCH_MEM(i, 1, j)] )==""\n"
646R"==(+ bias[OFF_KER_BIAS(1, j)], )==""\n"
647R"==(tm_scales[1]); )==""\n"
648R"==(/* TODO from CPU: Can be optimized for fwd_training by using )==""\n"
649R"==(ws_gates instead of scratch_gates in p2 */ )==""\n"
650R"==(scratch_gates[CELL_SCRATCH_MEM(i, 0, j)] = TO_INPUT(G0); )==""\n"
651R"==(scratch_gates[CELL_SCRATCH_MEM(i, 1, j)] = TO_INPUT(G1); )==""\n"
652R"==(float tmp = TO_REF(src_iter[CELL_WS_STATE(i, j)]); )==""\n"
653R"==(h_states_t_l[CELL_WS_STATE(i, j)] = TO_INPUT(tmp * G1); )==""\n"
654R"==(#if IS_TRAINING )==""\n"
655R"==(ws_gates[CELL_WS_GATES(i, 0, j)] = G0; )==""\n"
656R"==(ws_gates[CELL_WS_GATES(i, 1, j)] = G1; )==""\n"
657R"==(#endif )==""\n"
658R"==(} else if (n_part == 2) { )==""\n"
659R"==(float G0 = TO_REF(scratch_gates[CELL_SCRATCH_MEM(i, 0, j)]); )==""\n"
660R"==(float G2 = tanh_fwd_tm(scratch_gates[CELL_SCRATCH_MEM(i, 2, j)] )==""\n"
661R"==(+ bias[OFF_KER_BIAS(2, j)], )==""\n"
662R"==(tm_scales[2]); )==""\n"
663R"==(float tmp = TO_REF(src_iter[CELL_WS_STATE(i, j)]); )==""\n"
664R"==(h_states_t_l[CELL_WS_STATE(i, j)] )==""\n"
665R"==(= TO_INPUT(tmp * G0 + (1.0f - G0) * G2); )==""\n"
666R"==(#if IS_TRAINING )==""\n"
667R"==(ws_gates[CELL_WS_GATES(i, 2, j)] = G2; )==""\n"
668R"==(#endif )==""\n"
669R"==(} )==""\n"
670R"==(#else )==""\n"
671R"==(#error "Wrong Cell Kind" )==""\n"
672R"==(#endif )==""\n"
673R"==(} )==""\n"
674R"==(#endif )==""\n"
675R"==(__attribute__((intel_reqd_sub_group_size(SUBGROUP_SIZE))) __kernel void )==""\n"
676R"==(ref_rnn_elemwise_bwd(int dir, int lay, int iter, __global char *ws, )==""\n"
677R"==(__global char *scr_gates, __global AUX_DATA_T *bias_base, float alpha, )==""\n"
678R"==(__global float *tm_scales, )==""\n"
679R"==(#if CELL_KIND == VANILLA_LSTM || CELL_KIND == VANILLA_RNN )==""\n"
680R"==(float tm_cscale, )==""\n"
681R"==(#elif CELL_KIND == LBR_GRU )==""\n"
682R"==(__global char *scr_gate_r, )==""\n"
683R"==(#elif CELL_KIND == VANILLA_GRU )==""\n"
684R"==(int n_part, __global char *scr_cell, __global char *scratch_dhG1, )==""\n"
685R"==(#endif )==""\n"
686R"==(__global char *diff_states) { )==""\n"
687R"==(const int i = get_global_id(1); )==""\n"
688R"==(const int j = get_global_id(0); )==""\n"
689R"==(if (j >= DHC || i >= BATCH) return; )==""\n"
690R"==(#if CELL_KIND == VANILLA_LSTM )==""\n"
691R"==(__global AUX_DATA_T *ws_gates )==""\n"
692R"==(= (__global AUX_DATA_T *)(ws + WS_GATES_OFFSET) )==""\n"
693R"==(+ OFF_WS_GATES(lay, dir, iter, 0, 0, 0); )==""\n"
694R"==(__global SRC_DATA_T *scratch_gates = (__global SRC_DATA_T *)(scr_gates) )==""\n"
695R"==(+ OFF_SCRATCH_MEM(iter, 0, 0, 0); )==""\n"
696R"==(__global AUX_DATA_T *c_states_t_l )==""\n"
697R"==(= (__global AUX_DATA_T *)(ws + WS_C_STATE_OFFSET) )==""\n"
698R"==(+ OFF_WS_STATE(lay + 1, dir, iter + 1, 0, 0); )==""\n"
699R"==(__global AUX_DATA_T *c_states_tm1_l )==""\n"
700R"==(= (__global AUX_DATA_T *)(ws + WS_C_STATE_OFFSET) )==""\n"
701R"==(+ OFF_WS_STATE(lay + 1, dir, iter, 0, 0); )==""\n"
702R"==(__global DIFF_DATA_T *diff_states_t_l = (__global DIFF_DATA_T *)diff_states )==""\n"
703R"==(+ OFF_SCRATCH_DIFF_STATES(lay, dir, 0, iter, 0, 0); )==""\n"
704R"==(__global DIFF_DATA_T *diff_states_tp1_l )==""\n"
705R"==(= (__global DIFF_DATA_T *)diff_states )==""\n"
706R"==(+ OFF_SCRATCH_DIFF_STATES(lay, dir, 0, iter + 1, 0, 0); )==""\n"
707R"==(__global DIFF_DATA_T *diff_states_t_lp1 )==""\n"
708R"==(= (__global DIFF_DATA_T *)diff_states )==""\n"
709R"==(+ OFF_SCRATCH_DIFF_STATES(lay + 1, dir, 0, iter, 0, 0); )==""\n"
710R"==(float Ct = c_states_t_l[CELL_WS_STATE(i, j)]; )==""\n"
711R"==(float tanhCt = tanh_fwd_tm(Ct, tm_cscale); )==""\n"
712R"==(float dHt = (float)diff_states_tp1_l[CELL_SCRATCH_DIFF_STATES(0, i, j)] )==""\n"
713R"==(+ diff_states_t_lp1[CELL_SCRATCH_DIFF_STATES(N_STATES, i, j)]; )==""\n"
714R"==(float dCt = (float)diff_states_tp1_l[CELL_SCRATCH_DIFF_STATES(1, i, j)] )==""\n"
715R"==(+ one_m_square(tanhCt) * ws_gates[CELL_WS_GATES(i, 3, j)] * dHt; )==""\n"
716R"==(float dG1 = (float)c_states_tm1_l[CELL_WS_STATE(i, j)] * dCt )==""\n"
717R"==(* x_m_square(ws_gates[CELL_WS_GATES(i, 1, j)]); )==""\n"
718R"==(float dG0 = ws_gates[CELL_WS_GATES(i, 2, j)] * dCt )==""\n"
719R"==(* x_m_square(ws_gates[CELL_WS_GATES(i, 0, j)]); )==""\n"
720R"==(float dG3 = tanhCt * dHt * x_m_square(ws_gates[CELL_WS_GATES(i, 3, j)]); )==""\n"
721R"==(float dG2 = ws_gates[CELL_WS_GATES(i, 0, j)] * dCt )==""\n"
722R"==(* one_m_square(ws_gates[CELL_WS_GATES(i, 2, j)]); )==""\n"
723R"==(diff_states_t_l[CELL_SCRATCH_DIFF_STATES(1, i, j)] )==""\n"
724R"==(= dCt * ws_gates[CELL_WS_GATES(i, 1, j)]; )==""\n"
725R"==(scratch_gates[CELL_SCRATCH_MEM(i, 0, j)] = TO_INPUT(dG0); )==""\n"
726R"==(scratch_gates[CELL_SCRATCH_MEM(i, 1, j)] = TO_INPUT(dG1); )==""\n"
727R"==(scratch_gates[CELL_SCRATCH_MEM(i, 2, j)] = TO_INPUT(dG2); )==""\n"
728R"==(scratch_gates[CELL_SCRATCH_MEM(i, 3, j)] = TO_INPUT(dG3); )==""\n"
729R"==(#elif CELL_KIND == LBR_GRU )==""\n"
730R"==(__global SRC_DATA_T *scratch_gates = (__global SRC_DATA_T *)(scr_gates) )==""\n"
731R"==(+ OFF_SCRATCH_MEM(iter, 0, 0, 0); )==""\n"
732R"==(__global SRC_DATA_T *scratch_gate_r = (__global SRC_DATA_T *)(scr_gate_r); )==""\n"
733R"==(__global AUX_DATA_T *ws_gates )==""\n"
734R"==(= (__global AUX_DATA_T *)(ws + WS_GATES_OFFSET) )==""\n"
735R"==(+ OFF_WS_GATES(lay, dir, iter, 0, 0, 0); )==""\n"
736R"==(__global AUX_DATA_T *ws_grid )==""\n"
737R"==(= (__global AUX_DATA_T *)(ws + WS_GRID_COMP_OFFSET) )==""\n"
738R"==(+ OFF_WS_GRID_OFFSET(lay, dir, iter, 0, 0); )==""\n"
739R"==(__global DIFF_DATA_T *diff_src_iter = (__global DIFF_DATA_T *)diff_states )==""\n"
740R"==(+ OFF_SCRATCH_DIFF_STATES(lay, dir, 0, iter, 0, 0); )==""\n"
741R"==(__global DIFF_DATA_T *diff_dst_iter = (__global DIFF_DATA_T *)diff_states )==""\n"
742R"==(+ OFF_SCRATCH_DIFF_STATES(lay, dir, 0, iter + 1, 0, 0); )==""\n"
743R"==(__global DIFF_DATA_T *diff_dst_layer = (__global DIFF_DATA_T *)diff_states )==""\n"
744R"==(+ OFF_SCRATCH_DIFF_STATES(lay + 1, dir, 0, iter, 0, 0); )==""\n"
745R"==(__global WS_STATE_DATA_T *src_iter )==""\n"
746R"==(= (__global WS_STATE_DATA_T *)(ws + WS_STATES_OFFSET) )==""\n"
747R"==(+ OFF_WS_STATE(lay + 1, dir, iter, 0, 0); )==""\n"
748R"==(float h = TO_REF(src_iter[CELL_WS_STATE(i, j)]); )==""\n"
749R"==(float Wh_b = ws_grid[CELL_WS_GRID_COMP(i, j)]; )==""\n"
750R"==(float dHt = diff_dst_iter[CELL_SCRATCH_DIFF_STATES(0, i, j)] )==""\n"
751R"==(+ diff_dst_layer[CELL_SCRATCH_DIFF_STATES(N_STATES, i, j)]; )==""\n"
752R"==(float dG0 = (h - ws_gates[CELL_WS_GATES(i, 2, j)]) * dHt )==""\n"
753R"==(* x_m_square(ws_gates[CELL_WS_GATES(i, 0, j)]); )==""\n"
754R"==(float dG2 = (1.0f - ws_gates[CELL_WS_GATES(i, 0, j)]) )==""\n"
755R"==(* one_m_square(ws_gates[CELL_WS_GATES(i, 2, j)]) * dHt; )==""\n"
756R"==(float dG1 = Wh_b * dG2 * x_m_square(ws_gates[CELL_WS_GATES(i, 1, j)]); )==""\n"
757R"==(diff_src_iter[CELL_SCRATCH_DIFF_STATES(0, i, j)] )==""\n"
758R"==(= dHt * ws_gates[CELL_WS_GATES(i, 0, j)]; )==""\n"
759R"==(scratch_gates[CELL_SCRATCH_MEM(i, 0, j)] = TO_INPUT(dG0); )==""\n"
760R"==(scratch_gates[CELL_SCRATCH_MEM(i, 1, j)] = TO_INPUT(dG1); )==""\n"
761R"==(scratch_gates[CELL_SCRATCH_MEM(i, 2, j)] = TO_INPUT(dG2); )==""\n"
762R"==(scratch_gate_r[CELL_SCRATCH_MEM(i, 0, j)] = TO_INPUT(dG0); )==""\n"
763R"==(scratch_gate_r[CELL_SCRATCH_MEM(i, 1, j)] = TO_INPUT(dG1); )==""\n"
764R"==(scratch_gate_r[CELL_SCRATCH_MEM(i, 2, j)] )==""\n"
765R"==(= TO_INPUT(dG2 * ws_gates[CELL_WS_GATES(i, 1, j)]); )==""\n"
766R"==(#elif CELL_KIND == VANILLA_RNN )==""\n"
767R"==(__global AUX_DATA_T *ws_gates )==""\n"
768R"==(= (__global AUX_DATA_T *)(ws + WS_GATES_OFFSET) )==""\n"
769R"==(+ OFF_WS_GATES(lay, dir, iter, i, 0, j); )==""\n"
770R"==(__global SRC_DATA_T *scratch_gates = (__global SRC_DATA_T *)(scr_gates) )==""\n"
771R"==(+ OFF_SCRATCH_MEM(iter, i, 0, j); )==""\n"
772R"==(__global DIFF_DATA_T *diff_states_t_lp1 )==""\n"
773R"==(= (__global DIFF_DATA_T *)diff_states )==""\n"
774R"==(+ OFF_SCRATCH_DIFF_STATES(lay + 1, dir, N_STATES, iter, i, j); )==""\n"
775R"==(__global DIFF_DATA_T *diff_states_tp1_l )==""\n"
776R"==(= (__global DIFF_DATA_T *)diff_states )==""\n"
777R"==(+ OFF_SCRATCH_DIFF_STATES(lay, dir, 0, iter + 1, i, j); )==""\n"
778R"==(const float dH = (float)diff_states_t_lp1[0] + diff_states_tp1_l[0]; )==""\n"
779R"==(float g = ws_gates[0]; )==""\n"
780R"==(#if IS_TESTMODE )==""\n"
781R"==(scratch_gates[0] = TO_INPUT(dH * activation_bwd(g, tm_scales[0], 0.)); )==""\n"
782R"==(#else )==""\n"
783R"==(scratch_gates[0] = TO_INPUT(dH * activation_bwd(g, alpha, 0.)); )==""\n"
784R"==(#endif )==""\n"
785R"==(#elif CELL_KIND == VANILLA_GRU )==""\n"
786R"==(__global SRC_DATA_T *scratch_gates = (__global SRC_DATA_T *)(scr_gates) )==""\n"
787R"==(+ OFF_SCRATCH_MEM(iter, 0, 0, 0); )==""\n"
788R"==(__global AUX_DATA_T *ws_gates )==""\n"
789R"==(= (__global AUX_DATA_T *)(ws + WS_GATES_OFFSET) )==""\n"
790R"==(+ OFF_WS_GATES(lay, dir, iter, 0, 0, 0); )==""\n"
791R"==(__global DIFF_DATA_T *diff_src_iter = (__global DIFF_DATA_T *)diff_states )==""\n"
792R"==(+ OFF_SCRATCH_DIFF_STATES(lay, dir, 0, iter, 0, 0); )==""\n"
793R"==(__global DIFF_DATA_T *diff_dst_iter = (__global DIFF_DATA_T *)diff_states )==""\n"
794R"==(+ OFF_SCRATCH_DIFF_STATES(lay, dir, 0, iter + 1, 0, 0); )==""\n"
795R"==(__global DIFF_DATA_T *diff_dst_layer = (__global DIFF_DATA_T *)diff_states )==""\n"
796R"==(+ OFF_SCRATCH_DIFF_STATES(lay + 1, dir, 0, iter, 0, 0); )==""\n"
797R"==(__global WS_STATE_DATA_T *src_iter )==""\n"
798R"==(= (__global WS_STATE_DATA_T *)(ws + WS_STATES_OFFSET) )==""\n"
799R"==(+ OFF_WS_STATE(lay + 1, dir, iter, 0, 0); )==""\n"
800R"==(float h = TO_REF(src_iter[CELL_WS_STATE(i, j)]); )==""\n"
801R"==(if (n_part == 1) { )==""\n"
802R"==(float dHt = diff_dst_iter[CELL_SCRATCH_DIFF_STATES(0, i, j)] )==""\n"
803R"==(+ diff_dst_layer[CELL_SCRATCH_DIFF_STATES(N_STATES, i, j)]; )==""\n"
804R"==(float dG2 = (1.0f - ws_gates[CELL_WS_GATES(i, 0, j)]) * dHt )==""\n"
805R"==(* one_m_square(ws_gates[CELL_WS_GATES(i, 2, j)]); )==""\n"
806R"==(float dG0 = (h - ws_gates[CELL_WS_GATES(i, 2, j)]) * dHt )==""\n"
807R"==(* x_m_square(ws_gates[CELL_WS_GATES(i, 0, j)]); )==""\n"
808R"==(diff_src_iter[CELL_SCRATCH_DIFF_STATES(0, i, j)] )==""\n"
809R"==(= dHt * ws_gates[CELL_WS_GATES(i, 0, j)]; )==""\n"
810R"==(scratch_gates[CELL_SCRATCH_MEM(i, 0, j)] = TO_INPUT(dG0); )==""\n"
811R"==(scratch_gates[CELL_SCRATCH_MEM(i, 2, j)] = TO_INPUT(dG2); )==""\n"
812R"==(} else if (n_part == 2) { )==""\n"
813R"==(__global SRC_DATA_T *scratch_cell = (__global SRC_DATA_T *)(scr_cell); )==""\n"
814R"==(__global DIFF_DATA_T *dhG1 = (__global DIFF_DATA_T *)scratch_dhG1; )==""\n"
815R"==(float dG1 = ws_gates[CELL_WS_GATES(i, 1, j)]; )==""\n"
816R"==(diff_src_iter[CELL_SCRATCH_DIFF_STATES(0, i, j)] )==""\n"
817R"==(+= dhG1[OFF_SCRATCH_DHG1(i, j)] * dG1; )==""\n"
818R"==(scratch_gates[CELL_SCRATCH_MEM(i, 1, j)] )==""\n"
819R"==(= TO_INPUT(dhG1[OFF_SCRATCH_DHG1(i, j)] * h * x_m_square(dG1)); )==""\n"
820R"==(scratch_cell[OFF_SCRATCH_CELL(i, j)] = TO_INPUT(dG1 * h); )==""\n"
821R"==(} )==""\n"
822R"==(#else )==""\n"
823R"==(#error "Wrong Cell Kind" )==""\n"
824R"==(#endif )==""\n"
825R"==(} )==""\n"
826R"==(__attribute__((intel_reqd_sub_group_size(SUBGROUP_SIZE))) __kernel void )==""\n"
827R"==(ref_rnn_gates_reduction(int dir, int lay, int iter, )==""\n"
828R"==(__global DIFF_DATA_T *diff_bias_base, __global char *scratch_gates, )==""\n"
829R"==(__global char *scratch_cell) { )==""\n"
830R"==(#if !IS_FWD )==""\n"
831R"==(#if USE_SUBGROUP_REDUCTION )==""\n"
832R"==(const int k = get_global_id(1); )==""\n"
833R"==(const int i = get_global_id(2); )==""\n"
834R"==(#else )==""\n"
835R"==(const int k = get_global_id(0); )==""\n"
836R"==(const int i = get_global_id(1); )==""\n"
837R"==(#endif )==""\n"
838R"==(const int n_bias_max = (CELL_KIND == LBR_GRU) ? 4 : N_GATES; )==""\n"
839R"==(if (k >= DHC || i >= n_bias_max) return; )==""\n"
840R"==(__global DIFF_DATA_T *diff_bias )==""\n"
841R"==(= diff_bias_base + DIFF_BIAS_OFF(lay, dir, 0, 0); )==""\n"
842R"==(__global SRC_DATA_T *gates; )==""\n"
843R"==(int i_ = i; )==""\n"
844R"==(if (CELL_KIND == LBR_GRU && i == 3) { )==""\n"
845R"==(gates = (__global SRC_DATA_T *)(scratch_cell); )==""\n"
846R"==(i_ = 2; )==""\n"
847R"==(} else )==""\n"
848R"==(gates = (__global SRC_DATA_T *)(scratch_gates) )==""\n"
849R"==(+ OFF_SCRATCH_MEM(iter, 0, 0, 0); )==""\n"
850R"==(#if USE_SUBGROUP_REDUCTION )==""\n"
851R"==(DIFF_DATA_T result = 0; )==""\n"
852R"==(for (int j = get_local_id(0); j < BATCH; j += SUBGROUP_SIZE) { )==""\n"
853R"==(result += SRC_TO_REF(gates[CELL_SCRATCH_MEM(j, i_, k)]); )==""\n"
854R"==(} )==""\n"
855R"==(diff_bias[i * DHC + k] += sub_group_reduce_add(result); )==""\n"
856R"==(#else )==""\n"
857R"==(for (int j = 0; j < BATCH; j++) { )==""\n"
858R"==(diff_bias[i * DHC + k] += SRC_TO_REF(gates[CELL_SCRATCH_MEM(j, i_, k)]); )==""\n"
859R"==(} )==""\n"
860R"==(#endif )==""\n"
861R"==(#endif )==""\n"
862R"==(} )==""\n"
863R"==()==";
864}
865}
866}
867}