1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "utils/parallel.hpp"
18
19#include "bnorm/bnorm.hpp"
20
21namespace bnorm {
22
23void compute_ref_fwd(const prb_t *prb, const args_t &args) {
24 const dnn_mem_t &src = args.find(DNNL_ARG_SRC);
25 const dnn_mem_t &src_add = args.find(DNNL_ARG_SRC_1);
26 const dnn_mem_t &mean = args.find(DNNL_ARG_MEAN);
27 const dnn_mem_t &var = args.find(DNNL_ARG_VARIANCE);
28 const dnn_mem_t &sc = args.find(DNNL_ARG_SCALE);
29 const dnn_mem_t &sh = args.find(DNNL_ARG_SHIFT);
30 const dnn_mem_t &ws = args.find(DNNL_ARG_WORKSPACE);
31 const dnn_mem_t &dst = args.find(DNNL_ARG_DST);
32 const dnn_mem_t &src_hat = args.find(DNNL_ARG_DST_1);
33
34 uint8_t *ws_ptr = (uint8_t *)ws;
35 float *dst_ptr = (float *)dst;
36 float *src_hat_ptr = (float *)src_hat;
37
38 const int64_t MB = prb->mb;
39 const int64_t C = prb->ic;
40 const int64_t D = prb->id;
41 const int64_t H = prb->ih;
42 const int64_t W = prb->iw;
43 const bool use_sc = prb->use_sc();
44 const bool use_sh = prb->use_sh();
45 const bool fuse_relu = prb->fuse_relu();
46 const bool fuse_add_relu = prb->fuse_add_relu();
47 const bool need_ws = prb->need_ws();
48 const auto &attr = prb->attr;
49
50 benchdnn_parallel_nd(C, [&](int64_t c) {
51 float smean = mean.get_elem(c);
52 float svar = var.get_elem(c);
53 float sqrt_var = sqrtf(svar + prb->eps);
54 float rcp_denom = 1.f / sqrt_var;
55 float gamma = use_sc ? sc.get_elem(c) : 1.f;
56 float beta = use_sh ? sh.get_elem(c) : 0.f;
57
58 for_(int64_t mb = 0; mb < MB; ++mb)
59 for_(int64_t d = 0; d < D; ++d)
60 for_(int64_t h = 0; h < H; ++h)
61 for (int64_t w = 0; w < W; ++w) {
62 auto off = data_off(prb, mb, c, d, h, w);
63 float x_hat = (src.get_elem(off) - smean) * rcp_denom;
64 float res = gamma * x_hat + beta;
65 if (fuse_add_relu) res += src_add.get_elem(off);
66 if (fuse_relu && res < 0) res = 0;
67 if (need_ws) ws_ptr[off] = !!res;
68 maybe_post_ops(attr, res);
69 dst_ptr[off] = res;
70 if (prb->dir & FLAG_BWD) src_hat_ptr[off] = x_hat;
71 }
72 });
73}
74
75void compute_ref_bwd(const prb_t *prb, const args_t &args) {
76 const dnn_mem_t &src_hat = args.find(DNNL_ARG_DST_1);
77 const dnn_mem_t &var = args.find(DNNL_ARG_VARIANCE);
78 const dnn_mem_t &d_dst = args.find(DNNL_ARG_DIFF_DST);
79 const dnn_mem_t &sc = args.find(DNNL_ARG_SCALE);
80 const dnn_mem_t &ws = args.find(DNNL_ARG_WORKSPACE);
81 const dnn_mem_t &d_src = args.find(DNNL_ARG_DIFF_SRC);
82 const dnn_mem_t &d_src_add = args.find(DNNL_ARG_DIFF_SRC_1);
83 const dnn_mem_t &d_sc = args.find(DNNL_ARG_DIFF_SCALE);
84 const dnn_mem_t &d_sh = args.find(DNNL_ARG_DIFF_SHIFT);
85
86 float *d_src_ptr = (float *)d_src;
87 float *d_sc_ptr = (float *)d_sc;
88 float *d_sh_ptr = (float *)d_sh;
89
90 const int64_t MB = prb->mb;
91 const int64_t C = prb->ic;
92 const int64_t D = prb->id;
93 const int64_t H = prb->ih;
94 const int64_t W = prb->iw;
95 const bool glob_stats = prb->flags & GLOB_STATS;
96 const bool use_sc = prb->use_sc();
97 const bool use_sh = prb->use_sh();
98 const bool fuse_relu = prb->fuse_relu();
99 const bool fuse_add_relu = prb->fuse_add_relu();
100
101 const float MB_SP = MB * D * H * W;
102
103 benchdnn_parallel_nd(C, [&](int64_t c) {
104 float rcp_denom = 1.f / sqrtf(var.get_elem(c) + prb->eps);
105 float gamma = use_sc ? sc.get_elem(c) : 1.f;
106
107 float d_gamma = 0;
108 float d_beta = 0;
109
110 for_(int64_t mb = 0; mb < MB; ++mb)
111 for_(int64_t d = 0; d < D; ++d)
112 for_(int64_t h = 0; h < H; ++h)
113 for (int64_t w = 0; w < W; ++w) {
114 auto off = data_off(prb, mb, c, d, h, w);
115 float dd = d_dst.get_elem(off);
116 if (fuse_relu && ws.get_elem(off) == 0) dd = 0;
117 d_gamma += dd * src_hat.get_elem(off);
118 d_beta += dd;
119 }
120
121 if (use_sc && (prb->dir & FLAG_WEI)) d_sc_ptr[c] = d_gamma;
122 if (use_sh && (prb->dir & FLAG_WEI)) d_sh_ptr[c] = d_beta;
123
124 for_(int64_t mb = 0; mb < MB; ++mb)
125 for_(int64_t d = 0; d < D; ++d)
126 for_(int64_t h = 0; h < H; ++h)
127 for (int64_t w = 0; w < W; ++w) {
128 auto off = data_off(prb, mb, c, d, h, w);
129 float dd = d_dst.get_elem(off);
130 if (fuse_relu && ws.get_elem(off) == 0) dd = 0;
131 if (fuse_add_relu) d_src_add.set_elem(off, dd);
132 float ds = dd;
133
134 if (!glob_stats)
135 ds -= (d_beta + src_hat.get_elem(off) * d_gamma) / MB_SP;
136
137 d_src_ptr[off] = rcp_denom * ds * gamma;
138 }
139 });
140}
141
142void compute_ref(
143 const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) {
144 compute_ref_fwd(prb, args);
145 if (prb->dir & FLAG_BWD) compute_ref_bwd(prb, args);
146}
147
148} // namespace bnorm
149