1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <assert.h>
18#include <math.h>
19
20#include "common/c_types_map.hpp"
21#include "common/dnnl_thread.hpp"
22#include "common/type_helpers.hpp"
23
24#include "cpu/ref_lrn.hpp"
25
26namespace dnnl {
27namespace impl {
28namespace cpu {
29
30namespace {
31
32using acc_data_t = float;
33
34inline acc_data_t fast_negative_powf(acc_data_t omega, acc_data_t beta) {
35 acc_data_t Y;
36 /*
37 * Y = omega^(-3/4) =
38 * = 1.0f / sqrtf(omega) * sqrtf(1.0f / sqrtf(omega))
39 * = sqrtf(1.0f / sqrtf(omega)) * 1.0f / sqrtf(omega)
40 * = sqrtf(1.0f / sqrtf(omega)) / sqrtf(omega)
41 * = sqrtf(1.0f / sqrtf(omega) / omega)
42 * = sqrtf(1.0f / (sqrtf(omega) * omega))
43 */
44 if (beta == 0.75f) {
45 Y = sqrtf(1.0f / (sqrtf(omega) * omega));
46 } else {
47 Y = 1.0f / powf(omega, beta);
48 }
49 return Y;
50};
51} // namespace
52
53// Forward LRN formula:
54// y_i = x_i * (k + a / n * Sum:j [x_j^2])^-b, where
55// k, a(alpha), b(beta), n(local_size) - lrn hyperparameters;
56// j - kernel points, j in [i - n/2, i + n/2] for ACROSS, 2d-shape for WITHIN;
57
58template <impl::data_type_t d_type>
59template <impl::format_tag_t tag>
60status_t ref_lrn_fwd_t<d_type>::execute_forward(const exec_ctx_t &ctx) const {
61 using namespace alg_kind;
62 using namespace format_tag;
63
64 status_t status = status::success;
65
66 auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
67 auto dst = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_DST, status);
68 CHECK(status);
69
70 const memory_desc_wrapper data_d(pd()->src_md());
71
72 const dim_t C = pd()->C();
73 const dim_t D = pd()->D();
74 const dim_t H = pd()->H();
75 const dim_t W = pd()->W();
76 const auto stride_mb = data_d.blocking_desc().strides[0];
77 const bool across_channels = pd()->desc()->alg_kind == lrn_across_channels;
78 static constexpr dim_t blksize = tag == nChw16c ? 16 : 8;
79 const auto ndims = data_d.ndims();
80
81 auto compute_n_summands = [&](dim_t size) {
82 if (across_channels) {
83 return size;
84 } else { // within_channel
85 dim_t n_summands = 1;
86 for (auto d = ndims - 2; d > 0; --d)
87 n_summands *= size;
88 return n_summands;
89 }
90 };
91
92 const acc_data_t alpha = static_cast<acc_data_t>(pd()->desc()->lrn_alpha);
93 const acc_data_t beta = static_cast<acc_data_t>(pd()->desc()->lrn_beta);
94 const acc_data_t k = static_cast<acc_data_t>(pd()->desc()->lrn_k);
95 const dim_t size = pd()->desc()->local_size;
96 const dim_t half_size = (size - 1) / 2;
97 const dim_t summands = compute_n_summands(size);
98
99 auto data_off = [&](dim_t mb, dim_t c, dim_t d, dim_t h, dim_t w) -> dim_t {
100 switch (tag) {
101 case nChw16c:
102 case nChw8c:
103 return mb * stride_mb + (c / blksize) * H * W * blksize
104 + h * W * blksize + w * blksize + c % blksize;
105 case nchw: return mb * stride_mb + c * H * W + h * W + w;
106 case nhwc: return mb * stride_mb + h * W * C + w * C + c;
107 default:
108 if (ndims >= 5) return data_d.off(mb, c, d, h, w);
109 if (ndims >= 4) return data_d.off(mb, c, h, w);
110 if (ndims >= 3) return data_d.off(mb, c, w);
111 return data_d.off(mb, c);
112 }
113 };
114
115 // pass by value due to icc170 and icc180 problem on KNL
116 auto ker = [=](data_t *d, dim_t mb, dim_t oc, dim_t od, dim_t oh,
117 dim_t ow) {
118 acc_data_t sum = 0;
119 if (across_channels) {
120 const dim_t c_st = nstl::max(oc - half_size + 0, (dim_t)0);
121 const dim_t c_en = nstl::min(oc + half_size + 1, C);
122
123 for (dim_t c = c_st; c < c_en; ++c) {
124 const acc_data_t s = src[data_off(mb, c, od, oh, ow)];
125 sum += s * s;
126 }
127 } else {
128 dim_t d_st = nstl::max(od - half_size + 0, (dim_t)0);
129 dim_t d_en = nstl::min(od + half_size + 1, D);
130 dim_t h_st = nstl::max(oh - half_size + 0, (dim_t)0);
131 dim_t h_en = nstl::min(oh + half_size + 1, H);
132 dim_t w_st = nstl::max(ow - half_size + 0, (dim_t)0);
133 dim_t w_en = nstl::min(ow + half_size + 1, W);
134 for_(dim_t d = d_st; d < d_en; ++d)
135 for_(dim_t h = h_st; h < h_en; ++h)
136 for (dim_t w = w_st; w < w_en; ++w) {
137 const acc_data_t s = src[data_off(mb, oc, d, h, w)];
138 sum += s * s;
139 }
140 }
141 sum = k + alpha * sum / summands;
142 const acc_data_t s = src[data_off(mb, oc, od, oh, ow)];
143 d[0] = static_cast<data_t>(s * fast_negative_powf(sum, beta));
144 };
145
146 const dim_t MB = pd()->MB();
147 if (tag == nChw16c || tag == nChw8c) {
148 parallel_nd(MB, utils::div_up(C, blksize), H, W,
149 [&](dim_t mb, dim_t c_blk, dim_t h, dim_t w) {
150 dim_t c = c_blk * blksize;
151 const dim_t off = mb * stride_mb + c * H * W
152 + (h * W + w) * blksize;
153 PRAGMA_OMP_SIMD()
154 for (dim_t cc = 0; cc < nstl::min(blksize, C - c); ++cc)
155 ker(&dst[off + cc], mb, c + cc, 0, h, w);
156 });
157 } else if (tag == nhwc) {
158 parallel_nd(MB, H, W, C, [&](dim_t mb, dim_t h, dim_t w, dim_t c) {
159 const dim_t off = mb * stride_mb + h * W * C + w * C + c;
160 ker(&dst[off], mb, c, 0, h, w);
161 });
162 } else {
163 parallel_nd(MB, C, D, H, W,
164 [&](dim_t mb, dim_t c, dim_t d, dim_t h, dim_t w) {
165 const dim_t off = data_off(mb, c, d, h, w);
166 ker(&dst[off], mb, c, d, h, w);
167 });
168 }
169 return status::success;
170}
171
172// Backward LRN formula (refer to Forward LRN formula):
173// Partial derivatives:
174// dy_i/dx_j = - 2*a*b/n * x_i * O(i)^-b / O(i) * x_j, i != j
175// O(i)^-b - 2*a*b/n * x_i * O(i)^-b / O(i) * x_j, i == j, where
176// O(i) = (k + a / n * Sum:j [x_j^2]), j in [i - n/2, i + n/2]. Note: j depends
177// on i, which means that O(i) may use more points than local_size.
178// Now, z_i = Sum:k [dE/dy_k * dy_k/dx_j], where k in [i - n/2, i + n/2]
179// for ACROSS. 2d-shape for WITHIN.
180// Then, dE/dy_k = diffDst_k. Finally,
181// z_i = Sum:k [dd_k * dy_k/dx_j] = A - B (code variables) =
182// = dd_i * O(i)^-b - 2*a*b/n * x_i * Sum:k {O(k)^-b / O(k) * x_k * dd_k};
183
184template <impl::data_type_t d_type>
185template <dnnl_format_tag_t tag>
186status_t ref_lrn_bwd_t<d_type>::execute_backward(const exec_ctx_t &ctx) const {
187 using namespace alg_kind;
188 using namespace format_tag;
189
190 status_t status = status::success;
191
192 auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
193 auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST);
194 auto diff_src = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_DIFF_SRC, status);
195 CHECK(status);
196
197 const memory_desc_wrapper data_d(pd()->src_md());
198
199 const dim_t C = pd()->C();
200 const dim_t D = pd()->D();
201 const dim_t H = pd()->H();
202 const dim_t W = pd()->W();
203 const auto stride_mb = data_d.blocking_desc().strides[0];
204 const bool across_channels = pd()->desc()->alg_kind == lrn_across_channels;
205 static constexpr dim_t blksize = tag == nChw16c ? 16 : 8;
206 const auto ndims = data_d.ndims();
207
208 auto compute_n_summands = [&](dim_t size) {
209 if (across_channels) {
210 return size;
211 } else { // within_channel
212 dim_t n_summands = 1;
213 for (auto d = ndims - 2; d > 0; --d)
214 n_summands *= size;
215 return n_summands;
216 }
217 };
218
219 const acc_data_t alpha = static_cast<acc_data_t>(pd()->desc()->lrn_alpha);
220 const acc_data_t beta = static_cast<acc_data_t>(pd()->desc()->lrn_beta);
221 const acc_data_t k = static_cast<acc_data_t>(pd()->desc()->lrn_k);
222 const dim_t size = pd()->desc()->local_size;
223 const dim_t half_size = (size - 1) / 2;
224 const dim_t summands = compute_n_summands(size);
225
226 auto data_off = [&](dim_t mb, dim_t c, dim_t d, dim_t h, dim_t w) -> dim_t {
227 switch (tag) {
228 case nChw16c:
229 case nChw8c:
230 return mb * stride_mb + (c / blksize) * H * W * blksize
231 + h * W * blksize + w * blksize + c % blksize;
232 case nchw: return mb * stride_mb + c * H * W + h * W + w;
233 case nhwc: return mb * stride_mb + h * W * C + w * C + c;
234 default:
235 if (ndims >= 5) return data_d.off(mb, c, d, h, w);
236 if (ndims >= 4) return data_d.off(mb, c, h, w);
237 if (ndims >= 3) return data_d.off(mb, c, w);
238 return data_d.off(mb, c);
239 }
240 };
241
242 // pass by value due to icc170 and icc180 problem on KNL
243 auto get_omega = [=](dim_t mb, dim_t oc, dim_t od, dim_t oh, dim_t ow) {
244 acc_data_t sum = 0;
245 if (across_channels) {
246 const dim_t c_st = nstl::max(oc - half_size + 0, (dim_t)0);
247 const dim_t c_en = nstl::min(oc + half_size + 1, C);
248
249 for (dim_t c = c_st; c < c_en; ++c) {
250 const acc_data_t s = src[data_off(mb, c, od, oh, ow)];
251 sum += s * s;
252 }
253 } else {
254 dim_t d_st = nstl::max(od - half_size + 0, (dim_t)0);
255 dim_t d_en = nstl::min(od + half_size + 1, D);
256 dim_t h_st = nstl::max(oh - half_size + 0, (dim_t)0);
257 dim_t h_en = nstl::min(oh + half_size + 1, H);
258 dim_t w_st = nstl::max(ow - half_size + 0, (dim_t)0);
259 dim_t w_en = nstl::min(ow + half_size + 1, W);
260 for_(dim_t d = d_st; d < d_en; ++d)
261 for_(dim_t h = h_st; h < h_en; ++h)
262 for (dim_t w = w_st; w < w_en; ++w) {
263 const acc_data_t s = src[data_off(mb, oc, d, h, w)];
264 sum += s * s;
265 }
266 }
267 return (acc_data_t)(k + alpha * sum / summands);
268 };
269
270 // pass by value due to icc170 and icc180 problem on KNL
271 auto ker = [=](data_t *d, dim_t mb, dim_t oc, dim_t od, dim_t oh,
272 dim_t ow) {
273 acc_data_t A = 0, B = 0;
274 if (across_channels) {
275 const dim_t c_st = nstl::max(oc - half_size + 0, (dim_t)0);
276 const dim_t c_en = nstl::min(oc + half_size + 1, C);
277
278 for (dim_t c = c_st; c < c_en; c++) {
279 const auto off = data_off(mb, c, od, oh, ow);
280 const acc_data_t omega = get_omega(mb, c, od, oh, ow);
281 const acc_data_t omega_in_beta
282 = fast_negative_powf(omega, beta);
283 const acc_data_t tmp
284 = omega_in_beta * (acc_data_t)diff_dst[off];
285 if (c == oc) A = tmp;
286 B += (src[off] * tmp / omega);
287 }
288 } else {
289 dim_t d_st = nstl::max(od - half_size + 0, (dim_t)0);
290 dim_t d_en = nstl::min(od + half_size + 1, D);
291 dim_t h_st = nstl::max(oh - half_size + 0, (dim_t)0);
292 dim_t h_en = nstl::min(oh + half_size + 1, H);
293 dim_t w_st = nstl::max(ow - half_size + 0, (dim_t)0);
294 dim_t w_en = nstl::min(ow + half_size + 1, W);
295 for_(dim_t d = d_st; d < d_en; ++d)
296 for_(dim_t h = h_st; h < h_en; ++h)
297 for (dim_t w = w_st; w < w_en; ++w) {
298 const auto off = data_off(mb, oc, d, h, w);
299 const acc_data_t omega = get_omega(mb, oc, d, h, w);
300 const acc_data_t omega_in_beta
301 = fast_negative_powf(omega, beta);
302 const acc_data_t tmp
303 = omega_in_beta * (acc_data_t)diff_dst[off];
304 if (d == od && h == oh && w == ow) A = tmp;
305 B += (src[off] * tmp / omega);
306 }
307 }
308 const auto off = data_off(mb, oc, od, oh, ow);
309 B *= (2.0f * alpha * beta * src[off] / summands);
310 *d = static_cast<data_t>(A - B);
311 };
312
313 const dim_t MB = pd()->MB();
314 if (tag == nChw16c || tag == nChw8c) {
315 parallel_nd(MB, utils::div_up(C, blksize), H, W,
316 [&](dim_t mb, dim_t c_blk, dim_t h, dim_t w) {
317 dim_t c = c_blk * blksize;
318 const dim_t off = mb * stride_mb + c * H * W
319 + (h * W + w) * blksize;
320 PRAGMA_OMP_SIMD()
321 for (dim_t cc = 0; cc < nstl::min(blksize, C - c); ++cc)
322 ker(&diff_src[off + cc], mb, c + cc, 0, h, w);
323 });
324 } else if (tag == nhwc) {
325 parallel_nd(MB, H, W, C, [&](dim_t mb, dim_t h, dim_t w, dim_t c) {
326 const dim_t off = mb * stride_mb + h * W * C + w * C + c;
327 ker(&diff_src[off], mb, c, 0, h, w);
328 });
329 } else {
330 parallel_nd(MB, C, D, H, W,
331 [&](dim_t mb, dim_t c, dim_t d, dim_t h, dim_t w) {
332 const dim_t off = data_off(mb, c, d, h, w);
333 ker(&diff_src[off], mb, c, d, h, w);
334 });
335 }
336 return status::success;
337}
338
339template status_t
340ref_lrn_fwd_t<data_type::f32>::execute_forward<format_tag::nChw16c>(
341 const exec_ctx_t &ctx) const;
342template status_t
343ref_lrn_fwd_t<data_type::f32>::execute_forward<format_tag::nChw8c>(
344 const exec_ctx_t &ctx) const;
345template status_t
346ref_lrn_fwd_t<data_type::f32>::execute_forward<format_tag::nchw>(
347 const exec_ctx_t &ctx) const;
348template status_t
349ref_lrn_fwd_t<data_type::f32>::execute_forward<format_tag::nhwc>(
350 const exec_ctx_t &ctx) const;
351template status_t
352ref_lrn_fwd_t<data_type::f32>::execute_forward<format_tag::any>(
353 const exec_ctx_t &ctx) const;
354template status_t
355ref_lrn_bwd_t<data_type::f32>::execute_backward<format_tag::nChw16c>(
356 const exec_ctx_t &ctx) const;
357template status_t
358ref_lrn_bwd_t<data_type::f32>::execute_backward<format_tag::nChw8c>(
359 const exec_ctx_t &ctx) const;
360template status_t
361ref_lrn_bwd_t<data_type::f32>::execute_backward<format_tag::nchw>(
362 const exec_ctx_t &ctx) const;
363template status_t
364ref_lrn_bwd_t<data_type::f32>::execute_backward<format_tag::nhwc>(
365 const exec_ctx_t &ctx) const;
366template status_t
367ref_lrn_bwd_t<data_type::f32>::execute_backward<format_tag::any>(
368 const exec_ctx_t &ctx) const;
369
370template status_t
371ref_lrn_fwd_t<data_type::bf16>::execute_forward<format_tag::nChw16c>(
372 const exec_ctx_t &ctx) const;
373template status_t
374ref_lrn_fwd_t<data_type::bf16>::execute_forward<format_tag::nChw8c>(
375 const exec_ctx_t &ctx) const;
376template status_t
377ref_lrn_fwd_t<data_type::bf16>::execute_forward<format_tag::nchw>(
378 const exec_ctx_t &ctx) const;
379template status_t
380ref_lrn_fwd_t<data_type::bf16>::execute_forward<format_tag::nhwc>(
381 const exec_ctx_t &ctx) const;
382template status_t
383ref_lrn_fwd_t<data_type::bf16>::execute_forward<format_tag::any>(
384 const exec_ctx_t &ctx) const;
385template status_t
386ref_lrn_bwd_t<data_type::bf16>::execute_backward<format_tag::nChw16c>(
387 const exec_ctx_t &ctx) const;
388template status_t
389ref_lrn_bwd_t<data_type::bf16>::execute_backward<format_tag::nChw8c>(
390 const exec_ctx_t &ctx) const;
391template status_t
392ref_lrn_bwd_t<data_type::bf16>::execute_backward<format_tag::nchw>(
393 const exec_ctx_t &ctx) const;
394template status_t
395ref_lrn_bwd_t<data_type::bf16>::execute_backward<format_tag::nhwc>(
396 const exec_ctx_t &ctx) const;
397template status_t
398ref_lrn_bwd_t<data_type::bf16>::execute_backward<format_tag::any>(
399 const exec_ctx_t &ctx) const;
400
401template status_t
402ref_lrn_fwd_t<data_type::f16>::execute_forward<format_tag::nChw16c>(
403 const exec_ctx_t &ctx) const;
404template status_t
405ref_lrn_fwd_t<data_type::f16>::execute_forward<format_tag::nChw8c>(
406 const exec_ctx_t &ctx) const;
407template status_t
408ref_lrn_fwd_t<data_type::f16>::execute_forward<format_tag::nchw>(
409 const exec_ctx_t &ctx) const;
410template status_t
411ref_lrn_fwd_t<data_type::f16>::execute_forward<format_tag::nhwc>(
412 const exec_ctx_t &ctx) const;
413template status_t
414ref_lrn_fwd_t<data_type::f16>::execute_forward<format_tag::any>(
415 const exec_ctx_t &ctx) const;
416template status_t
417ref_lrn_bwd_t<data_type::f16>::execute_backward<format_tag::nChw16c>(
418 const exec_ctx_t &ctx) const;
419template status_t
420ref_lrn_bwd_t<data_type::f16>::execute_backward<format_tag::nChw8c>(
421 const exec_ctx_t &ctx) const;
422template status_t
423ref_lrn_bwd_t<data_type::f16>::execute_backward<format_tag::nchw>(
424 const exec_ctx_t &ctx) const;
425template status_t
426ref_lrn_bwd_t<data_type::f16>::execute_backward<format_tag::nhwc>(
427 const exec_ctx_t &ctx) const;
428template status_t
429ref_lrn_bwd_t<data_type::f16>::execute_backward<format_tag::any>(
430 const exec_ctx_t &ctx) const;
431
432} // namespace cpu
433} // namespace impl
434} // namespace dnnl
435
436// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
437