1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "utils/parallel.hpp" |
18 | |
19 | #include "pool/pool.hpp" |
20 | |
21 | namespace pool { |
22 | |
23 | void compute_ref_fwd(const prb_t *prb, const args_t &args) { |
24 | const dnn_mem_t &src = args.find(DNNL_ARG_SRC); |
25 | const dnn_mem_t &dst = args.find(DNNL_ARG_DST); |
26 | const dnn_mem_t &ws = args.find(DNNL_ARG_WORKSPACE); |
27 | |
28 | float *dst_ptr = (float *)dst; |
29 | |
30 | auto v_po_masks = prb->attr.post_ops.get_po_masks(); |
31 | auto ker = [&](int64_t mb, int64_t ic, int64_t od, int64_t oh, int64_t ow) { |
32 | const int64_t ID = prb->id, IH = prb->ih, IW = prb->iw; |
33 | const int64_t KD = prb->kd, KH = prb->kh, KW = prb->kw; |
34 | const int64_t PD = prb->pd, PH = prb->ph, PW = prb->pw; |
35 | const int64_t SD = prb->sd, SH = prb->sh, SW = prb->sw; |
36 | const int64_t DD = prb->dd, DH = prb->dh, DW = prb->dw; |
37 | |
38 | // XXX: this is a hack to let tests with padded area to pass for bf16 |
39 | // dt due to the library initialize values with -max_dt, but not -INF. |
40 | float max_value = lowest_dt(prb->cfg[DST].dt); |
41 | float avg_value = 0.; |
42 | // Set initial value based on ws data type |
43 | int ws_off = prb->kernel_size() <= UINT8_MAX ? UINT8_MAX : INT_MAX; |
44 | |
45 | for (int64_t kd = 0; kd < KD; ++kd) { |
46 | const int64_t id = od * SD - PD + kd * (DD + 1); |
47 | if (id < 0 || id >= ID) continue; |
48 | for (int64_t kh = 0; kh < KH; ++kh) { |
49 | const int64_t ih = oh * SH - PH + kh * (DH + 1); |
50 | if (ih < 0 || ih >= IH) continue; |
51 | for (int64_t kw = 0; kw < KW; ++kw) { |
52 | const int64_t iw = ow * SW - PW + kw * (DW + 1); |
53 | if (iw < 0 || iw >= IW) continue; |
54 | |
55 | float s = src.get_elem(src_off_f(prb, mb, ic, id, ih, iw)); |
56 | if (s > max_value) { |
57 | max_value = s; |
58 | ws_off = ker_off_f(prb, kd, kh, kw); |
59 | } |
60 | avg_value += s; |
61 | } |
62 | } |
63 | } |
64 | |
65 | const auto dst_off = dst_off_f(prb, mb, ic, od, oh, ow); |
66 | float res = 0.f; |
67 | if (prb->alg == max) { |
68 | res = max_value; |
69 | if (!(prb->dir & FLAG_INF)) ws.set_elem(dst_off, ws_off); |
70 | } else if (prb->alg == avg_np || prb->alg == avg_p) { |
71 | res = avg_value / get_num_summands(prb, od, oh, ow); |
72 | } |
73 | |
74 | const auto v_po_vals = prepare_po_vals(dst, args, v_po_masks, dst_off); |
75 | |
76 | maybe_post_ops(prb->attr, res, 0.f, v_po_vals); |
77 | dst_ptr[dst_off] = res; |
78 | }; |
79 | |
80 | benchdnn_parallel_nd(prb->mb, prb->ic, prb->od, prb->oh, prb->ow, |
81 | [&](int64_t mb, int64_t ic, int64_t od, int64_t oh, int64_t ow) { |
82 | ker(mb, ic, od, oh, ow); |
83 | }); |
84 | } |
85 | |
86 | void compute_ref_bwd(const prb_t *prb, const args_t &args) { |
87 | const dnn_mem_t &d_dst = args.find(DNNL_ARG_DIFF_DST); |
88 | const dnn_mem_t &ws = args.find(DNNL_ARG_WORKSPACE); |
89 | const dnn_mem_t &d_src = args.find(DNNL_ARG_DIFF_SRC); |
90 | |
91 | float *d_src_ptr = (float *)d_src; |
92 | |
93 | auto zero_d_src = [&](int64_t mb, int64_t ic) { |
94 | for_(int64_t id = 0; id < prb->id; ++id) |
95 | for_(int64_t ih = 0; ih < prb->ih; ++ih) |
96 | for (int64_t iw = 0; iw < prb->iw; ++iw) |
97 | d_src_ptr[src_off_f(prb, mb, ic, id, ih, iw)] = 0.f; |
98 | }; |
99 | |
100 | auto ker = [&](int64_t mb, int64_t ic, int64_t od, int64_t oh, int64_t ow) { |
101 | const auto d_dst_off = dst_off_f(prb, mb, ic, od, oh, ow); |
102 | float d_dst_val = d_dst.get_elem(d_dst_off); |
103 | int ws_off = (prb->alg == max) ? ws.get_elem(d_dst_off) : 0; |
104 | |
105 | const int64_t ID = prb->id, IH = prb->ih, IW = prb->iw; |
106 | const int64_t KD = prb->kd, KH = prb->kh, KW = prb->kw; |
107 | const int64_t DD = prb->dd, DH = prb->dh, DW = prb->dw; |
108 | const int64_t PD = prb->pd, PH = prb->ph, PW = prb->pw; |
109 | const int64_t SD = prb->sd, SH = prb->sh, SW = prb->sw; |
110 | |
111 | for (int64_t kd = 0; kd < KD; ++kd) { |
112 | const int64_t id = od * SD - PD + kd * (DD + 1); |
113 | if (id < 0 || id >= ID) continue; |
114 | for (int64_t kh = 0; kh < KH; ++kh) { |
115 | const int64_t ih = oh * SH - PH + kh * (DH + 1); |
116 | if (ih < 0 || ih >= IH) continue; |
117 | for (int64_t kw = 0; kw < KW; ++kw) { |
118 | const int64_t iw = ow * SW - PW + kw * (DW + 1); |
119 | if (iw < 0 || iw >= IW) continue; |
120 | |
121 | float &S = d_src_ptr[src_off_f(prb, mb, ic, id, ih, iw)]; |
122 | if (prb->alg == max) { |
123 | if (ws_off == ker_off_f(prb, kd, kh, kw)) |
124 | S += d_dst_val; |
125 | } else if (prb->alg == avg_np || prb->alg == avg_p) |
126 | S += d_dst_val / get_num_summands(prb, od, oh, ow); |
127 | } |
128 | } |
129 | } |
130 | }; |
131 | |
132 | benchdnn_parallel_nd(prb->mb, prb->ic, [&](int64_t mb, int64_t ic) { |
133 | zero_d_src(mb, ic); |
134 | for_(int64_t od = 0; od < prb->od; ++od) |
135 | for_(int64_t oh = 0; oh < prb->oh; ++oh) |
136 | for (int64_t ow = 0; ow < prb->ow; ++ow) |
137 | ker(mb, ic, od, oh, ow); |
138 | }); |
139 | } |
140 | |
141 | void compute_ref( |
142 | const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) { |
143 | compute_ref_fwd(prb, args); |
144 | if (prb->dir & FLAG_BWD) compute_ref_bwd(prb, args); |
145 | } |
146 | |
147 | } // namespace pool |
148 | |