1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "gpu/ocl/many_inputs_sum.hpp" |
18 | namespace dnnl { |
19 | namespace impl { |
20 | namespace gpu { |
21 | namespace ocl { |
22 | |
23 | status_t many_inputs_sum_t::execute(const exec_ctx_t &ctx) const { |
24 | auto &output = CTX_OUT_STORAGE(DNNL_ARG_DST); |
25 | const int num_arrs = pd()->n_inputs() |
26 | - 1; // input0 is copied over to output. Accumulate the rest. |
27 | const memory_desc_wrapper o_d(pd()->dst_md()); |
28 | const size_t nelems = o_d.nelems(true); |
29 | compute::kernel_arg_list_t arg_list; |
30 | |
31 | int num_batches = utils::div_up(num_arrs, max_num_arrs); |
32 | |
33 | auto &input0 = CTX_IN_STORAGE(DNNL_ARG_MULTIPLE_SRC); |
34 | const bool is_inplace = (output.data_handle() == input0.data_handle()); |
35 | if (!is_inplace) { |
36 | auto *compute_stream |
37 | = utils::downcast<compute::compute_stream_t *>(ctx.stream()); |
38 | CHECK(compute_stream->copy(input0, output, o_d.size())); |
39 | } |
40 | status_t status; |
41 | for (int batch_iter = 0; batch_iter < num_batches; batch_iter++) { |
42 | int kernel_num_arrs = max_num_arrs; |
43 | if ((batch_iter == num_batches - 1) && (num_arrs % max_num_arrs)) { |
44 | kernel_num_arrs = num_arrs % max_num_arrs; |
45 | } |
46 | for (int a = 0; a < max_num_arrs; ++a) { |
47 | if (a < kernel_num_arrs) { |
48 | auto &input = CTX_IN_STORAGE(DNNL_ARG_MULTIPLE_SRC + 1 + a |
49 | + batch_iter * max_num_arrs); |
50 | arg_list.set(a, input); |
51 | } else |
52 | arg_list.set(a, memory_storage_t::empty_storage()); |
53 | } |
54 | arg_list.set(94, output); |
55 | arg_list.set(95, CTX_GPU_RES_STORAGE(SCALES_)); |
56 | |
57 | const size_t total_width = nelems * kernel_num_arrs; |
58 | const size_t lws = utils::rnd_dn(256, kernel_num_arrs); |
59 | |
60 | compute::nd_range_t nd_range({utils::rnd_up(total_width, lws)}, {lws}); |
61 | if (batch_iter == num_batches - 1) { |
62 | status = parallel_for(ctx, nd_range, kernel_, arg_list); |
63 | } else { |
64 | status = parallel_for(ctx, nd_range, batched_kernel_, arg_list); |
65 | } |
66 | if (status != dnnl_success) return status; |
67 | } |
68 | return status; |
69 | } |
70 | |
71 | } // namespace ocl |
72 | } // namespace gpu |
73 | } // namespace impl |
74 | } // namespace dnnl |
75 | |