1/*******************************************************************************
2* Copyright 2019-2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_JIT_GEN9_SIMPLE_SUM_HPP
18#define GPU_JIT_GEN9_SIMPLE_SUM_HPP
19
20#include "common/c_types_map.hpp"
21#include "gpu/compute/compute.hpp"
22#include "gpu/gpu_primitive.hpp"
23#include "gpu/gpu_sum_pd.hpp"
24
25namespace dnnl {
26namespace impl {
27namespace gpu {
28namespace jit {
29
30struct gen9_simple_sum_t : public gpu_primitive_t {
31 struct pd_t : public gpu_sum_pd_t {
32 using gpu_sum_pd_t::gpu_sum_pd_t;
33
34 DECLARE_SUM_PD_T("ngen:simple:any", gen9_simple_sum_t);
35
36 status_t init(engine_t *engine) {
37 auto *compute_engine
38 = utils::downcast<compute::compute_engine_t *>(engine);
39 if (!compute_engine->mayiuse_ngen_kernels())
40 return status::unimplemented;
41
42 const int n = n_inputs();
43
44 constexpr auto data_type = data_type::f32;
45
46 bool ok = gpu_sum_pd_t::init(engine) == status::success;
47 if (!ok) return status::unimplemented;
48
49 const memory_desc_wrapper o_d(dst_md());
50 ok = ok && o_d.data_type() == data_type && o_d.is_dense();
51 if (!ok) return status::unimplemented;
52
53 for (int i = 0; i < n; ++i) {
54 const memory_desc_wrapper i_d(src_md(i));
55 if (i_d != o_d) return status::unimplemented;
56 }
57
58 return status::success;
59 }
60 };
61
62 gen9_simple_sum_t(const pd_t *apd) : gpu_primitive_t(apd) {}
63
64 virtual status_t init(engine_t *engine);
65
66 virtual status_t execute(const exec_ctx_t &ctx) const {
67 status_t status = status::success;
68 auto &output = CTX_OUT_CLEAN_STORAGE(DNNL_ARG_DST, status);
69 CHECK(status);
70
71 const int num_arrs = pd()->n_inputs();
72 const memory_desc_wrapper o_d(pd()->dst_md());
73 const size_t nelems = o_d.nelems();
74
75 for (int a = 0; a < num_arrs; ++a) {
76 auto &input = CTX_IN_STORAGE(DNNL_ARG_MULTIPLE_SRC + a);
77 const float scale = pd()->scales()[a];
78
79 compute::kernel_arg_list_t arg_list;
80 arg_list.set(0, input);
81 arg_list.set(1, output);
82 arg_list.set(2, scale);
83 arg_list.set(3, a);
84
85 size_t gws[3] = {nelems, 1, 1};
86 size_t lws[3] = {1, 1, 1};
87 auto nd_range = compute::nd_range_t(gws, lws);
88 status = parallel_for(ctx, nd_range, kernel_, arg_list);
89 if (status != status::success) return status;
90 }
91 return status::success;
92 }
93
94private:
95 const pd_t *pd() const { return (const pd_t *)gpu_primitive_t::pd().get(); }
96
97 compute::kernel_t kernel_;
98};
99
100} // namespace jit
101} // namespace gpu
102} // namespace impl
103} // namespace dnnl
104
105#endif // GPU_JIT_GEN9_SIMPLE_SUM_HPP
106