1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "gpu/ocl/ref_zero_pad.hpp" |
18 | #include "gpu/compute/compute.hpp" |
19 | #include "gpu/ocl/ocl_memory_storage.hpp" |
20 | |
21 | namespace dnnl { |
22 | namespace impl { |
23 | namespace gpu { |
24 | namespace ocl { |
25 | |
26 | status_t ref_zero_pad_t::execute_ref(const exec_ctx_t &ctx) const { |
27 | compute::kernel_arg_list_t arg_list; |
28 | |
29 | const memory_t *memory = ctx.input(DNNL_ARG_SRC); |
30 | memory_storage_t *mem_storage = memory->memory_storage(); |
31 | memory_desc_wrapper mdw(memory->md()); |
32 | |
33 | const int ndims = mdw.ndims(); |
34 | const auto &dims = mdw.dims(); |
35 | const auto &pdims = mdw.padded_dims(); |
36 | const blocking_desc_t blocking_desc = mdw.blocking_desc(); |
37 | const ptrdiff_t nelems = (ptrdiff_t)mdw.nelems(true); |
38 | const compute::compute_engine_t *engine |
39 | = utils::downcast<compute::compute_engine_t *>( |
40 | ctx.stream()->engine()); |
41 | const compute::device_info_t *device = engine->device_info(); |
42 | const unsigned int hw_threads = device->hw_threads(); |
43 | |
44 | // Setup Initial parameters used in opencl kernel computation |
45 | dims_t blk_size; |
46 | for (int i = 0; i < ndims; i++) { |
47 | blk_size[i] = 1; |
48 | } |
49 | |
50 | cl_ulong step_nelems = 1; |
51 | for (int i = 0; i < blocking_desc.inner_nblks; i++) { |
52 | step_nelems *= blocking_desc.inner_blks[i]; |
53 | blk_size[blocking_desc.inner_idxs[i]] *= blocking_desc.inner_blks[i]; |
54 | } |
55 | |
56 | // This constant needs to be the same as DEFAULT_NELEMS_BLOCK in |
57 | // ref_zero_pad.cl |
58 | const int default_nelems_block = 8; |
59 | |
60 | // This divisibility condition cannot be changed without some modifications |
61 | // to use of DEFAULT_NELEMS_BLOCK in ref_zero_pad.cl |
62 | size_t nelems_block = 1; |
63 | while (nelems_block < default_nelems_block |
64 | && step_nelems % (nelems_block * 2) == 0) |
65 | nelems_block *= 2; |
66 | |
67 | arg_list.set(0, *mem_storage); |
68 | arg_list.set(1, mdw.data_type_size()); |
69 | arg_list.set(2, step_nelems); |
70 | arg_list.set(3, nelems_block); |
71 | |
72 | for (int i = 0; i < ndims; i++) { |
73 | if (dims[i] == pdims[i]) continue; |
74 | cl_ulong stride = 1; |
75 | cl_ulong step_count = 1; |
76 | |
77 | step_count = blocking_desc.strides[i] / step_nelems; |
78 | stride = blocking_desc.strides[i] * (pdims[i] / blk_size[i]); |
79 | size_t npsteps = (nelems / stride) * step_count; |
80 | |
81 | // Balance work unit size with parallelism |
82 | cl_ulong step_block = 1; |
83 | if (!engine->is_xe_hp() && !engine->is_xe_hpg()) { |
84 | while (step_nelems / nelems_block * step_block < 4 * 1024 |
85 | && step_count % (step_block * 2) == 0 |
86 | && npsteps / step_block > 2 * hw_threads) { |
87 | step_block *= 2; |
88 | } |
89 | } |
90 | dim_t tail_start = dims[i] % blk_size[i]; |
91 | dims_t pos; |
92 | for (int j = 0; j < ndims; j++) { |
93 | pos[j] = 0; |
94 | } |
95 | |
96 | zero_pad_mask_t bit_mask; |
97 | zero_pad_mask_t lookup_mask; |
98 | for (unsigned int j = 0; j < ZERO_PAD_MASK_SIZE; j++) |
99 | bit_mask.mask[j] = 0; |
100 | |
101 | bool is_done = false; |
102 | bool use_lookup_mask = true; |
103 | size_t mask_count = 0; |
104 | while (!is_done) { |
105 | size_t idx = mdw.off_v(pos, true); |
106 | bool is_valid = pos[i] >= tail_start; |
107 | size_t mask_idx = idx / ZERO_PAD_MASK_DT_BITS; |
108 | size_t mask_bit = idx % ZERO_PAD_MASK_DT_BITS; |
109 | bit_mask.mask[mask_idx] |= (is_valid ? (1 << mask_bit) : 0); |
110 | if (is_valid && use_lookup_mask) { |
111 | if (mask_count < ZERO_PAD_MASK_SIZE |
112 | && idx <= std::numeric_limits< |
113 | ZERO_PAD_MASK_DATA_TYPE>::max()) { |
114 | lookup_mask.mask[mask_count] = (ZERO_PAD_MASK_DATA_TYPE)idx; |
115 | mask_count++; |
116 | } else { |
117 | use_lookup_mask = false; |
118 | } |
119 | } |
120 | |
121 | //Increment position in the block |
122 | is_done = true; |
123 | for (int j = 0; j < ndims; j++) { |
124 | if (blk_size[j] - 1 == pos[j]) continue; |
125 | is_done = false; |
126 | pos[j] = pos[j] + 1; |
127 | for (int k = j - 1; k >= 0; k--) |
128 | pos[k] = 0; |
129 | break; |
130 | } |
131 | } |
132 | |
133 | size_t mode = ZERO_PAD_BIT_MODE; |
134 | size_t gws0 = nelems_block; |
135 | zero_pad_mask_t *mask_in = &bit_mask; |
136 | if (use_lookup_mask) { |
137 | mode = ZERO_PAD_LOOKUP_MODE; |
138 | gws0 = mask_count; |
139 | mask_in = &lookup_mask; |
140 | } |
141 | |
142 | arg_list.set(4, step_block); |
143 | arg_list.set(5, step_count); |
144 | arg_list.set(6, stride); |
145 | arg_list.set(7, *mask_in); |
146 | arg_list.set(8, mode); |
147 | |
148 | const size_t gws[3] |
149 | = {gws0, step_count / step_block, npsteps / step_count}; |
150 | const compute::nd_range_t nd_range = compute::nd_range_t(3, gws); |
151 | status_t status = parallel_for(ctx, nd_range, kernel_, arg_list); |
152 | if (status != status::success) return status; |
153 | } |
154 | return status::success; |
155 | } |
156 | |
157 | status_t ref_zero_pad_t::execute_subg_16(const exec_ctx_t &ctx, |
158 | const memory_desc_wrapper &mdw, |
159 | const blocking_desc_t &blocking_desc) const { |
160 | |
161 | const memory_t *memory = ctx.input(DNNL_ARG_SRC); |
162 | const memory_storage_t *mem_storage = memory->memory_storage(); |
163 | |
164 | const int ndims = mdw.ndims(); |
165 | const auto &dims = mdw.dims(); |
166 | const auto &pdims = mdw.padded_dims(); |
167 | const auto mem_total_size = mdw.size(); |
168 | |
169 | const auto most_inner_nblk = blocking_desc.inner_nblks - 1; |
170 | |
171 | const unsigned mem_dt_size = static_cast<unsigned>(mdw.data_type_size()); |
172 | |
173 | const cl_ulong most_inner_block_size |
174 | = mem_dt_size * blocking_desc.inner_blks[most_inner_nblk]; |
175 | |
176 | compute::kernel_arg_list_t arg_list; |
177 | arg_list.set(0, *mem_storage); |
178 | arg_list.set(1, mem_dt_size); |
179 | arg_list.set(3, most_inner_block_size); |
180 | |
181 | int arg_idx = 0; |
182 | size_t gws2 = 1; |
183 | const size_t lws[3] = {16, 1, 1}; |
184 | |
185 | for (int j = 0; j < MAX_NDIMS; ++j) { |
186 | if (j != blocking_desc.inner_idxs[most_inner_nblk] |
187 | && j != blocking_desc.inner_idxs[most_inner_nblk - 1]) { |
188 | assert(arg_idx < 4); |
189 | if (j < ndims) { |
190 | arg_list.set(5 + arg_idx, |
191 | mem_dt_size * (cl_ulong)blocking_desc.strides[j]); |
192 | arg_list.set(9 + arg_idx++, (unsigned)dims[j]); |
193 | gws2 *= dims[j]; |
194 | } else { |
195 | arg_list.set(5 + arg_idx, cl_ulong(0)); |
196 | arg_list.set(9 + arg_idx++, unsigned(1)); |
197 | } |
198 | } |
199 | } |
200 | |
201 | status_t status; |
202 | dims_t coordinates; |
203 | |
204 | if (pdims[blocking_desc.inner_idxs[most_inner_nblk]] |
205 | != dims[blocking_desc.inner_idxs[most_inner_nblk]]) { |
206 | for (int j = 0; j < ndims; ++j) { |
207 | coordinates[j] = 0; |
208 | } |
209 | coordinates[blocking_desc.inner_idxs[most_inner_nblk]] |
210 | = dims[blocking_desc.inner_idxs[most_inner_nblk]]; |
211 | const cl_ulong most_inner_block_base_offset |
212 | = mem_dt_size * mdw.off_v(coordinates, true); |
213 | |
214 | const cl_ulong s2most_inner_block_stride = mem_dt_size |
215 | * blocking_desc.strides[blocking_desc.inner_idxs[most_inner_nblk |
216 | - 1]]; |
217 | const unsigned most_inner_block_write_multiplier |
218 | = (pdims[blocking_desc.inner_idxs[most_inner_nblk]] |
219 | - dims[blocking_desc.inner_idxs[most_inner_nblk]]) |
220 | / 16; |
221 | |
222 | arg_list.set(2, most_inner_block_base_offset); |
223 | arg_list.set(4, s2most_inner_block_stride); |
224 | arg_list.set(13, most_inner_block_write_multiplier); |
225 | |
226 | const size_t gws0 = 16 |
227 | * nstl::min<dnnl_dim_t>( |
228 | dims[blocking_desc.inner_idxs[most_inner_nblk - 1]], |
229 | blocking_desc.inner_blks[most_inner_nblk - 1]); |
230 | const size_t gws1 = nstl::max<dnnl_dim_t>( |
231 | dims[blocking_desc.inner_idxs[most_inner_nblk - 1]] |
232 | / blocking_desc.inner_blks[most_inner_nblk - 1], |
233 | 1); |
234 | const size_t gws[3] = {gws0, gws1, gws2}; |
235 | const compute::nd_range_t zp_nd_range |
236 | = compute::nd_range_t(3, gws, lws); |
237 | |
238 | status = parallel_for(ctx, zp_nd_range, kernel_subg16_, arg_list); |
239 | CHECK(status); |
240 | |
241 | if (dims[blocking_desc.inner_idxs[most_inner_nblk - 1]] |
242 | != pdims[blocking_desc.inner_idxs[most_inner_nblk - 1]] |
243 | && s2most_inner_block_stride != mem_total_size) { |
244 | const cl_ulong base_offset_b2 = most_inner_block_base_offset |
245 | + s2most_inner_block_stride * gws1; |
246 | arg_list.set(2, base_offset_b2); |
247 | |
248 | const size_t gws_10 = 16 |
249 | * (dims[blocking_desc.inner_idxs[most_inner_nblk - 1]] |
250 | % blocking_desc.inner_blks[most_inner_nblk - 1]); |
251 | const size_t gws_1[3] = {gws_10, 1, gws2}; |
252 | const compute::nd_range_t zp_nd_range1 |
253 | = compute::nd_range_t(3, gws_1, lws); |
254 | status = parallel_for(ctx, zp_nd_range1, kernel_subg16_, arg_list); |
255 | CHECK(status); |
256 | } |
257 | } |
258 | |
259 | for (int j = 0; j < ndims; ++j) { |
260 | coordinates[j] = 0; |
261 | } |
262 | coordinates[blocking_desc.inner_idxs[most_inner_nblk - 1]] |
263 | = dims[blocking_desc.inner_idxs[most_inner_nblk - 1]]; |
264 | const cl_ulong s2most_inner_block_base_offset |
265 | = mem_dt_size * mdw.off_v(coordinates, true); |
266 | |
267 | const cl_ulong most_inner_block_offset = mem_dt_size |
268 | * blocking_desc.strides[blocking_desc.inner_idxs[most_inner_nblk]]; |
269 | |
270 | const unsigned most_inner_block_write_multiplier = nstl::max<dnnl_dim_t>( |
271 | blocking_desc.inner_blks[most_inner_nblk] / 16, 1); |
272 | |
273 | arg_list.set(2, s2most_inner_block_base_offset); |
274 | arg_list.set(4, most_inner_block_offset); |
275 | arg_list.set(13, most_inner_block_write_multiplier); |
276 | |
277 | const size_t gws0 |
278 | = ((pdims[blocking_desc.inner_idxs[most_inner_nblk - 1]] |
279 | - dims[blocking_desc.inner_idxs[most_inner_nblk - 1]]) |
280 | * blocking_desc.inner_blks[most_inner_nblk]) |
281 | / most_inner_block_write_multiplier; |
282 | const size_t gws1 = nstl::max<dnnl_dim_t>( |
283 | pdims[blocking_desc.inner_idxs[most_inner_nblk]] |
284 | / blocking_desc.inner_blks[most_inner_nblk], |
285 | 1); |
286 | const size_t gws[3] = {gws0, gws1, gws2}; |
287 | |
288 | const compute::nd_range_t zp_nd_range = compute::nd_range_t(3, gws, lws); |
289 | status = parallel_for(ctx, zp_nd_range, kernel_subg16_, arg_list); |
290 | |
291 | return status; |
292 | } |
293 | |
294 | status_t ref_zero_pad_t::execute_subg_16_mask_and_clear_dt_1B( |
295 | const exec_ctx_t &ctx, const memory_desc_wrapper &mdw, |
296 | const blocking_desc_t &blocking_desc) const { |
297 | |
298 | const memory_t *memory = ctx.input(DNNL_ARG_SRC); |
299 | const memory_storage_t *mem_storage = memory->memory_storage(); |
300 | |
301 | const compute::compute_engine_t *engine |
302 | = utils::downcast<compute::compute_engine_t *>( |
303 | ctx.stream()->engine()); |
304 | const compute::device_info_t *device = engine->device_info(); |
305 | |
306 | const size_t max_local_ws = device->max_wg_size(); |
307 | |
308 | const auto &dims = mdw.dims(); |
309 | const auto nelems = mdw.nelems(true); |
310 | |
311 | compute::kernel_arg_list_t arg_list; |
312 | arg_list.set(0, *mem_storage); |
313 | |
314 | const unsigned mask |
315 | = dims[blocking_desc.inner_idxs[0]] % blocking_desc.inner_blks[0]; |
316 | arg_list.set(1, mask); |
317 | |
318 | const unsigned block_size = 16 * 8; // SIMD * block_size |
319 | const size_t gws[3] = {static_cast<size_t>(16 * nelems / block_size), 1, 1}; |
320 | const size_t lws[3] = {max_local_ws, 1, 1}; |
321 | |
322 | const compute::nd_range_t zp_nd_range = compute::nd_range_t(3, gws, lws); |
323 | |
324 | return parallel_for( |
325 | ctx, zp_nd_range, kernel_subg16_mask_and_clear_dt_1b_, arg_list); |
326 | } |
327 | |
328 | status_t ref_zero_pad_t::execute(const exec_ctx_t &ctx) const { |
329 | const memory_t *memory = ctx.input(DNNL_ARG_SRC); |
330 | const memory_desc_wrapper mdw(memory->md()); |
331 | const blocking_desc_t &blocking_desc = mdw.blocking_desc(); |
332 | |
333 | using namespace format_tag; |
334 | if (blocking_desc.inner_nblks == 2 |
335 | && mdw.dims()[blocking_desc.inner_idxs[1]] % 16 == 0 |
336 | && blocking_desc.inner_blks[1] % 16 == 0) { |
337 | return execute_subg_16(ctx, mdw, blocking_desc); |
338 | } else if (blocking_desc.inner_nblks == 1 |
339 | && blocking_desc.inner_blks[0] == 32 |
340 | && mdw.dims()[blocking_desc.inner_idxs[0]] < 16 |
341 | && (mdw.nelems(true) % 4096) == 0 && mdw.data_type_size() == 1) { |
342 | return execute_subg_16_mask_and_clear_dt_1B(ctx, mdw, blocking_desc); |
343 | } else { |
344 | return execute_ref(ctx); |
345 | } |
346 | } |
347 | |
348 | } // namespace ocl |
349 | } // namespace gpu |
350 | } // namespace impl |
351 | } // namespace dnnl |
352 | |