1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef POOL_HPP
18#define POOL_HPP
19
20#include <assert.h>
21#include <limits.h>
22#include <stdint.h>
23
24#include <iostream>
25
26#include "common.hpp"
27#include "dnn_types.hpp"
28#include "dnnl_common.hpp"
29#include "utils/perf_report.hpp"
30#include "utils/settings.hpp"
31
32namespace pool {
33
34enum alg_t {
35 undef,
36 max,
37 avg_np,
38 avg_p,
39 pooling_max = max,
40 pooling_avg_exclude_padding = avg_np,
41 pooling_avg_include_padding = avg_p,
42};
43alg_t str2alg(const char *str);
44const char *alg2str(alg_t alg);
45dnnl_alg_kind_t alg2alg_kind(alg_t alg);
46
47struct desc_t {
48 int64_t mb, ic;
49 int64_t id, ih, iw;
50 int64_t od, oh, ow;
51 int64_t kd, kh, kw;
52 int64_t dd, dh, dw;
53 int64_t sd, sh, sw;
54 int64_t pd, ph, pw;
55 int64_t pd_r, ph_r, pw_r; // End side padding for each dimension
56
57 std::string name;
58 int ndims;
59
60 // Initialize dependent opposite-side paddings values from the shape
61 // parameters
62 void init_pad_r() {
63 pw_r = opp_pad(iw, ow, kw, dw, sw, pw);
64 ph_r = opp_pad(ih, oh, kh, dh, sh, ph);
65 pd_r = opp_pad(id, od, kd, dd, sd, pd);
66 }
67
68 dims_t src_dims() const;
69 dims_t dst_dims() const;
70 dims_t strides() const;
71 dims_t kernel() const;
72 dims_t dilations() const;
73 dims_t padding() const;
74 dims_t padding_r() const;
75
76private:
77 int64_t opp_pad(
78 int64_t i, int64_t o, int64_t k, int64_t d, int64_t s, int64_t p) {
79 return (o - 1) * s - i + ((k - 1) * (d + 1) + 1) - p;
80 }
81};
82
83int str2desc(desc_t *desc, const char *str);
84std::ostream &operator<<(std::ostream &s, const desc_t &d);
85
86/** configuration structure, that controls initial data filling + error check
87 *
88 * dt defines pooling precision
89 *
90 * for each type (SRC and DST) the values are filled as follows:
91 * if (rand() > f_sparsity) then:
92 * v <-- f_base // it is guaranteed each kernel window
93 * // has at least one non-zero element
94 * else:
95 * v <-- f_min + rand() * f_step % (f_max - f_min)
96 *
97 * on final check the resulting values should be in [min .. max] range, the
98 * relative difference should not exceed eps
99 */
100typedef struct dt_conf_t {
101 dnnl_data_type_t dt;
102 double min, max; /* representative */
103 int f_min, f_max; /* fill range */
104 double eps; /* acceptable error */
105} _dt_conf_t[DAT_TOTAL];
106
107extern const _dt_conf_t conf_f32;
108
109const dt_conf_t *str2cfg(const char *str);
110std::ostream &operator<<(std::ostream &s, const dt_conf_t *cfg);
111
112struct settings_t : public base_settings_t {
113 settings_t() = default;
114
115 // ctor to save certain fields from resetting
116 settings_t(const char *perf_template) : settings_t() {
117 this->perf_template = perf_template;
118 }
119
120 desc_t desc {};
121
122 std::vector<dir_t> dir {FWD_D};
123 std::vector<const dt_conf_t *> cfg {conf_f32};
124 std::vector<std::string> tag {tag::abx};
125 std::vector<alg_t> alg {max};
126
127 const char *perf_template_csv() const {
128 static const std::string args = "%dir%,%cfg%,%tag%,%alg%";
129 return perf_template_csv_base(args);
130 }
131
132 void reset() { *this = settings_t(perf_template); }
133};
134
135struct prb_t : public desc_t {
136 prb_t(const desc_t &desc, dir_t dir, const dt_conf_t *cfg,
137 const std::string &tag, alg_t alg, const attr_t &attr,
138 const thr_ctx_t &ctx_init, const thr_ctx_t &ctx_exe, int64_t mb = 0)
139 : desc_t(desc)
140 , dir(dir)
141 , cfg(cfg)
142 , tag(tag)
143 , alg(alg)
144 , attr(attr)
145 , ctx_init(ctx_init)
146 , ctx_exe(ctx_exe)
147 , user_mb(mb) {
148 if (mb) this->mb = mb;
149 }
150 ~prb_t() {}
151
152 dir_t dir;
153 const dt_conf_t *cfg;
154 std::string tag;
155 alg_t alg;
156 attr_t attr;
157 thr_ctx_t ctx_init, ctx_exe;
158 int64_t user_mb;
159
160 int64_t kernel_size() const { return kd * kh * kw; }
161
162 BENCHDNN_DISALLOW_COPY_AND_ASSIGN(prb_t);
163};
164std::ostream &operator<<(std::ostream &s, const prb_t &prb);
165
166struct perf_report_t : public base_perf_report_t {
167 perf_report_t(const prb_t *prb, const char *perf_template)
168 : base_perf_report_t(perf_template)
169 , p_(prb)
170 , tag_(normalize_tag(p_->tag, p_->ndims)) {}
171
172 void dump_alg(std::ostream &s) const override { s << alg2str(p_->alg); }
173
174 void dump_cfg(std::ostream &s) const override { s << p_->cfg; }
175
176 void dump_desc(std::ostream &s) const override {
177 s << static_cast<const desc_t &>(*p_);
178 }
179
180 void dump_desc_csv(std::ostream &s) const override {
181 s << p_->mb << ','
182
183 << p_->ic << ',' << p_->id << ',' << p_->ih << ',' << p_->iw << ','
184
185 << p_->od << ',' << p_->oh << ',' << p_->ow << ','
186
187 << p_->kd << ',' << p_->kh << ',' << p_->kw << ','
188
189 << p_->sd << ',' << p_->sh << ',' << p_->sw << ','
190
191 << p_->pd << ',' << p_->ph << ',' << p_->pw << ','
192
193 << p_->dd << ',' << p_->dh << ',' << p_->dw;
194 }
195
196 const int64_t *user_mb() const override { return &p_->user_mb; }
197 const attr_t *attr() const override { return &p_->attr; }
198 const thr_ctx_t *ctx_init() const override { return &p_->ctx_init; }
199 const thr_ctx_t *ctx_exe() const override { return &p_->ctx_exe; }
200 const std::string *name() const override { return &p_->name; }
201 const dir_t *dir() const override { return &p_->dir; }
202 const std::string *tag() const override { return &tag_; }
203
204private:
205 const prb_t *p_;
206 std::string tag_;
207};
208
209inline int64_t src_off_f(const prb_t *prb, int64_t mb, int64_t ic, int64_t id,
210 int64_t ih, int64_t iw) {
211 return (((mb * prb->ic + ic) * prb->id + id) * prb->ih + ih) * prb->iw + iw;
212}
213
214inline int64_t dst_off_f(const prb_t *prb, int64_t mb, int64_t ic, int64_t od,
215 int64_t oh, int64_t ow) {
216 return (((mb * prb->ic + ic) * prb->od + od) * prb->oh + oh) * prb->ow + ow;
217}
218
219inline int64_t ker_off_f(const prb_t *prb, int64_t kd, int64_t kh, int64_t kw) {
220 return (kd * prb->kh + kh) * prb->kw + kw;
221}
222
223inline int64_t get_num_summands(
224 const prb_t *prb, int64_t d, int64_t h, int64_t w) {
225 const int64_t ID = prb->id, IH = prb->ih, IW = prb->iw;
226 const int64_t KD = prb->kd, KH = prb->kh, KW = prb->kw;
227 const int64_t DD = prb->dd, DH = prb->dh, DW = prb->dw;
228 const int64_t PD = prb->pd, PH = prb->ph, PW = prb->pw;
229 const int64_t SD = prb->sd, SH = prb->sh, SW = prb->sw;
230
231 auto id_start = d * SD - PD;
232 auto ih_start = h * SH - PH;
233 auto iw_start = w * SW - PW;
234 auto id_end = d * SD - PD + (KD - 1) * DD + KD;
235 auto ih_end = h * SH - PH + (KH - 1) * DH + KH;
236 auto iw_end = w * SW - PW + (KW - 1) * DW + KW;
237
238 auto id_start_excluded
239 = id_start < 0 ? (0 - id_start - 1) / (DD + 1) + 1 : 0;
240 auto ih_start_excluded
241 = ih_start < 0 ? (0 - ih_start - 1) / (DH + 1) + 1 : 0;
242 auto iw_start_excluded
243 = iw_start < 0 ? (0 - iw_start - 1) / (DW + 1) + 1 : 0;
244 auto id_end_excluded = id_end > ID ? (id_end - ID - 1) / (DD + 1) + 1 : 0;
245 auto ih_end_excluded = ih_end > IH ? (ih_end - IH - 1) / (DH + 1) + 1 : 0;
246 auto iw_end_excluded = iw_end > IW ? (iw_end - IW - 1) / (DW + 1) + 1 : 0;
247
248 return prb->alg == avg_p ? KD * KH * KW
249 : (KD - id_start_excluded - id_end_excluded)
250 * (KH - ih_start_excluded - ih_end_excluded)
251 * (KW - iw_start_excluded - iw_end_excluded);
252}
253
254void skip_unimplemented_prb(const prb_t *prb, res_t *res);
255void skip_invalid_prb(const prb_t *prb, res_t *res);
256void compute_ref(const prb_t *prb, const args_t &args,
257 dnnl_primitive_t prim_ref = nullptr);
258
259int compare_src(
260 const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *res);
261int compare_dst(
262 const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *res);
263int fill_src(
264 const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *res);
265int fill_dst(
266 const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *res);
267int fill_ws(const prb_t *prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *res);
268
269int doit(const prb_t *prb, res_t *res);
270int bench(int argc, char **argv);
271
272} // namespace pool
273
274#endif
275