1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <stdio.h> |
18 | #include <stdlib.h> |
19 | |
20 | #include <sstream> |
21 | |
22 | #include "oneapi/dnnl/dnnl.h" |
23 | |
24 | #include "utils/parallel.hpp" |
25 | |
26 | #include "dnnl_common.hpp" |
27 | #include "dnnl_memory.hpp" |
28 | |
29 | #include "binary/binary.hpp" |
30 | #include "resampling/resampling.hpp" |
31 | |
32 | namespace resampling { |
33 | |
34 | int fill_dat(const prb_t *prb, data_kind_t kind, dnn_mem_t &mem_dt, |
35 | dnn_mem_t &mem_fp, res_t *res) { |
36 | const auto nelems = mem_fp.nelems(); |
37 | const auto dt = mem_dt.dt(); |
38 | const int range = 16; |
39 | const int f_min = 0; |
40 | |
41 | benchdnn_parallel_nd(nelems, [&](int64_t i) { |
42 | const float gen = ((97 * i) - 19 * kind + 101) % (range + 1); |
43 | const float value = dt == dnnl_f32 || is_integral_dt(dt) |
44 | ? (f_min + gen) * (1.0f + 4.0f / range) |
45 | : (f_min + gen) / range; |
46 | |
47 | mem_fp.set_elem(i, round_to_nearest_representable(dt, value)); |
48 | }); |
49 | |
50 | SAFE(mem_dt.reorder(mem_fp), WARN); |
51 | |
52 | return OK; |
53 | } |
54 | |
55 | int fill_src( |
56 | const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *res) { |
57 | return fill_dat(prb, SRC, mem_dt, mem_fp, res); |
58 | } |
59 | |
60 | int fill_dst( |
61 | const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *res) { |
62 | return fill_dat(prb, DST, mem_dt, mem_fp, res); |
63 | } |
64 | |
65 | dnnl_status_t init_pd(init_pd_args_t<prb_t> &init_pd_args) { |
66 | const prb_t *prb = init_pd_args.prb; |
67 | |
68 | std::string src_tag = (prb->dir & FLAG_FWD) ? prb->tag : tag::any; |
69 | std::string dst_tag = (prb->dir & FLAG_BWD) ? prb->tag : tag::any; |
70 | |
71 | auto src_d = dnn_mem_t::init_md( |
72 | prb->ndims, prb->src_dims().data(), prb->sdt, src_tag); |
73 | auto dst_d = dnn_mem_t::init_md( |
74 | prb->ndims, prb->dst_dims().data(), prb->ddt, dst_tag); |
75 | |
76 | dnnl_alg_kind_t alg = alg2alg_kind(prb->alg); |
77 | |
78 | attr_args_t attr_args; |
79 | attr_args.prepare_post_ops_mds( |
80 | prb->attr, prb->ndims, prb->dst_dims().data()); |
81 | const auto dnnl_attr = make_benchdnn_dnnl_wrapper( |
82 | create_dnnl_attr(prb->attr, attr_args)); |
83 | |
84 | if (prb->dir & FLAG_FWD) { |
85 | auto prop_kind = prb->dir & FLAG_INF ? dnnl_forward_inference |
86 | : dnnl_forward_training; |
87 | DNN_SAFE_STATUS(dnnl_resampling_forward_primitive_desc_create( |
88 | &init_pd_args.pd, init_pd_args.engine, prop_kind, alg, nullptr, |
89 | src_d, dst_d, dnnl_attr)); |
90 | } else { |
91 | DNN_SAFE_STATUS(dnnl_resampling_backward_primitive_desc_create( |
92 | &init_pd_args.pd, init_pd_args.engine, alg, nullptr, src_d, |
93 | dst_d, init_pd_args.hint, dnnl_attr)); |
94 | } |
95 | return dnnl_success; |
96 | } |
97 | |
98 | void skip_unimplemented_prb(const prb_t *prb, res_t *res) { |
99 | skip_unimplemented_data_type({prb->sdt, prb->ddt}, prb->dir, res); |
100 | skip_unimplemented_sum_po(prb->attr, res); |
101 | } |
102 | |
103 | void skip_invalid_prb(const prb_t *prb, res_t *res) {} |
104 | |
105 | void setup_cmp(compare::compare_t &cmp, const prb_t *prb, data_kind_t kind, |
106 | const args_t &ref_args) { |
107 | const auto dt_from = (prb->dir & FLAG_FWD) ? prb->sdt : prb->ddt; |
108 | const auto dt_to = (prb->dir & FLAG_FWD) ? prb->ddt : prb->sdt; |
109 | const float linear_trh = epsilon_dt(dt_from) > epsilon_dt(dt_to) |
110 | ? epsilon_dt(dt_from) // conversion error for dt_to |
111 | : 7 * epsilon_dt(dt_to); // algorithm calculation error |
112 | float trh = prb->alg == nearest ? 0.f : linear_trh; |
113 | if (is_nvidia_gpu()) { |
114 | // cuDNN precision is different from ref one due to different |
115 | // computation algorithm used for resampling. |
116 | trh = prb->ddt == dnnl_f16 ? 4e-2 : 2e-5; |
117 | } |
118 | cmp.set_threshold(trh); |
119 | |
120 | // No sense to test zero trust for upsampling since it produces valid zeros. |
121 | // TODO: validate this once again. |
122 | cmp.set_zero_trust_percent(99.f); |
123 | } |
124 | |
125 | int doit(const prb_t *prb, res_t *res) { |
126 | if (bench_mode == LIST) return res->state = LISTED, OK; |
127 | |
128 | benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim; |
129 | SAFE(init_prim(prb->ctx_init, prim, init_pd, prb, res), WARN); |
130 | if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK; |
131 | |
132 | auto const_pd = query_pd(prim); |
133 | |
134 | const auto &src_md = prb->dir == BWD_D |
135 | ? query_md(const_pd, DNNL_ARG_DIFF_SRC) |
136 | : query_md(const_pd, DNNL_ARG_SRC); |
137 | const auto &dst_md = prb->dir == BWD_D |
138 | ? query_md(const_pd, DNNL_ARG_DIFF_DST) |
139 | : query_md(const_pd, DNNL_ARG_DST); |
140 | const auto &scratchpad_md = query_md(const_pd, DNNL_ARG_SCRATCHPAD); |
141 | |
142 | const auto fp = dnnl_f32; |
143 | const auto tag = tag::abx; |
144 | |
145 | const auto &test_engine = get_test_engine(); |
146 | const auto &ref_engine = get_cpu_engine(); |
147 | |
148 | dnn_mem_t src_fp(src_md, fp, tag, ref_engine); |
149 | dnn_mem_t src_dt(src_md, test_engine); |
150 | |
151 | dnn_mem_t dst_fp(dst_md, fp, tag, ref_engine); |
152 | dnn_mem_t dst_dt(dst_md, test_engine); |
153 | if (prb->attr.post_ops.find(attr_t::post_ops_t::kind_t::SUM) >= 0) |
154 | SAFE(fill_dst(prb, dst_dt, dst_fp, res), WARN); |
155 | |
156 | std::vector<dnn_mem_t> binary_po_fp, binary_po_dt; |
157 | std::vector<int> binary_po_args; |
158 | // When post-ops occur, the relative difference can change |
159 | // between the output from reference and the kernel. The compare |
160 | // function usually uses to compare a relative difference. |
161 | // Therefore, we should not lead to a situation where the |
162 | // relative difference is very small after executing a |
163 | // post-ops operation. Therefore, all values for binary post_ops |
164 | // are positive when the linear algorithm is present. This is |
165 | // important because there may be small differences in the result |
166 | // between the expected value and the gotten value with this algorithm. |
167 | const bool only_positive_values = prb->alg == linear; |
168 | SAFE(binary::setup_binary_po(const_pd, binary_po_args, binary_po_dt, |
169 | binary_po_fp, only_positive_values), |
170 | WARN); |
171 | |
172 | dnn_mem_t scratchpad_dt(scratchpad_md, test_engine); |
173 | |
174 | args_t args, ref_args; |
175 | |
176 | if (prb->dir & FLAG_FWD) { |
177 | SAFE(fill_src(prb, src_dt, src_fp, res), WARN); |
178 | |
179 | args.set(DNNL_ARG_SRC, src_dt); |
180 | args.set(DNNL_ARG_DST, dst_dt); |
181 | args.set(binary_po_args, binary_po_dt); |
182 | args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt); |
183 | |
184 | SAFE(execute_and_wait(prim, args, res), WARN); |
185 | |
186 | if (is_bench_mode(CORR)) { |
187 | ref_args.set(DNNL_ARG_SRC, src_fp); |
188 | ref_args.set(DNNL_ARG_DST, dst_fp); |
189 | ref_args.set(binary_po_args, binary_po_fp); |
190 | |
191 | check_correctness(prb, {DST}, args, ref_args, setup_cmp, res); |
192 | } |
193 | } else { |
194 | SAFE(fill_dst(prb, dst_dt, dst_fp, res), WARN); |
195 | |
196 | args.set(DNNL_ARG_DIFF_DST, dst_dt); |
197 | args.set(DNNL_ARG_DIFF_SRC, src_dt); |
198 | args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt); |
199 | |
200 | SAFE(execute_and_wait(prim, args, res), WARN); |
201 | |
202 | if (is_bench_mode(CORR)) { |
203 | ref_args.set(DNNL_ARG_DIFF_DST, dst_fp); |
204 | ref_args.set(DNNL_ARG_DIFF_SRC, src_fp); |
205 | |
206 | check_correctness(prb, {SRC}, args, ref_args, setup_cmp, res); |
207 | } |
208 | } |
209 | |
210 | return measure_perf(prb->ctx_exe, res, prim, args); |
211 | } |
212 | |
213 | } // namespace resampling |
214 | |