1/*******************************************************************************
2* Copyright 2019-2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_OCL_SIMPLE_SUM_HPP
18#define GPU_OCL_SIMPLE_SUM_HPP
19
20#include <assert.h>
21
22#include "common/c_types_map.hpp"
23#include "common/primitive.hpp"
24#include "gpu/compute/compute.hpp"
25#include "gpu/gpu_primitive.hpp"
26#include "gpu/gpu_resource.hpp"
27#include "gpu/gpu_sum_pd.hpp"
28#include "gpu/ocl/ocl_stream.hpp"
29#include "gpu/ocl/ocl_utils.hpp"
30#include "gpu/primitive_conf.hpp"
31
32namespace dnnl {
33namespace impl {
34namespace gpu {
35namespace ocl {
36
37template <data_type_t data_type>
38struct simple_sum_t : public gpu_primitive_t {
39 using gpu_primitive_t::gpu_primitive_t;
40 struct pd_t : public gpu_sum_pd_t {
41 using gpu_sum_pd_t::gpu_sum_pd_t;
42
43 DECLARE_SUM_PD_T("ocl:simple:any", simple_sum_t);
44
45 status_t init(engine_t *engine) {
46 const int n = n_inputs();
47
48 bool ok = gpu_sum_pd_t::init(engine) == status::success
49 && n <= max_num_arrs;
50 if (!ok) return status::unimplemented;
51
52 const memory_desc_wrapper o_d(dst_md());
53 ok = ok && o_d.data_type() == data_type && o_d.is_dense();
54 if (!ok) return status::unimplemented;
55
56 for (int i = 0; i < n; ++i) {
57 const memory_desc_wrapper i_d(src_md(i));
58 if (i_d != o_d) return status::unimplemented;
59 }
60
61 return status::success;
62 }
63 };
64
65 status_t init(engine_t *engine) override {
66 compute::kernel_ctx_t kernel_ctx;
67 create_kernel(engine, &kernel_, "simple_sum", kernel_ctx);
68 if (!kernel_) return status::runtime_error;
69 return status::success;
70 }
71
72 status_t execute(const exec_ctx_t &ctx) const override;
73
74 enum { max_num_arrs = 16 };
75 typedef typename prec_traits<data_type>::type data_t;
76
77private:
78 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
79 compute::kernel_t kernel_;
80};
81
82} // namespace ocl
83} // namespace gpu
84} // namespace impl
85} // namespace dnnl
86
87#endif
88
89// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
90