1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <assert.h> |
18 | #include <stdlib.h> |
19 | #include "bnorm/bnorm.hpp" |
20 | |
21 | namespace bnorm { |
22 | |
23 | check_alg_t str2check_alg(const char *str) { |
24 | if (!strcasecmp("alg_0" , str)) return ALG_0; |
25 | if (!strcasecmp("alg_1" , str)) return ALG_1; |
26 | if (!strcasecmp("alg_2" , str)) return ALG_2; |
27 | return ALG_AUTO; |
28 | } |
29 | |
30 | const char *check_alg2str(check_alg_t alg) { |
31 | switch (alg) { |
32 | case ALG_0: return "alg_0" ; |
33 | case ALG_1: return "alg_1" ; |
34 | case ALG_2: return "alg_2" ; |
35 | case ALG_AUTO: return "alg_auto" ; |
36 | } |
37 | return "alg_auto" ; |
38 | } |
39 | |
40 | flags_t str2flags(const char *str) { |
41 | flags_t flags = NONE; |
42 | while (str && *str) { |
43 | if (*str == 'G') flags |= GLOB_STATS; |
44 | if (*str == 'C') flags |= USE_SCALE; |
45 | if (*str == 'H') flags |= USE_SHIFT; |
46 | if (*str == 'R') flags |= FUSE_NORM_RELU; |
47 | if (*str == 'A') flags |= FUSE_NORM_ADD_RELU; |
48 | str++; |
49 | } |
50 | return flags; |
51 | } |
52 | |
53 | std::string flags2str(flags_t flags) { |
54 | std::string str; |
55 | if (flags & GLOB_STATS) str += "G" ; |
56 | if (flags & USE_SCALE) str += "C" ; |
57 | if (flags & USE_SHIFT) str += "H" ; |
58 | if (flags & FUSE_NORM_RELU) str += "R" ; |
59 | if (flags & FUSE_NORM_ADD_RELU) str += "A" ; |
60 | return str; |
61 | } |
62 | |
63 | int str2desc(desc_t *desc, const char *str) { |
64 | // Canonical form: mbXicXihXiwXidXepsYnS, |
65 | // where |
66 | // X is integer |
67 | // Y is float |
68 | // S is string |
69 | // note: symbol `_` is ignored. |
70 | // Cubic/square shapes are supported by specifying just highest dimension. |
71 | |
72 | desc_t d {0}; |
73 | d.mb = 2; |
74 | d.eps = 1.f / 16; |
75 | |
76 | const char *s = str; |
77 | assert(s); |
78 | |
79 | auto mstrtol = [](const char *nptr, char **endptr) { |
80 | return strtol(nptr, endptr, 10); |
81 | }; |
82 | |
83 | #define CASE_NN(prb, c, cvfunc) \ |
84 | do { \ |
85 | if (!strncmp(prb, s, strlen(prb))) { \ |
86 | ok = 1; \ |
87 | s += strlen(prb); \ |
88 | char *end_s; \ |
89 | d.c = cvfunc(s, &end_s); \ |
90 | s += (end_s - s); \ |
91 | if (d.c < 0) return FAIL; \ |
92 | /* printf("@@@debug: %s: " IFMT "\n", prb, d. c); */ \ |
93 | } \ |
94 | } while (0) |
95 | #define CASE_N(c, cvfunc) CASE_NN(#c, c, cvfunc) |
96 | while (*s) { |
97 | int ok = 0; |
98 | CASE_N(mb, mstrtol); |
99 | CASE_N(ic, mstrtol); |
100 | CASE_N(id, mstrtol); |
101 | CASE_N(ih, mstrtol); |
102 | CASE_N(iw, mstrtol); |
103 | CASE_N(eps, strtof); |
104 | if (*s == 'n') { |
105 | d.name = s + 1; |
106 | break; |
107 | } |
108 | if (*s == '_') ++s; |
109 | if (!ok) return FAIL; |
110 | } |
111 | #undef CASE_NN |
112 | #undef CASE_N |
113 | |
114 | if (d.ic == 0) return FAIL; |
115 | |
116 | if (sanitize_desc(d.ndims, {d.id}, {d.ih}, {d.iw}, {1}) != OK) return FAIL; |
117 | |
118 | *desc = d; |
119 | |
120 | return OK; |
121 | } |
122 | |
123 | dims_t desc_t::data_dims() const { |
124 | dims_t data_dims {mb, ic, id, ih, iw}; |
125 | for (int d = 0; d < 5 - ndims; ++d) { |
126 | data_dims.erase(data_dims.begin() + 2); |
127 | } |
128 | |
129 | return data_dims; |
130 | } |
131 | |
132 | std::ostream &operator<<(std::ostream &s, const desc_t &d) { |
133 | bool print_d = true, print_h = true, print_w = true; |
134 | print_dhw(print_d, print_h, print_w, d.ndims, {d.id}, {d.ih}, {d.iw}); |
135 | |
136 | if (canonical || d.mb != 2) s << "mb" << d.mb; |
137 | |
138 | s << "ic" << d.ic; |
139 | |
140 | if (print_d) s << "id" << d.id; |
141 | if (print_h) s << "ih" << d.ih; |
142 | if (print_w) s << "iw" << d.iw; |
143 | |
144 | if (canonical || d.eps != 1.f / 16) s << "eps" << d.eps; |
145 | |
146 | if (!d.name.empty()) s << "n" << d.name; |
147 | |
148 | return s; |
149 | } |
150 | |
151 | std::ostream &operator<<(std::ostream &s, const prb_t &prb) { |
152 | dump_global_params(s); |
153 | settings_t def; |
154 | |
155 | if (canonical || prb.dir != def.dir[0]) s << "--dir=" << prb.dir << " " ; |
156 | if (canonical || prb.dt != def.dt[0]) s << "--dt=" << prb.dt << " " ; |
157 | if (canonical || prb.tag != def.tag[0]) s << "--tag=" << prb.tag << " " ; |
158 | if (canonical || prb.flags != def.flags[0]) |
159 | s << "--flags=" << flags2str(prb.flags) << " " ; |
160 | if (canonical || prb.check_alg != def.check_alg) |
161 | s << "--check-alg=" << check_alg2str(prb.check_alg) << " " ; |
162 | if (canonical || prb.inplace != def.inplace[0]) |
163 | s << "--inplace=" << bool2str(prb.inplace) << " " ; |
164 | if (canonical || prb.debug_check_ws != def.debug_check_ws) |
165 | s << "--debug-check-ws=" << bool2str(prb.debug_check_ws) << " " ; |
166 | |
167 | s << prb.attr; |
168 | if (canonical || prb.ctx_init != def.ctx_init[0]) |
169 | s << "--ctx-init=" << prb.ctx_init << " " ; |
170 | if (canonical || prb.ctx_exe != def.ctx_exe[0]) |
171 | s << "--ctx-exe=" << prb.ctx_exe << " " ; |
172 | |
173 | s << static_cast<const desc_t &>(prb); |
174 | |
175 | return s; |
176 | } |
177 | |
178 | } // namespace bnorm |
179 | |