1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <cmath>
18#include "common/c_types_map.hpp"
19#include "common/dnnl_thread.hpp"
20#include "common/type_helpers.hpp"
21#include "common/utils.hpp"
22
23#include "cpu/x64/lrn/jit_uni_lrn.hpp"
24
25namespace dnnl {
26namespace impl {
27namespace cpu {
28namespace x64 {
29
30using namespace dnnl::impl::format_tag;
31using namespace dnnl::impl::status;
32using namespace dnnl::impl::utils;
33
34static constexpr int MAX_LOCAL_SIZE = 32u;
35
36static dnnl_dim_t compute_n_summands(
37 dnnl_dim_t size, int ndims, const dnnl_alg_kind_t &alg_kind) {
38 return alg_kind == alg_kind::lrn_across_channels
39 ? size
40 : std::pow(size, ndims - 2);
41};
42
43template <cpu_isa_t isa, data_type_t d_type>
44jit_uni_lrn_fwd_t<isa, d_type>::jit_uni_lrn_fwd_t(const pd_t *apd)
45 : primitive_t(apd)
46 , ker_(nullptr)
47 , ker_first_(nullptr)
48 , ker_last_(nullptr) {}
49
50template <cpu_isa_t isa, data_type_t d_type>
51jit_uni_lrn_fwd_t<isa, d_type>::~jit_uni_lrn_fwd_t() = default;
52
53template <cpu_isa_t isa, data_type_t d_type>
54status_t jit_uni_lrn_fwd_t<isa, d_type>::init(engine_t *engine) {
55 using namespace alg_kind;
56
57 const int C = pd()->C();
58 const int H = pd()->H();
59 const int W = pd()->W();
60 const int ndims = memory_desc_wrapper(pd()->src_md()).ndims();
61 const int ls = pd()->desc()->local_size;
62 const float K = pd()->desc()->lrn_k;
63 const auto pk = pd()->desc()->prop_kind;
64 const auto ak = pd()->desc()->alg_kind;
65 const auto dat_tag = pd()->dat_tag_;
66 const float A = pd()->desc()->lrn_alpha / compute_n_summands(ls, ndims, ak);
67
68 if (dat_tag == nChw8c && ls == 5 && ak == lrn_across_channels) {
69 ker_ = utils::make_unique<jit_uni_lrn_fwd_kernel_t<isa, d_type>>(
70 nchw8c_across_t(H, W, 0), A, K, pk);
71 ker_first_ = utils::make_unique<jit_uni_lrn_fwd_kernel_t<isa, d_type>>(
72 nchw8c_across_t(H, W, -1), A, K, pk);
73 ker_last_ = utils::make_unique<jit_uni_lrn_fwd_kernel_t<isa, d_type>>(
74 nchw8c_across_t(H, W, +1), A, K, pk);
75 } else if (one_of(dat_tag, nhwc, nChw8c, nChw16c)
76 && ak == lrn_within_channel) {
77
78 ker_ = utils::make_unique<jit_uni_lrn_fwd_kernel_t<isa, d_type>>(
79 within_config_t(H, W, C, ls, dat_tag), A, K, pk);
80 } else if (dat_tag == nchw && ls == 5 && ak == lrn_across_channels) {
81 ker_ = utils::make_unique<jit_uni_lrn_fwd_kernel_t<isa, d_type>>(
82 nchw_across_t(C, H * W, 0), A, K, pk);
83 const int remind = (H * W) % VECTOR_LENGTH;
84 if (remind != 0) {
85 ker_last_
86 = utils::make_unique<jit_uni_lrn_fwd_kernel_t<isa, d_type>>(
87 nchw_across_t(C, H * W, remind), A, K, pk);
88 }
89 } else {
90 ker_ = utils::make_unique<jit_uni_lrn_fwd_kernel_t<isa, d_type>>(
91 nhwc_across_t(C), A, K, pk);
92 }
93 CHECK(ker_->create_kernel());
94 if (ker_first_) CHECK(ker_first_->create_kernel());
95 if (ker_last_) CHECK(ker_last_->create_kernel());
96 return status::success;
97}
98
99template <cpu_isa_t isa, data_type_t d_type>
100status_t jit_uni_lrn_fwd_t<isa, d_type>::execute_forward(
101 const exec_ctx_t &ctx) const {
102 using namespace alg_kind;
103
104 status_t status = status::success;
105
106 auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
107 auto dst = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_DST, status);
108 CHECK(status);
109 auto ws = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_WORKSPACE, status);
110 CHECK(status);
111
112 const int N = pd()->MB();
113 const int C = pd()->C();
114 const int HW = pd()->H() * pd()->W();
115 const int ls = pd()->desc()->local_size;
116
117 const auto ak = pd()->desc()->alg_kind;
118 const auto dat_tag = pd()->dat_tag_;
119 const auto ker_first = ker_first_.get();
120 const auto ker = ker_.get();
121 const auto ker_last = ker_last_.get();
122
123 if (dat_tag == nChw8c && ls == 5 && ak == lrn_across_channels) {
124 parallel_nd(N, C / VECTOR_LENGTH, [&](dim_t n, dim_t c8) {
125 const auto offset = n * HW * C + c8 * HW * VECTOR_LENGTH;
126 auto ws_ptr = ws ? &ws[offset] : nullptr;
127 jit_args_fwd_t args {&src[offset], &dst[offset], ws_ptr, nullptr};
128 if (c8 == 0)
129 (*ker_first)(&args);
130 else if (c8 == C / VECTOR_LENGTH - 1)
131 (*ker_last)(&args);
132 else
133 (*ker)(&args);
134 });
135 } else if (one_of(dat_tag, nhwc, nChw8c, nChw16c)
136 && ak == lrn_within_channel) {
137 parallel_nd(N, C / VECTOR_LENGTH, [&](dim_t n, dim_t c) {
138 const std::size_t offset = dat_tag == nhwc
139 ? n * HW * C + c * VECTOR_LENGTH
140 : n * HW * C + c * HW * VECTOR_LENGTH;
141 auto ws0_ptr = ws ? &ws[offset] : nullptr;
142 auto ws1_ptr = ws ? &ws[offset + N * C * HW] : nullptr;
143 jit_args_fwd_t args {&src[offset], &dst[offset], ws0_ptr, ws1_ptr};
144 (*ker)(&args);
145 });
146 } else if (dat_tag == nchw && ls == 5 && ak == lrn_across_channels) {
147 parallel_nd(N, (HW + VECTOR_LENGTH - 1) / VECTOR_LENGTH,
148 [&](dim_t n, dim_t hw8) {
149 const auto offset = n * HW * C + hw8 * VECTOR_LENGTH;
150 auto ws0_ptr = ws ? &ws[offset] : nullptr;
151 jit_args_fwd_t args {
152 &src[offset], &dst[offset], ws0_ptr, nullptr};
153
154 if ((hw8 + 1) * VECTOR_LENGTH > HW)
155 (*ker_last)(&args);
156 else
157 (*ker)(&args);
158 });
159 } else { // nhwc
160 parallel_nd(N, HW, [&](dim_t n, dim_t hw) {
161 const auto offset = n * HW * C + hw * C;
162 auto ws_ptr = ws ? &ws[offset] : nullptr;
163 jit_args_fwd_t args {&src[offset], &dst[offset], ws_ptr, nullptr};
164 (*ker)(&args);
165 });
166 }
167 return status::success;
168}
169
170template <cpu_isa_t isa, data_type_t d_type>
171status_t jit_uni_lrn_fwd_t<isa, d_type>::pd_t::init(engine_t *engine) {
172 using namespace prop_kind;
173 using namespace alg_kind;
174
175 const memory_desc_wrapper src_d(src_md());
176 const memory_desc_wrapper dst_d(dst_md());
177
178 const bool ok = is_fwd() && mayiuse(isa) && !has_zero_dim_memory()
179 && everyone_is(d_type, src_d.data_type(), dst_d.data_type())
180 && attr()->has_default_values() && set_default_formats_common()
181 && src_d == dst_d && src_d.ndims() == 4
182 && src_d.dims()[1] % VECTOR_LENGTH == 0
183 && src_d.dims()[1] >= 2 * VECTOR_LENGTH && desc()->lrn_beta == 0.75;
184 if (!ok) return unimplemented;
185
186 dat_tag_ = memory_desc_matches_one_of_tag(
187 *src_md(), nChw16c, nChw8c, nchw, nhwc);
188
189 const int HW = src_d.dims()[2] * src_d.dims()[3];
190
191 const bool args_ok_across = true && desc()->alg_kind == lrn_across_channels
192 && desc()->local_size == 5 && one_of(dat_tag_, nChw8c, nchw, nhwc)
193 && everyone_is(data_type::f32, src_d.data_type())
194 /* SSE41: prevent loads smaller than the size of xmm registers,
195 * otherwise it will result in an illegal memory read (seg-fault)
196 * due to protected memory. */
197 && IMPLICATION(isa == sse41 && dat_tag_ == nchw, HW >= 4)
198 && !is_superset(isa, avx512_core);
199
200 const int jit_max_local_size = 5; // bigger size triggers too big code size
201 const bool args_ok_within = true && desc()->alg_kind == lrn_within_channel
202 && desc()->local_size <= (jit_max_local_size <= MAX_LOCAL_SIZE
203 ? jit_max_local_size
204 : MAX_LOCAL_SIZE)
205 && src_d.dims()[2] >= desc()->local_size
206 && src_d.dims()[3] >= desc()->local_size
207 && IMPLICATION(d_type == data_type::bf16, mayiuse(avx512_core))
208 && IMPLICATION(d_type == data_type::f16, mayiuse(avx512_core_fp16))
209 && (is_superset(isa, avx512_core) ? one_of(dat_tag_, nhwc, nChw16c)
210 : one_of(dat_tag_, nhwc, nChw8c));
211
212 const auto status
213 = args_ok_across || args_ok_within ? success : unimplemented;
214
215 if (desc()->prop_kind == forward_training && status == success) {
216 dims_t ws_dims = {MB(), C(), H(), 2 * W()};
217 memory_desc_init_by_tag(ws_md_, 4, ws_dims, d_type, dat_tag_);
218 }
219
220 return status;
221}
222
223template <cpu_isa_t isa, data_type_t d_type>
224jit_uni_lrn_bwd_t<isa, d_type>::jit_uni_lrn_bwd_t(const pd_t *apd)
225 : primitive_t(apd)
226 , ker_(nullptr)
227 , ker_first_(nullptr)
228 , ker_last_(nullptr) {}
229
230template <cpu_isa_t isa, data_type_t d_type>
231jit_uni_lrn_bwd_t<isa, d_type>::~jit_uni_lrn_bwd_t() = default;
232
233template <cpu_isa_t isa, data_type_t d_type>
234status_t jit_uni_lrn_bwd_t<isa, d_type>::init(engine_t *engine) {
235 using namespace alg_kind;
236 const int C = pd()->C();
237 const int H = pd()->H();
238 const int W = pd()->W();
239 const int &ls = pd()->desc()->local_size;
240 const auto &ak = pd()->desc()->alg_kind;
241 const int ndims = memory_desc_wrapper(pd()->src_md()).ndims();
242 const float A = pd()->desc()->lrn_alpha / compute_n_summands(ls, ndims, ak);
243 const float &B = pd()->desc()->lrn_beta;
244 const auto &dat_tag = pd()->dat_tag_;
245
246 if (one_of(dat_tag, nhwc, nChw8c, nChw16c) && ak == lrn_within_channel) {
247 ker_ = utils::make_unique<jit_uni_lrn_bwd_kernel_t<isa, d_type>>(
248 within_config_t(H, W, C, ls, dat_tag), A, B);
249 } else {
250 int use_h_parallelism = 0; // XXX
251 if (C / VECTOR_LENGTH == 1) {
252 ker_ = utils::make_unique<jit_uni_lrn_bwd_kernel_t<isa, d_type>>(
253 nchw8c_across_t(H, W, 3), A, B, use_h_parallelism);
254 } else {
255 ker_ = utils::make_unique<jit_uni_lrn_bwd_kernel_t<isa, d_type>>(
256 nchw8c_across_t(H, W, 0), A, B, use_h_parallelism);
257 ker_first_
258 = utils::make_unique<jit_uni_lrn_bwd_kernel_t<isa, d_type>>(
259 nchw8c_across_t(H, W, -1), A, B, use_h_parallelism);
260 ker_last_
261 = utils::make_unique<jit_uni_lrn_bwd_kernel_t<isa, d_type>>(
262 nchw8c_across_t(H, W, +1), A, B, use_h_parallelism);
263 }
264 }
265 CHECK(ker_->create_kernel());
266 if (ker_first_) CHECK(ker_first_->create_kernel());
267 if (ker_last_) CHECK(ker_last_->create_kernel());
268 return status::success;
269}
270
271template <cpu_isa_t isa, data_type_t d_type>
272status_t jit_uni_lrn_bwd_t<isa, d_type>::execute_backward(
273 const exec_ctx_t &ctx) const {
274 status_t status = status::success;
275 auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
276 auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST);
277 auto ws = CTX_IN_MEM(const data_t *, DNNL_ARG_WORKSPACE);
278 auto diff_src = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_DIFF_SRC, status);
279 CHECK(status);
280
281 const int N = pd()->MB();
282 const int C = pd()->C();
283 const int H = pd()->H();
284 const int W = pd()->W();
285 const auto ak = pd()->desc()->alg_kind;
286 const auto &dat_tag = pd()->dat_tag_;
287
288 static constexpr bool use_h_parallelism = false; // XXX
289
290 const auto ker = ker_.get();
291 const auto ker_first = ker_first_.get();
292 const auto ker_last = ker_last_.get();
293 const auto tensor_size = N * C * H * W;
294
295 if (one_of(dat_tag, nhwc, nChw8c, nChw16c)
296 && ak == alg_kind::lrn_within_channel) {
297 parallel_nd(N, C / VECTOR_LENGTH, [&](dim_t n, dim_t c) {
298 const std::size_t offset = dat_tag == nhwc
299 ? n * H * W * C + c * VECTOR_LENGTH
300 : n * H * W * C + c * H * W * VECTOR_LENGTH;
301 jit_args_bwd_t args {&src[offset], &diff_dst[offset], &ws[offset],
302 &ws[offset + tensor_size], &diff_src[offset]};
303 (*ker)(&args);
304 });
305 } else if (use_h_parallelism) {
306 parallel_nd(N, C / VECTOR_LENGTH, H, [&](dim_t n, dim_t c8, dim_t h) {
307 const std::size_t offset = n * C * H * W
308 + c8 * H * W * VECTOR_LENGTH + h * W * VECTOR_LENGTH;
309 jit_args_bwd_t args {&src[offset], &diff_dst[offset], &ws[offset],
310 nullptr, &diff_src[offset]};
311 if (C / VECTOR_LENGTH == 1)
312 (*ker)(&args);
313 else if (c8 == 0)
314 (*ker_first)(&args);
315 else if (c8 == C / VECTOR_LENGTH - 1)
316 (*ker_last)(&args);
317 else
318 (*ker)(&args);
319 });
320 } else {
321 parallel_nd(N, C / VECTOR_LENGTH, [&](dim_t n, dim_t c8) {
322 const std::size_t offset
323 = n * C * H * W + c8 * H * W * VECTOR_LENGTH;
324 jit_args_bwd_t args {&src[offset], &diff_dst[offset], &ws[offset],
325 nullptr, &diff_src[offset]};
326 if (C / VECTOR_LENGTH == 1)
327 (*ker)(&args);
328 else if (c8 == 0)
329 (*ker_first)(&args);
330 else if (c8 == C / VECTOR_LENGTH - 1)
331 (*ker_last)(&args);
332 else
333 (*ker)(&args);
334 });
335 }
336 return status::success;
337}
338
339template <cpu_isa_t isa, data_type_t d_type>
340status_t jit_uni_lrn_bwd_t<isa, d_type>::pd_t::init(engine_t *engine) {
341 using namespace prop_kind;
342 using namespace alg_kind;
343
344 const memory_desc_wrapper src_d(src_md());
345 const memory_desc_wrapper diff_src_d(diff_src_md());
346 const memory_desc_wrapper diff_dst_d(diff_dst_md());
347
348 const bool ok = !is_fwd() && mayiuse(avx512_core) && !has_zero_dim_memory()
349 && utils::everyone_is(d_type, src_d.data_type(),
350 diff_src_d.data_type(), diff_dst_d.data_type())
351 && src_d.ndims() == 4 && attr()->has_default_values()
352 && set_default_formats_common() && src_d == diff_dst_d
353 && diff_dst_d == diff_src_d && src_d.dims()[1] % VECTOR_LENGTH == 0
354 && src_d.dims()[1] >= 2 * VECTOR_LENGTH && desc()->lrn_beta == 0.75;
355 if (!ok) return unimplemented;
356
357 dat_tag_ = memory_desc_matches_one_of_tag(
358 *src_md(), nChw16c, nChw8c, nchw, nhwc);
359
360 const dims_t ws_dims = {MB(), C(), H(), 2 * W()};
361 memory_desc_init_by_tag(ws_md_, 4, ws_dims, d_type, dat_tag_);
362
363 if (!compare_ws(hint_fwd_pd_)) return unimplemented;
364
365 const bool args_ok_across = true && desc()->alg_kind == lrn_across_channels
366 && desc()->local_size == 5 && utils::one_of(dat_tag_, nChw8c)
367 && everyone_is(data_type::f32, src_d.data_type())
368 && !is_superset(isa, avx512_core);
369
370 const int jit_max_local_size = 5; // bigger size triggers too big code size
371 const bool args_ok_within = true && desc()->alg_kind == lrn_within_channel
372 && desc()->local_size <= (jit_max_local_size <= MAX_LOCAL_SIZE
373 ? jit_max_local_size
374 : MAX_LOCAL_SIZE)
375 && src_d.dims()[2] >= desc()->local_size
376 && src_d.dims()[3] >= desc()->local_size
377 && IMPLICATION(d_type == data_type::bf16, mayiuse(avx512_core))
378 && IMPLICATION(d_type == data_type::f16, mayiuse(avx512_core_fp16))
379 && (isa == avx512_core ? one_of(dat_tag_, nhwc, nChw16c)
380 : one_of(dat_tag_, nhwc, nChw8c));
381
382 return args_ok_across || args_ok_within ? success : unimplemented;
383}
384
385template struct jit_uni_lrn_fwd_t<avx512_core, dnnl::impl::data_type::f32>;
386template struct jit_uni_lrn_fwd_t<avx512_core, dnnl::impl::data_type::bf16>;
387template struct jit_uni_lrn_fwd_t<avx512_core_fp16, dnnl::impl::data_type::f16>;
388template struct jit_uni_lrn_fwd_t<avx2, dnnl::impl::data_type::f32>;
389template struct jit_uni_lrn_fwd_t<sse41, dnnl::impl::data_type::f32>;
390template struct jit_uni_lrn_bwd_t<avx512_core, dnnl::impl::data_type::f32>;
391template struct jit_uni_lrn_bwd_t<avx512_core, dnnl::impl::data_type::bf16>;
392template struct jit_uni_lrn_bwd_t<avx512_core_fp16, dnnl::impl::data_type::f16>;
393template struct jit_uni_lrn_bwd_t<avx2, dnnl::impl::data_type::f32>;
394
395} // namespace x64
396} // namespace cpu
397} // namespace impl
398} // namespace dnnl
399
400// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
401