1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "utils/parallel.hpp" |
18 | |
19 | #include "bnorm/bnorm.hpp" |
20 | |
21 | namespace bnorm { |
22 | |
23 | void compute_ref_fwd(const prb_t *prb, const args_t &args) { |
24 | const dnn_mem_t &src = args.find(DNNL_ARG_SRC); |
25 | const dnn_mem_t &src_add = args.find(DNNL_ARG_SRC_1); |
26 | const dnn_mem_t &mean = args.find(DNNL_ARG_MEAN); |
27 | const dnn_mem_t &var = args.find(DNNL_ARG_VARIANCE); |
28 | const dnn_mem_t &sc = args.find(DNNL_ARG_SCALE); |
29 | const dnn_mem_t &sh = args.find(DNNL_ARG_SHIFT); |
30 | const dnn_mem_t &ws = args.find(DNNL_ARG_WORKSPACE); |
31 | const dnn_mem_t &dst = args.find(DNNL_ARG_DST); |
32 | const dnn_mem_t &src_hat = args.find(DNNL_ARG_DST_1); |
33 | |
34 | uint8_t *ws_ptr = (uint8_t *)ws; |
35 | float *dst_ptr = (float *)dst; |
36 | float *src_hat_ptr = (float *)src_hat; |
37 | |
38 | const int64_t MB = prb->mb; |
39 | const int64_t C = prb->ic; |
40 | const int64_t D = prb->id; |
41 | const int64_t H = prb->ih; |
42 | const int64_t W = prb->iw; |
43 | const bool use_sc = prb->use_sc(); |
44 | const bool use_sh = prb->use_sh(); |
45 | const bool fuse_relu = prb->fuse_relu(); |
46 | const bool fuse_add_relu = prb->fuse_add_relu(); |
47 | const bool need_ws = prb->need_ws(); |
48 | const auto &attr = prb->attr; |
49 | |
50 | benchdnn_parallel_nd(C, [&](int64_t c) { |
51 | float smean = mean.get_elem(c); |
52 | float svar = var.get_elem(c); |
53 | float sqrt_var = sqrtf(svar + prb->eps); |
54 | float rcp_denom = 1.f / sqrt_var; |
55 | float gamma = use_sc ? sc.get_elem(c) : 1.f; |
56 | float beta = use_sh ? sh.get_elem(c) : 0.f; |
57 | |
58 | for_(int64_t mb = 0; mb < MB; ++mb) |
59 | for_(int64_t d = 0; d < D; ++d) |
60 | for_(int64_t h = 0; h < H; ++h) |
61 | for (int64_t w = 0; w < W; ++w) { |
62 | auto off = data_off(prb, mb, c, d, h, w); |
63 | float x_hat = (src.get_elem(off) - smean) * rcp_denom; |
64 | float res = gamma * x_hat + beta; |
65 | if (fuse_add_relu) res += src_add.get_elem(off); |
66 | if (fuse_relu && res < 0) res = 0; |
67 | if (need_ws) ws_ptr[off] = !!res; |
68 | maybe_post_ops(attr, res); |
69 | dst_ptr[off] = res; |
70 | if (prb->dir & FLAG_BWD) src_hat_ptr[off] = x_hat; |
71 | } |
72 | }); |
73 | } |
74 | |
75 | void compute_ref_bwd(const prb_t *prb, const args_t &args) { |
76 | const dnn_mem_t &src_hat = args.find(DNNL_ARG_DST_1); |
77 | const dnn_mem_t &var = args.find(DNNL_ARG_VARIANCE); |
78 | const dnn_mem_t &d_dst = args.find(DNNL_ARG_DIFF_DST); |
79 | const dnn_mem_t &sc = args.find(DNNL_ARG_SCALE); |
80 | const dnn_mem_t &ws = args.find(DNNL_ARG_WORKSPACE); |
81 | const dnn_mem_t &d_src = args.find(DNNL_ARG_DIFF_SRC); |
82 | const dnn_mem_t &d_src_add = args.find(DNNL_ARG_DIFF_SRC_1); |
83 | const dnn_mem_t &d_sc = args.find(DNNL_ARG_DIFF_SCALE); |
84 | const dnn_mem_t &d_sh = args.find(DNNL_ARG_DIFF_SHIFT); |
85 | |
86 | float *d_src_ptr = (float *)d_src; |
87 | float *d_sc_ptr = (float *)d_sc; |
88 | float *d_sh_ptr = (float *)d_sh; |
89 | |
90 | const int64_t MB = prb->mb; |
91 | const int64_t C = prb->ic; |
92 | const int64_t D = prb->id; |
93 | const int64_t H = prb->ih; |
94 | const int64_t W = prb->iw; |
95 | const bool glob_stats = prb->flags & GLOB_STATS; |
96 | const bool use_sc = prb->use_sc(); |
97 | const bool use_sh = prb->use_sh(); |
98 | const bool fuse_relu = prb->fuse_relu(); |
99 | const bool fuse_add_relu = prb->fuse_add_relu(); |
100 | |
101 | const float MB_SP = MB * D * H * W; |
102 | |
103 | benchdnn_parallel_nd(C, [&](int64_t c) { |
104 | float rcp_denom = 1.f / sqrtf(var.get_elem(c) + prb->eps); |
105 | float gamma = use_sc ? sc.get_elem(c) : 1.f; |
106 | |
107 | float d_gamma = 0; |
108 | float d_beta = 0; |
109 | |
110 | for_(int64_t mb = 0; mb < MB; ++mb) |
111 | for_(int64_t d = 0; d < D; ++d) |
112 | for_(int64_t h = 0; h < H; ++h) |
113 | for (int64_t w = 0; w < W; ++w) { |
114 | auto off = data_off(prb, mb, c, d, h, w); |
115 | float dd = d_dst.get_elem(off); |
116 | if (fuse_relu && ws.get_elem(off) == 0) dd = 0; |
117 | d_gamma += dd * src_hat.get_elem(off); |
118 | d_beta += dd; |
119 | } |
120 | |
121 | if (use_sc && (prb->dir & FLAG_WEI)) d_sc_ptr[c] = d_gamma; |
122 | if (use_sh && (prb->dir & FLAG_WEI)) d_sh_ptr[c] = d_beta; |
123 | |
124 | for_(int64_t mb = 0; mb < MB; ++mb) |
125 | for_(int64_t d = 0; d < D; ++d) |
126 | for_(int64_t h = 0; h < H; ++h) |
127 | for (int64_t w = 0; w < W; ++w) { |
128 | auto off = data_off(prb, mb, c, d, h, w); |
129 | float dd = d_dst.get_elem(off); |
130 | if (fuse_relu && ws.get_elem(off) == 0) dd = 0; |
131 | if (fuse_add_relu) d_src_add.set_elem(off, dd); |
132 | float ds = dd; |
133 | |
134 | if (!glob_stats) |
135 | ds -= (d_beta + src_hat.get_elem(off) * d_gamma) / MB_SP; |
136 | |
137 | d_src_ptr[off] = rcp_denom * ds * gamma; |
138 | } |
139 | }); |
140 | } |
141 | |
142 | void compute_ref( |
143 | const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) { |
144 | compute_ref_fwd(prb, args); |
145 | if (prb->dir & FLAG_BWD) compute_ref_bwd(prb, args); |
146 | } |
147 | |
148 | } // namespace bnorm |
149 | |