1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef GPU_OCL_REF_LRN_HPP |
18 | #define GPU_OCL_REF_LRN_HPP |
19 | |
20 | #include "common/c_types_map.hpp" |
21 | #include "common/nstl.hpp" |
22 | #include "common/primitive.hpp" |
23 | #include "common/type_helpers.hpp" |
24 | #include "gpu/compute/compute.hpp" |
25 | #include "gpu/gpu_lrn_pd.hpp" |
26 | #include "gpu/gpu_primitive.hpp" |
27 | #include "gpu/gpu_resource.hpp" |
28 | #include "gpu/ocl/ocl_stream.hpp" |
29 | #include "gpu/ocl/ocl_utils.hpp" |
30 | #include "gpu/primitive_conf.hpp" |
31 | |
32 | namespace dnnl { |
33 | namespace impl { |
34 | namespace gpu { |
35 | namespace ocl { |
36 | |
37 | struct ref_lrn_fwd_t : public gpu_primitive_t { |
38 | using gpu_primitive_t::gpu_primitive_t; |
39 | struct pd_t : public gpu_lrn_fwd_pd_t { |
40 | pd_t(const lrn_desc_t *adesc, const primitive_attr_t *attr, |
41 | const lrn_fwd_pd_t *hint_fwd_pd) |
42 | : gpu_lrn_fwd_pd_t(adesc, attr, hint_fwd_pd) {} |
43 | virtual ~pd_t() {} |
44 | |
45 | DECLARE_COMMON_PD_T("ref:any" , ref_lrn_fwd_t); |
46 | |
47 | status_t init(engine_t *engine) { |
48 | using namespace data_type; |
49 | assert(engine->kind() == engine_kind::gpu); |
50 | auto *compute_engine |
51 | = utils::downcast<compute::compute_engine_t *>(engine); |
52 | bool ok = is_fwd() |
53 | && utils::one_of(src_md()->data_type, f32, f16, bf16) |
54 | && src_md()->data_type == dst_md()->data_type |
55 | && attr()->has_default_values() |
56 | && IMPLICATION(src_md()->data_type == f16, |
57 | compute_engine->mayiuse( |
58 | compute::device_ext_t::khr_fp16)) |
59 | && set_default_formats_common() |
60 | && memory_desc_wrapper(src_md()) |
61 | == memory_desc_wrapper(dst_md()); |
62 | if (!ok) return status::unimplemented; |
63 | |
64 | if (desc_.prop_kind == prop_kind::forward_training) { |
65 | ws_md_ = *src_md(); |
66 | if (ws_md_.data_type == data_type::bf16 |
67 | || ws_md_.data_type == data_type::f16) |
68 | ws_md_.data_type = data_type::f32; |
69 | } |
70 | |
71 | dispatch = compute_engine->create_dispatch(src_md()); |
72 | dispatch.define_dim("MB" , 0, MB()); |
73 | dispatch.define_dim("IC" , 1, C()); |
74 | dispatch.define_dim("ID" , nstl::max(1, src_md()->ndims - 3), D()); |
75 | dispatch.define_dim("IH" , nstl::max(1, src_md()->ndims - 2), H()); |
76 | dispatch.define_dim("IW" , nstl::max(1, src_md()->ndims - 1), W()); |
77 | dispatch.generate(); |
78 | |
79 | return status::success; |
80 | } |
81 | |
82 | compute::dispatch_t dispatch; |
83 | }; |
84 | |
85 | status_t init(engine_t *engine) override { |
86 | using namespace alg_kind; |
87 | |
88 | compute::kernel_ctx_t kernel_ctx; |
89 | |
90 | status_t status = status::success; |
91 | const auto *desc = pd()->desc(); |
92 | |
93 | kernel_ctx.set_data_type(desc->src_desc.data_type); |
94 | |
95 | kernel_ctx.define_int("IS_FWD" , 1); |
96 | |
97 | if (desc->prop_kind == prop_kind::forward_training) |
98 | kernel_ctx.define_int("IS_TRAINING" , 1); |
99 | |
100 | switch (desc->alg_kind) { |
101 | case lrn_across_channels: |
102 | kernel_ctx.define_int("ACROSS_CHANNEL" , 1); |
103 | break; |
104 | case lrn_within_channel: |
105 | kernel_ctx.define_int("WITHIN_CHANNEL" , 1); |
106 | break; |
107 | default: status = status::unimplemented; |
108 | } |
109 | if (status != status::success) return status; |
110 | |
111 | const memory_desc_wrapper src_d(pd()->src_md()); |
112 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
113 | const int ndims = src_d.ndims(); |
114 | |
115 | kernel_ctx.define_int("NDIMS" , ndims); |
116 | kernel_ctx.define_int("MB" , pd()->MB()); |
117 | kernel_ctx.define_int("IC" , pd()->C()); |
118 | kernel_ctx.define_int("ID" , pd()->D()); |
119 | kernel_ctx.define_int("IH" , pd()->H()); |
120 | kernel_ctx.define_int("IW" , pd()->W()); |
121 | |
122 | const uint32_t round_norm_size = desc->local_size; |
123 | uint32_t num_elements = pow(round_norm_size, nstl::max(0, ndims - 2)); |
124 | if (desc->alg_kind == lrn_across_channels) { |
125 | num_elements = round_norm_size; |
126 | } |
127 | const float num_element_div = 1.f / (float)num_elements; |
128 | const auto padding = (desc->local_size - 1) / 2; |
129 | |
130 | kernel_ctx.define_float("NUM_ELEMENTS_DIV" , num_element_div); |
131 | kernel_ctx.define_int("PADDING" , padding); |
132 | kernel_ctx.define_int( |
133 | "LOCAL_SIZE" , desc->local_size - 1 + desc->local_size % 2); |
134 | kernel_ctx.define_float("LRN_ALPHA" , desc->lrn_alpha); |
135 | kernel_ctx.define_float("LRN_BETA" , desc->lrn_beta); |
136 | kernel_ctx.define_float("LRN_K" , desc->lrn_k); |
137 | |
138 | offsets_t off; |
139 | set_offsets(src_d, off.src_off); |
140 | set_offsets(dst_d, off.dst_off); |
141 | def_offsets(off.src_off, kernel_ctx, "SRC" , ndims); |
142 | def_offsets(off.dst_off, kernel_ctx, "DST" , ndims); |
143 | |
144 | def_dispatch(kernel_ctx, pd()->dispatch); |
145 | |
146 | create_kernel(engine, &kernel_, "ref_lrn_fwd" , kernel_ctx); |
147 | if (!kernel_) return status::runtime_error; |
148 | |
149 | return status::success; |
150 | } |
151 | |
152 | status_t execute(const exec_ctx_t &ctx) const override { |
153 | return execute_forward(ctx); |
154 | } |
155 | |
156 | private: |
157 | status_t execute_forward(const exec_ctx_t &ctx) const; |
158 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
159 | compute::kernel_t kernel_; |
160 | }; |
161 | |
162 | struct ref_lrn_bwd_t : public gpu_primitive_t { |
163 | using gpu_primitive_t::gpu_primitive_t; |
164 | struct pd_t : public gpu_lrn_bwd_pd_t { |
165 | pd_t(const lrn_desc_t *adesc, const primitive_attr_t *attr, |
166 | const lrn_fwd_pd_t *hint_fwd_pd) |
167 | : gpu_lrn_bwd_pd_t(adesc, attr, hint_fwd_pd) {} |
168 | virtual ~pd_t() {} |
169 | |
170 | DECLARE_COMMON_PD_T("ref:any" , ref_lrn_bwd_t); |
171 | |
172 | status_t init(engine_t *engine) { |
173 | using namespace data_type; |
174 | assert(engine->kind() == engine_kind::gpu); |
175 | auto *compute_engine |
176 | = utils::downcast<compute::compute_engine_t *>(engine); |
177 | bool ok = !is_fwd() && utils::one_of(src_md()->data_type, f32, bf16) |
178 | && utils::everyone_is(src_md()->data_type, |
179 | diff_src_md()->data_type, diff_dst_md()->data_type) |
180 | && attr()->has_default_values() |
181 | && set_default_formats_common() |
182 | && memory_desc_wrapper(diff_src_md()) |
183 | == memory_desc_wrapper(diff_dst_md()); |
184 | if (!ok) return status::unimplemented; |
185 | |
186 | ws_md_ = *src_md(); |
187 | if (ws_md_.data_type == data_type::bf16) |
188 | ws_md_.data_type = data_type::f32; |
189 | if (!compare_ws(hint_fwd_pd_)) return status::unimplemented; |
190 | |
191 | dispatch = compute_engine->create_dispatch(diff_src_md()); |
192 | dispatch.define_dim("MB" , 0, MB()); |
193 | dispatch.define_dim("IC" , 1, C()); |
194 | dispatch.define_dim("ID" , nstl::max(1, src_md()->ndims - 3), D()); |
195 | dispatch.define_dim("IH" , nstl::max(1, src_md()->ndims - 2), H()); |
196 | dispatch.define_dim("IW" , nstl::max(1, src_md()->ndims - 1), W()); |
197 | dispatch.generate(); |
198 | |
199 | return status::success; |
200 | } |
201 | |
202 | compute::dispatch_t dispatch; |
203 | }; |
204 | |
205 | status_t init(engine_t *engine) override { |
206 | using namespace alg_kind; |
207 | |
208 | compute::kernel_ctx_t kernel_ctx; |
209 | |
210 | status_t status = status::success; |
211 | const auto *desc = pd()->desc(); |
212 | |
213 | kernel_ctx.set_data_type(desc->src_desc.data_type); |
214 | |
215 | kernel_ctx.define_int("IS_BWD" , 1); |
216 | |
217 | switch (desc->alg_kind) { |
218 | case lrn_across_channels: |
219 | kernel_ctx.define_int("ACROSS_CHANNEL" , 1); |
220 | break; |
221 | case lrn_within_channel: |
222 | kernel_ctx.define_int("WITHIN_CHANNEL" , 1); |
223 | break; |
224 | default: status = status::unimplemented; |
225 | } |
226 | if (status != status::success) return status; |
227 | |
228 | const memory_desc_wrapper src_d(pd()->src_md()); |
229 | const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); |
230 | const int ndims = src_d.ndims(); |
231 | |
232 | kernel_ctx.define_int("NDIMS" , ndims); |
233 | kernel_ctx.define_int("MB" , pd()->MB()); |
234 | kernel_ctx.define_int("IC" , pd()->C()); |
235 | kernel_ctx.define_int("ID" , pd()->D()); |
236 | kernel_ctx.define_int("IH" , pd()->H()); |
237 | kernel_ctx.define_int("IW" , pd()->W()); |
238 | |
239 | const uint32_t round_norm_size = desc->local_size; |
240 | uint32_t num_elements = pow(round_norm_size, nstl::max(0, ndims - 2)); |
241 | if (desc->alg_kind == lrn_across_channels) { |
242 | num_elements = round_norm_size; |
243 | } |
244 | const float num_element_div = 1.f / (float)num_elements; |
245 | const auto padding = (desc->local_size - 1) / 2; |
246 | |
247 | kernel_ctx.define_float("NUM_ELEMENTS_DIV" , num_element_div); |
248 | kernel_ctx.define_int("PADDING" , padding); |
249 | kernel_ctx.define_int( |
250 | "LOCAL_SIZE" , desc->local_size - 1 + desc->local_size % 2); |
251 | kernel_ctx.define_float("LRN_ALPHA" , desc->lrn_alpha); |
252 | kernel_ctx.define_float("LRN_BETA" , desc->lrn_beta); |
253 | kernel_ctx.define_float("LRN_K" , desc->lrn_k); |
254 | |
255 | offsets_t off; |
256 | set_offsets(src_d, off.src_off); |
257 | set_offsets(diff_dst_d, off.dst_off); |
258 | def_offsets(off.src_off, kernel_ctx, "SRC" , ndims); |
259 | def_offsets(off.dst_off, kernel_ctx, "DST" , ndims); |
260 | |
261 | def_dispatch(kernel_ctx, pd()->dispatch); |
262 | |
263 | create_kernel(engine, &kernel_, "ref_lrn_bwd" , kernel_ctx); |
264 | if (!kernel_) return status::runtime_error; |
265 | |
266 | return status::success; |
267 | } |
268 | |
269 | status_t execute(const exec_ctx_t &ctx) const override { |
270 | return execute_backward(ctx); |
271 | } |
272 | |
273 | private: |
274 | status_t execute_backward(const exec_ctx_t &ctx) const; |
275 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
276 | |
277 | compute::kernel_t kernel_; |
278 | }; |
279 | |
280 | } // namespace ocl |
281 | } // namespace gpu |
282 | } // namespace impl |
283 | } // namespace dnnl |
284 | |
285 | #endif |
286 | |