1namespace dnnl {
2namespace impl {
3namespace gpu {
4namespace ocl {
5const char *ref_bnorm_kernel = R"==(/******************************************************************************* )==""\n"
6R"==(* Copyright 2019-2022 Intel Corporation )==""\n"
7R"==(* )==""\n"
8R"==(* Licensed under the Apache License, Version 2.0 (the "License"); )==""\n"
9R"==(* you may not use this file except in compliance with the License. )==""\n"
10R"==(* You may obtain a copy of the License at )==""\n"
11R"==(* )==""\n"
12R"==(* http: )==""\n"
13R"==(* )==""\n"
14R"==(* Unless required by applicable law or agreed to in writing, software )==""\n"
15R"==(* distributed under the License is distributed on an "AS IS" BASIS, )==""\n"
16R"==(* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. )==""\n"
17R"==(* See the License for the specific language governing permissions and )==""\n"
18R"==(* limitations under the License. )==""\n"
19R"==(*******************************************************************************/ )==""\n"
20R"==(#if MB_BLOCK == 16 )==""\n"
21R"==(#define MB16 )==""\n"
22R"==(#define VECT_DT_N 8 )==""\n"
23R"==(#elif VECTORIZE_CALC_STATS == 1 )==""\n"
24R"==(#define VECT_DT_N VECT_SIZE )==""\n"
25R"==(#else )==""\n"
26R"==(#define VECT_DT_N 1 )==""\n"
27R"==(#endif )==""\n"
28R"==(#if VECT_DT_N == 1 )==""\n"
29R"==(#define VECT_CHAR_TO_INT convert_int )==""\n"
30R"==(#else )==""\n"
31R"==(#define VECT_CHAR_TO_INT CONCAT2(convert_int, VECT_DT_N) )==""\n"
32R"==(#endif )==""\n"
33R"==(#if USE_16MB_UNROLL == 0 && (CALCULATE_STATS == 1 || IS_BWD == 1) )==""\n"
34R"==(int reduce_index(int x[5]) { )==""\n"
35R"==(int dim[5] = {MB, IC, ID, IH, IW}; )==""\n"
36R"==(dim[REDUCE_DIM_IDX] = 1; )==""\n"
37R"==(return x[0] * (dim[2] * dim[3] * dim[4]) + x[2] * (dim[3] * dim[4]) )==""\n"
38R"==(+ x[3] * dim[4] + x[4]; )==""\n"
39R"==(} )==""\n"
40R"==(#endif )==""\n"
41R"==(#include "gpu/ocl/ocl_types.h" )==""\n"
42R"==(#if IS_FWD == 1 )==""\n"
43R"==(#if USE_16MB_UNROLL == 0 && CALCULATE_STATS == 1 )==""\n"
44R"==(#if VECTORIZE_CALC_STATS == 1 )==""\n"
45R"==(NAMED_KERNEL_ATTR(CALC) )==""\n"
46R"==(__kernel void calculate_mean(__global DATA_T *src, __global float *mean) { )==""\n"
47R"==(int x[5]; )==""\n"
48R"==(x[0] = GWS_GET_STAT_MB(); )==""\n"
49R"==(x[1] = GWS_GET_STAT_IC(); )==""\n"
50R"==(x[2] = GWS_GET_STAT_ID(); )==""\n"
51R"==(x[3] = GWS_GET_STAT_IH(); )==""\n"
52R"==(x[4] = GWS_GET_STAT_IW(); )==""\n"
53R"==(VECT_FLOAT_T vect_sum = 0; )==""\n"
54R"==(for (int i = 0; i < REDUCE_DIM; i += SUB_GROUP_SIZE * VECT_DT_N) { )==""\n"
55R"==(x[REDUCE_DIM_IDX] = i; )==""\n"
56R"==(int src_off = SRC_OFF(x[0], x[1], x[2], x[3], x[4]); )==""\n"
57R"==(vect_sum += CONVERT_VECT_FLOAT_T(AS_VECT_DATA_T( )==""\n"
58R"==(VECT_BLOCK_READ((const __global BLOCK_DATA_T *)&src[src_off]))); )==""\n"
59R"==(} )==""\n"
60R"==(#if VECT_DT_N == 1 )==""\n"
61R"==(float sum = vect_sum; )==""\n"
62R"==(#else )==""\n"
63R"==(float sum = 0; )==""\n"
64R"==(for (int i = 0; i < VECT_DT_N; ++i) { )==""\n"
65R"==(sum += vect_sum[i]; )==""\n"
66R"==(} )==""\n"
67R"==(#endif )==""\n"
68R"==(x[REDUCE_DIM_IDX] = 0; )==""\n"
69R"==(int reduce_idx = reduce_index(x); )==""\n"
70R"==(float total_sum = sub_group_reduce_add(sum); )==""\n"
71R"==(int local_id = get_sub_group_local_id(); )==""\n"
72R"==(if (local_id == 0) { mean[reduce_idx * IC + x[1]] = total_sum; } )==""\n"
73R"==(} )==""\n"
74R"==(NAMED_KERNEL_ATTR(CALC) )==""\n"
75R"==(__kernel void calculate_variance( )==""\n"
76R"==(__global DATA_T *src, __global float *mean, __global float *variance) { )==""\n"
77R"==(int x[5]; )==""\n"
78R"==(x[0] = GWS_GET_STAT_MB(); )==""\n"
79R"==(x[1] = GWS_GET_STAT_IC(); )==""\n"
80R"==(x[2] = GWS_GET_STAT_ID(); )==""\n"
81R"==(x[3] = GWS_GET_STAT_IH(); )==""\n"
82R"==(x[4] = GWS_GET_STAT_IW(); )==""\n"
83R"==(VECT_FLOAT_T mean_tmp = mean[x[1]]; )==""\n"
84R"==(VECT_FLOAT_T vect_sum = 0; )==""\n"
85R"==(for (int i = 0; i < REDUCE_DIM; i += SUB_GROUP_SIZE * VECT_DT_N) { )==""\n"
86R"==(x[REDUCE_DIM_IDX] = i; )==""\n"
87R"==(int src_off = SRC_OFF(x[0], x[1], x[2], x[3], x[4]); )==""\n"
88R"==(VECT_FLOAT_T v0 = CONVERT_VECT_FLOAT_T(AS_VECT_DATA_T( )==""\n"
89R"==(VECT_BLOCK_READ((const __global BLOCK_DATA_T *)&src[src_off]))); )==""\n"
90R"==(v0 -= mean_tmp; )==""\n"
91R"==(vect_sum += v0 * v0; )==""\n"
92R"==(} )==""\n"
93R"==(#if VECT_DT_N == 1 )==""\n"
94R"==(float sum = vect_sum; )==""\n"
95R"==(#else )==""\n"
96R"==(float sum = 0; )==""\n"
97R"==(for (int i = 0; i < VECT_DT_N; ++i) { )==""\n"
98R"==(sum += vect_sum[i]; )==""\n"
99R"==(} )==""\n"
100R"==(#endif )==""\n"
101R"==(x[REDUCE_DIM_IDX] = 0; )==""\n"
102R"==(int reduce_idx = reduce_index(x); )==""\n"
103R"==(float total_sum = sub_group_reduce_add(sum); )==""\n"
104R"==(int local_id = get_sub_group_local_id(); )==""\n"
105R"==(if (local_id == 0) { )==""\n"
106R"==(variance += MB * ID * IH * IW * IC / REDUCE_DIM; )==""\n"
107R"==(variance[reduce_idx * IC + x[1]] = total_sum; )==""\n"
108R"==(} )==""\n"
109R"==(} )==""\n"
110R"==(NAMED_KERNEL_ATTR(CALC) )==""\n"
111R"==(__kernel void calculate_mean_variance( )==""\n"
112R"==(__global DATA_T *src, __global float *mean, __global float *variance) { )==""\n"
113R"==(#if SKIP_REDUCE_STATS == 1 )==""\n"
114R"==(int x[5]; )==""\n"
115R"==(x[0] = GWS_GET_STAT_MB(); )==""\n"
116R"==(x[1] = GWS_GET_STAT_IC(); )==""\n"
117R"==(x[2] = GWS_GET_STAT_ID(); )==""\n"
118R"==(x[3] = GWS_GET_STAT_IH(); )==""\n"
119R"==(x[4] = GWS_GET_STAT_IW(); )==""\n"
120R"==(VECT_FLOAT_T src_sum = 0; )==""\n"
121R"==(VECT_FLOAT_T src_pow_sum = 0; )==""\n"
122R"==(for (int i = 0; i < REDUCE_DIM; i += SUB_GROUP_SIZE * VECT_DT_N) { )==""\n"
123R"==(x[REDUCE_DIM_IDX] = i; )==""\n"
124R"==(int src_off = SRC_OFF(x[0], x[1], x[2], x[3], x[4]); )==""\n"
125R"==(VECT_FLOAT_T src_vect = CONVERT_VECT_FLOAT_T(AS_VECT_DATA_T( )==""\n"
126R"==(VECT_BLOCK_READ((const __global BLOCK_DATA_T *)&src[src_off]))); )==""\n"
127R"==(src_sum += src_vect; )==""\n"
128R"==(src_pow_sum += src_vect * src_vect; )==""\n"
129R"==(} )==""\n"
130R"==(#if VECT_DT_N == 1 )==""\n"
131R"==(float sum = src_sum; )==""\n"
132R"==(float pow_sum = src_pow_sum; )==""\n"
133R"==(#else )==""\n"
134R"==(float sum = 0; )==""\n"
135R"==(float pow_sum = 0; )==""\n"
136R"==(for (int i = 0; i < VECT_DT_N; ++i) { )==""\n"
137R"==(sum += src_sum[i]; )==""\n"
138R"==(pow_sum += src_pow_sum[i]; )==""\n"
139R"==(} )==""\n"
140R"==(#endif )==""\n"
141R"==(x[REDUCE_DIM_IDX] = 0; )==""\n"
142R"==(int reduce_idx = reduce_index(x); )==""\n"
143R"==(float total_sum = sub_group_reduce_add(sum); )==""\n"
144R"==(float total_pow_sum = sub_group_reduce_add(pow_sum); )==""\n"
145R"==(int local_id = get_sub_group_local_id(); )==""\n"
146R"==(if (local_id == 0) { )==""\n"
147R"==(float calc_mean = total_sum / (MB * ID * IH * IW); )==""\n"
148R"==(float calc_variance )==""\n"
149R"==(= total_pow_sum / (MB * ID * IH * IW) - calc_mean * calc_mean; )==""\n"
150R"==(mean[x[1]] = calc_mean; )==""\n"
151R"==(variance[x[1]] = calc_variance < 0 ? 0 : calc_variance; )==""\n"
152R"==(} )==""\n"
153R"==(#endif )==""\n"
154R"==(} )==""\n"
155R"==(#else )==""\n"
156R"==(NAMED_KERNEL_ATTR(CALC) )==""\n"
157R"==(__kernel void calculate_mean(__global DATA_T *src, __global float *mean) { )==""\n"
158R"==(int x[5]; )==""\n"
159R"==(x[0] = GWS_GET_STAT_MB(); )==""\n"
160R"==(x[1] = GWS_GET_STAT_IC(); )==""\n"
161R"==(x[2] = GWS_GET_STAT_ID(); )==""\n"
162R"==(x[3] = GWS_GET_STAT_IH(); )==""\n"
163R"==(x[4] = GWS_GET_STAT_IW(); )==""\n"
164R"==(float sum = 0; )==""\n"
165R"==(for (int i = 0; i < REDUCE_DIM; i++) { )==""\n"
166R"==(x[REDUCE_DIM_IDX] = i; )==""\n"
167R"==(sum += TO_DEF_ACC_DATA_T(src[SRC_OFF(x[0], x[1], x[2], x[3], x[4])]); )==""\n"
168R"==(} )==""\n"
169R"==(x[REDUCE_DIM_IDX] = 0; )==""\n"
170R"==(int reduce_idx = reduce_index(x); )==""\n"
171R"==(mean[reduce_idx * IC + x[1]] = sum; )==""\n"
172R"==(} )==""\n"
173R"==(NAMED_KERNEL_ATTR(CALC) )==""\n"
174R"==(__kernel void calculate_variance( )==""\n"
175R"==(__global DATA_T *src, __global float *mean, __global float *variance) { )==""\n"
176R"==(int x[5]; )==""\n"
177R"==(x[0] = GWS_GET_STAT_MB(); )==""\n"
178R"==(x[1] = GWS_GET_STAT_IC(); )==""\n"
179R"==(x[2] = GWS_GET_STAT_ID(); )==""\n"
180R"==(x[3] = GWS_GET_STAT_IH(); )==""\n"
181R"==(x[4] = GWS_GET_STAT_IW(); )==""\n"
182R"==(float sum = 0; )==""\n"
183R"==(for (int i = 0; i < REDUCE_DIM; i++) { )==""\n"
184R"==(x[REDUCE_DIM_IDX] = i; )==""\n"
185R"==(DEF_ACC_DATA_T v0 )==""\n"
186R"==(= TO_DEF_ACC_DATA_T(src[SRC_OFF(x[0], x[1], x[2], x[3], x[4])]) )==""\n"
187R"==(- mean[x[1]]; )==""\n"
188R"==(sum += v0 * v0; )==""\n"
189R"==(} )==""\n"
190R"==(variance += MB * ID * IH * IW * IC / REDUCE_DIM; )==""\n"
191R"==(x[REDUCE_DIM_IDX] = 0; )==""\n"
192R"==(int reduce_idx = reduce_index(x); )==""\n"
193R"==(variance[reduce_idx * IC + x[1]] = sum; )==""\n"
194R"==(} )==""\n"
195R"==(#endif )==""\n"
196R"==(NAMED_KERNEL_ATTR(REDUCE) )==""\n"
197R"==(__kernel void reduce_mean(__global float *reduce_temp, __global float *mean) { )==""\n"
198R"==(const int c = GWS_GET_REDUCE_STAT_IC(); )==""\n"
199R"==(reduce_temp += c; )==""\n"
200R"==(float sum = 0.0f; )==""\n"
201R"==(int reduce_size = MB * ID * IH * IW / REDUCE_DIM; )==""\n"
202R"==(for (int i = 0; i < reduce_size; i++) { )==""\n"
203R"==(sum += reduce_temp[i * IC]; )==""\n"
204R"==(} )==""\n"
205R"==(mean[c] = sum / (MB * ID * IH * IW); )==""\n"
206R"==(} )==""\n"
207R"==(NAMED_KERNEL_ATTR(REDUCE) )==""\n"
208R"==(__kernel void reduce_variance( )==""\n"
209R"==(__global float *reduce_temp, __global float *variance) { )==""\n"
210R"==(const int c = GWS_GET_REDUCE_STAT_IC(); )==""\n"
211R"==(#if SAVE_STATS == 0 )==""\n"
212R"==(variance += IC; )==""\n"
213R"==(#endif )==""\n"
214R"==(float sum = 0.0f; )==""\n"
215R"==(int reduce_size = MB * ID * IH * IW / REDUCE_DIM; )==""\n"
216R"==(reduce_temp += reduce_size * IC + c; )==""\n"
217R"==(for (int i = 0; i < reduce_size; i++) )==""\n"
218R"==(sum += reduce_temp[i * IC]; )==""\n"
219R"==(variance[c] = sum / (MB * ID * IH * IW); )==""\n"
220R"==(} )==""\n"
221R"==(#endif )==""\n"
222R"==(KERNEL_ATTR )==""\n"
223R"==(__kernel void ref_bnorm_fwd(__global DATA_T *src, __global float *mean, )==""\n"
224R"==(__global float *variance, __global DATA_T *dst, __global float *scale, )==""\n"
225R"==(__global float *shift, __global char *ws, float eps, )==""\n"
226R"==(__global DATA_T *src_add) { )==""\n"
227R"==(const int n = GWS_GET_MB(); )==""\n"
228R"==(const int c = GWS_GET_IC(); )==""\n"
229R"==(const int d = GWS_GET_ID(); )==""\n"
230R"==(const int h = GWS_GET_IH(); )==""\n"
231R"==(const int w = GWS_GET_IW(); )==""\n"
232R"==(#if USE_SCALE == 1 )==""\n"
233R"==(float sm = scale[c]; )==""\n"
234R"==(#else )==""\n"
235R"==(float sm = 1; )==""\n"
236R"==(#endif )==""\n"
237R"==(#if USE_SHIFT == 1 )==""\n"
238R"==(float sv = shift[c]; )==""\n"
239R"==(#else )==""\n"
240R"==(float sv = 0; )==""\n"
241R"==(#endif )==""\n"
242R"==(#if SAVE_STATS == 0 && CALCULATE_STATS == 1 )==""\n"
243R"==(variance += IC; )==""\n"
244R"==(#endif )==""\n"
245R"==(float v_mean = mean[c]; )==""\n"
246R"==(float v_variance = variance[c]; )==""\n"
247R"==(const int off = SRC_OFF(n, c, d, h, w); )==""\n"
248R"==(float v0 = TO_DEF_ACC_DATA_T(src[off]); )==""\n"
249R"==(float sqrt_variance = 1.0f / sqrt(v_variance + eps); )==""\n"
250R"==(float bn_res = sm * (v0 - v_mean) * sqrt_variance + sv; )==""\n"
251R"==(#if FUSE_BN_ADD_RELU == 1 )==""\n"
252R"==(bn_res += TO_DEF_ACC_DATA_T(src_add[off]); )==""\n"
253R"==(#endif )==""\n"
254R"==(#if FUSE_BN_RELU == 1 )==""\n"
255R"==(if (bn_res <= 0) { )==""\n"
256R"==(bn_res = 0; )==""\n"
257R"==(#if IS_TRAINING == 1 )==""\n"
258R"==(ws[off] = 0; )==""\n"
259R"==(} else { )==""\n"
260R"==(ws[off] = -1; )==""\n"
261R"==(#endif )==""\n"
262R"==(} )==""\n"
263R"==(#endif )==""\n"
264R"==(#if WITH_RELU )==""\n"
265R"==(bn_res = max(bn_res, 0.0f); )==""\n"
266R"==(#endif )==""\n"
267R"==(dst[off] = TO_DATA_T(bn_res); )==""\n"
268R"==(} )==""\n"
269R"==(#endif )==""\n"
270R"==(#if IS_BWD == 1 )==""\n"
271R"==(#if USE_16MB_UNROLL == 1 )==""\n"
272R"==(NAMED_KERNEL_ATTR(CALC) )==""\n"
273R"==(__kernel void calculate_stats(__global DATA_T *src, __global float *mean, )==""\n"
274R"==(__global DATA_T *diff_dst, __global char *ws, )==""\n"
275R"==(__global float *reduce_temp) { )==""\n"
276R"==(const int mb = GWS_GET_STAT_MB(); )==""\n"
277R"==(const int stat_mb_block_idx = mb / MB_BLOCK; )==""\n"
278R"==(const int c = GWS_GET_STAT_IC(); )==""\n"
279R"==(const int sp_beg = GWS_GET_STAT_SP(); )==""\n"
280R"==(const int stat_sp_block = GWS_GET_STAT_SP_BLOCK(); )==""\n"
281R"==(const int stat_sp_nblocks = ID * IH * IW / stat_sp_block; )==""\n"
282R"==(const int stat_sp_block_idx = sp_beg / stat_sp_block; )==""\n"
283R"==(const int mb_sp_idx )==""\n"
284R"==(= stat_mb_block_idx * stat_sp_nblocks + stat_sp_block_idx; )==""\n"
285R"==(const int s_off = c * ID * IH * IW * MB_BLOCK + mb * IC * ID * IH * IW )==""\n"
286R"==(+ sp_beg * MB_BLOCK * IC_BLOCK; )==""\n"
287R"==(src += s_off; )==""\n"
288R"==(diff_dst += s_off; )==""\n"
289R"==(#if FUSE_BN_RELU == 1 )==""\n"
290R"==(ws += s_off; )==""\n"
291R"==(#endif )==""\n"
292R"==(VECT_FLOAT_T diff_gamma0 = 0.0f, diff_beta0 = 0.0f; )==""\n"
293R"==(VECT_FLOAT_T diff_gamma1 = 0.0f, diff_beta1 = 0.0f; )==""\n"
294R"==(float v_mean = as_float( )==""\n"
295R"==(intel_sub_group_block_read((const __global uint *)&mean[c])); )==""\n"
296R"==(for (int sp = sp_beg; sp < sp_beg + stat_sp_block; sp++) { )==""\n"
297R"==(VECT_FLOAT_T dd0 = CONVERT_VECT_FLOAT_T(AS_VECT_DATA_T( )==""\n"
298R"==(VECT_BLOCK_READ((const __global BLOCK_DATA_T *)&diff_dst[0]))); )==""\n"
299R"==(VECT_FLOAT_T ss0 = CONVERT_VECT_FLOAT_T(AS_VECT_DATA_T( )==""\n"
300R"==(VECT_BLOCK_READ((const __global BLOCK_DATA_T *)&src[0]))); )==""\n"
301R"==(#ifdef MB16 )==""\n"
302R"==(VECT_FLOAT_T dd1 = CONVERT_VECT_FLOAT_T(AS_VECT_DATA_T(VECT_BLOCK_READ( )==""\n"
303R"==((const __global BLOCK_DATA_T *)&diff_dst[8 * 16]))); )==""\n"
304R"==(VECT_FLOAT_T ss1 = CONVERT_VECT_FLOAT_T(AS_VECT_DATA_T( )==""\n"
305R"==(VECT_BLOCK_READ((const __global BLOCK_DATA_T *)&src[8 * 16]))); )==""\n"
306R"==(#endif )==""\n"
307R"==(#if FUSE_BN_RELU == 1 )==""\n"
308R"==(VECT_INT_T ws0 = VECT_CHAR_TO_INT(AS_VECT_CHAR_T( )==""\n"
309R"==(VECT_UCHAR_READ((const __global uchar *)&ws[0]))); )==""\n"
310R"==(dd0 = select((VECT_FLOAT_T)0.0f, dd0, ws0); )==""\n"
311R"==(#ifdef MB16 )==""\n"
312R"==(VECT_INT_T ws1 = VECT_CHAR_TO_INT(AS_VECT_CHAR_T( )==""\n"
313R"==(VECT_UCHAR_READ((const __global uchar *)&ws[8 * 16]))); )==""\n"
314R"==(dd1 = select((VECT_FLOAT_T)0.0f, dd1, ws1); )==""\n"
315R"==(#endif )==""\n"
316R"==(ws += MB_BLOCK * IC_BLOCK; )==""\n"
317R"==(#endif )==""\n"
318R"==(diff_gamma0 = fma((ss0 - (VECT_FLOAT_T)v_mean), dd0, diff_gamma0); )==""\n"
319R"==(diff_beta0 += dd0; )==""\n"
320R"==(#ifdef MB16 )==""\n"
321R"==(diff_gamma1 = fma((ss1 - (VECT_FLOAT_T)v_mean), dd1, diff_gamma1); )==""\n"
322R"==(diff_beta1 += dd1; )==""\n"
323R"==(#endif )==""\n"
324R"==(src += MB_BLOCK * IC_BLOCK; )==""\n"
325R"==(diff_dst += MB_BLOCK * IC_BLOCK; )==""\n"
326R"==(} )==""\n"
327R"==(#ifdef MB16 )==""\n"
328R"==(float v_diff_gamma = 0.0f, v_diff_beta = 0.0; )==""\n"
329R"==(for (int i = 0; i < 8; i++) { )==""\n"
330R"==(v_diff_gamma += diff_gamma0[i] + diff_gamma1[i]; )==""\n"
331R"==(v_diff_beta += diff_beta0[i] + diff_beta1[i]; )==""\n"
332R"==(} )==""\n"
333R"==(#else )==""\n"
334R"==(float v_diff_gamma = diff_gamma0, v_diff_beta = diff_beta0; )==""\n"
335R"==(#endif )==""\n"
336R"==(intel_sub_group_block_write( )==""\n"
337R"==((__global uint *)&reduce_temp[mb_sp_idx * IC + c], )==""\n"
338R"==(as_uint(v_diff_gamma)); )==""\n"
339R"==(intel_sub_group_block_write( )==""\n"
340R"==((__global uint *)&reduce_temp[REDUCE_STAT_NBLOCKS * IC )==""\n"
341R"==(+ mb_sp_idx * IC + c], )==""\n"
342R"==(as_uint(v_diff_beta)); )==""\n"
343R"==(} )==""\n"
344R"==(NAMED_KERNEL_ATTR(REDUCE) )==""\n"
345R"==(__kernel void reduce_stats(__global float *reduce_temp, )==""\n"
346R"==(__global float *diff_scale, __global float *diff_shift, )==""\n"
347R"==(__global float *variance, float eps) { )==""\n"
348R"==(const int c = GWS_GET_REDUCE_STAT_IC(); )==""\n"
349R"==(reduce_temp += c; )==""\n"
350R"==(float diff_gamma = 0.0f, diff_beta = 0.0f; )==""\n"
351R"==(for (int i = 0; i < REDUCE_STAT_NBLOCKS; i++) { )==""\n"
352R"==(diff_gamma += reduce_temp[i * IC]; )==""\n"
353R"==(diff_beta += reduce_temp[REDUCE_STAT_NBLOCKS * IC + i * IC]; )==""\n"
354R"==(} )==""\n"
355R"==(float sqrt_variance = 1.0f / sqrt(variance[c] + eps); )==""\n"
356R"==(diff_scale[c] = diff_gamma * sqrt_variance; )==""\n"
357R"==(#if DIFF_SHIFT == 1 )==""\n"
358R"==(diff_shift[c] = diff_beta; )==""\n"
359R"==(#else )==""\n"
360R"==(diff_shift[REDUCE_STAT_NBLOCKS * IC + c] = diff_beta; )==""\n"
361R"==(#endif )==""\n"
362R"==(} )==""\n"
363R"==(#else )==""\n"
364R"==(NAMED_KERNEL_ATTR(CALC) )==""\n"
365R"==(__kernel void calculate_stats(__global DATA_T *src, __global float *mean, )==""\n"
366R"==(__global DATA_T *diff_dst, __global char *ws, )==""\n"
367R"==(__global float *reduce_temp) { )==""\n"
368R"==(float diff_gamma = 0; )==""\n"
369R"==(float diff_beta = 0; )==""\n"
370R"==(int x[5]; )==""\n"
371R"==(x[0] = GWS_GET_STAT_MB(); )==""\n"
372R"==(x[1] = GWS_GET_STAT_IC(); )==""\n"
373R"==(x[2] = GWS_GET_STAT_ID(); )==""\n"
374R"==(x[3] = GWS_GET_STAT_IH(); )==""\n"
375R"==(x[4] = GWS_GET_STAT_IW(); )==""\n"
376R"==(for (int i = 0; i < REDUCE_DIM; i++) { )==""\n"
377R"==(x[REDUCE_DIM_IDX] = i; )==""\n"
378R"==(int off = SRC_OFF(x[0], x[1], x[2], x[3], x[4]); )==""\n"
379R"==(float dd = CONVERT_FLOAT_T(diff_dst[off]); )==""\n"
380R"==(#if FUSE_BN_RELU == 1 )==""\n"
381R"==(if (!ws[off]) dd = 0; )==""\n"
382R"==(#endif )==""\n"
383R"==(diff_gamma += (CONVERT_FLOAT_T(src[off]) - mean[x[1]]) * dd; )==""\n"
384R"==(diff_beta += dd; )==""\n"
385R"==(} )==""\n"
386R"==(int ss_off = MB * ID * IH * IW * IC / REDUCE_DIM; )==""\n"
387R"==(x[REDUCE_DIM_IDX] = 0; )==""\n"
388R"==(int reduce_idx = reduce_index(x); )==""\n"
389R"==(reduce_temp[reduce_idx * IC + x[1]] = diff_gamma; )==""\n"
390R"==(reduce_temp[ss_off + reduce_idx * IC + x[1]] = diff_beta; )==""\n"
391R"==(} )==""\n"
392R"==(NAMED_KERNEL_ATTR(REDUCE) )==""\n"
393R"==(__kernel void reduce_stats(__global float *reduce_temp, )==""\n"
394R"==(__global float *diff_scale, __global float *diff_shift, )==""\n"
395R"==(__global float *variance, float eps) { )==""\n"
396R"==(const int c = GWS_GET_REDUCE_STAT_IC(); )==""\n"
397R"==(float diff_gamma = 0.0f; )==""\n"
398R"==(float diff_beta = 0.0f; )==""\n"
399R"==(int reduce_size = MB * ID * IH * IW / REDUCE_DIM; )==""\n"
400R"==(for (int i = 0; i < reduce_size; i++) { )==""\n"
401R"==(diff_gamma += reduce_temp[c + i * IC]; )==""\n"
402R"==(diff_beta += reduce_temp[IC * reduce_size + c + i * IC]; )==""\n"
403R"==(} )==""\n"
404R"==(float sqrt_variance = 1.0f / sqrt(variance[c] + eps); )==""\n"
405R"==(diff_scale[c] = diff_gamma * sqrt_variance; )==""\n"
406R"==(#if DIFF_SHIFT == 1 )==""\n"
407R"==(diff_shift[c] = diff_beta; )==""\n"
408R"==(#else )==""\n"
409R"==(diff_shift[IC * reduce_size + c] = diff_beta; )==""\n"
410R"==(#endif )==""\n"
411R"==(} )==""\n"
412R"==(#endif )==""\n"
413R"==(KERNEL_ATTR )==""\n"
414R"==(__kernel void ref_bnorm_bwd(__global DATA_T *src, __global float *mean, )==""\n"
415R"==(__global float *variance, __global DATA_T *diff_dst, )==""\n"
416R"==(__global float *scale, __global char *ws, __global DATA_T *diff_src, )==""\n"
417R"==(__global float *diff_scale, __global float *diff_shift, float eps, )==""\n"
418R"==(__global DATA_T *diff_src_add) { )==""\n"
419R"==(#if USE_16MB_UNROLL == 1 )==""\n"
420R"==(const int n = GWS_GET_MB(); )==""\n"
421R"==(const int c = GWS_GET_IC(); )==""\n"
422R"==(const int d = GWS_GET_ID(); )==""\n"
423R"==(const int h = GWS_GET_IH(); )==""\n"
424R"==(const int w = GWS_GET_IW(); )==""\n"
425R"==(#if USE_SCALE == 1 )==""\n"
426R"==(float gamma = as_float( )==""\n"
427R"==(intel_sub_group_block_read((const __global uint *)&scale[c])); )==""\n"
428R"==(#else )==""\n"
429R"==(float gamma = 1.0f; )==""\n"
430R"==(#endif )==""\n"
431R"==(float v_variance = as_float( )==""\n"
432R"==(intel_sub_group_block_read((const __global uint *)&variance[c])); )==""\n"
433R"==(float sqrt_variance = 1.0f / sqrt(v_variance + eps); )==""\n"
434R"==(#if CALCULATE_DIFF_STATS == 1 )==""\n"
435R"==(float v_mean = as_float( )==""\n"
436R"==(intel_sub_group_block_read((const __global uint *)&mean[c])); )==""\n"
437R"==(float diff_gamma = as_float( )==""\n"
438R"==(intel_sub_group_block_read((const __global uint *)&diff_scale[c])); )==""\n"
439R"==(#if DIFF_SHIFT == 1 )==""\n"
440R"==(float diff_beta = as_float( )==""\n"
441R"==(intel_sub_group_block_read((const __global uint *)&diff_shift[c])); )==""\n"
442R"==(#else )==""\n"
443R"==(float diff_beta = as_float(intel_sub_group_block_read( )==""\n"
444R"==((const __global uint *)&diff_shift[REDUCE_STAT_NBLOCKS * IC + c])); )==""\n"
445R"==(#endif )==""\n"
446R"==(#endif )==""\n"
447R"==(const uint d_off = SRC_OFF(n, c, d, h, w); )==""\n"
448R"==(diff_src += d_off; )==""\n"
449R"==(#if FUSE_BN_ADD_RELU == 1 )==""\n"
450R"==(diff_src_add += d_off; )==""\n"
451R"==(#endif )==""\n"
452R"==(diff_dst += d_off; )==""\n"
453R"==(src += d_off; )==""\n"
454R"==(VECT_FLOAT_T blockD0 = CONVERT_VECT_FLOAT_T(AS_VECT_DATA_T( )==""\n"
455R"==(VECT_BLOCK_READ((const __global BLOCK_DATA_T *)&diff_dst[0]))); )==""\n"
456R"==(#ifdef MB16 )==""\n"
457R"==(VECT_FLOAT_T blockD1 = CONVERT_VECT_FLOAT_T(AS_VECT_DATA_T(VECT_BLOCK_READ( )==""\n"
458R"==((const __global BLOCK_DATA_T *)&diff_dst[8 * IC_BLOCK]))); )==""\n"
459R"==(#endif )==""\n"
460R"==(#if FUSE_BN_RELU == 1 )==""\n"
461R"==(ws += d_off; )==""\n"
462R"==(VECT_INT_T blockWS0 = VECT_CHAR_TO_INT( )==""\n"
463R"==(AS_VECT_CHAR_T(VECT_UCHAR_READ((const __global uchar *)&ws[0]))); )==""\n"
464R"==(blockD0 = select((VECT_FLOAT_T)0.0f, blockD0, blockWS0); )==""\n"
465R"==(#if FUSE_BN_ADD_RELU == 1 )==""\n"
466R"==(VECT_BLOCK_WRITE((__global BLOCK_DATA_T *)&diff_src_add[0], )==""\n"
467R"==(AS_VECT_BLOCK_DATA_T(CONVERT_VECTOR_DATA_T(blockD0))); )==""\n"
468R"==(#endif )==""\n"
469R"==(#ifdef MB16 )==""\n"
470R"==(VECT_INT_T blockWS1 = VECT_CHAR_TO_INT(AS_VECT_CHAR_T( )==""\n"
471R"==(VECT_UCHAR_READ((const __global uchar *)&ws[8 * IC_BLOCK]))); )==""\n"
472R"==(blockD1 = select((VECT_FLOAT_T)0.0f, blockD1, blockWS1); )==""\n"
473R"==(#if FUSE_BN_ADD_RELU == 1 )==""\n"
474R"==(VECT_BLOCK_WRITE((__global BLOCK_DATA_T *)&diff_src_add[8 * 16], )==""\n"
475R"==(AS_VECT_BLOCK_DATA_T(CONVERT_VECTOR_DATA_T(blockD1))); )==""\n"
476R"==(#endif )==""\n"
477R"==(#endif )==""\n"
478R"==(#endif )==""\n"
479R"==(gamma *= sqrt_variance; )==""\n"
480R"==(#if CALCULATE_DIFF_STATS == 1 )==""\n"
481R"==(diff_gamma *= sqrt_variance; )==""\n"
482R"==(diff_gamma /= (MB * ID * IH * IW); )==""\n"
483R"==(diff_beta /= (MB * ID * IH * IW); )==""\n"
484R"==(VECT_FLOAT_T blockS0 = CONVERT_VECT_FLOAT_T(AS_VECT_DATA_T( )==""\n"
485R"==(VECT_BLOCK_READ((const __global BLOCK_DATA_T *)&src[0]))); )==""\n"
486R"==(blockD0 -= fma((VECT_FLOAT_T)diff_gamma, (blockS0 - (VECT_FLOAT_T)v_mean), )==""\n"
487R"==((VECT_FLOAT_T)diff_beta); )==""\n"
488R"==(#ifdef MB16 )==""\n"
489R"==(VECT_FLOAT_T blockS1 = CONVERT_VECT_FLOAT_T(AS_VECT_DATA_T(VECT_BLOCK_READ( )==""\n"
490R"==((const __global BLOCK_DATA_T *)&src[8 * IC_BLOCK]))); )==""\n"
491R"==(blockD1 -= fma((VECT_FLOAT_T)diff_gamma, (blockS1 - (VECT_FLOAT_T)v_mean), )==""\n"
492R"==((VECT_FLOAT_T)diff_beta); )==""\n"
493R"==(#endif )==""\n"
494R"==(#endif )==""\n"
495R"==(blockD0 *= gamma; )==""\n"
496R"==(VECT_BLOCK_WRITE((__global BLOCK_DATA_T *)&diff_src[0], )==""\n"
497R"==(AS_VECT_BLOCK_DATA_T(CONVERT_VECTOR_DATA_T(blockD0))); )==""\n"
498R"==(#ifdef MB16 )==""\n"
499R"==(blockD1 *= gamma; )==""\n"
500R"==(VECT_BLOCK_WRITE((__global BLOCK_DATA_T *)&diff_src[8 * 16], )==""\n"
501R"==(AS_VECT_BLOCK_DATA_T(CONVERT_VECTOR_DATA_T(blockD1))); )==""\n"
502R"==(#endif )==""\n"
503R"==(#else )==""\n"
504R"==(const int n = GWS_GET_MB(); )==""\n"
505R"==(const int c = GWS_GET_IC(); )==""\n"
506R"==(const int d = GWS_GET_ID(); )==""\n"
507R"==(const int h = GWS_GET_IH(); )==""\n"
508R"==(const int w = GWS_GET_IW(); )==""\n"
509R"==(float v_variance = variance[c]; )==""\n"
510R"==(float sqrt_variance = 1.0f / sqrt(v_variance + eps); )==""\n"
511R"==(#if USE_SCALE == 1 )==""\n"
512R"==(float gamma = scale[c]; )==""\n"
513R"==(#else )==""\n"
514R"==(float gamma = 1; )==""\n"
515R"==(#endif )==""\n"
516R"==(#if CALCULATE_DIFF_STATS == 1 )==""\n"
517R"==(float v_mean = mean[c]; )==""\n"
518R"==(float diff_gamma = diff_scale[c]; )==""\n"
519R"==(#if DIFF_SHIFT == 1 )==""\n"
520R"==(float diff_beta = diff_shift[c]; )==""\n"
521R"==(#else )==""\n"
522R"==(int reduce_size = MB * ID * IH * IW / REDUCE_DIM; )==""\n"
523R"==(float diff_beta = diff_shift[reduce_size * IC + c]; )==""\n"
524R"==(#endif )==""\n"
525R"==(#endif )==""\n"
526R"==(const int off = SRC_OFF(n, c, d, h, w); )==""\n"
527R"==(float dd = TO_DEF_ACC_DATA_T(diff_dst[off]); )==""\n"
528R"==(#if FUSE_BN_RELU == 1 )==""\n"
529R"==(if (!ws[off]) dd = 0; )==""\n"
530R"==(#if FUSE_BN_ADD_RELU == 1 )==""\n"
531R"==(diff_src_add[off] = TO_DATA_T(dd); )==""\n"
532R"==(#endif )==""\n"
533R"==(#endif )==""\n"
534R"==(float v_diff_src = dd; )==""\n"
535R"==(#if CALCULATE_DIFF_STATS == 1 )==""\n"
536R"==(v_diff_src -= diff_beta / (MB * ID * IH * IW) )==""\n"
537R"==(+ (CONVERT_FLOAT_T(src[off]) - v_mean) * diff_gamma * sqrt_variance )==""\n"
538R"==(/ (MB * ID * IH * IW); )==""\n"
539R"==(#endif )==""\n"
540R"==(v_diff_src *= gamma * sqrt_variance; )==""\n"
541R"==(diff_src[off] = TO_DATA_T(v_diff_src); )==""\n"
542R"==(#endif )==""\n"
543R"==(} )==""\n"
544R"==(#endif )==""\n"
545R"==()==";
546}
547}
548}
549}