1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_OCL_GEN9_ELTWISE_HPP
18#define GPU_OCL_GEN9_ELTWISE_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/primitive.hpp"
22#include "gpu/compute/compute.hpp"
23#include "gpu/gpu_eltwise_pd.hpp"
24#include "gpu/gpu_primitive.hpp"
25#include "gpu/gpu_resource.hpp"
26#include "gpu/ocl/ocl_stream.hpp"
27#include "gpu/ocl/ocl_utils.hpp"
28#include "gpu/primitive_conf.hpp"
29
30namespace dnnl {
31namespace impl {
32namespace gpu {
33namespace ocl {
34
35struct gen9_eltwise_fwd_t : public gpu_primitive_t {
36 using gpu_primitive_t::gpu_primitive_t;
37 struct pd_t : public gpu_eltwise_fwd_pd_t {
38 using gpu_eltwise_fwd_pd_t::gpu_eltwise_fwd_pd_t;
39
40 DECLARE_COMMON_PD_T("ocl:gen9:any", gen9_eltwise_fwd_t);
41
42 status_t init(engine_t *engine) {
43 auto *compute_engine
44 = utils::downcast<compute::compute_engine_t *>(engine);
45
46 using namespace alg_kind;
47 bool ok = is_fwd() && src_md()->data_type == dst_md()->data_type
48 && attr()->has_default_values()
49 && set_default_formats_common()
50 && memory_desc_wrapper(src_md())
51 == memory_desc_wrapper(dst_md())
52 && IMPLICATION(src_md()->data_type == data_type::f16,
53 compute_engine->mayiuse(
54 compute::device_ext_t::khr_fp16))
55 && compute_engine->mayiuse_sub_group(16);
56 if (!ok) return status::unimplemented;
57
58 return init_conf(engine);
59 }
60
61 status_t init_conf(engine_t *engine);
62 status_t init_kernel_ctx(compute::kernel_ctx_t &kernel_ctx) const;
63
64 eltwise_conf_t conf;
65 offsets_t off;
66 };
67
68 status_t init(engine_t *engine) override {
69 compute::kernel_ctx_t kernel_ctx;
70
71 status_t status = pd()->init_kernel_ctx(kernel_ctx);
72 if (status != status::success) return status;
73
74 create_kernel(engine, &kernel_, "gen9_eltwise_fwd", kernel_ctx);
75 if (!kernel_) return status::runtime_error;
76
77 return status::success;
78 }
79
80 virtual status_t execute(const exec_ctx_t &ctx) const override {
81 return execute_forward_dense(ctx);
82 }
83
84private:
85 status_t execute_forward_dense(const exec_ctx_t &ctx) const;
86 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
87 compute::kernel_t kernel_;
88};
89
90struct gen9_eltwise_bwd_t : public gpu_primitive_t {
91 using gpu_primitive_t::gpu_primitive_t;
92 struct pd_t : public gpu_eltwise_bwd_pd_t {
93 pd_t(const eltwise_desc_t *adesc, const primitive_attr_t *attr,
94 const eltwise_fwd_pd_t *hint_fwd_pd)
95 : gpu_eltwise_bwd_pd_t(adesc, attr, hint_fwd_pd) {}
96
97 DECLARE_COMMON_PD_T("ocl:gen9:any", gen9_eltwise_bwd_t);
98
99 status_t init(engine_t *engine) {
100 using namespace prop_kind;
101 using namespace utils;
102 assert(engine->kind() == engine_kind::gpu);
103
104 using namespace alg_kind;
105 bool ok = !is_fwd()
106 && utils::one_of(data_md()->data_type, data_type::f32,
107 data_type::bf16)
108 && utils::everyone_is(data_md()->data_type,
109 diff_src_md()->data_type, diff_dst_md()->data_type)
110 && set_default_formats_common()
111 && attr()->has_default_values()
112 && memory_desc_wrapper(diff_dst_md())
113 == memory_desc_wrapper(diff_src_md());
114 if (!ok) return status::unimplemented;
115
116 return init_conf(engine);
117 }
118
119 status_t init_conf(engine_t *engine);
120 status_t init_kernel_ctx(compute::kernel_ctx_t &kernel_ctx) const;
121
122 eltwise_conf_t conf;
123 offsets_t off;
124 bool use_dense;
125 };
126
127 status_t init(engine_t *engine) override {
128 compute::kernel_ctx_t kernel_ctx;
129
130 status_t status = pd()->init_kernel_ctx(kernel_ctx);
131 if (status != status::success) return status;
132
133 create_kernel(engine, &kernel_, "gen9_eltwise_bwd", kernel_ctx);
134 if (!kernel_) return status::runtime_error;
135
136 return status::success;
137 }
138
139 virtual status_t execute(const exec_ctx_t &ctx) const override {
140 return execute_backward_dense(ctx);
141 }
142
143private:
144 status_t execute_backward_dense(const exec_ctx_t &ctx) const;
145 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
146 compute::kernel_t kernel_;
147};
148
149} // namespace ocl
150} // namespace gpu
151} // namespace impl
152} // namespace dnnl
153
154#endif
155