1/*******************************************************************************
2* Copyright 2019-2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <assert.h>
18#include <math.h>
19
20#include "gpu/ocl/simple_sum.hpp"
21
22#include "common/c_types_map.hpp"
23#include "common/math_utils.hpp"
24#include "common/nstl.hpp"
25#include "common/type_helpers.hpp"
26#include "gpu/compute/compute.hpp"
27
28namespace dnnl {
29namespace impl {
30namespace gpu {
31namespace ocl {
32
33template <data_type_t data_type>
34status_t simple_sum_t<data_type>::execute(const exec_ctx_t &ctx) const {
35
36 auto &output = CTX_OUT_STORAGE(DNNL_ARG_DST);
37
38 const int num_arrs = pd()->n_inputs();
39 const memory_desc_wrapper o_d(pd()->dst_md());
40 const size_t nelems = o_d.nelems();
41
42 for (int a = 0; a < num_arrs; ++a) {
43
44 auto &input = CTX_IN_STORAGE(DNNL_ARG_MULTIPLE_SRC + a);
45 const float scale = pd()->scales()[a];
46
47 compute::kernel_arg_list_t arg_list;
48 arg_list.set(0, input);
49 arg_list.set(1, output);
50 arg_list.set(2, scale);
51 arg_list.set(3, a);
52
53 auto nd_range = compute::nd_range_t({nelems});
54
55 status_t status = parallel_for(ctx, nd_range, kernel_, arg_list);
56 if (status != status::success) return status;
57 }
58 return status::success;
59}
60
61template struct simple_sum_t<data_type::f32>;
62
63} // namespace ocl
64} // namespace gpu
65} // namespace impl
66} // namespace dnnl
67