1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "utils/parallel.hpp"
18
19#include "lnorm/lnorm.hpp"
20
21namespace lnorm {
22
23void compute_ref_fwd(const prb_t *prb, const args_t &args) {
24 const dnn_mem_t &src = args.find(DNNL_ARG_SRC);
25 const dnn_mem_t &mean = args.find(DNNL_ARG_MEAN);
26 const dnn_mem_t &var = args.find(DNNL_ARG_VARIANCE);
27 const dnn_mem_t &sc = args.find(DNNL_ARG_SCALE);
28 const dnn_mem_t &sh = args.find(DNNL_ARG_SHIFT);
29 const dnn_mem_t &dst = args.find(DNNL_ARG_DST);
30 const dnn_mem_t &src_scale = args.find(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC);
31 const dnn_mem_t &dst_scale = args.find(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST);
32
33 float *dst_ptr = (float *)dst;
34
35 const bool use_sc = prb->use_sc();
36 const bool use_sh = prb->use_sh();
37
38 assert(src_scale.nelems() == 1 && dst_scale.nelems() == 1);
39 const float output_scale = src_scale.get_elem(0) / dst_scale.get_elem(0);
40
41 benchdnn_parallel_nd(prb->n, [&](int64_t n) {
42 float smean = mean.get_elem(n);
43 float svar = var.get_elem(n);
44 float sqrt_var = sqrtf(svar + prb->eps);
45
46 for (int64_t c = 0; c < prb->c; ++c) {
47 float gamma = (use_sc ? sc.get_elem(c) : 1.0f) / sqrt_var;
48 float beta = use_sh ? sh.get_elem(c) : 0;
49 auto off = n * prb->c + c;
50 float res = gamma * (src.get_elem(off) - smean) + beta;
51 dst_ptr[off] = res * output_scale;
52 }
53 });
54}
55
56void compute_ref_bwd(const prb_t *prb, const args_t &args) {
57 const dnn_mem_t &src = args.find(DNNL_ARG_SRC);
58 const dnn_mem_t &mean = args.find(DNNL_ARG_MEAN);
59 const dnn_mem_t &var = args.find(DNNL_ARG_VARIANCE);
60 const dnn_mem_t &d_dst = args.find(DNNL_ARG_DIFF_DST);
61 const dnn_mem_t &sc = args.find(DNNL_ARG_SCALE);
62 const dnn_mem_t &d_src = args.find(DNNL_ARG_DIFF_SRC);
63 const dnn_mem_t &d_sc = args.find(DNNL_ARG_DIFF_SCALE);
64 const dnn_mem_t &d_sh = args.find(DNNL_ARG_DIFF_SHIFT);
65
66 float *d_src_ptr = (float *)d_src;
67 float *d_sc_ptr = (float *)d_sc;
68 float *d_sh_ptr = (float *)d_sh;
69
70 const bool use_sc = prb->use_sc();
71 const bool use_sh = prb->use_sh();
72
73 if ((use_sc || use_sh) && (prb->dir & FLAG_WEI)) {
74 benchdnn_parallel_nd(prb->c, [&](int64_t c) {
75 float d_gamma = 0;
76 float d_beta = 0;
77
78 for (int64_t n = 0; n < prb->n; ++n) {
79 float smean = mean.get_elem(n);
80 float svar = var.get_elem(n);
81 float rcp_denom = 1.f / sqrtf(svar + prb->eps);
82 auto off = n * prb->c + c;
83 float dd = d_dst.get_elem(off);
84 d_gamma += dd * (src.get_elem(off) - smean) * rcp_denom;
85 d_beta += dd;
86 }
87
88 if (use_sc) d_sc_ptr[c] = d_gamma;
89 if (use_sh) d_sh_ptr[c] = d_beta;
90 });
91 }
92
93 benchdnn_parallel_nd(prb->n, [&](int64_t n) {
94 float smean = mean.get_elem(n);
95 float svar = var.get_elem(n);
96 float rcp_denom = 1.f / sqrtf(svar + prb->eps);
97 float dd_gamma = 0, dd_gamma_x = 0;
98 if (!(prb->flags & GLOB_STATS)) {
99 for (int64_t c = 0; c < prb->c; ++c) {
100 auto off = n * prb->c + c;
101 float ds = d_dst.get_elem(off);
102 const float x = src.get_elem(off) - smean;
103 float gamma = use_sc ? sc.get_elem(c) : 1;
104 dd_gamma += gamma * ds;
105 dd_gamma_x += gamma * ds * x;
106 }
107 dd_gamma_x *= rcp_denom;
108 }
109 for (int64_t c = 0; c < prb->c; ++c) {
110 float gamma = use_sc ? sc.get_elem(c) : 1;
111 auto off = n * prb->c + c;
112 float ds = d_dst.get_elem(off) * gamma;
113 if (!(prb->flags & GLOB_STATS)) {
114 const float x = src.get_elem(off) - smean;
115 ds -= (dd_gamma + x * dd_gamma_x * rcp_denom) / prb->c;
116 }
117
118 d_src_ptr[off] = rcp_denom * ds;
119 }
120 });
121}
122
123void compute_ref(
124 const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) {
125 if (prb->dir & FLAG_FWD)
126 compute_ref_fwd(prb, args);
127 else
128 compute_ref_bwd(prb, args);
129}
130
131} // namespace lnorm
132