1namespace dnnl {
2namespace impl {
3namespace gpu {
4namespace ocl {
5const char *gen9_bnorm_kernel = R"==(/******************************************************************************* )==""\n"
6R"==(* Copyright 2020-2022 Intel Corporation )==""\n"
7R"==(* )==""\n"
8R"==(* Licensed under the Apache License, Version 2.0 (the "License"); )==""\n"
9R"==(* you may not use this file except in compliance with the License. )==""\n"
10R"==(* You may obtain a copy of the License at )==""\n"
11R"==(* )==""\n"
12R"==(* http: )==""\n"
13R"==(* )==""\n"
14R"==(* Unless required by applicable law or agreed to in writing, software )==""\n"
15R"==(* distributed under the License is distributed on an "AS IS" BASIS, )==""\n"
16R"==(* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. )==""\n"
17R"==(* See the License for the specific language governing permissions and )==""\n"
18R"==(* limitations under the License. )==""\n"
19R"==(*******************************************************************************/ )==""\n"
20R"==(#define VECT_DT_N VECT_SIZE )==""\n"
21R"==(#include "gpu/ocl/ocl_types.h" )==""\n"
22R"==(#define IS_IC_EQ_8 (IC == 8) )==""\n"
23R"==(#define HAS_IC_TAIL (IC != IC16) )==""\n"
24R"==(#define HAS_STAT_SP_BLOCK_TAIL (SP % STAT_SP_BLOCK) )==""\n"
25R"==(#if NHWC_OPTIMIZED )==""\n"
26R"==(#if HAS_IC_TAIL )==""\n"
27R"==(#error IC tail processing not supported )==""\n"
28R"==(#endif )==""\n"
29R"==(#else )==""\n"
30R"==(#if HAS_IC_TAIL && !USE_NHWC )==""\n"
31R"==(#error IC tail processing not supported )==""\n"
32R"==(#endif )==""\n"
33R"==(#define HAS_STAT_SP_TAIL (STAT_SP_TAIL != STAT_SP_NBLOCKS) )==""\n"
34R"==(#define HAS_SP_TAIL (SP != SP_TAIL) )==""\n"
35R"==(#endif )==""\n"
36R"==(#define IC_BLOCK_SGROUPS (IC_BLOCK / 16) )==""\n"
37R"==(#define IC_TAIL_SGROUPS (IC_BLOCK_SGROUPS % VECT_SIZE) )==""\n"
38R"==(#define IC_VECT_SGROUPS (IC_BLOCK_SGROUPS - IC_TAIL_SGROUPS) )==""\n"
39R"==(#define HAS_IC_VECT_TAIL (IC_TAIL_SGROUPS > 0) )==""\n"
40R"==(#define LOAD_FLOAT_1x16(ptr) \ )==""\n"
41R"==(as_float(intel_sub_group_block_read((const __global uint *)(ptr))) )==""\n"
42R"==(#define LOAD_UINT_1x16(ptr) \ )==""\n"
43R"==(as_uint(intel_sub_group_block_read((const __global uint *)(ptr))) )==""\n"
44R"==(#define LOAD_UINT_8x16(ptr) \ )==""\n"
45R"==(convert_uint8(as_uint8( \ )==""\n"
46R"==(intel_sub_group_block_read8((const __global uint *)(ptr)))) )==""\n"
47R"==(#define LOAD_CHAR_1x16(ptr) \ )==""\n"
48R"==(as_char(intel_sub_group_block_read_uc((const __global uchar *)(ptr))) )==""\n"
49R"==(#define LOAD_CHAR_8x16(ptr) \ )==""\n"
50R"==(convert_char8(as_char8( \ )==""\n"
51R"==(intel_sub_group_block_read_uc8((const __global uchar *)(ptr)))) )==""\n"
52R"==(#define LOAD_DATA_1x16(ptr) \ )==""\n"
53R"==(CONVERT_FLOAT_T(AS_DATA_T(BLOCK_READ((const __global BLOCK_DATA_T *)(ptr)))) )==""\n"
54R"==(#define LOAD_DATA_8x16(ptr) \ )==""\n"
55R"==(CONVERT_FLOAT8_T( \ )==""\n"
56R"==(AS_DATA8_T(BLOCK_READ8((const __global BLOCK_DATA_T *)(ptr)))) )==""\n"
57R"==(#define LOAD_VECT_DATA(ptr) \ )==""\n"
58R"==(CONVERT_VECT_FLOAT_T(AS_VECT_DATA_T( \ )==""\n"
59R"==(VECT_BLOCK_READ((const __global BLOCK_DATA_T *)(ptr)))) )==""\n"
60R"==(#define LOAD_VECT_CHAR(ptr) \ )==""\n"
61R"==(CONVERT_VECT_CHAR_T( \ )==""\n"
62R"==(AS_VECT_CHAR_T(VECT_UCHAR_READ((const __global uchar *)(ptr)))) )==""\n"
63R"==(#define LOAD_VECT_FLOAT(ptr) \ )==""\n"
64R"==(AS_VECT_FLOAT_T(VECT_UINT_READ((const __global uint *)(ptr))) )==""\n"
65R"==(#define STORE_DATA_1x16(ptr, val) \ )==""\n"
66R"==(BLOCK_WRITE((__global BLOCK_DATA_T *)(ptr), \ )==""\n"
67R"==(AS_BLOCK_DATA_T(CONVERT_DATA_T(val))) )==""\n"
68R"==(#define STORE_DATA_8x16(ptr, val) \ )==""\n"
69R"==(BLOCK_WRITE8((__global BLOCK_DATA_T *)ptr, \ )==""\n"
70R"==(AS_BLOCK_DATA8_T(CONVERT_DATA8_T(val))) )==""\n"
71R"==(#define STORE_VECT_DATA(ptr, val) \ )==""\n"
72R"==(VECT_BLOCK_WRITE((__global BLOCK_DATA_T *)(ptr), \ )==""\n"
73R"==(AS_VECT_BLOCK_DATA_T(CONVERT_VECTOR_DATA_T(val))) )==""\n"
74R"==(#define STORE_FLOAT_1x16(ptr, val) \ )==""\n"
75R"==(intel_sub_group_block_write((__global uint *)(ptr), as_uint(val)) )==""\n"
76R"==(#define STORE_FLOAT_8x16(ptr, val) \ )==""\n"
77R"==(intel_sub_group_block_write8((__global uint *)(ptr), as_uint8(val)) )==""\n"
78R"==(#define STORE_CHAR_1x16(ptr, val) \ )==""\n"
79R"==(intel_sub_group_block_write_uc((__global uchar *)(ptr), as_uchar(val)) )==""\n"
80R"==(#define STORE_CHAR_8x16(ptr, val) \ )==""\n"
81R"==(intel_sub_group_block_write_uc8((__global uchar *)(ptr), as_uchar8(val)) )==""\n"
82R"==(#define STORE_VECT_CHAR(ptr, val) \ )==""\n"
83R"==(VECT_UCHAR_WRITE((__global uchar *)(ptr), \ )==""\n"
84R"==(AS_VECT_UCHAR_T(CONVERT_VECT_CHAR_T(val))) )==""\n"
85R"==(#if HAS_IC_TAIL )==""\n"
86R"==(#define MAYBE_LAST_IC_LOAD_FLOAT_1x16(ptr, idx) \ )==""\n"
87R"==((is_last_ic_block ? (simd_id < 8 ? ptr[(idx) + simd_id] : 0.0f) \ )==""\n"
88R"==(: as_float(intel_sub_group_block_read( \ )==""\n"
89R"==((const __global uint *)(&ptr[(idx)])))) )==""\n"
90R"==(#else )==""\n"
91R"==(#define MAYBE_LAST_IC_LOAD_FLOAT_1x16(ptr, idx) LOAD_FLOAT_1x16(&ptr[(idx)]) )==""\n"
92R"==(#endif )==""\n"
93R"==(#if USE_NHWC )==""\n"
94R"==(#define IC_BLOCK_STRIDE IC )==""\n"
95R"==(#else )==""\n"
96R"==(#define IC_BLOCK_STRIDE 16 )==""\n"
97R"==(#endif )==""\n"
98R"==(#if NHWC_OPTIMIZED )==""\n"
99R"==(#define REDUCE_NUM_SGROUPS IC_BLOCK_SGROUPS )==""\n"
100R"==(#else )==""\n"
101R"==(#define REDUCE_NUM_SGROUPS 1 )==""\n"
102R"==(#endif )==""\n"
103R"==(#define CALC_SLM_LINE_SIZE (REDUCE_NUM_SGROUPS * GWS_LWS0_CALC) )==""\n"
104R"==(#define CALC_SLM_SIZE (CALC_SLM_LINE_SIZE * GWS_LWS1_CALC * GWS_LWS2_CALC) )==""\n"
105R"==(NAMED_KERNEL_ATTR(AUX) )==""\n"
106R"==(__kernel void gen9_fused_reduce_init( )==""\n"
107R"==(#if IS_FWD )==""\n"
108R"==(__global float *mean, __global float *variance )==""\n"
109R"==(#else )==""\n"
110R"==(__global float *diff_scale, __global float *diff_shift )==""\n"
111R"==(#endif )==""\n"
112R"==() { )==""\n"
113R"==(const int c = GWS_GET_IC_AUX(); )==""\n"
114R"==(#if IS_FWD )==""\n"
115R"==(mean[c] = 0.0f; )==""\n"
116R"==(variance[c] = 0.0f; )==""\n"
117R"==(#else )==""\n"
118R"==(diff_scale[c] = 0.0f; )==""\n"
119R"==(#if DIFF_SHIFT == 1 )==""\n"
120R"==(diff_shift[c] = 0.0f; )==""\n"
121R"==(#else )==""\n"
122R"==(diff_shift[IC + IC * REDUCE_STAT_NBLOCKS + c] = 0.0f; )==""\n"
123R"==(#endif )==""\n"
124R"==(#endif )==""\n"
125R"==(return; )==""\n"
126R"==(} )==""\n"
127R"==(#if IS_FWD )==""\n"
128R"==(#define LOAD_DATA_Nx16_USING_LOOP_IDX(n, dest, src, idx) \ )==""\n"
129R"==({ \ )==""\n"
130R"==(for (int k = 0; k < n; ++k) { \ )==""\n"
131R"==(dest[k] = LOAD_DATA_1x16(&src[(k + idx) * IC]); \ )==""\n"
132R"==(} \ )==""\n"
133R"==(} )==""\n"
134R"==(#define LOAD_DATA_Nx16_USING_LOOP_IDX_HALF(n, dest, src, idx) \ )==""\n"
135R"==({ \ )==""\n"
136R"==(for (int k = 0; k < n; k += 2) { \ )==""\n"
137R"==(dest[k] = LOAD_DATA_1x16(&src[(k + idx) * IC]); \ )==""\n"
138R"==(} \ )==""\n"
139R"==(} )==""\n"
140R"==(#if USE_STATS_ONE_PASS )==""\n"
141R"==(#define ACCUM_DATA_T float )==""\n"
142R"==(#define ACCUM_DATA8_T float8 )==""\n"
143R"==(#define ACCUM_DATA2_T float2 )==""\n"
144R"==(#define SUM_DATA_T ACCUM_DATA2_T )==""\n"
145R"==(SUM_DATA_T summation(ACCUM_DATA_T input, SUM_DATA_T state) { )==""\n"
146R"==(ACCUM_DATA2_T ret; )==""\n"
147R"==(ACCUM_DATA_T y = input - state.s1; )==""\n"
148R"==(ACCUM_DATA_T t = state.s0 + y; )==""\n"
149R"==(ret.s1 = (t - state.s0) - y; )==""\n"
150R"==(ret.s0 = t; )==""\n"
151R"==(return ret; )==""\n"
152R"==(} )==""\n"
153R"==(#if FUSED_ATOMICS_REDUCTION )==""\n"
154R"==(void gen9_mean_var_calc_fused_reduction(volatile __global atomic_float *mean, )==""\n"
155R"==(volatile __global atomic_float *variance, int dst_offset, )==""\n"
156R"==(SUM_DATA_T *sum, SUM_DATA_T *sum_sq, __local SUM_DATA_T *local_sum, )==""\n"
157R"==(__local SUM_DATA_T *local_sum_sq) { )==""\n"
158R"==(const int simd_id = get_sub_group_local_id(); )==""\n"
159R"==(const int group_size = GWS_LWS1_CALC * GWS_LWS2_CALC; )==""\n"
160R"==(const int sg_group_id = get_local_id(0) / 16; )==""\n"
161R"==(const int local_id = get_local_id(1); )==""\n"
162R"==(if (local_id > 0) { )==""\n"
163R"==(for (int sg = 0; sg < REDUCE_NUM_SGROUPS; ++sg) { )==""\n"
164R"==(const int slm_offset = CALC_SLM_LINE_SIZE * local_id )==""\n"
165R"==(+ REDUCE_NUM_SGROUPS * 16 * sg_group_id + sg * 16 + simd_id; )==""\n"
166R"==(local_sum[slm_offset] = sum[sg]; )==""\n"
167R"==(local_sum_sq[slm_offset] = sum_sq[sg]; )==""\n"
168R"==(} )==""\n"
169R"==(} )==""\n"
170R"==(barrier(CLK_LOCAL_MEM_FENCE); )==""\n"
171R"==(if (local_id == 0) { )==""\n"
172R"==(for (int sg = 0; sg < REDUCE_NUM_SGROUPS; ++sg) { )==""\n"
173R"==(for (int gr_id = 1; gr_id < group_size; ++gr_id) { )==""\n"
174R"==(const int off_local = CALC_SLM_LINE_SIZE * gr_id )==""\n"
175R"==(+ REDUCE_NUM_SGROUPS * 16 * sg_group_id + sg * 16 )==""\n"
176R"==(+ simd_id; )==""\n"
177R"==(SUM_DATA_T tmp = local_sum[off_local]; )==""\n"
178R"==(SUM_DATA_T tmp_sq = local_sum_sq[off_local]; )==""\n"
179R"==(sum[sg] = summation(tmp.s1, sum[sg]); )==""\n"
180R"==(sum_sq[sg] = summation(tmp_sq.s1, sum_sq[sg]); )==""\n"
181R"==(sum[sg] = summation(tmp.s0, sum[sg]); )==""\n"
182R"==(sum_sq[sg] = summation(tmp_sq.s0, sum_sq[sg]); )==""\n"
183R"==(} )==""\n"
184R"==(const int offset = dst_offset + sg * 16 + simd_id; )==""\n"
185R"==(#if HAS_IC_TAIL )==""\n"
186R"==(if (offset < IC) { )==""\n"
187R"==(#endif )==""\n"
188R"==(atomic_add_global(&mean[offset], sum[sg].s0); )==""\n"
189R"==(atomic_add_global(&variance[offset], sum_sq[sg].s0); )==""\n"
190R"==(#if HAS_IC_TAIL )==""\n"
191R"==(} )==""\n"
192R"==(#endif )==""\n"
193R"==(} )==""\n"
194R"==(} )==""\n"
195R"==(} )==""\n"
196R"==(#endif )==""\n"
197R"==(#endif )==""\n"
198R"==(#if FUSED_ATOMICS_REDUCTION )==""\n"
199R"==(void gen9_calc_fused_reduction(volatile __global atomic_float *dst, )==""\n"
200R"==(int dst_offset, float *sum, __local float *local_sum) { )==""\n"
201R"==(const int simd_id = get_sub_group_local_id(); )==""\n"
202R"==(const int group_size = GWS_LWS1_CALC * GWS_LWS2_CALC; )==""\n"
203R"==(const int sg_group_id = get_local_id(0) / 16; )==""\n"
204R"==(const int local_id = get_local_id(1); )==""\n"
205R"==(if (local_id > 0) { )==""\n"
206R"==(for (int sg = 0; sg < REDUCE_NUM_SGROUPS; ++sg) { )==""\n"
207R"==(const int slm_offset = CALC_SLM_LINE_SIZE * local_id )==""\n"
208R"==(+ REDUCE_NUM_SGROUPS * 16 * sg_group_id + sg * 16 + simd_id; )==""\n"
209R"==(local_sum[slm_offset] = sum[sg]; )==""\n"
210R"==(} )==""\n"
211R"==(} )==""\n"
212R"==(barrier(CLK_LOCAL_MEM_FENCE); )==""\n"
213R"==(if (local_id == 0) { )==""\n"
214R"==(for (int sg = 0; sg < REDUCE_NUM_SGROUPS; ++sg) { )==""\n"
215R"==(for (int gr_id = 1; gr_id < group_size; ++gr_id) { )==""\n"
216R"==(const int off_local = CALC_SLM_LINE_SIZE * gr_id )==""\n"
217R"==(+ REDUCE_NUM_SGROUPS * 16 * sg_group_id + sg * 16 )==""\n"
218R"==(+ simd_id; )==""\n"
219R"==(sum[sg] += local_sum[off_local]; )==""\n"
220R"==(} )==""\n"
221R"==(const int offset = dst_offset + sg * 16 + simd_id; )==""\n"
222R"==(#if HAS_IC_TAIL )==""\n"
223R"==(if (offset < IC) )==""\n"
224R"==(#endif )==""\n"
225R"==(atomic_add_global(&dst[offset], sum[sg]); )==""\n"
226R"==(} )==""\n"
227R"==(} )==""\n"
228R"==(return; )==""\n"
229R"==(} )==""\n"
230R"==(#endif )==""\n"
231R"==(void gen9_reduce_common(__global float *reduce_temp, __local float *local_sum, )==""\n"
232R"==(__global float *dst) { )==""\n"
233R"==(const int ic_sub_group = get_global_id(0) / 16; )==""\n"
234R"==(const int group_c = get_global_id(1); )==""\n"
235R"==(const int simd_id = get_sub_group_local_id(); )==""\n"
236R"==(const int c = group_c * 16 + simd_id; )==""\n"
237R"==(const bool is_last_ic_block = (IC - group_c * 16) < 16; )==""\n"
238R"==(float sum = 0.0f; )==""\n"
239R"==(reduce_temp )==""\n"
240R"==(+= REDUCE_STAT_NBLOCKS / REDUCE_IC_SUB_GROUPS * 16 * ic_sub_group )==""\n"
241R"==(+ REDUCE_STAT_NBLOCKS * 16 * group_c + simd_id; )==""\n"
242R"==(for (int i = 0; i < REDUCE_STAT_NBLOCKS / REDUCE_IC_SUB_GROUPS; i++) { )==""\n"
243R"==(sum += reduce_temp[i * 16]; )==""\n"
244R"==(} )==""\n"
245R"==(if (ic_sub_group > 0) { local_sum[ic_sub_group * 16 + simd_id] = sum; } )==""\n"
246R"==(barrier(CLK_LOCAL_MEM_FENCE); )==""\n"
247R"==(if (ic_sub_group == 0) { )==""\n"
248R"==(for (int i = 1; i < REDUCE_IC_SUB_GROUPS; i++) { )==""\n"
249R"==(sum += local_sum[i * 16 + simd_id]; )==""\n"
250R"==(} )==""\n"
251R"==(#if HAS_IC_TAIL )==""\n"
252R"==(if (!is_last_ic_block || (is_last_ic_block && simd_id < 8)) )==""\n"
253R"==(#endif )==""\n"
254R"==(dst[c] = sum / (MB * ID * IH * IW); )==""\n"
255R"==(} )==""\n"
256R"==(} )==""\n"
257R"==(#if USE_STATS_ONE_PASS )==""\n"
258R"==(#if NHWC_OPTIMIZED )==""\n"
259R"==(NAMED_KERNEL_ATTR(CALC) )==""\n"
260R"==(__kernel void gen9_calc_mean_var(__global DATA_T *src, )==""\n"
261R"==(__global ACCUM_DATA_T *reduce_temp, )==""\n"
262R"==(volatile __global atomic_float *mean, )==""\n"
263R"==(volatile __global atomic_float *variance) { )==""\n"
264R"==(const int mb = GWS_GET_STAT_MB(); )==""\n"
265R"==(const int c = GWS_GET_STAT_IC(); )==""\n"
266R"==(const int sp_block_idx = GWS_GET_STAT_SP(); )==""\n"
267R"==(const int ic_block_offset = (c / 16) * IC_BLOCK; )==""\n"
268R"==(const int group_c_offset )==""\n"
269R"==(= REDUCE_STAT_NBLOCKS * ic_block_offset + sp_block_idx * 16; )==""\n"
270R"==(const int ver_offs = REDUCE_STAT_NBLOCKS * IC; )==""\n"
271R"==(const int src_off = ic_block_offset + sp_block_idx * STAT_SP_BLOCK * IC; )==""\n"
272R"==(src += src_off; )==""\n"
273R"==(SUM_DATA_T sum[IC_BLOCK_SGROUPS] = {0.0f}; )==""\n"
274R"==(SUM_DATA_T sum_sq[IC_BLOCK_SGROUPS] = {0.0f}; )==""\n"
275R"==(#if HAS_STAT_SP_BLOCK_TAIL )==""\n"
276R"==(for (int sp = 0; sp < min(STAT_SP_BLOCK, SP - sp_block_idx * STAT_SP_BLOCK); )==""\n"
277R"==(++sp) { )==""\n"
278R"==(#else )==""\n"
279R"==(for (int sp = 0; sp < STAT_SP_BLOCK; ++sp) { )==""\n"
280R"==(#endif )==""\n"
281R"==(for (int sg = 0; sg < IC_BLOCK_SGROUPS / VECT_SIZE; ++sg) { )==""\n"
282R"==(VECT_FLOAT_T s_vect = LOAD_VECT_DATA(&src[sg * 16 * VECT_SIZE]); )==""\n"
283R"==(for (int vect = 0; vect < VECT_SIZE; ++vect) { )==""\n"
284R"==(const int sum_idx = sg * VECT_SIZE + vect; )==""\n"
285R"==(#if VECT_SIZE > 1 )==""\n"
286R"==(sum[sum_idx] = summation(s_vect[vect], sum[sum_idx]); )==""\n"
287R"==(sum_sq[sum_idx] = summation( )==""\n"
288R"==(s_vect[vect] * s_vect[vect], sum_sq[sum_idx]); )==""\n"
289R"==(#else )==""\n"
290R"==(sum[sum_idx] = summation(s_vect, sum[sum_idx]); )==""\n"
291R"==(sum_sq[sum_idx] = summation(s_vect * s_vect, sum_sq[sum_idx]); )==""\n"
292R"==(#endif )==""\n"
293R"==(} )==""\n"
294R"==(} )==""\n"
295R"==(#if HAS_IC_VECT_TAIL )==""\n"
296R"==(for (int sg = 0; sg < IC_TAIL_SGROUPS; ++sg) { )==""\n"
297R"==(const int sg_idx = IC_VECT_SGROUPS + sg; )==""\n"
298R"==(float s_tail = LOAD_DATA_1x16(&src[(IC_VECT_SGROUPS + sg) * 16]); )==""\n"
299R"==(sum[sg_idx] = summation(s_tail, sum[sg_idx]); )==""\n"
300R"==(sum_sq[sg_idx] = summation(s_tail * s_tail, sum_sq[sg_idx]); )==""\n"
301R"==(} )==""\n"
302R"==(#endif )==""\n"
303R"==(src += IC; )==""\n"
304R"==(} )==""\n"
305R"==(#if FUSED_ATOMICS_REDUCTION )==""\n"
306R"==(__local SUM_DATA_T local_sum[2 * CALC_SLM_SIZE]; )==""\n"
307R"==(__local SUM_DATA_T *local_sum_sq = local_sum + CALC_SLM_SIZE; )==""\n"
308R"==(gen9_mean_var_calc_fused_reduction(mean, variance, ic_block_offset, sum, )==""\n"
309R"==(sum_sq, local_sum, local_sum_sq); )==""\n"
310R"==(#else )==""\n"
311R"==(for (int sg = 0; sg < IC_BLOCK_SGROUPS; ++sg) { )==""\n"
312R"==(const int reduce_off = group_c_offset + sg * 16 * REDUCE_STAT_NBLOCKS; )==""\n"
313R"==(STORE_FLOAT_1x16(&reduce_temp[reduce_off], sum[sg].s0); )==""\n"
314R"==(STORE_FLOAT_1x16(&reduce_temp[ver_offs + reduce_off], sum_sq[sg].s0); )==""\n"
315R"==(} )==""\n"
316R"==(#endif )==""\n"
317R"==(} )==""\n"
318R"==(#else )==""\n"
319R"==(NAMED_KERNEL_ATTR(CALC) )==""\n"
320R"==(__kernel void gen9_calc_mean_var(__global DATA_T *src, )==""\n"
321R"==(__global ACCUM_DATA_T *reduce_temp, )==""\n"
322R"==(volatile __global atomic_float *mean, )==""\n"
323R"==(volatile __global atomic_float *variance) { )==""\n"
324R"==(const int mb = GWS_GET_STAT_MB(); )==""\n"
325R"==(const int c = GWS_GET_STAT_IC(); )==""\n"
326R"==(const int sp_block_idx = GWS_GET_STAT_SP(); )==""\n"
327R"==(const int mb_sp_idx = mb * STAT_SP_NBLOCKS + sp_block_idx; )==""\n"
328R"==(const int group_c_offset = REDUCE_STAT_NBLOCKS * 16 * (int)(c / 16); )==""\n"
329R"==(const int simd_id = get_sub_group_local_id(); )==""\n"
330R"==(const int ver_offs = REDUCE_STAT_NBLOCKS * IC; )==""\n"
331R"==(#if USE_NHWC )==""\n"
332R"==(src += c + sp_block_idx * STAT_SP_BLOCK * IC; )==""\n"
333R"==(#else )==""\n"
334R"==(src += (c & 15) + sp_block_idx * STAT_SP_BLOCK * 16 + (c & ~15) * SP )==""\n"
335R"==(+ mb * SP * IC; )==""\n"
336R"==(#endif )==""\n"
337R"==(SUM_DATA_T sum; )==""\n"
338R"==(SUM_DATA_T sum_sq; )==""\n"
339R"==(sum.s0 = 0; )==""\n"
340R"==(sum.s1 = 0; )==""\n"
341R"==(sum_sq.s0 = 0; )==""\n"
342R"==(sum_sq.s1 = 0; )==""\n"
343R"==(#if HAS_STAT_SP_TAIL )==""\n"
344R"==(if (sp_block_idx == STAT_SP_TAIL) { )==""\n"
345R"==(int sp = SP - STAT_SP_TAIL * STAT_SP_BLOCK; )==""\n"
346R"==(while (sp >= 16) { )==""\n"
347R"==(#if USE_NHWC )==""\n"
348R"==(float8 s0, s1; )==""\n"
349R"==(for (int k = 0; k < 8; ++k) )==""\n"
350R"==(s0[k] = LOAD_DATA_1x16(&src[k * IC]); )==""\n"
351R"==(for (int k = 0; k < 8; ++k) )==""\n"
352R"==(s1[k] = LOAD_DATA_1x16(&src[(k + 8) * IC]); )==""\n"
353R"==(#else )==""\n"
354R"==(float8 s0 = LOAD_DATA_8x16(&src[0]); )==""\n"
355R"==(float8 s1 = LOAD_DATA_8x16(&src[8 * 16]); )==""\n"
356R"==(#endif )==""\n"
357R"==(for (int i = 0; i < 8; i++) { )==""\n"
358R"==(sum = summation(s0[i], sum); )==""\n"
359R"==(sum = summation(s1[i], sum); )==""\n"
360R"==(sum_sq = summation(s0[i] * s0[i], sum_sq); )==""\n"
361R"==(sum_sq = summation(s1[i] * s1[i], sum_sq); )==""\n"
362R"==(} )==""\n"
363R"==(src += 16 * IC_BLOCK_STRIDE; )==""\n"
364R"==(sp -= 16; )==""\n"
365R"==(} )==""\n"
366R"==(while (sp >= 1) { )==""\n"
367R"==(float s0 = LOAD_DATA_1x16(&src[0]); )==""\n"
368R"==(sum = summation(s0, sum); )==""\n"
369R"==(sum_sq = summation(s0 * s0, sum_sq); )==""\n"
370R"==(src += IC_BLOCK_STRIDE; )==""\n"
371R"==(--sp; )==""\n"
372R"==(} )==""\n"
373R"==(} else )==""\n"
374R"==(#endif )==""\n"
375R"==({ )==""\n"
376R"==(for (int sp = 0; sp < STAT_SP_BLOCK / 16; ++sp) { )==""\n"
377R"==(#if USE_NHWC )==""\n"
378R"==(float8 s0, s1; )==""\n"
379R"==(for (int k = 0; k < 8; ++k) )==""\n"
380R"==(s0[k] = LOAD_DATA_1x16(&src[k * IC]); )==""\n"
381R"==(for (int k = 0; k < 8; ++k) )==""\n"
382R"==(s1[k] = LOAD_DATA_1x16(&src[(k + 8) * IC]); )==""\n"
383R"==(#else )==""\n"
384R"==(float8 s0 = LOAD_DATA_8x16(&src[0]); )==""\n"
385R"==(float8 s1 = LOAD_DATA_8x16(&src[8 * 16]); )==""\n"
386R"==(#endif )==""\n"
387R"==(for (int i = 0; i < 8; i++) { )==""\n"
388R"==(sum = summation(s0[i], sum); )==""\n"
389R"==(sum = summation(s1[i], sum); )==""\n"
390R"==(sum_sq = summation(s0[i] * s0[i], sum_sq); )==""\n"
391R"==(sum_sq = summation(s1[i] * s1[i], sum_sq); )==""\n"
392R"==(} )==""\n"
393R"==(src += 16 * IC_BLOCK_STRIDE; )==""\n"
394R"==(} )==""\n"
395R"==(} )==""\n"
396R"==(#if FUSED_ATOMICS_REDUCTION )==""\n"
397R"==(__local SUM_DATA_T local_sum[2 * CALC_SLM_SIZE]; )==""\n"
398R"==(__local SUM_DATA_T *local_sum_sq = local_sum + CALC_SLM_SIZE; )==""\n"
399R"==(gen9_mean_var_calc_fused_reduction( )==""\n"
400R"==(mean, variance, c, &sum, &sum_sq, local_sum, local_sum_sq); )==""\n"
401R"==(#else )==""\n"
402R"==(STORE_FLOAT_1x16(&reduce_temp[group_c_offset + mb_sp_idx * 16], sum.s0); )==""\n"
403R"==(STORE_FLOAT_1x16(&reduce_temp[ver_offs + group_c_offset + mb_sp_idx * 16], )==""\n"
404R"==(sum_sq.s0); )==""\n"
405R"==(#endif )==""\n"
406R"==(} )==""\n"
407R"==(#endif )==""\n"
408R"==(NAMED_KERNEL_ATTR(REDUCE) )==""\n"
409R"==(__kernel void gen9_reduce_mean_var(__global ACCUM_DATA_T *reduce_temp, )==""\n"
410R"==(__global float *mean, __global float *variance) { )==""\n"
411R"==(__local SUM_DATA_T local_sum[16 * REDUCE_IC_SUB_GROUPS]; )==""\n"
412R"==(__local SUM_DATA_T local_sum_sq[16 * REDUCE_IC_SUB_GROUPS]; )==""\n"
413R"==(const int ic_sub_group = get_global_id(0) / 16; )==""\n"
414R"==(const int group_c = get_global_id(1); )==""\n"
415R"==(const int simd_id = get_sub_group_local_id(); )==""\n"
416R"==(const int c = group_c * 16 + simd_id; )==""\n"
417R"==(SUM_DATA_T sum; )==""\n"
418R"==(SUM_DATA_T sum_sq; )==""\n"
419R"==(sum.s0 = 0; )==""\n"
420R"==(sum.s1 = 0; )==""\n"
421R"==(sum_sq.s0 = 0; )==""\n"
422R"==(sum_sq.s1 = 0; )==""\n"
423R"==(int offs_sq = REDUCE_STAT_NBLOCKS * IC; )==""\n"
424R"==(int offs = REDUCE_STAT_NBLOCKS / REDUCE_IC_SUB_GROUPS * 16 * ic_sub_group )==""\n"
425R"==(+ REDUCE_STAT_NBLOCKS * 16 * group_c + simd_id; )==""\n"
426R"==(for (int i = 0; i < REDUCE_STAT_NBLOCKS / REDUCE_IC_SUB_GROUPS; i++) { )==""\n"
427R"==(float tmp = reduce_temp[offs + i * 16]; )==""\n"
428R"==(sum = summation(tmp, sum); )==""\n"
429R"==(} )==""\n"
430R"==(for (int i = 0; i < REDUCE_STAT_NBLOCKS / REDUCE_IC_SUB_GROUPS; i++) { )==""\n"
431R"==(float tmp = reduce_temp[offs_sq + offs + i * 16]; )==""\n"
432R"==(sum_sq = summation(tmp, sum_sq); )==""\n"
433R"==(} )==""\n"
434R"==(if (ic_sub_group > 0) { )==""\n"
435R"==(local_sum[ic_sub_group * 16 + simd_id] = sum; )==""\n"
436R"==(local_sum_sq[ic_sub_group * 16 + simd_id] = sum_sq; )==""\n"
437R"==(} )==""\n"
438R"==(barrier(CLK_LOCAL_MEM_FENCE); )==""\n"
439R"==(if (ic_sub_group == 0) { )==""\n"
440R"==(for (int i = 1; i < REDUCE_IC_SUB_GROUPS; i++) { )==""\n"
441R"==(SUM_DATA_T tmp = local_sum[i * 16 + simd_id]; )==""\n"
442R"==(SUM_DATA_T tmp_sq = local_sum_sq[i * 16 + simd_id]; )==""\n"
443R"==(sum = summation(tmp.s1, sum); )==""\n"
444R"==(sum_sq = summation(tmp_sq.s1, sum_sq); )==""\n"
445R"==(sum = summation(tmp.s0, sum); )==""\n"
446R"==(sum_sq = summation(tmp_sq.s0, sum_sq); )==""\n"
447R"==(} )==""\n"
448R"==(float tmp_mean = sum.s0 / (MB * ID * IH * IW); )==""\n"
449R"==(mean[c] = tmp_mean; )==""\n"
450R"==(float tmp_var = max(0.0f, )==""\n"
451R"==((sum_sq.s0 / (MB * ID * IH * IW)) - (tmp_mean * tmp_mean)); )==""\n"
452R"==(variance[c] = tmp_var; )==""\n"
453R"==(} )==""\n"
454R"==(} )==""\n"
455R"==(#endif )==""\n"
456R"==(NAMED_KERNEL_ATTR(AUX) )==""\n"
457R"==(__kernel void gen9_fused_reduce_final( )==""\n"
458R"==(#if USE_STATS_ONE_PASS )==""\n"
459R"==(__global float *mean, __global float *variance )==""\n"
460R"==(#else )==""\n"
461R"==(__global float *data_reduce )==""\n"
462R"==(#endif )==""\n"
463R"==() { )==""\n"
464R"==(const int c = GWS_GET_IC_AUX(); )==""\n"
465R"==(#if USE_STATS_ONE_PASS )==""\n"
466R"==(mean[c] = mean[c] / (MB * ID * IH * IW); )==""\n"
467R"==(float tmp_var = max( )==""\n"
468R"==(0.0f, (variance[c] / (MB * ID * IH * IW)) - (mean[c] * mean[c])); )==""\n"
469R"==(variance[c] = tmp_var; )==""\n"
470R"==(#else )==""\n"
471R"==(data_reduce[c] = data_reduce[c] / (MB * ID * IH * IW); )==""\n"
472R"==(#endif )==""\n"
473R"==(return; )==""\n"
474R"==(} )==""\n"
475R"==(#if NHWC_OPTIMIZED )==""\n"
476R"==(NAMED_KERNEL_ATTR(CALC) )==""\n"
477R"==(__kernel void gen9_calc_mean(__global DATA_T *src, __global float *reduce_temp, )==""\n"
478R"==(volatile __global atomic_float *mean) { )==""\n"
479R"==(const int mb = GWS_GET_STAT_MB(); )==""\n"
480R"==(const int c = GWS_GET_STAT_IC(); )==""\n"
481R"==(const int sp_block_idx = GWS_GET_STAT_SP(); )==""\n"
482R"==(const int ic_block_offset = (c / 16) * IC_BLOCK; )==""\n"
483R"==(const int group_c_offset )==""\n"
484R"==(= REDUCE_STAT_NBLOCKS * ic_block_offset + sp_block_idx * 16; )==""\n"
485R"==(const int src_off = ic_block_offset + sp_block_idx * STAT_SP_BLOCK * IC; )==""\n"
486R"==(src += src_off; )==""\n"
487R"==(float v_mean[IC_BLOCK_SGROUPS] = {0.0f}; )==""\n"
488R"==(#if HAS_STAT_SP_BLOCK_TAIL )==""\n"
489R"==(for (int sp = 0; sp < min(STAT_SP_BLOCK, SP - sp_block_idx * STAT_SP_BLOCK); )==""\n"
490R"==(++sp) { )==""\n"
491R"==(#else )==""\n"
492R"==(for (int sp = 0; sp < STAT_SP_BLOCK; ++sp) { )==""\n"
493R"==(#endif )==""\n"
494R"==(for (int sg = 0; sg < IC_BLOCK_SGROUPS / VECT_SIZE; ++sg) { )==""\n"
495R"==(VECT_FLOAT_T s_vect = LOAD_VECT_DATA(&src[sg * 16 * VECT_SIZE]); )==""\n"
496R"==(for (int vect = 0; vect < VECT_SIZE; ++vect) { )==""\n"
497R"==(v_mean[sg * VECT_SIZE + vect] )==""\n"
498R"==(#if VECT_SIZE > 1 )==""\n"
499R"==(+= s_vect[vect]; )==""\n"
500R"==(#else )==""\n"
501R"==(+= s_vect; )==""\n"
502R"==(#endif )==""\n"
503R"==(} )==""\n"
504R"==(} )==""\n"
505R"==(#if HAS_IC_VECT_TAIL )==""\n"
506R"==(for (int sg = 0; sg < IC_TAIL_SGROUPS; ++sg) { )==""\n"
507R"==(float s_tail = LOAD_DATA_1x16(&src[(IC_VECT_SGROUPS + sg) * 16]); )==""\n"
508R"==(v_mean[IC_VECT_SGROUPS + sg] += s_tail; )==""\n"
509R"==(} )==""\n"
510R"==(#endif )==""\n"
511R"==(src += IC; )==""\n"
512R"==(} )==""\n"
513R"==(#if FUSED_ATOMICS_REDUCTION )==""\n"
514R"==(__local float local_sum[CALC_SLM_SIZE]; )==""\n"
515R"==(gen9_calc_fused_reduction(mean, ic_block_offset, v_mean, local_sum); )==""\n"
516R"==(#else )==""\n"
517R"==(for (int sg = 0; sg < IC_BLOCK_SGROUPS; ++sg) { )==""\n"
518R"==(const int reduce_off = group_c_offset + sg * 16 * REDUCE_STAT_NBLOCKS; )==""\n"
519R"==(STORE_FLOAT_1x16(&reduce_temp[reduce_off], v_mean[sg]); )==""\n"
520R"==(} )==""\n"
521R"==(#endif )==""\n"
522R"==(} )==""\n"
523R"==(#else )==""\n"
524R"==(NAMED_KERNEL_ATTR(CALC) )==""\n"
525R"==(__kernel void gen9_calc_mean(__global DATA_T *src, __global float *reduce_temp, )==""\n"
526R"==(__global float *mean) { )==""\n"
527R"==(const int mb = GWS_GET_STAT_MB(); )==""\n"
528R"==(const int c = GWS_GET_STAT_IC(); )==""\n"
529R"==(const int sp_block_idx = GWS_GET_STAT_SP(); )==""\n"
530R"==(const int mb_sp_idx = mb * STAT_SP_NBLOCKS + sp_block_idx; )==""\n"
531R"==(const int group_c_offset = REDUCE_STAT_NBLOCKS * 16 * (int)(c / 16); )==""\n"
532R"==(const int simd_id = get_sub_group_local_id(); )==""\n"
533R"==(#if HAS_IC_TAIL )==""\n"
534R"==(const bool is_last_ic_block = c + 16 > IC; )==""\n"
535R"==(const bool is_last_sp_block = (sp_block_idx == STAT_SP_NBLOCKS - 1); )==""\n"
536R"==(#endif )==""\n"
537R"==(#if USE_NHWC )==""\n"
538R"==(src += c + sp_block_idx * STAT_SP_BLOCK * IC; )==""\n"
539R"==(#else )==""\n"
540R"==(src += (c & 15) + sp_block_idx * STAT_SP_BLOCK * 16 + (c & ~15) * SP )==""\n"
541R"==(+ mb * SP * IC; )==""\n"
542R"==(#endif )==""\n"
543R"==(float8 res0 = 0.0f, res1 = 0.0f; )==""\n"
544R"==(float v_mean = 0.0f; )==""\n"
545R"==(#if HAS_STAT_SP_TAIL )==""\n"
546R"==(if (sp_block_idx == STAT_SP_TAIL) { )==""\n"
547R"==(int sp = SP - STAT_SP_TAIL * STAT_SP_BLOCK; )==""\n"
548R"==(while (sp >= 16) { )==""\n"
549R"==(#if USE_NHWC )==""\n"
550R"==(float8 s0, s1; )==""\n"
551R"==(#if IS_IC_EQ_8 )==""\n"
552R"==(LOAD_DATA_Nx16_USING_LOOP_IDX_HALF(8, s0, src, 0); )==""\n"
553R"==(LOAD_DATA_Nx16_USING_LOOP_IDX_HALF(8, s1, src, 8); )==""\n"
554R"==(float8 t0 = intel_sub_group_shuffle_down(s0, s0, 8); )==""\n"
555R"==(float8 t1 = intel_sub_group_shuffle_down(s1, s1, 8); )==""\n"
556R"==(for (int k = 0; k < 7; k += 2) { )==""\n"
557R"==(s0[k + 1] = t0[k]; )==""\n"
558R"==(s1[k + 1] = t1[k]; )==""\n"
559R"==(} )==""\n"
560R"==(#elif HAS_IC_TAIL )==""\n"
561R"==(const bool is_last_sp = sp == 16; )==""\n"
562R"==(if (is_last_sp && is_last_ic_block) { )==""\n"
563R"==(LOAD_DATA_Nx16_USING_LOOP_IDX(7, s0, src, 0); )==""\n"
564R"==(s0[7] = simd_id < 8 ? CONVERT_FLOAT_T(src[7 * IC + simd_id]) )==""\n"
565R"==(: 0.0f; )==""\n"
566R"==(LOAD_DATA_Nx16_USING_LOOP_IDX(7, s1, src, 8); )==""\n"
567R"==(s1[7] = simd_id < 8 ? CONVERT_FLOAT_T(src[15 * IC + simd_id]) )==""\n"
568R"==(: 0.0f; )==""\n"
569R"==(} else { )==""\n"
570R"==(LOAD_DATA_Nx16_USING_LOOP_IDX(8, s0, src, 0); )==""\n"
571R"==(LOAD_DATA_Nx16_USING_LOOP_IDX(8, s1, src, 8); )==""\n"
572R"==(} )==""\n"
573R"==(#else )==""\n"
574R"==(LOAD_DATA_Nx16_USING_LOOP_IDX(8, s0, src, 0); )==""\n"
575R"==(LOAD_DATA_Nx16_USING_LOOP_IDX(8, s1, src, 8); )==""\n"
576R"==(#endif )==""\n"
577R"==(#else )==""\n"
578R"==(float8 s0 = LOAD_DATA_8x16(&src[0]); )==""\n"
579R"==(float8 s1 = LOAD_DATA_8x16(&src[8 * 16]); )==""\n"
580R"==(#endif )==""\n"
581R"==(res0 += s0; )==""\n"
582R"==(res1 += s1; )==""\n"
583R"==(src += 16 * IC_BLOCK_STRIDE; )==""\n"
584R"==(sp -= 16; )==""\n"
585R"==(} )==""\n"
586R"==(while (sp >= 1) { )==""\n"
587R"==(#if HAS_IC_TAIL )==""\n"
588R"==(float s0; )==""\n"
589R"==(if (sp == 1 && is_last_ic_block) )==""\n"
590R"==(s0 = simd_id < 8 ? CONVERT_FLOAT_T(src[simd_id]) : 0.0f; )==""\n"
591R"==(else )==""\n"
592R"==(s0 = LOAD_DATA_1x16(&src[0]); )==""\n"
593R"==(#else )==""\n"
594R"==(float s0 = LOAD_DATA_1x16(&src[0]); )==""\n"
595R"==(#endif )==""\n"
596R"==(v_mean += s0; )==""\n"
597R"==(src += IC_BLOCK_STRIDE; )==""\n"
598R"==(--sp; )==""\n"
599R"==(} )==""\n"
600R"==(} else )==""\n"
601R"==(#endif )==""\n"
602R"==({ )==""\n"
603R"==(for (int sp = 0; sp < STAT_SP_BLOCK / 16; ++sp) { )==""\n"
604R"==(#if USE_NHWC )==""\n"
605R"==(float8 s0, s1; )==""\n"
606R"==(#if IS_IC_EQ_8 )==""\n"
607R"==(LOAD_DATA_Nx16_USING_LOOP_IDX_HALF(8, s0, src, 0); )==""\n"
608R"==(LOAD_DATA_Nx16_USING_LOOP_IDX_HALF(8, s1, src, 8); )==""\n"
609R"==(float8 t0 = intel_sub_group_shuffle_down(s0, s0, 8); )==""\n"
610R"==(float8 t1 = intel_sub_group_shuffle_down(s1, s1, 8); )==""\n"
611R"==(for (int k = 0; k < 7; k += 2) { )==""\n"
612R"==(s0[k + 1] = t0[k]; )==""\n"
613R"==(s1[k + 1] = t1[k]; )==""\n"
614R"==(} )==""\n"
615R"==(#elif HAS_IC_TAIL )==""\n"
616R"==(const bool is_last_sp = sp == STAT_SP_BLOCK / 16 - 1; )==""\n"
617R"==(if (is_last_sp && is_last_ic_block && is_last_sp_block) { )==""\n"
618R"==(LOAD_DATA_Nx16_USING_LOOP_IDX(7, s0, src, 0); )==""\n"
619R"==(s0[7] = simd_id < 8 ? CONVERT_FLOAT_T(src[7 * IC + simd_id]) )==""\n"
620R"==(: 0.0f; )==""\n"
621R"==(LOAD_DATA_Nx16_USING_LOOP_IDX(7, s1, src, 8); )==""\n"
622R"==(s1[7] = simd_id < 8 ? CONVERT_FLOAT_T(src[15 * IC + simd_id]) )==""\n"
623R"==(: 0.0f; )==""\n"
624R"==(} else { )==""\n"
625R"==(LOAD_DATA_Nx16_USING_LOOP_IDX(8, s0, src, 0); )==""\n"
626R"==(LOAD_DATA_Nx16_USING_LOOP_IDX(8, s1, src, 8); )==""\n"
627R"==(} )==""\n"
628R"==(#else )==""\n"
629R"==(LOAD_DATA_Nx16_USING_LOOP_IDX(8, s0, src, 0); )==""\n"
630R"==(LOAD_DATA_Nx16_USING_LOOP_IDX(8, s1, src, 8); )==""\n"
631R"==(#endif )==""\n"
632R"==(#else )==""\n"
633R"==(float8 s0 = LOAD_DATA_8x16(&src[0]); )==""\n"
634R"==(float8 s1 = LOAD_DATA_8x16(&src[8 * 16]); )==""\n"
635R"==(#endif )==""\n"
636R"==(res0 += s0; )==""\n"
637R"==(res1 += s1; )==""\n"
638R"==(src += 16 * IC_BLOCK_STRIDE; )==""\n"
639R"==(} )==""\n"
640R"==(} )==""\n"
641R"==(for (int i = 0; i < 8; i++) { )==""\n"
642R"==(v_mean += res0[i] + res1[i]; )==""\n"
643R"==(} )==""\n"
644R"==(#if FUSED_ATOMICS_REDUCTION )==""\n"
645R"==(__local float local_sum[CALC_SLM_SIZE]; )==""\n"
646R"==(gen9_calc_fused_reduction(mean, c, &v_mean, local_sum); )==""\n"
647R"==(#else )==""\n"
648R"==(STORE_FLOAT_1x16(&reduce_temp[group_c_offset + mb_sp_idx * 16], v_mean); )==""\n"
649R"==(#endif )==""\n"
650R"==(} )==""\n"
651R"==(#endif )==""\n"
652R"==(NAMED_KERNEL_ATTR(REDUCE) )==""\n"
653R"==(__kernel void gen9_reduce_mean( )==""\n"
654R"==(__global float *reduce_temp, __global float *mean) { )==""\n"
655R"==(__local float local_sum[16 * REDUCE_IC_SUB_GROUPS]; )==""\n"
656R"==(gen9_reduce_common(reduce_temp, local_sum, mean); )==""\n"
657R"==(} )==""\n"
658R"==(#if NHWC_OPTIMIZED )==""\n"
659R"==(NAMED_KERNEL_ATTR(CALC) )==""\n"
660R"==(__kernel void gen9_calc_variance(__global DATA_T *src, __global float *mean, )==""\n"
661R"==(__global float *reduce_temp, volatile __global atomic_float *variance) { )==""\n"
662R"==(const int mb = GWS_GET_STAT_MB(); )==""\n"
663R"==(const int c = GWS_GET_STAT_IC(); )==""\n"
664R"==(const int sp_block_idx = GWS_GET_STAT_SP(); )==""\n"
665R"==(const int ic_block_offset = (c / 16) * IC_BLOCK; )==""\n"
666R"==(const int group_c_offset )==""\n"
667R"==(= REDUCE_STAT_NBLOCKS * ic_block_offset + sp_block_idx * 16; )==""\n"
668R"==(reduce_temp += REDUCE_STAT_NBLOCKS * IC16; )==""\n"
669R"==(mean += ic_block_offset; )==""\n"
670R"==(const int src_off = ic_block_offset + sp_block_idx * STAT_SP_BLOCK * IC; )==""\n"
671R"==(src += src_off; )==""\n"
672R"==(float v_mean[IC_BLOCK_SGROUPS]; )==""\n"
673R"==(for (int sg = 0; sg < IC_BLOCK_SGROUPS; ++sg) { )==""\n"
674R"==(v_mean[sg] = as_float(intel_sub_group_block_read( )==""\n"
675R"==((const __global uint *)(&mean[(sg * 16)]))); )==""\n"
676R"==(} )==""\n"
677R"==(float v_var[IC_BLOCK_SGROUPS] = {0.0f}; )==""\n"
678R"==(float v0[IC_BLOCK_SGROUPS] = {0.0f}; )==""\n"
679R"==(#if HAS_STAT_SP_BLOCK_TAIL )==""\n"
680R"==(for (int sp = 0; sp < min(STAT_SP_BLOCK, SP - sp_block_idx * STAT_SP_BLOCK); )==""\n"
681R"==(++sp) { )==""\n"
682R"==(#else )==""\n"
683R"==(for (int sp = 0; sp < STAT_SP_BLOCK; ++sp) { )==""\n"
684R"==(#endif )==""\n"
685R"==(for (int sg = 0; sg < IC_BLOCK_SGROUPS / VECT_SIZE; ++sg) { )==""\n"
686R"==(VECT_FLOAT_T s_vect = LOAD_VECT_DATA(&src[sg * 16 * VECT_SIZE]); )==""\n"
687R"==(for (int vect = 0; vect < VECT_SIZE; ++vect) { )==""\n"
688R"==(int sg_idx = sg * VECT_SIZE + vect; )==""\n"
689R"==(#if VECT_SIZE > 1 )==""\n"
690R"==(v0[sg_idx] = s_vect[vect] - v_mean[sg_idx]; )==""\n"
691R"==(#else )==""\n"
692R"==(v0[sg_idx] = s_vect - v_mean[sg_idx]; )==""\n"
693R"==(#endif )==""\n"
694R"==(v_var[sg_idx] = fma(v0[sg_idx], v0[sg_idx], v_var[sg_idx]); )==""\n"
695R"==(} )==""\n"
696R"==(} )==""\n"
697R"==(#if HAS_IC_VECT_TAIL )==""\n"
698R"==(for (int sg = 0; sg < IC_TAIL_SGROUPS; ++sg) { )==""\n"
699R"==(const int sg_idx = IC_VECT_SGROUPS + sg; )==""\n"
700R"==(float s_tail = LOAD_DATA_1x16(&src[(IC_VECT_SGROUPS + sg) * 16]); )==""\n"
701R"==(v0[sg_idx] = s_tail - v_mean[sg_idx]; )==""\n"
702R"==(v_var[sg_idx] = fma(v0[sg_idx], v0[sg_idx], v_var[sg_idx]); )==""\n"
703R"==(} )==""\n"
704R"==(#endif )==""\n"
705R"==(src += IC; )==""\n"
706R"==(} )==""\n"
707R"==(#if FUSED_ATOMICS_REDUCTION )==""\n"
708R"==(__local float local_sum[CALC_SLM_SIZE]; )==""\n"
709R"==(gen9_calc_fused_reduction(variance, ic_block_offset, v_var, local_sum); )==""\n"
710R"==(#else )==""\n"
711R"==(for (int sg = 0; sg < IC_BLOCK_SGROUPS; ++sg) { )==""\n"
712R"==(const int reduce_off = group_c_offset + sg * 16 * REDUCE_STAT_NBLOCKS; )==""\n"
713R"==(STORE_FLOAT_1x16(&reduce_temp[reduce_off], v_var[sg]); )==""\n"
714R"==(} )==""\n"
715R"==(#endif )==""\n"
716R"==(} )==""\n"
717R"==(#else )==""\n"
718R"==(NAMED_KERNEL_ATTR(CALC) )==""\n"
719R"==(__kernel void gen9_calc_variance(__global DATA_T *src, __global float *mean, )==""\n"
720R"==(__global float *reduce_temp, __global float *variance) { )==""\n"
721R"==(const int mb = GWS_GET_STAT_MB(); )==""\n"
722R"==(const int c = GWS_GET_STAT_IC(); )==""\n"
723R"==(const int sp_block_idx = GWS_GET_STAT_SP(); )==""\n"
724R"==(const int mb_sp_idx = mb * STAT_SP_NBLOCKS + sp_block_idx; )==""\n"
725R"==(const int group_c_offset = REDUCE_STAT_NBLOCKS * 16 * (int)(c / 16); )==""\n"
726R"==(const int simd_id = get_sub_group_local_id(); )==""\n"
727R"==(#if HAS_IC_TAIL )==""\n"
728R"==(const bool is_last_ic_block = c + 16 > IC; )==""\n"
729R"==(const bool is_last_sp_block = (sp_block_idx == STAT_SP_NBLOCKS - 1); )==""\n"
730R"==(#endif )==""\n"
731R"==(reduce_temp += REDUCE_STAT_NBLOCKS * IC16; )==""\n"
732R"==(#if USE_NHWC )==""\n"
733R"==(src += c + sp_block_idx * STAT_SP_BLOCK * IC; )==""\n"
734R"==(#else )==""\n"
735R"==(src += (c & 15) + sp_block_idx * STAT_SP_BLOCK * 16 + (c & ~15) * SP )==""\n"
736R"==(+ mb * SP * IC; )==""\n"
737R"==(#endif )==""\n"
738R"==(float8 res0 = 0.0f, res1 = 0.0f; )==""\n"
739R"==(float v_var = 0.0f; )==""\n"
740R"==(float v_mean = MAYBE_LAST_IC_LOAD_FLOAT_1x16(mean, c); )==""\n"
741R"==(#if HAS_STAT_SP_TAIL )==""\n"
742R"==(if (sp_block_idx == STAT_SP_TAIL) { )==""\n"
743R"==(int sp = SP - STAT_SP_TAIL * STAT_SP_BLOCK; )==""\n"
744R"==(while (sp >= 16) { )==""\n"
745R"==(#if USE_NHWC )==""\n"
746R"==(float8 s0, s1; )==""\n"
747R"==(#if IS_IC_EQ_8 )==""\n"
748R"==(LOAD_DATA_Nx16_USING_LOOP_IDX_HALF(8, s0, src, 0); )==""\n"
749R"==(LOAD_DATA_Nx16_USING_LOOP_IDX_HALF(8, s1, src, 8); )==""\n"
750R"==(float8 t0 = intel_sub_group_shuffle_down(s0, s0, 8); )==""\n"
751R"==(float8 t1 = intel_sub_group_shuffle_down(s1, s1, 8); )==""\n"
752R"==(for (int k = 0; k < 7; k += 2) { )==""\n"
753R"==(s0[k + 1] = t0[k]; )==""\n"
754R"==(s1[k + 1] = t1[k]; )==""\n"
755R"==(} )==""\n"
756R"==(#elif HAS_IC_TAIL )==""\n"
757R"==(const bool is_last_sp = sp == 16; )==""\n"
758R"==(if (is_last_sp && is_last_ic_block) { )==""\n"
759R"==(LOAD_DATA_Nx16_USING_LOOP_IDX(7, s0, src, 0); )==""\n"
760R"==(s0[7] = simd_id < 8 ? CONVERT_FLOAT_T(src[7 * IC + simd_id]) )==""\n"
761R"==(: 0.0f; )==""\n"
762R"==(LOAD_DATA_Nx16_USING_LOOP_IDX(7, s1, src, 8); )==""\n"
763R"==(s1[7] = simd_id < 8 ? CONVERT_FLOAT_T(src[15 * IC + simd_id]) )==""\n"
764R"==(: 0.0f; )==""\n"
765R"==(} else { )==""\n"
766R"==(LOAD_DATA_Nx16_USING_LOOP_IDX(8, s0, src, 0); )==""\n"
767R"==(LOAD_DATA_Nx16_USING_LOOP_IDX(8, s1, src, 8); )==""\n"
768R"==(} )==""\n"
769R"==(#else )==""\n"
770R"==(LOAD_DATA_Nx16_USING_LOOP_IDX(8, s0, src, 0); )==""\n"
771R"==(LOAD_DATA_Nx16_USING_LOOP_IDX(8, s1, src, 8); )==""\n"
772R"==(#endif )==""\n"
773R"==(#else )==""\n"
774R"==(float8 s0 = LOAD_DATA_8x16(&src[0]); )==""\n"
775R"==(float8 s1 = LOAD_DATA_8x16(&src[8 * 16]); )==""\n"
776R"==(#endif )==""\n"
777R"==(float8 v0 = s0 - v_mean; )==""\n"
778R"==(float8 v1 = s1 - v_mean; )==""\n"
779R"==(res0 = fma(v0, v0, res0); )==""\n"
780R"==(res1 = fma(v1, v1, res1); )==""\n"
781R"==(src += 16 * IC_BLOCK_STRIDE; )==""\n"
782R"==(sp -= 16; )==""\n"
783R"==(} )==""\n"
784R"==(while (sp >= 1) { )==""\n"
785R"==(#if HAS_IC_TAIL )==""\n"
786R"==(float s0; )==""\n"
787R"==(if (sp == 1 && is_last_ic_block) )==""\n"
788R"==(s0 = simd_id < 8 ? CONVERT_FLOAT_T(src[simd_id]) : 0.0f; )==""\n"
789R"==(else )==""\n"
790R"==(s0 = LOAD_DATA_1x16(&src[0]); )==""\n"
791R"==(#else )==""\n"
792R"==(float s0 = LOAD_DATA_1x16(&src[0]); )==""\n"
793R"==(#endif )==""\n"
794R"==(float v0 = s0 - v_mean; )==""\n"
795R"==(v_var = fma(v0, v0, v_var); )==""\n"
796R"==(src += IC_BLOCK_STRIDE; )==""\n"
797R"==(--sp; )==""\n"
798R"==(} )==""\n"
799R"==(} else )==""\n"
800R"==(#endif )==""\n"
801R"==({ )==""\n"
802R"==(for (int sp = 0; sp < STAT_SP_BLOCK / 16; ++sp) { )==""\n"
803R"==(#if USE_NHWC )==""\n"
804R"==(float8 s0, s1; )==""\n"
805R"==(#if IS_IC_EQ_8 )==""\n"
806R"==(LOAD_DATA_Nx16_USING_LOOP_IDX_HALF(8, s0, src, 0); )==""\n"
807R"==(LOAD_DATA_Nx16_USING_LOOP_IDX_HALF(8, s1, src, 8); )==""\n"
808R"==(float8 t0 = intel_sub_group_shuffle_down(s0, s0, 8); )==""\n"
809R"==(float8 t1 = intel_sub_group_shuffle_down(s1, s1, 8); )==""\n"
810R"==(for (int k = 0; k < 7; k += 2) { )==""\n"
811R"==(s0[k + 1] = t0[k]; )==""\n"
812R"==(s1[k + 1] = t1[k]; )==""\n"
813R"==(} )==""\n"
814R"==(#elif HAS_IC_TAIL )==""\n"
815R"==(const bool is_last_sp = sp == STAT_SP_BLOCK / 16 - 1; )==""\n"
816R"==(if (is_last_sp && is_last_ic_block && is_last_sp_block) { )==""\n"
817R"==(LOAD_DATA_Nx16_USING_LOOP_IDX(7, s0, src, 0); )==""\n"
818R"==(s0[7] = simd_id < 8 ? CONVERT_FLOAT_T(src[7 * IC + simd_id]) )==""\n"
819R"==(: 0.0f; )==""\n"
820R"==(LOAD_DATA_Nx16_USING_LOOP_IDX(7, s1, src, 8); )==""\n"
821R"==(s1[7] = simd_id < 8 ? CONVERT_FLOAT_T(src[15 * IC + simd_id]) )==""\n"
822R"==(: 0.0f; )==""\n"
823R"==(} else { )==""\n"
824R"==(LOAD_DATA_Nx16_USING_LOOP_IDX(8, s0, src, 0); )==""\n"
825R"==(LOAD_DATA_Nx16_USING_LOOP_IDX(8, s1, src, 8); )==""\n"
826R"==(} )==""\n"
827R"==(#else )==""\n"
828R"==(LOAD_DATA_Nx16_USING_LOOP_IDX(8, s0, src, 0); )==""\n"
829R"==(LOAD_DATA_Nx16_USING_LOOP_IDX(8, s1, src, 8); )==""\n"
830R"==(#endif )==""\n"
831R"==(#else )==""\n"
832R"==(float8 s0 = LOAD_DATA_8x16(&src[0]); )==""\n"
833R"==(float8 s1 = LOAD_DATA_8x16(&src[8 * 16]); )==""\n"
834R"==(#endif )==""\n"
835R"==(float8 v0 = s0 - v_mean; )==""\n"
836R"==(float8 v1 = s1 - v_mean; )==""\n"
837R"==(res0 = fma(v0, v0, res0); )==""\n"
838R"==(res1 = fma(v1, v1, res1); )==""\n"
839R"==(src += 16 * IC_BLOCK_STRIDE; )==""\n"
840R"==(} )==""\n"
841R"==(} )==""\n"
842R"==(for (int i = 0; i < 8; i++) { )==""\n"
843R"==(v_var += res0[i] + res1[i]; )==""\n"
844R"==(} )==""\n"
845R"==(#if FUSED_ATOMICS_REDUCTION )==""\n"
846R"==(__local float local_sum[CALC_SLM_SIZE]; )==""\n"
847R"==(gen9_calc_fused_reduction(variance, c, &v_var, local_sum); )==""\n"
848R"==(#else )==""\n"
849R"==(STORE_FLOAT_1x16(&reduce_temp[group_c_offset + mb_sp_idx * 16], v_var); )==""\n"
850R"==(#endif )==""\n"
851R"==(} )==""\n"
852R"==(#endif )==""\n"
853R"==(NAMED_KERNEL_ATTR(REDUCE) )==""\n"
854R"==(__kernel void gen9_reduce_variance( )==""\n"
855R"==(__global float *reduce_temp, __global float *variance) { )==""\n"
856R"==(__local float local_sum[16 * REDUCE_IC_SUB_GROUPS]; )==""\n"
857R"==(gen9_reduce_common( )==""\n"
858R"==(reduce_temp + REDUCE_STAT_NBLOCKS * IC16, local_sum, variance); )==""\n"
859R"==(} )==""\n"
860R"==(#if NHWC_OPTIMIZED )==""\n"
861R"==(KERNEL_ATTR )==""\n"
862R"==(__kernel void gen9_bnorm_fwd(__global DATA_T *src, __global float *mean, )==""\n"
863R"==(__global float *variance, __global DATA_T *dst, )==""\n"
864R"==(__global float *scaleshift, __global float *shift, __global char *ws, )==""\n"
865R"==(float eps, __global DATA_T *src_add) { )==""\n"
866R"==(const int n = GWS_GET_MB(); )==""\n"
867R"==(const int c = GWS_GET_IC(); )==""\n"
868R"==(const int sp = GWS_GET_SP() * STAT_SP_BLOCK; )==""\n"
869R"==(const int ic_block_offset = (c / 16) * IC_BLOCK; )==""\n"
870R"==(mean += ic_block_offset; )==""\n"
871R"==(variance += ic_block_offset; )==""\n"
872R"==(shift += ic_block_offset; )==""\n"
873R"==(scaleshift += ic_block_offset; )==""\n"
874R"==(const uint d_off = sp * IC + ic_block_offset; )==""\n"
875R"==(src += d_off; )==""\n"
876R"==(#if FUSE_BN_ADD_RELU )==""\n"
877R"==(src_add += d_off; )==""\n"
878R"==(#endif )==""\n"
879R"==(dst += d_off; )==""\n"
880R"==(#if FUSE_BN_RELU && IS_TRAINING )==""\n"
881R"==(ws += d_off; )==""\n"
882R"==(#endif )==""\n"
883R"==(VECT_FLOAT_T sm[IC_BLOCK_SGROUPS / VECT_SIZE], )==""\n"
884R"==(sv[IC_BLOCK_SGROUPS / VECT_SIZE], )==""\n"
885R"==(v_mean[IC_BLOCK_SGROUPS / VECT_SIZE], )==""\n"
886R"==(v_variance[IC_BLOCK_SGROUPS / VECT_SIZE], )==""\n"
887R"==(sqrt_variance[IC_BLOCK_SGROUPS / VECT_SIZE]; )==""\n"
888R"==(for (int sg = 0; sg < IC_BLOCK_SGROUPS / VECT_SIZE; ++sg) { )==""\n"
889R"==(const int sg_idx = sg * 16 * VECT_SIZE; )==""\n"
890R"==(#if USE_SCALE == 1 )==""\n"
891R"==(sm[sg] = LOAD_VECT_FLOAT(&scaleshift[sg_idx]); )==""\n"
892R"==(#else )==""\n"
893R"==(sm[sg] = (VECT_FLOAT_T)1.0f; )==""\n"
894R"==(#endif )==""\n"
895R"==(#if USE_SHIFT == 1 )==""\n"
896R"==(sv[sg] = LOAD_VECT_FLOAT(&shift[sg_idx]); )==""\n"
897R"==(#else )==""\n"
898R"==(sv[sg] = (VECT_FLOAT_T)0.0f; )==""\n"
899R"==(#endif )==""\n"
900R"==(v_mean[sg] = LOAD_VECT_FLOAT(&mean[sg_idx]); )==""\n"
901R"==(v_variance[sg] = LOAD_VECT_FLOAT(&variance[sg_idx]); )==""\n"
902R"==(sqrt_variance[sg] = sm[sg] / sqrt(v_variance[sg] + (VECT_FLOAT_T)eps); )==""\n"
903R"==(} )==""\n"
904R"==(#if HAS_IC_VECT_TAIL )==""\n"
905R"==(float sm_tail[IC_TAIL_SGROUPS], sv_tail[IC_TAIL_SGROUPS], )==""\n"
906R"==(v_mean_tail[IC_TAIL_SGROUPS], v_variance_tail[IC_TAIL_SGROUPS], )==""\n"
907R"==(sqrt_variance_tail[IC_TAIL_SGROUPS]; )==""\n"
908R"==(for (int sg = 0; sg < IC_TAIL_SGROUPS; ++sg) { )==""\n"
909R"==(const int sg_idx = (IC_VECT_SGROUPS + sg) * 16; )==""\n"
910R"==(#if USE_SCALE == 1 )==""\n"
911R"==(sm_tail[sg] = LOAD_FLOAT_1x16(&scaleshift[sg_idx]); )==""\n"
912R"==(#else )==""\n"
913R"==(sm_tail[sg] = 1.0f; )==""\n"
914R"==(#endif )==""\n"
915R"==(#if USE_SHIFT == 1 )==""\n"
916R"==(sv_tail[sg] = LOAD_FLOAT_1x16(&shift[sg_idx]); )==""\n"
917R"==(#else )==""\n"
918R"==(sv_tail[sg] = 0.0f; )==""\n"
919R"==(#endif )==""\n"
920R"==(v_mean_tail[sg] = LOAD_FLOAT_1x16(&mean[sg_idx]); )==""\n"
921R"==(v_variance_tail[sg] = LOAD_FLOAT_1x16(&variance[sg_idx]); )==""\n"
922R"==(sqrt_variance_tail[sg] = sm_tail[sg] / sqrt(v_variance_tail[sg] + eps); )==""\n"
923R"==(} )==""\n"
924R"==(#endif )==""\n"
925R"==(#if HAS_STAT_SP_BLOCK_TAIL )==""\n"
926R"==(for (int sp_idx = 0; sp_idx < min(STAT_SP_BLOCK, SP - sp); ++sp_idx) { )==""\n"
927R"==(#else )==""\n"
928R"==(for (int sp_idx = 0; sp_idx < STAT_SP_BLOCK; ++sp_idx) { )==""\n"
929R"==(#endif )==""\n"
930R"==(for (int sg = 0; sg < IC_BLOCK_SGROUPS / VECT_SIZE; ++sg) { )==""\n"
931R"==(const int sg_idx = sg * 16 * VECT_SIZE; )==""\n"
932R"==(VECT_FLOAT_T s_vect = LOAD_VECT_DATA(&src[sg_idx]); )==""\n"
933R"==(VECT_FLOAT_T d_vect )==""\n"
934R"==(= fma(s_vect - v_mean[sg], sqrt_variance[sg], sv[sg]); )==""\n"
935R"==(#if FUSE_BN_RELU )==""\n"
936R"==(#if FUSE_BN_ADD_RELU )==""\n"
937R"==(VECT_FLOAT_T s_add_vect = LOAD_VECT_DATA(&src_add[sg_idx]); )==""\n"
938R"==(d_vect += s_add_vect; )==""\n"
939R"==(#endif )==""\n"
940R"==(VECT_INT_T ws_vect = isgreater(d_vect, (VECT_FLOAT_T)0.0f); )==""\n"
941R"==(d_vect = select((VECT_FLOAT_T)0.0f, d_vect, ws_vect); )==""\n"
942R"==(#if IS_TRAINING )==""\n"
943R"==(STORE_VECT_CHAR(&ws[sg_idx], ws_vect); )==""\n"
944R"==(#endif )==""\n"
945R"==(#endif )==""\n"
946R"==(#if WITH_RELU )==""\n"
947R"==(d_vect = max(d_vect, (VECT_FLOAT_T)0.0f); )==""\n"
948R"==(#endif )==""\n"
949R"==(STORE_VECT_DATA(&dst[sg_idx], d_vect); )==""\n"
950R"==(} )==""\n"
951R"==(#if HAS_IC_VECT_TAIL )==""\n"
952R"==(for (int sg = 0; sg < IC_TAIL_SGROUPS; ++sg) { )==""\n"
953R"==(const int sg_idx = (IC_VECT_SGROUPS + sg) * 16; )==""\n"
954R"==(float s_tail = LOAD_DATA_1x16(&src[sg_idx]); )==""\n"
955R"==(float d_tail = fma(s_tail - v_mean_tail[sg], sqrt_variance_tail[sg], )==""\n"
956R"==(sv_tail[sg]); )==""\n"
957R"==(#if FUSE_BN_RELU )==""\n"
958R"==(#if FUSE_BN_ADD_RELU )==""\n"
959R"==(float s_add_tail = LOAD_DATA_1x16(&src_add[sg_idx]); )==""\n"
960R"==(d_tail += s_add_tail; )==""\n"
961R"==(#endif )==""\n"
962R"==(int ws_tail = isgreater(d_tail, 0.0f); )==""\n"
963R"==(d_tail = select(0.0f, d_tail, ws_tail); )==""\n"
964R"==(#if IS_TRAINING )==""\n"
965R"==(STORE_CHAR_1x16(&ws[sg_idx], convert_char(ws_tail)); )==""\n"
966R"==(#endif )==""\n"
967R"==(#endif )==""\n"
968R"==(#if WITH_RELU )==""\n"
969R"==(d_tail = max(d_tail, 0.0f); )==""\n"
970R"==(#endif )==""\n"
971R"==(STORE_DATA_1x16(&dst[sg_idx], d_tail); )==""\n"
972R"==(} )==""\n"
973R"==(#endif )==""\n"
974R"==(src += IC; )==""\n"
975R"==(#if FUSE_BN_ADD_RELU )==""\n"
976R"==(src_add += IC; )==""\n"
977R"==(#endif )==""\n"
978R"==(dst += IC; )==""\n"
979R"==(#if FUSE_BN_RELU && IS_TRAINING )==""\n"
980R"==(ws += IC; )==""\n"
981R"==(#endif )==""\n"
982R"==(} )==""\n"
983R"==(} )==""\n"
984R"==(#else )==""\n"
985R"==(inline float8 read_src_block(__global DATA_T *src, int c, int sp) { )==""\n"
986R"==(float8 blockS0 = 0.0f; )==""\n"
987R"==(const int simd_id = get_sub_group_local_id(); )==""\n"
988R"==(#if HAS_IC_TAIL )==""\n"
989R"==(const bool is_last_ic_block = c + 16 > IC; )==""\n"
990R"==(const bool is_last_sp_block = sp >= SP - VECT_SIZE; )==""\n"
991R"==(#endif )==""\n"
992R"==(#if HAS_SP_TAIL )==""\n"
993R"==(if (sp == SP_TAIL) { )==""\n"
994R"==(for (int k = 0; k < SP - SP_TAIL; ++k) )==""\n"
995R"==(#if HAS_IC_TAIL )==""\n"
996R"==(if (k == SP - SP_TAIL - 1 && is_last_ic_block) )==""\n"
997R"==(blockS0[k] = simd_id < 8 )==""\n"
998R"==(? CONVERT_FLOAT_T(src[k * IC_BLOCK_STRIDE + simd_id]) )==""\n"
999R"==(: 0.0f; )==""\n"
1000R"==(else )==""\n"
1001R"==(#endif )==""\n"
1002R"==(blockS0[k] = LOAD_DATA_1x16(&src[k * IC_BLOCK_STRIDE]); )==""\n"
1003R"==(} else )==""\n"
1004R"==(#endif )==""\n"
1005R"==({ )==""\n"
1006R"==(#if USE_NHWC )==""\n"
1007R"==(#if IS_IC_EQ_8 )==""\n"
1008R"==(LOAD_DATA_Nx16_USING_LOOP_IDX_HALF(8, blockS0, src, 0); )==""\n"
1009R"==(float8 t0 = intel_sub_group_shuffle_down(blockS0, blockS0, 8); )==""\n"
1010R"==(for (int k = 0; k < 7; k += 2) )==""\n"
1011R"==(blockS0[k + 1] = t0[k]; )==""\n"
1012R"==(#elif HAS_IC_TAIL )==""\n"
1013R"==(if (is_last_ic_block && is_last_sp_block) { )==""\n"
1014R"==(LOAD_DATA_Nx16_USING_LOOP_IDX(7, blockS0, src, 0); )==""\n"
1015R"==(blockS0[7] = simd_id < 8 ? CONVERT_FLOAT_T(src[7 * IC + simd_id]) )==""\n"
1016R"==(: 0.0f; )==""\n"
1017R"==(} else { )==""\n"
1018R"==(LOAD_DATA_Nx16_USING_LOOP_IDX(8, blockS0, src, 0); )==""\n"
1019R"==(} )==""\n"
1020R"==(#else )==""\n"
1021R"==(LOAD_DATA_Nx16_USING_LOOP_IDX(8, blockS0, src, 0); )==""\n"
1022R"==(#endif )==""\n"
1023R"==(#else )==""\n"
1024R"==(blockS0 = LOAD_DATA_8x16(&src[0]); )==""\n"
1025R"==(#endif )==""\n"
1026R"==(} )==""\n"
1027R"==(return blockS0; )==""\n"
1028R"==(} )==""\n"
1029R"==(KERNEL_ATTR )==""\n"
1030R"==(__kernel void gen9_bnorm_fwd(__global DATA_T *src, __global float *mean, )==""\n"
1031R"==(__global float *variance, __global DATA_T *dst, )==""\n"
1032R"==(__global float *scaleshift, __global float *shift, __global char *ws, )==""\n"
1033R"==(float eps, __global DATA_T *src_add) { )==""\n"
1034R"==(const int n = GWS_GET_MB(); )==""\n"
1035R"==(const int c = GWS_GET_IC(); )==""\n"
1036R"==(const int sp = GWS_GET_SP() * VECT_SIZE; )==""\n"
1037R"==(const int simd_id = get_sub_group_local_id(); )==""\n"
1038R"==(#if HAS_IC_TAIL )==""\n"
1039R"==(const bool is_last_ic_block = c + 16 > IC; )==""\n"
1040R"==(const bool is_last_sp_block = sp >= SP - VECT_SIZE; )==""\n"
1041R"==(#endif )==""\n"
1042R"==(#if USE_NHWC )==""\n"
1043R"==(const uint d_off = sp * IC + c; )==""\n"
1044R"==(#else )==""\n"
1045R"==(const uint d_off = (c & 15) + sp * 16 + (c & ~15) * SP + n * SP * IC; )==""\n"
1046R"==(#endif )==""\n"
1047R"==(src += d_off; )==""\n"
1048R"==(dst += d_off; )==""\n"
1049R"==(float8 blockS0 = read_src_block(src, c, sp); )==""\n"
1050R"==(#if FUSE_BN_ADD_RELU )==""\n"
1051R"==(src_add += d_off; )==""\n"
1052R"==(float8 block_S0_Add = read_src_block(src_add, c, sp); )==""\n"
1053R"==(#endif )==""\n"
1054R"==(float8 blockD0; )==""\n"
1055R"==(#if USE_SCALE == 1 )==""\n"
1056R"==(float sm = MAYBE_LAST_IC_LOAD_FLOAT_1x16(scaleshift, c); )==""\n"
1057R"==(#else )==""\n"
1058R"==(float sm = 1.0f; )==""\n"
1059R"==(#endif )==""\n"
1060R"==(#if USE_SHIFT == 1 )==""\n"
1061R"==(float sv = MAYBE_LAST_IC_LOAD_FLOAT_1x16(shift, c); )==""\n"
1062R"==(#else )==""\n"
1063R"==(float sv = 0.0f; )==""\n"
1064R"==(#endif )==""\n"
1065R"==(float v_mean, v_variance; )==""\n"
1066R"==(#if HAS_IC_TAIL )==""\n"
1067R"==(if (is_last_ic_block) { )==""\n"
1068R"==(v_mean = simd_id < 8 ? mean[c + simd_id] : 0.0f; )==""\n"
1069R"==(v_variance = simd_id < 8 ? variance[c + simd_id] : 0.0f; )==""\n"
1070R"==(} else )==""\n"
1071R"==(#endif )==""\n"
1072R"==({ )==""\n"
1073R"==(v_mean = LOAD_FLOAT_1x16(&mean[c]); )==""\n"
1074R"==(v_variance = LOAD_FLOAT_1x16(&variance[c]); )==""\n"
1075R"==(} )==""\n"
1076R"==(float sqrt_variance = sm / sqrt(v_variance + eps); )==""\n"
1077R"==(blockD0 = fma(blockS0 - (float8)v_mean, (float8)sqrt_variance, (float8)sv); )==""\n"
1078R"==(#if FUSE_BN_RELU )==""\n"
1079R"==(#if FUSE_BN_ADD_RELU )==""\n"
1080R"==(blockD0 += block_S0_Add; )==""\n"
1081R"==(#endif )==""\n"
1082R"==(int8 blockWS0 = isgreater(blockD0, (float8)0.0f); )==""\n"
1083R"==(blockD0 = select((float8)0.0f, blockD0, blockWS0); )==""\n"
1084R"==(#if IS_TRAINING )==""\n"
1085R"==(ws += d_off; )==""\n"
1086R"==(#if HAS_SP_TAIL )==""\n"
1087R"==(if (sp == SP_TAIL) { )==""\n"
1088R"==(for (int k = 0; k < SP - SP_TAIL; ++k) { )==""\n"
1089R"==(STORE_CHAR_1x16( )==""\n"
1090R"==(&ws[k * IC_BLOCK_STRIDE], convert_char(blockWS0[k])); )==""\n"
1091R"==(} )==""\n"
1092R"==(} else )==""\n"
1093R"==(#endif )==""\n"
1094R"==({ )==""\n"
1095R"==(#if USE_NHWC )==""\n"
1096R"==(for (int k = 0; k < 8; ++k) )==""\n"
1097R"==(STORE_CHAR_1x16( )==""\n"
1098R"==(&ws[k * IC_BLOCK_STRIDE], convert_char(blockWS0[k])); )==""\n"
1099R"==(#else )==""\n"
1100R"==(STORE_CHAR_8x16(&ws[0], convert_char8(blockWS0)); )==""\n"
1101R"==(#endif )==""\n"
1102R"==(} )==""\n"
1103R"==(#endif )==""\n"
1104R"==(#endif )==""\n"
1105R"==(#if WITH_RELU )==""\n"
1106R"==(blockD0 = max(blockD0, (VECT_FLOAT_T)0.0f); )==""\n"
1107R"==(#endif )==""\n"
1108R"==(#if HAS_SP_TAIL )==""\n"
1109R"==(if (sp == SP_TAIL) { )==""\n"
1110R"==(for (int k = 0; k < SP - SP_TAIL; ++k) { )==""\n"
1111R"==(#if HAS_IC_TAIL )==""\n"
1112R"==(if (is_last_ic_block) { )==""\n"
1113R"==(if (simd_id < 8) )==""\n"
1114R"==(dst[k * IC_BLOCK_STRIDE + simd_id] )==""\n"
1115R"==(= CONVERT_DATA_T(blockD0[k]); )==""\n"
1116R"==(} else )==""\n"
1117R"==(#endif )==""\n"
1118R"==(STORE_DATA_1x16(&dst[k * IC_BLOCK_STRIDE], blockD0[k]); )==""\n"
1119R"==(} )==""\n"
1120R"==(} else )==""\n"
1121R"==(#endif )==""\n"
1122R"==({ )==""\n"
1123R"==(#if USE_NHWC )==""\n"
1124R"==(for (int k = 0; k < 8; ++k) )==""\n"
1125R"==(#if HAS_IC_TAIL )==""\n"
1126R"==(if (is_last_ic_block) { )==""\n"
1127R"==(if (simd_id < 8) )==""\n"
1128R"==(dst[k * IC_BLOCK_STRIDE + simd_id] )==""\n"
1129R"==(= CONVERT_DATA_T(blockD0[k]); )==""\n"
1130R"==(} else )==""\n"
1131R"==(#endif )==""\n"
1132R"==(STORE_DATA_1x16(&dst[k * IC_BLOCK_STRIDE], blockD0[k]); )==""\n"
1133R"==(#else )==""\n"
1134R"==(STORE_DATA_8x16(&dst[0], blockD0); )==""\n"
1135R"==(#endif )==""\n"
1136R"==(} )==""\n"
1137R"==(} )==""\n"
1138R"==(#endif )==""\n"
1139R"==(#endif )==""\n"
1140R"==(#if IS_BWD == 1 )==""\n"
1141R"==(#define LOAD_DATA_Nx16_USING_LOOP(n, dest, src) \ )==""\n"
1142R"==({ \ )==""\n"
1143R"==(for (int k = 0; k < n; ++k) { \ )==""\n"
1144R"==(dest[k] = LOAD_DATA_1x16(&src[k * IC_BLOCK_STRIDE]); \ )==""\n"
1145R"==(} \ )==""\n"
1146R"==(} )==""\n"
1147R"==(#define LOAD_UINT_Nx16_USING_LOOP(n, dest, src) \ )==""\n"
1148R"==({ \ )==""\n"
1149R"==(for (int k = 0; k < n; ++k) { \ )==""\n"
1150R"==(dest[k] = LOAD_UINT_1x16(&src[k * IC_BLOCK_STRIDE]); \ )==""\n"
1151R"==(} \ )==""\n"
1152R"==(} )==""\n"
1153R"==(#define LOAD_CHAR_Nx16_USING_LOOP(n, dest, src) \ )==""\n"
1154R"==({ \ )==""\n"
1155R"==(for (int k = 0; k < n; ++k) { \ )==""\n"
1156R"==(dest[k] = LOAD_CHAR_1x16(&src[k * IC_BLOCK_STRIDE]); \ )==""\n"
1157R"==(} \ )==""\n"
1158R"==(} )==""\n"
1159R"==(#define LOAD_DATA_8x16_USING_LAYOUT(dest, src) \ )==""\n"
1160R"==({ \ )==""\n"
1161R"==(if (USE_NHWC) { \ )==""\n"
1162R"==(LOAD_DATA_Nx16_USING_LOOP(8, dest, src); \ )==""\n"
1163R"==(} else { \ )==""\n"
1164R"==(dest = LOAD_DATA_8x16(src); \ )==""\n"
1165R"==(} \ )==""\n"
1166R"==(} )==""\n"
1167R"==(#define LOAD_UINT_8x16_USING_LAYOUT(dest, src) \ )==""\n"
1168R"==({ \ )==""\n"
1169R"==(if (USE_NHWC) { \ )==""\n"
1170R"==(LOAD_UINT_Nx16_USING_LOOP(8, dest, src); \ )==""\n"
1171R"==(} else { \ )==""\n"
1172R"==(dest = LOAD_UINT_8x16(src); \ )==""\n"
1173R"==(} \ )==""\n"
1174R"==(} )==""\n"
1175R"==(#define LOAD_CHAR_8x16_USING_LAYOUT(dest, src) \ )==""\n"
1176R"==({ \ )==""\n"
1177R"==(if (USE_NHWC) { \ )==""\n"
1178R"==(LOAD_CHAR_Nx16_USING_LOOP(8, dest, src); \ )==""\n"
1179R"==(} else { \ )==""\n"
1180R"==(dest = LOAD_CHAR_8x16(src); \ )==""\n"
1181R"==(} \ )==""\n"
1182R"==(} )==""\n"
1183R"==(#define LOAD_DATA_Nx16_USING_LOOP_HALF(n, dest, src) \ )==""\n"
1184R"==({ \ )==""\n"
1185R"==(for (int k = 0; k < n; k += 2) { \ )==""\n"
1186R"==(dest[k] = LOAD_DATA_1x16(&src[k * IC_BLOCK_STRIDE]); \ )==""\n"
1187R"==(} \ )==""\n"
1188R"==(} )==""\n"
1189R"==(#if FUSED_ATOMICS_REDUCTION )==""\n"
1190R"==(#if NHWC_OPTIMIZED )==""\n"
1191R"==(#if VECT_SIZE > 1 )==""\n"
1192R"==(#define GET_SCALAR_VAL(v, idx) v[idx / VECT_SIZE][idx % VECT_SIZE] )==""\n"
1193R"==(#else )==""\n"
1194R"==(#define GET_SCALAR_VAL(v, idx) v[idx] )==""\n"
1195R"==(#endif )==""\n"
1196R"==(#else )==""\n"
1197R"==(#define GET_SCALAR_VAL(v, idx) v[idx] )==""\n"
1198R"==(#endif )==""\n"
1199R"==(void gen9_calc_fused_reduction(volatile __global atomic_float *diff_scale, )==""\n"
1200R"==(volatile __global atomic_float *diff_shift, int dst_offset, )==""\n"
1201R"==(#if NHWC_OPTIMIZED )==""\n"
1202R"==(VECT_FLOAT_T *diff_gamma, VECT_FLOAT_T *diff_beta, )==""\n"
1203R"==(#else )==""\n"
1204R"==(float *diff_gamma, float *diff_beta, )==""\n"
1205R"==(#endif )==""\n"
1206R"==(float *diff_gamma_tail, float *diff_beta_tail, )==""\n"
1207R"==(__local float *local_gamma, __local float *local_beta) { )==""\n"
1208R"==(const int simd_id = get_sub_group_local_id(); )==""\n"
1209R"==(const int group_size = GWS_LWS1_CALC * GWS_LWS2_CALC; )==""\n"
1210R"==(const int sg_group_id = get_local_id(0) / 16; )==""\n"
1211R"==(const int local_id = get_local_id(1); )==""\n"
1212R"==(for (int sg = 0; sg < REDUCE_NUM_SGROUPS; ++sg) { )==""\n"
1213R"==(const int slm_offset = CALC_SLM_LINE_SIZE * local_id )==""\n"
1214R"==(+ REDUCE_NUM_SGROUPS * 16 * sg_group_id + sg * 16 + simd_id; )==""\n"
1215R"==(#if HAS_IC_VECT_TAIL && NHWC_OPTIMIZED )==""\n"
1216R"==(if (sg >= IC_VECT_SGROUPS) { )==""\n"
1217R"==(local_gamma[slm_offset] = diff_gamma_tail[sg - IC_VECT_SGROUPS]; )==""\n"
1218R"==(local_beta[slm_offset] = diff_beta_tail[sg - IC_VECT_SGROUPS]; )==""\n"
1219R"==(} else )==""\n"
1220R"==(#endif )==""\n"
1221R"==({ )==""\n"
1222R"==(local_gamma[slm_offset] = GET_SCALAR_VAL(diff_gamma, sg); )==""\n"
1223R"==(local_beta[slm_offset] = GET_SCALAR_VAL(diff_beta, sg); )==""\n"
1224R"==(} )==""\n"
1225R"==(} )==""\n"
1226R"==(barrier(CLK_LOCAL_MEM_FENCE); )==""\n"
1227R"==(if (local_id == 0) { )==""\n"
1228R"==(for (int sg = 0; sg < REDUCE_NUM_SGROUPS; ++sg) { )==""\n"
1229R"==(float d_gamma = 0.f; )==""\n"
1230R"==(float d_beta = 0.f; )==""\n"
1231R"==(for (int gr_id = 0; gr_id < group_size; ++gr_id) { )==""\n"
1232R"==(const int off_local = CALC_SLM_LINE_SIZE * gr_id )==""\n"
1233R"==(+ REDUCE_NUM_SGROUPS * 16 * sg_group_id + sg * 16 )==""\n"
1234R"==(+ simd_id; )==""\n"
1235R"==(d_gamma += local_gamma[off_local]; )==""\n"
1236R"==(d_beta += local_beta[off_local]; )==""\n"
1237R"==(} )==""\n"
1238R"==(const int offset = dst_offset + sg * 16 + simd_id; )==""\n"
1239R"==(#if HAS_IC_TAIL )==""\n"
1240R"==(if (offset < IC) )==""\n"
1241R"==(#endif )==""\n"
1242R"==({ )==""\n"
1243R"==(atomic_add_global(&diff_scale[offset], d_gamma); )==""\n"
1244R"==(#if DIFF_SHIFT == 1 )==""\n"
1245R"==(atomic_add_global(&diff_shift[offset], d_beta); )==""\n"
1246R"==(#else )==""\n"
1247R"==(atomic_add_global( )==""\n"
1248R"==(&diff_shift[IC + IC * REDUCE_STAT_NBLOCKS + offset], )==""\n"
1249R"==(d_beta); )==""\n"
1250R"==(#endif )==""\n"
1251R"==(} )==""\n"
1252R"==(} )==""\n"
1253R"==(} )==""\n"
1254R"==(return; )==""\n"
1255R"==(} )==""\n"
1256R"==(#endif )==""\n"
1257R"==(#if NHWC_OPTIMIZED )==""\n"
1258R"==(NAMED_KERNEL_ATTR(CALC) )==""\n"
1259R"==(__kernel void gen9_calculate_stats(__global DATA_T *src, __global float *mean, )==""\n"
1260R"==(__global DATA_T *diff_dst, __global char *ws, )==""\n"
1261R"==(__global float *temp_reduce, volatile __global atomic_float *diff_scale, )==""\n"
1262R"==(volatile __global atomic_float *diff_shift) { )==""\n"
1263R"==(const int mb = GWS_GET_STAT_MB(); )==""\n"
1264R"==(const int c = GWS_GET_STAT_IC(); )==""\n"
1265R"==(const int sp_block_idx = GWS_GET_STAT_SP(); )==""\n"
1266R"==(const int ic_block_offset = (c / 16) * IC_BLOCK; )==""\n"
1267R"==(const int offset = ic_block_offset + sp_block_idx * STAT_SP_BLOCK * IC; )==""\n"
1268R"==(mean += ic_block_offset; )==""\n"
1269R"==(src += offset; )==""\n"
1270R"==(diff_dst += offset; )==""\n"
1271R"==(ws += offset; )==""\n"
1272R"==(float v_mean[IC_BLOCK_SGROUPS]; )==""\n"
1273R"==(for (int sg = 0; sg < IC_BLOCK_SGROUPS; ++sg) { )==""\n"
1274R"==(v_mean[sg] = as_float(intel_sub_group_block_read( )==""\n"
1275R"==((const __global uint *)(&mean[(sg * 16)]))); )==""\n"
1276R"==(} )==""\n"
1277R"==(VECT_FLOAT_T diff_gamma[IC_BLOCK_SGROUPS / VECT_SIZE] = {0.0f}; )==""\n"
1278R"==(VECT_FLOAT_T diff_beta[IC_BLOCK_SGROUPS / VECT_SIZE] = {0.0f}; )==""\n"
1279R"==(#if HAS_IC_VECT_TAIL )==""\n"
1280R"==(float diff_gamma_tail[IC_TAIL_SGROUPS] = {0.0f}; )==""\n"
1281R"==(float diff_beta_tail[IC_TAIL_SGROUPS] = {0.0f}; )==""\n"
1282R"==(#else )==""\n"
1283R"==(float *diff_gamma_tail = NULL; )==""\n"
1284R"==(float *diff_beta_tail = NULL; )==""\n"
1285R"==(#endif )==""\n"
1286R"==(#if HAS_STAT_SP_BLOCK_TAIL )==""\n"
1287R"==(for (int sp = 0; sp < min(STAT_SP_BLOCK, SP - sp_block_idx * STAT_SP_BLOCK); )==""\n"
1288R"==(++sp) { )==""\n"
1289R"==(#else )==""\n"
1290R"==(for (int sp = 0; sp < STAT_SP_BLOCK; ++sp) { )==""\n"
1291R"==(#endif )==""\n"
1292R"==(for (int sg = 0; sg < IC_BLOCK_SGROUPS / VECT_SIZE; ++sg) { )==""\n"
1293R"==(const int sg_idx = sg * 16 * VECT_SIZE; )==""\n"
1294R"==(#if FUSE_BN_RELU )==""\n"
1295R"==(VECT_CHAR_T ws_vect = LOAD_VECT_CHAR(&ws[sg_idx]); )==""\n"
1296R"==(#endif )==""\n"
1297R"==(VECT_FLOAT_T src_vect = LOAD_VECT_DATA(&src[sg_idx]); )==""\n"
1298R"==(VECT_FLOAT_T dd_vect = LOAD_VECT_DATA(&diff_dst[sg_idx]); )==""\n"
1299R"==(VECT_FLOAT_T v0; )==""\n"
1300R"==(for (int vect = 0; vect < VECT_SIZE; ++vect) { )==""\n"
1301R"==(int sg_idx = sg * VECT_SIZE + vect; )==""\n"
1302R"==(#if VECT_SIZE > 1 )==""\n"
1303R"==(v0[vect] = src_vect[vect] - v_mean[sg_idx]; )==""\n"
1304R"==(#else )==""\n"
1305R"==(v0 = src_vect - v_mean[sg_idx]; )==""\n"
1306R"==(#endif )==""\n"
1307R"==(} )==""\n"
1308R"==(#if FUSE_BN_RELU )==""\n"
1309R"==(dd_vect = select( )==""\n"
1310R"==((VECT_FLOAT_T)0.0f, dd_vect, CONVERT_VECT_INT_T(ws_vect)); )==""\n"
1311R"==(#endif )==""\n"
1312R"==(diff_gamma[sg] = fma(v0, dd_vect, diff_gamma[sg]); )==""\n"
1313R"==(diff_beta[sg] += dd_vect; )==""\n"
1314R"==(} )==""\n"
1315R"==(#if HAS_IC_VECT_TAIL )==""\n"
1316R"==(for (int sg = 0; sg < IC_TAIL_SGROUPS; ++sg) { )==""\n"
1317R"==(const int sg_idx = IC_VECT_SGROUPS + sg; )==""\n"
1318R"==(#if FUSE_BN_RELU )==""\n"
1319R"==(char ws_tail = LOAD_CHAR_1x16(&ws[sg_idx * 16]); )==""\n"
1320R"==(#endif )==""\n"
1321R"==(float src_tail = LOAD_DATA_1x16(&src[sg_idx * 16]); )==""\n"
1322R"==(float dd_tail = LOAD_DATA_1x16(&diff_dst[sg_idx * 16]); )==""\n"
1323R"==(float v0 = src_tail - v_mean[sg_idx]; )==""\n"
1324R"==(#if FUSE_BN_RELU )==""\n"
1325R"==(dd_tail = select(0.0f, dd_tail, convert_int(ws_tail)); )==""\n"
1326R"==(#endif )==""\n"
1327R"==(diff_gamma_tail[sg] = fma(v0, dd_tail, diff_gamma_tail[sg]); )==""\n"
1328R"==(diff_beta_tail[sg] += dd_tail; )==""\n"
1329R"==(} )==""\n"
1330R"==(#endif )==""\n"
1331R"==(src += IC; )==""\n"
1332R"==(diff_dst += IC; )==""\n"
1333R"==(#if FUSE_BN_RELU )==""\n"
1334R"==(ws += IC; )==""\n"
1335R"==(#endif )==""\n"
1336R"==(} )==""\n"
1337R"==(#if FUSED_ATOMICS_REDUCTION )==""\n"
1338R"==(__local float local_gamma[2 * CALC_SLM_SIZE]; )==""\n"
1339R"==(__local float *local_beta = local_gamma + CALC_SLM_SIZE; )==""\n"
1340R"==(gen9_calc_fused_reduction(diff_scale, diff_shift, ic_block_offset, )==""\n"
1341R"==(diff_gamma, diff_beta, diff_gamma_tail, diff_beta_tail, local_gamma, )==""\n"
1342R"==(local_beta); )==""\n"
1343R"==(#else )==""\n"
1344R"==(for (int sg = 0; sg < IC_BLOCK_SGROUPS; ++sg) { )==""\n"
1345R"==(const int reduce_off = sp_block_idx * 16 )==""\n"
1346R"==(+ REDUCE_STAT_NBLOCKS * 16 )==""\n"
1347R"==(* (sg + (int)(c / 16) * (IC_BLOCK / 16)); )==""\n"
1348R"==(const int diff_gamma_offset = IC16 + reduce_off; )==""\n"
1349R"==(const int diff_beta_offset )==""\n"
1350R"==(= 2 * IC16 + REDUCE_STAT_NBLOCKS * IC16 + reduce_off; )==""\n"
1351R"==(#if HAS_IC_VECT_TAIL )==""\n"
1352R"==(if (sg >= IC_VECT_SGROUPS) { )==""\n"
1353R"==(STORE_FLOAT_1x16(&temp_reduce[diff_gamma_offset], )==""\n"
1354R"==(diff_gamma_tail[sg - IC_VECT_SGROUPS]); )==""\n"
1355R"==(STORE_FLOAT_1x16(&temp_reduce[diff_beta_offset], )==""\n"
1356R"==(diff_beta_tail[sg - IC_VECT_SGROUPS]); )==""\n"
1357R"==(} else )==""\n"
1358R"==(#endif )==""\n"
1359R"==({ )==""\n"
1360R"==(#if VECT_SIZE > 1 )==""\n"
1361R"==(STORE_FLOAT_1x16(&temp_reduce[diff_gamma_offset], )==""\n"
1362R"==(diff_gamma[sg / VECT_SIZE][sg % VECT_SIZE]); )==""\n"
1363R"==(STORE_FLOAT_1x16(&temp_reduce[diff_beta_offset], )==""\n"
1364R"==(diff_beta[sg / VECT_SIZE][sg % VECT_SIZE]); )==""\n"
1365R"==(#else )==""\n"
1366R"==(STORE_FLOAT_1x16(&temp_reduce[diff_gamma_offset], diff_gamma[sg]); )==""\n"
1367R"==(STORE_FLOAT_1x16(&temp_reduce[diff_beta_offset], diff_beta[sg]); )==""\n"
1368R"==(#endif )==""\n"
1369R"==(} )==""\n"
1370R"==(} )==""\n"
1371R"==(#endif )==""\n"
1372R"==(} )==""\n"
1373R"==(#else )==""\n"
1374R"==(NAMED_KERNEL_ATTR(CALC) )==""\n"
1375R"==(__kernel void gen9_calculate_stats(__global DATA_T *src, __global float *mean, )==""\n"
1376R"==(__global DATA_T *diff_dst, __global char *ws, )==""\n"
1377R"==(__global float *temp_reduce, volatile __global atomic_float *diff_scale, )==""\n"
1378R"==(volatile __global atomic_float *diff_shift) { )==""\n"
1379R"==(const int mb = GWS_GET_STAT_MB(); )==""\n"
1380R"==(const int c = GWS_GET_STAT_IC(); )==""\n"
1381R"==(const int sp_block_idx = GWS_GET_STAT_SP(); )==""\n"
1382R"==(const int mb_sp_idx = mb * STAT_SP_NBLOCKS + sp_block_idx; )==""\n"
1383R"==(const int group_c_offset = REDUCE_STAT_NBLOCKS * 16 * (int)(c / 16); )==""\n"
1384R"==(const int simd_id = get_sub_group_local_id(); )==""\n"
1385R"==(#if HAS_IC_TAIL )==""\n"
1386R"==(const bool is_last_ic_block = c + 16 > IC; )==""\n"
1387R"==(const bool is_last_sp_block = (sp_block_idx == STAT_SP_NBLOCKS - 1); )==""\n"
1388R"==(#endif )==""\n"
1389R"==(temp_reduce += group_c_offset; )==""\n"
1390R"==(#if USE_NHWC )==""\n"
1391R"==(const int offset = c + sp_block_idx * STAT_SP_BLOCK * IC; )==""\n"
1392R"==(#else )==""\n"
1393R"==(const int offset = (c & 15) + sp_block_idx * STAT_SP_BLOCK * 16 )==""\n"
1394R"==(+ (c & ~15) * SP + mb * SP * IC; )==""\n"
1395R"==(#endif )==""\n"
1396R"==(src += offset; )==""\n"
1397R"==(diff_dst += offset; )==""\n"
1398R"==(ws += offset; )==""\n"
1399R"==(float v_mean = MAYBE_LAST_IC_LOAD_FLOAT_1x16(mean, c); )==""\n"
1400R"==(float8 diff_gamma = 0.0f; )==""\n"
1401R"==(float8 diff_beta = 0.0f; )==""\n"
1402R"==(#if HAS_STAT_SP_TAIL )==""\n"
1403R"==(int sp; )==""\n"
1404R"==(if (sp_block_idx == STAT_SP_TAIL) { )==""\n"
1405R"==(sp = SP - STAT_SP_TAIL * STAT_SP_BLOCK; )==""\n"
1406R"==(} else { )==""\n"
1407R"==(sp = STAT_SP_BLOCK; )==""\n"
1408R"==(} )==""\n"
1409R"==(#else )==""\n"
1410R"==(int sp = STAT_SP_BLOCK; )==""\n"
1411R"==(#endif )==""\n"
1412R"==(const int C_PARALLEL_FACTOR = 8; )==""\n"
1413R"==(for (; sp > C_PARALLEL_FACTOR - 1; sp -= C_PARALLEL_FACTOR) { )==""\n"
1414R"==(float8 src_data; )==""\n"
1415R"==(float8 dd_data; )==""\n"
1416R"==(#if FUSE_BN_RELU == 1 )==""\n"
1417R"==(char8 ws_data; )==""\n"
1418R"==(LOAD_CHAR_8x16_USING_LAYOUT(ws_data, ws); )==""\n"
1419R"==(#endif )==""\n"
1420R"==(#if IS_IC_EQ_8 )==""\n"
1421R"==(LOAD_DATA_Nx16_USING_LOOP_HALF(8, src_data, src); )==""\n"
1422R"==(LOAD_DATA_Nx16_USING_LOOP_HALF(8, dd_data, diff_dst); )==""\n"
1423R"==(float8 t_src = intel_sub_group_shuffle_down(src_data, src_data, 8); )==""\n"
1424R"==(float8 t_dd = intel_sub_group_shuffle_down(dd_data, dd_data, 8); )==""\n"
1425R"==(for (int k = 0; k < 7; k += 2) { )==""\n"
1426R"==(dd_data[k + 1] = t_dd[k]; )==""\n"
1427R"==(src_data[k + 1] = t_src[k]; )==""\n"
1428R"==(} )==""\n"
1429R"==(#elif HAS_IC_TAIL )==""\n"
1430R"==(const bool is_last_sp = sp - C_PARALLEL_FACTOR <= C_PARALLEL_FACTOR - 1; )==""\n"
1431R"==(if (is_last_sp && is_last_ic_block && is_last_sp_block) { )==""\n"
1432R"==(LOAD_DATA_Nx16_USING_LOOP(7, src_data, src); )==""\n"
1433R"==(LOAD_DATA_Nx16_USING_LOOP(7, dd_data, diff_dst); )==""\n"
1434R"==(dd_data[7] = simd_id < 8 )==""\n"
1435R"==(? CONVERT_FLOAT_T(diff_dst[7 * IC_BLOCK_STRIDE + simd_id]) )==""\n"
1436R"==(: 0.0f; )==""\n"
1437R"==(src_data[7] = simd_id < 8 )==""\n"
1438R"==(? CONVERT_FLOAT_T(src[7 * IC_BLOCK_STRIDE + simd_id]) )==""\n"
1439R"==(: 0.0f; )==""\n"
1440R"==(} else { )==""\n"
1441R"==(LOAD_DATA_Nx16_USING_LOOP(8, src_data, src); )==""\n"
1442R"==(LOAD_DATA_Nx16_USING_LOOP(8, dd_data, diff_dst); )==""\n"
1443R"==(} )==""\n"
1444R"==(#else )==""\n"
1445R"==(LOAD_DATA_8x16_USING_LAYOUT(src_data, src); )==""\n"
1446R"==(LOAD_DATA_8x16_USING_LAYOUT(dd_data, diff_dst); )==""\n"
1447R"==(#endif )==""\n"
1448R"==(src += C_PARALLEL_FACTOR * IC_BLOCK_STRIDE; )==""\n"
1449R"==(diff_dst += C_PARALLEL_FACTOR * IC_BLOCK_STRIDE; )==""\n"
1450R"==(#if FUSE_BN_RELU == 1 )==""\n"
1451R"==(ws += C_PARALLEL_FACTOR * IC_BLOCK_STRIDE; )==""\n"
1452R"==(#endif )==""\n"
1453R"==(#if FUSE_BN_RELU == 1 )==""\n"
1454R"==(const float8 C_ZERO = 0.0; )==""\n"
1455R"==(dd_data = select(C_ZERO, dd_data, convert_int8(ws_data)); )==""\n"
1456R"==(#endif )==""\n"
1457R"==(const float8 v0 = src_data - v_mean; )==""\n"
1458R"==(diff_gamma = fma(v0, dd_data, diff_gamma); )==""\n"
1459R"==(diff_beta += dd_data; )==""\n"
1460R"==(} )==""\n"
1461R"==(#if HAS_STAT_SP_TAIL )==""\n"
1462R"==(if (sp_block_idx == STAT_SP_TAIL) { )==""\n"
1463R"==(sp = (SP - STAT_SP_TAIL * STAT_SP_BLOCK) % C_PARALLEL_FACTOR; )==""\n"
1464R"==(while (sp-- >= 1) { )==""\n"
1465R"==(#if FUSE_BN_RELU == 1 )==""\n"
1466R"==(const char ws_data = LOAD_CHAR_1x16(&ws[0]); )==""\n"
1467R"==(#else )==""\n"
1468R"==(const char ws_data = 1; )==""\n"
1469R"==(#endif )==""\n"
1470R"==(#if HAS_IC_TAIL )==""\n"
1471R"==(float src_data, dd_data; )==""\n"
1472R"==(if (sp == 0 && is_last_ic_block) { )==""\n"
1473R"==(src_data = simd_id < 8 ? CONVERT_FLOAT_T(src[simd_id]) : 0.0f; )==""\n"
1474R"==(dd_data = simd_id < 8 ? CONVERT_FLOAT_T(diff_dst[simd_id]) )==""\n"
1475R"==(: 0.0f; )==""\n"
1476R"==(} else { )==""\n"
1477R"==(src_data = LOAD_DATA_1x16(&src[0]); )==""\n"
1478R"==(dd_data = LOAD_DATA_1x16(&diff_dst[0]); )==""\n"
1479R"==(} )==""\n"
1480R"==(#else )==""\n"
1481R"==(const float src_data = LOAD_DATA_1x16(&src[0]); )==""\n"
1482R"==(const float dd_data = LOAD_DATA_1x16(&diff_dst[0]); )==""\n"
1483R"==(#endif )==""\n"
1484R"==(src += IC_BLOCK_STRIDE; )==""\n"
1485R"==(diff_dst += IC_BLOCK_STRIDE; )==""\n"
1486R"==(#if FUSE_BN_RELU == 1 )==""\n"
1487R"==(ws += IC_BLOCK_STRIDE; )==""\n"
1488R"==(#endif )==""\n"
1489R"==(if (ws_data != 0) { )==""\n"
1490R"==(const float v0 = src_data - v_mean; )==""\n"
1491R"==(const float diff_gamma_tmp = fma(v0, dd_data, diff_gamma[0]); )==""\n"
1492R"==(diff_gamma[0] = diff_gamma_tmp; )==""\n"
1493R"==(diff_beta[0] += dd_data; )==""\n"
1494R"==(} )==""\n"
1495R"==(} )==""\n"
1496R"==(} )==""\n"
1497R"==(#endif )==""\n"
1498R"==(for (int i = 1; i < 8; i++) { )==""\n"
1499R"==(diff_gamma[0] += diff_gamma[i]; )==""\n"
1500R"==(diff_beta[0] += diff_beta[i]; )==""\n"
1501R"==(} )==""\n"
1502R"==(#if FUSED_ATOMICS_REDUCTION )==""\n"
1503R"==(__local float local_gamma[2 * CALC_SLM_SIZE]; )==""\n"
1504R"==(__local float *local_beta = local_gamma + CALC_SLM_SIZE; )==""\n"
1505R"==(gen9_calc_fused_reduction(diff_scale, diff_shift, c, &diff_gamma, )==""\n"
1506R"==(&diff_beta, NULL, NULL, local_gamma, local_beta); )==""\n"
1507R"==(#else )==""\n"
1508R"==(STORE_FLOAT_1x16(&temp_reduce[IC16 + mb_sp_idx * 16], diff_gamma[0]); )==""\n"
1509R"==(STORE_FLOAT_1x16(&temp_reduce[2 * IC16 + REDUCE_STAT_NBLOCKS * IC16 )==""\n"
1510R"==(+ mb_sp_idx * 16], )==""\n"
1511R"==(diff_beta[0]); )==""\n"
1512R"==(#endif )==""\n"
1513R"==(} )==""\n"
1514R"==(#endif )==""\n"
1515R"==(NAMED_KERNEL_ATTR(REDUCE) )==""\n"
1516R"==(__kernel void gen9_reduce_stats(__global float *temp_reduce, )==""\n"
1517R"==(__global float *diff_scale, __global float *diff_shift, )==""\n"
1518R"==(__global float *variance, float eps) { )==""\n"
1519R"==(__local float local_gamma[16 * REDUCE_IC_SUB_GROUPS]; )==""\n"
1520R"==(__local float local_beta[16 * REDUCE_IC_SUB_GROUPS]; )==""\n"
1521R"==(const int ic_sub_group = get_global_id(0) / 16; )==""\n"
1522R"==(const int group_c = get_global_id(1); )==""\n"
1523R"==(const int simd_id = get_sub_group_local_id(); )==""\n"
1524R"==(const int c = group_c * 16 + simd_id; )==""\n"
1525R"==(float diff_gamma = 0.0f; )==""\n"
1526R"==(float diff_beta = 0.0f; )==""\n"
1527R"==(temp_reduce += IC16 + REDUCE_STAT_NBLOCKS * 16 * group_c )==""\n"
1528R"==(+ REDUCE_STAT_NBLOCKS / REDUCE_IC_SUB_GROUPS * 16 * ic_sub_group )==""\n"
1529R"==(+ simd_id; )==""\n"
1530R"==(for (int i = 0; i < REDUCE_STAT_NBLOCKS / REDUCE_IC_SUB_GROUPS; i++) { )==""\n"
1531R"==(diff_gamma += temp_reduce[i * 16]; )==""\n"
1532R"==(} )==""\n"
1533R"==(temp_reduce += IC16 + IC16 * REDUCE_STAT_NBLOCKS; )==""\n"
1534R"==(for (int i = 0; i < REDUCE_STAT_NBLOCKS / REDUCE_IC_SUB_GROUPS; i++) { )==""\n"
1535R"==(diff_beta += temp_reduce[i * 16]; )==""\n"
1536R"==(} )==""\n"
1537R"==(if (ic_sub_group > 0) { )==""\n"
1538R"==(local_gamma[ic_sub_group * 16 + simd_id] = diff_gamma; )==""\n"
1539R"==(local_beta[ic_sub_group * 16 + simd_id] = diff_beta; )==""\n"
1540R"==(} )==""\n"
1541R"==(barrier(CLK_LOCAL_MEM_FENCE); )==""\n"
1542R"==(if (ic_sub_group == 0) { )==""\n"
1543R"==(for (int i = 1; i < REDUCE_IC_SUB_GROUPS; i++) { )==""\n"
1544R"==(diff_gamma += local_gamma[i * 16 + simd_id]; )==""\n"
1545R"==(diff_beta += local_beta[i * 16 + simd_id]; )==""\n"
1546R"==(} )==""\n"
1547R"==(float sqrt_variance = 1.0f / sqrt(variance[c] + eps); )==""\n"
1548R"==(#if HAS_IC_TAIL )==""\n"
1549R"==(const bool is_last_ic_block = group_c * 16 + 16 > IC; )==""\n"
1550R"==(if (!is_last_ic_block || (is_last_ic_block && simd_id < 8)) )==""\n"
1551R"==(#endif )==""\n"
1552R"==({ )==""\n"
1553R"==(diff_scale[c] = diff_gamma * sqrt_variance; )==""\n"
1554R"==(#if DIFF_SHIFT == 1 )==""\n"
1555R"==(diff_shift[c] = diff_beta; )==""\n"
1556R"==(#else )==""\n"
1557R"==(diff_shift[IC + IC * REDUCE_STAT_NBLOCKS + c] = diff_beta; )==""\n"
1558R"==(#endif )==""\n"
1559R"==(} )==""\n"
1560R"==(} )==""\n"
1561R"==(} )==""\n"
1562R"==(NAMED_KERNEL_ATTR(AUX) )==""\n"
1563R"==(__kernel void gen9_fused_reduce_final( )==""\n"
1564R"==(__global float *diff_scale, __global float *variance, float eps) { )==""\n"
1565R"==(const int c = GWS_GET_IC_AUX(); )==""\n"
1566R"==(diff_scale[c] *= 1.0f / sqrt(variance[c] + eps); )==""\n"
1567R"==(return; )==""\n"
1568R"==(} )==""\n"
1569R"==(#if NHWC_OPTIMIZED )==""\n"
1570R"==(KERNEL_ATTR )==""\n"
1571R"==(__kernel void gen9_bnorm_bwd(__global DATA_T *src, __global float *mean, )==""\n"
1572R"==(__global float *variance, __global DATA_T *diff_dst, )==""\n"
1573R"==(__global float *scaleshift, __global char *ws, )==""\n"
1574R"==(__global DATA_T *diff_src, __global float *diff_scale, )==""\n"
1575R"==(__global float *diff_shift, float eps, __global DATA_T *diff_src_add) { )==""\n"
1576R"==(const int c = GWS_GET_IC(); )==""\n"
1577R"==(const int ic_block_offset = (c / 16) * IC_BLOCK; )==""\n"
1578R"==(variance += ic_block_offset; )==""\n"
1579R"==(mean += ic_block_offset; )==""\n"
1580R"==(diff_scale += ic_block_offset; )==""\n"
1581R"==(diff_shift += ic_block_offset; )==""\n"
1582R"==(scaleshift += ic_block_offset; )==""\n"
1583R"==(VECT_FLOAT_T v_variance[IC_BLOCK_SGROUPS / VECT_SIZE], )==""\n"
1584R"==(v_mean[IC_BLOCK_SGROUPS / VECT_SIZE], )==""\n"
1585R"==(diff_gamma[IC_BLOCK_SGROUPS / VECT_SIZE], )==""\n"
1586R"==(diff_beta[IC_BLOCK_SGROUPS / VECT_SIZE], )==""\n"
1587R"==(sqrt_variance[IC_BLOCK_SGROUPS / VECT_SIZE], )==""\n"
1588R"==(gamma[IC_BLOCK_SGROUPS / VECT_SIZE]; )==""\n"
1589R"==(for (int sg = 0; sg < IC_BLOCK_SGROUPS / VECT_SIZE; ++sg) { )==""\n"
1590R"==(const int sg_idx = sg * 16 * VECT_SIZE; )==""\n"
1591R"==(v_variance[sg] = LOAD_VECT_FLOAT(&variance[sg_idx]); )==""\n"
1592R"==(#if CALCULATE_DIFF_STATS == 1 )==""\n"
1593R"==(v_mean[sg] = LOAD_VECT_FLOAT(&mean[sg_idx]); )==""\n"
1594R"==(diff_gamma[sg] = LOAD_VECT_FLOAT(&diff_scale[sg_idx]); )==""\n"
1595R"==(#if DIFF_SHIFT == 1 )==""\n"
1596R"==(diff_beta[sg] = LOAD_VECT_FLOAT(&diff_shift[sg_idx]); )==""\n"
1597R"==(#else )==""\n"
1598R"==(diff_beta[sg] = LOAD_VECT_FLOAT( )==""\n"
1599R"==(&diff_shift[IC + REDUCE_STAT_NBLOCKS * IC + sg_idx]); )==""\n"
1600R"==(#endif )==""\n"
1601R"==(#endif )==""\n"
1602R"==(#if USE_SCALE == 1 )==""\n"
1603R"==(gamma[sg] = LOAD_VECT_FLOAT(&scaleshift[sg_idx]); )==""\n"
1604R"==(#else )==""\n"
1605R"==(gamma[sg] = (VECT_FLOAT_T)1.0f; )==""\n"
1606R"==(#endif )==""\n"
1607R"==(sqrt_variance[sg] )==""\n"
1608R"==(= (VECT_FLOAT_T)1.0f / sqrt(v_variance[sg] + (VECT_FLOAT_T)eps); )==""\n"
1609R"==(} )==""\n"
1610R"==(#if HAS_IC_VECT_TAIL )==""\n"
1611R"==(float v_variance_tail[IC_TAIL_SGROUPS], v_mean_tail[IC_TAIL_SGROUPS], )==""\n"
1612R"==(diff_gamma_tail[IC_TAIL_SGROUPS], diff_beta_tail[IC_TAIL_SGROUPS], )==""\n"
1613R"==(sqrt_variance_tail[IC_TAIL_SGROUPS], gamma_tail[IC_TAIL_SGROUPS]; )==""\n"
1614R"==(for (int sg = 0; sg < IC_TAIL_SGROUPS; ++sg) { )==""\n"
1615R"==(const int sg_idx = (IC_VECT_SGROUPS + sg) * 16; )==""\n"
1616R"==(v_variance_tail[sg] = LOAD_FLOAT_1x16(&variance[sg_idx]); )==""\n"
1617R"==(#if CALCULATE_DIFF_STATS == 1 )==""\n"
1618R"==(v_mean_tail[sg] = LOAD_FLOAT_1x16(&mean[sg_idx]); )==""\n"
1619R"==(diff_gamma_tail[sg] = LOAD_FLOAT_1x16(&diff_scale[sg_idx]); )==""\n"
1620R"==(#if DIFF_SHIFT == 1 )==""\n"
1621R"==(diff_beta_tail[sg] = LOAD_FLOAT_1x16(&diff_shift[sg_idx]); )==""\n"
1622R"==(#else )==""\n"
1623R"==(diff_beta_tail[sg] = LOAD_FLOAT_1x16( )==""\n"
1624R"==(&diff_shift[IC + REDUCE_STAT_NBLOCKS * IC + sg_idx]); )==""\n"
1625R"==(#endif )==""\n"
1626R"==(#endif )==""\n"
1627R"==(#if USE_SCALE == 1 )==""\n"
1628R"==(gamma_tail[sg] = LOAD_FLOAT_1x16(&scaleshift[sg_idx]); )==""\n"
1629R"==(#else )==""\n"
1630R"==(gamma_tail[sg] = 1.0f; )==""\n"
1631R"==(#endif )==""\n"
1632R"==(sqrt_variance_tail[sg] = 1.0f / sqrt(v_variance_tail[sg] + eps); )==""\n"
1633R"==(} )==""\n"
1634R"==(#endif )==""\n"
1635R"==(const int sp_block_idx = GWS_GET_SP(); )==""\n"
1636R"==(const int offset = ic_block_offset + sp_block_idx * STAT_SP_BLOCK * IC; )==""\n"
1637R"==(src += offset; )==""\n"
1638R"==(diff_dst += offset; )==""\n"
1639R"==(ws += offset; )==""\n"
1640R"==(diff_src += offset; )==""\n"
1641R"==(#if FUSE_BN_ADD_RELU )==""\n"
1642R"==(diff_src_add += offset; )==""\n"
1643R"==(#endif )==""\n"
1644R"==(#if HAS_STAT_SP_BLOCK_TAIL )==""\n"
1645R"==(for (int sp = 0; sp < min(STAT_SP_BLOCK, SP - sp_block_idx * STAT_SP_BLOCK); )==""\n"
1646R"==(++sp) { )==""\n"
1647R"==(#else )==""\n"
1648R"==(for (int sp = 0; sp < STAT_SP_BLOCK; ++sp) { )==""\n"
1649R"==(#endif )==""\n"
1650R"==(for (int sg = 0; sg < IC_BLOCK_SGROUPS / VECT_SIZE; ++sg) { )==""\n"
1651R"==(const int sg_idx = sg * 16 * VECT_SIZE; )==""\n"
1652R"==(VECT_FLOAT_T src_vect = LOAD_VECT_DATA(&src[sg_idx]); )==""\n"
1653R"==(VECT_FLOAT_T dd_vect = LOAD_VECT_DATA(&diff_dst[sg_idx]); )==""\n"
1654R"==(#if FUSE_BN_RELU )==""\n"
1655R"==(VECT_CHAR_T ws_vect = LOAD_VECT_CHAR(&ws[sg_idx]); )==""\n"
1656R"==(dd_vect = select( )==""\n"
1657R"==((VECT_FLOAT_T)0.0f, dd_vect, CONVERT_VECT_INT_T(ws_vect)); )==""\n"
1658R"==(#if FUSE_BN_ADD_RELU )==""\n"
1659R"==(STORE_VECT_DATA(&diff_src_add[sg_idx], dd_vect); )==""\n"
1660R"==(#endif )==""\n"
1661R"==(#endif )==""\n"
1662R"==(#if CALCULATE_DIFF_STATS == 1 )==""\n"
1663R"==(dd_vect -= (diff_beta[sg] )==""\n"
1664R"==(+ (src_vect - v_mean[sg]) * diff_gamma[sg] )==""\n"
1665R"==(* sqrt_variance[sg]) )==""\n"
1666R"==(/ (MB * ID * IH * IW); )==""\n"
1667R"==(#endif )==""\n"
1668R"==(dd_vect *= gamma[sg] * sqrt_variance[sg]; )==""\n"
1669R"==(STORE_VECT_DATA(&diff_src[sg_idx], dd_vect); )==""\n"
1670R"==(} )==""\n"
1671R"==(#if HAS_IC_VECT_TAIL )==""\n"
1672R"==(for (int sg = 0; sg < IC_TAIL_SGROUPS; ++sg) { )==""\n"
1673R"==(const int sg_idx = (IC_VECT_SGROUPS + sg) * 16; )==""\n"
1674R"==(float src_tail = LOAD_DATA_1x16(&src[sg_idx]); )==""\n"
1675R"==(float dd_tail = LOAD_DATA_1x16(&diff_dst[sg_idx]); )==""\n"
1676R"==(#if FUSE_BN_RELU )==""\n"
1677R"==(char ws_tail = LOAD_CHAR_1x16(&ws[sg_idx]); )==""\n"
1678R"==(dd_tail = select(0.0f, dd_tail, convert_int(ws_tail)); )==""\n"
1679R"==(#if FUSE_BN_ADD_RELU )==""\n"
1680R"==(STORE_DATA_1x16(&diff_src_add[sg_idx], dd_tail); )==""\n"
1681R"==(#endif )==""\n"
1682R"==(#endif )==""\n"
1683R"==(#if CALCULATE_DIFF_STATS == 1 )==""\n"
1684R"==(dd_tail -= (diff_beta_tail[sg] )==""\n"
1685R"==(+ (src_tail - v_mean_tail[sg]) )==""\n"
1686R"==(* diff_gamma_tail[sg] )==""\n"
1687R"==(* sqrt_variance_tail[sg]) )==""\n"
1688R"==(/ (MB * ID * IH * IW); )==""\n"
1689R"==(#endif )==""\n"
1690R"==(dd_tail *= gamma_tail[sg] * sqrt_variance_tail[sg]; )==""\n"
1691R"==(STORE_DATA_1x16(&diff_src[sg_idx], dd_tail); )==""\n"
1692R"==(} )==""\n"
1693R"==(#endif )==""\n"
1694R"==(src += IC; )==""\n"
1695R"==(diff_dst += IC; )==""\n"
1696R"==(diff_src += IC; )==""\n"
1697R"==(#if FUSE_BN_RELU )==""\n"
1698R"==(#if FUSE_BN_ADD_RELU )==""\n"
1699R"==(diff_src_add += IC; )==""\n"
1700R"==(#endif )==""\n"
1701R"==(ws += IC; )==""\n"
1702R"==(#endif )==""\n"
1703R"==(} )==""\n"
1704R"==(} )==""\n"
1705R"==(#else )==""\n"
1706R"==(inline void write_8x16_block(__global DATA_T *ptr, int c, float8 val) { )==""\n"
1707R"==(#if USE_NHWC )==""\n"
1708R"==(#if HAS_IC_TAIL )==""\n"
1709R"==(const int simd_id = get_sub_group_local_id(); )==""\n"
1710R"==(const bool is_last_ic_block = c + 16 > IC; )==""\n"
1711R"==(if (is_last_ic_block) { )==""\n"
1712R"==(if (simd_id < 8) { )==""\n"
1713R"==(for (int k = 0; k < 8; ++k) )==""\n"
1714R"==(ptr[k * IC_BLOCK_STRIDE + simd_id] = CONVERT_DATA_T(val[k]); )==""\n"
1715R"==(} )==""\n"
1716R"==(} else )==""\n"
1717R"==(#endif )==""\n"
1718R"==(for (int k = 0; k < 8; ++k) )==""\n"
1719R"==(STORE_DATA_1x16(&ptr[k * IC_BLOCK_STRIDE], val[k]); )==""\n"
1720R"==(#else )==""\n"
1721R"==(STORE_DATA_8x16(&ptr[0], val); )==""\n"
1722R"==(#endif )==""\n"
1723R"==(} )==""\n"
1724R"==(inline void write_1x16_block(__global DATA_T *ptr, int c, float val) { )==""\n"
1725R"==(#if HAS_IC_TAIL )==""\n"
1726R"==(const int simd_id = get_sub_group_local_id(); )==""\n"
1727R"==(const bool is_last_ic_block = c + 16 > IC; )==""\n"
1728R"==(if (!is_last_ic_block) { )==""\n"
1729R"==(STORE_DATA_1x16(&ptr[0], val); )==""\n"
1730R"==(} else { )==""\n"
1731R"==(if (simd_id < 8) { ptr[simd_id] = CONVERT_DATA_T(val); } )==""\n"
1732R"==(} )==""\n"
1733R"==(#else )==""\n"
1734R"==(STORE_DATA_1x16(&ptr[0], val); )==""\n"
1735R"==(#endif )==""\n"
1736R"==(} )==""\n"
1737R"==(KERNEL_ATTR )==""\n"
1738R"==(__kernel void gen9_bnorm_bwd(__global DATA_T *src, __global float *mean, )==""\n"
1739R"==(__global float *variance, __global DATA_T *diff_dst, )==""\n"
1740R"==(__global float *scaleshift, __global char *ws, )==""\n"
1741R"==(__global DATA_T *diff_src, __global float *diff_scale, )==""\n"
1742R"==(__global float *diff_shift, float eps, __global DATA_T *diff_src_add) { )==""\n"
1743R"==(const int c = GWS_GET_IC(); )==""\n"
1744R"==(const int simd_id = get_sub_group_local_id(); )==""\n"
1745R"==(#if HAS_IC_TAIL )==""\n"
1746R"==(const bool is_last_ic_block = c + 16 > IC; )==""\n"
1747R"==(#endif )==""\n"
1748R"==(const float v_variance = MAYBE_LAST_IC_LOAD_FLOAT_1x16(variance, c); )==""\n"
1749R"==(#if CALCULATE_DIFF_STATS == 1 )==""\n"
1750R"==(const float v_mean = MAYBE_LAST_IC_LOAD_FLOAT_1x16(mean, c); )==""\n"
1751R"==(const float diff_gamma = MAYBE_LAST_IC_LOAD_FLOAT_1x16(diff_scale, c); )==""\n"
1752R"==(#if DIFF_SHIFT == 1 )==""\n"
1753R"==(const float diff_beta = MAYBE_LAST_IC_LOAD_FLOAT_1x16(diff_shift, c); )==""\n"
1754R"==(#else )==""\n"
1755R"==(const float diff_beta = MAYBE_LAST_IC_LOAD_FLOAT_1x16( )==""\n"
1756R"==(diff_shift, IC + REDUCE_STAT_NBLOCKS * IC + c); )==""\n"
1757R"==(#endif )==""\n"
1758R"==(#endif )==""\n"
1759R"==(#if USE_SCALE == 1 )==""\n"
1760R"==(const float gamma = MAYBE_LAST_IC_LOAD_FLOAT_1x16(scaleshift, c); )==""\n"
1761R"==(#else )==""\n"
1762R"==(const float gamma = 1; )==""\n"
1763R"==(#endif )==""\n"
1764R"==(const int sp_block_idx = GWS_GET_SP(); )==""\n"
1765R"==(#if USE_NHWC )==""\n"
1766R"==(const int offset = c + sp_block_idx * VECT_SIZE * IC; )==""\n"
1767R"==(#else )==""\n"
1768R"==(const int mb = GWS_GET_MB(); )==""\n"
1769R"==(const int offset = (c & 15) + sp_block_idx * VECT_SIZE * 16 + (c & ~15) * SP )==""\n"
1770R"==(+ mb * SP * IC; )==""\n"
1771R"==(#endif )==""\n"
1772R"==(#if HAS_IC_TAIL )==""\n"
1773R"==(const bool is_last_sp_block = sp_block_idx == SP / VECT_SIZE - 1; )==""\n"
1774R"==(#endif )==""\n"
1775R"==(src += offset; )==""\n"
1776R"==(diff_dst += offset; )==""\n"
1777R"==(ws += offset; )==""\n"
1778R"==(diff_src += offset; )==""\n"
1779R"==(#if FUSE_BN_ADD_RELU )==""\n"
1780R"==(diff_src_add += offset; )==""\n"
1781R"==(#endif )==""\n"
1782R"==(#if HAS_SP_TAIL )==""\n"
1783R"==(int sp; )==""\n"
1784R"==(if (sp_block_idx == SP_TAIL / VECT_SIZE) { )==""\n"
1785R"==(sp = SP - SP_TAIL; )==""\n"
1786R"==(} else { )==""\n"
1787R"==(sp = VECT_SIZE; )==""\n"
1788R"==(} )==""\n"
1789R"==(#else )==""\n"
1790R"==(int sp = VECT_SIZE; )==""\n"
1791R"==(#endif )==""\n"
1792R"==(const float sqrt_variance = 1.0f / sqrt(v_variance + eps); )==""\n"
1793R"==(const int C_PARALLEL_FACTOR = 8; )==""\n"
1794R"==(for (; sp > C_PARALLEL_FACTOR - 1; sp -= C_PARALLEL_FACTOR) { )==""\n"
1795R"==(float8 src_data; )==""\n"
1796R"==(float8 dd_data; )==""\n"
1797R"==(#if FUSE_BN_RELU == 1 )==""\n"
1798R"==(char8 ws_data; )==""\n"
1799R"==(LOAD_CHAR_8x16_USING_LAYOUT(ws_data, ws); )==""\n"
1800R"==(#endif )==""\n"
1801R"==(#if IS_IC_EQ_8 )==""\n"
1802R"==(LOAD_DATA_Nx16_USING_LOOP_HALF(8, src_data, src); )==""\n"
1803R"==(LOAD_DATA_Nx16_USING_LOOP_HALF(8, dd_data, diff_dst); )==""\n"
1804R"==(float8 t_dd = intel_sub_group_shuffle_down(dd_data, dd_data, 8); )==""\n"
1805R"==(float8 t_src = intel_sub_group_shuffle_down(src_data, src_data, 8); )==""\n"
1806R"==(for (int k = 0; k < 7; k += 2) { )==""\n"
1807R"==(dd_data[k + 1] = t_dd[k]; )==""\n"
1808R"==(src_data[k + 1] = t_src[k]; )==""\n"
1809R"==(} )==""\n"
1810R"==(#elif HAS_IC_TAIL && !HAS_SP_TAIL )==""\n"
1811R"==(const bool is_last_sp = sp - C_PARALLEL_FACTOR <= C_PARALLEL_FACTOR - 1; )==""\n"
1812R"==(if (is_last_sp && is_last_ic_block && is_last_sp_block) { )==""\n"
1813R"==(LOAD_DATA_Nx16_USING_LOOP(7, src_data, src); )==""\n"
1814R"==(LOAD_DATA_Nx16_USING_LOOP(7, dd_data, diff_dst); )==""\n"
1815R"==(dd_data[7] = simd_id < 8 )==""\n"
1816R"==(? CONVERT_FLOAT_T(diff_dst[7 * IC_BLOCK_STRIDE + simd_id]) )==""\n"
1817R"==(: 0.0f; )==""\n"
1818R"==(src_data[7] = simd_id < 8 )==""\n"
1819R"==(? CONVERT_FLOAT_T(src[7 * IC_BLOCK_STRIDE + simd_id]) )==""\n"
1820R"==(: 0.0f; )==""\n"
1821R"==(} else { )==""\n"
1822R"==(LOAD_DATA_Nx16_USING_LOOP(8, src_data, src); )==""\n"
1823R"==(LOAD_DATA_Nx16_USING_LOOP(8, dd_data, diff_dst); )==""\n"
1824R"==(} )==""\n"
1825R"==(#else )==""\n"
1826R"==(LOAD_DATA_8x16_USING_LAYOUT(dd_data, diff_dst); )==""\n"
1827R"==(LOAD_DATA_8x16_USING_LAYOUT(src_data, src); )==""\n"
1828R"==(#endif )==""\n"
1829R"==(src += C_PARALLEL_FACTOR * IC_BLOCK_STRIDE; )==""\n"
1830R"==(diff_dst += C_PARALLEL_FACTOR * IC_BLOCK_STRIDE; )==""\n"
1831R"==(#if FUSE_BN_RELU == 1 )==""\n"
1832R"==(ws += C_PARALLEL_FACTOR * IC_BLOCK_STRIDE; )==""\n"
1833R"==(#endif )==""\n"
1834R"==(#if FUSE_BN_RELU == 1 )==""\n"
1835R"==(const float8 C_ZERO = 0.0; )==""\n"
1836R"==(dd_data = select(C_ZERO, dd_data, convert_int8(ws_data)); )==""\n"
1837R"==(#if FUSE_BN_ADD_RELU )==""\n"
1838R"==(write_8x16_block(diff_src_add, c, dd_data); )==""\n"
1839R"==(#endif )==""\n"
1840R"==(#endif )==""\n"
1841R"==(#if CALCULATE_DIFF_STATS == 1 )==""\n"
1842R"==(dd_data -= (diff_beta )==""\n"
1843R"==(+ (src_data - v_mean) * diff_gamma * sqrt_variance) )==""\n"
1844R"==(/ (MB * ID * IH * IW); )==""\n"
1845R"==(#endif )==""\n"
1846R"==(dd_data *= gamma * sqrt_variance; )==""\n"
1847R"==(write_8x16_block(diff_src, c, dd_data); )==""\n"
1848R"==(diff_src += C_PARALLEL_FACTOR * IC_BLOCK_STRIDE; )==""\n"
1849R"==(#if FUSE_BN_ADD_RELU )==""\n"
1850R"==(diff_src_add += C_PARALLEL_FACTOR * IC_BLOCK_STRIDE; )==""\n"
1851R"==(#endif )==""\n"
1852R"==(} )==""\n"
1853R"==(#if HAS_SP_TAIL )==""\n"
1854R"==(if (sp_block_idx == SP_TAIL / VECT_SIZE) { )==""\n"
1855R"==(sp = (SP - SP_TAIL) % C_PARALLEL_FACTOR; )==""\n"
1856R"==(while (sp-- >= 1) { )==""\n"
1857R"==(#if FUSE_BN_RELU == 1 )==""\n"
1858R"==(const char ws_data = LOAD_CHAR_1x16(&ws[0]); )==""\n"
1859R"==(#endif )==""\n"
1860R"==(#if HAS_IC_TAIL )==""\n"
1861R"==(float dd_data; )==""\n"
1862R"==(if (sp == 0 && is_last_ic_block) )==""\n"
1863R"==(dd_data = simd_id < 8 ? CONVERT_FLOAT_T(diff_dst[simd_id]) )==""\n"
1864R"==(: 0.0f; )==""\n"
1865R"==(else )==""\n"
1866R"==(dd_data = LOAD_DATA_1x16(&diff_dst[0]); )==""\n"
1867R"==(#if CALCULATE_DIFF_STATS == 1 )==""\n"
1868R"==(float src_data; )==""\n"
1869R"==(if (sp == 0 && is_last_ic_block) )==""\n"
1870R"==(src_data = simd_id < 8 ? CONVERT_FLOAT_T(src[simd_id]) : 0.0f; )==""\n"
1871R"==(else )==""\n"
1872R"==(src_data = LOAD_DATA_1x16(&src[0]); )==""\n"
1873R"==(#endif )==""\n"
1874R"==(#else )==""\n"
1875R"==(float dd_data = LOAD_DATA_1x16(&diff_dst[0]); )==""\n"
1876R"==(#if CALCULATE_DIFF_STATS == 1 )==""\n"
1877R"==(const float src_data = LOAD_DATA_1x16(&src[0]); )==""\n"
1878R"==(#endif )==""\n"
1879R"==(#endif )==""\n"
1880R"==(src += IC_BLOCK_STRIDE; )==""\n"
1881R"==(diff_dst += IC_BLOCK_STRIDE; )==""\n"
1882R"==(#if FUSE_BN_RELU == 1 )==""\n"
1883R"==(ws += IC_BLOCK_STRIDE; )==""\n"
1884R"==(#endif )==""\n"
1885R"==(#if FUSE_BN_RELU == 1 )==""\n"
1886R"==(if (ws_data == 0) dd_data = 0; )==""\n"
1887R"==(#if FUSE_BN_ADD_RELU )==""\n"
1888R"==(write_1x16_block(diff_src_add, c, dd_data); )==""\n"
1889R"==(#endif )==""\n"
1890R"==(#endif )==""\n"
1891R"==(#if CALCULATE_DIFF_STATS == 1 )==""\n"
1892R"==(dd_data -= (diff_beta )==""\n"
1893R"==(+ (src_data - v_mean) * diff_gamma )==""\n"
1894R"==(* sqrt_variance) )==""\n"
1895R"==(/ (MB * ID * IH * IW); )==""\n"
1896R"==(#endif )==""\n"
1897R"==(dd_data *= gamma * sqrt_variance; )==""\n"
1898R"==(write_1x16_block(diff_src, c, dd_data); )==""\n"
1899R"==(diff_src += IC_BLOCK_STRIDE; )==""\n"
1900R"==(#if FUSE_BN_ADD_RELU )==""\n"
1901R"==(diff_src_add += IC_BLOCK_STRIDE; )==""\n"
1902R"==(#endif )==""\n"
1903R"==(} )==""\n"
1904R"==(} )==""\n"
1905R"==(#endif )==""\n"
1906R"==(} )==""\n"
1907R"==(#endif )==""\n"
1908R"==(#endif )==""\n"
1909R"==()==";
1910}
1911}
1912}
1913}