1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/ocl/combined_reduction.hpp"
18#include "common/scratchpad.hpp"
19#include "gpu/ocl/ocl_utils.hpp"
20
21namespace dnnl {
22namespace impl {
23namespace gpu {
24namespace ocl {
25
26static reduction_phase_t init_phase(int start, int end, int num_reductions,
27 data_type_t src_type, data_type_t dst_type,
28 compute::nd_range_t nd_range, bool is_final, bool is_first) {
29 reduction_phase_t phase;
30 phase.initial_size = start;
31 phase.reduction_size = num_reductions;
32 phase.final_size = end;
33 phase.num_reduction_chunks
34 = utils::div_up(phase.initial_size, phase.reduction_size);
35 phase.src_type = src_type;
36 phase.dst_type = dst_type;
37 phase.nd_range = nd_range;
38 phase.is_final = is_final;
39 phase.is_first = is_first;
40 return phase;
41}
42
43void combined_reduction_t::pd_t::init_scratchpad() {
44 // Only need scratchpads for the first 2 phases, since we can reuse them
45 // and memory requirements are monotonically decreasing each phase.
46 uint32_t keys[2] = {memory_tracking::names::key_reduction,
47 memory_tracking::names::key_reduction_1};
48
49 for (int phase_num = 0;
50 phase_num < std::min(2, (int)conf.phases.size() - 1); phase_num++) {
51 const size_t sp_data_size
52 = types::data_type_size(conf.phases[phase_num].dst_type);
53 auto scratchpad = scratchpad_registry().registrar();
54 scratchpad.book(keys[phase_num], conf.sp_size[phase_num], sp_data_size,
55 OCL_BUFFER_ALIGNMENT);
56 }
57}
58
59status_t combined_reduction_t::pd_t::init_conf(engine_t *engine) {
60 // To start, check for compatibility
61 const memory_desc_wrapper src_mdw(src_md());
62 const memory_desc_wrapper dst_mdw(dst_md());
63 const int ndims = src_mdw.ndims();
64
65 const dnnl_dim_t *src_dims = src_mdw.dims();
66 const blocking_desc_t &blk = src_mdw.blocking_desc();
67
68 const dnnl_dim_t *dst_dims = dst_mdw.dims();
69 const blocking_desc_t &dst_blk = dst_mdw.blocking_desc();
70
71 const compute::compute_engine_t *compute_engine
72 = utils::downcast<compute::compute_engine_t *>(engine);
73
74 // Require same src/dst blocking
75 if (blk.inner_nblks != dst_blk.inner_nblks) { // Same number of blocks
76 return status::unimplemented;
77 }
78
79 for (int i = 0; i < blk.inner_nblks; i++) {
80 if (blk.inner_idxs[i]
81 != dst_blk.inner_idxs[i]) { // Same blocking permutation
82 return status::unimplemented;
83 }
84 if (blk.inner_blks[i] != dst_blk.inner_blks[i]) { // Same blocking sizes
85 return status::unimplemented;
86 }
87 }
88
89 // Zero padding is not implemented when dim is reduced
90 // Or when doing an LP/P alg (not zero-preserving)
91 using namespace alg_kind;
92 for (int i = 0; i < blk.inner_nblks; i++) {
93 // Needs zero padding
94 if (dst_mdw.padded_dims()[blk.inner_idxs[i]]
95 != dst_mdw.dims()[blk.inner_idxs[i]]) {
96 // non-zero-preserving alg
97 switch (desc()->alg_kind) {
98 case reduction_norm_lp_max:
99 case reduction_norm_lp_sum:
100 case reduction_norm_lp_power_p_max:
101 case reduction_norm_lp_power_p_sum:
102 return status::unimplemented;
103 default: break;
104 }
105 // Dim reduced
106 if (dst_mdw.dims()[blk.inner_idxs[i]]
107 != src_mdw.dims()[blk.inner_idxs[i]]) {
108 return status::unimplemented;
109 }
110 }
111 }
112
113 // Determine ordering of dims by stride
114 dims_t dim_perm = {0}, dst_perm = {0};
115 for (int i = 0; i < ndims; i++) {
116 // Src
117 dim_t stride = blk.strides[i];
118 int dim_idx = i;
119 for (int j = 0; j < i; j++) {
120 if (stride > blk.strides[dim_perm[j]]) {
121 // Insert this stride/idx into dim_perms
122 stride = blk.strides[dim_perm[j]];
123 int tmp_perm = dim_perm[j];
124 dim_perm[j] = dim_idx;
125 dim_idx = tmp_perm;
126 }
127 }
128 dim_perm[i] = dim_idx;
129
130 // Same for dst
131 stride = dst_blk.strides[i];
132 dim_idx = i;
133 for (int j = 0; j < i; j++) {
134 if (stride > dst_blk.strides[dst_perm[j]]) {
135 // Insert this stride/idx into dim_perms
136 stride = dst_blk.strides[dst_perm[j]];
137 int tmp_perm = dst_perm[j];
138 dst_perm[j] = dim_idx;
139 dim_idx = tmp_perm;
140 }
141 }
142 dst_perm[i] = dim_idx;
143 }
144
145 dims_t block_sizes, dst_blocks;
146 src_mdw.compute_blocks(block_sizes);
147 dst_mdw.compute_blocks(dst_blocks);
148 // Determine extended (plain+blocked) dim structure
149 dim_t extended_dim_order[2 * MAX_NDIMS], extended_dst_order[2 * MAX_NDIMS];
150 dim_t extended_dim_size[2 * MAX_NDIMS];
151 const int num_comp_dims = ndims + blk.inner_nblks;
152 for (int i = 0; i < ndims; i++) { // plain
153 extended_dim_order[i] = dim_perm[i];
154 extended_dim_size[i]
155 = src_mdw.padded_dims()[dim_perm[i]] / block_sizes[dim_perm[i]];
156 extended_dst_order[i] = dst_perm[i];
157 }
158 for (int i = 0; i < blk.inner_nblks; i++) { // blocked
159 extended_dim_order[i + ndims] = blk.inner_idxs[i];
160 extended_dim_size[i + ndims] = blk.inner_blks[i];
161 extended_dst_order[i + ndims] = dst_blk.inner_idxs[i];
162 }
163
164 // Only allow same src/dst format tags and permutations
165 // TODO: Relax src/dst format matching
166 for (int i = 0; i < num_comp_dims; i++) {
167 if (extended_dim_order[i] != extended_dst_order[i]) {
168 return status::unimplemented;
169 }
170 }
171
172 // Convert composite structure to reduced dims
173 dim_t extended_reduced_dims[2 * MAX_NDIMS];
174 for (int i = 0; i < num_comp_dims; i++) {
175 extended_reduced_dims[i] = (src_dims[extended_dim_order[i]]
176 != dst_dims[extended_dim_order[i]])
177 ? 1
178 : 0;
179 }
180
181 // Finally, the check: Make sure all reduced dims are sequential
182 // i.e. extended_reduced_dims has no 10...1 pattern
183 for (int i = 0; i < num_comp_dims - 2; i++) {
184 if (extended_reduced_dims[i] == 0
185 || extended_reduced_dims[i + 1] == 1) {
186 continue;
187 }
188 // Now we have the 10 pattern -- look for all 0's to the right
189 for (int j = i + 1; j < num_comp_dims; j++) {
190 if (extended_reduced_dims[j] == 1) { return status::unimplemented; }
191 }
192 break;
193 }
194
195 // Get information about composite outer/reduced/inner dimensions
196 int num_outer_dims = 0, num_reduced_dims = 0, num_inner_dims = 0;
197 bool left_side = true;
198 for (int i = 0; i < num_comp_dims; i++) {
199 if (extended_reduced_dims[i] == 1) {
200 left_side = false;
201 num_reduced_dims += 1;
202 continue;
203 }
204
205 if (left_side) {
206 num_outer_dims += 1;
207 } else {
208 num_inner_dims += 1;
209 }
210 }
211
212 // Compute composite dim sizes
213 int outer_dim_size = 1;
214 for (int i = 0; i < num_outer_dims; i++) {
215 outer_dim_size *= extended_dim_size[i];
216 }
217 int reduced_dim_size = 1;
218 for (int i = 0; i < num_reduced_dims; i++) {
219 reduced_dim_size *= extended_dim_size[i + num_outer_dims];
220 }
221 int inner_dim_size = 1;
222 for (int i = 0; i < num_inner_dims; i++) {
223 inner_dim_size
224 *= extended_dim_size[i + num_outer_dims + num_reduced_dims];
225 }
226
227 // Set up conf variables that don't change between phases
228 conf.ndims = ndims;
229 conf.alg = desc()->alg_kind;
230 conf.power = desc()->p;
231 conf.eps = desc()->eps;
232
233 conf.div = reduced_dim_size;
234 conf.outer_dim_size = outer_dim_size;
235 conf.inner_dim_size = inner_dim_size;
236
237 conf.attr_info = attr_info_t::create(attr());
238
239 // Heuristics based on testing on PVC
240 conf.sub_group_size = compute_engine->device_info()->max_subgroup_size();
241
242 const int target_reduction_size = 8;
243
244 // Pad the inner dim to a multiple of subgroup size
245 conf.inner_dim_per_sg = std::min(reduced_dim_size,
246 std::max(1, conf.sub_group_size / conf.inner_dim_size));
247 conf.gws_inner_dim_size = utils::rnd_up(
248 conf.inner_dim_per_sg * conf.inner_dim_size, conf.sub_group_size);
249
250 while (reduced_dim_size > 1) {
251 data_type_t src_data_type;
252 bool is_first;
253 if (reduced_dim_size == conf.div) {
254 src_data_type = src_mdw.data_type();
255 is_first = true;
256 } else {
257 src_data_type = types::default_accum_data_type(
258 src_mdw.data_type(), data_type::undef);
259 is_first = false;
260 }
261
262 // Compute the number of phases left
263 const int horiz_reductions
264 = utils::div_up(reduced_dim_size, conf.inner_dim_per_sg);
265 const int num_remaining_phases = std::floor(
266 std::log(horiz_reductions) / std::log(target_reduction_size));
267 const int red_per_phase
268 = std::pow(reduced_dim_size, 1.0f / num_remaining_phases);
269
270 int reduction_size;
271 bool is_final;
272 data_type_t dst_data_type;
273 if (num_remaining_phases > 1) {
274 reduction_size = red_per_phase;
275 is_final = false;
276 dst_data_type = types::default_accum_data_type(
277 src_mdw.data_type(), data_type::undef);
278 } else {
279 reduction_size = reduced_dim_size;
280 is_final = true;
281 dst_data_type = dst_mdw.data_type();
282 }
283
284 const int phase_start = reduced_dim_size;
285 const int phase_reductions = reduction_size;
286 const int phase_end = utils::div_up(phase_start, phase_reductions);
287
288 // Set scratchpad sizes
289 const int phase_num = (int)conf.phases.size();
290 if (!is_final && phase_num < 2) {
291 conf.sp_size[phase_num]
292 = conf.outer_dim_size * phase_end * conf.inner_dim_size;
293 }
294
295 compute::dispatch_t dispatch = compute_engine->create_dispatch();
296 size_t gws[3] = {1, 1, 1}, lws[3] = {1, 1, 1};
297 gws[0] *= outer_dim_size * conf.gws_inner_dim_size * phase_end;
298
299 // Set lws + pad gws simultaneously
300 // - lws multiple of sub_group_size
301 // - gws multiple of lws
302 lws[0] = utils::rnd_up(std::min((int)gws[0], 256), conf.sub_group_size);
303 gws[0] = utils::rnd_up(gws[0], lws[0]);
304 compute::nd_range_t nd_range(gws, lws);
305
306 conf.phases.push_back(init_phase(phase_start, phase_end,
307 phase_reductions, src_data_type, dst_data_type, nd_range,
308 is_final, is_first));
309 reduced_dim_size = phase_end;
310 }
311 return status::success;
312}
313
314static status_t init_kernel_ctx_common(compute::kernel_ctx_t &kernel_ctx,
315 const reduction_conf_t &conf, const reduction_phase_t &phase) {
316 using namespace alg_kind;
317
318 kernel_ctx.set_data_type(phase.src_type);
319
320 // 1 ==> Use subgroups
321 kernel_ctx.define_int("GWS_WITH_SG_DEFAULT", 1);
322 kernel_ctx.define_int("GWS_SGS_DEFAULT", conf.sub_group_size);
323 kernel_ctx.define_int("GWS_LWS0_DEFAULT", phase.nd_range.local_range()[0]);
324 kernel_ctx.define_int("GWS_LWS1_DEFAULT", 1);
325 kernel_ctx.define_int("GWS_LWS2_DEFAULT", 1);
326 kernel_ctx.define_int("INNER_DIMS_PER_WI", conf.inner_dim_per_sg);
327
328 kernel_ctx.define_int("REDUCTION_END_SIZE", phase.final_size);
329 kernel_ctx.define_int("REDUCTION_SIZE", phase.initial_size);
330 kernel_ctx.define_int("REDUCTION_CHUNK_SIZE", phase.num_reduction_chunks);
331 kernel_ctx.define_int("OUTER_DIM_STRIDE",
332 phase.num_reduction_chunks * conf.gws_inner_dim_size);
333 kernel_ctx.define_int("DIV", conf.div);
334 kernel_ctx.define_int("OUTER_DIM_SIZE", conf.outer_dim_size);
335 kernel_ctx.define_int("INNER_DIM_SIZE", conf.inner_dim_size);
336 kernel_ctx.define_int("PADDED_INNER_DIM_SIZE", conf.gws_inner_dim_size);
337 kernel_ctx.define_int("NDIMS", conf.ndims);
338 kernel_ctx.define_int("POWER", conf.power);
339 kernel_ctx.define_float("EPS", conf.eps);
340
341 int sg_reduction_per_wi
342 = utils::div_up(phase.reduction_size, conf.inner_dim_per_sg);
343 sg_reduction_per_wi = std::min(conf.div, sg_reduction_per_wi);
344 kernel_ctx.define_int("REDUCTIONS_PER_WI",
345 sg_reduction_per_wi); // Can change between phases
346 kernel_ctx.define_int("IS_FINAL", phase.is_final);
347 kernel_ctx.define_int("IS_FIRST", phase.is_first);
348
349 // Block loading is supported when inner dims are a multiple of 4 bytes
350 if ((types::data_type_size(phase.src_type) * conf.inner_dim_size
351 * conf.inner_dim_per_sg)
352 % 4
353 == 0) {
354 kernel_ctx.define_int("WITH_BLOCK_READ", 1);
355 } else {
356 kernel_ctx.define_int("WITH_BLOCK_READ", 0);
357 }
358
359 switch (conf.alg) {
360 case reduction_max: kernel_ctx.define_int("IS_MAX", 1); break;
361 case reduction_min: kernel_ctx.define_int("IS_MIN", 1); break;
362 case reduction_mean: kernel_ctx.define_int("IS_MEAN", 1); break;
363 case reduction_sum: kernel_ctx.define_int("IS_SUM", 1); break;
364 case reduction_mul: kernel_ctx.define_int("IS_MUL", 1); break;
365 case reduction_norm_lp_max:
366 kernel_ctx.define_int("IS_LP_MAX", 1);
367 break;
368 case reduction_norm_lp_sum:
369 kernel_ctx.define_int("IS_LP_SUM", 1);
370 break;
371 case reduction_norm_lp_power_p_max:
372 kernel_ctx.define_int("IS_P_MAX", 1);
373 break;
374 case reduction_norm_lp_power_p_sum:
375 kernel_ctx.define_int("IS_P_SUM", 1);
376 break;
377 default: return status::invalid_arguments;
378 }
379
380 def_data_type(kernel_ctx, phase.src_type, "SRC");
381 def_data_type(kernel_ctx, phase.dst_type, "DST");
382
383 return status::success;
384}
385
386status_t combined_reduction_t::pd_t::init_kernel_ctx(
387 compute::kernel_ctx_t &kernel_ctx,
388 const reduction_phase_t &phase) const {
389 return init_kernel_ctx_common(kernel_ctx, conf, phase);
390}
391
392status_t combined_reduction_t::execute_combined(const exec_ctx_t &ctx) const {
393 auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC);
394 auto &dst = CTX_OUT_STORAGE(DNNL_ARG_DST);
395 std::unique_ptr<memory_storage_t> sp_reduce[2]
396 = {ctx.get_scratchpad_grantor().get_memory_storage(
397 memory_tracking::names::key_reduction),
398 ctx.get_scratchpad_grantor().get_memory_storage(
399 memory_tracking::names::key_reduction_1)};
400 const auto &conf = pd()->conf;
401
402 status_t status = status::success;
403 for (size_t i = 0; i < kernels.size(); i++) {
404 auto &kernel = kernels[i];
405 auto &phase = conf.phases[i];
406 auto nd_range = phase.nd_range;
407
408 // Set up the reduction arg list
409 compute::kernel_arg_list_t reduction_arg_list;
410
411 if (i == 0) {
412 reduction_arg_list.set(0, src);
413 } else {
414 reduction_arg_list.set(0, *sp_reduce[(i - 1) % 2]);
415 }
416
417 if (i == kernels.size() - 1) {
418 reduction_arg_list.set(1, dst);
419 } else {
420 reduction_arg_list.set(1, *sp_reduce[i % 2]);
421 }
422
423 status = parallel_for(ctx, nd_range, kernel, reduction_arg_list);
424 CHECK(status);
425 }
426 return status;
427}
428
429} // namespace ocl
430} // namespace gpu
431} // namespace impl
432} // namespace dnnl
433