1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <stdio.h>
18#include <stdlib.h>
19
20#include <sstream>
21
22#include "oneapi/dnnl/dnnl.h"
23
24#include "utils/parallel.hpp"
25
26#include "dnnl_common.hpp"
27#include "dnnl_memory.hpp"
28
29#include "binary/binary.hpp"
30#include "resampling/resampling.hpp"
31
32namespace resampling {
33
34int fill_dat(const prb_t *prb, data_kind_t kind, dnn_mem_t &mem_dt,
35 dnn_mem_t &mem_fp, res_t *res) {
36 const auto nelems = mem_fp.nelems();
37 const auto dt = mem_dt.dt();
38 const int range = 16;
39 const int f_min = 0;
40
41 benchdnn_parallel_nd(nelems, [&](int64_t i) {
42 const float gen = ((97 * i) - 19 * kind + 101) % (range + 1);
43 const float value = dt == dnnl_f32 || is_integral_dt(dt)
44 ? (f_min + gen) * (1.0f + 4.0f / range)
45 : (f_min + gen) / range;
46
47 mem_fp.set_elem(i, round_to_nearest_representable(dt, value));
48 });
49
50 SAFE(mem_dt.reorder(mem_fp), WARN);
51
52 return OK;
53}
54
55int fill_src(
56 const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *res) {
57 return fill_dat(prb, SRC, mem_dt, mem_fp, res);
58}
59
60int fill_dst(
61 const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *res) {
62 return fill_dat(prb, DST, mem_dt, mem_fp, res);
63}
64
65dnnl_status_t init_pd(init_pd_args_t<prb_t> &init_pd_args) {
66 const prb_t *prb = init_pd_args.prb;
67
68 std::string src_tag = (prb->dir & FLAG_FWD) ? prb->tag : tag::any;
69 std::string dst_tag = (prb->dir & FLAG_BWD) ? prb->tag : tag::any;
70
71 auto src_d = dnn_mem_t::init_md(
72 prb->ndims, prb->src_dims().data(), prb->sdt, src_tag);
73 auto dst_d = dnn_mem_t::init_md(
74 prb->ndims, prb->dst_dims().data(), prb->ddt, dst_tag);
75
76 dnnl_alg_kind_t alg = alg2alg_kind(prb->alg);
77
78 attr_args_t attr_args;
79 attr_args.prepare_post_ops_mds(
80 prb->attr, prb->ndims, prb->dst_dims().data());
81 const auto dnnl_attr = make_benchdnn_dnnl_wrapper(
82 create_dnnl_attr(prb->attr, attr_args));
83
84 if (prb->dir & FLAG_FWD) {
85 auto prop_kind = prb->dir & FLAG_INF ? dnnl_forward_inference
86 : dnnl_forward_training;
87 DNN_SAFE_STATUS(dnnl_resampling_forward_primitive_desc_create(
88 &init_pd_args.pd, init_pd_args.engine, prop_kind, alg, nullptr,
89 src_d, dst_d, dnnl_attr));
90 } else {
91 DNN_SAFE_STATUS(dnnl_resampling_backward_primitive_desc_create(
92 &init_pd_args.pd, init_pd_args.engine, alg, nullptr, src_d,
93 dst_d, init_pd_args.hint, dnnl_attr));
94 }
95 return dnnl_success;
96}
97
98void skip_unimplemented_prb(const prb_t *prb, res_t *res) {
99 skip_unimplemented_data_type({prb->sdt, prb->ddt}, prb->dir, res);
100 skip_unimplemented_sum_po(prb->attr, res);
101}
102
103void skip_invalid_prb(const prb_t *prb, res_t *res) {}
104
105void setup_cmp(compare::compare_t &cmp, const prb_t *prb, data_kind_t kind,
106 const args_t &ref_args) {
107 const auto dt_from = (prb->dir & FLAG_FWD) ? prb->sdt : prb->ddt;
108 const auto dt_to = (prb->dir & FLAG_FWD) ? prb->ddt : prb->sdt;
109 const float linear_trh = epsilon_dt(dt_from) > epsilon_dt(dt_to)
110 ? epsilon_dt(dt_from) // conversion error for dt_to
111 : 7 * epsilon_dt(dt_to); // algorithm calculation error
112 float trh = prb->alg == nearest ? 0.f : linear_trh;
113 if (is_nvidia_gpu()) {
114 // cuDNN precision is different from ref one due to different
115 // computation algorithm used for resampling.
116 trh = prb->ddt == dnnl_f16 ? 4e-2 : 2e-5;
117 }
118 cmp.set_threshold(trh);
119
120 // No sense to test zero trust for upsampling since it produces valid zeros.
121 // TODO: validate this once again.
122 cmp.set_zero_trust_percent(99.f);
123}
124
125int doit(const prb_t *prb, res_t *res) {
126 if (bench_mode == LIST) return res->state = LISTED, OK;
127
128 benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim;
129 SAFE(init_prim(prb->ctx_init, prim, init_pd, prb, res), WARN);
130 if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK;
131
132 auto const_pd = query_pd(prim);
133
134 const auto &src_md = prb->dir == BWD_D
135 ? query_md(const_pd, DNNL_ARG_DIFF_SRC)
136 : query_md(const_pd, DNNL_ARG_SRC);
137 const auto &dst_md = prb->dir == BWD_D
138 ? query_md(const_pd, DNNL_ARG_DIFF_DST)
139 : query_md(const_pd, DNNL_ARG_DST);
140 const auto &scratchpad_md = query_md(const_pd, DNNL_ARG_SCRATCHPAD);
141
142 const auto fp = dnnl_f32;
143 const auto tag = tag::abx;
144
145 const auto &test_engine = get_test_engine();
146 const auto &ref_engine = get_cpu_engine();
147
148 dnn_mem_t src_fp(src_md, fp, tag, ref_engine);
149 dnn_mem_t src_dt(src_md, test_engine);
150
151 dnn_mem_t dst_fp(dst_md, fp, tag, ref_engine);
152 dnn_mem_t dst_dt(dst_md, test_engine);
153 if (prb->attr.post_ops.find(attr_t::post_ops_t::kind_t::SUM) >= 0)
154 SAFE(fill_dst(prb, dst_dt, dst_fp, res), WARN);
155
156 std::vector<dnn_mem_t> binary_po_fp, binary_po_dt;
157 std::vector<int> binary_po_args;
158 // When post-ops occur, the relative difference can change
159 // between the output from reference and the kernel. The compare
160 // function usually uses to compare a relative difference.
161 // Therefore, we should not lead to a situation where the
162 // relative difference is very small after executing a
163 // post-ops operation. Therefore, all values for binary post_ops
164 // are positive when the linear algorithm is present. This is
165 // important because there may be small differences in the result
166 // between the expected value and the gotten value with this algorithm.
167 const bool only_positive_values = prb->alg == linear;
168 SAFE(binary::setup_binary_po(const_pd, binary_po_args, binary_po_dt,
169 binary_po_fp, only_positive_values),
170 WARN);
171
172 dnn_mem_t scratchpad_dt(scratchpad_md, test_engine);
173
174 args_t args, ref_args;
175
176 if (prb->dir & FLAG_FWD) {
177 SAFE(fill_src(prb, src_dt, src_fp, res), WARN);
178
179 args.set(DNNL_ARG_SRC, src_dt);
180 args.set(DNNL_ARG_DST, dst_dt);
181 args.set(binary_po_args, binary_po_dt);
182 args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt);
183
184 SAFE(execute_and_wait(prim, args, res), WARN);
185
186 if (is_bench_mode(CORR)) {
187 ref_args.set(DNNL_ARG_SRC, src_fp);
188 ref_args.set(DNNL_ARG_DST, dst_fp);
189 ref_args.set(binary_po_args, binary_po_fp);
190
191 check_correctness(prb, {DST}, args, ref_args, setup_cmp, res);
192 }
193 } else {
194 SAFE(fill_dst(prb, dst_dt, dst_fp, res), WARN);
195
196 args.set(DNNL_ARG_DIFF_DST, dst_dt);
197 args.set(DNNL_ARG_DIFF_SRC, src_dt);
198 args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt);
199
200 SAFE(execute_and_wait(prim, args, res), WARN);
201
202 if (is_bench_mode(CORR)) {
203 ref_args.set(DNNL_ARG_DIFF_DST, dst_fp);
204 ref_args.set(DNNL_ARG_DIFF_SRC, src_fp);
205
206 check_correctness(prb, {SRC}, args, ref_args, setup_cmp, res);
207 }
208 }
209
210 return measure_perf(prb->ctx_exe, res, prim, args);
211}
212
213} // namespace resampling
214