1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3* Copyright 2022 Arm Ltd. and affiliates
4*
5* Licensed under the Apache License, Version 2.0 (the "License");
6* you may not use this file except in compliance with the License.
7* You may obtain a copy of the License at
8*
9* http://www.apache.org/licenses/LICENSE-2.0
10*
11* Unless required by applicable law or agreed to in writing, software
12* distributed under the License is distributed on an "AS IS" BASIS,
13* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14* See the License for the specific language governing permissions and
15* limitations under the License.
16*******************************************************************************/
17
18#include <stdio.h>
19#include <stdlib.h>
20
21#include <sstream>
22
23#include "oneapi/dnnl/dnnl.h"
24
25#include "utils/parallel.hpp"
26
27#include "dnnl_common.hpp"
28#include "dnnl_memory.hpp"
29
30#include "binary/binary.hpp"
31#include "pool/pool.hpp"
32
33namespace pool {
34
35int fill_dat(const prb_t *prb, data_kind_t kind, dnn_mem_t &mem_dt,
36 dnn_mem_t &mem_fp, res_t *res) {
37 const int64_t MB {prb->mb};
38 const int64_t IC {prb->ic};
39 const int64_t D {kind == SRC ? prb->id : prb->od};
40 const int64_t H {kind == SRC ? prb->ih : prb->oh};
41 const int64_t W {kind == SRC ? prb->iw : prb->ow};
42 const int64_t ker_size {prb->kd * prb->kh * prb->kw};
43 const auto &c = prb->cfg[kind];
44 // For huge kernels to get different output values filling should be very
45 // variative, thus, use a factor of 1.
46 const bool has_huge_kernel = ker_size >= c.f_max;
47
48 benchdnn_parallel_nd(MB, IC, D, H, W,
49 [&](int64_t mb, int64_t ic, int64_t d, int64_t h, int64_t w) {
50 const int64_t factor
51 = prb->alg == max || has_huge_kernel ? 1 : ker_size;
52 // keep values for avg_exclude_pad positive to prevent cancellation err
53 const int64_t f_min = prb->alg == max ? c.f_min / factor : 0;
54 // divide on factor to keep value in the range
55 const int64_t range = c.f_max / factor - f_min + 1;
56 const int64_t gen
57 = 5 * d + 17 * h + 13 * w + 13 * mb + 19 * ic + 1637;
58 const float value = (f_min + gen % range) * factor;
59
60 const size_t off = kind == SRC
61 ? src_off_f(prb, mb, ic, d, h, w)
62 : dst_off_f(prb, mb, ic, d, h, w);
63 ((float *)mem_fp)[off] = value;
64 });
65
66 SAFE(mem_dt.reorder(mem_fp), WARN);
67
68 return OK;
69}
70
71int fill_src(
72 const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *res) {
73 return fill_dat(prb, SRC, mem_dt, mem_fp, res);
74}
75
76int fill_dst(
77 const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *res) {
78 return fill_dat(prb, DST, mem_dt, mem_fp, res);
79}
80
81// fill ws with big numbers to reliably cause a correctness issue (and not
82// anything else) in case of a bug in the library
83int fill_ws(
84 const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *res) {
85 benchdnn_parallel_nd(mem_fp.nelems(),
86 [&](int64_t i) { mem_fp.set_elem(i, (1 << 24) - 1); });
87
88 SAFE(mem_dt.reorder(mem_fp), WARN);
89
90 return OK;
91}
92
93dnnl_status_t init_pd(init_pd_args_t<prb_t> &init_pd_args) {
94 const prb_t *prb = init_pd_args.prb;
95 const dir_t dir = init_pd_args.dir;
96
97 const auto src_tag = (dir & FLAG_FWD) ? prb->tag : tag::any;
98
99 auto src_d = dnn_mem_t::init_md(
100 prb->ndims, prb->src_dims().data(), prb->cfg[SRC].dt, src_tag);
101 auto dst_d = dnn_mem_t::init_md(
102 prb->ndims, prb->dst_dims().data(), prb->cfg[DST].dt, tag::any);
103
104 attr_args_t attr_args;
105 attr_args.prepare_post_ops_mds(
106 prb->attr, prb->ndims, prb->dst_dims().data());
107 auto dnnl_attr = make_benchdnn_dnnl_wrapper(
108 create_dnnl_attr(prb->attr, attr_args));
109
110 dnnl_alg_kind_t alg = alg2alg_kind(prb->alg);
111
112 if (dir & FLAG_FWD) {
113 auto prop_kind = prb->dir & FLAG_INF ? dnnl_forward_inference
114 : dnnl_forward_training;
115 DNN_SAFE_STATUS(dnnl_pooling_forward_primitive_desc_create(
116 &init_pd_args.pd, init_pd_args.engine, prop_kind, alg, src_d,
117 dst_d, prb->strides().data(), prb->kernel().data(),
118 prb->dilations().data(), prb->padding().data(),
119 prb->padding_r().data(), dnnl_attr));
120 } else {
121 DNN_SAFE_STATUS(dnnl_pooling_backward_primitive_desc_create(
122 &init_pd_args.pd, init_pd_args.engine, alg, src_d, dst_d,
123 prb->strides().data(), prb->kernel().data(),
124 prb->dilations().data(), prb->padding().data(),
125 prb->padding_r().data(), init_pd_args.hint, dnnl_attr));
126 }
127 return dnnl_success;
128}
129
130void skip_unimplemented_prb(const prb_t *prb, res_t *res) {
131 skip_unimplemented_data_type(
132 {prb->cfg[SRC].dt, prb->cfg[DST].dt}, prb->dir, res);
133 skip_unimplemented_sum_po(prb->attr, res);
134
135 if (is_cpu() && prb->cfg[SRC].dt != prb->cfg[DST].dt) {
136 res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED;
137 return;
138 }
139
140#if DNNL_AARCH64_USE_ACL
141 // Since ACL supports only forward pass.
142 // Ref: https://github.com/oneapi-src/oneDNN/issues/1205
143 if (prb->dir & FLAG_BWD) {
144 res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED;
145 return;
146 }
147#endif
148}
149
150void skip_invalid_prb(const prb_t *prb, res_t *res) {
151 // Average pooling without padding can't handle cases when kernel window is
152 // applied to padded area only.
153 if (prb->alg == avg_np) {
154 bool ker_in_pad_d = prb->pd >= prb->kd || prb->pd_r >= prb->kd;
155 bool ker_in_pad_h = prb->ph >= prb->kh || prb->ph_r >= prb->kh;
156 bool ker_in_pad_w = prb->pw >= prb->kw || prb->pw_r >= prb->kw;
157 bool ker_in_pad = ker_in_pad_d || ker_in_pad_h || ker_in_pad_w;
158
159 if (ker_in_pad) {
160 res->state = SKIPPED, res->reason = INVALID_CASE;
161 return;
162 }
163 }
164}
165
166void setup_cmp(compare::compare_t &cmp, const prb_t *prb, data_kind_t kind,
167 const args_t &ref_args) {
168 cmp.set_threshold(prb->cfg[kind].eps);
169 // Backward may have most zeroes for ker_in_pad with huge kernels problems.
170 const float zero_percent = (prb->dir & FLAG_FWD) ? 99.f : 100.f;
171 cmp.set_zero_trust_percent(zero_percent); // TODO: consider enabling
172
173 const auto pooling_add_check
174 = [&](const compare::compare_t::driver_check_func_args_t &args) {
175 // cuDNN bug: it spits fp16 min value as -inf,
176 // not -65504.
177 if (is_nvidia_gpu() && args.dt == dnnl_f16) {
178 return args.exp == lowest_dt(args.dt)
179 && std::isinf(args.got) && std::signbit(args.got);
180 }
181 return false;
182 };
183 cmp.set_driver_check_function(pooling_add_check);
184}
185
186int doit(const prb_t *prb, res_t *res) {
187 if (bench_mode == LIST) return res->state = LISTED, OK;
188
189 benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim;
190 bool is_service_prim = prb->dir & FLAG_BWD;
191 SAFE(init_prim(prb->ctx_init, prim, init_pd, prb, res, FLAG_FWD, nullptr,
192 is_service_prim),
193 WARN);
194 if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK;
195
196 auto const_fpd = query_pd(prim);
197
198 const auto &src_md = query_md(const_fpd, DNNL_ARG_SRC);
199 const auto &dst_md = query_md(const_fpd, DNNL_ARG_DST);
200 const auto &ws_md = query_md(const_fpd, DNNL_ARG_WORKSPACE);
201 const auto &scratchpad_md = query_md(const_fpd, DNNL_ARG_SCRATCHPAD);
202
203 SAFE(!check_md_consistency_with_tag(dst_md, prb->tag), WARN);
204
205 const auto fp = dnnl_f32;
206 const auto tag = tag::abx;
207
208 const auto &test_engine = get_test_engine();
209 const auto &ref_engine = get_cpu_engine();
210
211 dnn_mem_t src_fp(src_md, fp, tag, ref_engine);
212 dnn_mem_t src_dt(src_md, test_engine);
213
214 dnn_mem_t dst_fp(dst_md, fp, tag, ref_engine);
215 dnn_mem_t dst_dt(dst_md, test_engine);
216
217 dnn_mem_t ws_fp(ws_md, dnnl_s32, tag::abx, ref_engine);
218 dnn_mem_t ws_dt(ws_md, test_engine);
219 if (prb->dir & FLAG_INF) SAFE(ws_dt.ndims() == 0 ? OK : FAIL, WARN);
220 dnn_mem_t scratchpad_dt(scratchpad_md, test_engine);
221 std::vector<dnn_mem_t> binary_po_fp, binary_po_dt;
222 std::vector<int> binary_po_args;
223 SAFE(binary::setup_binary_po(
224 const_fpd, binary_po_args, binary_po_dt, binary_po_fp),
225 WARN);
226
227 dnn_mem_t d_src_dt, d_dst_dt;
228
229 SAFE(fill_src(prb, src_dt, src_fp, res), WARN);
230
231 args_t args, ref_args;
232
233 args.set(DNNL_ARG_SRC, src_dt);
234 args.set(DNNL_ARG_DST, dst_dt);
235 args.set(DNNL_ARG_WORKSPACE, ws_dt);
236 args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt);
237 args.set(binary_po_args, binary_po_dt);
238
239 SAFE(execute_and_wait(prim, args, res), WARN);
240
241 // want this pass on backward to get ws_fp filled properly
242 if (is_bench_mode(CORR)) {
243 if (prb->dir & FLAG_FWD) {
244 ref_args.set(DNNL_ARG_SRC, src_fp);
245 ref_args.set(DNNL_ARG_DST, dst_fp);
246 ref_args.set(DNNL_ARG_WORKSPACE, ws_fp);
247 ref_args.set(binary_po_args, binary_po_fp);
248
249 check_correctness(prb, {DST}, args, ref_args, setup_cmp, res);
250 }
251 }
252
253 if (prb->dir & FLAG_BWD) {
254 benchdnn_dnnl_wrapper_t<dnnl_primitive_t> tmp_prim;
255 SAFE(init_prim(tmp_prim, init_pd, prb, res, FLAG_BWD, const_fpd), WARN);
256 if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK;
257 prim.reset(tmp_prim.release());
258
259 auto const_bpd = query_pd(prim);
260
261 const auto &d_dst_md = query_md(const_bpd, DNNL_ARG_DIFF_DST);
262 const auto &d_src_md = query_md(const_bpd, DNNL_ARG_DIFF_SRC);
263 const auto &d_scratchpad_md = query_md(const_bpd, DNNL_ARG_SCRATCHPAD);
264
265 dnn_mem_t d_dst_fp = dnn_mem_t(d_dst_md, fp, tag, ref_engine);
266 d_dst_dt = dnn_mem_t(d_dst_md, test_engine);
267
268 dnn_mem_t d_src_fp = dnn_mem_t(d_src_md, fp, tag, ref_engine);
269 d_src_dt = dnn_mem_t(d_src_md, test_engine);
270
271 scratchpad_dt = dnn_mem_t(d_scratchpad_md, test_engine);
272
273 SAFE(fill_dst(prb, d_dst_dt, d_dst_fp, res), WARN);
274
275 args.clear();
276 args.set(DNNL_ARG_DIFF_DST, d_dst_dt);
277 args.set(DNNL_ARG_DIFF_SRC, d_src_dt);
278 args.set(DNNL_ARG_WORKSPACE, ws_dt);
279 args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt);
280
281 SAFE(execute_and_wait(prim, args, res), WARN);
282
283 if (is_bench_mode(CORR)) {
284 ref_args.set(DNNL_ARG_SRC, src_fp);
285 ref_args.set(DNNL_ARG_DST, dst_fp);
286 ref_args.set(DNNL_ARG_WORKSPACE, ws_fp);
287 ref_args.set(binary_po_args, binary_po_fp);
288 ref_args.set(DNNL_ARG_DIFF_DST, d_dst_fp);
289 ref_args.set(DNNL_ARG_DIFF_SRC, d_src_fp);
290
291 check_correctness(prb, {SRC}, args, ref_args, setup_cmp, res);
292 }
293 }
294
295 return measure_perf(prb->ctx_exe, res, prim, args);
296}
297
298} // namespace pool
299