1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <assert.h>
18#include <stdlib.h>
19
20#include "lrn/lrn.hpp"
21
22namespace lrn {
23
24alg_t str2alg(const char *str) {
25#define CASE(_alg) \
26 if (!strcasecmp(STRINGIFY(_alg), str)) return _alg
27 CASE(ACROSS);
28 CASE(WITHIN);
29#undef CASE
30 assert(!"unknown algorithm");
31 return ACROSS;
32}
33
34const char *alg2str(alg_t alg) {
35 if (alg == ACROSS) return "ACROSS";
36 if (alg == WITHIN) return "WITHIN";
37 assert(!"unknown algorithm");
38 return "unknown algorithm";
39}
40
41dnnl_alg_kind_t alg2alg_kind(alg_t alg) {
42 if (alg == ACROSS) return dnnl_lrn_across_channels;
43 if (alg == WITHIN) return dnnl_lrn_within_channel;
44 assert(!"unknown algorithm");
45 return dnnl_alg_kind_undef;
46}
47
48int str2desc(desc_t *desc, const char *str) {
49 // Canonical form: mbXicXidXihXiwX_lsXalphaYbetaYkY_nS,
50 // where
51 // X is integer
52 // Y is float
53 // S is string
54 // note: symbol `_` is ignored.
55 // Cubic/square shapes are supported by specifying just highest dimension.
56
57 desc_t d {0};
58 d.mb = 2;
59 d.ls = 5;
60 d.alpha = 1.f / 8192; // = 0.000122 ~~ 0.0001, but has exact representation
61 d.beta = 0.75f;
62 d.k = 1;
63
64 const char *s = str;
65 assert(s);
66
67 auto mstrtol = [](const char *nptr, char **endptr) {
68 return strtol(nptr, endptr, 10);
69 };
70
71#define CASE_NN(prb, c, cvfunc) \
72 do { \
73 if (!strncmp(prb, s, strlen(prb))) { \
74 ok = 1; \
75 s += strlen(prb); \
76 char *end_s; \
77 d.c = cvfunc(s, &end_s); \
78 s += (end_s - s); \
79 if (d.c < 0) return FAIL; \
80 /* printf("@@@debug: %s: " IFMT "\n", prb, d. c); */ \
81 } \
82 } while (0)
83#define CASE_N(c, cvfunc) CASE_NN(#c, c, cvfunc)
84 while (*s) {
85 int ok = 0;
86 CASE_N(mb, mstrtol);
87 CASE_N(ic, mstrtol);
88 CASE_N(id, mstrtol);
89 CASE_N(ih, mstrtol);
90 CASE_N(iw, mstrtol);
91 CASE_N(ls, mstrtol);
92 CASE_N(alpha, strtof);
93 CASE_N(beta, strtof);
94 CASE_N(k, strtof);
95 if (*s == 'n') {
96 d.name = s + 1;
97 break;
98 }
99 if (*s == '_') ++s;
100 if (!ok) return FAIL;
101 }
102#undef CASE_NN
103#undef CASE_N
104
105 if (d.ic == 0) return FAIL;
106
107 if (sanitize_desc(d.ndims, {d.id}, {d.ih}, {d.iw}, {1}) != OK) return FAIL;
108
109 *desc = d;
110
111 return OK;
112}
113
114std::ostream &operator<<(std::ostream &s, const desc_t &d) {
115 bool print_d = true, print_h = true, print_w = true;
116 print_dhw(print_d, print_h, print_w, d.ndims, {d.id}, {d.ih}, {d.iw});
117
118 if (canonical || d.mb != 2) s << "mb" << d.mb;
119
120 s << "ic" << d.ic;
121
122 if (print_d) s << "id" << d.id;
123 if (print_h) s << "ih" << d.ih;
124 if (print_w) s << "iw" << d.iw;
125
126 if (canonical || d.ls != 5) s << "ls" << d.ls;
127 if (canonical || d.alpha != 1.f / 8192) s << "alpha" << d.alpha;
128 if (canonical || d.beta != 0.75f) s << "beta" << d.beta;
129 if (canonical || d.k != 1) s << "k" << d.k;
130
131 if (!d.name.empty()) s << "n" << d.name;
132
133 return s;
134}
135
136std::ostream &operator<<(std::ostream &s, const prb_t &prb) {
137 dump_global_params(s);
138 settings_t def;
139
140 if (canonical || prb.dir != def.dir[0]) s << "--dir=" << prb.dir << " ";
141 if (canonical || prb.dt != def.dt[0]) s << "--dt=" << prb.dt << " ";
142 if (canonical || prb.tag != def.tag[0]) s << "--tag=" << prb.tag << " ";
143 if (canonical || prb.alg != def.alg[0])
144 s << "--alg=" << alg2str(prb.alg) << " ";
145
146 s << prb.attr;
147 if (canonical || prb.ctx_init != def.ctx_init[0])
148 s << "--ctx-init=" << prb.ctx_init << " ";
149 if (canonical || prb.ctx_exe != def.ctx_exe[0])
150 s << "--ctx-exe=" << prb.ctx_exe << " ";
151
152 s << static_cast<const desc_t &>(prb);
153
154 return s;
155}
156
157} // namespace lrn
158