1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <assert.h> |
18 | #include <stdlib.h> |
19 | |
20 | #include "lrn/lrn.hpp" |
21 | |
22 | namespace lrn { |
23 | |
24 | alg_t str2alg(const char *str) { |
25 | #define CASE(_alg) \ |
26 | if (!strcasecmp(STRINGIFY(_alg), str)) return _alg |
27 | CASE(ACROSS); |
28 | CASE(WITHIN); |
29 | #undef CASE |
30 | assert(!"unknown algorithm" ); |
31 | return ACROSS; |
32 | } |
33 | |
34 | const char *alg2str(alg_t alg) { |
35 | if (alg == ACROSS) return "ACROSS" ; |
36 | if (alg == WITHIN) return "WITHIN" ; |
37 | assert(!"unknown algorithm" ); |
38 | return "unknown algorithm" ; |
39 | } |
40 | |
41 | dnnl_alg_kind_t alg2alg_kind(alg_t alg) { |
42 | if (alg == ACROSS) return dnnl_lrn_across_channels; |
43 | if (alg == WITHIN) return dnnl_lrn_within_channel; |
44 | assert(!"unknown algorithm" ); |
45 | return dnnl_alg_kind_undef; |
46 | } |
47 | |
48 | int str2desc(desc_t *desc, const char *str) { |
49 | // Canonical form: mbXicXidXihXiwX_lsXalphaYbetaYkY_nS, |
50 | // where |
51 | // X is integer |
52 | // Y is float |
53 | // S is string |
54 | // note: symbol `_` is ignored. |
55 | // Cubic/square shapes are supported by specifying just highest dimension. |
56 | |
57 | desc_t d {0}; |
58 | d.mb = 2; |
59 | d.ls = 5; |
60 | d.alpha = 1.f / 8192; // = 0.000122 ~~ 0.0001, but has exact representation |
61 | d.beta = 0.75f; |
62 | d.k = 1; |
63 | |
64 | const char *s = str; |
65 | assert(s); |
66 | |
67 | auto mstrtol = [](const char *nptr, char **endptr) { |
68 | return strtol(nptr, endptr, 10); |
69 | }; |
70 | |
71 | #define CASE_NN(prb, c, cvfunc) \ |
72 | do { \ |
73 | if (!strncmp(prb, s, strlen(prb))) { \ |
74 | ok = 1; \ |
75 | s += strlen(prb); \ |
76 | char *end_s; \ |
77 | d.c = cvfunc(s, &end_s); \ |
78 | s += (end_s - s); \ |
79 | if (d.c < 0) return FAIL; \ |
80 | /* printf("@@@debug: %s: " IFMT "\n", prb, d. c); */ \ |
81 | } \ |
82 | } while (0) |
83 | #define CASE_N(c, cvfunc) CASE_NN(#c, c, cvfunc) |
84 | while (*s) { |
85 | int ok = 0; |
86 | CASE_N(mb, mstrtol); |
87 | CASE_N(ic, mstrtol); |
88 | CASE_N(id, mstrtol); |
89 | CASE_N(ih, mstrtol); |
90 | CASE_N(iw, mstrtol); |
91 | CASE_N(ls, mstrtol); |
92 | CASE_N(alpha, strtof); |
93 | CASE_N(beta, strtof); |
94 | CASE_N(k, strtof); |
95 | if (*s == 'n') { |
96 | d.name = s + 1; |
97 | break; |
98 | } |
99 | if (*s == '_') ++s; |
100 | if (!ok) return FAIL; |
101 | } |
102 | #undef CASE_NN |
103 | #undef CASE_N |
104 | |
105 | if (d.ic == 0) return FAIL; |
106 | |
107 | if (sanitize_desc(d.ndims, {d.id}, {d.ih}, {d.iw}, {1}) != OK) return FAIL; |
108 | |
109 | *desc = d; |
110 | |
111 | return OK; |
112 | } |
113 | |
114 | std::ostream &operator<<(std::ostream &s, const desc_t &d) { |
115 | bool print_d = true, print_h = true, print_w = true; |
116 | print_dhw(print_d, print_h, print_w, d.ndims, {d.id}, {d.ih}, {d.iw}); |
117 | |
118 | if (canonical || d.mb != 2) s << "mb" << d.mb; |
119 | |
120 | s << "ic" << d.ic; |
121 | |
122 | if (print_d) s << "id" << d.id; |
123 | if (print_h) s << "ih" << d.ih; |
124 | if (print_w) s << "iw" << d.iw; |
125 | |
126 | if (canonical || d.ls != 5) s << "ls" << d.ls; |
127 | if (canonical || d.alpha != 1.f / 8192) s << "alpha" << d.alpha; |
128 | if (canonical || d.beta != 0.75f) s << "beta" << d.beta; |
129 | if (canonical || d.k != 1) s << "k" << d.k; |
130 | |
131 | if (!d.name.empty()) s << "n" << d.name; |
132 | |
133 | return s; |
134 | } |
135 | |
136 | std::ostream &operator<<(std::ostream &s, const prb_t &prb) { |
137 | dump_global_params(s); |
138 | settings_t def; |
139 | |
140 | if (canonical || prb.dir != def.dir[0]) s << "--dir=" << prb.dir << " " ; |
141 | if (canonical || prb.dt != def.dt[0]) s << "--dt=" << prb.dt << " " ; |
142 | if (canonical || prb.tag != def.tag[0]) s << "--tag=" << prb.tag << " " ; |
143 | if (canonical || prb.alg != def.alg[0]) |
144 | s << "--alg=" << alg2str(prb.alg) << " " ; |
145 | |
146 | s << prb.attr; |
147 | if (canonical || prb.ctx_init != def.ctx_init[0]) |
148 | s << "--ctx-init=" << prb.ctx_init << " " ; |
149 | if (canonical || prb.ctx_exe != def.ctx_exe[0]) |
150 | s << "--ctx-exe=" << prb.ctx_exe << " " ; |
151 | |
152 | s << static_cast<const desc_t &>(prb); |
153 | |
154 | return s; |
155 | } |
156 | |
157 | } // namespace lrn |
158 | |