1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "utils/parallel.hpp" |
18 | |
19 | #include "lnorm/lnorm.hpp" |
20 | |
21 | namespace lnorm { |
22 | |
23 | void compute_ref_fwd(const prb_t *prb, const args_t &args) { |
24 | const dnn_mem_t &src = args.find(DNNL_ARG_SRC); |
25 | const dnn_mem_t &mean = args.find(DNNL_ARG_MEAN); |
26 | const dnn_mem_t &var = args.find(DNNL_ARG_VARIANCE); |
27 | const dnn_mem_t &sc = args.find(DNNL_ARG_SCALE); |
28 | const dnn_mem_t &sh = args.find(DNNL_ARG_SHIFT); |
29 | const dnn_mem_t &dst = args.find(DNNL_ARG_DST); |
30 | const dnn_mem_t &src_scale = args.find(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC); |
31 | const dnn_mem_t &dst_scale = args.find(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST); |
32 | |
33 | float *dst_ptr = (float *)dst; |
34 | |
35 | const bool use_sc = prb->use_sc(); |
36 | const bool use_sh = prb->use_sh(); |
37 | |
38 | assert(src_scale.nelems() == 1 && dst_scale.nelems() == 1); |
39 | const float output_scale = src_scale.get_elem(0) / dst_scale.get_elem(0); |
40 | |
41 | benchdnn_parallel_nd(prb->n, [&](int64_t n) { |
42 | float smean = mean.get_elem(n); |
43 | float svar = var.get_elem(n); |
44 | float sqrt_var = sqrtf(svar + prb->eps); |
45 | |
46 | for (int64_t c = 0; c < prb->c; ++c) { |
47 | float gamma = (use_sc ? sc.get_elem(c) : 1.0f) / sqrt_var; |
48 | float beta = use_sh ? sh.get_elem(c) : 0; |
49 | auto off = n * prb->c + c; |
50 | float res = gamma * (src.get_elem(off) - smean) + beta; |
51 | dst_ptr[off] = res * output_scale; |
52 | } |
53 | }); |
54 | } |
55 | |
56 | void compute_ref_bwd(const prb_t *prb, const args_t &args) { |
57 | const dnn_mem_t &src = args.find(DNNL_ARG_SRC); |
58 | const dnn_mem_t &mean = args.find(DNNL_ARG_MEAN); |
59 | const dnn_mem_t &var = args.find(DNNL_ARG_VARIANCE); |
60 | const dnn_mem_t &d_dst = args.find(DNNL_ARG_DIFF_DST); |
61 | const dnn_mem_t &sc = args.find(DNNL_ARG_SCALE); |
62 | const dnn_mem_t &d_src = args.find(DNNL_ARG_DIFF_SRC); |
63 | const dnn_mem_t &d_sc = args.find(DNNL_ARG_DIFF_SCALE); |
64 | const dnn_mem_t &d_sh = args.find(DNNL_ARG_DIFF_SHIFT); |
65 | |
66 | float *d_src_ptr = (float *)d_src; |
67 | float *d_sc_ptr = (float *)d_sc; |
68 | float *d_sh_ptr = (float *)d_sh; |
69 | |
70 | const bool use_sc = prb->use_sc(); |
71 | const bool use_sh = prb->use_sh(); |
72 | |
73 | if ((use_sc || use_sh) && (prb->dir & FLAG_WEI)) { |
74 | benchdnn_parallel_nd(prb->c, [&](int64_t c) { |
75 | float d_gamma = 0; |
76 | float d_beta = 0; |
77 | |
78 | for (int64_t n = 0; n < prb->n; ++n) { |
79 | float smean = mean.get_elem(n); |
80 | float svar = var.get_elem(n); |
81 | float rcp_denom = 1.f / sqrtf(svar + prb->eps); |
82 | auto off = n * prb->c + c; |
83 | float dd = d_dst.get_elem(off); |
84 | d_gamma += dd * (src.get_elem(off) - smean) * rcp_denom; |
85 | d_beta += dd; |
86 | } |
87 | |
88 | if (use_sc) d_sc_ptr[c] = d_gamma; |
89 | if (use_sh) d_sh_ptr[c] = d_beta; |
90 | }); |
91 | } |
92 | |
93 | benchdnn_parallel_nd(prb->n, [&](int64_t n) { |
94 | float smean = mean.get_elem(n); |
95 | float svar = var.get_elem(n); |
96 | float rcp_denom = 1.f / sqrtf(svar + prb->eps); |
97 | float dd_gamma = 0, dd_gamma_x = 0; |
98 | if (!(prb->flags & GLOB_STATS)) { |
99 | for (int64_t c = 0; c < prb->c; ++c) { |
100 | auto off = n * prb->c + c; |
101 | float ds = d_dst.get_elem(off); |
102 | const float x = src.get_elem(off) - smean; |
103 | float gamma = use_sc ? sc.get_elem(c) : 1; |
104 | dd_gamma += gamma * ds; |
105 | dd_gamma_x += gamma * ds * x; |
106 | } |
107 | dd_gamma_x *= rcp_denom; |
108 | } |
109 | for (int64_t c = 0; c < prb->c; ++c) { |
110 | float gamma = use_sc ? sc.get_elem(c) : 1; |
111 | auto off = n * prb->c + c; |
112 | float ds = d_dst.get_elem(off) * gamma; |
113 | if (!(prb->flags & GLOB_STATS)) { |
114 | const float x = src.get_elem(off) - smean; |
115 | ds -= (dd_gamma + x * dd_gamma_x * rcp_denom) / prb->c; |
116 | } |
117 | |
118 | d_src_ptr[off] = rcp_denom * ds; |
119 | } |
120 | }); |
121 | } |
122 | |
123 | void compute_ref( |
124 | const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) { |
125 | if (prb->dir & FLAG_FWD) |
126 | compute_ref_fwd(prb, args); |
127 | else |
128 | compute_ref_bwd(prb, args); |
129 | } |
130 | |
131 | } // namespace lnorm |
132 | |