1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_OCL_MANY_INPUTS_SUM_HPP
18#define GPU_OCL_MANY_INPUTS_SUM_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/primitive.hpp"
22#include "gpu/compute/compute.hpp"
23#include "gpu/gpu_primitive.hpp"
24#include "gpu/gpu_resource.hpp"
25#include "gpu/gpu_sum_pd.hpp"
26#include "gpu/ocl/ocl_stream.hpp"
27#include "gpu/ocl/ocl_utils.hpp"
28#include "gpu/primitive_conf.hpp"
29
30namespace dnnl {
31namespace impl {
32namespace gpu {
33namespace ocl {
34
35struct many_inputs_sum_t : public gpu_primitive_t {
36 using gpu_primitive_t::gpu_primitive_t;
37 struct pd_t : public gpu_sum_pd_t {
38 using gpu_sum_pd_t::gpu_sum_pd_t;
39
40 DECLARE_SUM_PD_T("ocl:many_inputs", many_inputs_sum_t);
41
42 status_t init(engine_t *engine) {
43 const int n = n_inputs();
44
45 bool ok = gpu_sum_pd_t::init(engine) == status::success;
46
47 if (!ok) return status::unimplemented;
48
49 const memory_desc_wrapper o_d(dst_md());
50
51 for (int i = 0; i < n; ++i) {
52 const memory_desc_wrapper i_d(src_md(i));
53 if (i_d != o_d) return status::unimplemented;
54 }
55
56 if (scales()[0] != 1.0f) return status::unimplemented;
57 return status::success;
58 }
59 };
60
61 status_t init(engine_t *engine) override {
62 compute::kernel_ctx_t kernel_ctx;
63
64 const memory_desc_wrapper data_d(pd()->dst_md());
65 const memory_desc_wrapper data_s(pd()->src_md());
66
67 kernel_ctx.set_data_type(data_s.data_type());
68
69 kernel_ctx.define_int("N_ELEMS", data_d.nelems(true));
70
71 const int num_arrs = pd()->n_inputs() - 1;
72 int N_INPUTS = (num_arrs) % max_num_arrs;
73 if (N_INPUTS == 0) { N_INPUTS = max_num_arrs; };
74 kernel_ctx.define_int("N_INPUTS", N_INPUTS);
75 kernel_ctx.define_int("MAX_N_INPUTS", max_num_arrs);
76
77 def_memory_desc_info(
78 kernel_ctx, memory_desc_info_t::create(data_d), "SRC");
79 def_memory_desc_info(
80 kernel_ctx, memory_desc_info_t::create(data_s), "DST");
81
82 std::vector<compute::kernel_t> kernels;
83 std::vector<const char *> kernel_names;
84 kernel_names.push_back("many_inputs_sum");
85 kernel_names.push_back("many_inputs_sum_batched");
86 CHECK(create_kernels(engine, &kernels, kernel_names, kernel_ctx));
87 kernel_ = kernels[0];
88 batched_kernel_ = kernels[1];
89 if (!kernel_ || !batched_kernel_) return status::runtime_error;
90 return status::success;
91 }
92
93 status_t init_res_storage(
94 engine_t *engine, gpu_resource_t *r) const override {
95 const dim_t count = pd()->n_inputs();
96 const float *s_data = pd()->scales();
97
98 const size_t size = count * sizeof(float);
99 std::unique_ptr<memory_storage_t> scales;
100 memory_storage_t *scale = nullptr;
101 auto s = engine->create_memory_storage(&scale, size);
102 if (s != status::success) return s;
103 float *mapped_mem_storage = nullptr;
104 s = scale->map_data((void **)&mapped_mem_storage, nullptr, size);
105 if (s != status::success) return s;
106 utils::array_copy(mapped_mem_storage, s_data, count);
107 s = scale->unmap_data((void *)mapped_mem_storage, nullptr);
108 if (s != status::success) return s;
109 scales.reset(scale);
110 r->add_memory_storage(SCALES_, std::move(scales));
111 return status::success;
112 }
113
114 status_t execute(const exec_ctx_t &ctx) const override;
115
116private:
117 enum { max_num_arrs = 94 };
118 enum { SCALES_ = 0 };
119 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
120 compute::kernel_t kernel_;
121 compute::kernel_t batched_kernel_;
122};
123
124} // namespace ocl
125} // namespace gpu
126} // namespace impl
127} // namespace dnnl
128
129#endif
130
131// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
132