1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <float.h> |
18 | #include <math.h> |
19 | #include <stdio.h> |
20 | #include <stdlib.h> |
21 | #include <string.h> |
22 | |
23 | #include "oneapi/dnnl/dnnl.h" |
24 | |
25 | #include "dnnl_common.hpp" |
26 | #include "dnnl_debug.hpp" |
27 | #include "pool/pool.hpp" |
28 | |
29 | namespace pool { |
30 | |
31 | alg_t str2alg(const char *str) { |
32 | #define CASE(_alg) \ |
33 | if (!strcasecmp(STRINGIFY(_alg), str)) return _alg |
34 | CASE(max); |
35 | CASE(pooling_max); |
36 | CASE(avg_np); |
37 | CASE(pooling_avg_exclude_padding); |
38 | CASE(avg_p); |
39 | CASE(pooling_avg_include_padding); |
40 | #undef CASE |
41 | assert(!"unknown algorithm" ); |
42 | return undef; |
43 | } |
44 | |
45 | const char *alg2str(alg_t alg) { |
46 | if (alg == max) return "max" ; |
47 | if (alg == avg_np) return "avg_np" ; |
48 | if (alg == avg_p) return "avg_p" ; |
49 | assert(!"unknown algorithm" ); |
50 | return "undef" ; |
51 | } |
52 | |
53 | dnnl_alg_kind_t alg2alg_kind(alg_t alg) { |
54 | if (alg == max) return dnnl_pooling_max; |
55 | if (alg == avg_np) return dnnl_pooling_avg_exclude_padding; |
56 | if (alg == avg_p) return dnnl_pooling_avg_include_padding; |
57 | assert(!"unknown algorithm" ); |
58 | return dnnl_alg_kind_undef; |
59 | } |
60 | |
61 | int str2desc(desc_t *desc, const char *str) { |
62 | /* canonical form: |
63 | * mbXicX_odXihXiwX_odXohXowX_kdXkhXkwX_sdXshXswX_pdXphXpwX_ddXdhXdwX_nS |
64 | * |
65 | * where X is number, S - string |
66 | * note: symbol `_` is ignored |
67 | * |
68 | * implicit rules: |
69 | * - if smaller dimensions are not specified => square or cubic form; |
70 | * - if output is undefined => compute output |
71 | * - if padding is undefined => compute trivial padding |
72 | */ |
73 | |
74 | desc_t d {0}; |
75 | d.mb = 2; |
76 | d.sd = d.sh = d.sw = 1; |
77 | d.pd = d.ph = d.pw = -1; |
78 | |
79 | const char *s = str; |
80 | assert(s); |
81 | |
82 | #define CASE_NN(prb, c) \ |
83 | do { \ |
84 | if (!strncmp(prb, s, strlen(prb))) { \ |
85 | ok = 1; \ |
86 | s += strlen(prb); \ |
87 | char *end_s; \ |
88 | d.c = strtol(s, &end_s, 10); \ |
89 | s += (end_s - s); \ |
90 | if (d.c < 0) return FAIL; \ |
91 | /* printf("@@@debug: %s: %ld\n", prb, d. c); */ \ |
92 | } \ |
93 | } while (0) |
94 | #define CASE_N(c) CASE_NN(#c, c) |
95 | while (*s) { |
96 | int ok = 0; |
97 | CASE_N(mb); |
98 | CASE_N(ic); |
99 | CASE_N(id); |
100 | CASE_N(ih); |
101 | CASE_N(iw); |
102 | CASE_N(od); |
103 | CASE_N(oh); |
104 | CASE_N(ow); |
105 | CASE_N(kd); |
106 | CASE_N(kh); |
107 | CASE_N(kw); |
108 | CASE_N(sd); |
109 | CASE_N(sh); |
110 | CASE_N(sw); |
111 | CASE_N(pd); |
112 | CASE_N(ph); |
113 | CASE_N(pw); |
114 | CASE_N(dd); |
115 | CASE_N(dh); |
116 | CASE_N(dw); |
117 | if (*s == 'n') { |
118 | d.name = s + 1; |
119 | break; |
120 | } |
121 | if (*s == '_') ++s; |
122 | if (!ok) return FAIL; |
123 | } |
124 | #undef CASE_NN |
125 | #undef CASE_N |
126 | |
127 | if (d.ic == 0) return FAIL; |
128 | if (d.sd <= 0 || d.sh <= 0 || d.sw <= 0) return FAIL; |
129 | |
130 | auto compute_out |
131 | = [](int64_t i, int64_t k, int64_t d, int64_t s, int64_t prb) { |
132 | return (i - (k - 1) * d + k + 2 * prb) / s + 1; |
133 | }; |
134 | auto compute_pad |
135 | = [](int64_t o, int64_t i, int64_t k, int64_t d, int64_t s) { |
136 | return ((o - 1) * s - i + (k - 1) * d + k) / 2; |
137 | }; |
138 | |
139 | const bool no_d = (d.id | d.kd | d.od) == 0 && d.sd == 1 && d.pd < 1; |
140 | const bool no_h = (d.ih | d.kh | d.oh) == 0 && d.sh == 1 && d.ph < 1; |
141 | const bool no_w = (d.iw | d.kw | d.ow) == 0 && d.sw == 1 && d.pw < 1; |
142 | |
143 | if (!no_d) { |
144 | if (!d.id || !d.kd) return FAIL; |
145 | if (!d.od) { |
146 | if (d.pd < 0) d.pd = 0; |
147 | d.od = compute_out(d.id, d.kd, d.dd, d.sd, d.pd); |
148 | if (d.od <= 0) return FAIL; |
149 | } else if (d.pd < 0) |
150 | d.pd = compute_pad(d.od, d.id, d.kd, d.dd, d.sd); |
151 | } |
152 | |
153 | if (!no_h) { |
154 | if (!d.ih || !d.kh) return FAIL; |
155 | if (!d.oh) { |
156 | if (d.ph < 0) d.ph = 0; |
157 | d.oh = compute_out(d.ih, d.kh, d.dh, d.sh, d.ph); |
158 | if (d.oh <= 0) return FAIL; |
159 | } else if (d.ph < 0) |
160 | d.ph = compute_pad(d.oh, d.ih, d.kh, d.dh, d.sh); |
161 | } |
162 | |
163 | if (!no_w) { |
164 | if (!d.iw || !d.kw) return FAIL; |
165 | if (!d.ow) { |
166 | if (d.pw < 0) d.pw = 0; |
167 | d.ow = compute_out(d.iw, d.kw, d.dw, d.sw, d.pw); |
168 | if (d.ow <= 0) return FAIL; |
169 | } else if (d.pw < 0) |
170 | d.pw = compute_pad(d.ow, d.iw, d.kw, d.dw, d.sw); |
171 | } |
172 | |
173 | if (sanitize_desc(d.ndims, {d.od, d.id, d.kd, d.sd, d.pd, d.dd}, |
174 | {d.oh, d.ih, d.kh, d.sh, d.ph, d.dh}, |
175 | {d.ow, d.iw, d.kw, d.sw, d.pw, d.dw}, {1, 1, 1, 1, 0, 0}, true) |
176 | != OK) |
177 | return FAIL; |
178 | |
179 | d.init_pad_r(); |
180 | *desc = d; |
181 | |
182 | return OK; |
183 | } |
184 | |
185 | dims_t desc_t::src_dims() const { |
186 | dims_t src_dims {mb, ic, id, ih, iw}; |
187 | for (int d = 0; d < 5 - ndims; ++d) { |
188 | src_dims.erase(src_dims.begin() + 2); |
189 | } |
190 | |
191 | return src_dims; |
192 | } |
193 | |
194 | dims_t desc_t::dst_dims() const { |
195 | dims_t dst_dims {mb, ic, od, oh, ow}; |
196 | for (int d = 0; d < 5 - ndims; ++d) { |
197 | dst_dims.erase(dst_dims.begin() + 2); |
198 | } |
199 | |
200 | return dst_dims; |
201 | } |
202 | |
203 | dims_t desc_t::strides() const { |
204 | dims_t strides {sd, sh, sw}; |
205 | return dims_t(strides.begin() + (5 - ndims), strides.end()); |
206 | } |
207 | |
208 | dims_t desc_t::kernel() const { |
209 | dims_t kernel {kd, kh, kw}; |
210 | return dims_t(kernel.begin() + (5 - ndims), kernel.end()); |
211 | } |
212 | |
213 | dims_t desc_t::dilations() const { |
214 | dims_t dilations {dd, dh, dw}; |
215 | return dims_t(dilations.begin() + (5 - ndims), dilations.end()); |
216 | } |
217 | |
218 | dims_t desc_t::padding() const { |
219 | dims_t padding {pd, ph, pw}; |
220 | return dims_t(padding.begin() + (5 - ndims), padding.end()); |
221 | } |
222 | |
223 | dims_t desc_t::padding_r() const { |
224 | dims_t padding_r {pd_r, ph_r, pw_r}; |
225 | return dims_t(padding_r.begin() + (5 - ndims), padding_r.end()); |
226 | } |
227 | |
228 | std::ostream &operator<<(std::ostream &s, const desc_t &d) { |
229 | bool print_d = true, print_h = true, print_w = true; |
230 | print_dhw(print_d, print_h, print_w, d.ndims, |
231 | {d.od, d.id, d.kd, d.sd, d.pd, d.dd}, |
232 | {d.oh, d.ih, d.kh, d.sh, d.ph, d.dh}, |
233 | {d.ow, d.iw, d.kw, d.sw, d.pw, d.dw}); |
234 | |
235 | auto print_spatial |
236 | = [&](const char *d_str, int64_t d_val, const char *h_str, |
237 | int64_t h_val, const char *w_str, int64_t w_val) { |
238 | if (print_d) s << d_str << d_val; |
239 | if (print_h) s << h_str << h_val; |
240 | if (print_w) s << w_str << w_val; |
241 | }; |
242 | |
243 | if (canonical || d.mb != 2) s << "mb" << d.mb; |
244 | s << "ic" << d.ic; |
245 | print_spatial("id" , d.id, "ih" , d.ih, "iw" , d.iw); |
246 | print_spatial("od" , d.od, "oh" , d.oh, "ow" , d.ow); |
247 | print_spatial("kd" , d.kd, "kh" , d.kh, "kw" , d.kw); |
248 | |
249 | if (canonical || d.sh != 1 || d.sw != 1 || d.sd != 1) |
250 | print_spatial("sd" , d.sd, "sh" , d.sh, "sw" , d.sw); |
251 | |
252 | print_spatial("pd" , d.pd, "ph" , d.ph, "pw" , d.pw); |
253 | |
254 | if (canonical || d.dh != 0 || d.dw != 0 || d.dd != 0) |
255 | print_spatial("dd" , d.dd, "dh" , d.dh, "dw" , d.dw); |
256 | |
257 | if (!d.name.empty()) s << "n" << d.name; |
258 | |
259 | return s; |
260 | } |
261 | |
262 | std::ostream &operator<<(std::ostream &s, const prb_t &prb) { |
263 | dump_global_params(s); |
264 | settings_t def; |
265 | |
266 | if (canonical || prb.dir != def.dir[0]) s << "--dir=" << prb.dir << " " ; |
267 | if (canonical || prb.cfg != def.cfg[0]) s << "--cfg=" << prb.cfg << " " ; |
268 | if (canonical || prb.tag != def.tag[0]) s << "--tag=" << prb.tag << " " ; |
269 | if (canonical || prb.alg != def.alg[0]) |
270 | s << "--alg=" << alg2str(prb.alg) << " " ; |
271 | |
272 | s << prb.attr; |
273 | if (canonical || prb.ctx_init != def.ctx_init[0]) |
274 | s << "--ctx-init=" << prb.ctx_init << " " ; |
275 | if (canonical || prb.ctx_exe != def.ctx_exe[0]) |
276 | s << "--ctx-exe=" << prb.ctx_exe << " " ; |
277 | |
278 | s << static_cast<const desc_t &>(prb); |
279 | |
280 | return s; |
281 | } |
282 | |
283 | } // namespace pool |
284 | |