1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <float.h> |
18 | #include <math.h> |
19 | #include <stdio.h> |
20 | #include <stdlib.h> |
21 | #include <string.h> |
22 | |
23 | #include "oneapi/dnnl/dnnl.h" |
24 | |
25 | #include "dnnl_common.hpp" |
26 | #include "dnnl_debug.hpp" |
27 | #include "resampling/resampling.hpp" |
28 | |
29 | namespace resampling { |
30 | |
31 | alg_t str2alg(const char *str) { |
32 | #define CASE(_alg) \ |
33 | if (!strcasecmp(STRINGIFY(_alg), str)) return _alg |
34 | CASE(nearest); |
35 | CASE(resampling_nearest); |
36 | CASE(linear); |
37 | CASE(resampling_linear); |
38 | #undef CASE |
39 | assert(!"unknown algorithm" ); |
40 | return undef; |
41 | } |
42 | |
43 | const char *alg2str(alg_t alg) { |
44 | if (alg == nearest) return "nearest" ; |
45 | if (alg == linear) return "linear" ; |
46 | assert(!"unknown algorithm" ); |
47 | return "undef" ; |
48 | } |
49 | |
50 | dnnl_alg_kind_t alg2alg_kind(alg_t alg) { |
51 | if (alg == nearest) return dnnl_resampling_nearest; |
52 | if (alg == linear) return dnnl_resampling_linear; |
53 | assert(!"unknown algorithm" ); |
54 | return dnnl_alg_kind_undef; |
55 | } |
56 | |
57 | int str2desc(desc_t *desc, const char *str) { |
58 | /* canonical form: |
59 | * mbXicXidXihXiwXodXohXowXnS |
60 | * |
61 | * where: Y = {fd, fi, bd}, X is number, S - string |
62 | * note: symbol `_` is ignored |
63 | * |
64 | * implicit rules: |
65 | * - default values: |
66 | * mb = 2, ih = oh = id = od = 1, S="wip" |
67 | */ |
68 | |
69 | desc_t d {0}; |
70 | d.mb = 2; |
71 | |
72 | const char *s = str; |
73 | assert(s); |
74 | |
75 | #define CASE_NN(prb, c) \ |
76 | do { \ |
77 | if (!strncmp(prb, s, strlen(prb))) { \ |
78 | ok = 1; \ |
79 | s += strlen(prb); \ |
80 | char *end_s; \ |
81 | d.c = strtol(s, &end_s, 10); \ |
82 | s += (end_s - s); \ |
83 | if (d.c < 0) return FAIL; \ |
84 | /* printf("@@@debug: %s: %ld\n", prb, d. c); */ \ |
85 | } \ |
86 | } while (0) |
87 | #define CASE_N(c) CASE_NN(#c, c) |
88 | while (*s) { |
89 | int ok = 0; |
90 | CASE_N(mb); |
91 | CASE_N(ic); |
92 | CASE_N(id); |
93 | CASE_N(ih); |
94 | CASE_N(iw); |
95 | CASE_N(od); |
96 | CASE_N(oh); |
97 | CASE_N(ow); |
98 | if (*s == 'n') { |
99 | d.name = s + 1; |
100 | break; |
101 | } |
102 | if (*s == '_') ++s; |
103 | if (!ok) return FAIL; |
104 | } |
105 | #undef CASE_NN |
106 | #undef CASE_N |
107 | |
108 | if (d.ic == 0) return FAIL; |
109 | if ((d.id && !d.od) || (!d.id && d.od)) return FAIL; |
110 | if ((d.ih && !d.oh) || (!d.ih && d.oh)) return FAIL; |
111 | if ((d.iw && !d.ow) || (!d.iw && d.ow)) return FAIL; |
112 | |
113 | if (sanitize_desc( |
114 | d.ndims, {d.od, d.id}, {d.oh, d.ih}, {d.ow, d.iw}, {1, 1}, true) |
115 | != OK) |
116 | return FAIL; |
117 | |
118 | *desc = d; |
119 | |
120 | return OK; |
121 | } |
122 | |
123 | dims_t desc_t::src_dims() const { |
124 | dims_t src_dims {mb, ic, id, ih, iw}; |
125 | for (int d = 0; d < 5 - ndims; ++d) { |
126 | src_dims.erase(src_dims.begin() + 2); |
127 | } |
128 | |
129 | return src_dims; |
130 | } |
131 | |
132 | dims_t desc_t::dst_dims() const { |
133 | dims_t dst_dims {mb, ic, od, oh, ow}; |
134 | for (int d = 0; d < 5 - ndims; ++d) { |
135 | dst_dims.erase(dst_dims.begin() + 2); |
136 | } |
137 | |
138 | return dst_dims; |
139 | } |
140 | |
141 | std::ostream &operator<<(std::ostream &s, const desc_t &d) { |
142 | bool print_d = true, print_h = true, print_w = true; |
143 | print_dhw(print_d, print_h, print_w, d.ndims, {d.od, d.id}, {d.oh, d.ih}, |
144 | {d.ow, d.iw}); |
145 | |
146 | if (canonical || d.mb != 2) s << "mb" << d.mb; |
147 | |
148 | s << "ic" << d.ic; |
149 | |
150 | if (print_d) s << "id" << d.id; |
151 | if (print_h) s << "ih" << d.ih; |
152 | if (print_w) s << "iw" << d.iw; |
153 | |
154 | if (print_d) s << "od" << d.od; |
155 | if (print_h) s << "oh" << d.oh; |
156 | if (print_w) s << "ow" << d.ow; |
157 | |
158 | if (!d.name.empty()) s << "n" << d.name; |
159 | |
160 | return s; |
161 | } |
162 | |
163 | std::ostream &operator<<(std::ostream &s, const prb_t &prb) { |
164 | dump_global_params(s); |
165 | settings_t def; |
166 | |
167 | if (canonical || prb.dir != def.dir[0]) s << "--dir=" << prb.dir << " " ; |
168 | if (canonical || prb.sdt != def.sdt[0]) s << "--sdt=" << prb.sdt << " " ; |
169 | if (canonical || prb.ddt != def.ddt[0]) s << "--ddt=" << prb.ddt << " " ; |
170 | if (canonical || prb.tag != def.tag[0]) s << "--tag=" << prb.tag << " " ; |
171 | if (canonical || prb.alg != def.alg[0]) |
172 | s << "--alg=" << alg2str(prb.alg) << " " ; |
173 | |
174 | s << prb.attr; |
175 | if (canonical || prb.ctx_init != def.ctx_init[0]) |
176 | s << "--ctx-init=" << prb.ctx_init << " " ; |
177 | if (canonical || prb.ctx_exe != def.ctx_exe[0]) |
178 | s << "--ctx-exe=" << prb.ctx_exe << " " ; |
179 | |
180 | s << static_cast<const desc_t &>(prb); |
181 | |
182 | return s; |
183 | } |
184 | |
185 | } // namespace resampling |
186 | |