1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * Copyright 2022 Arm Ltd. and affiliates |
4 | * |
5 | * Licensed under the Apache License, Version 2.0 (the "License"); |
6 | * you may not use this file except in compliance with the License. |
7 | * You may obtain a copy of the License at |
8 | * |
9 | * http://www.apache.org/licenses/LICENSE-2.0 |
10 | * |
11 | * Unless required by applicable law or agreed to in writing, software |
12 | * distributed under the License is distributed on an "AS IS" BASIS, |
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
14 | * See the License for the specific language governing permissions and |
15 | * limitations under the License. |
16 | *******************************************************************************/ |
17 | |
18 | #include <stdio.h> |
19 | #include <stdlib.h> |
20 | |
21 | #include <sstream> |
22 | |
23 | #include "oneapi/dnnl/dnnl.h" |
24 | |
25 | #include "utils/parallel.hpp" |
26 | |
27 | #include "dnnl_common.hpp" |
28 | #include "dnnl_memory.hpp" |
29 | |
30 | #include "binary/binary.hpp" |
31 | #include "pool/pool.hpp" |
32 | |
33 | namespace pool { |
34 | |
35 | int fill_dat(const prb_t *prb, data_kind_t kind, dnn_mem_t &mem_dt, |
36 | dnn_mem_t &mem_fp, res_t *res) { |
37 | const int64_t MB {prb->mb}; |
38 | const int64_t IC {prb->ic}; |
39 | const int64_t D {kind == SRC ? prb->id : prb->od}; |
40 | const int64_t H {kind == SRC ? prb->ih : prb->oh}; |
41 | const int64_t W {kind == SRC ? prb->iw : prb->ow}; |
42 | const int64_t ker_size {prb->kd * prb->kh * prb->kw}; |
43 | const auto &c = prb->cfg[kind]; |
44 | // For huge kernels to get different output values filling should be very |
45 | // variative, thus, use a factor of 1. |
46 | const bool has_huge_kernel = ker_size >= c.f_max; |
47 | |
48 | benchdnn_parallel_nd(MB, IC, D, H, W, |
49 | [&](int64_t mb, int64_t ic, int64_t d, int64_t h, int64_t w) { |
50 | const int64_t factor |
51 | = prb->alg == max || has_huge_kernel ? 1 : ker_size; |
52 | // keep values for avg_exclude_pad positive to prevent cancellation err |
53 | const int64_t f_min = prb->alg == max ? c.f_min / factor : 0; |
54 | // divide on factor to keep value in the range |
55 | const int64_t range = c.f_max / factor - f_min + 1; |
56 | const int64_t gen |
57 | = 5 * d + 17 * h + 13 * w + 13 * mb + 19 * ic + 1637; |
58 | const float value = (f_min + gen % range) * factor; |
59 | |
60 | const size_t off = kind == SRC |
61 | ? src_off_f(prb, mb, ic, d, h, w) |
62 | : dst_off_f(prb, mb, ic, d, h, w); |
63 | ((float *)mem_fp)[off] = value; |
64 | }); |
65 | |
66 | SAFE(mem_dt.reorder(mem_fp), WARN); |
67 | |
68 | return OK; |
69 | } |
70 | |
71 | int fill_src( |
72 | const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *res) { |
73 | return fill_dat(prb, SRC, mem_dt, mem_fp, res); |
74 | } |
75 | |
76 | int fill_dst( |
77 | const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *res) { |
78 | return fill_dat(prb, DST, mem_dt, mem_fp, res); |
79 | } |
80 | |
81 | // fill ws with big numbers to reliably cause a correctness issue (and not |
82 | // anything else) in case of a bug in the library |
83 | int fill_ws( |
84 | const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *res) { |
85 | benchdnn_parallel_nd(mem_fp.nelems(), |
86 | [&](int64_t i) { mem_fp.set_elem(i, (1 << 24) - 1); }); |
87 | |
88 | SAFE(mem_dt.reorder(mem_fp), WARN); |
89 | |
90 | return OK; |
91 | } |
92 | |
93 | dnnl_status_t init_pd(init_pd_args_t<prb_t> &init_pd_args) { |
94 | const prb_t *prb = init_pd_args.prb; |
95 | const dir_t dir = init_pd_args.dir; |
96 | |
97 | const auto src_tag = (dir & FLAG_FWD) ? prb->tag : tag::any; |
98 | |
99 | auto src_d = dnn_mem_t::init_md( |
100 | prb->ndims, prb->src_dims().data(), prb->cfg[SRC].dt, src_tag); |
101 | auto dst_d = dnn_mem_t::init_md( |
102 | prb->ndims, prb->dst_dims().data(), prb->cfg[DST].dt, tag::any); |
103 | |
104 | attr_args_t attr_args; |
105 | attr_args.prepare_post_ops_mds( |
106 | prb->attr, prb->ndims, prb->dst_dims().data()); |
107 | auto dnnl_attr = make_benchdnn_dnnl_wrapper( |
108 | create_dnnl_attr(prb->attr, attr_args)); |
109 | |
110 | dnnl_alg_kind_t alg = alg2alg_kind(prb->alg); |
111 | |
112 | if (dir & FLAG_FWD) { |
113 | auto prop_kind = prb->dir & FLAG_INF ? dnnl_forward_inference |
114 | : dnnl_forward_training; |
115 | DNN_SAFE_STATUS(dnnl_pooling_forward_primitive_desc_create( |
116 | &init_pd_args.pd, init_pd_args.engine, prop_kind, alg, src_d, |
117 | dst_d, prb->strides().data(), prb->kernel().data(), |
118 | prb->dilations().data(), prb->padding().data(), |
119 | prb->padding_r().data(), dnnl_attr)); |
120 | } else { |
121 | DNN_SAFE_STATUS(dnnl_pooling_backward_primitive_desc_create( |
122 | &init_pd_args.pd, init_pd_args.engine, alg, src_d, dst_d, |
123 | prb->strides().data(), prb->kernel().data(), |
124 | prb->dilations().data(), prb->padding().data(), |
125 | prb->padding_r().data(), init_pd_args.hint, dnnl_attr)); |
126 | } |
127 | return dnnl_success; |
128 | } |
129 | |
130 | void skip_unimplemented_prb(const prb_t *prb, res_t *res) { |
131 | skip_unimplemented_data_type( |
132 | {prb->cfg[SRC].dt, prb->cfg[DST].dt}, prb->dir, res); |
133 | skip_unimplemented_sum_po(prb->attr, res); |
134 | |
135 | if (is_cpu() && prb->cfg[SRC].dt != prb->cfg[DST].dt) { |
136 | res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED; |
137 | return; |
138 | } |
139 | |
140 | #if DNNL_AARCH64_USE_ACL |
141 | // Since ACL supports only forward pass. |
142 | // Ref: https://github.com/oneapi-src/oneDNN/issues/1205 |
143 | if (prb->dir & FLAG_BWD) { |
144 | res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED; |
145 | return; |
146 | } |
147 | #endif |
148 | } |
149 | |
150 | void skip_invalid_prb(const prb_t *prb, res_t *res) { |
151 | // Average pooling without padding can't handle cases when kernel window is |
152 | // applied to padded area only. |
153 | if (prb->alg == avg_np) { |
154 | bool ker_in_pad_d = prb->pd >= prb->kd || prb->pd_r >= prb->kd; |
155 | bool ker_in_pad_h = prb->ph >= prb->kh || prb->ph_r >= prb->kh; |
156 | bool ker_in_pad_w = prb->pw >= prb->kw || prb->pw_r >= prb->kw; |
157 | bool ker_in_pad = ker_in_pad_d || ker_in_pad_h || ker_in_pad_w; |
158 | |
159 | if (ker_in_pad) { |
160 | res->state = SKIPPED, res->reason = INVALID_CASE; |
161 | return; |
162 | } |
163 | } |
164 | } |
165 | |
166 | void setup_cmp(compare::compare_t &cmp, const prb_t *prb, data_kind_t kind, |
167 | const args_t &ref_args) { |
168 | cmp.set_threshold(prb->cfg[kind].eps); |
169 | // Backward may have most zeroes for ker_in_pad with huge kernels problems. |
170 | const float zero_percent = (prb->dir & FLAG_FWD) ? 99.f : 100.f; |
171 | cmp.set_zero_trust_percent(zero_percent); // TODO: consider enabling |
172 | |
173 | const auto pooling_add_check |
174 | = [&](const compare::compare_t::driver_check_func_args_t &args) { |
175 | // cuDNN bug: it spits fp16 min value as -inf, |
176 | // not -65504. |
177 | if (is_nvidia_gpu() && args.dt == dnnl_f16) { |
178 | return args.exp == lowest_dt(args.dt) |
179 | && std::isinf(args.got) && std::signbit(args.got); |
180 | } |
181 | return false; |
182 | }; |
183 | cmp.set_driver_check_function(pooling_add_check); |
184 | } |
185 | |
186 | int doit(const prb_t *prb, res_t *res) { |
187 | if (bench_mode == LIST) return res->state = LISTED, OK; |
188 | |
189 | benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim; |
190 | bool is_service_prim = prb->dir & FLAG_BWD; |
191 | SAFE(init_prim(prb->ctx_init, prim, init_pd, prb, res, FLAG_FWD, nullptr, |
192 | is_service_prim), |
193 | WARN); |
194 | if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK; |
195 | |
196 | auto const_fpd = query_pd(prim); |
197 | |
198 | const auto &src_md = query_md(const_fpd, DNNL_ARG_SRC); |
199 | const auto &dst_md = query_md(const_fpd, DNNL_ARG_DST); |
200 | const auto &ws_md = query_md(const_fpd, DNNL_ARG_WORKSPACE); |
201 | const auto &scratchpad_md = query_md(const_fpd, DNNL_ARG_SCRATCHPAD); |
202 | |
203 | SAFE(!check_md_consistency_with_tag(dst_md, prb->tag), WARN); |
204 | |
205 | const auto fp = dnnl_f32; |
206 | const auto tag = tag::abx; |
207 | |
208 | const auto &test_engine = get_test_engine(); |
209 | const auto &ref_engine = get_cpu_engine(); |
210 | |
211 | dnn_mem_t src_fp(src_md, fp, tag, ref_engine); |
212 | dnn_mem_t src_dt(src_md, test_engine); |
213 | |
214 | dnn_mem_t dst_fp(dst_md, fp, tag, ref_engine); |
215 | dnn_mem_t dst_dt(dst_md, test_engine); |
216 | |
217 | dnn_mem_t ws_fp(ws_md, dnnl_s32, tag::abx, ref_engine); |
218 | dnn_mem_t ws_dt(ws_md, test_engine); |
219 | if (prb->dir & FLAG_INF) SAFE(ws_dt.ndims() == 0 ? OK : FAIL, WARN); |
220 | dnn_mem_t scratchpad_dt(scratchpad_md, test_engine); |
221 | std::vector<dnn_mem_t> binary_po_fp, binary_po_dt; |
222 | std::vector<int> binary_po_args; |
223 | SAFE(binary::setup_binary_po( |
224 | const_fpd, binary_po_args, binary_po_dt, binary_po_fp), |
225 | WARN); |
226 | |
227 | dnn_mem_t d_src_dt, d_dst_dt; |
228 | |
229 | SAFE(fill_src(prb, src_dt, src_fp, res), WARN); |
230 | |
231 | args_t args, ref_args; |
232 | |
233 | args.set(DNNL_ARG_SRC, src_dt); |
234 | args.set(DNNL_ARG_DST, dst_dt); |
235 | args.set(DNNL_ARG_WORKSPACE, ws_dt); |
236 | args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt); |
237 | args.set(binary_po_args, binary_po_dt); |
238 | |
239 | SAFE(execute_and_wait(prim, args, res), WARN); |
240 | |
241 | // want this pass on backward to get ws_fp filled properly |
242 | if (is_bench_mode(CORR)) { |
243 | if (prb->dir & FLAG_FWD) { |
244 | ref_args.set(DNNL_ARG_SRC, src_fp); |
245 | ref_args.set(DNNL_ARG_DST, dst_fp); |
246 | ref_args.set(DNNL_ARG_WORKSPACE, ws_fp); |
247 | ref_args.set(binary_po_args, binary_po_fp); |
248 | |
249 | check_correctness(prb, {DST}, args, ref_args, setup_cmp, res); |
250 | } |
251 | } |
252 | |
253 | if (prb->dir & FLAG_BWD) { |
254 | benchdnn_dnnl_wrapper_t<dnnl_primitive_t> tmp_prim; |
255 | SAFE(init_prim(tmp_prim, init_pd, prb, res, FLAG_BWD, const_fpd), WARN); |
256 | if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK; |
257 | prim.reset(tmp_prim.release()); |
258 | |
259 | auto const_bpd = query_pd(prim); |
260 | |
261 | const auto &d_dst_md = query_md(const_bpd, DNNL_ARG_DIFF_DST); |
262 | const auto &d_src_md = query_md(const_bpd, DNNL_ARG_DIFF_SRC); |
263 | const auto &d_scratchpad_md = query_md(const_bpd, DNNL_ARG_SCRATCHPAD); |
264 | |
265 | dnn_mem_t d_dst_fp = dnn_mem_t(d_dst_md, fp, tag, ref_engine); |
266 | d_dst_dt = dnn_mem_t(d_dst_md, test_engine); |
267 | |
268 | dnn_mem_t d_src_fp = dnn_mem_t(d_src_md, fp, tag, ref_engine); |
269 | d_src_dt = dnn_mem_t(d_src_md, test_engine); |
270 | |
271 | scratchpad_dt = dnn_mem_t(d_scratchpad_md, test_engine); |
272 | |
273 | SAFE(fill_dst(prb, d_dst_dt, d_dst_fp, res), WARN); |
274 | |
275 | args.clear(); |
276 | args.set(DNNL_ARG_DIFF_DST, d_dst_dt); |
277 | args.set(DNNL_ARG_DIFF_SRC, d_src_dt); |
278 | args.set(DNNL_ARG_WORKSPACE, ws_dt); |
279 | args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt); |
280 | |
281 | SAFE(execute_and_wait(prim, args, res), WARN); |
282 | |
283 | if (is_bench_mode(CORR)) { |
284 | ref_args.set(DNNL_ARG_SRC, src_fp); |
285 | ref_args.set(DNNL_ARG_DST, dst_fp); |
286 | ref_args.set(DNNL_ARG_WORKSPACE, ws_fp); |
287 | ref_args.set(binary_po_args, binary_po_fp); |
288 | ref_args.set(DNNL_ARG_DIFF_DST, d_dst_fp); |
289 | ref_args.set(DNNL_ARG_DIFF_SRC, d_src_fp); |
290 | |
291 | check_correctness(prb, {SRC}, args, ref_args, setup_cmp, res); |
292 | } |
293 | } |
294 | |
295 | return measure_perf(prb->ctx_exe, res, prim, args); |
296 | } |
297 | |
298 | } // namespace pool |
299 | |