1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef RESAMPLING_HPP |
18 | #define RESAMPLING_HPP |
19 | |
20 | #include <assert.h> |
21 | #include <limits.h> |
22 | #include <stdint.h> |
23 | |
24 | #include <iostream> |
25 | |
26 | #include "common.hpp" |
27 | #include "dnn_types.hpp" |
28 | #include "dnnl_common.hpp" |
29 | #include "utils/perf_report.hpp" |
30 | #include "utils/settings.hpp" |
31 | |
32 | namespace resampling { |
33 | |
34 | enum alg_t { |
35 | undef, |
36 | nearest, |
37 | linear, |
38 | resampling_nearest = nearest, |
39 | resampling_linear = linear, |
40 | }; |
41 | alg_t str2alg(const char *str); |
42 | const char *alg2str(alg_t alg); |
43 | dnnl_alg_kind_t alg2alg_kind(alg_t alg); |
44 | |
45 | struct desc_t { |
46 | int64_t mb, ic; |
47 | int64_t id, ih, iw; |
48 | int64_t od, oh, ow; |
49 | std::string name; |
50 | int ndims; |
51 | |
52 | dims_t src_dims() const; |
53 | dims_t dst_dims() const; |
54 | }; |
55 | |
56 | int str2desc(desc_t *desc, const char *str); |
57 | std::ostream &operator<<(std::ostream &s, const desc_t &d); |
58 | |
59 | struct settings_t : public base_settings_t { |
60 | settings_t() = default; |
61 | |
62 | // ctor to save certain fields from resetting |
63 | settings_t(const char *perf_template) : settings_t() { |
64 | this->perf_template = perf_template; |
65 | } |
66 | |
67 | desc_t desc {}; |
68 | |
69 | std::vector<dir_t> dir {FWD_D}; |
70 | std::vector<dnnl_data_type_t> sdt {dnnl_f32}; |
71 | std::vector<dnnl_data_type_t> ddt {dnnl_f32}; |
72 | std::vector<std::string> tag {tag::abx}; |
73 | std::vector<alg_t> alg {nearest}; |
74 | |
75 | const char *perf_template_csv() const { |
76 | static const std::string args = "%dir%,%sdt%,%ddt%,%tag%,%alg%" ; |
77 | return perf_template_csv_base(args); |
78 | } |
79 | |
80 | void reset() { *this = settings_t(perf_template); } |
81 | }; |
82 | |
83 | struct prb_t : public desc_t { |
84 | prb_t(const desc_t &desc, dir_t dir, dnnl_data_type_t sdt, |
85 | dnnl_data_type_t ddt, const std::string &tag, alg_t alg, |
86 | const attr_t &attr, const thr_ctx_t &ctx_init, |
87 | const thr_ctx_t &ctx_exe, int64_t mb = 0) |
88 | : desc_t(desc) |
89 | , dir(dir) |
90 | , sdt(sdt) |
91 | , ddt(ddt) |
92 | , tag(tag) |
93 | , alg(alg) |
94 | , attr(attr) |
95 | , ctx_init(ctx_init) |
96 | , ctx_exe(ctx_exe) |
97 | , user_mb(mb) { |
98 | if (mb) this->mb = mb; |
99 | } |
100 | ~prb_t() {} |
101 | |
102 | dir_t dir; |
103 | dnnl_data_type_t sdt, ddt; |
104 | std::string tag; |
105 | alg_t alg; |
106 | attr_t attr; |
107 | thr_ctx_t ctx_init, ctx_exe; |
108 | int64_t user_mb; |
109 | |
110 | BENCHDNN_DISALLOW_COPY_AND_ASSIGN(prb_t); |
111 | }; |
112 | std::ostream &operator<<(std::ostream &s, const prb_t &prb); |
113 | |
114 | struct perf_report_t : public base_perf_report_t { |
115 | perf_report_t(const prb_t *prb, const char *perf_template) |
116 | : base_perf_report_t(perf_template) |
117 | , p_(prb) |
118 | , sdt_({prb->sdt}) |
119 | , tag_(normalize_tag(p_->tag, p_->ndims)) {} |
120 | |
121 | void dump_alg(std::ostream &s) const override { s << alg2str(p_->alg); } |
122 | |
123 | void dump_desc(std::ostream &s) const override { |
124 | s << static_cast<const desc_t &>(*p_); |
125 | } |
126 | |
127 | void dump_desc_csv(std::ostream &s) const override { |
128 | s << p_->mb << ',' |
129 | |
130 | << p_->ic << ',' << p_->id << ',' << p_->ih << ',' << p_->iw << ',' |
131 | |
132 | << p_->od << ',' << p_->oh << ',' << p_->ow; |
133 | } |
134 | |
135 | const int64_t *user_mb() const override { return &p_->user_mb; } |
136 | const attr_t *attr() const override { return &p_->attr; } |
137 | const thr_ctx_t *ctx_init() const override { return &p_->ctx_init; } |
138 | const thr_ctx_t *ctx_exe() const override { return &p_->ctx_exe; } |
139 | const std::string *name() const override { return &p_->name; } |
140 | const dir_t *dir() const override { return &p_->dir; } |
141 | const std::vector<dnnl_data_type_t> *sdt() const override { return &sdt_; } |
142 | const dnnl_data_type_t *ddt() const override { return &p_->ddt; } |
143 | const std::string *tag() const override { return &tag_; } |
144 | |
145 | private: |
146 | const prb_t *p_; |
147 | std::vector<dnnl_data_type_t> sdt_; |
148 | std::string tag_; |
149 | }; |
150 | |
151 | inline int64_t src_off_f(const prb_t *prb, int64_t mb, int64_t ic, int64_t id, |
152 | int64_t ih, int64_t iw) { |
153 | return (((mb * prb->ic + ic) * prb->id + id) * prb->ih + ih) * prb->iw + iw; |
154 | } |
155 | |
156 | inline int64_t dst_off_f(const prb_t *prb, int64_t mb, int64_t ic, int64_t od, |
157 | int64_t oh, int64_t ow) { |
158 | return (((mb * prb->ic + ic) * prb->od + od) * prb->oh + oh) * prb->ow + ow; |
159 | } |
160 | |
161 | void skip_unimplemented_prb(const prb_t *prb, res_t *res); |
162 | void skip_invalid_prb(const prb_t *prb, res_t *res); |
163 | void compute_ref(const prb_t *prb, const args_t &args, |
164 | dnnl_primitive_t prim_ref = nullptr); |
165 | |
166 | int compare_src( |
167 | const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *res); |
168 | int compare_dst( |
169 | const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *res); |
170 | int fill_dat( |
171 | const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *res); |
172 | |
173 | int doit(const prb_t *prb, res_t *res); |
174 | int bench(int argc, char **argv); |
175 | |
176 | } // namespace resampling |
177 | |
178 | #endif |
179 | |