1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/ocl/many_inputs_sum.hpp"
18namespace dnnl {
19namespace impl {
20namespace gpu {
21namespace ocl {
22
23status_t many_inputs_sum_t::execute(const exec_ctx_t &ctx) const {
24 auto &output = CTX_OUT_STORAGE(DNNL_ARG_DST);
25 const int num_arrs = pd()->n_inputs()
26 - 1; // input0 is copied over to output. Accumulate the rest.
27 const memory_desc_wrapper o_d(pd()->dst_md());
28 const size_t nelems = o_d.nelems(true);
29 compute::kernel_arg_list_t arg_list;
30
31 int num_batches = utils::div_up(num_arrs, max_num_arrs);
32
33 auto &input0 = CTX_IN_STORAGE(DNNL_ARG_MULTIPLE_SRC);
34 const bool is_inplace = (output.data_handle() == input0.data_handle());
35 if (!is_inplace) {
36 auto *compute_stream
37 = utils::downcast<compute::compute_stream_t *>(ctx.stream());
38 CHECK(compute_stream->copy(input0, output, o_d.size()));
39 }
40 status_t status;
41 for (int batch_iter = 0; batch_iter < num_batches; batch_iter++) {
42 int kernel_num_arrs = max_num_arrs;
43 if ((batch_iter == num_batches - 1) && (num_arrs % max_num_arrs)) {
44 kernel_num_arrs = num_arrs % max_num_arrs;
45 }
46 for (int a = 0; a < max_num_arrs; ++a) {
47 if (a < kernel_num_arrs) {
48 auto &input = CTX_IN_STORAGE(DNNL_ARG_MULTIPLE_SRC + 1 + a
49 + batch_iter * max_num_arrs);
50 arg_list.set(a, input);
51 } else
52 arg_list.set(a, memory_storage_t::empty_storage());
53 }
54 arg_list.set(94, output);
55 arg_list.set(95, CTX_GPU_RES_STORAGE(SCALES_));
56
57 const size_t total_width = nelems * kernel_num_arrs;
58 const size_t lws = utils::rnd_dn(256, kernel_num_arrs);
59
60 compute::nd_range_t nd_range({utils::rnd_up(total_width, lws)}, {lws});
61 if (batch_iter == num_batches - 1) {
62 status = parallel_for(ctx, nd_range, kernel_, arg_list);
63 } else {
64 status = parallel_for(ctx, nd_range, batched_kernel_, arg_list);
65 }
66 if (status != dnnl_success) return status;
67 }
68 return status;
69}
70
71} // namespace ocl
72} // namespace gpu
73} // namespace impl
74} // namespace dnnl
75