1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <float.h>
18#include <math.h>
19#include <stdio.h>
20#include <stdlib.h>
21#include <string.h>
22
23#include "oneapi/dnnl/dnnl.h"
24
25#include "dnnl_common.hpp"
26#include "dnnl_debug.hpp"
27#include "resampling/resampling.hpp"
28
29namespace resampling {
30
31alg_t str2alg(const char *str) {
32#define CASE(_alg) \
33 if (!strcasecmp(STRINGIFY(_alg), str)) return _alg
34 CASE(nearest);
35 CASE(resampling_nearest);
36 CASE(linear);
37 CASE(resampling_linear);
38#undef CASE
39 assert(!"unknown algorithm");
40 return undef;
41}
42
43const char *alg2str(alg_t alg) {
44 if (alg == nearest) return "nearest";
45 if (alg == linear) return "linear";
46 assert(!"unknown algorithm");
47 return "undef";
48}
49
50dnnl_alg_kind_t alg2alg_kind(alg_t alg) {
51 if (alg == nearest) return dnnl_resampling_nearest;
52 if (alg == linear) return dnnl_resampling_linear;
53 assert(!"unknown algorithm");
54 return dnnl_alg_kind_undef;
55}
56
57int str2desc(desc_t *desc, const char *str) {
58 /* canonical form:
59 * mbXicXidXihXiwXodXohXowXnS
60 *
61 * where: Y = {fd, fi, bd}, X is number, S - string
62 * note: symbol `_` is ignored
63 *
64 * implicit rules:
65 * - default values:
66 * mb = 2, ih = oh = id = od = 1, S="wip"
67 */
68
69 desc_t d {0};
70 d.mb = 2;
71
72 const char *s = str;
73 assert(s);
74
75#define CASE_NN(prb, c) \
76 do { \
77 if (!strncmp(prb, s, strlen(prb))) { \
78 ok = 1; \
79 s += strlen(prb); \
80 char *end_s; \
81 d.c = strtol(s, &end_s, 10); \
82 s += (end_s - s); \
83 if (d.c < 0) return FAIL; \
84 /* printf("@@@debug: %s: %ld\n", prb, d. c); */ \
85 } \
86 } while (0)
87#define CASE_N(c) CASE_NN(#c, c)
88 while (*s) {
89 int ok = 0;
90 CASE_N(mb);
91 CASE_N(ic);
92 CASE_N(id);
93 CASE_N(ih);
94 CASE_N(iw);
95 CASE_N(od);
96 CASE_N(oh);
97 CASE_N(ow);
98 if (*s == 'n') {
99 d.name = s + 1;
100 break;
101 }
102 if (*s == '_') ++s;
103 if (!ok) return FAIL;
104 }
105#undef CASE_NN
106#undef CASE_N
107
108 if (d.ic == 0) return FAIL;
109 if ((d.id && !d.od) || (!d.id && d.od)) return FAIL;
110 if ((d.ih && !d.oh) || (!d.ih && d.oh)) return FAIL;
111 if ((d.iw && !d.ow) || (!d.iw && d.ow)) return FAIL;
112
113 if (sanitize_desc(
114 d.ndims, {d.od, d.id}, {d.oh, d.ih}, {d.ow, d.iw}, {1, 1}, true)
115 != OK)
116 return FAIL;
117
118 *desc = d;
119
120 return OK;
121}
122
123dims_t desc_t::src_dims() const {
124 dims_t src_dims {mb, ic, id, ih, iw};
125 for (int d = 0; d < 5 - ndims; ++d) {
126 src_dims.erase(src_dims.begin() + 2);
127 }
128
129 return src_dims;
130}
131
132dims_t desc_t::dst_dims() const {
133 dims_t dst_dims {mb, ic, od, oh, ow};
134 for (int d = 0; d < 5 - ndims; ++d) {
135 dst_dims.erase(dst_dims.begin() + 2);
136 }
137
138 return dst_dims;
139}
140
141std::ostream &operator<<(std::ostream &s, const desc_t &d) {
142 bool print_d = true, print_h = true, print_w = true;
143 print_dhw(print_d, print_h, print_w, d.ndims, {d.od, d.id}, {d.oh, d.ih},
144 {d.ow, d.iw});
145
146 if (canonical || d.mb != 2) s << "mb" << d.mb;
147
148 s << "ic" << d.ic;
149
150 if (print_d) s << "id" << d.id;
151 if (print_h) s << "ih" << d.ih;
152 if (print_w) s << "iw" << d.iw;
153
154 if (print_d) s << "od" << d.od;
155 if (print_h) s << "oh" << d.oh;
156 if (print_w) s << "ow" << d.ow;
157
158 if (!d.name.empty()) s << "n" << d.name;
159
160 return s;
161}
162
163std::ostream &operator<<(std::ostream &s, const prb_t &prb) {
164 dump_global_params(s);
165 settings_t def;
166
167 if (canonical || prb.dir != def.dir[0]) s << "--dir=" << prb.dir << " ";
168 if (canonical || prb.sdt != def.sdt[0]) s << "--sdt=" << prb.sdt << " ";
169 if (canonical || prb.ddt != def.ddt[0]) s << "--ddt=" << prb.ddt << " ";
170 if (canonical || prb.tag != def.tag[0]) s << "--tag=" << prb.tag << " ";
171 if (canonical || prb.alg != def.alg[0])
172 s << "--alg=" << alg2str(prb.alg) << " ";
173
174 s << prb.attr;
175 if (canonical || prb.ctx_init != def.ctx_init[0])
176 s << "--ctx-init=" << prb.ctx_init << " ";
177 if (canonical || prb.ctx_exe != def.ctx_exe[0])
178 s << "--ctx-exe=" << prb.ctx_exe << " ";
179
180 s << static_cast<const desc_t &>(prb);
181
182 return s;
183}
184
185} // namespace resampling
186