1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef GPU_OCL_REF_PRELU_HPP |
18 | #define GPU_OCL_REF_PRELU_HPP |
19 | |
20 | #include "common/c_types_map.hpp" |
21 | #include "common/primitive.hpp" |
22 | #include "common/reduction_pd.hpp" |
23 | #include "common/type_helpers.hpp" |
24 | #include "common/utils.hpp" |
25 | #include "gpu/compute/compute.hpp" |
26 | #include "gpu/gpu_prelu_pd.hpp" |
27 | #include "gpu/gpu_primitive.hpp" |
28 | #include "gpu/gpu_resource.hpp" |
29 | #include "gpu/primitive_conf.hpp" |
30 | |
31 | namespace dnnl { |
32 | namespace impl { |
33 | namespace gpu { |
34 | namespace ocl { |
35 | |
36 | struct ref_prelu_fwd_t : public gpu_primitive_t { |
37 | using gpu_primitive_t::gpu_primitive_t; |
38 | struct pd_t : public gpu_prelu_fwd_pd_t { |
39 | using gpu_prelu_fwd_pd_t::gpu_prelu_fwd_pd_t; |
40 | |
41 | DECLARE_COMMON_PD_T("prelu_ref:any" , ref_prelu_fwd_t); |
42 | |
43 | status_t init(engine_t *engine) { |
44 | |
45 | bool ok = is_fwd() && src_md()->data_type == dst_md()->data_type |
46 | && set_default_formats() && attr()->has_default_values() |
47 | && !memory_desc_ndims_ok( |
48 | src_md(0), dst_md(0), weights_md(0)) |
49 | && memory_desc_wrapper(src_md()) |
50 | == memory_desc_wrapper(dst_md()); |
51 | |
52 | if (!ok) return status::unimplemented; |
53 | |
54 | return init_conf(engine); |
55 | } |
56 | |
57 | status_t init_conf(engine_t *engine); |
58 | status_t init_kernel_ctx(compute::kernel_ctx_t &kernel_ctx) const; |
59 | |
60 | prelu_conf_t conf; |
61 | }; |
62 | |
63 | status_t init(engine_t *engine) override { |
64 | compute::kernel_ctx_t kernel_ctx; |
65 | |
66 | status_t status = pd()->init_kernel_ctx(kernel_ctx); |
67 | CHECK(status); |
68 | |
69 | status = create_kernel(engine, &kernel_, "ref_prelu_fwd" , kernel_ctx); |
70 | CHECK(status); |
71 | |
72 | return status::success; |
73 | } |
74 | |
75 | virtual status_t execute(const exec_ctx_t &ctx) const override { |
76 | return execute_forward(ctx); |
77 | } |
78 | |
79 | private: |
80 | status_t execute_forward(const exec_ctx_t &ctx) const; |
81 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
82 | |
83 | compute::kernel_t kernel_; |
84 | }; |
85 | |
86 | struct ref_prelu_bwd_t : public gpu_primitive_t { |
87 | using gpu_primitive_t::gpu_primitive_t; |
88 | struct pd_t : public gpu_prelu_bwd_pd_t { |
89 | using gpu_prelu_bwd_pd_t::gpu_prelu_bwd_pd_t; |
90 | |
91 | pd_t(const prelu_desc_t *adesc, const primitive_attr_t *attr, |
92 | const prelu_fwd_pd_t *hint_fwd_pd) |
93 | : gpu_prelu_bwd_pd_t(adesc, attr, hint_fwd_pd) {} |
94 | |
95 | pd_t(const pd_t &other) = default; |
96 | |
97 | ~pd_t() = default; |
98 | |
99 | DECLARE_COMMON_PD_T("prelu_ref:any" , ref_prelu_bwd_t); |
100 | |
101 | status_t init(engine_t *engine) { |
102 | |
103 | bool ok = !is_fwd() |
104 | && diff_dst_md()->data_type == diff_src_md()->data_type |
105 | && set_default_formats() && attr()->has_default_values() |
106 | && !memory_desc_ndims_ok( |
107 | diff_src_md(0), diff_dst_md(0), diff_weights_md(0)) |
108 | && memory_desc_wrapper(diff_dst_md()) |
109 | == memory_desc_wrapper(diff_src_md()); |
110 | |
111 | if (!ok) return status::unimplemented; |
112 | |
113 | status_t status = init_conf(engine); |
114 | if (conf.reduce_diff_weights) { |
115 | CHECK(init_reduction(engine)); |
116 | init_scratchpad(); |
117 | } |
118 | |
119 | return status; |
120 | } |
121 | |
122 | status_t init_conf(engine_t *engine); |
123 | status_t init_kernel_ctx(compute::kernel_ctx_t &kernel_ctx) const; |
124 | void init_scratchpad(); |
125 | |
126 | status_t init_reduction(engine_t *engine) { |
127 | reduction_desc_t rdesc; |
128 | memory_desc_t red_diff_mem_desc(*src_md(0)); |
129 | red_diff_mem_desc.data_type = dnnl_f32; |
130 | reduction_desc_init(&rdesc, dnnl_alg_kind_t::dnnl_reduction_sum, |
131 | &red_diff_mem_desc, diff_weights_md(0), 0, 0); |
132 | primitive_attr_t reduction_attr(*attr()); |
133 | if (!reduction_attr.is_initialized()) return status::out_of_memory; |
134 | primitive_desc_iterator_t it( |
135 | engine, (op_desc_t *)&rdesc, &reduction_attr, nullptr); |
136 | if (!it.is_initialized()) return status::invalid_arguments; |
137 | reduction_pd_ = *(++it); |
138 | if (reduction_pd_) |
139 | return status::success; |
140 | else { |
141 | return status::invalid_arguments; |
142 | } |
143 | } |
144 | |
145 | prelu_conf_t conf; |
146 | std::shared_ptr<primitive_desc_t> reduction_pd_; |
147 | }; |
148 | |
149 | status_t init(engine_t *engine) override { |
150 | compute::kernel_ctx_t kernel_ctx; |
151 | |
152 | status_t status = pd()->init_kernel_ctx(kernel_ctx); |
153 | CHECK(status); |
154 | |
155 | status = create_kernel(engine, &kernel_, "ref_prelu_bwd" , kernel_ctx); |
156 | CHECK(status); |
157 | |
158 | if (pd()->conf.reduce_diff_weights) { |
159 | CHECK(create_nested_primitive( |
160 | reduction_p_, pd()->reduction_pd_, engine)); |
161 | } |
162 | return status::success; |
163 | } |
164 | |
165 | status_t execute(const exec_ctx_t &ctx) const override { |
166 | return execute_backward(ctx); |
167 | } |
168 | |
169 | private: |
170 | status_t execute_backward(const exec_ctx_t &ctx) const; |
171 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
172 | |
173 | compute::kernel_t kernel_; |
174 | std::shared_ptr<primitive_t> reduction_p_; |
175 | }; |
176 | |
177 | } // namespace ocl |
178 | } // namespace gpu |
179 | } // namespace impl |
180 | } // namespace dnnl |
181 | |
182 | #endif |
183 | |