1/*******************************************************************************
2 * Copyright 2020-2022 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17#include "gpu/ocl/ref_zero_pad.hpp"
18#include "gpu/compute/compute.hpp"
19#include "gpu/ocl/ocl_memory_storage.hpp"
20
21namespace dnnl {
22namespace impl {
23namespace gpu {
24namespace ocl {
25
26status_t ref_zero_pad_t::execute_ref(const exec_ctx_t &ctx) const {
27 compute::kernel_arg_list_t arg_list;
28
29 const memory_t *memory = ctx.input(DNNL_ARG_SRC);
30 memory_storage_t *mem_storage = memory->memory_storage();
31 memory_desc_wrapper mdw(memory->md());
32
33 const int ndims = mdw.ndims();
34 const auto &dims = mdw.dims();
35 const auto &pdims = mdw.padded_dims();
36 const blocking_desc_t blocking_desc = mdw.blocking_desc();
37 const ptrdiff_t nelems = (ptrdiff_t)mdw.nelems(true);
38 const compute::compute_engine_t *engine
39 = utils::downcast<compute::compute_engine_t *>(
40 ctx.stream()->engine());
41 const compute::device_info_t *device = engine->device_info();
42 const unsigned int hw_threads = device->hw_threads();
43
44 // Setup Initial parameters used in opencl kernel computation
45 dims_t blk_size;
46 for (int i = 0; i < ndims; i++) {
47 blk_size[i] = 1;
48 }
49
50 cl_ulong step_nelems = 1;
51 for (int i = 0; i < blocking_desc.inner_nblks; i++) {
52 step_nelems *= blocking_desc.inner_blks[i];
53 blk_size[blocking_desc.inner_idxs[i]] *= blocking_desc.inner_blks[i];
54 }
55
56 // This constant needs to be the same as DEFAULT_NELEMS_BLOCK in
57 // ref_zero_pad.cl
58 const int default_nelems_block = 8;
59
60 // This divisibility condition cannot be changed without some modifications
61 // to use of DEFAULT_NELEMS_BLOCK in ref_zero_pad.cl
62 size_t nelems_block = 1;
63 while (nelems_block < default_nelems_block
64 && step_nelems % (nelems_block * 2) == 0)
65 nelems_block *= 2;
66
67 arg_list.set(0, *mem_storage);
68 arg_list.set(1, mdw.data_type_size());
69 arg_list.set(2, step_nelems);
70 arg_list.set(3, nelems_block);
71
72 for (int i = 0; i < ndims; i++) {
73 if (dims[i] == pdims[i]) continue;
74 cl_ulong stride = 1;
75 cl_ulong step_count = 1;
76
77 step_count = blocking_desc.strides[i] / step_nelems;
78 stride = blocking_desc.strides[i] * (pdims[i] / blk_size[i]);
79 size_t npsteps = (nelems / stride) * step_count;
80
81 // Balance work unit size with parallelism
82 cl_ulong step_block = 1;
83 if (!engine->is_xe_hp() && !engine->is_xe_hpg()) {
84 while (step_nelems / nelems_block * step_block < 4 * 1024
85 && step_count % (step_block * 2) == 0
86 && npsteps / step_block > 2 * hw_threads) {
87 step_block *= 2;
88 }
89 }
90 dim_t tail_start = dims[i] % blk_size[i];
91 dims_t pos;
92 for (int j = 0; j < ndims; j++) {
93 pos[j] = 0;
94 }
95
96 zero_pad_mask_t bit_mask;
97 zero_pad_mask_t lookup_mask;
98 for (unsigned int j = 0; j < ZERO_PAD_MASK_SIZE; j++)
99 bit_mask.mask[j] = 0;
100
101 bool is_done = false;
102 bool use_lookup_mask = true;
103 size_t mask_count = 0;
104 while (!is_done) {
105 size_t idx = mdw.off_v(pos, true);
106 bool is_valid = pos[i] >= tail_start;
107 size_t mask_idx = idx / ZERO_PAD_MASK_DT_BITS;
108 size_t mask_bit = idx % ZERO_PAD_MASK_DT_BITS;
109 bit_mask.mask[mask_idx] |= (is_valid ? (1 << mask_bit) : 0);
110 if (is_valid && use_lookup_mask) {
111 if (mask_count < ZERO_PAD_MASK_SIZE
112 && idx <= std::numeric_limits<
113 ZERO_PAD_MASK_DATA_TYPE>::max()) {
114 lookup_mask.mask[mask_count] = (ZERO_PAD_MASK_DATA_TYPE)idx;
115 mask_count++;
116 } else {
117 use_lookup_mask = false;
118 }
119 }
120
121 //Increment position in the block
122 is_done = true;
123 for (int j = 0; j < ndims; j++) {
124 if (blk_size[j] - 1 == pos[j]) continue;
125 is_done = false;
126 pos[j] = pos[j] + 1;
127 for (int k = j - 1; k >= 0; k--)
128 pos[k] = 0;
129 break;
130 }
131 }
132
133 size_t mode = ZERO_PAD_BIT_MODE;
134 size_t gws0 = nelems_block;
135 zero_pad_mask_t *mask_in = &bit_mask;
136 if (use_lookup_mask) {
137 mode = ZERO_PAD_LOOKUP_MODE;
138 gws0 = mask_count;
139 mask_in = &lookup_mask;
140 }
141
142 arg_list.set(4, step_block);
143 arg_list.set(5, step_count);
144 arg_list.set(6, stride);
145 arg_list.set(7, *mask_in);
146 arg_list.set(8, mode);
147
148 const size_t gws[3]
149 = {gws0, step_count / step_block, npsteps / step_count};
150 const compute::nd_range_t nd_range = compute::nd_range_t(3, gws);
151 status_t status = parallel_for(ctx, nd_range, kernel_, arg_list);
152 if (status != status::success) return status;
153 }
154 return status::success;
155}
156
157status_t ref_zero_pad_t::execute_subg_16(const exec_ctx_t &ctx,
158 const memory_desc_wrapper &mdw,
159 const blocking_desc_t &blocking_desc) const {
160
161 const memory_t *memory = ctx.input(DNNL_ARG_SRC);
162 const memory_storage_t *mem_storage = memory->memory_storage();
163
164 const int ndims = mdw.ndims();
165 const auto &dims = mdw.dims();
166 const auto &pdims = mdw.padded_dims();
167 const auto mem_total_size = mdw.size();
168
169 const auto most_inner_nblk = blocking_desc.inner_nblks - 1;
170
171 const unsigned mem_dt_size = static_cast<unsigned>(mdw.data_type_size());
172
173 const cl_ulong most_inner_block_size
174 = mem_dt_size * blocking_desc.inner_blks[most_inner_nblk];
175
176 compute::kernel_arg_list_t arg_list;
177 arg_list.set(0, *mem_storage);
178 arg_list.set(1, mem_dt_size);
179 arg_list.set(3, most_inner_block_size);
180
181 int arg_idx = 0;
182 size_t gws2 = 1;
183 const size_t lws[3] = {16, 1, 1};
184
185 for (int j = 0; j < MAX_NDIMS; ++j) {
186 if (j != blocking_desc.inner_idxs[most_inner_nblk]
187 && j != blocking_desc.inner_idxs[most_inner_nblk - 1]) {
188 assert(arg_idx < 4);
189 if (j < ndims) {
190 arg_list.set(5 + arg_idx,
191 mem_dt_size * (cl_ulong)blocking_desc.strides[j]);
192 arg_list.set(9 + arg_idx++, (unsigned)dims[j]);
193 gws2 *= dims[j];
194 } else {
195 arg_list.set(5 + arg_idx, cl_ulong(0));
196 arg_list.set(9 + arg_idx++, unsigned(1));
197 }
198 }
199 }
200
201 status_t status;
202 dims_t coordinates;
203
204 if (pdims[blocking_desc.inner_idxs[most_inner_nblk]]
205 != dims[blocking_desc.inner_idxs[most_inner_nblk]]) {
206 for (int j = 0; j < ndims; ++j) {
207 coordinates[j] = 0;
208 }
209 coordinates[blocking_desc.inner_idxs[most_inner_nblk]]
210 = dims[blocking_desc.inner_idxs[most_inner_nblk]];
211 const cl_ulong most_inner_block_base_offset
212 = mem_dt_size * mdw.off_v(coordinates, true);
213
214 const cl_ulong s2most_inner_block_stride = mem_dt_size
215 * blocking_desc.strides[blocking_desc.inner_idxs[most_inner_nblk
216 - 1]];
217 const unsigned most_inner_block_write_multiplier
218 = (pdims[blocking_desc.inner_idxs[most_inner_nblk]]
219 - dims[blocking_desc.inner_idxs[most_inner_nblk]])
220 / 16;
221
222 arg_list.set(2, most_inner_block_base_offset);
223 arg_list.set(4, s2most_inner_block_stride);
224 arg_list.set(13, most_inner_block_write_multiplier);
225
226 const size_t gws0 = 16
227 * nstl::min<dnnl_dim_t>(
228 dims[blocking_desc.inner_idxs[most_inner_nblk - 1]],
229 blocking_desc.inner_blks[most_inner_nblk - 1]);
230 const size_t gws1 = nstl::max<dnnl_dim_t>(
231 dims[blocking_desc.inner_idxs[most_inner_nblk - 1]]
232 / blocking_desc.inner_blks[most_inner_nblk - 1],
233 1);
234 const size_t gws[3] = {gws0, gws1, gws2};
235 const compute::nd_range_t zp_nd_range
236 = compute::nd_range_t(3, gws, lws);
237
238 status = parallel_for(ctx, zp_nd_range, kernel_subg16_, arg_list);
239 CHECK(status);
240
241 if (dims[blocking_desc.inner_idxs[most_inner_nblk - 1]]
242 != pdims[blocking_desc.inner_idxs[most_inner_nblk - 1]]
243 && s2most_inner_block_stride != mem_total_size) {
244 const cl_ulong base_offset_b2 = most_inner_block_base_offset
245 + s2most_inner_block_stride * gws1;
246 arg_list.set(2, base_offset_b2);
247
248 const size_t gws_10 = 16
249 * (dims[blocking_desc.inner_idxs[most_inner_nblk - 1]]
250 % blocking_desc.inner_blks[most_inner_nblk - 1]);
251 const size_t gws_1[3] = {gws_10, 1, gws2};
252 const compute::nd_range_t zp_nd_range1
253 = compute::nd_range_t(3, gws_1, lws);
254 status = parallel_for(ctx, zp_nd_range1, kernel_subg16_, arg_list);
255 CHECK(status);
256 }
257 }
258
259 for (int j = 0; j < ndims; ++j) {
260 coordinates[j] = 0;
261 }
262 coordinates[blocking_desc.inner_idxs[most_inner_nblk - 1]]
263 = dims[blocking_desc.inner_idxs[most_inner_nblk - 1]];
264 const cl_ulong s2most_inner_block_base_offset
265 = mem_dt_size * mdw.off_v(coordinates, true);
266
267 const cl_ulong most_inner_block_offset = mem_dt_size
268 * blocking_desc.strides[blocking_desc.inner_idxs[most_inner_nblk]];
269
270 const unsigned most_inner_block_write_multiplier = nstl::max<dnnl_dim_t>(
271 blocking_desc.inner_blks[most_inner_nblk] / 16, 1);
272
273 arg_list.set(2, s2most_inner_block_base_offset);
274 arg_list.set(4, most_inner_block_offset);
275 arg_list.set(13, most_inner_block_write_multiplier);
276
277 const size_t gws0
278 = ((pdims[blocking_desc.inner_idxs[most_inner_nblk - 1]]
279 - dims[blocking_desc.inner_idxs[most_inner_nblk - 1]])
280 * blocking_desc.inner_blks[most_inner_nblk])
281 / most_inner_block_write_multiplier;
282 const size_t gws1 = nstl::max<dnnl_dim_t>(
283 pdims[blocking_desc.inner_idxs[most_inner_nblk]]
284 / blocking_desc.inner_blks[most_inner_nblk],
285 1);
286 const size_t gws[3] = {gws0, gws1, gws2};
287
288 const compute::nd_range_t zp_nd_range = compute::nd_range_t(3, gws, lws);
289 status = parallel_for(ctx, zp_nd_range, kernel_subg16_, arg_list);
290
291 return status;
292}
293
294status_t ref_zero_pad_t::execute_subg_16_mask_and_clear_dt_1B(
295 const exec_ctx_t &ctx, const memory_desc_wrapper &mdw,
296 const blocking_desc_t &blocking_desc) const {
297
298 const memory_t *memory = ctx.input(DNNL_ARG_SRC);
299 const memory_storage_t *mem_storage = memory->memory_storage();
300
301 const compute::compute_engine_t *engine
302 = utils::downcast<compute::compute_engine_t *>(
303 ctx.stream()->engine());
304 const compute::device_info_t *device = engine->device_info();
305
306 const size_t max_local_ws = device->max_wg_size();
307
308 const auto &dims = mdw.dims();
309 const auto nelems = mdw.nelems(true);
310
311 compute::kernel_arg_list_t arg_list;
312 arg_list.set(0, *mem_storage);
313
314 const unsigned mask
315 = dims[blocking_desc.inner_idxs[0]] % blocking_desc.inner_blks[0];
316 arg_list.set(1, mask);
317
318 const unsigned block_size = 16 * 8; // SIMD * block_size
319 const size_t gws[3] = {static_cast<size_t>(16 * nelems / block_size), 1, 1};
320 const size_t lws[3] = {max_local_ws, 1, 1};
321
322 const compute::nd_range_t zp_nd_range = compute::nd_range_t(3, gws, lws);
323
324 return parallel_for(
325 ctx, zp_nd_range, kernel_subg16_mask_and_clear_dt_1b_, arg_list);
326}
327
328status_t ref_zero_pad_t::execute(const exec_ctx_t &ctx) const {
329 const memory_t *memory = ctx.input(DNNL_ARG_SRC);
330 const memory_desc_wrapper mdw(memory->md());
331 const blocking_desc_t &blocking_desc = mdw.blocking_desc();
332
333 using namespace format_tag;
334 if (blocking_desc.inner_nblks == 2
335 && mdw.dims()[blocking_desc.inner_idxs[1]] % 16 == 0
336 && blocking_desc.inner_blks[1] % 16 == 0) {
337 return execute_subg_16(ctx, mdw, blocking_desc);
338 } else if (blocking_desc.inner_nblks == 1
339 && blocking_desc.inner_blks[0] == 32
340 && mdw.dims()[blocking_desc.inner_idxs[0]] < 16
341 && (mdw.nelems(true) % 4096) == 0 && mdw.data_type_size() == 1) {
342 return execute_subg_16_mask_and_clear_dt_1B(ctx, mdw, blocking_desc);
343 } else {
344 return execute_ref(ctx);
345 }
346}
347
348} // namespace ocl
349} // namespace gpu
350} // namespace impl
351} // namespace dnnl
352