1 | /******************************************************************************* |
2 | * Copyright 2020-2021 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <assert.h> |
18 | #include <math.h> |
19 | |
20 | #include "gpu/ocl/gen9_sum.hpp" |
21 | |
22 | namespace dnnl { |
23 | namespace impl { |
24 | namespace gpu { |
25 | namespace ocl { |
26 | |
27 | status_t gen9_sum_t::execute(const exec_ctx_t &ctx) const { |
28 | auto &output = CTX_OUT_STORAGE(DNNL_ARG_DST); |
29 | const int num_arrs = pd()->n_inputs(); |
30 | const memory_desc_wrapper o_d(pd()->dst_md()); |
31 | const size_t nelems = o_d.nelems(true); |
32 | compute::kernel_arg_list_t arg_list; |
33 | |
34 | for (int a = 0; a < max_num_arrs; ++a) { |
35 | if (a < num_arrs) { |
36 | auto &input = CTX_IN_STORAGE(DNNL_ARG_MULTIPLE_SRC + a); |
37 | arg_list.set(a, input); |
38 | } else |
39 | arg_list.set(a, memory_storage_t::empty_storage()); |
40 | } |
41 | arg_list.set(16, output); |
42 | arg_list.set(17, CTX_GPU_RES_STORAGE(SCALES_)); |
43 | |
44 | const size_t total_width = utils::div_up(nelems, vector_size); |
45 | const size_t lws = 256; |
46 | |
47 | compute::nd_range_t nd_range({utils::rnd_up(total_width, lws)}, {lws}); |
48 | |
49 | return parallel_for(ctx, nd_range, kernel_, arg_list); |
50 | } |
51 | |
52 | } // namespace ocl |
53 | } // namespace gpu |
54 | } // namespace impl |
55 | } // namespace dnnl |
56 | |