1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_OCL_REF_ELTWISE_HPP
18#define GPU_OCL_REF_ELTWISE_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/primitive.hpp"
22#include "gpu/compute/compute.hpp"
23#include "gpu/gpu_eltwise_pd.hpp"
24#include "gpu/gpu_primitive.hpp"
25#include "gpu/gpu_resource.hpp"
26#include "gpu/ocl/ocl_stream.hpp"
27#include "gpu/ocl/ocl_utils.hpp"
28#include "gpu/primitive_conf.hpp"
29
30namespace dnnl {
31namespace impl {
32namespace gpu {
33namespace ocl {
34
35struct ref_eltwise_fwd_t : public gpu_primitive_t {
36 using gpu_primitive_t::gpu_primitive_t;
37 struct pd_t : public gpu_eltwise_fwd_pd_t {
38 using gpu_eltwise_fwd_pd_t::gpu_eltwise_fwd_pd_t;
39
40 DECLARE_COMMON_PD_T("ocl:ref:any", ref_eltwise_fwd_t);
41
42 status_t init(engine_t *engine) {
43 auto *compute_engine
44 = utils::downcast<compute::compute_engine_t *>(engine);
45
46 const auto attr_skip_mask = primitive_attr_t::skip_mask_t::post_ops;
47
48 using namespace alg_kind;
49 const bool ok = is_fwd()
50 && src_md()->data_type == dst_md()->data_type
51 && !memory_desc_ndims_ok(dst_md())
52 && attr()->has_default_values(attr_skip_mask)
53 && set_default_formats_common()
54 && memory_desc_wrapper(src_md())
55 == memory_desc_wrapper(dst_md())
56 && post_ops_with_binary_ok(
57 attr(), dst_md()->data_type, MAX_NDIMS)
58 && attr_.set_default_formats(dst_md(0)) == status::success
59 && IMPLICATION(src_md()->data_type == data_type::f16,
60 compute_engine->mayiuse(
61 compute::device_ext_t::khr_fp16));
62 if (!ok) return status::unimplemented;
63
64 CHECK(init_conf(engine));
65 if (!compute_engine->mayiuse_sub_group(conf.sub_group_size))
66 return status::unimplemented;
67 return status::success;
68 }
69
70 status_t init_conf(engine_t *engine);
71 status_t init_kernel_ctx(compute::kernel_ctx_t &kernel_ctx) const;
72
73 eltwise_conf_t conf;
74 offsets_t off;
75 };
76
77 status_t init(engine_t *engine) override {
78 compute::kernel_ctx_t kernel_ctx;
79
80 status_t status = pd()->init_kernel_ctx(kernel_ctx);
81 if (status != status::success) return status;
82
83 create_kernel(engine, &kernel_, "ref_eltwise_fwd", kernel_ctx);
84 if (!kernel_) return status::runtime_error;
85
86 return status::success;
87 }
88
89 status_t execute(const exec_ctx_t &ctx) const override {
90 return execute_forward_dense(ctx);
91 }
92
93private:
94 status_t execute_forward_dense(const exec_ctx_t &ctx) const;
95 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
96 compute::kernel_t kernel_;
97};
98
99struct ref_eltwise_bwd_t : public gpu_primitive_t {
100 using gpu_primitive_t::gpu_primitive_t;
101 struct pd_t : public gpu_eltwise_bwd_pd_t {
102 pd_t(const eltwise_desc_t *adesc, const primitive_attr_t *attr,
103 const eltwise_fwd_pd_t *hint_fwd_pd)
104 : gpu_eltwise_bwd_pd_t(adesc, attr, hint_fwd_pd) {}
105
106 DECLARE_COMMON_PD_T("ocl:ref:any", ref_eltwise_bwd_t);
107
108 status_t init(engine_t *engine) {
109 using namespace prop_kind;
110 using namespace utils;
111 assert(engine->kind() == engine_kind::gpu);
112
113 auto *compute_engine
114 = utils::downcast<compute::compute_engine_t *>(engine);
115
116 using namespace alg_kind;
117 const bool ok = !is_fwd()
118 && !memory_desc_ndims_ok(data_md(), diff_dst_md())
119 && utils::one_of(data_md()->data_type, data_type::f32,
120 data_type::bf16)
121 && utils::everyone_is(data_md()->data_type,
122 diff_src_md()->data_type, diff_dst_md()->data_type)
123 && set_default_formats_common()
124 && attr()->has_default_values()
125 && memory_desc_wrapper(diff_dst_md())
126 == memory_desc_wrapper(diff_src_md());
127 if (!ok) return status::unimplemented;
128
129 CHECK(init_conf(engine));
130 if (!compute_engine->mayiuse_sub_group(conf.sub_group_size))
131 return status::unimplemented;
132 return status::success;
133 }
134
135 status_t init_conf(engine_t *engine);
136 status_t init_kernel_ctx(compute::kernel_ctx_t &kernel_ctx) const;
137
138 eltwise_conf_t conf;
139 offsets_t off;
140 bool use_dense;
141 };
142
143 status_t init(engine_t *engine) override {
144 compute::kernel_ctx_t kernel_ctx;
145
146 status_t status = pd()->init_kernel_ctx(kernel_ctx);
147 if (status != status::success) return status;
148
149 create_kernel(engine, &kernel_, "ref_eltwise_bwd", kernel_ctx);
150 if (!kernel_) return status::runtime_error;
151
152 return status::success;
153 }
154
155 status_t execute(const exec_ctx_t &ctx) const override {
156 return execute_backward_dense(ctx);
157 }
158
159private:
160 status_t execute_backward_dense(const exec_ctx_t &ctx) const;
161 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
162 compute::kernel_t kernel_;
163};
164
165} // namespace ocl
166} // namespace gpu
167} // namespace impl
168} // namespace dnnl
169
170#endif
171