1namespace dnnl {
2namespace impl {
3namespace gpu {
4namespace ocl {
5const char *ref_zero_pad_kernel = R"==(/******************************************************************************* )==""\n"
6R"==(* Copyright 2020-2021 Intel Corporation )==""\n"
7R"==(* )==""\n"
8R"==(* Licensed under the Apache License, Version 2.0 (the "License"); )==""\n"
9R"==(* you may not use this file except in compliance with the License. )==""\n"
10R"==(* You may obtain a copy of the License at )==""\n"
11R"==(* )==""\n"
12R"==(* http: )==""\n"
13R"==(* )==""\n"
14R"==(* Unless required by applicable law or agreed to in writing, software )==""\n"
15R"==(* distributed under the License is distributed on an "AS IS" BASIS, )==""\n"
16R"==(* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. )==""\n"
17R"==(* See the License for the specific language governing permissions and )==""\n"
18R"==(* limitations under the License. )==""\n"
19R"==(*******************************************************************************/ )==""\n"
20R"==(#define IS_OCL_KERNEL )==""\n"
21R"==(#include "gpu/zero_pad_struct.h" )==""\n"
22R"==(#define DEFAULT_NELEMS_BLOCK 8 )==""\n"
23R"==(static inline void typed_ref_zero_pad(__global void *a, ulong type_size, )==""\n"
24R"==(ulong step_nelems, ulong nelems_block, ulong step_block, ulong nsteps, )==""\n"
25R"==(ulong step_size, zero_pad_mask_t step_bitmask, ulong mode) { )==""\n"
26R"==(const int i0 = get_global_id(0); )==""\n"
27R"==(const int istep = get_global_id(1) * step_block; )==""\n"
28R"==(const int iblock = get_global_id(2); )==""\n"
29R"==(int offset = iblock * step_size + (step_size - nsteps * step_nelems) )==""\n"
30R"==(+ istep * step_nelems; )==""\n"
31R"==(const int step = ZERO_PAD_MASK_DT_BITS; )==""\n"
32R"==(__global int *a4 = (__global int *)a; )==""\n"
33R"==(__global short *a2 = (__global short *)a; )==""\n"
34R"==(__global char *a1 = (__global char *)a; )==""\n"
35R"==(if (mode == ZERO_PAD_BIT_MODE) { )==""\n"
36R"==(for (int k = 0; k < step_block; k++) { )==""\n"
37R"==(__attribute__((opencl_unroll_hint)) )==""\n"
38R"==(for (int i = i0; i < step_nelems; i += nelems_block) { )==""\n"
39R"==(if (step_bitmask.mask[i / step] & (1 << (i % step))) { )==""\n"
40R"==(switch (type_size) { )==""\n"
41R"==(case 4: a4[offset + i] = 0; break; )==""\n"
42R"==(case 2: a2[offset + i] = 0; break; )==""\n"
43R"==(case 1: a1[offset + i] = 0; break; )==""\n"
44R"==(} )==""\n"
45R"==(} )==""\n"
46R"==(} )==""\n"
47R"==(offset += step_nelems; )==""\n"
48R"==(} )==""\n"
49R"==(} else { )==""\n"
50R"==(int i = step_bitmask.mask[i0]; )==""\n"
51R"==(for (int k = 0; k < step_block; k++) { )==""\n"
52R"==(switch (type_size) { )==""\n"
53R"==(case 4: a4[offset + i] = 0; break; )==""\n"
54R"==(case 2: a2[offset + i] = 0; break; )==""\n"
55R"==(case 1: a1[offset + i] = 0; break; )==""\n"
56R"==(} )==""\n"
57R"==(offset += step_nelems; )==""\n"
58R"==(} )==""\n"
59R"==(} )==""\n"
60R"==(} )==""\n"
61R"==(static inline void sized_ref_zero_pad(__global void *a, ulong type_size, )==""\n"
62R"==(ulong step_nelems, ulong nelems_block, ulong step_block, ulong nsteps, )==""\n"
63R"==(ulong step_size, zero_pad_mask_t step_bitmask, ulong mode) { )==""\n"
64R"==(switch (type_size) { )==""\n"
65R"==(case 4: )==""\n"
66R"==(typed_ref_zero_pad((__global float *)a, 4, step_nelems, )==""\n"
67R"==(nelems_block, step_block, nsteps, step_size, step_bitmask, )==""\n"
68R"==(mode); )==""\n"
69R"==(break; )==""\n"
70R"==(case 2: )==""\n"
71R"==(typed_ref_zero_pad((__global float *)a, 2, step_nelems, )==""\n"
72R"==(nelems_block, step_block, nsteps, step_size, step_bitmask, )==""\n"
73R"==(mode); )==""\n"
74R"==(break; )==""\n"
75R"==(case 1: )==""\n"
76R"==(typed_ref_zero_pad((__global float *)a, 1, step_nelems, )==""\n"
77R"==(nelems_block, step_block, nsteps, step_size, step_bitmask, )==""\n"
78R"==(mode); )==""\n"
79R"==(break; )==""\n"
80R"==(} )==""\n"
81R"==(} )==""\n"
82R"==(__kernel void ref_zero_pad(__global void *a, ulong type_size, ulong step_nelems, )==""\n"
83R"==(ulong nelems_block, ulong step_block, ulong nsteps, ulong step_size, )==""\n"
84R"==(zero_pad_mask_t step_bitmask, ulong mode) { )==""\n"
85R"==(switch (step_nelems) { )==""\n"
86R"==(case 16: )==""\n"
87R"==(sized_ref_zero_pad(a, type_size, 16, DEFAULT_NELEMS_BLOCK, )==""\n"
88R"==(step_block, nsteps, step_size, step_bitmask, mode); )==""\n"
89R"==(break; )==""\n"
90R"==(case 32: )==""\n"
91R"==(sized_ref_zero_pad(a, type_size, 32, DEFAULT_NELEMS_BLOCK, )==""\n"
92R"==(step_block, nsteps, step_size, step_bitmask, mode); )==""\n"
93R"==(break; )==""\n"
94R"==(case 64: )==""\n"
95R"==(sized_ref_zero_pad(a, type_size, 64, DEFAULT_NELEMS_BLOCK, )==""\n"
96R"==(step_block, nsteps, step_size, step_bitmask, mode); )==""\n"
97R"==(break; )==""\n"
98R"==(default: )==""\n"
99R"==(sized_ref_zero_pad(a, type_size, step_nelems, nelems_block, )==""\n"
100R"==(step_block, nsteps, step_size, step_bitmask, mode); )==""\n"
101R"==(break; )==""\n"
102R"==(} )==""\n"
103R"==(} )==""\n"
104R"==(__attribute__((intel_reqd_sub_group_size(16))) __kernel void )==""\n"
105R"==(ref_zero_pad_subg_16(__global void *a, const uint type_size, )==""\n"
106R"==(const ulong base_offset, const ulong b_block_size, )==""\n"
107R"==(const ulong b_block_offset, const ulong d0_stride, )==""\n"
108R"==(const ulong d1_stride, const ulong d2_stride, const ulong d3_stride, )==""\n"
109R"==(const unsigned d0_size, const unsigned d1_size, const unsigned d2_size, )==""\n"
110R"==(const unsigned d3_size, const uint b_multiplier) { )==""\n"
111R"==(const unsigned a_block_id = get_global_id(0) / 16; )==""\n"
112R"==(const unsigned b_block_id = get_global_id(1); )==""\n"
113R"==(unsigned mixed_dims = get_global_id(2); )==""\n"
114R"==(const unsigned d3_dim = mixed_dims % d3_size; )==""\n"
115R"==(mixed_dims /= d3_size; )==""\n"
116R"==(const unsigned d2_dim = mixed_dims % d2_size; )==""\n"
117R"==(mixed_dims /= d2_size; )==""\n"
118R"==(const unsigned d1_dim = mixed_dims % d1_size; )==""\n"
119R"==(const unsigned d0_dim = mixed_dims / d1_size; )==""\n"
120R"==(__global void *p = a + base_offset; )==""\n"
121R"==(p += a_block_id * b_block_size; )==""\n"
122R"==(p += b_block_id * b_block_offset; )==""\n"
123R"==(p += d0_dim * d0_stride; )==""\n"
124R"==(p += d1_dim * d1_stride; )==""\n"
125R"==(p += d2_dim * d2_stride; )==""\n"
126R"==(p += d3_dim * d3_stride; )==""\n"
127R"==(const unsigned stride = 16 * type_size; )==""\n"
128R"==(for (unsigned b_midx = 0; b_midx < b_multiplier; ++b_midx) { )==""\n"
129R"==(switch (type_size) { )==""\n"
130R"==(case 4: intel_sub_group_block_write((__global uint *)p, 0); break; )==""\n"
131R"==(case 2: )==""\n"
132R"==(intel_sub_group_block_write_us((__global ushort *)p, 0); )==""\n"
133R"==(break; )==""\n"
134R"==(case 1: )==""\n"
135R"==(intel_sub_group_block_write_uc((__global uchar *)p, 0); )==""\n"
136R"==(break; )==""\n"
137R"==(} )==""\n"
138R"==(p += stride; )==""\n"
139R"==(} )==""\n"
140R"==(} )==""\n"
141R"==(__attribute__((intel_reqd_sub_group_size(16))) __kernel void )==""\n"
142R"==(ref_zero_pad_subg_16_mask_and_clear_dt_1b(__global void *a, const uint mask) { )==""\n"
143R"==(const uint block_size = 8; )==""\n"
144R"==(const uint data_stride = 32; )==""\n"
145R"==(const uint simd = 16; )==""\n"
146R"==(const ulong offset = get_global_id(0) * block_size; )==""\n"
147R"==(const unsigned subg_local_id = get_sub_group_local_id(); )==""\n"
148R"==(__global void *p = a + offset; )==""\n"
149R"==(const uint mask_val = mask > subg_local_id ? 1 : 0; )==""\n"
150R"==(uchar val_c[block_size]; )==""\n"
151R"==(for (unsigned idx = 0; idx < block_size / 2; ++idx) { )==""\n"
152R"==(val_c[idx * 2] = intel_sub_group_block_read_uc( )==""\n"
153R"==((__global uchar *)(p + data_stride * idx)); )==""\n"
154R"==(} )==""\n"
155R"==(for (unsigned idx = 1; idx < block_size; idx += 2) { )==""\n"
156R"==(val_c[idx] = 0; )==""\n"
157R"==(} )==""\n"
158R"==(for (unsigned idx = 0; idx < block_size; idx += 2) { )==""\n"
159R"==(val_c[idx] *= (uchar)mask_val; )==""\n"
160R"==(} )==""\n"
161R"==(for (unsigned idx = 0; idx < block_size; idx += 8) { )==""\n"
162R"==(intel_sub_group_block_write_uc8( )==""\n"
163R"==((__global uchar *)(p + simd * idx), *((uchar8 *)(&val_c[idx]))); )==""\n"
164R"==(} )==""\n"
165R"==(} )==""\n"
166R"==()==";
167}
168}
169}
170}