1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_OCL_REF_PRELU_HPP
18#define GPU_OCL_REF_PRELU_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/primitive.hpp"
22#include "common/reduction_pd.hpp"
23#include "common/type_helpers.hpp"
24#include "common/utils.hpp"
25#include "gpu/compute/compute.hpp"
26#include "gpu/gpu_prelu_pd.hpp"
27#include "gpu/gpu_primitive.hpp"
28#include "gpu/gpu_resource.hpp"
29#include "gpu/primitive_conf.hpp"
30
31namespace dnnl {
32namespace impl {
33namespace gpu {
34namespace ocl {
35
36struct ref_prelu_fwd_t : public gpu_primitive_t {
37 using gpu_primitive_t::gpu_primitive_t;
38 struct pd_t : public gpu_prelu_fwd_pd_t {
39 using gpu_prelu_fwd_pd_t::gpu_prelu_fwd_pd_t;
40
41 DECLARE_COMMON_PD_T("prelu_ref:any", ref_prelu_fwd_t);
42
43 status_t init(engine_t *engine) {
44
45 bool ok = is_fwd() && src_md()->data_type == dst_md()->data_type
46 && set_default_formats() && attr()->has_default_values()
47 && !memory_desc_ndims_ok(
48 src_md(0), dst_md(0), weights_md(0))
49 && memory_desc_wrapper(src_md())
50 == memory_desc_wrapper(dst_md());
51
52 if (!ok) return status::unimplemented;
53
54 return init_conf(engine);
55 }
56
57 status_t init_conf(engine_t *engine);
58 status_t init_kernel_ctx(compute::kernel_ctx_t &kernel_ctx) const;
59
60 prelu_conf_t conf;
61 };
62
63 status_t init(engine_t *engine) override {
64 compute::kernel_ctx_t kernel_ctx;
65
66 status_t status = pd()->init_kernel_ctx(kernel_ctx);
67 CHECK(status);
68
69 status = create_kernel(engine, &kernel_, "ref_prelu_fwd", kernel_ctx);
70 CHECK(status);
71
72 return status::success;
73 }
74
75 virtual status_t execute(const exec_ctx_t &ctx) const override {
76 return execute_forward(ctx);
77 }
78
79private:
80 status_t execute_forward(const exec_ctx_t &ctx) const;
81 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
82
83 compute::kernel_t kernel_;
84};
85
86struct ref_prelu_bwd_t : public gpu_primitive_t {
87 using gpu_primitive_t::gpu_primitive_t;
88 struct pd_t : public gpu_prelu_bwd_pd_t {
89 using gpu_prelu_bwd_pd_t::gpu_prelu_bwd_pd_t;
90
91 pd_t(const prelu_desc_t *adesc, const primitive_attr_t *attr,
92 const prelu_fwd_pd_t *hint_fwd_pd)
93 : gpu_prelu_bwd_pd_t(adesc, attr, hint_fwd_pd) {}
94
95 pd_t(const pd_t &other) = default;
96
97 ~pd_t() = default;
98
99 DECLARE_COMMON_PD_T("prelu_ref:any", ref_prelu_bwd_t);
100
101 status_t init(engine_t *engine) {
102
103 bool ok = !is_fwd()
104 && diff_dst_md()->data_type == diff_src_md()->data_type
105 && set_default_formats() && attr()->has_default_values()
106 && !memory_desc_ndims_ok(
107 diff_src_md(0), diff_dst_md(0), diff_weights_md(0))
108 && memory_desc_wrapper(diff_dst_md())
109 == memory_desc_wrapper(diff_src_md());
110
111 if (!ok) return status::unimplemented;
112
113 status_t status = init_conf(engine);
114 if (conf.reduce_diff_weights) {
115 CHECK(init_reduction(engine));
116 init_scratchpad();
117 }
118
119 return status;
120 }
121
122 status_t init_conf(engine_t *engine);
123 status_t init_kernel_ctx(compute::kernel_ctx_t &kernel_ctx) const;
124 void init_scratchpad();
125
126 status_t init_reduction(engine_t *engine) {
127 reduction_desc_t rdesc;
128 memory_desc_t red_diff_mem_desc(*src_md(0));
129 red_diff_mem_desc.data_type = dnnl_f32;
130 reduction_desc_init(&rdesc, dnnl_alg_kind_t::dnnl_reduction_sum,
131 &red_diff_mem_desc, diff_weights_md(0), 0, 0);
132 primitive_attr_t reduction_attr(*attr());
133 if (!reduction_attr.is_initialized()) return status::out_of_memory;
134 primitive_desc_iterator_t it(
135 engine, (op_desc_t *)&rdesc, &reduction_attr, nullptr);
136 if (!it.is_initialized()) return status::invalid_arguments;
137 reduction_pd_ = *(++it);
138 if (reduction_pd_)
139 return status::success;
140 else {
141 return status::invalid_arguments;
142 }
143 }
144
145 prelu_conf_t conf;
146 std::shared_ptr<primitive_desc_t> reduction_pd_;
147 };
148
149 status_t init(engine_t *engine) override {
150 compute::kernel_ctx_t kernel_ctx;
151
152 status_t status = pd()->init_kernel_ctx(kernel_ctx);
153 CHECK(status);
154
155 status = create_kernel(engine, &kernel_, "ref_prelu_bwd", kernel_ctx);
156 CHECK(status);
157
158 if (pd()->conf.reduce_diff_weights) {
159 CHECK(create_nested_primitive(
160 reduction_p_, pd()->reduction_pd_, engine));
161 }
162 return status::success;
163 }
164
165 status_t execute(const exec_ctx_t &ctx) const override {
166 return execute_backward(ctx);
167 }
168
169private:
170 status_t execute_backward(const exec_ctx_t &ctx) const;
171 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
172
173 compute::kernel_t kernel_;
174 std::shared_ptr<primitive_t> reduction_p_;
175};
176
177} // namespace ocl
178} // namespace gpu
179} // namespace impl
180} // namespace dnnl
181
182#endif
183