1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <math.h>
18
19#include "common/primitive_exec_types.hpp"
20
21#include "gpu/ocl/gen9_reduction.hpp"
22#include "gpu/ocl/ocl_utils.hpp"
23
24#include "common/c_types_map.hpp"
25#include "common/dnnl_traits.hpp"
26#include "common/math_utils.hpp"
27#include "common/scratchpad.hpp"
28#include "common/type_helpers.hpp"
29
30namespace dnnl {
31namespace impl {
32namespace gpu {
33namespace ocl {
34
35// Extract N and C block sizes from blk.inner_blks
36std::pair<int, int> get_n_c_block_sizes(const memory_desc_wrapper &mdw) {
37 int n_block_size = 1;
38 int c_block_size = 1;
39 const blocking_desc_t &blk = mdw.blocking_desc();
40 if (blk.inner_nblks > 0) {
41 // C must be the last blocked dimension
42 assert(blk.inner_idxs[blk.inner_nblks - 1] == 1);
43 c_block_size = blk.inner_blks[blk.inner_nblks - 1];
44 // if there is NC blocking (N is the blocked dimension before C) use N blocks as well
45 if (blk.inner_nblks > 1 && blk.inner_idxs[blk.inner_nblks - 2] == 0) {
46 n_block_size = blk.inner_blks[blk.inner_nblks - 2];
47 }
48 }
49 return std::make_pair(n_block_size, c_block_size);
50}
51
52std::pair<int, int> get_initial_n_split(const int n, const bool is_n_reduced) {
53 int initial_n_chunk_size;
54 int initial_n_chunks_num;
55 if (is_n_reduced) {
56 // Start with such constant and try to adjust that with heuristics
57 initial_n_chunk_size = 64;
58 while (initial_n_chunk_size > n) {
59 initial_n_chunk_size /= 2;
60 }
61 initial_n_chunks_num
62 = ceil(static_cast<float>(n) / initial_n_chunk_size);
63 // We don't want to have too many chunks as there would be a lot of work for
64 // final reduction. Desired values were selected experimentally.
65 int desired_n_chunks = 16;
66 constexpr int min_chunk_size = 4;
67 if (n / min_chunk_size < desired_n_chunks && n / min_chunk_size >= 1) {
68 desired_n_chunks = n / min_chunk_size;
69 }
70 int desired_chunk_size = 32;
71 if (n / desired_n_chunks < desired_chunk_size) {
72 desired_chunk_size = n / desired_n_chunks;
73 }
74 while (initial_n_chunk_size < desired_chunk_size
75 && initial_n_chunks_num > desired_n_chunks
76 && initial_n_chunk_size * 2 < n) {
77 initial_n_chunk_size *= 2;
78 initial_n_chunks_num
79 = ceil(static_cast<float>(n) / initial_n_chunk_size);
80 }
81 } else {
82 initial_n_chunks_num = n;
83 initial_n_chunk_size = 1;
84 }
85 return std::make_pair(initial_n_chunk_size, initial_n_chunks_num);
86}
87
88status_t gen9_reduction_t::pd_t::init_conf(engine_t *engine) {
89 const reduction_pd_t *pd = this;
90
91 const memory_desc_wrapper src_mdw(pd->src_md());
92 const memory_desc_wrapper dst_mdw(pd->dst_md());
93
94 const int ndims = src_mdw.ndims();
95 const dnnl_dim_t *src_dims = src_mdw.md_->dims;
96 const dnnl_dim_t *dst_dims = dst_mdw.md_->dims;
97 const compute::compute_engine_t *compute_engine
98 = utils::downcast<compute::compute_engine_t *>(engine);
99 const int num_threads = compute_engine->device_info()->hw_threads();
100
101 conf.alg = pd->desc()->alg_kind;
102 conf.src_md_info = memory_desc_info_t::create(src_mdw);
103 conf.dst_md_info = memory_desc_info_t::create(dst_mdw);
104 conf.dst_type = dst_mdw.data_type();
105 conf.src_type = src_mdw.data_type();
106 conf.ndims = ndims;
107 conf.power = pd->desc()->p;
108 conf.eps = pd->desc()->eps;
109 conf.dispatch = compute_engine->create_dispatch(src_mdw.md_);
110
111 // Last blocked dim is C and it has blockSize size
112 auto is_c_blocked_by
113 = [](const memory_desc_wrapper &mdw, const int blockSize) {
114 const blocking_desc_t &blk = mdw.blocking_desc();
115 if (blk.inner_nblks == 0) return false;
116 return (blk.inner_idxs[blk.inner_nblks - 1] == 1)
117 && (blk.inner_blks[blk.inner_nblks - 1] == blockSize);
118 };
119
120 using namespace dnnl::impl::format_tag;
121 const bool is_nhwc = (src_mdw.matches_one_of_tag(nwc, nhwc, ndhwc)
122 != format_tag::undef);
123
124 // plain layouts: NHWC, src C must be divisible by 16.
125 if (is_nhwc) {
126 int c = src_dims[1];
127 if (c % 16 != 0) { return status::unimplemented; }
128 if (src_dims[0] != dst_dims[0]) { return status::unimplemented; }
129 if (src_dims[1] != dst_dims[1]) { return status::unimplemented; }
130 } else {
131 // blocked layouts: src C must have blocks of 16 or 32
132 if (!(is_c_blocked_by(src_mdw, 16) || is_c_blocked_by(src_mdw, 32)))
133 return status::unimplemented;
134 }
135
136 int src_n_block_size, src_c_block_size;
137 int dst_n_block_size, dst_c_block_size;
138 std::tie(src_n_block_size, src_c_block_size) = get_n_c_block_sizes(src_mdw);
139 std::tie(dst_n_block_size, dst_c_block_size) = get_n_c_block_sizes(dst_mdw);
140
141 // src/dst blocking must match
142 if (src_n_block_size != dst_n_block_size
143 || src_c_block_size != dst_c_block_size
144 || src_mdw.blocking_desc().inner_nblks
145 != dst_mdw.blocking_desc().inner_nblks)
146 return status::unimplemented;
147
148 conf.n_block_size = src_n_block_size;
149 conf.c_block_size = src_c_block_size;
150 if (is_nhwc) { conf.c_block_size = src_dims[1]; }
151
152 // Either 0th/1st dims blocked or just 1st dim blocked
153 if ((conf.n_block_size == 1 && src_mdw.blocking_desc().inner_nblks > 1)
154 || src_mdw.blocking_desc().inner_nblks > 2) {
155 return status::unimplemented;
156 }
157
158 conf.div = 1;
159 int hwd_size = 1;
160 int hwd_reduction_size = 1;
161 for (int d = 0; d < ndims; d++) {
162 conf.src_dims[d] = src_dims[d];
163 conf.reduce_dims[d] = conf.dst_dims[d] = dim_t {1};
164 conf.is_reduction_dim[d] = conf.src_dims[d] != dst_dims[d];
165
166 if (conf.is_reduction_dim[d]) {
167 conf.reduce_dims[d] = conf.src_dims[d];
168 conf.div *= conf.reduce_dims[d];
169 } else {
170 conf.dst_dims[d] = conf.src_dims[d];
171 }
172 if (d >= 2) {
173 hwd_size *= conf.src_dims[d];
174 hwd_reduction_size *= conf.reduce_dims[d];
175 }
176 }
177
178 // If any spatial dims are reduced, they all have to be.
179 if (hwd_size != hwd_reduction_size && hwd_reduction_size > 1) {
180 return status::unimplemented;
181 }
182
183 // full padded C must be a multiple of 16 -- redundant due to blocking?
184 conf.sub_group_size = 16;
185 const auto &src_padded_dims = src_mdw.padded_dims();
186 if (src_padded_dims[1] % conf.sub_group_size != 0) {
187 return status_t::dnnl_unimplemented;
188 }
189
190 // number of C chunks in dim 1
191 conf.initial_c_chunks
192 = std::min(conf.c_block_size / conf.sub_group_size, 8);
193
194 // Split N chunks/chunk size according to heuristic
195 std::tie(conf.initial_n_chunk_size, conf.initial_n_chunks)
196 = get_initial_n_split(conf.src_dims[0], conf.is_reduction_dim[0]);
197
198 const auto get_reduction_elems_per_wi = [this]() {
199 return conf.initial_n_chunk_size * conf.initial_c_chunks
200 * conf.initial_hwd_chunk_size;
201 };
202
203 // Number of chunks of hwd to reduce
204 const auto get_wi_per_hwd = [this]() {
205 return ceil(static_cast<float>(conf.initial_hwd_dim)
206 / conf.initial_hwd_chunk_size);
207 };
208
209 const auto get_used_threads_num = [this, get_wi_per_hwd]() {
210 return conf.initial_n_chunks * conf.src_dims[1]
211 / (conf.sub_group_size * conf.initial_c_chunks)
212 * get_wi_per_hwd();
213 };
214
215 if (hwd_reduction_size == 1) {
216 conf.initial_hwd_chunk_size = 1;
217 // If there is no HWD reduction use vectors only to read whole C block
218 conf.vector_size = conf.initial_c_chunks;
219 conf.initial_hwd_dim = hwd_size;
220 conf.final_hwd_dim = hwd_size;
221 conf.final_hwd_chunk_size = 1;
222 } else {
223 // Start with such constant and try to adjust that with heuristics
224 conf.initial_hwd_chunk_size = 64;
225 if (conf.n_block_size > 1 || conf.src_dims[1] < conf.c_block_size) {
226 conf.vector_size = conf.initial_c_chunks;
227 } else {
228 conf.vector_size = 8;
229 }
230 conf.initial_hwd_dim = hwd_reduction_size;
231
232 // Experimentally selected values
233 constexpr int min_elems_per_wi = 64;
234 constexpr int max_wi_per_hwd = 512;
235 const int min_threads = num_threads;
236
237 while (get_used_threads_num() < min_threads
238 && get_reduction_elems_per_wi() > min_elems_per_wi
239 && get_wi_per_hwd() < max_wi_per_hwd) {
240 conf.initial_hwd_chunk_size /= 2;
241 }
242
243 while ((get_used_threads_num() > min_threads
244 && get_reduction_elems_per_wi() < min_elems_per_wi)
245 || get_wi_per_hwd() > max_wi_per_hwd) {
246 conf.initial_hwd_chunk_size *= 2;
247 }
248
249 while (conf.vector_size > conf.initial_hwd_chunk_size) {
250 conf.vector_size /= 2;
251 }
252 conf.final_hwd_dim = get_wi_per_hwd();
253 conf.final_hwd_chunk_size = conf.final_hwd_dim;
254 }
255
256 conf.final_c_dim = conf.is_reduction_dim[1]
257 ? src_padded_dims[1] / (conf.sub_group_size * conf.initial_c_chunks)
258 : conf.src_dims[1];
259 conf.final_c_chunk_size = conf.is_reduction_dim[1]
260 ? src_padded_dims[1] / (conf.sub_group_size * conf.initial_c_chunks)
261 : 1;
262
263 conf.final_n_dim = conf.is_reduction_dim[0] ? conf.initial_n_chunks
264 : conf.src_dims[0];
265 conf.final_n_chunk_size
266 = conf.is_reduction_dim[0] ? conf.initial_n_chunks : 1;
267
268 int initial_n_chunks_padded, initial_c_padded;
269
270 if (conf.final_c_chunk_size == 1 && conf.final_n_chunk_size == 1
271 && conf.final_hwd_chunk_size == 1) {
272 conf.skip_final_phase = true;
273 // zero pad N and C in initial phase only when there is no final phase
274 const int n_padded = utils::rnd_up(conf.src_dims[0], conf.n_block_size);
275 initial_n_chunks_padded = ceil(
276 static_cast<float>(n_padded) / conf.initial_n_chunk_size);
277 initial_c_padded = utils::rnd_up(conf.src_dims[1], conf.c_block_size);
278 } else {
279 conf.skip_final_phase = false;
280 initial_n_chunks_padded = conf.initial_n_chunks;
281 initial_c_padded = utils::rnd_up(conf.src_dims[1], conf.c_block_size);
282 }
283
284 conf.dispatch.define_dim("INITIAL_N", 0, initial_n_chunks_padded, 1);
285 conf.dispatch.define_dim("INITIAL_C", std::min(ndims - 1, 1),
286 initial_c_padded, conf.initial_c_chunks);
287 conf.dispatch.define_dim("INITIAL_HWD_CHUNK_ID", std::min(ndims - 1, 2),
288 conf.final_hwd_dim, 1);
289
290 // Each initial kernel will handle 16 C channels
291 // Requires INITIAL_C (initial_c_padded) to be a multiple of sub_group_size
292 CHECK(conf.dispatch.vectorize_dim("INITIAL_C", conf.sub_group_size));
293 conf.dispatch.set_kernel_attr_suffix("INITIAL");
294 conf.dispatch.generate();
295 conf.attr_info = attr_info_t::create(pd->attr());
296
297 if (!conf.skip_final_phase) {
298 conf.finalize_dispatch = compute_engine->create_dispatch();
299 const int final_n_padded
300 = utils::rnd_up(conf.final_n_dim, conf.n_block_size);
301 const int final_n_chunks_padded
302 = utils::div_up(final_n_padded, conf.final_n_chunk_size);
303 conf.finalize_dispatch.define_dim("FINAL_N", 0, final_n_chunks_padded);
304 const int final_c_padded
305 = utils::rnd_up(conf.final_c_dim, conf.c_block_size);
306 const int final_c_chunks_padded
307 = utils::div_up(final_c_padded, conf.final_c_chunk_size);
308 conf.finalize_dispatch.define_dim(
309 "FINAL_C", std::min(ndims - 1, 1), final_c_chunks_padded);
310 conf.finalize_dispatch.define_dim("FINAL_HWD", std::min(ndims - 1, 2),
311 conf.final_hwd_dim / conf.final_hwd_chunk_size);
312 conf.finalize_dispatch.set_kernel_attr_suffix("FINAL");
313 conf.finalize_dispatch.generate();
314 }
315
316 return status::success;
317}
318
319static status_t init_kernel_ctx_common(compute::kernel_ctx_t &kernel_ctx,
320 const reduction_conf_t &conf, const post_ops_t &post_ops) {
321 using namespace alg_kind;
322
323 kernel_ctx.set_data_type(conf.src_type);
324
325 // N shape descriptors
326 kernel_ctx.define_int("IS_N_REDUCED", conf.is_reduction_dim[0]);
327 kernel_ctx.define_int("INITIAL_N", conf.src_dims[0]);
328 kernel_ctx.define_int("INITIAL_N_CHUNKS", conf.initial_n_chunks);
329 kernel_ctx.define_int("INITIAL_N_CHUNK_SIZE", conf.initial_n_chunk_size);
330 kernel_ctx.define_int("N_BLOCK_SIZE", conf.n_block_size);
331
332 // C shape descriptors
333 kernel_ctx.define_int("IS_C_REDUCED", conf.is_reduction_dim[1]);
334 kernel_ctx.define_int("INITIAL_C", conf.src_dims[1]);
335 kernel_ctx.define_int("INITIAL_C_CHUNKS", conf.initial_c_chunks);
336 // No INITIAL_C_CHUNK_SIZE variable -- equal to SUB_GROUP_SIZE
337 kernel_ctx.define_int("C_BLOCK_SIZE", conf.c_block_size);
338
339 // Spatial shape descriptors
340 kernel_ctx.define_int("INITIAL_HWD_DIM", conf.initial_hwd_dim);
341 kernel_ctx.define_int(
342 "INITIAL_HWD_CHUNK_SIZE", conf.initial_hwd_chunk_size);
343 kernel_ctx.define_int(
344 "IS_HWD_REDUCED", conf.final_hwd_dim < conf.initial_hwd_dim);
345
346 // DST shape descriptors
347 kernel_ctx.define_int("DST_N", conf.dst_dims[0]);
348 kernel_ctx.define_int("DST_C", conf.dst_dims[1]);
349 kernel_ctx.define_int(
350 "DST_N_PADDED", utils::rnd_up(conf.dst_dims[0], conf.n_block_size));
351 kernel_ctx.define_int(
352 "DST_C_PADDED", utils::rnd_up(conf.dst_dims[1], conf.c_block_size));
353
354 // General problem descriptors
355 kernel_ctx.define_int("SUB_GROUP_SIZE", conf.sub_group_size);
356 kernel_ctx.define_int("VECT_DT_N", conf.vector_size);
357 kernel_ctx.define_int("REDUCTION_SIZE", conf.div);
358 kernel_ctx.define_int("NDIMS", conf.ndims);
359 kernel_ctx.define_int("POWER", conf.power);
360 kernel_ctx.define_float("EPS", conf.eps);
361
362 kernel_ctx.define_int("SKIP_FINAL_PHASE", conf.skip_final_phase);
363
364 // Final kernel variables
365 kernel_ctx.define_int("FINAL_N_DIM", conf.final_n_dim);
366 kernel_ctx.define_int("FINAL_N_CHUNK_SIZE", conf.final_n_chunk_size);
367 kernel_ctx.define_int("FINAL_C_DIM", conf.final_c_dim);
368 kernel_ctx.define_int("FINAL_C_CHUNK_SIZE", conf.final_c_chunk_size);
369 kernel_ctx.define_int("FINAL_HWD_DIM", conf.final_hwd_dim);
370 kernel_ctx.define_int("FINAL_HWD_CHUNK_SIZE", conf.final_hwd_chunk_size);
371
372 // Define H/W/D dimensions for use in binary post ops
373 std::string dim_names[3] = {"D", "H", "W"};
374 for (int i = 2; i < 5; i++) {
375 dim_t dim = (i < conf.ndims) ? conf.dst_dims[i] : 1;
376 kernel_ctx.define_int("DST_" + dim_names[i - 2] + "_DIM", dim);
377 }
378
379 switch (conf.alg) {
380 case reduction_max: kernel_ctx.define_int("IS_MAX", 1); break;
381 case reduction_min: kernel_ctx.define_int("IS_MIN", 1); break;
382 case reduction_mean: kernel_ctx.define_int("IS_MEAN", 1); break;
383 case reduction_sum: kernel_ctx.define_int("IS_SUM", 1); break;
384 case reduction_mul: kernel_ctx.define_int("IS_MUL", 1); break;
385 case reduction_norm_lp_max:
386 kernel_ctx.define_int("IS_LP_MAX", 1);
387 break;
388 case reduction_norm_lp_sum:
389 kernel_ctx.define_int("IS_LP_SUM", 1);
390 break;
391 case reduction_norm_lp_power_p_max:
392 kernel_ctx.define_int("IS_P_MAX", 1);
393 break;
394 case reduction_norm_lp_power_p_sum:
395 kernel_ctx.define_int("IS_P_SUM", 1);
396 break;
397 default: return status::invalid_arguments;
398 }
399
400 def_memory_desc_info(kernel_ctx, conf.src_md_info, "SRC");
401 def_memory_desc_info(kernel_ctx, conf.dst_md_info, "DST");
402
403 def_attr_info(kernel_ctx, conf.attr_info, post_ops);
404
405 def_dispatch(kernel_ctx, conf.dispatch);
406 if (!conf.skip_final_phase)
407 def_dispatch(kernel_ctx, conf.finalize_dispatch);
408
409 return status::success;
410}
411
412status_t gen9_reduction_t::pd_t::init_kernel_ctx(
413 compute::kernel_ctx_t &kernel_ctx) const {
414 return init_kernel_ctx_common(kernel_ctx, conf, attr()->post_ops_);
415}
416
417void gen9_reduction_t::pd_t::init_scratchpad() {
418 const size_t size = utils::rnd_up(conf.final_n_dim, conf.n_block_size)
419 * utils::rnd_up(conf.final_c_dim, conf.c_block_size)
420 * conf.final_hwd_dim;
421
422 auto scratchpad = scratchpad_registry().registrar();
423 scratchpad.book(memory_tracking::names::key_reduction, size,
424 types::data_type_size(data_type::f32), OCL_BUFFER_ALIGNMENT);
425}
426
427status_t gen9_reduction_t::execute_gen9(const exec_ctx_t &ctx) const {
428 auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC);
429 auto &dst = CTX_OUT_STORAGE(DNNL_ARG_DST);
430
431 std::unique_ptr<memory_storage_t> temp_reduce
432 = ctx.get_scratchpad_grantor().get_memory_storage(
433 memory_tracking::names::key_reduction);
434 const auto &conf = pd()->conf;
435
436 // Kick off the initial reduction phase
437 compute::kernel_arg_list_t reduction_arg_list;
438 reduction_arg_list.set(0, src);
439 reduction_arg_list.set(1, conf.skip_final_phase ? dst : *temp_reduce);
440 if (conf.skip_final_phase) {
441 append_post_ops_to_arg_list(
442 ctx, reduction_arg_list, 2, pd()->attr()->post_ops_);
443 }
444 auto initial_nd_range = conf.dispatch.nd_range();
445 status_t status = parallel_for(
446 ctx, initial_nd_range, initial_kernel, reduction_arg_list);
447
448 if (conf.skip_final_phase || status != status::success) return status;
449
450 // Continue with final reduction phase
451 compute::kernel_arg_list_t final_reduction_arg_list;
452 final_reduction_arg_list.set(0, *temp_reduce);
453 final_reduction_arg_list.set(1, dst);
454 append_post_ops_to_arg_list(
455 ctx, final_reduction_arg_list, 2, pd()->attr()->post_ops_);
456 auto final_nd_range = conf.finalize_dispatch.nd_range();
457 return parallel_for(
458 ctx, final_nd_range, final_kernel, final_reduction_arg_list);
459}
460
461} // namespace ocl
462} // namespace gpu
463} // namespace impl
464} // namespace dnnl
465