1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <assert.h>
18#include <stdlib.h>
19#include "bnorm/bnorm.hpp"
20
21namespace bnorm {
22
23check_alg_t str2check_alg(const char *str) {
24 if (!strcasecmp("alg_0", str)) return ALG_0;
25 if (!strcasecmp("alg_1", str)) return ALG_1;
26 if (!strcasecmp("alg_2", str)) return ALG_2;
27 return ALG_AUTO;
28}
29
30const char *check_alg2str(check_alg_t alg) {
31 switch (alg) {
32 case ALG_0: return "alg_0";
33 case ALG_1: return "alg_1";
34 case ALG_2: return "alg_2";
35 case ALG_AUTO: return "alg_auto";
36 }
37 return "alg_auto";
38}
39
40flags_t str2flags(const char *str) {
41 flags_t flags = NONE;
42 while (str && *str) {
43 if (*str == 'G') flags |= GLOB_STATS;
44 if (*str == 'C') flags |= USE_SCALE;
45 if (*str == 'H') flags |= USE_SHIFT;
46 if (*str == 'R') flags |= FUSE_NORM_RELU;
47 if (*str == 'A') flags |= FUSE_NORM_ADD_RELU;
48 str++;
49 }
50 return flags;
51}
52
53std::string flags2str(flags_t flags) {
54 std::string str;
55 if (flags & GLOB_STATS) str += "G";
56 if (flags & USE_SCALE) str += "C";
57 if (flags & USE_SHIFT) str += "H";
58 if (flags & FUSE_NORM_RELU) str += "R";
59 if (flags & FUSE_NORM_ADD_RELU) str += "A";
60 return str;
61}
62
63int str2desc(desc_t *desc, const char *str) {
64 // Canonical form: mbXicXihXiwXidXepsYnS,
65 // where
66 // X is integer
67 // Y is float
68 // S is string
69 // note: symbol `_` is ignored.
70 // Cubic/square shapes are supported by specifying just highest dimension.
71
72 desc_t d {0};
73 d.mb = 2;
74 d.eps = 1.f / 16;
75
76 const char *s = str;
77 assert(s);
78
79 auto mstrtol = [](const char *nptr, char **endptr) {
80 return strtol(nptr, endptr, 10);
81 };
82
83#define CASE_NN(prb, c, cvfunc) \
84 do { \
85 if (!strncmp(prb, s, strlen(prb))) { \
86 ok = 1; \
87 s += strlen(prb); \
88 char *end_s; \
89 d.c = cvfunc(s, &end_s); \
90 s += (end_s - s); \
91 if (d.c < 0) return FAIL; \
92 /* printf("@@@debug: %s: " IFMT "\n", prb, d. c); */ \
93 } \
94 } while (0)
95#define CASE_N(c, cvfunc) CASE_NN(#c, c, cvfunc)
96 while (*s) {
97 int ok = 0;
98 CASE_N(mb, mstrtol);
99 CASE_N(ic, mstrtol);
100 CASE_N(id, mstrtol);
101 CASE_N(ih, mstrtol);
102 CASE_N(iw, mstrtol);
103 CASE_N(eps, strtof);
104 if (*s == 'n') {
105 d.name = s + 1;
106 break;
107 }
108 if (*s == '_') ++s;
109 if (!ok) return FAIL;
110 }
111#undef CASE_NN
112#undef CASE_N
113
114 if (d.ic == 0) return FAIL;
115
116 if (sanitize_desc(d.ndims, {d.id}, {d.ih}, {d.iw}, {1}) != OK) return FAIL;
117
118 *desc = d;
119
120 return OK;
121}
122
123dims_t desc_t::data_dims() const {
124 dims_t data_dims {mb, ic, id, ih, iw};
125 for (int d = 0; d < 5 - ndims; ++d) {
126 data_dims.erase(data_dims.begin() + 2);
127 }
128
129 return data_dims;
130}
131
132std::ostream &operator<<(std::ostream &s, const desc_t &d) {
133 bool print_d = true, print_h = true, print_w = true;
134 print_dhw(print_d, print_h, print_w, d.ndims, {d.id}, {d.ih}, {d.iw});
135
136 if (canonical || d.mb != 2) s << "mb" << d.mb;
137
138 s << "ic" << d.ic;
139
140 if (print_d) s << "id" << d.id;
141 if (print_h) s << "ih" << d.ih;
142 if (print_w) s << "iw" << d.iw;
143
144 if (canonical || d.eps != 1.f / 16) s << "eps" << d.eps;
145
146 if (!d.name.empty()) s << "n" << d.name;
147
148 return s;
149}
150
151std::ostream &operator<<(std::ostream &s, const prb_t &prb) {
152 dump_global_params(s);
153 settings_t def;
154
155 if (canonical || prb.dir != def.dir[0]) s << "--dir=" << prb.dir << " ";
156 if (canonical || prb.dt != def.dt[0]) s << "--dt=" << prb.dt << " ";
157 if (canonical || prb.tag != def.tag[0]) s << "--tag=" << prb.tag << " ";
158 if (canonical || prb.flags != def.flags[0])
159 s << "--flags=" << flags2str(prb.flags) << " ";
160 if (canonical || prb.check_alg != def.check_alg)
161 s << "--check-alg=" << check_alg2str(prb.check_alg) << " ";
162 if (canonical || prb.inplace != def.inplace[0])
163 s << "--inplace=" << bool2str(prb.inplace) << " ";
164 if (canonical || prb.debug_check_ws != def.debug_check_ws)
165 s << "--debug-check-ws=" << bool2str(prb.debug_check_ws) << " ";
166
167 s << prb.attr;
168 if (canonical || prb.ctx_init != def.ctx_init[0])
169 s << "--ctx-init=" << prb.ctx_init << " ";
170 if (canonical || prb.ctx_exe != def.ctx_exe[0])
171 s << "--ctx-exe=" << prb.ctx_exe << " ";
172
173 s << static_cast<const desc_t &>(prb);
174
175 return s;
176}
177
178} // namespace bnorm
179