1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef RESAMPLING_HPP
18#define RESAMPLING_HPP
19
20#include <assert.h>
21#include <limits.h>
22#include <stdint.h>
23
24#include <iostream>
25
26#include "common.hpp"
27#include "dnn_types.hpp"
28#include "dnnl_common.hpp"
29#include "utils/perf_report.hpp"
30#include "utils/settings.hpp"
31
32namespace resampling {
33
34enum alg_t {
35 undef,
36 nearest,
37 linear,
38 resampling_nearest = nearest,
39 resampling_linear = linear,
40};
41alg_t str2alg(const char *str);
42const char *alg2str(alg_t alg);
43dnnl_alg_kind_t alg2alg_kind(alg_t alg);
44
45struct desc_t {
46 int64_t mb, ic;
47 int64_t id, ih, iw;
48 int64_t od, oh, ow;
49 std::string name;
50 int ndims;
51
52 dims_t src_dims() const;
53 dims_t dst_dims() const;
54};
55
56int str2desc(desc_t *desc, const char *str);
57std::ostream &operator<<(std::ostream &s, const desc_t &d);
58
59struct settings_t : public base_settings_t {
60 settings_t() = default;
61
62 // ctor to save certain fields from resetting
63 settings_t(const char *perf_template) : settings_t() {
64 this->perf_template = perf_template;
65 }
66
67 desc_t desc {};
68
69 std::vector<dir_t> dir {FWD_D};
70 std::vector<dnnl_data_type_t> sdt {dnnl_f32};
71 std::vector<dnnl_data_type_t> ddt {dnnl_f32};
72 std::vector<std::string> tag {tag::abx};
73 std::vector<alg_t> alg {nearest};
74
75 const char *perf_template_csv() const {
76 static const std::string args = "%dir%,%sdt%,%ddt%,%tag%,%alg%";
77 return perf_template_csv_base(args);
78 }
79
80 void reset() { *this = settings_t(perf_template); }
81};
82
83struct prb_t : public desc_t {
84 prb_t(const desc_t &desc, dir_t dir, dnnl_data_type_t sdt,
85 dnnl_data_type_t ddt, const std::string &tag, alg_t alg,
86 const attr_t &attr, const thr_ctx_t &ctx_init,
87 const thr_ctx_t &ctx_exe, int64_t mb = 0)
88 : desc_t(desc)
89 , dir(dir)
90 , sdt(sdt)
91 , ddt(ddt)
92 , tag(tag)
93 , alg(alg)
94 , attr(attr)
95 , ctx_init(ctx_init)
96 , ctx_exe(ctx_exe)
97 , user_mb(mb) {
98 if (mb) this->mb = mb;
99 }
100 ~prb_t() {}
101
102 dir_t dir;
103 dnnl_data_type_t sdt, ddt;
104 std::string tag;
105 alg_t alg;
106 attr_t attr;
107 thr_ctx_t ctx_init, ctx_exe;
108 int64_t user_mb;
109
110 BENCHDNN_DISALLOW_COPY_AND_ASSIGN(prb_t);
111};
112std::ostream &operator<<(std::ostream &s, const prb_t &prb);
113
114struct perf_report_t : public base_perf_report_t {
115 perf_report_t(const prb_t *prb, const char *perf_template)
116 : base_perf_report_t(perf_template)
117 , p_(prb)
118 , sdt_({prb->sdt})
119 , tag_(normalize_tag(p_->tag, p_->ndims)) {}
120
121 void dump_alg(std::ostream &s) const override { s << alg2str(p_->alg); }
122
123 void dump_desc(std::ostream &s) const override {
124 s << static_cast<const desc_t &>(*p_);
125 }
126
127 void dump_desc_csv(std::ostream &s) const override {
128 s << p_->mb << ','
129
130 << p_->ic << ',' << p_->id << ',' << p_->ih << ',' << p_->iw << ','
131
132 << p_->od << ',' << p_->oh << ',' << p_->ow;
133 }
134
135 const int64_t *user_mb() const override { return &p_->user_mb; }
136 const attr_t *attr() const override { return &p_->attr; }
137 const thr_ctx_t *ctx_init() const override { return &p_->ctx_init; }
138 const thr_ctx_t *ctx_exe() const override { return &p_->ctx_exe; }
139 const std::string *name() const override { return &p_->name; }
140 const dir_t *dir() const override { return &p_->dir; }
141 const std::vector<dnnl_data_type_t> *sdt() const override { return &sdt_; }
142 const dnnl_data_type_t *ddt() const override { return &p_->ddt; }
143 const std::string *tag() const override { return &tag_; }
144
145private:
146 const prb_t *p_;
147 std::vector<dnnl_data_type_t> sdt_;
148 std::string tag_;
149};
150
151inline int64_t src_off_f(const prb_t *prb, int64_t mb, int64_t ic, int64_t id,
152 int64_t ih, int64_t iw) {
153 return (((mb * prb->ic + ic) * prb->id + id) * prb->ih + ih) * prb->iw + iw;
154}
155
156inline int64_t dst_off_f(const prb_t *prb, int64_t mb, int64_t ic, int64_t od,
157 int64_t oh, int64_t ow) {
158 return (((mb * prb->ic + ic) * prb->od + od) * prb->oh + oh) * prb->ow + ow;
159}
160
161void skip_unimplemented_prb(const prb_t *prb, res_t *res);
162void skip_invalid_prb(const prb_t *prb, res_t *res);
163void compute_ref(const prb_t *prb, const args_t &args,
164 dnnl_primitive_t prim_ref = nullptr);
165
166int compare_src(
167 const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *res);
168int compare_dst(
169 const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *res);
170int fill_dat(
171 const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *res);
172
173int doit(const prb_t *prb, res_t *res);
174int bench(int argc, char **argv);
175
176} // namespace resampling
177
178#endif
179