1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "utils/parallel.hpp"
18
19#include "pool/pool.hpp"
20
21namespace pool {
22
23void compute_ref_fwd(const prb_t *prb, const args_t &args) {
24 const dnn_mem_t &src = args.find(DNNL_ARG_SRC);
25 const dnn_mem_t &dst = args.find(DNNL_ARG_DST);
26 const dnn_mem_t &ws = args.find(DNNL_ARG_WORKSPACE);
27
28 float *dst_ptr = (float *)dst;
29
30 auto v_po_masks = prb->attr.post_ops.get_po_masks();
31 auto ker = [&](int64_t mb, int64_t ic, int64_t od, int64_t oh, int64_t ow) {
32 const int64_t ID = prb->id, IH = prb->ih, IW = prb->iw;
33 const int64_t KD = prb->kd, KH = prb->kh, KW = prb->kw;
34 const int64_t PD = prb->pd, PH = prb->ph, PW = prb->pw;
35 const int64_t SD = prb->sd, SH = prb->sh, SW = prb->sw;
36 const int64_t DD = prb->dd, DH = prb->dh, DW = prb->dw;
37
38 // XXX: this is a hack to let tests with padded area to pass for bf16
39 // dt due to the library initialize values with -max_dt, but not -INF.
40 float max_value = lowest_dt(prb->cfg[DST].dt);
41 float avg_value = 0.;
42 // Set initial value based on ws data type
43 int ws_off = prb->kernel_size() <= UINT8_MAX ? UINT8_MAX : INT_MAX;
44
45 for (int64_t kd = 0; kd < KD; ++kd) {
46 const int64_t id = od * SD - PD + kd * (DD + 1);
47 if (id < 0 || id >= ID) continue;
48 for (int64_t kh = 0; kh < KH; ++kh) {
49 const int64_t ih = oh * SH - PH + kh * (DH + 1);
50 if (ih < 0 || ih >= IH) continue;
51 for (int64_t kw = 0; kw < KW; ++kw) {
52 const int64_t iw = ow * SW - PW + kw * (DW + 1);
53 if (iw < 0 || iw >= IW) continue;
54
55 float s = src.get_elem(src_off_f(prb, mb, ic, id, ih, iw));
56 if (s > max_value) {
57 max_value = s;
58 ws_off = ker_off_f(prb, kd, kh, kw);
59 }
60 avg_value += s;
61 }
62 }
63 }
64
65 const auto dst_off = dst_off_f(prb, mb, ic, od, oh, ow);
66 float res = 0.f;
67 if (prb->alg == max) {
68 res = max_value;
69 if (!(prb->dir & FLAG_INF)) ws.set_elem(dst_off, ws_off);
70 } else if (prb->alg == avg_np || prb->alg == avg_p) {
71 res = avg_value / get_num_summands(prb, od, oh, ow);
72 }
73
74 const auto v_po_vals = prepare_po_vals(dst, args, v_po_masks, dst_off);
75
76 maybe_post_ops(prb->attr, res, 0.f, v_po_vals);
77 dst_ptr[dst_off] = res;
78 };
79
80 benchdnn_parallel_nd(prb->mb, prb->ic, prb->od, prb->oh, prb->ow,
81 [&](int64_t mb, int64_t ic, int64_t od, int64_t oh, int64_t ow) {
82 ker(mb, ic, od, oh, ow);
83 });
84}
85
86void compute_ref_bwd(const prb_t *prb, const args_t &args) {
87 const dnn_mem_t &d_dst = args.find(DNNL_ARG_DIFF_DST);
88 const dnn_mem_t &ws = args.find(DNNL_ARG_WORKSPACE);
89 const dnn_mem_t &d_src = args.find(DNNL_ARG_DIFF_SRC);
90
91 float *d_src_ptr = (float *)d_src;
92
93 auto zero_d_src = [&](int64_t mb, int64_t ic) {
94 for_(int64_t id = 0; id < prb->id; ++id)
95 for_(int64_t ih = 0; ih < prb->ih; ++ih)
96 for (int64_t iw = 0; iw < prb->iw; ++iw)
97 d_src_ptr[src_off_f(prb, mb, ic, id, ih, iw)] = 0.f;
98 };
99
100 auto ker = [&](int64_t mb, int64_t ic, int64_t od, int64_t oh, int64_t ow) {
101 const auto d_dst_off = dst_off_f(prb, mb, ic, od, oh, ow);
102 float d_dst_val = d_dst.get_elem(d_dst_off);
103 int ws_off = (prb->alg == max) ? ws.get_elem(d_dst_off) : 0;
104
105 const int64_t ID = prb->id, IH = prb->ih, IW = prb->iw;
106 const int64_t KD = prb->kd, KH = prb->kh, KW = prb->kw;
107 const int64_t DD = prb->dd, DH = prb->dh, DW = prb->dw;
108 const int64_t PD = prb->pd, PH = prb->ph, PW = prb->pw;
109 const int64_t SD = prb->sd, SH = prb->sh, SW = prb->sw;
110
111 for (int64_t kd = 0; kd < KD; ++kd) {
112 const int64_t id = od * SD - PD + kd * (DD + 1);
113 if (id < 0 || id >= ID) continue;
114 for (int64_t kh = 0; kh < KH; ++kh) {
115 const int64_t ih = oh * SH - PH + kh * (DH + 1);
116 if (ih < 0 || ih >= IH) continue;
117 for (int64_t kw = 0; kw < KW; ++kw) {
118 const int64_t iw = ow * SW - PW + kw * (DW + 1);
119 if (iw < 0 || iw >= IW) continue;
120
121 float &S = d_src_ptr[src_off_f(prb, mb, ic, id, ih, iw)];
122 if (prb->alg == max) {
123 if (ws_off == ker_off_f(prb, kd, kh, kw))
124 S += d_dst_val;
125 } else if (prb->alg == avg_np || prb->alg == avg_p)
126 S += d_dst_val / get_num_summands(prb, od, oh, ow);
127 }
128 }
129 }
130 };
131
132 benchdnn_parallel_nd(prb->mb, prb->ic, [&](int64_t mb, int64_t ic) {
133 zero_d_src(mb, ic);
134 for_(int64_t od = 0; od < prb->od; ++od)
135 for_(int64_t oh = 0; oh < prb->oh; ++oh)
136 for (int64_t ow = 0; ow < prb->ow; ++ow)
137 ker(mb, ic, od, oh, ow);
138 });
139}
140
141void compute_ref(
142 const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) {
143 compute_ref_fwd(prb, args);
144 if (prb->dir & FLAG_BWD) compute_ref_bwd(prb, args);
145}
146
147} // namespace pool
148