1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "gpu/ocl/combined_reduction.hpp" |
18 | #include "common/scratchpad.hpp" |
19 | #include "gpu/ocl/ocl_utils.hpp" |
20 | |
21 | namespace dnnl { |
22 | namespace impl { |
23 | namespace gpu { |
24 | namespace ocl { |
25 | |
26 | static reduction_phase_t init_phase(int start, int end, int num_reductions, |
27 | data_type_t src_type, data_type_t dst_type, |
28 | compute::nd_range_t nd_range, bool is_final, bool is_first) { |
29 | reduction_phase_t phase; |
30 | phase.initial_size = start; |
31 | phase.reduction_size = num_reductions; |
32 | phase.final_size = end; |
33 | phase.num_reduction_chunks |
34 | = utils::div_up(phase.initial_size, phase.reduction_size); |
35 | phase.src_type = src_type; |
36 | phase.dst_type = dst_type; |
37 | phase.nd_range = nd_range; |
38 | phase.is_final = is_final; |
39 | phase.is_first = is_first; |
40 | return phase; |
41 | } |
42 | |
43 | void combined_reduction_t::pd_t::init_scratchpad() { |
44 | // Only need scratchpads for the first 2 phases, since we can reuse them |
45 | // and memory requirements are monotonically decreasing each phase. |
46 | uint32_t keys[2] = {memory_tracking::names::key_reduction, |
47 | memory_tracking::names::key_reduction_1}; |
48 | |
49 | for (int phase_num = 0; |
50 | phase_num < std::min(2, (int)conf.phases.size() - 1); phase_num++) { |
51 | const size_t sp_data_size |
52 | = types::data_type_size(conf.phases[phase_num].dst_type); |
53 | auto scratchpad = scratchpad_registry().registrar(); |
54 | scratchpad.book(keys[phase_num], conf.sp_size[phase_num], sp_data_size, |
55 | OCL_BUFFER_ALIGNMENT); |
56 | } |
57 | } |
58 | |
59 | status_t combined_reduction_t::pd_t::init_conf(engine_t *engine) { |
60 | // To start, check for compatibility |
61 | const memory_desc_wrapper src_mdw(src_md()); |
62 | const memory_desc_wrapper dst_mdw(dst_md()); |
63 | const int ndims = src_mdw.ndims(); |
64 | |
65 | const dnnl_dim_t *src_dims = src_mdw.dims(); |
66 | const blocking_desc_t &blk = src_mdw.blocking_desc(); |
67 | |
68 | const dnnl_dim_t *dst_dims = dst_mdw.dims(); |
69 | const blocking_desc_t &dst_blk = dst_mdw.blocking_desc(); |
70 | |
71 | const compute::compute_engine_t *compute_engine |
72 | = utils::downcast<compute::compute_engine_t *>(engine); |
73 | |
74 | // Require same src/dst blocking |
75 | if (blk.inner_nblks != dst_blk.inner_nblks) { // Same number of blocks |
76 | return status::unimplemented; |
77 | } |
78 | |
79 | for (int i = 0; i < blk.inner_nblks; i++) { |
80 | if (blk.inner_idxs[i] |
81 | != dst_blk.inner_idxs[i]) { // Same blocking permutation |
82 | return status::unimplemented; |
83 | } |
84 | if (blk.inner_blks[i] != dst_blk.inner_blks[i]) { // Same blocking sizes |
85 | return status::unimplemented; |
86 | } |
87 | } |
88 | |
89 | // Zero padding is not implemented when dim is reduced |
90 | // Or when doing an LP/P alg (not zero-preserving) |
91 | using namespace alg_kind; |
92 | for (int i = 0; i < blk.inner_nblks; i++) { |
93 | // Needs zero padding |
94 | if (dst_mdw.padded_dims()[blk.inner_idxs[i]] |
95 | != dst_mdw.dims()[blk.inner_idxs[i]]) { |
96 | // non-zero-preserving alg |
97 | switch (desc()->alg_kind) { |
98 | case reduction_norm_lp_max: |
99 | case reduction_norm_lp_sum: |
100 | case reduction_norm_lp_power_p_max: |
101 | case reduction_norm_lp_power_p_sum: |
102 | return status::unimplemented; |
103 | default: break; |
104 | } |
105 | // Dim reduced |
106 | if (dst_mdw.dims()[blk.inner_idxs[i]] |
107 | != src_mdw.dims()[blk.inner_idxs[i]]) { |
108 | return status::unimplemented; |
109 | } |
110 | } |
111 | } |
112 | |
113 | // Determine ordering of dims by stride |
114 | dims_t dim_perm = {0}, dst_perm = {0}; |
115 | for (int i = 0; i < ndims; i++) { |
116 | // Src |
117 | dim_t stride = blk.strides[i]; |
118 | int dim_idx = i; |
119 | for (int j = 0; j < i; j++) { |
120 | if (stride > blk.strides[dim_perm[j]]) { |
121 | // Insert this stride/idx into dim_perms |
122 | stride = blk.strides[dim_perm[j]]; |
123 | int tmp_perm = dim_perm[j]; |
124 | dim_perm[j] = dim_idx; |
125 | dim_idx = tmp_perm; |
126 | } |
127 | } |
128 | dim_perm[i] = dim_idx; |
129 | |
130 | // Same for dst |
131 | stride = dst_blk.strides[i]; |
132 | dim_idx = i; |
133 | for (int j = 0; j < i; j++) { |
134 | if (stride > dst_blk.strides[dst_perm[j]]) { |
135 | // Insert this stride/idx into dim_perms |
136 | stride = dst_blk.strides[dst_perm[j]]; |
137 | int tmp_perm = dst_perm[j]; |
138 | dst_perm[j] = dim_idx; |
139 | dim_idx = tmp_perm; |
140 | } |
141 | } |
142 | dst_perm[i] = dim_idx; |
143 | } |
144 | |
145 | dims_t block_sizes, dst_blocks; |
146 | src_mdw.compute_blocks(block_sizes); |
147 | dst_mdw.compute_blocks(dst_blocks); |
148 | // Determine extended (plain+blocked) dim structure |
149 | dim_t extended_dim_order[2 * MAX_NDIMS], extended_dst_order[2 * MAX_NDIMS]; |
150 | dim_t extended_dim_size[2 * MAX_NDIMS]; |
151 | const int num_comp_dims = ndims + blk.inner_nblks; |
152 | for (int i = 0; i < ndims; i++) { // plain |
153 | extended_dim_order[i] = dim_perm[i]; |
154 | extended_dim_size[i] |
155 | = src_mdw.padded_dims()[dim_perm[i]] / block_sizes[dim_perm[i]]; |
156 | extended_dst_order[i] = dst_perm[i]; |
157 | } |
158 | for (int i = 0; i < blk.inner_nblks; i++) { // blocked |
159 | extended_dim_order[i + ndims] = blk.inner_idxs[i]; |
160 | extended_dim_size[i + ndims] = blk.inner_blks[i]; |
161 | extended_dst_order[i + ndims] = dst_blk.inner_idxs[i]; |
162 | } |
163 | |
164 | // Only allow same src/dst format tags and permutations |
165 | // TODO: Relax src/dst format matching |
166 | for (int i = 0; i < num_comp_dims; i++) { |
167 | if (extended_dim_order[i] != extended_dst_order[i]) { |
168 | return status::unimplemented; |
169 | } |
170 | } |
171 | |
172 | // Convert composite structure to reduced dims |
173 | dim_t extended_reduced_dims[2 * MAX_NDIMS]; |
174 | for (int i = 0; i < num_comp_dims; i++) { |
175 | extended_reduced_dims[i] = (src_dims[extended_dim_order[i]] |
176 | != dst_dims[extended_dim_order[i]]) |
177 | ? 1 |
178 | : 0; |
179 | } |
180 | |
181 | // Finally, the check: Make sure all reduced dims are sequential |
182 | // i.e. extended_reduced_dims has no 10...1 pattern |
183 | for (int i = 0; i < num_comp_dims - 2; i++) { |
184 | if (extended_reduced_dims[i] == 0 |
185 | || extended_reduced_dims[i + 1] == 1) { |
186 | continue; |
187 | } |
188 | // Now we have the 10 pattern -- look for all 0's to the right |
189 | for (int j = i + 1; j < num_comp_dims; j++) { |
190 | if (extended_reduced_dims[j] == 1) { return status::unimplemented; } |
191 | } |
192 | break; |
193 | } |
194 | |
195 | // Get information about composite outer/reduced/inner dimensions |
196 | int num_outer_dims = 0, num_reduced_dims = 0, num_inner_dims = 0; |
197 | bool left_side = true; |
198 | for (int i = 0; i < num_comp_dims; i++) { |
199 | if (extended_reduced_dims[i] == 1) { |
200 | left_side = false; |
201 | num_reduced_dims += 1; |
202 | continue; |
203 | } |
204 | |
205 | if (left_side) { |
206 | num_outer_dims += 1; |
207 | } else { |
208 | num_inner_dims += 1; |
209 | } |
210 | } |
211 | |
212 | // Compute composite dim sizes |
213 | int outer_dim_size = 1; |
214 | for (int i = 0; i < num_outer_dims; i++) { |
215 | outer_dim_size *= extended_dim_size[i]; |
216 | } |
217 | int reduced_dim_size = 1; |
218 | for (int i = 0; i < num_reduced_dims; i++) { |
219 | reduced_dim_size *= extended_dim_size[i + num_outer_dims]; |
220 | } |
221 | int inner_dim_size = 1; |
222 | for (int i = 0; i < num_inner_dims; i++) { |
223 | inner_dim_size |
224 | *= extended_dim_size[i + num_outer_dims + num_reduced_dims]; |
225 | } |
226 | |
227 | // Set up conf variables that don't change between phases |
228 | conf.ndims = ndims; |
229 | conf.alg = desc()->alg_kind; |
230 | conf.power = desc()->p; |
231 | conf.eps = desc()->eps; |
232 | |
233 | conf.div = reduced_dim_size; |
234 | conf.outer_dim_size = outer_dim_size; |
235 | conf.inner_dim_size = inner_dim_size; |
236 | |
237 | conf.attr_info = attr_info_t::create(attr()); |
238 | |
239 | // Heuristics based on testing on PVC |
240 | conf.sub_group_size = compute_engine->device_info()->max_subgroup_size(); |
241 | |
242 | const int target_reduction_size = 8; |
243 | |
244 | // Pad the inner dim to a multiple of subgroup size |
245 | conf.inner_dim_per_sg = std::min(reduced_dim_size, |
246 | std::max(1, conf.sub_group_size / conf.inner_dim_size)); |
247 | conf.gws_inner_dim_size = utils::rnd_up( |
248 | conf.inner_dim_per_sg * conf.inner_dim_size, conf.sub_group_size); |
249 | |
250 | while (reduced_dim_size > 1) { |
251 | data_type_t src_data_type; |
252 | bool is_first; |
253 | if (reduced_dim_size == conf.div) { |
254 | src_data_type = src_mdw.data_type(); |
255 | is_first = true; |
256 | } else { |
257 | src_data_type = types::default_accum_data_type( |
258 | src_mdw.data_type(), data_type::undef); |
259 | is_first = false; |
260 | } |
261 | |
262 | // Compute the number of phases left |
263 | const int horiz_reductions |
264 | = utils::div_up(reduced_dim_size, conf.inner_dim_per_sg); |
265 | const int num_remaining_phases = std::floor( |
266 | std::log(horiz_reductions) / std::log(target_reduction_size)); |
267 | const int red_per_phase |
268 | = std::pow(reduced_dim_size, 1.0f / num_remaining_phases); |
269 | |
270 | int reduction_size; |
271 | bool is_final; |
272 | data_type_t dst_data_type; |
273 | if (num_remaining_phases > 1) { |
274 | reduction_size = red_per_phase; |
275 | is_final = false; |
276 | dst_data_type = types::default_accum_data_type( |
277 | src_mdw.data_type(), data_type::undef); |
278 | } else { |
279 | reduction_size = reduced_dim_size; |
280 | is_final = true; |
281 | dst_data_type = dst_mdw.data_type(); |
282 | } |
283 | |
284 | const int phase_start = reduced_dim_size; |
285 | const int phase_reductions = reduction_size; |
286 | const int phase_end = utils::div_up(phase_start, phase_reductions); |
287 | |
288 | // Set scratchpad sizes |
289 | const int phase_num = (int)conf.phases.size(); |
290 | if (!is_final && phase_num < 2) { |
291 | conf.sp_size[phase_num] |
292 | = conf.outer_dim_size * phase_end * conf.inner_dim_size; |
293 | } |
294 | |
295 | compute::dispatch_t dispatch = compute_engine->create_dispatch(); |
296 | size_t gws[3] = {1, 1, 1}, lws[3] = {1, 1, 1}; |
297 | gws[0] *= outer_dim_size * conf.gws_inner_dim_size * phase_end; |
298 | |
299 | // Set lws + pad gws simultaneously |
300 | // - lws multiple of sub_group_size |
301 | // - gws multiple of lws |
302 | lws[0] = utils::rnd_up(std::min((int)gws[0], 256), conf.sub_group_size); |
303 | gws[0] = utils::rnd_up(gws[0], lws[0]); |
304 | compute::nd_range_t nd_range(gws, lws); |
305 | |
306 | conf.phases.push_back(init_phase(phase_start, phase_end, |
307 | phase_reductions, src_data_type, dst_data_type, nd_range, |
308 | is_final, is_first)); |
309 | reduced_dim_size = phase_end; |
310 | } |
311 | return status::success; |
312 | } |
313 | |
314 | static status_t init_kernel_ctx_common(compute::kernel_ctx_t &kernel_ctx, |
315 | const reduction_conf_t &conf, const reduction_phase_t &phase) { |
316 | using namespace alg_kind; |
317 | |
318 | kernel_ctx.set_data_type(phase.src_type); |
319 | |
320 | // 1 ==> Use subgroups |
321 | kernel_ctx.define_int("GWS_WITH_SG_DEFAULT" , 1); |
322 | kernel_ctx.define_int("GWS_SGS_DEFAULT" , conf.sub_group_size); |
323 | kernel_ctx.define_int("GWS_LWS0_DEFAULT" , phase.nd_range.local_range()[0]); |
324 | kernel_ctx.define_int("GWS_LWS1_DEFAULT" , 1); |
325 | kernel_ctx.define_int("GWS_LWS2_DEFAULT" , 1); |
326 | kernel_ctx.define_int("INNER_DIMS_PER_WI" , conf.inner_dim_per_sg); |
327 | |
328 | kernel_ctx.define_int("REDUCTION_END_SIZE" , phase.final_size); |
329 | kernel_ctx.define_int("REDUCTION_SIZE" , phase.initial_size); |
330 | kernel_ctx.define_int("REDUCTION_CHUNK_SIZE" , phase.num_reduction_chunks); |
331 | kernel_ctx.define_int("OUTER_DIM_STRIDE" , |
332 | phase.num_reduction_chunks * conf.gws_inner_dim_size); |
333 | kernel_ctx.define_int("DIV" , conf.div); |
334 | kernel_ctx.define_int("OUTER_DIM_SIZE" , conf.outer_dim_size); |
335 | kernel_ctx.define_int("INNER_DIM_SIZE" , conf.inner_dim_size); |
336 | kernel_ctx.define_int("PADDED_INNER_DIM_SIZE" , conf.gws_inner_dim_size); |
337 | kernel_ctx.define_int("NDIMS" , conf.ndims); |
338 | kernel_ctx.define_int("POWER" , conf.power); |
339 | kernel_ctx.define_float("EPS" , conf.eps); |
340 | |
341 | int sg_reduction_per_wi |
342 | = utils::div_up(phase.reduction_size, conf.inner_dim_per_sg); |
343 | sg_reduction_per_wi = std::min(conf.div, sg_reduction_per_wi); |
344 | kernel_ctx.define_int("REDUCTIONS_PER_WI" , |
345 | sg_reduction_per_wi); // Can change between phases |
346 | kernel_ctx.define_int("IS_FINAL" , phase.is_final); |
347 | kernel_ctx.define_int("IS_FIRST" , phase.is_first); |
348 | |
349 | // Block loading is supported when inner dims are a multiple of 4 bytes |
350 | if ((types::data_type_size(phase.src_type) * conf.inner_dim_size |
351 | * conf.inner_dim_per_sg) |
352 | % 4 |
353 | == 0) { |
354 | kernel_ctx.define_int("WITH_BLOCK_READ" , 1); |
355 | } else { |
356 | kernel_ctx.define_int("WITH_BLOCK_READ" , 0); |
357 | } |
358 | |
359 | switch (conf.alg) { |
360 | case reduction_max: kernel_ctx.define_int("IS_MAX" , 1); break; |
361 | case reduction_min: kernel_ctx.define_int("IS_MIN" , 1); break; |
362 | case reduction_mean: kernel_ctx.define_int("IS_MEAN" , 1); break; |
363 | case reduction_sum: kernel_ctx.define_int("IS_SUM" , 1); break; |
364 | case reduction_mul: kernel_ctx.define_int("IS_MUL" , 1); break; |
365 | case reduction_norm_lp_max: |
366 | kernel_ctx.define_int("IS_LP_MAX" , 1); |
367 | break; |
368 | case reduction_norm_lp_sum: |
369 | kernel_ctx.define_int("IS_LP_SUM" , 1); |
370 | break; |
371 | case reduction_norm_lp_power_p_max: |
372 | kernel_ctx.define_int("IS_P_MAX" , 1); |
373 | break; |
374 | case reduction_norm_lp_power_p_sum: |
375 | kernel_ctx.define_int("IS_P_SUM" , 1); |
376 | break; |
377 | default: return status::invalid_arguments; |
378 | } |
379 | |
380 | def_data_type(kernel_ctx, phase.src_type, "SRC" ); |
381 | def_data_type(kernel_ctx, phase.dst_type, "DST" ); |
382 | |
383 | return status::success; |
384 | } |
385 | |
386 | status_t combined_reduction_t::pd_t::init_kernel_ctx( |
387 | compute::kernel_ctx_t &kernel_ctx, |
388 | const reduction_phase_t &phase) const { |
389 | return init_kernel_ctx_common(kernel_ctx, conf, phase); |
390 | } |
391 | |
392 | status_t combined_reduction_t::execute_combined(const exec_ctx_t &ctx) const { |
393 | auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC); |
394 | auto &dst = CTX_OUT_STORAGE(DNNL_ARG_DST); |
395 | std::unique_ptr<memory_storage_t> sp_reduce[2] |
396 | = {ctx.get_scratchpad_grantor().get_memory_storage( |
397 | memory_tracking::names::key_reduction), |
398 | ctx.get_scratchpad_grantor().get_memory_storage( |
399 | memory_tracking::names::key_reduction_1)}; |
400 | const auto &conf = pd()->conf; |
401 | |
402 | status_t status = status::success; |
403 | for (size_t i = 0; i < kernels.size(); i++) { |
404 | auto &kernel = kernels[i]; |
405 | auto &phase = conf.phases[i]; |
406 | auto nd_range = phase.nd_range; |
407 | |
408 | // Set up the reduction arg list |
409 | compute::kernel_arg_list_t reduction_arg_list; |
410 | |
411 | if (i == 0) { |
412 | reduction_arg_list.set(0, src); |
413 | } else { |
414 | reduction_arg_list.set(0, *sp_reduce[(i - 1) % 2]); |
415 | } |
416 | |
417 | if (i == kernels.size() - 1) { |
418 | reduction_arg_list.set(1, dst); |
419 | } else { |
420 | reduction_arg_list.set(1, *sp_reduce[i % 2]); |
421 | } |
422 | |
423 | status = parallel_for(ctx, nd_range, kernel, reduction_arg_list); |
424 | CHECK(status); |
425 | } |
426 | return status; |
427 | } |
428 | |
429 | } // namespace ocl |
430 | } // namespace gpu |
431 | } // namespace impl |
432 | } // namespace dnnl |
433 | |