1 | namespace dnnl { |
2 | namespace impl { |
3 | namespace gpu { |
4 | namespace ocl { |
5 | const char *ref_zero_pad_kernel = R"==(/******************************************************************************* )==" "\n" |
6 | R"==(* Copyright 2020-2021 Intel Corporation )==" "\n" |
7 | R"==(* )==" "\n" |
8 | R"==(* Licensed under the Apache License, Version 2.0 (the "License"); )==" "\n" |
9 | R"==(* you may not use this file except in compliance with the License. )==" "\n" |
10 | R"==(* You may obtain a copy of the License at )==" "\n" |
11 | R"==(* )==" "\n" |
12 | R"==(* http: )==" "\n" |
13 | R"==(* )==" "\n" |
14 | R"==(* Unless required by applicable law or agreed to in writing, software )==" "\n" |
15 | R"==(* distributed under the License is distributed on an "AS IS" BASIS, )==" "\n" |
16 | R"==(* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. )==" "\n" |
17 | R"==(* See the License for the specific language governing permissions and )==" "\n" |
18 | R"==(* limitations under the License. )==" "\n" |
19 | R"==(*******************************************************************************/ )==" "\n" |
20 | R"==(#define IS_OCL_KERNEL )==" "\n" |
21 | R"==(#include "gpu/zero_pad_struct.h" )==" "\n" |
22 | R"==(#define DEFAULT_NELEMS_BLOCK 8 )==" "\n" |
23 | R"==(static inline void typed_ref_zero_pad(__global void *a, ulong type_size, )==" "\n" |
24 | R"==(ulong step_nelems, ulong nelems_block, ulong step_block, ulong nsteps, )==" "\n" |
25 | R"==(ulong step_size, zero_pad_mask_t step_bitmask, ulong mode) { )==" "\n" |
26 | R"==(const int i0 = get_global_id(0); )==" "\n" |
27 | R"==(const int istep = get_global_id(1) * step_block; )==" "\n" |
28 | R"==(const int iblock = get_global_id(2); )==" "\n" |
29 | R"==(int offset = iblock * step_size + (step_size - nsteps * step_nelems) )==" "\n" |
30 | R"==(+ istep * step_nelems; )==" "\n" |
31 | R"==(const int step = ZERO_PAD_MASK_DT_BITS; )==" "\n" |
32 | R"==(__global int *a4 = (__global int *)a; )==" "\n" |
33 | R"==(__global short *a2 = (__global short *)a; )==" "\n" |
34 | R"==(__global char *a1 = (__global char *)a; )==" "\n" |
35 | R"==(if (mode == ZERO_PAD_BIT_MODE) { )==" "\n" |
36 | R"==(for (int k = 0; k < step_block; k++) { )==" "\n" |
37 | R"==(__attribute__((opencl_unroll_hint)) )==" "\n" |
38 | R"==(for (int i = i0; i < step_nelems; i += nelems_block) { )==" "\n" |
39 | R"==(if (step_bitmask.mask[i / step] & (1 << (i % step))) { )==" "\n" |
40 | R"==(switch (type_size) { )==" "\n" |
41 | R"==(case 4: a4[offset + i] = 0; break; )==" "\n" |
42 | R"==(case 2: a2[offset + i] = 0; break; )==" "\n" |
43 | R"==(case 1: a1[offset + i] = 0; break; )==" "\n" |
44 | R"==(} )==" "\n" |
45 | R"==(} )==" "\n" |
46 | R"==(} )==" "\n" |
47 | R"==(offset += step_nelems; )==" "\n" |
48 | R"==(} )==" "\n" |
49 | R"==(} else { )==" "\n" |
50 | R"==(int i = step_bitmask.mask[i0]; )==" "\n" |
51 | R"==(for (int k = 0; k < step_block; k++) { )==" "\n" |
52 | R"==(switch (type_size) { )==" "\n" |
53 | R"==(case 4: a4[offset + i] = 0; break; )==" "\n" |
54 | R"==(case 2: a2[offset + i] = 0; break; )==" "\n" |
55 | R"==(case 1: a1[offset + i] = 0; break; )==" "\n" |
56 | R"==(} )==" "\n" |
57 | R"==(offset += step_nelems; )==" "\n" |
58 | R"==(} )==" "\n" |
59 | R"==(} )==" "\n" |
60 | R"==(} )==" "\n" |
61 | R"==(static inline void sized_ref_zero_pad(__global void *a, ulong type_size, )==" "\n" |
62 | R"==(ulong step_nelems, ulong nelems_block, ulong step_block, ulong nsteps, )==" "\n" |
63 | R"==(ulong step_size, zero_pad_mask_t step_bitmask, ulong mode) { )==" "\n" |
64 | R"==(switch (type_size) { )==" "\n" |
65 | R"==(case 4: )==" "\n" |
66 | R"==(typed_ref_zero_pad((__global float *)a, 4, step_nelems, )==" "\n" |
67 | R"==(nelems_block, step_block, nsteps, step_size, step_bitmask, )==" "\n" |
68 | R"==(mode); )==" "\n" |
69 | R"==(break; )==" "\n" |
70 | R"==(case 2: )==" "\n" |
71 | R"==(typed_ref_zero_pad((__global float *)a, 2, step_nelems, )==" "\n" |
72 | R"==(nelems_block, step_block, nsteps, step_size, step_bitmask, )==" "\n" |
73 | R"==(mode); )==" "\n" |
74 | R"==(break; )==" "\n" |
75 | R"==(case 1: )==" "\n" |
76 | R"==(typed_ref_zero_pad((__global float *)a, 1, step_nelems, )==" "\n" |
77 | R"==(nelems_block, step_block, nsteps, step_size, step_bitmask, )==" "\n" |
78 | R"==(mode); )==" "\n" |
79 | R"==(break; )==" "\n" |
80 | R"==(} )==" "\n" |
81 | R"==(} )==" "\n" |
82 | R"==(__kernel void ref_zero_pad(__global void *a, ulong type_size, ulong step_nelems, )==" "\n" |
83 | R"==(ulong nelems_block, ulong step_block, ulong nsteps, ulong step_size, )==" "\n" |
84 | R"==(zero_pad_mask_t step_bitmask, ulong mode) { )==" "\n" |
85 | R"==(switch (step_nelems) { )==" "\n" |
86 | R"==(case 16: )==" "\n" |
87 | R"==(sized_ref_zero_pad(a, type_size, 16, DEFAULT_NELEMS_BLOCK, )==" "\n" |
88 | R"==(step_block, nsteps, step_size, step_bitmask, mode); )==" "\n" |
89 | R"==(break; )==" "\n" |
90 | R"==(case 32: )==" "\n" |
91 | R"==(sized_ref_zero_pad(a, type_size, 32, DEFAULT_NELEMS_BLOCK, )==" "\n" |
92 | R"==(step_block, nsteps, step_size, step_bitmask, mode); )==" "\n" |
93 | R"==(break; )==" "\n" |
94 | R"==(case 64: )==" "\n" |
95 | R"==(sized_ref_zero_pad(a, type_size, 64, DEFAULT_NELEMS_BLOCK, )==" "\n" |
96 | R"==(step_block, nsteps, step_size, step_bitmask, mode); )==" "\n" |
97 | R"==(break; )==" "\n" |
98 | R"==(default: )==" "\n" |
99 | R"==(sized_ref_zero_pad(a, type_size, step_nelems, nelems_block, )==" "\n" |
100 | R"==(step_block, nsteps, step_size, step_bitmask, mode); )==" "\n" |
101 | R"==(break; )==" "\n" |
102 | R"==(} )==" "\n" |
103 | R"==(} )==" "\n" |
104 | R"==(__attribute__((intel_reqd_sub_group_size(16))) __kernel void )==" "\n" |
105 | R"==(ref_zero_pad_subg_16(__global void *a, const uint type_size, )==" "\n" |
106 | R"==(const ulong base_offset, const ulong b_block_size, )==" "\n" |
107 | R"==(const ulong b_block_offset, const ulong d0_stride, )==" "\n" |
108 | R"==(const ulong d1_stride, const ulong d2_stride, const ulong d3_stride, )==" "\n" |
109 | R"==(const unsigned d0_size, const unsigned d1_size, const unsigned d2_size, )==" "\n" |
110 | R"==(const unsigned d3_size, const uint b_multiplier) { )==" "\n" |
111 | R"==(const unsigned a_block_id = get_global_id(0) / 16; )==" "\n" |
112 | R"==(const unsigned b_block_id = get_global_id(1); )==" "\n" |
113 | R"==(unsigned mixed_dims = get_global_id(2); )==" "\n" |
114 | R"==(const unsigned d3_dim = mixed_dims % d3_size; )==" "\n" |
115 | R"==(mixed_dims /= d3_size; )==" "\n" |
116 | R"==(const unsigned d2_dim = mixed_dims % d2_size; )==" "\n" |
117 | R"==(mixed_dims /= d2_size; )==" "\n" |
118 | R"==(const unsigned d1_dim = mixed_dims % d1_size; )==" "\n" |
119 | R"==(const unsigned d0_dim = mixed_dims / d1_size; )==" "\n" |
120 | R"==(__global void *p = a + base_offset; )==" "\n" |
121 | R"==(p += a_block_id * b_block_size; )==" "\n" |
122 | R"==(p += b_block_id * b_block_offset; )==" "\n" |
123 | R"==(p += d0_dim * d0_stride; )==" "\n" |
124 | R"==(p += d1_dim * d1_stride; )==" "\n" |
125 | R"==(p += d2_dim * d2_stride; )==" "\n" |
126 | R"==(p += d3_dim * d3_stride; )==" "\n" |
127 | R"==(const unsigned stride = 16 * type_size; )==" "\n" |
128 | R"==(for (unsigned b_midx = 0; b_midx < b_multiplier; ++b_midx) { )==" "\n" |
129 | R"==(switch (type_size) { )==" "\n" |
130 | R"==(case 4: intel_sub_group_block_write((__global uint *)p, 0); break; )==" "\n" |
131 | R"==(case 2: )==" "\n" |
132 | R"==(intel_sub_group_block_write_us((__global ushort *)p, 0); )==" "\n" |
133 | R"==(break; )==" "\n" |
134 | R"==(case 1: )==" "\n" |
135 | R"==(intel_sub_group_block_write_uc((__global uchar *)p, 0); )==" "\n" |
136 | R"==(break; )==" "\n" |
137 | R"==(} )==" "\n" |
138 | R"==(p += stride; )==" "\n" |
139 | R"==(} )==" "\n" |
140 | R"==(} )==" "\n" |
141 | R"==(__attribute__((intel_reqd_sub_group_size(16))) __kernel void )==" "\n" |
142 | R"==(ref_zero_pad_subg_16_mask_and_clear_dt_1b(__global void *a, const uint mask) { )==" "\n" |
143 | R"==(const uint block_size = 8; )==" "\n" |
144 | R"==(const uint data_stride = 32; )==" "\n" |
145 | R"==(const uint simd = 16; )==" "\n" |
146 | R"==(const ulong offset = get_global_id(0) * block_size; )==" "\n" |
147 | R"==(const unsigned subg_local_id = get_sub_group_local_id(); )==" "\n" |
148 | R"==(__global void *p = a + offset; )==" "\n" |
149 | R"==(const uint mask_val = mask > subg_local_id ? 1 : 0; )==" "\n" |
150 | R"==(uchar val_c[block_size]; )==" "\n" |
151 | R"==(for (unsigned idx = 0; idx < block_size / 2; ++idx) { )==" "\n" |
152 | R"==(val_c[idx * 2] = intel_sub_group_block_read_uc( )==" "\n" |
153 | R"==((__global uchar *)(p + data_stride * idx)); )==" "\n" |
154 | R"==(} )==" "\n" |
155 | R"==(for (unsigned idx = 1; idx < block_size; idx += 2) { )==" "\n" |
156 | R"==(val_c[idx] = 0; )==" "\n" |
157 | R"==(} )==" "\n" |
158 | R"==(for (unsigned idx = 0; idx < block_size; idx += 2) { )==" "\n" |
159 | R"==(val_c[idx] *= (uchar)mask_val; )==" "\n" |
160 | R"==(} )==" "\n" |
161 | R"==(for (unsigned idx = 0; idx < block_size; idx += 8) { )==" "\n" |
162 | R"==(intel_sub_group_block_write_uc8( )==" "\n" |
163 | R"==((__global uchar *)(p + simd * idx), *((uchar8 *)(&val_c[idx]))); )==" "\n" |
164 | R"==(} )==" "\n" |
165 | R"==(} )==" "\n" |
166 | R"==()==" ; |
167 | } |
168 | } |
169 | } |
170 | } |