1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_OCL_REF_LRN_HPP
18#define GPU_OCL_REF_LRN_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/nstl.hpp"
22#include "common/primitive.hpp"
23#include "common/type_helpers.hpp"
24#include "gpu/compute/compute.hpp"
25#include "gpu/gpu_lrn_pd.hpp"
26#include "gpu/gpu_primitive.hpp"
27#include "gpu/gpu_resource.hpp"
28#include "gpu/ocl/ocl_stream.hpp"
29#include "gpu/ocl/ocl_utils.hpp"
30#include "gpu/primitive_conf.hpp"
31
32namespace dnnl {
33namespace impl {
34namespace gpu {
35namespace ocl {
36
37struct ref_lrn_fwd_t : public gpu_primitive_t {
38 using gpu_primitive_t::gpu_primitive_t;
39 struct pd_t : public gpu_lrn_fwd_pd_t {
40 pd_t(const lrn_desc_t *adesc, const primitive_attr_t *attr,
41 const lrn_fwd_pd_t *hint_fwd_pd)
42 : gpu_lrn_fwd_pd_t(adesc, attr, hint_fwd_pd) {}
43 virtual ~pd_t() {}
44
45 DECLARE_COMMON_PD_T("ref:any", ref_lrn_fwd_t);
46
47 status_t init(engine_t *engine) {
48 using namespace data_type;
49 assert(engine->kind() == engine_kind::gpu);
50 auto *compute_engine
51 = utils::downcast<compute::compute_engine_t *>(engine);
52 bool ok = is_fwd()
53 && utils::one_of(src_md()->data_type, f32, f16, bf16)
54 && src_md()->data_type == dst_md()->data_type
55 && attr()->has_default_values()
56 && IMPLICATION(src_md()->data_type == f16,
57 compute_engine->mayiuse(
58 compute::device_ext_t::khr_fp16))
59 && set_default_formats_common()
60 && memory_desc_wrapper(src_md())
61 == memory_desc_wrapper(dst_md());
62 if (!ok) return status::unimplemented;
63
64 if (desc_.prop_kind == prop_kind::forward_training) {
65 ws_md_ = *src_md();
66 if (ws_md_.data_type == data_type::bf16
67 || ws_md_.data_type == data_type::f16)
68 ws_md_.data_type = data_type::f32;
69 }
70
71 dispatch = compute_engine->create_dispatch(src_md());
72 dispatch.define_dim("MB", 0, MB());
73 dispatch.define_dim("IC", 1, C());
74 dispatch.define_dim("ID", nstl::max(1, src_md()->ndims - 3), D());
75 dispatch.define_dim("IH", nstl::max(1, src_md()->ndims - 2), H());
76 dispatch.define_dim("IW", nstl::max(1, src_md()->ndims - 1), W());
77 dispatch.generate();
78
79 return status::success;
80 }
81
82 compute::dispatch_t dispatch;
83 };
84
85 status_t init(engine_t *engine) override {
86 using namespace alg_kind;
87
88 compute::kernel_ctx_t kernel_ctx;
89
90 status_t status = status::success;
91 const auto *desc = pd()->desc();
92
93 kernel_ctx.set_data_type(desc->src_desc.data_type);
94
95 kernel_ctx.define_int("IS_FWD", 1);
96
97 if (desc->prop_kind == prop_kind::forward_training)
98 kernel_ctx.define_int("IS_TRAINING", 1);
99
100 switch (desc->alg_kind) {
101 case lrn_across_channels:
102 kernel_ctx.define_int("ACROSS_CHANNEL", 1);
103 break;
104 case lrn_within_channel:
105 kernel_ctx.define_int("WITHIN_CHANNEL", 1);
106 break;
107 default: status = status::unimplemented;
108 }
109 if (status != status::success) return status;
110
111 const memory_desc_wrapper src_d(pd()->src_md());
112 const memory_desc_wrapper dst_d(pd()->dst_md());
113 const int ndims = src_d.ndims();
114
115 kernel_ctx.define_int("NDIMS", ndims);
116 kernel_ctx.define_int("MB", pd()->MB());
117 kernel_ctx.define_int("IC", pd()->C());
118 kernel_ctx.define_int("ID", pd()->D());
119 kernel_ctx.define_int("IH", pd()->H());
120 kernel_ctx.define_int("IW", pd()->W());
121
122 const uint32_t round_norm_size = desc->local_size;
123 uint32_t num_elements = pow(round_norm_size, nstl::max(0, ndims - 2));
124 if (desc->alg_kind == lrn_across_channels) {
125 num_elements = round_norm_size;
126 }
127 const float num_element_div = 1.f / (float)num_elements;
128 const auto padding = (desc->local_size - 1) / 2;
129
130 kernel_ctx.define_float("NUM_ELEMENTS_DIV", num_element_div);
131 kernel_ctx.define_int("PADDING", padding);
132 kernel_ctx.define_int(
133 "LOCAL_SIZE", desc->local_size - 1 + desc->local_size % 2);
134 kernel_ctx.define_float("LRN_ALPHA", desc->lrn_alpha);
135 kernel_ctx.define_float("LRN_BETA", desc->lrn_beta);
136 kernel_ctx.define_float("LRN_K", desc->lrn_k);
137
138 offsets_t off;
139 set_offsets(src_d, off.src_off);
140 set_offsets(dst_d, off.dst_off);
141 def_offsets(off.src_off, kernel_ctx, "SRC", ndims);
142 def_offsets(off.dst_off, kernel_ctx, "DST", ndims);
143
144 def_dispatch(kernel_ctx, pd()->dispatch);
145
146 create_kernel(engine, &kernel_, "ref_lrn_fwd", kernel_ctx);
147 if (!kernel_) return status::runtime_error;
148
149 return status::success;
150 }
151
152 status_t execute(const exec_ctx_t &ctx) const override {
153 return execute_forward(ctx);
154 }
155
156private:
157 status_t execute_forward(const exec_ctx_t &ctx) const;
158 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
159 compute::kernel_t kernel_;
160};
161
162struct ref_lrn_bwd_t : public gpu_primitive_t {
163 using gpu_primitive_t::gpu_primitive_t;
164 struct pd_t : public gpu_lrn_bwd_pd_t {
165 pd_t(const lrn_desc_t *adesc, const primitive_attr_t *attr,
166 const lrn_fwd_pd_t *hint_fwd_pd)
167 : gpu_lrn_bwd_pd_t(adesc, attr, hint_fwd_pd) {}
168 virtual ~pd_t() {}
169
170 DECLARE_COMMON_PD_T("ref:any", ref_lrn_bwd_t);
171
172 status_t init(engine_t *engine) {
173 using namespace data_type;
174 assert(engine->kind() == engine_kind::gpu);
175 auto *compute_engine
176 = utils::downcast<compute::compute_engine_t *>(engine);
177 bool ok = !is_fwd() && utils::one_of(src_md()->data_type, f32, bf16)
178 && utils::everyone_is(src_md()->data_type,
179 diff_src_md()->data_type, diff_dst_md()->data_type)
180 && attr()->has_default_values()
181 && set_default_formats_common()
182 && memory_desc_wrapper(diff_src_md())
183 == memory_desc_wrapper(diff_dst_md());
184 if (!ok) return status::unimplemented;
185
186 ws_md_ = *src_md();
187 if (ws_md_.data_type == data_type::bf16)
188 ws_md_.data_type = data_type::f32;
189 if (!compare_ws(hint_fwd_pd_)) return status::unimplemented;
190
191 dispatch = compute_engine->create_dispatch(diff_src_md());
192 dispatch.define_dim("MB", 0, MB());
193 dispatch.define_dim("IC", 1, C());
194 dispatch.define_dim("ID", nstl::max(1, src_md()->ndims - 3), D());
195 dispatch.define_dim("IH", nstl::max(1, src_md()->ndims - 2), H());
196 dispatch.define_dim("IW", nstl::max(1, src_md()->ndims - 1), W());
197 dispatch.generate();
198
199 return status::success;
200 }
201
202 compute::dispatch_t dispatch;
203 };
204
205 status_t init(engine_t *engine) override {
206 using namespace alg_kind;
207
208 compute::kernel_ctx_t kernel_ctx;
209
210 status_t status = status::success;
211 const auto *desc = pd()->desc();
212
213 kernel_ctx.set_data_type(desc->src_desc.data_type);
214
215 kernel_ctx.define_int("IS_BWD", 1);
216
217 switch (desc->alg_kind) {
218 case lrn_across_channels:
219 kernel_ctx.define_int("ACROSS_CHANNEL", 1);
220 break;
221 case lrn_within_channel:
222 kernel_ctx.define_int("WITHIN_CHANNEL", 1);
223 break;
224 default: status = status::unimplemented;
225 }
226 if (status != status::success) return status;
227
228 const memory_desc_wrapper src_d(pd()->src_md());
229 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
230 const int ndims = src_d.ndims();
231
232 kernel_ctx.define_int("NDIMS", ndims);
233 kernel_ctx.define_int("MB", pd()->MB());
234 kernel_ctx.define_int("IC", pd()->C());
235 kernel_ctx.define_int("ID", pd()->D());
236 kernel_ctx.define_int("IH", pd()->H());
237 kernel_ctx.define_int("IW", pd()->W());
238
239 const uint32_t round_norm_size = desc->local_size;
240 uint32_t num_elements = pow(round_norm_size, nstl::max(0, ndims - 2));
241 if (desc->alg_kind == lrn_across_channels) {
242 num_elements = round_norm_size;
243 }
244 const float num_element_div = 1.f / (float)num_elements;
245 const auto padding = (desc->local_size - 1) / 2;
246
247 kernel_ctx.define_float("NUM_ELEMENTS_DIV", num_element_div);
248 kernel_ctx.define_int("PADDING", padding);
249 kernel_ctx.define_int(
250 "LOCAL_SIZE", desc->local_size - 1 + desc->local_size % 2);
251 kernel_ctx.define_float("LRN_ALPHA", desc->lrn_alpha);
252 kernel_ctx.define_float("LRN_BETA", desc->lrn_beta);
253 kernel_ctx.define_float("LRN_K", desc->lrn_k);
254
255 offsets_t off;
256 set_offsets(src_d, off.src_off);
257 set_offsets(diff_dst_d, off.dst_off);
258 def_offsets(off.src_off, kernel_ctx, "SRC", ndims);
259 def_offsets(off.dst_off, kernel_ctx, "DST", ndims);
260
261 def_dispatch(kernel_ctx, pd()->dispatch);
262
263 create_kernel(engine, &kernel_, "ref_lrn_bwd", kernel_ctx);
264 if (!kernel_) return status::runtime_error;
265
266 return status::success;
267 }
268
269 status_t execute(const exec_ctx_t &ctx) const override {
270 return execute_backward(ctx);
271 }
272
273private:
274 status_t execute_backward(const exec_ctx_t &ctx) const;
275 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
276
277 compute::kernel_t kernel_;
278};
279
280} // namespace ocl
281} // namespace gpu
282} // namespace impl
283} // namespace dnnl
284
285#endif
286