1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "utils/parallel.hpp"
18
19#include "lrn/lrn.hpp"
20
21namespace lrn {
22
23float fast_powf(float omega, float beta) {
24 if (beta == 0.75f) return 1.0f / sqrtf(sqrtf(omega) * omega);
25 return 1.0f / powf(omega, beta);
26}
27
28float get_omega(const prb_t *prb, const dnn_mem_t &src, int64_t mb, int64_t c,
29 int64_t d, int64_t h, int64_t w) {
30 const int size = prb->ls;
31 const int half_size = (size - 1) / 2;
32 const int summands = compute_n_summands(prb);
33
34 float sum = 0;
35 if (prb->alg == ACROSS) {
36 const int64_t c_st = MAX2(c - half_size + 0, 0);
37 const int64_t c_en = MIN2(c + half_size + 1, prb->ic);
38
39 for (int64_t cs = c_st; cs < c_en; ++cs) {
40 const auto off = data_off(prb, mb, cs, d, h, w);
41 const float s = src.get_elem(off);
42 sum += s * s;
43 }
44 } else if (prb->alg == WITHIN) {
45 const int64_t d_st = MAX2(d - half_size + 0, 0);
46 const int64_t d_en = MIN2(d + half_size + 1, prb->id);
47 const int64_t h_st = MAX2(h - half_size + 0, 0);
48 const int64_t h_en = MIN2(h + half_size + 1, prb->ih);
49 const int64_t w_st = MAX2(w - half_size + 0, 0);
50 const int64_t w_en = MIN2(w + half_size + 1, prb->iw);
51
52 for_(int64_t ds = d_st; ds < d_en; ++ds)
53 for_(int64_t hs = h_st; hs < h_en; ++hs)
54 for (int64_t ws = w_st; ws < w_en; ++ws) {
55 const auto off = data_off(prb, mb, c, ds, hs, ws);
56 const float s = src.get_elem(off);
57 sum += s * s;
58 }
59 }
60
61 return (float)(prb->k + prb->alpha * sum / summands);
62}
63
64void compute_ref_fwd(const prb_t *prb, const args_t &args) {
65 const dnn_mem_t &src = args.find(DNNL_ARG_SRC);
66 const dnn_mem_t &dst = args.find(DNNL_ARG_DST);
67
68 float *dst_ptr = (float *)dst;
69
70 benchdnn_parallel_nd(prb->mb, prb->ic, prb->id, prb->ih, prb->iw,
71 [&](int64_t mb, int64_t c, int64_t d, int64_t h, int64_t w) {
72 const auto off = data_off(prb, mb, c, d, h, w);
73 const float omega = get_omega(prb, src, mb, c, d, h, w);
74 const float omega_in_beta = fast_powf(omega, prb->beta);
75 dst_ptr[off] = src.get_elem(off) * omega_in_beta;
76 });
77}
78
79void compute_ref_bwd(const prb_t *prb, const args_t &args) {
80 const dnn_mem_t &src = args.find(DNNL_ARG_SRC);
81 const dnn_mem_t &d_dst = args.find(DNNL_ARG_DIFF_DST);
82 const dnn_mem_t &d_src = args.find(DNNL_ARG_DIFF_SRC);
83
84 float *d_src_ptr = (float *)d_src;
85
86 const int size = prb->ls;
87 const int half_size = (size - 1) / 2;
88 const int summands = compute_n_summands(prb);
89
90 benchdnn_parallel_nd(prb->mb, prb->ic, prb->id, prb->ih, prb->iw,
91 [&](int64_t mb, int64_t c, int64_t d, int64_t h, int64_t w) {
92 float A = 0, B = 0;
93 if (prb->alg == ACROSS) {
94 const int64_t c_st = MAX2(c - half_size + 0, 0);
95 const int64_t c_en = MIN2(c + half_size + 1, prb->ic);
96
97 for (int64_t cs = c_st; cs < c_en; ++cs) {
98 const auto off = data_off(prb, mb, cs, d, h, w);
99 const float omega
100 = get_omega(prb, src, mb, cs, d, h, w);
101 const float omega_in_beta = fast_powf(omega, prb->beta);
102 const float tmp = omega_in_beta * d_dst.get_elem(off);
103 if (cs == c) A = tmp;
104 B += (tmp / omega * src.get_elem(off));
105 }
106 } else if (prb->alg == WITHIN) {
107 const int64_t d_st = MAX2(d - half_size + 0, 0);
108 const int64_t d_en = MIN2(d + half_size + 1, prb->id);
109 const int64_t h_st = MAX2(h - half_size + 0, 0);
110 const int64_t h_en = MIN2(h + half_size + 1, prb->ih);
111 const int64_t w_st = MAX2(w - half_size + 0, 0);
112 const int64_t w_en = MIN2(w + half_size + 1, prb->iw);
113
114 for_(int64_t ds = d_st; ds < d_en; ++ds)
115 for_(int64_t hs = h_st; hs < h_en; ++hs)
116 for (int64_t ws = w_st; ws < w_en; ++ws) {
117 const auto off = data_off(prb, mb, c, ds, hs, ws);
118 const float omega
119 = get_omega(prb, src, mb, c, ds, hs, ws);
120 const float omega_in_beta = fast_powf(omega, prb->beta);
121 const float tmp = omega_in_beta * d_dst.get_elem(off);
122 if (ds == d && hs == h && ws == w) A = tmp;
123 B += (tmp / omega * src.get_elem(off));
124 }
125 }
126 const auto off = data_off(prb, mb, c, d, h, w);
127 B *= (2.0f * prb->alpha * prb->beta * src.get_elem(off)
128 / summands);
129 d_src_ptr[off] = A - B;
130 });
131}
132
133void compute_ref(
134 const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) {
135 if (prb->dir & FLAG_FWD)
136 compute_ref_fwd(prb, args);
137 else
138 compute_ref_bwd(prb, args);
139}
140
141} // namespace lrn
142