1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "utils/parallel.hpp" |
18 | |
19 | #include "lrn/lrn.hpp" |
20 | |
21 | namespace lrn { |
22 | |
23 | float fast_powf(float omega, float beta) { |
24 | if (beta == 0.75f) return 1.0f / sqrtf(sqrtf(omega) * omega); |
25 | return 1.0f / powf(omega, beta); |
26 | } |
27 | |
28 | float get_omega(const prb_t *prb, const dnn_mem_t &src, int64_t mb, int64_t c, |
29 | int64_t d, int64_t h, int64_t w) { |
30 | const int size = prb->ls; |
31 | const int half_size = (size - 1) / 2; |
32 | const int summands = compute_n_summands(prb); |
33 | |
34 | float sum = 0; |
35 | if (prb->alg == ACROSS) { |
36 | const int64_t c_st = MAX2(c - half_size + 0, 0); |
37 | const int64_t c_en = MIN2(c + half_size + 1, prb->ic); |
38 | |
39 | for (int64_t cs = c_st; cs < c_en; ++cs) { |
40 | const auto off = data_off(prb, mb, cs, d, h, w); |
41 | const float s = src.get_elem(off); |
42 | sum += s * s; |
43 | } |
44 | } else if (prb->alg == WITHIN) { |
45 | const int64_t d_st = MAX2(d - half_size + 0, 0); |
46 | const int64_t d_en = MIN2(d + half_size + 1, prb->id); |
47 | const int64_t h_st = MAX2(h - half_size + 0, 0); |
48 | const int64_t h_en = MIN2(h + half_size + 1, prb->ih); |
49 | const int64_t w_st = MAX2(w - half_size + 0, 0); |
50 | const int64_t w_en = MIN2(w + half_size + 1, prb->iw); |
51 | |
52 | for_(int64_t ds = d_st; ds < d_en; ++ds) |
53 | for_(int64_t hs = h_st; hs < h_en; ++hs) |
54 | for (int64_t ws = w_st; ws < w_en; ++ws) { |
55 | const auto off = data_off(prb, mb, c, ds, hs, ws); |
56 | const float s = src.get_elem(off); |
57 | sum += s * s; |
58 | } |
59 | } |
60 | |
61 | return (float)(prb->k + prb->alpha * sum / summands); |
62 | } |
63 | |
64 | void compute_ref_fwd(const prb_t *prb, const args_t &args) { |
65 | const dnn_mem_t &src = args.find(DNNL_ARG_SRC); |
66 | const dnn_mem_t &dst = args.find(DNNL_ARG_DST); |
67 | |
68 | float *dst_ptr = (float *)dst; |
69 | |
70 | benchdnn_parallel_nd(prb->mb, prb->ic, prb->id, prb->ih, prb->iw, |
71 | [&](int64_t mb, int64_t c, int64_t d, int64_t h, int64_t w) { |
72 | const auto off = data_off(prb, mb, c, d, h, w); |
73 | const float omega = get_omega(prb, src, mb, c, d, h, w); |
74 | const float omega_in_beta = fast_powf(omega, prb->beta); |
75 | dst_ptr[off] = src.get_elem(off) * omega_in_beta; |
76 | }); |
77 | } |
78 | |
79 | void compute_ref_bwd(const prb_t *prb, const args_t &args) { |
80 | const dnn_mem_t &src = args.find(DNNL_ARG_SRC); |
81 | const dnn_mem_t &d_dst = args.find(DNNL_ARG_DIFF_DST); |
82 | const dnn_mem_t &d_src = args.find(DNNL_ARG_DIFF_SRC); |
83 | |
84 | float *d_src_ptr = (float *)d_src; |
85 | |
86 | const int size = prb->ls; |
87 | const int half_size = (size - 1) / 2; |
88 | const int summands = compute_n_summands(prb); |
89 | |
90 | benchdnn_parallel_nd(prb->mb, prb->ic, prb->id, prb->ih, prb->iw, |
91 | [&](int64_t mb, int64_t c, int64_t d, int64_t h, int64_t w) { |
92 | float A = 0, B = 0; |
93 | if (prb->alg == ACROSS) { |
94 | const int64_t c_st = MAX2(c - half_size + 0, 0); |
95 | const int64_t c_en = MIN2(c + half_size + 1, prb->ic); |
96 | |
97 | for (int64_t cs = c_st; cs < c_en; ++cs) { |
98 | const auto off = data_off(prb, mb, cs, d, h, w); |
99 | const float omega |
100 | = get_omega(prb, src, mb, cs, d, h, w); |
101 | const float omega_in_beta = fast_powf(omega, prb->beta); |
102 | const float tmp = omega_in_beta * d_dst.get_elem(off); |
103 | if (cs == c) A = tmp; |
104 | B += (tmp / omega * src.get_elem(off)); |
105 | } |
106 | } else if (prb->alg == WITHIN) { |
107 | const int64_t d_st = MAX2(d - half_size + 0, 0); |
108 | const int64_t d_en = MIN2(d + half_size + 1, prb->id); |
109 | const int64_t h_st = MAX2(h - half_size + 0, 0); |
110 | const int64_t h_en = MIN2(h + half_size + 1, prb->ih); |
111 | const int64_t w_st = MAX2(w - half_size + 0, 0); |
112 | const int64_t w_en = MIN2(w + half_size + 1, prb->iw); |
113 | |
114 | for_(int64_t ds = d_st; ds < d_en; ++ds) |
115 | for_(int64_t hs = h_st; hs < h_en; ++hs) |
116 | for (int64_t ws = w_st; ws < w_en; ++ws) { |
117 | const auto off = data_off(prb, mb, c, ds, hs, ws); |
118 | const float omega |
119 | = get_omega(prb, src, mb, c, ds, hs, ws); |
120 | const float omega_in_beta = fast_powf(omega, prb->beta); |
121 | const float tmp = omega_in_beta * d_dst.get_elem(off); |
122 | if (ds == d && hs == h && ws == w) A = tmp; |
123 | B += (tmp / omega * src.get_elem(off)); |
124 | } |
125 | } |
126 | const auto off = data_off(prb, mb, c, d, h, w); |
127 | B *= (2.0f * prb->alpha * prb->beta * src.get_elem(off) |
128 | / summands); |
129 | d_src_ptr[off] = A - B; |
130 | }); |
131 | } |
132 | |
133 | void compute_ref( |
134 | const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) { |
135 | if (prb->dir & FLAG_FWD) |
136 | compute_ref_fwd(prb, args); |
137 | else |
138 | compute_ref_bwd(prb, args); |
139 | } |
140 | |
141 | } // namespace lrn |
142 | |