1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * Copyright 2021 FUJITSU LIMITED |
4 | * Copyright 2021 Arm Ltd. and affiliates |
5 | * |
6 | * Licensed under the Apache License, Version 2.0 (the "License"); |
7 | * you may not use this file except in compliance with the License. |
8 | * You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, software |
13 | * distributed under the License is distributed on an "AS IS" BASIS, |
14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
15 | * See the License for the specific language governing permissions and |
16 | * limitations under the License. |
17 | *******************************************************************************/ |
18 | |
19 | #include <algorithm> |
20 | #include <cstring> |
21 | #include <vector> |
22 | |
23 | #include <float.h> |
24 | #include <math.h> |
25 | #include <stdio.h> |
26 | #include <stdlib.h> |
27 | |
28 | #include "oneapi/dnnl/dnnl.h" |
29 | |
30 | #include "utils/parallel.hpp" |
31 | |
32 | #include "dnnl_common.hpp" |
33 | #include "dnnl_memory.hpp" |
34 | |
35 | #include "binary/binary.hpp" |
36 | #include "conv/conv.hpp" |
37 | #include "prelu/prelu.hpp" |
38 | |
39 | namespace conv { |
40 | |
41 | double get_non_zero_trust_percent(const prb_t *prb, data_kind_t kind) { |
42 | auto negative_to_zero = [&]() { |
43 | using pk = attr_t::post_ops_t::kind_t; |
44 | const auto &po = prb->attr.post_ops; |
45 | int count = 0; |
46 | |
47 | // Check for all post-ops that convert negative to zero |
48 | std::vector<pk> non_neg_po {pk::ABS}; |
49 | std::vector<pk> non_neg_alpha_0_po { |
50 | pk::CLIP, pk::CLIP_V2, pk::ELU, pk::RELU}; |
51 | for (int i = 0; i < po.len(); ++i) { |
52 | const auto &e = po.entry[i]; |
53 | if (!e.is_eltwise_kind()) continue; |
54 | |
55 | auto k = e.kind; |
56 | auto alpha = e.eltwise.alpha; |
57 | |
58 | count += std::any_of(non_neg_po.cbegin(), non_neg_po.cend(), |
59 | [k](const pk alg) { return alg == k; }); |
60 | count += std::any_of(non_neg_alpha_0_po.cbegin(), |
61 | non_neg_alpha_0_po.cend(), [k, alpha](const pk alg) { |
62 | return alg == k && alpha == 0; |
63 | }); |
64 | } |
65 | // Check for u8 dst |
66 | count += prb->cfg[DST].dt == dnnl_u8; |
67 | // Check for physically padded area in the output |
68 | count += prb->od > prb->id || prb->oh > prb->ih || prb->ow > prb->iw; |
69 | |
70 | return !!count; |
71 | }; |
72 | |
73 | double trust = 0.3; /* why? */ |
74 | switch (kind) { |
75 | case SRC: trust /= prb->sd * prb->sh * prb->sw; break; |
76 | case WEI: |
77 | trust /= 1. * prb->kd * prb->kh * prb->kw |
78 | / MIN3(prb->kd * prb->kh * prb->kw, |
79 | prb->id * prb->ih * prb->iw, |
80 | prb->od * prb->oh * prb->ow); |
81 | break; |
82 | case BIA: |
83 | trust = 0.8 * prb->cfg[DST].f_sparsity; /* why? */ |
84 | break; |
85 | case DST: trust /= (1.f + negative_to_zero()); break; |
86 | default: assert(!"unsupported data kind" ); |
87 | } |
88 | |
89 | return trust; |
90 | } |
91 | |
92 | bool need_src_init(const prb_t *prb) { |
93 | return !(prb->dir == BWD_D); |
94 | } |
95 | |
96 | bool need_wei_init(const prb_t *prb) { |
97 | return !(prb->dir & FLAG_BWD && prb->dir & FLAG_WEI); |
98 | } |
99 | |
100 | bool need_bia_init(const prb_t *prb) { |
101 | return need_wei_init(prb); |
102 | } |
103 | |
104 | bool need_dst_init(const prb_t *prb) { |
105 | return !(prb->dir & FLAG_FWD) |
106 | || (prb->attr.post_ops.find(attr_t::post_ops_t::SUM) >= 0); |
107 | } |
108 | |
109 | int fill_src( |
110 | const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *res) { |
111 | const bool check_reorder |
112 | = (is_bench_mode(CORR)) && (mem_dt.dt() != mem_fp.dt()); |
113 | dnn_mem_t ; |
114 | if (check_reorder) { |
115 | extra_mem = dnn_mem_t(mem_dt.md_, dnnl_f32, tag::abx, get_cpu_engine()); |
116 | } |
117 | dnn_mem_t &mem_00 = check_reorder ? extra_mem : mem_fp; |
118 | |
119 | // Use dense filling for small problems. |
120 | int src_nelems_mask = powf(2.f, prb->ndims) - 1; |
121 | src_nelems_mask -= 1; // remove minibatch as independent dimension |
122 | auto src_nelems = prb->desc_nelems(DNNL_ARG_SRC, src_nelems_mask); |
123 | if (prb->has_groups) src_nelems /= prb->g; // groups are also independent |
124 | |
125 | const auto &c = prb->get_dt_conf(SRC); |
126 | const int range = c.f_max - c.f_min + 1; |
127 | const float sparsity |
128 | = (!is_bench_mode(CORR) || src_nelems < 100) ? 1.f : c.f_sparsity; |
129 | |
130 | benchdnn_parallel_nd(prb->mb, prb->ic, prb->id, prb->ih, prb->iw, |
131 | [&](int64_t mb, int64_t ic, int64_t id, int64_t ih, int64_t iw) { |
132 | const int64_t gen |
133 | = 101 * id + 103 * ih + 107 * iw + 109 * mb + 113 * ic; |
134 | const bool non_base = flip_coin(gen, sparsity); |
135 | float value = non_base ? c.f_min + gen * c.f_step % range |
136 | : c.f_base; |
137 | |
138 | maybe_zero_point( |
139 | prb->attr, value, prb->src_zp, ic, DNNL_ARG_SRC, true); |
140 | |
141 | ((float *)mem_00)[src_off_f(prb, mb, 0, ic, id, ih, iw)] |
142 | = round_to_nearest_representable(mem_dt.dt(), value); |
143 | }); |
144 | |
145 | SAFE(mem_dt.reorder(mem_00), WARN); |
146 | if (check_reorder) { |
147 | SAFE(mem_fp.reorder(mem_dt), WARN); |
148 | int rc = std::memcmp((void *)mem_fp, (void *)mem_00, mem_00.size()); |
149 | if (rc != 0) { |
150 | res->state = FAILED; |
151 | SAFE(FAIL, WARN); |
152 | } |
153 | } |
154 | |
155 | return OK; |
156 | } |
157 | |
158 | int fill_wei( |
159 | const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *res) { |
160 | const bool is_def_zp = prb->attr.zero_points.is_def(DNNL_ARG_SRC); |
161 | const bool diff_data_type = mem_dt.dt() != mem_fp.dt(); |
162 | |
163 | dnnl_data_type_t dt_check = dnnl_s8; |
164 | #if defined(DNNL_AARCH64) && (DNNL_AARCH64 == 1) |
165 | /* Note for x64: |
166 | Both data types of src and weight are s8, oneDNN addds 128 to one of the s8 |
167 | input to make it of type u8 instead, as explained in |
168 | https://oneapi-src.github.io/oneDNN/dev_guide_int8_computations.html or |
169 | doc/advanced/int8_computations.md |
170 | It is because `VPDPBUSD` instruction uses the combination of s8 and u8 as |
171 | input. |
172 | |
173 | Note for AArch64: |
174 | Because dot product instructions of AArch64 "SDOT" receives s8 input |
175 | for both src and weight, the addition (and its counterpart of subtraction) |
176 | is not required for AArch64. |
177 | */ |
178 | if (res->impl_name.find("jit" , 0) == 0) dt_check = dnnl_u8; |
179 | #endif |
180 | |
181 | const bool wei_x8x8 = prb->get_dt_conf(WEI).dt == dnnl_s8 |
182 | && prb->get_dt_conf(SRC).dt == dt_check; |
183 | const bool check_reorder |
184 | = (is_bench_mode(CORR)) && diff_data_type && !wei_x8x8 && is_def_zp; |
185 | |
186 | dnn_mem_t ; |
187 | if (check_reorder) { |
188 | extra_mem = dnn_mem_t(mem_dt.md_, dnnl_f32, tag::abx, get_cpu_engine()); |
189 | } |
190 | dnn_mem_t &mem_00 = check_reorder ? extra_mem : mem_fp; |
191 | |
192 | const auto &c = prb->get_dt_conf(WEI); |
193 | const int range = c.f_max - c.f_min + 1; |
194 | const float sparsity = !is_bench_mode(CORR) ? 1.f : c.f_sparsity; |
195 | |
196 | benchdnn_parallel_nd(prb->g, prb->oc / prb->g, prb->ic / prb->g, prb->kd, |
197 | prb->kh, prb->kw, |
198 | [&](int64_t g, int64_t oc, int64_t ic, int64_t kd, int64_t kh, |
199 | int64_t kw) { |
200 | const int64_t gen = 113 * g + 127 * kd + 131 * kh + 137 * kw |
201 | + 139 * oc + 149 * ic + 151; |
202 | const bool non_base = flip_coin(gen, sparsity); |
203 | const float value = non_base ? c.f_min + gen * c.f_step % range |
204 | : c.f_base; |
205 | ((float *)mem_00)[wei_off_f(prb, g, oc, ic, kd, kh, kw)] |
206 | = value; |
207 | }); |
208 | |
209 | SAFE(mem_dt.reorder(mem_00), WARN); |
210 | if (check_reorder) { |
211 | SAFE(mem_fp.reorder(mem_dt), WARN); |
212 | int rc = std::memcmp((void *)mem_fp, (void *)mem_00, mem_00.size()); |
213 | if (rc != 0) { |
214 | res->state = FAILED; |
215 | SAFE(FAIL, WARN); |
216 | } |
217 | } |
218 | if ((wei_x8x8 || !is_def_zp) && is_cpu()) { |
219 | // Check that s8 -> s8_comp exists in the library since users may have |
220 | // already quantized data. |
221 | dnn_mem_t mem_fp_s8(mem_fp.md_, dnnl_s8, tag::abx, get_cpu_engine()); |
222 | dnn_mem_t mem_dt_s8(mem_dt.md_, get_test_engine()); |
223 | SAFE(mem_fp_s8.reorder(mem_fp), WARN); |
224 | SAFE(mem_dt_s8.reorder(mem_fp_s8), WARN); |
225 | SAFE(mem_dt.size() == mem_dt_s8.size() ? OK : FAIL, WARN); |
226 | int rc = std::memcmp((void *)mem_dt, (void *)mem_dt_s8, mem_dt.size()); |
227 | SAFE(rc == 0 ? OK : FAIL, WARN); |
228 | } |
229 | return OK; |
230 | } |
231 | |
232 | int fill_bia( |
233 | const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *res) { |
234 | const bool check_reorder |
235 | = (is_bench_mode(CORR)) && (mem_dt.dt() != mem_fp.dt()); |
236 | dnn_mem_t ; |
237 | if (check_reorder) |
238 | extra_mem = dnn_mem_t(mem_dt.md_, dnnl_f32, tag::x, get_cpu_engine()); |
239 | dnn_mem_t &mem_00 = check_reorder ? extra_mem : mem_fp; |
240 | |
241 | const size_t nelems = mem_00.nelems(); |
242 | if (nelems == 0) return OK; |
243 | |
244 | const auto &c = prb->get_dt_conf(BIA); |
245 | const int range = c.f_max - c.f_min + 1; |
246 | |
247 | for (size_t i = 0; i < nelems; ++i) { |
248 | const int gen = (int)(151 * i); |
249 | const bool non_base = flip_coin(gen, c.f_sparsity); |
250 | const float value |
251 | = non_base ? c.f_min + gen * c.f_step % range : c.f_base; |
252 | |
253 | ((float *)mem_00)[i] = value; |
254 | } |
255 | |
256 | SAFE(mem_dt.reorder(mem_00), WARN); |
257 | if (check_reorder) { |
258 | SAFE(mem_fp.reorder(mem_dt), WARN); |
259 | int rc = std::memcmp((void *)mem_fp, (void *)mem_00, mem_00.size()); |
260 | if (rc != 0) { |
261 | res->state = FAILED; |
262 | SAFE(FAIL, WARN); |
263 | } |
264 | } |
265 | return OK; |
266 | } |
267 | |
268 | int fill_dst_with_params(const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, |
269 | dnnl_data_type_t dt, double sparsity, int min, int max, int base, |
270 | int step, res_t *res) { |
271 | const bool check_reorder |
272 | = (is_bench_mode(CORR)) && (mem_dt.dt() != mem_fp.dt()); |
273 | dnn_mem_t ; |
274 | if (check_reorder) { |
275 | extra_mem = dnn_mem_t(mem_dt.md_, dnnl_f32, tag::abx, get_cpu_engine()); |
276 | } |
277 | |
278 | dnn_mem_t &mem_00 = check_reorder ? extra_mem : mem_fp; |
279 | const int range = max - min + 1; |
280 | |
281 | benchdnn_parallel_nd(prb->mb, prb->oc, prb->od, prb->oh, prb->ow, |
282 | [&](int64_t mb, int64_t oc, int64_t od, int64_t oh, int64_t ow) { |
283 | const int64_t gen |
284 | = 157 * od + 163 * oh + 167 * ow + 173 * mb + 179 * oc; |
285 | const bool non_base = flip_coin(gen, sparsity); |
286 | const float value = non_base ? min + gen * step % range : base; |
287 | |
288 | ((float *)mem_00)[dst_off_f(prb, mb, 0, oc, od, oh, ow)] |
289 | = value; |
290 | }); |
291 | |
292 | SAFE(mem_dt.reorder(mem_00), WARN); |
293 | if (check_reorder) { |
294 | SAFE(mem_fp.reorder(mem_dt), WARN); |
295 | int rc = std::memcmp((void *)mem_fp, (void *)mem_00, mem_00.size()); |
296 | if (rc != 0) { |
297 | res->state = FAILED; |
298 | SAFE(FAIL, WARN); |
299 | } |
300 | } |
301 | |
302 | return OK; |
303 | } |
304 | |
305 | int fill_dst( |
306 | const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *res) { |
307 | auto dst_dt = mem_dt.dt(); |
308 | int sum_ind = prb->attr.post_ops.find(attr_t::post_ops_t::kind_t::SUM); |
309 | auto sum_dt = (sum_ind != -1) ? prb->attr.post_ops.entry[sum_ind].sum.dt |
310 | : dnnl_data_type_undef; |
311 | bool diff_sum_dst_types |
312 | = sum_dt != dnnl_data_type_undef && sum_dt != dst_dt; |
313 | bool sum_dt_is_int8 = sum_dt == dnnl_s8 || sum_dt == dnnl_u8; |
314 | |
315 | const auto &c = prb->get_dt_conf(DST); |
316 | float f_min = c.f_min; |
317 | float f_max = c.f_max; |
318 | if (diff_sum_dst_types && sum_dt_is_int8) { |
319 | f_min = lowest_dt(sum_dt); |
320 | f_max = max_dt(sum_dt); |
321 | } |
322 | |
323 | // Change mem dt to sum dt, so we can save sum data properly. |
324 | if (diff_sum_dst_types) { mem_dt.set_dt(sum_dt); } |
325 | |
326 | fill_dst_with_params(prb, mem_dt, mem_fp, sum_dt, c.f_sparsity, f_min, |
327 | f_max, c.f_base, c.f_step, res); |
328 | |
329 | // Return dst data type back. |
330 | if (diff_sum_dst_types) { mem_dt.set_dt(dst_dt); } |
331 | return OK; |
332 | } |
333 | |
334 | dnnl_status_t init_pd(init_pd_args_t<prb_t> &init_pd_args) { |
335 | const prb_t *prb = init_pd_args.prb; |
336 | |
337 | auto src_d = dnn_mem_t::init_md(prb->ndims, prb->src_dims().data(), |
338 | prb->get_dt_conf(SRC).dt, normalize_tag(prb->stag, prb->ndims)); |
339 | auto wei_d = dnn_mem_t::init_md(prb->ndims + prb->has_groups, |
340 | prb->wei_dims().data(), prb->get_dt_conf(WEI).dt, |
341 | normalize_tag(prb->wtag, prb->ndims + prb->has_groups)); |
342 | auto bia_d = dnn_mem_t::init_md( |
343 | 1, prb->bia_dims().data(), prb->get_dt_conf(BIA).dt, tag::any); |
344 | auto dst_d = dnn_mem_t::init_md(prb->ndims, prb->dst_dims().data(), |
345 | prb->get_dt_conf(DST).dt, normalize_tag(prb->dtag, prb->ndims)); |
346 | |
347 | dnnl_alg_kind_t alg = dnnl_convolution_direct; |
348 | if (prb->alg == WINO) alg = dnnl_convolution_winograd; |
349 | if (prb->alg == AUTO) alg = dnnl_convolution_auto; |
350 | |
351 | attr_args_t attr_args; |
352 | attr_args.prepare_post_ops_mds( |
353 | prb->attr, prb->ndims, prb->dst_dims().data()); |
354 | auto wei_scale = prb->attr.scales.get(DNNL_ARG_WEIGHTS); |
355 | if (wei_scale.policy == policy_t::PER_OC) { |
356 | auto wei_mask = prb->has_groups ? 2 : 1; |
357 | attr_args.prepare_scales(prb->attr, DNNL_ARG_WEIGHTS, prb->wei_scales, |
358 | prb->oc, wei_mask); |
359 | } |
360 | auto dnnl_attr = make_benchdnn_dnnl_wrapper( |
361 | create_dnnl_attr(prb->attr, attr_args)); |
362 | |
363 | switch (prb->dir) { |
364 | case FWD_D: |
365 | case FWD_B: |
366 | case FWD_I: |
367 | if (prb->dir != FWD_B) bia_d.reset(nullptr); |
368 | DNN_SAFE_STATUS(dnnl_convolution_forward_primitive_desc_create( |
369 | &init_pd_args.pd, init_pd_args.engine, |
370 | prb->dir == FWD_I ? dnnl_forward_inference |
371 | : dnnl_forward_training, |
372 | alg, src_d, wei_d, bia_d, dst_d, prb->strides().data(), |
373 | prb->dilations().data(), prb->padding().data(), |
374 | prb->padding_r().data(), dnnl_attr)); |
375 | break; |
376 | case BWD_D: |
377 | DNN_SAFE_STATUS( |
378 | dnnl_convolution_backward_data_primitive_desc_create( |
379 | &init_pd_args.pd, init_pd_args.engine, alg, src_d, |
380 | wei_d, dst_d, prb->strides().data(), |
381 | prb->dilations().data(), prb->padding().data(), |
382 | prb->padding_r().data(), init_pd_args.hint, |
383 | dnnl_attr)); |
384 | break; |
385 | case BWD_W: |
386 | case BWD_WB: |
387 | if (prb->dir == BWD_W) bia_d.reset(nullptr); |
388 | DNN_SAFE_STATUS( |
389 | dnnl_convolution_backward_weights_primitive_desc_create( |
390 | &init_pd_args.pd, init_pd_args.engine, alg, src_d, |
391 | wei_d, bia_d, dst_d, prb->strides().data(), |
392 | prb->dilations().data(), prb->padding().data(), |
393 | prb->padding_r().data(), init_pd_args.hint, |
394 | dnnl_attr)); |
395 | break; |
396 | default: DNN_SAFE_STATUS(dnnl_invalid_arguments); |
397 | } |
398 | |
399 | // TODO: add query for acc type in pd. |
400 | //DNN_SAFE_STATUS(cd.accum_data_type == prb->get_dt_conf(ACC).dt |
401 | // ? dnnl_success |
402 | // : dnnl_unimplemented); |
403 | return dnnl_success; |
404 | } |
405 | |
406 | int init_prim_ref( |
407 | benchdnn_dnnl_wrapper_t<dnnl_primitive_t> &prim_ref, const prb_t *prb) { |
408 | if (!(is_bench_mode(CORR) && is_gpu() && fast_ref_gpu)) return OK; |
409 | |
410 | // Create a new copy of prb to avoid potentially corrupting the test by |
411 | // modifying prb in place. |
412 | // DIRECT algorithm is used to prevent fallback to the slow benchdnn |
413 | // reference implementation. |
414 | auto cpu_attr = prb->attr; |
415 | update_cpu_ref_attrs(cpu_attr); |
416 | prb_t prb_cpu {*prb, prb->dir, conf_f32, tag::abx, tag::abx, tag::abx, |
417 | DIRECT, cpu_attr, prb->ctx_init, prb->ctx_exe, prb->mb}; |
418 | |
419 | init_pd_args_t<prb_t> init_pd_args( |
420 | /* res = */ nullptr, get_cpu_engine(), &prb_cpu, prb->dir, |
421 | /* hint = */ nullptr); |
422 | init_pd(init_pd_args); |
423 | |
424 | benchdnn_dnnl_wrapper_t<dnnl_primitive_desc_t> pdw; |
425 | fetch_impl(pdw, init_pd_args, /* res = */ nullptr, |
426 | /* is_service_prim = */ true); |
427 | |
428 | dnnl_primitive_t prim_ref_ {}; |
429 | if (pdw) { |
430 | if (query_impl_info(pdw) == "ref:any" ) return OK; |
431 | DNN_SAFE(dnnl_primitive_create(&prim_ref_, pdw), WARN); |
432 | BENCHDNN_PRINT(5, "CPU reference oneDNN implementation: %s\n" , |
433 | query_impl_info(pdw).c_str()); |
434 | } |
435 | prim_ref.reset(prim_ref_); |
436 | return OK; |
437 | } |
438 | |
439 | void skip_unimplemented_prb(const prb_t *prb, res_t *res) { |
440 | skip_unimplemented_data_type( |
441 | {prb->get_dt_conf(SRC).dt, prb->get_dt_conf(WEI).dt, |
442 | prb->get_dt_conf(DST).dt}, |
443 | prb->dir, res); |
444 | skip_unimplemented_sum_po(prb->attr, res, prb->get_dt_conf(DST).dt); |
445 | |
446 | if (is_cpu()) { |
447 | // Specific configurations are not supported. |
448 | const bool is_f32_src = prb->get_dt_conf(SRC).dt == dnnl_f32; |
449 | const bool is_f32_wei = prb->get_dt_conf(WEI).dt == dnnl_f32; |
450 | const bool is_f16 = prb->get_dt_conf(WEI).dt == dnnl_f16; |
451 | const bool is_bf16_src = prb->get_dt_conf(SRC).dt == dnnl_bf16; |
452 | const bool is_bf16_wei = prb->get_dt_conf(WEI).dt == dnnl_bf16; |
453 | const bool is_int8_dst = prb->get_dt_conf(DST).dt == dnnl_s8 |
454 | || prb->get_dt_conf(DST).dt == dnnl_u8; |
455 | const bool is_f32f32x8 = is_f32_src && is_f32_wei && is_int8_dst; |
456 | const bool is_bf16bf16x8 = is_bf16_src && is_bf16_wei && is_int8_dst; |
457 | const bool is_valid_f16 = is_f16 |
458 | && (prb->get_dt_conf(DST).dt == dnnl_f32 |
459 | || prb->get_dt_conf(DST).dt == dnnl_f16); |
460 | |
461 | if (is_f32f32x8 || is_bf16bf16x8 || !is_valid_f16) { |
462 | res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED; |
463 | return; |
464 | } |
465 | } |
466 | |
467 | // Winograd implementation has very limited scope and support. It doesn't |
468 | // make sense to list all of them, just convert all unimplemented Winograd |
469 | // problems into not supported. |
470 | if (prb->alg == WINO) { |
471 | res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED; |
472 | return; |
473 | } |
474 | } |
475 | |
476 | void skip_invalid_prb(const prb_t *prb, res_t *res) {} |
477 | |
478 | void setup_cmp(compare::compare_t &cmp, const prb_t *prb, data_kind_t kind, |
479 | const args_t &ref_args) { |
480 | const bool compare_with_norm = (prb->alg & WINO); |
481 | cmp.set_norm_validation_mode(compare_with_norm); |
482 | |
483 | float trh = prb->get_dt_conf(kind).eps; |
484 | if ((prb->alg & WINO) && (prb->dir & FLAG_WEI)) { |
485 | // This is an empirical equation derived by observing growth error with |
486 | // increasing 'k' dimension in gemm of winograd |
487 | const float log_const = log10(0.125 * prb->mb * prb->oh * prb->ow); |
488 | trh = prb->get_dt_conf(kind).eps * (MAX2(1, pow(10, 0.4 * log_const))); |
489 | } |
490 | cmp.set_threshold(trh); |
491 | |
492 | const float zpp = (1.f - get_non_zero_trust_percent(prb, kind)) * 100.f; |
493 | cmp.set_zero_trust_percent(zpp); |
494 | } |
495 | |
496 | int doit(const prb_t *prb, res_t *res) { |
497 | if (bench_mode == LIST) return res->state = LISTED, OK; |
498 | |
499 | benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim; |
500 | SAFE(init_prim(prb->ctx_init, prim, init_pd, prb, res), WARN); |
501 | if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK; |
502 | |
503 | auto const_pd = query_pd(prim); |
504 | |
505 | if (prb->alg == AUTO) prb->alg = alg_kind2alg(query_alg_kind(const_pd)); |
506 | prb->cfg = auto_cfg(prb->alg, prb->cfg); |
507 | |
508 | const auto &src_md = prb->dir == BWD_D |
509 | ? query_md(const_pd, DNNL_ARG_DIFF_SRC) |
510 | : query_md(const_pd, DNNL_ARG_SRC); |
511 | const auto &wei_md = prb->dir & FLAG_WEI |
512 | ? query_md(const_pd, DNNL_ARG_DIFF_WEIGHTS) |
513 | : query_md(const_pd, DNNL_ARG_WEIGHTS); |
514 | const auto &bia_md = prb->dir & FLAG_WEI |
515 | ? query_md(const_pd, DNNL_ARG_DIFF_BIAS) |
516 | : query_md(const_pd, DNNL_ARG_BIAS); |
517 | const auto &dst_md = prb->dir & FLAG_BWD |
518 | ? query_md(const_pd, DNNL_ARG_DIFF_DST) |
519 | : query_md(const_pd, DNNL_ARG_DST); |
520 | const auto &scratchpad_md = query_md(const_pd, DNNL_ARG_SCRATCHPAD); |
521 | |
522 | const auto fp = dnnl_f32; |
523 | const auto src_tag = tag::abx; |
524 | const auto wei_tag = tag::abx; |
525 | |
526 | // Use CPU prim as the reference in GPU testing to reduce testing time. |
527 | benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim_ref; |
528 | SAFE(init_prim_ref(prim_ref, prb), WARN); |
529 | |
530 | const auto &test_engine = get_test_engine(); |
531 | const auto &ref_engine = get_cpu_engine(); |
532 | |
533 | dnn_mem_t src_dt(src_md, test_engine); |
534 | dnn_mem_t wei_dt(wei_md, test_engine); |
535 | dnn_mem_t dst_dt(dst_md, test_engine); |
536 | dnn_mem_t bia_dt(bia_md, test_engine); |
537 | dnn_mem_t scratchpad_dt(scratchpad_md, test_engine); |
538 | dnn_mem_t src_scales_dt, src_scales_fp; |
539 | dnn_mem_t wei_scales_dt, wei_scales_fp; |
540 | dnn_mem_t dst_scales_dt, dst_scales_fp; |
541 | dnn_mem_t src_zp_dt, src_zp_fp; |
542 | dnn_mem_t dst_zp_dt, dst_zp_fp; |
543 | std::vector<dnn_mem_t> binary_po_fp, binary_po_dt; |
544 | std::vector<int> binary_po_args; |
545 | SAFE(binary::setup_binary_po( |
546 | const_pd, binary_po_args, binary_po_dt, binary_po_fp), |
547 | WARN); |
548 | std::vector<dnn_mem_t> prelu_po_fp, prelu_po_dt; |
549 | std::vector<int> prelu_po_args; |
550 | SAFE(prelu::setup_prelu_po( |
551 | const_pd, prelu_po_args, prelu_po_fp, prelu_po_dt), |
552 | WARN); |
553 | |
554 | dnn_mem_t src_fp(src_md, fp, src_tag, ref_engine); |
555 | dnn_mem_t wei_fp(wei_md, fp, wei_tag, ref_engine); |
556 | dnn_mem_t dst_fp(dst_md, fp, src_tag, ref_engine); |
557 | dnn_mem_t bia_fp(bia_md, fp, tag::x, ref_engine); |
558 | dnn_mem_t scratchpad_fp; |
559 | if (prim_ref) |
560 | scratchpad_fp = dnn_mem_t( |
561 | query_md(query_pd(prim_ref), DNNL_ARG_SCRATCHPAD), ref_engine); |
562 | |
563 | if (need_src_init(prb)) SAFE(fill_src(prb, src_dt, src_fp, res), WARN); |
564 | if (need_dst_init(prb)) SAFE(fill_dst(prb, dst_dt, dst_fp, res), WARN); |
565 | if (need_wei_init(prb)) SAFE(fill_wei(prb, wei_dt, wei_fp, res), WARN); |
566 | if (need_bia_init(prb)) SAFE(fill_bia(prb, bia_dt, bia_fp, res), WARN); |
567 | |
568 | const int src_mask = attr_t::get_default_mask( |
569 | prb->attr.scales.get(DNNL_ARG_SRC).policy); |
570 | int wei_mask = attr_t::get_default_mask( |
571 | prb->attr.scales.get(DNNL_ARG_WEIGHTS).policy, DNNL_ARG_WEIGHTS); |
572 | if (prb->has_groups) wei_mask = (1 << wei_mask) + 1; |
573 | const int dst_mask = attr_t::get_default_mask( |
574 | prb->attr.scales.get(DNNL_ARG_DST).policy); |
575 | maybe_prepare_runtime_scales_v2(src_scales_dt, src_scales_fp, |
576 | prb->attr.scales.get(DNNL_ARG_SRC), |
577 | prb->desc_nelems(DNNL_ARG_SRC, src_mask), prb->src_scales); |
578 | maybe_prepare_runtime_scales_v2(wei_scales_dt, wei_scales_fp, |
579 | prb->attr.scales.get(DNNL_ARG_WEIGHTS), |
580 | prb->desc_nelems(DNNL_ARG_WEIGHTS, wei_mask), prb->wei_scales); |
581 | maybe_prepare_runtime_scales_v2(dst_scales_dt, dst_scales_fp, |
582 | prb->attr.scales.get(DNNL_ARG_DST), |
583 | prb->desc_nelems(DNNL_ARG_DST, dst_mask), prb->dst_scales); |
584 | |
585 | maybe_prepare_runtime_zero_points_v2(src_zp_dt, src_zp_fp, prb->attr, |
586 | DNNL_ARG_SRC, prb->ic, prb->src_zp); |
587 | maybe_prepare_runtime_zero_points_v2(dst_zp_dt, dst_zp_fp, prb->attr, |
588 | DNNL_ARG_DST, prb->oc, prb->dst_zp); |
589 | |
590 | args_t args, ref_args; |
591 | |
592 | if (prb->dir & FLAG_FWD) { |
593 | args.set(DNNL_ARG_SRC, src_dt); |
594 | args.set(DNNL_ARG_WEIGHTS, wei_dt); |
595 | args.set(DNNL_ARG_BIAS, bia_dt); |
596 | args.set(DNNL_ARG_DST, dst_dt); |
597 | args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt); |
598 | args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_scales_dt); |
599 | args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, wei_scales_dt); |
600 | args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, dst_scales_dt); |
601 | args.set(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC, src_zp_dt); |
602 | args.set(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST, dst_zp_dt); |
603 | args.set(binary_po_args, binary_po_dt); |
604 | args.set(prelu_po_args, prelu_po_dt); |
605 | |
606 | SAFE(execute_and_wait(prim, args, res), WARN); |
607 | |
608 | if (is_bench_mode(CORR)) { |
609 | ref_args.set(DNNL_ARG_SRC, src_fp); |
610 | ref_args.set(DNNL_ARG_WEIGHTS, wei_fp); |
611 | ref_args.set(DNNL_ARG_BIAS, bia_fp); |
612 | ref_args.set(DNNL_ARG_DST, dst_fp); |
613 | ref_args.set(DNNL_ARG_SCRATCHPAD, scratchpad_fp); |
614 | ref_args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_scales_fp); |
615 | ref_args.set( |
616 | DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, wei_scales_fp); |
617 | ref_args.set(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, dst_scales_fp); |
618 | ref_args.set(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC, src_zp_fp); |
619 | ref_args.set(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST, dst_zp_fp); |
620 | ref_args.set(binary_po_args, binary_po_fp); |
621 | ref_args.set(prelu_po_args, prelu_po_fp); |
622 | |
623 | check_correctness( |
624 | prb, {DST}, args, ref_args, setup_cmp, res, prim_ref); |
625 | } |
626 | } else if (prb->dir == BWD_D) { |
627 | args.set(DNNL_ARG_DIFF_SRC, src_dt); |
628 | args.set(DNNL_ARG_WEIGHTS, wei_dt); |
629 | args.set(DNNL_ARG_DIFF_DST, dst_dt); |
630 | args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt); |
631 | |
632 | SAFE(execute_and_wait(prim, args, res), WARN); |
633 | |
634 | if (is_bench_mode(CORR)) { |
635 | ref_args.set(DNNL_ARG_DIFF_SRC, src_fp); |
636 | ref_args.set(DNNL_ARG_WEIGHTS, wei_fp); |
637 | ref_args.set(DNNL_ARG_DIFF_DST, dst_fp); |
638 | ref_args.set(DNNL_ARG_SCRATCHPAD, scratchpad_fp); |
639 | |
640 | check_correctness( |
641 | prb, {SRC}, args, ref_args, setup_cmp, res, prim_ref); |
642 | } |
643 | } else if (prb->dir & FLAG_BWD && prb->dir & FLAG_WEI) { |
644 | args.set(DNNL_ARG_SRC, src_dt); |
645 | args.set(DNNL_ARG_DIFF_DST, dst_dt); |
646 | args.set(DNNL_ARG_DIFF_WEIGHTS, wei_dt); |
647 | args.set(DNNL_ARG_DIFF_BIAS, bia_dt); |
648 | args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt); |
649 | |
650 | SAFE(execute_and_wait(prim, args, res), WARN); |
651 | |
652 | if (is_bench_mode(CORR)) { |
653 | ref_args.set(DNNL_ARG_SRC, src_fp); |
654 | ref_args.set(DNNL_ARG_DIFF_DST, dst_fp); |
655 | ref_args.set(DNNL_ARG_DIFF_WEIGHTS, wei_fp); |
656 | ref_args.set(DNNL_ARG_DIFF_BIAS, bia_fp); |
657 | ref_args.set(DNNL_ARG_SCRATCHPAD, scratchpad_fp); |
658 | |
659 | check_correctness( |
660 | prb, {WEI, BIA}, args, ref_args, setup_cmp, res, prim_ref); |
661 | } |
662 | } else { |
663 | SAFE(FAIL, CRIT); |
664 | } |
665 | |
666 | return measure_perf(prb->ctx_exe, res, prim, args); |
667 | } |
668 | |
669 | } // namespace conv |
670 | |