1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <float.h>
18#include <math.h>
19#include <stdio.h>
20#include <stdlib.h>
21#include <string.h>
22
23#include "oneapi/dnnl/dnnl.h"
24
25#include "dnnl_common.hpp"
26#include "dnnl_debug.hpp"
27#include "pool/pool.hpp"
28
29namespace pool {
30
31alg_t str2alg(const char *str) {
32#define CASE(_alg) \
33 if (!strcasecmp(STRINGIFY(_alg), str)) return _alg
34 CASE(max);
35 CASE(pooling_max);
36 CASE(avg_np);
37 CASE(pooling_avg_exclude_padding);
38 CASE(avg_p);
39 CASE(pooling_avg_include_padding);
40#undef CASE
41 assert(!"unknown algorithm");
42 return undef;
43}
44
45const char *alg2str(alg_t alg) {
46 if (alg == max) return "max";
47 if (alg == avg_np) return "avg_np";
48 if (alg == avg_p) return "avg_p";
49 assert(!"unknown algorithm");
50 return "undef";
51}
52
53dnnl_alg_kind_t alg2alg_kind(alg_t alg) {
54 if (alg == max) return dnnl_pooling_max;
55 if (alg == avg_np) return dnnl_pooling_avg_exclude_padding;
56 if (alg == avg_p) return dnnl_pooling_avg_include_padding;
57 assert(!"unknown algorithm");
58 return dnnl_alg_kind_undef;
59}
60
61int str2desc(desc_t *desc, const char *str) {
62 /* canonical form:
63 * mbXicX_odXihXiwX_odXohXowX_kdXkhXkwX_sdXshXswX_pdXphXpwX_ddXdhXdwX_nS
64 *
65 * where X is number, S - string
66 * note: symbol `_` is ignored
67 *
68 * implicit rules:
69 * - if smaller dimensions are not specified => square or cubic form;
70 * - if output is undefined => compute output
71 * - if padding is undefined => compute trivial padding
72 */
73
74 desc_t d {0};
75 d.mb = 2;
76 d.sd = d.sh = d.sw = 1;
77 d.pd = d.ph = d.pw = -1;
78
79 const char *s = str;
80 assert(s);
81
82#define CASE_NN(prb, c) \
83 do { \
84 if (!strncmp(prb, s, strlen(prb))) { \
85 ok = 1; \
86 s += strlen(prb); \
87 char *end_s; \
88 d.c = strtol(s, &end_s, 10); \
89 s += (end_s - s); \
90 if (d.c < 0) return FAIL; \
91 /* printf("@@@debug: %s: %ld\n", prb, d. c); */ \
92 } \
93 } while (0)
94#define CASE_N(c) CASE_NN(#c, c)
95 while (*s) {
96 int ok = 0;
97 CASE_N(mb);
98 CASE_N(ic);
99 CASE_N(id);
100 CASE_N(ih);
101 CASE_N(iw);
102 CASE_N(od);
103 CASE_N(oh);
104 CASE_N(ow);
105 CASE_N(kd);
106 CASE_N(kh);
107 CASE_N(kw);
108 CASE_N(sd);
109 CASE_N(sh);
110 CASE_N(sw);
111 CASE_N(pd);
112 CASE_N(ph);
113 CASE_N(pw);
114 CASE_N(dd);
115 CASE_N(dh);
116 CASE_N(dw);
117 if (*s == 'n') {
118 d.name = s + 1;
119 break;
120 }
121 if (*s == '_') ++s;
122 if (!ok) return FAIL;
123 }
124#undef CASE_NN
125#undef CASE_N
126
127 if (d.ic == 0) return FAIL;
128 if (d.sd <= 0 || d.sh <= 0 || d.sw <= 0) return FAIL;
129
130 auto compute_out
131 = [](int64_t i, int64_t k, int64_t d, int64_t s, int64_t prb) {
132 return (i - (k - 1) * d + k + 2 * prb) / s + 1;
133 };
134 auto compute_pad
135 = [](int64_t o, int64_t i, int64_t k, int64_t d, int64_t s) {
136 return ((o - 1) * s - i + (k - 1) * d + k) / 2;
137 };
138
139 const bool no_d = (d.id | d.kd | d.od) == 0 && d.sd == 1 && d.pd < 1;
140 const bool no_h = (d.ih | d.kh | d.oh) == 0 && d.sh == 1 && d.ph < 1;
141 const bool no_w = (d.iw | d.kw | d.ow) == 0 && d.sw == 1 && d.pw < 1;
142
143 if (!no_d) {
144 if (!d.id || !d.kd) return FAIL;
145 if (!d.od) {
146 if (d.pd < 0) d.pd = 0;
147 d.od = compute_out(d.id, d.kd, d.dd, d.sd, d.pd);
148 if (d.od <= 0) return FAIL;
149 } else if (d.pd < 0)
150 d.pd = compute_pad(d.od, d.id, d.kd, d.dd, d.sd);
151 }
152
153 if (!no_h) {
154 if (!d.ih || !d.kh) return FAIL;
155 if (!d.oh) {
156 if (d.ph < 0) d.ph = 0;
157 d.oh = compute_out(d.ih, d.kh, d.dh, d.sh, d.ph);
158 if (d.oh <= 0) return FAIL;
159 } else if (d.ph < 0)
160 d.ph = compute_pad(d.oh, d.ih, d.kh, d.dh, d.sh);
161 }
162
163 if (!no_w) {
164 if (!d.iw || !d.kw) return FAIL;
165 if (!d.ow) {
166 if (d.pw < 0) d.pw = 0;
167 d.ow = compute_out(d.iw, d.kw, d.dw, d.sw, d.pw);
168 if (d.ow <= 0) return FAIL;
169 } else if (d.pw < 0)
170 d.pw = compute_pad(d.ow, d.iw, d.kw, d.dw, d.sw);
171 }
172
173 if (sanitize_desc(d.ndims, {d.od, d.id, d.kd, d.sd, d.pd, d.dd},
174 {d.oh, d.ih, d.kh, d.sh, d.ph, d.dh},
175 {d.ow, d.iw, d.kw, d.sw, d.pw, d.dw}, {1, 1, 1, 1, 0, 0}, true)
176 != OK)
177 return FAIL;
178
179 d.init_pad_r();
180 *desc = d;
181
182 return OK;
183}
184
185dims_t desc_t::src_dims() const {
186 dims_t src_dims {mb, ic, id, ih, iw};
187 for (int d = 0; d < 5 - ndims; ++d) {
188 src_dims.erase(src_dims.begin() + 2);
189 }
190
191 return src_dims;
192}
193
194dims_t desc_t::dst_dims() const {
195 dims_t dst_dims {mb, ic, od, oh, ow};
196 for (int d = 0; d < 5 - ndims; ++d) {
197 dst_dims.erase(dst_dims.begin() + 2);
198 }
199
200 return dst_dims;
201}
202
203dims_t desc_t::strides() const {
204 dims_t strides {sd, sh, sw};
205 return dims_t(strides.begin() + (5 - ndims), strides.end());
206}
207
208dims_t desc_t::kernel() const {
209 dims_t kernel {kd, kh, kw};
210 return dims_t(kernel.begin() + (5 - ndims), kernel.end());
211}
212
213dims_t desc_t::dilations() const {
214 dims_t dilations {dd, dh, dw};
215 return dims_t(dilations.begin() + (5 - ndims), dilations.end());
216}
217
218dims_t desc_t::padding() const {
219 dims_t padding {pd, ph, pw};
220 return dims_t(padding.begin() + (5 - ndims), padding.end());
221}
222
223dims_t desc_t::padding_r() const {
224 dims_t padding_r {pd_r, ph_r, pw_r};
225 return dims_t(padding_r.begin() + (5 - ndims), padding_r.end());
226}
227
228std::ostream &operator<<(std::ostream &s, const desc_t &d) {
229 bool print_d = true, print_h = true, print_w = true;
230 print_dhw(print_d, print_h, print_w, d.ndims,
231 {d.od, d.id, d.kd, d.sd, d.pd, d.dd},
232 {d.oh, d.ih, d.kh, d.sh, d.ph, d.dh},
233 {d.ow, d.iw, d.kw, d.sw, d.pw, d.dw});
234
235 auto print_spatial
236 = [&](const char *d_str, int64_t d_val, const char *h_str,
237 int64_t h_val, const char *w_str, int64_t w_val) {
238 if (print_d) s << d_str << d_val;
239 if (print_h) s << h_str << h_val;
240 if (print_w) s << w_str << w_val;
241 };
242
243 if (canonical || d.mb != 2) s << "mb" << d.mb;
244 s << "ic" << d.ic;
245 print_spatial("id", d.id, "ih", d.ih, "iw", d.iw);
246 print_spatial("od", d.od, "oh", d.oh, "ow", d.ow);
247 print_spatial("kd", d.kd, "kh", d.kh, "kw", d.kw);
248
249 if (canonical || d.sh != 1 || d.sw != 1 || d.sd != 1)
250 print_spatial("sd", d.sd, "sh", d.sh, "sw", d.sw);
251
252 print_spatial("pd", d.pd, "ph", d.ph, "pw", d.pw);
253
254 if (canonical || d.dh != 0 || d.dw != 0 || d.dd != 0)
255 print_spatial("dd", d.dd, "dh", d.dh, "dw", d.dw);
256
257 if (!d.name.empty()) s << "n" << d.name;
258
259 return s;
260}
261
262std::ostream &operator<<(std::ostream &s, const prb_t &prb) {
263 dump_global_params(s);
264 settings_t def;
265
266 if (canonical || prb.dir != def.dir[0]) s << "--dir=" << prb.dir << " ";
267 if (canonical || prb.cfg != def.cfg[0]) s << "--cfg=" << prb.cfg << " ";
268 if (canonical || prb.tag != def.tag[0]) s << "--tag=" << prb.tag << " ";
269 if (canonical || prb.alg != def.alg[0])
270 s << "--alg=" << alg2str(prb.alg) << " ";
271
272 s << prb.attr;
273 if (canonical || prb.ctx_init != def.ctx_init[0])
274 s << "--ctx-init=" << prb.ctx_init << " ";
275 if (canonical || prb.ctx_exe != def.ctx_exe[0])
276 s << "--ctx-exe=" << prb.ctx_exe << " ";
277
278 s << static_cast<const desc_t &>(prb);
279
280 return s;
281}
282
283} // namespace pool
284